From bec1db79645f0aba3469def8d9cd9d7b0cc54db5 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 16 Oct 2012 22:17:15 -0700 Subject: [PATCH 0001/1502] Added README. --- README | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 README diff --git a/README b/README new file mode 100644 index 00000000..12674266 --- /dev/null +++ b/README @@ -0,0 +1,10 @@ +Tulip is the codename for my attempt at understanding PEP-380 style +coroutines (i.e. those using generators and 'yield from'). + +For reference, see many threads in python-ideas@python.org started in +October 2012, especially those with "The async API of the Future" in +their subject, and the various spin-off threads. + +Copyright/license: Open source, Apache 2.0. Enjoy. + +--Guido van Rossum From 45a772b617556712a35276dd0ceb54434204aac9 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 16 Oct 2012 22:27:35 -0700 Subject: [PATCH 0002/1502] Add .hgignore. --- .hgignore | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .hgignore diff --git a/.hgignore b/.hgignore new file mode 100644 index 00000000..e3600e75 --- /dev/null +++ b/.hgignore @@ -0,0 +1,6 @@ +.*\.py[co]$ +.*~$ +.*\.orig$ +.*\#.*$ +.*@.*$ +\.DS_Store$ From c565b7e268bc469c48822ecab71026293cd52c82 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 17 Oct 2012 07:38:10 -0700 Subject: [PATCH 0003/1502] Note Python version. --- README | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README b/README index 12674266..3a463354 100644 --- a/README +++ b/README @@ -5,6 +5,8 @@ For reference, see many threads in python-ideas@python.org started in October 2012, especially those with "The async API of the Future" in their subject, and the various spin-off threads. +Python version: 3.3. + Copyright/license: Open source, Apache 2.0. Enjoy. --Guido van Rossum From 72c4cb5c7c4db4250e35be1e8f53a94e1c462a9c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 17 Oct 2012 07:43:51 -0700 Subject: [PATCH 0004/1502] Update TODO. --- main.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index 06e1414d..ecad2fe5 100644 --- a/main.py +++ b/main.py @@ -8,8 +8,10 @@ Some incomplete laundry lists: TODO: -- Use poll() or better; need to figure out how to keep fds registered. +- Use poll() etc.; need to figure out how to keep fds registered. +- Support an external event loop. - Separate scheduler and event loop. +- Make coroutines an optinal part of the API. - A more varied set of test URLs. - A Hg repo. - Profiling. From e9f46715e52a5b96d64c8506cae6c9b102a76a6d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 17 Oct 2012 07:51:57 -0700 Subject: [PATCH 0005/1502] Take "hg repo" from TODO list. --- main.py | 1 - 1 file changed, 1 deletion(-) diff --git a/main.py b/main.py index ecad2fe5..34cff3d8 100644 --- a/main.py +++ b/main.py @@ -13,7 +13,6 @@ - Separate scheduler and event loop. - Make coroutines an optinal part of the API. - A more varied set of test URLs. -- A Hg repo. - Profiling. - Unittests. From 71328fc30e84c596737bbd04535f4585b291c0a8 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 17 Oct 2012 09:59:47 -0700 Subject: [PATCH 0006/1502] Switch to polling; separate coroutine scheduler from pollster. --- main.py | 143 ++++++++++++++++++++++++++++------------------------- polling.py | 117 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 192 insertions(+), 68 deletions(-) create mode 100644 polling.py diff --git a/main.py b/main.py index 34cff3d8..6ab4d0e1 100644 --- a/main.py +++ b/main.py @@ -8,10 +8,7 @@ Some incomplete laundry lists: TODO: -- Use poll() etc.; need to figure out how to keep fds registered. -- Support an external event loop. -- Separate scheduler and event loop. -- Make coroutines an optinal part of the API. +- Cancellation. - A more varied set of test URLs. - Profiling. - Unittests. @@ -32,6 +29,7 @@ __author__ = 'Guido van Rossum ' +# Standard library imports (keep in alphabetic order). import collections import errno import logging @@ -40,70 +38,70 @@ import socket import time +# Local imports (keep in alphabetic order). +import polling + class Scheduler: - def __init__(self): - self.runnable = collections.deque() - self.current = None - self.readers = {} - self.writers = {} - - def run(self, task): - self.runnable.append(task) - - def loop(self): - while self.runnable or self.readers or self.writers: - self.loop1() - - def loop1(self): -## print('loop1') - while self.runnable: - self.current = self.runnable.popleft() - try: - next(self.current) - except StopIteration: - self.current = None - except Exception: - self.current = None - logging.exception('Exception in task') - else: - if self.current is not None: - self.runnable.append(self.current) - self.current = None - if self.readers or self.writers: - # TODO: Schedule timed calls as well. - # TODO: Use poll() or better. - t0 = time.time() - ready_r, ready_w, _ = select.select(self.readers, self.writers, []) - t1 = time.time() -## print('select({}, {}) took {:.3f} secs to return {}, {}' -## .format(list(self.readers), list(self.writers), -## t1 - t0, ready_r, ready_w)) - for fd in ready_r: - self.unblock(self.readers.pop(fd)) - for fd in ready_w: - self.unblock(self.writers.pop(fd)) - - def unblock(self, task): - assert task - self.runnable.append(task) - - def block(self, queue, fd): - assert isinstance(fd, int) - assert fd not in queue - assert self.current is not None - queue[fd] = self.current + def __init__(self, ioloop): + self.ioloop = ioloop self.current = None + self.current_name = None + + def run(self): + self.ioloop.run() + + def start(self, task, name): + self.ioloop.call_soon(self.run_task, task, name) + + def run_task(self, task, name): + try: + self.current = task + self.current_name = name + next(self.current) + except StopIteration: + pass + except Exception: + logging.exception('Exception in task %r', name) + else: + if self.current is not None: + self.start(task, name) + finally: + self.current = None + self.current_name = None + def block_r(self, fd): - self.block(self.readers, fd) + self.block(fd, 'r') def block_w(self, fd): - self.block(self.writers, fd) + self.block(fd, 'w') + + def block(self, fd, flag): + assert isinstance(fd, int), repr(fd) + assert flag in ('r', 'w'), repr(flag) + assert self.current is not None + task = self.current + self.current = None + if flag == 'r': + method = self.ioloop.add_reader + callback = self.unblock_r + else: + method = self.ioloop.add_writer + callback = self.unblock_w + method(fd, callback, fd, task, self.current_name) + + def unblock_r(self, fd, task, name): + self.ioloop.remove_reader(fd) + self.start(task, name) + + def unblock_w(self, fd, task, name): + self.ioloop.remove_writer(fd) + self.start(task, name) -sched = Scheduler() +sched = Scheduler(polling.best_pollster()) class RawReader: @@ -274,17 +272,26 @@ def urlfetch(host, port=80, method='GET', path='/', def doit(): + # This references NDB's default test service. + # (Sadly the service is single-threaded.) gen1 = urlfetch('127.0.0.1', 8080, path='/', hdrs={'host': 'localhost'}) - gen2 = urlfetch('82.94.164.162', 80, path='/', hdrs={'host': 'python.org'}) - sched.run(gen1) - sched.run(gen2) - for x in '123': - for y in '0123456789': - g = urlfetch('82.94.164.162', 80, - path='/{}.{}'.format(x, y), - hdrs={'host': 'python.org'}) - sched.run(g) - sched.loop() + sched.start(gen1, 'gen1') + gen2 = urlfetch('127.0.0.1', 8080, path='/home', hdrs={'host': 'localhost'}) + sched.start(gen2, 'gen2') + +## # Fetch python.org home page. +## gen3 = urlfetch('82.94.164.162', 80, path='/', +## hdrs={'host': 'python.org'}) +## sched.run(gen3, 'gen3') + +## # Fetch many links from python.org (/x.y.z). +## for x in '123': +## for y in '0123456789': +## path = '/{}.{}'.format(x, y) +## g = urlfetch('82.94.164.162', 80, +## path=path, hdrs={'host': 'python.org'}) +## sched.run(g, path) + sched.run() def main(): diff --git a/polling.py b/polling.py new file mode 100644 index 00000000..1e866daa --- /dev/null +++ b/polling.py @@ -0,0 +1,117 @@ +"""I/O loop based on poll(). + +TODO: +- Docstrings. +- Use _ for non-public methods and instance variables. +- Support some of the other POLL* flags. +- Use kpoll(), epoll() etc. in preference. +- Fall back on select() if no poll() variant at all. +- Keyword args to callbacks. +""" + +import collections +import heapq +import logging +import select + + +class Pollster: + + def __init__(self): + self.ready = collections.deque() # [(callback, args), ...] + self.scheduled = [] # [(when, callback, args), ...] + self.readers = {} # {fd: (callback, args), ...}. + self.writers = {} # {fd: (callback, args), ...}. + self.pollster = select.poll() + + def update(self, fd): + assert isinstance(fd, int), fd + flags = 0 + if fd in self.readers: + flags |= select.POLLIN + if fd in self.writers: + flags |= select.POLLOUT + if flags: + self.pollster.register(fd, flags) + else: + self.pollster.unregister(fd) + + def add_reader(self, fd, callback, *args): + self.readers[fd] = (callback, args) + self.update(fd) + + def add_writer(self, fd, callback, *args): + self.writers[fd] = (callback, args) + self.update(fd) + + def remove_reader(self, fd): + del self.readers[fd] + self.update(fd) + + def remove_writer(self, fd): + del self.writers[fd] + self.update(fd) + + def call_soon(self, callback, *args): + self.ready.append((callback, args)) + + def call_later(self, when, callback, *args): + if when < 10000000: + when += time.time() + heapq.heappush(self.scheduled, (when, callback, args)) + + def rawpoll(self, timeout=None): + # Timeout is in seconds, but poll() takes milliseconds. :-( + msecs = None if timeout is None else int(1000 * timeout) + quads = [] + for fd, flags in self.pollster.poll(): + if flags & select.POLLIN: + if fd in self.readers: + callback, args = self.readers[fd] + quads.append((fd, select.POLLIN, callback, args)) + if flags & select.POLLOUT: + if fd in self.writers: + callback, args = self.writers[fd] + quads.append((fd, select.POLLOUT, callback, args)) + return quads + + def run_once(self): + # TODO: Break each of these into smaller pieces. + # TODO: Pass in a timeout or deadline or something. + # TODO: Refactor to separate the callbacks from the readers/writers. + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + while self.ready: + callback, args = self.ready.popleft() + try: + callback(*args) + except Exception: + logging.exception('Exception in callback %s %r', callback, args) + + # Inspect the poll queue. + if self.readers or self.writers: + if self.scheduled: + when, _, _ = self.scheduled[0] + timeout = max(0, when - time.time()) + else: + timeout = None + quads = self.rawpoll(timeout) + for fd, flag, callback, args in quads: + self.call_soon(callback, *args) + + # Handle 'later' callbacks that are ready. + while self.scheduled: + when, _, _ = self.scheduled[0] + if when > time.time(): + break + when, callback, args = heapq.heappop(self.scheduled) + self.call_soon(callback, *args) + + def run(self): + while self.ready or self.readers or self.writers or self.scheduled: + self.run_once() + + +def best_pollster(): + return Pollster() From e2e1072bc118bf38ab189e183cf579dbffc08336 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 17 Oct 2012 11:01:53 -0700 Subject: [PATCH 0007/1502] Link to Greg Ewing's tutorial. --- README | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README b/README index 3a463354..35541bcb 100644 --- a/README +++ b/README @@ -5,6 +5,9 @@ For reference, see many threads in python-ideas@python.org started in October 2012, especially those with "The async API of the Future" in their subject, and the various spin-off threads. +A particularly influential tutorial by Greg Ewing: +http://www.cosc.canterbury.ac.nz/greg.ewing/python/generators/yf_current/Examples/Scheduler/scheduler.txt + Python version: 3.3. Copyright/license: Open source, Apache 2.0. Enjoy. From f4d3e92baadc2525c440bea55bc0b4573748d9b2 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 17 Oct 2012 11:25:11 -0700 Subject: [PATCH 0008/1502] Make name optional; generators have names. --- main.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 6ab4d0e1..9d2ebdfb 100644 --- a/main.py +++ b/main.py @@ -52,7 +52,9 @@ def __init__(self, ioloop): def run(self): self.ioloop.run() - def start(self, task, name): + def start(self, task, name=None): + if name is None: + name = task.__name__ # If it doesn't have one, pass one. self.ioloop.call_soon(self.run_task, task, name) def run_task(self, task, name): @@ -279,10 +281,10 @@ def doit(): gen2 = urlfetch('127.0.0.1', 8080, path='/home', hdrs={'host': 'localhost'}) sched.start(gen2, 'gen2') -## # Fetch python.org home page. -## gen3 = urlfetch('82.94.164.162', 80, path='/', -## hdrs={'host': 'python.org'}) -## sched.run(gen3, 'gen3') + # Fetch python.org home page. + gen3 = urlfetch('82.94.164.162', 80, path='/', + hdrs={'host': 'python.org'}) + sched.start(gen3, 'gen3') ## # Fetch many links from python.org (/x.y.z). ## for x in '123': @@ -291,6 +293,7 @@ def doit(): ## g = urlfetch('82.94.164.162', 80, ## path=path, hdrs={'host': 'python.org'}) ## sched.run(g, path) + sched.run() From 2e38f359943044eccb6def1294a32ef013d0c0a5 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 17 Oct 2012 19:19:31 -0700 Subject: [PATCH 0009/1502] Add call_in_thread(), and use it to call getaddrinfo(). --- main.py | 78 +++++++++++++++++++++++++++++++++++++++++++----------- polling.py | 20 +++++++++++--- 2 files changed, 79 insertions(+), 19 deletions(-) diff --git a/main.py b/main.py index 9d2ebdfb..457df788 100644 --- a/main.py +++ b/main.py @@ -30,6 +30,7 @@ __author__ = 'Guido van Rossum ' # Standard library imports (keep in alphabetic order). +import concurrent.futures import collections import errno import logging @@ -75,24 +76,34 @@ def run_task(self, task, name): def block_r(self, fd): - self.block(fd, 'r') + self.block_io(fd, 'r') def block_w(self, fd): - self.block(fd, 'w') + self.block_io(fd, 'w') - def block(self, fd, flag): + def block_io(self, fd, flag): assert isinstance(fd, int), repr(fd) assert flag in ('r', 'w'), repr(flag) - assert self.current is not None - task = self.current - self.current = None + task, name = self.block() if flag == 'r': method = self.ioloop.add_reader callback = self.unblock_r else: method = self.ioloop.add_writer callback = self.unblock_w - method(fd, callback, fd, task, self.current_name) + method(fd, callback, fd, task, name) + + def block_future(self, future): + task, name = self.block() + self.ioloop.add_future(future) + # TODO: Don't use closures or lambdas. + future.add_done_callback(lambda unused_future: self.start(task, name)) + + def block(self): + assert self.current + task = self.current + self.current = None + return task, self.current_name def unblock_r(self, fd, task, name): self.ioloop.remove_reader(fd) @@ -106,7 +117,23 @@ def unblock_w(self, fd, task, name): sched = Scheduler(polling.best_pollster()) +max_workers = 5 +threadpool = None # Thread pool, lazily initialized. + +def call_in_thread(func, *args, **kwds): + # TODO: Timeout? + global threadpool + if threadpool is None: + threadpool = concurrent.futures.ThreadPoolExecutor(max_workers) + future = threadpool.submit(func, *args, **kwds) + sched.block_future(future) + yield + assert future.done() + return future.result() + + class RawReader: + # TODO: Merge with send() and newsocket() functions. def __init__(self, sock): self.sock = sock @@ -188,8 +215,8 @@ def send(sock, data): data = data[n:] -def newsocket(): - sock = socket.socket() +def newsocket(af, socktype, proto): + sock = socket.socket(af, socktype, proto) sock.setblocking(False) return sock @@ -212,9 +239,26 @@ def urlfetch(host, port=80, method='GET', path='/', t0 = time.time() # Must pass in an IP address. Later we'll call getaddrinfo() # using a thread pool. We'll also support IPv6. - assert re.match(r'(\d+)(\.\d+)(\.\d+)(\.\d+)\Z', host), repr(host) - sock = newsocket() - yield from connect(sock, (host, port)) + if not re.match(r'(\d+)(\.\d+)(\.\d+)(\.\d+)\Z', host): + infos = yield from call_in_thread(socket.getaddrinfo, + host, port, socket.AF_INET, + socket.SOCK_STREAM, + socket.SOL_TCP) + else: + infos = [(socket.AF_INET, socket.SOCK_STREAM, socket.SOL_TCP, '', + (host, port))] + assert infos, 'No address info for (%r, %r)' % (host, port) + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = newsocket(af, socktype, proto) + yield from connect(sock, address) + break + except socket.error: + if sock is not None: + sock.close() + else: + raise yield from send(sock, method.encode(encoding) + b' ' + path.encode(encoding) + b' HTTP/1.0\r\n') @@ -222,6 +266,8 @@ def urlfetch(host, port=80, method='GET', path='/', kwds = dict(hdrs) else: kwds = {} + if 'host' not in kwds: + kwds['host'] = host if body is not None: kwds['content_length'] = len(body) for header, value in kwds.items(): @@ -276,14 +322,14 @@ def urlfetch(host, port=80, method='GET', path='/', def doit(): # This references NDB's default test service. # (Sadly the service is single-threaded.) - gen1 = urlfetch('127.0.0.1', 8080, path='/', hdrs={'host': 'localhost'}) + gen1 = urlfetch('localhost', 8080, path='/') sched.start(gen1, 'gen1') - gen2 = urlfetch('127.0.0.1', 8080, path='/home', hdrs={'host': 'localhost'}) + + gen2 = urlfetch('localhost', 8080, path='/home') sched.start(gen2, 'gen2') # Fetch python.org home page. - gen3 = urlfetch('82.94.164.162', 80, path='/', - hdrs={'host': 'python.org'}) + gen3 = urlfetch('python.org', 80, path='/') sched.start(gen3, 'gen3') ## # Fetch many links from python.org (/x.y.z). diff --git a/polling.py b/polling.py index 1e866daa..9b983d00 100644 --- a/polling.py +++ b/polling.py @@ -13,6 +13,7 @@ import heapq import logging import select +import time class Pollster: @@ -22,6 +23,7 @@ def __init__(self): self.scheduled = [] # [(when, callback, args), ...] self.readers = {} # {fd: (callback, args), ...}. self.writers = {} # {fd: (callback, args), ...}. + self.futures = set() # {concurrent.futures.Future(), ...}. self.pollster = select.poll() def update(self, fd): @@ -52,6 +54,10 @@ def remove_writer(self, fd): del self.writers[fd] self.update(fd) + def add_future(self, future): + self.futures.add(future) + future.add_done_callback(self.futures.remove) + def call_soon(self, callback, *args): self.ready.append((callback, args)) @@ -64,7 +70,7 @@ def rawpoll(self, timeout=None): # Timeout is in seconds, but poll() takes milliseconds. :-( msecs = None if timeout is None else int(1000 * timeout) quads = [] - for fd, flags in self.pollster.poll(): + for fd, flags in self.pollster.poll(msecs): if flags & select.POLLIN: if fd in self.readers: callback, args = self.readers[fd] @@ -90,12 +96,19 @@ def run_once(self): logging.exception('Exception in callback %s %r', callback, args) # Inspect the poll queue. - if self.readers or self.writers: + if self.readers or self.writers or self.futures: if self.scheduled: when, _, _ = self.scheduled[0] timeout = max(0, when - time.time()) else: timeout = None + if self.futures: + # When there's a pending future, wait no more than 100 msec. + # TODO: Find a more reasonable way to wait for Futures. + if timeout == None: + timeout = 0.1 + else: + timeout = min(timeout, 0.1) quads = self.rawpoll(timeout) for fd, flag, callback, args in quads: self.call_soon(callback, *args) @@ -109,7 +122,8 @@ def run_once(self): self.call_soon(callback, *args) def run(self): - while self.ready or self.readers or self.writers or self.scheduled: + while (self.ready or self.readers or self.writers or + self.scheduled or self.futures): self.run_once() From 989d686ee30f9a541673d5913e2bad4a1e1af17f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 17 Oct 2012 23:47:34 -0700 Subject: [PATCH 0010/1502] Correct typos. --- main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 457df788..255d6696 100644 --- a/main.py +++ b/main.py @@ -288,7 +288,7 @@ def urlfetch(host, port=80, method='GET', path='/', m = re.match(br'(?ix) http/(\d\.\d) \s+ (\d\d\d) \s+ ([^\r]*)\r?\n\Z', resp) if not m: sock.close() - raise IOError('No valid HTTP response: %r' % response) + raise IOError('No valid HTTP response: %r' % resp) http_version, status, message = m.groups() # Read HTTP headers. @@ -338,7 +338,7 @@ def doit(): ## path = '/{}.{}'.format(x, y) ## g = urlfetch('82.94.164.162', 80, ## path=path, hdrs={'host': 'python.org'}) -## sched.run(g, path) +## sched.start(g, path) sched.run() From 56aafb644f465aec3294c91fae637edcae764bf8 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 18 Oct 2012 08:58:10 -0700 Subject: [PATCH 0011/1502] Rewrite call_in_thread() using a pipe for signalling. --- main.py | 25 +++++++------------------ polling.py | 48 +++++++++++++++++++++++++++++++----------------- 2 files changed, 38 insertions(+), 35 deletions(-) diff --git a/main.py b/main.py index 255d6696..9762b727 100644 --- a/main.py +++ b/main.py @@ -23,14 +23,12 @@ - Chunked encoding (request and response). - Pipelining, e.g. zlib (request and response). - Automatic encoding/decoding. -- Thread pool and getaddrinfo() calling. - A write() call that isn't a generator. """ __author__ = 'Guido van Rossum ' # Standard library imports (keep in alphabetic order). -import concurrent.futures import collections import errno import logging @@ -93,12 +91,6 @@ def block_io(self, fd, flag): callback = self.unblock_w method(fd, callback, fd, task, name) - def block_future(self, future): - task, name = self.block() - self.ioloop.add_future(future) - # TODO: Don't use closures or lambdas. - future.add_done_callback(lambda unused_future: self.start(task, name)) - def block(self): assert self.current task = self.current @@ -114,19 +106,16 @@ def unblock_w(self, fd, task, name): self.start(task, name) -sched = Scheduler(polling.best_pollster()) - +ioloop = polling.Pollster() +trunner = polling.ThreadRunner(ioloop) +sched = Scheduler(ioloop) -max_workers = 5 -threadpool = None # Thread pool, lazily initialized. def call_in_thread(func, *args, **kwds): - # TODO: Timeout? - global threadpool - if threadpool is None: - threadpool = concurrent.futures.ThreadPoolExecutor(max_workers) - future = threadpool.submit(func, *args, **kwds) - sched.block_future(future) + # TODO: Prove there is no race condition here. + task, name = sched.block() + future = trunner.submit(func, *args, **kwds) + future.add_done_callback(lambda _: sched.start(task, name)) yield assert future.done() return future.result() diff --git a/polling.py b/polling.py index 9b983d00..a4db88f0 100644 --- a/polling.py +++ b/polling.py @@ -10,8 +10,10 @@ """ import collections +import concurrent.futures import heapq import logging +import os import select import time @@ -23,7 +25,6 @@ def __init__(self): self.scheduled = [] # [(when, callback, args), ...] self.readers = {} # {fd: (callback, args), ...}. self.writers = {} # {fd: (callback, args), ...}. - self.futures = set() # {concurrent.futures.Future(), ...}. self.pollster = select.poll() def update(self, fd): @@ -54,10 +55,6 @@ def remove_writer(self, fd): del self.writers[fd] self.update(fd) - def add_future(self, future): - self.futures.add(future) - future.add_done_callback(self.futures.remove) - def call_soon(self, callback, *args): self.ready.append((callback, args)) @@ -96,19 +93,12 @@ def run_once(self): logging.exception('Exception in callback %s %r', callback, args) # Inspect the poll queue. - if self.readers or self.writers or self.futures: + if self.readers or self.writers: if self.scheduled: when, _, _ = self.scheduled[0] timeout = max(0, when - time.time()) else: timeout = None - if self.futures: - # When there's a pending future, wait no more than 100 msec. - # TODO: Find a more reasonable way to wait for Futures. - if timeout == None: - timeout = 0.1 - else: - timeout = min(timeout, 0.1) quads = self.rawpoll(timeout) for fd, flag, callback, args in quads: self.call_soon(callback, *args) @@ -122,10 +112,34 @@ def run_once(self): self.call_soon(callback, *args) def run(self): - while (self.ready or self.readers or self.writers or - self.scheduled or self.futures): + while self.ready or self.readers or self.writers or self.scheduled: self.run_once() -def best_pollster(): - return Pollster() +class ThreadRunner: + + def __init__(self, ioloop, max_workers=5): + self.ioloop = ioloop + self.threadpool = concurrent.futures.ThreadPoolExecutor(max_workers) + self.pipe_read_fd, self.pipe_write_fd = os.pipe() + self.active_count = 0 + + def read_callback(self): + # Semi-permanent callback while at least one future is active. + assert self.active_count > 0, self.active_count + data = os.read(self.pipe_read_fd, 8192) + self.active_count -= len(data) + if self.active_count == 0: + self.ioloop.remove_reader(self.pipe_read_fd) + assert self.active_count >= 0, self.active_count + + def submit(self, func, *args, **kwds): + assert self.active_count >= 0, self.active_count + future = self.threadpool.submit(func, *args, **kwds) + if self.active_count == 0: + self.ioloop.add_reader(self.pipe_read_fd, self.read_callback) + self.active_count += 1 + def done_callback(future): + os.write(self.pipe_write_fd, b'x') + future.add_done_callback(done_callback) + return future From 5b1b297e0e1d71894cbc4703ff3546b683b92ed7 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 18 Oct 2012 11:06:58 -0700 Subject: [PATCH 0012/1502] Fix exception reraising. Make connect() work on Linux. --- main.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 9762b727..1d535d3a 100644 --- a/main.py +++ b/main.py @@ -219,7 +219,7 @@ def connect(sock, address): err = sock.connect_ex(address) if err == errno.ECONNREFUSED: raise IOError('Connection refused') - if err != errno.EISCONN: + if err not in (0, errno.EISCONN): raise IOError('Connect error %d: %s' % (err, errno.errorcode.get(err))) @@ -237,17 +237,21 @@ def urlfetch(host, port=80, method='GET', path='/', infos = [(socket.AF_INET, socket.SOCK_STREAM, socket.SOL_TCP, '', (host, port))] assert infos, 'No address info for (%r, %r)' % (host, port) + exc = None for af, socktype, proto, cname, address in infos: sock = None try: sock = newsocket(af, socktype, proto) yield from connect(sock, address) break - except socket.error: + except socket.error as err: if sock is not None: sock.close() + if exc is None: + exc = err else: - raise + if exc is not None: + raise exc yield from send(sock, method.encode(encoding) + b' ' + path.encode(encoding) + b' HTTP/1.0\r\n') From 87839c22cdb10f5aecb2658e1a18074abcc88539 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 18 Oct 2012 11:10:47 -0700 Subject: [PATCH 0013/1502] Test submit. --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index 1d535d3a..843ffe3c 100644 --- a/main.py +++ b/main.py @@ -8,7 +8,7 @@ Some incomplete laundry lists: TODO: -- Cancellation. +- Cancellation? - A more varied set of test URLs. - Profiling. - Unittests. From d027c345338ff548721d74ba9493b7123d1af2af Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 18 Oct 2012 11:13:39 -0700 Subject: [PATCH 0014/1502] Test submit. --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index 843ffe3c..0b6ef0d6 100644 --- a/main.py +++ b/main.py @@ -8,8 +8,8 @@ Some incomplete laundry lists: TODO: +- Take test urls from command line. - Cancellation? -- A more varied set of test URLs. - Profiling. - Unittests. From a1384c8b0a796e0834beb9abbd8c489c1ca13591 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 18 Oct 2012 20:43:28 -0700 Subject: [PATCH 0015/1502] Recognize if poll() returns POLLHUP. --- polling.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/polling.py b/polling.py index a4db88f0..e4b3d636 100644 --- a/polling.py +++ b/polling.py @@ -68,14 +68,14 @@ def rawpoll(self, timeout=None): msecs = None if timeout is None else int(1000 * timeout) quads = [] for fd, flags in self.pollster.poll(msecs): - if flags & select.POLLIN: + if flags & (select.POLLIN | select.POLLHUP): if fd in self.readers: callback, args = self.readers[fd] - quads.append((fd, select.POLLIN, callback, args)) - if flags & select.POLLOUT: + quads.append((fd, flags, callback, args)) + if flags & (select.POLLOUT | select.POLLHUP): if fd in self.writers: callback, args = self.writers[fd] - quads.append((fd, select.POLLOUT, callback, args)) + quads.append((fd, flags, callback, args)) return quads def run_once(self): From ef538e16bbe472716021ed21422057e5a0d9b7f3 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 18 Oct 2012 20:48:29 -0700 Subject: [PATCH 0016/1502] Drop obsolete TODO. --- main.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/main.py b/main.py index 0b6ef0d6..5a6dd4c6 100644 --- a/main.py +++ b/main.py @@ -8,6 +8,7 @@ Some incomplete laundry lists: TODO: +- Refactor RawReader -> Connection, with read/write operations. - Take test urls from command line. - Cancellation? - Profiling. @@ -226,8 +227,6 @@ def connect(sock, address): def urlfetch(host, port=80, method='GET', path='/', body=None, hdrs=None, encoding='utf-8'): t0 = time.time() - # Must pass in an IP address. Later we'll call getaddrinfo() - # using a thread pool. We'll also support IPv6. if not re.match(r'(\d+)(\.\d+)(\.\d+)(\.\d+)\Z', host): infos = yield from call_in_thread(socket.getaddrinfo, host, port, socket.AF_INET, From 0c88caaa7bf2608644ef86e915e4f39a8bcb2dbf Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 19 Oct 2012 14:33:29 -0700 Subject: [PATCH 0017/1502] A little benchmark (also posted to python-ideas). --- p3time.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 p3time.py diff --git a/p3time.py b/p3time.py new file mode 100644 index 00000000..f1e5f5ab --- /dev/null +++ b/p3time.py @@ -0,0 +1,38 @@ +"""Compare timing of plain vs. yield-from calls.""" + +import time + +def plain(n): + if n <= 0: + return 1 + l = plain(n-1) + r = plain(n-1) + return l + 1 + r + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def main(): + depth = 20 + + t0 = time.time() + k = plain(depth) + t1 = time.time() + print('plain', k, t1-t0) + + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + print('coro.', k, t1-t0) + +if __name__ == '__main__': + main() From 960a633f1df9d124d1f4fa4d79326907bbbac69d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 19 Oct 2012 15:37:03 -0700 Subject: [PATCH 0018/1502] Add Make rule to run p3time. --- Makefile | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Makefile b/Makefile index 89fb927d..43f15dbc 100644 --- a/Makefile +++ b/Makefile @@ -3,3 +3,6 @@ test: profile: python3.3 -m profile -s time main.py + +time: + python3.3 p3time.py From 3b246f5c1b1123341d96e896f67c907ca395313d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 19 Oct 2012 19:10:49 -0700 Subject: [PATCH 0019/1502] Christian Tismer's update to p3time, runs for various depths. --- p3time.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/p3time.py b/p3time.py index f1e5f5ab..f7c67b1f 100644 --- a/p3time.py +++ b/p3time.py @@ -16,13 +16,13 @@ def coroutine(n): r = yield from coroutine(n-1) return l + 1 + r -def main(): - depth = 20 - +def submain(depth): t0 = time.time() k = plain(depth) t1 = time.time() - print('plain', k, t1-t0) + fmt = ' {} {} {:-9,.5f}' + delta0 = t1-t0 + print(('plain' + fmt).format(depth, k, delta0)) t0 = time.time() try: @@ -31,8 +31,14 @@ def main(): next(g) except StopIteration as err: k = err.value - t1 = time.time() - print('coro.', k, t1-t0) + t1 = time.time() + delta1 = t1-t0 + print(('coro.' + fmt).format(depth, k, delta1)) + print(('relat' + fmt).format(depth, k, delta1/delta0)) +def main(reasonable=100): + for depth in range(reasonable): + submain(depth) + if __name__ == '__main__': main() From d37d27e3149e2af8922fffc4bbb592b0a8fcaf5b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 20 Oct 2012 08:48:35 -0700 Subject: [PATCH 0020/1502] Improve connect() after Richard Oudkerk's example. --- main.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index 5a6dd4c6..d096d358 100644 --- a/main.py +++ b/main.py @@ -213,15 +213,16 @@ def newsocket(af, socktype, proto): def connect(sock, address): ## print('connect:', address) - err = sock.connect_ex(address) - assert err == errno.EINPROGRESS, err + try: + sock.connect(address) + except socket.error as err: + if err.errno != errno.EINPROGRESS: + raise sched.block_w(sock.fileno()) yield - err = sock.connect_ex(address) - if err == errno.ECONNREFUSED: + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: raise IOError('Connection refused') - if err not in (0, errno.EISCONN): - raise IOError('Connect error %d: %s' % (err, errno.errorcode.get(err))) def urlfetch(host, port=80, method='GET', path='/', From 3bcea7ee2682a34e8fa14a662eeb0d503280e8b1 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 20 Oct 2012 08:57:26 -0700 Subject: [PATCH 0021/1502] Clean up whitespace. --- main.py | 1 - p3time.py | 2 +- polling.py | 3 ++- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index d096d358..0a013035 100644 --- a/main.py +++ b/main.py @@ -72,7 +72,6 @@ def run_task(self, task, name): finally: self.current = None self.current_name = None - def block_r(self, fd): self.block_io(fd, 'r') diff --git a/p3time.py b/p3time.py index f7c67b1f..d0cef6a4 100644 --- a/p3time.py +++ b/p3time.py @@ -39,6 +39,6 @@ def submain(depth): def main(reasonable=100): for depth in range(reasonable): submain(depth) - + if __name__ == '__main__': main() diff --git a/polling.py b/polling.py index e4b3d636..f41145b1 100644 --- a/polling.py +++ b/polling.py @@ -90,7 +90,8 @@ def run_once(self): try: callback(*args) except Exception: - logging.exception('Exception in callback %s %r', callback, args) + logging.exception('Exception in callback %s %r', + callback, args) # Inspect the poll queue. if self.readers or self.writers: From 15180f156edec6eed9559d48149010548b51b56e Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 21 Oct 2012 17:48:15 -0700 Subject: [PATCH 0022/1502] Refactor to separate Pollster and Eventloop mixin classes. --- polling.py | 40 +++++++++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/polling.py b/polling.py index f41145b1..4c524247 100644 --- a/polling.py +++ b/polling.py @@ -18,15 +18,17 @@ import time -class Pollster: +class PollsterMixin: def __init__(self): - self.ready = collections.deque() # [(callback, args), ...] - self.scheduled = [] # [(when, callback, args), ...] + super().__init__() self.readers = {} # {fd: (callback, args), ...}. self.writers = {} # {fd: (callback, args), ...}. self.pollster = select.poll() + def pollable(self): + return bool(self.readers or self.writers) + def update(self, fd): assert isinstance(fd, int), fd flags = 0 @@ -55,14 +57,6 @@ def remove_writer(self, fd): del self.writers[fd] self.update(fd) - def call_soon(self, callback, *args): - self.ready.append((callback, args)) - - def call_later(self, when, callback, *args): - if when < 10000000: - when += time.time() - heapq.heappush(self.scheduled, (when, callback, args)) - def rawpoll(self, timeout=None): # Timeout is in seconds, but poll() takes milliseconds. :-( msecs = None if timeout is None else int(1000 * timeout) @@ -78,6 +72,22 @@ def rawpoll(self, timeout=None): quads.append((fd, flags, callback, args)) return quads + +class EventLoopMixin: + + def __init__(self): + self.ready = collections.deque() # [(callback, args), ...] + self.scheduled = [] # [(when, callback, args), ...] + super().__init__() + + def call_soon(self, callback, *args): + self.ready.append((callback, args)) + + def call_later(self, when, callback, *args): + if when < 10000000: + when += time.time() + heapq.heappush(self.scheduled, (when, callback, args)) + def run_once(self): # TODO: Break each of these into smaller pieces. # TODO: Pass in a timeout or deadline or something. @@ -94,7 +104,7 @@ def run_once(self): callback, args) # Inspect the poll queue. - if self.readers or self.writers: + if self.pollable(): if self.scheduled: when, _, _ = self.scheduled[0] timeout = max(0, when - time.time()) @@ -113,10 +123,14 @@ def run_once(self): self.call_soon(callback, *args) def run(self): - while self.ready or self.readers or self.writers or self.scheduled: + while self.ready or self.scheduled or self.pollable(): self.run_once() +class Pollster(EventLoopMixin, PollsterMixin): + pass + + class ThreadRunner: def __init__(self, ioloop, max_workers=5): From edab9583059c68b06223765cb9ff6bbf5ee0e6bf Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 21 Oct 2012 19:23:02 -0700 Subject: [PATCH 0023/1502] Add kqueue polling mixin. --- polling.py | 94 ++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 74 insertions(+), 20 deletions(-) diff --git a/polling.py b/polling.py index 4c524247..ffba21ff 100644 --- a/polling.py +++ b/polling.py @@ -18,18 +18,18 @@ import time -class PollsterMixin: +class PollMixin: def __init__(self): super().__init__() self.readers = {} # {fd: (callback, args), ...}. self.writers = {} # {fd: (callback, args), ...}. - self.pollster = select.poll() + self._pollster = select.poll() def pollable(self): return bool(self.readers or self.writers) - def update(self, fd): + def _update(self, fd): assert isinstance(fd, int), fd flags = 0 if fd in self.readers: @@ -37,48 +37,96 @@ def update(self, fd): if fd in self.writers: flags |= select.POLLOUT if flags: - self.pollster.register(fd, flags) + self._pollster.register(fd, flags) else: - self.pollster.unregister(fd) + self._pollster.unregister(fd) def add_reader(self, fd, callback, *args): self.readers[fd] = (callback, args) - self.update(fd) + self._update(fd) def add_writer(self, fd, callback, *args): self.writers[fd] = (callback, args) - self.update(fd) + self._update(fd) def remove_reader(self, fd): del self.readers[fd] - self.update(fd) + self._update(fd) def remove_writer(self, fd): del self.writers[fd] - self.update(fd) + self._update(fd) - def rawpoll(self, timeout=None): + def poll(self, timeout=None): # Timeout is in seconds, but poll() takes milliseconds. :-( msecs = None if timeout is None else int(1000 * timeout) - quads = [] - for fd, flags in self.pollster.poll(msecs): + events = [] # TODO: Do we need fd and flags in events? + for fd, flags in self._pollster.poll(msecs): if flags & (select.POLLIN | select.POLLHUP): if fd in self.readers: callback, args = self.readers[fd] - quads.append((fd, flags, callback, args)) + events.append((fd, flags, callback, args)) if flags & (select.POLLOUT | select.POLLHUP): if fd in self.writers: callback, args = self.writers[fd] - quads.append((fd, flags, callback, args)) - return quads + events.append((fd, flags, callback, args)) + return events + + +class KqueueMixin: + + def __init__(self): + super().__init__() + self.readers = {} + self.writers = {} + self.kqueue = select.kqueue() + + def pollable(self): + return bool(self.readers or self.writers) + + def add_reader(self, fd, callback, *args): + if fd not in self.readers: + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) + self.kqueue.control([kev], 0, 0) + self.readers[fd] = (callback, args) + + def add_writer(self, fd, callback, *args): + if fd not in self.readers: + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) + self.kqueue.control([kev], 0, 0) + self.writers[fd] = (callback, args) + + def remove_reader(self, fd): + del self.readers[fd] + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) + self.kqueue.control([kev], 0, 0) + + def remove_writer(self, fd): + del self.writers[fd] + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE) + self.kqueue.control([kev], 0, 0) + + def poll(self, timeout=None): + events = [] + max_ev = len(self.readers) + len(self.writers) + for kev in self.kqueue.control(None, max_ev, timeout): + fd = kev.ident + flag = kev.filter + if flag == select.KQ_FILTER_READ and fd in self.readers: + callback, args = self.readers[fd] + events.append((fd, flag, callback, args)) + elif flag == select.KQ_FILTER_WRITE and fd in self.writers: + callback, args = self.writers[fd] + events.append((fd, flag, callback, args)) + return events class EventLoopMixin: def __init__(self): + super().__init__() self.ready = collections.deque() # [(callback, args), ...] self.scheduled = [] # [(when, callback, args), ...] - super().__init__() def call_soon(self, callback, *args): self.ready.append((callback, args)) @@ -110,8 +158,8 @@ def run_once(self): timeout = max(0, when - time.time()) else: timeout = None - quads = self.rawpoll(timeout) - for fd, flag, callback, args in quads: + events = self.poll(timeout) + for fd, flag, callback, args in events: self.call_soon(callback, *args) # Handle 'later' callbacks that are ready. @@ -127,8 +175,14 @@ def run(self): self.run_once() -class Pollster(EventLoopMixin, PollsterMixin): - pass +if hasattr(select, 'kqueue'): + class Pollster(EventLoopMixin, KqueueMixin): + pass +elif hasattr(select, 'poll'): + class Pollster(EventLoopMixin, PollMixin): + pass +else: + raise ImportError('Neither poll() not kqueue() supported') class ThreadRunner: From 790c62201c3543cfe3a0445db57b945fb40a0438 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 21 Oct 2012 19:34:26 -0700 Subject: [PATCH 0024/1502] Refactor poll/kqueue classes a bit more. --- polling.py | 61 ++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 39 insertions(+), 22 deletions(-) diff --git a/polling.py b/polling.py index ffba21ff..721de297 100644 --- a/polling.py +++ b/polling.py @@ -1,10 +1,11 @@ -"""I/O loop based on poll(). +"""I/O loop implementations based on kqueue() and poll(). + +If both exist, kqueue() is preferred. If neither exists, raise +ImportError. TODO: -- Docstrings. -- Use _ for non-public methods and instance variables. -- Support some of the other POLL* flags. -- Use kpoll(), epoll() etc. in preference. +- Docstrings, unittests. +- Support epoll(). - Fall back on select() if no poll() variant at all. - Keyword args to callbacks. """ @@ -18,17 +19,38 @@ import time -class PollMixin: +class PollsterBase: def __init__(self): super().__init__() self.readers = {} # {fd: (callback, args), ...}. self.writers = {} # {fd: (callback, args), ...}. - self._pollster = select.poll() def pollable(self): return bool(self.readers or self.writers) + def add_reader(self, fd, callback, *args): + self.readers[fd] = (callback, args) + + def add_writer(self, fd, callback, *args): + self.writers[fd] = (callback, args) + + def remove_reader(self, fd): + del self.readers[fd] + + def remove_writer(self, fd): + del self.writers[fd] + + def poll(self, timeout=None): + raise NotImplementedError + + +class PollMixin(PollsterBase): + + def __init__(self): + super().__init__() + self._pollster = select.poll() + def _update(self, fd): assert isinstance(fd, int), fd flags = 0 @@ -42,19 +64,19 @@ def _update(self, fd): self._pollster.unregister(fd) def add_reader(self, fd, callback, *args): - self.readers[fd] = (callback, args) + super().add_reader(fd, callback, *args) self._update(fd) def add_writer(self, fd, callback, *args): - self.writers[fd] = (callback, args) + super().add_writer(fd, callback, *args) self._update(fd) def remove_reader(self, fd): - del self.readers[fd] + super().remove_reader(fd) self._update(fd) def remove_writer(self, fd): - del self.writers[fd] + super().remove_writer(fd) self._update(fd) def poll(self, timeout=None): @@ -73,36 +95,31 @@ def poll(self, timeout=None): return events -class KqueueMixin: +class KqueueMixin(PollsterBase): def __init__(self): super().__init__() - self.readers = {} - self.writers = {} self.kqueue = select.kqueue() - def pollable(self): - return bool(self.readers or self.writers) - def add_reader(self, fd, callback, *args): if fd not in self.readers: kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) self.kqueue.control([kev], 0, 0) - self.readers[fd] = (callback, args) + super().add_reader(fd, callback, *args) def add_writer(self, fd, callback, *args): if fd not in self.readers: kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) self.kqueue.control([kev], 0, 0) - self.writers[fd] = (callback, args) + super().add_writer(fd, callback, *args) def remove_reader(self, fd): - del self.readers[fd] + super().remove_reader(fd) kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) self.kqueue.control([kev], 0, 0) def remove_writer(self, fd): - del self.writers[fd] + super().remove_writer(fd) kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE) self.kqueue.control([kev], 0, 0) @@ -121,7 +138,7 @@ def poll(self, timeout=None): return events -class EventLoopMixin: +class EventLoopMixin(PollsterBase): def __init__(self): super().__init__() From 6bb079a75d7f87e453f4f210617bda79ba644aed Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 22 Oct 2012 09:32:34 -0700 Subject: [PATCH 0025/1502] Add a fallback class using select() -- always there. --- polling.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/polling.py b/polling.py index 721de297..fbf008f0 100644 --- a/polling.py +++ b/polling.py @@ -45,6 +45,17 @@ def poll(self, timeout=None): raise NotImplementedError +class SelectMixin(PollsterBase): + + def poll(self, timeout=None): + readable, writable, _ = select.select(self.readers, self.writers, + [], timeout) + events = [] + events += ((fd, 'r') + self.readers[fd] for fd in readable) + events += ((fd, 'w') + self.writers[fd] for fd in writable) + return events + + class PollMixin(PollsterBase): def __init__(self): @@ -99,34 +110,34 @@ class KqueueMixin(PollsterBase): def __init__(self): super().__init__() - self.kqueue = select.kqueue() + self._kqueue = select.kqueue() def add_reader(self, fd, callback, *args): if fd not in self.readers: kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) - self.kqueue.control([kev], 0, 0) + self._kqueue.control([kev], 0, 0) super().add_reader(fd, callback, *args) def add_writer(self, fd, callback, *args): if fd not in self.readers: kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) - self.kqueue.control([kev], 0, 0) + self._kqueue.control([kev], 0, 0) super().add_writer(fd, callback, *args) def remove_reader(self, fd): super().remove_reader(fd) kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) - self.kqueue.control([kev], 0, 0) + self._kqueue.control([kev], 0, 0) def remove_writer(self, fd): super().remove_writer(fd) kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE) - self.kqueue.control([kev], 0, 0) + self._kqueue.control([kev], 0, 0) def poll(self, timeout=None): events = [] max_ev = len(self.readers) + len(self.writers) - for kev in self.kqueue.control(None, max_ev, timeout): + for kev in self._kqueue.control(None, max_ev, timeout): fd = kev.ident flag = kev.filter if flag == select.KQ_FILTER_READ and fd in self.readers: @@ -199,7 +210,8 @@ class Pollster(EventLoopMixin, KqueueMixin): class Pollster(EventLoopMixin, PollMixin): pass else: - raise ImportError('Neither poll() not kqueue() supported') + class Pollster(EventLoopMixin, SelectMixin): + pass class ThreadRunner: From fdbc5a17d227b277592fd85aff5200ccab3f0f77 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 22 Oct 2012 09:50:37 -0700 Subject: [PATCH 0026/1502] Add an epoll() mixin class. --- polling.py | 93 +++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 75 insertions(+), 18 deletions(-) diff --git a/polling.py b/polling.py index fbf008f0..28cc03cc 100644 --- a/polling.py +++ b/polling.py @@ -4,9 +4,8 @@ ImportError. TODO: +- Do we need fd and flags in event tuples? - Docstrings, unittests. -- Support epoll(). -- Fall back on select() if no poll() variant at all. - Keyword args to callbacks. """ @@ -60,7 +59,7 @@ class PollMixin(PollsterBase): def __init__(self): super().__init__() - self._pollster = select.poll() + self._poll = select.poll() def _update(self, fd): assert isinstance(fd, int), fd @@ -70,9 +69,9 @@ def _update(self, fd): if fd in self.writers: flags |= select.POLLOUT if flags: - self._pollster.register(fd, flags) + self._poll.register(fd, flags) else: - self._pollster.unregister(fd) + self._poll.unregister(fd) def add_reader(self, fd, callback, *args): super().add_reader(fd, callback, *args) @@ -92,9 +91,9 @@ def remove_writer(self, fd): def poll(self, timeout=None): # Timeout is in seconds, but poll() takes milliseconds. :-( - msecs = None if timeout is None else int(1000 * timeout) - events = [] # TODO: Do we need fd and flags in events? - for fd, flags in self._pollster.poll(msecs): + msecs = None if timeout is None else int(round(1000 * timeout)) + events = [] + for fd, flags in self._poll.poll(msecs): if flags & (select.POLLIN | select.POLLHUP): if fd in self.readers: callback, args = self.readers[fd] @@ -106,6 +105,59 @@ def poll(self, timeout=None): return events +class EPollMixin(PollsterBase): + + def __init__(self): + super().__init__() + self._epoll = select.epoll() + + def _update(self, fd): + assert isinstance(fd, int), fd + eventmask = 0 + if fd in self.readers: + eventmask |= select.EPOLLIN + if fd in self.writers: + eventmask |= select.EPOLLOUT + if eventmask: + try: + self._epoll.register(fd, eventmask) + except IOError: + self._epoll.modify(fd, eventmask) + else: + self._epoll.unregister(fd) + + def add_reader(self, fd, callback, *args): + super().add_reader(fd, callback, *args) + self._update(fd) + + def add_writer(self, fd, callback, *args): + super().add_writer(fd, callback, *args) + self._update(fd) + + def remove_reader(self, fd): + super().remove_reader(fd) + self._update(fd) + + def remove_writer(self, fd): + super().remove_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + if timeout is None: + timeout = -1 # epoll.poll() uses -1 to mean "wait forever". + events = [] + for fd, eventmask in self._epoll.poll(timeout): + if eventmask & (select.EPOLLIN | select.EPOLLHUP): + if fd in self.readers: + callback, args = self.readers[fd] + events.append((fd, eventmask, callback, args)) + if eventmask & (select.EPOLLOUT | select.EPOLLHUP): + if fd in self.writers: + callback, args = self.writers[fd] + events.append((fd, eventmask, callback, args)) + return events + + class KqueueMixin(PollsterBase): def __init__(self): @@ -160,6 +212,7 @@ def call_soon(self, callback, *args): self.ready.append((callback, args)) def call_later(self, when, callback, *args): + # If when is small enough (~11 days), it's a relative time. if when < 10000000: when += time.time() heapq.heappush(self.scheduled, (when, callback, args)) @@ -203,15 +256,19 @@ def run(self): self.run_once() -if hasattr(select, 'kqueue'): - class Pollster(EventLoopMixin, KqueueMixin): - pass -elif hasattr(select, 'poll'): - class Pollster(EventLoopMixin, PollMixin): - pass -else: - class Pollster(EventLoopMixin, SelectMixin): - pass +# Select most appropriate base class for platform. +if hasattr(select, 'kqueue'): # Most BSD + poll_base = KqueueMixin +elif hasattr(select, 'epoll'): # Linux 2.5 and later + poll_base = EPollMixin +elif hasattr(select, 'poll'): # Newer UNIX + poll_base = PollMixin +else: # All UNIX; Windows (for sockets only) + poll_base = SelectMixin + + +class Pollster(EventLoopMixin, poll_base): + pass class ThreadRunner: @@ -225,7 +282,7 @@ def __init__(self, ioloop, max_workers=5): def read_callback(self): # Semi-permanent callback while at least one future is active. assert self.active_count > 0, self.active_count - data = os.read(self.pipe_read_fd, 8192) + data = os.read(self.pipe_read_fd, 8192) # Traditional buffer size. self.active_count -= len(data) if self.active_count == 0: self.ioloop.remove_reader(self.pipe_read_fd) From bbf25ded90e0af14d15a310be8921cd261c80deb Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 22 Oct 2012 10:44:53 -0700 Subject: [PATCH 0027/1502] Apparently epoll() does not need to check for EPOLLHUP. --- main.py | 3 +++ polling.py | 6 ++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 0a013035..52a0fef6 100644 --- a/main.py +++ b/main.py @@ -38,6 +38,9 @@ import socket import time +# Initialize logging before we import polling. +logging.basicConfig(level=logging.INFO) + # Local imports (keep in alphabetic order). import polling diff --git a/polling.py b/polling.py index 28cc03cc..a85d8af3 100644 --- a/polling.py +++ b/polling.py @@ -147,11 +147,11 @@ def poll(self, timeout=None): timeout = -1 # epoll.poll() uses -1 to mean "wait forever". events = [] for fd, eventmask in self._epoll.poll(timeout): - if eventmask & (select.EPOLLIN | select.EPOLLHUP): + if eventmask & select.EPOLLIN: if fd in self.readers: callback, args = self.readers[fd] events.append((fd, eventmask, callback, args)) - if eventmask & (select.EPOLLOUT | select.EPOLLHUP): + if eventmask & select.EPOLLOUT: if fd in self.writers: callback, args = self.writers[fd] events.append((fd, eventmask, callback, args)) @@ -266,6 +266,8 @@ def run(self): else: # All UNIX; Windows (for sockets only) poll_base = SelectMixin +logging.info('Using Pollster base class %r', poll_base.__name__) + class Pollster(EventLoopMixin, poll_base): pass From 0b489ae8392638f96bc6bcf60a5c573b630b3e14 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 23 Oct 2012 09:07:51 -0700 Subject: [PATCH 0028/1502] Made a start with docstrings. Renamed Pollster -> EventLoop. --- main.py | 2 +- polling.py | 92 +++++++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 79 insertions(+), 15 deletions(-) diff --git a/main.py b/main.py index 52a0fef6..604766f6 100644 --- a/main.py +++ b/main.py @@ -109,7 +109,7 @@ def unblock_w(self, fd, task, name): self.start(task, name) -ioloop = polling.Pollster() +ioloop = polling.EventLoop() trunner = polling.ThreadRunner(ioloop) sched = Scheduler(ioloop) diff --git a/polling.py b/polling.py index a85d8af3..a33d31db 100644 --- a/polling.py +++ b/polling.py @@ -1,12 +1,31 @@ -"""I/O loop implementations based on kqueue() and poll(). +"""Event loop and related classes. + +The event loop can be broken up into a pollster (the part responsible +for telling us when file descriptors are ready) and the event loop +proper, which adds functionality for scheduling callbacks, immediately +or at a given time in the future. + +There are several implementations of the pollster part, several using +esoteric system calls that exist only on some platforms. These are: + +- kqueue (most BSD systems) +- epoll (newer Linux systems) +- poll (most UNIX systems) +- select (all UNIX systems, and Windows) +- TODO: Support IOCP on Windows and some UNIX platforms. + +NOTE: We don't use select on systems where any of the others is +available, because select performs poorly as the number of file +descriptors goes up. The ranking is roughly: + + 1. kqueue, epoll, IOCP + 2. poll + 3. select -If both exist, kqueue() is preferred. If neither exists, raise -ImportError. TODO: -- Do we need fd and flags in event tuples? -- Docstrings, unittests. -- Keyword args to callbacks. +- Optimize the various pollster. +- Unittests. """ import collections @@ -19,6 +38,16 @@ class PollsterBase: + """Base class for all polling implementations. + + This defines an interface to register and unregister readers and + writers (defined as a callback plus optional positional arguments) + for specific file descriptors, and an interface to get a list of + events. There's also an interface to check whether any readers or + writers are currently registered. The readers and writers + attributes are public -- they are simply mappings of file + descriptors to tuples of (callback, args). + """ def __init__(self): super().__init__() @@ -26,25 +55,49 @@ def __init__(self): self.writers = {} # {fd: (callback, args), ...}. def pollable(self): + """Return True if any readers or writers are currently registered.""" return bool(self.readers or self.writers) + # Subclasses are expected to extend the add/remove methods. + def add_reader(self, fd, callback, *args): + """Add or update a reader for a file descriptor.""" self.readers[fd] = (callback, args) def add_writer(self, fd, callback, *args): + """Add or update a writer for a file descriptor.""" self.writers[fd] = (callback, args) def remove_reader(self, fd): + """Remove the reader for a file descriptor.""" del self.readers[fd] def remove_writer(self, fd): + """Remove the writer for a file descriptor.""" del self.writers[fd] def poll(self, timeout=None): + """Poll for events. A subclass must implement this. + + If timeout is omitted or None, this blocks until at least one + event is ready. Otherwise, timeout gives a maximum time to + wait (an int of float in seconds) -- the method returns as + soon as at least one event is ready or when the timeout is + expired. For a non-blocking poll, pass 0. + + The return value is a list of events; it is empty when the + timeout expired before any events were ready. Each event + is a tuple of the form (fd, flag, callback, args): + fd: the file descriptor + flag: 'r' or 'w' (to distinguish readers from writers) + callback: callback function + args: arguments tuple for callback + """ raise NotImplementedError class SelectMixin(PollsterBase): + """Pollster implementation using select.""" def poll(self, timeout=None): readable, writable, _ = select.select(self.readers, self.writers, @@ -56,6 +109,7 @@ def poll(self, timeout=None): class PollMixin(PollsterBase): + """Pollster implementation using poll.""" def __init__(self): super().__init__() @@ -90,22 +144,23 @@ def remove_writer(self, fd): self._update(fd) def poll(self, timeout=None): - # Timeout is in seconds, but poll() takes milliseconds. :-( + # Timeout is in seconds, but poll() takes milliseconds. msecs = None if timeout is None else int(round(1000 * timeout)) events = [] for fd, flags in self._poll.poll(msecs): if flags & (select.POLLIN | select.POLLHUP): if fd in self.readers: callback, args = self.readers[fd] - events.append((fd, flags, callback, args)) + events.append((fd, 'r', callback, args)) if flags & (select.POLLOUT | select.POLLHUP): if fd in self.writers: callback, args = self.writers[fd] - events.append((fd, flags, callback, args)) + events.append((fd, 'w', callback, args)) return events class EPollMixin(PollsterBase): + """Pollster implementation using epoll.""" def __init__(self): super().__init__() @@ -150,15 +205,16 @@ def poll(self, timeout=None): if eventmask & select.EPOLLIN: if fd in self.readers: callback, args = self.readers[fd] - events.append((fd, eventmask, callback, args)) + events.append((fd, 'r', callback, args)) if eventmask & select.EPOLLOUT: if fd in self.writers: callback, args = self.writers[fd] - events.append((fd, eventmask, callback, args)) + events.append((fd, 'w', callback, args)) return events class KqueueMixin(PollsterBase): + """Pollster implementation using kqueue.""" def __init__(self): super().__init__() @@ -194,14 +250,22 @@ def poll(self, timeout=None): flag = kev.filter if flag == select.KQ_FILTER_READ and fd in self.readers: callback, args = self.readers[fd] - events.append((fd, flag, callback, args)) + events.append((fd, 'r', callback, args)) elif flag == select.KQ_FILTER_WRITE and fd in self.writers: callback, args = self.writers[fd] - events.append((fd, flag, callback, args)) + events.append((fd, 'w', callback, args)) return events class EventLoopMixin(PollsterBase): + """Event loop functionality. + + This defines call_soon(), call_later(), run_once() and run(). + + This is an abstract class, inheriting from the abstract class + PollsterBase. A concrete class can be formed trivially by + inheriting from any of the pollster mixin classes. + """ def __init__(self): super().__init__() @@ -269,7 +333,7 @@ def run(self): logging.info('Using Pollster base class %r', poll_base.__name__) -class Pollster(EventLoopMixin, poll_base): +class EventLoop(EventLoopMixin, poll_base): pass From 79090b84d07a890d1ef040965f3e2bc3049440a8 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 23 Oct 2012 11:53:41 -0700 Subject: [PATCH 0029/1502] Finish docstrings for polling.py. Rename ioloop to eventloop. --- main.py | 26 +++++++------- polling.py | 103 ++++++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 99 insertions(+), 30 deletions(-) diff --git a/main.py b/main.py index 604766f6..c9c469b3 100644 --- a/main.py +++ b/main.py @@ -47,18 +47,18 @@ class Scheduler: - def __init__(self, ioloop): - self.ioloop = ioloop + def __init__(self, eventloop): + self.eventloop = eventloop self.current = None self.current_name = None def run(self): - self.ioloop.run() + self.eventloop.run() def start(self, task, name=None): if name is None: name = task.__name__ # If it doesn't have one, pass one. - self.ioloop.call_soon(self.run_task, task, name) + self.eventloop.call_soon(self.run_task, task, name) def run_task(self, task, name): try: @@ -87,10 +87,10 @@ def block_io(self, fd, flag): assert flag in ('r', 'w'), repr(flag) task, name = self.block() if flag == 'r': - method = self.ioloop.add_reader + method = self.eventloop.add_reader callback = self.unblock_r else: - method = self.ioloop.add_writer + method = self.eventloop.add_writer callback = self.unblock_w method(fd, callback, fd, task, name) @@ -101,23 +101,23 @@ def block(self): return task, self.current_name def unblock_r(self, fd, task, name): - self.ioloop.remove_reader(fd) + self.eventloop.remove_reader(fd) self.start(task, name) def unblock_w(self, fd, task, name): - self.ioloop.remove_writer(fd) + self.eventloop.remove_writer(fd) self.start(task, name) -ioloop = polling.EventLoop() -trunner = polling.ThreadRunner(ioloop) -sched = Scheduler(ioloop) +eventloop = polling.EventLoop() +threadrunner = polling.ThreadRunner(eventloop) +sched = Scheduler(eventloop) -def call_in_thread(func, *args, **kwds): +def call_in_thread(func, *args): # TODO: Prove there is no race condition here. task, name = sched.block() - future = trunner.submit(func, *args, **kwds) + future = threadrunner.submit(func, *args) future.add_done_callback(lambda _: sched.start(task, name)) yield assert future.done() diff --git a/polling.py b/polling.py index a33d31db..e85a0ecb 100644 --- a/polling.py +++ b/polling.py @@ -5,6 +5,13 @@ proper, which adds functionality for scheduling callbacks, immediately or at a given time in the future. +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. + There are several implementations of the pollster part, several using esoteric system calls that exist only on some platforms. These are: @@ -260,11 +267,19 @@ def poll(self, timeout=None): class EventLoopMixin(PollsterBase): """Event loop functionality. - This defines call_soon(), call_later(), run_once() and run(). - This is an abstract class, inheriting from the abstract class PollsterBase. A concrete class can be formed trivially by - inheriting from any of the pollster mixin classes. + inheriting from any of the pollster mixin classes; the concrete + class EventLoop is such a concrete class using the preferred mixin + given the platform. + + This defines public APIs call_soon(), call_later(), run_once() and + run(). It also inherits public APIs add_reader(), add_writer(), + remove_reader(), remove_writer() from the mixin class. The APIs + pollable() and poll(), implemented by the mix-in, are not part of + the public API. + + This class's instance variables are not part of its API. """ def __init__(self): @@ -273,21 +288,56 @@ def __init__(self): self.scheduled = [] # [(when, callback, args), ...] def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ self.ready.append((callback, args)) def call_later(self, when, callback, *args): - # If when is small enough (~11 days), it's a relative time. + """Arrange for a callback to be called at a given time. + + The time can be an int or float, expressed in seconds. + + If when is small enough (~11 days), it's assumed to be a + relative time, meaning the call will be scheduled that many + seconds in the future; otherwise it's assumed to be a posix + timestamp as returned by time.time(). + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ if when < 10000000: when += time.time() heapq.heappush(self.scheduled, (when, callback, args)) def run_once(self): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ # TODO: Break each of these into smaller pieces. # TODO: Pass in a timeout or deadline or something. # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: As step 4, run everything scheduled by steps 1-3. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. # This is the only place where callbacks are actually *called*. # All other places just add them to ready. + # TODO: Ensure this loop always finishes, even if some + # callbacks keeps registering more callbacks. while self.ready: callback, args = self.ready.popleft() try: @@ -316,49 +366,68 @@ def run_once(self): self.call_soon(callback, *args) def run(self): + """Run the event loop until there is no work left to do. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + """ while self.ready or self.scheduled or self.pollable(): self.run_once() -# Select most appropriate base class for platform. -if hasattr(select, 'kqueue'): # Most BSD +# Select the most appropriate base class for the platform. +if hasattr(select, 'kqueue'): poll_base = KqueueMixin -elif hasattr(select, 'epoll'): # Linux 2.5 and later +elif hasattr(select, 'epoll'): poll_base = EPollMixin -elif hasattr(select, 'poll'): # Newer UNIX +elif hasattr(select, 'poll'): poll_base = PollMixin -else: # All UNIX; Windows (for sockets only) +else: poll_base = SelectMixin logging.info('Using Pollster base class %r', poll_base.__name__) class EventLoop(EventLoopMixin, poll_base): - pass + """Event loop implementation using the optimal pollster mixin.""" class ThreadRunner: + """Helper to submit work to a thread pool and wait for it. - def __init__(self, ioloop, max_workers=5): - self.ioloop = ioloop + This is the glue between the single-threaded callback-based async + world and the threaded world. Use it to call functions that must + block and don't have an async alternative (e.g. getaddrinfo()). + + The only public API is submit(). + """ + + def __init__(self, eventloop, max_workers=5): + self.eventloop = eventloop self.threadpool = concurrent.futures.ThreadPoolExecutor(max_workers) self.pipe_read_fd, self.pipe_write_fd = os.pipe() self.active_count = 0 def read_callback(self): - # Semi-permanent callback while at least one future is active. + """Semi-permanent callback while at least one future is active.""" assert self.active_count > 0, self.active_count data = os.read(self.pipe_read_fd, 8192) # Traditional buffer size. self.active_count -= len(data) if self.active_count == 0: - self.ioloop.remove_reader(self.pipe_read_fd) + self.eventloop.remove_reader(self.pipe_read_fd) assert self.active_count >= 0, self.active_count - def submit(self, func, *args, **kwds): + def submit(self, func, *args): + """Submit a function to the thread pool. + + This returns a concurrent.futures.Future instance. The caller + should not wait for that, but rather add a callback to it. + """ assert self.active_count >= 0, self.active_count - future = self.threadpool.submit(func, *args, **kwds) + future = self.threadpool.submit(func, *args) if self.active_count == 0: - self.ioloop.add_reader(self.pipe_read_fd, self.read_callback) + self.eventloop.add_reader(self.pipe_read_fd, self.read_callback) self.active_count += 1 def done_callback(future): os.write(self.pipe_write_fd, b'x') From ee072b517b356150e463b45c0150335d31520de7 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 23 Oct 2012 17:29:55 -0700 Subject: [PATCH 0030/1502] Move Scheduler and call_in_thread into scheduling.py. --- main.py | 113 ++++++++------------------------------------------ scheduling.py | 97 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+), 95 deletions(-) create mode 100644 scheduling.py diff --git a/main.py b/main.py index c9c469b3..bc7c6fcf 100644 --- a/main.py +++ b/main.py @@ -8,33 +8,29 @@ Some incomplete laundry lists: TODO: +- Make nice transport and protocol abstractions. - Refactor RawReader -> Connection, with read/write operations. - Take test urls from command line. - Cancellation? - Profiling. +- Docstrings. - Unittests. -PATTERNS TO TRY: -- Wait for all, collate results. -- Wait for first N that are ready. -- Wait until some predicate becomes true. - FUNCTIONALITY: - Connection pool (keep connection open). - Chunked encoding (request and response). - Pipelining, e.g. zlib (request and response). - Automatic encoding/decoding. -- A write() call that isn't a generator. +- A write() call that isn't a generator (needed so you can substitute it + for sys.stderr, pass it to logging.StreamHandler, etc.). """ __author__ = 'Guido van Rossum ' # Standard library imports (keep in alphabetic order). -import collections import errno import logging import re -import select import socket import time @@ -43,85 +39,12 @@ # Local imports (keep in alphabetic order). import polling - - -class Scheduler: - - def __init__(self, eventloop): - self.eventloop = eventloop - self.current = None - self.current_name = None - - def run(self): - self.eventloop.run() - - def start(self, task, name=None): - if name is None: - name = task.__name__ # If it doesn't have one, pass one. - self.eventloop.call_soon(self.run_task, task, name) - - def run_task(self, task, name): - try: - self.current = task - self.current_name = name - next(self.current) - except StopIteration: - pass - except Exception: - logging.exception('Exception in task %r', name) - else: - if self.current is not None: - self.start(task, name) - finally: - self.current = None - self.current_name = None - - def block_r(self, fd): - self.block_io(fd, 'r') - - def block_w(self, fd): - self.block_io(fd, 'w') - - def block_io(self, fd, flag): - assert isinstance(fd, int), repr(fd) - assert flag in ('r', 'w'), repr(flag) - task, name = self.block() - if flag == 'r': - method = self.eventloop.add_reader - callback = self.unblock_r - else: - method = self.eventloop.add_writer - callback = self.unblock_w - method(fd, callback, fd, task, name) - - def block(self): - assert self.current - task = self.current - self.current = None - return task, self.current_name - - def unblock_r(self, fd, task, name): - self.eventloop.remove_reader(fd) - self.start(task, name) - - def unblock_w(self, fd, task, name): - self.eventloop.remove_writer(fd) - self.start(task, name) +import scheduling eventloop = polling.EventLoop() threadrunner = polling.ThreadRunner(eventloop) -sched = Scheduler(eventloop) - - -def call_in_thread(func, *args): - # TODO: Prove there is no race condition here. - task, name = sched.block() - future = threadrunner.submit(func, *args) - future.add_done_callback(lambda _: sched.start(task, name)) - yield - assert future.done() - return future.result() +scheduler = scheduling.Scheduler(eventloop, threadrunner) class RawReader: @@ -133,7 +56,7 @@ def __init__(self, sock): def read(self, n): """Read up to n bytes, blocking at most once.""" assert n >= 0, n - sched.block_r(self.sock.fileno()) + scheduler.block_r(self.sock.fileno()) yield return self.sock.recv(n) @@ -198,7 +121,7 @@ def fillbuffer(self, n): def send(sock, data): ## print('send:', repr(data)) while data: - sched.block_w(sock.fileno()) + scheduler.block_w(sock.fileno()) yield n = sock.send(data) assert 0 <= n <= len(data), (n, len(data)) @@ -220,7 +143,7 @@ def connect(sock, address): except socket.error as err: if err.errno != errno.EINPROGRESS: raise - sched.block_w(sock.fileno()) + scheduler.block_w(sock.fileno()) yield err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) if err != 0: @@ -231,10 +154,10 @@ def urlfetch(host, port=80, method='GET', path='/', body=None, hdrs=None, encoding='utf-8'): t0 = time.time() if not re.match(r'(\d+)(\.\d+)(\.\d+)(\.\d+)\Z', host): - infos = yield from call_in_thread(socket.getaddrinfo, - host, port, socket.AF_INET, - socket.SOCK_STREAM, - socket.SOL_TCP) + infos = yield from scheduler.call_in_thread(socket.getaddrinfo, + host, port, socket.AF_INET, + socket.SOCK_STREAM, + socket.SOL_TCP) else: infos = [(socket.AF_INET, socket.SOCK_STREAM, socket.SOL_TCP, '', (host, port))] @@ -318,14 +241,14 @@ def doit(): # This references NDB's default test service. # (Sadly the service is single-threaded.) gen1 = urlfetch('localhost', 8080, path='/') - sched.start(gen1, 'gen1') + scheduler.start(gen1, 'gen1') gen2 = urlfetch('localhost', 8080, path='/home') - sched.start(gen2, 'gen2') + scheduler.start(gen2, 'gen2') # Fetch python.org home page. gen3 = urlfetch('python.org', 80, path='/') - sched.start(gen3, 'gen3') + scheduler.start(gen3, 'gen3') ## # Fetch many links from python.org (/x.y.z). ## for x in '123': @@ -333,9 +256,9 @@ def doit(): ## path = '/{}.{}'.format(x, y) ## g = urlfetch('82.94.164.162', 80, ## path=path, hdrs={'host': 'python.org'}) -## sched.start(g, path) +## scheduler.start(g, path) - sched.run() + scheduler.run() def main(): diff --git a/scheduling.py b/scheduling.py new file mode 100644 index 00000000..c97e35b8 --- /dev/null +++ b/scheduling.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3.3 +"""Example coroutine scheduler, PEP-380-style ('yield from '). + +Requires Python 3.3. + +There are likely micro-optimizations possible here, but that's not the point. + +TODO: +- Docstrings. +- Unittests. + +PATTERNS TO TRY: +- Wait for all, collate results. +- Wait for first N that are ready. +- Wait until some predicate becomes true. +- Run with timeout. + +""" + +__author__ = 'Guido van Rossum ' + +# Standard library imports (keep in alphabetic order). +import logging + + +class Scheduler: + + def __init__(self, eventloop, threadrunner): + self.eventloop = eventloop # polling.EventLoop instance. + self.threadrunner = threadrunner # polling.Threadrunner instance. + self.current = None # Current generator. + self.current_name = None # Current generator's name. + + def run(self): + self.eventloop.run() + + def start(self, task, name=None): + if name is None: + name = task.__name__ # If it doesn't have one, pass one. + self.eventloop.call_soon(self.run_task, task, name) + + def run_task(self, task, name): + try: + self.current = task + self.current_name = name + next(self.current) + except StopIteration: + pass + except Exception: + logging.exception('Exception in task %r', name) + else: + if self.current is not None: + self.start(task, name) + finally: + self.current = None + self.current_name = None + + def block_r(self, fd): + self.block_io(fd, 'r') + + def block_w(self, fd): + self.block_io(fd, 'w') + + def block_io(self, fd, flag): + assert isinstance(fd, int), repr(fd) + assert flag in ('r', 'w'), repr(flag) + task, name = self.block() + if flag == 'r': + method = self.eventloop.add_reader + callback = self.unblock_r + else: + method = self.eventloop.add_writer + callback = self.unblock_w + method(fd, callback, fd, task, name) + + def block(self): + assert self.current + task = self.current + self.current = None + return task, self.current_name + + def unblock_r(self, fd, task, name): + self.eventloop.remove_reader(fd) + self.start(task, name) + + def unblock_w(self, fd, task, name): + self.eventloop.remove_writer(fd) + self.start(task, name) + + def call_in_thread(self, func, *args): + # TODO: Prove there is no race condition here. + task, name = self.block() + future = self.threadrunner.submit(func, *args) + future.add_done_callback(lambda _: self.start(task, name)) + yield + assert future.done() + return future.result() From fef963eb5ea26347f775f6238fe70335f2de3014 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 23 Oct 2012 19:09:47 -0700 Subject: [PATCH 0031/1502] Move socket ops to their own file. Very coarse. --- main.py | 135 ++++++----------------------------------------------- sockets.py | 120 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+), 121 deletions(-) create mode 100644 sockets.py diff --git a/main.py b/main.py index bc7c6fcf..11b278d1 100644 --- a/main.py +++ b/main.py @@ -8,10 +8,7 @@ Some incomplete laundry lists: TODO: -- Make nice transport and protocol abstractions. -- Refactor RawReader -> Connection, with read/write operations. - Take test urls from command line. -- Cancellation? - Profiling. - Docstrings. - Unittests. @@ -21,14 +18,11 @@ - Chunked encoding (request and response). - Pipelining, e.g. zlib (request and response). - Automatic encoding/decoding. -- A write() call that isn't a generator (needed so you can substitute it - for sys.stderr, pass it to logging.StreamHandler, etc.). """ __author__ = 'Guido van Rossum ' # Standard library imports (keep in alphabetic order). -import errno import logging import re import socket @@ -40,114 +34,13 @@ # Local imports (keep in alphabetic order). import polling import scheduling - +import sockets eventloop = polling.EventLoop() threadrunner = polling.ThreadRunner(eventloop) scheduler = scheduling.Scheduler(eventloop, threadrunner) - -class RawReader: - # TODO: Merge with send() and newsocket() functions. - - def __init__(self, sock): - self.sock = sock - - def read(self, n): - """Read up to n bytes, blocking at most once.""" - assert n >= 0, n - scheduler.block_r(self.sock.fileno()) - yield - return self.sock.recv(n) - - -class BufferedReader: - - def __init__(self, raw, limit=8192): - self.raw = raw - self.limit = limit - self.buffer = b'' - self.eof = False - - def read(self, n): - """Read up to n bytes, blocking at most once.""" - assert n >= 0, n - if not self.buffer and not self.eof: - yield from self.fillbuffer(max(n, self.limit)) - return self.getfrombuffer(n) - - def readexactly(self, n): - """Read exactly n bytes, or until EOF.""" - blocks = [] - count = 0 - while n > count: - block = yield from self.read(n - count) - blocks.append(block) - count += len(block) - return b''.join(blocks) - - def readline(self): - """Read up to newline or limit, whichever comes first.""" - end = self.buffer.find(b'\n') + 1 # Point past newline, or 0. - while not end and not self.eof and len(self.buffer) < self.limit: - anchor = len(self.buffer) - yield from self.fillbuffer(self.limit) - end = self.buffer.find(b'\n', anchor) + 1 - if not end: - end = len(self.buffer) - if end > self.limit: - end = self.limit - return self.getfrombuffer(end) - - def getfrombuffer(self, n): - """Read up to n bytes without blocking.""" - if n >= len(self.buffer): - result, self.buffer = self.buffer, b'' - else: - result, self.buffer = self.buffer[:n], self.buffer[n:] - return result - - def fillbuffer(self, n): - """Fill buffer with one (up to) n bytes from raw reader.""" - assert not self.eof, 'fillbuffer called at eof' - data = yield from self.raw.read(n) -## print('fillbuffer:', repr(data)[:100]) - if data: - self.buffer += data - else: - self.eof = True - - -def send(sock, data): -## print('send:', repr(data)) - while data: - scheduler.block_w(sock.fileno()) - yield - n = sock.send(data) - assert 0 <= n <= len(data), (n, len(data)) - if n == len(data): - break - data = data[n:] - - -def newsocket(af, socktype, proto): - sock = socket.socket(af, socktype, proto) - sock.setblocking(False) - return sock - - -def connect(sock, address): -## print('connect:', address) - try: - sock.connect(address) - except socket.error as err: - if err.errno != errno.EINPROGRESS: - raise - scheduler.block_w(sock.fileno()) - yield - err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) - if err != 0: - raise IOError('Connection refused') +sockets.scheduler = scheduler # TODO: Find a better way. def urlfetch(host, port=80, method='GET', path='/', @@ -166,8 +59,8 @@ def urlfetch(host, port=80, method='GET', path='/', for af, socktype, proto, cname, address in infos: sock = None try: - sock = newsocket(af, socktype, proto) - yield from connect(sock, address) + sock = sockets.newsocket(af, socktype, proto) + yield from sockets.connect(sock, address) break except socket.error as err: if sock is not None: @@ -177,9 +70,9 @@ def urlfetch(host, port=80, method='GET', path='/', else: if exc is not None: raise exc - yield from send(sock, - method.encode(encoding) + b' ' + - path.encode(encoding) + b' HTTP/1.0\r\n') + yield from sockets.send(sock, + method.encode(encoding) + b' ' + + path.encode(encoding) + b' HTTP/1.0\r\n') if hdrs: kwds = dict(hdrs) else: @@ -189,18 +82,18 @@ def urlfetch(host, port=80, method='GET', path='/', if body is not None: kwds['content_length'] = len(body) for header, value in kwds.items(): - yield from send(sock, - header.replace('_', '-').encode(encoding) + b': ' + - value.encode(encoding) + b'\r\n') + yield from sockets.send(sock, + header.replace('_', '-').encode(encoding) + + b': ' + value.encode(encoding) + b'\r\n') - yield from send(sock, b'\r\n') + yield from sockets.send(sock, b'\r\n') if body is not None: - yield from send(sock, body) + yield from sockets.send(sock, body) ##sock.shutdown(1) # Close the writing end of the socket. # Read HTTP response line. - raw = RawReader(sock) - buf = BufferedReader(raw) + raw = sockets.RawReader(sock) + buf = sockets.BufferedReader(raw) resp = yield from buf.readline() ## print('resp =', repr(resp)) m = re.match(br'(?ix) http/(\d\.\d) \s+ (\d\d\d) \s+ ([^\r]*)\r?\n\Z', resp) diff --git a/sockets.py b/sockets.py new file mode 100644 index 00000000..7a6222d5 --- /dev/null +++ b/sockets.py @@ -0,0 +1,120 @@ +"""Socket wrappers to go with scheduling.py. + +TODO: +- Make nice transport and protocol abstractions. +- Refactor RawReader -> Connection, with read/write operations. +- Docstrings. +- Unittests. +- A write() call that isn't a generator (needed so you can substitute it + for sys.stderr, pass it to logging.StreamHandler, etc.). +- Move getaddrinfo() call here. +""" + +__author__ = 'Guido van Rossum ' + +import errno +import socket + + +class RawReader: + # TODO: Merge with send() and newsocket() functions. + + def __init__(self, sock): + self.sock = sock + + def read(self, n): + """Read up to n bytes, blocking at most once.""" + assert n >= 0, n + scheduler.block_r(self.sock.fileno()) + yield + return self.sock.recv(n) + + +class BufferedReader: + + def __init__(self, raw, limit=8192): + self.raw = raw + self.limit = limit + self.buffer = b'' + self.eof = False + + def read(self, n): + """Read up to n bytes, blocking at most once.""" + assert n >= 0, n + if not self.buffer and not self.eof: + yield from self.fillbuffer(max(n, self.limit)) + return self.getfrombuffer(n) + + def readexactly(self, n): + """Read exactly n bytes, or until EOF.""" + blocks = [] + count = 0 + while n > count: + block = yield from self.read(n - count) + blocks.append(block) + count += len(block) + return b''.join(blocks) + + def readline(self): + """Read up to newline or limit, whichever comes first.""" + end = self.buffer.find(b'\n') + 1 # Point past newline, or 0. + while not end and not self.eof and len(self.buffer) < self.limit: + anchor = len(self.buffer) + yield from self.fillbuffer(self.limit) + end = self.buffer.find(b'\n', anchor) + 1 + if not end: + end = len(self.buffer) + if end > self.limit: + end = self.limit + return self.getfrombuffer(end) + + def getfrombuffer(self, n): + """Read up to n bytes without blocking.""" + if n >= len(self.buffer): + result, self.buffer = self.buffer, b'' + else: + result, self.buffer = self.buffer[:n], self.buffer[n:] + return result + + def fillbuffer(self, n): + """Fill buffer with one (up to) n bytes from raw reader.""" + assert not self.eof, 'fillbuffer called at eof' + data = yield from self.raw.read(n) +## print('fillbuffer:', repr(data)[:100]) + if data: + self.buffer += data + else: + self.eof = True + + +def send(sock, data): +## print('send:', repr(data)) + while data: + scheduler.block_w(sock.fileno()) + yield + n = sock.send(data) + assert 0 <= n <= len(data), (n, len(data)) + if n == len(data): + break + data = data[n:] + + +def newsocket(af, socktype, proto): + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + return sock + + +def connect(sock, address): +## print('connect:', address) + try: + sock.connect(address) + except socket.error as err: + if err.errno != errno.EINPROGRESS: + raise + scheduler.block_w(sock.fileno()) + yield + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + raise IOError('Connection refused') + From ad84b824ed18d72679fedcdc64e689d4d3c89944 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 23 Oct 2012 19:36:18 -0700 Subject: [PATCH 0032/1502] Move getaddrinfo wrapper into sockets.py. --- main.py | 26 +------------------------- sockets.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 25 deletions(-) diff --git a/main.py b/main.py index 11b278d1..d1a187e1 100644 --- a/main.py +++ b/main.py @@ -25,7 +25,6 @@ # Standard library imports (keep in alphabetic order). import logging import re -import socket import time # Initialize logging before we import polling. @@ -46,30 +45,7 @@ def urlfetch(host, port=80, method='GET', path='/', body=None, hdrs=None, encoding='utf-8'): t0 = time.time() - if not re.match(r'(\d+)(\.\d+)(\.\d+)(\.\d+)\Z', host): - infos = yield from scheduler.call_in_thread(socket.getaddrinfo, - host, port, socket.AF_INET, - socket.SOCK_STREAM, - socket.SOL_TCP) - else: - infos = [(socket.AF_INET, socket.SOCK_STREAM, socket.SOL_TCP, '', - (host, port))] - assert infos, 'No address info for (%r, %r)' % (host, port) - exc = None - for af, socktype, proto, cname, address in infos: - sock = None - try: - sock = sockets.newsocket(af, socktype, proto) - yield from sockets.connect(sock, address) - break - except socket.error as err: - if sock is not None: - sock.close() - if exc is None: - exc = err - else: - if exc is not None: - raise exc + sock = yield from sockets.create_connection((host, port)) yield from sockets.send(sock, method.encode(encoding) + b' ' + path.encode(encoding) + b' HTTP/1.0\r\n') diff --git a/sockets.py b/sockets.py index 7a6222d5..824e8075 100644 --- a/sockets.py +++ b/sockets.py @@ -13,6 +13,7 @@ __author__ = 'Guido van Rossum ' import errno +import re import socket @@ -118,3 +119,37 @@ def connect(sock, address): if err != 0: raise IOError('Connection refused') + +def create_connection(address): + host, port = address + match = re.match(r'(\d+)(\.\d+)(\.\d+)(\.\d+)\Z', host) + if match: + d1, d2, d3, d4 = map(int, match.groups()) + if not (0 <= d1 <= 255 and 0 <= d2 <= 255 and + 0 <= d3 <= 255 and 0 <= d4 <= 255): + match = None + if not match: + infos = yield from scheduler.call_in_thread(socket.getaddrinfo, + host, port, socket.AF_INET, + socket.SOCK_STREAM, + socket.SOL_TCP) + else: + infos = [(socket.AF_INET, socket.SOCK_STREAM, socket.SOL_TCP, '', + (host, port))] + assert infos, 'No address info for (%r, %r)' % (host, port) + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = newsocket(af, socktype, proto) + yield from connect(sock, address) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + if exc is not None: + raise exc + return sock From e1f7c38294136befd8cd392c25de952a707a8403 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 23 Oct 2012 21:40:29 -0700 Subject: [PATCH 0033/1502] Clean up sockets.py. --- main.py | 33 ++++++------- sockets.py | 140 +++++++++++++++++++++++++++++++++++------------------ 2 files changed, 106 insertions(+), 67 deletions(-) diff --git a/main.py b/main.py index d1a187e1..c7c96916 100644 --- a/main.py +++ b/main.py @@ -45,10 +45,9 @@ def urlfetch(host, port=80, method='GET', path='/', body=None, hdrs=None, encoding='utf-8'): t0 = time.time() - sock = yield from sockets.create_connection((host, port)) - yield from sockets.send(sock, - method.encode(encoding) + b' ' + - path.encode(encoding) + b' HTTP/1.0\r\n') + trans = yield from sockets.create_transport((host, port)) + yield from trans.send(method.encode(encoding) + b' ' + + path.encode(encoding) + b' HTTP/1.0\r\n') if hdrs: kwds = dict(hdrs) else: @@ -58,23 +57,20 @@ def urlfetch(host, port=80, method='GET', path='/', if body is not None: kwds['content_length'] = len(body) for header, value in kwds.items(): - yield from sockets.send(sock, - header.replace('_', '-').encode(encoding) + - b': ' + value.encode(encoding) + b'\r\n') + yield from trans.send(header.replace('_', '-').encode(encoding) + + b': ' + value.encode(encoding) + b'\r\n') - yield from sockets.send(sock, b'\r\n') + yield from trans.send(b'\r\n') if body is not None: - yield from sockets.send(sock, body) - ##sock.shutdown(1) # Close the writing end of the socket. + yield from trans.send(body) + trans.shutdown('w') # Close the writing end of the socket. # Read HTTP response line. - raw = sockets.RawReader(sock) - buf = sockets.BufferedReader(raw) - resp = yield from buf.readline() -## print('resp =', repr(resp)) + rdr = sockets.BufferedReader(trans) + resp = yield from rdr.readline() m = re.match(br'(?ix) http/(\d\.\d) \s+ (\d\d\d) \s+ ([^\r]*)\r?\n\Z', resp) if not m: - sock.close() + trans.close() raise IOError('No valid HTTP response: %r' % resp) http_version, status, message = m.groups() @@ -82,7 +78,7 @@ def urlfetch(host, port=80, method='GET', path='/', headers = [] hdict = {} while True: - line = yield from buf.readline() + line = yield from rdr.readline() if not line.strip(): break m = re.match(br'([^\s:]+):\s*([^\r]*)\r?\n\Z', line) @@ -99,10 +95,9 @@ def urlfetch(host, port=80, method='GET', path='/', assert size >= 0, size else: size = 2**20 # Protective limit (1 MB). - data = yield from buf.readexactly(size) - sock.close() # Can this block? + data = yield from rdr.readexactly(size) + trans.close() # Can this block? t1 = time.time() -## print(http_version, status, message, headers, hdict, len(data)) print(host, port, path, status, len(data), '{:.3}'.format(t1-t0)) diff --git a/sockets.py b/sockets.py index 824e8075..4b780ae3 100644 --- a/sockets.py +++ b/sockets.py @@ -1,9 +1,20 @@ """Socket wrappers to go with scheduling.py. +Classes: + +- SocketTransport: a transport implementation wrapping a socket. +- BufferedReader: a buffer wrapping the read end of a transport. + +Functions (all coroutines): + +- connect(): connect a socket. +- getaddrinfo(): look up an address. +- create_connection(): look up an address and return a connected socket for it. +- create_transport(): look up an address and return a connected trahsport. + TODO: -- Make nice transport and protocol abstractions. -- Refactor RawReader -> Connection, with read/write operations. -- Docstrings. +- Improve transport abstraction. +- Make a nice protocol abstraction. - Unittests. - A write() call that isn't a generator (needed so you can substitute it for sys.stderr, pass it to logging.StreamHandler, etc.). @@ -17,37 +28,73 @@ import socket -class RawReader: - # TODO: Merge with send() and newsocket() functions. +class SocketTransport: + """Transport wrapping a socket. + + The socket must already be connected. + """ def __init__(self, sock): self.sock = sock - def read(self, n): - """Read up to n bytes, blocking at most once.""" + def recv(self, n): + """COROUTINE: Read up to n bytes, blocking at most once.""" assert n >= 0, n scheduler.block_r(self.sock.fileno()) yield return self.sock.recv(n) + def send(self, data): + """COROUTINE; Send data to the socket, blocking until all written.""" + while data: + scheduler.block_w(self.sock.fileno()) + yield + n = self.sock.send(data) + assert 0 <= n <= len(data), (n, len(data)) + if n == len(data): + break + data = data[n:] + + def shutdown(self, flag): + """Call shutdown() on the socket. (Not a coroutine.) + + This is like closing one direction. + + The argument must be 'r', 'w' or 'rw'. + """ + if flag == 'r': + flag = socket.SHUT_RD + elif flag == 'w': + flag = socket.SHUT_WR + elif flag == 'rw': + flag = socket.SHUT_RDWR + else: + raise ValueError('flag must be r, w or rw, not %s' % flag) + self.sock.shutdown(flag) + + def close(self): + """Close the socket. (Not a coroutine.)""" + self.sock.close() + class BufferedReader: + """A buffered reader wrapping a transport.""" - def __init__(self, raw, limit=8192): - self.raw = raw + def __init__(self, trans, limit=8192): + self.trans = trans self.limit = limit self.buffer = b'' self.eof = False def read(self, n): - """Read up to n bytes, blocking at most once.""" + """COROUTINE: Read up to n bytes, blocking at most once.""" assert n >= 0, n if not self.buffer and not self.eof: - yield from self.fillbuffer(max(n, self.limit)) - return self.getfrombuffer(n) + yield from self._fillbuffer(max(n, self.limit)) + return self._getfrombuffer(n) def readexactly(self, n): - """Read exactly n bytes, or until EOF.""" + """COUROUTINE: Read exactly n bytes, or until EOF.""" blocks = [] count = 0 while n > count: @@ -57,57 +104,38 @@ def readexactly(self, n): return b''.join(blocks) def readline(self): - """Read up to newline or limit, whichever comes first.""" + """COROUTINE: Read up to newline or limit, whichever comes first.""" end = self.buffer.find(b'\n') + 1 # Point past newline, or 0. while not end and not self.eof and len(self.buffer) < self.limit: anchor = len(self.buffer) - yield from self.fillbuffer(self.limit) + yield from self._fillbuffer(self.limit) end = self.buffer.find(b'\n', anchor) + 1 if not end: end = len(self.buffer) if end > self.limit: end = self.limit - return self.getfrombuffer(end) + return self._getfrombuffer(end) - def getfrombuffer(self, n): - """Read up to n bytes without blocking.""" + def _getfrombuffer(self, n): + """Read up to n bytes without blocking (not a coroutine).""" if n >= len(self.buffer): result, self.buffer = self.buffer, b'' else: result, self.buffer = self.buffer[:n], self.buffer[n:] return result - def fillbuffer(self, n): - """Fill buffer with one (up to) n bytes from raw reader.""" - assert not self.eof, 'fillbuffer called at eof' - data = yield from self.raw.read(n) -## print('fillbuffer:', repr(data)[:100]) + def _fillbuffer(self, n): + """COROUTINE: Fill buffer with one (up to) n bytes from transport.""" + assert not self.eof, '_fillbuffer called at eof' + data = yield from self.trans.recv(n) if data: self.buffer += data else: self.eof = True -def send(sock, data): -## print('send:', repr(data)) - while data: - scheduler.block_w(sock.fileno()) - yield - n = sock.send(data) - assert 0 <= n <= len(data), (n, len(data)) - if n == len(data): - break - data = data[n:] - - -def newsocket(af, socktype, proto): - sock = socket.socket(af, socktype, proto) - sock.setblocking(False) - return sock - - def connect(sock, address): -## print('connect:', address) + """COROUTINE: Connect a socket to an address.""" try: sock.connect(address) except socket.error as err: @@ -117,10 +145,22 @@ def connect(sock, address): yield err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) if err != 0: - raise IOError('Connection refused') + raise IOError(err, 'Connection refused') + + +def getaddrinfo(host, port, af=0, socktype=0, proto=0): + """COROUTINE: Look up an address and return a list of infos for it. + + Each info is a tuple (af, socktype, protocol, canonname, address). + """ + infos = yield from scheduler.call_in_thread(socket.getaddrinfo, + host, port, af, + socktype, proto) + return infos def create_connection(address): + """COROUTINE: Look up address and create a socket connected to it.""" host, port = address match = re.match(r'(\d+)(\.\d+)(\.\d+)(\.\d+)\Z', host) if match: @@ -129,10 +169,7 @@ def create_connection(address): 0 <= d3 <= 255 and 0 <= d4 <= 255): match = None if not match: - infos = yield from scheduler.call_in_thread(socket.getaddrinfo, - host, port, socket.AF_INET, - socket.SOCK_STREAM, - socket.SOL_TCP) + infos = yield from getaddrinfo(host, port, socktype=socket.SOCK_STREAM) else: infos = [(socket.AF_INET, socket.SOCK_STREAM, socket.SOL_TCP, '', (host, port))] @@ -141,7 +178,8 @@ def create_connection(address): for af, socktype, proto, cname, address in infos: sock = None try: - sock = newsocket(af, socktype, proto) + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) yield from connect(sock, address) break except socket.error as err: @@ -153,3 +191,9 @@ def create_connection(address): if exc is not None: raise exc return sock + + +def create_transport(address): + """COROUTINE: Look up address and create a transport connected to it.""" + sock = yield from create_connection(address) + return SocketTransport(sock) From 906e4bcbf82d12f82bb079a9bde54dbec8ea60af Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 24 Oct 2012 06:54:28 -0700 Subject: [PATCH 0034/1502] Add sleep() to scheduler. --- scheduling.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/scheduling.py b/scheduling.py index c97e35b8..d7721f50 100644 --- a/scheduling.py +++ b/scheduling.py @@ -95,3 +95,8 @@ def call_in_thread(self, func, *args): yield assert future.done() return future.result() + + def sleep(self, secs): + task, name = self.block() + self.eventloop.call_later(secs, self.start, task, name) + yield From 07a579429127e76bfe86f72da7b4d2ffad94231a Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 24 Oct 2012 08:32:06 -0700 Subject: [PATCH 0035/1502] Add a timeout feature to start(). A bit messy. --- main.py | 6 +-- scheduling.py | 115 +++++++++++++++++++++++++++++++++----------------- 2 files changed, 79 insertions(+), 42 deletions(-) diff --git a/main.py b/main.py index c7c96916..b23223ef 100644 --- a/main.py +++ b/main.py @@ -105,14 +105,14 @@ def doit(): # This references NDB's default test service. # (Sadly the service is single-threaded.) gen1 = urlfetch('localhost', 8080, path='/') - scheduler.start(gen1, 'gen1') + scheduler.start(gen1, 'gen1', timeout=2) gen2 = urlfetch('localhost', 8080, path='/home') - scheduler.start(gen2, 'gen2') + scheduler.start(gen2, 'gen2', timeout=2) # Fetch python.org home page. gen3 = urlfetch('python.org', 80, path='/') - scheduler.start(gen3, 'gen3') + scheduler.start(gen3, 'gen3', timeout=2) ## # Fetch many links from python.org (/x.y.z). ## for x in '123': diff --git a/scheduling.py b/scheduling.py index d7721f50..3f00e26d 100644 --- a/scheduling.py +++ b/scheduling.py @@ -21,39 +21,64 @@ # Standard library imports (keep in alphabetic order). import logging +import time -class Scheduler: +class TimeoutExpired(Exception): + pass - def __init__(self, eventloop, threadrunner): - self.eventloop = eventloop # polling.EventLoop instance. - self.threadrunner = threadrunner # polling.Threadrunner instance. - self.current = None # Current generator. - self.current_name = None # Current generator's name. - def run(self): - self.eventloop.run() +class Task: + """Lightweight wrapper around a generator.""" - def start(self, task, name=None): - if name is None: - name = task.__name__ # If it doesn't have one, pass one. - self.eventloop.call_soon(self.run_task, task, name) + def __init__(self, sched, gen, name=None, *, timeout=None): + self.sched = sched + self.gen = gen + self.name = name or gen.__name__ + if timeout is not None and timeout < 1000000: + timeout += time.time() + self.timeout = timeout + self.alive = True - def run_task(self, task, name): + def run(self): + if not self.alive: + return + self.sched.current = self try: - self.current = task - self.current_name = name - next(self.current) + if self.timeout is not None and self.timeout < time.time(): + self.gen.throw(TimeoutExpired) + else: + next(self.gen) except StopIteration: - pass + self.alive = False except Exception: - logging.exception('Exception in task %r', name) + self.alive = False + logging.exception('Uncaught exception in task %r', self.name) + except BaseException: + self.alive = False + raise else: - if self.current is not None: - self.start(task, name) + if self.sched.current is not None: + self.start() finally: - self.current = None - self.current_name = None + self.sched.current = None + + def start(self): + self.sched.eventloop.call_soon(self.run) + + +class Scheduler: + + def __init__(self, eventloop, threadrunner): + self.eventloop = eventloop # polling.EventLoop instance. + self.threadrunner = threadrunner # polling.Threadrunner instance. + self.current = None # Current Task. + + def run(self): + self.eventloop.run() + + def start(self, gen, name=None, *, timeout=None): + Task(self, gen, name, timeout=timeout).start() def block_r(self, fd): self.block_io(fd, 'r') @@ -64,39 +89,51 @@ def block_w(self, fd): def block_io(self, fd, flag): assert isinstance(fd, int), repr(fd) assert flag in ('r', 'w'), repr(flag) - task, name = self.block() + task = self.block() if flag == 'r': - method = self.eventloop.add_reader - callback = self.unblock_r + self.eventloop.add_reader(fd, self.unblock_io, fd, flag, task) else: - method = self.eventloop.add_writer - callback = self.unblock_w - method(fd, callback, fd, task, name) + self.eventloop.add_writer(fd, self.unblock_io, fd, flag, task) + if task.timeout: + self.eventloop.call_later(task.timeout, + self.unblock_timeout, fd, flag, task) def block(self): assert self.current task = self.current self.current = None - return task, self.current_name + return task - def unblock_r(self, fd, task, name): - self.eventloop.remove_reader(fd) - self.start(task, name) + def unblock_io(self, fd, flag, task): + if flag == 'r': + self.eventloop.remove_reader(fd) + else: + self.eventloop.remove_writer(fd) + task.start() - def unblock_w(self, fd, task, name): - self.eventloop.remove_writer(fd) - self.start(task, name) + def unblock_timeout(self, fd, flag, task): + if not task.alive: + return + if flag == 'r': + if fd in self.eventloop.readers: + self.eventloop.remove_reader(fd) + else: + if fd in self.eventloop.writers: + self.eventloop.remove_writer(fd) + if task.alive: + task.timeout = 0 # Force it to cancel + task.start() def call_in_thread(self, func, *args): # TODO: Prove there is no race condition here. - task, name = self.block() + task = self.block() future = self.threadrunner.submit(func, *args) - future.add_done_callback(lambda _: self.start(task, name)) + future.add_done_callback(lambda _: task.start()) yield assert future.done() return future.result() def sleep(self, secs): - task, name = self.block() - self.eventloop.call_later(secs, self.start, task, name) + task = self.block() + self.eventloop.call_later(secs, task.start) yield From e1f62d99d81a5c306e4a6192a51639a84718321d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 24 Oct 2012 08:40:56 -0700 Subject: [PATCH 0036/1502] Add timeout support for call_in_thread(). --- scheduling.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/scheduling.py b/scheduling.py index 3f00e26d..3e0ad328 100644 --- a/scheduling.py +++ b/scheduling.py @@ -129,7 +129,11 @@ def call_in_thread(self, func, *args): task = self.block() future = self.threadrunner.submit(func, *args) future.add_done_callback(lambda _: task.start()) - yield + try: + yield + except TimeoutExpired: + future.cancel() + raise assert future.done() return future.result() From 82e77d30fef2e7d74f508d860fcceaa2f876a3ba Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 24 Oct 2012 08:43:17 -0700 Subject: [PATCH 0037/1502] Reuse concurrent.futures.TimeoutError. --- scheduling.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/scheduling.py b/scheduling.py index 3e0ad328..8a58254b 100644 --- a/scheduling.py +++ b/scheduling.py @@ -20,14 +20,11 @@ __author__ = 'Guido van Rossum ' # Standard library imports (keep in alphabetic order). +from concurrent.futures import TimeoutError import logging import time -class TimeoutExpired(Exception): - pass - - class Task: """Lightweight wrapper around a generator.""" @@ -46,7 +43,7 @@ def run(self): self.sched.current = self try: if self.timeout is not None and self.timeout < time.time(): - self.gen.throw(TimeoutExpired) + self.gen.throw(TimeoutError) else: next(self.gen) except StopIteration: @@ -131,7 +128,7 @@ def call_in_thread(self, func, *args): future.add_done_callback(lambda _: task.start()) try: yield - except TimeoutExpired: + except TimeoutError: future.cancel() raise assert future.done() From 1a3237e5206c0cc044dea4171f82107876956d77 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 24 Oct 2012 09:23:40 -0700 Subject: [PATCH 0038/1502] Add cancellation option to call_later(). Use it for task timeouts. --- polling.py | 43 +++++++++++++++++++++++++++++++++++++------ scheduling.py | 31 +++++++++++++++++++------------ 2 files changed, 56 insertions(+), 18 deletions(-) diff --git a/polling.py b/polling.py index e85a0ecb..6078fd43 100644 --- a/polling.py +++ b/polling.py @@ -264,6 +264,28 @@ def poll(self, timeout=None): return events +class DelayedCall: + """Object returned by call_later(); can be used to cancel the call.""" + + def __init__(self, when, callback, args): + self.when = when + self.callback = callback + self.args = args + self.cancelled = False + + def cancel(self): + self.cancelled = True + + def __lt__(self, other): + return self.when < other.when + + def __le__(self, other): + return self.when <= other.when + + def __eq__(self, other): + return self.when == other.when + + class EventLoopMixin(PollsterBase): """Event loop functionality. @@ -302,6 +324,9 @@ def call_soon(self, callback, *args): def call_later(self, when, callback, *args): """Arrange for a callback to be called at a given time. + Return an object with a cancel() method that can be used to + cancel the call. + The time can be an int or float, expressed in seconds. If when is small enough (~11 days), it's assumed to be a @@ -318,7 +343,9 @@ def call_later(self, when, callback, *args): """ if when < 10000000: when += time.time() - heapq.heappush(self.scheduled, (when, callback, args)) + dcall = DelayedCall(when, callback, args) + heapq.heappush(self.scheduled, dcall) + return dcall def run_once(self): """Run one full iteration of the event loop. @@ -346,10 +373,14 @@ def run_once(self): logging.exception('Exception in callback %s %r', callback, args) + # Remove delayed calls that were cancelled from head of queue. + while self.scheduled and self.scheduled[0].cancelled: + heapq.heappop(self.scheduled) + # Inspect the poll queue. if self.pollable(): if self.scheduled: - when, _, _ = self.scheduled[0] + when = self.scheduled[0].when timeout = max(0, when - time.time()) else: timeout = None @@ -359,11 +390,11 @@ def run_once(self): # Handle 'later' callbacks that are ready. while self.scheduled: - when, _, _ = self.scheduled[0] - if when > time.time(): + dcall = self.scheduled[0] + if dcall.when > time.time(): break - when, callback, args = heapq.heappop(self.scheduled) - self.call_soon(callback, *args) + dcall = heapq.heappop(self.scheduled) + self.call_soon(dcall.callback, *dcall.args) def run(self): """Run the event loop until there is no work left to do. diff --git a/scheduling.py b/scheduling.py index 8a58254b..c35c8cea 100644 --- a/scheduling.py +++ b/scheduling.py @@ -61,7 +61,8 @@ def run(self): self.sched.current = None def start(self): - self.sched.eventloop.call_soon(self.run) + if self.alive: + self.sched.eventloop.call_soon(self.run) class Scheduler: @@ -87,13 +88,17 @@ def block_io(self, fd, flag): assert isinstance(fd, int), repr(fd) assert flag in ('r', 'w'), repr(flag) task = self.block() + dcall = None + if task.timeout: + dcall = self.eventloop.call_later(task.timeout, + self.unblock_timeout, + fd, flag, task) if flag == 'r': - self.eventloop.add_reader(fd, self.unblock_io, fd, flag, task) + self.eventloop.add_reader(fd, self.unblock_io, + fd, flag, task, dcall) else: - self.eventloop.add_writer(fd, self.unblock_io, fd, flag, task) - if task.timeout: - self.eventloop.call_later(task.timeout, - self.unblock_timeout, fd, flag, task) + self.eventloop.add_writer(fd, self.unblock_io, + fd, flag, task, dcall) def block(self): assert self.current @@ -101,7 +106,9 @@ def block(self): self.current = None return task - def unblock_io(self, fd, flag, task): + def unblock_io(self, fd, flag, task, dcall): + if dcall is not None: + dcall.cancel() if flag == 'r': self.eventloop.remove_reader(fd) else: @@ -109,17 +116,17 @@ def unblock_io(self, fd, flag, task): task.start() def unblock_timeout(self, fd, flag, task): - if not task.alive: - return + # NOTE: Due to the call_soon() semantics, we can't guarantee + # that unblock_timeout() isn't called *after* unblock_io() has + # already been called. So we must write this defensively. if flag == 'r': if fd in self.eventloop.readers: self.eventloop.remove_reader(fd) else: if fd in self.eventloop.writers: self.eventloop.remove_writer(fd) - if task.alive: - task.timeout = 0 # Force it to cancel - task.start() + task.timeout = 0 # Force it to cancel. + task.start() def call_in_thread(self, func, *args): # TODO: Prove there is no race condition here. From fbb304d5eea6203823f90cf03a44336bfaf4dea3 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 24 Oct 2012 10:20:10 -0700 Subject: [PATCH 0039/1502] Use time -p for friendlier output. --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 43f15dbc..2e30d535 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ test: - time python3.3 main.py + time -p python3.3 main.py profile: python3.3 -m profile -s time main.py From a9c1218a110bf103b43fd1002096474b8350fe2c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 24 Oct 2012 17:14:41 -0700 Subject: [PATCH 0040/1502] Add TODO about checking for races. --- scheduling.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scheduling.py b/scheduling.py index c35c8cea..86e96030 100644 --- a/scheduling.py +++ b/scheduling.py @@ -119,6 +119,7 @@ def unblock_timeout(self, fd, flag, task): # NOTE: Due to the call_soon() semantics, we can't guarantee # that unblock_timeout() isn't called *after* unblock_io() has # already been called. So we must write this defensively. + # TODO: Analyse this further for race conditions etc. if flag == 'r': if fd in self.eventloop.readers: self.eventloop.remove_reader(fd) From f9a45bc0e349aba78591e87ace25017b15b0f616 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 25 Oct 2012 08:05:21 -0700 Subject: [PATCH 0041/1502] Add temporary debug prints around poll() call. --- polling.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/polling.py b/polling.py index 6078fd43..e1dab505 100644 --- a/polling.py +++ b/polling.py @@ -384,7 +384,11 @@ def run_once(self): timeout = max(0, when - time.time()) else: timeout = None + t0 = time.time() events = self.poll(timeout) + t1 = time.time() + print('poll' if timeout is None else 'poll {:.3f}'.format(timeout), + 'took {:.3f} seconds'.format(t1-t0)) for fd, flag, callback, args in events: self.call_soon(callback, *args) From 9ccd89c20b0921bed1721c2ed7b734d0819229d3 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 25 Oct 2012 08:15:37 -0700 Subject: [PATCH 0042/1502] Fix the regex to match IPv4 addresses. --- sockets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sockets.py b/sockets.py index 4b780ae3..1429e3ee 100644 --- a/sockets.py +++ b/sockets.py @@ -162,7 +162,7 @@ def getaddrinfo(host, port, af=0, socktype=0, proto=0): def create_connection(address): """COROUTINE: Look up address and create a socket connected to it.""" host, port = address - match = re.match(r'(\d+)(\.\d+)(\.\d+)(\.\d+)\Z', host) + match = re.match(r'(\d+)\.(\d+)\.(\d+)\.(\d+)\Z', host) if match: d1, d2, d3, d4 = map(int, match.groups()) if not (0 <= d1 <= 255 and 0 <= d2 <= 255 and From 2c4e3a989244d584e9bab896de72aefc21eea8c4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 25 Oct 2012 08:33:15 -0700 Subject: [PATCH 0043/1502] Use debug level logging to print poll timing. --- polling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/polling.py b/polling.py index e1dab505..1f3b161c 100644 --- a/polling.py +++ b/polling.py @@ -387,8 +387,8 @@ def run_once(self): t0 = time.time() events = self.poll(timeout) t1 = time.time() - print('poll' if timeout is None else 'poll {:.3f}'.format(timeout), - 'took {:.3f} seconds'.format(t1-t0)) + argstr = '' if timeout is None else ' %.3f' % timeout + logging.debug('poll%s took %.3f seconds', argstr, t1-t0) for fd, flag, callback, args in events: self.call_soon(callback, *args) From 43fafe713e7d399e65e6ff29264ea39d49ea6df3 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 25 Oct 2012 08:33:47 -0700 Subject: [PATCH 0044/1502] Evolve Task more towards a future-like object. --- main.py | 38 ++++++++++++++++++++++++++++---------- scheduling.py | 29 +++++++++++++++++++++++++---- 2 files changed, 53 insertions(+), 14 deletions(-) diff --git a/main.py b/main.py index b23223ef..7eb88127 100644 --- a/main.py +++ b/main.py @@ -26,9 +26,19 @@ import logging import re import time +import sys # Initialize logging before we import polling. -logging.basicConfig(level=logging.INFO) +# TODO: Change polling.py so we can do this in main(). +if '-d' in sys.argv: + level = logging.DEBUG +elif '-v' in sys.argv: + level = logging.INFO +elif '-q' in sys.argv: + level = logging.ERROR +else: + level = logging.WARN +logging.basicConfig(level=level) # Local imports (keep in alphabetic order). import polling @@ -98,21 +108,22 @@ def urlfetch(host, port=80, method='GET', path='/', data = yield from rdr.readexactly(size) trans.close() # Can this block? t1 = time.time() - print(host, port, path, status, len(data), '{:.3}'.format(t1-t0)) + return (host, port, path, status, len(data), '{:.3}'.format(t1-t0)) def doit(): # This references NDB's default test service. # (Sadly the service is single-threaded.) - gen1 = urlfetch('localhost', 8080, path='/') - scheduler.start(gen1, 'gen1', timeout=2) - - gen2 = urlfetch('localhost', 8080, path='/home') - scheduler.start(gen2, 'gen2', timeout=2) + task1 = scheduler.newtask(urlfetch('localhost', 8080, path='/'), + 'task1', timeout=2) + task2 = scheduler.newtask(urlfetch('localhost', 8080, path='/home'), + 'task2', timeout=2) # Fetch python.org home page. - gen3 = urlfetch('python.org', 80, path='/') - scheduler.start(gen3, 'gen3', timeout=2) + task3 = scheduler.newtask(urlfetch('python.org', 80, path='/'), + 'task3', timeout=2) + + tasks = {task1, task2, task3} ## # Fetch many links from python.org (/x.y.z). ## for x in '123': @@ -120,9 +131,16 @@ def doit(): ## path = '/{}.{}'.format(x, y) ## g = urlfetch('82.94.164.162', 80, ## path=path, hdrs={'host': 'python.org'}) -## scheduler.start(g, path) +## t = scheduler.newtask(g, path, timeout=2) +## tasks.add(t) +## print(tasks) + for t in tasks: + t.start() scheduler.run() +## print(tasks) + for t in tasks: + print(t.name + ':', t.exception or t.result) def main(): diff --git a/scheduling.py b/scheduling.py index 86e96030..d43a3530 100644 --- a/scheduling.py +++ b/scheduling.py @@ -26,7 +26,14 @@ class Task: - """Lightweight wrapper around a generator.""" + """Lightweight wrapper around a generator. + + This is a bit like a Future, but with a different interface. + + TODO: + - cancellation. + - wait for result. + """ def __init__(self, sched, gen, name=None, *, timeout=None): self.sched = sched @@ -36,6 +43,12 @@ def __init__(self, sched, gen, name=None, *, timeout=None): timeout += time.time() self.timeout = timeout self.alive = True + self.result = None + self.exception = None + + def __repr__(self): + return 'Task<%r, timeout=%s>(alive=%r, result=%r, exception=%r)' % ( + self.name, self.timeout, self.alive, self.result, self.exception) def run(self): if not self.alive: @@ -46,12 +59,15 @@ def run(self): self.gen.throw(TimeoutError) else: next(self.gen) - except StopIteration: + except StopIteration as exc: + self.result = exc.value self.alive = False - except Exception: + except Exception as exc: + self.exception = exc self.alive = False logging.exception('Uncaught exception in task %r', self.name) except BaseException: + self.exception = exc self.alive = False raise else: @@ -75,8 +91,13 @@ def __init__(self, eventloop, threadrunner): def run(self): self.eventloop.run() + def newtask(self, gen, name=None, *, timeout=None): + return Task(self, gen, name, timeout=timeout) + def start(self, gen, name=None, *, timeout=None): - Task(self, gen, name, timeout=timeout).start() + task = self.newtask(gen, name, timeout=timeout) + task.start() + return task def block_r(self, fd): self.block_io(fd, 'r') From 14fe2b930113ea5c9b7168a16c5a361f49160b58 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 25 Oct 2012 08:40:35 -0700 Subject: [PATCH 0045/1502] Generalize call_in_thread: optionally pass in an executor. --- polling.py | 6 ++++-- scheduling.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/polling.py b/polling.py index 1f3b161c..c8220349 100644 --- a/polling.py +++ b/polling.py @@ -453,14 +453,16 @@ def read_callback(self): self.eventloop.remove_reader(self.pipe_read_fd) assert self.active_count >= 0, self.active_count - def submit(self, func, *args): + def submit(self, func, *args, executor=None): """Submit a function to the thread pool. This returns a concurrent.futures.Future instance. The caller should not wait for that, but rather add a callback to it. """ + if executor is None: + executor = self.threadpool assert self.active_count >= 0, self.active_count - future = self.threadpool.submit(func, *args) + future = executor.submit(func, *args) if self.active_count == 0: self.eventloop.add_reader(self.pipe_read_fd, self.read_callback) self.active_count += 1 diff --git a/scheduling.py b/scheduling.py index d43a3530..638ef90f 100644 --- a/scheduling.py +++ b/scheduling.py @@ -150,10 +150,10 @@ def unblock_timeout(self, fd, flag, task): task.timeout = 0 # Force it to cancel. task.start() - def call_in_thread(self, func, *args): + def call_in_thread(self, func, *args, executor=None): # TODO: Prove there is no race condition here. task = self.block() - future = self.threadrunner.submit(func, *args) + future = self.threadrunner.submit(func, *args, executor=executor) future.add_done_callback(lambda _: task.start()) try: yield From 9825748f419aa5f98ec902ef0d46cbdb0e88f0d1 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 25 Oct 2012 08:48:44 -0700 Subject: [PATCH 0046/1502] Use os.utimes to print real/user/system time. --- Makefile | 2 +- main.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 2e30d535..db8ce1f6 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ test: - time -p python3.3 main.py + python3.3 main.py -v profile: python3.3 -m profile -s time main.py diff --git a/main.py b/main.py index 7eb88127..4a9c1e48 100644 --- a/main.py +++ b/main.py @@ -24,6 +24,7 @@ # Standard library imports (keep in alphabetic order). import logging +import os import re import time import sys @@ -143,8 +144,18 @@ def doit(): print(t.name + ':', t.exception or t.result) +def logtimes(real): + utime, stime, cutime, cstime, unused = os.times() + logging.info('real %10.3f', real) + logging.info('user %10.3f', utime + cutime) + logging.info('sys %10.3f', stime + cstime) + + def main(): + t0 = time.time() doit() + t1 = time.time() + logtimes(t1-t0) if __name__ == '__main__': From eac438105d977bedb3e0a20ae2387a62cfafdfd5 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 25 Oct 2012 09:14:35 -0700 Subject: [PATCH 0047/1502] Tweak assignment order in except clauses. --- scheduling.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scheduling.py b/scheduling.py index 638ef90f..d4c3cb6b 100644 --- a/scheduling.py +++ b/scheduling.py @@ -60,15 +60,15 @@ def run(self): else: next(self.gen) except StopIteration as exc: - self.result = exc.value self.alive = False + self.result = exc.value except Exception as exc: - self.exception = exc self.alive = False + self.exception = exc logging.exception('Uncaught exception in task %r', self.name) except BaseException: - self.exception = exc self.alive = False + self.exception = exc raise else: if self.sched.current is not None: From 9c8104bd57b3c79766dc5fac477b853f6a2dce70 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 25 Oct 2012 14:10:00 -0700 Subject: [PATCH 0048/1502] Add ssl transport. Kill shutdown() call. --- main.py | 26 +++++++++++----- sockets.py | 88 +++++++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 89 insertions(+), 25 deletions(-) diff --git a/main.py b/main.py index 4a9c1e48..d8e9d7b2 100644 --- a/main.py +++ b/main.py @@ -54,9 +54,14 @@ def urlfetch(host, port=80, method='GET', path='/', - body=None, hdrs=None, encoding='utf-8'): + body=None, hdrs=None, encoding='utf-8', ssl=None): t0 = time.time() - trans = yield from sockets.create_transport((host, port)) + if ssl is None: + ssl = (port == 443) + if ssl: + trans = yield from sockets.create_ssl_transport((host, port)) + else: + trans = yield from sockets.create_transport((host, port)) yield from trans.send(method.encode(encoding) + b' ' + path.encode(encoding) + b' HTTP/1.0\r\n') if hdrs: @@ -74,12 +79,12 @@ def urlfetch(host, port=80, method='GET', path='/', yield from trans.send(b'\r\n') if body is not None: yield from trans.send(body) - trans.shutdown('w') # Close the writing end of the socket. # Read HTTP response line. rdr = sockets.BufferedReader(trans) resp = yield from rdr.readline() - m = re.match(br'(?ix) http/(\d\.\d) \s+ (\d\d\d) \s+ ([^\r]*)\r?\n\Z', resp) + m = re.match(br'(?ix) http/(\d\.\d) \s+ (\d\d\d) \s+ ([^\r]*)\r?\n\Z', + resp) if not m: trans.close() raise IOError('No valid HTTP response: %r' % resp) @@ -109,23 +114,28 @@ def urlfetch(host, port=80, method='GET', path='/', data = yield from rdr.readexactly(size) trans.close() # Can this block? t1 = time.time() - return (host, port, path, status, len(data), '{:.3}'.format(t1-t0)) + return (host, port, path, int(status), len(data), round(t1-t0, 3)) def doit(): # This references NDB's default test service. # (Sadly the service is single-threaded.) task1 = scheduler.newtask(urlfetch('localhost', 8080, path='/'), - 'task1', timeout=2) + 'root', timeout=2) task2 = scheduler.newtask(urlfetch('localhost', 8080, path='/home'), - 'task2', timeout=2) + 'home', timeout=2) # Fetch python.org home page. task3 = scheduler.newtask(urlfetch('python.org', 80, path='/'), - 'task3', timeout=2) + 'python', timeout=2) tasks = {task1, task2, task3} + # Fetch XKCD home page using SSL. + task4 = scheduler.newtask(urlfetch('xkcd.com', 443, path='/'), + 'xkcd', timeout=2) + tasks.add(task4) + ## # Fetch many links from python.org (/x.y.z). ## for x in '123': ## for y in '0123456789': diff --git a/sockets.py b/sockets.py index 1429e3ee..ac3748e8 100644 --- a/sockets.py +++ b/sockets.py @@ -26,12 +26,13 @@ import errno import re import socket +import ssl class SocketTransport: """Transport wrapping a socket. - The socket must already be connected. + The socket must already be connected in non-blocking mode. """ def __init__(self, sock): @@ -55,26 +56,70 @@ def send(self, data): break data = data[n:] - def shutdown(self, flag): - """Call shutdown() on the socket. (Not a coroutine.) + def close(self): + """Close the socket. (Not a coroutine.)""" + self.sock.close() + + +class SSLTransport: + """Transport wrapping a socket in SSL. + + The socket must already be connected at the TCP level in + non-blocking mode. + """ + + def __init__(self, rawsock, sslcontext=None): + self.rawsock = rawsock + self.sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self.sslsock = self.sslcontext.wrap_socket( + self.rawsock, do_handshake_on_connect=False) + + def do_handshake(self): + """COROUTINE: Finish the SSL handshake.""" + while True: + try: + self.sslsock.do_handshake() + except ssl.SSLWantReadError: + scheduler.block_r(self.sslsock.fileno()) + yield + except ssl.SSLWantWriteError: + scheduler.block_w(self.sslsock.fileno()) + yield + else: + break - This is like closing one direction. + def recv(self, n): + """COROUTINE: Read up to n bytes. - The argument must be 'r', 'w' or 'rw'. + This blocks until at least one byte is read, or until EOF. """ - if flag == 'r': - flag = socket.SHUT_RD - elif flag == 'w': - flag = socket.SHUT_WR - elif flag == 'rw': - flag = socket.SHUT_RDWR - else: - raise ValueError('flag must be r, w or rw, not %s' % flag) - self.sock.shutdown(flag) + while True: + try: + return self.sslsock.recv(n) + except socket.error as err: + scheduler.block_r(self.sslsock.fileno()) + yield + + def send(self, data): + """COROUTINE: Send data to the socket, blocking as needed.""" + while data: + try: + n = self.sslsock.send(data) + except socket.error as err: + scheduler.block_w(self.sslsock.fileno()) + yield + if n == len(data): + break + data = data[n:] def close(self): - """Close the socket. (Not a coroutine.)""" - self.sock.close() + """Close the socket. (Not a coroutine.) + + This also closes the raw socket. + """ + self.sslsock.close() + + # TODO: More SSL-specific methods, e.g. certificate stuff, unwrap(), ... class BufferedReader: @@ -169,7 +214,8 @@ def create_connection(address): 0 <= d3 <= 255 and 0 <= d4 <= 255): match = None if not match: - infos = yield from getaddrinfo(host, port, socktype=socket.SOCK_STREAM) + infos = yield from getaddrinfo(host, port, + socktype=socket.SOCK_STREAM) else: infos = [(socket.AF_INET, socket.SOCK_STREAM, socket.SOL_TCP, '', (host, port))] @@ -197,3 +243,11 @@ def create_transport(address): """COROUTINE: Look up address and create a transport connected to it.""" sock = yield from create_connection(address) return SocketTransport(sock) + + +def create_ssl_transport(address): + """COROUTINE: Look up address and create an SSL transport connected.""" + rawsock = yield from create_connection(address) + trans = SSLTransport(rawsock) + yield from trans.do_handshake() + return trans From 32d0fc30e751c06a87e0616f060659b441fdcf17 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 25 Oct 2012 14:32:26 -0700 Subject: [PATCH 0049/1502] Improve transport abstraction. Add address family option. --- main.py | 15 ++++++--------- sockets.py | 32 +++++++++++++++----------------- 2 files changed, 21 insertions(+), 26 deletions(-) diff --git a/main.py b/main.py index d8e9d7b2..f4e64d04 100644 --- a/main.py +++ b/main.py @@ -27,6 +27,7 @@ import os import re import time +import socket import sys # Initialize logging before we import polling. @@ -54,14 +55,9 @@ def urlfetch(host, port=80, method='GET', path='/', - body=None, hdrs=None, encoding='utf-8', ssl=None): + body=None, hdrs=None, encoding='utf-8', ssl=None, af=0): t0 = time.time() - if ssl is None: - ssl = (port == 443) - if ssl: - trans = yield from sockets.create_ssl_transport((host, port)) - else: - trans = yield from sockets.create_transport((host, port)) + trans = yield from sockets.create_transport(host, port, ssl=ssl, af=af) yield from trans.send(method.encode(encoding) + b' ' + path.encode(encoding) + b' HTTP/1.0\r\n') if hdrs: @@ -131,8 +127,9 @@ def doit(): tasks = {task1, task2, task3} - # Fetch XKCD home page using SSL. - task4 = scheduler.newtask(urlfetch('xkcd.com', 443, path='/'), + # Fetch XKCD home page using SSL. (Doesn't like IPv6.) + task4 = scheduler.newtask(urlfetch('xkcd.com', 443, path='/', + af=socket.AF_INET), 'xkcd', timeout=2) tasks.add(task4) diff --git a/sockets.py b/sockets.py index ac3748e8..85b70052 100644 --- a/sockets.py +++ b/sockets.py @@ -3,14 +3,15 @@ Classes: - SocketTransport: a transport implementation wrapping a socket. +- SslTransport: a transport implementation wrapping SSL around a socket. - BufferedReader: a buffer wrapping the read end of a transport. Functions (all coroutines): - connect(): connect a socket. - getaddrinfo(): look up an address. -- create_connection(): look up an address and return a connected socket for it. -- create_transport(): look up an address and return a connected trahsport. +- create_connection(): look up address and return a connected socket for it. +- create_transport(): look up address and return a connected transport. TODO: - Improve transport abstraction. @@ -18,7 +19,6 @@ - Unittests. - A write() call that isn't a generator (needed so you can substitute it for sys.stderr, pass it to logging.StreamHandler, etc.). -- Move getaddrinfo() call here. """ __author__ = 'Guido van Rossum ' @@ -61,7 +61,7 @@ def close(self): self.sock.close() -class SSLTransport: +class SslTransport: """Transport wrapping a socket in SSL. The socket must already be connected at the TCP level in @@ -204,9 +204,8 @@ def getaddrinfo(host, port, af=0, socktype=0, proto=0): return infos -def create_connection(address): +def create_connection(host, port, af=0, socktype=socket.SOCK_STREAM): """COROUTINE: Look up address and create a socket connected to it.""" - host, port = address match = re.match(r'(\d+)\.(\d+)\.(\d+)\.(\d+)\Z', host) if match: d1, d2, d3, d4 = map(int, match.groups()) @@ -215,7 +214,7 @@ def create_connection(address): match = None if not match: infos = yield from getaddrinfo(host, port, - socktype=socket.SOCK_STREAM) + af=af, socktype=socket.SOCK_STREAM) else: infos = [(socket.AF_INET, socket.SOCK_STREAM, socket.SOL_TCP, '', (host, port))] @@ -239,15 +238,14 @@ def create_connection(address): return sock -def create_transport(address): +def create_transport(host, port, af=0, ssl=None): """COROUTINE: Look up address and create a transport connected to it.""" - sock = yield from create_connection(address) - return SocketTransport(sock) - - -def create_ssl_transport(address): - """COROUTINE: Look up address and create an SSL transport connected.""" - rawsock = yield from create_connection(address) - trans = SSLTransport(rawsock) - yield from trans.do_handshake() + if ssl is None: + ssl = (port == 443) + sock = yield from create_connection(host, port, af=af) + if ssl: + trans = SslTransport(sock) + yield from trans.do_handshake() + else: + trans = SocketTransport(sock) return trans From 7f0d5f4804160a1f3b78e9e2b66ee6c652e6de30 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 25 Oct 2012 16:11:43 -0700 Subject: [PATCH 0050/1502] Add longlines.py and "make check" target. --- Makefile | 11 ++++++++--- longlines.py | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 3 deletions(-) create mode 100644 longlines.py diff --git a/Makefile b/Makefile index db8ce1f6..8613d101 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,13 @@ +PYTHON=python3.3 + test: - python3.3 main.py -v + $(PYTHON) main.py -v profile: - python3.3 -m profile -s time main.py + $(PYTHON) -m profile -s time main.py time: - python3.3 p3time.py + $(PYTHON) p3time.py + +check: + $(PYTHON) longlines.py diff --git a/longlines.py b/longlines.py new file mode 100644 index 00000000..f0aa9a66 --- /dev/null +++ b/longlines.py @@ -0,0 +1,40 @@ +"""Search for lines > 80 chars or with trailing whitespace.""" + +import sys, os + +def main(): + args = sys.argv[1:] or os.curdir + for arg in args: + if os.path.isdir(arg): + for dn, dirs, files in os.walk(arg): + for fn in sorted(files): + if fn.endswith('.py'): + process(os.path.join(dn, fn)) + dirs[:] = [d for d in dirs if d[0] != '.'] + dirs.sort() + else: + process(arg) + +def isascii(x): + try: + x.encode('ascii') + return True + except UnicodeError: + return False + +def process(fn): + try: + f = open(fn) + except IOError as err: + print(err) + return + try: + for i, line in enumerate(f): + line = line.rstrip('\n') + sline = line.rstrip() + if len(line) > 80 or line != sline or not isascii(line): + print('%s:%d:%s%s' % (fn, i+1, sline, '_' * (len(line) - len(sline)))) + finally: + f.close() + +main() From 245aacf40c19560264e3bc311d2af2105b232765 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 26 Oct 2012 11:22:47 -0700 Subject: [PATCH 0051/1502] Radical refactoring: no more Scheduler class; use thread-local Context. --- main.py | 65 ++++++++-------- polling.py | 20 +++-- scheduling.py | 208 ++++++++++++++++++++++++++++---------------------- sockets.py | 24 +++--- 4 files changed, 178 insertions(+), 139 deletions(-) diff --git a/main.py b/main.py index f4e64d04..829e1f42 100644 --- a/main.py +++ b/main.py @@ -30,29 +30,10 @@ import socket import sys -# Initialize logging before we import polling. -# TODO: Change polling.py so we can do this in main(). -if '-d' in sys.argv: - level = logging.DEBUG -elif '-v' in sys.argv: - level = logging.INFO -elif '-q' in sys.argv: - level = logging.ERROR -else: - level = logging.WARN -logging.basicConfig(level=level) - # Local imports (keep in alphabetic order). -import polling import scheduling import sockets -eventloop = polling.EventLoop() -threadrunner = polling.ThreadRunner(eventloop) -scheduler = scheduling.Scheduler(eventloop, threadrunner) - -sockets.scheduler = scheduler # TODO: Find a better way. - def urlfetch(host, port=80, method='GET', path='/', body=None, hdrs=None, encoding='utf-8', ssl=None, af=0): @@ -114,23 +95,27 @@ def urlfetch(host, port=80, method='GET', path='/', def doit(): + TIMEOUT = 2 + tasks = set() + # This references NDB's default test service. # (Sadly the service is single-threaded.) - task1 = scheduler.newtask(urlfetch('localhost', 8080, path='/'), - 'root', timeout=2) - task2 = scheduler.newtask(urlfetch('localhost', 8080, path='/home'), - 'home', timeout=2) + task1 = scheduling.Task(urlfetch('localhost', 8080, path='/'), + 'root', timeout=TIMEOUT) + tasks.add(task1) + task2 = scheduling.Task(urlfetch('localhost', 8080, path='/home'), + 'home', timeout=TIMEOUT) + tasks.add(task2) # Fetch python.org home page. - task3 = scheduler.newtask(urlfetch('python.org', 80, path='/'), - 'python', timeout=2) - - tasks = {task1, task2, task3} + task3 = scheduling.Task(urlfetch('python.org', 80, path='/'), + 'python', timeout=TIMEOUT) + tasks.add(task3) # Fetch XKCD home page using SSL. (Doesn't like IPv6.) - task4 = scheduler.newtask(urlfetch('xkcd.com', 443, path='/', - af=socket.AF_INET), - 'xkcd', timeout=2) + task4 = scheduling.Task(urlfetch('xkcd.com', 443, path='/', + af=socket.AF_INET), + 'xkcd', timeout=TIMEOUT) tasks.add(task4) ## # Fetch many links from python.org (/x.y.z). @@ -139,14 +124,14 @@ def doit(): ## path = '/{}.{}'.format(x, y) ## g = urlfetch('82.94.164.162', 80, ## path=path, hdrs={'host': 'python.org'}) -## t = scheduler.newtask(g, path, timeout=2) +## t = scheduling.Task(g, path, timeout=2) ## tasks.add(t) -## print(tasks) +## print(tasks) for t in tasks: t.start() - scheduler.run() -## print(tasks) + scheduling.run() +## print(tasks) for t in tasks: print(t.name + ':', t.exception or t.result) @@ -160,6 +145,18 @@ def logtimes(real): def main(): t0 = time.time() + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + doit() t1 = time.time() logtimes(t1-t0) diff --git a/polling.py b/polling.py index c8220349..96a3f110 100644 --- a/polling.py +++ b/polling.py @@ -421,12 +421,17 @@ def run(self): else: poll_base = SelectMixin -logging.info('Using Pollster base class %r', poll_base.__name__) - class EventLoop(EventLoopMixin, poll_base): """Event loop implementation using the optimal pollster mixin.""" + def __init__(self): + super().__init__() + logging.info('Using Pollster base class %r', poll_base.__name__) + + +MAX_WORKERS = 5 # Default max workers when creating an executor. + class ThreadRunner: """Helper to submit work to a thread pool and wait for it. @@ -438,9 +443,9 @@ class ThreadRunner: The only public API is submit(). """ - def __init__(self, eventloop, max_workers=5): + def __init__(self, eventloop, executor=None): self.eventloop = eventloop - self.threadpool = concurrent.futures.ThreadPoolExecutor(max_workers) + self.executor = executor # Will be constructed lazily. self.pipe_read_fd, self.pipe_write_fd = os.pipe() self.active_count = 0 @@ -460,7 +465,12 @@ def submit(self, func, *args, executor=None): should not wait for that, but rather add a callback to it. """ if executor is None: - executor = self.threadpool + executor = self.executor + if executor is None: + # Lazily construct a default executor. + # TODO: Should this be shared between threads? + executor = concurrent.futures.ThreadPoolExecutor(MAX_WORKERS) + self.executor = executor assert self.active_count >= 0, self.active_count future = executor.submit(func, *args) if self.active_count == 0: diff --git a/scheduling.py b/scheduling.py index d4c3cb6b..27339d0f 100644 --- a/scheduling.py +++ b/scheduling.py @@ -14,7 +14,8 @@ - Wait for first N that are ready. - Wait until some predicate becomes true. - Run with timeout. - +- Various synchronization primitives (Lock, RLock, Event, Condition, + Semaphore, BoundedSemaphore, Barrier). """ __author__ = 'Guido van Rossum ' @@ -22,11 +23,47 @@ # Standard library imports (keep in alphabetic order). from concurrent.futures import TimeoutError import logging +import threading import time +import polling + + +class Context(threading.local): + """Thread-local context. + + We use this to avoid having to explicitly pass around an event loop + or something to hold the current task. + + TODO: Add an API so frameworks can substitute a different notion + of context more easily. + """ + + def __init__(self, eventloop=None, threadrunner=None): + # Default event loop and thread runner are lazily constructed + # when first accessed. + self._eventloop = eventloop + self._threadrunner = threadrunner + self.current_task = None + + @property + def eventloop(self): + if self._eventloop is None: + self._eventloop = polling.EventLoop() + return self._eventloop + + @property + def threadrunner(self): + if self._threadrunner is None: + self._threadrunner = polling.ThreadRunner(self.eventloop) + return self._threadrunner + + +context = Context() # Thread-local! + class Task: - """Lightweight wrapper around a generator. + """Wrapper around a stack of generators. This is a bit like a Future, but with a different interface. @@ -35,8 +72,7 @@ class Task: - wait for result. """ - def __init__(self, sched, gen, name=None, *, timeout=None): - self.sched = sched + def __init__(self, gen, name=None, *, timeout=None): self.gen = gen self.name = name or gen.__name__ if timeout is not None and timeout < 1000000: @@ -53,7 +89,7 @@ def __repr__(self): def run(self): if not self.alive: return - self.sched.current = self + context.current_task = self try: if self.timeout is not None and self.timeout < time.time(): self.gen.throw(TimeoutError) @@ -71,99 +107,91 @@ def run(self): self.exception = exc raise else: - if self.sched.current is not None: + if context.current_task is not None: self.start() finally: - self.sched.current = None + context.current_task = None def start(self): if self.alive: - self.sched.eventloop.call_soon(self.run) + context.eventloop.call_soon(self.run) -class Scheduler: +def run(): + context.eventloop.run() - def __init__(self, eventloop, threadrunner): - self.eventloop = eventloop # polling.EventLoop instance. - self.threadrunner = threadrunner # polling.Threadrunner instance. - self.current = None # Current Task. - def run(self): - self.eventloop.run() - - def newtask(self, gen, name=None, *, timeout=None): - return Task(self, gen, name, timeout=timeout) - - def start(self, gen, name=None, *, timeout=None): - task = self.newtask(gen, name, timeout=timeout) - task.start() - return task - - def block_r(self, fd): - self.block_io(fd, 'r') - - def block_w(self, fd): - self.block_io(fd, 'w') - - def block_io(self, fd, flag): - assert isinstance(fd, int), repr(fd) - assert flag in ('r', 'w'), repr(flag) - task = self.block() - dcall = None - if task.timeout: - dcall = self.eventloop.call_later(task.timeout, - self.unblock_timeout, - fd, flag, task) - if flag == 'r': - self.eventloop.add_reader(fd, self.unblock_io, - fd, flag, task, dcall) - else: - self.eventloop.add_writer(fd, self.unblock_io, - fd, flag, task, dcall) - - def block(self): - assert self.current - task = self.current - self.current = None - return task - - def unblock_io(self, fd, flag, task, dcall): - if dcall is not None: - dcall.cancel() - if flag == 'r': - self.eventloop.remove_reader(fd) - else: - self.eventloop.remove_writer(fd) - task.start() - - def unblock_timeout(self, fd, flag, task): - # NOTE: Due to the call_soon() semantics, we can't guarantee - # that unblock_timeout() isn't called *after* unblock_io() has - # already been called. So we must write this defensively. - # TODO: Analyse this further for race conditions etc. - if flag == 'r': - if fd in self.eventloop.readers: - self.eventloop.remove_reader(fd) - else: - if fd in self.eventloop.writers: - self.eventloop.remove_writer(fd) - task.timeout = 0 # Force it to cancel. - task.start() - - def call_in_thread(self, func, *args, executor=None): - # TODO: Prove there is no race condition here. - task = self.block() - future = self.threadrunner.submit(func, *args, executor=executor) - future.add_done_callback(lambda _: task.start()) - try: - yield - except TimeoutError: - future.cancel() - raise - assert future.done() - return future.result() +def sleep(secs): + task = block() + context.eventloop.call_later(secs, task.start) + yield + + +def block_r(fd): + block_io(fd, 'r') + + +def block_w(fd): + block_io(fd, 'w') + + +def block_io(fd, flag): + assert isinstance(fd, int), repr(fd) + assert flag in ('r', 'w'), repr(flag) + task = block() + dcall = None + if task.timeout: + dcall = context.eventloop.call_later(task.timeout, unblock_timeout, + fd, flag, task) + if flag == 'r': + context.eventloop.add_reader(fd, unblock_io, fd, flag, task, dcall) + else: + context.eventloop.add_writer(fd, unblock_io, fd, flag, task, dcall) + + +def block(): + assert context.current_task + task = context.current_task + context.current_task = None + return task + + +def unblock_io(fd, flag, task, dcall): + if dcall is not None: + dcall.cancel() + if flag == 'r': + context.eventloop.remove_reader(fd) + else: + context.eventloop.remove_writer(fd) + task.start() + + +def unblock_timeout(fd, flag, task): + # NOTE: Due to the call_soon() semantics, we can't guarantee + # that unblock_timeout() isn't called *after* unblock_io() has + # already been called. So we must write this defensively. + # TODO: Analyse this further for race conditions etc. + if flag == 'r': + if fd in context.eventloop.readers: + context.eventloop.remove_reader(fd) + else: + if fd in context.eventloop.writers: + context.eventloop.remove_writer(fd) + task.timeout = 0 # Force it to cancel. + task.start() + - def sleep(self, secs): - task = self.block() - self.eventloop.call_later(secs, task.start) +def call_in_thread(func, *args, executor=None): + # TODO: Prove there is no race condition here. + task = block() + future = context.threadrunner.submit(func, *args, executor=executor) + # Don't reference context in the lambda! It is called in another thread. + this_eventloop = context.eventloop + future.add_done_callback(lambda _: this_eventloop.call_soon(task.run)) + try: yield + except TimeoutError: + future.cancel() + raise + assert future.done() + return future.result() diff --git a/sockets.py b/sockets.py index 85b70052..9f569e46 100644 --- a/sockets.py +++ b/sockets.py @@ -23,11 +23,15 @@ __author__ = 'Guido van Rossum ' +# Stdlib imports. import errno import re import socket import ssl +# Local imports. +import scheduling + class SocketTransport: """Transport wrapping a socket. @@ -41,14 +45,14 @@ def __init__(self, sock): def recv(self, n): """COROUTINE: Read up to n bytes, blocking at most once.""" assert n >= 0, n - scheduler.block_r(self.sock.fileno()) + scheduling.block_r(self.sock.fileno()) yield return self.sock.recv(n) def send(self, data): """COROUTINE; Send data to the socket, blocking until all written.""" while data: - scheduler.block_w(self.sock.fileno()) + scheduling.block_w(self.sock.fileno()) yield n = self.sock.send(data) assert 0 <= n <= len(data), (n, len(data)) @@ -80,10 +84,10 @@ def do_handshake(self): try: self.sslsock.do_handshake() except ssl.SSLWantReadError: - scheduler.block_r(self.sslsock.fileno()) + scheduling.block_r(self.sslsock.fileno()) yield except ssl.SSLWantWriteError: - scheduler.block_w(self.sslsock.fileno()) + scheduling.block_w(self.sslsock.fileno()) yield else: break @@ -97,7 +101,7 @@ def recv(self, n): try: return self.sslsock.recv(n) except socket.error as err: - scheduler.block_r(self.sslsock.fileno()) + scheduling.block_r(self.sslsock.fileno()) yield def send(self, data): @@ -106,7 +110,7 @@ def send(self, data): try: n = self.sslsock.send(data) except socket.error as err: - scheduler.block_w(self.sslsock.fileno()) + scheduling.block_w(self.sslsock.fileno()) yield if n == len(data): break @@ -186,7 +190,7 @@ def connect(sock, address): except socket.error as err: if err.errno != errno.EINPROGRESS: raise - scheduler.block_w(sock.fileno()) + scheduling.block_w(sock.fileno()) yield err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) if err != 0: @@ -198,9 +202,9 @@ def getaddrinfo(host, port, af=0, socktype=0, proto=0): Each info is a tuple (af, socktype, protocol, canonname, address). """ - infos = yield from scheduler.call_in_thread(socket.getaddrinfo, - host, port, af, - socktype, proto) + infos = yield from scheduling.call_in_thread(socket.getaddrinfo, + host, port, af, + socktype, proto) return infos From 6a8bbb729fca2136e81acfa49cb8e12100b3c6bb Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 26 Oct 2012 12:05:16 -0700 Subject: [PATCH 0052/1502] Default port depends on ssl flag. --- main.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 829e1f42..202cdce8 100644 --- a/main.py +++ b/main.py @@ -9,6 +9,7 @@ TODO: - Take test urls from command line. +- Move urlfetch to a separate module. - Profiling. - Docstrings. - Unittests. @@ -35,9 +36,14 @@ import sockets -def urlfetch(host, port=80, method='GET', path='/', +def urlfetch(host, port=None, method='GET', path='/', body=None, hdrs=None, encoding='utf-8', ssl=None, af=0): t0 = time.time() + if port is None: + if ssl: + port = 443 + else: + port = 80 trans = yield from sockets.create_transport(host, port, ssl=ssl, af=af) yield from trans.send(method.encode(encoding) + b' ' + path.encode(encoding) + b' HTTP/1.0\r\n') @@ -113,7 +119,7 @@ def doit(): tasks.add(task3) # Fetch XKCD home page using SSL. (Doesn't like IPv6.) - task4 = scheduling.Task(urlfetch('xkcd.com', 443, path='/', + task4 = scheduling.Task(urlfetch('xkcd.com', ssl=True, path='/', af=socket.AF_INET), 'xkcd', timeout=TIMEOUT) tasks.add(task4) @@ -133,7 +139,7 @@ def doit(): scheduling.run() ## print(tasks) for t in tasks: - print(t.name + ':', t.exception or t.result) + print(t.name + ':', repr(t.exception) if t.exception else t.result) def logtimes(real): From b9cde113428189fe7692312982a4898269c00dac Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 26 Oct 2012 15:39:32 -0700 Subject: [PATCH 0053/1502] Another radical task refactor. Happier with timeouts. --- scheduling.py | 150 +++++++++++++++++++++++++------------------------- 1 file changed, 75 insertions(+), 75 deletions(-) diff --git a/scheduling.py b/scheduling.py index 27339d0f..6498c418 100644 --- a/scheduling.py +++ b/scheduling.py @@ -13,7 +13,6 @@ - Wait for all, collate results. - Wait for first N that are ready. - Wait until some predicate becomes true. -- Run with timeout. - Various synchronization primitives (Lock, RLock, Event, Condition, Semaphore, BoundedSemaphore, Barrier). """ @@ -21,11 +20,13 @@ __author__ = 'Guido van Rossum ' # Standard library imports (keep in alphabetic order). -from concurrent.futures import TimeoutError +from concurrent.futures import CancelledError, TimeoutError import logging import threading import time +import types +# Local imports (keep in alphabetic order). import polling @@ -68,16 +69,22 @@ class Task: This is a bit like a Future, but with a different interface. TODO: - - cancellation. - wait for result. """ def __init__(self, gen, name=None, *, timeout=None): + assert isinstance(gen, types.GeneratorType), repr(gen) self.gen = gen self.name = name or gen.__name__ - if timeout is not None and timeout < 1000000: - timeout += time.time() self.timeout = timeout + self.eventloop = context.eventloop + self.canceleer = None + if timeout is not None: + self.canceleer = self.eventloop.call_later(timeout, self.cancel) + self.blocked = False + self.unblocker = None + self.cancelled = False + self.must_cancel = False self.alive = True self.result = None self.exception = None @@ -86,13 +93,19 @@ def __repr__(self): return 'Task<%r, timeout=%s>(alive=%r, result=%r, exception=%r)' % ( self.name, self.timeout, self.alive, self.result, self.exception) - def run(self): - if not self.alive: - return - context.current_task = self + def cancel(self): + if self.alive: + self.must_cancel = True + self.unblock() + + def step(self): + assert self.alive try: - if self.timeout is not None and self.timeout < time.time(): - self.gen.throw(TimeoutError) + context.current_task = self + if self.must_cancel: + self.must_cancel = False + self.cancelled = True + self.gen.throw(CancelledError()) else: next(self.gen) except StopIteration as exc: @@ -101,20 +114,56 @@ def run(self): except Exception as exc: self.alive = False self.exception = exc - logging.exception('Uncaught exception in task %r', self.name) + logging.debug('Uncaught exception in task %r', self.name, + exc_info=True, stack_info=True) except BaseException: self.alive = False self.exception = exc raise else: - if context.current_task is not None: - self.start() + if not self.blocked: + self.eventloop.call_soon(self.step) finally: context.current_task = None + # Cancel timeout callback if set. + if not self.alive and self.canceleer is not None: + self.canceleer.cancel() def start(self): - if self.alive: - context.eventloop.call_soon(self.run) + assert self.alive + self.eventloop.call_soon(self.step) + + def block(self, unblock_callback=None, *unblock_args): + assert self is context.current_task + assert self.alive + assert not self.blocked + self.blocked = True + self.unblocker = (unblock_callback, unblock_args) + + def unblock(self): + assert self.alive + assert self.blocked + self.blocked = False + unblock_callback, unblock_args = self.unblocker + if unblock_callback is not None: + try: + unblock_callback(*unblock_args) + except Exception: + logging.error('Exception in unblocker in task %r', self.name) + raise + finally: + self.unblocker = None + self.eventloop.call_soon(self.step) + + def block_io(self, fd, flag): + assert isinstance(fd, int), repr(fd) + assert flag in ('r', 'w'), repr(flag) + if flag == 'r': + self.block(self.eventloop.remove_reader, fd) + self.eventloop.add_reader(fd, self.unblock) + else: + self.block(self.eventloop.remove_writer, fd) + self.eventloop.add_writer(fd, self.unblock) def run(): @@ -122,76 +171,27 @@ def run(): def sleep(secs): - task = block() - context.eventloop.call_later(secs, task.start) + """COROUTINE: Sleep for some time (a float in seconds).""" + context.current_task.block() + context.eventloop.call_later(secs, self.unblock) yield def block_r(fd): - block_io(fd, 'r') + context.current_task.block_io(fd, 'r') def block_w(fd): - block_io(fd, 'w') - - -def block_io(fd, flag): - assert isinstance(fd, int), repr(fd) - assert flag in ('r', 'w'), repr(flag) - task = block() - dcall = None - if task.timeout: - dcall = context.eventloop.call_later(task.timeout, unblock_timeout, - fd, flag, task) - if flag == 'r': - context.eventloop.add_reader(fd, unblock_io, fd, flag, task, dcall) - else: - context.eventloop.add_writer(fd, unblock_io, fd, flag, task, dcall) - - -def block(): - assert context.current_task - task = context.current_task - context.current_task = None - return task - - -def unblock_io(fd, flag, task, dcall): - if dcall is not None: - dcall.cancel() - if flag == 'r': - context.eventloop.remove_reader(fd) - else: - context.eventloop.remove_writer(fd) - task.start() - - -def unblock_timeout(fd, flag, task): - # NOTE: Due to the call_soon() semantics, we can't guarantee - # that unblock_timeout() isn't called *after* unblock_io() has - # already been called. So we must write this defensively. - # TODO: Analyse this further for race conditions etc. - if flag == 'r': - if fd in context.eventloop.readers: - context.eventloop.remove_reader(fd) - else: - if fd in context.eventloop.writers: - context.eventloop.remove_writer(fd) - task.timeout = 0 # Force it to cancel. - task.start() + context.current_task.block_io(fd, 'w') def call_in_thread(func, *args, executor=None): + """COROUTINE: Run a function in a thread.""" # TODO: Prove there is no race condition here. - task = block() future = context.threadrunner.submit(func, *args, executor=executor) - # Don't reference context in the lambda! It is called in another thread. - this_eventloop = context.eventloop - future.add_done_callback(lambda _: this_eventloop.call_soon(task.run)) - try: - yield - except TimeoutError: - future.cancel() - raise + task = context.current_task + task.block(future.cancel) + future.add_done_callback(lambda _: task.unblock()) + yield assert future.done() return future.result() From 074c514f66cc823de8eea136f0bdfec738ee3bb8 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 26 Oct 2012 16:57:41 -0700 Subject: [PATCH 0054/1502] Added wait_any() and wait_all(). --- main.py | 25 ++++++++++++---- scheduling.py | 79 +++++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 92 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index 202cdce8..697866a7 100644 --- a/main.py +++ b/main.py @@ -97,7 +97,9 @@ def urlfetch(host, port=None, method='GET', path='/', data = yield from rdr.readexactly(size) trans.close() # Can this block? t1 = time.time() - return (host, port, path, int(status), len(data), round(t1-t0, 3)) + result = (host, port, path, int(status), len(data), round(t1-t0, 3)) +## print(result) + return result def doit(): @@ -136,10 +138,10 @@ def doit(): ## print(tasks) for t in tasks: t.start() - scheduling.run() -## print(tasks) - for t in tasks: - print(t.name + ':', repr(t.exception) if t.exception else t.result) + winner = yield from scheduling.wait_any(tasks) + print('The winner is:', winner) + tasks = yield from scheduling.wait_all(tasks) + return tasks def logtimes(real): @@ -163,7 +165,18 @@ def main(): level = logging.WARN logging.basicConfig(level=level) - doit() + # Run doit() as a task. + task = scheduling.Task(doit()) + task.start() + scheduling.run() + if task.exception: + print('Exception:', repr(task.exception)) + else: + for t in task.result: + print(t.name + ':', + repr(t.exception) if t.exception else t.result) + + # Report real, user, sys times. t1 = time.time() logtimes(t1-t0) diff --git a/scheduling.py b/scheduling.py index 6498c418..9c876fb6 100644 --- a/scheduling.py +++ b/scheduling.py @@ -59,7 +59,7 @@ def threadrunner(self): self._threadrunner = polling.ThreadRunner(self.eventloop) return self._threadrunner - + context = Context() # Thread-local! @@ -88,6 +88,12 @@ def __init__(self, gen, name=None, *, timeout=None): self.alive = True self.result = None self.exception = None + self.done_callbacks = [] + + def add_done_callback(self, done_callback): + # For better or for worse, the callback will always be called + # with the task as an argument, like concurrent.futures.Future. + self.done_callbacks.append(done_callback) def __repr__(self): return 'Task<%r, timeout=%s>(alive=%r, result=%r, exception=%r)' % ( @@ -125,9 +131,13 @@ def step(self): self.eventloop.call_soon(self.step) finally: context.current_task = None - # Cancel timeout callback if set. - if not self.alive and self.canceleer is not None: - self.canceleer.cancel() + if not self.alive: + # Cancel timeout callback if set. + if self.canceleer is not None: + self.canceleer.cancel() + # Schedule done_callbacks. + for callback in self.done_callbacks: + self.eventloop.call_soon(callback, self) def start(self): assert self.alive @@ -140,7 +150,8 @@ def block(self, unblock_callback=None, *unblock_args): self.blocked = True self.unblocker = (unblock_callback, unblock_args) - def unblock(self): + def unblock(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. assert self.alive assert self.blocked self.blocked = False @@ -165,8 +176,19 @@ def block_io(self, fd, flag): self.block(self.eventloop.remove_writer, fd) self.eventloop.add_writer(fd, self.unblock) + def wait(self): + """COROUTINE: Wait until this task is finished.""" + current_task = context.current_task + assert self is not current_task # How confusing! + if not self.alive: + return + current_task.block() + self.add_done_callback(current_task.unblock) + yield + def run(): + """Run the event loop until it's out of work.""" context.eventloop.run() @@ -178,10 +200,12 @@ def sleep(secs): def block_r(fd): + """Helper to call block_io() for reading.""" context.current_task.block_io(fd, 'r') def block_w(fd): + """Helper to call block_io() for writing.""" context.current_task.block_io(fd, 'w') @@ -191,7 +215,50 @@ def call_in_thread(func, *args, executor=None): future = context.threadrunner.submit(func, *args, executor=executor) task = context.current_task task.block(future.cancel) - future.add_done_callback(lambda _: task.unblock()) + future.add_done_callback(task.unblock) yield assert future.done() return future.result() + + +def wait_any(tasks): + """COROUTINE: Wait for the first of a set of tasks to complete.""" + assert tasks + current_task = context.current_task + assert all(task is not current_task for task in tasks) + for task in tasks: + if not task.alive: + return task + winner = None + def wait_any_callback(task): + nonlocal winner, current_task + if not winner: + winner = task + current_task.unblock() + # TODO: Avoid adding N callbacks. + for task in tasks: + task.add_done_callback(wait_any_callback) + current_task.block() + yield + return winner + + +def wait_all(tasks): + """COROUTINE: Wait for all of a set of tasks to complete.""" + assert tasks + current_task = context.current_task + assert all(task is not current_task for task in tasks) + todo = set() + def wait_all_callback(task): + nonlocal todo, current_task + todo.remove(task) + if not todo: + current_task.unblock() + for task in tasks: + if task.alive: + todo.add(task) + task.add_done_callback(wait_all_callback) + if todo: + current_task.block() + yield + return tasks # Not redundant: handy if called with a comprehension. From 054207ef3dae60ae4a22d36c0d699389b7017af1 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 26 Oct 2012 17:31:42 -0700 Subject: [PATCH 0055/1502] Replace wait_any() and wait_all() with wait_for(). --- main.py | 6 +++-- scheduling.py | 64 ++++++++++++++++++++++++++------------------------- 2 files changed, 37 insertions(+), 33 deletions(-) diff --git a/main.py b/main.py index 697866a7..7c9ee7fa 100644 --- a/main.py +++ b/main.py @@ -138,9 +138,11 @@ def doit(): ## print(tasks) for t in tasks: t.start() - winner = yield from scheduling.wait_any(tasks) - print('The winner is:', winner) + yield from scheduling.sleep(0.2) + winners = yield from scheduling.wait_any(tasks) + print('And the winners are:', [w.name for w in winners]) tasks = yield from scheduling.wait_all(tasks) + print('And the players were:', [t.name for t in tasks]) return tasks diff --git a/scheduling.py b/scheduling.py index 9c876fb6..82840b6e 100644 --- a/scheduling.py +++ b/scheduling.py @@ -194,8 +194,9 @@ def run(): def sleep(secs): """COROUTINE: Sleep for some time (a float in seconds).""" - context.current_task.block() - context.eventloop.call_later(secs, self.unblock) + current_task = context.current_task + current_task.block() + context.eventloop.call_later(secs, current_task.unblock) yield @@ -221,44 +222,45 @@ def call_in_thread(func, *args, executor=None): return future.result() -def wait_any(tasks): - """COROUTINE: Wait for the first of a set of tasks to complete.""" - assert tasks - current_task = context.current_task - assert all(task is not current_task for task in tasks) - for task in tasks: - if not task.alive: - return task - winner = None - def wait_any_callback(task): - nonlocal winner, current_task - if not winner: - winner = task - current_task.unblock() - # TODO: Avoid adding N callbacks. - for task in tasks: - task.add_done_callback(wait_any_callback) - current_task.block() - yield - return winner +def wait_for(count, tasks): + """COROUTINE: Wait for the first N of a set of tasks to complete. + May return more than N if more than N are immediately ready. -def wait_all(tasks): - """COROUTINE: Wait for all of a set of tasks to complete.""" + NOTE: Tasks that were cancelled or raised are also considered ready. + """ assert tasks + tasks = set(tasks) + assert 1 <= count <= len(tasks) current_task = context.current_task assert all(task is not current_task for task in tasks) todo = set() - def wait_all_callback(task): - nonlocal todo, current_task + done = set() + def wait_for_callback(task): + nonlocal todo, done, current_task, count todo.remove(task) - if not todo: - current_task.unblock() + if len(done) < count: + done.add(task) + if len(done) == count: + current_task.unblock() for task in tasks: if task.alive: todo.add(task) - task.add_done_callback(wait_all_callback) - if todo: + else: + done.add(task) + if len(done) < count: + for task in todo: + task.add_done_callback(wait_for_callback) current_task.block() yield - return tasks # Not redundant: handy if called with a comprehension. + return done + + +def wait_any(tasks): + """COROUTINE: Wait for the first of a set of tasks to complete.""" + return wait_for(1, tasks) + + +def wait_all(tasks): + """COROUTINE: Wait for all of a set of tasks to complete.""" + return wait_for(len(tasks), tasks) From 5e8683b47a2963c07988663833a9d7757ab8a2f5 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 26 Oct 2012 17:50:45 -0700 Subject: [PATCH 0056/1502] Add with_timeout(). Fix bug in cancelled sleeps. --- main.py | 4 ++-- scheduling.py | 15 ++++++++++----- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index 7c9ee7fa..f2dce8d9 100644 --- a/main.py +++ b/main.py @@ -138,7 +138,7 @@ def doit(): ## print(tasks) for t in tasks: t.start() - yield from scheduling.sleep(0.2) + yield from scheduling.with_timeout(0.2, scheduling.sleep(1)) winners = yield from scheduling.wait_any(tasks) print('And the winners are:', [w.name for w in winners]) tasks = yield from scheduling.wait_all(tasks) @@ -168,7 +168,7 @@ def main(): logging.basicConfig(level=level) # Run doit() as a task. - task = scheduling.Task(doit()) + task = scheduling.Task(doit(), timeout=2.1) task.start() scheduling.run() if task.exception: diff --git a/scheduling.py b/scheduling.py index 82840b6e..34078f23 100644 --- a/scheduling.py +++ b/scheduling.py @@ -10,9 +10,6 @@ - Unittests. PATTERNS TO TRY: -- Wait for all, collate results. -- Wait for first N that are ready. -- Wait until some predicate becomes true. - Various synchronization primitives (Lock, RLock, Event, Condition, Semaphore, BoundedSemaphore, Barrier). """ @@ -195,8 +192,8 @@ def run(): def sleep(secs): """COROUTINE: Sleep for some time (a float in seconds).""" current_task = context.current_task - current_task.block() - context.eventloop.call_later(secs, current_task.unblock) + unblocker = context.eventloop.call_later(secs, current_task.unblock) + current_task.block(unblocker.cancel) yield @@ -264,3 +261,11 @@ def wait_any(tasks): def wait_all(tasks): """COROUTINE: Wait for all of a set of tasks to complete.""" return wait_for(len(tasks), tasks) + + +def with_timeout(timeout, gen, name=None): + """COROUTINE: Run generator synchronously with a timeout.""" + assert timeout is not None + task = Task(gen, name, timeout=timeout) + task.start() + return (yield from task.wait()) From 51525ccc0d745a323996a610b8017be91040ad86 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 27 Oct 2012 11:32:04 -0700 Subject: [PATCH 0057/1502] Use DelayedCall objects for all types of callbacks. --- polling.py | 88 +++++++++++++++++++++++++++++++++--------------------- 1 file changed, 54 insertions(+), 34 deletions(-) diff --git a/polling.py b/polling.py index 96a3f110..590b2a99 100644 --- a/polling.py +++ b/polling.py @@ -58,8 +58,8 @@ class PollsterBase: def __init__(self): super().__init__() - self.readers = {} # {fd: (callback, args), ...}. - self.writers = {} # {fd: (callback, args), ...}. + self.readers = {} # {fd: , ...}. + self.writers = {} # {fd: , ...}. def pollable(self): """Return True if any readers or writers are currently registered.""" @@ -69,11 +69,15 @@ def pollable(self): def add_reader(self, fd, callback, *args): """Add or update a reader for a file descriptor.""" - self.readers[fd] = (callback, args) + dcall = DelayedCall(None, callback, args) + self.readers[fd] = dcall + return dcall def add_writer(self, fd, callback, *args): """Add or update a writer for a file descriptor.""" - self.writers[fd] = (callback, args) + dcall = DelayedCall(None, callback, args) + self.writers[fd] = dcall + return dcall def remove_reader(self, fd): """Remove the reader for a file descriptor.""" @@ -110,8 +114,8 @@ def poll(self, timeout=None): readable, writable, _ = select.select(self.readers, self.writers, [], timeout) events = [] - events += ((fd, 'r') + self.readers[fd] for fd in readable) - events += ((fd, 'w') + self.writers[fd] for fd in writable) + events += ((fd, 'r', self.readers[fd]) for fd in readable) + events += ((fd, 'w', self.writers[fd]) for fd in writable) return events @@ -135,12 +139,14 @@ def _update(self, fd): self._poll.unregister(fd) def add_reader(self, fd, callback, *args): - super().add_reader(fd, callback, *args) + dcall = super().add_reader(fd, callback, *args) self._update(fd) + return dcall def add_writer(self, fd, callback, *args): - super().add_writer(fd, callback, *args) + dcall = super().add_writer(fd, callback, *args) self._update(fd) + return dcall def remove_reader(self, fd): super().remove_reader(fd) @@ -157,12 +163,12 @@ def poll(self, timeout=None): for fd, flags in self._poll.poll(msecs): if flags & (select.POLLIN | select.POLLHUP): if fd in self.readers: - callback, args = self.readers[fd] - events.append((fd, 'r', callback, args)) + dcall = self.readers[fd] + events.append((fd, 'r', dcall)) if flags & (select.POLLOUT | select.POLLHUP): if fd in self.writers: - callback, args = self.writers[fd] - events.append((fd, 'w', callback, args)) + dcall = self.writers[fd] + events.append((fd, 'w', dcall)) return events @@ -189,12 +195,14 @@ def _update(self, fd): self._epoll.unregister(fd) def add_reader(self, fd, callback, *args): - super().add_reader(fd, callback, *args) + dcall = super().add_reader(fd, callback, *args) self._update(fd) + return dcall def add_writer(self, fd, callback, *args): - super().add_writer(fd, callback, *args) + dcall = super().add_writer(fd, callback, *args) self._update(fd) + return dcall def remove_reader(self, fd): super().remove_reader(fd) @@ -211,12 +219,12 @@ def poll(self, timeout=None): for fd, eventmask in self._epoll.poll(timeout): if eventmask & select.EPOLLIN: if fd in self.readers: - callback, args = self.readers[fd] - events.append((fd, 'r', callback, args)) + dcall = self.readers[fd] + events.append((fd, 'r', dcall)) if eventmask & select.EPOLLOUT: if fd in self.writers: - callback, args = self.writers[fd] - events.append((fd, 'w', callback, args)) + dcall = self.writers[fd] + events.append((fd, 'w', dcall)) return events @@ -231,13 +239,13 @@ def add_reader(self, fd, callback, *args): if fd not in self.readers: kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) self._kqueue.control([kev], 0, 0) - super().add_reader(fd, callback, *args) + return super().add_reader(fd, callback, *args) def add_writer(self, fd, callback, *args): if fd not in self.readers: kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) self._kqueue.control([kev], 0, 0) - super().add_writer(fd, callback, *args) + return super().add_writer(fd, callback, *args) def remove_reader(self, fd): super().remove_reader(fd) @@ -256,16 +264,16 @@ def poll(self, timeout=None): fd = kev.ident flag = kev.filter if flag == select.KQ_FILTER_READ and fd in self.readers: - callback, args = self.readers[fd] - events.append((fd, 'r', callback, args)) + dcall = self.readers[fd] + events.append((fd, 'r', dcall)) elif flag == select.KQ_FILTER_WRITE and fd in self.writers: - callback, args = self.writers[fd] - events.append((fd, 'w', callback, args)) + dcall = self.writers[fd] + events.append((fd, 'w', dcall)) return events class DelayedCall: - """Object returned by call_later(); can be used to cancel the call.""" + """Object returned by call_soon/later(), add_reader/writer().""" def __init__(self, when, callback, args): self.when = when @@ -309,6 +317,15 @@ def __init__(self): self.ready = collections.deque() # [(callback, args), ...] self.scheduled = [] # [(when, callback, args), ...] + def add_callback(self, dcall): + """Add a DelayedCall to ready or scheduled.""" + if dcall.cancelled: + return + if dcall.when is None: + self.ready.append(dcall) + else: + heapq.heappush(self.scheduled, dcall) + def call_soon(self, callback, *args): """Arrange for a callback to be called as soon as possible. @@ -319,7 +336,9 @@ def call_soon(self, callback, *args): Any positional arguments after the callback will be passed to the callback when it is called. """ - self.ready.append((callback, args)) + dcall = DelayedCall(None, callback, args) + self.ready.append(dcall) + return dcall def call_later(self, when, callback, *args): """Arrange for a callback to be called at a given time. @@ -366,12 +385,13 @@ def run_once(self): # TODO: Ensure this loop always finishes, even if some # callbacks keeps registering more callbacks. while self.ready: - callback, args = self.ready.popleft() - try: - callback(*args) - except Exception: - logging.exception('Exception in callback %s %r', - callback, args) + dcall = self.ready.popleft() + if not dcall.cancelled: + try: + dcall.callback(*dcall.args) + except Exception: + logging.exception('Exception in callback %s %r', + callback, args) # Remove delayed calls that were cancelled from head of queue. while self.scheduled and self.scheduled[0].cancelled: @@ -389,8 +409,8 @@ def run_once(self): t1 = time.time() argstr = '' if timeout is None else ' %.3f' % timeout logging.debug('poll%s took %.3f seconds', argstr, t1-t0) - for fd, flag, callback, args in events: - self.call_soon(callback, *args) + for fd, flag, dcall in events: + self.add_callback(dcall) # Handle 'later' callbacks that are ready. while self.scheduled: From 5226f4ffefcb5eb5e8343bddad8d2cf1f9f0b53f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 27 Oct 2012 18:55:30 -0700 Subject: [PATCH 0058/1502] Fix code in except clause. --- polling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/polling.py b/polling.py index 590b2a99..9e8673e2 100644 --- a/polling.py +++ b/polling.py @@ -391,7 +391,7 @@ def run_once(self): dcall.callback(*dcall.args) except Exception: logging.exception('Exception in callback %s %r', - callback, args) + dcall.callback, dcall.args) # Remove delayed calls that were cancelled from head of queue. while self.scheduled and self.scheduled[0].cancelled: From 28868471402c43a37fa588f67d37b80a8adf522d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 28 Oct 2012 17:25:29 -0700 Subject: [PATCH 0059/1502] Add crude time comparison between yield and yield from. --- Makefile | 3 +++ yyftime.py | 72 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+) create mode 100644 yyftime.py diff --git a/Makefile b/Makefile index 8613d101..04117ce0 100644 --- a/Makefile +++ b/Makefile @@ -9,5 +9,8 @@ profile: time: $(PYTHON) p3time.py +ytime: + $(PYTHON) yyftime.py + check: $(PYTHON) longlines.py diff --git a/yyftime.py b/yyftime.py new file mode 100644 index 00000000..c7ab5406 --- /dev/null +++ b/yyftime.py @@ -0,0 +1,72 @@ +"""Compare timing of yield-from vs. yield calls.""" + +import time + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def run_coro(depth): + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + print('coro', depth, k, round(t1-t0, 6)) + return t1-t0 + +class Future: + + def __init__(self, g): + self.g = g + + def wait(self): + value = None + try: + while True: + f = self.g.send(value) + f.wait() + value = f.value + except StopIteration as err: + self.value = err.value + + + +def task(func): # Decorator + def wrapper(*args): + g = func(*args) + f = Future(g) + return f + return wrapper + +@task +def oldstyle(n): + if n <= 0: + return 1 + l = yield oldstyle(n-1) + r = yield oldstyle(n-1) + return l + 1 + r + +def run_olds(depth): + t0 = time.time() + f = oldstyle(depth) + f.wait() + k = f.value + t1 = time.time() + print('olds', depth, k, round(t1-t0, 6)) + return t1-t0 + +def main(): + for depth in range(30): + tc = run_coro(depth) + to = run_olds(depth) + print('ratio', round(to/tc, 2)) + +if __name__ == '__main__': + main() From a492cf8505aa63ec92a0a20681255bffc785530c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 28 Oct 2012 17:37:08 -0700 Subject: [PATCH 0060/1502] Disable gc in benchmarks. --- p3time.py | 7 +++++-- yyftime.py | 7 +++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/p3time.py b/p3time.py index d0cef6a4..35e14c96 100644 --- a/p3time.py +++ b/p3time.py @@ -1,5 +1,6 @@ """Compare timing of plain vs. yield-from calls.""" +import gc import time def plain(n): @@ -34,9 +35,11 @@ def submain(depth): t1 = time.time() delta1 = t1-t0 print(('coro.' + fmt).format(depth, k, delta1)) - print(('relat' + fmt).format(depth, k, delta1/delta0)) + if delta0: + print(('relat' + fmt).format(depth, k, delta1/delta0)) -def main(reasonable=100): +def main(reasonable=16): + gc.disable() for depth in range(reasonable): submain(depth) diff --git a/yyftime.py b/yyftime.py index c7ab5406..a0c2cad0 100644 --- a/yyftime.py +++ b/yyftime.py @@ -1,5 +1,6 @@ """Compare timing of yield-from vs. yield calls.""" +import gc import time def coroutine(n): @@ -63,10 +64,12 @@ def run_olds(depth): return t1-t0 def main(): - for depth in range(30): + gc.disable() + for depth in range(16): tc = run_coro(depth) to = run_olds(depth) - print('ratio', round(to/tc, 2)) + if tc: + print('ratio', round(to/tc, 2)) if __name__ == '__main__': main() From 54de22d7c4796a364f88601ee4c91c0f359cff77 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 29 Oct 2012 10:41:25 -0700 Subject: [PATCH 0061/1502] Update README, add TODO and xkcd.py. --- README | 29 ++++++++++++++++++-- TODO | 85 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ xkcd.py | 18 ++++++++++++ 3 files changed, 130 insertions(+), 2 deletions(-) create mode 100644 TODO create mode 100755 xkcd.py diff --git a/README b/README index 35541bcb..1339d065 100644 --- a/README +++ b/README @@ -1,5 +1,7 @@ Tulip is the codename for my attempt at understanding PEP-380 style -coroutines (i.e. those using generators and 'yield from'). +coroutines (i.e. those using generators and 'yield from'). + +*** This requires Python 3.3 or later! *** For reference, see many threads in python-ideas@python.org started in October 2012, especially those with "The async API of the Future" in @@ -8,8 +10,31 @@ their subject, and the various spin-off threads. A particularly influential tutorial by Greg Ewing: http://www.cosc.canterbury.ac.nz/greg.ewing/python/generators/yf_current/Examples/Scheduler/scheduler.txt -Python version: 3.3. +A message I posted with some explanation of the design: +http://mail.python.org/pipermail/python-ideas/2012-October/017501.html + +Essential files here: + +- main.py: the main program for testing, and a rough HTTP client +- sockets.py: transports for sockets and SSL, and a buffering layer +- scheduling.py: a Task class and related stuff; this is where the PEP + 380 scheduler is implemented +- polling.py: an event loop and basic polling implementations for: + select(), poll(), epoll(), kqueue() + +Secondary files: + +- .hgignore: files I don't care about +- Makefile: various quick shell commands +- README: this file +- TODO: longer list of TODO items and general thoughts +- longlines.py: stupid style checker +- p3time.py: benchmark yield from vs. plain functions +- xkcd.py: *synchronous* ssl example +- yyftime.py: benchmark yield from vs. yield Copyright/license: Open source, Apache 2.0. Enjoy. +Master Mercurial repo: http://code.google.com/p/tulip/ + --Guido van Rossum diff --git a/TODO b/TODO new file mode 100644 index 00000000..5c194bb3 --- /dev/null +++ b/TODO @@ -0,0 +1,85 @@ +- Ensure multiple tasks can do atomic writes to the same pipe (since + UNIX guarantees that short writes to pipes are atomic). + +- Ensure some easy way of distributing accepted connections across tasks. + +- Make pollster take an abstract token and return it. + +- Make pollster a sub-object instead of a superclass of the eventloop. + +- Be wary of thread-local storage. There should be a standard API to + get the current Context (which holds current task, event loop, and + maybe more) and a standard meta-API to change how that standard API + works (i.e. without monkey-patching). + +- See how much of asyncore I've already replaced. + +- Write up a tutorial for the scheduling API. + +- Change block_r/w into COROUTINE style APIs. + +- Do we need _async suffixes to all async APIs? + +- Do we need synchronous parallel APIs for all async APIs? + +- Add a decorator just for documenting a coroutine? + +- Fix recv(), send() to catch EAGAIN. + +- Fix ssh recv(), send() to catch SSLWantReadError and SSLWantWriteError. + +- Could BufferedReader reuse the standard io module's readers??? + +[From older list] + +- Is it better to have separate add_{reader,writer} methods, vs. one + add_thingie method taking a fd and a r/w flag? + +- Multiple readers/writers per socket? (At which level? pollster, + eventloop, or scheduler?) + +- Should poll() return a list of tokens or a list of (fd, flag, token)? + +- Could poll() usefully be an iterator? + +- Do we need to support more epoll and/or kqueue modes/flags/options/etc.? + +- Optimize register/unregister calls away if they cancel each other out? + +- Should block() use a queue? + +- Add explicit wait queue to wait for Task's completion, instead of + callbacks? + +- Global functions vs. Task methods? + +- Is the Task design good? + +- Make Task more like Future? (Or less???) + +- Implement various lock styles a la threading.py. + +- Handle disconnect errors from send() (and from recv()???). + +- Add write() calls that don't require yield from. + +- Add simple non-async APIs, for simple apps? + + +MISTAKES I MADE + +- Forgetting yield from. (E.g.: scheduler.sleep(1).) + +- Forgot to add bare yield at end of internal function, after block(). + +- Forgot to call add_done_callback(). + +- Forgot to pass an undoer to block(), bug only found when cancelled. + +- Subtle accounting mistake in a callback. + +- Used context.eventloop from a different thread, forgetting about TLS. + +- Nasty race: eventloop.ready may contain both an I/O callback and a + cancel callback. How to avoid? Keep the DelayedCall in ready. Is + that enough? diff --git a/xkcd.py b/xkcd.py new file mode 100755 index 00000000..474009d0 --- /dev/null +++ b/xkcd.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3.3 +"""Minimal synchronous SSL demo, connecting to xkcd.com.""" + +import socket, ssl + +s = socket.socket() +s.connect(('xkcd.com', 443)) +ss = ssl.wrap_socket(s) + +ss.send(b'GET / HTTP/1.0\r\n\r\n') + +while True: + data = ss.recv(1000000) + print(data) + if not data: + break + +ss.close() From dabc066ff37545eb9d5760279343d4a079621c86 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 29 Oct 2012 11:11:29 -0700 Subject: [PATCH 0062/1502] Update TODO. --- TODO | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/TODO b/TODO index 5c194bb3..ac309f53 100644 --- a/TODO +++ b/TODO @@ -22,7 +22,9 @@ - Do we need synchronous parallel APIs for all async APIs? -- Add a decorator just for documenting a coroutine? +- Add a decorator just for documenting a coroutine. It should set a + flag on the function. It should not interfere with methods, + staticmethod, classmethod and the like. - Fix recv(), send() to catch EAGAIN. @@ -30,7 +32,10 @@ - Could BufferedReader reuse the standard io module's readers??? -[From older list] +- Support ZeroMQ "sockets" which are user objects. + + +FROM OLDER LIST - Is it better to have separate add_{reader,writer} methods, vs. one add_thingie method taking a fd and a r/w flag? From d9543edf90ec649fc5b982c2ba48bb1fd2e9b1a1 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 29 Oct 2012 11:45:58 -0700 Subject: [PATCH 0063/1502] Update TODO. --- TODO | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/TODO b/TODO index ac309f53..b73a8c1f 100644 --- a/TODO +++ b/TODO @@ -1,3 +1,13 @@ +- Need more examples, e.g. simple client, simple server (echo). + +- Benchmarkable HTTP server? + +- Do we need call_every()? + +- What to do about callbacks with keyword-only arguments? + +- Example of using UDP. + - Ensure multiple tasks can do atomic writes to the same pipe (since UNIX guarantees that short writes to pipes are atomic). From 08fd919eb52f0e52aa056d67cfb209f40fddde44 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 29 Oct 2012 12:54:45 -0700 Subject: [PATCH 0064/1502] Expand 0mq section. --- TODO | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/TODO b/TODO index b73a8c1f..af6ff4b1 100644 --- a/TODO +++ b/TODO @@ -42,7 +42,11 @@ - Could BufferedReader reuse the standard io module's readers??? -- Support ZeroMQ "sockets" which are user objects. +- Support ZeroMQ "sockets" which are user objects. Though possibly + this can be supported by getting the underlying fd? See + http://mail.python.org/pipermail/python-ideas/2012-October/017532.html + OTOH see + https://github.com/zeromq/pyzmq/blob/master/zmq/eventloop/ioloop.py FROM OLDER LIST From 800f1ab642967406d36b9b346e27ecd2b24e8815 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 29 Oct 2012 13:11:32 -0700 Subject: [PATCH 0065/1502] Robustify recv() and send() methods (both transports). --- TODO | 6 ---- sockets.py | 92 +++++++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 78 insertions(+), 20 deletions(-) diff --git a/TODO b/TODO index af6ff4b1..9c23712a 100644 --- a/TODO +++ b/TODO @@ -36,10 +36,6 @@ flag on the function. It should not interfere with methods, staticmethod, classmethod and the like. -- Fix recv(), send() to catch EAGAIN. - -- Fix ssh recv(), send() to catch SSLWantReadError and SSLWantWriteError. - - Could BufferedReader reuse the standard io module's readers??? - Support ZeroMQ "sockets" which are user objects. Though possibly @@ -78,8 +74,6 @@ FROM OLDER LIST - Implement various lock styles a la threading.py. -- Handle disconnect errors from send() (and from recv()???). - - Add write() calls that don't require yield from. - Add simple non-async APIs, for simple apps? diff --git a/sockets.py b/sockets.py index 9f569e46..e170d2a0 100644 --- a/sockets.py +++ b/sockets.py @@ -32,6 +32,18 @@ # Local imports. import scheduling +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + class SocketTransport: """Transport wrapping a socket. @@ -43,22 +55,49 @@ def __init__(self, sock): self.sock = sock def recv(self, n): - """COROUTINE: Read up to n bytes, blocking at most once.""" + """COROUTINE: Read up to n bytes, blocking as needed. + + Always returns at least one byte, except if the socket was + closed or disconnected and there's no more data; then it + returns b''. + """ assert n >= 0, n - scheduling.block_r(self.sock.fileno()) - yield - return self.sock.recv(n) + while True: + try: + return self.sock.recv(n) + except socket.error as err: + if err.errno in _TRYAGAIN: + scheduling.block_r(self.sock.fileno()) + yield + elif err.errno in _DISCONNECTED: + # Can this happen? + return b'' + else: + raise # Unexpected, propagate. def send(self, data): - """COROUTINE; Send data to the socket, blocking until all written.""" + """COROUTINE; Send data to the socket, blocking until all written. + + Return True if all went well, False if socket was disconnected. + """ while data: - scheduling.block_w(self.sock.fileno()) - yield - n = self.sock.send(data) - assert 0 <= n <= len(data), (n, len(data)) - if n == len(data): - break - data = data[n:] + try: + n = self.sock.send(data) + except socket.error as err: + if err.errno in _TRYAGAIN: + scheduling.block_w(self.sock.fileno()) + yield + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + else: + assert 0 <= n <= len(data), (n, len(data)) + if n == len(data): + break + data = data[n:] + + return True def close(self): """Close the socket. (Not a coroutine.)""" @@ -100,22 +139,47 @@ def recv(self, n): while True: try: return self.sslsock.recv(n) - except socket.error as err: + except ssl.SSLWantReadError: scheduling.block_r(self.sslsock.fileno()) yield + except ssl.SSLWantWriteError: + scheduling.block_w(self.sslsock.fileno()) + yield + except socket.error as err: + if err.errno in _TRYAGAIN: + scheduling.block_r(self.sock.fileno()) + yield + elif err.errno in _DISCONNECTED: + # Can this happen? + return b'' + else: + raise # Unexpected, propagate. def send(self, data): """COROUTINE: Send data to the socket, blocking as needed.""" while data: try: n = self.sslsock.send(data) - except socket.error as err: + except ssl.SSLWantReadError: + scheduling.block_r(self.sslsock.fileno()) + yield + except ssl.SSLWantWriteError: scheduling.block_w(self.sslsock.fileno()) yield + except socket.error as err: + if err.errno in _TRYAGAIN: + scheduling.block_w(self.sock.fileno()) + yield + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. if n == len(data): break data = data[n:] + return True + def close(self): """Close the socket. (Not a coroutine.) From abc740276155225e1fe4838d9f00f2db0155ee6f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 29 Oct 2012 13:16:31 -0700 Subject: [PATCH 0066/1502] TODO refactor. --- TODO | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/TODO b/TODO index 9c23712a..7bced3f5 100644 --- a/TODO +++ b/TODO @@ -1,22 +1,36 @@ +TO DO SMALLER TASKS + +- Make pollster take an abstract token and return it. + +- Make pollster a sub-object instead of a superclass of the eventloop. + +- Change block_r/w into COROUTINE style APIs. + +- Should poll() return a list of tokens or a list of (fd, flag, token)? + + +TO DO LARGER TASKS + - Need more examples, e.g. simple client, simple server (echo). - Benchmarkable HTTP server? +- Example of using UDP. + +- Write up a tutorial for the scheduling API. + + +TO DO LATER + - Do we need call_every()? - What to do about callbacks with keyword-only arguments? -- Example of using UDP. - - Ensure multiple tasks can do atomic writes to the same pipe (since UNIX guarantees that short writes to pipes are atomic). - Ensure some easy way of distributing accepted connections across tasks. -- Make pollster take an abstract token and return it. - -- Make pollster a sub-object instead of a superclass of the eventloop. - - Be wary of thread-local storage. There should be a standard API to get the current Context (which holds current task, event loop, and maybe more) and a standard meta-API to change how that standard API @@ -24,10 +38,6 @@ - See how much of asyncore I've already replaced. -- Write up a tutorial for the scheduling API. - -- Change block_r/w into COROUTINE style APIs. - - Do we need _async suffixes to all async APIs? - Do we need synchronous parallel APIs for all async APIs? @@ -53,8 +63,6 @@ FROM OLDER LIST - Multiple readers/writers per socket? (At which level? pollster, eventloop, or scheduler?) -- Should poll() return a list of tokens or a list of (fd, flag, token)? - - Could poll() usefully be an iterator? - Do we need to support more epoll and/or kqueue modes/flags/options/etc.? @@ -78,6 +86,9 @@ FROM OLDER LIST - Add simple non-async APIs, for simple apps? +- Look at pyfdpdlib's ioloop.py: + http://code.google.com/p/pyftpdlib/source/browse/trunk/pyftpdlib/lib/ioloop.py + MISTAKES I MADE From e5a1b8f6459ae88a313c44e9535cc6665ea76e71 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 29 Oct 2012 16:04:21 -0700 Subject: [PATCH 0067/1502] Fix bad code in except handler. --- scheduling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scheduling.py b/scheduling.py index 34078f23..e809b33f 100644 --- a/scheduling.py +++ b/scheduling.py @@ -119,7 +119,7 @@ def step(self): self.exception = exc logging.debug('Uncaught exception in task %r', self.name, exc_info=True, stack_info=True) - except BaseException: + except BaseException as exc: self.alive = False self.exception = exc raise From f4e7eaec782c38be9fa8b01c288c106d345c5b8d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 29 Oct 2012 16:08:00 -0700 Subject: [PATCH 0068/1502] Add an echo server demo. --- Makefile | 3 +++ TODO | 12 +++++++-- echosvr.py | 64 +++++++++++++++++++++++++++++++++++++++++++++++ main.py | 2 +- sockets.py | 73 ++++++++++++++++++++++++++++++++++++++++-------------- 5 files changed, 133 insertions(+), 21 deletions(-) create mode 100644 echosvr.py diff --git a/Makefile b/Makefile index 04117ce0..533e7427 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,9 @@ PYTHON=python3.3 test: $(PYTHON) main.py -v +echo: + $(PYTHON) echosvr.py -v + profile: $(PYTHON) -m profile -s time main.py diff --git a/TODO b/TODO index 7bced3f5..d6843db2 100644 --- a/TODO +++ b/TODO @@ -11,7 +11,7 @@ TO DO SMALLER TASKS TO DO LARGER TASKS -- Need more examples, e.g. simple client, simple server (echo). +- Need more examples. - Benchmarkable HTTP server? @@ -92,7 +92,7 @@ FROM OLDER LIST MISTAKES I MADE -- Forgetting yield from. (E.g.: scheduler.sleep(1).) +- Forgetting yield from. (E.g.: scheduler.sleep(1); listener.accept().) - Forgot to add bare yield at end of internal function, after block(). @@ -107,3 +107,11 @@ MISTAKES I MADE - Nasty race: eventloop.ready may contain both an I/O callback and a cancel callback. How to avoid? Keep the DelayedCall in ready. Is that enough? + +- If a toplevel task raises an error it just stops and nothing is logged + unless you have debug logging on. This confused me. (Then again, + previously I logged whenever a task raised an error, and that was too + chatty...) + +- Forgot to set the connection socket returned by accept() in + nonblocking mode. diff --git a/echosvr.py b/echosvr.py new file mode 100644 index 00000000..060a764d --- /dev/null +++ b/echosvr.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3.3 +"""Example echo server.""" + +# Stdlib imports. +import logging +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + while True: + line = yield from rdr.readline() + logging.debug('Received: %r from %r', line, addr) + if not line: + break + yield from trans.send(line.upper()) + logging.debug('Closing %r', addr) + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 1111, + af=socket.AF_INET) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + t.start() + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Start the task that starts the listener. + t = scheduling.Task(doit()) + t.start() + + # Run the main loop. + scheduling.run() + + +if __name__ == '__main__': + main() diff --git a/main.py b/main.py index f2dce8d9..c561d6cd 100644 --- a/main.py +++ b/main.py @@ -111,7 +111,7 @@ def doit(): task1 = scheduling.Task(urlfetch('localhost', 8080, path='/'), 'root', timeout=TIMEOUT) tasks.add(task1) - task2 = scheduling.Task(urlfetch('localhost', 8080, path='/home'), + task2 = scheduling.Task(urlfetch('127.0.0.1', 8080, path='/home'), 'home', timeout=TIMEOUT) tasks.add(task2) diff --git a/sockets.py b/sockets.py index e170d2a0..2d250625 100644 --- a/sockets.py +++ b/sockets.py @@ -25,7 +25,6 @@ # Stdlib imports. import errno -import re import socket import ssl @@ -272,21 +271,11 @@ def getaddrinfo(host, port, af=0, socktype=0, proto=0): return infos -def create_connection(host, port, af=0, socktype=socket.SOCK_STREAM): +def create_connection(host, port, af=0, socktype=socket.SOCK_STREAM, proto=0): """COROUTINE: Look up address and create a socket connected to it.""" - match = re.match(r'(\d+)\.(\d+)\.(\d+)\.(\d+)\Z', host) - if match: - d1, d2, d3, d4 = map(int, match.groups()) - if not (0 <= d1 <= 255 and 0 <= d2 <= 255 and - 0 <= d3 <= 255 and 0 <= d4 <= 255): - match = None - if not match: - infos = yield from getaddrinfo(host, port, - af=af, socktype=socket.SOCK_STREAM) - else: - infos = [(socket.AF_INET, socket.SOCK_STREAM, socket.SOL_TCP, '', - (host, port))] - assert infos, 'No address info for (%r, %r)' % (host, port) + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') exc = None for af, socktype, proto, cname, address in infos: sock = None @@ -301,8 +290,7 @@ def create_connection(host, port, af=0, socktype=socket.SOCK_STREAM): if exc is None: exc = err else: - if exc is not None: - raise exc + raise exc return sock @@ -310,10 +298,59 @@ def create_transport(host, port, af=0, ssl=None): """COROUTINE: Look up address and create a transport connected to it.""" if ssl is None: ssl = (port == 443) - sock = yield from create_connection(host, port, af=af) + sock = yield from create_connection(host, port, af) if ssl: trans = SslTransport(sock) yield from trans.do_handshake() else: trans = SocketTransport(sock) return trans + + +class Listener: + """Wrapper for a listening socket.""" + + def __init__(self, sock): + self.sock = sock + + def accept(self): + """COROUTINE: Accept a connection.""" + while True: + try: + conn, addr = self.sock.accept() + except socket.error as err: + if err.errno in _TRYAGAIN: + scheduling.block_r(self.sock.fileno()) + yield + else: + raise # Unexpected, propagate. + else: + conn.setblocking(False) + return conn, addr + + +def create_listener(host, port, af=0, socktype=0, proto=0, + backlog=5, reuse_addr=True): + """COROUTINE: Look up address and create a listener for it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + if reuse_addr: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + sock.listen(backlog) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return Listener(sock) From b594d2e97d10e2ab932efbff8ca9ef7b3f1da25b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 29 Oct 2012 16:51:47 -0700 Subject: [PATCH 0069/1502] Make block_{r,w} COROUTINEs. --- TODO | 2 -- scheduling.py | 6 ++++-- sockets.py | 36 ++++++++++++------------------------ 3 files changed, 16 insertions(+), 28 deletions(-) diff --git a/TODO b/TODO index d6843db2..1eefde7f 100644 --- a/TODO +++ b/TODO @@ -4,8 +4,6 @@ TO DO SMALLER TASKS - Make pollster a sub-object instead of a superclass of the eventloop. -- Change block_r/w into COROUTINE style APIs. - - Should poll() return a list of tokens or a list of (fd, flag, token)? diff --git a/scheduling.py b/scheduling.py index e809b33f..c49976a3 100644 --- a/scheduling.py +++ b/scheduling.py @@ -198,13 +198,15 @@ def sleep(secs): def block_r(fd): - """Helper to call block_io() for reading.""" + """COROUTINE: Block until a file descriptor is ready for reading.""" context.current_task.block_io(fd, 'r') + yield def block_w(fd): - """Helper to call block_io() for writing.""" + """COROUTINE: Block until a file descriptor is ready for writing.""" context.current_task.block_io(fd, 'w') + yield def call_in_thread(func, *args, executor=None): diff --git a/sockets.py b/sockets.py index 2d250625..ebc85221 100644 --- a/sockets.py +++ b/sockets.py @@ -66,8 +66,7 @@ def recv(self, n): return self.sock.recv(n) except socket.error as err: if err.errno in _TRYAGAIN: - scheduling.block_r(self.sock.fileno()) - yield + yield from scheduling.block_r(self.sock.fileno()) elif err.errno in _DISCONNECTED: # Can this happen? return b'' @@ -84,8 +83,7 @@ def send(self, data): n = self.sock.send(data) except socket.error as err: if err.errno in _TRYAGAIN: - scheduling.block_w(self.sock.fileno()) - yield + yield from scheduling.block_w(self.sock.fileno()) elif err.errno in _DISCONNECTED: return False else: @@ -122,11 +120,9 @@ def do_handshake(self): try: self.sslsock.do_handshake() except ssl.SSLWantReadError: - scheduling.block_r(self.sslsock.fileno()) - yield + yield from scheduling.block_r(self.sslsock.fileno()) except ssl.SSLWantWriteError: - scheduling.block_w(self.sslsock.fileno()) - yield + yield from scheduling.block_w(self.sslsock.fileno()) else: break @@ -139,15 +135,12 @@ def recv(self, n): try: return self.sslsock.recv(n) except ssl.SSLWantReadError: - scheduling.block_r(self.sslsock.fileno()) - yield + yield from scheduling.block_r(self.sslsock.fileno()) except ssl.SSLWantWriteError: - scheduling.block_w(self.sslsock.fileno()) - yield + yield from scheduling.block_w(self.sslsock.fileno()) except socket.error as err: if err.errno in _TRYAGAIN: - scheduling.block_r(self.sock.fileno()) - yield + yield from scheduling.block_r(self.sock.fileno()) elif err.errno in _DISCONNECTED: # Can this happen? return b'' @@ -160,15 +153,12 @@ def send(self, data): try: n = self.sslsock.send(data) except ssl.SSLWantReadError: - scheduling.block_r(self.sslsock.fileno()) - yield + yield from scheduling.block_r(self.sslsock.fileno()) except ssl.SSLWantWriteError: - scheduling.block_w(self.sslsock.fileno()) - yield + yield from scheduling.block_w(self.sslsock.fileno()) except socket.error as err: if err.errno in _TRYAGAIN: - scheduling.block_w(self.sock.fileno()) - yield + yield from scheduling.block_w(self.sock.fileno()) elif err.errno in _DISCONNECTED: return False else: @@ -253,8 +243,7 @@ def connect(sock, address): except socket.error as err: if err.errno != errno.EINPROGRESS: raise - scheduling.block_w(sock.fileno()) - yield + yield from scheduling.block_w(sock.fileno()) err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) if err != 0: raise IOError(err, 'Connection refused') @@ -320,8 +309,7 @@ def accept(self): conn, addr = self.sock.accept() except socket.error as err: if err.errno in _TRYAGAIN: - scheduling.block_r(self.sock.fileno()) - yield + yield from scheduling.block_r(self.sock.fileno()) else: raise # Unexpected, propagate. else: From 4d99906857e2c965c4f3373c102d53f4db835739 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 29 Oct 2012 17:35:16 -0700 Subject: [PATCH 0070/1502] Make EventLoop wrap a Pollster instead of inherit from it. --- TODO | 4 -- polling.py | 197 +++++++++++++++++++++++++---------------------------- 2 files changed, 93 insertions(+), 108 deletions(-) diff --git a/TODO b/TODO index 1eefde7f..43efc1ee 100644 --- a/TODO +++ b/TODO @@ -1,9 +1,5 @@ TO DO SMALLER TASKS -- Make pollster take an abstract token and return it. - -- Make pollster a sub-object instead of a superclass of the eventloop. - - Should poll() return a list of tokens or a list of (fd, flag, token)? diff --git a/polling.py b/polling.py index 9e8673e2..5c2f3717 100644 --- a/polling.py +++ b/polling.py @@ -2,8 +2,8 @@ The event loop can be broken up into a pollster (the part responsible for telling us when file descriptors are ready) and the event loop -proper, which adds functionality for scheduling callbacks, immediately -or at a given time in the future. +proper, which wraps a pollster with functionality for scheduling +callbacks, immediately or at a given time in the future. Whenever a public API takes a callback, subsequent positional arguments will be passed to the callback if/when it is called. This @@ -29,9 +29,8 @@ 2. poll 3. select - TODO: -- Optimize the various pollster. +- Optimize the various pollsters. - Unittests. """ @@ -48,18 +47,15 @@ class PollsterBase: """Base class for all polling implementations. This defines an interface to register and unregister readers and - writers (defined as a callback plus optional positional arguments) - for specific file descriptors, and an interface to get a list of - events. There's also an interface to check whether any readers or - writers are currently registered. The readers and writers - attributes are public -- they are simply mappings of file - descriptors to tuples of (callback, args). + writers for specific file descriptors, and an interface to get a + list of events. There's also an interface to check whether any + readers or writers are currently registered. """ def __init__(self): super().__init__() - self.readers = {} # {fd: , ...}. - self.writers = {} # {fd: , ...}. + self.readers = {} # {fd: token, ...}. + self.writers = {} # {fd: token, ...}. def pollable(self): """Return True if any readers or writers are currently registered.""" @@ -67,23 +63,19 @@ def pollable(self): # Subclasses are expected to extend the add/remove methods. - def add_reader(self, fd, callback, *args): + def register_reader(self, fd, token): """Add or update a reader for a file descriptor.""" - dcall = DelayedCall(None, callback, args) - self.readers[fd] = dcall - return dcall + self.readers[fd] = token - def add_writer(self, fd, callback, *args): + def register_writer(self, fd, token): """Add or update a writer for a file descriptor.""" - dcall = DelayedCall(None, callback, args) - self.writers[fd] = dcall - return dcall + self.writers[fd] = token - def remove_reader(self, fd): + def unregister_reader(self, fd): """Remove the reader for a file descriptor.""" del self.readers[fd] - def remove_writer(self, fd): + def unregister_writer(self, fd): """Remove the writer for a file descriptor.""" del self.writers[fd] @@ -98,16 +90,15 @@ def poll(self, timeout=None): The return value is a list of events; it is empty when the timeout expired before any events were ready. Each event - is a tuple of the form (fd, flag, callback, args): + is a tuple of the form (fd, flag, token): fd: the file descriptor flag: 'r' or 'w' (to distinguish readers from writers) - callback: callback function - args: arguments tuple for callback + token: whatever you passed to register_reader/writer(). """ raise NotImplementedError -class SelectMixin(PollsterBase): +class SelectPollster(PollsterBase): """Pollster implementation using select.""" def poll(self, timeout=None): @@ -119,7 +110,7 @@ def poll(self, timeout=None): return events -class PollMixin(PollsterBase): +class PollPollster(PollsterBase): """Pollster implementation using poll.""" def __init__(self): @@ -138,22 +129,20 @@ def _update(self, fd): else: self._poll.unregister(fd) - def add_reader(self, fd, callback, *args): - dcall = super().add_reader(fd, callback, *args) + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) self._update(fd) - return dcall - def add_writer(self, fd, callback, *args): - dcall = super().add_writer(fd, callback, *args) + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) self._update(fd) - return dcall - def remove_reader(self, fd): - super().remove_reader(fd) + def unregister_reader(self, fd): + super().unregister_reader(fd) self._update(fd) - def remove_writer(self, fd): - super().remove_writer(fd) + def unregister_writer(self, fd): + super().unregister_writer(fd) self._update(fd) def poll(self, timeout=None): @@ -163,16 +152,14 @@ def poll(self, timeout=None): for fd, flags in self._poll.poll(msecs): if flags & (select.POLLIN | select.POLLHUP): if fd in self.readers: - dcall = self.readers[fd] - events.append((fd, 'r', dcall)) + events.append((fd, 'r', self.readers[fd])) if flags & (select.POLLOUT | select.POLLHUP): if fd in self.writers: - dcall = self.writers[fd] - events.append((fd, 'w', dcall)) + events.append((fd, 'w', self.writers[fd])) return events -class EPollMixin(PollsterBase): +class EPollPollster(PollsterBase): """Pollster implementation using epoll.""" def __init__(self): @@ -194,22 +181,20 @@ def _update(self, fd): else: self._epoll.unregister(fd) - def add_reader(self, fd, callback, *args): - dcall = super().add_reader(fd, callback, *args) + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) self._update(fd) - return dcall - def add_writer(self, fd, callback, *args): - dcall = super().add_writer(fd, callback, *args) + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) self._update(fd) - return dcall - def remove_reader(self, fd): - super().remove_reader(fd) + def unregister_reader(self, fd): + super().unregister_reader(fd) self._update(fd) - def remove_writer(self, fd): - super().remove_writer(fd) + def unregister_writer(self, fd): + super().unregister_writer(fd) self._update(fd) def poll(self, timeout=None): @@ -219,41 +204,39 @@ def poll(self, timeout=None): for fd, eventmask in self._epoll.poll(timeout): if eventmask & select.EPOLLIN: if fd in self.readers: - dcall = self.readers[fd] - events.append((fd, 'r', dcall)) + events.append((fd, 'r', self.readers[fd])) if eventmask & select.EPOLLOUT: if fd in self.writers: - dcall = self.writers[fd] - events.append((fd, 'w', dcall)) + events.append((fd, 'w', self.writers[fd])) return events -class KqueueMixin(PollsterBase): +class KqueuePollster(PollsterBase): """Pollster implementation using kqueue.""" def __init__(self): super().__init__() self._kqueue = select.kqueue() - def add_reader(self, fd, callback, *args): + def register_reader(self, fd, callback, *args): if fd not in self.readers: kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) self._kqueue.control([kev], 0, 0) - return super().add_reader(fd, callback, *args) + return super().register_reader(fd, callback, *args) - def add_writer(self, fd, callback, *args): + def register_writer(self, fd, callback, *args): if fd not in self.readers: kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) self._kqueue.control([kev], 0, 0) - return super().add_writer(fd, callback, *args) + return super().register_writer(fd, callback, *args) - def remove_reader(self, fd): - super().remove_reader(fd) + def unregister_reader(self, fd): + super().unregister_reader(fd) kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) self._kqueue.control([kev], 0, 0) - def remove_writer(self, fd): - super().remove_writer(fd) + def unregister_writer(self, fd): + super().unregister_writer(fd) kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE) self._kqueue.control([kev], 0, 0) @@ -264,16 +247,25 @@ def poll(self, timeout=None): fd = kev.ident flag = kev.filter if flag == select.KQ_FILTER_READ and fd in self.readers: - dcall = self.readers[fd] - events.append((fd, 'r', dcall)) + events.append((fd, 'r', self.readers[fd])) elif flag == select.KQ_FILTER_WRITE and fd in self.writers: - dcall = self.writers[fd] - events.append((fd, 'w', dcall)) + events.append((fd, 'w', self.writers[fd])) return events +# Pick the best pollster class for the platform. +if hasattr(select, 'kqueue'): + best_pollster = KqueuePollster +elif hasattr(select, 'epoll'): + best_pollster = EPollPollster +elif hasattr(select, 'poll'): + best_pollster = PollPollster +else: + best_pollster = SelectPollster + + class DelayedCall: - """Object returned by call_soon/later(), add_reader/writer().""" + """Object returned by callback registration methods.""" def __init__(self, when, callback, args): self.when = when @@ -294,29 +286,45 @@ def __eq__(self, other): return self.when == other.when -class EventLoopMixin(PollsterBase): +class EventLoop: """Event loop functionality. - This is an abstract class, inheriting from the abstract class - PollsterBase. A concrete class can be formed trivially by - inheriting from any of the pollster mixin classes; the concrete - class EventLoop is such a concrete class using the preferred mixin - given the platform. - This defines public APIs call_soon(), call_later(), run_once() and - run(). It also inherits public APIs add_reader(), add_writer(), - remove_reader(), remove_writer() from the mixin class. The APIs - pollable() and poll(), implemented by the mix-in, are not part of - the public API. + run(). It also wraps Pollster APIs register_reader(), + register_writer(), remove_reader(), remove_writer() with + add_reader() etc. This class's instance variables are not part of its API. """ - def __init__(self): + def __init__(self, pollster=None): super().__init__() + if pollster is None: + pollster = best_pollster() + self.pollster = pollster self.ready = collections.deque() # [(callback, args), ...] self.scheduled = [] # [(when, callback, args), ...] + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_reader(fd, dcall) + return dcall + + def remove_reader(self, fd): + """Remove a reader callback.""" + self.pollster.unregister_reader(fd) + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_writer(fd, dcall) + return dcall + + def remove_writer(self, fd): + """Remove a writer callback.""" + self.pollster.unregister_writer(fd) + def add_callback(self, dcall): """Add a DelayedCall to ready or scheduled.""" if dcall.cancelled: @@ -398,14 +406,14 @@ def run_once(self): heapq.heappop(self.scheduled) # Inspect the poll queue. - if self.pollable(): + if self.pollster.pollable(): if self.scheduled: when = self.scheduled[0].when timeout = max(0, when - time.time()) else: timeout = None t0 = time.time() - events = self.poll(timeout) + events = self.pollster.poll(timeout) t1 = time.time() argstr = '' if timeout is None else ' %.3f' % timeout logging.debug('poll%s took %.3f seconds', argstr, t1-t0) @@ -427,29 +435,10 @@ def run(self): writable file descriptors, or scheduled callbacks (of either variety). """ - while self.ready or self.scheduled or self.pollable(): + while self.ready or self.scheduled or self.pollster.pollable(): self.run_once() -# Select the most appropriate base class for the platform. -if hasattr(select, 'kqueue'): - poll_base = KqueueMixin -elif hasattr(select, 'epoll'): - poll_base = EPollMixin -elif hasattr(select, 'poll'): - poll_base = PollMixin -else: - poll_base = SelectMixin - - -class EventLoop(EventLoopMixin, poll_base): - """Event loop implementation using the optimal pollster mixin.""" - - def __init__(self): - super().__init__() - logging.info('Using Pollster base class %r', poll_base.__name__) - - MAX_WORKERS = 5 # Default max workers when creating an executor. From 2ac001f2e8e1ac57cfe368ace774900042347f6d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 29 Oct 2012 17:41:27 -0700 Subject: [PATCH 0071/1502] Change poll() to return just a list of tokens. Support keyword args through DelayedCall. --- TODO | 6 ++---- polling.py | 31 ++++++++++++++++--------------- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/TODO b/TODO index 43efc1ee..2b541dc1 100644 --- a/TODO +++ b/TODO @@ -1,6 +1,6 @@ TO DO SMALLER TASKS -- Should poll() return a list of tokens or a list of (fd, flag, token)? +- Echo client demo. TO DO LARGER TASKS @@ -16,9 +16,7 @@ TO DO LARGER TASKS TO DO LATER -- Do we need call_every()? - -- What to do about callbacks with keyword-only arguments? +- Do we need call_every()? (Easily emulated with a loop and sleep().) - Ensure multiple tasks can do atomic writes to the same pipe (since UNIX guarantees that short writes to pipes are atomic). diff --git a/polling.py b/polling.py index 5c2f3717..3897bf7b 100644 --- a/polling.py +++ b/polling.py @@ -90,10 +90,7 @@ def poll(self, timeout=None): The return value is a list of events; it is empty when the timeout expired before any events were ready. Each event - is a tuple of the form (fd, flag, token): - fd: the file descriptor - flag: 'r' or 'w' (to distinguish readers from writers) - token: whatever you passed to register_reader/writer(). + is a token previously passed to register_reader/writer(). """ raise NotImplementedError @@ -105,8 +102,8 @@ def poll(self, timeout=None): readable, writable, _ = select.select(self.readers, self.writers, [], timeout) events = [] - events += ((fd, 'r', self.readers[fd]) for fd in readable) - events += ((fd, 'w', self.writers[fd]) for fd in writable) + events += (self.readers[fd] for fd in readable) + events += (self.writers[fd] for fd in writable) return events @@ -152,10 +149,10 @@ def poll(self, timeout=None): for fd, flags in self._poll.poll(msecs): if flags & (select.POLLIN | select.POLLHUP): if fd in self.readers: - events.append((fd, 'r', self.readers[fd])) + events.append(self.readers[fd]) if flags & (select.POLLOUT | select.POLLHUP): if fd in self.writers: - events.append((fd, 'w', self.writers[fd])) + events.append(self.writers[fd]) return events @@ -204,10 +201,10 @@ def poll(self, timeout=None): for fd, eventmask in self._epoll.poll(timeout): if eventmask & select.EPOLLIN: if fd in self.readers: - events.append((fd, 'r', self.readers[fd])) + events.append(self.readers[fd]) if eventmask & select.EPOLLOUT: if fd in self.writers: - events.append((fd, 'w', self.writers[fd])) + events.append(self.writers[fd]) return events @@ -247,9 +244,9 @@ def poll(self, timeout=None): fd = kev.ident flag = kev.filter if flag == select.KQ_FILTER_READ and fd in self.readers: - events.append((fd, 'r', self.readers[fd])) + events.append(self.readers[fd]) elif flag == select.KQ_FILTER_WRITE and fd in self.writers: - events.append((fd, 'w', self.writers[fd])) + events.append(self.writers[fd]) return events @@ -267,10 +264,11 @@ def poll(self, timeout=None): class DelayedCall: """Object returned by callback registration methods.""" - def __init__(self, when, callback, args): + def __init__(self, when, callback, args, kwds=None): self.when = when self.callback = callback self.args = args + self.kwds = kwds self.cancelled = False def cancel(self): @@ -396,7 +394,10 @@ def run_once(self): dcall = self.ready.popleft() if not dcall.cancelled: try: - dcall.callback(*dcall.args) + if dcall.kwds: + dcall.callback(*dcall.args, **dcall.kwds) + else: + dcall.callback(*dcall.args) except Exception: logging.exception('Exception in callback %s %r', dcall.callback, dcall.args) @@ -417,7 +418,7 @@ def run_once(self): t1 = time.time() argstr = '' if timeout is None else ' %.3f' % timeout logging.debug('poll%s took %.3f seconds', argstr, t1-t0) - for fd, flag, dcall in events: + for dcall in events: self.add_callback(dcall) # Handle 'later' callbacks that are ready. From 2a94688320c0f5d276b52a4699d9cc10594796ae Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 30 Oct 2012 11:12:32 -0700 Subject: [PATCH 0072/1502] Add TODO items. --- TODO | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/TODO b/TODO index 2b541dc1..b5f68de6 100644 --- a/TODO +++ b/TODO @@ -1,5 +1,8 @@ TO DO SMALLER TASKS +- Move accept loop into Listener class. (Windows works better if you + make many AcceptEx() calls in parallel.) + - Echo client demo. @@ -16,6 +19,17 @@ TO DO LARGER TASKS TO DO LATER +- Wrap select(), epoll() etc. in try/except checking for EINTR. + +- When multiple tasks are accessing the same socket, they should + either get interleaved I/O or an immediate exception; it should not + compromise the integrity of the scheduler or the app or leave a task + hanging. + +- For epoll you probably want to check/(log?) EPOLLHUP and EPOLLERR errors. + +- Add the simplest API possible to run a generator with a timeout. + - Do we need call_every()? (Easily emulated with a loop and sleep().) - Ensure multiple tasks can do atomic writes to the same pipe (since From bb5c8704b12fb21c06f37dda207704d0c9bc1bc8 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 30 Oct 2012 11:26:26 -0700 Subject: [PATCH 0073/1502] Housekeeping. --- TODO | 2 ++ yyftime.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/TODO b/TODO index b5f68de6..c1b30801 100644 --- a/TODO +++ b/TODO @@ -60,6 +60,8 @@ TO DO LATER OTOH see https://github.com/zeromq/pyzmq/blob/master/zmq/eventloop/ioloop.py +- Study goroutines (again). + FROM OLDER LIST diff --git a/yyftime.py b/yyftime.py index a0c2cad0..f55234b9 100644 --- a/yyftime.py +++ b/yyftime.py @@ -36,8 +36,8 @@ def wait(self): value = f.value except StopIteration as err: self.value = err.value - - + + def task(func): # Decorator def wrapper(*args): From b232f293020f2c3693af8f2c0028f0228952cbae Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 30 Oct 2012 11:47:36 -0700 Subject: [PATCH 0074/1502] Move urlfetch() to a new file, http_client.py. --- http_client.py | 78 +++++++++++++++++++++++++++++++++++++++++++++ main.py | 85 ++++++-------------------------------------------- 2 files changed, 87 insertions(+), 76 deletions(-) create mode 100644 http_client.py diff --git a/http_client.py b/http_client.py new file mode 100644 index 00000000..a51ce310 --- /dev/null +++ b/http_client.py @@ -0,0 +1,78 @@ +"""Crummy HTTP client. + +This is not meant as an example of how to write a good client. +""" + +# Stdlib. +import re +import time + +# Local. +import sockets + + +def urlfetch(host, port=None, method='GET', path='/', + body=None, hdrs=None, encoding='utf-8', ssl=None, af=0): + """COROUTINE: Make an HTTP 1.0 request.""" + t0 = time.time() + if port is None: + if ssl: + port = 443 + else: + port = 80 + trans = yield from sockets.create_transport(host, port, ssl=ssl, af=af) + yield from trans.send(method.encode(encoding) + b' ' + + path.encode(encoding) + b' HTTP/1.0\r\n') + if hdrs: + kwds = dict(hdrs) + else: + kwds = {} + if 'host' not in kwds: + kwds['host'] = host + if body is not None: + kwds['content_length'] = len(body) + for header, value in kwds.items(): + yield from trans.send(header.replace('_', '-').encode(encoding) + + b': ' + value.encode(encoding) + b'\r\n') + + yield from trans.send(b'\r\n') + if body is not None: + yield from trans.send(body) + + # Read HTTP response line. + rdr = sockets.BufferedReader(trans) + resp = yield from rdr.readline() + m = re.match(br'(?ix) http/(\d\.\d) \s+ (\d\d\d) \s+ ([^\r]*)\r?\n\Z', + resp) + if not m: + trans.close() + raise IOError('No valid HTTP response: %r' % resp) + http_version, status, message = m.groups() + + # Read HTTP headers. + headers = [] + hdict = {} + while True: + line = yield from rdr.readline() + if not line.strip(): + break + m = re.match(br'([^\s:]+):\s*([^\r]*)\r?\n\Z', line) + if not m: + raise IOError('Invalid header: %r' % line) + header, value = m.groups() + headers.append((header, value)) + hdict[header.decode(encoding).lower()] = value.decode(encoding) + + # Read response body. + content_length = hdict.get('content-length') + if content_length is not None: + size = int(content_length) # TODO: Catch errors. + assert size >= 0, size + else: + size = 2**20 # Protective limit (1 MB). + data = yield from rdr.readexactly(size) + trans.close() # Can this block? + t1 = time.time() + result = (host, port, path, int(status), len(data), round(t1-t0, 3)) +## print(result) + return result diff --git a/main.py b/main.py index c561d6cd..7f23752f 100644 --- a/main.py +++ b/main.py @@ -26,81 +26,13 @@ # Standard library imports (keep in alphabetic order). import logging import os -import re import time import socket import sys # Local imports (keep in alphabetic order). import scheduling -import sockets - - -def urlfetch(host, port=None, method='GET', path='/', - body=None, hdrs=None, encoding='utf-8', ssl=None, af=0): - t0 = time.time() - if port is None: - if ssl: - port = 443 - else: - port = 80 - trans = yield from sockets.create_transport(host, port, ssl=ssl, af=af) - yield from trans.send(method.encode(encoding) + b' ' + - path.encode(encoding) + b' HTTP/1.0\r\n') - if hdrs: - kwds = dict(hdrs) - else: - kwds = {} - if 'host' not in kwds: - kwds['host'] = host - if body is not None: - kwds['content_length'] = len(body) - for header, value in kwds.items(): - yield from trans.send(header.replace('_', '-').encode(encoding) + - b': ' + value.encode(encoding) + b'\r\n') - - yield from trans.send(b'\r\n') - if body is not None: - yield from trans.send(body) - - # Read HTTP response line. - rdr = sockets.BufferedReader(trans) - resp = yield from rdr.readline() - m = re.match(br'(?ix) http/(\d\.\d) \s+ (\d\d\d) \s+ ([^\r]*)\r?\n\Z', - resp) - if not m: - trans.close() - raise IOError('No valid HTTP response: %r' % resp) - http_version, status, message = m.groups() - - # Read HTTP headers. - headers = [] - hdict = {} - while True: - line = yield from rdr.readline() - if not line.strip(): - break - m = re.match(br'([^\s:]+):\s*([^\r]*)\r?\n\Z', line) - if not m: - raise IOError('Invalid header: %r' % line) - header, value = m.groups() - headers.append((header, value)) - hdict[header.decode(encoding).lower()] = value.decode(encoding) - - # Read response body. - content_length = hdict.get('content-length') - if content_length is not None: - size = int(content_length) # TODO: Catch errors. - assert size >= 0, size - else: - size = 2**20 # Protective limit (1 MB). - data = yield from rdr.readexactly(size) - trans.close() # Can this block? - t1 = time.time() - result = (host, port, path, int(status), len(data), round(t1-t0, 3)) -## print(result) - return result - +import http_client def doit(): TIMEOUT = 2 @@ -108,21 +40,22 @@ def doit(): # This references NDB's default test service. # (Sadly the service is single-threaded.) - task1 = scheduling.Task(urlfetch('localhost', 8080, path='/'), + task1 = scheduling.Task(http_client.urlfetch('localhost', 8080, path='/'), 'root', timeout=TIMEOUT) tasks.add(task1) - task2 = scheduling.Task(urlfetch('127.0.0.1', 8080, path='/home'), + task2 = scheduling.Task(http_client.urlfetch('127.0.0.1', 8080, + path='/home'), 'home', timeout=TIMEOUT) tasks.add(task2) # Fetch python.org home page. - task3 = scheduling.Task(urlfetch('python.org', 80, path='/'), + task3 = scheduling.Task(http_client.urlfetch('python.org', 80, path='/'), 'python', timeout=TIMEOUT) tasks.add(task3) # Fetch XKCD home page using SSL. (Doesn't like IPv6.) - task4 = scheduling.Task(urlfetch('xkcd.com', ssl=True, path='/', - af=socket.AF_INET), + task4 = scheduling.Task(http_client.urlfetch('xkcd.com', ssl=True, path='/', + af=socket.AF_INET), 'xkcd', timeout=TIMEOUT) tasks.add(task4) @@ -130,8 +63,8 @@ def doit(): ## for x in '123': ## for y in '0123456789': ## path = '/{}.{}'.format(x, y) -## g = urlfetch('82.94.164.162', 80, -## path=path, hdrs={'host': 'python.org'}) +## g = http_client.urlfetch('82.94.164.162', 80, +## path=path, hdrs={'host': 'python.org'}) ## t = scheduling.Task(g, path, timeout=2) ## tasks.add(t) From 27a792134336ed1f0a6aa9925496e0ec002f9766 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 30 Oct 2012 12:51:39 -0700 Subject: [PATCH 0075/1502] Add a convenience feature to scheduling.run(). --- main.py | 6 ++---- scheduling.py | 17 +++++++++++++++-- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index 7f23752f..aff05a92 100644 --- a/main.py +++ b/main.py @@ -100,10 +100,8 @@ def main(): level = logging.WARN logging.basicConfig(level=level) - # Run doit() as a task. - task = scheduling.Task(doit(), timeout=2.1) - task.start() - scheduling.run() + # Run scheduler, starting it off with doit(). + task = scheduling.run(doit()) if task.exception: print('Exception:', repr(task.exception)) else: diff --git a/scheduling.py b/scheduling.py index c49976a3..f6f48c0b 100644 --- a/scheduling.py +++ b/scheduling.py @@ -184,9 +184,22 @@ def wait(self): yield -def run(): - """Run the event loop until it's out of work.""" +def run(arg=None): + """Run the event loop until it's out of work. + + If you pass a generator, it will be spawned for you. + If you pass a Task, it will be started for you. + Returns the task. + """ + t = None + if arg is not None: + if isinstance(arg, Task): + t = arg + else: + t = Task(arg) + t.start() context.eventloop.run() + return t def sleep(secs): From 9fe3bb54aec80d84b41968b90dd13c65a4bf588c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 30 Oct 2012 13:13:26 -0700 Subject: [PATCH 0076/1502] Log an error if the startup task has an exception. --- scheduling.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scheduling.py b/scheduling.py index f6f48c0b..9042a87a 100644 --- a/scheduling.py +++ b/scheduling.py @@ -199,6 +199,9 @@ def run(arg=None): t = Task(arg) t.start() context.eventloop.run() + if t is not None and t.exception is not None: + logging.error('Uncaught exception in startup task: %r', + t.exception) return t From 97e65e652a8b96a31260066e2a257d22b0a33992 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 30 Oct 2012 13:13:51 -0700 Subject: [PATCH 0077/1502] Stupid http server to tweak benchmarks. --- http_server.py | 69 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 http_server.py diff --git a/http_server.py b/http_server.py new file mode 100644 index 00000000..b387f0d6 --- /dev/null +++ b/http_server.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3.3 +"""Simple HTTP server. + +This currenty exists just so we can benchmark this thing! +""" + +# Stdlib imports. +import logging +import re +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + ##logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + + # Read but ignore request line. + request_line = yield from rdr.readline() + + # Consume headers but don't interpret them. + while True: + header_line = yield from rdr.readline() + if not header_line.strip(): + break + + # Always send an empty 200 response and close. + yield from trans.send(b'HTTP/1.0 200 Ok\r\n\r\n') + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 8080, + af=socket.AF_INET) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + t.start() + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() From f86846201971cddfbf00d3dcd5f1e90bfe9d1167 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 30 Oct 2012 14:44:00 -0700 Subject: [PATCH 0078/1502] Add benchmarks link. --- TODO | 2 ++ 1 file changed, 2 insertions(+) diff --git a/TODO b/TODO index c1b30801..cdb50c19 100644 --- a/TODO +++ b/TODO @@ -62,6 +62,8 @@ TO DO LATER - Study goroutines (again). +- Benchmarks: http://nichol.as/benchmark-of-python-web-servers + FROM OLDER LIST From 120718795cb4fc4aca67ea8781c233f46952dd10 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 30 Oct 2012 15:47:42 -0700 Subject: [PATCH 0079/1502] Move work out of except clauses, for better tracebacks. --- sockets.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sockets.py b/sockets.py index ebc85221..6c1be4bb 100644 --- a/sockets.py +++ b/sockets.py @@ -66,12 +66,12 @@ def recv(self, n): return self.sock.recv(n) except socket.error as err: if err.errno in _TRYAGAIN: - yield from scheduling.block_r(self.sock.fileno()) + pass elif err.errno in _DISCONNECTED: - # Can this happen? return b'' else: raise # Unexpected, propagate. + yield from scheduling.block_r(self.sock.fileno()) def send(self, data): """COROUTINE; Send data to the socket, blocking until all written. @@ -83,7 +83,7 @@ def send(self, data): n = self.sock.send(data) except socket.error as err: if err.errno in _TRYAGAIN: - yield from scheduling.block_w(self.sock.fileno()) + pass elif err.errno in _DISCONNECTED: return False else: @@ -93,6 +93,8 @@ def send(self, data): if n == len(data): break data = data[n:] + continue + yield from scheduling.block_w(self.sock.fileno()) return True From 363641051c8d3b2b34cda811a28e3b064bc12807 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 30 Oct 2012 15:49:58 -0700 Subject: [PATCH 0080/1502] Subtle logging improvements. --- polling.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/polling.py b/polling.py index 3897bf7b..8e160b3a 100644 --- a/polling.py +++ b/polling.py @@ -298,6 +298,7 @@ class EventLoop: def __init__(self, pollster=None): super().__init__() if pollster is None: + logging.info('Using pollster: %s', best_pollster.__name__) pollster = best_pollster() self.pollster = pollster self.ready = collections.deque() # [(callback, args), ...] @@ -417,7 +418,11 @@ def run_once(self): events = self.pollster.poll(timeout) t1 = time.time() argstr = '' if timeout is None else ' %.3f' % timeout - logging.debug('poll%s took %.3f seconds', argstr, t1-t0) + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) for dcall in events: self.add_callback(dcall) From ffd2ec7f5fa7a3e422b6da7514c4596d5c3fab3c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 30 Oct 2012 15:54:30 -0700 Subject: [PATCH 0081/1502] Add a simple echo client. (Though it has problems.) --- echoclt.py | 73 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 echoclt.py diff --git a/echoclt.py b/echoclt.py new file mode 100644 index 00000000..782e21ae --- /dev/null +++ b/echoclt.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3.3 +"""Example echo client.""" + +# Stdlib imports. +import logging +import socket +import sys +import time + +# Local imports. +import scheduling +import sockets + + +def echoclient(host, port): + """COROUTINE""" + testdata = b'hi hi hi ha ha ha\n' + trans = yield from sockets.create_transport(host, port, af=socket.AF_INET) + ok = yield from trans.send(testdata) + if ok: + response = yield from trans.recv(100) + trans.close() + return response == testdata.upper() + + +def doit(n): + """COROUTINE""" + t0 = time.time() + tasks = set() + for i in range(n): + t = scheduling.Task(echoclient('127.0.0.1', 1111), 'client-%d' % i) + t.start() + tasks.add(t) + ok = 0 + bad = 0 + for t in tasks: + yield from t.wait() + assert not t.alive + if t.result: + ok += 1 + else: + bad += 1 + t1 = time.time() + print('ok: ', ok) + print('bad:', bad) + print('dt: ', round(t1-t0, 6)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Get integer from command line. + n = 1 + for arg in sys.argv[1:]: + if not arg.startswith('-'): + n = int(arg) + break + + # Run the main loop. + scheduling.run(doit(n)) + + +if __name__ == '__main__': + main() From edb33eeebca6b5c2313c65f0303446739daa4ab6 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 31 Oct 2012 09:13:37 -0700 Subject: [PATCH 0082/1502] Raise listen() backlog to 100. --- echosvr.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/echosvr.py b/echosvr.py index 060a764d..3c236288 100644 --- a/echosvr.py +++ b/echosvr.py @@ -30,7 +30,8 @@ def doit(): """COROUTINE: Set the wheels in motion.""" # Set up listener. listener = yield from sockets.create_listener('localhost', 1111, - af=socket.AF_INET) + af=socket.AF_INET, + backlog=100) logging.info('Listening on %r', listener.sock.getsockname()) # Loop accepting connections. From 4a54ec84c3dcc0a9ffbf097432111bf4139c0a78 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 31 Oct 2012 09:14:55 -0700 Subject: [PATCH 0083/1502] Prettier Task repr(). --- scheduling.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/scheduling.py b/scheduling.py index 9042a87a..057613cb 100644 --- a/scheduling.py +++ b/scheduling.py @@ -93,8 +93,19 @@ def add_done_callback(self, done_callback): self.done_callbacks.append(done_callback) def __repr__(self): - return 'Task<%r, timeout=%s>(alive=%r, result=%r, exception=%r)' % ( - self.name, self.timeout, self.alive, self.result, self.exception) + parts = [self.name] + if self.alive: + is_current = (self is context.current_task) + if self.blocked: + parts.append('blocking' if is_current else 'blocked') + else: + parts.append('running' if is_current else 'runnable') + else: + if self.exception: + parts.append('exception=%r' % self.exception) + else: + parts.append('result=%r' % self.result) + return 'Task<' + ', '.join(parts) + '>' def cancel(self): if self.alive: From 4191d55fe24c3b77f32189e4c4adc2adf0cb31ee Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 31 Oct 2012 09:16:19 -0700 Subject: [PATCH 0084/1502] Improvements. --- echoclt.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/echoclt.py b/echoclt.py index 782e21ae..32c2b87c 100644 --- a/echoclt.py +++ b/echoclt.py @@ -15,12 +15,19 @@ def echoclient(host, port): """COROUTINE""" testdata = b'hi hi hi ha ha ha\n' - trans = yield from sockets.create_transport(host, port, af=socket.AF_INET) - ok = yield from trans.send(testdata) - if ok: - response = yield from trans.recv(100) - trans.close() - return response == testdata.upper() + try: + trans = yield from sockets.create_transport(host, port, + af=socket.AF_INET) + except OSError: + return False + try: + ok = yield from trans.send(testdata) + if ok: + response = yield from trans.recv(100) + ok = response == testdata.upper() + return ok + finally: + trans.close() def doit(n): From 45882e471ea3556ed99981f0d09edbbde54176a4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 31 Oct 2012 09:16:32 -0700 Subject: [PATCH 0085/1502] Cleanup TODO. --- TODO | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/TODO b/TODO index cdb50c19..8444c2b0 100644 --- a/TODO +++ b/TODO @@ -1,16 +1,13 @@ TO DO SMALLER TASKS -- Move accept loop into Listener class. (Windows works better if you - make many AcceptEx() calls in parallel.) - -- Echo client demo. +- Make Task more like Future; getting result() should re-raise exception. TO DO LARGER TASKS - Need more examples. -- Benchmarkable HTTP server? +- Benchmarkable but more realistic HTTP server? - Example of using UDP. @@ -21,6 +18,10 @@ TO DO LATER - Wrap select(), epoll() etc. in try/except checking for EINTR. +- Move accept loop into Listener class? (Windows is said to work + better if you make many AcceptEx() calls in parallel.) OTOH we can + already accept many incoming connections without suspending. + - When multiple tasks are accessing the same socket, they should either get interleaved I/O or an immediate exception; it should not compromise the integrity of the scheduler or the app or leave a task From eab0bafc76eadc7bb41d616098611baeba0fa479 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 31 Oct 2012 13:48:43 -0700 Subject: [PATCH 0086/1502] Fix race condition in call_in_thread(). --- TODO | 12 ++++++++++++ scheduling.py | 54 +++++++++++++++++++++++++++++---------------------- 2 files changed, 43 insertions(+), 23 deletions(-) diff --git a/TODO b/TODO index 8444c2b0..163f5d34 100644 --- a/TODO +++ b/TODO @@ -1,3 +1,5 @@ +# -*- Mode: text -*- + TO DO SMALLER TASKS - Make Task more like Future; getting result() should re-raise exception. @@ -13,6 +15,9 @@ TO DO LARGER TASKS - Write up a tutorial for the scheduling API. +- More systematic approach to logging. Logger objects? What about + heavy-duty logging, tracking essentially all task state changes? + TO DO LATER @@ -126,3 +131,10 @@ MISTAKES I MADE - Forgot to set the connection socket returned by accept() in nonblocking mode. + +- Nastiest so far (cost me about a day): A race condition in + call_in_thread() where the Future's done_callback (which was + task.unblock()) would run immediately at the time when + add_done_callback() was called, and this screwed over the task + state. Solution: wrap the callback in eventloop.call_later(). + Ironically, I had a comment stating there might be a race condition. diff --git a/scheduling.py b/scheduling.py index 057613cb..48481655 100644 --- a/scheduling.py +++ b/scheduling.py @@ -94,17 +94,21 @@ def add_done_callback(self, done_callback): def __repr__(self): parts = [self.name] - if self.alive: - is_current = (self is context.current_task) - if self.blocked: - parts.append('blocking' if is_current else 'blocked') - else: - parts.append('running' if is_current else 'runnable') - else: - if self.exception: - parts.append('exception=%r' % self.exception) - else: - parts.append('result=%r' % self.result) + is_current = (self is context.current_task) + if self.blocked: + parts.append('blocking' if is_current else 'blocked') + elif self.alive: + parts.append('running' if is_current else 'runnable') + if self.must_cancel: + parts.append('must_cancel') + if self.cancelled: + parts.append('cancelled') + if self.exception: + parts.append('exception=%r' % self.exception) + elif not self.alive: + parts.append('result=%r' % self.result) + if self.timeout is not None: + parts.append('timeout=%.3f' % self.timeout) return 'Task<' + ', '.join(parts) + '>' def cancel(self): @@ -113,7 +117,7 @@ def cancel(self): self.unblock() def step(self): - assert self.alive + assert self.alive, self try: context.current_task = self if self.must_cancel: @@ -128,7 +132,7 @@ def step(self): except Exception as exc: self.alive = False self.exception = exc - logging.debug('Uncaught exception in task %r', self.name, + logging.debug('Uncaught exception in %s', self, exc_info=True, stack_info=True) except BaseException as exc: self.alive = False @@ -148,20 +152,20 @@ def step(self): self.eventloop.call_soon(callback, self) def start(self): - assert self.alive + assert self.alive, self self.eventloop.call_soon(self.step) def block(self, unblock_callback=None, *unblock_args): - assert self is context.current_task - assert self.alive - assert not self.blocked + assert self is context.current_task, self + assert self.alive, self + assert not self.blocked, self self.blocked = True self.unblocker = (unblock_callback, unblock_args) def unblock(self, unused=None): # Ignore optional argument so we can be a Future's done_callback. - assert self.alive - assert self.blocked + assert self.alive, self + assert self.blocked, self self.blocked = False unblock_callback, unblock_args = self.unblocker if unblock_callback is not None: @@ -187,7 +191,7 @@ def block_io(self, fd, flag): def wait(self): """COROUTINE: Wait until this task is finished.""" current_task = context.current_task - assert self is not current_task # How confusing! + assert self is not current_task, (self, current_task) # How confusing! if not self.alive: return current_task.block() @@ -238,11 +242,15 @@ def block_w(fd): def call_in_thread(func, *args, executor=None): """COROUTINE: Run a function in a thread.""" - # TODO: Prove there is no race condition here. - future = context.threadrunner.submit(func, *args, executor=executor) task = context.current_task + eventloop = context.eventloop + future = context.threadrunner.submit(func, *args, executor=executor) task.block(future.cancel) - future.add_done_callback(task.unblock) + # If the thread managed to complete before we get here, + # add_done_callback() will call the callback right now. Make sure + # the unblock() call doesn't happen until later. + # TODO: Make unblock() robust so this doesn't hurt? + future.add_done_callback(lambda _: eventloop.call_soon(task.unblock)) yield assert future.done() return future.result() From a89f73ae2157c02922245aeeab82717ac3b7954b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 31 Oct 2012 13:54:50 -0700 Subject: [PATCH 0087/1502] Housecleaning; added another benchmark. --- README | 9 ++++++--- TODO | 2 ++ tulip_bench.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 3 deletions(-) create mode 100644 tulip_bench.py diff --git a/README b/README index 1339d065..abfad624 100644 --- a/README +++ b/README @@ -13,10 +13,11 @@ http://www.cosc.canterbury.ac.nz/greg.ewing/python/generators/yf_current/Example A message I posted with some explanation of the design: http://mail.python.org/pipermail/python-ideas/2012-October/017501.html -Essential files here: +Essential files here (in top-to-bottom ordering): -- main.py: the main program for testing, and a rough HTTP client -- sockets.py: transports for sockets and SSL, and a buffering layer +- main.py: the main program for testing +- http_client.py: a rough HTTP/1.0 client +- sockets.py: transports for sockets and (client) SSL, and a buffering layer - scheduling.py: a Task class and related stuff; this is where the PEP 380 scheduler is implemented - polling.py: an event loop and basic polling implementations for: @@ -28,8 +29,10 @@ Secondary files: - Makefile: various quick shell commands - README: this file - TODO: longer list of TODO items and general thoughts +- http_server.py: enough of an HTTP server to point 'ab' at - longlines.py: stupid style checker - p3time.py: benchmark yield from vs. plain functions +- tulip_bench.py: yet another benchmark (like p3time.py and yyftime.py) - xkcd.py: *synchronous* ssl example - yyftime.py: benchmark yield from vs. yield diff --git a/TODO b/TODO index 163f5d34..1f2d1f07 100644 --- a/TODO +++ b/TODO @@ -18,6 +18,8 @@ TO DO LARGER TASKS - More systematic approach to logging. Logger objects? What about heavy-duty logging, tracking essentially all task state changes? +- Restructure directory, move demos and benchmarks to subdirectories. + TO DO LATER diff --git a/tulip_bench.py b/tulip_bench.py new file mode 100644 index 00000000..b55b8495 --- /dev/null +++ b/tulip_bench.py @@ -0,0 +1,31 @@ +'''Example app using `file_async` and cancellations.''' + +__author__ = 'Guido van Rossum ' + +import time + +import scheduling + +def binary(n): + if n <= 0: + return 1 + l = yield from binary(n-1) + r = yield from binary(n-1) + return l + 1 + r + +def doit(depth): + t0 = time.time() + k = yield from binary(depth) + t1 = time.time() + print(depth, k, round(t1-t0, 6)) + return (depth, k, round(t1-t0, 6)) + +def main(): + for depth in range(20): + yield from doit(depth) + +import logging +logging.basicConfig(level=logging.DEBUG) + +scheduling.Task(main()).start() +scheduling.run() From dd1beadc18657b77bad3e33af8fa16ea11502843 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 1 Nov 2012 09:39:48 -0700 Subject: [PATCH 0088/1502] Get rid of start(). --- TODO | 10 ++++++---- echoclt.py | 1 - echosvr.py | 2 -- http_server.py | 1 - main.py | 2 -- scheduling.py | 10 +++------- tulip_bench.py | 2 +- 7 files changed, 10 insertions(+), 18 deletions(-) diff --git a/TODO b/TODO index 1f2d1f07..ab7b53f3 100644 --- a/TODO +++ b/TODO @@ -4,6 +4,12 @@ TO DO SMALLER TASKS - Make Task more like Future; getting result() should re-raise exception. +- Add a decorator just for documenting a coroutine. It should set a + flag on the function. It should not interfere with methods, + staticmethod, classmethod and the like. The Task constructor should + check the flag. The decorator should notice if the wrapped function + is not a generator. + TO DO LARGER TASKS @@ -56,10 +62,6 @@ TO DO LATER - Do we need synchronous parallel APIs for all async APIs? -- Add a decorator just for documenting a coroutine. It should set a - flag on the function. It should not interfere with methods, - staticmethod, classmethod and the like. - - Could BufferedReader reuse the standard io module's readers??? - Support ZeroMQ "sockets" which are user objects. Though possibly diff --git a/echoclt.py b/echoclt.py index 32c2b87c..a46e86f9 100644 --- a/echoclt.py +++ b/echoclt.py @@ -36,7 +36,6 @@ def doit(n): tasks = set() for i in range(n): t = scheduling.Task(echoclient('127.0.0.1', 1111), 'client-%d' % i) - t.start() tasks.add(t) ok = 0 bad = 0 diff --git a/echosvr.py b/echosvr.py index 3c236288..4756daed 100644 --- a/echosvr.py +++ b/echosvr.py @@ -38,7 +38,6 @@ def doit(): while True: conn, addr = yield from listener.accept() t = scheduling.Task(handler(conn, addr)) - t.start() def main(): @@ -55,7 +54,6 @@ def main(): # Start the task that starts the listener. t = scheduling.Task(doit()) - t.start() # Run the main loop. scheduling.run() diff --git a/http_server.py b/http_server.py index b387f0d6..2b1e3dd6 100644 --- a/http_server.py +++ b/http_server.py @@ -46,7 +46,6 @@ def doit(): while True: conn, addr = yield from listener.accept() t = scheduling.Task(handler(conn, addr)) - t.start() def main(): diff --git a/main.py b/main.py index aff05a92..4bcc59a9 100644 --- a/main.py +++ b/main.py @@ -69,8 +69,6 @@ def doit(): ## tasks.add(t) ## print(tasks) - for t in tasks: - t.start() yield from scheduling.with_timeout(0.2, scheduling.sleep(1)) winners = yield from scheduling.wait_any(tasks) print('And the winners are:', [w.name for w in winners]) diff --git a/scheduling.py b/scheduling.py index 48481655..463cb499 100644 --- a/scheduling.py +++ b/scheduling.py @@ -86,6 +86,8 @@ def __init__(self, gen, name=None, *, timeout=None): self.result = None self.exception = None self.done_callbacks = [] + # Start the task immediately. + self.eventloop.call_soon(self.step) def add_done_callback(self, done_callback): # For better or for worse, the callback will always be called @@ -151,10 +153,6 @@ def step(self): for callback in self.done_callbacks: self.eventloop.call_soon(callback, self) - def start(self): - assert self.alive, self - self.eventloop.call_soon(self.step) - def block(self, unblock_callback=None, *unblock_args): assert self is context.current_task, self assert self.alive, self @@ -203,7 +201,7 @@ def run(arg=None): """Run the event loop until it's out of work. If you pass a generator, it will be spawned for you. - If you pass a Task, it will be started for you. + You can also pass a task (already started). Returns the task. """ t = None @@ -212,7 +210,6 @@ def run(arg=None): t = arg else: t = Task(arg) - t.start() context.eventloop.run() if t is not None and t.exception is not None: logging.error('Uncaught exception in startup task: %r', @@ -304,5 +301,4 @@ def with_timeout(timeout, gen, name=None): """COROUTINE: Run generator synchronously with a timeout.""" assert timeout is not None task = Task(gen, name, timeout=timeout) - task.start() return (yield from task.wait()) diff --git a/tulip_bench.py b/tulip_bench.py index b55b8495..0dd8ad34 100644 --- a/tulip_bench.py +++ b/tulip_bench.py @@ -27,5 +27,5 @@ def main(): import logging logging.basicConfig(level=logging.DEBUG) -scheduling.Task(main()).start() +scheduling.Task(main()) scheduling.run() From 6dd0ecd84f21a7816788b181bf6e0e3429456a27 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 1 Nov 2012 11:59:47 -0700 Subject: [PATCH 0089/1502] Put path before method argument. --- http_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/http_client.py b/http_client.py index a51ce310..8937ba20 100644 --- a/http_client.py +++ b/http_client.py @@ -11,7 +11,7 @@ import sockets -def urlfetch(host, port=None, method='GET', path='/', +def urlfetch(host, port=None, path='/', method='GET', body=None, hdrs=None, encoding='utf-8', ssl=None, af=0): """COROUTINE: Make an HTTP 1.0 request.""" t0 = time.time() From b8a0bf2407bc8277eee748d0009bd58181e2d743 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 1 Nov 2012 12:01:17 -0700 Subject: [PATCH 0090/1502] Add map_over() operation. Improve cancel(). --- main.py | 18 ++++++++++++++++++ scheduling.py | 46 +++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 59 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index 4bcc59a9..99de19c7 100644 --- a/main.py +++ b/main.py @@ -34,6 +34,22 @@ import scheduling import http_client + + +def doit2(): + argses = [ + ('localhost', 8080, '/'), + ('127.0.0.1', 8080, '/home'), + ('python.org', 80, '/'), + ('xkcd.com', 443, '/'), + ] + results = yield from scheduling.map_over( + lambda args: http_client.urlfetch(*args), argses, timeout=2) + for res in results: + print('-->', res) + return [] + + def doit(): TIMEOUT = 2 tasks = set() @@ -102,6 +118,8 @@ def main(): task = scheduling.run(doit()) if task.exception: print('Exception:', repr(task.exception)) + if isinstance(task.exception, AssertionError): + raise task.exception else: for t in task.result: print(t.name + ':', diff --git a/scheduling.py b/scheduling.py index 463cb499..069bcf98 100644 --- a/scheduling.py +++ b/scheduling.py @@ -115,8 +115,10 @@ def __repr__(self): def cancel(self): if self.alive: - self.must_cancel = True - self.unblock() + if not self.must_cancel and not self.cancelled: + self.must_cancel = True + if self.blocked: + self.unblock() def step(self): assert self.alive, self @@ -160,6 +162,10 @@ def block(self, unblock_callback=None, *unblock_args): self.blocked = True self.unblocker = (unblock_callback, unblock_args) + def unblock_if_alive(self): + if self.alive: + self.unblock() + def unblock(self, unused=None): # Ignore optional argument so we can be a Future's done_callback. assert self.alive, self @@ -245,9 +251,13 @@ def call_in_thread(func, *args, executor=None): task.block(future.cancel) # If the thread managed to complete before we get here, # add_done_callback() will call the callback right now. Make sure - # the unblock() call doesn't happen until later. - # TODO: Make unblock() robust so this doesn't hurt? - future.add_done_callback(lambda _: eventloop.call_soon(task.unblock)) + # the unblock() call doesn't happen until later. But then, the + # task may already have been cancelled (and it may have been too + # late to cancel the Future) so it should be okay if this call + # finds the task deceased. For that purpose we have + # unblock_if_alive(). + future.add_done_callback( + lambda _: eventloop.call_soon(task.unblock_if_alive)) yield assert future.done() return future.result() @@ -302,3 +312,29 @@ def with_timeout(timeout, gen, name=None): assert timeout is not None task = Task(gen, name, timeout=timeout) return (yield from task.wait()) + + +def map_over(gen, *args, timeout=None): + """COROUTINE: map a generator over one or more iterables. + + E.g. map_over(foo, xs, ys) runs + + Task(foo(x, y) for x, y in zip(xs, ys) + + and returns a list of all results (in that order). However if any + task raises an exception, the remaining tasks are cancelled and + the exception is propagated. + """ + # gen is a generator function. + tasks = [Task(gobj, timeout=timeout) for gobj in map(gen, *args)] + todo = set(tasks) + while todo: + ts = yield from wait_for(1, todo) + for t in ts: + assert not t.alive, t + todo.remove(t) + if t.exception is not None: + for other in todo: + other.cancel() + raise t.exception + return [t.result for t in tasks] From fa2eae5e0521fa42bbe39362fe5fba8c03723a56 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 2 Nov 2012 09:18:22 -0700 Subject: [PATCH 0091/1502] Make sure something is printed if main task fails. --- .hgignore | 6 + Makefile | 19 ++ README | 43 +++++ TODO | 144 ++++++++++++++ echoclt.py | 79 ++++++++ echosvr.py | 60 ++++++ http_client.py | 78 ++++++++ http_server.py | 68 +++++++ longlines.py | 40 ++++ main.py | 134 +++++++++++++ p3time.py | 47 +++++ polling.py | 497 +++++++++++++++++++++++++++++++++++++++++++++++++ scheduling.py | 340 +++++++++++++++++++++++++++++++++ sockets.py | 346 ++++++++++++++++++++++++++++++++++ tulip_bench.py | 30 +++ xkcd.py | 18 ++ yyftime.py | 75 ++++++++ 17 files changed, 2024 insertions(+) create mode 100644 .hgignore create mode 100644 Makefile create mode 100644 README create mode 100644 TODO create mode 100644 echoclt.py create mode 100644 echosvr.py create mode 100644 http_client.py create mode 100644 http_server.py create mode 100644 longlines.py create mode 100644 main.py create mode 100644 p3time.py create mode 100644 polling.py create mode 100644 scheduling.py create mode 100644 sockets.py create mode 100644 tulip_bench.py create mode 100755 xkcd.py create mode 100644 yyftime.py diff --git a/.hgignore b/.hgignore new file mode 100644 index 00000000..e3600e75 --- /dev/null +++ b/.hgignore @@ -0,0 +1,6 @@ +.*\.py[co]$ +.*~$ +.*\.orig$ +.*\#.*$ +.*@.*$ +\.DS_Store$ diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..533e7427 --- /dev/null +++ b/Makefile @@ -0,0 +1,19 @@ +PYTHON=python3.3 + +test: + $(PYTHON) main.py -v + +echo: + $(PYTHON) echosvr.py -v + +profile: + $(PYTHON) -m profile -s time main.py + +time: + $(PYTHON) p3time.py + +ytime: + $(PYTHON) yyftime.py + +check: + $(PYTHON) longlines.py diff --git a/README b/README new file mode 100644 index 00000000..abfad624 --- /dev/null +++ b/README @@ -0,0 +1,43 @@ +Tulip is the codename for my attempt at understanding PEP-380 style +coroutines (i.e. those using generators and 'yield from'). + +*** This requires Python 3.3 or later! *** + +For reference, see many threads in python-ideas@python.org started in +October 2012, especially those with "The async API of the Future" in +their subject, and the various spin-off threads. + +A particularly influential tutorial by Greg Ewing: +http://www.cosc.canterbury.ac.nz/greg.ewing/python/generators/yf_current/Examples/Scheduler/scheduler.txt + +A message I posted with some explanation of the design: +http://mail.python.org/pipermail/python-ideas/2012-October/017501.html + +Essential files here (in top-to-bottom ordering): + +- main.py: the main program for testing +- http_client.py: a rough HTTP/1.0 client +- sockets.py: transports for sockets and (client) SSL, and a buffering layer +- scheduling.py: a Task class and related stuff; this is where the PEP + 380 scheduler is implemented +- polling.py: an event loop and basic polling implementations for: + select(), poll(), epoll(), kqueue() + +Secondary files: + +- .hgignore: files I don't care about +- Makefile: various quick shell commands +- README: this file +- TODO: longer list of TODO items and general thoughts +- http_server.py: enough of an HTTP server to point 'ab' at +- longlines.py: stupid style checker +- p3time.py: benchmark yield from vs. plain functions +- tulip_bench.py: yet another benchmark (like p3time.py and yyftime.py) +- xkcd.py: *synchronous* ssl example +- yyftime.py: benchmark yield from vs. yield + +Copyright/license: Open source, Apache 2.0. Enjoy. + +Master Mercurial repo: http://code.google.com/p/tulip/ + +--Guido van Rossum diff --git a/TODO b/TODO new file mode 100644 index 00000000..ab7b53f3 --- /dev/null +++ b/TODO @@ -0,0 +1,144 @@ +# -*- Mode: text -*- + +TO DO SMALLER TASKS + +- Make Task more like Future; getting result() should re-raise exception. + +- Add a decorator just for documenting a coroutine. It should set a + flag on the function. It should not interfere with methods, + staticmethod, classmethod and the like. The Task constructor should + check the flag. The decorator should notice if the wrapped function + is not a generator. + + +TO DO LARGER TASKS + +- Need more examples. + +- Benchmarkable but more realistic HTTP server? + +- Example of using UDP. + +- Write up a tutorial for the scheduling API. + +- More systematic approach to logging. Logger objects? What about + heavy-duty logging, tracking essentially all task state changes? + +- Restructure directory, move demos and benchmarks to subdirectories. + + +TO DO LATER + +- Wrap select(), epoll() etc. in try/except checking for EINTR. + +- Move accept loop into Listener class? (Windows is said to work + better if you make many AcceptEx() calls in parallel.) OTOH we can + already accept many incoming connections without suspending. + +- When multiple tasks are accessing the same socket, they should + either get interleaved I/O or an immediate exception; it should not + compromise the integrity of the scheduler or the app or leave a task + hanging. + +- For epoll you probably want to check/(log?) EPOLLHUP and EPOLLERR errors. + +- Add the simplest API possible to run a generator with a timeout. + +- Do we need call_every()? (Easily emulated with a loop and sleep().) + +- Ensure multiple tasks can do atomic writes to the same pipe (since + UNIX guarantees that short writes to pipes are atomic). + +- Ensure some easy way of distributing accepted connections across tasks. + +- Be wary of thread-local storage. There should be a standard API to + get the current Context (which holds current task, event loop, and + maybe more) and a standard meta-API to change how that standard API + works (i.e. without monkey-patching). + +- See how much of asyncore I've already replaced. + +- Do we need _async suffixes to all async APIs? + +- Do we need synchronous parallel APIs for all async APIs? + +- Could BufferedReader reuse the standard io module's readers??? + +- Support ZeroMQ "sockets" which are user objects. Though possibly + this can be supported by getting the underlying fd? See + http://mail.python.org/pipermail/python-ideas/2012-October/017532.html + OTOH see + https://github.com/zeromq/pyzmq/blob/master/zmq/eventloop/ioloop.py + +- Study goroutines (again). + +- Benchmarks: http://nichol.as/benchmark-of-python-web-servers + + +FROM OLDER LIST + +- Is it better to have separate add_{reader,writer} methods, vs. one + add_thingie method taking a fd and a r/w flag? + +- Multiple readers/writers per socket? (At which level? pollster, + eventloop, or scheduler?) + +- Could poll() usefully be an iterator? + +- Do we need to support more epoll and/or kqueue modes/flags/options/etc.? + +- Optimize register/unregister calls away if they cancel each other out? + +- Should block() use a queue? + +- Add explicit wait queue to wait for Task's completion, instead of + callbacks? + +- Global functions vs. Task methods? + +- Is the Task design good? + +- Make Task more like Future? (Or less???) + +- Implement various lock styles a la threading.py. + +- Add write() calls that don't require yield from. + +- Add simple non-async APIs, for simple apps? + +- Look at pyfdpdlib's ioloop.py: + http://code.google.com/p/pyftpdlib/source/browse/trunk/pyftpdlib/lib/ioloop.py + + +MISTAKES I MADE + +- Forgetting yield from. (E.g.: scheduler.sleep(1); listener.accept().) + +- Forgot to add bare yield at end of internal function, after block(). + +- Forgot to call add_done_callback(). + +- Forgot to pass an undoer to block(), bug only found when cancelled. + +- Subtle accounting mistake in a callback. + +- Used context.eventloop from a different thread, forgetting about TLS. + +- Nasty race: eventloop.ready may contain both an I/O callback and a + cancel callback. How to avoid? Keep the DelayedCall in ready. Is + that enough? + +- If a toplevel task raises an error it just stops and nothing is logged + unless you have debug logging on. This confused me. (Then again, + previously I logged whenever a task raised an error, and that was too + chatty...) + +- Forgot to set the connection socket returned by accept() in + nonblocking mode. + +- Nastiest so far (cost me about a day): A race condition in + call_in_thread() where the Future's done_callback (which was + task.unblock()) would run immediately at the time when + add_done_callback() was called, and this screwed over the task + state. Solution: wrap the callback in eventloop.call_later(). + Ironically, I had a comment stating there might be a race condition. diff --git a/echoclt.py b/echoclt.py new file mode 100644 index 00000000..368856bc --- /dev/null +++ b/echoclt.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3.3 +"""Example echo client.""" + +# Stdlib imports. +import logging +import socket +import sys +import time + +# Local imports. +import scheduling +import sockets + + +def echoclient(host, port): + """COROUTINE""" + testdata = b'hi hi hi ha ha ha\n' + try: + trans = yield from sockets.create_transport(host, port, + af=socket.AF_INET) + except OSError: + return False + try: + ok = yield from trans.send(testdata) + if ok: + response = yield from trans.recv(100) + ok = response == testdata.upper() + return ok + finally: + trans.close() + + +def doit(n): + """COROUTINE""" + t0 = time.time() + tasks = set() + for i in range(n): + t = scheduling.Task(echoclient('127.0.0.1', 1111), 'client-%d' % i) + tasks.add(t) + ok = 0 + bad = 0 + for t in tasks: + yield from t.wait() + assert not t.alive + if t.result: + ok += 1 + else: + bad += 1 + t1 = time.time() + print('ok: ', ok) + print('bad:', bad) + print('dt: ', round(t1-t0, 6)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Get integer from command line. + n = 1 + for arg in sys.argv[1:]: + if not arg.startswith('-'): + n = int(arg) + break + + # Run scheduler, starting it off with doit(). + scheduling.run(doit(n)) + + +if __name__ == '__main__': + main() diff --git a/echosvr.py b/echosvr.py new file mode 100644 index 00000000..4085f4c6 --- /dev/null +++ b/echosvr.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3.3 +"""Example echo server.""" + +# Stdlib imports. +import logging +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + while True: + line = yield from rdr.readline() + logging.debug('Received: %r from %r', line, addr) + if not line: + break + yield from trans.send(line.upper()) + logging.debug('Closing %r', addr) + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 1111, + af=socket.AF_INET, + backlog=100) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() diff --git a/http_client.py b/http_client.py new file mode 100644 index 00000000..8937ba20 --- /dev/null +++ b/http_client.py @@ -0,0 +1,78 @@ +"""Crummy HTTP client. + +This is not meant as an example of how to write a good client. +""" + +# Stdlib. +import re +import time + +# Local. +import sockets + + +def urlfetch(host, port=None, path='/', method='GET', + body=None, hdrs=None, encoding='utf-8', ssl=None, af=0): + """COROUTINE: Make an HTTP 1.0 request.""" + t0 = time.time() + if port is None: + if ssl: + port = 443 + else: + port = 80 + trans = yield from sockets.create_transport(host, port, ssl=ssl, af=af) + yield from trans.send(method.encode(encoding) + b' ' + + path.encode(encoding) + b' HTTP/1.0\r\n') + if hdrs: + kwds = dict(hdrs) + else: + kwds = {} + if 'host' not in kwds: + kwds['host'] = host + if body is not None: + kwds['content_length'] = len(body) + for header, value in kwds.items(): + yield from trans.send(header.replace('_', '-').encode(encoding) + + b': ' + value.encode(encoding) + b'\r\n') + + yield from trans.send(b'\r\n') + if body is not None: + yield from trans.send(body) + + # Read HTTP response line. + rdr = sockets.BufferedReader(trans) + resp = yield from rdr.readline() + m = re.match(br'(?ix) http/(\d\.\d) \s+ (\d\d\d) \s+ ([^\r]*)\r?\n\Z', + resp) + if not m: + trans.close() + raise IOError('No valid HTTP response: %r' % resp) + http_version, status, message = m.groups() + + # Read HTTP headers. + headers = [] + hdict = {} + while True: + line = yield from rdr.readline() + if not line.strip(): + break + m = re.match(br'([^\s:]+):\s*([^\r]*)\r?\n\Z', line) + if not m: + raise IOError('Invalid header: %r' % line) + header, value = m.groups() + headers.append((header, value)) + hdict[header.decode(encoding).lower()] = value.decode(encoding) + + # Read response body. + content_length = hdict.get('content-length') + if content_length is not None: + size = int(content_length) # TODO: Catch errors. + assert size >= 0, size + else: + size = 2**20 # Protective limit (1 MB). + data = yield from rdr.readexactly(size) + trans.close() # Can this block? + t1 = time.time() + result = (host, port, path, int(status), len(data), round(t1-t0, 3)) +## print(result) + return result diff --git a/http_server.py b/http_server.py new file mode 100644 index 00000000..2b1e3dd6 --- /dev/null +++ b/http_server.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3.3 +"""Simple HTTP server. + +This currenty exists just so we can benchmark this thing! +""" + +# Stdlib imports. +import logging +import re +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + ##logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + + # Read but ignore request line. + request_line = yield from rdr.readline() + + # Consume headers but don't interpret them. + while True: + header_line = yield from rdr.readline() + if not header_line.strip(): + break + + # Always send an empty 200 response and close. + yield from trans.send(b'HTTP/1.0 200 Ok\r\n\r\n') + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 8080, + af=socket.AF_INET) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() diff --git a/longlines.py b/longlines.py new file mode 100644 index 00000000..f0aa9a66 --- /dev/null +++ b/longlines.py @@ -0,0 +1,40 @@ +"""Search for lines > 80 chars or with trailing whitespace.""" + +import sys, os + +def main(): + args = sys.argv[1:] or os.curdir + for arg in args: + if os.path.isdir(arg): + for dn, dirs, files in os.walk(arg): + for fn in sorted(files): + if fn.endswith('.py'): + process(os.path.join(dn, fn)) + dirs[:] = [d for d in dirs if d[0] != '.'] + dirs.sort() + else: + process(arg) + +def isascii(x): + try: + x.encode('ascii') + return True + except UnicodeError: + return False + +def process(fn): + try: + f = open(fn) + except IOError as err: + print(err) + return + try: + for i, line in enumerate(f): + line = line.rstrip('\n') + sline = line.rstrip() + if len(line) > 80 or line != sline or not isascii(line): + print('%s:%d:%s%s' % (fn, i+1, sline, '_' * (len(line) - len(sline)))) + finally: + f.close() + +main() diff --git a/main.py b/main.py new file mode 100644 index 00000000..99de19c7 --- /dev/null +++ b/main.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3.3 +"""Example HTTP client using yield-from coroutines (PEP 380). + +Requires Python 3.3. + +There are many micro-optimizations possible here, but that's not the point. + +Some incomplete laundry lists: + +TODO: +- Take test urls from command line. +- Move urlfetch to a separate module. +- Profiling. +- Docstrings. +- Unittests. + +FUNCTIONALITY: +- Connection pool (keep connection open). +- Chunked encoding (request and response). +- Pipelining, e.g. zlib (request and response). +- Automatic encoding/decoding. +""" + +__author__ = 'Guido van Rossum ' + +# Standard library imports (keep in alphabetic order). +import logging +import os +import time +import socket +import sys + +# Local imports (keep in alphabetic order). +import scheduling +import http_client + + + +def doit2(): + argses = [ + ('localhost', 8080, '/'), + ('127.0.0.1', 8080, '/home'), + ('python.org', 80, '/'), + ('xkcd.com', 443, '/'), + ] + results = yield from scheduling.map_over( + lambda args: http_client.urlfetch(*args), argses, timeout=2) + for res in results: + print('-->', res) + return [] + + +def doit(): + TIMEOUT = 2 + tasks = set() + + # This references NDB's default test service. + # (Sadly the service is single-threaded.) + task1 = scheduling.Task(http_client.urlfetch('localhost', 8080, path='/'), + 'root', timeout=TIMEOUT) + tasks.add(task1) + task2 = scheduling.Task(http_client.urlfetch('127.0.0.1', 8080, + path='/home'), + 'home', timeout=TIMEOUT) + tasks.add(task2) + + # Fetch python.org home page. + task3 = scheduling.Task(http_client.urlfetch('python.org', 80, path='/'), + 'python', timeout=TIMEOUT) + tasks.add(task3) + + # Fetch XKCD home page using SSL. (Doesn't like IPv6.) + task4 = scheduling.Task(http_client.urlfetch('xkcd.com', ssl=True, path='/', + af=socket.AF_INET), + 'xkcd', timeout=TIMEOUT) + tasks.add(task4) + +## # Fetch many links from python.org (/x.y.z). +## for x in '123': +## for y in '0123456789': +## path = '/{}.{}'.format(x, y) +## g = http_client.urlfetch('82.94.164.162', 80, +## path=path, hdrs={'host': 'python.org'}) +## t = scheduling.Task(g, path, timeout=2) +## tasks.add(t) + +## print(tasks) + yield from scheduling.with_timeout(0.2, scheduling.sleep(1)) + winners = yield from scheduling.wait_any(tasks) + print('And the winners are:', [w.name for w in winners]) + tasks = yield from scheduling.wait_all(tasks) + print('And the players were:', [t.name for t in tasks]) + return tasks + + +def logtimes(real): + utime, stime, cutime, cstime, unused = os.times() + logging.info('real %10.3f', real) + logging.info('user %10.3f', utime + cutime) + logging.info('sys %10.3f', stime + cstime) + + +def main(): + t0 = time.time() + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + task = scheduling.run(doit()) + if task.exception: + print('Exception:', repr(task.exception)) + if isinstance(task.exception, AssertionError): + raise task.exception + else: + for t in task.result: + print(t.name + ':', + repr(t.exception) if t.exception else t.result) + + # Report real, user, sys times. + t1 = time.time() + logtimes(t1-t0) + + +if __name__ == '__main__': + main() diff --git a/p3time.py b/p3time.py new file mode 100644 index 00000000..35e14c96 --- /dev/null +++ b/p3time.py @@ -0,0 +1,47 @@ +"""Compare timing of plain vs. yield-from calls.""" + +import gc +import time + +def plain(n): + if n <= 0: + return 1 + l = plain(n-1) + r = plain(n-1) + return l + 1 + r + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def submain(depth): + t0 = time.time() + k = plain(depth) + t1 = time.time() + fmt = ' {} {} {:-9,.5f}' + delta0 = t1-t0 + print(('plain' + fmt).format(depth, k, delta0)) + + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + delta1 = t1-t0 + print(('coro.' + fmt).format(depth, k, delta1)) + if delta0: + print(('relat' + fmt).format(depth, k, delta1/delta0)) + +def main(reasonable=16): + gc.disable() + for depth in range(reasonable): + submain(depth) + +if __name__ == '__main__': + main() diff --git a/polling.py b/polling.py new file mode 100644 index 00000000..8e160b3a --- /dev/null +++ b/polling.py @@ -0,0 +1,497 @@ +"""Event loop and related classes. + +The event loop can be broken up into a pollster (the part responsible +for telling us when file descriptors are ready) and the event loop +proper, which wraps a pollster with functionality for scheduling +callbacks, immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. + +There are several implementations of the pollster part, several using +esoteric system calls that exist only on some platforms. These are: + +- kqueue (most BSD systems) +- epoll (newer Linux systems) +- poll (most UNIX systems) +- select (all UNIX systems, and Windows) +- TODO: Support IOCP on Windows and some UNIX platforms. + +NOTE: We don't use select on systems where any of the others is +available, because select performs poorly as the number of file +descriptors goes up. The ranking is roughly: + + 1. kqueue, epoll, IOCP + 2. poll + 3. select + +TODO: +- Optimize the various pollsters. +- Unittests. +""" + +import collections +import concurrent.futures +import heapq +import logging +import os +import select +import time + + +class PollsterBase: + """Base class for all polling implementations. + + This defines an interface to register and unregister readers and + writers for specific file descriptors, and an interface to get a + list of events. There's also an interface to check whether any + readers or writers are currently registered. + """ + + def __init__(self): + super().__init__() + self.readers = {} # {fd: token, ...}. + self.writers = {} # {fd: token, ...}. + + def pollable(self): + """Return True if any readers or writers are currently registered.""" + return bool(self.readers or self.writers) + + # Subclasses are expected to extend the add/remove methods. + + def register_reader(self, fd, token): + """Add or update a reader for a file descriptor.""" + self.readers[fd] = token + + def register_writer(self, fd, token): + """Add or update a writer for a file descriptor.""" + self.writers[fd] = token + + def unregister_reader(self, fd): + """Remove the reader for a file descriptor.""" + del self.readers[fd] + + def unregister_writer(self, fd): + """Remove the writer for a file descriptor.""" + del self.writers[fd] + + def poll(self, timeout=None): + """Poll for events. A subclass must implement this. + + If timeout is omitted or None, this blocks until at least one + event is ready. Otherwise, timeout gives a maximum time to + wait (an int of float in seconds) -- the method returns as + soon as at least one event is ready or when the timeout is + expired. For a non-blocking poll, pass 0. + + The return value is a list of events; it is empty when the + timeout expired before any events were ready. Each event + is a token previously passed to register_reader/writer(). + """ + raise NotImplementedError + + +class SelectPollster(PollsterBase): + """Pollster implementation using select.""" + + def poll(self, timeout=None): + readable, writable, _ = select.select(self.readers, self.writers, + [], timeout) + events = [] + events += (self.readers[fd] for fd in readable) + events += (self.writers[fd] for fd in writable) + return events + + +class PollPollster(PollsterBase): + """Pollster implementation using poll.""" + + def __init__(self): + super().__init__() + self._poll = select.poll() + + def _update(self, fd): + assert isinstance(fd, int), fd + flags = 0 + if fd in self.readers: + flags |= select.POLLIN + if fd in self.writers: + flags |= select.POLLOUT + if flags: + self._poll.register(fd, flags) + else: + self._poll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + # Timeout is in seconds, but poll() takes milliseconds. + msecs = None if timeout is None else int(round(1000 * timeout)) + events = [] + for fd, flags in self._poll.poll(msecs): + if flags & (select.POLLIN | select.POLLHUP): + if fd in self.readers: + events.append(self.readers[fd]) + if flags & (select.POLLOUT | select.POLLHUP): + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class EPollPollster(PollsterBase): + """Pollster implementation using epoll.""" + + def __init__(self): + super().__init__() + self._epoll = select.epoll() + + def _update(self, fd): + assert isinstance(fd, int), fd + eventmask = 0 + if fd in self.readers: + eventmask |= select.EPOLLIN + if fd in self.writers: + eventmask |= select.EPOLLOUT + if eventmask: + try: + self._epoll.register(fd, eventmask) + except IOError: + self._epoll.modify(fd, eventmask) + else: + self._epoll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + if timeout is None: + timeout = -1 # epoll.poll() uses -1 to mean "wait forever". + events = [] + for fd, eventmask in self._epoll.poll(timeout): + if eventmask & select.EPOLLIN: + if fd in self.readers: + events.append(self.readers[fd]) + if eventmask & select.EPOLLOUT: + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class KqueuePollster(PollsterBase): + """Pollster implementation using kqueue.""" + + def __init__(self): + super().__init__() + self._kqueue = select.kqueue() + + def register_reader(self, fd, callback, *args): + if fd not in self.readers: + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_reader(fd, callback, *args) + + def register_writer(self, fd, callback, *args): + if fd not in self.readers: + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_writer(fd, callback, *args) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def poll(self, timeout=None): + events = [] + max_ev = len(self.readers) + len(self.writers) + for kev in self._kqueue.control(None, max_ev, timeout): + fd = kev.ident + flag = kev.filter + if flag == select.KQ_FILTER_READ and fd in self.readers: + events.append(self.readers[fd]) + elif flag == select.KQ_FILTER_WRITE and fd in self.writers: + events.append(self.writers[fd]) + return events + + +# Pick the best pollster class for the platform. +if hasattr(select, 'kqueue'): + best_pollster = KqueuePollster +elif hasattr(select, 'epoll'): + best_pollster = EPollPollster +elif hasattr(select, 'poll'): + best_pollster = PollPollster +else: + best_pollster = SelectPollster + + +class DelayedCall: + """Object returned by callback registration methods.""" + + def __init__(self, when, callback, args, kwds=None): + self.when = when + self.callback = callback + self.args = args + self.kwds = kwds + self.cancelled = False + + def cancel(self): + self.cancelled = True + + def __lt__(self, other): + return self.when < other.when + + def __le__(self, other): + return self.when <= other.when + + def __eq__(self, other): + return self.when == other.when + + +class EventLoop: + """Event loop functionality. + + This defines public APIs call_soon(), call_later(), run_once() and + run(). It also wraps Pollster APIs register_reader(), + register_writer(), remove_reader(), remove_writer() with + add_reader() etc. + + This class's instance variables are not part of its API. + """ + + def __init__(self, pollster=None): + super().__init__() + if pollster is None: + logging.info('Using pollster: %s', best_pollster.__name__) + pollster = best_pollster() + self.pollster = pollster + self.ready = collections.deque() # [(callback, args), ...] + self.scheduled = [] # [(when, callback, args), ...] + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_reader(fd, dcall) + return dcall + + def remove_reader(self, fd): + """Remove a reader callback.""" + self.pollster.unregister_reader(fd) + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_writer(fd, dcall) + return dcall + + def remove_writer(self, fd): + """Remove a writer callback.""" + self.pollster.unregister_writer(fd) + + def add_callback(self, dcall): + """Add a DelayedCall to ready or scheduled.""" + if dcall.cancelled: + return + if dcall.when is None: + self.ready.append(dcall) + else: + heapq.heappush(self.scheduled, dcall) + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + dcall = DelayedCall(None, callback, args) + self.ready.append(dcall) + return dcall + + def call_later(self, when, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The time can be an int or float, expressed in seconds. + + If when is small enough (~11 days), it's assumed to be a + relative time, meaning the call will be scheduled that many + seconds in the future; otherwise it's assumed to be a posix + timestamp as returned by time.time(). + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + if when < 10000000: + when += time.time() + dcall = DelayedCall(when, callback, args) + heapq.heappush(self.scheduled, dcall) + return dcall + + def run_once(self): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Pass in a timeout or deadline or something. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: As step 4, run everything scheduled by steps 1-3. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # TODO: Ensure this loop always finishes, even if some + # callbacks keeps registering more callbacks. + while self.ready: + dcall = self.ready.popleft() + if not dcall.cancelled: + try: + if dcall.kwds: + dcall.callback(*dcall.args, **dcall.kwds) + else: + dcall.callback(*dcall.args) + except Exception: + logging.exception('Exception in callback %s %r', + dcall.callback, dcall.args) + + # Remove delayed calls that were cancelled from head of queue. + while self.scheduled and self.scheduled[0].cancelled: + heapq.heappop(self.scheduled) + + # Inspect the poll queue. + if self.pollster.pollable(): + if self.scheduled: + when = self.scheduled[0].when + timeout = max(0, when - time.time()) + else: + timeout = None + t0 = time.time() + events = self.pollster.poll(timeout) + t1 = time.time() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + for dcall in events: + self.add_callback(dcall) + + # Handle 'later' callbacks that are ready. + while self.scheduled: + dcall = self.scheduled[0] + if dcall.when > time.time(): + break + dcall = heapq.heappop(self.scheduled) + self.call_soon(dcall.callback, *dcall.args) + + def run(self): + """Run the event loop until there is no work left to do. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + """ + while self.ready or self.scheduled or self.pollster.pollable(): + self.run_once() + + +MAX_WORKERS = 5 # Default max workers when creating an executor. + + +class ThreadRunner: + """Helper to submit work to a thread pool and wait for it. + + This is the glue between the single-threaded callback-based async + world and the threaded world. Use it to call functions that must + block and don't have an async alternative (e.g. getaddrinfo()). + + The only public API is submit(). + """ + + def __init__(self, eventloop, executor=None): + self.eventloop = eventloop + self.executor = executor # Will be constructed lazily. + self.pipe_read_fd, self.pipe_write_fd = os.pipe() + self.active_count = 0 + + def read_callback(self): + """Semi-permanent callback while at least one future is active.""" + assert self.active_count > 0, self.active_count + data = os.read(self.pipe_read_fd, 8192) # Traditional buffer size. + self.active_count -= len(data) + if self.active_count == 0: + self.eventloop.remove_reader(self.pipe_read_fd) + assert self.active_count >= 0, self.active_count + + def submit(self, func, *args, executor=None): + """Submit a function to the thread pool. + + This returns a concurrent.futures.Future instance. The caller + should not wait for that, but rather add a callback to it. + """ + if executor is None: + executor = self.executor + if executor is None: + # Lazily construct a default executor. + # TODO: Should this be shared between threads? + executor = concurrent.futures.ThreadPoolExecutor(MAX_WORKERS) + self.executor = executor + assert self.active_count >= 0, self.active_count + future = executor.submit(func, *args) + if self.active_count == 0: + self.eventloop.add_reader(self.pipe_read_fd, self.read_callback) + self.active_count += 1 + def done_callback(future): + os.write(self.pipe_write_fd, b'x') + future.add_done_callback(done_callback) + return future diff --git a/scheduling.py b/scheduling.py new file mode 100644 index 00000000..069bcf98 --- /dev/null +++ b/scheduling.py @@ -0,0 +1,340 @@ +#!/usr/bin/env python3.3 +"""Example coroutine scheduler, PEP-380-style ('yield from '). + +Requires Python 3.3. + +There are likely micro-optimizations possible here, but that's not the point. + +TODO: +- Docstrings. +- Unittests. + +PATTERNS TO TRY: +- Various synchronization primitives (Lock, RLock, Event, Condition, + Semaphore, BoundedSemaphore, Barrier). +""" + +__author__ = 'Guido van Rossum ' + +# Standard library imports (keep in alphabetic order). +from concurrent.futures import CancelledError, TimeoutError +import logging +import threading +import time +import types + +# Local imports (keep in alphabetic order). +import polling + + +class Context(threading.local): + """Thread-local context. + + We use this to avoid having to explicitly pass around an event loop + or something to hold the current task. + + TODO: Add an API so frameworks can substitute a different notion + of context more easily. + """ + + def __init__(self, eventloop=None, threadrunner=None): + # Default event loop and thread runner are lazily constructed + # when first accessed. + self._eventloop = eventloop + self._threadrunner = threadrunner + self.current_task = None + + @property + def eventloop(self): + if self._eventloop is None: + self._eventloop = polling.EventLoop() + return self._eventloop + + @property + def threadrunner(self): + if self._threadrunner is None: + self._threadrunner = polling.ThreadRunner(self.eventloop) + return self._threadrunner + + +context = Context() # Thread-local! + + +class Task: + """Wrapper around a stack of generators. + + This is a bit like a Future, but with a different interface. + + TODO: + - wait for result. + """ + + def __init__(self, gen, name=None, *, timeout=None): + assert isinstance(gen, types.GeneratorType), repr(gen) + self.gen = gen + self.name = name or gen.__name__ + self.timeout = timeout + self.eventloop = context.eventloop + self.canceleer = None + if timeout is not None: + self.canceleer = self.eventloop.call_later(timeout, self.cancel) + self.blocked = False + self.unblocker = None + self.cancelled = False + self.must_cancel = False + self.alive = True + self.result = None + self.exception = None + self.done_callbacks = [] + # Start the task immediately. + self.eventloop.call_soon(self.step) + + def add_done_callback(self, done_callback): + # For better or for worse, the callback will always be called + # with the task as an argument, like concurrent.futures.Future. + self.done_callbacks.append(done_callback) + + def __repr__(self): + parts = [self.name] + is_current = (self is context.current_task) + if self.blocked: + parts.append('blocking' if is_current else 'blocked') + elif self.alive: + parts.append('running' if is_current else 'runnable') + if self.must_cancel: + parts.append('must_cancel') + if self.cancelled: + parts.append('cancelled') + if self.exception: + parts.append('exception=%r' % self.exception) + elif not self.alive: + parts.append('result=%r' % self.result) + if self.timeout is not None: + parts.append('timeout=%.3f' % self.timeout) + return 'Task<' + ', '.join(parts) + '>' + + def cancel(self): + if self.alive: + if not self.must_cancel and not self.cancelled: + self.must_cancel = True + if self.blocked: + self.unblock() + + def step(self): + assert self.alive, self + try: + context.current_task = self + if self.must_cancel: + self.must_cancel = False + self.cancelled = True + self.gen.throw(CancelledError()) + else: + next(self.gen) + except StopIteration as exc: + self.alive = False + self.result = exc.value + except Exception as exc: + self.alive = False + self.exception = exc + logging.debug('Uncaught exception in %s', self, + exc_info=True, stack_info=True) + except BaseException as exc: + self.alive = False + self.exception = exc + raise + else: + if not self.blocked: + self.eventloop.call_soon(self.step) + finally: + context.current_task = None + if not self.alive: + # Cancel timeout callback if set. + if self.canceleer is not None: + self.canceleer.cancel() + # Schedule done_callbacks. + for callback in self.done_callbacks: + self.eventloop.call_soon(callback, self) + + def block(self, unblock_callback=None, *unblock_args): + assert self is context.current_task, self + assert self.alive, self + assert not self.blocked, self + self.blocked = True + self.unblocker = (unblock_callback, unblock_args) + + def unblock_if_alive(self): + if self.alive: + self.unblock() + + def unblock(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. + assert self.alive, self + assert self.blocked, self + self.blocked = False + unblock_callback, unblock_args = self.unblocker + if unblock_callback is not None: + try: + unblock_callback(*unblock_args) + except Exception: + logging.error('Exception in unblocker in task %r', self.name) + raise + finally: + self.unblocker = None + self.eventloop.call_soon(self.step) + + def block_io(self, fd, flag): + assert isinstance(fd, int), repr(fd) + assert flag in ('r', 'w'), repr(flag) + if flag == 'r': + self.block(self.eventloop.remove_reader, fd) + self.eventloop.add_reader(fd, self.unblock) + else: + self.block(self.eventloop.remove_writer, fd) + self.eventloop.add_writer(fd, self.unblock) + + def wait(self): + """COROUTINE: Wait until this task is finished.""" + current_task = context.current_task + assert self is not current_task, (self, current_task) # How confusing! + if not self.alive: + return + current_task.block() + self.add_done_callback(current_task.unblock) + yield + + +def run(arg=None): + """Run the event loop until it's out of work. + + If you pass a generator, it will be spawned for you. + You can also pass a task (already started). + Returns the task. + """ + t = None + if arg is not None: + if isinstance(arg, Task): + t = arg + else: + t = Task(arg) + context.eventloop.run() + if t is not None and t.exception is not None: + logging.error('Uncaught exception in startup task: %r', + t.exception) + return t + + +def sleep(secs): + """COROUTINE: Sleep for some time (a float in seconds).""" + current_task = context.current_task + unblocker = context.eventloop.call_later(secs, current_task.unblock) + current_task.block(unblocker.cancel) + yield + + +def block_r(fd): + """COROUTINE: Block until a file descriptor is ready for reading.""" + context.current_task.block_io(fd, 'r') + yield + + +def block_w(fd): + """COROUTINE: Block until a file descriptor is ready for writing.""" + context.current_task.block_io(fd, 'w') + yield + + +def call_in_thread(func, *args, executor=None): + """COROUTINE: Run a function in a thread.""" + task = context.current_task + eventloop = context.eventloop + future = context.threadrunner.submit(func, *args, executor=executor) + task.block(future.cancel) + # If the thread managed to complete before we get here, + # add_done_callback() will call the callback right now. Make sure + # the unblock() call doesn't happen until later. But then, the + # task may already have been cancelled (and it may have been too + # late to cancel the Future) so it should be okay if this call + # finds the task deceased. For that purpose we have + # unblock_if_alive(). + future.add_done_callback( + lambda _: eventloop.call_soon(task.unblock_if_alive)) + yield + assert future.done() + return future.result() + + +def wait_for(count, tasks): + """COROUTINE: Wait for the first N of a set of tasks to complete. + + May return more than N if more than N are immediately ready. + + NOTE: Tasks that were cancelled or raised are also considered ready. + """ + assert tasks + tasks = set(tasks) + assert 1 <= count <= len(tasks) + current_task = context.current_task + assert all(task is not current_task for task in tasks) + todo = set() + done = set() + def wait_for_callback(task): + nonlocal todo, done, current_task, count + todo.remove(task) + if len(done) < count: + done.add(task) + if len(done) == count: + current_task.unblock() + for task in tasks: + if task.alive: + todo.add(task) + else: + done.add(task) + if len(done) < count: + for task in todo: + task.add_done_callback(wait_for_callback) + current_task.block() + yield + return done + + +def wait_any(tasks): + """COROUTINE: Wait for the first of a set of tasks to complete.""" + return wait_for(1, tasks) + + +def wait_all(tasks): + """COROUTINE: Wait for all of a set of tasks to complete.""" + return wait_for(len(tasks), tasks) + + +def with_timeout(timeout, gen, name=None): + """COROUTINE: Run generator synchronously with a timeout.""" + assert timeout is not None + task = Task(gen, name, timeout=timeout) + return (yield from task.wait()) + + +def map_over(gen, *args, timeout=None): + """COROUTINE: map a generator over one or more iterables. + + E.g. map_over(foo, xs, ys) runs + + Task(foo(x, y) for x, y in zip(xs, ys) + + and returns a list of all results (in that order). However if any + task raises an exception, the remaining tasks are cancelled and + the exception is propagated. + """ + # gen is a generator function. + tasks = [Task(gobj, timeout=timeout) for gobj in map(gen, *args)] + todo = set(tasks) + while todo: + ts = yield from wait_for(1, todo) + for t in ts: + assert not t.alive, t + todo.remove(t) + if t.exception is not None: + for other in todo: + other.cancel() + raise t.exception + return [t.result for t in tasks] diff --git a/sockets.py b/sockets.py new file mode 100644 index 00000000..6c1be4bb --- /dev/null +++ b/sockets.py @@ -0,0 +1,346 @@ +"""Socket wrappers to go with scheduling.py. + +Classes: + +- SocketTransport: a transport implementation wrapping a socket. +- SslTransport: a transport implementation wrapping SSL around a socket. +- BufferedReader: a buffer wrapping the read end of a transport. + +Functions (all coroutines): + +- connect(): connect a socket. +- getaddrinfo(): look up an address. +- create_connection(): look up address and return a connected socket for it. +- create_transport(): look up address and return a connected transport. + +TODO: +- Improve transport abstraction. +- Make a nice protocol abstraction. +- Unittests. +- A write() call that isn't a generator (needed so you can substitute it + for sys.stderr, pass it to logging.StreamHandler, etc.). +""" + +__author__ = 'Guido van Rossum ' + +# Stdlib imports. +import errno +import socket +import ssl + +# Local imports. +import scheduling + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + + +class SocketTransport: + """Transport wrapping a socket. + + The socket must already be connected in non-blocking mode. + """ + + def __init__(self, sock): + self.sock = sock + + def recv(self, n): + """COROUTINE: Read up to n bytes, blocking as needed. + + Always returns at least one byte, except if the socket was + closed or disconnected and there's no more data; then it + returns b''. + """ + assert n >= 0, n + while True: + try: + return self.sock.recv(n) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return b'' + else: + raise # Unexpected, propagate. + yield from scheduling.block_r(self.sock.fileno()) + + def send(self, data): + """COROUTINE; Send data to the socket, blocking until all written. + + Return True if all went well, False if socket was disconnected. + """ + while data: + try: + n = self.sock.send(data) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + else: + assert 0 <= n <= len(data), (n, len(data)) + if n == len(data): + break + data = data[n:] + continue + yield from scheduling.block_w(self.sock.fileno()) + + return True + + def close(self): + """Close the socket. (Not a coroutine.)""" + self.sock.close() + + +class SslTransport: + """Transport wrapping a socket in SSL. + + The socket must already be connected at the TCP level in + non-blocking mode. + """ + + def __init__(self, rawsock, sslcontext=None): + self.rawsock = rawsock + self.sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self.sslsock = self.sslcontext.wrap_socket( + self.rawsock, do_handshake_on_connect=False) + + def do_handshake(self): + """COROUTINE: Finish the SSL handshake.""" + while True: + try: + self.sslsock.do_handshake() + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + else: + break + + def recv(self, n): + """COROUTINE: Read up to n bytes. + + This blocks until at least one byte is read, or until EOF. + """ + while True: + try: + return self.sslsock.recv(n) + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_r(self.sock.fileno()) + elif err.errno in _DISCONNECTED: + # Can this happen? + return b'' + else: + raise # Unexpected, propagate. + + def send(self, data): + """COROUTINE: Send data to the socket, blocking as needed.""" + while data: + try: + n = self.sslsock.send(data) + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_w(self.sock.fileno()) + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + if n == len(data): + break + data = data[n:] + + return True + + def close(self): + """Close the socket. (Not a coroutine.) + + This also closes the raw socket. + """ + self.sslsock.close() + + # TODO: More SSL-specific methods, e.g. certificate stuff, unwrap(), ... + + +class BufferedReader: + """A buffered reader wrapping a transport.""" + + def __init__(self, trans, limit=8192): + self.trans = trans + self.limit = limit + self.buffer = b'' + self.eof = False + + def read(self, n): + """COROUTINE: Read up to n bytes, blocking at most once.""" + assert n >= 0, n + if not self.buffer and not self.eof: + yield from self._fillbuffer(max(n, self.limit)) + return self._getfrombuffer(n) + + def readexactly(self, n): + """COUROUTINE: Read exactly n bytes, or until EOF.""" + blocks = [] + count = 0 + while n > count: + block = yield from self.read(n - count) + blocks.append(block) + count += len(block) + return b''.join(blocks) + + def readline(self): + """COROUTINE: Read up to newline or limit, whichever comes first.""" + end = self.buffer.find(b'\n') + 1 # Point past newline, or 0. + while not end and not self.eof and len(self.buffer) < self.limit: + anchor = len(self.buffer) + yield from self._fillbuffer(self.limit) + end = self.buffer.find(b'\n', anchor) + 1 + if not end: + end = len(self.buffer) + if end > self.limit: + end = self.limit + return self._getfrombuffer(end) + + def _getfrombuffer(self, n): + """Read up to n bytes without blocking (not a coroutine).""" + if n >= len(self.buffer): + result, self.buffer = self.buffer, b'' + else: + result, self.buffer = self.buffer[:n], self.buffer[n:] + return result + + def _fillbuffer(self, n): + """COROUTINE: Fill buffer with one (up to) n bytes from transport.""" + assert not self.eof, '_fillbuffer called at eof' + data = yield from self.trans.recv(n) + if data: + self.buffer += data + else: + self.eof = True + + +def connect(sock, address): + """COROUTINE: Connect a socket to an address.""" + try: + sock.connect(address) + except socket.error as err: + if err.errno != errno.EINPROGRESS: + raise + yield from scheduling.block_w(sock.fileno()) + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + raise IOError(err, 'Connection refused') + + +def getaddrinfo(host, port, af=0, socktype=0, proto=0): + """COROUTINE: Look up an address and return a list of infos for it. + + Each info is a tuple (af, socktype, protocol, canonname, address). + """ + infos = yield from scheduling.call_in_thread(socket.getaddrinfo, + host, port, af, + socktype, proto) + return infos + + +def create_connection(host, port, af=0, socktype=socket.SOCK_STREAM, proto=0): + """COROUTINE: Look up address and create a socket connected to it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + yield from connect(sock, address) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return sock + + +def create_transport(host, port, af=0, ssl=None): + """COROUTINE: Look up address and create a transport connected to it.""" + if ssl is None: + ssl = (port == 443) + sock = yield from create_connection(host, port, af) + if ssl: + trans = SslTransport(sock) + yield from trans.do_handshake() + else: + trans = SocketTransport(sock) + return trans + + +class Listener: + """Wrapper for a listening socket.""" + + def __init__(self, sock): + self.sock = sock + + def accept(self): + """COROUTINE: Accept a connection.""" + while True: + try: + conn, addr = self.sock.accept() + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_r(self.sock.fileno()) + else: + raise # Unexpected, propagate. + else: + conn.setblocking(False) + return conn, addr + + +def create_listener(host, port, af=0, socktype=0, proto=0, + backlog=5, reuse_addr=True): + """COROUTINE: Look up address and create a listener for it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + if reuse_addr: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + sock.listen(backlog) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return Listener(sock) diff --git a/tulip_bench.py b/tulip_bench.py new file mode 100644 index 00000000..64937881 --- /dev/null +++ b/tulip_bench.py @@ -0,0 +1,30 @@ +'''Example app using `file_async` and cancellations.''' + +__author__ = 'Guido van Rossum ' + +import time + +import scheduling + +def binary(n): + if n <= 0: + return 1 + l = yield from binary(n-1) + r = yield from binary(n-1) + return l + 1 + r + +def doit(depth): + t0 = time.time() + k = yield from binary(depth) + t1 = time.time() + print(depth, k, round(t1-t0, 6)) + return (depth, k, round(t1-t0, 6)) + +def main(): + for depth in range(20): + yield from doit(depth) + +import logging +logging.basicConfig(level=logging.DEBUG) + +scheduling.run(main()) diff --git a/xkcd.py b/xkcd.py new file mode 100755 index 00000000..474009d0 --- /dev/null +++ b/xkcd.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3.3 +"""Minimal synchronous SSL demo, connecting to xkcd.com.""" + +import socket, ssl + +s = socket.socket() +s.connect(('xkcd.com', 443)) +ss = ssl.wrap_socket(s) + +ss.send(b'GET / HTTP/1.0\r\n\r\n') + +while True: + data = ss.recv(1000000) + print(data) + if not data: + break + +ss.close() diff --git a/yyftime.py b/yyftime.py new file mode 100644 index 00000000..f55234b9 --- /dev/null +++ b/yyftime.py @@ -0,0 +1,75 @@ +"""Compare timing of yield-from vs. yield calls.""" + +import gc +import time + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def run_coro(depth): + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + print('coro', depth, k, round(t1-t0, 6)) + return t1-t0 + +class Future: + + def __init__(self, g): + self.g = g + + def wait(self): + value = None + try: + while True: + f = self.g.send(value) + f.wait() + value = f.value + except StopIteration as err: + self.value = err.value + + + +def task(func): # Decorator + def wrapper(*args): + g = func(*args) + f = Future(g) + return f + return wrapper + +@task +def oldstyle(n): + if n <= 0: + return 1 + l = yield oldstyle(n-1) + r = yield oldstyle(n-1) + return l + 1 + r + +def run_olds(depth): + t0 = time.time() + f = oldstyle(depth) + f.wait() + k = f.value + t1 = time.time() + print('olds', depth, k, round(t1-t0, 6)) + return t1-t0 + +def main(): + gc.disable() + for depth in range(16): + tc = run_coro(depth) + to = run_olds(depth) + if tc: + print('ratio', round(to/tc, 2)) + +if __name__ == '__main__': + main() From ad1bf5a845a5d0b80953bba5ba215e7a63dc9678 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 2 Nov 2012 17:44:10 -0700 Subject: [PATCH 0092/1502] Make Task.add_done_callback() return a DelayedCall, and use it. --- scheduling.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/scheduling.py b/scheduling.py index 069bcf98..3c468b52 100644 --- a/scheduling.py +++ b/scheduling.py @@ -92,7 +92,11 @@ def __init__(self, gen, name=None, *, timeout=None): def add_done_callback(self, done_callback): # For better or for worse, the callback will always be called # with the task as an argument, like concurrent.futures.Future. - self.done_callbacks.append(done_callback) + dcall = polling.DelayedCall(None, done_callback, (self,)) + self.done_callbacks.append(dcall) + self.done_callbacks = [dc for dc in self.done_callbacks + if not dc.cancelled] + return dcall def __repr__(self): parts = [self.name] @@ -105,10 +109,10 @@ def __repr__(self): parts.append('must_cancel') if self.cancelled: parts.append('cancelled') - if self.exception: + if self.exception is not None: parts.append('exception=%r' % self.exception) elif not self.alive: - parts.append('result=%r' % self.result) + parts.append('result=%r' % (self.result,)) if self.timeout is not None: parts.append('timeout=%.3f' % self.timeout) return 'Task<' + ', '.join(parts) + '>' @@ -152,8 +156,8 @@ def step(self): if self.canceleer is not None: self.canceleer.cancel() # Schedule done_callbacks. - for callback in self.done_callbacks: - self.eventloop.call_soon(callback, self) + for dcall in self.done_callbacks: + self.eventloop.add_callback(dcall) def block(self, unblock_callback=None, *unblock_args): assert self is context.current_task, self @@ -277,12 +281,15 @@ def wait_for(count, tasks): assert all(task is not current_task for task in tasks) todo = set() done = set() + dcalls = [] def wait_for_callback(task): - nonlocal todo, done, current_task, count + nonlocal todo, done, current_task, count, dcalls todo.remove(task) if len(done) < count: done.add(task) if len(done) == count: + for dcall in dcalls: + dcall.cancel() current_task.unblock() for task in tasks: if task.alive: @@ -291,7 +298,8 @@ def wait_for_callback(task): done.add(task) if len(done) < count: for task in todo: - task.add_done_callback(wait_for_callback) + dcall = task.add_done_callback(wait_for_callback) + dcalls.append(dcall) current_task.block() yield return done From 1e88c7c3b1615b54e400ef57936b2edc708ea81a Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 2 Nov 2012 17:45:24 -0700 Subject: [PATCH 0093/1502] Add TODO to add_done_call(). --- scheduling.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scheduling.py b/scheduling.py index 3c468b52..dc387b68 100644 --- a/scheduling.py +++ b/scheduling.py @@ -92,6 +92,7 @@ def __init__(self, gen, name=None, *, timeout=None): def add_done_callback(self, done_callback): # For better or for worse, the callback will always be called # with the task as an argument, like concurrent.futures.Future. + # TODO: Call it right away if task is no longer alive. dcall = polling.DelayedCall(None, done_callback, (self,)) self.done_callbacks.append(dcall) self.done_callbacks = [dc for dc in self.done_callbacks From 85982033cacac20a72869660363def37c9bd611a Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 3 Nov 2012 13:35:32 -0700 Subject: [PATCH 0094/1502] Make it so you can just yield a Task to wait for its result (or exception). --- echoclt.py | 10 +++++----- main.py | 2 +- scheduling.py | 27 ++++++++++++++++++++------- 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/echoclt.py b/echoclt.py index 368856bc..c24c573e 100644 --- a/echoclt.py +++ b/echoclt.py @@ -40,12 +40,12 @@ def doit(n): ok = 0 bad = 0 for t in tasks: - yield from t.wait() - assert not t.alive - if t.result: - ok += 1 - else: + try: + yield from t + except Exception: bad += 1 + else: + ok += 1 t1 = time.time() print('ok: ', ok) print('bad:', bad) diff --git a/main.py b/main.py index 99de19c7..c1f9d0a8 100644 --- a/main.py +++ b/main.py @@ -85,7 +85,7 @@ def doit(): ## tasks.add(t) ## print(tasks) - yield from scheduling.with_timeout(0.2, scheduling.sleep(1)) + yield from scheduling.Task(scheduling.sleep(1), timeout=0.2).wait() winners = yield from scheduling.wait_any(tasks) print('And the winners are:', [w.name for w in winners]) tasks = yield from scheduling.wait_all(tasks) diff --git a/scheduling.py b/scheduling.py index dc387b68..6fa01d08 100644 --- a/scheduling.py +++ b/scheduling.py @@ -207,6 +207,26 @@ def wait(self): self.add_done_callback(current_task.unblock) yield + def __iter__(self): + """COROUTINE: Wait, then return result or raise exception. + + This adds a little magic so you can say + + x = yield from Task(gen()) + + and it is equivalent to + + x = yield from gen() + + but with the option to add a timeout (and only a tad slower). + """ + if self.alive: + yield from self.wait() + assert not self.alive + if self.exception is not None: + raise self.exception + return self.result + def run(arg=None): """Run the event loop until it's out of work. @@ -316,13 +336,6 @@ def wait_all(tasks): return wait_for(len(tasks), tasks) -def with_timeout(timeout, gen, name=None): - """COROUTINE: Run generator synchronously with a timeout.""" - assert timeout is not None - task = Task(gen, name, timeout=timeout) - return (yield from task.wait()) - - def map_over(gen, *args, timeout=None): """COROUTINE: map a generator over one or more iterables. From 84b59833a6efd4c2694a14560ab46b00cb071fef Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 3 Nov 2012 13:54:14 -0700 Subject: [PATCH 0095/1502] Add par(). --- scheduling.py | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/scheduling.py b/scheduling.py index 6fa01d08..ce649f5e 100644 --- a/scheduling.py +++ b/scheduling.py @@ -296,6 +296,7 @@ def wait_for(count, tasks): NOTE: Tasks that were cancelled or raised are also considered ready. """ assert tasks + assert all(isinstance(task, Task) for task in tasks) tasks = set(tasks) assert 1 <= count <= len(tasks) current_task = context.current_task @@ -341,7 +342,7 @@ def map_over(gen, *args, timeout=None): E.g. map_over(foo, xs, ys) runs - Task(foo(x, y) for x, y in zip(xs, ys) + Task(foo(x, y)) for x, y in zip(xs, ys) and returns a list of all results (in that order). However if any task raises an exception, the remaining tasks are cancelled and @@ -349,9 +350,37 @@ def map_over(gen, *args, timeout=None): """ # gen is a generator function. tasks = [Task(gobj, timeout=timeout) for gobj in map(gen, *args)] + return (yield from par_tasks(tasks)) + + +def par(*args): + """COROUTINE: Wait for generators, return a list of results. + + Raises as soon as one of the tasks raises an exception (and then + remaining tasks are cancelled). + + This differs from par_tasks() in two ways: + - takes *args instead of list of args + - each arg may be a generator or a task + """ + tasks = [] + for arg in args: + if not isinstance(arg, Task): + # TODO: assert arg is a generator or an iterator? + arg = Task(arg) + tasks.append(arg) + return (yield from par_tasks(tasks)) + + +def par_tasks(tasks): + """COROUTINE: Wait for a list of tasks, return a list of results. + + Raises as soon as one of the tasks raises an exception (and then + remaining tasks are cancelled). + """ todo = set(tasks) while todo: - ts = yield from wait_for(1, todo) + ts = yield from wait_any(todo) for t in ts: assert not t.alive, t todo.remove(t) From 2cf437b13b447b4773e64819d5488d02c602761d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 10 Nov 2012 16:21:21 -0800 Subject: [PATCH 0096/1502] readexactly() should stop on EOF. --- sockets.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sockets.py b/sockets.py index 6c1be4bb..13eeaac0 100644 --- a/sockets.py +++ b/sockets.py @@ -201,8 +201,10 @@ def readexactly(self, n): """COUROUTINE: Read exactly n bytes, or until EOF.""" blocks = [] count = 0 - while n > count: + while n > count and not self.eof: block = yield from self.read(n - count) + if not block: + break blocks.append(block) count += len(block) return b''.join(blocks) From fe55aa864c91e7a5b4f8868496b645949f34b6b4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 10 Nov 2012 17:03:35 -0800 Subject: [PATCH 0097/1502] Fix bug in previous fix. --- sockets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sockets.py b/sockets.py index 13eeaac0..4d64ee46 100644 --- a/sockets.py +++ b/sockets.py @@ -201,7 +201,7 @@ def readexactly(self, n): """COUROUTINE: Read exactly n bytes, or until EOF.""" blocks = [] count = 0 - while n > count and not self.eof: + while count < n: block = yield from self.read(n - count) if not block: break From 53fb18cb69bcc0d279a563303c2a41d9bbbe6849 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 10 Nov 2012 20:37:29 -0800 Subject: [PATCH 0098/1502] Add another mistake. --- TODO | 3 +++ 1 file changed, 3 insertions(+) diff --git a/TODO b/TODO index ab7b53f3..52be4ec3 100644 --- a/TODO +++ b/TODO @@ -142,3 +142,6 @@ MISTAKES I MADE add_done_callback() was called, and this screwed over the task state. Solution: wrap the callback in eventloop.call_later(). Ironically, I had a comment stating there might be a race condition. + +- readexactly() wasn't checking for EOF, so could be looping. + (Worse, the first fix I attempted was wrong.) From 1d35c3523bb30d1242c36a39682d8cffe2cffbe7 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 10 Nov 2012 20:40:03 -0800 Subject: [PATCH 0099/1502] Add another mistake. --- TODO | 3 +++ 1 file changed, 3 insertions(+) diff --git a/TODO b/TODO index 52be4ec3..eb3f9bbb 100644 --- a/TODO +++ b/TODO @@ -143,5 +143,8 @@ MISTAKES I MADE state. Solution: wrap the callback in eventloop.call_later(). Ironically, I had a comment stating there might be a race condition. +- Another bug where I was calling unblock() for the current thread + immediately after calling block(), before yielding. + - readexactly() wasn't checking for EOF, so could be looping. (Worse, the first fix I attempted was wrong.) From 051d9cab1270e176738b30555421f5e8371e6618 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 11 Nov 2012 09:47:39 -0800 Subject: [PATCH 0100/1502] Post-mortem of yesterday's disaster. --- Makefile | 2 +- TODO | 14 ++++++++ http_client.py | 21 +++++++----- main.py | 26 ++++++++------- polling.py | 91 ++++++++++++++++++++++++++++++++++++++++++++++++-- scheduling.py | 61 ++++++++++++++++++++++++++++++--- sockets.py | 12 +------ 7 files changed, 187 insertions(+), 40 deletions(-) diff --git a/Makefile b/Makefile index 533e7427..1da4dbbc 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ PYTHON=python3.3 test: - $(PYTHON) main.py -v + $(PYTHON) main.py -d echo: $(PYTHON) echosvr.py -v diff --git a/TODO b/TODO index eb3f9bbb..9dc556fa 100644 --- a/TODO +++ b/TODO @@ -148,3 +148,17 @@ MISTAKES I MADE - readexactly() wasn't checking for EOF, so could be looping. (Worse, the first fix I attempted was wrong.) + +- Spent a day trying to understand why a tentative patch trying to + move the recv() implementation into the eventloop (or the pollster) + resulted in problems cancelling a recv() call. Ultimately the + problem is that the cancellation mechanism is part of the coroutine + scheduler, which simply throws an exception into a task when it next + runs, and there isn't anything to be interrupted in the eventloop; + but the eventloop still has a reader registered (which will never + fire because I suspended the server -- that's my test case :-). + Then, the eventloop keeps running until the last file descriptor is + unregistered. What contributed to this disaster? + * I didn't build the whole infrastructure, just played with recv() + * I don't have unittests + * I don't have good logging to see what is going diff --git a/http_client.py b/http_client.py index 8937ba20..b828660c 100644 --- a/http_client.py +++ b/http_client.py @@ -4,6 +4,7 @@ """ # Stdlib. +import logging import re import time @@ -13,14 +14,15 @@ def urlfetch(host, port=None, path='/', method='GET', body=None, hdrs=None, encoding='utf-8', ssl=None, af=0): - """COROUTINE: Make an HTTP 1.0 request.""" - t0 = time.time() - if port is None: - if ssl: - port = 443 - else: - port = 80 - trans = yield from sockets.create_transport(host, port, ssl=ssl, af=af) + """COROUTINE: Make an HTTP 1.0 request.""" + t0 = time.time() + if port is None: + if ssl: + port = 443 + else: + port = 80 + trans = yield from sockets.create_transport(host, port, ssl=ssl, af=af) + try: yield from trans.send(method.encode(encoding) + b' ' + path.encode(encoding) + b' HTTP/1.0\r\n') if hdrs: @@ -76,3 +78,6 @@ def urlfetch(host, port=None, path='/', method='GET', result = (host, port, path, int(status), len(data), round(t1-t0, 3)) ## print(result) return result + except: + logging.exception('********** Exception in urlfetch **********') + trans.close() diff --git a/main.py b/main.py index c1f9d0a8..d2dac196 100644 --- a/main.py +++ b/main.py @@ -56,24 +56,24 @@ def doit(): # This references NDB's default test service. # (Sadly the service is single-threaded.) - task1 = scheduling.Task(http_client.urlfetch('localhost', 8080, path='/'), - 'root', timeout=TIMEOUT) - tasks.add(task1) +## task1 = scheduling.Task(http_client.urlfetch('localhost', 8080, path='/'), +## 'root', timeout=TIMEOUT) +## tasks.add(task1) task2 = scheduling.Task(http_client.urlfetch('127.0.0.1', 8080, path='/home'), 'home', timeout=TIMEOUT) tasks.add(task2) - # Fetch python.org home page. - task3 = scheduling.Task(http_client.urlfetch('python.org', 80, path='/'), - 'python', timeout=TIMEOUT) - tasks.add(task3) +## # Fetch python.org home page. +## task3 = scheduling.Task(http_client.urlfetch('python.org', 80, path='/'), +## 'python', timeout=TIMEOUT) +## tasks.add(task3) - # Fetch XKCD home page using SSL. (Doesn't like IPv6.) - task4 = scheduling.Task(http_client.urlfetch('xkcd.com', ssl=True, path='/', - af=socket.AF_INET), - 'xkcd', timeout=TIMEOUT) - tasks.add(task4) +## # Fetch XKCD home page using SSL. (Doesn't like IPv6.) +## task4 = scheduling.Task(http_client.urlfetch('xkcd.com', ssl=True, path='/', +## af=socket.AF_INET), +## 'xkcd', timeout=TIMEOUT) +## tasks.add(task4) ## # Fetch many links from python.org (/x.y.z). ## for x in '123': @@ -90,6 +90,8 @@ def doit(): print('And the winners are:', [w.name for w in winners]) tasks = yield from scheduling.wait_all(tasks) print('And the players were:', [t.name for t in tasks]) + print('readers =', scheduling.context.eventloop.pollster.readers) + print('tasks =', tasks) return tasks diff --git a/polling.py b/polling.py index 8e160b3a..33da01e6 100644 --- a/polling.py +++ b/polling.py @@ -36,12 +36,25 @@ import collections import concurrent.futures +import errno import heapq import logging import os import select import time +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + class PollsterBase: """Base class for all polling implementations. @@ -411,11 +424,16 @@ def run_once(self): if self.pollster.pollable(): if self.scheduled: when = self.scheduled[0].when - timeout = max(0, when - time.time()) + timeout = when - time.time() else: - timeout = None + timeout = 10 + timeout = max(0, min(1, timeout)) t0 = time.time() - events = self.pollster.poll(timeout) + events = [] + try: + events = self.pollster.poll(timeout) + except KeyboardInterrupt: + import pdb; pdb.set_trace() t1 = time.time() argstr = '' if timeout is None else ' %.3f' % timeout if t1-t0 >= 1: @@ -423,6 +441,11 @@ def run_once(self): else: level = logging.DEBUG logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + if level == logging.INFO: + logging.info(' ready = %r', self.ready) + logging.info(' scheduled = %r', self.scheduled) + logging.info(' readers = %r', self.pollster.readers) + logging.info(' writers = %r', self.pollster.writers) for dcall in events: self.add_callback(dcall) @@ -444,6 +467,68 @@ def run(self): while self.ready or self.scheduled or self.pollster.pollable(): self.run_once() + # Experiment with I/O in the event loop. + def recv(self, sock, n, callback): + """Receive up to n bytes from a socket. + + Call callback(data) when we have some data; 1 <= len(data) <= n. + Call callback(b'') when the connection is dropped/closed. + + If data is available (or the connection is dropped/closed) + immediately the callback is called before this function + returns. + + Otherwise, register the socket's fd for reading with a helper + callback that ensures that eventually the callback is called + and the registration removed. + """ + def try_recv(direct=True): + # Helper: Either call the callback and return True, or + # else return False. On errors call the callback with b'' + # and return True. + try: + value = sock.recv(n) + except IOError as err: + if err.errno in _TRYAGAIN: + return False + elif err.errno in _DISCONNECTED: + value = b'' + else: + logging.exception('[a] Unexpected error from recv()') + raise + except Exception: + logging.exception('[b] Unexpected error from recv()') + raise + logging.info('recv(%s, %s) returned %d bytes', + sock, n, len(value)) + if direct: + try: + callback(value) + except Exception: + logging.exception('Error in callback for recv()') + else: + self.call_soon(callback, value) + return True + + if try_recv(direct=False): + logging.info('Early return from recv()') + return + + fd = sock.fileno() + + def read_callback(): + # Helper: Callback for add_reader. + done = True + try: + done = try_recv() + finally: + logging.info('recv(%d, %d): done=%s', fd, n, done) + if done: + logging.info('!!!!!!!!!!!!! removing reader %d', fd) + self.remove_reader(fd) + + self.add_reader(fd, read_callback) + MAX_WORKERS = 5 # Default max workers when creating an executor. diff --git a/scheduling.py b/scheduling.py index ce649f5e..eaccce08 100644 --- a/scheduling.py +++ b/scheduling.py @@ -78,6 +78,7 @@ def __init__(self, gen, name=None, *, timeout=None): self.canceleer = None if timeout is not None: self.canceleer = self.eventloop.call_later(timeout, self.cancel) + self.next_value = None self.blocked = False self.unblocker = None self.cancelled = False @@ -132,9 +133,15 @@ def step(self): if self.must_cancel: self.must_cancel = False self.cancelled = True + logging.debug('Throwing CancelledError into %s', self.gen) self.gen.throw(CancelledError()) else: - next(self.gen) + next_value = self.next_value + self.next_value = None + if next_value is None: + next(self.gen) + else: + self.gen.send(next_value) except StopIteration as exc: self.alive = False self.result = exc.value @@ -167,15 +174,15 @@ def block(self, unblock_callback=None, *unblock_args): self.blocked = True self.unblocker = (unblock_callback, unblock_args) - def unblock_if_alive(self): + def unblock_if_alive(self, value=None): if self.alive: - self.unblock() + self.unblock(value) - def unblock(self, unused=None): - # Ignore optional argument so we can be a Future's done_callback. + def unblock(self, value=None): assert self.alive, self assert self.blocked, self self.blocked = False + self.next_value = value unblock_callback, unblock_args = self.unblocker if unblock_callback is not None: try: @@ -389,3 +396,47 @@ def par_tasks(tasks): other.cancel() raise t.exception return [t.result for t in tasks] + + +################################################################################ + +_DOC = """Pattern on how to use this: + + def my_thing(sock): + "COROUTINE: ..." + blah blah blah + data = yield scheduling.recv(sock, n) + + def my_other_thing(arg): + "COROUTINE: ..." + ... + conn = yield scheduling.make_connection(address, [family etc.]) + + def my_blah(arg): + "COROUTINE: ..." + ... + lsnr = yield scheduling.make_listener(address, [family etc.]) + conn = yield scheduling.accept(lsnr) + +""" + + +def recv(sock, n): + """COROUTINE""" + try: + task = context.current_task + # This can't be right. The add_reader call isn't here, + # so why should the remove_reader call be here? + # However, without this here, if recv() is cancelled, + # the exception interrupts the yield below, and somehow + # EventLoop.recv() in polling.py never gets called. + # Hmmm... Maybe that's the bug: we need an errback + # or something. + task.block(context.eventloop.remove_reader, sock.fileno()) + context.eventloop.recv(sock, n, task.unblock_if_alive) + value = yield + return value + except: + logging.exception('Exception in scheduling.recv(%r, %r)', + sock, n) + raise diff --git a/sockets.py b/sockets.py index 4d64ee46..3a23d94d 100644 --- a/sockets.py +++ b/sockets.py @@ -61,17 +61,7 @@ def recv(self, n): returns b''. """ assert n >= 0, n - while True: - try: - return self.sock.recv(n) - except socket.error as err: - if err.errno in _TRYAGAIN: - pass - elif err.errno in _DISCONNECTED: - return b'' - else: - raise # Unexpected, propagate. - yield from scheduling.block_r(self.sock.fileno()) + return (yield from scheduling.recv(self.sock, n)) def send(self, data): """COROUTINE; Send data to the socket, blocking until all written. From 2b6425909076e5e92f9a3a26b4c626a1f806dfb8 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 5 Dec 2012 11:47:19 -0800 Subject: [PATCH 0101/1502] Roll back last change -- it was an experiment accidentally committed. --- Makefile | 2 +- http_client.py | 21 +++++------- main.py | 26 +++++++-------- polling.py | 91 ++------------------------------------------------ scheduling.py | 61 +++------------------------------ sockets.py | 12 ++++++- 6 files changed, 40 insertions(+), 173 deletions(-) diff --git a/Makefile b/Makefile index 1da4dbbc..533e7427 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ PYTHON=python3.3 test: - $(PYTHON) main.py -d + $(PYTHON) main.py -v echo: $(PYTHON) echosvr.py -v diff --git a/http_client.py b/http_client.py index b828660c..8937ba20 100644 --- a/http_client.py +++ b/http_client.py @@ -4,7 +4,6 @@ """ # Stdlib. -import logging import re import time @@ -14,15 +13,14 @@ def urlfetch(host, port=None, path='/', method='GET', body=None, hdrs=None, encoding='utf-8', ssl=None, af=0): - """COROUTINE: Make an HTTP 1.0 request.""" - t0 = time.time() - if port is None: - if ssl: - port = 443 - else: - port = 80 - trans = yield from sockets.create_transport(host, port, ssl=ssl, af=af) - try: + """COROUTINE: Make an HTTP 1.0 request.""" + t0 = time.time() + if port is None: + if ssl: + port = 443 + else: + port = 80 + trans = yield from sockets.create_transport(host, port, ssl=ssl, af=af) yield from trans.send(method.encode(encoding) + b' ' + path.encode(encoding) + b' HTTP/1.0\r\n') if hdrs: @@ -78,6 +76,3 @@ def urlfetch(host, port=None, path='/', method='GET', result = (host, port, path, int(status), len(data), round(t1-t0, 3)) ## print(result) return result - except: - logging.exception('********** Exception in urlfetch **********') - trans.close() diff --git a/main.py b/main.py index d2dac196..c1f9d0a8 100644 --- a/main.py +++ b/main.py @@ -56,24 +56,24 @@ def doit(): # This references NDB's default test service. # (Sadly the service is single-threaded.) -## task1 = scheduling.Task(http_client.urlfetch('localhost', 8080, path='/'), -## 'root', timeout=TIMEOUT) -## tasks.add(task1) + task1 = scheduling.Task(http_client.urlfetch('localhost', 8080, path='/'), + 'root', timeout=TIMEOUT) + tasks.add(task1) task2 = scheduling.Task(http_client.urlfetch('127.0.0.1', 8080, path='/home'), 'home', timeout=TIMEOUT) tasks.add(task2) -## # Fetch python.org home page. -## task3 = scheduling.Task(http_client.urlfetch('python.org', 80, path='/'), -## 'python', timeout=TIMEOUT) -## tasks.add(task3) + # Fetch python.org home page. + task3 = scheduling.Task(http_client.urlfetch('python.org', 80, path='/'), + 'python', timeout=TIMEOUT) + tasks.add(task3) -## # Fetch XKCD home page using SSL. (Doesn't like IPv6.) -## task4 = scheduling.Task(http_client.urlfetch('xkcd.com', ssl=True, path='/', -## af=socket.AF_INET), -## 'xkcd', timeout=TIMEOUT) -## tasks.add(task4) + # Fetch XKCD home page using SSL. (Doesn't like IPv6.) + task4 = scheduling.Task(http_client.urlfetch('xkcd.com', ssl=True, path='/', + af=socket.AF_INET), + 'xkcd', timeout=TIMEOUT) + tasks.add(task4) ## # Fetch many links from python.org (/x.y.z). ## for x in '123': @@ -90,8 +90,6 @@ def doit(): print('And the winners are:', [w.name for w in winners]) tasks = yield from scheduling.wait_all(tasks) print('And the players were:', [t.name for t in tasks]) - print('readers =', scheduling.context.eventloop.pollster.readers) - print('tasks =', tasks) return tasks diff --git a/polling.py b/polling.py index 33da01e6..8e160b3a 100644 --- a/polling.py +++ b/polling.py @@ -36,25 +36,12 @@ import collections import concurrent.futures -import errno import heapq import logging import os import select import time -# Errno values indicating the connection was disconnected. -_DISCONNECTED = frozenset((errno.ECONNRESET, - errno.ENOTCONN, - errno.ESHUTDOWN, - errno.ECONNABORTED, - errno.EPIPE, - errno.EBADF, - )) - -# Errno values indicating the socket isn't ready for I/O just yet. -_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) - class PollsterBase: """Base class for all polling implementations. @@ -424,16 +411,11 @@ def run_once(self): if self.pollster.pollable(): if self.scheduled: when = self.scheduled[0].when - timeout = when - time.time() + timeout = max(0, when - time.time()) else: - timeout = 10 - timeout = max(0, min(1, timeout)) + timeout = None t0 = time.time() - events = [] - try: - events = self.pollster.poll(timeout) - except KeyboardInterrupt: - import pdb; pdb.set_trace() + events = self.pollster.poll(timeout) t1 = time.time() argstr = '' if timeout is None else ' %.3f' % timeout if t1-t0 >= 1: @@ -441,11 +423,6 @@ def run_once(self): else: level = logging.DEBUG logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) - if level == logging.INFO: - logging.info(' ready = %r', self.ready) - logging.info(' scheduled = %r', self.scheduled) - logging.info(' readers = %r', self.pollster.readers) - logging.info(' writers = %r', self.pollster.writers) for dcall in events: self.add_callback(dcall) @@ -467,68 +444,6 @@ def run(self): while self.ready or self.scheduled or self.pollster.pollable(): self.run_once() - # Experiment with I/O in the event loop. - def recv(self, sock, n, callback): - """Receive up to n bytes from a socket. - - Call callback(data) when we have some data; 1 <= len(data) <= n. - Call callback(b'') when the connection is dropped/closed. - - If data is available (or the connection is dropped/closed) - immediately the callback is called before this function - returns. - - Otherwise, register the socket's fd for reading with a helper - callback that ensures that eventually the callback is called - and the registration removed. - """ - def try_recv(direct=True): - # Helper: Either call the callback and return True, or - # else return False. On errors call the callback with b'' - # and return True. - try: - value = sock.recv(n) - except IOError as err: - if err.errno in _TRYAGAIN: - return False - elif err.errno in _DISCONNECTED: - value = b'' - else: - logging.exception('[a] Unexpected error from recv()') - raise - except Exception: - logging.exception('[b] Unexpected error from recv()') - raise - logging.info('recv(%s, %s) returned %d bytes', - sock, n, len(value)) - if direct: - try: - callback(value) - except Exception: - logging.exception('Error in callback for recv()') - else: - self.call_soon(callback, value) - return True - - if try_recv(direct=False): - logging.info('Early return from recv()') - return - - fd = sock.fileno() - - def read_callback(): - # Helper: Callback for add_reader. - done = True - try: - done = try_recv() - finally: - logging.info('recv(%d, %d): done=%s', fd, n, done) - if done: - logging.info('!!!!!!!!!!!!! removing reader %d', fd) - self.remove_reader(fd) - - self.add_reader(fd, read_callback) - MAX_WORKERS = 5 # Default max workers when creating an executor. diff --git a/scheduling.py b/scheduling.py index eaccce08..ce649f5e 100644 --- a/scheduling.py +++ b/scheduling.py @@ -78,7 +78,6 @@ def __init__(self, gen, name=None, *, timeout=None): self.canceleer = None if timeout is not None: self.canceleer = self.eventloop.call_later(timeout, self.cancel) - self.next_value = None self.blocked = False self.unblocker = None self.cancelled = False @@ -133,15 +132,9 @@ def step(self): if self.must_cancel: self.must_cancel = False self.cancelled = True - logging.debug('Throwing CancelledError into %s', self.gen) self.gen.throw(CancelledError()) else: - next_value = self.next_value - self.next_value = None - if next_value is None: - next(self.gen) - else: - self.gen.send(next_value) + next(self.gen) except StopIteration as exc: self.alive = False self.result = exc.value @@ -174,15 +167,15 @@ def block(self, unblock_callback=None, *unblock_args): self.blocked = True self.unblocker = (unblock_callback, unblock_args) - def unblock_if_alive(self, value=None): + def unblock_if_alive(self): if self.alive: - self.unblock(value) + self.unblock() - def unblock(self, value=None): + def unblock(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. assert self.alive, self assert self.blocked, self self.blocked = False - self.next_value = value unblock_callback, unblock_args = self.unblocker if unblock_callback is not None: try: @@ -396,47 +389,3 @@ def par_tasks(tasks): other.cancel() raise t.exception return [t.result for t in tasks] - - -################################################################################ - -_DOC = """Pattern on how to use this: - - def my_thing(sock): - "COROUTINE: ..." - blah blah blah - data = yield scheduling.recv(sock, n) - - def my_other_thing(arg): - "COROUTINE: ..." - ... - conn = yield scheduling.make_connection(address, [family etc.]) - - def my_blah(arg): - "COROUTINE: ..." - ... - lsnr = yield scheduling.make_listener(address, [family etc.]) - conn = yield scheduling.accept(lsnr) - -""" - - -def recv(sock, n): - """COROUTINE""" - try: - task = context.current_task - # This can't be right. The add_reader call isn't here, - # so why should the remove_reader call be here? - # However, without this here, if recv() is cancelled, - # the exception interrupts the yield below, and somehow - # EventLoop.recv() in polling.py never gets called. - # Hmmm... Maybe that's the bug: we need an errback - # or something. - task.block(context.eventloop.remove_reader, sock.fileno()) - context.eventloop.recv(sock, n, task.unblock_if_alive) - value = yield - return value - except: - logging.exception('Exception in scheduling.recv(%r, %r)', - sock, n) - raise diff --git a/sockets.py b/sockets.py index 3a23d94d..4d64ee46 100644 --- a/sockets.py +++ b/sockets.py @@ -61,7 +61,17 @@ def recv(self, n): returns b''. """ assert n >= 0, n - return (yield from scheduling.recv(self.sock, n)) + while True: + try: + return self.sock.recv(n) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return b'' + else: + raise # Unexpected, propagate. + yield from scheduling.block_r(self.sock.fileno()) def send(self, data): """COROUTINE; Send data to the socket, blocking until all written. From 48c41a52e758577c564bf0887285495c217944ce Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 6 Dec 2012 11:30:13 -0800 Subject: [PATCH 0102/1502] Add feedback from Ben Darnell. --- TODO | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/TODO b/TODO index 9dc556fa..df868cca 100644 --- a/TODO +++ b/TODO @@ -27,6 +27,26 @@ TO DO LARGER TASKS - Restructure directory, move demos and benchmarks to subdirectories. +FROM BEN DARNELL (Tornado) + +- The waker pipe in ThreadRunner should really be in EventLoop itself + - we need to be able to call call_soon (or some equivalent) from + threads that were not created by ThreadRunner. In Tornado I ended + up needing two functions, add_callback (usable from any thread) and + add_callback_from_signal (usable only from signal handlers). + +- Timeouts should ideally be based on time.monotonic, although this + requires some extra complexity to deal with the cases where you + actually do want time.time. (in tornado, the clock used is + configurable on a per-ioloop basis, which is not ideal but is + workable) + +- I'm sure you've heard this from the twisted guys by now, but to + properly support completion-based event loops like IOCP you need to + be able to swap out most of sockets.py (the layers below + BufferedReader) for an alternative implementation. + + TO DO LATER - Wrap select(), epoll() etc. in try/except checking for EINTR. From c24aead78084342fe4ac9f091f48c01c3adf3a5c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 6 Dec 2012 14:46:49 -0800 Subject: [PATCH 0103/1502] Checkpoint: Initial draft transports.py. Does not work. --- transports.py | 138 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 transports.py diff --git a/transports.py b/transports.py new file mode 100644 index 00000000..ce4d2b30 --- /dev/null +++ b/transports.py @@ -0,0 +1,138 @@ +"""Transports and Protocols, actually. + +Inspired by Twisted, PEP 3153 and github.com/lvh/async-pep. +""" + +# Stdlib imports. +import errno +import socket +import ssl + +# Local imports. +import scheduling + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + + +class Transport: + + def write(self, data): + XXX + + def writelines(self, list_of_data): # Not just for lines. + XXX + + def close(self): + XXX + + def abort(self): + XXX + + def half_close(self): # Closes the write end after flushing. + XXX + + def pause(self): + XXX + + def resume(self): + XXX + + +class Protocol: + + def connection_made(self, transport): + XXX + + def data_received(self, data): + XXX + + def connection_lost(self, exc): # Also when connect() failed. + XXX + + +# XXX The rest is platform specific and should move elsewhere. + +class SocketTransport(Transport): + + def __init__(self, eventloop, protocol, sock): + self._eventloop = None + self._protocol = protocol + self._sock = sock + + def _on_readable(self): + try: + data = self._sock.recv(8192) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._eventloop.remove_reader(sock.fileno()) + sock.close() + self._protocol.connection_lost(exc) # XXX calL_soon()? + else: + self._protocol.data_received(data) # XXX call_soon()? + + # XXX implement write buffering. + + +def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, + use_ssl=None): + if port is None: + port = 443 if use_ssl else 80 + if use_ssl is None: + use_ssl = (port == 443) + eventloop = scheduling.context.eventloop + + # XXX Move all this to private methods on SocketTransport. + + def on_addrinfo(infos, exc): + # XXX Make infos into an iterator, to avoid pop()? + if not infos: + if exc is not None: + protocol.connection_lost(exc) + return + protocol.connection_lost(IOError(0, 'No more infos to try')) + return + + af, socktype, proto, cname, address = infos.pop(0) + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + + try: + sock.connect(address) + except socket.error as exc: + if exc.errno != errno.EINPROGRESS: + sock.close() + on_addrinfo(infos, exc) # XXX Use eventloop.call_soon()? + return + + def on_writable(): + eventloop.remove_writer(sock.fileno()) + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + sock.close() + # Try the next getaddrinfo() return value. + on_addrinfo(infos, None) + return + + if use_ssl: + XXX + else: + transport = SocketTransport(eventloop, protocol, sock) + protocol.connection_made(transport) # XXX call_soon()? + eventloop.add_reader(sock.fileno(), transport._on_readable) + + eventloop.add_writer(sock.fileno(), on_writable) + + # XXX Implement EventLoop.call_in_thread(). + eventloop.call_in_thread(socket.getaddrinfo, + host, port, af, socktype, proto, + callback=on_addrinfo) From 42b068a13efcf91b2263da5f0cca6fb5bb253dc0 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 6 Dec 2012 15:01:30 -0800 Subject: [PATCH 0104/1502] Move Context from scheduling.py to polling.py. --- polling.py | 34 ++++++++++++++++++++++++++++++++++ scheduling.py | 33 +-------------------------------- transports.py | 4 ++-- 3 files changed, 37 insertions(+), 34 deletions(-) diff --git a/polling.py b/polling.py index 8e160b3a..c9e630d3 100644 --- a/polling.py +++ b/polling.py @@ -40,6 +40,7 @@ import logging import os import select +import threading import time @@ -495,3 +496,36 @@ def done_callback(future): os.write(self.pipe_write_fd, b'x') future.add_done_callback(done_callback) return future + + +class Context(threading.local): + """Thread-local context. + + We use this to avoid having to explicitly pass around an event loop + or something to hold the current task. + + TODO: Add an API so frameworks can substitute a different notion + of context more easily. + """ + + def __init__(self, eventloop=None, threadrunner=None): + # Default event loop and thread runner are lazily constructed + # when first accessed. + self._eventloop = eventloop + self._threadrunner = threadrunner + self.current_task = None # For the benefit of scheduling.py. + + @property + def eventloop(self): + if self._eventloop is None: + self._eventloop = EventLoop() + return self._eventloop + + @property + def threadrunner(self): + if self._threadrunner is None: + self._threadrunner = ThreadRunner(self.eventloop) + return self._threadrunner + + +context = Context() # Thread-local! diff --git a/scheduling.py b/scheduling.py index ce649f5e..bca14dcf 100644 --- a/scheduling.py +++ b/scheduling.py @@ -19,7 +19,6 @@ # Standard library imports (keep in alphabetic order). from concurrent.futures import CancelledError, TimeoutError import logging -import threading import time import types @@ -27,37 +26,7 @@ import polling -class Context(threading.local): - """Thread-local context. - - We use this to avoid having to explicitly pass around an event loop - or something to hold the current task. - - TODO: Add an API so frameworks can substitute a different notion - of context more easily. - """ - - def __init__(self, eventloop=None, threadrunner=None): - # Default event loop and thread runner are lazily constructed - # when first accessed. - self._eventloop = eventloop - self._threadrunner = threadrunner - self.current_task = None - - @property - def eventloop(self): - if self._eventloop is None: - self._eventloop = polling.EventLoop() - return self._eventloop - - @property - def threadrunner(self): - if self._threadrunner is None: - self._threadrunner = polling.ThreadRunner(self.eventloop) - return self._threadrunner - - -context = Context() # Thread-local! +context = polling.context class Task: diff --git a/transports.py b/transports.py index ce4d2b30..7513df8e 100644 --- a/transports.py +++ b/transports.py @@ -9,7 +9,7 @@ import ssl # Local imports. -import scheduling +import polling # Errno values indicating the connection was disconnected. _DISCONNECTED = frozenset((errno.ECONNRESET, @@ -89,7 +89,7 @@ def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, port = 443 if use_ssl else 80 if use_ssl is None: use_ssl = (port == 443) - eventloop = scheduling.context.eventloop + eventloop = polling.context.eventloop # XXX Move all this to private methods on SocketTransport. From 7b17732753b6125888515693b3e776919b3f6ffb Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 6 Dec 2012 15:28:24 -0800 Subject: [PATCH 0105/1502] Use threadrunner.submit() to call getaddrinfo(). --- transports.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/transports.py b/transports.py index 7513df8e..1b3a279d 100644 --- a/transports.py +++ b/transports.py @@ -90,6 +90,7 @@ def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, if use_ssl is None: use_ssl = (port == 443) eventloop = polling.context.eventloop + threadrunner = polling.context.threadrunner # XXX Move all this to private methods on SocketTransport. @@ -132,7 +133,13 @@ def on_writable(): eventloop.add_writer(sock.fileno(), on_writable) - # XXX Implement EventLoop.call_in_thread(). - eventloop.call_in_thread(socket.getaddrinfo, - host, port, af, socktype, proto, - callback=on_addrinfo) + future = threadrunner.submit(socket.getaddrinfo, + host, port, af, socktype, proto) + def on_future_done(fut): + exc = fut.exception() + if exc is None: + infos = fut.result() + else: + infos = None + on_addrinfo(infos, exc) + future.add_done_callback(on_future_done) From dfb29e4f59adfb037cb41f640df215949faa3c8f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 6 Dec 2012 16:27:12 -0800 Subject: [PATCH 0106/1502] Transport works a little bit. --- transports.py | 57 ++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 52 insertions(+), 5 deletions(-) diff --git a/transports.py b/transports.py index 1b3a279d..f85fdec0 100644 --- a/transports.py +++ b/transports.py @@ -5,8 +5,10 @@ # Stdlib imports. import errno +import logging import socket import ssl +import sys # Local imports. import polling @@ -65,7 +67,7 @@ def connection_lost(self, exc): # Also when connect() failed. class SocketTransport(Transport): def __init__(self, eventloop, protocol, sock): - self._eventloop = None + self._eventloop = eventloop self._protocol = protocol self._sock = sock @@ -78,9 +80,16 @@ def _on_readable(self): sock.close() self._protocol.connection_lost(exc) # XXX calL_soon()? else: - self._protocol.data_received(data) # XXX call_soon()? + if not data: + self._eventloop.remove_reader(self._sock.fileno()) + self._sock.close() + self._protocol.connection_lost(None) + else: + self._protocol.data_received(data) # XXX call_soon()? - # XXX implement write buffering. + def write(self, data): + # XXX implement write buffering. + self._sock.sendall(data) def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, @@ -95,6 +104,7 @@ def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, # XXX Move all this to private methods on SocketTransport. def on_addrinfo(infos, exc): + logging.debug('on_addrinfo(, %r)', len(infos), exc) # XXX Make infos into an iterator, to avoid pop()? if not infos: if exc is not None: @@ -136,10 +146,47 @@ def on_writable(): future = threadrunner.submit(socket.getaddrinfo, host, port, af, socktype, proto) def on_future_done(fut): + logging.debug('Future done.') exc = fut.exception() if exc is None: infos = fut.result() else: infos = None - on_addrinfo(infos, exc) - future.add_done_callback(on_future_done) + eventloop.call_soon(on_addrinfo, infos, exc) + future.add_done_callback(lambda fut: eventloop.call_soon(on_future_done, fut)) + + +def main(): # Testing... + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + class TestProtocol(Protocol): + def connection_made(self, transport): + logging.debug('Connection made.') + self.transport = transport + self.transport.write(b'GET / HTTP/1.0\r\n\r\n') + ## self.transport.half_close() + def data_received(self, data): + logging.info('Received %d bytes: %r', len(data), data) + def connection_lost(self, exc): + logging.debug('Connection lost: %r', exc) + + tp = TestProtocol() + logging.info('tp = %r', tp) + make_connection(tp, 'python.org') + logging.info('Running...') + polling.context.eventloop.run() + logging.info('Done.') + + +if __name__ == '__main__': + main() From 6f4b1e8d65eee81c014e0c0568dd2a0cbc301de5 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 6 Dec 2012 16:35:57 -0800 Subject: [PATCH 0107/1502] Make transport work a little better. --- transports.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/transports.py b/transports.py index f85fdec0..427f501a 100644 --- a/transports.py +++ b/transports.py @@ -76,8 +76,8 @@ def _on_readable(self): data = self._sock.recv(8192) except socket.error as exc: if exc.errno not in _TRYAGAIN: - self._eventloop.remove_reader(sock.fileno()) - sock.close() + self._eventloop.remove_reader(self._sock.fileno()) + self._sock.close() self._protocol.connection_lost(exc) # XXX calL_soon()? else: if not data: @@ -98,6 +98,8 @@ def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, port = 443 if use_ssl else 80 if use_ssl is None: use_ssl = (port == 443) + if not socktype: + socktype = socket.SOCK_STREAM eventloop = polling.context.eventloop threadrunner = polling.context.threadrunner @@ -173,7 +175,7 @@ class TestProtocol(Protocol): def connection_made(self, transport): logging.debug('Connection made.') self.transport = transport - self.transport.write(b'GET / HTTP/1.0\r\n\r\n') + self.transport.write(b'GET / HTTP/1.0\r\nHost: python.org\r\n\r\n') ## self.transport.half_close() def data_received(self, data): logging.info('Received %d bytes: %r', len(data), data) From eb2924be86405af3c0686987cfd4e5a61769e326 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 6 Dec 2012 16:57:03 -0800 Subject: [PATCH 0108/1502] Add a callback to ThreadRunner.submit(). --- polling.py | 9 ++++++--- scheduling.py | 16 +++++----------- transports.py | 11 ++++++----- 3 files changed, 17 insertions(+), 19 deletions(-) diff --git a/polling.py b/polling.py index c9e630d3..07ae3d8e 100644 --- a/polling.py +++ b/polling.py @@ -474,11 +474,11 @@ def read_callback(self): self.eventloop.remove_reader(self.pipe_read_fd) assert self.active_count >= 0, self.active_count - def submit(self, func, *args, executor=None): + def submit(self, func, *args, executor=None, callback=None): """Submit a function to the thread pool. This returns a concurrent.futures.Future instance. The caller - should not wait for that, but rather add a callback to it. + should not wait for that, but rather use the callback argument.. """ if executor is None: executor = self.executor @@ -492,7 +492,10 @@ def submit(self, func, *args, executor=None): if self.active_count == 0: self.eventloop.add_reader(self.pipe_read_fd, self.read_callback) self.active_count += 1 - def done_callback(future): + def done_callback(fut): + if callback is not None: + self.eventloop.call_soon(callback, fut) + # TODO: Wake up the pipe in call_soon()? os.write(self.pipe_write_fd, b'x') future.add_done_callback(done_callback) return future diff --git a/scheduling.py b/scheduling.py index bca14dcf..3864571d 100644 --- a/scheduling.py +++ b/scheduling.py @@ -136,7 +136,8 @@ def block(self, unblock_callback=None, *unblock_args): self.blocked = True self.unblocker = (unblock_callback, unblock_args) - def unblock_if_alive(self): + def unblock_if_alive(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. if self.alive: self.unblock() @@ -241,17 +242,10 @@ def call_in_thread(func, *args, executor=None): """COROUTINE: Run a function in a thread.""" task = context.current_task eventloop = context.eventloop - future = context.threadrunner.submit(func, *args, executor=executor) + future = context.threadrunner.submit(func, *args, + executor=executor, + callback=task.unblock_if_alive) task.block(future.cancel) - # If the thread managed to complete before we get here, - # add_done_callback() will call the callback right now. Make sure - # the unblock() call doesn't happen until later. But then, the - # task may already have been cancelled (and it may have been too - # late to cancel the Future) so it should be okay if this call - # finds the task deceased. For that purpose we have - # unblock_if_alive(). - future.add_done_callback( - lambda _: eventloop.call_soon(task.unblock_if_alive)) yield assert future.done() return future.result() diff --git a/transports.py b/transports.py index 427f501a..5ed365b4 100644 --- a/transports.py +++ b/transports.py @@ -145,8 +145,6 @@ def on_writable(): eventloop.add_writer(sock.fileno(), on_writable) - future = threadrunner.submit(socket.getaddrinfo, - host, port, af, socktype, proto) def on_future_done(fut): logging.debug('Future done.') exc = fut.exception() @@ -155,7 +153,10 @@ def on_future_done(fut): else: infos = None eventloop.call_soon(on_addrinfo, infos, exc) - future.add_done_callback(lambda fut: eventloop.call_soon(on_future_done, fut)) + + future = threadrunner.submit(socket.getaddrinfo, + host, port, af, socktype, proto, + callback=on_future_done) def main(): # Testing... @@ -178,12 +179,12 @@ def connection_made(self, transport): self.transport.write(b'GET / HTTP/1.0\r\nHost: python.org\r\n\r\n') ## self.transport.half_close() def data_received(self, data): - logging.info('Received %d bytes: %r', len(data), data) + logging.debug('Received %d bytes: %r', len(data), data) def connection_lost(self, exc): logging.debug('Connection lost: %r', exc) tp = TestProtocol() - logging.info('tp = %r', tp) + logging.debug('tp = %r', tp) make_connection(tp, 'python.org') logging.info('Running...') polling.context.eventloop.run() From 30a8d299366683707e0a5d65a923c645b8e8cd1f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 10 Dec 2012 10:55:12 -0800 Subject: [PATCH 0109/1502] Add some docstrings to Transport and Protocol ABCs. --- transports.py | 101 ++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 90 insertions(+), 11 deletions(-) diff --git a/transports.py b/transports.py index 5ed365b4..c74f2142 100644 --- a/transports.py +++ b/transports.py @@ -1,6 +1,8 @@ """Transports and Protocols, actually. Inspired by Twisted, PEP 3153 and github.com/lvh/async-pep. + +THIS IS NOT REAL CODE! IT IS JUST AN EXPERIMENT. """ # Stdlib imports. @@ -27,39 +29,116 @@ class Transport: + """ABC representing a transport. + + There may be many implementations. The user never instantiates + this directly; they call some utility function, passing it a + protocol, and the utility function will call the protocol's + connection_made() method with a transport (or it will call + connection_lost() with an exception if it fails to create the + desired transport). + """ def write(self, data): - XXX + """Write some data (bytes) to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplemented def writelines(self, list_of_data): # Not just for lines. - XXX + """Write a list (or any iterable) of data (bytes) to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) def close(self): - XXX + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data will + be received. When all buffered data is flushed, the protocol's + connection_lost() method is called with None as its argument. + """ def abort(self): - XXX + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method is called with None as + its argument. + """ def half_close(self): # Closes the write end after flushing. - XXX + """Closes the write end after flushing buffered data. + + Data may still be received. + """ def pause(self): - XXX + """Pause the receiving end. + + No data will be received until resume end is called. + """ def resume(self): - XXX + """Resume the receiving end. + + Cancels a pause() call. + """ class Protocol: + """ABC representing a protocol. + + The user should implement this interface. + + When the user requests a transport, they pass it a protocol + instance to a utility function. + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_list() will be called exactly once + with either an exception object or None as an argument. + + If the utility function does not succeed in creating a transport, + it will call connection_lost() with an exception object. + + State machine of calls: + + start -> [CM -> DR*] -> CL -> end + """ def connection_made(self, transport): - XXX + """Called when a connection is made. + + The argument is the transport representing the connection. + To send data, call its write() or writelines() method. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ def data_received(self, data): - XXX + """Called when some data is received. + + The argument is a bytes object. + + TODO: Should we allow it to be a bytesarray or some other + memory buffer? + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + Also called when we fail to make a connection at all. - def connection_lost(self, exc): # Also when connect() failed. - XXX + The argument is an exception object or None (the latter meaning + a regular EOF is received or the connection was aborted). + """ # XXX The rest is platform specific and should move elsewhere. From 7a39a517a9ce9dc2a73ef78ba8812fb529068380 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 10 Dec 2012 15:17:05 -0800 Subject: [PATCH 0110/1502] More about socket transports. --- transports.py | 167 ++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 142 insertions(+), 25 deletions(-) diff --git a/transports.py b/transports.py index c74f2142..b9bafb94 100644 --- a/transports.py +++ b/transports.py @@ -6,11 +6,13 @@ """ # Stdlib imports. +import collections import errno import logging import socket import ssl import sys +import time # Local imports. import polling @@ -37,6 +39,9 @@ class Transport: connection_made() method with a transport (or it will call connection_lost() with an exception if it fails to create the desired transport). + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. """ def write(self, data): @@ -47,7 +52,7 @@ def write(self, data): """ raise NotImplemented - def writelines(self, list_of_data): # Not just for lines. + def writelines(self, list_of_data): """Write a list (or any iterable) of data (bytes) to the transport. The default implementation just calls write() for each item in @@ -63,6 +68,7 @@ def close(self): be received. When all buffered data is flushed, the protocol's connection_lost() method is called with None as its argument. """ + raise NotImplemented def abort(self): """Closes the transport immediately. @@ -71,32 +77,42 @@ def abort(self): The protocol's connection_lost() method is called with None as its argument. """ + raise NotImplemented - def half_close(self): # Closes the write end after flushing. + def half_close(self): """Closes the write end after flushing buffered data. Data may still be received. + + TODO: What's the use case for this? How to implement it? + Should it call shutdown(SHUT_WR) after all the data is flushed? + Is there no use case for closing the other half first? """ + raise NotImplemented def pause(self): """Pause the receiving end. - No data will be received until resume end is called. + No data will be received until resume() is called. """ + raise NotImplemented def resume(self): """Resume the receiving end. - Cancels a pause() call. + Cancels a pause() call, resumes receiving data. """ + raise NotImplemented class Protocol: """ABC representing a protocol. - The user should implement this interface. + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing. - When the user requests a transport, they pass it a protocol + When the user wants to requests a transport, they pass a protocol instance to a utility function. When the connection is made successfully, connection_made() is @@ -134,30 +150,32 @@ def data_received(self, data): def connection_lost(self, exc): """Called when the connection is lost or closed. - Also called when we fail to make a connection at all. + Also called when we fail to make a connection at all (in that + case connection_made() will not be called). - The argument is an exception object or None (the latter meaning - a regular EOF is received or the connection was aborted). + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). """ -# XXX The rest is platform specific and should move elsewhere. +# TODO: The rest is platform specific and should move elsewhere. -class SocketTransport(Transport): +class UnixSocketTransport(Transport): def __init__(self, eventloop, protocol, sock): self._eventloop = eventloop self._protocol = protocol self._sock = sock + self._buffer = collections.deque() # For write(). + self._write_closed = False def _on_readable(self): try: data = self._sock.recv(8192) except socket.error as exc: if exc.errno not in _TRYAGAIN: - self._eventloop.remove_reader(self._sock.fileno()) - self._sock.close() - self._protocol.connection_lost(exc) # XXX calL_soon()? + self._bad_error(exc) else: if not data: self._eventloop.remove_reader(self._sock.fileno()) @@ -167,8 +185,88 @@ def _on_readable(self): self._protocol.data_received(data) # XXX call_soon()? def write(self, data): - # XXX implement write buffering. - self._sock.sendall(data) + assert isinstance(data, bytes) + assert not self._write_closed + if not data: + # Silly, but it happens. + return + if self._buffer: + # We've already registered a callback, just buffer the data. + self._buffer.append(data) + # Consider pausing if the total length of the buffer is + # truly huge. + return + + # TODO: Refactor so there's more sharing between this and + # _on_writable(). + + # There's no callback registered yet. It's quite possible + # that the kernel has buffer space for our data, so try to + # write now. Since the socket is non-blocking it will + # give us an error in _TRYAGAIN if it doesn't have enough + # space for even one more byte; it will return the number + # of bytes written if it can write at least one byte. + try: + n = self._sock.send(data) + except socket.error as exc: + # An error. + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + # The kernel doesn't have room for more data right now. + n = 0 + else: + # Wrote at least one byte. + if n == len(data): + # Wrote it all. Done! + if self._write_closed: + self._sock.shutdown(socket.SHUT_WR) + return + # Throw away the data that was already written. + # TODO: Do this without copying the data? + data = data[n:] + self._buffer.append(data) + self._eventloop.add_writer(self._sock.fileno(), self._on_writable) + + def _on_writable(self): + while self._buffer: + data = self._buffer[0] + # TODO: Join small amounts of data? + try: + n = self._sock.send(data) + except socket.error as exc: + # Error handling is the same as in write(). + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + if n < len(data): + self._buffer[0] = data[n:] + return + self._buffer.popleft() + self._eventloop.remove_writer(self._sock.fileno()) + if self._write_closed: + self._sock.shutdown(socket.SHUT_WR) + + def abort(self): + self._bad_error(None) + + def _bad_error(self, exc): + # A serious error. Close the socket etc. + fd = self._sock.fileno() + # TODO: Record whether we have a writer and/or reader registered. + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + self._sock.close() + self._protocol.connection_lost(exc) # XXX call_soon()? + + def half_close(self): + self._write_closed = True def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, @@ -182,11 +280,18 @@ def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, eventloop = polling.context.eventloop threadrunner = polling.context.threadrunner - # XXX Move all this to private methods on SocketTransport. + # TODO: Maybe use scheduling.Task(sockets.create_connection()). + # (That sounds insane, but it's purely an internal detail. Once + # we have a socket, we can switch to the coroutine-free + # transport+protocol API.) + + # TODO: Move this to private methods on UnixSocketTransport? + # (But then how to handle sharing code between socket and ssl + # transports?) def on_addrinfo(infos, exc): logging.debug('on_addrinfo(, %r)', len(infos), exc) - # XXX Make infos into an iterator, to avoid pop()? + # TODO: Make infos into an iterator, to avoid pop()? if not infos: if exc is not None: protocol.connection_lost(exc) @@ -203,7 +308,7 @@ def on_addrinfo(infos, exc): except socket.error as exc: if exc.errno != errno.EINPROGRESS: sock.close() - on_addrinfo(infos, exc) # XXX Use eventloop.call_soon()? + on_addrinfo(infos, exc) # XXX call_soon()? return def on_writable(): @@ -218,7 +323,7 @@ def on_writable(): if use_ssl: XXX else: - transport = SocketTransport(eventloop, protocol, sock) + transport = UnixSocketTransport(eventloop, protocol, sock) protocol.connection_made(transport) # XXX call_soon()? eventloop.add_reader(sock.fileno(), transport._on_readable) @@ -251,20 +356,32 @@ def main(): # Testing... level = logging.WARN logging.basicConfig(level=level) + host = 'xkcd.com' + if '.' in sys.argv[-1]: + host = sys.argv[-1] + + t0 = time.time() + class TestProtocol(Protocol): def connection_made(self, transport): - logging.debug('Connection made.') + logging.info('Connection made at %.3f secs', time.time() - t0) self.transport = transport - self.transport.write(b'GET / HTTP/1.0\r\nHost: python.org\r\n\r\n') - ## self.transport.half_close() + self.transport.write(b'GET / HTTP/1.0\r\nHost: ' + + host.encode('ascii') + + b'\r\n\r\n') + self.transport.half_close() def data_received(self, data): - logging.debug('Received %d bytes: %r', len(data), data) + logging.info('Received %d bytes at t=%.3f', + len(data), time.time() - t0) + logging.debug('Receved %r', data) def connection_lost(self, exc): logging.debug('Connection lost: %r', exc) + self.t1 = time.time() + logging.info('Total time %.3f secs', self.t1 - t0) tp = TestProtocol() logging.debug('tp = %r', tp) - make_connection(tp, 'python.org') + make_connection(tp, host) logging.info('Running...') polling.context.eventloop.run() logging.info('Done.') From a1593c3fb26f226b8c3a206862f346a89141ec08 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 10 Dec 2012 16:36:42 -0800 Subject: [PATCH 0111/1502] Fix what must be typos in the ssl transport error handling. --- sockets.py | 4 +- transports.py | 135 +++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 129 insertions(+), 10 deletions(-) diff --git a/sockets.py b/sockets.py index 4d64ee46..a5005dc3 100644 --- a/sockets.py +++ b/sockets.py @@ -142,7 +142,7 @@ def recv(self, n): yield from scheduling.block_w(self.sslsock.fileno()) except socket.error as err: if err.errno in _TRYAGAIN: - yield from scheduling.block_r(self.sock.fileno()) + yield from scheduling.block_r(self.sslsock.fileno()) elif err.errno in _DISCONNECTED: # Can this happen? return b'' @@ -160,7 +160,7 @@ def send(self, data): yield from scheduling.block_w(self.sslsock.fileno()) except socket.error as err: if err.errno in _TRYAGAIN: - yield from scheduling.block_w(self.sock.fileno()) + yield from scheduling.block_w(self.sslsock.fileno()) elif err.errno in _DISCONNECTED: return False else: diff --git a/transports.py b/transports.py index b9bafb94..fb7353a3 100644 --- a/transports.py +++ b/transports.py @@ -50,7 +50,7 @@ def write(self, data): This does not block; it buffers the data and arranges for it to be sent out asynchronously. """ - raise NotImplemented + raise NotImplementedError def writelines(self, list_of_data): """Write a list (or any iterable) of data (bytes) to the transport. @@ -68,7 +68,7 @@ def close(self): be received. When all buffered data is flushed, the protocol's connection_lost() method is called with None as its argument. """ - raise NotImplemented + raise NotImplementedError def abort(self): """Closes the transport immediately. @@ -77,7 +77,7 @@ def abort(self): The protocol's connection_lost() method is called with None as its argument. """ - raise NotImplemented + raise NotImplementedError def half_close(self): """Closes the write end after flushing buffered data. @@ -88,21 +88,21 @@ def half_close(self): Should it call shutdown(SHUT_WR) after all the data is flushed? Is there no use case for closing the other half first? """ - raise NotImplemented + raise NotImplementedError def pause(self): """Pause the receiving end. No data will be received until resume() is called. """ - raise NotImplemented + raise NotImplementedError def resume(self): """Resume the receiving end. Cancels a pause() call, resumes receiving data. """ - raise NotImplemented + raise NotImplementedError class Protocol: @@ -269,6 +269,116 @@ def half_close(self): self._write_closed = True +class UnixSslTransport(Transport): + + # TODO: Refactor Socket and Ssl transport to share some code. + # (E.g. buffering.) + + # TODO: Consider using coroutines instead of callbacks, it seems + # much easier that way. + + def __init__(self, eventloop, protocol, rawsock, sslcontext=None): + self._eventloop = eventloop + self._protocol = protocol + self._rawsock = rawsock + self._sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslsock = self._sslcontext.wrap_socket( + self._rawsock, do_handshake_on_connect=False) + + self._buffer = collections.deque() # For write(). + self._write_closed = False + + # Try the handshake now. Likely it will raise EAGAIN, then it + # will take care of registering the appropriate callback. + self._on_handshake() + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._eventloop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._eventloop.add_writable(fd, self._on_handshake) + return + # TODO: What if it raises another error? + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + self._protocol.connection_made(self) + self._eventloop.add_reader(fd, self._on_ready) + self._eventloop.add_writer(fd, self._on_ready) + + def _on_ready(self): + import pdb; pdb.set_trace() + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + else: + if data: + self._protocol.data_received(data) + else: + fd = self._sslsock.fileno() + self._eventloop.remove_reader(fd) + self._eventloop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = self._buffer[0] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + else: + if n == len(data): + self._buffer.popleft() + # Could try again, but let's just have the next callback do it. + if not self._buffer: + self._sslsock.shutdown(socket.SHUT_WR) + else: + self._buffer[0] = data[n:] + + def write(self, data): + assert isinstance(data, bytes) + assert not self._write_closed + if not data: + return + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + def half_close(self): + self._write_closed = True + + def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, use_ssl=None): if port is None: @@ -312,6 +422,7 @@ def on_addrinfo(infos, exc): return def on_writable(): + # connect() makes the socket writable when it is connected. eventloop.remove_writer(sock.fileno()) err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) if err != 0: @@ -321,9 +432,17 @@ def on_writable(): return if use_ssl: - XXX + # You can pass an ssl.SSLContext object as use_ssl, + # or a bool. + if isinstance(use_ssl, bool): + sslcontext = None + else: + sslcontext = use_ssl + transport = UnixSslTransport(eventloop, protocol, sock, + sslcontext) else: transport = UnixSocketTransport(eventloop, protocol, sock) + # TODO: Should the ransport make the following calls? protocol.connection_made(transport) # XXX call_soon()? eventloop.add_reader(sock.fileno(), transport._on_readable) @@ -381,7 +500,7 @@ def connection_lost(self, exc): tp = TestProtocol() logging.debug('tp = %r', tp) - make_connection(tp, host) + make_connection(tp, host, use_ssl=('-S' in sys.argv)) logging.info('Running...') polling.context.eventloop.run() logging.info('Done.') From 2e4f5e03d5e8a437e79d8de32d50bceb8da4245b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 10 Dec 2012 17:01:47 -0800 Subject: [PATCH 0112/1502] Fix bug in KqueuePollster.register_writer(). --- polling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/polling.py b/polling.py index 07ae3d8e..4790ea0a 100644 --- a/polling.py +++ b/polling.py @@ -223,7 +223,7 @@ def register_reader(self, fd, callback, *args): return super().register_reader(fd, callback, *args) def register_writer(self, fd, callback, *args): - if fd not in self.readers: + if fd not in self.writers: kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) self._kqueue.control([kev], 0, 0) return super().register_writer(fd, callback, *args) From 58e65b6364516ebf19e9463350a34b54aeb65fa8 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 10 Dec 2012 17:03:25 -0800 Subject: [PATCH 0113/1502] Report some more errors. --- TODO | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/TODO b/TODO index df868cca..5dc2fcb5 100644 --- a/TODO +++ b/TODO @@ -182,3 +182,14 @@ MISTAKES I MADE * I didn't build the whole infrastructure, just played with recv() * I don't have unittests * I don't have good logging to see what is going + +- In sockets.py, in some SSL error handling code, used the wrong + variable (sock instead of sslsock). A linter would have found this. + +- In polling.py, in KqueuePollster.register_writer(), a copy/paste + error where I was testing for "if fd not in self.readers" instead of + writers. This only came out when I had both a reader and a writer + for the same fd. + +- Submitted some changes prematurely (forgot to pass the filename on + hg ci). From ec1261fbd2fa65d716dbeebf20b24018ffa07ec4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 10 Dec 2012 17:20:57 -0800 Subject: [PATCH 0114/1502] Fix transport -- shutdown() on an ssl socket does not work. --- TODO | 4 ++++ transports.py | 8 +++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/TODO b/TODO index 5dc2fcb5..ed03e1f7 100644 --- a/TODO +++ b/TODO @@ -193,3 +193,7 @@ MISTAKES I MADE - Submitted some changes prematurely (forgot to pass the filename on hg ci). + +- Forgot again that shutdown(SHUT_WR) on an ssl socket does not work + as I expected. I ran into this with the origininal sockets.py and + again in transport.py. diff --git a/transports.py b/transports.py index fb7353a3..f5b0b661 100644 --- a/transports.py +++ b/transports.py @@ -316,7 +316,6 @@ def _on_handshake(self): self._eventloop.add_writer(fd, self._on_ready) def _on_ready(self): - import pdb; pdb.set_trace() # Because of renegotiations (?), there's no difference between # readable and writable. We just try both. XXX This may be # incorrect; we probably need to keep state about what we @@ -337,11 +336,14 @@ def _on_ready(self): if data: self._protocol.data_received(data) else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer fd = self._sslsock.fileno() self._eventloop.remove_reader(fd) self._eventloop.remove_writer(fd) self._sslsock.close() self._protocol.connection_lost(None) + return # Now try writing, if there's anything to write. if not self._buffer: @@ -362,8 +364,6 @@ def _on_ready(self): if n == len(data): self._buffer.popleft() # Could try again, but let's just have the next callback do it. - if not self._buffer: - self._sslsock.shutdown(socket.SHUT_WR) else: self._buffer[0] = data[n:] @@ -377,6 +377,8 @@ def write(self, data): def half_close(self): self._write_closed = True + # Just set the flag. Calling shutdown() on the ssl socket + # breaks something, causing recv() to return binary data. def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, From acda7bef301c4ef5b4f8dcf5d79d26556517552c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 10 Dec 2012 17:22:29 -0800 Subject: [PATCH 0115/1502] Delete trailing whitespace. --- transports.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transports.py b/transports.py index f5b0b661..c910dd0b 100644 --- a/transports.py +++ b/transports.py @@ -92,7 +92,7 @@ def half_close(self): def pause(self): """Pause the receiving end. - + No data will be received until resume() is called. """ raise NotImplementedError From 4f3bc9e85d1ae964464092daf6be079065eeea25 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 10 Dec 2012 17:24:01 -0800 Subject: [PATCH 0116/1502] Fix logging if infos is None. --- transports.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/transports.py b/transports.py index c910dd0b..24b0d7fa 100644 --- a/transports.py +++ b/transports.py @@ -402,7 +402,10 @@ def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, # transports?) def on_addrinfo(infos, exc): - logging.debug('on_addrinfo(, %r)', len(infos), exc) + if infos is None: + logging.debug('on_addrinfo(None, %r)', exc) + else: + logging.debug('on_addrinfo(, %r)', len(infos), exc) # TODO: Make infos into an iterator, to avoid pop()? if not infos: if exc is not None: From 2d884ec5d060ff24f6b1b01123701413656fe9f1 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 11 Dec 2012 17:11:42 -0800 Subject: [PATCH 0117/1502] Fix bug in argv processing. --- transports.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transports.py b/transports.py index 24b0d7fa..51b2772f 100644 --- a/transports.py +++ b/transports.py @@ -481,7 +481,7 @@ def main(): # Testing... logging.basicConfig(level=level) host = 'xkcd.com' - if '.' in sys.argv[-1]: + if sys.argv[1:] and '.' in sys.argv[-1]: host = sys.argv[-1] t0 = time.time() @@ -497,7 +497,7 @@ def connection_made(self, transport): def data_received(self, data): logging.info('Received %d bytes at t=%.3f', len(data), time.time() - t0) - logging.debug('Receved %r', data) + logging.debug('Received %r', data) def connection_lost(self, exc): logging.debug('Connection lost: %r', exc) self.t1 = time.time() From ca6a1b5a59d57f56ad12185441a65b1331558150 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 11 Dec 2012 20:03:32 -0800 Subject: [PATCH 0118/1502] Fix another bug. Add pdb-invoking _bad_error(). --- TODO | 4 ++++ transports.py | 10 +++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/TODO b/TODO index ed03e1f7..aca24e32 100644 --- a/TODO +++ b/TODO @@ -197,3 +197,7 @@ MISTAKES I MADE - Forgot again that shutdown(SHUT_WR) on an ssl socket does not work as I expected. I ran into this with the origininal sockets.py and again in transport.py. + +- Having the same callback for both reading and writing has a problem: + it may be scheduled twice, and if the first call closes the socket, + the second runs into trouble. diff --git a/transports.py b/transports.py index 51b2772f..257aded8 100644 --- a/transports.py +++ b/transports.py @@ -292,6 +292,10 @@ def __init__(self, eventloop, protocol, rawsock, sslcontext=None): # will take care of registering the appropriate callback. self._on_handshake() + def _bad_error(self, exc): + import pdb; pdb.set_trace() + logging.error('Exception: %s', exc) + def _on_handshake(self): fd = self._sslsock.fileno() try: @@ -321,6 +325,11 @@ def _on_ready(self): # incorrect; we probably need to keep state about what we # should do next. + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + # First try reading. try: data = self._sslsock.recv(8192) @@ -338,7 +347,6 @@ def _on_ready(self): else: # TODO: Don't close when self._buffer is non-empty. assert not self._buffer - fd = self._sslsock.fileno() self._eventloop.remove_reader(fd) self._eventloop.remove_writer(fd) self._sslsock.close() From 1e5d6f628c8062de5bb7608ad518a5c514e055c8 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 11 Dec 2012 20:32:59 -0800 Subject: [PATCH 0119/1502] Add notes from Twisted/Tulip meetup. --- NOTES | 109 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 NOTES diff --git a/NOTES b/NOTES new file mode 100644 index 00000000..e955d5d2 --- /dev/null +++ b/NOTES @@ -0,0 +1,109 @@ +Notes from the second Tulip/Twisted meet-up +=========================================== + +Rackspace, 12/11/2012 +Glyph, Brian Warner, David Reid, Duncan McGreggor, others + +Flow control +------------ + +- Pause/resume on transport manages data_received. + +- There's also an API to tell the transport whom to pause when the + write calls are overwhelming it: IConsumer.registerProducer(). + +- There's also something called pipes but it's built on top of the + old interface. + +- Twisted has variations on the basic flow control that I should + ignore. + +Half_close +---------- + +- This sends an EOF after writing some stuff. + +- Can't write any more. + +- Problem with TLS is known (the RFC sadly specifies this behavior). + +- It must be dynamimcally discoverable whether the transport supports + half_close, since the protocol may have to do something different to + make up for its missing (e.g. use chunked encoding). Twisted uses + an interface check for this and also hasattr(trans, 'halfClose') + but a flag (or flag method) is fine too. + +Constructing transport and protocol +----------------------------------- + +- There are good reasons for passing a function to the transport + construction helper that creates the protocol. (You need these + anyway for server-side protocols.) The sequence of events is + something like + + . open socket + . create transport (pass it a socket?) + . create protocol (pass it nothing) + . proto.make_connection(transport); this does: + . self.transport = transport + . self.connection_made(transport) + + But it seems okay to skip make_connection and setting .transport. + Note that make_connection() is a concrete method on the Protocol + implementation base class, while connection_made() is an abstract + method on IProtocol. + +Event Loop +---------- + +- We discussed the sequence of actions in the event loop. I think in the + end we're fine with what Tulip currently does. There are two choices: + + Tulip: + . run ready callbacks until there aren't any left + . poll, adding more callbacks to the ready list + . add now-ready delayed callbacks to the ready list + . go to top + + Tornado: + . run all currently ready callbacks (but not new ones added during this) + . (the rest is the same) + + The difference is that in the Tulip version, CPU bound callbacks + that keep adding more to the queue will starve I/O (and yielding to + other tasks won't actually cause I/O to happen unless you do + e.g. sleep(0.001)). OTOH this may be good because it means there's + less overhead if you frequently split operations in two. + +- I think Twisted does it Tornado style (in a convoluted way :-), but + it may not matter, and it's important to leave this vague so + implementations can do what's best for their platform. (E.g. if the + event loop is built into the OS there are different trade-offs.) + +System call cost +---------------- + +- System calls on MacOS are expensive, on Linux they are cheap. + +- Optimal buffer size ~16K. + +- Try joining small buffer pieces together, but expect to be tuning + this later. + +Futures +------- + +- Futures are the most robust API for async stuff, you can check + errors etc. So let's do this. + +- Just don't implement wait(). + +- For the basics, however, (recv/send, mostly), don't use Futures but use + basic callbacks, transport/protocol style. + +- This means revisiting the Tulip proactor branch (IOCP). + +What else? +---------- + +- I think we discussed some other topics but I've forgotten these. From 152787a6bea8ebcb2aef142cf9a0e435f1866db8 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 11 Dec 2012 22:34:54 -0800 Subject: [PATCH 0120/1502] Updated notes. --- NOTES | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/NOTES b/NOTES index e955d5d2..01427cc9 100644 --- a/NOTES +++ b/NOTES @@ -103,7 +103,25 @@ Futures - This means revisiting the Tulip proactor branch (IOCP). -What else? ----------- +- The semantics of add_done_callback() are fuzzy about in which thread + the callback will be called. (It may be the current thread or + another one.) We don't like that. But always inserting a + call_soon() indirection may be expensive? Glyph suggested changing + the add_done_callback() method name to something else to indicate + the changed promise. + +- Separately, I've been thinking about having two versions of + call_soon() -- a more heavy-weight one to be called from other + threads that also writes a byte to the self-pipe. + +Signals +------- + +- There was a side conversation about signals. A signal handler is + similar to another thread, so probably should use (the heavy-weight + version of) call_soon() to schedule the real callback and not do + anything else. -- I think we discussed some other topics but I've forgotten these. +- Glyph vaguely recalled some trickiness with the self-pipe. We + should be able to fix this afterwards if necessary, it shouldn't + affect the API design. From 4f73976c3621436317436f831570316fba495688 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 12 Dec 2012 10:36:51 -0800 Subject: [PATCH 0121/1502] Use sockets.create_connection() to connect a socket. --- transports.py | 106 ++++++++++++++++---------------------------------- 1 file changed, 34 insertions(+), 72 deletions(-) diff --git a/transports.py b/transports.py index 257aded8..dcbbb752 100644 --- a/transports.py +++ b/transports.py @@ -16,6 +16,8 @@ # Local imports. import polling +import scheduling +import sockets # Errno values indicating the connection was disconnected. _DISCONNECTED = frozenset((errno.ECONNRESET, @@ -391,6 +393,14 @@ def half_close(self): def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, use_ssl=None): + # TODO: Pass in a protocol factory, not a protocol. + # What should be the exact sequence of events? + # - socket + # - transport + # - protocol + # - tell transport about protocol + # - tell protocol about transport + # Or should the latter two be reversed? Does it matter? if port is None: port = 443 if use_ssl else 80 if use_ssl is None: @@ -398,81 +408,33 @@ def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, if not socktype: socktype = socket.SOCK_STREAM eventloop = polling.context.eventloop - threadrunner = polling.context.threadrunner - # TODO: Maybe use scheduling.Task(sockets.create_connection()). - # (That sounds insane, but it's purely an internal detail. Once - # we have a socket, we can switch to the coroutine-free - # transport+protocol API.) - - # TODO: Move this to private methods on UnixSocketTransport? - # (But then how to handle sharing code between socket and ssl - # transports?) - - def on_addrinfo(infos, exc): - if infos is None: - logging.debug('on_addrinfo(None, %r)', exc) - else: - logging.debug('on_addrinfo(, %r)', len(infos), exc) - # TODO: Make infos into an iterator, to avoid pop()? - if not infos: - if exc is not None: - protocol.connection_lost(exc) - return - protocol.connection_lost(IOError(0, 'No more infos to try')) - return - - af, socktype, proto, cname, address = infos.pop(0) - sock = socket.socket(af, socktype, proto) - sock.setblocking(False) - - try: - sock.connect(address) - except socket.error as exc: - if exc.errno != errno.EINPROGRESS: - sock.close() - on_addrinfo(infos, exc) # XXX call_soon()? - return - - def on_writable(): - # connect() makes the socket writable when it is connected. - eventloop.remove_writer(sock.fileno()) - err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) - if err != 0: - sock.close() - # Try the next getaddrinfo() return value. - on_addrinfo(infos, None) - return - - if use_ssl: - # You can pass an ssl.SSLContext object as use_ssl, - # or a bool. - if isinstance(use_ssl, bool): - sslcontext = None - else: - sslcontext = use_ssl - transport = UnixSslTransport(eventloop, protocol, sock, - sslcontext) + def on_socket_connected(task): + assert not task.alive + if task.exception is not None: + # TODO: Call some callback. + raise task.exception + sock = task.result + assert sock is not None + logging.debug('on_socket_connected') + if use_ssl: + # You can pass an ssl.SSLContext object as use_ssl, + # or a bool. + if isinstance(use_ssl, bool): + sslcontext = None else: - transport = UnixSocketTransport(eventloop, protocol, sock) - # TODO: Should the ransport make the following calls? - protocol.connection_made(transport) # XXX call_soon()? - eventloop.add_reader(sock.fileno(), transport._on_readable) - - eventloop.add_writer(sock.fileno(), on_writable) - - def on_future_done(fut): - logging.debug('Future done.') - exc = fut.exception() - if exc is None: - infos = fut.result() + sslcontext = use_ssl + transport = UnixSslTransport(eventloop, protocol, sock, sslcontext) else: - infos = None - eventloop.call_soon(on_addrinfo, infos, exc) - - future = threadrunner.submit(socket.getaddrinfo, - host, port, af, socktype, proto, - callback=on_future_done) + transport = UnixSocketTransport(eventloop, protocol, sock) + # TODO: Should the ransport make the following calls? + protocol.connection_made(transport) # XXX call_soon()? + # Don't do this before connection_made() is called. + eventloop.add_reader(sock.fileno(), transport._on_readable) + + coro = sockets.create_connection(host, port, af, socktype, proto) + task = scheduling.Task(coro) + task.add_done_callback(on_socket_connected) def main(): # Testing... From 3ca341d05d841039c2ba006aa369dca1275fe20d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 12 Dec 2012 10:46:26 -0800 Subject: [PATCH 0122/1502] Add note about make_connection() returning a Future. --- NOTES | 3 +++ 1 file changed, 3 insertions(+) diff --git a/NOTES b/NOTES index 01427cc9..6f41578e 100644 --- a/NOTES +++ b/NOTES @@ -101,6 +101,9 @@ Futures - For the basics, however, (recv/send, mostly), don't use Futures but use basic callbacks, transport/protocol style. +- make_connection() (by any name) can return a Future, it makes it + easier to check for errors. + - This means revisiting the Tulip proactor branch (IOCP). - The semantics of add_done_callback() are fuzzy about in which thread From 2ea79f2319740f4be40ae8a2077b9a5e7399652f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 12 Dec 2012 14:10:02 -0800 Subject: [PATCH 0123/1502] Finish UnixSslSocket._bad_error(). --- transports.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/transports.py b/transports.py index dcbbb752..19095bf4 100644 --- a/transports.py +++ b/transports.py @@ -295,8 +295,19 @@ def __init__(self, eventloop, protocol, rawsock, sslcontext=None): self._on_handshake() def _bad_error(self, exc): - import pdb; pdb.set_trace() - logging.error('Exception: %s', exc) + # A serious error. Close the socket etc. + fd = self._sslsock.fileno() + # TODO: Record whether we have a writer and/or reader registered. + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + self._sslsock.close() + self._protocol.connection_lost(exc) # XXX call_soon()? def _on_handshake(self): fd = self._sslsock.fileno() From 5e1d15118de167f8fb6c9a8fa7c80effbbf89063 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 16 Dec 2012 16:59:34 -0800 Subject: [PATCH 0124/1502] Checkpoint. --- tulip/__init__.py | 14 ++ tulip/events.py | 208 +++++++++++++++++ tulip/futures.py | 175 +++++++++++++++ tulip/protocols.py | 9 + tulip/tasks.py | 23 ++ tulip/transports.py | 9 + tulip/unix_events.py | 514 +++++++++++++++++++++++++++++++++++++++++++ 7 files changed, 952 insertions(+) create mode 100644 tulip/__init__.py create mode 100644 tulip/events.py create mode 100644 tulip/futures.py create mode 100644 tulip/protocols.py create mode 100644 tulip/tasks.py create mode 100644 tulip/transports.py create mode 100644 tulip/unix_events.py diff --git a/tulip/__init__.py b/tulip/__init__.py new file mode 100644 index 00000000..185fe3fe --- /dev/null +++ b/tulip/__init__.py @@ -0,0 +1,14 @@ +"""Tulip 2.0, tracking PEP 3156.""" + +# This relies on each of the submodules having an __all__ variable. +from .futures import * +from .events import * +from .transports import * +from .protocols import * +from .tasks import * + +__all__ = (futures.__all__ + + events.__all__ + + transports.__all__ + + protocols.__all__ + + tasks.__all__) diff --git a/tulip/events.py b/tulip/events.py new file mode 100644 index 00000000..67da8a84 --- /dev/null +++ b/tulip/events.py @@ -0,0 +1,208 @@ +"""Event loop and event loop policy. + +Beyond the PEP: +- Only the main thread has a default event loop. +- init_event_loop() (re-)initializes the event loop. +""" + +__all__ = ['EventLoopPolicy', 'DefaultEventLoopPolicy', + 'EventLoop', 'DelayedCall', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', + 'init_event_loop', + ] + +import threading + + +class DelayedCall: + """Object returned by callback registration methods.""" + + def __init__(self, when, callback, args, kwds=None): + self.when = when + self.callback = callback + self.args = args + self.kwds = kwds + self.cancelled = False + + def cancel(self): + self.cancelled = True + + def __lt__(self, other): + return self.when < other.when + + def __le__(self, other): + return self.when <= other.when + + def __eq__(self, other): + return self.when == other.when + + +class EventLoop: + """Abstract event loop.""" + + def run(self): + """Run the event loop. Block until there is nothing left to do.""" + raise NotImplementedError + + # TODO: stop()? + + # Methods returning DelayedCalls for scheduling callbacks. + + def call_later(self, when, callback, *args): + raise NotImplementedError + + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) + + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError + + # Methods returning Futures for interacting with threads. + + def wrap_future(self, future): + raise NotImplementedError + + def run_in_executor(self, executor, function, *args): + raise NotImplementedError + + # Network I/O methods returning Futures. + + def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError + + def create_transport(self, protocol_factory, host, port, *, + family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def start_serving(self, protocol_factory, host, port, *, + family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + # Ready-based callback registration methods. + # The add_*() methods return a DelayedCall. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. + + def add_reader(self, fd, callback, *args): + raise NotImplementedError + + def remove_reader(self, fd): + raise NotImplementedError + + def add_writer(self, fd, callback, *args): + raise NotImplementedError + + def remove_writer(self, fd): + raise NotImplementedError + + # Completion based I/O methods returning Futures. + + def sock_recv(self, sock, nbytes): + raise NotImplementedError + + def sock_sendall(self, sock, data): + raise NotImplementedError + + def sock_connect(self, sock, address): + raise NotImplementedError + + def sock_accept(self, sock): + raise NotImplementedError + + +class EventLoopPolicy: + """Abstract policy for accessing the event loop.""" + + def get_event_loop(self): + """XXX""" + raise NotImplementedError + + def set_event_loop(self, event_loop): + """XXX""" + raise NotImplementedError + + def init_event_loop(self): + """XXX""" + raise NotImplementedError + + +class DefaultEventLoopPolicy(threading.local, EventLoopPolicy): + """Default policy implementation for accessing the event loop. + + In this policy, each thread has its own event loop. However, we + only automatically create an event loop by default for the main + thread; other threads by default have no event loop. + + Other policies may have different rules (e.g. a single global + event loop, or automatically creating an event loop per thread, or + using some other notion of context to which an event loop is + associated). + """ + + _event_loop = None + + def get_event_loop(self): + """Get the event loop. + + This may be None or an instance of EventLoop. + """ + if (self._event_loop is None and + threading.current_thread().name == 'MainThread'): + self.init_event_loop() + return self._event_loop + + def set_event_loop(self, event_loop): + """Set the event loop.""" + assert event_loop is None or isinstance(event_loop, EventLoop) + self._event_loop = event_loop + + def init_event_loop(self): + """(Re-)initialize the event loop. + + This is calls set_event_loop() with a freshly created event + loop suitable for the platform. + """ + # TODO: Do something else for Windows. + from . import unix_events + self.set_event_loop(unix_events.UnixEventLoop()) + + +# Event loop policy. The policy itself is always global, even if the +# policy's rules say that there is an event loop per thread (or other +# notion of context). The default policy is installed by the first +# call to get_event_loop_policy(). +_event_loop_policy = None + + +def get_event_loop_policy(): + """XXX""" + global _event_loop_policy + if _event_loop_policy is None: + _event_loop_policy = DefaultEventLoopPolicy() + return _event_loop_policy + + +def set_event_loop_policy(policy): + """XXX""" + global _event_loop_policy + assert policy is None or isinstance(policy, EventLoopPolicy) + _event_loop_policy = policy + + +def get_event_loop(): + """XXX""" + return get_event_loop_policy().get_event_loop() + + +def set_event_loop(event_loop): + """XXX""" + get_event_loop_policy().set_event_loop(event_loop) + + +def init_event_loop(): + """XXX""" + get_event_loop_policy().init_event_loop() diff --git a/tulip/futures.py b/tulip/futures.py new file mode 100644 index 00000000..9e0fd767 --- /dev/null +++ b/tulip/futures.py @@ -0,0 +1,175 @@ +"""A Future class similar to the one in PEP 3148.""" + +__all__ = ['Future', 'InvalidStateError', 'InvalidTimeoutError', 'sleep'] + +import concurrent.futures._base + +from . import events + +# States for Future. +_PENDING = 'PENDING' # Next states: _CANCELLED, _RUNNING. +_CANCELLED = 'CANCELLED' # End state. +_RUNNING = 'RUNNING' # Next state: _FINISHED. +_FINISHED = 'FINISHED' # End state. +_DONE_STATES = (_CANCELLED, _FINISHED) + + +Error = concurrent.futures._base.Error + + +class InvalidStateError(Error): + """The operation is not allowed in this state.""" + # TODO: Show the future, its state, the method, and the required state. + + +class InvalidTimeoutError(Error): + """Called result() or exception() with timeout != 0.""" + # TODO: Print a nice error message. + + +class Future: + """This class is *almost* compatible with concurrent.futures.Future. + + Differences: + + - result() and exception() do not take a timeout argument and + raise an exception when the future isn't done yet. + + - Callbacks registered with add_done_callback() are always called + via the event loop's call_soon_threadsafe(). + + - This class is not compatible with the wait() and as_completed() + methods in the concurrent.futures package. + + (In Python 3.4 or later we may be able to unify the implementations.) + """ + + # TODO: PEP 3148 seems to say that cancel() does not call the + # callbacks, but set_running_or_notify_cancel() does (if cancel() + # was called). Here, cancel() schedules the callbacks, and + # set_running_or_notify_cancel() just sets the state. + + # Class variables serving to as defaults for instance variables. + _state = _PENDING + _result = None + _exception = None + + def __init__(self): + """XXX""" + self._callbacks = [] + + def __repr__(self): + """XXX""" + if self._callbacks: + # TODO: Maybe limit the list of callbacks if there are many? + return 'Future<{}, {}>'.format(self._state, self._callbacks) + else: + return 'Future<{}>'.format(self._state) + + def cancel(self): + """XXX""" + if self._state != _PENDING: + return False + self._state = _CANCELLED + self._schedule_callbacks() + return True + + def _schedule_callbacks(self): + """XXX""" + callbacks = self._callbacks[:] + if not callbacks: + return + # Is it worth emptying the callbacks? It may reduce the + # usefulness of repr(). + self._callbacks[:] = [] + event_loop = events.get_event_loop() + for callback in self._callbacks: + event_loop.call_soon(callback, self) + + def cancelled(self): + """XXX""" + return self._state == _CANCELLED + + def running(self): + """XXX""" + return self._state == _RUNNING + + def done(self): + """XXX""" + return self._state in _DONE_STATES + + def result(self, timeout=0): + """XXX""" + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + if self._exception is not None: + raise self._exception + return self._result + + def exception(self, timeout=0): + """XXX""" + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + return self._exception + + def add_done_callback(self, function): + """XXX""" + if self._state in _DONE_STATES: + events.get_event_loop().call_soon(function, self) + else: + self._callbacks.append(function) + + # So-called internal methods. + + def set_running_or_notify_cancel(self): + """XXX""" + if self._state == _CANCELLED: + return False + if self._state != _PENDING: + raise InvalidStateError + self._state = _RUNNING + return True + + def set_result(self, result): + """XXX""" + if self._state != _RUNNING: + raise InvalidStateError + self._result = result + self._state = _FINISHED + self._schedule_callbacks() + + def set_exception(self, exception): + """XXX""" + if self._state != _RUNNING: + raise InvalidStateError + self._exception = exception + self._state = _FINISHED + self._schedule_callbacks() + + +# TODO: Is this the right module for sleep()? +def sleep(when, result=None): + """Return a Future that completes after a given time (in seconds). + + It's okay to cancel the Future. + + Undocumented feature: sleep(when, x) sets the Future's result to x. + """ + event_loop = events.get_event_loop() + future = Future() + event_loop.call_later(when, _done_sleeping, future, result) + return future + + +def _done_sleeping(future, result=None): + """Helper for sleep() to set the result.""" + if future.set_running_or_notify_cancel(): + future.set_result(result) diff --git a/tulip/protocols.py b/tulip/protocols.py new file mode 100644 index 00000000..72f6abb1 --- /dev/null +++ b/tulip/protocols.py @@ -0,0 +1,9 @@ +"""Abstract Protocol class.""" + +__all__ = ['Protocol'] + + +class Protocol: + """Abstract protocol.""" + + ... diff --git a/tulip/tasks.py b/tulip/tasks.py new file mode 100644 index 00000000..7a5b8f22 --- /dev/null +++ b/tulip/tasks.py @@ -0,0 +1,23 @@ +"""Support for tasks, coroutines and the scheduler.""" + +__all__ = ['coroutine', 'Task'] + +import inspect + +from . import futures + + +def coroutine(func): + """Decorator to mark coroutines.""" + assert inspect.isgeneratorfunction(func) + func._is_coroutine = True + return func + + +class Task(futures.Future): + """A coroutine wrapped in a Future.""" + + def __init__(self, coro): + super().__init__() + assert inspect.isgenerator(coro) # Must be a coroutine *object* + # XXX Now what? diff --git a/tulip/transports.py b/tulip/transports.py new file mode 100644 index 00000000..2d1a84e0 --- /dev/null +++ b/tulip/transports.py @@ -0,0 +1,9 @@ +"""Abstract Transport class.""" + +__all__ = ['Transport'] + + +class Transport: + """Abstract transport.""" + + ... diff --git a/tulip/unix_events.py b/tulip/unix_events.py new file mode 100644 index 00000000..6f6c1a9f --- /dev/null +++ b/tulip/unix_events.py @@ -0,0 +1,514 @@ +"""UNIX event loop and related classes. + +The event loop can be broken up into a pollster (the part responsible +for telling us when file descriptors are ready) and the event loop +proper, which wraps a pollster with functionality for scheduling +callbacks, immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. + +There are several implementations of the pollster part, several using +esoteric system calls that exist only on some platforms. These are: + +- kqueue (most BSD systems) +- epoll (newer Linux systems) +- poll (most UNIX systems) +- select (all UNIX systems, and Windows) + +NOTE: We don't use select on systems where any of the others is +available, because select performs poorly as the number of file +descriptors goes up. The ranking is roughly: + + 1. kqueue, epoll, IOCP (best for each platform) + 2. poll (linear in number of file descriptors polled) + 3. select (linear in max number of file descriptors supported) + +TODO: +- Optimize the various pollsters. +- Unittests. +""" + +import collections +import concurrent.futures +import heapq +import logging +import os +import select +import threading +import time + +from . import events + +DelayedCall = events.DelayedCall # TODO: Use the module name. + + +class PollsterBase: + """Base class for all polling implementations. + + This defines an interface to register and unregister readers and + writers for specific file descriptors, and an interface to get a + list of events. There's also an interface to check whether any + readers or writers are currently registered. + """ + + def __init__(self): + super().__init__() + self.readers = {} # {fd: token, ...}. + self.writers = {} # {fd: token, ...}. + + def pollable(self): + """Return True if any readers or writers are currently registered.""" + return bool(self.readers or self.writers) + + # Subclasses are expected to extend the add/remove methods. + + def register_reader(self, fd, token): + """Add or update a reader for a file descriptor.""" + self.readers[fd] = token + + def register_writer(self, fd, token): + """Add or update a writer for a file descriptor.""" + self.writers[fd] = token + + def unregister_reader(self, fd): + """Remove the reader for a file descriptor.""" + del self.readers[fd] + + def unregister_writer(self, fd): + """Remove the writer for a file descriptor.""" + del self.writers[fd] + + def poll(self, timeout=None): + """Poll for events. A subclass must implement this. + + If timeout is omitted or None, this blocks until at least one + event is ready. Otherwise, timeout gives a maximum time to + wait (an int of float in seconds) -- the method returns as + soon as at least one event is ready or when the timeout is + expired. For a non-blocking poll, pass 0. + + The return value is a list of events; it is empty when the + timeout expired before any events were ready. Each event + is a token previously passed to register_reader/writer(). + """ + raise NotImplementedError + + +class SelectPollster(PollsterBase): + """Pollster implementation using select.""" + + def poll(self, timeout=None): + readable, writable, _ = select.select(self.readers, self.writers, + [], timeout) + events = [] + events += (self.readers[fd] for fd in readable) + events += (self.writers[fd] for fd in writable) + return events + + +class PollPollster(PollsterBase): + """Pollster implementation using poll.""" + + def __init__(self): + super().__init__() + self._poll = select.poll() + + def _update(self, fd): + assert isinstance(fd, int), fd + flags = 0 + if fd in self.readers: + flags |= select.POLLIN + if fd in self.writers: + flags |= select.POLLOUT + if flags: + self._poll.register(fd, flags) + else: + self._poll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + # Timeout is in seconds, but poll() takes milliseconds. + msecs = None if timeout is None else int(round(1000 * timeout)) + events = [] + for fd, flags in self._poll.poll(msecs): + if flags & (select.POLLIN | select.POLLHUP): + if fd in self.readers: + events.append(self.readers[fd]) + if flags & (select.POLLOUT | select.POLLHUP): + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class EPollPollster(PollsterBase): + """Pollster implementation using epoll.""" + + def __init__(self): + super().__init__() + self._epoll = select.epoll() + + def _update(self, fd): + assert isinstance(fd, int), fd + eventmask = 0 + if fd in self.readers: + eventmask |= select.EPOLLIN + if fd in self.writers: + eventmask |= select.EPOLLOUT + if eventmask: + try: + self._epoll.register(fd, eventmask) + except IOError: + self._epoll.modify(fd, eventmask) + else: + self._epoll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + if timeout is None: + timeout = -1 # epoll.poll() uses -1 to mean "wait forever". + events = [] + for fd, eventmask in self._epoll.poll(timeout): + if eventmask & select.EPOLLIN: + if fd in self.readers: + events.append(self.readers[fd]) + if eventmask & select.EPOLLOUT: + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class KqueuePollster(PollsterBase): + """Pollster implementation using kqueue.""" + + def __init__(self): + super().__init__() + self._kqueue = select.kqueue() + + def register_reader(self, fd, callback, *args): + if fd not in self.readers: + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_reader(fd, callback, *args) + + def register_writer(self, fd, callback, *args): + if fd not in self.writers: + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_writer(fd, callback, *args) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def poll(self, timeout=None): + events = [] + max_ev = len(self.readers) + len(self.writers) + for kev in self._kqueue.control(None, max_ev, timeout): + fd = kev.ident + flag = kev.filter + if flag == select.KQ_FILTER_READ and fd in self.readers: + events.append(self.readers[fd]) + elif flag == select.KQ_FILTER_WRITE and fd in self.writers: + events.append(self.writers[fd]) + return events + + +# Pick the best pollster class for the platform. +if hasattr(select, 'kqueue'): + best_pollster = KqueuePollster +elif hasattr(select, 'epoll'): + best_pollster = EPollPollster +elif hasattr(select, 'poll'): + best_pollster = PollPollster +else: + best_pollster = SelectPollster + + +class UnixEventLoop(events.EventLoop): + """Unix event loop. + + This defines public APIs call_soon(), call_later(), run_once() and + run(). It also wraps Pollster APIs register_reader(), + register_writer(), remove_reader(), remove_writer() with + add_reader() etc. + + This class's instance variables are not part of its API. + """ + + def __init__(self, pollster=None): + super().__init__() + if pollster is None: + logging.info('Using pollster: %s', best_pollster.__name__) + pollster = best_pollster() + self.pollster = pollster + self.ready = collections.deque() # [(callback, args), ...] + self.scheduled = [] # [(when, callback, args), ...] + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_reader(fd, dcall) + return dcall + + def remove_reader(self, fd): + """Remove a reader callback.""" + self.pollster.unregister_reader(fd) + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_writer(fd, dcall) + return dcall + + def remove_writer(self, fd): + """Remove a writer callback.""" + self.pollster.unregister_writer(fd) + + def add_callback(self, dcall): + """Add a DelayedCall to ready or scheduled.""" + if dcall.cancelled: + return + if dcall.when is None: + self.ready.append(dcall) + else: + heapq.heappush(self.scheduled, dcall) + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + dcall = DelayedCall(None, callback, args) + self.ready.append(dcall) + return dcall + + def call_later(self, when, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The time can be an int or float, expressed in seconds. + + If when is small enough (~11 days), it's assumed to be a + relative time, meaning the call will be scheduled that many + seconds in the future; otherwise it's assumed to be a posix + timestamp as returned by time.time(). + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + if when < 10000000: + when += time.time() + dcall = DelayedCall(when, callback, args) + heapq.heappush(self.scheduled, dcall) + return dcall + + def run_once(self): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Pass in a timeout or deadline or something. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: As step 4, run everything scheduled by steps 1-3. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # TODO: Ensure this loop always finishes, even if some + # callbacks keeps registering more callbacks. + while self.ready: + dcall = self.ready.popleft() + if not dcall.cancelled: + try: + if dcall.kwds: + dcall.callback(*dcall.args, **dcall.kwds) + else: + dcall.callback(*dcall.args) + except Exception: + logging.exception('Exception in callback %s %r', + dcall.callback, dcall.args) + + # Remove delayed calls that were cancelled from head of queue. + while self.scheduled and self.scheduled[0].cancelled: + heapq.heappop(self.scheduled) + + # Inspect the poll queue. + if self.pollster.pollable(): + if self.scheduled: + when = self.scheduled[0].when + timeout = max(0, when - time.time()) + else: + timeout = None + t0 = time.time() + events = self.pollster.poll(timeout) + t1 = time.time() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + for dcall in events: + self.add_callback(dcall) + + # Handle 'later' callbacks that are ready. + while self.scheduled: + dcall = self.scheduled[0] + if dcall.when > time.time(): + break + dcall = heapq.heappop(self.scheduled) + self.call_soon(dcall.callback, *dcall.args) + + def run(self): + """Run the event loop until there is no work left to do. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + """ + while self.ready or self.scheduled or self.pollster.pollable(): + self.run_once() + + +MAX_WORKERS = 5 # Default max workers when creating an executor. + + +class ThreadRunner: + """Helper to submit work to a thread pool and wait for it. + + This is the glue between the single-threaded callback-based async + world and the threaded world. Use it to call functions that must + block and don't have an async alternative (e.g. getaddrinfo()). + + The only public API is submit(). + """ + + def __init__(self, eventloop, executor=None): + self.eventloop = eventloop + self.executor = executor # Will be constructed lazily. + self.pipe_read_fd, self.pipe_write_fd = os.pipe() + self.active_count = 0 + + def read_callback(self): + """Semi-permanent callback while at least one future is active.""" + assert self.active_count > 0, self.active_count + data = os.read(self.pipe_read_fd, 8192) # Traditional buffer size. + self.active_count -= len(data) + if self.active_count == 0: + self.eventloop.remove_reader(self.pipe_read_fd) + assert self.active_count >= 0, self.active_count + + def submit(self, func, *args, executor=None, callback=None): + """Submit a function to the thread pool. + + This returns a concurrent.futures.Future instance. The caller + should not wait for that, but rather use the callback argument.. + """ + if executor is None: + executor = self.executor + if executor is None: + # Lazily construct a default executor. + # TODO: Should this be shared between threads? + executor = concurrent.futures.ThreadPoolExecutor(MAX_WORKERS) + self.executor = executor + assert self.active_count >= 0, self.active_count + future = executor.submit(func, *args) + if self.active_count == 0: + self.eventloop.add_reader(self.pipe_read_fd, self.read_callback) + self.active_count += 1 + def done_callback(fut): + if callback is not None: + self.eventloop.call_soon(callback, fut) + # TODO: Wake up the pipe in call_soon()? + os.write(self.pipe_write_fd, b'x') + future.add_done_callback(done_callback) + return future + + +class Context(threading.local): + """Thread-local context. + + We use this to avoid having to explicitly pass around an event loop + or something to hold the current task. + + TODO: Add an API so frameworks can substitute a different notion + of context more easily. + """ + + def __init__(self, eventloop=None, threadrunner=None): + # Default event loop and thread runner are lazily constructed + # when first accessed. + self._eventloop = eventloop + self._threadrunner = threadrunner + self.current_task = None # For the benefit of scheduling.py. + + @property + def eventloop(self): + if self._eventloop is None: + self._eventloop = EventLoop() + return self._eventloop + + @property + def threadrunner(self): + if self._threadrunner is None: + self._threadrunner = ThreadRunner(self.eventloop) + return self._threadrunner + + +context = Context() # Thread-local! From ad2c49f5c99e5088e24ef94797977fa523e54ff6 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 16 Dec 2012 20:00:03 -0800 Subject: [PATCH 0125/1502] Checkpoint. --- tulip/events.py | 61 ++++++-- tulip/futures.py | 25 +++- tulip/unix_events.py | 333 ++++++++++++++++++++----------------------- 3 files changed, 224 insertions(+), 195 deletions(-) diff --git a/tulip/events.py b/tulip/events.py index 67da8a84..bb05d047 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -8,8 +8,7 @@ __all__ = ['EventLoopPolicy', 'DefaultEventLoopPolicy', 'EventLoop', 'DelayedCall', 'get_event_loop_policy', 'set_event_loop_policy', - 'get_event_loop', 'set_event_loop', - 'init_event_loop', + 'get_event_loop', 'set_event_loop', 'init_event_loop', ] import threading @@ -19,23 +18,63 @@ class DelayedCall: """Object returned by callback registration methods.""" def __init__(self, when, callback, args, kwds=None): - self.when = when - self.callback = callback - self.args = args - self.kwds = kwds - self.cancelled = False + self._when = when + self._callback = callback + self._args = args + self._kwds = kwds + self._cancelled = False + + def __repr__(self): + if self.kwds: + res = 'DelayedCall({}, {}, {}, kwds={})'.format(self._when, + self._callback, + self._args, + self._kwds) + else: + res = 'DelayedCall({}, {}, {})'.format(self._when, + self._callback, + self._args) + if self._cancelled: + res += '' + return res + + @property + def when(self): + return self._when + + @property + def callback(self): + return self._callback + + @property + def args(self): + return self._args + + @property + def kwds(self): + return self._kwds + + @property + def cancelled(self): + return self._cancelled def cancel(self): - self.cancelled = True + self._cancelled = True def __lt__(self, other): - return self.when < other.when + return self._when < other._when def __le__(self, other): - return self.when <= other.when + return self._when <= other._when + + def __gt__(self, other): + return self._when > other._when + + def __ge__(self, other): + return self._when >= other._when def __eq__(self, other): - return self.when == other.when + return self._when == other._when class EventLoop: diff --git a/tulip/futures.py b/tulip/futures.py index 9e0fd767..2dfb6e13 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -120,12 +120,12 @@ def exception(self, timeout=0): raise InvalidStateError return self._exception - def add_done_callback(self, function): + def add_done_callback(self, fn): """XXX""" if self._state in _DONE_STATES: - events.get_event_loop().call_soon(function, self) + events.get_event_loop().call_soon(fn, self) else: - self._callbacks.append(function) + self._callbacks.append(fn) # So-called internal methods. @@ -154,6 +154,23 @@ def set_exception(self, exception): self._state = _FINISHED self._schedule_callbacks() + def _copy_state(self, other): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert other.done() + assert not self.done() + if other.cancelled(): + self.cancel() + elif self.set_running_or_notify_cancel(): + exception = other.exception() + if exception is not None: + self.set_exception(exception) + else: + result = other.result() + self.set_result(result) + # TODO: Is this the right module for sleep()? def sleep(when, result=None): @@ -170,6 +187,6 @@ def sleep(when, result=None): def _done_sleeping(future, result=None): - """Helper for sleep() to set the result.""" + """Internal helper for sleep() to set the result.""" if future.set_running_or_notify_cancel(): future.set_result(result) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 6f6c1a9f..23ad62f9 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -1,5 +1,7 @@ """UNIX event loop and related classes. +NOTE: The Pollster classes are not part of the published API. + The event loop can be broken up into a pollster (the part responsible for telling us when file descriptors are ready) and the event loop proper, which wraps a pollster with functionality for scheduling @@ -27,10 +29,6 @@ 1. kqueue, epoll, IOCP (best for each platform) 2. poll (linear in number of file descriptors polled) 3. select (linear in max number of file descriptors supported) - -TODO: -- Optimize the various pollsters. -- Unittests. """ import collections @@ -39,12 +37,14 @@ import logging import os import select +import socket import threading import time from . import events +from . import futures -DelayedCall = events.DelayedCall # TODO: Use the module name. +_MAX_WORKERS = 5 class PollsterBase: @@ -62,8 +62,10 @@ def __init__(self): self.writers = {} # {fd: token, ...}. def pollable(self): - """Return True if any readers or writers are currently registered.""" - return bool(self.readers or self.writers) + """Return the number readers and writers currently registered.""" + # The event loop needs the number since it must subtract one for + # the self-pipe. + return len(self.readers) + len(self.writers) # Subclasses are expected to extend the add/remove methods. @@ -130,12 +132,12 @@ def _update(self, fd): else: self._poll.unregister(fd) - def register_reader(self, fd, callback, *args): - super().register_reader(fd, callback, *args) + def register_reader(self, fd, token): + super().register_reader(fd, token) self._update(fd) - def register_writer(self, fd, callback, *args): - super().register_writer(fd, callback, *args) + def register_writer(self, fd, token): + super().register_writer(fd, token) self._update(fd) def unregister_reader(self, fd): @@ -182,12 +184,12 @@ def _update(self, fd): else: self._epoll.unregister(fd) - def register_reader(self, fd, callback, *args): - super().register_reader(fd, callback, *args) + def register_reader(self, fd, token): + super().register_reader(fd, token) self._update(fd) - def register_writer(self, fd, callback, *args): - super().register_writer(fd, callback, *args) + def register_writer(self, fd, token): + super().register_writer(fd, token) self._update(fd) def unregister_reader(self, fd): @@ -219,17 +221,17 @@ def __init__(self): super().__init__() self._kqueue = select.kqueue() - def register_reader(self, fd, callback, *args): + def register_reader(self, fd, token): if fd not in self.readers: kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) self._kqueue.control([kev], 0, 0) - return super().register_reader(fd, callback, *args) + return super().register_reader(fd, token) - def register_writer(self, fd, callback, *args): + def register_writer(self, fd, token): if fd not in self.writers: kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) self._kqueue.control([kev], 0, 0) - return super().register_writer(fd, callback, *args) + return super().register_writer(fd, token) def unregister_reader(self, fd): super().unregister_reader(fd) @@ -268,12 +270,7 @@ def poll(self, timeout=None): class UnixEventLoop(events.EventLoop): """Unix event loop. - This defines public APIs call_soon(), call_later(), run_once() and - run(). It also wraps Pollster APIs register_reader(), - register_writer(), remove_reader(), remove_writer() with - add_reader() etc. - - This class's instance variables are not part of its API. + See events.EventLoop for API specification. """ def __init__(self, pollster=None): @@ -281,52 +278,28 @@ def __init__(self, pollster=None): if pollster is None: logging.info('Using pollster: %s', best_pollster.__name__) pollster = best_pollster() - self.pollster = pollster - self.ready = collections.deque() # [(callback, args), ...] - self.scheduled = [] # [(when, callback, args), ...] - - def add_reader(self, fd, callback, *args): - """Add a reader callback. Return a DelayedCall instance.""" - dcall = DelayedCall(None, callback, args) - self.pollster.register_reader(fd, dcall) - return dcall - - def remove_reader(self, fd): - """Remove a reader callback.""" - self.pollster.unregister_reader(fd) - - def add_writer(self, fd, callback, *args): - """Add a writer callback. Return a DelayedCall instance.""" - dcall = DelayedCall(None, callback, args) - self.pollster.register_writer(fd, dcall) - return dcall - - def remove_writer(self, fd): - """Remove a writer callback.""" - self.pollster.unregister_writer(fd) - - def add_callback(self, dcall): - """Add a DelayedCall to ready or scheduled.""" - if dcall.cancelled: - return - if dcall.when is None: - self.ready.append(dcall) - else: - heapq.heappush(self.scheduled, dcall) + self._pollster = pollster + self._ready = collections.deque() # [(callback, args), ...] + self._scheduled = [] # [(when, callback, args), ...] + self._pipe_read_fd, self._pipe_write_fd = os.pipe() # Self-pipe. + self._pollster.register_reader(self._pipe_read_fd, + self._read_from_pipe) + self._default_executor = None - def call_soon(self, callback, *args): - """Arrange for a callback to be called as soon as possible. + def _read_from_pipe(self): + os.read(self._pipe_read_fd, 1) - This operates as a FIFO queue, callbacks are called in the - order in which they are registered. Each callback will be - called exactly once. + def run(self): + """Run the event loop until there is no work left to do. - Any positional arguments after the callback will be passed to - the callback when it is called. + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). """ - dcall = DelayedCall(None, callback, args) - self.ready.append(dcall) - return dcall + while self._ready or self._scheduled or self._pollster.pollable() > 1: + self._run_once() + + # TODO: stop()? def call_later(self, when, callback, *args): """Arrange for a callback to be called at a given time. @@ -345,16 +318,114 @@ def call_later(self, when, callback, *args): are scheduled for exactly the same time, it undefined which will be called first. + Events scheduled in the past are passed on to call_soon(), so + these will be called in the order in which they were + registered rather than by time due. This is so you can't + cheat and insert yourself at the front of the ready queue by + using a negative time. + Any positional arguments after the callback will be passed to the callback when it is called. """ + if when <= 0: + return self.call_soon(callback, *args) if when < 10000000: when += time.time() - dcall = DelayedCall(when, callback, args) - heapq.heappush(self.scheduled, dcall) + dcall = events.DelayedCall(when, callback, args) + heapq.heappush(self._scheduled, dcall) + return dcall + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + dcall = events.DelayedCall(None, callback, args) + self._ready.append(dcall) + return dcall + + def call_soon_threadsafe(self, callback, *args): + """XXX""" + dcall = self.call_soon(callback, *args) + os.write(self._pipe_write_fd, b'x') return dcall - def run_once(self): + def wrap_future(self, future): + """XXX""" + if isinstance(future, futures.Future): + return future # Don't wrap our own type of Future. + new_future = futures.Future() + future.add_done_callback( + lambda future: + self.call_soon_threadsafe(new_future._copy_state, future)) + return new_future + + def run_in_executor(self, executor, function, *args): + if executor is None: + executor = self._default_executor + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + self._default_executor = executor + return self.wrap_future(executor.submit(function, *args)) + + def set_default_executor(self, executor): + self._default_executor = executor + + def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) + + def getnameinfo(self, sockaddr, flags=0): + return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + + # XXX create_transport() + + # XXX start_serving() + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a DelayedCall instance.""" + dcall = events.DelayedCall(None, callback, args) + self._pollster.register_reader(fd, dcall) + return dcall + + def remove_reader(self, fd): + """Remove a reader callback.""" + self._pollster.unregister_reader(fd) + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a DelayedCall instance.""" + dcall = events.DelayedCall(None, callback, args) + self._pollster.register_writer(fd, dcall) + return dcall + + def remove_writer(self, fd): + """Remove a writer callback.""" + self._pollster.unregister_writer(fd) + + # XXX sock_recv() + + # XXX sock_send[all]() + + # XXX sock_connect() + + # XXX sock_accept() + + def _add_callback(self, dcall): + """Add a DelayedCall to ready or scheduled.""" + if dcall.cancelled: + return + if dcall.when is None: + self._ready.append(dcall) + else: + heapq.heappush(self._scheduled, dcall) + + def _run_once(self): """Run one full iteration of the event loop. This calls all currently ready callbacks, polls for I/O, @@ -372,8 +443,8 @@ def run_once(self): # All other places just add them to ready. # TODO: Ensure this loop always finishes, even if some # callbacks keeps registering more callbacks. - while self.ready: - dcall = self.ready.popleft() + while self._ready: + dcall = self._ready.popleft() if not dcall.cancelled: try: if dcall.kwds: @@ -385,18 +456,18 @@ def run_once(self): dcall.callback, dcall.args) # Remove delayed calls that were cancelled from head of queue. - while self.scheduled and self.scheduled[0].cancelled: - heapq.heappop(self.scheduled) + while self._scheduled and self._scheduled[0].cancelled: + heapq.heappop(self._scheduled) # Inspect the poll queue. - if self.pollster.pollable(): - if self.scheduled: - when = self.scheduled[0].when + if self._pollster.pollable() > 1: + if self._scheduled: + when = self._scheduled[0].when timeout = max(0, when - time.time()) else: timeout = None t0 = time.time() - events = self.pollster.poll(timeout) + events = self._pollster.poll(timeout) t1 = time.time() argstr = '' if timeout is None else ' %.3f' % timeout if t1-t0 >= 1: @@ -405,110 +476,12 @@ def run_once(self): level = logging.DEBUG logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) for dcall in events: - self.add_callback(dcall) + self._add_callback(dcall) # Handle 'later' callbacks that are ready. - while self.scheduled: - dcall = self.scheduled[0] + while self._scheduled: + dcall = self._scheduled[0] if dcall.when > time.time(): break - dcall = heapq.heappop(self.scheduled) + dcall = heapq.heappop(self._scheduled) self.call_soon(dcall.callback, *dcall.args) - - def run(self): - """Run the event loop until there is no work left to do. - - This keeps going as long as there are either readable and - writable file descriptors, or scheduled callbacks (of either - variety). - """ - while self.ready or self.scheduled or self.pollster.pollable(): - self.run_once() - - -MAX_WORKERS = 5 # Default max workers when creating an executor. - - -class ThreadRunner: - """Helper to submit work to a thread pool and wait for it. - - This is the glue between the single-threaded callback-based async - world and the threaded world. Use it to call functions that must - block and don't have an async alternative (e.g. getaddrinfo()). - - The only public API is submit(). - """ - - def __init__(self, eventloop, executor=None): - self.eventloop = eventloop - self.executor = executor # Will be constructed lazily. - self.pipe_read_fd, self.pipe_write_fd = os.pipe() - self.active_count = 0 - - def read_callback(self): - """Semi-permanent callback while at least one future is active.""" - assert self.active_count > 0, self.active_count - data = os.read(self.pipe_read_fd, 8192) # Traditional buffer size. - self.active_count -= len(data) - if self.active_count == 0: - self.eventloop.remove_reader(self.pipe_read_fd) - assert self.active_count >= 0, self.active_count - - def submit(self, func, *args, executor=None, callback=None): - """Submit a function to the thread pool. - - This returns a concurrent.futures.Future instance. The caller - should not wait for that, but rather use the callback argument.. - """ - if executor is None: - executor = self.executor - if executor is None: - # Lazily construct a default executor. - # TODO: Should this be shared between threads? - executor = concurrent.futures.ThreadPoolExecutor(MAX_WORKERS) - self.executor = executor - assert self.active_count >= 0, self.active_count - future = executor.submit(func, *args) - if self.active_count == 0: - self.eventloop.add_reader(self.pipe_read_fd, self.read_callback) - self.active_count += 1 - def done_callback(fut): - if callback is not None: - self.eventloop.call_soon(callback, fut) - # TODO: Wake up the pipe in call_soon()? - os.write(self.pipe_write_fd, b'x') - future.add_done_callback(done_callback) - return future - - -class Context(threading.local): - """Thread-local context. - - We use this to avoid having to explicitly pass around an event loop - or something to hold the current task. - - TODO: Add an API so frameworks can substitute a different notion - of context more easily. - """ - - def __init__(self, eventloop=None, threadrunner=None): - # Default event loop and thread runner are lazily constructed - # when first accessed. - self._eventloop = eventloop - self._threadrunner = threadrunner - self.current_task = None # For the benefit of scheduling.py. - - @property - def eventloop(self): - if self._eventloop is None: - self._eventloop = EventLoop() - return self._eventloop - - @property - def threadrunner(self): - if self._threadrunner is None: - self._threadrunner = ThreadRunner(self.eventloop) - return self._threadrunner - - -context = Context() # Thread-local! From b3e7656e1b623641a30a67116fdd8dc4051596fa Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 16 Dec 2012 20:08:52 -0800 Subject: [PATCH 0126/1502] Add short TODO for Tulip v2. --- tulip/TODO | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 tulip/TODO diff --git a/tulip/TODO b/tulip/TODO new file mode 100644 index 00000000..7175caa7 --- /dev/null +++ b/tulip/TODO @@ -0,0 +1,32 @@ +TODO in tulip v2 (tulip/ package directory) +------------------------------------------- + +- See also TBD and Open Issues in PEP 3156 + +- Docstrings + +- Unittests + +- stop()? + +- create_transport() + +- start_serving() + +- Make DelayedCall() callable? + +- Recognize DelayedCall passed to add_reader(), call_soon(), etc.? + +- sock_recv() + +- sock_send[all]() + +- sock_connect() + +- sock_accept() + +- Transports + +- Protocols + +- Task class From 290e8a460305c5f913559264ee3da4514c626483 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 17 Dec 2012 08:47:44 -0800 Subject: [PATCH 0127/1502] Avoid calling time.time() in a while loop. --- polling.py | 3 ++- tulip/unix_events.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/polling.py b/polling.py index 4790ea0a..cffed2af 100644 --- a/polling.py +++ b/polling.py @@ -428,9 +428,10 @@ def run_once(self): self.add_callback(dcall) # Handle 'later' callbacks that are ready. + now = time.time() while self.scheduled: dcall = self.scheduled[0] - if dcall.when > time.time(): + if dcall.when > now: break dcall = heapq.heappop(self.scheduled) self.call_soon(dcall.callback, *dcall.args) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 23ad62f9..1a7c4de0 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -479,9 +479,10 @@ def _run_once(self): self._add_callback(dcall) # Handle 'later' callbacks that are ready. + now = time.time() while self._scheduled: dcall = self._scheduled[0] - if dcall.when > time.time(): + if dcall.when > now: break dcall = heapq.heappop(self._scheduled) self.call_soon(dcall.callback, *dcall.args) From 1a84ac3b1bd70dfc5eec7a5caa664f70bc43b5d4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 17 Dec 2012 09:54:44 -0800 Subject: [PATCH 0128/1502] Use time.monotonic(). Only relative delays. --- tulip/TODO | 2 ++ tulip/events.py | 2 +- tulip/unix_events.py | 26 +++++++++++--------------- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/tulip/TODO b/tulip/TODO index 7175caa7..d77429f7 100644 --- a/tulip/TODO +++ b/tulip/TODO @@ -7,6 +7,8 @@ TODO in tulip v2 (tulip/ package directory) - Unittests +- Use time.monotonic(); don't allow absolute times + - stop()? - create_transport() diff --git a/tulip/events.py b/tulip/events.py index bb05d047..1d5feb77 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -88,7 +88,7 @@ def run(self): # Methods returning DelayedCalls for scheduling callbacks. - def call_later(self, when, callback, *args): + def call_later(self, delay, callback, *args): raise NotImplementedError def call_soon(self, callback, *args): diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 1a7c4de0..b9f73ccf 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -301,18 +301,14 @@ def run(self): # TODO: stop()? - def call_later(self, when, callback, *args): + def call_later(self, delay, callback, *args): """Arrange for a callback to be called at a given time. Return an object with a cancel() method that can be used to cancel the call. - The time can be an int or float, expressed in seconds. - - If when is small enough (~11 days), it's assumed to be a - relative time, meaning the call will be scheduled that many - seconds in the future; otherwise it's assumed to be a posix - timestamp as returned by time.time(). + The delay can be an int or float, expressed in seconds. It is + always a relative time. Each callback will be called exactly once. If two callbacks are scheduled for exactly the same time, it undefined which @@ -327,11 +323,9 @@ def call_later(self, when, callback, *args): Any positional arguments after the callback will be passed to the callback when it is called. """ - if when <= 0: + if delay <= 0: return self.call_soon(callback, *args) - if when < 10000000: - when += time.time() - dcall = events.DelayedCall(when, callback, args) + dcall = events.DelayedCall(time.monotonic() + delay, callback, args) heapq.heappush(self._scheduled, dcall) return dcall @@ -425,6 +419,8 @@ def _add_callback(self, dcall): else: heapq.heappush(self._scheduled, dcall) + # TODO: Make this public? + # TODO: Guarantee ready queue is empty on exit? def _run_once(self): """Run one full iteration of the event loop. @@ -463,12 +459,12 @@ def _run_once(self): if self._pollster.pollable() > 1: if self._scheduled: when = self._scheduled[0].when - timeout = max(0, when - time.time()) + timeout = max(0, when - time.monotonic()) else: timeout = None - t0 = time.time() + t0 = time.monotonic() events = self._pollster.poll(timeout) - t1 = time.time() + t1 = time.monotonic() argstr = '' if timeout is None else ' %.3f' % timeout if t1-t0 >= 1: level = logging.INFO @@ -479,7 +475,7 @@ def _run_once(self): self._add_callback(dcall) # Handle 'later' callbacks that are ready. - now = time.time() + now = time.monotonic() while self._scheduled: dcall = self._scheduled[0] if dcall.when > now: From c7d3ba1809bf198b0940c3bd879861a3c4649f0f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 17 Dec 2012 11:32:00 -0800 Subject: [PATCH 0129/1502] Better __repr__(). --- tulip/futures.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/tulip/futures.py b/tulip/futures.py index 2dfb6e13..0c494240 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -60,9 +60,19 @@ def __init__(self): def __repr__(self): """XXX""" - if self._callbacks: - # TODO: Maybe limit the list of callbacks if there are many? - return 'Future<{}, {}>'.format(self._state, self._callbacks) + if self._state == _FINISHED: + if self._exception is not None: + return 'Future'.format(self._exception) + else: + return 'Future'.format(self._result) + elif self._callbacks: + size = len(self._callbacks) + if size > 2: + return 'Future<{}, [{}, <{} more>, {}]>'.format( + self._state, self._callbacks[0], + size-2, self._callbacks[-1]) + else: + return 'Future<{}, {}>'.format(self._state, self._callbacks) else: return 'Future<{}>'.format(self._state) @@ -129,6 +139,10 @@ def add_done_callback(self, fn): # So-called internal methods. + # TODO: set_running_or_notify_cancel() is not really needed since + # we're not in a threaded environment. Consider allowing + # transitions directly from PENDING to FINISHED. + def set_running_or_notify_cancel(self): """XXX""" if self._state == _CANCELLED: From 2527c6a4f09200f722c26d279f9e5188ade5b217 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 17 Dec 2012 11:32:23 -0800 Subject: [PATCH 0130/1502] Add sock_recv() and friends. --- tulip/TODO | 12 +--- tulip/unix_events.py | 137 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 133 insertions(+), 16 deletions(-) diff --git a/tulip/TODO b/tulip/TODO index d77429f7..3171a1af 100644 --- a/tulip/TODO +++ b/tulip/TODO @@ -7,10 +7,10 @@ TODO in tulip v2 (tulip/ package directory) - Unittests -- Use time.monotonic(); don't allow absolute times - - stop()? +- better run_once() behavior? + - create_transport() - start_serving() @@ -19,14 +19,6 @@ TODO in tulip v2 (tulip/ package directory) - Recognize DelayedCall passed to add_reader(), call_soon(), etc.? -- sock_recv() - -- sock_send[all]() - -- sock_connect() - -- sock_accept() - - Transports - Protocols diff --git a/tulip/unix_events.py b/tulip/unix_events.py index b9f73ccf..83d0d177 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -33,6 +33,7 @@ import collections import concurrent.futures +import errno import heapq import logging import os @@ -44,6 +45,19 @@ from . import events from . import futures +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK, errno.EINPROGRESS)) + +# Argument for default thread pool executor creation. _MAX_WORKERS = 5 @@ -390,7 +404,8 @@ def add_reader(self, fd, callback, *args): def remove_reader(self, fd): """Remove a reader callback.""" - self._pollster.unregister_reader(fd) + if fd in self._pollster.readers: + self._pollster.unregister_reader(fd) def add_writer(self, fd, callback, *args): """Add a writer callback. Return a DelayedCall instance.""" @@ -400,15 +415,125 @@ def add_writer(self, fd, callback, *args): def remove_writer(self, fd): """Remove a writer callback.""" - self._pollster.unregister_writer(fd) + if fd in self._pollster.writers: + self._pollster.unregister_writer(fd) - # XXX sock_recv() + def sock_recv(self, sock, n): + """XXX""" + fut = futures.Future() + self._sock_recv(fut, False, sock, n) + return fut + + def _sock_recv(self, fut, registered, sock, n): + fd = sock.fileno() + if registered: + # Remove the callback early. It should be rare that the + # pollster says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_reader(fd) + if fut.cancelled(): + return + try: + data = sock.recv(n) + if fut.set_running_or_notify_cancel(): + fut.set_result(data) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + if fut.set_running_or_notify_cancel(): + fut.set_exception(exc) + else: + self.add_reader(fd, self._sock_recv, fut, True, sock, n) - # XXX sock_send[all]() + def sock_sendall(self, sock, data): + """XXX""" + fut = futures.Future() + self._sock_sendall(fut, False, sock, data) + return fut + + def _sock_sendall(self, fut, registered, sock, data): + fd = sock.fileno() + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + n = 0 + try: + if data: + n = sock.send(data) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + if fut.set_running_or_notify_cancel(): + fut.set_exception(exc) + return + if n == len(data): + if fut.set_running_or_notify_cancel(): + fut.set_result(None) + else: + if n: + data = data[n:] + self.add_writer(fd, self._sock_sendall, fut, True, sock, data) - # XXX sock_connect() + def sock_connect(self, sock, address): + """XXX""" + # That address better not require a lookup! We're not calling + # self.getaddrinfo() for you here. But verifying this is + # complicated; the socket module doesn't have a pattern for + # IPv6 addresses (there are too many forms, apparently). + fut = futures.Future() + self._sock_connect(fut, False, sock, address) + return fut + + def _sock_connect(self, fut, registered, sock, address): + fd = sock.fileno() + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + try: + if not registered: + # First time around. + sock.connect(address) + else: + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to the except clause below. + raise socket.error(err, 'Connect call failed') + if fut.set_running_or_notify_cancel(): + fut.set_result(None) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + if fut.set_running_or_notify_cancel(): + fut.set_exception(exc) + else: + self.add_writer(fd, self._sock_connect, + fut, True, sock, address) - # XXX sock_accept() + def sock_accept(self, sock): + """XXX""" + fut = futures.Future() + self._sock_accept(fut, False, sock) + return fut + + def _sock_accept(self, fut, registered, sock): + fd = sock.fileno() + if registered: + self.remove_reader(fd) + if fut.cancelled(): + return + try: + conn, address = sock.accept() + if fut.set_running_or_notify_cancel(): + conn.setblocking(False) + fut.set_result((conn, address)) + else: + conn.close() + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + if fut.set_running_or_notify_cancel(): + fut.set_exception(exc) + else: + self.add_reader(fd, self._sock_accept, fut, True, sock) def _add_callback(self, dcall): """Add a DelayedCall to ready or scheduled.""" From 25de213599c693d2ffcb74b528a37c89bf4d3bd1 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 17 Dec 2012 12:09:35 -0800 Subject: [PATCH 0131/1502] Added abstract Transport and Protocol classes. --- tulip/TODO | 7 ++-- tulip/protocols.py | 53 ++++++++++++++++++++++++++-- tulip/transports.py | 85 +++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 139 insertions(+), 6 deletions(-) diff --git a/tulip/TODO b/tulip/TODO index 3171a1af..5b6d8f1b 100644 --- a/tulip/TODO +++ b/tulip/TODO @@ -19,8 +19,11 @@ TODO in tulip v2 (tulip/ package directory) - Recognize DelayedCall passed to add_reader(), call_soon(), etc.? -- Transports +- Transport implementations -- Protocols +- Protocol implementations - Task class + +- Change Future to get rid of set_running_or_notify_cancel(). + We can just do "if not f.cancelled(): f.set_result(x)" diff --git a/tulip/protocols.py b/tulip/protocols.py index 72f6abb1..3a00e2ee 100644 --- a/tulip/protocols.py +++ b/tulip/protocols.py @@ -4,6 +4,55 @@ class Protocol: - """Abstract protocol.""" + """ABC representing a protocol. - ... + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing (they don't raise exceptions). + + When the user wants to requests a transport, they pass a protocol + factory to a utility function (e.g., EventLoop.create_transport()). + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_list() will be called exactly once + with either an exception object or None as an argument. + + State machine of calls: + + start -> CM [-> DR*] [-> ER?] -> CL -> end + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the connection. + To send data, call its write() or writelines() method. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + The default implementation does nothing. + + TODO: By default close the transport. But we don't have the + transport as an instance variable (connection_made() may not + set it). + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ diff --git a/tulip/transports.py b/tulip/transports.py index 2d1a84e0..a1eace56 100644 --- a/tulip/transports.py +++ b/tulip/transports.py @@ -4,6 +4,87 @@ class Transport: - """Abstract transport.""" + """ABC representing a transport. - ... + There may be several implementations, but typically, the user does + not implement new transports; rather, the platform provides some + useful transports that are implemented using the platform's best + practices. + + The user never instantiates a transport directly; they call a + utility function, passing it a protocol factory and other + information necessary to create the transport and protocol. (E.g. + EventLoop.create_transport() or EventLoop.start_serving().) + + The utility function will asynchronously create a transport and a + protocol and hook them up by calling the protocol's + connection_made() method, passing it the transport. + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def write_eof(self): + """Closes the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + raise NotImplementedError + + def can_write_eof(self): + """Return True if this protocol supports write_eof(), False if not.""" + raise NotImplementedError + + def pause(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + raise NotImplementedError + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError From 10abc453716597d09ecb6fcc820b3bffc28cab9e Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 17 Dec 2012 17:16:33 -0800 Subject: [PATCH 0132/1502] Checkpoint. Got rid of _RUNNING state for Futures. --- tulip/TODO | 7 +++-- tulip/futures.py | 46 ++++++++-------------------- tulip/tasks.py | 71 ++++++++++++++++++++++++++++++++++++++++++-- tulip/unix_events.py | 38 +++++++++++------------- 4 files changed, 102 insertions(+), 60 deletions(-) diff --git a/tulip/TODO b/tulip/TODO index 5b6d8f1b..76ad9cf4 100644 --- a/tulip/TODO +++ b/tulip/TODO @@ -15,6 +15,8 @@ TODO in tulip v2 (tulip/ package directory) - start_serving() +- Rename DelayedCall to Handler? + - Make DelayedCall() callable? - Recognize DelayedCall passed to add_reader(), call_soon(), etc.? @@ -23,7 +25,6 @@ TODO in tulip v2 (tulip/ package directory) - Protocol implementations -- Task class +- Task class: I/O blocking -- Change Future to get rid of set_running_or_notify_cancel(). - We can just do "if not f.cancelled(): f.set_result(x)" +- Primitives like par() and wait_one() diff --git a/tulip/futures.py b/tulip/futures.py index 0c494240..4eb4e78d 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -7,14 +7,13 @@ from . import events # States for Future. -_PENDING = 'PENDING' # Next states: _CANCELLED, _RUNNING. -_CANCELLED = 'CANCELLED' # End state. -_RUNNING = 'RUNNING' # Next state: _FINISHED. -_FINISHED = 'FINISHED' # End state. -_DONE_STATES = (_CANCELLED, _FINISHED) - +_PENDING = 'PENDING' +_CANCELLED = 'CANCELLED' +_FINISHED = 'FINISHED' +# TODO: Do we really want to depend on concurrent.futures internals? Error = concurrent.futures._base.Error +CancelledError = concurrent.futures.CancelledError class InvalidStateError(Error): @@ -47,7 +46,7 @@ class Future: # TODO: PEP 3148 seems to say that cancel() does not call the # callbacks, but set_running_or_notify_cancel() does (if cancel() # was called). Here, cancel() schedules the callbacks, and - # set_running_or_notify_cancel() just sets the state. + # set_running_or_notify_cancel() is not supported. # Class variables serving to as defaults for instance variables. _state = _PENDING @@ -102,11 +101,11 @@ def cancelled(self): def running(self): """XXX""" - return self._state == _RUNNING + return False # We don't have a running state. def done(self): """XXX""" - return self._state in _DONE_STATES + return self._state != _PENDING def result(self, timeout=0): """XXX""" @@ -132,29 +131,16 @@ def exception(self, timeout=0): def add_done_callback(self, fn): """XXX""" - if self._state in _DONE_STATES: + if self._state != _PENDING: events.get_event_loop().call_soon(fn, self) else: self._callbacks.append(fn) # So-called internal methods. - # TODO: set_running_or_notify_cancel() is not really needed since - # we're not in a threaded environment. Consider allowing - # transitions directly from PENDING to FINISHED. - - def set_running_or_notify_cancel(self): - """XXX""" - if self._state == _CANCELLED: - return False - if self._state != _PENDING: - raise InvalidStateError - self._state = _RUNNING - return True - def set_result(self, result): """XXX""" - if self._state != _RUNNING: + if self._state != _PENDING: raise InvalidStateError self._result = result self._state = _FINISHED @@ -162,7 +148,7 @@ def set_result(self, result): def set_exception(self, exception): """XXX""" - if self._state != _RUNNING: + if self._state != _PENDING: raise InvalidStateError self._exception = exception self._state = _FINISHED @@ -177,7 +163,7 @@ def _copy_state(self, other): assert not self.done() if other.cancelled(): self.cancel() - elif self.set_running_or_notify_cancel(): + else: exception = other.exception() if exception is not None: self.set_exception(exception) @@ -196,11 +182,5 @@ def sleep(when, result=None): """ event_loop = events.get_event_loop() future = Future() - event_loop.call_later(when, _done_sleeping, future, result) + event_loop.call_later(when, future.set_result, result) return future - - -def _done_sleeping(future, result=None): - """Internal helper for sleep() to set the result.""" - if future.set_running_or_notify_cancel(): - future.set_result(result) diff --git a/tulip/tasks.py b/tulip/tasks.py index 7a5b8f22..0c07bee0 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -4,20 +4,85 @@ import inspect +from . import events from . import futures def coroutine(func): """Decorator to mark coroutines.""" + # TODO: This is a feel-good API only. It is not enforced. assert inspect.isgeneratorfunction(func) - func._is_coroutine = True + func._is_coroutine = True # Not sure who can use this. return func +# TODO: Do we need this? +def iscoroutinefunction(func): + """Return True if func is a decorated coroutine function.""" + return (inspect.isgeneratorfunction(func) and + getattr(func, '_is_coroutine', False)) + + +# TODO: Do we need this? +def iscoroutine(obj): + """Return True if obj is a coroutine object.""" + return inspect.isgenerator(obj) # TODO: And what? + + class Task(futures.Future): """A coroutine wrapped in a Future.""" def __init__(self, coro): + assert inspect.isgenerator(coro) # Must be a coroutine *object*. super().__init__() - assert inspect.isgenerator(coro) # Must be a coroutine *object* - # XXX Now what? + self._event_loop = events.get_event_loop() + self._coro = coro + self._must_cancel = False + self._event_loop.call_soon(self._step) + + def cancel(self): + if self.done(): + return False + self._must_cancel = True + # _step() will call super().cancel() to call the callbacks. + return True + + def cancelled(self): + return self._must_cancel or super().cancelled() + + def _step(self): + if self.done(): + return + # We'll call either coro.throw(exc) or coro.send(value). + # TODO: Set these from the result of the Future on which we waited. + exc = None + value = None + if self._must_cancel: + exc = futures.CancelledError + coro = self._coro + try: + if exc is not None: + coro.throw(exc) + elif value is not None: + coro.send(value) + else: + next(coro) + except StopIteration as exc: + if self._must_cancel: + super().cancel() + else: + self.set_result(exc.value) + except Exception as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + except BaseException as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + raise + else: + # XXX What if blocked for I/O? + self._event_loop.call_soon(self._step) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 83d0d177..46a88663 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -392,9 +392,15 @@ def getaddrinfo(self, host, port, *, def getnameinfo(self, sockaddr, flags=0): return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) - # XXX create_transport() + # TODO: Or create_connection()? + def create_transport(self, protocol_factory, host, port, *, + family=0, type=0, proto=0, flags=0): + """XXX""" + - # XXX start_serving() + def start_serving(self, protocol_factory, host, port, *, + family=0, type=0, proto=0, flags=0): + """XXX""" def add_reader(self, fd, callback, *args): """Add a reader callback. Return a DelayedCall instance.""" @@ -436,12 +442,10 @@ def _sock_recv(self, fut, registered, sock, n): return try: data = sock.recv(n) - if fut.set_running_or_notify_cancel(): - fut.set_result(data) + fut.set_result(data) except socket.error as exc: if exc.errno not in _TRYAGAIN: - if fut.set_running_or_notify_cancel(): - fut.set_exception(exc) + fut.set_exception(exc) else: self.add_reader(fd, self._sock_recv, fut, True, sock, n) @@ -463,12 +467,10 @@ def _sock_sendall(self, fut, registered, sock, data): n = sock.send(data) except socket.error as exc: if exc.errno not in _TRYAGAIN: - if fut.set_running_or_notify_cancel(): - fut.set_exception(exc) + fut.set_exception(exc) return if n == len(data): - if fut.set_running_or_notify_cancel(): - fut.set_result(None) + fut.set_result(None) else: if n: data = data[n:] @@ -499,12 +501,10 @@ def _sock_connect(self, fut, registered, sock, address): if err != 0: # Jump to the except clause below. raise socket.error(err, 'Connect call failed') - if fut.set_running_or_notify_cancel(): - fut.set_result(None) + fut.set_result(None) except socket.error as exc: if exc.errno not in _TRYAGAIN: - if fut.set_running_or_notify_cancel(): - fut.set_exception(exc) + fut.set_exception(exc) else: self.add_writer(fd, self._sock_connect, fut, True, sock, address) @@ -523,15 +523,11 @@ def _sock_accept(self, fut, registered, sock): return try: conn, address = sock.accept() - if fut.set_running_or_notify_cancel(): - conn.setblocking(False) - fut.set_result((conn, address)) - else: - conn.close() + conn.setblocking(False) + fut.set_result((conn, address)) except socket.error as exc: if exc.errno not in _TRYAGAIN: - if fut.set_running_or_notify_cancel(): - fut.set_exception(exc) + fut.set_exception(exc) else: self.add_reader(fd, self._sock_accept, fut, True, sock) From 1f7299b9eeea21a4830fb9cd1da31a97fdacca52 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 17 Dec 2012 17:46:08 -0800 Subject: [PATCH 0133/1502] Checkpoint: First unittests. --- Makefile | 3 +++ runtests.py | 39 +++++++++++++++++++++++++++++++++++++ tulip/futures_test.py | 45 +++++++++++++++++++++++++++++++++++++++++++ tulip_bench.py | 30 ----------------------------- 4 files changed, 87 insertions(+), 30 deletions(-) create mode 100644 runtests.py create mode 100644 tulip/futures_test.py delete mode 100644 tulip_bench.py diff --git a/Makefile b/Makefile index 533e7427..dc64a639 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,9 @@ PYTHON=python3.3 test: + $(PYTHON) runtests.py -v + +main: $(PYTHON) main.py -v echo: diff --git a/runtests.py b/runtests.py new file mode 100644 index 00000000..fa7638a7 --- /dev/null +++ b/runtests.py @@ -0,0 +1,39 @@ +"""Run all unittests.""" + +# Originally written by Beech Horn (for NDB). + +import sys +import unittest + + +def load_tests(): + mods = ['futures'] + test_mods = ['%s_test' % name for name in mods] + tulip = __import__('tulip', fromlist=test_mods) + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + for mod in [getattr(tulip, name) for name in test_mods]: + for name in set(dir(mod)): + if name.endswith('Tests'): + test_module = getattr(mod, name) + tests = loader.loadTestsFromTestCase(test_module) + suite.addTests(tests) + + return suite + + +def main(): + v = 1 + for arg in sys.argv[1:]: + if arg.startswith('-v'): + v += arg.count('v') + elif arg == '-q': + v = 0 + result = unittest.TextTestRunner(verbosity=v).run(load_tests()) + sys.exit(not result.wasSuccessful()) + + +if __name__ == '__main__': + main() diff --git a/tulip/futures_test.py b/tulip/futures_test.py new file mode 100644 index 00000000..f7212878 --- /dev/null +++ b/tulip/futures_test.py @@ -0,0 +1,45 @@ +import unittest + + +from . import futures + + +class FutureTests(unittest.TestCase): + + def testInitialState(self): + f = futures.Future() + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertFalse(f.done()) + + def testCancel(self): + f = futures.Future() + f.cancel() + self.assertTrue(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(futures.CancelledError, f.result) + self.assertRaises(futures.CancelledError, f.exception) + + def testResult(self): + f = futures.Future() + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + + def testException(self): + exc = RuntimeError() + f = futures.Future() + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + + +if __name__ == '__main__': + unittest.main() diff --git a/tulip_bench.py b/tulip_bench.py deleted file mode 100644 index 64937881..00000000 --- a/tulip_bench.py +++ /dev/null @@ -1,30 +0,0 @@ -'''Example app using `file_async` and cancellations.''' - -__author__ = 'Guido van Rossum ' - -import time - -import scheduling - -def binary(n): - if n <= 0: - return 1 - l = yield from binary(n-1) - r = yield from binary(n-1) - return l + 1 + r - -def doit(depth): - t0 = time.time() - k = yield from binary(depth) - t1 = time.time() - print(depth, k, round(t1-t0, 6)) - return (depth, k, round(t1-t0, 6)) - -def main(): - for depth in range(20): - yield from doit(depth) - -import logging -logging.basicConfig(level=logging.DEBUG) - -scheduling.run(main()) From e8c4cc1cde047514b6b98c1f56e067679e41b4da Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 17 Dec 2012 20:21:56 -0800 Subject: [PATCH 0134/1502] Set up coverage. --- .hgignore | 2 ++ Makefile | 11 ++++++++++- runtests.py | 2 +- tulip/events_test.py | 27 +++++++++++++++++++++++++++ 4 files changed, 40 insertions(+), 2 deletions(-) create mode 100644 tulip/events_test.py diff --git a/.hgignore b/.hgignore index e3600e75..42309f0c 100644 --- a/.hgignore +++ b/.hgignore @@ -3,4 +3,6 @@ .*\.orig$ .*\#.*$ .*@.*$ +\.coverage$ +htmlcov$ \.DS_Store$ diff --git a/Makefile b/Makefile index dc64a639..95cb0711 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,17 @@ -PYTHON=python3.3 +PYTHON=python3 +COVERAGE=coverage3 +NONTESTS=`find tulip -name [a-z]\*.py ! -name \*_test.py` test: $(PYTHON) runtests.py -v +cov coverage: + $(COVERAGE) run runtests.py -v + $(COVERAGE) html $(NONTESTS) + $(COVERAGE) report -m $(NONTESTS) + echo "open file://`pwd`/htmlcov/index.html" + + main: $(PYTHON) main.py -v diff --git a/runtests.py b/runtests.py index fa7638a7..c6ffa3bc 100644 --- a/runtests.py +++ b/runtests.py @@ -7,7 +7,7 @@ def load_tests(): - mods = ['futures'] + mods = ['events', 'futures'] test_mods = ['%s_test' % name for name in mods] tulip = __import__('tulip', fromlist=test_mods) diff --git a/tulip/events_test.py b/tulip/events_test.py new file mode 100644 index 00000000..0c29aa62 --- /dev/null +++ b/tulip/events_test.py @@ -0,0 +1,27 @@ +"""Tests for events.py.""" + +import unittest + +from . import events + + +class EventLoopTests(unittest.TestCase): + + def testEventLoop(self): + pass + + +class DelayedCallTests(unittest.TestCase): + + def testDelayedCall(self): + pass + + +class PolicyTests(unittest.TestCase): + + def testPolicy(self): + pass + + +if __name__ == '__main__': + unittest.main() From 936af5f9092680779b5d39823190e041ae8aa8e8 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 17 Dec 2012 20:30:02 -0800 Subject: [PATCH 0135/1502] Update README --- README | 53 ++++++++++++++++++++--------------------------------- 1 file changed, 20 insertions(+), 33 deletions(-) diff --git a/README b/README index abfad624..94210f16 100644 --- a/README +++ b/README @@ -1,43 +1,30 @@ -Tulip is the codename for my attempt at understanding PEP-380 style -coroutines (i.e. those using generators and 'yield from'). +Tulip is the codename for my reference implementation of PEP 3156. -*** This requires Python 3.3 or later! *** - -For reference, see many threads in python-ideas@python.org started in -October 2012, especially those with "The async API of the Future" in -their subject, and the various spin-off threads. +PEP 3156: http://www.python.org/dev/peps/pep-3156/ -A particularly influential tutorial by Greg Ewing: -http://www.cosc.canterbury.ac.nz/greg.ewing/python/generators/yf_current/Examples/Scheduler/scheduler.txt - -A message I posted with some explanation of the design: -http://mail.python.org/pipermail/python-ideas/2012-October/017501.html +*** This requires Python 3.3 or later! *** -Essential files here (in top-to-bottom ordering): +Copyright/license: Open source, Apache 2.0. Enjoy. -- main.py: the main program for testing -- http_client.py: a rough HTTP/1.0 client -- sockets.py: transports for sockets and (client) SSL, and a buffering layer -- scheduling.py: a Task class and related stuff; this is where the PEP - 380 scheduler is implemented -- polling.py: an event loop and basic polling implementations for: - select(), poll(), epoll(), kqueue() +Master Mercurial repo: http://code.google.com/p/tulip/ -Secondary files: +The old code lives at the toplevel directory; the new code (conforming +to PEP 3156, under construction) lives in the tulip subdirectory. -- .hgignore: files I don't care about -- Makefile: various quick shell commands -- README: this file -- TODO: longer list of TODO items and general thoughts -- http_server.py: enough of an HTTP server to point 'ab' at -- longlines.py: stupid style checker -- p3time.py: benchmark yield from vs. plain functions -- tulip_bench.py: yet another benchmark (like p3time.py and yyftime.py) -- xkcd.py: *synchronous* ssl example -- yyftime.py: benchmark yield from vs. yield +To run tests: + - make test -Copyright/license: Open source, Apache 2.0. Enjoy. +To run coverage (after installing coverage3, see below): + - make coverage -Master Mercurial repo: http://code.google.com/p/tulip/ +To install coverage3 (coverage.py for Python 3), you need: + - Distribute (http://packages.python.org/distribute/) + - Coverage (http://nedbatchelder.com/code/coverage/) + What worked for me: + - curl -O http://python-distribute.org/distribute_setup.py + - python3 distribute_setup.py + - + - cd coveragepy + - python3 setup.py install --Guido van Rossum From 32c9d3a0c69b52c78c2d5e307fe7467c41a5b8e0 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 18 Dec 2012 09:10:40 -0800 Subject: [PATCH 0136/1502] A socket pair for Windows, public domain by Geert Jansen. --- tulip/winsocketpair.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 tulip/winsocketpair.py diff --git a/tulip/winsocketpair.py b/tulip/winsocketpair.py new file mode 100644 index 00000000..48263c95 --- /dev/null +++ b/tulip/winsocketpair.py @@ -0,0 +1,30 @@ +"""A socket pair usable as a self-pipe, for Windows. + +Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. +""" + +import errno +import socket + + +def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """Emulate the Unix socketpair() function on Windows.""" + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + lsock.bind(('localhost', 0)) + lsock.listen(1) + addr, port = lsock.getsockname() + csock = socket.socket(family, type, proto) + csock.setblocking(False) + try: + csock.connect((addr, port)) + except socket.error, e: + if e.errno != errno.WSAEWOULDBLOCK: + lsock.close() + csock.close() + raise + ssock, _ = lsock.accept() + csock.setblocking(True) + lsock.close() + return (ssock, csock) From 052b068028a1a6febfff50cb8843b5b06dccce77 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 18 Dec 2012 09:24:07 -0800 Subject: [PATCH 0137/1502] Hopeful code to make unix_events.py work on Windows. --- tulip/events_test.py | 5 +++-- tulip/unix_events.py | 31 ++++++++++++++++++++++++++----- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/tulip/events_test.py b/tulip/events_test.py index 0c29aa62..cc5925a8 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -7,8 +7,9 @@ class EventLoopTests(unittest.TestCase): - def testEventLoop(self): - pass + def testRun(self): + el = events.get_event_loop() + el.run() # Returns immediately. class DelayedCallTests(unittest.TestCase): diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 46a88663..561ce6ed 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -39,6 +39,7 @@ import os import select import socket +import sys import threading import time @@ -295,14 +296,34 @@ def __init__(self, pollster=None): self._pollster = pollster self._ready = collections.deque() # [(callback, args), ...] self._scheduled = [] # [(when, callback, args), ...] - self._pipe_read_fd, self._pipe_write_fd = os.pipe() # Self-pipe. - self._pollster.register_reader(self._pipe_read_fd, - self._read_from_pipe) self._default_executor = None + self._make_self_pipe_or_sock() + + def _make_self_pipe_or_sock(self): + if sys.platform == 'win32': + from . import winsocketpair + self._ssock, self._csock = winsocketpair.socketpair() + self._pollster.register_reader(self._ssock.fileno(), + self._read_from_self_sock) + self._write_to_self = self._write_to_self_sock + else: + self._pipe_read_fd, self._pipe_write_fd = os.pipe() # Self-pipe. + self._pollster.register_reader(self._pipe_read_fd, + self._read_from_self_pipe) + self._write_to_self = self._write_to_self_pipe + + def _read_from_self_sock(self): + self._ssock.recv(1) + + def _write_to_self_sock(self): + self._csock.send(b'x') - def _read_from_pipe(self): + def _read_from_self_pipe(self): os.read(self._pipe_read_fd, 1) + def _write_to_self_pipe(self): + os.write(self._pipe_write_fd, b'x') + def run(self): """Run the event loop until there is no work left to do. @@ -360,7 +381,7 @@ def call_soon(self, callback, *args): def call_soon_threadsafe(self, callback, *args): """XXX""" dcall = self.call_soon(callback, *args) - os.write(self._pipe_write_fd, b'x') + self._write_to_self() return dcall def wrap_future(self, future): From 33d6d1d04ac093b029191d7ec67beacaf8d000d3 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 18 Dec 2012 10:14:01 -0800 Subject: [PATCH 0138/1502] Fix py3 syntax error. --- tulip/winsocketpair.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tulip/winsocketpair.py b/tulip/winsocketpair.py index 48263c95..b9177fe1 100644 --- a/tulip/winsocketpair.py +++ b/tulip/winsocketpair.py @@ -19,7 +19,7 @@ def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): csock.setblocking(False) try: csock.connect((addr, port)) - except socket.error, e: + except socket.error as e: if e.errno != errno.WSAEWOULDBLOCK: lsock.close() csock.close() From cbda765fc3f8388b9bdd869dbc4720878e5010ad Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 18 Dec 2012 10:23:53 -0800 Subject: [PATCH 0139/1502] Move old stuff into old/. --- Makefile | 27 ++++++++++----------------- README | 4 ++-- longlines.py => check.py | 0 old/Makefile | 16 ++++++++++++++++ echoclt.py => old/echoclt.py | 0 echosvr.py => old/echosvr.py | 0 http_client.py => old/http_client.py | 0 http_server.py => old/http_server.py | 0 main.py => old/main.py | 0 p3time.py => old/p3time.py | 0 polling.py => old/polling.py | 0 scheduling.py => old/scheduling.py | 0 sockets.py => old/sockets.py | 0 transports.py => old/transports.py | 0 xkcd.py => old/xkcd.py | 0 yyftime.py => old/yyftime.py | 0 tulip/futures_test.py | 3 ++- tulip/unix_events.py | 1 - tulip/winsocketpair.py | 2 +- 19 files changed, 31 insertions(+), 22 deletions(-) rename longlines.py => check.py (100%) create mode 100644 old/Makefile rename echoclt.py => old/echoclt.py (100%) rename echosvr.py => old/echosvr.py (100%) rename http_client.py => old/http_client.py (100%) rename http_server.py => old/http_server.py (100%) rename main.py => old/main.py (100%) rename p3time.py => old/p3time.py (100%) rename polling.py => old/polling.py (100%) rename scheduling.py => old/scheduling.py (100%) rename sockets.py => old/sockets.py (100%) rename transports.py => old/transports.py (100%) rename xkcd.py => old/xkcd.py (100%) rename yyftime.py => old/yyftime.py (100%) diff --git a/Makefile b/Makefile index 95cb0711..3e74a160 100644 --- a/Makefile +++ b/Makefile @@ -11,21 +11,14 @@ cov coverage: $(COVERAGE) report -m $(NONTESTS) echo "open file://`pwd`/htmlcov/index.html" - -main: - $(PYTHON) main.py -v - -echo: - $(PYTHON) echosvr.py -v - -profile: - $(PYTHON) -m profile -s time main.py - -time: - $(PYTHON) p3time.py - -ytime: - $(PYTHON) yyftime.py - check: - $(PYTHON) longlines.py + $(PYTHON) check.py + +clean: + rm -f *.py[co] */*.py[co] + rm -f *~ */*~ + rm -f .*~ */.*~ + rm -f @* */@* + rm -f '#'*'#' */'#'*'#' + rm -f .coverage + rm -rf htmlcov diff --git a/README b/README index 94210f16..c1c86a54 100644 --- a/README +++ b/README @@ -8,8 +8,8 @@ Copyright/license: Open source, Apache 2.0. Enjoy. Master Mercurial repo: http://code.google.com/p/tulip/ -The old code lives at the toplevel directory; the new code (conforming -to PEP 3156, under construction) lives in the tulip subdirectory. +The old code lives in the subdirectory 'old'; the new code (conforming +to PEP 3156, under construction) lives in the 'tulip' subdirectory. To run tests: - make test diff --git a/longlines.py b/check.py similarity index 100% rename from longlines.py rename to check.py diff --git a/old/Makefile b/old/Makefile new file mode 100644 index 00000000..d352cd70 --- /dev/null +++ b/old/Makefile @@ -0,0 +1,16 @@ +PYTHON=python3 + +main: + $(PYTHON) main.py -v + +echo: + $(PYTHON) echosvr.py -v + +profile: + $(PYTHON) -m profile -s time main.py + +time: + $(PYTHON) p3time.py + +ytime: + $(PYTHON) yyftime.py diff --git a/echoclt.py b/old/echoclt.py similarity index 100% rename from echoclt.py rename to old/echoclt.py diff --git a/echosvr.py b/old/echosvr.py similarity index 100% rename from echosvr.py rename to old/echosvr.py diff --git a/http_client.py b/old/http_client.py similarity index 100% rename from http_client.py rename to old/http_client.py diff --git a/http_server.py b/old/http_server.py similarity index 100% rename from http_server.py rename to old/http_server.py diff --git a/main.py b/old/main.py similarity index 100% rename from main.py rename to old/main.py diff --git a/p3time.py b/old/p3time.py similarity index 100% rename from p3time.py rename to old/p3time.py diff --git a/polling.py b/old/polling.py similarity index 100% rename from polling.py rename to old/polling.py diff --git a/scheduling.py b/old/scheduling.py similarity index 100% rename from scheduling.py rename to old/scheduling.py diff --git a/sockets.py b/old/sockets.py similarity index 100% rename from sockets.py rename to old/sockets.py diff --git a/transports.py b/old/transports.py similarity index 100% rename from transports.py rename to old/transports.py diff --git a/xkcd.py b/old/xkcd.py similarity index 100% rename from xkcd.py rename to old/xkcd.py diff --git a/yyftime.py b/old/yyftime.py similarity index 100% rename from yyftime.py rename to old/yyftime.py diff --git a/tulip/futures_test.py b/tulip/futures_test.py index f7212878..f7d93ba2 100644 --- a/tulip/futures_test.py +++ b/tulip/futures_test.py @@ -1,5 +1,6 @@ -import unittest +"""Tests for futures.py.""" +import unittest from . import futures diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 561ce6ed..62e72567 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -417,7 +417,6 @@ def getnameinfo(self, sockaddr, flags=0): def create_transport(self, protocol_factory, host, port, *, family=0, type=0, proto=0, flags=0): """XXX""" - def start_serving(self, protocol_factory, host, port, *, family=0, type=0, proto=0, flags=0): diff --git a/tulip/winsocketpair.py b/tulip/winsocketpair.py index b9177fe1..87d54c91 100644 --- a/tulip/winsocketpair.py +++ b/tulip/winsocketpair.py @@ -12,7 +12,7 @@ def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): # We create a connected TCP socket. Note the trick with setblocking(0) # that prevents us from having to create a thread. lsock = socket.socket(family, type, proto) - lsock.bind(('localhost', 0)) + lsock.bind(('localhost', 0)) lsock.listen(1) addr, port = lsock.getsockname() csock = socket.socket(family, type, proto) From da4992635d1151bddd82f6df6ff822aa8387264d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 18 Dec 2012 15:31:15 -0800 Subject: [PATCH 0140/1502] Add tests for call_{soon,later}*. --- Makefile | 3 +++ tulip/events_test.py | 40 ++++++++++++++++++++++++++++++++++++++++ tulip/unix_events.py | 3 +++ 3 files changed, 46 insertions(+) diff --git a/Makefile b/Makefile index 3e74a160..9e749ab4 100644 --- a/Makefile +++ b/Makefile @@ -5,6 +5,9 @@ NONTESTS=`find tulip -name [a-z]\*.py ! -name \*_test.py` test: $(PYTHON) runtests.py -v +testloop: + while sleep 1; do $(PYTHON) runtests.py -v; done + cov coverage: $(COVERAGE) run runtests.py -v $(COVERAGE) html $(NONTESTS) diff --git a/tulip/events_test.py b/tulip/events_test.py index cc5925a8..3f2361eb 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -1,5 +1,7 @@ """Tests for events.py.""" +import threading +import time import unittest from . import events @@ -11,6 +13,44 @@ def testRun(self): el = events.get_event_loop() el.run() # Returns immediately. + def testCallLater(self): + el = events.get_event_loop() + results = [] + def callback(arg): + results.append(arg) + el.call_later(0.1, callback, 'hello world') + t0 = time.monotonic() + el.run() + t1 = time.monotonic() + self.assertEqual(results, ['hello world']) + self.assertTrue(t1-t0 >= 0.1) + + def testCallSoon(self): + el = events.get_event_loop() + results = [] + def callback(arg1, arg2): + results.append((arg1, arg2)) + el.call_soon(callback, 'hello', 'world') + el.run() + self.assertEqual(results, [('hello', 'world')]) + + def testCallSoonThreadsafe(self): + el = events.get_event_loop() + results = [] + def callback(arg): + results.append(arg) + def run(): + el.call_soon_threadsafe(callback, 'hello') + t = threading.Thread(target=run) + el.call_later(0.1, callback, 'world') + t0 = time.monotonic() + t.start() + el.run() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.1) + class DelayedCallTests(unittest.TestCase): diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 62e72567..88a5c8d0 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -120,6 +120,8 @@ class SelectPollster(PollsterBase): """Pollster implementation using select.""" def poll(self, timeout=None): + # TODO: Add connections to the third list since "connection + # failed" doesn't make the socket writable on Windows. readable, writable, _ = select.select(self.readers, self.writers, [], timeout) events = [] @@ -300,6 +302,7 @@ def __init__(self, pollster=None): self._make_self_pipe_or_sock() def _make_self_pipe_or_sock(self): + # TODO: Just always use socketpair(). See proactor branch. if sys.platform == 'win32': from . import winsocketpair self._ssock, self._csock = winsocketpair.socketpair() From 47a1ed6f13ede3750b8f1a6e298739fa4975e8a1 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Wed, 19 Dec 2012 00:26:03 +0000 Subject: [PATCH 0141/1502] Workaround for quirks in Windows' select() --- tulip/unix_events.py | 39 ++++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 88a5c8d0..b2cd7d57 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -57,6 +57,8 @@ # Errno values indicating the socket isn't ready for I/O just yet. _TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK, errno.EINPROGRESS)) +if sys.platform == 'win32': + _TRYAGAIN = frozenset(list(_TRYAGAIN) + [errno.WSAEWOULDBLOCK]) # Argument for default thread pool executor creation. _MAX_WORKERS = 5 @@ -116,18 +118,33 @@ def poll(self, timeout=None): raise NotImplementedError -class SelectPollster(PollsterBase): - """Pollster implementation using select.""" +if sys.platform != 'win32': - def poll(self, timeout=None): - # TODO: Add connections to the third list since "connection - # failed" doesn't make the socket writable on Windows. - readable, writable, _ = select.select(self.readers, self.writers, - [], timeout) - events = [] - events += (self.readers[fd] for fd in readable) - events += (self.writers[fd] for fd in writable) - return events + class SelectPollster(PollsterBase): + """Pollster implementation using select.""" + + def poll(self, timeout=None): + readable, writable, _ = select.select(self.readers, self.writers, + [], timeout) + events = [] + events += (self.readers[fd] for fd in readable) + events += (self.writers[fd] for fd in writable) + return events + +else: + + class SelectPollster(PollsterBase): + """Pollster implementation using select.""" + + def poll(self, timeout=None): + # failed connections are reported as exceptional but not writable + readable, writable, exceptional = select.select( + self.readers, self.writers, self.writers, timeout) + writable = set(writable).union(exceptional) + events = [] + events += (self.readers[fd] for fd in readable) + events += (self.writers[fd] for fd in writable) + return events class PollPollster(PollsterBase): From a807d3b49c8796e5f15d5809cb0cce19671ff4d8 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 18 Dec 2012 16:59:48 -0800 Subject: [PATCH 0142/1502] Tasks can wait for Futures. --- TODO | 15 +++++++++++++++ runtests.py | 2 +- tulip/TODO | 2 -- tulip/events_test.py | 2 +- tulip/futures.py | 31 +++++++++++++++++++------------ tulip/tasks.py | 44 +++++++++++++++++++++++++++++++++----------- tulip/tasks_test.py | 37 +++++++++++++++++++++++++++++++++++++ tulip/unix_events.py | 2 +- 8 files changed, 107 insertions(+), 28 deletions(-) create mode 100644 tulip/tasks_test.py diff --git a/TODO b/TODO index aca24e32..8250cc36 100644 --- a/TODO +++ b/TODO @@ -201,3 +201,18 @@ MISTAKES I MADE - Having the same callback for both reading and writing has a problem: it may be scheduled twice, and if the first call closes the socket, the second runs into trouble. + + +MISTAKES I MADE IN TULIP V2 + +- Nice one: Future._schedule_callbacks() wasn't scheduling any callbacks. + Spot the bug in these four lines: + + def _schedule_callbacks(self): + callbacks = self._callbacks[:] + self._callbacks[:] = [] + for callback in self._callbacks: + self._event_loop.call_soon(callback, self) + + The good news is that I found it with a unittest (albeit not the + unittest intended to exercise this particular method :-( ). diff --git a/runtests.py b/runtests.py index c6ffa3bc..fa784572 100644 --- a/runtests.py +++ b/runtests.py @@ -7,7 +7,7 @@ def load_tests(): - mods = ['events', 'futures'] + mods = ['events', 'futures', 'tasks'] test_mods = ['%s_test' % name for name in mods] tulip = __import__('tulip', fromlist=test_mods) diff --git a/tulip/TODO b/tulip/TODO index 76ad9cf4..94700c84 100644 --- a/tulip/TODO +++ b/tulip/TODO @@ -25,6 +25,4 @@ TODO in tulip v2 (tulip/ package directory) - Protocol implementations -- Task class: I/O blocking - - Primitives like par() and wait_one() diff --git a/tulip/events_test.py b/tulip/events_test.py index 3f2361eb..9f4ffde9 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -23,7 +23,7 @@ def callback(arg): el.run() t1 = time.monotonic() self.assertEqual(results, ['hello world']) - self.assertTrue(t1-t0 >= 0.1) + self.assertTrue(t1-t0 >= 0.09) def testCallSoon(self): el = events.get_event_loop() diff --git a/tulip/futures.py b/tulip/futures.py index 4eb4e78d..9a70efa9 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -55,25 +55,28 @@ class Future: def __init__(self): """XXX""" + self._event_loop = events.get_event_loop() self._callbacks = [] def __repr__(self): """XXX""" + res = self.__class__.__name__ if self._state == _FINISHED: if self._exception is not None: - return 'Future'.format(self._exception) + res += ''.format(self._exception) else: - return 'Future'.format(self._result) + res += ''.format(self._result) elif self._callbacks: size = len(self._callbacks) if size > 2: - return 'Future<{}, [{}, <{} more>, {}]>'.format( + res += '<{}, [{}, <{} more>, {}]>'.format( self._state, self._callbacks[0], size-2, self._callbacks[-1]) else: - return 'Future<{}, {}>'.format(self._state, self._callbacks) + res += '<{}, {}>'.format(self._state, self._callbacks) else: - return 'Future<{}>'.format(self._state) + res +='<{}>'.format(self._state) + return res def cancel(self): """XXX""" @@ -91,9 +94,8 @@ def _schedule_callbacks(self): # Is it worth emptying the callbacks? It may reduce the # usefulness of repr(). self._callbacks[:] = [] - event_loop = events.get_event_loop() - for callback in self._callbacks: - event_loop.call_soon(callback, self) + for callback in callbacks: + self._event_loop.call_soon(callback, self) def cancelled(self): """XXX""" @@ -132,11 +134,11 @@ def exception(self, timeout=0): def add_done_callback(self, fn): """XXX""" if self._state != _PENDING: - events.get_event_loop().call_soon(fn, self) + self._event_loop.call_soon(fn, self) else: self._callbacks.append(fn) - # So-called internal methods. + # So-called internal methods (note: no set_running_or_notify_cancel()). def set_result(self, result): """XXX""" @@ -154,6 +156,8 @@ def set_exception(self, exception): self._state = _FINISHED self._schedule_callbacks() + # Truly internal methods. + def _copy_state(self, other): """Internal helper to copy state from another Future. @@ -171,6 +175,10 @@ def _copy_state(self, other): result = other.result() self.set_result(result) + def __iter__(self): + yield self # This tells Task to wait for completion. + return self.result() # May raise too. + # TODO: Is this the right module for sleep()? def sleep(when, result=None): @@ -180,7 +188,6 @@ def sleep(when, result=None): Undocumented feature: sleep(when, x) sets the Future's result to x. """ - event_loop = events.get_event_loop() future = Future() - event_loop.call_later(when, future.set_result, result) + future._event_loop.call_later(when, future.set_result, result) return future diff --git a/tulip/tasks.py b/tulip/tasks.py index 0c07bee0..5fbd9335 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -2,6 +2,7 @@ __all__ = ['coroutine', 'Task'] +import concurrent.futures import inspect from . import events @@ -34,12 +35,19 @@ class Task(futures.Future): def __init__(self, coro): assert inspect.isgenerator(coro) # Must be a coroutine *object*. - super().__init__() - self._event_loop = events.get_event_loop() + super().__init__() # Sets self._event_loop. self._coro = coro self._must_cancel = False self._event_loop.call_soon(self._step) + def __repr__(self): + res = super().__repr__() + i = res.find('<') + if i < 0: + i = len(res) + res = res[:i] + '(<{}>)'.format(self._coro.__name__) + res[i:] + return res + def cancel(self): if self.done(): return False @@ -50,23 +58,21 @@ def cancel(self): def cancelled(self): return self._must_cancel or super().cancelled() - def _step(self): + def _step(self, value=None, exc=None): if self.done(): + logging.warn('_step(): already done: %r, %r, %r', self, value, exc) return # We'll call either coro.throw(exc) or coro.send(value). - # TODO: Set these from the result of the Future on which we waited. - exc = None - value = None if self._must_cancel: exc = futures.CancelledError coro = self._coro try: if exc is not None: - coro.throw(exc) + result = coro.throw(exc) elif value is not None: - coro.send(value) + result = coro.send(value) else: - next(coro) + result = next(coro) except StopIteration as exc: if self._must_cancel: super().cancel() @@ -84,5 +90,21 @@ def _step(self): self.set_exception(exc) raise else: - # XXX What if blocked for I/O? - self._event_loop.call_soon(self._step) + def _wakeup(future): + value = None + exc = future.exception() + if exc is None: + value = future.result() + self._step(value, exc) + if isinstance(result, futures.Future): + result.add_done_callback(_wakeup) + elif isinstance(result, concurrent.futures.Future): + # This ought to be more efficient than wrap_future(), + # because we don't create an extra Future. + result.add_done_callback( + lambda future: + self._event_loop.call_soon_threadsafe(_wakeup, future)) + else: + if result is not None: + logging.warn('_step(): bad yield: %r', result) + self._event_loop.call_soon(self._step) diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py new file mode 100644 index 00000000..74552007 --- /dev/null +++ b/tulip/tasks_test.py @@ -0,0 +1,37 @@ +"""Tests for tasks.py.""" + +import time +import unittest + +from . import futures +from . import tasks + + +class TaskTests(unittest.TestCase): + + def testTask(self): + @tasks.coroutine + def notmuch(): + yield from [] + return 'ok' + t = tasks.Task(notmuch()) + t._event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + + def testSleep(self): + @tasks.coroutine + def sleeper(dt, arg): + res = yield from futures.sleep(dt, arg) + return res + t = tasks.Task(sleeper(0.1, 'yeah')) + t0 = time.monotonic() + t._event_loop.run() + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.09) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + + +if __name__ == '__main__': + unittest.main() diff --git a/tulip/unix_events.py b/tulip/unix_events.py index b2cd7d57..5a5c9d19 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -137,7 +137,7 @@ class SelectPollster(PollsterBase): """Pollster implementation using select.""" def poll(self, timeout=None): - # failed connections are reported as exceptional but not writable + # Failed connections are reported as exceptional but not writable. readable, writable, exceptional = select.select( self.readers, self.writers, self.writers, timeout) writable = set(writable).union(exceptional) From 2bde5bcc0bf0be7aeebbd8ad60d610f77d4ec4ea Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 18 Dec 2012 20:17:06 -0800 Subject: [PATCH 0143/1502] Fix amusing bug in _make_self_pipe_or_sock(). --- TODO | 5 +++++ tulip/unix_events.py | 6 ++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/TODO b/TODO index 8250cc36..309bcd30 100644 --- a/TODO +++ b/TODO @@ -216,3 +216,8 @@ MISTAKES I MADE IN TULIP V2 The good news is that I found it with a unittest (albeit not the unittest intended to exercise this particular method :-( ). + +- In _make_self_pipe_or_sock(), called _pollster.register_reader() + instead of add_reader(), trying to optimize something but breaking + things instead (since the -- internal -- API of register_reader() + had changed). diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 5a5c9d19..b9770fca 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -323,13 +323,11 @@ def _make_self_pipe_or_sock(self): if sys.platform == 'win32': from . import winsocketpair self._ssock, self._csock = winsocketpair.socketpair() - self._pollster.register_reader(self._ssock.fileno(), - self._read_from_self_sock) + self.add_reader(self._ssock.fileno(), self._read_from_self_sock) self._write_to_self = self._write_to_self_sock else: self._pipe_read_fd, self._pipe_write_fd = os.pipe() # Self-pipe. - self._pollster.register_reader(self._pipe_read_fd, - self._read_from_self_pipe) + self.add_reader(self._pipe_read_fd, self._read_from_self_pipe) self._write_to_self = self._write_to_self_pipe def _read_from_self_sock(self): From 5b4863612ad030a7aa8362d7ff9ff5b5afffd69f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 18 Dec 2012 20:17:39 -0800 Subject: [PATCH 0144/1502] Add tests for wrap_future() and run_in_executor(). --- tulip/TODO | 2 ++ tulip/events_test.py | 41 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/tulip/TODO b/tulip/TODO index 94700c84..3c213e0f 100644 --- a/tulip/TODO +++ b/tulip/TODO @@ -9,6 +9,8 @@ TODO in tulip v2 (tulip/ package directory) - stop()? +- repeated callback + - better run_once() behavior? - create_transport() diff --git a/tulip/events_test.py b/tulip/events_test.py index 9f4ffde9..b4371370 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -1,5 +1,7 @@ """Tests for events.py.""" +import concurrent.futures +import os import threading import time import unittest @@ -7,8 +9,23 @@ from . import events +def run_while_future(event_loop, future): + r, w = os.pipe() + def cleanup(): + event_loop.remove_reader(r) + os.close(w) + os.close(r) + event_loop.add_reader(r, cleanup) + # Closing the write end makes the read end readable. + future.add_done_callback(lambda _: os.write(w, b'x')) + event_loop.run() + + class EventLoopTests(unittest.TestCase): + def setUp(self): + events.init_event_loop() + def testRun(self): el = events.get_event_loop() el.run() # Returns immediately. @@ -49,7 +66,29 @@ def run(): t1 = time.monotonic() t.join() self.assertEqual(results, ['hello', 'world']) - self.assertTrue(t1-t0 >= 0.1) + self.assertTrue(t1-t0 >= 0.09) + + def testWrapFuture(self): + el = events.get_event_loop() + def run(arg): + time.sleep(0.1) + return arg + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = el.wrap_future(f1) + run_while_future(el, f2) + self.assertTrue(f2.done()) + self.assertEqual(f2.result(), 'oi') + + def testRunInExecutor(self): + el = events.get_event_loop() + def run(arg): + time.sleep(0.1) + return arg + f2 = el.run_in_executor(None, run, 'yo') + run_while_future(el, f2) + self.assertTrue(f2.done()) + self.assertEqual(f2.result(), 'yo') class DelayedCallTests(unittest.TestCase): From db1ddd1687f8061c6c52a5699d44e31432175b01 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 18 Dec 2012 20:35:56 -0800 Subject: [PATCH 0145/1502] Make run_while_future() even more compact. --- tulip/events_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tulip/events_test.py b/tulip/events_test.py index b4371370..c7ec10a5 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -13,11 +13,10 @@ def run_while_future(event_loop, future): r, w = os.pipe() def cleanup(): event_loop.remove_reader(r) - os.close(w) os.close(r) event_loop.add_reader(r, cleanup) # Closing the write end makes the read end readable. - future.add_done_callback(lambda _: os.write(w, b'x')) + future.add_done_callback(lambda _: os.close(w)) event_loop.run() From 957015bd40751a25d23aa1d097e9855460c870ed Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 18 Dec 2012 20:47:31 -0800 Subject: [PATCH 0146/1502] Tests for {add,remove}_{reader,writer}. --- tulip/events_test.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tulip/events_test.py b/tulip/events_test.py index c7ec10a5..0fcf4f15 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -89,6 +89,35 @@ def run(arg): self.assertTrue(f2.done()) self.assertEqual(f2.result(), 'yo') + def test_reader_callback(self): + el = events.get_event_loop() + r, w = os.pipe() + bytes_read = [] + def reader(): + data = os.read(r, 1024) + if data: + bytes_read.append(data) + else: + el.remove_reader(r) + os.close(r) + el.add_reader(r, reader) + el.call_later(0.05, os.write, w, b'abc') + el.call_later(0.1, os.write, w, b'def') + el.call_later(0.15, os.close, w) + el.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_writer_callback(self): + el = events.get_event_loop() + r, w = os.pipe() + el.add_writer(w, os.write, w, b'x'*100) + el.call_later(0.1, el.remove_writer, w) + el.run() + os.close(w) + data = os.read(r, 32*1024) + os.close(r) + self.assertTrue(len(data) >= 200) + class DelayedCallTests(unittest.TestCase): From 3a6011fa4a3a3a06fac73d832bdc24146285468f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 18 Dec 2012 21:01:51 -0800 Subject: [PATCH 0147/1502] Tests for sock_{connect,recv,sendall}. --- tulip/events_test.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tulip/events_test.py b/tulip/events_test.py index 0fcf4f15..17158cf5 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -2,6 +2,7 @@ import concurrent.futures import os +import socket import threading import time import unittest @@ -10,6 +11,7 @@ def run_while_future(event_loop, future): + # TODO: Use socketpair(). r, w = os.pipe() def cleanup(): event_loop.remove_reader(r) @@ -18,6 +20,7 @@ def cleanup(): # Closing the write end makes the read end readable. future.add_done_callback(lambda _: os.close(w)) event_loop.run() + return future.result() # May raise. class EventLoopTests(unittest.TestCase): @@ -91,6 +94,7 @@ def run(arg): def test_reader_callback(self): el = events.get_event_loop() + # TODO: Use socketpair(). r, w = os.pipe() bytes_read = [] def reader(): @@ -109,6 +113,7 @@ def reader(): def test_writer_callback(self): el = events.get_event_loop() + # TODO: Use socketpair(). r, w = os.pipe() el.add_writer(w, os.write, w, b'x'*100) el.call_later(0.1, el.remove_writer, w) @@ -118,6 +123,17 @@ def test_writer_callback(self): os.close(r) self.assertTrue(len(data) >= 200) + def test_sock_client_ops(self): + el = events.get_event_loop() + sock = socket.socket() + sock.setblocking(False) + # TODO: This depends on python.org behavior! + run_while_future(el, el.sock_connect(sock, ('python.org', 80))) + run_while_future(el, el.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = run_while_future(el, el.sock_recv(sock, 1024)) + sock.close() + self.assertTrue(data.startswith(b'HTTP/1.1 302 Found\r\n')) + class DelayedCallTests(unittest.TestCase): From 650ad9889ddae15ceb649054d94c629614cfff08 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 19 Dec 2012 07:56:33 -0800 Subject: [PATCH 0148/1502] Test for sock_accept(). --- tulip/events_test.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/tulip/events_test.py b/tulip/events_test.py index 17158cf5..0a807960 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -78,9 +78,8 @@ def run(arg): ex = concurrent.futures.ThreadPoolExecutor(1) f1 = ex.submit(run, 'oi') f2 = el.wrap_future(f1) - run_while_future(el, f2) - self.assertTrue(f2.done()) - self.assertEqual(f2.result(), 'oi') + res = run_while_future(el, f2) + self.assertEqual(res, 'oi') def testRunInExecutor(self): el = events.get_event_loop() @@ -88,9 +87,8 @@ def run(arg): time.sleep(0.1) return arg f2 = el.run_in_executor(None, run, 'yo') - run_while_future(el, f2) - self.assertTrue(f2.done()) - self.assertEqual(f2.result(), 'yo') + res = run_while_future(el, f2) + self.assertEqual(res, 'yo') def test_reader_callback(self): el = events.get_event_loop() @@ -134,6 +132,23 @@ def test_sock_client_ops(self): sock.close() self.assertTrue(data.startswith(b'HTTP/1.1 302 Found\r\n')) + def test_sock_accept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + el = events.get_event_loop() + f = el.sock_accept(listener) + conn, addr = run_while_future(el, f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + class DelayedCallTests(unittest.TestCase): From 12a0276a55722bff8130855e1e2c877b4e4f0b2d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 19 Dec 2012 10:00:52 -0800 Subject: [PATCH 0149/1502] Rename DelayedCall to Handler. Add new abstract run/stop and call_repeatedly. --- tulip/events.py | 43 +++++++++++++++++++++++++++++++------------ tulip/events_test.py | 4 ++-- tulip/unix_events.py | 14 +++++++------- 3 files changed, 40 insertions(+), 21 deletions(-) diff --git a/tulip/events.py b/tulip/events.py index 1d5feb77..ea5334ce 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -6,7 +6,7 @@ """ __all__ = ['EventLoopPolicy', 'DefaultEventLoopPolicy', - 'EventLoop', 'DelayedCall', + 'EventLoop', 'Handler', 'get_event_loop_policy', 'set_event_loop_policy', 'get_event_loop', 'set_event_loop', 'init_event_loop', ] @@ -14,7 +14,7 @@ import threading -class DelayedCall: +class Handler: """Object returned by callback registration methods.""" def __init__(self, when, callback, args, kwds=None): @@ -26,14 +26,14 @@ def __init__(self, when, callback, args, kwds=None): def __repr__(self): if self.kwds: - res = 'DelayedCall({}, {}, {}, kwds={})'.format(self._when, - self._callback, - self._args, - self._kwds) + res = 'Handler({}, {}, {}, kwds={})'.format(self._when, + self._callback, + self._args, + self._kwds) else: - res = 'DelayedCall({}, {}, {})'.format(self._when, - self._callback, - self._args) + res = 'Handler({}, {}, {})'.format(self._when, + self._callback, + self._args) if self._cancelled: res += '' return res @@ -84,13 +84,32 @@ def run(self): """Run the event loop. Block until there is nothing left to do.""" raise NotImplementedError - # TODO: stop()? + def run_until_complete(self, future, timeout=None): + """Run the event loop until a Future is done. - # Methods returning DelayedCalls for scheduling callbacks. + Return the Future's result, or raise its exception. + + If timeout is not None, run it for at most that long; + if the Future is still not done, raise TimeoutError + (but don't cancel the Future). + """ + raise NotImplementedError + + def stop(self): # NEW! + """Stop the event loop as soon as reasonable. + + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + + # Methods returning Handlers for scheduling callbacks. def call_later(self, delay, callback, *args): raise NotImplementedError + def call_repeatedly(self, interval, callback, *args): # NEW! + raise NotImplementdError + def call_soon(self, callback, *args): return self.call_later(0, callback, *args) @@ -122,7 +141,7 @@ def start_serving(self, protocol_factory, host, port, *, raise NotImplementedError # Ready-based callback registration methods. - # The add_*() methods return a DelayedCall. + # The add_*() methods return a Handler. # The remove_*() methods return True if something was removed, # False if there was nothing to delete. diff --git a/tulip/events_test.py b/tulip/events_test.py index 0a807960..6fcaa451 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -150,9 +150,9 @@ def test_sock_accept(self): listener.close() -class DelayedCallTests(unittest.TestCase): +class HandlerTests(unittest.TestCase): - def testDelayedCall(self): + def testHandler(self): pass diff --git a/tulip/unix_events.py b/tulip/unix_events.py index b9770fca..ce78e43a 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -378,7 +378,7 @@ def call_later(self, delay, callback, *args): """ if delay <= 0: return self.call_soon(callback, *args) - dcall = events.DelayedCall(time.monotonic() + delay, callback, args) + dcall = events.Handler(time.monotonic() + delay, callback, args) heapq.heappush(self._scheduled, dcall) return dcall @@ -392,7 +392,7 @@ def call_soon(self, callback, *args): Any positional arguments after the callback will be passed to the callback when it is called. """ - dcall = events.DelayedCall(None, callback, args) + dcall = events.Handler(None, callback, args) self._ready.append(dcall) return dcall @@ -441,8 +441,8 @@ def start_serving(self, protocol_factory, host, port, *, """XXX""" def add_reader(self, fd, callback, *args): - """Add a reader callback. Return a DelayedCall instance.""" - dcall = events.DelayedCall(None, callback, args) + """Add a reader callback. Return a Handler instance.""" + dcall = events.Handler(None, callback, args) self._pollster.register_reader(fd, dcall) return dcall @@ -452,8 +452,8 @@ def remove_reader(self, fd): self._pollster.unregister_reader(fd) def add_writer(self, fd, callback, *args): - """Add a writer callback. Return a DelayedCall instance.""" - dcall = events.DelayedCall(None, callback, args) + """Add a writer callback. Return a Handler instance.""" + dcall = events.Handler(None, callback, args) self._pollster.register_writer(fd, dcall) return dcall @@ -570,7 +570,7 @@ def _sock_accept(self, fut, registered, sock): self.add_reader(fd, self._sock_accept, fut, True, sock) def _add_callback(self, dcall): - """Add a DelayedCall to ready or scheduled.""" + """Add a Handler to ready or scheduled.""" if dcall.cancelled: return if dcall.when is None: From 65f273c541fd9062f2907d3c5032f512e786e92a Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Wed, 19 Dec 2012 19:40:23 +0000 Subject: [PATCH 0150/1502] Stop ignoring HUP and errors indicated by poll() and epoll(). --- tulip/unix_events.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index ce78e43a..fb1da9b3 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -187,10 +187,10 @@ def poll(self, timeout=None): msecs = None if timeout is None else int(round(1000 * timeout)) events = [] for fd, flags in self._poll.poll(msecs): - if flags & (select.POLLIN | select.POLLHUP): + if flags & ~select.POLLOUT: if fd in self.readers: events.append(self.readers[fd]) - if flags & (select.POLLOUT | select.POLLHUP): + if flags & ~select.POLLIN: if fd in self.writers: events.append(self.writers[fd]) return events @@ -239,10 +239,10 @@ def poll(self, timeout=None): timeout = -1 # epoll.poll() uses -1 to mean "wait forever". events = [] for fd, eventmask in self._epoll.poll(timeout): - if eventmask & select.EPOLLIN: + if eventmask & ~select.EPOLLOUT: if fd in self.readers: events.append(self.readers[fd]) - if eventmask & select.EPOLLOUT: + if eventmask & ~select.EPOLLIN: if fd in self.writers: events.append(self.writers[fd]) return events From b57e43af0bf0ad6e720de81f9cb048dd1f59633c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 19 Dec 2012 16:19:54 -0800 Subject: [PATCH 0151/1502] Add stop() and run_until_complete(). --- tulip/events.py | 7 +- tulip/events_test.py | 25 ++----- tulip/futures.py | 6 +- tulip/unix_events.py | 159 +++++++++++++++++++++++++++---------------- 4 files changed, 117 insertions(+), 80 deletions(-) diff --git a/tulip/events.py b/tulip/events.py index ea5334ce..7f7fb543 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -84,7 +84,11 @@ def run(self): """Run the event loop. Block until there is nothing left to do.""" raise NotImplementedError - def run_until_complete(self, future, timeout=None): + def run_once(self, timeout=None): # NEW! + """Run one complete cycle of the event loop.""" + raise NotImplementedError + + def run_until_complete(self, future, timeout=None): # NEW! """Run the event loop until a Future is done. Return the Future's result, or raise its exception. @@ -101,6 +105,7 @@ def stop(self): # NEW! Exactly how soon that is may depend on the implementation, but no more I/O callbacks should be scheduled. """ + raise NotImplementedError # Methods returning Handlers for scheduling callbacks. diff --git a/tulip/events_test.py b/tulip/events_test.py index 6fcaa451..e76aa78d 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -10,19 +10,6 @@ from . import events -def run_while_future(event_loop, future): - # TODO: Use socketpair(). - r, w = os.pipe() - def cleanup(): - event_loop.remove_reader(r) - os.close(r) - event_loop.add_reader(r, cleanup) - # Closing the write end makes the read end readable. - future.add_done_callback(lambda _: os.close(w)) - event_loop.run() - return future.result() # May raise. - - class EventLoopTests(unittest.TestCase): def setUp(self): @@ -78,7 +65,7 @@ def run(arg): ex = concurrent.futures.ThreadPoolExecutor(1) f1 = ex.submit(run, 'oi') f2 = el.wrap_future(f1) - res = run_while_future(el, f2) + res = el.run_until_complete(f2) self.assertEqual(res, 'oi') def testRunInExecutor(self): @@ -87,7 +74,7 @@ def run(arg): time.sleep(0.1) return arg f2 = el.run_in_executor(None, run, 'yo') - res = run_while_future(el, f2) + res = el.run_until_complete(f2) self.assertEqual(res, 'yo') def test_reader_callback(self): @@ -126,9 +113,9 @@ def test_sock_client_ops(self): sock = socket.socket() sock.setblocking(False) # TODO: This depends on python.org behavior! - run_while_future(el, el.sock_connect(sock, ('python.org', 80))) - run_while_future(el, el.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) - data = run_while_future(el, el.sock_recv(sock, 1024)) + el.run_until_complete(el.sock_connect(sock, ('python.org', 80))) + el.run_until_complete(el.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = el.run_until_complete(el.sock_recv(sock, 1024)) sock.close() self.assertTrue(data.startswith(b'HTTP/1.1 302 Found\r\n')) @@ -141,7 +128,7 @@ def test_sock_accept(self): client.connect(listener.getsockname()) el = events.get_event_loop() f = el.sock_accept(listener) - conn, addr = run_while_future(el, f) + conn, addr = el.run_until_complete(f) self.assertEqual(conn.gettimeout(), 0) self.assertEqual(addr, client.getsockname()) self.assertEqual(client.getpeername(), listener.getsockname()) diff --git a/tulip/futures.py b/tulip/futures.py index 9a70efa9..616a6ac3 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -1,6 +1,9 @@ """A Future class similar to the one in PEP 3148.""" -__all__ = ['Future', 'InvalidStateError', 'InvalidTimeoutError', 'sleep'] +__all__ = ['CancelledError', 'TimeoutError', + 'InvalidStateError', 'InvalidTimeoutError', + 'Future', 'sleep', + ] import concurrent.futures._base @@ -14,6 +17,7 @@ # TODO: Do we really want to depend on concurrent.futures internals? Error = concurrent.futures._base.Error CancelledError = concurrent.futures.CancelledError +TimeoutError = concurrent.futures.CancelledError class InvalidStateError(Error): diff --git a/tulip/unix_events.py b/tulip/unix_events.py index fb1da9b3..181187d2 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -75,8 +75,8 @@ class PollsterBase: def __init__(self): super().__init__() - self.readers = {} # {fd: token, ...}. - self.writers = {} # {fd: token, ...}. + self.readers = {} # {fd: handler, ...}. + self.writers = {} # {fd: handler, ...}. def pollable(self): """Return the number readers and writers currently registered.""" @@ -86,13 +86,13 @@ def pollable(self): # Subclasses are expected to extend the add/remove methods. - def register_reader(self, fd, token): + def register_reader(self, fd, handler): """Add or update a reader for a file descriptor.""" - self.readers[fd] = token + self.readers[fd] = handler - def register_writer(self, fd, token): + def register_writer(self, fd, handler): """Add or update a writer for a file descriptor.""" - self.writers[fd] = token + self.writers[fd] = handler def unregister_reader(self, fd): """Remove the reader for a file descriptor.""" @@ -113,7 +113,7 @@ def poll(self, timeout=None): The return value is a list of events; it is empty when the timeout expired before any events were ready. Each event - is a token previously passed to register_reader/writer(). + is a handler previously passed to register_reader/writer(). """ raise NotImplementedError @@ -166,12 +166,12 @@ def _update(self, fd): else: self._poll.unregister(fd) - def register_reader(self, fd, token): - super().register_reader(fd, token) + def register_reader(self, fd, handler): + super().register_reader(fd, handler) self._update(fd) - def register_writer(self, fd, token): - super().register_writer(fd, token) + def register_writer(self, fd, handler): + super().register_writer(fd, handler) self._update(fd) def unregister_reader(self, fd): @@ -218,12 +218,12 @@ def _update(self, fd): else: self._epoll.unregister(fd) - def register_reader(self, fd, token): - super().register_reader(fd, token) + def register_reader(self, fd, handler): + super().register_reader(fd, handler) self._update(fd) - def register_writer(self, fd, token): - super().register_writer(fd, token) + def register_writer(self, fd, handler): + super().register_writer(fd, handler) self._update(fd) def unregister_reader(self, fd): @@ -255,17 +255,17 @@ def __init__(self): super().__init__() self._kqueue = select.kqueue() - def register_reader(self, fd, token): + def register_reader(self, fd, handler): if fd not in self.readers: kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) self._kqueue.control([kev], 0, 0) - return super().register_reader(fd, token) + return super().register_reader(fd, handler) - def register_writer(self, fd, token): + def register_writer(self, fd, handler): if fd not in self.writers: kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) self._kqueue.control([kev], 0, 0) - return super().register_writer(fd, token) + return super().register_writer(fd, handler) def unregister_reader(self, fd): super().unregister_reader(fd) @@ -301,6 +301,14 @@ def poll(self, timeout=None): best_pollster = SelectPollster +class _StopError(BaseException): + """Raised to stop the event loop.""" + + +def _raise_stop_error(): + raise _StopError + + class UnixEventLoop(events.EventLoop): """Unix event loop. @@ -343,16 +351,52 @@ def _write_to_self_pipe(self): os.write(self._pipe_write_fd, b'x') def run(self): - """Run the event loop until there is no work left to do. + """Run the event loop until nothing left to do or stop() called. This keeps going as long as there are either readable and writable file descriptors, or scheduled callbacks (of either variety). + + TODO: Give this a timeout too? """ while self._ready or self._scheduled or self._pollster.pollable() > 1: - self._run_once() + try: + self._run_once() + except _StopError: + break + + def run_once(self, timeout=None): + """Run through all callbacks and all I/O polls once.""" + try: + self._run_once(timeout) + except _StopError: + pass - # TODO: stop()? + def run_until_complete(self, future, timeout=None): + """Run until the Future is done, or until a timeout. + + Return the Future's result, or raise its exception. If the + timeout is reached or stop() is called, raise TimeoutError. + """ + if timeout is None: + timeout = 1e8 # Over 3 years; kqueue doesn't like it larger. + future.add_done_callback(lambda _: self.stop()) + handler = self.call_later(timeout, self.stop) + self.run() + handler.cancel() + if future.done(): + return future.result() # May raise future.exception(). + else: + raise futures.TimeoutError + + def stop(self): + """Stop running the event loop. + + Every callback scheduled before stop() is called will run. + Callback scheduled after stop() is called won't. However, + those callbacks will run if run() is called again later. + """ + self.call_soon(_raise_stop_error) def call_later(self, delay, callback, *args): """Arrange for a callback to be called at a given time. @@ -375,12 +419,14 @@ def call_later(self, delay, callback, *args): Any positional arguments after the callback will be passed to the callback when it is called. + + # TODO: Should delay is None be interpreted as Infinity? """ if delay <= 0: return self.call_soon(callback, *args) - dcall = events.Handler(time.monotonic() + delay, callback, args) - heapq.heappush(self._scheduled, dcall) - return dcall + handler = events.Handler(time.monotonic() + delay, callback, args) + heapq.heappush(self._scheduled, handler) + return handler def call_soon(self, callback, *args): """Arrange for a callback to be called as soon as possible. @@ -392,15 +438,15 @@ def call_soon(self, callback, *args): Any positional arguments after the callback will be passed to the callback when it is called. """ - dcall = events.Handler(None, callback, args) - self._ready.append(dcall) - return dcall + handler = events.Handler(None, callback, args) + self._ready.append(handler) + return handler def call_soon_threadsafe(self, callback, *args): """XXX""" - dcall = self.call_soon(callback, *args) + handler = self.call_soon(callback, *args) self._write_to_self() - return dcall + return handler def wrap_future(self, future): """XXX""" @@ -442,9 +488,9 @@ def start_serving(self, protocol_factory, host, port, *, def add_reader(self, fd, callback, *args): """Add a reader callback. Return a Handler instance.""" - dcall = events.Handler(None, callback, args) - self._pollster.register_reader(fd, dcall) - return dcall + handler = events.Handler(None, callback, args) + self._pollster.register_reader(fd, handler) + return handler def remove_reader(self, fd): """Remove a reader callback.""" @@ -453,9 +499,9 @@ def remove_reader(self, fd): def add_writer(self, fd, callback, *args): """Add a writer callback. Return a Handler instance.""" - dcall = events.Handler(None, callback, args) - self._pollster.register_writer(fd, dcall) - return dcall + handler = events.Handler(None, callback, args) + self._pollster.register_writer(fd, handler) + return handler def remove_writer(self, fd): """Remove a writer callback.""" @@ -569,18 +615,16 @@ def _sock_accept(self, fut, registered, sock): else: self.add_reader(fd, self._sock_accept, fut, True, sock) - def _add_callback(self, dcall): + def _add_callback(self, handler): """Add a Handler to ready or scheduled.""" - if dcall.cancelled: + if handler.cancelled: return - if dcall.when is None: - self._ready.append(dcall) + if handler.when is None: + self._ready.append(handler) else: - heapq.heappush(self._scheduled, dcall) + heapq.heappush(self._scheduled, handler) - # TODO: Make this public? - # TODO: Guarantee ready queue is empty on exit? - def _run_once(self): + def _run_once(self, timeout=None): """Run one full iteration of the event loop. This calls all currently ready callbacks, polls for I/O, @@ -588,7 +632,6 @@ def _run_once(self): 'call_later' callbacks. """ # TODO: Break each of these into smaller pieces. - # TODO: Pass in a timeout or deadline or something. # TODO: Refactor to separate the callbacks from the readers/writers. # TODO: As step 4, run everything scheduled by steps 1-3. # TODO: An alternative API would be to do the *minimal* amount @@ -599,16 +642,16 @@ def _run_once(self): # TODO: Ensure this loop always finishes, even if some # callbacks keeps registering more callbacks. while self._ready: - dcall = self._ready.popleft() - if not dcall.cancelled: + handler = self._ready.popleft() + if not handler.cancelled: try: - if dcall.kwds: - dcall.callback(*dcall.args, **dcall.kwds) + if handler.kwds: + handler.callback(*handler.args, **handler.kwds) else: - dcall.callback(*dcall.args) + handler.callback(*handler.args) except Exception: logging.exception('Exception in callback %s %r', - dcall.callback, dcall.args) + handler.callback, handler.args) # Remove delayed calls that were cancelled from head of queue. while self._scheduled and self._scheduled[0].cancelled: @@ -619,8 +662,6 @@ def _run_once(self): if self._scheduled: when = self._scheduled[0].when timeout = max(0, when - time.monotonic()) - else: - timeout = None t0 = time.monotonic() events = self._pollster.poll(timeout) t1 = time.monotonic() @@ -630,14 +671,14 @@ def _run_once(self): else: level = logging.DEBUG logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) - for dcall in events: - self._add_callback(dcall) + for handler in events: + self._add_callback(handler) # Handle 'later' callbacks that are ready. now = time.monotonic() while self._scheduled: - dcall = self._scheduled[0] - if dcall.when > now: + handler = self._scheduled[0] + if handler.when > now: break - dcall = heapq.heappop(self._scheduled) - self.call_soon(dcall.callback, *dcall.args) + handler = heapq.heappop(self._scheduled) + self.call_soon(handler.callback, *handler.args) From fcd3a7beed171203ab81467c7317f8d029348bfa Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 19 Dec 2012 17:48:32 -0800 Subject: [PATCH 0152/1502] Add call_repeatedly(). --- tulip/events_test.py | 10 ++++++++++ tulip/unix_events.py | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/tulip/events_test.py b/tulip/events_test.py index e76aa78d..39ad744d 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -31,6 +31,16 @@ def callback(arg): self.assertEqual(results, ['hello world']) self.assertTrue(t1-t0 >= 0.09) + def testCallRepeatedly(self): + el = events.get_event_loop() + results = [] + def callback(arg): + results.append(arg) + el.call_repeatedly(0.03, callback, 'ho') + el.call_later(0.1, el.stop) + el.run() + self.assertEqual(results, ['ho', 'ho', 'ho']) + def testCallSoon(self): el = events.get_event_loop() results = [] diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 181187d2..63d54b51 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -428,6 +428,16 @@ def call_later(self, delay, callback, *args): heapq.heappush(self._scheduled, handler) return handler + def call_repeatedly(self, interval, callback, *args): + """Call a callback every 'interval' seconds.""" + def wrapper(): + callback(*args) # If this fails, the chain is broken. + handler._when = time.monotonic() + interval + heapq.heappush(self._scheduled, handler) + handler = events.Handler(time.monotonic() + interval, wrapper, ()) + heapq.heappush(self._scheduled, handler) + return handler + def call_soon(self, callback, *args): """Arrange for a callback to be called as soon as possible. From 17e09cebf92c7f9c6cb23353560fe6060fc3fe11 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 19 Dec 2012 18:12:51 -0800 Subject: [PATCH 0153/1502] Use socketpair() instead of os.pipe() everywhere. --- tulip/TODO | 10 ++-------- tulip/events_test.py | 32 +++++++++++++++----------------- tulip/unix_events.py | 36 ++++++++++++++---------------------- 3 files changed, 31 insertions(+), 47 deletions(-) diff --git a/tulip/TODO b/tulip/TODO index 3c213e0f..c0eca195 100644 --- a/tulip/TODO +++ b/tulip/TODO @@ -7,21 +7,15 @@ TODO in tulip v2 (tulip/ package directory) - Unittests -- stop()? - -- repeated callback - - better run_once() behavior? - create_transport() - start_serving() -- Rename DelayedCall to Handler? - -- Make DelayedCall() callable? +- Make Handler() callable? -- Recognize DelayedCall passed to add_reader(), call_soon(), etc.? +- Recognize Handler passed to add_reader(), call_soon(), etc.? - Transport implementations diff --git a/tulip/events_test.py b/tulip/events_test.py index 39ad744d..45ddbf0e 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -1,13 +1,13 @@ """Tests for events.py.""" import concurrent.futures -import os import socket import threading import time import unittest from . import events +from . import unix_events class EventLoopTests(unittest.TestCase): @@ -89,33 +89,31 @@ def run(arg): def test_reader_callback(self): el = events.get_event_loop() - # TODO: Use socketpair(). - r, w = os.pipe() + r, w = unix_events.socketpair() bytes_read = [] def reader(): - data = os.read(r, 1024) + data = r.recv(1024) if data: bytes_read.append(data) else: - el.remove_reader(r) - os.close(r) - el.add_reader(r, reader) - el.call_later(0.05, os.write, w, b'abc') - el.call_later(0.1, os.write, w, b'def') - el.call_later(0.15, os.close, w) + el.remove_reader(r.fileno()) + r.close() + el.add_reader(r.fileno(), reader) + el.call_later(0.05, w.send, b'abc') + el.call_later(0.1, w.send, b'def') + el.call_later(0.15, w.close) el.run() self.assertEqual(b''.join(bytes_read), b'abcdef') def test_writer_callback(self): el = events.get_event_loop() - # TODO: Use socketpair(). - r, w = os.pipe() - el.add_writer(w, os.write, w, b'x'*100) - el.call_later(0.1, el.remove_writer, w) + r, w = unix_events.socketpair() + el.add_writer(w.fileno(), w.send, b'x'*100) + el.call_later(0.1, el.remove_writer, w.fileno()) el.run() - os.close(w) - data = os.read(r, 32*1024) - os.close(r) + w.close() + data = r.recv(32*1024) + r.close() self.assertTrue(len(data) >= 200) def test_sock_client_ops(self): diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 63d54b51..0c00957e 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -36,7 +36,6 @@ import errno import heapq import logging -import os import select import socket import sys @@ -46,6 +45,12 @@ from . import events from . import futures +try: + from socket import socketpair +except ImportError: + assert sys.platform == 'win32' + from winsocketpair import socketpair + # Errno values indicating the connection was disconnected. _DISCONNECTED = frozenset((errno.ECONNRESET, errno.ENOTCONN, @@ -324,32 +329,19 @@ def __init__(self, pollster=None): self._ready = collections.deque() # [(callback, args), ...] self._scheduled = [] # [(when, callback, args), ...] self._default_executor = None - self._make_self_pipe_or_sock() - - def _make_self_pipe_or_sock(self): - # TODO: Just always use socketpair(). See proactor branch. - if sys.platform == 'win32': - from . import winsocketpair - self._ssock, self._csock = winsocketpair.socketpair() - self.add_reader(self._ssock.fileno(), self._read_from_self_sock) - self._write_to_self = self._write_to_self_sock - else: - self._pipe_read_fd, self._pipe_write_fd = os.pipe() # Self-pipe. - self.add_reader(self._pipe_read_fd, self._read_from_self_pipe) - self._write_to_self = self._write_to_self_pipe + self._make_self_pipe() + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = socketpair() + self.add_reader(self._ssock.fileno(), self._read_from_self) - def _read_from_self_sock(self): + def _read_from_self(self): self._ssock.recv(1) - def _write_to_self_sock(self): + def _write_to_self(self): self._csock.send(b'x') - def _read_from_self_pipe(self): - os.read(self._pipe_read_fd, 1) - - def _write_to_self_pipe(self): - os.write(self._pipe_write_fd, b'x') - def run(self): """Run the event loop until nothing left to do or stop() called. From bd8a71260b2368aee4476d2f02bc49302fa431c6 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 19 Dec 2012 18:26:52 -0800 Subject: [PATCH 0154/1502] Avoid extra indirection when scheduling a stop. --- tulip/unix_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 0c00957e..a544e20b 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -373,7 +373,7 @@ def run_until_complete(self, future, timeout=None): if timeout is None: timeout = 1e8 # Over 3 years; kqueue doesn't like it larger. future.add_done_callback(lambda _: self.stop()) - handler = self.call_later(timeout, self.stop) + handler = self.call_later(timeout, _raise_stop_error) self.run() handler.cancel() if future.done(): From e99793339ffc95e1d28aa8221e02f3aa49be968b Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Thu, 20 Dec 2012 10:32:22 +0000 Subject: [PATCH 0155/1502] Fix import of socketpair on Windows. --- tulip/unix_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index a544e20b..274952d4 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -49,7 +49,7 @@ from socket import socketpair except ImportError: assert sys.platform == 'win32' - from winsocketpair import socketpair + from .winsocketpair import socketpair # Errno values indicating the connection was disconnected. _DISCONNECTED = frozenset((errno.ECONNRESET, From 552e404ef725b9da7fe11ab0c5b3409b5921e3f5 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Thu, 20 Dec 2012 15:22:07 +0000 Subject: [PATCH 0156/1502] Change default timeout in run_until_complete() to avoid overflow. --- tulip/unix_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 274952d4..fb87320c 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -371,7 +371,7 @@ def run_until_complete(self, future, timeout=None): timeout is reached or stop() is called, raise TimeoutError. """ if timeout is None: - timeout = 1e8 # Over 3 years; kqueue doesn't like it larger. + timeout = 0x7fffffff/1000.0 # 24 days future.add_done_callback(lambda _: self.stop()) handler = self.call_later(timeout, _raise_stop_error) self.run() From e630e8cd53d14e2e60ade2ba1fcea6db42978c62 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Thu, 20 Dec 2012 15:39:34 +0000 Subject: [PATCH 0157/1502] Add support for select.poll() on Windows versions which support it. This depends on the patch for Python 3.x at http://bugs.python.org/issue16507 It works around the WSAPoll bug discussed at daniel.haxx.se/blog/2012/10/10/wsapoll-is-broken/ --- tulip/events.py | 6 ++ tulip/events_test.py | 9 +++ tulip/unix_events.py | 129 +++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 140 insertions(+), 4 deletions(-) diff --git a/tulip/events.py b/tulip/events.py index 7f7fb543..d843c946 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -162,6 +162,12 @@ def add_writer(self, fd, callback, *args): def remove_writer(self, fd): raise NotImplementedError + def add_connector(self, fd, callback, *args): + raise NotImplementedError + + def remove_connector(self, fd): + raise NotImplementedError + # Completion based I/O methods returning Futures. def sock_recv(self, sock, nbytes): diff --git a/tulip/events_test.py b/tulip/events_test.py index 45ddbf0e..2dc21f6c 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -127,6 +127,15 @@ def test_sock_client_ops(self): sock.close() self.assertTrue(data.startswith(b'HTTP/1.1 302 Found\r\n')) + def test_sock_client_fail(self): + el = events.get_event_loop() + sock = socket.socket() + sock.setblocking(False) + # TODO: This depends on python.org behavior! + with self.assertRaises(ConnectionRefusedError): + el.run_until_complete(el.sock_connect(sock, ('python.org', 12345))) + sock.close() + def test_sock_accept(self): listener = socket.socket() listener.setblocking(False) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index fb87320c..bf8163c4 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -107,6 +107,16 @@ def unregister_writer(self, fd): """Remove the writer for a file descriptor.""" del self.writers[fd] + def register_connector(self, fd, handler): + """Add or update a connector for a file descriptor.""" + # On Unix a connector is the same as a writer. + self.register_writer(fd, handler) + + def unregister_connector(self, fd): + """Remove the connector for a file descriptor.""" + # On Unix a connector is the same as a writer. + self.unregister_writer(fd) + def poll(self, timeout=None): """Poll for events. A subclass must implement this. @@ -141,16 +151,32 @@ def poll(self, timeout=None): class SelectPollster(PollsterBase): """Pollster implementation using select.""" + def __init__(self): + super().__init__() + self.exceptionals = {} + def poll(self, timeout=None): # Failed connections are reported as exceptional but not writable. readable, writable, exceptional = select.select( - self.readers, self.writers, self.writers, timeout) + self.readers, self.writers, self.exceptionals, timeout) writable = set(writable).union(exceptional) events = [] events += (self.readers[fd] for fd in readable) events += (self.writers[fd] for fd in writable) return events + def register_connector(self, fd, token): + self.register_writer(fd, token) + self.exceptionals[fd] = token + + def unregister_connector(self, fd): + self.unregister_writer(fd) + try: + del self.exceptionals[fd] + except KeyError: + # remove_connector() does not check fd in self.exceptionals. + pass + class PollPollster(PollsterBase): """Pollster implementation using poll.""" @@ -201,6 +227,89 @@ def poll(self, timeout=None): return events +if sys.platform == 'win32': + + class WindowsPollPollster(PollPollster): + """Pollster implementation using WSAPoll. + + WSAPoll is only available on Windows Vista and later. Python + does not currently support WSAPoll, but there is a patch + available at http://bugs.python.org/issue16507. + """ + + # REAP_PERIOD is the maximum wait before checking for failed + # connections. This is necessary because WSAPoll() does notify us + # of failed connections. See + # daniel.haxx.se/blog/2012/10/10/wsapoll-is-broken/ + REAP_PERIOD = 5.0 + + # FD_SETSIZE is maximum number of sockets in an fd_set + FD_SETSIZE = 512 + + def __init__(self): + super().__init__() + self.exceptionals = {} + + def register_connector(self, fd, token): + self.register_writer(fd, token) + self.exceptionals[fd] = token + + def unregister_connector(self, fd): + self.unregister_writer(fd) + try: + del self.exceptionals[fd] + except KeyError: + # remove_connector() does not check fd in self.exceptionals. + pass + + def _get_failed_connector_events(self): + fds = [] + remaining = list(self.exceptionals) + while remaining: + fds += select.select([], [], remaining[:self.FD_SETSIZE], 0)[2] + del remaining[:self.FD_SETSIZE] + return [(fd, select.POLLOUT) for fd in fds] + + def poll(self, timeout=None): + if not self.exceptionals: + msecs = None if timeout is None else int(round(1000 * timeout)) + polled = self._poll.poll(msecs) + + elif timeout is None: + polled = None + while not polled: + polled = (self._get_failed_connector_events() or + self._poll.poll(self.REAP_PERIOD)) + + elif timeout == 0: + polled = (self._get_failed_connector_events() or + self._poll.poll(0)) + + else: + start = time.monotonic() + deadline = start + timeout + polled = None + while timeout >= 0: + msecs = int(round(1000 * min(self.REAP_PERIOD, timeout))) + polled = (self._get_failed_connector_events() or + self._poll.poll(self.REAP_PERIOD)) + if polled: + break + timemout = deadline - time.monotonic() + + events = [] + for fd, flags in polled: + if flags & ~select.POLLOUT: + if fd in self.readers: + events.append(self.readers[fd]) + if flags & ~select.POLLIN: + if fd in self.writers: + events.append(self.writers[fd]) + return events + + PollPollster = WindowsPollPollster + + class EPollPollster(PollsterBase): """Pollster implementation using epoll.""" @@ -510,6 +619,18 @@ def remove_writer(self, fd): if fd in self._pollster.writers: self._pollster.unregister_writer(fd) + def add_connector(self, fd, callback, *args): + """Add a connector callback. Return a Handler instance.""" + dcall = events.Handler(None, callback, args) + self._pollster.register_connector(fd, dcall) + return dcall + + def remove_connector(self, fd): + """Remove a connector callback.""" + # Every connector fd is in self._pollsters.writers. + if fd in self._pollster.writers: + self._pollster.unregister_connector(fd) + def sock_recv(self, sock, n): """XXX""" fut = futures.Future() @@ -575,7 +696,7 @@ def sock_connect(self, sock, address): def _sock_connect(self, fut, registered, sock, address): fd = sock.fileno() if registered: - self.remove_writer(fd) + self.remove_connector(fd) if fut.cancelled(): return try: @@ -592,8 +713,8 @@ def _sock_connect(self, fut, registered, sock, address): if exc.errno not in _TRYAGAIN: fut.set_exception(exc) else: - self.add_writer(fd, self._sock_connect, - fut, True, sock, address) + self.add_connector(fd, self._sock_connect, + fut, True, sock, address) def sock_accept(self, sock): """XXX""" From 50f04831dd7689056d8323e2e6e7f7673022a31f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 20 Dec 2012 13:36:22 -0800 Subject: [PATCH 0158/1502] Add @task decorator and test for it. (This wraps the coro in a Task.) --- tulip/tasks.py | 10 +++++++++- tulip/tasks_test.py | 12 +++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/tulip/tasks.py b/tulip/tasks.py index 5fbd9335..7c7828d8 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -1,6 +1,6 @@ """Support for tasks, coroutines and the scheduler.""" -__all__ = ['coroutine', 'Task'] +__all__ = ['coroutine', 'task', 'Task'] import concurrent.futures import inspect @@ -30,6 +30,14 @@ def iscoroutine(obj): return inspect.isgenerator(obj) # TODO: And what? +def task(func): + """Decorator for a coroutine to be wrapped in a Task.""" + def task_wrapper(*args, **kwds): + coro = func(*args, **kwds) + return Task(coro) + return task_wrapper + + class Task(futures.Future): """A coroutine wrapped in a Future.""" diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py index 74552007..c9c184c1 100644 --- a/tulip/tasks_test.py +++ b/tulip/tasks_test.py @@ -9,7 +9,7 @@ class TaskTests(unittest.TestCase): - def testTask(self): + def testTaskClass(self): @tasks.coroutine def notmuch(): yield from [] @@ -19,6 +19,16 @@ def notmuch(): self.assertTrue(t.done()) self.assertEqual(t.result(), 'ok') + def testTaskDecorator(self): + @tasks.task + def notmuch(): + yield from [] + return 'ko' + t = notmuch() + t._event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + def testSleep(self): @tasks.coroutine def sleeper(dt, arg): From 32bebb3cf7c11bbed6f5b6febd25c6f4fc64c51c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 20 Dec 2012 13:40:03 -0800 Subject: [PATCH 0159/1502] Add run_forever(). Add UNIX socket transport. --- tulip/TODO | 14 ++-- tulip/events_test.py | 32 +++++++++ tulip/unix_events.py | 161 ++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 191 insertions(+), 16 deletions(-) diff --git a/tulip/TODO b/tulip/TODO index c0eca195..b3a9302e 100644 --- a/tulip/TODO +++ b/tulip/TODO @@ -3,22 +3,24 @@ TODO in tulip v2 (tulip/ package directory) - See also TBD and Open Issues in PEP 3156 +- Refactor unix_events.py (it's getting too long) + - Docstrings - Unittests -- better run_once() behavior? - -- create_transport() +- better run_once() behavior? (Run ready list last.) - start_serving() -- Make Handler() callable? +- Make Handler() callable? Log the exception in there? + +- Add the repeat interval to the Handler class? - Recognize Handler passed to add_reader(), call_soon(), etc.? -- Transport implementations +- SSL support -- Protocol implementations +- buffered stream implementation - Primitives like par() and wait_one() diff --git a/tulip/events_test.py b/tulip/events_test.py index 2dc21f6c..9ee02896 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -7,6 +7,8 @@ import unittest from . import events +from . import transports +from . import protocols from . import unix_events @@ -153,6 +155,36 @@ def test_sock_accept(self): conn.close() listener.close() + def testCreateTransport(self): + el = events.get_event_loop() + # TODO: This depends on xkcd.com behavior! + class MyProto(protocols.Protocol): + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: xkcd.com\r\n\r\n') + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + el.stop() + f = el.create_transport(MyProto, 'xkcd.com', 80) + tr, pr = el.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + el.run_forever() # Really, until connection_lost() calls el.stop(). + self.assertTrue(pr.nbytes > 0) + class HandlerTests(unittest.TestCase): diff --git a/tulip/unix_events.py b/tulip/unix_events.py index bf8163c4..5c7e2cfa 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -44,6 +44,9 @@ from . import events from . import futures +from . import protocols +from . import tasks +from . import transports try: from socket import socketpair @@ -118,7 +121,7 @@ def unregister_connector(self, fd): self.unregister_writer(fd) def poll(self, timeout=None): - """Poll for events. A subclass must implement this. + """Poll for I/O events. A subclass must implement this. If timeout is omitted or None, this blocks until at least one event is ready. Otherwise, timeout gives a maximum time to @@ -466,8 +469,23 @@ def run(self): except _StopError: break + def run_forever(self): + """Run until stop() is called. + + This only makes sense over run() if you have another thread + scheduling callbacks using call_soon_threadsafe(). + """ + handler = self.call_repeatedly(24*3600, lambda: None) + try: + self.run() + finally: + handler.cancel() + def run_once(self, timeout=None): - """Run through all callbacks and all I/O polls once.""" + """Run through all callbacks and all I/O polls once. + + Calling stop() will break out of this too. + """ try: self._run_once(timeout) except _StopError: @@ -512,8 +530,8 @@ def call_later(self, delay, callback, *args): are scheduled for exactly the same time, it undefined which will be called first. - Events scheduled in the past are passed on to call_soon(), so - these will be called in the order in which they were + Callbacks scheduled in the past are passed on to call_soon(), + so these will be called in the order in which they were registered rather than by time due. This is so you can't cheat and insert yourself at the front of the ready queue by using a negative time. @@ -588,11 +606,37 @@ def getaddrinfo(self, host, port, *, def getnameinfo(self, sockaddr, flags=0): return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) - # TODO: Or create_connection()? + # TODO: Or create_connection()? Or create_client()? + @tasks.task def create_transport(self, protocol_factory, host, port, *, - family=0, type=0, proto=0, flags=0): + family=0, type=socket.SOCK_STREAM, proto=0, flags=0): """XXX""" - + infos = yield from self.getaddrinfo(host, port, + family=family, type=type, + proto=proto, flags=flags) + if not infos: + raise socket.error('getaddrinfo() returned empty list') + exceptions = [] + for family, type, proto, cname, address in infos: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + # TODO: Use a small timeout here and overlap connect attempts. + try: + yield self.sock_connect(sock, address) + except socket.error as exc: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + protocol = protocol_factory() + # TODO: SSL. + transport = UnixSocketTransport(self, sock, protocol) + self.call_soon(protocol.connection_made, transport) + return transport, protocol + + # TODO: Or create_server()? def start_serving(self, protocol_factory, host, port, *, family=0, type=0, proto=0, flags=0): """XXX""" @@ -780,11 +824,19 @@ def _run_once(self, timeout=None): while self._scheduled and self._scheduled[0].cancelled: heapq.heappop(self._scheduled) - # Inspect the poll queue. - if self._pollster.pollable() > 1: + # Inspect the poll queue. If there's exactly one pollable + # file descriptor, it's the self-pipe, and if there's nothing + # scheduled, we should ignore it. + if self._pollster.pollable() > 1 or self._scheduled: if self._scheduled: + # Compute the desired timeout. when = self._scheduled[0].when - timeout = max(0, when - time.monotonic()) + deadline = max(0, when - time.monotonic()) + if timeout is None: + timeout = deadline + else: + timeout = min(timeout, deadline) + t0 = time.monotonic() events = self._pollster.poll(timeout) t1 = time.monotonic() @@ -805,3 +857,92 @@ def _run_once(self, timeout=None): break handler = heapq.heappop(self._scheduled) self.call_soon(handler.callback, *handler.args) + + +class UnixSocketTransport(transports.Transport): + + def __init__(self, event_loop, sock, protocol): + self._event_loop = event_loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._closing = False # Set when closed() called. + self._event_loop.add_reader(self._sock.fileno(), self._read_ready) + + def _read_ready(self): + try: + data = self._sock.recv(16*1024) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + else: + if data: + self._event_loop.call_soon(self._protocol.data_received, data) + else: + self._event_loop.remove_reader(self._sock.fileno()) + self._event_loop.call_soon(self._protocol.eof_received) + + + def write(self, data): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + if not self._buffer: + # Attempt to send it right away first. + try: + n = self._sock.send(data) + except socket.error as exc: + if exc.errno in _TRYAGAIN: + n = 0 + else: + self._fatal_error(exc) + return + if n == len(data): + return + if n: + data = data[n:] + self.add_writer(self._sock.fileno(), self._write_ready) + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + self._buffer = [] + try: + if data: + n = self._sock.send(data) + else: + n = 0 + except socket.error as exc: + if exc.errno in _TRYAGAIN: + n = 0 + else: + self._fatal_error(exc) + return + if n == len(data): + self._event_loop.remove_writer(self._sock.fileno()) + if self._closing: + self._event_loop.call_soon(self._protocol.connection_lost, + None) + return + if n: + data = data[n:] + self._buffer.append(data) # Try again later. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sock.fileno()) + if not self._buffer: + self._event_loop.call_soon(self._protocol.connection_lost, None) + + def _fatal_error(self, exc): + logging.exception('Fatal error for %s', self) + self._event_loop.remove_writer(self._sock.fileno()) + self._event_loop.remove_reader(self._sock.fileno()) + self._buffer = [] + self._event_loop.call_soon(self._protocol.connection_lost, exc) From 62b6a5d4d16986c06291a7059fea0030fc11cb9e Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 20 Dec 2012 13:41:49 -0800 Subject: [PATCH 0160/1502] No need to use run_forever() and stop() in transport test. --- tulip/events_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tulip/events_test.py b/tulip/events_test.py index 9ee02896..4c163a0b 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -177,12 +177,11 @@ def eof_received(self): def connection_lost(self, exc): assert self.state in ('CONNECTED', 'EOF'), self.state self.state = 'CLOSED' - el.stop() f = el.create_transport(MyProto, 'xkcd.com', 80) tr, pr = el.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) - el.run_forever() # Really, until connection_lost() calls el.stop(). + el.run() self.assertTrue(pr.nbytes > 0) From df6d5a882c8077e4d36b8139a6e2e7289e35824b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 20 Dec 2012 14:31:46 -0800 Subject: [PATCH 0161/1502] Add ssl transport. --- tulip/events_test.py | 51 ++++++++++------- tulip/unix_events.py | 129 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 155 insertions(+), 25 deletions(-) diff --git a/tulip/events_test.py b/tulip/events_test.py index 4c163a0b..09992edd 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -12,6 +12,27 @@ from . import unix_events +class MyProto(protocols.Protocol): + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: xkcd.com\r\n\r\n') + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + + class EventLoopTests(unittest.TestCase): def setUp(self): @@ -158,25 +179,6 @@ def test_sock_accept(self): def testCreateTransport(self): el = events.get_event_loop() # TODO: This depends on xkcd.com behavior! - class MyProto(protocols.Protocol): - def __init__(self): - self.state = 'INITIAL' - self.nbytes = 0 - def connection_made(self, transport): - self.transport = transport - assert self.state == 'INITIAL', self.state - self.state = 'CONNECTED' - transport.write(b'GET / HTTP/1.0\r\nHost: xkcd.com\r\n\r\n') - def data_received(self, data): - assert self.state == 'CONNECTED', self.state - self.nbytes += len(data) - def eof_received(self): - assert self.state == 'CONNECTED', self.state - self.state = 'EOF' - self.transport.close() - def connection_lost(self, exc): - assert self.state in ('CONNECTED', 'EOF'), self.state - self.state = 'CLOSED' f = el.create_transport(MyProto, 'xkcd.com', 80) tr, pr = el.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) @@ -184,6 +186,17 @@ def connection_lost(self, exc): el.run() self.assertTrue(pr.nbytes > 0) + def testCreateSslTransport(self): + el = events.get_event_loop() + # TODO: This depends on xkcd.com behavior! + f = el.create_transport(MyProto, 'xkcd.com', 443, ssl=True) + tr, pr = el.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + el.run() + self.assertTrue(pr.nbytes > 0) + class HandlerTests(unittest.TestCase): diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 5c7e2cfa..486c2521 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -38,6 +38,7 @@ import logging import select import socket +import ssl import sys import threading import time @@ -608,7 +609,7 @@ def getnameinfo(self, sockaddr, flags=0): # TODO: Or create_connection()? Or create_client()? @tasks.task - def create_transport(self, protocol_factory, host, port, *, + def create_transport(self, protocol_factory, host, port, *, ssl=False, family=0, type=socket.SOCK_STREAM, proto=0, flags=0): """XXX""" infos = yield from self.getaddrinfo(host, port, @@ -631,9 +632,11 @@ def create_transport(self, protocol_factory, host, port, *, else: raise exceptions[0] protocol = protocol_factory() - # TODO: SSL. - transport = UnixSocketTransport(self, sock, protocol) - self.call_soon(protocol.connection_made, transport) + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + transport = _UnixSslTransport(self, sock, protocol, sslcontext) + else: + transport = _UnixSocketTransport(self, sock, protocol) return transport, protocol # TODO: Or create_server()? @@ -859,15 +862,16 @@ def _run_once(self, timeout=None): self.call_soon(handler.callback, *handler.args) -class UnixSocketTransport(transports.Transport): +class _UnixSocketTransport(transports.Transport): def __init__(self, event_loop, sock, protocol): self._event_loop = event_loop self._sock = sock self._protocol = protocol self._buffer = [] - self._closing = False # Set when closed() called. + self._closing = False # Set when close() called. self._event_loop.add_reader(self._sock.fileno(), self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) def _read_ready(self): try: @@ -946,3 +950,116 @@ def _fatal_error(self, exc): self._event_loop.remove_reader(self._sock.fileno()) self._buffer = [] self._event_loop.call_soon(self._protocol.connection_lost, exc) + + +class _UnixSslTransport(transports.Transport): + + def __init__(self, event_loop, rawsock, protocol, sslcontext): + self._event_loop = event_loop + self._rawsock = rawsock + self._protocol = protocol + sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslcontext = sslcontext + sslsock = sslcontext.wrap_socket(rawsock, + do_handshake_on_connect=False) + self._sslsock = sslsock + self._buffer = [] + self._closing = False # Set when close() called. + self._on_handshake() + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._event_loop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._event_loop.add_writable(fd, self._on_handshake) + return + # TODO: What if it raises another error? + self._event_loop.remove_reader(fd) + self._event_loop.remove_writer(fd) + self._event_loop.add_reader(fd, self._on_ready) + self._event_loop.add_writer(fd, self._on_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + return + else: + if data: + self._protocol.data_received(data) + else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer + self._event_loop.remove_reader(fd) + self._event_loop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + return + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = b''.join(self._buffer) + self._buffer = [] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + return + else: + if n < len(data): + self._buffer.append(data[n:]) + + def write(self, data): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sslsock.fileno()) + if not self._buffer: + self._event_loop.call_soon(self._protocol.connection_lost, None) + + def _fatal_error(self, exc): + logging.exception('Fatal error for %s', self) + self._event_loop.remove_writer(self._sslsock.fileno()) + self._event_loop.remove_reader(self._sslsock.fileno()) + self._buffer = [] + self._event_loop.call_soon(self._protocol.connection_lost, exc) From a943cc609643be3e0d08b24841f2498792e13fea Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 20 Dec 2012 18:02:59 -0800 Subject: [PATCH 0162/1502] Define separate tests for each supported pollster class. --- tulip/events_test.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/tulip/events_test.py b/tulip/events_test.py index 09992edd..9a43acf9 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -1,6 +1,7 @@ """Tests for events.py.""" import concurrent.futures +import select import socket import threading import time @@ -33,10 +34,12 @@ def connection_lost(self, exc): self.state = 'CLOSED' -class EventLoopTests(unittest.TestCase): +class EventLoopTestsMixin: def setUp(self): - events.init_event_loop() + pollster = self.POLLSTER_CLASS() + event_loop = unix_events.UnixEventLoop(pollster) + events.set_event_loop(event_loop) def testRun(self): el = events.get_event_loop() @@ -198,6 +201,26 @@ def testCreateSslTransport(self): self.assertTrue(pr.nbytes > 0) +if hasattr(select, 'kqueue'): + class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + POLLSTER_CLASS = unix_events.KqueuePollster + + +if hasattr(select, 'epoll'): + class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + POLLSTER_CLASS = unix_events.EPollPollster + + +if hasattr(select, 'poll'): + class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + POLLSTER_CLASS = unix_events.PollPollster + + +# Should always exist. +class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + POLLSTER_CLASS = unix_events.SelectPollster + + class HandlerTests(unittest.TestCase): def testHandler(self): From cc7b71ef5c57a45b85acd31f5a33dd428c28328c Mon Sep 17 00:00:00 2001 From: Geert Jansen Date: Fri, 21 Dec 2012 23:01:05 +0100 Subject: [PATCH 0163/1502] Fix typo --- tulip/events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tulip/events.py b/tulip/events.py index d843c946..0a642f19 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -113,7 +113,7 @@ def call_later(self, delay, callback, *args): raise NotImplementedError def call_repeatedly(self, interval, callback, *args): # NEW! - raise NotImplementdError + raise NotImplementedError def call_soon(self, callback, *args): return self.call_later(0, callback, *args) From 8b89d80d629c25c6a32c021193465e07c4009666 Mon Sep 17 00:00:00 2001 From: Geert Jansen Date: Fri, 21 Dec 2012 23:03:00 +0100 Subject: [PATCH 0164/1502] Remove out-of-date comments. --- tulip/unix_events.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 486c2521..fe9288b2 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -439,8 +439,8 @@ def __init__(self, pollster=None): logging.info('Using pollster: %s', best_pollster.__name__) pollster = best_pollster() self._pollster = pollster - self._ready = collections.deque() # [(callback, args), ...] - self._scheduled = [] # [(when, callback, args), ...] + self._ready = collections.deque() + self._scheduled = [] self._default_executor = None self._make_self_pipe() From 540ec85393bdc663e090f15fb0328ef768a7c87f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 21 Dec 2012 17:09:14 -0800 Subject: [PATCH 0165/1502] Add start_serving(). Only half tested. --- tulip/events_test.py | 7 +++++++ tulip/unix_events.py | 47 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/tulip/events_test.py b/tulip/events_test.py index 9a43acf9..63ad8f67 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -200,6 +200,13 @@ def testCreateSslTransport(self): el.run() self.assertTrue(pr.nbytes > 0) + def testStartServing(self): + el = events.get_event_loop() + f = el.start_serving(MyProto, '0.0.0.0', 0) + sock = el.run_until_complete(f) + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + if hasattr(select, 'kqueue'): class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase): diff --git a/tulip/unix_events.py b/tulip/unix_events.py index fe9288b2..b2cade6e 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -630,6 +630,8 @@ def create_transport(self, protocol_factory, host, port, *, ssl=False, else: break else: + # TODO: What to do if there are multiple exceptions? We + # can't raise them all. Arbitrarily pick the first one. raise exceptions[0] protocol = protocol_factory() if ssl: @@ -640,9 +642,52 @@ def create_transport(self, protocol_factory, host, port, *, ssl=False, return transport, protocol # TODO: Or create_server()? + @tasks.task def start_serving(self, protocol_factory, host, port, *, - family=0, type=0, proto=0, flags=0): + family=0, type=socket.SOCK_STREAM, proto=0, flags=0, + backlog=100): """XXX""" + infos = yield from self.getaddrinfo(host, port, + family=family, type=type, + proto=proto, flags=flags) + if not infos: + raise socket.error('getaddrinfo() returned empty list') + # TODO: Maybe we want to bind every address in the list + # instead of the first one that works? + exceptions = [] + for family, type, proto, cname, address in infos: + sock = socket.socket(family=family, type=type, proto=proto) + try: + sock.bind(address) + except socket.error as exc: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + sock.listen(backlog) + sock.setblocking(False) + self.add_reader(sock.fileno(), self._accept_connection, + protocol_factory, sock) + return sock + + def _accept_connection(self, protocol_factory, sock): + try: + conn, addr = sock.accept() + except socket.error as exc: + if exc in _TRYAGAIN: + return # False alarm. + # Bad error. Stop serving. + self.remove_reader(sock.fileno()) + sock.close() + # There's nowhere to send the error, so just log it. + # TODO: Someone will want an error handler for this. + logging.exception('Accept failed') + return + protocol = protocol_factory() + transport = _UnixSocketTransport(self, conn, protocol) + # It's now up to the protocol to handle the connection. def add_reader(self, fd, callback, *args): """Add a reader callback. Return a Handler instance.""" From dade206d548bb80a63150109affb9d25adff7d30 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 21 Dec 2012 17:14:35 -0800 Subject: [PATCH 0166/1502] Exercise _accept_connection() a bit. --- tulip/events_test.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tulip/events_test.py b/tulip/events_test.py index 63ad8f67..94dd2e97 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -206,6 +206,12 @@ def testStartServing(self): sock = el.run_until_complete(f) host, port = sock.getsockname() self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect((host, port)) + client.send(b'xxx') + client.close() + el.run_once() + el.run_once() if hasattr(select, 'kqueue'): From 82d83675b49d7cbd440c4894bcb343ae5a0b02af Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 21 Dec 2012 17:32:08 -0800 Subject: [PATCH 0167/1502] Clean more thoroughly. --- Makefile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Makefile b/Makefile index 9e749ab4..274df966 100644 --- a/Makefile +++ b/Makefile @@ -18,10 +18,12 @@ check: $(PYTHON) check.py clean: + rm -rf __pycache__ */__pycache__ rm -f *.py[co] */*.py[co] rm -f *~ */*~ rm -f .*~ */.*~ rm -f @* */@* rm -f '#'*'#' */'#'*'#' + rm -f *.orig */*.orig rm -f .coverage rm -rf htmlcov From 35de35f483722d25e8c4037f64cc19ca522c580c Mon Sep 17 00:00:00 2001 From: Geert Jansen Date: Fri, 21 Dec 2012 23:04:32 +0100 Subject: [PATCH 0168/1502] Run file descriptors callbacks in the same iteration the fd got ready. --- tulip/unix_events.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index b2cade6e..de08bd51 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -848,26 +848,9 @@ def _run_once(self, timeout=None): """ # TODO: Break each of these into smaller pieces. # TODO: Refactor to separate the callbacks from the readers/writers. - # TODO: As step 4, run everything scheduled by steps 1-3. # TODO: An alternative API would be to do the *minimal* amount # of work, e.g. one callback or one I/O poll. - # This is the only place where callbacks are actually *called*. - # All other places just add them to ready. - # TODO: Ensure this loop always finishes, even if some - # callbacks keeps registering more callbacks. - while self._ready: - handler = self._ready.popleft() - if not handler.cancelled: - try: - if handler.kwds: - handler.callback(*handler.args, **handler.kwds) - else: - handler.callback(*handler.args) - except Exception: - logging.exception('Exception in callback %s %r', - handler.callback, handler.args) - # Remove delayed calls that were cancelled from head of queue. while self._scheduled and self._scheduled[0].cancelled: heapq.heappop(self._scheduled) @@ -876,7 +859,9 @@ def _run_once(self, timeout=None): # file descriptor, it's the self-pipe, and if there's nothing # scheduled, we should ignore it. if self._pollster.pollable() > 1 or self._scheduled: - if self._scheduled: + if self._ready: + timeout = 0 + elif self._scheduled: # Compute the desired timeout. when = self._scheduled[0].when deadline = max(0, when - time.monotonic()) @@ -906,6 +891,22 @@ def _run_once(self, timeout=None): handler = heapq.heappop(self._scheduled) self.call_soon(handler.callback, *handler.args) + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # TODO: Ensure this loop always finishes, even if some + # callbacks keeps registering more callbacks. + while self._ready: + handler = self._ready.popleft() + if not handler.cancelled: + try: + if handler.kwds: + handler.callback(*handler.args, **handler.kwds) + else: + handler.callback(*handler.args) + except Exception: + logging.exception('Exception in callback %s %r', + handler.callback, handler.args) + class _UnixSocketTransport(transports.Transport): From db3ebf888b74e2a562d4145504e6f1276753f42c Mon Sep 17 00:00:00 2001 From: Geert Jansen Date: Fri, 21 Dec 2012 23:06:22 +0100 Subject: [PATCH 0169/1502] Add call_every_iteration(). --- tulip/events_test.py | 15 +++++++++++++++ tulip/unix_events.py | 21 +++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/tulip/events_test.py b/tulip/events_test.py index 94dd2e97..b69479ff 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -93,6 +93,21 @@ def run(): self.assertEqual(results, ['hello', 'world']) self.assertTrue(t1-t0 >= 0.09) + def testCallEveryIteration(self): + el = events.get_event_loop() + results = [] + def callback(arg): + results.append(arg) + handle = el.call_every_iteration(callback, 'ho') + el.run_once() + self.assertEqual(results, ['ho']) + el.run_once() + el.run_once() + self.assertEqual(results, ['ho', 'ho', 'ho']) + handle.cancel() + el.run_once() + self.assertEqual(results, ['ho', 'ho', 'ho']) + def testWrapFuture(self): el = events.get_event_loop() def run(arg): diff --git a/tulip/unix_events.py b/tulip/unix_events.py index de08bd51..04c31255 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -441,6 +441,7 @@ def __init__(self, pollster=None): self._pollster = pollster self._ready = collections.deque() self._scheduled = [] + self._everytime = [] self._default_executor = None self._make_self_pipe() @@ -578,6 +579,15 @@ def call_soon_threadsafe(self, callback, *args): self._write_to_self() return handler + def call_every_iteration(self, callback, *args): + """Call a callback just before the loop blocks. + + The callback is called for every iteration of the loop. + """ + handler = events.Handler(None, callback, args) + self._everytime.append(handler) + return handler + def wrap_future(self, future): """XXX""" if isinstance(future, futures.Future): @@ -851,6 +861,17 @@ def _run_once(self, timeout=None): # TODO: An alternative API would be to do the *minimal* amount # of work, e.g. one callback or one I/O poll. + # Add everytime handlers. + any_cancelled = False + for handler in self._everytime: + self._add_callback(handler) + any_cancelled = any_cancelled or handler.cancelled + + # Remove cancelled handlers if there are any. + if any_cancelled: + self._everytime = [handler for handler in self._everytime + if not handler.cancelled] + # Remove delayed calls that were cancelled from head of queue. while self._scheduled and self._scheduled[0].cancelled: heapq.heappop(self._scheduled) From 42480b0d3e18bfd55217a4c5c16249e58c252f9b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 22 Dec 2012 10:03:04 -0800 Subject: [PATCH 0170/1502] In Future.__iter__(), don't yield self if already done. --- tulip/futures.py | 3 ++- tulip/futures_test.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/tulip/futures.py b/tulip/futures.py index 616a6ac3..b079c5fc 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -180,7 +180,8 @@ def _copy_state(self, other): self.set_result(result) def __iter__(self): - yield self # This tells Task to wait for completion. + if not self.done(): + yield self # This tells Task to wait for completion. return self.result() # May raise too. diff --git a/tulip/futures_test.py b/tulip/futures_test.py index f7d93ba2..4704849a 100644 --- a/tulip/futures_test.py +++ b/tulip/futures_test.py @@ -41,6 +41,22 @@ def testException(self): self.assertRaises(RuntimeError, f.result) self.assertEqual(f.exception(), exc) + def testYieldFromTwice(self): + f = futures.Future() + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + if __name__ == '__main__': unittest.main() From fea13726c2ff6c153b77b61ee469e75d63471a26 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 8 Jan 2013 10:42:32 -0800 Subject: [PATCH 0171/1502] Make test_writer_callback() more robust. --- tulip/events_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tulip/events_test.py b/tulip/events_test.py index b69479ff..eac4c414 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -149,11 +149,12 @@ def reader(): def test_writer_callback(self): el = events.get_event_loop() r, w = unix_events.socketpair() - el.add_writer(w.fileno(), w.send, b'x'*100) + w.setblocking(False) + el.add_writer(w.fileno(), w.send, b'x'*(256*1024)) el.call_later(0.1, el.remove_writer, w.fileno()) el.run() w.close() - data = r.recv(32*1024) + data = r.recv(256*1024) r.close() self.assertTrue(len(data) >= 200) From 773ab655ce897bf3dac48daa9226ec8b9c0a7f37 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 8 Jan 2013 14:26:56 -0800 Subject: [PATCH 0172/1502] Use Selector classes by Charles-Francois Natali instead of pollsters. See http://bugs.python.org/issue16853. The selectors.py file is a slightly modified clone of the select.py file from that issue. --- tulip/events_test.py | 19 +- tulip/selectors.py | 363 +++++++++++++++++++++++++++++++++ tulip/unix_events.py | 472 ++++++++----------------------------------- 3 files changed, 455 insertions(+), 399 deletions(-) create mode 100644 tulip/selectors.py diff --git a/tulip/events_test.py b/tulip/events_test.py index eac4c414..1a6652fc 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -10,6 +10,7 @@ from . import events from . import transports from . import protocols +from . import selectors from . import unix_events @@ -37,8 +38,8 @@ def connection_lost(self, exc): class EventLoopTestsMixin: def setUp(self): - pollster = self.POLLSTER_CLASS() - event_loop = unix_events.UnixEventLoop(pollster) + selector = self.SELECTOR_CLASS() + event_loop = unix_events.UnixEventLoop(selector) events.set_event_loop(event_loop) def testRun(self): @@ -230,24 +231,24 @@ def testStartServing(self): el.run_once() -if hasattr(select, 'kqueue'): +if hasattr(selectors, 'KqueueSelector'): class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase): - POLLSTER_CLASS = unix_events.KqueuePollster + SELECTOR_CLASS = selectors.KqueueSelector -if hasattr(select, 'epoll'): +if hasattr(selectors, 'EpollSelector'): class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): - POLLSTER_CLASS = unix_events.EPollPollster + SELECTOR_CLASS = selectors.EpollSelector -if hasattr(select, 'poll'): +if hasattr(selectors, 'PollSelector'): class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): - POLLSTER_CLASS = unix_events.PollPollster + SELECTOR_CLASS = selectors.PollSelector # Should always exist. class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): - POLLSTER_CLASS = unix_events.SelectPollster + SELECTOR_CLASS = selectors.SelectSelector class HandlerTests(unittest.TestCase): diff --git a/tulip/selectors.py b/tulip/selectors.py new file mode 100644 index 00000000..53e2210c --- /dev/null +++ b/tulip/selectors.py @@ -0,0 +1,363 @@ +"""Select module. + +This module supports asynchronous I/O on multiple file descriptors. +""" + +import logging + +from select import * + + +# generic events, that must be mapped to implementation-specific ones +# read event +SELECT_IN = (1 << 0) +# write event +SELECT_OUT = (1 << 1) +# connect event +SELECT_CONNECT = SELECT_OUT + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file descriptor, or any object with a `fileno()` method + + Returns: + corresponding file descriptor + """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (ValueError, TypeError): + raise ValueError("Invalid file object: {!r}".format(fileobj)) + return fd + + +class _Key: + """Object used internally to associate a file object to its backing file + descriptor, selected event mask and attached data.""" + + def __init__(self, fileobj, events, data=None): + self.fileobj = fileobj + self.fd = _fileobj_to_fd(fileobj) + self.events = events + self.data = data + + def __repr__(self): + return '{}'.format( + self.__class__.__name__, + self.fileobj, self.fd, self.events, self.data) + + +class _BaseSelector: + """Base selector class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + performant implementation on the current platform. + """ + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # this maps file objects to keys - for fast (un)registering + self._fileobj_to_key = {} + + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of SELECT_IN|SELECT_OUT) + data -- attached data + """ + if (not events) or (events & ~(SELECT_IN|SELECT_OUT)): + raise ValueError("Invalid events: {}".format(events)) + + if fileobj in self._fileobj_to_key: + raise ValueError("{!r} is already registered".format(fileobj)) + + key = _Key(fileobj, events, data) + self._fd_to_key[key.fd] = key + self._fileobj_to_key[fileobj] = key + return key + + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object + """ + try: + key = self._fileobj_to_key[fileobj] + del self._fd_to_key[key.fd] + del self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + return key + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of SELECT_IN|SELECT_OUT) + data -- attached data + """ + self.unregister(fileobj) + self.register(fileobj, events, data) + + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout == 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (fileobj, events, attached data) for ready file objects + `events` is a bitwise mask of SELECT_IN|SELECT_OUT + """ + raise NotImplementedError() + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + self._fd_to_key.clear() + self._fileobj_to_key.clear() + + def get_info(self, fileobj): + """Return information about a registered file object. + + Returns: + (events, data) associated to this file object + + Raises KeyError if the file object is not registered. + """ + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise KeyError("{} is not registered".format(fileobj)) + return key.events, key.data + + def registered_count(self): + """Return the number of registered file objects. + + Returns: + number of currently registered file objects + """ + return len(self._fd_to_key) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key + """ + try: + return self._fd_to_key[fd] + except KeyError: + raise RuntimeError("No key found for fd {}".format(fd)) + + +class SelectSelector(_BaseSelector): + """Select-based selector.""" + + def __init__(self): + super().__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & SELECT_IN: + self._readers.add(key.fd) + if events & SELECT_OUT: + self._writers.add(key.fd) + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + + def select(self, timeout=None): + r, w, _ = select(self._readers, self._writers, [], timeout) + r = set(r) + w = set(w) + ready = [] + for fd in r | w: + events = 0 + if fd in r: + events |= SELECT_IN + if fd in w: + events |= SELECT_OUT + + key = self._key_from_fd(fd) + ready.append((key.fileobj, events, key.data)) + return ready + + +if 'poll' in globals(): + + class PollSelector(_BaseSelector): + """Poll-based selector.""" + + def __init__(self): + super().__init__() + self._poll = poll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & SELECT_IN: + poll_events |= POLLIN + if events & SELECT_OUT: + poll_events |= POLLOUT + self._poll.register(key.fd, poll_events) + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._poll.unregister(key.fd) + + def select(self, timeout=None): + timeout = None if timeout is None else int(1000 * timeout) + ready = [] + for fd, event in self._poll.poll(timeout): + events = 0 + if event & ~POLLIN: + events |= SELECT_OUT + if event & ~POLLOUT: + events |= SELECT_IN + + key = self._key_from_fd(fd) + ready.append((key.fileobj, events, key.data)) + return ready + + +if 'epoll' in globals(): + + class EpollSelector(_BaseSelector): + """Epoll-based selector.""" + + def __init__(self): + super().__init__() + self._epoll = epoll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + epoll_events = 0 + if events & SELECT_IN: + epoll_events |= EPOLLIN + if events & SELECT_OUT: + epoll_events |= EPOLLOUT + self._epoll.register(key.fd, epoll_events) + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._epoll.unregister(key.fd) + + def select(self, timeout=None): + timeout = -1 if timeout is None else timeout + max_ev = self.registered_count() + ready = [] + for fd, event in self._epoll.poll(timeout, max_ev): + events = 0 + if event & ~EPOLLIN: + events |= SELECT_OUT + if event & ~EPOLLOUT: + events |= SELECT_IN + + key = self._key_from_fd(fd) + ready.append((key.fileobj, events, key.data)) + return ready + + def close(self): + super().close() + self._epoll.close() + + +if 'kqueue' in globals(): + + class KqueueSelector(_BaseSelector): + """Kqueue-based selector.""" + + def __init__(self): + super().__init__() + self._kqueue = kqueue() + + def unregister(self, fileobj): + key = super().unregister(fileobj) + mask = 0 + if key.events & SELECT_IN: + mask |= KQ_FILTER_READ + if key.events & SELECT_OUT: + mask |= KQ_FILTER_WRITE + kev = kevent(key.fd, mask, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & SELECT_IN: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & SELECT_OUT: + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + + def select(self, timeout=None): + max_ev = self.registered_count() + ready = [] + for kev in self._kqueue.control(None, max_ev, timeout): + fd = kev.ident + flag = kev.filter + events = 0 + if flag == KQ_FILTER_READ: + events |= SELECT_IN + if flag == KQ_FILTER_WRITE: + events |= SELECT_OUT + + key = self._key_from_fd(fd) + ready.append((key.fileobj, events, key.data)) + return ready + + def close(self): + super().close() + self._kqueue.close() + + +# Choose the best implementation: roughly, epoll|kqueue > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + Selector = KqueueSelector +elif 'EpollSelector' in globals(): + Selector = EpollSelector +elif 'PollSelector' in globals(): + Selector = PollSelector +else: + Selector = SelectSelector diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 04c31255..575bde10 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -1,10 +1,8 @@ """UNIX event loop and related classes. -NOTE: The Pollster classes are not part of the published API. - -The event loop can be broken up into a pollster (the part responsible +The event loop can be broken up into a selector (the part responsible for telling us when file descriptors are ready) and the event loop -proper, which wraps a pollster with functionality for scheduling +proper, which wraps a selector with functionality for scheduling callbacks, immediately or at a given time in the future. Whenever a public API takes a callback, subsequent positional @@ -13,22 +11,6 @@ Keyword arguments for the callback are not supported; this is a conscious design decision, leaving the door open for keyword arguments to modify the meaning of the API call itself. - -There are several implementations of the pollster part, several using -esoteric system calls that exist only on some platforms. These are: - -- kqueue (most BSD systems) -- epoll (newer Linux systems) -- poll (most UNIX systems) -- select (all UNIX systems, and Windows) - -NOTE: We don't use select on systems where any of the others is -available, because select performs poorly as the number of file -descriptors goes up. The ranking is roughly: - - 1. kqueue, epoll, IOCP (best for each platform) - 2. poll (linear in number of file descriptors polled) - 3. select (linear in max number of file descriptors supported) """ import collections @@ -46,6 +28,7 @@ from . import events from . import futures from . import protocols +from . import selectors from . import tasks from . import transports @@ -73,352 +56,6 @@ _MAX_WORKERS = 5 -class PollsterBase: - """Base class for all polling implementations. - - This defines an interface to register and unregister readers and - writers for specific file descriptors, and an interface to get a - list of events. There's also an interface to check whether any - readers or writers are currently registered. - """ - - def __init__(self): - super().__init__() - self.readers = {} # {fd: handler, ...}. - self.writers = {} # {fd: handler, ...}. - - def pollable(self): - """Return the number readers and writers currently registered.""" - # The event loop needs the number since it must subtract one for - # the self-pipe. - return len(self.readers) + len(self.writers) - - # Subclasses are expected to extend the add/remove methods. - - def register_reader(self, fd, handler): - """Add or update a reader for a file descriptor.""" - self.readers[fd] = handler - - def register_writer(self, fd, handler): - """Add or update a writer for a file descriptor.""" - self.writers[fd] = handler - - def unregister_reader(self, fd): - """Remove the reader for a file descriptor.""" - del self.readers[fd] - - def unregister_writer(self, fd): - """Remove the writer for a file descriptor.""" - del self.writers[fd] - - def register_connector(self, fd, handler): - """Add or update a connector for a file descriptor.""" - # On Unix a connector is the same as a writer. - self.register_writer(fd, handler) - - def unregister_connector(self, fd): - """Remove the connector for a file descriptor.""" - # On Unix a connector is the same as a writer. - self.unregister_writer(fd) - - def poll(self, timeout=None): - """Poll for I/O events. A subclass must implement this. - - If timeout is omitted or None, this blocks until at least one - event is ready. Otherwise, timeout gives a maximum time to - wait (an int of float in seconds) -- the method returns as - soon as at least one event is ready or when the timeout is - expired. For a non-blocking poll, pass 0. - - The return value is a list of events; it is empty when the - timeout expired before any events were ready. Each event - is a handler previously passed to register_reader/writer(). - """ - raise NotImplementedError - - -if sys.platform != 'win32': - - class SelectPollster(PollsterBase): - """Pollster implementation using select.""" - - def poll(self, timeout=None): - readable, writable, _ = select.select(self.readers, self.writers, - [], timeout) - events = [] - events += (self.readers[fd] for fd in readable) - events += (self.writers[fd] for fd in writable) - return events - -else: - - class SelectPollster(PollsterBase): - """Pollster implementation using select.""" - - def __init__(self): - super().__init__() - self.exceptionals = {} - - def poll(self, timeout=None): - # Failed connections are reported as exceptional but not writable. - readable, writable, exceptional = select.select( - self.readers, self.writers, self.exceptionals, timeout) - writable = set(writable).union(exceptional) - events = [] - events += (self.readers[fd] for fd in readable) - events += (self.writers[fd] for fd in writable) - return events - - def register_connector(self, fd, token): - self.register_writer(fd, token) - self.exceptionals[fd] = token - - def unregister_connector(self, fd): - self.unregister_writer(fd) - try: - del self.exceptionals[fd] - except KeyError: - # remove_connector() does not check fd in self.exceptionals. - pass - - -class PollPollster(PollsterBase): - """Pollster implementation using poll.""" - - def __init__(self): - super().__init__() - self._poll = select.poll() - - def _update(self, fd): - assert isinstance(fd, int), fd - flags = 0 - if fd in self.readers: - flags |= select.POLLIN - if fd in self.writers: - flags |= select.POLLOUT - if flags: - self._poll.register(fd, flags) - else: - self._poll.unregister(fd) - - def register_reader(self, fd, handler): - super().register_reader(fd, handler) - self._update(fd) - - def register_writer(self, fd, handler): - super().register_writer(fd, handler) - self._update(fd) - - def unregister_reader(self, fd): - super().unregister_reader(fd) - self._update(fd) - - def unregister_writer(self, fd): - super().unregister_writer(fd) - self._update(fd) - - def poll(self, timeout=None): - # Timeout is in seconds, but poll() takes milliseconds. - msecs = None if timeout is None else int(round(1000 * timeout)) - events = [] - for fd, flags in self._poll.poll(msecs): - if flags & ~select.POLLOUT: - if fd in self.readers: - events.append(self.readers[fd]) - if flags & ~select.POLLIN: - if fd in self.writers: - events.append(self.writers[fd]) - return events - - -if sys.platform == 'win32': - - class WindowsPollPollster(PollPollster): - """Pollster implementation using WSAPoll. - - WSAPoll is only available on Windows Vista and later. Python - does not currently support WSAPoll, but there is a patch - available at http://bugs.python.org/issue16507. - """ - - # REAP_PERIOD is the maximum wait before checking for failed - # connections. This is necessary because WSAPoll() does notify us - # of failed connections. See - # daniel.haxx.se/blog/2012/10/10/wsapoll-is-broken/ - REAP_PERIOD = 5.0 - - # FD_SETSIZE is maximum number of sockets in an fd_set - FD_SETSIZE = 512 - - def __init__(self): - super().__init__() - self.exceptionals = {} - - def register_connector(self, fd, token): - self.register_writer(fd, token) - self.exceptionals[fd] = token - - def unregister_connector(self, fd): - self.unregister_writer(fd) - try: - del self.exceptionals[fd] - except KeyError: - # remove_connector() does not check fd in self.exceptionals. - pass - - def _get_failed_connector_events(self): - fds = [] - remaining = list(self.exceptionals) - while remaining: - fds += select.select([], [], remaining[:self.FD_SETSIZE], 0)[2] - del remaining[:self.FD_SETSIZE] - return [(fd, select.POLLOUT) for fd in fds] - - def poll(self, timeout=None): - if not self.exceptionals: - msecs = None if timeout is None else int(round(1000 * timeout)) - polled = self._poll.poll(msecs) - - elif timeout is None: - polled = None - while not polled: - polled = (self._get_failed_connector_events() or - self._poll.poll(self.REAP_PERIOD)) - - elif timeout == 0: - polled = (self._get_failed_connector_events() or - self._poll.poll(0)) - - else: - start = time.monotonic() - deadline = start + timeout - polled = None - while timeout >= 0: - msecs = int(round(1000 * min(self.REAP_PERIOD, timeout))) - polled = (self._get_failed_connector_events() or - self._poll.poll(self.REAP_PERIOD)) - if polled: - break - timemout = deadline - time.monotonic() - - events = [] - for fd, flags in polled: - if flags & ~select.POLLOUT: - if fd in self.readers: - events.append(self.readers[fd]) - if flags & ~select.POLLIN: - if fd in self.writers: - events.append(self.writers[fd]) - return events - - PollPollster = WindowsPollPollster - - -class EPollPollster(PollsterBase): - """Pollster implementation using epoll.""" - - def __init__(self): - super().__init__() - self._epoll = select.epoll() - - def _update(self, fd): - assert isinstance(fd, int), fd - eventmask = 0 - if fd in self.readers: - eventmask |= select.EPOLLIN - if fd in self.writers: - eventmask |= select.EPOLLOUT - if eventmask: - try: - self._epoll.register(fd, eventmask) - except IOError: - self._epoll.modify(fd, eventmask) - else: - self._epoll.unregister(fd) - - def register_reader(self, fd, handler): - super().register_reader(fd, handler) - self._update(fd) - - def register_writer(self, fd, handler): - super().register_writer(fd, handler) - self._update(fd) - - def unregister_reader(self, fd): - super().unregister_reader(fd) - self._update(fd) - - def unregister_writer(self, fd): - super().unregister_writer(fd) - self._update(fd) - - def poll(self, timeout=None): - if timeout is None: - timeout = -1 # epoll.poll() uses -1 to mean "wait forever". - events = [] - for fd, eventmask in self._epoll.poll(timeout): - if eventmask & ~select.EPOLLOUT: - if fd in self.readers: - events.append(self.readers[fd]) - if eventmask & ~select.EPOLLIN: - if fd in self.writers: - events.append(self.writers[fd]) - return events - - -class KqueuePollster(PollsterBase): - """Pollster implementation using kqueue.""" - - def __init__(self): - super().__init__() - self._kqueue = select.kqueue() - - def register_reader(self, fd, handler): - if fd not in self.readers: - kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) - self._kqueue.control([kev], 0, 0) - return super().register_reader(fd, handler) - - def register_writer(self, fd, handler): - if fd not in self.writers: - kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) - self._kqueue.control([kev], 0, 0) - return super().register_writer(fd, handler) - - def unregister_reader(self, fd): - super().unregister_reader(fd) - kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) - self._kqueue.control([kev], 0, 0) - - def unregister_writer(self, fd): - super().unregister_writer(fd) - kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE) - self._kqueue.control([kev], 0, 0) - - def poll(self, timeout=None): - events = [] - max_ev = len(self.readers) + len(self.writers) - for kev in self._kqueue.control(None, max_ev, timeout): - fd = kev.ident - flag = kev.filter - if flag == select.KQ_FILTER_READ and fd in self.readers: - events.append(self.readers[fd]) - elif flag == select.KQ_FILTER_WRITE and fd in self.writers: - events.append(self.writers[fd]) - return events - - -# Pick the best pollster class for the platform. -if hasattr(select, 'kqueue'): - best_pollster = KqueuePollster -elif hasattr(select, 'epoll'): - best_pollster = EPollPollster -elif hasattr(select, 'poll'): - best_pollster = PollPollster -else: - best_pollster = SelectPollster - - class _StopError(BaseException): """Raised to stop the event loop.""" @@ -433,12 +70,13 @@ class UnixEventLoop(events.EventLoop): See events.EventLoop for API specification. """ - def __init__(self, pollster=None): + def __init__(self, selector=None): super().__init__() - if pollster is None: - logging.info('Using pollster: %s', best_pollster.__name__) - pollster = best_pollster() - self._pollster = pollster + if selector is None: + # pick the best selector class for the platform + selector = selectors.Selector() + logging.info('Using selector: %s', selector.__name__) + self._selector = selector self._ready = collections.deque() self._scheduled = [] self._everytime = [] @@ -465,7 +103,9 @@ def run(self): TODO: Give this a timeout too? """ - while self._ready or self._scheduled or self._pollster.pollable() > 1: + while (self._ready or + self._scheduled or + self._selector.registered_count() > 1): try: self._run_once() except _StopError: @@ -702,36 +342,83 @@ def _accept_connection(self, protocol_factory, sock): def add_reader(self, fd, callback, *args): """Add a reader callback. Return a Handler instance.""" handler = events.Handler(None, callback, args) - self._pollster.register_reader(fd, handler) + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.SELECT_IN, + (handler, None, None)) + else: + self._selector.modify(fd, mask | selectors.SELECT_IN, + (handler, writer, connector)) + return handler def remove_reader(self, fd): """Remove a reader callback.""" - if fd in self._pollster.readers: - self._pollster.unregister_reader(fd) + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + pass + else: + mask &= ~selectors.SELECT_IN + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (None, writer, connector)) def add_writer(self, fd, callback, *args): """Add a writer callback. Return a Handler instance.""" handler = events.Handler(None, callback, args) - self._pollster.register_writer(fd, handler) + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.SELECT_OUT, + (None, handler, None)) + else: + self._selector.modify(fd, mask | selectors.SELECT_OUT, + (reader, handler, connector)) return handler def remove_writer(self, fd): """Remove a writer callback.""" - if fd in self._pollster.writers: - self._pollster.unregister_writer(fd) + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + pass + else: + mask &= ~selectors.SELECT_OUT + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None, connector)) def add_connector(self, fd, callback, *args): """Add a connector callback. Return a Handler instance.""" - dcall = events.Handler(None, callback, args) - self._pollster.register_connector(fd, dcall) - return dcall + # XXX As long as SELECT_CONNECT == SELECT_OUT, set the handler + # as both writer and connector. + handler = events.Handler(None, callback, args) + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.SELECT_CONNECT, + (None, handler, handler)) + else: + self._selector.modify(fd, mask | selectors.SELECT_CONNECT, + (reader, handler, handler)) + return handler def remove_connector(self, fd): """Remove a connector callback.""" - # Every connector fd is in self._pollsters.writers. - if fd in self._pollster.writers: - self._pollster.unregister_connector(fd) + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + pass + else: + mask &= ~selectors.SELECT_CONNECT + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None, None)) def sock_recv(self, sock, n): """XXX""" @@ -743,7 +430,7 @@ def _sock_recv(self, fut, registered, sock, n): fd = sock.fileno() if registered: # Remove the callback early. It should be rare that the - # pollster says the fd is ready but the call still returns + # selector says the fd is ready but the call still returns # EAGAIN, and I am willing to take a hit in that case in # order to simplify the common case. self.remove_reader(fd) @@ -876,10 +563,10 @@ def _run_once(self, timeout=None): while self._scheduled and self._scheduled[0].cancelled: heapq.heappop(self._scheduled) - # Inspect the poll queue. If there's exactly one pollable + # Inspect the poll queue. If there's exactly one selectable # file descriptor, it's the self-pipe, and if there's nothing # scheduled, we should ignore it. - if self._pollster.pollable() > 1 or self._scheduled: + if self._selector.registered_count() > 1 or self._scheduled: if self._ready: timeout = 0 elif self._scheduled: @@ -892,7 +579,7 @@ def _run_once(self, timeout=None): timeout = min(timeout, deadline) t0 = time.monotonic() - events = self._pollster.poll(timeout) + event_list = self._selector.select(timeout) t1 = time.monotonic() argstr = '' if timeout is None else ' %.3f' % timeout if t1-t0 >= 1: @@ -900,8 +587,13 @@ def _run_once(self, timeout=None): else: level = logging.DEBUG logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) - for handler in events: - self._add_callback(handler) + for fileobj, mask, (reader, writer, connector) in event_list: + if mask & selectors.SELECT_IN and reader is not None: + self._add_callback(reader) + if mask & selectors.SELECT_OUT and writer is not None: + self._add_callback(writer) + elif mask & selectors.SELECT_CONNECT and connector is not None: + self._add_callback(connector) # Handle 'later' callbacks that are ready. now = time.monotonic() From fbfd4e3d07458b46e82e70bc88e782b0d329268f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 8 Jan 2013 14:48:26 -0800 Subject: [PATCH 0173/1502] Make register()/unregister() in subclasses return a _Key. --- tulip/selectors.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tulip/selectors.py b/tulip/selectors.py index 53e2210c..f1ee8c3a 100644 --- a/tulip/selectors.py +++ b/tulip/selectors.py @@ -80,6 +80,9 @@ def register(self, fileobj, events, data=None): fileobj -- file object events -- events to monitor (bitwise mask of SELECT_IN|SELECT_OUT) data -- attached data + + Returns: + _Key instance """ if (not events) or (events & ~(SELECT_IN|SELECT_OUT)): raise ValueError("Invalid events: {}".format(events)) @@ -97,6 +100,9 @@ def unregister(self, fileobj): Parameters: fileobj -- file object + + Returns: + _Key instance """ try: key = self._fileobj_to_key[fileobj] @@ -200,11 +206,13 @@ def register(self, fileobj, events, data=None): self._readers.add(key.fd) if events & SELECT_OUT: self._writers.add(key.fd) + return key def unregister(self, fileobj): key = super().unregister(fileobj) self._readers.discard(key.fd) self._writers.discard(key.fd) + return key def select(self, timeout=None): r, w, _ = select(self._readers, self._writers, [], timeout) @@ -240,10 +248,12 @@ def register(self, fileobj, events, data=None): if events & SELECT_OUT: poll_events |= POLLOUT self._poll.register(key.fd, poll_events) + return key def unregister(self, fileobj): key = super().unregister(fileobj) self._poll.unregister(key.fd) + return key def select(self, timeout=None): timeout = None if timeout is None else int(1000 * timeout) @@ -277,10 +287,12 @@ def register(self, fileobj, events, data=None): if events & SELECT_OUT: epoll_events |= EPOLLOUT self._epoll.register(key.fd, epoll_events) + return key def unregister(self, fileobj): key = super().unregister(fileobj) self._epoll.unregister(key.fd) + return key def select(self, timeout=None): timeout = -1 if timeout is None else timeout @@ -320,6 +332,7 @@ def unregister(self, fileobj): mask |= KQ_FILTER_WRITE kev = kevent(key.fd, mask, KQ_EV_DELETE) self._kqueue.control([kev], 0, 0) + return key def register(self, fileobj, events, data=None): key = super().register(fileobj, events, data) @@ -329,6 +342,7 @@ def register(self, fileobj, events, data=None): if events & SELECT_OUT: kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_ADD) self._kqueue.control([kev], 0, 0) + return key def select(self, timeout=None): max_ev = self.registered_count() From 5f7436abd9672ac6963d851b1a663578277495e1 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 8 Jan 2013 15:45:02 -0800 Subject: [PATCH 0174/1502] Instead of init_event_loop(), define new_event_loop(). --- tulip/events.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/tulip/events.py b/tulip/events.py index 0a642f19..5049c20c 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -2,13 +2,12 @@ Beyond the PEP: - Only the main thread has a default event loop. -- init_event_loop() (re-)initializes the event loop. """ __all__ = ['EventLoopPolicy', 'DefaultEventLoopPolicy', 'EventLoop', 'Handler', 'get_event_loop_policy', 'set_event_loop_policy', - 'get_event_loop', 'set_event_loop', 'init_event_loop', + 'get_event_loop', 'set_event_loop', 'new_event_loop', ] import threading @@ -194,7 +193,7 @@ def set_event_loop(self, event_loop): """XXX""" raise NotImplementedError - def init_event_loop(self): + def new_event_loop(self): """XXX""" raise NotImplementedError @@ -221,7 +220,7 @@ def get_event_loop(self): """ if (self._event_loop is None and threading.current_thread().name == 'MainThread'): - self.init_event_loop() + self._event_loop = self.new_event_loop() return self._event_loop def set_event_loop(self, event_loop): @@ -229,15 +228,15 @@ def set_event_loop(self, event_loop): assert event_loop is None or isinstance(event_loop, EventLoop) self._event_loop = event_loop - def init_event_loop(self): - """(Re-)initialize the event loop. + def new_event_loop(self): + """Create a new event loop. - This is calls set_event_loop() with a freshly created event - loop suitable for the platform. + You must call set_event_loop() to make this the current event + loop. """ # TODO: Do something else for Windows. from . import unix_events - self.set_event_loop(unix_events.UnixEventLoop()) + return unix_events.UnixEventLoop() # Event loop policy. The policy itself is always global, even if the @@ -272,6 +271,6 @@ def set_event_loop(event_loop): get_event_loop_policy().set_event_loop(event_loop) -def init_event_loop(): +def new_event_loop(): """XXX""" - get_event_loop_policy().init_event_loop() + return get_event_loop_policy().new_event_loop() From c210c76a1477e40cda255f657a0d01fa9861d245 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 8 Jan 2013 17:15:40 -0800 Subject: [PATCH 0175/1502] Add signal handling APIs. --- tulip/events.py | 14 +++++++++ tulip/events_test.py | 37 ++++++++++++++++++++++ tulip/unix_events.py | 75 ++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 124 insertions(+), 2 deletions(-) diff --git a/tulip/events.py b/tulip/events.py index 5049c20c..ac3473d2 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -79,10 +79,16 @@ def __eq__(self, other): class EventLoop: """Abstract event loop.""" + # TODO: Rename run() -> run_until_idle(), run_forever() -> run(). + def run(self): """Run the event loop. Block until there is nothing left to do.""" raise NotImplementedError + def run_forever(self): + """Run the event loop. Block until stop() is called.""" + raise NotImplementedError + def run_once(self, timeout=None): # NEW! """Run one complete cycle of the event loop.""" raise NotImplementedError @@ -181,6 +187,14 @@ def sock_connect(self, sock, address): def sock_accept(self, sock): raise NotImplementedError + # Signal handling. + + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError + + def remove_signal_handler(self, sig): + raise NotImplementedError + class EventLoopPolicy: """Abstract policy for accessing the event loop.""" diff --git a/tulip/events_test.py b/tulip/events_test.py index 1a6652fc..e5cdab5a 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -1,7 +1,9 @@ """Tests for events.py.""" import concurrent.futures +import os import select +import signal import socket import threading import time @@ -42,6 +44,8 @@ def setUp(self): event_loop = unix_events.UnixEventLoop(selector) events.set_event_loop(event_loop) + # TODO: Add tearDown() which closes the selector and event loop. + def testRun(self): el = events.get_event_loop() el.run() # Returns immediately. @@ -196,6 +200,39 @@ def test_sock_accept(self): conn.close() listener.close() + def testAddSignalHandler(self): + caught = 0 + def my_handler(): + nonlocal caught + caught += 1 + el = events.get_event_loop() + # Check error behavior first. + self.assertRaises(TypeError, el.add_signal_handler, 'boom', my_handler) + self.assertRaises(TypeError, el.remove_signal_handler, 'boom') + self.assertRaises(ValueError, el.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises(ValueError, el.remove_signal_handler, signal.NSIG+1) + self.assertRaises(ValueError, el.add_signal_handler, 0, my_handler) + self.assertRaises(ValueError, el.remove_signal_handler, 0) + self.assertRaises(ValueError, el.add_signal_handler, -1, my_handler) + self.assertRaises(ValueError, el.remove_signal_handler, -1) + self.assertRaises(RuntimeError, el.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(el.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + el.add_signal_handler(signal.SIGINT, my_handler) + el.run_once() + os.kill(os.getpid(), signal.SIGINT) + el.run_once() + self.assertEqual(caught, 1) + # Removing it should restore the default handler. + self.assertTrue(el.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(el.remove_signal_handler(signal.SIGINT)) + def testCreateTransport(self): el = events.get_event_loop() # TODO: This depends on xkcd.com behavior! diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 575bde10..b9002710 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -25,6 +25,11 @@ import threading import time +try: + import signal +except ImportError: + signal = None + from . import events from . import futures from . import protocols @@ -81,6 +86,7 @@ def __init__(self, selector=None): self._scheduled = [] self._everytime = [] self._default_executor = None + self._signal_handlers = {} self._make_self_pipe() def _make_self_pipe(self): @@ -350,7 +356,7 @@ def add_reader(self, fd, callback, *args): else: self._selector.modify(fd, mask | selectors.SELECT_IN, (handler, writer, connector)) - + return handler def remove_reader(self, fd): @@ -527,6 +533,71 @@ def _sock_accept(self, fut, registered, sock): else: self.add_reader(fd, self._sock_accept, fut, True, sock) + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + self._check_signal(sig) + handler = events.Handler(None, callback, args) + self._signal_handlers[sig] = handler + try: + signal.signal(sig, self._handle_signal) + except OSError as exc: + del self._signal_handlers[sig] + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + return handler + + def _handle_signal(self, sig, arg): + """Internal helper that is the actual signal handler.""" + handler = self._signal_handlers.get(sig) + if handler is None: + return # Assume it's some race condition. + if handler.cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self.call_soon_threadsafe(handler.callback, *handler.args) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not.""" + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + if signal is None: + raise RuntimeError('Signals are not supported') + if not (1 <= sig < signal.NSIG): + raise ValueError('sig {} out of range(1, {})'.format(sig, + signal.NSIG)) + def _add_callback(self, handler): """Add a Handler to ready or scheduled.""" if handler.cancelled: @@ -644,7 +715,7 @@ def _read_ready(self): else: self._event_loop.remove_reader(self._sock.fileno()) self._event_loop.call_soon(self._protocol.eof_received) - + def write(self, data): assert isinstance(data, bytes) From da335f4bcbe98bfbe214d24831a00e14840103ad Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 8 Jan 2013 17:48:04 -0800 Subject: [PATCH 0176/1502] Add test for cancelling a signal handler. --- tulip/events_test.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tulip/events_test.py b/tulip/events_test.py index e5cdab5a..5133e960 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -233,6 +233,19 @@ def my_handler(): # Removing again returns False. self.assertFalse(el.remove_signal_handler(signal.SIGINT)) + def testCancelSignalHandler(self): + # Cancelling the handler should remove it (eventually). + caught = 0 + def my_handler(): + nonlocal caught + caught += 1 + el = events.get_event_loop() + handler = el.add_signal_handler(signal.SIGINT, my_handler) + handler.cancel() + os.kill(os.getpid(), signal.SIGINT) + el.run_once() + self.assertEqual(caught, 0) + def testCreateTransport(self): el = events.get_event_loop() # TODO: This depends on xkcd.com behavior! From a8786eb7e8eb63c9fd7f21db4b3aa5004f004d8d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 8 Jan 2013 19:22:37 -0800 Subject: [PATCH 0177/1502] Allow Handler instances instead of callbacks everywhere. --- tulip/events.py | 31 ++++++------- tulip/events_test.py | 107 ++++++++++++++++++++++++++++++++++++++++--- tulip/unix_events.py | 33 +++++++------ 3 files changed, 133 insertions(+), 38 deletions(-) diff --git a/tulip/events.py b/tulip/events.py index ac3473d2..a5c3e383 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -5,7 +5,7 @@ """ __all__ = ['EventLoopPolicy', 'DefaultEventLoopPolicy', - 'EventLoop', 'Handler', + 'EventLoop', 'Handler', 'make_handler', 'get_event_loop_policy', 'set_event_loop_policy', 'get_event_loop', 'set_event_loop', 'new_event_loop', ] @@ -16,23 +16,16 @@ class Handler: """Object returned by callback registration methods.""" - def __init__(self, when, callback, args, kwds=None): + def __init__(self, when, callback, args): self._when = when self._callback = callback self._args = args - self._kwds = kwds self._cancelled = False def __repr__(self): - if self.kwds: - res = 'Handler({}, {}, {}, kwds={})'.format(self._when, - self._callback, - self._args, - self._kwds) - else: - res = 'Handler({}, {}, {})'.format(self._when, - self._callback, - self._args) + res = 'Handler({}, {}, {})'.format(self._when, + self._callback, + self._args) if self._cancelled: res += '' return res @@ -49,10 +42,6 @@ def callback(self): def args(self): return self._args - @property - def kwds(self): - return self._kwds - @property def cancelled(self): return self._cancelled @@ -76,6 +65,14 @@ def __eq__(self, other): return self._when == other._when +def make_handler(when, callback, args): + if isinstance(callback, Handler): + assert not args + assert when is None + return callback + return Handler(when, callback, args) + + class EventLoop: """Abstract event loop.""" @@ -131,7 +128,7 @@ def call_soon_threadsafe(self, callback, *args): def wrap_future(self, future): raise NotImplementedError - def run_in_executor(self, executor, function, *args): + def run_in_executor(self, executor, callback, *args): raise NotImplementedError # Network I/O methods returning Futures. diff --git a/tulip/events_test.py b/tulip/events_test.py index 5133e960..2cc4dc25 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -81,6 +81,16 @@ def callback(arg1, arg2): el.run() self.assertEqual(results, [('hello', 'world')]) + def testCallSoonWithHandler(self): + el = events.get_event_loop() + results = [] + def callback(): + results.append('yeah') + handler = events.Handler(None, callback, ()) + self.assertEqual(el.call_soon(handler), handler) + el.run() + self.assertEqual(results, ['yeah']) + def testCallSoonThreadsafe(self): el = events.get_event_loop() results = [] @@ -98,18 +108,52 @@ def run(): self.assertEqual(results, ['hello', 'world']) self.assertTrue(t1-t0 >= 0.09) + def testCallSoonThreadsafeWithHandler(self): + el = events.get_event_loop() + results = [] + def callback(arg): + results.append(arg) + handler = events.Handler(None, callback, ('hello',)) + def run(): + self.assertEqual(el.call_soon_threadsafe(handler), handler) + t = threading.Thread(target=run) + el.call_later(0.1, callback, 'world') + t0 = time.monotonic() + t.start() + el.run() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + def testCallEveryIteration(self): el = events.get_event_loop() results = [] def callback(arg): results.append(arg) - handle = el.call_every_iteration(callback, 'ho') + handler = el.call_every_iteration(callback, 'ho') el.run_once() self.assertEqual(results, ['ho']) el.run_once() el.run_once() self.assertEqual(results, ['ho', 'ho', 'ho']) - handle.cancel() + handler.cancel() + el.run_once() + self.assertEqual(results, ['ho', 'ho', 'ho']) + + def testCallEveryIterationWithHandler(self): + el = events.get_event_loop() + results = [] + def callback(arg): + results.append(arg) + handler = events.Handler(None, callback, ('ho',)) + self.assertEqual(el.call_every_iteration(handler), handler) + el.run_once() + self.assertEqual(results, ['ho']) + el.run_once() + el.run_once() + self.assertEqual(results, ['ho', 'ho', 'ho']) + handler.cancel() el.run_once() self.assertEqual(results, ['ho', 'ho', 'ho']) @@ -133,7 +177,17 @@ def run(arg): res = el.run_until_complete(f2) self.assertEqual(res, 'yo') - def test_reader_callback(self): + def testRunInExecutorWithHandler(self): + el = events.get_event_loop() + def run(arg): + time.sleep(0.1) + return arg + handler = events.Handler(None, run, ('yo',)) + f2 = el.run_in_executor(None, handler) + res = el.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def testReaderCallback(self): el = events.get_event_loop() r, w = unix_events.socketpair() bytes_read = [] @@ -151,7 +205,26 @@ def reader(): el.run() self.assertEqual(b''.join(bytes_read), b'abcdef') - def test_writer_callback(self): + def testReaderCallbackWithHandler(self): + el = events.get_event_loop() + r, w = unix_events.socketpair() + bytes_read = [] + def reader(): + data = r.recv(1024) + if data: + bytes_read.append(data) + else: + el.remove_reader(r.fileno()) + r.close() + handler = events.Handler(None, reader, ()) + self.assertEqual(el.add_reader(r.fileno(), handler), handler) + el.call_later(0.05, w.send, b'abc') + el.call_later(0.1, w.send, b'def') + el.call_later(0.15, w.close) + el.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def testWriterCallback(self): el = events.get_event_loop() r, w = unix_events.socketpair() w.setblocking(False) @@ -163,7 +236,20 @@ def test_writer_callback(self): r.close() self.assertTrue(len(data) >= 200) - def test_sock_client_ops(self): + def testWriterCallbackWithHandler(self): + el = events.get_event_loop() + r, w = unix_events.socketpair() + w.setblocking(False) + handler = events.Handler(None, w.send, (b'x'*(256*1024),)) + self.assertEqual(el.add_writer(w.fileno(), handler), handler) + el.call_later(0.1, el.remove_writer, w.fileno()) + el.run() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def testSockClientOps(self): el = events.get_event_loop() sock = socket.socket() sock.setblocking(False) @@ -174,7 +260,7 @@ def test_sock_client_ops(self): sock.close() self.assertTrue(data.startswith(b'HTTP/1.1 302 Found\r\n')) - def test_sock_client_fail(self): + def testSockClientFail(self): el = events.get_event_loop() sock = socket.socket() sock.setblocking(False) @@ -183,7 +269,7 @@ def test_sock_client_fail(self): el.run_until_complete(el.sock_connect(sock, ('python.org', 12345))) sock.close() - def test_sock_accept(self): + def testSockAccept(self): listener = socket.socket() listener.setblocking(False) listener.bind(('127.0.0.1', 0)) @@ -306,6 +392,13 @@ class HandlerTests(unittest.TestCase): def testHandler(self): pass + def testMakeHandler(self): + def callback(*args): + return args + h1 = events.Handler(None, callback, ()) + h2 = events.make_handler(None, h1, ()) + self.assertEqual(h1, h2) + class PolicyTests(unittest.TestCase): diff --git a/tulip/unix_events.py b/tulip/unix_events.py index b9002710..184661f3 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -191,7 +191,7 @@ def call_later(self, delay, callback, *args): """ if delay <= 0: return self.call_soon(callback, *args) - handler = events.Handler(time.monotonic() + delay, callback, args) + handler = events.make_handler(time.monotonic() + delay, callback, args) heapq.heappush(self._scheduled, handler) return handler @@ -201,7 +201,7 @@ def wrapper(): callback(*args) # If this fails, the chain is broken. handler._when = time.monotonic() + interval heapq.heappush(self._scheduled, handler) - handler = events.Handler(time.monotonic() + interval, wrapper, ()) + handler = events.make_handler(time.monotonic() + interval, wrapper, ()) heapq.heappush(self._scheduled, handler) return handler @@ -215,7 +215,7 @@ def call_soon(self, callback, *args): Any positional arguments after the callback will be passed to the callback when it is called. """ - handler = events.Handler(None, callback, args) + handler = events.make_handler(None, callback, args) self._ready.append(handler) return handler @@ -230,7 +230,7 @@ def call_every_iteration(self, callback, *args): The callback is called for every iteration of the loop. """ - handler = events.Handler(None, callback, args) + handler = events.make_handler(None, callback, args) self._everytime.append(handler) return handler @@ -244,13 +244,21 @@ def wrap_future(self, future): self.call_soon_threadsafe(new_future._copy_state, future)) return new_future - def run_in_executor(self, executor, function, *args): + def run_in_executor(self, executor, callback, *args): + if isinstance(callback, events.Handler): + assert not args + assert callback.when is None + if callback.cancelled: + f = futures.Future() + f.set_result(None) + return f + callback, args = callback.callback, callback.args if executor is None: executor = self._default_executor if executor is None: executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) self._default_executor = executor - return self.wrap_future(executor.submit(function, *args)) + return self.wrap_future(executor.submit(callback, *args)) def set_default_executor(self, executor): self._default_executor = executor @@ -347,7 +355,7 @@ def _accept_connection(self, protocol_factory, sock): def add_reader(self, fd, callback, *args): """Add a reader callback. Return a Handler instance.""" - handler = events.Handler(None, callback, args) + handler = events.make_handler(None, callback, args) try: mask, (reader, writer, connector) = self._selector.get_info(fd) except KeyError: @@ -374,7 +382,7 @@ def remove_reader(self, fd): def add_writer(self, fd, callback, *args): """Add a writer callback. Return a Handler instance.""" - handler = events.Handler(None, callback, args) + handler = events.make_handler(None, callback, args) try: mask, (reader, writer, connector) = self._selector.get_info(fd) except KeyError: @@ -402,7 +410,7 @@ def add_connector(self, fd, callback, *args): """Add a connector callback. Return a Handler instance.""" # XXX As long as SELECT_CONNECT == SELECT_OUT, set the handler # as both writer and connector. - handler = events.Handler(None, callback, args) + handler = events.make_handler(None, callback, args) try: mask, (reader, writer, connector) = self._selector.get_info(fd) except KeyError: @@ -540,7 +548,7 @@ def add_signal_handler(self, sig, callback, *args): Raise RuntimeError if there is a problem setting up the handler. """ self._check_signal(sig) - handler = events.Handler(None, callback, args) + handler = events.make_handler(None, callback, args) self._signal_handlers[sig] = handler try: signal.signal(sig, self._handle_signal) @@ -683,10 +691,7 @@ def _run_once(self, timeout=None): handler = self._ready.popleft() if not handler.cancelled: try: - if handler.kwds: - handler.callback(*handler.args, **handler.kwds) - else: - handler.callback(*handler.args) + handler.callback(*handler.args) except Exception: logging.exception('Exception in callback %s %r', handler.callback, handler.args) From f090d1eb3638c9652b67ece5b935f298fd4a496f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 8 Jan 2013 19:37:15 -0800 Subject: [PATCH 0178/1502] Explicitly close the event loop after each test. --- tulip/events_test.py | 9 +++++---- tulip/tasks_test.py | 8 ++++++++ tulip/unix_events.py | 7 ++++++- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/tulip/events_test.py b/tulip/events_test.py index 2cc4dc25..42a26a93 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -40,11 +40,12 @@ def connection_lost(self, exc): class EventLoopTestsMixin: def setUp(self): - selector = self.SELECTOR_CLASS() - event_loop = unix_events.UnixEventLoop(selector) - events.set_event_loop(event_loop) + self.selector = self.SELECTOR_CLASS() + self.event_loop = unix_events.UnixEventLoop(self.selector) + events.set_event_loop(self.event_loop) - # TODO: Add tearDown() which closes the selector and event loop. + def tearDown(self): + self.event_loop.close() def testRun(self): el = events.get_event_loop() diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py index c9c184c1..4182056c 100644 --- a/tulip/tasks_test.py +++ b/tulip/tasks_test.py @@ -3,12 +3,20 @@ import time import unittest +from . import events from . import futures from . import tasks class TaskTests(unittest.TestCase): + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + def testTaskClass(self): @tasks.coroutine def notmuch(): diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 184661f3..46f01f49 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -80,7 +80,7 @@ def __init__(self, selector=None): if selector is None: # pick the best selector class for the platform selector = selectors.Selector() - logging.info('Using selector: %s', selector.__name__) + logging.info('Using selector: %s', selector.__class__.__name__) self._selector = selector self._ready = collections.deque() self._scheduled = [] @@ -89,6 +89,11 @@ def __init__(self, selector=None): self._signal_handlers = {} self._make_self_pipe() + def close(self): + if self._selector is not None: + self._selector.close() + self._selector = None + def _make_self_pipe(self): # A self-socket, really. :-) self._ssock, self._csock = socketpair() From 36cf2e5a0b66a36e190d8a48e0fdb7205c983c11 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 8 Jan 2013 20:23:55 -0800 Subject: [PATCH 0179/1502] Close all sockets properly. --- tulip/events_test.py | 4 ++++ tulip/unix_events.py | 13 +++++++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/tulip/events_test.py b/tulip/events_test.py index 42a26a93..b45c286d 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -1,6 +1,7 @@ """Tests for events.py.""" import concurrent.futures +import gc import os import select import signal @@ -46,6 +47,7 @@ def setUp(self): def tearDown(self): self.event_loop.close() + gc.collect() def testRun(self): el = events.get_event_loop() @@ -363,9 +365,11 @@ def testStartServing(self): client = socket.socket() client.connect((host, port)) client.send(b'xxx') + el.run_once() # This is quite mysterious, but necessary. client.close() el.run_once() el.run_once() + sock.close() if hasattr(selectors, 'KqueueSelector'): diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 46f01f49..116cd5a5 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -766,8 +766,7 @@ def _write_ready(self): if n == len(data): self._event_loop.remove_writer(self._sock.fileno()) if self._closing: - self._event_loop.call_soon(self._protocol.connection_lost, - None) + self._event_loop.call_soon(self._call_connection_lost, None) return if n: data = data[n:] @@ -782,14 +781,20 @@ def close(self): self._closing = True self._event_loop.remove_reader(self._sock.fileno()) if not self._buffer: - self._event_loop.call_soon(self._protocol.connection_lost, None) + self._event_loop.call_soon(self._call_connection_lost, None) def _fatal_error(self, exc): logging.exception('Fatal error for %s', self) self._event_loop.remove_writer(self._sock.fileno()) self._event_loop.remove_reader(self._sock.fileno()) self._buffer = [] - self._event_loop.call_soon(self._protocol.connection_lost, exc) + self._event_loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() class _UnixSslTransport(transports.Transport): From 084d0b070e141b3596135ea92c4c773b0dcdc354 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 9 Jan 2013 09:31:23 -0800 Subject: [PATCH 0180/1502] Use separate r/w kevents in KqueueSelector.unregister(). --- tulip/selectors.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tulip/selectors.py b/tulip/selectors.py index f1ee8c3a..ed3e19f8 100644 --- a/tulip/selectors.py +++ b/tulip/selectors.py @@ -327,11 +327,11 @@ def unregister(self, fileobj): key = super().unregister(fileobj) mask = 0 if key.events & SELECT_IN: - mask |= KQ_FILTER_READ + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) if key.events & SELECT_OUT: - mask |= KQ_FILTER_WRITE - kev = kevent(key.fd, mask, KQ_EV_DELETE) - self._kqueue.control([kev], 0, 0) + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) return key def register(self, fileobj, events, data=None): From dead0c0e3cae24bff5c87e1765140f0506cd77fd Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 9 Jan 2013 10:05:53 -0800 Subject: [PATCH 0181/1502] Rename _Key to Key. --- tulip/selectors.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tulip/selectors.py b/tulip/selectors.py index ed3e19f8..95c51366 100644 --- a/tulip/selectors.py +++ b/tulip/selectors.py @@ -36,7 +36,7 @@ def _fileobj_to_fd(fileobj): return fd -class _Key: +class Key: """Object used internally to associate a file object to its backing file descriptor, selected event mask and attached data.""" @@ -82,7 +82,7 @@ def register(self, fileobj, events, data=None): data -- attached data Returns: - _Key instance + Key instance """ if (not events) or (events & ~(SELECT_IN|SELECT_OUT)): raise ValueError("Invalid events: {}".format(events)) @@ -90,7 +90,7 @@ def register(self, fileobj, events, data=None): if fileobj in self._fileobj_to_key: raise ValueError("{!r} is already registered".format(fileobj)) - key = _Key(fileobj, events, data) + key = Key(fileobj, events, data) self._fd_to_key[key.fd] = key self._fileobj_to_key[fileobj] = key return key @@ -102,7 +102,7 @@ def unregister(self, fileobj): fileobj -- file object Returns: - _Key instance + Key instance """ try: key = self._fileobj_to_key[fileobj] From 888b317432696d1c2bfa7fb320f947704da82e9f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 9 Jan 2013 10:20:51 -0800 Subject: [PATCH 0182/1502] Use signal.set_wakeup_fd() to avoid race condition. --- tulip/unix_events.py | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 116cd5a5..2595368b 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -97,13 +97,25 @@ def close(self): def _make_self_pipe(self): # A self-socket, really. :-) self._ssock, self._csock = socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) self.add_reader(self._ssock.fileno(), self._read_from_self) def _read_from_self(self): - self._ssock.recv(1) + try: + self._ssock.recv(1) + except socket.error as exc: + if exc in _TRYAGAIN: + return + raise # Halp! def _write_to_self(self): - self._csock.send(b'x') + try: + self._csock.send(b'x') + except socket.error as exc: + if exc in _TRYAGAIN: + return + raise # Halp! def run(self): """Run the event loop until nothing left to do or stop() called. @@ -553,12 +565,25 @@ def add_signal_handler(self, sig, callback, *args): Raise RuntimeError if there is a problem setting up the handler. """ self._check_signal(sig) + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except ValueError as exc: + raise RuntimeError(str(exc)) handler = events.make_handler(None, callback, args) self._signal_handlers[sig] = handler try: signal.signal(sig, self._handle_signal) except OSError as exc: del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as nexc: + logging.info('set_wakeup_fd(-1) failed: %s', nexc) if exc.errno == errno.EINVAL: raise RuntimeError('sig {} cannot be caught'.format(sig)) else: @@ -595,6 +620,11 @@ def remove_signal_handler(self, sig): raise RuntimeError('sig {} cannot be caught'.format(sig)) else: raise + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as exc: + logging.info('set_wakeup_fd(-1) failed: %s', exc) return True def _check_signal(self, sig): From 291d26c5563307e33f7a4aaee406b75c4b8c591a Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 9 Jan 2013 11:24:35 -0800 Subject: [PATCH 0183/1502] Test for sleep(dt) without extra arg. --- tulip/tasks_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py index 4182056c..0285e91a 100644 --- a/tulip/tasks_test.py +++ b/tulip/tasks_test.py @@ -40,7 +40,8 @@ def notmuch(): def testSleep(self): @tasks.coroutine def sleeper(dt, arg): - res = yield from futures.sleep(dt, arg) + yield from futures.sleep(dt/2) + res = yield from futures.sleep(dt/2, arg) return res t = tasks.Task(sleeper(0.1, 'yeah')) t0 = time.monotonic() From e77a14274ae037d4f9382b28311a5b2e0444436f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 9 Jan 2013 12:00:41 -0800 Subject: [PATCH 0184/1502] Fix definition of TimeoutError. --- tulip/futures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tulip/futures.py b/tulip/futures.py index b079c5fc..62b9e4c5 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -17,7 +17,7 @@ # TODO: Do we really want to depend on concurrent.futures internals? Error = concurrent.futures._base.Error CancelledError = concurrent.futures.CancelledError -TimeoutError = concurrent.futures.CancelledError +TimeoutError = concurrent.futures.TimeoutError class InvalidStateError(Error): From f08bbcaf8f51ba83a65a3c3edd90241bde046500 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 9 Jan 2013 13:42:30 -0800 Subject: [PATCH 0185/1502] Handle signals arriving during a select() (etc.) call. --- tulip/events_test.py | 13 +++++++++++++ tulip/selectors.py | 27 +++++++++++++++++++++++---- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/tulip/events_test.py b/tulip/events_test.py index b45c286d..f8b718e0 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -335,6 +335,19 @@ def my_handler(): el.run_once() self.assertEqual(caught, 0) + def testSignalHandlingWhileSelecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + def my_handler(): + nonlocal caught + caught += 1 + el = events.get_event_loop() + handler = el.add_signal_handler(signal.SIGALRM, my_handler) + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + el.call_later(0.15, el.stop) + el.run_forever() + self.assertEqual(caught, 1) + def testCreateTransport(self): el = events.get_event_loop() # TODO: This depends on xkcd.com behavior! diff --git a/tulip/selectors.py b/tulip/selectors.py index 95c51366..3bc95b42 100644 --- a/tulip/selectors.py +++ b/tulip/selectors.py @@ -215,7 +215,11 @@ def unregister(self, fileobj): return key def select(self, timeout=None): - r, w, _ = select(self._readers, self._writers, [], timeout) + try: + r, w, _ = select(self._readers, self._writers, [], timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] r = set(r) w = set(w) ready = [] @@ -258,7 +262,12 @@ def unregister(self, fileobj): def select(self, timeout=None): timeout = None if timeout is None else int(1000 * timeout) ready = [] - for fd, event in self._poll.poll(timeout): + try: + fd_event_list = self._poll.poll(timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: events = 0 if event & ~POLLIN: events |= SELECT_OUT @@ -298,7 +307,12 @@ def select(self, timeout=None): timeout = -1 if timeout is None else timeout max_ev = self.registered_count() ready = [] - for fd, event in self._epoll.poll(timeout, max_ev): + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: events = 0 if event & ~EPOLLIN: events |= SELECT_OUT @@ -347,7 +361,12 @@ def register(self, fileobj, events, data=None): def select(self, timeout=None): max_ev = self.registered_count() ready = [] - for kev in self._kqueue.control(None, max_ev, timeout): + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for kev in kev_list: fd = kev.ident flag = kev.filter events = 0 From 12859db42680a9f473ad23bdddef4658322f0382 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 9 Jan 2013 13:44:50 -0800 Subject: [PATCH 0186/1502] Rename Key to SelectorKey. --- tulip/selectors.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tulip/selectors.py b/tulip/selectors.py index 3bc95b42..7923686a 100644 --- a/tulip/selectors.py +++ b/tulip/selectors.py @@ -36,7 +36,7 @@ def _fileobj_to_fd(fileobj): return fd -class Key: +class SelectorKey: """Object used internally to associate a file object to its backing file descriptor, selected event mask and attached data.""" @@ -82,7 +82,7 @@ def register(self, fileobj, events, data=None): data -- attached data Returns: - Key instance + SelectorKey instance """ if (not events) or (events & ~(SELECT_IN|SELECT_OUT)): raise ValueError("Invalid events: {}".format(events)) @@ -90,7 +90,7 @@ def register(self, fileobj, events, data=None): if fileobj in self._fileobj_to_key: raise ValueError("{!r} is already registered".format(fileobj)) - key = Key(fileobj, events, data) + key = SelectorKey(fileobj, events, data) self._fd_to_key[key.fd] = key self._fileobj_to_key[fileobj] = key return key @@ -102,7 +102,7 @@ def unregister(self, fileobj): fileobj -- file object Returns: - Key instance + SelectorKey instance """ try: key = self._fileobj_to_key[fileobj] From 0fabd025f947074733a58b10511253f33775c2e0 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 9 Jan 2013 16:43:32 -0800 Subject: [PATCH 0187/1502] Implement wait() and as_completed() (partially). --- tulip/futures.py | 29 ++++------ tulip/tasks.py | 134 +++++++++++++++++++++++++++++++++++++++++++- tulip/tasks_test.py | 58 +++++++++++++++++-- 3 files changed, 198 insertions(+), 23 deletions(-) diff --git a/tulip/futures.py b/tulip/futures.py index 62b9e4c5..c682dd4b 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -2,7 +2,7 @@ __all__ = ['CancelledError', 'TimeoutError', 'InvalidStateError', 'InvalidTimeoutError', - 'Future', 'sleep', + 'Future', ] import concurrent.futures._base @@ -52,7 +52,7 @@ class Future: # was called). Here, cancel() schedules the callbacks, and # set_running_or_notify_cancel() is not supported. - # Class variables serving to as defaults for instance variables. + # Class variables serving as defaults for instance variables. _state = _PENDING _result = None _exception = None @@ -95,8 +95,6 @@ def _schedule_callbacks(self): callbacks = self._callbacks[:] if not callbacks: return - # Is it worth emptying the callbacks? It may reduce the - # usefulness of repr(). self._callbacks[:] = [] for callback in callbacks: self._event_loop.call_soon(callback, self) @@ -142,6 +140,16 @@ def add_done_callback(self, fn): else: self._callbacks.append(fn) + # New method not in PPE 3148. + + def remove_done_callback(self, fn): + """XXX""" + filtered_callbacks = [f for f in self._callbacks if f != fn] + removed_count = len(self._callbacks) - len(filtered_callbacks) + if removed_count: + self._callbacks[:] = filtered_callbacks + return removed_count + # So-called internal methods (note: no set_running_or_notify_cancel()). def set_result(self, result): @@ -183,16 +191,3 @@ def __iter__(self): if not self.done(): yield self # This tells Task to wait for completion. return self.result() # May raise too. - - -# TODO: Is this the right module for sleep()? -def sleep(when, result=None): - """Return a Future that completes after a given time (in seconds). - - It's okay to cancel the Future. - - Undocumented feature: sleep(when, x) sets the Future's result to x. - """ - future = Future() - future._event_loop.call_later(when, future.set_result, result) - return future diff --git a/tulip/tasks.py b/tulip/tasks.py index 7c7828d8..3f151cd5 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -1,6 +1,9 @@ """Support for tasks, coroutines and the scheduler.""" -__all__ = ['coroutine', 'task', 'Task'] +__all__ = ['coroutine', 'task', 'Task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'as_completed', 'sleep', + ] import concurrent.futures import inspect @@ -116,3 +119,132 @@ def _wakeup(future): if result is not None: logging.warn('_step(): bad yield: %r', result) self._event_loop.call_soon(self._step) + + +# wait() and as_completed() similar to those in PEP 3148. + +FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED +FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION +ALL_COMPLETED = concurrent.futures.ALL_COMPLETED + + +# Even though this *is* a @coroutine, we don't mark it as such! +def wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Wait for the Futures and and coroutines given by fs to complete. + + Coroutines will be wrapped in Tasks. + + Returns two sets of Future: (done, pending). + + Usage: + + done, pending = yield from tulip.wait(fs) + + Note: This does not raise TimeoutError! Futures that aren't done + when the timeout occurs are returned in the second set. + """ + fs = _wrap_coroutines(fs) + return _wait(fs, timeout, return_when) + + +@coroutine +def _wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Internal helper: _wait() but does not wrap coroutines.""" + done, pending = set(), set() + errors = 0 + for f in fs: + if f.done(): + done.add(f) + if not f.cancelled() and f.exception() is not None: + errors += 1 + else: + pending.add(f) + if (not pending or + timeout != None and timeout <= 0 or + return_when == FIRST_COMPLETED and done or + return_when == FIRST_EXCEPTION and errors): + return done, pending + bail = futures.Future() + timeout_handler = None + if timeout is not None: + loop = events.get_event_loop() + timeout_handler = loop.call_later(timeout, bail.set_result, None) + def _on_completion(f): + pending.remove(f) + done.add(f) + if (not pending or + return_when == FIRST_COMPLETED or + (return_when == FIRST_EXCEPTION and + not f.cancelled() and + f.exception() is not None)): + if not bail.done(): + bail.set_result(None) + try: + for f in pending: + f.add_done_callback(_on_completion) + yield from bail + finally: + for f in pending: + f.remove_done_callback(_on_completion) + if timeout_handler is not None: + timeout_handler.cancel() + really_done = set(f for f in pending if f.done()) + if really_done: + done.update(really_done) + pending.difference_update(really_done) + return done, pending + + +# This is *not* a @coroutine! It is just an iterator (yielding Futures). +def as_completed(fs, timeout=None): + """Return an iterator whose values, when waited for, are Futures. + + This differs from PEP 3148; the proper way to use this is: + + for f in as_completed(fs): + result = yield from f # The 'yield from' may raise. + # Use result. + + Raises TimeoutError if the timeout occurs before all Futures are + done. + + Note: The futures 'f' are not necessarily members of fs. + """ + assert timeout is None, 'timeout not yet supported' + done = None # Make nonlocal happy. + fs = _wrap_coroutines(fs) + while fs: + @coroutine + def _wait_for_some(): + nonlocal done, fs + done, fs = yield from _wait(fs, return_when=FIRST_COMPLETED) + return done.pop().result() # May raise. + yield Task(_wait_for_some()) + for f in done: + yield f + + +def _wrap_coroutines(fs): + """Internal helper to process an iterator of Futures and coroutines. + + Returns a set of Futures. + """ + wrapped = set() + for f in fs: + if not isinstance(f, futures.Future): + assert iscoroutine(f) + f = Task(f) + wrapped.add(f) + return wrapped + + +def sleep(when, result=None): + """Return a Future that completes after a given time (in seconds). + + It's okay to cancel the Future. + + Undocumented feature: sleep(when, x) sets the Future's result to x. + """ + future = futures.Future() + future._event_loop.call_later(when, future.set_result, result) + return future diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py index 0285e91a..e620bd6e 100644 --- a/tulip/tasks_test.py +++ b/tulip/tasks_test.py @@ -23,7 +23,7 @@ def notmuch(): yield from [] return 'ok' t = tasks.Task(notmuch()) - t._event_loop.run() + self.event_loop.run() self.assertTrue(t.done()) self.assertEqual(t.result(), 'ok') @@ -33,19 +33,67 @@ def notmuch(): yield from [] return 'ko' t = notmuch() - t._event_loop.run() + self.event_loop.run() self.assertTrue(t.done()) self.assertEqual(t.result(), 'ko') + def testWait(self): + a = tasks.sleep(0.1) + b = tasks.sleep(0.15) + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertEqual(res, 42) + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0, 0.01) + # TODO: Test with timeout. + + def testAsCompleted(self): + @tasks.coroutine + def sleeper(dt, x): + yield from tasks.sleep(dt) + return x + a = sleeper(0.1, 'a') + b = sleeper(0.1, 'b') + c = sleeper(0.15, 'c') + def foo(): + values = [] + for f in tasks.as_completed([b, c, a]): + values.append((yield from f)) + return values + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0, 0.01) + # TODO: Test with timeout. + def testSleep(self): @tasks.coroutine def sleeper(dt, arg): - yield from futures.sleep(dt/2) - res = yield from futures.sleep(dt/2, arg) + yield from tasks.sleep(dt/2) + res = yield from tasks.sleep(dt/2, arg) return res t = tasks.Task(sleeper(0.1, 'yeah')) t0 = time.monotonic() - t._event_loop.run() + self.event_loop.run() t1 = time.monotonic() self.assertTrue(t1-t0 >= 0.09) self.assertTrue(t.done()) From 6522d61779bf0fda9094d168cde41bba26ea9527 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 9 Jan 2013 17:41:49 -0800 Subject: [PATCH 0188/1502] More tests for wait(): exception, timeout. --- tulip/tasks_test.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py index e620bd6e..f1f17d10 100644 --- a/tulip/tasks_test.py +++ b/tulip/tasks_test.py @@ -56,7 +56,42 @@ def foo(): res = self.event_loop.run_until_complete(tasks.Task(foo())) t1 = time.monotonic() self.assertTrue(t1-t0, 0.01) - # TODO: Test with timeout. + # TODO: Test different return_when values. + + def testWaitWithException(self): + a = tasks.sleep(0.1) + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(0.15) + raise ZeroDivisionError('really') + b = tasks.Task(sleeper()) + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + + def testWaitWithTimeout(self): + a = tasks.sleep(0.1) + b = tasks.sleep(0.15) + def foo(): + done, pending = yield from tasks.wait([b, a], timeout=0.11) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.1) + self.assertTrue(t1-t0 <= 0.12) def testAsCompleted(self): @tasks.coroutine From 78402f04bcf224a9e47576d26eadfe3e16585b52 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 9 Jan 2013 21:35:55 -0800 Subject: [PATCH 0189/1502] Use repr() of result/exception in repr() of Future. --- tulip/futures.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tulip/futures.py b/tulip/futures.py index c682dd4b..d679e31f 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -67,9 +67,9 @@ def __repr__(self): res = self.__class__.__name__ if self._state == _FINISHED: if self._exception is not None: - res += ''.format(self._exception) + res += ''.format(self._exception) else: - res += ''.format(self._result) + res += ''.format(self._result) elif self._callbacks: size = len(self._callbacks) if size > 2: From 18acbde487c6e4578680a3da9a9ee6b09e1401b9 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 9 Jan 2013 21:44:13 -0800 Subject: [PATCH 0190/1502] Add timeout support to as_completed(). --- tulip/tasks.py | 25 +++++++++++++++++++++---- tulip/tasks_test.py | 26 +++++++++++++++++++++++++- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/tulip/tasks.py b/tulip/tasks.py index 3f151cd5..55377db1 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -7,6 +7,7 @@ import concurrent.futures import inspect +import time from . import events from . import futures @@ -149,7 +150,7 @@ def wait(fs, timeout=None, return_when=ALL_COMPLETED): @coroutine def _wait(fs, timeout=None, return_when=ALL_COMPLETED): - """Internal helper: _wait() but does not wrap coroutines.""" + """Internal helper: Like wait() but does not wrap coroutines.""" done, pending = set(), set() errors = 0 for f in fs: @@ -168,7 +169,15 @@ def _wait(fs, timeout=None, return_when=ALL_COMPLETED): timeout_handler = None if timeout is not None: loop = events.get_event_loop() - timeout_handler = loop.call_later(timeout, bail.set_result, None) + # We need a helper to check that bail isn't already done, + # because somehow it is possible that this gets called after + # bail is already complete. (I tried using bail.cancel() and + # catching CancelledError, but that didn't work out. Maybe + # there's a bug with cancellation?) + def _bail_out(): + if not bail.done(): + bail.set_result(None) + timeout_handler = loop.call_later(timeout, _bail_out) def _on_completion(f): pending.remove(f) done.add(f) @@ -210,14 +219,22 @@ def as_completed(fs, timeout=None): Note: The futures 'f' are not necessarily members of fs. """ - assert timeout is None, 'timeout not yet supported' + deadline = None + if timeout is not None: + deadline = time.monotonic() + timeout done = None # Make nonlocal happy. fs = _wrap_coroutines(fs) while fs: + if deadline is not None: + timeout = deadline - time.monotonic() @coroutine def _wait_for_some(): nonlocal done, fs - done, fs = yield from _wait(fs, return_when=FIRST_COMPLETED) + done, fs = yield from _wait(fs, timeout=timeout, + return_when=FIRST_COMPLETED) + if not done: + fs = set() + raise futures.TimeoutError() return done.pop().result() # May raise. yield Task(_wait_for_some()) for f in done: diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py index f1f17d10..4b25939c 100644 --- a/tulip/tasks_test.py +++ b/tulip/tasks_test.py @@ -65,6 +65,7 @@ def sleeper(): yield from tasks.sleep(0.15) raise ZeroDivisionError('really') b = tasks.Task(sleeper()) + @tasks.coroutine def foo(): done, pending = yield from tasks.wait([b, a]) self.assertEqual(len(done), 2) @@ -83,6 +84,7 @@ def foo(): def testWaitWithTimeout(self): a = tasks.sleep(0.1) b = tasks.sleep(0.15) + @tasks.coroutine def foo(): done, pending = yield from tasks.wait([b, a], timeout=0.11) self.assertEqual(done, set([a])) @@ -101,6 +103,7 @@ def sleeper(dt, x): a = sleeper(0.1, 'a') b = sleeper(0.1, 'b') c = sleeper(0.15, 'c') + @tasks.coroutine def foo(): values = [] for f in tasks.as_completed([b, c, a]): @@ -118,7 +121,28 @@ def foo(): res = self.event_loop.run_until_complete(tasks.Task(foo())) t1 = time.monotonic() self.assertTrue(t1-t0, 0.01) - # TODO: Test with timeout. + + def testAsCompletedWithTimeout(self): + a = tasks.sleep(0.1, 'a') + b = tasks.sleep(0.15, 'b') + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([a, b], timeout=0.12): + try: + v = yield from f + values.append((1, v)) + except futures.TimeoutError as exc: + values.append((2, exc)) + return values + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.11) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) def testSleep(self): @tasks.coroutine From 00196004e53b8b8368b34199d1f7b6025b65a9a8 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 10 Jan 2013 08:49:04 -0800 Subject: [PATCH 0191/1502] More state checks for Futures. --- tulip/futures_test.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tulip/futures_test.py b/tulip/futures_test.py index 4704849a..d45b4410 100644 --- a/tulip/futures_test.py +++ b/tulip/futures_test.py @@ -15,12 +15,15 @@ def testInitialState(self): def testCancel(self): f = futures.Future() - f.cancel() + self.assertTrue(f.cancel()) self.assertTrue(f.cancelled()) self.assertFalse(f.running()) self.assertTrue(f.done()) self.assertRaises(futures.CancelledError, f.result) self.assertRaises(futures.CancelledError, f.exception) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) def testResult(self): f = futures.Future() @@ -30,6 +33,9 @@ def testResult(self): self.assertTrue(f.done()) self.assertEqual(f.result(), 42) self.assertEqual(f.exception(), None) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) def testException(self): exc = RuntimeError() @@ -40,6 +46,9 @@ def testException(self): self.assertTrue(f.done()) self.assertRaises(RuntimeError, f.result) self.assertEqual(f.exception(), exc) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) def testYieldFromTwice(self): f = futures.Future() From 334f927f4bc8c92a4d5a66bf09f8d2907c374247 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 10 Jan 2013 09:52:38 -0800 Subject: [PATCH 0192/1502] Delete trailing whitespace. --- tulip/selectors.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/tulip/selectors.py b/tulip/selectors.py index 7923686a..b762e7e0 100644 --- a/tulip/selectors.py +++ b/tulip/selectors.py @@ -126,7 +126,7 @@ def modify(self, fileobj, events, data=None): def select(self, timeout=None): """Perform the actual selection, until some monitored file objects are ready or a timeout expires. - + Parameters: timeout -- if timeout > 0, this specifies the maximum wait time, in seconds @@ -239,11 +239,11 @@ def select(self, timeout=None): class PollSelector(_BaseSelector): """Poll-based selector.""" - + def __init__(self): super().__init__() self._poll = poll() - + def register(self, fileobj, events, data=None): key = super().register(fileobj, events, data) poll_events = 0 @@ -253,12 +253,12 @@ def register(self, fileobj, events, data=None): poll_events |= POLLOUT self._poll.register(key.fd, poll_events) return key - + def unregister(self, fileobj): key = super().unregister(fileobj) self._poll.unregister(key.fd) return key - + def select(self, timeout=None): timeout = None if timeout is None else int(1000 * timeout) ready = [] @@ -273,7 +273,7 @@ def select(self, timeout=None): events |= SELECT_OUT if event & ~POLLOUT: events |= SELECT_IN - + key = self._key_from_fd(fd) ready.append((key.fileobj, events, key.data)) return ready @@ -283,11 +283,11 @@ def select(self, timeout=None): class EpollSelector(_BaseSelector): """Epoll-based selector.""" - + def __init__(self): super().__init__() self._epoll = epoll() - + def register(self, fileobj, events, data=None): key = super().register(fileobj, events, data) epoll_events = 0 @@ -297,12 +297,12 @@ def register(self, fileobj, events, data=None): epoll_events |= EPOLLOUT self._epoll.register(key.fd, epoll_events) return key - + def unregister(self, fileobj): key = super().unregister(fileobj) self._epoll.unregister(key.fd) return key - + def select(self, timeout=None): timeout = -1 if timeout is None else timeout max_ev = self.registered_count() @@ -318,11 +318,11 @@ def select(self, timeout=None): events |= SELECT_OUT if event & ~EPOLLOUT: events |= SELECT_IN - + key = self._key_from_fd(fd) ready.append((key.fileobj, events, key.data)) return ready - + def close(self): super().close() self._epoll.close() @@ -332,7 +332,7 @@ def close(self): class KqueueSelector(_BaseSelector): """Kqueue-based selector.""" - + def __init__(self): super().__init__() self._kqueue = kqueue() @@ -347,7 +347,7 @@ def unregister(self, fileobj): kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_DELETE) self._kqueue.control([kev], 0, 0) return key - + def register(self, fileobj, events, data=None): key = super().register(fileobj, events, data) if events & SELECT_IN: @@ -357,7 +357,7 @@ def register(self, fileobj, events, data=None): kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_ADD) self._kqueue.control([kev], 0, 0) return key - + def select(self, timeout=None): max_ev = self.registered_count() ready = [] @@ -378,7 +378,7 @@ def select(self, timeout=None): key = self._key_from_fd(fd) ready.append((key.fileobj, events, key.data)) return ready - + def close(self): super().close() self._kqueue.close() From 8379a059ef621fd92f1c8115221892ede420f553 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 10 Jan 2013 09:57:35 -0800 Subject: [PATCH 0193/1502] Add test id pattern matching feature to runtests.py. --- Makefile | 7 ++++--- runtests.py | 28 +++++++++++++++++++++++++--- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index 274df966..8db8ad5e 100644 --- a/Makefile +++ b/Makefile @@ -1,15 +1,16 @@ PYTHON=python3 COVERAGE=coverage3 NONTESTS=`find tulip -name [a-z]\*.py ! -name \*_test.py` +FLAGS= test: - $(PYTHON) runtests.py -v + $(PYTHON) runtests.py -v $(FLAGS) testloop: - while sleep 1; do $(PYTHON) runtests.py -v; done + while sleep 1; do $(PYTHON) runtests.py -v $(FLAGS); done cov coverage: - $(COVERAGE) run runtests.py -v + $(COVERAGE) run runtests.py -v $(FLAGS) $(COVERAGE) html $(NONTESTS) $(COVERAGE) report -m $(NONTESTS) echo "open file://`pwd`/htmlcov/index.html" diff --git a/runtests.py b/runtests.py index fa784572..a4dd9378 100644 --- a/runtests.py +++ b/runtests.py @@ -1,12 +1,27 @@ -"""Run all unittests.""" +"""Run all unittests. + +Usage: + python3 runtests.py [-v] [-q] [pattern] ... + +Where: + -v: verbose + -q: quiet + pattern: optional regex patterns to match test ids (default all tests) + +Note that the test id is the fully qualified name of the test, +including package, module, class and method, +e.g. 'tulip.events_test.PolicyTests.testPolicy'. +""" # Originally written by Beech Horn (for NDB). +import re import sys import unittest +assert sys.version.startswith('3'), 'Please use Python 3.3 or higher.' -def load_tests(): +def load_tests(patterns=()): mods = ['events', 'futures', 'tasks'] test_mods = ['%s_test' % name for name in mods] tulip = __import__('tulip', fromlist=test_mods) @@ -19,19 +34,26 @@ def load_tests(): if name.endswith('Tests'): test_module = getattr(mod, name) tests = loader.loadTestsFromTestCase(test_module) + if patterns: + tests = [test + for test in tests + if any(re.search(pat, test.id()) for pat in patterns)] suite.addTests(tests) return suite def main(): + patterns = [] v = 1 for arg in sys.argv[1:]: if arg.startswith('-v'): v += arg.count('v') elif arg == '-q': v = 0 - result = unittest.TextTestRunner(verbosity=v).run(load_tests()) + elif arg and not arg.startswith('-'): + patterns.append(arg) + result = unittest.TextTestRunner(verbosity=v).run(load_tests(patterns)) sys.exit(not result.wasSuccessful()) From cce571a031276382a64091d914ab9a3dd0eead30 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 10 Jan 2013 10:52:18 -0800 Subject: [PATCH 0194/1502] Slight tweaks. --- tulip/unix_events.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 2595368b..7eae5bd5 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -662,13 +662,15 @@ def _run_once(self, timeout=None): # TODO: An alternative API would be to do the *minimal* amount # of work, e.g. one callback or one I/O poll. - # Add everytime handlers. + # Add everytime handlers, skipping cancelled ones. any_cancelled = False for handler in self._everytime: - self._add_callback(handler) - any_cancelled = any_cancelled or handler.cancelled + if handler.cancelled: + any_cancelled = True + else: + self._ready.append(handler) - # Remove cancelled handlers if there are any. + # Remove cancelled everytime handlers if there are any. if any_cancelled: self._everytime = [handler for handler in self._everytime if not handler.cancelled] @@ -716,7 +718,7 @@ def _run_once(self, timeout=None): if handler.when > now: break handler = heapq.heappop(self._scheduled) - self.call_soon(handler.callback, *handler.args) + self._ready.append(handler) # This is the only place where callbacks are actually *called*. # All other places just add them to ready. From 91c3144e1bbf7ad6274c4226244d02047254d00f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 10 Jan 2013 11:08:10 -0800 Subject: [PATCH 0195/1502] Improve+test Task.repr(). Fix bug in yield from . --- tulip/tasks.py | 33 ++++++++++++++++----------------- tulip/tasks_test.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 17 deletions(-) diff --git a/tulip/tasks.py b/tulip/tasks.py index 55377db1..ee16f496 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -54,6 +54,8 @@ def __init__(self, coro): def __repr__(self): res = super().__repr__() + if self._must_cancel and res.endswith(''): + res = res[:-len('')] + '' i = res.find('<') if i < 0: i = len(res) @@ -103,11 +105,12 @@ def _step(self, value=None, exc=None): raise else: def _wakeup(future): - value = None - exc = future.exception() - if exc is None: + try: value = future.result() - self._step(value, exc) + except Exception as exc: + self._step(None, exc) + else: + self._step(value, None) if isinstance(result, futures.Future): result.add_done_callback(_wakeup) elif isinstance(result, concurrent.futures.Future): @@ -165,19 +168,12 @@ def _wait(fs, timeout=None, return_when=ALL_COMPLETED): return_when == FIRST_COMPLETED and done or return_when == FIRST_EXCEPTION and errors): return done, pending - bail = futures.Future() + bail = futures.Future() # Will always be cancelled eventually. timeout_handler = None + debugstuff = locals() if timeout is not None: loop = events.get_event_loop() - # We need a helper to check that bail isn't already done, - # because somehow it is possible that this gets called after - # bail is already complete. (I tried using bail.cancel() and - # catching CancelledError, but that didn't work out. Maybe - # there's a bug with cancellation?) - def _bail_out(): - if not bail.done(): - bail.set_result(None) - timeout_handler = loop.call_later(timeout, _bail_out) + timeout_handler = loop.call_later(timeout, bail.cancel) def _on_completion(f): pending.remove(f) done.add(f) @@ -186,12 +182,14 @@ def _on_completion(f): (return_when == FIRST_EXCEPTION and not f.cancelled() and f.exception() is not None)): - if not bail.done(): - bail.set_result(None) + bail.cancel() try: for f in pending: f.add_done_callback(_on_completion) - yield from bail + try: + yield from bail + except futures.CancelledError: + pass finally: for f in pending: f.remove_done_callback(_on_completion) @@ -199,6 +197,7 @@ def _on_completion(f): timeout_handler.cancel() really_done = set(f for f in pending if f.done()) if really_done: + # We don't expect this to ever happen. Or do we? done.update(really_done) pending.difference_update(really_done) return done, pending diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py index 4b25939c..c31213e7 100644 --- a/tulip/tasks_test.py +++ b/tulip/tasks_test.py @@ -37,6 +37,39 @@ def notmuch(): self.assertTrue(t.done()) self.assertEqual(t.result(), 'ko') + def testTaskRepr(self): + @tasks.task + def notmuch(): + yield from [] + return 'abc' + t = notmuch() + self.assertEqual(repr(t), 'Task()') + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), 'Task()') + self.assertRaises(futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertEqual(repr(t), 'Task()') + t = notmuch() + self.event_loop.run_until_complete(t) + self.assertEqual(repr(t), "Task()") + + def testTaskWaiting(self): + @tasks.task + def outer(): + a = yield from inner1() + b = yield from inner2() + return a+b + @tasks.task + def inner1(): + yield from [] + return 42 + @tasks.task + def inner2(): + yield from [] + return 1000 + t = outer() + self.assertEqual(self.event_loop.run_until_complete(t), 1042) + def testWait(self): a = tasks.sleep(0.1) b = tasks.sleep(0.15) From f1259cfde23c2412dec57e80a09644d1f56d1a75 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 10 Jan 2013 11:29:32 -0800 Subject: [PATCH 0196/1502] Specifically check for 3.3 or higher. --- runtests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtests.py b/runtests.py index a4dd9378..fde87f08 100644 --- a/runtests.py +++ b/runtests.py @@ -19,7 +19,7 @@ import sys import unittest -assert sys.version.startswith('3'), 'Please use Python 3.3 or higher.' +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' def load_tests(patterns=()): mods = ['events', 'futures', 'tasks'] From 876e45c45f8adcece47257ffa96681544adad7d9 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 10 Jan 2013 12:08:04 -0800 Subject: [PATCH 0197/1502] Improve Task repr(). --- tulip/tasks.py | 6 ++++-- tulip/tasks_test.py | 14 +++++++++++--- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/tulip/tasks.py b/tulip/tasks.py index ee16f496..b29d7495 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -54,8 +54,10 @@ def __init__(self, coro): def __repr__(self): res = super().__repr__() - if self._must_cancel and res.endswith(''): - res = res[:-len('')] + '' + if (self._must_cancel and + self._state == futures._PENDING and + ')') + t.add_done_callback(Dummy()) + self.assertEqual(repr(t), 'Task()') t.cancel() # Does not take immediate effect! - self.assertEqual(repr(t), 'Task()') + self.assertEqual(repr(t), 'Task()') self.assertRaises(futures.CancelledError, self.event_loop.run_until_complete, t) self.assertEqual(repr(t), 'Task()') @@ -53,7 +61,7 @@ def notmuch(): self.event_loop.run_until_complete(t) self.assertEqual(repr(t), "Task()") - def testTaskWaiting(self): + def testTaskBasics(self): @tasks.task def outer(): a = yield from inner1() From 4392c15bca66cd65431656e8a7bcfea12d91ad6f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 10 Jan 2013 12:23:17 -0800 Subject: [PATCH 0198/1502] Fix typo reported (long ago) by Brett Cannon. --- old/polling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/old/polling.py b/old/polling.py index cffed2af..6586efcc 100644 --- a/old/polling.py +++ b/old/polling.py @@ -85,7 +85,7 @@ def poll(self, timeout=None): If timeout is omitted or None, this blocks until at least one event is ready. Otherwise, timeout gives a maximum time to - wait (an int of float in seconds) -- the method returns as + wait (in seconds as an int or float) -- the method returns as soon as at least one event is ready or when the timeout is expired. For a non-blocking poll, pass 0. From 67b66a039da3ce86333b877a6a52bbcc58fc306b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 10 Jan 2013 13:10:23 -0800 Subject: [PATCH 0199/1502] Tentatively support TCP_FASTOPEN (can't test though). --- tulip/unix_events.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 7eae5bd5..f297b092 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -326,7 +326,7 @@ def create_transport(self, protocol_factory, host, port, *, ssl=False, @tasks.task def start_serving(self, protocol_factory, host, port, *, family=0, type=socket.SOCK_STREAM, proto=0, flags=0, - backlog=100): + backlog=100, fastopen=5): """XXX""" infos = yield from self.getaddrinfo(host, port, family=family, type=type, @@ -347,6 +347,13 @@ def start_serving(self, protocol_factory, host, port, *, break else: raise exceptions[0] + if fastopen and hasattr(socket, 'TCP_FASTOPEN'): + try: + sock.setsockopt(socket.SOL_TCP, socket.TCP_FASTOPEN, fastopen) + except socket.error: + # Even if TCP_FASTOPEN is defined by glibc, it may + # still not be supported by the kernel. + logging.info('TCP_FASTOPEN(%r) failed', fastopen) sock.listen(backlog) sock.setblocking(False) self.add_reader(sock.fileno(), self._accept_connection, From 31a98a011065c1bceb7f6406ce6d87a03de6e8b2 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 10 Jan 2013 13:52:03 -0800 Subject: [PATCH 0200/1502] Turn on branch coverage. --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 8db8ad5e..d11e9716 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ testloop: while sleep 1; do $(PYTHON) runtests.py -v $(FLAGS); done cov coverage: - $(COVERAGE) run runtests.py -v $(FLAGS) + $(COVERAGE) run --branch runtests.py -v $(FLAGS) $(COVERAGE) html $(NONTESTS) $(COVERAGE) report -m $(NONTESTS) echo "open file://`pwd`/htmlcov/index.html" From a737b81831152f22dc1aaa6d3de6b442df97e835 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 10 Jan 2013 13:54:17 -0800 Subject: [PATCH 0201/1502] Fix cancelling a sleeping task. --- tulip/tasks.py | 31 ++++++++++++++++++++----------- tulip/tasks_test.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 11 deletions(-) diff --git a/tulip/tasks.py b/tulip/tasks.py index b29d7495..91c32076 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -7,6 +7,7 @@ import concurrent.futures import inspect +import logging import time from . import events @@ -60,7 +61,7 @@ def __repr__(self): res = res.replace(')'.format(self._coro.__name__) + res[i:] return res @@ -69,13 +70,19 @@ def cancel(self): return False self._must_cancel = True # _step() will call super().cancel() to call the callbacks. + self._event_loop.call_soon(self._step_maybe) return True def cancelled(self): return self._must_cancel or super().cancelled() + def _step_maybe(self): + # Helper for cancel(). + if not self.done(): + return self._step() + def _step(self, value=None, exc=None): - if self.done(): + if self.done(): # pragma: no cover logging.warn('_step(): already done: %r, %r, %r', self, value, exc) return # We'll call either coro.throw(exc) or coro.send(value). @@ -106,15 +113,9 @@ def _step(self, value=None, exc=None): self.set_exception(exc) raise else: - def _wakeup(future): - try: - value = future.result() - except Exception as exc: - self._step(None, exc) - else: - self._step(value, None) + # XXX No check for self._must_cancel here? if isinstance(result, futures.Future): - result.add_done_callback(_wakeup) + result.add_done_callback(self._wakeup) elif isinstance(result, concurrent.futures.Future): # This ought to be more efficient than wrap_future(), # because we don't create an extra Future. @@ -126,6 +127,14 @@ def _wakeup(future): logging.warn('_step(): bad yield: %r', result) self._event_loop.call_soon(self._step) + def _wakeup(self, future): + try: + value = future.result() + except Exception as exc: + self._step(None, exc) + else: + self._step(value, None) + # wait() and as_completed() similar to those in PEP 3148. @@ -198,7 +207,7 @@ def _on_completion(f): if timeout_handler is not None: timeout_handler.cancel() really_done = set(f for f in pending if f.done()) - if really_done: + if really_done: # pragma: no cover # We don't expect this to ever happen. Or do we? done.update(really_done) pending.difference_update(really_done) diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py index f9e19094..a214468e 100644 --- a/tulip/tasks_test.py +++ b/tulip/tasks_test.py @@ -199,6 +199,35 @@ def sleeper(dt, arg): self.assertTrue(t.done()) self.assertEqual(t.result(), 'yeah') + def testTaskCancelSleepingTask(self): + sleepfut = None + @tasks.task + def sleep(dt): + nonlocal sleepfut + sleepfut = tasks.sleep(dt) + try: + t0 = time.monotonic() + yield from sleepfut + finally: + t1 = time.monotonic() + @tasks.task + def doit(): + sleeper = sleep(5000) + self.event_loop.call_later(0.1, sleeper.cancel) + try: + t0 = time.monotonic() + yield from sleeper + except futures.CancelledError: + t1 = time.monotonic() + return 'cancelled' + else: + return 'slept in' + t0 = time.monotonic() + doer = doit() + self.assertEqual(self.event_loop.run_until_complete(doer), 'cancelled') + t1 = time.monotonic() + self.assertTrue(0.09 <= t1-t0 <= 0.11, (t1-t0, sleepfut, doer)) + if __name__ == '__main__': unittest.main() From 5839b7b4ad0ea065758e9a1e9860359ff3f03f71 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 10 Jan 2013 16:50:48 -0800 Subject: [PATCH 0202/1502] Checkpoint: HTTP client (needs a buffered stream badly). --- curl.py | 19 +++++ tulip/http_client.py | 185 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 204 insertions(+) create mode 100755 curl.py create mode 100644 tulip/http_client.py diff --git a/curl.py b/curl.py new file mode 100755 index 00000000..127616a8 --- /dev/null +++ b/curl.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python3 + +import sys + +import tulip +from tulip import http_client + + +def main(): + url = sys.argv[1] + p = http_client.HttpClientProtocol(url) + f = p.connect() + t = p.event_loop.run_until_complete(tulip.Task(f)) + print('transport =', t) + p.event_loop.run() + + +if __name__ == '__main__': + main() diff --git a/tulip/http_client.py b/tulip/http_client.py new file mode 100644 index 00000000..76eb85a8 --- /dev/null +++ b/tulip/http_client.py @@ -0,0 +1,185 @@ +"""HTTP Client for Tulip. + +Most basic usage: + + sts, headers, response = yield from http_client.fetch(url, + method='GET', headers={}, request=b'') + assert isinstance(sts, int) + assert isinstance(headers, dict) + # sort of; case insensitive (what about multiple values for same header?) + headers['status'] == '200 Ok' # or some such + assert isinstance(response, bytes) + +However you can also open a stream: + + f, wstream = http_client.open_stream(url, method, headers) + wstream.write(b'abc') + wstream.writelines([b'def', b'ghi']) + wstream.write_eof() + sts, headers, rstream = yield from f + response = yield from rstream.read() + +TODO: Reuse email.Message class (or its subclass, http.client.HTTPMessage). +TODO: How do we do connection keep alive? Pooling? +""" + +import urllib.parse # For urlparse(). + +import tulip +from . import events + + +class HttpClientProtocol: + """This Protocol class is also used to initiate the connection. + + Usage: + p = HttpClientProtocol(url, ...) + f = p.connect() # Returns a Future + ...now what?... + """ + + def __init__(self, url, method='GET', headers=None, make_body=None, + encoding='utf-8', version='1.1', chunked=False): + self.url = self.validate(url, 'url') + self.method = self.validate(method, 'method') + self.headers = {} + if headers: + for key, value in headers.items(): + self.validate(key, 'header key') + self.validate(value, 'header value', True) + assert key not in self.headers, \ + '{} header is a duplicate'.format(key) + self.headers[key.lower()] = value + self.encoding = self.validate(encoding, 'encoding') + self.version = self.validate(version, 'version') + self.make_body = make_body + self.chunked = chunked + self.split_url = urllib.parse.urlsplit(url) + (self.scheme, self.netloc, self.path, + self.query, self.fragment) = self.split_url + if not self.path: + self.path = '/' + self.ssl = self.scheme == 'https' + if 'content-length' not in self.headers: + if self.make_body is None: + self.headers['content-length'] = '0' + else: + self.chunked = True + if self.chunked: + if 'transfer-encoding' not in self.headers: + self.headers['transfer-encoding'] = 'chunked' + else: + assert self.headers['transfer-encoding'] == 'chunked' + if ':' in self.netloc: + self.host, port_str = self.netloc.split(':', 1) + self.port = int(port_str) + else: + self.host = self.netloc + if self.ssl: + self.port = 443 + else: + self.port = 80 + if 'host' not in self.headers: + self.headers['host'] = self.host + self.event_loop = events.get_event_loop() + self.transport = None + + def validate(self, value, name, embedded_spaces_okay=False): + # Must be a string. If embedded_spaces_okay is False, no + # whitespace is allowed; otherwise, internal single spaces are + # allowed (but no other whitespace). + assert isinstance(value, str), \ + '{} should be str, not {}'.format(name, type(value)) + parts = value.split() + assert parts, '{} should not be empty'.format(name) + if embedded_spaces_okay: + assert ' '.join(parts) == value, \ + '{} can only contain embedded single spaces'.format(name) + else: + assert parts == [value], \ + '{} cannot contain whitespace'.format(name) + return value + + @tulip.coroutine + def connect(self): + t, p = yield from self.event_loop.create_transport(lambda: self, + self.host, + self.port, + ssl=self.ssl) + return t # Since p is self. + + def encode(self, s): + if isinstance(s, bytes): + return s + return s.encode(self.encoding) + + def decode(self, s): + if isinstance(s, str): + return s + return s.decode(self.encoding) + + def write_str(self, s): + self.transport.write(self.encode(s)) + + def write_chunked(self, s): + if not s: + return + data = self.encode(s) + self.write_str('{:x}\r\n'.format(len(data))) + self.transport.write(data) + self.transport.write(b'\r\n') + + def write_chunked_eof(self): + self.transport.write(b'0\r\n\r\n') + + def connection_made(self, transport): + self.transport = transport + line = '{} {} HTTP/{}\r\n'.format(self.method, + self.path, + self.version) + self.write_str(line) + for key, value in self.headers.items(): + self.write_str('{}: {}\r\n'.format(key, value)) + self.transport.write(b'\r\n') + if self.make_body is not None: + if self.chunked: + self.make_body(self.write_chunked, self.write_chunked_eof) + else: + self.make_body(self.write_str, self.transport.write_eof) + self.lines_received = [] + self.incomplete_line = b'' + self.body_bytes_received = None + + def data_received(self, data): + if self.body_bytes_received is not None: # State: reading body. + print('body data received:', data) + self.body_bytes_received.append(data) + self.body_byte_count += len(data) + return + self.incomplete_line += data + parts = self.incomplete_line.splitlines(True) + self.incomplete_line = b'' + done = False + for part in parts: + if not done: + if not part.endswith(b'\n'): + self.incomplete_line = part + break + self.lines_received.append(part) + if part in (b'\r\n', b'\n'): + done = True + self.body_bytes_received = [] + self.body_byte_count = 0 + else: + self.body_bytes_received.append(part) + self.body_byte_count += len(part) + if done: + print('headers received:', str(self.lines_received).replace(', ', ',\n ')) + for data in self.body_bytes_received: + print('more data received:', data) + + def eof_received(self): + print('received EOF') + + def connection_lost(self, exc): + print('connection lost:', exc) From f6cbc1aaf64c310e0ce7d0ae574a561ec002eec3 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 10 Jan 2013 21:23:42 -0800 Subject: [PATCH 0203/1502] Add and use StreamReader class. --- curl.py | 4 +- tulip/http_client.py | 165 ++++++++++++++++++++++++++++++++++--------- 2 files changed, 131 insertions(+), 38 deletions(-) diff --git a/curl.py b/curl.py index 127616a8..3b09fe69 100755 --- a/curl.py +++ b/curl.py @@ -10,9 +10,7 @@ def main(): url = sys.argv[1] p = http_client.HttpClientProtocol(url) f = p.connect() - t = p.event_loop.run_until_complete(tulip.Task(f)) - print('transport =', t) - p.event_loop.run() + p.event_loop.run_until_complete(tulip.Task(f)) if __name__ == '__main__': diff --git a/tulip/http_client.py b/tulip/http_client.py index 76eb85a8..319316e5 100644 --- a/tulip/http_client.py +++ b/tulip/http_client.py @@ -23,12 +23,113 @@ TODO: How do we do connection keep alive? Pooling? """ +import collections import urllib.parse # For urlparse(). import tulip from . import events +from . import futures +from . import tasks +# TODO: Move to another module. +class StreamReader: + + def __init__(self, limit=2**16): + self.limit = limit # Max line length. (Security feature.) + self.buffer = collections.deque() # Deque of bytes objects. + self.byte_count = 0 # Bytes in buffer. + self.line_count = 0 # Number of complete lines in buffer. + self.eof = False # Whether we're done. + self.waiter = None # A future. + + def feed_eof(self): + self.eof = True + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(True) + + def feed_data(self, data): + assert data + self.buffer.append(data) + self.line_count += data.count(b'\n') + self.byte_count += len(data) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(False) + + @tasks.coroutine + def readline(self): + # TODO: Limit line length for security. + while not self.line_count and not self.eof: + assert self.waiter is None + self.waiter = futures.Future() + yield from self.waiter + continue + parts = [] + while self.buffer: + data = self.buffer.popleft() + ichar = data.find(b'\n') + if ichar < 0: + parts.append(data) + else: + ichar += 1 + head, tail = data[:ichar], data[ichar:] + parts.append(head) + if tail: + self.buffer.appendleft(tail) + self.line_count -= 1 + break + return b''.join(parts) + + @tasks.coroutine + def read(self, n=-1): + if not n: + return b'' + if n < 0: + while not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + else: + if not self.byte_count and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + if n < 0 or self.byte_count <= n: + data = b''.join(self.buffer) + self.buffer.clear() + self.byte_count = 0 + self.line_count = 0 + return data + parts = [] + parts_bytes = 0 + while self.buffer and parts_bytes < n: + data = self.buffer.popleft() + data_bytes = len(data) + if n < parts_bytes + data_bytes: + data_bytes = n - parts_bytes + data, rest = data[:data_bytes], data[data_bytes:] + self.buffer.appendleft(rest) + parts.append(data) + parts_bytes += data_bytes + self.byte_count -= data_bytes + if self.line_count: + self.line_count -= data.count(b'\n') + return b''.join(parts) + + @tasks.coroutine + def readexactly(self, n): + if n <= 0: + return b'' + while self.byte_count < n and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + return (yield from self.read(n)) + class HttpClientProtocol: """This Protocol class is also used to initiate the connection. @@ -102,11 +203,32 @@ def validate(self, value, name, embedded_spaces_okay=False): @tulip.coroutine def connect(self): - t, p = yield from self.event_loop.create_transport(lambda: self, - self.host, - self.port, - ssl=self.ssl) - return t # Since p is self. + yield from self.event_loop.create_transport(lambda: self, + self.host, + self.port, + ssl=self.ssl) + status_line = yield from self.stream.readline() + print('status line:', status_line) + headers = [] + content_length = None + while True: + header = yield from self.stream.readline() + if not header.strip(): + break + headers.append(header) + if header.lower().startswith(b'content-length:'): + parts = header.split(None, 1) + if len(parts) == 2: + try: + content_length = int(parts[1]) + except ValueError: + pass + print('headers:', repr(headers).replace(', ', ',\n ')) + if content_length is None: + body = yield from self.stream.read() + else: + body = yield from self.stream.readexactly(content_length) + print('body:', body) def encode(self, s): if isinstance(s, bytes): @@ -141,45 +263,18 @@ def connection_made(self, transport): for key, value in self.headers.items(): self.write_str('{}: {}\r\n'.format(key, value)) self.transport.write(b'\r\n') + self.stream = StreamReader() if self.make_body is not None: if self.chunked: self.make_body(self.write_chunked, self.write_chunked_eof) else: self.make_body(self.write_str, self.transport.write_eof) - self.lines_received = [] - self.incomplete_line = b'' - self.body_bytes_received = None def data_received(self, data): - if self.body_bytes_received is not None: # State: reading body. - print('body data received:', data) - self.body_bytes_received.append(data) - self.body_byte_count += len(data) - return - self.incomplete_line += data - parts = self.incomplete_line.splitlines(True) - self.incomplete_line = b'' - done = False - for part in parts: - if not done: - if not part.endswith(b'\n'): - self.incomplete_line = part - break - self.lines_received.append(part) - if part in (b'\r\n', b'\n'): - done = True - self.body_bytes_received = [] - self.body_byte_count = 0 - else: - self.body_bytes_received.append(part) - self.body_byte_count += len(part) - if done: - print('headers received:', str(self.lines_received).replace(', ', ',\n ')) - for data in self.body_bytes_received: - print('more data received:', data) + self.stream.feed_data(data) def eof_received(self): - print('received EOF') + self.stream.feed_eof() def connection_lost(self, exc): print('connection lost:', exc) From 016e04a62095be63f8c36929c6f6b5e78930892f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 11 Jan 2013 08:43:40 -0800 Subject: [PATCH 0204/1502] Move urlsplit() to curl.py. --- curl.py | 10 +++++++++- tulip/http_client.py | 40 ++++++++++++++++++++-------------------- 2 files changed, 29 insertions(+), 21 deletions(-) diff --git a/curl.py b/curl.py index 3b09fe69..2d409b58 100755 --- a/curl.py +++ b/curl.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import sys +import urllib.parse import tulip from tulip import http_client @@ -8,7 +9,14 @@ def main(): url = sys.argv[1] - p = http_client.HttpClientProtocol(url) + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not path: + path = '/' + if query: + path = '?'.join([path, query]) + print(netloc, path, scheme) + p = http_client.HttpClientProtocol(netloc, path=path, + ssl=(scheme=='https')) f = p.connect() p.event_loop.run_until_complete(tulip.Task(f)) diff --git a/tulip/http_client.py b/tulip/http_client.py index 319316e5..b59f77f2 100644 --- a/tulip/http_client.py +++ b/tulip/http_client.py @@ -24,7 +24,6 @@ """ import collections -import urllib.parse # For urlparse(). import tulip from . import events @@ -139,9 +138,24 @@ class HttpClientProtocol: ...now what?... """ - def __init__(self, url, method='GET', headers=None, make_body=None, - encoding='utf-8', version='1.1', chunked=False): - self.url = self.validate(url, 'url') + def __init__(self, host, port=None, *, + path='/', method='GET', headers=None, ssl=None, + make_body=None, encoding='utf-8', version='1.1', + chunked=False): + host = self.validate(host, 'host') + if ':' in host: + assert port is None + host, port_s = host.split(':', 1) + port = int(port_s) + self.host = host + if port is None: + if ssl: + port = 443 + else: + port = 80 + assert isinstance(port, int) + self.port = port + self.path = self.validate(path, 'path') self.method = self.validate(method, 'method') self.headers = {} if headers: @@ -155,12 +169,7 @@ def __init__(self, url, method='GET', headers=None, make_body=None, self.version = self.validate(version, 'version') self.make_body = make_body self.chunked = chunked - self.split_url = urllib.parse.urlsplit(url) - (self.scheme, self.netloc, self.path, - self.query, self.fragment) = self.split_url - if not self.path: - self.path = '/' - self.ssl = self.scheme == 'https' + self.ssl = ssl if 'content-length' not in self.headers: if self.make_body is None: self.headers['content-length'] = '0' @@ -170,16 +179,7 @@ def __init__(self, url, method='GET', headers=None, make_body=None, if 'transfer-encoding' not in self.headers: self.headers['transfer-encoding'] = 'chunked' else: - assert self.headers['transfer-encoding'] == 'chunked' - if ':' in self.netloc: - self.host, port_str = self.netloc.split(':', 1) - self.port = int(port_str) - else: - self.host = self.netloc - if self.ssl: - self.port = 443 - else: - self.port = 80 + assert self.headers['transfer-encoding'].lower() == 'chunked' if 'host' not in self.headers: self.headers['host'] = self.host self.event_loop = events.get_event_loop() From ffbec8611fd9a1e6602bd98329a3f69d72cde5e1 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 11 Jan 2013 08:58:23 -0800 Subject: [PATCH 0205/1502] Use email.message.Message() for headers. --- tulip/http_client.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tulip/http_client.py b/tulip/http_client.py index b59f77f2..18549572 100644 --- a/tulip/http_client.py +++ b/tulip/http_client.py @@ -24,6 +24,7 @@ """ import collections +import email.message import tulip from . import events @@ -157,14 +158,12 @@ def __init__(self, host, port=None, *, self.port = port self.path = self.validate(path, 'path') self.method = self.validate(method, 'method') - self.headers = {} + self.headers = email.message.Message() if headers: for key, value in headers.items(): self.validate(key, 'header key') self.validate(value, 'header value', True) - assert key not in self.headers, \ - '{} header is a duplicate'.format(key) - self.headers[key.lower()] = value + self.headers[key] = value self.encoding = self.validate(encoding, 'encoding') self.version = self.validate(version, 'version') self.make_body = make_body @@ -172,16 +171,16 @@ def __init__(self, host, port=None, *, self.ssl = ssl if 'content-length' not in self.headers: if self.make_body is None: - self.headers['content-length'] = '0' + self.headers['Content-Length'] = '0' else: self.chunked = True if self.chunked: - if 'transfer-encoding' not in self.headers: - self.headers['transfer-encoding'] = 'chunked' + if 'Transfer-Encoding' not in self.headers: + self.headers['Transfer-Encoding'] = 'chunked' else: - assert self.headers['transfer-encoding'].lower() == 'chunked' + assert self.headers['Transfer-Encoding'].lower() == 'chunked' if 'host' not in self.headers: - self.headers['host'] = self.host + self.headers['Host'] = self.host self.event_loop = events.get_event_loop() self.transport = None @@ -229,6 +228,7 @@ def connect(self): else: body = yield from self.stream.readexactly(content_length) print('body:', body) + self.transport.close() def encode(self, s): if isinstance(s, bytes): From 54a24b508478be9f1d9d125c3c4945bab7ba5395 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 11 Jan 2013 10:07:42 -0800 Subject: [PATCH 0206/1502] Make create_transport(..., ssl=True) wait until connected. --- tulip/unix_events.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index f297b092..0ffbb6e4 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -317,7 +317,10 @@ def create_transport(self, protocol_factory, host, port, *, ssl=False, protocol = protocol_factory() if ssl: sslcontext = None if isinstance(ssl, bool) else ssl - transport = _UnixSslTransport(self, sock, protocol, sslcontext) + waiter = futures.Future() + transport = _UnixSslTransport(self, sock, protocol, sslcontext, + waiter) + yield from waiter else: transport = _UnixSocketTransport(self, sock, protocol) return transport, protocol @@ -838,12 +841,13 @@ def _call_connection_lost(self, exc): class _UnixSslTransport(transports.Transport): - def __init__(self, event_loop, rawsock, protocol, sslcontext): + def __init__(self, event_loop, rawsock, protocol, sslcontext, waiter): self._event_loop = event_loop self._rawsock = rawsock self._protocol = protocol sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) self._sslcontext = sslcontext + self._waiter = waiter sslsock = sslcontext.wrap_socket(rawsock, do_handshake_on_connect=False) self._sslsock = sslsock @@ -861,12 +865,20 @@ def _on_handshake(self): except ssl.SSLWantWriteError: self._event_loop.add_writable(fd, self._on_handshake) return - # TODO: What if it raises another error? + except Exception as exc: + self._sslsock.close() + self._waiter.set_exception(exc) + return + except BaseException as exc: + self._sslsock.close() + self._waiter.set_exception(exc) + raise self._event_loop.remove_reader(fd) self._event_loop.remove_writer(fd) self._event_loop.add_reader(fd, self._on_ready) self._event_loop.add_writer(fd, self._on_ready) self._event_loop.call_soon(self._protocol.connection_made, self) + self._waiter.set_result(None) def _on_ready(self): # Because of renegotiations (?), there's no difference between From 3ae72d864266cdec52f15004b016f7fe4e0e045c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 11 Jan 2013 12:16:59 -0800 Subject: [PATCH 0207/1502] More useful connect() return. Add a crawler. --- crawl.py | 107 +++++++++++++++++++++++++++++++++++++++++++ curl.py | 5 +- tulip/http_client.py | 43 ++++++++++------- 3 files changed, 138 insertions(+), 17 deletions(-) create mode 100755 crawl.py diff --git a/crawl.py b/crawl.py new file mode 100755 index 00000000..801b19ce --- /dev/null +++ b/crawl.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 + +import logging +import re +import signal +import sys +import urllib.parse + +import tulip +from tulip import http_client + + +class Crawler: + + def __init__(self, rooturl): + self.rooturl = rooturl + self.todo = set() + self.busy = set() + self.done = {} + self.tasks = set() + self.waiter = None + self.add(self.rooturl) # Set initial work. + self.run() # Kick off work. + + def add(self, url): + if url in self.busy or url in self.done or url in self.todo: + return False + self.todo.add(url) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(None) + return True + + @tulip.task + def run(self): + while self.todo or self.busy or self.tasks: + complete, self.tasks = yield from tulip.wait(self.tasks, timeout=0) + for task in complete: + try: + yield from task + except Exception as exc: + logging.warn('Exception in task: %s', exc) + while self.todo: + url = self.todo.pop() + self.busy.add(url) + self.tasks.add(self.process(url)) # Async task. + if self.busy: + self.waiter = tulip.Future() + yield from self.waiter + tulip.get_event_loop().stop() + + @tulip.task + def process(self, url): + ok = False + try: + print('processing', url) + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not path: + path = '/' + if query: + path = '?'.join([path, query]) + p = http_client.HttpClientProtocol(netloc, path=path, + ssl=(scheme=='https')) + status, headers, stream = yield from p.connect() + if status.startswith('200'): + ctype = headers.get_content_type() + if ctype == 'text/html': + while True: + line = yield from stream.readline() + if not line: + break + line = line.decode('utf-8') + urls = re.findall(r'(?i)href=["\']?([^ "\'<>]+)', + line) + for u in urls: + u, frag = urllib.parse.urldefrag(u) + u = urllib.parse.urljoin(self.rooturl, u) + if u.startswith(self.rooturl): + if self.add(u): + print(' ', url, '->', u) + ok = True + finally: + self.done[url] = ok + self.busy.remove(url) + if not ok: + print('failure for', url, sys.exc_info()) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(None) + + +def main(): + rooturl = sys.argv[1] + c = Crawler(rooturl) + loop = tulip.get_event_loop() + loop.add_signal_handler(signal.SIGINT, loop.stop) + loop.run_forever() + print('todo:', len(c.todo)) + print('busy:', len(c.busy)) + print('done:', len(c.done), '; ok:', sum(c.done.values())) + print('tasks:', len(c.tasks)) + + +if __name__ == '__main__': + main() diff --git a/curl.py b/curl.py index 2d409b58..c6259711 100755 --- a/curl.py +++ b/curl.py @@ -18,7 +18,10 @@ def main(): p = http_client.HttpClientProtocol(netloc, path=path, ssl=(scheme=='https')) f = p.connect() - p.event_loop.run_until_complete(tulip.Task(f)) + sts, headers, stream = p.event_loop.run_until_complete(tulip.Task(f)) + print(sts) + data = p.event_loop.run_until_complete(tulip.Task(stream.read(1000))) + print(data) if __name__ == '__main__': diff --git a/tulip/http_client.py b/tulip/http_client.py index 18549572..094f7d5b 100644 --- a/tulip/http_client.py +++ b/tulip/http_client.py @@ -25,6 +25,8 @@ import collections import email.message +import email.parser +import re import tulip from . import events @@ -51,7 +53,8 @@ def feed_eof(self): waiter.set_result(True) def feed_data(self, data): - assert data + if not data: + return self.buffer.append(data) self.line_count += data.count(b'\n') self.byte_count += len(data) @@ -206,29 +209,37 @@ def connect(self): self.host, self.port, ssl=self.ssl) + # TODO: A better mechanism to return all info from the + # status line, all headers, and the buffer, without having + # an N-tuple return value. status_line = yield from self.stream.readline() - print('status line:', status_line) - headers = [] - content_length = None + m = re.match(rb'HTTP/(\d\.\d)\s+(\d\d\d)\s+([^\r\n]+)\r?\n\Z', + status_line) + if not m: + raise 'Invalid HTTP status line ({!r})'.format(status_line) + version, status, message = m.groups() + raw_headers = [] while True: header = yield from self.stream.readline() if not header.strip(): break - headers.append(header) - if header.lower().startswith(b'content-length:'): - parts = header.split(None, 1) - if len(parts) == 2: - try: - content_length = int(parts[1]) - except ValueError: - pass - print('headers:', repr(headers).replace(', ', ',\n ')) + raw_headers.append(header) + parser = email.parser.BytesHeaderParser() + headers = parser.parsebytes(b''.join(raw_headers)) + content_length = headers.get('content-length') + if content_length: + content_length = int(content_length) # May raise. if content_length is None: - body = yield from self.stream.read() + stream = self.stream else: + # TODO: A wrapping stream that limits how much it can read + # without reading it all into memory at once. body = yield from self.stream.readexactly(content_length) - print('body:', body) - self.transport.close() + stream = StreamReader() + stream.feed_data(body) + stream.feed_eof() + sts = '{} {}'.format(self.decode(status), self.decode(message)) + return (sts, headers, stream) def encode(self, s): if isinstance(s, bytes): From 9b65d447db9ce33755dae14ab883a697b10a75d5 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 11 Jan 2013 14:10:16 -0800 Subject: [PATCH 0208/1502] Fix href-finding pattern to exclude all whitespace. --- crawl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crawl.py b/crawl.py index 801b19ce..b4c9eac6 100755 --- a/crawl.py +++ b/crawl.py @@ -71,7 +71,7 @@ def process(self, url): if not line: break line = line.decode('utf-8') - urls = re.findall(r'(?i)href=["\']?([^ "\'<>]+)', + urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', line) for u in urls: u, frag = urllib.parse.urldefrag(u) From 3187f3fb06e5caf3d2be95ab1935f1e5d1b198dd Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 11 Jan 2013 14:11:48 -0800 Subject: [PATCH 0209/1502] Improve error message for value containing whitespace. --- tulip/http_client.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tulip/http_client.py b/tulip/http_client.py index 094f7d5b..3ebdff0f 100644 --- a/tulip/http_client.py +++ b/tulip/http_client.py @@ -197,10 +197,11 @@ def validate(self, value, name, embedded_spaces_okay=False): assert parts, '{} should not be empty'.format(name) if embedded_spaces_okay: assert ' '.join(parts) == value, \ - '{} can only contain embedded single spaces'.format(name) + '{} can only contain embedded single spaces ({!r})'.format( + name, value) else: assert parts == [value], \ - '{} cannot contain whitespace'.format(name) + '{} cannot contain whitespace ({!r})'.format(name, value) return value @tulip.coroutine From ff32223a4243bad7b22e16ec80f682df5d01242b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 11 Jan 2013 14:12:33 -0800 Subject: [PATCH 0210/1502] Tighten connect() error handling. --- tulip/unix_events.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 0ffbb6e4..71c52b38 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -300,20 +300,29 @@ def create_transport(self, protocol_factory, host, port, *, ssl=False, raise socket.error('getaddrinfo() returned empty list') exceptions = [] for family, type, proto, cname, address in infos: - sock = socket.socket(family=family, type=type, proto=proto) - sock.setblocking(False) - # TODO: Use a small timeout here and overlap connect attempts. + sock = None try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) yield self.sock_connect(sock, address) except socket.error as exc: - sock.close() + if sock is not None: + sock.close() exceptions.append(exc) else: break else: - # TODO: What to do if there are multiple exceptions? We - # can't raise them all. Arbitrarily pick the first one. - raise exceptions[0] + if len(exceptions) == 1: + raise exceptions[0] + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise socket.error('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) protocol = protocol_factory() if ssl: sslcontext = None if isinstance(ssl, bool) else ssl From f6f3814be9ec613e133ce48f43afa5d8637a5891 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 11 Jan 2013 14:32:46 -0800 Subject: [PATCH 0211/1502] Add retry with exponential back-off to process(). --- crawl.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/crawl.py b/crawl.py index b4c9eac6..88cd9c21 100755 --- a/crawl.py +++ b/crawl.py @@ -3,6 +3,7 @@ import logging import re import signal +import socket import sys import urllib.parse @@ -62,7 +63,18 @@ def process(self, url): path = '?'.join([path, query]) p = http_client.HttpClientProtocol(netloc, path=path, ssl=(scheme=='https')) - status, headers, stream = yield from p.connect() + delay = 1 + while True: + try: + status, headers, stream = yield from p.connect() + break + except socket.error as exc: + if delay >= 60: + raise + print('...', url, 'has error', repr(str(exc)), + 'retrying after sleep', delay, '...') + yield from tulip.sleep(delay) + delay *= 2 if status.startswith('200'): ctype = headers.get_content_type() if ctype == 'text/html': From 58d2b934d06482b2be008564a07971f0fd01ebdd Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 11 Jan 2013 14:38:32 -0800 Subject: [PATCH 0212/1502] Print headers. Decode body. --- curl.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/curl.py b/curl.py index c6259711..c5283bfb 100755 --- a/curl.py +++ b/curl.py @@ -20,8 +20,11 @@ def main(): f = p.connect() sts, headers, stream = p.event_loop.run_until_complete(tulip.Task(f)) print(sts) + for k, v in headers.items(): + print('{}: {}'.format(k, v)) + print() data = p.event_loop.run_until_complete(tulip.Task(stream.read(1000))) - print(data) + print(data.decode('utf-8')) if __name__ == '__main__': From e1f590c87db33809c870c9251545278d34eea09b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 11 Jan 2013 14:45:53 -0800 Subject: [PATCH 0213/1502] Support 301/302 redirects. --- crawl.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/crawl.py b/crawl.py index 88cd9c21..970152fe 100755 --- a/crawl.py +++ b/crawl.py @@ -24,6 +24,10 @@ def __init__(self, rooturl): self.run() # Kick off work. def add(self, url): + url = urllib.parse.urljoin(self.rooturl, url) + url, frag = urllib.parse.urldefrag(url) + if not url.startswith(self.rooturl): + return False if url in self.busy or url in self.done or url in self.todo: return False self.todo.add(url) @@ -75,7 +79,12 @@ def process(self, url): 'retrying after sleep', delay, '...') yield from tulip.sleep(delay) delay *= 2 - if status.startswith('200'): + if status[:3] in ('301', '302'): + # Redirect. + u = headers.get('location') or headers.get('uri') + if self.add(u): + print(' ', url, status[:3], 'redirect to', u) + elif status.startswith('200'): ctype = headers.get_content_type() if ctype == 'text/html': while True: @@ -86,11 +95,8 @@ def process(self, url): urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', line) for u in urls: - u, frag = urllib.parse.urldefrag(u) - u = urllib.parse.urljoin(self.rooturl, u) - if u.startswith(self.rooturl): - if self.add(u): - print(' ', url, '->', u) + if self.add(u): + print(' ', url, 'href to', u) ok = True finally: self.done[url] = ok From aaa6fd348380e0b3460d9f846b86f942e6257885 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 11 Jan 2013 14:49:59 -0800 Subject: [PATCH 0214/1502] Fix typos in _TRYAGAIN tests. --- tulip/unix_events.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 71c52b38..a9d297f8 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -105,7 +105,7 @@ def _read_from_self(self): try: self._ssock.recv(1) except socket.error as exc: - if exc in _TRYAGAIN: + if exc.errno in _TRYAGAIN: return raise # Halp! @@ -113,7 +113,7 @@ def _write_to_self(self): try: self._csock.send(b'x') except socket.error as exc: - if exc in _TRYAGAIN: + if exc.errno in _TRYAGAIN: return raise # Halp! @@ -376,7 +376,7 @@ def _accept_connection(self, protocol_factory, sock): try: conn, addr = sock.accept() except socket.error as exc: - if exc in _TRYAGAIN: + if exc.errno in _TRYAGAIN: return # False alarm. # Bad error. Stop serving. self.remove_reader(sock.fileno()) From 00209220ccf8ecfdc62186a1103891b2807c3854 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 11 Jan 2013 15:15:33 -0800 Subject: [PATCH 0215/1502] Use more lenient decoding practice. --- crawl.py | 2 +- curl.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/crawl.py b/crawl.py index 970152fe..47121587 100755 --- a/crawl.py +++ b/crawl.py @@ -91,7 +91,7 @@ def process(self, url): line = yield from stream.readline() if not line: break - line = line.decode('utf-8') + line = line.decode('utf-8', 'replace') urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', line) for u in urls: diff --git a/curl.py b/curl.py index c5283bfb..1a73c194 100755 --- a/curl.py +++ b/curl.py @@ -23,8 +23,8 @@ def main(): for k, v in headers.items(): print('{}: {}'.format(k, v)) print() - data = p.event_loop.run_until_complete(tulip.Task(stream.read(1000))) - print(data.decode('utf-8')) + data = p.event_loop.run_until_complete(tulip.Task(stream.read(1000000))) + print(data.decode('utf-8', 'replace')) if __name__ == '__main__': From cf5cc4aea0390fb680ccf9bba15473d136067f04 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 11 Jan 2013 15:20:11 -0800 Subject: [PATCH 0216/1502] Remove some TODOs that are done. --- TODO | 58 ---------------------------------------------------------- 1 file changed, 58 deletions(-) diff --git a/TODO b/TODO index 309bcd30..b9559ef0 100644 --- a/TODO +++ b/TODO @@ -1,15 +1,4 @@ # -*- Mode: text -*- - -TO DO SMALLER TASKS - -- Make Task more like Future; getting result() should re-raise exception. - -- Add a decorator just for documenting a coroutine. It should set a - flag on the function. It should not interfere with methods, - staticmethod, classmethod and the like. The Task constructor should - check the flag. The decorator should notice if the wrapped function - is not a generator. - TO DO LARGER TASKS @@ -27,34 +16,8 @@ TO DO LARGER TASKS - Restructure directory, move demos and benchmarks to subdirectories. -FROM BEN DARNELL (Tornado) - -- The waker pipe in ThreadRunner should really be in EventLoop itself - - we need to be able to call call_soon (or some equivalent) from - threads that were not created by ThreadRunner. In Tornado I ended - up needing two functions, add_callback (usable from any thread) and - add_callback_from_signal (usable only from signal handlers). - -- Timeouts should ideally be based on time.monotonic, although this - requires some extra complexity to deal with the cases where you - actually do want time.time. (in tornado, the clock used is - configurable on a per-ioloop basis, which is not ideal but is - workable) - -- I'm sure you've heard this from the twisted guys by now, but to - properly support completion-based event loops like IOCP you need to - be able to swap out most of sockets.py (the layers below - BufferedReader) for an alternative implementation. - - TO DO LATER -- Wrap select(), epoll() etc. in try/except checking for EINTR. - -- Move accept loop into Listener class? (Windows is said to work - better if you make many AcceptEx() calls in parallel.) OTOH we can - already accept many incoming connections without suspending. - - When multiple tasks are accessing the same socket, they should either get interleaved I/O or an immediate exception; it should not compromise the integrity of the scheduler or the app or leave a task @@ -64,8 +27,6 @@ TO DO LATER - Add the simplest API possible to run a generator with a timeout. -- Do we need call_every()? (Easily emulated with a loop and sleep().) - - Ensure multiple tasks can do atomic writes to the same pipe (since UNIX guarantees that short writes to pipes are atomic). @@ -78,10 +39,6 @@ TO DO LATER - See how much of asyncore I've already replaced. -- Do we need _async suffixes to all async APIs? - -- Do we need synchronous parallel APIs for all async APIs? - - Could BufferedReader reuse the standard io module's readers??? - Support ZeroMQ "sockets" which are user objects. Though possibly @@ -97,9 +54,6 @@ TO DO LATER FROM OLDER LIST -- Is it better to have separate add_{reader,writer} methods, vs. one - add_thingie method taking a fd and a r/w flag? - - Multiple readers/writers per socket? (At which level? pollster, eventloop, or scheduler?) @@ -109,23 +63,11 @@ FROM OLDER LIST - Optimize register/unregister calls away if they cancel each other out? -- Should block() use a queue? - - Add explicit wait queue to wait for Task's completion, instead of callbacks? -- Global functions vs. Task methods? - -- Is the Task design good? - -- Make Task more like Future? (Or less???) - - Implement various lock styles a la threading.py. -- Add write() calls that don't require yield from. - -- Add simple non-async APIs, for simple apps? - - Look at pyfdpdlib's ioloop.py: http://code.google.com/p/pyftpdlib/source/browse/trunk/pyftpdlib/lib/ioloop.py From 24872cfbf224f9bcfa472d40eaea2c09ace6a4fa Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 11 Jan 2013 16:17:06 -0800 Subject: [PATCH 0217/1502] Kill fastopen code for now. It feels too experimental. --- tulip/unix_events.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index a9d297f8..69d67f82 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -338,7 +338,7 @@ def create_transport(self, protocol_factory, host, port, *, ssl=False, @tasks.task def start_serving(self, protocol_factory, host, port, *, family=0, type=socket.SOCK_STREAM, proto=0, flags=0, - backlog=100, fastopen=5): + backlog=100): """XXX""" infos = yield from self.getaddrinfo(host, port, family=family, type=type, @@ -359,13 +359,6 @@ def start_serving(self, protocol_factory, host, port, *, break else: raise exceptions[0] - if fastopen and hasattr(socket, 'TCP_FASTOPEN'): - try: - sock.setsockopt(socket.SOL_TCP, socket.TCP_FASTOPEN, fastopen) - except socket.error: - # Even if TCP_FASTOPEN is defined by glibc, it may - # still not be supported by the kernel. - logging.info('TCP_FASTOPEN(%r) failed', fastopen) sock.listen(backlog) sock.setblocking(False) self.add_reader(sock.fileno(), self._accept_connection, From 401aa2db1111698f3f619e350ab5b51d60a1f0dc Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 11 Jan 2013 17:25:38 -0800 Subject: [PATCH 0218/1502] Always close the transport when done processing. --- crawl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/crawl.py b/crawl.py index 47121587..99cd04d5 100755 --- a/crawl.py +++ b/crawl.py @@ -58,6 +58,7 @@ def run(self): @tulip.task def process(self, url): ok = False + p = None try: print('processing', url) scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) @@ -99,6 +100,8 @@ def process(self, url): print(' ', url, 'href to', u) ok = True finally: + if p is not None: + p.transport.close() self.done[url] = ok self.busy.remove(url) if not ok: From 41bbcfdbcefc7c2eb73e40a564d685f1d06eb159 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 11 Jan 2013 17:26:46 -0800 Subject: [PATCH 0219/1502] Don't print when connection is closed. --- tulip/http_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tulip/http_client.py b/tulip/http_client.py index 3ebdff0f..d5f2e57c 100644 --- a/tulip/http_client.py +++ b/tulip/http_client.py @@ -289,4 +289,4 @@ def eof_received(self): self.stream.feed_eof() def connection_lost(self, exc): - print('connection lost:', exc) + pass From 1716b0400d33beb112c52453982fd7bfbff2a3df Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 11 Jan 2013 17:30:00 -0800 Subject: [PATCH 0220/1502] Limit #tasks. Also make end of line char configurable. --- crawl.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/crawl.py b/crawl.py index 99cd04d5..a1303fcb 100755 --- a/crawl.py +++ b/crawl.py @@ -10,6 +10,9 @@ import tulip from tulip import http_client +END = '\n' +MAXTASKS = 100 + class Crawler: @@ -41,12 +44,14 @@ def add(self, url): def run(self): while self.todo or self.busy or self.tasks: complete, self.tasks = yield from tulip.wait(self.tasks, timeout=0) + print(len(complete), 'completed tasks,', len(self.tasks), + 'still pending ', end=END) for task in complete: try: yield from task except Exception as exc: - logging.warn('Exception in task: %s', exc) - while self.todo: + print('Exception in task:', exc, end=END) + while self.todo and len(self.tasks) < MAXTASKS: url = self.todo.pop() self.busy.add(url) self.tasks.add(self.process(url)) # Async task. @@ -60,7 +65,7 @@ def process(self, url): ok = False p = None try: - print('processing', url) + print('processing', url, end=END) scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) if not path: path = '/' @@ -77,14 +82,14 @@ def process(self, url): if delay >= 60: raise print('...', url, 'has error', repr(str(exc)), - 'retrying after sleep', delay, '...') + 'retrying after sleep', delay, '...', end=END) yield from tulip.sleep(delay) delay *= 2 if status[:3] in ('301', '302'): # Redirect. u = headers.get('location') or headers.get('uri') if self.add(u): - print(' ', url, status[:3], 'redirect to', u) + print(' ', url, status[:3], 'redirect to', u, end=END) elif status.startswith('200'): ctype = headers.get_content_type() if ctype == 'text/html': @@ -97,7 +102,7 @@ def process(self, url): line) for u in urls: if self.add(u): - print(' ', url, 'href to', u) + print(' ', url, 'href to', u, end=END) ok = True finally: if p is not None: @@ -105,7 +110,7 @@ def process(self, url): self.done[url] = ok self.busy.remove(url) if not ok: - print('failure for', url, sys.exc_info()) + print('failure for', url, sys.exc_info(), end=END) waiter = self.waiter if waiter is not None: self.waiter = None From 0626c1ddcc3d6b9fa0f269d3e539347569998e62 Mon Sep 17 00:00:00 2001 From: Charles-Fran?ois Natali Date: Sat, 12 Jan 2013 12:45:53 +0100 Subject: [PATCH 0221/1502] Fix a typo In SelectorKey.__repr__(). --- .hgignore | 8 + Makefile | 30 ++ NOTES | 130 ++++++ README | 30 ++ TODO | 165 +++++++ check.py | 40 ++ crawl.py | 125 ++++++ curl.py | 31 ++ old/Makefile | 16 + old/echoclt.py | 79 ++++ old/echosvr.py | 60 +++ old/http_client.py | 78 ++++ old/http_server.py | 68 +++ old/main.py | 134 ++++++ old/p3time.py | 47 ++ old/polling.py | 535 +++++++++++++++++++++++ old/scheduling.py | 354 +++++++++++++++ old/sockets.py | 348 +++++++++++++++ old/transports.py | 496 +++++++++++++++++++++ old/xkcd.py | 18 + old/yyftime.py | 75 ++++ runtests.py | 61 +++ tulip/TODO | 26 ++ tulip/__init__.py | 14 + tulip/events.py | 287 ++++++++++++ tulip/events_test.py | 428 ++++++++++++++++++ tulip/futures.py | 193 +++++++++ tulip/futures_test.py | 71 +++ tulip/http_client.py | 292 +++++++++++++ tulip/protocols.py | 58 +++ tulip/selectors.py | 396 +++++++++++++++++ tulip/tasks.py | 277 ++++++++++++ tulip/tasks_test.py | 233 ++++++++++ tulip/transports.py | 90 ++++ tulip/unix_events.py | 963 +++++++++++++++++++++++++++++++++++++++++ tulip/winsocketpair.py | 30 ++ 36 files changed, 6286 insertions(+) create mode 100644 .hgignore create mode 100644 Makefile create mode 100644 NOTES create mode 100644 README create mode 100644 TODO create mode 100644 check.py create mode 100755 crawl.py create mode 100755 curl.py create mode 100644 old/Makefile create mode 100644 old/echoclt.py create mode 100644 old/echosvr.py create mode 100644 old/http_client.py create mode 100644 old/http_server.py create mode 100644 old/main.py create mode 100644 old/p3time.py create mode 100644 old/polling.py create mode 100644 old/scheduling.py create mode 100644 old/sockets.py create mode 100644 old/transports.py create mode 100755 old/xkcd.py create mode 100644 old/yyftime.py create mode 100644 runtests.py create mode 100644 tulip/TODO create mode 100644 tulip/__init__.py create mode 100644 tulip/events.py create mode 100644 tulip/events_test.py create mode 100644 tulip/futures.py create mode 100644 tulip/futures_test.py create mode 100644 tulip/http_client.py create mode 100644 tulip/protocols.py create mode 100644 tulip/selectors.py create mode 100644 tulip/tasks.py create mode 100644 tulip/tasks_test.py create mode 100644 tulip/transports.py create mode 100644 tulip/unix_events.py create mode 100644 tulip/winsocketpair.py diff --git a/.hgignore b/.hgignore new file mode 100644 index 00000000..42309f0c --- /dev/null +++ b/.hgignore @@ -0,0 +1,8 @@ +.*\.py[co]$ +.*~$ +.*\.orig$ +.*\#.*$ +.*@.*$ +\.coverage$ +htmlcov$ +\.DS_Store$ diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..d11e9716 --- /dev/null +++ b/Makefile @@ -0,0 +1,30 @@ +PYTHON=python3 +COVERAGE=coverage3 +NONTESTS=`find tulip -name [a-z]\*.py ! -name \*_test.py` +FLAGS= + +test: + $(PYTHON) runtests.py -v $(FLAGS) + +testloop: + while sleep 1; do $(PYTHON) runtests.py -v $(FLAGS); done + +cov coverage: + $(COVERAGE) run --branch runtests.py -v $(FLAGS) + $(COVERAGE) html $(NONTESTS) + $(COVERAGE) report -m $(NONTESTS) + echo "open file://`pwd`/htmlcov/index.html" + +check: + $(PYTHON) check.py + +clean: + rm -rf __pycache__ */__pycache__ + rm -f *.py[co] */*.py[co] + rm -f *~ */*~ + rm -f .*~ */.*~ + rm -f @* */@* + rm -f '#'*'#' */'#'*'#' + rm -f *.orig */*.orig + rm -f .coverage + rm -rf htmlcov diff --git a/NOTES b/NOTES new file mode 100644 index 00000000..6f41578e --- /dev/null +++ b/NOTES @@ -0,0 +1,130 @@ +Notes from the second Tulip/Twisted meet-up +=========================================== + +Rackspace, 12/11/2012 +Glyph, Brian Warner, David Reid, Duncan McGreggor, others + +Flow control +------------ + +- Pause/resume on transport manages data_received. + +- There's also an API to tell the transport whom to pause when the + write calls are overwhelming it: IConsumer.registerProducer(). + +- There's also something called pipes but it's built on top of the + old interface. + +- Twisted has variations on the basic flow control that I should + ignore. + +Half_close +---------- + +- This sends an EOF after writing some stuff. + +- Can't write any more. + +- Problem with TLS is known (the RFC sadly specifies this behavior). + +- It must be dynamimcally discoverable whether the transport supports + half_close, since the protocol may have to do something different to + make up for its missing (e.g. use chunked encoding). Twisted uses + an interface check for this and also hasattr(trans, 'halfClose') + but a flag (or flag method) is fine too. + +Constructing transport and protocol +----------------------------------- + +- There are good reasons for passing a function to the transport + construction helper that creates the protocol. (You need these + anyway for server-side protocols.) The sequence of events is + something like + + . open socket + . create transport (pass it a socket?) + . create protocol (pass it nothing) + . proto.make_connection(transport); this does: + . self.transport = transport + . self.connection_made(transport) + + But it seems okay to skip make_connection and setting .transport. + Note that make_connection() is a concrete method on the Protocol + implementation base class, while connection_made() is an abstract + method on IProtocol. + +Event Loop +---------- + +- We discussed the sequence of actions in the event loop. I think in the + end we're fine with what Tulip currently does. There are two choices: + + Tulip: + . run ready callbacks until there aren't any left + . poll, adding more callbacks to the ready list + . add now-ready delayed callbacks to the ready list + . go to top + + Tornado: + . run all currently ready callbacks (but not new ones added during this) + . (the rest is the same) + + The difference is that in the Tulip version, CPU bound callbacks + that keep adding more to the queue will starve I/O (and yielding to + other tasks won't actually cause I/O to happen unless you do + e.g. sleep(0.001)). OTOH this may be good because it means there's + less overhead if you frequently split operations in two. + +- I think Twisted does it Tornado style (in a convoluted way :-), but + it may not matter, and it's important to leave this vague so + implementations can do what's best for their platform. (E.g. if the + event loop is built into the OS there are different trade-offs.) + +System call cost +---------------- + +- System calls on MacOS are expensive, on Linux they are cheap. + +- Optimal buffer size ~16K. + +- Try joining small buffer pieces together, but expect to be tuning + this later. + +Futures +------- + +- Futures are the most robust API for async stuff, you can check + errors etc. So let's do this. + +- Just don't implement wait(). + +- For the basics, however, (recv/send, mostly), don't use Futures but use + basic callbacks, transport/protocol style. + +- make_connection() (by any name) can return a Future, it makes it + easier to check for errors. + +- This means revisiting the Tulip proactor branch (IOCP). + +- The semantics of add_done_callback() are fuzzy about in which thread + the callback will be called. (It may be the current thread or + another one.) We don't like that. But always inserting a + call_soon() indirection may be expensive? Glyph suggested changing + the add_done_callback() method name to something else to indicate + the changed promise. + +- Separately, I've been thinking about having two versions of + call_soon() -- a more heavy-weight one to be called from other + threads that also writes a byte to the self-pipe. + +Signals +------- + +- There was a side conversation about signals. A signal handler is + similar to another thread, so probably should use (the heavy-weight + version of) call_soon() to schedule the real callback and not do + anything else. + +- Glyph vaguely recalled some trickiness with the self-pipe. We + should be able to fix this afterwards if necessary, it shouldn't + affect the API design. diff --git a/README b/README new file mode 100644 index 00000000..c1c86a54 --- /dev/null +++ b/README @@ -0,0 +1,30 @@ +Tulip is the codename for my reference implementation of PEP 3156. + +PEP 3156: http://www.python.org/dev/peps/pep-3156/ + +*** This requires Python 3.3 or later! *** + +Copyright/license: Open source, Apache 2.0. Enjoy. + +Master Mercurial repo: http://code.google.com/p/tulip/ + +The old code lives in the subdirectory 'old'; the new code (conforming +to PEP 3156, under construction) lives in the 'tulip' subdirectory. + +To run tests: + - make test + +To run coverage (after installing coverage3, see below): + - make coverage + +To install coverage3 (coverage.py for Python 3), you need: + - Distribute (http://packages.python.org/distribute/) + - Coverage (http://nedbatchelder.com/code/coverage/) + What worked for me: + - curl -O http://python-distribute.org/distribute_setup.py + - python3 distribute_setup.py + - + - cd coveragepy + - python3 setup.py install + +--Guido van Rossum diff --git a/TODO b/TODO new file mode 100644 index 00000000..b9559ef0 --- /dev/null +++ b/TODO @@ -0,0 +1,165 @@ +# -*- Mode: text -*- + +TO DO LARGER TASKS + +- Need more examples. + +- Benchmarkable but more realistic HTTP server? + +- Example of using UDP. + +- Write up a tutorial for the scheduling API. + +- More systematic approach to logging. Logger objects? What about + heavy-duty logging, tracking essentially all task state changes? + +- Restructure directory, move demos and benchmarks to subdirectories. + + +TO DO LATER + +- When multiple tasks are accessing the same socket, they should + either get interleaved I/O or an immediate exception; it should not + compromise the integrity of the scheduler or the app or leave a task + hanging. + +- For epoll you probably want to check/(log?) EPOLLHUP and EPOLLERR errors. + +- Add the simplest API possible to run a generator with a timeout. + +- Ensure multiple tasks can do atomic writes to the same pipe (since + UNIX guarantees that short writes to pipes are atomic). + +- Ensure some easy way of distributing accepted connections across tasks. + +- Be wary of thread-local storage. There should be a standard API to + get the current Context (which holds current task, event loop, and + maybe more) and a standard meta-API to change how that standard API + works (i.e. without monkey-patching). + +- See how much of asyncore I've already replaced. + +- Could BufferedReader reuse the standard io module's readers??? + +- Support ZeroMQ "sockets" which are user objects. Though possibly + this can be supported by getting the underlying fd? See + http://mail.python.org/pipermail/python-ideas/2012-October/017532.html + OTOH see + https://github.com/zeromq/pyzmq/blob/master/zmq/eventloop/ioloop.py + +- Study goroutines (again). + +- Benchmarks: http://nichol.as/benchmark-of-python-web-servers + + +FROM OLDER LIST + +- Multiple readers/writers per socket? (At which level? pollster, + eventloop, or scheduler?) + +- Could poll() usefully be an iterator? + +- Do we need to support more epoll and/or kqueue modes/flags/options/etc.? + +- Optimize register/unregister calls away if they cancel each other out? + +- Add explicit wait queue to wait for Task's completion, instead of + callbacks? + +- Implement various lock styles a la threading.py. + +- Look at pyfdpdlib's ioloop.py: + http://code.google.com/p/pyftpdlib/source/browse/trunk/pyftpdlib/lib/ioloop.py + + +MISTAKES I MADE + +- Forgetting yield from. (E.g.: scheduler.sleep(1); listener.accept().) + +- Forgot to add bare yield at end of internal function, after block(). + +- Forgot to call add_done_callback(). + +- Forgot to pass an undoer to block(), bug only found when cancelled. + +- Subtle accounting mistake in a callback. + +- Used context.eventloop from a different thread, forgetting about TLS. + +- Nasty race: eventloop.ready may contain both an I/O callback and a + cancel callback. How to avoid? Keep the DelayedCall in ready. Is + that enough? + +- If a toplevel task raises an error it just stops and nothing is logged + unless you have debug logging on. This confused me. (Then again, + previously I logged whenever a task raised an error, and that was too + chatty...) + +- Forgot to set the connection socket returned by accept() in + nonblocking mode. + +- Nastiest so far (cost me about a day): A race condition in + call_in_thread() where the Future's done_callback (which was + task.unblock()) would run immediately at the time when + add_done_callback() was called, and this screwed over the task + state. Solution: wrap the callback in eventloop.call_later(). + Ironically, I had a comment stating there might be a race condition. + +- Another bug where I was calling unblock() for the current thread + immediately after calling block(), before yielding. + +- readexactly() wasn't checking for EOF, so could be looping. + (Worse, the first fix I attempted was wrong.) + +- Spent a day trying to understand why a tentative patch trying to + move the recv() implementation into the eventloop (or the pollster) + resulted in problems cancelling a recv() call. Ultimately the + problem is that the cancellation mechanism is part of the coroutine + scheduler, which simply throws an exception into a task when it next + runs, and there isn't anything to be interrupted in the eventloop; + but the eventloop still has a reader registered (which will never + fire because I suspended the server -- that's my test case :-). + Then, the eventloop keeps running until the last file descriptor is + unregistered. What contributed to this disaster? + * I didn't build the whole infrastructure, just played with recv() + * I don't have unittests + * I don't have good logging to see what is going + +- In sockets.py, in some SSL error handling code, used the wrong + variable (sock instead of sslsock). A linter would have found this. + +- In polling.py, in KqueuePollster.register_writer(), a copy/paste + error where I was testing for "if fd not in self.readers" instead of + writers. This only came out when I had both a reader and a writer + for the same fd. + +- Submitted some changes prematurely (forgot to pass the filename on + hg ci). + +- Forgot again that shutdown(SHUT_WR) on an ssl socket does not work + as I expected. I ran into this with the origininal sockets.py and + again in transport.py. + +- Having the same callback for both reading and writing has a problem: + it may be scheduled twice, and if the first call closes the socket, + the second runs into trouble. + + +MISTAKES I MADE IN TULIP V2 + +- Nice one: Future._schedule_callbacks() wasn't scheduling any callbacks. + Spot the bug in these four lines: + + def _schedule_callbacks(self): + callbacks = self._callbacks[:] + self._callbacks[:] = [] + for callback in self._callbacks: + self._event_loop.call_soon(callback, self) + + The good news is that I found it with a unittest (albeit not the + unittest intended to exercise this particular method :-( ). + +- In _make_self_pipe_or_sock(), called _pollster.register_reader() + instead of add_reader(), trying to optimize something but breaking + things instead (since the -- internal -- API of register_reader() + had changed). diff --git a/check.py b/check.py new file mode 100644 index 00000000..f0aa9a66 --- /dev/null +++ b/check.py @@ -0,0 +1,40 @@ +"""Search for lines > 80 chars or with trailing whitespace.""" + +import sys, os + +def main(): + args = sys.argv[1:] or os.curdir + for arg in args: + if os.path.isdir(arg): + for dn, dirs, files in os.walk(arg): + for fn in sorted(files): + if fn.endswith('.py'): + process(os.path.join(dn, fn)) + dirs[:] = [d for d in dirs if d[0] != '.'] + dirs.sort() + else: + process(arg) + +def isascii(x): + try: + x.encode('ascii') + return True + except UnicodeError: + return False + +def process(fn): + try: + f = open(fn) + except IOError as err: + print(err) + return + try: + for i, line in enumerate(f): + line = line.rstrip('\n') + sline = line.rstrip() + if len(line) > 80 or line != sline or not isascii(line): + print('%s:%d:%s%s' % (fn, i+1, sline, '_' * (len(line) - len(sline)))) + finally: + f.close() + +main() diff --git a/crawl.py b/crawl.py new file mode 100755 index 00000000..47121587 --- /dev/null +++ b/crawl.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 + +import logging +import re +import signal +import socket +import sys +import urllib.parse + +import tulip +from tulip import http_client + + +class Crawler: + + def __init__(self, rooturl): + self.rooturl = rooturl + self.todo = set() + self.busy = set() + self.done = {} + self.tasks = set() + self.waiter = None + self.add(self.rooturl) # Set initial work. + self.run() # Kick off work. + + def add(self, url): + url = urllib.parse.urljoin(self.rooturl, url) + url, frag = urllib.parse.urldefrag(url) + if not url.startswith(self.rooturl): + return False + if url in self.busy or url in self.done or url in self.todo: + return False + self.todo.add(url) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(None) + return True + + @tulip.task + def run(self): + while self.todo or self.busy or self.tasks: + complete, self.tasks = yield from tulip.wait(self.tasks, timeout=0) + for task in complete: + try: + yield from task + except Exception as exc: + logging.warn('Exception in task: %s', exc) + while self.todo: + url = self.todo.pop() + self.busy.add(url) + self.tasks.add(self.process(url)) # Async task. + if self.busy: + self.waiter = tulip.Future() + yield from self.waiter + tulip.get_event_loop().stop() + + @tulip.task + def process(self, url): + ok = False + try: + print('processing', url) + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not path: + path = '/' + if query: + path = '?'.join([path, query]) + p = http_client.HttpClientProtocol(netloc, path=path, + ssl=(scheme=='https')) + delay = 1 + while True: + try: + status, headers, stream = yield from p.connect() + break + except socket.error as exc: + if delay >= 60: + raise + print('...', url, 'has error', repr(str(exc)), + 'retrying after sleep', delay, '...') + yield from tulip.sleep(delay) + delay *= 2 + if status[:3] in ('301', '302'): + # Redirect. + u = headers.get('location') or headers.get('uri') + if self.add(u): + print(' ', url, status[:3], 'redirect to', u) + elif status.startswith('200'): + ctype = headers.get_content_type() + if ctype == 'text/html': + while True: + line = yield from stream.readline() + if not line: + break + line = line.decode('utf-8', 'replace') + urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', + line) + for u in urls: + if self.add(u): + print(' ', url, 'href to', u) + ok = True + finally: + self.done[url] = ok + self.busy.remove(url) + if not ok: + print('failure for', url, sys.exc_info()) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(None) + + +def main(): + rooturl = sys.argv[1] + c = Crawler(rooturl) + loop = tulip.get_event_loop() + loop.add_signal_handler(signal.SIGINT, loop.stop) + loop.run_forever() + print('todo:', len(c.todo)) + print('busy:', len(c.busy)) + print('done:', len(c.done), '; ok:', sum(c.done.values())) + print('tasks:', len(c.tasks)) + + +if __name__ == '__main__': + main() diff --git a/curl.py b/curl.py new file mode 100755 index 00000000..1a73c194 --- /dev/null +++ b/curl.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 + +import sys +import urllib.parse + +import tulip +from tulip import http_client + + +def main(): + url = sys.argv[1] + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not path: + path = '/' + if query: + path = '?'.join([path, query]) + print(netloc, path, scheme) + p = http_client.HttpClientProtocol(netloc, path=path, + ssl=(scheme=='https')) + f = p.connect() + sts, headers, stream = p.event_loop.run_until_complete(tulip.Task(f)) + print(sts) + for k, v in headers.items(): + print('{}: {}'.format(k, v)) + print() + data = p.event_loop.run_until_complete(tulip.Task(stream.read(1000000))) + print(data.decode('utf-8', 'replace')) + + +if __name__ == '__main__': + main() diff --git a/old/Makefile b/old/Makefile new file mode 100644 index 00000000..d352cd70 --- /dev/null +++ b/old/Makefile @@ -0,0 +1,16 @@ +PYTHON=python3 + +main: + $(PYTHON) main.py -v + +echo: + $(PYTHON) echosvr.py -v + +profile: + $(PYTHON) -m profile -s time main.py + +time: + $(PYTHON) p3time.py + +ytime: + $(PYTHON) yyftime.py diff --git a/old/echoclt.py b/old/echoclt.py new file mode 100644 index 00000000..c24c573e --- /dev/null +++ b/old/echoclt.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3.3 +"""Example echo client.""" + +# Stdlib imports. +import logging +import socket +import sys +import time + +# Local imports. +import scheduling +import sockets + + +def echoclient(host, port): + """COROUTINE""" + testdata = b'hi hi hi ha ha ha\n' + try: + trans = yield from sockets.create_transport(host, port, + af=socket.AF_INET) + except OSError: + return False + try: + ok = yield from trans.send(testdata) + if ok: + response = yield from trans.recv(100) + ok = response == testdata.upper() + return ok + finally: + trans.close() + + +def doit(n): + """COROUTINE""" + t0 = time.time() + tasks = set() + for i in range(n): + t = scheduling.Task(echoclient('127.0.0.1', 1111), 'client-%d' % i) + tasks.add(t) + ok = 0 + bad = 0 + for t in tasks: + try: + yield from t + except Exception: + bad += 1 + else: + ok += 1 + t1 = time.time() + print('ok: ', ok) + print('bad:', bad) + print('dt: ', round(t1-t0, 6)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Get integer from command line. + n = 1 + for arg in sys.argv[1:]: + if not arg.startswith('-'): + n = int(arg) + break + + # Run scheduler, starting it off with doit(). + scheduling.run(doit(n)) + + +if __name__ == '__main__': + main() diff --git a/old/echosvr.py b/old/echosvr.py new file mode 100644 index 00000000..4085f4c6 --- /dev/null +++ b/old/echosvr.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3.3 +"""Example echo server.""" + +# Stdlib imports. +import logging +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + while True: + line = yield from rdr.readline() + logging.debug('Received: %r from %r', line, addr) + if not line: + break + yield from trans.send(line.upper()) + logging.debug('Closing %r', addr) + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 1111, + af=socket.AF_INET, + backlog=100) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() diff --git a/old/http_client.py b/old/http_client.py new file mode 100644 index 00000000..8937ba20 --- /dev/null +++ b/old/http_client.py @@ -0,0 +1,78 @@ +"""Crummy HTTP client. + +This is not meant as an example of how to write a good client. +""" + +# Stdlib. +import re +import time + +# Local. +import sockets + + +def urlfetch(host, port=None, path='/', method='GET', + body=None, hdrs=None, encoding='utf-8', ssl=None, af=0): + """COROUTINE: Make an HTTP 1.0 request.""" + t0 = time.time() + if port is None: + if ssl: + port = 443 + else: + port = 80 + trans = yield from sockets.create_transport(host, port, ssl=ssl, af=af) + yield from trans.send(method.encode(encoding) + b' ' + + path.encode(encoding) + b' HTTP/1.0\r\n') + if hdrs: + kwds = dict(hdrs) + else: + kwds = {} + if 'host' not in kwds: + kwds['host'] = host + if body is not None: + kwds['content_length'] = len(body) + for header, value in kwds.items(): + yield from trans.send(header.replace('_', '-').encode(encoding) + + b': ' + value.encode(encoding) + b'\r\n') + + yield from trans.send(b'\r\n') + if body is not None: + yield from trans.send(body) + + # Read HTTP response line. + rdr = sockets.BufferedReader(trans) + resp = yield from rdr.readline() + m = re.match(br'(?ix) http/(\d\.\d) \s+ (\d\d\d) \s+ ([^\r]*)\r?\n\Z', + resp) + if not m: + trans.close() + raise IOError('No valid HTTP response: %r' % resp) + http_version, status, message = m.groups() + + # Read HTTP headers. + headers = [] + hdict = {} + while True: + line = yield from rdr.readline() + if not line.strip(): + break + m = re.match(br'([^\s:]+):\s*([^\r]*)\r?\n\Z', line) + if not m: + raise IOError('Invalid header: %r' % line) + header, value = m.groups() + headers.append((header, value)) + hdict[header.decode(encoding).lower()] = value.decode(encoding) + + # Read response body. + content_length = hdict.get('content-length') + if content_length is not None: + size = int(content_length) # TODO: Catch errors. + assert size >= 0, size + else: + size = 2**20 # Protective limit (1 MB). + data = yield from rdr.readexactly(size) + trans.close() # Can this block? + t1 = time.time() + result = (host, port, path, int(status), len(data), round(t1-t0, 3)) +## print(result) + return result diff --git a/old/http_server.py b/old/http_server.py new file mode 100644 index 00000000..2b1e3dd6 --- /dev/null +++ b/old/http_server.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3.3 +"""Simple HTTP server. + +This currenty exists just so we can benchmark this thing! +""" + +# Stdlib imports. +import logging +import re +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + ##logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + + # Read but ignore request line. + request_line = yield from rdr.readline() + + # Consume headers but don't interpret them. + while True: + header_line = yield from rdr.readline() + if not header_line.strip(): + break + + # Always send an empty 200 response and close. + yield from trans.send(b'HTTP/1.0 200 Ok\r\n\r\n') + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 8080, + af=socket.AF_INET) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() diff --git a/old/main.py b/old/main.py new file mode 100644 index 00000000..c1f9d0a8 --- /dev/null +++ b/old/main.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3.3 +"""Example HTTP client using yield-from coroutines (PEP 380). + +Requires Python 3.3. + +There are many micro-optimizations possible here, but that's not the point. + +Some incomplete laundry lists: + +TODO: +- Take test urls from command line. +- Move urlfetch to a separate module. +- Profiling. +- Docstrings. +- Unittests. + +FUNCTIONALITY: +- Connection pool (keep connection open). +- Chunked encoding (request and response). +- Pipelining, e.g. zlib (request and response). +- Automatic encoding/decoding. +""" + +__author__ = 'Guido van Rossum ' + +# Standard library imports (keep in alphabetic order). +import logging +import os +import time +import socket +import sys + +# Local imports (keep in alphabetic order). +import scheduling +import http_client + + + +def doit2(): + argses = [ + ('localhost', 8080, '/'), + ('127.0.0.1', 8080, '/home'), + ('python.org', 80, '/'), + ('xkcd.com', 443, '/'), + ] + results = yield from scheduling.map_over( + lambda args: http_client.urlfetch(*args), argses, timeout=2) + for res in results: + print('-->', res) + return [] + + +def doit(): + TIMEOUT = 2 + tasks = set() + + # This references NDB's default test service. + # (Sadly the service is single-threaded.) + task1 = scheduling.Task(http_client.urlfetch('localhost', 8080, path='/'), + 'root', timeout=TIMEOUT) + tasks.add(task1) + task2 = scheduling.Task(http_client.urlfetch('127.0.0.1', 8080, + path='/home'), + 'home', timeout=TIMEOUT) + tasks.add(task2) + + # Fetch python.org home page. + task3 = scheduling.Task(http_client.urlfetch('python.org', 80, path='/'), + 'python', timeout=TIMEOUT) + tasks.add(task3) + + # Fetch XKCD home page using SSL. (Doesn't like IPv6.) + task4 = scheduling.Task(http_client.urlfetch('xkcd.com', ssl=True, path='/', + af=socket.AF_INET), + 'xkcd', timeout=TIMEOUT) + tasks.add(task4) + +## # Fetch many links from python.org (/x.y.z). +## for x in '123': +## for y in '0123456789': +## path = '/{}.{}'.format(x, y) +## g = http_client.urlfetch('82.94.164.162', 80, +## path=path, hdrs={'host': 'python.org'}) +## t = scheduling.Task(g, path, timeout=2) +## tasks.add(t) + +## print(tasks) + yield from scheduling.Task(scheduling.sleep(1), timeout=0.2).wait() + winners = yield from scheduling.wait_any(tasks) + print('And the winners are:', [w.name for w in winners]) + tasks = yield from scheduling.wait_all(tasks) + print('And the players were:', [t.name for t in tasks]) + return tasks + + +def logtimes(real): + utime, stime, cutime, cstime, unused = os.times() + logging.info('real %10.3f', real) + logging.info('user %10.3f', utime + cutime) + logging.info('sys %10.3f', stime + cstime) + + +def main(): + t0 = time.time() + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + task = scheduling.run(doit()) + if task.exception: + print('Exception:', repr(task.exception)) + if isinstance(task.exception, AssertionError): + raise task.exception + else: + for t in task.result: + print(t.name + ':', + repr(t.exception) if t.exception else t.result) + + # Report real, user, sys times. + t1 = time.time() + logtimes(t1-t0) + + +if __name__ == '__main__': + main() diff --git a/old/p3time.py b/old/p3time.py new file mode 100644 index 00000000..35e14c96 --- /dev/null +++ b/old/p3time.py @@ -0,0 +1,47 @@ +"""Compare timing of plain vs. yield-from calls.""" + +import gc +import time + +def plain(n): + if n <= 0: + return 1 + l = plain(n-1) + r = plain(n-1) + return l + 1 + r + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def submain(depth): + t0 = time.time() + k = plain(depth) + t1 = time.time() + fmt = ' {} {} {:-9,.5f}' + delta0 = t1-t0 + print(('plain' + fmt).format(depth, k, delta0)) + + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + delta1 = t1-t0 + print(('coro.' + fmt).format(depth, k, delta1)) + if delta0: + print(('relat' + fmt).format(depth, k, delta1/delta0)) + +def main(reasonable=16): + gc.disable() + for depth in range(reasonable): + submain(depth) + +if __name__ == '__main__': + main() diff --git a/old/polling.py b/old/polling.py new file mode 100644 index 00000000..6586efcc --- /dev/null +++ b/old/polling.py @@ -0,0 +1,535 @@ +"""Event loop and related classes. + +The event loop can be broken up into a pollster (the part responsible +for telling us when file descriptors are ready) and the event loop +proper, which wraps a pollster with functionality for scheduling +callbacks, immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. + +There are several implementations of the pollster part, several using +esoteric system calls that exist only on some platforms. These are: + +- kqueue (most BSD systems) +- epoll (newer Linux systems) +- poll (most UNIX systems) +- select (all UNIX systems, and Windows) +- TODO: Support IOCP on Windows and some UNIX platforms. + +NOTE: We don't use select on systems where any of the others is +available, because select performs poorly as the number of file +descriptors goes up. The ranking is roughly: + + 1. kqueue, epoll, IOCP + 2. poll + 3. select + +TODO: +- Optimize the various pollsters. +- Unittests. +""" + +import collections +import concurrent.futures +import heapq +import logging +import os +import select +import threading +import time + + +class PollsterBase: + """Base class for all polling implementations. + + This defines an interface to register and unregister readers and + writers for specific file descriptors, and an interface to get a + list of events. There's also an interface to check whether any + readers or writers are currently registered. + """ + + def __init__(self): + super().__init__() + self.readers = {} # {fd: token, ...}. + self.writers = {} # {fd: token, ...}. + + def pollable(self): + """Return True if any readers or writers are currently registered.""" + return bool(self.readers or self.writers) + + # Subclasses are expected to extend the add/remove methods. + + def register_reader(self, fd, token): + """Add or update a reader for a file descriptor.""" + self.readers[fd] = token + + def register_writer(self, fd, token): + """Add or update a writer for a file descriptor.""" + self.writers[fd] = token + + def unregister_reader(self, fd): + """Remove the reader for a file descriptor.""" + del self.readers[fd] + + def unregister_writer(self, fd): + """Remove the writer for a file descriptor.""" + del self.writers[fd] + + def poll(self, timeout=None): + """Poll for events. A subclass must implement this. + + If timeout is omitted or None, this blocks until at least one + event is ready. Otherwise, timeout gives a maximum time to + wait (in seconds as an int or float) -- the method returns as + soon as at least one event is ready or when the timeout is + expired. For a non-blocking poll, pass 0. + + The return value is a list of events; it is empty when the + timeout expired before any events were ready. Each event + is a token previously passed to register_reader/writer(). + """ + raise NotImplementedError + + +class SelectPollster(PollsterBase): + """Pollster implementation using select.""" + + def poll(self, timeout=None): + readable, writable, _ = select.select(self.readers, self.writers, + [], timeout) + events = [] + events += (self.readers[fd] for fd in readable) + events += (self.writers[fd] for fd in writable) + return events + + +class PollPollster(PollsterBase): + """Pollster implementation using poll.""" + + def __init__(self): + super().__init__() + self._poll = select.poll() + + def _update(self, fd): + assert isinstance(fd, int), fd + flags = 0 + if fd in self.readers: + flags |= select.POLLIN + if fd in self.writers: + flags |= select.POLLOUT + if flags: + self._poll.register(fd, flags) + else: + self._poll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + # Timeout is in seconds, but poll() takes milliseconds. + msecs = None if timeout is None else int(round(1000 * timeout)) + events = [] + for fd, flags in self._poll.poll(msecs): + if flags & (select.POLLIN | select.POLLHUP): + if fd in self.readers: + events.append(self.readers[fd]) + if flags & (select.POLLOUT | select.POLLHUP): + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class EPollPollster(PollsterBase): + """Pollster implementation using epoll.""" + + def __init__(self): + super().__init__() + self._epoll = select.epoll() + + def _update(self, fd): + assert isinstance(fd, int), fd + eventmask = 0 + if fd in self.readers: + eventmask |= select.EPOLLIN + if fd in self.writers: + eventmask |= select.EPOLLOUT + if eventmask: + try: + self._epoll.register(fd, eventmask) + except IOError: + self._epoll.modify(fd, eventmask) + else: + self._epoll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + if timeout is None: + timeout = -1 # epoll.poll() uses -1 to mean "wait forever". + events = [] + for fd, eventmask in self._epoll.poll(timeout): + if eventmask & select.EPOLLIN: + if fd in self.readers: + events.append(self.readers[fd]) + if eventmask & select.EPOLLOUT: + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class KqueuePollster(PollsterBase): + """Pollster implementation using kqueue.""" + + def __init__(self): + super().__init__() + self._kqueue = select.kqueue() + + def register_reader(self, fd, callback, *args): + if fd not in self.readers: + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_reader(fd, callback, *args) + + def register_writer(self, fd, callback, *args): + if fd not in self.writers: + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_writer(fd, callback, *args) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def poll(self, timeout=None): + events = [] + max_ev = len(self.readers) + len(self.writers) + for kev in self._kqueue.control(None, max_ev, timeout): + fd = kev.ident + flag = kev.filter + if flag == select.KQ_FILTER_READ and fd in self.readers: + events.append(self.readers[fd]) + elif flag == select.KQ_FILTER_WRITE and fd in self.writers: + events.append(self.writers[fd]) + return events + + +# Pick the best pollster class for the platform. +if hasattr(select, 'kqueue'): + best_pollster = KqueuePollster +elif hasattr(select, 'epoll'): + best_pollster = EPollPollster +elif hasattr(select, 'poll'): + best_pollster = PollPollster +else: + best_pollster = SelectPollster + + +class DelayedCall: + """Object returned by callback registration methods.""" + + def __init__(self, when, callback, args, kwds=None): + self.when = when + self.callback = callback + self.args = args + self.kwds = kwds + self.cancelled = False + + def cancel(self): + self.cancelled = True + + def __lt__(self, other): + return self.when < other.when + + def __le__(self, other): + return self.when <= other.when + + def __eq__(self, other): + return self.when == other.when + + +class EventLoop: + """Event loop functionality. + + This defines public APIs call_soon(), call_later(), run_once() and + run(). It also wraps Pollster APIs register_reader(), + register_writer(), remove_reader(), remove_writer() with + add_reader() etc. + + This class's instance variables are not part of its API. + """ + + def __init__(self, pollster=None): + super().__init__() + if pollster is None: + logging.info('Using pollster: %s', best_pollster.__name__) + pollster = best_pollster() + self.pollster = pollster + self.ready = collections.deque() # [(callback, args), ...] + self.scheduled = [] # [(when, callback, args), ...] + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_reader(fd, dcall) + return dcall + + def remove_reader(self, fd): + """Remove a reader callback.""" + self.pollster.unregister_reader(fd) + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_writer(fd, dcall) + return dcall + + def remove_writer(self, fd): + """Remove a writer callback.""" + self.pollster.unregister_writer(fd) + + def add_callback(self, dcall): + """Add a DelayedCall to ready or scheduled.""" + if dcall.cancelled: + return + if dcall.when is None: + self.ready.append(dcall) + else: + heapq.heappush(self.scheduled, dcall) + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + dcall = DelayedCall(None, callback, args) + self.ready.append(dcall) + return dcall + + def call_later(self, when, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The time can be an int or float, expressed in seconds. + + If when is small enough (~11 days), it's assumed to be a + relative time, meaning the call will be scheduled that many + seconds in the future; otherwise it's assumed to be a posix + timestamp as returned by time.time(). + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + if when < 10000000: + when += time.time() + dcall = DelayedCall(when, callback, args) + heapq.heappush(self.scheduled, dcall) + return dcall + + def run_once(self): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Pass in a timeout or deadline or something. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: As step 4, run everything scheduled by steps 1-3. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # TODO: Ensure this loop always finishes, even if some + # callbacks keeps registering more callbacks. + while self.ready: + dcall = self.ready.popleft() + if not dcall.cancelled: + try: + if dcall.kwds: + dcall.callback(*dcall.args, **dcall.kwds) + else: + dcall.callback(*dcall.args) + except Exception: + logging.exception('Exception in callback %s %r', + dcall.callback, dcall.args) + + # Remove delayed calls that were cancelled from head of queue. + while self.scheduled and self.scheduled[0].cancelled: + heapq.heappop(self.scheduled) + + # Inspect the poll queue. + if self.pollster.pollable(): + if self.scheduled: + when = self.scheduled[0].when + timeout = max(0, when - time.time()) + else: + timeout = None + t0 = time.time() + events = self.pollster.poll(timeout) + t1 = time.time() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + for dcall in events: + self.add_callback(dcall) + + # Handle 'later' callbacks that are ready. + now = time.time() + while self.scheduled: + dcall = self.scheduled[0] + if dcall.when > now: + break + dcall = heapq.heappop(self.scheduled) + self.call_soon(dcall.callback, *dcall.args) + + def run(self): + """Run the event loop until there is no work left to do. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + """ + while self.ready or self.scheduled or self.pollster.pollable(): + self.run_once() + + +MAX_WORKERS = 5 # Default max workers when creating an executor. + + +class ThreadRunner: + """Helper to submit work to a thread pool and wait for it. + + This is the glue between the single-threaded callback-based async + world and the threaded world. Use it to call functions that must + block and don't have an async alternative (e.g. getaddrinfo()). + + The only public API is submit(). + """ + + def __init__(self, eventloop, executor=None): + self.eventloop = eventloop + self.executor = executor # Will be constructed lazily. + self.pipe_read_fd, self.pipe_write_fd = os.pipe() + self.active_count = 0 + + def read_callback(self): + """Semi-permanent callback while at least one future is active.""" + assert self.active_count > 0, self.active_count + data = os.read(self.pipe_read_fd, 8192) # Traditional buffer size. + self.active_count -= len(data) + if self.active_count == 0: + self.eventloop.remove_reader(self.pipe_read_fd) + assert self.active_count >= 0, self.active_count + + def submit(self, func, *args, executor=None, callback=None): + """Submit a function to the thread pool. + + This returns a concurrent.futures.Future instance. The caller + should not wait for that, but rather use the callback argument.. + """ + if executor is None: + executor = self.executor + if executor is None: + # Lazily construct a default executor. + # TODO: Should this be shared between threads? + executor = concurrent.futures.ThreadPoolExecutor(MAX_WORKERS) + self.executor = executor + assert self.active_count >= 0, self.active_count + future = executor.submit(func, *args) + if self.active_count == 0: + self.eventloop.add_reader(self.pipe_read_fd, self.read_callback) + self.active_count += 1 + def done_callback(fut): + if callback is not None: + self.eventloop.call_soon(callback, fut) + # TODO: Wake up the pipe in call_soon()? + os.write(self.pipe_write_fd, b'x') + future.add_done_callback(done_callback) + return future + + +class Context(threading.local): + """Thread-local context. + + We use this to avoid having to explicitly pass around an event loop + or something to hold the current task. + + TODO: Add an API so frameworks can substitute a different notion + of context more easily. + """ + + def __init__(self, eventloop=None, threadrunner=None): + # Default event loop and thread runner are lazily constructed + # when first accessed. + self._eventloop = eventloop + self._threadrunner = threadrunner + self.current_task = None # For the benefit of scheduling.py. + + @property + def eventloop(self): + if self._eventloop is None: + self._eventloop = EventLoop() + return self._eventloop + + @property + def threadrunner(self): + if self._threadrunner is None: + self._threadrunner = ThreadRunner(self.eventloop) + return self._threadrunner + + +context = Context() # Thread-local! diff --git a/old/scheduling.py b/old/scheduling.py new file mode 100644 index 00000000..3864571d --- /dev/null +++ b/old/scheduling.py @@ -0,0 +1,354 @@ +#!/usr/bin/env python3.3 +"""Example coroutine scheduler, PEP-380-style ('yield from '). + +Requires Python 3.3. + +There are likely micro-optimizations possible here, but that's not the point. + +TODO: +- Docstrings. +- Unittests. + +PATTERNS TO TRY: +- Various synchronization primitives (Lock, RLock, Event, Condition, + Semaphore, BoundedSemaphore, Barrier). +""" + +__author__ = 'Guido van Rossum ' + +# Standard library imports (keep in alphabetic order). +from concurrent.futures import CancelledError, TimeoutError +import logging +import time +import types + +# Local imports (keep in alphabetic order). +import polling + + +context = polling.context + + +class Task: + """Wrapper around a stack of generators. + + This is a bit like a Future, but with a different interface. + + TODO: + - wait for result. + """ + + def __init__(self, gen, name=None, *, timeout=None): + assert isinstance(gen, types.GeneratorType), repr(gen) + self.gen = gen + self.name = name or gen.__name__ + self.timeout = timeout + self.eventloop = context.eventloop + self.canceleer = None + if timeout is not None: + self.canceleer = self.eventloop.call_later(timeout, self.cancel) + self.blocked = False + self.unblocker = None + self.cancelled = False + self.must_cancel = False + self.alive = True + self.result = None + self.exception = None + self.done_callbacks = [] + # Start the task immediately. + self.eventloop.call_soon(self.step) + + def add_done_callback(self, done_callback): + # For better or for worse, the callback will always be called + # with the task as an argument, like concurrent.futures.Future. + # TODO: Call it right away if task is no longer alive. + dcall = polling.DelayedCall(None, done_callback, (self,)) + self.done_callbacks.append(dcall) + self.done_callbacks = [dc for dc in self.done_callbacks + if not dc.cancelled] + return dcall + + def __repr__(self): + parts = [self.name] + is_current = (self is context.current_task) + if self.blocked: + parts.append('blocking' if is_current else 'blocked') + elif self.alive: + parts.append('running' if is_current else 'runnable') + if self.must_cancel: + parts.append('must_cancel') + if self.cancelled: + parts.append('cancelled') + if self.exception is not None: + parts.append('exception=%r' % self.exception) + elif not self.alive: + parts.append('result=%r' % (self.result,)) + if self.timeout is not None: + parts.append('timeout=%.3f' % self.timeout) + return 'Task<' + ', '.join(parts) + '>' + + def cancel(self): + if self.alive: + if not self.must_cancel and not self.cancelled: + self.must_cancel = True + if self.blocked: + self.unblock() + + def step(self): + assert self.alive, self + try: + context.current_task = self + if self.must_cancel: + self.must_cancel = False + self.cancelled = True + self.gen.throw(CancelledError()) + else: + next(self.gen) + except StopIteration as exc: + self.alive = False + self.result = exc.value + except Exception as exc: + self.alive = False + self.exception = exc + logging.debug('Uncaught exception in %s', self, + exc_info=True, stack_info=True) + except BaseException as exc: + self.alive = False + self.exception = exc + raise + else: + if not self.blocked: + self.eventloop.call_soon(self.step) + finally: + context.current_task = None + if not self.alive: + # Cancel timeout callback if set. + if self.canceleer is not None: + self.canceleer.cancel() + # Schedule done_callbacks. + for dcall in self.done_callbacks: + self.eventloop.add_callback(dcall) + + def block(self, unblock_callback=None, *unblock_args): + assert self is context.current_task, self + assert self.alive, self + assert not self.blocked, self + self.blocked = True + self.unblocker = (unblock_callback, unblock_args) + + def unblock_if_alive(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. + if self.alive: + self.unblock() + + def unblock(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. + assert self.alive, self + assert self.blocked, self + self.blocked = False + unblock_callback, unblock_args = self.unblocker + if unblock_callback is not None: + try: + unblock_callback(*unblock_args) + except Exception: + logging.error('Exception in unblocker in task %r', self.name) + raise + finally: + self.unblocker = None + self.eventloop.call_soon(self.step) + + def block_io(self, fd, flag): + assert isinstance(fd, int), repr(fd) + assert flag in ('r', 'w'), repr(flag) + if flag == 'r': + self.block(self.eventloop.remove_reader, fd) + self.eventloop.add_reader(fd, self.unblock) + else: + self.block(self.eventloop.remove_writer, fd) + self.eventloop.add_writer(fd, self.unblock) + + def wait(self): + """COROUTINE: Wait until this task is finished.""" + current_task = context.current_task + assert self is not current_task, (self, current_task) # How confusing! + if not self.alive: + return + current_task.block() + self.add_done_callback(current_task.unblock) + yield + + def __iter__(self): + """COROUTINE: Wait, then return result or raise exception. + + This adds a little magic so you can say + + x = yield from Task(gen()) + + and it is equivalent to + + x = yield from gen() + + but with the option to add a timeout (and only a tad slower). + """ + if self.alive: + yield from self.wait() + assert not self.alive + if self.exception is not None: + raise self.exception + return self.result + + +def run(arg=None): + """Run the event loop until it's out of work. + + If you pass a generator, it will be spawned for you. + You can also pass a task (already started). + Returns the task. + """ + t = None + if arg is not None: + if isinstance(arg, Task): + t = arg + else: + t = Task(arg) + context.eventloop.run() + if t is not None and t.exception is not None: + logging.error('Uncaught exception in startup task: %r', + t.exception) + return t + + +def sleep(secs): + """COROUTINE: Sleep for some time (a float in seconds).""" + current_task = context.current_task + unblocker = context.eventloop.call_later(secs, current_task.unblock) + current_task.block(unblocker.cancel) + yield + + +def block_r(fd): + """COROUTINE: Block until a file descriptor is ready for reading.""" + context.current_task.block_io(fd, 'r') + yield + + +def block_w(fd): + """COROUTINE: Block until a file descriptor is ready for writing.""" + context.current_task.block_io(fd, 'w') + yield + + +def call_in_thread(func, *args, executor=None): + """COROUTINE: Run a function in a thread.""" + task = context.current_task + eventloop = context.eventloop + future = context.threadrunner.submit(func, *args, + executor=executor, + callback=task.unblock_if_alive) + task.block(future.cancel) + yield + assert future.done() + return future.result() + + +def wait_for(count, tasks): + """COROUTINE: Wait for the first N of a set of tasks to complete. + + May return more than N if more than N are immediately ready. + + NOTE: Tasks that were cancelled or raised are also considered ready. + """ + assert tasks + assert all(isinstance(task, Task) for task in tasks) + tasks = set(tasks) + assert 1 <= count <= len(tasks) + current_task = context.current_task + assert all(task is not current_task for task in tasks) + todo = set() + done = set() + dcalls = [] + def wait_for_callback(task): + nonlocal todo, done, current_task, count, dcalls + todo.remove(task) + if len(done) < count: + done.add(task) + if len(done) == count: + for dcall in dcalls: + dcall.cancel() + current_task.unblock() + for task in tasks: + if task.alive: + todo.add(task) + else: + done.add(task) + if len(done) < count: + for task in todo: + dcall = task.add_done_callback(wait_for_callback) + dcalls.append(dcall) + current_task.block() + yield + return done + + +def wait_any(tasks): + """COROUTINE: Wait for the first of a set of tasks to complete.""" + return wait_for(1, tasks) + + +def wait_all(tasks): + """COROUTINE: Wait for all of a set of tasks to complete.""" + return wait_for(len(tasks), tasks) + + +def map_over(gen, *args, timeout=None): + """COROUTINE: map a generator over one or more iterables. + + E.g. map_over(foo, xs, ys) runs + + Task(foo(x, y)) for x, y in zip(xs, ys) + + and returns a list of all results (in that order). However if any + task raises an exception, the remaining tasks are cancelled and + the exception is propagated. + """ + # gen is a generator function. + tasks = [Task(gobj, timeout=timeout) for gobj in map(gen, *args)] + return (yield from par_tasks(tasks)) + + +def par(*args): + """COROUTINE: Wait for generators, return a list of results. + + Raises as soon as one of the tasks raises an exception (and then + remaining tasks are cancelled). + + This differs from par_tasks() in two ways: + - takes *args instead of list of args + - each arg may be a generator or a task + """ + tasks = [] + for arg in args: + if not isinstance(arg, Task): + # TODO: assert arg is a generator or an iterator? + arg = Task(arg) + tasks.append(arg) + return (yield from par_tasks(tasks)) + + +def par_tasks(tasks): + """COROUTINE: Wait for a list of tasks, return a list of results. + + Raises as soon as one of the tasks raises an exception (and then + remaining tasks are cancelled). + """ + todo = set(tasks) + while todo: + ts = yield from wait_any(todo) + for t in ts: + assert not t.alive, t + todo.remove(t) + if t.exception is not None: + for other in todo: + other.cancel() + raise t.exception + return [t.result for t in tasks] diff --git a/old/sockets.py b/old/sockets.py new file mode 100644 index 00000000..a5005dc3 --- /dev/null +++ b/old/sockets.py @@ -0,0 +1,348 @@ +"""Socket wrappers to go with scheduling.py. + +Classes: + +- SocketTransport: a transport implementation wrapping a socket. +- SslTransport: a transport implementation wrapping SSL around a socket. +- BufferedReader: a buffer wrapping the read end of a transport. + +Functions (all coroutines): + +- connect(): connect a socket. +- getaddrinfo(): look up an address. +- create_connection(): look up address and return a connected socket for it. +- create_transport(): look up address and return a connected transport. + +TODO: +- Improve transport abstraction. +- Make a nice protocol abstraction. +- Unittests. +- A write() call that isn't a generator (needed so you can substitute it + for sys.stderr, pass it to logging.StreamHandler, etc.). +""" + +__author__ = 'Guido van Rossum ' + +# Stdlib imports. +import errno +import socket +import ssl + +# Local imports. +import scheduling + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + + +class SocketTransport: + """Transport wrapping a socket. + + The socket must already be connected in non-blocking mode. + """ + + def __init__(self, sock): + self.sock = sock + + def recv(self, n): + """COROUTINE: Read up to n bytes, blocking as needed. + + Always returns at least one byte, except if the socket was + closed or disconnected and there's no more data; then it + returns b''. + """ + assert n >= 0, n + while True: + try: + return self.sock.recv(n) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return b'' + else: + raise # Unexpected, propagate. + yield from scheduling.block_r(self.sock.fileno()) + + def send(self, data): + """COROUTINE; Send data to the socket, blocking until all written. + + Return True if all went well, False if socket was disconnected. + """ + while data: + try: + n = self.sock.send(data) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + else: + assert 0 <= n <= len(data), (n, len(data)) + if n == len(data): + break + data = data[n:] + continue + yield from scheduling.block_w(self.sock.fileno()) + + return True + + def close(self): + """Close the socket. (Not a coroutine.)""" + self.sock.close() + + +class SslTransport: + """Transport wrapping a socket in SSL. + + The socket must already be connected at the TCP level in + non-blocking mode. + """ + + def __init__(self, rawsock, sslcontext=None): + self.rawsock = rawsock + self.sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self.sslsock = self.sslcontext.wrap_socket( + self.rawsock, do_handshake_on_connect=False) + + def do_handshake(self): + """COROUTINE: Finish the SSL handshake.""" + while True: + try: + self.sslsock.do_handshake() + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + else: + break + + def recv(self, n): + """COROUTINE: Read up to n bytes. + + This blocks until at least one byte is read, or until EOF. + """ + while True: + try: + return self.sslsock.recv(n) + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_r(self.sslsock.fileno()) + elif err.errno in _DISCONNECTED: + # Can this happen? + return b'' + else: + raise # Unexpected, propagate. + + def send(self, data): + """COROUTINE: Send data to the socket, blocking as needed.""" + while data: + try: + n = self.sslsock.send(data) + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_w(self.sslsock.fileno()) + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + if n == len(data): + break + data = data[n:] + + return True + + def close(self): + """Close the socket. (Not a coroutine.) + + This also closes the raw socket. + """ + self.sslsock.close() + + # TODO: More SSL-specific methods, e.g. certificate stuff, unwrap(), ... + + +class BufferedReader: + """A buffered reader wrapping a transport.""" + + def __init__(self, trans, limit=8192): + self.trans = trans + self.limit = limit + self.buffer = b'' + self.eof = False + + def read(self, n): + """COROUTINE: Read up to n bytes, blocking at most once.""" + assert n >= 0, n + if not self.buffer and not self.eof: + yield from self._fillbuffer(max(n, self.limit)) + return self._getfrombuffer(n) + + def readexactly(self, n): + """COUROUTINE: Read exactly n bytes, or until EOF.""" + blocks = [] + count = 0 + while count < n: + block = yield from self.read(n - count) + if not block: + break + blocks.append(block) + count += len(block) + return b''.join(blocks) + + def readline(self): + """COROUTINE: Read up to newline or limit, whichever comes first.""" + end = self.buffer.find(b'\n') + 1 # Point past newline, or 0. + while not end and not self.eof and len(self.buffer) < self.limit: + anchor = len(self.buffer) + yield from self._fillbuffer(self.limit) + end = self.buffer.find(b'\n', anchor) + 1 + if not end: + end = len(self.buffer) + if end > self.limit: + end = self.limit + return self._getfrombuffer(end) + + def _getfrombuffer(self, n): + """Read up to n bytes without blocking (not a coroutine).""" + if n >= len(self.buffer): + result, self.buffer = self.buffer, b'' + else: + result, self.buffer = self.buffer[:n], self.buffer[n:] + return result + + def _fillbuffer(self, n): + """COROUTINE: Fill buffer with one (up to) n bytes from transport.""" + assert not self.eof, '_fillbuffer called at eof' + data = yield from self.trans.recv(n) + if data: + self.buffer += data + else: + self.eof = True + + +def connect(sock, address): + """COROUTINE: Connect a socket to an address.""" + try: + sock.connect(address) + except socket.error as err: + if err.errno != errno.EINPROGRESS: + raise + yield from scheduling.block_w(sock.fileno()) + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + raise IOError(err, 'Connection refused') + + +def getaddrinfo(host, port, af=0, socktype=0, proto=0): + """COROUTINE: Look up an address and return a list of infos for it. + + Each info is a tuple (af, socktype, protocol, canonname, address). + """ + infos = yield from scheduling.call_in_thread(socket.getaddrinfo, + host, port, af, + socktype, proto) + return infos + + +def create_connection(host, port, af=0, socktype=socket.SOCK_STREAM, proto=0): + """COROUTINE: Look up address and create a socket connected to it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + yield from connect(sock, address) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return sock + + +def create_transport(host, port, af=0, ssl=None): + """COROUTINE: Look up address and create a transport connected to it.""" + if ssl is None: + ssl = (port == 443) + sock = yield from create_connection(host, port, af) + if ssl: + trans = SslTransport(sock) + yield from trans.do_handshake() + else: + trans = SocketTransport(sock) + return trans + + +class Listener: + """Wrapper for a listening socket.""" + + def __init__(self, sock): + self.sock = sock + + def accept(self): + """COROUTINE: Accept a connection.""" + while True: + try: + conn, addr = self.sock.accept() + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_r(self.sock.fileno()) + else: + raise # Unexpected, propagate. + else: + conn.setblocking(False) + return conn, addr + + +def create_listener(host, port, af=0, socktype=0, proto=0, + backlog=5, reuse_addr=True): + """COROUTINE: Look up address and create a listener for it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + if reuse_addr: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + sock.listen(backlog) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return Listener(sock) diff --git a/old/transports.py b/old/transports.py new file mode 100644 index 00000000..19095bf4 --- /dev/null +++ b/old/transports.py @@ -0,0 +1,496 @@ +"""Transports and Protocols, actually. + +Inspired by Twisted, PEP 3153 and github.com/lvh/async-pep. + +THIS IS NOT REAL CODE! IT IS JUST AN EXPERIMENT. +""" + +# Stdlib imports. +import collections +import errno +import logging +import socket +import ssl +import sys +import time + +# Local imports. +import polling +import scheduling +import sockets + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + + +class Transport: + """ABC representing a transport. + + There may be many implementations. The user never instantiates + this directly; they call some utility function, passing it a + protocol, and the utility function will call the protocol's + connection_made() method with a transport (or it will call + connection_lost() with an exception if it fails to create the + desired transport). + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + def write(self, data): + """Write some data (bytes) to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data (bytes) to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data will + be received. When all buffered data is flushed, the protocol's + connection_lost() method is called with None as its argument. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method is called with None as + its argument. + """ + raise NotImplementedError + + def half_close(self): + """Closes the write end after flushing buffered data. + + Data may still be received. + + TODO: What's the use case for this? How to implement it? + Should it call shutdown(SHUT_WR) after all the data is flushed? + Is there no use case for closing the other half first? + """ + raise NotImplementedError + + def pause(self): + """Pause the receiving end. + + No data will be received until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Cancels a pause() call, resumes receiving data. + """ + raise NotImplementedError + + +class Protocol: + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing. + + When the user wants to requests a transport, they pass a protocol + instance to a utility function. + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_list() will be called exactly once + with either an exception object or None as an argument. + + If the utility function does not succeed in creating a transport, + it will call connection_lost() with an exception object. + + State machine of calls: + + start -> [CM -> DR*] -> CL -> end + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the connection. + To send data, call its write() or writelines() method. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + + TODO: Should we allow it to be a bytesarray or some other + memory buffer? + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + Also called when we fail to make a connection at all (in that + case connection_made() will not be called). + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +# TODO: The rest is platform specific and should move elsewhere. + +class UnixSocketTransport(Transport): + + def __init__(self, eventloop, protocol, sock): + self._eventloop = eventloop + self._protocol = protocol + self._sock = sock + self._buffer = collections.deque() # For write(). + self._write_closed = False + + def _on_readable(self): + try: + data = self._sock.recv(8192) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + else: + if not data: + self._eventloop.remove_reader(self._sock.fileno()) + self._sock.close() + self._protocol.connection_lost(None) + else: + self._protocol.data_received(data) # XXX call_soon()? + + def write(self, data): + assert isinstance(data, bytes) + assert not self._write_closed + if not data: + # Silly, but it happens. + return + if self._buffer: + # We've already registered a callback, just buffer the data. + self._buffer.append(data) + # Consider pausing if the total length of the buffer is + # truly huge. + return + + # TODO: Refactor so there's more sharing between this and + # _on_writable(). + + # There's no callback registered yet. It's quite possible + # that the kernel has buffer space for our data, so try to + # write now. Since the socket is non-blocking it will + # give us an error in _TRYAGAIN if it doesn't have enough + # space for even one more byte; it will return the number + # of bytes written if it can write at least one byte. + try: + n = self._sock.send(data) + except socket.error as exc: + # An error. + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + # The kernel doesn't have room for more data right now. + n = 0 + else: + # Wrote at least one byte. + if n == len(data): + # Wrote it all. Done! + if self._write_closed: + self._sock.shutdown(socket.SHUT_WR) + return + # Throw away the data that was already written. + # TODO: Do this without copying the data? + data = data[n:] + self._buffer.append(data) + self._eventloop.add_writer(self._sock.fileno(), self._on_writable) + + def _on_writable(self): + while self._buffer: + data = self._buffer[0] + # TODO: Join small amounts of data? + try: + n = self._sock.send(data) + except socket.error as exc: + # Error handling is the same as in write(). + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + if n < len(data): + self._buffer[0] = data[n:] + return + self._buffer.popleft() + self._eventloop.remove_writer(self._sock.fileno()) + if self._write_closed: + self._sock.shutdown(socket.SHUT_WR) + + def abort(self): + self._bad_error(None) + + def _bad_error(self, exc): + # A serious error. Close the socket etc. + fd = self._sock.fileno() + # TODO: Record whether we have a writer and/or reader registered. + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + self._sock.close() + self._protocol.connection_lost(exc) # XXX call_soon()? + + def half_close(self): + self._write_closed = True + + +class UnixSslTransport(Transport): + + # TODO: Refactor Socket and Ssl transport to share some code. + # (E.g. buffering.) + + # TODO: Consider using coroutines instead of callbacks, it seems + # much easier that way. + + def __init__(self, eventloop, protocol, rawsock, sslcontext=None): + self._eventloop = eventloop + self._protocol = protocol + self._rawsock = rawsock + self._sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslsock = self._sslcontext.wrap_socket( + self._rawsock, do_handshake_on_connect=False) + + self._buffer = collections.deque() # For write(). + self._write_closed = False + + # Try the handshake now. Likely it will raise EAGAIN, then it + # will take care of registering the appropriate callback. + self._on_handshake() + + def _bad_error(self, exc): + # A serious error. Close the socket etc. + fd = self._sslsock.fileno() + # TODO: Record whether we have a writer and/or reader registered. + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + self._sslsock.close() + self._protocol.connection_lost(exc) # XXX call_soon()? + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._eventloop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._eventloop.add_writable(fd, self._on_handshake) + return + # TODO: What if it raises another error? + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + self._protocol.connection_made(self) + self._eventloop.add_reader(fd, self._on_ready) + self._eventloop.add_writer(fd, self._on_ready) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + else: + if data: + self._protocol.data_received(data) + else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer + self._eventloop.remove_reader(fd) + self._eventloop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + return + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = self._buffer[0] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + else: + if n == len(data): + self._buffer.popleft() + # Could try again, but let's just have the next callback do it. + else: + self._buffer[0] = data[n:] + + def write(self, data): + assert isinstance(data, bytes) + assert not self._write_closed + if not data: + return + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + def half_close(self): + self._write_closed = True + # Just set the flag. Calling shutdown() on the ssl socket + # breaks something, causing recv() to return binary data. + + +def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, + use_ssl=None): + # TODO: Pass in a protocol factory, not a protocol. + # What should be the exact sequence of events? + # - socket + # - transport + # - protocol + # - tell transport about protocol + # - tell protocol about transport + # Or should the latter two be reversed? Does it matter? + if port is None: + port = 443 if use_ssl else 80 + if use_ssl is None: + use_ssl = (port == 443) + if not socktype: + socktype = socket.SOCK_STREAM + eventloop = polling.context.eventloop + + def on_socket_connected(task): + assert not task.alive + if task.exception is not None: + # TODO: Call some callback. + raise task.exception + sock = task.result + assert sock is not None + logging.debug('on_socket_connected') + if use_ssl: + # You can pass an ssl.SSLContext object as use_ssl, + # or a bool. + if isinstance(use_ssl, bool): + sslcontext = None + else: + sslcontext = use_ssl + transport = UnixSslTransport(eventloop, protocol, sock, sslcontext) + else: + transport = UnixSocketTransport(eventloop, protocol, sock) + # TODO: Should the ransport make the following calls? + protocol.connection_made(transport) # XXX call_soon()? + # Don't do this before connection_made() is called. + eventloop.add_reader(sock.fileno(), transport._on_readable) + + coro = sockets.create_connection(host, port, af, socktype, proto) + task = scheduling.Task(coro) + task.add_done_callback(on_socket_connected) + + +def main(): # Testing... + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + host = 'xkcd.com' + if sys.argv[1:] and '.' in sys.argv[-1]: + host = sys.argv[-1] + + t0 = time.time() + + class TestProtocol(Protocol): + def connection_made(self, transport): + logging.info('Connection made at %.3f secs', time.time() - t0) + self.transport = transport + self.transport.write(b'GET / HTTP/1.0\r\nHost: ' + + host.encode('ascii') + + b'\r\n\r\n') + self.transport.half_close() + def data_received(self, data): + logging.info('Received %d bytes at t=%.3f', + len(data), time.time() - t0) + logging.debug('Received %r', data) + def connection_lost(self, exc): + logging.debug('Connection lost: %r', exc) + self.t1 = time.time() + logging.info('Total time %.3f secs', self.t1 - t0) + + tp = TestProtocol() + logging.debug('tp = %r', tp) + make_connection(tp, host, use_ssl=('-S' in sys.argv)) + logging.info('Running...') + polling.context.eventloop.run() + logging.info('Done.') + + +if __name__ == '__main__': + main() diff --git a/old/xkcd.py b/old/xkcd.py new file mode 100755 index 00000000..474009d0 --- /dev/null +++ b/old/xkcd.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3.3 +"""Minimal synchronous SSL demo, connecting to xkcd.com.""" + +import socket, ssl + +s = socket.socket() +s.connect(('xkcd.com', 443)) +ss = ssl.wrap_socket(s) + +ss.send(b'GET / HTTP/1.0\r\n\r\n') + +while True: + data = ss.recv(1000000) + print(data) + if not data: + break + +ss.close() diff --git a/old/yyftime.py b/old/yyftime.py new file mode 100644 index 00000000..f55234b9 --- /dev/null +++ b/old/yyftime.py @@ -0,0 +1,75 @@ +"""Compare timing of yield-from vs. yield calls.""" + +import gc +import time + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def run_coro(depth): + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + print('coro', depth, k, round(t1-t0, 6)) + return t1-t0 + +class Future: + + def __init__(self, g): + self.g = g + + def wait(self): + value = None + try: + while True: + f = self.g.send(value) + f.wait() + value = f.value + except StopIteration as err: + self.value = err.value + + + +def task(func): # Decorator + def wrapper(*args): + g = func(*args) + f = Future(g) + return f + return wrapper + +@task +def oldstyle(n): + if n <= 0: + return 1 + l = yield oldstyle(n-1) + r = yield oldstyle(n-1) + return l + 1 + r + +def run_olds(depth): + t0 = time.time() + f = oldstyle(depth) + f.wait() + k = f.value + t1 = time.time() + print('olds', depth, k, round(t1-t0, 6)) + return t1-t0 + +def main(): + gc.disable() + for depth in range(16): + tc = run_coro(depth) + to = run_olds(depth) + if tc: + print('ratio', round(to/tc, 2)) + +if __name__ == '__main__': + main() diff --git a/runtests.py b/runtests.py new file mode 100644 index 00000000..fde87f08 --- /dev/null +++ b/runtests.py @@ -0,0 +1,61 @@ +"""Run all unittests. + +Usage: + python3 runtests.py [-v] [-q] [pattern] ... + +Where: + -v: verbose + -q: quiet + pattern: optional regex patterns to match test ids (default all tests) + +Note that the test id is the fully qualified name of the test, +including package, module, class and method, +e.g. 'tulip.events_test.PolicyTests.testPolicy'. +""" + +# Originally written by Beech Horn (for NDB). + +import re +import sys +import unittest + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +def load_tests(patterns=()): + mods = ['events', 'futures', 'tasks'] + test_mods = ['%s_test' % name for name in mods] + tulip = __import__('tulip', fromlist=test_mods) + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + for mod in [getattr(tulip, name) for name in test_mods]: + for name in set(dir(mod)): + if name.endswith('Tests'): + test_module = getattr(mod, name) + tests = loader.loadTestsFromTestCase(test_module) + if patterns: + tests = [test + for test in tests + if any(re.search(pat, test.id()) for pat in patterns)] + suite.addTests(tests) + + return suite + + +def main(): + patterns = [] + v = 1 + for arg in sys.argv[1:]: + if arg.startswith('-v'): + v += arg.count('v') + elif arg == '-q': + v = 0 + elif arg and not arg.startswith('-'): + patterns.append(arg) + result = unittest.TextTestRunner(verbosity=v).run(load_tests(patterns)) + sys.exit(not result.wasSuccessful()) + + +if __name__ == '__main__': + main() diff --git a/tulip/TODO b/tulip/TODO new file mode 100644 index 00000000..b3a9302e --- /dev/null +++ b/tulip/TODO @@ -0,0 +1,26 @@ +TODO in tulip v2 (tulip/ package directory) +------------------------------------------- + +- See also TBD and Open Issues in PEP 3156 + +- Refactor unix_events.py (it's getting too long) + +- Docstrings + +- Unittests + +- better run_once() behavior? (Run ready list last.) + +- start_serving() + +- Make Handler() callable? Log the exception in there? + +- Add the repeat interval to the Handler class? + +- Recognize Handler passed to add_reader(), call_soon(), etc.? + +- SSL support + +- buffered stream implementation + +- Primitives like par() and wait_one() diff --git a/tulip/__init__.py b/tulip/__init__.py new file mode 100644 index 00000000..185fe3fe --- /dev/null +++ b/tulip/__init__.py @@ -0,0 +1,14 @@ +"""Tulip 2.0, tracking PEP 3156.""" + +# This relies on each of the submodules having an __all__ variable. +from .futures import * +from .events import * +from .transports import * +from .protocols import * +from .tasks import * + +__all__ = (futures.__all__ + + events.__all__ + + transports.__all__ + + protocols.__all__ + + tasks.__all__) diff --git a/tulip/events.py b/tulip/events.py new file mode 100644 index 00000000..a5c3e383 --- /dev/null +++ b/tulip/events.py @@ -0,0 +1,287 @@ +"""Event loop and event loop policy. + +Beyond the PEP: +- Only the main thread has a default event loop. +""" + +__all__ = ['EventLoopPolicy', 'DefaultEventLoopPolicy', + 'EventLoop', 'Handler', 'make_handler', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', 'new_event_loop', + ] + +import threading + + +class Handler: + """Object returned by callback registration methods.""" + + def __init__(self, when, callback, args): + self._when = when + self._callback = callback + self._args = args + self._cancelled = False + + def __repr__(self): + res = 'Handler({}, {}, {})'.format(self._when, + self._callback, + self._args) + if self._cancelled: + res += '' + return res + + @property + def when(self): + return self._when + + @property + def callback(self): + return self._callback + + @property + def args(self): + return self._args + + @property + def cancelled(self): + return self._cancelled + + def cancel(self): + self._cancelled = True + + def __lt__(self, other): + return self._when < other._when + + def __le__(self, other): + return self._when <= other._when + + def __gt__(self, other): + return self._when > other._when + + def __ge__(self, other): + return self._when >= other._when + + def __eq__(self, other): + return self._when == other._when + + +def make_handler(when, callback, args): + if isinstance(callback, Handler): + assert not args + assert when is None + return callback + return Handler(when, callback, args) + + +class EventLoop: + """Abstract event loop.""" + + # TODO: Rename run() -> run_until_idle(), run_forever() -> run(). + + def run(self): + """Run the event loop. Block until there is nothing left to do.""" + raise NotImplementedError + + def run_forever(self): + """Run the event loop. Block until stop() is called.""" + raise NotImplementedError + + def run_once(self, timeout=None): # NEW! + """Run one complete cycle of the event loop.""" + raise NotImplementedError + + def run_until_complete(self, future, timeout=None): # NEW! + """Run the event loop until a Future is done. + + Return the Future's result, or raise its exception. + + If timeout is not None, run it for at most that long; + if the Future is still not done, raise TimeoutError + (but don't cancel the Future). + """ + raise NotImplementedError + + def stop(self): # NEW! + """Stop the event loop as soon as reasonable. + + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + raise NotImplementedError + + # Methods returning Handlers for scheduling callbacks. + + def call_later(self, delay, callback, *args): + raise NotImplementedError + + def call_repeatedly(self, interval, callback, *args): # NEW! + raise NotImplementedError + + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) + + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError + + # Methods returning Futures for interacting with threads. + + def wrap_future(self, future): + raise NotImplementedError + + def run_in_executor(self, executor, callback, *args): + raise NotImplementedError + + # Network I/O methods returning Futures. + + def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError + + def create_transport(self, protocol_factory, host, port, *, + family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def start_serving(self, protocol_factory, host, port, *, + family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + # Ready-based callback registration methods. + # The add_*() methods return a Handler. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. + + def add_reader(self, fd, callback, *args): + raise NotImplementedError + + def remove_reader(self, fd): + raise NotImplementedError + + def add_writer(self, fd, callback, *args): + raise NotImplementedError + + def remove_writer(self, fd): + raise NotImplementedError + + def add_connector(self, fd, callback, *args): + raise NotImplementedError + + def remove_connector(self, fd): + raise NotImplementedError + + # Completion based I/O methods returning Futures. + + def sock_recv(self, sock, nbytes): + raise NotImplementedError + + def sock_sendall(self, sock, data): + raise NotImplementedError + + def sock_connect(self, sock, address): + raise NotImplementedError + + def sock_accept(self, sock): + raise NotImplementedError + + # Signal handling. + + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError + + def remove_signal_handler(self, sig): + raise NotImplementedError + + +class EventLoopPolicy: + """Abstract policy for accessing the event loop.""" + + def get_event_loop(self): + """XXX""" + raise NotImplementedError + + def set_event_loop(self, event_loop): + """XXX""" + raise NotImplementedError + + def new_event_loop(self): + """XXX""" + raise NotImplementedError + + +class DefaultEventLoopPolicy(threading.local, EventLoopPolicy): + """Default policy implementation for accessing the event loop. + + In this policy, each thread has its own event loop. However, we + only automatically create an event loop by default for the main + thread; other threads by default have no event loop. + + Other policies may have different rules (e.g. a single global + event loop, or automatically creating an event loop per thread, or + using some other notion of context to which an event loop is + associated). + """ + + _event_loop = None + + def get_event_loop(self): + """Get the event loop. + + This may be None or an instance of EventLoop. + """ + if (self._event_loop is None and + threading.current_thread().name == 'MainThread'): + self._event_loop = self.new_event_loop() + return self._event_loop + + def set_event_loop(self, event_loop): + """Set the event loop.""" + assert event_loop is None or isinstance(event_loop, EventLoop) + self._event_loop = event_loop + + def new_event_loop(self): + """Create a new event loop. + + You must call set_event_loop() to make this the current event + loop. + """ + # TODO: Do something else for Windows. + from . import unix_events + return unix_events.UnixEventLoop() + + +# Event loop policy. The policy itself is always global, even if the +# policy's rules say that there is an event loop per thread (or other +# notion of context). The default policy is installed by the first +# call to get_event_loop_policy(). +_event_loop_policy = None + + +def get_event_loop_policy(): + """XXX""" + global _event_loop_policy + if _event_loop_policy is None: + _event_loop_policy = DefaultEventLoopPolicy() + return _event_loop_policy + + +def set_event_loop_policy(policy): + """XXX""" + global _event_loop_policy + assert policy is None or isinstance(policy, EventLoopPolicy) + _event_loop_policy = policy + + +def get_event_loop(): + """XXX""" + return get_event_loop_policy().get_event_loop() + + +def set_event_loop(event_loop): + """XXX""" + get_event_loop_policy().set_event_loop(event_loop) + + +def new_event_loop(): + """XXX""" + return get_event_loop_policy().new_event_loop() diff --git a/tulip/events_test.py b/tulip/events_test.py new file mode 100644 index 00000000..f8b718e0 --- /dev/null +++ b/tulip/events_test.py @@ -0,0 +1,428 @@ +"""Tests for events.py.""" + +import concurrent.futures +import gc +import os +import select +import signal +import socket +import threading +import time +import unittest + +from . import events +from . import transports +from . import protocols +from . import selectors +from . import unix_events + + +class MyProto(protocols.Protocol): + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: xkcd.com\r\n\r\n') + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + + +class EventLoopTestsMixin: + + def setUp(self): + self.selector = self.SELECTOR_CLASS() + self.event_loop = unix_events.UnixEventLoop(self.selector) + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + gc.collect() + + def testRun(self): + el = events.get_event_loop() + el.run() # Returns immediately. + + def testCallLater(self): + el = events.get_event_loop() + results = [] + def callback(arg): + results.append(arg) + el.call_later(0.1, callback, 'hello world') + t0 = time.monotonic() + el.run() + t1 = time.monotonic() + self.assertEqual(results, ['hello world']) + self.assertTrue(t1-t0 >= 0.09) + + def testCallRepeatedly(self): + el = events.get_event_loop() + results = [] + def callback(arg): + results.append(arg) + el.call_repeatedly(0.03, callback, 'ho') + el.call_later(0.1, el.stop) + el.run() + self.assertEqual(results, ['ho', 'ho', 'ho']) + + def testCallSoon(self): + el = events.get_event_loop() + results = [] + def callback(arg1, arg2): + results.append((arg1, arg2)) + el.call_soon(callback, 'hello', 'world') + el.run() + self.assertEqual(results, [('hello', 'world')]) + + def testCallSoonWithHandler(self): + el = events.get_event_loop() + results = [] + def callback(): + results.append('yeah') + handler = events.Handler(None, callback, ()) + self.assertEqual(el.call_soon(handler), handler) + el.run() + self.assertEqual(results, ['yeah']) + + def testCallSoonThreadsafe(self): + el = events.get_event_loop() + results = [] + def callback(arg): + results.append(arg) + def run(): + el.call_soon_threadsafe(callback, 'hello') + t = threading.Thread(target=run) + el.call_later(0.1, callback, 'world') + t0 = time.monotonic() + t.start() + el.run() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + + def testCallSoonThreadsafeWithHandler(self): + el = events.get_event_loop() + results = [] + def callback(arg): + results.append(arg) + handler = events.Handler(None, callback, ('hello',)) + def run(): + self.assertEqual(el.call_soon_threadsafe(handler), handler) + t = threading.Thread(target=run) + el.call_later(0.1, callback, 'world') + t0 = time.monotonic() + t.start() + el.run() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + + def testCallEveryIteration(self): + el = events.get_event_loop() + results = [] + def callback(arg): + results.append(arg) + handler = el.call_every_iteration(callback, 'ho') + el.run_once() + self.assertEqual(results, ['ho']) + el.run_once() + el.run_once() + self.assertEqual(results, ['ho', 'ho', 'ho']) + handler.cancel() + el.run_once() + self.assertEqual(results, ['ho', 'ho', 'ho']) + + def testCallEveryIterationWithHandler(self): + el = events.get_event_loop() + results = [] + def callback(arg): + results.append(arg) + handler = events.Handler(None, callback, ('ho',)) + self.assertEqual(el.call_every_iteration(handler), handler) + el.run_once() + self.assertEqual(results, ['ho']) + el.run_once() + el.run_once() + self.assertEqual(results, ['ho', 'ho', 'ho']) + handler.cancel() + el.run_once() + self.assertEqual(results, ['ho', 'ho', 'ho']) + + def testWrapFuture(self): + el = events.get_event_loop() + def run(arg): + time.sleep(0.1) + return arg + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = el.wrap_future(f1) + res = el.run_until_complete(f2) + self.assertEqual(res, 'oi') + + def testRunInExecutor(self): + el = events.get_event_loop() + def run(arg): + time.sleep(0.1) + return arg + f2 = el.run_in_executor(None, run, 'yo') + res = el.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def testRunInExecutorWithHandler(self): + el = events.get_event_loop() + def run(arg): + time.sleep(0.1) + return arg + handler = events.Handler(None, run, ('yo',)) + f2 = el.run_in_executor(None, handler) + res = el.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def testReaderCallback(self): + el = events.get_event_loop() + r, w = unix_events.socketpair() + bytes_read = [] + def reader(): + data = r.recv(1024) + if data: + bytes_read.append(data) + else: + el.remove_reader(r.fileno()) + r.close() + el.add_reader(r.fileno(), reader) + el.call_later(0.05, w.send, b'abc') + el.call_later(0.1, w.send, b'def') + el.call_later(0.15, w.close) + el.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def testReaderCallbackWithHandler(self): + el = events.get_event_loop() + r, w = unix_events.socketpair() + bytes_read = [] + def reader(): + data = r.recv(1024) + if data: + bytes_read.append(data) + else: + el.remove_reader(r.fileno()) + r.close() + handler = events.Handler(None, reader, ()) + self.assertEqual(el.add_reader(r.fileno(), handler), handler) + el.call_later(0.05, w.send, b'abc') + el.call_later(0.1, w.send, b'def') + el.call_later(0.15, w.close) + el.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def testWriterCallback(self): + el = events.get_event_loop() + r, w = unix_events.socketpair() + w.setblocking(False) + el.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + el.call_later(0.1, el.remove_writer, w.fileno()) + el.run() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def testWriterCallbackWithHandler(self): + el = events.get_event_loop() + r, w = unix_events.socketpair() + w.setblocking(False) + handler = events.Handler(None, w.send, (b'x'*(256*1024),)) + self.assertEqual(el.add_writer(w.fileno(), handler), handler) + el.call_later(0.1, el.remove_writer, w.fileno()) + el.run() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def testSockClientOps(self): + el = events.get_event_loop() + sock = socket.socket() + sock.setblocking(False) + # TODO: This depends on python.org behavior! + el.run_until_complete(el.sock_connect(sock, ('python.org', 80))) + el.run_until_complete(el.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = el.run_until_complete(el.sock_recv(sock, 1024)) + sock.close() + self.assertTrue(data.startswith(b'HTTP/1.1 302 Found\r\n')) + + def testSockClientFail(self): + el = events.get_event_loop() + sock = socket.socket() + sock.setblocking(False) + # TODO: This depends on python.org behavior! + with self.assertRaises(ConnectionRefusedError): + el.run_until_complete(el.sock_connect(sock, ('python.org', 12345))) + sock.close() + + def testSockAccept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + el = events.get_event_loop() + f = el.sock_accept(listener) + conn, addr = el.run_until_complete(f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + + def testAddSignalHandler(self): + caught = 0 + def my_handler(): + nonlocal caught + caught += 1 + el = events.get_event_loop() + # Check error behavior first. + self.assertRaises(TypeError, el.add_signal_handler, 'boom', my_handler) + self.assertRaises(TypeError, el.remove_signal_handler, 'boom') + self.assertRaises(ValueError, el.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises(ValueError, el.remove_signal_handler, signal.NSIG+1) + self.assertRaises(ValueError, el.add_signal_handler, 0, my_handler) + self.assertRaises(ValueError, el.remove_signal_handler, 0) + self.assertRaises(ValueError, el.add_signal_handler, -1, my_handler) + self.assertRaises(ValueError, el.remove_signal_handler, -1) + self.assertRaises(RuntimeError, el.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(el.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + el.add_signal_handler(signal.SIGINT, my_handler) + el.run_once() + os.kill(os.getpid(), signal.SIGINT) + el.run_once() + self.assertEqual(caught, 1) + # Removing it should restore the default handler. + self.assertTrue(el.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(el.remove_signal_handler(signal.SIGINT)) + + def testCancelSignalHandler(self): + # Cancelling the handler should remove it (eventually). + caught = 0 + def my_handler(): + nonlocal caught + caught += 1 + el = events.get_event_loop() + handler = el.add_signal_handler(signal.SIGINT, my_handler) + handler.cancel() + os.kill(os.getpid(), signal.SIGINT) + el.run_once() + self.assertEqual(caught, 0) + + def testSignalHandlingWhileSelecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + def my_handler(): + nonlocal caught + caught += 1 + el = events.get_event_loop() + handler = el.add_signal_handler(signal.SIGALRM, my_handler) + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + el.call_later(0.15, el.stop) + el.run_forever() + self.assertEqual(caught, 1) + + def testCreateTransport(self): + el = events.get_event_loop() + # TODO: This depends on xkcd.com behavior! + f = el.create_transport(MyProto, 'xkcd.com', 80) + tr, pr = el.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + el.run() + self.assertTrue(pr.nbytes > 0) + + def testCreateSslTransport(self): + el = events.get_event_loop() + # TODO: This depends on xkcd.com behavior! + f = el.create_transport(MyProto, 'xkcd.com', 443, ssl=True) + tr, pr = el.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + el.run() + self.assertTrue(pr.nbytes > 0) + + def testStartServing(self): + el = events.get_event_loop() + f = el.start_serving(MyProto, '0.0.0.0', 0) + sock = el.run_until_complete(f) + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect((host, port)) + client.send(b'xxx') + el.run_once() # This is quite mysterious, but necessary. + client.close() + el.run_once() + el.run_once() + sock.close() + + +if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + SELECTOR_CLASS = selectors.KqueueSelector + + +if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + SELECTOR_CLASS = selectors.EpollSelector + + +if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + SELECTOR_CLASS = selectors.PollSelector + + +# Should always exist. +class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + SELECTOR_CLASS = selectors.SelectSelector + + +class HandlerTests(unittest.TestCase): + + def testHandler(self): + pass + + def testMakeHandler(self): + def callback(*args): + return args + h1 = events.Handler(None, callback, ()) + h2 = events.make_handler(None, h1, ()) + self.assertEqual(h1, h2) + + +class PolicyTests(unittest.TestCase): + + def testPolicy(self): + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/tulip/futures.py b/tulip/futures.py new file mode 100644 index 00000000..d679e31f --- /dev/null +++ b/tulip/futures.py @@ -0,0 +1,193 @@ +"""A Future class similar to the one in PEP 3148.""" + +__all__ = ['CancelledError', 'TimeoutError', + 'InvalidStateError', 'InvalidTimeoutError', + 'Future', + ] + +import concurrent.futures._base + +from . import events + +# States for Future. +_PENDING = 'PENDING' +_CANCELLED = 'CANCELLED' +_FINISHED = 'FINISHED' + +# TODO: Do we really want to depend on concurrent.futures internals? +Error = concurrent.futures._base.Error +CancelledError = concurrent.futures.CancelledError +TimeoutError = concurrent.futures.TimeoutError + + +class InvalidStateError(Error): + """The operation is not allowed in this state.""" + # TODO: Show the future, its state, the method, and the required state. + + +class InvalidTimeoutError(Error): + """Called result() or exception() with timeout != 0.""" + # TODO: Print a nice error message. + + +class Future: + """This class is *almost* compatible with concurrent.futures.Future. + + Differences: + + - result() and exception() do not take a timeout argument and + raise an exception when the future isn't done yet. + + - Callbacks registered with add_done_callback() are always called + via the event loop's call_soon_threadsafe(). + + - This class is not compatible with the wait() and as_completed() + methods in the concurrent.futures package. + + (In Python 3.4 or later we may be able to unify the implementations.) + """ + + # TODO: PEP 3148 seems to say that cancel() does not call the + # callbacks, but set_running_or_notify_cancel() does (if cancel() + # was called). Here, cancel() schedules the callbacks, and + # set_running_or_notify_cancel() is not supported. + + # Class variables serving as defaults for instance variables. + _state = _PENDING + _result = None + _exception = None + + def __init__(self): + """XXX""" + self._event_loop = events.get_event_loop() + self._callbacks = [] + + def __repr__(self): + """XXX""" + res = self.__class__.__name__ + if self._state == _FINISHED: + if self._exception is not None: + res += ''.format(self._exception) + else: + res += ''.format(self._result) + elif self._callbacks: + size = len(self._callbacks) + if size > 2: + res += '<{}, [{}, <{} more>, {}]>'.format( + self._state, self._callbacks[0], + size-2, self._callbacks[-1]) + else: + res += '<{}, {}>'.format(self._state, self._callbacks) + else: + res +='<{}>'.format(self._state) + return res + + def cancel(self): + """XXX""" + if self._state != _PENDING: + return False + self._state = _CANCELLED + self._schedule_callbacks() + return True + + def _schedule_callbacks(self): + """XXX""" + callbacks = self._callbacks[:] + if not callbacks: + return + self._callbacks[:] = [] + for callback in callbacks: + self._event_loop.call_soon(callback, self) + + def cancelled(self): + """XXX""" + return self._state == _CANCELLED + + def running(self): + """XXX""" + return False # We don't have a running state. + + def done(self): + """XXX""" + return self._state != _PENDING + + def result(self, timeout=0): + """XXX""" + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + if self._exception is not None: + raise self._exception + return self._result + + def exception(self, timeout=0): + """XXX""" + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + return self._exception + + def add_done_callback(self, fn): + """XXX""" + if self._state != _PENDING: + self._event_loop.call_soon(fn, self) + else: + self._callbacks.append(fn) + + # New method not in PPE 3148. + + def remove_done_callback(self, fn): + """XXX""" + filtered_callbacks = [f for f in self._callbacks if f != fn] + removed_count = len(self._callbacks) - len(filtered_callbacks) + if removed_count: + self._callbacks[:] = filtered_callbacks + return removed_count + + # So-called internal methods (note: no set_running_or_notify_cancel()). + + def set_result(self, result): + """XXX""" + if self._state != _PENDING: + raise InvalidStateError + self._result = result + self._state = _FINISHED + self._schedule_callbacks() + + def set_exception(self, exception): + """XXX""" + if self._state != _PENDING: + raise InvalidStateError + self._exception = exception + self._state = _FINISHED + self._schedule_callbacks() + + # Truly internal methods. + + def _copy_state(self, other): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert other.done() + assert not self.done() + if other.cancelled(): + self.cancel() + else: + exception = other.exception() + if exception is not None: + self.set_exception(exception) + else: + result = other.result() + self.set_result(result) + + def __iter__(self): + if not self.done(): + yield self # This tells Task to wait for completion. + return self.result() # May raise too. diff --git a/tulip/futures_test.py b/tulip/futures_test.py new file mode 100644 index 00000000..d45b4410 --- /dev/null +++ b/tulip/futures_test.py @@ -0,0 +1,71 @@ +"""Tests for futures.py.""" + +import unittest + +from . import futures + + +class FutureTests(unittest.TestCase): + + def testInitialState(self): + f = futures.Future() + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertFalse(f.done()) + + def testCancel(self): + f = futures.Future() + self.assertTrue(f.cancel()) + self.assertTrue(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(futures.CancelledError, f.result) + self.assertRaises(futures.CancelledError, f.exception) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def testResult(self): + f = futures.Future() + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def testException(self): + exc = RuntimeError() + f = futures.Future() + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def testYieldFromTwice(self): + f = futures.Future() + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + + +if __name__ == '__main__': + unittest.main() diff --git a/tulip/http_client.py b/tulip/http_client.py new file mode 100644 index 00000000..3ebdff0f --- /dev/null +++ b/tulip/http_client.py @@ -0,0 +1,292 @@ +"""HTTP Client for Tulip. + +Most basic usage: + + sts, headers, response = yield from http_client.fetch(url, + method='GET', headers={}, request=b'') + assert isinstance(sts, int) + assert isinstance(headers, dict) + # sort of; case insensitive (what about multiple values for same header?) + headers['status'] == '200 Ok' # or some such + assert isinstance(response, bytes) + +However you can also open a stream: + + f, wstream = http_client.open_stream(url, method, headers) + wstream.write(b'abc') + wstream.writelines([b'def', b'ghi']) + wstream.write_eof() + sts, headers, rstream = yield from f + response = yield from rstream.read() + +TODO: Reuse email.Message class (or its subclass, http.client.HTTPMessage). +TODO: How do we do connection keep alive? Pooling? +""" + +import collections +import email.message +import email.parser +import re + +import tulip +from . import events +from . import futures +from . import tasks + + +# TODO: Move to another module. +class StreamReader: + + def __init__(self, limit=2**16): + self.limit = limit # Max line length. (Security feature.) + self.buffer = collections.deque() # Deque of bytes objects. + self.byte_count = 0 # Bytes in buffer. + self.line_count = 0 # Number of complete lines in buffer. + self.eof = False # Whether we're done. + self.waiter = None # A future. + + def feed_eof(self): + self.eof = True + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(True) + + def feed_data(self, data): + if not data: + return + self.buffer.append(data) + self.line_count += data.count(b'\n') + self.byte_count += len(data) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(False) + + @tasks.coroutine + def readline(self): + # TODO: Limit line length for security. + while not self.line_count and not self.eof: + assert self.waiter is None + self.waiter = futures.Future() + yield from self.waiter + continue + parts = [] + while self.buffer: + data = self.buffer.popleft() + ichar = data.find(b'\n') + if ichar < 0: + parts.append(data) + else: + ichar += 1 + head, tail = data[:ichar], data[ichar:] + parts.append(head) + if tail: + self.buffer.appendleft(tail) + self.line_count -= 1 + break + return b''.join(parts) + + @tasks.coroutine + def read(self, n=-1): + if not n: + return b'' + if n < 0: + while not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + else: + if not self.byte_count and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + if n < 0 or self.byte_count <= n: + data = b''.join(self.buffer) + self.buffer.clear() + self.byte_count = 0 + self.line_count = 0 + return data + parts = [] + parts_bytes = 0 + while self.buffer and parts_bytes < n: + data = self.buffer.popleft() + data_bytes = len(data) + if n < parts_bytes + data_bytes: + data_bytes = n - parts_bytes + data, rest = data[:data_bytes], data[data_bytes:] + self.buffer.appendleft(rest) + parts.append(data) + parts_bytes += data_bytes + self.byte_count -= data_bytes + if self.line_count: + self.line_count -= data.count(b'\n') + return b''.join(parts) + + @tasks.coroutine + def readexactly(self, n): + if n <= 0: + return b'' + while self.byte_count < n and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + return (yield from self.read(n)) + +class HttpClientProtocol: + """This Protocol class is also used to initiate the connection. + + Usage: + p = HttpClientProtocol(url, ...) + f = p.connect() # Returns a Future + ...now what?... + """ + + def __init__(self, host, port=None, *, + path='/', method='GET', headers=None, ssl=None, + make_body=None, encoding='utf-8', version='1.1', + chunked=False): + host = self.validate(host, 'host') + if ':' in host: + assert port is None + host, port_s = host.split(':', 1) + port = int(port_s) + self.host = host + if port is None: + if ssl: + port = 443 + else: + port = 80 + assert isinstance(port, int) + self.port = port + self.path = self.validate(path, 'path') + self.method = self.validate(method, 'method') + self.headers = email.message.Message() + if headers: + for key, value in headers.items(): + self.validate(key, 'header key') + self.validate(value, 'header value', True) + self.headers[key] = value + self.encoding = self.validate(encoding, 'encoding') + self.version = self.validate(version, 'version') + self.make_body = make_body + self.chunked = chunked + self.ssl = ssl + if 'content-length' not in self.headers: + if self.make_body is None: + self.headers['Content-Length'] = '0' + else: + self.chunked = True + if self.chunked: + if 'Transfer-Encoding' not in self.headers: + self.headers['Transfer-Encoding'] = 'chunked' + else: + assert self.headers['Transfer-Encoding'].lower() == 'chunked' + if 'host' not in self.headers: + self.headers['Host'] = self.host + self.event_loop = events.get_event_loop() + self.transport = None + + def validate(self, value, name, embedded_spaces_okay=False): + # Must be a string. If embedded_spaces_okay is False, no + # whitespace is allowed; otherwise, internal single spaces are + # allowed (but no other whitespace). + assert isinstance(value, str), \ + '{} should be str, not {}'.format(name, type(value)) + parts = value.split() + assert parts, '{} should not be empty'.format(name) + if embedded_spaces_okay: + assert ' '.join(parts) == value, \ + '{} can only contain embedded single spaces ({!r})'.format( + name, value) + else: + assert parts == [value], \ + '{} cannot contain whitespace ({!r})'.format(name, value) + return value + + @tulip.coroutine + def connect(self): + yield from self.event_loop.create_transport(lambda: self, + self.host, + self.port, + ssl=self.ssl) + # TODO: A better mechanism to return all info from the + # status line, all headers, and the buffer, without having + # an N-tuple return value. + status_line = yield from self.stream.readline() + m = re.match(rb'HTTP/(\d\.\d)\s+(\d\d\d)\s+([^\r\n]+)\r?\n\Z', + status_line) + if not m: + raise 'Invalid HTTP status line ({!r})'.format(status_line) + version, status, message = m.groups() + raw_headers = [] + while True: + header = yield from self.stream.readline() + if not header.strip(): + break + raw_headers.append(header) + parser = email.parser.BytesHeaderParser() + headers = parser.parsebytes(b''.join(raw_headers)) + content_length = headers.get('content-length') + if content_length: + content_length = int(content_length) # May raise. + if content_length is None: + stream = self.stream + else: + # TODO: A wrapping stream that limits how much it can read + # without reading it all into memory at once. + body = yield from self.stream.readexactly(content_length) + stream = StreamReader() + stream.feed_data(body) + stream.feed_eof() + sts = '{} {}'.format(self.decode(status), self.decode(message)) + return (sts, headers, stream) + + def encode(self, s): + if isinstance(s, bytes): + return s + return s.encode(self.encoding) + + def decode(self, s): + if isinstance(s, str): + return s + return s.decode(self.encoding) + + def write_str(self, s): + self.transport.write(self.encode(s)) + + def write_chunked(self, s): + if not s: + return + data = self.encode(s) + self.write_str('{:x}\r\n'.format(len(data))) + self.transport.write(data) + self.transport.write(b'\r\n') + + def write_chunked_eof(self): + self.transport.write(b'0\r\n\r\n') + + def connection_made(self, transport): + self.transport = transport + line = '{} {} HTTP/{}\r\n'.format(self.method, + self.path, + self.version) + self.write_str(line) + for key, value in self.headers.items(): + self.write_str('{}: {}\r\n'.format(key, value)) + self.transport.write(b'\r\n') + self.stream = StreamReader() + if self.make_body is not None: + if self.chunked: + self.make_body(self.write_chunked, self.write_chunked_eof) + else: + self.make_body(self.write_str, self.transport.write_eof) + + def data_received(self, data): + self.stream.feed_data(data) + + def eof_received(self): + self.stream.feed_eof() + + def connection_lost(self, exc): + print('connection lost:', exc) diff --git a/tulip/protocols.py b/tulip/protocols.py new file mode 100644 index 00000000..3a00e2ee --- /dev/null +++ b/tulip/protocols.py @@ -0,0 +1,58 @@ +"""Abstract Protocol class.""" + +__all__ = ['Protocol'] + + +class Protocol: + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing (they don't raise exceptions). + + When the user wants to requests a transport, they pass a protocol + factory to a utility function (e.g., EventLoop.create_transport()). + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_list() will be called exactly once + with either an exception object or None as an argument. + + State machine of calls: + + start -> CM [-> DR*] [-> ER?] -> CL -> end + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the connection. + To send data, call its write() or writelines() method. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + The default implementation does nothing. + + TODO: By default close the transport. But we don't have the + transport as an instance variable (connection_made() may not + set it). + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ diff --git a/tulip/selectors.py b/tulip/selectors.py new file mode 100644 index 00000000..08762f86 --- /dev/null +++ b/tulip/selectors.py @@ -0,0 +1,396 @@ +"""Select module. + +This module supports asynchronous I/O on multiple file descriptors. +""" + +import logging + +from select import * + + +# generic events, that must be mapped to implementation-specific ones +# read event +SELECT_IN = (1 << 0) +# write event +SELECT_OUT = (1 << 1) +# connect event +SELECT_CONNECT = SELECT_OUT + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file descriptor, or any object with a `fileno()` method + + Returns: + corresponding file descriptor + """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (ValueError, TypeError): + raise ValueError("Invalid file object: {!r}".format(fileobj)) + return fd + + +class SelectorKey: + """Object used internally to associate a file object to its backing file + descriptor, selected event mask and attached data.""" + + def __init__(self, fileobj, events, data=None): + self.fileobj = fileobj + self.fd = _fileobj_to_fd(fileobj) + self.events = events + self.data = data + + def __repr__(self): + return '{}'.format( + self.__class__.__name__, + self.fileobj, self.fd, self.events, self.data) + + +class _BaseSelector: + """Base selector class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + performant implementation on the current platform. + """ + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # this maps file objects to keys - for fast (un)registering + self._fileobj_to_key = {} + + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of SELECT_IN|SELECT_OUT) + data -- attached data + + Returns: + SelectorKey instance + """ + if (not events) or (events & ~(SELECT_IN|SELECT_OUT)): + raise ValueError("Invalid events: {}".format(events)) + + if fileobj in self._fileobj_to_key: + raise ValueError("{!r} is already registered".format(fileobj)) + + key = SelectorKey(fileobj, events, data) + self._fd_to_key[key.fd] = key + self._fileobj_to_key[fileobj] = key + return key + + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object + + Returns: + SelectorKey instance + """ + try: + key = self._fileobj_to_key[fileobj] + del self._fd_to_key[key.fd] + del self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + return key + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of SELECT_IN|SELECT_OUT) + data -- attached data + """ + self.unregister(fileobj) + self.register(fileobj, events, data) + + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout == 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (fileobj, events, attached data) for ready file objects + `events` is a bitwise mask of SELECT_IN|SELECT_OUT + """ + raise NotImplementedError() + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + self._fd_to_key.clear() + self._fileobj_to_key.clear() + + def get_info(self, fileobj): + """Return information about a registered file object. + + Returns: + (events, data) associated to this file object + + Raises KeyError if the file object is not registered. + """ + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise KeyError("{} is not registered".format(fileobj)) + return key.events, key.data + + def registered_count(self): + """Return the number of registered file objects. + + Returns: + number of currently registered file objects + """ + return len(self._fd_to_key) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key + """ + try: + return self._fd_to_key[fd] + except KeyError: + raise RuntimeError("No key found for fd {}".format(fd)) + + +class SelectSelector(_BaseSelector): + """Select-based selector.""" + + def __init__(self): + super().__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & SELECT_IN: + self._readers.add(key.fd) + if events & SELECT_OUT: + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + def select(self, timeout=None): + try: + r, w, _ = select(self._readers, self._writers, [], timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + r = set(r) + w = set(w) + ready = [] + for fd in r | w: + events = 0 + if fd in r: + events |= SELECT_IN + if fd in w: + events |= SELECT_OUT + + key = self._key_from_fd(fd) + ready.append((key.fileobj, events, key.data)) + return ready + + +if 'poll' in globals(): + + class PollSelector(_BaseSelector): + """Poll-based selector.""" + + def __init__(self): + super().__init__() + self._poll = poll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & SELECT_IN: + poll_events |= POLLIN + if events & SELECT_OUT: + poll_events |= POLLOUT + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = None if timeout is None else int(1000 * timeout) + ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~POLLIN: + events |= SELECT_OUT + if event & ~POLLOUT: + events |= SELECT_IN + + key = self._key_from_fd(fd) + ready.append((key.fileobj, events, key.data)) + return ready + + +if 'epoll' in globals(): + + class EpollSelector(_BaseSelector): + """Epoll-based selector.""" + + def __init__(self): + super().__init__() + self._epoll = epoll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + epoll_events = 0 + if events & SELECT_IN: + epoll_events |= EPOLLIN + if events & SELECT_OUT: + epoll_events |= EPOLLOUT + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._epoll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = -1 if timeout is None else timeout + max_ev = self.registered_count() + ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~EPOLLIN: + events |= SELECT_OUT + if event & ~EPOLLOUT: + events |= SELECT_IN + + key = self._key_from_fd(fd) + ready.append((key.fileobj, events, key.data)) + return ready + + def close(self): + super().close() + self._epoll.close() + + +if 'kqueue' in globals(): + + class KqueueSelector(_BaseSelector): + """Kqueue-based selector.""" + + def __init__(self): + super().__init__() + self._kqueue = kqueue() + + def unregister(self, fileobj): + key = super().unregister(fileobj) + mask = 0 + if key.events & SELECT_IN: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + if key.events & SELECT_OUT: + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + return key + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & SELECT_IN: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & SELECT_OUT: + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def select(self, timeout=None): + max_ev = self.registered_count() + ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == KQ_FILTER_READ: + events |= SELECT_IN + if flag == KQ_FILTER_WRITE: + events |= SELECT_OUT + + key = self._key_from_fd(fd) + ready.append((key.fileobj, events, key.data)) + return ready + + def close(self): + super().close() + self._kqueue.close() + + +# Choose the best implementation: roughly, epoll|kqueue > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + Selector = KqueueSelector +elif 'EpollSelector' in globals(): + Selector = EpollSelector +elif 'PollSelector' in globals(): + Selector = PollSelector +else: + Selector = SelectSelector diff --git a/tulip/tasks.py b/tulip/tasks.py new file mode 100644 index 00000000..91c32076 --- /dev/null +++ b/tulip/tasks.py @@ -0,0 +1,277 @@ +"""Support for tasks, coroutines and the scheduler.""" + +__all__ = ['coroutine', 'task', 'Task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'as_completed', 'sleep', + ] + +import concurrent.futures +import inspect +import logging +import time + +from . import events +from . import futures + + +def coroutine(func): + """Decorator to mark coroutines.""" + # TODO: This is a feel-good API only. It is not enforced. + assert inspect.isgeneratorfunction(func) + func._is_coroutine = True # Not sure who can use this. + return func + + +# TODO: Do we need this? +def iscoroutinefunction(func): + """Return True if func is a decorated coroutine function.""" + return (inspect.isgeneratorfunction(func) and + getattr(func, '_is_coroutine', False)) + + +# TODO: Do we need this? +def iscoroutine(obj): + """Return True if obj is a coroutine object.""" + return inspect.isgenerator(obj) # TODO: And what? + + +def task(func): + """Decorator for a coroutine to be wrapped in a Task.""" + def task_wrapper(*args, **kwds): + coro = func(*args, **kwds) + return Task(coro) + return task_wrapper + + +class Task(futures.Future): + """A coroutine wrapped in a Future.""" + + def __init__(self, coro): + assert inspect.isgenerator(coro) # Must be a coroutine *object*. + super().__init__() # Sets self._event_loop. + self._coro = coro + self._must_cancel = False + self._event_loop.call_soon(self._step) + + def __repr__(self): + res = super().__repr__() + if (self._must_cancel and + self._state == futures._PENDING and + ')'.format(self._coro.__name__) + res[i:] + return res + + def cancel(self): + if self.done(): + return False + self._must_cancel = True + # _step() will call super().cancel() to call the callbacks. + self._event_loop.call_soon(self._step_maybe) + return True + + def cancelled(self): + return self._must_cancel or super().cancelled() + + def _step_maybe(self): + # Helper for cancel(). + if not self.done(): + return self._step() + + def _step(self, value=None, exc=None): + if self.done(): # pragma: no cover + logging.warn('_step(): already done: %r, %r, %r', self, value, exc) + return + # We'll call either coro.throw(exc) or coro.send(value). + if self._must_cancel: + exc = futures.CancelledError + coro = self._coro + try: + if exc is not None: + result = coro.throw(exc) + elif value is not None: + result = coro.send(value) + else: + result = next(coro) + except StopIteration as exc: + if self._must_cancel: + super().cancel() + else: + self.set_result(exc.value) + except Exception as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + except BaseException as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + raise + else: + # XXX No check for self._must_cancel here? + if isinstance(result, futures.Future): + result.add_done_callback(self._wakeup) + elif isinstance(result, concurrent.futures.Future): + # This ought to be more efficient than wrap_future(), + # because we don't create an extra Future. + result.add_done_callback( + lambda future: + self._event_loop.call_soon_threadsafe(_wakeup, future)) + else: + if result is not None: + logging.warn('_step(): bad yield: %r', result) + self._event_loop.call_soon(self._step) + + def _wakeup(self, future): + try: + value = future.result() + except Exception as exc: + self._step(None, exc) + else: + self._step(value, None) + + +# wait() and as_completed() similar to those in PEP 3148. + +FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED +FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION +ALL_COMPLETED = concurrent.futures.ALL_COMPLETED + + +# Even though this *is* a @coroutine, we don't mark it as such! +def wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Wait for the Futures and and coroutines given by fs to complete. + + Coroutines will be wrapped in Tasks. + + Returns two sets of Future: (done, pending). + + Usage: + + done, pending = yield from tulip.wait(fs) + + Note: This does not raise TimeoutError! Futures that aren't done + when the timeout occurs are returned in the second set. + """ + fs = _wrap_coroutines(fs) + return _wait(fs, timeout, return_when) + + +@coroutine +def _wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Internal helper: Like wait() but does not wrap coroutines.""" + done, pending = set(), set() + errors = 0 + for f in fs: + if f.done(): + done.add(f) + if not f.cancelled() and f.exception() is not None: + errors += 1 + else: + pending.add(f) + if (not pending or + timeout != None and timeout <= 0 or + return_when == FIRST_COMPLETED and done or + return_when == FIRST_EXCEPTION and errors): + return done, pending + bail = futures.Future() # Will always be cancelled eventually. + timeout_handler = None + debugstuff = locals() + if timeout is not None: + loop = events.get_event_loop() + timeout_handler = loop.call_later(timeout, bail.cancel) + def _on_completion(f): + pending.remove(f) + done.add(f) + if (not pending or + return_when == FIRST_COMPLETED or + (return_when == FIRST_EXCEPTION and + not f.cancelled() and + f.exception() is not None)): + bail.cancel() + try: + for f in pending: + f.add_done_callback(_on_completion) + try: + yield from bail + except futures.CancelledError: + pass + finally: + for f in pending: + f.remove_done_callback(_on_completion) + if timeout_handler is not None: + timeout_handler.cancel() + really_done = set(f for f in pending if f.done()) + if really_done: # pragma: no cover + # We don't expect this to ever happen. Or do we? + done.update(really_done) + pending.difference_update(really_done) + return done, pending + + +# This is *not* a @coroutine! It is just an iterator (yielding Futures). +def as_completed(fs, timeout=None): + """Return an iterator whose values, when waited for, are Futures. + + This differs from PEP 3148; the proper way to use this is: + + for f in as_completed(fs): + result = yield from f # The 'yield from' may raise. + # Use result. + + Raises TimeoutError if the timeout occurs before all Futures are + done. + + Note: The futures 'f' are not necessarily members of fs. + """ + deadline = None + if timeout is not None: + deadline = time.monotonic() + timeout + done = None # Make nonlocal happy. + fs = _wrap_coroutines(fs) + while fs: + if deadline is not None: + timeout = deadline - time.monotonic() + @coroutine + def _wait_for_some(): + nonlocal done, fs + done, fs = yield from _wait(fs, timeout=timeout, + return_when=FIRST_COMPLETED) + if not done: + fs = set() + raise futures.TimeoutError() + return done.pop().result() # May raise. + yield Task(_wait_for_some()) + for f in done: + yield f + + +def _wrap_coroutines(fs): + """Internal helper to process an iterator of Futures and coroutines. + + Returns a set of Futures. + """ + wrapped = set() + for f in fs: + if not isinstance(f, futures.Future): + assert iscoroutine(f) + f = Task(f) + wrapped.add(f) + return wrapped + + +def sleep(when, result=None): + """Return a Future that completes after a given time (in seconds). + + It's okay to cancel the Future. + + Undocumented feature: sleep(when, x) sets the Future's result to x. + """ + future = futures.Future() + future._event_loop.call_later(when, future.set_result, result) + return future diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py new file mode 100644 index 00000000..a214468e --- /dev/null +++ b/tulip/tasks_test.py @@ -0,0 +1,233 @@ +"""Tests for tasks.py.""" + +import time +import unittest + +from . import events +from . import futures +from . import tasks + + +class Dummy: + def __repr__(self): + return 'Dummy()' + def __call__(self, *args): + pass + + +class TaskTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def testTaskClass(self): + @tasks.coroutine + def notmuch(): + yield from [] + return 'ok' + t = tasks.Task(notmuch()) + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + + def testTaskDecorator(self): + @tasks.task + def notmuch(): + yield from [] + return 'ko' + t = notmuch() + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + + def testTaskRepr(self): + @tasks.task + def notmuch(): + yield from [] + return 'abc' + t = notmuch() + t.add_done_callback(Dummy()) + self.assertEqual(repr(t), 'Task()') + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), 'Task()') + self.assertRaises(futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertEqual(repr(t), 'Task()') + t = notmuch() + self.event_loop.run_until_complete(t) + self.assertEqual(repr(t), "Task()") + + def testTaskBasics(self): + @tasks.task + def outer(): + a = yield from inner1() + b = yield from inner2() + return a+b + @tasks.task + def inner1(): + yield from [] + return 42 + @tasks.task + def inner2(): + yield from [] + return 1000 + t = outer() + self.assertEqual(self.event_loop.run_until_complete(t), 1042) + + def testWait(self): + a = tasks.sleep(0.1) + b = tasks.sleep(0.15) + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertEqual(res, 42) + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0, 0.01) + # TODO: Test different return_when values. + + def testWaitWithException(self): + a = tasks.sleep(0.1) + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(0.15) + raise ZeroDivisionError('really') + b = tasks.Task(sleeper()) + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + + def testWaitWithTimeout(self): + a = tasks.sleep(0.1) + b = tasks.sleep(0.15) + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], timeout=0.11) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.1) + self.assertTrue(t1-t0 <= 0.12) + + def testAsCompleted(self): + @tasks.coroutine + def sleeper(dt, x): + yield from tasks.sleep(dt) + return x + a = sleeper(0.1, 'a') + b = sleeper(0.1, 'b') + c = sleeper(0.15, 'c') + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([b, c, a]): + values.append((yield from f)) + return values + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0, 0.01) + + def testAsCompletedWithTimeout(self): + a = tasks.sleep(0.1, 'a') + b = tasks.sleep(0.15, 'b') + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([a, b], timeout=0.12): + try: + v = yield from f + values.append((1, v)) + except futures.TimeoutError as exc: + values.append((2, exc)) + return values + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.11) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) + + def testSleep(self): + @tasks.coroutine + def sleeper(dt, arg): + yield from tasks.sleep(dt/2) + res = yield from tasks.sleep(dt/2, arg) + return res + t = tasks.Task(sleeper(0.1, 'yeah')) + t0 = time.monotonic() + self.event_loop.run() + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.09) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + + def testTaskCancelSleepingTask(self): + sleepfut = None + @tasks.task + def sleep(dt): + nonlocal sleepfut + sleepfut = tasks.sleep(dt) + try: + t0 = time.monotonic() + yield from sleepfut + finally: + t1 = time.monotonic() + @tasks.task + def doit(): + sleeper = sleep(5000) + self.event_loop.call_later(0.1, sleeper.cancel) + try: + t0 = time.monotonic() + yield from sleeper + except futures.CancelledError: + t1 = time.monotonic() + return 'cancelled' + else: + return 'slept in' + t0 = time.monotonic() + doer = doit() + self.assertEqual(self.event_loop.run_until_complete(doer), 'cancelled') + t1 = time.monotonic() + self.assertTrue(0.09 <= t1-t0 <= 0.11, (t1-t0, sleepfut, doer)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tulip/transports.py b/tulip/transports.py new file mode 100644 index 00000000..a1eace56 --- /dev/null +++ b/tulip/transports.py @@ -0,0 +1,90 @@ +"""Abstract Transport class.""" + +__all__ = ['Transport'] + + +class Transport: + """ABC representing a transport. + + There may be several implementations, but typically, the user does + not implement new transports; rather, the platform provides some + useful transports that are implemented using the platform's best + practices. + + The user never instantiates a transport directly; they call a + utility function, passing it a protocol factory and other + information necessary to create the transport and protocol. (E.g. + EventLoop.create_transport() or EventLoop.start_serving().) + + The utility function will asynchronously create a transport and a + protocol and hook them up by calling the protocol's + connection_made() method, passing it the transport. + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def write_eof(self): + """Closes the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + raise NotImplementedError + + def can_write_eof(self): + """Return True if this protocol supports write_eof(), False if not.""" + raise NotImplementedError + + def pause(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + raise NotImplementedError + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError diff --git a/tulip/unix_events.py b/tulip/unix_events.py new file mode 100644 index 00000000..69d67f82 --- /dev/null +++ b/tulip/unix_events.py @@ -0,0 +1,963 @@ +"""UNIX event loop and related classes. + +The event loop can be broken up into a selector (the part responsible +for telling us when file descriptors are ready) and the event loop +proper, which wraps a selector with functionality for scheduling +callbacks, immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. +""" + +import collections +import concurrent.futures +import errno +import heapq +import logging +import select +import socket +import ssl +import sys +import threading +import time + +try: + import signal +except ImportError: + signal = None + +from . import events +from . import futures +from . import protocols +from . import selectors +from . import tasks +from . import transports + +try: + from socket import socketpair +except ImportError: + assert sys.platform == 'win32' + from .winsocketpair import socketpair + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK, errno.EINPROGRESS)) +if sys.platform == 'win32': + _TRYAGAIN = frozenset(list(_TRYAGAIN) + [errno.WSAEWOULDBLOCK]) + +# Argument for default thread pool executor creation. +_MAX_WORKERS = 5 + + +class _StopError(BaseException): + """Raised to stop the event loop.""" + + +def _raise_stop_error(): + raise _StopError + + +class UnixEventLoop(events.EventLoop): + """Unix event loop. + + See events.EventLoop for API specification. + """ + + def __init__(self, selector=None): + super().__init__() + if selector is None: + # pick the best selector class for the platform + selector = selectors.Selector() + logging.info('Using selector: %s', selector.__class__.__name__) + self._selector = selector + self._ready = collections.deque() + self._scheduled = [] + self._everytime = [] + self._default_executor = None + self._signal_handlers = {} + self._make_self_pipe() + + def close(self): + if self._selector is not None: + self._selector.close() + self._selector = None + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self.add_reader(self._ssock.fileno(), self._read_from_self) + + def _read_from_self(self): + try: + self._ssock.recv(1) + except socket.error as exc: + if exc.errno in _TRYAGAIN: + return + raise # Halp! + + def _write_to_self(self): + try: + self._csock.send(b'x') + except socket.error as exc: + if exc.errno in _TRYAGAIN: + return + raise # Halp! + + def run(self): + """Run the event loop until nothing left to do or stop() called. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + + TODO: Give this a timeout too? + """ + while (self._ready or + self._scheduled or + self._selector.registered_count() > 1): + try: + self._run_once() + except _StopError: + break + + def run_forever(self): + """Run until stop() is called. + + This only makes sense over run() if you have another thread + scheduling callbacks using call_soon_threadsafe(). + """ + handler = self.call_repeatedly(24*3600, lambda: None) + try: + self.run() + finally: + handler.cancel() + + def run_once(self, timeout=None): + """Run through all callbacks and all I/O polls once. + + Calling stop() will break out of this too. + """ + try: + self._run_once(timeout) + except _StopError: + pass + + def run_until_complete(self, future, timeout=None): + """Run until the Future is done, or until a timeout. + + Return the Future's result, or raise its exception. If the + timeout is reached or stop() is called, raise TimeoutError. + """ + if timeout is None: + timeout = 0x7fffffff/1000.0 # 24 days + future.add_done_callback(lambda _: self.stop()) + handler = self.call_later(timeout, _raise_stop_error) + self.run() + handler.cancel() + if future.done(): + return future.result() # May raise future.exception(). + else: + raise futures.TimeoutError + + def stop(self): + """Stop running the event loop. + + Every callback scheduled before stop() is called will run. + Callback scheduled after stop() is called won't. However, + those callbacks will run if run() is called again later. + """ + self.call_soon(_raise_stop_error) + + def call_later(self, delay, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The delay can be an int or float, expressed in seconds. It is + always a relative time. + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Callbacks scheduled in the past are passed on to call_soon(), + so these will be called in the order in which they were + registered rather than by time due. This is so you can't + cheat and insert yourself at the front of the ready queue by + using a negative time. + + Any positional arguments after the callback will be passed to + the callback when it is called. + + # TODO: Should delay is None be interpreted as Infinity? + """ + if delay <= 0: + return self.call_soon(callback, *args) + handler = events.make_handler(time.monotonic() + delay, callback, args) + heapq.heappush(self._scheduled, handler) + return handler + + def call_repeatedly(self, interval, callback, *args): + """Call a callback every 'interval' seconds.""" + def wrapper(): + callback(*args) # If this fails, the chain is broken. + handler._when = time.monotonic() + interval + heapq.heappush(self._scheduled, handler) + handler = events.make_handler(time.monotonic() + interval, wrapper, ()) + heapq.heappush(self._scheduled, handler) + return handler + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + handler = events.make_handler(None, callback, args) + self._ready.append(handler) + return handler + + def call_soon_threadsafe(self, callback, *args): + """XXX""" + handler = self.call_soon(callback, *args) + self._write_to_self() + return handler + + def call_every_iteration(self, callback, *args): + """Call a callback just before the loop blocks. + + The callback is called for every iteration of the loop. + """ + handler = events.make_handler(None, callback, args) + self._everytime.append(handler) + return handler + + def wrap_future(self, future): + """XXX""" + if isinstance(future, futures.Future): + return future # Don't wrap our own type of Future. + new_future = futures.Future() + future.add_done_callback( + lambda future: + self.call_soon_threadsafe(new_future._copy_state, future)) + return new_future + + def run_in_executor(self, executor, callback, *args): + if isinstance(callback, events.Handler): + assert not args + assert callback.when is None + if callback.cancelled: + f = futures.Future() + f.set_result(None) + return f + callback, args = callback.callback, callback.args + if executor is None: + executor = self._default_executor + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + self._default_executor = executor + return self.wrap_future(executor.submit(callback, *args)) + + def set_default_executor(self, executor): + self._default_executor = executor + + def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) + + def getnameinfo(self, sockaddr, flags=0): + return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + + # TODO: Or create_connection()? Or create_client()? + @tasks.task + def create_transport(self, protocol_factory, host, port, *, ssl=False, + family=0, type=socket.SOCK_STREAM, proto=0, flags=0): + """XXX""" + infos = yield from self.getaddrinfo(host, port, + family=family, type=type, + proto=proto, flags=flags) + if not infos: + raise socket.error('getaddrinfo() returned empty list') + exceptions = [] + for family, type, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + yield self.sock_connect(sock, address) + except socket.error as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + if len(exceptions) == 1: + raise exceptions[0] + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise socket.error('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + protocol = protocol_factory() + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + waiter = futures.Future() + transport = _UnixSslTransport(self, sock, protocol, sslcontext, + waiter) + yield from waiter + else: + transport = _UnixSocketTransport(self, sock, protocol) + return transport, protocol + + # TODO: Or create_server()? + @tasks.task + def start_serving(self, protocol_factory, host, port, *, + family=0, type=socket.SOCK_STREAM, proto=0, flags=0, + backlog=100): + """XXX""" + infos = yield from self.getaddrinfo(host, port, + family=family, type=type, + proto=proto, flags=flags) + if not infos: + raise socket.error('getaddrinfo() returned empty list') + # TODO: Maybe we want to bind every address in the list + # instead of the first one that works? + exceptions = [] + for family, type, proto, cname, address in infos: + sock = socket.socket(family=family, type=type, proto=proto) + try: + sock.bind(address) + except socket.error as exc: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + sock.listen(backlog) + sock.setblocking(False) + self.add_reader(sock.fileno(), self._accept_connection, + protocol_factory, sock) + return sock + + def _accept_connection(self, protocol_factory, sock): + try: + conn, addr = sock.accept() + except socket.error as exc: + if exc.errno in _TRYAGAIN: + return # False alarm. + # Bad error. Stop serving. + self.remove_reader(sock.fileno()) + sock.close() + # There's nowhere to send the error, so just log it. + # TODO: Someone will want an error handler for this. + logging.exception('Accept failed') + return + protocol = protocol_factory() + transport = _UnixSocketTransport(self, conn, protocol) + # It's now up to the protocol to handle the connection. + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a Handler instance.""" + handler = events.make_handler(None, callback, args) + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.SELECT_IN, + (handler, None, None)) + else: + self._selector.modify(fd, mask | selectors.SELECT_IN, + (handler, writer, connector)) + + return handler + + def remove_reader(self, fd): + """Remove a reader callback.""" + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + pass + else: + mask &= ~selectors.SELECT_IN + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (None, writer, connector)) + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a Handler instance.""" + handler = events.make_handler(None, callback, args) + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.SELECT_OUT, + (None, handler, None)) + else: + self._selector.modify(fd, mask | selectors.SELECT_OUT, + (reader, handler, connector)) + return handler + + def remove_writer(self, fd): + """Remove a writer callback.""" + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + pass + else: + mask &= ~selectors.SELECT_OUT + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None, connector)) + + def add_connector(self, fd, callback, *args): + """Add a connector callback. Return a Handler instance.""" + # XXX As long as SELECT_CONNECT == SELECT_OUT, set the handler + # as both writer and connector. + handler = events.make_handler(None, callback, args) + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.SELECT_CONNECT, + (None, handler, handler)) + else: + self._selector.modify(fd, mask | selectors.SELECT_CONNECT, + (reader, handler, handler)) + return handler + + def remove_connector(self, fd): + """Remove a connector callback.""" + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + pass + else: + mask &= ~selectors.SELECT_CONNECT + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None, None)) + + def sock_recv(self, sock, n): + """XXX""" + fut = futures.Future() + self._sock_recv(fut, False, sock, n) + return fut + + def _sock_recv(self, fut, registered, sock, n): + fd = sock.fileno() + if registered: + # Remove the callback early. It should be rare that the + # selector says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_reader(fd) + if fut.cancelled(): + return + try: + data = sock.recv(n) + fut.set_result(data) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + fut.set_exception(exc) + else: + self.add_reader(fd, self._sock_recv, fut, True, sock, n) + + def sock_sendall(self, sock, data): + """XXX""" + fut = futures.Future() + self._sock_sendall(fut, False, sock, data) + return fut + + def _sock_sendall(self, fut, registered, sock, data): + fd = sock.fileno() + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + n = 0 + try: + if data: + n = sock.send(data) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + fut.set_exception(exc) + return + if n == len(data): + fut.set_result(None) + else: + if n: + data = data[n:] + self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + + def sock_connect(self, sock, address): + """XXX""" + # That address better not require a lookup! We're not calling + # self.getaddrinfo() for you here. But verifying this is + # complicated; the socket module doesn't have a pattern for + # IPv6 addresses (there are too many forms, apparently). + fut = futures.Future() + self._sock_connect(fut, False, sock, address) + return fut + + def _sock_connect(self, fut, registered, sock, address): + fd = sock.fileno() + if registered: + self.remove_connector(fd) + if fut.cancelled(): + return + try: + if not registered: + # First time around. + sock.connect(address) + else: + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to the except clause below. + raise socket.error(err, 'Connect call failed') + fut.set_result(None) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + fut.set_exception(exc) + else: + self.add_connector(fd, self._sock_connect, + fut, True, sock, address) + + def sock_accept(self, sock): + """XXX""" + fut = futures.Future() + self._sock_accept(fut, False, sock) + return fut + + def _sock_accept(self, fut, registered, sock): + fd = sock.fileno() + if registered: + self.remove_reader(fd) + if fut.cancelled(): + return + try: + conn, address = sock.accept() + conn.setblocking(False) + fut.set_result((conn, address)) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + fut.set_exception(exc) + else: + self.add_reader(fd, self._sock_accept, fut, True, sock) + + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + self._check_signal(sig) + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except ValueError as exc: + raise RuntimeError(str(exc)) + handler = events.make_handler(None, callback, args) + self._signal_handlers[sig] = handler + try: + signal.signal(sig, self._handle_signal) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as nexc: + logging.info('set_wakeup_fd(-1) failed: %s', nexc) + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + return handler + + def _handle_signal(self, sig, arg): + """Internal helper that is the actual signal handler.""" + handler = self._signal_handlers.get(sig) + if handler is None: + return # Assume it's some race condition. + if handler.cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self.call_soon_threadsafe(handler.callback, *handler.args) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not.""" + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as exc: + logging.info('set_wakeup_fd(-1) failed: %s', exc) + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + if signal is None: + raise RuntimeError('Signals are not supported') + if not (1 <= sig < signal.NSIG): + raise ValueError('sig {} out of range(1, {})'.format(sig, + signal.NSIG)) + + def _add_callback(self, handler): + """Add a Handler to ready or scheduled.""" + if handler.cancelled: + return + if handler.when is None: + self._ready.append(handler) + else: + heapq.heappush(self._scheduled, handler) + + def _run_once(self, timeout=None): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # Add everytime handlers, skipping cancelled ones. + any_cancelled = False + for handler in self._everytime: + if handler.cancelled: + any_cancelled = True + else: + self._ready.append(handler) + + # Remove cancelled everytime handlers if there are any. + if any_cancelled: + self._everytime = [handler for handler in self._everytime + if not handler.cancelled] + + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0].cancelled: + heapq.heappop(self._scheduled) + + # Inspect the poll queue. If there's exactly one selectable + # file descriptor, it's the self-pipe, and if there's nothing + # scheduled, we should ignore it. + if self._selector.registered_count() > 1 or self._scheduled: + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0].when + deadline = max(0, when - time.monotonic()) + if timeout is None: + timeout = deadline + else: + timeout = min(timeout, deadline) + + t0 = time.monotonic() + event_list = self._selector.select(timeout) + t1 = time.monotonic() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + for fileobj, mask, (reader, writer, connector) in event_list: + if mask & selectors.SELECT_IN and reader is not None: + self._add_callback(reader) + if mask & selectors.SELECT_OUT and writer is not None: + self._add_callback(writer) + elif mask & selectors.SELECT_CONNECT and connector is not None: + self._add_callback(connector) + + # Handle 'later' callbacks that are ready. + now = time.monotonic() + while self._scheduled: + handler = self._scheduled[0] + if handler.when > now: + break + handler = heapq.heappop(self._scheduled) + self._ready.append(handler) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # TODO: Ensure this loop always finishes, even if some + # callbacks keeps registering more callbacks. + while self._ready: + handler = self._ready.popleft() + if not handler.cancelled: + try: + handler.callback(*handler.args) + except Exception: + logging.exception('Exception in callback %s %r', + handler.callback, handler.args) + + +class _UnixSocketTransport(transports.Transport): + + def __init__(self, event_loop, sock, protocol): + self._event_loop = event_loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._closing = False # Set when close() called. + self._event_loop.add_reader(self._sock.fileno(), self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + + def _read_ready(self): + try: + data = self._sock.recv(16*1024) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + else: + if data: + self._event_loop.call_soon(self._protocol.data_received, data) + else: + self._event_loop.remove_reader(self._sock.fileno()) + self._event_loop.call_soon(self._protocol.eof_received) + + + def write(self, data): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + if not self._buffer: + # Attempt to send it right away first. + try: + n = self._sock.send(data) + except socket.error as exc: + if exc.errno in _TRYAGAIN: + n = 0 + else: + self._fatal_error(exc) + return + if n == len(data): + return + if n: + data = data[n:] + self.add_writer(self._sock.fileno(), self._write_ready) + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + self._buffer = [] + try: + if data: + n = self._sock.send(data) + else: + n = 0 + except socket.error as exc: + if exc.errno in _TRYAGAIN: + n = 0 + else: + self._fatal_error(exc) + return + if n == len(data): + self._event_loop.remove_writer(self._sock.fileno()) + if self._closing: + self._event_loop.call_soon(self._call_connection_lost, None) + return + if n: + data = data[n:] + self._buffer.append(data) # Try again later. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sock.fileno()) + if not self._buffer: + self._event_loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + logging.exception('Fatal error for %s', self) + self._event_loop.remove_writer(self._sock.fileno()) + self._event_loop.remove_reader(self._sock.fileno()) + self._buffer = [] + self._event_loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + + +class _UnixSslTransport(transports.Transport): + + def __init__(self, event_loop, rawsock, protocol, sslcontext, waiter): + self._event_loop = event_loop + self._rawsock = rawsock + self._protocol = protocol + sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslcontext = sslcontext + self._waiter = waiter + sslsock = sslcontext.wrap_socket(rawsock, + do_handshake_on_connect=False) + self._sslsock = sslsock + self._buffer = [] + self._closing = False # Set when close() called. + self._on_handshake() + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._event_loop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._event_loop.add_writable(fd, self._on_handshake) + return + except Exception as exc: + self._sslsock.close() + self._waiter.set_exception(exc) + return + except BaseException as exc: + self._sslsock.close() + self._waiter.set_exception(exc) + raise + self._event_loop.remove_reader(fd) + self._event_loop.remove_writer(fd) + self._event_loop.add_reader(fd, self._on_ready) + self._event_loop.add_writer(fd, self._on_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + self._waiter.set_result(None) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + return + else: + if data: + self._protocol.data_received(data) + else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer + self._event_loop.remove_reader(fd) + self._event_loop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + return + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = b''.join(self._buffer) + self._buffer = [] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + return + else: + if n < len(data): + self._buffer.append(data[n:]) + + def write(self, data): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sslsock.fileno()) + if not self._buffer: + self._event_loop.call_soon(self._protocol.connection_lost, None) + + def _fatal_error(self, exc): + logging.exception('Fatal error for %s', self) + self._event_loop.remove_writer(self._sslsock.fileno()) + self._event_loop.remove_reader(self._sslsock.fileno()) + self._buffer = [] + self._event_loop.call_soon(self._protocol.connection_lost, exc) diff --git a/tulip/winsocketpair.py b/tulip/winsocketpair.py new file mode 100644 index 00000000..87d54c91 --- /dev/null +++ b/tulip/winsocketpair.py @@ -0,0 +1,30 @@ +"""A socket pair usable as a self-pipe, for Windows. + +Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. +""" + +import errno +import socket + + +def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """Emulate the Unix socketpair() function on Windows.""" + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + lsock.bind(('localhost', 0)) + lsock.listen(1) + addr, port = lsock.getsockname() + csock = socket.socket(family, type, proto) + csock.setblocking(False) + try: + csock.connect((addr, port)) + except socket.error as e: + if e.errno != errno.WSAEWOULDBLOCK: + lsock.close() + csock.close() + raise + ssock, _ = lsock.accept() + csock.setblocking(True) + lsock.close() + return (ssock, csock) From c7ac172a06b4dc7fbb1ba69334962898b86f6a4a Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 13 Jan 2013 19:07:51 -0800 Subject: [PATCH 0222/1502] New policy for running ready callbacks: alternate callbacks and I/O poll. (See python-ideas discussion.) --- tulip/events_test.py | 2 +- tulip/unix_events.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tulip/events_test.py b/tulip/events_test.py index f8b718e0..3b803836 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -379,10 +379,10 @@ def testStartServing(self): client.connect((host, port)) client.send(b'xxx') el.run_once() # This is quite mysterious, but necessary. - client.close() el.run_once() el.run_once() sock.close() + client.close() if hasattr(selectors, 'KqueueSelector'): diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 69d67f82..12bb9b36 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -734,9 +734,12 @@ def _run_once(self, timeout=None): # This is the only place where callbacks are actually *called*. # All other places just add them to ready. - # TODO: Ensure this loop always finishes, even if some - # callbacks keeps registering more callbacks. - while self._ready: + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is threadsafe without using locks. + ntodo = len(self._ready) + for i in range(ntodo): handler = self._ready.popleft() if not handler.cancelled: try: @@ -762,6 +765,7 @@ def _read_ready(self): data = self._sock.recv(16*1024) except socket.error as exc: if exc.errno not in _TRYAGAIN: + import pdb; pdb.set_trace() self._fatal_error(exc) else: if data: From 82075f5a36cdca74c1239af63310b4658f8bf5c6 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 13 Jan 2013 19:10:06 -0800 Subject: [PATCH 0223/1502] Remove pdb call accidentally checked in. --- tulip/unix_events.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 12bb9b36..4bbeb7c7 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -765,7 +765,6 @@ def _read_ready(self): data = self._sock.recv(16*1024) except socket.error as exc: if exc.errno not in _TRYAGAIN: - import pdb; pdb.set_trace() self._fatal_error(exc) else: if data: From d8c14fb19ae697fbd44f0bb64e031c7f3c2ea431 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 14 Jan 2013 17:00:04 -0800 Subject: [PATCH 0224/1502] Cancelling a read/write/connect handler now removes the handler. However, the effect is delayed until more I/O happens on the FD. So don't use this if you plan to close the FD -- it will crash or hang (depending on which selector is used). --- tulip/events_test.py | 33 +++++++++++++++++++++++++++++++++ tulip/unix_events.py | 17 ++++++++++++++--- 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/tulip/events_test.py b/tulip/events_test.py index 3b803836..6efa23f6 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -227,6 +227,25 @@ def reader(): el.run() self.assertEqual(b''.join(bytes_read), b'abcdef') + def testReaderCallbackCancel(self): + el = events.get_event_loop() + r, w = unix_events.socketpair() + bytes_read = [] + def reader(): + data = r.recv(1024) + if data: + bytes_read.append(data) + if sum(len(b) for b in bytes_read) >= 6: + handler.cancel() + if not data: + r.close() + handler = el.add_reader(r.fileno(), reader) + el.call_later(0.05, w.send, b'abc') + el.call_later(0.1, w.send, b'def') + el.call_later(0.15, w.close) + el.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + def testWriterCallback(self): el = events.get_event_loop() r, w = unix_events.socketpair() @@ -252,6 +271,20 @@ def testWriterCallbackWithHandler(self): r.close() self.assertTrue(len(data) >= 200) + def testWriterCallbackCancel(self): + el = events.get_event_loop() + r, w = unix_events.socketpair() + w.setblocking(False) + def sender(): + w.send(b'x'*256) + handler.cancel() + handler = el.add_writer(w.fileno(), sender) + el.run() + w.close() + data = r.recv(1024) + r.close() + self.assertTrue(data == b'x'*256) + def testSockClientOps(self): el = events.get_event_loop() sock = socket.socket() diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 4bbeb7c7..e221ee73 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -717,11 +717,22 @@ def _run_once(self, timeout=None): logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) for fileobj, mask, (reader, writer, connector) in event_list: if mask & selectors.SELECT_IN and reader is not None: - self._add_callback(reader) + if reader.cancelled: + self.remove_reader(fileobj) + else: + self._add_callback(reader) if mask & selectors.SELECT_OUT and writer is not None: - self._add_callback(writer) + if writer.cancelled: + self.remove_writer(fileobj) + else: + self._add_callback(writer) + # XXX The next elif is unreachable until selector.py + # changes to implement SELECT_CONNECT != SELECTOR_OUT. elif mask & selectors.SELECT_CONNECT and connector is not None: - self._add_callback(connector) + if connector.cancelled: + self.remove_connector(fileobj) + else: + self._add_callback(connector) # Handle 'later' callbacks that are ready. now = time.monotonic() From f0cb9958f3bf81dd1d9a0c778fb11fee54fd10ea Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 14 Jan 2013 17:05:16 -0800 Subject: [PATCH 0225/1502] Return True/False from remove_reader/writer/connector. --- tulip/events_test.py | 12 ++++++++---- tulip/unix_events.py | 9 ++++++--- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/tulip/events_test.py b/tulip/events_test.py index 6efa23f6..5d635c31 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -199,7 +199,7 @@ def reader(): if data: bytes_read.append(data) else: - el.remove_reader(r.fileno()) + self.assertTrue(el.remove_reader(r.fileno())) r.close() el.add_reader(r.fileno(), reader) el.call_later(0.05, w.send, b'abc') @@ -217,7 +217,7 @@ def reader(): if data: bytes_read.append(data) else: - el.remove_reader(r.fileno()) + self.assertTrue(el.remove_reader(r.fileno())) r.close() handler = events.Handler(None, reader, ()) self.assertEqual(el.add_reader(r.fileno(), handler), handler) @@ -251,7 +251,9 @@ def testWriterCallback(self): r, w = unix_events.socketpair() w.setblocking(False) el.add_writer(w.fileno(), w.send, b'x'*(256*1024)) - el.call_later(0.1, el.remove_writer, w.fileno()) + def remove_writer(): + self.assertTrue(el.remove_writer(w.fileno())) + el.call_later(0.1, remove_writer) el.run() w.close() data = r.recv(256*1024) @@ -264,7 +266,9 @@ def testWriterCallbackWithHandler(self): w.setblocking(False) handler = events.Handler(None, w.send, (b'x'*(256*1024),)) self.assertEqual(el.add_writer(w.fileno(), handler), handler) - el.call_later(0.1, el.remove_writer, w.fileno()) + def remove_writer(): + self.assertTrue(el.remove_writer(w.fileno())) + el.call_later(0.1, remove_writer) el.run() w.close() data = r.recv(256*1024) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index e221ee73..5f7415b3 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -401,13 +401,14 @@ def remove_reader(self, fd): try: mask, (reader, writer, connector) = self._selector.get_info(fd) except KeyError: - pass + return False else: mask &= ~selectors.SELECT_IN if not mask: self._selector.unregister(fd) else: self._selector.modify(fd, mask, (None, writer, connector)) + return True def add_writer(self, fd, callback, *args): """Add a writer callback. Return a Handler instance.""" @@ -427,13 +428,14 @@ def remove_writer(self, fd): try: mask, (reader, writer, connector) = self._selector.get_info(fd) except KeyError: - pass + return False else: mask &= ~selectors.SELECT_OUT if not mask: self._selector.unregister(fd) else: self._selector.modify(fd, mask, (reader, None, connector)) + return True def add_connector(self, fd, callback, *args): """Add a connector callback. Return a Handler instance.""" @@ -455,13 +457,14 @@ def remove_connector(self, fd): try: mask, (reader, writer, connector) = self._selector.get_info(fd) except KeyError: - pass + return False else: mask &= ~selectors.SELECT_CONNECT if not mask: self._selector.unregister(fd) else: self._selector.modify(fd, mask, (reader, None, None)) + return True def sock_recv(self, sock, n): """XXX""" From c2d667484c360d4050b40f96071783aa0fd600dd Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 17 Jan 2013 12:01:29 -0800 Subject: [PATCH 0226/1502] Rename create_transport() to create_connection(). Remove type argument. --- tulip/events.py | 6 +++--- tulip/events_test.py | 4 ++-- tulip/http_client.py | 8 ++++---- tulip/protocols.py | 2 +- tulip/transports.py | 2 +- tulip/unix_events.py | 12 +++++++----- 6 files changed, 18 insertions(+), 16 deletions(-) diff --git a/tulip/events.py b/tulip/events.py index a5c3e383..f39ddb79 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -139,12 +139,12 @@ def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): def getnameinfo(self, sockaddr, flags=0): raise NotImplementedError - def create_transport(self, protocol_factory, host, port, *, - family=0, type=0, proto=0, flags=0): + def create_connection(self, protocol_factory, host, port, *, + family=0, proto=0, flags=0): raise NotImplementedError def start_serving(self, protocol_factory, host, port, *, - family=0, type=0, proto=0, flags=0): + family=0, proto=0, flags=0): raise NotImplementedError # Ready-based callback registration methods. diff --git a/tulip/events_test.py b/tulip/events_test.py index 5d635c31..c2dd5e0b 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -388,7 +388,7 @@ def my_handler(): def testCreateTransport(self): el = events.get_event_loop() # TODO: This depends on xkcd.com behavior! - f = el.create_transport(MyProto, 'xkcd.com', 80) + f = el.create_connection(MyProto, 'xkcd.com', 80) tr, pr = el.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) @@ -398,7 +398,7 @@ def testCreateTransport(self): def testCreateSslTransport(self): el = events.get_event_loop() # TODO: This depends on xkcd.com behavior! - f = el.create_transport(MyProto, 'xkcd.com', 443, ssl=True) + f = el.create_connection(MyProto, 'xkcd.com', 443, ssl=True) tr, pr = el.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) diff --git a/tulip/http_client.py b/tulip/http_client.py index d5f2e57c..d658be51 100644 --- a/tulip/http_client.py +++ b/tulip/http_client.py @@ -206,10 +206,10 @@ def validate(self, value, name, embedded_spaces_okay=False): @tulip.coroutine def connect(self): - yield from self.event_loop.create_transport(lambda: self, - self.host, - self.port, - ssl=self.ssl) + yield from self.event_loop.create_connection(lambda: self, + self.host, + self.port, + ssl=self.ssl) # TODO: A better mechanism to return all info from the # status line, all headers, and the buffer, without having # an N-tuple return value. diff --git a/tulip/protocols.py b/tulip/protocols.py index 3a00e2ee..ad294f3a 100644 --- a/tulip/protocols.py +++ b/tulip/protocols.py @@ -11,7 +11,7 @@ class Protocol: nothing (they don't raise exceptions). When the user wants to requests a transport, they pass a protocol - factory to a utility function (e.g., EventLoop.create_transport()). + factory to a utility function (e.g., EventLoop.create_connection()). When the connection is made successfully, connection_made() is called with a suitable transport object. Then data_received() diff --git a/tulip/transports.py b/tulip/transports.py index a1eace56..4aaae3c7 100644 --- a/tulip/transports.py +++ b/tulip/transports.py @@ -14,7 +14,7 @@ class Transport: The user never instantiates a transport directly; they call a utility function, passing it a protocol factory and other information necessary to create the transport and protocol. (E.g. - EventLoop.create_transport() or EventLoop.start_serving().) + EventLoop.create_connection() or EventLoop.start_serving().) The utility function will asynchronously create a transport and a protocol and hook them up by calling the protocol's diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 5f7415b3..3a8c2e5f 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -290,11 +290,12 @@ def getnameinfo(self, sockaddr, flags=0): # TODO: Or create_connection()? Or create_client()? @tasks.task - def create_transport(self, protocol_factory, host, port, *, ssl=False, - family=0, type=socket.SOCK_STREAM, proto=0, flags=0): + def create_connection(self, protocol_factory, host, port, *, ssl=False, + family=0, proto=0, flags=0): """XXX""" infos = yield from self.getaddrinfo(host, port, - family=family, type=type, + family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) if not infos: raise socket.error('getaddrinfo() returned empty list') @@ -337,11 +338,12 @@ def create_transport(self, protocol_factory, host, port, *, ssl=False, # TODO: Or create_server()? @tasks.task def start_serving(self, protocol_factory, host, port, *, - family=0, type=socket.SOCK_STREAM, proto=0, flags=0, + family=0, proto=0, flags=0, backlog=100): """XXX""" infos = yield from self.getaddrinfo(host, port, - family=family, type=type, + family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) if not infos: raise socket.error('getaddrinfo() returned empty list') From 03b4263b141114f02fb6696201c8cd4694800d69 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 17 Jan 2013 20:15:22 -0800 Subject: [PATCH 0227/1502] Fail later when ssl cannot be imported. --- .hgeol | 2 ++ tulip/unix_events.py | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 .hgeol diff --git a/.hgeol b/.hgeol new file mode 100644 index 00000000..b6910a22 --- /dev/null +++ b/.hgeol @@ -0,0 +1,2 @@ +[patterns] +** = native diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 3a8c2e5f..ccce71f0 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -20,7 +20,10 @@ import logging import select import socket -import ssl +try: + import ssl +except ImportError: + ssl = None import sys import threading import time From 546a192e772d2b5a55d963f34a1d644159e23aaf Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Fri, 18 Jan 2013 12:01:53 +0000 Subject: [PATCH 0228/1502] Fix inequalities in assertions --- tulip/tasks_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py index a214468e..c99cb8f1 100644 --- a/tulip/tasks_test.py +++ b/tulip/tasks_test.py @@ -96,7 +96,7 @@ def foo(): t0 = time.monotonic() res = self.event_loop.run_until_complete(tasks.Task(foo())) t1 = time.monotonic() - self.assertTrue(t1-t0, 0.01) + self.assertTrue(t1-t0 <= 0.01) # TODO: Test different return_when values. def testWaitWithException(self): @@ -161,7 +161,7 @@ def foo(): t0 = time.monotonic() res = self.event_loop.run_until_complete(tasks.Task(foo())) t1 = time.monotonic() - self.assertTrue(t1-t0, 0.01) + self.assertTrue(t1-t0 <= 0.01) def testAsCompletedWithTimeout(self): a = tasks.sleep(0.1, 'a') From 84124f3d725c9931249d083e78f43fcda91c383a Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Fri, 18 Jan 2013 13:03:09 +0000 Subject: [PATCH 0229/1502] Minimal chages to make tests pass on Windows --- tulip/events_test.py | 6 +++++- tulip/selectors.py | 10 +++++++++- tulip/tasks_test.py | 4 ++-- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/tulip/events_test.py b/tulip/events_test.py index c2dd5e0b..acb15ba8 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -6,6 +6,7 @@ import select import signal import socket +import sys import threading import time import unittest @@ -326,6 +327,7 @@ def testSockAccept(self): conn.close() listener.close() + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') def testAddSignalHandler(self): caught = 0 def my_handler(): @@ -359,6 +361,7 @@ def my_handler(): # Removing again returns False. self.assertFalse(el.remove_signal_handler(signal.SIGINT)) + @unittest.skipIf(sys.platform == 'win32', 'Unix only') def testCancelSignalHandler(self): # Cancelling the handler should remove it (eventually). caught = 0 @@ -372,6 +375,7 @@ def my_handler(): el.run_once() self.assertEqual(caught, 0) + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') def testSignalHandlingWhileSelecting(self): # Test with a signal actually arriving during a select() call. caught = 0 @@ -413,7 +417,7 @@ def testStartServing(self): host, port = sock.getsockname() self.assertEqual(host, '0.0.0.0') client = socket.socket() - client.connect((host, port)) + client.connect(('127.0.0.1', port)) client.send(b'xxx') el.run_once() # This is quite mysterious, but necessary. el.run_once() diff --git a/tulip/selectors.py b/tulip/selectors.py index 08762f86..158876cf 100644 --- a/tulip/selectors.py +++ b/tulip/selectors.py @@ -4,6 +4,7 @@ """ import logging +import sys from select import * @@ -216,7 +217,7 @@ def unregister(self, fileobj): def select(self, timeout=None): try: - r, w, _ = select(self._readers, self._writers, [], timeout) + r, w, _ = self._select(self._readers, self._writers, [], timeout) except InterruptedError: # A signal arrived. Don't die, just return no events. return [] @@ -234,6 +235,13 @@ def select(self, timeout=None): ready.append((key.fileobj, events, key.data)) return ready + if sys.platform == 'win32': + def _select(self, r, w, _, timeout=None): + r, w, x = select(r, w, w, timeout) + return r, w + x, [] + else: + from select import select as _select + if 'poll' in globals(): diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py index c99cb8f1..a7381450 100644 --- a/tulip/tasks_test.py +++ b/tulip/tasks_test.py @@ -134,7 +134,7 @@ def foo(): res = self.event_loop.run_until_complete(tasks.Task(foo())) t1 = time.monotonic() self.assertTrue(t1-t0 >= 0.1) - self.assertTrue(t1-t0 <= 0.12) + self.assertTrue(t1-t0 <= 0.13) def testAsCompleted(self): @tasks.coroutine @@ -226,7 +226,7 @@ def doit(): doer = doit() self.assertEqual(self.event_loop.run_until_complete(doer), 'cancelled') t1 = time.monotonic() - self.assertTrue(0.09 <= t1-t0 <= 0.11, (t1-t0, sleepfut, doer)) + self.assertTrue(0.09 <= t1-t0 <= 0.13, (t1-t0, sleepfut, doer)) if __name__ == '__main__': From 42a19fac06ea01e549ad29da2388a243bc95bb0e Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 18 Jan 2013 13:51:48 -0800 Subject: [PATCH 0230/1502] Fix typo in _UnixSocketTransport.write(). --- tulip/unix_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index ccce71f0..685a576f 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -812,7 +812,7 @@ def write(self, data): return if n: data = data[n:] - self.add_writer(self._sock.fileno(), self._write_ready) + self._event_loop.add_writer(self._sock.fileno(), self._write_ready) self._buffer.append(data) def _write_ready(self): From 928304d94deb0557b0cdf3edf7ba672e876ec4b0 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 18 Jan 2013 14:17:55 -0800 Subject: [PATCH 0231/1502] CHECKPOINT: subprocess transport. --- tulip/subprocess_test.py | 48 ++++++++++++ tulip/subprocess_transport.py | 133 ++++++++++++++++++++++++++++++++++ 2 files changed, 181 insertions(+) create mode 100644 tulip/subprocess_test.py create mode 100644 tulip/subprocess_transport.py diff --git a/tulip/subprocess_test.py b/tulip/subprocess_test.py new file mode 100644 index 00000000..4eb24e41 --- /dev/null +++ b/tulip/subprocess_test.py @@ -0,0 +1,48 @@ +"""Tests for subprocess_transport.py.""" + +import unittest + +from . import events +from . import protocols +from . import subprocess_transport + + +class MyProto(protocols.Protocol): + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write_eof() + def data_received(self, data): + print('received:', data) + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + + +class FutureTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def testUnixSubprocess(self): + p = MyProto() + t = subprocess_transport.UnixSubprocessTransport(p, ['/bin/ls', '-lR']) + self.event_loop.run() + + +if __name__ == '__main__': + unittest.main() diff --git a/tulip/subprocess_transport.py b/tulip/subprocess_transport.py new file mode 100644 index 00000000..bcf859b1 --- /dev/null +++ b/tulip/subprocess_transport.py @@ -0,0 +1,133 @@ +import fcntl +import os +import traceback + +from . import transports +from . import events + + +class UnixSubprocessTransport(transports.Transport): + """Transport class managing a subprocess. + + TODO: Separate this into something that just handles pipe I/O, + and something else that handles pipe setup, fork, and exec. + """ + + def __init__(self, protocol, args): + self._protocol = protocol # Not a factory! :-) + self._args = args # args[0] must be full path of binary. + self._event_loop = events.get_event_loop() + self._buffer = [] + self._eof = False + rstdin, self._wstdin = os.pipe() + self._rstdout, wstdout = os.pipe() + + # TODO: This is incredibly naive. Should look at + # subprocess.py for all the precautions around fork/exec. + pid = os.fork() + if not pid: + # Child. + try: + os.dup2(rstdin, 0) + os.dup2(wstdout, 1) + # TODO: What to do with stderr? + os.execv(args[0], args) + except: + try: + traceback.print_traceback() + finally: + os._exit(127) + + # Parent. + os.close(rstdin) + os.close(wstdout) + _setnonblocking(self._wstdin) + _setnonblocking(self._rstdout) + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.add_reader(self._rstdout, self._stdout_callback) + + def write(self, data): + assert not self._eof + if not data: + return + if not self._buffer: + # Attempt to write it right away first. + try: + n = os.write(self._wstdin, data) + except BlockingIOError: + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + if n >= len(data): + return + if n > 0: + data = data[n:] + self._event_loop.add_writer(self._wstdin, self._stdin_callback) + self._buffer.append(data) + + def write_eof(self): + assert not self._eof + assert self._wstdin >= 0 + self._eof = True + if not self._buffer: + self._event_loop.remove_writer(self._wstdin) + os.close(self._wstdin) + self._wstdin = -1 + + def close(self): + if not self._eof: + self.write_eof() + # XXX What else? + + def _fatal_error(self, exc): + logging.error('Fatal error: %r', exc) + if self._rstdout >= 0: + os.close(self._rstdout) + self._rstdout = -1 + if self._wstdin >= 0: + os.close(self._wstdin) + self._wstdin = -1 + self._eof = True + self._buffer = None + + def _stdin_callback(self): + data = b''.join(self._buffer) + self._buffer = [] + try: + if data: + n = os.write(self._wstdin, data) + else: + n = 0 + except BlockingIOError: + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + if n >= len(data): + self._event_loop.remove_writer(self._wstdin) + if self._eof: + os.close(self._wstdin) + self._wstdin = -1 + return + if n > 0: + data = data[n:] + self._buffer.append(data) # Try again later. + + def _stdout_callback(self): + try: + data = os.read(self._rstdout, 1024) + except BlockingIOError: + return + if data: + self._event_loop.call_soon(self._protocol.data_received, data) + else: + self._event_loop.remove_reader(self._rstdout) + os.close(self._rstdout) + self._rstdout = -1 + self._event_loop.call_soon(self._protocol.eof_received) + + +def _setnonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) From 878d692a9612ba15ddc6b3c123d629151dfb38dd Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Sat, 19 Jan 2013 05:16:48 -0800 Subject: [PATCH 0232/1502] Add docstrings to Future, as well as a test for add_done_callback. For the latter, a new keyword argument - event_loop - is added to the constructor. --- tulip/futures.py | 84 ++++++++++++++++++++++++++++++++++--------- tulip/futures_test.py | 24 +++++++++++++ 2 files changed, 92 insertions(+), 16 deletions(-) diff --git a/tulip/futures.py b/tulip/futures.py index d679e31f..6b90727f 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -57,13 +57,20 @@ class Future: _result = None _exception = None - def __init__(self): - """XXX""" - self._event_loop = events.get_event_loop() + def __init__(self, *, event_loop=None): + """Initialize the future. + + The optional event_loop argument allows to explicitly set the event + loop object used by the future. If it's not provided, the future uses + the default event loop. + """ + if event_loop is None: + self._event_loop = events.get_event_loop() + else: + self._event_loop = event_loop self._callbacks = [] def __repr__(self): - """XXX""" res = self.__class__.__name__ if self._state == _FINISHED: if self._exception is not None: @@ -83,7 +90,12 @@ def __repr__(self): return res def cancel(self): - """XXX""" + """Cancel the future and schedule callbacks. + + If the future is already done or cancelled, return False. Otherwise, + change the future's state to cancelled, schedule the callbacks and + return True. + """ if self._state != _PENDING: return False self._state = _CANCELLED @@ -91,7 +103,11 @@ def cancel(self): return True def _schedule_callbacks(self): - """XXX""" + """Internal: Ask the event loop to call all callbacks. + + The callbacks are scheduled to be called as soon as possible. Also + clears the callback list. + """ callbacks = self._callbacks[:] if not callbacks: return @@ -100,19 +116,33 @@ def _schedule_callbacks(self): self._event_loop.call_soon(callback, self) def cancelled(self): - """XXX""" + """Return True if the future was cancelled.""" return self._state == _CANCELLED def running(self): - """XXX""" + """Always return False. + + This method is for compatibility with concurrent.futures; we don't + have a running state. + """ return False # We don't have a running state. def done(self): - """XXX""" + """Return True if the future is done. + + Done means either that a result / exception are available, or that the + future was cancelled. + """ return self._state != _PENDING def result(self, timeout=0): - """XXX""" + """Return the result this future represents. + + If the future has been cancelled, raises CancelledError. If the + future's result isn't yet available, raises InvalidStateError. If + the future is done and has an exception set, this exception is raised. + Timeout values other than 0 are not supported. + """ if timeout != 0: raise InvalidTimeoutError if self._state == _CANCELLED: @@ -124,7 +154,13 @@ def result(self, timeout=0): return self._result def exception(self, timeout=0): - """XXX""" + """Return the exception that was set on this future. + + The exception (or None if no exception was set) is returned only if + the future is done. If the future has been cancelled, raises + CancelledError. If the future isn't done yet, raises + InvalidStateError. Timeout values other than 0 are not supported. + """ if timeout != 0: raise InvalidTimeoutError if self._state == _CANCELLED: @@ -134,16 +170,24 @@ def exception(self, timeout=0): return self._exception def add_done_callback(self, fn): - """XXX""" + """Add a callback to be run when the future becomes done. + + The callback is called with a single argument - the future object. If + the future is already done when this is called, the callback is + scheduled with call_soon. + """ if self._state != _PENDING: self._event_loop.call_soon(fn, self) else: self._callbacks.append(fn) - # New method not in PPE 3148. + # New method not in PEP 3148. def remove_done_callback(self, fn): - """XXX""" + """Remove all instances of a callback from the "call when done" list. + + Returns the number of callbacks removed. + """ filtered_callbacks = [f for f in self._callbacks if f != fn] removed_count = len(self._callbacks) - len(filtered_callbacks) if removed_count: @@ -153,7 +197,11 @@ def remove_done_callback(self, fn): # So-called internal methods (note: no set_running_or_notify_cancel()). def set_result(self, result): - """XXX""" + """Mark the future done and set its result. + + If the future is already done when this method is called, raises + InvalidStateError. + """ if self._state != _PENDING: raise InvalidStateError self._result = result @@ -161,7 +209,11 @@ def set_result(self, result): self._schedule_callbacks() def set_exception(self, exception): - """XXX""" + """ Mark the future done and set an exception. + + If the future is already done when this method is called, raises + InvalidStateError. + """ if self._state != _PENDING: raise InvalidStateError self._exception = exception diff --git a/tulip/futures_test.py b/tulip/futures_test.py index d45b4410..c4951772 100644 --- a/tulip/futures_test.py +++ b/tulip/futures_test.py @@ -13,6 +13,10 @@ def testInitialState(self): self.assertFalse(f.running()) self.assertFalse(f.done()) + def testInitEventLoopPositional(self): + # Make sure Future does't accept a positional argument + self.assertRaises(TypeError, futures.Future, 42) + def testCancel(self): f = futures.Future() self.assertTrue(f.cancel()) @@ -50,6 +54,26 @@ def testException(self): self.assertRaises(futures.InvalidStateError, f.set_exception, None) self.assertFalse(f.cancel()) + def testDoneCallbacks(self): + class MyEventLoop: + def call_soon(self, fn, future): + fn(future) + + bag = [] + def make_callback(num): + def bag_appender(future): + bag.append(num) + return bag_appender + + f = futures.Future(event_loop=MyEventLoop()) + f.add_done_callback(make_callback(42)) + f.add_done_callback(make_callback(17)) + + self.assertEquals(bag, []) + f.set_result('foo') + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + def testYieldFromTwice(self): f = futures.Future() def fixture(): From 20ba0707720681b6e37881c3d45aed6b006cca46 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 19 Jan 2013 11:08:54 -0800 Subject: [PATCH 0233/1502] Change key_from_fd to log warning and return None if key not found. --- tulip/selectors.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tulip/selectors.py b/tulip/selectors.py index 158876cf..d3f16a30 100644 --- a/tulip/selectors.py +++ b/tulip/selectors.py @@ -190,7 +190,8 @@ def _key_from_fd(self, fd): try: return self._fd_to_key[fd] except KeyError: - raise RuntimeError("No key found for fd {}".format(fd)) + logging.warn('No key found for fd %r', fd) + return None class SelectSelector(_BaseSelector): @@ -232,7 +233,8 @@ def select(self, timeout=None): events |= SELECT_OUT key = self._key_from_fd(fd) - ready.append((key.fileobj, events, key.data)) + if key: + ready.append((key.fileobj, events, key.data)) return ready if sys.platform == 'win32': @@ -283,7 +285,8 @@ def select(self, timeout=None): events |= SELECT_IN key = self._key_from_fd(fd) - ready.append((key.fileobj, events, key.data)) + if key: + ready.append((key.fileobj, events, key.data)) return ready @@ -328,7 +331,8 @@ def select(self, timeout=None): events |= SELECT_IN key = self._key_from_fd(fd) - ready.append((key.fileobj, events, key.data)) + if key: + ready.append((key.fileobj, events, key.data)) return ready def close(self): @@ -384,7 +388,8 @@ def select(self, timeout=None): events |= SELECT_OUT key = self._key_from_fd(fd) - ready.append((key.fileobj, events, key.data)) + if key: + ready.append((key.fileobj, events, key.data)) return ready def close(self): From c786bce1e5d94b83876fcbedab3a1588a0279da8 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 19 Jan 2013 11:19:12 -0800 Subject: [PATCH 0234/1502] Skip ssl test if ssl module not found. --- tulip/events_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tulip/events_test.py b/tulip/events_test.py index acb15ba8..d805c32a 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -6,6 +6,10 @@ import select import signal import socket +try: + import ssl +except ImportError: + ssl = None import sys import threading import time @@ -399,6 +403,7 @@ def testCreateTransport(self): el.run() self.assertTrue(pr.nbytes > 0) + @unittest.skipIf(ssl is None, 'No ssl module') def testCreateSslTransport(self): el = events.get_event_loop() # TODO: This depends on xkcd.com behavior! From d9562247c60152e5100347201280fc830f63e638 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Sat, 19 Jan 2013 13:39:10 -0800 Subject: [PATCH 0235/1502] Add more unit tests for Future.(add|remove)_done_callback --- tulip/futures_test.py | 91 +++++++++++++++++++++++++++++++++---------- 1 file changed, 71 insertions(+), 20 deletions(-) diff --git a/tulip/futures_test.py b/tulip/futures_test.py index c4951772..6c59edc2 100644 --- a/tulip/futures_test.py +++ b/tulip/futures_test.py @@ -54,26 +54,6 @@ def testException(self): self.assertRaises(futures.InvalidStateError, f.set_exception, None) self.assertFalse(f.cancel()) - def testDoneCallbacks(self): - class MyEventLoop: - def call_soon(self, fn, future): - fn(future) - - bag = [] - def make_callback(num): - def bag_appender(future): - bag.append(num) - return bag_appender - - f = futures.Future(event_loop=MyEventLoop()) - f.add_done_callback(make_callback(42)) - f.add_done_callback(make_callback(17)) - - self.assertEquals(bag, []) - f.set_result('foo') - self.assertEqual(bag, [42, 17]) - self.assertEqual(f.result(), 'foo') - def testYieldFromTwice(self): f = futures.Future() def fixture(): @@ -91,5 +71,76 @@ def fixture(): self.assertEqual(next(g), ('C', 42)) # yield 'C', y. +# A fake event loop for tests. All it does is implement a call_soon method +# that immediately invokes the given function. +class _FakeEventLoop: + def call_soon(self, fn, future): + fn(future) + + +class FutureDoneCallbackTests(unittest.TestCase): + + def _make_callback(self, bag, thing): + # Create a callback function that appends thing to bag. + def bag_appender(future): + bag.append(thing) + return bag_appender + + def _new_future(self): + return futures.Future(event_loop=_FakeEventLoop()) + + def testCallbacksInvokedOnSetResult(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 42)) + f.add_done_callback(self._make_callback(bag, 17)) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def testCallbacksInvokedOnSetException(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 100)) + + self.assertEqual(bag, []) + exc = RuntimeError() + f.set_exception(exc) + self.assertEqual(bag, [100]) + self.assertEqual(f.exception(), exc) + + def testRemoveDoneCallback(self): + bag = [] + f = self._new_future() + cb1 = self._make_callback(bag, 1) + cb2 = self._make_callback(bag, 2) + cb3 = self._make_callback(bag, 3) + + # Add one cb1 and one cb2. + f.add_done_callback(cb1) + f.add_done_callback(cb2) + + # One instance of cb2 removed. Now there's only one cb1. + self.assertEqual(f.remove_done_callback(cb2), 1) + + # Never had any cb3 in there. + self.assertEqual(f.remove_done_callback(cb3), 0) + + # After this there will be 6 instances of cb1 and one of cb2. + f.add_done_callback(cb2) + for i in range(5): + f.add_done_callback(cb1) + + # Remove all instances of cb1. One cb2 remains. + self.assertEqual(f.remove_done_callback(cb1), 6) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [2]) + self.assertEqual(f.result(), 'foo') + + if __name__ == '__main__': unittest.main() From f02fd00e09a79be339fbeff9f223c5fea768d068 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Sun, 20 Jan 2013 05:28:22 -0800 Subject: [PATCH 0236/1502] Increase test coverage of futures.py to 100% --- tulip/futures_test.py | 64 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/tulip/futures_test.py b/tulip/futures_test.py index 6c59edc2..7834fec8 100644 --- a/tulip/futures_test.py +++ b/tulip/futures_test.py @@ -5,6 +5,10 @@ from . import futures +def _fakefunc(f): + return f + + class FutureTests(unittest.TestCase): def testInitialState(self): @@ -31,6 +35,9 @@ def testCancel(self): def testResult(self): f = futures.Future() + self.assertRaises(futures.InvalidStateError, f.result) + self.assertRaises(futures.InvalidTimeoutError, f.result, 10) + f.set_result(42) self.assertFalse(f.cancelled()) self.assertFalse(f.running()) @@ -44,6 +51,9 @@ def testResult(self): def testException(self): exc = RuntimeError() f = futures.Future() + self.assertRaises(futures.InvalidStateError, f.exception) + self.assertRaises(futures.InvalidTimeoutError, f.exception, 10) + f.set_exception(exc) self.assertFalse(f.cancelled()) self.assertFalse(f.running()) @@ -70,6 +80,60 @@ def fixture(): # The second "yield from f" does not yield f. self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + def testRepr(self): + f_pending = futures.Future() + self.assertEqual(repr(f_pending), 'Future') + + f_cancelled = futures.Future() + f_cancelled.cancel() + self.assertEqual(repr(f_cancelled), 'Future') + + f_result = futures.Future() + f_result.set_result(4) + self.assertEqual(repr(f_result), 'Future') + + f_exception = futures.Future() + f_exception.set_exception(RuntimeError()) + self.assertEqual(repr(f_exception), 'Future') + + f_few_callbacks = futures.Future() + f_few_callbacks.add_done_callback(_fakefunc) + self.assertIn('Future', r) + + def testCopyState(self): + # Test the internal _copy_state method since it's being directly + # invoked in other modules. + f = futures.Future() + f.set_result(10) + + newf = futures.Future() + newf._copy_state(f) + self.assertTrue(newf.done()) + self.assertEqual(newf.result(), 10) + + f_exception = futures.Future() + f_exception.set_exception(RuntimeError()) + + newf_exception = futures.Future() + newf_exception._copy_state(f_exception) + self.assertTrue(newf_exception.done()) + self.assertRaises(RuntimeError, newf_exception.result) + + f_cancelled = futures.Future() + f_cancelled.cancel() + + newf_cancelled = futures.Future() + newf_cancelled._copy_state(f_cancelled) + self.assertTrue(newf_cancelled.cancelled()) + # A fake event loop for tests. All it does is implement a call_soon method # that immediately invokes the given function. From e964a1fce1e9d18ea1f2e06682c9186b21e71fb1 Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Sun, 20 Jan 2013 05:31:22 -0800 Subject: [PATCH 0237/1502] Remove unnecessary TODO comment --- tulip/futures.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tulip/futures.py b/tulip/futures.py index 6b90727f..e79999fc 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -47,11 +47,6 @@ class Future: (In Python 3.4 or later we may be able to unify the implementations.) """ - # TODO: PEP 3148 seems to say that cancel() does not call the - # callbacks, but set_running_or_notify_cancel() does (if cancel() - # was called). Here, cancel() schedules the callbacks, and - # set_running_or_notify_cancel() is not supported. - # Class variables serving as defaults for instance variables. _state = _PENDING _result = None From f47f3656aeae40dd53f20be83f43a083e2898046 Mon Sep 17 00:00:00 2001 From: Charles-Fran?ois Natali Date: Sun, 20 Jan 2013 19:04:56 +0100 Subject: [PATCH 0238/1502] Rename SELECT_(IN|OUT|CONNECT) to EVENT_(READ|WRITE|CONNECT). --- tulip/selectors.py | 50 ++++++++++++++++++++++---------------------- tulip/unix_events.py | 28 ++++++++++++------------- 2 files changed, 39 insertions(+), 39 deletions(-) diff --git a/tulip/selectors.py b/tulip/selectors.py index d3f16a30..8af5e5fc 100644 --- a/tulip/selectors.py +++ b/tulip/selectors.py @@ -11,11 +11,11 @@ # generic events, that must be mapped to implementation-specific ones # read event -SELECT_IN = (1 << 0) +EVENT_READ = (1 << 0) # write event -SELECT_OUT = (1 << 1) +EVENT_WRITE = (1 << 1) # connect event -SELECT_CONNECT = SELECT_OUT +EVENT_CONNECT = EVENT_WRITE def _fileobj_to_fd(fileobj): @@ -79,13 +79,13 @@ def register(self, fileobj, events, data=None): Parameters: fileobj -- file object - events -- events to monitor (bitwise mask of SELECT_IN|SELECT_OUT) + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) data -- attached data Returns: SelectorKey instance """ - if (not events) or (events & ~(SELECT_IN|SELECT_OUT)): + if (not events) or (events & ~(EVENT_READ|EVENT_WRITE)): raise ValueError("Invalid events: {}".format(events)) if fileobj in self._fileobj_to_key: @@ -118,7 +118,7 @@ def modify(self, fileobj, events, data=None): Parameters: fileobj -- file object - events -- events to monitor (bitwise mask of SELECT_IN|SELECT_OUT) + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) data -- attached data """ self.unregister(fileobj) @@ -138,7 +138,7 @@ def select(self, timeout=None): Returns: list of (fileobj, events, attached data) for ready file objects - `events` is a bitwise mask of SELECT_IN|SELECT_OUT + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE """ raise NotImplementedError() @@ -204,9 +204,9 @@ def __init__(self): def register(self, fileobj, events, data=None): key = super().register(fileobj, events, data) - if events & SELECT_IN: + if events & EVENT_READ: self._readers.add(key.fd) - if events & SELECT_OUT: + if events & EVENT_WRITE: self._writers.add(key.fd) return key @@ -228,9 +228,9 @@ def select(self, timeout=None): for fd in r | w: events = 0 if fd in r: - events |= SELECT_IN + events |= EVENT_READ if fd in w: - events |= SELECT_OUT + events |= EVENT_WRITE key = self._key_from_fd(fd) if key: @@ -257,9 +257,9 @@ def __init__(self): def register(self, fileobj, events, data=None): key = super().register(fileobj, events, data) poll_events = 0 - if events & SELECT_IN: + if events & EVENT_READ: poll_events |= POLLIN - if events & SELECT_OUT: + if events & EVENT_WRITE: poll_events |= POLLOUT self._poll.register(key.fd, poll_events) return key @@ -280,9 +280,9 @@ def select(self, timeout=None): for fd, event in fd_event_list: events = 0 if event & ~POLLIN: - events |= SELECT_OUT + events |= EVENT_WRITE if event & ~POLLOUT: - events |= SELECT_IN + events |= EVENT_READ key = self._key_from_fd(fd) if key: @@ -302,9 +302,9 @@ def __init__(self): def register(self, fileobj, events, data=None): key = super().register(fileobj, events, data) epoll_events = 0 - if events & SELECT_IN: + if events & EVENT_READ: epoll_events |= EPOLLIN - if events & SELECT_OUT: + if events & EVENT_WRITE: epoll_events |= EPOLLOUT self._epoll.register(key.fd, epoll_events) return key @@ -326,9 +326,9 @@ def select(self, timeout=None): for fd, event in fd_event_list: events = 0 if event & ~EPOLLIN: - events |= SELECT_OUT + events |= EVENT_WRITE if event & ~EPOLLOUT: - events |= SELECT_IN + events |= EVENT_READ key = self._key_from_fd(fd) if key: @@ -352,20 +352,20 @@ def __init__(self): def unregister(self, fileobj): key = super().unregister(fileobj) mask = 0 - if key.events & SELECT_IN: + if key.events & EVENT_READ: kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_DELETE) self._kqueue.control([kev], 0, 0) - if key.events & SELECT_OUT: + if key.events & EVENT_WRITE: kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_DELETE) self._kqueue.control([kev], 0, 0) return key def register(self, fileobj, events, data=None): key = super().register(fileobj, events, data) - if events & SELECT_IN: + if events & EVENT_READ: kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_ADD) self._kqueue.control([kev], 0, 0) - if events & SELECT_OUT: + if events & EVENT_WRITE: kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_ADD) self._kqueue.control([kev], 0, 0) return key @@ -383,9 +383,9 @@ def select(self, timeout=None): flag = kev.filter events = 0 if flag == KQ_FILTER_READ: - events |= SELECT_IN + events |= EVENT_READ if flag == KQ_FILTER_WRITE: - events |= SELECT_OUT + events |= EVENT_WRITE key = self._key_from_fd(fd) if key: diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 685a576f..fc32cba3 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -393,10 +393,10 @@ def add_reader(self, fd, callback, *args): try: mask, (reader, writer, connector) = self._selector.get_info(fd) except KeyError: - self._selector.register(fd, selectors.SELECT_IN, + self._selector.register(fd, selectors.EVENT_READ, (handler, None, None)) else: - self._selector.modify(fd, mask | selectors.SELECT_IN, + self._selector.modify(fd, mask | selectors.EVENT_READ, (handler, writer, connector)) return handler @@ -408,7 +408,7 @@ def remove_reader(self, fd): except KeyError: return False else: - mask &= ~selectors.SELECT_IN + mask &= ~selectors.EVENT_READ if not mask: self._selector.unregister(fd) else: @@ -421,10 +421,10 @@ def add_writer(self, fd, callback, *args): try: mask, (reader, writer, connector) = self._selector.get_info(fd) except KeyError: - self._selector.register(fd, selectors.SELECT_OUT, + self._selector.register(fd, selectors.EVENT_WRITE, (None, handler, None)) else: - self._selector.modify(fd, mask | selectors.SELECT_OUT, + self._selector.modify(fd, mask | selectors.EVENT_WRITE, (reader, handler, connector)) return handler @@ -435,7 +435,7 @@ def remove_writer(self, fd): except KeyError: return False else: - mask &= ~selectors.SELECT_OUT + mask &= ~selectors.EVENT_WRITE if not mask: self._selector.unregister(fd) else: @@ -444,16 +444,16 @@ def remove_writer(self, fd): def add_connector(self, fd, callback, *args): """Add a connector callback. Return a Handler instance.""" - # XXX As long as SELECT_CONNECT == SELECT_OUT, set the handler + # XXX As long as EVENT_CONNECT == EVENT_WRITE, set the handler # as both writer and connector. handler = events.make_handler(None, callback, args) try: mask, (reader, writer, connector) = self._selector.get_info(fd) except KeyError: - self._selector.register(fd, selectors.SELECT_CONNECT, + self._selector.register(fd, selectors.EVENT_CONNECT, (None, handler, handler)) else: - self._selector.modify(fd, mask | selectors.SELECT_CONNECT, + self._selector.modify(fd, mask | selectors.EVENT_CONNECT, (reader, handler, handler)) return handler @@ -464,7 +464,7 @@ def remove_connector(self, fd): except KeyError: return False else: - mask &= ~selectors.SELECT_CONNECT + mask &= ~selectors.EVENT_CONNECT if not mask: self._selector.unregister(fd) else: @@ -724,19 +724,19 @@ def _run_once(self, timeout=None): level = logging.DEBUG logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) for fileobj, mask, (reader, writer, connector) in event_list: - if mask & selectors.SELECT_IN and reader is not None: + if mask & selectors.EVENT_READ and reader is not None: if reader.cancelled: self.remove_reader(fileobj) else: self._add_callback(reader) - if mask & selectors.SELECT_OUT and writer is not None: + if mask & selectors.EVENT_WRITE and writer is not None: if writer.cancelled: self.remove_writer(fileobj) else: self._add_callback(writer) # XXX The next elif is unreachable until selector.py - # changes to implement SELECT_CONNECT != SELECTOR_OUT. - elif mask & selectors.SELECT_CONNECT and connector is not None: + # changes to implement EVENT_CONNECT != EVENT_WRITE. + elif mask & selectors.EVENT_CONNECT and connector is not None: if connector.cancelled: self.remove_connector(fileobj) else: From d1d51f327dfbba11f68a8e36598a8de9f8a45e6c Mon Sep 17 00:00:00 2001 From: Charles-Fran?ois Natali Date: Sun, 20 Jan 2013 19:17:51 +0100 Subject: [PATCH 0239/1502] The fix for the ECONNRESET error during testStartServing has been committed "accidently" in a earlier commit: add a comment explaining the logic behind it. --- tulip/events_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tulip/events_test.py b/tulip/events_test.py index d805c32a..d0061cf5 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -428,6 +428,8 @@ def testStartServing(self): el.run_once() el.run_once() sock.close() + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket client.close() From 69105d6769092459d2423f78b0f3b9ea6067afb4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 20 Jan 2013 11:07:42 -0800 Subject: [PATCH 0240/1502] Whitespace fix. --- tulip/subprocess_transport.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tulip/subprocess_transport.py b/tulip/subprocess_transport.py index bcf859b1..721013f8 100644 --- a/tulip/subprocess_transport.py +++ b/tulip/subprocess_transport.py @@ -21,7 +21,7 @@ def __init__(self, protocol, args): self._eof = False rstdin, self._wstdin = os.pipe() self._rstdout, wstdout = os.pipe() - + # TODO: This is incredibly naive. Should look at # subprocess.py for all the precautions around fork/exec. pid = os.fork() From 79e2d2306346b8fd1cb6d8441fe8c5e8cb85e40b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 20 Jan 2013 11:17:13 -0800 Subject: [PATCH 0241/1502] Don't return from create_connection() until after connection_made() is called. --- tulip/unix_events.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index fc32cba3..09544853 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -291,7 +291,6 @@ def getaddrinfo(self, host, port, *, def getnameinfo(self, sockaddr, flags=0): return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) - # TODO: Or create_connection()? Or create_client()? @tasks.task def create_connection(self, protocol_factory, host, port, *, ssl=False, family=0, proto=0, flags=0): @@ -328,14 +327,14 @@ def create_connection(self, protocol_factory, host, port, *, ssl=False, raise socket.error('Multiple exceptions: {}'.format( ', '.join(str(exc) for exc in exceptions))) protocol = protocol_factory() + waiter = futures.Future() if ssl: sslcontext = None if isinstance(ssl, bool) else ssl - waiter = futures.Future() transport = _UnixSslTransport(self, sock, protocol, sslcontext, waiter) - yield from waiter else: - transport = _UnixSocketTransport(self, sock, protocol) + transport = _UnixSocketTransport(self, sock, protocol, waiter) + yield from waiter return transport, protocol # TODO: Or create_server()? @@ -770,7 +769,7 @@ def _run_once(self, timeout=None): class _UnixSocketTransport(transports.Transport): - def __init__(self, event_loop, sock, protocol): + def __init__(self, event_loop, sock, protocol, waiter=None): self._event_loop = event_loop self._sock = sock self._protocol = protocol @@ -778,6 +777,8 @@ def __init__(self, event_loop, sock, protocol): self._closing = False # Set when close() called. self._event_loop.add_reader(self._sock.fileno(), self._read_ready) self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) def _read_ready(self): try: @@ -902,7 +903,7 @@ def _on_handshake(self): self._event_loop.add_reader(fd, self._on_ready) self._event_loop.add_writer(fd, self._on_ready) self._event_loop.call_soon(self._protocol.connection_made, self) - self._waiter.set_result(None) + self._event_loop.call_soon(self._waiter.set_result, None) def _on_ready(self): # Because of renegotiations (?), there's no difference between From eb612331e9aa687b36ab3d5b397732f28c1ed461 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 20 Jan 2013 13:29:27 -0800 Subject: [PATCH 0242/1502] Support -x flag to exclude tests. --- runtests.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/runtests.py b/runtests.py index fde87f08..1eef652b 100644 --- a/runtests.py +++ b/runtests.py @@ -21,7 +21,7 @@ assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' -def load_tests(patterns=()): +def load_tests(includes=(), excludes=()): mods = ['events', 'futures', 'tasks'] test_mods = ['%s_test' % name for name in mods] tulip = __import__('tulip', fromlist=test_mods) @@ -34,26 +34,38 @@ def load_tests(patterns=()): if name.endswith('Tests'): test_module = getattr(mod, name) tests = loader.loadTestsFromTestCase(test_module) - if patterns: + if includes: tests = [test for test in tests - if any(re.search(pat, test.id()) for pat in patterns)] + if any(re.search(pat, test.id()) for pat in includes)] + if excludes: + tests = [test + for test in tests + if not any(re.search(pat, test.id()) for pat in excludes)] suite.addTests(tests) return suite def main(): - patterns = [] + excludes = [] + includes = [] + patterns = includes # A reference. v = 1 for arg in sys.argv[1:]: if arg.startswith('-v'): v += arg.count('v') elif arg == '-q': v = 0 + elif arg == '-x': + if patterns is includes: + patterns = excludes + else: + patterns = includes elif arg and not arg.startswith('-'): patterns.append(arg) - result = unittest.TextTestRunner(verbosity=v).run(load_tests(patterns)) + tests = load_tests(includes, excludes) + result = unittest.TextTestRunner(verbosity=v).run(tests) sys.exit(not result.wasSuccessful()) From 7a445d2aa32a7b28c2e5fe7ff7a65b93b5355f91 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 20 Jan 2013 14:30:47 -0800 Subject: [PATCH 0243/1502] Separate EVENT_WRITE and EVENT_CONNECT. --- tulip/selectors.py | 61 +++++++++++++++++++++++++++++--------------- tulip/unix_events.py | 33 ++++++++++++++++-------- 2 files changed, 63 insertions(+), 31 deletions(-) diff --git a/tulip/selectors.py b/tulip/selectors.py index 8af5e5fc..05434630 100644 --- a/tulip/selectors.py +++ b/tulip/selectors.py @@ -15,7 +15,12 @@ # write event EVENT_WRITE = (1 << 1) # connect event -EVENT_CONNECT = EVENT_WRITE +EVENT_CONNECT = (1 << 2) + +# In most cases we treat EVENT_WRITE and EVENT_CONNECT as aliases for +# each other, and in fact we return both flags when a FD is found +# either writable or connectable. The distinction is necessary +# only for poll() on Windows. def _fileobj_to_fd(fileobj): @@ -79,15 +84,20 @@ def register(self, fileobj, events, data=None): Parameters: fileobj -- file object - events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + events -- events to monitor (bitwise mask of + EVENT_READ|EVENT_WRITE|EVENT_CONNECT) data -- attached data Returns: SelectorKey instance """ - if (not events) or (events & ~(EVENT_READ|EVENT_WRITE)): + if (not events) or (events & ~(EVENT_READ|EVENT_WRITE|EVENT_CONNECT)): raise ValueError("Invalid events: {}".format(events)) + if events & (EVENT_WRITE|EVENT_CONNECT) == (EVENT_WRITE|EVENT_CONNECT): + raise ValueError("WRITE and CONNECT are mutually exclusive. " + "Invalid events: {}".format(events)) + if fileobj in self._fileobj_to_key: raise ValueError("{!r} is already registered".format(fileobj)) @@ -118,11 +128,18 @@ def modify(self, fileobj, events, data=None): Parameters: fileobj -- file object - events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + events -- events to monitor (bitwise mask of + EVENT_READ|EVENT_WRITE|EVENT_CONNECT) data -- attached data """ - self.unregister(fileobj) - self.register(fileobj, events, data) + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + if events != key.events: + self.unregister(fileobj) + self.register(fileobj, events, data) def select(self, timeout=None): """Perform the actual selection, until some monitored file objects are @@ -138,7 +155,7 @@ def select(self, timeout=None): Returns: list of (fileobj, events, attached data) for ready file objects - `events` is a bitwise mask of EVENT_READ|EVENT_WRITE + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE|EVENT_CONNECT """ raise NotImplementedError() @@ -206,7 +223,7 @@ def register(self, fileobj, events, data=None): key = super().register(fileobj, events, data) if events & EVENT_READ: self._readers.add(key.fd) - if events & EVENT_WRITE: + if events & (EVENT_WRITE|EVENT_CONNECT): self._writers.add(key.fd) return key @@ -230,11 +247,11 @@ def select(self, timeout=None): if fd in r: events |= EVENT_READ if fd in w: - events |= EVENT_WRITE + events |= EVENT_WRITE|EVENT_CONNECT key = self._key_from_fd(fd) if key: - ready.append((key.fileobj, events, key.data)) + ready.append((key.fileobj, events & key.events, key.data)) return ready if sys.platform == 'win32': @@ -247,6 +264,10 @@ def _select(self, r, w, _, timeout=None): if 'poll' in globals(): + # TODO: Implement poll() for Windows with workaround for + # brokenness in WSAPoll() (Richard Oudkerk, see + # http://bugs.python.org/issue16507). + class PollSelector(_BaseSelector): """Poll-based selector.""" @@ -259,7 +280,7 @@ def register(self, fileobj, events, data=None): poll_events = 0 if events & EVENT_READ: poll_events |= POLLIN - if events & EVENT_WRITE: + if events & (EVENT_WRITE|EVENT_CONNECT): poll_events |= POLLOUT self._poll.register(key.fd, poll_events) return key @@ -280,13 +301,13 @@ def select(self, timeout=None): for fd, event in fd_event_list: events = 0 if event & ~POLLIN: - events |= EVENT_WRITE + events |= EVENT_WRITE|EVENT_CONNECT if event & ~POLLOUT: events |= EVENT_READ key = self._key_from_fd(fd) if key: - ready.append((key.fileobj, events, key.data)) + ready.append((key.fileobj, events & key.events, key.data)) return ready @@ -304,7 +325,7 @@ def register(self, fileobj, events, data=None): epoll_events = 0 if events & EVENT_READ: epoll_events |= EPOLLIN - if events & EVENT_WRITE: + if events & (EVENT_WRITE|EVENT_CONNECT): epoll_events |= EPOLLOUT self._epoll.register(key.fd, epoll_events) return key @@ -326,13 +347,13 @@ def select(self, timeout=None): for fd, event in fd_event_list: events = 0 if event & ~EPOLLIN: - events |= EVENT_WRITE + events |= EVENT_WRITE|EVENT_CONNECT if event & ~EPOLLOUT: events |= EVENT_READ key = self._key_from_fd(fd) if key: - ready.append((key.fileobj, events, key.data)) + ready.append((key.fileobj, events & key.events, key.data)) return ready def close(self): @@ -355,7 +376,7 @@ def unregister(self, fileobj): if key.events & EVENT_READ: kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_DELETE) self._kqueue.control([kev], 0, 0) - if key.events & EVENT_WRITE: + if key.events & (EVENT_WRITE|EVENT_CONNECT): kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_DELETE) self._kqueue.control([kev], 0, 0) return key @@ -365,7 +386,7 @@ def register(self, fileobj, events, data=None): if events & EVENT_READ: kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_ADD) self._kqueue.control([kev], 0, 0) - if events & EVENT_WRITE: + if events & (EVENT_WRITE|EVENT_CONNECT): kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_ADD) self._kqueue.control([kev], 0, 0) return key @@ -385,11 +406,11 @@ def select(self, timeout=None): if flag == KQ_FILTER_READ: events |= EVENT_READ if flag == KQ_FILTER_WRITE: - events |= EVENT_WRITE + events |= EVENT_WRITE|EVENT_CONNECT key = self._key_from_fd(fd) if key: - ready.append((key.fileobj, events, key.data)) + ready.append((key.fileobj, events & key.events, key.data)) return ready def close(self): diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 09544853..d23c0163 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -423,8 +423,10 @@ def add_writer(self, fd, callback, *args): self._selector.register(fd, selectors.EVENT_WRITE, (None, handler, None)) else: + # Remove connector. + mask &= ~selectors.EVENT_CONNECT self._selector.modify(fd, mask | selectors.EVENT_WRITE, - (reader, handler, connector)) + (reader, handler, None)) return handler def remove_writer(self, fd): @@ -434,26 +436,36 @@ def remove_writer(self, fd): except KeyError: return False else: - mask &= ~selectors.EVENT_WRITE + # Remove both writer and connector. + mask &= ~(selectors.EVENT_WRITE | selectors.EVENT_CONNECT) if not mask: self._selector.unregister(fd) else: - self._selector.modify(fd, mask, (reader, None, connector)) + self._selector.modify(fd, mask, (reader, None, None)) return True + # NOTE: add_connector() and add_writer() are mutually exclusive. + # While you can independently manipulate readers and writers, + # adding a connector for a particular FD automatically removes the + # writer for that FD, and vice versa, and removing a writer or a + # connector actually removes both writer and connector. This is + # because in most cases writers and connectors use the same mode + # for the platform polling function; the distinction is only + # important for PollSelector() on Windows. + def add_connector(self, fd, callback, *args): """Add a connector callback. Return a Handler instance.""" - # XXX As long as EVENT_CONNECT == EVENT_WRITE, set the handler - # as both writer and connector. handler = events.make_handler(None, callback, args) try: mask, (reader, writer, connector) = self._selector.get_info(fd) except KeyError: self._selector.register(fd, selectors.EVENT_CONNECT, - (None, handler, handler)) + (None, None, handler)) else: + # Remove writer. + mask &= ~selectors.EVENT_WRITE self._selector.modify(fd, mask | selectors.EVENT_CONNECT, - (reader, handler, handler)) + (reader, None, handler)) return handler def remove_connector(self, fd): @@ -463,7 +475,8 @@ def remove_connector(self, fd): except KeyError: return False else: - mask &= ~selectors.EVENT_CONNECT + # Remove both writer and connector. + mask &= ~(selectors.EVENT_WRITE | selectors.EVENT_CONNECT) if not mask: self._selector.unregister(fd) else: @@ -733,8 +746,6 @@ def _run_once(self, timeout=None): self.remove_writer(fileobj) else: self._add_callback(writer) - # XXX The next elif is unreachable until selector.py - # changes to implement EVENT_CONNECT != EVENT_WRITE. elif mask & selectors.EVENT_CONNECT and connector is not None: if connector.cancelled: self.remove_connector(fileobj) @@ -888,7 +899,7 @@ def _on_handshake(self): self._event_loop.add_reader(fd, self._on_handshake) return except ssl.SSLWantWriteError: - self._event_loop.add_writable(fd, self._on_handshake) + self._event_loop.add_writer(fd, self._on_handshake) return except Exception as exc: self._sslsock.close() From fb25d9012471e095db73bb5d19c39dc17a42669d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 21 Jan 2013 08:32:21 -0800 Subject: [PATCH 0244/1502] Cancel handlers in remove_reader() etc. --- tulip/unix_events.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index d23c0163..bf58d22e 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -412,6 +412,8 @@ def remove_reader(self, fd): self._selector.unregister(fd) else: self._selector.modify(fd, mask, (None, writer, connector)) + if reader is not None: + reader.cancel() return True def add_writer(self, fd, callback, *args): @@ -442,6 +444,10 @@ def remove_writer(self, fd): self._selector.unregister(fd) else: self._selector.modify(fd, mask, (reader, None, None)) + if writer is not None: + writer.cancel() + if connector is not None: + connector.cancel() return True # NOTE: add_connector() and add_writer() are mutually exclusive. @@ -481,6 +487,10 @@ def remove_connector(self, fd): self._selector.unregister(fd) else: self._selector.modify(fd, mask, (reader, None, None)) + if writer is not None: + writer.cancel() + if connector is not None: + connector.cancel() return True def sock_recv(self, sock, n): From 98eb0c7155ec57bce37ac073f8fdfc091cc73cc5 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 21 Jan 2013 08:37:57 -0800 Subject: [PATCH 0245/1502] Close ssock/csock FDs upon event loop close(). --- tulip/unix_events.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index bf58d22e..cea3ba95 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -83,7 +83,7 @@ def __init__(self, selector=None): if selector is None: # pick the best selector class for the platform selector = selectors.Selector() - logging.info('Using selector: %s', selector.__class__.__name__) + logging.debug('Using selector: %s', selector.__class__.__name__) self._selector = selector self._ready = collections.deque() self._scheduled = [] @@ -96,6 +96,8 @@ def close(self): if self._selector is not None: self._selector.close() self._selector = None + self._ssock.close() + self._csock.close() def _make_self_pipe(self): # A self-socket, really. :-) From 27c403531670f52cad8388aaa2a13a658f753fd5 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Mon, 21 Jan 2013 20:34:38 +0000 Subject: [PATCH 0246/1502] New experimental iocp branch. --- .hgeol | 2 + .hgignore | 8 + Makefile | 30 + NOTES | 130 +++++ README | 30 + TODO | 165 ++++++ check.py | 40 ++ crawl.py | 133 +++++ curl.py | 31 + old/Makefile | 16 + old/echoclt.py | 79 +++ old/echosvr.py | 60 ++ old/http_client.py | 78 +++ old/http_server.py | 68 +++ old/main.py | 134 +++++ old/p3time.py | 47 ++ old/polling.py | 535 ++++++++++++++++++ old/scheduling.py | 354 ++++++++++++ old/sockets.py | 348 ++++++++++++ old/transports.py | 496 ++++++++++++++++ old/xkcd.py | 18 + old/yyftime.py | 75 +++ overlapped.c | 997 ++++++++++++++++++++++++++++++++ runtests.py | 73 +++ setup.cfg | 2 + setup.py | 4 + tulip/TODO | 26 + tulip/__init__.py | 14 + tulip/events.py | 287 ++++++++++ tulip/events_test.py | 512 +++++++++++++++++ tulip/futures.py | 240 ++++++++ tulip/futures_test.py | 210 +++++++ tulip/http_client.py | 292 ++++++++++ tulip/iocpsockets.py | 360 ++++++++++++ tulip/protocols.py | 58 ++ tulip/selectors.py | 434 ++++++++++++++ tulip/subprocess_test.py | 48 ++ tulip/subprocess_transport.py | 133 +++++ tulip/tasks.py | 277 +++++++++ tulip/tasks_test.py | 233 ++++++++ tulip/transports.py | 90 +++ tulip/unix_events.py | 1001 +++++++++++++++++++++++++++++++++ tulip/winsocketpair.py | 30 + 43 files changed, 8198 insertions(+) create mode 100644 .hgeol create mode 100644 .hgignore create mode 100644 Makefile create mode 100644 NOTES create mode 100644 README create mode 100644 TODO create mode 100644 check.py create mode 100755 crawl.py create mode 100755 curl.py create mode 100644 old/Makefile create mode 100644 old/echoclt.py create mode 100644 old/echosvr.py create mode 100644 old/http_client.py create mode 100644 old/http_server.py create mode 100644 old/main.py create mode 100644 old/p3time.py create mode 100644 old/polling.py create mode 100644 old/scheduling.py create mode 100644 old/sockets.py create mode 100644 old/transports.py create mode 100755 old/xkcd.py create mode 100644 old/yyftime.py create mode 100644 overlapped.c create mode 100644 runtests.py create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 tulip/TODO create mode 100644 tulip/__init__.py create mode 100644 tulip/events.py create mode 100644 tulip/events_test.py create mode 100644 tulip/futures.py create mode 100644 tulip/futures_test.py create mode 100644 tulip/http_client.py create mode 100644 tulip/iocpsockets.py create mode 100644 tulip/protocols.py create mode 100644 tulip/selectors.py create mode 100644 tulip/subprocess_test.py create mode 100644 tulip/subprocess_transport.py create mode 100644 tulip/tasks.py create mode 100644 tulip/tasks_test.py create mode 100644 tulip/transports.py create mode 100644 tulip/unix_events.py create mode 100644 tulip/winsocketpair.py diff --git a/.hgeol b/.hgeol new file mode 100644 index 00000000..b6910a22 --- /dev/null +++ b/.hgeol @@ -0,0 +1,2 @@ +[patterns] +** = native diff --git a/.hgignore b/.hgignore new file mode 100644 index 00000000..42309f0c --- /dev/null +++ b/.hgignore @@ -0,0 +1,8 @@ +.*\.py[co]$ +.*~$ +.*\.orig$ +.*\#.*$ +.*@.*$ +\.coverage$ +htmlcov$ +\.DS_Store$ diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..d11e9716 --- /dev/null +++ b/Makefile @@ -0,0 +1,30 @@ +PYTHON=python3 +COVERAGE=coverage3 +NONTESTS=`find tulip -name [a-z]\*.py ! -name \*_test.py` +FLAGS= + +test: + $(PYTHON) runtests.py -v $(FLAGS) + +testloop: + while sleep 1; do $(PYTHON) runtests.py -v $(FLAGS); done + +cov coverage: + $(COVERAGE) run --branch runtests.py -v $(FLAGS) + $(COVERAGE) html $(NONTESTS) + $(COVERAGE) report -m $(NONTESTS) + echo "open file://`pwd`/htmlcov/index.html" + +check: + $(PYTHON) check.py + +clean: + rm -rf __pycache__ */__pycache__ + rm -f *.py[co] */*.py[co] + rm -f *~ */*~ + rm -f .*~ */.*~ + rm -f @* */@* + rm -f '#'*'#' */'#'*'#' + rm -f *.orig */*.orig + rm -f .coverage + rm -rf htmlcov diff --git a/NOTES b/NOTES new file mode 100644 index 00000000..6f41578e --- /dev/null +++ b/NOTES @@ -0,0 +1,130 @@ +Notes from the second Tulip/Twisted meet-up +=========================================== + +Rackspace, 12/11/2012 +Glyph, Brian Warner, David Reid, Duncan McGreggor, others + +Flow control +------------ + +- Pause/resume on transport manages data_received. + +- There's also an API to tell the transport whom to pause when the + write calls are overwhelming it: IConsumer.registerProducer(). + +- There's also something called pipes but it's built on top of the + old interface. + +- Twisted has variations on the basic flow control that I should + ignore. + +Half_close +---------- + +- This sends an EOF after writing some stuff. + +- Can't write any more. + +- Problem with TLS is known (the RFC sadly specifies this behavior). + +- It must be dynamimcally discoverable whether the transport supports + half_close, since the protocol may have to do something different to + make up for its missing (e.g. use chunked encoding). Twisted uses + an interface check for this and also hasattr(trans, 'halfClose') + but a flag (or flag method) is fine too. + +Constructing transport and protocol +----------------------------------- + +- There are good reasons for passing a function to the transport + construction helper that creates the protocol. (You need these + anyway for server-side protocols.) The sequence of events is + something like + + . open socket + . create transport (pass it a socket?) + . create protocol (pass it nothing) + . proto.make_connection(transport); this does: + . self.transport = transport + . self.connection_made(transport) + + But it seems okay to skip make_connection and setting .transport. + Note that make_connection() is a concrete method on the Protocol + implementation base class, while connection_made() is an abstract + method on IProtocol. + +Event Loop +---------- + +- We discussed the sequence of actions in the event loop. I think in the + end we're fine with what Tulip currently does. There are two choices: + + Tulip: + . run ready callbacks until there aren't any left + . poll, adding more callbacks to the ready list + . add now-ready delayed callbacks to the ready list + . go to top + + Tornado: + . run all currently ready callbacks (but not new ones added during this) + . (the rest is the same) + + The difference is that in the Tulip version, CPU bound callbacks + that keep adding more to the queue will starve I/O (and yielding to + other tasks won't actually cause I/O to happen unless you do + e.g. sleep(0.001)). OTOH this may be good because it means there's + less overhead if you frequently split operations in two. + +- I think Twisted does it Tornado style (in a convoluted way :-), but + it may not matter, and it's important to leave this vague so + implementations can do what's best for their platform. (E.g. if the + event loop is built into the OS there are different trade-offs.) + +System call cost +---------------- + +- System calls on MacOS are expensive, on Linux they are cheap. + +- Optimal buffer size ~16K. + +- Try joining small buffer pieces together, but expect to be tuning + this later. + +Futures +------- + +- Futures are the most robust API for async stuff, you can check + errors etc. So let's do this. + +- Just don't implement wait(). + +- For the basics, however, (recv/send, mostly), don't use Futures but use + basic callbacks, transport/protocol style. + +- make_connection() (by any name) can return a Future, it makes it + easier to check for errors. + +- This means revisiting the Tulip proactor branch (IOCP). + +- The semantics of add_done_callback() are fuzzy about in which thread + the callback will be called. (It may be the current thread or + another one.) We don't like that. But always inserting a + call_soon() indirection may be expensive? Glyph suggested changing + the add_done_callback() method name to something else to indicate + the changed promise. + +- Separately, I've been thinking about having two versions of + call_soon() -- a more heavy-weight one to be called from other + threads that also writes a byte to the self-pipe. + +Signals +------- + +- There was a side conversation about signals. A signal handler is + similar to another thread, so probably should use (the heavy-weight + version of) call_soon() to schedule the real callback and not do + anything else. + +- Glyph vaguely recalled some trickiness with the self-pipe. We + should be able to fix this afterwards if necessary, it shouldn't + affect the API design. diff --git a/README b/README new file mode 100644 index 00000000..c1c86a54 --- /dev/null +++ b/README @@ -0,0 +1,30 @@ +Tulip is the codename for my reference implementation of PEP 3156. + +PEP 3156: http://www.python.org/dev/peps/pep-3156/ + +*** This requires Python 3.3 or later! *** + +Copyright/license: Open source, Apache 2.0. Enjoy. + +Master Mercurial repo: http://code.google.com/p/tulip/ + +The old code lives in the subdirectory 'old'; the new code (conforming +to PEP 3156, under construction) lives in the 'tulip' subdirectory. + +To run tests: + - make test + +To run coverage (after installing coverage3, see below): + - make coverage + +To install coverage3 (coverage.py for Python 3), you need: + - Distribute (http://packages.python.org/distribute/) + - Coverage (http://nedbatchelder.com/code/coverage/) + What worked for me: + - curl -O http://python-distribute.org/distribute_setup.py + - python3 distribute_setup.py + - + - cd coveragepy + - python3 setup.py install + +--Guido van Rossum diff --git a/TODO b/TODO new file mode 100644 index 00000000..b9559ef0 --- /dev/null +++ b/TODO @@ -0,0 +1,165 @@ +# -*- Mode: text -*- + +TO DO LARGER TASKS + +- Need more examples. + +- Benchmarkable but more realistic HTTP server? + +- Example of using UDP. + +- Write up a tutorial for the scheduling API. + +- More systematic approach to logging. Logger objects? What about + heavy-duty logging, tracking essentially all task state changes? + +- Restructure directory, move demos and benchmarks to subdirectories. + + +TO DO LATER + +- When multiple tasks are accessing the same socket, they should + either get interleaved I/O or an immediate exception; it should not + compromise the integrity of the scheduler or the app or leave a task + hanging. + +- For epoll you probably want to check/(log?) EPOLLHUP and EPOLLERR errors. + +- Add the simplest API possible to run a generator with a timeout. + +- Ensure multiple tasks can do atomic writes to the same pipe (since + UNIX guarantees that short writes to pipes are atomic). + +- Ensure some easy way of distributing accepted connections across tasks. + +- Be wary of thread-local storage. There should be a standard API to + get the current Context (which holds current task, event loop, and + maybe more) and a standard meta-API to change how that standard API + works (i.e. without monkey-patching). + +- See how much of asyncore I've already replaced. + +- Could BufferedReader reuse the standard io module's readers??? + +- Support ZeroMQ "sockets" which are user objects. Though possibly + this can be supported by getting the underlying fd? See + http://mail.python.org/pipermail/python-ideas/2012-October/017532.html + OTOH see + https://github.com/zeromq/pyzmq/blob/master/zmq/eventloop/ioloop.py + +- Study goroutines (again). + +- Benchmarks: http://nichol.as/benchmark-of-python-web-servers + + +FROM OLDER LIST + +- Multiple readers/writers per socket? (At which level? pollster, + eventloop, or scheduler?) + +- Could poll() usefully be an iterator? + +- Do we need to support more epoll and/or kqueue modes/flags/options/etc.? + +- Optimize register/unregister calls away if they cancel each other out? + +- Add explicit wait queue to wait for Task's completion, instead of + callbacks? + +- Implement various lock styles a la threading.py. + +- Look at pyfdpdlib's ioloop.py: + http://code.google.com/p/pyftpdlib/source/browse/trunk/pyftpdlib/lib/ioloop.py + + +MISTAKES I MADE + +- Forgetting yield from. (E.g.: scheduler.sleep(1); listener.accept().) + +- Forgot to add bare yield at end of internal function, after block(). + +- Forgot to call add_done_callback(). + +- Forgot to pass an undoer to block(), bug only found when cancelled. + +- Subtle accounting mistake in a callback. + +- Used context.eventloop from a different thread, forgetting about TLS. + +- Nasty race: eventloop.ready may contain both an I/O callback and a + cancel callback. How to avoid? Keep the DelayedCall in ready. Is + that enough? + +- If a toplevel task raises an error it just stops and nothing is logged + unless you have debug logging on. This confused me. (Then again, + previously I logged whenever a task raised an error, and that was too + chatty...) + +- Forgot to set the connection socket returned by accept() in + nonblocking mode. + +- Nastiest so far (cost me about a day): A race condition in + call_in_thread() where the Future's done_callback (which was + task.unblock()) would run immediately at the time when + add_done_callback() was called, and this screwed over the task + state. Solution: wrap the callback in eventloop.call_later(). + Ironically, I had a comment stating there might be a race condition. + +- Another bug where I was calling unblock() for the current thread + immediately after calling block(), before yielding. + +- readexactly() wasn't checking for EOF, so could be looping. + (Worse, the first fix I attempted was wrong.) + +- Spent a day trying to understand why a tentative patch trying to + move the recv() implementation into the eventloop (or the pollster) + resulted in problems cancelling a recv() call. Ultimately the + problem is that the cancellation mechanism is part of the coroutine + scheduler, which simply throws an exception into a task when it next + runs, and there isn't anything to be interrupted in the eventloop; + but the eventloop still has a reader registered (which will never + fire because I suspended the server -- that's my test case :-). + Then, the eventloop keeps running until the last file descriptor is + unregistered. What contributed to this disaster? + * I didn't build the whole infrastructure, just played with recv() + * I don't have unittests + * I don't have good logging to see what is going + +- In sockets.py, in some SSL error handling code, used the wrong + variable (sock instead of sslsock). A linter would have found this. + +- In polling.py, in KqueuePollster.register_writer(), a copy/paste + error where I was testing for "if fd not in self.readers" instead of + writers. This only came out when I had both a reader and a writer + for the same fd. + +- Submitted some changes prematurely (forgot to pass the filename on + hg ci). + +- Forgot again that shutdown(SHUT_WR) on an ssl socket does not work + as I expected. I ran into this with the origininal sockets.py and + again in transport.py. + +- Having the same callback for both reading and writing has a problem: + it may be scheduled twice, and if the first call closes the socket, + the second runs into trouble. + + +MISTAKES I MADE IN TULIP V2 + +- Nice one: Future._schedule_callbacks() wasn't scheduling any callbacks. + Spot the bug in these four lines: + + def _schedule_callbacks(self): + callbacks = self._callbacks[:] + self._callbacks[:] = [] + for callback in self._callbacks: + self._event_loop.call_soon(callback, self) + + The good news is that I found it with a unittest (albeit not the + unittest intended to exercise this particular method :-( ). + +- In _make_self_pipe_or_sock(), called _pollster.register_reader() + instead of add_reader(), trying to optimize something but breaking + things instead (since the -- internal -- API of register_reader() + had changed). diff --git a/check.py b/check.py new file mode 100644 index 00000000..f0aa9a66 --- /dev/null +++ b/check.py @@ -0,0 +1,40 @@ +"""Search for lines > 80 chars or with trailing whitespace.""" + +import sys, os + +def main(): + args = sys.argv[1:] or os.curdir + for arg in args: + if os.path.isdir(arg): + for dn, dirs, files in os.walk(arg): + for fn in sorted(files): + if fn.endswith('.py'): + process(os.path.join(dn, fn)) + dirs[:] = [d for d in dirs if d[0] != '.'] + dirs.sort() + else: + process(arg) + +def isascii(x): + try: + x.encode('ascii') + return True + except UnicodeError: + return False + +def process(fn): + try: + f = open(fn) + except IOError as err: + print(err) + return + try: + for i, line in enumerate(f): + line = line.rstrip('\n') + sline = line.rstrip() + if len(line) > 80 or line != sline or not isascii(line): + print('%s:%d:%s%s' % (fn, i+1, sline, '_' * (len(line) - len(sline)))) + finally: + f.close() + +main() diff --git a/crawl.py b/crawl.py new file mode 100755 index 00000000..a1303fcb --- /dev/null +++ b/crawl.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 + +import logging +import re +import signal +import socket +import sys +import urllib.parse + +import tulip +from tulip import http_client + +END = '\n' +MAXTASKS = 100 + + +class Crawler: + + def __init__(self, rooturl): + self.rooturl = rooturl + self.todo = set() + self.busy = set() + self.done = {} + self.tasks = set() + self.waiter = None + self.add(self.rooturl) # Set initial work. + self.run() # Kick off work. + + def add(self, url): + url = urllib.parse.urljoin(self.rooturl, url) + url, frag = urllib.parse.urldefrag(url) + if not url.startswith(self.rooturl): + return False + if url in self.busy or url in self.done or url in self.todo: + return False + self.todo.add(url) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(None) + return True + + @tulip.task + def run(self): + while self.todo or self.busy or self.tasks: + complete, self.tasks = yield from tulip.wait(self.tasks, timeout=0) + print(len(complete), 'completed tasks,', len(self.tasks), + 'still pending ', end=END) + for task in complete: + try: + yield from task + except Exception as exc: + print('Exception in task:', exc, end=END) + while self.todo and len(self.tasks) < MAXTASKS: + url = self.todo.pop() + self.busy.add(url) + self.tasks.add(self.process(url)) # Async task. + if self.busy: + self.waiter = tulip.Future() + yield from self.waiter + tulip.get_event_loop().stop() + + @tulip.task + def process(self, url): + ok = False + p = None + try: + print('processing', url, end=END) + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not path: + path = '/' + if query: + path = '?'.join([path, query]) + p = http_client.HttpClientProtocol(netloc, path=path, + ssl=(scheme=='https')) + delay = 1 + while True: + try: + status, headers, stream = yield from p.connect() + break + except socket.error as exc: + if delay >= 60: + raise + print('...', url, 'has error', repr(str(exc)), + 'retrying after sleep', delay, '...', end=END) + yield from tulip.sleep(delay) + delay *= 2 + if status[:3] in ('301', '302'): + # Redirect. + u = headers.get('location') or headers.get('uri') + if self.add(u): + print(' ', url, status[:3], 'redirect to', u, end=END) + elif status.startswith('200'): + ctype = headers.get_content_type() + if ctype == 'text/html': + while True: + line = yield from stream.readline() + if not line: + break + line = line.decode('utf-8', 'replace') + urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', + line) + for u in urls: + if self.add(u): + print(' ', url, 'href to', u, end=END) + ok = True + finally: + if p is not None: + p.transport.close() + self.done[url] = ok + self.busy.remove(url) + if not ok: + print('failure for', url, sys.exc_info(), end=END) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(None) + + +def main(): + rooturl = sys.argv[1] + c = Crawler(rooturl) + loop = tulip.get_event_loop() + loop.add_signal_handler(signal.SIGINT, loop.stop) + loop.run_forever() + print('todo:', len(c.todo)) + print('busy:', len(c.busy)) + print('done:', len(c.done), '; ok:', sum(c.done.values())) + print('tasks:', len(c.tasks)) + + +if __name__ == '__main__': + main() diff --git a/curl.py b/curl.py new file mode 100755 index 00000000..1a73c194 --- /dev/null +++ b/curl.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 + +import sys +import urllib.parse + +import tulip +from tulip import http_client + + +def main(): + url = sys.argv[1] + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not path: + path = '/' + if query: + path = '?'.join([path, query]) + print(netloc, path, scheme) + p = http_client.HttpClientProtocol(netloc, path=path, + ssl=(scheme=='https')) + f = p.connect() + sts, headers, stream = p.event_loop.run_until_complete(tulip.Task(f)) + print(sts) + for k, v in headers.items(): + print('{}: {}'.format(k, v)) + print() + data = p.event_loop.run_until_complete(tulip.Task(stream.read(1000000))) + print(data.decode('utf-8', 'replace')) + + +if __name__ == '__main__': + main() diff --git a/old/Makefile b/old/Makefile new file mode 100644 index 00000000..d352cd70 --- /dev/null +++ b/old/Makefile @@ -0,0 +1,16 @@ +PYTHON=python3 + +main: + $(PYTHON) main.py -v + +echo: + $(PYTHON) echosvr.py -v + +profile: + $(PYTHON) -m profile -s time main.py + +time: + $(PYTHON) p3time.py + +ytime: + $(PYTHON) yyftime.py diff --git a/old/echoclt.py b/old/echoclt.py new file mode 100644 index 00000000..c24c573e --- /dev/null +++ b/old/echoclt.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3.3 +"""Example echo client.""" + +# Stdlib imports. +import logging +import socket +import sys +import time + +# Local imports. +import scheduling +import sockets + + +def echoclient(host, port): + """COROUTINE""" + testdata = b'hi hi hi ha ha ha\n' + try: + trans = yield from sockets.create_transport(host, port, + af=socket.AF_INET) + except OSError: + return False + try: + ok = yield from trans.send(testdata) + if ok: + response = yield from trans.recv(100) + ok = response == testdata.upper() + return ok + finally: + trans.close() + + +def doit(n): + """COROUTINE""" + t0 = time.time() + tasks = set() + for i in range(n): + t = scheduling.Task(echoclient('127.0.0.1', 1111), 'client-%d' % i) + tasks.add(t) + ok = 0 + bad = 0 + for t in tasks: + try: + yield from t + except Exception: + bad += 1 + else: + ok += 1 + t1 = time.time() + print('ok: ', ok) + print('bad:', bad) + print('dt: ', round(t1-t0, 6)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Get integer from command line. + n = 1 + for arg in sys.argv[1:]: + if not arg.startswith('-'): + n = int(arg) + break + + # Run scheduler, starting it off with doit(). + scheduling.run(doit(n)) + + +if __name__ == '__main__': + main() diff --git a/old/echosvr.py b/old/echosvr.py new file mode 100644 index 00000000..4085f4c6 --- /dev/null +++ b/old/echosvr.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3.3 +"""Example echo server.""" + +# Stdlib imports. +import logging +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + while True: + line = yield from rdr.readline() + logging.debug('Received: %r from %r', line, addr) + if not line: + break + yield from trans.send(line.upper()) + logging.debug('Closing %r', addr) + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 1111, + af=socket.AF_INET, + backlog=100) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() diff --git a/old/http_client.py b/old/http_client.py new file mode 100644 index 00000000..8937ba20 --- /dev/null +++ b/old/http_client.py @@ -0,0 +1,78 @@ +"""Crummy HTTP client. + +This is not meant as an example of how to write a good client. +""" + +# Stdlib. +import re +import time + +# Local. +import sockets + + +def urlfetch(host, port=None, path='/', method='GET', + body=None, hdrs=None, encoding='utf-8', ssl=None, af=0): + """COROUTINE: Make an HTTP 1.0 request.""" + t0 = time.time() + if port is None: + if ssl: + port = 443 + else: + port = 80 + trans = yield from sockets.create_transport(host, port, ssl=ssl, af=af) + yield from trans.send(method.encode(encoding) + b' ' + + path.encode(encoding) + b' HTTP/1.0\r\n') + if hdrs: + kwds = dict(hdrs) + else: + kwds = {} + if 'host' not in kwds: + kwds['host'] = host + if body is not None: + kwds['content_length'] = len(body) + for header, value in kwds.items(): + yield from trans.send(header.replace('_', '-').encode(encoding) + + b': ' + value.encode(encoding) + b'\r\n') + + yield from trans.send(b'\r\n') + if body is not None: + yield from trans.send(body) + + # Read HTTP response line. + rdr = sockets.BufferedReader(trans) + resp = yield from rdr.readline() + m = re.match(br'(?ix) http/(\d\.\d) \s+ (\d\d\d) \s+ ([^\r]*)\r?\n\Z', + resp) + if not m: + trans.close() + raise IOError('No valid HTTP response: %r' % resp) + http_version, status, message = m.groups() + + # Read HTTP headers. + headers = [] + hdict = {} + while True: + line = yield from rdr.readline() + if not line.strip(): + break + m = re.match(br'([^\s:]+):\s*([^\r]*)\r?\n\Z', line) + if not m: + raise IOError('Invalid header: %r' % line) + header, value = m.groups() + headers.append((header, value)) + hdict[header.decode(encoding).lower()] = value.decode(encoding) + + # Read response body. + content_length = hdict.get('content-length') + if content_length is not None: + size = int(content_length) # TODO: Catch errors. + assert size >= 0, size + else: + size = 2**20 # Protective limit (1 MB). + data = yield from rdr.readexactly(size) + trans.close() # Can this block? + t1 = time.time() + result = (host, port, path, int(status), len(data), round(t1-t0, 3)) +## print(result) + return result diff --git a/old/http_server.py b/old/http_server.py new file mode 100644 index 00000000..2b1e3dd6 --- /dev/null +++ b/old/http_server.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3.3 +"""Simple HTTP server. + +This currenty exists just so we can benchmark this thing! +""" + +# Stdlib imports. +import logging +import re +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + ##logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + + # Read but ignore request line. + request_line = yield from rdr.readline() + + # Consume headers but don't interpret them. + while True: + header_line = yield from rdr.readline() + if not header_line.strip(): + break + + # Always send an empty 200 response and close. + yield from trans.send(b'HTTP/1.0 200 Ok\r\n\r\n') + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 8080, + af=socket.AF_INET) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() diff --git a/old/main.py b/old/main.py new file mode 100644 index 00000000..c1f9d0a8 --- /dev/null +++ b/old/main.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3.3 +"""Example HTTP client using yield-from coroutines (PEP 380). + +Requires Python 3.3. + +There are many micro-optimizations possible here, but that's not the point. + +Some incomplete laundry lists: + +TODO: +- Take test urls from command line. +- Move urlfetch to a separate module. +- Profiling. +- Docstrings. +- Unittests. + +FUNCTIONALITY: +- Connection pool (keep connection open). +- Chunked encoding (request and response). +- Pipelining, e.g. zlib (request and response). +- Automatic encoding/decoding. +""" + +__author__ = 'Guido van Rossum ' + +# Standard library imports (keep in alphabetic order). +import logging +import os +import time +import socket +import sys + +# Local imports (keep in alphabetic order). +import scheduling +import http_client + + + +def doit2(): + argses = [ + ('localhost', 8080, '/'), + ('127.0.0.1', 8080, '/home'), + ('python.org', 80, '/'), + ('xkcd.com', 443, '/'), + ] + results = yield from scheduling.map_over( + lambda args: http_client.urlfetch(*args), argses, timeout=2) + for res in results: + print('-->', res) + return [] + + +def doit(): + TIMEOUT = 2 + tasks = set() + + # This references NDB's default test service. + # (Sadly the service is single-threaded.) + task1 = scheduling.Task(http_client.urlfetch('localhost', 8080, path='/'), + 'root', timeout=TIMEOUT) + tasks.add(task1) + task2 = scheduling.Task(http_client.urlfetch('127.0.0.1', 8080, + path='/home'), + 'home', timeout=TIMEOUT) + tasks.add(task2) + + # Fetch python.org home page. + task3 = scheduling.Task(http_client.urlfetch('python.org', 80, path='/'), + 'python', timeout=TIMEOUT) + tasks.add(task3) + + # Fetch XKCD home page using SSL. (Doesn't like IPv6.) + task4 = scheduling.Task(http_client.urlfetch('xkcd.com', ssl=True, path='/', + af=socket.AF_INET), + 'xkcd', timeout=TIMEOUT) + tasks.add(task4) + +## # Fetch many links from python.org (/x.y.z). +## for x in '123': +## for y in '0123456789': +## path = '/{}.{}'.format(x, y) +## g = http_client.urlfetch('82.94.164.162', 80, +## path=path, hdrs={'host': 'python.org'}) +## t = scheduling.Task(g, path, timeout=2) +## tasks.add(t) + +## print(tasks) + yield from scheduling.Task(scheduling.sleep(1), timeout=0.2).wait() + winners = yield from scheduling.wait_any(tasks) + print('And the winners are:', [w.name for w in winners]) + tasks = yield from scheduling.wait_all(tasks) + print('And the players were:', [t.name for t in tasks]) + return tasks + + +def logtimes(real): + utime, stime, cutime, cstime, unused = os.times() + logging.info('real %10.3f', real) + logging.info('user %10.3f', utime + cutime) + logging.info('sys %10.3f', stime + cstime) + + +def main(): + t0 = time.time() + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + task = scheduling.run(doit()) + if task.exception: + print('Exception:', repr(task.exception)) + if isinstance(task.exception, AssertionError): + raise task.exception + else: + for t in task.result: + print(t.name + ':', + repr(t.exception) if t.exception else t.result) + + # Report real, user, sys times. + t1 = time.time() + logtimes(t1-t0) + + +if __name__ == '__main__': + main() diff --git a/old/p3time.py b/old/p3time.py new file mode 100644 index 00000000..35e14c96 --- /dev/null +++ b/old/p3time.py @@ -0,0 +1,47 @@ +"""Compare timing of plain vs. yield-from calls.""" + +import gc +import time + +def plain(n): + if n <= 0: + return 1 + l = plain(n-1) + r = plain(n-1) + return l + 1 + r + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def submain(depth): + t0 = time.time() + k = plain(depth) + t1 = time.time() + fmt = ' {} {} {:-9,.5f}' + delta0 = t1-t0 + print(('plain' + fmt).format(depth, k, delta0)) + + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + delta1 = t1-t0 + print(('coro.' + fmt).format(depth, k, delta1)) + if delta0: + print(('relat' + fmt).format(depth, k, delta1/delta0)) + +def main(reasonable=16): + gc.disable() + for depth in range(reasonable): + submain(depth) + +if __name__ == '__main__': + main() diff --git a/old/polling.py b/old/polling.py new file mode 100644 index 00000000..6586efcc --- /dev/null +++ b/old/polling.py @@ -0,0 +1,535 @@ +"""Event loop and related classes. + +The event loop can be broken up into a pollster (the part responsible +for telling us when file descriptors are ready) and the event loop +proper, which wraps a pollster with functionality for scheduling +callbacks, immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. + +There are several implementations of the pollster part, several using +esoteric system calls that exist only on some platforms. These are: + +- kqueue (most BSD systems) +- epoll (newer Linux systems) +- poll (most UNIX systems) +- select (all UNIX systems, and Windows) +- TODO: Support IOCP on Windows and some UNIX platforms. + +NOTE: We don't use select on systems where any of the others is +available, because select performs poorly as the number of file +descriptors goes up. The ranking is roughly: + + 1. kqueue, epoll, IOCP + 2. poll + 3. select + +TODO: +- Optimize the various pollsters. +- Unittests. +""" + +import collections +import concurrent.futures +import heapq +import logging +import os +import select +import threading +import time + + +class PollsterBase: + """Base class for all polling implementations. + + This defines an interface to register and unregister readers and + writers for specific file descriptors, and an interface to get a + list of events. There's also an interface to check whether any + readers or writers are currently registered. + """ + + def __init__(self): + super().__init__() + self.readers = {} # {fd: token, ...}. + self.writers = {} # {fd: token, ...}. + + def pollable(self): + """Return True if any readers or writers are currently registered.""" + return bool(self.readers or self.writers) + + # Subclasses are expected to extend the add/remove methods. + + def register_reader(self, fd, token): + """Add or update a reader for a file descriptor.""" + self.readers[fd] = token + + def register_writer(self, fd, token): + """Add or update a writer for a file descriptor.""" + self.writers[fd] = token + + def unregister_reader(self, fd): + """Remove the reader for a file descriptor.""" + del self.readers[fd] + + def unregister_writer(self, fd): + """Remove the writer for a file descriptor.""" + del self.writers[fd] + + def poll(self, timeout=None): + """Poll for events. A subclass must implement this. + + If timeout is omitted or None, this blocks until at least one + event is ready. Otherwise, timeout gives a maximum time to + wait (in seconds as an int or float) -- the method returns as + soon as at least one event is ready or when the timeout is + expired. For a non-blocking poll, pass 0. + + The return value is a list of events; it is empty when the + timeout expired before any events were ready. Each event + is a token previously passed to register_reader/writer(). + """ + raise NotImplementedError + + +class SelectPollster(PollsterBase): + """Pollster implementation using select.""" + + def poll(self, timeout=None): + readable, writable, _ = select.select(self.readers, self.writers, + [], timeout) + events = [] + events += (self.readers[fd] for fd in readable) + events += (self.writers[fd] for fd in writable) + return events + + +class PollPollster(PollsterBase): + """Pollster implementation using poll.""" + + def __init__(self): + super().__init__() + self._poll = select.poll() + + def _update(self, fd): + assert isinstance(fd, int), fd + flags = 0 + if fd in self.readers: + flags |= select.POLLIN + if fd in self.writers: + flags |= select.POLLOUT + if flags: + self._poll.register(fd, flags) + else: + self._poll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + # Timeout is in seconds, but poll() takes milliseconds. + msecs = None if timeout is None else int(round(1000 * timeout)) + events = [] + for fd, flags in self._poll.poll(msecs): + if flags & (select.POLLIN | select.POLLHUP): + if fd in self.readers: + events.append(self.readers[fd]) + if flags & (select.POLLOUT | select.POLLHUP): + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class EPollPollster(PollsterBase): + """Pollster implementation using epoll.""" + + def __init__(self): + super().__init__() + self._epoll = select.epoll() + + def _update(self, fd): + assert isinstance(fd, int), fd + eventmask = 0 + if fd in self.readers: + eventmask |= select.EPOLLIN + if fd in self.writers: + eventmask |= select.EPOLLOUT + if eventmask: + try: + self._epoll.register(fd, eventmask) + except IOError: + self._epoll.modify(fd, eventmask) + else: + self._epoll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + if timeout is None: + timeout = -1 # epoll.poll() uses -1 to mean "wait forever". + events = [] + for fd, eventmask in self._epoll.poll(timeout): + if eventmask & select.EPOLLIN: + if fd in self.readers: + events.append(self.readers[fd]) + if eventmask & select.EPOLLOUT: + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class KqueuePollster(PollsterBase): + """Pollster implementation using kqueue.""" + + def __init__(self): + super().__init__() + self._kqueue = select.kqueue() + + def register_reader(self, fd, callback, *args): + if fd not in self.readers: + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_reader(fd, callback, *args) + + def register_writer(self, fd, callback, *args): + if fd not in self.writers: + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_writer(fd, callback, *args) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def poll(self, timeout=None): + events = [] + max_ev = len(self.readers) + len(self.writers) + for kev in self._kqueue.control(None, max_ev, timeout): + fd = kev.ident + flag = kev.filter + if flag == select.KQ_FILTER_READ and fd in self.readers: + events.append(self.readers[fd]) + elif flag == select.KQ_FILTER_WRITE and fd in self.writers: + events.append(self.writers[fd]) + return events + + +# Pick the best pollster class for the platform. +if hasattr(select, 'kqueue'): + best_pollster = KqueuePollster +elif hasattr(select, 'epoll'): + best_pollster = EPollPollster +elif hasattr(select, 'poll'): + best_pollster = PollPollster +else: + best_pollster = SelectPollster + + +class DelayedCall: + """Object returned by callback registration methods.""" + + def __init__(self, when, callback, args, kwds=None): + self.when = when + self.callback = callback + self.args = args + self.kwds = kwds + self.cancelled = False + + def cancel(self): + self.cancelled = True + + def __lt__(self, other): + return self.when < other.when + + def __le__(self, other): + return self.when <= other.when + + def __eq__(self, other): + return self.when == other.when + + +class EventLoop: + """Event loop functionality. + + This defines public APIs call_soon(), call_later(), run_once() and + run(). It also wraps Pollster APIs register_reader(), + register_writer(), remove_reader(), remove_writer() with + add_reader() etc. + + This class's instance variables are not part of its API. + """ + + def __init__(self, pollster=None): + super().__init__() + if pollster is None: + logging.info('Using pollster: %s', best_pollster.__name__) + pollster = best_pollster() + self.pollster = pollster + self.ready = collections.deque() # [(callback, args), ...] + self.scheduled = [] # [(when, callback, args), ...] + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_reader(fd, dcall) + return dcall + + def remove_reader(self, fd): + """Remove a reader callback.""" + self.pollster.unregister_reader(fd) + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_writer(fd, dcall) + return dcall + + def remove_writer(self, fd): + """Remove a writer callback.""" + self.pollster.unregister_writer(fd) + + def add_callback(self, dcall): + """Add a DelayedCall to ready or scheduled.""" + if dcall.cancelled: + return + if dcall.when is None: + self.ready.append(dcall) + else: + heapq.heappush(self.scheduled, dcall) + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + dcall = DelayedCall(None, callback, args) + self.ready.append(dcall) + return dcall + + def call_later(self, when, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The time can be an int or float, expressed in seconds. + + If when is small enough (~11 days), it's assumed to be a + relative time, meaning the call will be scheduled that many + seconds in the future; otherwise it's assumed to be a posix + timestamp as returned by time.time(). + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + if when < 10000000: + when += time.time() + dcall = DelayedCall(when, callback, args) + heapq.heappush(self.scheduled, dcall) + return dcall + + def run_once(self): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Pass in a timeout or deadline or something. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: As step 4, run everything scheduled by steps 1-3. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # TODO: Ensure this loop always finishes, even if some + # callbacks keeps registering more callbacks. + while self.ready: + dcall = self.ready.popleft() + if not dcall.cancelled: + try: + if dcall.kwds: + dcall.callback(*dcall.args, **dcall.kwds) + else: + dcall.callback(*dcall.args) + except Exception: + logging.exception('Exception in callback %s %r', + dcall.callback, dcall.args) + + # Remove delayed calls that were cancelled from head of queue. + while self.scheduled and self.scheduled[0].cancelled: + heapq.heappop(self.scheduled) + + # Inspect the poll queue. + if self.pollster.pollable(): + if self.scheduled: + when = self.scheduled[0].when + timeout = max(0, when - time.time()) + else: + timeout = None + t0 = time.time() + events = self.pollster.poll(timeout) + t1 = time.time() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + for dcall in events: + self.add_callback(dcall) + + # Handle 'later' callbacks that are ready. + now = time.time() + while self.scheduled: + dcall = self.scheduled[0] + if dcall.when > now: + break + dcall = heapq.heappop(self.scheduled) + self.call_soon(dcall.callback, *dcall.args) + + def run(self): + """Run the event loop until there is no work left to do. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + """ + while self.ready or self.scheduled or self.pollster.pollable(): + self.run_once() + + +MAX_WORKERS = 5 # Default max workers when creating an executor. + + +class ThreadRunner: + """Helper to submit work to a thread pool and wait for it. + + This is the glue between the single-threaded callback-based async + world and the threaded world. Use it to call functions that must + block and don't have an async alternative (e.g. getaddrinfo()). + + The only public API is submit(). + """ + + def __init__(self, eventloop, executor=None): + self.eventloop = eventloop + self.executor = executor # Will be constructed lazily. + self.pipe_read_fd, self.pipe_write_fd = os.pipe() + self.active_count = 0 + + def read_callback(self): + """Semi-permanent callback while at least one future is active.""" + assert self.active_count > 0, self.active_count + data = os.read(self.pipe_read_fd, 8192) # Traditional buffer size. + self.active_count -= len(data) + if self.active_count == 0: + self.eventloop.remove_reader(self.pipe_read_fd) + assert self.active_count >= 0, self.active_count + + def submit(self, func, *args, executor=None, callback=None): + """Submit a function to the thread pool. + + This returns a concurrent.futures.Future instance. The caller + should not wait for that, but rather use the callback argument.. + """ + if executor is None: + executor = self.executor + if executor is None: + # Lazily construct a default executor. + # TODO: Should this be shared between threads? + executor = concurrent.futures.ThreadPoolExecutor(MAX_WORKERS) + self.executor = executor + assert self.active_count >= 0, self.active_count + future = executor.submit(func, *args) + if self.active_count == 0: + self.eventloop.add_reader(self.pipe_read_fd, self.read_callback) + self.active_count += 1 + def done_callback(fut): + if callback is not None: + self.eventloop.call_soon(callback, fut) + # TODO: Wake up the pipe in call_soon()? + os.write(self.pipe_write_fd, b'x') + future.add_done_callback(done_callback) + return future + + +class Context(threading.local): + """Thread-local context. + + We use this to avoid having to explicitly pass around an event loop + or something to hold the current task. + + TODO: Add an API so frameworks can substitute a different notion + of context more easily. + """ + + def __init__(self, eventloop=None, threadrunner=None): + # Default event loop and thread runner are lazily constructed + # when first accessed. + self._eventloop = eventloop + self._threadrunner = threadrunner + self.current_task = None # For the benefit of scheduling.py. + + @property + def eventloop(self): + if self._eventloop is None: + self._eventloop = EventLoop() + return self._eventloop + + @property + def threadrunner(self): + if self._threadrunner is None: + self._threadrunner = ThreadRunner(self.eventloop) + return self._threadrunner + + +context = Context() # Thread-local! diff --git a/old/scheduling.py b/old/scheduling.py new file mode 100644 index 00000000..3864571d --- /dev/null +++ b/old/scheduling.py @@ -0,0 +1,354 @@ +#!/usr/bin/env python3.3 +"""Example coroutine scheduler, PEP-380-style ('yield from '). + +Requires Python 3.3. + +There are likely micro-optimizations possible here, but that's not the point. + +TODO: +- Docstrings. +- Unittests. + +PATTERNS TO TRY: +- Various synchronization primitives (Lock, RLock, Event, Condition, + Semaphore, BoundedSemaphore, Barrier). +""" + +__author__ = 'Guido van Rossum ' + +# Standard library imports (keep in alphabetic order). +from concurrent.futures import CancelledError, TimeoutError +import logging +import time +import types + +# Local imports (keep in alphabetic order). +import polling + + +context = polling.context + + +class Task: + """Wrapper around a stack of generators. + + This is a bit like a Future, but with a different interface. + + TODO: + - wait for result. + """ + + def __init__(self, gen, name=None, *, timeout=None): + assert isinstance(gen, types.GeneratorType), repr(gen) + self.gen = gen + self.name = name or gen.__name__ + self.timeout = timeout + self.eventloop = context.eventloop + self.canceleer = None + if timeout is not None: + self.canceleer = self.eventloop.call_later(timeout, self.cancel) + self.blocked = False + self.unblocker = None + self.cancelled = False + self.must_cancel = False + self.alive = True + self.result = None + self.exception = None + self.done_callbacks = [] + # Start the task immediately. + self.eventloop.call_soon(self.step) + + def add_done_callback(self, done_callback): + # For better or for worse, the callback will always be called + # with the task as an argument, like concurrent.futures.Future. + # TODO: Call it right away if task is no longer alive. + dcall = polling.DelayedCall(None, done_callback, (self,)) + self.done_callbacks.append(dcall) + self.done_callbacks = [dc for dc in self.done_callbacks + if not dc.cancelled] + return dcall + + def __repr__(self): + parts = [self.name] + is_current = (self is context.current_task) + if self.blocked: + parts.append('blocking' if is_current else 'blocked') + elif self.alive: + parts.append('running' if is_current else 'runnable') + if self.must_cancel: + parts.append('must_cancel') + if self.cancelled: + parts.append('cancelled') + if self.exception is not None: + parts.append('exception=%r' % self.exception) + elif not self.alive: + parts.append('result=%r' % (self.result,)) + if self.timeout is not None: + parts.append('timeout=%.3f' % self.timeout) + return 'Task<' + ', '.join(parts) + '>' + + def cancel(self): + if self.alive: + if not self.must_cancel and not self.cancelled: + self.must_cancel = True + if self.blocked: + self.unblock() + + def step(self): + assert self.alive, self + try: + context.current_task = self + if self.must_cancel: + self.must_cancel = False + self.cancelled = True + self.gen.throw(CancelledError()) + else: + next(self.gen) + except StopIteration as exc: + self.alive = False + self.result = exc.value + except Exception as exc: + self.alive = False + self.exception = exc + logging.debug('Uncaught exception in %s', self, + exc_info=True, stack_info=True) + except BaseException as exc: + self.alive = False + self.exception = exc + raise + else: + if not self.blocked: + self.eventloop.call_soon(self.step) + finally: + context.current_task = None + if not self.alive: + # Cancel timeout callback if set. + if self.canceleer is not None: + self.canceleer.cancel() + # Schedule done_callbacks. + for dcall in self.done_callbacks: + self.eventloop.add_callback(dcall) + + def block(self, unblock_callback=None, *unblock_args): + assert self is context.current_task, self + assert self.alive, self + assert not self.blocked, self + self.blocked = True + self.unblocker = (unblock_callback, unblock_args) + + def unblock_if_alive(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. + if self.alive: + self.unblock() + + def unblock(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. + assert self.alive, self + assert self.blocked, self + self.blocked = False + unblock_callback, unblock_args = self.unblocker + if unblock_callback is not None: + try: + unblock_callback(*unblock_args) + except Exception: + logging.error('Exception in unblocker in task %r', self.name) + raise + finally: + self.unblocker = None + self.eventloop.call_soon(self.step) + + def block_io(self, fd, flag): + assert isinstance(fd, int), repr(fd) + assert flag in ('r', 'w'), repr(flag) + if flag == 'r': + self.block(self.eventloop.remove_reader, fd) + self.eventloop.add_reader(fd, self.unblock) + else: + self.block(self.eventloop.remove_writer, fd) + self.eventloop.add_writer(fd, self.unblock) + + def wait(self): + """COROUTINE: Wait until this task is finished.""" + current_task = context.current_task + assert self is not current_task, (self, current_task) # How confusing! + if not self.alive: + return + current_task.block() + self.add_done_callback(current_task.unblock) + yield + + def __iter__(self): + """COROUTINE: Wait, then return result or raise exception. + + This adds a little magic so you can say + + x = yield from Task(gen()) + + and it is equivalent to + + x = yield from gen() + + but with the option to add a timeout (and only a tad slower). + """ + if self.alive: + yield from self.wait() + assert not self.alive + if self.exception is not None: + raise self.exception + return self.result + + +def run(arg=None): + """Run the event loop until it's out of work. + + If you pass a generator, it will be spawned for you. + You can also pass a task (already started). + Returns the task. + """ + t = None + if arg is not None: + if isinstance(arg, Task): + t = arg + else: + t = Task(arg) + context.eventloop.run() + if t is not None and t.exception is not None: + logging.error('Uncaught exception in startup task: %r', + t.exception) + return t + + +def sleep(secs): + """COROUTINE: Sleep for some time (a float in seconds).""" + current_task = context.current_task + unblocker = context.eventloop.call_later(secs, current_task.unblock) + current_task.block(unblocker.cancel) + yield + + +def block_r(fd): + """COROUTINE: Block until a file descriptor is ready for reading.""" + context.current_task.block_io(fd, 'r') + yield + + +def block_w(fd): + """COROUTINE: Block until a file descriptor is ready for writing.""" + context.current_task.block_io(fd, 'w') + yield + + +def call_in_thread(func, *args, executor=None): + """COROUTINE: Run a function in a thread.""" + task = context.current_task + eventloop = context.eventloop + future = context.threadrunner.submit(func, *args, + executor=executor, + callback=task.unblock_if_alive) + task.block(future.cancel) + yield + assert future.done() + return future.result() + + +def wait_for(count, tasks): + """COROUTINE: Wait for the first N of a set of tasks to complete. + + May return more than N if more than N are immediately ready. + + NOTE: Tasks that were cancelled or raised are also considered ready. + """ + assert tasks + assert all(isinstance(task, Task) for task in tasks) + tasks = set(tasks) + assert 1 <= count <= len(tasks) + current_task = context.current_task + assert all(task is not current_task for task in tasks) + todo = set() + done = set() + dcalls = [] + def wait_for_callback(task): + nonlocal todo, done, current_task, count, dcalls + todo.remove(task) + if len(done) < count: + done.add(task) + if len(done) == count: + for dcall in dcalls: + dcall.cancel() + current_task.unblock() + for task in tasks: + if task.alive: + todo.add(task) + else: + done.add(task) + if len(done) < count: + for task in todo: + dcall = task.add_done_callback(wait_for_callback) + dcalls.append(dcall) + current_task.block() + yield + return done + + +def wait_any(tasks): + """COROUTINE: Wait for the first of a set of tasks to complete.""" + return wait_for(1, tasks) + + +def wait_all(tasks): + """COROUTINE: Wait for all of a set of tasks to complete.""" + return wait_for(len(tasks), tasks) + + +def map_over(gen, *args, timeout=None): + """COROUTINE: map a generator over one or more iterables. + + E.g. map_over(foo, xs, ys) runs + + Task(foo(x, y)) for x, y in zip(xs, ys) + + and returns a list of all results (in that order). However if any + task raises an exception, the remaining tasks are cancelled and + the exception is propagated. + """ + # gen is a generator function. + tasks = [Task(gobj, timeout=timeout) for gobj in map(gen, *args)] + return (yield from par_tasks(tasks)) + + +def par(*args): + """COROUTINE: Wait for generators, return a list of results. + + Raises as soon as one of the tasks raises an exception (and then + remaining tasks are cancelled). + + This differs from par_tasks() in two ways: + - takes *args instead of list of args + - each arg may be a generator or a task + """ + tasks = [] + for arg in args: + if not isinstance(arg, Task): + # TODO: assert arg is a generator or an iterator? + arg = Task(arg) + tasks.append(arg) + return (yield from par_tasks(tasks)) + + +def par_tasks(tasks): + """COROUTINE: Wait for a list of tasks, return a list of results. + + Raises as soon as one of the tasks raises an exception (and then + remaining tasks are cancelled). + """ + todo = set(tasks) + while todo: + ts = yield from wait_any(todo) + for t in ts: + assert not t.alive, t + todo.remove(t) + if t.exception is not None: + for other in todo: + other.cancel() + raise t.exception + return [t.result for t in tasks] diff --git a/old/sockets.py b/old/sockets.py new file mode 100644 index 00000000..a5005dc3 --- /dev/null +++ b/old/sockets.py @@ -0,0 +1,348 @@ +"""Socket wrappers to go with scheduling.py. + +Classes: + +- SocketTransport: a transport implementation wrapping a socket. +- SslTransport: a transport implementation wrapping SSL around a socket. +- BufferedReader: a buffer wrapping the read end of a transport. + +Functions (all coroutines): + +- connect(): connect a socket. +- getaddrinfo(): look up an address. +- create_connection(): look up address and return a connected socket for it. +- create_transport(): look up address and return a connected transport. + +TODO: +- Improve transport abstraction. +- Make a nice protocol abstraction. +- Unittests. +- A write() call that isn't a generator (needed so you can substitute it + for sys.stderr, pass it to logging.StreamHandler, etc.). +""" + +__author__ = 'Guido van Rossum ' + +# Stdlib imports. +import errno +import socket +import ssl + +# Local imports. +import scheduling + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + + +class SocketTransport: + """Transport wrapping a socket. + + The socket must already be connected in non-blocking mode. + """ + + def __init__(self, sock): + self.sock = sock + + def recv(self, n): + """COROUTINE: Read up to n bytes, blocking as needed. + + Always returns at least one byte, except if the socket was + closed or disconnected and there's no more data; then it + returns b''. + """ + assert n >= 0, n + while True: + try: + return self.sock.recv(n) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return b'' + else: + raise # Unexpected, propagate. + yield from scheduling.block_r(self.sock.fileno()) + + def send(self, data): + """COROUTINE; Send data to the socket, blocking until all written. + + Return True if all went well, False if socket was disconnected. + """ + while data: + try: + n = self.sock.send(data) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + else: + assert 0 <= n <= len(data), (n, len(data)) + if n == len(data): + break + data = data[n:] + continue + yield from scheduling.block_w(self.sock.fileno()) + + return True + + def close(self): + """Close the socket. (Not a coroutine.)""" + self.sock.close() + + +class SslTransport: + """Transport wrapping a socket in SSL. + + The socket must already be connected at the TCP level in + non-blocking mode. + """ + + def __init__(self, rawsock, sslcontext=None): + self.rawsock = rawsock + self.sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self.sslsock = self.sslcontext.wrap_socket( + self.rawsock, do_handshake_on_connect=False) + + def do_handshake(self): + """COROUTINE: Finish the SSL handshake.""" + while True: + try: + self.sslsock.do_handshake() + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + else: + break + + def recv(self, n): + """COROUTINE: Read up to n bytes. + + This blocks until at least one byte is read, or until EOF. + """ + while True: + try: + return self.sslsock.recv(n) + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_r(self.sslsock.fileno()) + elif err.errno in _DISCONNECTED: + # Can this happen? + return b'' + else: + raise # Unexpected, propagate. + + def send(self, data): + """COROUTINE: Send data to the socket, blocking as needed.""" + while data: + try: + n = self.sslsock.send(data) + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_w(self.sslsock.fileno()) + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + if n == len(data): + break + data = data[n:] + + return True + + def close(self): + """Close the socket. (Not a coroutine.) + + This also closes the raw socket. + """ + self.sslsock.close() + + # TODO: More SSL-specific methods, e.g. certificate stuff, unwrap(), ... + + +class BufferedReader: + """A buffered reader wrapping a transport.""" + + def __init__(self, trans, limit=8192): + self.trans = trans + self.limit = limit + self.buffer = b'' + self.eof = False + + def read(self, n): + """COROUTINE: Read up to n bytes, blocking at most once.""" + assert n >= 0, n + if not self.buffer and not self.eof: + yield from self._fillbuffer(max(n, self.limit)) + return self._getfrombuffer(n) + + def readexactly(self, n): + """COUROUTINE: Read exactly n bytes, or until EOF.""" + blocks = [] + count = 0 + while count < n: + block = yield from self.read(n - count) + if not block: + break + blocks.append(block) + count += len(block) + return b''.join(blocks) + + def readline(self): + """COROUTINE: Read up to newline or limit, whichever comes first.""" + end = self.buffer.find(b'\n') + 1 # Point past newline, or 0. + while not end and not self.eof and len(self.buffer) < self.limit: + anchor = len(self.buffer) + yield from self._fillbuffer(self.limit) + end = self.buffer.find(b'\n', anchor) + 1 + if not end: + end = len(self.buffer) + if end > self.limit: + end = self.limit + return self._getfrombuffer(end) + + def _getfrombuffer(self, n): + """Read up to n bytes without blocking (not a coroutine).""" + if n >= len(self.buffer): + result, self.buffer = self.buffer, b'' + else: + result, self.buffer = self.buffer[:n], self.buffer[n:] + return result + + def _fillbuffer(self, n): + """COROUTINE: Fill buffer with one (up to) n bytes from transport.""" + assert not self.eof, '_fillbuffer called at eof' + data = yield from self.trans.recv(n) + if data: + self.buffer += data + else: + self.eof = True + + +def connect(sock, address): + """COROUTINE: Connect a socket to an address.""" + try: + sock.connect(address) + except socket.error as err: + if err.errno != errno.EINPROGRESS: + raise + yield from scheduling.block_w(sock.fileno()) + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + raise IOError(err, 'Connection refused') + + +def getaddrinfo(host, port, af=0, socktype=0, proto=0): + """COROUTINE: Look up an address and return a list of infos for it. + + Each info is a tuple (af, socktype, protocol, canonname, address). + """ + infos = yield from scheduling.call_in_thread(socket.getaddrinfo, + host, port, af, + socktype, proto) + return infos + + +def create_connection(host, port, af=0, socktype=socket.SOCK_STREAM, proto=0): + """COROUTINE: Look up address and create a socket connected to it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + yield from connect(sock, address) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return sock + + +def create_transport(host, port, af=0, ssl=None): + """COROUTINE: Look up address and create a transport connected to it.""" + if ssl is None: + ssl = (port == 443) + sock = yield from create_connection(host, port, af) + if ssl: + trans = SslTransport(sock) + yield from trans.do_handshake() + else: + trans = SocketTransport(sock) + return trans + + +class Listener: + """Wrapper for a listening socket.""" + + def __init__(self, sock): + self.sock = sock + + def accept(self): + """COROUTINE: Accept a connection.""" + while True: + try: + conn, addr = self.sock.accept() + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_r(self.sock.fileno()) + else: + raise # Unexpected, propagate. + else: + conn.setblocking(False) + return conn, addr + + +def create_listener(host, port, af=0, socktype=0, proto=0, + backlog=5, reuse_addr=True): + """COROUTINE: Look up address and create a listener for it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + if reuse_addr: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + sock.listen(backlog) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return Listener(sock) diff --git a/old/transports.py b/old/transports.py new file mode 100644 index 00000000..19095bf4 --- /dev/null +++ b/old/transports.py @@ -0,0 +1,496 @@ +"""Transports and Protocols, actually. + +Inspired by Twisted, PEP 3153 and github.com/lvh/async-pep. + +THIS IS NOT REAL CODE! IT IS JUST AN EXPERIMENT. +""" + +# Stdlib imports. +import collections +import errno +import logging +import socket +import ssl +import sys +import time + +# Local imports. +import polling +import scheduling +import sockets + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + + +class Transport: + """ABC representing a transport. + + There may be many implementations. The user never instantiates + this directly; they call some utility function, passing it a + protocol, and the utility function will call the protocol's + connection_made() method with a transport (or it will call + connection_lost() with an exception if it fails to create the + desired transport). + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + def write(self, data): + """Write some data (bytes) to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data (bytes) to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data will + be received. When all buffered data is flushed, the protocol's + connection_lost() method is called with None as its argument. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method is called with None as + its argument. + """ + raise NotImplementedError + + def half_close(self): + """Closes the write end after flushing buffered data. + + Data may still be received. + + TODO: What's the use case for this? How to implement it? + Should it call shutdown(SHUT_WR) after all the data is flushed? + Is there no use case for closing the other half first? + """ + raise NotImplementedError + + def pause(self): + """Pause the receiving end. + + No data will be received until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Cancels a pause() call, resumes receiving data. + """ + raise NotImplementedError + + +class Protocol: + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing. + + When the user wants to requests a transport, they pass a protocol + instance to a utility function. + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_list() will be called exactly once + with either an exception object or None as an argument. + + If the utility function does not succeed in creating a transport, + it will call connection_lost() with an exception object. + + State machine of calls: + + start -> [CM -> DR*] -> CL -> end + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the connection. + To send data, call its write() or writelines() method. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + + TODO: Should we allow it to be a bytesarray or some other + memory buffer? + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + Also called when we fail to make a connection at all (in that + case connection_made() will not be called). + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +# TODO: The rest is platform specific and should move elsewhere. + +class UnixSocketTransport(Transport): + + def __init__(self, eventloop, protocol, sock): + self._eventloop = eventloop + self._protocol = protocol + self._sock = sock + self._buffer = collections.deque() # For write(). + self._write_closed = False + + def _on_readable(self): + try: + data = self._sock.recv(8192) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + else: + if not data: + self._eventloop.remove_reader(self._sock.fileno()) + self._sock.close() + self._protocol.connection_lost(None) + else: + self._protocol.data_received(data) # XXX call_soon()? + + def write(self, data): + assert isinstance(data, bytes) + assert not self._write_closed + if not data: + # Silly, but it happens. + return + if self._buffer: + # We've already registered a callback, just buffer the data. + self._buffer.append(data) + # Consider pausing if the total length of the buffer is + # truly huge. + return + + # TODO: Refactor so there's more sharing between this and + # _on_writable(). + + # There's no callback registered yet. It's quite possible + # that the kernel has buffer space for our data, so try to + # write now. Since the socket is non-blocking it will + # give us an error in _TRYAGAIN if it doesn't have enough + # space for even one more byte; it will return the number + # of bytes written if it can write at least one byte. + try: + n = self._sock.send(data) + except socket.error as exc: + # An error. + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + # The kernel doesn't have room for more data right now. + n = 0 + else: + # Wrote at least one byte. + if n == len(data): + # Wrote it all. Done! + if self._write_closed: + self._sock.shutdown(socket.SHUT_WR) + return + # Throw away the data that was already written. + # TODO: Do this without copying the data? + data = data[n:] + self._buffer.append(data) + self._eventloop.add_writer(self._sock.fileno(), self._on_writable) + + def _on_writable(self): + while self._buffer: + data = self._buffer[0] + # TODO: Join small amounts of data? + try: + n = self._sock.send(data) + except socket.error as exc: + # Error handling is the same as in write(). + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + if n < len(data): + self._buffer[0] = data[n:] + return + self._buffer.popleft() + self._eventloop.remove_writer(self._sock.fileno()) + if self._write_closed: + self._sock.shutdown(socket.SHUT_WR) + + def abort(self): + self._bad_error(None) + + def _bad_error(self, exc): + # A serious error. Close the socket etc. + fd = self._sock.fileno() + # TODO: Record whether we have a writer and/or reader registered. + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + self._sock.close() + self._protocol.connection_lost(exc) # XXX call_soon()? + + def half_close(self): + self._write_closed = True + + +class UnixSslTransport(Transport): + + # TODO: Refactor Socket and Ssl transport to share some code. + # (E.g. buffering.) + + # TODO: Consider using coroutines instead of callbacks, it seems + # much easier that way. + + def __init__(self, eventloop, protocol, rawsock, sslcontext=None): + self._eventloop = eventloop + self._protocol = protocol + self._rawsock = rawsock + self._sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslsock = self._sslcontext.wrap_socket( + self._rawsock, do_handshake_on_connect=False) + + self._buffer = collections.deque() # For write(). + self._write_closed = False + + # Try the handshake now. Likely it will raise EAGAIN, then it + # will take care of registering the appropriate callback. + self._on_handshake() + + def _bad_error(self, exc): + # A serious error. Close the socket etc. + fd = self._sslsock.fileno() + # TODO: Record whether we have a writer and/or reader registered. + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + self._sslsock.close() + self._protocol.connection_lost(exc) # XXX call_soon()? + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._eventloop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._eventloop.add_writable(fd, self._on_handshake) + return + # TODO: What if it raises another error? + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + self._protocol.connection_made(self) + self._eventloop.add_reader(fd, self._on_ready) + self._eventloop.add_writer(fd, self._on_ready) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + else: + if data: + self._protocol.data_received(data) + else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer + self._eventloop.remove_reader(fd) + self._eventloop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + return + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = self._buffer[0] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + else: + if n == len(data): + self._buffer.popleft() + # Could try again, but let's just have the next callback do it. + else: + self._buffer[0] = data[n:] + + def write(self, data): + assert isinstance(data, bytes) + assert not self._write_closed + if not data: + return + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + def half_close(self): + self._write_closed = True + # Just set the flag. Calling shutdown() on the ssl socket + # breaks something, causing recv() to return binary data. + + +def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, + use_ssl=None): + # TODO: Pass in a protocol factory, not a protocol. + # What should be the exact sequence of events? + # - socket + # - transport + # - protocol + # - tell transport about protocol + # - tell protocol about transport + # Or should the latter two be reversed? Does it matter? + if port is None: + port = 443 if use_ssl else 80 + if use_ssl is None: + use_ssl = (port == 443) + if not socktype: + socktype = socket.SOCK_STREAM + eventloop = polling.context.eventloop + + def on_socket_connected(task): + assert not task.alive + if task.exception is not None: + # TODO: Call some callback. + raise task.exception + sock = task.result + assert sock is not None + logging.debug('on_socket_connected') + if use_ssl: + # You can pass an ssl.SSLContext object as use_ssl, + # or a bool. + if isinstance(use_ssl, bool): + sslcontext = None + else: + sslcontext = use_ssl + transport = UnixSslTransport(eventloop, protocol, sock, sslcontext) + else: + transport = UnixSocketTransport(eventloop, protocol, sock) + # TODO: Should the ransport make the following calls? + protocol.connection_made(transport) # XXX call_soon()? + # Don't do this before connection_made() is called. + eventloop.add_reader(sock.fileno(), transport._on_readable) + + coro = sockets.create_connection(host, port, af, socktype, proto) + task = scheduling.Task(coro) + task.add_done_callback(on_socket_connected) + + +def main(): # Testing... + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + host = 'xkcd.com' + if sys.argv[1:] and '.' in sys.argv[-1]: + host = sys.argv[-1] + + t0 = time.time() + + class TestProtocol(Protocol): + def connection_made(self, transport): + logging.info('Connection made at %.3f secs', time.time() - t0) + self.transport = transport + self.transport.write(b'GET / HTTP/1.0\r\nHost: ' + + host.encode('ascii') + + b'\r\n\r\n') + self.transport.half_close() + def data_received(self, data): + logging.info('Received %d bytes at t=%.3f', + len(data), time.time() - t0) + logging.debug('Received %r', data) + def connection_lost(self, exc): + logging.debug('Connection lost: %r', exc) + self.t1 = time.time() + logging.info('Total time %.3f secs', self.t1 - t0) + + tp = TestProtocol() + logging.debug('tp = %r', tp) + make_connection(tp, host, use_ssl=('-S' in sys.argv)) + logging.info('Running...') + polling.context.eventloop.run() + logging.info('Done.') + + +if __name__ == '__main__': + main() diff --git a/old/xkcd.py b/old/xkcd.py new file mode 100755 index 00000000..474009d0 --- /dev/null +++ b/old/xkcd.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3.3 +"""Minimal synchronous SSL demo, connecting to xkcd.com.""" + +import socket, ssl + +s = socket.socket() +s.connect(('xkcd.com', 443)) +ss = ssl.wrap_socket(s) + +ss.send(b'GET / HTTP/1.0\r\n\r\n') + +while True: + data = ss.recv(1000000) + print(data) + if not data: + break + +ss.close() diff --git a/old/yyftime.py b/old/yyftime.py new file mode 100644 index 00000000..f55234b9 --- /dev/null +++ b/old/yyftime.py @@ -0,0 +1,75 @@ +"""Compare timing of yield-from vs. yield calls.""" + +import gc +import time + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def run_coro(depth): + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + print('coro', depth, k, round(t1-t0, 6)) + return t1-t0 + +class Future: + + def __init__(self, g): + self.g = g + + def wait(self): + value = None + try: + while True: + f = self.g.send(value) + f.wait() + value = f.value + except StopIteration as err: + self.value = err.value + + + +def task(func): # Decorator + def wrapper(*args): + g = func(*args) + f = Future(g) + return f + return wrapper + +@task +def oldstyle(n): + if n <= 0: + return 1 + l = yield oldstyle(n-1) + r = yield oldstyle(n-1) + return l + 1 + r + +def run_olds(depth): + t0 = time.time() + f = oldstyle(depth) + f.wait() + k = f.value + t1 = time.time() + print('olds', depth, k, round(t1-t0, 6)) + return t1-t0 + +def main(): + gc.disable() + for depth in range(16): + tc = run_coro(depth) + to = run_olds(depth) + if tc: + print('ratio', round(to/tc, 2)) + +if __name__ == '__main__': + main() diff --git a/overlapped.c b/overlapped.c new file mode 100644 index 00000000..1e7d6119 --- /dev/null +++ b/overlapped.c @@ -0,0 +1,997 @@ +/* + * Support for overlapped IO + * + * Some code borrowed from Modules/_winapi.c of CPython + */ + +/* XXX check overflow and DWORD <-> Py_ssize_t conversions + Check itemsize */ + +#include "Python.h" +#include "structmember.h" + +#define WINDOWS_LEAN_AND_MEAN +#include +#include +#include + +#if defined(MS_WIN32) && !defined(MS_WIN64) +# define F_POINTER "k" +# define T_POINTER T_ULONG +#else +# define F_POINTER "K" +# define T_POINTER T_ULONGLONG +#endif + +#define F_HANDLE F_POINTER +#define F_ULONG_PTR F_POINTER +#define F_DWORD "k" +#define F_BOOL "i" +#define F_UINT "I" + +#define T_HANDLE T_POINTER + +enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_WRITE, TYPE_ACCEPT, + TYPE_CONNECT, TYPE_DISCONNECT}; + +/* + * Some functions should be loaded at runtime + */ + +static LPFN_ACCEPTEX Py_AcceptEx = NULL; +static LPFN_CONNECTEX Py_ConnectEx = NULL; +static LPFN_DISCONNECTEX Py_DisconnectEx = NULL; +static BOOL (CALLBACK *Py_CancelIoEx)(HANDLE, LPOVERLAPPED) = NULL; + +#define GET_WSA_POINTER(s, x) \ + (SOCKET_ERROR != WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, \ + &Guid##x, sizeof(Guid##x), &Py_##x, \ + sizeof(Py_##x), &dwBytes, NULL, NULL)) + +static int +initialize_function_pointers(void) +{ + GUID GuidAcceptEx = WSAID_ACCEPTEX; + GUID GuidConnectEx = WSAID_CONNECTEX; + GUID GuidDisconnectEx = WSAID_DISCONNECTEX; + HINSTANCE hKernel32; + SOCKET s; + DWORD dwBytes; + + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (s == INVALID_SOCKET) { + PyErr_SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + if (!GET_WSA_POINTER(s, AcceptEx) || + !GET_WSA_POINTER(s, ConnectEx) || + !GET_WSA_POINTER(s, DisconnectEx)) + { + closesocket(s); + PyErr_SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + closesocket(s); + + /* On WinXP we will have Py_CancelIoEx == NULL */ + hKernel32 = GetModuleHandle("KERNEL32"); + *(FARPROC *)&Py_CancelIoEx = GetProcAddress(hKernel32, "CancelIoEx"); + return 0; +} + +/* + * Completion port stuff + */ + +PyDoc_STRVAR( + CreateIoCompletionPort_doc, + "CreateIoCompletionPort(handle, port, key, concurrency) -> port\n\n" + "Create a completion port or register a handle with a port."); + +static PyObject * +overlapped_CreateIoCompletionPort(PyObject *self, PyObject *args) +{ + HANDLE FileHandle; + HANDLE ExistingCompletionPort; + ULONG_PTR CompletionKey; + DWORD NumberOfConcurrentThreads; + HANDLE ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE F_ULONG_PTR F_DWORD, + &FileHandle, &ExistingCompletionPort, &CompletionKey, + &NumberOfConcurrentThreads)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = CreateIoCompletionPort(FileHandle, ExistingCompletionPort, + CompletionKey, NumberOfConcurrentThreads); + Py_END_ALLOW_THREADS + + if (ret == NULL) + return PyErr_SetFromWindowsErr(0); + return Py_BuildValue(F_HANDLE, ret); +} + +PyDoc_STRVAR( + GetQueuedCompletionStatus_doc, + "GetQueuedCompletionStatus(port, msecs) -> (err, bytes, key, address)\n\n" + "Get a message from completion port. Wait for up to msecs milliseconds."); + +static PyObject * +overlapped_GetQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort; + DWORD Milliseconds; + DWORD NumberOfBytes; + ULONG_PTR CompletionKey; + OVERLAPPED *Overlapped = NULL; + DWORD err; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, + &CompletionPort, &Milliseconds)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = GetQueuedCompletionStatus(CompletionPort, &NumberOfBytes, + &CompletionKey, &Overlapped, Milliseconds); + Py_END_ALLOW_THREADS + + err = ret ? ERROR_SUCCESS : GetLastError(); + if (Overlapped == NULL) { + if (err == WAIT_TIMEOUT) + Py_RETURN_NONE; + else + return PyErr_SetFromWindowsErr(err); + } + return Py_BuildValue(F_DWORD F_DWORD F_ULONG_PTR F_POINTER, + err, NumberOfBytes, CompletionKey, Overlapped); +} + +PyDoc_STRVAR( + PostQueuedCompletionStatus_doc, + "PostQueuedCompletionStatus(port, bytes, key, address) -> None\n\n" + "Post a message to completion port."); + +static PyObject * +overlapped_PostQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort; + DWORD NumberOfBytes; + ULONG_PTR CompletionKey; + OVERLAPPED *Overlapped; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD F_ULONG_PTR F_POINTER, + &CompletionPort, &NumberOfBytes, &CompletionKey, + &Overlapped)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = PostQueuedCompletionStatus(CompletionPort, NumberOfBytes, + CompletionKey, Overlapped); + Py_END_ALLOW_THREADS + + if (!ret) + return PyErr_SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +/* + * Bind socket handle to local port without doing slow getaddrinfo() + */ + +PyDoc_STRVAR( + BindLocal_doc, + "BindLocal(handle, length_of_address_tuple) -> None\n\n" + "Bind a socket handle to an arbitrary local port.\n" + "If length_of_address_tuple is 2 then an AF_INET address is used.\n" + "If length_of_address_tuple is 4 then an AF_INET6 address is used."); + +static PyObject * +overlapped_BindLocal(PyObject *self, PyObject *args) +{ + SOCKET Socket; + int TupleLength; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE "i", &Socket, &TupleLength)) + return NULL; + + if (TupleLength == 2) { + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = 0; + addr.sin_addr.S_un.S_addr = INADDR_ANY; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else if (TupleLength == 4) { + struct sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_port = 0; + addr.sin6_addr = in6addr_any; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else { + PyErr_SetString(PyExc_ValueError, "expected tuple of length 2 or 4"); + return NULL; + } + + if (!ret) + return PyErr_SetExcFromWindowsErr(PyExc_IOError, WSAGetLastError()); + Py_RETURN_NONE; +} + +/* + * Set notification mode for the handle + */ + +PyDoc_STRVAR( + SetFileCompletionNotificationModes_doc, + "SetFileCompletionNotificationModes(FileHandle, Flags) -> None\n\n" + "Set whether notification happens if operation succeeds without blocking"); + +static PyObject * +overlapped_SetFileCompletionNotificationModes(PyObject *self, PyObject *args) +{ + HANDLE FileHandle; + UCHAR Flags; + + if (!PyArg_ParseTuple(args, F_HANDLE F_BOOL, &FileHandle, &Flags)) + return NULL; + + if (!SetFileCompletionNotificationModes(FileHandle, Flags)) + return PyErr_SetExcFromWindowsErr(PyExc_IOError, 0); + + Py_RETURN_NONE; +} + +/* + * A Python object wrapping an OVERLAPPED structure and other useful data + * for overlapped I/O + */ + +PyDoc_STRVAR( + Overlapped_doc, + "Overlapped object"); + +typedef struct { + PyObject_HEAD + OVERLAPPED overlapped; + /* For convenience, we store the file handle too */ + HANDLE handle; + /* Error returned by last method call */ + DWORD error; + /* Type of operation */ + DWORD type; + /* Buffer used for reading (optional) */ + PyObject *read_buffer; + /* Buffer used for writing (optional) */ + Py_buffer write_buffer; +} OverlappedObject; + + +static PyObject * +Overlapped_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + OverlappedObject *self; + HANDLE event = INVALID_HANDLE_VALUE; + static char *kwlist[] = {"event", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|" F_HANDLE, kwlist, &event)) + return NULL; + + if (event == INVALID_HANDLE_VALUE) { + event = CreateEvent(NULL, TRUE, FALSE, NULL); + if (event == NULL) + return PyErr_SetExcFromWindowsErr(PyExc_OSError, 0); + } + + self = PyObject_New(OverlappedObject, type); + if (self == NULL) { + if (event != NULL) + CloseHandle(event); + return NULL; + } + + self->handle = NULL; + self->error = 0; + self->type = TYPE_NONE; + self->read_buffer = NULL; + memset(&self->overlapped, 0, sizeof(OVERLAPPED)); + memset(&self->write_buffer, 0, sizeof(Py_buffer)); + if (event) + self->overlapped.hEvent = event; + return (PyObject *)self; +} + +static void +Overlapped_dealloc(OverlappedObject *self) +{ + DWORD bytes; + DWORD olderr = GetLastError(); + BOOL wait = FALSE; + BOOL ret; + + if (!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED) + { + if (Py_CancelIoEx && Py_CancelIoEx(self->handle, &self->overlapped)) + wait = TRUE; + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, + &bytes, wait); + Py_END_ALLOW_THREADS + + switch (ret ? ERROR_SUCCESS : GetLastError()) { + case ERROR_SUCCESS: + case ERROR_NOT_FOUND: + case ERROR_OPERATION_ABORTED: + break; + default: + PyErr_SetString( + PyExc_RuntimeError, + "I/O operations still in flight while destroying " + "Overlapped object, the process may crash"); + PyErr_WriteUnraisable(NULL); + } + } + + if (self->overlapped.hEvent != NULL) + CloseHandle(self->overlapped.hEvent); + + if (self->write_buffer.obj) + PyBuffer_Release(&self->write_buffer); + + Py_CLEAR(self->read_buffer); + PyObject_Del(self); + SetLastError(olderr); +} + +PyDoc_STRVAR( + Overlapped_cancel_doc, + "cancel() -> None\n\n" + "Cancel overlapped operation"); + +static PyObject * +Overlapped_cancel(OverlappedObject *self) +{ + BOOL ret = TRUE; + + if (self->type == TYPE_NOT_STARTED) + Py_RETURN_NONE; + + if (!HasOverlappedIoCompleted(&self->overlapped)) { + Py_BEGIN_ALLOW_THREADS + if (Py_CancelIoEx) + ret = Py_CancelIoEx(self->handle, &self->overlapped); + else + ret = CancelIo(self->handle); + Py_END_ALLOW_THREADS + } + + /* CancelIoEx returns ERROR_NOT_FOUND if the I/O completed in-between */ + if (!ret && GetLastError() != ERROR_NOT_FOUND) + return PyErr_SetExcFromWindowsErr(PyExc_IOError, 0); + Py_RETURN_NONE; +} + +PyDoc_STRVAR( + Overlapped_getresult_doc, + "getresult(wait=False) -> result\n\n" + "Retrieve result of operation. If wait is true then it blocks\n" + "until the operation is finished. If wait is false and the\n" + "operation is still pending then an error is raised."); + +static PyObject * +Overlapped_getresult(OverlappedObject *self, PyObject *args) +{ + BOOL wait = FALSE; + DWORD transferred = 0; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, "|" F_BOOL, &wait)) + return NULL; + + if (self->type == TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation not yet attempted"); + return NULL; + } + + if (self->type == TYPE_NOT_STARTED) { + PyErr_SetString(PyExc_ValueError, "operation failed to start"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, &transferred, + wait); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + break; + case ERROR_BROKEN_PIPE: + if (self->read_buffer != NULL) + break; + /* fall through */ + default: + return PyErr_SetExcFromWindowsErr(PyExc_IOError, err); + } + + switch (self->type) { + case TYPE_READ: + assert(PyBytes_CheckExact(self->read_buffer)); + if (transferred != PyBytes_GET_SIZE(self->read_buffer) && + _PyBytes_Resize(&self->read_buffer, transferred)) + return NULL; + Py_INCREF(self->read_buffer); + return self->read_buffer; + case TYPE_ACCEPT: + case TYPE_CONNECT: + case TYPE_DISCONNECT: + Py_RETURN_NONE; + default: + return PyLong_FromUnsignedLong((unsigned long) transferred); + } +} + +PyDoc_STRVAR( + Overlapped_ReadFile_doc, + "ReadFile(handle, size) -> Overlapped[message]\n\n" + "Start overlapped read"); + +static PyObject * +Overlapped_ReadFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD nread; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &handle, &size)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = ReadFile(handle, PyBytes_AS_STRING(buf), size, &nread, + &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + case ERROR_BROKEN_PIPE: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return PyErr_SetExcFromWindowsErr(PyExc_IOError, err); + } +} + +PyDoc_STRVAR( + Overlapped_WSARecv_doc, + "RecvFile(handle, size, flags) -> Overlapped[message]\n\n" + "Start overlapped receive"); + +static PyObject * +Overlapped_WSARecv(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD flags; + DWORD nread; + PyObject *buf; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD F_DWORD, + &handle, &size, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + wsabuf.len = size; + wsabuf.buf = PyBytes_AS_STRING(buf); + + Py_BEGIN_ALLOW_THREADS + ret = WSARecv((SOCKET)handle, &wsabuf, 1, &nread, &flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + case ERROR_BROKEN_PIPE: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return PyErr_SetExcFromWindowsErr(PyExc_IOError, err); + } +} + +PyDoc_STRVAR( + Overlapped_WriteFile_doc, + "WriteFile(handle, buf) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped write"); + +static PyObject * +Overlapped_WriteFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD written; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &handle, &bufobj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)PY_ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + + Py_BEGIN_ALLOW_THREADS + ret = WriteFile(handle, self->write_buffer.buf, self->write_buffer.len, + &written, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return PyErr_SetExcFromWindowsErr(PyExc_IOError, err); + } +} + +PyDoc_STRVAR( + Overlapped_WSASend_doc, + "WSASend(handle, buf, flags) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped send"); + +static PyObject * +Overlapped_WSASend(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD flags; + DWORD written; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O" F_DWORD, + &handle, &bufobj, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)PY_ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + wsabuf.len = (DWORD)self->write_buffer.len; + wsabuf.buf = self->write_buffer.buf; + + Py_BEGIN_ALLOW_THREADS + ret = WSASend((SOCKET)handle, &wsabuf, 1, &written, flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return PyErr_SetExcFromWindowsErr(PyExc_IOError, err); + } +} + +PyDoc_STRVAR( + Overlapped_AcceptEx_doc, + "AcceptEx(listen_handle, accept_handle) -> Overlapped[address_as_bytes]\n\n" + "Start overlapped wait for client to connect"); + +static PyObject * +Overlapped_AcceptEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ListenSocket; + SOCKET AcceptSocket; + DWORD BytesReceived; + DWORD size; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE, + &ListenSocket, &AcceptSocket)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + size = sizeof(struct sockaddr_in6) + 16; + buf = PyBytes_FromStringAndSize(NULL, size*2); + if (!buf) + return NULL; + + self->type = TYPE_ACCEPT; + self->handle = (HANDLE)ListenSocket; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = Py_AcceptEx(ListenSocket, AcceptSocket, PyBytes_AS_STRING(buf), + 0, size, size, &BytesReceived, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return PyErr_SetExcFromWindowsErr(PyExc_IOError, err); + } +} + + +static int +parse_address(PyObject *obj, SOCKADDR *Address, int Length) +{ + char *Host; + unsigned short Port; + unsigned long FlowInfo; + unsigned long ScopeId; + + memset(Address, 0, Length); + + if (PyArg_ParseTuple(obj, "sH", &Host, &Port)) + { + Address->sa_family = AF_INET; + if (WSAStringToAddressA(Host, AF_INET, NULL, Address, &Length) < 0) { + PyErr_SetExcFromWindowsErr(PyExc_IOError, WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN*)Address)->sin_port = htons(Port); + return Length; + } + else if (PyArg_ParseTuple(obj, "sHkk", &Host, &Port, &FlowInfo, &ScopeId)) + { + PyErr_Clear(); + Address->sa_family = AF_INET6; + if (WSAStringToAddressA(Host, AF_INET6, NULL, Address, &Length) < 0) { + PyErr_SetExcFromWindowsErr(PyExc_IOError, WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN6*)Address)->sin6_port = htons(Port); + ((SOCKADDR_IN6*)Address)->sin6_flowinfo = FlowInfo; + ((SOCKADDR_IN6*)Address)->sin6_scope_id = ScopeId; + return Length; + } + + return -1; +} + + +PyDoc_STRVAR( + Overlapped_ConnectEx_doc, + "ConnectEx(client_handle, address_as_bytes) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_ConnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ConnectSocket; + PyObject *AddressObj; + char AddressBuf[sizeof(struct sockaddr_in6)]; + SOCKADDR *Address = (SOCKADDR*)AddressBuf; + int Length; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &ConnectSocket, &AddressObj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + Length = sizeof(AddressBuf); + Length = parse_address(AddressObj, Address, Length); + if (Length < 0) + return NULL; + + self->type = TYPE_CONNECT; + self->handle = (HANDLE)ConnectSocket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_ConnectEx(ConnectSocket, Address, Length, + NULL, 0, NULL, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return PyErr_SetExcFromWindowsErr(PyExc_IOError, err); + } +} + +PyDoc_STRVAR( + Overlapped_DisconnectEx_doc, + "DisconnectEx(handle, flags) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_DisconnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET Socket; + DWORD flags; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &Socket, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + self->type = TYPE_DISCONNECT; + self->handle = (HANDLE)Socket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_DisconnectEx(Socket, &self->overlapped, flags, 0); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return PyErr_SetExcFromWindowsErr(PyExc_IOError, err); + } +} + +static PyObject* +Overlapped_getaddress(OverlappedObject *self) +{ + return PyLong_FromVoidPtr(&self->overlapped); +} + +static PyObject* +Overlapped_getpending(OverlappedObject *self) +{ + return PyBool_FromLong(!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED); +} + +static PyMethodDef Overlapped_methods[] = { + {"getresult", (PyCFunction) Overlapped_getresult, + METH_VARARGS, Overlapped_getresult_doc}, + {"cancel", (PyCFunction) Overlapped_cancel, + METH_NOARGS, Overlapped_cancel_doc}, + {"ReadFile", (PyCFunction) Overlapped_ReadFile, + METH_VARARGS, Overlapped_ReadFile_doc}, + {"WSARecv", (PyCFunction) Overlapped_WSARecv, + METH_VARARGS, Overlapped_WSARecv_doc}, + {"WriteFile", (PyCFunction) Overlapped_WriteFile, + METH_VARARGS, Overlapped_WriteFile_doc}, + {"WSASend", (PyCFunction) Overlapped_WSASend, + METH_VARARGS, Overlapped_WSASend_doc}, + {"AcceptEx", (PyCFunction) Overlapped_AcceptEx, + METH_VARARGS, Overlapped_AcceptEx_doc}, + {"ConnectEx", (PyCFunction) Overlapped_ConnectEx, + METH_VARARGS, Overlapped_ConnectEx_doc}, + {"DisconnectEx", (PyCFunction) Overlapped_DisconnectEx, + METH_VARARGS, Overlapped_DisconnectEx_doc}, + {NULL} +}; + +static PyMemberDef Overlapped_members[] = { + {"error", T_ULONG, + offsetof(OverlappedObject, error), + READONLY, "Error from last operation"}, + {"event", T_HANDLE, + offsetof(OverlappedObject, overlapped) + offsetof(OVERLAPPED, hEvent), + READONLY, "Overlapped event handle"}, + {NULL} +}; + +static PyGetSetDef Overlapped_getsets[] = { + {"address", (getter)Overlapped_getaddress, NULL, + "Address of overlapped structure"}, + {"pending", (getter)Overlapped_getpending, NULL, + "Whether the operation is pending"}, + {NULL}, +}; + +PyTypeObject OverlappedType = { + PyVarObject_HEAD_INIT(NULL, 0) + /* tp_name */ "_overlapped.Overlapped", + /* tp_basicsize */ sizeof(OverlappedObject), + /* tp_itemsize */ 0, + /* tp_dealloc */ (destructor) Overlapped_dealloc, + /* tp_print */ 0, + /* tp_getattr */ 0, + /* tp_setattr */ 0, + /* tp_reserved */ 0, + /* tp_repr */ 0, + /* tp_as_number */ 0, + /* tp_as_sequence */ 0, + /* tp_as_mapping */ 0, + /* tp_hash */ 0, + /* tp_call */ 0, + /* tp_str */ 0, + /* tp_getattro */ 0, + /* tp_setattro */ 0, + /* tp_as_buffer */ 0, + /* tp_flags */ Py_TPFLAGS_DEFAULT, + /* tp_doc */ "OVERLAPPED structure wrapper", + /* tp_traverse */ 0, + /* tp_clear */ 0, + /* tp_richcompare */ 0, + /* tp_weaklistoffset */ 0, + /* tp_iter */ 0, + /* tp_iternext */ 0, + /* tp_methods */ Overlapped_methods, + /* tp_members */ Overlapped_members, + /* tp_getset */ Overlapped_getsets, + /* tp_base */ 0, + /* tp_dict */ 0, + /* tp_descr_get */ 0, + /* tp_descr_set */ 0, + /* tp_dictoffset */ 0, + /* tp_init */ 0, + /* tp_alloc */ 0, + /* tp_new */ Overlapped_new, +}; + +static PyMethodDef overlapped_functions[] = { + {"CreateIoCompletionPort", overlapped_CreateIoCompletionPort, + METH_VARARGS, CreateIoCompletionPort_doc}, + {"GetQueuedCompletionStatus", overlapped_GetQueuedCompletionStatus, + METH_VARARGS, GetQueuedCompletionStatus_doc}, + {"PostQueuedCompletionStatus", overlapped_PostQueuedCompletionStatus, + METH_VARARGS, PostQueuedCompletionStatus_doc}, + {"BindLocal", overlapped_BindLocal, + METH_VARARGS, BindLocal_doc}, + {"SetFileCompletionNotificationModes", + overlapped_SetFileCompletionNotificationModes, + METH_VARARGS, SetFileCompletionNotificationModes_doc}, + {NULL} +}; + +static struct PyModuleDef overlapped_module = { + PyModuleDef_HEAD_INIT, + "_overlapped", + NULL, + -1, + overlapped_functions, + NULL, + NULL, + NULL, + NULL +}; + +#define WINAPI_CONSTANT(fmt, con) \ + PyDict_SetItemString(d, #con, Py_BuildValue(fmt, con)) + +PyMODINIT_FUNC +PyInit__overlapped(void) +{ + PyObject *m, *d; + + /* Ensure WSAStartup() called before initializing function pointers */ + m = PyImport_ImportModule("_socket"); + if (!m) + return NULL; + Py_DECREF(m); + + if (initialize_function_pointers() < 0) + return NULL; + + if (PyType_Ready(&OverlappedType) < 0) + return NULL; + + m = PyModule_Create(&overlapped_module); + if (PyModule_AddObject(m, "Overlapped", (PyObject *)&OverlappedType) < 0) + return NULL; + + d = PyModule_GetDict(m); + + WINAPI_CONSTANT(F_DWORD, FILE_SKIP_COMPLETION_PORT_ON_SUCCESS); + WINAPI_CONSTANT(F_DWORD, INFINITE); + WINAPI_CONSTANT(F_HANDLE, INVALID_HANDLE_VALUE); + WINAPI_CONSTANT(F_HANDLE, NULL); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_ACCEPT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_CONNECT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, TF_REUSE_SOCKET); + + return m; +} diff --git a/runtests.py b/runtests.py new file mode 100644 index 00000000..1eef652b --- /dev/null +++ b/runtests.py @@ -0,0 +1,73 @@ +"""Run all unittests. + +Usage: + python3 runtests.py [-v] [-q] [pattern] ... + +Where: + -v: verbose + -q: quiet + pattern: optional regex patterns to match test ids (default all tests) + +Note that the test id is the fully qualified name of the test, +including package, module, class and method, +e.g. 'tulip.events_test.PolicyTests.testPolicy'. +""" + +# Originally written by Beech Horn (for NDB). + +import re +import sys +import unittest + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +def load_tests(includes=(), excludes=()): + mods = ['events', 'futures', 'tasks'] + test_mods = ['%s_test' % name for name in mods] + tulip = __import__('tulip', fromlist=test_mods) + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + for mod in [getattr(tulip, name) for name in test_mods]: + for name in set(dir(mod)): + if name.endswith('Tests'): + test_module = getattr(mod, name) + tests = loader.loadTestsFromTestCase(test_module) + if includes: + tests = [test + for test in tests + if any(re.search(pat, test.id()) for pat in includes)] + if excludes: + tests = [test + for test in tests + if not any(re.search(pat, test.id()) for pat in excludes)] + suite.addTests(tests) + + return suite + + +def main(): + excludes = [] + includes = [] + patterns = includes # A reference. + v = 1 + for arg in sys.argv[1:]: + if arg.startswith('-v'): + v += arg.count('v') + elif arg == '-q': + v = 0 + elif arg == '-x': + if patterns is includes: + patterns = excludes + else: + patterns = includes + elif arg and not arg.startswith('-'): + patterns.append(arg) + tests = load_tests(includes, excludes) + result = unittest.TextTestRunner(verbosity=v).run(tests) + sys.exit(not result.wasSuccessful()) + + +if __name__ == '__main__': + main() diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..0260f9d5 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[build_ext] +build_lib=tulip diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..67b037cc --- /dev/null +++ b/setup.py @@ -0,0 +1,4 @@ +from distutils.core import setup, Extension + +ext = Extension('_overlapped', ['overlapped.c'], libraries=['ws2_32']) +setup(name='_overlapped', ext_modules=[ext]) diff --git a/tulip/TODO b/tulip/TODO new file mode 100644 index 00000000..b3a9302e --- /dev/null +++ b/tulip/TODO @@ -0,0 +1,26 @@ +TODO in tulip v2 (tulip/ package directory) +------------------------------------------- + +- See also TBD and Open Issues in PEP 3156 + +- Refactor unix_events.py (it's getting too long) + +- Docstrings + +- Unittests + +- better run_once() behavior? (Run ready list last.) + +- start_serving() + +- Make Handler() callable? Log the exception in there? + +- Add the repeat interval to the Handler class? + +- Recognize Handler passed to add_reader(), call_soon(), etc.? + +- SSL support + +- buffered stream implementation + +- Primitives like par() and wait_one() diff --git a/tulip/__init__.py b/tulip/__init__.py new file mode 100644 index 00000000..185fe3fe --- /dev/null +++ b/tulip/__init__.py @@ -0,0 +1,14 @@ +"""Tulip 2.0, tracking PEP 3156.""" + +# This relies on each of the submodules having an __all__ variable. +from .futures import * +from .events import * +from .transports import * +from .protocols import * +from .tasks import * + +__all__ = (futures.__all__ + + events.__all__ + + transports.__all__ + + protocols.__all__ + + tasks.__all__) diff --git a/tulip/events.py b/tulip/events.py new file mode 100644 index 00000000..f39ddb79 --- /dev/null +++ b/tulip/events.py @@ -0,0 +1,287 @@ +"""Event loop and event loop policy. + +Beyond the PEP: +- Only the main thread has a default event loop. +""" + +__all__ = ['EventLoopPolicy', 'DefaultEventLoopPolicy', + 'EventLoop', 'Handler', 'make_handler', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', 'new_event_loop', + ] + +import threading + + +class Handler: + """Object returned by callback registration methods.""" + + def __init__(self, when, callback, args): + self._when = when + self._callback = callback + self._args = args + self._cancelled = False + + def __repr__(self): + res = 'Handler({}, {}, {})'.format(self._when, + self._callback, + self._args) + if self._cancelled: + res += '' + return res + + @property + def when(self): + return self._when + + @property + def callback(self): + return self._callback + + @property + def args(self): + return self._args + + @property + def cancelled(self): + return self._cancelled + + def cancel(self): + self._cancelled = True + + def __lt__(self, other): + return self._when < other._when + + def __le__(self, other): + return self._when <= other._when + + def __gt__(self, other): + return self._when > other._when + + def __ge__(self, other): + return self._when >= other._when + + def __eq__(self, other): + return self._when == other._when + + +def make_handler(when, callback, args): + if isinstance(callback, Handler): + assert not args + assert when is None + return callback + return Handler(when, callback, args) + + +class EventLoop: + """Abstract event loop.""" + + # TODO: Rename run() -> run_until_idle(), run_forever() -> run(). + + def run(self): + """Run the event loop. Block until there is nothing left to do.""" + raise NotImplementedError + + def run_forever(self): + """Run the event loop. Block until stop() is called.""" + raise NotImplementedError + + def run_once(self, timeout=None): # NEW! + """Run one complete cycle of the event loop.""" + raise NotImplementedError + + def run_until_complete(self, future, timeout=None): # NEW! + """Run the event loop until a Future is done. + + Return the Future's result, or raise its exception. + + If timeout is not None, run it for at most that long; + if the Future is still not done, raise TimeoutError + (but don't cancel the Future). + """ + raise NotImplementedError + + def stop(self): # NEW! + """Stop the event loop as soon as reasonable. + + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + raise NotImplementedError + + # Methods returning Handlers for scheduling callbacks. + + def call_later(self, delay, callback, *args): + raise NotImplementedError + + def call_repeatedly(self, interval, callback, *args): # NEW! + raise NotImplementedError + + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) + + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError + + # Methods returning Futures for interacting with threads. + + def wrap_future(self, future): + raise NotImplementedError + + def run_in_executor(self, executor, callback, *args): + raise NotImplementedError + + # Network I/O methods returning Futures. + + def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError + + def create_connection(self, protocol_factory, host, port, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + def start_serving(self, protocol_factory, host, port, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + # Ready-based callback registration methods. + # The add_*() methods return a Handler. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. + + def add_reader(self, fd, callback, *args): + raise NotImplementedError + + def remove_reader(self, fd): + raise NotImplementedError + + def add_writer(self, fd, callback, *args): + raise NotImplementedError + + def remove_writer(self, fd): + raise NotImplementedError + + def add_connector(self, fd, callback, *args): + raise NotImplementedError + + def remove_connector(self, fd): + raise NotImplementedError + + # Completion based I/O methods returning Futures. + + def sock_recv(self, sock, nbytes): + raise NotImplementedError + + def sock_sendall(self, sock, data): + raise NotImplementedError + + def sock_connect(self, sock, address): + raise NotImplementedError + + def sock_accept(self, sock): + raise NotImplementedError + + # Signal handling. + + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError + + def remove_signal_handler(self, sig): + raise NotImplementedError + + +class EventLoopPolicy: + """Abstract policy for accessing the event loop.""" + + def get_event_loop(self): + """XXX""" + raise NotImplementedError + + def set_event_loop(self, event_loop): + """XXX""" + raise NotImplementedError + + def new_event_loop(self): + """XXX""" + raise NotImplementedError + + +class DefaultEventLoopPolicy(threading.local, EventLoopPolicy): + """Default policy implementation for accessing the event loop. + + In this policy, each thread has its own event loop. However, we + only automatically create an event loop by default for the main + thread; other threads by default have no event loop. + + Other policies may have different rules (e.g. a single global + event loop, or automatically creating an event loop per thread, or + using some other notion of context to which an event loop is + associated). + """ + + _event_loop = None + + def get_event_loop(self): + """Get the event loop. + + This may be None or an instance of EventLoop. + """ + if (self._event_loop is None and + threading.current_thread().name == 'MainThread'): + self._event_loop = self.new_event_loop() + return self._event_loop + + def set_event_loop(self, event_loop): + """Set the event loop.""" + assert event_loop is None or isinstance(event_loop, EventLoop) + self._event_loop = event_loop + + def new_event_loop(self): + """Create a new event loop. + + You must call set_event_loop() to make this the current event + loop. + """ + # TODO: Do something else for Windows. + from . import unix_events + return unix_events.UnixEventLoop() + + +# Event loop policy. The policy itself is always global, even if the +# policy's rules say that there is an event loop per thread (or other +# notion of context). The default policy is installed by the first +# call to get_event_loop_policy(). +_event_loop_policy = None + + +def get_event_loop_policy(): + """XXX""" + global _event_loop_policy + if _event_loop_policy is None: + _event_loop_policy = DefaultEventLoopPolicy() + return _event_loop_policy + + +def set_event_loop_policy(policy): + """XXX""" + global _event_loop_policy + assert policy is None or isinstance(policy, EventLoopPolicy) + _event_loop_policy = policy + + +def get_event_loop(): + """XXX""" + return get_event_loop_policy().get_event_loop() + + +def set_event_loop(event_loop): + """XXX""" + get_event_loop_policy().set_event_loop(event_loop) + + +def new_event_loop(): + """XXX""" + return get_event_loop_policy().new_event_loop() diff --git a/tulip/events_test.py b/tulip/events_test.py new file mode 100644 index 00000000..07f09dd4 --- /dev/null +++ b/tulip/events_test.py @@ -0,0 +1,512 @@ +"""Tests for events.py.""" + +import concurrent.futures +import gc +import os +import select +import signal +import socket +try: + import ssl +except ImportError: + ssl = None +import sys +import threading +import time +import unittest + +from . import events +from . import transports +from . import protocols +from . import selectors +from . import unix_events + + +class MyProto(protocols.Protocol): + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: xkcd.com\r\n\r\n') + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + + +class EventLoopTestsMixin: + + def setUp(self): + self.selector = self.SELECTOR_CLASS() + self.event_loop = unix_events.UnixEventLoop(self.selector) + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + gc.collect() + + def testRun(self): + el = events.get_event_loop() + el.run() # Returns immediately. + + def testCallLater(self): + el = events.get_event_loop() + results = [] + def callback(arg): + results.append(arg) + el.call_later(0.1, callback, 'hello world') + t0 = time.monotonic() + el.run() + t1 = time.monotonic() + self.assertEqual(results, ['hello world']) + self.assertTrue(t1-t0 >= 0.09) + + def testCallRepeatedly(self): + el = events.get_event_loop() + results = [] + def callback(arg): + results.append(arg) + el.call_repeatedly(0.03, callback, 'ho') + el.call_later(0.1, el.stop) + el.run() + self.assertEqual(results, ['ho', 'ho', 'ho']) + + def testCallSoon(self): + el = events.get_event_loop() + results = [] + def callback(arg1, arg2): + results.append((arg1, arg2)) + el.call_soon(callback, 'hello', 'world') + el.run() + self.assertEqual(results, [('hello', 'world')]) + + def testCallSoonWithHandler(self): + el = events.get_event_loop() + results = [] + def callback(): + results.append('yeah') + handler = events.Handler(None, callback, ()) + self.assertEqual(el.call_soon(handler), handler) + el.run() + self.assertEqual(results, ['yeah']) + + def testCallSoonThreadsafe(self): + el = events.get_event_loop() + results = [] + def callback(arg): + results.append(arg) + def run(): + el.call_soon_threadsafe(callback, 'hello') + t = threading.Thread(target=run) + el.call_later(0.1, callback, 'world') + t0 = time.monotonic() + t.start() + el.run() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + + def testCallSoonThreadsafeWithHandler(self): + el = events.get_event_loop() + results = [] + def callback(arg): + results.append(arg) + handler = events.Handler(None, callback, ('hello',)) + def run(): + self.assertEqual(el.call_soon_threadsafe(handler), handler) + t = threading.Thread(target=run) + el.call_later(0.1, callback, 'world') + t0 = time.monotonic() + t.start() + el.run() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + + def testCallEveryIteration(self): + el = events.get_event_loop() + results = [] + def callback(arg): + results.append(arg) + handler = el.call_every_iteration(callback, 'ho') + el.run_once() + self.assertEqual(results, ['ho']) + el.run_once() + el.run_once() + self.assertEqual(results, ['ho', 'ho', 'ho']) + handler.cancel() + el.run_once() + self.assertEqual(results, ['ho', 'ho', 'ho']) + + def testCallEveryIterationWithHandler(self): + el = events.get_event_loop() + results = [] + def callback(arg): + results.append(arg) + handler = events.Handler(None, callback, ('ho',)) + self.assertEqual(el.call_every_iteration(handler), handler) + el.run_once() + self.assertEqual(results, ['ho']) + el.run_once() + el.run_once() + self.assertEqual(results, ['ho', 'ho', 'ho']) + handler.cancel() + el.run_once() + self.assertEqual(results, ['ho', 'ho', 'ho']) + + def testWrapFuture(self): + el = events.get_event_loop() + def run(arg): + time.sleep(0.1) + return arg + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = el.wrap_future(f1) + res = el.run_until_complete(f2) + self.assertEqual(res, 'oi') + + def testRunInExecutor(self): + el = events.get_event_loop() + def run(arg): + time.sleep(0.1) + return arg + f2 = el.run_in_executor(None, run, 'yo') + res = el.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def testRunInExecutorWithHandler(self): + el = events.get_event_loop() + def run(arg): + time.sleep(0.1) + return arg + handler = events.Handler(None, run, ('yo',)) + f2 = el.run_in_executor(None, handler) + res = el.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def testReaderCallback(self): + el = events.get_event_loop() + r, w = unix_events.socketpair() + r = el._selector.wrap_socket(r) + w = el._selector.wrap_socket(w) + bytes_read = [] + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + return + if data: + bytes_read.append(data) + else: + self.assertTrue(el.remove_reader(r.fileno())) + r.close() + el.add_reader(r.fileno(), reader) + el.call_later(0.05, w.send, b'abc') + el.call_later(0.1, w.send, b'def') + el.call_later(0.15, w.close) + el.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def testReaderCallbackWithHandler(self): + el = events.get_event_loop() + r, w = unix_events.socketpair() + r = el._selector.wrap_socket(r) + w = el._selector.wrap_socket(w) + bytes_read = [] + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + return + if data: + bytes_read.append(data) + else: + self.assertTrue(el.remove_reader(r.fileno())) + r.close() + handler = events.Handler(None, reader, ()) + self.assertEqual(el.add_reader(r.fileno(), handler), handler) + el.call_later(0.05, w.send, b'abc') + el.call_later(0.1, w.send, b'def') + el.call_later(0.15, w.close) + el.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def testReaderCallbackCancel(self): + el = events.get_event_loop() + r, w = unix_events.socketpair() + r = el._selector.wrap_socket(r) + w = el._selector.wrap_socket(w) + bytes_read = [] + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + return + if data: + bytes_read.append(data) + if sum(len(b) for b in bytes_read) >= 6: + handler.cancel() + if not data: + r.close() + handler = el.add_reader(r.fileno(), reader) + el.call_later(0.05, w.send, b'abc') + el.call_later(0.1, w.send, b'def') + el.call_later(0.15, w.close) + el.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def testWriterCallback(self): + el = events.get_event_loop() + r, w = unix_events.socketpair() + w = el._selector.wrap_socket(w) + w.setblocking(False) + el.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + def remove_writer(): + self.assertTrue(el.remove_writer(w.fileno())) + el.call_later(0.1, remove_writer) + el.run() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def testWriterCallbackWithHandler(self): + el = events.get_event_loop() + r, w = unix_events.socketpair() + w = el._selector.wrap_socket(w) + w.setblocking(False) + handler = events.Handler(None, w.send, (b'x'*(256*1024),)) + self.assertEqual(el.add_writer(w.fileno(), handler), handler) + def remove_writer(): + self.assertTrue(el.remove_writer(w.fileno())) + el.call_later(0.1, remove_writer) + el.run() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def testWriterCallbackCancel(self): + el = events.get_event_loop() + r, w = unix_events.socketpair() + r = el._selector.wrap_socket(r) + w = el._selector.wrap_socket(w) + w.setblocking(False) + def sender(): + w.send(b'x'*256) + handler.cancel() + handler = el.add_writer(w.fileno(), sender) + el.run() + w.close() + data = r.recv(1024) + r.close() + self.assertTrue(data == b'x'*256) + + def testSockClientOps(self): + el = events.get_event_loop() + sock = socket.socket() + sock = el._selector.wrap_socket(sock) + sock.setblocking(False) + # TODO: This depends on python.org behavior! + address = socket.getaddrinfo('python.org', 80, socket.AF_INET)[0][4] + el.run_until_complete(el.sock_connect(sock, address)) + el.run_until_complete(el.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = el.run_until_complete(el.sock_recv(sock, 1024)) + sock.close() + self.assertTrue(data.startswith(b'HTTP/1.1 302 Found\r\n')) + + def testSockClientFail(self): + el = events.get_event_loop() + sock = socket.socket() + sock.setblocking(False) + sock = el._selector.wrap_socket(sock) + # TODO: This depends on python.org behavior! + address = socket.getaddrinfo('python.org', 12345, socket.AF_INET)[0][4] + with self.assertRaises(ConnectionRefusedError): + el.run_until_complete(el.sock_connect(sock, address)) + sock.close() + + def testSockAccept(self): + el = events.get_event_loop() + listener = socket.socket() + listener = el._selector.wrap_socket(listener) + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + f = el.sock_accept(listener) + conn, addr = el.run_until_complete(f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + def testAddSignalHandler(self): + caught = 0 + def my_handler(): + nonlocal caught + caught += 1 + el = events.get_event_loop() + # Check error behavior first. + self.assertRaises(TypeError, el.add_signal_handler, 'boom', my_handler) + self.assertRaises(TypeError, el.remove_signal_handler, 'boom') + self.assertRaises(ValueError, el.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises(ValueError, el.remove_signal_handler, signal.NSIG+1) + self.assertRaises(ValueError, el.add_signal_handler, 0, my_handler) + self.assertRaises(ValueError, el.remove_signal_handler, 0) + self.assertRaises(ValueError, el.add_signal_handler, -1, my_handler) + self.assertRaises(ValueError, el.remove_signal_handler, -1) + self.assertRaises(RuntimeError, el.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(el.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + el.add_signal_handler(signal.SIGINT, my_handler) + el.run_once() + os.kill(os.getpid(), signal.SIGINT) + el.run_once() + self.assertEqual(caught, 1) + # Removing it should restore the default handler. + self.assertTrue(el.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(el.remove_signal_handler(signal.SIGINT)) + + @unittest.skipIf(sys.platform == 'win32', 'Unix only') + def testCancelSignalHandler(self): + # Cancelling the handler should remove it (eventually). + caught = 0 + def my_handler(): + nonlocal caught + caught += 1 + el = events.get_event_loop() + handler = el.add_signal_handler(signal.SIGINT, my_handler) + handler.cancel() + os.kill(os.getpid(), signal.SIGINT) + el.run_once() + self.assertEqual(caught, 0) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def testSignalHandlingWhileSelecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + def my_handler(): + nonlocal caught + caught += 1 + el = events.get_event_loop() + handler = el.add_signal_handler(signal.SIGALRM, my_handler) + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + el.call_later(0.15, el.stop) + el.run_forever() + self.assertEqual(caught, 1) + + def testCreateTransport(self): + el = events.get_event_loop() + # TODO: This depends on xkcd.com behavior! + f = el.create_connection(MyProto, 'xkcd.com', 80) + tr, pr = el.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + el.run() + self.assertTrue(pr.nbytes > 0) + + @unittest.skipIf(ssl is None, 'No ssl module') + def testCreateSslTransport(self): + el = events.get_event_loop() + # TODO: This depends on xkcd.com behavior! + f = el.create_connection(MyProto, 'xkcd.com', 443, ssl=True) + tr, pr = el.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + el.run() + self.assertTrue(pr.nbytes > 0) + + def testStartServing(self): + el = events.get_event_loop() + f = el.start_serving(MyProto, '0.0.0.0', 0) + sock = el.run_until_complete(f) + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + el.run_once() # This is quite mysterious, but necessary. + el.run_once() + el.run_once() + sock.close() + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + +if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + SELECTOR_CLASS = selectors.KqueueSelector + + +if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + SELECTOR_CLASS = selectors.EpollSelector + + +if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + SELECTOR_CLASS = selectors.PollSelector + + +try: + from . import iocpsockets +except ImportError: + pass +else: + class IocpEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + SELECTOR_CLASS = iocpsockets.IocpSelector + + def testCreateSslTransport(self): + raise unittest.SkipTest("IocpSelector imcompatible with SSL") + + +# Should always exist. +class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + SELECTOR_CLASS = selectors.SelectSelector + + +class HandlerTests(unittest.TestCase): + + def testHandler(self): + pass + + def testMakeHandler(self): + def callback(*args): + return args + h1 = events.Handler(None, callback, ()) + h2 = events.make_handler(None, h1, ()) + self.assertEqual(h1, h2) + + +class PolicyTests(unittest.TestCase): + + def testPolicy(self): + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/tulip/futures.py b/tulip/futures.py new file mode 100644 index 00000000..e79999fc --- /dev/null +++ b/tulip/futures.py @@ -0,0 +1,240 @@ +"""A Future class similar to the one in PEP 3148.""" + +__all__ = ['CancelledError', 'TimeoutError', + 'InvalidStateError', 'InvalidTimeoutError', + 'Future', + ] + +import concurrent.futures._base + +from . import events + +# States for Future. +_PENDING = 'PENDING' +_CANCELLED = 'CANCELLED' +_FINISHED = 'FINISHED' + +# TODO: Do we really want to depend on concurrent.futures internals? +Error = concurrent.futures._base.Error +CancelledError = concurrent.futures.CancelledError +TimeoutError = concurrent.futures.TimeoutError + + +class InvalidStateError(Error): + """The operation is not allowed in this state.""" + # TODO: Show the future, its state, the method, and the required state. + + +class InvalidTimeoutError(Error): + """Called result() or exception() with timeout != 0.""" + # TODO: Print a nice error message. + + +class Future: + """This class is *almost* compatible with concurrent.futures.Future. + + Differences: + + - result() and exception() do not take a timeout argument and + raise an exception when the future isn't done yet. + + - Callbacks registered with add_done_callback() are always called + via the event loop's call_soon_threadsafe(). + + - This class is not compatible with the wait() and as_completed() + methods in the concurrent.futures package. + + (In Python 3.4 or later we may be able to unify the implementations.) + """ + + # Class variables serving as defaults for instance variables. + _state = _PENDING + _result = None + _exception = None + + def __init__(self, *, event_loop=None): + """Initialize the future. + + The optional event_loop argument allows to explicitly set the event + loop object used by the future. If it's not provided, the future uses + the default event loop. + """ + if event_loop is None: + self._event_loop = events.get_event_loop() + else: + self._event_loop = event_loop + self._callbacks = [] + + def __repr__(self): + res = self.__class__.__name__ + if self._state == _FINISHED: + if self._exception is not None: + res += ''.format(self._exception) + else: + res += ''.format(self._result) + elif self._callbacks: + size = len(self._callbacks) + if size > 2: + res += '<{}, [{}, <{} more>, {}]>'.format( + self._state, self._callbacks[0], + size-2, self._callbacks[-1]) + else: + res += '<{}, {}>'.format(self._state, self._callbacks) + else: + res +='<{}>'.format(self._state) + return res + + def cancel(self): + """Cancel the future and schedule callbacks. + + If the future is already done or cancelled, return False. Otherwise, + change the future's state to cancelled, schedule the callbacks and + return True. + """ + if self._state != _PENDING: + return False + self._state = _CANCELLED + self._schedule_callbacks() + return True + + def _schedule_callbacks(self): + """Internal: Ask the event loop to call all callbacks. + + The callbacks are scheduled to be called as soon as possible. Also + clears the callback list. + """ + callbacks = self._callbacks[:] + if not callbacks: + return + self._callbacks[:] = [] + for callback in callbacks: + self._event_loop.call_soon(callback, self) + + def cancelled(self): + """Return True if the future was cancelled.""" + return self._state == _CANCELLED + + def running(self): + """Always return False. + + This method is for compatibility with concurrent.futures; we don't + have a running state. + """ + return False # We don't have a running state. + + def done(self): + """Return True if the future is done. + + Done means either that a result / exception are available, or that the + future was cancelled. + """ + return self._state != _PENDING + + def result(self, timeout=0): + """Return the result this future represents. + + If the future has been cancelled, raises CancelledError. If the + future's result isn't yet available, raises InvalidStateError. If + the future is done and has an exception set, this exception is raised. + Timeout values other than 0 are not supported. + """ + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + if self._exception is not None: + raise self._exception + return self._result + + def exception(self, timeout=0): + """Return the exception that was set on this future. + + The exception (or None if no exception was set) is returned only if + the future is done. If the future has been cancelled, raises + CancelledError. If the future isn't done yet, raises + InvalidStateError. Timeout values other than 0 are not supported. + """ + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + return self._exception + + def add_done_callback(self, fn): + """Add a callback to be run when the future becomes done. + + The callback is called with a single argument - the future object. If + the future is already done when this is called, the callback is + scheduled with call_soon. + """ + if self._state != _PENDING: + self._event_loop.call_soon(fn, self) + else: + self._callbacks.append(fn) + + # New method not in PEP 3148. + + def remove_done_callback(self, fn): + """Remove all instances of a callback from the "call when done" list. + + Returns the number of callbacks removed. + """ + filtered_callbacks = [f for f in self._callbacks if f != fn] + removed_count = len(self._callbacks) - len(filtered_callbacks) + if removed_count: + self._callbacks[:] = filtered_callbacks + return removed_count + + # So-called internal methods (note: no set_running_or_notify_cancel()). + + def set_result(self, result): + """Mark the future done and set its result. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError + self._result = result + self._state = _FINISHED + self._schedule_callbacks() + + def set_exception(self, exception): + """ Mark the future done and set an exception. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError + self._exception = exception + self._state = _FINISHED + self._schedule_callbacks() + + # Truly internal methods. + + def _copy_state(self, other): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert other.done() + assert not self.done() + if other.cancelled(): + self.cancel() + else: + exception = other.exception() + if exception is not None: + self.set_exception(exception) + else: + result = other.result() + self.set_result(result) + + def __iter__(self): + if not self.done(): + yield self # This tells Task to wait for completion. + return self.result() # May raise too. diff --git a/tulip/futures_test.py b/tulip/futures_test.py new file mode 100644 index 00000000..7834fec8 --- /dev/null +++ b/tulip/futures_test.py @@ -0,0 +1,210 @@ +"""Tests for futures.py.""" + +import unittest + +from . import futures + + +def _fakefunc(f): + return f + + +class FutureTests(unittest.TestCase): + + def testInitialState(self): + f = futures.Future() + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertFalse(f.done()) + + def testInitEventLoopPositional(self): + # Make sure Future does't accept a positional argument + self.assertRaises(TypeError, futures.Future, 42) + + def testCancel(self): + f = futures.Future() + self.assertTrue(f.cancel()) + self.assertTrue(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(futures.CancelledError, f.result) + self.assertRaises(futures.CancelledError, f.exception) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def testResult(self): + f = futures.Future() + self.assertRaises(futures.InvalidStateError, f.result) + self.assertRaises(futures.InvalidTimeoutError, f.result, 10) + + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def testException(self): + exc = RuntimeError() + f = futures.Future() + self.assertRaises(futures.InvalidStateError, f.exception) + self.assertRaises(futures.InvalidTimeoutError, f.exception, 10) + + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def testYieldFromTwice(self): + f = futures.Future() + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + + def testRepr(self): + f_pending = futures.Future() + self.assertEqual(repr(f_pending), 'Future') + + f_cancelled = futures.Future() + f_cancelled.cancel() + self.assertEqual(repr(f_cancelled), 'Future') + + f_result = futures.Future() + f_result.set_result(4) + self.assertEqual(repr(f_result), 'Future') + + f_exception = futures.Future() + f_exception.set_exception(RuntimeError()) + self.assertEqual(repr(f_exception), 'Future') + + f_few_callbacks = futures.Future() + f_few_callbacks.add_done_callback(_fakefunc) + self.assertIn('Future', r) + + def testCopyState(self): + # Test the internal _copy_state method since it's being directly + # invoked in other modules. + f = futures.Future() + f.set_result(10) + + newf = futures.Future() + newf._copy_state(f) + self.assertTrue(newf.done()) + self.assertEqual(newf.result(), 10) + + f_exception = futures.Future() + f_exception.set_exception(RuntimeError()) + + newf_exception = futures.Future() + newf_exception._copy_state(f_exception) + self.assertTrue(newf_exception.done()) + self.assertRaises(RuntimeError, newf_exception.result) + + f_cancelled = futures.Future() + f_cancelled.cancel() + + newf_cancelled = futures.Future() + newf_cancelled._copy_state(f_cancelled) + self.assertTrue(newf_cancelled.cancelled()) + + +# A fake event loop for tests. All it does is implement a call_soon method +# that immediately invokes the given function. +class _FakeEventLoop: + def call_soon(self, fn, future): + fn(future) + + +class FutureDoneCallbackTests(unittest.TestCase): + + def _make_callback(self, bag, thing): + # Create a callback function that appends thing to bag. + def bag_appender(future): + bag.append(thing) + return bag_appender + + def _new_future(self): + return futures.Future(event_loop=_FakeEventLoop()) + + def testCallbacksInvokedOnSetResult(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 42)) + f.add_done_callback(self._make_callback(bag, 17)) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def testCallbacksInvokedOnSetException(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 100)) + + self.assertEqual(bag, []) + exc = RuntimeError() + f.set_exception(exc) + self.assertEqual(bag, [100]) + self.assertEqual(f.exception(), exc) + + def testRemoveDoneCallback(self): + bag = [] + f = self._new_future() + cb1 = self._make_callback(bag, 1) + cb2 = self._make_callback(bag, 2) + cb3 = self._make_callback(bag, 3) + + # Add one cb1 and one cb2. + f.add_done_callback(cb1) + f.add_done_callback(cb2) + + # One instance of cb2 removed. Now there's only one cb1. + self.assertEqual(f.remove_done_callback(cb2), 1) + + # Never had any cb3 in there. + self.assertEqual(f.remove_done_callback(cb3), 0) + + # After this there will be 6 instances of cb1 and one of cb2. + f.add_done_callback(cb2) + for i in range(5): + f.add_done_callback(cb1) + + # Remove all instances of cb1. One cb2 remains. + self.assertEqual(f.remove_done_callback(cb1), 6) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [2]) + self.assertEqual(f.result(), 'foo') + + +if __name__ == '__main__': + unittest.main() diff --git a/tulip/http_client.py b/tulip/http_client.py new file mode 100644 index 00000000..d658be51 --- /dev/null +++ b/tulip/http_client.py @@ -0,0 +1,292 @@ +"""HTTP Client for Tulip. + +Most basic usage: + + sts, headers, response = yield from http_client.fetch(url, + method='GET', headers={}, request=b'') + assert isinstance(sts, int) + assert isinstance(headers, dict) + # sort of; case insensitive (what about multiple values for same header?) + headers['status'] == '200 Ok' # or some such + assert isinstance(response, bytes) + +However you can also open a stream: + + f, wstream = http_client.open_stream(url, method, headers) + wstream.write(b'abc') + wstream.writelines([b'def', b'ghi']) + wstream.write_eof() + sts, headers, rstream = yield from f + response = yield from rstream.read() + +TODO: Reuse email.Message class (or its subclass, http.client.HTTPMessage). +TODO: How do we do connection keep alive? Pooling? +""" + +import collections +import email.message +import email.parser +import re + +import tulip +from . import events +from . import futures +from . import tasks + + +# TODO: Move to another module. +class StreamReader: + + def __init__(self, limit=2**16): + self.limit = limit # Max line length. (Security feature.) + self.buffer = collections.deque() # Deque of bytes objects. + self.byte_count = 0 # Bytes in buffer. + self.line_count = 0 # Number of complete lines in buffer. + self.eof = False # Whether we're done. + self.waiter = None # A future. + + def feed_eof(self): + self.eof = True + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(True) + + def feed_data(self, data): + if not data: + return + self.buffer.append(data) + self.line_count += data.count(b'\n') + self.byte_count += len(data) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(False) + + @tasks.coroutine + def readline(self): + # TODO: Limit line length for security. + while not self.line_count and not self.eof: + assert self.waiter is None + self.waiter = futures.Future() + yield from self.waiter + continue + parts = [] + while self.buffer: + data = self.buffer.popleft() + ichar = data.find(b'\n') + if ichar < 0: + parts.append(data) + else: + ichar += 1 + head, tail = data[:ichar], data[ichar:] + parts.append(head) + if tail: + self.buffer.appendleft(tail) + self.line_count -= 1 + break + return b''.join(parts) + + @tasks.coroutine + def read(self, n=-1): + if not n: + return b'' + if n < 0: + while not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + else: + if not self.byte_count and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + if n < 0 or self.byte_count <= n: + data = b''.join(self.buffer) + self.buffer.clear() + self.byte_count = 0 + self.line_count = 0 + return data + parts = [] + parts_bytes = 0 + while self.buffer and parts_bytes < n: + data = self.buffer.popleft() + data_bytes = len(data) + if n < parts_bytes + data_bytes: + data_bytes = n - parts_bytes + data, rest = data[:data_bytes], data[data_bytes:] + self.buffer.appendleft(rest) + parts.append(data) + parts_bytes += data_bytes + self.byte_count -= data_bytes + if self.line_count: + self.line_count -= data.count(b'\n') + return b''.join(parts) + + @tasks.coroutine + def readexactly(self, n): + if n <= 0: + return b'' + while self.byte_count < n and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + return (yield from self.read(n)) + +class HttpClientProtocol: + """This Protocol class is also used to initiate the connection. + + Usage: + p = HttpClientProtocol(url, ...) + f = p.connect() # Returns a Future + ...now what?... + """ + + def __init__(self, host, port=None, *, + path='/', method='GET', headers=None, ssl=None, + make_body=None, encoding='utf-8', version='1.1', + chunked=False): + host = self.validate(host, 'host') + if ':' in host: + assert port is None + host, port_s = host.split(':', 1) + port = int(port_s) + self.host = host + if port is None: + if ssl: + port = 443 + else: + port = 80 + assert isinstance(port, int) + self.port = port + self.path = self.validate(path, 'path') + self.method = self.validate(method, 'method') + self.headers = email.message.Message() + if headers: + for key, value in headers.items(): + self.validate(key, 'header key') + self.validate(value, 'header value', True) + self.headers[key] = value + self.encoding = self.validate(encoding, 'encoding') + self.version = self.validate(version, 'version') + self.make_body = make_body + self.chunked = chunked + self.ssl = ssl + if 'content-length' not in self.headers: + if self.make_body is None: + self.headers['Content-Length'] = '0' + else: + self.chunked = True + if self.chunked: + if 'Transfer-Encoding' not in self.headers: + self.headers['Transfer-Encoding'] = 'chunked' + else: + assert self.headers['Transfer-Encoding'].lower() == 'chunked' + if 'host' not in self.headers: + self.headers['Host'] = self.host + self.event_loop = events.get_event_loop() + self.transport = None + + def validate(self, value, name, embedded_spaces_okay=False): + # Must be a string. If embedded_spaces_okay is False, no + # whitespace is allowed; otherwise, internal single spaces are + # allowed (but no other whitespace). + assert isinstance(value, str), \ + '{} should be str, not {}'.format(name, type(value)) + parts = value.split() + assert parts, '{} should not be empty'.format(name) + if embedded_spaces_okay: + assert ' '.join(parts) == value, \ + '{} can only contain embedded single spaces ({!r})'.format( + name, value) + else: + assert parts == [value], \ + '{} cannot contain whitespace ({!r})'.format(name, value) + return value + + @tulip.coroutine + def connect(self): + yield from self.event_loop.create_connection(lambda: self, + self.host, + self.port, + ssl=self.ssl) + # TODO: A better mechanism to return all info from the + # status line, all headers, and the buffer, without having + # an N-tuple return value. + status_line = yield from self.stream.readline() + m = re.match(rb'HTTP/(\d\.\d)\s+(\d\d\d)\s+([^\r\n]+)\r?\n\Z', + status_line) + if not m: + raise 'Invalid HTTP status line ({!r})'.format(status_line) + version, status, message = m.groups() + raw_headers = [] + while True: + header = yield from self.stream.readline() + if not header.strip(): + break + raw_headers.append(header) + parser = email.parser.BytesHeaderParser() + headers = parser.parsebytes(b''.join(raw_headers)) + content_length = headers.get('content-length') + if content_length: + content_length = int(content_length) # May raise. + if content_length is None: + stream = self.stream + else: + # TODO: A wrapping stream that limits how much it can read + # without reading it all into memory at once. + body = yield from self.stream.readexactly(content_length) + stream = StreamReader() + stream.feed_data(body) + stream.feed_eof() + sts = '{} {}'.format(self.decode(status), self.decode(message)) + return (sts, headers, stream) + + def encode(self, s): + if isinstance(s, bytes): + return s + return s.encode(self.encoding) + + def decode(self, s): + if isinstance(s, str): + return s + return s.decode(self.encoding) + + def write_str(self, s): + self.transport.write(self.encode(s)) + + def write_chunked(self, s): + if not s: + return + data = self.encode(s) + self.write_str('{:x}\r\n'.format(len(data))) + self.transport.write(data) + self.transport.write(b'\r\n') + + def write_chunked_eof(self): + self.transport.write(b'0\r\n\r\n') + + def connection_made(self, transport): + self.transport = transport + line = '{} {} HTTP/{}\r\n'.format(self.method, + self.path, + self.version) + self.write_str(line) + for key, value in self.headers.items(): + self.write_str('{}: {}\r\n'.format(key, value)) + self.transport.write(b'\r\n') + self.stream = StreamReader() + if self.make_body is not None: + if self.chunked: + self.make_body(self.write_chunked, self.write_chunked_eof) + else: + self.make_body(self.write_str, self.transport.write_eof) + + def data_received(self, data): + self.stream.feed_data(data) + + def eof_received(self): + self.stream.feed_eof() + + def connection_lost(self, exc): + pass diff --git a/tulip/iocpsockets.py b/tulip/iocpsockets.py new file mode 100644 index 00000000..67916050 --- /dev/null +++ b/tulip/iocpsockets.py @@ -0,0 +1,360 @@ +import collections +import errno +import socket +import weakref +import _winapi + +from .selectors import _BaseSelector, EVENT_READ, EVENT_WRITE, EVENT_CONNECT +from ._overlapped import * + + +class IocpSocket: + + _state = 'unknown' + + def __init__(self, selector, sock=None, family=socket.AF_INET, + type=socket.SOCK_STREAM, proto=0): + if sock is None: + sock = socket.socket(family, type, proto) + self._sock = sock + self._fd = self._sock.fileno() + self._selector = selector + self._pending = {EVENT_READ:False, EVENT_WRITE:False} + self._result = {EVENT_READ:None, EVENT_WRITE:None} + CreateIoCompletionPort(self._fd, selector._iocp, 0, 0) + # XXX SetFileCompletionNotificationModes() requires Vista or later + SetFileCompletionNotificationModes( + self._fd, FILE_SKIP_COMPLETION_PORT_ON_SUCCESS) + selector._fd_to_fileobj[sock.fileno()] = self + + def __getattr__(self, name): + return getattr(self._sock, name) + + def listen(self, backlog): + self._state = 'listening' + self._sock.listen(backlog) + + def send(self, buf): + # if self._state != 'connected': + # raise ValueError('socket is in state %r' % self._state) + if self._pending[EVENT_WRITE]: + raise BlockingIOError(errno.EAGAIN, 'try again') + res = self._result[EVENT_WRITE] + if res and not res[0]: + raise res[1] + ov = Overlapped(0) + ov.WSASend(self._fd, buf, 0) + def callback(): + self._sock._decref_socketios() + if ov.getresult() < len(buf): + # partial writes only happen if something has broken + raise RuntimeError('partial write -- should not get here') + self._sock._io_refs += 1 + if ov.pending: + self._selector._defer(self, ov, EVENT_WRITE, callback) + else: + callback() + return len(buf) + + def recv(self, length): + # if self._state != 'connected': + # raise ValueError('socket is in state %r' % self._state) + if self._pending[EVENT_READ]: + raise BlockingIOError(errno.EAGAIN, 'try again') + if length <= 0: + if length < 0: + raise ValueError('negative length') + return b'' + res = self._result[EVENT_READ] + if res: + success, value = self._result[EVENT_READ] + if not success: + raise value + if length < len(value): + value = value[:length] + self._result[EVENT_READ] = (True, res[length:]) + else: + self._result[EVENT_READ] = None + return value + ov = Overlapped(0) + ov.WSARecv(self._fd, length, 0) + if ov.pending: + self._selector._defer(self, ov, EVENT_READ, ov.getresult) + raise BlockingIOError(errno.EAGAIN, 'try again') + else: + return ov.getresult() + + def connect(self, address): + if self._state != 'unknown': + raise ValueError('socket is in state %r' % self._state) + self._state = 'connecting' + BindLocal(self._fd, len(address)) + ov = Overlapped(0) + try: + ov.ConnectEx(self._fd, address) + except OSError as e: + if e.winerror == 10022: + raise ConnectionRefusedError( + errno.ECONNREFUSED, e.strerror, None, e.winerror) + def callback(): + try: + ov.getresult(False) + except OSError as e: + self._state = 'broken' + if e.winerror == 1225: + self._error = errno.ECONNREFUSED + else: + self._error = e.errno + raise + else: + self._state = 'connected' + if ov.pending: + self._selector._defer(self, ov, EVENT_WRITE, callback) + raise BlockingIOError(errno.EINPROGRESS, 'connect in progress') + else: + callback() + + def getsockopt(self, level, optname, buflen=None): + if ((level, optname) == (socket.SOL_SOCKET, socket.SO_ERROR)): + if self._state == 'connecting': + return errno.EINPROGRESS + elif self._state == 'broken': + return self._error + if buflen is None: + return self._sock.getsockopt(level, optname) + else: + return self._sock.getsockopt(level, optname, buflen) + + def accept(self): + if self._state != 'listening': + raise ValueError('socket is in state %r' % self._state) + if self._result[EVENT_READ]: + success, value = self._result[EVENT_READ] + self._result[EVENT_READ] = None + if success: + return value + else: + raise value + if self._pending[EVENT_READ]: + raise BlockingIOError(errno.EAGAIN, 'try again') + conn = socket.socket(self.family, self.type, self.proto) + conn = self._selector.wrap_socket(conn) + ov = Overlapped(0) + ov.AcceptEx(self._fd, conn.fileno()) + def callback(): + ov.getresult(False) + conn._sock.setsockopt( + socket.SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, self._fd) + conn._state = 'connected' + return conn, conn.getpeername() # XXX + if ov.pending: + self._selector._defer(self, ov, EVENT_READ, callback) + raise BlockingIOError(errno.EAGAIN, 'try again') + else: + return callback() + + sendall = send + + # XXX how do we deal with shutdown? + # XXX connect_ex, makefile, ...? + + +class IocpSelector(_BaseSelector): + + def __init__(self, *, concurrency=0xffffffff): + super().__init__() + self._iocp = CreateIoCompletionPort( + INVALID_HANDLE_VALUE, NULL, 0, concurrency) + self._address_to_info = {} + self._fd_to_fileobj = weakref.WeakValueDictionary() + + def wrap_socket(self, sock): + return IocpSocket(self, sock) + + def _defer(self, sock, ov, flag, callback): + sock._pending[flag] = True + self._address_to_info[ov.address] = (sock, ov, flag, callback) + + def close(self): + super().close() + if self._iocp is not None: + try: + # cancel pending IO + for info in self._address_to_info.values(): + ov = info[1] + try: + ov.cancel() + except OSError as e: + # handle may have closed + pass # XXX check e.winerror + # wait for pending IO to stop + while self._address_to_info: + status = GetQueuedCompletionStatus(self._iocp, 1000) + if status is None: + print(self._address_to_info) + continue + self._address_to_info.pop(status[3], None) + finally: + _winapi.CloseHandle(self._iocp) + self._iocp = None + + def select(self, timeout=None): + results = {} + for fd, key in self._fd_to_key.items(): + fileobj = self._fd_to_fileobj[fd] + if ((key.events & EVENT_READ) + and not fileobj._pending[EVENT_READ]): + results[fd] = results.get(fd, 0) | EVENT_READ + if ((key.events & (EVENT_WRITE|EVENT_CONNECT)) + and not fileobj._pending[EVENT_WRITE]): + results[fd] = results.get(fd, 0) | EVENT_WRITE | EVENT_CONNECT + + if results: + ms = 0 + elif timeout is None: + ms = INFINITE + elif timeout < 0: + raise ValueError('negative timeout') + else: + ms = int(timeout * 1000 + 0.5) + if ms >= INFINITE: + raise ValueError('timeout too big') + + while True: + status = GetQueuedCompletionStatus(self._iocp, ms) + if status is None: + break + try: + fobj, ov, flag, callback = self._address_to_info.pop(status[3]) + except KeyError: + continue + fobj._pending[flag] = False + try: + value = callback() + except OSError as e: + fobj._result[flag] = (False, e) + else: + fobj._result[flag] = (True, value) + key = self._fileobj_to_key.get(fobj) + if key and (key.events & flag): + results[fobj._fd] = results.get(fobj._fd, 0) | flag + ms = 0 + + tmp = [] + for fd, events in results.items(): + if events & EVENT_WRITE: + events |= EVENT_CONNECT + key = self._fd_to_key[fd] + tmp.append((key.fileobj, events, key.data)) + return tmp + + +def main(): + from .winsocketpair import socketpair + + selector = IocpSelector() + + # listen + listener = selector.wrap_socket(socket.socket()) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + + # connect + conn = selector.wrap_socket(socket.socket()) + try: + conn.connect(listener.getsockname()) + # conn.connect(('127.0.0.1', 7868)) + except BlockingIOError: + selector.register(conn, EVENT_WRITE) + res = selector.select(5) + # assert [(conn, EVENT_WRITE, None)] == res, res + error = conn.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + assert error == 0, error + selector.unregister(conn) + + # accept + selector.register(listener, EVENT_READ) + while True: + try: + a, addr = listener.accept() + break + except BlockingIOError: + res = selector.select(1) + assert [(listener, EVENT_READ, None)] == res + selector.unregister(listener) + + + selector.register(a, EVENT_WRITE) + selector.register(conn, EVENT_READ) + msgs = [b"hello"] * 100 + + while selector.registered_count() > 0: + for (f, event, data) in selector.select(): + if event & EVENT_READ: + try: + msg = f.recv(20) + except BlockingIOError: + print("READ BLOCKED") + else: + print("read %r" % msg) + if not msg: + print("UNREGISTER READER") + selector.unregister(f) + f.close() + if event & EVENT_WRITE: + try: + nbytes = f.send(msgs.pop()) + except BlockingIOError: + print("WRITE BLOCKED") + except IndexError: + print("UNREGISTER WRITER") + selector.unregister(f) + f.close() + else: + print("bytes sent %r" % nbytes) + + + a, b = socketpair() + a = selector.wrap_socket(a) + b = selector.wrap_socket(b) + selector.register(a, EVENT_READ) + selector.register(b, EVENT_WRITE) + + msg = b"x"*(1024*1024*16) + view = memoryview(msg) + res = [] + + while selector.registered_count() > 0: + for (f, event, data) in selector.select(): + if event & EVENT_READ: + try: + data = f.recv(8192) + except BlockingIOError: + print("READ BLOCKED") + else: + res.append(data) + if not data: + print("UNREGISTER READER") + selector.unregister(f) + f.close() + if event & EVENT_WRITE: + try: + nbytes = f.send(view) + except BlockingIOError: + print("WRITE BLOCKED") + else: + assert nbytes == len(view) + if nbytes == 0: + print("UNREGISTER WRITER") + selector.unregister(f) + f.close() + else: + view = view[nbytes:] + + print(len(msg), sum(len(frag) for frag in res)) + + selector.close() + + +if __name__ == '__main__': + main() diff --git a/tulip/protocols.py b/tulip/protocols.py new file mode 100644 index 00000000..ad294f3a --- /dev/null +++ b/tulip/protocols.py @@ -0,0 +1,58 @@ +"""Abstract Protocol class.""" + +__all__ = ['Protocol'] + + +class Protocol: + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing (they don't raise exceptions). + + When the user wants to requests a transport, they pass a protocol + factory to a utility function (e.g., EventLoop.create_connection()). + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_list() will be called exactly once + with either an exception object or None as an argument. + + State machine of calls: + + start -> CM [-> DR*] [-> ER?] -> CL -> end + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the connection. + To send data, call its write() or writelines() method. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + The default implementation does nothing. + + TODO: By default close the transport. But we don't have the + transport as an instance variable (connection_made() may not + set it). + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ diff --git a/tulip/selectors.py b/tulip/selectors.py new file mode 100644 index 00000000..2ea92b38 --- /dev/null +++ b/tulip/selectors.py @@ -0,0 +1,434 @@ +"""Select module. + +This module supports asynchronous I/O on multiple file descriptors. +""" + +import logging +import sys + +from select import * + + +# generic events, that must be mapped to implementation-specific ones +# read event +EVENT_READ = (1 << 0) +# write event +EVENT_WRITE = (1 << 1) +# connect event +EVENT_CONNECT = (1 << 2) + +# In most cases we treat EVENT_WRITE and EVENT_CONNECT as aliases for +# each other, and in fact we return both flags when a FD is found +# either writable or connectable. The distinction is necessary +# only for poll() on Windows. + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file descriptor, or any object with a `fileno()` method + + Returns: + corresponding file descriptor + """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (ValueError, TypeError): + raise ValueError("Invalid file object: {!r}".format(fileobj)) + return fd + + +class SelectorKey: + """Object used internally to associate a file object to its backing file + descriptor, selected event mask and attached data.""" + + def __init__(self, fileobj, events, data=None): + self.fileobj = fileobj + self.fd = _fileobj_to_fd(fileobj) + self.events = events + self.data = data + + def __repr__(self): + return '{}'.format( + self.__class__.__name__, + self.fileobj, self.fd, self.events, self.data) + + +class _BaseSelector: + """Base selector class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + performant implementation on the current platform. + """ + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # this maps file objects to keys - for fast (un)registering + self._fileobj_to_key = {} + + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of + EVENT_READ|EVENT_WRITE|EVENT_CONNECT) + data -- attached data + + Returns: + SelectorKey instance + """ + if (not events) or (events & ~(EVENT_READ|EVENT_WRITE|EVENT_CONNECT)): + raise ValueError("Invalid events: {}".format(events)) + + if events & (EVENT_WRITE|EVENT_CONNECT) == (EVENT_WRITE|EVENT_CONNECT): + raise ValueError("WRITE and CONNECT are mutually exclusive. " + "Invalid events: {}".format(events)) + + if fileobj in self._fileobj_to_key: + raise ValueError("{!r} is already registered".format(fileobj)) + + key = SelectorKey(fileobj, events, data) + self._fd_to_key[key.fd] = key + self._fileobj_to_key[fileobj] = key + return key + + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object + + Returns: + SelectorKey instance + """ + try: + key = self._fileobj_to_key[fileobj] + del self._fd_to_key[key.fd] + del self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + return key + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of + EVENT_READ|EVENT_WRITE|EVENT_CONNECT) + data -- attached data + """ + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + if events != key.events: + self.unregister(fileobj) + self.register(fileobj, events, data) + + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout == 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (fileobj, events, attached data) for ready file objects + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE|EVENT_CONNECT + """ + raise NotImplementedError() + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + self._fd_to_key.clear() + self._fileobj_to_key.clear() + + def get_info(self, fileobj): + """Return information about a registered file object. + + Returns: + (events, data) associated to this file object + + Raises KeyError if the file object is not registered. + """ + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise KeyError("{} is not registered".format(fileobj)) + return key.events, key.data + + def registered_count(self): + """Return the number of registered file objects. + + Returns: + number of currently registered file objects + """ + return len(self._fd_to_key) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key + """ + try: + return self._fd_to_key[fd] + except KeyError: + logging.warn('No key found for fd %r', fd) + return None + + def wrap_socket(self, sock): + """Return sock or a wrapper for sock compatible with selector""" + return sock + + +class SelectSelector(_BaseSelector): + """Select-based selector.""" + + def __init__(self): + super().__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & (EVENT_WRITE|EVENT_CONNECT): + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + def select(self, timeout=None): + try: + r, w, _ = self._select(self._readers, self._writers, [], timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + r = set(r) + w = set(w) + ready = [] + for fd in r | w: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE|EVENT_CONNECT + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + if sys.platform == 'win32': + def _select(self, r, w, _, timeout=None): + r, w, x = select(r, w, w, timeout) + return r, w + x, [] + else: + from select import select as _select + + +if 'poll' in globals(): + + # TODO: Implement poll() for Windows with workaround for + # brokenness in WSAPoll() (Richard Oudkerk, see + # http://bugs.python.org/issue16507). + + class PollSelector(_BaseSelector): + """Poll-based selector.""" + + def __init__(self): + super().__init__() + self._poll = poll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= POLLIN + if events & (EVENT_WRITE|EVENT_CONNECT): + poll_events |= POLLOUT + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = None if timeout is None else int(1000 * timeout) + ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~POLLIN: + events |= EVENT_WRITE|EVENT_CONNECT + if event & ~POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + +if 'epoll' in globals(): + + class EpollSelector(_BaseSelector): + """Epoll-based selector.""" + + def __init__(self): + super().__init__() + self._epoll = epoll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + epoll_events = 0 + if events & EVENT_READ: + epoll_events |= EPOLLIN + if events & (EVENT_WRITE|EVENT_CONNECT): + epoll_events |= EPOLLOUT + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._epoll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = -1 if timeout is None else timeout + max_ev = self.registered_count() + ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~EPOLLIN: + events |= EVENT_WRITE|EVENT_CONNECT + if event & ~EPOLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + def close(self): + super().close() + self._epoll.close() + + +if 'kqueue' in globals(): + + class KqueueSelector(_BaseSelector): + """Kqueue-based selector.""" + + def __init__(self): + super().__init__() + self._kqueue = kqueue() + + def unregister(self, fileobj): + key = super().unregister(fileobj) + mask = 0 + if key.events & EVENT_READ: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + if key.events & (EVENT_WRITE|EVENT_CONNECT): + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + return key + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & (EVENT_WRITE|EVENT_CONNECT): + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def select(self, timeout=None): + max_ev = self.registered_count() + ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == KQ_FILTER_READ: + events |= EVENT_READ + if flag == KQ_FILTER_WRITE: + events |= EVENT_WRITE|EVENT_CONNECT + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + def close(self): + super().close() + self._kqueue.close() + + +# Choose the best implementation: roughly, epoll|kqueue > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + Selector = KqueueSelector +elif 'EpollSelector' in globals(): + Selector = EpollSelector +elif 'PollSelector' in globals(): + Selector = PollSelector +else: + Selector = SelectSelector diff --git a/tulip/subprocess_test.py b/tulip/subprocess_test.py new file mode 100644 index 00000000..4eb24e41 --- /dev/null +++ b/tulip/subprocess_test.py @@ -0,0 +1,48 @@ +"""Tests for subprocess_transport.py.""" + +import unittest + +from . import events +from . import protocols +from . import subprocess_transport + + +class MyProto(protocols.Protocol): + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write_eof() + def data_received(self, data): + print('received:', data) + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + + +class FutureTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def testUnixSubprocess(self): + p = MyProto() + t = subprocess_transport.UnixSubprocessTransport(p, ['/bin/ls', '-lR']) + self.event_loop.run() + + +if __name__ == '__main__': + unittest.main() diff --git a/tulip/subprocess_transport.py b/tulip/subprocess_transport.py new file mode 100644 index 00000000..721013f8 --- /dev/null +++ b/tulip/subprocess_transport.py @@ -0,0 +1,133 @@ +import fcntl +import os +import traceback + +from . import transports +from . import events + + +class UnixSubprocessTransport(transports.Transport): + """Transport class managing a subprocess. + + TODO: Separate this into something that just handles pipe I/O, + and something else that handles pipe setup, fork, and exec. + """ + + def __init__(self, protocol, args): + self._protocol = protocol # Not a factory! :-) + self._args = args # args[0] must be full path of binary. + self._event_loop = events.get_event_loop() + self._buffer = [] + self._eof = False + rstdin, self._wstdin = os.pipe() + self._rstdout, wstdout = os.pipe() + + # TODO: This is incredibly naive. Should look at + # subprocess.py for all the precautions around fork/exec. + pid = os.fork() + if not pid: + # Child. + try: + os.dup2(rstdin, 0) + os.dup2(wstdout, 1) + # TODO: What to do with stderr? + os.execv(args[0], args) + except: + try: + traceback.print_traceback() + finally: + os._exit(127) + + # Parent. + os.close(rstdin) + os.close(wstdout) + _setnonblocking(self._wstdin) + _setnonblocking(self._rstdout) + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.add_reader(self._rstdout, self._stdout_callback) + + def write(self, data): + assert not self._eof + if not data: + return + if not self._buffer: + # Attempt to write it right away first. + try: + n = os.write(self._wstdin, data) + except BlockingIOError: + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + if n >= len(data): + return + if n > 0: + data = data[n:] + self._event_loop.add_writer(self._wstdin, self._stdin_callback) + self._buffer.append(data) + + def write_eof(self): + assert not self._eof + assert self._wstdin >= 0 + self._eof = True + if not self._buffer: + self._event_loop.remove_writer(self._wstdin) + os.close(self._wstdin) + self._wstdin = -1 + + def close(self): + if not self._eof: + self.write_eof() + # XXX What else? + + def _fatal_error(self, exc): + logging.error('Fatal error: %r', exc) + if self._rstdout >= 0: + os.close(self._rstdout) + self._rstdout = -1 + if self._wstdin >= 0: + os.close(self._wstdin) + self._wstdin = -1 + self._eof = True + self._buffer = None + + def _stdin_callback(self): + data = b''.join(self._buffer) + self._buffer = [] + try: + if data: + n = os.write(self._wstdin, data) + else: + n = 0 + except BlockingIOError: + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + if n >= len(data): + self._event_loop.remove_writer(self._wstdin) + if self._eof: + os.close(self._wstdin) + self._wstdin = -1 + return + if n > 0: + data = data[n:] + self._buffer.append(data) # Try again later. + + def _stdout_callback(self): + try: + data = os.read(self._rstdout, 1024) + except BlockingIOError: + return + if data: + self._event_loop.call_soon(self._protocol.data_received, data) + else: + self._event_loop.remove_reader(self._rstdout) + os.close(self._rstdout) + self._rstdout = -1 + self._event_loop.call_soon(self._protocol.eof_received) + + +def _setnonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) diff --git a/tulip/tasks.py b/tulip/tasks.py new file mode 100644 index 00000000..91c32076 --- /dev/null +++ b/tulip/tasks.py @@ -0,0 +1,277 @@ +"""Support for tasks, coroutines and the scheduler.""" + +__all__ = ['coroutine', 'task', 'Task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'as_completed', 'sleep', + ] + +import concurrent.futures +import inspect +import logging +import time + +from . import events +from . import futures + + +def coroutine(func): + """Decorator to mark coroutines.""" + # TODO: This is a feel-good API only. It is not enforced. + assert inspect.isgeneratorfunction(func) + func._is_coroutine = True # Not sure who can use this. + return func + + +# TODO: Do we need this? +def iscoroutinefunction(func): + """Return True if func is a decorated coroutine function.""" + return (inspect.isgeneratorfunction(func) and + getattr(func, '_is_coroutine', False)) + + +# TODO: Do we need this? +def iscoroutine(obj): + """Return True if obj is a coroutine object.""" + return inspect.isgenerator(obj) # TODO: And what? + + +def task(func): + """Decorator for a coroutine to be wrapped in a Task.""" + def task_wrapper(*args, **kwds): + coro = func(*args, **kwds) + return Task(coro) + return task_wrapper + + +class Task(futures.Future): + """A coroutine wrapped in a Future.""" + + def __init__(self, coro): + assert inspect.isgenerator(coro) # Must be a coroutine *object*. + super().__init__() # Sets self._event_loop. + self._coro = coro + self._must_cancel = False + self._event_loop.call_soon(self._step) + + def __repr__(self): + res = super().__repr__() + if (self._must_cancel and + self._state == futures._PENDING and + ')'.format(self._coro.__name__) + res[i:] + return res + + def cancel(self): + if self.done(): + return False + self._must_cancel = True + # _step() will call super().cancel() to call the callbacks. + self._event_loop.call_soon(self._step_maybe) + return True + + def cancelled(self): + return self._must_cancel or super().cancelled() + + def _step_maybe(self): + # Helper for cancel(). + if not self.done(): + return self._step() + + def _step(self, value=None, exc=None): + if self.done(): # pragma: no cover + logging.warn('_step(): already done: %r, %r, %r', self, value, exc) + return + # We'll call either coro.throw(exc) or coro.send(value). + if self._must_cancel: + exc = futures.CancelledError + coro = self._coro + try: + if exc is not None: + result = coro.throw(exc) + elif value is not None: + result = coro.send(value) + else: + result = next(coro) + except StopIteration as exc: + if self._must_cancel: + super().cancel() + else: + self.set_result(exc.value) + except Exception as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + except BaseException as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + raise + else: + # XXX No check for self._must_cancel here? + if isinstance(result, futures.Future): + result.add_done_callback(self._wakeup) + elif isinstance(result, concurrent.futures.Future): + # This ought to be more efficient than wrap_future(), + # because we don't create an extra Future. + result.add_done_callback( + lambda future: + self._event_loop.call_soon_threadsafe(_wakeup, future)) + else: + if result is not None: + logging.warn('_step(): bad yield: %r', result) + self._event_loop.call_soon(self._step) + + def _wakeup(self, future): + try: + value = future.result() + except Exception as exc: + self._step(None, exc) + else: + self._step(value, None) + + +# wait() and as_completed() similar to those in PEP 3148. + +FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED +FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION +ALL_COMPLETED = concurrent.futures.ALL_COMPLETED + + +# Even though this *is* a @coroutine, we don't mark it as such! +def wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Wait for the Futures and and coroutines given by fs to complete. + + Coroutines will be wrapped in Tasks. + + Returns two sets of Future: (done, pending). + + Usage: + + done, pending = yield from tulip.wait(fs) + + Note: This does not raise TimeoutError! Futures that aren't done + when the timeout occurs are returned in the second set. + """ + fs = _wrap_coroutines(fs) + return _wait(fs, timeout, return_when) + + +@coroutine +def _wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Internal helper: Like wait() but does not wrap coroutines.""" + done, pending = set(), set() + errors = 0 + for f in fs: + if f.done(): + done.add(f) + if not f.cancelled() and f.exception() is not None: + errors += 1 + else: + pending.add(f) + if (not pending or + timeout != None and timeout <= 0 or + return_when == FIRST_COMPLETED and done or + return_when == FIRST_EXCEPTION and errors): + return done, pending + bail = futures.Future() # Will always be cancelled eventually. + timeout_handler = None + debugstuff = locals() + if timeout is not None: + loop = events.get_event_loop() + timeout_handler = loop.call_later(timeout, bail.cancel) + def _on_completion(f): + pending.remove(f) + done.add(f) + if (not pending or + return_when == FIRST_COMPLETED or + (return_when == FIRST_EXCEPTION and + not f.cancelled() and + f.exception() is not None)): + bail.cancel() + try: + for f in pending: + f.add_done_callback(_on_completion) + try: + yield from bail + except futures.CancelledError: + pass + finally: + for f in pending: + f.remove_done_callback(_on_completion) + if timeout_handler is not None: + timeout_handler.cancel() + really_done = set(f for f in pending if f.done()) + if really_done: # pragma: no cover + # We don't expect this to ever happen. Or do we? + done.update(really_done) + pending.difference_update(really_done) + return done, pending + + +# This is *not* a @coroutine! It is just an iterator (yielding Futures). +def as_completed(fs, timeout=None): + """Return an iterator whose values, when waited for, are Futures. + + This differs from PEP 3148; the proper way to use this is: + + for f in as_completed(fs): + result = yield from f # The 'yield from' may raise. + # Use result. + + Raises TimeoutError if the timeout occurs before all Futures are + done. + + Note: The futures 'f' are not necessarily members of fs. + """ + deadline = None + if timeout is not None: + deadline = time.monotonic() + timeout + done = None # Make nonlocal happy. + fs = _wrap_coroutines(fs) + while fs: + if deadline is not None: + timeout = deadline - time.monotonic() + @coroutine + def _wait_for_some(): + nonlocal done, fs + done, fs = yield from _wait(fs, timeout=timeout, + return_when=FIRST_COMPLETED) + if not done: + fs = set() + raise futures.TimeoutError() + return done.pop().result() # May raise. + yield Task(_wait_for_some()) + for f in done: + yield f + + +def _wrap_coroutines(fs): + """Internal helper to process an iterator of Futures and coroutines. + + Returns a set of Futures. + """ + wrapped = set() + for f in fs: + if not isinstance(f, futures.Future): + assert iscoroutine(f) + f = Task(f) + wrapped.add(f) + return wrapped + + +def sleep(when, result=None): + """Return a Future that completes after a given time (in seconds). + + It's okay to cancel the Future. + + Undocumented feature: sleep(when, x) sets the Future's result to x. + """ + future = futures.Future() + future._event_loop.call_later(when, future.set_result, result) + return future diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py new file mode 100644 index 00000000..a7381450 --- /dev/null +++ b/tulip/tasks_test.py @@ -0,0 +1,233 @@ +"""Tests for tasks.py.""" + +import time +import unittest + +from . import events +from . import futures +from . import tasks + + +class Dummy: + def __repr__(self): + return 'Dummy()' + def __call__(self, *args): + pass + + +class TaskTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def testTaskClass(self): + @tasks.coroutine + def notmuch(): + yield from [] + return 'ok' + t = tasks.Task(notmuch()) + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + + def testTaskDecorator(self): + @tasks.task + def notmuch(): + yield from [] + return 'ko' + t = notmuch() + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + + def testTaskRepr(self): + @tasks.task + def notmuch(): + yield from [] + return 'abc' + t = notmuch() + t.add_done_callback(Dummy()) + self.assertEqual(repr(t), 'Task()') + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), 'Task()') + self.assertRaises(futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertEqual(repr(t), 'Task()') + t = notmuch() + self.event_loop.run_until_complete(t) + self.assertEqual(repr(t), "Task()") + + def testTaskBasics(self): + @tasks.task + def outer(): + a = yield from inner1() + b = yield from inner2() + return a+b + @tasks.task + def inner1(): + yield from [] + return 42 + @tasks.task + def inner2(): + yield from [] + return 1000 + t = outer() + self.assertEqual(self.event_loop.run_until_complete(t), 1042) + + def testWait(self): + a = tasks.sleep(0.1) + b = tasks.sleep(0.15) + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertEqual(res, 42) + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + # TODO: Test different return_when values. + + def testWaitWithException(self): + a = tasks.sleep(0.1) + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(0.15) + raise ZeroDivisionError('really') + b = tasks.Task(sleeper()) + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + + def testWaitWithTimeout(self): + a = tasks.sleep(0.1) + b = tasks.sleep(0.15) + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], timeout=0.11) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.1) + self.assertTrue(t1-t0 <= 0.13) + + def testAsCompleted(self): + @tasks.coroutine + def sleeper(dt, x): + yield from tasks.sleep(dt) + return x + a = sleeper(0.1, 'a') + b = sleeper(0.1, 'b') + c = sleeper(0.15, 'c') + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([b, c, a]): + values.append((yield from f)) + return values + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + + def testAsCompletedWithTimeout(self): + a = tasks.sleep(0.1, 'a') + b = tasks.sleep(0.15, 'b') + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([a, b], timeout=0.12): + try: + v = yield from f + values.append((1, v)) + except futures.TimeoutError as exc: + values.append((2, exc)) + return values + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.11) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) + + def testSleep(self): + @tasks.coroutine + def sleeper(dt, arg): + yield from tasks.sleep(dt/2) + res = yield from tasks.sleep(dt/2, arg) + return res + t = tasks.Task(sleeper(0.1, 'yeah')) + t0 = time.monotonic() + self.event_loop.run() + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.09) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + + def testTaskCancelSleepingTask(self): + sleepfut = None + @tasks.task + def sleep(dt): + nonlocal sleepfut + sleepfut = tasks.sleep(dt) + try: + t0 = time.monotonic() + yield from sleepfut + finally: + t1 = time.monotonic() + @tasks.task + def doit(): + sleeper = sleep(5000) + self.event_loop.call_later(0.1, sleeper.cancel) + try: + t0 = time.monotonic() + yield from sleeper + except futures.CancelledError: + t1 = time.monotonic() + return 'cancelled' + else: + return 'slept in' + t0 = time.monotonic() + doer = doit() + self.assertEqual(self.event_loop.run_until_complete(doer), 'cancelled') + t1 = time.monotonic() + self.assertTrue(0.09 <= t1-t0 <= 0.13, (t1-t0, sleepfut, doer)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tulip/transports.py b/tulip/transports.py new file mode 100644 index 00000000..4aaae3c7 --- /dev/null +++ b/tulip/transports.py @@ -0,0 +1,90 @@ +"""Abstract Transport class.""" + +__all__ = ['Transport'] + + +class Transport: + """ABC representing a transport. + + There may be several implementations, but typically, the user does + not implement new transports; rather, the platform provides some + useful transports that are implemented using the platform's best + practices. + + The user never instantiates a transport directly; they call a + utility function, passing it a protocol factory and other + information necessary to create the transport and protocol. (E.g. + EventLoop.create_connection() or EventLoop.start_serving().) + + The utility function will asynchronously create a transport and a + protocol and hook them up by calling the protocol's + connection_made() method, passing it the transport. + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def write_eof(self): + """Closes the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + raise NotImplementedError + + def can_write_eof(self): + """Return True if this protocol supports write_eof(), False if not.""" + raise NotImplementedError + + def pause(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + raise NotImplementedError + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError diff --git a/tulip/unix_events.py b/tulip/unix_events.py new file mode 100644 index 00000000..4b3dc525 --- /dev/null +++ b/tulip/unix_events.py @@ -0,0 +1,1001 @@ +"""UNIX event loop and related classes. + +The event loop can be broken up into a selector (the part responsible +for telling us when file descriptors are ready) and the event loop +proper, which wraps a selector with functionality for scheduling +callbacks, immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. +""" + +import collections +import concurrent.futures +import errno +import heapq +import logging +import select +import socket +try: + import ssl +except ImportError: + ssl = None +import sys +import threading +import time + +try: + import signal +except ImportError: + signal = None + +from . import events +from . import futures +from . import protocols +from . import selectors +from . import tasks +from . import transports + +try: + from socket import socketpair +except ImportError: + assert sys.platform == 'win32' + from .winsocketpair import socketpair + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK, errno.EINPROGRESS)) +if sys.platform == 'win32': + _TRYAGAIN = frozenset(list(_TRYAGAIN) + [errno.WSAEWOULDBLOCK]) + +# Argument for default thread pool executor creation. +_MAX_WORKERS = 5 + + +class _StopError(BaseException): + """Raised to stop the event loop.""" + + +def _raise_stop_error(): + raise _StopError + + +class UnixEventLoop(events.EventLoop): + """Unix event loop. + + See events.EventLoop for API specification. + """ + + def __init__(self, selector=None): + super().__init__() + if selector is None: + # pick the best selector class for the platform + selector = selectors.Selector() + logging.info('Using selector: %s', selector.__class__.__name__) + self._selector = selector + self._ready = collections.deque() + self._scheduled = [] + self._everytime = [] + self._default_executor = None + self._signal_handlers = {} + self._make_self_pipe() + + def close(self): + if self._selector is not None: + self._selector.close() + self._selector = None + + def _make_self_pipe(self): + # A self-socket, really. :-) + a, b = socketpair() + self._ssock = self._selector.wrap_socket(a) + self._csock = self._selector.wrap_socket(b) + self._ssock.setblocking(False) + self._csock.setblocking(False) + self.add_reader(self._ssock.fileno(), self._read_from_self) + + def _read_from_self(self): + try: + self._ssock.recv(1) + except socket.error as exc: + if exc.errno in _TRYAGAIN: + return + raise # Halp! + + def _write_to_self(self): + try: + self._csock.send(b'x') + except socket.error as exc: + if exc.errno in _TRYAGAIN: + return + raise # Halp! + + def run(self): + """Run the event loop until nothing left to do or stop() called. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + + TODO: Give this a timeout too? + """ + while (self._ready or + self._scheduled or + self._selector.registered_count() > 1): + try: + self._run_once() + except _StopError: + break + + def run_forever(self): + """Run until stop() is called. + + This only makes sense over run() if you have another thread + scheduling callbacks using call_soon_threadsafe(). + """ + handler = self.call_repeatedly(24*3600, lambda: None) + try: + self.run() + finally: + handler.cancel() + + def run_once(self, timeout=None): + """Run through all callbacks and all I/O polls once. + + Calling stop() will break out of this too. + """ + try: + self._run_once(timeout) + except _StopError: + pass + + def run_until_complete(self, future, timeout=None): + """Run until the Future is done, or until a timeout. + + Return the Future's result, or raise its exception. If the + timeout is reached or stop() is called, raise TimeoutError. + """ + if timeout is None: + timeout = 0x7fffffff/1000.0 # 24 days + future.add_done_callback(lambda _: self.stop()) + handler = self.call_later(timeout, _raise_stop_error) + self.run() + handler.cancel() + if future.done(): + return future.result() # May raise future.exception(). + else: + raise futures.TimeoutError + + def stop(self): + """Stop running the event loop. + + Every callback scheduled before stop() is called will run. + Callback scheduled after stop() is called won't. However, + those callbacks will run if run() is called again later. + """ + self.call_soon(_raise_stop_error) + + def call_later(self, delay, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The delay can be an int or float, expressed in seconds. It is + always a relative time. + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Callbacks scheduled in the past are passed on to call_soon(), + so these will be called in the order in which they were + registered rather than by time due. This is so you can't + cheat and insert yourself at the front of the ready queue by + using a negative time. + + Any positional arguments after the callback will be passed to + the callback when it is called. + + # TODO: Should delay is None be interpreted as Infinity? + """ + if delay <= 0: + return self.call_soon(callback, *args) + handler = events.make_handler(time.monotonic() + delay, callback, args) + heapq.heappush(self._scheduled, handler) + return handler + + def call_repeatedly(self, interval, callback, *args): + """Call a callback every 'interval' seconds.""" + def wrapper(): + callback(*args) # If this fails, the chain is broken. + handler._when = time.monotonic() + interval + heapq.heappush(self._scheduled, handler) + handler = events.make_handler(time.monotonic() + interval, wrapper, ()) + heapq.heappush(self._scheduled, handler) + return handler + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + handler = events.make_handler(None, callback, args) + self._ready.append(handler) + return handler + + def call_soon_threadsafe(self, callback, *args): + """XXX""" + handler = self.call_soon(callback, *args) + self._write_to_self() + return handler + + def call_every_iteration(self, callback, *args): + """Call a callback just before the loop blocks. + + The callback is called for every iteration of the loop. + """ + handler = events.make_handler(None, callback, args) + self._everytime.append(handler) + return handler + + def wrap_future(self, future): + """XXX""" + if isinstance(future, futures.Future): + return future # Don't wrap our own type of Future. + new_future = futures.Future() + future.add_done_callback( + lambda future: + self.call_soon_threadsafe(new_future._copy_state, future)) + return new_future + + def run_in_executor(self, executor, callback, *args): + if isinstance(callback, events.Handler): + assert not args + assert callback.when is None + if callback.cancelled: + f = futures.Future() + f.set_result(None) + return f + callback, args = callback.callback, callback.args + if executor is None: + executor = self._default_executor + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + self._default_executor = executor + return self.wrap_future(executor.submit(callback, *args)) + + def set_default_executor(self, executor): + self._default_executor = executor + + def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) + + def getnameinfo(self, sockaddr, flags=0): + return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + + @tasks.task + def create_connection(self, protocol_factory, host, port, *, ssl=False, + family=0, proto=0, flags=0): + """XXX""" + infos = yield from self.getaddrinfo(host, port, + family=family, + type=socket.SOCK_STREAM, + proto=proto, flags=flags) + if not infos: + raise socket.error('getaddrinfo() returned empty list') + exceptions = [] + for family, type, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock = self._selector.wrap_socket(sock) + sock.setblocking(False) + yield self.sock_connect(sock, address) + except socket.error as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + if len(exceptions) == 1: + raise exceptions[0] + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise socket.error('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + protocol = protocol_factory() + waiter = futures.Future() + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + transport = _UnixSslTransport(self, sock, protocol, sslcontext, + waiter) + else: + transport = _UnixSocketTransport(self, sock, protocol, waiter) + yield from waiter + return transport, protocol + + # TODO: Or create_server()? + @tasks.task + def start_serving(self, protocol_factory, host, port, *, + family=0, proto=0, flags=0, + backlog=100): + """XXX""" + infos = yield from self.getaddrinfo(host, port, + family=family, + type=socket.SOCK_STREAM, + proto=proto, flags=flags) + if not infos: + raise socket.error('getaddrinfo() returned empty list') + # TODO: Maybe we want to bind every address in the list + # instead of the first one that works? + exceptions = [] + for family, type, proto, cname, address in infos: + sock = socket.socket(family=family, type=type, proto=proto) + sock = self._selector.wrap_socket(sock) + try: + sock.bind(address) + except socket.error as exc: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + sock.listen(backlog) + sock.setblocking(False) + self.add_reader(sock.fileno(), self._accept_connection, + protocol_factory, sock) + return sock + + def _accept_connection(self, protocol_factory, sock): + try: + conn, addr = sock.accept() + except socket.error as exc: + if exc.errno in _TRYAGAIN: + return # False alarm. + # Bad error. Stop serving. + self.remove_reader(sock.fileno()) + sock.close() + # There's nowhere to send the error, so just log it. + # TODO: Someone will want an error handler for this. + logging.exception('Accept failed') + return + protocol = protocol_factory() + transport = _UnixSocketTransport(self, conn, protocol) + # It's now up to the protocol to handle the connection. + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a Handler instance.""" + handler = events.make_handler(None, callback, args) + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_READ, + (handler, None, None)) + else: + self._selector.modify(fd, mask | selectors.EVENT_READ, + (handler, writer, connector)) + + return handler + + def remove_reader(self, fd): + """Remove a reader callback.""" + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + return False + else: + mask &= ~selectors.EVENT_READ + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (None, writer, connector)) + return True + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a Handler instance.""" + handler = events.make_handler(None, callback, args) + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_WRITE, + (None, handler, None)) + else: + # Remove connector. + mask &= ~selectors.EVENT_CONNECT + self._selector.modify(fd, mask | selectors.EVENT_WRITE, + (reader, handler, None)) + return handler + + def remove_writer(self, fd): + """Remove a writer callback.""" + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + return False + else: + # Remove both writer and connector. + mask &= ~(selectors.EVENT_WRITE | selectors.EVENT_CONNECT) + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None, None)) + return True + + # NOTE: add_connector() and add_writer() are mutually exclusive. + # While you can independently manipulate readers and writers, + # adding a connector for a particular FD automatically removes the + # writer for that FD, and vice versa, and removing a writer or a + # connector actually removes both writer and connector. This is + # because in most cases writers and connectors use the same mode + # for the platform polling function; the distinction is only + # important for PollSelector() on Windows. + + def add_connector(self, fd, callback, *args): + """Add a connector callback. Return a Handler instance.""" + handler = events.make_handler(None, callback, args) + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_CONNECT, + (None, None, handler)) + else: + # Remove writer. + mask &= ~selectors.EVENT_WRITE + self._selector.modify(fd, mask | selectors.EVENT_CONNECT, + (reader, None, handler)) + return handler + + def remove_connector(self, fd): + """Remove a connector callback.""" + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + return False + else: + # Remove both writer and connector. + mask &= ~(selectors.EVENT_WRITE | selectors.EVENT_CONNECT) + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None, None)) + return True + + def sock_recv(self, sock, n): + """XXX""" + fut = futures.Future() + self._sock_recv(fut, False, sock, n) + return fut + + def _sock_recv(self, fut, registered, sock, n): + fd = sock.fileno() + if registered: + # Remove the callback early. It should be rare that the + # selector says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_reader(fd) + if fut.cancelled(): + return + try: + data = sock.recv(n) + fut.set_result(data) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + fut.set_exception(exc) + else: + self.add_reader(fd, self._sock_recv, fut, True, sock, n) + + def sock_sendall(self, sock, data): + """XXX""" + fut = futures.Future() + self._sock_sendall(fut, False, sock, data) + return fut + + def _sock_sendall(self, fut, registered, sock, data): + fd = sock.fileno() + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + n = 0 + try: + if data: + n = sock.send(data) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + fut.set_exception(exc) + return + if n == len(data): + fut.set_result(None) + else: + if n: + data = data[n:] + self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + + def sock_connect(self, sock, address): + """XXX""" + # That address better not require a lookup! We're not calling + # self.getaddrinfo() for you here. But verifying this is + # complicated; the socket module doesn't have a pattern for + # IPv6 addresses (there are too many forms, apparently). + fut = futures.Future() + self._sock_connect(fut, False, sock, address) + return fut + + def _sock_connect(self, fut, registered, sock, address): + fd = sock.fileno() + if registered: + self.remove_connector(fd) + if fut.cancelled(): + return + try: + if not registered: + # First time around. + sock.connect(address) + else: + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to the except clause below. + raise socket.error(err, 'Connect call failed') + fut.set_result(None) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + fut.set_exception(exc) + else: + self.add_connector(fd, self._sock_connect, + fut, True, sock, address) + + def sock_accept(self, sock): + """XXX""" + fut = futures.Future() + self._sock_accept(fut, False, sock) + return fut + + def _sock_accept(self, fut, registered, sock): + fd = sock.fileno() + if registered: + self.remove_reader(fd) + if fut.cancelled(): + return + try: + conn, address = sock.accept() + conn.setblocking(False) + fut.set_result((conn, address)) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + fut.set_exception(exc) + else: + self.add_reader(fd, self._sock_accept, fut, True, sock) + + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + self._check_signal(sig) + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except ValueError as exc: + raise RuntimeError(str(exc)) + handler = events.make_handler(None, callback, args) + self._signal_handlers[sig] = handler + try: + signal.signal(sig, self._handle_signal) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as nexc: + logging.info('set_wakeup_fd(-1) failed: %s', nexc) + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + return handler + + def _handle_signal(self, sig, arg): + """Internal helper that is the actual signal handler.""" + handler = self._signal_handlers.get(sig) + if handler is None: + return # Assume it's some race condition. + if handler.cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self.call_soon_threadsafe(handler.callback, *handler.args) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not.""" + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as exc: + logging.info('set_wakeup_fd(-1) failed: %s', exc) + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + if signal is None: + raise RuntimeError('Signals are not supported') + if not (1 <= sig < signal.NSIG): + raise ValueError('sig {} out of range(1, {})'.format(sig, + signal.NSIG)) + + def _add_callback(self, handler): + """Add a Handler to ready or scheduled.""" + if handler.cancelled: + return + if handler.when is None: + self._ready.append(handler) + else: + heapq.heappush(self._scheduled, handler) + + def _run_once(self, timeout=None): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # Add everytime handlers, skipping cancelled ones. + any_cancelled = False + for handler in self._everytime: + if handler.cancelled: + any_cancelled = True + else: + self._ready.append(handler) + + # Remove cancelled everytime handlers if there are any. + if any_cancelled: + self._everytime = [handler for handler in self._everytime + if not handler.cancelled] + + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0].cancelled: + heapq.heappop(self._scheduled) + + # Inspect the poll queue. If there's exactly one selectable + # file descriptor, it's the self-pipe, and if there's nothing + # scheduled, we should ignore it. + if self._selector.registered_count() > 1 or self._scheduled: + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0].when + deadline = max(0, when - time.monotonic()) + if timeout is None: + timeout = deadline + else: + timeout = min(timeout, deadline) + + t0 = time.monotonic() + event_list = self._selector.select(timeout) + t1 = time.monotonic() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + for fileobj, mask, (reader, writer, connector) in event_list: + if mask & selectors.EVENT_READ and reader is not None: + if reader.cancelled: + self.remove_reader(fileobj) + else: + self._add_callback(reader) + if mask & selectors.EVENT_WRITE and writer is not None: + if writer.cancelled: + self.remove_writer(fileobj) + else: + self._add_callback(writer) + elif mask & selectors.EVENT_CONNECT and connector is not None: + if connector.cancelled: + self.remove_connector(fileobj) + else: + self._add_callback(connector) + + # Handle 'later' callbacks that are ready. + now = time.monotonic() + while self._scheduled: + handler = self._scheduled[0] + if handler.when > now: + break + handler = heapq.heappop(self._scheduled) + self._ready.append(handler) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is threadsafe without using locks. + ntodo = len(self._ready) + for i in range(ntodo): + handler = self._ready.popleft() + if not handler.cancelled: + try: + handler.callback(*handler.args) + except Exception: + logging.exception('Exception in callback %s %r', + handler.callback, handler.args) + + +class _UnixSocketTransport(transports.Transport): + + def __init__(self, event_loop, sock, protocol, waiter=None): + self._event_loop = event_loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._closing = False # Set when close() called. + self._event_loop.add_reader(self._sock.fileno(), self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = self._sock.recv(16*1024) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + else: + if data: + self._event_loop.call_soon(self._protocol.data_received, data) + else: + self._event_loop.remove_reader(self._sock.fileno()) + self._event_loop.call_soon(self._protocol.eof_received) + + + def write(self, data): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + if not self._buffer: + # Attempt to send it right away first. + try: + n = self._sock.send(data) + except socket.error as exc: + if exc.errno in _TRYAGAIN: + n = 0 + else: + self._fatal_error(exc) + return + if n == len(data): + return + if n: + data = data[n:] + self._event_loop.add_writer(self._sock.fileno(), self._write_ready) + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + self._buffer = [] + try: + if data: + n = self._sock.send(data) + else: + n = 0 + except socket.error as exc: + if exc.errno in _TRYAGAIN: + n = 0 + else: + self._fatal_error(exc) + return + if n == len(data): + self._event_loop.remove_writer(self._sock.fileno()) + if self._closing: + self._event_loop.call_soon(self._call_connection_lost, None) + return + if n: + data = data[n:] + self._buffer.append(data) # Try again later. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sock.fileno()) + if not self._buffer: + self._event_loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + logging.exception('Fatal error for %s', self) + self._event_loop.remove_writer(self._sock.fileno()) + self._event_loop.remove_reader(self._sock.fileno()) + self._buffer = [] + self._event_loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + + +class _UnixSslTransport(transports.Transport): + + def __init__(self, event_loop, rawsock, protocol, sslcontext, waiter): + self._event_loop = event_loop + self._rawsock = rawsock + self._protocol = protocol + sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslcontext = sslcontext + self._waiter = waiter + sslsock = sslcontext.wrap_socket(rawsock, + do_handshake_on_connect=False) + self._sslsock = sslsock + self._buffer = [] + self._closing = False # Set when close() called. + self._on_handshake() + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._event_loop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._event_loop.add_writer(fd, self._on_handshake) + return + except Exception as exc: + self._sslsock.close() + self._waiter.set_exception(exc) + return + except BaseException as exc: + self._sslsock.close() + self._waiter.set_exception(exc) + raise + self._event_loop.remove_reader(fd) + self._event_loop.remove_writer(fd) + self._event_loop.add_reader(fd, self._on_ready) + self._event_loop.add_writer(fd, self._on_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.call_soon(self._waiter.set_result, None) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + return + else: + if data: + self._protocol.data_received(data) + else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer + self._event_loop.remove_reader(fd) + self._event_loop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + return + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = b''.join(self._buffer) + self._buffer = [] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + return + else: + if n < len(data): + self._buffer.append(data[n:]) + + def write(self, data): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sslsock.fileno()) + if not self._buffer: + self._event_loop.call_soon(self._protocol.connection_lost, None) + + def _fatal_error(self, exc): + logging.exception('Fatal error for %s', self) + self._event_loop.remove_writer(self._sslsock.fileno()) + self._event_loop.remove_reader(self._sslsock.fileno()) + self._buffer = [] + self._event_loop.call_soon(self._protocol.connection_lost, exc) diff --git a/tulip/winsocketpair.py b/tulip/winsocketpair.py new file mode 100644 index 00000000..87d54c91 --- /dev/null +++ b/tulip/winsocketpair.py @@ -0,0 +1,30 @@ +"""A socket pair usable as a self-pipe, for Windows. + +Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. +""" + +import errno +import socket + + +def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """Emulate the Unix socketpair() function on Windows.""" + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + lsock.bind(('localhost', 0)) + lsock.listen(1) + addr, port = lsock.getsockname() + csock = socket.socket(family, type, proto) + csock.setblocking(False) + try: + csock.connect((addr, port)) + except socket.error as e: + if e.errno != errno.WSAEWOULDBLOCK: + lsock.close() + csock.close() + raise + ssock, _ = lsock.accept() + csock.setblocking(True) + lsock.close() + return (ssock, csock) From df2b407ab6985a8c5e4075787a5177794ef4e4e6 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Mon, 21 Jan 2013 20:39:21 +0000 Subject: [PATCH 0247/1502] Merge. --- tulip/unix_events.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 4b3dc525..2b064f12 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -83,7 +83,7 @@ def __init__(self, selector=None): if selector is None: # pick the best selector class for the platform selector = selectors.Selector() - logging.info('Using selector: %s', selector.__class__.__name__) + logging.debug('Using selector: %s', selector.__class__.__name__) self._selector = selector self._ready = collections.deque() self._scheduled = [] @@ -96,6 +96,8 @@ def close(self): if self._selector is not None: self._selector.close() self._selector = None + self._ssock.close() + self._csock.close() def _make_self_pipe(self): # A self-socket, really. :-) @@ -416,6 +418,8 @@ def remove_reader(self, fd): self._selector.unregister(fd) else: self._selector.modify(fd, mask, (None, writer, connector)) + if reader is not None: + reader.cancel() return True def add_writer(self, fd, callback, *args): @@ -446,6 +450,10 @@ def remove_writer(self, fd): self._selector.unregister(fd) else: self._selector.modify(fd, mask, (reader, None, None)) + if writer is not None: + writer.cancel() + if connector is not None: + connector.cancel() return True # NOTE: add_connector() and add_writer() are mutually exclusive. @@ -485,6 +493,10 @@ def remove_connector(self, fd): self._selector.unregister(fd) else: self._selector.modify(fd, mask, (reader, None, None)) + if writer is not None: + writer.cancel() + if connector is not None: + connector.cancel() return True def sock_recv(self, sock, n): From c2ac04343790abc6c46db9878ec17c29db3a0dee Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 21 Jan 2013 18:55:29 -0800 Subject: [PATCH 0248/1502] Throwaway code: async server (srv.py), sync ssl server (sslsrv.py). --- srv.py | 76 ++++++++++++++++++++++++++++++++++++++++++++ sslsrv.py | 56 ++++++++++++++++++++++++++++++++ tulip/unix_events.py | 1 + 3 files changed, 133 insertions(+) create mode 100644 srv.py create mode 100644 sslsrv.py diff --git a/srv.py b/srv.py new file mode 100644 index 00000000..e49b8998 --- /dev/null +++ b/srv.py @@ -0,0 +1,76 @@ +"""Simple server written using an event loop.""" + +import email.message +import email.parser +import gc +import re + +import tulip +from tulip.http_client import StreamReader + + +class HttpServer(tulip.Protocol): + + def __init__(self): + super().__init__() + self.transport = None + self.reader = None + self.handler = None + + @tulip.task + def handle_request(self): + line = yield from self.reader.readline() + print('request line', line) + match = re.match(rb'GET (\S+) HTTP/(1.\d)\r?\n\Z', line) + if not match: + self.transport.close() + return + lines = [] + while True: + line = yield from self.reader.readline() + print('header line', line) + if not line.strip(b' \t\r\n'): + break + lines.append(line) + if line == b'\r\n': + break + parser = email.parser.BytesHeaderParser() + headers = parser.parsebytes(b''.join(lines)) + self.transport.write(b'HTTP/1.0 200 Ok\r\n' + b'Content-type: text/plain\r\n' + b'\r\n' + b'Hello world.\r\n') + self.transport.close() + + def connection_made(self, transport): + self.transport = transport + print('connection made', transport, transport._sock) + self.reader = StreamReader() + self.handler = self.handle_request() + + def data_received(self, data): + print('data received', data) + self.reader.feed_data(data) + + def eof_received(self): + print('eof received') + self.reader.feed_eof() + + def connection_lost(self, exc): + print('connection lost', exc) + if (self.handler.done() and + not self.handler.cancelled() and + self.handler.exception() is not None): + print('handler exception:', self.handler.exception()) + + +def main(): + loop = tulip.get_event_loop() + f = loop.start_serving(HttpServer, '127.0.0.1', 8080) + x = loop.run_until_complete(f) + print('serving on', x.getsockname()) + loop.run_forever() + + +if __name__ == '__main__': + main() diff --git a/sslsrv.py b/sslsrv.py new file mode 100644 index 00000000..a1bc04f9 --- /dev/null +++ b/sslsrv.py @@ -0,0 +1,56 @@ +"""Serve up an SSL connection, after Python ssl module docs.""" + +import socket +import ssl +import os + + +def main(): + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + certfile = getcertfile() + context.load_cert_chain(certfile=certfile, keyfile=certfile) + bindsocket = socket.socket() + bindsocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + bindsocket.bind(('', 4443)) + bindsocket.listen(5) + + while True: + newsocket, fromaddr = bindsocket.accept() + try: + connstream = context.wrap_socket(newsocket, server_side=True) + try: + deal_with_client(connstream) + finally: + connstream.shutdown(socket.SHUT_RDWR) + connstream.close() + except Exception as exc: + print(exc.__class__.__name__ + ':', exc) + + +def getcertfile(): + import test # Test package + testdir = os.path.dirname(test.__file__) + certfile = os.path.join(testdir, 'keycert.pem') + print('certfile =', certfile) + return certfile + + +def deal_with_client(connstream): + data = connstream.recv(1024) + # empty data means the client is finished with us + while data: + if not do_something(connstream, data): + # we'll assume do_something returns False + # when we're finished with client + break + data = connstream.recv(1024) + # finished with client + + +def do_something(connstream, data): + # just echo back + connstream.sendall(data) + + +if __name__ == '__main__': + main() diff --git a/tulip/unix_events.py b/tulip/unix_events.py index cea3ba95..e281b8a9 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -357,6 +357,7 @@ def start_serving(self, protocol_factory, host, port, *, for family, type, proto, cname, address in infos: sock = socket.socket(family=family, type=type, proto=proto) try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind(address) except socket.error as exc: sock.close() From 9a1f5215f2b0f94c8d10ef60606f3ce1b1900ec3 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Tue, 22 Jan 2013 15:20:26 +0000 Subject: [PATCH 0249/1502] Make tests allow for spurious readiness notifications. --- tulip/events_test.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tulip/events_test.py b/tulip/events_test.py index d0061cf5..76209bf5 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -200,7 +200,12 @@ def testReaderCallback(self): r, w = unix_events.socketpair() bytes_read = [] def reader(): - data = r.recv(1024) + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return if data: bytes_read.append(data) else: @@ -218,7 +223,12 @@ def testReaderCallbackWithHandler(self): r, w = unix_events.socketpair() bytes_read = [] def reader(): - data = r.recv(1024) + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return if data: bytes_read.append(data) else: From 652ddd8d59e53bb45fd160fc3ff3aea80ade8cd4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 22 Jan 2013 11:22:13 -0800 Subject: [PATCH 0250/1502] Get rid of call_every_iteration() (Tornado cannot implement it). --- tulip/events_test.py | 31 ------------------------------- tulip/unix_events.py | 23 ----------------------- 2 files changed, 54 deletions(-) diff --git a/tulip/events_test.py b/tulip/events_test.py index 76209bf5..76802579 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -134,37 +134,6 @@ def run(): self.assertEqual(results, ['hello', 'world']) self.assertTrue(t1-t0 >= 0.09) - def testCallEveryIteration(self): - el = events.get_event_loop() - results = [] - def callback(arg): - results.append(arg) - handler = el.call_every_iteration(callback, 'ho') - el.run_once() - self.assertEqual(results, ['ho']) - el.run_once() - el.run_once() - self.assertEqual(results, ['ho', 'ho', 'ho']) - handler.cancel() - el.run_once() - self.assertEqual(results, ['ho', 'ho', 'ho']) - - def testCallEveryIterationWithHandler(self): - el = events.get_event_loop() - results = [] - def callback(arg): - results.append(arg) - handler = events.Handler(None, callback, ('ho',)) - self.assertEqual(el.call_every_iteration(handler), handler) - el.run_once() - self.assertEqual(results, ['ho']) - el.run_once() - el.run_once() - self.assertEqual(results, ['ho', 'ho', 'ho']) - handler.cancel() - el.run_once() - self.assertEqual(results, ['ho', 'ho', 'ho']) - def testWrapFuture(self): el = events.get_event_loop() def run(arg): diff --git a/tulip/unix_events.py b/tulip/unix_events.py index e281b8a9..60ffda75 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -87,7 +87,6 @@ def __init__(self, selector=None): self._selector = selector self._ready = collections.deque() self._scheduled = [] - self._everytime = [] self._default_executor = None self._signal_handlers = {} self._make_self_pipe() @@ -247,15 +246,6 @@ def call_soon_threadsafe(self, callback, *args): self._write_to_self() return handler - def call_every_iteration(self, callback, *args): - """Call a callback just before the loop blocks. - - The callback is called for every iteration of the loop. - """ - handler = events.make_handler(None, callback, args) - self._everytime.append(handler) - return handler - def wrap_future(self, future): """XXX""" if isinstance(future, futures.Future): @@ -707,19 +697,6 @@ def _run_once(self, timeout=None): # TODO: An alternative API would be to do the *minimal* amount # of work, e.g. one callback or one I/O poll. - # Add everytime handlers, skipping cancelled ones. - any_cancelled = False - for handler in self._everytime: - if handler.cancelled: - any_cancelled = True - else: - self._ready.append(handler) - - # Remove cancelled everytime handlers if there are any. - if any_cancelled: - self._everytime = [handler for handler in self._everytime - if not handler.cancelled] - # Remove delayed calls that were cancelled from head of queue. while self._scheduled and self._scheduled[0].cancelled: heapq.heappop(self._scheduled) From 2e37b330e6660b7a116e8ef19543d64526e54fd4 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 23 Jan 2013 08:53:16 -0800 Subject: [PATCH 0251/1502] StreamReader.readline does not maintain byte_count value --- tulip/http_client.py | 6 +++- tulip/http_client_test.py | 76 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 tulip/http_client_test.py diff --git a/tulip/http_client.py b/tulip/http_client.py index d658be51..a7df6ff9 100644 --- a/tulip/http_client.py +++ b/tulip/http_client.py @@ -85,7 +85,11 @@ def readline(self): self.buffer.appendleft(tail) self.line_count -= 1 break - return b''.join(parts) + + line = b''.join(parts) + self.byte_count -= len(line) + + return line @tasks.coroutine def read(self, n=-1): diff --git a/tulip/http_client_test.py b/tulip/http_client_test.py new file mode 100644 index 00000000..6010c872 --- /dev/null +++ b/tulip/http_client_test.py @@ -0,0 +1,76 @@ +"""Tests for http_client.py.""" + +import unittest + +from . import events +from . import http_client +from . import tasks + + +class StreamReaderTest(unittest.TestCase): + + DATA = b'line1\nline2\nline3\n' + + def setUp(self): + self.event_loop = events.new_event_loop() + self.addCleanup(self.event_loop.close) + + events.set_event_loop(self.event_loop) + + def test_feed_empty_data(self): + stream = http_client.StreamReader() + + stream.feed_data(b'') + self.assertEqual(0, stream.line_count) + self.assertEqual(0, stream.byte_count) + + def test_feed_data_line_byte_count(self): + stream = http_client.StreamReader() + + stream.feed_data(self.DATA) + self.assertEqual(self.DATA.count(b'\n'), stream.line_count) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readline_line_byte_count(self): + stream = http_client.StreamReader() + stream.feed_data(self.DATA) + + @tasks.coroutine + def readline(): + line = yield from stream.readline() + return line + + line = self.event_loop.run_until_complete(tasks.Task(readline())) + + self.assertEqual(b'line1\n', line) + self.assertEqual(self.DATA.count(b'\n')-1, stream.line_count) + self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) + + def test_readline_read_byte_count(self): + stream = http_client.StreamReader() + stream.feed_data(self.DATA) + + @tasks.coroutine + def readline(): + line = yield from stream.readline() + return line + + line = self.event_loop.run_until_complete(tasks.Task(readline())) + + @tasks.coroutine + def read(): + line = yield from stream.read(7) + return line + + data = self.event_loop.run_until_complete(tasks.Task(read())) + + self.assertEqual(b'line2\nl', data) + self.assertEqual( + 1, stream.line_count) + self.assertEqual( + len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), + stream.byte_count) + + +if __name__ == '__main__': + unittest.main() From 21ac0599f07aaf63d0007c57806cd8ea800cd39d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 23 Jan 2013 20:57:40 -0800 Subject: [PATCH 0252/1502] No signals on Windows, really. --- crawl.py | 5 ++++- tulip/unix_events.py | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/crawl.py b/crawl.py index a1303fcb..20527e70 100755 --- a/crawl.py +++ b/crawl.py @@ -121,7 +121,10 @@ def main(): rooturl = sys.argv[1] c = Crawler(rooturl) loop = tulip.get_event_loop() - loop.add_signal_handler(signal.SIGINT, loop.stop) + try: + loop.add_signal_handler(signal.SIGINT, loop.stop) + except RuntimeError: + pass loop.run_forever() print('todo:', len(c.todo)) print('busy:', len(c.busy)) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 2b064f12..6bb1ce8f 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -688,6 +688,8 @@ def _check_signal(self, sig): if not (1 <= sig < signal.NSIG): raise ValueError('sig {} out of range(1, {})'.format(sig, signal.NSIG)) + if sys.platform == 'win32': + raise RuntimeError('Signals are not really supported on Windows') def _add_callback(self, handler): """Add a Handler to ready or scheduled.""" From f2c654750a7042b2156f2dd4edd117843123633f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 23 Jan 2013 20:59:21 -0800 Subject: [PATCH 0253/1502] No signals on Windows (port to main branch). --- crawl.py | 5 ++++- tulip/unix_events.py | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/crawl.py b/crawl.py index a1303fcb..20527e70 100755 --- a/crawl.py +++ b/crawl.py @@ -121,7 +121,10 @@ def main(): rooturl = sys.argv[1] c = Crawler(rooturl) loop = tulip.get_event_loop() - loop.add_signal_handler(signal.SIGINT, loop.stop) + try: + loop.add_signal_handler(signal.SIGINT, loop.stop) + except RuntimeError: + pass loop.run_forever() print('todo:', len(c.todo)) print('busy:', len(c.busy)) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 60ffda75..fc2cb4a7 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -675,6 +675,8 @@ def _check_signal(self, sig): if not (1 <= sig < signal.NSIG): raise ValueError('sig {} out of range(1, {})'.format(sig, signal.NSIG)) + if sys.platform == 'win32': + raise RuntimeError('Signals are not really supported on Windows') def _add_callback(self, handler): """Add a Handler to ready or scheduled.""" From c58658d03667a9563b1d0c5a17a7309f7ca6aed9 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 23 Jan 2013 21:11:14 -0800 Subject: [PATCH 0254/1502] Merge default - iocp. --- srv.py | 76 +++++++++++++++++++++++++++++++++++++++ sslsrv.py | 56 +++++++++++++++++++++++++++++ tulip/events_test.py | 35 +++--------------- tulip/http_client.py | 6 +++- tulip/http_client_test.py | 76 +++++++++++++++++++++++++++++++++++++++ tulip/unix_events.py | 24 +------------ 6 files changed, 218 insertions(+), 55 deletions(-) create mode 100644 srv.py create mode 100644 sslsrv.py create mode 100644 tulip/http_client_test.py diff --git a/srv.py b/srv.py new file mode 100644 index 00000000..e49b8998 --- /dev/null +++ b/srv.py @@ -0,0 +1,76 @@ +"""Simple server written using an event loop.""" + +import email.message +import email.parser +import gc +import re + +import tulip +from tulip.http_client import StreamReader + + +class HttpServer(tulip.Protocol): + + def __init__(self): + super().__init__() + self.transport = None + self.reader = None + self.handler = None + + @tulip.task + def handle_request(self): + line = yield from self.reader.readline() + print('request line', line) + match = re.match(rb'GET (\S+) HTTP/(1.\d)\r?\n\Z', line) + if not match: + self.transport.close() + return + lines = [] + while True: + line = yield from self.reader.readline() + print('header line', line) + if not line.strip(b' \t\r\n'): + break + lines.append(line) + if line == b'\r\n': + break + parser = email.parser.BytesHeaderParser() + headers = parser.parsebytes(b''.join(lines)) + self.transport.write(b'HTTP/1.0 200 Ok\r\n' + b'Content-type: text/plain\r\n' + b'\r\n' + b'Hello world.\r\n') + self.transport.close() + + def connection_made(self, transport): + self.transport = transport + print('connection made', transport, transport._sock) + self.reader = StreamReader() + self.handler = self.handle_request() + + def data_received(self, data): + print('data received', data) + self.reader.feed_data(data) + + def eof_received(self): + print('eof received') + self.reader.feed_eof() + + def connection_lost(self, exc): + print('connection lost', exc) + if (self.handler.done() and + not self.handler.cancelled() and + self.handler.exception() is not None): + print('handler exception:', self.handler.exception()) + + +def main(): + loop = tulip.get_event_loop() + f = loop.start_serving(HttpServer, '127.0.0.1', 8080) + x = loop.run_until_complete(f) + print('serving on', x.getsockname()) + loop.run_forever() + + +if __name__ == '__main__': + main() diff --git a/sslsrv.py b/sslsrv.py new file mode 100644 index 00000000..a1bc04f9 --- /dev/null +++ b/sslsrv.py @@ -0,0 +1,56 @@ +"""Serve up an SSL connection, after Python ssl module docs.""" + +import socket +import ssl +import os + + +def main(): + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + certfile = getcertfile() + context.load_cert_chain(certfile=certfile, keyfile=certfile) + bindsocket = socket.socket() + bindsocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + bindsocket.bind(('', 4443)) + bindsocket.listen(5) + + while True: + newsocket, fromaddr = bindsocket.accept() + try: + connstream = context.wrap_socket(newsocket, server_side=True) + try: + deal_with_client(connstream) + finally: + connstream.shutdown(socket.SHUT_RDWR) + connstream.close() + except Exception as exc: + print(exc.__class__.__name__ + ':', exc) + + +def getcertfile(): + import test # Test package + testdir = os.path.dirname(test.__file__) + certfile = os.path.join(testdir, 'keycert.pem') + print('certfile =', certfile) + return certfile + + +def deal_with_client(connstream): + data = connstream.recv(1024) + # empty data means the client is finished with us + while data: + if not do_something(connstream, data): + # we'll assume do_something returns False + # when we're finished with client + break + data = connstream.recv(1024) + # finished with client + + +def do_something(connstream, data): + # just echo back + connstream.sendall(data) + + +if __name__ == '__main__': + main() diff --git a/tulip/events_test.py b/tulip/events_test.py index 07f09dd4..28fe8abc 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -134,37 +134,6 @@ def run(): self.assertEqual(results, ['hello', 'world']) self.assertTrue(t1-t0 >= 0.09) - def testCallEveryIteration(self): - el = events.get_event_loop() - results = [] - def callback(arg): - results.append(arg) - handler = el.call_every_iteration(callback, 'ho') - el.run_once() - self.assertEqual(results, ['ho']) - el.run_once() - el.run_once() - self.assertEqual(results, ['ho', 'ho', 'ho']) - handler.cancel() - el.run_once() - self.assertEqual(results, ['ho', 'ho', 'ho']) - - def testCallEveryIterationWithHandler(self): - el = events.get_event_loop() - results = [] - def callback(arg): - results.append(arg) - handler = events.Handler(None, callback, ('ho',)) - self.assertEqual(el.call_every_iteration(handler), handler) - el.run_once() - self.assertEqual(results, ['ho']) - el.run_once() - el.run_once() - self.assertEqual(results, ['ho', 'ho', 'ho']) - handler.cancel() - el.run_once() - self.assertEqual(results, ['ho', 'ho', 'ho']) - def testWrapFuture(self): el = events.get_event_loop() def run(arg): @@ -205,6 +174,8 @@ def reader(): try: data = r.recv(1024) except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. return if data: bytes_read.append(data) @@ -228,6 +199,8 @@ def reader(): try: data = r.recv(1024) except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. return if data: bytes_read.append(data) diff --git a/tulip/http_client.py b/tulip/http_client.py index d658be51..a7df6ff9 100644 --- a/tulip/http_client.py +++ b/tulip/http_client.py @@ -85,7 +85,11 @@ def readline(self): self.buffer.appendleft(tail) self.line_count -= 1 break - return b''.join(parts) + + line = b''.join(parts) + self.byte_count -= len(line) + + return line @tasks.coroutine def read(self, n=-1): diff --git a/tulip/http_client_test.py b/tulip/http_client_test.py new file mode 100644 index 00000000..6010c872 --- /dev/null +++ b/tulip/http_client_test.py @@ -0,0 +1,76 @@ +"""Tests for http_client.py.""" + +import unittest + +from . import events +from . import http_client +from . import tasks + + +class StreamReaderTest(unittest.TestCase): + + DATA = b'line1\nline2\nline3\n' + + def setUp(self): + self.event_loop = events.new_event_loop() + self.addCleanup(self.event_loop.close) + + events.set_event_loop(self.event_loop) + + def test_feed_empty_data(self): + stream = http_client.StreamReader() + + stream.feed_data(b'') + self.assertEqual(0, stream.line_count) + self.assertEqual(0, stream.byte_count) + + def test_feed_data_line_byte_count(self): + stream = http_client.StreamReader() + + stream.feed_data(self.DATA) + self.assertEqual(self.DATA.count(b'\n'), stream.line_count) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readline_line_byte_count(self): + stream = http_client.StreamReader() + stream.feed_data(self.DATA) + + @tasks.coroutine + def readline(): + line = yield from stream.readline() + return line + + line = self.event_loop.run_until_complete(tasks.Task(readline())) + + self.assertEqual(b'line1\n', line) + self.assertEqual(self.DATA.count(b'\n')-1, stream.line_count) + self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) + + def test_readline_read_byte_count(self): + stream = http_client.StreamReader() + stream.feed_data(self.DATA) + + @tasks.coroutine + def readline(): + line = yield from stream.readline() + return line + + line = self.event_loop.run_until_complete(tasks.Task(readline())) + + @tasks.coroutine + def read(): + line = yield from stream.read(7) + return line + + data = self.event_loop.run_until_complete(tasks.Task(read())) + + self.assertEqual(b'line2\nl', data) + self.assertEqual( + 1, stream.line_count) + self.assertEqual( + len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), + stream.byte_count) + + +if __name__ == '__main__': + unittest.main() diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 6bb1ce8f..ef2f6f80 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -87,7 +87,6 @@ def __init__(self, selector=None): self._selector = selector self._ready = collections.deque() self._scheduled = [] - self._everytime = [] self._default_executor = None self._signal_handlers = {} self._make_self_pipe() @@ -249,15 +248,6 @@ def call_soon_threadsafe(self, callback, *args): self._write_to_self() return handler - def call_every_iteration(self, callback, *args): - """Call a callback just before the loop blocks. - - The callback is called for every iteration of the loop. - """ - handler = events.make_handler(None, callback, args) - self._everytime.append(handler) - return handler - def wrap_future(self, future): """XXX""" if isinstance(future, futures.Future): @@ -361,6 +351,7 @@ def start_serving(self, protocol_factory, host, port, *, sock = socket.socket(family=family, type=type, proto=proto) sock = self._selector.wrap_socket(sock) try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind(address) except socket.error as exc: sock.close() @@ -712,19 +703,6 @@ def _run_once(self, timeout=None): # TODO: An alternative API would be to do the *minimal* amount # of work, e.g. one callback or one I/O poll. - # Add everytime handlers, skipping cancelled ones. - any_cancelled = False - for handler in self._everytime: - if handler.cancelled: - any_cancelled = True - else: - self._ready.append(handler) - - # Remove cancelled everytime handlers if there are any. - if any_cancelled: - self._everytime = [handler for handler in self._everytime - if not handler.cancelled] - # Remove delayed calls that were cancelled from head of queue. while self._scheduled and self._scheduled[0].cancelled: heapq.heappop(self._scheduled) From 706866c4ac1fe8fb95bcdd1d9f2a4b114f1ecf17 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Thu, 24 Jan 2013 16:48:57 +0000 Subject: [PATCH 0255/1502] Only use overlapped IO as a fallback when non-blocking read/write fails. Also use overlapped zero length reads to wait for readability. --- tulip/iocpsockets.py | 76 ++++++++++++++++++++++++++------------------ 1 file changed, 45 insertions(+), 31 deletions(-) diff --git a/tulip/iocpsockets.py b/tulip/iocpsockets.py index 67916050..2f918436 100644 --- a/tulip/iocpsockets.py +++ b/tulip/iocpsockets.py @@ -31,25 +31,32 @@ def __getattr__(self, name): return getattr(self._sock, name) def listen(self, backlog): - self._state = 'listening' self._sock.listen(backlog) + self._state = 'listening' def send(self, buf): - # if self._state != 'connected': - # raise ValueError('socket is in state %r' % self._state) if self._pending[EVENT_WRITE]: raise BlockingIOError(errno.EAGAIN, 'try again') + res = self._result[EVENT_WRITE] if res and not res[0]: + self._result[EVENT_WRITE] = None raise res[1] - ov = Overlapped(0) - ov.WSASend(self._fd, buf, 0) + + try: + return self._sock.send(buf) + except BlockingIOError: + pass + def callback(): self._sock._decref_socketios() if ov.getresult() < len(buf): # partial writes only happen if something has broken raise RuntimeError('partial write -- should not get here') - self._sock._io_refs += 1 + + ov = Overlapped(0) + ov.WSASend(self._fd, buf, 0) + self._sock._io_refs += 1 # prevent real close till send complete if ov.pending: self._selector._defer(self, ov, EVENT_WRITE, callback) else: @@ -57,45 +64,45 @@ def callback(): return len(buf) def recv(self, length): - # if self._state != 'connected': - # raise ValueError('socket is in state %r' % self._state) if self._pending[EVENT_READ]: raise BlockingIOError(errno.EAGAIN, 'try again') - if length <= 0: - if length < 0: - raise ValueError('negative length') - return b'' + res = self._result[EVENT_READ] - if res: - success, value = self._result[EVENT_READ] - if not success: - raise value - if length < len(value): - value = value[:length] - self._result[EVENT_READ] = (True, res[length:]) - else: - self._result[EVENT_READ] = None - return value + if res and not res[0]: + self._result[EVENT_READ] = None + raise res[1] + + try: + return self._sock.recv(length) + except BlockingIOError: + pass + + # a zero length read will block till socket is readable ov = Overlapped(0) - ov.WSARecv(self._fd, length, 0) + ov.WSARecv(self._fd, 0, 0) if ov.pending: self._selector._defer(self, ov, EVENT_READ, ov.getresult) raise BlockingIOError(errno.EAGAIN, 'try again') else: - return ov.getresult() + return self._sock.recv(length) def connect(self, address): if self._state != 'unknown': raise ValueError('socket is in state %r' % self._state) + self._state = 'connecting' BindLocal(self._fd, len(address)) ov = Overlapped(0) + try: ov.ConnectEx(self._fd, address) except OSError as e: if e.winerror == 10022: raise ConnectionRefusedError( errno.ECONNREFUSED, e.strerror, None, e.winerror) + else: + raise + def callback(): try: ov.getresult(False) @@ -108,6 +115,7 @@ def callback(): raise else: self._state = 'connected' + if ov.pending: self._selector._defer(self, ov, EVENT_WRITE, callback) raise BlockingIOError(errno.EINPROGRESS, 'connect in progress') @@ -128,25 +136,30 @@ def getsockopt(self, level, optname, buflen=None): def accept(self): if self._state != 'listening': raise ValueError('socket is in state %r' % self._state) - if self._result[EVENT_READ]: + + res = self._result[EVENT_READ] + if res: success, value = self._result[EVENT_READ] self._result[EVENT_READ] = None if success: return value else: raise value + if self._pending[EVENT_READ]: raise BlockingIOError(errno.EAGAIN, 'try again') - conn = socket.socket(self.family, self.type, self.proto) - conn = self._selector.wrap_socket(conn) - ov = Overlapped(0) - ov.AcceptEx(self._fd, conn.fileno()) + def callback(): ov.getresult(False) conn._sock.setsockopt( socket.SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, self._fd) conn._state = 'connected' - return conn, conn.getpeername() # XXX + return conn, conn.getpeername() + + conn = socket.socket(self.family, self.type, self.proto) + conn = self._selector.wrap_socket(conn) + ov = Overlapped(0) + ov.AcceptEx(self._fd, conn.fileno()) if ov.pending: self._selector._defer(self, ov, EVENT_READ, callback) raise BlockingIOError(errno.EAGAIN, 'try again') @@ -169,6 +182,7 @@ def __init__(self, *, concurrency=0xffffffff): self._fd_to_fileobj = weakref.WeakValueDictionary() def wrap_socket(self, sock): + sock.setblocking(False) return IocpSocket(self, sock) def _defer(self, sock, ov, flag, callback): @@ -191,7 +205,6 @@ def close(self): while self._address_to_info: status = GetQueuedCompletionStatus(self._iocp, 1000) if status is None: - print(self._address_to_info) continue self._address_to_info.pop(status[3], None) finally: @@ -199,6 +212,7 @@ def close(self): self._iocp = None def select(self, timeout=None): + # XXX currently this is O(n) where n is number of registered fds results = {} for fd, key in self._fd_to_key.items(): fileobj = self._fd_to_fileobj[fd] From 92114cd313df1cc4cc2f5ec43611597d3c0f6842 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Thu, 24 Jan 2013 17:33:39 +0000 Subject: [PATCH 0256/1502] Prevent sendall() from doing partial write without error. --- overlapped.c | 4 ++-- tulip/iocpsockets.py | 16 +++++++++++++--- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/overlapped.c b/overlapped.c index 1e7d6119..66c9e326 100644 --- a/overlapped.c +++ b/overlapped.c @@ -504,14 +504,14 @@ Overlapped_WSARecv(OverlappedObject *self, PyObject *args) { HANDLE handle; DWORD size; - DWORD flags; + DWORD flags = 0; DWORD nread; PyObject *buf; WSABUF wsabuf; int ret; DWORD err; - if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD F_DWORD, + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD "|" F_DWORD, &handle, &size, &flags)) return NULL; diff --git a/tulip/iocpsockets.py b/tulip/iocpsockets.py index 2f918436..73110d98 100644 --- a/tulip/iocpsockets.py +++ b/tulip/iocpsockets.py @@ -46,8 +46,20 @@ def send(self, buf): try: return self._sock.send(buf) except BlockingIOError: - pass + return self._send(buf) + + def sendall(self, buf): + if self._pending[EVENT_WRITE]: + raise BlockingIOError(errno.EAGAIN, 'try again') + res = self._result[EVENT_WRITE] + if res and not res[0]: + self._result[EVENT_WRITE] = None + raise res[1] + + return self._send(buf) + + def _send(self, buf): def callback(): self._sock._decref_socketios() if ov.getresult() < len(buf): @@ -166,8 +178,6 @@ def callback(): else: return callback() - sendall = send - # XXX how do we deal with shutdown? # XXX connect_ex, makefile, ...? From 86ce0444f44e13dfe9049ba8e26ca68e06cdd8e6 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 24 Jan 2013 15:05:44 -0800 Subject: [PATCH 0257/1502] Serve from the current filesystem. --- srv.py | 57 +++++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 51 insertions(+), 6 deletions(-) diff --git a/srv.py b/srv.py index e49b8998..7cbedf24 100644 --- a/srv.py +++ b/srv.py @@ -2,7 +2,7 @@ import email.message import email.parser -import gc +import os import re import tulip @@ -21,10 +21,32 @@ def __init__(self): def handle_request(self): line = yield from self.reader.readline() print('request line', line) - match = re.match(rb'GET (\S+) HTTP/(1.\d)\r?\n\Z', line) + match = re.match(rb'([A-Z]+) (\S+) HTTP/(1.\d)\r?\n\Z', line) if not match: self.transport.close() return + method, path, version = match.groups() + print('method = {!r}; path = {!r}; version = {!r}'.format(method, path, version)) + try: + path = path.decode('ascii') + except UnicodeError as exc: + print('not ascii', repr(path), exc) + path = None + else: + if not (path.isprintable() and path.startswith('/')) or '/.' in path: + print('bad path', repr(path)) + path = None + else: + path = '.' + path + if not os.path.exists(path): + print('no file', repr(path)) + path = None + else: + isdir = os.path.isdir(path) + if not path: + self.transport.write(b'HTTP/1.0 404 Not found\r\n\r\n') + self.transport.close() + return lines = [] while True: line = yield from self.reader.readline() @@ -36,10 +58,33 @@ def handle_request(self): break parser = email.parser.BytesHeaderParser() headers = parser.parsebytes(b''.join(lines)) - self.transport.write(b'HTTP/1.0 200 Ok\r\n' - b'Content-type: text/plain\r\n' - b'\r\n' - b'Hello world.\r\n') + write = self.transport.write + write(b'HTTP/1.0 200 Ok\r\n') + if isdir: + write(b'Content-type: text/html\r\n') + else: + write(b'Content-type: text/plain\r\n') + write(b'\r\n') + if isdir: + write(b'
    \r\n') + for name in sorted(os.listdir(path)): + if name.isprintable() and not name.startswith('.'): + try: + bname = name.encode('ascii') + except UnicodeError as exc: + pass + else: + if os.path.isdir(os.path.join(path, name)): + write(b'
  • ' + bname + b'/
  • \r\n') + else: + write(b'
  • ' + bname + b'
  • \r\n') + write(b'
') + else: + try: + with open(path, 'rb') as f: + write(f.read()) + except OSError as exc: + write(b'Cannot open\r\n') self.transport.close() def connection_made(self, transport): From 9d427a17e7e1874f3aba9244cd8b53b39cde099c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 24 Jan 2013 15:13:43 -0800 Subject: [PATCH 0258/1502] Fix relative url following. --- crawl.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/crawl.py b/crawl.py index 20527e70..b9881fcb 100755 --- a/crawl.py +++ b/crawl.py @@ -23,11 +23,11 @@ def __init__(self, rooturl): self.done = {} self.tasks = set() self.waiter = None - self.add(self.rooturl) # Set initial work. + self.addurl(self.rooturl, '') # Set initial work. self.run() # Kick off work. - def add(self, url): - url = urllib.parse.urljoin(self.rooturl, url) + def addurl(self, url, parenturl): + url = urllib.parse.urljoin(parenturl, url) url, frag = urllib.parse.urldefrag(url) if not url.startswith(self.rooturl): return False @@ -88,7 +88,7 @@ def process(self, url): if status[:3] in ('301', '302'): # Redirect. u = headers.get('location') or headers.get('uri') - if self.add(u): + if self.addurl(u, url): print(' ', url, status[:3], 'redirect to', u, end=END) elif status.startswith('200'): ctype = headers.get_content_type() @@ -101,7 +101,7 @@ def process(self, url): urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', line) for u in urls: - if self.add(u): + if self.addurl(u, url): print(' ', url, 'href to', u, end=END) ok = True finally: From 18c4c350ac3ff1dea46465cd6702a76af5007e73 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 24 Jan 2013 15:34:17 -0800 Subject: [PATCH 0259/1502] Fix trailing whitespace. --- tulip/http_client.py | 2 +- tulip/http_client_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tulip/http_client.py b/tulip/http_client.py index a7df6ff9..f01ee631 100644 --- a/tulip/http_client.py +++ b/tulip/http_client.py @@ -85,7 +85,7 @@ def readline(self): self.buffer.appendleft(tail) self.line_count -= 1 break - + line = b''.join(parts) self.byte_count -= len(line) diff --git a/tulip/http_client_test.py b/tulip/http_client_test.py index 6010c872..b3e3414f 100644 --- a/tulip/http_client_test.py +++ b/tulip/http_client_test.py @@ -70,7 +70,7 @@ def read(): self.assertEqual( len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), stream.byte_count) - + if __name__ == '__main__': unittest.main() From ca98344e6950f21348a3c6c1ab21c69548e9122f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 24 Jan 2013 15:35:13 -0800 Subject: [PATCH 0260/1502] Log exceptions when they occur. This is suboptimal, but logging them only when they aren't caught is trickier, and I've been bitten by buggy tasks too often. --- tulip/tasks.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tulip/tasks.py b/tulip/tasks.py index 91c32076..af7eaaff 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -106,11 +106,13 @@ def _step(self, value=None, exc=None): super().cancel() else: self.set_exception(exc) + logging.exception('Exception in task') except BaseException as exc: if self._must_cancel: super().cancel() else: self.set_exception(exc) + logging.exception('BaseException in task') raise else: # XXX No check for self._must_cancel here? From 251e50b65c2a513fc9bceb42af0afbf1e77c1d27 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 24 Jan 2013 15:35:38 -0800 Subject: [PATCH 0261/1502] Redirect if path to directory without trailing slash detected. --- srv.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/srv.py b/srv.py index 7cbedf24..bae9fdd8 100644 --- a/srv.py +++ b/srv.py @@ -25,12 +25,12 @@ def handle_request(self): if not match: self.transport.close() return - method, path, version = match.groups() - print('method = {!r}; path = {!r}; version = {!r}'.format(method, path, version)) + bmethod, bpath, bversion = match.groups() + print('method = {!r}; path = {!r}; version = {!r}'.format(bmethod, bpath, bversion)) try: - path = path.decode('ascii') + path = bpath.decode('ascii') except UnicodeError as exc: - print('not ascii', repr(path), exc) + print('not ascii', repr(bpath), exc) path = None else: if not (path.isprintable() and path.startswith('/')) or '/.' in path: @@ -59,6 +59,12 @@ def handle_request(self): parser = email.parser.BytesHeaderParser() headers = parser.parsebytes(b''.join(lines)) write = self.transport.write + if isdir and not path.endswith('/'): + write(b'HTTP/1.0 302 Redirected\r\n' + b'URI: ' + bpath + b'/\r\n' + b'Location: ' + bpath + b'/\r\n' + b'\r\n') + return write(b'HTTP/1.0 200 Ok\r\n') if isdir: write(b'Content-type: text/html\r\n') From ad7e6bb443a9c9329b0ca38055aac19256dadc0b Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 24 Jan 2013 20:54:56 -0800 Subject: [PATCH 0262/1502] load tests from *_test.py files --- runtests.py | 7 +++++-- tulip/http_client_test.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/runtests.py b/runtests.py index 1eef652b..e2598c9a 100644 --- a/runtests.py +++ b/runtests.py @@ -15,15 +15,18 @@ # Originally written by Beech Horn (for NDB). +import os import re import sys import unittest assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' +TULIP_DIR = os.path.join(os.path.dirname(__file__), 'tulip') + + def load_tests(includes=(), excludes=()): - mods = ['events', 'futures', 'tasks'] - test_mods = ['%s_test' % name for name in mods] + test_mods = [f[:-3] for f in os.listdir(TULIP_DIR) if f.endswith('_test.py')] tulip = __import__('tulip', fromlist=test_mods) loader = unittest.TestLoader() diff --git a/tulip/http_client_test.py b/tulip/http_client_test.py index b3e3414f..f712da94 100644 --- a/tulip/http_client_test.py +++ b/tulip/http_client_test.py @@ -7,7 +7,7 @@ from . import tasks -class StreamReaderTest(unittest.TestCase): +class StreamReaderTests(unittest.TestCase): DATA = b'line1\nline2\nline3\n' From aa2d5a4249bfa00ab5cb1bdf4db1b6b87c4cc334 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Fri, 25 Jan 2013 16:42:42 +0000 Subject: [PATCH 0263/1502] Replace IocpSelector with IocpEventLoop, and refactor UnixEventLoop. Note that those tests (6 of them) which use add_reader() or add_writer() fail with NotImplementedError. --- overlapped.c | 1 + tulip/events_test.py | 44 ++-- tulip/iocp_events.py | 318 ++++++++++++++++++++++++++++ tulip/iocpsockets.py | 384 ---------------------------------- tulip/unix_events.py | 484 ++++++++++++++++++++++--------------------- 5 files changed, 583 insertions(+), 648 deletions(-) create mode 100644 tulip/iocp_events.py delete mode 100644 tulip/iocpsockets.py diff --git a/overlapped.c b/overlapped.c index 66c9e326..a428cf2c 100644 --- a/overlapped.c +++ b/overlapped.c @@ -985,6 +985,7 @@ PyInit__overlapped(void) d = PyModule_GetDict(m); + WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING); WINAPI_CONSTANT(F_DWORD, FILE_SKIP_COMPLETION_PORT_ON_SUCCESS); WINAPI_CONSTANT(F_DWORD, INFINITE); WINAPI_CONSTANT(F_HANDLE, INVALID_HANDLE_VALUE); diff --git a/tulip/events_test.py b/tulip/events_test.py index 28fe8abc..0a084800 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -46,8 +46,7 @@ def connection_lost(self, exc): class EventLoopTestsMixin: def setUp(self): - self.selector = self.SELECTOR_CLASS() - self.event_loop = unix_events.UnixEventLoop(self.selector) + self.event_loop = self.create_event_loop() events.set_event_loop(self.event_loop) def tearDown(self): @@ -167,8 +166,6 @@ def run(arg): def testReaderCallback(self): el = events.get_event_loop() r, w = unix_events.socketpair() - r = el._selector.wrap_socket(r) - w = el._selector.wrap_socket(w) bytes_read = [] def reader(): try: @@ -192,8 +189,6 @@ def reader(): def testReaderCallbackWithHandler(self): el = events.get_event_loop() r, w = unix_events.socketpair() - r = el._selector.wrap_socket(r) - w = el._selector.wrap_socket(w) bytes_read = [] def reader(): try: @@ -218,8 +213,6 @@ def reader(): def testReaderCallbackCancel(self): el = events.get_event_loop() r, w = unix_events.socketpair() - r = el._selector.wrap_socket(r) - w = el._selector.wrap_socket(w) bytes_read = [] def reader(): try: @@ -242,7 +235,6 @@ def reader(): def testWriterCallback(self): el = events.get_event_loop() r, w = unix_events.socketpair() - w = el._selector.wrap_socket(w) w.setblocking(False) el.add_writer(w.fileno(), w.send, b'x'*(256*1024)) def remove_writer(): @@ -257,7 +249,6 @@ def remove_writer(): def testWriterCallbackWithHandler(self): el = events.get_event_loop() r, w = unix_events.socketpair() - w = el._selector.wrap_socket(w) w.setblocking(False) handler = events.Handler(None, w.send, (b'x'*(256*1024),)) self.assertEqual(el.add_writer(w.fileno(), handler), handler) @@ -273,8 +264,6 @@ def remove_writer(): def testWriterCallbackCancel(self): el = events.get_event_loop() r, w = unix_events.socketpair() - r = el._selector.wrap_socket(r) - w = el._selector.wrap_socket(w) w.setblocking(False) def sender(): w.send(b'x'*256) @@ -289,7 +278,6 @@ def sender(): def testSockClientOps(self): el = events.get_event_loop() sock = socket.socket() - sock = el._selector.wrap_socket(sock) sock.setblocking(False) # TODO: This depends on python.org behavior! address = socket.getaddrinfo('python.org', 80, socket.AF_INET)[0][4] @@ -303,7 +291,6 @@ def testSockClientFail(self): el = events.get_event_loop() sock = socket.socket() sock.setblocking(False) - sock = el._selector.wrap_socket(sock) # TODO: This depends on python.org behavior! address = socket.getaddrinfo('python.org', 12345, socket.AF_INET)[0][4] with self.assertRaises(ConnectionRefusedError): @@ -313,7 +300,6 @@ def testSockClientFail(self): def testSockAccept(self): el = events.get_event_loop() listener = socket.socket() - listener = el._selector.wrap_socket(listener) listener.setblocking(False) listener.bind(('127.0.0.1', 0)) listener.listen(1) @@ -432,34 +418,32 @@ def testStartServing(self): if hasattr(selectors, 'KqueueSelector'): class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase): - SELECTOR_CLASS = selectors.KqueueSelector - + def create_event_loop(self): + return unix_events.UnixEventLoop(selectors.KqueueSelector()) if hasattr(selectors, 'EpollSelector'): class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): - SELECTOR_CLASS = selectors.EpollSelector - + def create_event_loop(self): + return unix_events.UnixEventLoop(selectors.EpollSelector()) if hasattr(selectors, 'PollSelector'): class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): - SELECTOR_CLASS = selectors.PollSelector + def create_event_loop(self): + return unix_events.UnixEventLoop(selectors.PollSelector()) +if sys.platform == 'win32': + from . import iocp_events -try: - from . import iocpsockets -except ImportError: - pass -else: class IocpEventLoopTests(EventLoopTestsMixin, unittest.TestCase): - SELECTOR_CLASS = iocpsockets.IocpSelector - + def create_event_loop(self): + return iocp_events.IocpEventLoop() def testCreateSslTransport(self): - raise unittest.SkipTest("IocpSelector imcompatible with SSL") - + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") # Should always exist. class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): - SELECTOR_CLASS = selectors.SelectSelector + def create_event_loop(self): + return unix_events.UnixEventLoop(selectors.SelectSelector()) class HandlerTests(unittest.TestCase): diff --git a/tulip/iocp_events.py b/tulip/iocp_events.py new file mode 100644 index 00000000..cfa09a23 --- /dev/null +++ b/tulip/iocp_events.py @@ -0,0 +1,318 @@ +# +# Module implementing the Proactor pattern +# +# A proactor is used to initiate asynchronous I/O, and to wait for +# completion of previously initiated operations. +# + +import errno +import logging +import os +import heapq +import sys +import socket +import time +import weakref + +from _winapi import CloseHandle + +from . import transports + +from .futures import Future +from .unix_events import BaseEventLoop, _StopError +from .winsocketpair import socketpair +from ._overlapped import * + + +_TRYAGAIN = frozenset() # XXX + + +class IocpProactor: + + def __init__(self, concurrency=0xffffffff): + self._results = [] + self._iocp = CreateIoCompletionPort( + INVALID_HANDLE_VALUE, NULL, 0, concurrency) + self._cache = {} + self._registered = weakref.WeakSet() + + def registered_count(self): + return len(self._cache) + + def select(self, timeout=None): + if not self._results: + self._poll(timeout) + tmp, self._results = self._results, [] + return tmp + + def recv(self, conn, nbytes, flags=0): + self._register_with_iocp(conn) + ov = Overlapped(NULL) + ov.WSARecv(conn.fileno(), nbytes, flags) + return self._register(ov, conn, ov.getresult) + + def send(self, conn, buf, flags=0): + self._register_with_iocp(conn) + ov = Overlapped(NULL) + ov.WSASend(conn.fileno(), buf, flags) + return self._register(ov, conn, ov.getresult) + + def accept(self, listener): + self._register_with_iocp(listener) + conn = self._get_accept_socket() + ov = Overlapped(NULL) + ov.AcceptEx(listener.fileno(), conn.fileno()) + def finish_accept(): + addr = ov.getresult() + conn.setsockopt(socket.SOL_SOCKET, + SO_UPDATE_ACCEPT_CONTEXT, listener.fileno()) + conn.settimeout(listener.gettimeout()) + return conn, conn.getpeername() + return self._register(ov, listener, finish_accept) + + def connect(self, conn, address): + self._register_with_iocp(conn) + BindLocal(conn.fileno(), len(address)) + ov = Overlapped(NULL) + ov.ConnectEx(conn.fileno(), address) + def finish_connect(): + try: + ov.getresult() + except OSError as e: + if e.winerror == 1225: + raise ConnectionRefusedError(errno.ECONNREFUSED, + 'connection refused') + raise + conn.setsockopt(socket.SOL_SOCKET, + SO_UPDATE_CONNECT_CONTEXT, 0) + return conn + return self._register(ov, conn, finish_connect) + + def _register_with_iocp(self, obj): + if obj not in self._registered: + self._registered.add(obj) + CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) + SetFileCompletionNotificationModes(obj.fileno(), + FILE_SKIP_COMPLETION_PORT_ON_SUCCESS) + + def _register(self, ov, obj, callback): + f = Future() + if ov.error == ERROR_IO_PENDING: + # we must prevent ov and obj from being garbage collected + self._cache[ov.address] = (f, ov, obj, callback) + else: + try: + f.set_result(callback()) + except Exception as e: + f.set_exception(e) + return f + + def _get_accept_socket(self): + s = socket.socket() + s.settimeout(0) + return s + + def _poll(self, timeout=None): + if timeout is None: + ms = INFINITE + elif timeout < 0: + raise ValueError("negative timeout") + else: + ms = int(timeout * 1000 + 0.5) + if ms >= INFINITE: + raise ValueError("timeout too big") + while True: + status = GetQueuedCompletionStatus(self._iocp, ms) + if status is None: + return + f, ov, obj, callback = self._cache.pop(status[3]) + try: + value = callback() + except OSError as e: + if f is None: + sys.excepthook(*sys.exc_info()) + continue + f.set_exception(e) + self._results.append(f) + else: + if f is None: + continue + f.set_result(value) + self._results.append(f) + ms = 0 + + def close(self, *, CloseHandle=CloseHandle): + for address, (f, ov, obj, callback) in list(self._cache.items()): + try: + ov.cancel() + except OSError: + pass + + while self._cache: + status = GetQueuedCompletionStatus(self._iocp, 1000) + if status is None: + logging.debug('taking long time to close proactor') + continue + self._cache.pop(status[3]) + + if self._iocp is not None: + CloseHandle(self._iocp) + self._iocp = None + + +class IocpEventLoop(BaseEventLoop): + + @staticmethod + def SocketTransport(*args, **kwds): + return _IocpSocketTransport(*args, **kwds) + + @staticmethod + def SslTransport(*args, **kwds): + raise NotImplementedError + + def __init__(self, proactor=None): + super().__init__() + if proactor is None: + proactor = IocpProactor() + logging.debug('Using proactor: %s', proactor.__class__.__name__) + self._proactor = proactor + self._selector = proactor # convenient alias + self._readers = {} + self._make_self_pipe() + + def close(self): + if self._proactor is not None: + self._proactor.close() + self._proactor = None + self._ssock.close() + self._csock.close() + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + def loop(f=None): + if f and f.exception(): + self.close() + raise f.exception() + f = self._proactor.recv(self._ssock, 4096) + self.call_soon(f.add_done_callback, loop) + self.call_soon(loop) + + def _write_to_self(self): + self._proactor.send(self._csock, b'x') + + def _start_serving(self, protocol_factory, sock): + def loop(f=None): + try: + if f: + conn, addr = f.result() + protocol = protocol_factory() + transport = self.SocketTransport(self, conn, protocol) + f = self._proactor.accept(sock) + self.call_soon(f.add_done_callback, loop) + except OSError as exc: + if exc.errno in _TRYAGAIN: + self.call_soon(loop) + else: + sock.close() + logging.exception('Accept failed') + self.call_soon(loop) + + def sock_recv(self, sock, n): + return self._proactor.recv(sock, n) + + def sock_sendall(self, sock, data): + return self._proactor.send(sock, data) + + def sock_connect(self, sock, address): + return self._proactor.connect(sock, address) + + def sock_accept(self, sock): + return self._proactor.accept(sock) + + def _process_events(self, event_list): + pass # XXX hard work currently done in poll + + +class _IocpSocketTransport(transports.Transport): + + def __init__(self, event_loop, sock, protocol, waiter=None): + self._event_loop = event_loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._write_fut = None + self._closing = False # Set when close() called. + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.call_soon(self._loop_reading) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _loop_reading(self, f=None): + try: + if f: + data = f.result() + if not data: + self._event_loop.call_soon(self._protocol.eof_received) + return + self._event_loop.call_soon(self._protocol.data_received, data) + f = self._event_loop._proactor.recv(self._sock, 4096) + self._event_loop.call_soon( + f.add_done_callback, self._loop_reading) + except OSError as exc: + self._fatal_error(exc) + + def write(self, data): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + + if self._write_fut is not None: + self._buffer.append(data) + return + + def callback(f): + if f.exception(): + self._fatal_error(f.exception()) + # XXX should check for partial write + data = b''.join(self._buffer) + if data: + self._buffer = [] + self._write_fut = self._event_loop._proactor.send( + self._sock, data) + assert f is self._write_fut + self._write_fut = None + + self._write_fut = self._event_loop._proactor.send(self._sock, data) + self._write_fut.add_done_callback(callback) + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + if self._write_fut: + self._write_fut.cancel() + if not self._buffer: + self._event_loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + logging.exception('Fatal error for %s', self) + if self._write_fut: + self._write_fut.cancel() + # if self._read_fut: # XXX + # self._read_fut.cancel() + self._write_fut = self._read_fut = None + self._buffer = [] + self._event_loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() diff --git a/tulip/iocpsockets.py b/tulip/iocpsockets.py deleted file mode 100644 index 73110d98..00000000 --- a/tulip/iocpsockets.py +++ /dev/null @@ -1,384 +0,0 @@ -import collections -import errno -import socket -import weakref -import _winapi - -from .selectors import _BaseSelector, EVENT_READ, EVENT_WRITE, EVENT_CONNECT -from ._overlapped import * - - -class IocpSocket: - - _state = 'unknown' - - def __init__(self, selector, sock=None, family=socket.AF_INET, - type=socket.SOCK_STREAM, proto=0): - if sock is None: - sock = socket.socket(family, type, proto) - self._sock = sock - self._fd = self._sock.fileno() - self._selector = selector - self._pending = {EVENT_READ:False, EVENT_WRITE:False} - self._result = {EVENT_READ:None, EVENT_WRITE:None} - CreateIoCompletionPort(self._fd, selector._iocp, 0, 0) - # XXX SetFileCompletionNotificationModes() requires Vista or later - SetFileCompletionNotificationModes( - self._fd, FILE_SKIP_COMPLETION_PORT_ON_SUCCESS) - selector._fd_to_fileobj[sock.fileno()] = self - - def __getattr__(self, name): - return getattr(self._sock, name) - - def listen(self, backlog): - self._sock.listen(backlog) - self._state = 'listening' - - def send(self, buf): - if self._pending[EVENT_WRITE]: - raise BlockingIOError(errno.EAGAIN, 'try again') - - res = self._result[EVENT_WRITE] - if res and not res[0]: - self._result[EVENT_WRITE] = None - raise res[1] - - try: - return self._sock.send(buf) - except BlockingIOError: - return self._send(buf) - - def sendall(self, buf): - if self._pending[EVENT_WRITE]: - raise BlockingIOError(errno.EAGAIN, 'try again') - - res = self._result[EVENT_WRITE] - if res and not res[0]: - self._result[EVENT_WRITE] = None - raise res[1] - - return self._send(buf) - - def _send(self, buf): - def callback(): - self._sock._decref_socketios() - if ov.getresult() < len(buf): - # partial writes only happen if something has broken - raise RuntimeError('partial write -- should not get here') - - ov = Overlapped(0) - ov.WSASend(self._fd, buf, 0) - self._sock._io_refs += 1 # prevent real close till send complete - if ov.pending: - self._selector._defer(self, ov, EVENT_WRITE, callback) - else: - callback() - return len(buf) - - def recv(self, length): - if self._pending[EVENT_READ]: - raise BlockingIOError(errno.EAGAIN, 'try again') - - res = self._result[EVENT_READ] - if res and not res[0]: - self._result[EVENT_READ] = None - raise res[1] - - try: - return self._sock.recv(length) - except BlockingIOError: - pass - - # a zero length read will block till socket is readable - ov = Overlapped(0) - ov.WSARecv(self._fd, 0, 0) - if ov.pending: - self._selector._defer(self, ov, EVENT_READ, ov.getresult) - raise BlockingIOError(errno.EAGAIN, 'try again') - else: - return self._sock.recv(length) - - def connect(self, address): - if self._state != 'unknown': - raise ValueError('socket is in state %r' % self._state) - - self._state = 'connecting' - BindLocal(self._fd, len(address)) - ov = Overlapped(0) - - try: - ov.ConnectEx(self._fd, address) - except OSError as e: - if e.winerror == 10022: - raise ConnectionRefusedError( - errno.ECONNREFUSED, e.strerror, None, e.winerror) - else: - raise - - def callback(): - try: - ov.getresult(False) - except OSError as e: - self._state = 'broken' - if e.winerror == 1225: - self._error = errno.ECONNREFUSED - else: - self._error = e.errno - raise - else: - self._state = 'connected' - - if ov.pending: - self._selector._defer(self, ov, EVENT_WRITE, callback) - raise BlockingIOError(errno.EINPROGRESS, 'connect in progress') - else: - callback() - - def getsockopt(self, level, optname, buflen=None): - if ((level, optname) == (socket.SOL_SOCKET, socket.SO_ERROR)): - if self._state == 'connecting': - return errno.EINPROGRESS - elif self._state == 'broken': - return self._error - if buflen is None: - return self._sock.getsockopt(level, optname) - else: - return self._sock.getsockopt(level, optname, buflen) - - def accept(self): - if self._state != 'listening': - raise ValueError('socket is in state %r' % self._state) - - res = self._result[EVENT_READ] - if res: - success, value = self._result[EVENT_READ] - self._result[EVENT_READ] = None - if success: - return value - else: - raise value - - if self._pending[EVENT_READ]: - raise BlockingIOError(errno.EAGAIN, 'try again') - - def callback(): - ov.getresult(False) - conn._sock.setsockopt( - socket.SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, self._fd) - conn._state = 'connected' - return conn, conn.getpeername() - - conn = socket.socket(self.family, self.type, self.proto) - conn = self._selector.wrap_socket(conn) - ov = Overlapped(0) - ov.AcceptEx(self._fd, conn.fileno()) - if ov.pending: - self._selector._defer(self, ov, EVENT_READ, callback) - raise BlockingIOError(errno.EAGAIN, 'try again') - else: - return callback() - - # XXX how do we deal with shutdown? - # XXX connect_ex, makefile, ...? - - -class IocpSelector(_BaseSelector): - - def __init__(self, *, concurrency=0xffffffff): - super().__init__() - self._iocp = CreateIoCompletionPort( - INVALID_HANDLE_VALUE, NULL, 0, concurrency) - self._address_to_info = {} - self._fd_to_fileobj = weakref.WeakValueDictionary() - - def wrap_socket(self, sock): - sock.setblocking(False) - return IocpSocket(self, sock) - - def _defer(self, sock, ov, flag, callback): - sock._pending[flag] = True - self._address_to_info[ov.address] = (sock, ov, flag, callback) - - def close(self): - super().close() - if self._iocp is not None: - try: - # cancel pending IO - for info in self._address_to_info.values(): - ov = info[1] - try: - ov.cancel() - except OSError as e: - # handle may have closed - pass # XXX check e.winerror - # wait for pending IO to stop - while self._address_to_info: - status = GetQueuedCompletionStatus(self._iocp, 1000) - if status is None: - continue - self._address_to_info.pop(status[3], None) - finally: - _winapi.CloseHandle(self._iocp) - self._iocp = None - - def select(self, timeout=None): - # XXX currently this is O(n) where n is number of registered fds - results = {} - for fd, key in self._fd_to_key.items(): - fileobj = self._fd_to_fileobj[fd] - if ((key.events & EVENT_READ) - and not fileobj._pending[EVENT_READ]): - results[fd] = results.get(fd, 0) | EVENT_READ - if ((key.events & (EVENT_WRITE|EVENT_CONNECT)) - and not fileobj._pending[EVENT_WRITE]): - results[fd] = results.get(fd, 0) | EVENT_WRITE | EVENT_CONNECT - - if results: - ms = 0 - elif timeout is None: - ms = INFINITE - elif timeout < 0: - raise ValueError('negative timeout') - else: - ms = int(timeout * 1000 + 0.5) - if ms >= INFINITE: - raise ValueError('timeout too big') - - while True: - status = GetQueuedCompletionStatus(self._iocp, ms) - if status is None: - break - try: - fobj, ov, flag, callback = self._address_to_info.pop(status[3]) - except KeyError: - continue - fobj._pending[flag] = False - try: - value = callback() - except OSError as e: - fobj._result[flag] = (False, e) - else: - fobj._result[flag] = (True, value) - key = self._fileobj_to_key.get(fobj) - if key and (key.events & flag): - results[fobj._fd] = results.get(fobj._fd, 0) | flag - ms = 0 - - tmp = [] - for fd, events in results.items(): - if events & EVENT_WRITE: - events |= EVENT_CONNECT - key = self._fd_to_key[fd] - tmp.append((key.fileobj, events, key.data)) - return tmp - - -def main(): - from .winsocketpair import socketpair - - selector = IocpSelector() - - # listen - listener = selector.wrap_socket(socket.socket()) - listener.bind(('127.0.0.1', 0)) - listener.listen(1) - - # connect - conn = selector.wrap_socket(socket.socket()) - try: - conn.connect(listener.getsockname()) - # conn.connect(('127.0.0.1', 7868)) - except BlockingIOError: - selector.register(conn, EVENT_WRITE) - res = selector.select(5) - # assert [(conn, EVENT_WRITE, None)] == res, res - error = conn.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) - assert error == 0, error - selector.unregister(conn) - - # accept - selector.register(listener, EVENT_READ) - while True: - try: - a, addr = listener.accept() - break - except BlockingIOError: - res = selector.select(1) - assert [(listener, EVENT_READ, None)] == res - selector.unregister(listener) - - - selector.register(a, EVENT_WRITE) - selector.register(conn, EVENT_READ) - msgs = [b"hello"] * 100 - - while selector.registered_count() > 0: - for (f, event, data) in selector.select(): - if event & EVENT_READ: - try: - msg = f.recv(20) - except BlockingIOError: - print("READ BLOCKED") - else: - print("read %r" % msg) - if not msg: - print("UNREGISTER READER") - selector.unregister(f) - f.close() - if event & EVENT_WRITE: - try: - nbytes = f.send(msgs.pop()) - except BlockingIOError: - print("WRITE BLOCKED") - except IndexError: - print("UNREGISTER WRITER") - selector.unregister(f) - f.close() - else: - print("bytes sent %r" % nbytes) - - - a, b = socketpair() - a = selector.wrap_socket(a) - b = selector.wrap_socket(b) - selector.register(a, EVENT_READ) - selector.register(b, EVENT_WRITE) - - msg = b"x"*(1024*1024*16) - view = memoryview(msg) - res = [] - - while selector.registered_count() > 0: - for (f, event, data) in selector.select(): - if event & EVENT_READ: - try: - data = f.recv(8192) - except BlockingIOError: - print("READ BLOCKED") - else: - res.append(data) - if not data: - print("UNREGISTER READER") - selector.unregister(f) - f.close() - if event & EVENT_WRITE: - try: - nbytes = f.send(view) - except BlockingIOError: - print("WRITE BLOCKED") - else: - assert nbytes == len(view) - if nbytes == 0: - print("UNREGISTER WRITER") - selector.unregister(f) - f.close() - else: - view = view[nbytes:] - - print(len(msg), sum(len(frag) for frag in res)) - - selector.close() - - -if __name__ == '__main__': - main() diff --git a/tulip/unix_events.py b/tulip/unix_events.py index ef2f6f80..1db6666e 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -72,56 +72,13 @@ def _raise_stop_error(): raise _StopError -class UnixEventLoop(events.EventLoop): - """Unix event loop. - - See events.EventLoop for API specification. - """ +class BaseEventLoop(events.EventLoop): - def __init__(self, selector=None): - super().__init__() - if selector is None: - # pick the best selector class for the platform - selector = selectors.Selector() - logging.debug('Using selector: %s', selector.__class__.__name__) - self._selector = selector + def __init__(self): self._ready = collections.deque() self._scheduled = [] self._default_executor = None self._signal_handlers = {} - self._make_self_pipe() - - def close(self): - if self._selector is not None: - self._selector.close() - self._selector = None - self._ssock.close() - self._csock.close() - - def _make_self_pipe(self): - # A self-socket, really. :-) - a, b = socketpair() - self._ssock = self._selector.wrap_socket(a) - self._csock = self._selector.wrap_socket(b) - self._ssock.setblocking(False) - self._csock.setblocking(False) - self.add_reader(self._ssock.fileno(), self._read_from_self) - - def _read_from_self(self): - try: - self._ssock.recv(1) - except socket.error as exc: - if exc.errno in _TRYAGAIN: - return - raise # Halp! - - def _write_to_self(self): - try: - self._csock.send(b'x') - except socket.error as exc: - if exc.errno in _TRYAGAIN: - return - raise # Halp! def run(self): """Run the event loop until nothing left to do or stop() called. @@ -248,16 +205,6 @@ def call_soon_threadsafe(self, callback, *args): self._write_to_self() return handler - def wrap_future(self, future): - """XXX""" - if isinstance(future, futures.Future): - return future # Don't wrap our own type of Future. - new_future = futures.Future() - future.add_done_callback( - lambda future: - self.call_soon_threadsafe(new_future._copy_state, future)) - return new_future - def run_in_executor(self, executor, callback, *args): if isinstance(callback, events.Handler): assert not args @@ -300,7 +247,6 @@ def create_connection(self, protocol_factory, host, port, *, ssl=False, sock = None try: sock = socket.socket(family=family, type=type, proto=proto) - sock = self._selector.wrap_socket(sock) sock.setblocking(False) yield self.sock_connect(sock, address) except socket.error as exc: @@ -325,10 +271,10 @@ def create_connection(self, protocol_factory, host, port, *, ssl=False, waiter = futures.Future() if ssl: sslcontext = None if isinstance(ssl, bool) else ssl - transport = _UnixSslTransport(self, sock, protocol, sslcontext, + transport = self.SslTransport(self, sock, protocol, sslcontext, waiter) else: - transport = _UnixSocketTransport(self, sock, protocol, waiter) + transport = self.SocketTransport(self, sock, protocol, waiter) yield from waiter return transport, protocol @@ -349,7 +295,6 @@ def start_serving(self, protocol_factory, host, port, *, exceptions = [] for family, type, proto, cname, address in infos: sock = socket.socket(family=family, type=type, proto=proto) - sock = self._selector.wrap_socket(sock) try: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind(address) @@ -362,9 +307,238 @@ def start_serving(self, protocol_factory, host, port, *, raise exceptions[0] sock.listen(backlog) sock.setblocking(False) + self._start_serving(protocol_factory, sock) + return sock + + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + self._check_signal(sig) + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except ValueError as exc: + raise RuntimeError(str(exc)) + handler = events.make_handler(None, callback, args) + self._signal_handlers[sig] = handler + try: + signal.signal(sig, self._handle_signal) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as nexc: + logging.info('set_wakeup_fd(-1) failed: %s', nexc) + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + return handler + + def _handle_signal(self, sig, arg): + """Internal helper that is the actual signal handler.""" + handler = self._signal_handlers.get(sig) + if handler is None: + return # Assume it's some race condition. + if handler.cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self.call_soon_threadsafe(handler.callback, *handler.args) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not.""" + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as exc: + logging.info('set_wakeup_fd(-1) failed: %s', exc) + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + if signal is None: + raise RuntimeError('Signals are not supported') + if not (1 <= sig < signal.NSIG): + raise ValueError('sig {} out of range(1, {})'.format(sig, + signal.NSIG)) + if sys.platform == 'win32': + raise RuntimeError('Signals are not really supported on Windows') + + def _add_callback(self, handler): + """Add a Handler to ready or scheduled.""" + if handler.cancelled: + return + if handler.when is None: + self._ready.append(handler) + else: + heapq.heappush(self._scheduled, handler) + + def wrap_future(self, future): + """XXX""" + if isinstance(future, futures.Future): + return future # Don't wrap our own type of Future. + new_future = futures.Future() + future.add_done_callback( + lambda future: + self.call_soon_threadsafe(new_future._copy_state, future)) + return new_future + + def _run_once(self, timeout=None): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0].cancelled: + heapq.heappop(self._scheduled) + + # Inspect the poll queue. If there's exactly one selectable + # file descriptor, it's the self-pipe, and if there's nothing + # scheduled, we should ignore it. + if self._selector.registered_count() > 1 or self._scheduled: + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0].when + deadline = max(0, when - time.monotonic()) + if timeout is None: + timeout = deadline + else: + timeout = min(timeout, deadline) + + t0 = time.monotonic() + event_list = self._selector.select(timeout) + t1 = time.monotonic() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + self._process_events(event_list) + + # Handle 'later' callbacks that are ready. + now = time.monotonic() + while self._scheduled: + handler = self._scheduled[0] + if handler.when > now: + break + handler = heapq.heappop(self._scheduled) + self._ready.append(handler) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is threadsafe without using locks. + ntodo = len(self._ready) + for i in range(ntodo): + handler = self._ready.popleft() + if not handler.cancelled: + try: + handler.callback(*handler.args) + except Exception: + logging.exception('Exception in callback %s %r', + handler.callback, handler.args) + + +class UnixEventLoop(BaseEventLoop): + """Unix event loop. + + See events.EventLoop for API specification. + """ + + @staticmethod + def SocketTransport(event_loop, sock, protocol, waiter=None): + return _UnixSocketTransport(event_loop, sock, protocol, waiter) + + @staticmethod + def SslTransport(event_loop, rawsock, protocol, sslcontext, waiter): + return _UnixSslTransport(event_loop, rawsock, protocol, + sslcontext, waiter) + + def __init__(self, selector=None): + super().__init__() + if selector is None: + # pick the best selector class for the platform + selector = selectors.Selector() + logging.debug('Using selector: %s', selector.__class__.__name__) + self._selector = selector + self._make_self_pipe() + + def close(self): + if self._selector is not None: + self._selector.close() + self._selector = None + self._ssock.close() + self._csock.close() + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self.add_reader(self._ssock.fileno(), self._read_from_self) + + def _read_from_self(self): + try: + self._ssock.recv(1) + except socket.error as exc: + if exc.errno in _TRYAGAIN: + return + raise # Halp! + + def _write_to_self(self): + try: + self._csock.send(b'x') + except socket.error as exc: + if exc.errno in _TRYAGAIN: + return + raise # Halp! + + def _start_serving(self, protocol_factory, sock): self.add_reader(sock.fileno(), self._accept_connection, protocol_factory, sock) - return sock def _accept_connection(self, protocol_factory, sock): try: @@ -597,181 +771,23 @@ def _sock_accept(self, fut, registered, sock): else: self.add_reader(fd, self._sock_accept, fut, True, sock) - def add_signal_handler(self, sig, callback, *args): - """Add a handler for a signal. UNIX only. - - Raise ValueError if the signal number is invalid or uncatchable. - Raise RuntimeError if there is a problem setting up the handler. - """ - self._check_signal(sig) - try: - # set_wakeup_fd() raises ValueError if this is not the - # main thread. By calling it early we ensure that an - # event loop running in another thread cannot add a signal - # handler. - signal.set_wakeup_fd(self._csock.fileno()) - except ValueError as exc: - raise RuntimeError(str(exc)) - handler = events.make_handler(None, callback, args) - self._signal_handlers[sig] = handler - try: - signal.signal(sig, self._handle_signal) - except OSError as exc: - del self._signal_handlers[sig] - if not self._signal_handlers: - try: - signal.set_wakeup_fd(-1) - except ValueError as nexc: - logging.info('set_wakeup_fd(-1) failed: %s', nexc) - if exc.errno == errno.EINVAL: - raise RuntimeError('sig {} cannot be caught'.format(sig)) - else: - raise - return handler - - def _handle_signal(self, sig, arg): - """Internal helper that is the actual signal handler.""" - handler = self._signal_handlers.get(sig) - if handler is None: - return # Assume it's some race condition. - if handler.cancelled: - self.remove_signal_handler(sig) # Remove it properly. - else: - self.call_soon_threadsafe(handler.callback, *handler.args) - - def remove_signal_handler(self, sig): - """Remove a handler for a signal. UNIX only. - - Return True if a signal handler was removed, False if not.""" - self._check_signal(sig) - try: - del self._signal_handlers[sig] - except KeyError: - return False - if sig == signal.SIGINT: - handler = signal.default_int_handler - else: - handler = signal.SIG_DFL - try: - signal.signal(sig, handler) - except OSError as exc: - if exc.errno == errno.EINVAL: - raise RuntimeError('sig {} cannot be caught'.format(sig)) - else: - raise - if not self._signal_handlers: - try: - signal.set_wakeup_fd(-1) - except ValueError as exc: - logging.info('set_wakeup_fd(-1) failed: %s', exc) - return True - - def _check_signal(self, sig): - """Internal helper to validate a signal. - - Raise ValueError if the signal number is invalid or uncatchable. - Raise RuntimeError if there is a problem setting up the handler. - """ - if not isinstance(sig, int): - raise TypeError('sig must be an int, not {!r}'.format(sig)) - if signal is None: - raise RuntimeError('Signals are not supported') - if not (1 <= sig < signal.NSIG): - raise ValueError('sig {} out of range(1, {})'.format(sig, - signal.NSIG)) - if sys.platform == 'win32': - raise RuntimeError('Signals are not really supported on Windows') - - def _add_callback(self, handler): - """Add a Handler to ready or scheduled.""" - if handler.cancelled: - return - if handler.when is None: - self._ready.append(handler) - else: - heapq.heappush(self._scheduled, handler) - - def _run_once(self, timeout=None): - """Run one full iteration of the event loop. - - This calls all currently ready callbacks, polls for I/O, - schedules the resulting callbacks, and finally schedules - 'call_later' callbacks. - """ - # TODO: Break each of these into smaller pieces. - # TODO: Refactor to separate the callbacks from the readers/writers. - # TODO: An alternative API would be to do the *minimal* amount - # of work, e.g. one callback or one I/O poll. - - # Remove delayed calls that were cancelled from head of queue. - while self._scheduled and self._scheduled[0].cancelled: - heapq.heappop(self._scheduled) - - # Inspect the poll queue. If there's exactly one selectable - # file descriptor, it's the self-pipe, and if there's nothing - # scheduled, we should ignore it. - if self._selector.registered_count() > 1 or self._scheduled: - if self._ready: - timeout = 0 - elif self._scheduled: - # Compute the desired timeout. - when = self._scheduled[0].when - deadline = max(0, when - time.monotonic()) - if timeout is None: - timeout = deadline + def _process_events(self, event_list): + for fileobj, mask, (reader, writer, connector) in event_list: + if mask & selectors.EVENT_READ and reader is not None: + if reader.cancelled: + self.remove_reader(fileobj) else: - timeout = min(timeout, deadline) - - t0 = time.monotonic() - event_list = self._selector.select(timeout) - t1 = time.monotonic() - argstr = '' if timeout is None else ' %.3f' % timeout - if t1-t0 >= 1: - level = logging.INFO - else: - level = logging.DEBUG - logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) - for fileobj, mask, (reader, writer, connector) in event_list: - if mask & selectors.EVENT_READ and reader is not None: - if reader.cancelled: - self.remove_reader(fileobj) - else: - self._add_callback(reader) - if mask & selectors.EVENT_WRITE and writer is not None: - if writer.cancelled: - self.remove_writer(fileobj) - else: - self._add_callback(writer) - elif mask & selectors.EVENT_CONNECT and connector is not None: - if connector.cancelled: - self.remove_connector(fileobj) - else: - self._add_callback(connector) - - # Handle 'later' callbacks that are ready. - now = time.monotonic() - while self._scheduled: - handler = self._scheduled[0] - if handler.when > now: - break - handler = heapq.heappop(self._scheduled) - self._ready.append(handler) - - # This is the only place where callbacks are actually *called*. - # All other places just add them to ready. - # Note: We run all currently scheduled callbacks, but not any - # callbacks scheduled by callbacks run this time around -- - # they will be run the next time (after another I/O poll). - # Use an idiom that is threadsafe without using locks. - ntodo = len(self._ready) - for i in range(ntodo): - handler = self._ready.popleft() - if not handler.cancelled: - try: - handler.callback(*handler.args) - except Exception: - logging.exception('Exception in callback %s %r', - handler.callback, handler.args) + self._add_callback(reader) + if mask & selectors.EVENT_WRITE and writer is not None: + if writer.cancelled: + self.remove_writer(fileobj) + else: + self._add_callback(writer) + elif mask & selectors.EVENT_CONNECT and connector is not None: + if connector.cancelled: + self.remove_connector(fileobj) + else: + self._add_callback(connector) class _UnixSocketTransport(transports.Transport): From 6a02d9349c7d238acb56498427010189da70eee2 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Fri, 25 Jan 2013 17:06:24 +0000 Subject: [PATCH 0264/1502] Merge. --- crawl.py | 10 +++---- runtests.py | 13 ++++++-- srv.py | 63 +++++++++++++++++++++++++++++++++++---- tulip/http_client.py | 2 +- tulip/http_client_test.py | 4 +-- tulip/tasks.py | 2 ++ 6 files changed, 78 insertions(+), 16 deletions(-) diff --git a/crawl.py b/crawl.py index 20527e70..b9881fcb 100755 --- a/crawl.py +++ b/crawl.py @@ -23,11 +23,11 @@ def __init__(self, rooturl): self.done = {} self.tasks = set() self.waiter = None - self.add(self.rooturl) # Set initial work. + self.addurl(self.rooturl, '') # Set initial work. self.run() # Kick off work. - def add(self, url): - url = urllib.parse.urljoin(self.rooturl, url) + def addurl(self, url, parenturl): + url = urllib.parse.urljoin(parenturl, url) url, frag = urllib.parse.urldefrag(url) if not url.startswith(self.rooturl): return False @@ -88,7 +88,7 @@ def process(self, url): if status[:3] in ('301', '302'): # Redirect. u = headers.get('location') or headers.get('uri') - if self.add(u): + if self.addurl(u, url): print(' ', url, status[:3], 'redirect to', u, end=END) elif status.startswith('200'): ctype = headers.get_content_type() @@ -101,7 +101,7 @@ def process(self, url): urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', line) for u in urls: - if self.add(u): + if self.addurl(u, url): print(' ', url, 'href to', u, end=END) ok = True finally: diff --git a/runtests.py b/runtests.py index 1eef652b..d18fe7d3 100644 --- a/runtests.py +++ b/runtests.py @@ -15,15 +15,24 @@ # Originally written by Beech Horn (for NDB). +import os import re import sys import unittest assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' +TULIP_DIR = os.path.join(os.path.dirname(__file__), 'tulip') + + def load_tests(includes=(), excludes=()): - mods = ['events', 'futures', 'tasks'] - test_mods = ['%s_test' % name for name in mods] + test_mods = [f[:-3] for f in os.listdir(TULIP_DIR) if f.endswith('_test.py')] + + if sys.platform == 'win32': + try: + test_mods.remove('subprocess_test') + except ValueError: + pass tulip = __import__('tulip', fromlist=test_mods) loader = unittest.TestLoader() diff --git a/srv.py b/srv.py index e49b8998..bae9fdd8 100644 --- a/srv.py +++ b/srv.py @@ -2,7 +2,7 @@ import email.message import email.parser -import gc +import os import re import tulip @@ -21,10 +21,32 @@ def __init__(self): def handle_request(self): line = yield from self.reader.readline() print('request line', line) - match = re.match(rb'GET (\S+) HTTP/(1.\d)\r?\n\Z', line) + match = re.match(rb'([A-Z]+) (\S+) HTTP/(1.\d)\r?\n\Z', line) if not match: self.transport.close() return + bmethod, bpath, bversion = match.groups() + print('method = {!r}; path = {!r}; version = {!r}'.format(bmethod, bpath, bversion)) + try: + path = bpath.decode('ascii') + except UnicodeError as exc: + print('not ascii', repr(bpath), exc) + path = None + else: + if not (path.isprintable() and path.startswith('/')) or '/.' in path: + print('bad path', repr(path)) + path = None + else: + path = '.' + path + if not os.path.exists(path): + print('no file', repr(path)) + path = None + else: + isdir = os.path.isdir(path) + if not path: + self.transport.write(b'HTTP/1.0 404 Not found\r\n\r\n') + self.transport.close() + return lines = [] while True: line = yield from self.reader.readline() @@ -36,10 +58,39 @@ def handle_request(self): break parser = email.parser.BytesHeaderParser() headers = parser.parsebytes(b''.join(lines)) - self.transport.write(b'HTTP/1.0 200 Ok\r\n' - b'Content-type: text/plain\r\n' - b'\r\n' - b'Hello world.\r\n') + write = self.transport.write + if isdir and not path.endswith('/'): + write(b'HTTP/1.0 302 Redirected\r\n' + b'URI: ' + bpath + b'/\r\n' + b'Location: ' + bpath + b'/\r\n' + b'\r\n') + return + write(b'HTTP/1.0 200 Ok\r\n') + if isdir: + write(b'Content-type: text/html\r\n') + else: + write(b'Content-type: text/plain\r\n') + write(b'\r\n') + if isdir: + write(b'
    \r\n') + for name in sorted(os.listdir(path)): + if name.isprintable() and not name.startswith('.'): + try: + bname = name.encode('ascii') + except UnicodeError as exc: + pass + else: + if os.path.isdir(os.path.join(path, name)): + write(b'
  • ' + bname + b'/
  • \r\n') + else: + write(b'
  • ' + bname + b'
  • \r\n') + write(b'
') + else: + try: + with open(path, 'rb') as f: + write(f.read()) + except OSError as exc: + write(b'Cannot open\r\n') self.transport.close() def connection_made(self, transport): diff --git a/tulip/http_client.py b/tulip/http_client.py index a7df6ff9..f01ee631 100644 --- a/tulip/http_client.py +++ b/tulip/http_client.py @@ -85,7 +85,7 @@ def readline(self): self.buffer.appendleft(tail) self.line_count -= 1 break - + line = b''.join(parts) self.byte_count -= len(line) diff --git a/tulip/http_client_test.py b/tulip/http_client_test.py index 6010c872..f712da94 100644 --- a/tulip/http_client_test.py +++ b/tulip/http_client_test.py @@ -7,7 +7,7 @@ from . import tasks -class StreamReaderTest(unittest.TestCase): +class StreamReaderTests(unittest.TestCase): DATA = b'line1\nline2\nline3\n' @@ -70,7 +70,7 @@ def read(): self.assertEqual( len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), stream.byte_count) - + if __name__ == '__main__': unittest.main() diff --git a/tulip/tasks.py b/tulip/tasks.py index 91c32076..af7eaaff 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -106,11 +106,13 @@ def _step(self, value=None, exc=None): super().cancel() else: self.set_exception(exc) + logging.exception('Exception in task') except BaseException as exc: if self._must_cancel: super().cancel() else: self.set_exception(exc) + logging.exception('BaseException in task') raise else: # XXX No check for self._must_cancel here? From dcce4ce5eb6f70c180dc01c6fc608eebc0f03539 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Fri, 25 Jan 2013 17:19:37 +0000 Subject: [PATCH 0265/1502] Disable subprocess_test on Windows. --- .hgeol | 2 + .hgignore | 8 + Makefile | 30 ++ NOTES | 130 +++++ README | 30 ++ TODO | 165 ++++++ check.py | 40 ++ crawl.py | 136 +++++ curl.py | 31 ++ old/Makefile | 16 + old/echoclt.py | 79 +++ old/echosvr.py | 60 +++ old/http_client.py | 78 +++ old/http_server.py | 68 +++ old/main.py | 134 +++++ old/p3time.py | 47 ++ old/polling.py | 535 ++++++++++++++++++ old/scheduling.py | 354 ++++++++++++ old/sockets.py | 348 ++++++++++++ old/transports.py | 496 +++++++++++++++++ old/xkcd.py | 18 + old/yyftime.py | 75 +++ runtests.py | 82 +++ srv.py | 127 +++++ sslsrv.py | 56 ++ tulip/TODO | 26 + tulip/__init__.py | 14 + tulip/events.py | 287 ++++++++++ tulip/events_test.py | 455 ++++++++++++++++ tulip/futures.py | 240 +++++++++ tulip/futures_test.py | 210 ++++++++ tulip/http_client.py | 296 ++++++++++ tulip/http_client_test.py | 76 +++ tulip/protocols.py | 58 ++ tulip/selectors.py | 430 +++++++++++++++ tulip/subprocess_test.py | 48 ++ tulip/subprocess_transport.py | 133 +++++ tulip/tasks.py | 279 ++++++++++ tulip/tasks_test.py | 233 ++++++++ tulip/transports.py | 90 ++++ tulip/unix_events.py | 989 ++++++++++++++++++++++++++++++++++ tulip/winsocketpair.py | 30 ++ 42 files changed, 7039 insertions(+) create mode 100644 .hgeol create mode 100644 .hgignore create mode 100644 Makefile create mode 100644 NOTES create mode 100644 README create mode 100644 TODO create mode 100644 check.py create mode 100755 crawl.py create mode 100755 curl.py create mode 100644 old/Makefile create mode 100644 old/echoclt.py create mode 100644 old/echosvr.py create mode 100644 old/http_client.py create mode 100644 old/http_server.py create mode 100644 old/main.py create mode 100644 old/p3time.py create mode 100644 old/polling.py create mode 100644 old/scheduling.py create mode 100644 old/sockets.py create mode 100644 old/transports.py create mode 100755 old/xkcd.py create mode 100644 old/yyftime.py create mode 100644 runtests.py create mode 100644 srv.py create mode 100644 sslsrv.py create mode 100644 tulip/TODO create mode 100644 tulip/__init__.py create mode 100644 tulip/events.py create mode 100644 tulip/events_test.py create mode 100644 tulip/futures.py create mode 100644 tulip/futures_test.py create mode 100644 tulip/http_client.py create mode 100644 tulip/http_client_test.py create mode 100644 tulip/protocols.py create mode 100644 tulip/selectors.py create mode 100644 tulip/subprocess_test.py create mode 100644 tulip/subprocess_transport.py create mode 100644 tulip/tasks.py create mode 100644 tulip/tasks_test.py create mode 100644 tulip/transports.py create mode 100644 tulip/unix_events.py create mode 100644 tulip/winsocketpair.py diff --git a/.hgeol b/.hgeol new file mode 100644 index 00000000..b6910a22 --- /dev/null +++ b/.hgeol @@ -0,0 +1,2 @@ +[patterns] +** = native diff --git a/.hgignore b/.hgignore new file mode 100644 index 00000000..42309f0c --- /dev/null +++ b/.hgignore @@ -0,0 +1,8 @@ +.*\.py[co]$ +.*~$ +.*\.orig$ +.*\#.*$ +.*@.*$ +\.coverage$ +htmlcov$ +\.DS_Store$ diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..d11e9716 --- /dev/null +++ b/Makefile @@ -0,0 +1,30 @@ +PYTHON=python3 +COVERAGE=coverage3 +NONTESTS=`find tulip -name [a-z]\*.py ! -name \*_test.py` +FLAGS= + +test: + $(PYTHON) runtests.py -v $(FLAGS) + +testloop: + while sleep 1; do $(PYTHON) runtests.py -v $(FLAGS); done + +cov coverage: + $(COVERAGE) run --branch runtests.py -v $(FLAGS) + $(COVERAGE) html $(NONTESTS) + $(COVERAGE) report -m $(NONTESTS) + echo "open file://`pwd`/htmlcov/index.html" + +check: + $(PYTHON) check.py + +clean: + rm -rf __pycache__ */__pycache__ + rm -f *.py[co] */*.py[co] + rm -f *~ */*~ + rm -f .*~ */.*~ + rm -f @* */@* + rm -f '#'*'#' */'#'*'#' + rm -f *.orig */*.orig + rm -f .coverage + rm -rf htmlcov diff --git a/NOTES b/NOTES new file mode 100644 index 00000000..6f41578e --- /dev/null +++ b/NOTES @@ -0,0 +1,130 @@ +Notes from the second Tulip/Twisted meet-up +=========================================== + +Rackspace, 12/11/2012 +Glyph, Brian Warner, David Reid, Duncan McGreggor, others + +Flow control +------------ + +- Pause/resume on transport manages data_received. + +- There's also an API to tell the transport whom to pause when the + write calls are overwhelming it: IConsumer.registerProducer(). + +- There's also something called pipes but it's built on top of the + old interface. + +- Twisted has variations on the basic flow control that I should + ignore. + +Half_close +---------- + +- This sends an EOF after writing some stuff. + +- Can't write any more. + +- Problem with TLS is known (the RFC sadly specifies this behavior). + +- It must be dynamimcally discoverable whether the transport supports + half_close, since the protocol may have to do something different to + make up for its missing (e.g. use chunked encoding). Twisted uses + an interface check for this and also hasattr(trans, 'halfClose') + but a flag (or flag method) is fine too. + +Constructing transport and protocol +----------------------------------- + +- There are good reasons for passing a function to the transport + construction helper that creates the protocol. (You need these + anyway for server-side protocols.) The sequence of events is + something like + + . open socket + . create transport (pass it a socket?) + . create protocol (pass it nothing) + . proto.make_connection(transport); this does: + . self.transport = transport + . self.connection_made(transport) + + But it seems okay to skip make_connection and setting .transport. + Note that make_connection() is a concrete method on the Protocol + implementation base class, while connection_made() is an abstract + method on IProtocol. + +Event Loop +---------- + +- We discussed the sequence of actions in the event loop. I think in the + end we're fine with what Tulip currently does. There are two choices: + + Tulip: + . run ready callbacks until there aren't any left + . poll, adding more callbacks to the ready list + . add now-ready delayed callbacks to the ready list + . go to top + + Tornado: + . run all currently ready callbacks (but not new ones added during this) + . (the rest is the same) + + The difference is that in the Tulip version, CPU bound callbacks + that keep adding more to the queue will starve I/O (and yielding to + other tasks won't actually cause I/O to happen unless you do + e.g. sleep(0.001)). OTOH this may be good because it means there's + less overhead if you frequently split operations in two. + +- I think Twisted does it Tornado style (in a convoluted way :-), but + it may not matter, and it's important to leave this vague so + implementations can do what's best for their platform. (E.g. if the + event loop is built into the OS there are different trade-offs.) + +System call cost +---------------- + +- System calls on MacOS are expensive, on Linux they are cheap. + +- Optimal buffer size ~16K. + +- Try joining small buffer pieces together, but expect to be tuning + this later. + +Futures +------- + +- Futures are the most robust API for async stuff, you can check + errors etc. So let's do this. + +- Just don't implement wait(). + +- For the basics, however, (recv/send, mostly), don't use Futures but use + basic callbacks, transport/protocol style. + +- make_connection() (by any name) can return a Future, it makes it + easier to check for errors. + +- This means revisiting the Tulip proactor branch (IOCP). + +- The semantics of add_done_callback() are fuzzy about in which thread + the callback will be called. (It may be the current thread or + another one.) We don't like that. But always inserting a + call_soon() indirection may be expensive? Glyph suggested changing + the add_done_callback() method name to something else to indicate + the changed promise. + +- Separately, I've been thinking about having two versions of + call_soon() -- a more heavy-weight one to be called from other + threads that also writes a byte to the self-pipe. + +Signals +------- + +- There was a side conversation about signals. A signal handler is + similar to another thread, so probably should use (the heavy-weight + version of) call_soon() to schedule the real callback and not do + anything else. + +- Glyph vaguely recalled some trickiness with the self-pipe. We + should be able to fix this afterwards if necessary, it shouldn't + affect the API design. diff --git a/README b/README new file mode 100644 index 00000000..c1c86a54 --- /dev/null +++ b/README @@ -0,0 +1,30 @@ +Tulip is the codename for my reference implementation of PEP 3156. + +PEP 3156: http://www.python.org/dev/peps/pep-3156/ + +*** This requires Python 3.3 or later! *** + +Copyright/license: Open source, Apache 2.0. Enjoy. + +Master Mercurial repo: http://code.google.com/p/tulip/ + +The old code lives in the subdirectory 'old'; the new code (conforming +to PEP 3156, under construction) lives in the 'tulip' subdirectory. + +To run tests: + - make test + +To run coverage (after installing coverage3, see below): + - make coverage + +To install coverage3 (coverage.py for Python 3), you need: + - Distribute (http://packages.python.org/distribute/) + - Coverage (http://nedbatchelder.com/code/coverage/) + What worked for me: + - curl -O http://python-distribute.org/distribute_setup.py + - python3 distribute_setup.py + - + - cd coveragepy + - python3 setup.py install + +--Guido van Rossum diff --git a/TODO b/TODO new file mode 100644 index 00000000..b9559ef0 --- /dev/null +++ b/TODO @@ -0,0 +1,165 @@ +# -*- Mode: text -*- + +TO DO LARGER TASKS + +- Need more examples. + +- Benchmarkable but more realistic HTTP server? + +- Example of using UDP. + +- Write up a tutorial for the scheduling API. + +- More systematic approach to logging. Logger objects? What about + heavy-duty logging, tracking essentially all task state changes? + +- Restructure directory, move demos and benchmarks to subdirectories. + + +TO DO LATER + +- When multiple tasks are accessing the same socket, they should + either get interleaved I/O or an immediate exception; it should not + compromise the integrity of the scheduler or the app or leave a task + hanging. + +- For epoll you probably want to check/(log?) EPOLLHUP and EPOLLERR errors. + +- Add the simplest API possible to run a generator with a timeout. + +- Ensure multiple tasks can do atomic writes to the same pipe (since + UNIX guarantees that short writes to pipes are atomic). + +- Ensure some easy way of distributing accepted connections across tasks. + +- Be wary of thread-local storage. There should be a standard API to + get the current Context (which holds current task, event loop, and + maybe more) and a standard meta-API to change how that standard API + works (i.e. without monkey-patching). + +- See how much of asyncore I've already replaced. + +- Could BufferedReader reuse the standard io module's readers??? + +- Support ZeroMQ "sockets" which are user objects. Though possibly + this can be supported by getting the underlying fd? See + http://mail.python.org/pipermail/python-ideas/2012-October/017532.html + OTOH see + https://github.com/zeromq/pyzmq/blob/master/zmq/eventloop/ioloop.py + +- Study goroutines (again). + +- Benchmarks: http://nichol.as/benchmark-of-python-web-servers + + +FROM OLDER LIST + +- Multiple readers/writers per socket? (At which level? pollster, + eventloop, or scheduler?) + +- Could poll() usefully be an iterator? + +- Do we need to support more epoll and/or kqueue modes/flags/options/etc.? + +- Optimize register/unregister calls away if they cancel each other out? + +- Add explicit wait queue to wait for Task's completion, instead of + callbacks? + +- Implement various lock styles a la threading.py. + +- Look at pyfdpdlib's ioloop.py: + http://code.google.com/p/pyftpdlib/source/browse/trunk/pyftpdlib/lib/ioloop.py + + +MISTAKES I MADE + +- Forgetting yield from. (E.g.: scheduler.sleep(1); listener.accept().) + +- Forgot to add bare yield at end of internal function, after block(). + +- Forgot to call add_done_callback(). + +- Forgot to pass an undoer to block(), bug only found when cancelled. + +- Subtle accounting mistake in a callback. + +- Used context.eventloop from a different thread, forgetting about TLS. + +- Nasty race: eventloop.ready may contain both an I/O callback and a + cancel callback. How to avoid? Keep the DelayedCall in ready. Is + that enough? + +- If a toplevel task raises an error it just stops and nothing is logged + unless you have debug logging on. This confused me. (Then again, + previously I logged whenever a task raised an error, and that was too + chatty...) + +- Forgot to set the connection socket returned by accept() in + nonblocking mode. + +- Nastiest so far (cost me about a day): A race condition in + call_in_thread() where the Future's done_callback (which was + task.unblock()) would run immediately at the time when + add_done_callback() was called, and this screwed over the task + state. Solution: wrap the callback in eventloop.call_later(). + Ironically, I had a comment stating there might be a race condition. + +- Another bug where I was calling unblock() for the current thread + immediately after calling block(), before yielding. + +- readexactly() wasn't checking for EOF, so could be looping. + (Worse, the first fix I attempted was wrong.) + +- Spent a day trying to understand why a tentative patch trying to + move the recv() implementation into the eventloop (or the pollster) + resulted in problems cancelling a recv() call. Ultimately the + problem is that the cancellation mechanism is part of the coroutine + scheduler, which simply throws an exception into a task when it next + runs, and there isn't anything to be interrupted in the eventloop; + but the eventloop still has a reader registered (which will never + fire because I suspended the server -- that's my test case :-). + Then, the eventloop keeps running until the last file descriptor is + unregistered. What contributed to this disaster? + * I didn't build the whole infrastructure, just played with recv() + * I don't have unittests + * I don't have good logging to see what is going + +- In sockets.py, in some SSL error handling code, used the wrong + variable (sock instead of sslsock). A linter would have found this. + +- In polling.py, in KqueuePollster.register_writer(), a copy/paste + error where I was testing for "if fd not in self.readers" instead of + writers. This only came out when I had both a reader and a writer + for the same fd. + +- Submitted some changes prematurely (forgot to pass the filename on + hg ci). + +- Forgot again that shutdown(SHUT_WR) on an ssl socket does not work + as I expected. I ran into this with the origininal sockets.py and + again in transport.py. + +- Having the same callback for both reading and writing has a problem: + it may be scheduled twice, and if the first call closes the socket, + the second runs into trouble. + + +MISTAKES I MADE IN TULIP V2 + +- Nice one: Future._schedule_callbacks() wasn't scheduling any callbacks. + Spot the bug in these four lines: + + def _schedule_callbacks(self): + callbacks = self._callbacks[:] + self._callbacks[:] = [] + for callback in self._callbacks: + self._event_loop.call_soon(callback, self) + + The good news is that I found it with a unittest (albeit not the + unittest intended to exercise this particular method :-( ). + +- In _make_self_pipe_or_sock(), called _pollster.register_reader() + instead of add_reader(), trying to optimize something but breaking + things instead (since the -- internal -- API of register_reader() + had changed). diff --git a/check.py b/check.py new file mode 100644 index 00000000..f0aa9a66 --- /dev/null +++ b/check.py @@ -0,0 +1,40 @@ +"""Search for lines > 80 chars or with trailing whitespace.""" + +import sys, os + +def main(): + args = sys.argv[1:] or os.curdir + for arg in args: + if os.path.isdir(arg): + for dn, dirs, files in os.walk(arg): + for fn in sorted(files): + if fn.endswith('.py'): + process(os.path.join(dn, fn)) + dirs[:] = [d for d in dirs if d[0] != '.'] + dirs.sort() + else: + process(arg) + +def isascii(x): + try: + x.encode('ascii') + return True + except UnicodeError: + return False + +def process(fn): + try: + f = open(fn) + except IOError as err: + print(err) + return + try: + for i, line in enumerate(f): + line = line.rstrip('\n') + sline = line.rstrip() + if len(line) > 80 or line != sline or not isascii(line): + print('%s:%d:%s%s' % (fn, i+1, sline, '_' * (len(line) - len(sline)))) + finally: + f.close() + +main() diff --git a/crawl.py b/crawl.py new file mode 100755 index 00000000..b9881fcb --- /dev/null +++ b/crawl.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 + +import logging +import re +import signal +import socket +import sys +import urllib.parse + +import tulip +from tulip import http_client + +END = '\n' +MAXTASKS = 100 + + +class Crawler: + + def __init__(self, rooturl): + self.rooturl = rooturl + self.todo = set() + self.busy = set() + self.done = {} + self.tasks = set() + self.waiter = None + self.addurl(self.rooturl, '') # Set initial work. + self.run() # Kick off work. + + def addurl(self, url, parenturl): + url = urllib.parse.urljoin(parenturl, url) + url, frag = urllib.parse.urldefrag(url) + if not url.startswith(self.rooturl): + return False + if url in self.busy or url in self.done or url in self.todo: + return False + self.todo.add(url) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(None) + return True + + @tulip.task + def run(self): + while self.todo or self.busy or self.tasks: + complete, self.tasks = yield from tulip.wait(self.tasks, timeout=0) + print(len(complete), 'completed tasks,', len(self.tasks), + 'still pending ', end=END) + for task in complete: + try: + yield from task + except Exception as exc: + print('Exception in task:', exc, end=END) + while self.todo and len(self.tasks) < MAXTASKS: + url = self.todo.pop() + self.busy.add(url) + self.tasks.add(self.process(url)) # Async task. + if self.busy: + self.waiter = tulip.Future() + yield from self.waiter + tulip.get_event_loop().stop() + + @tulip.task + def process(self, url): + ok = False + p = None + try: + print('processing', url, end=END) + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not path: + path = '/' + if query: + path = '?'.join([path, query]) + p = http_client.HttpClientProtocol(netloc, path=path, + ssl=(scheme=='https')) + delay = 1 + while True: + try: + status, headers, stream = yield from p.connect() + break + except socket.error as exc: + if delay >= 60: + raise + print('...', url, 'has error', repr(str(exc)), + 'retrying after sleep', delay, '...', end=END) + yield from tulip.sleep(delay) + delay *= 2 + if status[:3] in ('301', '302'): + # Redirect. + u = headers.get('location') or headers.get('uri') + if self.addurl(u, url): + print(' ', url, status[:3], 'redirect to', u, end=END) + elif status.startswith('200'): + ctype = headers.get_content_type() + if ctype == 'text/html': + while True: + line = yield from stream.readline() + if not line: + break + line = line.decode('utf-8', 'replace') + urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', + line) + for u in urls: + if self.addurl(u, url): + print(' ', url, 'href to', u, end=END) + ok = True + finally: + if p is not None: + p.transport.close() + self.done[url] = ok + self.busy.remove(url) + if not ok: + print('failure for', url, sys.exc_info(), end=END) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(None) + + +def main(): + rooturl = sys.argv[1] + c = Crawler(rooturl) + loop = tulip.get_event_loop() + try: + loop.add_signal_handler(signal.SIGINT, loop.stop) + except RuntimeError: + pass + loop.run_forever() + print('todo:', len(c.todo)) + print('busy:', len(c.busy)) + print('done:', len(c.done), '; ok:', sum(c.done.values())) + print('tasks:', len(c.tasks)) + + +if __name__ == '__main__': + main() diff --git a/curl.py b/curl.py new file mode 100755 index 00000000..1a73c194 --- /dev/null +++ b/curl.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 + +import sys +import urllib.parse + +import tulip +from tulip import http_client + + +def main(): + url = sys.argv[1] + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not path: + path = '/' + if query: + path = '?'.join([path, query]) + print(netloc, path, scheme) + p = http_client.HttpClientProtocol(netloc, path=path, + ssl=(scheme=='https')) + f = p.connect() + sts, headers, stream = p.event_loop.run_until_complete(tulip.Task(f)) + print(sts) + for k, v in headers.items(): + print('{}: {}'.format(k, v)) + print() + data = p.event_loop.run_until_complete(tulip.Task(stream.read(1000000))) + print(data.decode('utf-8', 'replace')) + + +if __name__ == '__main__': + main() diff --git a/old/Makefile b/old/Makefile new file mode 100644 index 00000000..d352cd70 --- /dev/null +++ b/old/Makefile @@ -0,0 +1,16 @@ +PYTHON=python3 + +main: + $(PYTHON) main.py -v + +echo: + $(PYTHON) echosvr.py -v + +profile: + $(PYTHON) -m profile -s time main.py + +time: + $(PYTHON) p3time.py + +ytime: + $(PYTHON) yyftime.py diff --git a/old/echoclt.py b/old/echoclt.py new file mode 100644 index 00000000..c24c573e --- /dev/null +++ b/old/echoclt.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3.3 +"""Example echo client.""" + +# Stdlib imports. +import logging +import socket +import sys +import time + +# Local imports. +import scheduling +import sockets + + +def echoclient(host, port): + """COROUTINE""" + testdata = b'hi hi hi ha ha ha\n' + try: + trans = yield from sockets.create_transport(host, port, + af=socket.AF_INET) + except OSError: + return False + try: + ok = yield from trans.send(testdata) + if ok: + response = yield from trans.recv(100) + ok = response == testdata.upper() + return ok + finally: + trans.close() + + +def doit(n): + """COROUTINE""" + t0 = time.time() + tasks = set() + for i in range(n): + t = scheduling.Task(echoclient('127.0.0.1', 1111), 'client-%d' % i) + tasks.add(t) + ok = 0 + bad = 0 + for t in tasks: + try: + yield from t + except Exception: + bad += 1 + else: + ok += 1 + t1 = time.time() + print('ok: ', ok) + print('bad:', bad) + print('dt: ', round(t1-t0, 6)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Get integer from command line. + n = 1 + for arg in sys.argv[1:]: + if not arg.startswith('-'): + n = int(arg) + break + + # Run scheduler, starting it off with doit(). + scheduling.run(doit(n)) + + +if __name__ == '__main__': + main() diff --git a/old/echosvr.py b/old/echosvr.py new file mode 100644 index 00000000..4085f4c6 --- /dev/null +++ b/old/echosvr.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3.3 +"""Example echo server.""" + +# Stdlib imports. +import logging +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + while True: + line = yield from rdr.readline() + logging.debug('Received: %r from %r', line, addr) + if not line: + break + yield from trans.send(line.upper()) + logging.debug('Closing %r', addr) + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 1111, + af=socket.AF_INET, + backlog=100) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() diff --git a/old/http_client.py b/old/http_client.py new file mode 100644 index 00000000..8937ba20 --- /dev/null +++ b/old/http_client.py @@ -0,0 +1,78 @@ +"""Crummy HTTP client. + +This is not meant as an example of how to write a good client. +""" + +# Stdlib. +import re +import time + +# Local. +import sockets + + +def urlfetch(host, port=None, path='/', method='GET', + body=None, hdrs=None, encoding='utf-8', ssl=None, af=0): + """COROUTINE: Make an HTTP 1.0 request.""" + t0 = time.time() + if port is None: + if ssl: + port = 443 + else: + port = 80 + trans = yield from sockets.create_transport(host, port, ssl=ssl, af=af) + yield from trans.send(method.encode(encoding) + b' ' + + path.encode(encoding) + b' HTTP/1.0\r\n') + if hdrs: + kwds = dict(hdrs) + else: + kwds = {} + if 'host' not in kwds: + kwds['host'] = host + if body is not None: + kwds['content_length'] = len(body) + for header, value in kwds.items(): + yield from trans.send(header.replace('_', '-').encode(encoding) + + b': ' + value.encode(encoding) + b'\r\n') + + yield from trans.send(b'\r\n') + if body is not None: + yield from trans.send(body) + + # Read HTTP response line. + rdr = sockets.BufferedReader(trans) + resp = yield from rdr.readline() + m = re.match(br'(?ix) http/(\d\.\d) \s+ (\d\d\d) \s+ ([^\r]*)\r?\n\Z', + resp) + if not m: + trans.close() + raise IOError('No valid HTTP response: %r' % resp) + http_version, status, message = m.groups() + + # Read HTTP headers. + headers = [] + hdict = {} + while True: + line = yield from rdr.readline() + if not line.strip(): + break + m = re.match(br'([^\s:]+):\s*([^\r]*)\r?\n\Z', line) + if not m: + raise IOError('Invalid header: %r' % line) + header, value = m.groups() + headers.append((header, value)) + hdict[header.decode(encoding).lower()] = value.decode(encoding) + + # Read response body. + content_length = hdict.get('content-length') + if content_length is not None: + size = int(content_length) # TODO: Catch errors. + assert size >= 0, size + else: + size = 2**20 # Protective limit (1 MB). + data = yield from rdr.readexactly(size) + trans.close() # Can this block? + t1 = time.time() + result = (host, port, path, int(status), len(data), round(t1-t0, 3)) +## print(result) + return result diff --git a/old/http_server.py b/old/http_server.py new file mode 100644 index 00000000..2b1e3dd6 --- /dev/null +++ b/old/http_server.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3.3 +"""Simple HTTP server. + +This currenty exists just so we can benchmark this thing! +""" + +# Stdlib imports. +import logging +import re +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + ##logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + + # Read but ignore request line. + request_line = yield from rdr.readline() + + # Consume headers but don't interpret them. + while True: + header_line = yield from rdr.readline() + if not header_line.strip(): + break + + # Always send an empty 200 response and close. + yield from trans.send(b'HTTP/1.0 200 Ok\r\n\r\n') + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 8080, + af=socket.AF_INET) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() diff --git a/old/main.py b/old/main.py new file mode 100644 index 00000000..c1f9d0a8 --- /dev/null +++ b/old/main.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3.3 +"""Example HTTP client using yield-from coroutines (PEP 380). + +Requires Python 3.3. + +There are many micro-optimizations possible here, but that's not the point. + +Some incomplete laundry lists: + +TODO: +- Take test urls from command line. +- Move urlfetch to a separate module. +- Profiling. +- Docstrings. +- Unittests. + +FUNCTIONALITY: +- Connection pool (keep connection open). +- Chunked encoding (request and response). +- Pipelining, e.g. zlib (request and response). +- Automatic encoding/decoding. +""" + +__author__ = 'Guido van Rossum ' + +# Standard library imports (keep in alphabetic order). +import logging +import os +import time +import socket +import sys + +# Local imports (keep in alphabetic order). +import scheduling +import http_client + + + +def doit2(): + argses = [ + ('localhost', 8080, '/'), + ('127.0.0.1', 8080, '/home'), + ('python.org', 80, '/'), + ('xkcd.com', 443, '/'), + ] + results = yield from scheduling.map_over( + lambda args: http_client.urlfetch(*args), argses, timeout=2) + for res in results: + print('-->', res) + return [] + + +def doit(): + TIMEOUT = 2 + tasks = set() + + # This references NDB's default test service. + # (Sadly the service is single-threaded.) + task1 = scheduling.Task(http_client.urlfetch('localhost', 8080, path='/'), + 'root', timeout=TIMEOUT) + tasks.add(task1) + task2 = scheduling.Task(http_client.urlfetch('127.0.0.1', 8080, + path='/home'), + 'home', timeout=TIMEOUT) + tasks.add(task2) + + # Fetch python.org home page. + task3 = scheduling.Task(http_client.urlfetch('python.org', 80, path='/'), + 'python', timeout=TIMEOUT) + tasks.add(task3) + + # Fetch XKCD home page using SSL. (Doesn't like IPv6.) + task4 = scheduling.Task(http_client.urlfetch('xkcd.com', ssl=True, path='/', + af=socket.AF_INET), + 'xkcd', timeout=TIMEOUT) + tasks.add(task4) + +## # Fetch many links from python.org (/x.y.z). +## for x in '123': +## for y in '0123456789': +## path = '/{}.{}'.format(x, y) +## g = http_client.urlfetch('82.94.164.162', 80, +## path=path, hdrs={'host': 'python.org'}) +## t = scheduling.Task(g, path, timeout=2) +## tasks.add(t) + +## print(tasks) + yield from scheduling.Task(scheduling.sleep(1), timeout=0.2).wait() + winners = yield from scheduling.wait_any(tasks) + print('And the winners are:', [w.name for w in winners]) + tasks = yield from scheduling.wait_all(tasks) + print('And the players were:', [t.name for t in tasks]) + return tasks + + +def logtimes(real): + utime, stime, cutime, cstime, unused = os.times() + logging.info('real %10.3f', real) + logging.info('user %10.3f', utime + cutime) + logging.info('sys %10.3f', stime + cstime) + + +def main(): + t0 = time.time() + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + task = scheduling.run(doit()) + if task.exception: + print('Exception:', repr(task.exception)) + if isinstance(task.exception, AssertionError): + raise task.exception + else: + for t in task.result: + print(t.name + ':', + repr(t.exception) if t.exception else t.result) + + # Report real, user, sys times. + t1 = time.time() + logtimes(t1-t0) + + +if __name__ == '__main__': + main() diff --git a/old/p3time.py b/old/p3time.py new file mode 100644 index 00000000..35e14c96 --- /dev/null +++ b/old/p3time.py @@ -0,0 +1,47 @@ +"""Compare timing of plain vs. yield-from calls.""" + +import gc +import time + +def plain(n): + if n <= 0: + return 1 + l = plain(n-1) + r = plain(n-1) + return l + 1 + r + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def submain(depth): + t0 = time.time() + k = plain(depth) + t1 = time.time() + fmt = ' {} {} {:-9,.5f}' + delta0 = t1-t0 + print(('plain' + fmt).format(depth, k, delta0)) + + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + delta1 = t1-t0 + print(('coro.' + fmt).format(depth, k, delta1)) + if delta0: + print(('relat' + fmt).format(depth, k, delta1/delta0)) + +def main(reasonable=16): + gc.disable() + for depth in range(reasonable): + submain(depth) + +if __name__ == '__main__': + main() diff --git a/old/polling.py b/old/polling.py new file mode 100644 index 00000000..6586efcc --- /dev/null +++ b/old/polling.py @@ -0,0 +1,535 @@ +"""Event loop and related classes. + +The event loop can be broken up into a pollster (the part responsible +for telling us when file descriptors are ready) and the event loop +proper, which wraps a pollster with functionality for scheduling +callbacks, immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. + +There are several implementations of the pollster part, several using +esoteric system calls that exist only on some platforms. These are: + +- kqueue (most BSD systems) +- epoll (newer Linux systems) +- poll (most UNIX systems) +- select (all UNIX systems, and Windows) +- TODO: Support IOCP on Windows and some UNIX platforms. + +NOTE: We don't use select on systems where any of the others is +available, because select performs poorly as the number of file +descriptors goes up. The ranking is roughly: + + 1. kqueue, epoll, IOCP + 2. poll + 3. select + +TODO: +- Optimize the various pollsters. +- Unittests. +""" + +import collections +import concurrent.futures +import heapq +import logging +import os +import select +import threading +import time + + +class PollsterBase: + """Base class for all polling implementations. + + This defines an interface to register and unregister readers and + writers for specific file descriptors, and an interface to get a + list of events. There's also an interface to check whether any + readers or writers are currently registered. + """ + + def __init__(self): + super().__init__() + self.readers = {} # {fd: token, ...}. + self.writers = {} # {fd: token, ...}. + + def pollable(self): + """Return True if any readers or writers are currently registered.""" + return bool(self.readers or self.writers) + + # Subclasses are expected to extend the add/remove methods. + + def register_reader(self, fd, token): + """Add or update a reader for a file descriptor.""" + self.readers[fd] = token + + def register_writer(self, fd, token): + """Add or update a writer for a file descriptor.""" + self.writers[fd] = token + + def unregister_reader(self, fd): + """Remove the reader for a file descriptor.""" + del self.readers[fd] + + def unregister_writer(self, fd): + """Remove the writer for a file descriptor.""" + del self.writers[fd] + + def poll(self, timeout=None): + """Poll for events. A subclass must implement this. + + If timeout is omitted or None, this blocks until at least one + event is ready. Otherwise, timeout gives a maximum time to + wait (in seconds as an int or float) -- the method returns as + soon as at least one event is ready or when the timeout is + expired. For a non-blocking poll, pass 0. + + The return value is a list of events; it is empty when the + timeout expired before any events were ready. Each event + is a token previously passed to register_reader/writer(). + """ + raise NotImplementedError + + +class SelectPollster(PollsterBase): + """Pollster implementation using select.""" + + def poll(self, timeout=None): + readable, writable, _ = select.select(self.readers, self.writers, + [], timeout) + events = [] + events += (self.readers[fd] for fd in readable) + events += (self.writers[fd] for fd in writable) + return events + + +class PollPollster(PollsterBase): + """Pollster implementation using poll.""" + + def __init__(self): + super().__init__() + self._poll = select.poll() + + def _update(self, fd): + assert isinstance(fd, int), fd + flags = 0 + if fd in self.readers: + flags |= select.POLLIN + if fd in self.writers: + flags |= select.POLLOUT + if flags: + self._poll.register(fd, flags) + else: + self._poll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + # Timeout is in seconds, but poll() takes milliseconds. + msecs = None if timeout is None else int(round(1000 * timeout)) + events = [] + for fd, flags in self._poll.poll(msecs): + if flags & (select.POLLIN | select.POLLHUP): + if fd in self.readers: + events.append(self.readers[fd]) + if flags & (select.POLLOUT | select.POLLHUP): + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class EPollPollster(PollsterBase): + """Pollster implementation using epoll.""" + + def __init__(self): + super().__init__() + self._epoll = select.epoll() + + def _update(self, fd): + assert isinstance(fd, int), fd + eventmask = 0 + if fd in self.readers: + eventmask |= select.EPOLLIN + if fd in self.writers: + eventmask |= select.EPOLLOUT + if eventmask: + try: + self._epoll.register(fd, eventmask) + except IOError: + self._epoll.modify(fd, eventmask) + else: + self._epoll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + if timeout is None: + timeout = -1 # epoll.poll() uses -1 to mean "wait forever". + events = [] + for fd, eventmask in self._epoll.poll(timeout): + if eventmask & select.EPOLLIN: + if fd in self.readers: + events.append(self.readers[fd]) + if eventmask & select.EPOLLOUT: + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class KqueuePollster(PollsterBase): + """Pollster implementation using kqueue.""" + + def __init__(self): + super().__init__() + self._kqueue = select.kqueue() + + def register_reader(self, fd, callback, *args): + if fd not in self.readers: + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_reader(fd, callback, *args) + + def register_writer(self, fd, callback, *args): + if fd not in self.writers: + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_writer(fd, callback, *args) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def poll(self, timeout=None): + events = [] + max_ev = len(self.readers) + len(self.writers) + for kev in self._kqueue.control(None, max_ev, timeout): + fd = kev.ident + flag = kev.filter + if flag == select.KQ_FILTER_READ and fd in self.readers: + events.append(self.readers[fd]) + elif flag == select.KQ_FILTER_WRITE and fd in self.writers: + events.append(self.writers[fd]) + return events + + +# Pick the best pollster class for the platform. +if hasattr(select, 'kqueue'): + best_pollster = KqueuePollster +elif hasattr(select, 'epoll'): + best_pollster = EPollPollster +elif hasattr(select, 'poll'): + best_pollster = PollPollster +else: + best_pollster = SelectPollster + + +class DelayedCall: + """Object returned by callback registration methods.""" + + def __init__(self, when, callback, args, kwds=None): + self.when = when + self.callback = callback + self.args = args + self.kwds = kwds + self.cancelled = False + + def cancel(self): + self.cancelled = True + + def __lt__(self, other): + return self.when < other.when + + def __le__(self, other): + return self.when <= other.when + + def __eq__(self, other): + return self.when == other.when + + +class EventLoop: + """Event loop functionality. + + This defines public APIs call_soon(), call_later(), run_once() and + run(). It also wraps Pollster APIs register_reader(), + register_writer(), remove_reader(), remove_writer() with + add_reader() etc. + + This class's instance variables are not part of its API. + """ + + def __init__(self, pollster=None): + super().__init__() + if pollster is None: + logging.info('Using pollster: %s', best_pollster.__name__) + pollster = best_pollster() + self.pollster = pollster + self.ready = collections.deque() # [(callback, args), ...] + self.scheduled = [] # [(when, callback, args), ...] + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_reader(fd, dcall) + return dcall + + def remove_reader(self, fd): + """Remove a reader callback.""" + self.pollster.unregister_reader(fd) + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_writer(fd, dcall) + return dcall + + def remove_writer(self, fd): + """Remove a writer callback.""" + self.pollster.unregister_writer(fd) + + def add_callback(self, dcall): + """Add a DelayedCall to ready or scheduled.""" + if dcall.cancelled: + return + if dcall.when is None: + self.ready.append(dcall) + else: + heapq.heappush(self.scheduled, dcall) + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + dcall = DelayedCall(None, callback, args) + self.ready.append(dcall) + return dcall + + def call_later(self, when, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The time can be an int or float, expressed in seconds. + + If when is small enough (~11 days), it's assumed to be a + relative time, meaning the call will be scheduled that many + seconds in the future; otherwise it's assumed to be a posix + timestamp as returned by time.time(). + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + if when < 10000000: + when += time.time() + dcall = DelayedCall(when, callback, args) + heapq.heappush(self.scheduled, dcall) + return dcall + + def run_once(self): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Pass in a timeout or deadline or something. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: As step 4, run everything scheduled by steps 1-3. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # TODO: Ensure this loop always finishes, even if some + # callbacks keeps registering more callbacks. + while self.ready: + dcall = self.ready.popleft() + if not dcall.cancelled: + try: + if dcall.kwds: + dcall.callback(*dcall.args, **dcall.kwds) + else: + dcall.callback(*dcall.args) + except Exception: + logging.exception('Exception in callback %s %r', + dcall.callback, dcall.args) + + # Remove delayed calls that were cancelled from head of queue. + while self.scheduled and self.scheduled[0].cancelled: + heapq.heappop(self.scheduled) + + # Inspect the poll queue. + if self.pollster.pollable(): + if self.scheduled: + when = self.scheduled[0].when + timeout = max(0, when - time.time()) + else: + timeout = None + t0 = time.time() + events = self.pollster.poll(timeout) + t1 = time.time() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + for dcall in events: + self.add_callback(dcall) + + # Handle 'later' callbacks that are ready. + now = time.time() + while self.scheduled: + dcall = self.scheduled[0] + if dcall.when > now: + break + dcall = heapq.heappop(self.scheduled) + self.call_soon(dcall.callback, *dcall.args) + + def run(self): + """Run the event loop until there is no work left to do. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + """ + while self.ready or self.scheduled or self.pollster.pollable(): + self.run_once() + + +MAX_WORKERS = 5 # Default max workers when creating an executor. + + +class ThreadRunner: + """Helper to submit work to a thread pool and wait for it. + + This is the glue between the single-threaded callback-based async + world and the threaded world. Use it to call functions that must + block and don't have an async alternative (e.g. getaddrinfo()). + + The only public API is submit(). + """ + + def __init__(self, eventloop, executor=None): + self.eventloop = eventloop + self.executor = executor # Will be constructed lazily. + self.pipe_read_fd, self.pipe_write_fd = os.pipe() + self.active_count = 0 + + def read_callback(self): + """Semi-permanent callback while at least one future is active.""" + assert self.active_count > 0, self.active_count + data = os.read(self.pipe_read_fd, 8192) # Traditional buffer size. + self.active_count -= len(data) + if self.active_count == 0: + self.eventloop.remove_reader(self.pipe_read_fd) + assert self.active_count >= 0, self.active_count + + def submit(self, func, *args, executor=None, callback=None): + """Submit a function to the thread pool. + + This returns a concurrent.futures.Future instance. The caller + should not wait for that, but rather use the callback argument.. + """ + if executor is None: + executor = self.executor + if executor is None: + # Lazily construct a default executor. + # TODO: Should this be shared between threads? + executor = concurrent.futures.ThreadPoolExecutor(MAX_WORKERS) + self.executor = executor + assert self.active_count >= 0, self.active_count + future = executor.submit(func, *args) + if self.active_count == 0: + self.eventloop.add_reader(self.pipe_read_fd, self.read_callback) + self.active_count += 1 + def done_callback(fut): + if callback is not None: + self.eventloop.call_soon(callback, fut) + # TODO: Wake up the pipe in call_soon()? + os.write(self.pipe_write_fd, b'x') + future.add_done_callback(done_callback) + return future + + +class Context(threading.local): + """Thread-local context. + + We use this to avoid having to explicitly pass around an event loop + or something to hold the current task. + + TODO: Add an API so frameworks can substitute a different notion + of context more easily. + """ + + def __init__(self, eventloop=None, threadrunner=None): + # Default event loop and thread runner are lazily constructed + # when first accessed. + self._eventloop = eventloop + self._threadrunner = threadrunner + self.current_task = None # For the benefit of scheduling.py. + + @property + def eventloop(self): + if self._eventloop is None: + self._eventloop = EventLoop() + return self._eventloop + + @property + def threadrunner(self): + if self._threadrunner is None: + self._threadrunner = ThreadRunner(self.eventloop) + return self._threadrunner + + +context = Context() # Thread-local! diff --git a/old/scheduling.py b/old/scheduling.py new file mode 100644 index 00000000..3864571d --- /dev/null +++ b/old/scheduling.py @@ -0,0 +1,354 @@ +#!/usr/bin/env python3.3 +"""Example coroutine scheduler, PEP-380-style ('yield from '). + +Requires Python 3.3. + +There are likely micro-optimizations possible here, but that's not the point. + +TODO: +- Docstrings. +- Unittests. + +PATTERNS TO TRY: +- Various synchronization primitives (Lock, RLock, Event, Condition, + Semaphore, BoundedSemaphore, Barrier). +""" + +__author__ = 'Guido van Rossum ' + +# Standard library imports (keep in alphabetic order). +from concurrent.futures import CancelledError, TimeoutError +import logging +import time +import types + +# Local imports (keep in alphabetic order). +import polling + + +context = polling.context + + +class Task: + """Wrapper around a stack of generators. + + This is a bit like a Future, but with a different interface. + + TODO: + - wait for result. + """ + + def __init__(self, gen, name=None, *, timeout=None): + assert isinstance(gen, types.GeneratorType), repr(gen) + self.gen = gen + self.name = name or gen.__name__ + self.timeout = timeout + self.eventloop = context.eventloop + self.canceleer = None + if timeout is not None: + self.canceleer = self.eventloop.call_later(timeout, self.cancel) + self.blocked = False + self.unblocker = None + self.cancelled = False + self.must_cancel = False + self.alive = True + self.result = None + self.exception = None + self.done_callbacks = [] + # Start the task immediately. + self.eventloop.call_soon(self.step) + + def add_done_callback(self, done_callback): + # For better or for worse, the callback will always be called + # with the task as an argument, like concurrent.futures.Future. + # TODO: Call it right away if task is no longer alive. + dcall = polling.DelayedCall(None, done_callback, (self,)) + self.done_callbacks.append(dcall) + self.done_callbacks = [dc for dc in self.done_callbacks + if not dc.cancelled] + return dcall + + def __repr__(self): + parts = [self.name] + is_current = (self is context.current_task) + if self.blocked: + parts.append('blocking' if is_current else 'blocked') + elif self.alive: + parts.append('running' if is_current else 'runnable') + if self.must_cancel: + parts.append('must_cancel') + if self.cancelled: + parts.append('cancelled') + if self.exception is not None: + parts.append('exception=%r' % self.exception) + elif not self.alive: + parts.append('result=%r' % (self.result,)) + if self.timeout is not None: + parts.append('timeout=%.3f' % self.timeout) + return 'Task<' + ', '.join(parts) + '>' + + def cancel(self): + if self.alive: + if not self.must_cancel and not self.cancelled: + self.must_cancel = True + if self.blocked: + self.unblock() + + def step(self): + assert self.alive, self + try: + context.current_task = self + if self.must_cancel: + self.must_cancel = False + self.cancelled = True + self.gen.throw(CancelledError()) + else: + next(self.gen) + except StopIteration as exc: + self.alive = False + self.result = exc.value + except Exception as exc: + self.alive = False + self.exception = exc + logging.debug('Uncaught exception in %s', self, + exc_info=True, stack_info=True) + except BaseException as exc: + self.alive = False + self.exception = exc + raise + else: + if not self.blocked: + self.eventloop.call_soon(self.step) + finally: + context.current_task = None + if not self.alive: + # Cancel timeout callback if set. + if self.canceleer is not None: + self.canceleer.cancel() + # Schedule done_callbacks. + for dcall in self.done_callbacks: + self.eventloop.add_callback(dcall) + + def block(self, unblock_callback=None, *unblock_args): + assert self is context.current_task, self + assert self.alive, self + assert not self.blocked, self + self.blocked = True + self.unblocker = (unblock_callback, unblock_args) + + def unblock_if_alive(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. + if self.alive: + self.unblock() + + def unblock(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. + assert self.alive, self + assert self.blocked, self + self.blocked = False + unblock_callback, unblock_args = self.unblocker + if unblock_callback is not None: + try: + unblock_callback(*unblock_args) + except Exception: + logging.error('Exception in unblocker in task %r', self.name) + raise + finally: + self.unblocker = None + self.eventloop.call_soon(self.step) + + def block_io(self, fd, flag): + assert isinstance(fd, int), repr(fd) + assert flag in ('r', 'w'), repr(flag) + if flag == 'r': + self.block(self.eventloop.remove_reader, fd) + self.eventloop.add_reader(fd, self.unblock) + else: + self.block(self.eventloop.remove_writer, fd) + self.eventloop.add_writer(fd, self.unblock) + + def wait(self): + """COROUTINE: Wait until this task is finished.""" + current_task = context.current_task + assert self is not current_task, (self, current_task) # How confusing! + if not self.alive: + return + current_task.block() + self.add_done_callback(current_task.unblock) + yield + + def __iter__(self): + """COROUTINE: Wait, then return result or raise exception. + + This adds a little magic so you can say + + x = yield from Task(gen()) + + and it is equivalent to + + x = yield from gen() + + but with the option to add a timeout (and only a tad slower). + """ + if self.alive: + yield from self.wait() + assert not self.alive + if self.exception is not None: + raise self.exception + return self.result + + +def run(arg=None): + """Run the event loop until it's out of work. + + If you pass a generator, it will be spawned for you. + You can also pass a task (already started). + Returns the task. + """ + t = None + if arg is not None: + if isinstance(arg, Task): + t = arg + else: + t = Task(arg) + context.eventloop.run() + if t is not None and t.exception is not None: + logging.error('Uncaught exception in startup task: %r', + t.exception) + return t + + +def sleep(secs): + """COROUTINE: Sleep for some time (a float in seconds).""" + current_task = context.current_task + unblocker = context.eventloop.call_later(secs, current_task.unblock) + current_task.block(unblocker.cancel) + yield + + +def block_r(fd): + """COROUTINE: Block until a file descriptor is ready for reading.""" + context.current_task.block_io(fd, 'r') + yield + + +def block_w(fd): + """COROUTINE: Block until a file descriptor is ready for writing.""" + context.current_task.block_io(fd, 'w') + yield + + +def call_in_thread(func, *args, executor=None): + """COROUTINE: Run a function in a thread.""" + task = context.current_task + eventloop = context.eventloop + future = context.threadrunner.submit(func, *args, + executor=executor, + callback=task.unblock_if_alive) + task.block(future.cancel) + yield + assert future.done() + return future.result() + + +def wait_for(count, tasks): + """COROUTINE: Wait for the first N of a set of tasks to complete. + + May return more than N if more than N are immediately ready. + + NOTE: Tasks that were cancelled or raised are also considered ready. + """ + assert tasks + assert all(isinstance(task, Task) for task in tasks) + tasks = set(tasks) + assert 1 <= count <= len(tasks) + current_task = context.current_task + assert all(task is not current_task for task in tasks) + todo = set() + done = set() + dcalls = [] + def wait_for_callback(task): + nonlocal todo, done, current_task, count, dcalls + todo.remove(task) + if len(done) < count: + done.add(task) + if len(done) == count: + for dcall in dcalls: + dcall.cancel() + current_task.unblock() + for task in tasks: + if task.alive: + todo.add(task) + else: + done.add(task) + if len(done) < count: + for task in todo: + dcall = task.add_done_callback(wait_for_callback) + dcalls.append(dcall) + current_task.block() + yield + return done + + +def wait_any(tasks): + """COROUTINE: Wait for the first of a set of tasks to complete.""" + return wait_for(1, tasks) + + +def wait_all(tasks): + """COROUTINE: Wait for all of a set of tasks to complete.""" + return wait_for(len(tasks), tasks) + + +def map_over(gen, *args, timeout=None): + """COROUTINE: map a generator over one or more iterables. + + E.g. map_over(foo, xs, ys) runs + + Task(foo(x, y)) for x, y in zip(xs, ys) + + and returns a list of all results (in that order). However if any + task raises an exception, the remaining tasks are cancelled and + the exception is propagated. + """ + # gen is a generator function. + tasks = [Task(gobj, timeout=timeout) for gobj in map(gen, *args)] + return (yield from par_tasks(tasks)) + + +def par(*args): + """COROUTINE: Wait for generators, return a list of results. + + Raises as soon as one of the tasks raises an exception (and then + remaining tasks are cancelled). + + This differs from par_tasks() in two ways: + - takes *args instead of list of args + - each arg may be a generator or a task + """ + tasks = [] + for arg in args: + if not isinstance(arg, Task): + # TODO: assert arg is a generator or an iterator? + arg = Task(arg) + tasks.append(arg) + return (yield from par_tasks(tasks)) + + +def par_tasks(tasks): + """COROUTINE: Wait for a list of tasks, return a list of results. + + Raises as soon as one of the tasks raises an exception (and then + remaining tasks are cancelled). + """ + todo = set(tasks) + while todo: + ts = yield from wait_any(todo) + for t in ts: + assert not t.alive, t + todo.remove(t) + if t.exception is not None: + for other in todo: + other.cancel() + raise t.exception + return [t.result for t in tasks] diff --git a/old/sockets.py b/old/sockets.py new file mode 100644 index 00000000..a5005dc3 --- /dev/null +++ b/old/sockets.py @@ -0,0 +1,348 @@ +"""Socket wrappers to go with scheduling.py. + +Classes: + +- SocketTransport: a transport implementation wrapping a socket. +- SslTransport: a transport implementation wrapping SSL around a socket. +- BufferedReader: a buffer wrapping the read end of a transport. + +Functions (all coroutines): + +- connect(): connect a socket. +- getaddrinfo(): look up an address. +- create_connection(): look up address and return a connected socket for it. +- create_transport(): look up address and return a connected transport. + +TODO: +- Improve transport abstraction. +- Make a nice protocol abstraction. +- Unittests. +- A write() call that isn't a generator (needed so you can substitute it + for sys.stderr, pass it to logging.StreamHandler, etc.). +""" + +__author__ = 'Guido van Rossum ' + +# Stdlib imports. +import errno +import socket +import ssl + +# Local imports. +import scheduling + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + + +class SocketTransport: + """Transport wrapping a socket. + + The socket must already be connected in non-blocking mode. + """ + + def __init__(self, sock): + self.sock = sock + + def recv(self, n): + """COROUTINE: Read up to n bytes, blocking as needed. + + Always returns at least one byte, except if the socket was + closed or disconnected and there's no more data; then it + returns b''. + """ + assert n >= 0, n + while True: + try: + return self.sock.recv(n) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return b'' + else: + raise # Unexpected, propagate. + yield from scheduling.block_r(self.sock.fileno()) + + def send(self, data): + """COROUTINE; Send data to the socket, blocking until all written. + + Return True if all went well, False if socket was disconnected. + """ + while data: + try: + n = self.sock.send(data) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + else: + assert 0 <= n <= len(data), (n, len(data)) + if n == len(data): + break + data = data[n:] + continue + yield from scheduling.block_w(self.sock.fileno()) + + return True + + def close(self): + """Close the socket. (Not a coroutine.)""" + self.sock.close() + + +class SslTransport: + """Transport wrapping a socket in SSL. + + The socket must already be connected at the TCP level in + non-blocking mode. + """ + + def __init__(self, rawsock, sslcontext=None): + self.rawsock = rawsock + self.sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self.sslsock = self.sslcontext.wrap_socket( + self.rawsock, do_handshake_on_connect=False) + + def do_handshake(self): + """COROUTINE: Finish the SSL handshake.""" + while True: + try: + self.sslsock.do_handshake() + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + else: + break + + def recv(self, n): + """COROUTINE: Read up to n bytes. + + This blocks until at least one byte is read, or until EOF. + """ + while True: + try: + return self.sslsock.recv(n) + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_r(self.sslsock.fileno()) + elif err.errno in _DISCONNECTED: + # Can this happen? + return b'' + else: + raise # Unexpected, propagate. + + def send(self, data): + """COROUTINE: Send data to the socket, blocking as needed.""" + while data: + try: + n = self.sslsock.send(data) + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_w(self.sslsock.fileno()) + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + if n == len(data): + break + data = data[n:] + + return True + + def close(self): + """Close the socket. (Not a coroutine.) + + This also closes the raw socket. + """ + self.sslsock.close() + + # TODO: More SSL-specific methods, e.g. certificate stuff, unwrap(), ... + + +class BufferedReader: + """A buffered reader wrapping a transport.""" + + def __init__(self, trans, limit=8192): + self.trans = trans + self.limit = limit + self.buffer = b'' + self.eof = False + + def read(self, n): + """COROUTINE: Read up to n bytes, blocking at most once.""" + assert n >= 0, n + if not self.buffer and not self.eof: + yield from self._fillbuffer(max(n, self.limit)) + return self._getfrombuffer(n) + + def readexactly(self, n): + """COUROUTINE: Read exactly n bytes, or until EOF.""" + blocks = [] + count = 0 + while count < n: + block = yield from self.read(n - count) + if not block: + break + blocks.append(block) + count += len(block) + return b''.join(blocks) + + def readline(self): + """COROUTINE: Read up to newline or limit, whichever comes first.""" + end = self.buffer.find(b'\n') + 1 # Point past newline, or 0. + while not end and not self.eof and len(self.buffer) < self.limit: + anchor = len(self.buffer) + yield from self._fillbuffer(self.limit) + end = self.buffer.find(b'\n', anchor) + 1 + if not end: + end = len(self.buffer) + if end > self.limit: + end = self.limit + return self._getfrombuffer(end) + + def _getfrombuffer(self, n): + """Read up to n bytes without blocking (not a coroutine).""" + if n >= len(self.buffer): + result, self.buffer = self.buffer, b'' + else: + result, self.buffer = self.buffer[:n], self.buffer[n:] + return result + + def _fillbuffer(self, n): + """COROUTINE: Fill buffer with one (up to) n bytes from transport.""" + assert not self.eof, '_fillbuffer called at eof' + data = yield from self.trans.recv(n) + if data: + self.buffer += data + else: + self.eof = True + + +def connect(sock, address): + """COROUTINE: Connect a socket to an address.""" + try: + sock.connect(address) + except socket.error as err: + if err.errno != errno.EINPROGRESS: + raise + yield from scheduling.block_w(sock.fileno()) + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + raise IOError(err, 'Connection refused') + + +def getaddrinfo(host, port, af=0, socktype=0, proto=0): + """COROUTINE: Look up an address and return a list of infos for it. + + Each info is a tuple (af, socktype, protocol, canonname, address). + """ + infos = yield from scheduling.call_in_thread(socket.getaddrinfo, + host, port, af, + socktype, proto) + return infos + + +def create_connection(host, port, af=0, socktype=socket.SOCK_STREAM, proto=0): + """COROUTINE: Look up address and create a socket connected to it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + yield from connect(sock, address) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return sock + + +def create_transport(host, port, af=0, ssl=None): + """COROUTINE: Look up address and create a transport connected to it.""" + if ssl is None: + ssl = (port == 443) + sock = yield from create_connection(host, port, af) + if ssl: + trans = SslTransport(sock) + yield from trans.do_handshake() + else: + trans = SocketTransport(sock) + return trans + + +class Listener: + """Wrapper for a listening socket.""" + + def __init__(self, sock): + self.sock = sock + + def accept(self): + """COROUTINE: Accept a connection.""" + while True: + try: + conn, addr = self.sock.accept() + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_r(self.sock.fileno()) + else: + raise # Unexpected, propagate. + else: + conn.setblocking(False) + return conn, addr + + +def create_listener(host, port, af=0, socktype=0, proto=0, + backlog=5, reuse_addr=True): + """COROUTINE: Look up address and create a listener for it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + if reuse_addr: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + sock.listen(backlog) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return Listener(sock) diff --git a/old/transports.py b/old/transports.py new file mode 100644 index 00000000..19095bf4 --- /dev/null +++ b/old/transports.py @@ -0,0 +1,496 @@ +"""Transports and Protocols, actually. + +Inspired by Twisted, PEP 3153 and github.com/lvh/async-pep. + +THIS IS NOT REAL CODE! IT IS JUST AN EXPERIMENT. +""" + +# Stdlib imports. +import collections +import errno +import logging +import socket +import ssl +import sys +import time + +# Local imports. +import polling +import scheduling +import sockets + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + + +class Transport: + """ABC representing a transport. + + There may be many implementations. The user never instantiates + this directly; they call some utility function, passing it a + protocol, and the utility function will call the protocol's + connection_made() method with a transport (or it will call + connection_lost() with an exception if it fails to create the + desired transport). + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + def write(self, data): + """Write some data (bytes) to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data (bytes) to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data will + be received. When all buffered data is flushed, the protocol's + connection_lost() method is called with None as its argument. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method is called with None as + its argument. + """ + raise NotImplementedError + + def half_close(self): + """Closes the write end after flushing buffered data. + + Data may still be received. + + TODO: What's the use case for this? How to implement it? + Should it call shutdown(SHUT_WR) after all the data is flushed? + Is there no use case for closing the other half first? + """ + raise NotImplementedError + + def pause(self): + """Pause the receiving end. + + No data will be received until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Cancels a pause() call, resumes receiving data. + """ + raise NotImplementedError + + +class Protocol: + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing. + + When the user wants to requests a transport, they pass a protocol + instance to a utility function. + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_list() will be called exactly once + with either an exception object or None as an argument. + + If the utility function does not succeed in creating a transport, + it will call connection_lost() with an exception object. + + State machine of calls: + + start -> [CM -> DR*] -> CL -> end + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the connection. + To send data, call its write() or writelines() method. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + + TODO: Should we allow it to be a bytesarray or some other + memory buffer? + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + Also called when we fail to make a connection at all (in that + case connection_made() will not be called). + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +# TODO: The rest is platform specific and should move elsewhere. + +class UnixSocketTransport(Transport): + + def __init__(self, eventloop, protocol, sock): + self._eventloop = eventloop + self._protocol = protocol + self._sock = sock + self._buffer = collections.deque() # For write(). + self._write_closed = False + + def _on_readable(self): + try: + data = self._sock.recv(8192) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + else: + if not data: + self._eventloop.remove_reader(self._sock.fileno()) + self._sock.close() + self._protocol.connection_lost(None) + else: + self._protocol.data_received(data) # XXX call_soon()? + + def write(self, data): + assert isinstance(data, bytes) + assert not self._write_closed + if not data: + # Silly, but it happens. + return + if self._buffer: + # We've already registered a callback, just buffer the data. + self._buffer.append(data) + # Consider pausing if the total length of the buffer is + # truly huge. + return + + # TODO: Refactor so there's more sharing between this and + # _on_writable(). + + # There's no callback registered yet. It's quite possible + # that the kernel has buffer space for our data, so try to + # write now. Since the socket is non-blocking it will + # give us an error in _TRYAGAIN if it doesn't have enough + # space for even one more byte; it will return the number + # of bytes written if it can write at least one byte. + try: + n = self._sock.send(data) + except socket.error as exc: + # An error. + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + # The kernel doesn't have room for more data right now. + n = 0 + else: + # Wrote at least one byte. + if n == len(data): + # Wrote it all. Done! + if self._write_closed: + self._sock.shutdown(socket.SHUT_WR) + return + # Throw away the data that was already written. + # TODO: Do this without copying the data? + data = data[n:] + self._buffer.append(data) + self._eventloop.add_writer(self._sock.fileno(), self._on_writable) + + def _on_writable(self): + while self._buffer: + data = self._buffer[0] + # TODO: Join small amounts of data? + try: + n = self._sock.send(data) + except socket.error as exc: + # Error handling is the same as in write(). + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + if n < len(data): + self._buffer[0] = data[n:] + return + self._buffer.popleft() + self._eventloop.remove_writer(self._sock.fileno()) + if self._write_closed: + self._sock.shutdown(socket.SHUT_WR) + + def abort(self): + self._bad_error(None) + + def _bad_error(self, exc): + # A serious error. Close the socket etc. + fd = self._sock.fileno() + # TODO: Record whether we have a writer and/or reader registered. + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + self._sock.close() + self._protocol.connection_lost(exc) # XXX call_soon()? + + def half_close(self): + self._write_closed = True + + +class UnixSslTransport(Transport): + + # TODO: Refactor Socket and Ssl transport to share some code. + # (E.g. buffering.) + + # TODO: Consider using coroutines instead of callbacks, it seems + # much easier that way. + + def __init__(self, eventloop, protocol, rawsock, sslcontext=None): + self._eventloop = eventloop + self._protocol = protocol + self._rawsock = rawsock + self._sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslsock = self._sslcontext.wrap_socket( + self._rawsock, do_handshake_on_connect=False) + + self._buffer = collections.deque() # For write(). + self._write_closed = False + + # Try the handshake now. Likely it will raise EAGAIN, then it + # will take care of registering the appropriate callback. + self._on_handshake() + + def _bad_error(self, exc): + # A serious error. Close the socket etc. + fd = self._sslsock.fileno() + # TODO: Record whether we have a writer and/or reader registered. + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + self._sslsock.close() + self._protocol.connection_lost(exc) # XXX call_soon()? + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._eventloop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._eventloop.add_writable(fd, self._on_handshake) + return + # TODO: What if it raises another error? + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + self._protocol.connection_made(self) + self._eventloop.add_reader(fd, self._on_ready) + self._eventloop.add_writer(fd, self._on_ready) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + else: + if data: + self._protocol.data_received(data) + else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer + self._eventloop.remove_reader(fd) + self._eventloop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + return + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = self._buffer[0] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + else: + if n == len(data): + self._buffer.popleft() + # Could try again, but let's just have the next callback do it. + else: + self._buffer[0] = data[n:] + + def write(self, data): + assert isinstance(data, bytes) + assert not self._write_closed + if not data: + return + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + def half_close(self): + self._write_closed = True + # Just set the flag. Calling shutdown() on the ssl socket + # breaks something, causing recv() to return binary data. + + +def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, + use_ssl=None): + # TODO: Pass in a protocol factory, not a protocol. + # What should be the exact sequence of events? + # - socket + # - transport + # - protocol + # - tell transport about protocol + # - tell protocol about transport + # Or should the latter two be reversed? Does it matter? + if port is None: + port = 443 if use_ssl else 80 + if use_ssl is None: + use_ssl = (port == 443) + if not socktype: + socktype = socket.SOCK_STREAM + eventloop = polling.context.eventloop + + def on_socket_connected(task): + assert not task.alive + if task.exception is not None: + # TODO: Call some callback. + raise task.exception + sock = task.result + assert sock is not None + logging.debug('on_socket_connected') + if use_ssl: + # You can pass an ssl.SSLContext object as use_ssl, + # or a bool. + if isinstance(use_ssl, bool): + sslcontext = None + else: + sslcontext = use_ssl + transport = UnixSslTransport(eventloop, protocol, sock, sslcontext) + else: + transport = UnixSocketTransport(eventloop, protocol, sock) + # TODO: Should the ransport make the following calls? + protocol.connection_made(transport) # XXX call_soon()? + # Don't do this before connection_made() is called. + eventloop.add_reader(sock.fileno(), transport._on_readable) + + coro = sockets.create_connection(host, port, af, socktype, proto) + task = scheduling.Task(coro) + task.add_done_callback(on_socket_connected) + + +def main(): # Testing... + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + host = 'xkcd.com' + if sys.argv[1:] and '.' in sys.argv[-1]: + host = sys.argv[-1] + + t0 = time.time() + + class TestProtocol(Protocol): + def connection_made(self, transport): + logging.info('Connection made at %.3f secs', time.time() - t0) + self.transport = transport + self.transport.write(b'GET / HTTP/1.0\r\nHost: ' + + host.encode('ascii') + + b'\r\n\r\n') + self.transport.half_close() + def data_received(self, data): + logging.info('Received %d bytes at t=%.3f', + len(data), time.time() - t0) + logging.debug('Received %r', data) + def connection_lost(self, exc): + logging.debug('Connection lost: %r', exc) + self.t1 = time.time() + logging.info('Total time %.3f secs', self.t1 - t0) + + tp = TestProtocol() + logging.debug('tp = %r', tp) + make_connection(tp, host, use_ssl=('-S' in sys.argv)) + logging.info('Running...') + polling.context.eventloop.run() + logging.info('Done.') + + +if __name__ == '__main__': + main() diff --git a/old/xkcd.py b/old/xkcd.py new file mode 100755 index 00000000..474009d0 --- /dev/null +++ b/old/xkcd.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3.3 +"""Minimal synchronous SSL demo, connecting to xkcd.com.""" + +import socket, ssl + +s = socket.socket() +s.connect(('xkcd.com', 443)) +ss = ssl.wrap_socket(s) + +ss.send(b'GET / HTTP/1.0\r\n\r\n') + +while True: + data = ss.recv(1000000) + print(data) + if not data: + break + +ss.close() diff --git a/old/yyftime.py b/old/yyftime.py new file mode 100644 index 00000000..f55234b9 --- /dev/null +++ b/old/yyftime.py @@ -0,0 +1,75 @@ +"""Compare timing of yield-from vs. yield calls.""" + +import gc +import time + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def run_coro(depth): + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + print('coro', depth, k, round(t1-t0, 6)) + return t1-t0 + +class Future: + + def __init__(self, g): + self.g = g + + def wait(self): + value = None + try: + while True: + f = self.g.send(value) + f.wait() + value = f.value + except StopIteration as err: + self.value = err.value + + + +def task(func): # Decorator + def wrapper(*args): + g = func(*args) + f = Future(g) + return f + return wrapper + +@task +def oldstyle(n): + if n <= 0: + return 1 + l = yield oldstyle(n-1) + r = yield oldstyle(n-1) + return l + 1 + r + +def run_olds(depth): + t0 = time.time() + f = oldstyle(depth) + f.wait() + k = f.value + t1 = time.time() + print('olds', depth, k, round(t1-t0, 6)) + return t1-t0 + +def main(): + gc.disable() + for depth in range(16): + tc = run_coro(depth) + to = run_olds(depth) + if tc: + print('ratio', round(to/tc, 2)) + +if __name__ == '__main__': + main() diff --git a/runtests.py b/runtests.py new file mode 100644 index 00000000..d18fe7d3 --- /dev/null +++ b/runtests.py @@ -0,0 +1,82 @@ +"""Run all unittests. + +Usage: + python3 runtests.py [-v] [-q] [pattern] ... + +Where: + -v: verbose + -q: quiet + pattern: optional regex patterns to match test ids (default all tests) + +Note that the test id is the fully qualified name of the test, +including package, module, class and method, +e.g. 'tulip.events_test.PolicyTests.testPolicy'. +""" + +# Originally written by Beech Horn (for NDB). + +import os +import re +import sys +import unittest + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +TULIP_DIR = os.path.join(os.path.dirname(__file__), 'tulip') + + +def load_tests(includes=(), excludes=()): + test_mods = [f[:-3] for f in os.listdir(TULIP_DIR) if f.endswith('_test.py')] + + if sys.platform == 'win32': + try: + test_mods.remove('subprocess_test') + except ValueError: + pass + tulip = __import__('tulip', fromlist=test_mods) + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + for mod in [getattr(tulip, name) for name in test_mods]: + for name in set(dir(mod)): + if name.endswith('Tests'): + test_module = getattr(mod, name) + tests = loader.loadTestsFromTestCase(test_module) + if includes: + tests = [test + for test in tests + if any(re.search(pat, test.id()) for pat in includes)] + if excludes: + tests = [test + for test in tests + if not any(re.search(pat, test.id()) for pat in excludes)] + suite.addTests(tests) + + return suite + + +def main(): + excludes = [] + includes = [] + patterns = includes # A reference. + v = 1 + for arg in sys.argv[1:]: + if arg.startswith('-v'): + v += arg.count('v') + elif arg == '-q': + v = 0 + elif arg == '-x': + if patterns is includes: + patterns = excludes + else: + patterns = includes + elif arg and not arg.startswith('-'): + patterns.append(arg) + tests = load_tests(includes, excludes) + result = unittest.TextTestRunner(verbosity=v).run(tests) + sys.exit(not result.wasSuccessful()) + + +if __name__ == '__main__': + main() diff --git a/srv.py b/srv.py new file mode 100644 index 00000000..bae9fdd8 --- /dev/null +++ b/srv.py @@ -0,0 +1,127 @@ +"""Simple server written using an event loop.""" + +import email.message +import email.parser +import os +import re + +import tulip +from tulip.http_client import StreamReader + + +class HttpServer(tulip.Protocol): + + def __init__(self): + super().__init__() + self.transport = None + self.reader = None + self.handler = None + + @tulip.task + def handle_request(self): + line = yield from self.reader.readline() + print('request line', line) + match = re.match(rb'([A-Z]+) (\S+) HTTP/(1.\d)\r?\n\Z', line) + if not match: + self.transport.close() + return + bmethod, bpath, bversion = match.groups() + print('method = {!r}; path = {!r}; version = {!r}'.format(bmethod, bpath, bversion)) + try: + path = bpath.decode('ascii') + except UnicodeError as exc: + print('not ascii', repr(bpath), exc) + path = None + else: + if not (path.isprintable() and path.startswith('/')) or '/.' in path: + print('bad path', repr(path)) + path = None + else: + path = '.' + path + if not os.path.exists(path): + print('no file', repr(path)) + path = None + else: + isdir = os.path.isdir(path) + if not path: + self.transport.write(b'HTTP/1.0 404 Not found\r\n\r\n') + self.transport.close() + return + lines = [] + while True: + line = yield from self.reader.readline() + print('header line', line) + if not line.strip(b' \t\r\n'): + break + lines.append(line) + if line == b'\r\n': + break + parser = email.parser.BytesHeaderParser() + headers = parser.parsebytes(b''.join(lines)) + write = self.transport.write + if isdir and not path.endswith('/'): + write(b'HTTP/1.0 302 Redirected\r\n' + b'URI: ' + bpath + b'/\r\n' + b'Location: ' + bpath + b'/\r\n' + b'\r\n') + return + write(b'HTTP/1.0 200 Ok\r\n') + if isdir: + write(b'Content-type: text/html\r\n') + else: + write(b'Content-type: text/plain\r\n') + write(b'\r\n') + if isdir: + write(b'
    \r\n') + for name in sorted(os.listdir(path)): + if name.isprintable() and not name.startswith('.'): + try: + bname = name.encode('ascii') + except UnicodeError as exc: + pass + else: + if os.path.isdir(os.path.join(path, name)): + write(b'
  • ' + bname + b'/
  • \r\n') + else: + write(b'
  • ' + bname + b'
  • \r\n') + write(b'
') + else: + try: + with open(path, 'rb') as f: + write(f.read()) + except OSError as exc: + write(b'Cannot open\r\n') + self.transport.close() + + def connection_made(self, transport): + self.transport = transport + print('connection made', transport, transport._sock) + self.reader = StreamReader() + self.handler = self.handle_request() + + def data_received(self, data): + print('data received', data) + self.reader.feed_data(data) + + def eof_received(self): + print('eof received') + self.reader.feed_eof() + + def connection_lost(self, exc): + print('connection lost', exc) + if (self.handler.done() and + not self.handler.cancelled() and + self.handler.exception() is not None): + print('handler exception:', self.handler.exception()) + + +def main(): + loop = tulip.get_event_loop() + f = loop.start_serving(HttpServer, '127.0.0.1', 8080) + x = loop.run_until_complete(f) + print('serving on', x.getsockname()) + loop.run_forever() + + +if __name__ == '__main__': + main() diff --git a/sslsrv.py b/sslsrv.py new file mode 100644 index 00000000..a1bc04f9 --- /dev/null +++ b/sslsrv.py @@ -0,0 +1,56 @@ +"""Serve up an SSL connection, after Python ssl module docs.""" + +import socket +import ssl +import os + + +def main(): + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + certfile = getcertfile() + context.load_cert_chain(certfile=certfile, keyfile=certfile) + bindsocket = socket.socket() + bindsocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + bindsocket.bind(('', 4443)) + bindsocket.listen(5) + + while True: + newsocket, fromaddr = bindsocket.accept() + try: + connstream = context.wrap_socket(newsocket, server_side=True) + try: + deal_with_client(connstream) + finally: + connstream.shutdown(socket.SHUT_RDWR) + connstream.close() + except Exception as exc: + print(exc.__class__.__name__ + ':', exc) + + +def getcertfile(): + import test # Test package + testdir = os.path.dirname(test.__file__) + certfile = os.path.join(testdir, 'keycert.pem') + print('certfile =', certfile) + return certfile + + +def deal_with_client(connstream): + data = connstream.recv(1024) + # empty data means the client is finished with us + while data: + if not do_something(connstream, data): + # we'll assume do_something returns False + # when we're finished with client + break + data = connstream.recv(1024) + # finished with client + + +def do_something(connstream, data): + # just echo back + connstream.sendall(data) + + +if __name__ == '__main__': + main() diff --git a/tulip/TODO b/tulip/TODO new file mode 100644 index 00000000..b3a9302e --- /dev/null +++ b/tulip/TODO @@ -0,0 +1,26 @@ +TODO in tulip v2 (tulip/ package directory) +------------------------------------------- + +- See also TBD and Open Issues in PEP 3156 + +- Refactor unix_events.py (it's getting too long) + +- Docstrings + +- Unittests + +- better run_once() behavior? (Run ready list last.) + +- start_serving() + +- Make Handler() callable? Log the exception in there? + +- Add the repeat interval to the Handler class? + +- Recognize Handler passed to add_reader(), call_soon(), etc.? + +- SSL support + +- buffered stream implementation + +- Primitives like par() and wait_one() diff --git a/tulip/__init__.py b/tulip/__init__.py new file mode 100644 index 00000000..185fe3fe --- /dev/null +++ b/tulip/__init__.py @@ -0,0 +1,14 @@ +"""Tulip 2.0, tracking PEP 3156.""" + +# This relies on each of the submodules having an __all__ variable. +from .futures import * +from .events import * +from .transports import * +from .protocols import * +from .tasks import * + +__all__ = (futures.__all__ + + events.__all__ + + transports.__all__ + + protocols.__all__ + + tasks.__all__) diff --git a/tulip/events.py b/tulip/events.py new file mode 100644 index 00000000..f39ddb79 --- /dev/null +++ b/tulip/events.py @@ -0,0 +1,287 @@ +"""Event loop and event loop policy. + +Beyond the PEP: +- Only the main thread has a default event loop. +""" + +__all__ = ['EventLoopPolicy', 'DefaultEventLoopPolicy', + 'EventLoop', 'Handler', 'make_handler', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', 'new_event_loop', + ] + +import threading + + +class Handler: + """Object returned by callback registration methods.""" + + def __init__(self, when, callback, args): + self._when = when + self._callback = callback + self._args = args + self._cancelled = False + + def __repr__(self): + res = 'Handler({}, {}, {})'.format(self._when, + self._callback, + self._args) + if self._cancelled: + res += '' + return res + + @property + def when(self): + return self._when + + @property + def callback(self): + return self._callback + + @property + def args(self): + return self._args + + @property + def cancelled(self): + return self._cancelled + + def cancel(self): + self._cancelled = True + + def __lt__(self, other): + return self._when < other._when + + def __le__(self, other): + return self._when <= other._when + + def __gt__(self, other): + return self._when > other._when + + def __ge__(self, other): + return self._when >= other._when + + def __eq__(self, other): + return self._when == other._when + + +def make_handler(when, callback, args): + if isinstance(callback, Handler): + assert not args + assert when is None + return callback + return Handler(when, callback, args) + + +class EventLoop: + """Abstract event loop.""" + + # TODO: Rename run() -> run_until_idle(), run_forever() -> run(). + + def run(self): + """Run the event loop. Block until there is nothing left to do.""" + raise NotImplementedError + + def run_forever(self): + """Run the event loop. Block until stop() is called.""" + raise NotImplementedError + + def run_once(self, timeout=None): # NEW! + """Run one complete cycle of the event loop.""" + raise NotImplementedError + + def run_until_complete(self, future, timeout=None): # NEW! + """Run the event loop until a Future is done. + + Return the Future's result, or raise its exception. + + If timeout is not None, run it for at most that long; + if the Future is still not done, raise TimeoutError + (but don't cancel the Future). + """ + raise NotImplementedError + + def stop(self): # NEW! + """Stop the event loop as soon as reasonable. + + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + raise NotImplementedError + + # Methods returning Handlers for scheduling callbacks. + + def call_later(self, delay, callback, *args): + raise NotImplementedError + + def call_repeatedly(self, interval, callback, *args): # NEW! + raise NotImplementedError + + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) + + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError + + # Methods returning Futures for interacting with threads. + + def wrap_future(self, future): + raise NotImplementedError + + def run_in_executor(self, executor, callback, *args): + raise NotImplementedError + + # Network I/O methods returning Futures. + + def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError + + def create_connection(self, protocol_factory, host, port, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + def start_serving(self, protocol_factory, host, port, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + # Ready-based callback registration methods. + # The add_*() methods return a Handler. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. + + def add_reader(self, fd, callback, *args): + raise NotImplementedError + + def remove_reader(self, fd): + raise NotImplementedError + + def add_writer(self, fd, callback, *args): + raise NotImplementedError + + def remove_writer(self, fd): + raise NotImplementedError + + def add_connector(self, fd, callback, *args): + raise NotImplementedError + + def remove_connector(self, fd): + raise NotImplementedError + + # Completion based I/O methods returning Futures. + + def sock_recv(self, sock, nbytes): + raise NotImplementedError + + def sock_sendall(self, sock, data): + raise NotImplementedError + + def sock_connect(self, sock, address): + raise NotImplementedError + + def sock_accept(self, sock): + raise NotImplementedError + + # Signal handling. + + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError + + def remove_signal_handler(self, sig): + raise NotImplementedError + + +class EventLoopPolicy: + """Abstract policy for accessing the event loop.""" + + def get_event_loop(self): + """XXX""" + raise NotImplementedError + + def set_event_loop(self, event_loop): + """XXX""" + raise NotImplementedError + + def new_event_loop(self): + """XXX""" + raise NotImplementedError + + +class DefaultEventLoopPolicy(threading.local, EventLoopPolicy): + """Default policy implementation for accessing the event loop. + + In this policy, each thread has its own event loop. However, we + only automatically create an event loop by default for the main + thread; other threads by default have no event loop. + + Other policies may have different rules (e.g. a single global + event loop, or automatically creating an event loop per thread, or + using some other notion of context to which an event loop is + associated). + """ + + _event_loop = None + + def get_event_loop(self): + """Get the event loop. + + This may be None or an instance of EventLoop. + """ + if (self._event_loop is None and + threading.current_thread().name == 'MainThread'): + self._event_loop = self.new_event_loop() + return self._event_loop + + def set_event_loop(self, event_loop): + """Set the event loop.""" + assert event_loop is None or isinstance(event_loop, EventLoop) + self._event_loop = event_loop + + def new_event_loop(self): + """Create a new event loop. + + You must call set_event_loop() to make this the current event + loop. + """ + # TODO: Do something else for Windows. + from . import unix_events + return unix_events.UnixEventLoop() + + +# Event loop policy. The policy itself is always global, even if the +# policy's rules say that there is an event loop per thread (or other +# notion of context). The default policy is installed by the first +# call to get_event_loop_policy(). +_event_loop_policy = None + + +def get_event_loop_policy(): + """XXX""" + global _event_loop_policy + if _event_loop_policy is None: + _event_loop_policy = DefaultEventLoopPolicy() + return _event_loop_policy + + +def set_event_loop_policy(policy): + """XXX""" + global _event_loop_policy + assert policy is None or isinstance(policy, EventLoopPolicy) + _event_loop_policy = policy + + +def get_event_loop(): + """XXX""" + return get_event_loop_policy().get_event_loop() + + +def set_event_loop(event_loop): + """XXX""" + get_event_loop_policy().set_event_loop(event_loop) + + +def new_event_loop(): + """XXX""" + return get_event_loop_policy().new_event_loop() diff --git a/tulip/events_test.py b/tulip/events_test.py new file mode 100644 index 00000000..76802579 --- /dev/null +++ b/tulip/events_test.py @@ -0,0 +1,455 @@ +"""Tests for events.py.""" + +import concurrent.futures +import gc +import os +import select +import signal +import socket +try: + import ssl +except ImportError: + ssl = None +import sys +import threading +import time +import unittest + +from . import events +from . import transports +from . import protocols +from . import selectors +from . import unix_events + + +class MyProto(protocols.Protocol): + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: xkcd.com\r\n\r\n') + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + + +class EventLoopTestsMixin: + + def setUp(self): + self.selector = self.SELECTOR_CLASS() + self.event_loop = unix_events.UnixEventLoop(self.selector) + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + gc.collect() + + def testRun(self): + el = events.get_event_loop() + el.run() # Returns immediately. + + def testCallLater(self): + el = events.get_event_loop() + results = [] + def callback(arg): + results.append(arg) + el.call_later(0.1, callback, 'hello world') + t0 = time.monotonic() + el.run() + t1 = time.monotonic() + self.assertEqual(results, ['hello world']) + self.assertTrue(t1-t0 >= 0.09) + + def testCallRepeatedly(self): + el = events.get_event_loop() + results = [] + def callback(arg): + results.append(arg) + el.call_repeatedly(0.03, callback, 'ho') + el.call_later(0.1, el.stop) + el.run() + self.assertEqual(results, ['ho', 'ho', 'ho']) + + def testCallSoon(self): + el = events.get_event_loop() + results = [] + def callback(arg1, arg2): + results.append((arg1, arg2)) + el.call_soon(callback, 'hello', 'world') + el.run() + self.assertEqual(results, [('hello', 'world')]) + + def testCallSoonWithHandler(self): + el = events.get_event_loop() + results = [] + def callback(): + results.append('yeah') + handler = events.Handler(None, callback, ()) + self.assertEqual(el.call_soon(handler), handler) + el.run() + self.assertEqual(results, ['yeah']) + + def testCallSoonThreadsafe(self): + el = events.get_event_loop() + results = [] + def callback(arg): + results.append(arg) + def run(): + el.call_soon_threadsafe(callback, 'hello') + t = threading.Thread(target=run) + el.call_later(0.1, callback, 'world') + t0 = time.monotonic() + t.start() + el.run() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + + def testCallSoonThreadsafeWithHandler(self): + el = events.get_event_loop() + results = [] + def callback(arg): + results.append(arg) + handler = events.Handler(None, callback, ('hello',)) + def run(): + self.assertEqual(el.call_soon_threadsafe(handler), handler) + t = threading.Thread(target=run) + el.call_later(0.1, callback, 'world') + t0 = time.monotonic() + t.start() + el.run() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + + def testWrapFuture(self): + el = events.get_event_loop() + def run(arg): + time.sleep(0.1) + return arg + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = el.wrap_future(f1) + res = el.run_until_complete(f2) + self.assertEqual(res, 'oi') + + def testRunInExecutor(self): + el = events.get_event_loop() + def run(arg): + time.sleep(0.1) + return arg + f2 = el.run_in_executor(None, run, 'yo') + res = el.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def testRunInExecutorWithHandler(self): + el = events.get_event_loop() + def run(arg): + time.sleep(0.1) + return arg + handler = events.Handler(None, run, ('yo',)) + f2 = el.run_in_executor(None, handler) + res = el.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def testReaderCallback(self): + el = events.get_event_loop() + r, w = unix_events.socketpair() + bytes_read = [] + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(el.remove_reader(r.fileno())) + r.close() + el.add_reader(r.fileno(), reader) + el.call_later(0.05, w.send, b'abc') + el.call_later(0.1, w.send, b'def') + el.call_later(0.15, w.close) + el.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def testReaderCallbackWithHandler(self): + el = events.get_event_loop() + r, w = unix_events.socketpair() + bytes_read = [] + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(el.remove_reader(r.fileno())) + r.close() + handler = events.Handler(None, reader, ()) + self.assertEqual(el.add_reader(r.fileno(), handler), handler) + el.call_later(0.05, w.send, b'abc') + el.call_later(0.1, w.send, b'def') + el.call_later(0.15, w.close) + el.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def testReaderCallbackCancel(self): + el = events.get_event_loop() + r, w = unix_events.socketpair() + bytes_read = [] + def reader(): + data = r.recv(1024) + if data: + bytes_read.append(data) + if sum(len(b) for b in bytes_read) >= 6: + handler.cancel() + if not data: + r.close() + handler = el.add_reader(r.fileno(), reader) + el.call_later(0.05, w.send, b'abc') + el.call_later(0.1, w.send, b'def') + el.call_later(0.15, w.close) + el.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def testWriterCallback(self): + el = events.get_event_loop() + r, w = unix_events.socketpair() + w.setblocking(False) + el.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + def remove_writer(): + self.assertTrue(el.remove_writer(w.fileno())) + el.call_later(0.1, remove_writer) + el.run() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def testWriterCallbackWithHandler(self): + el = events.get_event_loop() + r, w = unix_events.socketpair() + w.setblocking(False) + handler = events.Handler(None, w.send, (b'x'*(256*1024),)) + self.assertEqual(el.add_writer(w.fileno(), handler), handler) + def remove_writer(): + self.assertTrue(el.remove_writer(w.fileno())) + el.call_later(0.1, remove_writer) + el.run() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def testWriterCallbackCancel(self): + el = events.get_event_loop() + r, w = unix_events.socketpair() + w.setblocking(False) + def sender(): + w.send(b'x'*256) + handler.cancel() + handler = el.add_writer(w.fileno(), sender) + el.run() + w.close() + data = r.recv(1024) + r.close() + self.assertTrue(data == b'x'*256) + + def testSockClientOps(self): + el = events.get_event_loop() + sock = socket.socket() + sock.setblocking(False) + # TODO: This depends on python.org behavior! + el.run_until_complete(el.sock_connect(sock, ('python.org', 80))) + el.run_until_complete(el.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = el.run_until_complete(el.sock_recv(sock, 1024)) + sock.close() + self.assertTrue(data.startswith(b'HTTP/1.1 302 Found\r\n')) + + def testSockClientFail(self): + el = events.get_event_loop() + sock = socket.socket() + sock.setblocking(False) + # TODO: This depends on python.org behavior! + with self.assertRaises(ConnectionRefusedError): + el.run_until_complete(el.sock_connect(sock, ('python.org', 12345))) + sock.close() + + def testSockAccept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + el = events.get_event_loop() + f = el.sock_accept(listener) + conn, addr = el.run_until_complete(f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + def testAddSignalHandler(self): + caught = 0 + def my_handler(): + nonlocal caught + caught += 1 + el = events.get_event_loop() + # Check error behavior first. + self.assertRaises(TypeError, el.add_signal_handler, 'boom', my_handler) + self.assertRaises(TypeError, el.remove_signal_handler, 'boom') + self.assertRaises(ValueError, el.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises(ValueError, el.remove_signal_handler, signal.NSIG+1) + self.assertRaises(ValueError, el.add_signal_handler, 0, my_handler) + self.assertRaises(ValueError, el.remove_signal_handler, 0) + self.assertRaises(ValueError, el.add_signal_handler, -1, my_handler) + self.assertRaises(ValueError, el.remove_signal_handler, -1) + self.assertRaises(RuntimeError, el.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(el.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + el.add_signal_handler(signal.SIGINT, my_handler) + el.run_once() + os.kill(os.getpid(), signal.SIGINT) + el.run_once() + self.assertEqual(caught, 1) + # Removing it should restore the default handler. + self.assertTrue(el.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(el.remove_signal_handler(signal.SIGINT)) + + @unittest.skipIf(sys.platform == 'win32', 'Unix only') + def testCancelSignalHandler(self): + # Cancelling the handler should remove it (eventually). + caught = 0 + def my_handler(): + nonlocal caught + caught += 1 + el = events.get_event_loop() + handler = el.add_signal_handler(signal.SIGINT, my_handler) + handler.cancel() + os.kill(os.getpid(), signal.SIGINT) + el.run_once() + self.assertEqual(caught, 0) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def testSignalHandlingWhileSelecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + def my_handler(): + nonlocal caught + caught += 1 + el = events.get_event_loop() + handler = el.add_signal_handler(signal.SIGALRM, my_handler) + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + el.call_later(0.15, el.stop) + el.run_forever() + self.assertEqual(caught, 1) + + def testCreateTransport(self): + el = events.get_event_loop() + # TODO: This depends on xkcd.com behavior! + f = el.create_connection(MyProto, 'xkcd.com', 80) + tr, pr = el.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + el.run() + self.assertTrue(pr.nbytes > 0) + + @unittest.skipIf(ssl is None, 'No ssl module') + def testCreateSslTransport(self): + el = events.get_event_loop() + # TODO: This depends on xkcd.com behavior! + f = el.create_connection(MyProto, 'xkcd.com', 443, ssl=True) + tr, pr = el.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + el.run() + self.assertTrue(pr.nbytes > 0) + + def testStartServing(self): + el = events.get_event_loop() + f = el.start_serving(MyProto, '0.0.0.0', 0) + sock = el.run_until_complete(f) + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + el.run_once() # This is quite mysterious, but necessary. + el.run_once() + el.run_once() + sock.close() + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + +if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + SELECTOR_CLASS = selectors.KqueueSelector + + +if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + SELECTOR_CLASS = selectors.EpollSelector + + +if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + SELECTOR_CLASS = selectors.PollSelector + + +# Should always exist. +class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + SELECTOR_CLASS = selectors.SelectSelector + + +class HandlerTests(unittest.TestCase): + + def testHandler(self): + pass + + def testMakeHandler(self): + def callback(*args): + return args + h1 = events.Handler(None, callback, ()) + h2 = events.make_handler(None, h1, ()) + self.assertEqual(h1, h2) + + +class PolicyTests(unittest.TestCase): + + def testPolicy(self): + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/tulip/futures.py b/tulip/futures.py new file mode 100644 index 00000000..e79999fc --- /dev/null +++ b/tulip/futures.py @@ -0,0 +1,240 @@ +"""A Future class similar to the one in PEP 3148.""" + +__all__ = ['CancelledError', 'TimeoutError', + 'InvalidStateError', 'InvalidTimeoutError', + 'Future', + ] + +import concurrent.futures._base + +from . import events + +# States for Future. +_PENDING = 'PENDING' +_CANCELLED = 'CANCELLED' +_FINISHED = 'FINISHED' + +# TODO: Do we really want to depend on concurrent.futures internals? +Error = concurrent.futures._base.Error +CancelledError = concurrent.futures.CancelledError +TimeoutError = concurrent.futures.TimeoutError + + +class InvalidStateError(Error): + """The operation is not allowed in this state.""" + # TODO: Show the future, its state, the method, and the required state. + + +class InvalidTimeoutError(Error): + """Called result() or exception() with timeout != 0.""" + # TODO: Print a nice error message. + + +class Future: + """This class is *almost* compatible with concurrent.futures.Future. + + Differences: + + - result() and exception() do not take a timeout argument and + raise an exception when the future isn't done yet. + + - Callbacks registered with add_done_callback() are always called + via the event loop's call_soon_threadsafe(). + + - This class is not compatible with the wait() and as_completed() + methods in the concurrent.futures package. + + (In Python 3.4 or later we may be able to unify the implementations.) + """ + + # Class variables serving as defaults for instance variables. + _state = _PENDING + _result = None + _exception = None + + def __init__(self, *, event_loop=None): + """Initialize the future. + + The optional event_loop argument allows to explicitly set the event + loop object used by the future. If it's not provided, the future uses + the default event loop. + """ + if event_loop is None: + self._event_loop = events.get_event_loop() + else: + self._event_loop = event_loop + self._callbacks = [] + + def __repr__(self): + res = self.__class__.__name__ + if self._state == _FINISHED: + if self._exception is not None: + res += ''.format(self._exception) + else: + res += ''.format(self._result) + elif self._callbacks: + size = len(self._callbacks) + if size > 2: + res += '<{}, [{}, <{} more>, {}]>'.format( + self._state, self._callbacks[0], + size-2, self._callbacks[-1]) + else: + res += '<{}, {}>'.format(self._state, self._callbacks) + else: + res +='<{}>'.format(self._state) + return res + + def cancel(self): + """Cancel the future and schedule callbacks. + + If the future is already done or cancelled, return False. Otherwise, + change the future's state to cancelled, schedule the callbacks and + return True. + """ + if self._state != _PENDING: + return False + self._state = _CANCELLED + self._schedule_callbacks() + return True + + def _schedule_callbacks(self): + """Internal: Ask the event loop to call all callbacks. + + The callbacks are scheduled to be called as soon as possible. Also + clears the callback list. + """ + callbacks = self._callbacks[:] + if not callbacks: + return + self._callbacks[:] = [] + for callback in callbacks: + self._event_loop.call_soon(callback, self) + + def cancelled(self): + """Return True if the future was cancelled.""" + return self._state == _CANCELLED + + def running(self): + """Always return False. + + This method is for compatibility with concurrent.futures; we don't + have a running state. + """ + return False # We don't have a running state. + + def done(self): + """Return True if the future is done. + + Done means either that a result / exception are available, or that the + future was cancelled. + """ + return self._state != _PENDING + + def result(self, timeout=0): + """Return the result this future represents. + + If the future has been cancelled, raises CancelledError. If the + future's result isn't yet available, raises InvalidStateError. If + the future is done and has an exception set, this exception is raised. + Timeout values other than 0 are not supported. + """ + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + if self._exception is not None: + raise self._exception + return self._result + + def exception(self, timeout=0): + """Return the exception that was set on this future. + + The exception (or None if no exception was set) is returned only if + the future is done. If the future has been cancelled, raises + CancelledError. If the future isn't done yet, raises + InvalidStateError. Timeout values other than 0 are not supported. + """ + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + return self._exception + + def add_done_callback(self, fn): + """Add a callback to be run when the future becomes done. + + The callback is called with a single argument - the future object. If + the future is already done when this is called, the callback is + scheduled with call_soon. + """ + if self._state != _PENDING: + self._event_loop.call_soon(fn, self) + else: + self._callbacks.append(fn) + + # New method not in PEP 3148. + + def remove_done_callback(self, fn): + """Remove all instances of a callback from the "call when done" list. + + Returns the number of callbacks removed. + """ + filtered_callbacks = [f for f in self._callbacks if f != fn] + removed_count = len(self._callbacks) - len(filtered_callbacks) + if removed_count: + self._callbacks[:] = filtered_callbacks + return removed_count + + # So-called internal methods (note: no set_running_or_notify_cancel()). + + def set_result(self, result): + """Mark the future done and set its result. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError + self._result = result + self._state = _FINISHED + self._schedule_callbacks() + + def set_exception(self, exception): + """ Mark the future done and set an exception. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError + self._exception = exception + self._state = _FINISHED + self._schedule_callbacks() + + # Truly internal methods. + + def _copy_state(self, other): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert other.done() + assert not self.done() + if other.cancelled(): + self.cancel() + else: + exception = other.exception() + if exception is not None: + self.set_exception(exception) + else: + result = other.result() + self.set_result(result) + + def __iter__(self): + if not self.done(): + yield self # This tells Task to wait for completion. + return self.result() # May raise too. diff --git a/tulip/futures_test.py b/tulip/futures_test.py new file mode 100644 index 00000000..7834fec8 --- /dev/null +++ b/tulip/futures_test.py @@ -0,0 +1,210 @@ +"""Tests for futures.py.""" + +import unittest + +from . import futures + + +def _fakefunc(f): + return f + + +class FutureTests(unittest.TestCase): + + def testInitialState(self): + f = futures.Future() + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertFalse(f.done()) + + def testInitEventLoopPositional(self): + # Make sure Future does't accept a positional argument + self.assertRaises(TypeError, futures.Future, 42) + + def testCancel(self): + f = futures.Future() + self.assertTrue(f.cancel()) + self.assertTrue(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(futures.CancelledError, f.result) + self.assertRaises(futures.CancelledError, f.exception) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def testResult(self): + f = futures.Future() + self.assertRaises(futures.InvalidStateError, f.result) + self.assertRaises(futures.InvalidTimeoutError, f.result, 10) + + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def testException(self): + exc = RuntimeError() + f = futures.Future() + self.assertRaises(futures.InvalidStateError, f.exception) + self.assertRaises(futures.InvalidTimeoutError, f.exception, 10) + + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def testYieldFromTwice(self): + f = futures.Future() + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + + def testRepr(self): + f_pending = futures.Future() + self.assertEqual(repr(f_pending), 'Future') + + f_cancelled = futures.Future() + f_cancelled.cancel() + self.assertEqual(repr(f_cancelled), 'Future') + + f_result = futures.Future() + f_result.set_result(4) + self.assertEqual(repr(f_result), 'Future') + + f_exception = futures.Future() + f_exception.set_exception(RuntimeError()) + self.assertEqual(repr(f_exception), 'Future') + + f_few_callbacks = futures.Future() + f_few_callbacks.add_done_callback(_fakefunc) + self.assertIn('Future', r) + + def testCopyState(self): + # Test the internal _copy_state method since it's being directly + # invoked in other modules. + f = futures.Future() + f.set_result(10) + + newf = futures.Future() + newf._copy_state(f) + self.assertTrue(newf.done()) + self.assertEqual(newf.result(), 10) + + f_exception = futures.Future() + f_exception.set_exception(RuntimeError()) + + newf_exception = futures.Future() + newf_exception._copy_state(f_exception) + self.assertTrue(newf_exception.done()) + self.assertRaises(RuntimeError, newf_exception.result) + + f_cancelled = futures.Future() + f_cancelled.cancel() + + newf_cancelled = futures.Future() + newf_cancelled._copy_state(f_cancelled) + self.assertTrue(newf_cancelled.cancelled()) + + +# A fake event loop for tests. All it does is implement a call_soon method +# that immediately invokes the given function. +class _FakeEventLoop: + def call_soon(self, fn, future): + fn(future) + + +class FutureDoneCallbackTests(unittest.TestCase): + + def _make_callback(self, bag, thing): + # Create a callback function that appends thing to bag. + def bag_appender(future): + bag.append(thing) + return bag_appender + + def _new_future(self): + return futures.Future(event_loop=_FakeEventLoop()) + + def testCallbacksInvokedOnSetResult(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 42)) + f.add_done_callback(self._make_callback(bag, 17)) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def testCallbacksInvokedOnSetException(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 100)) + + self.assertEqual(bag, []) + exc = RuntimeError() + f.set_exception(exc) + self.assertEqual(bag, [100]) + self.assertEqual(f.exception(), exc) + + def testRemoveDoneCallback(self): + bag = [] + f = self._new_future() + cb1 = self._make_callback(bag, 1) + cb2 = self._make_callback(bag, 2) + cb3 = self._make_callback(bag, 3) + + # Add one cb1 and one cb2. + f.add_done_callback(cb1) + f.add_done_callback(cb2) + + # One instance of cb2 removed. Now there's only one cb1. + self.assertEqual(f.remove_done_callback(cb2), 1) + + # Never had any cb3 in there. + self.assertEqual(f.remove_done_callback(cb3), 0) + + # After this there will be 6 instances of cb1 and one of cb2. + f.add_done_callback(cb2) + for i in range(5): + f.add_done_callback(cb1) + + # Remove all instances of cb1. One cb2 remains. + self.assertEqual(f.remove_done_callback(cb1), 6) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [2]) + self.assertEqual(f.result(), 'foo') + + +if __name__ == '__main__': + unittest.main() diff --git a/tulip/http_client.py b/tulip/http_client.py new file mode 100644 index 00000000..f01ee631 --- /dev/null +++ b/tulip/http_client.py @@ -0,0 +1,296 @@ +"""HTTP Client for Tulip. + +Most basic usage: + + sts, headers, response = yield from http_client.fetch(url, + method='GET', headers={}, request=b'') + assert isinstance(sts, int) + assert isinstance(headers, dict) + # sort of; case insensitive (what about multiple values for same header?) + headers['status'] == '200 Ok' # or some such + assert isinstance(response, bytes) + +However you can also open a stream: + + f, wstream = http_client.open_stream(url, method, headers) + wstream.write(b'abc') + wstream.writelines([b'def', b'ghi']) + wstream.write_eof() + sts, headers, rstream = yield from f + response = yield from rstream.read() + +TODO: Reuse email.Message class (or its subclass, http.client.HTTPMessage). +TODO: How do we do connection keep alive? Pooling? +""" + +import collections +import email.message +import email.parser +import re + +import tulip +from . import events +from . import futures +from . import tasks + + +# TODO: Move to another module. +class StreamReader: + + def __init__(self, limit=2**16): + self.limit = limit # Max line length. (Security feature.) + self.buffer = collections.deque() # Deque of bytes objects. + self.byte_count = 0 # Bytes in buffer. + self.line_count = 0 # Number of complete lines in buffer. + self.eof = False # Whether we're done. + self.waiter = None # A future. + + def feed_eof(self): + self.eof = True + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(True) + + def feed_data(self, data): + if not data: + return + self.buffer.append(data) + self.line_count += data.count(b'\n') + self.byte_count += len(data) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(False) + + @tasks.coroutine + def readline(self): + # TODO: Limit line length for security. + while not self.line_count and not self.eof: + assert self.waiter is None + self.waiter = futures.Future() + yield from self.waiter + continue + parts = [] + while self.buffer: + data = self.buffer.popleft() + ichar = data.find(b'\n') + if ichar < 0: + parts.append(data) + else: + ichar += 1 + head, tail = data[:ichar], data[ichar:] + parts.append(head) + if tail: + self.buffer.appendleft(tail) + self.line_count -= 1 + break + + line = b''.join(parts) + self.byte_count -= len(line) + + return line + + @tasks.coroutine + def read(self, n=-1): + if not n: + return b'' + if n < 0: + while not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + else: + if not self.byte_count and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + if n < 0 or self.byte_count <= n: + data = b''.join(self.buffer) + self.buffer.clear() + self.byte_count = 0 + self.line_count = 0 + return data + parts = [] + parts_bytes = 0 + while self.buffer and parts_bytes < n: + data = self.buffer.popleft() + data_bytes = len(data) + if n < parts_bytes + data_bytes: + data_bytes = n - parts_bytes + data, rest = data[:data_bytes], data[data_bytes:] + self.buffer.appendleft(rest) + parts.append(data) + parts_bytes += data_bytes + self.byte_count -= data_bytes + if self.line_count: + self.line_count -= data.count(b'\n') + return b''.join(parts) + + @tasks.coroutine + def readexactly(self, n): + if n <= 0: + return b'' + while self.byte_count < n and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + return (yield from self.read(n)) + +class HttpClientProtocol: + """This Protocol class is also used to initiate the connection. + + Usage: + p = HttpClientProtocol(url, ...) + f = p.connect() # Returns a Future + ...now what?... + """ + + def __init__(self, host, port=None, *, + path='/', method='GET', headers=None, ssl=None, + make_body=None, encoding='utf-8', version='1.1', + chunked=False): + host = self.validate(host, 'host') + if ':' in host: + assert port is None + host, port_s = host.split(':', 1) + port = int(port_s) + self.host = host + if port is None: + if ssl: + port = 443 + else: + port = 80 + assert isinstance(port, int) + self.port = port + self.path = self.validate(path, 'path') + self.method = self.validate(method, 'method') + self.headers = email.message.Message() + if headers: + for key, value in headers.items(): + self.validate(key, 'header key') + self.validate(value, 'header value', True) + self.headers[key] = value + self.encoding = self.validate(encoding, 'encoding') + self.version = self.validate(version, 'version') + self.make_body = make_body + self.chunked = chunked + self.ssl = ssl + if 'content-length' not in self.headers: + if self.make_body is None: + self.headers['Content-Length'] = '0' + else: + self.chunked = True + if self.chunked: + if 'Transfer-Encoding' not in self.headers: + self.headers['Transfer-Encoding'] = 'chunked' + else: + assert self.headers['Transfer-Encoding'].lower() == 'chunked' + if 'host' not in self.headers: + self.headers['Host'] = self.host + self.event_loop = events.get_event_loop() + self.transport = None + + def validate(self, value, name, embedded_spaces_okay=False): + # Must be a string. If embedded_spaces_okay is False, no + # whitespace is allowed; otherwise, internal single spaces are + # allowed (but no other whitespace). + assert isinstance(value, str), \ + '{} should be str, not {}'.format(name, type(value)) + parts = value.split() + assert parts, '{} should not be empty'.format(name) + if embedded_spaces_okay: + assert ' '.join(parts) == value, \ + '{} can only contain embedded single spaces ({!r})'.format( + name, value) + else: + assert parts == [value], \ + '{} cannot contain whitespace ({!r})'.format(name, value) + return value + + @tulip.coroutine + def connect(self): + yield from self.event_loop.create_connection(lambda: self, + self.host, + self.port, + ssl=self.ssl) + # TODO: A better mechanism to return all info from the + # status line, all headers, and the buffer, without having + # an N-tuple return value. + status_line = yield from self.stream.readline() + m = re.match(rb'HTTP/(\d\.\d)\s+(\d\d\d)\s+([^\r\n]+)\r?\n\Z', + status_line) + if not m: + raise 'Invalid HTTP status line ({!r})'.format(status_line) + version, status, message = m.groups() + raw_headers = [] + while True: + header = yield from self.stream.readline() + if not header.strip(): + break + raw_headers.append(header) + parser = email.parser.BytesHeaderParser() + headers = parser.parsebytes(b''.join(raw_headers)) + content_length = headers.get('content-length') + if content_length: + content_length = int(content_length) # May raise. + if content_length is None: + stream = self.stream + else: + # TODO: A wrapping stream that limits how much it can read + # without reading it all into memory at once. + body = yield from self.stream.readexactly(content_length) + stream = StreamReader() + stream.feed_data(body) + stream.feed_eof() + sts = '{} {}'.format(self.decode(status), self.decode(message)) + return (sts, headers, stream) + + def encode(self, s): + if isinstance(s, bytes): + return s + return s.encode(self.encoding) + + def decode(self, s): + if isinstance(s, str): + return s + return s.decode(self.encoding) + + def write_str(self, s): + self.transport.write(self.encode(s)) + + def write_chunked(self, s): + if not s: + return + data = self.encode(s) + self.write_str('{:x}\r\n'.format(len(data))) + self.transport.write(data) + self.transport.write(b'\r\n') + + def write_chunked_eof(self): + self.transport.write(b'0\r\n\r\n') + + def connection_made(self, transport): + self.transport = transport + line = '{} {} HTTP/{}\r\n'.format(self.method, + self.path, + self.version) + self.write_str(line) + for key, value in self.headers.items(): + self.write_str('{}: {}\r\n'.format(key, value)) + self.transport.write(b'\r\n') + self.stream = StreamReader() + if self.make_body is not None: + if self.chunked: + self.make_body(self.write_chunked, self.write_chunked_eof) + else: + self.make_body(self.write_str, self.transport.write_eof) + + def data_received(self, data): + self.stream.feed_data(data) + + def eof_received(self): + self.stream.feed_eof() + + def connection_lost(self, exc): + pass diff --git a/tulip/http_client_test.py b/tulip/http_client_test.py new file mode 100644 index 00000000..f712da94 --- /dev/null +++ b/tulip/http_client_test.py @@ -0,0 +1,76 @@ +"""Tests for http_client.py.""" + +import unittest + +from . import events +from . import http_client +from . import tasks + + +class StreamReaderTests(unittest.TestCase): + + DATA = b'line1\nline2\nline3\n' + + def setUp(self): + self.event_loop = events.new_event_loop() + self.addCleanup(self.event_loop.close) + + events.set_event_loop(self.event_loop) + + def test_feed_empty_data(self): + stream = http_client.StreamReader() + + stream.feed_data(b'') + self.assertEqual(0, stream.line_count) + self.assertEqual(0, stream.byte_count) + + def test_feed_data_line_byte_count(self): + stream = http_client.StreamReader() + + stream.feed_data(self.DATA) + self.assertEqual(self.DATA.count(b'\n'), stream.line_count) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readline_line_byte_count(self): + stream = http_client.StreamReader() + stream.feed_data(self.DATA) + + @tasks.coroutine + def readline(): + line = yield from stream.readline() + return line + + line = self.event_loop.run_until_complete(tasks.Task(readline())) + + self.assertEqual(b'line1\n', line) + self.assertEqual(self.DATA.count(b'\n')-1, stream.line_count) + self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) + + def test_readline_read_byte_count(self): + stream = http_client.StreamReader() + stream.feed_data(self.DATA) + + @tasks.coroutine + def readline(): + line = yield from stream.readline() + return line + + line = self.event_loop.run_until_complete(tasks.Task(readline())) + + @tasks.coroutine + def read(): + line = yield from stream.read(7) + return line + + data = self.event_loop.run_until_complete(tasks.Task(read())) + + self.assertEqual(b'line2\nl', data) + self.assertEqual( + 1, stream.line_count) + self.assertEqual( + len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), + stream.byte_count) + + +if __name__ == '__main__': + unittest.main() diff --git a/tulip/protocols.py b/tulip/protocols.py new file mode 100644 index 00000000..ad294f3a --- /dev/null +++ b/tulip/protocols.py @@ -0,0 +1,58 @@ +"""Abstract Protocol class.""" + +__all__ = ['Protocol'] + + +class Protocol: + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing (they don't raise exceptions). + + When the user wants to requests a transport, they pass a protocol + factory to a utility function (e.g., EventLoop.create_connection()). + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_list() will be called exactly once + with either an exception object or None as an argument. + + State machine of calls: + + start -> CM [-> DR*] [-> ER?] -> CL -> end + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the connection. + To send data, call its write() or writelines() method. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + The default implementation does nothing. + + TODO: By default close the transport. But we don't have the + transport as an instance variable (connection_made() may not + set it). + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ diff --git a/tulip/selectors.py b/tulip/selectors.py new file mode 100644 index 00000000..05434630 --- /dev/null +++ b/tulip/selectors.py @@ -0,0 +1,430 @@ +"""Select module. + +This module supports asynchronous I/O on multiple file descriptors. +""" + +import logging +import sys + +from select import * + + +# generic events, that must be mapped to implementation-specific ones +# read event +EVENT_READ = (1 << 0) +# write event +EVENT_WRITE = (1 << 1) +# connect event +EVENT_CONNECT = (1 << 2) + +# In most cases we treat EVENT_WRITE and EVENT_CONNECT as aliases for +# each other, and in fact we return both flags when a FD is found +# either writable or connectable. The distinction is necessary +# only for poll() on Windows. + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file descriptor, or any object with a `fileno()` method + + Returns: + corresponding file descriptor + """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (ValueError, TypeError): + raise ValueError("Invalid file object: {!r}".format(fileobj)) + return fd + + +class SelectorKey: + """Object used internally to associate a file object to its backing file + descriptor, selected event mask and attached data.""" + + def __init__(self, fileobj, events, data=None): + self.fileobj = fileobj + self.fd = _fileobj_to_fd(fileobj) + self.events = events + self.data = data + + def __repr__(self): + return '{}'.format( + self.__class__.__name__, + self.fileobj, self.fd, self.events, self.data) + + +class _BaseSelector: + """Base selector class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + performant implementation on the current platform. + """ + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # this maps file objects to keys - for fast (un)registering + self._fileobj_to_key = {} + + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of + EVENT_READ|EVENT_WRITE|EVENT_CONNECT) + data -- attached data + + Returns: + SelectorKey instance + """ + if (not events) or (events & ~(EVENT_READ|EVENT_WRITE|EVENT_CONNECT)): + raise ValueError("Invalid events: {}".format(events)) + + if events & (EVENT_WRITE|EVENT_CONNECT) == (EVENT_WRITE|EVENT_CONNECT): + raise ValueError("WRITE and CONNECT are mutually exclusive. " + "Invalid events: {}".format(events)) + + if fileobj in self._fileobj_to_key: + raise ValueError("{!r} is already registered".format(fileobj)) + + key = SelectorKey(fileobj, events, data) + self._fd_to_key[key.fd] = key + self._fileobj_to_key[fileobj] = key + return key + + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object + + Returns: + SelectorKey instance + """ + try: + key = self._fileobj_to_key[fileobj] + del self._fd_to_key[key.fd] + del self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + return key + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of + EVENT_READ|EVENT_WRITE|EVENT_CONNECT) + data -- attached data + """ + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + if events != key.events: + self.unregister(fileobj) + self.register(fileobj, events, data) + + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout == 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (fileobj, events, attached data) for ready file objects + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE|EVENT_CONNECT + """ + raise NotImplementedError() + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + self._fd_to_key.clear() + self._fileobj_to_key.clear() + + def get_info(self, fileobj): + """Return information about a registered file object. + + Returns: + (events, data) associated to this file object + + Raises KeyError if the file object is not registered. + """ + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise KeyError("{} is not registered".format(fileobj)) + return key.events, key.data + + def registered_count(self): + """Return the number of registered file objects. + + Returns: + number of currently registered file objects + """ + return len(self._fd_to_key) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key + """ + try: + return self._fd_to_key[fd] + except KeyError: + logging.warn('No key found for fd %r', fd) + return None + + +class SelectSelector(_BaseSelector): + """Select-based selector.""" + + def __init__(self): + super().__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & (EVENT_WRITE|EVENT_CONNECT): + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + def select(self, timeout=None): + try: + r, w, _ = self._select(self._readers, self._writers, [], timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + r = set(r) + w = set(w) + ready = [] + for fd in r | w: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE|EVENT_CONNECT + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + if sys.platform == 'win32': + def _select(self, r, w, _, timeout=None): + r, w, x = select(r, w, w, timeout) + return r, w + x, [] + else: + from select import select as _select + + +if 'poll' in globals(): + + # TODO: Implement poll() for Windows with workaround for + # brokenness in WSAPoll() (Richard Oudkerk, see + # http://bugs.python.org/issue16507). + + class PollSelector(_BaseSelector): + """Poll-based selector.""" + + def __init__(self): + super().__init__() + self._poll = poll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= POLLIN + if events & (EVENT_WRITE|EVENT_CONNECT): + poll_events |= POLLOUT + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = None if timeout is None else int(1000 * timeout) + ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~POLLIN: + events |= EVENT_WRITE|EVENT_CONNECT + if event & ~POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + +if 'epoll' in globals(): + + class EpollSelector(_BaseSelector): + """Epoll-based selector.""" + + def __init__(self): + super().__init__() + self._epoll = epoll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + epoll_events = 0 + if events & EVENT_READ: + epoll_events |= EPOLLIN + if events & (EVENT_WRITE|EVENT_CONNECT): + epoll_events |= EPOLLOUT + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._epoll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = -1 if timeout is None else timeout + max_ev = self.registered_count() + ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~EPOLLIN: + events |= EVENT_WRITE|EVENT_CONNECT + if event & ~EPOLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + def close(self): + super().close() + self._epoll.close() + + +if 'kqueue' in globals(): + + class KqueueSelector(_BaseSelector): + """Kqueue-based selector.""" + + def __init__(self): + super().__init__() + self._kqueue = kqueue() + + def unregister(self, fileobj): + key = super().unregister(fileobj) + mask = 0 + if key.events & EVENT_READ: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + if key.events & (EVENT_WRITE|EVENT_CONNECT): + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + return key + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & (EVENT_WRITE|EVENT_CONNECT): + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def select(self, timeout=None): + max_ev = self.registered_count() + ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == KQ_FILTER_READ: + events |= EVENT_READ + if flag == KQ_FILTER_WRITE: + events |= EVENT_WRITE|EVENT_CONNECT + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + def close(self): + super().close() + self._kqueue.close() + + +# Choose the best implementation: roughly, epoll|kqueue > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + Selector = KqueueSelector +elif 'EpollSelector' in globals(): + Selector = EpollSelector +elif 'PollSelector' in globals(): + Selector = PollSelector +else: + Selector = SelectSelector diff --git a/tulip/subprocess_test.py b/tulip/subprocess_test.py new file mode 100644 index 00000000..4eb24e41 --- /dev/null +++ b/tulip/subprocess_test.py @@ -0,0 +1,48 @@ +"""Tests for subprocess_transport.py.""" + +import unittest + +from . import events +from . import protocols +from . import subprocess_transport + + +class MyProto(protocols.Protocol): + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write_eof() + def data_received(self, data): + print('received:', data) + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + + +class FutureTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def testUnixSubprocess(self): + p = MyProto() + t = subprocess_transport.UnixSubprocessTransport(p, ['/bin/ls', '-lR']) + self.event_loop.run() + + +if __name__ == '__main__': + unittest.main() diff --git a/tulip/subprocess_transport.py b/tulip/subprocess_transport.py new file mode 100644 index 00000000..721013f8 --- /dev/null +++ b/tulip/subprocess_transport.py @@ -0,0 +1,133 @@ +import fcntl +import os +import traceback + +from . import transports +from . import events + + +class UnixSubprocessTransport(transports.Transport): + """Transport class managing a subprocess. + + TODO: Separate this into something that just handles pipe I/O, + and something else that handles pipe setup, fork, and exec. + """ + + def __init__(self, protocol, args): + self._protocol = protocol # Not a factory! :-) + self._args = args # args[0] must be full path of binary. + self._event_loop = events.get_event_loop() + self._buffer = [] + self._eof = False + rstdin, self._wstdin = os.pipe() + self._rstdout, wstdout = os.pipe() + + # TODO: This is incredibly naive. Should look at + # subprocess.py for all the precautions around fork/exec. + pid = os.fork() + if not pid: + # Child. + try: + os.dup2(rstdin, 0) + os.dup2(wstdout, 1) + # TODO: What to do with stderr? + os.execv(args[0], args) + except: + try: + traceback.print_traceback() + finally: + os._exit(127) + + # Parent. + os.close(rstdin) + os.close(wstdout) + _setnonblocking(self._wstdin) + _setnonblocking(self._rstdout) + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.add_reader(self._rstdout, self._stdout_callback) + + def write(self, data): + assert not self._eof + if not data: + return + if not self._buffer: + # Attempt to write it right away first. + try: + n = os.write(self._wstdin, data) + except BlockingIOError: + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + if n >= len(data): + return + if n > 0: + data = data[n:] + self._event_loop.add_writer(self._wstdin, self._stdin_callback) + self._buffer.append(data) + + def write_eof(self): + assert not self._eof + assert self._wstdin >= 0 + self._eof = True + if not self._buffer: + self._event_loop.remove_writer(self._wstdin) + os.close(self._wstdin) + self._wstdin = -1 + + def close(self): + if not self._eof: + self.write_eof() + # XXX What else? + + def _fatal_error(self, exc): + logging.error('Fatal error: %r', exc) + if self._rstdout >= 0: + os.close(self._rstdout) + self._rstdout = -1 + if self._wstdin >= 0: + os.close(self._wstdin) + self._wstdin = -1 + self._eof = True + self._buffer = None + + def _stdin_callback(self): + data = b''.join(self._buffer) + self._buffer = [] + try: + if data: + n = os.write(self._wstdin, data) + else: + n = 0 + except BlockingIOError: + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + if n >= len(data): + self._event_loop.remove_writer(self._wstdin) + if self._eof: + os.close(self._wstdin) + self._wstdin = -1 + return + if n > 0: + data = data[n:] + self._buffer.append(data) # Try again later. + + def _stdout_callback(self): + try: + data = os.read(self._rstdout, 1024) + except BlockingIOError: + return + if data: + self._event_loop.call_soon(self._protocol.data_received, data) + else: + self._event_loop.remove_reader(self._rstdout) + os.close(self._rstdout) + self._rstdout = -1 + self._event_loop.call_soon(self._protocol.eof_received) + + +def _setnonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) diff --git a/tulip/tasks.py b/tulip/tasks.py new file mode 100644 index 00000000..af7eaaff --- /dev/null +++ b/tulip/tasks.py @@ -0,0 +1,279 @@ +"""Support for tasks, coroutines and the scheduler.""" + +__all__ = ['coroutine', 'task', 'Task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'as_completed', 'sleep', + ] + +import concurrent.futures +import inspect +import logging +import time + +from . import events +from . import futures + + +def coroutine(func): + """Decorator to mark coroutines.""" + # TODO: This is a feel-good API only. It is not enforced. + assert inspect.isgeneratorfunction(func) + func._is_coroutine = True # Not sure who can use this. + return func + + +# TODO: Do we need this? +def iscoroutinefunction(func): + """Return True if func is a decorated coroutine function.""" + return (inspect.isgeneratorfunction(func) and + getattr(func, '_is_coroutine', False)) + + +# TODO: Do we need this? +def iscoroutine(obj): + """Return True if obj is a coroutine object.""" + return inspect.isgenerator(obj) # TODO: And what? + + +def task(func): + """Decorator for a coroutine to be wrapped in a Task.""" + def task_wrapper(*args, **kwds): + coro = func(*args, **kwds) + return Task(coro) + return task_wrapper + + +class Task(futures.Future): + """A coroutine wrapped in a Future.""" + + def __init__(self, coro): + assert inspect.isgenerator(coro) # Must be a coroutine *object*. + super().__init__() # Sets self._event_loop. + self._coro = coro + self._must_cancel = False + self._event_loop.call_soon(self._step) + + def __repr__(self): + res = super().__repr__() + if (self._must_cancel and + self._state == futures._PENDING and + ')'.format(self._coro.__name__) + res[i:] + return res + + def cancel(self): + if self.done(): + return False + self._must_cancel = True + # _step() will call super().cancel() to call the callbacks. + self._event_loop.call_soon(self._step_maybe) + return True + + def cancelled(self): + return self._must_cancel or super().cancelled() + + def _step_maybe(self): + # Helper for cancel(). + if not self.done(): + return self._step() + + def _step(self, value=None, exc=None): + if self.done(): # pragma: no cover + logging.warn('_step(): already done: %r, %r, %r', self, value, exc) + return + # We'll call either coro.throw(exc) or coro.send(value). + if self._must_cancel: + exc = futures.CancelledError + coro = self._coro + try: + if exc is not None: + result = coro.throw(exc) + elif value is not None: + result = coro.send(value) + else: + result = next(coro) + except StopIteration as exc: + if self._must_cancel: + super().cancel() + else: + self.set_result(exc.value) + except Exception as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + logging.exception('Exception in task') + except BaseException as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + logging.exception('BaseException in task') + raise + else: + # XXX No check for self._must_cancel here? + if isinstance(result, futures.Future): + result.add_done_callback(self._wakeup) + elif isinstance(result, concurrent.futures.Future): + # This ought to be more efficient than wrap_future(), + # because we don't create an extra Future. + result.add_done_callback( + lambda future: + self._event_loop.call_soon_threadsafe(_wakeup, future)) + else: + if result is not None: + logging.warn('_step(): bad yield: %r', result) + self._event_loop.call_soon(self._step) + + def _wakeup(self, future): + try: + value = future.result() + except Exception as exc: + self._step(None, exc) + else: + self._step(value, None) + + +# wait() and as_completed() similar to those in PEP 3148. + +FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED +FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION +ALL_COMPLETED = concurrent.futures.ALL_COMPLETED + + +# Even though this *is* a @coroutine, we don't mark it as such! +def wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Wait for the Futures and and coroutines given by fs to complete. + + Coroutines will be wrapped in Tasks. + + Returns two sets of Future: (done, pending). + + Usage: + + done, pending = yield from tulip.wait(fs) + + Note: This does not raise TimeoutError! Futures that aren't done + when the timeout occurs are returned in the second set. + """ + fs = _wrap_coroutines(fs) + return _wait(fs, timeout, return_when) + + +@coroutine +def _wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Internal helper: Like wait() but does not wrap coroutines.""" + done, pending = set(), set() + errors = 0 + for f in fs: + if f.done(): + done.add(f) + if not f.cancelled() and f.exception() is not None: + errors += 1 + else: + pending.add(f) + if (not pending or + timeout != None and timeout <= 0 or + return_when == FIRST_COMPLETED and done or + return_when == FIRST_EXCEPTION and errors): + return done, pending + bail = futures.Future() # Will always be cancelled eventually. + timeout_handler = None + debugstuff = locals() + if timeout is not None: + loop = events.get_event_loop() + timeout_handler = loop.call_later(timeout, bail.cancel) + def _on_completion(f): + pending.remove(f) + done.add(f) + if (not pending or + return_when == FIRST_COMPLETED or + (return_when == FIRST_EXCEPTION and + not f.cancelled() and + f.exception() is not None)): + bail.cancel() + try: + for f in pending: + f.add_done_callback(_on_completion) + try: + yield from bail + except futures.CancelledError: + pass + finally: + for f in pending: + f.remove_done_callback(_on_completion) + if timeout_handler is not None: + timeout_handler.cancel() + really_done = set(f for f in pending if f.done()) + if really_done: # pragma: no cover + # We don't expect this to ever happen. Or do we? + done.update(really_done) + pending.difference_update(really_done) + return done, pending + + +# This is *not* a @coroutine! It is just an iterator (yielding Futures). +def as_completed(fs, timeout=None): + """Return an iterator whose values, when waited for, are Futures. + + This differs from PEP 3148; the proper way to use this is: + + for f in as_completed(fs): + result = yield from f # The 'yield from' may raise. + # Use result. + + Raises TimeoutError if the timeout occurs before all Futures are + done. + + Note: The futures 'f' are not necessarily members of fs. + """ + deadline = None + if timeout is not None: + deadline = time.monotonic() + timeout + done = None # Make nonlocal happy. + fs = _wrap_coroutines(fs) + while fs: + if deadline is not None: + timeout = deadline - time.monotonic() + @coroutine + def _wait_for_some(): + nonlocal done, fs + done, fs = yield from _wait(fs, timeout=timeout, + return_when=FIRST_COMPLETED) + if not done: + fs = set() + raise futures.TimeoutError() + return done.pop().result() # May raise. + yield Task(_wait_for_some()) + for f in done: + yield f + + +def _wrap_coroutines(fs): + """Internal helper to process an iterator of Futures and coroutines. + + Returns a set of Futures. + """ + wrapped = set() + for f in fs: + if not isinstance(f, futures.Future): + assert iscoroutine(f) + f = Task(f) + wrapped.add(f) + return wrapped + + +def sleep(when, result=None): + """Return a Future that completes after a given time (in seconds). + + It's okay to cancel the Future. + + Undocumented feature: sleep(when, x) sets the Future's result to x. + """ + future = futures.Future() + future._event_loop.call_later(when, future.set_result, result) + return future diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py new file mode 100644 index 00000000..a7381450 --- /dev/null +++ b/tulip/tasks_test.py @@ -0,0 +1,233 @@ +"""Tests for tasks.py.""" + +import time +import unittest + +from . import events +from . import futures +from . import tasks + + +class Dummy: + def __repr__(self): + return 'Dummy()' + def __call__(self, *args): + pass + + +class TaskTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def testTaskClass(self): + @tasks.coroutine + def notmuch(): + yield from [] + return 'ok' + t = tasks.Task(notmuch()) + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + + def testTaskDecorator(self): + @tasks.task + def notmuch(): + yield from [] + return 'ko' + t = notmuch() + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + + def testTaskRepr(self): + @tasks.task + def notmuch(): + yield from [] + return 'abc' + t = notmuch() + t.add_done_callback(Dummy()) + self.assertEqual(repr(t), 'Task()') + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), 'Task()') + self.assertRaises(futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertEqual(repr(t), 'Task()') + t = notmuch() + self.event_loop.run_until_complete(t) + self.assertEqual(repr(t), "Task()") + + def testTaskBasics(self): + @tasks.task + def outer(): + a = yield from inner1() + b = yield from inner2() + return a+b + @tasks.task + def inner1(): + yield from [] + return 42 + @tasks.task + def inner2(): + yield from [] + return 1000 + t = outer() + self.assertEqual(self.event_loop.run_until_complete(t), 1042) + + def testWait(self): + a = tasks.sleep(0.1) + b = tasks.sleep(0.15) + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertEqual(res, 42) + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + # TODO: Test different return_when values. + + def testWaitWithException(self): + a = tasks.sleep(0.1) + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(0.15) + raise ZeroDivisionError('really') + b = tasks.Task(sleeper()) + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + + def testWaitWithTimeout(self): + a = tasks.sleep(0.1) + b = tasks.sleep(0.15) + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], timeout=0.11) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.1) + self.assertTrue(t1-t0 <= 0.13) + + def testAsCompleted(self): + @tasks.coroutine + def sleeper(dt, x): + yield from tasks.sleep(dt) + return x + a = sleeper(0.1, 'a') + b = sleeper(0.1, 'b') + c = sleeper(0.15, 'c') + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([b, c, a]): + values.append((yield from f)) + return values + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + + def testAsCompletedWithTimeout(self): + a = tasks.sleep(0.1, 'a') + b = tasks.sleep(0.15, 'b') + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([a, b], timeout=0.12): + try: + v = yield from f + values.append((1, v)) + except futures.TimeoutError as exc: + values.append((2, exc)) + return values + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.11) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) + + def testSleep(self): + @tasks.coroutine + def sleeper(dt, arg): + yield from tasks.sleep(dt/2) + res = yield from tasks.sleep(dt/2, arg) + return res + t = tasks.Task(sleeper(0.1, 'yeah')) + t0 = time.monotonic() + self.event_loop.run() + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.09) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + + def testTaskCancelSleepingTask(self): + sleepfut = None + @tasks.task + def sleep(dt): + nonlocal sleepfut + sleepfut = tasks.sleep(dt) + try: + t0 = time.monotonic() + yield from sleepfut + finally: + t1 = time.monotonic() + @tasks.task + def doit(): + sleeper = sleep(5000) + self.event_loop.call_later(0.1, sleeper.cancel) + try: + t0 = time.monotonic() + yield from sleeper + except futures.CancelledError: + t1 = time.monotonic() + return 'cancelled' + else: + return 'slept in' + t0 = time.monotonic() + doer = doit() + self.assertEqual(self.event_loop.run_until_complete(doer), 'cancelled') + t1 = time.monotonic() + self.assertTrue(0.09 <= t1-t0 <= 0.13, (t1-t0, sleepfut, doer)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tulip/transports.py b/tulip/transports.py new file mode 100644 index 00000000..4aaae3c7 --- /dev/null +++ b/tulip/transports.py @@ -0,0 +1,90 @@ +"""Abstract Transport class.""" + +__all__ = ['Transport'] + + +class Transport: + """ABC representing a transport. + + There may be several implementations, but typically, the user does + not implement new transports; rather, the platform provides some + useful transports that are implemented using the platform's best + practices. + + The user never instantiates a transport directly; they call a + utility function, passing it a protocol factory and other + information necessary to create the transport and protocol. (E.g. + EventLoop.create_connection() or EventLoop.start_serving().) + + The utility function will asynchronously create a transport and a + protocol and hook them up by calling the protocol's + connection_made() method, passing it the transport. + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def write_eof(self): + """Closes the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + raise NotImplementedError + + def can_write_eof(self): + """Return True if this protocol supports write_eof(), False if not.""" + raise NotImplementedError + + def pause(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + raise NotImplementedError + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError diff --git a/tulip/unix_events.py b/tulip/unix_events.py new file mode 100644 index 00000000..fc2cb4a7 --- /dev/null +++ b/tulip/unix_events.py @@ -0,0 +1,989 @@ +"""UNIX event loop and related classes. + +The event loop can be broken up into a selector (the part responsible +for telling us when file descriptors are ready) and the event loop +proper, which wraps a selector with functionality for scheduling +callbacks, immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. +""" + +import collections +import concurrent.futures +import errno +import heapq +import logging +import select +import socket +try: + import ssl +except ImportError: + ssl = None +import sys +import threading +import time + +try: + import signal +except ImportError: + signal = None + +from . import events +from . import futures +from . import protocols +from . import selectors +from . import tasks +from . import transports + +try: + from socket import socketpair +except ImportError: + assert sys.platform == 'win32' + from .winsocketpair import socketpair + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK, errno.EINPROGRESS)) +if sys.platform == 'win32': + _TRYAGAIN = frozenset(list(_TRYAGAIN) + [errno.WSAEWOULDBLOCK]) + +# Argument for default thread pool executor creation. +_MAX_WORKERS = 5 + + +class _StopError(BaseException): + """Raised to stop the event loop.""" + + +def _raise_stop_error(): + raise _StopError + + +class UnixEventLoop(events.EventLoop): + """Unix event loop. + + See events.EventLoop for API specification. + """ + + def __init__(self, selector=None): + super().__init__() + if selector is None: + # pick the best selector class for the platform + selector = selectors.Selector() + logging.debug('Using selector: %s', selector.__class__.__name__) + self._selector = selector + self._ready = collections.deque() + self._scheduled = [] + self._default_executor = None + self._signal_handlers = {} + self._make_self_pipe() + + def close(self): + if self._selector is not None: + self._selector.close() + self._selector = None + self._ssock.close() + self._csock.close() + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self.add_reader(self._ssock.fileno(), self._read_from_self) + + def _read_from_self(self): + try: + self._ssock.recv(1) + except socket.error as exc: + if exc.errno in _TRYAGAIN: + return + raise # Halp! + + def _write_to_self(self): + try: + self._csock.send(b'x') + except socket.error as exc: + if exc.errno in _TRYAGAIN: + return + raise # Halp! + + def run(self): + """Run the event loop until nothing left to do or stop() called. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + + TODO: Give this a timeout too? + """ + while (self._ready or + self._scheduled or + self._selector.registered_count() > 1): + try: + self._run_once() + except _StopError: + break + + def run_forever(self): + """Run until stop() is called. + + This only makes sense over run() if you have another thread + scheduling callbacks using call_soon_threadsafe(). + """ + handler = self.call_repeatedly(24*3600, lambda: None) + try: + self.run() + finally: + handler.cancel() + + def run_once(self, timeout=None): + """Run through all callbacks and all I/O polls once. + + Calling stop() will break out of this too. + """ + try: + self._run_once(timeout) + except _StopError: + pass + + def run_until_complete(self, future, timeout=None): + """Run until the Future is done, or until a timeout. + + Return the Future's result, or raise its exception. If the + timeout is reached or stop() is called, raise TimeoutError. + """ + if timeout is None: + timeout = 0x7fffffff/1000.0 # 24 days + future.add_done_callback(lambda _: self.stop()) + handler = self.call_later(timeout, _raise_stop_error) + self.run() + handler.cancel() + if future.done(): + return future.result() # May raise future.exception(). + else: + raise futures.TimeoutError + + def stop(self): + """Stop running the event loop. + + Every callback scheduled before stop() is called will run. + Callback scheduled after stop() is called won't. However, + those callbacks will run if run() is called again later. + """ + self.call_soon(_raise_stop_error) + + def call_later(self, delay, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The delay can be an int or float, expressed in seconds. It is + always a relative time. + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Callbacks scheduled in the past are passed on to call_soon(), + so these will be called in the order in which they were + registered rather than by time due. This is so you can't + cheat and insert yourself at the front of the ready queue by + using a negative time. + + Any positional arguments after the callback will be passed to + the callback when it is called. + + # TODO: Should delay is None be interpreted as Infinity? + """ + if delay <= 0: + return self.call_soon(callback, *args) + handler = events.make_handler(time.monotonic() + delay, callback, args) + heapq.heappush(self._scheduled, handler) + return handler + + def call_repeatedly(self, interval, callback, *args): + """Call a callback every 'interval' seconds.""" + def wrapper(): + callback(*args) # If this fails, the chain is broken. + handler._when = time.monotonic() + interval + heapq.heappush(self._scheduled, handler) + handler = events.make_handler(time.monotonic() + interval, wrapper, ()) + heapq.heappush(self._scheduled, handler) + return handler + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + handler = events.make_handler(None, callback, args) + self._ready.append(handler) + return handler + + def call_soon_threadsafe(self, callback, *args): + """XXX""" + handler = self.call_soon(callback, *args) + self._write_to_self() + return handler + + def wrap_future(self, future): + """XXX""" + if isinstance(future, futures.Future): + return future # Don't wrap our own type of Future. + new_future = futures.Future() + future.add_done_callback( + lambda future: + self.call_soon_threadsafe(new_future._copy_state, future)) + return new_future + + def run_in_executor(self, executor, callback, *args): + if isinstance(callback, events.Handler): + assert not args + assert callback.when is None + if callback.cancelled: + f = futures.Future() + f.set_result(None) + return f + callback, args = callback.callback, callback.args + if executor is None: + executor = self._default_executor + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + self._default_executor = executor + return self.wrap_future(executor.submit(callback, *args)) + + def set_default_executor(self, executor): + self._default_executor = executor + + def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) + + def getnameinfo(self, sockaddr, flags=0): + return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + + @tasks.task + def create_connection(self, protocol_factory, host, port, *, ssl=False, + family=0, proto=0, flags=0): + """XXX""" + infos = yield from self.getaddrinfo(host, port, + family=family, + type=socket.SOCK_STREAM, + proto=proto, flags=flags) + if not infos: + raise socket.error('getaddrinfo() returned empty list') + exceptions = [] + for family, type, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + yield self.sock_connect(sock, address) + except socket.error as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + if len(exceptions) == 1: + raise exceptions[0] + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise socket.error('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + protocol = protocol_factory() + waiter = futures.Future() + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + transport = _UnixSslTransport(self, sock, protocol, sslcontext, + waiter) + else: + transport = _UnixSocketTransport(self, sock, protocol, waiter) + yield from waiter + return transport, protocol + + # TODO: Or create_server()? + @tasks.task + def start_serving(self, protocol_factory, host, port, *, + family=0, proto=0, flags=0, + backlog=100): + """XXX""" + infos = yield from self.getaddrinfo(host, port, + family=family, + type=socket.SOCK_STREAM, + proto=proto, flags=flags) + if not infos: + raise socket.error('getaddrinfo() returned empty list') + # TODO: Maybe we want to bind every address in the list + # instead of the first one that works? + exceptions = [] + for family, type, proto, cname, address in infos: + sock = socket.socket(family=family, type=type, proto=proto) + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + except socket.error as exc: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + sock.listen(backlog) + sock.setblocking(False) + self.add_reader(sock.fileno(), self._accept_connection, + protocol_factory, sock) + return sock + + def _accept_connection(self, protocol_factory, sock): + try: + conn, addr = sock.accept() + except socket.error as exc: + if exc.errno in _TRYAGAIN: + return # False alarm. + # Bad error. Stop serving. + self.remove_reader(sock.fileno()) + sock.close() + # There's nowhere to send the error, so just log it. + # TODO: Someone will want an error handler for this. + logging.exception('Accept failed') + return + protocol = protocol_factory() + transport = _UnixSocketTransport(self, conn, protocol) + # It's now up to the protocol to handle the connection. + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a Handler instance.""" + handler = events.make_handler(None, callback, args) + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_READ, + (handler, None, None)) + else: + self._selector.modify(fd, mask | selectors.EVENT_READ, + (handler, writer, connector)) + + return handler + + def remove_reader(self, fd): + """Remove a reader callback.""" + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + return False + else: + mask &= ~selectors.EVENT_READ + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (None, writer, connector)) + if reader is not None: + reader.cancel() + return True + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a Handler instance.""" + handler = events.make_handler(None, callback, args) + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_WRITE, + (None, handler, None)) + else: + # Remove connector. + mask &= ~selectors.EVENT_CONNECT + self._selector.modify(fd, mask | selectors.EVENT_WRITE, + (reader, handler, None)) + return handler + + def remove_writer(self, fd): + """Remove a writer callback.""" + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + return False + else: + # Remove both writer and connector. + mask &= ~(selectors.EVENT_WRITE | selectors.EVENT_CONNECT) + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None, None)) + if writer is not None: + writer.cancel() + if connector is not None: + connector.cancel() + return True + + # NOTE: add_connector() and add_writer() are mutually exclusive. + # While you can independently manipulate readers and writers, + # adding a connector for a particular FD automatically removes the + # writer for that FD, and vice versa, and removing a writer or a + # connector actually removes both writer and connector. This is + # because in most cases writers and connectors use the same mode + # for the platform polling function; the distinction is only + # important for PollSelector() on Windows. + + def add_connector(self, fd, callback, *args): + """Add a connector callback. Return a Handler instance.""" + handler = events.make_handler(None, callback, args) + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_CONNECT, + (None, None, handler)) + else: + # Remove writer. + mask &= ~selectors.EVENT_WRITE + self._selector.modify(fd, mask | selectors.EVENT_CONNECT, + (reader, None, handler)) + return handler + + def remove_connector(self, fd): + """Remove a connector callback.""" + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + return False + else: + # Remove both writer and connector. + mask &= ~(selectors.EVENT_WRITE | selectors.EVENT_CONNECT) + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None, None)) + if writer is not None: + writer.cancel() + if connector is not None: + connector.cancel() + return True + + def sock_recv(self, sock, n): + """XXX""" + fut = futures.Future() + self._sock_recv(fut, False, sock, n) + return fut + + def _sock_recv(self, fut, registered, sock, n): + fd = sock.fileno() + if registered: + # Remove the callback early. It should be rare that the + # selector says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_reader(fd) + if fut.cancelled(): + return + try: + data = sock.recv(n) + fut.set_result(data) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + fut.set_exception(exc) + else: + self.add_reader(fd, self._sock_recv, fut, True, sock, n) + + def sock_sendall(self, sock, data): + """XXX""" + fut = futures.Future() + self._sock_sendall(fut, False, sock, data) + return fut + + def _sock_sendall(self, fut, registered, sock, data): + fd = sock.fileno() + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + n = 0 + try: + if data: + n = sock.send(data) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + fut.set_exception(exc) + return + if n == len(data): + fut.set_result(None) + else: + if n: + data = data[n:] + self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + + def sock_connect(self, sock, address): + """XXX""" + # That address better not require a lookup! We're not calling + # self.getaddrinfo() for you here. But verifying this is + # complicated; the socket module doesn't have a pattern for + # IPv6 addresses (there are too many forms, apparently). + fut = futures.Future() + self._sock_connect(fut, False, sock, address) + return fut + + def _sock_connect(self, fut, registered, sock, address): + fd = sock.fileno() + if registered: + self.remove_connector(fd) + if fut.cancelled(): + return + try: + if not registered: + # First time around. + sock.connect(address) + else: + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to the except clause below. + raise socket.error(err, 'Connect call failed') + fut.set_result(None) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + fut.set_exception(exc) + else: + self.add_connector(fd, self._sock_connect, + fut, True, sock, address) + + def sock_accept(self, sock): + """XXX""" + fut = futures.Future() + self._sock_accept(fut, False, sock) + return fut + + def _sock_accept(self, fut, registered, sock): + fd = sock.fileno() + if registered: + self.remove_reader(fd) + if fut.cancelled(): + return + try: + conn, address = sock.accept() + conn.setblocking(False) + fut.set_result((conn, address)) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + fut.set_exception(exc) + else: + self.add_reader(fd, self._sock_accept, fut, True, sock) + + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + self._check_signal(sig) + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except ValueError as exc: + raise RuntimeError(str(exc)) + handler = events.make_handler(None, callback, args) + self._signal_handlers[sig] = handler + try: + signal.signal(sig, self._handle_signal) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as nexc: + logging.info('set_wakeup_fd(-1) failed: %s', nexc) + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + return handler + + def _handle_signal(self, sig, arg): + """Internal helper that is the actual signal handler.""" + handler = self._signal_handlers.get(sig) + if handler is None: + return # Assume it's some race condition. + if handler.cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self.call_soon_threadsafe(handler.callback, *handler.args) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not.""" + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as exc: + logging.info('set_wakeup_fd(-1) failed: %s', exc) + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + if signal is None: + raise RuntimeError('Signals are not supported') + if not (1 <= sig < signal.NSIG): + raise ValueError('sig {} out of range(1, {})'.format(sig, + signal.NSIG)) + if sys.platform == 'win32': + raise RuntimeError('Signals are not really supported on Windows') + + def _add_callback(self, handler): + """Add a Handler to ready or scheduled.""" + if handler.cancelled: + return + if handler.when is None: + self._ready.append(handler) + else: + heapq.heappush(self._scheduled, handler) + + def _run_once(self, timeout=None): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0].cancelled: + heapq.heappop(self._scheduled) + + # Inspect the poll queue. If there's exactly one selectable + # file descriptor, it's the self-pipe, and if there's nothing + # scheduled, we should ignore it. + if self._selector.registered_count() > 1 or self._scheduled: + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0].when + deadline = max(0, when - time.monotonic()) + if timeout is None: + timeout = deadline + else: + timeout = min(timeout, deadline) + + t0 = time.monotonic() + event_list = self._selector.select(timeout) + t1 = time.monotonic() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + for fileobj, mask, (reader, writer, connector) in event_list: + if mask & selectors.EVENT_READ and reader is not None: + if reader.cancelled: + self.remove_reader(fileobj) + else: + self._add_callback(reader) + if mask & selectors.EVENT_WRITE and writer is not None: + if writer.cancelled: + self.remove_writer(fileobj) + else: + self._add_callback(writer) + elif mask & selectors.EVENT_CONNECT and connector is not None: + if connector.cancelled: + self.remove_connector(fileobj) + else: + self._add_callback(connector) + + # Handle 'later' callbacks that are ready. + now = time.monotonic() + while self._scheduled: + handler = self._scheduled[0] + if handler.when > now: + break + handler = heapq.heappop(self._scheduled) + self._ready.append(handler) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is threadsafe without using locks. + ntodo = len(self._ready) + for i in range(ntodo): + handler = self._ready.popleft() + if not handler.cancelled: + try: + handler.callback(*handler.args) + except Exception: + logging.exception('Exception in callback %s %r', + handler.callback, handler.args) + + +class _UnixSocketTransport(transports.Transport): + + def __init__(self, event_loop, sock, protocol, waiter=None): + self._event_loop = event_loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._closing = False # Set when close() called. + self._event_loop.add_reader(self._sock.fileno(), self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = self._sock.recv(16*1024) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + else: + if data: + self._event_loop.call_soon(self._protocol.data_received, data) + else: + self._event_loop.remove_reader(self._sock.fileno()) + self._event_loop.call_soon(self._protocol.eof_received) + + + def write(self, data): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + if not self._buffer: + # Attempt to send it right away first. + try: + n = self._sock.send(data) + except socket.error as exc: + if exc.errno in _TRYAGAIN: + n = 0 + else: + self._fatal_error(exc) + return + if n == len(data): + return + if n: + data = data[n:] + self._event_loop.add_writer(self._sock.fileno(), self._write_ready) + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + self._buffer = [] + try: + if data: + n = self._sock.send(data) + else: + n = 0 + except socket.error as exc: + if exc.errno in _TRYAGAIN: + n = 0 + else: + self._fatal_error(exc) + return + if n == len(data): + self._event_loop.remove_writer(self._sock.fileno()) + if self._closing: + self._event_loop.call_soon(self._call_connection_lost, None) + return + if n: + data = data[n:] + self._buffer.append(data) # Try again later. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sock.fileno()) + if not self._buffer: + self._event_loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + logging.exception('Fatal error for %s', self) + self._event_loop.remove_writer(self._sock.fileno()) + self._event_loop.remove_reader(self._sock.fileno()) + self._buffer = [] + self._event_loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + + +class _UnixSslTransport(transports.Transport): + + def __init__(self, event_loop, rawsock, protocol, sslcontext, waiter): + self._event_loop = event_loop + self._rawsock = rawsock + self._protocol = protocol + sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslcontext = sslcontext + self._waiter = waiter + sslsock = sslcontext.wrap_socket(rawsock, + do_handshake_on_connect=False) + self._sslsock = sslsock + self._buffer = [] + self._closing = False # Set when close() called. + self._on_handshake() + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._event_loop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._event_loop.add_writer(fd, self._on_handshake) + return + except Exception as exc: + self._sslsock.close() + self._waiter.set_exception(exc) + return + except BaseException as exc: + self._sslsock.close() + self._waiter.set_exception(exc) + raise + self._event_loop.remove_reader(fd) + self._event_loop.remove_writer(fd) + self._event_loop.add_reader(fd, self._on_ready) + self._event_loop.add_writer(fd, self._on_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.call_soon(self._waiter.set_result, None) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + return + else: + if data: + self._protocol.data_received(data) + else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer + self._event_loop.remove_reader(fd) + self._event_loop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + return + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = b''.join(self._buffer) + self._buffer = [] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + return + else: + if n < len(data): + self._buffer.append(data[n:]) + + def write(self, data): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sslsock.fileno()) + if not self._buffer: + self._event_loop.call_soon(self._protocol.connection_lost, None) + + def _fatal_error(self, exc): + logging.exception('Fatal error for %s', self) + self._event_loop.remove_writer(self._sslsock.fileno()) + self._event_loop.remove_reader(self._sslsock.fileno()) + self._buffer = [] + self._event_loop.call_soon(self._protocol.connection_lost, exc) diff --git a/tulip/winsocketpair.py b/tulip/winsocketpair.py new file mode 100644 index 00000000..87d54c91 --- /dev/null +++ b/tulip/winsocketpair.py @@ -0,0 +1,30 @@ +"""A socket pair usable as a self-pipe, for Windows. + +Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. +""" + +import errno +import socket + + +def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """Emulate the Unix socketpair() function on Windows.""" + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + lsock.bind(('localhost', 0)) + lsock.listen(1) + addr, port = lsock.getsockname() + csock = socket.socket(family, type, proto) + csock.setblocking(False) + try: + csock.connect((addr, port)) + except socket.error as e: + if e.errno != errno.WSAEWOULDBLOCK: + lsock.close() + csock.close() + raise + ssock, _ = lsock.accept() + csock.setblocking(True) + lsock.close() + return (ssock, csock) From 7d6fc4632247280e42e028d53562c16d2bd9b57a Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Fri, 25 Jan 2013 17:21:15 +0000 Subject: [PATCH 0266/1502] Dummy merge. From 52d3ebbea1a7652891b1e8022e02a6f32053b9cb Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 25 Jan 2013 09:43:41 -0800 Subject: [PATCH 0267/1502] Logging tweaks for tests. --- runtests.py | 12 ++++++++++++ tulip/subprocess_test.py | 3 ++- tulip/tasks_test.py | 7 ++++++- tulip/test_utils.py | 21 +++++++++++++++++++++ 4 files changed, 41 insertions(+), 2 deletions(-) create mode 100644 tulip/test_utils.py diff --git a/runtests.py b/runtests.py index e2598c9a..d7de9743 100644 --- a/runtests.py +++ b/runtests.py @@ -15,6 +15,7 @@ # Originally written by Beech Horn (for NDB). +import logging import os import re import sys @@ -68,6 +69,17 @@ def main(): elif arg and not arg.startswith('-'): patterns.append(arg) tests = load_tests(includes, excludes) + logger = logging.getLogger() + if v == 0: + logger.setLevel(logging.CRITICAL) + elif v == 1: + logger.setLevel(logging.ERROR) + elif v == 2: + logger.setLevel(logging.WARNING) + elif v == 3: + logger.setLevel(logging.INFO) + elif v >= 4: + logger.setLevel(logging.DEBUG) result = unittest.TextTestRunner(verbosity=v).run(tests) sys.exit(not result.wasSuccessful()) diff --git a/tulip/subprocess_test.py b/tulip/subprocess_test.py index 4eb24e41..b4bcc26f 100644 --- a/tulip/subprocess_test.py +++ b/tulip/subprocess_test.py @@ -1,5 +1,6 @@ """Tests for subprocess_transport.py.""" +import logging import unittest from . import events @@ -17,7 +18,7 @@ def connection_made(self, transport): self.state = 'CONNECTED' transport.write_eof() def data_received(self, data): - print('received:', data) + logging.info('received: %r', data) assert self.state == 'CONNECTED', self.state self.nbytes += len(data) def eof_received(self): diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py index a7381450..456cccfa 100644 --- a/tulip/tasks_test.py +++ b/tulip/tasks_test.py @@ -6,6 +6,7 @@ from . import events from . import futures from . import tasks +from . import test_utils class Dummy: @@ -15,14 +16,16 @@ def __call__(self, *args): pass -class TaskTests(unittest.TestCase): +class TaskTests(test_utils.LogTrackingTestCase): def setUp(self): + super().setUp() self.event_loop = events.new_event_loop() events.set_event_loop(self.event_loop) def tearDown(self): self.event_loop.close() + super().tearDown() def testTaskClass(self): @tasks.coroutine @@ -100,6 +103,7 @@ def foo(): # TODO: Test different return_when values. def testWaitWithException(self): + self.suppress_log_errors() a = tasks.sleep(0.1) @tasks.coroutine def sleeper(): @@ -164,6 +168,7 @@ def foo(): self.assertTrue(t1-t0 <= 0.01) def testAsCompletedWithTimeout(self): + self.suppress_log_errors() a = tasks.sleep(0.1, 'a') b = tasks.sleep(0.15, 'b') @tasks.coroutine diff --git a/tulip/test_utils.py b/tulip/test_utils.py new file mode 100644 index 00000000..2d9e64bd --- /dev/null +++ b/tulip/test_utils.py @@ -0,0 +1,21 @@ +"""Utilities shared by tests.""" + +import logging +import unittest + +class LogTrackingTestCase(unittest.TestCase): + + def setUp(self): + self._logger = logging.getLogger() + self._log_level = self._logger.getEffectiveLevel() + + def tearDown(self): + self._logger.setLevel(self._log_level) + + def suppress_log_errors(self): + if self._log_level >= logging.WARNING: + self._logger.setLevel(logging.CRITICAL) + + def suppress_log_warnings(self): + if self._log_level >= logging.WARNING: + self._logger.setLevel(logging.ERROR) From d521aed00732d4946d9c390dfe22f29955865da8 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 25 Jan 2013 10:32:45 -0800 Subject: [PATCH 0268/1502] http_client.StreamReader tests --- tulip/http_client_test.py | 157 +++++++++++++++++++++++++++++++++----- tulip/test_utils.py | 12 +++ 2 files changed, 150 insertions(+), 19 deletions(-) diff --git a/tulip/http_client_test.py b/tulip/http_client_test.py index f712da94..65a8b69d 100644 --- a/tulip/http_client_test.py +++ b/tulip/http_client_test.py @@ -5,6 +5,7 @@ from . import events from . import http_client from . import tasks +from . import test_utils class StreamReaderTests(unittest.TestCase): @@ -31,38 +32,105 @@ def test_feed_data_line_byte_count(self): self.assertEqual(self.DATA.count(b'\n'), stream.line_count) self.assertEqual(len(self.DATA), stream.byte_count) - def test_readline_line_byte_count(self): + @test_utils.sync + def test_read_zero(self): + """ Read zero bytes """ stream = http_client.StreamReader() stream.feed_data(self.DATA) - @tasks.coroutine - def readline(): - line = yield from stream.readline() - return line + data = yield from stream.read(0) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + self.assertEqual(self.DATA.count(b'\n'), stream.line_count) + + @test_utils.sync + def test_read(self): + """ Read bytes """ + stream = http_client.StreamReader() + + res = stream.read(30) + + def cb(): + stream.feed_data(self.DATA) + self.event_loop.call_soon(cb) - line = self.event_loop.run_until_complete(tasks.Task(readline())) + data = yield from res + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + self.assertFalse(stream.line_count) + + @test_utils.sync + def test_read_eof(self): + """ Read bytes, stop at eof """ + stream = http_client.StreamReader() + + read = tasks.Task(stream.read(1024)) + + def cb(): + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = yield from read + + self.assertEqual(b'', data) + self.assertFalse(stream.byte_count) + self.assertFalse(stream.line_count) + + @test_utils.sync + def test_read_until_eof(self): + """ Read all bytes until eof """ + stream = http_client.StreamReader() + + read = tasks.Task(stream.read(-1)) + + def cb(): + stream.feed_data(b'chunk1\n') + stream.feed_data(b'chunk2') + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = yield from read + + self.assertEqual(b'chunk1\nchunk2', data) + self.assertFalse(stream.byte_count) + self.assertFalse(stream.line_count) + + @test_utils.sync + def test_readline(self): + """ Read one line """ + stream = http_client.StreamReader() + stream.feed_data(b'chunk1 ') + + def cb(): + stream.feed_data(b'chunk2 ') + stream.feed_data(b'chunk3 ') + stream.feed_data(b'\n chunk4') + self.event_loop.call_soon(cb) + + line = yield from stream.readline() + + self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) + self.assertFalse(stream.line_count) + self.assertEqual(len(b'\n chunk4')-1, stream.byte_count) + + @test_utils.sync + def test_readline_line_byte_count(self): + stream = http_client.StreamReader() + stream.feed_data(self.DATA) + + line = yield from stream.readline() self.assertEqual(b'line1\n', line) self.assertEqual(self.DATA.count(b'\n')-1, stream.line_count) self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) + @test_utils.sync def test_readline_read_byte_count(self): stream = http_client.StreamReader() stream.feed_data(self.DATA) - @tasks.coroutine - def readline(): - line = yield from stream.readline() - return line - - line = self.event_loop.run_until_complete(tasks.Task(readline())) - - @tasks.coroutine - def read(): - line = yield from stream.read(7) - return line - - data = self.event_loop.run_until_complete(tasks.Task(read())) + line = yield from stream.readline() + data = yield from stream.read(7) self.assertEqual(b'line2\nl', data) self.assertEqual( @@ -71,6 +139,57 @@ def read(): len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), stream.byte_count) + @test_utils.sync + def test_readexactly_zero_or_less(self): + """ Read exact number of bytes (zero or less) """ + stream = http_client.StreamReader() + stream.feed_data(self.DATA) + + data = yield from stream.readexactly(0) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + self.assertEqual(self.DATA.count(b'\n'), stream.line_count) + + data = yield from stream.readexactly(-1) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + self.assertEqual(self.DATA.count(b'\n'), stream.line_count) + + @test_utils.sync + def test_readexactly(self): + """ Read exact number of bytes """ + stream = http_client.StreamReader() + + def cb(): + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + self.event_loop.call_soon(cb) + + n = 2*len(self.DATA) + data = yield from stream.readexactly(n) + + self.assertEqual(self.DATA+self.DATA, data) + self.assertEqual(len(self.DATA), stream.byte_count) + self.assertEqual(self.DATA.count(b'\n'), stream.line_count) + + @test_utils.sync + def test_readexactly_eof(self): + """ Read exact number of bytes (eof) """ + stream = http_client.StreamReader() + + def cb(): + stream.feed_data(self.DATA) + stream.feed_eof() + self.event_loop.call_soon(cb) + + n = 2*len(self.DATA) + data = yield from stream.readexactly(n) + + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + self.assertFalse(stream.line_count) + if __name__ == '__main__': unittest.main() diff --git a/tulip/test_utils.py b/tulip/test_utils.py index 2d9e64bd..65d0d421 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -1,8 +1,20 @@ """Utilities shared by tests.""" +import functools import logging import unittest +from . import events + +def sync(gen): + @functools.wraps(gen) + def wrapper(*args, **kw): + return events.get_event_loop().run_until_complete( + tasks.Task(gen(*args, **kw))) + + return wrapper + + class LogTrackingTestCase(unittest.TestCase): def setUp(self): From 2b30be372b5649420ee0d0c5f93d5cc6d73a883f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 25 Jan 2013 12:05:38 -0800 Subject: [PATCH 0269/1502] Fix missing import in test_utils.py. --- tulip/test_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tulip/test_utils.py b/tulip/test_utils.py index 65d0d421..ac737c25 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -5,6 +5,7 @@ import unittest from . import events +from . import tasks def sync(gen): @functools.wraps(gen) From 2c69c899d50edf4fdfd5168102511d076407b9d8 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 25 Jan 2013 13:20:54 -0800 Subject: [PATCH 0270/1502] Merge default -> iocp (mostly test improvements). --- runtests.py | 12 +++ tulip/http_client_test.py | 157 +++++++++++++++++++++++++++++++++----- tulip/subprocess_test.py | 3 +- tulip/tasks_test.py | 7 +- tulip/test_utils.py | 34 +++++++++ 5 files changed, 192 insertions(+), 21 deletions(-) create mode 100644 tulip/test_utils.py diff --git a/runtests.py b/runtests.py index d18fe7d3..275b23cb 100644 --- a/runtests.py +++ b/runtests.py @@ -15,6 +15,7 @@ # Originally written by Beech Horn (for NDB). +import logging import os import re import sys @@ -74,6 +75,17 @@ def main(): elif arg and not arg.startswith('-'): patterns.append(arg) tests = load_tests(includes, excludes) + logger = logging.getLogger() + if v == 0: + logger.setLevel(logging.CRITICAL) + elif v == 1: + logger.setLevel(logging.ERROR) + elif v == 2: + logger.setLevel(logging.WARNING) + elif v == 3: + logger.setLevel(logging.INFO) + elif v >= 4: + logger.setLevel(logging.DEBUG) result = unittest.TextTestRunner(verbosity=v).run(tests) sys.exit(not result.wasSuccessful()) diff --git a/tulip/http_client_test.py b/tulip/http_client_test.py index f712da94..65a8b69d 100644 --- a/tulip/http_client_test.py +++ b/tulip/http_client_test.py @@ -5,6 +5,7 @@ from . import events from . import http_client from . import tasks +from . import test_utils class StreamReaderTests(unittest.TestCase): @@ -31,38 +32,105 @@ def test_feed_data_line_byte_count(self): self.assertEqual(self.DATA.count(b'\n'), stream.line_count) self.assertEqual(len(self.DATA), stream.byte_count) - def test_readline_line_byte_count(self): + @test_utils.sync + def test_read_zero(self): + """ Read zero bytes """ stream = http_client.StreamReader() stream.feed_data(self.DATA) - @tasks.coroutine - def readline(): - line = yield from stream.readline() - return line + data = yield from stream.read(0) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + self.assertEqual(self.DATA.count(b'\n'), stream.line_count) + + @test_utils.sync + def test_read(self): + """ Read bytes """ + stream = http_client.StreamReader() + + res = stream.read(30) + + def cb(): + stream.feed_data(self.DATA) + self.event_loop.call_soon(cb) - line = self.event_loop.run_until_complete(tasks.Task(readline())) + data = yield from res + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + self.assertFalse(stream.line_count) + + @test_utils.sync + def test_read_eof(self): + """ Read bytes, stop at eof """ + stream = http_client.StreamReader() + + read = tasks.Task(stream.read(1024)) + + def cb(): + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = yield from read + + self.assertEqual(b'', data) + self.assertFalse(stream.byte_count) + self.assertFalse(stream.line_count) + + @test_utils.sync + def test_read_until_eof(self): + """ Read all bytes until eof """ + stream = http_client.StreamReader() + + read = tasks.Task(stream.read(-1)) + + def cb(): + stream.feed_data(b'chunk1\n') + stream.feed_data(b'chunk2') + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = yield from read + + self.assertEqual(b'chunk1\nchunk2', data) + self.assertFalse(stream.byte_count) + self.assertFalse(stream.line_count) + + @test_utils.sync + def test_readline(self): + """ Read one line """ + stream = http_client.StreamReader() + stream.feed_data(b'chunk1 ') + + def cb(): + stream.feed_data(b'chunk2 ') + stream.feed_data(b'chunk3 ') + stream.feed_data(b'\n chunk4') + self.event_loop.call_soon(cb) + + line = yield from stream.readline() + + self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) + self.assertFalse(stream.line_count) + self.assertEqual(len(b'\n chunk4')-1, stream.byte_count) + + @test_utils.sync + def test_readline_line_byte_count(self): + stream = http_client.StreamReader() + stream.feed_data(self.DATA) + + line = yield from stream.readline() self.assertEqual(b'line1\n', line) self.assertEqual(self.DATA.count(b'\n')-1, stream.line_count) self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) + @test_utils.sync def test_readline_read_byte_count(self): stream = http_client.StreamReader() stream.feed_data(self.DATA) - @tasks.coroutine - def readline(): - line = yield from stream.readline() - return line - - line = self.event_loop.run_until_complete(tasks.Task(readline())) - - @tasks.coroutine - def read(): - line = yield from stream.read(7) - return line - - data = self.event_loop.run_until_complete(tasks.Task(read())) + line = yield from stream.readline() + data = yield from stream.read(7) self.assertEqual(b'line2\nl', data) self.assertEqual( @@ -71,6 +139,57 @@ def read(): len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), stream.byte_count) + @test_utils.sync + def test_readexactly_zero_or_less(self): + """ Read exact number of bytes (zero or less) """ + stream = http_client.StreamReader() + stream.feed_data(self.DATA) + + data = yield from stream.readexactly(0) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + self.assertEqual(self.DATA.count(b'\n'), stream.line_count) + + data = yield from stream.readexactly(-1) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + self.assertEqual(self.DATA.count(b'\n'), stream.line_count) + + @test_utils.sync + def test_readexactly(self): + """ Read exact number of bytes """ + stream = http_client.StreamReader() + + def cb(): + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + self.event_loop.call_soon(cb) + + n = 2*len(self.DATA) + data = yield from stream.readexactly(n) + + self.assertEqual(self.DATA+self.DATA, data) + self.assertEqual(len(self.DATA), stream.byte_count) + self.assertEqual(self.DATA.count(b'\n'), stream.line_count) + + @test_utils.sync + def test_readexactly_eof(self): + """ Read exact number of bytes (eof) """ + stream = http_client.StreamReader() + + def cb(): + stream.feed_data(self.DATA) + stream.feed_eof() + self.event_loop.call_soon(cb) + + n = 2*len(self.DATA) + data = yield from stream.readexactly(n) + + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + self.assertFalse(stream.line_count) + if __name__ == '__main__': unittest.main() diff --git a/tulip/subprocess_test.py b/tulip/subprocess_test.py index 4eb24e41..b4bcc26f 100644 --- a/tulip/subprocess_test.py +++ b/tulip/subprocess_test.py @@ -1,5 +1,6 @@ """Tests for subprocess_transport.py.""" +import logging import unittest from . import events @@ -17,7 +18,7 @@ def connection_made(self, transport): self.state = 'CONNECTED' transport.write_eof() def data_received(self, data): - print('received:', data) + logging.info('received: %r', data) assert self.state == 'CONNECTED', self.state self.nbytes += len(data) def eof_received(self): diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py index a7381450..456cccfa 100644 --- a/tulip/tasks_test.py +++ b/tulip/tasks_test.py @@ -6,6 +6,7 @@ from . import events from . import futures from . import tasks +from . import test_utils class Dummy: @@ -15,14 +16,16 @@ def __call__(self, *args): pass -class TaskTests(unittest.TestCase): +class TaskTests(test_utils.LogTrackingTestCase): def setUp(self): + super().setUp() self.event_loop = events.new_event_loop() events.set_event_loop(self.event_loop) def tearDown(self): self.event_loop.close() + super().tearDown() def testTaskClass(self): @tasks.coroutine @@ -100,6 +103,7 @@ def foo(): # TODO: Test different return_when values. def testWaitWithException(self): + self.suppress_log_errors() a = tasks.sleep(0.1) @tasks.coroutine def sleeper(): @@ -164,6 +168,7 @@ def foo(): self.assertTrue(t1-t0 <= 0.01) def testAsCompletedWithTimeout(self): + self.suppress_log_errors() a = tasks.sleep(0.1, 'a') b = tasks.sleep(0.15, 'b') @tasks.coroutine diff --git a/tulip/test_utils.py b/tulip/test_utils.py new file mode 100644 index 00000000..ac737c25 --- /dev/null +++ b/tulip/test_utils.py @@ -0,0 +1,34 @@ +"""Utilities shared by tests.""" + +import functools +import logging +import unittest + +from . import events +from . import tasks + +def sync(gen): + @functools.wraps(gen) + def wrapper(*args, **kw): + return events.get_event_loop().run_until_complete( + tasks.Task(gen(*args, **kw))) + + return wrapper + + +class LogTrackingTestCase(unittest.TestCase): + + def setUp(self): + self._logger = logging.getLogger() + self._log_level = self._logger.getEffectiveLevel() + + def tearDown(self): + self._logger.setLevel(self._log_level) + + def suppress_log_errors(self): + if self._log_level >= logging.WARNING: + self._logger.setLevel(logging.CRITICAL) + + def suppress_log_warnings(self): + if self._log_level >= logging.WARNING: + self._logger.setLevel(logging.ERROR) From 48651c6a96d94eb818500f3cadd3c95ca4183360 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Mon, 28 Jan 2013 15:39:21 +0000 Subject: [PATCH 0271/1502] Stop using SetFileCompletionNotificationModes() since not available on WinXP --- overlapped.c | 8 +- tulip/events_test.py | 12 +++ tulip/iocp_events.py | 231 +++++++++++++++++++++---------------------- 3 files changed, 127 insertions(+), 124 deletions(-) diff --git a/overlapped.c b/overlapped.c index a428cf2c..09ac4e68 100644 --- a/overlapped.c +++ b/overlapped.c @@ -122,11 +122,11 @@ PyDoc_STRVAR( static PyObject * overlapped_GetQueuedCompletionStatus(PyObject *self, PyObject *args) { - HANDLE CompletionPort; - DWORD Milliseconds; - DWORD NumberOfBytes; - ULONG_PTR CompletionKey; + HANDLE CompletionPort = NULL; + DWORD NumberOfBytes = 0; + ULONG_PTR CompletionKey = 0; OVERLAPPED *Overlapped = NULL; + DWORD Milliseconds; DWORD err; BOOL ret; diff --git a/tulip/events_test.py b/tulip/events_test.py index 0a084800..1938e8e2 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -439,6 +439,18 @@ def create_event_loop(self): return iocp_events.IocpEventLoop() def testCreateSslTransport(self): raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + def testReaderCallback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def testReaderCallbackCancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def testReaderCallbackWithHandler(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def testWriterCallback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def testWriterCallbackCancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def testWriterCallbackWithHandler(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") # Should always exist. class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): diff --git a/tulip/iocp_events.py b/tulip/iocp_events.py index cfa09a23..8148ca27 100644 --- a/tulip/iocp_events.py +++ b/tulip/iocp_events.py @@ -8,7 +8,6 @@ import errno import logging import os -import heapq import sys import socket import time @@ -17,6 +16,7 @@ from _winapi import CloseHandle from . import transports +from . import events from .futures import Future from .unix_events import BaseEventLoop, _StopError @@ -92,19 +92,10 @@ def _register_with_iocp(self, obj): if obj not in self._registered: self._registered.add(obj) CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) - SetFileCompletionNotificationModes(obj.fileno(), - FILE_SKIP_COMPLETION_PORT_ON_SUCCESS) def _register(self, ov, obj, callback): f = Future() - if ov.error == ERROR_IO_PENDING: - # we must prevent ov and obj from being garbage collected - self._cache[ov.address] = (f, ov, obj, callback) - else: - try: - f.set_result(callback()) - except Exception as e: - f.set_exception(e) + self._cache[ov.address] = (f, ov, obj, callback) return f def _get_accept_socket(self): @@ -125,117 +116,35 @@ def _poll(self, timeout=None): status = GetQueuedCompletionStatus(self._iocp, ms) if status is None: return - f, ov, obj, callback = self._cache.pop(status[3]) + address = status[3] + f, ov, obj, callback = self._cache.pop(address) try: value = callback() except OSError as e: - if f is None: - sys.excepthook(*sys.exc_info()) - continue f.set_exception(e) self._results.append(f) else: - if f is None: - continue f.set_result(value) self._results.append(f) ms = 0 - def close(self, *, CloseHandle=CloseHandle): - for address, (f, ov, obj, callback) in list(self._cache.items()): + def close(self): + for (f, ov, obj, callback) in self._cache.values(): try: ov.cancel() except OSError: pass while self._cache: - status = GetQueuedCompletionStatus(self._iocp, 1000) - if status is None: + if not self._poll(1000): logging.debug('taking long time to close proactor') - continue - self._cache.pop(status[3]) + self._results = [] if self._iocp is not None: CloseHandle(self._iocp) self._iocp = None -class IocpEventLoop(BaseEventLoop): - - @staticmethod - def SocketTransport(*args, **kwds): - return _IocpSocketTransport(*args, **kwds) - - @staticmethod - def SslTransport(*args, **kwds): - raise NotImplementedError - - def __init__(self, proactor=None): - super().__init__() - if proactor is None: - proactor = IocpProactor() - logging.debug('Using proactor: %s', proactor.__class__.__name__) - self._proactor = proactor - self._selector = proactor # convenient alias - self._readers = {} - self._make_self_pipe() - - def close(self): - if self._proactor is not None: - self._proactor.close() - self._proactor = None - self._ssock.close() - self._csock.close() - - def _make_self_pipe(self): - # A self-socket, really. :-) - self._ssock, self._csock = socketpair() - self._ssock.setblocking(False) - self._csock.setblocking(False) - def loop(f=None): - if f and f.exception(): - self.close() - raise f.exception() - f = self._proactor.recv(self._ssock, 4096) - self.call_soon(f.add_done_callback, loop) - self.call_soon(loop) - - def _write_to_self(self): - self._proactor.send(self._csock, b'x') - - def _start_serving(self, protocol_factory, sock): - def loop(f=None): - try: - if f: - conn, addr = f.result() - protocol = protocol_factory() - transport = self.SocketTransport(self, conn, protocol) - f = self._proactor.accept(sock) - self.call_soon(f.add_done_callback, loop) - except OSError as exc: - if exc.errno in _TRYAGAIN: - self.call_soon(loop) - else: - sock.close() - logging.exception('Accept failed') - self.call_soon(loop) - - def sock_recv(self, sock, n): - return self._proactor.recv(sock, n) - - def sock_sendall(self, sock, data): - return self._proactor.send(sock, data) - - def sock_connect(self, sock, address): - return self._proactor.connect(sock, address) - - def sock_accept(self, sock): - return self._proactor.accept(sock) - - def _process_events(self, event_list): - pass # XXX hard work currently done in poll - - class _IocpSocketTransport(transports.Transport): def __init__(self, event_loop, sock, protocol, waiter=None): @@ -243,6 +152,7 @@ def __init__(self, event_loop, sock, protocol, waiter=None): self._sock = sock self._protocol = protocol self._buffer = [] + self._read_fut = None self._write_fut = None self._closing = False # Set when close() called. self._event_loop.call_soon(self._protocol.connection_made, self) @@ -252,42 +162,44 @@ def __init__(self, event_loop, sock, protocol, waiter=None): def _loop_reading(self, f=None): try: + assert f is self._read_fut if f: data = f.result() if not data: self._event_loop.call_soon(self._protocol.eof_received) + self._read_fut = None return self._event_loop.call_soon(self._protocol.data_received, data) - f = self._event_loop._proactor.recv(self._sock, 4096) - self._event_loop.call_soon( - f.add_done_callback, self._loop_reading) + self._read_fut = self._event_loop._proactor.recv(self._sock, 4096) except OSError as exc: self._fatal_error(exc) + else: + self._read_fut.add_done_callback(self._loop_reading) def write(self, data): assert isinstance(data, bytes) assert not self._closing if not data: return + self._buffer.append(data) + if not self._write_fut: + self._loop_writing() - if self._write_fut is not None: - self._buffer.append(data) - return - - def callback(f): - if f.exception(): - self._fatal_error(f.exception()) - # XXX should check for partial write + def _loop_writing(self, f=None): + try: + assert f is self._write_fut + if f: + f.result() data = b''.join(self._buffer) - if data: - self._buffer = [] - self._write_fut = self._event_loop._proactor.send( - self._sock, data) - assert f is self._write_fut + self._buffer = [] + if not data: self._write_fut = None - - self._write_fut = self._event_loop._proactor.send(self._sock, data) - self._write_fut.add_done_callback(callback) + return + self._write_fut = self._event_loop._proactor.send(self._sock, data) + except OSError as exc: + self._fatal_error(exc) + else: + self._write_fut.add_done_callback(self._loop_writing) # TODO: write_eof(), can_write_eof(). @@ -305,8 +217,8 @@ def _fatal_error(self, exc): logging.exception('Fatal error for %s', self) if self._write_fut: self._write_fut.cancel() - # if self._read_fut: # XXX - # self._read_fut.cancel() + if self._read_fut: # XXX + self._read_fut.cancel() self._write_fut = self._read_fut = None self._buffer = [] self._event_loop.call_soon(self._call_connection_lost, exc) @@ -316,3 +228,82 @@ def _call_connection_lost(self, exc): self._protocol.connection_lost(exc) finally: self._sock.close() + + +class IocpEventLoop(BaseEventLoop): + + SocketTransport = _IocpSocketTransport + + @staticmethod + def SslTransport(*args, **kwds): + raise NotImplementedError + + def __init__(self, proactor=None): + super().__init__() + if proactor is None: + proactor = IocpProactor() + logging.debug('Using proactor: %s', proactor.__class__.__name__) + self._proactor = proactor + self._selector = proactor # convenient alias + self._make_self_pipe() + + def close(self): + if self._proactor is not None: + self._proactor.close() + self._proactor = None + self._selector = None + self._ssock.close() + self._csock.close() + + def sock_recv(self, sock, n): + return self._proactor.recv(sock, n) + + def sock_sendall(self, sock, data): + return self._proactor.send(sock, data) + + def sock_connect(self, sock, address): + return self._proactor.connect(sock, address) + + def sock_accept(self, sock): + return self._proactor.accept(sock) + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + def loop(f=None): + try: + if f: + f.result() # may raise + f = self._proactor.recv(self._ssock, 4096) + except: + self.close() + raise + else: + f.add_done_callback(loop) + self.call_soon(loop) + + def _write_to_self(self): + self._csock.send(b'x') + + def _start_serving(self, protocol_factory, sock): + def loop(f=None): + try: + if f: + conn, addr = f.result() + protocol = protocol_factory() + transport = self.SocketTransport(self, conn, protocol) + f = self._proactor.accept(sock) + except OSError as exc: + if exc.errno in _TRYAGAIN: + self.call_soon(loop) + else: + sock.close() + logging.exception('Accept failed') + else: + f.add_done_callback(loop) + self.call_soon(loop) + + def _process_events(self, event_list): + pass # XXX hard work currently done in poll From d7995049858ebdc72756743c356228c02d61f4ce Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 28 Jan 2013 07:50:49 -0800 Subject: [PATCH 0272/1502] Get rid of unneeded continue. --- tulip/http_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tulip/http_client.py b/tulip/http_client.py index f01ee631..0a03d81f 100644 --- a/tulip/http_client.py +++ b/tulip/http_client.py @@ -70,7 +70,6 @@ def readline(self): assert self.waiter is None self.waiter = futures.Future() yield from self.waiter - continue parts = [] while self.buffer: data = self.buffer.popleft() From 72c2880274dbda2fedf7ae88429fbfb998ea3bdd Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Mon, 28 Jan 2013 22:53:14 +0000 Subject: [PATCH 0273/1502] Ignore ERROR_CONNECTION_ABORTED if transport has been closed. Plus changes suggested by Guido. --- crawl.py | 3 + curl.py | 3 + tulip/iocp_events.py | 57 ++++++++------- tulip/selectors.py | 4 - tulip/unix_events.py | 171 +++++++++++++++++++++---------------------- 5 files changed, 121 insertions(+), 117 deletions(-) diff --git a/crawl.py b/crawl.py index b9881fcb..7272f18f 100755 --- a/crawl.py +++ b/crawl.py @@ -133,4 +133,7 @@ def main(): if __name__ == '__main__': + from tulip import events, iocp_events + el = iocp_events.IocpEventLoop() + events.set_event_loop(el) main() diff --git a/curl.py b/curl.py index 1a73c194..6620fa8c 100755 --- a/curl.py +++ b/curl.py @@ -28,4 +28,7 @@ def main(): if __name__ == '__main__': + from tulip import events, iocp_events + el = iocp_events.IocpEventLoop() + events.set_event_loop(el) main() diff --git a/tulip/iocp_events.py b/tulip/iocp_events.py index 8148ca27..8c63f03c 100644 --- a/tulip/iocp_events.py +++ b/tulip/iocp_events.py @@ -15,24 +15,26 @@ from _winapi import CloseHandle -from . import transports from . import events - -from .futures import Future -from .unix_events import BaseEventLoop, _StopError -from .winsocketpair import socketpair -from ._overlapped import * +from . import futures +from . import transports +from . import unix_events +from . import winsocketpair +from . import _overlapped -_TRYAGAIN = frozenset() # XXX +NULL = 0 +INFINITE = 0xffffffff +ERROR_CONNECTION_REFUSED = 1225 +ERROR_CONNECTION_ABORTED = 1236 class IocpProactor: def __init__(self, concurrency=0xffffffff): self._results = [] - self._iocp = CreateIoCompletionPort( - INVALID_HANDLE_VALUE, NULL, 0, concurrency) + self._iocp = _overlapped.CreateIoCompletionPort( + _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency) self._cache = {} self._registered = weakref.WeakSet() @@ -47,54 +49,56 @@ def select(self, timeout=None): def recv(self, conn, nbytes, flags=0): self._register_with_iocp(conn) - ov = Overlapped(NULL) + ov = _overlapped.Overlapped(NULL) ov.WSARecv(conn.fileno(), nbytes, flags) return self._register(ov, conn, ov.getresult) def send(self, conn, buf, flags=0): self._register_with_iocp(conn) - ov = Overlapped(NULL) + ov = _overlapped.Overlapped(NULL) ov.WSASend(conn.fileno(), buf, flags) return self._register(ov, conn, ov.getresult) def accept(self, listener): self._register_with_iocp(listener) conn = self._get_accept_socket() - ov = Overlapped(NULL) + ov = _overlapped.Overlapped(NULL) ov.AcceptEx(listener.fileno(), conn.fileno()) def finish_accept(): addr = ov.getresult() conn.setsockopt(socket.SOL_SOCKET, - SO_UPDATE_ACCEPT_CONTEXT, listener.fileno()) + _overlapped.SO_UPDATE_ACCEPT_CONTEXT, + listener.fileno()) conn.settimeout(listener.gettimeout()) return conn, conn.getpeername() return self._register(ov, listener, finish_accept) def connect(self, conn, address): self._register_with_iocp(conn) - BindLocal(conn.fileno(), len(address)) - ov = Overlapped(NULL) + _overlapped.BindLocal(conn.fileno(), len(address)) + ov = _overlapped.Overlapped(NULL) ov.ConnectEx(conn.fileno(), address) def finish_connect(): try: ov.getresult() except OSError as e: - if e.winerror == 1225: + if e.winerror == ERROR_CONNECTION_REFUSED: raise ConnectionRefusedError(errno.ECONNREFUSED, 'connection refused') raise conn.setsockopt(socket.SOL_SOCKET, - SO_UPDATE_CONNECT_CONTEXT, 0) + _overlapped.SO_UPDATE_CONNECT_CONTEXT, + 0) return conn return self._register(ov, conn, finish_connect) def _register_with_iocp(self, obj): if obj not in self._registered: self._registered.add(obj) - CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) + _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) def _register(self, ov, obj, callback): - f = Future() + f = futures.Future() self._cache[ov.address] = (f, ov, obj, callback) return f @@ -113,7 +117,7 @@ def _poll(self, timeout=None): if ms >= INFINITE: raise ValueError("timeout too big") while True: - status = GetQueuedCompletionStatus(self._iocp, ms) + status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms) if status is None: return address = status[3] @@ -172,6 +176,8 @@ def _loop_reading(self, f=None): self._event_loop.call_soon(self._protocol.data_received, data) self._read_fut = self._event_loop._proactor.recv(self._sock, 4096) except OSError as exc: + if exc.winerror == ERROR_CONNECTION_ABORTED and self._closing: + return self._fatal_error(exc) else: self._read_fut.add_done_callback(self._loop_reading) @@ -230,7 +236,7 @@ def _call_connection_lost(self, exc): self._sock.close() -class IocpEventLoop(BaseEventLoop): +class IocpEventLoop(unix_events.BaseEventLoop): SocketTransport = _IocpSocketTransport @@ -269,7 +275,7 @@ def sock_accept(self, sock): def _make_self_pipe(self): # A self-socket, really. :-) - self._ssock, self._csock = socketpair() + self._ssock, self._csock = winsocketpair.socketpair() self._ssock.setblocking(False) self._csock.setblocking(False) def loop(f=None): @@ -296,11 +302,8 @@ def loop(f=None): transport = self.SocketTransport(self, conn, protocol) f = self._proactor.accept(sock) except OSError as exc: - if exc.errno in _TRYAGAIN: - self.call_soon(loop) - else: - sock.close() - logging.exception('Accept failed') + sock.close() + logging.exception('Accept failed') else: f.add_done_callback(loop) self.call_soon(loop) diff --git a/tulip/selectors.py b/tulip/selectors.py index 2ea92b38..05434630 100644 --- a/tulip/selectors.py +++ b/tulip/selectors.py @@ -210,10 +210,6 @@ def _key_from_fd(self, fd): logging.warn('No key found for fd %r', fd) return None - def wrap_socket(self, sock): - """Return sock or a wrapper for sock compatible with selector""" - return sock - class SelectSelector(_BaseSelector): """Select-based selector.""" diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 1db6666e..546e55d5 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -310,91 +310,6 @@ def start_serving(self, protocol_factory, host, port, *, self._start_serving(protocol_factory, sock) return sock - def add_signal_handler(self, sig, callback, *args): - """Add a handler for a signal. UNIX only. - - Raise ValueError if the signal number is invalid or uncatchable. - Raise RuntimeError if there is a problem setting up the handler. - """ - self._check_signal(sig) - try: - # set_wakeup_fd() raises ValueError if this is not the - # main thread. By calling it early we ensure that an - # event loop running in another thread cannot add a signal - # handler. - signal.set_wakeup_fd(self._csock.fileno()) - except ValueError as exc: - raise RuntimeError(str(exc)) - handler = events.make_handler(None, callback, args) - self._signal_handlers[sig] = handler - try: - signal.signal(sig, self._handle_signal) - except OSError as exc: - del self._signal_handlers[sig] - if not self._signal_handlers: - try: - signal.set_wakeup_fd(-1) - except ValueError as nexc: - logging.info('set_wakeup_fd(-1) failed: %s', nexc) - if exc.errno == errno.EINVAL: - raise RuntimeError('sig {} cannot be caught'.format(sig)) - else: - raise - return handler - - def _handle_signal(self, sig, arg): - """Internal helper that is the actual signal handler.""" - handler = self._signal_handlers.get(sig) - if handler is None: - return # Assume it's some race condition. - if handler.cancelled: - self.remove_signal_handler(sig) # Remove it properly. - else: - self.call_soon_threadsafe(handler.callback, *handler.args) - - def remove_signal_handler(self, sig): - """Remove a handler for a signal. UNIX only. - - Return True if a signal handler was removed, False if not.""" - self._check_signal(sig) - try: - del self._signal_handlers[sig] - except KeyError: - return False - if sig == signal.SIGINT: - handler = signal.default_int_handler - else: - handler = signal.SIG_DFL - try: - signal.signal(sig, handler) - except OSError as exc: - if exc.errno == errno.EINVAL: - raise RuntimeError('sig {} cannot be caught'.format(sig)) - else: - raise - if not self._signal_handlers: - try: - signal.set_wakeup_fd(-1) - except ValueError as exc: - logging.info('set_wakeup_fd(-1) failed: %s', exc) - return True - - def _check_signal(self, sig): - """Internal helper to validate a signal. - - Raise ValueError if the signal number is invalid or uncatchable. - Raise RuntimeError if there is a problem setting up the handler. - """ - if not isinstance(sig, int): - raise TypeError('sig must be an int, not {!r}'.format(sig)) - if signal is None: - raise RuntimeError('Signals are not supported') - if not (1 <= sig < signal.NSIG): - raise ValueError('sig {} out of range(1, {})'.format(sig, - signal.NSIG)) - if sys.platform == 'win32': - raise RuntimeError('Signals are not really supported on Windows') - def _add_callback(self, handler): """Add a Handler to ready or scheduled.""" if handler.cancelled: @@ -789,6 +704,91 @@ def _process_events(self, event_list): else: self._add_callback(connector) + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + self._check_signal(sig) + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except ValueError as exc: + raise RuntimeError(str(exc)) + handler = events.make_handler(None, callback, args) + self._signal_handlers[sig] = handler + try: + signal.signal(sig, self._handle_signal) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as nexc: + logging.info('set_wakeup_fd(-1) failed: %s', nexc) + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + return handler + + def _handle_signal(self, sig, arg): + """Internal helper that is the actual signal handler.""" + handler = self._signal_handlers.get(sig) + if handler is None: + return # Assume it's some race condition. + if handler.cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self.call_soon_threadsafe(handler.callback, *handler.args) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not.""" + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as exc: + logging.info('set_wakeup_fd(-1) failed: %s', exc) + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + if signal is None: + raise RuntimeError('Signals are not supported') + if not (1 <= sig < signal.NSIG): + raise ValueError('sig {} out of range(1, {})'.format(sig, + signal.NSIG)) + if sys.platform == 'win32': + raise RuntimeError('Signals are not really supported on Windows') + class _UnixSocketTransport(transports.Transport): @@ -816,7 +816,6 @@ def _read_ready(self): self._event_loop.remove_reader(self._sock.fileno()) self._event_loop.call_soon(self._protocol.eof_received) - def write(self, data): assert isinstance(data, bytes) assert not self._closing From 56cb0b06f1d7636698b70dab39b26ce5ce0a4a99 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Mon, 28 Jan 2013 22:54:03 +0000 Subject: [PATCH 0274/1502] Merge. --- tulip/http_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tulip/http_client.py b/tulip/http_client.py index f01ee631..0a03d81f 100644 --- a/tulip/http_client.py +++ b/tulip/http_client.py @@ -70,7 +70,6 @@ def readline(self): assert self.waiter is None self.waiter = futures.Future() yield from self.waiter - continue parts = [] while self.buffer: data = self.buffer.popleft() From f99024a10970704b526345aa7b3220997dcd0760 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 29 Jan 2013 16:54:34 -0800 Subject: [PATCH 0275/1502] Use pep8 naming style for test names --- check.py | 59 +++++++++---------- runtests.py | 121 ++++++++++++++++++++------------------- srv.py | 14 +++-- tulip/events_test.py | 61 +++++++++++--------- tulip/futures_test.py | 22 +++---- tulip/subprocess_test.py | 7 ++- tulip/tasks_test.py | 22 +++---- 7 files changed, 162 insertions(+), 144 deletions(-) diff --git a/check.py b/check.py index f0aa9a66..64bc2cdd 100644 --- a/check.py +++ b/check.py @@ -3,38 +3,39 @@ import sys, os def main(): - args = sys.argv[1:] or os.curdir - for arg in args: - if os.path.isdir(arg): - for dn, dirs, files in os.walk(arg): - for fn in sorted(files): - if fn.endswith('.py'): - process(os.path.join(dn, fn)) - dirs[:] = [d for d in dirs if d[0] != '.'] - dirs.sort() - else: - process(arg) + args = sys.argv[1:] or os.curdir + for arg in args: + if os.path.isdir(arg): + for dn, dirs, files in os.walk(arg): + for fn in sorted(files): + if fn.endswith('.py'): + process(os.path.join(dn, fn)) + dirs[:] = [d for d in dirs if d[0] != '.'] + dirs.sort() + else: + process(arg) def isascii(x): - try: - x.encode('ascii') - return True - except UnicodeError: - return False + try: + x.encode('ascii') + return True + except UnicodeError: + return False def process(fn): - try: - f = open(fn) - except IOError as err: - print(err) - return - try: - for i, line in enumerate(f): - line = line.rstrip('\n') - sline = line.rstrip() - if len(line) > 80 or line != sline or not isascii(line): - print('%s:%d:%s%s' % (fn, i+1, sline, '_' * (len(line) - len(sline)))) - finally: - f.close() + try: + f = open(fn) + except IOError as err: + print(err) + return + try: + for i, line in enumerate(f): + line = line.rstrip('\n') + sline = line.rstrip() + if len(line) > 80 or line != sline or not isascii(line): + print('%s:%d:%s%s' % ( + fn, i+1, sline, '_' * (len(line) - len(sline)))) + finally: + f.close() main() diff --git a/runtests.py b/runtests.py index 275b23cb..6758c742 100644 --- a/runtests.py +++ b/runtests.py @@ -27,68 +27,71 @@ def load_tests(includes=(), excludes=()): - test_mods = [f[:-3] for f in os.listdir(TULIP_DIR) if f.endswith('_test.py')] - - if sys.platform == 'win32': - try: - test_mods.remove('subprocess_test') - except ValueError: - pass - tulip = __import__('tulip', fromlist=test_mods) - - loader = unittest.TestLoader() - suite = unittest.TestSuite() - - for mod in [getattr(tulip, name) for name in test_mods]: - for name in set(dir(mod)): - if name.endswith('Tests'): - test_module = getattr(mod, name) - tests = loader.loadTestsFromTestCase(test_module) - if includes: - tests = [test - for test in tests - if any(re.search(pat, test.id()) for pat in includes)] - if excludes: - tests = [test - for test in tests - if not any(re.search(pat, test.id()) for pat in excludes)] - suite.addTests(tests) - - return suite + test_mods = [f[:-3] for f in os.listdir(TULIP_DIR) + if f.endswith('_test.py')] + + if sys.platform == 'win32': + try: + test_mods.remove('subprocess_test') + except ValueError: + pass + tulip = __import__('tulip', fromlist=test_mods) + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + for mod in [getattr(tulip, name) for name in test_mods]: + for name in set(dir(mod)): + if name.endswith('Tests'): + test_module = getattr(mod, name) + tests = loader.loadTestsFromTestCase(test_module) + if includes: + tests = [test + for test in tests + if any(re.search(pat, test.id()) + for pat in includes)] + if excludes: + tests = [test + for test in tests + if not any(re.search(pat, test.id()) + for pat in excludes)] + suite.addTests(tests) + + return suite def main(): - excludes = [] - includes = [] - patterns = includes # A reference. - v = 1 - for arg in sys.argv[1:]: - if arg.startswith('-v'): - v += arg.count('v') - elif arg == '-q': - v = 0 - elif arg == '-x': - if patterns is includes: - patterns = excludes - else: - patterns = includes - elif arg and not arg.startswith('-'): - patterns.append(arg) - tests = load_tests(includes, excludes) - logger = logging.getLogger() - if v == 0: - logger.setLevel(logging.CRITICAL) - elif v == 1: - logger.setLevel(logging.ERROR) - elif v == 2: - logger.setLevel(logging.WARNING) - elif v == 3: - logger.setLevel(logging.INFO) - elif v >= 4: - logger.setLevel(logging.DEBUG) - result = unittest.TextTestRunner(verbosity=v).run(tests) - sys.exit(not result.wasSuccessful()) + excludes = [] + includes = [] + patterns = includes # A reference. + v = 1 + for arg in sys.argv[1:]: + if arg.startswith('-v'): + v += arg.count('v') + elif arg == '-q': + v = 0 + elif arg == '-x': + if patterns is includes: + patterns = excludes + else: + patterns = includes + elif arg and not arg.startswith('-'): + patterns.append(arg) + tests = load_tests(includes, excludes) + logger = logging.getLogger() + if v == 0: + logger.setLevel(logging.CRITICAL) + elif v == 1: + logger.setLevel(logging.ERROR) + elif v == 2: + logger.setLevel(logging.WARNING) + elif v == 3: + logger.setLevel(logging.INFO) + elif v >= 4: + logger.setLevel(logging.DEBUG) + result = unittest.TextTestRunner(verbosity=v).run(tests) + sys.exit(not result.wasSuccessful()) if __name__ == '__main__': - main() + main() diff --git a/srv.py b/srv.py index bae9fdd8..077279c1 100644 --- a/srv.py +++ b/srv.py @@ -26,14 +26,16 @@ def handle_request(self): self.transport.close() return bmethod, bpath, bversion = match.groups() - print('method = {!r}; path = {!r}; version = {!r}'.format(bmethod, bpath, bversion)) + print('method = {!r}; path = {!r}; version = {!r}'.format( + bmethod, bpath, bversion)) try: path = bpath.decode('ascii') except UnicodeError as exc: print('not ascii', repr(bpath), exc) path = None else: - if not (path.isprintable() and path.startswith('/')) or '/.' in path: + if (not (path.isprintable() and path.startswith('/')) or + '/.' in path): print('bad path', repr(path)) path = None else: @@ -69,7 +71,7 @@ def handle_request(self): if isdir: write(b'Content-type: text/html\r\n') else: - write(b'Content-type: text/plain\r\n') + write(b'Content-type: text/plain\r\n') write(b'\r\n') if isdir: write(b'') else: try: diff --git a/tulip/events_test.py b/tulip/events_test.py index 76802579..c0d53c22 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -23,21 +23,26 @@ class MyProto(protocols.Protocol): + def __init__(self): self.state = 'INITIAL' self.nbytes = 0 + def connection_made(self, transport): self.transport = transport assert self.state == 'INITIAL', self.state self.state = 'CONNECTED' transport.write(b'GET / HTTP/1.0\r\nHost: xkcd.com\r\n\r\n') + def data_received(self, data): assert self.state == 'CONNECTED', self.state self.nbytes += len(data) + def eof_received(self): assert self.state == 'CONNECTED', self.state self.state = 'EOF' self.transport.close() + def connection_lost(self, exc): assert self.state in ('CONNECTED', 'EOF'), self.state self.state = 'CLOSED' @@ -54,11 +59,11 @@ def tearDown(self): self.event_loop.close() gc.collect() - def testRun(self): + def test_run(self): el = events.get_event_loop() el.run() # Returns immediately. - def testCallLater(self): + def test_call_later(self): el = events.get_event_loop() results = [] def callback(arg): @@ -70,7 +75,7 @@ def callback(arg): self.assertEqual(results, ['hello world']) self.assertTrue(t1-t0 >= 0.09) - def testCallRepeatedly(self): + def test_call_repeatedly(self): el = events.get_event_loop() results = [] def callback(arg): @@ -80,7 +85,7 @@ def callback(arg): el.run() self.assertEqual(results, ['ho', 'ho', 'ho']) - def testCallSoon(self): + def test_call_soon(self): el = events.get_event_loop() results = [] def callback(arg1, arg2): @@ -89,7 +94,7 @@ def callback(arg1, arg2): el.run() self.assertEqual(results, [('hello', 'world')]) - def testCallSoonWithHandler(self): + def test_call_soon_with_handler(self): el = events.get_event_loop() results = [] def callback(): @@ -99,7 +104,7 @@ def callback(): el.run() self.assertEqual(results, ['yeah']) - def testCallSoonThreadsafe(self): + def test_call_soon_threadsafe(self): el = events.get_event_loop() results = [] def callback(arg): @@ -116,7 +121,7 @@ def run(): self.assertEqual(results, ['hello', 'world']) self.assertTrue(t1-t0 >= 0.09) - def testCallSoonThreadsafeWithHandler(self): + def test_call_soon_threadsafe_with_handler(self): el = events.get_event_loop() results = [] def callback(arg): @@ -134,7 +139,7 @@ def run(): self.assertEqual(results, ['hello', 'world']) self.assertTrue(t1-t0 >= 0.09) - def testWrapFuture(self): + def test_wrap_future(self): el = events.get_event_loop() def run(arg): time.sleep(0.1) @@ -145,7 +150,7 @@ def run(arg): res = el.run_until_complete(f2) self.assertEqual(res, 'oi') - def testRunInExecutor(self): + def test_run_in_executor(self): el = events.get_event_loop() def run(arg): time.sleep(0.1) @@ -154,7 +159,7 @@ def run(arg): res = el.run_until_complete(f2) self.assertEqual(res, 'yo') - def testRunInExecutorWithHandler(self): + def test_run_in_executor_with_handler(self): el = events.get_event_loop() def run(arg): time.sleep(0.1) @@ -164,7 +169,7 @@ def run(arg): res = el.run_until_complete(f2) self.assertEqual(res, 'yo') - def testReaderCallback(self): + def test_reader_callback(self): el = events.get_event_loop() r, w = unix_events.socketpair() bytes_read = [] @@ -187,7 +192,7 @@ def reader(): el.run() self.assertEqual(b''.join(bytes_read), b'abcdef') - def testReaderCallbackWithHandler(self): + def test_reader_callback_with_handler(self): el = events.get_event_loop() r, w = unix_events.socketpair() bytes_read = [] @@ -211,7 +216,7 @@ def reader(): el.run() self.assertEqual(b''.join(bytes_read), b'abcdef') - def testReaderCallbackCancel(self): + def test_reader_callback_cancel(self): el = events.get_event_loop() r, w = unix_events.socketpair() bytes_read = [] @@ -230,7 +235,7 @@ def reader(): el.run() self.assertEqual(b''.join(bytes_read), b'abcdef') - def testWriterCallback(self): + def test_writer_callback(self): el = events.get_event_loop() r, w = unix_events.socketpair() w.setblocking(False) @@ -244,7 +249,7 @@ def remove_writer(): r.close() self.assertTrue(len(data) >= 200) - def testWriterCallbackWithHandler(self): + def test_writer_callback_with_handler(self): el = events.get_event_loop() r, w = unix_events.socketpair() w.setblocking(False) @@ -259,7 +264,7 @@ def remove_writer(): r.close() self.assertTrue(len(data) >= 200) - def testWriterCallbackCancel(self): + def test_writer_callback_cancel(self): el = events.get_event_loop() r, w = unix_events.socketpair() w.setblocking(False) @@ -273,7 +278,7 @@ def sender(): r.close() self.assertTrue(data == b'x'*256) - def testSockClientOps(self): + def test_sock_client_ops(self): el = events.get_event_loop() sock = socket.socket() sock.setblocking(False) @@ -284,7 +289,7 @@ def testSockClientOps(self): sock.close() self.assertTrue(data.startswith(b'HTTP/1.1 302 Found\r\n')) - def testSockClientFail(self): + def test_sock_client_fail(self): el = events.get_event_loop() sock = socket.socket() sock.setblocking(False) @@ -293,7 +298,7 @@ def testSockClientFail(self): el.run_until_complete(el.sock_connect(sock, ('python.org', 12345))) sock.close() - def testSockAccept(self): + def test_sock_accept(self): listener = socket.socket() listener.setblocking(False) listener.bind(('127.0.0.1', 0)) @@ -311,7 +316,7 @@ def testSockAccept(self): listener.close() @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') - def testAddSignalHandler(self): + def test_add_signal_handler(self): caught = 0 def my_handler(): nonlocal caught @@ -345,7 +350,7 @@ def my_handler(): self.assertFalse(el.remove_signal_handler(signal.SIGINT)) @unittest.skipIf(sys.platform == 'win32', 'Unix only') - def testCancelSignalHandler(self): + def test_cancel_signal_handler(self): # Cancelling the handler should remove it (eventually). caught = 0 def my_handler(): @@ -359,7 +364,7 @@ def my_handler(): self.assertEqual(caught, 0) @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') - def testSignalHandlingWhileSelecting(self): + def test_signal_handling_while_selecting(self): # Test with a signal actually arriving during a select() call. caught = 0 def my_handler(): @@ -372,7 +377,7 @@ def my_handler(): el.run_forever() self.assertEqual(caught, 1) - def testCreateTransport(self): + def test_create_transport(self): el = events.get_event_loop() # TODO: This depends on xkcd.com behavior! f = el.create_connection(MyProto, 'xkcd.com', 80) @@ -383,7 +388,7 @@ def testCreateTransport(self): self.assertTrue(pr.nbytes > 0) @unittest.skipIf(ssl is None, 'No ssl module') - def testCreateSslTransport(self): + def test_create_ssl_transport(self): el = events.get_event_loop() # TODO: This depends on xkcd.com behavior! f = el.create_connection(MyProto, 'xkcd.com', 443, ssl=True) @@ -394,7 +399,7 @@ def testCreateSslTransport(self): el.run() self.assertTrue(pr.nbytes > 0) - def testStartServing(self): + def test_start_serving(self): el = events.get_event_loop() f = el.start_serving(MyProto, '0.0.0.0', 0) sock = el.run_until_complete(f) @@ -434,10 +439,10 @@ class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): class HandlerTests(unittest.TestCase): - def testHandler(self): + def test_handler(self): pass - def testMakeHandler(self): + def test_make_handler(self): def callback(*args): return args h1 = events.Handler(None, callback, ()) @@ -447,7 +452,7 @@ def callback(*args): class PolicyTests(unittest.TestCase): - def testPolicy(self): + def test_policy(self): pass diff --git a/tulip/futures_test.py b/tulip/futures_test.py index 7834fec8..2610b5be 100644 --- a/tulip/futures_test.py +++ b/tulip/futures_test.py @@ -11,17 +11,17 @@ def _fakefunc(f): class FutureTests(unittest.TestCase): - def testInitialState(self): + def test_initial_state(self): f = futures.Future() self.assertFalse(f.cancelled()) self.assertFalse(f.running()) self.assertFalse(f.done()) - def testInitEventLoopPositional(self): + def test_init_event_loop_positional(self): # Make sure Future does't accept a positional argument self.assertRaises(TypeError, futures.Future, 42) - def testCancel(self): + def test_cancel(self): f = futures.Future() self.assertTrue(f.cancel()) self.assertTrue(f.cancelled()) @@ -33,7 +33,7 @@ def testCancel(self): self.assertRaises(futures.InvalidStateError, f.set_exception, None) self.assertFalse(f.cancel()) - def testResult(self): + def test_result(self): f = futures.Future() self.assertRaises(futures.InvalidStateError, f.result) self.assertRaises(futures.InvalidTimeoutError, f.result, 10) @@ -48,7 +48,7 @@ def testResult(self): self.assertRaises(futures.InvalidStateError, f.set_exception, None) self.assertFalse(f.cancel()) - def testException(self): + def test_exception(self): exc = RuntimeError() f = futures.Future() self.assertRaises(futures.InvalidStateError, f.exception) @@ -64,7 +64,7 @@ def testException(self): self.assertRaises(futures.InvalidStateError, f.set_exception, None) self.assertFalse(f.cancel()) - def testYieldFromTwice(self): + def test_yield_from_twice(self): f = futures.Future() def fixture(): yield 'A' @@ -80,7 +80,7 @@ def fixture(): # The second "yield from f" does not yield f. self.assertEqual(next(g), ('C', 42)) # yield 'C', y. - def testRepr(self): + def test_repr(self): f_pending = futures.Future() self.assertEqual(repr(f_pending), 'Future') @@ -108,7 +108,7 @@ def testRepr(self): self.assertIn('Future', r) - def testCopyState(self): + def test_copy_state(self): # Test the internal _copy_state method since it's being directly # invoked in other modules. f = futures.Future() @@ -153,7 +153,7 @@ def bag_appender(future): def _new_future(self): return futures.Future(event_loop=_FakeEventLoop()) - def testCallbacksInvokedOnSetResult(self): + def test_callbacks_invoked_on_set_result(self): bag = [] f = self._new_future() f.add_done_callback(self._make_callback(bag, 42)) @@ -164,7 +164,7 @@ def testCallbacksInvokedOnSetResult(self): self.assertEqual(bag, [42, 17]) self.assertEqual(f.result(), 'foo') - def testCallbacksInvokedOnSetException(self): + def test_callbacks_invoked_on_set_exception(self): bag = [] f = self._new_future() f.add_done_callback(self._make_callback(bag, 100)) @@ -175,7 +175,7 @@ def testCallbacksInvokedOnSetException(self): self.assertEqual(bag, [100]) self.assertEqual(f.exception(), exc) - def testRemoveDoneCallback(self): + def test_remove_done_callback(self): bag = [] f = self._new_future() cb1 = self._make_callback(bag, 1) diff --git a/tulip/subprocess_test.py b/tulip/subprocess_test.py index b4bcc26f..3d996f6a 100644 --- a/tulip/subprocess_test.py +++ b/tulip/subprocess_test.py @@ -9,22 +9,27 @@ class MyProto(protocols.Protocol): + def __init__(self): self.state = 'INITIAL' self.nbytes = 0 + def connection_made(self, transport): self.transport = transport assert self.state == 'INITIAL', self.state self.state = 'CONNECTED' transport.write_eof() + def data_received(self, data): logging.info('received: %r', data) assert self.state == 'CONNECTED', self.state self.nbytes += len(data) + def eof_received(self): assert self.state == 'CONNECTED', self.state self.state = 'EOF' self.transport.close() + def connection_lost(self, exc): assert self.state in ('CONNECTED', 'EOF'), self.state self.state = 'CLOSED' @@ -39,7 +44,7 @@ def setUp(self): def tearDown(self): self.event_loop.close() - def testUnixSubprocess(self): + def test_unix_subprocess(self): p = MyProto() t = subprocess_transport.UnixSubprocessTransport(p, ['/bin/ls', '-lR']) self.event_loop.run() diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py index 456cccfa..bc6fa531 100644 --- a/tulip/tasks_test.py +++ b/tulip/tasks_test.py @@ -27,7 +27,7 @@ def tearDown(self): self.event_loop.close() super().tearDown() - def testTaskClass(self): + def test_task_class(self): @tasks.coroutine def notmuch(): yield from [] @@ -37,7 +37,7 @@ def notmuch(): self.assertTrue(t.done()) self.assertEqual(t.result(), 'ok') - def testTaskDecorator(self): + def test_task_decorator(self): @tasks.task def notmuch(): yield from [] @@ -47,7 +47,7 @@ def notmuch(): self.assertTrue(t.done()) self.assertEqual(t.result(), 'ko') - def testTaskRepr(self): + def test_task_repr(self): @tasks.task def notmuch(): yield from [] @@ -64,7 +64,7 @@ def notmuch(): self.event_loop.run_until_complete(t) self.assertEqual(repr(t), "Task()") - def testTaskBasics(self): + def test_task_basics(self): @tasks.task def outer(): a = yield from inner1() @@ -81,7 +81,7 @@ def inner2(): t = outer() self.assertEqual(self.event_loop.run_until_complete(t), 1042) - def testWait(self): + def test_wait(self): a = tasks.sleep(0.1) b = tasks.sleep(0.15) @tasks.coroutine @@ -102,7 +102,7 @@ def foo(): self.assertTrue(t1-t0 <= 0.01) # TODO: Test different return_when values. - def testWaitWithException(self): + def test_wait_with_exception(self): self.suppress_log_errors() a = tasks.sleep(0.1) @tasks.coroutine @@ -126,7 +126,7 @@ def foo(): t1 = time.monotonic() self.assertTrue(t1-t0 <= 0.01) - def testWaitWithTimeout(self): + def test_wait_with_timeout(self): a = tasks.sleep(0.1) b = tasks.sleep(0.15) @tasks.coroutine @@ -140,7 +140,7 @@ def foo(): self.assertTrue(t1-t0 >= 0.1) self.assertTrue(t1-t0 <= 0.13) - def testAsCompleted(self): + def test_as_completed(self): @tasks.coroutine def sleeper(dt, x): yield from tasks.sleep(dt) @@ -167,7 +167,7 @@ def foo(): t1 = time.monotonic() self.assertTrue(t1-t0 <= 0.01) - def testAsCompletedWithTimeout(self): + def test_as_completed_with_timeout(self): self.suppress_log_errors() a = tasks.sleep(0.1, 'a') b = tasks.sleep(0.15, 'b') @@ -190,7 +190,7 @@ def foo(): self.assertEqual(res[1][0], 2) self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) - def testSleep(self): + def test_sleep(self): @tasks.coroutine def sleeper(dt, arg): yield from tasks.sleep(dt/2) @@ -204,7 +204,7 @@ def sleeper(dt, arg): self.assertTrue(t.done()) self.assertEqual(t.result(), 'yeah') - def testTaskCancelSleepingTask(self): + def test_task_cancel_sleeping_task(self): sleepfut = None @tasks.task def sleep(dt): From 159c6080bafa1969cdc3c03103d36be897f3a2f6 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Wed, 30 Jan 2013 11:15:43 +0000 Subject: [PATCH 0276/1502] In curl.py and crawl.py, only use IocpEventLoop if available. --- crawl.py | 10 +++++++--- curl.py | 10 +++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/crawl.py b/crawl.py index 7272f18f..8d404f55 100755 --- a/crawl.py +++ b/crawl.py @@ -133,7 +133,11 @@ def main(): if __name__ == '__main__': - from tulip import events, iocp_events - el = iocp_events.IocpEventLoop() - events.set_event_loop(el) + try: + from tulip import events, iocp_events + except ImportError: + pass + else: + el = iocp_events.IocpEventLoop() + events.set_event_loop(el) main() diff --git a/curl.py b/curl.py index 6620fa8c..af3725e8 100755 --- a/curl.py +++ b/curl.py @@ -28,7 +28,11 @@ def main(): if __name__ == '__main__': - from tulip import events, iocp_events - el = iocp_events.IocpEventLoop() - events.set_event_loop(el) + try: + from tulip import events, iocp_events + except ImportError: + pass + else: + el = iocp_events.IocpEventLoop() + events.set_event_loop(el) main() From 3a2a1c7bd6902f6288195777206454fba131b39b Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Wed, 30 Jan 2013 11:46:45 +0000 Subject: [PATCH 0277/1502] Merge. --- check.py | 59 +++++++++---------- runtests.py | 121 ++++++++++++++++++++------------------- srv.py | 14 +++-- tulip/events_test.py | 75 +++++++++++++----------- tulip/futures_test.py | 22 +++---- tulip/subprocess_test.py | 7 ++- tulip/tasks_test.py | 22 +++---- 7 files changed, 169 insertions(+), 151 deletions(-) diff --git a/check.py b/check.py index f0aa9a66..64bc2cdd 100644 --- a/check.py +++ b/check.py @@ -3,38 +3,39 @@ import sys, os def main(): - args = sys.argv[1:] or os.curdir - for arg in args: - if os.path.isdir(arg): - for dn, dirs, files in os.walk(arg): - for fn in sorted(files): - if fn.endswith('.py'): - process(os.path.join(dn, fn)) - dirs[:] = [d for d in dirs if d[0] != '.'] - dirs.sort() - else: - process(arg) + args = sys.argv[1:] or os.curdir + for arg in args: + if os.path.isdir(arg): + for dn, dirs, files in os.walk(arg): + for fn in sorted(files): + if fn.endswith('.py'): + process(os.path.join(dn, fn)) + dirs[:] = [d for d in dirs if d[0] != '.'] + dirs.sort() + else: + process(arg) def isascii(x): - try: - x.encode('ascii') - return True - except UnicodeError: - return False + try: + x.encode('ascii') + return True + except UnicodeError: + return False def process(fn): - try: - f = open(fn) - except IOError as err: - print(err) - return - try: - for i, line in enumerate(f): - line = line.rstrip('\n') - sline = line.rstrip() - if len(line) > 80 or line != sline or not isascii(line): - print('%s:%d:%s%s' % (fn, i+1, sline, '_' * (len(line) - len(sline)))) - finally: - f.close() + try: + f = open(fn) + except IOError as err: + print(err) + return + try: + for i, line in enumerate(f): + line = line.rstrip('\n') + sline = line.rstrip() + if len(line) > 80 or line != sline or not isascii(line): + print('%s:%d:%s%s' % ( + fn, i+1, sline, '_' * (len(line) - len(sline)))) + finally: + f.close() main() diff --git a/runtests.py b/runtests.py index 275b23cb..6758c742 100644 --- a/runtests.py +++ b/runtests.py @@ -27,68 +27,71 @@ def load_tests(includes=(), excludes=()): - test_mods = [f[:-3] for f in os.listdir(TULIP_DIR) if f.endswith('_test.py')] - - if sys.platform == 'win32': - try: - test_mods.remove('subprocess_test') - except ValueError: - pass - tulip = __import__('tulip', fromlist=test_mods) - - loader = unittest.TestLoader() - suite = unittest.TestSuite() - - for mod in [getattr(tulip, name) for name in test_mods]: - for name in set(dir(mod)): - if name.endswith('Tests'): - test_module = getattr(mod, name) - tests = loader.loadTestsFromTestCase(test_module) - if includes: - tests = [test - for test in tests - if any(re.search(pat, test.id()) for pat in includes)] - if excludes: - tests = [test - for test in tests - if not any(re.search(pat, test.id()) for pat in excludes)] - suite.addTests(tests) - - return suite + test_mods = [f[:-3] for f in os.listdir(TULIP_DIR) + if f.endswith('_test.py')] + + if sys.platform == 'win32': + try: + test_mods.remove('subprocess_test') + except ValueError: + pass + tulip = __import__('tulip', fromlist=test_mods) + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + for mod in [getattr(tulip, name) for name in test_mods]: + for name in set(dir(mod)): + if name.endswith('Tests'): + test_module = getattr(mod, name) + tests = loader.loadTestsFromTestCase(test_module) + if includes: + tests = [test + for test in tests + if any(re.search(pat, test.id()) + for pat in includes)] + if excludes: + tests = [test + for test in tests + if not any(re.search(pat, test.id()) + for pat in excludes)] + suite.addTests(tests) + + return suite def main(): - excludes = [] - includes = [] - patterns = includes # A reference. - v = 1 - for arg in sys.argv[1:]: - if arg.startswith('-v'): - v += arg.count('v') - elif arg == '-q': - v = 0 - elif arg == '-x': - if patterns is includes: - patterns = excludes - else: - patterns = includes - elif arg and not arg.startswith('-'): - patterns.append(arg) - tests = load_tests(includes, excludes) - logger = logging.getLogger() - if v == 0: - logger.setLevel(logging.CRITICAL) - elif v == 1: - logger.setLevel(logging.ERROR) - elif v == 2: - logger.setLevel(logging.WARNING) - elif v == 3: - logger.setLevel(logging.INFO) - elif v >= 4: - logger.setLevel(logging.DEBUG) - result = unittest.TextTestRunner(verbosity=v).run(tests) - sys.exit(not result.wasSuccessful()) + excludes = [] + includes = [] + patterns = includes # A reference. + v = 1 + for arg in sys.argv[1:]: + if arg.startswith('-v'): + v += arg.count('v') + elif arg == '-q': + v = 0 + elif arg == '-x': + if patterns is includes: + patterns = excludes + else: + patterns = includes + elif arg and not arg.startswith('-'): + patterns.append(arg) + tests = load_tests(includes, excludes) + logger = logging.getLogger() + if v == 0: + logger.setLevel(logging.CRITICAL) + elif v == 1: + logger.setLevel(logging.ERROR) + elif v == 2: + logger.setLevel(logging.WARNING) + elif v == 3: + logger.setLevel(logging.INFO) + elif v >= 4: + logger.setLevel(logging.DEBUG) + result = unittest.TextTestRunner(verbosity=v).run(tests) + sys.exit(not result.wasSuccessful()) if __name__ == '__main__': - main() + main() diff --git a/srv.py b/srv.py index bae9fdd8..077279c1 100644 --- a/srv.py +++ b/srv.py @@ -26,14 +26,16 @@ def handle_request(self): self.transport.close() return bmethod, bpath, bversion = match.groups() - print('method = {!r}; path = {!r}; version = {!r}'.format(bmethod, bpath, bversion)) + print('method = {!r}; path = {!r}; version = {!r}'.format( + bmethod, bpath, bversion)) try: path = bpath.decode('ascii') except UnicodeError as exc: print('not ascii', repr(bpath), exc) path = None else: - if not (path.isprintable() and path.startswith('/')) or '/.' in path: + if (not (path.isprintable() and path.startswith('/')) or + '/.' in path): print('bad path', repr(path)) path = None else: @@ -69,7 +71,7 @@ def handle_request(self): if isdir: write(b'Content-type: text/html\r\n') else: - write(b'Content-type: text/plain\r\n') + write(b'Content-type: text/plain\r\n') write(b'\r\n') if isdir: write(b'') else: try: diff --git a/tulip/events_test.py b/tulip/events_test.py index 1938e8e2..c9b11d4c 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -23,21 +23,26 @@ class MyProto(protocols.Protocol): + def __init__(self): self.state = 'INITIAL' self.nbytes = 0 + def connection_made(self, transport): self.transport = transport assert self.state == 'INITIAL', self.state self.state = 'CONNECTED' transport.write(b'GET / HTTP/1.0\r\nHost: xkcd.com\r\n\r\n') + def data_received(self, data): assert self.state == 'CONNECTED', self.state self.nbytes += len(data) + def eof_received(self): assert self.state == 'CONNECTED', self.state self.state = 'EOF' self.transport.close() + def connection_lost(self, exc): assert self.state in ('CONNECTED', 'EOF'), self.state self.state = 'CLOSED' @@ -53,11 +58,11 @@ def tearDown(self): self.event_loop.close() gc.collect() - def testRun(self): + def test_run(self): el = events.get_event_loop() el.run() # Returns immediately. - def testCallLater(self): + def test_call_later(self): el = events.get_event_loop() results = [] def callback(arg): @@ -69,7 +74,7 @@ def callback(arg): self.assertEqual(results, ['hello world']) self.assertTrue(t1-t0 >= 0.09) - def testCallRepeatedly(self): + def test_call_repeatedly(self): el = events.get_event_loop() results = [] def callback(arg): @@ -79,7 +84,7 @@ def callback(arg): el.run() self.assertEqual(results, ['ho', 'ho', 'ho']) - def testCallSoon(self): + def test_call_soon(self): el = events.get_event_loop() results = [] def callback(arg1, arg2): @@ -88,7 +93,7 @@ def callback(arg1, arg2): el.run() self.assertEqual(results, [('hello', 'world')]) - def testCallSoonWithHandler(self): + def test_call_soon_with_handler(self): el = events.get_event_loop() results = [] def callback(): @@ -98,7 +103,7 @@ def callback(): el.run() self.assertEqual(results, ['yeah']) - def testCallSoonThreadsafe(self): + def test_call_soon_threadsafe(self): el = events.get_event_loop() results = [] def callback(arg): @@ -115,7 +120,7 @@ def run(): self.assertEqual(results, ['hello', 'world']) self.assertTrue(t1-t0 >= 0.09) - def testCallSoonThreadsafeWithHandler(self): + def test_call_soon_threadsafe_with_handler(self): el = events.get_event_loop() results = [] def callback(arg): @@ -133,7 +138,7 @@ def run(): self.assertEqual(results, ['hello', 'world']) self.assertTrue(t1-t0 >= 0.09) - def testWrapFuture(self): + def test_wrap_future(self): el = events.get_event_loop() def run(arg): time.sleep(0.1) @@ -144,7 +149,7 @@ def run(arg): res = el.run_until_complete(f2) self.assertEqual(res, 'oi') - def testRunInExecutor(self): + def test_run_in_executor(self): el = events.get_event_loop() def run(arg): time.sleep(0.1) @@ -153,7 +158,7 @@ def run(arg): res = el.run_until_complete(f2) self.assertEqual(res, 'yo') - def testRunInExecutorWithHandler(self): + def test_run_in_executor_with_handler(self): el = events.get_event_loop() def run(arg): time.sleep(0.1) @@ -163,7 +168,7 @@ def run(arg): res = el.run_until_complete(f2) self.assertEqual(res, 'yo') - def testReaderCallback(self): + def test_reader_callback(self): el = events.get_event_loop() r, w = unix_events.socketpair() bytes_read = [] @@ -186,7 +191,7 @@ def reader(): el.run() self.assertEqual(b''.join(bytes_read), b'abcdef') - def testReaderCallbackWithHandler(self): + def test_reader_callback_with_handler(self): el = events.get_event_loop() r, w = unix_events.socketpair() bytes_read = [] @@ -210,7 +215,7 @@ def reader(): el.run() self.assertEqual(b''.join(bytes_read), b'abcdef') - def testReaderCallbackCancel(self): + def test_reader_callback_cancel(self): el = events.get_event_loop() r, w = unix_events.socketpair() bytes_read = [] @@ -232,7 +237,7 @@ def reader(): el.run() self.assertEqual(b''.join(bytes_read), b'abcdef') - def testWriterCallback(self): + def test_writer_callback(self): el = events.get_event_loop() r, w = unix_events.socketpair() w.setblocking(False) @@ -246,7 +251,7 @@ def remove_writer(): r.close() self.assertTrue(len(data) >= 200) - def testWriterCallbackWithHandler(self): + def test_writer_callback_with_handler(self): el = events.get_event_loop() r, w = unix_events.socketpair() w.setblocking(False) @@ -261,7 +266,7 @@ def remove_writer(): r.close() self.assertTrue(len(data) >= 200) - def testWriterCallbackCancel(self): + def test_writer_callback_cancel(self): el = events.get_event_loop() r, w = unix_events.socketpair() w.setblocking(False) @@ -275,7 +280,7 @@ def sender(): r.close() self.assertTrue(data == b'x'*256) - def testSockClientOps(self): + def test_sock_client_ops(self): el = events.get_event_loop() sock = socket.socket() sock.setblocking(False) @@ -287,7 +292,7 @@ def testSockClientOps(self): sock.close() self.assertTrue(data.startswith(b'HTTP/1.1 302 Found\r\n')) - def testSockClientFail(self): + def test_sock_client_fail(self): el = events.get_event_loop() sock = socket.socket() sock.setblocking(False) @@ -297,7 +302,7 @@ def testSockClientFail(self): el.run_until_complete(el.sock_connect(sock, address)) sock.close() - def testSockAccept(self): + def test_sock_accept(self): el = events.get_event_loop() listener = socket.socket() listener.setblocking(False) @@ -315,7 +320,7 @@ def testSockAccept(self): listener.close() @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') - def testAddSignalHandler(self): + def test_add_signal_handler(self): caught = 0 def my_handler(): nonlocal caught @@ -349,7 +354,7 @@ def my_handler(): self.assertFalse(el.remove_signal_handler(signal.SIGINT)) @unittest.skipIf(sys.platform == 'win32', 'Unix only') - def testCancelSignalHandler(self): + def test_cancel_signal_handler(self): # Cancelling the handler should remove it (eventually). caught = 0 def my_handler(): @@ -363,7 +368,7 @@ def my_handler(): self.assertEqual(caught, 0) @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') - def testSignalHandlingWhileSelecting(self): + def test_signal_handling_while_selecting(self): # Test with a signal actually arriving during a select() call. caught = 0 def my_handler(): @@ -376,7 +381,7 @@ def my_handler(): el.run_forever() self.assertEqual(caught, 1) - def testCreateTransport(self): + def test_create_transport(self): el = events.get_event_loop() # TODO: This depends on xkcd.com behavior! f = el.create_connection(MyProto, 'xkcd.com', 80) @@ -387,7 +392,7 @@ def testCreateTransport(self): self.assertTrue(pr.nbytes > 0) @unittest.skipIf(ssl is None, 'No ssl module') - def testCreateSslTransport(self): + def test_create_ssl_transport(self): el = events.get_event_loop() # TODO: This depends on xkcd.com behavior! f = el.create_connection(MyProto, 'xkcd.com', 443, ssl=True) @@ -398,7 +403,7 @@ def testCreateSslTransport(self): el.run() self.assertTrue(pr.nbytes > 0) - def testStartServing(self): + def test_start_serving(self): el = events.get_event_loop() f = el.start_serving(MyProto, '0.0.0.0', 0) sock = el.run_until_complete(f) @@ -437,19 +442,19 @@ def create_event_loop(self): class IocpEventLoopTests(EventLoopTestsMixin, unittest.TestCase): def create_event_loop(self): return iocp_events.IocpEventLoop() - def testCreateSslTransport(self): + def test_create_ssl_transport(self): raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") - def testReaderCallback(self): + def test_reader_callback(self): raise unittest.SkipTest("IocpEventLoop does not have add_reader()") - def testReaderCallbackCancel(self): + def test_reader_callback_cancel(self): raise unittest.SkipTest("IocpEventLoop does not have add_reader()") - def testReaderCallbackWithHandler(self): + def test_reader_callback_with_handler(self): raise unittest.SkipTest("IocpEventLoop does not have add_reader()") - def testWriterCallback(self): + def test_writer_callback(self): raise unittest.SkipTest("IocpEventLoop does not have add_writer()") - def testWriterCallbackCancel(self): + def test_writer_callback_cancel(self): raise unittest.SkipTest("IocpEventLoop does not have add_writer()") - def testWriterCallbackWithHandler(self): + def test_writer_callback_with_handler(self): raise unittest.SkipTest("IocpEventLoop does not have add_writer()") # Should always exist. @@ -460,10 +465,10 @@ def create_event_loop(self): class HandlerTests(unittest.TestCase): - def testHandler(self): + def test_handler(self): pass - def testMakeHandler(self): + def test_make_handler(self): def callback(*args): return args h1 = events.Handler(None, callback, ()) @@ -473,7 +478,7 @@ def callback(*args): class PolicyTests(unittest.TestCase): - def testPolicy(self): + def test_policy(self): pass diff --git a/tulip/futures_test.py b/tulip/futures_test.py index 7834fec8..2610b5be 100644 --- a/tulip/futures_test.py +++ b/tulip/futures_test.py @@ -11,17 +11,17 @@ def _fakefunc(f): class FutureTests(unittest.TestCase): - def testInitialState(self): + def test_initial_state(self): f = futures.Future() self.assertFalse(f.cancelled()) self.assertFalse(f.running()) self.assertFalse(f.done()) - def testInitEventLoopPositional(self): + def test_init_event_loop_positional(self): # Make sure Future does't accept a positional argument self.assertRaises(TypeError, futures.Future, 42) - def testCancel(self): + def test_cancel(self): f = futures.Future() self.assertTrue(f.cancel()) self.assertTrue(f.cancelled()) @@ -33,7 +33,7 @@ def testCancel(self): self.assertRaises(futures.InvalidStateError, f.set_exception, None) self.assertFalse(f.cancel()) - def testResult(self): + def test_result(self): f = futures.Future() self.assertRaises(futures.InvalidStateError, f.result) self.assertRaises(futures.InvalidTimeoutError, f.result, 10) @@ -48,7 +48,7 @@ def testResult(self): self.assertRaises(futures.InvalidStateError, f.set_exception, None) self.assertFalse(f.cancel()) - def testException(self): + def test_exception(self): exc = RuntimeError() f = futures.Future() self.assertRaises(futures.InvalidStateError, f.exception) @@ -64,7 +64,7 @@ def testException(self): self.assertRaises(futures.InvalidStateError, f.set_exception, None) self.assertFalse(f.cancel()) - def testYieldFromTwice(self): + def test_yield_from_twice(self): f = futures.Future() def fixture(): yield 'A' @@ -80,7 +80,7 @@ def fixture(): # The second "yield from f" does not yield f. self.assertEqual(next(g), ('C', 42)) # yield 'C', y. - def testRepr(self): + def test_repr(self): f_pending = futures.Future() self.assertEqual(repr(f_pending), 'Future') @@ -108,7 +108,7 @@ def testRepr(self): self.assertIn('Future', r) - def testCopyState(self): + def test_copy_state(self): # Test the internal _copy_state method since it's being directly # invoked in other modules. f = futures.Future() @@ -153,7 +153,7 @@ def bag_appender(future): def _new_future(self): return futures.Future(event_loop=_FakeEventLoop()) - def testCallbacksInvokedOnSetResult(self): + def test_callbacks_invoked_on_set_result(self): bag = [] f = self._new_future() f.add_done_callback(self._make_callback(bag, 42)) @@ -164,7 +164,7 @@ def testCallbacksInvokedOnSetResult(self): self.assertEqual(bag, [42, 17]) self.assertEqual(f.result(), 'foo') - def testCallbacksInvokedOnSetException(self): + def test_callbacks_invoked_on_set_exception(self): bag = [] f = self._new_future() f.add_done_callback(self._make_callback(bag, 100)) @@ -175,7 +175,7 @@ def testCallbacksInvokedOnSetException(self): self.assertEqual(bag, [100]) self.assertEqual(f.exception(), exc) - def testRemoveDoneCallback(self): + def test_remove_done_callback(self): bag = [] f = self._new_future() cb1 = self._make_callback(bag, 1) diff --git a/tulip/subprocess_test.py b/tulip/subprocess_test.py index b4bcc26f..3d996f6a 100644 --- a/tulip/subprocess_test.py +++ b/tulip/subprocess_test.py @@ -9,22 +9,27 @@ class MyProto(protocols.Protocol): + def __init__(self): self.state = 'INITIAL' self.nbytes = 0 + def connection_made(self, transport): self.transport = transport assert self.state == 'INITIAL', self.state self.state = 'CONNECTED' transport.write_eof() + def data_received(self, data): logging.info('received: %r', data) assert self.state == 'CONNECTED', self.state self.nbytes += len(data) + def eof_received(self): assert self.state == 'CONNECTED', self.state self.state = 'EOF' self.transport.close() + def connection_lost(self, exc): assert self.state in ('CONNECTED', 'EOF'), self.state self.state = 'CLOSED' @@ -39,7 +44,7 @@ def setUp(self): def tearDown(self): self.event_loop.close() - def testUnixSubprocess(self): + def test_unix_subprocess(self): p = MyProto() t = subprocess_transport.UnixSubprocessTransport(p, ['/bin/ls', '-lR']) self.event_loop.run() diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py index 456cccfa..bc6fa531 100644 --- a/tulip/tasks_test.py +++ b/tulip/tasks_test.py @@ -27,7 +27,7 @@ def tearDown(self): self.event_loop.close() super().tearDown() - def testTaskClass(self): + def test_task_class(self): @tasks.coroutine def notmuch(): yield from [] @@ -37,7 +37,7 @@ def notmuch(): self.assertTrue(t.done()) self.assertEqual(t.result(), 'ok') - def testTaskDecorator(self): + def test_task_decorator(self): @tasks.task def notmuch(): yield from [] @@ -47,7 +47,7 @@ def notmuch(): self.assertTrue(t.done()) self.assertEqual(t.result(), 'ko') - def testTaskRepr(self): + def test_task_repr(self): @tasks.task def notmuch(): yield from [] @@ -64,7 +64,7 @@ def notmuch(): self.event_loop.run_until_complete(t) self.assertEqual(repr(t), "Task()") - def testTaskBasics(self): + def test_task_basics(self): @tasks.task def outer(): a = yield from inner1() @@ -81,7 +81,7 @@ def inner2(): t = outer() self.assertEqual(self.event_loop.run_until_complete(t), 1042) - def testWait(self): + def test_wait(self): a = tasks.sleep(0.1) b = tasks.sleep(0.15) @tasks.coroutine @@ -102,7 +102,7 @@ def foo(): self.assertTrue(t1-t0 <= 0.01) # TODO: Test different return_when values. - def testWaitWithException(self): + def test_wait_with_exception(self): self.suppress_log_errors() a = tasks.sleep(0.1) @tasks.coroutine @@ -126,7 +126,7 @@ def foo(): t1 = time.monotonic() self.assertTrue(t1-t0 <= 0.01) - def testWaitWithTimeout(self): + def test_wait_with_timeout(self): a = tasks.sleep(0.1) b = tasks.sleep(0.15) @tasks.coroutine @@ -140,7 +140,7 @@ def foo(): self.assertTrue(t1-t0 >= 0.1) self.assertTrue(t1-t0 <= 0.13) - def testAsCompleted(self): + def test_as_completed(self): @tasks.coroutine def sleeper(dt, x): yield from tasks.sleep(dt) @@ -167,7 +167,7 @@ def foo(): t1 = time.monotonic() self.assertTrue(t1-t0 <= 0.01) - def testAsCompletedWithTimeout(self): + def test_as_completed_with_timeout(self): self.suppress_log_errors() a = tasks.sleep(0.1, 'a') b = tasks.sleep(0.15, 'b') @@ -190,7 +190,7 @@ def foo(): self.assertEqual(res[1][0], 2) self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) - def testSleep(self): + def test_sleep(self): @tasks.coroutine def sleeper(dt, arg): yield from tasks.sleep(dt/2) @@ -204,7 +204,7 @@ def sleeper(dt, arg): self.assertTrue(t.done()) self.assertEqual(t.result(), 'yeah') - def testTaskCancelSleepingTask(self): + def test_task_cancel_sleeping_task(self): sleepfut = None @tasks.task def sleep(dt): From 90c020c7700f2d300125c074fb41793c1d02cd8f Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Wed, 30 Jan 2013 12:00:31 +0000 Subject: [PATCH 0278/1502] Rename EventLoop to AbstractEventLoop. --- tulip/events.py | 6 +++--- tulip/unix_events.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tulip/events.py b/tulip/events.py index f39ddb79..30b35d2c 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -5,7 +5,7 @@ """ __all__ = ['EventLoopPolicy', 'DefaultEventLoopPolicy', - 'EventLoop', 'Handler', 'make_handler', + 'AbstractEventLoop', 'Handler', 'make_handler', 'get_event_loop_policy', 'set_event_loop_policy', 'get_event_loop', 'set_event_loop', 'new_event_loop', ] @@ -73,7 +73,7 @@ def make_handler(when, callback, args): return Handler(when, callback, args) -class EventLoop: +class AbstractEventLoop: """Abstract event loop.""" # TODO: Rename run() -> run_until_idle(), run_forever() -> run(). @@ -236,7 +236,7 @@ def get_event_loop(self): def set_event_loop(self, event_loop): """Set the event loop.""" - assert event_loop is None or isinstance(event_loop, EventLoop) + assert event_loop is None or isinstance(event_loop, AbstractEventLoop) self._event_loop = event_loop def new_event_loop(self): diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 546e55d5..610d0101 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -72,7 +72,7 @@ def _raise_stop_error(): raise _StopError -class BaseEventLoop(events.EventLoop): +class BaseEventLoop(events.AbstractEventLoop): def __init__(self): self._ready = collections.deque() From 264bae608a7150aa64140d123db502da52569544 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Wed, 30 Jan 2013 12:19:57 +0000 Subject: [PATCH 0279/1502] Move BaseEventLoop to base_events.py. --- tulip/base_events.py | 350 +++++++++++++++++++++++++++++++++++++++++++ tulip/iocp_events.py | 4 +- tulip/unix_events.py | 347 +----------------------------------------- 3 files changed, 354 insertions(+), 347 deletions(-) create mode 100644 tulip/base_events.py diff --git a/tulip/base_events.py b/tulip/base_events.py new file mode 100644 index 00000000..d9e5633f --- /dev/null +++ b/tulip/base_events.py @@ -0,0 +1,350 @@ +import collections +import concurrent.futures +import heapq +import logging +import socket +import time + +from . import events +from . import futures +from . import tasks + + +__all__ = ['BaseEventLoop'] + + +# Argument for default thread pool executor creation. +_MAX_WORKERS = 5 + + +class _StopError(BaseException): + """Raised to stop the event loop.""" + + +def _raise_stop_error(): + raise _StopError + + +class BaseEventLoop(events.AbstractEventLoop): + + def __init__(self): + self._ready = collections.deque() + self._scheduled = [] + self._default_executor = None + self._signal_handlers = {} + + def run(self): + """Run the event loop until nothing left to do or stop() called. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + + TODO: Give this a timeout too? + """ + while (self._ready or + self._scheduled or + self._selector.registered_count() > 1): + try: + self._run_once() + except _StopError: + break + + def run_forever(self): + """Run until stop() is called. + + This only makes sense over run() if you have another thread + scheduling callbacks using call_soon_threadsafe(). + """ + handler = self.call_repeatedly(24*3600, lambda: None) + try: + self.run() + finally: + handler.cancel() + + def run_once(self, timeout=None): + """Run through all callbacks and all I/O polls once. + + Calling stop() will break out of this too. + """ + try: + self._run_once(timeout) + except _StopError: + pass + + def run_until_complete(self, future, timeout=None): + """Run until the Future is done, or until a timeout. + + Return the Future's result, or raise its exception. If the + timeout is reached or stop() is called, raise TimeoutError. + """ + if timeout is None: + timeout = 0x7fffffff/1000.0 # 24 days + future.add_done_callback(lambda _: self.stop()) + handler = self.call_later(timeout, _raise_stop_error) + self.run() + handler.cancel() + if future.done(): + return future.result() # May raise future.exception(). + else: + raise futures.TimeoutError + + def stop(self): + """Stop running the event loop. + + Every callback scheduled before stop() is called will run. + Callback scheduled after stop() is called won't. However, + those callbacks will run if run() is called again later. + """ + self.call_soon(_raise_stop_error) + + def call_later(self, delay, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The delay can be an int or float, expressed in seconds. It is + always a relative time. + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Callbacks scheduled in the past are passed on to call_soon(), + so these will be called in the order in which they were + registered rather than by time due. This is so you can't + cheat and insert yourself at the front of the ready queue by + using a negative time. + + Any positional arguments after the callback will be passed to + the callback when it is called. + + # TODO: Should delay is None be interpreted as Infinity? + """ + if delay <= 0: + return self.call_soon(callback, *args) + handler = events.make_handler(time.monotonic() + delay, callback, args) + heapq.heappush(self._scheduled, handler) + return handler + + def call_repeatedly(self, interval, callback, *args): + """Call a callback every 'interval' seconds.""" + def wrapper(): + callback(*args) # If this fails, the chain is broken. + handler._when = time.monotonic() + interval + heapq.heappush(self._scheduled, handler) + handler = events.make_handler(time.monotonic() + interval, wrapper, ()) + heapq.heappush(self._scheduled, handler) + return handler + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + handler = events.make_handler(None, callback, args) + self._ready.append(handler) + return handler + + def call_soon_threadsafe(self, callback, *args): + """XXX""" + handler = self.call_soon(callback, *args) + self._write_to_self() + return handler + + def run_in_executor(self, executor, callback, *args): + if isinstance(callback, events.Handler): + assert not args + assert callback.when is None + if callback.cancelled: + f = futures.Future() + f.set_result(None) + return f + callback, args = callback.callback, callback.args + if executor is None: + executor = self._default_executor + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + self._default_executor = executor + return self.wrap_future(executor.submit(callback, *args)) + + def set_default_executor(self, executor): + self._default_executor = executor + + def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) + + def getnameinfo(self, sockaddr, flags=0): + return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + + @tasks.task + def create_connection(self, protocol_factory, host, port, *, ssl=False, + family=0, proto=0, flags=0): + """XXX""" + infos = yield from self.getaddrinfo(host, port, + family=family, + type=socket.SOCK_STREAM, + proto=proto, flags=flags) + if not infos: + raise socket.error('getaddrinfo() returned empty list') + exceptions = [] + for family, type, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + yield self.sock_connect(sock, address) + except socket.error as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + if len(exceptions) == 1: + raise exceptions[0] + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise socket.error('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + protocol = protocol_factory() + waiter = futures.Future() + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + transport = self.SslTransport(self, sock, protocol, sslcontext, + waiter) + else: + transport = self.SocketTransport(self, sock, protocol, waiter) + yield from waiter + return transport, protocol + + # TODO: Or create_server()? + @tasks.task + def start_serving(self, protocol_factory, host, port, *, + family=0, proto=0, flags=0, + backlog=100): + """XXX""" + infos = yield from self.getaddrinfo(host, port, + family=family, + type=socket.SOCK_STREAM, + proto=proto, flags=flags) + if not infos: + raise socket.error('getaddrinfo() returned empty list') + # TODO: Maybe we want to bind every address in the list + # instead of the first one that works? + exceptions = [] + for family, type, proto, cname, address in infos: + sock = socket.socket(family=family, type=type, proto=proto) + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + except socket.error as exc: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock) + return sock + + def _add_callback(self, handler): + """Add a Handler to ready or scheduled.""" + if handler.cancelled: + return + if handler.when is None: + self._ready.append(handler) + else: + heapq.heappush(self._scheduled, handler) + + def wrap_future(self, future): + """XXX""" + if isinstance(future, futures.Future): + return future # Don't wrap our own type of Future. + new_future = futures.Future() + future.add_done_callback( + lambda future: + self.call_soon_threadsafe(new_future._copy_state, future)) + return new_future + + def _run_once(self, timeout=None): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0].cancelled: + heapq.heappop(self._scheduled) + + # Inspect the poll queue. If there's exactly one selectable + # file descriptor, it's the self-pipe, and if there's nothing + # scheduled, we should ignore it. + if self._selector.registered_count() > 1 or self._scheduled: + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0].when + deadline = max(0, when - time.monotonic()) + if timeout is None: + timeout = deadline + else: + timeout = min(timeout, deadline) + + t0 = time.monotonic() + event_list = self._selector.select(timeout) + t1 = time.monotonic() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + self._process_events(event_list) + + # Handle 'later' callbacks that are ready. + now = time.monotonic() + while self._scheduled: + handler = self._scheduled[0] + if handler.when > now: + break + handler = heapq.heappop(self._scheduled) + self._ready.append(handler) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is threadsafe without using locks. + ntodo = len(self._ready) + for i in range(ntodo): + handler = self._ready.popleft() + if not handler.cancelled: + try: + handler.callback(*handler.args) + except Exception: + logging.exception('Exception in callback %s %r', + handler.callback, handler.args) diff --git a/tulip/iocp_events.py b/tulip/iocp_events.py index 8c63f03c..4167488c 100644 --- a/tulip/iocp_events.py +++ b/tulip/iocp_events.py @@ -15,7 +15,7 @@ from _winapi import CloseHandle -from . import events +from . import base_events from . import futures from . import transports from . import unix_events @@ -236,7 +236,7 @@ def _call_connection_lost(self, exc): self._sock.close() -class IocpEventLoop(unix_events.BaseEventLoop): +class IocpEventLoop(base_events.BaseEventLoop): SocketTransport = _IocpSocketTransport diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 610d0101..60f6161a 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -13,31 +13,24 @@ to modify the meaning of the API call itself. """ -import collections -import concurrent.futures import errno -import heapq import logging -import select import socket try: import ssl except ImportError: ssl = None import sys -import threading -import time try: import signal except ImportError: signal = None +from . import base_events from . import events from . import futures -from . import protocols from . import selectors -from . import tasks from . import transports try: @@ -60,344 +53,8 @@ if sys.platform == 'win32': _TRYAGAIN = frozenset(list(_TRYAGAIN) + [errno.WSAEWOULDBLOCK]) -# Argument for default thread pool executor creation. -_MAX_WORKERS = 5 - -class _StopError(BaseException): - """Raised to stop the event loop.""" - - -def _raise_stop_error(): - raise _StopError - - -class BaseEventLoop(events.AbstractEventLoop): - - def __init__(self): - self._ready = collections.deque() - self._scheduled = [] - self._default_executor = None - self._signal_handlers = {} - - def run(self): - """Run the event loop until nothing left to do or stop() called. - - This keeps going as long as there are either readable and - writable file descriptors, or scheduled callbacks (of either - variety). - - TODO: Give this a timeout too? - """ - while (self._ready or - self._scheduled or - self._selector.registered_count() > 1): - try: - self._run_once() - except _StopError: - break - - def run_forever(self): - """Run until stop() is called. - - This only makes sense over run() if you have another thread - scheduling callbacks using call_soon_threadsafe(). - """ - handler = self.call_repeatedly(24*3600, lambda: None) - try: - self.run() - finally: - handler.cancel() - - def run_once(self, timeout=None): - """Run through all callbacks and all I/O polls once. - - Calling stop() will break out of this too. - """ - try: - self._run_once(timeout) - except _StopError: - pass - - def run_until_complete(self, future, timeout=None): - """Run until the Future is done, or until a timeout. - - Return the Future's result, or raise its exception. If the - timeout is reached or stop() is called, raise TimeoutError. - """ - if timeout is None: - timeout = 0x7fffffff/1000.0 # 24 days - future.add_done_callback(lambda _: self.stop()) - handler = self.call_later(timeout, _raise_stop_error) - self.run() - handler.cancel() - if future.done(): - return future.result() # May raise future.exception(). - else: - raise futures.TimeoutError - - def stop(self): - """Stop running the event loop. - - Every callback scheduled before stop() is called will run. - Callback scheduled after stop() is called won't. However, - those callbacks will run if run() is called again later. - """ - self.call_soon(_raise_stop_error) - - def call_later(self, delay, callback, *args): - """Arrange for a callback to be called at a given time. - - Return an object with a cancel() method that can be used to - cancel the call. - - The delay can be an int or float, expressed in seconds. It is - always a relative time. - - Each callback will be called exactly once. If two callbacks - are scheduled for exactly the same time, it undefined which - will be called first. - - Callbacks scheduled in the past are passed on to call_soon(), - so these will be called in the order in which they were - registered rather than by time due. This is so you can't - cheat and insert yourself at the front of the ready queue by - using a negative time. - - Any positional arguments after the callback will be passed to - the callback when it is called. - - # TODO: Should delay is None be interpreted as Infinity? - """ - if delay <= 0: - return self.call_soon(callback, *args) - handler = events.make_handler(time.monotonic() + delay, callback, args) - heapq.heappush(self._scheduled, handler) - return handler - - def call_repeatedly(self, interval, callback, *args): - """Call a callback every 'interval' seconds.""" - def wrapper(): - callback(*args) # If this fails, the chain is broken. - handler._when = time.monotonic() + interval - heapq.heappush(self._scheduled, handler) - handler = events.make_handler(time.monotonic() + interval, wrapper, ()) - heapq.heappush(self._scheduled, handler) - return handler - - def call_soon(self, callback, *args): - """Arrange for a callback to be called as soon as possible. - - This operates as a FIFO queue, callbacks are called in the - order in which they are registered. Each callback will be - called exactly once. - - Any positional arguments after the callback will be passed to - the callback when it is called. - """ - handler = events.make_handler(None, callback, args) - self._ready.append(handler) - return handler - - def call_soon_threadsafe(self, callback, *args): - """XXX""" - handler = self.call_soon(callback, *args) - self._write_to_self() - return handler - - def run_in_executor(self, executor, callback, *args): - if isinstance(callback, events.Handler): - assert not args - assert callback.when is None - if callback.cancelled: - f = futures.Future() - f.set_result(None) - return f - callback, args = callback.callback, callback.args - if executor is None: - executor = self._default_executor - if executor is None: - executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) - self._default_executor = executor - return self.wrap_future(executor.submit(callback, *args)) - - def set_default_executor(self, executor): - self._default_executor = executor - - def getaddrinfo(self, host, port, *, - family=0, type=0, proto=0, flags=0): - return self.run_in_executor(None, socket.getaddrinfo, - host, port, family, type, proto, flags) - - def getnameinfo(self, sockaddr, flags=0): - return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) - - @tasks.task - def create_connection(self, protocol_factory, host, port, *, ssl=False, - family=0, proto=0, flags=0): - """XXX""" - infos = yield from self.getaddrinfo(host, port, - family=family, - type=socket.SOCK_STREAM, - proto=proto, flags=flags) - if not infos: - raise socket.error('getaddrinfo() returned empty list') - exceptions = [] - for family, type, proto, cname, address in infos: - sock = None - try: - sock = socket.socket(family=family, type=type, proto=proto) - sock.setblocking(False) - yield self.sock_connect(sock, address) - except socket.error as exc: - if sock is not None: - sock.close() - exceptions.append(exc) - else: - break - else: - if len(exceptions) == 1: - raise exceptions[0] - else: - # If they all have the same str(), raise one. - model = str(exceptions[0]) - if all(str(exc) == model for exc in exceptions): - raise exceptions[0] - # Raise a combined exception so the user can see all - # the various error messages. - raise socket.error('Multiple exceptions: {}'.format( - ', '.join(str(exc) for exc in exceptions))) - protocol = protocol_factory() - waiter = futures.Future() - if ssl: - sslcontext = None if isinstance(ssl, bool) else ssl - transport = self.SslTransport(self, sock, protocol, sslcontext, - waiter) - else: - transport = self.SocketTransport(self, sock, protocol, waiter) - yield from waiter - return transport, protocol - - # TODO: Or create_server()? - @tasks.task - def start_serving(self, protocol_factory, host, port, *, - family=0, proto=0, flags=0, - backlog=100): - """XXX""" - infos = yield from self.getaddrinfo(host, port, - family=family, - type=socket.SOCK_STREAM, - proto=proto, flags=flags) - if not infos: - raise socket.error('getaddrinfo() returned empty list') - # TODO: Maybe we want to bind every address in the list - # instead of the first one that works? - exceptions = [] - for family, type, proto, cname, address in infos: - sock = socket.socket(family=family, type=type, proto=proto) - try: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(address) - except socket.error as exc: - sock.close() - exceptions.append(exc) - else: - break - else: - raise exceptions[0] - sock.listen(backlog) - sock.setblocking(False) - self._start_serving(protocol_factory, sock) - return sock - - def _add_callback(self, handler): - """Add a Handler to ready or scheduled.""" - if handler.cancelled: - return - if handler.when is None: - self._ready.append(handler) - else: - heapq.heappush(self._scheduled, handler) - - def wrap_future(self, future): - """XXX""" - if isinstance(future, futures.Future): - return future # Don't wrap our own type of Future. - new_future = futures.Future() - future.add_done_callback( - lambda future: - self.call_soon_threadsafe(new_future._copy_state, future)) - return new_future - - def _run_once(self, timeout=None): - """Run one full iteration of the event loop. - - This calls all currently ready callbacks, polls for I/O, - schedules the resulting callbacks, and finally schedules - 'call_later' callbacks. - """ - # TODO: Break each of these into smaller pieces. - # TODO: Refactor to separate the callbacks from the readers/writers. - # TODO: An alternative API would be to do the *minimal* amount - # of work, e.g. one callback or one I/O poll. - - # Remove delayed calls that were cancelled from head of queue. - while self._scheduled and self._scheduled[0].cancelled: - heapq.heappop(self._scheduled) - - # Inspect the poll queue. If there's exactly one selectable - # file descriptor, it's the self-pipe, and if there's nothing - # scheduled, we should ignore it. - if self._selector.registered_count() > 1 or self._scheduled: - if self._ready: - timeout = 0 - elif self._scheduled: - # Compute the desired timeout. - when = self._scheduled[0].when - deadline = max(0, when - time.monotonic()) - if timeout is None: - timeout = deadline - else: - timeout = min(timeout, deadline) - - t0 = time.monotonic() - event_list = self._selector.select(timeout) - t1 = time.monotonic() - argstr = '' if timeout is None else ' %.3f' % timeout - if t1-t0 >= 1: - level = logging.INFO - else: - level = logging.DEBUG - logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) - self._process_events(event_list) - - # Handle 'later' callbacks that are ready. - now = time.monotonic() - while self._scheduled: - handler = self._scheduled[0] - if handler.when > now: - break - handler = heapq.heappop(self._scheduled) - self._ready.append(handler) - - # This is the only place where callbacks are actually *called*. - # All other places just add them to ready. - # Note: We run all currently scheduled callbacks, but not any - # callbacks scheduled by callbacks run this time around -- - # they will be run the next time (after another I/O poll). - # Use an idiom that is threadsafe without using locks. - ntodo = len(self._ready) - for i in range(ntodo): - handler = self._ready.popleft() - if not handler.cancelled: - try: - handler.callback(*handler.args) - except Exception: - logging.exception('Exception in callback %s %r', - handler.callback, handler.args) - - -class UnixEventLoop(BaseEventLoop): +class UnixEventLoop(base_events.BaseEventLoop): """Unix event loop. See events.EventLoop for API specification. From 6fe9fe5f5ea7c1041c3229e4d4ffb51650cd1a91 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Wed, 30 Jan 2013 12:36:33 +0000 Subject: [PATCH 0280/1502] Rename unix_events.py to selector_events.py. --- tulip/events.py | 4 ++-- tulip/events_test.py | 22 ++++++++++---------- tulip/iocp_events.py | 2 +- tulip/{unix_events.py => selector_events.py} | 16 +++++++------- 4 files changed, 22 insertions(+), 22 deletions(-) rename tulip/{unix_events.py => selector_events.py} (98%) diff --git a/tulip/events.py b/tulip/events.py index 30b35d2c..e9600291 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -246,8 +246,8 @@ def new_event_loop(self): loop. """ # TODO: Do something else for Windows. - from . import unix_events - return unix_events.UnixEventLoop() + from . import selector_events + return selector_events.SelectorEventLoop() # Event loop policy. The policy itself is always global, even if the diff --git a/tulip/events_test.py b/tulip/events_test.py index c9b11d4c..c6bdd88e 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -19,7 +19,7 @@ from . import transports from . import protocols from . import selectors -from . import unix_events +from . import selector_events class MyProto(protocols.Protocol): @@ -170,7 +170,7 @@ def run(arg): def test_reader_callback(self): el = events.get_event_loop() - r, w = unix_events.socketpair() + r, w = selector_events.socketpair() bytes_read = [] def reader(): try: @@ -193,7 +193,7 @@ def reader(): def test_reader_callback_with_handler(self): el = events.get_event_loop() - r, w = unix_events.socketpair() + r, w = selector_events.socketpair() bytes_read = [] def reader(): try: @@ -217,7 +217,7 @@ def reader(): def test_reader_callback_cancel(self): el = events.get_event_loop() - r, w = unix_events.socketpair() + r, w = selector_events.socketpair() bytes_read = [] def reader(): try: @@ -239,7 +239,7 @@ def reader(): def test_writer_callback(self): el = events.get_event_loop() - r, w = unix_events.socketpair() + r, w = selector_events.socketpair() w.setblocking(False) el.add_writer(w.fileno(), w.send, b'x'*(256*1024)) def remove_writer(): @@ -253,7 +253,7 @@ def remove_writer(): def test_writer_callback_with_handler(self): el = events.get_event_loop() - r, w = unix_events.socketpair() + r, w = selector_events.socketpair() w.setblocking(False) handler = events.Handler(None, w.send, (b'x'*(256*1024),)) self.assertEqual(el.add_writer(w.fileno(), handler), handler) @@ -268,7 +268,7 @@ def remove_writer(): def test_writer_callback_cancel(self): el = events.get_event_loop() - r, w = unix_events.socketpair() + r, w = selector_events.socketpair() w.setblocking(False) def sender(): w.send(b'x'*256) @@ -424,17 +424,17 @@ def test_start_serving(self): if hasattr(selectors, 'KqueueSelector'): class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase): def create_event_loop(self): - return unix_events.UnixEventLoop(selectors.KqueueSelector()) + return selector_events.SelectorEventLoop(selectors.KqueueSelector()) if hasattr(selectors, 'EpollSelector'): class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): def create_event_loop(self): - return unix_events.UnixEventLoop(selectors.EpollSelector()) + return selector_events.SelectorEventLoop(selectors.EpollSelector()) if hasattr(selectors, 'PollSelector'): class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): def create_event_loop(self): - return unix_events.UnixEventLoop(selectors.PollSelector()) + return selector_events.SelectorEventLoop(selectors.PollSelector()) if sys.platform == 'win32': from . import iocp_events @@ -460,7 +460,7 @@ def test_writer_callback_with_handler(self): # Should always exist. class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): def create_event_loop(self): - return unix_events.UnixEventLoop(selectors.SelectSelector()) + return selector_events.SelectorEventLoop(selectors.SelectSelector()) class HandlerTests(unittest.TestCase): diff --git a/tulip/iocp_events.py b/tulip/iocp_events.py index 4167488c..3eccd004 100644 --- a/tulip/iocp_events.py +++ b/tulip/iocp_events.py @@ -18,7 +18,7 @@ from . import base_events from . import futures from . import transports -from . import unix_events +from . import selector_events from . import winsocketpair from . import _overlapped diff --git a/tulip/unix_events.py b/tulip/selector_events.py similarity index 98% rename from tulip/unix_events.py rename to tulip/selector_events.py index 60f6161a..99c94dbe 100644 --- a/tulip/unix_events.py +++ b/tulip/selector_events.py @@ -54,20 +54,20 @@ _TRYAGAIN = frozenset(list(_TRYAGAIN) + [errno.WSAEWOULDBLOCK]) -class UnixEventLoop(base_events.BaseEventLoop): - """Unix event loop. +class SelectorEventLoop(base_events.BaseEventLoop): + """Selector event loop. See events.EventLoop for API specification. """ @staticmethod def SocketTransport(event_loop, sock, protocol, waiter=None): - return _UnixSocketTransport(event_loop, sock, protocol, waiter) + return _SelectorSocketTransport(event_loop, sock, protocol, waiter) @staticmethod def SslTransport(event_loop, rawsock, protocol, sslcontext, waiter): - return _UnixSslTransport(event_loop, rawsock, protocol, - sslcontext, waiter) + return _SelectorSslTransport(event_loop, rawsock, protocol, + sslcontext, waiter) def __init__(self, selector=None): super().__init__() @@ -126,7 +126,7 @@ def _accept_connection(self, protocol_factory, sock): logging.exception('Accept failed') return protocol = protocol_factory() - transport = _UnixSocketTransport(self, conn, protocol) + transport = self.SocketTransport(self, conn, protocol) # It's now up to the protocol to handle the connection. def add_reader(self, fd, callback, *args): @@ -447,7 +447,7 @@ def _check_signal(self, sig): raise RuntimeError('Signals are not really supported on Windows') -class _UnixSocketTransport(transports.Transport): +class _SelectorSocketTransport(transports.Transport): def __init__(self, event_loop, sock, protocol, waiter=None): self._event_loop = event_loop @@ -543,7 +543,7 @@ def _call_connection_lost(self, exc): self._sock.close() -class _UnixSslTransport(transports.Transport): +class _SelectorSslTransport(transports.Transport): def __init__(self, event_loop, rawsock, protocol, sslcontext, waiter): self._event_loop = event_loop From 1bce93891b59006617892e21b54957611a58062d Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Wed, 30 Jan 2013 12:58:40 +0000 Subject: [PATCH 0281/1502] Recreate unix_events.py and move signal handling there. --- tulip/events.py | 9 +++- tulip/events_test.py | 46 ++++++++++------- tulip/selector_events.py | 90 -------------------------------- tulip/unix_events.py | 107 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 141 insertions(+), 111 deletions(-) create mode 100644 tulip/unix_events.py diff --git a/tulip/events.py b/tulip/events.py index e9600291..70eb9b01 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -10,6 +10,7 @@ 'get_event_loop', 'set_event_loop', 'new_event_loop', ] +import sys import threading @@ -246,8 +247,12 @@ def new_event_loop(self): loop. """ # TODO: Do something else for Windows. - from . import selector_events - return selector_events.SelectorEventLoop() + if sys.platform == 'win32': + from . import selector_events + return selector_events.SelectorEventLoop() + else: + from . import unix_events + return unix_events.UnixEventLoop() # Event loop policy. The policy itself is always global, even if the diff --git a/tulip/events_test.py b/tulip/events_test.py index c6bdd88e..1ff23ee2 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -421,23 +421,13 @@ def test_start_serving(self): client.close() -if hasattr(selectors, 'KqueueSelector'): - class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase): - def create_event_loop(self): - return selector_events.SelectorEventLoop(selectors.KqueueSelector()) - -if hasattr(selectors, 'EpollSelector'): - class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): - def create_event_loop(self): - return selector_events.SelectorEventLoop(selectors.EpollSelector()) - -if hasattr(selectors, 'PollSelector'): - class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): - def create_event_loop(self): - return selector_events.SelectorEventLoop(selectors.PollSelector()) - if sys.platform == 'win32': from . import iocp_events + from . import selector_events + + class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + def create_event_loop(self): + return selector_events.SelectorEventLoop(selectors.SelectSelector()) class IocpEventLoopTests(EventLoopTestsMixin, unittest.TestCase): def create_event_loop(self): @@ -457,10 +447,28 @@ def test_writer_callback_cancel(self): def test_writer_callback_with_handler(self): raise unittest.SkipTest("IocpEventLoop does not have add_writer()") -# Should always exist. -class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): - def create_event_loop(self): - return selector_events.SelectorEventLoop(selectors.SelectSelector()) +else: + from . import unix_events + + if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + def create_event_loop(self): + return unix_events.UnixEventLoop(selectors.KqueueSelector()) + + if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + def create_event_loop(self): + return unix_events.UnixEventLoop(selectors.EpollSelector()) + + if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + def create_event_loop(self): + return unix_events.UnixEventLoop(selectors.PollSelector()) + + # Should always exist. + class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + def create_event_loop(self): + return unix_events.UnixEventLoop(selectors.SelectSelector()) class HandlerTests(unittest.TestCase): diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 99c94dbe..5742fce4 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -22,11 +22,6 @@ ssl = None import sys -try: - import signal -except ImportError: - signal = None - from . import base_events from . import events from . import futures @@ -361,91 +356,6 @@ def _process_events(self, event_list): else: self._add_callback(connector) - def add_signal_handler(self, sig, callback, *args): - """Add a handler for a signal. UNIX only. - - Raise ValueError if the signal number is invalid or uncatchable. - Raise RuntimeError if there is a problem setting up the handler. - """ - self._check_signal(sig) - try: - # set_wakeup_fd() raises ValueError if this is not the - # main thread. By calling it early we ensure that an - # event loop running in another thread cannot add a signal - # handler. - signal.set_wakeup_fd(self._csock.fileno()) - except ValueError as exc: - raise RuntimeError(str(exc)) - handler = events.make_handler(None, callback, args) - self._signal_handlers[sig] = handler - try: - signal.signal(sig, self._handle_signal) - except OSError as exc: - del self._signal_handlers[sig] - if not self._signal_handlers: - try: - signal.set_wakeup_fd(-1) - except ValueError as nexc: - logging.info('set_wakeup_fd(-1) failed: %s', nexc) - if exc.errno == errno.EINVAL: - raise RuntimeError('sig {} cannot be caught'.format(sig)) - else: - raise - return handler - - def _handle_signal(self, sig, arg): - """Internal helper that is the actual signal handler.""" - handler = self._signal_handlers.get(sig) - if handler is None: - return # Assume it's some race condition. - if handler.cancelled: - self.remove_signal_handler(sig) # Remove it properly. - else: - self.call_soon_threadsafe(handler.callback, *handler.args) - - def remove_signal_handler(self, sig): - """Remove a handler for a signal. UNIX only. - - Return True if a signal handler was removed, False if not.""" - self._check_signal(sig) - try: - del self._signal_handlers[sig] - except KeyError: - return False - if sig == signal.SIGINT: - handler = signal.default_int_handler - else: - handler = signal.SIG_DFL - try: - signal.signal(sig, handler) - except OSError as exc: - if exc.errno == errno.EINVAL: - raise RuntimeError('sig {} cannot be caught'.format(sig)) - else: - raise - if not self._signal_handlers: - try: - signal.set_wakeup_fd(-1) - except ValueError as exc: - logging.info('set_wakeup_fd(-1) failed: %s', exc) - return True - - def _check_signal(self, sig): - """Internal helper to validate a signal. - - Raise ValueError if the signal number is invalid or uncatchable. - Raise RuntimeError if there is a problem setting up the handler. - """ - if not isinstance(sig, int): - raise TypeError('sig must be an int, not {!r}'.format(sig)) - if signal is None: - raise RuntimeError('Signals are not supported') - if not (1 <= sig < signal.NSIG): - raise ValueError('sig {} out of range(1, {})'.format(sig, - signal.NSIG)) - if sys.platform == 'win32': - raise RuntimeError('Signals are not really supported on Windows') - class _SelectorSocketTransport(transports.Transport): diff --git a/tulip/unix_events.py b/tulip/unix_events.py new file mode 100644 index 00000000..9d83fe3f --- /dev/null +++ b/tulip/unix_events.py @@ -0,0 +1,107 @@ +import errno +import sys + +try: + import signal +except ImportError: + signal = None + +from . import events +from . import selector_events + + +__all__ = ['UnixEventLoop'] + + +if sys.platform == 'win32': + raise ImportError('Signals are not really supported on Windows') + + +class UnixEventLoop(selector_events.SelectorEventLoop): + """Unix event loop + + Adds signal handling to SelectorEventLoop + """ + + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + self._check_signal(sig) + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except ValueError as exc: + raise RuntimeError(str(exc)) + handler = events.make_handler(None, callback, args) + self._signal_handlers[sig] = handler + try: + signal.signal(sig, self._handle_signal) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as nexc: + logging.info('set_wakeup_fd(-1) failed: %s', nexc) + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + return handler + + def _handle_signal(self, sig, arg): + """Internal helper that is the actual signal handler.""" + handler = self._signal_handlers.get(sig) + if handler is None: + return # Assume it's some race condition. + if handler.cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self.call_soon_threadsafe(handler.callback, *handler.args) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not.""" + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as exc: + logging.info('set_wakeup_fd(-1) failed: %s', exc) + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + if signal is None: + raise RuntimeError('Signals are not supported') + if not (1 <= sig < signal.NSIG): + raise ValueError('sig {} out of range(1, {})'.format(sig, + signal.NSIG)) From fa4b0769d514a0a73d3eafb81e34abfc9fcbd9d9 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Wed, 30 Jan 2013 14:39:42 +0000 Subject: [PATCH 0282/1502] Replace iocp_events.py by proactor_events.py and create windows_events.py. --- tulip/base_events.py | 16 ++ tulip/events.py | 6 +- tulip/events_test.py | 29 ++-- tulip/{iocp_events.py => proactor_events.py} | 164 ++----------------- tulip/selector_events.py | 32 ++-- tulip/unix_events.py | 10 +- tulip/windows_events.py | 160 ++++++++++++++++++ 7 files changed, 225 insertions(+), 192 deletions(-) rename tulip/{iocp_events.py => proactor_events.py} (50%) create mode 100644 tulip/windows_events.py diff --git a/tulip/base_events.py b/tulip/base_events.py index d9e5633f..92304211 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -1,3 +1,19 @@ +"""Base implementation of event loop. + +The event loop can be broken up into a multiplexer (the part +responsible for notifying us of IO events) and the event loop proper, +which wraps a multiplexer with functionality for scheduling callbacks, +immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. +""" + + import collections import concurrent.futures import heapq diff --git a/tulip/events.py b/tulip/events.py index 70eb9b01..6aaea66d 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -248,11 +248,11 @@ def new_event_loop(self): """ # TODO: Do something else for Windows. if sys.platform == 'win32': - from . import selector_events - return selector_events.SelectorEventLoop() + from . import windows_events + return windows_events.SelectorEventLoop() else: from . import unix_events - return unix_events.UnixEventLoop() + return unix_events.SelectorEventLoop() # Event loop policy. The policy itself is always global, even if the diff --git a/tulip/events_test.py b/tulip/events_test.py index 1ff23ee2..0b27e94d 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -170,7 +170,7 @@ def run(arg): def test_reader_callback(self): el = events.get_event_loop() - r, w = selector_events.socketpair() + r, w = el._socketpair() bytes_read = [] def reader(): try: @@ -193,7 +193,7 @@ def reader(): def test_reader_callback_with_handler(self): el = events.get_event_loop() - r, w = selector_events.socketpair() + r, w = el._socketpair() bytes_read = [] def reader(): try: @@ -217,7 +217,7 @@ def reader(): def test_reader_callback_cancel(self): el = events.get_event_loop() - r, w = selector_events.socketpair() + r, w = el._socketpair() bytes_read = [] def reader(): try: @@ -239,7 +239,7 @@ def reader(): def test_writer_callback(self): el = events.get_event_loop() - r, w = selector_events.socketpair() + r, w = el._socketpair() w.setblocking(False) el.add_writer(w.fileno(), w.send, b'x'*(256*1024)) def remove_writer(): @@ -253,7 +253,7 @@ def remove_writer(): def test_writer_callback_with_handler(self): el = events.get_event_loop() - r, w = selector_events.socketpair() + r, w = el._socketpair() w.setblocking(False) handler = events.Handler(None, w.send, (b'x'*(256*1024),)) self.assertEqual(el.add_writer(w.fileno(), handler), handler) @@ -268,7 +268,7 @@ def remove_writer(): def test_writer_callback_cancel(self): el = events.get_event_loop() - r, w = selector_events.socketpair() + r, w = el._socketpair() w.setblocking(False) def sender(): w.send(b'x'*256) @@ -422,16 +422,15 @@ def test_start_serving(self): if sys.platform == 'win32': - from . import iocp_events - from . import selector_events + from . import windows_events class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): def create_event_loop(self): - return selector_events.SelectorEventLoop(selectors.SelectSelector()) + return windows_events.SelectorEventLoop() - class IocpEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + class ProactorEventLoopTests(EventLoopTestsMixin, unittest.TestCase): def create_event_loop(self): - return iocp_events.IocpEventLoop() + return windows_events.ProactorEventLoop() def test_create_ssl_transport(self): raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") def test_reader_callback(self): @@ -453,22 +452,22 @@ def test_writer_callback_with_handler(self): if hasattr(selectors, 'KqueueSelector'): class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase): def create_event_loop(self): - return unix_events.UnixEventLoop(selectors.KqueueSelector()) + return unix_events.SelectorEventLoop(selectors.KqueueSelector()) if hasattr(selectors, 'EpollSelector'): class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): def create_event_loop(self): - return unix_events.UnixEventLoop(selectors.EpollSelector()) + return unix_events.SelectorEventLoop(selectors.EpollSelector()) if hasattr(selectors, 'PollSelector'): class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): def create_event_loop(self): - return unix_events.UnixEventLoop(selectors.PollSelector()) + return unix_events.SelectorEventLoop(selectors.PollSelector()) # Should always exist. class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): def create_event_loop(self): - return unix_events.UnixEventLoop(selectors.SelectSelector()) + return unix_events.SelectorEventLoop(selectors.SelectSelector()) class HandlerTests(unittest.TestCase): diff --git a/tulip/iocp_events.py b/tulip/proactor_events.py similarity index 50% rename from tulip/iocp_events.py rename to tulip/proactor_events.py index 3eccd004..16099ca3 100644 --- a/tulip/iocp_events.py +++ b/tulip/proactor_events.py @@ -1,155 +1,18 @@ -# -# Module implementing the Proactor pattern -# -# A proactor is used to initiate asynchronous I/O, and to wait for -# completion of previously initiated operations. -# +"""Event loop using a proactor and related classes. + +A proactor is a "notify-on-completion" multiplexer. Currently a +proactor is only implemented on Windows with IOCP. +""" -import errno import logging -import os -import sys -import socket -import time -import weakref -from _winapi import CloseHandle from . import base_events -from . import futures from . import transports -from . import selector_events from . import winsocketpair -from . import _overlapped - - -NULL = 0 -INFINITE = 0xffffffff -ERROR_CONNECTION_REFUSED = 1225 -ERROR_CONNECTION_ABORTED = 1236 - - -class IocpProactor: - - def __init__(self, concurrency=0xffffffff): - self._results = [] - self._iocp = _overlapped.CreateIoCompletionPort( - _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency) - self._cache = {} - self._registered = weakref.WeakSet() - - def registered_count(self): - return len(self._cache) - - def select(self, timeout=None): - if not self._results: - self._poll(timeout) - tmp, self._results = self._results, [] - return tmp - - def recv(self, conn, nbytes, flags=0): - self._register_with_iocp(conn) - ov = _overlapped.Overlapped(NULL) - ov.WSARecv(conn.fileno(), nbytes, flags) - return self._register(ov, conn, ov.getresult) - - def send(self, conn, buf, flags=0): - self._register_with_iocp(conn) - ov = _overlapped.Overlapped(NULL) - ov.WSASend(conn.fileno(), buf, flags) - return self._register(ov, conn, ov.getresult) - - def accept(self, listener): - self._register_with_iocp(listener) - conn = self._get_accept_socket() - ov = _overlapped.Overlapped(NULL) - ov.AcceptEx(listener.fileno(), conn.fileno()) - def finish_accept(): - addr = ov.getresult() - conn.setsockopt(socket.SOL_SOCKET, - _overlapped.SO_UPDATE_ACCEPT_CONTEXT, - listener.fileno()) - conn.settimeout(listener.gettimeout()) - return conn, conn.getpeername() - return self._register(ov, listener, finish_accept) - - def connect(self, conn, address): - self._register_with_iocp(conn) - _overlapped.BindLocal(conn.fileno(), len(address)) - ov = _overlapped.Overlapped(NULL) - ov.ConnectEx(conn.fileno(), address) - def finish_connect(): - try: - ov.getresult() - except OSError as e: - if e.winerror == ERROR_CONNECTION_REFUSED: - raise ConnectionRefusedError(errno.ECONNREFUSED, - 'connection refused') - raise - conn.setsockopt(socket.SOL_SOCKET, - _overlapped.SO_UPDATE_CONNECT_CONTEXT, - 0) - return conn - return self._register(ov, conn, finish_connect) - - def _register_with_iocp(self, obj): - if obj not in self._registered: - self._registered.add(obj) - _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) - def _register(self, ov, obj, callback): - f = futures.Future() - self._cache[ov.address] = (f, ov, obj, callback) - return f - def _get_accept_socket(self): - s = socket.socket() - s.settimeout(0) - return s - - def _poll(self, timeout=None): - if timeout is None: - ms = INFINITE - elif timeout < 0: - raise ValueError("negative timeout") - else: - ms = int(timeout * 1000 + 0.5) - if ms >= INFINITE: - raise ValueError("timeout too big") - while True: - status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms) - if status is None: - return - address = status[3] - f, ov, obj, callback = self._cache.pop(address) - try: - value = callback() - except OSError as e: - f.set_exception(e) - self._results.append(f) - else: - f.set_result(value) - self._results.append(f) - ms = 0 - - def close(self): - for (f, ov, obj, callback) in self._cache.values(): - try: - ov.cancel() - except OSError: - pass - - while self._cache: - if not self._poll(1000): - logging.debug('taking long time to close proactor') - - self._results = [] - if self._iocp is not None: - CloseHandle(self._iocp) - self._iocp = None - - -class _IocpSocketTransport(transports.Transport): +class _ProactorSocketTransport(transports.Transport): def __init__(self, event_loop, sock, protocol, waiter=None): self._event_loop = event_loop @@ -236,19 +99,17 @@ def _call_connection_lost(self, exc): self._sock.close() -class IocpEventLoop(base_events.BaseEventLoop): +class BaseProactorEventLoop(base_events.BaseEventLoop): - SocketTransport = _IocpSocketTransport + SocketTransport = _ProactorSocketTransport @staticmethod def SslTransport(*args, **kwds): raise NotImplementedError - def __init__(self, proactor=None): + def __init__(self, proactor): super().__init__() - if proactor is None: - proactor = IocpProactor() - logging.debug('Using proactor: %s', proactor.__class__.__name__) + logging.debug('Using proactor: %s', proactor.__class__.__name__) self._proactor = proactor self._selector = proactor # convenient alias self._make_self_pipe() @@ -273,9 +134,12 @@ def sock_connect(self, sock, address): def sock_accept(self, sock): return self._proactor.accept(sock) + def _socketpair(self): + raise NotImplementedError + def _make_self_pipe(self): # A self-socket, really. :-) - self._ssock, self._csock = winsocketpair.socketpair() + self._ssock, self._csock = self._socketpair() self._ssock.setblocking(False) self._csock.setblocking(False) def loop(f=None): diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 5742fce4..8f05b443 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -1,16 +1,7 @@ -"""UNIX event loop and related classes. - -The event loop can be broken up into a selector (the part responsible -for telling us when file descriptors are ready) and the event loop -proper, which wraps a selector with functionality for scheduling -callbacks, immediately or at a given time in the future. - -Whenever a public API takes a callback, subsequent positional -arguments will be passed to the callback if/when it is called. This -avoids the proliferation of trivial lambdas implementing closures. -Keyword arguments for the callback are not supported; this is a -conscious design decision, leaving the door open for keyword arguments -to modify the meaning of the API call itself. +"""Event loop using a selector and related classes. + +A selector is a "notify-when-ready" multiplexer. For a subclass which +also includes support for signal handling, see the unix_events sub-module. """ import errno @@ -28,11 +19,6 @@ from . import selectors from . import transports -try: - from socket import socketpair -except ImportError: - assert sys.platform == 'win32' - from .winsocketpair import socketpair # Errno values indicating the connection was disconnected. _DISCONNECTED = frozenset((errno.ECONNRESET, @@ -49,7 +35,7 @@ _TRYAGAIN = frozenset(list(_TRYAGAIN) + [errno.WSAEWOULDBLOCK]) -class SelectorEventLoop(base_events.BaseEventLoop): +class BaseSelectorEventLoop(base_events.BaseEventLoop): """Selector event loop. See events.EventLoop for API specification. @@ -67,9 +53,8 @@ def SslTransport(event_loop, rawsock, protocol, sslcontext, waiter): def __init__(self, selector=None): super().__init__() if selector is None: - # pick the best selector class for the platform selector = selectors.Selector() - logging.debug('Using selector: %s', selector.__class__.__name__) + logging.debug('Using selector: %s', selector.__class__.__name__) self._selector = selector self._make_self_pipe() @@ -80,9 +65,12 @@ def close(self): self._ssock.close() self._csock.close() + def _socketpair(self): + raise NotImplementedError + def _make_self_pipe(self): # A self-socket, really. :-) - self._ssock, self._csock = socketpair() + self._ssock, self._csock = self._socketpair() self._ssock.setblocking(False) self._csock.setblocking(False) self.add_reader(self._ssock.fileno(), self._read_from_self) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 9d83fe3f..7c4d3cf6 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -1,4 +1,8 @@ +"""Selector eventloop for Unix with signal handling. +""" + import errno +import socket import sys try: @@ -10,18 +14,20 @@ from . import selector_events -__all__ = ['UnixEventLoop'] +__all__ = ['SelectorEventLoop'] if sys.platform == 'win32': raise ImportError('Signals are not really supported on Windows') -class UnixEventLoop(selector_events.SelectorEventLoop): +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): """Unix event loop Adds signal handling to SelectorEventLoop """ + def _socketpair(self): + return socket.socketpair() def add_signal_handler(self, sig, callback, *args): """Add a handler for a signal. UNIX only. diff --git a/tulip/windows_events.py b/tulip/windows_events.py new file mode 100644 index 00000000..23aed886 --- /dev/null +++ b/tulip/windows_events.py @@ -0,0 +1,160 @@ +"""Selector and proactor eventloops for Windows. +""" + +import errno +import logging +import socket +import weakref +import _winapi + + +from . import futures +from . import proactor_events +from . import selectors +from . import selector_events +from . import winsocketpair +from . import _overlapped + + +__all__ = ['SelectorEventLoop', 'ProactorEventLoop'] + + +NULL = 0 +INFINITE = 0xffffffff +ERROR_CONNECTION_REFUSED = 1225 +ERROR_CONNECTION_ABORTED = 1236 + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + def _socketpair(self): + return winsocketpair.socketpair() + + +class ProactorEventLoop(proactor_events.BaseProactorEventLoop): + def __init__(self, proactor=None): + if proactor is None: + proactor = IocpProactor() + super().__init__(proactor) + + def _socketpair(self): + return winsocketpair.socketpair() + + +class IocpProactor: + + def __init__(self, concurrency=0xffffffff): + self._results = [] + self._iocp = _overlapped.CreateIoCompletionPort( + _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency) + self._cache = {} + self._registered = weakref.WeakSet() + + def registered_count(self): + return len(self._cache) + + def select(self, timeout=None): + if not self._results: + self._poll(timeout) + tmp, self._results = self._results, [] + return tmp + + def recv(self, conn, nbytes, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + ov.WSARecv(conn.fileno(), nbytes, flags) + return self._register(ov, conn, ov.getresult) + + def send(self, conn, buf, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + ov.WSASend(conn.fileno(), buf, flags) + return self._register(ov, conn, ov.getresult) + + def accept(self, listener): + self._register_with_iocp(listener) + conn = self._get_accept_socket() + ov = _overlapped.Overlapped(NULL) + ov.AcceptEx(listener.fileno(), conn.fileno()) + def finish_accept(): + addr = ov.getresult() + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_ACCEPT_CONTEXT, + listener.fileno()) + conn.settimeout(listener.gettimeout()) + return conn, conn.getpeername() + return self._register(ov, listener, finish_accept) + + def connect(self, conn, address): + self._register_with_iocp(conn) + _overlapped.BindLocal(conn.fileno(), len(address)) + ov = _overlapped.Overlapped(NULL) + ov.ConnectEx(conn.fileno(), address) + def finish_connect(): + try: + ov.getresult() + except OSError as e: + if e.winerror == ERROR_CONNECTION_REFUSED: + raise ConnectionRefusedError(errno.ECONNREFUSED, + 'connection refused') + raise + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_CONNECT_CONTEXT, + 0) + return conn + return self._register(ov, conn, finish_connect) + + def _register_with_iocp(self, obj): + if obj not in self._registered: + self._registered.add(obj) + _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) + + def _register(self, ov, obj, callback): + f = futures.Future() + self._cache[ov.address] = (f, ov, obj, callback) + return f + + def _get_accept_socket(self): + s = socket.socket() + s.settimeout(0) + return s + + def _poll(self, timeout=None): + if timeout is None: + ms = INFINITE + elif timeout < 0: + raise ValueError("negative timeout") + else: + ms = int(timeout * 1000 + 0.5) + if ms >= INFINITE: + raise ValueError("timeout too big") + while True: + status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms) + if status is None: + return + address = status[3] + f, ov, obj, callback = self._cache.pop(address) + try: + value = callback() + except OSError as e: + f.set_exception(e) + self._results.append(f) + else: + f.set_result(value) + self._results.append(f) + ms = 0 + + def close(self): + for (f, ov, obj, callback) in self._cache.values(): + try: + ov.cancel() + except OSError: + pass + + while self._cache: + if not self._poll(1000): + logging.debug('taking long time to close proactor') + + self._results = [] + if self._iocp is not None: + _winapi.CloseHandle(self._iocp) + self._iocp = None From ab2d6e6a7f5534fd5c9de3605a036041b7da9e24 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 1 Feb 2013 14:26:01 -0800 Subject: [PATCH 0283/1502] Add event_loop parameter to Task ctor --- tulip/tasks.py | 6 +++--- tulip/tasks_test.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/tulip/tasks.py b/tulip/tasks.py index af7eaaff..6ab221b4 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -46,9 +46,9 @@ def task_wrapper(*args, **kwds): class Task(futures.Future): """A coroutine wrapped in a Future.""" - def __init__(self, coro): + def __init__(self, coro, event_loop=None): assert inspect.isgenerator(coro) # Must be a coroutine *object*. - super().__init__() # Sets self._event_loop. + super().__init__(event_loop=event_loop) # Sets self._event_loop. self._coro = coro self._must_cancel = False self._event_loop.call_soon(self._step) @@ -82,7 +82,7 @@ def _step_maybe(self): return self._step() def _step(self, value=None, exc=None): - if self.done(): # pragma: no cover + if self.done(): logging.warn('_step(): already done: %r, %r, %r', self, value, exc) return # We'll call either coro.throw(exc) or coro.send(value). diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py index bc6fa531..81854fd3 100644 --- a/tulip/tasks_test.py +++ b/tulip/tasks_test.py @@ -36,6 +36,11 @@ def notmuch(): self.event_loop.run() self.assertTrue(t.done()) self.assertEqual(t.result(), 'ok') + self.assertIs(t._event_loop, self.event_loop) + + event_loop = events.new_event_loop() + t = tasks.Task(notmuch(), event_loop=event_loop) + self.assertIs(t._event_loop, event_loop) def test_task_decorator(self): @tasks.task @@ -81,6 +86,34 @@ def inner2(): t = outer() self.assertEqual(self.event_loop.run_until_complete(t), 1042) + def test_cancel(self): + @tasks.task + def task(): + yield from tasks.sleep(10.0) + return 12 + + t = task() + self.event_loop.call_soon(t.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_cancel_in_coro(self): + @tasks.coroutine + def task(): + t.cancel() + yield from [] + return 12 + + t = tasks.Task(task()) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + def test_wait(self): a = tasks.sleep(0.1) b = tasks.sleep(0.15) From 52b718002adef27db1a8dda7b97869c0cd4e2898 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 1 Feb 2013 15:41:03 -0800 Subject: [PATCH 0284/1502] Allow to use existing socket object with start_serving and create_connection methods --- tulip/events.py | 8 +- tulip/events_test.py | 334 ++++++++++++++++++++++++++++--------------- tulip/unix_events.py | 137 ++++++++++-------- 3 files changed, 301 insertions(+), 178 deletions(-) diff --git a/tulip/events.py b/tulip/events.py index f39ddb79..9ff6e3fa 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -139,12 +139,12 @@ def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): def getnameinfo(self, sockaddr, flags=0): raise NotImplementedError - def create_connection(self, protocol_factory, host, port, *, - family=0, proto=0, flags=0): + def create_connection(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, sock=None): raise NotImplementedError - def start_serving(self, protocol_factory, host, port, *, - family=0, proto=0, flags=0): + def start_serving(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, sock=None): raise NotImplementedError # Ready-based callback registration methods. diff --git a/tulip/events_test.py b/tulip/events_test.py index c0d53c22..fa4b996e 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -1,6 +1,7 @@ """Tests for events.py.""" import concurrent.futures +import errno import gc import os import select @@ -14,11 +15,13 @@ import threading import time import unittest +import unittest.mock from . import events from . import transports from . import protocols from . import selectors +from . import test_utils from . import unix_events @@ -51,6 +54,7 @@ def connection_lost(self, exc): class EventLoopTestsMixin: def setUp(self): + super().setUp() self.selector = self.SELECTOR_CLASS() self.event_loop = unix_events.UnixEventLoop(self.selector) events.set_event_loop(self.event_loop) @@ -60,117 +64,109 @@ def tearDown(self): gc.collect() def test_run(self): - el = events.get_event_loop() - el.run() # Returns immediately. + self.event_loop.run() # Returns immediately. def test_call_later(self): - el = events.get_event_loop() results = [] def callback(arg): results.append(arg) - el.call_later(0.1, callback, 'hello world') + self.event_loop.call_later(0.1, callback, 'hello world') t0 = time.monotonic() - el.run() + self.event_loop.run() t1 = time.monotonic() self.assertEqual(results, ['hello world']) self.assertTrue(t1-t0 >= 0.09) def test_call_repeatedly(self): - el = events.get_event_loop() results = [] def callback(arg): results.append(arg) - el.call_repeatedly(0.03, callback, 'ho') - el.call_later(0.1, el.stop) - el.run() + self.event_loop.call_repeatedly(0.03, callback, 'ho') + self.event_loop.call_later(0.1, self.event_loop.stop) + self.event_loop.run() self.assertEqual(results, ['ho', 'ho', 'ho']) def test_call_soon(self): - el = events.get_event_loop() results = [] def callback(arg1, arg2): results.append((arg1, arg2)) - el.call_soon(callback, 'hello', 'world') - el.run() + self.event_loop.call_soon(callback, 'hello', 'world') + self.event_loop.run() self.assertEqual(results, [('hello', 'world')]) def test_call_soon_with_handler(self): - el = events.get_event_loop() results = [] def callback(): results.append('yeah') handler = events.Handler(None, callback, ()) - self.assertEqual(el.call_soon(handler), handler) - el.run() + self.assertIs(self.event_loop.call_soon(handler), handler) + self.event_loop.run() self.assertEqual(results, ['yeah']) def test_call_soon_threadsafe(self): - el = events.get_event_loop() results = [] def callback(arg): results.append(arg) def run(): - el.call_soon_threadsafe(callback, 'hello') + self.event_loop.call_soon_threadsafe(callback, 'hello') t = threading.Thread(target=run) - el.call_later(0.1, callback, 'world') + self.event_loop.call_later(0.1, callback, 'world') t0 = time.monotonic() t.start() - el.run() + self.event_loop.run() t1 = time.monotonic() t.join() self.assertEqual(results, ['hello', 'world']) self.assertTrue(t1-t0 >= 0.09) def test_call_soon_threadsafe_with_handler(self): - el = events.get_event_loop() results = [] def callback(arg): results.append(arg) + handler = events.Handler(None, callback, ('hello',)) def run(): - self.assertEqual(el.call_soon_threadsafe(handler), handler) + self.assertIs(self.event_loop.call_soon_threadsafe(handler),handler) + t = threading.Thread(target=run) - el.call_later(0.1, callback, 'world') + self.event_loop.call_later(0.1, callback, 'world') + t0 = time.monotonic() t.start() - el.run() + self.event_loop.run() t1 = time.monotonic() t.join() self.assertEqual(results, ['hello', 'world']) self.assertTrue(t1-t0 >= 0.09) def test_wrap_future(self): - el = events.get_event_loop() def run(arg): time.sleep(0.1) return arg ex = concurrent.futures.ThreadPoolExecutor(1) f1 = ex.submit(run, 'oi') - f2 = el.wrap_future(f1) - res = el.run_until_complete(f2) + f2 = self.event_loop.wrap_future(f1) + res = self.event_loop.run_until_complete(f2) self.assertEqual(res, 'oi') def test_run_in_executor(self): - el = events.get_event_loop() def run(arg): time.sleep(0.1) return arg - f2 = el.run_in_executor(None, run, 'yo') - res = el.run_until_complete(f2) + f2 = self.event_loop.run_in_executor(None, run, 'yo') + res = self.event_loop.run_until_complete(f2) self.assertEqual(res, 'yo') def test_run_in_executor_with_handler(self): - el = events.get_event_loop() def run(arg): time.sleep(0.1) return arg handler = events.Handler(None, run, ('yo',)) - f2 = el.run_in_executor(None, handler) - res = el.run_until_complete(f2) + f2 = self.event_loop.run_in_executor(None, handler) + res = self.event_loop.run_until_complete(f2) self.assertEqual(res, 'yo') def test_reader_callback(self): - el = events.get_event_loop() r, w = unix_events.socketpair() bytes_read = [] def reader(): @@ -183,17 +179,16 @@ def reader(): if data: bytes_read.append(data) else: - self.assertTrue(el.remove_reader(r.fileno())) + self.assertTrue(self.event_loop.remove_reader(r.fileno())) r.close() - el.add_reader(r.fileno(), reader) - el.call_later(0.05, w.send, b'abc') - el.call_later(0.1, w.send, b'def') - el.call_later(0.15, w.close) - el.run() + self.event_loop.add_reader(r.fileno(), reader) + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() self.assertEqual(b''.join(bytes_read), b'abcdef') def test_reader_callback_with_handler(self): - el = events.get_event_loop() r, w = unix_events.socketpair() bytes_read = [] def reader(): @@ -206,18 +201,19 @@ def reader(): if data: bytes_read.append(data) else: - self.assertTrue(el.remove_reader(r.fileno())) + self.assertTrue(self.event_loop.remove_reader(r.fileno())) r.close() + handler = events.Handler(None, reader, ()) - self.assertEqual(el.add_reader(r.fileno(), handler), handler) - el.call_later(0.05, w.send, b'abc') - el.call_later(0.1, w.send, b'def') - el.call_later(0.15, w.close) - el.run() + self.assertIs(handler, self.event_loop.add_reader(r.fileno(), handler)) + + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() self.assertEqual(b''.join(bytes_read), b'abcdef') def test_reader_callback_cancel(self): - el = events.get_event_loop() r, w = unix_events.socketpair() bytes_read = [] def reader(): @@ -228,74 +224,73 @@ def reader(): handler.cancel() if not data: r.close() - handler = el.add_reader(r.fileno(), reader) - el.call_later(0.05, w.send, b'abc') - el.call_later(0.1, w.send, b'def') - el.call_later(0.15, w.close) - el.run() + handler = self.event_loop.add_reader(r.fileno(), reader) + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() self.assertEqual(b''.join(bytes_read), b'abcdef') def test_writer_callback(self): - el = events.get_event_loop() r, w = unix_events.socketpair() w.setblocking(False) - el.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + self.event_loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) def remove_writer(): - self.assertTrue(el.remove_writer(w.fileno())) - el.call_later(0.1, remove_writer) - el.run() + self.assertTrue(self.event_loop.remove_writer(w.fileno())) + self.event_loop.call_later(0.1, remove_writer) + self.event_loop.run() w.close() data = r.recv(256*1024) r.close() self.assertTrue(len(data) >= 200) def test_writer_callback_with_handler(self): - el = events.get_event_loop() r, w = unix_events.socketpair() w.setblocking(False) handler = events.Handler(None, w.send, (b'x'*(256*1024),)) - self.assertEqual(el.add_writer(w.fileno(), handler), handler) + self.assertIs(self.event_loop.add_writer(w.fileno(), handler), handler) def remove_writer(): - self.assertTrue(el.remove_writer(w.fileno())) - el.call_later(0.1, remove_writer) - el.run() + self.assertTrue(self.event_loop.remove_writer(w.fileno())) + self.event_loop.call_later(0.1, remove_writer) + self.event_loop.run() w.close() data = r.recv(256*1024) r.close() self.assertTrue(len(data) >= 200) def test_writer_callback_cancel(self): - el = events.get_event_loop() r, w = unix_events.socketpair() w.setblocking(False) def sender(): w.send(b'x'*256) handler.cancel() - handler = el.add_writer(w.fileno(), sender) - el.run() + handler = self.event_loop.add_writer(w.fileno(), sender) + self.event_loop.run() w.close() data = r.recv(1024) r.close() self.assertTrue(data == b'x'*256) def test_sock_client_ops(self): - el = events.get_event_loop() sock = socket.socket() sock.setblocking(False) # TODO: This depends on python.org behavior! - el.run_until_complete(el.sock_connect(sock, ('python.org', 80))) - el.run_until_complete(el.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) - data = el.run_until_complete(el.sock_recv(sock, 1024)) + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, ('python.org', 80))) + self.event_loop.run_until_complete( + self.event_loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.event_loop.run_until_complete( + self.event_loop.sock_recv(sock, 1024)) sock.close() self.assertTrue(data.startswith(b'HTTP/1.1 302 Found\r\n')) def test_sock_client_fail(self): - el = events.get_event_loop() sock = socket.socket() sock.setblocking(False) # TODO: This depends on python.org behavior! with self.assertRaises(ConnectionRefusedError): - el.run_until_complete(el.sock_connect(sock, ('python.org', 12345))) + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, ('python.org', 12345))) sock.close() def test_sock_accept(self): @@ -305,9 +300,9 @@ def test_sock_accept(self): listener.listen(1) client = socket.socket() client.connect(listener.getsockname()) - el = events.get_event_loop() - f = el.sock_accept(listener) - conn, addr = el.run_until_complete(f) + + f = self.event_loop.sock_accept(listener) + conn, addr = self.event_loop.run_until_complete(f) self.assertEqual(conn.gettimeout(), 0) self.assertEqual(addr, client.getsockname()) self.assertEqual(client.getpeername(), listener.getsockname()) @@ -321,33 +316,42 @@ def test_add_signal_handler(self): def my_handler(): nonlocal caught caught += 1 - el = events.get_event_loop() + # Check error behavior first. - self.assertRaises(TypeError, el.add_signal_handler, 'boom', my_handler) - self.assertRaises(TypeError, el.remove_signal_handler, 'boom') - self.assertRaises(ValueError, el.add_signal_handler, signal.NSIG+1, - my_handler) - self.assertRaises(ValueError, el.remove_signal_handler, signal.NSIG+1) - self.assertRaises(ValueError, el.add_signal_handler, 0, my_handler) - self.assertRaises(ValueError, el.remove_signal_handler, 0) - self.assertRaises(ValueError, el.add_signal_handler, -1, my_handler) - self.assertRaises(ValueError, el.remove_signal_handler, -1) - self.assertRaises(RuntimeError, el.add_signal_handler, signal.SIGKILL, - my_handler) + self.assertRaises( + TypeError, self.event_loop.add_signal_handler, 'boom', my_handler) + self.assertRaises( + TypeError, self.event_loop.remove_signal_handler, 'boom') + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, signal.NSIG+1) + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, 0, my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, 0) + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, -1, my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, -1) + self.assertRaises( + RuntimeError, self.event_loop.add_signal_handler, signal.SIGKILL, + my_handler) # Removing SIGKILL doesn't raise, since we don't call signal(). - self.assertFalse(el.remove_signal_handler(signal.SIGKILL)) + self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGKILL)) # Now set a handler and handle it. - el.add_signal_handler(signal.SIGINT, my_handler) - el.run_once() + self.event_loop.add_signal_handler(signal.SIGINT, my_handler) + self.event_loop.run_once() os.kill(os.getpid(), signal.SIGINT) - el.run_once() + self.event_loop.run_once() self.assertEqual(caught, 1) # Removing it should restore the default handler. - self.assertTrue(el.remove_signal_handler(signal.SIGINT)) + self.assertTrue(self.event_loop.remove_signal_handler(signal.SIGINT)) self.assertEqual(signal.getsignal(signal.SIGINT), signal.default_int_handler) # Removing again returns False. - self.assertFalse(el.remove_signal_handler(signal.SIGINT)) + self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGINT)) @unittest.skipIf(sys.platform == 'win32', 'Unix only') def test_cancel_signal_handler(self): @@ -356,11 +360,11 @@ def test_cancel_signal_handler(self): def my_handler(): nonlocal caught caught += 1 - el = events.get_event_loop() - handler = el.add_signal_handler(signal.SIGINT, my_handler) + + handler = self.event_loop.add_signal_handler(signal.SIGINT, my_handler) handler.cancel() os.kill(os.getpid(), signal.SIGINT) - el.run_once() + self.event_loop.run_once() self.assertEqual(caught, 0) @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') @@ -370,70 +374,166 @@ def test_signal_handling_while_selecting(self): def my_handler(): nonlocal caught caught += 1 - el = events.get_event_loop() - handler = el.add_signal_handler(signal.SIGALRM, my_handler) + + handler = self.event_loop.add_signal_handler(signal.SIGALRM, my_handler) signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. - el.call_later(0.15, el.stop) - el.run_forever() + self.event_loop.call_later(0.15, self.event_loop.stop) + self.event_loop.run_forever() self.assertEqual(caught, 1) def test_create_transport(self): - el = events.get_event_loop() # TODO: This depends on xkcd.com behavior! - f = el.create_connection(MyProto, 'xkcd.com', 80) - tr, pr = el.run_until_complete(f) + f = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + tr, pr = self.event_loop.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) - el.run() + self.event_loop.run() self.assertTrue(pr.nbytes > 0) @unittest.skipIf(ssl is None, 'No ssl module') def test_create_ssl_transport(self): - el = events.get_event_loop() # TODO: This depends on xkcd.com behavior! - f = el.create_connection(MyProto, 'xkcd.com', 443, ssl=True) - tr, pr = el.run_until_complete(f) + f = self.event_loop.create_connection( + MyProto, 'xkcd.com', 443, ssl=True) + tr, pr = self.event_loop.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) self.assertTrue('ssl' in tr.__class__.__name__.lower()) - el.run() + self.event_loop.run() self.assertTrue(pr.nbytes > 0) + def test_create_transport_host_port_sock(self): + fut = self.event_loop.create_connection( + MyProto, 'xkcd.com', 80, sock=object()) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_create_transport_no_host_port_sock(self): + fut = self.event_loop.create_connection(MyProto) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_create_transport_no_getaddrinfo(self): + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + fut = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, fut) + + def test_create_transport_connect_err(self): + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + fut = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, fut) + def test_start_serving(self): - el = events.get_event_loop() - f = el.start_serving(MyProto, '0.0.0.0', 0) - sock = el.run_until_complete(f) + f = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + sock = self.event_loop.run_until_complete(f) host, port = sock.getsockname() self.assertEqual(host, '0.0.0.0') client = socket.socket() client.connect(('127.0.0.1', port)) client.send(b'xxx') - el.run_once() # This is quite mysterious, but necessary. - el.run_once() - el.run_once() + self.event_loop.run_once() # This is quite mysterious, but necessary. + self.event_loop.run_once() + self.event_loop.run_once() sock.close() # the client socket must be closed after to avoid ECONNRESET upon # recv()/send() on the serving socket client.close() + def test_start_serving_sock(self): + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.event_loop.start_serving(MyProto, sock=sock_ob) + sock = self.event_loop.run_until_complete(f) + self.assertIs(sock, sock_ob) + + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + self.event_loop.run_once() # This is quite mysterious, but necessary. + self.event_loop.run_once() + self.event_loop.run_once() + sock.close() + client.close() + + def test_start_serving_host_port_sock(self): + fut = self.event_loop.start_serving(MyProto,'0.0.0.0',0,sock=object()) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_start_serving_no_host_port_sock(self): + fut = self.event_loop.start_serving(MyProto) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_start_serving_no_getaddrinfo(self): + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + f = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, f) + + @unittest.mock.patch('tulip.unix_events.socket') + def test_start_serving_cant_bind(self, m_socket): + class Err(socket.error): + pass + + m_socket.error = socket.error + m_socket.getaddrinfo.return_value = [(2, 1, 6, '', ('127.0.0.1',10100))] + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.setsockopt.side_effect = Err + + fut = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises(Err, self.event_loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_accept_connection_retry(self): + class Err(socket.error): + errno = errno.EAGAIN + + sock = unittest.mock.Mock() + sock.accept.side_effect = Err + + self.event_loop._accept_connection(MyProto, sock) + self.assertFalse(sock.close.called) + + def test_accept_connection_exception(self): + self.suppress_log_errors() + + sock = unittest.mock.Mock() + sock.accept.side_effect = socket.error + + self.event_loop._accept_connection(MyProto, sock) + self.assertTrue(sock.close.called) + if hasattr(selectors, 'KqueueSelector'): - class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + class KqueueEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): SELECTOR_CLASS = selectors.KqueueSelector if hasattr(selectors, 'EpollSelector'): - class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + class EPollEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): SELECTOR_CLASS = selectors.EpollSelector if hasattr(selectors, 'PollSelector'): - class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + class PollEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): SELECTOR_CLASS = selectors.PollSelector # Should always exist. -class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): +class SelectEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): SELECTOR_CLASS = selectors.SelectSelector diff --git a/tulip/unix_events.py b/tulip/unix_events.py index fc2cb4a7..94bcde9d 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -284,78 +284,101 @@ def getnameinfo(self, sockaddr, flags=0): return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) @tasks.task - def create_connection(self, protocol_factory, host, port, *, ssl=False, - family=0, proto=0, flags=0): + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=False, family=0, proto=0, flags=0, sock=None): """XXX""" - infos = yield from self.getaddrinfo(host, port, - family=family, - type=socket.SOCK_STREAM, - proto=proto, flags=flags) - if not infos: - raise socket.error('getaddrinfo() returned empty list') - exceptions = [] - for family, type, proto, cname, address in infos: - sock = None - try: - sock = socket.socket(family=family, type=type, proto=proto) - sock.setblocking(False) - yield self.sock_connect(sock, address) - except socket.error as exc: - if sock is not None: - sock.close() - exceptions.append(exc) - else: - break - else: - if len(exceptions) == 1: - raise exceptions[0] + if host is not None or port is not None: + if sock is not None: + raise ValueError( + "host, port and sock can not be specified at the same time") + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + yield self.sock_connect(sock, address) + except socket.error as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break else: - # If they all have the same str(), raise one. - model = str(exceptions[0]) - if all(str(exc) == model for exc in exceptions): + if len(exceptions) == 1: raise exceptions[0] - # Raise a combined exception so the user can see all - # the various error messages. - raise socket.error('Multiple exceptions: {}'.format( - ', '.join(str(exc) for exc in exceptions))) + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise socket.error('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + + elif sock is None: + raise ValueError( + "host and port was not specified and no sock specified") + protocol = protocol_factory() waiter = futures.Future() if ssl: sslcontext = None if isinstance(ssl, bool) else ssl - transport = _UnixSslTransport(self, sock, protocol, sslcontext, - waiter) + transport = _UnixSslTransport( + self, sock, protocol, sslcontext, waiter) else: - transport = _UnixSocketTransport(self, sock, protocol, waiter) + transport = _UnixSocketTransport( + self, sock, protocol, waiter) + yield from waiter return transport, protocol # TODO: Or create_server()? @tasks.task - def start_serving(self, protocol_factory, host, port, *, - family=0, proto=0, flags=0, - backlog=100): + def start_serving(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, backlog=100, sock=None): """XXX""" - infos = yield from self.getaddrinfo(host, port, - family=family, - type=socket.SOCK_STREAM, - proto=proto, flags=flags) - if not infos: - raise socket.error('getaddrinfo() returned empty list') - # TODO: Maybe we want to bind every address in the list - # instead of the first one that works? - exceptions = [] - for family, type, proto, cname, address in infos: - sock = socket.socket(family=family, type=type, proto=proto) - try: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(address) - except socket.error as exc: - sock.close() - exceptions.append(exc) + if host is not None or port is not None: + if sock is not None: + raise ValueError( + "host, port and sock can not be specified at the same time") + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + # TODO: Maybe we want to bind every address in the list + # instead of the first one that works? + exceptions = [] + for family, type, proto, cname, address in infos: + sock = socket.socket(family=family, type=type, proto=proto) + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + except socket.error as exc: + sock.close() + exceptions.append(exc) + else: + break else: - break - else: - raise exceptions[0] + raise exceptions[0] + + elif sock is None: + raise ValueError( + "host and port was not specified and no sock specified") + sock.listen(backlog) sock.setblocking(False) self.add_reader(sock.fileno(), self._accept_connection, From 0cef6b5eebc10392c95934021149daa4ea91bcc3 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 4 Feb 2013 12:51:01 -0800 Subject: [PATCH 0285/1502] EventLoopPolicy tests --- tulip/events_test.py | 57 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 2 deletions(-) diff --git a/tulip/events_test.py b/tulip/events_test.py index fa4b996e..34806710 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -552,8 +552,61 @@ def callback(*args): class PolicyTests(unittest.TestCase): - def test_policy(self): - pass + def test_event_loop_policy(self): + policy = events.EventLoopPolicy() + self.assertRaises(NotImplementedError, policy.get_event_loop) + self.assertRaises(NotImplementedError, policy.set_event_loop, object()) + self.assertRaises(NotImplementedError, policy.new_event_loop) + + def test_get_event_loop(self): + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy._event_loop) + + event_loop = policy.get_event_loop() + self.assertIsInstance(event_loop, events.EventLoop) + + self.assertIs(policy._event_loop, event_loop) + self.assertIs(event_loop, policy.get_event_loop()) + + @unittest.mock.patch('tulip.events.threading') + def test_get_event_loop_thread(self, m_threading): + m_t = m_threading.current_thread.return_value = unittest.mock.Mock() + m_t.name = 'Thread 1' + + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy.get_event_loop()) + + def test_new_event_loop(self): + policy = events.DefaultEventLoopPolicy() + + event_loop = policy.new_event_loop() + self.assertIsInstance(event_loop, events.EventLoop) + + def test_set_event_loop(self): + policy = events.DefaultEventLoopPolicy() + old_event_loop = policy.get_event_loop() + + self.assertRaises(AssertionError, policy.set_event_loop, object()) + + event_loop = policy.new_event_loop() + policy.set_event_loop(event_loop) + self.assertIs(event_loop, policy.get_event_loop()) + self.assertIsNot(old_event_loop, policy.get_event_loop()) + + def test_get_event_loop_policy(self): + policy = events.get_event_loop_policy() + self.assertIsInstance(policy, events.EventLoopPolicy) + self.assertIs(policy, events.get_event_loop_policy()) + + def test_set_event_loop_policy(self): + self.assertRaises( + AssertionError, events.set_event_loop_policy, object()) + + old_policy = events.get_event_loop_policy() + + policy = events.DefaultEventLoopPolicy() + events.set_event_loop_policy(policy) + self.assertIs(policy, events.get_event_loop_policy()) if __name__ == '__main__': From 35f015995fc2ef9cb8e54e8abd9cbe30e43e8ed7 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 4 Feb 2013 16:30:49 -0800 Subject: [PATCH 0286/1502] remove test_utils.sync decorator --- tulip/http_client_test.py | 95 +++++++++++++++++++++++---------------- tulip/test_utils.py | 12 ----- 2 files changed, 56 insertions(+), 51 deletions(-) diff --git a/tulip/http_client_test.py b/tulip/http_client_test.py index 65a8b69d..c598a339 100644 --- a/tulip/http_client_test.py +++ b/tulip/http_client_test.py @@ -5,7 +5,6 @@ from . import events from . import http_client from . import tasks -from . import test_utils class StreamReaderTests(unittest.TestCase): @@ -14,10 +13,11 @@ class StreamReaderTests(unittest.TestCase): def setUp(self): self.event_loop = events.new_event_loop() - self.addCleanup(self.event_loop.close) - events.set_event_loop(self.event_loop) + def tearDown(self): + self.event_loop.close() + def test_feed_empty_data(self): stream = http_client.StreamReader() @@ -32,56 +32,62 @@ def test_feed_data_line_byte_count(self): self.assertEqual(self.DATA.count(b'\n'), stream.line_count) self.assertEqual(len(self.DATA), stream.byte_count) - @test_utils.sync def test_read_zero(self): - """ Read zero bytes """ + """Read zero bytes""" stream = http_client.StreamReader() stream.feed_data(self.DATA) - data = yield from stream.read(0) + read_task = tasks.Task(stream.read(0)) + data = self.event_loop.run_until_complete(read_task) self.assertEqual(b'', data) self.assertEqual(len(self.DATA), stream.byte_count) self.assertEqual(self.DATA.count(b'\n'), stream.line_count) - @test_utils.sync def test_read(self): """ Read bytes """ stream = http_client.StreamReader() - - res = stream.read(30) + read_task = tasks.Task(stream.read(30)) def cb(): stream.feed_data(self.DATA) self.event_loop.call_soon(cb) - data = yield from res + data = self.event_loop.run_until_complete(read_task) self.assertEqual(self.DATA, data) self.assertFalse(stream.byte_count) self.assertFalse(stream.line_count) - @test_utils.sync + def test_read_line_breaks(self): + """ Read bytes without line breaks """ + stream = http_client.StreamReader() + stream.feed_data(b'line1') + stream.feed_data(b'line2') + + read_task = tasks.Task(stream.read(5)) + data = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'line1', data) + self.assertEqual(5, stream.byte_count) + self.assertFalse(stream.line_count) + def test_read_eof(self): """ Read bytes, stop at eof """ stream = http_client.StreamReader() - - read = tasks.Task(stream.read(1024)) + read_task = tasks.Task(stream.read(1024)) def cb(): stream.feed_eof() self.event_loop.call_soon(cb) - data = yield from read - + data = self.event_loop.run_until_complete(read_task) self.assertEqual(b'', data) self.assertFalse(stream.byte_count) self.assertFalse(stream.line_count) - @test_utils.sync def test_read_until_eof(self): """ Read all bytes until eof """ stream = http_client.StreamReader() - - read = tasks.Task(stream.read(-1)) + read_task = tasks.Task(stream.read(-1)) def cb(): stream.feed_data(b'chunk1\n') @@ -89,17 +95,17 @@ def cb(): stream.feed_eof() self.event_loop.call_soon(cb) - data = yield from read + data = self.event_loop.run_until_complete(read_task) self.assertEqual(b'chunk1\nchunk2', data) self.assertFalse(stream.byte_count) self.assertFalse(stream.line_count) - @test_utils.sync def test_readline(self): """ Read one line """ stream = http_client.StreamReader() stream.feed_data(b'chunk1 ') + read_task = tasks.Task(stream.readline()) def cb(): stream.feed_data(b'chunk2 ') @@ -107,30 +113,41 @@ def cb(): stream.feed_data(b'\n chunk4') self.event_loop.call_soon(cb) - line = yield from stream.readline() - + line = self.event_loop.run_until_complete(read_task) self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) self.assertFalse(stream.line_count) self.assertEqual(len(b'\n chunk4')-1, stream.byte_count) - @test_utils.sync def test_readline_line_byte_count(self): stream = http_client.StreamReader() - stream.feed_data(self.DATA) + stream.feed_data(self.DATA[:6]) + stream.feed_data(self.DATA[6:]) - line = yield from stream.readline() + read_task = tasks.Task(stream.readline()) + line = self.event_loop.run_until_complete(read_task) self.assertEqual(b'line1\n', line) self.assertEqual(self.DATA.count(b'\n')-1, stream.line_count) self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) - @test_utils.sync + def test_readline_empty_eof(self): + stream = http_client.StreamReader() + stream.feed_eof() + + read_task = tasks.Task(stream.readline()) + line = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'', line) + def test_readline_read_byte_count(self): stream = http_client.StreamReader() stream.feed_data(self.DATA) - line = yield from stream.readline() - data = yield from stream.read(7) + read_task = tasks.Task(stream.readline()) + line = self.event_loop.run_until_complete(read_task) + + read_task = tasks.Task(stream.read(7)) + data = self.event_loop.run_until_complete(read_task) self.assertEqual(b'line2\nl', data) self.assertEqual( @@ -139,53 +156,53 @@ def test_readline_read_byte_count(self): len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), stream.byte_count) - @test_utils.sync def test_readexactly_zero_or_less(self): """ Read exact number of bytes (zero or less) """ stream = http_client.StreamReader() stream.feed_data(self.DATA) - data = yield from stream.readexactly(0) + read_task = tasks.Task(stream.readexactly(0)) + data = self.event_loop.run_until_complete(read_task) self.assertEqual(b'', data) self.assertEqual(len(self.DATA), stream.byte_count) self.assertEqual(self.DATA.count(b'\n'), stream.line_count) - data = yield from stream.readexactly(-1) + read_task = tasks.Task(stream.readexactly(-1)) + data = self.event_loop.run_until_complete(read_task) self.assertEqual(b'', data) self.assertEqual(len(self.DATA), stream.byte_count) self.assertEqual(self.DATA.count(b'\n'), stream.line_count) - @test_utils.sync def test_readexactly(self): """ Read exact number of bytes """ stream = http_client.StreamReader() + n = 2*len(self.DATA) + read_task = tasks.Task(stream.readexactly(n)) + def cb(): stream.feed_data(self.DATA) stream.feed_data(self.DATA) stream.feed_data(self.DATA) self.event_loop.call_soon(cb) - n = 2*len(self.DATA) - data = yield from stream.readexactly(n) - + data = self.event_loop.run_until_complete(read_task) self.assertEqual(self.DATA+self.DATA, data) self.assertEqual(len(self.DATA), stream.byte_count) self.assertEqual(self.DATA.count(b'\n'), stream.line_count) - @test_utils.sync def test_readexactly_eof(self): """ Read exact number of bytes (eof) """ stream = http_client.StreamReader() + n = 2*len(self.DATA) + read_task = tasks.Task(stream.readexactly(n)) def cb(): stream.feed_data(self.DATA) stream.feed_eof() self.event_loop.call_soon(cb) - n = 2*len(self.DATA) - data = yield from stream.readexactly(n) - + data = self.event_loop.run_until_complete(read_task) self.assertEqual(self.DATA, data) self.assertFalse(stream.byte_count) self.assertFalse(stream.line_count) diff --git a/tulip/test_utils.py b/tulip/test_utils.py index ac737c25..f07c34ce 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -1,20 +1,8 @@ """Utilities shared by tests.""" -import functools import logging import unittest -from . import events -from . import tasks - -def sync(gen): - @functools.wraps(gen) - def wrapper(*args, **kw): - return events.get_event_loop().run_until_complete( - tasks.Task(gen(*args, **kw))) - - return wrapper - class LogTrackingTestCase(unittest.TestCase): From cad8584cf82585e98fb51884ad5af873e7e93f96 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 4 Feb 2013 16:34:13 -0800 Subject: [PATCH 0287/1502] Handler comparison operations --- tulip/events.py | 20 ++++++++++++ tulip/events_test.py | 77 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 95 insertions(+), 2 deletions(-) diff --git a/tulip/events.py b/tulip/events.py index 9ff6e3fa..83b1adf3 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -50,15 +50,35 @@ def cancel(self): self._cancelled = True def __lt__(self, other): + if self._when is None: + return other._when is not None + elif other._when is None: + return False + return self._when < other._when def __le__(self, other): + if self._when is None: + return True + elif other._when is None: + return False + return self._when <= other._when def __gt__(self, other): + if self._when is None: + return False + elif other._when is None: + return True + return self._when > other._when def __ge__(self, other): + if self._when is None: + return other._when is None + elif other._when is None: + return True + return self._when >= other._when def __eq__(self, other): diff --git a/tulip/events_test.py b/tulip/events_test.py index 34806710..384c31ae 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -540,14 +540,87 @@ class SelectEventLoopTests(EventLoopTestsMixin, class HandlerTests(unittest.TestCase): def test_handler(self): - pass + def callback(*args): + return args + + args = () + h = events.Handler(None, callback, args) + self.assertIsNone(h.when) + self.assertIs(h.callback, callback) + self.assertIs(h.args, args) + self.assertFalse(h.cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handler(None, ' + '.callback')) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h.cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handler(None, ' + '.callback')) + self.assertTrue(r.endswith('())')) + + def test_handler_comparison(self): + def callback(*args): + return args + + h1 = events.Handler(None, callback, ()) + h2 = events.Handler(None, callback, ()) + self.assertTrue((h1 < h2) == False) + self.assertTrue((h1 <= h2) == True) + self.assertTrue((h1 > h2) == False) + self.assertTrue((h1 >= h2) == True) + self.assertTrue((h1 == h2) == True) + + when = time.monotonic() + + h1 = events.Handler(when, callback, ()) + h2 = events.Handler(None, callback, ()) + self.assertTrue((h1 < h2) == False) + self.assertTrue((h1 <= h2) == False) + self.assertTrue((h1 > h2) == True) + self.assertTrue((h1 >= h2) == True) + self.assertTrue((h1 == h2) == False) + + self.assertTrue((h2 < h1) == True) + self.assertTrue((h2 <= h1) == True) + self.assertTrue((h2 > h1) == False) + self.assertTrue((h2 >= h1) == False) + self.assertTrue((h2 == h1) == False) + + h1 = events.Handler(when, callback, ()) + h2 = events.Handler(when, callback, ()) + self.assertTrue((h1 < h2) == False) + self.assertTrue((h1 <= h2) == True) + self.assertTrue((h1 > h2) == False) + self.assertTrue((h1 >= h2) == True) + self.assertTrue((h1 == h2) == True) + + h1 = events.Handler(when, callback, ()) + h2 = events.Handler(when + 10.0, callback, ()) + self.assertTrue((h1 < h2) == True) + self.assertTrue((h1 <= h2) == True) + self.assertTrue((h1 > h2) == False) + self.assertTrue((h1 >= h2) == False) + self.assertTrue((h1 == h2) == False) def test_make_handler(self): def callback(*args): return args h1 = events.Handler(None, callback, ()) h2 = events.make_handler(None, h1, ()) - self.assertEqual(h1, h2) + self.assertIs(h1, h2) + + self.assertRaises(AssertionError, + events.make_handler, 10.0, h1, ()) + + self.assertRaises(AssertionError, + events.make_handler, None, h1, (1,2,)) class PolicyTests(unittest.TestCase): From 2bf355832244e10f1395a3e3e342ff03d5acb5e7 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 5 Feb 2013 09:43:45 -0800 Subject: [PATCH 0288/1502] Add a few comments. --- Makefile | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Makefile b/Makefile index d11e9716..65a48111 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,5 @@ +# Some simple testing tasks (sorry, UNIX only). + PYTHON=python3 COVERAGE=coverage3 NONTESTS=`find tulip -name [a-z]\*.py ! -name \*_test.py` @@ -9,6 +11,7 @@ test: testloop: while sleep 1; do $(PYTHON) runtests.py -v $(FLAGS); done +# See README for coverage installation instructions. cov coverage: $(COVERAGE) run --branch runtests.py -v $(FLAGS) $(COVERAGE) html $(NONTESTS) From 7828acfee9a9aa1f7f5b6ec1e7d756b69774a02f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 5 Feb 2013 10:17:30 -0800 Subject: [PATCH 0289/1502] Merge default rev 299 into iocp. Babysteps. --- tulip/tasks.py | 6 +++--- tulip/tasks_test.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/tulip/tasks.py b/tulip/tasks.py index af7eaaff..6ab221b4 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -46,9 +46,9 @@ def task_wrapper(*args, **kwds): class Task(futures.Future): """A coroutine wrapped in a Future.""" - def __init__(self, coro): + def __init__(self, coro, event_loop=None): assert inspect.isgenerator(coro) # Must be a coroutine *object*. - super().__init__() # Sets self._event_loop. + super().__init__(event_loop=event_loop) # Sets self._event_loop. self._coro = coro self._must_cancel = False self._event_loop.call_soon(self._step) @@ -82,7 +82,7 @@ def _step_maybe(self): return self._step() def _step(self, value=None, exc=None): - if self.done(): # pragma: no cover + if self.done(): logging.warn('_step(): already done: %r, %r, %r', self, value, exc) return # We'll call either coro.throw(exc) or coro.send(value). diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py index bc6fa531..81854fd3 100644 --- a/tulip/tasks_test.py +++ b/tulip/tasks_test.py @@ -36,6 +36,11 @@ def notmuch(): self.event_loop.run() self.assertTrue(t.done()) self.assertEqual(t.result(), 'ok') + self.assertIs(t._event_loop, self.event_loop) + + event_loop = events.new_event_loop() + t = tasks.Task(notmuch(), event_loop=event_loop) + self.assertIs(t._event_loop, event_loop) def test_task_decorator(self): @tasks.task @@ -81,6 +86,34 @@ def inner2(): t = outer() self.assertEqual(self.event_loop.run_until_complete(t), 1042) + def test_cancel(self): + @tasks.task + def task(): + yield from tasks.sleep(10.0) + return 12 + + t = task() + self.event_loop.call_soon(t.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_cancel_in_coro(self): + @tasks.coroutine + def task(): + t.cancel() + yield from [] + return 12 + + t = tasks.Task(task()) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + def test_wait(self): a = tasks.sleep(0.1) b = tasks.sleep(0.15) From 81a1fab10c83305db73150b0ba2b06498af5abe0 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 5 Feb 2013 10:25:13 -0800 Subject: [PATCH 0290/1502] Increase test coverage of tasks.py to 100% --- .hgeol | 2 + .hgignore | 8 + Makefile | 33 ++ NOTES | 130 +++++ README | 30 + TODO | 165 ++++++ check.py | 41 ++ crawl.py | 136 +++++ curl.py | 31 + old/Makefile | 16 + old/echoclt.py | 79 +++ old/echosvr.py | 60 ++ old/http_client.py | 78 +++ old/http_server.py | 68 +++ old/main.py | 134 +++++ old/p3time.py | 47 ++ old/polling.py | 535 +++++++++++++++++ old/scheduling.py | 354 ++++++++++++ old/sockets.py | 348 ++++++++++++ old/transports.py | 496 ++++++++++++++++ old/xkcd.py | 18 + old/yyftime.py | 75 +++ runtests.py | 97 ++++ srv.py | 131 +++++ sslsrv.py | 56 ++ tulip/TODO | 26 + tulip/__init__.py | 14 + tulip/events.py | 307 ++++++++++ tulip/events_test.py | 686 ++++++++++++++++++++++ tulip/futures.py | 240 ++++++++ tulip/futures_test.py | 210 +++++++ tulip/http_client.py | 295 ++++++++++ tulip/http_client_test.py | 212 +++++++ tulip/protocols.py | 58 ++ tulip/selectors.py | 430 ++++++++++++++ tulip/subprocess_test.py | 54 ++ tulip/subprocess_transport.py | 133 +++++ tulip/tasks.py | 287 ++++++++++ tulip/tasks_test.py | 473 +++++++++++++++ tulip/test_utils.py | 22 + tulip/transports.py | 90 +++ tulip/unix_events.py | 1012 +++++++++++++++++++++++++++++++++ tulip/winsocketpair.py | 30 + 43 files changed, 7747 insertions(+) create mode 100644 .hgeol create mode 100644 .hgignore create mode 100644 Makefile create mode 100644 NOTES create mode 100644 README create mode 100644 TODO create mode 100644 check.py create mode 100755 crawl.py create mode 100755 curl.py create mode 100644 old/Makefile create mode 100644 old/echoclt.py create mode 100644 old/echosvr.py create mode 100644 old/http_client.py create mode 100644 old/http_server.py create mode 100644 old/main.py create mode 100644 old/p3time.py create mode 100644 old/polling.py create mode 100644 old/scheduling.py create mode 100644 old/sockets.py create mode 100644 old/transports.py create mode 100755 old/xkcd.py create mode 100644 old/yyftime.py create mode 100644 runtests.py create mode 100644 srv.py create mode 100644 sslsrv.py create mode 100644 tulip/TODO create mode 100644 tulip/__init__.py create mode 100644 tulip/events.py create mode 100644 tulip/events_test.py create mode 100644 tulip/futures.py create mode 100644 tulip/futures_test.py create mode 100644 tulip/http_client.py create mode 100644 tulip/http_client_test.py create mode 100644 tulip/protocols.py create mode 100644 tulip/selectors.py create mode 100644 tulip/subprocess_test.py create mode 100644 tulip/subprocess_transport.py create mode 100644 tulip/tasks.py create mode 100644 tulip/tasks_test.py create mode 100644 tulip/test_utils.py create mode 100644 tulip/transports.py create mode 100644 tulip/unix_events.py create mode 100644 tulip/winsocketpair.py diff --git a/.hgeol b/.hgeol new file mode 100644 index 00000000..b6910a22 --- /dev/null +++ b/.hgeol @@ -0,0 +1,2 @@ +[patterns] +** = native diff --git a/.hgignore b/.hgignore new file mode 100644 index 00000000..42309f0c --- /dev/null +++ b/.hgignore @@ -0,0 +1,8 @@ +.*\.py[co]$ +.*~$ +.*\.orig$ +.*\#.*$ +.*@.*$ +\.coverage$ +htmlcov$ +\.DS_Store$ diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..65a48111 --- /dev/null +++ b/Makefile @@ -0,0 +1,33 @@ +# Some simple testing tasks (sorry, UNIX only). + +PYTHON=python3 +COVERAGE=coverage3 +NONTESTS=`find tulip -name [a-z]\*.py ! -name \*_test.py` +FLAGS= + +test: + $(PYTHON) runtests.py -v $(FLAGS) + +testloop: + while sleep 1; do $(PYTHON) runtests.py -v $(FLAGS); done + +# See README for coverage installation instructions. +cov coverage: + $(COVERAGE) run --branch runtests.py -v $(FLAGS) + $(COVERAGE) html $(NONTESTS) + $(COVERAGE) report -m $(NONTESTS) + echo "open file://`pwd`/htmlcov/index.html" + +check: + $(PYTHON) check.py + +clean: + rm -rf __pycache__ */__pycache__ + rm -f *.py[co] */*.py[co] + rm -f *~ */*~ + rm -f .*~ */.*~ + rm -f @* */@* + rm -f '#'*'#' */'#'*'#' + rm -f *.orig */*.orig + rm -f .coverage + rm -rf htmlcov diff --git a/NOTES b/NOTES new file mode 100644 index 00000000..6f41578e --- /dev/null +++ b/NOTES @@ -0,0 +1,130 @@ +Notes from the second Tulip/Twisted meet-up +=========================================== + +Rackspace, 12/11/2012 +Glyph, Brian Warner, David Reid, Duncan McGreggor, others + +Flow control +------------ + +- Pause/resume on transport manages data_received. + +- There's also an API to tell the transport whom to pause when the + write calls are overwhelming it: IConsumer.registerProducer(). + +- There's also something called pipes but it's built on top of the + old interface. + +- Twisted has variations on the basic flow control that I should + ignore. + +Half_close +---------- + +- This sends an EOF after writing some stuff. + +- Can't write any more. + +- Problem with TLS is known (the RFC sadly specifies this behavior). + +- It must be dynamimcally discoverable whether the transport supports + half_close, since the protocol may have to do something different to + make up for its missing (e.g. use chunked encoding). Twisted uses + an interface check for this and also hasattr(trans, 'halfClose') + but a flag (or flag method) is fine too. + +Constructing transport and protocol +----------------------------------- + +- There are good reasons for passing a function to the transport + construction helper that creates the protocol. (You need these + anyway for server-side protocols.) The sequence of events is + something like + + . open socket + . create transport (pass it a socket?) + . create protocol (pass it nothing) + . proto.make_connection(transport); this does: + . self.transport = transport + . self.connection_made(transport) + + But it seems okay to skip make_connection and setting .transport. + Note that make_connection() is a concrete method on the Protocol + implementation base class, while connection_made() is an abstract + method on IProtocol. + +Event Loop +---------- + +- We discussed the sequence of actions in the event loop. I think in the + end we're fine with what Tulip currently does. There are two choices: + + Tulip: + . run ready callbacks until there aren't any left + . poll, adding more callbacks to the ready list + . add now-ready delayed callbacks to the ready list + . go to top + + Tornado: + . run all currently ready callbacks (but not new ones added during this) + . (the rest is the same) + + The difference is that in the Tulip version, CPU bound callbacks + that keep adding more to the queue will starve I/O (and yielding to + other tasks won't actually cause I/O to happen unless you do + e.g. sleep(0.001)). OTOH this may be good because it means there's + less overhead if you frequently split operations in two. + +- I think Twisted does it Tornado style (in a convoluted way :-), but + it may not matter, and it's important to leave this vague so + implementations can do what's best for their platform. (E.g. if the + event loop is built into the OS there are different trade-offs.) + +System call cost +---------------- + +- System calls on MacOS are expensive, on Linux they are cheap. + +- Optimal buffer size ~16K. + +- Try joining small buffer pieces together, but expect to be tuning + this later. + +Futures +------- + +- Futures are the most robust API for async stuff, you can check + errors etc. So let's do this. + +- Just don't implement wait(). + +- For the basics, however, (recv/send, mostly), don't use Futures but use + basic callbacks, transport/protocol style. + +- make_connection() (by any name) can return a Future, it makes it + easier to check for errors. + +- This means revisiting the Tulip proactor branch (IOCP). + +- The semantics of add_done_callback() are fuzzy about in which thread + the callback will be called. (It may be the current thread or + another one.) We don't like that. But always inserting a + call_soon() indirection may be expensive? Glyph suggested changing + the add_done_callback() method name to something else to indicate + the changed promise. + +- Separately, I've been thinking about having two versions of + call_soon() -- a more heavy-weight one to be called from other + threads that also writes a byte to the self-pipe. + +Signals +------- + +- There was a side conversation about signals. A signal handler is + similar to another thread, so probably should use (the heavy-weight + version of) call_soon() to schedule the real callback and not do + anything else. + +- Glyph vaguely recalled some trickiness with the self-pipe. We + should be able to fix this afterwards if necessary, it shouldn't + affect the API design. diff --git a/README b/README new file mode 100644 index 00000000..c1c86a54 --- /dev/null +++ b/README @@ -0,0 +1,30 @@ +Tulip is the codename for my reference implementation of PEP 3156. + +PEP 3156: http://www.python.org/dev/peps/pep-3156/ + +*** This requires Python 3.3 or later! *** + +Copyright/license: Open source, Apache 2.0. Enjoy. + +Master Mercurial repo: http://code.google.com/p/tulip/ + +The old code lives in the subdirectory 'old'; the new code (conforming +to PEP 3156, under construction) lives in the 'tulip' subdirectory. + +To run tests: + - make test + +To run coverage (after installing coverage3, see below): + - make coverage + +To install coverage3 (coverage.py for Python 3), you need: + - Distribute (http://packages.python.org/distribute/) + - Coverage (http://nedbatchelder.com/code/coverage/) + What worked for me: + - curl -O http://python-distribute.org/distribute_setup.py + - python3 distribute_setup.py + - + - cd coveragepy + - python3 setup.py install + +--Guido van Rossum diff --git a/TODO b/TODO new file mode 100644 index 00000000..b9559ef0 --- /dev/null +++ b/TODO @@ -0,0 +1,165 @@ +# -*- Mode: text -*- + +TO DO LARGER TASKS + +- Need more examples. + +- Benchmarkable but more realistic HTTP server? + +- Example of using UDP. + +- Write up a tutorial for the scheduling API. + +- More systematic approach to logging. Logger objects? What about + heavy-duty logging, tracking essentially all task state changes? + +- Restructure directory, move demos and benchmarks to subdirectories. + + +TO DO LATER + +- When multiple tasks are accessing the same socket, they should + either get interleaved I/O or an immediate exception; it should not + compromise the integrity of the scheduler or the app or leave a task + hanging. + +- For epoll you probably want to check/(log?) EPOLLHUP and EPOLLERR errors. + +- Add the simplest API possible to run a generator with a timeout. + +- Ensure multiple tasks can do atomic writes to the same pipe (since + UNIX guarantees that short writes to pipes are atomic). + +- Ensure some easy way of distributing accepted connections across tasks. + +- Be wary of thread-local storage. There should be a standard API to + get the current Context (which holds current task, event loop, and + maybe more) and a standard meta-API to change how that standard API + works (i.e. without monkey-patching). + +- See how much of asyncore I've already replaced. + +- Could BufferedReader reuse the standard io module's readers??? + +- Support ZeroMQ "sockets" which are user objects. Though possibly + this can be supported by getting the underlying fd? See + http://mail.python.org/pipermail/python-ideas/2012-October/017532.html + OTOH see + https://github.com/zeromq/pyzmq/blob/master/zmq/eventloop/ioloop.py + +- Study goroutines (again). + +- Benchmarks: http://nichol.as/benchmark-of-python-web-servers + + +FROM OLDER LIST + +- Multiple readers/writers per socket? (At which level? pollster, + eventloop, or scheduler?) + +- Could poll() usefully be an iterator? + +- Do we need to support more epoll and/or kqueue modes/flags/options/etc.? + +- Optimize register/unregister calls away if they cancel each other out? + +- Add explicit wait queue to wait for Task's completion, instead of + callbacks? + +- Implement various lock styles a la threading.py. + +- Look at pyfdpdlib's ioloop.py: + http://code.google.com/p/pyftpdlib/source/browse/trunk/pyftpdlib/lib/ioloop.py + + +MISTAKES I MADE + +- Forgetting yield from. (E.g.: scheduler.sleep(1); listener.accept().) + +- Forgot to add bare yield at end of internal function, after block(). + +- Forgot to call add_done_callback(). + +- Forgot to pass an undoer to block(), bug only found when cancelled. + +- Subtle accounting mistake in a callback. + +- Used context.eventloop from a different thread, forgetting about TLS. + +- Nasty race: eventloop.ready may contain both an I/O callback and a + cancel callback. How to avoid? Keep the DelayedCall in ready. Is + that enough? + +- If a toplevel task raises an error it just stops and nothing is logged + unless you have debug logging on. This confused me. (Then again, + previously I logged whenever a task raised an error, and that was too + chatty...) + +- Forgot to set the connection socket returned by accept() in + nonblocking mode. + +- Nastiest so far (cost me about a day): A race condition in + call_in_thread() where the Future's done_callback (which was + task.unblock()) would run immediately at the time when + add_done_callback() was called, and this screwed over the task + state. Solution: wrap the callback in eventloop.call_later(). + Ironically, I had a comment stating there might be a race condition. + +- Another bug where I was calling unblock() for the current thread + immediately after calling block(), before yielding. + +- readexactly() wasn't checking for EOF, so could be looping. + (Worse, the first fix I attempted was wrong.) + +- Spent a day trying to understand why a tentative patch trying to + move the recv() implementation into the eventloop (or the pollster) + resulted in problems cancelling a recv() call. Ultimately the + problem is that the cancellation mechanism is part of the coroutine + scheduler, which simply throws an exception into a task when it next + runs, and there isn't anything to be interrupted in the eventloop; + but the eventloop still has a reader registered (which will never + fire because I suspended the server -- that's my test case :-). + Then, the eventloop keeps running until the last file descriptor is + unregistered. What contributed to this disaster? + * I didn't build the whole infrastructure, just played with recv() + * I don't have unittests + * I don't have good logging to see what is going + +- In sockets.py, in some SSL error handling code, used the wrong + variable (sock instead of sslsock). A linter would have found this. + +- In polling.py, in KqueuePollster.register_writer(), a copy/paste + error where I was testing for "if fd not in self.readers" instead of + writers. This only came out when I had both a reader and a writer + for the same fd. + +- Submitted some changes prematurely (forgot to pass the filename on + hg ci). + +- Forgot again that shutdown(SHUT_WR) on an ssl socket does not work + as I expected. I ran into this with the origininal sockets.py and + again in transport.py. + +- Having the same callback for both reading and writing has a problem: + it may be scheduled twice, and if the first call closes the socket, + the second runs into trouble. + + +MISTAKES I MADE IN TULIP V2 + +- Nice one: Future._schedule_callbacks() wasn't scheduling any callbacks. + Spot the bug in these four lines: + + def _schedule_callbacks(self): + callbacks = self._callbacks[:] + self._callbacks[:] = [] + for callback in self._callbacks: + self._event_loop.call_soon(callback, self) + + The good news is that I found it with a unittest (albeit not the + unittest intended to exercise this particular method :-( ). + +- In _make_self_pipe_or_sock(), called _pollster.register_reader() + instead of add_reader(), trying to optimize something but breaking + things instead (since the -- internal -- API of register_reader() + had changed). diff --git a/check.py b/check.py new file mode 100644 index 00000000..64bc2cdd --- /dev/null +++ b/check.py @@ -0,0 +1,41 @@ +"""Search for lines > 80 chars or with trailing whitespace.""" + +import sys, os + +def main(): + args = sys.argv[1:] or os.curdir + for arg in args: + if os.path.isdir(arg): + for dn, dirs, files in os.walk(arg): + for fn in sorted(files): + if fn.endswith('.py'): + process(os.path.join(dn, fn)) + dirs[:] = [d for d in dirs if d[0] != '.'] + dirs.sort() + else: + process(arg) + +def isascii(x): + try: + x.encode('ascii') + return True + except UnicodeError: + return False + +def process(fn): + try: + f = open(fn) + except IOError as err: + print(err) + return + try: + for i, line in enumerate(f): + line = line.rstrip('\n') + sline = line.rstrip() + if len(line) > 80 or line != sline or not isascii(line): + print('%s:%d:%s%s' % ( + fn, i+1, sline, '_' * (len(line) - len(sline)))) + finally: + f.close() + +main() diff --git a/crawl.py b/crawl.py new file mode 100755 index 00000000..b9881fcb --- /dev/null +++ b/crawl.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 + +import logging +import re +import signal +import socket +import sys +import urllib.parse + +import tulip +from tulip import http_client + +END = '\n' +MAXTASKS = 100 + + +class Crawler: + + def __init__(self, rooturl): + self.rooturl = rooturl + self.todo = set() + self.busy = set() + self.done = {} + self.tasks = set() + self.waiter = None + self.addurl(self.rooturl, '') # Set initial work. + self.run() # Kick off work. + + def addurl(self, url, parenturl): + url = urllib.parse.urljoin(parenturl, url) + url, frag = urllib.parse.urldefrag(url) + if not url.startswith(self.rooturl): + return False + if url in self.busy or url in self.done or url in self.todo: + return False + self.todo.add(url) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(None) + return True + + @tulip.task + def run(self): + while self.todo or self.busy or self.tasks: + complete, self.tasks = yield from tulip.wait(self.tasks, timeout=0) + print(len(complete), 'completed tasks,', len(self.tasks), + 'still pending ', end=END) + for task in complete: + try: + yield from task + except Exception as exc: + print('Exception in task:', exc, end=END) + while self.todo and len(self.tasks) < MAXTASKS: + url = self.todo.pop() + self.busy.add(url) + self.tasks.add(self.process(url)) # Async task. + if self.busy: + self.waiter = tulip.Future() + yield from self.waiter + tulip.get_event_loop().stop() + + @tulip.task + def process(self, url): + ok = False + p = None + try: + print('processing', url, end=END) + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not path: + path = '/' + if query: + path = '?'.join([path, query]) + p = http_client.HttpClientProtocol(netloc, path=path, + ssl=(scheme=='https')) + delay = 1 + while True: + try: + status, headers, stream = yield from p.connect() + break + except socket.error as exc: + if delay >= 60: + raise + print('...', url, 'has error', repr(str(exc)), + 'retrying after sleep', delay, '...', end=END) + yield from tulip.sleep(delay) + delay *= 2 + if status[:3] in ('301', '302'): + # Redirect. + u = headers.get('location') or headers.get('uri') + if self.addurl(u, url): + print(' ', url, status[:3], 'redirect to', u, end=END) + elif status.startswith('200'): + ctype = headers.get_content_type() + if ctype == 'text/html': + while True: + line = yield from stream.readline() + if not line: + break + line = line.decode('utf-8', 'replace') + urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', + line) + for u in urls: + if self.addurl(u, url): + print(' ', url, 'href to', u, end=END) + ok = True + finally: + if p is not None: + p.transport.close() + self.done[url] = ok + self.busy.remove(url) + if not ok: + print('failure for', url, sys.exc_info(), end=END) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(None) + + +def main(): + rooturl = sys.argv[1] + c = Crawler(rooturl) + loop = tulip.get_event_loop() + try: + loop.add_signal_handler(signal.SIGINT, loop.stop) + except RuntimeError: + pass + loop.run_forever() + print('todo:', len(c.todo)) + print('busy:', len(c.busy)) + print('done:', len(c.done), '; ok:', sum(c.done.values())) + print('tasks:', len(c.tasks)) + + +if __name__ == '__main__': + main() diff --git a/curl.py b/curl.py new file mode 100755 index 00000000..1a73c194 --- /dev/null +++ b/curl.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 + +import sys +import urllib.parse + +import tulip +from tulip import http_client + + +def main(): + url = sys.argv[1] + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not path: + path = '/' + if query: + path = '?'.join([path, query]) + print(netloc, path, scheme) + p = http_client.HttpClientProtocol(netloc, path=path, + ssl=(scheme=='https')) + f = p.connect() + sts, headers, stream = p.event_loop.run_until_complete(tulip.Task(f)) + print(sts) + for k, v in headers.items(): + print('{}: {}'.format(k, v)) + print() + data = p.event_loop.run_until_complete(tulip.Task(stream.read(1000000))) + print(data.decode('utf-8', 'replace')) + + +if __name__ == '__main__': + main() diff --git a/old/Makefile b/old/Makefile new file mode 100644 index 00000000..d352cd70 --- /dev/null +++ b/old/Makefile @@ -0,0 +1,16 @@ +PYTHON=python3 + +main: + $(PYTHON) main.py -v + +echo: + $(PYTHON) echosvr.py -v + +profile: + $(PYTHON) -m profile -s time main.py + +time: + $(PYTHON) p3time.py + +ytime: + $(PYTHON) yyftime.py diff --git a/old/echoclt.py b/old/echoclt.py new file mode 100644 index 00000000..c24c573e --- /dev/null +++ b/old/echoclt.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3.3 +"""Example echo client.""" + +# Stdlib imports. +import logging +import socket +import sys +import time + +# Local imports. +import scheduling +import sockets + + +def echoclient(host, port): + """COROUTINE""" + testdata = b'hi hi hi ha ha ha\n' + try: + trans = yield from sockets.create_transport(host, port, + af=socket.AF_INET) + except OSError: + return False + try: + ok = yield from trans.send(testdata) + if ok: + response = yield from trans.recv(100) + ok = response == testdata.upper() + return ok + finally: + trans.close() + + +def doit(n): + """COROUTINE""" + t0 = time.time() + tasks = set() + for i in range(n): + t = scheduling.Task(echoclient('127.0.0.1', 1111), 'client-%d' % i) + tasks.add(t) + ok = 0 + bad = 0 + for t in tasks: + try: + yield from t + except Exception: + bad += 1 + else: + ok += 1 + t1 = time.time() + print('ok: ', ok) + print('bad:', bad) + print('dt: ', round(t1-t0, 6)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Get integer from command line. + n = 1 + for arg in sys.argv[1:]: + if not arg.startswith('-'): + n = int(arg) + break + + # Run scheduler, starting it off with doit(). + scheduling.run(doit(n)) + + +if __name__ == '__main__': + main() diff --git a/old/echosvr.py b/old/echosvr.py new file mode 100644 index 00000000..4085f4c6 --- /dev/null +++ b/old/echosvr.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3.3 +"""Example echo server.""" + +# Stdlib imports. +import logging +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + while True: + line = yield from rdr.readline() + logging.debug('Received: %r from %r', line, addr) + if not line: + break + yield from trans.send(line.upper()) + logging.debug('Closing %r', addr) + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 1111, + af=socket.AF_INET, + backlog=100) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() diff --git a/old/http_client.py b/old/http_client.py new file mode 100644 index 00000000..8937ba20 --- /dev/null +++ b/old/http_client.py @@ -0,0 +1,78 @@ +"""Crummy HTTP client. + +This is not meant as an example of how to write a good client. +""" + +# Stdlib. +import re +import time + +# Local. +import sockets + + +def urlfetch(host, port=None, path='/', method='GET', + body=None, hdrs=None, encoding='utf-8', ssl=None, af=0): + """COROUTINE: Make an HTTP 1.0 request.""" + t0 = time.time() + if port is None: + if ssl: + port = 443 + else: + port = 80 + trans = yield from sockets.create_transport(host, port, ssl=ssl, af=af) + yield from trans.send(method.encode(encoding) + b' ' + + path.encode(encoding) + b' HTTP/1.0\r\n') + if hdrs: + kwds = dict(hdrs) + else: + kwds = {} + if 'host' not in kwds: + kwds['host'] = host + if body is not None: + kwds['content_length'] = len(body) + for header, value in kwds.items(): + yield from trans.send(header.replace('_', '-').encode(encoding) + + b': ' + value.encode(encoding) + b'\r\n') + + yield from trans.send(b'\r\n') + if body is not None: + yield from trans.send(body) + + # Read HTTP response line. + rdr = sockets.BufferedReader(trans) + resp = yield from rdr.readline() + m = re.match(br'(?ix) http/(\d\.\d) \s+ (\d\d\d) \s+ ([^\r]*)\r?\n\Z', + resp) + if not m: + trans.close() + raise IOError('No valid HTTP response: %r' % resp) + http_version, status, message = m.groups() + + # Read HTTP headers. + headers = [] + hdict = {} + while True: + line = yield from rdr.readline() + if not line.strip(): + break + m = re.match(br'([^\s:]+):\s*([^\r]*)\r?\n\Z', line) + if not m: + raise IOError('Invalid header: %r' % line) + header, value = m.groups() + headers.append((header, value)) + hdict[header.decode(encoding).lower()] = value.decode(encoding) + + # Read response body. + content_length = hdict.get('content-length') + if content_length is not None: + size = int(content_length) # TODO: Catch errors. + assert size >= 0, size + else: + size = 2**20 # Protective limit (1 MB). + data = yield from rdr.readexactly(size) + trans.close() # Can this block? + t1 = time.time() + result = (host, port, path, int(status), len(data), round(t1-t0, 3)) +## print(result) + return result diff --git a/old/http_server.py b/old/http_server.py new file mode 100644 index 00000000..2b1e3dd6 --- /dev/null +++ b/old/http_server.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3.3 +"""Simple HTTP server. + +This currenty exists just so we can benchmark this thing! +""" + +# Stdlib imports. +import logging +import re +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + ##logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + + # Read but ignore request line. + request_line = yield from rdr.readline() + + # Consume headers but don't interpret them. + while True: + header_line = yield from rdr.readline() + if not header_line.strip(): + break + + # Always send an empty 200 response and close. + yield from trans.send(b'HTTP/1.0 200 Ok\r\n\r\n') + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 8080, + af=socket.AF_INET) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() diff --git a/old/main.py b/old/main.py new file mode 100644 index 00000000..c1f9d0a8 --- /dev/null +++ b/old/main.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3.3 +"""Example HTTP client using yield-from coroutines (PEP 380). + +Requires Python 3.3. + +There are many micro-optimizations possible here, but that's not the point. + +Some incomplete laundry lists: + +TODO: +- Take test urls from command line. +- Move urlfetch to a separate module. +- Profiling. +- Docstrings. +- Unittests. + +FUNCTIONALITY: +- Connection pool (keep connection open). +- Chunked encoding (request and response). +- Pipelining, e.g. zlib (request and response). +- Automatic encoding/decoding. +""" + +__author__ = 'Guido van Rossum ' + +# Standard library imports (keep in alphabetic order). +import logging +import os +import time +import socket +import sys + +# Local imports (keep in alphabetic order). +import scheduling +import http_client + + + +def doit2(): + argses = [ + ('localhost', 8080, '/'), + ('127.0.0.1', 8080, '/home'), + ('python.org', 80, '/'), + ('xkcd.com', 443, '/'), + ] + results = yield from scheduling.map_over( + lambda args: http_client.urlfetch(*args), argses, timeout=2) + for res in results: + print('-->', res) + return [] + + +def doit(): + TIMEOUT = 2 + tasks = set() + + # This references NDB's default test service. + # (Sadly the service is single-threaded.) + task1 = scheduling.Task(http_client.urlfetch('localhost', 8080, path='/'), + 'root', timeout=TIMEOUT) + tasks.add(task1) + task2 = scheduling.Task(http_client.urlfetch('127.0.0.1', 8080, + path='/home'), + 'home', timeout=TIMEOUT) + tasks.add(task2) + + # Fetch python.org home page. + task3 = scheduling.Task(http_client.urlfetch('python.org', 80, path='/'), + 'python', timeout=TIMEOUT) + tasks.add(task3) + + # Fetch XKCD home page using SSL. (Doesn't like IPv6.) + task4 = scheduling.Task(http_client.urlfetch('xkcd.com', ssl=True, path='/', + af=socket.AF_INET), + 'xkcd', timeout=TIMEOUT) + tasks.add(task4) + +## # Fetch many links from python.org (/x.y.z). +## for x in '123': +## for y in '0123456789': +## path = '/{}.{}'.format(x, y) +## g = http_client.urlfetch('82.94.164.162', 80, +## path=path, hdrs={'host': 'python.org'}) +## t = scheduling.Task(g, path, timeout=2) +## tasks.add(t) + +## print(tasks) + yield from scheduling.Task(scheduling.sleep(1), timeout=0.2).wait() + winners = yield from scheduling.wait_any(tasks) + print('And the winners are:', [w.name for w in winners]) + tasks = yield from scheduling.wait_all(tasks) + print('And the players were:', [t.name for t in tasks]) + return tasks + + +def logtimes(real): + utime, stime, cutime, cstime, unused = os.times() + logging.info('real %10.3f', real) + logging.info('user %10.3f', utime + cutime) + logging.info('sys %10.3f', stime + cstime) + + +def main(): + t0 = time.time() + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + task = scheduling.run(doit()) + if task.exception: + print('Exception:', repr(task.exception)) + if isinstance(task.exception, AssertionError): + raise task.exception + else: + for t in task.result: + print(t.name + ':', + repr(t.exception) if t.exception else t.result) + + # Report real, user, sys times. + t1 = time.time() + logtimes(t1-t0) + + +if __name__ == '__main__': + main() diff --git a/old/p3time.py b/old/p3time.py new file mode 100644 index 00000000..35e14c96 --- /dev/null +++ b/old/p3time.py @@ -0,0 +1,47 @@ +"""Compare timing of plain vs. yield-from calls.""" + +import gc +import time + +def plain(n): + if n <= 0: + return 1 + l = plain(n-1) + r = plain(n-1) + return l + 1 + r + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def submain(depth): + t0 = time.time() + k = plain(depth) + t1 = time.time() + fmt = ' {} {} {:-9,.5f}' + delta0 = t1-t0 + print(('plain' + fmt).format(depth, k, delta0)) + + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + delta1 = t1-t0 + print(('coro.' + fmt).format(depth, k, delta1)) + if delta0: + print(('relat' + fmt).format(depth, k, delta1/delta0)) + +def main(reasonable=16): + gc.disable() + for depth in range(reasonable): + submain(depth) + +if __name__ == '__main__': + main() diff --git a/old/polling.py b/old/polling.py new file mode 100644 index 00000000..6586efcc --- /dev/null +++ b/old/polling.py @@ -0,0 +1,535 @@ +"""Event loop and related classes. + +The event loop can be broken up into a pollster (the part responsible +for telling us when file descriptors are ready) and the event loop +proper, which wraps a pollster with functionality for scheduling +callbacks, immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. + +There are several implementations of the pollster part, several using +esoteric system calls that exist only on some platforms. These are: + +- kqueue (most BSD systems) +- epoll (newer Linux systems) +- poll (most UNIX systems) +- select (all UNIX systems, and Windows) +- TODO: Support IOCP on Windows and some UNIX platforms. + +NOTE: We don't use select on systems where any of the others is +available, because select performs poorly as the number of file +descriptors goes up. The ranking is roughly: + + 1. kqueue, epoll, IOCP + 2. poll + 3. select + +TODO: +- Optimize the various pollsters. +- Unittests. +""" + +import collections +import concurrent.futures +import heapq +import logging +import os +import select +import threading +import time + + +class PollsterBase: + """Base class for all polling implementations. + + This defines an interface to register and unregister readers and + writers for specific file descriptors, and an interface to get a + list of events. There's also an interface to check whether any + readers or writers are currently registered. + """ + + def __init__(self): + super().__init__() + self.readers = {} # {fd: token, ...}. + self.writers = {} # {fd: token, ...}. + + def pollable(self): + """Return True if any readers or writers are currently registered.""" + return bool(self.readers or self.writers) + + # Subclasses are expected to extend the add/remove methods. + + def register_reader(self, fd, token): + """Add or update a reader for a file descriptor.""" + self.readers[fd] = token + + def register_writer(self, fd, token): + """Add or update a writer for a file descriptor.""" + self.writers[fd] = token + + def unregister_reader(self, fd): + """Remove the reader for a file descriptor.""" + del self.readers[fd] + + def unregister_writer(self, fd): + """Remove the writer for a file descriptor.""" + del self.writers[fd] + + def poll(self, timeout=None): + """Poll for events. A subclass must implement this. + + If timeout is omitted or None, this blocks until at least one + event is ready. Otherwise, timeout gives a maximum time to + wait (in seconds as an int or float) -- the method returns as + soon as at least one event is ready or when the timeout is + expired. For a non-blocking poll, pass 0. + + The return value is a list of events; it is empty when the + timeout expired before any events were ready. Each event + is a token previously passed to register_reader/writer(). + """ + raise NotImplementedError + + +class SelectPollster(PollsterBase): + """Pollster implementation using select.""" + + def poll(self, timeout=None): + readable, writable, _ = select.select(self.readers, self.writers, + [], timeout) + events = [] + events += (self.readers[fd] for fd in readable) + events += (self.writers[fd] for fd in writable) + return events + + +class PollPollster(PollsterBase): + """Pollster implementation using poll.""" + + def __init__(self): + super().__init__() + self._poll = select.poll() + + def _update(self, fd): + assert isinstance(fd, int), fd + flags = 0 + if fd in self.readers: + flags |= select.POLLIN + if fd in self.writers: + flags |= select.POLLOUT + if flags: + self._poll.register(fd, flags) + else: + self._poll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + # Timeout is in seconds, but poll() takes milliseconds. + msecs = None if timeout is None else int(round(1000 * timeout)) + events = [] + for fd, flags in self._poll.poll(msecs): + if flags & (select.POLLIN | select.POLLHUP): + if fd in self.readers: + events.append(self.readers[fd]) + if flags & (select.POLLOUT | select.POLLHUP): + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class EPollPollster(PollsterBase): + """Pollster implementation using epoll.""" + + def __init__(self): + super().__init__() + self._epoll = select.epoll() + + def _update(self, fd): + assert isinstance(fd, int), fd + eventmask = 0 + if fd in self.readers: + eventmask |= select.EPOLLIN + if fd in self.writers: + eventmask |= select.EPOLLOUT + if eventmask: + try: + self._epoll.register(fd, eventmask) + except IOError: + self._epoll.modify(fd, eventmask) + else: + self._epoll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + if timeout is None: + timeout = -1 # epoll.poll() uses -1 to mean "wait forever". + events = [] + for fd, eventmask in self._epoll.poll(timeout): + if eventmask & select.EPOLLIN: + if fd in self.readers: + events.append(self.readers[fd]) + if eventmask & select.EPOLLOUT: + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class KqueuePollster(PollsterBase): + """Pollster implementation using kqueue.""" + + def __init__(self): + super().__init__() + self._kqueue = select.kqueue() + + def register_reader(self, fd, callback, *args): + if fd not in self.readers: + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_reader(fd, callback, *args) + + def register_writer(self, fd, callback, *args): + if fd not in self.writers: + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_writer(fd, callback, *args) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def poll(self, timeout=None): + events = [] + max_ev = len(self.readers) + len(self.writers) + for kev in self._kqueue.control(None, max_ev, timeout): + fd = kev.ident + flag = kev.filter + if flag == select.KQ_FILTER_READ and fd in self.readers: + events.append(self.readers[fd]) + elif flag == select.KQ_FILTER_WRITE and fd in self.writers: + events.append(self.writers[fd]) + return events + + +# Pick the best pollster class for the platform. +if hasattr(select, 'kqueue'): + best_pollster = KqueuePollster +elif hasattr(select, 'epoll'): + best_pollster = EPollPollster +elif hasattr(select, 'poll'): + best_pollster = PollPollster +else: + best_pollster = SelectPollster + + +class DelayedCall: + """Object returned by callback registration methods.""" + + def __init__(self, when, callback, args, kwds=None): + self.when = when + self.callback = callback + self.args = args + self.kwds = kwds + self.cancelled = False + + def cancel(self): + self.cancelled = True + + def __lt__(self, other): + return self.when < other.when + + def __le__(self, other): + return self.when <= other.when + + def __eq__(self, other): + return self.when == other.when + + +class EventLoop: + """Event loop functionality. + + This defines public APIs call_soon(), call_later(), run_once() and + run(). It also wraps Pollster APIs register_reader(), + register_writer(), remove_reader(), remove_writer() with + add_reader() etc. + + This class's instance variables are not part of its API. + """ + + def __init__(self, pollster=None): + super().__init__() + if pollster is None: + logging.info('Using pollster: %s', best_pollster.__name__) + pollster = best_pollster() + self.pollster = pollster + self.ready = collections.deque() # [(callback, args), ...] + self.scheduled = [] # [(when, callback, args), ...] + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_reader(fd, dcall) + return dcall + + def remove_reader(self, fd): + """Remove a reader callback.""" + self.pollster.unregister_reader(fd) + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_writer(fd, dcall) + return dcall + + def remove_writer(self, fd): + """Remove a writer callback.""" + self.pollster.unregister_writer(fd) + + def add_callback(self, dcall): + """Add a DelayedCall to ready or scheduled.""" + if dcall.cancelled: + return + if dcall.when is None: + self.ready.append(dcall) + else: + heapq.heappush(self.scheduled, dcall) + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + dcall = DelayedCall(None, callback, args) + self.ready.append(dcall) + return dcall + + def call_later(self, when, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The time can be an int or float, expressed in seconds. + + If when is small enough (~11 days), it's assumed to be a + relative time, meaning the call will be scheduled that many + seconds in the future; otherwise it's assumed to be a posix + timestamp as returned by time.time(). + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + if when < 10000000: + when += time.time() + dcall = DelayedCall(when, callback, args) + heapq.heappush(self.scheduled, dcall) + return dcall + + def run_once(self): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Pass in a timeout or deadline or something. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: As step 4, run everything scheduled by steps 1-3. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # TODO: Ensure this loop always finishes, even if some + # callbacks keeps registering more callbacks. + while self.ready: + dcall = self.ready.popleft() + if not dcall.cancelled: + try: + if dcall.kwds: + dcall.callback(*dcall.args, **dcall.kwds) + else: + dcall.callback(*dcall.args) + except Exception: + logging.exception('Exception in callback %s %r', + dcall.callback, dcall.args) + + # Remove delayed calls that were cancelled from head of queue. + while self.scheduled and self.scheduled[0].cancelled: + heapq.heappop(self.scheduled) + + # Inspect the poll queue. + if self.pollster.pollable(): + if self.scheduled: + when = self.scheduled[0].when + timeout = max(0, when - time.time()) + else: + timeout = None + t0 = time.time() + events = self.pollster.poll(timeout) + t1 = time.time() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + for dcall in events: + self.add_callback(dcall) + + # Handle 'later' callbacks that are ready. + now = time.time() + while self.scheduled: + dcall = self.scheduled[0] + if dcall.when > now: + break + dcall = heapq.heappop(self.scheduled) + self.call_soon(dcall.callback, *dcall.args) + + def run(self): + """Run the event loop until there is no work left to do. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + """ + while self.ready or self.scheduled or self.pollster.pollable(): + self.run_once() + + +MAX_WORKERS = 5 # Default max workers when creating an executor. + + +class ThreadRunner: + """Helper to submit work to a thread pool and wait for it. + + This is the glue between the single-threaded callback-based async + world and the threaded world. Use it to call functions that must + block and don't have an async alternative (e.g. getaddrinfo()). + + The only public API is submit(). + """ + + def __init__(self, eventloop, executor=None): + self.eventloop = eventloop + self.executor = executor # Will be constructed lazily. + self.pipe_read_fd, self.pipe_write_fd = os.pipe() + self.active_count = 0 + + def read_callback(self): + """Semi-permanent callback while at least one future is active.""" + assert self.active_count > 0, self.active_count + data = os.read(self.pipe_read_fd, 8192) # Traditional buffer size. + self.active_count -= len(data) + if self.active_count == 0: + self.eventloop.remove_reader(self.pipe_read_fd) + assert self.active_count >= 0, self.active_count + + def submit(self, func, *args, executor=None, callback=None): + """Submit a function to the thread pool. + + This returns a concurrent.futures.Future instance. The caller + should not wait for that, but rather use the callback argument.. + """ + if executor is None: + executor = self.executor + if executor is None: + # Lazily construct a default executor. + # TODO: Should this be shared between threads? + executor = concurrent.futures.ThreadPoolExecutor(MAX_WORKERS) + self.executor = executor + assert self.active_count >= 0, self.active_count + future = executor.submit(func, *args) + if self.active_count == 0: + self.eventloop.add_reader(self.pipe_read_fd, self.read_callback) + self.active_count += 1 + def done_callback(fut): + if callback is not None: + self.eventloop.call_soon(callback, fut) + # TODO: Wake up the pipe in call_soon()? + os.write(self.pipe_write_fd, b'x') + future.add_done_callback(done_callback) + return future + + +class Context(threading.local): + """Thread-local context. + + We use this to avoid having to explicitly pass around an event loop + or something to hold the current task. + + TODO: Add an API so frameworks can substitute a different notion + of context more easily. + """ + + def __init__(self, eventloop=None, threadrunner=None): + # Default event loop and thread runner are lazily constructed + # when first accessed. + self._eventloop = eventloop + self._threadrunner = threadrunner + self.current_task = None # For the benefit of scheduling.py. + + @property + def eventloop(self): + if self._eventloop is None: + self._eventloop = EventLoop() + return self._eventloop + + @property + def threadrunner(self): + if self._threadrunner is None: + self._threadrunner = ThreadRunner(self.eventloop) + return self._threadrunner + + +context = Context() # Thread-local! diff --git a/old/scheduling.py b/old/scheduling.py new file mode 100644 index 00000000..3864571d --- /dev/null +++ b/old/scheduling.py @@ -0,0 +1,354 @@ +#!/usr/bin/env python3.3 +"""Example coroutine scheduler, PEP-380-style ('yield from '). + +Requires Python 3.3. + +There are likely micro-optimizations possible here, but that's not the point. + +TODO: +- Docstrings. +- Unittests. + +PATTERNS TO TRY: +- Various synchronization primitives (Lock, RLock, Event, Condition, + Semaphore, BoundedSemaphore, Barrier). +""" + +__author__ = 'Guido van Rossum ' + +# Standard library imports (keep in alphabetic order). +from concurrent.futures import CancelledError, TimeoutError +import logging +import time +import types + +# Local imports (keep in alphabetic order). +import polling + + +context = polling.context + + +class Task: + """Wrapper around a stack of generators. + + This is a bit like a Future, but with a different interface. + + TODO: + - wait for result. + """ + + def __init__(self, gen, name=None, *, timeout=None): + assert isinstance(gen, types.GeneratorType), repr(gen) + self.gen = gen + self.name = name or gen.__name__ + self.timeout = timeout + self.eventloop = context.eventloop + self.canceleer = None + if timeout is not None: + self.canceleer = self.eventloop.call_later(timeout, self.cancel) + self.blocked = False + self.unblocker = None + self.cancelled = False + self.must_cancel = False + self.alive = True + self.result = None + self.exception = None + self.done_callbacks = [] + # Start the task immediately. + self.eventloop.call_soon(self.step) + + def add_done_callback(self, done_callback): + # For better or for worse, the callback will always be called + # with the task as an argument, like concurrent.futures.Future. + # TODO: Call it right away if task is no longer alive. + dcall = polling.DelayedCall(None, done_callback, (self,)) + self.done_callbacks.append(dcall) + self.done_callbacks = [dc for dc in self.done_callbacks + if not dc.cancelled] + return dcall + + def __repr__(self): + parts = [self.name] + is_current = (self is context.current_task) + if self.blocked: + parts.append('blocking' if is_current else 'blocked') + elif self.alive: + parts.append('running' if is_current else 'runnable') + if self.must_cancel: + parts.append('must_cancel') + if self.cancelled: + parts.append('cancelled') + if self.exception is not None: + parts.append('exception=%r' % self.exception) + elif not self.alive: + parts.append('result=%r' % (self.result,)) + if self.timeout is not None: + parts.append('timeout=%.3f' % self.timeout) + return 'Task<' + ', '.join(parts) + '>' + + def cancel(self): + if self.alive: + if not self.must_cancel and not self.cancelled: + self.must_cancel = True + if self.blocked: + self.unblock() + + def step(self): + assert self.alive, self + try: + context.current_task = self + if self.must_cancel: + self.must_cancel = False + self.cancelled = True + self.gen.throw(CancelledError()) + else: + next(self.gen) + except StopIteration as exc: + self.alive = False + self.result = exc.value + except Exception as exc: + self.alive = False + self.exception = exc + logging.debug('Uncaught exception in %s', self, + exc_info=True, stack_info=True) + except BaseException as exc: + self.alive = False + self.exception = exc + raise + else: + if not self.blocked: + self.eventloop.call_soon(self.step) + finally: + context.current_task = None + if not self.alive: + # Cancel timeout callback if set. + if self.canceleer is not None: + self.canceleer.cancel() + # Schedule done_callbacks. + for dcall in self.done_callbacks: + self.eventloop.add_callback(dcall) + + def block(self, unblock_callback=None, *unblock_args): + assert self is context.current_task, self + assert self.alive, self + assert not self.blocked, self + self.blocked = True + self.unblocker = (unblock_callback, unblock_args) + + def unblock_if_alive(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. + if self.alive: + self.unblock() + + def unblock(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. + assert self.alive, self + assert self.blocked, self + self.blocked = False + unblock_callback, unblock_args = self.unblocker + if unblock_callback is not None: + try: + unblock_callback(*unblock_args) + except Exception: + logging.error('Exception in unblocker in task %r', self.name) + raise + finally: + self.unblocker = None + self.eventloop.call_soon(self.step) + + def block_io(self, fd, flag): + assert isinstance(fd, int), repr(fd) + assert flag in ('r', 'w'), repr(flag) + if flag == 'r': + self.block(self.eventloop.remove_reader, fd) + self.eventloop.add_reader(fd, self.unblock) + else: + self.block(self.eventloop.remove_writer, fd) + self.eventloop.add_writer(fd, self.unblock) + + def wait(self): + """COROUTINE: Wait until this task is finished.""" + current_task = context.current_task + assert self is not current_task, (self, current_task) # How confusing! + if not self.alive: + return + current_task.block() + self.add_done_callback(current_task.unblock) + yield + + def __iter__(self): + """COROUTINE: Wait, then return result or raise exception. + + This adds a little magic so you can say + + x = yield from Task(gen()) + + and it is equivalent to + + x = yield from gen() + + but with the option to add a timeout (and only a tad slower). + """ + if self.alive: + yield from self.wait() + assert not self.alive + if self.exception is not None: + raise self.exception + return self.result + + +def run(arg=None): + """Run the event loop until it's out of work. + + If you pass a generator, it will be spawned for you. + You can also pass a task (already started). + Returns the task. + """ + t = None + if arg is not None: + if isinstance(arg, Task): + t = arg + else: + t = Task(arg) + context.eventloop.run() + if t is not None and t.exception is not None: + logging.error('Uncaught exception in startup task: %r', + t.exception) + return t + + +def sleep(secs): + """COROUTINE: Sleep for some time (a float in seconds).""" + current_task = context.current_task + unblocker = context.eventloop.call_later(secs, current_task.unblock) + current_task.block(unblocker.cancel) + yield + + +def block_r(fd): + """COROUTINE: Block until a file descriptor is ready for reading.""" + context.current_task.block_io(fd, 'r') + yield + + +def block_w(fd): + """COROUTINE: Block until a file descriptor is ready for writing.""" + context.current_task.block_io(fd, 'w') + yield + + +def call_in_thread(func, *args, executor=None): + """COROUTINE: Run a function in a thread.""" + task = context.current_task + eventloop = context.eventloop + future = context.threadrunner.submit(func, *args, + executor=executor, + callback=task.unblock_if_alive) + task.block(future.cancel) + yield + assert future.done() + return future.result() + + +def wait_for(count, tasks): + """COROUTINE: Wait for the first N of a set of tasks to complete. + + May return more than N if more than N are immediately ready. + + NOTE: Tasks that were cancelled or raised are also considered ready. + """ + assert tasks + assert all(isinstance(task, Task) for task in tasks) + tasks = set(tasks) + assert 1 <= count <= len(tasks) + current_task = context.current_task + assert all(task is not current_task for task in tasks) + todo = set() + done = set() + dcalls = [] + def wait_for_callback(task): + nonlocal todo, done, current_task, count, dcalls + todo.remove(task) + if len(done) < count: + done.add(task) + if len(done) == count: + for dcall in dcalls: + dcall.cancel() + current_task.unblock() + for task in tasks: + if task.alive: + todo.add(task) + else: + done.add(task) + if len(done) < count: + for task in todo: + dcall = task.add_done_callback(wait_for_callback) + dcalls.append(dcall) + current_task.block() + yield + return done + + +def wait_any(tasks): + """COROUTINE: Wait for the first of a set of tasks to complete.""" + return wait_for(1, tasks) + + +def wait_all(tasks): + """COROUTINE: Wait for all of a set of tasks to complete.""" + return wait_for(len(tasks), tasks) + + +def map_over(gen, *args, timeout=None): + """COROUTINE: map a generator over one or more iterables. + + E.g. map_over(foo, xs, ys) runs + + Task(foo(x, y)) for x, y in zip(xs, ys) + + and returns a list of all results (in that order). However if any + task raises an exception, the remaining tasks are cancelled and + the exception is propagated. + """ + # gen is a generator function. + tasks = [Task(gobj, timeout=timeout) for gobj in map(gen, *args)] + return (yield from par_tasks(tasks)) + + +def par(*args): + """COROUTINE: Wait for generators, return a list of results. + + Raises as soon as one of the tasks raises an exception (and then + remaining tasks are cancelled). + + This differs from par_tasks() in two ways: + - takes *args instead of list of args + - each arg may be a generator or a task + """ + tasks = [] + for arg in args: + if not isinstance(arg, Task): + # TODO: assert arg is a generator or an iterator? + arg = Task(arg) + tasks.append(arg) + return (yield from par_tasks(tasks)) + + +def par_tasks(tasks): + """COROUTINE: Wait for a list of tasks, return a list of results. + + Raises as soon as one of the tasks raises an exception (and then + remaining tasks are cancelled). + """ + todo = set(tasks) + while todo: + ts = yield from wait_any(todo) + for t in ts: + assert not t.alive, t + todo.remove(t) + if t.exception is not None: + for other in todo: + other.cancel() + raise t.exception + return [t.result for t in tasks] diff --git a/old/sockets.py b/old/sockets.py new file mode 100644 index 00000000..a5005dc3 --- /dev/null +++ b/old/sockets.py @@ -0,0 +1,348 @@ +"""Socket wrappers to go with scheduling.py. + +Classes: + +- SocketTransport: a transport implementation wrapping a socket. +- SslTransport: a transport implementation wrapping SSL around a socket. +- BufferedReader: a buffer wrapping the read end of a transport. + +Functions (all coroutines): + +- connect(): connect a socket. +- getaddrinfo(): look up an address. +- create_connection(): look up address and return a connected socket for it. +- create_transport(): look up address and return a connected transport. + +TODO: +- Improve transport abstraction. +- Make a nice protocol abstraction. +- Unittests. +- A write() call that isn't a generator (needed so you can substitute it + for sys.stderr, pass it to logging.StreamHandler, etc.). +""" + +__author__ = 'Guido van Rossum ' + +# Stdlib imports. +import errno +import socket +import ssl + +# Local imports. +import scheduling + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + + +class SocketTransport: + """Transport wrapping a socket. + + The socket must already be connected in non-blocking mode. + """ + + def __init__(self, sock): + self.sock = sock + + def recv(self, n): + """COROUTINE: Read up to n bytes, blocking as needed. + + Always returns at least one byte, except if the socket was + closed or disconnected and there's no more data; then it + returns b''. + """ + assert n >= 0, n + while True: + try: + return self.sock.recv(n) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return b'' + else: + raise # Unexpected, propagate. + yield from scheduling.block_r(self.sock.fileno()) + + def send(self, data): + """COROUTINE; Send data to the socket, blocking until all written. + + Return True if all went well, False if socket was disconnected. + """ + while data: + try: + n = self.sock.send(data) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + else: + assert 0 <= n <= len(data), (n, len(data)) + if n == len(data): + break + data = data[n:] + continue + yield from scheduling.block_w(self.sock.fileno()) + + return True + + def close(self): + """Close the socket. (Not a coroutine.)""" + self.sock.close() + + +class SslTransport: + """Transport wrapping a socket in SSL. + + The socket must already be connected at the TCP level in + non-blocking mode. + """ + + def __init__(self, rawsock, sslcontext=None): + self.rawsock = rawsock + self.sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self.sslsock = self.sslcontext.wrap_socket( + self.rawsock, do_handshake_on_connect=False) + + def do_handshake(self): + """COROUTINE: Finish the SSL handshake.""" + while True: + try: + self.sslsock.do_handshake() + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + else: + break + + def recv(self, n): + """COROUTINE: Read up to n bytes. + + This blocks until at least one byte is read, or until EOF. + """ + while True: + try: + return self.sslsock.recv(n) + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_r(self.sslsock.fileno()) + elif err.errno in _DISCONNECTED: + # Can this happen? + return b'' + else: + raise # Unexpected, propagate. + + def send(self, data): + """COROUTINE: Send data to the socket, blocking as needed.""" + while data: + try: + n = self.sslsock.send(data) + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_w(self.sslsock.fileno()) + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + if n == len(data): + break + data = data[n:] + + return True + + def close(self): + """Close the socket. (Not a coroutine.) + + This also closes the raw socket. + """ + self.sslsock.close() + + # TODO: More SSL-specific methods, e.g. certificate stuff, unwrap(), ... + + +class BufferedReader: + """A buffered reader wrapping a transport.""" + + def __init__(self, trans, limit=8192): + self.trans = trans + self.limit = limit + self.buffer = b'' + self.eof = False + + def read(self, n): + """COROUTINE: Read up to n bytes, blocking at most once.""" + assert n >= 0, n + if not self.buffer and not self.eof: + yield from self._fillbuffer(max(n, self.limit)) + return self._getfrombuffer(n) + + def readexactly(self, n): + """COUROUTINE: Read exactly n bytes, or until EOF.""" + blocks = [] + count = 0 + while count < n: + block = yield from self.read(n - count) + if not block: + break + blocks.append(block) + count += len(block) + return b''.join(blocks) + + def readline(self): + """COROUTINE: Read up to newline or limit, whichever comes first.""" + end = self.buffer.find(b'\n') + 1 # Point past newline, or 0. + while not end and not self.eof and len(self.buffer) < self.limit: + anchor = len(self.buffer) + yield from self._fillbuffer(self.limit) + end = self.buffer.find(b'\n', anchor) + 1 + if not end: + end = len(self.buffer) + if end > self.limit: + end = self.limit + return self._getfrombuffer(end) + + def _getfrombuffer(self, n): + """Read up to n bytes without blocking (not a coroutine).""" + if n >= len(self.buffer): + result, self.buffer = self.buffer, b'' + else: + result, self.buffer = self.buffer[:n], self.buffer[n:] + return result + + def _fillbuffer(self, n): + """COROUTINE: Fill buffer with one (up to) n bytes from transport.""" + assert not self.eof, '_fillbuffer called at eof' + data = yield from self.trans.recv(n) + if data: + self.buffer += data + else: + self.eof = True + + +def connect(sock, address): + """COROUTINE: Connect a socket to an address.""" + try: + sock.connect(address) + except socket.error as err: + if err.errno != errno.EINPROGRESS: + raise + yield from scheduling.block_w(sock.fileno()) + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + raise IOError(err, 'Connection refused') + + +def getaddrinfo(host, port, af=0, socktype=0, proto=0): + """COROUTINE: Look up an address and return a list of infos for it. + + Each info is a tuple (af, socktype, protocol, canonname, address). + """ + infos = yield from scheduling.call_in_thread(socket.getaddrinfo, + host, port, af, + socktype, proto) + return infos + + +def create_connection(host, port, af=0, socktype=socket.SOCK_STREAM, proto=0): + """COROUTINE: Look up address and create a socket connected to it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + yield from connect(sock, address) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return sock + + +def create_transport(host, port, af=0, ssl=None): + """COROUTINE: Look up address and create a transport connected to it.""" + if ssl is None: + ssl = (port == 443) + sock = yield from create_connection(host, port, af) + if ssl: + trans = SslTransport(sock) + yield from trans.do_handshake() + else: + trans = SocketTransport(sock) + return trans + + +class Listener: + """Wrapper for a listening socket.""" + + def __init__(self, sock): + self.sock = sock + + def accept(self): + """COROUTINE: Accept a connection.""" + while True: + try: + conn, addr = self.sock.accept() + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_r(self.sock.fileno()) + else: + raise # Unexpected, propagate. + else: + conn.setblocking(False) + return conn, addr + + +def create_listener(host, port, af=0, socktype=0, proto=0, + backlog=5, reuse_addr=True): + """COROUTINE: Look up address and create a listener for it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + if reuse_addr: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + sock.listen(backlog) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return Listener(sock) diff --git a/old/transports.py b/old/transports.py new file mode 100644 index 00000000..19095bf4 --- /dev/null +++ b/old/transports.py @@ -0,0 +1,496 @@ +"""Transports and Protocols, actually. + +Inspired by Twisted, PEP 3153 and github.com/lvh/async-pep. + +THIS IS NOT REAL CODE! IT IS JUST AN EXPERIMENT. +""" + +# Stdlib imports. +import collections +import errno +import logging +import socket +import ssl +import sys +import time + +# Local imports. +import polling +import scheduling +import sockets + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + + +class Transport: + """ABC representing a transport. + + There may be many implementations. The user never instantiates + this directly; they call some utility function, passing it a + protocol, and the utility function will call the protocol's + connection_made() method with a transport (or it will call + connection_lost() with an exception if it fails to create the + desired transport). + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + def write(self, data): + """Write some data (bytes) to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data (bytes) to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data will + be received. When all buffered data is flushed, the protocol's + connection_lost() method is called with None as its argument. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method is called with None as + its argument. + """ + raise NotImplementedError + + def half_close(self): + """Closes the write end after flushing buffered data. + + Data may still be received. + + TODO: What's the use case for this? How to implement it? + Should it call shutdown(SHUT_WR) after all the data is flushed? + Is there no use case for closing the other half first? + """ + raise NotImplementedError + + def pause(self): + """Pause the receiving end. + + No data will be received until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Cancels a pause() call, resumes receiving data. + """ + raise NotImplementedError + + +class Protocol: + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing. + + When the user wants to requests a transport, they pass a protocol + instance to a utility function. + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_list() will be called exactly once + with either an exception object or None as an argument. + + If the utility function does not succeed in creating a transport, + it will call connection_lost() with an exception object. + + State machine of calls: + + start -> [CM -> DR*] -> CL -> end + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the connection. + To send data, call its write() or writelines() method. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + + TODO: Should we allow it to be a bytesarray or some other + memory buffer? + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + Also called when we fail to make a connection at all (in that + case connection_made() will not be called). + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +# TODO: The rest is platform specific and should move elsewhere. + +class UnixSocketTransport(Transport): + + def __init__(self, eventloop, protocol, sock): + self._eventloop = eventloop + self._protocol = protocol + self._sock = sock + self._buffer = collections.deque() # For write(). + self._write_closed = False + + def _on_readable(self): + try: + data = self._sock.recv(8192) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + else: + if not data: + self._eventloop.remove_reader(self._sock.fileno()) + self._sock.close() + self._protocol.connection_lost(None) + else: + self._protocol.data_received(data) # XXX call_soon()? + + def write(self, data): + assert isinstance(data, bytes) + assert not self._write_closed + if not data: + # Silly, but it happens. + return + if self._buffer: + # We've already registered a callback, just buffer the data. + self._buffer.append(data) + # Consider pausing if the total length of the buffer is + # truly huge. + return + + # TODO: Refactor so there's more sharing between this and + # _on_writable(). + + # There's no callback registered yet. It's quite possible + # that the kernel has buffer space for our data, so try to + # write now. Since the socket is non-blocking it will + # give us an error in _TRYAGAIN if it doesn't have enough + # space for even one more byte; it will return the number + # of bytes written if it can write at least one byte. + try: + n = self._sock.send(data) + except socket.error as exc: + # An error. + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + # The kernel doesn't have room for more data right now. + n = 0 + else: + # Wrote at least one byte. + if n == len(data): + # Wrote it all. Done! + if self._write_closed: + self._sock.shutdown(socket.SHUT_WR) + return + # Throw away the data that was already written. + # TODO: Do this without copying the data? + data = data[n:] + self._buffer.append(data) + self._eventloop.add_writer(self._sock.fileno(), self._on_writable) + + def _on_writable(self): + while self._buffer: + data = self._buffer[0] + # TODO: Join small amounts of data? + try: + n = self._sock.send(data) + except socket.error as exc: + # Error handling is the same as in write(). + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + if n < len(data): + self._buffer[0] = data[n:] + return + self._buffer.popleft() + self._eventloop.remove_writer(self._sock.fileno()) + if self._write_closed: + self._sock.shutdown(socket.SHUT_WR) + + def abort(self): + self._bad_error(None) + + def _bad_error(self, exc): + # A serious error. Close the socket etc. + fd = self._sock.fileno() + # TODO: Record whether we have a writer and/or reader registered. + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + self._sock.close() + self._protocol.connection_lost(exc) # XXX call_soon()? + + def half_close(self): + self._write_closed = True + + +class UnixSslTransport(Transport): + + # TODO: Refactor Socket and Ssl transport to share some code. + # (E.g. buffering.) + + # TODO: Consider using coroutines instead of callbacks, it seems + # much easier that way. + + def __init__(self, eventloop, protocol, rawsock, sslcontext=None): + self._eventloop = eventloop + self._protocol = protocol + self._rawsock = rawsock + self._sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslsock = self._sslcontext.wrap_socket( + self._rawsock, do_handshake_on_connect=False) + + self._buffer = collections.deque() # For write(). + self._write_closed = False + + # Try the handshake now. Likely it will raise EAGAIN, then it + # will take care of registering the appropriate callback. + self._on_handshake() + + def _bad_error(self, exc): + # A serious error. Close the socket etc. + fd = self._sslsock.fileno() + # TODO: Record whether we have a writer and/or reader registered. + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + self._sslsock.close() + self._protocol.connection_lost(exc) # XXX call_soon()? + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._eventloop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._eventloop.add_writable(fd, self._on_handshake) + return + # TODO: What if it raises another error? + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + self._protocol.connection_made(self) + self._eventloop.add_reader(fd, self._on_ready) + self._eventloop.add_writer(fd, self._on_ready) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + else: + if data: + self._protocol.data_received(data) + else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer + self._eventloop.remove_reader(fd) + self._eventloop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + return + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = self._buffer[0] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + else: + if n == len(data): + self._buffer.popleft() + # Could try again, but let's just have the next callback do it. + else: + self._buffer[0] = data[n:] + + def write(self, data): + assert isinstance(data, bytes) + assert not self._write_closed + if not data: + return + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + def half_close(self): + self._write_closed = True + # Just set the flag. Calling shutdown() on the ssl socket + # breaks something, causing recv() to return binary data. + + +def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, + use_ssl=None): + # TODO: Pass in a protocol factory, not a protocol. + # What should be the exact sequence of events? + # - socket + # - transport + # - protocol + # - tell transport about protocol + # - tell protocol about transport + # Or should the latter two be reversed? Does it matter? + if port is None: + port = 443 if use_ssl else 80 + if use_ssl is None: + use_ssl = (port == 443) + if not socktype: + socktype = socket.SOCK_STREAM + eventloop = polling.context.eventloop + + def on_socket_connected(task): + assert not task.alive + if task.exception is not None: + # TODO: Call some callback. + raise task.exception + sock = task.result + assert sock is not None + logging.debug('on_socket_connected') + if use_ssl: + # You can pass an ssl.SSLContext object as use_ssl, + # or a bool. + if isinstance(use_ssl, bool): + sslcontext = None + else: + sslcontext = use_ssl + transport = UnixSslTransport(eventloop, protocol, sock, sslcontext) + else: + transport = UnixSocketTransport(eventloop, protocol, sock) + # TODO: Should the ransport make the following calls? + protocol.connection_made(transport) # XXX call_soon()? + # Don't do this before connection_made() is called. + eventloop.add_reader(sock.fileno(), transport._on_readable) + + coro = sockets.create_connection(host, port, af, socktype, proto) + task = scheduling.Task(coro) + task.add_done_callback(on_socket_connected) + + +def main(): # Testing... + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + host = 'xkcd.com' + if sys.argv[1:] and '.' in sys.argv[-1]: + host = sys.argv[-1] + + t0 = time.time() + + class TestProtocol(Protocol): + def connection_made(self, transport): + logging.info('Connection made at %.3f secs', time.time() - t0) + self.transport = transport + self.transport.write(b'GET / HTTP/1.0\r\nHost: ' + + host.encode('ascii') + + b'\r\n\r\n') + self.transport.half_close() + def data_received(self, data): + logging.info('Received %d bytes at t=%.3f', + len(data), time.time() - t0) + logging.debug('Received %r', data) + def connection_lost(self, exc): + logging.debug('Connection lost: %r', exc) + self.t1 = time.time() + logging.info('Total time %.3f secs', self.t1 - t0) + + tp = TestProtocol() + logging.debug('tp = %r', tp) + make_connection(tp, host, use_ssl=('-S' in sys.argv)) + logging.info('Running...') + polling.context.eventloop.run() + logging.info('Done.') + + +if __name__ == '__main__': + main() diff --git a/old/xkcd.py b/old/xkcd.py new file mode 100755 index 00000000..474009d0 --- /dev/null +++ b/old/xkcd.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3.3 +"""Minimal synchronous SSL demo, connecting to xkcd.com.""" + +import socket, ssl + +s = socket.socket() +s.connect(('xkcd.com', 443)) +ss = ssl.wrap_socket(s) + +ss.send(b'GET / HTTP/1.0\r\n\r\n') + +while True: + data = ss.recv(1000000) + print(data) + if not data: + break + +ss.close() diff --git a/old/yyftime.py b/old/yyftime.py new file mode 100644 index 00000000..f55234b9 --- /dev/null +++ b/old/yyftime.py @@ -0,0 +1,75 @@ +"""Compare timing of yield-from vs. yield calls.""" + +import gc +import time + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def run_coro(depth): + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + print('coro', depth, k, round(t1-t0, 6)) + return t1-t0 + +class Future: + + def __init__(self, g): + self.g = g + + def wait(self): + value = None + try: + while True: + f = self.g.send(value) + f.wait() + value = f.value + except StopIteration as err: + self.value = err.value + + + +def task(func): # Decorator + def wrapper(*args): + g = func(*args) + f = Future(g) + return f + return wrapper + +@task +def oldstyle(n): + if n <= 0: + return 1 + l = yield oldstyle(n-1) + r = yield oldstyle(n-1) + return l + 1 + r + +def run_olds(depth): + t0 = time.time() + f = oldstyle(depth) + f.wait() + k = f.value + t1 = time.time() + print('olds', depth, k, round(t1-t0, 6)) + return t1-t0 + +def main(): + gc.disable() + for depth in range(16): + tc = run_coro(depth) + to = run_olds(depth) + if tc: + print('ratio', round(to/tc, 2)) + +if __name__ == '__main__': + main() diff --git a/runtests.py b/runtests.py new file mode 100644 index 00000000..6758c742 --- /dev/null +++ b/runtests.py @@ -0,0 +1,97 @@ +"""Run all unittests. + +Usage: + python3 runtests.py [-v] [-q] [pattern] ... + +Where: + -v: verbose + -q: quiet + pattern: optional regex patterns to match test ids (default all tests) + +Note that the test id is the fully qualified name of the test, +including package, module, class and method, +e.g. 'tulip.events_test.PolicyTests.testPolicy'. +""" + +# Originally written by Beech Horn (for NDB). + +import logging +import os +import re +import sys +import unittest + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +TULIP_DIR = os.path.join(os.path.dirname(__file__), 'tulip') + + +def load_tests(includes=(), excludes=()): + test_mods = [f[:-3] for f in os.listdir(TULIP_DIR) + if f.endswith('_test.py')] + + if sys.platform == 'win32': + try: + test_mods.remove('subprocess_test') + except ValueError: + pass + tulip = __import__('tulip', fromlist=test_mods) + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + for mod in [getattr(tulip, name) for name in test_mods]: + for name in set(dir(mod)): + if name.endswith('Tests'): + test_module = getattr(mod, name) + tests = loader.loadTestsFromTestCase(test_module) + if includes: + tests = [test + for test in tests + if any(re.search(pat, test.id()) + for pat in includes)] + if excludes: + tests = [test + for test in tests + if not any(re.search(pat, test.id()) + for pat in excludes)] + suite.addTests(tests) + + return suite + + +def main(): + excludes = [] + includes = [] + patterns = includes # A reference. + v = 1 + for arg in sys.argv[1:]: + if arg.startswith('-v'): + v += arg.count('v') + elif arg == '-q': + v = 0 + elif arg == '-x': + if patterns is includes: + patterns = excludes + else: + patterns = includes + elif arg and not arg.startswith('-'): + patterns.append(arg) + tests = load_tests(includes, excludes) + logger = logging.getLogger() + if v == 0: + logger.setLevel(logging.CRITICAL) + elif v == 1: + logger.setLevel(logging.ERROR) + elif v == 2: + logger.setLevel(logging.WARNING) + elif v == 3: + logger.setLevel(logging.INFO) + elif v >= 4: + logger.setLevel(logging.DEBUG) + result = unittest.TextTestRunner(verbosity=v).run(tests) + sys.exit(not result.wasSuccessful()) + + +if __name__ == '__main__': + main() diff --git a/srv.py b/srv.py new file mode 100644 index 00000000..077279c1 --- /dev/null +++ b/srv.py @@ -0,0 +1,131 @@ +"""Simple server written using an event loop.""" + +import email.message +import email.parser +import os +import re + +import tulip +from tulip.http_client import StreamReader + + +class HttpServer(tulip.Protocol): + + def __init__(self): + super().__init__() + self.transport = None + self.reader = None + self.handler = None + + @tulip.task + def handle_request(self): + line = yield from self.reader.readline() + print('request line', line) + match = re.match(rb'([A-Z]+) (\S+) HTTP/(1.\d)\r?\n\Z', line) + if not match: + self.transport.close() + return + bmethod, bpath, bversion = match.groups() + print('method = {!r}; path = {!r}; version = {!r}'.format( + bmethod, bpath, bversion)) + try: + path = bpath.decode('ascii') + except UnicodeError as exc: + print('not ascii', repr(bpath), exc) + path = None + else: + if (not (path.isprintable() and path.startswith('/')) or + '/.' in path): + print('bad path', repr(path)) + path = None + else: + path = '.' + path + if not os.path.exists(path): + print('no file', repr(path)) + path = None + else: + isdir = os.path.isdir(path) + if not path: + self.transport.write(b'HTTP/1.0 404 Not found\r\n\r\n') + self.transport.close() + return + lines = [] + while True: + line = yield from self.reader.readline() + print('header line', line) + if not line.strip(b' \t\r\n'): + break + lines.append(line) + if line == b'\r\n': + break + parser = email.parser.BytesHeaderParser() + headers = parser.parsebytes(b''.join(lines)) + write = self.transport.write + if isdir and not path.endswith('/'): + write(b'HTTP/1.0 302 Redirected\r\n' + b'URI: ' + bpath + b'/\r\n' + b'Location: ' + bpath + b'/\r\n' + b'\r\n') + return + write(b'HTTP/1.0 200 Ok\r\n') + if isdir: + write(b'Content-type: text/html\r\n') + else: + write(b'Content-type: text/plain\r\n') + write(b'\r\n') + if isdir: + write(b'
    \r\n') + for name in sorted(os.listdir(path)): + if name.isprintable() and not name.startswith('.'): + try: + bname = name.encode('ascii') + except UnicodeError as exc: + pass + else: + if os.path.isdir(os.path.join(path, name)): + write(b'
  • ' + bname + b'/
  • \r\n') + else: + write(b'
  • ' + bname + b'
  • \r\n') + write(b'
') + else: + try: + with open(path, 'rb') as f: + write(f.read()) + except OSError as exc: + write(b'Cannot open\r\n') + self.transport.close() + + def connection_made(self, transport): + self.transport = transport + print('connection made', transport, transport._sock) + self.reader = StreamReader() + self.handler = self.handle_request() + + def data_received(self, data): + print('data received', data) + self.reader.feed_data(data) + + def eof_received(self): + print('eof received') + self.reader.feed_eof() + + def connection_lost(self, exc): + print('connection lost', exc) + if (self.handler.done() and + not self.handler.cancelled() and + self.handler.exception() is not None): + print('handler exception:', self.handler.exception()) + + +def main(): + loop = tulip.get_event_loop() + f = loop.start_serving(HttpServer, '127.0.0.1', 8080) + x = loop.run_until_complete(f) + print('serving on', x.getsockname()) + loop.run_forever() + + +if __name__ == '__main__': + main() diff --git a/sslsrv.py b/sslsrv.py new file mode 100644 index 00000000..a1bc04f9 --- /dev/null +++ b/sslsrv.py @@ -0,0 +1,56 @@ +"""Serve up an SSL connection, after Python ssl module docs.""" + +import socket +import ssl +import os + + +def main(): + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + certfile = getcertfile() + context.load_cert_chain(certfile=certfile, keyfile=certfile) + bindsocket = socket.socket() + bindsocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + bindsocket.bind(('', 4443)) + bindsocket.listen(5) + + while True: + newsocket, fromaddr = bindsocket.accept() + try: + connstream = context.wrap_socket(newsocket, server_side=True) + try: + deal_with_client(connstream) + finally: + connstream.shutdown(socket.SHUT_RDWR) + connstream.close() + except Exception as exc: + print(exc.__class__.__name__ + ':', exc) + + +def getcertfile(): + import test # Test package + testdir = os.path.dirname(test.__file__) + certfile = os.path.join(testdir, 'keycert.pem') + print('certfile =', certfile) + return certfile + + +def deal_with_client(connstream): + data = connstream.recv(1024) + # empty data means the client is finished with us + while data: + if not do_something(connstream, data): + # we'll assume do_something returns False + # when we're finished with client + break + data = connstream.recv(1024) + # finished with client + + +def do_something(connstream, data): + # just echo back + connstream.sendall(data) + + +if __name__ == '__main__': + main() diff --git a/tulip/TODO b/tulip/TODO new file mode 100644 index 00000000..b3a9302e --- /dev/null +++ b/tulip/TODO @@ -0,0 +1,26 @@ +TODO in tulip v2 (tulip/ package directory) +------------------------------------------- + +- See also TBD and Open Issues in PEP 3156 + +- Refactor unix_events.py (it's getting too long) + +- Docstrings + +- Unittests + +- better run_once() behavior? (Run ready list last.) + +- start_serving() + +- Make Handler() callable? Log the exception in there? + +- Add the repeat interval to the Handler class? + +- Recognize Handler passed to add_reader(), call_soon(), etc.? + +- SSL support + +- buffered stream implementation + +- Primitives like par() and wait_one() diff --git a/tulip/__init__.py b/tulip/__init__.py new file mode 100644 index 00000000..185fe3fe --- /dev/null +++ b/tulip/__init__.py @@ -0,0 +1,14 @@ +"""Tulip 2.0, tracking PEP 3156.""" + +# This relies on each of the submodules having an __all__ variable. +from .futures import * +from .events import * +from .transports import * +from .protocols import * +from .tasks import * + +__all__ = (futures.__all__ + + events.__all__ + + transports.__all__ + + protocols.__all__ + + tasks.__all__) diff --git a/tulip/events.py b/tulip/events.py new file mode 100644 index 00000000..83b1adf3 --- /dev/null +++ b/tulip/events.py @@ -0,0 +1,307 @@ +"""Event loop and event loop policy. + +Beyond the PEP: +- Only the main thread has a default event loop. +""" + +__all__ = ['EventLoopPolicy', 'DefaultEventLoopPolicy', + 'EventLoop', 'Handler', 'make_handler', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', 'new_event_loop', + ] + +import threading + + +class Handler: + """Object returned by callback registration methods.""" + + def __init__(self, when, callback, args): + self._when = when + self._callback = callback + self._args = args + self._cancelled = False + + def __repr__(self): + res = 'Handler({}, {}, {})'.format(self._when, + self._callback, + self._args) + if self._cancelled: + res += '' + return res + + @property + def when(self): + return self._when + + @property + def callback(self): + return self._callback + + @property + def args(self): + return self._args + + @property + def cancelled(self): + return self._cancelled + + def cancel(self): + self._cancelled = True + + def __lt__(self, other): + if self._when is None: + return other._when is not None + elif other._when is None: + return False + + return self._when < other._when + + def __le__(self, other): + if self._when is None: + return True + elif other._when is None: + return False + + return self._when <= other._when + + def __gt__(self, other): + if self._when is None: + return False + elif other._when is None: + return True + + return self._when > other._when + + def __ge__(self, other): + if self._when is None: + return other._when is None + elif other._when is None: + return True + + return self._when >= other._when + + def __eq__(self, other): + return self._when == other._when + + +def make_handler(when, callback, args): + if isinstance(callback, Handler): + assert not args + assert when is None + return callback + return Handler(when, callback, args) + + +class EventLoop: + """Abstract event loop.""" + + # TODO: Rename run() -> run_until_idle(), run_forever() -> run(). + + def run(self): + """Run the event loop. Block until there is nothing left to do.""" + raise NotImplementedError + + def run_forever(self): + """Run the event loop. Block until stop() is called.""" + raise NotImplementedError + + def run_once(self, timeout=None): # NEW! + """Run one complete cycle of the event loop.""" + raise NotImplementedError + + def run_until_complete(self, future, timeout=None): # NEW! + """Run the event loop until a Future is done. + + Return the Future's result, or raise its exception. + + If timeout is not None, run it for at most that long; + if the Future is still not done, raise TimeoutError + (but don't cancel the Future). + """ + raise NotImplementedError + + def stop(self): # NEW! + """Stop the event loop as soon as reasonable. + + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + raise NotImplementedError + + # Methods returning Handlers for scheduling callbacks. + + def call_later(self, delay, callback, *args): + raise NotImplementedError + + def call_repeatedly(self, interval, callback, *args): # NEW! + raise NotImplementedError + + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) + + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError + + # Methods returning Futures for interacting with threads. + + def wrap_future(self, future): + raise NotImplementedError + + def run_in_executor(self, executor, callback, *args): + raise NotImplementedError + + # Network I/O methods returning Futures. + + def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError + + def create_connection(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, sock=None): + raise NotImplementedError + + def start_serving(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, sock=None): + raise NotImplementedError + + # Ready-based callback registration methods. + # The add_*() methods return a Handler. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. + + def add_reader(self, fd, callback, *args): + raise NotImplementedError + + def remove_reader(self, fd): + raise NotImplementedError + + def add_writer(self, fd, callback, *args): + raise NotImplementedError + + def remove_writer(self, fd): + raise NotImplementedError + + def add_connector(self, fd, callback, *args): + raise NotImplementedError + + def remove_connector(self, fd): + raise NotImplementedError + + # Completion based I/O methods returning Futures. + + def sock_recv(self, sock, nbytes): + raise NotImplementedError + + def sock_sendall(self, sock, data): + raise NotImplementedError + + def sock_connect(self, sock, address): + raise NotImplementedError + + def sock_accept(self, sock): + raise NotImplementedError + + # Signal handling. + + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError + + def remove_signal_handler(self, sig): + raise NotImplementedError + + +class EventLoopPolicy: + """Abstract policy for accessing the event loop.""" + + def get_event_loop(self): + """XXX""" + raise NotImplementedError + + def set_event_loop(self, event_loop): + """XXX""" + raise NotImplementedError + + def new_event_loop(self): + """XXX""" + raise NotImplementedError + + +class DefaultEventLoopPolicy(threading.local, EventLoopPolicy): + """Default policy implementation for accessing the event loop. + + In this policy, each thread has its own event loop. However, we + only automatically create an event loop by default for the main + thread; other threads by default have no event loop. + + Other policies may have different rules (e.g. a single global + event loop, or automatically creating an event loop per thread, or + using some other notion of context to which an event loop is + associated). + """ + + _event_loop = None + + def get_event_loop(self): + """Get the event loop. + + This may be None or an instance of EventLoop. + """ + if (self._event_loop is None and + threading.current_thread().name == 'MainThread'): + self._event_loop = self.new_event_loop() + return self._event_loop + + def set_event_loop(self, event_loop): + """Set the event loop.""" + assert event_loop is None or isinstance(event_loop, EventLoop) + self._event_loop = event_loop + + def new_event_loop(self): + """Create a new event loop. + + You must call set_event_loop() to make this the current event + loop. + """ + # TODO: Do something else for Windows. + from . import unix_events + return unix_events.UnixEventLoop() + + +# Event loop policy. The policy itself is always global, even if the +# policy's rules say that there is an event loop per thread (or other +# notion of context). The default policy is installed by the first +# call to get_event_loop_policy(). +_event_loop_policy = None + + +def get_event_loop_policy(): + """XXX""" + global _event_loop_policy + if _event_loop_policy is None: + _event_loop_policy = DefaultEventLoopPolicy() + return _event_loop_policy + + +def set_event_loop_policy(policy): + """XXX""" + global _event_loop_policy + assert policy is None or isinstance(policy, EventLoopPolicy) + _event_loop_policy = policy + + +def get_event_loop(): + """XXX""" + return get_event_loop_policy().get_event_loop() + + +def set_event_loop(event_loop): + """XXX""" + get_event_loop_policy().set_event_loop(event_loop) + + +def new_event_loop(): + """XXX""" + return get_event_loop_policy().new_event_loop() diff --git a/tulip/events_test.py b/tulip/events_test.py new file mode 100644 index 00000000..384c31ae --- /dev/null +++ b/tulip/events_test.py @@ -0,0 +1,686 @@ +"""Tests for events.py.""" + +import concurrent.futures +import errno +import gc +import os +import select +import signal +import socket +try: + import ssl +except ImportError: + ssl = None +import sys +import threading +import time +import unittest +import unittest.mock + +from . import events +from . import transports +from . import protocols +from . import selectors +from . import test_utils +from . import unix_events + + +class MyProto(protocols.Protocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: xkcd.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + + +class EventLoopTestsMixin: + + def setUp(self): + super().setUp() + self.selector = self.SELECTOR_CLASS() + self.event_loop = unix_events.UnixEventLoop(self.selector) + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + gc.collect() + + def test_run(self): + self.event_loop.run() # Returns immediately. + + def test_call_later(self): + results = [] + def callback(arg): + results.append(arg) + self.event_loop.call_later(0.1, callback, 'hello world') + t0 = time.monotonic() + self.event_loop.run() + t1 = time.monotonic() + self.assertEqual(results, ['hello world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_call_repeatedly(self): + results = [] + def callback(arg): + results.append(arg) + self.event_loop.call_repeatedly(0.03, callback, 'ho') + self.event_loop.call_later(0.1, self.event_loop.stop) + self.event_loop.run() + self.assertEqual(results, ['ho', 'ho', 'ho']) + + def test_call_soon(self): + results = [] + def callback(arg1, arg2): + results.append((arg1, arg2)) + self.event_loop.call_soon(callback, 'hello', 'world') + self.event_loop.run() + self.assertEqual(results, [('hello', 'world')]) + + def test_call_soon_with_handler(self): + results = [] + def callback(): + results.append('yeah') + handler = events.Handler(None, callback, ()) + self.assertIs(self.event_loop.call_soon(handler), handler) + self.event_loop.run() + self.assertEqual(results, ['yeah']) + + def test_call_soon_threadsafe(self): + results = [] + def callback(arg): + results.append(arg) + def run(): + self.event_loop.call_soon_threadsafe(callback, 'hello') + t = threading.Thread(target=run) + self.event_loop.call_later(0.1, callback, 'world') + t0 = time.monotonic() + t.start() + self.event_loop.run() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_call_soon_threadsafe_with_handler(self): + results = [] + def callback(arg): + results.append(arg) + + handler = events.Handler(None, callback, ('hello',)) + def run(): + self.assertIs(self.event_loop.call_soon_threadsafe(handler),handler) + + t = threading.Thread(target=run) + self.event_loop.call_later(0.1, callback, 'world') + + t0 = time.monotonic() + t.start() + self.event_loop.run() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_wrap_future(self): + def run(arg): + time.sleep(0.1) + return arg + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = self.event_loop.wrap_future(f1) + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'oi') + + def test_run_in_executor(self): + def run(arg): + time.sleep(0.1) + return arg + f2 = self.event_loop.run_in_executor(None, run, 'yo') + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def test_run_in_executor_with_handler(self): + def run(arg): + time.sleep(0.1) + return arg + handler = events.Handler(None, run, ('yo',)) + f2 = self.event_loop.run_in_executor(None, handler) + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def test_reader_callback(self): + r, w = unix_events.socketpair() + bytes_read = [] + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.event_loop.remove_reader(r.fileno())) + r.close() + self.event_loop.add_reader(r.fileno(), reader) + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_reader_callback_with_handler(self): + r, w = unix_events.socketpair() + bytes_read = [] + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.event_loop.remove_reader(r.fileno())) + r.close() + + handler = events.Handler(None, reader, ()) + self.assertIs(handler, self.event_loop.add_reader(r.fileno(), handler)) + + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_reader_callback_cancel(self): + r, w = unix_events.socketpair() + bytes_read = [] + def reader(): + data = r.recv(1024) + if data: + bytes_read.append(data) + if sum(len(b) for b in bytes_read) >= 6: + handler.cancel() + if not data: + r.close() + handler = self.event_loop.add_reader(r.fileno(), reader) + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_writer_callback(self): + r, w = unix_events.socketpair() + w.setblocking(False) + self.event_loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + def remove_writer(): + self.assertTrue(self.event_loop.remove_writer(w.fileno())) + self.event_loop.call_later(0.1, remove_writer) + self.event_loop.run() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def test_writer_callback_with_handler(self): + r, w = unix_events.socketpair() + w.setblocking(False) + handler = events.Handler(None, w.send, (b'x'*(256*1024),)) + self.assertIs(self.event_loop.add_writer(w.fileno(), handler), handler) + def remove_writer(): + self.assertTrue(self.event_loop.remove_writer(w.fileno())) + self.event_loop.call_later(0.1, remove_writer) + self.event_loop.run() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def test_writer_callback_cancel(self): + r, w = unix_events.socketpair() + w.setblocking(False) + def sender(): + w.send(b'x'*256) + handler.cancel() + handler = self.event_loop.add_writer(w.fileno(), sender) + self.event_loop.run() + w.close() + data = r.recv(1024) + r.close() + self.assertTrue(data == b'x'*256) + + def test_sock_client_ops(self): + sock = socket.socket() + sock.setblocking(False) + # TODO: This depends on python.org behavior! + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, ('python.org', 80))) + self.event_loop.run_until_complete( + self.event_loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.event_loop.run_until_complete( + self.event_loop.sock_recv(sock, 1024)) + sock.close() + self.assertTrue(data.startswith(b'HTTP/1.1 302 Found\r\n')) + + def test_sock_client_fail(self): + sock = socket.socket() + sock.setblocking(False) + # TODO: This depends on python.org behavior! + with self.assertRaises(ConnectionRefusedError): + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, ('python.org', 12345))) + sock.close() + + def test_sock_accept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + + f = self.event_loop.sock_accept(listener) + conn, addr = self.event_loop.run_until_complete(f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + def test_add_signal_handler(self): + caught = 0 + def my_handler(): + nonlocal caught + caught += 1 + + # Check error behavior first. + self.assertRaises( + TypeError, self.event_loop.add_signal_handler, 'boom', my_handler) + self.assertRaises( + TypeError, self.event_loop.remove_signal_handler, 'boom') + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, signal.NSIG+1) + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, 0, my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, 0) + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, -1, my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, -1) + self.assertRaises( + RuntimeError, self.event_loop.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + self.event_loop.add_signal_handler(signal.SIGINT, my_handler) + self.event_loop.run_once() + os.kill(os.getpid(), signal.SIGINT) + self.event_loop.run_once() + self.assertEqual(caught, 1) + # Removing it should restore the default handler. + self.assertTrue(self.event_loop.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGINT)) + + @unittest.skipIf(sys.platform == 'win32', 'Unix only') + def test_cancel_signal_handler(self): + # Cancelling the handler should remove it (eventually). + caught = 0 + def my_handler(): + nonlocal caught + caught += 1 + + handler = self.event_loop.add_signal_handler(signal.SIGINT, my_handler) + handler.cancel() + os.kill(os.getpid(), signal.SIGINT) + self.event_loop.run_once() + self.assertEqual(caught, 0) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_while_selecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + def my_handler(): + nonlocal caught + caught += 1 + + handler = self.event_loop.add_signal_handler(signal.SIGALRM, my_handler) + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + self.event_loop.call_later(0.15, self.event_loop.stop) + self.event_loop.run_forever() + self.assertEqual(caught, 1) + + def test_create_transport(self): + # TODO: This depends on xkcd.com behavior! + f = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_ssl_transport(self): + # TODO: This depends on xkcd.com behavior! + f = self.event_loop.create_connection( + MyProto, 'xkcd.com', 443, ssl=True) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) + + def test_create_transport_host_port_sock(self): + fut = self.event_loop.create_connection( + MyProto, 'xkcd.com', 80, sock=object()) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_create_transport_no_host_port_sock(self): + fut = self.event_loop.create_connection(MyProto) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_create_transport_no_getaddrinfo(self): + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + fut = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, fut) + + def test_create_transport_connect_err(self): + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + fut = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, fut) + + def test_start_serving(self): + f = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + sock = self.event_loop.run_until_complete(f) + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + self.event_loop.run_once() # This is quite mysterious, but necessary. + self.event_loop.run_once() + self.event_loop.run_once() + sock.close() + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + def test_start_serving_sock(self): + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.event_loop.start_serving(MyProto, sock=sock_ob) + sock = self.event_loop.run_until_complete(f) + self.assertIs(sock, sock_ob) + + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + self.event_loop.run_once() # This is quite mysterious, but necessary. + self.event_loop.run_once() + self.event_loop.run_once() + sock.close() + client.close() + + def test_start_serving_host_port_sock(self): + fut = self.event_loop.start_serving(MyProto,'0.0.0.0',0,sock=object()) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_start_serving_no_host_port_sock(self): + fut = self.event_loop.start_serving(MyProto) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_start_serving_no_getaddrinfo(self): + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + f = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, f) + + @unittest.mock.patch('tulip.unix_events.socket') + def test_start_serving_cant_bind(self, m_socket): + class Err(socket.error): + pass + + m_socket.error = socket.error + m_socket.getaddrinfo.return_value = [(2, 1, 6, '', ('127.0.0.1',10100))] + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.setsockopt.side_effect = Err + + fut = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises(Err, self.event_loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_accept_connection_retry(self): + class Err(socket.error): + errno = errno.EAGAIN + + sock = unittest.mock.Mock() + sock.accept.side_effect = Err + + self.event_loop._accept_connection(MyProto, sock) + self.assertFalse(sock.close.called) + + def test_accept_connection_exception(self): + self.suppress_log_errors() + + sock = unittest.mock.Mock() + sock.accept.side_effect = socket.error + + self.event_loop._accept_connection(MyProto, sock) + self.assertTrue(sock.close.called) + + +if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + SELECTOR_CLASS = selectors.KqueueSelector + + +if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + SELECTOR_CLASS = selectors.EpollSelector + + +if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + SELECTOR_CLASS = selectors.PollSelector + + +# Should always exist. +class SelectEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + SELECTOR_CLASS = selectors.SelectSelector + + +class HandlerTests(unittest.TestCase): + + def test_handler(self): + def callback(*args): + return args + + args = () + h = events.Handler(None, callback, args) + self.assertIsNone(h.when) + self.assertIs(h.callback, callback) + self.assertIs(h.args, args) + self.assertFalse(h.cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handler(None, ' + '.callback')) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h.cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handler(None, ' + '.callback')) + self.assertTrue(r.endswith('())')) + + def test_handler_comparison(self): + def callback(*args): + return args + + h1 = events.Handler(None, callback, ()) + h2 = events.Handler(None, callback, ()) + self.assertTrue((h1 < h2) == False) + self.assertTrue((h1 <= h2) == True) + self.assertTrue((h1 > h2) == False) + self.assertTrue((h1 >= h2) == True) + self.assertTrue((h1 == h2) == True) + + when = time.monotonic() + + h1 = events.Handler(when, callback, ()) + h2 = events.Handler(None, callback, ()) + self.assertTrue((h1 < h2) == False) + self.assertTrue((h1 <= h2) == False) + self.assertTrue((h1 > h2) == True) + self.assertTrue((h1 >= h2) == True) + self.assertTrue((h1 == h2) == False) + + self.assertTrue((h2 < h1) == True) + self.assertTrue((h2 <= h1) == True) + self.assertTrue((h2 > h1) == False) + self.assertTrue((h2 >= h1) == False) + self.assertTrue((h2 == h1) == False) + + h1 = events.Handler(when, callback, ()) + h2 = events.Handler(when, callback, ()) + self.assertTrue((h1 < h2) == False) + self.assertTrue((h1 <= h2) == True) + self.assertTrue((h1 > h2) == False) + self.assertTrue((h1 >= h2) == True) + self.assertTrue((h1 == h2) == True) + + h1 = events.Handler(when, callback, ()) + h2 = events.Handler(when + 10.0, callback, ()) + self.assertTrue((h1 < h2) == True) + self.assertTrue((h1 <= h2) == True) + self.assertTrue((h1 > h2) == False) + self.assertTrue((h1 >= h2) == False) + self.assertTrue((h1 == h2) == False) + + def test_make_handler(self): + def callback(*args): + return args + h1 = events.Handler(None, callback, ()) + h2 = events.make_handler(None, h1, ()) + self.assertIs(h1, h2) + + self.assertRaises(AssertionError, + events.make_handler, 10.0, h1, ()) + + self.assertRaises(AssertionError, + events.make_handler, None, h1, (1,2,)) + + +class PolicyTests(unittest.TestCase): + + def test_event_loop_policy(self): + policy = events.EventLoopPolicy() + self.assertRaises(NotImplementedError, policy.get_event_loop) + self.assertRaises(NotImplementedError, policy.set_event_loop, object()) + self.assertRaises(NotImplementedError, policy.new_event_loop) + + def test_get_event_loop(self): + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy._event_loop) + + event_loop = policy.get_event_loop() + self.assertIsInstance(event_loop, events.EventLoop) + + self.assertIs(policy._event_loop, event_loop) + self.assertIs(event_loop, policy.get_event_loop()) + + @unittest.mock.patch('tulip.events.threading') + def test_get_event_loop_thread(self, m_threading): + m_t = m_threading.current_thread.return_value = unittest.mock.Mock() + m_t.name = 'Thread 1' + + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy.get_event_loop()) + + def test_new_event_loop(self): + policy = events.DefaultEventLoopPolicy() + + event_loop = policy.new_event_loop() + self.assertIsInstance(event_loop, events.EventLoop) + + def test_set_event_loop(self): + policy = events.DefaultEventLoopPolicy() + old_event_loop = policy.get_event_loop() + + self.assertRaises(AssertionError, policy.set_event_loop, object()) + + event_loop = policy.new_event_loop() + policy.set_event_loop(event_loop) + self.assertIs(event_loop, policy.get_event_loop()) + self.assertIsNot(old_event_loop, policy.get_event_loop()) + + def test_get_event_loop_policy(self): + policy = events.get_event_loop_policy() + self.assertIsInstance(policy, events.EventLoopPolicy) + self.assertIs(policy, events.get_event_loop_policy()) + + def test_set_event_loop_policy(self): + self.assertRaises( + AssertionError, events.set_event_loop_policy, object()) + + old_policy = events.get_event_loop_policy() + + policy = events.DefaultEventLoopPolicy() + events.set_event_loop_policy(policy) + self.assertIs(policy, events.get_event_loop_policy()) + + +if __name__ == '__main__': + unittest.main() diff --git a/tulip/futures.py b/tulip/futures.py new file mode 100644 index 00000000..e79999fc --- /dev/null +++ b/tulip/futures.py @@ -0,0 +1,240 @@ +"""A Future class similar to the one in PEP 3148.""" + +__all__ = ['CancelledError', 'TimeoutError', + 'InvalidStateError', 'InvalidTimeoutError', + 'Future', + ] + +import concurrent.futures._base + +from . import events + +# States for Future. +_PENDING = 'PENDING' +_CANCELLED = 'CANCELLED' +_FINISHED = 'FINISHED' + +# TODO: Do we really want to depend on concurrent.futures internals? +Error = concurrent.futures._base.Error +CancelledError = concurrent.futures.CancelledError +TimeoutError = concurrent.futures.TimeoutError + + +class InvalidStateError(Error): + """The operation is not allowed in this state.""" + # TODO: Show the future, its state, the method, and the required state. + + +class InvalidTimeoutError(Error): + """Called result() or exception() with timeout != 0.""" + # TODO: Print a nice error message. + + +class Future: + """This class is *almost* compatible with concurrent.futures.Future. + + Differences: + + - result() and exception() do not take a timeout argument and + raise an exception when the future isn't done yet. + + - Callbacks registered with add_done_callback() are always called + via the event loop's call_soon_threadsafe(). + + - This class is not compatible with the wait() and as_completed() + methods in the concurrent.futures package. + + (In Python 3.4 or later we may be able to unify the implementations.) + """ + + # Class variables serving as defaults for instance variables. + _state = _PENDING + _result = None + _exception = None + + def __init__(self, *, event_loop=None): + """Initialize the future. + + The optional event_loop argument allows to explicitly set the event + loop object used by the future. If it's not provided, the future uses + the default event loop. + """ + if event_loop is None: + self._event_loop = events.get_event_loop() + else: + self._event_loop = event_loop + self._callbacks = [] + + def __repr__(self): + res = self.__class__.__name__ + if self._state == _FINISHED: + if self._exception is not None: + res += ''.format(self._exception) + else: + res += ''.format(self._result) + elif self._callbacks: + size = len(self._callbacks) + if size > 2: + res += '<{}, [{}, <{} more>, {}]>'.format( + self._state, self._callbacks[0], + size-2, self._callbacks[-1]) + else: + res += '<{}, {}>'.format(self._state, self._callbacks) + else: + res +='<{}>'.format(self._state) + return res + + def cancel(self): + """Cancel the future and schedule callbacks. + + If the future is already done or cancelled, return False. Otherwise, + change the future's state to cancelled, schedule the callbacks and + return True. + """ + if self._state != _PENDING: + return False + self._state = _CANCELLED + self._schedule_callbacks() + return True + + def _schedule_callbacks(self): + """Internal: Ask the event loop to call all callbacks. + + The callbacks are scheduled to be called as soon as possible. Also + clears the callback list. + """ + callbacks = self._callbacks[:] + if not callbacks: + return + self._callbacks[:] = [] + for callback in callbacks: + self._event_loop.call_soon(callback, self) + + def cancelled(self): + """Return True if the future was cancelled.""" + return self._state == _CANCELLED + + def running(self): + """Always return False. + + This method is for compatibility with concurrent.futures; we don't + have a running state. + """ + return False # We don't have a running state. + + def done(self): + """Return True if the future is done. + + Done means either that a result / exception are available, or that the + future was cancelled. + """ + return self._state != _PENDING + + def result(self, timeout=0): + """Return the result this future represents. + + If the future has been cancelled, raises CancelledError. If the + future's result isn't yet available, raises InvalidStateError. If + the future is done and has an exception set, this exception is raised. + Timeout values other than 0 are not supported. + """ + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + if self._exception is not None: + raise self._exception + return self._result + + def exception(self, timeout=0): + """Return the exception that was set on this future. + + The exception (or None if no exception was set) is returned only if + the future is done. If the future has been cancelled, raises + CancelledError. If the future isn't done yet, raises + InvalidStateError. Timeout values other than 0 are not supported. + """ + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + return self._exception + + def add_done_callback(self, fn): + """Add a callback to be run when the future becomes done. + + The callback is called with a single argument - the future object. If + the future is already done when this is called, the callback is + scheduled with call_soon. + """ + if self._state != _PENDING: + self._event_loop.call_soon(fn, self) + else: + self._callbacks.append(fn) + + # New method not in PEP 3148. + + def remove_done_callback(self, fn): + """Remove all instances of a callback from the "call when done" list. + + Returns the number of callbacks removed. + """ + filtered_callbacks = [f for f in self._callbacks if f != fn] + removed_count = len(self._callbacks) - len(filtered_callbacks) + if removed_count: + self._callbacks[:] = filtered_callbacks + return removed_count + + # So-called internal methods (note: no set_running_or_notify_cancel()). + + def set_result(self, result): + """Mark the future done and set its result. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError + self._result = result + self._state = _FINISHED + self._schedule_callbacks() + + def set_exception(self, exception): + """ Mark the future done and set an exception. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError + self._exception = exception + self._state = _FINISHED + self._schedule_callbacks() + + # Truly internal methods. + + def _copy_state(self, other): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert other.done() + assert not self.done() + if other.cancelled(): + self.cancel() + else: + exception = other.exception() + if exception is not None: + self.set_exception(exception) + else: + result = other.result() + self.set_result(result) + + def __iter__(self): + if not self.done(): + yield self # This tells Task to wait for completion. + return self.result() # May raise too. diff --git a/tulip/futures_test.py b/tulip/futures_test.py new file mode 100644 index 00000000..2610b5be --- /dev/null +++ b/tulip/futures_test.py @@ -0,0 +1,210 @@ +"""Tests for futures.py.""" + +import unittest + +from . import futures + + +def _fakefunc(f): + return f + + +class FutureTests(unittest.TestCase): + + def test_initial_state(self): + f = futures.Future() + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertFalse(f.done()) + + def test_init_event_loop_positional(self): + # Make sure Future does't accept a positional argument + self.assertRaises(TypeError, futures.Future, 42) + + def test_cancel(self): + f = futures.Future() + self.assertTrue(f.cancel()) + self.assertTrue(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(futures.CancelledError, f.result) + self.assertRaises(futures.CancelledError, f.exception) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_result(self): + f = futures.Future() + self.assertRaises(futures.InvalidStateError, f.result) + self.assertRaises(futures.InvalidTimeoutError, f.result, 10) + + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_exception(self): + exc = RuntimeError() + f = futures.Future() + self.assertRaises(futures.InvalidStateError, f.exception) + self.assertRaises(futures.InvalidTimeoutError, f.exception, 10) + + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_yield_from_twice(self): + f = futures.Future() + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + + def test_repr(self): + f_pending = futures.Future() + self.assertEqual(repr(f_pending), 'Future') + + f_cancelled = futures.Future() + f_cancelled.cancel() + self.assertEqual(repr(f_cancelled), 'Future') + + f_result = futures.Future() + f_result.set_result(4) + self.assertEqual(repr(f_result), 'Future') + + f_exception = futures.Future() + f_exception.set_exception(RuntimeError()) + self.assertEqual(repr(f_exception), 'Future') + + f_few_callbacks = futures.Future() + f_few_callbacks.add_done_callback(_fakefunc) + self.assertIn('Future', r) + + def test_copy_state(self): + # Test the internal _copy_state method since it's being directly + # invoked in other modules. + f = futures.Future() + f.set_result(10) + + newf = futures.Future() + newf._copy_state(f) + self.assertTrue(newf.done()) + self.assertEqual(newf.result(), 10) + + f_exception = futures.Future() + f_exception.set_exception(RuntimeError()) + + newf_exception = futures.Future() + newf_exception._copy_state(f_exception) + self.assertTrue(newf_exception.done()) + self.assertRaises(RuntimeError, newf_exception.result) + + f_cancelled = futures.Future() + f_cancelled.cancel() + + newf_cancelled = futures.Future() + newf_cancelled._copy_state(f_cancelled) + self.assertTrue(newf_cancelled.cancelled()) + + +# A fake event loop for tests. All it does is implement a call_soon method +# that immediately invokes the given function. +class _FakeEventLoop: + def call_soon(self, fn, future): + fn(future) + + +class FutureDoneCallbackTests(unittest.TestCase): + + def _make_callback(self, bag, thing): + # Create a callback function that appends thing to bag. + def bag_appender(future): + bag.append(thing) + return bag_appender + + def _new_future(self): + return futures.Future(event_loop=_FakeEventLoop()) + + def test_callbacks_invoked_on_set_result(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 42)) + f.add_done_callback(self._make_callback(bag, 17)) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def test_callbacks_invoked_on_set_exception(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 100)) + + self.assertEqual(bag, []) + exc = RuntimeError() + f.set_exception(exc) + self.assertEqual(bag, [100]) + self.assertEqual(f.exception(), exc) + + def test_remove_done_callback(self): + bag = [] + f = self._new_future() + cb1 = self._make_callback(bag, 1) + cb2 = self._make_callback(bag, 2) + cb3 = self._make_callback(bag, 3) + + # Add one cb1 and one cb2. + f.add_done_callback(cb1) + f.add_done_callback(cb2) + + # One instance of cb2 removed. Now there's only one cb1. + self.assertEqual(f.remove_done_callback(cb2), 1) + + # Never had any cb3 in there. + self.assertEqual(f.remove_done_callback(cb3), 0) + + # After this there will be 6 instances of cb1 and one of cb2. + f.add_done_callback(cb2) + for i in range(5): + f.add_done_callback(cb1) + + # Remove all instances of cb1. One cb2 remains. + self.assertEqual(f.remove_done_callback(cb1), 6) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [2]) + self.assertEqual(f.result(), 'foo') + + +if __name__ == '__main__': + unittest.main() diff --git a/tulip/http_client.py b/tulip/http_client.py new file mode 100644 index 00000000..0a03d81f --- /dev/null +++ b/tulip/http_client.py @@ -0,0 +1,295 @@ +"""HTTP Client for Tulip. + +Most basic usage: + + sts, headers, response = yield from http_client.fetch(url, + method='GET', headers={}, request=b'') + assert isinstance(sts, int) + assert isinstance(headers, dict) + # sort of; case insensitive (what about multiple values for same header?) + headers['status'] == '200 Ok' # or some such + assert isinstance(response, bytes) + +However you can also open a stream: + + f, wstream = http_client.open_stream(url, method, headers) + wstream.write(b'abc') + wstream.writelines([b'def', b'ghi']) + wstream.write_eof() + sts, headers, rstream = yield from f + response = yield from rstream.read() + +TODO: Reuse email.Message class (or its subclass, http.client.HTTPMessage). +TODO: How do we do connection keep alive? Pooling? +""" + +import collections +import email.message +import email.parser +import re + +import tulip +from . import events +from . import futures +from . import tasks + + +# TODO: Move to another module. +class StreamReader: + + def __init__(self, limit=2**16): + self.limit = limit # Max line length. (Security feature.) + self.buffer = collections.deque() # Deque of bytes objects. + self.byte_count = 0 # Bytes in buffer. + self.line_count = 0 # Number of complete lines in buffer. + self.eof = False # Whether we're done. + self.waiter = None # A future. + + def feed_eof(self): + self.eof = True + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(True) + + def feed_data(self, data): + if not data: + return + self.buffer.append(data) + self.line_count += data.count(b'\n') + self.byte_count += len(data) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(False) + + @tasks.coroutine + def readline(self): + # TODO: Limit line length for security. + while not self.line_count and not self.eof: + assert self.waiter is None + self.waiter = futures.Future() + yield from self.waiter + parts = [] + while self.buffer: + data = self.buffer.popleft() + ichar = data.find(b'\n') + if ichar < 0: + parts.append(data) + else: + ichar += 1 + head, tail = data[:ichar], data[ichar:] + parts.append(head) + if tail: + self.buffer.appendleft(tail) + self.line_count -= 1 + break + + line = b''.join(parts) + self.byte_count -= len(line) + + return line + + @tasks.coroutine + def read(self, n=-1): + if not n: + return b'' + if n < 0: + while not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + else: + if not self.byte_count and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + if n < 0 or self.byte_count <= n: + data = b''.join(self.buffer) + self.buffer.clear() + self.byte_count = 0 + self.line_count = 0 + return data + parts = [] + parts_bytes = 0 + while self.buffer and parts_bytes < n: + data = self.buffer.popleft() + data_bytes = len(data) + if n < parts_bytes + data_bytes: + data_bytes = n - parts_bytes + data, rest = data[:data_bytes], data[data_bytes:] + self.buffer.appendleft(rest) + parts.append(data) + parts_bytes += data_bytes + self.byte_count -= data_bytes + if self.line_count: + self.line_count -= data.count(b'\n') + return b''.join(parts) + + @tasks.coroutine + def readexactly(self, n): + if n <= 0: + return b'' + while self.byte_count < n and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + return (yield from self.read(n)) + +class HttpClientProtocol: + """This Protocol class is also used to initiate the connection. + + Usage: + p = HttpClientProtocol(url, ...) + f = p.connect() # Returns a Future + ...now what?... + """ + + def __init__(self, host, port=None, *, + path='/', method='GET', headers=None, ssl=None, + make_body=None, encoding='utf-8', version='1.1', + chunked=False): + host = self.validate(host, 'host') + if ':' in host: + assert port is None + host, port_s = host.split(':', 1) + port = int(port_s) + self.host = host + if port is None: + if ssl: + port = 443 + else: + port = 80 + assert isinstance(port, int) + self.port = port + self.path = self.validate(path, 'path') + self.method = self.validate(method, 'method') + self.headers = email.message.Message() + if headers: + for key, value in headers.items(): + self.validate(key, 'header key') + self.validate(value, 'header value', True) + self.headers[key] = value + self.encoding = self.validate(encoding, 'encoding') + self.version = self.validate(version, 'version') + self.make_body = make_body + self.chunked = chunked + self.ssl = ssl + if 'content-length' not in self.headers: + if self.make_body is None: + self.headers['Content-Length'] = '0' + else: + self.chunked = True + if self.chunked: + if 'Transfer-Encoding' not in self.headers: + self.headers['Transfer-Encoding'] = 'chunked' + else: + assert self.headers['Transfer-Encoding'].lower() == 'chunked' + if 'host' not in self.headers: + self.headers['Host'] = self.host + self.event_loop = events.get_event_loop() + self.transport = None + + def validate(self, value, name, embedded_spaces_okay=False): + # Must be a string. If embedded_spaces_okay is False, no + # whitespace is allowed; otherwise, internal single spaces are + # allowed (but no other whitespace). + assert isinstance(value, str), \ + '{} should be str, not {}'.format(name, type(value)) + parts = value.split() + assert parts, '{} should not be empty'.format(name) + if embedded_spaces_okay: + assert ' '.join(parts) == value, \ + '{} can only contain embedded single spaces ({!r})'.format( + name, value) + else: + assert parts == [value], \ + '{} cannot contain whitespace ({!r})'.format(name, value) + return value + + @tulip.coroutine + def connect(self): + yield from self.event_loop.create_connection(lambda: self, + self.host, + self.port, + ssl=self.ssl) + # TODO: A better mechanism to return all info from the + # status line, all headers, and the buffer, without having + # an N-tuple return value. + status_line = yield from self.stream.readline() + m = re.match(rb'HTTP/(\d\.\d)\s+(\d\d\d)\s+([^\r\n]+)\r?\n\Z', + status_line) + if not m: + raise 'Invalid HTTP status line ({!r})'.format(status_line) + version, status, message = m.groups() + raw_headers = [] + while True: + header = yield from self.stream.readline() + if not header.strip(): + break + raw_headers.append(header) + parser = email.parser.BytesHeaderParser() + headers = parser.parsebytes(b''.join(raw_headers)) + content_length = headers.get('content-length') + if content_length: + content_length = int(content_length) # May raise. + if content_length is None: + stream = self.stream + else: + # TODO: A wrapping stream that limits how much it can read + # without reading it all into memory at once. + body = yield from self.stream.readexactly(content_length) + stream = StreamReader() + stream.feed_data(body) + stream.feed_eof() + sts = '{} {}'.format(self.decode(status), self.decode(message)) + return (sts, headers, stream) + + def encode(self, s): + if isinstance(s, bytes): + return s + return s.encode(self.encoding) + + def decode(self, s): + if isinstance(s, str): + return s + return s.decode(self.encoding) + + def write_str(self, s): + self.transport.write(self.encode(s)) + + def write_chunked(self, s): + if not s: + return + data = self.encode(s) + self.write_str('{:x}\r\n'.format(len(data))) + self.transport.write(data) + self.transport.write(b'\r\n') + + def write_chunked_eof(self): + self.transport.write(b'0\r\n\r\n') + + def connection_made(self, transport): + self.transport = transport + line = '{} {} HTTP/{}\r\n'.format(self.method, + self.path, + self.version) + self.write_str(line) + for key, value in self.headers.items(): + self.write_str('{}: {}\r\n'.format(key, value)) + self.transport.write(b'\r\n') + self.stream = StreamReader() + if self.make_body is not None: + if self.chunked: + self.make_body(self.write_chunked, self.write_chunked_eof) + else: + self.make_body(self.write_str, self.transport.write_eof) + + def data_received(self, data): + self.stream.feed_data(data) + + def eof_received(self): + self.stream.feed_eof() + + def connection_lost(self, exc): + pass diff --git a/tulip/http_client_test.py b/tulip/http_client_test.py new file mode 100644 index 00000000..c598a339 --- /dev/null +++ b/tulip/http_client_test.py @@ -0,0 +1,212 @@ +"""Tests for http_client.py.""" + +import unittest + +from . import events +from . import http_client +from . import tasks + + +class StreamReaderTests(unittest.TestCase): + + DATA = b'line1\nline2\nline3\n' + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_feed_empty_data(self): + stream = http_client.StreamReader() + + stream.feed_data(b'') + self.assertEqual(0, stream.line_count) + self.assertEqual(0, stream.byte_count) + + def test_feed_data_line_byte_count(self): + stream = http_client.StreamReader() + + stream.feed_data(self.DATA) + self.assertEqual(self.DATA.count(b'\n'), stream.line_count) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read_zero(self): + """Read zero bytes""" + stream = http_client.StreamReader() + stream.feed_data(self.DATA) + + read_task = tasks.Task(stream.read(0)) + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + self.assertEqual(self.DATA.count(b'\n'), stream.line_count) + + def test_read(self): + """ Read bytes """ + stream = http_client.StreamReader() + read_task = tasks.Task(stream.read(30)) + + def cb(): + stream.feed_data(self.DATA) + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + self.assertFalse(stream.line_count) + + def test_read_line_breaks(self): + """ Read bytes without line breaks """ + stream = http_client.StreamReader() + stream.feed_data(b'line1') + stream.feed_data(b'line2') + + read_task = tasks.Task(stream.read(5)) + data = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'line1', data) + self.assertEqual(5, stream.byte_count) + self.assertFalse(stream.line_count) + + def test_read_eof(self): + """ Read bytes, stop at eof """ + stream = http_client.StreamReader() + read_task = tasks.Task(stream.read(1024)) + + def cb(): + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'', data) + self.assertFalse(stream.byte_count) + self.assertFalse(stream.line_count) + + def test_read_until_eof(self): + """ Read all bytes until eof """ + stream = http_client.StreamReader() + read_task = tasks.Task(stream.read(-1)) + + def cb(): + stream.feed_data(b'chunk1\n') + stream.feed_data(b'chunk2') + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'chunk1\nchunk2', data) + self.assertFalse(stream.byte_count) + self.assertFalse(stream.line_count) + + def test_readline(self): + """ Read one line """ + stream = http_client.StreamReader() + stream.feed_data(b'chunk1 ') + read_task = tasks.Task(stream.readline()) + + def cb(): + stream.feed_data(b'chunk2 ') + stream.feed_data(b'chunk3 ') + stream.feed_data(b'\n chunk4') + self.event_loop.call_soon(cb) + + line = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) + self.assertFalse(stream.line_count) + self.assertEqual(len(b'\n chunk4')-1, stream.byte_count) + + def test_readline_line_byte_count(self): + stream = http_client.StreamReader() + stream.feed_data(self.DATA[:6]) + stream.feed_data(self.DATA[6:]) + + read_task = tasks.Task(stream.readline()) + line = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'line1\n', line) + self.assertEqual(self.DATA.count(b'\n')-1, stream.line_count) + self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) + + def test_readline_empty_eof(self): + stream = http_client.StreamReader() + stream.feed_eof() + + read_task = tasks.Task(stream.readline()) + line = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'', line) + + def test_readline_read_byte_count(self): + stream = http_client.StreamReader() + stream.feed_data(self.DATA) + + read_task = tasks.Task(stream.readline()) + line = self.event_loop.run_until_complete(read_task) + + read_task = tasks.Task(stream.read(7)) + data = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'line2\nl', data) + self.assertEqual( + 1, stream.line_count) + self.assertEqual( + len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), + stream.byte_count) + + def test_readexactly_zero_or_less(self): + """ Read exact number of bytes (zero or less) """ + stream = http_client.StreamReader() + stream.feed_data(self.DATA) + + read_task = tasks.Task(stream.readexactly(0)) + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + self.assertEqual(self.DATA.count(b'\n'), stream.line_count) + + read_task = tasks.Task(stream.readexactly(-1)) + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + self.assertEqual(self.DATA.count(b'\n'), stream.line_count) + + def test_readexactly(self): + """ Read exact number of bytes """ + stream = http_client.StreamReader() + + n = 2*len(self.DATA) + read_task = tasks.Task(stream.readexactly(n)) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA+self.DATA, data) + self.assertEqual(len(self.DATA), stream.byte_count) + self.assertEqual(self.DATA.count(b'\n'), stream.line_count) + + def test_readexactly_eof(self): + """ Read exact number of bytes (eof) """ + stream = http_client.StreamReader() + n = 2*len(self.DATA) + read_task = tasks.Task(stream.readexactly(n)) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + self.assertFalse(stream.line_count) + + +if __name__ == '__main__': + unittest.main() diff --git a/tulip/protocols.py b/tulip/protocols.py new file mode 100644 index 00000000..ad294f3a --- /dev/null +++ b/tulip/protocols.py @@ -0,0 +1,58 @@ +"""Abstract Protocol class.""" + +__all__ = ['Protocol'] + + +class Protocol: + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing (they don't raise exceptions). + + When the user wants to requests a transport, they pass a protocol + factory to a utility function (e.g., EventLoop.create_connection()). + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_list() will be called exactly once + with either an exception object or None as an argument. + + State machine of calls: + + start -> CM [-> DR*] [-> ER?] -> CL -> end + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the connection. + To send data, call its write() or writelines() method. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + The default implementation does nothing. + + TODO: By default close the transport. But we don't have the + transport as an instance variable (connection_made() may not + set it). + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ diff --git a/tulip/selectors.py b/tulip/selectors.py new file mode 100644 index 00000000..05434630 --- /dev/null +++ b/tulip/selectors.py @@ -0,0 +1,430 @@ +"""Select module. + +This module supports asynchronous I/O on multiple file descriptors. +""" + +import logging +import sys + +from select import * + + +# generic events, that must be mapped to implementation-specific ones +# read event +EVENT_READ = (1 << 0) +# write event +EVENT_WRITE = (1 << 1) +# connect event +EVENT_CONNECT = (1 << 2) + +# In most cases we treat EVENT_WRITE and EVENT_CONNECT as aliases for +# each other, and in fact we return both flags when a FD is found +# either writable or connectable. The distinction is necessary +# only for poll() on Windows. + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file descriptor, or any object with a `fileno()` method + + Returns: + corresponding file descriptor + """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (ValueError, TypeError): + raise ValueError("Invalid file object: {!r}".format(fileobj)) + return fd + + +class SelectorKey: + """Object used internally to associate a file object to its backing file + descriptor, selected event mask and attached data.""" + + def __init__(self, fileobj, events, data=None): + self.fileobj = fileobj + self.fd = _fileobj_to_fd(fileobj) + self.events = events + self.data = data + + def __repr__(self): + return '{}'.format( + self.__class__.__name__, + self.fileobj, self.fd, self.events, self.data) + + +class _BaseSelector: + """Base selector class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + performant implementation on the current platform. + """ + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # this maps file objects to keys - for fast (un)registering + self._fileobj_to_key = {} + + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of + EVENT_READ|EVENT_WRITE|EVENT_CONNECT) + data -- attached data + + Returns: + SelectorKey instance + """ + if (not events) or (events & ~(EVENT_READ|EVENT_WRITE|EVENT_CONNECT)): + raise ValueError("Invalid events: {}".format(events)) + + if events & (EVENT_WRITE|EVENT_CONNECT) == (EVENT_WRITE|EVENT_CONNECT): + raise ValueError("WRITE and CONNECT are mutually exclusive. " + "Invalid events: {}".format(events)) + + if fileobj in self._fileobj_to_key: + raise ValueError("{!r} is already registered".format(fileobj)) + + key = SelectorKey(fileobj, events, data) + self._fd_to_key[key.fd] = key + self._fileobj_to_key[fileobj] = key + return key + + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object + + Returns: + SelectorKey instance + """ + try: + key = self._fileobj_to_key[fileobj] + del self._fd_to_key[key.fd] + del self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + return key + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of + EVENT_READ|EVENT_WRITE|EVENT_CONNECT) + data -- attached data + """ + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + if events != key.events: + self.unregister(fileobj) + self.register(fileobj, events, data) + + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout == 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (fileobj, events, attached data) for ready file objects + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE|EVENT_CONNECT + """ + raise NotImplementedError() + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + self._fd_to_key.clear() + self._fileobj_to_key.clear() + + def get_info(self, fileobj): + """Return information about a registered file object. + + Returns: + (events, data) associated to this file object + + Raises KeyError if the file object is not registered. + """ + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise KeyError("{} is not registered".format(fileobj)) + return key.events, key.data + + def registered_count(self): + """Return the number of registered file objects. + + Returns: + number of currently registered file objects + """ + return len(self._fd_to_key) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key + """ + try: + return self._fd_to_key[fd] + except KeyError: + logging.warn('No key found for fd %r', fd) + return None + + +class SelectSelector(_BaseSelector): + """Select-based selector.""" + + def __init__(self): + super().__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & (EVENT_WRITE|EVENT_CONNECT): + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + def select(self, timeout=None): + try: + r, w, _ = self._select(self._readers, self._writers, [], timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + r = set(r) + w = set(w) + ready = [] + for fd in r | w: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE|EVENT_CONNECT + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + if sys.platform == 'win32': + def _select(self, r, w, _, timeout=None): + r, w, x = select(r, w, w, timeout) + return r, w + x, [] + else: + from select import select as _select + + +if 'poll' in globals(): + + # TODO: Implement poll() for Windows with workaround for + # brokenness in WSAPoll() (Richard Oudkerk, see + # http://bugs.python.org/issue16507). + + class PollSelector(_BaseSelector): + """Poll-based selector.""" + + def __init__(self): + super().__init__() + self._poll = poll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= POLLIN + if events & (EVENT_WRITE|EVENT_CONNECT): + poll_events |= POLLOUT + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = None if timeout is None else int(1000 * timeout) + ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~POLLIN: + events |= EVENT_WRITE|EVENT_CONNECT + if event & ~POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + +if 'epoll' in globals(): + + class EpollSelector(_BaseSelector): + """Epoll-based selector.""" + + def __init__(self): + super().__init__() + self._epoll = epoll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + epoll_events = 0 + if events & EVENT_READ: + epoll_events |= EPOLLIN + if events & (EVENT_WRITE|EVENT_CONNECT): + epoll_events |= EPOLLOUT + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._epoll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = -1 if timeout is None else timeout + max_ev = self.registered_count() + ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~EPOLLIN: + events |= EVENT_WRITE|EVENT_CONNECT + if event & ~EPOLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + def close(self): + super().close() + self._epoll.close() + + +if 'kqueue' in globals(): + + class KqueueSelector(_BaseSelector): + """Kqueue-based selector.""" + + def __init__(self): + super().__init__() + self._kqueue = kqueue() + + def unregister(self, fileobj): + key = super().unregister(fileobj) + mask = 0 + if key.events & EVENT_READ: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + if key.events & (EVENT_WRITE|EVENT_CONNECT): + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + return key + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & (EVENT_WRITE|EVENT_CONNECT): + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def select(self, timeout=None): + max_ev = self.registered_count() + ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == KQ_FILTER_READ: + events |= EVENT_READ + if flag == KQ_FILTER_WRITE: + events |= EVENT_WRITE|EVENT_CONNECT + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + def close(self): + super().close() + self._kqueue.close() + + +# Choose the best implementation: roughly, epoll|kqueue > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + Selector = KqueueSelector +elif 'EpollSelector' in globals(): + Selector = EpollSelector +elif 'PollSelector' in globals(): + Selector = PollSelector +else: + Selector = SelectSelector diff --git a/tulip/subprocess_test.py b/tulip/subprocess_test.py new file mode 100644 index 00000000..3d996f6a --- /dev/null +++ b/tulip/subprocess_test.py @@ -0,0 +1,54 @@ +"""Tests for subprocess_transport.py.""" + +import logging +import unittest + +from . import events +from . import protocols +from . import subprocess_transport + + +class MyProto(protocols.Protocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write_eof() + + def data_received(self, data): + logging.info('received: %r', data) + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + + +class FutureTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_unix_subprocess(self): + p = MyProto() + t = subprocess_transport.UnixSubprocessTransport(p, ['/bin/ls', '-lR']) + self.event_loop.run() + + +if __name__ == '__main__': + unittest.main() diff --git a/tulip/subprocess_transport.py b/tulip/subprocess_transport.py new file mode 100644 index 00000000..721013f8 --- /dev/null +++ b/tulip/subprocess_transport.py @@ -0,0 +1,133 @@ +import fcntl +import os +import traceback + +from . import transports +from . import events + + +class UnixSubprocessTransport(transports.Transport): + """Transport class managing a subprocess. + + TODO: Separate this into something that just handles pipe I/O, + and something else that handles pipe setup, fork, and exec. + """ + + def __init__(self, protocol, args): + self._protocol = protocol # Not a factory! :-) + self._args = args # args[0] must be full path of binary. + self._event_loop = events.get_event_loop() + self._buffer = [] + self._eof = False + rstdin, self._wstdin = os.pipe() + self._rstdout, wstdout = os.pipe() + + # TODO: This is incredibly naive. Should look at + # subprocess.py for all the precautions around fork/exec. + pid = os.fork() + if not pid: + # Child. + try: + os.dup2(rstdin, 0) + os.dup2(wstdout, 1) + # TODO: What to do with stderr? + os.execv(args[0], args) + except: + try: + traceback.print_traceback() + finally: + os._exit(127) + + # Parent. + os.close(rstdin) + os.close(wstdout) + _setnonblocking(self._wstdin) + _setnonblocking(self._rstdout) + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.add_reader(self._rstdout, self._stdout_callback) + + def write(self, data): + assert not self._eof + if not data: + return + if not self._buffer: + # Attempt to write it right away first. + try: + n = os.write(self._wstdin, data) + except BlockingIOError: + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + if n >= len(data): + return + if n > 0: + data = data[n:] + self._event_loop.add_writer(self._wstdin, self._stdin_callback) + self._buffer.append(data) + + def write_eof(self): + assert not self._eof + assert self._wstdin >= 0 + self._eof = True + if not self._buffer: + self._event_loop.remove_writer(self._wstdin) + os.close(self._wstdin) + self._wstdin = -1 + + def close(self): + if not self._eof: + self.write_eof() + # XXX What else? + + def _fatal_error(self, exc): + logging.error('Fatal error: %r', exc) + if self._rstdout >= 0: + os.close(self._rstdout) + self._rstdout = -1 + if self._wstdin >= 0: + os.close(self._wstdin) + self._wstdin = -1 + self._eof = True + self._buffer = None + + def _stdin_callback(self): + data = b''.join(self._buffer) + self._buffer = [] + try: + if data: + n = os.write(self._wstdin, data) + else: + n = 0 + except BlockingIOError: + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + if n >= len(data): + self._event_loop.remove_writer(self._wstdin) + if self._eof: + os.close(self._wstdin) + self._wstdin = -1 + return + if n > 0: + data = data[n:] + self._buffer.append(data) # Try again later. + + def _stdout_callback(self): + try: + data = os.read(self._rstdout, 1024) + except BlockingIOError: + return + if data: + self._event_loop.call_soon(self._protocol.data_received, data) + else: + self._event_loop.remove_reader(self._rstdout) + os.close(self._rstdout) + self._rstdout = -1 + self._event_loop.call_soon(self._protocol.eof_received) + + +def _setnonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) diff --git a/tulip/tasks.py b/tulip/tasks.py new file mode 100644 index 00000000..617040ce --- /dev/null +++ b/tulip/tasks.py @@ -0,0 +1,287 @@ +"""Support for tasks, coroutines and the scheduler.""" + +__all__ = ['coroutine', 'task', 'Task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'as_completed', 'sleep', + ] + +import concurrent.futures +import inspect +import logging +import time + +from . import events +from . import futures + + +def coroutine(func): + """Decorator to mark coroutines.""" + # TODO: This is a feel-good API only. It is not enforced. + assert inspect.isgeneratorfunction(func) + func._is_coroutine = True # Not sure who can use this. + return func + + +# TODO: Do we need this? +def iscoroutinefunction(func): + """Return True if func is a decorated coroutine function.""" + return (inspect.isgeneratorfunction(func) and + getattr(func, '_is_coroutine', False)) + + +# TODO: Do we need this? +def iscoroutine(obj): + """Return True if obj is a coroutine object.""" + return inspect.isgenerator(obj) # TODO: And what? + + +def task(func): + """Decorator for a coroutine to be wrapped in a Task.""" + def task_wrapper(*args, **kwds): + coro = func(*args, **kwds) + return Task(coro) + return task_wrapper + + +class Task(futures.Future): + """A coroutine wrapped in a Future.""" + + def __init__(self, coro, event_loop=None): + assert inspect.isgenerator(coro) # Must be a coroutine *object*. + super().__init__(event_loop=event_loop) # Sets self._event_loop. + self._coro = coro + self._must_cancel = False + self._event_loop.call_soon(self._step) + + def __repr__(self): + res = super().__repr__() + if (self._must_cancel and + self._state == futures._PENDING and + ')'.format(self._coro.__name__) + res[i:] + return res + + def cancel(self): + if self.done(): + return False + self._must_cancel = True + # _step() will call super().cancel() to call the callbacks. + self._event_loop.call_soon(self._step_maybe) + return True + + def cancelled(self): + return self._must_cancel or super().cancelled() + + def _step_maybe(self): + # Helper for cancel(). + if not self.done(): + return self._step() + + def _step(self, value=None, exc=None): + if self.done(): + logging.warn('_step(): already done: %r, %r, %r', self, value, exc) + return + # We'll call either coro.throw(exc) or coro.send(value). + if self._must_cancel: + exc = futures.CancelledError + coro = self._coro + try: + if exc is not None: + result = coro.throw(exc) + elif value is not None: + result = coro.send(value) + else: + result = next(coro) + except StopIteration as exc: + if self._must_cancel: + super().cancel() + else: + self.set_result(exc.value) + except Exception as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + logging.exception('Exception in task') + except BaseException as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + logging.exception('BaseException in task') + raise + else: + # XXX No check for self._must_cancel here? + if isinstance(result, futures.Future): + result.add_done_callback(self._wakeup) + elif isinstance(result, concurrent.futures.Future): + # This ought to be more efficient than wrap_future(), + # because we don't create an extra Future. + result.add_done_callback( + lambda future: + self._event_loop.call_soon_threadsafe( + self._wakeup, future)) + else: + if result is not None: + logging.warn('_step(): bad yield: %r', result) + self._event_loop.call_soon(self._step) + + def _wakeup(self, future): + try: + value = future.result() + except Exception as exc: + self._step(None, exc) + else: + self._step(value, None) + + +# wait() and as_completed() similar to those in PEP 3148. + +FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED +FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION +ALL_COMPLETED = concurrent.futures.ALL_COMPLETED + + +# Even though this *is* a @coroutine, we don't mark it as such! +def wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Wait for the Futures and and coroutines given by fs to complete. + + Coroutines will be wrapped in Tasks. + + Returns two sets of Future: (done, pending). + + Usage: + + done, pending = yield from tulip.wait(fs) + + Note: This does not raise TimeoutError! Futures that aren't done + when the timeout occurs are returned in the second set. + """ + fs = _wrap_coroutines(fs) + return _wait(fs, timeout, return_when) + + +@coroutine +def _wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Internal helper: Like wait() but does not wrap coroutines.""" + done, pending = set(), set() + + errors = 0 + for f in fs: + if f.done(): + done.add(f) + if not f.cancelled() and f.exception() is not None: + errors += 1 + else: + pending.add(f) + + if (not pending or + timeout != None and timeout <= 0 or + return_when == FIRST_COMPLETED and done or + return_when == FIRST_EXCEPTION and errors): + return done, pending + + bail = futures.Future() # Will always be cancelled eventually. + timeout_handler = None + debugstuff = locals() + + if timeout is not None: + loop = events.get_event_loop() + timeout_handler = loop.call_later(timeout, bail.cancel) + + def _on_completion(f): + pending.remove(f) + done.add(f) + if (not pending or + return_when == FIRST_COMPLETED or + (return_when == FIRST_EXCEPTION and + not f.cancelled() and + f.exception() is not None)): + bail.cancel() + + try: + for f in pending: + f.add_done_callback(_on_completion) + try: + yield from bail + except futures.CancelledError: + pass + finally: + for f in pending: + f.remove_done_callback(_on_completion) + if timeout_handler is not None: + timeout_handler.cancel() + + really_done = set(f for f in pending if f.done()) + if really_done: + done.update(really_done) + pending.difference_update(really_done) + + return done, pending + + +# This is *not* a @coroutine! It is just an iterator (yielding Futures). +def as_completed(fs, timeout=None): + """Return an iterator whose values, when waited for, are Futures. + + This differs from PEP 3148; the proper way to use this is: + + for f in as_completed(fs): + result = yield from f # The 'yield from' may raise. + # Use result. + + Raises TimeoutError if the timeout occurs before all Futures are + done. + + Note: The futures 'f' are not necessarily members of fs. + """ + deadline = None + if timeout is not None: + deadline = time.monotonic() + timeout + done = None # Make nonlocal happy. + fs = _wrap_coroutines(fs) + while fs: + if deadline is not None: + timeout = deadline - time.monotonic() + @coroutine + def _wait_for_some(): + nonlocal done, fs + done, fs = yield from _wait(fs, timeout=timeout, + return_when=FIRST_COMPLETED) + if not done: + fs = set() + raise futures.TimeoutError() + return done.pop().result() # May raise. + yield Task(_wait_for_some()) + for f in done: + yield f + + +def _wrap_coroutines(fs): + """Internal helper to process an iterator of Futures and coroutines. + + Returns a set of Futures. + """ + wrapped = set() + for f in fs: + if not isinstance(f, futures.Future): + assert iscoroutine(f) + f = Task(f) + wrapped.add(f) + return wrapped + + +def sleep(when, result=None): + """Return a Future that completes after a given time (in seconds). + + It's okay to cancel the Future. + + Undocumented feature: sleep(when, x) sets the Future's result to x. + """ + future = futures.Future() + future._event_loop.call_later(when, future.set_result, result) + return future diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py new file mode 100644 index 00000000..4895732d --- /dev/null +++ b/tulip/tasks_test.py @@ -0,0 +1,473 @@ +"""Tests for tasks.py.""" + +import concurrent.futures +import time +import unittest +import unittest.mock + +from . import events +from . import futures +from . import tasks +from . import test_utils + + +class Dummy: + def __repr__(self): + return 'Dummy()' + def __call__(self, *args): + pass + + +class TaskTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + super().tearDown() + + def test_task_class(self): + @tasks.coroutine + def notmuch(): + yield from [] + return 'ok' + t = tasks.Task(notmuch()) + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t._event_loop, self.event_loop) + + event_loop = events.new_event_loop() + t = tasks.Task(notmuch(), event_loop=event_loop) + self.assertIs(t._event_loop, event_loop) + + def test_task_decorator(self): + @tasks.task + def notmuch(): + yield from [] + return 'ko' + t = notmuch() + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + + def test_task_repr(self): + @tasks.task + def notmuch(): + yield from [] + return 'abc' + t = notmuch() + t.add_done_callback(Dummy()) + self.assertEqual(repr(t), 'Task()') + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), 'Task()') + self.assertRaises(futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertEqual(repr(t), 'Task()') + t = notmuch() + self.event_loop.run_until_complete(t) + self.assertEqual(repr(t), "Task()") + + def test_task_repr_custom(self): + def coro(): + yield from [] + + class T(futures.Future): + def __repr__(self): + return 'T[]' + + class MyTask(tasks.Task, T): + def __repr__(self): + return super().__repr__() + + t = MyTask(coro()) + self.assertEqual(repr(t), 'T[]()') + + def test_task_basics(self): + @tasks.task + def outer(): + a = yield from inner1() + b = yield from inner2() + return a+b + @tasks.task + def inner1(): + yield from [] + return 42 + @tasks.task + def inner2(): + yield from [] + return 1000 + t = outer() + self.assertEqual(self.event_loop.run_until_complete(t), 1042) + + def test_cancel(self): + @tasks.task + def task(): + yield from tasks.sleep(10.0) + return 12 + + t = task() + self.event_loop.call_soon(t.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_cancel_in_coro(self): + @tasks.coroutine + def task(): + t.cancel() + yield from [] + return 12 + + t = tasks.Task(task()) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_wait(self): + a = tasks.sleep(0.1) + b = tasks.sleep(0.15) + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertEqual(res, 42) + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + # TODO: Test different return_when values. + + def test_wait_first_completed(self): + a = tasks.sleep(10.0) + b = tasks.sleep(0.1) + task = tasks.Task(tasks.wait( + [b, a], return_when=tasks.FIRST_COMPLETED)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + + def test_wait_really_done(self): + # there is possibility that some tasks in the pending list + # became done but their callbacks haven't all been called yet + + @tasks.coroutine + def coro1(): + yield from [None] + @tasks.coroutine + def coro2(): + yield from [None, None] + + a = tasks.Task(coro1()) + b = tasks.Task(coro2()) + task = tasks.Task(tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({a, b}, done) + + def test_wait_first_exception(self): + self.suppress_log_errors() + + a = tasks.sleep(10.0) + @tasks.coroutine + def exc(): + yield from [] + raise ZeroDivisionError('err') + + b = tasks.Task(exc()) + task = tasks.Task(tasks.wait( + [b, a], return_when=tasks.FIRST_EXCEPTION)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + + def test_wait_with_exception(self): + self.suppress_log_errors() + a = tasks.sleep(0.1) + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(0.15) + raise ZeroDivisionError('really') + b = tasks.Task(sleeper()) + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + + def test_wait_with_timeout(self): + a = tasks.sleep(0.1) + b = tasks.sleep(0.15) + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], timeout=0.11) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.1) + self.assertTrue(t1-t0 <= 0.13) + + def test_as_completed(self): + @tasks.coroutine + def sleeper(dt, x): + yield from tasks.sleep(dt) + return x + a = sleeper(0.1, 'a') + b = sleeper(0.1, 'b') + c = sleeper(0.15, 'c') + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([b, c, a]): + values.append((yield from f)) + return values + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + + def test_as_completed_with_timeout(self): + self.suppress_log_errors() + a = tasks.sleep(0.1, 'a') + b = tasks.sleep(0.15, 'b') + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([a, b], timeout=0.12): + try: + v = yield from f + values.append((1, v)) + except futures.TimeoutError as exc: + values.append((2, exc)) + return values + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.11) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) + + def test_sleep(self): + @tasks.coroutine + def sleeper(dt, arg): + yield from tasks.sleep(dt/2) + res = yield from tasks.sleep(dt/2, arg) + return res + t = tasks.Task(sleeper(0.1, 'yeah')) + t0 = time.monotonic() + self.event_loop.run() + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.09) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + + def test_task_cancel_sleeping_task(self): + sleepfut = None + @tasks.task + def sleep(dt): + nonlocal sleepfut + sleepfut = tasks.sleep(dt) + try: + t0 = time.monotonic() + yield from sleepfut + finally: + t1 = time.monotonic() + @tasks.task + def doit(): + sleeper = sleep(5000) + self.event_loop.call_later(0.1, sleeper.cancel) + try: + t0 = time.monotonic() + yield from sleeper + except futures.CancelledError: + t1 = time.monotonic() + return 'cancelled' + else: + return 'slept in' + t0 = time.monotonic() + doer = doit() + self.assertEqual(self.event_loop.run_until_complete(doer), 'cancelled') + t1 = time.monotonic() + self.assertTrue(0.09 <= t1-t0 <= 0.13, (t1-t0, sleepfut, doer)) + + @unittest.mock.patch('tulip.tasks.logging') + def test_step_in_completed_task(self, m_logging): + @tasks.coroutine + def notmuch(): + yield from [] + return 'ko' + + task = tasks.Task(notmuch()) + task.set_result('ok') + + task._step() + self.assertTrue(m_logging.warn.called) + self.assertTrue(m_logging.warn.call_args[0][0].startswith( + '_step(): already done: ')) + + @unittest.mock.patch('tulip.tasks.logging') + def test_step_result(self, m_logging): + @tasks.coroutine + def notmuch(): + yield from [None, 1] + return 'ko' + + task = tasks.Task(notmuch()) + task._step() + self.assertFalse(m_logging.warn.called) + + task._step() + self.assertTrue(m_logging.warn.called) + self.assertEqual( + '_step(): bad yield: %r', + m_logging.warn.call_args[0][0]) + self.assertEqual(1, m_logging.warn.call_args[0][1]) + + def test_step_result_future(self): + """Coroutine returns Future""" + class Fut(futures.Future): + def __init__(self, *args): + self.cb_added = False + super().__init__(*args) + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + c_fut = Fut() + + @tasks.coroutine + def notmuch(): + yield from [c_fut] + return (yield) + + task = tasks.Task(notmuch()) + task._step() + self.assertTrue(c_fut.cb_added) + + res = object() + c_fut.set_result(res) + self.event_loop.run() + self.assertIs(res, task.result()) + + def test_step_result_cuncurrent_future(self): + """Coroutine returns cuncurrent.future.Future""" + class Fut(concurrent.futures.Future): + + def __init__(self): + self.cb_added = False + super().__init__() + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + c_fut = Fut() + + @tasks.coroutine + def notmuch(): + yield from [c_fut] + return (yield) + + task = tasks.Task(notmuch()) + task._step() + self.assertTrue(c_fut.cb_added) + + res = object() + c_fut.set_result(res) + self.event_loop.run() + self.assertIs(res, task.result()) + + def test_step_with_baseexception(self): + self.suppress_log_errors() + + @tasks.coroutine + def notmutch(): + yield from [] + raise BaseException() + + task = tasks.Task(notmutch()) + self.assertRaises(BaseException, task._step) + + self.assertTrue(task.done()) + self.assertIsInstance(task.exception(), BaseException) + + def test_baseexception_during_cancel(self): + self.suppress_log_errors() + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(10) + + @tasks.coroutine + def notmutch(): + try: + yield from sleeper() + except futures.CancelledError: + raise BaseException() + + task = tasks.Task(notmutch()) + self.event_loop.run_once() + + task.cancel() + self.assertFalse(task.done()) + + self.assertRaises(BaseException, self.event_loop.run_once) + + self.assertTrue(task.done()) + self.assertTrue(task.cancelled()) + + def test_iscoroutinefunction(self): + def fn(): + pass + + self.assertFalse(tasks.iscoroutinefunction(fn)) + + def fn1(): + yield + self.assertFalse(tasks.iscoroutinefunction(fn1)) + + @tasks.coroutine + def fn2(): + yield + self.assertTrue(tasks.iscoroutinefunction(fn2)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tulip/test_utils.py b/tulip/test_utils.py new file mode 100644 index 00000000..f07c34ce --- /dev/null +++ b/tulip/test_utils.py @@ -0,0 +1,22 @@ +"""Utilities shared by tests.""" + +import logging +import unittest + + +class LogTrackingTestCase(unittest.TestCase): + + def setUp(self): + self._logger = logging.getLogger() + self._log_level = self._logger.getEffectiveLevel() + + def tearDown(self): + self._logger.setLevel(self._log_level) + + def suppress_log_errors(self): + if self._log_level >= logging.WARNING: + self._logger.setLevel(logging.CRITICAL) + + def suppress_log_warnings(self): + if self._log_level >= logging.WARNING: + self._logger.setLevel(logging.ERROR) diff --git a/tulip/transports.py b/tulip/transports.py new file mode 100644 index 00000000..4aaae3c7 --- /dev/null +++ b/tulip/transports.py @@ -0,0 +1,90 @@ +"""Abstract Transport class.""" + +__all__ = ['Transport'] + + +class Transport: + """ABC representing a transport. + + There may be several implementations, but typically, the user does + not implement new transports; rather, the platform provides some + useful transports that are implemented using the platform's best + practices. + + The user never instantiates a transport directly; they call a + utility function, passing it a protocol factory and other + information necessary to create the transport and protocol. (E.g. + EventLoop.create_connection() or EventLoop.start_serving().) + + The utility function will asynchronously create a transport and a + protocol and hook them up by calling the protocol's + connection_made() method, passing it the transport. + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def write_eof(self): + """Closes the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + raise NotImplementedError + + def can_write_eof(self): + """Return True if this protocol supports write_eof(), False if not.""" + raise NotImplementedError + + def pause(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + raise NotImplementedError + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError diff --git a/tulip/unix_events.py b/tulip/unix_events.py new file mode 100644 index 00000000..94bcde9d --- /dev/null +++ b/tulip/unix_events.py @@ -0,0 +1,1012 @@ +"""UNIX event loop and related classes. + +The event loop can be broken up into a selector (the part responsible +for telling us when file descriptors are ready) and the event loop +proper, which wraps a selector with functionality for scheduling +callbacks, immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. +""" + +import collections +import concurrent.futures +import errno +import heapq +import logging +import select +import socket +try: + import ssl +except ImportError: + ssl = None +import sys +import threading +import time + +try: + import signal +except ImportError: + signal = None + +from . import events +from . import futures +from . import protocols +from . import selectors +from . import tasks +from . import transports + +try: + from socket import socketpair +except ImportError: + assert sys.platform == 'win32' + from .winsocketpair import socketpair + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK, errno.EINPROGRESS)) +if sys.platform == 'win32': + _TRYAGAIN = frozenset(list(_TRYAGAIN) + [errno.WSAEWOULDBLOCK]) + +# Argument for default thread pool executor creation. +_MAX_WORKERS = 5 + + +class _StopError(BaseException): + """Raised to stop the event loop.""" + + +def _raise_stop_error(): + raise _StopError + + +class UnixEventLoop(events.EventLoop): + """Unix event loop. + + See events.EventLoop for API specification. + """ + + def __init__(self, selector=None): + super().__init__() + if selector is None: + # pick the best selector class for the platform + selector = selectors.Selector() + logging.debug('Using selector: %s', selector.__class__.__name__) + self._selector = selector + self._ready = collections.deque() + self._scheduled = [] + self._default_executor = None + self._signal_handlers = {} + self._make_self_pipe() + + def close(self): + if self._selector is not None: + self._selector.close() + self._selector = None + self._ssock.close() + self._csock.close() + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self.add_reader(self._ssock.fileno(), self._read_from_self) + + def _read_from_self(self): + try: + self._ssock.recv(1) + except socket.error as exc: + if exc.errno in _TRYAGAIN: + return + raise # Halp! + + def _write_to_self(self): + try: + self._csock.send(b'x') + except socket.error as exc: + if exc.errno in _TRYAGAIN: + return + raise # Halp! + + def run(self): + """Run the event loop until nothing left to do or stop() called. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + + TODO: Give this a timeout too? + """ + while (self._ready or + self._scheduled or + self._selector.registered_count() > 1): + try: + self._run_once() + except _StopError: + break + + def run_forever(self): + """Run until stop() is called. + + This only makes sense over run() if you have another thread + scheduling callbacks using call_soon_threadsafe(). + """ + handler = self.call_repeatedly(24*3600, lambda: None) + try: + self.run() + finally: + handler.cancel() + + def run_once(self, timeout=None): + """Run through all callbacks and all I/O polls once. + + Calling stop() will break out of this too. + """ + try: + self._run_once(timeout) + except _StopError: + pass + + def run_until_complete(self, future, timeout=None): + """Run until the Future is done, or until a timeout. + + Return the Future's result, or raise its exception. If the + timeout is reached or stop() is called, raise TimeoutError. + """ + if timeout is None: + timeout = 0x7fffffff/1000.0 # 24 days + future.add_done_callback(lambda _: self.stop()) + handler = self.call_later(timeout, _raise_stop_error) + self.run() + handler.cancel() + if future.done(): + return future.result() # May raise future.exception(). + else: + raise futures.TimeoutError + + def stop(self): + """Stop running the event loop. + + Every callback scheduled before stop() is called will run. + Callback scheduled after stop() is called won't. However, + those callbacks will run if run() is called again later. + """ + self.call_soon(_raise_stop_error) + + def call_later(self, delay, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The delay can be an int or float, expressed in seconds. It is + always a relative time. + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Callbacks scheduled in the past are passed on to call_soon(), + so these will be called in the order in which they were + registered rather than by time due. This is so you can't + cheat and insert yourself at the front of the ready queue by + using a negative time. + + Any positional arguments after the callback will be passed to + the callback when it is called. + + # TODO: Should delay is None be interpreted as Infinity? + """ + if delay <= 0: + return self.call_soon(callback, *args) + handler = events.make_handler(time.monotonic() + delay, callback, args) + heapq.heappush(self._scheduled, handler) + return handler + + def call_repeatedly(self, interval, callback, *args): + """Call a callback every 'interval' seconds.""" + def wrapper(): + callback(*args) # If this fails, the chain is broken. + handler._when = time.monotonic() + interval + heapq.heappush(self._scheduled, handler) + handler = events.make_handler(time.monotonic() + interval, wrapper, ()) + heapq.heappush(self._scheduled, handler) + return handler + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + handler = events.make_handler(None, callback, args) + self._ready.append(handler) + return handler + + def call_soon_threadsafe(self, callback, *args): + """XXX""" + handler = self.call_soon(callback, *args) + self._write_to_self() + return handler + + def wrap_future(self, future): + """XXX""" + if isinstance(future, futures.Future): + return future # Don't wrap our own type of Future. + new_future = futures.Future() + future.add_done_callback( + lambda future: + self.call_soon_threadsafe(new_future._copy_state, future)) + return new_future + + def run_in_executor(self, executor, callback, *args): + if isinstance(callback, events.Handler): + assert not args + assert callback.when is None + if callback.cancelled: + f = futures.Future() + f.set_result(None) + return f + callback, args = callback.callback, callback.args + if executor is None: + executor = self._default_executor + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + self._default_executor = executor + return self.wrap_future(executor.submit(callback, *args)) + + def set_default_executor(self, executor): + self._default_executor = executor + + def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) + + def getnameinfo(self, sockaddr, flags=0): + return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + + @tasks.task + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=False, family=0, proto=0, flags=0, sock=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + "host, port and sock can not be specified at the same time") + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + yield self.sock_connect(sock, address) + except socket.error as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + if len(exceptions) == 1: + raise exceptions[0] + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise socket.error('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + + elif sock is None: + raise ValueError( + "host and port was not specified and no sock specified") + + protocol = protocol_factory() + waiter = futures.Future() + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + transport = _UnixSslTransport( + self, sock, protocol, sslcontext, waiter) + else: + transport = _UnixSocketTransport( + self, sock, protocol, waiter) + + yield from waiter + return transport, protocol + + # TODO: Or create_server()? + @tasks.task + def start_serving(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, backlog=100, sock=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + "host, port and sock can not be specified at the same time") + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + # TODO: Maybe we want to bind every address in the list + # instead of the first one that works? + exceptions = [] + for family, type, proto, cname, address in infos: + sock = socket.socket(family=family, type=type, proto=proto) + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + except socket.error as exc: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + + elif sock is None: + raise ValueError( + "host and port was not specified and no sock specified") + + sock.listen(backlog) + sock.setblocking(False) + self.add_reader(sock.fileno(), self._accept_connection, + protocol_factory, sock) + return sock + + def _accept_connection(self, protocol_factory, sock): + try: + conn, addr = sock.accept() + except socket.error as exc: + if exc.errno in _TRYAGAIN: + return # False alarm. + # Bad error. Stop serving. + self.remove_reader(sock.fileno()) + sock.close() + # There's nowhere to send the error, so just log it. + # TODO: Someone will want an error handler for this. + logging.exception('Accept failed') + return + protocol = protocol_factory() + transport = _UnixSocketTransport(self, conn, protocol) + # It's now up to the protocol to handle the connection. + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a Handler instance.""" + handler = events.make_handler(None, callback, args) + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_READ, + (handler, None, None)) + else: + self._selector.modify(fd, mask | selectors.EVENT_READ, + (handler, writer, connector)) + + return handler + + def remove_reader(self, fd): + """Remove a reader callback.""" + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + return False + else: + mask &= ~selectors.EVENT_READ + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (None, writer, connector)) + if reader is not None: + reader.cancel() + return True + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a Handler instance.""" + handler = events.make_handler(None, callback, args) + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_WRITE, + (None, handler, None)) + else: + # Remove connector. + mask &= ~selectors.EVENT_CONNECT + self._selector.modify(fd, mask | selectors.EVENT_WRITE, + (reader, handler, None)) + return handler + + def remove_writer(self, fd): + """Remove a writer callback.""" + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + return False + else: + # Remove both writer and connector. + mask &= ~(selectors.EVENT_WRITE | selectors.EVENT_CONNECT) + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None, None)) + if writer is not None: + writer.cancel() + if connector is not None: + connector.cancel() + return True + + # NOTE: add_connector() and add_writer() are mutually exclusive. + # While you can independently manipulate readers and writers, + # adding a connector for a particular FD automatically removes the + # writer for that FD, and vice versa, and removing a writer or a + # connector actually removes both writer and connector. This is + # because in most cases writers and connectors use the same mode + # for the platform polling function; the distinction is only + # important for PollSelector() on Windows. + + def add_connector(self, fd, callback, *args): + """Add a connector callback. Return a Handler instance.""" + handler = events.make_handler(None, callback, args) + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_CONNECT, + (None, None, handler)) + else: + # Remove writer. + mask &= ~selectors.EVENT_WRITE + self._selector.modify(fd, mask | selectors.EVENT_CONNECT, + (reader, None, handler)) + return handler + + def remove_connector(self, fd): + """Remove a connector callback.""" + try: + mask, (reader, writer, connector) = self._selector.get_info(fd) + except KeyError: + return False + else: + # Remove both writer and connector. + mask &= ~(selectors.EVENT_WRITE | selectors.EVENT_CONNECT) + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None, None)) + if writer is not None: + writer.cancel() + if connector is not None: + connector.cancel() + return True + + def sock_recv(self, sock, n): + """XXX""" + fut = futures.Future() + self._sock_recv(fut, False, sock, n) + return fut + + def _sock_recv(self, fut, registered, sock, n): + fd = sock.fileno() + if registered: + # Remove the callback early. It should be rare that the + # selector says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_reader(fd) + if fut.cancelled(): + return + try: + data = sock.recv(n) + fut.set_result(data) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + fut.set_exception(exc) + else: + self.add_reader(fd, self._sock_recv, fut, True, sock, n) + + def sock_sendall(self, sock, data): + """XXX""" + fut = futures.Future() + self._sock_sendall(fut, False, sock, data) + return fut + + def _sock_sendall(self, fut, registered, sock, data): + fd = sock.fileno() + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + n = 0 + try: + if data: + n = sock.send(data) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + fut.set_exception(exc) + return + if n == len(data): + fut.set_result(None) + else: + if n: + data = data[n:] + self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + + def sock_connect(self, sock, address): + """XXX""" + # That address better not require a lookup! We're not calling + # self.getaddrinfo() for you here. But verifying this is + # complicated; the socket module doesn't have a pattern for + # IPv6 addresses (there are too many forms, apparently). + fut = futures.Future() + self._sock_connect(fut, False, sock, address) + return fut + + def _sock_connect(self, fut, registered, sock, address): + fd = sock.fileno() + if registered: + self.remove_connector(fd) + if fut.cancelled(): + return + try: + if not registered: + # First time around. + sock.connect(address) + else: + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to the except clause below. + raise socket.error(err, 'Connect call failed') + fut.set_result(None) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + fut.set_exception(exc) + else: + self.add_connector(fd, self._sock_connect, + fut, True, sock, address) + + def sock_accept(self, sock): + """XXX""" + fut = futures.Future() + self._sock_accept(fut, False, sock) + return fut + + def _sock_accept(self, fut, registered, sock): + fd = sock.fileno() + if registered: + self.remove_reader(fd) + if fut.cancelled(): + return + try: + conn, address = sock.accept() + conn.setblocking(False) + fut.set_result((conn, address)) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + fut.set_exception(exc) + else: + self.add_reader(fd, self._sock_accept, fut, True, sock) + + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + self._check_signal(sig) + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except ValueError as exc: + raise RuntimeError(str(exc)) + handler = events.make_handler(None, callback, args) + self._signal_handlers[sig] = handler + try: + signal.signal(sig, self._handle_signal) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as nexc: + logging.info('set_wakeup_fd(-1) failed: %s', nexc) + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + return handler + + def _handle_signal(self, sig, arg): + """Internal helper that is the actual signal handler.""" + handler = self._signal_handlers.get(sig) + if handler is None: + return # Assume it's some race condition. + if handler.cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self.call_soon_threadsafe(handler.callback, *handler.args) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not.""" + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as exc: + logging.info('set_wakeup_fd(-1) failed: %s', exc) + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + if signal is None: + raise RuntimeError('Signals are not supported') + if not (1 <= sig < signal.NSIG): + raise ValueError('sig {} out of range(1, {})'.format(sig, + signal.NSIG)) + if sys.platform == 'win32': + raise RuntimeError('Signals are not really supported on Windows') + + def _add_callback(self, handler): + """Add a Handler to ready or scheduled.""" + if handler.cancelled: + return + if handler.when is None: + self._ready.append(handler) + else: + heapq.heappush(self._scheduled, handler) + + def _run_once(self, timeout=None): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0].cancelled: + heapq.heappop(self._scheduled) + + # Inspect the poll queue. If there's exactly one selectable + # file descriptor, it's the self-pipe, and if there's nothing + # scheduled, we should ignore it. + if self._selector.registered_count() > 1 or self._scheduled: + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0].when + deadline = max(0, when - time.monotonic()) + if timeout is None: + timeout = deadline + else: + timeout = min(timeout, deadline) + + t0 = time.monotonic() + event_list = self._selector.select(timeout) + t1 = time.monotonic() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + for fileobj, mask, (reader, writer, connector) in event_list: + if mask & selectors.EVENT_READ and reader is not None: + if reader.cancelled: + self.remove_reader(fileobj) + else: + self._add_callback(reader) + if mask & selectors.EVENT_WRITE and writer is not None: + if writer.cancelled: + self.remove_writer(fileobj) + else: + self._add_callback(writer) + elif mask & selectors.EVENT_CONNECT and connector is not None: + if connector.cancelled: + self.remove_connector(fileobj) + else: + self._add_callback(connector) + + # Handle 'later' callbacks that are ready. + now = time.monotonic() + while self._scheduled: + handler = self._scheduled[0] + if handler.when > now: + break + handler = heapq.heappop(self._scheduled) + self._ready.append(handler) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is threadsafe without using locks. + ntodo = len(self._ready) + for i in range(ntodo): + handler = self._ready.popleft() + if not handler.cancelled: + try: + handler.callback(*handler.args) + except Exception: + logging.exception('Exception in callback %s %r', + handler.callback, handler.args) + + +class _UnixSocketTransport(transports.Transport): + + def __init__(self, event_loop, sock, protocol, waiter=None): + self._event_loop = event_loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._closing = False # Set when close() called. + self._event_loop.add_reader(self._sock.fileno(), self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = self._sock.recv(16*1024) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + else: + if data: + self._event_loop.call_soon(self._protocol.data_received, data) + else: + self._event_loop.remove_reader(self._sock.fileno()) + self._event_loop.call_soon(self._protocol.eof_received) + + + def write(self, data): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + if not self._buffer: + # Attempt to send it right away first. + try: + n = self._sock.send(data) + except socket.error as exc: + if exc.errno in _TRYAGAIN: + n = 0 + else: + self._fatal_error(exc) + return + if n == len(data): + return + if n: + data = data[n:] + self._event_loop.add_writer(self._sock.fileno(), self._write_ready) + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + self._buffer = [] + try: + if data: + n = self._sock.send(data) + else: + n = 0 + except socket.error as exc: + if exc.errno in _TRYAGAIN: + n = 0 + else: + self._fatal_error(exc) + return + if n == len(data): + self._event_loop.remove_writer(self._sock.fileno()) + if self._closing: + self._event_loop.call_soon(self._call_connection_lost, None) + return + if n: + data = data[n:] + self._buffer.append(data) # Try again later. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sock.fileno()) + if not self._buffer: + self._event_loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + logging.exception('Fatal error for %s', self) + self._event_loop.remove_writer(self._sock.fileno()) + self._event_loop.remove_reader(self._sock.fileno()) + self._buffer = [] + self._event_loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + + +class _UnixSslTransport(transports.Transport): + + def __init__(self, event_loop, rawsock, protocol, sslcontext, waiter): + self._event_loop = event_loop + self._rawsock = rawsock + self._protocol = protocol + sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslcontext = sslcontext + self._waiter = waiter + sslsock = sslcontext.wrap_socket(rawsock, + do_handshake_on_connect=False) + self._sslsock = sslsock + self._buffer = [] + self._closing = False # Set when close() called. + self._on_handshake() + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._event_loop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._event_loop.add_writer(fd, self._on_handshake) + return + except Exception as exc: + self._sslsock.close() + self._waiter.set_exception(exc) + return + except BaseException as exc: + self._sslsock.close() + self._waiter.set_exception(exc) + raise + self._event_loop.remove_reader(fd) + self._event_loop.remove_writer(fd) + self._event_loop.add_reader(fd, self._on_ready) + self._event_loop.add_writer(fd, self._on_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.call_soon(self._waiter.set_result, None) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + return + else: + if data: + self._protocol.data_received(data) + else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer + self._event_loop.remove_reader(fd) + self._event_loop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + return + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = b''.join(self._buffer) + self._buffer = [] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + return + else: + if n < len(data): + self._buffer.append(data[n:]) + + def write(self, data): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sslsock.fileno()) + if not self._buffer: + self._event_loop.call_soon(self._protocol.connection_lost, None) + + def _fatal_error(self, exc): + logging.exception('Fatal error for %s', self) + self._event_loop.remove_writer(self._sslsock.fileno()) + self._event_loop.remove_reader(self._sslsock.fileno()) + self._buffer = [] + self._event_loop.call_soon(self._protocol.connection_lost, exc) diff --git a/tulip/winsocketpair.py b/tulip/winsocketpair.py new file mode 100644 index 00000000..87d54c91 --- /dev/null +++ b/tulip/winsocketpair.py @@ -0,0 +1,30 @@ +"""A socket pair usable as a self-pipe, for Windows. + +Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. +""" + +import errno +import socket + + +def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """Emulate the Unix socketpair() function on Windows.""" + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + lsock.bind(('localhost', 0)) + lsock.listen(1) + addr, port = lsock.getsockname() + csock = socket.socket(family, type, proto) + csock.setblocking(False) + try: + csock.connect((addr, port)) + except socket.error as e: + if e.errno != errno.WSAEWOULDBLOCK: + lsock.close() + csock.close() + raise + ssock, _ = lsock.accept() + csock.setblocking(True) + lsock.close() + return (ssock, csock) From ededd4d0c75b62ea6d60960b90047e67a0db8026 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 5 Feb 2013 12:43:18 -0800 Subject: [PATCH 0291/1502] Merge rev 300 into iocp branch. --- tulip/base_events.py | 136 ++++++++++------- tulip/events.py | 8 +- tulip/events_test.py | 354 ++++++++++++++++++++++++++++--------------- 3 files changed, 312 insertions(+), 186 deletions(-) diff --git a/tulip/base_events.py b/tulip/base_events.py index 92304211..973ce177 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -202,81 +202,103 @@ def getnameinfo(self, sockaddr, flags=0): return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) @tasks.task - def create_connection(self, protocol_factory, host, port, *, ssl=False, - family=0, proto=0, flags=0): + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=False, family=0, proto=0, flags=0, sock=None): """XXX""" - infos = yield from self.getaddrinfo(host, port, - family=family, - type=socket.SOCK_STREAM, - proto=proto, flags=flags) - if not infos: - raise socket.error('getaddrinfo() returned empty list') - exceptions = [] - for family, type, proto, cname, address in infos: - sock = None - try: - sock = socket.socket(family=family, type=type, proto=proto) - sock.setblocking(False) - yield self.sock_connect(sock, address) - except socket.error as exc: - if sock is not None: - sock.close() - exceptions.append(exc) - else: - break - else: - if len(exceptions) == 1: - raise exceptions[0] + if host is not None or port is not None: + if sock is not None: + raise ValueError( + "host, port and sock can not be specified at the same time") + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + yield self.sock_connect(sock, address) + except socket.error as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break else: - # If they all have the same str(), raise one. - model = str(exceptions[0]) - if all(str(exc) == model for exc in exceptions): + if len(exceptions) == 1: raise exceptions[0] - # Raise a combined exception so the user can see all - # the various error messages. - raise socket.error('Multiple exceptions: {}'.format( - ', '.join(str(exc) for exc in exceptions))) + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise socket.error('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + + elif sock is None: + raise ValueError( + "host and port was not specified and no sock specified") + protocol = protocol_factory() waiter = futures.Future() if ssl: sslcontext = None if isinstance(ssl, bool) else ssl - transport = self.SslTransport(self, sock, protocol, sslcontext, - waiter) + transport = self.SslTransport(self, sock, protocol, sslcontext, waiter) else: transport = self.SocketTransport(self, sock, protocol, waiter) + yield from waiter return transport, protocol # TODO: Or create_server()? @tasks.task - def start_serving(self, protocol_factory, host, port, *, - family=0, proto=0, flags=0, - backlog=100): + def start_serving(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, backlog=100, sock=None): """XXX""" - infos = yield from self.getaddrinfo(host, port, - family=family, - type=socket.SOCK_STREAM, - proto=proto, flags=flags) - if not infos: - raise socket.error('getaddrinfo() returned empty list') - # TODO: Maybe we want to bind every address in the list - # instead of the first one that works? - exceptions = [] - for family, type, proto, cname, address in infos: - sock = socket.socket(family=family, type=type, proto=proto) - try: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(address) - except socket.error as exc: - sock.close() - exceptions.append(exc) + if host is not None or port is not None: + if sock is not None: + raise ValueError( + "host, port and sock can not be specified at the same time") + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + # TODO: Maybe we want to bind every address in the list + # instead of the first one that works? + exceptions = [] + for family, type, proto, cname, address in infos: + sock = socket.socket(family=family, type=type, proto=proto) + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + except socket.error as exc: + sock.close() + exceptions.append(exc) + else: + break else: - break - else: - raise exceptions[0] + raise exceptions[0] + + elif sock is None: + raise ValueError( + "host and port was not specified and no sock specified") + sock.listen(backlog) sock.setblocking(False) - self._start_serving(protocol_factory, sock) + self.add_reader(sock.fileno(), self._accept_connection, + protocol_factory, sock) return sock def _add_callback(self, handler): diff --git a/tulip/events.py b/tulip/events.py index 6aaea66d..e9fceaab 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -140,12 +140,12 @@ def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): def getnameinfo(self, sockaddr, flags=0): raise NotImplementedError - def create_connection(self, protocol_factory, host, port, *, - family=0, proto=0, flags=0): + def create_connection(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, sock=None): raise NotImplementedError - def start_serving(self, protocol_factory, host, port, *, - family=0, proto=0, flags=0): + def start_serving(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, sock=None): raise NotImplementedError # Ready-based callback registration methods. diff --git a/tulip/events_test.py b/tulip/events_test.py index 0b27e94d..8b356681 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -1,6 +1,7 @@ """Tests for events.py.""" import concurrent.futures +import errno import gc import os import select @@ -14,12 +15,14 @@ import threading import time import unittest +import unittest.mock from . import events from . import transports from . import protocols -from . import selectors from . import selector_events +from . import test_utils +from . import unix_events class MyProto(protocols.Protocol): @@ -51,6 +54,7 @@ def connection_lost(self, exc): class EventLoopTestsMixin: def setUp(self): + super().setUp() self.event_loop = self.create_event_loop() events.set_event_loop(self.event_loop) @@ -59,118 +63,110 @@ def tearDown(self): gc.collect() def test_run(self): - el = events.get_event_loop() - el.run() # Returns immediately. + self.event_loop.run() # Returns immediately. def test_call_later(self): - el = events.get_event_loop() results = [] def callback(arg): results.append(arg) - el.call_later(0.1, callback, 'hello world') + self.event_loop.call_later(0.1, callback, 'hello world') t0 = time.monotonic() - el.run() + self.event_loop.run() t1 = time.monotonic() self.assertEqual(results, ['hello world']) self.assertTrue(t1-t0 >= 0.09) def test_call_repeatedly(self): - el = events.get_event_loop() results = [] def callback(arg): results.append(arg) - el.call_repeatedly(0.03, callback, 'ho') - el.call_later(0.1, el.stop) - el.run() + self.event_loop.call_repeatedly(0.03, callback, 'ho') + self.event_loop.call_later(0.1, self.event_loop.stop) + self.event_loop.run() self.assertEqual(results, ['ho', 'ho', 'ho']) def test_call_soon(self): - el = events.get_event_loop() results = [] def callback(arg1, arg2): results.append((arg1, arg2)) - el.call_soon(callback, 'hello', 'world') - el.run() + self.event_loop.call_soon(callback, 'hello', 'world') + self.event_loop.run() self.assertEqual(results, [('hello', 'world')]) def test_call_soon_with_handler(self): - el = events.get_event_loop() results = [] def callback(): results.append('yeah') handler = events.Handler(None, callback, ()) - self.assertEqual(el.call_soon(handler), handler) - el.run() + self.assertIs(self.event_loop.call_soon(handler), handler) + self.event_loop.run() self.assertEqual(results, ['yeah']) def test_call_soon_threadsafe(self): - el = events.get_event_loop() results = [] def callback(arg): results.append(arg) def run(): - el.call_soon_threadsafe(callback, 'hello') + self.event_loop.call_soon_threadsafe(callback, 'hello') t = threading.Thread(target=run) - el.call_later(0.1, callback, 'world') + self.event_loop.call_later(0.1, callback, 'world') t0 = time.monotonic() t.start() - el.run() + self.event_loop.run() t1 = time.monotonic() t.join() self.assertEqual(results, ['hello', 'world']) self.assertTrue(t1-t0 >= 0.09) def test_call_soon_threadsafe_with_handler(self): - el = events.get_event_loop() results = [] def callback(arg): results.append(arg) + handler = events.Handler(None, callback, ('hello',)) def run(): - self.assertEqual(el.call_soon_threadsafe(handler), handler) + self.assertIs(self.event_loop.call_soon_threadsafe(handler),handler) + t = threading.Thread(target=run) - el.call_later(0.1, callback, 'world') + self.event_loop.call_later(0.1, callback, 'world') + t0 = time.monotonic() t.start() - el.run() + self.event_loop.run() t1 = time.monotonic() t.join() self.assertEqual(results, ['hello', 'world']) self.assertTrue(t1-t0 >= 0.09) def test_wrap_future(self): - el = events.get_event_loop() def run(arg): time.sleep(0.1) return arg ex = concurrent.futures.ThreadPoolExecutor(1) f1 = ex.submit(run, 'oi') - f2 = el.wrap_future(f1) - res = el.run_until_complete(f2) + f2 = self.event_loop.wrap_future(f1) + res = self.event_loop.run_until_complete(f2) self.assertEqual(res, 'oi') def test_run_in_executor(self): - el = events.get_event_loop() def run(arg): time.sleep(0.1) return arg - f2 = el.run_in_executor(None, run, 'yo') - res = el.run_until_complete(f2) + f2 = self.event_loop.run_in_executor(None, run, 'yo') + res = self.event_loop.run_until_complete(f2) self.assertEqual(res, 'yo') def test_run_in_executor_with_handler(self): - el = events.get_event_loop() def run(arg): time.sleep(0.1) return arg handler = events.Handler(None, run, ('yo',)) - f2 = el.run_in_executor(None, handler) - res = el.run_until_complete(f2) + f2 = self.event_loop.run_in_executor(None, handler) + res = self.event_loop.run_until_complete(f2) self.assertEqual(res, 'yo') def test_reader_callback(self): - el = events.get_event_loop() - r, w = el._socketpair() + r, w = self.event_loop._socketpair() bytes_read = [] def reader(): try: @@ -182,18 +178,17 @@ def reader(): if data: bytes_read.append(data) else: - self.assertTrue(el.remove_reader(r.fileno())) + self.assertTrue(self.event_loop.remove_reader(r.fileno())) r.close() - el.add_reader(r.fileno(), reader) - el.call_later(0.05, w.send, b'abc') - el.call_later(0.1, w.send, b'def') - el.call_later(0.15, w.close) - el.run() + self.event_loop.add_reader(r.fileno(), reader) + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() self.assertEqual(b''.join(bytes_read), b'abcdef') def test_reader_callback_with_handler(self): - el = events.get_event_loop() - r, w = el._socketpair() + r, w = self.event_loop._socketpair() bytes_read = [] def reader(): try: @@ -205,19 +200,20 @@ def reader(): if data: bytes_read.append(data) else: - self.assertTrue(el.remove_reader(r.fileno())) + self.assertTrue(self.event_loop.remove_reader(r.fileno())) r.close() + handler = events.Handler(None, reader, ()) - self.assertEqual(el.add_reader(r.fileno(), handler), handler) - el.call_later(0.05, w.send, b'abc') - el.call_later(0.1, w.send, b'def') - el.call_later(0.15, w.close) - el.run() + self.assertIs(handler, self.event_loop.add_reader(r.fileno(), handler)) + + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() self.assertEqual(b''.join(bytes_read), b'abcdef') def test_reader_callback_cancel(self): - el = events.get_event_loop() - r, w = el._socketpair() + r, w = self.event_loop._socketpair() bytes_read = [] def reader(): try: @@ -230,76 +226,75 @@ def reader(): handler.cancel() if not data: r.close() - handler = el.add_reader(r.fileno(), reader) - el.call_later(0.05, w.send, b'abc') - el.call_later(0.1, w.send, b'def') - el.call_later(0.15, w.close) - el.run() + handler = self.event_loop.add_reader(r.fileno(), reader) + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() self.assertEqual(b''.join(bytes_read), b'abcdef') def test_writer_callback(self): - el = events.get_event_loop() - r, w = el._socketpair() + r, w = self.event_loop._socketpair() w.setblocking(False) - el.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + self.event_loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) def remove_writer(): - self.assertTrue(el.remove_writer(w.fileno())) - el.call_later(0.1, remove_writer) - el.run() + self.assertTrue(self.event_loop.remove_writer(w.fileno())) + self.event_loop.call_later(0.1, remove_writer) + self.event_loop.run() w.close() data = r.recv(256*1024) r.close() self.assertTrue(len(data) >= 200) def test_writer_callback_with_handler(self): - el = events.get_event_loop() - r, w = el._socketpair() + r, w = self.event_loop._socketpair() w.setblocking(False) handler = events.Handler(None, w.send, (b'x'*(256*1024),)) - self.assertEqual(el.add_writer(w.fileno(), handler), handler) + self.assertIs(self.event_loop.add_writer(w.fileno(), handler), handler) def remove_writer(): - self.assertTrue(el.remove_writer(w.fileno())) - el.call_later(0.1, remove_writer) - el.run() + self.assertTrue(self.event_loop.remove_writer(w.fileno())) + self.event_loop.call_later(0.1, remove_writer) + self.event_loop.run() w.close() data = r.recv(256*1024) r.close() self.assertTrue(len(data) >= 200) def test_writer_callback_cancel(self): - el = events.get_event_loop() - r, w = el._socketpair() + r, w = self.event_loop._socketpair() w.setblocking(False) def sender(): w.send(b'x'*256) handler.cancel() - handler = el.add_writer(w.fileno(), sender) - el.run() + handler = self.event_loop.add_writer(w.fileno(), sender) + self.event_loop.run() w.close() data = r.recv(1024) r.close() self.assertTrue(data == b'x'*256) def test_sock_client_ops(self): - el = events.get_event_loop() sock = socket.socket() sock.setblocking(False) # TODO: This depends on python.org behavior! address = socket.getaddrinfo('python.org', 80, socket.AF_INET)[0][4] - el.run_until_complete(el.sock_connect(sock, address)) - el.run_until_complete(el.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) - data = el.run_until_complete(el.sock_recv(sock, 1024)) + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, address)) + self.event_loop.run_until_complete( + self.event_loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.event_loop.run_until_complete( + self.event_loop.sock_recv(sock, 1024)) sock.close() self.assertTrue(data.startswith(b'HTTP/1.1 302 Found\r\n')) def test_sock_client_fail(self): - el = events.get_event_loop() sock = socket.socket() sock.setblocking(False) # TODO: This depends on python.org behavior! address = socket.getaddrinfo('python.org', 12345, socket.AF_INET)[0][4] with self.assertRaises(ConnectionRefusedError): - el.run_until_complete(el.sock_connect(sock, address)) + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, address)) sock.close() def test_sock_accept(self): @@ -310,8 +305,9 @@ def test_sock_accept(self): listener.listen(1) client = socket.socket() client.connect(listener.getsockname()) - f = el.sock_accept(listener) - conn, addr = el.run_until_complete(f) + + f = self.event_loop.sock_accept(listener) + conn, addr = self.event_loop.run_until_complete(f) self.assertEqual(conn.gettimeout(), 0) self.assertEqual(addr, client.getsockname()) self.assertEqual(client.getpeername(), listener.getsockname()) @@ -325,33 +321,42 @@ def test_add_signal_handler(self): def my_handler(): nonlocal caught caught += 1 - el = events.get_event_loop() + # Check error behavior first. - self.assertRaises(TypeError, el.add_signal_handler, 'boom', my_handler) - self.assertRaises(TypeError, el.remove_signal_handler, 'boom') - self.assertRaises(ValueError, el.add_signal_handler, signal.NSIG+1, - my_handler) - self.assertRaises(ValueError, el.remove_signal_handler, signal.NSIG+1) - self.assertRaises(ValueError, el.add_signal_handler, 0, my_handler) - self.assertRaises(ValueError, el.remove_signal_handler, 0) - self.assertRaises(ValueError, el.add_signal_handler, -1, my_handler) - self.assertRaises(ValueError, el.remove_signal_handler, -1) - self.assertRaises(RuntimeError, el.add_signal_handler, signal.SIGKILL, - my_handler) + self.assertRaises( + TypeError, self.event_loop.add_signal_handler, 'boom', my_handler) + self.assertRaises( + TypeError, self.event_loop.remove_signal_handler, 'boom') + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, signal.NSIG+1) + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, 0, my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, 0) + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, -1, my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, -1) + self.assertRaises( + RuntimeError, self.event_loop.add_signal_handler, signal.SIGKILL, + my_handler) # Removing SIGKILL doesn't raise, since we don't call signal(). - self.assertFalse(el.remove_signal_handler(signal.SIGKILL)) + self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGKILL)) # Now set a handler and handle it. - el.add_signal_handler(signal.SIGINT, my_handler) - el.run_once() + self.event_loop.add_signal_handler(signal.SIGINT, my_handler) + self.event_loop.run_once() os.kill(os.getpid(), signal.SIGINT) - el.run_once() + self.event_loop.run_once() self.assertEqual(caught, 1) # Removing it should restore the default handler. - self.assertTrue(el.remove_signal_handler(signal.SIGINT)) + self.assertTrue(self.event_loop.remove_signal_handler(signal.SIGINT)) self.assertEqual(signal.getsignal(signal.SIGINT), signal.default_int_handler) # Removing again returns False. - self.assertFalse(el.remove_signal_handler(signal.SIGINT)) + self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGINT)) @unittest.skipIf(sys.platform == 'win32', 'Unix only') def test_cancel_signal_handler(self): @@ -360,11 +365,11 @@ def test_cancel_signal_handler(self): def my_handler(): nonlocal caught caught += 1 - el = events.get_event_loop() - handler = el.add_signal_handler(signal.SIGINT, my_handler) + + handler = self.event_loop.add_signal_handler(signal.SIGINT, my_handler) handler.cancel() os.kill(os.getpid(), signal.SIGINT) - el.run_once() + self.event_loop.run_once() self.assertEqual(caught, 0) @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') @@ -374,61 +379,155 @@ def test_signal_handling_while_selecting(self): def my_handler(): nonlocal caught caught += 1 - el = events.get_event_loop() - handler = el.add_signal_handler(signal.SIGALRM, my_handler) + + handler = self.event_loop.add_signal_handler(signal.SIGALRM, my_handler) signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. - el.call_later(0.15, el.stop) - el.run_forever() + self.event_loop.call_later(0.15, self.event_loop.stop) + self.event_loop.run_forever() self.assertEqual(caught, 1) def test_create_transport(self): - el = events.get_event_loop() # TODO: This depends on xkcd.com behavior! - f = el.create_connection(MyProto, 'xkcd.com', 80) - tr, pr = el.run_until_complete(f) + f = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + tr, pr = self.event_loop.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) - el.run() + self.event_loop.run() self.assertTrue(pr.nbytes > 0) @unittest.skipIf(ssl is None, 'No ssl module') def test_create_ssl_transport(self): - el = events.get_event_loop() # TODO: This depends on xkcd.com behavior! - f = el.create_connection(MyProto, 'xkcd.com', 443, ssl=True) - tr, pr = el.run_until_complete(f) + f = self.event_loop.create_connection( + MyProto, 'xkcd.com', 443, ssl=True) + tr, pr = self.event_loop.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) self.assertTrue('ssl' in tr.__class__.__name__.lower()) - el.run() + self.event_loop.run() self.assertTrue(pr.nbytes > 0) + def test_create_transport_host_port_sock(self): + fut = self.event_loop.create_connection( + MyProto, 'xkcd.com', 80, sock=object()) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_create_transport_no_host_port_sock(self): + fut = self.event_loop.create_connection(MyProto) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_create_transport_no_getaddrinfo(self): + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + fut = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, fut) + + def test_create_transport_connect_err(self): + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + fut = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, fut) + def test_start_serving(self): - el = events.get_event_loop() - f = el.start_serving(MyProto, '0.0.0.0', 0) - sock = el.run_until_complete(f) + f = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + sock = self.event_loop.run_until_complete(f) host, port = sock.getsockname() self.assertEqual(host, '0.0.0.0') client = socket.socket() client.connect(('127.0.0.1', port)) client.send(b'xxx') - el.run_once() # This is quite mysterious, but necessary. - el.run_once() - el.run_once() + self.event_loop.run_once() # This is quite mysterious, but necessary. + self.event_loop.run_once() + self.event_loop.run_once() sock.close() # the client socket must be closed after to avoid ECONNRESET upon # recv()/send() on the serving socket client.close() + def test_start_serving_sock(self): + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.event_loop.start_serving(MyProto, sock=sock_ob) + sock = self.event_loop.run_until_complete(f) + self.assertIs(sock, sock_ob) + + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + self.event_loop.run_once() # This is quite mysterious, but necessary. + self.event_loop.run_once() + self.event_loop.run_once() + sock.close() + client.close() + + def test_start_serving_host_port_sock(self): + fut = self.event_loop.start_serving(MyProto,'0.0.0.0',0,sock=object()) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_start_serving_no_host_port_sock(self): + fut = self.event_loop.start_serving(MyProto) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_start_serving_no_getaddrinfo(self): + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + f = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, f) + + @unittest.mock.patch('tulip.base_events.socket') + def test_start_serving_cant_bind(self, m_socket): + class Err(socket.error): + pass + + m_socket.error = socket.error + m_socket.getaddrinfo.return_value = [(2, 1, 6, '', ('127.0.0.1',10100))] + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.setsockopt.side_effect = Err + + fut = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises(Err, self.event_loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_accept_connection_retry(self): + class Err(socket.error): + errno = errno.EAGAIN + + sock = unittest.mock.Mock() + sock.accept.side_effect = Err + + self.event_loop._accept_connection(MyProto, sock) + self.assertFalse(sock.close.called) + + def test_accept_connection_exception(self): + self.suppress_log_errors() + + sock = unittest.mock.Mock() + sock.accept.side_effect = socket.error + + self.event_loop._accept_connection(MyProto, sock) + self.assertTrue(sock.close.called) + if sys.platform == 'win32': from . import windows_events - class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + class SelectEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): def create_event_loop(self): return windows_events.SelectorEventLoop() - class ProactorEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + class ProactorEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): def create_event_loop(self): return windows_events.ProactorEventLoop() def test_create_ssl_transport(self): @@ -447,25 +546,30 @@ def test_writer_callback_with_handler(self): raise unittest.SkipTest("IocpEventLoop does not have add_writer()") else: + from . import selectors from . import unix_events if hasattr(selectors, 'KqueueSelector'): - class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + class KqueueEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): def create_event_loop(self): return unix_events.SelectorEventLoop(selectors.KqueueSelector()) if hasattr(selectors, 'EpollSelector'): - class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + class EPollEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): def create_event_loop(self): return unix_events.SelectorEventLoop(selectors.EpollSelector()) if hasattr(selectors, 'PollSelector'): - class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + class PollEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): def create_event_loop(self): return unix_events.SelectorEventLoop(selectors.PollSelector()) # Should always exist. - class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + class SelectEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): def create_event_loop(self): return unix_events.SelectorEventLoop(selectors.SelectSelector()) From af2af0c138dbc3b7951e53c07b10165b85843263 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 5 Feb 2013 12:45:48 -0800 Subject: [PATCH 0292/1502] Merge the rest of default into iocp. --- Makefile | 3 + tulip/events.py | 20 ++++++ tulip/events_test.py | 134 ++++++++++++++++++++++++++++++++++++-- tulip/http_client_test.py | 95 ++++++++++++++++----------- tulip/test_utils.py | 12 ---- 5 files changed, 209 insertions(+), 55 deletions(-) diff --git a/Makefile b/Makefile index d11e9716..65a48111 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,5 @@ +# Some simple testing tasks (sorry, UNIX only). + PYTHON=python3 COVERAGE=coverage3 NONTESTS=`find tulip -name [a-z]\*.py ! -name \*_test.py` @@ -9,6 +11,7 @@ test: testloop: while sleep 1; do $(PYTHON) runtests.py -v $(FLAGS); done +# See README for coverage installation instructions. cov coverage: $(COVERAGE) run --branch runtests.py -v $(FLAGS) $(COVERAGE) html $(NONTESTS) diff --git a/tulip/events.py b/tulip/events.py index e9fceaab..0c7244fa 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -51,15 +51,35 @@ def cancel(self): self._cancelled = True def __lt__(self, other): + if self._when is None: + return other._when is not None + elif other._when is None: + return False + return self._when < other._when def __le__(self, other): + if self._when is None: + return True + elif other._when is None: + return False + return self._when <= other._when def __gt__(self, other): + if self._when is None: + return False + elif other._when is None: + return True + return self._when > other._when def __ge__(self, other): + if self._when is None: + return other._when is None + elif other._when is None: + return True + return self._when >= other._when def __eq__(self, other): diff --git a/tulip/events_test.py b/tulip/events_test.py index 8b356681..65820620 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -577,20 +577,146 @@ def create_event_loop(self): class HandlerTests(unittest.TestCase): def test_handler(self): - pass + def callback(*args): + return args + + args = () + h = events.Handler(None, callback, args) + self.assertIsNone(h.when) + self.assertIs(h.callback, callback) + self.assertIs(h.args, args) + self.assertFalse(h.cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handler(None, ' + '.callback')) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h.cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handler(None, ' + '.callback')) + self.assertTrue(r.endswith('())')) + + def test_handler_comparison(self): + def callback(*args): + return args + + h1 = events.Handler(None, callback, ()) + h2 = events.Handler(None, callback, ()) + self.assertTrue((h1 < h2) == False) + self.assertTrue((h1 <= h2) == True) + self.assertTrue((h1 > h2) == False) + self.assertTrue((h1 >= h2) == True) + self.assertTrue((h1 == h2) == True) + + when = time.monotonic() + + h1 = events.Handler(when, callback, ()) + h2 = events.Handler(None, callback, ()) + self.assertTrue((h1 < h2) == False) + self.assertTrue((h1 <= h2) == False) + self.assertTrue((h1 > h2) == True) + self.assertTrue((h1 >= h2) == True) + self.assertTrue((h1 == h2) == False) + + self.assertTrue((h2 < h1) == True) + self.assertTrue((h2 <= h1) == True) + self.assertTrue((h2 > h1) == False) + self.assertTrue((h2 >= h1) == False) + self.assertTrue((h2 == h1) == False) + + h1 = events.Handler(when, callback, ()) + h2 = events.Handler(when, callback, ()) + self.assertTrue((h1 < h2) == False) + self.assertTrue((h1 <= h2) == True) + self.assertTrue((h1 > h2) == False) + self.assertTrue((h1 >= h2) == True) + self.assertTrue((h1 == h2) == True) + + h1 = events.Handler(when, callback, ()) + h2 = events.Handler(when + 10.0, callback, ()) + self.assertTrue((h1 < h2) == True) + self.assertTrue((h1 <= h2) == True) + self.assertTrue((h1 > h2) == False) + self.assertTrue((h1 >= h2) == False) + self.assertTrue((h1 == h2) == False) def test_make_handler(self): def callback(*args): return args h1 = events.Handler(None, callback, ()) h2 = events.make_handler(None, h1, ()) - self.assertEqual(h1, h2) + self.assertIs(h1, h2) + + self.assertRaises(AssertionError, + events.make_handler, 10.0, h1, ()) + + self.assertRaises(AssertionError, + events.make_handler, None, h1, (1,2,)) class PolicyTests(unittest.TestCase): - def test_policy(self): - pass + def test_event_loop_policy(self): + policy = events.EventLoopPolicy() + self.assertRaises(NotImplementedError, policy.get_event_loop) + self.assertRaises(NotImplementedError, policy.set_event_loop, object()) + self.assertRaises(NotImplementedError, policy.new_event_loop) + + def test_get_event_loop(self): + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy._event_loop) + + event_loop = policy.get_event_loop() + self.assertIsInstance(event_loop, events.AbstractEventLoop) + + self.assertIs(policy._event_loop, event_loop) + self.assertIs(event_loop, policy.get_event_loop()) + + @unittest.mock.patch('tulip.events.threading') + def test_get_event_loop_thread(self, m_threading): + m_t = m_threading.current_thread.return_value = unittest.mock.Mock() + m_t.name = 'Thread 1' + + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy.get_event_loop()) + + def test_new_event_loop(self): + policy = events.DefaultEventLoopPolicy() + + event_loop = policy.new_event_loop() + self.assertIsInstance(event_loop, events.AbstractEventLoop) + + def test_set_event_loop(self): + policy = events.DefaultEventLoopPolicy() + old_event_loop = policy.get_event_loop() + + self.assertRaises(AssertionError, policy.set_event_loop, object()) + + event_loop = policy.new_event_loop() + policy.set_event_loop(event_loop) + self.assertIs(event_loop, policy.get_event_loop()) + self.assertIsNot(old_event_loop, policy.get_event_loop()) + + def test_get_event_loop_policy(self): + policy = events.get_event_loop_policy() + self.assertIsInstance(policy, events.EventLoopPolicy) + self.assertIs(policy, events.get_event_loop_policy()) + + def test_set_event_loop_policy(self): + self.assertRaises( + AssertionError, events.set_event_loop_policy, object()) + + old_policy = events.get_event_loop_policy() + + policy = events.DefaultEventLoopPolicy() + events.set_event_loop_policy(policy) + self.assertIs(policy, events.get_event_loop_policy()) if __name__ == '__main__': diff --git a/tulip/http_client_test.py b/tulip/http_client_test.py index 65a8b69d..c598a339 100644 --- a/tulip/http_client_test.py +++ b/tulip/http_client_test.py @@ -5,7 +5,6 @@ from . import events from . import http_client from . import tasks -from . import test_utils class StreamReaderTests(unittest.TestCase): @@ -14,10 +13,11 @@ class StreamReaderTests(unittest.TestCase): def setUp(self): self.event_loop = events.new_event_loop() - self.addCleanup(self.event_loop.close) - events.set_event_loop(self.event_loop) + def tearDown(self): + self.event_loop.close() + def test_feed_empty_data(self): stream = http_client.StreamReader() @@ -32,56 +32,62 @@ def test_feed_data_line_byte_count(self): self.assertEqual(self.DATA.count(b'\n'), stream.line_count) self.assertEqual(len(self.DATA), stream.byte_count) - @test_utils.sync def test_read_zero(self): - """ Read zero bytes """ + """Read zero bytes""" stream = http_client.StreamReader() stream.feed_data(self.DATA) - data = yield from stream.read(0) + read_task = tasks.Task(stream.read(0)) + data = self.event_loop.run_until_complete(read_task) self.assertEqual(b'', data) self.assertEqual(len(self.DATA), stream.byte_count) self.assertEqual(self.DATA.count(b'\n'), stream.line_count) - @test_utils.sync def test_read(self): """ Read bytes """ stream = http_client.StreamReader() - - res = stream.read(30) + read_task = tasks.Task(stream.read(30)) def cb(): stream.feed_data(self.DATA) self.event_loop.call_soon(cb) - data = yield from res + data = self.event_loop.run_until_complete(read_task) self.assertEqual(self.DATA, data) self.assertFalse(stream.byte_count) self.assertFalse(stream.line_count) - @test_utils.sync + def test_read_line_breaks(self): + """ Read bytes without line breaks """ + stream = http_client.StreamReader() + stream.feed_data(b'line1') + stream.feed_data(b'line2') + + read_task = tasks.Task(stream.read(5)) + data = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'line1', data) + self.assertEqual(5, stream.byte_count) + self.assertFalse(stream.line_count) + def test_read_eof(self): """ Read bytes, stop at eof """ stream = http_client.StreamReader() - - read = tasks.Task(stream.read(1024)) + read_task = tasks.Task(stream.read(1024)) def cb(): stream.feed_eof() self.event_loop.call_soon(cb) - data = yield from read - + data = self.event_loop.run_until_complete(read_task) self.assertEqual(b'', data) self.assertFalse(stream.byte_count) self.assertFalse(stream.line_count) - @test_utils.sync def test_read_until_eof(self): """ Read all bytes until eof """ stream = http_client.StreamReader() - - read = tasks.Task(stream.read(-1)) + read_task = tasks.Task(stream.read(-1)) def cb(): stream.feed_data(b'chunk1\n') @@ -89,17 +95,17 @@ def cb(): stream.feed_eof() self.event_loop.call_soon(cb) - data = yield from read + data = self.event_loop.run_until_complete(read_task) self.assertEqual(b'chunk1\nchunk2', data) self.assertFalse(stream.byte_count) self.assertFalse(stream.line_count) - @test_utils.sync def test_readline(self): """ Read one line """ stream = http_client.StreamReader() stream.feed_data(b'chunk1 ') + read_task = tasks.Task(stream.readline()) def cb(): stream.feed_data(b'chunk2 ') @@ -107,30 +113,41 @@ def cb(): stream.feed_data(b'\n chunk4') self.event_loop.call_soon(cb) - line = yield from stream.readline() - + line = self.event_loop.run_until_complete(read_task) self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) self.assertFalse(stream.line_count) self.assertEqual(len(b'\n chunk4')-1, stream.byte_count) - @test_utils.sync def test_readline_line_byte_count(self): stream = http_client.StreamReader() - stream.feed_data(self.DATA) + stream.feed_data(self.DATA[:6]) + stream.feed_data(self.DATA[6:]) - line = yield from stream.readline() + read_task = tasks.Task(stream.readline()) + line = self.event_loop.run_until_complete(read_task) self.assertEqual(b'line1\n', line) self.assertEqual(self.DATA.count(b'\n')-1, stream.line_count) self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) - @test_utils.sync + def test_readline_empty_eof(self): + stream = http_client.StreamReader() + stream.feed_eof() + + read_task = tasks.Task(stream.readline()) + line = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'', line) + def test_readline_read_byte_count(self): stream = http_client.StreamReader() stream.feed_data(self.DATA) - line = yield from stream.readline() - data = yield from stream.read(7) + read_task = tasks.Task(stream.readline()) + line = self.event_loop.run_until_complete(read_task) + + read_task = tasks.Task(stream.read(7)) + data = self.event_loop.run_until_complete(read_task) self.assertEqual(b'line2\nl', data) self.assertEqual( @@ -139,53 +156,53 @@ def test_readline_read_byte_count(self): len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), stream.byte_count) - @test_utils.sync def test_readexactly_zero_or_less(self): """ Read exact number of bytes (zero or less) """ stream = http_client.StreamReader() stream.feed_data(self.DATA) - data = yield from stream.readexactly(0) + read_task = tasks.Task(stream.readexactly(0)) + data = self.event_loop.run_until_complete(read_task) self.assertEqual(b'', data) self.assertEqual(len(self.DATA), stream.byte_count) self.assertEqual(self.DATA.count(b'\n'), stream.line_count) - data = yield from stream.readexactly(-1) + read_task = tasks.Task(stream.readexactly(-1)) + data = self.event_loop.run_until_complete(read_task) self.assertEqual(b'', data) self.assertEqual(len(self.DATA), stream.byte_count) self.assertEqual(self.DATA.count(b'\n'), stream.line_count) - @test_utils.sync def test_readexactly(self): """ Read exact number of bytes """ stream = http_client.StreamReader() + n = 2*len(self.DATA) + read_task = tasks.Task(stream.readexactly(n)) + def cb(): stream.feed_data(self.DATA) stream.feed_data(self.DATA) stream.feed_data(self.DATA) self.event_loop.call_soon(cb) - n = 2*len(self.DATA) - data = yield from stream.readexactly(n) - + data = self.event_loop.run_until_complete(read_task) self.assertEqual(self.DATA+self.DATA, data) self.assertEqual(len(self.DATA), stream.byte_count) self.assertEqual(self.DATA.count(b'\n'), stream.line_count) - @test_utils.sync def test_readexactly_eof(self): """ Read exact number of bytes (eof) """ stream = http_client.StreamReader() + n = 2*len(self.DATA) + read_task = tasks.Task(stream.readexactly(n)) def cb(): stream.feed_data(self.DATA) stream.feed_eof() self.event_loop.call_soon(cb) - n = 2*len(self.DATA) - data = yield from stream.readexactly(n) - + data = self.event_loop.run_until_complete(read_task) self.assertEqual(self.DATA, data) self.assertFalse(stream.byte_count) self.assertFalse(stream.line_count) diff --git a/tulip/test_utils.py b/tulip/test_utils.py index ac737c25..f07c34ce 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -1,20 +1,8 @@ """Utilities shared by tests.""" -import functools import logging import unittest -from . import events -from . import tasks - -def sync(gen): - @functools.wraps(gen) - def wrapper(*args, **kw): - return events.get_event_loop().run_until_complete( - tasks.Task(gen(*args, **kw))) - - return wrapper - class LogTrackingTestCase(unittest.TestCase): From d72afacbb155281692cc3efc8c519df7ecd0d6d3 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 5 Feb 2013 12:51:10 -0800 Subject: [PATCH 0293/1502] Kill bad import of unix_events. --- tulip/events_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tulip/events_test.py b/tulip/events_test.py index 65820620..7f8b4a38 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -22,7 +22,6 @@ from . import protocols from . import selector_events from . import test_utils -from . import unix_events class MyProto(protocols.Protocol): From 065088b573247bd40e7a27983c43f48f00920716 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 5 Feb 2013 13:49:37 -0800 Subject: [PATCH 0294/1502] Actually use _start_serving(). (Merge error probably.) --- tulip/base_events.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tulip/base_events.py b/tulip/base_events.py index 973ce177..a9733eab 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -297,8 +297,7 @@ def start_serving(self, protocol_factory, host=None, port=None, *, sock.listen(backlog) sock.setblocking(False) - self.add_reader(sock.fileno(), self._accept_connection, - protocol_factory, sock) + self._start_serving(protocol_factory, sock) return sock def _add_callback(self, handler): From 5caa4e1c44a7b6e43e64335e3081000b9b7e775c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 5 Feb 2013 13:57:11 -0800 Subject: [PATCH 0295/1502] Disable tests for _accept_connection in Proactor test. --- tulip/events_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tulip/events_test.py b/tulip/events_test.py index 7f8b4a38..ea22aeda 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -543,6 +543,10 @@ def test_writer_callback_cancel(self): raise unittest.SkipTest("IocpEventLoop does not have add_writer()") def test_writer_callback_with_handler(self): raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_accept_connection_retry(self): + raise unittest.SkipTest("IocpEventLoop does not have _accept_connection()") + def test_accept_connection_exception(self): + raise unittest.SkipTest("IocpEventLoop does not have _accept_connection()") else: from . import selectors From 5f4fd141e6d0eccf5ce64d283f27b4f448263ad2 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 5 Feb 2013 14:24:27 -0800 Subject: [PATCH 0296/1502] Get rid of add_connector and everything relating to it. We'll never need it. --- tulip/events.py | 6 --- tulip/selector_events.py | 80 +++++++--------------------------------- tulip/selectors.py | 39 +++++++------------- 3 files changed, 27 insertions(+), 98 deletions(-) diff --git a/tulip/events.py b/tulip/events.py index 0c7244fa..f2022fac 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -185,12 +185,6 @@ def add_writer(self, fd, callback, *args): def remove_writer(self, fd): raise NotImplementedError - def add_connector(self, fd, callback, *args): - raise NotImplementedError - - def remove_connector(self, fd): - raise NotImplementedError - # Completion based I/O methods returning Futures. def sock_recv(self, sock, nbytes): diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 8f05b443..f738daec 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -116,20 +116,20 @@ def add_reader(self, fd, callback, *args): """Add a reader callback. Return a Handler instance.""" handler = events.make_handler(None, callback, args) try: - mask, (reader, writer, connector) = self._selector.get_info(fd) + mask, (reader, writer) = self._selector.get_info(fd) except KeyError: self._selector.register(fd, selectors.EVENT_READ, - (handler, None, None)) + (handler, None)) else: self._selector.modify(fd, mask | selectors.EVENT_READ, - (handler, writer, connector)) + (handler, writer)) return handler def remove_reader(self, fd): """Remove a reader callback.""" try: - mask, (reader, writer, connector) = self._selector.get_info(fd) + mask, (reader, writer) = self._selector.get_info(fd) except KeyError: return False else: @@ -137,7 +137,7 @@ def remove_reader(self, fd): if not mask: self._selector.unregister(fd) else: - self._selector.modify(fd, mask, (None, writer, connector)) + self._selector.modify(fd, mask, (None, writer)) if reader is not None: reader.cancel() return True @@ -146,77 +146,30 @@ def add_writer(self, fd, callback, *args): """Add a writer callback. Return a Handler instance.""" handler = events.make_handler(None, callback, args) try: - mask, (reader, writer, connector) = self._selector.get_info(fd) + mask, (reader, writer) = self._selector.get_info(fd) except KeyError: self._selector.register(fd, selectors.EVENT_WRITE, - (None, handler, None)) + (None, handler)) else: - # Remove connector. - mask &= ~selectors.EVENT_CONNECT self._selector.modify(fd, mask | selectors.EVENT_WRITE, - (reader, handler, None)) + (reader, handler)) return handler def remove_writer(self, fd): """Remove a writer callback.""" try: - mask, (reader, writer, connector) = self._selector.get_info(fd) + mask, (reader, writer) = self._selector.get_info(fd) except KeyError: return False else: # Remove both writer and connector. - mask &= ~(selectors.EVENT_WRITE | selectors.EVENT_CONNECT) - if not mask: - self._selector.unregister(fd) - else: - self._selector.modify(fd, mask, (reader, None, None)) - if writer is not None: - writer.cancel() - if connector is not None: - connector.cancel() - return True - - # NOTE: add_connector() and add_writer() are mutually exclusive. - # While you can independently manipulate readers and writers, - # adding a connector for a particular FD automatically removes the - # writer for that FD, and vice versa, and removing a writer or a - # connector actually removes both writer and connector. This is - # because in most cases writers and connectors use the same mode - # for the platform polling function; the distinction is only - # important for PollSelector() on Windows. - - def add_connector(self, fd, callback, *args): - """Add a connector callback. Return a Handler instance.""" - handler = events.make_handler(None, callback, args) - try: - mask, (reader, writer, connector) = self._selector.get_info(fd) - except KeyError: - self._selector.register(fd, selectors.EVENT_CONNECT, - (None, None, handler)) - else: - # Remove writer. mask &= ~selectors.EVENT_WRITE - self._selector.modify(fd, mask | selectors.EVENT_CONNECT, - (reader, None, handler)) - return handler - - def remove_connector(self, fd): - """Remove a connector callback.""" - try: - mask, (reader, writer, connector) = self._selector.get_info(fd) - except KeyError: - return False - else: - # Remove both writer and connector. - mask &= ~(selectors.EVENT_WRITE | selectors.EVENT_CONNECT) if not mask: self._selector.unregister(fd) else: - self._selector.modify(fd, mask, (reader, None, None)) + self._selector.modify(fd, mask, (reader, None)) if writer is not None: writer.cancel() - if connector is not None: - connector.cancel() return True def sock_recv(self, sock, n): @@ -284,7 +237,7 @@ def sock_connect(self, sock, address): def _sock_connect(self, fut, registered, sock, address): fd = sock.fileno() if registered: - self.remove_connector(fd) + self.remove_writer(fd) if fut.cancelled(): return try: @@ -301,8 +254,8 @@ def _sock_connect(self, fut, registered, sock, address): if exc.errno not in _TRYAGAIN: fut.set_exception(exc) else: - self.add_connector(fd, self._sock_connect, - fut, True, sock, address) + self.add_writer(fd, self._sock_connect, + fut, True, sock, address) def sock_accept(self, sock): """XXX""" @@ -327,7 +280,7 @@ def _sock_accept(self, fut, registered, sock): self.add_reader(fd, self._sock_accept, fut, True, sock) def _process_events(self, event_list): - for fileobj, mask, (reader, writer, connector) in event_list: + for fileobj, mask, (reader, writer) in event_list: if mask & selectors.EVENT_READ and reader is not None: if reader.cancelled: self.remove_reader(fileobj) @@ -338,11 +291,6 @@ def _process_events(self, event_list): self.remove_writer(fileobj) else: self._add_callback(writer) - elif mask & selectors.EVENT_CONNECT and connector is not None: - if connector.cancelled: - self.remove_connector(fileobj) - else: - self._add_callback(connector) class _SelectorSocketTransport(transports.Transport): diff --git a/tulip/selectors.py b/tulip/selectors.py index 05434630..d51b976d 100644 --- a/tulip/selectors.py +++ b/tulip/selectors.py @@ -14,13 +14,6 @@ EVENT_READ = (1 << 0) # write event EVENT_WRITE = (1 << 1) -# connect event -EVENT_CONNECT = (1 << 2) - -# In most cases we treat EVENT_WRITE and EVENT_CONNECT as aliases for -# each other, and in fact we return both flags when a FD is found -# either writable or connectable. The distinction is necessary -# only for poll() on Windows. def _fileobj_to_fd(fileobj): @@ -84,20 +77,15 @@ def register(self, fileobj, events, data=None): Parameters: fileobj -- file object - events -- events to monitor (bitwise mask of - EVENT_READ|EVENT_WRITE|EVENT_CONNECT) + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) data -- attached data Returns: SelectorKey instance """ - if (not events) or (events & ~(EVENT_READ|EVENT_WRITE|EVENT_CONNECT)): + if (not events) or (events & ~(EVENT_READ|EVENT_WRITE)): raise ValueError("Invalid events: {}".format(events)) - if events & (EVENT_WRITE|EVENT_CONNECT) == (EVENT_WRITE|EVENT_CONNECT): - raise ValueError("WRITE and CONNECT are mutually exclusive. " - "Invalid events: {}".format(events)) - if fileobj in self._fileobj_to_key: raise ValueError("{!r} is already registered".format(fileobj)) @@ -128,8 +116,7 @@ def modify(self, fileobj, events, data=None): Parameters: fileobj -- file object - events -- events to monitor (bitwise mask of - EVENT_READ|EVENT_WRITE|EVENT_CONNECT) + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) data -- attached data """ # TODO: Subclasses can probably optimize this even further. @@ -155,7 +142,7 @@ def select(self, timeout=None): Returns: list of (fileobj, events, attached data) for ready file objects - `events` is a bitwise mask of EVENT_READ|EVENT_WRITE|EVENT_CONNECT + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE """ raise NotImplementedError() @@ -223,7 +210,7 @@ def register(self, fileobj, events, data=None): key = super().register(fileobj, events, data) if events & EVENT_READ: self._readers.add(key.fd) - if events & (EVENT_WRITE|EVENT_CONNECT): + if events & EVENT_WRITE: self._writers.add(key.fd) return key @@ -247,7 +234,7 @@ def select(self, timeout=None): if fd in r: events |= EVENT_READ if fd in w: - events |= EVENT_WRITE|EVENT_CONNECT + events |= EVENT_WRITE key = self._key_from_fd(fd) if key: @@ -280,7 +267,7 @@ def register(self, fileobj, events, data=None): poll_events = 0 if events & EVENT_READ: poll_events |= POLLIN - if events & (EVENT_WRITE|EVENT_CONNECT): + if events & EVENT_WRITE: poll_events |= POLLOUT self._poll.register(key.fd, poll_events) return key @@ -301,7 +288,7 @@ def select(self, timeout=None): for fd, event in fd_event_list: events = 0 if event & ~POLLIN: - events |= EVENT_WRITE|EVENT_CONNECT + events |= EVENT_WRITE if event & ~POLLOUT: events |= EVENT_READ @@ -325,7 +312,7 @@ def register(self, fileobj, events, data=None): epoll_events = 0 if events & EVENT_READ: epoll_events |= EPOLLIN - if events & (EVENT_WRITE|EVENT_CONNECT): + if events & EVENT_WRITE: epoll_events |= EPOLLOUT self._epoll.register(key.fd, epoll_events) return key @@ -347,7 +334,7 @@ def select(self, timeout=None): for fd, event in fd_event_list: events = 0 if event & ~EPOLLIN: - events |= EVENT_WRITE|EVENT_CONNECT + events |= EVENT_WRITE if event & ~EPOLLOUT: events |= EVENT_READ @@ -376,7 +363,7 @@ def unregister(self, fileobj): if key.events & EVENT_READ: kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_DELETE) self._kqueue.control([kev], 0, 0) - if key.events & (EVENT_WRITE|EVENT_CONNECT): + if key.events & EVENT_WRITE: kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_DELETE) self._kqueue.control([kev], 0, 0) return key @@ -386,7 +373,7 @@ def register(self, fileobj, events, data=None): if events & EVENT_READ: kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_ADD) self._kqueue.control([kev], 0, 0) - if events & (EVENT_WRITE|EVENT_CONNECT): + if events & EVENT_WRITE: kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_ADD) self._kqueue.control([kev], 0, 0) return key @@ -406,7 +393,7 @@ def select(self, timeout=None): if flag == KQ_FILTER_READ: events |= EVENT_READ if flag == KQ_FILTER_WRITE: - events |= EVENT_WRITE|EVENT_CONNECT + events |= EVENT_WRITE key = self._key_from_fd(fd) if key: From 321dd3bda443f4250d18c9c82f0ef22e62110574 Mon Sep 17 00:00:00 2001 From: Sa?l Ibarra Corretg? Date: Wed, 6 Feb 2013 01:37:28 +0100 Subject: [PATCH 0297/1502] Rmoved _when from Handler and added Timer class --- tulip/base_events.py | 14 +++--- tulip/events.py | 84 ++++++++++++++++++----------------- tulip/events_test.py | 94 ++++++++++++++++++++-------------------- tulip/selector_events.py | 4 +- tulip/unix_events.py | 2 +- 5 files changed, 103 insertions(+), 95 deletions(-) diff --git a/tulip/base_events.py b/tulip/base_events.py index a9733eab..bce9b682 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -140,7 +140,7 @@ def call_later(self, delay, callback, *args): """ if delay <= 0: return self.call_soon(callback, *args) - handler = events.make_handler(time.monotonic() + delay, callback, args) + handler = events.Timer(time.monotonic() + delay, callback, args) heapq.heappush(self._scheduled, handler) return handler @@ -150,7 +150,7 @@ def wrapper(): callback(*args) # If this fails, the chain is broken. handler._when = time.monotonic() + interval heapq.heappush(self._scheduled, handler) - handler = events.make_handler(time.monotonic() + interval, wrapper, ()) + handler = events.Timer(time.monotonic() + interval, wrapper, ()) heapq.heappush(self._scheduled, handler) return handler @@ -164,7 +164,7 @@ def call_soon(self, callback, *args): Any positional arguments after the callback will be passed to the callback when it is called. """ - handler = events.make_handler(None, callback, args) + handler = events.make_handler(callback, args) self._ready.append(handler) return handler @@ -177,7 +177,7 @@ def call_soon_threadsafe(self, callback, *args): def run_in_executor(self, executor, callback, *args): if isinstance(callback, events.Handler): assert not args - assert callback.when is None + assert not isinstance(callback, events.Timer) if callback.cancelled: f = futures.Future() f.set_result(None) @@ -304,10 +304,10 @@ def _add_callback(self, handler): """Add a Handler to ready or scheduled.""" if handler.cancelled: return - if handler.when is None: - self._ready.append(handler) - else: + if isinstance(handler, events.Timer): heapq.heappush(self._scheduled, handler) + else: + self._ready.append(handler) def wrap_future(self, future): """XXX""" diff --git a/tulip/events.py b/tulip/events.py index f2022fac..1a509378 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -5,7 +5,7 @@ """ __all__ = ['EventLoopPolicy', 'DefaultEventLoopPolicy', - 'AbstractEventLoop', 'Handler', 'make_handler', + 'AbstractEventLoop', 'Timer', 'Handler', 'make_handler', 'get_event_loop_policy', 'set_event_loop_policy', 'get_event_loop', 'set_event_loop', 'new_event_loop', ] @@ -17,24 +17,17 @@ class Handler: """Object returned by callback registration methods.""" - def __init__(self, when, callback, args): - self._when = when + def __init__(self, callback, args): self._callback = callback self._args = args self._cancelled = False def __repr__(self): - res = 'Handler({}, {}, {})'.format(self._when, - self._callback, - self._args) + res = 'Handler({}, {})'.format(self._callback, self._args) if self._cancelled: res += '' return res - @property - def when(self): - return self._when - @property def callback(self): return self._callback @@ -50,48 +43,61 @@ def cancelled(self): def cancel(self): self._cancelled = True - def __lt__(self, other): - if self._when is None: - return other._when is not None - elif other._when is None: - return False +def make_handler(callback, args): + if isinstance(callback, Handler): + assert not args + return callback + return Handler(callback, args) + + +class Timer(Handler): + """Object returned by timed callback registration methods.""" + + def __init__(self, when, callback, args): + assert when is not None + super().__init__(callback, args) + self._when = when + + def __repr__(self): + res = 'Timer({}, {}, {})'.format(self._when, + self._callback, + self._args) + if self._cancelled: + res += '' + return res + + @property + def when(self): + return self._when + + def __lt__(self, other): return self._when < other._when def __le__(self, other): - if self._when is None: + if self._when < other._when: return True - elif other._when is None: - return False - - return self._when <= other._when + return self.__eq__(other) def __gt__(self, other): - if self._when is None: - return False - elif other._when is None: - return True - return self._when > other._when def __ge__(self, other): - if self._when is None: - return other._when is None - elif other._when is None: + if self._when > other._when: return True - - return self._when >= other._when + return self.__eq__(other) def __eq__(self, other): - return self._when == other._when - - -def make_handler(when, callback, args): - if isinstance(callback, Handler): - assert not args - assert when is None - return callback - return Handler(when, callback, args) + if isinstance(other, Timer): + return (self._when == other._when and + self._callback == other._callback and + self._args == other._args and + self._cancelled == other._cancelled) + return NotImplemented + + def __ne__(self, other): + equal = self.__eq__(other) + return NotImplemented if equal is NotImplemented else not equal class AbstractEventLoop: diff --git a/tulip/events_test.py b/tulip/events_test.py index ea22aeda..36fa0094 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -96,7 +96,7 @@ def test_call_soon_with_handler(self): results = [] def callback(): results.append('yeah') - handler = events.Handler(None, callback, ()) + handler = events.Handler(callback, ()) self.assertIs(self.event_loop.call_soon(handler), handler) self.event_loop.run() self.assertEqual(results, ['yeah']) @@ -122,7 +122,7 @@ def test_call_soon_threadsafe_with_handler(self): def callback(arg): results.append(arg) - handler = events.Handler(None, callback, ('hello',)) + handler = events.Handler(callback, ('hello',)) def run(): self.assertIs(self.event_loop.call_soon_threadsafe(handler),handler) @@ -159,7 +159,7 @@ def test_run_in_executor_with_handler(self): def run(arg): time.sleep(0.1) return arg - handler = events.Handler(None, run, ('yo',)) + handler = events.Handler(run, ('yo',)) f2 = self.event_loop.run_in_executor(None, handler) res = self.event_loop.run_until_complete(f2) self.assertEqual(res, 'yo') @@ -202,7 +202,7 @@ def reader(): self.assertTrue(self.event_loop.remove_reader(r.fileno())) r.close() - handler = events.Handler(None, reader, ()) + handler = events.Handler(reader, ()) self.assertIs(handler, self.event_loop.add_reader(r.fileno(), handler)) self.event_loop.call_later(0.05, w.send, b'abc') @@ -248,7 +248,7 @@ def remove_writer(): def test_writer_callback_with_handler(self): r, w = self.event_loop._socketpair() w.setblocking(False) - handler = events.Handler(None, w.send, (b'x'*(256*1024),)) + handler = events.Handler(w.send, (b'x'*(256*1024),)) self.assertIs(self.event_loop.add_writer(w.fileno(), handler), handler) def remove_writer(): self.assertTrue(self.event_loop.remove_writer(w.fileno())) @@ -584,15 +584,14 @@ def callback(*args): return args args = () - h = events.Handler(None, callback, args) - self.assertIsNone(h.when) + h = events.Handler(callback, args) self.assertIs(h.callback, callback) self.assertIs(h.args, args) self.assertFalse(h.cancelled) r = repr(h) self.assertTrue(r.startswith( - 'Handler(None, ' + 'Handler(' '.callback')) self.assertTrue(r.endswith('())')) @@ -601,67 +600,70 @@ def callback(*args): r = repr(h) self.assertTrue(r.startswith( - 'Handler(None, ' + 'Handler(' '.callback')) self.assertTrue(r.endswith('())')) - def test_handler_comparison(self): + def test_make_handler(self): def callback(*args): return args + h1 = events.Handler(callback, ()) + h2 = events.make_handler(h1, ()) + self.assertIs(h1, h2) - h1 = events.Handler(None, callback, ()) - h2 = events.Handler(None, callback, ()) - self.assertTrue((h1 < h2) == False) - self.assertTrue((h1 <= h2) == True) - self.assertTrue((h1 > h2) == False) - self.assertTrue((h1 >= h2) == True) - self.assertTrue((h1 == h2) == True) + self.assertRaises(AssertionError, + events.make_handler, h1, (1,2,)) + + +class TimerTests(unittest.TestCase): + + def test_timer(self): + def callback(*args): + return args + args = () when = time.monotonic() + h = events.Timer(when, callback, args) + self.assertIs(h.callback, callback) + self.assertIs(h.args, args) + self.assertFalse(h.cancelled) - h1 = events.Handler(when, callback, ()) - h2 = events.Handler(None, callback, ()) - self.assertTrue((h1 < h2) == False) - self.assertTrue((h1 <= h2) == False) - self.assertTrue((h1 > h2) == True) - self.assertTrue((h1 >= h2) == True) - self.assertTrue((h1 == h2) == False) + r = repr(h) + self.assertTrue(r.endswith('())')) - self.assertTrue((h2 < h1) == True) - self.assertTrue((h2 <= h1) == True) - self.assertTrue((h2 > h1) == False) - self.assertTrue((h2 >= h1) == False) - self.assertTrue((h2 == h1) == False) + h.cancel() + self.assertTrue(h.cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + self.assertRaises(AssertionError, events.Timer, None, callback, args) + + def test_timer_comparison(self): + def callback(*args): + return args + + when = time.monotonic() - h1 = events.Handler(when, callback, ()) - h2 = events.Handler(when, callback, ()) + h1 = events.Timer(when, callback, ()) + h2 = events.Timer(when, callback, ()) self.assertTrue((h1 < h2) == False) self.assertTrue((h1 <= h2) == True) self.assertTrue((h1 > h2) == False) self.assertTrue((h1 >= h2) == True) self.assertTrue((h1 == h2) == True) - h1 = events.Handler(when, callback, ()) - h2 = events.Handler(when + 10.0, callback, ()) + h2.cancel() + self.assertTrue((h1 == h2) == False) + + h1 = events.Timer(when, callback, ()) + h2 = events.Timer(when + 10.0, callback, ()) self.assertTrue((h1 < h2) == True) self.assertTrue((h1 <= h2) == True) self.assertTrue((h1 > h2) == False) self.assertTrue((h1 >= h2) == False) self.assertTrue((h1 == h2) == False) - def test_make_handler(self): - def callback(*args): - return args - h1 = events.Handler(None, callback, ()) - h2 = events.make_handler(None, h1, ()) - self.assertIs(h1, h2) - - self.assertRaises(AssertionError, - events.make_handler, 10.0, h1, ()) - - self.assertRaises(AssertionError, - events.make_handler, None, h1, (1,2,)) - class PolicyTests(unittest.TestCase): diff --git a/tulip/selector_events.py b/tulip/selector_events.py index f738daec..31c3d4b3 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -114,7 +114,7 @@ def _accept_connection(self, protocol_factory, sock): def add_reader(self, fd, callback, *args): """Add a reader callback. Return a Handler instance.""" - handler = events.make_handler(None, callback, args) + handler = events.make_handler(callback, args) try: mask, (reader, writer) = self._selector.get_info(fd) except KeyError: @@ -144,7 +144,7 @@ def remove_reader(self, fd): def add_writer(self, fd, callback, *args): """Add a writer callback. Return a Handler instance.""" - handler = events.make_handler(None, callback, args) + handler = events.make_handler(callback, args) try: mask, (reader, writer) = self._selector.get_info(fd) except KeyError: diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 7c4d3cf6..8a2026bb 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -44,7 +44,7 @@ def add_signal_handler(self, sig, callback, *args): signal.set_wakeup_fd(self._csock.fileno()) except ValueError as exc: raise RuntimeError(str(exc)) - handler = events.make_handler(None, callback, args) + handler = events.make_handler(callback, args) self._signal_handlers[sig] = handler try: signal.signal(sig, self._handle_signal) From 8416473adab26a4249e0fe9a80d32ed5ff081360 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 5 Feb 2013 17:05:27 -0800 Subject: [PATCH 0298/1502] Break some long lines. --- tulip/base_events.py | 3 ++- tulip/events_test.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tulip/base_events.py b/tulip/base_events.py index bce9b682..8787b5b3 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -251,7 +251,8 @@ def create_connection(self, protocol_factory, host=None, port=None, *, waiter = futures.Future() if ssl: sslcontext = None if isinstance(ssl, bool) else ssl - transport = self.SslTransport(self, sock, protocol, sslcontext, waiter) + transport = self.SslTransport(self, sock, protocol, + sslcontext, waiter) else: transport = self.SocketTransport(self, sock, protocol, waiter) diff --git a/tulip/events_test.py b/tulip/events_test.py index 36fa0094..7d443065 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -544,9 +544,11 @@ def test_writer_callback_cancel(self): def test_writer_callback_with_handler(self): raise unittest.SkipTest("IocpEventLoop does not have add_writer()") def test_accept_connection_retry(self): - raise unittest.SkipTest("IocpEventLoop does not have _accept_connection()") + raise unittest.SkipTest( + "IocpEventLoop does not have _accept_connection()") def test_accept_connection_exception(self): - raise unittest.SkipTest("IocpEventLoop does not have _accept_connection()") + raise unittest.SkipTest( + "IocpEventLoop does not have _accept_connection()") else: from . import selectors From b368d1437f200ac637b181a362cf8917fa778da6 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Wed, 6 Feb 2013 16:03:27 +0000 Subject: [PATCH 0299/1502] Change in handling of ERROR_BROKEN_PIPE. --- overlapped.c | 16 +++++++++------- tulip/windows_events.py | 5 +++-- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/overlapped.c b/overlapped.c index 09ac4e68..eea56665 100644 --- a/overlapped.c +++ b/overlapped.c @@ -332,10 +332,10 @@ Overlapped_dealloc(OverlappedObject *self) case ERROR_OPERATION_ABORTED: break; default: - PyErr_SetString( + PyErr_Format( PyExc_RuntimeError, - "I/O operations still in flight while destroying " - "Overlapped object, the process may crash"); + "%R still has pending operation at " + "deallocation, the process may crash", self); PyErr_WriteUnraisable(NULL); } } @@ -483,10 +483,12 @@ Overlapped_ReadFile(OverlappedObject *self, PyObject *args) self->error = err = ret ? ERROR_SUCCESS : GetLastError(); switch (err) { + case ERROR_BROKEN_PIPE: + self->type = TYPE_NOT_STARTED; + Py_RETURN_NONE; case ERROR_SUCCESS: case ERROR_MORE_DATA: case ERROR_IO_PENDING: - case ERROR_BROKEN_PIPE: Py_RETURN_NONE; default: self->type = TYPE_NOT_STARTED; @@ -540,10 +542,12 @@ Overlapped_WSARecv(OverlappedObject *self, PyObject *args) self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); switch (err) { + case ERROR_BROKEN_PIPE: + self->type = TYPE_NOT_STARTED; + Py_RETURN_NONE; case ERROR_SUCCESS: case ERROR_MORE_DATA: case ERROR_IO_PENDING: - case ERROR_BROKEN_PIPE: Py_RETURN_NONE; default: self->type = TYPE_NOT_STARTED; @@ -595,7 +599,6 @@ Overlapped_WriteFile(OverlappedObject *self, PyObject *args) self->error = err = ret ? ERROR_SUCCESS : GetLastError(); switch (err) { case ERROR_SUCCESS: - case ERROR_MORE_DATA: case ERROR_IO_PENDING: Py_RETURN_NONE; default: @@ -653,7 +656,6 @@ Overlapped_WSASend(OverlappedObject *self, PyObject *args) self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); switch (err) { case ERROR_SUCCESS: - case ERROR_MORE_DATA: case ERROR_IO_PENDING: Py_RETURN_NONE; default: diff --git a/tulip/windows_events.py b/tulip/windows_events.py index 23aed886..32c962f7 100644 --- a/tulip/windows_events.py +++ b/tulip/windows_events.py @@ -55,7 +55,8 @@ def registered_count(self): def select(self, timeout=None): if not self._results: self._poll(timeout) - tmp, self._results = self._results, [] + tmp = self._results + self._results = [] return tmp def recv(self, conn, nbytes, flags=0): @@ -151,7 +152,7 @@ def close(self): pass while self._cache: - if not self._poll(1000): + if not self._poll(1): logging.debug('taking long time to close proactor') self._results = [] From 92aef55a9c15334c516f059f5715bfdc57597426 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Wed, 6 Feb 2013 16:03:29 +0000 Subject: [PATCH 0300/1502] Fix confusion between seconds and milliseconds. From 866d9f4f00fc6f44db3daa55d03a1ecb77b8b134 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Wed, 6 Feb 2013 18:27:37 +0000 Subject: [PATCH 0301/1502] Add --iocp argument to curl.py and crawl.py which forces use of IOCP. Also fix use of undeclared error constant in proactor.py, and map some Windows error numbers to correct subclasses of OSError. --- crawl.py | 11 ++++--- curl.py | 10 +++---- overlapped.c | 62 ++++++++++++++++++++++++++++------------ tulip/proactor_events.py | 5 ++-- 4 files changed, 55 insertions(+), 33 deletions(-) diff --git a/crawl.py b/crawl.py index 8d404f55..3d7055f9 100755 --- a/crawl.py +++ b/crawl.py @@ -133,11 +133,10 @@ def main(): if __name__ == '__main__': - try: - from tulip import events, iocp_events - except ImportError: - pass - else: - el = iocp_events.IocpEventLoop() + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + logging.info('using iocp') + el = windows_events.ProactorEventLoop() events.set_event_loop(el) main() diff --git a/curl.py b/curl.py index af3725e8..0ba82404 100755 --- a/curl.py +++ b/curl.py @@ -28,11 +28,9 @@ def main(): if __name__ == '__main__': - try: - from tulip import events, iocp_events - except ImportError: - pass - else: - el = iocp_events.IocpEventLoop() + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + el = windows_events.ProactorEventLoop() events.set_event_loop(el) main() diff --git a/overlapped.c b/overlapped.c index eea56665..1243cf46 100644 --- a/overlapped.c +++ b/overlapped.c @@ -34,6 +34,30 @@ enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_WRITE, TYPE_ACCEPT, TYPE_CONNECT, TYPE_DISCONNECT}; +/* + * Map Windows error codes to subclasses of OSError + */ + +static void * +SetFromWindowsErr(DWORD err) +{ + PyObject *exception_type; + + if (err == 0) + err = GetLastError(); + switch (err) { + case ERROR_CONNECTION_REFUSED: + exception_type = PyExc_ConnectionRefusedError; + break; + case ERROR_CONNECTION_ABORTED: + exception_type = PyExc_ConnectionAbortedError; + break; + default: + exception_type = PyExc_OSError; + } + return PyErr_SetExcFromWindowsErr(exception_type, err); +} + /* * Some functions should be loaded at runtime */ @@ -60,7 +84,7 @@ initialize_function_pointers(void) s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); if (s == INVALID_SOCKET) { - PyErr_SetFromWindowsErr(WSAGetLastError()); + SetFromWindowsErr(WSAGetLastError()); return -1; } @@ -69,7 +93,7 @@ initialize_function_pointers(void) !GET_WSA_POINTER(s, DisconnectEx)) { closesocket(s); - PyErr_SetFromWindowsErr(WSAGetLastError()); + SetFromWindowsErr(WSAGetLastError()); return -1; } @@ -110,7 +134,7 @@ overlapped_CreateIoCompletionPort(PyObject *self, PyObject *args) Py_END_ALLOW_THREADS if (ret == NULL) - return PyErr_SetFromWindowsErr(0); + return SetFromWindowsErr(0); return Py_BuildValue(F_HANDLE, ret); } @@ -144,7 +168,7 @@ overlapped_GetQueuedCompletionStatus(PyObject *self, PyObject *args) if (err == WAIT_TIMEOUT) Py_RETURN_NONE; else - return PyErr_SetFromWindowsErr(err); + return SetFromWindowsErr(err); } return Py_BuildValue(F_DWORD F_DWORD F_ULONG_PTR F_POINTER, err, NumberOfBytes, CompletionKey, Overlapped); @@ -175,7 +199,7 @@ overlapped_PostQueuedCompletionStatus(PyObject *self, PyObject *args) Py_END_ALLOW_THREADS if (!ret) - return PyErr_SetFromWindowsErr(0); + return SetFromWindowsErr(0); Py_RETURN_NONE; } @@ -220,7 +244,7 @@ overlapped_BindLocal(PyObject *self, PyObject *args) } if (!ret) - return PyErr_SetExcFromWindowsErr(PyExc_IOError, WSAGetLastError()); + return SetFromWindowsErr(WSAGetLastError()); Py_RETURN_NONE; } @@ -243,7 +267,7 @@ overlapped_SetFileCompletionNotificationModes(PyObject *self, PyObject *args) return NULL; if (!SetFileCompletionNotificationModes(FileHandle, Flags)) - return PyErr_SetExcFromWindowsErr(PyExc_IOError, 0); + return SetFromWindowsErr(0); Py_RETURN_NONE; } @@ -286,7 +310,7 @@ Overlapped_new(PyTypeObject *type, PyObject *args, PyObject *kwds) if (event == INVALID_HANDLE_VALUE) { event = CreateEvent(NULL, TRUE, FALSE, NULL); if (event == NULL) - return PyErr_SetExcFromWindowsErr(PyExc_OSError, 0); + return SetFromWindowsErr(0); } self = PyObject_New(OverlappedObject, type); @@ -375,7 +399,7 @@ Overlapped_cancel(OverlappedObject *self) /* CancelIoEx returns ERROR_NOT_FOUND if the I/O completed in-between */ if (!ret && GetLastError() != ERROR_NOT_FOUND) - return PyErr_SetExcFromWindowsErr(PyExc_IOError, 0); + return SetFromWindowsErr(0); Py_RETURN_NONE; } @@ -422,7 +446,7 @@ Overlapped_getresult(OverlappedObject *self, PyObject *args) break; /* fall through */ default: - return PyErr_SetExcFromWindowsErr(PyExc_IOError, err); + return SetFromWindowsErr(err); } switch (self->type) { @@ -492,7 +516,7 @@ Overlapped_ReadFile(OverlappedObject *self, PyObject *args) Py_RETURN_NONE; default: self->type = TYPE_NOT_STARTED; - return PyErr_SetExcFromWindowsErr(PyExc_IOError, err); + return SetFromWindowsErr(err); } } @@ -551,7 +575,7 @@ Overlapped_WSARecv(OverlappedObject *self, PyObject *args) Py_RETURN_NONE; default: self->type = TYPE_NOT_STARTED; - return PyErr_SetExcFromWindowsErr(PyExc_IOError, err); + return SetFromWindowsErr(err); } } @@ -603,7 +627,7 @@ Overlapped_WriteFile(OverlappedObject *self, PyObject *args) Py_RETURN_NONE; default: self->type = TYPE_NOT_STARTED; - return PyErr_SetExcFromWindowsErr(PyExc_IOError, err); + return SetFromWindowsErr(err); } } @@ -660,7 +684,7 @@ Overlapped_WSASend(OverlappedObject *self, PyObject *args) Py_RETURN_NONE; default: self->type = TYPE_NOT_STARTED; - return PyErr_SetExcFromWindowsErr(PyExc_IOError, err); + return SetFromWindowsErr(err); } } @@ -710,7 +734,7 @@ Overlapped_AcceptEx(OverlappedObject *self, PyObject *args) Py_RETURN_NONE; default: self->type = TYPE_NOT_STARTED; - return PyErr_SetExcFromWindowsErr(PyExc_IOError, err); + return SetFromWindowsErr(err); } } @@ -729,7 +753,7 @@ parse_address(PyObject *obj, SOCKADDR *Address, int Length) { Address->sa_family = AF_INET; if (WSAStringToAddressA(Host, AF_INET, NULL, Address, &Length) < 0) { - PyErr_SetExcFromWindowsErr(PyExc_IOError, WSAGetLastError()); + SetFromWindowsErr(WSAGetLastError()); return -1; } ((SOCKADDR_IN*)Address)->sin_port = htons(Port); @@ -740,7 +764,7 @@ parse_address(PyObject *obj, SOCKADDR *Address, int Length) PyErr_Clear(); Address->sa_family = AF_INET6; if (WSAStringToAddressA(Host, AF_INET6, NULL, Address, &Length) < 0) { - PyErr_SetExcFromWindowsErr(PyExc_IOError, WSAGetLastError()); + SetFromWindowsErr(WSAGetLastError()); return -1; } ((SOCKADDR_IN6*)Address)->sin6_port = htons(Port); @@ -797,7 +821,7 @@ Overlapped_ConnectEx(OverlappedObject *self, PyObject *args) Py_RETURN_NONE; default: self->type = TYPE_NOT_STARTED; - return PyErr_SetExcFromWindowsErr(PyExc_IOError, err); + return SetFromWindowsErr(err); } } @@ -836,7 +860,7 @@ Overlapped_DisconnectEx(OverlappedObject *self, PyObject *args) Py_RETURN_NONE; default: self->type = TYPE_NOT_STARTED; - return PyErr_SetExcFromWindowsErr(PyExc_IOError, err); + return SetFromWindowsErr(err); } } diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index 16099ca3..8117e1cc 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -38,9 +38,10 @@ def _loop_reading(self, f=None): return self._event_loop.call_soon(self._protocol.data_received, data) self._read_fut = self._event_loop._proactor.recv(self._sock, 4096) + except ConnectionAbortedError as exc: + if not self._closing: + self._fatal_error(exc) except OSError as exc: - if exc.winerror == ERROR_CONNECTION_ABORTED and self._closing: - return self._fatal_error(exc) else: self._read_fut.add_done_callback(self._loop_reading) From 6c706e40c701f71bbec6ecbdfe918c99ae61eca3 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Wed, 6 Feb 2013 19:04:15 +0000 Subject: [PATCH 0302/1502] No more need to explicitly convert OSError to ConnectionRefusedError. --- tulip/windows_events.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/tulip/windows_events.py b/tulip/windows_events.py index 32c962f7..6752f4ed 100644 --- a/tulip/windows_events.py +++ b/tulip/windows_events.py @@ -91,13 +91,7 @@ def connect(self, conn, address): ov = _overlapped.Overlapped(NULL) ov.ConnectEx(conn.fileno(), address) def finish_connect(): - try: - ov.getresult() - except OSError as e: - if e.winerror == ERROR_CONNECTION_REFUSED: - raise ConnectionRefusedError(errno.ECONNREFUSED, - 'connection refused') - raise + ov.getresult() conn.setsockopt(socket.SOL_SOCKET, _overlapped.SO_UPDATE_CONNECT_CONTEXT, 0) From 164ba52a497c9c24525090a1e15667daaa3ceeaf Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Wed, 6 Feb 2013 19:08:14 +0000 Subject: [PATCH 0303/1502] Fix return type of SetFromWindowsErr(). --- overlapped.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/overlapped.c b/overlapped.c index 1243cf46..eed0989b 100644 --- a/overlapped.c +++ b/overlapped.c @@ -38,7 +38,7 @@ enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_WRITE, TYPE_ACCEPT, * Map Windows error codes to subclasses of OSError */ -static void * +static PyObject * SetFromWindowsErr(DWORD err) { PyObject *exception_type; From 7cd90b8ea3e868bd21ae038c2a90abeaca1fe55d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 6 Feb 2013 16:12:29 -0800 Subject: [PATCH 0304/1502] Add missing super().tearDown() class to EventLoopTestsMixin. --- tulip/events_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tulip/events_test.py b/tulip/events_test.py index 7d443065..704fd7e7 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -60,6 +60,7 @@ def setUp(self): def tearDown(self): self.event_loop.close() gc.collect() + super().tearDown() def test_run(self): self.event_loop.run() # Returns immediately. From 2801a0c58c769e31ec146bd7c2e0c1ea0c4a7e37 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 6 Feb 2013 16:16:33 -0800 Subject: [PATCH 0305/1502] Suppress error/warning logs in some tests, and other cleanup. --- tulip/tasks_test.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py index 4895732d..fac97655 100644 --- a/tulip/tasks_test.py +++ b/tulip/tasks_test.py @@ -163,6 +163,7 @@ def test_wait_first_completed(self): self.assertEqual({a}, pending) def test_wait_really_done(self): + self.suppress_log_errors() # there is possibility that some tasks in the pending list # became done but their callbacks haven't all been called yet @@ -362,7 +363,9 @@ def notmuch(): self.assertEqual(1, m_logging.warn.call_args[0][1]) def test_step_result_future(self): - """Coroutine returns Future""" + # Coroutine returns Future + self.suppress_log_warnings() + class Fut(futures.Future): def __init__(self, *args): self.cb_added = False @@ -387,10 +390,11 @@ def notmuch(): self.event_loop.run() self.assertIs(res, task.result()) - def test_step_result_cuncurrent_future(self): - """Coroutine returns cuncurrent.future.Future""" - class Fut(concurrent.futures.Future): + def test_step_result_concurrent_future(self): + # Coroutine returns concurrent.futures.Future + self.suppress_log_warnings() + class Fut(concurrent.futures.Future): def __init__(self): self.cb_added = False super().__init__() From 50e3c9fd859a60fd34dcee27f440599f17b7a3f1 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 6 Feb 2013 16:20:33 -0800 Subject: [PATCH 0306/1502] More suppress_log_errors() calls. --- tulip/events_test.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tulip/events_test.py b/tulip/events_test.py index 704fd7e7..cf868b42 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -408,15 +408,18 @@ def test_create_ssl_transport(self): self.assertTrue(pr.nbytes > 0) def test_create_transport_host_port_sock(self): + self.suppress_log_errors() fut = self.event_loop.create_connection( MyProto, 'xkcd.com', 80, sock=object()) self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) def test_create_transport_no_host_port_sock(self): + self.suppress_log_errors() fut = self.event_loop.create_connection(MyProto) self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) def test_create_transport_no_getaddrinfo(self): + self.suppress_log_errors() getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() getaddrinfo.return_value = [] @@ -425,6 +428,7 @@ def test_create_transport_no_getaddrinfo(self): socket.error, self.event_loop.run_until_complete, fut) def test_create_transport_connect_err(self): + self.suppress_log_errors() self.event_loop.sock_connect = unittest.mock.Mock() self.event_loop.sock_connect.side_effect = socket.error @@ -469,14 +473,17 @@ def test_start_serving_sock(self): client.close() def test_start_serving_host_port_sock(self): + self.suppress_log_errors() fut = self.event_loop.start_serving(MyProto,'0.0.0.0',0,sock=object()) self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) def test_start_serving_no_host_port_sock(self): + self.suppress_log_errors() fut = self.event_loop.start_serving(MyProto) self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) def test_start_serving_no_getaddrinfo(self): + self.suppress_log_errors() getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() getaddrinfo.return_value = [] @@ -486,6 +493,8 @@ def test_start_serving_no_getaddrinfo(self): @unittest.mock.patch('tulip.base_events.socket') def test_start_serving_cant_bind(self, m_socket): + self.suppress_log_errors() + class Err(socket.error): pass From 9ddc0df517271849dfbdcc99d5787656bab9a58c Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Thu, 7 Feb 2013 00:29:19 +0000 Subject: [PATCH 0307/1502] Fixes for 64 bit Windows. --- overlapped.c | 2 +- tulip/windows_events.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/overlapped.c b/overlapped.c index eed0989b..943e788a 100644 --- a/overlapped.c +++ b/overlapped.c @@ -605,7 +605,7 @@ Overlapped_WriteFile(OverlappedObject *self, PyObject *args) return NULL; #if SIZEOF_SIZE_T > SIZEOF_LONG - if (self->write_buffer.len > (Py_ssize_t)PY_ULONG_MAX) { + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { PyBuffer_Release(&self->write_buffer); PyErr_SetString(PyExc_ValueError, "buffer to large"); return NULL; diff --git a/tulip/windows_events.py b/tulip/windows_events.py index 6752f4ed..3a0b8675 100644 --- a/tulip/windows_events.py +++ b/tulip/windows_events.py @@ -5,6 +5,7 @@ import logging import socket import weakref +import struct import _winapi @@ -78,9 +79,10 @@ def accept(self, listener): ov.AcceptEx(listener.fileno(), conn.fileno()) def finish_accept(): addr = ov.getresult() + buf = struct.pack('@P', listener.fileno()) conn.setsockopt(socket.SOL_SOCKET, _overlapped.SO_UPDATE_ACCEPT_CONTEXT, - listener.fileno()) + buf) conn.settimeout(listener.gettimeout()) return conn, conn.getpeername() return self._register(ov, listener, finish_accept) From 878ba83a6a737a8face09126d4341dcd3275ea01 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Thu, 7 Feb 2013 01:09:56 +0000 Subject: [PATCH 0308/1502] Forgotten bits for 64 bit Windows. --- overlapped.c | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/overlapped.c b/overlapped.c index 943e788a..cccdedff 100644 --- a/overlapped.c +++ b/overlapped.c @@ -616,7 +616,8 @@ Overlapped_WriteFile(OverlappedObject *self, PyObject *args) self->handle = handle; Py_BEGIN_ALLOW_THREADS - ret = WriteFile(handle, self->write_buffer.buf, self->write_buffer.len, + ret = WriteFile(handle, self->write_buffer.buf, + (DWORD)self->write_buffer.len, &written, &self->overlapped); Py_END_ALLOW_THREADS @@ -660,7 +661,7 @@ Overlapped_WSASend(OverlappedObject *self, PyObject *args) return NULL; #if SIZEOF_SIZE_T > SIZEOF_LONG - if (self->write_buffer.len > (Py_ssize_t)PY_ULONG_MAX) { + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { PyBuffer_Release(&self->write_buffer); PyErr_SetString(PyExc_ValueError, "buffer to large"); return NULL; From d127511970fb73603491104b11c125ed6ec2e4e7 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Thu, 7 Feb 2013 16:00:14 +0000 Subject: [PATCH 0309/1502] Stop exposing SetFileCompletionNotificationModes() in _overlapped since it is not currently used by tulip, and is not WinXP compatible. --- overlapped.c | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/overlapped.c b/overlapped.c index cccdedff..c9f6ec9f 100644 --- a/overlapped.c +++ b/overlapped.c @@ -248,30 +248,6 @@ overlapped_BindLocal(PyObject *self, PyObject *args) Py_RETURN_NONE; } -/* - * Set notification mode for the handle - */ - -PyDoc_STRVAR( - SetFileCompletionNotificationModes_doc, - "SetFileCompletionNotificationModes(FileHandle, Flags) -> None\n\n" - "Set whether notification happens if operation succeeds without blocking"); - -static PyObject * -overlapped_SetFileCompletionNotificationModes(PyObject *self, PyObject *args) -{ - HANDLE FileHandle; - UCHAR Flags; - - if (!PyArg_ParseTuple(args, F_HANDLE F_BOOL, &FileHandle, &Flags)) - return NULL; - - if (!SetFileCompletionNotificationModes(FileHandle, Flags)) - return SetFromWindowsErr(0); - - Py_RETURN_NONE; -} - /* * A Python object wrapping an OVERLAPPED structure and other useful data * for overlapped I/O @@ -968,9 +944,6 @@ static PyMethodDef overlapped_functions[] = { METH_VARARGS, PostQueuedCompletionStatus_doc}, {"BindLocal", overlapped_BindLocal, METH_VARARGS, BindLocal_doc}, - {"SetFileCompletionNotificationModes", - overlapped_SetFileCompletionNotificationModes, - METH_VARARGS, SetFileCompletionNotificationModes_doc}, {NULL} }; @@ -1013,7 +986,6 @@ PyInit__overlapped(void) d = PyModule_GetDict(m); WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING); - WINAPI_CONSTANT(F_DWORD, FILE_SKIP_COMPLETION_PORT_ON_SUCCESS); WINAPI_CONSTANT(F_DWORD, INFINITE); WINAPI_CONSTANT(F_HANDLE, INVALID_HANDLE_VALUE); WINAPI_CONSTANT(F_HANDLE, NULL); From 04c70017985ebbae6638b1654b4a7a6e8b5e809c Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 7 Feb 2013 15:45:55 -0800 Subject: [PATCH 0310/1502] Limit line length for security in StreamReader.readline --- tulip/http_client.py | 49 +++++++++++++++++++++------------- tulip/http_client_test.py | 56 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 86 insertions(+), 19 deletions(-) diff --git a/tulip/http_client.py b/tulip/http_client.py index 0a03d81f..c5cddba0 100644 --- a/tulip/http_client.py +++ b/tulip/http_client.py @@ -65,28 +65,41 @@ def feed_data(self, data): @tasks.coroutine def readline(self): - # TODO: Limit line length for security. - while not self.line_count and not self.eof: - assert self.waiter is None - self.waiter = futures.Future() - yield from self.waiter parts = [] - while self.buffer: - data = self.buffer.popleft() - ichar = data.find(b'\n') - if ichar < 0: - parts.append(data) - else: - ichar += 1 - head, tail = data[:ichar], data[ichar:] - parts.append(head) - if tail: - self.buffer.appendleft(tail) - self.line_count -= 1 + parts_size = 0 + not_enough = True + + while not_enough: + while self.buffer and not_enough: + data = self.buffer.popleft() + ichar = data.find(b'\n') + if ichar < 0: + parts.append(data) + parts_size += len(data) + else: + ichar += 1 + head, tail = data[:ichar], data[ichar:] + if tail: + self.buffer.appendleft(tail) + self.line_count -= 1 + not_enough = False + parts.append(head) + parts_size += len(head) + + if parts_size > self.limit: + self.byte_count -= parts_size + raise ValueError('Line is too long') + + if self.eof: break + if not_enough: + assert self.waiter is None + self.waiter = futures.Future() + yield from self.waiter + line = b''.join(parts) - self.byte_count -= len(line) + self.byte_count -= parts_size return line diff --git a/tulip/http_client_test.py b/tulip/http_client_test.py index c598a339..3c3b1242 100644 --- a/tulip/http_client_test.py +++ b/tulip/http_client_test.py @@ -5,17 +5,20 @@ from . import events from . import http_client from . import tasks +from . import test_utils -class StreamReaderTests(unittest.TestCase): +class StreamReaderTests(test_utils.LogTrackingTestCase): DATA = b'line1\nline2\nline3\n' def setUp(self): + super().setUp() self.event_loop = events.new_event_loop() events.set_event_loop(self.event_loop) def tearDown(self): + super().tearDown() self.event_loop.close() def test_feed_empty_data(self): @@ -118,6 +121,47 @@ def cb(): self.assertFalse(stream.line_count) self.assertEqual(len(b'\n chunk4')-1, stream.byte_count) + def test_readline_limit_with_existing_data(self): + self.suppress_log_errors() + + stream = http_client.StreamReader(3) + stream.feed_data(b'li') + stream.feed_data(b'ne1\nline2\n') + + read_task = tasks.Task(stream.readline()) + self.assertRaises( + ValueError, self.event_loop.run_until_complete, read_task) + self.assertEqual([b'line2\n'], list(stream.buffer)) + + stream = http_client.StreamReader(3) + stream.feed_data(b'li') + stream.feed_data(b'ne1') + stream.feed_data(b'li') + + read_task = tasks.Task(stream.readline()) + self.assertRaises( + ValueError, self.event_loop.run_until_complete, read_task) + self.assertEqual([b'li'], list(stream.buffer)) + self.assertEqual(2, stream.byte_count) + + def test_readline_limit(self): + self.suppress_log_errors() + + stream = http_client.StreamReader(7) + + def cb(): + stream.feed_data(b'chunk1') + stream.feed_data(b'chunk2') + stream.feed_data(b'chunk3\n') + stream.feed_eof() + self.event_loop.call_soon(cb) + + read_task = tasks.Task(stream.readline()) + self.assertRaises( + ValueError, self.event_loop.run_until_complete, read_task) + self.assertEqual([b'chunk3\n'], list(stream.buffer)) + self.assertEqual(7, stream.byte_count) + def test_readline_line_byte_count(self): stream = http_client.StreamReader() stream.feed_data(self.DATA[:6]) @@ -130,6 +174,16 @@ def test_readline_line_byte_count(self): self.assertEqual(self.DATA.count(b'\n')-1, stream.line_count) self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) + def test_readline_eof(self): + stream = http_client.StreamReader() + stream.feed_data(b'some data') + stream.feed_eof() + + read_task = tasks.Task(stream.readline()) + line = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'some data', line) + def test_readline_empty_eof(self): stream = http_client.StreamReader() stream.feed_eof() From 3b22ab324b85f19b7110b6bac88511ee6c582db0 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 11 Feb 2013 10:12:02 -0800 Subject: [PATCH 0311/1502] Synchronization primitives --- tulip/__init__.py | 2 + tulip/locks.py | 460 +++++++++++++++++++++++++++ tulip/locks_test.py | 752 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 1214 insertions(+) create mode 100644 tulip/locks.py create mode 100644 tulip/locks_test.py diff --git a/tulip/__init__.py b/tulip/__init__.py index 185fe3fe..f3b407ec 100644 --- a/tulip/__init__.py +++ b/tulip/__init__.py @@ -3,12 +3,14 @@ # This relies on each of the submodules having an __all__ variable. from .futures import * from .events import * +from .locks import * from .transports import * from .protocols import * from .tasks import * __all__ = (futures.__all__ + events.__all__ + + locks.__all__ + transports.__all__ + protocols.__all__ + tasks.__all__) diff --git a/tulip/locks.py b/tulip/locks.py new file mode 100644 index 00000000..e55487a6 --- /dev/null +++ b/tulip/locks.py @@ -0,0 +1,460 @@ +"""Synchronization primitives""" + +__all__ = ['Lock', 'EventWaiter', 'Condition', 'Semaphore'] + +import collections +import time + +from . import events +from . import futures +from . import tasks + + +class Lock: + """The class implementing primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned by + a particular coroutine when locked. A primitive lock is in one of two + states, "locked" or "unlocked". + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() changes + the state to locked and returns immediately. When the state is locked, + acquire() blocks until a call to release() in another coroutine changes + it to unlocked, then the acquire() call resets it to locked and returns. + The release() method should only be called in the locked state; it changes + the state to unlocked and returns immediately. If an attempt is made + to release an unlocked lock, a RuntimeError will be raised. + + When more than one coroutine is blocked in acquire() waiting for the state + to turn to unlocked, only one coroutine proceeds when a release() call + resets the state to unlocked; first coroutine which is blocked in acquire() + is being processed. + + acquire() method is a coroutine and should be called with "yield from" + + Locks also support the context manager protocol. (yield from lock) should + be used as context manager expression. + + Usage: + + lock = Lock() + ... + yield from lock + try: + ... + finally: + lock.release() + + Context manager usage: + + lock = Lock() + ... + with (yield from lock): + ... + + Lock object could be tested for locking state: + + if not lock.locked(): + yield from lock + else: + # lock is acquired + ... + + """ + + def __init__(self): + self._waiters = collections.deque() + self._locked = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<%s [%s]>'%(res[1:-1], 'locked' if self._locked else 'unlocked') + + def locked(self): + """Return true if lock is acquired.""" + return self._locked + + @tasks.coroutine + def acquire(self, timeout=None): + """Acquire a lock. + + Acquire method blocks until the lock is unlocked, then set it to + locked and return True. + + When invoked with the floating-point timeout argument set, blocks for + at most the number of seconds specified by timeout and as long as + the lock cannot be acquired. + + The return value is True if the lock is acquired successfully, + False if not (for example if the timeout expired). + """ + if not self._waiters and not self._locked: + self._locked = True + return True + + fut = futures.Future(event_loop=self._event_loop) + if timeout is not None: + handler = self._event_loop.call_later(timeout, fut.cancel) + else: + handler = None + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + if handler is not None: + handler.cancel() + + self._locked = True + return True + + def release(self): + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + If any other coroutines are blocked waiting for the lock to become + unlocked, allow exactly one of them to proceed. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + if self._waiters: + self._waiters[0].set_result(True) + else: + raise RuntimeError('Lock is not acquired.') + + def __enter__(self): + if not self._locked: + raise RuntimeError( + '"yield from" should be used as context manager expression') + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self + + +class EventWaiter: + """A EventWaiter implementation, our equivalent to threading.Event + + Class implementing event objects. An event manages a flag that can be set + to true with the set() method and reset to false with the clear() method. + The wait() method blocks until the flag is true. The flag is initially + false. + """ + + def __init__(self): + self._waiters = collections.deque() + self._value = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<%s [%s]>'%(res[1:-1], 'set' if self._value else 'unset') + + def is_set(self): + """Return true if and only if the internal flag is true.""" + return self._value + + def set(self): + """Set the internal flag to true. All coroutines waiting for it to + become true are awakened. Coroutine that call wait() once the flag is + true will not block at all. + """ + if not self._value: + self._value = True + + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + + def clear(self): + """Reset the internal flag to false. Subsequently, coroutines calling + wait() will block until set() is called to set the internal flag + to true again.""" + self._value = False + + @tasks.coroutine + def wait(self, timeout=None): + """Block until the internal flag is true. If the internal flag + is true on entry, return immediately. Otherwise, block until another + coroutine calls set() to set the flag to true, or until the optional + timeout occurs. + + When the timeout argument is present and not None, it should be + a floating point number specifying a timeout for the operation in + seconds (or fractions thereof). + + This method returns true if and only if the internal flag has been + set to true, either before the wait call or after the wait starts, + so it will always return True except if a timeout is given and + the operation times out. + + wait() method is a coroutine. + """ + if self._value: + return True + + fut = futures.Future(event_loop=self._event_loop) + if timeout is not None: + handler = self._event_loop.call_later(timeout, fut.cancel) + else: + handler = None + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + if handler is not None: + handler.cancel() + + return True + + +class Condition(Lock): + """A Condition implementation. + + This class implements condition variable objects. A condition variable + allows one or more coroutines to wait until they are notified by another + coroutine. + """ + + def __init__(self): + super().__init__() + + self._condition_waiters = collections.deque() + + @tasks.coroutine + def wait(self, timeout=None): + """Wait until notified or until a timeout occurs. If the calling + coroutine has not acquired the lock when this method is called, + a RuntimeError is raised. + + This method releases the underlying lock, and then blocks until it is + awakened by a notify() or notify_all() call for the same condition + variable in another coroutine, or until the optional timeout occurs. + Once awakened or timed out, it re-acquires the lock and returns. + + When the timeout argument is present and not None, it should be + a floating point number specifying a timeout for the operation + in seconds (or fractions thereof). + + The return value is True unless a given timeout expired, in which + case it is False. + """ + if not self._locked: + raise RuntimeError('cannot wait on un-acquired lock') + + self.release() + + fut = futures.Future(event_loop=self._event_loop) + if timeout is not None: + handler = self._event_loop.call_later(timeout, fut.cancel) + else: + handler = None + + self._condition_waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._condition_waiters.remove(fut) + return False + else: + f = self._condition_waiters.popleft() + assert fut is f + finally: + yield from self.acquire() + + if handler is not None: + handler.cancel() + + return True + + @tasks.coroutine + def wait_for(self, predicate, timeout=None): + """Wait until a condition evaluates to True. predicate should be a + callable which result will be interpreted as a boolean value. A timeout + may be provided giving the maximum time to wait. + """ + endtime = None + waittime = timeout + result = predicate() + + while not result: + if waittime is not None: + if endtime is None: + endtime = time.monotonic() + waittime + else: + waittime = endtime - time.monotonic() + if waittime <= 0: + break + + yield from self.wait(waittime) + result = predicate() + + return result + + def notify(self, n=1): + """By default, wake up one coroutine waiting on this condition, if any. + If the calling coroutine has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up at most n of the coroutines waiting for the + condition variable; it is a no-op if no coroutines are waiting. + + Note: an awakened coroutine does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self._locked: + raise RuntimeError('cannot notify on un-acquired lock') + + idx = 0 + for fut in self._condition_waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self): + """Wake up all threads waiting on this condition. This method acts + like notify(), but wakes up all waiting threads instead of one. If the + calling thread has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._condition_waiters)) + + +class Semaphore: + """A Semaphore implementation. + + A semaphore manages an internal counter which is decremented by each + acquire() call and incremented by each release() call. The counter + can never go below zero; when acquire() finds that it is zero, it blocks, + waiting until some other thread calls release(). + + Semaphores also support the context manager protocol. + + The first optional argument gives the initial value for the internal + counter; it defaults to 1. If the value given is less than 0, + ValueError is raised. + + The second optional argument determins can semophore be released more than + initial internal counter value; it defaults to False. If the value given + is True and number of release() is more than number of successfull + acquire() calls ValueError is raised. + """ + + def __init__(self, value=1, bound=False): + if value < 0: + raise ValueError("Semaphore initial value must be > 0") + self._value = value + self._bound = bound + self._bound_value = value + self._waiters = collections.deque() + self._locked = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<%s [%s]>'%( + res[1:-1], + 'locked' if self._locked else 'unlocked,value:%s'%self._value) + + def locked(self): + """Returns True if semaphore can not be acquired immediately.""" + return self._locked + + @tasks.coroutine + def acquire(self, timeout=None): + """Acquire a semaphore. acquire() method is a coroutine. + + When invoked without arguments: if the internal counter is larger + than zero on entry, decrement it by one and return immediately. + If it is zero on entry, block, waiting until some other coroutine has + called release() to make it larger than zero. + + When invoked with a timeout other than None, it will block for at + most timeout seconds. If acquire does not complete successfully in + that interval, return false. Return true otherwise. + """ + if not self._waiters and self._value > 0: + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + fut = futures.Future(event_loop=self._event_loop) + if timeout is not None: + handler = self._event_loop.call_later(timeout, fut.cancel) + else: + handler = None + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + if handler is not None: + handler.cancel() + + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + When it was zero on entry and another coroutine is waiting for it to + become larger than zero again, wake up that coroutine. + + If Semaphore is create with "bound" paramter equals true, then + release() method checks to make sure its current value doesn't exceed + its initial value. If it does, ValueError is raised. + """ + if self._bound and self._value >= self._bound_value: + raise ValueError('Semaphore released too many times') + + self._value += 1 + self._locked = False + + for waiter in self._waiters: + if not waiter.done(): + waiter.set_result(True) + break + + def __enter__(self): + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self diff --git a/tulip/locks_test.py b/tulip/locks_test.py new file mode 100644 index 00000000..9444ceaa --- /dev/null +++ b/tulip/locks_test.py @@ -0,0 +1,752 @@ +"""Tests for lock.py""" + +import time +import unittest +import unittest.mock + +from . import events +from . import futures +from . import locks +from . import tasks +from . import test_utils + + +class LockTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + lock = locks.Lock() + self.assertTrue(repr(lock).endswith('[unlocked]>')) + + def acquire_lock(): + yield from lock + + self.event_loop.run_until_complete(tasks.Task(acquire_lock())) + self.assertTrue(repr(lock).endswith('[locked]>')) + + def test_lock(self): + lock = locks.Lock() + + def acquire_lock(): + return (yield from lock) + + res = self.event_loop.run_until_complete(tasks.Task(acquire_lock())) + + self.assertTrue(res) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_acquire(self): + lock = locks.Lock() + result = [] + + self.assertTrue( + self.event_loop.run_until_complete( + tasks.Task(lock.acquire()) + )) + + @tasks.coroutine + def c1(result): + if (yield from lock.acquire()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from lock.acquire()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from lock.acquire()): + result.append(3) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + self.event_loop.run_once() + self.assertEqual([1], result) + + tasks.Task(c3(result)) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + + def test_acquire_timeout(self): + lock = locks.Lock() + self.assertTrue( + self.event_loop.run_until_complete(tasks.Task(lock.acquire()))) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete( + tasks.Task(lock.acquire(timeout=0.1))) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + lock = locks.Lock() + self.event_loop.run_until_complete(tasks.Task(lock.acquire())) + + self.event_loop.call_later(0.1, lock.release) + acquired = self.event_loop.run_until_complete( + tasks.Task(lock.acquire(10.1))) + self.assertTrue(acquired) + + def test_acquire_timeout_mixed(self): + lock = locks.Lock() + self.event_loop.run_until_complete(tasks.Task(lock.acquire())) + tasks.Task(lock.acquire()) + tasks.Task(lock.acquire()) + acquire_task = tasks.Task(lock.acquire(0.1)) + tasks.Task(lock.acquire()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(lock._waiters)) + + def test_acquire_cancel(self): + lock = locks.Lock() + self.assertTrue( + self.event_loop.run_until_complete(tasks.Task(lock.acquire()))) + + task = tasks.Task(lock.acquire()) + self.event_loop.call_soon(task.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, task) + self.assertFalse(lock._waiters) + + def test_release_not_acquired(self): + lock = locks.Lock() + + self.assertRaises(RuntimeError, lock.release) + + def test_release_no_waiters(self): + lock = locks.Lock() + self.event_loop.run_until_complete(tasks.Task(lock.acquire())) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_context_manager(self): + lock = locks.Lock() + + @tasks.task + def acquire_lock(): + return (yield from lock) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + def test_context_manager_no_yield(self): + lock = locks.Lock() + + try: + with lock: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + +class EventWaiterTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + ev = locks.EventWaiter() + self.assertTrue(repr(ev).endswith('[unset]>')) + + ev.set() + self.assertTrue(repr(ev).endswith('[set]>')) + + def test_wait(self): + ev = locks.EventWaiter() + self.assertFalse(ev.is_set()) + + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from ev.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from ev.wait()): + result.append(3) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + tasks.Task(c3(result)) + + ev.set() + self.event_loop.run_once() + self.assertEqual([3, 1, 2], result) + + def test_wait_on_set(self): + ev = locks.EventWaiter() + ev.set() + + res = self.event_loop.run_until_complete(tasks.Task(ev.wait())) + self.assertTrue(res) + + def test_wait_timeout(self): + ev = locks.EventWaiter() + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(ev.wait(0.1))) + self.assertFalse(res) + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + ev = locks.EventWaiter() + self.event_loop.call_later(0.1, ev.set) + acquired = self.event_loop.run_until_complete(tasks.Task(ev.wait(10.1))) + self.assertTrue(acquired) + + def test_wait_timeout_mixed(self): + ev = locks.EventWaiter() + tasks.Task(ev.wait()) + tasks.Task(ev.wait()) + acquire_task = tasks.Task(ev.wait(0.1)) + tasks.Task(ev.wait()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(ev._waiters)) + + def test_wait_cancel(self): + ev = locks.EventWaiter() + + wait = tasks.Task(ev.wait()) + self.event_loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, wait) + self.assertFalse(ev._waiters) + + def test_clear(self): + ev = locks.EventWaiter() + self.assertFalse(ev.is_set()) + + ev.set() + self.assertTrue(ev.is_set()) + + ev.clear() + self.assertFalse(ev.is_set()) + + def test_clear_with_waiters(self): + ev = locks.EventWaiter() + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + tasks.Task(c1(result)) + self.event_loop.run_once() + self.assertEqual([], result) + + ev.set() + ev.clear() + self.assertFalse(ev.is_set()) + + ev.set() + ev.set() + self.assertEqual(1, len(ev._waiters)) + + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertEqual(0, len(ev._waiters)) + + +class ConditionTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_wait(self): + cond = locks.Condition() + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue( + self.event_loop.run_until_complete(tasks.Task(cond.acquire()))) + cond.notify() + self.event_loop.run_once() + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + def test_wait_timeout(self): + cond = locks.Condition() + self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + + t0 = time.monotonic() + wait = self.event_loop.run_until_complete(tasks.Task(cond.wait(0.1))) + self.assertFalse(wait) + self.assertTrue(cond.locked()) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + def test_wait_cancel(self): + cond = locks.Condition() + self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + + wait = tasks.Task(cond.wait()) + self.event_loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, wait) + self.assertFalse(cond._condition_waiters) + self.assertTrue(cond.locked()) + + def test_wait_unacquired(self): + self.suppress_log_errors() + + cond = locks.Condition() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, tasks.Task(cond.wait())) + + def test_wait_for(self): + cond = locks.Condition() + + presult = False + def predicate(): + return presult + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate)): + result.append(1) + cond.release() + + tasks.Task(c1(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([], result) + + presult = True + self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + def test_wait_for_timeout(self): + cond = locks.Condition() + + result = [] + + predicate = unittest.mock.Mock() + predicate.return_value = False + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate, 0.2)): + result.append(1) + else: + result.append(2) + cond.release() + + wait_for = tasks.Task(c1(result)) + + t0 = time.monotonic() + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(wait_for) + self.assertEqual([2], result) + self.assertEqual(3, predicate.call_count) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.18 < total_time < 0.22) + + def test_wait_for_unacquired(self): + self.suppress_log_errors() + + cond = locks.Condition() + + # predicate can return true immediately + res = self.event_loop.run_until_complete(tasks.Task( + cond.wait_for(lambda: [1,2,3]))) + self.assertEqual([1,2,3], res) + + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, + tasks.Task(cond.wait_for(lambda: False))) + + def test_notify(self): + cond = locks.Condition() + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + cond.release() + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + cond.notify(1) + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + cond.notify(1) + cond.notify(2048) + cond.release() + self.event_loop.run_once() + self.assertEqual([1,2,3], result) + + def test_notify_all(self): + cond = locks.Condition() + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + cond.notify_all() + cond.release() + self.event_loop.run_once() + self.assertEqual([1,2], result) + + def test_notify_unacquired(self): + cond = locks.Condition() + self.assertRaises(RuntimeError, cond.notify) + + def test_notify_all_unacquired(self): + cond = locks.Condition() + self.assertRaises(RuntimeError, cond.notify_all) + + +class SemaphoreTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + sem = locks.Semaphore() + self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) + + self.event_loop.run_until_complete(tasks.Task(sem.acquire())) + self.assertTrue(repr(sem).endswith('[locked]>')) + + def test_semaphore(self): + sem = locks.Semaphore() + self.assertEqual(1, sem._value) + + @tasks.task + def acquire_lock(): + return (yield from sem) + + res = self.event_loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(sem.locked()) + self.assertEqual(0, sem._value) + + sem.release() + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + def test_semaphore_value(self): + self.assertRaises(ValueError, locks.Semaphore, -1) + + def test_acquire(self): + sem = locks.Semaphore(3) + result = [] + + self.assertTrue( + self.event_loop.run_until_complete(tasks.Task(sem.acquire()))) + self.assertTrue( + self.event_loop.run_until_complete(tasks.Task(sem.acquire()))) + self.assertFalse(sem.locked()) + + @tasks.coroutine + def c1(result): + res = yield from sem.acquire() + result.append(1) + + @tasks.coroutine + def c2(result): + res = yield from sem.acquire() + result.append(2) + + @tasks.coroutine + def c3(result): + yield from sem.acquire() + result.append(3) + + @tasks.coroutine + def c4(result): + yield from sem.acquire() + result.append(4) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(sem.locked()) + self.assertEqual(2, len(sem._waiters)) + self.assertEqual(0, sem._value) + + tasks.Task(c4(result)) + + sem.release() + sem.release() + self.assertEqual(2, sem._value) + + self.event_loop.run_once() + self.assertEqual(0, sem._value) + self.assertEqual([1, 2, 3], result) + self.assertTrue(sem.locked()) + self.assertEqual(1, len(sem._waiters)) + self.assertEqual(0, sem._value) + + def test_acquire_timeout(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(tasks.Task(sem.acquire())) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete( + tasks.Task(sem.acquire(0.1))) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + sem = locks.Semaphore() + self.event_loop.run_until_complete(tasks.Task(sem.acquire())) + + self.event_loop.call_later(0.1, sem.release) + acquired = self.event_loop.run_until_complete( + tasks.Task(sem.acquire(10.1))) + self.assertTrue(acquired) + + def test_acquire_timeout_mixed(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(tasks.Task(sem.acquire())) + tasks.Task(sem.acquire()) + tasks.Task(sem.acquire()) + acquire_task = tasks.Task(sem.acquire(0.1)) + tasks.Task(sem.acquire()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(sem._waiters)) + + def test_acquire_cancel(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(tasks.Task(sem.acquire())) + + acquire = tasks.Task(sem.acquire()) + self.event_loop.call_soon(acquire.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, acquire) + self.assertFalse(sem._waiters) + + def test_release_not_acquired(self): + sem = locks.Semaphore(bound=True) + + self.assertRaises(ValueError, sem.release) + + def test_release_no_waiters(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(tasks.Task(sem.acquire())) + self.assertTrue(sem.locked()) + + sem.release() + self.assertFalse(sem.locked()) + + def test_context_manager(self): + sem = locks.Semaphore(2) + + @tasks.task + def acquire_lock(): + return (yield from sem) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertTrue(sem.locked()) + + self.assertEqual(2, sem._value) + + +if __name__ == '__main__': + unittest.main() From 155a3ec1746abe7671bad5de0689a7518383db0a Mon Sep 17 00:00:00 2001 From: Sa?l Ibarra Corretg? Date: Mon, 11 Feb 2013 22:33:30 +0100 Subject: [PATCH 0312/1502] Added timeout tests for run_until_complete --- tulip/tasks_test.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py index fac97655..10d99893 100644 --- a/tulip/tasks_test.py +++ b/tulip/tasks_test.py @@ -131,6 +131,29 @@ def task(): self.assertTrue(t.done()) self.assertFalse(t.cancel()) + def test_timeout(self): + @tasks.task + def task(): + yield from tasks.sleep(10.0) + return 42 + + t = task() + self.assertRaises( + futures.TimeoutError, + self.event_loop.run_until_complete, t, 0.1) + self.assertFalse(t.done()) + + def test_timeout_not(self): + @tasks.task + def task(): + yield from tasks.sleep(0.1) + return 42 + + t = task() + r = self.event_loop.run_until_complete(t, 10.0) + self.assertTrue(t.done()) + self.assertEqual(r, 42) + def test_wait(self): a = tasks.sleep(0.1) b = tasks.sleep(0.15) From b0792989af3a4df3b6658d0fca87e582c933a471 Mon Sep 17 00:00:00 2001 From: Sa?l Ibarra Corretg? Date: Tue, 12 Feb 2013 00:01:48 +0100 Subject: [PATCH 0313/1502] Fixed run_until_complete It will internally call run_forever in case the specified timout is None and it will always call result() on the future in case the timeout wasn't hit. This means that if the loop was stopped prematurely, futures.InvalidStateError will be raised. --- tulip/base_events.py | 21 +++++++++++++-------- tulip/tasks_test.py | 27 +++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/tulip/base_events.py b/tulip/base_events.py index 8787b5b3..ddf0122b 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -37,7 +37,7 @@ class _StopError(BaseException): """Raised to stop the event loop.""" -def _raise_stop_error(): +def _raise_stop_error(*args): raise _StopError @@ -94,16 +94,21 @@ def run_until_complete(self, future, timeout=None): Return the Future's result, or raise its exception. If the timeout is reached or stop() is called, raise TimeoutError. """ + handler_called = False + def stop_loop(): + nonlocal handler_called + handler_called = True + raise _StopError + future.add_done_callback(_raise_stop_error) if timeout is None: - timeout = 0x7fffffff/1000.0 # 24 days - future.add_done_callback(lambda _: self.stop()) - handler = self.call_later(timeout, _raise_stop_error) - self.run() - handler.cancel() - if future.done(): - return future.result() # May raise future.exception(). + self.run_forever() else: + handler = self.call_later(timeout, stop_loop) + self.run() + handler.cancel() + if handler_called: raise futures.TimeoutError + return future.result() def stop(self): """Stop running the event loop. diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py index 10d99893..f122d279 100644 --- a/tulip/tasks_test.py +++ b/tulip/tasks_test.py @@ -131,6 +131,27 @@ def task(): self.assertTrue(t.done()) self.assertFalse(t.cancel()) + def test_stop_while_run_in_complete(self): + x = 0 + @tasks.coroutine + def task(): + nonlocal x + while x < 10: + yield from tasks.sleep(0.1) + x += 1 + if x == 2: + self.event_loop.stop() + + t = tasks.Task(task()) + t0 = time.monotonic() + self.assertRaises( + futures.InvalidStateError, + self.event_loop.run_until_complete, t) + t1 = time.monotonic() + self.assertFalse(t.done()) + self.assertTrue(0.18 <= t1-t0 <= 0.22) + self.assertEqual(x, 2) + def test_timeout(self): @tasks.task def task(): @@ -138,10 +159,13 @@ def task(): return 42 t = task() + t0 = time.monotonic() self.assertRaises( futures.TimeoutError, self.event_loop.run_until_complete, t, 0.1) + t1 = time.monotonic() self.assertFalse(t.done()) + self.assertTrue(0.08 <= t1-t0 <= 0.12) def test_timeout_not(self): @tasks.task @@ -150,9 +174,12 @@ def task(): return 42 t = task() + t0 = time.monotonic() r = self.event_loop.run_until_complete(t, 10.0) + t1 = time.monotonic() self.assertTrue(t.done()) self.assertEqual(r, 42) + self.assertTrue(0.08 <= t1-t0 <= 0.12) def test_wait(self): a = tasks.sleep(0.1) From 5fd30adfa7e8750f84eeb04e6dfe6761f9feb087 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 12 Feb 2013 11:50:22 -0800 Subject: [PATCH 0314/1502] _SelectorSocketTransport and _SelectorSslTransport tests --- tulip/selector_events.py | 20 +- tulip/selector_events_test.py | 567 ++++++++++++++++++++++++++++++++++ 2 files changed, 580 insertions(+), 7 deletions(-) create mode 100644 tulip/selector_events_test.py diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 31c3d4b3..850e2a47 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -9,7 +9,7 @@ import socket try: import ssl -except ImportError: +except ImportError: # pragma: no cover ssl = None import sys @@ -31,7 +31,7 @@ # Errno values indicating the socket isn't ready for I/O just yet. _TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK, errno.EINPROGRESS)) -if sys.platform == 'win32': +if sys.platform == 'win32': # pragma: no cover _TRYAGAIN = frozenset(list(_TRYAGAIN) + [errno.WSAEWOULDBLOCK]) @@ -473,16 +473,22 @@ def _on_ready(self): try: n = self._sslsock.send(data) except ssl.SSLWantReadError: - pass + n = 0 except ssl.SSLWantWriteError: - pass + n = 0 except socket.error as exc: if exc.errno not in _TRYAGAIN: self._fatal_error(exc) return - else: - if n < len(data): - self._buffer.append(data[n:]) + else: + n = 0 + + if n < len(data): + self._buffer.append(data[n:]) + elif self._closing: + self._event_loop.remove_writer(self._sslsock.fileno()) + self._sslsock.close() + self._protocol.connection_lost(None) def write(self, data): assert isinstance(data, bytes) diff --git a/tulip/selector_events_test.py b/tulip/selector_events_test.py new file mode 100644 index 00000000..bbef84e5 --- /dev/null +++ b/tulip/selector_events_test.py @@ -0,0 +1,567 @@ +"""Tests for selector_events.py""" + +import errno +import socket +import unittest +import unittest.mock +try: + import ssl +except ImportError: + ssl = None + +from . import futures +from .selector_events import _SelectorSslTransport +from .selector_events import _SelectorSocketTransport + + +class SelectorSocketTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock() + self.sock = unittest.mock.Mock() + self.protocol = unittest.mock.Mock() + + def test_ctor(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + self.assertTrue(self.event_loop.add_reader.called) + self.assertTrue(self.event_loop.call_soon.called) + + def test_ctor_with_waiter(self): + fut = futures.Future() + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol, fut) + self.assertEqual(2, self.event_loop.call_soon.call_count) + self.assertEqual(fut.set_result, + self.event_loop.call_soon.call_args[0][0]) + + def test_read_ready(self): + data_received = unittest.mock.Mock() + self.protocol.data_received = data_received + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + + self.sock.recv.return_value = b'data' + transport._read_ready() + + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (data_received, b'data'), + self.event_loop.call_soon.call_args[0]) + + def test_read_ready_eof(self): + eof_received = unittest.mock.Mock() + self.protocol.eof_received = eof_received + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + + self.sock.recv.return_value = b'' + transport._read_ready() + + self.assertTrue(self.event_loop.remove_reader.called) + self.assertEqual( + (eof_received,), self.event_loop.call_soon.call_args[0]) + + def test_read_ready_tryagain(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + + class Err(socket.error): + errno = errno.EAGAIN + + self.sock.recv.side_effect = Err + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + + def test_read_ready_err(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + + class Err(socket.error): + pass + + self.sock.recv.side_effect = Err + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + self.assertTrue(transport._fatal_error.called) + self.assertIsInstance(transport._fatal_error.call_args[0][0], Err) + + def test_abort(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + + transport.abort() + self.assertTrue(transport._fatal_error.called) + self.assertIsNone(transport._fatal_error.call_args[0][0]) + + def test_write(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport.write(data) + self.assertTrue(self.sock.send.called) + self.assertEqual(self.sock.send.call_args[0], (data,)) + + def test_write_no_data(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append(b'data') + transport.write(b'') + self.assertFalse(self.sock.send.called) + self.assertEqual([b'data'], transport._buffer) + + def test_write_buffer(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append(b'data1') + transport.write(b'data2') + self.assertFalse(self.sock.send.called) + self.assertEqual([b'data1', b'data2'], transport._buffer) + + def test_write_partial(self): + data = b'data' + self.sock.send.return_value = 2 + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport.write(data) + + self.assertTrue(self.event_loop.add_writer.called) + self.assertEqual( + transport._write_ready, self.event_loop.add_writer.call_args[0][1]) + + self.assertEqual([b'ta'], transport._buffer) + + def test_write_partial_none(self): + data = b'data' + self.sock.send.return_value = 0 + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport.write(data) + + self.assertTrue(self.event_loop.add_writer.called) + self.assertEqual( + transport._write_ready, self.event_loop.add_writer.call_args[0][1]) + + self.assertEqual([b'data'], transport._buffer) + + def test_write_tryagain(self): + data = b'data' + + class Err(socket.error): + errno = errno.EAGAIN + + self.sock.send.side_effect = Err + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport.write(data) + + self.assertTrue(self.event_loop.add_writer.called) + self.assertEqual( + transport._write_ready, self.event_loop.add_writer.call_args[0][1]) + + self.assertEqual([b'data'], transport._buffer) + + def test_write_exception(self): + data = b'data' + + class Err(socket.error): + pass + + self.sock.send.side_effect = Err + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport.write(data) + + self.assertTrue(transport._fatal_error.called) + self.assertIsInstance(transport._fatal_error.call_args[0][0], Err) + + def test_write_str(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + self.assertRaises(AssertionError, transport.write, 'str') + + def test_write_closing(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport.close() + self.assertRaises(AssertionError, transport.write, b'data') + + def test_write_ready(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append(data) + transport._write_ready() + self.assertTrue(self.sock.send.called) + self.assertEqual(self.sock.send.call_args[0], (data,)) + self.assertTrue(self.event_loop.remove_writer.called) + + def test_write_ready_closing(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._closing = True + transport._buffer.append(data) + transport._write_ready() + self.assertTrue(self.sock.send.called) + self.assertEqual(self.sock.send.call_args[0], (data,)) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (transport._call_connection_lost, None), + self.event_loop.call_soon.call_args[0]) + + def test_write_ready_no_data(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._write_ready() + self.assertFalse(self.sock.send.called) + self.assertTrue(self.event_loop.remove_writer.called) + + def test_write_ready_partial(self): + data = b'data' + self.sock.send.return_value = 2 + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append(data) + transport._write_ready() + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'ta'], transport._buffer) + + def test_write_ready_partial_none(self): + data = b'data' + self.sock.send.return_value = 0 + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append(data) + transport._write_ready() + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'data'], transport._buffer) + + def test_write_ready_tryagain(self): + class Err(socket.error): + errno = errno.EAGAIN + + self.sock.send.side_effect = Err + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer = [b'data1', b'data2'] + transport._write_ready() + + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'data1data2'], transport._buffer) + + def test_write_ready_exception(self): + class Err(socket.error): + pass + + self.sock.send.side_effect = Err + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._buffer.append(b'data') + transport._write_ready() + + self.assertTrue(transport._fatal_error.called) + self.assertIsInstance(transport._fatal_error.call_args[0][0], Err) + + def test_close(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport.close() + + self.assertTrue(transport._closing) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (transport._call_connection_lost, None), + self.event_loop.call_soon.call_args[0]) + + def test_close_write_buffer(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + self.event_loop.reset_mock() + transport._buffer.append(b'data') + transport.close() + + self.assertTrue(self.event_loop.remove_reader.called) + self.assertFalse(self.event_loop.call_soon.called) + + def test_fatal_error(self): + exc = object() + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + self.event_loop.reset_mock() + transport._buffer.append(b'data') + transport._fatal_error(exc) + + self.assertEqual([], transport._buffer) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (transport._call_connection_lost, exc), + self.event_loop.call_soon.call_args[0]) + + def test_connection_lost(self): + exc = object() + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + self.sock.reset_mock() + self.protocol.reset_mock() + transport._call_connection_lost(exc) + + self.assertTrue(self.protocol.connection_lost.called) + self.assertEqual( + (exc,), self.protocol.connection_lost.call_args[0]) + self.assertTrue(self.sock.close.called) + + +@unittest.skipIf(ssl is None, 'No ssl module') +class SelectorSslTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock() + self.sock = unittest.mock.Mock() + self.protocol = unittest.mock.Mock() + self.sslsock = unittest.mock.Mock() + self.sslsock.fileno.return_value = 1 + self.sslcontext = unittest.mock.Mock() + self.sslcontext.wrap_socket.return_value = self.sslsock + self.waiter = futures.Future() + + self.transport = _SelectorSslTransport( + self.event_loop, self.sock, + self.protocol, self.sslcontext, self.waiter) + self.event_loop.reset_mock() + self.sock.reset_mock() + self.protocol.reset_mock() + self.sslcontext.reset_mock() + + def test_on_handshake(self): + self.transport._on_handshake() + self.assertTrue(self.sslsock.do_handshake.called) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertEqual( + (1, self.transport._on_ready,), + self.event_loop.add_reader.call_args[0]) + self.assertEqual( + (1, self.transport._on_ready,), + self.event_loop.add_writer.call_args[0]) + + def test_on_handshake_reader_retry(self): + self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError + self.transport._on_handshake() + self.assertEqual( + (1, self.transport._on_handshake,), + self.event_loop.add_reader.call_args[0]) + + def test_on_handshake_writer_retry(self): + self.sslsock.do_handshake.side_effect = ssl.SSLWantWriteError + self.transport._on_handshake() + self.assertEqual( + (1, self.transport._on_handshake,), + self.event_loop.add_writer.call_args[0]) + + def test_on_handshake_exc(self): + self.sslsock.do_handshake.side_effect = ValueError + self.transport._on_handshake() + self.assertTrue(self.sslsock.close.called) + + def test_on_handshake_base_exc(self): + self.sslsock.do_handshake.side_effect = BaseException + self.assertRaises(BaseException, self.transport._on_handshake) + self.assertTrue(self.sslsock.close.called) + + def test_write_no_data(self): + self.transport._buffer.append(b'data') + self.transport.write(b'') + self.assertEqual([b'data'], self.transport._buffer) + + def test_write_str(self): + self.assertRaises(AssertionError, self.transport.write, 'str') + + def test_write_closing(self): + self.transport.close() + self.assertRaises(AssertionError, self.transport.write, b'data') + + def test_abort(self): + self.transport._fatal_error = unittest.mock.Mock() + + self.transport.abort() + self.assertTrue(self.transport._fatal_error.called) + self.assertEqual((None,), self.transport._fatal_error.call_args[0]) + + def test_fatal_error(self): + exc = object() + self.transport._buffer.append(b'data') + self.transport._fatal_error(exc) + + self.assertEqual([], self.transport._buffer) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (self.protocol.connection_lost, exc), + self.event_loop.call_soon.call_args[0]) + + def test_close(self): + self.transport.close() + + self.assertTrue(self.transport._closing) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (self.protocol.connection_lost, None), + self.event_loop.call_soon.call_args[0]) + + def test_close_write_buffer(self): + self.transport._buffer.append(b'data') + self.transport.close() + + self.assertTrue(self.event_loop.remove_reader.called) + self.assertFalse(self.event_loop.call_soon.called) + + def test_on_ready_closed(self): + self.sslsock.fileno.return_value = -1 + self.transport._on_ready() + self.assertFalse(self.sslsock.recv.called) + + def test_on_ready_recv(self): + self.sslsock.recv.return_value = b'data' + self.transport._on_ready() + self.assertTrue(self.sslsock.recv.called) + self.assertEqual((b'data',), self.protocol.data_received.call_args[0]) + + def test_on_ready_recv_eof(self): + self.sslsock.recv.return_value = b'' + self.transport._on_ready() + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.sslsock.close.called) + self.assertTrue(self.protocol.connection_lost.called) + + def test_on_ready_recv_retry(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.transport._on_ready() + self.assertTrue(self.sslsock.recv.called) + self.assertFalse(self.protocol.data_received.called) + + self.sslsock.recv.side_effect = ssl.SSLWantWriteError + self.transport._on_ready() + self.assertFalse(self.protocol.data_received.called) + + class Err(socket.error): + errno = errno.EAGAIN + + self.sslsock.recv.side_effect = Err + self.transport._on_ready() + self.assertFalse(self.protocol.data_received.called) + + def test_on_ready_recv_exc(self): + class Err(socket.error): + pass + + self.sslsock.recv.side_effect = Err + self.transport._fatal_error = unittest.mock.Mock() + self.transport._on_ready() + self.assertTrue(self.transport._fatal_error.called) + + def test_on_ready_send(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 4 + self.transport._buffer = [b'data'] + self.transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertEqual([], self.transport._buffer) + + def test_on_ready_send_none(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 0 + self.transport._buffer = [b'data1', b'data2'] + self.transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertEqual([b'data1data2'], self.transport._buffer) + + def test_on_ready_send_partial(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 2 + self.transport._buffer = [b'data1', b'data2'] + self.transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertEqual([b'ta1data2'], self.transport._buffer) + + def test_on_ready_send_closing_partial(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 2 + self.transport._buffer = [b'data1', b'data2'] + self.transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertFalse(self.sslsock.close.called) + + def test_on_ready_send_closing(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 4 + self.transport.close() + self.transport._buffer = [b'data'] + self.transport._on_ready() + self.assertTrue(self.sslsock.close.called) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.protocol.connection_lost.called) + + def test_on_ready_send_retry(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + + self.transport._buffer = [b'data'] + + self.sslsock.send.side_effect = ssl.SSLWantReadError + self.transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertEqual([b'data'], self.transport._buffer) + + self.sslsock.send.side_effect = ssl.SSLWantWriteError + self.transport._on_ready() + self.assertEqual([b'data'], self.transport._buffer) + + class Err(socket.error): + errno = errno.EAGAIN + + self.sslsock.send.side_effect = Err + self.transport._on_ready() + self.assertEqual([b'data'], self.transport._buffer) + + def test_on_ready_send_exc(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + + class Err(socket.error): + pass + + self.sslsock.send.side_effect = Err + self.transport._buffer = [b'data'] + self.transport._fatal_error = unittest.mock.Mock() + self.transport._on_ready() + self.assertTrue(self.transport._fatal_error.called) + self.assertEqual([], self.transport._buffer) \ No newline at end of file From f06a34b59e4b669b36db02dd0db6f6ad2fda31fd Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 12 Feb 2013 12:46:07 -0800 Subject: [PATCH 0315/1502] more tests for events.Timer --- tulip/events.py | 2 +- tulip/events_test.py | 94 ++++++++++++++++++++++++++++++++++++++------ 2 files changed, 84 insertions(+), 12 deletions(-) diff --git a/tulip/events.py b/tulip/events.py index 1a509378..86f8c508 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -267,7 +267,7 @@ def new_event_loop(self): loop. """ # TODO: Do something else for Windows. - if sys.platform == 'win32': + if sys.platform == 'win32': # pragma: no cover from . import windows_events return windows_events.SelectorEventLoop() else: diff --git a/tulip/events_test.py b/tulip/events_test.py index cf868b42..2ad7b00b 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -18,6 +18,7 @@ import unittest.mock from . import events +from . import futures from . import transports from . import protocols from . import selector_events @@ -659,22 +660,93 @@ def callback(*args): h1 = events.Timer(when, callback, ()) h2 = events.Timer(when, callback, ()) - self.assertTrue((h1 < h2) == False) - self.assertTrue((h1 <= h2) == True) - self.assertTrue((h1 > h2) == False) - self.assertTrue((h1 >= h2) == True) - self.assertTrue((h1 == h2) == True) + self.assertFalse(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertTrue(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertFalse(h2 > h1) + self.assertTrue(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertTrue(h1 == h2) + self.assertFalse(h1 != h2) h2.cancel() - self.assertTrue((h1 == h2) == False) + self.assertFalse(h1 == h2) h1 = events.Timer(when, callback, ()) h2 = events.Timer(when + 10.0, callback, ()) - self.assertTrue((h1 < h2) == True) - self.assertTrue((h1 <= h2) == True) - self.assertTrue((h1 > h2) == False) - self.assertTrue((h1 >= h2) == False) - self.assertTrue((h1 == h2) == False) + self.assertTrue(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertFalse(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertTrue(h2 > h1) + self.assertFalse(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h3 = events.Handler(callback, ()) + self.assertIs(NotImplemented, h1.__eq__(h3)) + self.assertIs(NotImplemented, h1.__ne__(h3)) + + +class AbstractEventLoopTests(unittest.TestCase): + + def test_not_imlemented(self): + f = unittest.mock.Mock() + ev_loop = events.AbstractEventLoop() + self.assertRaises( + NotImplementedError, ev_loop.run) + self.assertRaises( + NotImplementedError, ev_loop.run_forever) + self.assertRaises( + NotImplementedError, ev_loop.run_once) + self.assertRaises( + NotImplementedError, ev_loop.run_until_complete, None) + self.assertRaises( + NotImplementedError, ev_loop.stop) + self.assertRaises( + NotImplementedError, ev_loop.call_later, None, None) + self.assertRaises( + NotImplementedError, ev_loop.call_repeatedly, None, None) + self.assertRaises( + NotImplementedError, ev_loop.call_soon, None) + self.assertRaises( + NotImplementedError, ev_loop.call_soon_threadsafe, None) + self.assertRaises( + NotImplementedError, ev_loop.wrap_future, f) + self.assertRaises( + NotImplementedError, ev_loop.run_in_executor, f, f) + self.assertRaises( + NotImplementedError, ev_loop.getaddrinfo, 'localhost', 8080) + self.assertRaises( + NotImplementedError, ev_loop.getnameinfo, ('localhost', 8080)) + self.assertRaises( + NotImplementedError, ev_loop.create_connection, f) + self.assertRaises( + NotImplementedError, ev_loop.start_serving, f) + self.assertRaises( + NotImplementedError, ev_loop.add_reader, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_reader, 1) + self.assertRaises( + NotImplementedError, ev_loop.add_writer, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_writer, 1) + self.assertRaises( + NotImplementedError, ev_loop.sock_recv, f, 10) + self.assertRaises( + NotImplementedError, ev_loop.sock_sendall, f, 10) + self.assertRaises( + NotImplementedError, ev_loop.sock_connect, f, f) + self.assertRaises( + NotImplementedError, ev_loop.sock_accept, f) + self.assertRaises( + NotImplementedError, ev_loop.add_signal_handler, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_signal_handler, 1) class PolicyTests(unittest.TestCase): From fee2dd9ade38bab46ebfbe581acb25751fd09fe6 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 12 Feb 2013 15:33:45 -0800 Subject: [PATCH 0316/1502] Rename SocketTransport and SslTransport to _make_socket_transport and _make_ssl_transport --- tulip/base_events.py | 16 +++++++++++++--- tulip/base_events_test.py | 20 ++++++++++++++++++++ tulip/selector_events.py | 21 +++++++++++---------- tulip/selector_events_test.py | 25 +++++++++++++++++++++++++ 4 files changed, 69 insertions(+), 13 deletions(-) create mode 100644 tulip/base_events_test.py diff --git a/tulip/base_events.py b/tulip/base_events.py index ddf0122b..de9396a4 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -49,6 +49,15 @@ def __init__(self): self._default_executor = None self._signal_handlers = {} + def _make_socket_transport(self, event_loop, sock, protocol, waiter=None): + """Create socket transport.""" + raise NotImplementedError + + def _make_ssl_transport(self, event_loop, rawsock, + protocol, sslcontext, waiter): + """Create SSL transport.""" + raise NotImplementedError + def run(self): """Run the event loop until nothing left to do or stop() called. @@ -256,10 +265,11 @@ def create_connection(self, protocol_factory, host=None, port=None, *, waiter = futures.Future() if ssl: sslcontext = None if isinstance(ssl, bool) else ssl - transport = self.SslTransport(self, sock, protocol, - sslcontext, waiter) + transport = self._make_ssl_transport( + self, sock, protocol, sslcontext, waiter) else: - transport = self.SocketTransport(self, sock, protocol, waiter) + transport = self._make_socket_transport( + self, sock, protocol, waiter) yield from waiter return transport, protocol diff --git a/tulip/base_events_test.py b/tulip/base_events_test.py new file mode 100644 index 00000000..03f5a736 --- /dev/null +++ b/tulip/base_events_test.py @@ -0,0 +1,20 @@ +"""Tests for base_events.py""" + +import unittest +import unittest.mock + +from . import base_events + + +class BaseEventLoopTests(unittest.TestCase): + + def test_not_implemented(self): + base_event_loop = base_events.BaseEventLoop() + + m = unittest.mock.Mock() + self.assertRaises( + NotImplementedError, + base_event_loop._make_socket_transport, m, m, m) + self.assertRaises( + NotImplementedError, + base_event_loop._make_ssl_transport, m, m, m, m, m) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 850e2a47..8064d2ab 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -41,23 +41,24 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): See events.EventLoop for API specification. """ - @staticmethod - def SocketTransport(event_loop, sock, protocol, waiter=None): - return _SelectorSocketTransport(event_loop, sock, protocol, waiter) - - @staticmethod - def SslTransport(event_loop, rawsock, protocol, sslcontext, waiter): - return _SelectorSslTransport(event_loop, rawsock, protocol, - sslcontext, waiter) - def __init__(self, selector=None): super().__init__() + if selector is None: selector = selectors.Selector() logging.debug('Using selector: %s', selector.__class__.__name__) self._selector = selector self._make_self_pipe() + def _make_socket_transport(self, event_loop, sock, protocol, waiter=None): + return _SelectorSocketTransport( + event_loop, sock, protocol, waiter) + + def _make_ssl_transport(self, event_loop, rawsock, protocol, + sslcontext, waiter): + return _SelectorSslTransport( + event_loop, rawsock, protocol, sslcontext, waiter) + def close(self): if self._selector is not None: self._selector.close() @@ -109,7 +110,7 @@ def _accept_connection(self, protocol_factory, sock): logging.exception('Accept failed') return protocol = protocol_factory() - transport = self.SocketTransport(self, conn, protocol) + transport = self._make_socket_transport(self, conn, protocol) # It's now up to the protocol to handle the connection. def add_reader(self, fd, callback, *args): diff --git a/tulip/selector_events_test.py b/tulip/selector_events_test.py index bbef84e5..aca8b04d 100644 --- a/tulip/selector_events_test.py +++ b/tulip/selector_events_test.py @@ -10,10 +10,35 @@ ssl = None from . import futures +from .selector_events import BaseSelectorEventLoop from .selector_events import _SelectorSslTransport from .selector_events import _SelectorSocketTransport +class TestBaseSelectorEventLoop(BaseSelectorEventLoop): + + def _make_self_pipe(self): + pass + + +class BaseSelectorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.event_loop = TestBaseSelectorEventLoop() + + def test_make_socket_transport(self): + m = unittest.mock.Mock() + self.assertIsInstance( + self.event_loop._make_socket_transport(m, m, m), + _SelectorSocketTransport) + + def test_make_ssl_transport(self): + m = unittest.mock.Mock() + self.assertIsInstance( + self.event_loop._make_ssl_transport(m, m, m, m, m), + _SelectorSslTransport) + + class SelectorSocketTransportTests(unittest.TestCase): def setUp(self): From a7ae8d7fffbac76c147d2a04bb1f0ceec2e0c12d Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 12 Feb 2013 20:08:22 -0800 Subject: [PATCH 0317/1502] BaseSelectorEventLoop tests --- tulip/selector_events.py | 13 ++- tulip/selector_events_test.py | 211 +++++++++++++++++++++++++++++++++- 2 files changed, 219 insertions(+), 5 deletions(-) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 8064d2ab..d23652fc 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -201,23 +201,28 @@ def _sock_recv(self, fut, registered, sock, n): def sock_sendall(self, sock, data): """XXX""" fut = futures.Future() - self._sock_sendall(fut, False, sock, data) + if data: + self._sock_sendall(fut, False, sock, data) + else: + fut.set_result(None) return fut def _sock_sendall(self, fut, registered, sock, data): fd = sock.fileno() + if registered: self.remove_writer(fd) if fut.cancelled(): return - n = 0 + try: - if data: - n = sock.send(data) + n = sock.send(data) except socket.error as exc: if exc.errno not in _TRYAGAIN: fut.set_exception(exc) return + n = 0 + if n == len(data): fut.set_result(None) else: diff --git a/tulip/selector_events_test.py b/tulip/selector_events_test.py index aca8b04d..da9fc0ff 100644 --- a/tulip/selector_events_test.py +++ b/tulip/selector_events_test.py @@ -18,7 +18,8 @@ class TestBaseSelectorEventLoop(BaseSelectorEventLoop): def _make_self_pipe(self): - pass + self._ssock = unittest.mock.Mock() + self._csock = unittest.mock.Mock() class BaseSelectorEventLoopTests(unittest.TestCase): @@ -38,6 +39,214 @@ def test_make_ssl_transport(self): self.event_loop._make_ssl_transport(m, m, m, m, m), _SelectorSslTransport) + def test_close(self): + self.event_loop._selector.close() + self.event_loop._selector = selector = unittest.mock.Mock() + self.event_loop.close() + self.assertIsNone(self.event_loop._selector) + self.assertTrue(selector.close.called) + self.assertTrue(self.event_loop._ssock.close.called) + self.assertTrue(self.event_loop._csock.close.called) + + def test_close_no_selector(self): + self.event_loop._selector.close() + self.event_loop._selector = None + self.event_loop.close() + self.assertIsNone(self.event_loop._selector) + self.assertTrue(self.event_loop._ssock.close.called) + self.assertTrue(self.event_loop._csock.close.called) + + def test_socketpair(self): + self.assertRaises(NotImplementedError, self.event_loop._socketpair) + + def test_read_from_self_tryagain(self): + class Err(socket.error): + errno = errno.EAGAIN + + self.event_loop._ssock.recv.side_effect = Err + self.assertIsNone(self.event_loop._read_from_self()) + + def test_read_from_self_exception(self): + class Err(socket.error): + pass + + self.event_loop._ssock.recv.side_effect = Err + self.assertRaises(Err, self.event_loop._read_from_self) + + def test_write_to_self_tryagain(self): + class Err(socket.error): + errno = errno.EAGAIN + + self.event_loop._csock.send.side_effect = Err + self.assertIsNone(self.event_loop._write_to_self()) + + def test_write_to_self_exception(self): + class Err(socket.error): + pass + + self.event_loop._csock.send.side_effect = Err + self.assertRaises(Err, self.event_loop._write_to_self) + + def test_sock_recv(self): + sock = unittest.mock.Mock() + self.event_loop._sock_recv = unittest.mock.Mock() + + f = self.event_loop.sock_recv(sock, 1024) + self.assertIsInstance(f, futures.Future) + self.assertEqual( + (f, False, sock, 1024), + self.event_loop._sock_recv.call_args[0]) + + def test__sock_recv_canceled_fut(self): + sock = unittest.mock.Mock() + + f = futures.Future() + f.cancel() + + self.event_loop._sock_recv(f, False, sock, 1024) + self.assertFalse(sock.recv.called) + + def test__sock_recv_unregister(self): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + f = futures.Future() + f.cancel() + + self.event_loop.remove_reader = unittest.mock.Mock() + self.event_loop._sock_recv(f, True, sock, 1024) + self.assertEqual((10,), self.event_loop.remove_reader.call_args[0]) + + def test__sock_recv_tryagain(self): + class Err(socket.error): + errno = errno.EAGAIN + + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.recv.side_effect = Err + + self.event_loop.add_reader = unittest.mock.Mock() + self.event_loop._sock_recv(f, False, sock, 1024) + self.assertEqual((10, self.event_loop._sock_recv, f, True, sock, 1024), + self.event_loop.add_reader.call_args[0]) + + def test__sock_recv_exception(self): + class Err(socket.error): + pass + + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.recv.side_effect = Err + + self.event_loop._sock_recv(f, False, sock, 1024) + self.assertIsInstance(f.exception(), Err) + + def test_sock_sendall(self): + sock = unittest.mock.Mock() + self.event_loop._sock_sendall = unittest.mock.Mock() + + f = self.event_loop.sock_sendall(sock, b'data') + self.assertIsInstance(f, futures.Future) + self.assertEqual( + (f, False, sock, b'data'), + self.event_loop._sock_sendall.call_args[0]) + + def test_sock_sendall_nodata(self): + sock = unittest.mock.Mock() + self.event_loop._sock_sendall = unittest.mock.Mock() + + f = self.event_loop.sock_sendall(sock, b'') + self.assertIsInstance(f, futures.Future) + self.assertTrue(f.done()) + self.assertFalse(self.event_loop._sock_sendall.called) + + def test__sock_sendall_canceled_fut(self): + sock = unittest.mock.Mock() + + f = futures.Future() + f.cancel() + + self.event_loop._sock_sendall(f, False, sock, b'data') + self.assertFalse(sock.send.called) + + def test__sock_sendall_unregister(self): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + f = futures.Future() + f.cancel() + + self.event_loop.remove_writer = unittest.mock.Mock() + self.event_loop._sock_sendall(f, True, sock, b'data') + self.assertEqual((10,), self.event_loop.remove_writer.call_args[0]) + + def test__sock_sendall_tryagain(self): + class Err(socket.error): + errno = errno.EAGAIN + + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.send.side_effect = Err + + self.event_loop.add_writer = unittest.mock.Mock() + self.event_loop._sock_sendall(f, False, sock, b'data') + self.assertEqual( + (10, self.event_loop._sock_sendall, f, True, sock, b'data'), + self.event_loop.add_writer.call_args[0]) + + def test__sock_sendall_exception(self): + class Err(socket.error): + pass + + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.send.side_effect = Err + + self.event_loop._sock_sendall(f, False, sock, b'data') + self.assertIsInstance(f.exception(), Err) + + def test__sock_sendall(self): + sock = unittest.mock.Mock() + + f = futures.Future() + sock.fileno.return_value = 10 + sock.send.return_value = 4 + + self.event_loop._sock_sendall(f, False, sock, b'data') + self.assertTrue(f.done()) + + def test__sock_sendall_partial(self): + sock = unittest.mock.Mock() + + f = futures.Future() + sock.fileno.return_value = 10 + sock.send.return_value = 2 + + self.event_loop.add_writer = unittest.mock.Mock() + self.event_loop._sock_sendall(f, False, sock, b'data') + self.assertFalse(f.done()) + self.assertEqual( + (10, self.event_loop._sock_sendall, f, True, sock, b'ta'), + self.event_loop.add_writer.call_args[0]) + + def test__sock_sendall_none(self): + sock = unittest.mock.Mock() + + f = futures.Future() + sock.fileno.return_value = 10 + sock.send.return_value = 0 + + self.event_loop.add_writer = unittest.mock.Mock() + self.event_loop._sock_sendall(f, False, sock, b'data') + self.assertFalse(f.done()) + self.assertEqual( + (10, self.event_loop._sock_sendall, f, True, sock, b'data'), + self.event_loop.add_writer.call_args[0]) + class SelectorSocketTransportTests(unittest.TestCase): From 04e8437f088516bf715f7a9846aa45ac6e9bc990 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 13 Feb 2013 12:18:38 -0800 Subject: [PATCH 0318/1502] check if key data is different _BaseSelector.modify --- tulip/selectors.py | 6 +- tulip/selectors_test.py | 141 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+), 2 deletions(-) create mode 100644 tulip/selectors_test.py diff --git a/tulip/selectors.py b/tulip/selectors.py index d51b976d..b8b830eb 100644 --- a/tulip/selectors.py +++ b/tulip/selectors.py @@ -124,9 +124,11 @@ def modify(self, fileobj, events, data=None): key = self._fileobj_to_key[fileobj] except KeyError: raise ValueError("{!r} is not registered".format(fileobj)) - if events != key.events: + if events != key.events or data != key.data: self.unregister(fileobj) - self.register(fileobj, events, data) + return self.register(fileobj, events, data) + else: + return key def select(self, timeout=None): """Perform the actual selection, until some monitored file objects are diff --git a/tulip/selectors_test.py b/tulip/selectors_test.py new file mode 100644 index 00000000..7970a681 --- /dev/null +++ b/tulip/selectors_test.py @@ -0,0 +1,141 @@ +"""Tests for selectors.py.""" + +import sys +import unittest +import unittest.mock + +from . import events +from . import selectors + + +class BaseSelectorTests(unittest.TestCase): + + def test_fileobj_to_fd(self): + self.assertEqual(10, selectors._fileobj_to_fd(10)) + + f = unittest.mock.Mock() + f.fileno.return_value = 10 + self.assertEqual(10, selectors._fileobj_to_fd(f)) + + f.fileno.side_effect = TypeError + self.assertRaises(ValueError, selectors._fileobj_to_fd, f) + + def test_selector_key_repr(self): + key = selectors.SelectorKey(sys.stdin, selectors.EVENT_READ) + self.assertEqual( + "SelectorKey, fd=0, events=0x1, data=None>", + repr(key)) + + def test_register(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ) + self.assertIsInstance(key, selectors.SelectorKey) + self.assertEqual(key.fd, 10) + self.assertIs(key, s._fd_to_key[10]) + + def test_register_unknown_event(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.register, unittest.mock.Mock(), 999999) + + def test_register_already_registered(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ) + self.assertRaises(ValueError, s.register, fobj, selectors.EVENT_READ) + + def test_unregister(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ) + s.unregister(fobj) + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_unregister_unknown(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.unregister, unittest.mock.Mock()) + + def test_modify_unknown(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.modify, unittest.mock.Mock(), 1) + + def test_modify(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ) + key2 = s.modify(fobj, selectors.EVENT_WRITE) + self.assertNotEqual(key.events, key2.events) + self.assertEqual((selectors.EVENT_WRITE, None), s.get_info(fobj)) + + def test_modify_data(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + d1 = object() + d2 = object() + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ, d1) + key2 = s.modify(fobj, selectors.EVENT_READ, d2) + self.assertEqual(key.events, key2.events) + self.assertNotEqual(key.data, key2.data) + self.assertEqual((selectors.EVENT_READ, d2), s.get_info(fobj)) + + def test_modify_same(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + data = object() + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ, data) + key2 = s.modify(fobj, selectors.EVENT_READ, data) + self.assertIs(key, key2) + + def test_select(self): + s = selectors._BaseSelector() + self.assertRaises(NotImplementedError, s.select) + + def test_close(self): + s = selectors._BaseSelector() + key = s.register(1, selectors.EVENT_READ) + + s.close() + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_registered_count(self): + s = selectors._BaseSelector() + self.assertEqual(0, s.registered_count()) + + s.register(1, selectors.EVENT_READ) + self.assertEqual(1, s.registered_count()) + + s.unregister(1) + self.assertEqual(0, s.registered_count()) + + def test_context_manager(self): + s = selectors._BaseSelector() + + with s as sel: + sel.register(1, selectors.EVENT_READ) + + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_key_from_fd(self): + s = selectors._BaseSelector() + key = s.register(1, selectors.EVENT_READ) + + self.assertIs(key, s._key_from_fd(1)) + self.assertIsNone(s._key_from_fd(10)) From b854d89667bec613bd059f36fcecbdfe2753df5e Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 13 Feb 2013 12:29:37 -0800 Subject: [PATCH 0319/1502] remove event_loop parameter for _make_socket_transport and _make_ssl_transport methods; more BaseSelectorEventLoopTests --- tulip/base_events.py | 14 +- tulip/base_events_test.py | 7 +- tulip/proactor_events.py | 9 +- tulip/selector_events.py | 27 ++-- tulip/selector_events_test.py | 295 +++++++++++++++++++++++++++++++++- 5 files changed, 326 insertions(+), 26 deletions(-) diff --git a/tulip/base_events.py b/tulip/base_events.py index de9396a4..a00ea8c5 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -49,15 +49,18 @@ def __init__(self): self._default_executor = None self._signal_handlers = {} - def _make_socket_transport(self, event_loop, sock, protocol, waiter=None): + def _make_socket_transport(self, sock, protocol, waiter=None): """Create socket transport.""" raise NotImplementedError - def _make_ssl_transport(self, event_loop, rawsock, - protocol, sslcontext, waiter): + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter): """Create SSL transport.""" raise NotImplementedError + def _process_events(self, event_list): + """Process selector events.""" + raise NotImplementedError + def run(self): """Run the event loop until nothing left to do or stop() called. @@ -266,10 +269,9 @@ def create_connection(self, protocol_factory, host=None, port=None, *, if ssl: sslcontext = None if isinstance(ssl, bool) else ssl transport = self._make_ssl_transport( - self, sock, protocol, sslcontext, waiter) + sock, protocol, sslcontext, waiter) else: - transport = self._make_socket_transport( - self, sock, protocol, waiter) + transport = self._make_socket_transport(sock, protocol, waiter) yield from waiter return transport, protocol diff --git a/tulip/base_events_test.py b/tulip/base_events_test.py index 03f5a736..c5196e07 100644 --- a/tulip/base_events_test.py +++ b/tulip/base_events_test.py @@ -14,7 +14,10 @@ def test_not_implemented(self): m = unittest.mock.Mock() self.assertRaises( NotImplementedError, - base_event_loop._make_socket_transport, m, m, m) + base_event_loop._make_socket_transport, m, m) self.assertRaises( NotImplementedError, - base_event_loop._make_ssl_transport, m, m, m, m, m) + base_event_loop._make_ssl_transport, m, m, m, m) + self.assertRaises( + NotImplementedError, + base_event_loop._process_events, []) diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index 8117e1cc..2322f0cd 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -102,12 +102,6 @@ def _call_connection_lost(self, exc): class BaseProactorEventLoop(base_events.BaseEventLoop): - SocketTransport = _ProactorSocketTransport - - @staticmethod - def SslTransport(*args, **kwds): - raise NotImplementedError - def __init__(self, proactor): super().__init__() logging.debug('Using proactor: %s', proactor.__class__.__name__) @@ -115,6 +109,9 @@ def __init__(self, proactor): self._selector = proactor # convenient alias self._make_self_pipe() + def _make_socket_transport(self, sock, protocol, waiter=None): + return _ProactorSocketTransport(self, sock, protocol, waiter) + def close(self): if self._proactor is not None: self._proactor.close() diff --git a/tulip/selector_events.py b/tulip/selector_events.py index d23652fc..ff301df8 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -50,14 +50,12 @@ def __init__(self, selector=None): self._selector = selector self._make_self_pipe() - def _make_socket_transport(self, event_loop, sock, protocol, waiter=None): - return _SelectorSocketTransport( - event_loop, sock, protocol, waiter) + def _make_socket_transport(self, sock, protocol, waiter=None): + return _SelectorSocketTransport(self, sock, protocol, waiter) - def _make_ssl_transport(self, event_loop, rawsock, protocol, - sslcontext, waiter): + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter): return _SelectorSslTransport( - event_loop, rawsock, protocol, sslcontext, waiter) + self, rawsock, protocol, sslcontext, waiter) def close(self): if self._selector is not None: @@ -110,7 +108,7 @@ def _accept_connection(self, protocol_factory, sock): logging.exception('Accept failed') return protocol = protocol_factory() - transport = self._make_socket_transport(self, conn, protocol) + transport = self._make_socket_transport(conn, protocol) # It's now up to the protocol to handle the connection. def add_reader(self, fd, callback, *args): @@ -124,6 +122,8 @@ def add_reader(self, fd, callback, *args): else: self._selector.modify(fd, mask | selectors.EVENT_READ, (handler, writer)) + if reader is not None: + reader.cancel() return handler @@ -139,9 +139,12 @@ def remove_reader(self, fd): self._selector.unregister(fd) else: self._selector.modify(fd, mask, (None, writer)) + if reader is not None: reader.cancel() - return True + return True + else: + return False def add_writer(self, fd, callback, *args): """Add a writer callback. Return a Handler instance.""" @@ -154,6 +157,9 @@ def add_writer(self, fd, callback, *args): else: self._selector.modify(fd, mask | selectors.EVENT_WRITE, (reader, handler)) + if writer is not None: + writer.cancel() + return handler def remove_writer(self, fd): @@ -169,9 +175,12 @@ def remove_writer(self, fd): self._selector.unregister(fd) else: self._selector.modify(fd, mask, (reader, None)) + if writer is not None: writer.cancel() - return True + return True + else: + return False def sock_recv(self, sock, n): """XXX""" diff --git a/tulip/selector_events_test.py b/tulip/selector_events_test.py index da9fc0ff..b68fdbfb 100644 --- a/tulip/selector_events_test.py +++ b/tulip/selector_events_test.py @@ -10,6 +10,7 @@ ssl = None from . import futures +from . import selectors from .selector_events import BaseSelectorEventLoop from .selector_events import _SelectorSslTransport from .selector_events import _SelectorSocketTransport @@ -25,18 +26,23 @@ def _make_self_pipe(self): class BaseSelectorEventLoopTests(unittest.TestCase): def setUp(self): - self.event_loop = TestBaseSelectorEventLoop() + self.event_loop = TestBaseSelectorEventLoop(unittest.mock.Mock()) def test_make_socket_transport(self): m = unittest.mock.Mock() + self.event_loop.add_reader = unittest.mock.Mock() self.assertIsInstance( - self.event_loop._make_socket_transport(m, m, m), + self.event_loop._make_socket_transport(m, m), _SelectorSocketTransport) def test_make_ssl_transport(self): m = unittest.mock.Mock() + self.event_loop.add_reader = unittest.mock.Mock() + self.event_loop.add_writer = unittest.mock.Mock() + self.event_loop.remove_reader = unittest.mock.Mock() + self.event_loop.remove_writer = unittest.mock.Mock() self.assertIsInstance( - self.event_loop._make_ssl_transport(m, m, m, m, m), + self.event_loop._make_ssl_transport(m, m, m, m), _SelectorSslTransport) def test_close(self): @@ -247,6 +253,289 @@ def test__sock_sendall_none(self): (10, self.event_loop._sock_sendall, f, True, sock, b'data'), self.event_loop.add_writer.call_args[0]) + def test_sock_connect(self): + sock = unittest.mock.Mock() + self.event_loop._sock_connect = unittest.mock.Mock() + + f = self.event_loop.sock_connect(sock, ('127.0.0.1',8080)) + self.assertIsInstance(f, futures.Future) + self.assertEqual( + (f, False, sock, ('127.0.0.1',8080)), + self.event_loop._sock_connect.call_args[0]) + + def test__sock_connect(self): + f = futures.Future() + + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + self.event_loop._sock_connect(f, False, sock, ('127.0.0.1',8080)) + self.assertTrue(f.done()) + self.assertTrue(sock.connect.called) + + def test__sock_connect_canceled_fut(self): + sock = unittest.mock.Mock() + + f = futures.Future() + f.cancel() + + self.event_loop._sock_connect(f, False, sock, ('127.0.0.1',8080)) + self.assertFalse(sock.connect.called) + + def test__sock_connect_unregister(self): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + f = futures.Future() + f.cancel() + + self.event_loop.remove_writer = unittest.mock.Mock() + self.event_loop._sock_connect(f, True, sock, ('127.0.0.1',8080)) + self.assertEqual((10,), self.event_loop.remove_writer.call_args[0]) + + def test__sock_connect_tryagain(self): + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.getsockopt.return_value = errno.EAGAIN + + self.event_loop.add_writer = unittest.mock.Mock() + self.event_loop.remove_writer = unittest.mock.Mock() + + self.event_loop._sock_connect(f, True, sock, ('127.0.0.1',8080)) + self.assertEqual( + (10, self.event_loop._sock_connect, f, + True, sock, ('127.0.0.1',8080)), + self.event_loop.add_writer.call_args[0]) + + def test__sock_connect_exception(self): + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.getsockopt.return_value = errno.ENOTCONN + + self.event_loop.remove_writer = unittest.mock.Mock() + self.event_loop._sock_connect(f, True, sock, ('127.0.0.1',8080)) + self.assertIsInstance(f.exception(), socket.error) + + def test_sock_accept(self): + sock = unittest.mock.Mock() + self.event_loop._sock_accept = unittest.mock.Mock() + + f = self.event_loop.sock_accept(sock) + self.assertIsInstance(f, futures.Future) + self.assertEqual( + (f, False, sock), self.event_loop._sock_accept.call_args[0]) + + def test__sock_accept(self): + f = futures.Future() + + conn = unittest.mock.Mock() + + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.accept.return_value = conn, ('127.0.0.1', 1000) + + self.event_loop._sock_accept(f, False, sock) + self.assertTrue(f.done()) + self.assertEqual((conn, ('127.0.0.1', 1000)), f.result()) + self.assertEqual((False,), conn.setblocking.call_args[0]) + + def test__sock_accept_canceled_fut(self): + sock = unittest.mock.Mock() + + f = futures.Future() + f.cancel() + + self.event_loop._sock_accept(f, False, sock) + self.assertFalse(sock.accept.called) + + def test__sock_accept_unregister(self): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + f = futures.Future() + f.cancel() + + self.event_loop.remove_reader = unittest.mock.Mock() + self.event_loop._sock_accept(f, True, sock) + self.assertEqual((10,), self.event_loop.remove_reader.call_args[0]) + + def test__sock_accept_tryagain(self): + class Err(socket.error): + errno = errno.EAGAIN + + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.accept.side_effect = Err + + self.event_loop.add_reader = unittest.mock.Mock() + self.event_loop._sock_accept(f, False, sock) + self.assertEqual( + (10, self.event_loop._sock_accept, f, True, sock), + self.event_loop.add_reader.call_args[0]) + + def test__sock_accept_exception(self): + class Err(socket.error): + pass + + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.accept.side_effect = Err + + self.event_loop._sock_accept(f, False, sock) + self.assertIsInstance(f.exception(), Err) + + def test_add_reader(self): + self.event_loop._selector.get_info.side_effect = KeyError + h = self.event_loop.add_reader(1, lambda: True) + + self.assertTrue(self.event_loop._selector.register.called) + self.assertEqual( + (1, selectors.EVENT_READ, (h, None)), + self.event_loop._selector.register.call_args[0]) + + def test_add_reader_existing(self): + reader = unittest.mock.Mock() + writer = unittest.mock.Mock() + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_WRITE, (reader, writer)) + h = self.event_loop.add_reader(1, lambda: True) + + self.assertTrue(reader.cancel.called) + self.assertFalse(self.event_loop._selector.register.called) + self.assertTrue(self.event_loop._selector.modify.called) + self.assertEqual( + (1, selectors.EVENT_WRITE | selectors.EVENT_READ, (h, writer)), + self.event_loop._selector.modify.call_args[0]) + + def test_add_reader_existing_writer(self): + writer = unittest.mock.Mock() + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_WRITE, (None, writer)) + h = self.event_loop.add_reader(1, lambda: True) + + self.assertFalse(self.event_loop._selector.register.called) + self.assertTrue(self.event_loop._selector.modify.called) + self.assertEqual( + (1, selectors.EVENT_WRITE | selectors.EVENT_READ, (h, writer)), + self.event_loop._selector.modify.call_args[0]) + + def test_remove_reader(self): + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_READ, (None, None)) + self.assertFalse(self.event_loop.remove_reader(1)) + + self.assertTrue(self.event_loop._selector.unregister.called) + + def test_remove_reader_read_write(self): + reader = unittest.mock.Mock() + writer = unittest.mock.Mock() + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_READ | selectors.EVENT_WRITE, (reader, writer)) + self.assertTrue( + self.event_loop.remove_reader(1)) + + self.assertFalse(self.event_loop._selector.unregister.called) + self.assertEqual( + (1, selectors.EVENT_WRITE, (None, writer)), + self.event_loop._selector.modify.call_args[0]) + + def test_remove_reader_unknown(self): + self.event_loop._selector.get_info.side_effect = KeyError + self.assertFalse( + self.event_loop.remove_reader(1)) + + def test_add_writer(self): + self.event_loop._selector.get_info.side_effect = KeyError + h = self.event_loop.add_writer(1, lambda: True) + + self.assertTrue(self.event_loop._selector.register.called) + self.assertEqual( + (1, selectors.EVENT_WRITE, (None, h)), + self.event_loop._selector.register.call_args[0]) + + def test_add_writer_existing(self): + reader = unittest.mock.Mock() + writer = unittest.mock.Mock() + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_READ, (reader, writer)) + h = self.event_loop.add_writer(1, lambda: True) + + self.assertTrue(writer.cancel.called) + self.assertFalse(self.event_loop._selector.register.called) + self.assertTrue(self.event_loop._selector.modify.called) + self.assertEqual( + (1, selectors.EVENT_WRITE | selectors.EVENT_READ, (reader, h)), + self.event_loop._selector.modify.call_args[0]) + + def test_remove_writer(self): + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_WRITE, (None, None)) + self.assertFalse(self.event_loop.remove_writer(1)) + + self.assertTrue(self.event_loop._selector.unregister.called) + + def test_remove_writer_read_write(self): + reader = unittest.mock.Mock() + writer = unittest.mock.Mock() + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_READ | selectors.EVENT_WRITE, (reader, writer)) + self.assertTrue( + self.event_loop.remove_writer(1)) + + self.assertFalse(self.event_loop._selector.unregister.called) + self.assertEqual( + (1, selectors.EVENT_READ, (reader, None)), + self.event_loop._selector.modify.call_args[0]) + + def test_remove_writer_unknown(self): + self.event_loop._selector.get_info.side_effect = KeyError + self.assertFalse( + self.event_loop.remove_writer(1)) + + def test_process_events_read(self): + reader = unittest.mock.Mock() + reader.cancelled = False + + self.event_loop._add_callback = unittest.mock.Mock() + self.event_loop._process_events( + ((1, selectors.EVENT_READ, (reader, None)),)) + self.assertTrue(self.event_loop._add_callback.called) + self.assertEqual((reader,), self.event_loop._add_callback.call_args[0]) + + def test_process_events_read_cancelled(self): + reader = unittest.mock.Mock() + reader.cancelled = True + + self.event_loop.remove_reader = unittest.mock.Mock() + self.event_loop._process_events( + ((1, selectors.EVENT_READ, (reader, None)),)) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertEqual((1,), self.event_loop.remove_reader.call_args[0]) + + def test_process_events_write(self): + writer = unittest.mock.Mock() + writer.cancelled = False + + self.event_loop._add_callback = unittest.mock.Mock() + self.event_loop._process_events( + ((1, selectors.EVENT_WRITE, (None, writer)),)) + self.assertTrue(self.event_loop._add_callback.called) + self.assertEqual((writer,), self.event_loop._add_callback.call_args[0]) + + def test_process_events_write_cancelled(self): + writer = unittest.mock.Mock() + writer.cancelled = True + + self.event_loop.remove_writer = unittest.mock.Mock() + self.event_loop._process_events( + ((1, selectors.EVENT_WRITE, (None, writer)),)) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertEqual((1,), self.event_loop.remove_writer.call_args[0]) + class SelectorSocketTransportTests(unittest.TestCase): From b31a2f7a4a3daf00c60fa1727a670757ee98bdc8 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 13 Feb 2013 17:07:34 -0800 Subject: [PATCH 0320/1502] Fix call to create socket transport. --- tulip/proactor_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index 2322f0cd..684352b2 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -161,7 +161,7 @@ def loop(f=None): if f: conn, addr = f.result() protocol = protocol_factory() - transport = self.SocketTransport(self, conn, protocol) + transport = self._make_socket_transport(conn, protocol) f = self._proactor.accept(sock) except OSError as exc: sock.close() From 3fbab5c745f0d2a2e6658698191df5c410d04f37 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 13 Feb 2013 20:49:20 -0800 Subject: [PATCH 0321/1502] Fix SelectorKey tests --- tulip/selectors_test.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tulip/selectors_test.py b/tulip/selectors_test.py index 7970a681..1cdd0080 100644 --- a/tulip/selectors_test.py +++ b/tulip/selectors_test.py @@ -1,6 +1,5 @@ """Tests for selectors.py.""" -import sys import unittest import unittest.mock @@ -21,11 +20,9 @@ def test_fileobj_to_fd(self): self.assertRaises(ValueError, selectors._fileobj_to_fd, f) def test_selector_key_repr(self): - key = selectors.SelectorKey(sys.stdin, selectors.EVENT_READ) + key = selectors.SelectorKey(10, selectors.EVENT_READ) self.assertEqual( - "SelectorKey, fd=0, events=0x1, data=None>", - repr(key)) + "SelectorKey", repr(key)) def test_register(self): fobj = unittest.mock.Mock() From 12ed221aa5e934ed0c6b7f7b2ee76b16acafb644 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 15 Feb 2013 09:29:12 -0800 Subject: [PATCH 0322/1502] new Transport.get_extra_info() api --- tulip/events_test.py | 31 +++++++++++++++++++++++++--- tulip/selector_events.py | 23 ++++++++++++++------- tulip/selector_events_test.py | 2 +- tulip/transports.py | 9 ++++++++ tulip/transports_test.py | 39 +++++++++++++++++++++++++++++++++++ 5 files changed, 93 insertions(+), 11 deletions(-) create mode 100644 tulip/transports_test.py diff --git a/tulip/events_test.py b/tulip/events_test.py index 2ad7b00b..b65bf9ca 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -405,6 +405,7 @@ def test_create_ssl_transport(self): self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) self.assertTrue('ssl' in tr.__class__.__name__.lower()) + self.assertTrue(hasattr(tr.get_extra_info('socket'), 'getsockname')) self.event_loop.run() self.assertTrue(pr.nbytes > 0) @@ -438,17 +439,41 @@ def test_create_transport_connect_err(self): socket.error, self.event_loop.run_until_complete, fut) def test_start_serving(self): - f = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + proto = None + def factory(): + nonlocal proto + proto = MyProto() + return proto + + f = self.event_loop.start_serving(factory, '0.0.0.0', 0) sock = self.event_loop.run_until_complete(f) host, port = sock.getsockname() self.assertEqual(host, '0.0.0.0') client = socket.socket() client.connect(('127.0.0.1', port)) client.send(b'xxx') - self.event_loop.run_once() # This is quite mysterious, but necessary. self.event_loop.run_once() + self.assertIsInstance(proto, MyProto) + self.assertEqual('INITIAL', proto.state) self.event_loop.run_once() - sock.close() + self.assertEqual('CONNECTED', proto.state) + self.assertEqual(0, proto.nbytes) + self.event_loop.run_once() + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('socket')) + conn = proto.transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + self.assertEqual( + '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + + # close connection + proto.transport.close() + + self.event_loop.run_once() + self.assertEqual('CLOSED', proto.state) + # the client socket must be closed after to avoid ECONNRESET upon # recv()/send() on the serving socket client.close() diff --git a/tulip/selector_events.py b/tulip/selector_events.py index ff301df8..6954f701 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -50,12 +50,13 @@ def __init__(self, selector=None): self._selector = selector self._make_self_pipe() - def _make_socket_transport(self, sock, protocol, waiter=None): - return _SelectorSocketTransport(self, sock, protocol, waiter) + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + return _SelectorSocketTransport(self, sock, protocol, waiter, extra) - def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter): + def _make_ssl_transport(self, rawsock, protocol, + sslcontext, waiter, extra=None): return _SelectorSslTransport( - self, rawsock, protocol, sslcontext, waiter) + self, rawsock, protocol, sslcontext, waiter, extra) def close(self): if self._selector is not None: @@ -108,7 +109,8 @@ def _accept_connection(self, protocol_factory, sock): logging.exception('Accept failed') return protocol = protocol_factory() - transport = self._make_socket_transport(conn, protocol) + transport = self._make_socket_transport( + conn, protocol, extra={'addr': addr}) # It's now up to the protocol to handle the connection. def add_reader(self, fd, callback, *args): @@ -310,7 +312,9 @@ def _process_events(self, event_list): class _SelectorSocketTransport(transports.Transport): - def __init__(self, event_loop, sock, protocol, waiter=None): + def __init__(self, event_loop, sock, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock self._event_loop = event_loop self._sock = sock self._protocol = protocol @@ -406,7 +410,10 @@ def _call_connection_lost(self, exc): class _SelectorSslTransport(transports.Transport): - def __init__(self, event_loop, rawsock, protocol, sslcontext, waiter): + def __init__(self, event_loop, rawsock, + protocol, sslcontext, waiter, extra=None): + super().__init__(extra) + self._event_loop = event_loop self._rawsock = rawsock self._protocol = protocol @@ -418,6 +425,8 @@ def __init__(self, event_loop, rawsock, protocol, sslcontext, waiter): self._sslsock = sslsock self._buffer = [] self._closing = False # Set when close() called. + self._extra['socket'] = sslsock + self._on_handshake() def _on_handshake(self): diff --git a/tulip/selector_events_test.py b/tulip/selector_events_test.py index b68fdbfb..542662de 100644 --- a/tulip/selector_events_test.py +++ b/tulip/selector_events_test.py @@ -1087,4 +1087,4 @@ class Err(socket.error): self.transport._fatal_error = unittest.mock.Mock() self.transport._on_ready() self.assertTrue(self.transport._fatal_error.called) - self.assertEqual([], self.transport._buffer) \ No newline at end of file + self.assertEqual([], self.transport._buffer) diff --git a/tulip/transports.py b/tulip/transports.py index 4aaae3c7..6eb1c554 100644 --- a/tulip/transports.py +++ b/tulip/transports.py @@ -24,6 +24,15 @@ class Transport: except writelines(), which calls write() in a loop. """ + def __init__(self, extra=None): + if extra is None: + extra = {} + self._extra = extra + + def get_extra_info(self, name, default=None): + """Get optional transport information.""" + return self._extra.get(name, default) + def write(self, data): """Write some data bytes to the transport. diff --git a/tulip/transports_test.py b/tulip/transports_test.py new file mode 100644 index 00000000..59b10fd7 --- /dev/null +++ b/tulip/transports_test.py @@ -0,0 +1,39 @@ +"""Tests for transports.py.""" + +import unittest +import unittest.mock + +from . import transports + + +class TransportTests(unittest.TestCase): + + def test_ctor_extra_is_none(self): + transport = transports.Transport() + self.assertEqual(transport._extra, {}) + + def test_get_extra_info(self): + transport = transports.Transport({'extra': 'info'}) + self.assertEqual('info', transport.get_extra_info('extra')) + self.assertIsNone(transport.get_extra_info('unknown')) + + default = object() + self.assertIs(default, transport.get_extra_info('unknown', default)) + + def test_writelines(self): + transport = transports.Transport() + transport.write = unittest.mock.Mock() + + transport.writelines(['line1', 'line2', 'line3']) + self.assertEqual(3, transport.write.call_count) + + def test_not_implemented(self): + transport = transports.Transport() + + self.assertRaises(NotImplementedError, transport.write, 'data') + self.assertRaises(NotImplementedError, transport.write_eof) + self.assertRaises(NotImplementedError, transport.can_write_eof) + self.assertRaises(NotImplementedError, transport.pause) + self.assertRaises(NotImplementedError, transport.resume) + self.assertRaises(NotImplementedError, transport.close) + self.assertRaises(NotImplementedError, transport.abort) From b7ebf542e0eb6bab19ee783f04dcf0a490f02ffe Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 15 Feb 2013 14:22:05 -0800 Subject: [PATCH 0323/1502] more BaseEventLoop tests --- TODO | 2 - tulip/TODO | 2 + tulip/base_events.py | 13 +- tulip/base_events_test.py | 248 ++++++++++++++++++++++++++++++++++++-- tulip/events_test.py | 50 +++++++- 5 files changed, 298 insertions(+), 17 deletions(-) diff --git a/TODO b/TODO index b9559ef0..c6d4eead 100644 --- a/TODO +++ b/TODO @@ -66,8 +66,6 @@ FROM OLDER LIST - Add explicit wait queue to wait for Task's completion, instead of callbacks? -- Implement various lock styles a la threading.py. - - Look at pyfdpdlib's ioloop.py: http://code.google.com/p/pyftpdlib/source/browse/trunk/pyftpdlib/lib/ioloop.py diff --git a/tulip/TODO b/tulip/TODO index b3a9302e..acec5c24 100644 --- a/tulip/TODO +++ b/tulip/TODO @@ -24,3 +24,5 @@ TODO in tulip v2 (tulip/ package directory) - buffered stream implementation - Primitives like par() and wait_one() + +- Remove test dependency on xkcd.com, write our own test server diff --git a/tulip/base_events.py b/tulip/base_events.py index a00ea8c5..2e323efc 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -57,6 +57,14 @@ def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter): """Create SSL transport.""" raise NotImplementedError + def _read_from_self(self): + """XXX""" + raise NotImplementedError + + def _write_to_self(self): + """XXX""" + raise NotImplementedError + def _process_events(self, event_list): """Process selector events.""" raise NotImplementedError @@ -236,11 +244,10 @@ def create_connection(self, protocol_factory, host=None, port=None, *, exceptions = [] for family, type, proto, cname, address in infos: - sock = None try: sock = socket.socket(family=family, type=type, proto=proto) sock.setblocking(False) - yield self.sock_connect(sock, address) + yield from self.sock_connect(sock, address) except socket.error as exc: if sock is not None: sock.close() @@ -264,6 +271,8 @@ def create_connection(self, protocol_factory, host=None, port=None, *, raise ValueError( "host and port was not specified and no sock specified") + sock.setblocking(False) + protocol = protocol_factory() waiter = futures.Future() if ssl: diff --git a/tulip/base_events_test.py b/tulip/base_events_test.py index c5196e07..5dd0263c 100644 --- a/tulip/base_events_test.py +++ b/tulip/base_events_test.py @@ -1,23 +1,257 @@ """Tests for base_events.py""" +import concurrent.futures +import logging +import socket +import time import unittest import unittest.mock from . import base_events +from . import events +from . import futures +from . import protocols +from . import test_utils -class BaseEventLoopTests(unittest.TestCase): +class BaseEventLoopTests(test_utils.LogTrackingTestCase): - def test_not_implemented(self): - base_event_loop = base_events.BaseEventLoop() + def setUp(self): + super().setUp() + + self.event_loop = base_events.BaseEventLoop() + self.event_loop._selector = unittest.mock.Mock() + self.event_loop._selector.registered_count.return_value = 1 + def test_not_implemented(self): m = unittest.mock.Mock() self.assertRaises( NotImplementedError, - base_event_loop._make_socket_transport, m, m) + self.event_loop._make_socket_transport, m, m) self.assertRaises( NotImplementedError, - base_event_loop._make_ssl_transport, m, m, m, m) + self.event_loop._make_ssl_transport, m, m, m, m) self.assertRaises( - NotImplementedError, - base_event_loop._process_events, []) + NotImplementedError, self.event_loop._process_events, []) + self.assertRaises( + NotImplementedError, self.event_loop._write_to_self) + self.assertRaises( + NotImplementedError, self.event_loop._read_from_self) + + def test_add_callback_handler(self): + h = events.Handler(lambda: False, ()) + + self.event_loop._add_callback(h) + self.assertFalse(self.event_loop._scheduled) + self.assertIn(h, self.event_loop._ready) + + def test_add_callback_timer(self): + when = time.monotonic() + + h1 = events.Timer(when, lambda: False, ()) + h2 = events.Timer(when+10.0, lambda: False, ()) + + self.event_loop._add_callback(h2) + self.event_loop._add_callback(h1) + self.assertEqual([h1, h2], self.event_loop._scheduled) + self.assertFalse(self.event_loop._ready) + + def test_add_callback_cancelled_handler(self): + h = events.Handler(lambda: False, ()) + h.cancel() + + self.event_loop._add_callback(h) + self.assertFalse(self.event_loop._scheduled) + self.assertFalse(self.event_loop._ready) + + def test_wrap_future(self): + f = futures.Future() + self.assertIs(self.event_loop.wrap_future(f), f) + + def test_wrap_future_concurrent(self): + f = concurrent.futures.Future() + self.assertIsInstance(self.event_loop.wrap_future(f), futures.Future) + + def test_set_default_executor(self): + executor = unittest.mock.Mock() + self.event_loop.set_default_executor(executor) + self.assertIs(executor, self.event_loop._default_executor) + + def test_getnameinfo(self): + sockaddr = unittest.mock.Mock() + self.event_loop.run_in_executor = unittest.mock.Mock() + self.event_loop.getnameinfo(sockaddr) + self.assertEqual( + (None, socket.getnameinfo, sockaddr, 0), + self.event_loop.run_in_executor.call_args[0]) + + def test_call_soon(self): + def cb(): pass + + h = self.event_loop.call_soon(cb) + self.assertEqual(h._callback, cb) + self.assertIsInstance(h, events.Handler) + self.assertIn(h, self.event_loop._ready) + + def test_call_later(self): + def cb(): pass + + h = self.event_loop.call_later(10.0, cb) + self.assertIsInstance(h, events.Timer) + self.assertIn(h, self.event_loop._scheduled) + self.assertNotIn(h, self.event_loop._ready) + + def test_call_later_no_delay(self): + def cb(): pass + + h = self.event_loop.call_later(0, cb) + self.assertIn(h, self.event_loop._ready) + self.assertNotIn(h, self.event_loop._scheduled) + + def test_run_once_in_executor_handler(self): + def cb(): pass + + self.assertRaises( + AssertionError, self.event_loop.run_in_executor, + None, events.Handler(cb, ()), ('',)) + self.assertRaises( + AssertionError, self.event_loop.run_in_executor, + None, events.Timer(10, cb, ())) + + def test_run_once_in_executor_canceled(self): + def cb(): pass + h = events.Handler(cb, ()) + h.cancel() + + f = self.event_loop.run_in_executor(None, h) + self.assertIsInstance(f, futures.Future) + self.assertTrue(f.done()) + + def test_run_once_in_executor(self): + def cb(): pass + h = events.Handler(cb, ()) + f = futures.Future() + executor = unittest.mock.Mock() + executor.submit.return_value = f + + self.event_loop.set_default_executor(executor) + + res = self.event_loop.run_in_executor(None, h) + self.assertIs(f, res) + + executor = unittest.mock.Mock() + executor.submit.return_value = f + res = self.event_loop.run_in_executor(executor, h) + self.assertIs(f, res) + self.assertTrue(executor.submit.called) + + def test_run_once(self): + self.event_loop._run_once = unittest.mock.Mock() + self.event_loop._run_once.side_effect = base_events._StopError + self.event_loop.run_once() + self.assertTrue(self.event_loop._run_once.called) + + def test__run_once(self): + h1 = events.Timer(time.monotonic() + 0.1, lambda:True, ()) + h2 = events.Timer(time.monotonic() + 10.0, lambda:True, ()) + + h1.cancel() + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h1) + self.event_loop._scheduled.append(h2) + self.event_loop._run_once() + + t = self.event_loop._selector.select.call_args[0][0] + self.assertTrue(9.99 < t < 10.1) + self.assertEqual([h2], self.event_loop._scheduled) + self.assertTrue(self.event_loop._process_events.called) + + def test__run_once_timeout(self): + h = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._run_once(1.0) + self.assertEqual((1.0,), self.event_loop._selector.select.call_args[0]) + + def test__run_once_timeout_with_ready(self): + """If event loop has ready callbacks, select timeout is always 0.""" + h = events.Timer(time.monotonic() + 10.0, lambda:True, ()) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._ready.append(h) + self.event_loop._run_once(1.0) + + self.assertEqual((0,), self.event_loop._selector.select.call_args[0]) + + @unittest.mock.patch('tulip.base_events.time') + @unittest.mock.patch('tulip.base_events.logging') + def test__run_once_logging(self, m_logging, m_time): + """Log to INFO level if timeout > 1.0 sec.""" + idx = -1 + data = [10.0, 10.0, 12.0, 13.0] + def monotonic(): + nonlocal data, idx + idx += 1 + return data[idx] + + m_time.monotonic = monotonic + m_logging.INFO = logging.INFO + m_logging.DEBUG = logging.DEBUG + + self.event_loop._scheduled.append(events.Timer(11.0, lambda:True, ())) + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._run_once() + self.assertEqual(logging.INFO, m_logging.log.call_args[0][0]) + + idx = -1 + data = [10.0, 10.0, 10.3, 13.0] + self.event_loop._scheduled = [events.Timer(11.0, lambda:True, ())] + self.event_loop._run_once() + self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0]) + + def test__run_once_schedule_handler(self): + handler = None + processed = False + def cb(event_loop): + nonlocal processed, handler + processed = True + handler = event_loop.call_soon(lambda:True) + + h = events.Timer(time.monotonic() - 1, cb, (self.event_loop,)) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._run_once() + + self.assertTrue(processed) + self.assertEqual([handler], list(self.event_loop._ready)) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_connection_mutiple_errors(self, m_socket): + self.suppress_log_errors() + + class MyProto(protocols.Protocol): + pass + def getaddrinfo(*args, **kw): + yield from [] + return [(2,1,6,'',('107.6.106.82',80)), + (2,1,6,'',('107.6.106.82',80))] + + idx = -1 + errors = ['err1', 'err2'] + def _socket(*args, **kw): + nonlocal idx, errors + idx += 1 + raise socket.error(errors[idx]) + m_socket.socket = _socket + m_socket.error = socket.error + + self.event_loop.getaddrinfo = getaddrinfo + + task = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + task._step() + exc = task.exception() + self.assertEqual("Multiple exceptions: err1, err2", str(exc)) diff --git a/tulip/events_test.py b/tulip/events_test.py index b65bf9ca..90ff019b 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -387,7 +387,7 @@ def my_handler(): self.event_loop.run_forever() self.assertEqual(caught, 1) - def test_create_transport(self): + def test_create_connection(self): # TODO: This depends on xkcd.com behavior! f = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) tr, pr = self.event_loop.run_until_complete(f) @@ -396,8 +396,31 @@ def test_create_transport(self): self.event_loop.run() self.assertTrue(pr.nbytes > 0) + def test_create_connection_sock(self): + # TODO: This depends on xkcd.com behavior! + sock = None + infos = self.event_loop.run_until_complete( + self.event_loop.getaddrinfo('xkcd.com', 80,type=socket.SOCK_STREAM)) + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, address)) + except: + pass + else: + break + + f = self.event_loop.create_connection(MyProto, sock=sock) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) + @unittest.skipIf(ssl is None, 'No ssl module') - def test_create_ssl_transport(self): + def test_create_ssl_connection(self): # TODO: This depends on xkcd.com behavior! f = self.event_loop.create_connection( MyProto, 'xkcd.com', 443, ssl=True) @@ -409,18 +432,18 @@ def test_create_ssl_transport(self): self.event_loop.run() self.assertTrue(pr.nbytes > 0) - def test_create_transport_host_port_sock(self): + def test_create_connection_host_port_sock(self): self.suppress_log_errors() fut = self.event_loop.create_connection( MyProto, 'xkcd.com', 80, sock=object()) self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) - def test_create_transport_no_host_port_sock(self): + def test_create_connection_no_host_port_sock(self): self.suppress_log_errors() fut = self.event_loop.create_connection(MyProto) self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) - def test_create_transport_no_getaddrinfo(self): + def test_create_connection_no_getaddrinfo(self): self.suppress_log_errors() getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() getaddrinfo.return_value = [] @@ -429,7 +452,7 @@ def test_create_transport_no_getaddrinfo(self): self.assertRaises( socket.error, self.event_loop.run_until_complete, fut) - def test_create_transport_connect_err(self): + def test_create_connection_connect_err(self): self.suppress_log_errors() self.event_loop.sock_connect = unittest.mock.Mock() self.event_loop.sock_connect.side_effect = socket.error @@ -438,6 +461,21 @@ def test_create_transport_connect_err(self): self.assertRaises( socket.error, self.event_loop.run_until_complete, fut) + def test_create_connection_mutiple_errors(self): + self.suppress_log_errors() + + def getaddrinfo(*args, **kw): + yield from [] + return [(2,1,6,'',('107.6.106.82',80)), + (2,1,6,'',('107.6.106.82',80))] + self.event_loop.getaddrinfo = getaddrinfo + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + fut = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, fut) + def test_start_serving(self): proto = None def factory(): From 7ef3212eb8e6d1b62dc37a7ac2f7bc9f484da3aa Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 15 Feb 2013 21:24:02 -0800 Subject: [PATCH 0324/1502] add BaseEventLoop internal fds counter --- tulip/base_events.py | 4 +++- tulip/events_test.py | 20 ++++++++++++++++++++ tulip/selector_events.py | 12 ++++++++++-- tulip/selector_events_test.py | 24 ++++++++++++++++++++---- 4 files changed, 53 insertions(+), 7 deletions(-) diff --git a/tulip/base_events.py b/tulip/base_events.py index 2e323efc..0bd47ed9 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -47,6 +47,7 @@ def __init__(self): self._ready = collections.deque() self._scheduled = [] self._default_executor = None + self._internal_fds = 0 self._signal_handlers = {} def _make_socket_transport(self, sock, protocol, waiter=None): @@ -365,7 +366,8 @@ def _run_once(self, timeout=None): # Inspect the poll queue. If there's exactly one selectable # file descriptor, it's the self-pipe, and if there's nothing # scheduled, we should ignore it. - if self._selector.registered_count() > 1 or self._scheduled: + if (self._scheduled or + self._selector.registered_count() > self._internal_fds): if self._ready: timeout = 0 elif self._scheduled: diff --git a/tulip/events_test.py b/tulip/events_test.py index 90ff019b..c327575a 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -119,6 +119,15 @@ def run(): self.assertEqual(results, ['hello', 'world']) self.assertTrue(t1-t0 >= 0.09) + def test_call_soon_threadsafe_same_thread(self): + results = [] + def callback(arg): + results.append(arg) + self.event_loop.call_later(0.1, callback, 'world') + self.event_loop.call_soon_threadsafe(callback, 'hello') + self.event_loop.run() + self.assertEqual(results, ['hello', 'world']) + def test_call_soon_threadsafe_with_handler(self): results = [] def callback(arg): @@ -590,6 +599,17 @@ def test_accept_connection_exception(self): self.event_loop._accept_connection(MyProto, sock) self.assertTrue(sock.close.called) + def test_internal_fds(self): + event_loop = self.create_event_loop() + if not isinstance(event_loop, selector_events.BaseSelectorEventLoop): + return + + self.assertEqual(1, event_loop._internal_fds) + event_loop.close() + self.assertEqual(0, event_loop._internal_fds) + self.assertIsNone(event_loop._csock) + self.assertIsNone(event_loop._ssock) + if sys.platform == 'win32': from . import windows_events diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 6954f701..7851f205 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -60,19 +60,27 @@ def _make_ssl_transport(self, rawsock, protocol, def close(self): if self._selector is not None: + self._close_self_pipe() self._selector.close() self._selector = None - self._ssock.close() - self._csock.close() def _socketpair(self): raise NotImplementedError + def _close_self_pipe(self): + self.remove_reader(self._ssock.fileno()) + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + def _make_self_pipe(self): # A self-socket, really. :-) self._ssock, self._csock = self._socketpair() self._ssock.setblocking(False) self._csock.setblocking(False) + self._internal_fds += 1 self.add_reader(self._ssock.fileno(), self._read_from_self) def _read_from_self(self): diff --git a/tulip/selector_events_test.py b/tulip/selector_events_test.py index 542662de..13fce021 100644 --- a/tulip/selector_events_test.py +++ b/tulip/selector_events_test.py @@ -21,6 +21,7 @@ class TestBaseSelectorEventLoop(BaseSelectorEventLoop): def _make_self_pipe(self): self._ssock = unittest.mock.Mock() self._csock = unittest.mock.Mock() + self._internal_fds += 1 class BaseSelectorEventLoopTests(unittest.TestCase): @@ -46,21 +47,36 @@ def test_make_ssl_transport(self): _SelectorSslTransport) def test_close(self): + ssock = self.event_loop._ssock + csock = self.event_loop._csock + remove_reader = self.event_loop.remove_reader = unittest.mock.Mock() + self.event_loop._selector.close() self.event_loop._selector = selector = unittest.mock.Mock() self.event_loop.close() self.assertIsNone(self.event_loop._selector) + self.assertIsNone(self.event_loop._csock) + self.assertIsNone(self.event_loop._ssock) self.assertTrue(selector.close.called) - self.assertTrue(self.event_loop._ssock.close.called) - self.assertTrue(self.event_loop._csock.close.called) + self.assertTrue(ssock.close.called) + self.assertTrue(csock.close.called) + self.assertTrue(remove_reader.called) + + self.event_loop.close() + self.event_loop.close() def test_close_no_selector(self): + ssock = self.event_loop._ssock + csock = self.event_loop._csock + remove_reader = self.event_loop.remove_reader = unittest.mock.Mock() + self.event_loop._selector.close() self.event_loop._selector = None self.event_loop.close() self.assertIsNone(self.event_loop._selector) - self.assertTrue(self.event_loop._ssock.close.called) - self.assertTrue(self.event_loop._csock.close.called) + self.assertFalse(ssock.close.called) + self.assertFalse(csock.close.called) + self.assertFalse(remove_reader.called) def test_socketpair(self): self.assertRaises(NotImplementedError, self.event_loop._socketpair) From fdd938b8ae34e2b651a5b68603909d3d5468d2a5 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 15 Feb 2013 21:58:47 -0800 Subject: [PATCH 0325/1502] SelectorEventLoop tests --- tulip/unix_events.py | 8 +- tulip/unix_events_test.py | 168 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 172 insertions(+), 4 deletions(-) create mode 100644 tulip/unix_events_test.py diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 8a2026bb..760db1bf 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -1,13 +1,13 @@ -"""Selector eventloop for Unix with signal handling. -""" +"""Selector eventloop for Unix with signal handling.""" import errno +import logging import socket import sys try: import signal -except ImportError: +except ImportError: # pragma: no cover signal = None from . import events @@ -17,7 +17,7 @@ __all__ = ['SelectorEventLoop'] -if sys.platform == 'win32': +if sys.platform == 'win32': # pragma: no cover raise ImportError('Signals are not really supported on Windows') diff --git a/tulip/unix_events_test.py b/tulip/unix_events_test.py new file mode 100644 index 00000000..1ed344cd --- /dev/null +++ b/tulip/unix_events_test.py @@ -0,0 +1,168 @@ +"""Tests for unix_events.py.""" + +import errno +import unittest +import unittest.mock + +try: + import signal +except ImportError: + signal = None + +from . import events +from . import unix_events + + +@unittest.skipUnless(signal, 'Signals are not supported') +class SelectorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unix_events.SelectorEventLoop() + + def test_check_signal(self): + self.assertRaises( + TypeError, self.event_loop._check_signal, '1') + self.assertRaises( + ValueError, self.event_loop._check_signal, signal.NSIG + 1) + + unix_events.signal = None + + def restore_signal(): + unix_events.signal = signal + self.addCleanup(restore_signal) + + self.assertRaises( + RuntimeError, self.event_loop._check_signal, signal.SIGINT) + + def test_handle_signal_no_handler(self): + self.event_loop._handle_signal(signal.NSIG + 1, ()) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_setup_error(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.set_wakeup_fd.side_effect = ValueError + + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + h = self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + self.assertIsInstance(h, events.Handler) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_install_error(self, m_signal): + m_signal.NSIG = signal.NSIG + + def set_wakeup_fd(fd): + if fd == -1: + raise ValueError() + m_signal.set_wakeup_fd = set_wakeup_fd + + class Err(OSError): + errno = errno.EFAULT + m_signal.signal.side_effect = Err + + self.assertRaises( + Err, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.logging') + def test_add_signal_handler_install_error2(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.event_loop._signal_handlers[signal.SIGHUP] = lambda: True + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(1, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.logging') + def test_add_signal_handler_install_error3(self, m_logging, m_signal): + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + m_signal.NSIG = signal.NSIG + + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(2, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + h = self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + self.assertTrue( + self.event_loop.remove_signal_handler(signal.SIGHUP)) + self.assertTrue(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_2(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.SIGINT = signal.SIGINT + + h = self.event_loop.add_signal_handler(signal.SIGINT, lambda: True) + self.event_loop._signal_handlers[signal.SIGHUP] = object() + m_signal.set_wakeup_fd.reset_mock() + + self.assertTrue( + self.event_loop.remove_signal_handler(signal.SIGINT)) + self.assertFalse(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGINT, m_signal.default_int_handler), + m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.logging') + def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + h = self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.set_wakeup_fd.side_effect = ValueError + + self.event_loop.remove_signal_handler(signal.SIGHUP) + self.assertTrue(m_logging.info) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error(self, m_signal): + m_signal.NSIG = signal.NSIG + h = self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.signal.side_effect = OSError + + self.assertRaises( + OSError, self.event_loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error2(self, m_signal): + m_signal.NSIG = signal.NSIG + h = self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.assertRaises( + RuntimeError, self.event_loop.remove_signal_handler, signal.SIGHUP) From 3dad7cdc45707995b5bd70ec2f19489729cb6b75 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 19 Feb 2013 18:03:06 -0800 Subject: [PATCH 0326/1502] raise exception if coroutine uses yield instead of yield from for future or generator --- tulip/futures.py | 3 +++ tulip/tasks.py | 27 ++++++++++++++++++----- tulip/tasks_test.py | 53 +++++++++++++++++++++++++++++++++++---------- 3 files changed, 66 insertions(+), 17 deletions(-) diff --git a/tulip/futures.py b/tulip/futures.py index e79999fc..cf794417 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -52,6 +52,8 @@ class Future: _result = None _exception = None + _blocking = False # proper use of future (yield vs yield from) + def __init__(self, *, event_loop=None): """Initialize the future. @@ -236,5 +238,6 @@ def _copy_state(self, other): def __iter__(self): if not self.done(): + self._blocking = True yield self # This tells Task to wait for completion. return self.result() # May raise too. diff --git a/tulip/tasks.py b/tulip/tasks.py index 617040ce..2e6b73f3 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -117,7 +117,16 @@ def _step(self, value=None, exc=None): else: # XXX No check for self._must_cancel here? if isinstance(result, futures.Future): - result.add_done_callback(self._wakeup) + if not result._blocking: + self._event_loop.call_soon( + self._step, + None, RuntimeError( + 'yield was used instead of yield from in task %r ' + 'with %r' % (self, result))) + else: + result._blocking = False + result.add_done_callback(self._wakeup) + elif isinstance(result, concurrent.futures.Future): # This ought to be more efficient than wrap_future(), # because we don't create an extra Future. @@ -126,9 +135,17 @@ def _step(self, value=None, exc=None): self._event_loop.call_soon_threadsafe( self._wakeup, future)) else: - if result is not None: - logging.warn('_step(): bad yield: %r', result) - self._event_loop.call_soon(self._step) + if inspect.isgenerator(result): + self._event_loop.call_soon( + self._step, + None, RuntimeError( + 'yield was used instead of yield from for ' + 'generator in task %r with %s' % (self, result))) + else: + if result is not None: + logging.warn('_step(): bad yield: %r', result) + + self._event_loop.call_soon(self._step) def _wakeup(self, future): try: @@ -180,7 +197,7 @@ def _wait(fs, timeout=None, return_when=ALL_COMPLETED): pending.add(f) if (not pending or - timeout != None and timeout <= 0 or + timeout is not None and timeout <= 0 or return_when == FIRST_COMPLETED and done or return_when == FIRST_EXCEPTION and errors): return done, pending diff --git a/tulip/tasks_test.py b/tulip/tasks_test.py index f122d279..12b6daf1 100644 --- a/tulip/tasks_test.py +++ b/tulip/tasks_test.py @@ -413,7 +413,7 @@ def notmuch(): self.assertEqual(1, m_logging.warn.call_args[0][1]) def test_step_result_future(self): - # Coroutine returns Future + """If coroutine returns future, task waits on this future.""" self.suppress_log_warnings() class Fut(futures.Future): @@ -424,21 +424,22 @@ def add_done_callback(self, fn): self.cb_added = True super().add_done_callback(fn) - c_fut = Fut() + fut = Fut() + result = None - @tasks.coroutine - def notmuch(): - yield from [c_fut] - return (yield) + @tasks.task + def wait_for_future(): + nonlocal result + result = yield from fut - task = tasks.Task(notmuch()) - task._step() - self.assertTrue(c_fut.cb_added) + task = wait_for_future() + self.event_loop.run_once() + self.assertTrue(fut.cb_added) res = object() - c_fut.set_result(res) - self.event_loop.run() - self.assertIs(res, task.result()) + fut.set_result(res) + self.event_loop.run_once() + self.assertIs(res, result) def test_step_result_concurrent_future(self): # Coroutine returns concurrent.futures.Future @@ -522,6 +523,34 @@ def fn2(): yield self.assertTrue(tasks.iscoroutinefunction(fn2)) + def test_yield_vs_yield_from(self): + fut = futures.Future() + + @tasks.task + def wait_for_future(): + yield fut + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, task) + + def test_yield_vs_yield_from_generator(self): + fut = futures.Future() + + @tasks.coroutine + def coro(): + yield from fut + + @tasks.task + def wait_for_future(): + yield coro() + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, task) + if __name__ == '__main__': unittest.main() From 10690fc6bae8e73933abbea69ee169c5310c4491 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 20 Feb 2013 09:40:34 -0800 Subject: [PATCH 0327/1502] _ProactorSocketTransport get_extra_info support --- runtests.py | 12 +++++++----- tulip/events_test.py | 4 +++- tulip/proactor_events.py | 22 ++++++++++++++++------ 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/runtests.py b/runtests.py index 6758c742..64036d57 100644 --- a/runtests.py +++ b/runtests.py @@ -30,17 +30,19 @@ def load_tests(includes=(), excludes=()): test_mods = [f[:-3] for f in os.listdir(TULIP_DIR) if f.endswith('_test.py')] - if sys.platform == 'win32': + mods = [] + for mod in test_mods: try: - test_mods.remove('subprocess_test') - except ValueError: + __import__('tulip', fromlist=[mod]) + mods.append(mod) + except ImportError: pass - tulip = __import__('tulip', fromlist=test_mods) loader = unittest.TestLoader() suite = unittest.TestSuite() + tulip = sys.modules['tulip'] - for mod in [getattr(tulip, name) for name in test_mods]: + for mod in [getattr(tulip, name) for name in mods]: for name in set(dir(mod)): if name.endswith('Tests'): test_module = getattr(mod, name) diff --git a/tulip/events_test.py b/tulip/events_test.py index c327575a..2635da16 100644 --- a/tulip/events_test.py +++ b/tulip/events_test.py @@ -496,6 +496,7 @@ def factory(): sock = self.event_loop.run_until_complete(f) host, port = sock.getsockname() self.assertEqual(host, '0.0.0.0') + self.event_loop.run_once(0.01) # for windows proactor selector client = socket.socket() client.connect(('127.0.0.1', port)) client.send(b'xxx') @@ -506,6 +507,7 @@ def factory(): self.assertEqual('CONNECTED', proto.state) self.assertEqual(0, proto.nbytes) self.event_loop.run_once() + self.event_loop.run_once(0.1) # for windows proactor selector self.assertEqual(3, proto.nbytes) # extra info is available @@ -623,7 +625,7 @@ class ProactorEventLoopTests(EventLoopTestsMixin, test_utils.LogTrackingTestCase): def create_event_loop(self): return windows_events.ProactorEventLoop() - def test_create_ssl_transport(self): + def test_create_ssl_connection(self): raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") def test_reader_callback(self): raise unittest.SkipTest("IocpEventLoop does not have add_reader()") diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index 684352b2..0391eb49 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -14,7 +14,9 @@ class _ProactorSocketTransport(transports.Transport): - def __init__(self, event_loop, sock, protocol, waiter=None): + def __init__(self, event_loop, sock, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock self._event_loop = event_loop self._sock = sock self._protocol = protocol @@ -109,16 +111,15 @@ def __init__(self, proactor): self._selector = proactor # convenient alias self._make_self_pipe() - def _make_socket_transport(self, sock, protocol, waiter=None): - return _ProactorSocketTransport(self, sock, protocol, waiter) + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + return _ProactorSocketTransport(self, sock, protocol, waiter, extra) def close(self): if self._proactor is not None: + self._close_self_pipe() self._proactor.close() self._proactor = None self._selector = None - self._ssock.close() - self._csock.close() def sock_recv(self, sock, n): return self._proactor.recv(sock, n) @@ -135,11 +136,19 @@ def sock_accept(self, sock): def _socketpair(self): raise NotImplementedError + def _close_self_pipe(self): + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + def _make_self_pipe(self): # A self-socket, really. :-) self._ssock, self._csock = self._socketpair() self._ssock.setblocking(False) self._csock.setblocking(False) + self._internal_fds += 1 def loop(f=None): try: if f: @@ -161,7 +170,8 @@ def loop(f=None): if f: conn, addr = f.result() protocol = protocol_factory() - transport = self._make_socket_transport(conn, protocol) + transport = self._make_socket_transport( + conn, protocol, extra={'addr': addr}) f = self._proactor.accept(sock) except OSError as exc: sock.close() From e830375782209bb395fca870adc9fad90326e048 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 20 Feb 2013 14:23:00 -0800 Subject: [PATCH 0328/1502] assert if pending future is used as iterator --- tulip/futures.py | 1 + tulip/futures_test.py | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/tulip/futures.py b/tulip/futures.py index cf794417..4bb2f198 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -240,4 +240,5 @@ def __iter__(self): if not self.done(): self._blocking = True yield self # This tells Task to wait for completion. + assert self.done(), "yield from wasn't used with future" return self.result() # May raise too. diff --git a/tulip/futures_test.py b/tulip/futures_test.py index 2610b5be..36910163 100644 --- a/tulip/futures_test.py +++ b/tulip/futures_test.py @@ -134,6 +134,16 @@ def test_copy_state(self): newf_cancelled._copy_state(f_cancelled) self.assertTrue(newf_cancelled.cancelled()) + def test_iter(self): + def coro(): + fut = futures.Future() + yield from fut + + def test(): + arg1, arg2 = coro() + + self.assertRaises(AssertionError, test) + # A fake event loop for tests. All it does is implement a call_soon method # that immediately invokes the given function. From e594979c18438f51c446ac0612106066c32a0d95 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 22 Feb 2013 21:24:01 -0800 Subject: [PATCH 0329/1502] tests moved to separate directory --- runtests.py | 14 ++++++++------ {tulip => tests}/base_events_test.py | 10 +++++----- {tulip => tests}/events_test.py | 16 ++++++++-------- {tulip => tests}/futures_test.py | 2 +- {tulip => tests}/http_client_test.py | 8 ++++---- {tulip => tests}/locks_test.py | 10 +++++----- {tulip => tests}/selector_events_test.py | 10 +++++----- {tulip => tests}/selectors_test.py | 4 ++-- {tulip => tests}/subprocess_test.py | 6 +++--- {tulip => tests}/tasks_test.py | 8 ++++---- {tulip => tests}/transports_test.py | 2 +- {tulip => tests}/unix_events_test.py | 4 ++-- 12 files changed, 48 insertions(+), 46 deletions(-) rename {tulip => tests}/base_events_test.py (98%) rename {tulip => tests}/events_test.py (99%) rename {tulip => tests}/futures_test.py (99%) rename {tulip => tests}/http_client_test.py (98%) rename {tulip => tests}/locks_test.py (99%) rename {tulip => tests}/selector_events_test.py (99%) rename {tulip => tests}/selectors_test.py (98%) rename {tulip => tests}/subprocess_test.py (93%) rename {tulip => tests}/tasks_test.py (99%) rename {tulip => tests}/transports_test.py (97%) rename {tulip => tests}/unix_events_test.py (99%) diff --git a/runtests.py b/runtests.py index 64036d57..9b00c9d8 100644 --- a/runtests.py +++ b/runtests.py @@ -20,21 +20,23 @@ import re import sys import unittest +import importlib.machinery assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' -TULIP_DIR = os.path.join(os.path.dirname(__file__), 'tulip') +TULIP_DIR = os.path.join(os.path.dirname(__file__), 'tests') def load_tests(includes=(), excludes=()): - test_mods = [f[:-3] for f in os.listdir(TULIP_DIR) + test_mods = [(f[:-3], f) for f in os.listdir(TULIP_DIR) if f.endswith('_test.py')] mods = [] - for mod in test_mods: + for mod, sourcefile in test_mods: try: - __import__('tulip', fromlist=[mod]) - mods.append(mod) + loader = importlib.machinery.SourceFileLoader( + mod, os.path.join(TULIP_DIR, sourcefile)) + mods.append(loader.load_module()) except ImportError: pass @@ -42,7 +44,7 @@ def load_tests(includes=(), excludes=()): suite = unittest.TestSuite() tulip = sys.modules['tulip'] - for mod in [getattr(tulip, name) for name in mods]: + for mod in mods: for name in set(dir(mod)): if name.endswith('Tests'): test_module = getattr(mod, name) diff --git a/tulip/base_events_test.py b/tests/base_events_test.py similarity index 98% rename from tulip/base_events_test.py rename to tests/base_events_test.py index 5dd0263c..34f04f6d 100644 --- a/tulip/base_events_test.py +++ b/tests/base_events_test.py @@ -7,11 +7,11 @@ import unittest import unittest.mock -from . import base_events -from . import events -from . import futures -from . import protocols -from . import test_utils +from tulip import base_events +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import test_utils class BaseEventLoopTests(test_utils.LogTrackingTestCase): diff --git a/tulip/events_test.py b/tests/events_test.py similarity index 99% rename from tulip/events_test.py rename to tests/events_test.py index 2635da16..f5a2ebcf 100644 --- a/tulip/events_test.py +++ b/tests/events_test.py @@ -17,12 +17,12 @@ import unittest import unittest.mock -from . import events -from . import futures -from . import transports -from . import protocols -from . import selector_events -from . import test_utils +from tulip import events +from tulip import futures +from tulip import transports +from tulip import protocols +from tulip import selector_events +from tulip import test_utils class MyProto(protocols.Protocol): @@ -647,8 +647,8 @@ def test_accept_connection_exception(self): "IocpEventLoop does not have _accept_connection()") else: - from . import selectors - from . import unix_events + from tulip import selectors + from tulip import unix_events if hasattr(selectors, 'KqueueSelector'): class KqueueEventLoopTests(EventLoopTestsMixin, diff --git a/tulip/futures_test.py b/tests/futures_test.py similarity index 99% rename from tulip/futures_test.py rename to tests/futures_test.py index 36910163..fd49493d 100644 --- a/tulip/futures_test.py +++ b/tests/futures_test.py @@ -2,7 +2,7 @@ import unittest -from . import futures +from tulip import futures def _fakefunc(f): diff --git a/tulip/http_client_test.py b/tests/http_client_test.py similarity index 98% rename from tulip/http_client_test.py rename to tests/http_client_test.py index 3c3b1242..177ad31b 100644 --- a/tulip/http_client_test.py +++ b/tests/http_client_test.py @@ -2,10 +2,10 @@ import unittest -from . import events -from . import http_client -from . import tasks -from . import test_utils +from tulip import events +from tulip import http_client +from tulip import tasks +from tulip import test_utils class StreamReaderTests(test_utils.LogTrackingTestCase): diff --git a/tulip/locks_test.py b/tests/locks_test.py similarity index 99% rename from tulip/locks_test.py rename to tests/locks_test.py index 9444ceaa..4267134c 100644 --- a/tulip/locks_test.py +++ b/tests/locks_test.py @@ -4,11 +4,11 @@ import unittest import unittest.mock -from . import events -from . import futures -from . import locks -from . import tasks -from . import test_utils +from tulip import events +from tulip import futures +from tulip import locks +from tulip import tasks +from tulip import test_utils class LockTests(unittest.TestCase): diff --git a/tulip/selector_events_test.py b/tests/selector_events_test.py similarity index 99% rename from tulip/selector_events_test.py rename to tests/selector_events_test.py index 13fce021..80f0d73d 100644 --- a/tulip/selector_events_test.py +++ b/tests/selector_events_test.py @@ -9,11 +9,11 @@ except ImportError: ssl = None -from . import futures -from . import selectors -from .selector_events import BaseSelectorEventLoop -from .selector_events import _SelectorSslTransport -from .selector_events import _SelectorSocketTransport +from tulip import futures +from tulip import selectors +from tulip.selector_events import BaseSelectorEventLoop +from tulip.selector_events import _SelectorSslTransport +from tulip.selector_events import _SelectorSocketTransport class TestBaseSelectorEventLoop(BaseSelectorEventLoop): diff --git a/tulip/selectors_test.py b/tests/selectors_test.py similarity index 98% rename from tulip/selectors_test.py rename to tests/selectors_test.py index 1cdd0080..3ebaab8c 100644 --- a/tulip/selectors_test.py +++ b/tests/selectors_test.py @@ -3,8 +3,8 @@ import unittest import unittest.mock -from . import events -from . import selectors +from tulip import events +from tulip import selectors class BaseSelectorTests(unittest.TestCase): diff --git a/tulip/subprocess_test.py b/tests/subprocess_test.py similarity index 93% rename from tulip/subprocess_test.py rename to tests/subprocess_test.py index 3d996f6a..14ce11d7 100644 --- a/tulip/subprocess_test.py +++ b/tests/subprocess_test.py @@ -3,9 +3,9 @@ import logging import unittest -from . import events -from . import protocols -from . import subprocess_transport +from tulip import events +from tulip import protocols +from tulip import subprocess_transport class MyProto(protocols.Protocol): diff --git a/tulip/tasks_test.py b/tests/tasks_test.py similarity index 99% rename from tulip/tasks_test.py rename to tests/tasks_test.py index 12b6daf1..25ca5a4f 100644 --- a/tulip/tasks_test.py +++ b/tests/tasks_test.py @@ -5,10 +5,10 @@ import unittest import unittest.mock -from . import events -from . import futures -from . import tasks -from . import test_utils +from tulip import events +from tulip import futures +from tulip import tasks +from tulip import test_utils class Dummy: diff --git a/tulip/transports_test.py b/tests/transports_test.py similarity index 97% rename from tulip/transports_test.py rename to tests/transports_test.py index 59b10fd7..eb61d914 100644 --- a/tulip/transports_test.py +++ b/tests/transports_test.py @@ -3,7 +3,7 @@ import unittest import unittest.mock -from . import transports +from tulip import transports class TransportTests(unittest.TestCase): diff --git a/tulip/unix_events_test.py b/tests/unix_events_test.py similarity index 99% rename from tulip/unix_events_test.py rename to tests/unix_events_test.py index 1ed344cd..2504648b 100644 --- a/tulip/unix_events_test.py +++ b/tests/unix_events_test.py @@ -9,8 +9,8 @@ except ImportError: signal = None -from . import events -from . import unix_events +from tulip import events +from tulip import unix_events @unittest.skipUnless(signal, 'Signals are not supported') From 461e72986b9e8b8dc053dd9c8a3bcbfbabe3e8ce Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 25 Feb 2013 12:52:58 -0800 Subject: [PATCH 0330/1502] Move http_client to submodule --- crawl.py | 6 ++--- curl.py | 5 ++-- srv.py | 6 ++--- tests/http_client_test.py | 46 ++++++++++++++++----------------- tulip/http/__init__.py | 5 ++++ tulip/{ => http}/http_client.py | 33 +++++++++++------------ 6 files changed, 53 insertions(+), 48 deletions(-) create mode 100644 tulip/http/__init__.py rename tulip/{ => http}/http_client.py (93%) diff --git a/crawl.py b/crawl.py index 3d7055f9..3dd270b8 100755 --- a/crawl.py +++ b/crawl.py @@ -8,7 +8,7 @@ import urllib.parse import tulip -from tulip import http_client +import tulip.http END = '\n' MAXTASKS = 100 @@ -71,8 +71,8 @@ def process(self, url): path = '/' if query: path = '?'.join([path, query]) - p = http_client.HttpClientProtocol(netloc, path=path, - ssl=(scheme=='https')) + p = tulip.http.HttpClientProtocol(netloc, path=path, + ssl=(scheme=='https')) delay = 1 while True: try: diff --git a/curl.py b/curl.py index 0ba82404..b0566b67 100755 --- a/curl.py +++ b/curl.py @@ -4,7 +4,7 @@ import urllib.parse import tulip -from tulip import http_client +import tulip.http def main(): @@ -15,8 +15,7 @@ def main(): if query: path = '?'.join([path, query]) print(netloc, path, scheme) - p = http_client.HttpClientProtocol(netloc, path=path, - ssl=(scheme=='https')) + p = tulip.http.HttpClientProtocol(netloc, path=path, ssl=(scheme=='https')) f = p.connect() sts, headers, stream = p.event_loop.run_until_complete(tulip.Task(f)) print(sts) diff --git a/srv.py b/srv.py index 077279c1..545ba4de 100644 --- a/srv.py +++ b/srv.py @@ -6,7 +6,7 @@ import re import tulip -from tulip.http_client import StreamReader +import tulip.http class HttpServer(tulip.Protocol): @@ -99,8 +99,8 @@ def handle_request(self): def connection_made(self, transport): self.transport = transport - print('connection made', transport, transport._sock) - self.reader = StreamReader() + print('connection made', transport, transport.get_extra_info('socket')) + self.reader = tulip.http.StreamReader() self.handler = self.handle_request() def data_received(self, data): diff --git a/tests/http_client_test.py b/tests/http_client_test.py index 177ad31b..e0ca45af 100644 --- a/tests/http_client_test.py +++ b/tests/http_client_test.py @@ -1,9 +1,9 @@ -"""Tests for http_client.py.""" +"""Tests for http/http_client.py.""" import unittest from tulip import events -from tulip import http_client +from tulip import http from tulip import tasks from tulip import test_utils @@ -22,14 +22,14 @@ def tearDown(self): self.event_loop.close() def test_feed_empty_data(self): - stream = http_client.StreamReader() + stream = http.StreamReader() stream.feed_data(b'') self.assertEqual(0, stream.line_count) self.assertEqual(0, stream.byte_count) def test_feed_data_line_byte_count(self): - stream = http_client.StreamReader() + stream = http.StreamReader() stream.feed_data(self.DATA) self.assertEqual(self.DATA.count(b'\n'), stream.line_count) @@ -37,7 +37,7 @@ def test_feed_data_line_byte_count(self): def test_read_zero(self): """Read zero bytes""" - stream = http_client.StreamReader() + stream = http.StreamReader() stream.feed_data(self.DATA) read_task = tasks.Task(stream.read(0)) @@ -48,7 +48,7 @@ def test_read_zero(self): def test_read(self): """ Read bytes """ - stream = http_client.StreamReader() + stream = http.StreamReader() read_task = tasks.Task(stream.read(30)) def cb(): @@ -62,7 +62,7 @@ def cb(): def test_read_line_breaks(self): """ Read bytes without line breaks """ - stream = http_client.StreamReader() + stream = http.StreamReader() stream.feed_data(b'line1') stream.feed_data(b'line2') @@ -75,7 +75,7 @@ def test_read_line_breaks(self): def test_read_eof(self): """ Read bytes, stop at eof """ - stream = http_client.StreamReader() + stream = http.StreamReader() read_task = tasks.Task(stream.read(1024)) def cb(): @@ -89,7 +89,7 @@ def cb(): def test_read_until_eof(self): """ Read all bytes until eof """ - stream = http_client.StreamReader() + stream = http.StreamReader() read_task = tasks.Task(stream.read(-1)) def cb(): @@ -106,7 +106,7 @@ def cb(): def test_readline(self): """ Read one line """ - stream = http_client.StreamReader() + stream = http.StreamReader() stream.feed_data(b'chunk1 ') read_task = tasks.Task(stream.readline()) @@ -124,7 +124,7 @@ def cb(): def test_readline_limit_with_existing_data(self): self.suppress_log_errors() - stream = http_client.StreamReader(3) + stream = http.StreamReader(3) stream.feed_data(b'li') stream.feed_data(b'ne1\nline2\n') @@ -133,7 +133,7 @@ def test_readline_limit_with_existing_data(self): ValueError, self.event_loop.run_until_complete, read_task) self.assertEqual([b'line2\n'], list(stream.buffer)) - stream = http_client.StreamReader(3) + stream = http.StreamReader(3) stream.feed_data(b'li') stream.feed_data(b'ne1') stream.feed_data(b'li') @@ -147,7 +147,7 @@ def test_readline_limit_with_existing_data(self): def test_readline_limit(self): self.suppress_log_errors() - stream = http_client.StreamReader(7) + stream = http.StreamReader(7) def cb(): stream.feed_data(b'chunk1') @@ -163,7 +163,7 @@ def cb(): self.assertEqual(7, stream.byte_count) def test_readline_line_byte_count(self): - stream = http_client.StreamReader() + stream = http.StreamReader() stream.feed_data(self.DATA[:6]) stream.feed_data(self.DATA[6:]) @@ -175,7 +175,7 @@ def test_readline_line_byte_count(self): self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) def test_readline_eof(self): - stream = http_client.StreamReader() + stream = http.StreamReader() stream.feed_data(b'some data') stream.feed_eof() @@ -185,7 +185,7 @@ def test_readline_eof(self): self.assertEqual(b'some data', line) def test_readline_empty_eof(self): - stream = http_client.StreamReader() + stream = http.StreamReader() stream.feed_eof() read_task = tasks.Task(stream.readline()) @@ -194,7 +194,7 @@ def test_readline_empty_eof(self): self.assertEqual(b'', line) def test_readline_read_byte_count(self): - stream = http_client.StreamReader() + stream = http.StreamReader() stream.feed_data(self.DATA) read_task = tasks.Task(stream.readline()) @@ -212,7 +212,7 @@ def test_readline_read_byte_count(self): def test_readexactly_zero_or_less(self): """ Read exact number of bytes (zero or less) """ - stream = http_client.StreamReader() + stream = http.StreamReader() stream.feed_data(self.DATA) read_task = tasks.Task(stream.readexactly(0)) @@ -229,9 +229,9 @@ def test_readexactly_zero_or_less(self): def test_readexactly(self): """ Read exact number of bytes """ - stream = http_client.StreamReader() + stream = http.StreamReader() - n = 2*len(self.DATA) + n = 2 * len(self.DATA) read_task = tasks.Task(stream.readexactly(n)) def cb(): @@ -241,14 +241,14 @@ def cb(): self.event_loop.call_soon(cb) data = self.event_loop.run_until_complete(read_task) - self.assertEqual(self.DATA+self.DATA, data) + self.assertEqual(self.DATA + self.DATA, data) self.assertEqual(len(self.DATA), stream.byte_count) self.assertEqual(self.DATA.count(b'\n'), stream.line_count) def test_readexactly_eof(self): """ Read exact number of bytes (eof) """ - stream = http_client.StreamReader() - n = 2*len(self.DATA) + stream = http.StreamReader() + n = 2 * len(self.DATA) read_task = tasks.Task(stream.readexactly(n)) def cb(): diff --git a/tulip/http/__init__.py b/tulip/http/__init__.py new file mode 100644 index 00000000..cb07d9ff --- /dev/null +++ b/tulip/http/__init__.py @@ -0,0 +1,5 @@ +# This relies on each of the submodules having an __all__ variable. +from .http_client import * + + +__all__ = http_client.__all__ diff --git a/tulip/http_client.py b/tulip/http/http_client.py similarity index 93% rename from tulip/http_client.py rename to tulip/http/http_client.py index c5cddba0..a5a4f04b 100644 --- a/tulip/http_client.py +++ b/tulip/http/http_client.py @@ -23,15 +23,15 @@ TODO: How do we do connection keep alive? Pooling? """ +__all__ = ['StreamReader', 'HttpClientProtocol'] + + import collections import email.message import email.parser import re import tulip -from . import events -from . import futures -from . import tasks # TODO: Move to another module. @@ -63,7 +63,7 @@ def feed_data(self, data): self.waiter = None waiter.set_result(False) - @tasks.coroutine + @tulip.coroutine def readline(self): parts = [] parts_size = 0 @@ -95,7 +95,7 @@ def readline(self): if not_enough: assert self.waiter is None - self.waiter = futures.Future() + self.waiter = tulip.Future() yield from self.waiter line = b''.join(parts) @@ -103,19 +103,19 @@ def readline(self): return line - @tasks.coroutine + @tulip.coroutine def read(self, n=-1): if not n: return b'' if n < 0: while not self.eof: assert not self.waiter - self.waiter = futures.Future() + self.waiter = tulip.Future() yield from self.waiter else: if not self.byte_count and not self.eof: assert not self.waiter - self.waiter = futures.Future() + self.waiter = tulip.Future() yield from self.waiter if n < 0 or self.byte_count <= n: data = b''.join(self.buffer) @@ -139,16 +139,17 @@ def read(self, n=-1): self.line_count -= data.count(b'\n') return b''.join(parts) - @tasks.coroutine + @tulip.coroutine def readexactly(self, n): if n <= 0: return b'' while self.byte_count < n and not self.eof: assert not self.waiter - self.waiter = futures.Future() + self.waiter = tulip.Future() yield from self.waiter return (yield from self.read(n)) + class HttpClientProtocol: """This Protocol class is also used to initiate the connection. @@ -200,24 +201,24 @@ def __init__(self, host, port=None, *, assert self.headers['Transfer-Encoding'].lower() == 'chunked' if 'host' not in self.headers: self.headers['Host'] = self.host - self.event_loop = events.get_event_loop() + self.event_loop = tulip.get_event_loop() self.transport = None def validate(self, value, name, embedded_spaces_okay=False): - # Must be a string. If embedded_spaces_okay is False, no + # Must be a string. If embedded_spaces_okay is False, no # whitespace is allowed; otherwise, internal single spaces are # allowed (but no other whitespace). assert isinstance(value, str), \ - '{} should be str, not {}'.format(name, type(value)) + '{} should be str, not {}'.format(name, type(value)) parts = value.split() assert parts, '{} should not be empty'.format(name) if embedded_spaces_okay: assert ' '.join(parts) == value, \ - '{} can only contain embedded single spaces ({!r})'.format( - name, value) + '{} can only contain embedded single spaces ({!r})'.format( + name, value) else: assert parts == [value], \ - '{} cannot contain whitespace ({!r})'.format(name, value) + '{} cannot contain whitespace ({!r})'.format(name, value) return value @tulip.coroutine From e3febbfea65876dae325605de76302b34942b1bb Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 25 Feb 2013 14:42:27 -0800 Subject: [PATCH 0331/1502] Move StreamReader to separate module --- .../{http_client_test.py => streams_test.py} | 58 ++++---- tulip/__init__.py | 2 + tulip/http/http_client.py | 116 --------------- tulip/streams.py | 132 ++++++++++++++++++ 4 files changed, 163 insertions(+), 145 deletions(-) rename tests/{http_client_test.py => streams_test.py} (86%) create mode 100644 tulip/streams.py diff --git a/tests/http_client_test.py b/tests/streams_test.py similarity index 86% rename from tests/http_client_test.py rename to tests/streams_test.py index e0ca45af..118f6a05 100644 --- a/tests/http_client_test.py +++ b/tests/streams_test.py @@ -1,9 +1,9 @@ -"""Tests for http/http_client.py.""" +"""Tests for streams.py.""" import unittest from tulip import events -from tulip import http +from tulip import streams from tulip import tasks from tulip import test_utils @@ -22,22 +22,22 @@ def tearDown(self): self.event_loop.close() def test_feed_empty_data(self): - stream = http.StreamReader() + stream = streams.StreamReader() stream.feed_data(b'') self.assertEqual(0, stream.line_count) self.assertEqual(0, stream.byte_count) def test_feed_data_line_byte_count(self): - stream = http.StreamReader() + stream = streams.StreamReader() stream.feed_data(self.DATA) self.assertEqual(self.DATA.count(b'\n'), stream.line_count) self.assertEqual(len(self.DATA), stream.byte_count) def test_read_zero(self): - """Read zero bytes""" - stream = http.StreamReader() + """Read zero bytes.""" + stream = streams.StreamReader() stream.feed_data(self.DATA) read_task = tasks.Task(stream.read(0)) @@ -47,8 +47,8 @@ def test_read_zero(self): self.assertEqual(self.DATA.count(b'\n'), stream.line_count) def test_read(self): - """ Read bytes """ - stream = http.StreamReader() + """Read bytes.""" + stream = streams.StreamReader() read_task = tasks.Task(stream.read(30)) def cb(): @@ -61,8 +61,8 @@ def cb(): self.assertFalse(stream.line_count) def test_read_line_breaks(self): - """ Read bytes without line breaks """ - stream = http.StreamReader() + """Read bytes without line breaks.""" + stream = streams.StreamReader() stream.feed_data(b'line1') stream.feed_data(b'line2') @@ -74,8 +74,8 @@ def test_read_line_breaks(self): self.assertFalse(stream.line_count) def test_read_eof(self): - """ Read bytes, stop at eof """ - stream = http.StreamReader() + """Read bytes, stop at eof.""" + stream = streams.StreamReader() read_task = tasks.Task(stream.read(1024)) def cb(): @@ -88,8 +88,8 @@ def cb(): self.assertFalse(stream.line_count) def test_read_until_eof(self): - """ Read all bytes until eof """ - stream = http.StreamReader() + """Read all bytes until eof.""" + stream = streams.StreamReader() read_task = tasks.Task(stream.read(-1)) def cb(): @@ -105,8 +105,8 @@ def cb(): self.assertFalse(stream.line_count) def test_readline(self): - """ Read one line """ - stream = http.StreamReader() + """Read one line.""" + stream = streams.StreamReader() stream.feed_data(b'chunk1 ') read_task = tasks.Task(stream.readline()) @@ -124,7 +124,7 @@ def cb(): def test_readline_limit_with_existing_data(self): self.suppress_log_errors() - stream = http.StreamReader(3) + stream = streams.StreamReader(3) stream.feed_data(b'li') stream.feed_data(b'ne1\nline2\n') @@ -133,7 +133,7 @@ def test_readline_limit_with_existing_data(self): ValueError, self.event_loop.run_until_complete, read_task) self.assertEqual([b'line2\n'], list(stream.buffer)) - stream = http.StreamReader(3) + stream = streams.StreamReader(3) stream.feed_data(b'li') stream.feed_data(b'ne1') stream.feed_data(b'li') @@ -147,7 +147,7 @@ def test_readline_limit_with_existing_data(self): def test_readline_limit(self): self.suppress_log_errors() - stream = http.StreamReader(7) + stream = streams.StreamReader(7) def cb(): stream.feed_data(b'chunk1') @@ -163,7 +163,7 @@ def cb(): self.assertEqual(7, stream.byte_count) def test_readline_line_byte_count(self): - stream = http.StreamReader() + stream = streams.StreamReader() stream.feed_data(self.DATA[:6]) stream.feed_data(self.DATA[6:]) @@ -175,7 +175,7 @@ def test_readline_line_byte_count(self): self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) def test_readline_eof(self): - stream = http.StreamReader() + stream = streams.StreamReader() stream.feed_data(b'some data') stream.feed_eof() @@ -185,7 +185,7 @@ def test_readline_eof(self): self.assertEqual(b'some data', line) def test_readline_empty_eof(self): - stream = http.StreamReader() + stream = streams.StreamReader() stream.feed_eof() read_task = tasks.Task(stream.readline()) @@ -194,7 +194,7 @@ def test_readline_empty_eof(self): self.assertEqual(b'', line) def test_readline_read_byte_count(self): - stream = http.StreamReader() + stream = streams.StreamReader() stream.feed_data(self.DATA) read_task = tasks.Task(stream.readline()) @@ -211,8 +211,8 @@ def test_readline_read_byte_count(self): stream.byte_count) def test_readexactly_zero_or_less(self): - """ Read exact number of bytes (zero or less) """ - stream = http.StreamReader() + """Read exact number of bytes (zero or less).""" + stream = streams.StreamReader() stream.feed_data(self.DATA) read_task = tasks.Task(stream.readexactly(0)) @@ -228,8 +228,8 @@ def test_readexactly_zero_or_less(self): self.assertEqual(self.DATA.count(b'\n'), stream.line_count) def test_readexactly(self): - """ Read exact number of bytes """ - stream = http.StreamReader() + """Read exact number of bytes.""" + stream = streams.StreamReader() n = 2 * len(self.DATA) read_task = tasks.Task(stream.readexactly(n)) @@ -246,8 +246,8 @@ def cb(): self.assertEqual(self.DATA.count(b'\n'), stream.line_count) def test_readexactly_eof(self): - """ Read exact number of bytes (eof) """ - stream = http.StreamReader() + """Read exact number of bytes (eof).""" + stream = streams.StreamReader() n = 2 * len(self.DATA) read_task = tasks.Task(stream.readexactly(n)) diff --git a/tulip/__init__.py b/tulip/__init__.py index f3b407ec..9a9b305e 100644 --- a/tulip/__init__.py +++ b/tulip/__init__.py @@ -6,6 +6,7 @@ from .locks import * from .transports import * from .protocols import * +from .streams import * from .tasks import * __all__ = (futures.__all__ + @@ -13,4 +14,5 @@ locks.__all__ + transports.__all__ + protocols.__all__ + + streams.__all__ + tasks.__all__) diff --git a/tulip/http/http_client.py b/tulip/http/http_client.py index a5a4f04b..e7b884d6 100644 --- a/tulip/http/http_client.py +++ b/tulip/http/http_client.py @@ -34,122 +34,6 @@ import tulip -# TODO: Move to another module. -class StreamReader: - - def __init__(self, limit=2**16): - self.limit = limit # Max line length. (Security feature.) - self.buffer = collections.deque() # Deque of bytes objects. - self.byte_count = 0 # Bytes in buffer. - self.line_count = 0 # Number of complete lines in buffer. - self.eof = False # Whether we're done. - self.waiter = None # A future. - - def feed_eof(self): - self.eof = True - waiter = self.waiter - if waiter is not None: - self.waiter = None - waiter.set_result(True) - - def feed_data(self, data): - if not data: - return - self.buffer.append(data) - self.line_count += data.count(b'\n') - self.byte_count += len(data) - waiter = self.waiter - if waiter is not None: - self.waiter = None - waiter.set_result(False) - - @tulip.coroutine - def readline(self): - parts = [] - parts_size = 0 - not_enough = True - - while not_enough: - while self.buffer and not_enough: - data = self.buffer.popleft() - ichar = data.find(b'\n') - if ichar < 0: - parts.append(data) - parts_size += len(data) - else: - ichar += 1 - head, tail = data[:ichar], data[ichar:] - if tail: - self.buffer.appendleft(tail) - self.line_count -= 1 - not_enough = False - parts.append(head) - parts_size += len(head) - - if parts_size > self.limit: - self.byte_count -= parts_size - raise ValueError('Line is too long') - - if self.eof: - break - - if not_enough: - assert self.waiter is None - self.waiter = tulip.Future() - yield from self.waiter - - line = b''.join(parts) - self.byte_count -= parts_size - - return line - - @tulip.coroutine - def read(self, n=-1): - if not n: - return b'' - if n < 0: - while not self.eof: - assert not self.waiter - self.waiter = tulip.Future() - yield from self.waiter - else: - if not self.byte_count and not self.eof: - assert not self.waiter - self.waiter = tulip.Future() - yield from self.waiter - if n < 0 or self.byte_count <= n: - data = b''.join(self.buffer) - self.buffer.clear() - self.byte_count = 0 - self.line_count = 0 - return data - parts = [] - parts_bytes = 0 - while self.buffer and parts_bytes < n: - data = self.buffer.popleft() - data_bytes = len(data) - if n < parts_bytes + data_bytes: - data_bytes = n - parts_bytes - data, rest = data[:data_bytes], data[data_bytes:] - self.buffer.appendleft(rest) - parts.append(data) - parts_bytes += data_bytes - self.byte_count -= data_bytes - if self.line_count: - self.line_count -= data.count(b'\n') - return b''.join(parts) - - @tulip.coroutine - def readexactly(self, n): - if n <= 0: - return b'' - while self.byte_count < n and not self.eof: - assert not self.waiter - self.waiter = tulip.Future() - yield from self.waiter - return (yield from self.read(n)) - - class HttpClientProtocol: """This Protocol class is also used to initiate the connection. diff --git a/tulip/streams.py b/tulip/streams.py new file mode 100644 index 00000000..b10aa55c --- /dev/null +++ b/tulip/streams.py @@ -0,0 +1,132 @@ +"""Stream-related things.""" + +__all__ = ['StreamReader'] + +import collections + +from . import futures +from . import tasks + + +class StreamReader: + + def __init__(self, limit=2**16): + self.limit = limit # Max line length. (Security feature.) + self.buffer = collections.deque() # Deque of bytes objects. + self.byte_count = 0 # Bytes in buffer. + self.line_count = 0 # Number of complete lines in buffer. + self.eof = False # Whether we're done. + self.waiter = None # A future. + + def feed_eof(self): + self.eof = True + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(True) + + def feed_data(self, data): + if not data: + return + + self.buffer.append(data) + self.line_count += data.count(b'\n') + self.byte_count += len(data) + + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(False) + + @tasks.coroutine + def readline(self): + parts = [] + parts_size = 0 + not_enough = True + + while not_enough: + while self.buffer and not_enough: + data = self.buffer.popleft() + ichar = data.find(b'\n') + if ichar < 0: + parts.append(data) + parts_size += len(data) + else: + ichar += 1 + head, tail = data[:ichar], data[ichar:] + if tail: + self.buffer.appendleft(tail) + self.line_count -= 1 + not_enough = False + parts.append(head) + parts_size += len(head) + + if parts_size > self.limit: + self.byte_count -= parts_size + raise ValueError('Line is too long') + + if self.eof: + break + + if not_enough: + assert self.waiter is None + self.waiter = futures.Future() + yield from self.waiter + + line = b''.join(parts) + self.byte_count -= parts_size + + return line + + @tasks.coroutine + def read(self, n=-1): + if not n: + return b'' + + if n < 0: + while not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + else: + if not self.byte_count and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + + if n < 0 or self.byte_count <= n: + data = b''.join(self.buffer) + self.buffer.clear() + self.byte_count = 0 + self.line_count = 0 + return data + + parts = [] + parts_bytes = 0 + while self.buffer and parts_bytes < n: + data = self.buffer.popleft() + data_bytes = len(data) + if n < parts_bytes + data_bytes: + data_bytes = n - parts_bytes + data, rest = data[:data_bytes], data[data_bytes:] + self.buffer.appendleft(rest) + + parts.append(data) + parts_bytes += data_bytes + self.byte_count -= data_bytes + if self.line_count: + self.line_count -= data.count(b'\n') + + return b''.join(parts) + + @tasks.coroutine + def readexactly(self, n): + if n <= 0: + return b'' + + while self.byte_count < n and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + + return (yield from self.read(n)) From 4a2cb59a8623e883420df206e526bedebe81db75 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 26 Feb 2013 17:10:03 -0800 Subject: [PATCH 0332/1502] There is nothing in runtests.py that is tulip-specific any more. --- runtests.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/runtests.py b/runtests.py index 9b00c9d8..c0ef265f 100644 --- a/runtests.py +++ b/runtests.py @@ -10,7 +10,7 @@ Note that the test id is the fully qualified name of the test, including package, module, class and method, -e.g. 'tulip.events_test.PolicyTests.testPolicy'. +e.g. 'tests.events_test.PolicyTests.testPolicy'. """ # Originally written by Beech Horn (for NDB). @@ -24,25 +24,24 @@ assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' -TULIP_DIR = os.path.join(os.path.dirname(__file__), 'tests') +TESTS_DIR = os.path.join(os.path.dirname(__file__), 'tests') def load_tests(includes=(), excludes=()): - test_mods = [(f[:-3], f) for f in os.listdir(TULIP_DIR) + test_mods = [(f[:-3], f) for f in os.listdir(TESTS_DIR) if f.endswith('_test.py')] mods = [] for mod, sourcefile in test_mods: try: loader = importlib.machinery.SourceFileLoader( - mod, os.path.join(TULIP_DIR, sourcefile)) + mod, os.path.join(TESTS_DIR, sourcefile)) mods.append(loader.load_module()) except ImportError: pass loader = unittest.TestLoader() suite = unittest.TestSuite() - tulip = sys.modules['tulip'] for mod in mods: for name in set(dir(mod)): From df656a087ec28a6876460e17a8c287c487e538f2 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 27 Feb 2013 09:36:35 -0800 Subject: [PATCH 0333/1502] runtests improvements. --- runtests.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/runtests.py b/runtests.py index c0ef265f..fb6038a4 100644 --- a/runtests.py +++ b/runtests.py @@ -24,12 +24,12 @@ assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' -TESTS_DIR = os.path.join(os.path.dirname(__file__), 'tests') +TESTS_DIR = 'tests' def load_tests(includes=(), excludes=()): test_mods = [(f[:-3], f) for f in os.listdir(TESTS_DIR) - if f.endswith('_test.py')] + if f.endswith('_test.py') and not f.startswith('.')] mods = [] for mod, sourcefile in test_mods: From 75a39964a5f7f25d703d72356674544ea7c1f541 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 4 Mar 2013 14:01:52 -0800 Subject: [PATCH 0334/1502] Fix tests for Windows. --- tests/events_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/events_test.py b/tests/events_test.py index f5a2ebcf..538a3d7a 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -614,7 +614,7 @@ def test_internal_fds(self): if sys.platform == 'win32': - from . import windows_events + from tulip import windows_events class SelectEventLoopTests(EventLoopTestsMixin, test_utils.LogTrackingTestCase): From 6a0bb329272a1e045eb5e4fce59fb2a3dae9d516 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 6 Mar 2013 22:20:03 -0800 Subject: [PATCH 0335/1502] winsocketpair tests --- tests/winsocketpair_test.py | 32 ++++++++++++++++++++++++++++++++ tulip/winsocketpair.py | 4 ++++ 2 files changed, 36 insertions(+) create mode 100644 tests/winsocketpair_test.py diff --git a/tests/winsocketpair_test.py b/tests/winsocketpair_test.py new file mode 100644 index 00000000..0175e9b9 --- /dev/null +++ b/tests/winsocketpair_test.py @@ -0,0 +1,32 @@ +"""Tests for winsocketpair.py""" + +import errno +import socket +import unittest +import unittest.mock + +from tulip import winsocketpair + + +class WinsocketpairTests(unittest.TestCase): + + def test_winsocketpair(self): + ssock, csock = winsocketpair.socketpair() + + csock.send(b'xxx') + self.assertEqual(b'xxx', ssock.recv(1024)) + + csock.close() + ssock.close() + + @unittest.mock.patch('tulip.winsocketpair.socket') + def test_winsocketpair_exc(self, m_socket): + m_socket.error = socket.error + class Err(socket.error): + errno = errno.WSAEWOULDBLOCK + 1 + + m_socket.socket.return_value.getsockname.return_value = ('', 12345) + m_socket.socket.return_value.accept.return_value = object(), object() + m_socket.socket.return_value.connect.side_effect = Err + + self.assertRaises(Err, winsocketpair.socketpair) diff --git a/tulip/winsocketpair.py b/tulip/winsocketpair.py index 87d54c91..805b215f 100644 --- a/tulip/winsocketpair.py +++ b/tulip/winsocketpair.py @@ -5,6 +5,10 @@ import errno import socket +import sys + +if sys.platform != 'win32': + raise ImportError('winsocketpair is win32 only') def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): From 8a3cb5101c47d210f76ee770257a90d210483a23 Mon Sep 17 00:00:00 2001 From: Sa?l Ibarra Corretg? Date: Thu, 7 Mar 2013 09:06:20 +0100 Subject: [PATCH 0336/1502] Avoid using event_loop._socketpair in tests --- tests/events_test.py | 12 ++++++------ tulip/test_utils.py | 7 +++++++ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 538a3d7a..3392d164 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -176,7 +176,7 @@ def run(arg): self.assertEqual(res, 'yo') def test_reader_callback(self): - r, w = self.event_loop._socketpair() + r, w = test_utils.socketpair() bytes_read = [] def reader(): try: @@ -198,7 +198,7 @@ def reader(): self.assertEqual(b''.join(bytes_read), b'abcdef') def test_reader_callback_with_handler(self): - r, w = self.event_loop._socketpair() + r, w = test_utils.socketpair() bytes_read = [] def reader(): try: @@ -223,7 +223,7 @@ def reader(): self.assertEqual(b''.join(bytes_read), b'abcdef') def test_reader_callback_cancel(self): - r, w = self.event_loop._socketpair() + r, w = test_utils.socketpair() bytes_read = [] def reader(): try: @@ -244,7 +244,7 @@ def reader(): self.assertEqual(b''.join(bytes_read), b'abcdef') def test_writer_callback(self): - r, w = self.event_loop._socketpair() + r, w = test_utils.socketpair() w.setblocking(False) self.event_loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) def remove_writer(): @@ -257,7 +257,7 @@ def remove_writer(): self.assertTrue(len(data) >= 200) def test_writer_callback_with_handler(self): - r, w = self.event_loop._socketpair() + r, w = test_utils.socketpair() w.setblocking(False) handler = events.Handler(w.send, (b'x'*(256*1024),)) self.assertIs(self.event_loop.add_writer(w.fileno(), handler), handler) @@ -271,7 +271,7 @@ def remove_writer(): self.assertTrue(len(data) >= 200) def test_writer_callback_cancel(self): - r, w = self.event_loop._socketpair() + r, w = test_utils.socketpair() w.setblocking(False) def sender(): w.send(b'x'*256) diff --git a/tulip/test_utils.py b/tulip/test_utils.py index f07c34ce..75e6514e 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -1,9 +1,16 @@ """Utilities shared by tests.""" import logging +import socket import unittest +try: + from socket import socketpair +except ImportError: + from .winsocketpair import socketpair + + class LogTrackingTestCase(unittest.TestCase): def setUp(self): From 9e19042649baf552ad5c0865242faf2837f47dd5 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 7 Mar 2013 10:55:14 -0800 Subject: [PATCH 0337/1502] runtests.py coverage support --- Makefile | 6 +- README | 13 +--- runtests.py | 149 +++++++++++++++++++++++++++++++------- tulip/__init__.py | 8 ++ tulip/http/http_client.py | 2 +- tulip/test_utils.py | 11 +-- 6 files changed, 139 insertions(+), 50 deletions(-) diff --git a/Makefile b/Makefile index 65a48111..db06318b 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,6 @@ # Some simple testing tasks (sorry, UNIX only). PYTHON=python3 -COVERAGE=coverage3 -NONTESTS=`find tulip -name [a-z]\*.py ! -name \*_test.py` FLAGS= test: @@ -13,9 +11,7 @@ testloop: # See README for coverage installation instructions. cov coverage: - $(COVERAGE) run --branch runtests.py -v $(FLAGS) - $(COVERAGE) html $(NONTESTS) - $(COVERAGE) report -m $(NONTESTS) + $(PYTHON) runtests.py --coverage tulip -v $(FLAGS) echo "open file://`pwd`/htmlcov/index.html" check: diff --git a/README b/README index c1c86a54..85bfe5a7 100644 --- a/README +++ b/README @@ -4,7 +4,7 @@ PEP 3156: http://www.python.org/dev/peps/pep-3156/ *** This requires Python 3.3 or later! *** -Copyright/license: Open source, Apache 2.0. Enjoy. +Copyright/license: Open source, Apache 2.0. Enjoy. Master Mercurial repo: http://code.google.com/p/tulip/ @@ -14,17 +14,8 @@ to PEP 3156, under construction) lives in the 'tulip' subdirectory. To run tests: - make test -To run coverage (after installing coverage3, see below): +To run coverage (coverage package is required): - make coverage -To install coverage3 (coverage.py for Python 3), you need: - - Distribute (http://packages.python.org/distribute/) - - Coverage (http://nedbatchelder.com/code/coverage/) - What worked for me: - - curl -O http://python-distribute.org/distribute_setup.py - - python3 distribute_setup.py - - - - cd coveragepy - - python3 setup.py install --Guido van Rossum diff --git a/runtests.py b/runtests.py index fb6038a4..6fa888b3 100644 --- a/runtests.py +++ b/runtests.py @@ -11,35 +11,92 @@ Note that the test id is the fully qualified name of the test, including package, module, class and method, e.g. 'tests.events_test.PolicyTests.testPolicy'. + +runtests.py with --coverage argument is equivalent of: + + $(COVERAGE) run --branch runtests.py -v + $(COVERAGE) html $(list of files) + $(COVERAGE) report -m $(list of files) + """ # Originally written by Beech Horn (for NDB). +import argparse import logging import os import re import sys +import subprocess import unittest import importlib.machinery assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' -TESTS_DIR = 'tests' +ARGS = argparse.ArgumentParser(description="Run all unittests.") +ARGS.add_argument( + '-v', action="store", dest='verbose', + nargs='?', const=1, type=int, default=0, help='verbose') +ARGS.add_argument( + '-x', action="store_true", dest='exclude', help='exclude tests') +ARGS.add_argument( + '-q', action="store_true", dest='quiet', help='quiet') +ARGS.add_argument( + '--tests', action="store", dest='testsdir', default='tests', + help='tests directory') +ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') +ARGS.add_argument( + 'pattern', action="store", nargs="*", + help='optional regex patterns to match test ids (default all tests)') + +COV_ARGS = argparse.ArgumentParser(description="Run all unittests.") +COV_ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') + + +def load_modules(basedir, suffix='.py'): + def list_dir(prefix, dir): + files = [] + + for name in os.listdir(dir): + path = os.path.join(dir, name) + + if os.path.isdir(path): + files.extend(list_dir('%s%s.' % (prefix, name), path)) + else: + if name.endswith(suffix) and not name.startswith(('.', '_')): + files.append(('%s%s'%(prefix, name[:-3]), path)) + return files -def load_tests(includes=(), excludes=()): - test_mods = [(f[:-3], f) for f in os.listdir(TESTS_DIR) - if f.endswith('_test.py') and not f.startswith('.')] + files = [] + modpath = os.path.join(basedir, '__init__.py') + if os.path.isfile(modpath): + mod = os.path.split(basedir)[-1] + prefix = '%s.' % mod + files.append((mod, modpath)) + else: + prefix = '' mods = [] - for mod, sourcefile in test_mods: + for modname, sourcefile in files + list_dir(prefix, basedir): + if modname == 'runtests': + continue try: - loader = importlib.machinery.SourceFileLoader( - mod, os.path.join(TESTS_DIR, sourcefile)) - mods.append(loader.load_module()) - except ImportError: + loader = importlib.machinery.SourceFileLoader(modname, sourcefile) + mods.append((loader.load_module(), sourcefile)) + except: pass + return mods + + +def load_tests(testsdir, includes=(), excludes=()): + mods = [mod for mod, _ in load_modules(testsdir)] + loader = unittest.TestLoader() suite = unittest.TestSuite() @@ -63,24 +120,24 @@ def load_tests(includes=(), excludes=()): return suite -def main(): - excludes = [] - includes = [] - patterns = includes # A reference. - v = 1 - for arg in sys.argv[1:]: - if arg.startswith('-v'): - v += arg.count('v') - elif arg == '-q': - v = 0 - elif arg == '-x': - if patterns is includes: - patterns = excludes - else: - patterns = includes - elif arg and not arg.startswith('-'): - patterns.append(arg) - tests = load_tests(includes, excludes) +def runtests(): + args = ARGS.parse_args() + + testsdir = os.path.abspath(args.testsdir) + if not os.path.isdir(testsdir): + print("Tests directory is not found: %s\n"%testsdir) + ARGS.print_help() + return + + excludes = includes = [] + if args.exclude: + excludes = args.pattern + else: + includes = args.pattern + + v = 0 if args.quiet else args.verbose + 1 + + tests = load_tests(args.testsdir, includes, excludes) logger = logging.getLogger() if v == 0: logger.setLevel(logging.CRITICAL) @@ -96,5 +153,41 @@ def main(): sys.exit(not result.wasSuccessful()) +def runcoverage(sdir, args): + """ + To install coverage3 for Python 3, you need: + - Distribute (http://packages.python.org/distribute/) + + What worked for me: + - download http://python-distribute.org/distribute_setup.py + * curl -O http://python-distribute.org/distribute_setup.py + - python3 distribute_setup.py + - python3 -m easy_install coverage + """ + try: + import coverage + except ImportError: + print("Coverage package is not found.") + print(runcoverage.__doc__) + return + + sdir = os.path.abspath(sdir) + if not os.path.isdir(sdir): + print("Python files directory is not found: %s\n"%sdir) + ARGS.print_help() + return + + mods = [source for _, source in load_modules(sdir)] + coverage = [sys.executable, '-m', 'coverage'] + + subprocess.check_call(coverage + ['run', '--branch', 'runtests.py'] + args) + subprocess.check_call(coverage + ['html'] + mods) + subprocess.check_call(coverage + ['report'] + mods) + + if __name__ == '__main__': - main() + if '--coverage' in sys.argv: + cov_args, args = COV_ARGS.parse_known_args() + runcoverage(cov_args.coverage, args) + else: + runtests() diff --git a/tulip/__init__.py b/tulip/__init__.py index 9a9b305e..e8fa861a 100644 --- a/tulip/__init__.py +++ b/tulip/__init__.py @@ -1,5 +1,7 @@ """Tulip 2.0, tracking PEP 3156.""" +import sys + # This relies on each of the submodules having an __all__ variable. from .futures import * from .events import * @@ -9,6 +11,12 @@ from .streams import * from .tasks import * +if sys.platform == 'win32': # pragma: no cover + from .windows_events import * +else: + from .unix_events import * + + __all__ = (futures.__all__ + events.__all__ + locks.__all__ + diff --git a/tulip/http/http_client.py b/tulip/http/http_client.py index e7b884d6..71f97c5d 100644 --- a/tulip/http/http_client.py +++ b/tulip/http/http_client.py @@ -23,7 +23,7 @@ TODO: How do we do connection keep alive? Pooling? """ -__all__ = ['StreamReader', 'HttpClientProtocol'] +__all__ = ['HttpClientProtocol'] import collections diff --git a/tulip/test_utils.py b/tulip/test_utils.py index 75e6514e..946e1ef7 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -2,13 +2,14 @@ import logging import socket +import sys import unittest -try: - from socket import socketpair -except ImportError: +if sys.platform == 'win32': # pragma: no cover from .winsocketpair import socketpair +else: + from socket import socketpair class LogTrackingTestCase(unittest.TestCase): @@ -20,10 +21,10 @@ def setUp(self): def tearDown(self): self._logger.setLevel(self._log_level) - def suppress_log_errors(self): + def suppress_log_errors(self): # pragma: no cover if self._log_level >= logging.WARNING: self._logger.setLevel(logging.CRITICAL) - def suppress_log_warnings(self): + def suppress_log_warnings(self): # pragma: no cover if self._log_level >= logging.WARNING: self._logger.setLevel(logging.ERROR) From a946a77dc01f62981e65968bbe9be5acb68abcaf Mon Sep 17 00:00:00 2001 From: Sa?l Ibarra Corretg? Date: Fri, 8 Mar 2013 09:31:04 +0100 Subject: [PATCH 0338/1502] Added Handler.run --- tests/events_test.py | 17 +++++++++++++++++ tulip/base_events.py | 6 +----- tulip/events.py | 8 ++++++++ tulip/unix_events.py | 2 +- 4 files changed, 27 insertions(+), 6 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 3392d164..4c3056a9 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -396,6 +396,23 @@ def my_handler(): self.event_loop.run_forever() self.assertEqual(caught, 1) + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_args(self): + some_args = (42,) + caught = 0 + def my_handler(*args): + nonlocal caught + caught += 1 + self.assertEqual(args, some_args) + + handler = self.event_loop.add_signal_handler(signal.SIGALRM, + my_handler, + *some_args) + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + self.event_loop.call_later(0.15, self.event_loop.stop) + self.event_loop.run_forever() + self.assertEqual(caught, 1) + def test_create_connection(self): # TODO: This depends on xkcd.com behavior! f = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) diff --git a/tulip/base_events.py b/tulip/base_events.py index 0bd47ed9..09bdd617 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -409,8 +409,4 @@ def _run_once(self, timeout=None): for i in range(ntodo): handler = self._ready.popleft() if not handler.cancelled: - try: - handler.callback(*handler.args) - except Exception: - logging.exception('Exception in callback %s %r', - handler.callback, handler.args) + handler.run() diff --git a/tulip/events.py b/tulip/events.py index 86f8c508..6b2887b1 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -10,6 +10,7 @@ 'get_event_loop', 'set_event_loop', 'new_event_loop', ] +import logging import sys import threading @@ -43,6 +44,13 @@ def cancelled(self): def cancel(self): self._cancelled = True + def run(self): + try: + self._callback(*self._args) + except Exception: + logging.exception('Exception in callback %s %r', + self._callback, self._args) + def make_handler(callback, args): if isinstance(callback, Handler): diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 760db1bf..db2c560d 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -69,7 +69,7 @@ def _handle_signal(self, sig, arg): if handler.cancelled: self.remove_signal_handler(sig) # Remove it properly. else: - self.call_soon_threadsafe(handler.callback, *handler.args) + self.call_soon_threadsafe(handler) def remove_signal_handler(self, sig): """Remove a handler for a signal. UNIX only. From bca244f1642f1364c70b6e927e68fe4b43ab65d1 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 8 Mar 2013 15:12:49 -0800 Subject: [PATCH 0339/1502] better submodule support, more verbose on skipping modules --- runtests.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/runtests.py b/runtests.py index 6fa888b3..89204e88 100644 --- a/runtests.py +++ b/runtests.py @@ -61,35 +61,35 @@ def load_modules(basedir, suffix='.py'): def list_dir(prefix, dir): files = [] + modpath = os.path.join(dir, '__init__.py') + if os.path.isfile(modpath): + mod = os.path.split(dir)[-1] + files.append(('%s%s' % (prefix, mod), modpath)) + + prefix = '%s%s.' % (prefix, mod) + for name in os.listdir(dir): path = os.path.join(dir, name) if os.path.isdir(path): files.extend(list_dir('%s%s.' % (prefix, name), path)) else: - if name.endswith(suffix) and not name.startswith(('.', '_')): - files.append(('%s%s'%(prefix, name[:-3]), path)) + if (name != '__init__.py' and + name.endswith(suffix) and + not name.startswith(('.', '_'))): + files.append(('%s%s' % (prefix, name[:-3]), path)) return files - files = [] - modpath = os.path.join(basedir, '__init__.py') - if os.path.isfile(modpath): - mod = os.path.split(basedir)[-1] - prefix = '%s.' % mod - files.append((mod, modpath)) - else: - prefix = '' - mods = [] - for modname, sourcefile in files + list_dir(prefix, basedir): + for modname, sourcefile in list_dir('', basedir): if modname == 'runtests': continue try: loader = importlib.machinery.SourceFileLoader(modname, sourcefile) mods.append((loader.load_module(), sourcefile)) - except: - pass + except Exception as err: + print("Skipping '%s': %s" % (modname, err)) return mods @@ -125,7 +125,7 @@ def runtests(): testsdir = os.path.abspath(args.testsdir) if not os.path.isdir(testsdir): - print("Tests directory is not found: %s\n"%testsdir) + print("Tests directory is not found: %s\n" % testsdir) ARGS.print_help() return @@ -173,7 +173,7 @@ def runcoverage(sdir, args): sdir = os.path.abspath(sdir) if not os.path.isdir(sdir): - print("Python files directory is not found: %s\n"%sdir) + print("Python files directory is not found: %s\n" % sdir) ARGS.print_help() return From ccdd90bbad0f951c06c5f80be93b0dbbbc847122 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 8 Mar 2013 22:30:01 -0800 Subject: [PATCH 0340/1502] proactor socket transport optimization --- tests/events_test.py | 2 -- tulip/proactor_events.py | 26 ++++++++++++++++---------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 4c3056a9..807aa501 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -513,7 +513,6 @@ def factory(): sock = self.event_loop.run_until_complete(f) host, port = sock.getsockname() self.assertEqual(host, '0.0.0.0') - self.event_loop.run_once(0.01) # for windows proactor selector client = socket.socket() client.connect(('127.0.0.1', port)) client.send(b'xxx') @@ -524,7 +523,6 @@ def factory(): self.assertEqual('CONNECTED', proto.state) self.assertEqual(0, proto.nbytes) self.event_loop.run_once() - self.event_loop.run_once(0.1) # for windows proactor selector self.assertEqual(3, proto.nbytes) # extra info is available diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index 0391eb49..45c075e3 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -6,10 +6,8 @@ import logging - from . import base_events from . import transports -from . import winsocketpair class _ProactorSocketTransport(transports.Transport): @@ -29,16 +27,18 @@ def __init__(self, event_loop, sock, protocol, waiter=None, extra=None): if waiter is not None: self._event_loop.call_soon(waiter.set_result, None) - def _loop_reading(self, f=None): + def _loop_reading(self, fut=None): + data = None + try: - assert f is self._read_fut - if f: - data = f.result() + if fut is not None: + assert fut is self._read_fut + + data = fut.result() # deliver data later in "finally" clause if not data: - self._event_loop.call_soon(self._protocol.eof_received) self._read_fut = None return - self._event_loop.call_soon(self._protocol.data_received, data) + self._read_fut = self._event_loop._proactor.recv(self._sock, 4096) except ConnectionAbortedError as exc: if not self._closing: @@ -47,6 +47,11 @@ def _loop_reading(self, f=None): self._fatal_error(exc) else: self._read_fut.add_done_callback(self._loop_reading) + finally: + if data: + self._protocol.data_received(data) + elif data is not None: + self._protocol.eof_received() def write(self, data): assert isinstance(data, bytes) @@ -149,6 +154,7 @@ def _make_self_pipe(self): self._ssock.setblocking(False) self._csock.setblocking(False) self._internal_fds += 1 + def loop(f=None): try: if f: @@ -170,10 +176,10 @@ def loop(f=None): if f: conn, addr = f.result() protocol = protocol_factory() - transport = self._make_socket_transport( + self._make_socket_transport( conn, protocol, extra={'addr': addr}) f = self._proactor.accept(sock) - except OSError as exc: + except OSError: sock.close() logging.exception('Accept failed') else: From 80e928d685bef38c7dbaed3222f500f09793784f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 11 Mar 2013 10:40:20 -0700 Subject: [PATCH 0341/1502] Clean deeper. --- Makefile | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/Makefile b/Makefile index db06318b..2391d89c 100644 --- a/Makefile +++ b/Makefile @@ -18,12 +18,12 @@ check: $(PYTHON) check.py clean: - rm -rf __pycache__ */__pycache__ - rm -f *.py[co] */*.py[co] - rm -f *~ */*~ - rm -f .*~ */.*~ - rm -f @* */@* - rm -f '#'*'#' */'#'*'#' - rm -f *.orig */*.orig + rm -rf `find . -name __pycache__` + rm -f `find . -type f -name '*.py[co]' ` + rm -f `find . -type f -name '*~' ` + rm -f `find . -type f -name '.*~' ` + rm -f `find . -type f -name '@*' ` + rm -f `find . -type f -name '#*#' ` + rm -f `find . -type f -name '*.orig' ` rm -f .coverage rm -rf htmlcov From 8acaa423659ce547dca56707c0c0eea79277c053 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 11 Mar 2013 15:54:55 -0700 Subject: [PATCH 0342/1502] helper coroutines for http protocol --- srv.py | 46 ++++++------ tests/http_protocol_test.py | 142 ++++++++++++++++++++++++++++++++++++ tulip/http/__init__.py | 5 +- tulip/http/http_client.py | 27 +++---- tulip/http/protocol.py | 113 ++++++++++++++++++++++++++++ 5 files changed, 292 insertions(+), 41 deletions(-) create mode 100644 tests/http_protocol_test.py create mode 100644 tulip/http/protocol.py diff --git a/srv.py b/srv.py index 545ba4de..0a7c87a4 100644 --- a/srv.py +++ b/srv.py @@ -1,9 +1,9 @@ """Simple server written using an event loop.""" +import http.client import email.message import email.parser import os -import re import tulip import tulip.http @@ -19,36 +19,31 @@ def __init__(self): @tulip.task def handle_request(self): - line = yield from self.reader.readline() - print('request line', line) - match = re.match(rb'([A-Z]+) (\S+) HTTP/(1.\d)\r?\n\Z', line) - if not match: + try: + method, path, version = yield from self.reader.read_request_line() + except http.client.BadStatusLine: self.transport.close() return - bmethod, bpath, bversion = match.groups() + print('method = {!r}; path = {!r}; version = {!r}'.format( - bmethod, bpath, bversion)) - try: - path = bpath.decode('ascii') - except UnicodeError as exc: - print('not ascii', repr(bpath), exc) + method, path, version)) + + if (not (path.isprintable() and path.startswith('/')) or '/.' in path): + print('bad path', repr(path)) path = None else: - if (not (path.isprintable() and path.startswith('/')) or - '/.' in path): - print('bad path', repr(path)) + path = '.' + path + if not os.path.exists(path): + print('no file', repr(path)) path = None else: - path = '.' + path - if not os.path.exists(path): - print('no file', repr(path)) - path = None - else: - isdir = os.path.isdir(path) + isdir = os.path.isdir(path) + if not path: self.transport.write(b'HTTP/1.0 404 Not found\r\n\r\n') self.transport.close() return + lines = [] while True: line = yield from self.reader.readline() @@ -58,10 +53,13 @@ def handle_request(self): lines.append(line) if line == b'\r\n': break + parser = email.parser.BytesHeaderParser() - headers = parser.parsebytes(b''.join(lines)) + parser.parsebytes(b''.join(lines)) + write = self.transport.write if isdir and not path.endswith('/'): + bpath = path.encode('ascii') write(b'HTTP/1.0 302 Redirected\r\n' b'URI: ' + bpath + b'/\r\n' b'Location: ' + bpath + b'/\r\n' @@ -79,7 +77,7 @@ def handle_request(self): if name.isprintable() and not name.startswith('.'): try: bname = name.encode('ascii') - except UnicodeError as exc: + except UnicodeError: pass else: if os.path.isdir(os.path.join(path, name)): @@ -93,14 +91,14 @@ def handle_request(self): try: with open(path, 'rb') as f: write(f.read()) - except OSError as exc: + except OSError: write(b'Cannot open\r\n') self.transport.close() def connection_made(self, transport): self.transport = transport print('connection made', transport, transport.get_extra_info('socket')) - self.reader = tulip.http.StreamReader() + self.reader = tulip.http.HttpStreamReader() self.handler = self.handle_request() def data_received(self, data): diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py new file mode 100644 index 00000000..61d7ebc9 --- /dev/null +++ b/tests/http_protocol_test.py @@ -0,0 +1,142 @@ +"""Tests for http/protocol.py""" + +import http.client +import unittest +import unittest.mock + +import tulip +from tulip.http import protocol +from tulip.test_utils import LogTrackingTestCase + + +class HttpStreamReaderTests(LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.suppress_log_errors() + + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + self.transport = unittest.mock.Mock() + self.stream = protocol.HttpStreamReader() + + def tearDown(self): + self.loop.close() + super().tearDown() + + def test_request_line(self): + self.stream.feed_data(b'get /path HTTP/1.1\r\n') + self.assertEqual( + ('GET', '/path', (1, 1)), + self.loop.run_until_complete( + tulip.Task(self.stream.read_request_line()))) + + def test_request_line_two_slashes(self): + self.stream.feed_data(b'get //path HTTP/1.1\r\n') + self.assertEqual( + ('GET', '//path', (1, 1)), + self.loop.run_until_complete( + tulip.Task(self.stream.read_request_line()))) + + def test_request_line_non_ascii(self): + self.stream.feed_data(b'get /path\xd0\xb0 HTTP/1.1\r\n') + + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_request_line())) + + self.assertEqual( + b'get /path\xd0\xb0 HTTP/1.1\r\n', cm.exception.args[0]) + + def test_request_line_bad_status_line(self): + self.stream.feed_data(b'\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + tulip.Task(self.stream.read_request_line())) + + def test_request_line_bad_method(self): + self.stream.feed_data(b'!12%()+=~$ /get HTTP/1.1\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + tulip.Task(self.stream.read_request_line())) + + def test_request_line_bad_version(self): + self.stream.feed_data(b'GET //get HT/11\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + tulip.Task(self.stream.read_request_line())) + + def test_response_status_bad_status_line(self): + self.stream.feed_data(b'\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + tulip.Task(self.stream.read_response_status())) + + def test_response_status_bad_status_line_eof(self): + self.stream.feed_eof() + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + tulip.Task(self.stream.read_response_status())) + + def test_response_status_bad_status_non_ascii(self): + self.stream.feed_data(b'HTTP/1.1 200 \xd0\xb0\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_response_status())) + + self.assertEqual(b'HTTP/1.1 200 \xd0\xb0\r\n', cm.exception.args[0]) + + def test_response_status_bad_version(self): + self.stream.feed_data(b'HT/11 200 Ok\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_response_status())) + + self.assertEqual('HT/11 200 Ok', cm.exception.args[0]) + + def test_response_status_no_reason(self): + self.stream.feed_data(b'HTTP/1.1 200\r\n') + + v, s, r = self.loop.run_until_complete( + tulip.Task(self.stream.read_response_status())) + self.assertEqual(v, (1, 1)) + self.assertEqual(s, 200) + self.assertEqual(r, '') + + def test_response_status_bad(self): + self.stream.feed_data(b'HTT/1\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_response_status())) + + self.assertIn('HTT/1', str(cm.exception)) + + def test_response_status_bad_code_under_100(self): + self.stream.feed_data(b'HTTP/1.1 99 test\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_response_status())) + + self.assertIn('HTTP/1.1 99 test', str(cm.exception)) + + def test_response_status_bad_code_above_999(self): + self.stream.feed_data(b'HTTP/1.1 9999 test\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_response_status())) + + self.assertIn('HTTP/1.1 9999 test', str(cm.exception)) + + def test_response_status_bad_code_not_int(self): + self.stream.feed_data(b'HTTP/1.1 ttt test\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_response_status())) + + self.assertIn('HTTP/1.1 ttt test', str(cm.exception)) diff --git a/tulip/http/__init__.py b/tulip/http/__init__.py index cb07d9ff..16f4383b 100644 --- a/tulip/http/__init__.py +++ b/tulip/http/__init__.py @@ -1,5 +1,8 @@ # This relies on each of the submodules having an __all__ variable. + +from .protocol import * from .http_client import * -__all__ = http_client.__all__ +__all__ = (protocol.__all__ + + http_client.__all__) diff --git a/tulip/http/http_client.py b/tulip/http/http_client.py index 71f97c5d..520d162d 100644 --- a/tulip/http/http_client.py +++ b/tulip/http/http_client.py @@ -26,21 +26,21 @@ __all__ = ['HttpClientProtocol'] -import collections import email.message import email.parser -import re import tulip +from . import protocol + class HttpClientProtocol: """This Protocol class is also used to initiate the connection. Usage: p = HttpClientProtocol(url, ...) - f = p.connect() # Returns a Future - ...now what?... + sts, headers, stream = yield from p.connect() + """ def __init__(self, host, port=None, *, @@ -107,19 +107,14 @@ def validate(self, value, name, embedded_spaces_okay=False): @tulip.coroutine def connect(self): - yield from self.event_loop.create_connection(lambda: self, - self.host, - self.port, - ssl=self.ssl) + yield from self.event_loop.create_connection( + lambda: self, self.host, self.port, ssl=self.ssl) + # TODO: A better mechanism to return all info from the # status line, all headers, and the buffer, without having # an N-tuple return value. - status_line = yield from self.stream.readline() - m = re.match(rb'HTTP/(\d\.\d)\s+(\d\d\d)\s+([^\r\n]+)\r?\n\Z', - status_line) - if not m: - raise 'Invalid HTTP status line ({!r})'.format(status_line) - version, status, message = m.groups() + version, status, message = yield from self.stream.read_response_status() + raw_headers = [] while True: header = yield from self.stream.readline() @@ -137,7 +132,7 @@ def connect(self): # TODO: A wrapping stream that limits how much it can read # without reading it all into memory at once. body = yield from self.stream.readexactly(content_length) - stream = StreamReader() + stream = protocol.HttpStreamReader() stream.feed_data(body) stream.feed_eof() sts = '{} {}'.format(self.decode(status), self.decode(message)) @@ -176,7 +171,7 @@ def connection_made(self, transport): for key, value in self.headers.items(): self.write_str('{}: {}\r\n'.format(key, value)) self.transport.write(b'\r\n') - self.stream = StreamReader() + self.stream = protocol.HttpStreamReader() if self.make_body is not None: if self.chunked: self.make_body(self.write_chunked, self.write_chunked_eof) diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py new file mode 100644 index 00000000..c6f5a60b --- /dev/null +++ b/tulip/http/protocol.py @@ -0,0 +1,113 @@ +"""Http related helper utils.""" + +__all__ = ['HttpStreamReader', 'RequestLine', 'ResponseStatus'] + +import collections +import http.client +import re + +import tulip + +METHRE = re.compile('[A-Z0-9$-_.]+') +VERSRE = re.compile('HTTP/(\d+).(\d+)') + + +RequestLine = collections.namedtuple( + 'RequestLine', ['method', 'uri', 'version']) + + +ResponseStatus = collections.namedtuple( + 'ResponseStatus', ['version', 'code', 'reason']) + + +class HttpStreamReader(tulip.StreamReader): + + @tulip.coroutine + def read_request_line(self): + """Read request status line. Exception http.client.BadStatusLine + could be raised in case of any errors in status line. + Returns three values (method, uri, version) + + Example: + + GET /path HTTP/1.1 + + >> yield from reader.read_request_line() + ('GET', '/path', (1, 1)) + + """ + bline = yield from self.readline() + try: + line = bline.decode('ascii').rstrip() + except UnicodeDecodeError: + raise http.client.BadStatusLine(bline) from None + + try: + method, uri, version = line.split(None, 2) + except ValueError: + raise http.client.BadStatusLine(line) from None + + # method + method = method.upper() + if not METHRE.match(method): + raise http.client.BadStatusLine(method) + + # version + match = VERSRE.match(version) + if match is None: + raise http.client.BadStatusLine(version) + version = (int(match.group(1)), int(match.group(2))) + + return RequestLine(method, uri, version) + + @tulip.coroutine + def read_response_status(self): + """Read response status line. Exception http.client.BadStatusLine + could be raised in case of any errors in status line. + Returns three values (version, status_code, reason) + + Example: + + HTTP/1.1 200 Ok + + >> yield from reader.read_response_status() + ((1, 1), 200, 'Ok') + + """ + bline = yield from self.readline() + if not bline: + # Presumably, the server closed the connection before + # sending a valid response. + raise http.client.BadStatusLine(bline) + + try: + line = bline.decode('ascii').rstrip() + except UnicodeDecodeError: + raise http.client.BadStatusLine(bline) from None + + try: + version, status = line.split(None, 1) + except ValueError: + raise http.client.BadStatusLine(line) from None + else: + try: + status, reason = status.split(None, 1) + except ValueError: + reason = '' + + # version + match = VERSRE.match(version) + if match is None: + raise http.client.BadStatusLine(line) + version = (int(match.group(1)), int(match.group(2))) + + # The status code is a three-digit number + try: + status = int(status) + except ValueError: + raise http.client.BadStatusLine(line) from None + + if status < 100 or status > 999: + raise http.client.BadStatusLine(line) + + return ResponseStatus(version, status, reason.strip()) From 2bd936c2ae67c3aa31c9d8e3809afe5d5bd2943e Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 11 Mar 2013 21:38:01 -0700 Subject: [PATCH 0343/1502] read_headers helper for http protocol --- srv.py | 16 +++------ tests/http_protocol_test.py | 61 ++++++++++++++++++++++++++++++++++ tulip/http/http_client.py | 24 ++++++-------- tulip/http/protocol.py | 65 +++++++++++++++++++++++++++++++++++++ 4 files changed, 140 insertions(+), 26 deletions(-) diff --git a/srv.py b/srv.py index 0a7c87a4..296e157c 100644 --- a/srv.py +++ b/srv.py @@ -44,18 +44,10 @@ def handle_request(self): self.transport.close() return - lines = [] - while True: - line = yield from self.reader.readline() - print('header line', line) - if not line.strip(b' \t\r\n'): - break - lines.append(line) - if line == b'\r\n': - break - - parser = email.parser.BytesHeaderParser() - parser.parsebytes(b''.join(lines)) + headers = email.message.Message() + for hdr, val in (yield from self.reader.read_headers()): + print(hdr, val) + headers.add_header(hdr, val) write = self.transport.write if isdir and not path.endswith('/'): diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py index 61d7ebc9..1f3aa659 100644 --- a/tests/http_protocol_test.py +++ b/tests/http_protocol_test.py @@ -140,3 +140,64 @@ def test_response_status_bad_code_not_int(self): tulip.Task(self.stream.read_response_status())) self.assertIn('HTTP/1.1 ttt test', str(cm.exception)) + + def test_read_headers(self): + self.stream.feed_data(b'test: line\r\n' + b' continue\r\n' + b'test2: data\r\n' + b'\r\n') + + headers = self.loop.run_until_complete( + tulip.Task(self.stream.read_headers())) + self.assertEqual(headers, + [('TEST', 'line\r\n continue'), ('TEST2', 'data')]) + + def test_read_headers_size(self): + self.stream.feed_data(b'test: line\r\n') + self.stream.feed_data(b' continue\r\n') + self.stream.feed_data(b'test2: data\r\n') + self.stream.feed_data(b'\r\n') + + self.stream.MAX_HEADERS = 5 + self.assertRaises( + http.client.LineTooLong, + self.loop.run_until_complete, + tulip.Task(self.stream.read_headers())) + + def test_read_headers_invalid_header(self): + self.stream.feed_data(b'test line\r\n') + + with self.assertRaises(ValueError) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_headers())) + + self.assertIn("Invalid header b'test line'", str(cm.exception)) + + def test_read_headers_invalid_name(self): + self.stream.feed_data(b'test[]: line\r\n') + + with self.assertRaises(ValueError) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_headers())) + + self.assertIn("Invalid header name b'TEST[]'", str(cm.exception)) + + def test_read_headers_headers_size(self): + self.stream.MAX_HEADERFIELD_SIZE = 5 + self.stream.feed_data(b'test: line data data\r\ndata\r\n') + + with self.assertRaises(http.client.LineTooLong) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_headers())) + + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_read_headers_continuation_headers_size(self): + self.stream.MAX_HEADERFIELD_SIZE = 5 + self.stream.feed_data(b'test: line\r\n test\r\n') + + with self.assertRaises(http.client.LineTooLong) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_headers())) + + self.assertIn("limit request headers fields size", str(cm.exception)) diff --git a/tulip/http/http_client.py b/tulip/http/http_client.py index 520d162d..fc9365e1 100644 --- a/tulip/http/http_client.py +++ b/tulip/http/http_client.py @@ -110,19 +110,15 @@ def connect(self): yield from self.event_loop.create_connection( lambda: self, self.host, self.port, ssl=self.ssl) - # TODO: A better mechanism to return all info from the - # status line, all headers, and the buffer, without having - # an N-tuple return value. - version, status, message = yield from self.stream.read_response_status() - - raw_headers = [] - while True: - header = yield from self.stream.readline() - if not header.strip(): - break - raw_headers.append(header) - parser = email.parser.BytesHeaderParser() - headers = parser.parsebytes(b''.join(raw_headers)) + # read response status + version, status, reason = yield from self.stream.read_response_status() + + # read headers + headers = email.message.Message() + for hdr, val in (yield from self.stream.read_headers()): + headers.add_header(hdr, val) + + # read payload content_length = headers.get('content-length') if content_length: content_length = int(content_length) # May raise. @@ -135,7 +131,7 @@ def connect(self): stream = protocol.HttpStreamReader() stream.feed_data(body) stream.feed_eof() - sts = '{} {}'.format(self.decode(status), self.decode(message)) + sts = '{} {}'.format(status, reason) return (sts, headers, stream) def encode(self, s): diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py index c6f5a60b..dce35a2b 100644 --- a/tulip/http/protocol.py +++ b/tulip/http/protocol.py @@ -10,6 +10,8 @@ METHRE = re.compile('[A-Z0-9$-_.]+') VERSRE = re.compile('HTTP/(\d+).(\d+)') +HDRRE = re.compile(b"[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]") +CONTINUATION = (b' ', b'\t') RequestLine = collections.namedtuple( @@ -22,6 +24,9 @@ class HttpStreamReader(tulip.StreamReader): + MAX_HEADERS = 32768 + MAX_HEADERFIELD_SIZE = 8190 + @tulip.coroutine def read_request_line(self): """Read request status line. Exception http.client.BadStatusLine @@ -111,3 +116,63 @@ def read_response_status(self): raise http.client.BadStatusLine(line) return ResponseStatus(version, status, reason.strip()) + + @tulip.coroutine + def read_headers(self): + """Read and parses RFC2822 headers from a stream. + + Line continuations are supported. Returns list of header name + and value pairs. + """ + size = 0 + headers = [] + + line = yield from self.readline() + + while line not in (b'\r\n', b'\n'): + header_length = len(line) + + # Parse initial header name : value pair. + sep_pos = line.find(b':') + if sep_pos < 0: + raise ValueError('Invalid header %s' % line.strip()) + + name, value = line[:sep_pos], line[sep_pos+1:] + name = name.rstrip(b' \t').upper() + if HDRRE.search(name): + raise ValueError('Invalid header name %s' % name) + + name = name.strip().decode('ascii', 'surrogateescape') + value = [value.lstrip()] + + # next line + line = yield from self.readline() + + # consume continuation lines + continuation = line.startswith(CONTINUATION) + + if continuation: + while continuation: + header_length += len(line) + if header_length > self.MAX_HEADERFIELD_SIZE: + raise http.client.LineTooLong( + 'limit request headers fields size') + value.append(line) + + line = yield from self.readline() + continuation = line.startswith(CONTINUATION) + else: + if header_length > self.MAX_HEADERFIELD_SIZE: + raise http.client.LineTooLong( + 'limit request headers fields size') + + # total headers size + size += header_length + if size >= self.MAX_HEADERS: + raise http.client.LineTooLong('limit request headers fields') + + headers.append( + (name, + b''.join(value).rstrip().decode('ascii', 'surrogateescape'))) + + return headers From 879d8ca9b8ab2596254efc26e90e3ac622f24a70 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 13 Mar 2013 13:12:31 -0700 Subject: [PATCH 0344/1502] http message payload related helpers --- crawl.py | 5 +- curl.py | 2 +- tests/http_protocol_test.py | 146 ++++++++++++++ tests/streams_test.py | 66 ++++++- tulip/http/__init__.py | 6 +- tulip/http/{http_client.py => client.py} | 54 ++++-- tulip/http/protocol.py | 231 ++++++++++++++++++++++- tulip/streams.py | 19 ++ 8 files changed, 506 insertions(+), 23 deletions(-) rename tulip/http/{http_client.py => client.py} (80%) diff --git a/crawl.py b/crawl.py index 3dd270b8..4e5bebe2 100755 --- a/crawl.py +++ b/crawl.py @@ -71,8 +71,8 @@ def process(self, url): path = '/' if query: path = '?'.join([path, query]) - p = tulip.http.HttpClientProtocol(netloc, path=path, - ssl=(scheme=='https')) + p = tulip.http.HttpClientProtocol( + netloc, path=path, ssl=(scheme=='https')) delay = 1 while True: try: @@ -85,6 +85,7 @@ def process(self, url): 'retrying after sleep', delay, '...', end=END) yield from tulip.sleep(delay) delay *= 2 + if status[:3] in ('301', '302'): # Redirect. u = headers.get('location') or headers.get('uri') diff --git a/curl.py b/curl.py index b0566b67..0624df86 100755 --- a/curl.py +++ b/curl.py @@ -22,7 +22,7 @@ def main(): for k, v in headers.items(): print('{}: {}'.format(k, v)) print() - data = p.event_loop.run_until_complete(tulip.Task(stream.read(1000000))) + data = p.event_loop.run_until_complete(tulip.Task(stream)) print(data.decode('utf-8', 'replace')) diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py index 1f3aa659..7fee4ff7 100644 --- a/tests/http_protocol_test.py +++ b/tests/http_protocol_test.py @@ -3,6 +3,7 @@ import http.client import unittest import unittest.mock +import zlib import tulip from tulip.http import protocol @@ -201,3 +202,148 @@ def test_read_headers_continuation_headers_size(self): tulip.Task(self.stream.read_headers())) self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_read_payload_unknown_encoding(self): + self.assertRaises( + ValueError, self.stream.read_length_payload, encoding='unknown') + + def test_read_payload(self): + self.stream.feed_data(b'da') + self.stream.feed_data(b't') + self.stream.feed_data(b'ali') + self.stream.feed_data(b'ne') + + stream = self.stream.read_length_payload(4) + self.assertIsInstance(stream, tulip.StreamReader) + + data = self.loop.run_until_complete(tulip.Task(stream.read())) + self.assertEqual(b'data', data) + self.assertEqual(b'line', b''.join(self.stream.buffer)) + + def test_read_payload_eof(self): + self.stream.feed_data(b'da') + self.stream.feed_eof() + stream = self.stream.read_length_payload(4) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, tulip.Task(stream.read())) + + def test_read_payload_eof_exc(self): + self.stream.feed_data(b'da') + stream = self.stream.read_length_payload(4) + + def eof(): + yield from [] + self.stream.feed_eof() + + t1 = tulip.Task(stream.read()) + t2 = tulip.Task(eof()) + + self.loop.run_until_complete(tulip.Task(tulip.wait([t1, t2]))) + self.assertRaises(http.client.IncompleteRead, t1.result) + self.assertIsNone(self.stream._reader) + + def test_read_payload_deflate(self): + comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + + data = b''.join([comp.compress(b'data'), comp.flush()]) + stream = self.stream.read_length_payload(len(data), encoding='deflate') + + self.stream.feed_data(data) + + data = self.loop.run_until_complete(tulip.Task(stream.read())) + self.assertEqual(b'data', data) + + def _test_read_payload_compress_error(self): + data = b'123123123datadatadata' + reader = protocol.length_reader(4) + self.stream.feed_data(data) + stream = self.stream.read_payload(reader, 'deflate') + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, tulip.Task(stream.read())) + + def test_read_chunked_payload(self): + stream = self.stream.read_chunked_payload() + self.stream.feed_data(b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + + data = self.loop.run_until_complete(tulip.Task(stream.read())) + self.assertEqual(b'dataline', data) + + def test_read_chunked_payload_chunks(self): + stream = self.stream.read_chunked_payload() + + self.stream.feed_data(b'4\r\ndata\r') + self.stream.feed_data(b'\n4') + self.stream.feed_data(b'\r') + self.stream.feed_data(b'\n') + self.stream.feed_data(b'line\r\n0\r\n') + self.stream.feed_data(b'test\r\n\r\n') + + data = self.loop.run_until_complete(tulip.Task(stream.read())) + self.assertEqual(b'dataline', data) + + def test_read_chunked_payload_incomplete(self): + stream = self.stream.read_chunked_payload() + + self.stream.feed_data(b'4\r\ndata\r\n') + self.stream.feed_eof() + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, tulip.Task(stream.read())) + + def test_read_chunked_payload_extension(self): + stream = self.stream.read_chunked_payload() + + self.stream.feed_data( + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + + data = self.loop.run_until_complete(tulip.Task(stream.read())) + self.assertEqual(b'dataline', data) + + def test_read_chunked_payload_size_error(self): + stream = self.stream.read_chunked_payload() + + self.stream.feed_data(b'blah\r\n') + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, tulip.Task(stream.read())) + + def test_read_length_payload(self): + stream = self.stream.read_length_payload(8) + + self.stream.feed_data(b'data') + self.stream.feed_data(b'data') + + data = self.loop.run_until_complete(tulip.Task(stream.read())) + self.assertEqual(b'datadata', data) + + def test_read_length_payload_zero(self): + stream = self.stream.read_length_payload(0) + + self.stream.feed_data(b'data') + + data = self.loop.run_until_complete(tulip.Task(stream.read())) + self.assertEqual(b'', data) + + def test_read_length_payload_incomplete(self): + stream = self.stream.read_length_payload(8) + + self.stream.feed_data(b'data') + self.stream.feed_eof() + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, tulip.Task(stream.read())) + + def test_read_eof_payload(self): + stream = self.stream.read_eof_payload() + + self.stream.feed_data(b'data') + self.stream.feed_eof() + + data = self.loop.run_until_complete(tulip.Task(stream.read())) + self.assertEqual(b'data', data) diff --git a/tests/streams_test.py b/tests/streams_test.py index 118f6a05..0772fcde 100644 --- a/tests/streams_test.py +++ b/tests/streams_test.py @@ -104,6 +104,18 @@ def cb(): self.assertFalse(stream.byte_count) self.assertFalse(stream.line_count) + def test_read_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete(tasks.Task(stream.read(2))) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, tasks.Task(stream.read(2))) + def test_readline(self): """Read one line.""" stream = streams.StreamReader() @@ -198,7 +210,7 @@ def test_readline_read_byte_count(self): stream.feed_data(self.DATA) read_task = tasks.Task(stream.readline()) - line = self.event_loop.run_until_complete(read_task) + self.event_loop.run_until_complete(read_task) read_task = tasks.Task(stream.read(7)) data = self.event_loop.run_until_complete(read_task) @@ -210,6 +222,19 @@ def test_readline_read_byte_count(self): len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), stream.byte_count) + def test_readline_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete( + tasks.Task(stream.readline())) + self.assertEqual(b'line\n', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, tasks.Task(stream.readline())) + def test_readexactly_zero_or_less(self): """Read exact number of bytes (zero or less).""" stream = streams.StreamReader() @@ -261,6 +286,45 @@ def cb(): self.assertFalse(stream.byte_count) self.assertFalse(stream.line_count) + def test_readexactly_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete( + tasks.Task(stream.readexactly(2))) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, + tasks.Task(stream.readexactly(2))) + + def test_exception(self): + stream = streams.StreamReader() + self.assertIsNone(stream.exception()) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(stream.exception(), exc) + + def test_exception_waiter(self): + stream = streams.StreamReader() + + def set_err(): + yield from [] + stream.set_exception(ValueError()) + + def readline(): + yield from stream.readline() + + t1 = tasks.Task(stream.readline()) + t2 = tasks.Task(set_err()) + + self.event_loop.run_until_complete(tasks.Task(tasks.wait([t1, t2]))) + + self.assertRaises(ValueError, t1.result) + if __name__ == '__main__': unittest.main() diff --git a/tulip/http/__init__.py b/tulip/http/__init__.py index 16f4383b..d436383f 100644 --- a/tulip/http/__init__.py +++ b/tulip/http/__init__.py @@ -1,8 +1,8 @@ # This relies on each of the submodules having an __all__ variable. +from .client import * from .protocol import * -from .http_client import * -__all__ = (protocol.__all__ + - http_client.__all__) +__all__ = (client.__all__ + + protocol.__all__) diff --git a/tulip/http/http_client.py b/tulip/http/client.py similarity index 80% rename from tulip/http/http_client.py rename to tulip/http/client.py index fc9365e1..cf8e52cc 100644 --- a/tulip/http/http_client.py +++ b/tulip/http/client.py @@ -63,6 +63,7 @@ def __init__(self, host, port=None, *, self.path = self.validate(path, 'path') self.method = self.validate(method, 'method') self.headers = email.message.Message() + self.headers['Accept-Encoding'] = 'gzip, deflate' if headers: for key, value in headers.items(): self.validate(key, 'header key') @@ -114,25 +115,48 @@ def connect(self): version, status, reason = yield from self.stream.read_response_status() # read headers - headers = email.message.Message() - for hdr, val in (yield from self.stream.read_headers()): - headers.add_header(hdr, val) + headers = yield from self.stream.read_headers() + msg_headers = email.message.Message() + for hdr, val in headers: + msg_headers.add_header(hdr, val) + + # TODO: A wrapping stream that limits how much it can read + # without reading it all into memory at once. # read payload - content_length = headers.get('content-length') - if content_length: - content_length = int(content_length) # May raise. - if content_length is None: - stream = self.stream + chunked = False + length = None + encoding = None + + for (name, value) in headers: + if name == 'CONTENT-LENGTH': + length = value + elif name == 'TRANSFER-ENCODING': + chunked = value.lower() == 'chunked' + elif name == 'CONTENT-ENCODING': + enc = value.lower() + if enc in ('gzip', 'deflate'): + encoding = enc + + # payload + if chunked: + payload = self.stream.read_chunked_payload(encoding=encoding) + + elif length is not None: + try: + length = int(length) + except ValueError: + raise ValueError('CONTENT-LENGTH') + + if length < 0: + raise ValueError('CONTENT-LENGTH') + + payload = self.stream.read_length_payload(length, encoding=encoding) else: - # TODO: A wrapping stream that limits how much it can read - # without reading it all into memory at once. - body = yield from self.stream.readexactly(content_length) - stream = protocol.HttpStreamReader() - stream.feed_data(body) - stream.feed_eof() + payload = self.stream.read_length_payload(0, encoding=encoding) + sts = '{} {}'.format(status, reason) - return (sts, headers, stream) + return (sts, msg_headers, payload) def encode(self, s): if isinstance(s, bytes): diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py index dce35a2b..ba6bd00b 100644 --- a/tulip/http/protocol.py +++ b/tulip/http/protocol.py @@ -3,8 +3,10 @@ __all__ = ['HttpStreamReader', 'RequestLine', 'ResponseStatus'] import collections +import functools import http.client import re +import zlib import tulip @@ -22,11 +24,141 @@ 'ResponseStatus', ['version', 'code', 'reason']) +class StreamEofException(http.client.HTTPException): + """eof received""" + + +def wrap_payload_reader(f): + """wrap_payload_reader wraps payload readers and redirect stream. + payload readers are generator functions, read_chunked_payload, + read_length_payload, read_eof_payload. + payload reader allows to modify data stream and feed data into stream. + + StreamReader instance should be send to generator as first parameter. + This steam is used as destination stream for processed data. + To send data to reader use generator's send() method. + + To indicate eof stream, throw StreamEofException exception into the reader. + In case of errors in incoming stream reader sets exception to + destination stream with StreamReader.set_exception() method. + + Before exit, reader generator returns unprocessed data. + """ + + @functools.wraps(f) + def wrapper(self, *args, **kw): + assert self._reader is None + + rstream = stream = tulip.StreamReader() + + encoding = kw.pop('encoding', None) + if encoding is not None: + if encoding not in ('gzip', 'deflate'): + raise ValueError( + 'Content-Encoding %r is not supported' % encoding) + + stream = DeflateStream(stream, encoding) + + reader = f(self, *args, **kw) + next(reader) + try: + reader.send(stream) + except StopIteration: + pass + else: + # feed buffer + self.line_count = 0 + self.byte_count = 0 + while self.buffer: + try: + reader.send(self.buffer.popleft()) + except StopIteration as exc: + buf = b''.join(self.buffer) + self.buffer.clear() + reader = None + if exc.value: + self.feed_data(exc.value) + + if buf: + self.feed_data(buf) + + break + + if reader is not None: + if self.eof: + try: + reader.throw(StreamEofException()) + except StopIteration as exc: + pass + else: + self._reader = reader + + return rstream + + return wrapper + + +class DeflateStream: + """DeflateStream decomress stream and feed data into specified steram.""" + + def __init__(self, stream, encoding): + self.stream = stream + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + + self.zlib = zlib.decompressobj(wbits=zlib_mode) + + def feed_data(self, chunk): + try: + chunk = self.zlib.decompress(chunk) + except: + self.stream.set_exception(http.client.IncompleteRead(b'')) + + if chunk: + self.stream.feed_data(chunk) + + def feed_eof(self): + self.stream.feed_data(self.zlib.flush()) + if not self.zlib.eof: + self.stream.set_exception(http.client.IncompleteRead(b'')) + + self.stream.feed_eof() + + class HttpStreamReader(tulip.StreamReader): MAX_HEADERS = 32768 MAX_HEADERFIELD_SIZE = 8190 + # if _reader is set, feed_data and feed_eof sends data into + # _reader instead of self. is it being used as stream redirection for + # read_chunked_payload, read_length_payload and read_eof_payload + _reader = None + + def feed_data(self, data): + """_reader is a generator, if _reader is set, feed_data sends + incoming data into this generator untile generates stops.""" + if self._reader: + try: + self._reader.send(data) + except StopIteration as exc: + self._reader = None + if exc.value: + self.feed_data(exc.value) + else: + super().feed_data(data) + + def feed_eof(self): + """_reader is a generator, if _reader is set feed_eof throws + StreamEofException into this generator.""" + if self._reader: + try: + self._reader.throw(StreamEofException()) + except StopIteration: + self._reader = None + + super().feed_eof() + @tulip.coroutine def read_request_line(self): """Read request status line. Exception http.client.BadStatusLine @@ -122,7 +254,7 @@ def read_headers(self): """Read and parses RFC2822 headers from a stream. Line continuations are supported. Returns list of header name - and value pairs. + and value pairs. Header name is in upper case. """ size = 0 headers = [] @@ -176,3 +308,100 @@ def read_headers(self): b''.join(value).rstrip().decode('ascii', 'surrogateescape'))) return headers + + @wrap_payload_reader + def read_chunked_payload(self): + """Read chunked stream.""" + stream = yield + + try: + data = bytearray() + + while True: + # read line + if b'\n' not in data: + data.extend((yield)) + continue + + line, data = data.split(b'\n', 1) + + # Read the next chunk size from the file + i = line.find(b";") + if i >= 0: + line = line[:i] # strip chunk-extensions + try: + size = int(line, 16) + except ValueError: + raise http.client.IncompleteRead(b'') from None + + if size == 0: + break + + # read chunk + while len(data) < size: + data.extend((yield)) + + # feed stream + stream.feed_data(data[:size]) + + data = data[size:] + + # toss the CRLF at the end of the chunk + while len(data) < 2: + data.extend((yield)) + + data = data[2:] + + # read and discard trailer up to the CRLF terminator + while True: + if b'\n' in data: + line, data = data.split(b'\n', 1) + if line in (b'\r', b''): + break + else: + data.extend((yield)) + + # stream eof + stream.feed_eof() + return data + + except StreamEofException: + stream.set_exception(http.client.IncompleteRead(b'')) + except http.client.IncompleteRead as exc: + stream.set_exception(exc) + + @wrap_payload_reader + def read_length_payload(self, length): + """Read specified amount of bytes.""" + stream = yield + + try: + data = bytearray() + while length: + data.extend((yield)) + + data_len = len(data) + if data_len <= length: + stream.feed_data(data) + data = bytearray() + length -= data_len + else: + stream.feed_data(data[:length]) + data = data[length:] + length = 0 + + stream.feed_eof() + return data + except StreamEofException: + stream.set_exception(http.client.IncompleteRead(b'')) + + @wrap_payload_reader + def read_eof_payload(self): + """Read all bytes untile eof.""" + stream = yield + + try: + while True: + stream.feed_data((yield)) + except StreamEofException: + stream.feed_eof() diff --git a/tulip/streams.py b/tulip/streams.py index b10aa55c..d68c8d6d 100644 --- a/tulip/streams.py +++ b/tulip/streams.py @@ -17,6 +17,16 @@ def __init__(self, limit=2**16): self.line_count = 0 # Number of complete lines in buffer. self.eof = False # Whether we're done. self.waiter = None # A future. + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + if self.waiter is not None: + self.waiter.set_exception(exc) def feed_eof(self): self.eof = True @@ -40,6 +50,9 @@ def feed_data(self, data): @tasks.coroutine def readline(self): + if self._exception is not None: + raise self._exception + parts = [] parts_size = 0 not_enough = True @@ -80,6 +93,9 @@ def readline(self): @tasks.coroutine def read(self, n=-1): + if self._exception is not None: + raise self._exception + if not n: return b'' @@ -121,6 +137,9 @@ def read(self, n=-1): @tasks.coroutine def readexactly(self, n): + if self._exception is not None: + raise self._exception + if n <= 0: return b'' From 320ae0d60b8bb30697f6d6852846c916d595d722 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 13 Mar 2013 14:29:27 -0700 Subject: [PATCH 0345/1502] StreamReader.line_count is not required --- tests/streams_test.py | 17 +---------------- tulip/streams.py | 6 ------ 2 files changed, 1 insertion(+), 22 deletions(-) diff --git a/tests/streams_test.py b/tests/streams_test.py index 0772fcde..832ce371 100644 --- a/tests/streams_test.py +++ b/tests/streams_test.py @@ -25,14 +25,12 @@ def test_feed_empty_data(self): stream = streams.StreamReader() stream.feed_data(b'') - self.assertEqual(0, stream.line_count) self.assertEqual(0, stream.byte_count) - def test_feed_data_line_byte_count(self): + def test_feed_data_byte_count(self): stream = streams.StreamReader() stream.feed_data(self.DATA) - self.assertEqual(self.DATA.count(b'\n'), stream.line_count) self.assertEqual(len(self.DATA), stream.byte_count) def test_read_zero(self): @@ -44,7 +42,6 @@ def test_read_zero(self): data = self.event_loop.run_until_complete(read_task) self.assertEqual(b'', data) self.assertEqual(len(self.DATA), stream.byte_count) - self.assertEqual(self.DATA.count(b'\n'), stream.line_count) def test_read(self): """Read bytes.""" @@ -58,7 +55,6 @@ def cb(): data = self.event_loop.run_until_complete(read_task) self.assertEqual(self.DATA, data) self.assertFalse(stream.byte_count) - self.assertFalse(stream.line_count) def test_read_line_breaks(self): """Read bytes without line breaks.""" @@ -71,7 +67,6 @@ def test_read_line_breaks(self): self.assertEqual(b'line1', data) self.assertEqual(5, stream.byte_count) - self.assertFalse(stream.line_count) def test_read_eof(self): """Read bytes, stop at eof.""" @@ -85,7 +80,6 @@ def cb(): data = self.event_loop.run_until_complete(read_task) self.assertEqual(b'', data) self.assertFalse(stream.byte_count) - self.assertFalse(stream.line_count) def test_read_until_eof(self): """Read all bytes until eof.""" @@ -102,7 +96,6 @@ def cb(): self.assertEqual(b'chunk1\nchunk2', data) self.assertFalse(stream.byte_count) - self.assertFalse(stream.line_count) def test_read_exception(self): stream = streams.StreamReader() @@ -130,7 +123,6 @@ def cb(): line = self.event_loop.run_until_complete(read_task) self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) - self.assertFalse(stream.line_count) self.assertEqual(len(b'\n chunk4')-1, stream.byte_count) def test_readline_limit_with_existing_data(self): @@ -183,7 +175,6 @@ def test_readline_line_byte_count(self): line = self.event_loop.run_until_complete(read_task) self.assertEqual(b'line1\n', line) - self.assertEqual(self.DATA.count(b'\n')-1, stream.line_count) self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) def test_readline_eof(self): @@ -216,8 +207,6 @@ def test_readline_read_byte_count(self): data = self.event_loop.run_until_complete(read_task) self.assertEqual(b'line2\nl', data) - self.assertEqual( - 1, stream.line_count) self.assertEqual( len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), stream.byte_count) @@ -244,13 +233,11 @@ def test_readexactly_zero_or_less(self): data = self.event_loop.run_until_complete(read_task) self.assertEqual(b'', data) self.assertEqual(len(self.DATA), stream.byte_count) - self.assertEqual(self.DATA.count(b'\n'), stream.line_count) read_task = tasks.Task(stream.readexactly(-1)) data = self.event_loop.run_until_complete(read_task) self.assertEqual(b'', data) self.assertEqual(len(self.DATA), stream.byte_count) - self.assertEqual(self.DATA.count(b'\n'), stream.line_count) def test_readexactly(self): """Read exact number of bytes.""" @@ -268,7 +255,6 @@ def cb(): data = self.event_loop.run_until_complete(read_task) self.assertEqual(self.DATA + self.DATA, data) self.assertEqual(len(self.DATA), stream.byte_count) - self.assertEqual(self.DATA.count(b'\n'), stream.line_count) def test_readexactly_eof(self): """Read exact number of bytes (eof).""" @@ -284,7 +270,6 @@ def cb(): data = self.event_loop.run_until_complete(read_task) self.assertEqual(self.DATA, data) self.assertFalse(stream.byte_count) - self.assertFalse(stream.line_count) def test_readexactly_exception(self): stream = streams.StreamReader() diff --git a/tulip/streams.py b/tulip/streams.py index d68c8d6d..8d7f6236 100644 --- a/tulip/streams.py +++ b/tulip/streams.py @@ -14,7 +14,6 @@ def __init__(self, limit=2**16): self.limit = limit # Max line length. (Security feature.) self.buffer = collections.deque() # Deque of bytes objects. self.byte_count = 0 # Bytes in buffer. - self.line_count = 0 # Number of complete lines in buffer. self.eof = False # Whether we're done. self.waiter = None # A future. self._exception = None @@ -40,7 +39,6 @@ def feed_data(self, data): return self.buffer.append(data) - self.line_count += data.count(b'\n') self.byte_count += len(data) waiter = self.waiter @@ -69,7 +67,6 @@ def readline(self): head, tail = data[:ichar], data[ichar:] if tail: self.buffer.appendleft(tail) - self.line_count -= 1 not_enough = False parts.append(head) parts_size += len(head) @@ -114,7 +111,6 @@ def read(self, n=-1): data = b''.join(self.buffer) self.buffer.clear() self.byte_count = 0 - self.line_count = 0 return data parts = [] @@ -130,8 +126,6 @@ def read(self, n=-1): parts.append(data) parts_bytes += data_bytes self.byte_count -= data_bytes - if self.line_count: - self.line_count -= data.count(b'\n') return b''.join(parts) From 5f86502ee98fcc4fdbeeb7ffeb00527c31924ec3 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 13 Mar 2013 14:35:06 -0700 Subject: [PATCH 0346/1502] http message helper --- srv.py | 4 +- tests/http_protocol_test.py | 142 ++++++++++++++++++++++++++++++++++++ tulip/http/client.py | 48 ++---------- tulip/http/protocol.py | 68 ++++++++++++++++- 4 files changed, 219 insertions(+), 43 deletions(-) diff --git a/srv.py b/srv.py index 296e157c..540e63b9 100644 --- a/srv.py +++ b/srv.py @@ -44,8 +44,10 @@ def handle_request(self): self.transport.close() return + message = yield from self.reader.read_message() + headers = email.message.Message() - for hdr, val in (yield from self.reader.read_headers()): + for hdr, val in message.headers: print(hdr, val) headers.add_header(hdr, val) diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py index 7fee4ff7..98e9d6f2 100644 --- a/tests/http_protocol_test.py +++ b/tests/http_protocol_test.py @@ -347,3 +347,145 @@ def test_read_eof_payload(self): data = self.loop.run_until_complete(tulip.Task(stream.read())) self.assertEqual(b'data', data) + + def test_read_message_should_close(self): + self.stream.feed_data( + b'Host: example.com\r\nConnection: close\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + self.assertTrue(msg.should_close) + + def test_read_message_should_close_http11(self): + self.stream.feed_data( + b'Host: example.com\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(version=(1, 1)))) + self.assertFalse(msg.should_close) + + def test_read_message_should_close_http10(self): + self.stream.feed_data( + b'Host: example.com\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(version=(1, 0)))) + self.assertTrue(msg.should_close) + + def test_read_message_should_close_keep_alive(self): + self.stream.feed_data( + b'Host: example.com\r\nConnection: keep-alive\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + self.assertFalse(msg.should_close) + + def test_read_message_content_length_broken(self): + self.stream.feed_data( + b'Host: example.com\r\nContent-Length: qwe\r\n\r\n') + + self.assertRaises( + http.client.HTTPException, + self.loop.run_until_complete, + tulip.Task(self.stream.read_message())) + + def test_read_message_content_length_wrong(self): + self.stream.feed_data( + b'Host: example.com\r\nContent-Length: -1\r\n\r\n') + + self.assertRaises( + http.client.HTTPException, + self.loop.run_until_complete, + tulip.Task(self.stream.read_message())) + + def test_read_message_content_length(self): + self.stream.feed_data( + b'Host: example.com\r\nContent-Length: 2\r\n\r\n12') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(b'12', payload) + + def test_read_message_content_length_no_val(self): + self.stream.feed_data(b'Host: example.com\r\n\r\n12') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(readall=False))) + + payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(b'', payload) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_read_message_deflate(self): + self.stream.feed_data( + ('Host: example.com\r\nContent-Length: %s\r\n' + 'Content-Encoding: deflate\r\n\r\n' % + len(self._COMPRESSED)).encode()) + self.stream.feed_data(self._COMPRESSED) + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(b'data', payload) + + def test_read_message_deflate_disabled(self): + self.stream.feed_data( + ('Host: example.com\r\nContent-Encoding: deflate\r\n' + 'Content-Length: %s\r\n\r\n' % + len(self._COMPRESSED)).encode()) + self.stream.feed_data(self._COMPRESSED) + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(compression=False))) + payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(self._COMPRESSED, payload) + + def test_read_message_deflate_unknown(self): + self.stream.feed_data( + ('Host: example.com\r\nContent-Encoding: compress\r\n' + 'Content-Length: %s\r\n\r\n' % len(self._COMPRESSED)).encode()) + self.stream.feed_data(self._COMPRESSED) + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(compression=False))) + payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(self._COMPRESSED, payload) + + def test_read_message_websocket(self): + self.stream.feed_data( + b'Host: example.com\r\nSec-Websocket-Key1: 13\r\n\r\n1234567890') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(b'12345678', payload) + + def test_read_message_chunked(self): + self.stream.feed_data( + b'Host: example.com\r\nTransfer-Encoding: chunked\r\n\r\n') + self.stream.feed_data( + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(b'dataline', payload) + + def test_read_message_readall(self): + self.stream.feed_data( + b'Host: example.com\r\n\r\n') + self.stream.feed_data(b'data') + self.stream.feed_data(b'line') + self.stream.feed_eof() + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(readall=True))) + + payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(b'dataline', payload) diff --git a/tulip/http/client.py b/tulip/http/client.py index cf8e52cc..7a494447 100644 --- a/tulip/http/client.py +++ b/tulip/http/client.py @@ -114,49 +114,15 @@ def connect(self): # read response status version, status, reason = yield from self.stream.read_response_status() - # read headers - headers = yield from self.stream.read_headers() - msg_headers = email.message.Message() - for hdr, val in headers: - msg_headers.add_header(hdr, val) - - # TODO: A wrapping stream that limits how much it can read - # without reading it all into memory at once. - - # read payload - chunked = False - length = None - encoding = None - - for (name, value) in headers: - if name == 'CONTENT-LENGTH': - length = value - elif name == 'TRANSFER-ENCODING': - chunked = value.lower() == 'chunked' - elif name == 'CONTENT-ENCODING': - enc = value.lower() - if enc in ('gzip', 'deflate'): - encoding = enc - - # payload - if chunked: - payload = self.stream.read_chunked_payload(encoding=encoding) - - elif length is not None: - try: - length = int(length) - except ValueError: - raise ValueError('CONTENT-LENGTH') - - if length < 0: - raise ValueError('CONTENT-LENGTH') - - payload = self.stream.read_length_payload(length, encoding=encoding) - else: - payload = self.stream.read_length_payload(0, encoding=encoding) + message = yield from self.stream.read_message(version) + + # headers + headers = email.message.Message() + for hdr, val in message.headers: + headers.add_header(hdr, val) sts = '{} {}'.format(status, reason) - return (sts, msg_headers, payload) + return (sts, headers, message.payload) def encode(self, s): if isinstance(s, bytes): diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py index ba6bd00b..b4454f5f 100644 --- a/tulip/http/protocol.py +++ b/tulip/http/protocol.py @@ -1,6 +1,6 @@ """Http related helper utils.""" -__all__ = ['HttpStreamReader', 'RequestLine', 'ResponseStatus'] +__all__ = ['HttpStreamReader', 'HttpMessage', 'RequestLine', 'ResponseStatus'] import collections import functools @@ -24,6 +24,10 @@ 'ResponseStatus', ['version', 'code', 'reason']) +HttpMessage = collections.namedtuple( + 'HttpMessage', ['headers', 'payload', 'should_close', 'compression']) + + class StreamEofException(http.client.HTTPException): """eof received""" @@ -405,3 +409,65 @@ def read_eof_payload(self): stream.feed_data((yield)) except StreamEofException: stream.feed_eof() + + @tulip.coroutine + def read_message(self, version=(1, 1), + length=None, compression=True, readall=False): + """Read RFC2822 headers and message payload from a stream. + + read_message() automatically decompress gzip and deflate content + encoding. To prevent decompression pass compression=False. + + Returns tuple of headers, payload stream, should close flag, + compression type. + """ + # load headers + headers = yield from self.read_headers() + + # payload params + chunked = False + encoding = None + close_conn = None + + for name, value in headers: + if name == 'CONTENT-LENGTH': + length = value + elif name == 'TRANSFER-ENCODING': + chunked = value.lower() == 'chunked' + elif name == 'SEC-WEBSOCKET-KEY1': + length = 8 + elif name == "CONNECTION": + v = value.lower() + if v == "close": + close_conn = True + elif v == "keep-alive": + close_conn = False + elif compression and name == 'CONTENT-ENCODING': + enc = value.lower() + if enc in ('gzip', 'deflate'): + encoding = enc + + if close_conn is None: + close_conn = version <= (1, 0) + + # payload stream + if chunked: + payload = self.read_chunked_payload(encoding=encoding) + + elif length is not None: + try: + length = int(length) + except ValueError: + raise http.client.HTTPException('CONTENT-LENGTH') from None + + if length < 0: + raise http.client.HTTPException('CONTENT-LENGTH') + + payload = self.read_length_payload(length, encoding=encoding) + else: + if readall: + payload = self.read_eof_payload(encoding=encoding) + else: + payload = self.read_length_payload(0, encoding=encoding) + + return HttpMessage(headers, payload, close_conn, encoding) From 9b74d3c09ae6123d88b15741ed9367a8d3ab2b83 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 13 Mar 2013 15:07:26 -0700 Subject: [PATCH 0347/1502] basic udp support --- examples/udp_echo.py | 71 +++++++ runtests.py | 11 +- tests/base_events_test.py | 40 ++-- tests/events_test.py | 323 ++++++++++++++++++++++++++++++-- tests/selector_events_test.py | 337 ++++++++++++++++++++++++++++++++-- tulip/base_events.py | 90 ++++++++- tulip/events.py | 13 +- tulip/protocols.py | 18 +- tulip/selector_events.py | 122 +++++++++++- 9 files changed, 972 insertions(+), 53 deletions(-) create mode 100644 examples/udp_echo.py diff --git a/examples/udp_echo.py b/examples/udp_echo.py new file mode 100644 index 00000000..b76aa07a --- /dev/null +++ b/examples/udp_echo.py @@ -0,0 +1,71 @@ +"""UDP echo example. + +Start server: + + >> python ./udp_echo.py --server + +""" + +import sys +import tulip + +ADDRESS = ('localhost', 10000) + + +class MyUdpEchoProtocol: + + def connection_made(self, transport): + print('start', transport) + self.transport = transport + + def datagram_received(self, data, addr): + print('Data received:', data, addr) + self.transport.sendto(data, addr) + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('stop', exc) + + +class MyClientUdpEchoProtocol: + + message = 'This is the message. It will be repeated.' + + def connection_made(self, transport): + self.transport = transport + print('sending "%s"' % self.message) + self.transport.sendto(self.message.encode()) + print('waiting to receive') + + def datagram_received(self, data, addr): + print('received "%s"' % data.decode()) + self.transport.close() + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('closing transport', exc) + loop = tulip.get_event_loop() + loop.stop() + + +def start_server(): + loop = tulip.get_event_loop() + loop.start_serving_datagram(MyUdpEchoProtocol, *ADDRESS) + loop.run_forever() + + +def start_client(): + loop = tulip.get_event_loop() + loop.create_datagram_connection(MyClientUdpEchoProtocol, *ADDRESS) + loop.run_forever() + + +if __name__ == '__main__': + if '--server' in sys.argv: + start_server() + else: + start_client() diff --git a/runtests.py b/runtests.py index 89204e88..e4678d70 100644 --- a/runtests.py +++ b/runtests.py @@ -180,9 +180,14 @@ def runcoverage(sdir, args): mods = [source for _, source in load_modules(sdir)] coverage = [sys.executable, '-m', 'coverage'] - subprocess.check_call(coverage + ['run', '--branch', 'runtests.py'] + args) - subprocess.check_call(coverage + ['html'] + mods) - subprocess.check_call(coverage + ['report'] + mods) + try: + subprocess.check_call( + coverage + ['run', '--branch', 'runtests.py'] + args) + except: + pass + else: + subprocess.check_call(coverage + ['html'] + mods) + subprocess.check_call(coverage + ['report'] + mods) if __name__ == '__main__': diff --git a/tests/base_events_test.py b/tests/base_events_test.py index 34f04f6d..f5bfc363 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -31,6 +31,9 @@ def test_not_implemented(self): self.assertRaises( NotImplementedError, self.event_loop._make_ssl_transport, m, m, m, m) + self.assertRaises( + NotImplementedError, + self.event_loop._make_datagram_transport, m, m) self.assertRaises( NotImplementedError, self.event_loop._process_events, []) self.assertRaises( @@ -86,7 +89,8 @@ def test_getnameinfo(self): self.event_loop.run_in_executor.call_args[0]) def test_call_soon(self): - def cb(): pass + def cb(): + pass h = self.event_loop.call_soon(cb) self.assertEqual(h._callback, cb) @@ -94,7 +98,8 @@ def cb(): pass self.assertIn(h, self.event_loop._ready) def test_call_later(self): - def cb(): pass + def cb(): + pass h = self.event_loop.call_later(10.0, cb) self.assertIsInstance(h, events.Timer) @@ -102,14 +107,16 @@ def cb(): pass self.assertNotIn(h, self.event_loop._ready) def test_call_later_no_delay(self): - def cb(): pass + def cb(): + pass h = self.event_loop.call_later(0, cb) self.assertIn(h, self.event_loop._ready) self.assertNotIn(h, self.event_loop._scheduled) def test_run_once_in_executor_handler(self): - def cb(): pass + def cb(): + pass self.assertRaises( AssertionError, self.event_loop.run_in_executor, @@ -119,7 +126,8 @@ def cb(): pass None, events.Timer(10, cb, ())) def test_run_once_in_executor_canceled(self): - def cb(): pass + def cb(): + pass h = events.Handler(cb, ()) h.cancel() @@ -128,7 +136,8 @@ def cb(): pass self.assertTrue(f.done()) def test_run_once_in_executor(self): - def cb(): pass + def cb(): + pass h = events.Handler(cb, ()) f = futures.Future() executor = unittest.mock.Mock() @@ -152,8 +161,8 @@ def test_run_once(self): self.assertTrue(self.event_loop._run_once.called) def test__run_once(self): - h1 = events.Timer(time.monotonic() + 0.1, lambda:True, ()) - h2 = events.Timer(time.monotonic() + 10.0, lambda:True, ()) + h1 = events.Timer(time.monotonic() + 0.1, lambda: True, ()) + h2 = events.Timer(time.monotonic() + 10.0, lambda: True, ()) h1.cancel() @@ -177,7 +186,7 @@ def test__run_once_timeout(self): def test__run_once_timeout_with_ready(self): """If event loop has ready callbacks, select timeout is always 0.""" - h = events.Timer(time.monotonic() + 10.0, lambda:True, ()) + h = events.Timer(time.monotonic() + 10.0, lambda: True, ()) self.event_loop._process_events = unittest.mock.Mock() self.event_loop._scheduled.append(h) @@ -192,6 +201,7 @@ def test__run_once_logging(self, m_logging, m_time): """Log to INFO level if timeout > 1.0 sec.""" idx = -1 data = [10.0, 10.0, 12.0, 13.0] + def monotonic(): nonlocal data, idx idx += 1 @@ -201,7 +211,7 @@ def monotonic(): m_logging.INFO = logging.INFO m_logging.DEBUG = logging.DEBUG - self.event_loop._scheduled.append(events.Timer(11.0, lambda:True, ())) + self.event_loop._scheduled.append(events.Timer(11.0, lambda: True, ())) self.event_loop._process_events = unittest.mock.Mock() self.event_loop._run_once() self.assertEqual(logging.INFO, m_logging.log.call_args[0][0]) @@ -215,10 +225,11 @@ def monotonic(): def test__run_once_schedule_handler(self): handler = None processed = False + def cb(event_loop): nonlocal processed, handler processed = True - handler = event_loop.call_soon(lambda:True) + handler = event_loop.call_soon(lambda: True) h = events.Timer(time.monotonic() - 1, cb, (self.event_loop,)) @@ -235,17 +246,20 @@ def test_create_connection_mutiple_errors(self, m_socket): class MyProto(protocols.Protocol): pass + def getaddrinfo(*args, **kw): yield from [] - return [(2,1,6,'',('107.6.106.82',80)), - (2,1,6,'',('107.6.106.82',80))] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] idx = -1 errors = ['err1', 'err2'] + def _socket(*args, **kw): nonlocal idx, errors idx += 1 raise socket.error(errors[idx]) + m_socket.socket = _socket m_socket.error = socket.error diff --git a/tests/events_test.py b/tests/events_test.py index 807aa501..9da8e105 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -4,7 +4,6 @@ import errno import gc import os -import select import signal import socket try: @@ -18,7 +17,6 @@ import unittest.mock from tulip import events -from tulip import futures from tulip import transports from tulip import protocols from tulip import selector_events @@ -51,6 +49,29 @@ def connection_lost(self, exc): self.state = 'CLOSED' +class MyDatagramProto(protocols.DatagramProtocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + + class EventLoopTestsMixin: def setUp(self): @@ -68,8 +89,10 @@ def test_run(self): def test_call_later(self): results = [] + def callback(arg): results.append(arg) + self.event_loop.call_later(0.1, callback, 'hello world') t0 = time.monotonic() self.event_loop.run() @@ -79,8 +102,10 @@ def callback(arg): def test_call_repeatedly(self): results = [] + def callback(arg): results.append(arg) + self.event_loop.call_repeatedly(0.03, callback, 'ho') self.event_loop.call_later(0.1, self.event_loop.stop) self.event_loop.run() @@ -88,16 +113,20 @@ def callback(arg): def test_call_soon(self): results = [] + def callback(arg1, arg2): results.append((arg1, arg2)) + self.event_loop.call_soon(callback, 'hello', 'world') self.event_loop.run() self.assertEqual(results, [('hello', 'world')]) def test_call_soon_with_handler(self): results = [] + def callback(): results.append('yeah') + handler = events.Handler(callback, ()) self.assertIs(self.event_loop.call_soon(handler), handler) self.event_loop.run() @@ -105,10 +134,13 @@ def callback(): def test_call_soon_threadsafe(self): results = [] + def callback(arg): results.append(arg) + def run(): self.event_loop.call_soon_threadsafe(callback, 'hello') + t = threading.Thread(target=run) self.event_loop.call_later(0.1, callback, 'world') t0 = time.monotonic() @@ -121,8 +153,10 @@ def run(): def test_call_soon_threadsafe_same_thread(self): results = [] + def callback(arg): results.append(arg) + self.event_loop.call_later(0.1, callback, 'world') self.event_loop.call_soon_threadsafe(callback, 'hello') self.event_loop.run() @@ -130,12 +164,15 @@ def callback(arg): def test_call_soon_threadsafe_with_handler(self): results = [] + def callback(arg): results.append(arg) handler = events.Handler(callback, ('hello',)) + def run(): - self.assertIs(self.event_loop.call_soon_threadsafe(handler),handler) + self.assertIs( + self.event_loop.call_soon_threadsafe(handler), handler) t = threading.Thread(target=run) self.event_loop.call_later(0.1, callback, 'world') @@ -178,6 +215,7 @@ def run(arg): def test_reader_callback(self): r, w = test_utils.socketpair() bytes_read = [] + def reader(): try: data = r.recv(1024) @@ -190,6 +228,7 @@ def reader(): else: self.assertTrue(self.event_loop.remove_reader(r.fileno())) r.close() + self.event_loop.add_reader(r.fileno(), reader) self.event_loop.call_later(0.05, w.send, b'abc') self.event_loop.call_later(0.1, w.send, b'def') @@ -200,6 +239,7 @@ def reader(): def test_reader_callback_with_handler(self): r, w = test_utils.socketpair() bytes_read = [] + def reader(): try: data = r.recv(1024) @@ -225,6 +265,7 @@ def reader(): def test_reader_callback_cancel(self): r, w = test_utils.socketpair() bytes_read = [] + def reader(): try: data = r.recv(1024) @@ -236,6 +277,7 @@ def reader(): handler.cancel() if not data: r.close() + handler = self.event_loop.add_reader(r.fileno(), reader) self.event_loop.call_later(0.05, w.send, b'abc') self.event_loop.call_later(0.1, w.send, b'def') @@ -247,8 +289,10 @@ def test_writer_callback(self): r, w = test_utils.socketpair() w.setblocking(False) self.event_loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + def remove_writer(): self.assertTrue(self.event_loop.remove_writer(w.fileno())) + self.event_loop.call_later(0.1, remove_writer) self.event_loop.run() w.close() @@ -261,8 +305,10 @@ def test_writer_callback_with_handler(self): w.setblocking(False) handler = events.Handler(w.send, (b'x'*(256*1024),)) self.assertIs(self.event_loop.add_writer(w.fileno(), handler), handler) + def remove_writer(): self.assertTrue(self.event_loop.remove_writer(w.fileno())) + self.event_loop.call_later(0.1, remove_writer) self.event_loop.run() w.close() @@ -273,9 +319,11 @@ def remove_writer(): def test_writer_callback_cancel(self): r, w = test_utils.socketpair() w.setblocking(False) + def sender(): w.send(b'x'*256) handler.cancel() + handler = self.event_loop.add_writer(w.fileno(), sender) self.event_loop.run() w.close() @@ -308,7 +356,6 @@ def test_sock_client_fail(self): sock.close() def test_sock_accept(self): - el = events.get_event_loop() listener = socket.socket() listener.setblocking(False) listener.bind(('127.0.0.1', 0)) @@ -328,6 +375,7 @@ def test_sock_accept(self): @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') def test_add_signal_handler(self): caught = 0 + def my_handler(): nonlocal caught caught += 1 @@ -372,6 +420,7 @@ def my_handler(): def test_cancel_signal_handler(self): # Cancelling the handler should remove it (eventually). caught = 0 + def my_handler(): nonlocal caught caught += 1 @@ -386,11 +435,14 @@ def my_handler(): def test_signal_handling_while_selecting(self): # Test with a signal actually arriving during a select() call. caught = 0 + def my_handler(): nonlocal caught caught += 1 - handler = self.event_loop.add_signal_handler(signal.SIGALRM, my_handler) + self.event_loop.add_signal_handler( + signal.SIGALRM, my_handler) + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. self.event_loop.call_later(0.15, self.event_loop.stop) self.event_loop.run_forever() @@ -400,14 +452,15 @@ def my_handler(): def test_signal_handling_args(self): some_args = (42,) caught = 0 + def my_handler(*args): nonlocal caught caught += 1 self.assertEqual(args, some_args) - handler = self.event_loop.add_signal_handler(signal.SIGALRM, - my_handler, - *some_args) + self.event_loop.add_signal_handler( + signal.SIGALRM, my_handler, *some_args) + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. self.event_loop.call_later(0.15, self.event_loop.stop) self.event_loop.run_forever() @@ -426,7 +479,8 @@ def test_create_connection_sock(self): # TODO: This depends on xkcd.com behavior! sock = None infos = self.event_loop.run_until_complete( - self.event_loop.getaddrinfo('xkcd.com', 80,type=socket.SOCK_STREAM)) + self.event_loop.getaddrinfo( + 'xkcd.com', 80, type=socket.SOCK_STREAM)) for family, type, proto, cname, address in infos: try: sock = socket.socket(family=family, type=type, proto=proto) @@ -492,8 +546,8 @@ def test_create_connection_mutiple_errors(self): def getaddrinfo(*args, **kw): yield from [] - return [(2,1,6,'',('107.6.106.82',80)), - (2,1,6,'',('107.6.106.82',80))] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] self.event_loop.getaddrinfo = getaddrinfo self.event_loop.sock_connect = unittest.mock.Mock() self.event_loop.sock_connect.side_effect = socket.error @@ -504,6 +558,7 @@ def getaddrinfo(*args, **kw): def test_start_serving(self): proto = None + def factory(): nonlocal proto proto = MyProto() @@ -564,7 +619,8 @@ def test_start_serving_sock(self): def test_start_serving_host_port_sock(self): self.suppress_log_errors() - fut = self.event_loop.start_serving(MyProto,'0.0.0.0',0,sock=object()) + fut = self.event_loop.start_serving( + MyProto, '0.0.0.0', 0, sock=object()) self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) def test_start_serving_no_host_port_sock(self): @@ -589,7 +645,8 @@ class Err(socket.error): pass m_socket.error = socket.error - m_socket.getaddrinfo.return_value = [(2, 1, 6, '', ('127.0.0.1',10100))] + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] m_sock = m_socket.socket.return_value = unittest.mock.Mock() m_sock.setsockopt.side_effect = Err @@ -597,6 +654,204 @@ class Err(socket.error): self.assertRaises(Err, self.event_loop.run_until_complete, fut) self.assertTrue(m_sock.close.called) + def test_create_datagram_connection(self): + server = None + + def factory(): + nonlocal server + server = TestMyDatagramProto() + return server + + class TestMyDatagramProto(MyDatagramProto): + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + f = self.event_loop.start_serving_datagram(factory, '0.0.0.0', 0) + sock = self.event_loop.run_until_complete(f) + host, port = sock.getsockname() + + f = self.event_loop.create_datagram_connection( + MyDatagramProto, host, port) + transport, protocol = self.event_loop.run_until_complete(f) + + self.assertEqual('INITIALIZED', protocol.state) + transport.sendto(b'xxx') + self.event_loop.run_once() + self.assertEqual(0, server.nbytes) + self.event_loop.run_once() + self.assertEqual(3, server.nbytes) + self.event_loop.run_once() + + # received + self.event_loop.run_once() + self.assertEqual(8, protocol.nbytes) + + # extra info is available + self.assertIsNotNone(transport.get_extra_info('socket')) + conn = transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + + # close connection + transport.close() + + self.event_loop.run_once() + self.assertEqual('CLOSED', protocol.state) + + server.transport.close() + + def test_create_datagram_connection_no_connection(self): + server = None + + def factory(): + nonlocal server + server = TestMyDatagramProto() + return server + + class TestMyDatagramProto(MyDatagramProto): + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + f = self.event_loop.start_serving_datagram(factory, '0.0.0.0', 0) + sock = self.event_loop.run_until_complete(f) + host, port = sock.getsockname() + + f = self.event_loop.create_datagram_connection(MyDatagramProto) + transport, protocol = self.event_loop.run_until_complete(f) + + self.assertEqual('INITIALIZED', protocol.state) + transport.sendto(b'xxx', (host, port)) + self.event_loop.run_once() + self.assertEqual(0, server.nbytes) + self.event_loop.run_once() + self.assertEqual(3, server.nbytes) + self.event_loop.run_once() + + # received + self.event_loop.run_once() + self.assertEqual(8, protocol.nbytes) + + # extra info is available + self.assertIsNotNone(transport.get_extra_info('socket')) + conn = transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + + # close connection + transport.close() + + self.event_loop.run_once() + self.assertEqual('CLOSED', protocol.state) + + server.transport.close() + + def test_create_datagram_connection_no_getaddrinfo(self): + self.suppress_log_errors() + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + fut = self.event_loop.create_datagram_connection( + protocols.DatagramProtocol, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, fut) + + def test_create_datagram_connection_connect_err(self): + self.suppress_log_errors() + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + fut = self.event_loop.create_datagram_connection( + protocols.DatagramProtocol, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, fut) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_connection_sockopt_err(self, m_socket): + self.suppress_log_errors() + + m_socket.error = socket.error + m_socket.socket.return_value.setsockopt.side_effect = socket.error + + fut = self.event_loop.create_datagram_connection( + protocols.DatagramProtocol) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, fut) + self.assertTrue( + m_socket.socket.return_value.close.called) + + def test_start_serving_datagram(self): + class TestMyDatagramProto(MyDatagramProto): + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + proto = None + + def factory(): + nonlocal proto + proto = TestMyDatagramProto() + return proto + + f = self.event_loop.start_serving_datagram(factory, '0.0.0.0', 0) + sock = self.event_loop.run_until_complete(f) + self.assertEqual('INITIALIZED', proto.state) + + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + client.sendto(b'xxx', ('127.0.0.1', port)) + self.event_loop.run_once() + self.assertEqual(0, proto.nbytes) + self.event_loop.run_once() + self.assertEqual(3, proto.nbytes) + + data, server = client.recvfrom(4096) + self.assertEqual(b'resp:xxx', data) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('socket')) + conn = proto.transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + self.assertEqual( + '0.0.0.0', proto.transport.get_extra_info('addr')[0]) + + # close connection + proto.transport.close() + + self.event_loop.run_once() + self.assertEqual('CLOSED', proto.state) + + client.close() + + def test_start_serving_datagram_no_getaddrinfoc(self): + self.suppress_log_errors() + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + f = self.event_loop.start_serving_datagram( + MyDatagramProto, '0.0.0.0', 0) + + self.assertRaises( + socket.error, self.event_loop.run_until_complete, f) + + @unittest.mock.patch('tulip.base_events.socket') + def test_start_serving_datagram_cant_bind(self, m_socket): + self.suppress_log_errors() + + class Err(socket.error): + pass + + m_socket.error = socket.error + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.setsockopt.side_effect = Err + + fut = self.event_loop.start_serving_datagram( + MyDatagramProto, '0.0.0.0', 0) + self.assertRaises(Err, self.event_loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + def test_accept_connection_retry(self): class Err(socket.error): errno = errno.EAGAIN @@ -660,6 +915,18 @@ def test_accept_connection_retry(self): def test_accept_connection_exception(self): raise unittest.SkipTest( "IocpEventLoop does not have _accept_connection()") + def test_create_datagram_connection(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_connection()") + def test_start_serving_datagram(self): + raise unittest.SkipTest( + "IocpEventLoop does not have start_serving_datagram()") + def test_start_serving_datagram_no_getaddrinfoc(self): + raise unittest.SkipTest( + "IocpEventLoop does not have start_serving_datagram()") + def test_start_serving_datagram_cant_bind(self, m_socket): + raise unittest.SkipTest( + "IocpEventLoop does not have start_serving_udp()") else: from tulip import selectors @@ -669,7 +936,8 @@ def test_accept_connection_exception(self): class KqueueEventLoopTests(EventLoopTestsMixin, test_utils.LogTrackingTestCase): def create_event_loop(self): - return unix_events.SelectorEventLoop(selectors.KqueueSelector()) + return unix_events.SelectorEventLoop( + selectors.KqueueSelector()) if hasattr(selectors, 'EpollSelector'): class EPollEventLoopTests(EventLoopTestsMixin, @@ -724,8 +992,8 @@ def callback(*args): h2 = events.make_handler(h1, ()) self.assertIs(h1, h2) - self.assertRaises(AssertionError, - events.make_handler, h1, (1,2,)) + self.assertRaises( + AssertionError, events.make_handler, h1, (1, 2)) class TimerTests(unittest.TestCase): @@ -827,6 +1095,11 @@ def test_not_imlemented(self): NotImplementedError, ev_loop.create_connection, f) self.assertRaises( NotImplementedError, ev_loop.start_serving, f) + self.assertRaises( + NotImplementedError, ev_loop.create_datagram_connection, f) + self.assertRaises( + NotImplementedError, ev_loop.start_serving_datagram, + f, 'localhost', 8080) self.assertRaises( NotImplementedError, ev_loop.add_reader, 1, f) self.assertRaises( @@ -849,6 +1122,23 @@ def test_not_imlemented(self): NotImplementedError, ev_loop.remove_signal_handler, 1) +class ProtocolsAbsTests(unittest.TestCase): + + def test_empty(self): + f = unittest.mock.Mock() + p = protocols.Protocol() + self.assertIsNone(p.connection_made(f)) + self.assertIsNone(p.connection_lost(f)) + self.assertIsNone(p.data_received(f)) + self.assertIsNone(p.eof_received()) + + dp = protocols.DatagramProtocol() + self.assertIsNone(dp.connection_made(f)) + self.assertIsNone(dp.connection_lost(f)) + self.assertIsNone(dp.connection_refused(f)) + self.assertIsNone(dp.datagram_received(f, f)) + + class PolicyTests(unittest.TestCase): def test_event_loop_policy(self): @@ -906,6 +1196,7 @@ def test_set_event_loop_policy(self): policy = events.DefaultEventLoopPolicy() events.set_event_loop_policy(policy) self.assertIs(policy, events.get_event_loop_policy()) + self.assertIsNot(policy, old_policy) if __name__ == '__main__': diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index 80f0d73d..b8fb10b0 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -14,6 +14,7 @@ from tulip.selector_events import BaseSelectorEventLoop from tulip.selector_events import _SelectorSslTransport from tulip.selector_events import _SelectorSocketTransport +from tulip.selector_events import _SelectorDatagramTransport class TestBaseSelectorEventLoop(BaseSelectorEventLoop): @@ -273,10 +274,10 @@ def test_sock_connect(self): sock = unittest.mock.Mock() self.event_loop._sock_connect = unittest.mock.Mock() - f = self.event_loop.sock_connect(sock, ('127.0.0.1',8080)) + f = self.event_loop.sock_connect(sock, ('127.0.0.1', 8080)) self.assertIsInstance(f, futures.Future) self.assertEqual( - (f, False, sock, ('127.0.0.1',8080)), + (f, False, sock, ('127.0.0.1', 8080)), self.event_loop._sock_connect.call_args[0]) def test__sock_connect(self): @@ -285,7 +286,7 @@ def test__sock_connect(self): sock = unittest.mock.Mock() sock.fileno.return_value = 10 - self.event_loop._sock_connect(f, False, sock, ('127.0.0.1',8080)) + self.event_loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) self.assertTrue(f.done()) self.assertTrue(sock.connect.called) @@ -295,7 +296,7 @@ def test__sock_connect_canceled_fut(self): f = futures.Future() f.cancel() - self.event_loop._sock_connect(f, False, sock, ('127.0.0.1',8080)) + self.event_loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) self.assertFalse(sock.connect.called) def test__sock_connect_unregister(self): @@ -306,7 +307,7 @@ def test__sock_connect_unregister(self): f.cancel() self.event_loop.remove_writer = unittest.mock.Mock() - self.event_loop._sock_connect(f, True, sock, ('127.0.0.1',8080)) + self.event_loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) self.assertEqual((10,), self.event_loop.remove_writer.call_args[0]) def test__sock_connect_tryagain(self): @@ -318,10 +319,10 @@ def test__sock_connect_tryagain(self): self.event_loop.add_writer = unittest.mock.Mock() self.event_loop.remove_writer = unittest.mock.Mock() - self.event_loop._sock_connect(f, True, sock, ('127.0.0.1',8080)) + self.event_loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) self.assertEqual( (10, self.event_loop._sock_connect, f, - True, sock, ('127.0.0.1',8080)), + True, sock, ('127.0.0.1', 8080)), self.event_loop.add_writer.call_args[0]) def test__sock_connect_exception(self): @@ -331,7 +332,7 @@ def test__sock_connect_exception(self): sock.getsockopt.return_value = errno.ENOTCONN self.event_loop.remove_writer = unittest.mock.Mock() - self.event_loop._sock_connect(f, True, sock, ('127.0.0.1',8080)) + self.event_loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) self.assertIsInstance(f.exception(), socket.error) def test_sock_accept(self): @@ -561,15 +562,14 @@ def setUp(self): self.protocol = unittest.mock.Mock() def test_ctor(self): - transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + _SelectorSocketTransport(self.event_loop, self.sock, self.protocol) self.assertTrue(self.event_loop.add_reader.called) self.assertTrue(self.event_loop.call_soon.called) def test_ctor_with_waiter(self): fut = futures.Future() - transport = _SelectorSocketTransport( + _SelectorSocketTransport( self.event_loop, self.sock, self.protocol, fut) self.assertEqual(2, self.event_loop.call_soon.call_count) self.assertEqual(fut.set_result, @@ -1104,3 +1104,318 @@ class Err(socket.error): self.transport._on_ready() self.assertTrue(self.transport._fatal_error.called) self.assertEqual([], self.transport._buffer) + + +class SelectorDatagramTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock() + self.sock = unittest.mock.Mock() + self.protocol = unittest.mock.Mock() + + def test_read_ready(self): + datagram_received = unittest.mock.Mock() + self.protocol.datagram_received = datagram_received + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + + self.sock.recvfrom.return_value = (b'data', ('0.0.0.0', 1234)) + transport._read_ready() + + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (datagram_received, b'data', ('0.0.0.0', 1234)), + self.event_loop.call_soon.call_args[0]) + + def test_read_ready_tryagain(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + + class Err(socket.error): + errno = errno.EAGAIN + + self.sock.recvfrom.side_effect = Err + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + + def test_read_ready_err(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + + class Err(socket.error): + pass + + self.sock.recvfrom.side_effect = Err + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + self.assertTrue(transport._fatal_error.called) + self.assertIsInstance(transport._fatal_error.call_args[0][0], Err) + + def test_abort(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + + transport.abort() + self.assertTrue(transport._fatal_error.called) + self.assertIsNone(transport._fatal_error.call_args[0][0]) + + def test_sendto(self): + data = b'data' + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport.sendto(data, ('0.0.0.0', 1234)) + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234))) + + def test_sendto_no_data(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append((b'data', ('0.0.0.0', 12345))) + transport.sendto(b'', ()) + self.assertFalse(self.sock.sendto.called) + self.assertEqual( + [(b'data', ('0.0.0.0', 12345))], list(transport._buffer)) + + def test_sendto_buffer(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append((b'data1', ('0.0.0.0', 12345))) + transport.sendto(b'data2', ('0.0.0.0', 12345)) + self.assertFalse(self.sock.sendto.called) + self.assertEqual( + [(b'data1', ('0.0.0.0', 12345)), + (b'data2', ('0.0.0.0', 12345))], + list(transport._buffer)) + + def test_sendto_tryagain(self): + data = b'data' + + class Err(socket.error): + errno = errno.EAGAIN + + self.sock.sendto.side_effect = Err + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport.sendto(data, ('0.0.0.0', 12345)) + + self.assertTrue(self.event_loop.add_writer.called) + self.assertEqual( + transport._sendto_ready, + self.event_loop.add_writer.call_args[0][1]) + + self.assertEqual( + [(b'data', ('0.0.0.0', 12345))], list(transport._buffer)) + + def test_sendto_exception(self): + data = b'data' + + class Err(socket.error): + pass + + self.sock.sendto.side_effect = Err + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport.sendto(data, ()) + + self.assertTrue(transport._fatal_error.called) + self.assertIsInstance(transport._fatal_error.call_args[0][0], Err) + + def test_sendto_connection_refused(self): + data = b'data' + + self.sock.sendto.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport.sendto(data, ()) + + self.assertFalse(transport._fatal_error.called) + + def test_sendto_connection_refused_connected(self): + data = b'data' + + self.sock.send.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1)) + transport._fatal_error = unittest.mock.Mock() + transport.sendto(data) + + self.assertTrue(transport._fatal_error.called) + + def test_sendto_str(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + self.assertRaises(AssertionError, transport.sendto, 'str', ()) + + def test_sendto_connected_addr(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1)) + self.assertRaises( + AssertionError, transport.sendto, b'str', ('0.0.0.0', 2)) + + def test_sendto_closing(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport.close() + self.assertRaises(AssertionError, transport.sendto, b'data', ()) + + def test_sendto_ready(self): + data = b'data' + self.sock.sendto.return_value = len(data) + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append((data, ('0.0.0.0', 12345))) + transport._sendto_ready() + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (data, ('0.0.0.0', 12345))) + self.assertTrue(self.event_loop.remove_writer.called) + + def test_sendto_ready_closing(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._closing = True + transport._buffer.append((data, ())) + transport._sendto_ready() + self.assertTrue(self.sock.sendto.called) + self.assertEqual(self.sock.sendto.call_args[0], (data, ())) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (transport._call_connection_lost, None), + self.event_loop.call_soon.call_args[0]) + + def test_sendto_ready_no_data(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._sendto_ready() + self.assertFalse(self.sock.sendto.called) + self.assertTrue(self.event_loop.remove_writer.called) + + def test_sendto_ready_tryagain(self): + class Err(socket.error): + errno = errno.EAGAIN + + self.sock.sendto.side_effect = Err + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.extend([(b'data1', ()), (b'data2', ())]) + transport._sendto_ready() + + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual( + [(b'data1', ()), (b'data2', ())], + list(transport._buffer)) + + def test_sendto_ready_exception(self): + class Err(socket.error): + pass + + self.sock.sendto.side_effect = Err + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._buffer.append((b'data', ())) + transport._sendto_ready() + + self.assertTrue(transport._fatal_error.called) + self.assertIsInstance(transport._fatal_error.call_args[0][0], Err) + + def test_sendto_ready_connection_refused(self): + self.sock.sendto.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._buffer.append((b'data', ())) + transport._sendto_ready() + + self.assertFalse(transport._fatal_error.called) + + def test_sendto_ready_connection_refused_connection(self): + self.sock.send.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1)) + transport._fatal_error = unittest.mock.Mock() + transport._buffer.append((b'data', ())) + transport._sendto_ready() + + self.assertTrue(transport._fatal_error.called) + + def test_close(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport.close() + + self.assertTrue(transport._closing) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (transport._call_connection_lost, None), + self.event_loop.call_soon.call_args[0]) + + def test_close_write_buffer(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + self.event_loop.reset_mock() + transport._buffer.append((b'data', ())) + transport.close() + + self.assertTrue(self.event_loop.remove_reader.called) + self.assertFalse(self.event_loop.call_soon.called) + + def test_fatal_error(self): + exc = object() + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + self.event_loop.reset_mock() + transport._buffer.append((b'data', ())) + transport._fatal_error(exc) + + self.assertEqual([], list(transport._buffer)) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (transport._call_connection_lost, exc), + self.event_loop.call_soon.call_args[0]) + + def test_fatal_error_connected(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1)) + self.event_loop.reset_mock() + transport._fatal_error(ConnectionRefusedError()) + + self.assertEqual( + 2, self.event_loop.call_soon.call_count) + + def test_transport_closing(self): + exc = object() + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + self.sock.reset_mock() + self.protocol.reset_mock() + transport._call_connection_lost(exc) + + self.assertTrue(self.protocol.connection_lost.called) + self.assertEqual( + (exc,), self.protocol.connection_lost.call_args[0]) + self.assertTrue(self.sock.close.called) diff --git a/tulip/base_events.py b/tulip/base_events.py index 09bdd617..9e2e161c 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -50,14 +50,20 @@ def __init__(self): self._internal_fds = 0 self._signal_handlers = {} - def _make_socket_transport(self, sock, protocol, waiter=None): + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): """Create socket transport.""" raise NotImplementedError - def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter): + def _make_ssl_transport(self, rawsock, protocol, + sslcontext, waiter, extra=None): """Create SSL transport.""" raise NotImplementedError + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + """Create datagram transport.""" + raise NotImplementedError + def _read_from_self(self): """XXX""" raise NotImplementedError @@ -121,14 +127,17 @@ def stop_loop(): handler_called = True raise _StopError future.add_done_callback(_raise_stop_error) + if timeout is None: self.run_forever() else: handler = self.call_later(timeout, stop_loop) self.run() handler.cancel() + if handler_called: raise futures.TimeoutError + return future.result() def stop(self): @@ -286,6 +295,52 @@ def create_connection(self, protocol_factory, host=None, port=None, *, yield from waiter return transport, protocol + @tasks.task + def create_datagram_connection(self, protocol_factory, + host=None, port=None, *, + family=socket.AF_INET, proto=0, flags=0): + """Create datagram connection.""" + + addr = None + if host is not None or port is not None: + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_DGRAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + sock = socket.socket(family=family, type=type, proto=proto) + + try: + yield from self.sock_connect(sock, address) + addr = address + except socket.error as exc: + sock.close() + exceptions.append(exc) + else: + break + else: + if exceptions: + raise exceptions[0] + else: + sock = socket.socket( + family=family, type=socket.SOCK_DGRAM, proto=proto) + + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(False) + except socket.error: + sock.close() + raise + + protocol = protocol_factory() + transport = self._make_datagram_transport(sock, protocol, addr) + + return transport, protocol + # TODO: Or create_server()? @tasks.task def start_serving(self, protocol_factory, host=None, port=None, *, @@ -328,6 +383,37 @@ def start_serving(self, protocol_factory, host=None, port=None, *, self._start_serving(protocol_factory, sock) return sock + @tasks.task + def start_serving_datagram(self, protocol_factory, host, port, *, + family=0, proto=0, flags=0): + """XXX""" + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_DGRAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + sock = socket.socket(family=family, type=type, proto=proto) + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + except socket.error as exc: + sock.close() + exceptions.append(exc) + else: + sock.setblocking(False) + break + else: + raise exceptions[0] + + self._make_datagram_transport( + sock, protocol_factory(), extra={'addr': sock.getsockname()}) + + return sock + def _add_callback(self, handler): """Add a Handler to ready or scheduled.""" if handler.cancelled: diff --git a/tulip/events.py b/tulip/events.py index 6b2887b1..b8dd4334 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -49,7 +49,7 @@ def run(self): self._callback(*self._args) except Exception: logging.exception('Exception in callback %s %r', - self._callback, self._args) + self._callback, self._args) def make_handler(callback, args): @@ -182,6 +182,15 @@ def start_serving(self, protocol_factory, host=None, port=None, *, family=0, proto=0, flags=0, sock=None): raise NotImplementedError + def create_datagram_connection(self, protocol_factory, + host=None, port=None, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + def start_serving_datagram(self, protocol_factory, host, port, *, + family=0, proto=0, flags=0): + raise NotImplementedError + # Ready-based callback registration methods. # The add_*() methods return a Handler. # The remove_*() methods return True if something was removed, @@ -275,7 +284,7 @@ def new_event_loop(self): loop. """ # TODO: Do something else for Windows. - if sys.platform == 'win32': # pragma: no cover + if sys.platform == 'win32': # pragma: no cover from . import windows_events return windows_events.SelectorEventLoop() else: diff --git a/tulip/protocols.py b/tulip/protocols.py index ad294f3a..f01e2fd2 100644 --- a/tulip/protocols.py +++ b/tulip/protocols.py @@ -1,6 +1,6 @@ """Abstract Protocol class.""" -__all__ = ['Protocol'] +__all__ = ['Protocol', 'DatagramProtocol'] class Protocol: @@ -56,3 +56,19 @@ def connection_lost(self, exc): meaning a regular EOF is received or the connection was aborted or closed). """ + + +class DatagramProtocol: + """ABC representing a datagram protocol.""" + + def connection_made(self, transport): + """Called when a datagram transport is ready.""" + + def datagram_received(self, data, addr): + """Called when some datagram is received.""" + + def connection_refused(self, exc): + """Connection is refused.""" + + def connection_lost(self, exc): + """Called when the connection is lost or closed.""" diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 7851f205..5feaafa4 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -4,12 +4,13 @@ also includes support for signal handling, see the unix_events sub-module. """ +import collections import errno import logging import socket try: import ssl -except ImportError: # pragma: no cover +except ImportError: # pragma: no cover ssl = None import sys @@ -31,7 +32,7 @@ # Errno values indicating the socket isn't ready for I/O just yet. _TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK, errno.EINPROGRESS)) -if sys.platform == 'win32': # pragma: no cover +if sys.platform == 'win32': # pragma: no cover _TRYAGAIN = frozenset(list(_TRYAGAIN) + [errno.WSAEWOULDBLOCK]) @@ -58,6 +59,10 @@ def _make_ssl_transport(self, rawsock, protocol, return _SelectorSslTransport( self, rawsock, protocol, sslcontext, waiter, extra) + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + return _SelectorDatagramTransport(self, sock, protocol, address, extra) + def close(self): if self._selector is not None: self._close_self_pipe() @@ -116,9 +121,9 @@ def _accept_connection(self, protocol_factory, sock): # TODO: Someone will want an error handler for this. logging.exception('Accept failed') return - protocol = protocol_factory() - transport = self._make_socket_transport( - conn, protocol, extra={'addr': addr}) + + self._make_socket_transport( + conn, protocol_factory(), extra={'addr': addr}) # It's now up to the protocol to handle the connection. def add_reader(self, fd, callback, *args): @@ -547,3 +552,110 @@ def _fatal_error(self, exc): self._event_loop.remove_reader(self._sslsock.fileno()) self._buffer = [] self._event_loop.call_soon(self._protocol.connection_lost, exc) + + +class _SelectorDatagramTransport(transports.Transport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, event_loop, sock, protocol, address=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._event_loop = event_loop + self._sock = sock + self._fileno = sock.fileno() + self._protocol = protocol + self._address = address + self._buffer = collections.deque() + self._closing = False # Set when close() called. + self._event_loop.add_reader(self._fileno, self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + + def _read_ready(self): + try: + data, addr = self._sock.recvfrom(self.max_size) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + else: + self._event_loop.call_soon( + self._protocol.datagram_received, data, addr) + + def sendto(self, data, addr=None): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + + if self._address: + assert addr in (None, self._address) + + if not self._buffer: + # Attempt to send it right away first. + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + return + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + + self._event_loop.add_writer(self._fileno, self._sendto_ready) + + self._buffer.append((data, addr)) + + def _sendto_ready(self): + while self._buffer: + data, addr = self._buffer.popleft() + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + return + + # Try again later. + self._buffer.appendleft((data, addr)) + break + + if not self._buffer: + self._event_loop.remove_writer(self._fileno) + if self._closing: + self._event_loop.call_soon(self._call_connection_lost, None) + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._fileno) + if not self._buffer: + self._event_loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + logging.exception('Fatal error for %s', self) + + self._buffer.clear() + self._event_loop.remove_writer(self._fileno) + self._event_loop.remove_reader(self._fileno) + if self._address and isinstance(exc, ConnectionRefusedError): + self._event_loop.call_soon(self._protocol.connection_refused, exc) + self._event_loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() From 6f255f21e8dee114ef46e6ad11562078ae4b4162 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 13 Mar 2013 15:52:22 -0700 Subject: [PATCH 0348/1502] move HttpClientProtocol.write_xxx methods to separate class --- tests/http_protocol_test.py | 44 +++++++++++++++++++++++++++++++++++++ tulip/http/client.py | 40 +++++++++------------------------ tulip/http/protocol.py | 37 ++++++++++++++++++++++++++++++- 3 files changed, 90 insertions(+), 31 deletions(-) diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py index 98e9d6f2..c0952287 100644 --- a/tests/http_protocol_test.py +++ b/tests/http_protocol_test.py @@ -489,3 +489,47 @@ def test_read_message_readall(self): payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) self.assertEqual(b'dataline', payload) + + +class HttpStreamWriterTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + self.writer = protocol.HttpStreamWriter(self.transport) + + def test_ctor(self): + transport = unittest.mock.Mock() + writer = protocol.HttpStreamWriter(transport, 'latin-1') + self.assertIs(writer.transport, transport) + self.assertEqual(writer.encoding, 'latin-1') + + def test_encode(self): + self.assertEqual(b'test', self.writer.encode('test')) + self.assertEqual(b'test', self.writer.encode(b'test')) + + def test_decode(self): + self.assertEqual('test', self.writer.decode('test')) + self.assertEqual('test', self.writer.decode(b'test')) + + def test_write(self): + self.writer.write(b'test') + self.assertTrue(self.transport.write.called) + self.assertEqual((b'test',), self.transport.write.call_args[0]) + + def test_write_str(self): + self.writer.write_str('test') + self.assertTrue(self.transport.write.called) + self.assertEqual((b'test',), self.transport.write.call_args[0]) + + def test_write_cunked(self): + self.writer.write_chunked('') + self.assertFalse(self.transport.write.called) + + self.writer.write_chunked('data') + self.assertEqual( + [(b'4\r\n',), (b'data',), (b'\r\n',)], + [c[0] for c in self.transport.write.call_args_list]) + + def test_write_eof(self): + self.writer.write_chunked_eof() + self.assertEqual((b'0\r\n\r\n',), self.transport.write.call_args[0]) diff --git a/tulip/http/client.py b/tulip/http/client.py index 7a494447..b4db5ccb 100644 --- a/tulip/http/client.py +++ b/tulip/http/client.py @@ -124,45 +124,25 @@ def connect(self): sts = '{} {}'.format(status, reason) return (sts, headers, message.payload) - def encode(self, s): - if isinstance(s, bytes): - return s - return s.encode(self.encoding) - - def decode(self, s): - if isinstance(s, str): - return s - return s.decode(self.encoding) - - def write_str(self, s): - self.transport.write(self.encode(s)) - - def write_chunked(self, s): - if not s: - return - data = self.encode(s) - self.write_str('{:x}\r\n'.format(len(data))) - self.transport.write(data) - self.transport.write(b'\r\n') - - def write_chunked_eof(self): - self.transport.write(b'0\r\n\r\n') - def connection_made(self, transport): self.transport = transport + self.stream = protocol.HttpStreamReader() + self.wstream = protocol.HttpStreamWriter(transport) + line = '{} {} HTTP/{}\r\n'.format(self.method, self.path, self.version) - self.write_str(line) + self.wstream.write_str(line) for key, value in self.headers.items(): - self.write_str('{}: {}\r\n'.format(key, value)) - self.transport.write(b'\r\n') - self.stream = protocol.HttpStreamReader() + self.wstream.write_str('{}: {}\r\n'.format(key, value)) + self.wstream.write(b'\r\n') if self.make_body is not None: if self.chunked: - self.make_body(self.write_chunked, self.write_chunked_eof) + self.make_body( + self.wstream.write_chunked, self.wstream.write_chunked_eof) else: - self.make_body(self.write_str, self.transport.write_eof) + self.make_body( + self.wstream.write_str, self.wstream.write_eof) def data_received(self, data): self.stream.feed_data(data) diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py index b4454f5f..2ff7876f 100644 --- a/tulip/http/protocol.py +++ b/tulip/http/protocol.py @@ -1,6 +1,7 @@ """Http related helper utils.""" -__all__ = ['HttpStreamReader', 'HttpMessage', 'RequestLine', 'ResponseStatus'] +__all__ = ['HttpStreamReader', 'HttpStreamWriter', + 'HttpMessage', 'RequestLine', 'ResponseStatus'] import collections import functools @@ -471,3 +472,37 @@ def read_message(self, version=(1, 1), payload = self.read_length_payload(0, encoding=encoding) return HttpMessage(headers, payload, close_conn, encoding) + + +class HttpStreamWriter: + + def __init__(self, transport, encoding='utf-8'): + self.transport = transport + self.encoding = encoding + + def encode(self, s): + if isinstance(s, bytes): + return s + return s.encode(self.encoding) + + def decode(self, s): + if isinstance(s, str): + return s + return s.decode(self.encoding) + + def write(self, b): + self.transport.write(b) + + def write_str(self, s): + self.transport.write(self.encode(s)) + + def write_chunked(self, chunk): + if not chunk: + return + data = self.encode(chunk) + self.write_str('{:x}\r\n'.format(len(data))) + self.transport.write(data) + self.transport.write(b'\r\n') + + def write_chunked_eof(self): + self.transport.write(b'0\r\n\r\n') From e6e6912c12ce685dfefb916c4bb8b4899b55d29f Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 14 Mar 2013 01:46:44 -0700 Subject: [PATCH 0349/1502] Fix udp tests for Windows --- examples/udp_echo.py | 2 +- tests/events_test.py | 16 ++++++++++------ tulip/__init__.py | 2 +- tulip/events.py | 3 +-- tulip/test_utils.py | 2 +- tulip/winsocketpair.py | 2 +- 6 files changed, 15 insertions(+), 12 deletions(-) diff --git a/examples/udp_echo.py b/examples/udp_echo.py index b76aa07a..c92cb06d 100644 --- a/examples/udp_echo.py +++ b/examples/udp_echo.py @@ -9,7 +9,7 @@ import sys import tulip -ADDRESS = ('localhost', 10000) +ADDRESS = ('127.0.0.1', 10000) class MyUdpEchoProtocol: diff --git a/tests/events_test.py b/tests/events_test.py index 9da8e105..d92ab932 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -667,7 +667,7 @@ def datagram_received(self, data, addr): super().datagram_received(data, addr) self.transport.sendto(b'resp:'+data, addr) - f = self.event_loop.start_serving_datagram(factory, '0.0.0.0', 0) + f = self.event_loop.start_serving_datagram(factory, '127.0.0.1', 0) sock = self.event_loop.run_until_complete(f) host, port = sock.getsockname() @@ -713,7 +713,7 @@ def datagram_received(self, data, addr): super().datagram_received(data, addr) self.transport.sendto(b'resp:'+data, addr) - f = self.event_loop.start_serving_datagram(factory, '0.0.0.0', 0) + f = self.event_loop.start_serving_datagram(factory, '127.0.0.1', 0) sock = self.event_loop.run_until_complete(f) host, port = sock.getsockname() @@ -792,12 +792,12 @@ def factory(): proto = TestMyDatagramProto() return proto - f = self.event_loop.start_serving_datagram(factory, '0.0.0.0', 0) + f = self.event_loop.start_serving_datagram(factory, '127.0.0.1', 0) sock = self.event_loop.run_until_complete(f) self.assertEqual('INITIALIZED', proto.state) host, port = sock.getsockname() - self.assertEqual(host, '0.0.0.0') + self.assertEqual(host, '127.0.0.1') client = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) client.sendto(b'xxx', ('127.0.0.1', port)) self.event_loop.run_once() @@ -813,7 +813,7 @@ def factory(): conn = proto.transport.get_extra_info('socket') self.assertTrue(hasattr(conn, 'getsockname')) self.assertEqual( - '0.0.0.0', proto.transport.get_extra_info('addr')[0]) + '127.0.0.1', proto.transport.get_extra_info('addr')[0]) # close connection proto.transport.close() @@ -918,13 +918,17 @@ def test_accept_connection_exception(self): def test_create_datagram_connection(self): raise unittest.SkipTest( "IocpEventLoop does not have create_datagram_connection()") + def test_create_datagram_connection_no_connection(self): + raise unittest.SkipTest( + "IocpEventLoop does not have " + "create_datagram_connection_no_connection()") def test_start_serving_datagram(self): raise unittest.SkipTest( "IocpEventLoop does not have start_serving_datagram()") def test_start_serving_datagram_no_getaddrinfoc(self): raise unittest.SkipTest( "IocpEventLoop does not have start_serving_datagram()") - def test_start_serving_datagram_cant_bind(self, m_socket): + def test_start_serving_datagram_cant_bind(self): raise unittest.SkipTest( "IocpEventLoop does not have start_serving_udp()") diff --git a/tulip/__init__.py b/tulip/__init__.py index e8fa861a..faf307fb 100644 --- a/tulip/__init__.py +++ b/tulip/__init__.py @@ -14,7 +14,7 @@ if sys.platform == 'win32': # pragma: no cover from .windows_events import * else: - from .unix_events import * + from .unix_events import * # pragma: no cover __all__ = (futures.__all__ + diff --git a/tulip/events.py b/tulip/events.py index b8dd4334..a5c2dd4c 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -283,11 +283,10 @@ def new_event_loop(self): You must call set_event_loop() to make this the current event loop. """ - # TODO: Do something else for Windows. if sys.platform == 'win32': # pragma: no cover from . import windows_events return windows_events.SelectorEventLoop() - else: + else: # pragma: no cover from . import unix_events return unix_events.SelectorEventLoop() diff --git a/tulip/test_utils.py b/tulip/test_utils.py index 946e1ef7..9b87db2f 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -9,7 +9,7 @@ if sys.platform == 'win32': # pragma: no cover from .winsocketpair import socketpair else: - from socket import socketpair + from socket import socketpair # pragma: no cover class LogTrackingTestCase(unittest.TestCase): diff --git a/tulip/winsocketpair.py b/tulip/winsocketpair.py index 805b215f..374616f6 100644 --- a/tulip/winsocketpair.py +++ b/tulip/winsocketpair.py @@ -7,7 +7,7 @@ import socket import sys -if sys.platform != 'win32': +if sys.platform != 'win32': # pragma: no cover raise ImportError('winsocketpair is win32 only') From 1ee0a70c25873c29e0860004534943624a2f0d3d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 18 Mar 2013 10:29:22 -0700 Subject: [PATCH 0350/1502] Make python.org 302 test more robust. --- tests/events_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/events_test.py b/tests/events_test.py index d92ab932..32f2379c 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -4,6 +4,7 @@ import errno import gc import os +import re import signal import socket try: @@ -343,7 +344,7 @@ def test_sock_client_ops(self): data = self.event_loop.run_until_complete( self.event_loop.sock_recv(sock, 1024)) sock.close() - self.assertTrue(data.startswith(b'HTTP/1.1 302 Found\r\n')) + self.assertTrue(re.match(rb'HTTP/1.\d 302', data), data) def test_sock_client_fail(self): sock = socket.socket() From e90eb2cb9975c18e2065f853f4ffed4e3250459f Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 19 Mar 2013 09:13:45 -0700 Subject: [PATCH 0351/1502] Get rid of network interaction in tests --- tests/events_test.py | 150 ++++++++++++++++++++++++++++--------------- tests/sample.crt | 14 ++++ tests/sample.key | 15 +++++ 3 files changed, 126 insertions(+), 53 deletions(-) create mode 100644 tests/sample.crt create mode 100644 tests/sample.key diff --git a/tests/events_test.py b/tests/events_test.py index 32f2379c..cdad9bb0 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -1,8 +1,10 @@ """Tests for events.py.""" import concurrent.futures +import contextlib import errno import gc +import io import os import re import signal @@ -17,6 +19,8 @@ import unittest import unittest.mock +from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer + from tulip import events from tulip import transports from tulip import protocols @@ -85,6 +89,50 @@ def tearDown(self): gc.collect() super().tearDown() + @contextlib.contextmanager + def run_test_server(self, *, use_ssl=False): + + class SilentWSGIRequestHandler(WSGIRequestHandler): + def get_stderr(self): + return io.StringIO() + + def log_message(self, format, *args): + pass + + class SSLWSGIServer(WSGIServer): + def finish_request(self, request, client_address): + here = os.path.dirname(__file__) + keyfile = os.path.join(here, 'sample.key') + certfile = os.path.join(here, 'sample.crt') + ssock = ssl.wrap_socket(request, + keyfile=keyfile, + certfile=certfile, + server_side=True) + try: + self.RequestHandlerClass(ssock, client_address, self) + ssock.close() + except OSError: + # maybe socket has been closed by peer + pass + + def app(environ, start_response): + status = '302 Found' + headers = [('Content-type', 'text/plain')] + start_response(status, headers) + return [b'Test message'] + + # Run the test WSGI server in a separate thread in order not to + # interfere with event handling in the main thread + server_class = SSLWSGIServer if use_ssl else WSGIServer + httpd = make_server('', 0, app, server_class, SilentWSGIRequestHandler) + server_thread = threading.Thread(target=httpd.serve_forever) + server_thread.start() + try: + yield httpd + finally: + httpd.shutdown() + server_thread.join() + def test_run(self): self.event_loop.run() # Returns immediately. @@ -333,24 +381,32 @@ def sender(): self.assertTrue(data == b'x'*256) def test_sock_client_ops(self): - sock = socket.socket() - sock.setblocking(False) - # TODO: This depends on python.org behavior! - address = socket.getaddrinfo('python.org', 80, socket.AF_INET)[0][4] - self.event_loop.run_until_complete( - self.event_loop.sock_connect(sock, address)) - self.event_loop.run_until_complete( - self.event_loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) - data = self.event_loop.run_until_complete( - self.event_loop.sock_recv(sock, 1024)) - sock.close() - self.assertTrue(re.match(rb'HTTP/1.\d 302', data), data) + with self.run_test_server() as httpd: + sock = socket.socket() + sock.setblocking(False) + address = httpd.socket.getsockname() + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, address)) + self.event_loop.run_until_complete( + self.event_loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.event_loop.run_until_complete( + self.event_loop.sock_recv(sock, 1024)) + sock.close() + + self.assertTrue(re.match(rb'HTTP/1.0 302 Found', data), data) def test_sock_client_fail(self): + # Make sure that we will get an unused port + address = None + try: + s = socket.socket() + s.bind(('', 0)) + address = s.getsockname() + finally: + s.close() + sock = socket.socket() sock.setblocking(False) - # TODO: This depends on python.org behavior! - address = socket.getaddrinfo('python.org', 12345, socket.AF_INET)[0][4] with self.assertRaises(ConnectionRefusedError): self.event_loop.run_until_complete( self.event_loop.sock_connect(sock, address)) @@ -468,50 +524,38 @@ def my_handler(*args): self.assertEqual(caught, 1) def test_create_connection(self): - # TODO: This depends on xkcd.com behavior! - f = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) - tr, pr = self.event_loop.run_until_complete(f) - self.assertTrue(isinstance(tr, transports.Transport)) - self.assertTrue(isinstance(pr, protocols.Protocol)) - self.event_loop.run() - self.assertTrue(pr.nbytes > 0) + with self.run_test_server() as httpd: + host, port = httpd.socket.getsockname() + f = self.event_loop.create_connection(MyProto, host, port) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) def test_create_connection_sock(self): - # TODO: This depends on xkcd.com behavior! - sock = None - infos = self.event_loop.run_until_complete( - self.event_loop.getaddrinfo( - 'xkcd.com', 80, type=socket.SOCK_STREAM)) - for family, type, proto, cname, address in infos: - try: - sock = socket.socket(family=family, type=type, proto=proto) - sock.setblocking(False) - self.event_loop.run_until_complete( - self.event_loop.sock_connect(sock, address)) - except: - pass - else: - break - - f = self.event_loop.create_connection(MyProto, sock=sock) - tr, pr = self.event_loop.run_until_complete(f) - self.assertTrue(isinstance(tr, transports.Transport)) - self.assertTrue(isinstance(pr, protocols.Protocol)) - self.event_loop.run() - self.assertTrue(pr.nbytes > 0) + with self.run_test_server() as httpd: + host, port = httpd.socket.getsockname() + f = self.event_loop.create_connection(MyProto, host, port) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) @unittest.skipIf(ssl is None, 'No ssl module') def test_create_ssl_connection(self): - # TODO: This depends on xkcd.com behavior! - f = self.event_loop.create_connection( - MyProto, 'xkcd.com', 443, ssl=True) - tr, pr = self.event_loop.run_until_complete(f) - self.assertTrue(isinstance(tr, transports.Transport)) - self.assertTrue(isinstance(pr, protocols.Protocol)) - self.assertTrue('ssl' in tr.__class__.__name__.lower()) - self.assertTrue(hasattr(tr.get_extra_info('socket'), 'getsockname')) - self.event_loop.run() - self.assertTrue(pr.nbytes > 0) + with self.run_test_server(use_ssl=True) as httpsd: + host, port = httpsd.socket.getsockname() + f = self.event_loop.create_connection( + MyProto, host, port, ssl=True) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + self.assertTrue(hasattr(tr.get_extra_info('socket'), 'getsockname')) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) def test_create_connection_host_port_sock(self): self.suppress_log_errors() diff --git a/tests/sample.crt b/tests/sample.crt new file mode 100644 index 00000000..6a1e3f3c --- /dev/null +++ b/tests/sample.crt @@ -0,0 +1,14 @@ +-----BEGIN CERTIFICATE----- +MIICMzCCAZwCCQDFl4ys0fU7iTANBgkqhkiG9w0BAQUFADBeMQswCQYDVQQGEwJV +UzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwNU2FuLUZyYW5jaXNjbzEi +MCAGA1UECgwZUHl0aG9uIFNvZnR3YXJlIEZvbmRhdGlvbjAeFw0xMzAzMTgyMDA3 +MjhaFw0yMzAzMTYyMDA3MjhaMF4xCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxp +Zm9ybmlhMRYwFAYDVQQHDA1TYW4tRnJhbmNpc2NvMSIwIAYDVQQKDBlQeXRob24g +U29mdHdhcmUgRm9uZGF0aW9uMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCn +t3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx/47Vc5TZSaO11uO7 +gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIiqusnLfpqR8cIAavg +Z06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQABMA0GCSqGSIb3DQEB +BQUAA4GBAE9PknG6pv72+5z/gsDGYy8sK5UNkbWSNr4i4e5lxVsF03+/M71H+3AB +MxVX4+A+Vlk2fmU+BrdHIIUE0r1dDcO3josQ9hc9OJpp5VLSQFP8VeuJCmzYPp9I +I8WbW93cnXnChTrYQVdgVoFdv7GE9YgU7NYkrGIM0nZl1/f/bHPB +-----END CERTIFICATE----- diff --git a/tests/sample.key b/tests/sample.key new file mode 100644 index 00000000..edfea8dc --- /dev/null +++ b/tests/sample.key @@ -0,0 +1,15 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXQIBAAKBgQCnt3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx +/47Vc5TZSaO11uO7gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIi +qusnLfpqR8cIAavgZ06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQAB +AoGABfm8k19Yue3W68BecKEGS0VBV57GRTPT+MiBGvVGNIQ15gk6w3sGfMZsdD1y +bsUkQgcDb2d/4i5poBTpl/+Cd41V+c20IC/sSl5X1IEreHMKSLhy/uyjyiyfXlP1 +iXhToFCgLWwENWc8LzfUV8vuAV5WG6oL9bnudWzZxeqx8V0CQQDR7xwVj6LN70Eb +DUhSKLkusmFw5Gk9NJ/7wZ4eHg4B8c9KNVvSlLCLhcsVTQXuqYeFpOqytI45SneP +lr0vrvsDAkEAzITYiXu6ox5huDCG7imX2W9CAYuX638urLxBqBXMS7GqBzojD6RL +21Q8oPwJWJquERa3HDScq1deiQbM9uKIkwJBAIa1PLslGN216Xv3UPHPScyKD/aF +ynXIv+OnANPoiyp6RH4ksQ/18zcEGiVH8EeNpvV9tlAHhb+DZibQHgNr74sCQQC0 +zhToplu/bVKSlUQUNO0rqrI9z30FErDewKeCw5KSsIRSU1E/uM3fHr9iyq4wiL6u +GNjUtKZ0y46lsT9uW6LFAkB5eqeEQnshAdr3X5GykWHJ8DDGBXPPn6Rce1NX4RSq +V9khG2z1bFyfo+hMqpYnF2k32hVq3E54RS8YYnwBsVof +-----END RSA PRIVATE KEY----- From 1f44bdc9a950c53583c3d95dfc2239c0ecbcf03c Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 14 Mar 2013 13:04:32 -0700 Subject: [PATCH 0352/1502] Fix events tests for Windows --- tests/events_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index cdad9bb0..be8ccfeb 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -124,7 +124,8 @@ def app(environ, start_response): # Run the test WSGI server in a separate thread in order not to # interfere with event handling in the main thread server_class = SSLWSGIServer if use_ssl else WSGIServer - httpd = make_server('', 0, app, server_class, SilentWSGIRequestHandler) + httpd = make_server('127.0.0.1', 0, app, + server_class, SilentWSGIRequestHandler) server_thread = threading.Thread(target=httpd.serve_forever) server_thread.start() try: @@ -400,7 +401,7 @@ def test_sock_client_fail(self): address = None try: s = socket.socket() - s.bind(('', 0)) + s.bind(('127.0.0.1', 0)) address = s.getsockname() finally: s.close() From 5ab0270315eceb875d5d7abef00614ebe6a232c5 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 18 Mar 2013 17:03:58 -0700 Subject: [PATCH 0353/1502] Add notes from PyCon (more to come). --- .hgeol | 2 + .hgignore | 8 + Makefile | 29 + NOTES | 143 ++++ README | 21 + TODO | 163 ++++ check.py | 41 + crawl.py | 143 ++++ curl.py | 35 + examples/udp_echo.py | 71 ++ old/Makefile | 16 + old/echoclt.py | 79 ++ old/echosvr.py | 60 ++ old/http_client.py | 78 ++ old/http_server.py | 68 ++ old/main.py | 134 ++++ old/p3time.py | 47 ++ old/polling.py | 535 +++++++++++++ old/scheduling.py | 354 ++++++++ old/sockets.py | 348 ++++++++ old/transports.py | 496 ++++++++++++ old/xkcd.py | 18 + old/yyftime.py | 75 ++ overlapped.c | 997 +++++++++++++++++++++++ runtests.py | 198 +++++ setup.cfg | 2 + setup.py | 4 + srv.py | 123 +++ sslsrv.py | 56 ++ tests/base_events_test.py | 271 +++++++ tests/events_test.py | 1208 ++++++++++++++++++++++++++++ tests/futures_test.py | 220 +++++ tests/http_protocol_test.py | 535 +++++++++++++ tests/locks_test.py | 752 +++++++++++++++++ tests/selector_events_test.py | 1421 +++++++++++++++++++++++++++++++++ tests/selectors_test.py | 138 ++++ tests/streams_test.py | 315 ++++++++ tests/subprocess_test.py | 54 ++ tests/tasks_test.py | 556 +++++++++++++ tests/transports_test.py | 39 + tests/unix_events_test.py | 168 ++++ tests/winsocketpair_test.py | 32 + tulip/TODO | 28 + tulip/__init__.py | 26 + tulip/base_events.py | 498 ++++++++++++ tulip/events.py | 328 ++++++++ tulip/futures.py | 244 ++++++ tulip/http/__init__.py | 8 + tulip/http/client.py | 154 ++++ tulip/http/protocol.py | 508 ++++++++++++ tulip/locks.py | 460 +++++++++++ tulip/proactor_events.py | 190 +++++ tulip/protocols.py | 74 ++ tulip/selector_events.py | 661 +++++++++++++++ tulip/selectors.py | 419 ++++++++++ tulip/streams.py | 145 ++++ tulip/subprocess_transport.py | 133 +++ tulip/tasks.py | 304 +++++++ tulip/test_utils.py | 30 + tulip/transports.py | 99 +++ tulip/unix_events.py | 113 +++ tulip/windows_events.py | 157 ++++ tulip/winsocketpair.py | 34 + 63 files changed, 14666 insertions(+) create mode 100644 .hgeol create mode 100644 .hgignore create mode 100644 Makefile create mode 100644 NOTES create mode 100644 README create mode 100644 TODO create mode 100644 check.py create mode 100755 crawl.py create mode 100755 curl.py create mode 100644 examples/udp_echo.py create mode 100644 old/Makefile create mode 100644 old/echoclt.py create mode 100644 old/echosvr.py create mode 100644 old/http_client.py create mode 100644 old/http_server.py create mode 100644 old/main.py create mode 100644 old/p3time.py create mode 100644 old/polling.py create mode 100644 old/scheduling.py create mode 100644 old/sockets.py create mode 100644 old/transports.py create mode 100755 old/xkcd.py create mode 100644 old/yyftime.py create mode 100644 overlapped.c create mode 100644 runtests.py create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 srv.py create mode 100644 sslsrv.py create mode 100644 tests/base_events_test.py create mode 100644 tests/events_test.py create mode 100644 tests/futures_test.py create mode 100644 tests/http_protocol_test.py create mode 100644 tests/locks_test.py create mode 100644 tests/selector_events_test.py create mode 100644 tests/selectors_test.py create mode 100644 tests/streams_test.py create mode 100644 tests/subprocess_test.py create mode 100644 tests/tasks_test.py create mode 100644 tests/transports_test.py create mode 100644 tests/unix_events_test.py create mode 100644 tests/winsocketpair_test.py create mode 100644 tulip/TODO create mode 100644 tulip/__init__.py create mode 100644 tulip/base_events.py create mode 100644 tulip/events.py create mode 100644 tulip/futures.py create mode 100644 tulip/http/__init__.py create mode 100644 tulip/http/client.py create mode 100644 tulip/http/protocol.py create mode 100644 tulip/locks.py create mode 100644 tulip/proactor_events.py create mode 100644 tulip/protocols.py create mode 100644 tulip/selector_events.py create mode 100644 tulip/selectors.py create mode 100644 tulip/streams.py create mode 100644 tulip/subprocess_transport.py create mode 100644 tulip/tasks.py create mode 100644 tulip/test_utils.py create mode 100644 tulip/transports.py create mode 100644 tulip/unix_events.py create mode 100644 tulip/windows_events.py create mode 100644 tulip/winsocketpair.py diff --git a/.hgeol b/.hgeol new file mode 100644 index 00000000..b6910a22 --- /dev/null +++ b/.hgeol @@ -0,0 +1,2 @@ +[patterns] +** = native diff --git a/.hgignore b/.hgignore new file mode 100644 index 00000000..42309f0c --- /dev/null +++ b/.hgignore @@ -0,0 +1,8 @@ +.*\.py[co]$ +.*~$ +.*\.orig$ +.*\#.*$ +.*@.*$ +\.coverage$ +htmlcov$ +\.DS_Store$ diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..2391d89c --- /dev/null +++ b/Makefile @@ -0,0 +1,29 @@ +# Some simple testing tasks (sorry, UNIX only). + +PYTHON=python3 +FLAGS= + +test: + $(PYTHON) runtests.py -v $(FLAGS) + +testloop: + while sleep 1; do $(PYTHON) runtests.py -v $(FLAGS); done + +# See README for coverage installation instructions. +cov coverage: + $(PYTHON) runtests.py --coverage tulip -v $(FLAGS) + echo "open file://`pwd`/htmlcov/index.html" + +check: + $(PYTHON) check.py + +clean: + rm -rf `find . -name __pycache__` + rm -f `find . -type f -name '*.py[co]' ` + rm -f `find . -type f -name '*~' ` + rm -f `find . -type f -name '.*~' ` + rm -f `find . -type f -name '@*' ` + rm -f `find . -type f -name '#*#' ` + rm -f `find . -type f -name '*.orig' ` + rm -f .coverage + rm -rf htmlcov diff --git a/NOTES b/NOTES new file mode 100644 index 00000000..1121cca9 --- /dev/null +++ b/NOTES @@ -0,0 +1,143 @@ +Notes from PyCon 2013 sprints +============================= + +- Cancellation. If a task creates several subtasks, and then the + parent task fails, should the subtasks be cancelled? (How do we + even establish the parent/subtask relationship?) + +- Adam Sah suggests that there might be a need for scheduling + (especially when multiple frameworks share an event loop). He + points to lottery scheduling but also mentions that's just one of + the options. + + +Notes from the second Tulip/Twisted meet-up +=========================================== + +Rackspace, 12/11/2012 +Glyph, Brian Warner, David Reid, Duncan McGreggor, others + +Flow control +------------ + +- Pause/resume on transport manages data_received. + +- There's also an API to tell the transport whom to pause when the + write calls are overwhelming it: IConsumer.registerProducer(). + +- There's also something called pipes but it's built on top of the + old interface. + +- Twisted has variations on the basic flow control that I should + ignore. + +Half_close +---------- + +- This sends an EOF after writing some stuff. + +- Can't write any more. + +- Problem with TLS is known (the RFC sadly specifies this behavior). + +- It must be dynamimcally discoverable whether the transport supports + half_close, since the protocol may have to do something different to + make up for its missing (e.g. use chunked encoding). Twisted uses + an interface check for this and also hasattr(trans, 'halfClose') + but a flag (or flag method) is fine too. + +Constructing transport and protocol +----------------------------------- + +- There are good reasons for passing a function to the transport + construction helper that creates the protocol. (You need these + anyway for server-side protocols.) The sequence of events is + something like + + . open socket + . create transport (pass it a socket?) + . create protocol (pass it nothing) + . proto.make_connection(transport); this does: + . self.transport = transport + . self.connection_made(transport) + + But it seems okay to skip make_connection and setting .transport. + Note that make_connection() is a concrete method on the Protocol + implementation base class, while connection_made() is an abstract + method on IProtocol. + +Event Loop +---------- + +- We discussed the sequence of actions in the event loop. I think in the + end we're fine with what Tulip currently does. There are two choices: + + Tulip: + . run ready callbacks until there aren't any left + . poll, adding more callbacks to the ready list + . add now-ready delayed callbacks to the ready list + . go to top + + Tornado: + . run all currently ready callbacks (but not new ones added during this) + . (the rest is the same) + + The difference is that in the Tulip version, CPU bound callbacks + that keep adding more to the queue will starve I/O (and yielding to + other tasks won't actually cause I/O to happen unless you do + e.g. sleep(0.001)). OTOH this may be good because it means there's + less overhead if you frequently split operations in two. + +- I think Twisted does it Tornado style (in a convoluted way :-), but + it may not matter, and it's important to leave this vague so + implementations can do what's best for their platform. (E.g. if the + event loop is built into the OS there are different trade-offs.) + +System call cost +---------------- + +- System calls on MacOS are expensive, on Linux they are cheap. + +- Optimal buffer size ~16K. + +- Try joining small buffer pieces together, but expect to be tuning + this later. + +Futures +------- + +- Futures are the most robust API for async stuff, you can check + errors etc. So let's do this. + +- Just don't implement wait(). + +- For the basics, however, (recv/send, mostly), don't use Futures but use + basic callbacks, transport/protocol style. + +- make_connection() (by any name) can return a Future, it makes it + easier to check for errors. + +- This means revisiting the Tulip proactor branch (IOCP). + +- The semantics of add_done_callback() are fuzzy about in which thread + the callback will be called. (It may be the current thread or + another one.) We don't like that. But always inserting a + call_soon() indirection may be expensive? Glyph suggested changing + the add_done_callback() method name to something else to indicate + the changed promise. + +- Separately, I've been thinking about having two versions of + call_soon() -- a more heavy-weight one to be called from other + threads that also writes a byte to the self-pipe. + +Signals +------- + +- There was a side conversation about signals. A signal handler is + similar to another thread, so probably should use (the heavy-weight + version of) call_soon() to schedule the real callback and not do + anything else. + +- Glyph vaguely recalled some trickiness with the self-pipe. We + should be able to fix this afterwards if necessary, it shouldn't + affect the API design. diff --git a/README b/README new file mode 100644 index 00000000..85bfe5a7 --- /dev/null +++ b/README @@ -0,0 +1,21 @@ +Tulip is the codename for my reference implementation of PEP 3156. + +PEP 3156: http://www.python.org/dev/peps/pep-3156/ + +*** This requires Python 3.3 or later! *** + +Copyright/license: Open source, Apache 2.0. Enjoy. + +Master Mercurial repo: http://code.google.com/p/tulip/ + +The old code lives in the subdirectory 'old'; the new code (conforming +to PEP 3156, under construction) lives in the 'tulip' subdirectory. + +To run tests: + - make test + +To run coverage (coverage package is required): + - make coverage + + +--Guido van Rossum diff --git a/TODO b/TODO new file mode 100644 index 00000000..c6d4eead --- /dev/null +++ b/TODO @@ -0,0 +1,163 @@ +# -*- Mode: text -*- + +TO DO LARGER TASKS + +- Need more examples. + +- Benchmarkable but more realistic HTTP server? + +- Example of using UDP. + +- Write up a tutorial for the scheduling API. + +- More systematic approach to logging. Logger objects? What about + heavy-duty logging, tracking essentially all task state changes? + +- Restructure directory, move demos and benchmarks to subdirectories. + + +TO DO LATER + +- When multiple tasks are accessing the same socket, they should + either get interleaved I/O or an immediate exception; it should not + compromise the integrity of the scheduler or the app or leave a task + hanging. + +- For epoll you probably want to check/(log?) EPOLLHUP and EPOLLERR errors. + +- Add the simplest API possible to run a generator with a timeout. + +- Ensure multiple tasks can do atomic writes to the same pipe (since + UNIX guarantees that short writes to pipes are atomic). + +- Ensure some easy way of distributing accepted connections across tasks. + +- Be wary of thread-local storage. There should be a standard API to + get the current Context (which holds current task, event loop, and + maybe more) and a standard meta-API to change how that standard API + works (i.e. without monkey-patching). + +- See how much of asyncore I've already replaced. + +- Could BufferedReader reuse the standard io module's readers??? + +- Support ZeroMQ "sockets" which are user objects. Though possibly + this can be supported by getting the underlying fd? See + http://mail.python.org/pipermail/python-ideas/2012-October/017532.html + OTOH see + https://github.com/zeromq/pyzmq/blob/master/zmq/eventloop/ioloop.py + +- Study goroutines (again). + +- Benchmarks: http://nichol.as/benchmark-of-python-web-servers + + +FROM OLDER LIST + +- Multiple readers/writers per socket? (At which level? pollster, + eventloop, or scheduler?) + +- Could poll() usefully be an iterator? + +- Do we need to support more epoll and/or kqueue modes/flags/options/etc.? + +- Optimize register/unregister calls away if they cancel each other out? + +- Add explicit wait queue to wait for Task's completion, instead of + callbacks? + +- Look at pyfdpdlib's ioloop.py: + http://code.google.com/p/pyftpdlib/source/browse/trunk/pyftpdlib/lib/ioloop.py + + +MISTAKES I MADE + +- Forgetting yield from. (E.g.: scheduler.sleep(1); listener.accept().) + +- Forgot to add bare yield at end of internal function, after block(). + +- Forgot to call add_done_callback(). + +- Forgot to pass an undoer to block(), bug only found when cancelled. + +- Subtle accounting mistake in a callback. + +- Used context.eventloop from a different thread, forgetting about TLS. + +- Nasty race: eventloop.ready may contain both an I/O callback and a + cancel callback. How to avoid? Keep the DelayedCall in ready. Is + that enough? + +- If a toplevel task raises an error it just stops and nothing is logged + unless you have debug logging on. This confused me. (Then again, + previously I logged whenever a task raised an error, and that was too + chatty...) + +- Forgot to set the connection socket returned by accept() in + nonblocking mode. + +- Nastiest so far (cost me about a day): A race condition in + call_in_thread() where the Future's done_callback (which was + task.unblock()) would run immediately at the time when + add_done_callback() was called, and this screwed over the task + state. Solution: wrap the callback in eventloop.call_later(). + Ironically, I had a comment stating there might be a race condition. + +- Another bug where I was calling unblock() for the current thread + immediately after calling block(), before yielding. + +- readexactly() wasn't checking for EOF, so could be looping. + (Worse, the first fix I attempted was wrong.) + +- Spent a day trying to understand why a tentative patch trying to + move the recv() implementation into the eventloop (or the pollster) + resulted in problems cancelling a recv() call. Ultimately the + problem is that the cancellation mechanism is part of the coroutine + scheduler, which simply throws an exception into a task when it next + runs, and there isn't anything to be interrupted in the eventloop; + but the eventloop still has a reader registered (which will never + fire because I suspended the server -- that's my test case :-). + Then, the eventloop keeps running until the last file descriptor is + unregistered. What contributed to this disaster? + * I didn't build the whole infrastructure, just played with recv() + * I don't have unittests + * I don't have good logging to see what is going + +- In sockets.py, in some SSL error handling code, used the wrong + variable (sock instead of sslsock). A linter would have found this. + +- In polling.py, in KqueuePollster.register_writer(), a copy/paste + error where I was testing for "if fd not in self.readers" instead of + writers. This only came out when I had both a reader and a writer + for the same fd. + +- Submitted some changes prematurely (forgot to pass the filename on + hg ci). + +- Forgot again that shutdown(SHUT_WR) on an ssl socket does not work + as I expected. I ran into this with the origininal sockets.py and + again in transport.py. + +- Having the same callback for both reading and writing has a problem: + it may be scheduled twice, and if the first call closes the socket, + the second runs into trouble. + + +MISTAKES I MADE IN TULIP V2 + +- Nice one: Future._schedule_callbacks() wasn't scheduling any callbacks. + Spot the bug in these four lines: + + def _schedule_callbacks(self): + callbacks = self._callbacks[:] + self._callbacks[:] = [] + for callback in self._callbacks: + self._event_loop.call_soon(callback, self) + + The good news is that I found it with a unittest (albeit not the + unittest intended to exercise this particular method :-( ). + +- In _make_self_pipe_or_sock(), called _pollster.register_reader() + instead of add_reader(), trying to optimize something but breaking + things instead (since the -- internal -- API of register_reader() + had changed). diff --git a/check.py b/check.py new file mode 100644 index 00000000..64bc2cdd --- /dev/null +++ b/check.py @@ -0,0 +1,41 @@ +"""Search for lines > 80 chars or with trailing whitespace.""" + +import sys, os + +def main(): + args = sys.argv[1:] or os.curdir + for arg in args: + if os.path.isdir(arg): + for dn, dirs, files in os.walk(arg): + for fn in sorted(files): + if fn.endswith('.py'): + process(os.path.join(dn, fn)) + dirs[:] = [d for d in dirs if d[0] != '.'] + dirs.sort() + else: + process(arg) + +def isascii(x): + try: + x.encode('ascii') + return True + except UnicodeError: + return False + +def process(fn): + try: + f = open(fn) + except IOError as err: + print(err) + return + try: + for i, line in enumerate(f): + line = line.rstrip('\n') + sline = line.rstrip() + if len(line) > 80 or line != sline or not isascii(line): + print('%s:%d:%s%s' % ( + fn, i+1, sline, '_' * (len(line) - len(sline)))) + finally: + f.close() + +main() diff --git a/crawl.py b/crawl.py new file mode 100755 index 00000000..4e5bebe2 --- /dev/null +++ b/crawl.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 + +import logging +import re +import signal +import socket +import sys +import urllib.parse + +import tulip +import tulip.http + +END = '\n' +MAXTASKS = 100 + + +class Crawler: + + def __init__(self, rooturl): + self.rooturl = rooturl + self.todo = set() + self.busy = set() + self.done = {} + self.tasks = set() + self.waiter = None + self.addurl(self.rooturl, '') # Set initial work. + self.run() # Kick off work. + + def addurl(self, url, parenturl): + url = urllib.parse.urljoin(parenturl, url) + url, frag = urllib.parse.urldefrag(url) + if not url.startswith(self.rooturl): + return False + if url in self.busy or url in self.done or url in self.todo: + return False + self.todo.add(url) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(None) + return True + + @tulip.task + def run(self): + while self.todo or self.busy or self.tasks: + complete, self.tasks = yield from tulip.wait(self.tasks, timeout=0) + print(len(complete), 'completed tasks,', len(self.tasks), + 'still pending ', end=END) + for task in complete: + try: + yield from task + except Exception as exc: + print('Exception in task:', exc, end=END) + while self.todo and len(self.tasks) < MAXTASKS: + url = self.todo.pop() + self.busy.add(url) + self.tasks.add(self.process(url)) # Async task. + if self.busy: + self.waiter = tulip.Future() + yield from self.waiter + tulip.get_event_loop().stop() + + @tulip.task + def process(self, url): + ok = False + p = None + try: + print('processing', url, end=END) + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not path: + path = '/' + if query: + path = '?'.join([path, query]) + p = tulip.http.HttpClientProtocol( + netloc, path=path, ssl=(scheme=='https')) + delay = 1 + while True: + try: + status, headers, stream = yield from p.connect() + break + except socket.error as exc: + if delay >= 60: + raise + print('...', url, 'has error', repr(str(exc)), + 'retrying after sleep', delay, '...', end=END) + yield from tulip.sleep(delay) + delay *= 2 + + if status[:3] in ('301', '302'): + # Redirect. + u = headers.get('location') or headers.get('uri') + if self.addurl(u, url): + print(' ', url, status[:3], 'redirect to', u, end=END) + elif status.startswith('200'): + ctype = headers.get_content_type() + if ctype == 'text/html': + while True: + line = yield from stream.readline() + if not line: + break + line = line.decode('utf-8', 'replace') + urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', + line) + for u in urls: + if self.addurl(u, url): + print(' ', url, 'href to', u, end=END) + ok = True + finally: + if p is not None: + p.transport.close() + self.done[url] = ok + self.busy.remove(url) + if not ok: + print('failure for', url, sys.exc_info(), end=END) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(None) + + +def main(): + rooturl = sys.argv[1] + c = Crawler(rooturl) + loop = tulip.get_event_loop() + try: + loop.add_signal_handler(signal.SIGINT, loop.stop) + except RuntimeError: + pass + loop.run_forever() + print('todo:', len(c.todo)) + print('busy:', len(c.busy)) + print('done:', len(c.done), '; ok:', sum(c.done.values())) + print('tasks:', len(c.tasks)) + + +if __name__ == '__main__': + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + logging.info('using iocp') + el = windows_events.ProactorEventLoop() + events.set_event_loop(el) + main() diff --git a/curl.py b/curl.py new file mode 100755 index 00000000..0624df86 --- /dev/null +++ b/curl.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 + +import sys +import urllib.parse + +import tulip +import tulip.http + + +def main(): + url = sys.argv[1] + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not path: + path = '/' + if query: + path = '?'.join([path, query]) + print(netloc, path, scheme) + p = tulip.http.HttpClientProtocol(netloc, path=path, ssl=(scheme=='https')) + f = p.connect() + sts, headers, stream = p.event_loop.run_until_complete(tulip.Task(f)) + print(sts) + for k, v in headers.items(): + print('{}: {}'.format(k, v)) + print() + data = p.event_loop.run_until_complete(tulip.Task(stream)) + print(data.decode('utf-8', 'replace')) + + +if __name__ == '__main__': + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + el = windows_events.ProactorEventLoop() + events.set_event_loop(el) + main() diff --git a/examples/udp_echo.py b/examples/udp_echo.py new file mode 100644 index 00000000..c92cb06d --- /dev/null +++ b/examples/udp_echo.py @@ -0,0 +1,71 @@ +"""UDP echo example. + +Start server: + + >> python ./udp_echo.py --server + +""" + +import sys +import tulip + +ADDRESS = ('127.0.0.1', 10000) + + +class MyUdpEchoProtocol: + + def connection_made(self, transport): + print('start', transport) + self.transport = transport + + def datagram_received(self, data, addr): + print('Data received:', data, addr) + self.transport.sendto(data, addr) + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('stop', exc) + + +class MyClientUdpEchoProtocol: + + message = 'This is the message. It will be repeated.' + + def connection_made(self, transport): + self.transport = transport + print('sending "%s"' % self.message) + self.transport.sendto(self.message.encode()) + print('waiting to receive') + + def datagram_received(self, data, addr): + print('received "%s"' % data.decode()) + self.transport.close() + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('closing transport', exc) + loop = tulip.get_event_loop() + loop.stop() + + +def start_server(): + loop = tulip.get_event_loop() + loop.start_serving_datagram(MyUdpEchoProtocol, *ADDRESS) + loop.run_forever() + + +def start_client(): + loop = tulip.get_event_loop() + loop.create_datagram_connection(MyClientUdpEchoProtocol, *ADDRESS) + loop.run_forever() + + +if __name__ == '__main__': + if '--server' in sys.argv: + start_server() + else: + start_client() diff --git a/old/Makefile b/old/Makefile new file mode 100644 index 00000000..d352cd70 --- /dev/null +++ b/old/Makefile @@ -0,0 +1,16 @@ +PYTHON=python3 + +main: + $(PYTHON) main.py -v + +echo: + $(PYTHON) echosvr.py -v + +profile: + $(PYTHON) -m profile -s time main.py + +time: + $(PYTHON) p3time.py + +ytime: + $(PYTHON) yyftime.py diff --git a/old/echoclt.py b/old/echoclt.py new file mode 100644 index 00000000..c24c573e --- /dev/null +++ b/old/echoclt.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3.3 +"""Example echo client.""" + +# Stdlib imports. +import logging +import socket +import sys +import time + +# Local imports. +import scheduling +import sockets + + +def echoclient(host, port): + """COROUTINE""" + testdata = b'hi hi hi ha ha ha\n' + try: + trans = yield from sockets.create_transport(host, port, + af=socket.AF_INET) + except OSError: + return False + try: + ok = yield from trans.send(testdata) + if ok: + response = yield from trans.recv(100) + ok = response == testdata.upper() + return ok + finally: + trans.close() + + +def doit(n): + """COROUTINE""" + t0 = time.time() + tasks = set() + for i in range(n): + t = scheduling.Task(echoclient('127.0.0.1', 1111), 'client-%d' % i) + tasks.add(t) + ok = 0 + bad = 0 + for t in tasks: + try: + yield from t + except Exception: + bad += 1 + else: + ok += 1 + t1 = time.time() + print('ok: ', ok) + print('bad:', bad) + print('dt: ', round(t1-t0, 6)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Get integer from command line. + n = 1 + for arg in sys.argv[1:]: + if not arg.startswith('-'): + n = int(arg) + break + + # Run scheduler, starting it off with doit(). + scheduling.run(doit(n)) + + +if __name__ == '__main__': + main() diff --git a/old/echosvr.py b/old/echosvr.py new file mode 100644 index 00000000..4085f4c6 --- /dev/null +++ b/old/echosvr.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3.3 +"""Example echo server.""" + +# Stdlib imports. +import logging +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + while True: + line = yield from rdr.readline() + logging.debug('Received: %r from %r', line, addr) + if not line: + break + yield from trans.send(line.upper()) + logging.debug('Closing %r', addr) + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 1111, + af=socket.AF_INET, + backlog=100) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() diff --git a/old/http_client.py b/old/http_client.py new file mode 100644 index 00000000..8937ba20 --- /dev/null +++ b/old/http_client.py @@ -0,0 +1,78 @@ +"""Crummy HTTP client. + +This is not meant as an example of how to write a good client. +""" + +# Stdlib. +import re +import time + +# Local. +import sockets + + +def urlfetch(host, port=None, path='/', method='GET', + body=None, hdrs=None, encoding='utf-8', ssl=None, af=0): + """COROUTINE: Make an HTTP 1.0 request.""" + t0 = time.time() + if port is None: + if ssl: + port = 443 + else: + port = 80 + trans = yield from sockets.create_transport(host, port, ssl=ssl, af=af) + yield from trans.send(method.encode(encoding) + b' ' + + path.encode(encoding) + b' HTTP/1.0\r\n') + if hdrs: + kwds = dict(hdrs) + else: + kwds = {} + if 'host' not in kwds: + kwds['host'] = host + if body is not None: + kwds['content_length'] = len(body) + for header, value in kwds.items(): + yield from trans.send(header.replace('_', '-').encode(encoding) + + b': ' + value.encode(encoding) + b'\r\n') + + yield from trans.send(b'\r\n') + if body is not None: + yield from trans.send(body) + + # Read HTTP response line. + rdr = sockets.BufferedReader(trans) + resp = yield from rdr.readline() + m = re.match(br'(?ix) http/(\d\.\d) \s+ (\d\d\d) \s+ ([^\r]*)\r?\n\Z', + resp) + if not m: + trans.close() + raise IOError('No valid HTTP response: %r' % resp) + http_version, status, message = m.groups() + + # Read HTTP headers. + headers = [] + hdict = {} + while True: + line = yield from rdr.readline() + if not line.strip(): + break + m = re.match(br'([^\s:]+):\s*([^\r]*)\r?\n\Z', line) + if not m: + raise IOError('Invalid header: %r' % line) + header, value = m.groups() + headers.append((header, value)) + hdict[header.decode(encoding).lower()] = value.decode(encoding) + + # Read response body. + content_length = hdict.get('content-length') + if content_length is not None: + size = int(content_length) # TODO: Catch errors. + assert size >= 0, size + else: + size = 2**20 # Protective limit (1 MB). + data = yield from rdr.readexactly(size) + trans.close() # Can this block? + t1 = time.time() + result = (host, port, path, int(status), len(data), round(t1-t0, 3)) +## print(result) + return result diff --git a/old/http_server.py b/old/http_server.py new file mode 100644 index 00000000..2b1e3dd6 --- /dev/null +++ b/old/http_server.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3.3 +"""Simple HTTP server. + +This currenty exists just so we can benchmark this thing! +""" + +# Stdlib imports. +import logging +import re +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + ##logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + + # Read but ignore request line. + request_line = yield from rdr.readline() + + # Consume headers but don't interpret them. + while True: + header_line = yield from rdr.readline() + if not header_line.strip(): + break + + # Always send an empty 200 response and close. + yield from trans.send(b'HTTP/1.0 200 Ok\r\n\r\n') + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 8080, + af=socket.AF_INET) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() diff --git a/old/main.py b/old/main.py new file mode 100644 index 00000000..c1f9d0a8 --- /dev/null +++ b/old/main.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3.3 +"""Example HTTP client using yield-from coroutines (PEP 380). + +Requires Python 3.3. + +There are many micro-optimizations possible here, but that's not the point. + +Some incomplete laundry lists: + +TODO: +- Take test urls from command line. +- Move urlfetch to a separate module. +- Profiling. +- Docstrings. +- Unittests. + +FUNCTIONALITY: +- Connection pool (keep connection open). +- Chunked encoding (request and response). +- Pipelining, e.g. zlib (request and response). +- Automatic encoding/decoding. +""" + +__author__ = 'Guido van Rossum ' + +# Standard library imports (keep in alphabetic order). +import logging +import os +import time +import socket +import sys + +# Local imports (keep in alphabetic order). +import scheduling +import http_client + + + +def doit2(): + argses = [ + ('localhost', 8080, '/'), + ('127.0.0.1', 8080, '/home'), + ('python.org', 80, '/'), + ('xkcd.com', 443, '/'), + ] + results = yield from scheduling.map_over( + lambda args: http_client.urlfetch(*args), argses, timeout=2) + for res in results: + print('-->', res) + return [] + + +def doit(): + TIMEOUT = 2 + tasks = set() + + # This references NDB's default test service. + # (Sadly the service is single-threaded.) + task1 = scheduling.Task(http_client.urlfetch('localhost', 8080, path='/'), + 'root', timeout=TIMEOUT) + tasks.add(task1) + task2 = scheduling.Task(http_client.urlfetch('127.0.0.1', 8080, + path='/home'), + 'home', timeout=TIMEOUT) + tasks.add(task2) + + # Fetch python.org home page. + task3 = scheduling.Task(http_client.urlfetch('python.org', 80, path='/'), + 'python', timeout=TIMEOUT) + tasks.add(task3) + + # Fetch XKCD home page using SSL. (Doesn't like IPv6.) + task4 = scheduling.Task(http_client.urlfetch('xkcd.com', ssl=True, path='/', + af=socket.AF_INET), + 'xkcd', timeout=TIMEOUT) + tasks.add(task4) + +## # Fetch many links from python.org (/x.y.z). +## for x in '123': +## for y in '0123456789': +## path = '/{}.{}'.format(x, y) +## g = http_client.urlfetch('82.94.164.162', 80, +## path=path, hdrs={'host': 'python.org'}) +## t = scheduling.Task(g, path, timeout=2) +## tasks.add(t) + +## print(tasks) + yield from scheduling.Task(scheduling.sleep(1), timeout=0.2).wait() + winners = yield from scheduling.wait_any(tasks) + print('And the winners are:', [w.name for w in winners]) + tasks = yield from scheduling.wait_all(tasks) + print('And the players were:', [t.name for t in tasks]) + return tasks + + +def logtimes(real): + utime, stime, cutime, cstime, unused = os.times() + logging.info('real %10.3f', real) + logging.info('user %10.3f', utime + cutime) + logging.info('sys %10.3f', stime + cstime) + + +def main(): + t0 = time.time() + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + task = scheduling.run(doit()) + if task.exception: + print('Exception:', repr(task.exception)) + if isinstance(task.exception, AssertionError): + raise task.exception + else: + for t in task.result: + print(t.name + ':', + repr(t.exception) if t.exception else t.result) + + # Report real, user, sys times. + t1 = time.time() + logtimes(t1-t0) + + +if __name__ == '__main__': + main() diff --git a/old/p3time.py b/old/p3time.py new file mode 100644 index 00000000..35e14c96 --- /dev/null +++ b/old/p3time.py @@ -0,0 +1,47 @@ +"""Compare timing of plain vs. yield-from calls.""" + +import gc +import time + +def plain(n): + if n <= 0: + return 1 + l = plain(n-1) + r = plain(n-1) + return l + 1 + r + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def submain(depth): + t0 = time.time() + k = plain(depth) + t1 = time.time() + fmt = ' {} {} {:-9,.5f}' + delta0 = t1-t0 + print(('plain' + fmt).format(depth, k, delta0)) + + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + delta1 = t1-t0 + print(('coro.' + fmt).format(depth, k, delta1)) + if delta0: + print(('relat' + fmt).format(depth, k, delta1/delta0)) + +def main(reasonable=16): + gc.disable() + for depth in range(reasonable): + submain(depth) + +if __name__ == '__main__': + main() diff --git a/old/polling.py b/old/polling.py new file mode 100644 index 00000000..6586efcc --- /dev/null +++ b/old/polling.py @@ -0,0 +1,535 @@ +"""Event loop and related classes. + +The event loop can be broken up into a pollster (the part responsible +for telling us when file descriptors are ready) and the event loop +proper, which wraps a pollster with functionality for scheduling +callbacks, immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. + +There are several implementations of the pollster part, several using +esoteric system calls that exist only on some platforms. These are: + +- kqueue (most BSD systems) +- epoll (newer Linux systems) +- poll (most UNIX systems) +- select (all UNIX systems, and Windows) +- TODO: Support IOCP on Windows and some UNIX platforms. + +NOTE: We don't use select on systems where any of the others is +available, because select performs poorly as the number of file +descriptors goes up. The ranking is roughly: + + 1. kqueue, epoll, IOCP + 2. poll + 3. select + +TODO: +- Optimize the various pollsters. +- Unittests. +""" + +import collections +import concurrent.futures +import heapq +import logging +import os +import select +import threading +import time + + +class PollsterBase: + """Base class for all polling implementations. + + This defines an interface to register and unregister readers and + writers for specific file descriptors, and an interface to get a + list of events. There's also an interface to check whether any + readers or writers are currently registered. + """ + + def __init__(self): + super().__init__() + self.readers = {} # {fd: token, ...}. + self.writers = {} # {fd: token, ...}. + + def pollable(self): + """Return True if any readers or writers are currently registered.""" + return bool(self.readers or self.writers) + + # Subclasses are expected to extend the add/remove methods. + + def register_reader(self, fd, token): + """Add or update a reader for a file descriptor.""" + self.readers[fd] = token + + def register_writer(self, fd, token): + """Add or update a writer for a file descriptor.""" + self.writers[fd] = token + + def unregister_reader(self, fd): + """Remove the reader for a file descriptor.""" + del self.readers[fd] + + def unregister_writer(self, fd): + """Remove the writer for a file descriptor.""" + del self.writers[fd] + + def poll(self, timeout=None): + """Poll for events. A subclass must implement this. + + If timeout is omitted or None, this blocks until at least one + event is ready. Otherwise, timeout gives a maximum time to + wait (in seconds as an int or float) -- the method returns as + soon as at least one event is ready or when the timeout is + expired. For a non-blocking poll, pass 0. + + The return value is a list of events; it is empty when the + timeout expired before any events were ready. Each event + is a token previously passed to register_reader/writer(). + """ + raise NotImplementedError + + +class SelectPollster(PollsterBase): + """Pollster implementation using select.""" + + def poll(self, timeout=None): + readable, writable, _ = select.select(self.readers, self.writers, + [], timeout) + events = [] + events += (self.readers[fd] for fd in readable) + events += (self.writers[fd] for fd in writable) + return events + + +class PollPollster(PollsterBase): + """Pollster implementation using poll.""" + + def __init__(self): + super().__init__() + self._poll = select.poll() + + def _update(self, fd): + assert isinstance(fd, int), fd + flags = 0 + if fd in self.readers: + flags |= select.POLLIN + if fd in self.writers: + flags |= select.POLLOUT + if flags: + self._poll.register(fd, flags) + else: + self._poll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + # Timeout is in seconds, but poll() takes milliseconds. + msecs = None if timeout is None else int(round(1000 * timeout)) + events = [] + for fd, flags in self._poll.poll(msecs): + if flags & (select.POLLIN | select.POLLHUP): + if fd in self.readers: + events.append(self.readers[fd]) + if flags & (select.POLLOUT | select.POLLHUP): + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class EPollPollster(PollsterBase): + """Pollster implementation using epoll.""" + + def __init__(self): + super().__init__() + self._epoll = select.epoll() + + def _update(self, fd): + assert isinstance(fd, int), fd + eventmask = 0 + if fd in self.readers: + eventmask |= select.EPOLLIN + if fd in self.writers: + eventmask |= select.EPOLLOUT + if eventmask: + try: + self._epoll.register(fd, eventmask) + except IOError: + self._epoll.modify(fd, eventmask) + else: + self._epoll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + if timeout is None: + timeout = -1 # epoll.poll() uses -1 to mean "wait forever". + events = [] + for fd, eventmask in self._epoll.poll(timeout): + if eventmask & select.EPOLLIN: + if fd in self.readers: + events.append(self.readers[fd]) + if eventmask & select.EPOLLOUT: + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class KqueuePollster(PollsterBase): + """Pollster implementation using kqueue.""" + + def __init__(self): + super().__init__() + self._kqueue = select.kqueue() + + def register_reader(self, fd, callback, *args): + if fd not in self.readers: + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_reader(fd, callback, *args) + + def register_writer(self, fd, callback, *args): + if fd not in self.writers: + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_writer(fd, callback, *args) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def poll(self, timeout=None): + events = [] + max_ev = len(self.readers) + len(self.writers) + for kev in self._kqueue.control(None, max_ev, timeout): + fd = kev.ident + flag = kev.filter + if flag == select.KQ_FILTER_READ and fd in self.readers: + events.append(self.readers[fd]) + elif flag == select.KQ_FILTER_WRITE and fd in self.writers: + events.append(self.writers[fd]) + return events + + +# Pick the best pollster class for the platform. +if hasattr(select, 'kqueue'): + best_pollster = KqueuePollster +elif hasattr(select, 'epoll'): + best_pollster = EPollPollster +elif hasattr(select, 'poll'): + best_pollster = PollPollster +else: + best_pollster = SelectPollster + + +class DelayedCall: + """Object returned by callback registration methods.""" + + def __init__(self, when, callback, args, kwds=None): + self.when = when + self.callback = callback + self.args = args + self.kwds = kwds + self.cancelled = False + + def cancel(self): + self.cancelled = True + + def __lt__(self, other): + return self.when < other.when + + def __le__(self, other): + return self.when <= other.when + + def __eq__(self, other): + return self.when == other.when + + +class EventLoop: + """Event loop functionality. + + This defines public APIs call_soon(), call_later(), run_once() and + run(). It also wraps Pollster APIs register_reader(), + register_writer(), remove_reader(), remove_writer() with + add_reader() etc. + + This class's instance variables are not part of its API. + """ + + def __init__(self, pollster=None): + super().__init__() + if pollster is None: + logging.info('Using pollster: %s', best_pollster.__name__) + pollster = best_pollster() + self.pollster = pollster + self.ready = collections.deque() # [(callback, args), ...] + self.scheduled = [] # [(when, callback, args), ...] + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_reader(fd, dcall) + return dcall + + def remove_reader(self, fd): + """Remove a reader callback.""" + self.pollster.unregister_reader(fd) + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_writer(fd, dcall) + return dcall + + def remove_writer(self, fd): + """Remove a writer callback.""" + self.pollster.unregister_writer(fd) + + def add_callback(self, dcall): + """Add a DelayedCall to ready or scheduled.""" + if dcall.cancelled: + return + if dcall.when is None: + self.ready.append(dcall) + else: + heapq.heappush(self.scheduled, dcall) + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + dcall = DelayedCall(None, callback, args) + self.ready.append(dcall) + return dcall + + def call_later(self, when, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The time can be an int or float, expressed in seconds. + + If when is small enough (~11 days), it's assumed to be a + relative time, meaning the call will be scheduled that many + seconds in the future; otherwise it's assumed to be a posix + timestamp as returned by time.time(). + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + if when < 10000000: + when += time.time() + dcall = DelayedCall(when, callback, args) + heapq.heappush(self.scheduled, dcall) + return dcall + + def run_once(self): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Pass in a timeout or deadline or something. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: As step 4, run everything scheduled by steps 1-3. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # TODO: Ensure this loop always finishes, even if some + # callbacks keeps registering more callbacks. + while self.ready: + dcall = self.ready.popleft() + if not dcall.cancelled: + try: + if dcall.kwds: + dcall.callback(*dcall.args, **dcall.kwds) + else: + dcall.callback(*dcall.args) + except Exception: + logging.exception('Exception in callback %s %r', + dcall.callback, dcall.args) + + # Remove delayed calls that were cancelled from head of queue. + while self.scheduled and self.scheduled[0].cancelled: + heapq.heappop(self.scheduled) + + # Inspect the poll queue. + if self.pollster.pollable(): + if self.scheduled: + when = self.scheduled[0].when + timeout = max(0, when - time.time()) + else: + timeout = None + t0 = time.time() + events = self.pollster.poll(timeout) + t1 = time.time() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + for dcall in events: + self.add_callback(dcall) + + # Handle 'later' callbacks that are ready. + now = time.time() + while self.scheduled: + dcall = self.scheduled[0] + if dcall.when > now: + break + dcall = heapq.heappop(self.scheduled) + self.call_soon(dcall.callback, *dcall.args) + + def run(self): + """Run the event loop until there is no work left to do. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + """ + while self.ready or self.scheduled or self.pollster.pollable(): + self.run_once() + + +MAX_WORKERS = 5 # Default max workers when creating an executor. + + +class ThreadRunner: + """Helper to submit work to a thread pool and wait for it. + + This is the glue between the single-threaded callback-based async + world and the threaded world. Use it to call functions that must + block and don't have an async alternative (e.g. getaddrinfo()). + + The only public API is submit(). + """ + + def __init__(self, eventloop, executor=None): + self.eventloop = eventloop + self.executor = executor # Will be constructed lazily. + self.pipe_read_fd, self.pipe_write_fd = os.pipe() + self.active_count = 0 + + def read_callback(self): + """Semi-permanent callback while at least one future is active.""" + assert self.active_count > 0, self.active_count + data = os.read(self.pipe_read_fd, 8192) # Traditional buffer size. + self.active_count -= len(data) + if self.active_count == 0: + self.eventloop.remove_reader(self.pipe_read_fd) + assert self.active_count >= 0, self.active_count + + def submit(self, func, *args, executor=None, callback=None): + """Submit a function to the thread pool. + + This returns a concurrent.futures.Future instance. The caller + should not wait for that, but rather use the callback argument.. + """ + if executor is None: + executor = self.executor + if executor is None: + # Lazily construct a default executor. + # TODO: Should this be shared between threads? + executor = concurrent.futures.ThreadPoolExecutor(MAX_WORKERS) + self.executor = executor + assert self.active_count >= 0, self.active_count + future = executor.submit(func, *args) + if self.active_count == 0: + self.eventloop.add_reader(self.pipe_read_fd, self.read_callback) + self.active_count += 1 + def done_callback(fut): + if callback is not None: + self.eventloop.call_soon(callback, fut) + # TODO: Wake up the pipe in call_soon()? + os.write(self.pipe_write_fd, b'x') + future.add_done_callback(done_callback) + return future + + +class Context(threading.local): + """Thread-local context. + + We use this to avoid having to explicitly pass around an event loop + or something to hold the current task. + + TODO: Add an API so frameworks can substitute a different notion + of context more easily. + """ + + def __init__(self, eventloop=None, threadrunner=None): + # Default event loop and thread runner are lazily constructed + # when first accessed. + self._eventloop = eventloop + self._threadrunner = threadrunner + self.current_task = None # For the benefit of scheduling.py. + + @property + def eventloop(self): + if self._eventloop is None: + self._eventloop = EventLoop() + return self._eventloop + + @property + def threadrunner(self): + if self._threadrunner is None: + self._threadrunner = ThreadRunner(self.eventloop) + return self._threadrunner + + +context = Context() # Thread-local! diff --git a/old/scheduling.py b/old/scheduling.py new file mode 100644 index 00000000..3864571d --- /dev/null +++ b/old/scheduling.py @@ -0,0 +1,354 @@ +#!/usr/bin/env python3.3 +"""Example coroutine scheduler, PEP-380-style ('yield from '). + +Requires Python 3.3. + +There are likely micro-optimizations possible here, but that's not the point. + +TODO: +- Docstrings. +- Unittests. + +PATTERNS TO TRY: +- Various synchronization primitives (Lock, RLock, Event, Condition, + Semaphore, BoundedSemaphore, Barrier). +""" + +__author__ = 'Guido van Rossum ' + +# Standard library imports (keep in alphabetic order). +from concurrent.futures import CancelledError, TimeoutError +import logging +import time +import types + +# Local imports (keep in alphabetic order). +import polling + + +context = polling.context + + +class Task: + """Wrapper around a stack of generators. + + This is a bit like a Future, but with a different interface. + + TODO: + - wait for result. + """ + + def __init__(self, gen, name=None, *, timeout=None): + assert isinstance(gen, types.GeneratorType), repr(gen) + self.gen = gen + self.name = name or gen.__name__ + self.timeout = timeout + self.eventloop = context.eventloop + self.canceleer = None + if timeout is not None: + self.canceleer = self.eventloop.call_later(timeout, self.cancel) + self.blocked = False + self.unblocker = None + self.cancelled = False + self.must_cancel = False + self.alive = True + self.result = None + self.exception = None + self.done_callbacks = [] + # Start the task immediately. + self.eventloop.call_soon(self.step) + + def add_done_callback(self, done_callback): + # For better or for worse, the callback will always be called + # with the task as an argument, like concurrent.futures.Future. + # TODO: Call it right away if task is no longer alive. + dcall = polling.DelayedCall(None, done_callback, (self,)) + self.done_callbacks.append(dcall) + self.done_callbacks = [dc for dc in self.done_callbacks + if not dc.cancelled] + return dcall + + def __repr__(self): + parts = [self.name] + is_current = (self is context.current_task) + if self.blocked: + parts.append('blocking' if is_current else 'blocked') + elif self.alive: + parts.append('running' if is_current else 'runnable') + if self.must_cancel: + parts.append('must_cancel') + if self.cancelled: + parts.append('cancelled') + if self.exception is not None: + parts.append('exception=%r' % self.exception) + elif not self.alive: + parts.append('result=%r' % (self.result,)) + if self.timeout is not None: + parts.append('timeout=%.3f' % self.timeout) + return 'Task<' + ', '.join(parts) + '>' + + def cancel(self): + if self.alive: + if not self.must_cancel and not self.cancelled: + self.must_cancel = True + if self.blocked: + self.unblock() + + def step(self): + assert self.alive, self + try: + context.current_task = self + if self.must_cancel: + self.must_cancel = False + self.cancelled = True + self.gen.throw(CancelledError()) + else: + next(self.gen) + except StopIteration as exc: + self.alive = False + self.result = exc.value + except Exception as exc: + self.alive = False + self.exception = exc + logging.debug('Uncaught exception in %s', self, + exc_info=True, stack_info=True) + except BaseException as exc: + self.alive = False + self.exception = exc + raise + else: + if not self.blocked: + self.eventloop.call_soon(self.step) + finally: + context.current_task = None + if not self.alive: + # Cancel timeout callback if set. + if self.canceleer is not None: + self.canceleer.cancel() + # Schedule done_callbacks. + for dcall in self.done_callbacks: + self.eventloop.add_callback(dcall) + + def block(self, unblock_callback=None, *unblock_args): + assert self is context.current_task, self + assert self.alive, self + assert not self.blocked, self + self.blocked = True + self.unblocker = (unblock_callback, unblock_args) + + def unblock_if_alive(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. + if self.alive: + self.unblock() + + def unblock(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. + assert self.alive, self + assert self.blocked, self + self.blocked = False + unblock_callback, unblock_args = self.unblocker + if unblock_callback is not None: + try: + unblock_callback(*unblock_args) + except Exception: + logging.error('Exception in unblocker in task %r', self.name) + raise + finally: + self.unblocker = None + self.eventloop.call_soon(self.step) + + def block_io(self, fd, flag): + assert isinstance(fd, int), repr(fd) + assert flag in ('r', 'w'), repr(flag) + if flag == 'r': + self.block(self.eventloop.remove_reader, fd) + self.eventloop.add_reader(fd, self.unblock) + else: + self.block(self.eventloop.remove_writer, fd) + self.eventloop.add_writer(fd, self.unblock) + + def wait(self): + """COROUTINE: Wait until this task is finished.""" + current_task = context.current_task + assert self is not current_task, (self, current_task) # How confusing! + if not self.alive: + return + current_task.block() + self.add_done_callback(current_task.unblock) + yield + + def __iter__(self): + """COROUTINE: Wait, then return result or raise exception. + + This adds a little magic so you can say + + x = yield from Task(gen()) + + and it is equivalent to + + x = yield from gen() + + but with the option to add a timeout (and only a tad slower). + """ + if self.alive: + yield from self.wait() + assert not self.alive + if self.exception is not None: + raise self.exception + return self.result + + +def run(arg=None): + """Run the event loop until it's out of work. + + If you pass a generator, it will be spawned for you. + You can also pass a task (already started). + Returns the task. + """ + t = None + if arg is not None: + if isinstance(arg, Task): + t = arg + else: + t = Task(arg) + context.eventloop.run() + if t is not None and t.exception is not None: + logging.error('Uncaught exception in startup task: %r', + t.exception) + return t + + +def sleep(secs): + """COROUTINE: Sleep for some time (a float in seconds).""" + current_task = context.current_task + unblocker = context.eventloop.call_later(secs, current_task.unblock) + current_task.block(unblocker.cancel) + yield + + +def block_r(fd): + """COROUTINE: Block until a file descriptor is ready for reading.""" + context.current_task.block_io(fd, 'r') + yield + + +def block_w(fd): + """COROUTINE: Block until a file descriptor is ready for writing.""" + context.current_task.block_io(fd, 'w') + yield + + +def call_in_thread(func, *args, executor=None): + """COROUTINE: Run a function in a thread.""" + task = context.current_task + eventloop = context.eventloop + future = context.threadrunner.submit(func, *args, + executor=executor, + callback=task.unblock_if_alive) + task.block(future.cancel) + yield + assert future.done() + return future.result() + + +def wait_for(count, tasks): + """COROUTINE: Wait for the first N of a set of tasks to complete. + + May return more than N if more than N are immediately ready. + + NOTE: Tasks that were cancelled or raised are also considered ready. + """ + assert tasks + assert all(isinstance(task, Task) for task in tasks) + tasks = set(tasks) + assert 1 <= count <= len(tasks) + current_task = context.current_task + assert all(task is not current_task for task in tasks) + todo = set() + done = set() + dcalls = [] + def wait_for_callback(task): + nonlocal todo, done, current_task, count, dcalls + todo.remove(task) + if len(done) < count: + done.add(task) + if len(done) == count: + for dcall in dcalls: + dcall.cancel() + current_task.unblock() + for task in tasks: + if task.alive: + todo.add(task) + else: + done.add(task) + if len(done) < count: + for task in todo: + dcall = task.add_done_callback(wait_for_callback) + dcalls.append(dcall) + current_task.block() + yield + return done + + +def wait_any(tasks): + """COROUTINE: Wait for the first of a set of tasks to complete.""" + return wait_for(1, tasks) + + +def wait_all(tasks): + """COROUTINE: Wait for all of a set of tasks to complete.""" + return wait_for(len(tasks), tasks) + + +def map_over(gen, *args, timeout=None): + """COROUTINE: map a generator over one or more iterables. + + E.g. map_over(foo, xs, ys) runs + + Task(foo(x, y)) for x, y in zip(xs, ys) + + and returns a list of all results (in that order). However if any + task raises an exception, the remaining tasks are cancelled and + the exception is propagated. + """ + # gen is a generator function. + tasks = [Task(gobj, timeout=timeout) for gobj in map(gen, *args)] + return (yield from par_tasks(tasks)) + + +def par(*args): + """COROUTINE: Wait for generators, return a list of results. + + Raises as soon as one of the tasks raises an exception (and then + remaining tasks are cancelled). + + This differs from par_tasks() in two ways: + - takes *args instead of list of args + - each arg may be a generator or a task + """ + tasks = [] + for arg in args: + if not isinstance(arg, Task): + # TODO: assert arg is a generator or an iterator? + arg = Task(arg) + tasks.append(arg) + return (yield from par_tasks(tasks)) + + +def par_tasks(tasks): + """COROUTINE: Wait for a list of tasks, return a list of results. + + Raises as soon as one of the tasks raises an exception (and then + remaining tasks are cancelled). + """ + todo = set(tasks) + while todo: + ts = yield from wait_any(todo) + for t in ts: + assert not t.alive, t + todo.remove(t) + if t.exception is not None: + for other in todo: + other.cancel() + raise t.exception + return [t.result for t in tasks] diff --git a/old/sockets.py b/old/sockets.py new file mode 100644 index 00000000..a5005dc3 --- /dev/null +++ b/old/sockets.py @@ -0,0 +1,348 @@ +"""Socket wrappers to go with scheduling.py. + +Classes: + +- SocketTransport: a transport implementation wrapping a socket. +- SslTransport: a transport implementation wrapping SSL around a socket. +- BufferedReader: a buffer wrapping the read end of a transport. + +Functions (all coroutines): + +- connect(): connect a socket. +- getaddrinfo(): look up an address. +- create_connection(): look up address and return a connected socket for it. +- create_transport(): look up address and return a connected transport. + +TODO: +- Improve transport abstraction. +- Make a nice protocol abstraction. +- Unittests. +- A write() call that isn't a generator (needed so you can substitute it + for sys.stderr, pass it to logging.StreamHandler, etc.). +""" + +__author__ = 'Guido van Rossum ' + +# Stdlib imports. +import errno +import socket +import ssl + +# Local imports. +import scheduling + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + + +class SocketTransport: + """Transport wrapping a socket. + + The socket must already be connected in non-blocking mode. + """ + + def __init__(self, sock): + self.sock = sock + + def recv(self, n): + """COROUTINE: Read up to n bytes, blocking as needed. + + Always returns at least one byte, except if the socket was + closed or disconnected and there's no more data; then it + returns b''. + """ + assert n >= 0, n + while True: + try: + return self.sock.recv(n) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return b'' + else: + raise # Unexpected, propagate. + yield from scheduling.block_r(self.sock.fileno()) + + def send(self, data): + """COROUTINE; Send data to the socket, blocking until all written. + + Return True if all went well, False if socket was disconnected. + """ + while data: + try: + n = self.sock.send(data) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + else: + assert 0 <= n <= len(data), (n, len(data)) + if n == len(data): + break + data = data[n:] + continue + yield from scheduling.block_w(self.sock.fileno()) + + return True + + def close(self): + """Close the socket. (Not a coroutine.)""" + self.sock.close() + + +class SslTransport: + """Transport wrapping a socket in SSL. + + The socket must already be connected at the TCP level in + non-blocking mode. + """ + + def __init__(self, rawsock, sslcontext=None): + self.rawsock = rawsock + self.sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self.sslsock = self.sslcontext.wrap_socket( + self.rawsock, do_handshake_on_connect=False) + + def do_handshake(self): + """COROUTINE: Finish the SSL handshake.""" + while True: + try: + self.sslsock.do_handshake() + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + else: + break + + def recv(self, n): + """COROUTINE: Read up to n bytes. + + This blocks until at least one byte is read, or until EOF. + """ + while True: + try: + return self.sslsock.recv(n) + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_r(self.sslsock.fileno()) + elif err.errno in _DISCONNECTED: + # Can this happen? + return b'' + else: + raise # Unexpected, propagate. + + def send(self, data): + """COROUTINE: Send data to the socket, blocking as needed.""" + while data: + try: + n = self.sslsock.send(data) + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_w(self.sslsock.fileno()) + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + if n == len(data): + break + data = data[n:] + + return True + + def close(self): + """Close the socket. (Not a coroutine.) + + This also closes the raw socket. + """ + self.sslsock.close() + + # TODO: More SSL-specific methods, e.g. certificate stuff, unwrap(), ... + + +class BufferedReader: + """A buffered reader wrapping a transport.""" + + def __init__(self, trans, limit=8192): + self.trans = trans + self.limit = limit + self.buffer = b'' + self.eof = False + + def read(self, n): + """COROUTINE: Read up to n bytes, blocking at most once.""" + assert n >= 0, n + if not self.buffer and not self.eof: + yield from self._fillbuffer(max(n, self.limit)) + return self._getfrombuffer(n) + + def readexactly(self, n): + """COUROUTINE: Read exactly n bytes, or until EOF.""" + blocks = [] + count = 0 + while count < n: + block = yield from self.read(n - count) + if not block: + break + blocks.append(block) + count += len(block) + return b''.join(blocks) + + def readline(self): + """COROUTINE: Read up to newline or limit, whichever comes first.""" + end = self.buffer.find(b'\n') + 1 # Point past newline, or 0. + while not end and not self.eof and len(self.buffer) < self.limit: + anchor = len(self.buffer) + yield from self._fillbuffer(self.limit) + end = self.buffer.find(b'\n', anchor) + 1 + if not end: + end = len(self.buffer) + if end > self.limit: + end = self.limit + return self._getfrombuffer(end) + + def _getfrombuffer(self, n): + """Read up to n bytes without blocking (not a coroutine).""" + if n >= len(self.buffer): + result, self.buffer = self.buffer, b'' + else: + result, self.buffer = self.buffer[:n], self.buffer[n:] + return result + + def _fillbuffer(self, n): + """COROUTINE: Fill buffer with one (up to) n bytes from transport.""" + assert not self.eof, '_fillbuffer called at eof' + data = yield from self.trans.recv(n) + if data: + self.buffer += data + else: + self.eof = True + + +def connect(sock, address): + """COROUTINE: Connect a socket to an address.""" + try: + sock.connect(address) + except socket.error as err: + if err.errno != errno.EINPROGRESS: + raise + yield from scheduling.block_w(sock.fileno()) + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + raise IOError(err, 'Connection refused') + + +def getaddrinfo(host, port, af=0, socktype=0, proto=0): + """COROUTINE: Look up an address and return a list of infos for it. + + Each info is a tuple (af, socktype, protocol, canonname, address). + """ + infos = yield from scheduling.call_in_thread(socket.getaddrinfo, + host, port, af, + socktype, proto) + return infos + + +def create_connection(host, port, af=0, socktype=socket.SOCK_STREAM, proto=0): + """COROUTINE: Look up address and create a socket connected to it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + yield from connect(sock, address) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return sock + + +def create_transport(host, port, af=0, ssl=None): + """COROUTINE: Look up address and create a transport connected to it.""" + if ssl is None: + ssl = (port == 443) + sock = yield from create_connection(host, port, af) + if ssl: + trans = SslTransport(sock) + yield from trans.do_handshake() + else: + trans = SocketTransport(sock) + return trans + + +class Listener: + """Wrapper for a listening socket.""" + + def __init__(self, sock): + self.sock = sock + + def accept(self): + """COROUTINE: Accept a connection.""" + while True: + try: + conn, addr = self.sock.accept() + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_r(self.sock.fileno()) + else: + raise # Unexpected, propagate. + else: + conn.setblocking(False) + return conn, addr + + +def create_listener(host, port, af=0, socktype=0, proto=0, + backlog=5, reuse_addr=True): + """COROUTINE: Look up address and create a listener for it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + if reuse_addr: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + sock.listen(backlog) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return Listener(sock) diff --git a/old/transports.py b/old/transports.py new file mode 100644 index 00000000..19095bf4 --- /dev/null +++ b/old/transports.py @@ -0,0 +1,496 @@ +"""Transports and Protocols, actually. + +Inspired by Twisted, PEP 3153 and github.com/lvh/async-pep. + +THIS IS NOT REAL CODE! IT IS JUST AN EXPERIMENT. +""" + +# Stdlib imports. +import collections +import errno +import logging +import socket +import ssl +import sys +import time + +# Local imports. +import polling +import scheduling +import sockets + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + + +class Transport: + """ABC representing a transport. + + There may be many implementations. The user never instantiates + this directly; they call some utility function, passing it a + protocol, and the utility function will call the protocol's + connection_made() method with a transport (or it will call + connection_lost() with an exception if it fails to create the + desired transport). + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + def write(self, data): + """Write some data (bytes) to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data (bytes) to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data will + be received. When all buffered data is flushed, the protocol's + connection_lost() method is called with None as its argument. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method is called with None as + its argument. + """ + raise NotImplementedError + + def half_close(self): + """Closes the write end after flushing buffered data. + + Data may still be received. + + TODO: What's the use case for this? How to implement it? + Should it call shutdown(SHUT_WR) after all the data is flushed? + Is there no use case for closing the other half first? + """ + raise NotImplementedError + + def pause(self): + """Pause the receiving end. + + No data will be received until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Cancels a pause() call, resumes receiving data. + """ + raise NotImplementedError + + +class Protocol: + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing. + + When the user wants to requests a transport, they pass a protocol + instance to a utility function. + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_list() will be called exactly once + with either an exception object or None as an argument. + + If the utility function does not succeed in creating a transport, + it will call connection_lost() with an exception object. + + State machine of calls: + + start -> [CM -> DR*] -> CL -> end + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the connection. + To send data, call its write() or writelines() method. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + + TODO: Should we allow it to be a bytesarray or some other + memory buffer? + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + Also called when we fail to make a connection at all (in that + case connection_made() will not be called). + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +# TODO: The rest is platform specific and should move elsewhere. + +class UnixSocketTransport(Transport): + + def __init__(self, eventloop, protocol, sock): + self._eventloop = eventloop + self._protocol = protocol + self._sock = sock + self._buffer = collections.deque() # For write(). + self._write_closed = False + + def _on_readable(self): + try: + data = self._sock.recv(8192) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + else: + if not data: + self._eventloop.remove_reader(self._sock.fileno()) + self._sock.close() + self._protocol.connection_lost(None) + else: + self._protocol.data_received(data) # XXX call_soon()? + + def write(self, data): + assert isinstance(data, bytes) + assert not self._write_closed + if not data: + # Silly, but it happens. + return + if self._buffer: + # We've already registered a callback, just buffer the data. + self._buffer.append(data) + # Consider pausing if the total length of the buffer is + # truly huge. + return + + # TODO: Refactor so there's more sharing between this and + # _on_writable(). + + # There's no callback registered yet. It's quite possible + # that the kernel has buffer space for our data, so try to + # write now. Since the socket is non-blocking it will + # give us an error in _TRYAGAIN if it doesn't have enough + # space for even one more byte; it will return the number + # of bytes written if it can write at least one byte. + try: + n = self._sock.send(data) + except socket.error as exc: + # An error. + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + # The kernel doesn't have room for more data right now. + n = 0 + else: + # Wrote at least one byte. + if n == len(data): + # Wrote it all. Done! + if self._write_closed: + self._sock.shutdown(socket.SHUT_WR) + return + # Throw away the data that was already written. + # TODO: Do this without copying the data? + data = data[n:] + self._buffer.append(data) + self._eventloop.add_writer(self._sock.fileno(), self._on_writable) + + def _on_writable(self): + while self._buffer: + data = self._buffer[0] + # TODO: Join small amounts of data? + try: + n = self._sock.send(data) + except socket.error as exc: + # Error handling is the same as in write(). + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + if n < len(data): + self._buffer[0] = data[n:] + return + self._buffer.popleft() + self._eventloop.remove_writer(self._sock.fileno()) + if self._write_closed: + self._sock.shutdown(socket.SHUT_WR) + + def abort(self): + self._bad_error(None) + + def _bad_error(self, exc): + # A serious error. Close the socket etc. + fd = self._sock.fileno() + # TODO: Record whether we have a writer and/or reader registered. + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + self._sock.close() + self._protocol.connection_lost(exc) # XXX call_soon()? + + def half_close(self): + self._write_closed = True + + +class UnixSslTransport(Transport): + + # TODO: Refactor Socket and Ssl transport to share some code. + # (E.g. buffering.) + + # TODO: Consider using coroutines instead of callbacks, it seems + # much easier that way. + + def __init__(self, eventloop, protocol, rawsock, sslcontext=None): + self._eventloop = eventloop + self._protocol = protocol + self._rawsock = rawsock + self._sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslsock = self._sslcontext.wrap_socket( + self._rawsock, do_handshake_on_connect=False) + + self._buffer = collections.deque() # For write(). + self._write_closed = False + + # Try the handshake now. Likely it will raise EAGAIN, then it + # will take care of registering the appropriate callback. + self._on_handshake() + + def _bad_error(self, exc): + # A serious error. Close the socket etc. + fd = self._sslsock.fileno() + # TODO: Record whether we have a writer and/or reader registered. + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + self._sslsock.close() + self._protocol.connection_lost(exc) # XXX call_soon()? + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._eventloop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._eventloop.add_writable(fd, self._on_handshake) + return + # TODO: What if it raises another error? + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + self._protocol.connection_made(self) + self._eventloop.add_reader(fd, self._on_ready) + self._eventloop.add_writer(fd, self._on_ready) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + else: + if data: + self._protocol.data_received(data) + else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer + self._eventloop.remove_reader(fd) + self._eventloop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + return + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = self._buffer[0] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + else: + if n == len(data): + self._buffer.popleft() + # Could try again, but let's just have the next callback do it. + else: + self._buffer[0] = data[n:] + + def write(self, data): + assert isinstance(data, bytes) + assert not self._write_closed + if not data: + return + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + def half_close(self): + self._write_closed = True + # Just set the flag. Calling shutdown() on the ssl socket + # breaks something, causing recv() to return binary data. + + +def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, + use_ssl=None): + # TODO: Pass in a protocol factory, not a protocol. + # What should be the exact sequence of events? + # - socket + # - transport + # - protocol + # - tell transport about protocol + # - tell protocol about transport + # Or should the latter two be reversed? Does it matter? + if port is None: + port = 443 if use_ssl else 80 + if use_ssl is None: + use_ssl = (port == 443) + if not socktype: + socktype = socket.SOCK_STREAM + eventloop = polling.context.eventloop + + def on_socket_connected(task): + assert not task.alive + if task.exception is not None: + # TODO: Call some callback. + raise task.exception + sock = task.result + assert sock is not None + logging.debug('on_socket_connected') + if use_ssl: + # You can pass an ssl.SSLContext object as use_ssl, + # or a bool. + if isinstance(use_ssl, bool): + sslcontext = None + else: + sslcontext = use_ssl + transport = UnixSslTransport(eventloop, protocol, sock, sslcontext) + else: + transport = UnixSocketTransport(eventloop, protocol, sock) + # TODO: Should the ransport make the following calls? + protocol.connection_made(transport) # XXX call_soon()? + # Don't do this before connection_made() is called. + eventloop.add_reader(sock.fileno(), transport._on_readable) + + coro = sockets.create_connection(host, port, af, socktype, proto) + task = scheduling.Task(coro) + task.add_done_callback(on_socket_connected) + + +def main(): # Testing... + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + host = 'xkcd.com' + if sys.argv[1:] and '.' in sys.argv[-1]: + host = sys.argv[-1] + + t0 = time.time() + + class TestProtocol(Protocol): + def connection_made(self, transport): + logging.info('Connection made at %.3f secs', time.time() - t0) + self.transport = transport + self.transport.write(b'GET / HTTP/1.0\r\nHost: ' + + host.encode('ascii') + + b'\r\n\r\n') + self.transport.half_close() + def data_received(self, data): + logging.info('Received %d bytes at t=%.3f', + len(data), time.time() - t0) + logging.debug('Received %r', data) + def connection_lost(self, exc): + logging.debug('Connection lost: %r', exc) + self.t1 = time.time() + logging.info('Total time %.3f secs', self.t1 - t0) + + tp = TestProtocol() + logging.debug('tp = %r', tp) + make_connection(tp, host, use_ssl=('-S' in sys.argv)) + logging.info('Running...') + polling.context.eventloop.run() + logging.info('Done.') + + +if __name__ == '__main__': + main() diff --git a/old/xkcd.py b/old/xkcd.py new file mode 100755 index 00000000..474009d0 --- /dev/null +++ b/old/xkcd.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3.3 +"""Minimal synchronous SSL demo, connecting to xkcd.com.""" + +import socket, ssl + +s = socket.socket() +s.connect(('xkcd.com', 443)) +ss = ssl.wrap_socket(s) + +ss.send(b'GET / HTTP/1.0\r\n\r\n') + +while True: + data = ss.recv(1000000) + print(data) + if not data: + break + +ss.close() diff --git a/old/yyftime.py b/old/yyftime.py new file mode 100644 index 00000000..f55234b9 --- /dev/null +++ b/old/yyftime.py @@ -0,0 +1,75 @@ +"""Compare timing of yield-from vs. yield calls.""" + +import gc +import time + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def run_coro(depth): + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + print('coro', depth, k, round(t1-t0, 6)) + return t1-t0 + +class Future: + + def __init__(self, g): + self.g = g + + def wait(self): + value = None + try: + while True: + f = self.g.send(value) + f.wait() + value = f.value + except StopIteration as err: + self.value = err.value + + + +def task(func): # Decorator + def wrapper(*args): + g = func(*args) + f = Future(g) + return f + return wrapper + +@task +def oldstyle(n): + if n <= 0: + return 1 + l = yield oldstyle(n-1) + r = yield oldstyle(n-1) + return l + 1 + r + +def run_olds(depth): + t0 = time.time() + f = oldstyle(depth) + f.wait() + k = f.value + t1 = time.time() + print('olds', depth, k, round(t1-t0, 6)) + return t1-t0 + +def main(): + gc.disable() + for depth in range(16): + tc = run_coro(depth) + to = run_olds(depth) + if tc: + print('ratio', round(to/tc, 2)) + +if __name__ == '__main__': + main() diff --git a/overlapped.c b/overlapped.c new file mode 100644 index 00000000..c9f6ec9f --- /dev/null +++ b/overlapped.c @@ -0,0 +1,997 @@ +/* + * Support for overlapped IO + * + * Some code borrowed from Modules/_winapi.c of CPython + */ + +/* XXX check overflow and DWORD <-> Py_ssize_t conversions + Check itemsize */ + +#include "Python.h" +#include "structmember.h" + +#define WINDOWS_LEAN_AND_MEAN +#include +#include +#include + +#if defined(MS_WIN32) && !defined(MS_WIN64) +# define F_POINTER "k" +# define T_POINTER T_ULONG +#else +# define F_POINTER "K" +# define T_POINTER T_ULONGLONG +#endif + +#define F_HANDLE F_POINTER +#define F_ULONG_PTR F_POINTER +#define F_DWORD "k" +#define F_BOOL "i" +#define F_UINT "I" + +#define T_HANDLE T_POINTER + +enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_WRITE, TYPE_ACCEPT, + TYPE_CONNECT, TYPE_DISCONNECT}; + +/* + * Map Windows error codes to subclasses of OSError + */ + +static PyObject * +SetFromWindowsErr(DWORD err) +{ + PyObject *exception_type; + + if (err == 0) + err = GetLastError(); + switch (err) { + case ERROR_CONNECTION_REFUSED: + exception_type = PyExc_ConnectionRefusedError; + break; + case ERROR_CONNECTION_ABORTED: + exception_type = PyExc_ConnectionAbortedError; + break; + default: + exception_type = PyExc_OSError; + } + return PyErr_SetExcFromWindowsErr(exception_type, err); +} + +/* + * Some functions should be loaded at runtime + */ + +static LPFN_ACCEPTEX Py_AcceptEx = NULL; +static LPFN_CONNECTEX Py_ConnectEx = NULL; +static LPFN_DISCONNECTEX Py_DisconnectEx = NULL; +static BOOL (CALLBACK *Py_CancelIoEx)(HANDLE, LPOVERLAPPED) = NULL; + +#define GET_WSA_POINTER(s, x) \ + (SOCKET_ERROR != WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, \ + &Guid##x, sizeof(Guid##x), &Py_##x, \ + sizeof(Py_##x), &dwBytes, NULL, NULL)) + +static int +initialize_function_pointers(void) +{ + GUID GuidAcceptEx = WSAID_ACCEPTEX; + GUID GuidConnectEx = WSAID_CONNECTEX; + GUID GuidDisconnectEx = WSAID_DISCONNECTEX; + HINSTANCE hKernel32; + SOCKET s; + DWORD dwBytes; + + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (s == INVALID_SOCKET) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + if (!GET_WSA_POINTER(s, AcceptEx) || + !GET_WSA_POINTER(s, ConnectEx) || + !GET_WSA_POINTER(s, DisconnectEx)) + { + closesocket(s); + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + closesocket(s); + + /* On WinXP we will have Py_CancelIoEx == NULL */ + hKernel32 = GetModuleHandle("KERNEL32"); + *(FARPROC *)&Py_CancelIoEx = GetProcAddress(hKernel32, "CancelIoEx"); + return 0; +} + +/* + * Completion port stuff + */ + +PyDoc_STRVAR( + CreateIoCompletionPort_doc, + "CreateIoCompletionPort(handle, port, key, concurrency) -> port\n\n" + "Create a completion port or register a handle with a port."); + +static PyObject * +overlapped_CreateIoCompletionPort(PyObject *self, PyObject *args) +{ + HANDLE FileHandle; + HANDLE ExistingCompletionPort; + ULONG_PTR CompletionKey; + DWORD NumberOfConcurrentThreads; + HANDLE ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE F_ULONG_PTR F_DWORD, + &FileHandle, &ExistingCompletionPort, &CompletionKey, + &NumberOfConcurrentThreads)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = CreateIoCompletionPort(FileHandle, ExistingCompletionPort, + CompletionKey, NumberOfConcurrentThreads); + Py_END_ALLOW_THREADS + + if (ret == NULL) + return SetFromWindowsErr(0); + return Py_BuildValue(F_HANDLE, ret); +} + +PyDoc_STRVAR( + GetQueuedCompletionStatus_doc, + "GetQueuedCompletionStatus(port, msecs) -> (err, bytes, key, address)\n\n" + "Get a message from completion port. Wait for up to msecs milliseconds."); + +static PyObject * +overlapped_GetQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort = NULL; + DWORD NumberOfBytes = 0; + ULONG_PTR CompletionKey = 0; + OVERLAPPED *Overlapped = NULL; + DWORD Milliseconds; + DWORD err; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, + &CompletionPort, &Milliseconds)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = GetQueuedCompletionStatus(CompletionPort, &NumberOfBytes, + &CompletionKey, &Overlapped, Milliseconds); + Py_END_ALLOW_THREADS + + err = ret ? ERROR_SUCCESS : GetLastError(); + if (Overlapped == NULL) { + if (err == WAIT_TIMEOUT) + Py_RETURN_NONE; + else + return SetFromWindowsErr(err); + } + return Py_BuildValue(F_DWORD F_DWORD F_ULONG_PTR F_POINTER, + err, NumberOfBytes, CompletionKey, Overlapped); +} + +PyDoc_STRVAR( + PostQueuedCompletionStatus_doc, + "PostQueuedCompletionStatus(port, bytes, key, address) -> None\n\n" + "Post a message to completion port."); + +static PyObject * +overlapped_PostQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort; + DWORD NumberOfBytes; + ULONG_PTR CompletionKey; + OVERLAPPED *Overlapped; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD F_ULONG_PTR F_POINTER, + &CompletionPort, &NumberOfBytes, &CompletionKey, + &Overlapped)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = PostQueuedCompletionStatus(CompletionPort, NumberOfBytes, + CompletionKey, Overlapped); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +/* + * Bind socket handle to local port without doing slow getaddrinfo() + */ + +PyDoc_STRVAR( + BindLocal_doc, + "BindLocal(handle, length_of_address_tuple) -> None\n\n" + "Bind a socket handle to an arbitrary local port.\n" + "If length_of_address_tuple is 2 then an AF_INET address is used.\n" + "If length_of_address_tuple is 4 then an AF_INET6 address is used."); + +static PyObject * +overlapped_BindLocal(PyObject *self, PyObject *args) +{ + SOCKET Socket; + int TupleLength; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE "i", &Socket, &TupleLength)) + return NULL; + + if (TupleLength == 2) { + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = 0; + addr.sin_addr.S_un.S_addr = INADDR_ANY; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else if (TupleLength == 4) { + struct sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_port = 0; + addr.sin6_addr = in6addr_any; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else { + PyErr_SetString(PyExc_ValueError, "expected tuple of length 2 or 4"); + return NULL; + } + + if (!ret) + return SetFromWindowsErr(WSAGetLastError()); + Py_RETURN_NONE; +} + +/* + * A Python object wrapping an OVERLAPPED structure and other useful data + * for overlapped I/O + */ + +PyDoc_STRVAR( + Overlapped_doc, + "Overlapped object"); + +typedef struct { + PyObject_HEAD + OVERLAPPED overlapped; + /* For convenience, we store the file handle too */ + HANDLE handle; + /* Error returned by last method call */ + DWORD error; + /* Type of operation */ + DWORD type; + /* Buffer used for reading (optional) */ + PyObject *read_buffer; + /* Buffer used for writing (optional) */ + Py_buffer write_buffer; +} OverlappedObject; + + +static PyObject * +Overlapped_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + OverlappedObject *self; + HANDLE event = INVALID_HANDLE_VALUE; + static char *kwlist[] = {"event", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|" F_HANDLE, kwlist, &event)) + return NULL; + + if (event == INVALID_HANDLE_VALUE) { + event = CreateEvent(NULL, TRUE, FALSE, NULL); + if (event == NULL) + return SetFromWindowsErr(0); + } + + self = PyObject_New(OverlappedObject, type); + if (self == NULL) { + if (event != NULL) + CloseHandle(event); + return NULL; + } + + self->handle = NULL; + self->error = 0; + self->type = TYPE_NONE; + self->read_buffer = NULL; + memset(&self->overlapped, 0, sizeof(OVERLAPPED)); + memset(&self->write_buffer, 0, sizeof(Py_buffer)); + if (event) + self->overlapped.hEvent = event; + return (PyObject *)self; +} + +static void +Overlapped_dealloc(OverlappedObject *self) +{ + DWORD bytes; + DWORD olderr = GetLastError(); + BOOL wait = FALSE; + BOOL ret; + + if (!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED) + { + if (Py_CancelIoEx && Py_CancelIoEx(self->handle, &self->overlapped)) + wait = TRUE; + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, + &bytes, wait); + Py_END_ALLOW_THREADS + + switch (ret ? ERROR_SUCCESS : GetLastError()) { + case ERROR_SUCCESS: + case ERROR_NOT_FOUND: + case ERROR_OPERATION_ABORTED: + break; + default: + PyErr_Format( + PyExc_RuntimeError, + "%R still has pending operation at " + "deallocation, the process may crash", self); + PyErr_WriteUnraisable(NULL); + } + } + + if (self->overlapped.hEvent != NULL) + CloseHandle(self->overlapped.hEvent); + + if (self->write_buffer.obj) + PyBuffer_Release(&self->write_buffer); + + Py_CLEAR(self->read_buffer); + PyObject_Del(self); + SetLastError(olderr); +} + +PyDoc_STRVAR( + Overlapped_cancel_doc, + "cancel() -> None\n\n" + "Cancel overlapped operation"); + +static PyObject * +Overlapped_cancel(OverlappedObject *self) +{ + BOOL ret = TRUE; + + if (self->type == TYPE_NOT_STARTED) + Py_RETURN_NONE; + + if (!HasOverlappedIoCompleted(&self->overlapped)) { + Py_BEGIN_ALLOW_THREADS + if (Py_CancelIoEx) + ret = Py_CancelIoEx(self->handle, &self->overlapped); + else + ret = CancelIo(self->handle); + Py_END_ALLOW_THREADS + } + + /* CancelIoEx returns ERROR_NOT_FOUND if the I/O completed in-between */ + if (!ret && GetLastError() != ERROR_NOT_FOUND) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +PyDoc_STRVAR( + Overlapped_getresult_doc, + "getresult(wait=False) -> result\n\n" + "Retrieve result of operation. If wait is true then it blocks\n" + "until the operation is finished. If wait is false and the\n" + "operation is still pending then an error is raised."); + +static PyObject * +Overlapped_getresult(OverlappedObject *self, PyObject *args) +{ + BOOL wait = FALSE; + DWORD transferred = 0; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, "|" F_BOOL, &wait)) + return NULL; + + if (self->type == TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation not yet attempted"); + return NULL; + } + + if (self->type == TYPE_NOT_STARTED) { + PyErr_SetString(PyExc_ValueError, "operation failed to start"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, &transferred, + wait); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + break; + case ERROR_BROKEN_PIPE: + if (self->read_buffer != NULL) + break; + /* fall through */ + default: + return SetFromWindowsErr(err); + } + + switch (self->type) { + case TYPE_READ: + assert(PyBytes_CheckExact(self->read_buffer)); + if (transferred != PyBytes_GET_SIZE(self->read_buffer) && + _PyBytes_Resize(&self->read_buffer, transferred)) + return NULL; + Py_INCREF(self->read_buffer); + return self->read_buffer; + case TYPE_ACCEPT: + case TYPE_CONNECT: + case TYPE_DISCONNECT: + Py_RETURN_NONE; + default: + return PyLong_FromUnsignedLong((unsigned long) transferred); + } +} + +PyDoc_STRVAR( + Overlapped_ReadFile_doc, + "ReadFile(handle, size) -> Overlapped[message]\n\n" + "Start overlapped read"); + +static PyObject * +Overlapped_ReadFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD nread; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &handle, &size)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = ReadFile(handle, PyBytes_AS_STRING(buf), size, &nread, + &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_BROKEN_PIPE: + self->type = TYPE_NOT_STARTED; + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSARecv_doc, + "RecvFile(handle, size, flags) -> Overlapped[message]\n\n" + "Start overlapped receive"); + +static PyObject * +Overlapped_WSARecv(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD flags = 0; + DWORD nread; + PyObject *buf; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD "|" F_DWORD, + &handle, &size, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + wsabuf.len = size; + wsabuf.buf = PyBytes_AS_STRING(buf); + + Py_BEGIN_ALLOW_THREADS + ret = WSARecv((SOCKET)handle, &wsabuf, 1, &nread, &flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_BROKEN_PIPE: + self->type = TYPE_NOT_STARTED; + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WriteFile_doc, + "WriteFile(handle, buf) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped write"); + +static PyObject * +Overlapped_WriteFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD written; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &handle, &bufobj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + + Py_BEGIN_ALLOW_THREADS + ret = WriteFile(handle, self->write_buffer.buf, + (DWORD)self->write_buffer.len, + &written, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSASend_doc, + "WSASend(handle, buf, flags) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped send"); + +static PyObject * +Overlapped_WSASend(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD flags; + DWORD written; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O" F_DWORD, + &handle, &bufobj, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + wsabuf.len = (DWORD)self->write_buffer.len; + wsabuf.buf = self->write_buffer.buf; + + Py_BEGIN_ALLOW_THREADS + ret = WSASend((SOCKET)handle, &wsabuf, 1, &written, flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_AcceptEx_doc, + "AcceptEx(listen_handle, accept_handle) -> Overlapped[address_as_bytes]\n\n" + "Start overlapped wait for client to connect"); + +static PyObject * +Overlapped_AcceptEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ListenSocket; + SOCKET AcceptSocket; + DWORD BytesReceived; + DWORD size; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE, + &ListenSocket, &AcceptSocket)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + size = sizeof(struct sockaddr_in6) + 16; + buf = PyBytes_FromStringAndSize(NULL, size*2); + if (!buf) + return NULL; + + self->type = TYPE_ACCEPT; + self->handle = (HANDLE)ListenSocket; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = Py_AcceptEx(ListenSocket, AcceptSocket, PyBytes_AS_STRING(buf), + 0, size, size, &BytesReceived, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + + +static int +parse_address(PyObject *obj, SOCKADDR *Address, int Length) +{ + char *Host; + unsigned short Port; + unsigned long FlowInfo; + unsigned long ScopeId; + + memset(Address, 0, Length); + + if (PyArg_ParseTuple(obj, "sH", &Host, &Port)) + { + Address->sa_family = AF_INET; + if (WSAStringToAddressA(Host, AF_INET, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN*)Address)->sin_port = htons(Port); + return Length; + } + else if (PyArg_ParseTuple(obj, "sHkk", &Host, &Port, &FlowInfo, &ScopeId)) + { + PyErr_Clear(); + Address->sa_family = AF_INET6; + if (WSAStringToAddressA(Host, AF_INET6, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN6*)Address)->sin6_port = htons(Port); + ((SOCKADDR_IN6*)Address)->sin6_flowinfo = FlowInfo; + ((SOCKADDR_IN6*)Address)->sin6_scope_id = ScopeId; + return Length; + } + + return -1; +} + + +PyDoc_STRVAR( + Overlapped_ConnectEx_doc, + "ConnectEx(client_handle, address_as_bytes) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_ConnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ConnectSocket; + PyObject *AddressObj; + char AddressBuf[sizeof(struct sockaddr_in6)]; + SOCKADDR *Address = (SOCKADDR*)AddressBuf; + int Length; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &ConnectSocket, &AddressObj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + Length = sizeof(AddressBuf); + Length = parse_address(AddressObj, Address, Length); + if (Length < 0) + return NULL; + + self->type = TYPE_CONNECT; + self->handle = (HANDLE)ConnectSocket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_ConnectEx(ConnectSocket, Address, Length, + NULL, 0, NULL, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_DisconnectEx_doc, + "DisconnectEx(handle, flags) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_DisconnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET Socket; + DWORD flags; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &Socket, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + self->type = TYPE_DISCONNECT; + self->handle = (HANDLE)Socket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_DisconnectEx(Socket, &self->overlapped, flags, 0); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +static PyObject* +Overlapped_getaddress(OverlappedObject *self) +{ + return PyLong_FromVoidPtr(&self->overlapped); +} + +static PyObject* +Overlapped_getpending(OverlappedObject *self) +{ + return PyBool_FromLong(!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED); +} + +static PyMethodDef Overlapped_methods[] = { + {"getresult", (PyCFunction) Overlapped_getresult, + METH_VARARGS, Overlapped_getresult_doc}, + {"cancel", (PyCFunction) Overlapped_cancel, + METH_NOARGS, Overlapped_cancel_doc}, + {"ReadFile", (PyCFunction) Overlapped_ReadFile, + METH_VARARGS, Overlapped_ReadFile_doc}, + {"WSARecv", (PyCFunction) Overlapped_WSARecv, + METH_VARARGS, Overlapped_WSARecv_doc}, + {"WriteFile", (PyCFunction) Overlapped_WriteFile, + METH_VARARGS, Overlapped_WriteFile_doc}, + {"WSASend", (PyCFunction) Overlapped_WSASend, + METH_VARARGS, Overlapped_WSASend_doc}, + {"AcceptEx", (PyCFunction) Overlapped_AcceptEx, + METH_VARARGS, Overlapped_AcceptEx_doc}, + {"ConnectEx", (PyCFunction) Overlapped_ConnectEx, + METH_VARARGS, Overlapped_ConnectEx_doc}, + {"DisconnectEx", (PyCFunction) Overlapped_DisconnectEx, + METH_VARARGS, Overlapped_DisconnectEx_doc}, + {NULL} +}; + +static PyMemberDef Overlapped_members[] = { + {"error", T_ULONG, + offsetof(OverlappedObject, error), + READONLY, "Error from last operation"}, + {"event", T_HANDLE, + offsetof(OverlappedObject, overlapped) + offsetof(OVERLAPPED, hEvent), + READONLY, "Overlapped event handle"}, + {NULL} +}; + +static PyGetSetDef Overlapped_getsets[] = { + {"address", (getter)Overlapped_getaddress, NULL, + "Address of overlapped structure"}, + {"pending", (getter)Overlapped_getpending, NULL, + "Whether the operation is pending"}, + {NULL}, +}; + +PyTypeObject OverlappedType = { + PyVarObject_HEAD_INIT(NULL, 0) + /* tp_name */ "_overlapped.Overlapped", + /* tp_basicsize */ sizeof(OverlappedObject), + /* tp_itemsize */ 0, + /* tp_dealloc */ (destructor) Overlapped_dealloc, + /* tp_print */ 0, + /* tp_getattr */ 0, + /* tp_setattr */ 0, + /* tp_reserved */ 0, + /* tp_repr */ 0, + /* tp_as_number */ 0, + /* tp_as_sequence */ 0, + /* tp_as_mapping */ 0, + /* tp_hash */ 0, + /* tp_call */ 0, + /* tp_str */ 0, + /* tp_getattro */ 0, + /* tp_setattro */ 0, + /* tp_as_buffer */ 0, + /* tp_flags */ Py_TPFLAGS_DEFAULT, + /* tp_doc */ "OVERLAPPED structure wrapper", + /* tp_traverse */ 0, + /* tp_clear */ 0, + /* tp_richcompare */ 0, + /* tp_weaklistoffset */ 0, + /* tp_iter */ 0, + /* tp_iternext */ 0, + /* tp_methods */ Overlapped_methods, + /* tp_members */ Overlapped_members, + /* tp_getset */ Overlapped_getsets, + /* tp_base */ 0, + /* tp_dict */ 0, + /* tp_descr_get */ 0, + /* tp_descr_set */ 0, + /* tp_dictoffset */ 0, + /* tp_init */ 0, + /* tp_alloc */ 0, + /* tp_new */ Overlapped_new, +}; + +static PyMethodDef overlapped_functions[] = { + {"CreateIoCompletionPort", overlapped_CreateIoCompletionPort, + METH_VARARGS, CreateIoCompletionPort_doc}, + {"GetQueuedCompletionStatus", overlapped_GetQueuedCompletionStatus, + METH_VARARGS, GetQueuedCompletionStatus_doc}, + {"PostQueuedCompletionStatus", overlapped_PostQueuedCompletionStatus, + METH_VARARGS, PostQueuedCompletionStatus_doc}, + {"BindLocal", overlapped_BindLocal, + METH_VARARGS, BindLocal_doc}, + {NULL} +}; + +static struct PyModuleDef overlapped_module = { + PyModuleDef_HEAD_INIT, + "_overlapped", + NULL, + -1, + overlapped_functions, + NULL, + NULL, + NULL, + NULL +}; + +#define WINAPI_CONSTANT(fmt, con) \ + PyDict_SetItemString(d, #con, Py_BuildValue(fmt, con)) + +PyMODINIT_FUNC +PyInit__overlapped(void) +{ + PyObject *m, *d; + + /* Ensure WSAStartup() called before initializing function pointers */ + m = PyImport_ImportModule("_socket"); + if (!m) + return NULL; + Py_DECREF(m); + + if (initialize_function_pointers() < 0) + return NULL; + + if (PyType_Ready(&OverlappedType) < 0) + return NULL; + + m = PyModule_Create(&overlapped_module); + if (PyModule_AddObject(m, "Overlapped", (PyObject *)&OverlappedType) < 0) + return NULL; + + d = PyModule_GetDict(m); + + WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING); + WINAPI_CONSTANT(F_DWORD, INFINITE); + WINAPI_CONSTANT(F_HANDLE, INVALID_HANDLE_VALUE); + WINAPI_CONSTANT(F_HANDLE, NULL); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_ACCEPT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_CONNECT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, TF_REUSE_SOCKET); + + return m; +} diff --git a/runtests.py b/runtests.py new file mode 100644 index 00000000..e4678d70 --- /dev/null +++ b/runtests.py @@ -0,0 +1,198 @@ +"""Run all unittests. + +Usage: + python3 runtests.py [-v] [-q] [pattern] ... + +Where: + -v: verbose + -q: quiet + pattern: optional regex patterns to match test ids (default all tests) + +Note that the test id is the fully qualified name of the test, +including package, module, class and method, +e.g. 'tests.events_test.PolicyTests.testPolicy'. + +runtests.py with --coverage argument is equivalent of: + + $(COVERAGE) run --branch runtests.py -v + $(COVERAGE) html $(list of files) + $(COVERAGE) report -m $(list of files) + +""" + +# Originally written by Beech Horn (for NDB). + +import argparse +import logging +import os +import re +import sys +import subprocess +import unittest +import importlib.machinery + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +ARGS = argparse.ArgumentParser(description="Run all unittests.") +ARGS.add_argument( + '-v', action="store", dest='verbose', + nargs='?', const=1, type=int, default=0, help='verbose') +ARGS.add_argument( + '-x', action="store_true", dest='exclude', help='exclude tests') +ARGS.add_argument( + '-q', action="store_true", dest='quiet', help='quiet') +ARGS.add_argument( + '--tests', action="store", dest='testsdir', default='tests', + help='tests directory') +ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') +ARGS.add_argument( + 'pattern', action="store", nargs="*", + help='optional regex patterns to match test ids (default all tests)') + +COV_ARGS = argparse.ArgumentParser(description="Run all unittests.") +COV_ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') + + +def load_modules(basedir, suffix='.py'): + def list_dir(prefix, dir): + files = [] + + modpath = os.path.join(dir, '__init__.py') + if os.path.isfile(modpath): + mod = os.path.split(dir)[-1] + files.append(('%s%s' % (prefix, mod), modpath)) + + prefix = '%s%s.' % (prefix, mod) + + for name in os.listdir(dir): + path = os.path.join(dir, name) + + if os.path.isdir(path): + files.extend(list_dir('%s%s.' % (prefix, name), path)) + else: + if (name != '__init__.py' and + name.endswith(suffix) and + not name.startswith(('.', '_'))): + files.append(('%s%s' % (prefix, name[:-3]), path)) + + return files + + mods = [] + for modname, sourcefile in list_dir('', basedir): + if modname == 'runtests': + continue + try: + loader = importlib.machinery.SourceFileLoader(modname, sourcefile) + mods.append((loader.load_module(), sourcefile)) + except Exception as err: + print("Skipping '%s': %s" % (modname, err)) + + return mods + + +def load_tests(testsdir, includes=(), excludes=()): + mods = [mod for mod, _ in load_modules(testsdir)] + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + for mod in mods: + for name in set(dir(mod)): + if name.endswith('Tests'): + test_module = getattr(mod, name) + tests = loader.loadTestsFromTestCase(test_module) + if includes: + tests = [test + for test in tests + if any(re.search(pat, test.id()) + for pat in includes)] + if excludes: + tests = [test + for test in tests + if not any(re.search(pat, test.id()) + for pat in excludes)] + suite.addTests(tests) + + return suite + + +def runtests(): + args = ARGS.parse_args() + + testsdir = os.path.abspath(args.testsdir) + if not os.path.isdir(testsdir): + print("Tests directory is not found: %s\n" % testsdir) + ARGS.print_help() + return + + excludes = includes = [] + if args.exclude: + excludes = args.pattern + else: + includes = args.pattern + + v = 0 if args.quiet else args.verbose + 1 + + tests = load_tests(args.testsdir, includes, excludes) + logger = logging.getLogger() + if v == 0: + logger.setLevel(logging.CRITICAL) + elif v == 1: + logger.setLevel(logging.ERROR) + elif v == 2: + logger.setLevel(logging.WARNING) + elif v == 3: + logger.setLevel(logging.INFO) + elif v >= 4: + logger.setLevel(logging.DEBUG) + result = unittest.TextTestRunner(verbosity=v).run(tests) + sys.exit(not result.wasSuccessful()) + + +def runcoverage(sdir, args): + """ + To install coverage3 for Python 3, you need: + - Distribute (http://packages.python.org/distribute/) + + What worked for me: + - download http://python-distribute.org/distribute_setup.py + * curl -O http://python-distribute.org/distribute_setup.py + - python3 distribute_setup.py + - python3 -m easy_install coverage + """ + try: + import coverage + except ImportError: + print("Coverage package is not found.") + print(runcoverage.__doc__) + return + + sdir = os.path.abspath(sdir) + if not os.path.isdir(sdir): + print("Python files directory is not found: %s\n" % sdir) + ARGS.print_help() + return + + mods = [source for _, source in load_modules(sdir)] + coverage = [sys.executable, '-m', 'coverage'] + + try: + subprocess.check_call( + coverage + ['run', '--branch', 'runtests.py'] + args) + except: + pass + else: + subprocess.check_call(coverage + ['html'] + mods) + subprocess.check_call(coverage + ['report'] + mods) + + +if __name__ == '__main__': + if '--coverage' in sys.argv: + cov_args, args = COV_ARGS.parse_known_args() + runcoverage(cov_args.coverage, args) + else: + runtests() diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..0260f9d5 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[build_ext] +build_lib=tulip diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..67b037cc --- /dev/null +++ b/setup.py @@ -0,0 +1,4 @@ +from distutils.core import setup, Extension + +ext = Extension('_overlapped', ['overlapped.c'], libraries=['ws2_32']) +setup(name='_overlapped', ext_modules=[ext]) diff --git a/srv.py b/srv.py new file mode 100644 index 00000000..540e63b9 --- /dev/null +++ b/srv.py @@ -0,0 +1,123 @@ +"""Simple server written using an event loop.""" + +import http.client +import email.message +import email.parser +import os + +import tulip +import tulip.http + + +class HttpServer(tulip.Protocol): + + def __init__(self): + super().__init__() + self.transport = None + self.reader = None + self.handler = None + + @tulip.task + def handle_request(self): + try: + method, path, version = yield from self.reader.read_request_line() + except http.client.BadStatusLine: + self.transport.close() + return + + print('method = {!r}; path = {!r}; version = {!r}'.format( + method, path, version)) + + if (not (path.isprintable() and path.startswith('/')) or '/.' in path): + print('bad path', repr(path)) + path = None + else: + path = '.' + path + if not os.path.exists(path): + print('no file', repr(path)) + path = None + else: + isdir = os.path.isdir(path) + + if not path: + self.transport.write(b'HTTP/1.0 404 Not found\r\n\r\n') + self.transport.close() + return + + message = yield from self.reader.read_message() + + headers = email.message.Message() + for hdr, val in message.headers: + print(hdr, val) + headers.add_header(hdr, val) + + write = self.transport.write + if isdir and not path.endswith('/'): + bpath = path.encode('ascii') + write(b'HTTP/1.0 302 Redirected\r\n' + b'URI: ' + bpath + b'/\r\n' + b'Location: ' + bpath + b'/\r\n' + b'\r\n') + return + write(b'HTTP/1.0 200 Ok\r\n') + if isdir: + write(b'Content-type: text/html\r\n') + else: + write(b'Content-type: text/plain\r\n') + write(b'\r\n') + if isdir: + write(b'
    \r\n') + for name in sorted(os.listdir(path)): + if name.isprintable() and not name.startswith('.'): + try: + bname = name.encode('ascii') + except UnicodeError: + pass + else: + if os.path.isdir(os.path.join(path, name)): + write(b'
  • ' + bname + b'/
  • \r\n') + else: + write(b'
  • ' + bname + b'
  • \r\n') + write(b'
') + else: + try: + with open(path, 'rb') as f: + write(f.read()) + except OSError: + write(b'Cannot open\r\n') + self.transport.close() + + def connection_made(self, transport): + self.transport = transport + print('connection made', transport, transport.get_extra_info('socket')) + self.reader = tulip.http.HttpStreamReader() + self.handler = self.handle_request() + + def data_received(self, data): + print('data received', data) + self.reader.feed_data(data) + + def eof_received(self): + print('eof received') + self.reader.feed_eof() + + def connection_lost(self, exc): + print('connection lost', exc) + if (self.handler.done() and + not self.handler.cancelled() and + self.handler.exception() is not None): + print('handler exception:', self.handler.exception()) + + +def main(): + loop = tulip.get_event_loop() + f = loop.start_serving(HttpServer, '127.0.0.1', 8080) + x = loop.run_until_complete(f) + print('serving on', x.getsockname()) + loop.run_forever() + + +if __name__ == '__main__': + main() diff --git a/sslsrv.py b/sslsrv.py new file mode 100644 index 00000000..a1bc04f9 --- /dev/null +++ b/sslsrv.py @@ -0,0 +1,56 @@ +"""Serve up an SSL connection, after Python ssl module docs.""" + +import socket +import ssl +import os + + +def main(): + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + certfile = getcertfile() + context.load_cert_chain(certfile=certfile, keyfile=certfile) + bindsocket = socket.socket() + bindsocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + bindsocket.bind(('', 4443)) + bindsocket.listen(5) + + while True: + newsocket, fromaddr = bindsocket.accept() + try: + connstream = context.wrap_socket(newsocket, server_side=True) + try: + deal_with_client(connstream) + finally: + connstream.shutdown(socket.SHUT_RDWR) + connstream.close() + except Exception as exc: + print(exc.__class__.__name__ + ':', exc) + + +def getcertfile(): + import test # Test package + testdir = os.path.dirname(test.__file__) + certfile = os.path.join(testdir, 'keycert.pem') + print('certfile =', certfile) + return certfile + + +def deal_with_client(connstream): + data = connstream.recv(1024) + # empty data means the client is finished with us + while data: + if not do_something(connstream, data): + # we'll assume do_something returns False + # when we're finished with client + break + data = connstream.recv(1024) + # finished with client + + +def do_something(connstream, data): + # just echo back + connstream.sendall(data) + + +if __name__ == '__main__': + main() diff --git a/tests/base_events_test.py b/tests/base_events_test.py new file mode 100644 index 00000000..f5bfc363 --- /dev/null +++ b/tests/base_events_test.py @@ -0,0 +1,271 @@ +"""Tests for base_events.py""" + +import concurrent.futures +import logging +import socket +import time +import unittest +import unittest.mock + +from tulip import base_events +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import test_utils + + +class BaseEventLoopTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + + self.event_loop = base_events.BaseEventLoop() + self.event_loop._selector = unittest.mock.Mock() + self.event_loop._selector.registered_count.return_value = 1 + + def test_not_implemented(self): + m = unittest.mock.Mock() + self.assertRaises( + NotImplementedError, + self.event_loop._make_socket_transport, m, m) + self.assertRaises( + NotImplementedError, + self.event_loop._make_ssl_transport, m, m, m, m) + self.assertRaises( + NotImplementedError, + self.event_loop._make_datagram_transport, m, m) + self.assertRaises( + NotImplementedError, self.event_loop._process_events, []) + self.assertRaises( + NotImplementedError, self.event_loop._write_to_self) + self.assertRaises( + NotImplementedError, self.event_loop._read_from_self) + + def test_add_callback_handler(self): + h = events.Handler(lambda: False, ()) + + self.event_loop._add_callback(h) + self.assertFalse(self.event_loop._scheduled) + self.assertIn(h, self.event_loop._ready) + + def test_add_callback_timer(self): + when = time.monotonic() + + h1 = events.Timer(when, lambda: False, ()) + h2 = events.Timer(when+10.0, lambda: False, ()) + + self.event_loop._add_callback(h2) + self.event_loop._add_callback(h1) + self.assertEqual([h1, h2], self.event_loop._scheduled) + self.assertFalse(self.event_loop._ready) + + def test_add_callback_cancelled_handler(self): + h = events.Handler(lambda: False, ()) + h.cancel() + + self.event_loop._add_callback(h) + self.assertFalse(self.event_loop._scheduled) + self.assertFalse(self.event_loop._ready) + + def test_wrap_future(self): + f = futures.Future() + self.assertIs(self.event_loop.wrap_future(f), f) + + def test_wrap_future_concurrent(self): + f = concurrent.futures.Future() + self.assertIsInstance(self.event_loop.wrap_future(f), futures.Future) + + def test_set_default_executor(self): + executor = unittest.mock.Mock() + self.event_loop.set_default_executor(executor) + self.assertIs(executor, self.event_loop._default_executor) + + def test_getnameinfo(self): + sockaddr = unittest.mock.Mock() + self.event_loop.run_in_executor = unittest.mock.Mock() + self.event_loop.getnameinfo(sockaddr) + self.assertEqual( + (None, socket.getnameinfo, sockaddr, 0), + self.event_loop.run_in_executor.call_args[0]) + + def test_call_soon(self): + def cb(): + pass + + h = self.event_loop.call_soon(cb) + self.assertEqual(h._callback, cb) + self.assertIsInstance(h, events.Handler) + self.assertIn(h, self.event_loop._ready) + + def test_call_later(self): + def cb(): + pass + + h = self.event_loop.call_later(10.0, cb) + self.assertIsInstance(h, events.Timer) + self.assertIn(h, self.event_loop._scheduled) + self.assertNotIn(h, self.event_loop._ready) + + def test_call_later_no_delay(self): + def cb(): + pass + + h = self.event_loop.call_later(0, cb) + self.assertIn(h, self.event_loop._ready) + self.assertNotIn(h, self.event_loop._scheduled) + + def test_run_once_in_executor_handler(self): + def cb(): + pass + + self.assertRaises( + AssertionError, self.event_loop.run_in_executor, + None, events.Handler(cb, ()), ('',)) + self.assertRaises( + AssertionError, self.event_loop.run_in_executor, + None, events.Timer(10, cb, ())) + + def test_run_once_in_executor_canceled(self): + def cb(): + pass + h = events.Handler(cb, ()) + h.cancel() + + f = self.event_loop.run_in_executor(None, h) + self.assertIsInstance(f, futures.Future) + self.assertTrue(f.done()) + + def test_run_once_in_executor(self): + def cb(): + pass + h = events.Handler(cb, ()) + f = futures.Future() + executor = unittest.mock.Mock() + executor.submit.return_value = f + + self.event_loop.set_default_executor(executor) + + res = self.event_loop.run_in_executor(None, h) + self.assertIs(f, res) + + executor = unittest.mock.Mock() + executor.submit.return_value = f + res = self.event_loop.run_in_executor(executor, h) + self.assertIs(f, res) + self.assertTrue(executor.submit.called) + + def test_run_once(self): + self.event_loop._run_once = unittest.mock.Mock() + self.event_loop._run_once.side_effect = base_events._StopError + self.event_loop.run_once() + self.assertTrue(self.event_loop._run_once.called) + + def test__run_once(self): + h1 = events.Timer(time.monotonic() + 0.1, lambda: True, ()) + h2 = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + + h1.cancel() + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h1) + self.event_loop._scheduled.append(h2) + self.event_loop._run_once() + + t = self.event_loop._selector.select.call_args[0][0] + self.assertTrue(9.99 < t < 10.1) + self.assertEqual([h2], self.event_loop._scheduled) + self.assertTrue(self.event_loop._process_events.called) + + def test__run_once_timeout(self): + h = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._run_once(1.0) + self.assertEqual((1.0,), self.event_loop._selector.select.call_args[0]) + + def test__run_once_timeout_with_ready(self): + """If event loop has ready callbacks, select timeout is always 0.""" + h = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._ready.append(h) + self.event_loop._run_once(1.0) + + self.assertEqual((0,), self.event_loop._selector.select.call_args[0]) + + @unittest.mock.patch('tulip.base_events.time') + @unittest.mock.patch('tulip.base_events.logging') + def test__run_once_logging(self, m_logging, m_time): + """Log to INFO level if timeout > 1.0 sec.""" + idx = -1 + data = [10.0, 10.0, 12.0, 13.0] + + def monotonic(): + nonlocal data, idx + idx += 1 + return data[idx] + + m_time.monotonic = monotonic + m_logging.INFO = logging.INFO + m_logging.DEBUG = logging.DEBUG + + self.event_loop._scheduled.append(events.Timer(11.0, lambda: True, ())) + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._run_once() + self.assertEqual(logging.INFO, m_logging.log.call_args[0][0]) + + idx = -1 + data = [10.0, 10.0, 10.3, 13.0] + self.event_loop._scheduled = [events.Timer(11.0, lambda:True, ())] + self.event_loop._run_once() + self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0]) + + def test__run_once_schedule_handler(self): + handler = None + processed = False + + def cb(event_loop): + nonlocal processed, handler + processed = True + handler = event_loop.call_soon(lambda: True) + + h = events.Timer(time.monotonic() - 1, cb, (self.event_loop,)) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._run_once() + + self.assertTrue(processed) + self.assertEqual([handler], list(self.event_loop._ready)) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_connection_mutiple_errors(self, m_socket): + self.suppress_log_errors() + + class MyProto(protocols.Protocol): + pass + + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + + idx = -1 + errors = ['err1', 'err2'] + + def _socket(*args, **kw): + nonlocal idx, errors + idx += 1 + raise socket.error(errors[idx]) + + m_socket.socket = _socket + m_socket.error = socket.error + + self.event_loop.getaddrinfo = getaddrinfo + + task = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + task._step() + exc = task.exception() + self.assertEqual("Multiple exceptions: err1, err2", str(exc)) diff --git a/tests/events_test.py b/tests/events_test.py new file mode 100644 index 00000000..32f2379c --- /dev/null +++ b/tests/events_test.py @@ -0,0 +1,1208 @@ +"""Tests for events.py.""" + +import concurrent.futures +import errno +import gc +import os +import re +import signal +import socket +try: + import ssl +except ImportError: + ssl = None +import sys +import threading +import time +import unittest +import unittest.mock + +from tulip import events +from tulip import transports +from tulip import protocols +from tulip import selector_events +from tulip import test_utils + + +class MyProto(protocols.Protocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: xkcd.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + + +class MyDatagramProto(protocols.DatagramProtocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + + +class EventLoopTestsMixin: + + def setUp(self): + super().setUp() + self.event_loop = self.create_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + gc.collect() + super().tearDown() + + def test_run(self): + self.event_loop.run() # Returns immediately. + + def test_call_later(self): + results = [] + + def callback(arg): + results.append(arg) + + self.event_loop.call_later(0.1, callback, 'hello world') + t0 = time.monotonic() + self.event_loop.run() + t1 = time.monotonic() + self.assertEqual(results, ['hello world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_call_repeatedly(self): + results = [] + + def callback(arg): + results.append(arg) + + self.event_loop.call_repeatedly(0.03, callback, 'ho') + self.event_loop.call_later(0.1, self.event_loop.stop) + self.event_loop.run() + self.assertEqual(results, ['ho', 'ho', 'ho']) + + def test_call_soon(self): + results = [] + + def callback(arg1, arg2): + results.append((arg1, arg2)) + + self.event_loop.call_soon(callback, 'hello', 'world') + self.event_loop.run() + self.assertEqual(results, [('hello', 'world')]) + + def test_call_soon_with_handler(self): + results = [] + + def callback(): + results.append('yeah') + + handler = events.Handler(callback, ()) + self.assertIs(self.event_loop.call_soon(handler), handler) + self.event_loop.run() + self.assertEqual(results, ['yeah']) + + def test_call_soon_threadsafe(self): + results = [] + + def callback(arg): + results.append(arg) + + def run(): + self.event_loop.call_soon_threadsafe(callback, 'hello') + + t = threading.Thread(target=run) + self.event_loop.call_later(0.1, callback, 'world') + t0 = time.monotonic() + t.start() + self.event_loop.run() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_call_soon_threadsafe_same_thread(self): + results = [] + + def callback(arg): + results.append(arg) + + self.event_loop.call_later(0.1, callback, 'world') + self.event_loop.call_soon_threadsafe(callback, 'hello') + self.event_loop.run() + self.assertEqual(results, ['hello', 'world']) + + def test_call_soon_threadsafe_with_handler(self): + results = [] + + def callback(arg): + results.append(arg) + + handler = events.Handler(callback, ('hello',)) + + def run(): + self.assertIs( + self.event_loop.call_soon_threadsafe(handler), handler) + + t = threading.Thread(target=run) + self.event_loop.call_later(0.1, callback, 'world') + + t0 = time.monotonic() + t.start() + self.event_loop.run() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_wrap_future(self): + def run(arg): + time.sleep(0.1) + return arg + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = self.event_loop.wrap_future(f1) + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'oi') + + def test_run_in_executor(self): + def run(arg): + time.sleep(0.1) + return arg + f2 = self.event_loop.run_in_executor(None, run, 'yo') + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def test_run_in_executor_with_handler(self): + def run(arg): + time.sleep(0.1) + return arg + handler = events.Handler(run, ('yo',)) + f2 = self.event_loop.run_in_executor(None, handler) + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def test_reader_callback(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.event_loop.remove_reader(r.fileno())) + r.close() + + self.event_loop.add_reader(r.fileno(), reader) + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_reader_callback_with_handler(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.event_loop.remove_reader(r.fileno())) + r.close() + + handler = events.Handler(reader, ()) + self.assertIs(handler, self.event_loop.add_reader(r.fileno(), handler)) + + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_reader_callback_cancel(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + return + if data: + bytes_read.append(data) + if sum(len(b) for b in bytes_read) >= 6: + handler.cancel() + if not data: + r.close() + + handler = self.event_loop.add_reader(r.fileno(), reader) + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_writer_callback(self): + r, w = test_utils.socketpair() + w.setblocking(False) + self.event_loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + + def remove_writer(): + self.assertTrue(self.event_loop.remove_writer(w.fileno())) + + self.event_loop.call_later(0.1, remove_writer) + self.event_loop.run() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def test_writer_callback_with_handler(self): + r, w = test_utils.socketpair() + w.setblocking(False) + handler = events.Handler(w.send, (b'x'*(256*1024),)) + self.assertIs(self.event_loop.add_writer(w.fileno(), handler), handler) + + def remove_writer(): + self.assertTrue(self.event_loop.remove_writer(w.fileno())) + + self.event_loop.call_later(0.1, remove_writer) + self.event_loop.run() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def test_writer_callback_cancel(self): + r, w = test_utils.socketpair() + w.setblocking(False) + + def sender(): + w.send(b'x'*256) + handler.cancel() + + handler = self.event_loop.add_writer(w.fileno(), sender) + self.event_loop.run() + w.close() + data = r.recv(1024) + r.close() + self.assertTrue(data == b'x'*256) + + def test_sock_client_ops(self): + sock = socket.socket() + sock.setblocking(False) + # TODO: This depends on python.org behavior! + address = socket.getaddrinfo('python.org', 80, socket.AF_INET)[0][4] + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, address)) + self.event_loop.run_until_complete( + self.event_loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.event_loop.run_until_complete( + self.event_loop.sock_recv(sock, 1024)) + sock.close() + self.assertTrue(re.match(rb'HTTP/1.\d 302', data), data) + + def test_sock_client_fail(self): + sock = socket.socket() + sock.setblocking(False) + # TODO: This depends on python.org behavior! + address = socket.getaddrinfo('python.org', 12345, socket.AF_INET)[0][4] + with self.assertRaises(ConnectionRefusedError): + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, address)) + sock.close() + + def test_sock_accept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + + f = self.event_loop.sock_accept(listener) + conn, addr = self.event_loop.run_until_complete(f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + def test_add_signal_handler(self): + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + # Check error behavior first. + self.assertRaises( + TypeError, self.event_loop.add_signal_handler, 'boom', my_handler) + self.assertRaises( + TypeError, self.event_loop.remove_signal_handler, 'boom') + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, signal.NSIG+1) + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, 0, my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, 0) + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, -1, my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, -1) + self.assertRaises( + RuntimeError, self.event_loop.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + self.event_loop.add_signal_handler(signal.SIGINT, my_handler) + self.event_loop.run_once() + os.kill(os.getpid(), signal.SIGINT) + self.event_loop.run_once() + self.assertEqual(caught, 1) + # Removing it should restore the default handler. + self.assertTrue(self.event_loop.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGINT)) + + @unittest.skipIf(sys.platform == 'win32', 'Unix only') + def test_cancel_signal_handler(self): + # Cancelling the handler should remove it (eventually). + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + handler = self.event_loop.add_signal_handler(signal.SIGINT, my_handler) + handler.cancel() + os.kill(os.getpid(), signal.SIGINT) + self.event_loop.run_once() + self.assertEqual(caught, 0) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_while_selecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + self.event_loop.add_signal_handler( + signal.SIGALRM, my_handler) + + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + self.event_loop.call_later(0.15, self.event_loop.stop) + self.event_loop.run_forever() + self.assertEqual(caught, 1) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_args(self): + some_args = (42,) + caught = 0 + + def my_handler(*args): + nonlocal caught + caught += 1 + self.assertEqual(args, some_args) + + self.event_loop.add_signal_handler( + signal.SIGALRM, my_handler, *some_args) + + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + self.event_loop.call_later(0.15, self.event_loop.stop) + self.event_loop.run_forever() + self.assertEqual(caught, 1) + + def test_create_connection(self): + # TODO: This depends on xkcd.com behavior! + f = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) + + def test_create_connection_sock(self): + # TODO: This depends on xkcd.com behavior! + sock = None + infos = self.event_loop.run_until_complete( + self.event_loop.getaddrinfo( + 'xkcd.com', 80, type=socket.SOCK_STREAM)) + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, address)) + except: + pass + else: + break + + f = self.event_loop.create_connection(MyProto, sock=sock) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_ssl_connection(self): + # TODO: This depends on xkcd.com behavior! + f = self.event_loop.create_connection( + MyProto, 'xkcd.com', 443, ssl=True) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + self.assertTrue(hasattr(tr.get_extra_info('socket'), 'getsockname')) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) + + def test_create_connection_host_port_sock(self): + self.suppress_log_errors() + fut = self.event_loop.create_connection( + MyProto, 'xkcd.com', 80, sock=object()) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_create_connection_no_host_port_sock(self): + self.suppress_log_errors() + fut = self.event_loop.create_connection(MyProto) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_create_connection_no_getaddrinfo(self): + self.suppress_log_errors() + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + fut = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, fut) + + def test_create_connection_connect_err(self): + self.suppress_log_errors() + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + fut = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, fut) + + def test_create_connection_mutiple_errors(self): + self.suppress_log_errors() + + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + self.event_loop.getaddrinfo = getaddrinfo + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + fut = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, fut) + + def test_start_serving(self): + proto = None + + def factory(): + nonlocal proto + proto = MyProto() + return proto + + f = self.event_loop.start_serving(factory, '0.0.0.0', 0) + sock = self.event_loop.run_until_complete(f) + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + self.event_loop.run_once() + self.assertIsInstance(proto, MyProto) + self.assertEqual('INITIAL', proto.state) + self.event_loop.run_once() + self.assertEqual('CONNECTED', proto.state) + self.assertEqual(0, proto.nbytes) + self.event_loop.run_once() + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('socket')) + conn = proto.transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + self.assertEqual( + '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + + # close connection + proto.transport.close() + + self.event_loop.run_once() + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + def test_start_serving_sock(self): + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.event_loop.start_serving(MyProto, sock=sock_ob) + sock = self.event_loop.run_until_complete(f) + self.assertIs(sock, sock_ob) + + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + self.event_loop.run_once() # This is quite mysterious, but necessary. + self.event_loop.run_once() + self.event_loop.run_once() + sock.close() + client.close() + + def test_start_serving_host_port_sock(self): + self.suppress_log_errors() + fut = self.event_loop.start_serving( + MyProto, '0.0.0.0', 0, sock=object()) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_start_serving_no_host_port_sock(self): + self.suppress_log_errors() + fut = self.event_loop.start_serving(MyProto) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_start_serving_no_getaddrinfo(self): + self.suppress_log_errors() + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + f = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, f) + + @unittest.mock.patch('tulip.base_events.socket') + def test_start_serving_cant_bind(self, m_socket): + self.suppress_log_errors() + + class Err(socket.error): + pass + + m_socket.error = socket.error + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.setsockopt.side_effect = Err + + fut = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises(Err, self.event_loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_create_datagram_connection(self): + server = None + + def factory(): + nonlocal server + server = TestMyDatagramProto() + return server + + class TestMyDatagramProto(MyDatagramProto): + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + f = self.event_loop.start_serving_datagram(factory, '127.0.0.1', 0) + sock = self.event_loop.run_until_complete(f) + host, port = sock.getsockname() + + f = self.event_loop.create_datagram_connection( + MyDatagramProto, host, port) + transport, protocol = self.event_loop.run_until_complete(f) + + self.assertEqual('INITIALIZED', protocol.state) + transport.sendto(b'xxx') + self.event_loop.run_once() + self.assertEqual(0, server.nbytes) + self.event_loop.run_once() + self.assertEqual(3, server.nbytes) + self.event_loop.run_once() + + # received + self.event_loop.run_once() + self.assertEqual(8, protocol.nbytes) + + # extra info is available + self.assertIsNotNone(transport.get_extra_info('socket')) + conn = transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + + # close connection + transport.close() + + self.event_loop.run_once() + self.assertEqual('CLOSED', protocol.state) + + server.transport.close() + + def test_create_datagram_connection_no_connection(self): + server = None + + def factory(): + nonlocal server + server = TestMyDatagramProto() + return server + + class TestMyDatagramProto(MyDatagramProto): + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + f = self.event_loop.start_serving_datagram(factory, '127.0.0.1', 0) + sock = self.event_loop.run_until_complete(f) + host, port = sock.getsockname() + + f = self.event_loop.create_datagram_connection(MyDatagramProto) + transport, protocol = self.event_loop.run_until_complete(f) + + self.assertEqual('INITIALIZED', protocol.state) + transport.sendto(b'xxx', (host, port)) + self.event_loop.run_once() + self.assertEqual(0, server.nbytes) + self.event_loop.run_once() + self.assertEqual(3, server.nbytes) + self.event_loop.run_once() + + # received + self.event_loop.run_once() + self.assertEqual(8, protocol.nbytes) + + # extra info is available + self.assertIsNotNone(transport.get_extra_info('socket')) + conn = transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + + # close connection + transport.close() + + self.event_loop.run_once() + self.assertEqual('CLOSED', protocol.state) + + server.transport.close() + + def test_create_datagram_connection_no_getaddrinfo(self): + self.suppress_log_errors() + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + fut = self.event_loop.create_datagram_connection( + protocols.DatagramProtocol, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, fut) + + def test_create_datagram_connection_connect_err(self): + self.suppress_log_errors() + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + fut = self.event_loop.create_datagram_connection( + protocols.DatagramProtocol, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, fut) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_connection_sockopt_err(self, m_socket): + self.suppress_log_errors() + + m_socket.error = socket.error + m_socket.socket.return_value.setsockopt.side_effect = socket.error + + fut = self.event_loop.create_datagram_connection( + protocols.DatagramProtocol) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, fut) + self.assertTrue( + m_socket.socket.return_value.close.called) + + def test_start_serving_datagram(self): + class TestMyDatagramProto(MyDatagramProto): + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + proto = None + + def factory(): + nonlocal proto + proto = TestMyDatagramProto() + return proto + + f = self.event_loop.start_serving_datagram(factory, '127.0.0.1', 0) + sock = self.event_loop.run_until_complete(f) + self.assertEqual('INITIALIZED', proto.state) + + host, port = sock.getsockname() + self.assertEqual(host, '127.0.0.1') + client = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + client.sendto(b'xxx', ('127.0.0.1', port)) + self.event_loop.run_once() + self.assertEqual(0, proto.nbytes) + self.event_loop.run_once() + self.assertEqual(3, proto.nbytes) + + data, server = client.recvfrom(4096) + self.assertEqual(b'resp:xxx', data) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('socket')) + conn = proto.transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + self.assertEqual( + '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + + # close connection + proto.transport.close() + + self.event_loop.run_once() + self.assertEqual('CLOSED', proto.state) + + client.close() + + def test_start_serving_datagram_no_getaddrinfoc(self): + self.suppress_log_errors() + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + f = self.event_loop.start_serving_datagram( + MyDatagramProto, '0.0.0.0', 0) + + self.assertRaises( + socket.error, self.event_loop.run_until_complete, f) + + @unittest.mock.patch('tulip.base_events.socket') + def test_start_serving_datagram_cant_bind(self, m_socket): + self.suppress_log_errors() + + class Err(socket.error): + pass + + m_socket.error = socket.error + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.setsockopt.side_effect = Err + + fut = self.event_loop.start_serving_datagram( + MyDatagramProto, '0.0.0.0', 0) + self.assertRaises(Err, self.event_loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_accept_connection_retry(self): + class Err(socket.error): + errno = errno.EAGAIN + + sock = unittest.mock.Mock() + sock.accept.side_effect = Err + + self.event_loop._accept_connection(MyProto, sock) + self.assertFalse(sock.close.called) + + def test_accept_connection_exception(self): + self.suppress_log_errors() + + sock = unittest.mock.Mock() + sock.accept.side_effect = socket.error + + self.event_loop._accept_connection(MyProto, sock) + self.assertTrue(sock.close.called) + + def test_internal_fds(self): + event_loop = self.create_event_loop() + if not isinstance(event_loop, selector_events.BaseSelectorEventLoop): + return + + self.assertEqual(1, event_loop._internal_fds) + event_loop.close() + self.assertEqual(0, event_loop._internal_fds) + self.assertIsNone(event_loop._csock) + self.assertIsNone(event_loop._ssock) + + +if sys.platform == 'win32': + from tulip import windows_events + + class SelectEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return windows_events.SelectorEventLoop() + + class ProactorEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return windows_events.ProactorEventLoop() + def test_create_ssl_connection(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + def test_reader_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_reader_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_reader_callback_with_handler(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_writer_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_writer_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_writer_callback_with_handler(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_accept_connection_retry(self): + raise unittest.SkipTest( + "IocpEventLoop does not have _accept_connection()") + def test_accept_connection_exception(self): + raise unittest.SkipTest( + "IocpEventLoop does not have _accept_connection()") + def test_create_datagram_connection(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_connection()") + def test_create_datagram_connection_no_connection(self): + raise unittest.SkipTest( + "IocpEventLoop does not have " + "create_datagram_connection_no_connection()") + def test_start_serving_datagram(self): + raise unittest.SkipTest( + "IocpEventLoop does not have start_serving_datagram()") + def test_start_serving_datagram_no_getaddrinfoc(self): + raise unittest.SkipTest( + "IocpEventLoop does not have start_serving_datagram()") + def test_start_serving_datagram_cant_bind(self): + raise unittest.SkipTest( + "IocpEventLoop does not have start_serving_udp()") + +else: + from tulip import selectors + from tulip import unix_events + + if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop( + selectors.KqueueSelector()) + + if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.EpollSelector()) + + if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.PollSelector()) + + # Should always exist. + class SelectEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.SelectSelector()) + + +class HandlerTests(unittest.TestCase): + + def test_handler(self): + def callback(*args): + return args + + args = () + h = events.Handler(callback, args) + self.assertIs(h.callback, callback) + self.assertIs(h.args, args) + self.assertFalse(h.cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handler(' + '.callback')) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h.cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handler(' + '.callback')) + self.assertTrue(r.endswith('())')) + + def test_make_handler(self): + def callback(*args): + return args + h1 = events.Handler(callback, ()) + h2 = events.make_handler(h1, ()) + self.assertIs(h1, h2) + + self.assertRaises( + AssertionError, events.make_handler, h1, (1, 2)) + + +class TimerTests(unittest.TestCase): + + def test_timer(self): + def callback(*args): + return args + + args = () + when = time.monotonic() + h = events.Timer(when, callback, args) + self.assertIs(h.callback, callback) + self.assertIs(h.args, args) + self.assertFalse(h.cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h.cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + self.assertRaises(AssertionError, events.Timer, None, callback, args) + + def test_timer_comparison(self): + def callback(*args): + return args + + when = time.monotonic() + + h1 = events.Timer(when, callback, ()) + h2 = events.Timer(when, callback, ()) + self.assertFalse(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertTrue(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertFalse(h2 > h1) + self.assertTrue(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertTrue(h1 == h2) + self.assertFalse(h1 != h2) + + h2.cancel() + self.assertFalse(h1 == h2) + + h1 = events.Timer(when, callback, ()) + h2 = events.Timer(when + 10.0, callback, ()) + self.assertTrue(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertFalse(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertTrue(h2 > h1) + self.assertFalse(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h3 = events.Handler(callback, ()) + self.assertIs(NotImplemented, h1.__eq__(h3)) + self.assertIs(NotImplemented, h1.__ne__(h3)) + + +class AbstractEventLoopTests(unittest.TestCase): + + def test_not_imlemented(self): + f = unittest.mock.Mock() + ev_loop = events.AbstractEventLoop() + self.assertRaises( + NotImplementedError, ev_loop.run) + self.assertRaises( + NotImplementedError, ev_loop.run_forever) + self.assertRaises( + NotImplementedError, ev_loop.run_once) + self.assertRaises( + NotImplementedError, ev_loop.run_until_complete, None) + self.assertRaises( + NotImplementedError, ev_loop.stop) + self.assertRaises( + NotImplementedError, ev_loop.call_later, None, None) + self.assertRaises( + NotImplementedError, ev_loop.call_repeatedly, None, None) + self.assertRaises( + NotImplementedError, ev_loop.call_soon, None) + self.assertRaises( + NotImplementedError, ev_loop.call_soon_threadsafe, None) + self.assertRaises( + NotImplementedError, ev_loop.wrap_future, f) + self.assertRaises( + NotImplementedError, ev_loop.run_in_executor, f, f) + self.assertRaises( + NotImplementedError, ev_loop.getaddrinfo, 'localhost', 8080) + self.assertRaises( + NotImplementedError, ev_loop.getnameinfo, ('localhost', 8080)) + self.assertRaises( + NotImplementedError, ev_loop.create_connection, f) + self.assertRaises( + NotImplementedError, ev_loop.start_serving, f) + self.assertRaises( + NotImplementedError, ev_loop.create_datagram_connection, f) + self.assertRaises( + NotImplementedError, ev_loop.start_serving_datagram, + f, 'localhost', 8080) + self.assertRaises( + NotImplementedError, ev_loop.add_reader, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_reader, 1) + self.assertRaises( + NotImplementedError, ev_loop.add_writer, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_writer, 1) + self.assertRaises( + NotImplementedError, ev_loop.sock_recv, f, 10) + self.assertRaises( + NotImplementedError, ev_loop.sock_sendall, f, 10) + self.assertRaises( + NotImplementedError, ev_loop.sock_connect, f, f) + self.assertRaises( + NotImplementedError, ev_loop.sock_accept, f) + self.assertRaises( + NotImplementedError, ev_loop.add_signal_handler, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_signal_handler, 1) + + +class ProtocolsAbsTests(unittest.TestCase): + + def test_empty(self): + f = unittest.mock.Mock() + p = protocols.Protocol() + self.assertIsNone(p.connection_made(f)) + self.assertIsNone(p.connection_lost(f)) + self.assertIsNone(p.data_received(f)) + self.assertIsNone(p.eof_received()) + + dp = protocols.DatagramProtocol() + self.assertIsNone(dp.connection_made(f)) + self.assertIsNone(dp.connection_lost(f)) + self.assertIsNone(dp.connection_refused(f)) + self.assertIsNone(dp.datagram_received(f, f)) + + +class PolicyTests(unittest.TestCase): + + def test_event_loop_policy(self): + policy = events.EventLoopPolicy() + self.assertRaises(NotImplementedError, policy.get_event_loop) + self.assertRaises(NotImplementedError, policy.set_event_loop, object()) + self.assertRaises(NotImplementedError, policy.new_event_loop) + + def test_get_event_loop(self): + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy._event_loop) + + event_loop = policy.get_event_loop() + self.assertIsInstance(event_loop, events.AbstractEventLoop) + + self.assertIs(policy._event_loop, event_loop) + self.assertIs(event_loop, policy.get_event_loop()) + + @unittest.mock.patch('tulip.events.threading') + def test_get_event_loop_thread(self, m_threading): + m_t = m_threading.current_thread.return_value = unittest.mock.Mock() + m_t.name = 'Thread 1' + + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy.get_event_loop()) + + def test_new_event_loop(self): + policy = events.DefaultEventLoopPolicy() + + event_loop = policy.new_event_loop() + self.assertIsInstance(event_loop, events.AbstractEventLoop) + + def test_set_event_loop(self): + policy = events.DefaultEventLoopPolicy() + old_event_loop = policy.get_event_loop() + + self.assertRaises(AssertionError, policy.set_event_loop, object()) + + event_loop = policy.new_event_loop() + policy.set_event_loop(event_loop) + self.assertIs(event_loop, policy.get_event_loop()) + self.assertIsNot(old_event_loop, policy.get_event_loop()) + + def test_get_event_loop_policy(self): + policy = events.get_event_loop_policy() + self.assertIsInstance(policy, events.EventLoopPolicy) + self.assertIs(policy, events.get_event_loop_policy()) + + def test_set_event_loop_policy(self): + self.assertRaises( + AssertionError, events.set_event_loop_policy, object()) + + old_policy = events.get_event_loop_policy() + + policy = events.DefaultEventLoopPolicy() + events.set_event_loop_policy(policy) + self.assertIs(policy, events.get_event_loop_policy()) + self.assertIsNot(policy, old_policy) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/futures_test.py b/tests/futures_test.py new file mode 100644 index 00000000..fd49493d --- /dev/null +++ b/tests/futures_test.py @@ -0,0 +1,220 @@ +"""Tests for futures.py.""" + +import unittest + +from tulip import futures + + +def _fakefunc(f): + return f + + +class FutureTests(unittest.TestCase): + + def test_initial_state(self): + f = futures.Future() + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertFalse(f.done()) + + def test_init_event_loop_positional(self): + # Make sure Future does't accept a positional argument + self.assertRaises(TypeError, futures.Future, 42) + + def test_cancel(self): + f = futures.Future() + self.assertTrue(f.cancel()) + self.assertTrue(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(futures.CancelledError, f.result) + self.assertRaises(futures.CancelledError, f.exception) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_result(self): + f = futures.Future() + self.assertRaises(futures.InvalidStateError, f.result) + self.assertRaises(futures.InvalidTimeoutError, f.result, 10) + + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_exception(self): + exc = RuntimeError() + f = futures.Future() + self.assertRaises(futures.InvalidStateError, f.exception) + self.assertRaises(futures.InvalidTimeoutError, f.exception, 10) + + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_yield_from_twice(self): + f = futures.Future() + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + + def test_repr(self): + f_pending = futures.Future() + self.assertEqual(repr(f_pending), 'Future') + + f_cancelled = futures.Future() + f_cancelled.cancel() + self.assertEqual(repr(f_cancelled), 'Future') + + f_result = futures.Future() + f_result.set_result(4) + self.assertEqual(repr(f_result), 'Future') + + f_exception = futures.Future() + f_exception.set_exception(RuntimeError()) + self.assertEqual(repr(f_exception), 'Future') + + f_few_callbacks = futures.Future() + f_few_callbacks.add_done_callback(_fakefunc) + self.assertIn('Future', r) + + def test_copy_state(self): + # Test the internal _copy_state method since it's being directly + # invoked in other modules. + f = futures.Future() + f.set_result(10) + + newf = futures.Future() + newf._copy_state(f) + self.assertTrue(newf.done()) + self.assertEqual(newf.result(), 10) + + f_exception = futures.Future() + f_exception.set_exception(RuntimeError()) + + newf_exception = futures.Future() + newf_exception._copy_state(f_exception) + self.assertTrue(newf_exception.done()) + self.assertRaises(RuntimeError, newf_exception.result) + + f_cancelled = futures.Future() + f_cancelled.cancel() + + newf_cancelled = futures.Future() + newf_cancelled._copy_state(f_cancelled) + self.assertTrue(newf_cancelled.cancelled()) + + def test_iter(self): + def coro(): + fut = futures.Future() + yield from fut + + def test(): + arg1, arg2 = coro() + + self.assertRaises(AssertionError, test) + + +# A fake event loop for tests. All it does is implement a call_soon method +# that immediately invokes the given function. +class _FakeEventLoop: + def call_soon(self, fn, future): + fn(future) + + +class FutureDoneCallbackTests(unittest.TestCase): + + def _make_callback(self, bag, thing): + # Create a callback function that appends thing to bag. + def bag_appender(future): + bag.append(thing) + return bag_appender + + def _new_future(self): + return futures.Future(event_loop=_FakeEventLoop()) + + def test_callbacks_invoked_on_set_result(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 42)) + f.add_done_callback(self._make_callback(bag, 17)) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def test_callbacks_invoked_on_set_exception(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 100)) + + self.assertEqual(bag, []) + exc = RuntimeError() + f.set_exception(exc) + self.assertEqual(bag, [100]) + self.assertEqual(f.exception(), exc) + + def test_remove_done_callback(self): + bag = [] + f = self._new_future() + cb1 = self._make_callback(bag, 1) + cb2 = self._make_callback(bag, 2) + cb3 = self._make_callback(bag, 3) + + # Add one cb1 and one cb2. + f.add_done_callback(cb1) + f.add_done_callback(cb2) + + # One instance of cb2 removed. Now there's only one cb1. + self.assertEqual(f.remove_done_callback(cb2), 1) + + # Never had any cb3 in there. + self.assertEqual(f.remove_done_callback(cb3), 0) + + # After this there will be 6 instances of cb1 and one of cb2. + f.add_done_callback(cb2) + for i in range(5): + f.add_done_callback(cb1) + + # Remove all instances of cb1. One cb2 remains. + self.assertEqual(f.remove_done_callback(cb1), 6) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [2]) + self.assertEqual(f.result(), 'foo') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py new file mode 100644 index 00000000..c0952287 --- /dev/null +++ b/tests/http_protocol_test.py @@ -0,0 +1,535 @@ +"""Tests for http/protocol.py""" + +import http.client +import unittest +import unittest.mock +import zlib + +import tulip +from tulip.http import protocol +from tulip.test_utils import LogTrackingTestCase + + +class HttpStreamReaderTests(LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.suppress_log_errors() + + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + self.transport = unittest.mock.Mock() + self.stream = protocol.HttpStreamReader() + + def tearDown(self): + self.loop.close() + super().tearDown() + + def test_request_line(self): + self.stream.feed_data(b'get /path HTTP/1.1\r\n') + self.assertEqual( + ('GET', '/path', (1, 1)), + self.loop.run_until_complete( + tulip.Task(self.stream.read_request_line()))) + + def test_request_line_two_slashes(self): + self.stream.feed_data(b'get //path HTTP/1.1\r\n') + self.assertEqual( + ('GET', '//path', (1, 1)), + self.loop.run_until_complete( + tulip.Task(self.stream.read_request_line()))) + + def test_request_line_non_ascii(self): + self.stream.feed_data(b'get /path\xd0\xb0 HTTP/1.1\r\n') + + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_request_line())) + + self.assertEqual( + b'get /path\xd0\xb0 HTTP/1.1\r\n', cm.exception.args[0]) + + def test_request_line_bad_status_line(self): + self.stream.feed_data(b'\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + tulip.Task(self.stream.read_request_line())) + + def test_request_line_bad_method(self): + self.stream.feed_data(b'!12%()+=~$ /get HTTP/1.1\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + tulip.Task(self.stream.read_request_line())) + + def test_request_line_bad_version(self): + self.stream.feed_data(b'GET //get HT/11\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + tulip.Task(self.stream.read_request_line())) + + def test_response_status_bad_status_line(self): + self.stream.feed_data(b'\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + tulip.Task(self.stream.read_response_status())) + + def test_response_status_bad_status_line_eof(self): + self.stream.feed_eof() + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + tulip.Task(self.stream.read_response_status())) + + def test_response_status_bad_status_non_ascii(self): + self.stream.feed_data(b'HTTP/1.1 200 \xd0\xb0\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_response_status())) + + self.assertEqual(b'HTTP/1.1 200 \xd0\xb0\r\n', cm.exception.args[0]) + + def test_response_status_bad_version(self): + self.stream.feed_data(b'HT/11 200 Ok\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_response_status())) + + self.assertEqual('HT/11 200 Ok', cm.exception.args[0]) + + def test_response_status_no_reason(self): + self.stream.feed_data(b'HTTP/1.1 200\r\n') + + v, s, r = self.loop.run_until_complete( + tulip.Task(self.stream.read_response_status())) + self.assertEqual(v, (1, 1)) + self.assertEqual(s, 200) + self.assertEqual(r, '') + + def test_response_status_bad(self): + self.stream.feed_data(b'HTT/1\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_response_status())) + + self.assertIn('HTT/1', str(cm.exception)) + + def test_response_status_bad_code_under_100(self): + self.stream.feed_data(b'HTTP/1.1 99 test\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_response_status())) + + self.assertIn('HTTP/1.1 99 test', str(cm.exception)) + + def test_response_status_bad_code_above_999(self): + self.stream.feed_data(b'HTTP/1.1 9999 test\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_response_status())) + + self.assertIn('HTTP/1.1 9999 test', str(cm.exception)) + + def test_response_status_bad_code_not_int(self): + self.stream.feed_data(b'HTTP/1.1 ttt test\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_response_status())) + + self.assertIn('HTTP/1.1 ttt test', str(cm.exception)) + + def test_read_headers(self): + self.stream.feed_data(b'test: line\r\n' + b' continue\r\n' + b'test2: data\r\n' + b'\r\n') + + headers = self.loop.run_until_complete( + tulip.Task(self.stream.read_headers())) + self.assertEqual(headers, + [('TEST', 'line\r\n continue'), ('TEST2', 'data')]) + + def test_read_headers_size(self): + self.stream.feed_data(b'test: line\r\n') + self.stream.feed_data(b' continue\r\n') + self.stream.feed_data(b'test2: data\r\n') + self.stream.feed_data(b'\r\n') + + self.stream.MAX_HEADERS = 5 + self.assertRaises( + http.client.LineTooLong, + self.loop.run_until_complete, + tulip.Task(self.stream.read_headers())) + + def test_read_headers_invalid_header(self): + self.stream.feed_data(b'test line\r\n') + + with self.assertRaises(ValueError) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_headers())) + + self.assertIn("Invalid header b'test line'", str(cm.exception)) + + def test_read_headers_invalid_name(self): + self.stream.feed_data(b'test[]: line\r\n') + + with self.assertRaises(ValueError) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_headers())) + + self.assertIn("Invalid header name b'TEST[]'", str(cm.exception)) + + def test_read_headers_headers_size(self): + self.stream.MAX_HEADERFIELD_SIZE = 5 + self.stream.feed_data(b'test: line data data\r\ndata\r\n') + + with self.assertRaises(http.client.LineTooLong) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_headers())) + + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_read_headers_continuation_headers_size(self): + self.stream.MAX_HEADERFIELD_SIZE = 5 + self.stream.feed_data(b'test: line\r\n test\r\n') + + with self.assertRaises(http.client.LineTooLong) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_headers())) + + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_read_payload_unknown_encoding(self): + self.assertRaises( + ValueError, self.stream.read_length_payload, encoding='unknown') + + def test_read_payload(self): + self.stream.feed_data(b'da') + self.stream.feed_data(b't') + self.stream.feed_data(b'ali') + self.stream.feed_data(b'ne') + + stream = self.stream.read_length_payload(4) + self.assertIsInstance(stream, tulip.StreamReader) + + data = self.loop.run_until_complete(tulip.Task(stream.read())) + self.assertEqual(b'data', data) + self.assertEqual(b'line', b''.join(self.stream.buffer)) + + def test_read_payload_eof(self): + self.stream.feed_data(b'da') + self.stream.feed_eof() + stream = self.stream.read_length_payload(4) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, tulip.Task(stream.read())) + + def test_read_payload_eof_exc(self): + self.stream.feed_data(b'da') + stream = self.stream.read_length_payload(4) + + def eof(): + yield from [] + self.stream.feed_eof() + + t1 = tulip.Task(stream.read()) + t2 = tulip.Task(eof()) + + self.loop.run_until_complete(tulip.Task(tulip.wait([t1, t2]))) + self.assertRaises(http.client.IncompleteRead, t1.result) + self.assertIsNone(self.stream._reader) + + def test_read_payload_deflate(self): + comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + + data = b''.join([comp.compress(b'data'), comp.flush()]) + stream = self.stream.read_length_payload(len(data), encoding='deflate') + + self.stream.feed_data(data) + + data = self.loop.run_until_complete(tulip.Task(stream.read())) + self.assertEqual(b'data', data) + + def _test_read_payload_compress_error(self): + data = b'123123123datadatadata' + reader = protocol.length_reader(4) + self.stream.feed_data(data) + stream = self.stream.read_payload(reader, 'deflate') + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, tulip.Task(stream.read())) + + def test_read_chunked_payload(self): + stream = self.stream.read_chunked_payload() + self.stream.feed_data(b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + + data = self.loop.run_until_complete(tulip.Task(stream.read())) + self.assertEqual(b'dataline', data) + + def test_read_chunked_payload_chunks(self): + stream = self.stream.read_chunked_payload() + + self.stream.feed_data(b'4\r\ndata\r') + self.stream.feed_data(b'\n4') + self.stream.feed_data(b'\r') + self.stream.feed_data(b'\n') + self.stream.feed_data(b'line\r\n0\r\n') + self.stream.feed_data(b'test\r\n\r\n') + + data = self.loop.run_until_complete(tulip.Task(stream.read())) + self.assertEqual(b'dataline', data) + + def test_read_chunked_payload_incomplete(self): + stream = self.stream.read_chunked_payload() + + self.stream.feed_data(b'4\r\ndata\r\n') + self.stream.feed_eof() + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, tulip.Task(stream.read())) + + def test_read_chunked_payload_extension(self): + stream = self.stream.read_chunked_payload() + + self.stream.feed_data( + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + + data = self.loop.run_until_complete(tulip.Task(stream.read())) + self.assertEqual(b'dataline', data) + + def test_read_chunked_payload_size_error(self): + stream = self.stream.read_chunked_payload() + + self.stream.feed_data(b'blah\r\n') + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, tulip.Task(stream.read())) + + def test_read_length_payload(self): + stream = self.stream.read_length_payload(8) + + self.stream.feed_data(b'data') + self.stream.feed_data(b'data') + + data = self.loop.run_until_complete(tulip.Task(stream.read())) + self.assertEqual(b'datadata', data) + + def test_read_length_payload_zero(self): + stream = self.stream.read_length_payload(0) + + self.stream.feed_data(b'data') + + data = self.loop.run_until_complete(tulip.Task(stream.read())) + self.assertEqual(b'', data) + + def test_read_length_payload_incomplete(self): + stream = self.stream.read_length_payload(8) + + self.stream.feed_data(b'data') + self.stream.feed_eof() + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, tulip.Task(stream.read())) + + def test_read_eof_payload(self): + stream = self.stream.read_eof_payload() + + self.stream.feed_data(b'data') + self.stream.feed_eof() + + data = self.loop.run_until_complete(tulip.Task(stream.read())) + self.assertEqual(b'data', data) + + def test_read_message_should_close(self): + self.stream.feed_data( + b'Host: example.com\r\nConnection: close\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + self.assertTrue(msg.should_close) + + def test_read_message_should_close_http11(self): + self.stream.feed_data( + b'Host: example.com\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(version=(1, 1)))) + self.assertFalse(msg.should_close) + + def test_read_message_should_close_http10(self): + self.stream.feed_data( + b'Host: example.com\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(version=(1, 0)))) + self.assertTrue(msg.should_close) + + def test_read_message_should_close_keep_alive(self): + self.stream.feed_data( + b'Host: example.com\r\nConnection: keep-alive\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + self.assertFalse(msg.should_close) + + def test_read_message_content_length_broken(self): + self.stream.feed_data( + b'Host: example.com\r\nContent-Length: qwe\r\n\r\n') + + self.assertRaises( + http.client.HTTPException, + self.loop.run_until_complete, + tulip.Task(self.stream.read_message())) + + def test_read_message_content_length_wrong(self): + self.stream.feed_data( + b'Host: example.com\r\nContent-Length: -1\r\n\r\n') + + self.assertRaises( + http.client.HTTPException, + self.loop.run_until_complete, + tulip.Task(self.stream.read_message())) + + def test_read_message_content_length(self): + self.stream.feed_data( + b'Host: example.com\r\nContent-Length: 2\r\n\r\n12') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(b'12', payload) + + def test_read_message_content_length_no_val(self): + self.stream.feed_data(b'Host: example.com\r\n\r\n12') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(readall=False))) + + payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(b'', payload) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_read_message_deflate(self): + self.stream.feed_data( + ('Host: example.com\r\nContent-Length: %s\r\n' + 'Content-Encoding: deflate\r\n\r\n' % + len(self._COMPRESSED)).encode()) + self.stream.feed_data(self._COMPRESSED) + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(b'data', payload) + + def test_read_message_deflate_disabled(self): + self.stream.feed_data( + ('Host: example.com\r\nContent-Encoding: deflate\r\n' + 'Content-Length: %s\r\n\r\n' % + len(self._COMPRESSED)).encode()) + self.stream.feed_data(self._COMPRESSED) + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(compression=False))) + payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(self._COMPRESSED, payload) + + def test_read_message_deflate_unknown(self): + self.stream.feed_data( + ('Host: example.com\r\nContent-Encoding: compress\r\n' + 'Content-Length: %s\r\n\r\n' % len(self._COMPRESSED)).encode()) + self.stream.feed_data(self._COMPRESSED) + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(compression=False))) + payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(self._COMPRESSED, payload) + + def test_read_message_websocket(self): + self.stream.feed_data( + b'Host: example.com\r\nSec-Websocket-Key1: 13\r\n\r\n1234567890') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(b'12345678', payload) + + def test_read_message_chunked(self): + self.stream.feed_data( + b'Host: example.com\r\nTransfer-Encoding: chunked\r\n\r\n') + self.stream.feed_data( + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(b'dataline', payload) + + def test_read_message_readall(self): + self.stream.feed_data( + b'Host: example.com\r\n\r\n') + self.stream.feed_data(b'data') + self.stream.feed_data(b'line') + self.stream.feed_eof() + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(readall=True))) + + payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(b'dataline', payload) + + +class HttpStreamWriterTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + self.writer = protocol.HttpStreamWriter(self.transport) + + def test_ctor(self): + transport = unittest.mock.Mock() + writer = protocol.HttpStreamWriter(transport, 'latin-1') + self.assertIs(writer.transport, transport) + self.assertEqual(writer.encoding, 'latin-1') + + def test_encode(self): + self.assertEqual(b'test', self.writer.encode('test')) + self.assertEqual(b'test', self.writer.encode(b'test')) + + def test_decode(self): + self.assertEqual('test', self.writer.decode('test')) + self.assertEqual('test', self.writer.decode(b'test')) + + def test_write(self): + self.writer.write(b'test') + self.assertTrue(self.transport.write.called) + self.assertEqual((b'test',), self.transport.write.call_args[0]) + + def test_write_str(self): + self.writer.write_str('test') + self.assertTrue(self.transport.write.called) + self.assertEqual((b'test',), self.transport.write.call_args[0]) + + def test_write_cunked(self): + self.writer.write_chunked('') + self.assertFalse(self.transport.write.called) + + self.writer.write_chunked('data') + self.assertEqual( + [(b'4\r\n',), (b'data',), (b'\r\n',)], + [c[0] for c in self.transport.write.call_args_list]) + + def test_write_eof(self): + self.writer.write_chunked_eof() + self.assertEqual((b'0\r\n\r\n',), self.transport.write.call_args[0]) diff --git a/tests/locks_test.py b/tests/locks_test.py new file mode 100644 index 00000000..4267134c --- /dev/null +++ b/tests/locks_test.py @@ -0,0 +1,752 @@ +"""Tests for lock.py""" + +import time +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import locks +from tulip import tasks +from tulip import test_utils + + +class LockTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + lock = locks.Lock() + self.assertTrue(repr(lock).endswith('[unlocked]>')) + + def acquire_lock(): + yield from lock + + self.event_loop.run_until_complete(tasks.Task(acquire_lock())) + self.assertTrue(repr(lock).endswith('[locked]>')) + + def test_lock(self): + lock = locks.Lock() + + def acquire_lock(): + return (yield from lock) + + res = self.event_loop.run_until_complete(tasks.Task(acquire_lock())) + + self.assertTrue(res) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_acquire(self): + lock = locks.Lock() + result = [] + + self.assertTrue( + self.event_loop.run_until_complete( + tasks.Task(lock.acquire()) + )) + + @tasks.coroutine + def c1(result): + if (yield from lock.acquire()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from lock.acquire()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from lock.acquire()): + result.append(3) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + self.event_loop.run_once() + self.assertEqual([1], result) + + tasks.Task(c3(result)) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + + def test_acquire_timeout(self): + lock = locks.Lock() + self.assertTrue( + self.event_loop.run_until_complete(tasks.Task(lock.acquire()))) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete( + tasks.Task(lock.acquire(timeout=0.1))) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + lock = locks.Lock() + self.event_loop.run_until_complete(tasks.Task(lock.acquire())) + + self.event_loop.call_later(0.1, lock.release) + acquired = self.event_loop.run_until_complete( + tasks.Task(lock.acquire(10.1))) + self.assertTrue(acquired) + + def test_acquire_timeout_mixed(self): + lock = locks.Lock() + self.event_loop.run_until_complete(tasks.Task(lock.acquire())) + tasks.Task(lock.acquire()) + tasks.Task(lock.acquire()) + acquire_task = tasks.Task(lock.acquire(0.1)) + tasks.Task(lock.acquire()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(lock._waiters)) + + def test_acquire_cancel(self): + lock = locks.Lock() + self.assertTrue( + self.event_loop.run_until_complete(tasks.Task(lock.acquire()))) + + task = tasks.Task(lock.acquire()) + self.event_loop.call_soon(task.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, task) + self.assertFalse(lock._waiters) + + def test_release_not_acquired(self): + lock = locks.Lock() + + self.assertRaises(RuntimeError, lock.release) + + def test_release_no_waiters(self): + lock = locks.Lock() + self.event_loop.run_until_complete(tasks.Task(lock.acquire())) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_context_manager(self): + lock = locks.Lock() + + @tasks.task + def acquire_lock(): + return (yield from lock) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + def test_context_manager_no_yield(self): + lock = locks.Lock() + + try: + with lock: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + +class EventWaiterTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + ev = locks.EventWaiter() + self.assertTrue(repr(ev).endswith('[unset]>')) + + ev.set() + self.assertTrue(repr(ev).endswith('[set]>')) + + def test_wait(self): + ev = locks.EventWaiter() + self.assertFalse(ev.is_set()) + + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from ev.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from ev.wait()): + result.append(3) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + tasks.Task(c3(result)) + + ev.set() + self.event_loop.run_once() + self.assertEqual([3, 1, 2], result) + + def test_wait_on_set(self): + ev = locks.EventWaiter() + ev.set() + + res = self.event_loop.run_until_complete(tasks.Task(ev.wait())) + self.assertTrue(res) + + def test_wait_timeout(self): + ev = locks.EventWaiter() + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(ev.wait(0.1))) + self.assertFalse(res) + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + ev = locks.EventWaiter() + self.event_loop.call_later(0.1, ev.set) + acquired = self.event_loop.run_until_complete(tasks.Task(ev.wait(10.1))) + self.assertTrue(acquired) + + def test_wait_timeout_mixed(self): + ev = locks.EventWaiter() + tasks.Task(ev.wait()) + tasks.Task(ev.wait()) + acquire_task = tasks.Task(ev.wait(0.1)) + tasks.Task(ev.wait()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(ev._waiters)) + + def test_wait_cancel(self): + ev = locks.EventWaiter() + + wait = tasks.Task(ev.wait()) + self.event_loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, wait) + self.assertFalse(ev._waiters) + + def test_clear(self): + ev = locks.EventWaiter() + self.assertFalse(ev.is_set()) + + ev.set() + self.assertTrue(ev.is_set()) + + ev.clear() + self.assertFalse(ev.is_set()) + + def test_clear_with_waiters(self): + ev = locks.EventWaiter() + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + tasks.Task(c1(result)) + self.event_loop.run_once() + self.assertEqual([], result) + + ev.set() + ev.clear() + self.assertFalse(ev.is_set()) + + ev.set() + ev.set() + self.assertEqual(1, len(ev._waiters)) + + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertEqual(0, len(ev._waiters)) + + +class ConditionTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_wait(self): + cond = locks.Condition() + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue( + self.event_loop.run_until_complete(tasks.Task(cond.acquire()))) + cond.notify() + self.event_loop.run_once() + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + def test_wait_timeout(self): + cond = locks.Condition() + self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + + t0 = time.monotonic() + wait = self.event_loop.run_until_complete(tasks.Task(cond.wait(0.1))) + self.assertFalse(wait) + self.assertTrue(cond.locked()) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + def test_wait_cancel(self): + cond = locks.Condition() + self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + + wait = tasks.Task(cond.wait()) + self.event_loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, wait) + self.assertFalse(cond._condition_waiters) + self.assertTrue(cond.locked()) + + def test_wait_unacquired(self): + self.suppress_log_errors() + + cond = locks.Condition() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, tasks.Task(cond.wait())) + + def test_wait_for(self): + cond = locks.Condition() + + presult = False + def predicate(): + return presult + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate)): + result.append(1) + cond.release() + + tasks.Task(c1(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([], result) + + presult = True + self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + def test_wait_for_timeout(self): + cond = locks.Condition() + + result = [] + + predicate = unittest.mock.Mock() + predicate.return_value = False + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate, 0.2)): + result.append(1) + else: + result.append(2) + cond.release() + + wait_for = tasks.Task(c1(result)) + + t0 = time.monotonic() + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(wait_for) + self.assertEqual([2], result) + self.assertEqual(3, predicate.call_count) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.18 < total_time < 0.22) + + def test_wait_for_unacquired(self): + self.suppress_log_errors() + + cond = locks.Condition() + + # predicate can return true immediately + res = self.event_loop.run_until_complete(tasks.Task( + cond.wait_for(lambda: [1,2,3]))) + self.assertEqual([1,2,3], res) + + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, + tasks.Task(cond.wait_for(lambda: False))) + + def test_notify(self): + cond = locks.Condition() + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + cond.release() + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + cond.notify(1) + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + cond.notify(1) + cond.notify(2048) + cond.release() + self.event_loop.run_once() + self.assertEqual([1,2,3], result) + + def test_notify_all(self): + cond = locks.Condition() + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + cond.notify_all() + cond.release() + self.event_loop.run_once() + self.assertEqual([1,2], result) + + def test_notify_unacquired(self): + cond = locks.Condition() + self.assertRaises(RuntimeError, cond.notify) + + def test_notify_all_unacquired(self): + cond = locks.Condition() + self.assertRaises(RuntimeError, cond.notify_all) + + +class SemaphoreTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + sem = locks.Semaphore() + self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) + + self.event_loop.run_until_complete(tasks.Task(sem.acquire())) + self.assertTrue(repr(sem).endswith('[locked]>')) + + def test_semaphore(self): + sem = locks.Semaphore() + self.assertEqual(1, sem._value) + + @tasks.task + def acquire_lock(): + return (yield from sem) + + res = self.event_loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(sem.locked()) + self.assertEqual(0, sem._value) + + sem.release() + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + def test_semaphore_value(self): + self.assertRaises(ValueError, locks.Semaphore, -1) + + def test_acquire(self): + sem = locks.Semaphore(3) + result = [] + + self.assertTrue( + self.event_loop.run_until_complete(tasks.Task(sem.acquire()))) + self.assertTrue( + self.event_loop.run_until_complete(tasks.Task(sem.acquire()))) + self.assertFalse(sem.locked()) + + @tasks.coroutine + def c1(result): + res = yield from sem.acquire() + result.append(1) + + @tasks.coroutine + def c2(result): + res = yield from sem.acquire() + result.append(2) + + @tasks.coroutine + def c3(result): + yield from sem.acquire() + result.append(3) + + @tasks.coroutine + def c4(result): + yield from sem.acquire() + result.append(4) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(sem.locked()) + self.assertEqual(2, len(sem._waiters)) + self.assertEqual(0, sem._value) + + tasks.Task(c4(result)) + + sem.release() + sem.release() + self.assertEqual(2, sem._value) + + self.event_loop.run_once() + self.assertEqual(0, sem._value) + self.assertEqual([1, 2, 3], result) + self.assertTrue(sem.locked()) + self.assertEqual(1, len(sem._waiters)) + self.assertEqual(0, sem._value) + + def test_acquire_timeout(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(tasks.Task(sem.acquire())) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete( + tasks.Task(sem.acquire(0.1))) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + sem = locks.Semaphore() + self.event_loop.run_until_complete(tasks.Task(sem.acquire())) + + self.event_loop.call_later(0.1, sem.release) + acquired = self.event_loop.run_until_complete( + tasks.Task(sem.acquire(10.1))) + self.assertTrue(acquired) + + def test_acquire_timeout_mixed(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(tasks.Task(sem.acquire())) + tasks.Task(sem.acquire()) + tasks.Task(sem.acquire()) + acquire_task = tasks.Task(sem.acquire(0.1)) + tasks.Task(sem.acquire()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(sem._waiters)) + + def test_acquire_cancel(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(tasks.Task(sem.acquire())) + + acquire = tasks.Task(sem.acquire()) + self.event_loop.call_soon(acquire.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, acquire) + self.assertFalse(sem._waiters) + + def test_release_not_acquired(self): + sem = locks.Semaphore(bound=True) + + self.assertRaises(ValueError, sem.release) + + def test_release_no_waiters(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(tasks.Task(sem.acquire())) + self.assertTrue(sem.locked()) + + sem.release() + self.assertFalse(sem.locked()) + + def test_context_manager(self): + sem = locks.Semaphore(2) + + @tasks.task + def acquire_lock(): + return (yield from sem) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertTrue(sem.locked()) + + self.assertEqual(2, sem._value) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py new file mode 100644 index 00000000..b8fb10b0 --- /dev/null +++ b/tests/selector_events_test.py @@ -0,0 +1,1421 @@ +"""Tests for selector_events.py""" + +import errno +import socket +import unittest +import unittest.mock +try: + import ssl +except ImportError: + ssl = None + +from tulip import futures +from tulip import selectors +from tulip.selector_events import BaseSelectorEventLoop +from tulip.selector_events import _SelectorSslTransport +from tulip.selector_events import _SelectorSocketTransport +from tulip.selector_events import _SelectorDatagramTransport + + +class TestBaseSelectorEventLoop(BaseSelectorEventLoop): + + def _make_self_pipe(self): + self._ssock = unittest.mock.Mock() + self._csock = unittest.mock.Mock() + self._internal_fds += 1 + + +class BaseSelectorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.event_loop = TestBaseSelectorEventLoop(unittest.mock.Mock()) + + def test_make_socket_transport(self): + m = unittest.mock.Mock() + self.event_loop.add_reader = unittest.mock.Mock() + self.assertIsInstance( + self.event_loop._make_socket_transport(m, m), + _SelectorSocketTransport) + + def test_make_ssl_transport(self): + m = unittest.mock.Mock() + self.event_loop.add_reader = unittest.mock.Mock() + self.event_loop.add_writer = unittest.mock.Mock() + self.event_loop.remove_reader = unittest.mock.Mock() + self.event_loop.remove_writer = unittest.mock.Mock() + self.assertIsInstance( + self.event_loop._make_ssl_transport(m, m, m, m), + _SelectorSslTransport) + + def test_close(self): + ssock = self.event_loop._ssock + csock = self.event_loop._csock + remove_reader = self.event_loop.remove_reader = unittest.mock.Mock() + + self.event_loop._selector.close() + self.event_loop._selector = selector = unittest.mock.Mock() + self.event_loop.close() + self.assertIsNone(self.event_loop._selector) + self.assertIsNone(self.event_loop._csock) + self.assertIsNone(self.event_loop._ssock) + self.assertTrue(selector.close.called) + self.assertTrue(ssock.close.called) + self.assertTrue(csock.close.called) + self.assertTrue(remove_reader.called) + + self.event_loop.close() + self.event_loop.close() + + def test_close_no_selector(self): + ssock = self.event_loop._ssock + csock = self.event_loop._csock + remove_reader = self.event_loop.remove_reader = unittest.mock.Mock() + + self.event_loop._selector.close() + self.event_loop._selector = None + self.event_loop.close() + self.assertIsNone(self.event_loop._selector) + self.assertFalse(ssock.close.called) + self.assertFalse(csock.close.called) + self.assertFalse(remove_reader.called) + + def test_socketpair(self): + self.assertRaises(NotImplementedError, self.event_loop._socketpair) + + def test_read_from_self_tryagain(self): + class Err(socket.error): + errno = errno.EAGAIN + + self.event_loop._ssock.recv.side_effect = Err + self.assertIsNone(self.event_loop._read_from_self()) + + def test_read_from_self_exception(self): + class Err(socket.error): + pass + + self.event_loop._ssock.recv.side_effect = Err + self.assertRaises(Err, self.event_loop._read_from_self) + + def test_write_to_self_tryagain(self): + class Err(socket.error): + errno = errno.EAGAIN + + self.event_loop._csock.send.side_effect = Err + self.assertIsNone(self.event_loop._write_to_self()) + + def test_write_to_self_exception(self): + class Err(socket.error): + pass + + self.event_loop._csock.send.side_effect = Err + self.assertRaises(Err, self.event_loop._write_to_self) + + def test_sock_recv(self): + sock = unittest.mock.Mock() + self.event_loop._sock_recv = unittest.mock.Mock() + + f = self.event_loop.sock_recv(sock, 1024) + self.assertIsInstance(f, futures.Future) + self.assertEqual( + (f, False, sock, 1024), + self.event_loop._sock_recv.call_args[0]) + + def test__sock_recv_canceled_fut(self): + sock = unittest.mock.Mock() + + f = futures.Future() + f.cancel() + + self.event_loop._sock_recv(f, False, sock, 1024) + self.assertFalse(sock.recv.called) + + def test__sock_recv_unregister(self): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + f = futures.Future() + f.cancel() + + self.event_loop.remove_reader = unittest.mock.Mock() + self.event_loop._sock_recv(f, True, sock, 1024) + self.assertEqual((10,), self.event_loop.remove_reader.call_args[0]) + + def test__sock_recv_tryagain(self): + class Err(socket.error): + errno = errno.EAGAIN + + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.recv.side_effect = Err + + self.event_loop.add_reader = unittest.mock.Mock() + self.event_loop._sock_recv(f, False, sock, 1024) + self.assertEqual((10, self.event_loop._sock_recv, f, True, sock, 1024), + self.event_loop.add_reader.call_args[0]) + + def test__sock_recv_exception(self): + class Err(socket.error): + pass + + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.recv.side_effect = Err + + self.event_loop._sock_recv(f, False, sock, 1024) + self.assertIsInstance(f.exception(), Err) + + def test_sock_sendall(self): + sock = unittest.mock.Mock() + self.event_loop._sock_sendall = unittest.mock.Mock() + + f = self.event_loop.sock_sendall(sock, b'data') + self.assertIsInstance(f, futures.Future) + self.assertEqual( + (f, False, sock, b'data'), + self.event_loop._sock_sendall.call_args[0]) + + def test_sock_sendall_nodata(self): + sock = unittest.mock.Mock() + self.event_loop._sock_sendall = unittest.mock.Mock() + + f = self.event_loop.sock_sendall(sock, b'') + self.assertIsInstance(f, futures.Future) + self.assertTrue(f.done()) + self.assertFalse(self.event_loop._sock_sendall.called) + + def test__sock_sendall_canceled_fut(self): + sock = unittest.mock.Mock() + + f = futures.Future() + f.cancel() + + self.event_loop._sock_sendall(f, False, sock, b'data') + self.assertFalse(sock.send.called) + + def test__sock_sendall_unregister(self): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + f = futures.Future() + f.cancel() + + self.event_loop.remove_writer = unittest.mock.Mock() + self.event_loop._sock_sendall(f, True, sock, b'data') + self.assertEqual((10,), self.event_loop.remove_writer.call_args[0]) + + def test__sock_sendall_tryagain(self): + class Err(socket.error): + errno = errno.EAGAIN + + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.send.side_effect = Err + + self.event_loop.add_writer = unittest.mock.Mock() + self.event_loop._sock_sendall(f, False, sock, b'data') + self.assertEqual( + (10, self.event_loop._sock_sendall, f, True, sock, b'data'), + self.event_loop.add_writer.call_args[0]) + + def test__sock_sendall_exception(self): + class Err(socket.error): + pass + + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.send.side_effect = Err + + self.event_loop._sock_sendall(f, False, sock, b'data') + self.assertIsInstance(f.exception(), Err) + + def test__sock_sendall(self): + sock = unittest.mock.Mock() + + f = futures.Future() + sock.fileno.return_value = 10 + sock.send.return_value = 4 + + self.event_loop._sock_sendall(f, False, sock, b'data') + self.assertTrue(f.done()) + + def test__sock_sendall_partial(self): + sock = unittest.mock.Mock() + + f = futures.Future() + sock.fileno.return_value = 10 + sock.send.return_value = 2 + + self.event_loop.add_writer = unittest.mock.Mock() + self.event_loop._sock_sendall(f, False, sock, b'data') + self.assertFalse(f.done()) + self.assertEqual( + (10, self.event_loop._sock_sendall, f, True, sock, b'ta'), + self.event_loop.add_writer.call_args[0]) + + def test__sock_sendall_none(self): + sock = unittest.mock.Mock() + + f = futures.Future() + sock.fileno.return_value = 10 + sock.send.return_value = 0 + + self.event_loop.add_writer = unittest.mock.Mock() + self.event_loop._sock_sendall(f, False, sock, b'data') + self.assertFalse(f.done()) + self.assertEqual( + (10, self.event_loop._sock_sendall, f, True, sock, b'data'), + self.event_loop.add_writer.call_args[0]) + + def test_sock_connect(self): + sock = unittest.mock.Mock() + self.event_loop._sock_connect = unittest.mock.Mock() + + f = self.event_loop.sock_connect(sock, ('127.0.0.1', 8080)) + self.assertIsInstance(f, futures.Future) + self.assertEqual( + (f, False, sock, ('127.0.0.1', 8080)), + self.event_loop._sock_connect.call_args[0]) + + def test__sock_connect(self): + f = futures.Future() + + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + self.event_loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) + self.assertTrue(f.done()) + self.assertTrue(sock.connect.called) + + def test__sock_connect_canceled_fut(self): + sock = unittest.mock.Mock() + + f = futures.Future() + f.cancel() + + self.event_loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) + self.assertFalse(sock.connect.called) + + def test__sock_connect_unregister(self): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + f = futures.Future() + f.cancel() + + self.event_loop.remove_writer = unittest.mock.Mock() + self.event_loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.assertEqual((10,), self.event_loop.remove_writer.call_args[0]) + + def test__sock_connect_tryagain(self): + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.getsockopt.return_value = errno.EAGAIN + + self.event_loop.add_writer = unittest.mock.Mock() + self.event_loop.remove_writer = unittest.mock.Mock() + + self.event_loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.assertEqual( + (10, self.event_loop._sock_connect, f, + True, sock, ('127.0.0.1', 8080)), + self.event_loop.add_writer.call_args[0]) + + def test__sock_connect_exception(self): + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.getsockopt.return_value = errno.ENOTCONN + + self.event_loop.remove_writer = unittest.mock.Mock() + self.event_loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.assertIsInstance(f.exception(), socket.error) + + def test_sock_accept(self): + sock = unittest.mock.Mock() + self.event_loop._sock_accept = unittest.mock.Mock() + + f = self.event_loop.sock_accept(sock) + self.assertIsInstance(f, futures.Future) + self.assertEqual( + (f, False, sock), self.event_loop._sock_accept.call_args[0]) + + def test__sock_accept(self): + f = futures.Future() + + conn = unittest.mock.Mock() + + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.accept.return_value = conn, ('127.0.0.1', 1000) + + self.event_loop._sock_accept(f, False, sock) + self.assertTrue(f.done()) + self.assertEqual((conn, ('127.0.0.1', 1000)), f.result()) + self.assertEqual((False,), conn.setblocking.call_args[0]) + + def test__sock_accept_canceled_fut(self): + sock = unittest.mock.Mock() + + f = futures.Future() + f.cancel() + + self.event_loop._sock_accept(f, False, sock) + self.assertFalse(sock.accept.called) + + def test__sock_accept_unregister(self): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + f = futures.Future() + f.cancel() + + self.event_loop.remove_reader = unittest.mock.Mock() + self.event_loop._sock_accept(f, True, sock) + self.assertEqual((10,), self.event_loop.remove_reader.call_args[0]) + + def test__sock_accept_tryagain(self): + class Err(socket.error): + errno = errno.EAGAIN + + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.accept.side_effect = Err + + self.event_loop.add_reader = unittest.mock.Mock() + self.event_loop._sock_accept(f, False, sock) + self.assertEqual( + (10, self.event_loop._sock_accept, f, True, sock), + self.event_loop.add_reader.call_args[0]) + + def test__sock_accept_exception(self): + class Err(socket.error): + pass + + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.accept.side_effect = Err + + self.event_loop._sock_accept(f, False, sock) + self.assertIsInstance(f.exception(), Err) + + def test_add_reader(self): + self.event_loop._selector.get_info.side_effect = KeyError + h = self.event_loop.add_reader(1, lambda: True) + + self.assertTrue(self.event_loop._selector.register.called) + self.assertEqual( + (1, selectors.EVENT_READ, (h, None)), + self.event_loop._selector.register.call_args[0]) + + def test_add_reader_existing(self): + reader = unittest.mock.Mock() + writer = unittest.mock.Mock() + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_WRITE, (reader, writer)) + h = self.event_loop.add_reader(1, lambda: True) + + self.assertTrue(reader.cancel.called) + self.assertFalse(self.event_loop._selector.register.called) + self.assertTrue(self.event_loop._selector.modify.called) + self.assertEqual( + (1, selectors.EVENT_WRITE | selectors.EVENT_READ, (h, writer)), + self.event_loop._selector.modify.call_args[0]) + + def test_add_reader_existing_writer(self): + writer = unittest.mock.Mock() + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_WRITE, (None, writer)) + h = self.event_loop.add_reader(1, lambda: True) + + self.assertFalse(self.event_loop._selector.register.called) + self.assertTrue(self.event_loop._selector.modify.called) + self.assertEqual( + (1, selectors.EVENT_WRITE | selectors.EVENT_READ, (h, writer)), + self.event_loop._selector.modify.call_args[0]) + + def test_remove_reader(self): + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_READ, (None, None)) + self.assertFalse(self.event_loop.remove_reader(1)) + + self.assertTrue(self.event_loop._selector.unregister.called) + + def test_remove_reader_read_write(self): + reader = unittest.mock.Mock() + writer = unittest.mock.Mock() + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_READ | selectors.EVENT_WRITE, (reader, writer)) + self.assertTrue( + self.event_loop.remove_reader(1)) + + self.assertFalse(self.event_loop._selector.unregister.called) + self.assertEqual( + (1, selectors.EVENT_WRITE, (None, writer)), + self.event_loop._selector.modify.call_args[0]) + + def test_remove_reader_unknown(self): + self.event_loop._selector.get_info.side_effect = KeyError + self.assertFalse( + self.event_loop.remove_reader(1)) + + def test_add_writer(self): + self.event_loop._selector.get_info.side_effect = KeyError + h = self.event_loop.add_writer(1, lambda: True) + + self.assertTrue(self.event_loop._selector.register.called) + self.assertEqual( + (1, selectors.EVENT_WRITE, (None, h)), + self.event_loop._selector.register.call_args[0]) + + def test_add_writer_existing(self): + reader = unittest.mock.Mock() + writer = unittest.mock.Mock() + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_READ, (reader, writer)) + h = self.event_loop.add_writer(1, lambda: True) + + self.assertTrue(writer.cancel.called) + self.assertFalse(self.event_loop._selector.register.called) + self.assertTrue(self.event_loop._selector.modify.called) + self.assertEqual( + (1, selectors.EVENT_WRITE | selectors.EVENT_READ, (reader, h)), + self.event_loop._selector.modify.call_args[0]) + + def test_remove_writer(self): + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_WRITE, (None, None)) + self.assertFalse(self.event_loop.remove_writer(1)) + + self.assertTrue(self.event_loop._selector.unregister.called) + + def test_remove_writer_read_write(self): + reader = unittest.mock.Mock() + writer = unittest.mock.Mock() + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_READ | selectors.EVENT_WRITE, (reader, writer)) + self.assertTrue( + self.event_loop.remove_writer(1)) + + self.assertFalse(self.event_loop._selector.unregister.called) + self.assertEqual( + (1, selectors.EVENT_READ, (reader, None)), + self.event_loop._selector.modify.call_args[0]) + + def test_remove_writer_unknown(self): + self.event_loop._selector.get_info.side_effect = KeyError + self.assertFalse( + self.event_loop.remove_writer(1)) + + def test_process_events_read(self): + reader = unittest.mock.Mock() + reader.cancelled = False + + self.event_loop._add_callback = unittest.mock.Mock() + self.event_loop._process_events( + ((1, selectors.EVENT_READ, (reader, None)),)) + self.assertTrue(self.event_loop._add_callback.called) + self.assertEqual((reader,), self.event_loop._add_callback.call_args[0]) + + def test_process_events_read_cancelled(self): + reader = unittest.mock.Mock() + reader.cancelled = True + + self.event_loop.remove_reader = unittest.mock.Mock() + self.event_loop._process_events( + ((1, selectors.EVENT_READ, (reader, None)),)) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertEqual((1,), self.event_loop.remove_reader.call_args[0]) + + def test_process_events_write(self): + writer = unittest.mock.Mock() + writer.cancelled = False + + self.event_loop._add_callback = unittest.mock.Mock() + self.event_loop._process_events( + ((1, selectors.EVENT_WRITE, (None, writer)),)) + self.assertTrue(self.event_loop._add_callback.called) + self.assertEqual((writer,), self.event_loop._add_callback.call_args[0]) + + def test_process_events_write_cancelled(self): + writer = unittest.mock.Mock() + writer.cancelled = True + + self.event_loop.remove_writer = unittest.mock.Mock() + self.event_loop._process_events( + ((1, selectors.EVENT_WRITE, (None, writer)),)) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertEqual((1,), self.event_loop.remove_writer.call_args[0]) + + +class SelectorSocketTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock() + self.sock = unittest.mock.Mock() + self.protocol = unittest.mock.Mock() + + def test_ctor(self): + _SelectorSocketTransport(self.event_loop, self.sock, self.protocol) + self.assertTrue(self.event_loop.add_reader.called) + self.assertTrue(self.event_loop.call_soon.called) + + def test_ctor_with_waiter(self): + fut = futures.Future() + + _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol, fut) + self.assertEqual(2, self.event_loop.call_soon.call_count) + self.assertEqual(fut.set_result, + self.event_loop.call_soon.call_args[0][0]) + + def test_read_ready(self): + data_received = unittest.mock.Mock() + self.protocol.data_received = data_received + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + + self.sock.recv.return_value = b'data' + transport._read_ready() + + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (data_received, b'data'), + self.event_loop.call_soon.call_args[0]) + + def test_read_ready_eof(self): + eof_received = unittest.mock.Mock() + self.protocol.eof_received = eof_received + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + + self.sock.recv.return_value = b'' + transport._read_ready() + + self.assertTrue(self.event_loop.remove_reader.called) + self.assertEqual( + (eof_received,), self.event_loop.call_soon.call_args[0]) + + def test_read_ready_tryagain(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + + class Err(socket.error): + errno = errno.EAGAIN + + self.sock.recv.side_effect = Err + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + + def test_read_ready_err(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + + class Err(socket.error): + pass + + self.sock.recv.side_effect = Err + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + self.assertTrue(transport._fatal_error.called) + self.assertIsInstance(transport._fatal_error.call_args[0][0], Err) + + def test_abort(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + + transport.abort() + self.assertTrue(transport._fatal_error.called) + self.assertIsNone(transport._fatal_error.call_args[0][0]) + + def test_write(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport.write(data) + self.assertTrue(self.sock.send.called) + self.assertEqual(self.sock.send.call_args[0], (data,)) + + def test_write_no_data(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append(b'data') + transport.write(b'') + self.assertFalse(self.sock.send.called) + self.assertEqual([b'data'], transport._buffer) + + def test_write_buffer(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append(b'data1') + transport.write(b'data2') + self.assertFalse(self.sock.send.called) + self.assertEqual([b'data1', b'data2'], transport._buffer) + + def test_write_partial(self): + data = b'data' + self.sock.send.return_value = 2 + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport.write(data) + + self.assertTrue(self.event_loop.add_writer.called) + self.assertEqual( + transport._write_ready, self.event_loop.add_writer.call_args[0][1]) + + self.assertEqual([b'ta'], transport._buffer) + + def test_write_partial_none(self): + data = b'data' + self.sock.send.return_value = 0 + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport.write(data) + + self.assertTrue(self.event_loop.add_writer.called) + self.assertEqual( + transport._write_ready, self.event_loop.add_writer.call_args[0][1]) + + self.assertEqual([b'data'], transport._buffer) + + def test_write_tryagain(self): + data = b'data' + + class Err(socket.error): + errno = errno.EAGAIN + + self.sock.send.side_effect = Err + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport.write(data) + + self.assertTrue(self.event_loop.add_writer.called) + self.assertEqual( + transport._write_ready, self.event_loop.add_writer.call_args[0][1]) + + self.assertEqual([b'data'], transport._buffer) + + def test_write_exception(self): + data = b'data' + + class Err(socket.error): + pass + + self.sock.send.side_effect = Err + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport.write(data) + + self.assertTrue(transport._fatal_error.called) + self.assertIsInstance(transport._fatal_error.call_args[0][0], Err) + + def test_write_str(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + self.assertRaises(AssertionError, transport.write, 'str') + + def test_write_closing(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport.close() + self.assertRaises(AssertionError, transport.write, b'data') + + def test_write_ready(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append(data) + transport._write_ready() + self.assertTrue(self.sock.send.called) + self.assertEqual(self.sock.send.call_args[0], (data,)) + self.assertTrue(self.event_loop.remove_writer.called) + + def test_write_ready_closing(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._closing = True + transport._buffer.append(data) + transport._write_ready() + self.assertTrue(self.sock.send.called) + self.assertEqual(self.sock.send.call_args[0], (data,)) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (transport._call_connection_lost, None), + self.event_loop.call_soon.call_args[0]) + + def test_write_ready_no_data(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._write_ready() + self.assertFalse(self.sock.send.called) + self.assertTrue(self.event_loop.remove_writer.called) + + def test_write_ready_partial(self): + data = b'data' + self.sock.send.return_value = 2 + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append(data) + transport._write_ready() + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'ta'], transport._buffer) + + def test_write_ready_partial_none(self): + data = b'data' + self.sock.send.return_value = 0 + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append(data) + transport._write_ready() + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'data'], transport._buffer) + + def test_write_ready_tryagain(self): + class Err(socket.error): + errno = errno.EAGAIN + + self.sock.send.side_effect = Err + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer = [b'data1', b'data2'] + transport._write_ready() + + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'data1data2'], transport._buffer) + + def test_write_ready_exception(self): + class Err(socket.error): + pass + + self.sock.send.side_effect = Err + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._buffer.append(b'data') + transport._write_ready() + + self.assertTrue(transport._fatal_error.called) + self.assertIsInstance(transport._fatal_error.call_args[0][0], Err) + + def test_close(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport.close() + + self.assertTrue(transport._closing) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (transport._call_connection_lost, None), + self.event_loop.call_soon.call_args[0]) + + def test_close_write_buffer(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + self.event_loop.reset_mock() + transport._buffer.append(b'data') + transport.close() + + self.assertTrue(self.event_loop.remove_reader.called) + self.assertFalse(self.event_loop.call_soon.called) + + def test_fatal_error(self): + exc = object() + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + self.event_loop.reset_mock() + transport._buffer.append(b'data') + transport._fatal_error(exc) + + self.assertEqual([], transport._buffer) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (transport._call_connection_lost, exc), + self.event_loop.call_soon.call_args[0]) + + def test_connection_lost(self): + exc = object() + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + self.sock.reset_mock() + self.protocol.reset_mock() + transport._call_connection_lost(exc) + + self.assertTrue(self.protocol.connection_lost.called) + self.assertEqual( + (exc,), self.protocol.connection_lost.call_args[0]) + self.assertTrue(self.sock.close.called) + + +@unittest.skipIf(ssl is None, 'No ssl module') +class SelectorSslTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock() + self.sock = unittest.mock.Mock() + self.protocol = unittest.mock.Mock() + self.sslsock = unittest.mock.Mock() + self.sslsock.fileno.return_value = 1 + self.sslcontext = unittest.mock.Mock() + self.sslcontext.wrap_socket.return_value = self.sslsock + self.waiter = futures.Future() + + self.transport = _SelectorSslTransport( + self.event_loop, self.sock, + self.protocol, self.sslcontext, self.waiter) + self.event_loop.reset_mock() + self.sock.reset_mock() + self.protocol.reset_mock() + self.sslcontext.reset_mock() + + def test_on_handshake(self): + self.transport._on_handshake() + self.assertTrue(self.sslsock.do_handshake.called) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertEqual( + (1, self.transport._on_ready,), + self.event_loop.add_reader.call_args[0]) + self.assertEqual( + (1, self.transport._on_ready,), + self.event_loop.add_writer.call_args[0]) + + def test_on_handshake_reader_retry(self): + self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError + self.transport._on_handshake() + self.assertEqual( + (1, self.transport._on_handshake,), + self.event_loop.add_reader.call_args[0]) + + def test_on_handshake_writer_retry(self): + self.sslsock.do_handshake.side_effect = ssl.SSLWantWriteError + self.transport._on_handshake() + self.assertEqual( + (1, self.transport._on_handshake,), + self.event_loop.add_writer.call_args[0]) + + def test_on_handshake_exc(self): + self.sslsock.do_handshake.side_effect = ValueError + self.transport._on_handshake() + self.assertTrue(self.sslsock.close.called) + + def test_on_handshake_base_exc(self): + self.sslsock.do_handshake.side_effect = BaseException + self.assertRaises(BaseException, self.transport._on_handshake) + self.assertTrue(self.sslsock.close.called) + + def test_write_no_data(self): + self.transport._buffer.append(b'data') + self.transport.write(b'') + self.assertEqual([b'data'], self.transport._buffer) + + def test_write_str(self): + self.assertRaises(AssertionError, self.transport.write, 'str') + + def test_write_closing(self): + self.transport.close() + self.assertRaises(AssertionError, self.transport.write, b'data') + + def test_abort(self): + self.transport._fatal_error = unittest.mock.Mock() + + self.transport.abort() + self.assertTrue(self.transport._fatal_error.called) + self.assertEqual((None,), self.transport._fatal_error.call_args[0]) + + def test_fatal_error(self): + exc = object() + self.transport._buffer.append(b'data') + self.transport._fatal_error(exc) + + self.assertEqual([], self.transport._buffer) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (self.protocol.connection_lost, exc), + self.event_loop.call_soon.call_args[0]) + + def test_close(self): + self.transport.close() + + self.assertTrue(self.transport._closing) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (self.protocol.connection_lost, None), + self.event_loop.call_soon.call_args[0]) + + def test_close_write_buffer(self): + self.transport._buffer.append(b'data') + self.transport.close() + + self.assertTrue(self.event_loop.remove_reader.called) + self.assertFalse(self.event_loop.call_soon.called) + + def test_on_ready_closed(self): + self.sslsock.fileno.return_value = -1 + self.transport._on_ready() + self.assertFalse(self.sslsock.recv.called) + + def test_on_ready_recv(self): + self.sslsock.recv.return_value = b'data' + self.transport._on_ready() + self.assertTrue(self.sslsock.recv.called) + self.assertEqual((b'data',), self.protocol.data_received.call_args[0]) + + def test_on_ready_recv_eof(self): + self.sslsock.recv.return_value = b'' + self.transport._on_ready() + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.sslsock.close.called) + self.assertTrue(self.protocol.connection_lost.called) + + def test_on_ready_recv_retry(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.transport._on_ready() + self.assertTrue(self.sslsock.recv.called) + self.assertFalse(self.protocol.data_received.called) + + self.sslsock.recv.side_effect = ssl.SSLWantWriteError + self.transport._on_ready() + self.assertFalse(self.protocol.data_received.called) + + class Err(socket.error): + errno = errno.EAGAIN + + self.sslsock.recv.side_effect = Err + self.transport._on_ready() + self.assertFalse(self.protocol.data_received.called) + + def test_on_ready_recv_exc(self): + class Err(socket.error): + pass + + self.sslsock.recv.side_effect = Err + self.transport._fatal_error = unittest.mock.Mock() + self.transport._on_ready() + self.assertTrue(self.transport._fatal_error.called) + + def test_on_ready_send(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 4 + self.transport._buffer = [b'data'] + self.transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertEqual([], self.transport._buffer) + + def test_on_ready_send_none(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 0 + self.transport._buffer = [b'data1', b'data2'] + self.transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertEqual([b'data1data2'], self.transport._buffer) + + def test_on_ready_send_partial(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 2 + self.transport._buffer = [b'data1', b'data2'] + self.transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertEqual([b'ta1data2'], self.transport._buffer) + + def test_on_ready_send_closing_partial(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 2 + self.transport._buffer = [b'data1', b'data2'] + self.transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertFalse(self.sslsock.close.called) + + def test_on_ready_send_closing(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 4 + self.transport.close() + self.transport._buffer = [b'data'] + self.transport._on_ready() + self.assertTrue(self.sslsock.close.called) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.protocol.connection_lost.called) + + def test_on_ready_send_retry(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + + self.transport._buffer = [b'data'] + + self.sslsock.send.side_effect = ssl.SSLWantReadError + self.transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertEqual([b'data'], self.transport._buffer) + + self.sslsock.send.side_effect = ssl.SSLWantWriteError + self.transport._on_ready() + self.assertEqual([b'data'], self.transport._buffer) + + class Err(socket.error): + errno = errno.EAGAIN + + self.sslsock.send.side_effect = Err + self.transport._on_ready() + self.assertEqual([b'data'], self.transport._buffer) + + def test_on_ready_send_exc(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + + class Err(socket.error): + pass + + self.sslsock.send.side_effect = Err + self.transport._buffer = [b'data'] + self.transport._fatal_error = unittest.mock.Mock() + self.transport._on_ready() + self.assertTrue(self.transport._fatal_error.called) + self.assertEqual([], self.transport._buffer) + + +class SelectorDatagramTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock() + self.sock = unittest.mock.Mock() + self.protocol = unittest.mock.Mock() + + def test_read_ready(self): + datagram_received = unittest.mock.Mock() + self.protocol.datagram_received = datagram_received + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + + self.sock.recvfrom.return_value = (b'data', ('0.0.0.0', 1234)) + transport._read_ready() + + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (datagram_received, b'data', ('0.0.0.0', 1234)), + self.event_loop.call_soon.call_args[0]) + + def test_read_ready_tryagain(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + + class Err(socket.error): + errno = errno.EAGAIN + + self.sock.recvfrom.side_effect = Err + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + + def test_read_ready_err(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + + class Err(socket.error): + pass + + self.sock.recvfrom.side_effect = Err + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + self.assertTrue(transport._fatal_error.called) + self.assertIsInstance(transport._fatal_error.call_args[0][0], Err) + + def test_abort(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + + transport.abort() + self.assertTrue(transport._fatal_error.called) + self.assertIsNone(transport._fatal_error.call_args[0][0]) + + def test_sendto(self): + data = b'data' + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport.sendto(data, ('0.0.0.0', 1234)) + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234))) + + def test_sendto_no_data(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append((b'data', ('0.0.0.0', 12345))) + transport.sendto(b'', ()) + self.assertFalse(self.sock.sendto.called) + self.assertEqual( + [(b'data', ('0.0.0.0', 12345))], list(transport._buffer)) + + def test_sendto_buffer(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append((b'data1', ('0.0.0.0', 12345))) + transport.sendto(b'data2', ('0.0.0.0', 12345)) + self.assertFalse(self.sock.sendto.called) + self.assertEqual( + [(b'data1', ('0.0.0.0', 12345)), + (b'data2', ('0.0.0.0', 12345))], + list(transport._buffer)) + + def test_sendto_tryagain(self): + data = b'data' + + class Err(socket.error): + errno = errno.EAGAIN + + self.sock.sendto.side_effect = Err + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport.sendto(data, ('0.0.0.0', 12345)) + + self.assertTrue(self.event_loop.add_writer.called) + self.assertEqual( + transport._sendto_ready, + self.event_loop.add_writer.call_args[0][1]) + + self.assertEqual( + [(b'data', ('0.0.0.0', 12345))], list(transport._buffer)) + + def test_sendto_exception(self): + data = b'data' + + class Err(socket.error): + pass + + self.sock.sendto.side_effect = Err + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport.sendto(data, ()) + + self.assertTrue(transport._fatal_error.called) + self.assertIsInstance(transport._fatal_error.call_args[0][0], Err) + + def test_sendto_connection_refused(self): + data = b'data' + + self.sock.sendto.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport.sendto(data, ()) + + self.assertFalse(transport._fatal_error.called) + + def test_sendto_connection_refused_connected(self): + data = b'data' + + self.sock.send.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1)) + transport._fatal_error = unittest.mock.Mock() + transport.sendto(data) + + self.assertTrue(transport._fatal_error.called) + + def test_sendto_str(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + self.assertRaises(AssertionError, transport.sendto, 'str', ()) + + def test_sendto_connected_addr(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1)) + self.assertRaises( + AssertionError, transport.sendto, b'str', ('0.0.0.0', 2)) + + def test_sendto_closing(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport.close() + self.assertRaises(AssertionError, transport.sendto, b'data', ()) + + def test_sendto_ready(self): + data = b'data' + self.sock.sendto.return_value = len(data) + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append((data, ('0.0.0.0', 12345))) + transport._sendto_ready() + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (data, ('0.0.0.0', 12345))) + self.assertTrue(self.event_loop.remove_writer.called) + + def test_sendto_ready_closing(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._closing = True + transport._buffer.append((data, ())) + transport._sendto_ready() + self.assertTrue(self.sock.sendto.called) + self.assertEqual(self.sock.sendto.call_args[0], (data, ())) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (transport._call_connection_lost, None), + self.event_loop.call_soon.call_args[0]) + + def test_sendto_ready_no_data(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._sendto_ready() + self.assertFalse(self.sock.sendto.called) + self.assertTrue(self.event_loop.remove_writer.called) + + def test_sendto_ready_tryagain(self): + class Err(socket.error): + errno = errno.EAGAIN + + self.sock.sendto.side_effect = Err + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.extend([(b'data1', ()), (b'data2', ())]) + transport._sendto_ready() + + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual( + [(b'data1', ()), (b'data2', ())], + list(transport._buffer)) + + def test_sendto_ready_exception(self): + class Err(socket.error): + pass + + self.sock.sendto.side_effect = Err + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._buffer.append((b'data', ())) + transport._sendto_ready() + + self.assertTrue(transport._fatal_error.called) + self.assertIsInstance(transport._fatal_error.call_args[0][0], Err) + + def test_sendto_ready_connection_refused(self): + self.sock.sendto.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._buffer.append((b'data', ())) + transport._sendto_ready() + + self.assertFalse(transport._fatal_error.called) + + def test_sendto_ready_connection_refused_connection(self): + self.sock.send.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1)) + transport._fatal_error = unittest.mock.Mock() + transport._buffer.append((b'data', ())) + transport._sendto_ready() + + self.assertTrue(transport._fatal_error.called) + + def test_close(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport.close() + + self.assertTrue(transport._closing) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (transport._call_connection_lost, None), + self.event_loop.call_soon.call_args[0]) + + def test_close_write_buffer(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + self.event_loop.reset_mock() + transport._buffer.append((b'data', ())) + transport.close() + + self.assertTrue(self.event_loop.remove_reader.called) + self.assertFalse(self.event_loop.call_soon.called) + + def test_fatal_error(self): + exc = object() + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + self.event_loop.reset_mock() + transport._buffer.append((b'data', ())) + transport._fatal_error(exc) + + self.assertEqual([], list(transport._buffer)) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (transport._call_connection_lost, exc), + self.event_loop.call_soon.call_args[0]) + + def test_fatal_error_connected(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1)) + self.event_loop.reset_mock() + transport._fatal_error(ConnectionRefusedError()) + + self.assertEqual( + 2, self.event_loop.call_soon.call_count) + + def test_transport_closing(self): + exc = object() + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + self.sock.reset_mock() + self.protocol.reset_mock() + transport._call_connection_lost(exc) + + self.assertTrue(self.protocol.connection_lost.called) + self.assertEqual( + (exc,), self.protocol.connection_lost.call_args[0]) + self.assertTrue(self.sock.close.called) diff --git a/tests/selectors_test.py b/tests/selectors_test.py new file mode 100644 index 00000000..3ebaab8c --- /dev/null +++ b/tests/selectors_test.py @@ -0,0 +1,138 @@ +"""Tests for selectors.py.""" + +import unittest +import unittest.mock + +from tulip import events +from tulip import selectors + + +class BaseSelectorTests(unittest.TestCase): + + def test_fileobj_to_fd(self): + self.assertEqual(10, selectors._fileobj_to_fd(10)) + + f = unittest.mock.Mock() + f.fileno.return_value = 10 + self.assertEqual(10, selectors._fileobj_to_fd(f)) + + f.fileno.side_effect = TypeError + self.assertRaises(ValueError, selectors._fileobj_to_fd, f) + + def test_selector_key_repr(self): + key = selectors.SelectorKey(10, selectors.EVENT_READ) + self.assertEqual( + "SelectorKey", repr(key)) + + def test_register(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ) + self.assertIsInstance(key, selectors.SelectorKey) + self.assertEqual(key.fd, 10) + self.assertIs(key, s._fd_to_key[10]) + + def test_register_unknown_event(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.register, unittest.mock.Mock(), 999999) + + def test_register_already_registered(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ) + self.assertRaises(ValueError, s.register, fobj, selectors.EVENT_READ) + + def test_unregister(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ) + s.unregister(fobj) + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_unregister_unknown(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.unregister, unittest.mock.Mock()) + + def test_modify_unknown(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.modify, unittest.mock.Mock(), 1) + + def test_modify(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ) + key2 = s.modify(fobj, selectors.EVENT_WRITE) + self.assertNotEqual(key.events, key2.events) + self.assertEqual((selectors.EVENT_WRITE, None), s.get_info(fobj)) + + def test_modify_data(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + d1 = object() + d2 = object() + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ, d1) + key2 = s.modify(fobj, selectors.EVENT_READ, d2) + self.assertEqual(key.events, key2.events) + self.assertNotEqual(key.data, key2.data) + self.assertEqual((selectors.EVENT_READ, d2), s.get_info(fobj)) + + def test_modify_same(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + data = object() + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ, data) + key2 = s.modify(fobj, selectors.EVENT_READ, data) + self.assertIs(key, key2) + + def test_select(self): + s = selectors._BaseSelector() + self.assertRaises(NotImplementedError, s.select) + + def test_close(self): + s = selectors._BaseSelector() + key = s.register(1, selectors.EVENT_READ) + + s.close() + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_registered_count(self): + s = selectors._BaseSelector() + self.assertEqual(0, s.registered_count()) + + s.register(1, selectors.EVENT_READ) + self.assertEqual(1, s.registered_count()) + + s.unregister(1) + self.assertEqual(0, s.registered_count()) + + def test_context_manager(self): + s = selectors._BaseSelector() + + with s as sel: + sel.register(1, selectors.EVENT_READ) + + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_key_from_fd(self): + s = selectors._BaseSelector() + key = s.register(1, selectors.EVENT_READ) + + self.assertIs(key, s._key_from_fd(1)) + self.assertIsNone(s._key_from_fd(10)) diff --git a/tests/streams_test.py b/tests/streams_test.py new file mode 100644 index 00000000..832ce371 --- /dev/null +++ b/tests/streams_test.py @@ -0,0 +1,315 @@ +"""Tests for streams.py.""" + +import unittest + +from tulip import events +from tulip import streams +from tulip import tasks +from tulip import test_utils + + +class StreamReaderTests(test_utils.LogTrackingTestCase): + + DATA = b'line1\nline2\nline3\n' + + def setUp(self): + super().setUp() + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + super().tearDown() + self.event_loop.close() + + def test_feed_empty_data(self): + stream = streams.StreamReader() + + stream.feed_data(b'') + self.assertEqual(0, stream.byte_count) + + def test_feed_data_byte_count(self): + stream = streams.StreamReader() + + stream.feed_data(self.DATA) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read_zero(self): + """Read zero bytes.""" + stream = streams.StreamReader() + stream.feed_data(self.DATA) + + read_task = tasks.Task(stream.read(0)) + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read(self): + """Read bytes.""" + stream = streams.StreamReader() + read_task = tasks.Task(stream.read(30)) + + def cb(): + stream.feed_data(self.DATA) + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + + def test_read_line_breaks(self): + """Read bytes without line breaks.""" + stream = streams.StreamReader() + stream.feed_data(b'line1') + stream.feed_data(b'line2') + + read_task = tasks.Task(stream.read(5)) + data = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'line1', data) + self.assertEqual(5, stream.byte_count) + + def test_read_eof(self): + """Read bytes, stop at eof.""" + stream = streams.StreamReader() + read_task = tasks.Task(stream.read(1024)) + + def cb(): + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'', data) + self.assertFalse(stream.byte_count) + + def test_read_until_eof(self): + """Read all bytes until eof.""" + stream = streams.StreamReader() + read_task = tasks.Task(stream.read(-1)) + + def cb(): + stream.feed_data(b'chunk1\n') + stream.feed_data(b'chunk2') + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'chunk1\nchunk2', data) + self.assertFalse(stream.byte_count) + + def test_read_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete(tasks.Task(stream.read(2))) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, tasks.Task(stream.read(2))) + + def test_readline(self): + """Read one line.""" + stream = streams.StreamReader() + stream.feed_data(b'chunk1 ') + read_task = tasks.Task(stream.readline()) + + def cb(): + stream.feed_data(b'chunk2 ') + stream.feed_data(b'chunk3 ') + stream.feed_data(b'\n chunk4') + self.event_loop.call_soon(cb) + + line = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) + self.assertEqual(len(b'\n chunk4')-1, stream.byte_count) + + def test_readline_limit_with_existing_data(self): + self.suppress_log_errors() + + stream = streams.StreamReader(3) + stream.feed_data(b'li') + stream.feed_data(b'ne1\nline2\n') + + read_task = tasks.Task(stream.readline()) + self.assertRaises( + ValueError, self.event_loop.run_until_complete, read_task) + self.assertEqual([b'line2\n'], list(stream.buffer)) + + stream = streams.StreamReader(3) + stream.feed_data(b'li') + stream.feed_data(b'ne1') + stream.feed_data(b'li') + + read_task = tasks.Task(stream.readline()) + self.assertRaises( + ValueError, self.event_loop.run_until_complete, read_task) + self.assertEqual([b'li'], list(stream.buffer)) + self.assertEqual(2, stream.byte_count) + + def test_readline_limit(self): + self.suppress_log_errors() + + stream = streams.StreamReader(7) + + def cb(): + stream.feed_data(b'chunk1') + stream.feed_data(b'chunk2') + stream.feed_data(b'chunk3\n') + stream.feed_eof() + self.event_loop.call_soon(cb) + + read_task = tasks.Task(stream.readline()) + self.assertRaises( + ValueError, self.event_loop.run_until_complete, read_task) + self.assertEqual([b'chunk3\n'], list(stream.buffer)) + self.assertEqual(7, stream.byte_count) + + def test_readline_line_byte_count(self): + stream = streams.StreamReader() + stream.feed_data(self.DATA[:6]) + stream.feed_data(self.DATA[6:]) + + read_task = tasks.Task(stream.readline()) + line = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'line1\n', line) + self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) + + def test_readline_eof(self): + stream = streams.StreamReader() + stream.feed_data(b'some data') + stream.feed_eof() + + read_task = tasks.Task(stream.readline()) + line = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'some data', line) + + def test_readline_empty_eof(self): + stream = streams.StreamReader() + stream.feed_eof() + + read_task = tasks.Task(stream.readline()) + line = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'', line) + + def test_readline_read_byte_count(self): + stream = streams.StreamReader() + stream.feed_data(self.DATA) + + read_task = tasks.Task(stream.readline()) + self.event_loop.run_until_complete(read_task) + + read_task = tasks.Task(stream.read(7)) + data = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'line2\nl', data) + self.assertEqual( + len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), + stream.byte_count) + + def test_readline_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete( + tasks.Task(stream.readline())) + self.assertEqual(b'line\n', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, tasks.Task(stream.readline())) + + def test_readexactly_zero_or_less(self): + """Read exact number of bytes (zero or less).""" + stream = streams.StreamReader() + stream.feed_data(self.DATA) + + read_task = tasks.Task(stream.readexactly(0)) + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + read_task = tasks.Task(stream.readexactly(-1)) + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readexactly(self): + """Read exact number of bytes.""" + stream = streams.StreamReader() + + n = 2 * len(self.DATA) + read_task = tasks.Task(stream.readexactly(n)) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA + self.DATA, data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readexactly_eof(self): + """Read exact number of bytes (eof).""" + stream = streams.StreamReader() + n = 2 * len(self.DATA) + read_task = tasks.Task(stream.readexactly(n)) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + + def test_readexactly_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete( + tasks.Task(stream.readexactly(2))) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, + tasks.Task(stream.readexactly(2))) + + def test_exception(self): + stream = streams.StreamReader() + self.assertIsNone(stream.exception()) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(stream.exception(), exc) + + def test_exception_waiter(self): + stream = streams.StreamReader() + + def set_err(): + yield from [] + stream.set_exception(ValueError()) + + def readline(): + yield from stream.readline() + + t1 = tasks.Task(stream.readline()) + t2 = tasks.Task(set_err()) + + self.event_loop.run_until_complete(tasks.Task(tasks.wait([t1, t2]))) + + self.assertRaises(ValueError, t1.result) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/subprocess_test.py b/tests/subprocess_test.py new file mode 100644 index 00000000..14ce11d7 --- /dev/null +++ b/tests/subprocess_test.py @@ -0,0 +1,54 @@ +"""Tests for subprocess_transport.py.""" + +import logging +import unittest + +from tulip import events +from tulip import protocols +from tulip import subprocess_transport + + +class MyProto(protocols.Protocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write_eof() + + def data_received(self, data): + logging.info('received: %r', data) + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + + +class FutureTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_unix_subprocess(self): + p = MyProto() + t = subprocess_transport.UnixSubprocessTransport(p, ['/bin/ls', '-lR']) + self.event_loop.run() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/tasks_test.py b/tests/tasks_test.py new file mode 100644 index 00000000..25ca5a4f --- /dev/null +++ b/tests/tasks_test.py @@ -0,0 +1,556 @@ +"""Tests for tasks.py.""" + +import concurrent.futures +import time +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import tasks +from tulip import test_utils + + +class Dummy: + def __repr__(self): + return 'Dummy()' + def __call__(self, *args): + pass + + +class TaskTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + super().tearDown() + + def test_task_class(self): + @tasks.coroutine + def notmuch(): + yield from [] + return 'ok' + t = tasks.Task(notmuch()) + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t._event_loop, self.event_loop) + + event_loop = events.new_event_loop() + t = tasks.Task(notmuch(), event_loop=event_loop) + self.assertIs(t._event_loop, event_loop) + + def test_task_decorator(self): + @tasks.task + def notmuch(): + yield from [] + return 'ko' + t = notmuch() + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + + def test_task_repr(self): + @tasks.task + def notmuch(): + yield from [] + return 'abc' + t = notmuch() + t.add_done_callback(Dummy()) + self.assertEqual(repr(t), 'Task()') + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), 'Task()') + self.assertRaises(futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertEqual(repr(t), 'Task()') + t = notmuch() + self.event_loop.run_until_complete(t) + self.assertEqual(repr(t), "Task()") + + def test_task_repr_custom(self): + def coro(): + yield from [] + + class T(futures.Future): + def __repr__(self): + return 'T[]' + + class MyTask(tasks.Task, T): + def __repr__(self): + return super().__repr__() + + t = MyTask(coro()) + self.assertEqual(repr(t), 'T[]()') + + def test_task_basics(self): + @tasks.task + def outer(): + a = yield from inner1() + b = yield from inner2() + return a+b + @tasks.task + def inner1(): + yield from [] + return 42 + @tasks.task + def inner2(): + yield from [] + return 1000 + t = outer() + self.assertEqual(self.event_loop.run_until_complete(t), 1042) + + def test_cancel(self): + @tasks.task + def task(): + yield from tasks.sleep(10.0) + return 12 + + t = task() + self.event_loop.call_soon(t.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_cancel_in_coro(self): + @tasks.coroutine + def task(): + t.cancel() + yield from [] + return 12 + + t = tasks.Task(task()) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_stop_while_run_in_complete(self): + x = 0 + @tasks.coroutine + def task(): + nonlocal x + while x < 10: + yield from tasks.sleep(0.1) + x += 1 + if x == 2: + self.event_loop.stop() + + t = tasks.Task(task()) + t0 = time.monotonic() + self.assertRaises( + futures.InvalidStateError, + self.event_loop.run_until_complete, t) + t1 = time.monotonic() + self.assertFalse(t.done()) + self.assertTrue(0.18 <= t1-t0 <= 0.22) + self.assertEqual(x, 2) + + def test_timeout(self): + @tasks.task + def task(): + yield from tasks.sleep(10.0) + return 42 + + t = task() + t0 = time.monotonic() + self.assertRaises( + futures.TimeoutError, + self.event_loop.run_until_complete, t, 0.1) + t1 = time.monotonic() + self.assertFalse(t.done()) + self.assertTrue(0.08 <= t1-t0 <= 0.12) + + def test_timeout_not(self): + @tasks.task + def task(): + yield from tasks.sleep(0.1) + return 42 + + t = task() + t0 = time.monotonic() + r = self.event_loop.run_until_complete(t, 10.0) + t1 = time.monotonic() + self.assertTrue(t.done()) + self.assertEqual(r, 42) + self.assertTrue(0.08 <= t1-t0 <= 0.12) + + def test_wait(self): + a = tasks.sleep(0.1) + b = tasks.sleep(0.15) + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertEqual(res, 42) + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + # TODO: Test different return_when values. + + def test_wait_first_completed(self): + a = tasks.sleep(10.0) + b = tasks.sleep(0.1) + task = tasks.Task(tasks.wait( + [b, a], return_when=tasks.FIRST_COMPLETED)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + + def test_wait_really_done(self): + self.suppress_log_errors() + # there is possibility that some tasks in the pending list + # became done but their callbacks haven't all been called yet + + @tasks.coroutine + def coro1(): + yield from [None] + @tasks.coroutine + def coro2(): + yield from [None, None] + + a = tasks.Task(coro1()) + b = tasks.Task(coro2()) + task = tasks.Task(tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({a, b}, done) + + def test_wait_first_exception(self): + self.suppress_log_errors() + + a = tasks.sleep(10.0) + @tasks.coroutine + def exc(): + yield from [] + raise ZeroDivisionError('err') + + b = tasks.Task(exc()) + task = tasks.Task(tasks.wait( + [b, a], return_when=tasks.FIRST_EXCEPTION)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + + def test_wait_with_exception(self): + self.suppress_log_errors() + a = tasks.sleep(0.1) + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(0.15) + raise ZeroDivisionError('really') + b = tasks.Task(sleeper()) + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + + def test_wait_with_timeout(self): + a = tasks.sleep(0.1) + b = tasks.sleep(0.15) + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], timeout=0.11) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.1) + self.assertTrue(t1-t0 <= 0.13) + + def test_as_completed(self): + @tasks.coroutine + def sleeper(dt, x): + yield from tasks.sleep(dt) + return x + a = sleeper(0.1, 'a') + b = sleeper(0.1, 'b') + c = sleeper(0.15, 'c') + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([b, c, a]): + values.append((yield from f)) + return values + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + + def test_as_completed_with_timeout(self): + self.suppress_log_errors() + a = tasks.sleep(0.1, 'a') + b = tasks.sleep(0.15, 'b') + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([a, b], timeout=0.12): + try: + v = yield from f + values.append((1, v)) + except futures.TimeoutError as exc: + values.append((2, exc)) + return values + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.11) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) + + def test_sleep(self): + @tasks.coroutine + def sleeper(dt, arg): + yield from tasks.sleep(dt/2) + res = yield from tasks.sleep(dt/2, arg) + return res + t = tasks.Task(sleeper(0.1, 'yeah')) + t0 = time.monotonic() + self.event_loop.run() + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.09) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + + def test_task_cancel_sleeping_task(self): + sleepfut = None + @tasks.task + def sleep(dt): + nonlocal sleepfut + sleepfut = tasks.sleep(dt) + try: + t0 = time.monotonic() + yield from sleepfut + finally: + t1 = time.monotonic() + @tasks.task + def doit(): + sleeper = sleep(5000) + self.event_loop.call_later(0.1, sleeper.cancel) + try: + t0 = time.monotonic() + yield from sleeper + except futures.CancelledError: + t1 = time.monotonic() + return 'cancelled' + else: + return 'slept in' + t0 = time.monotonic() + doer = doit() + self.assertEqual(self.event_loop.run_until_complete(doer), 'cancelled') + t1 = time.monotonic() + self.assertTrue(0.09 <= t1-t0 <= 0.13, (t1-t0, sleepfut, doer)) + + @unittest.mock.patch('tulip.tasks.logging') + def test_step_in_completed_task(self, m_logging): + @tasks.coroutine + def notmuch(): + yield from [] + return 'ko' + + task = tasks.Task(notmuch()) + task.set_result('ok') + + task._step() + self.assertTrue(m_logging.warn.called) + self.assertTrue(m_logging.warn.call_args[0][0].startswith( + '_step(): already done: ')) + + @unittest.mock.patch('tulip.tasks.logging') + def test_step_result(self, m_logging): + @tasks.coroutine + def notmuch(): + yield from [None, 1] + return 'ko' + + task = tasks.Task(notmuch()) + task._step() + self.assertFalse(m_logging.warn.called) + + task._step() + self.assertTrue(m_logging.warn.called) + self.assertEqual( + '_step(): bad yield: %r', + m_logging.warn.call_args[0][0]) + self.assertEqual(1, m_logging.warn.call_args[0][1]) + + def test_step_result_future(self): + """If coroutine returns future, task waits on this future.""" + self.suppress_log_warnings() + + class Fut(futures.Future): + def __init__(self, *args): + self.cb_added = False + super().__init__(*args) + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + fut = Fut() + result = None + + @tasks.task + def wait_for_future(): + nonlocal result + result = yield from fut + + task = wait_for_future() + self.event_loop.run_once() + self.assertTrue(fut.cb_added) + + res = object() + fut.set_result(res) + self.event_loop.run_once() + self.assertIs(res, result) + + def test_step_result_concurrent_future(self): + # Coroutine returns concurrent.futures.Future + self.suppress_log_warnings() + + class Fut(concurrent.futures.Future): + def __init__(self): + self.cb_added = False + super().__init__() + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + c_fut = Fut() + + @tasks.coroutine + def notmuch(): + yield from [c_fut] + return (yield) + + task = tasks.Task(notmuch()) + task._step() + self.assertTrue(c_fut.cb_added) + + res = object() + c_fut.set_result(res) + self.event_loop.run() + self.assertIs(res, task.result()) + + def test_step_with_baseexception(self): + self.suppress_log_errors() + + @tasks.coroutine + def notmutch(): + yield from [] + raise BaseException() + + task = tasks.Task(notmutch()) + self.assertRaises(BaseException, task._step) + + self.assertTrue(task.done()) + self.assertIsInstance(task.exception(), BaseException) + + def test_baseexception_during_cancel(self): + self.suppress_log_errors() + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(10) + + @tasks.coroutine + def notmutch(): + try: + yield from sleeper() + except futures.CancelledError: + raise BaseException() + + task = tasks.Task(notmutch()) + self.event_loop.run_once() + + task.cancel() + self.assertFalse(task.done()) + + self.assertRaises(BaseException, self.event_loop.run_once) + + self.assertTrue(task.done()) + self.assertTrue(task.cancelled()) + + def test_iscoroutinefunction(self): + def fn(): + pass + + self.assertFalse(tasks.iscoroutinefunction(fn)) + + def fn1(): + yield + self.assertFalse(tasks.iscoroutinefunction(fn1)) + + @tasks.coroutine + def fn2(): + yield + self.assertTrue(tasks.iscoroutinefunction(fn2)) + + def test_yield_vs_yield_from(self): + fut = futures.Future() + + @tasks.task + def wait_for_future(): + yield fut + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, task) + + def test_yield_vs_yield_from_generator(self): + fut = futures.Future() + + @tasks.coroutine + def coro(): + yield from fut + + @tasks.task + def wait_for_future(): + yield coro() + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, task) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/transports_test.py b/tests/transports_test.py new file mode 100644 index 00000000..eb61d914 --- /dev/null +++ b/tests/transports_test.py @@ -0,0 +1,39 @@ +"""Tests for transports.py.""" + +import unittest +import unittest.mock + +from tulip import transports + + +class TransportTests(unittest.TestCase): + + def test_ctor_extra_is_none(self): + transport = transports.Transport() + self.assertEqual(transport._extra, {}) + + def test_get_extra_info(self): + transport = transports.Transport({'extra': 'info'}) + self.assertEqual('info', transport.get_extra_info('extra')) + self.assertIsNone(transport.get_extra_info('unknown')) + + default = object() + self.assertIs(default, transport.get_extra_info('unknown', default)) + + def test_writelines(self): + transport = transports.Transport() + transport.write = unittest.mock.Mock() + + transport.writelines(['line1', 'line2', 'line3']) + self.assertEqual(3, transport.write.call_count) + + def test_not_implemented(self): + transport = transports.Transport() + + self.assertRaises(NotImplementedError, transport.write, 'data') + self.assertRaises(NotImplementedError, transport.write_eof) + self.assertRaises(NotImplementedError, transport.can_write_eof) + self.assertRaises(NotImplementedError, transport.pause) + self.assertRaises(NotImplementedError, transport.resume) + self.assertRaises(NotImplementedError, transport.close) + self.assertRaises(NotImplementedError, transport.abort) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py new file mode 100644 index 00000000..2504648b --- /dev/null +++ b/tests/unix_events_test.py @@ -0,0 +1,168 @@ +"""Tests for unix_events.py.""" + +import errno +import unittest +import unittest.mock + +try: + import signal +except ImportError: + signal = None + +from tulip import events +from tulip import unix_events + + +@unittest.skipUnless(signal, 'Signals are not supported') +class SelectorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unix_events.SelectorEventLoop() + + def test_check_signal(self): + self.assertRaises( + TypeError, self.event_loop._check_signal, '1') + self.assertRaises( + ValueError, self.event_loop._check_signal, signal.NSIG + 1) + + unix_events.signal = None + + def restore_signal(): + unix_events.signal = signal + self.addCleanup(restore_signal) + + self.assertRaises( + RuntimeError, self.event_loop._check_signal, signal.SIGINT) + + def test_handle_signal_no_handler(self): + self.event_loop._handle_signal(signal.NSIG + 1, ()) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_setup_error(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.set_wakeup_fd.side_effect = ValueError + + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + h = self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + self.assertIsInstance(h, events.Handler) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_install_error(self, m_signal): + m_signal.NSIG = signal.NSIG + + def set_wakeup_fd(fd): + if fd == -1: + raise ValueError() + m_signal.set_wakeup_fd = set_wakeup_fd + + class Err(OSError): + errno = errno.EFAULT + m_signal.signal.side_effect = Err + + self.assertRaises( + Err, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.logging') + def test_add_signal_handler_install_error2(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.event_loop._signal_handlers[signal.SIGHUP] = lambda: True + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(1, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.logging') + def test_add_signal_handler_install_error3(self, m_logging, m_signal): + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + m_signal.NSIG = signal.NSIG + + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(2, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + h = self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + self.assertTrue( + self.event_loop.remove_signal_handler(signal.SIGHUP)) + self.assertTrue(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_2(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.SIGINT = signal.SIGINT + + h = self.event_loop.add_signal_handler(signal.SIGINT, lambda: True) + self.event_loop._signal_handlers[signal.SIGHUP] = object() + m_signal.set_wakeup_fd.reset_mock() + + self.assertTrue( + self.event_loop.remove_signal_handler(signal.SIGINT)) + self.assertFalse(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGINT, m_signal.default_int_handler), + m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.logging') + def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + h = self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.set_wakeup_fd.side_effect = ValueError + + self.event_loop.remove_signal_handler(signal.SIGHUP) + self.assertTrue(m_logging.info) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error(self, m_signal): + m_signal.NSIG = signal.NSIG + h = self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.signal.side_effect = OSError + + self.assertRaises( + OSError, self.event_loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error2(self, m_signal): + m_signal.NSIG = signal.NSIG + h = self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.assertRaises( + RuntimeError, self.event_loop.remove_signal_handler, signal.SIGHUP) diff --git a/tests/winsocketpair_test.py b/tests/winsocketpair_test.py new file mode 100644 index 00000000..0175e9b9 --- /dev/null +++ b/tests/winsocketpair_test.py @@ -0,0 +1,32 @@ +"""Tests for winsocketpair.py""" + +import errno +import socket +import unittest +import unittest.mock + +from tulip import winsocketpair + + +class WinsocketpairTests(unittest.TestCase): + + def test_winsocketpair(self): + ssock, csock = winsocketpair.socketpair() + + csock.send(b'xxx') + self.assertEqual(b'xxx', ssock.recv(1024)) + + csock.close() + ssock.close() + + @unittest.mock.patch('tulip.winsocketpair.socket') + def test_winsocketpair_exc(self, m_socket): + m_socket.error = socket.error + class Err(socket.error): + errno = errno.WSAEWOULDBLOCK + 1 + + m_socket.socket.return_value.getsockname.return_value = ('', 12345) + m_socket.socket.return_value.accept.return_value = object(), object() + m_socket.socket.return_value.connect.side_effect = Err + + self.assertRaises(Err, winsocketpair.socketpair) diff --git a/tulip/TODO b/tulip/TODO new file mode 100644 index 00000000..acec5c24 --- /dev/null +++ b/tulip/TODO @@ -0,0 +1,28 @@ +TODO in tulip v2 (tulip/ package directory) +------------------------------------------- + +- See also TBD and Open Issues in PEP 3156 + +- Refactor unix_events.py (it's getting too long) + +- Docstrings + +- Unittests + +- better run_once() behavior? (Run ready list last.) + +- start_serving() + +- Make Handler() callable? Log the exception in there? + +- Add the repeat interval to the Handler class? + +- Recognize Handler passed to add_reader(), call_soon(), etc.? + +- SSL support + +- buffered stream implementation + +- Primitives like par() and wait_one() + +- Remove test dependency on xkcd.com, write our own test server diff --git a/tulip/__init__.py b/tulip/__init__.py new file mode 100644 index 00000000..faf307fb --- /dev/null +++ b/tulip/__init__.py @@ -0,0 +1,26 @@ +"""Tulip 2.0, tracking PEP 3156.""" + +import sys + +# This relies on each of the submodules having an __all__ variable. +from .futures import * +from .events import * +from .locks import * +from .transports import * +from .protocols import * +from .streams import * +from .tasks import * + +if sys.platform == 'win32': # pragma: no cover + from .windows_events import * +else: + from .unix_events import * # pragma: no cover + + +__all__ = (futures.__all__ + + events.__all__ + + locks.__all__ + + transports.__all__ + + protocols.__all__ + + streams.__all__ + + tasks.__all__) diff --git a/tulip/base_events.py b/tulip/base_events.py new file mode 100644 index 00000000..9e2e161c --- /dev/null +++ b/tulip/base_events.py @@ -0,0 +1,498 @@ +"""Base implementation of event loop. + +The event loop can be broken up into a multiplexer (the part +responsible for notifying us of IO events) and the event loop proper, +which wraps a multiplexer with functionality for scheduling callbacks, +immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. +""" + + +import collections +import concurrent.futures +import heapq +import logging +import socket +import time + +from . import events +from . import futures +from . import tasks + + +__all__ = ['BaseEventLoop'] + + +# Argument for default thread pool executor creation. +_MAX_WORKERS = 5 + + +class _StopError(BaseException): + """Raised to stop the event loop.""" + + +def _raise_stop_error(*args): + raise _StopError + + +class BaseEventLoop(events.AbstractEventLoop): + + def __init__(self): + self._ready = collections.deque() + self._scheduled = [] + self._default_executor = None + self._internal_fds = 0 + self._signal_handlers = {} + + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + """Create socket transport.""" + raise NotImplementedError + + def _make_ssl_transport(self, rawsock, protocol, + sslcontext, waiter, extra=None): + """Create SSL transport.""" + raise NotImplementedError + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + """Create datagram transport.""" + raise NotImplementedError + + def _read_from_self(self): + """XXX""" + raise NotImplementedError + + def _write_to_self(self): + """XXX""" + raise NotImplementedError + + def _process_events(self, event_list): + """Process selector events.""" + raise NotImplementedError + + def run(self): + """Run the event loop until nothing left to do or stop() called. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + + TODO: Give this a timeout too? + """ + while (self._ready or + self._scheduled or + self._selector.registered_count() > 1): + try: + self._run_once() + except _StopError: + break + + def run_forever(self): + """Run until stop() is called. + + This only makes sense over run() if you have another thread + scheduling callbacks using call_soon_threadsafe(). + """ + handler = self.call_repeatedly(24*3600, lambda: None) + try: + self.run() + finally: + handler.cancel() + + def run_once(self, timeout=None): + """Run through all callbacks and all I/O polls once. + + Calling stop() will break out of this too. + """ + try: + self._run_once(timeout) + except _StopError: + pass + + def run_until_complete(self, future, timeout=None): + """Run until the Future is done, or until a timeout. + + Return the Future's result, or raise its exception. If the + timeout is reached or stop() is called, raise TimeoutError. + """ + handler_called = False + def stop_loop(): + nonlocal handler_called + handler_called = True + raise _StopError + future.add_done_callback(_raise_stop_error) + + if timeout is None: + self.run_forever() + else: + handler = self.call_later(timeout, stop_loop) + self.run() + handler.cancel() + + if handler_called: + raise futures.TimeoutError + + return future.result() + + def stop(self): + """Stop running the event loop. + + Every callback scheduled before stop() is called will run. + Callback scheduled after stop() is called won't. However, + those callbacks will run if run() is called again later. + """ + self.call_soon(_raise_stop_error) + + def call_later(self, delay, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The delay can be an int or float, expressed in seconds. It is + always a relative time. + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Callbacks scheduled in the past are passed on to call_soon(), + so these will be called in the order in which they were + registered rather than by time due. This is so you can't + cheat and insert yourself at the front of the ready queue by + using a negative time. + + Any positional arguments after the callback will be passed to + the callback when it is called. + + # TODO: Should delay is None be interpreted as Infinity? + """ + if delay <= 0: + return self.call_soon(callback, *args) + handler = events.Timer(time.monotonic() + delay, callback, args) + heapq.heappush(self._scheduled, handler) + return handler + + def call_repeatedly(self, interval, callback, *args): + """Call a callback every 'interval' seconds.""" + def wrapper(): + callback(*args) # If this fails, the chain is broken. + handler._when = time.monotonic() + interval + heapq.heappush(self._scheduled, handler) + handler = events.Timer(time.monotonic() + interval, wrapper, ()) + heapq.heappush(self._scheduled, handler) + return handler + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + handler = events.make_handler(callback, args) + self._ready.append(handler) + return handler + + def call_soon_threadsafe(self, callback, *args): + """XXX""" + handler = self.call_soon(callback, *args) + self._write_to_self() + return handler + + def run_in_executor(self, executor, callback, *args): + if isinstance(callback, events.Handler): + assert not args + assert not isinstance(callback, events.Timer) + if callback.cancelled: + f = futures.Future() + f.set_result(None) + return f + callback, args = callback.callback, callback.args + if executor is None: + executor = self._default_executor + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + self._default_executor = executor + return self.wrap_future(executor.submit(callback, *args)) + + def set_default_executor(self, executor): + self._default_executor = executor + + def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) + + def getnameinfo(self, sockaddr, flags=0): + return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + + @tasks.task + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=False, family=0, proto=0, flags=0, sock=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + "host, port and sock can not be specified at the same time") + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + yield from self.sock_connect(sock, address) + except socket.error as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + if len(exceptions) == 1: + raise exceptions[0] + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise socket.error('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + + elif sock is None: + raise ValueError( + "host and port was not specified and no sock specified") + + sock.setblocking(False) + + protocol = protocol_factory() + waiter = futures.Future() + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + transport = self._make_ssl_transport( + sock, protocol, sslcontext, waiter) + else: + transport = self._make_socket_transport(sock, protocol, waiter) + + yield from waiter + return transport, protocol + + @tasks.task + def create_datagram_connection(self, protocol_factory, + host=None, port=None, *, + family=socket.AF_INET, proto=0, flags=0): + """Create datagram connection.""" + + addr = None + if host is not None or port is not None: + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_DGRAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + sock = socket.socket(family=family, type=type, proto=proto) + + try: + yield from self.sock_connect(sock, address) + addr = address + except socket.error as exc: + sock.close() + exceptions.append(exc) + else: + break + else: + if exceptions: + raise exceptions[0] + else: + sock = socket.socket( + family=family, type=socket.SOCK_DGRAM, proto=proto) + + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(False) + except socket.error: + sock.close() + raise + + protocol = protocol_factory() + transport = self._make_datagram_transport(sock, protocol, addr) + + return transport, protocol + + # TODO: Or create_server()? + @tasks.task + def start_serving(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, backlog=100, sock=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + "host, port and sock can not be specified at the same time") + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + # TODO: Maybe we want to bind every address in the list + # instead of the first one that works? + exceptions = [] + for family, type, proto, cname, address in infos: + sock = socket.socket(family=family, type=type, proto=proto) + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + except socket.error as exc: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + + elif sock is None: + raise ValueError( + "host and port was not specified and no sock specified") + + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock) + return sock + + @tasks.task + def start_serving_datagram(self, protocol_factory, host, port, *, + family=0, proto=0, flags=0): + """XXX""" + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_DGRAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + sock = socket.socket(family=family, type=type, proto=proto) + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + except socket.error as exc: + sock.close() + exceptions.append(exc) + else: + sock.setblocking(False) + break + else: + raise exceptions[0] + + self._make_datagram_transport( + sock, protocol_factory(), extra={'addr': sock.getsockname()}) + + return sock + + def _add_callback(self, handler): + """Add a Handler to ready or scheduled.""" + if handler.cancelled: + return + if isinstance(handler, events.Timer): + heapq.heappush(self._scheduled, handler) + else: + self._ready.append(handler) + + def wrap_future(self, future): + """XXX""" + if isinstance(future, futures.Future): + return future # Don't wrap our own type of Future. + new_future = futures.Future() + future.add_done_callback( + lambda future: + self.call_soon_threadsafe(new_future._copy_state, future)) + return new_future + + def _run_once(self, timeout=None): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0].cancelled: + heapq.heappop(self._scheduled) + + # Inspect the poll queue. If there's exactly one selectable + # file descriptor, it's the self-pipe, and if there's nothing + # scheduled, we should ignore it. + if (self._scheduled or + self._selector.registered_count() > self._internal_fds): + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0].when + deadline = max(0, when - time.monotonic()) + if timeout is None: + timeout = deadline + else: + timeout = min(timeout, deadline) + + t0 = time.monotonic() + event_list = self._selector.select(timeout) + t1 = time.monotonic() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + self._process_events(event_list) + + # Handle 'later' callbacks that are ready. + now = time.monotonic() + while self._scheduled: + handler = self._scheduled[0] + if handler.when > now: + break + handler = heapq.heappop(self._scheduled) + self._ready.append(handler) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is threadsafe without using locks. + ntodo = len(self._ready) + for i in range(ntodo): + handler = self._ready.popleft() + if not handler.cancelled: + handler.run() diff --git a/tulip/events.py b/tulip/events.py new file mode 100644 index 00000000..a5c2dd4c --- /dev/null +++ b/tulip/events.py @@ -0,0 +1,328 @@ +"""Event loop and event loop policy. + +Beyond the PEP: +- Only the main thread has a default event loop. +""" + +__all__ = ['EventLoopPolicy', 'DefaultEventLoopPolicy', + 'AbstractEventLoop', 'Timer', 'Handler', 'make_handler', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', 'new_event_loop', + ] + +import logging +import sys +import threading + + +class Handler: + """Object returned by callback registration methods.""" + + def __init__(self, callback, args): + self._callback = callback + self._args = args + self._cancelled = False + + def __repr__(self): + res = 'Handler({}, {})'.format(self._callback, self._args) + if self._cancelled: + res += '' + return res + + @property + def callback(self): + return self._callback + + @property + def args(self): + return self._args + + @property + def cancelled(self): + return self._cancelled + + def cancel(self): + self._cancelled = True + + def run(self): + try: + self._callback(*self._args) + except Exception: + logging.exception('Exception in callback %s %r', + self._callback, self._args) + + +def make_handler(callback, args): + if isinstance(callback, Handler): + assert not args + return callback + return Handler(callback, args) + + +class Timer(Handler): + """Object returned by timed callback registration methods.""" + + def __init__(self, when, callback, args): + assert when is not None + super().__init__(callback, args) + self._when = when + + def __repr__(self): + res = 'Timer({}, {}, {})'.format(self._when, + self._callback, + self._args) + if self._cancelled: + res += '' + return res + + @property + def when(self): + return self._when + + def __lt__(self, other): + return self._when < other._when + + def __le__(self, other): + if self._when < other._when: + return True + return self.__eq__(other) + + def __gt__(self, other): + return self._when > other._when + + def __ge__(self, other): + if self._when > other._when: + return True + return self.__eq__(other) + + def __eq__(self, other): + if isinstance(other, Timer): + return (self._when == other._when and + self._callback == other._callback and + self._args == other._args and + self._cancelled == other._cancelled) + return NotImplemented + + def __ne__(self, other): + equal = self.__eq__(other) + return NotImplemented if equal is NotImplemented else not equal + + +class AbstractEventLoop: + """Abstract event loop.""" + + # TODO: Rename run() -> run_until_idle(), run_forever() -> run(). + + def run(self): + """Run the event loop. Block until there is nothing left to do.""" + raise NotImplementedError + + def run_forever(self): + """Run the event loop. Block until stop() is called.""" + raise NotImplementedError + + def run_once(self, timeout=None): # NEW! + """Run one complete cycle of the event loop.""" + raise NotImplementedError + + def run_until_complete(self, future, timeout=None): # NEW! + """Run the event loop until a Future is done. + + Return the Future's result, or raise its exception. + + If timeout is not None, run it for at most that long; + if the Future is still not done, raise TimeoutError + (but don't cancel the Future). + """ + raise NotImplementedError + + def stop(self): # NEW! + """Stop the event loop as soon as reasonable. + + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + raise NotImplementedError + + # Methods returning Handlers for scheduling callbacks. + + def call_later(self, delay, callback, *args): + raise NotImplementedError + + def call_repeatedly(self, interval, callback, *args): # NEW! + raise NotImplementedError + + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) + + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError + + # Methods returning Futures for interacting with threads. + + def wrap_future(self, future): + raise NotImplementedError + + def run_in_executor(self, executor, callback, *args): + raise NotImplementedError + + # Network I/O methods returning Futures. + + def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError + + def create_connection(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, sock=None): + raise NotImplementedError + + def start_serving(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, sock=None): + raise NotImplementedError + + def create_datagram_connection(self, protocol_factory, + host=None, port=None, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + def start_serving_datagram(self, protocol_factory, host, port, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + # Ready-based callback registration methods. + # The add_*() methods return a Handler. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. + + def add_reader(self, fd, callback, *args): + raise NotImplementedError + + def remove_reader(self, fd): + raise NotImplementedError + + def add_writer(self, fd, callback, *args): + raise NotImplementedError + + def remove_writer(self, fd): + raise NotImplementedError + + # Completion based I/O methods returning Futures. + + def sock_recv(self, sock, nbytes): + raise NotImplementedError + + def sock_sendall(self, sock, data): + raise NotImplementedError + + def sock_connect(self, sock, address): + raise NotImplementedError + + def sock_accept(self, sock): + raise NotImplementedError + + # Signal handling. + + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError + + def remove_signal_handler(self, sig): + raise NotImplementedError + + +class EventLoopPolicy: + """Abstract policy for accessing the event loop.""" + + def get_event_loop(self): + """XXX""" + raise NotImplementedError + + def set_event_loop(self, event_loop): + """XXX""" + raise NotImplementedError + + def new_event_loop(self): + """XXX""" + raise NotImplementedError + + +class DefaultEventLoopPolicy(threading.local, EventLoopPolicy): + """Default policy implementation for accessing the event loop. + + In this policy, each thread has its own event loop. However, we + only automatically create an event loop by default for the main + thread; other threads by default have no event loop. + + Other policies may have different rules (e.g. a single global + event loop, or automatically creating an event loop per thread, or + using some other notion of context to which an event loop is + associated). + """ + + _event_loop = None + + def get_event_loop(self): + """Get the event loop. + + This may be None or an instance of EventLoop. + """ + if (self._event_loop is None and + threading.current_thread().name == 'MainThread'): + self._event_loop = self.new_event_loop() + return self._event_loop + + def set_event_loop(self, event_loop): + """Set the event loop.""" + assert event_loop is None or isinstance(event_loop, AbstractEventLoop) + self._event_loop = event_loop + + def new_event_loop(self): + """Create a new event loop. + + You must call set_event_loop() to make this the current event + loop. + """ + if sys.platform == 'win32': # pragma: no cover + from . import windows_events + return windows_events.SelectorEventLoop() + else: # pragma: no cover + from . import unix_events + return unix_events.SelectorEventLoop() + + +# Event loop policy. The policy itself is always global, even if the +# policy's rules say that there is an event loop per thread (or other +# notion of context). The default policy is installed by the first +# call to get_event_loop_policy(). +_event_loop_policy = None + + +def get_event_loop_policy(): + """XXX""" + global _event_loop_policy + if _event_loop_policy is None: + _event_loop_policy = DefaultEventLoopPolicy() + return _event_loop_policy + + +def set_event_loop_policy(policy): + """XXX""" + global _event_loop_policy + assert policy is None or isinstance(policy, EventLoopPolicy) + _event_loop_policy = policy + + +def get_event_loop(): + """XXX""" + return get_event_loop_policy().get_event_loop() + + +def set_event_loop(event_loop): + """XXX""" + get_event_loop_policy().set_event_loop(event_loop) + + +def new_event_loop(): + """XXX""" + return get_event_loop_policy().new_event_loop() diff --git a/tulip/futures.py b/tulip/futures.py new file mode 100644 index 00000000..4bb2f198 --- /dev/null +++ b/tulip/futures.py @@ -0,0 +1,244 @@ +"""A Future class similar to the one in PEP 3148.""" + +__all__ = ['CancelledError', 'TimeoutError', + 'InvalidStateError', 'InvalidTimeoutError', + 'Future', + ] + +import concurrent.futures._base + +from . import events + +# States for Future. +_PENDING = 'PENDING' +_CANCELLED = 'CANCELLED' +_FINISHED = 'FINISHED' + +# TODO: Do we really want to depend on concurrent.futures internals? +Error = concurrent.futures._base.Error +CancelledError = concurrent.futures.CancelledError +TimeoutError = concurrent.futures.TimeoutError + + +class InvalidStateError(Error): + """The operation is not allowed in this state.""" + # TODO: Show the future, its state, the method, and the required state. + + +class InvalidTimeoutError(Error): + """Called result() or exception() with timeout != 0.""" + # TODO: Print a nice error message. + + +class Future: + """This class is *almost* compatible with concurrent.futures.Future. + + Differences: + + - result() and exception() do not take a timeout argument and + raise an exception when the future isn't done yet. + + - Callbacks registered with add_done_callback() are always called + via the event loop's call_soon_threadsafe(). + + - This class is not compatible with the wait() and as_completed() + methods in the concurrent.futures package. + + (In Python 3.4 or later we may be able to unify the implementations.) + """ + + # Class variables serving as defaults for instance variables. + _state = _PENDING + _result = None + _exception = None + + _blocking = False # proper use of future (yield vs yield from) + + def __init__(self, *, event_loop=None): + """Initialize the future. + + The optional event_loop argument allows to explicitly set the event + loop object used by the future. If it's not provided, the future uses + the default event loop. + """ + if event_loop is None: + self._event_loop = events.get_event_loop() + else: + self._event_loop = event_loop + self._callbacks = [] + + def __repr__(self): + res = self.__class__.__name__ + if self._state == _FINISHED: + if self._exception is not None: + res += ''.format(self._exception) + else: + res += ''.format(self._result) + elif self._callbacks: + size = len(self._callbacks) + if size > 2: + res += '<{}, [{}, <{} more>, {}]>'.format( + self._state, self._callbacks[0], + size-2, self._callbacks[-1]) + else: + res += '<{}, {}>'.format(self._state, self._callbacks) + else: + res +='<{}>'.format(self._state) + return res + + def cancel(self): + """Cancel the future and schedule callbacks. + + If the future is already done or cancelled, return False. Otherwise, + change the future's state to cancelled, schedule the callbacks and + return True. + """ + if self._state != _PENDING: + return False + self._state = _CANCELLED + self._schedule_callbacks() + return True + + def _schedule_callbacks(self): + """Internal: Ask the event loop to call all callbacks. + + The callbacks are scheduled to be called as soon as possible. Also + clears the callback list. + """ + callbacks = self._callbacks[:] + if not callbacks: + return + self._callbacks[:] = [] + for callback in callbacks: + self._event_loop.call_soon(callback, self) + + def cancelled(self): + """Return True if the future was cancelled.""" + return self._state == _CANCELLED + + def running(self): + """Always return False. + + This method is for compatibility with concurrent.futures; we don't + have a running state. + """ + return False # We don't have a running state. + + def done(self): + """Return True if the future is done. + + Done means either that a result / exception are available, or that the + future was cancelled. + """ + return self._state != _PENDING + + def result(self, timeout=0): + """Return the result this future represents. + + If the future has been cancelled, raises CancelledError. If the + future's result isn't yet available, raises InvalidStateError. If + the future is done and has an exception set, this exception is raised. + Timeout values other than 0 are not supported. + """ + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + if self._exception is not None: + raise self._exception + return self._result + + def exception(self, timeout=0): + """Return the exception that was set on this future. + + The exception (or None if no exception was set) is returned only if + the future is done. If the future has been cancelled, raises + CancelledError. If the future isn't done yet, raises + InvalidStateError. Timeout values other than 0 are not supported. + """ + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + return self._exception + + def add_done_callback(self, fn): + """Add a callback to be run when the future becomes done. + + The callback is called with a single argument - the future object. If + the future is already done when this is called, the callback is + scheduled with call_soon. + """ + if self._state != _PENDING: + self._event_loop.call_soon(fn, self) + else: + self._callbacks.append(fn) + + # New method not in PEP 3148. + + def remove_done_callback(self, fn): + """Remove all instances of a callback from the "call when done" list. + + Returns the number of callbacks removed. + """ + filtered_callbacks = [f for f in self._callbacks if f != fn] + removed_count = len(self._callbacks) - len(filtered_callbacks) + if removed_count: + self._callbacks[:] = filtered_callbacks + return removed_count + + # So-called internal methods (note: no set_running_or_notify_cancel()). + + def set_result(self, result): + """Mark the future done and set its result. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError + self._result = result + self._state = _FINISHED + self._schedule_callbacks() + + def set_exception(self, exception): + """ Mark the future done and set an exception. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError + self._exception = exception + self._state = _FINISHED + self._schedule_callbacks() + + # Truly internal methods. + + def _copy_state(self, other): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert other.done() + assert not self.done() + if other.cancelled(): + self.cancel() + else: + exception = other.exception() + if exception is not None: + self.set_exception(exception) + else: + result = other.result() + self.set_result(result) + + def __iter__(self): + if not self.done(): + self._blocking = True + yield self # This tells Task to wait for completion. + assert self.done(), "yield from wasn't used with future" + return self.result() # May raise too. diff --git a/tulip/http/__init__.py b/tulip/http/__init__.py new file mode 100644 index 00000000..d436383f --- /dev/null +++ b/tulip/http/__init__.py @@ -0,0 +1,8 @@ +# This relies on each of the submodules having an __all__ variable. + +from .client import * +from .protocol import * + + +__all__ = (client.__all__ + + protocol.__all__) diff --git a/tulip/http/client.py b/tulip/http/client.py new file mode 100644 index 00000000..b4db5ccb --- /dev/null +++ b/tulip/http/client.py @@ -0,0 +1,154 @@ +"""HTTP Client for Tulip. + +Most basic usage: + + sts, headers, response = yield from http_client.fetch(url, + method='GET', headers={}, request=b'') + assert isinstance(sts, int) + assert isinstance(headers, dict) + # sort of; case insensitive (what about multiple values for same header?) + headers['status'] == '200 Ok' # or some such + assert isinstance(response, bytes) + +However you can also open a stream: + + f, wstream = http_client.open_stream(url, method, headers) + wstream.write(b'abc') + wstream.writelines([b'def', b'ghi']) + wstream.write_eof() + sts, headers, rstream = yield from f + response = yield from rstream.read() + +TODO: Reuse email.Message class (or its subclass, http.client.HTTPMessage). +TODO: How do we do connection keep alive? Pooling? +""" + +__all__ = ['HttpClientProtocol'] + + +import email.message +import email.parser + +import tulip + +from . import protocol + + +class HttpClientProtocol: + """This Protocol class is also used to initiate the connection. + + Usage: + p = HttpClientProtocol(url, ...) + sts, headers, stream = yield from p.connect() + + """ + + def __init__(self, host, port=None, *, + path='/', method='GET', headers=None, ssl=None, + make_body=None, encoding='utf-8', version='1.1', + chunked=False): + host = self.validate(host, 'host') + if ':' in host: + assert port is None + host, port_s = host.split(':', 1) + port = int(port_s) + self.host = host + if port is None: + if ssl: + port = 443 + else: + port = 80 + assert isinstance(port, int) + self.port = port + self.path = self.validate(path, 'path') + self.method = self.validate(method, 'method') + self.headers = email.message.Message() + self.headers['Accept-Encoding'] = 'gzip, deflate' + if headers: + for key, value in headers.items(): + self.validate(key, 'header key') + self.validate(value, 'header value', True) + self.headers[key] = value + self.encoding = self.validate(encoding, 'encoding') + self.version = self.validate(version, 'version') + self.make_body = make_body + self.chunked = chunked + self.ssl = ssl + if 'content-length' not in self.headers: + if self.make_body is None: + self.headers['Content-Length'] = '0' + else: + self.chunked = True + if self.chunked: + if 'Transfer-Encoding' not in self.headers: + self.headers['Transfer-Encoding'] = 'chunked' + else: + assert self.headers['Transfer-Encoding'].lower() == 'chunked' + if 'host' not in self.headers: + self.headers['Host'] = self.host + self.event_loop = tulip.get_event_loop() + self.transport = None + + def validate(self, value, name, embedded_spaces_okay=False): + # Must be a string. If embedded_spaces_okay is False, no + # whitespace is allowed; otherwise, internal single spaces are + # allowed (but no other whitespace). + assert isinstance(value, str), \ + '{} should be str, not {}'.format(name, type(value)) + parts = value.split() + assert parts, '{} should not be empty'.format(name) + if embedded_spaces_okay: + assert ' '.join(parts) == value, \ + '{} can only contain embedded single spaces ({!r})'.format( + name, value) + else: + assert parts == [value], \ + '{} cannot contain whitespace ({!r})'.format(name, value) + return value + + @tulip.coroutine + def connect(self): + yield from self.event_loop.create_connection( + lambda: self, self.host, self.port, ssl=self.ssl) + + # read response status + version, status, reason = yield from self.stream.read_response_status() + + message = yield from self.stream.read_message(version) + + # headers + headers = email.message.Message() + for hdr, val in message.headers: + headers.add_header(hdr, val) + + sts = '{} {}'.format(status, reason) + return (sts, headers, message.payload) + + def connection_made(self, transport): + self.transport = transport + self.stream = protocol.HttpStreamReader() + self.wstream = protocol.HttpStreamWriter(transport) + + line = '{} {} HTTP/{}\r\n'.format(self.method, + self.path, + self.version) + self.wstream.write_str(line) + for key, value in self.headers.items(): + self.wstream.write_str('{}: {}\r\n'.format(key, value)) + self.wstream.write(b'\r\n') + if self.make_body is not None: + if self.chunked: + self.make_body( + self.wstream.write_chunked, self.wstream.write_chunked_eof) + else: + self.make_body( + self.wstream.write_str, self.wstream.write_eof) + + def data_received(self, data): + self.stream.feed_data(data) + + def eof_received(self): + self.stream.feed_eof() + + def connection_lost(self, exc): + pass diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py new file mode 100644 index 00000000..2ff7876f --- /dev/null +++ b/tulip/http/protocol.py @@ -0,0 +1,508 @@ +"""Http related helper utils.""" + +__all__ = ['HttpStreamReader', 'HttpStreamWriter', + 'HttpMessage', 'RequestLine', 'ResponseStatus'] + +import collections +import functools +import http.client +import re +import zlib + +import tulip + +METHRE = re.compile('[A-Z0-9$-_.]+') +VERSRE = re.compile('HTTP/(\d+).(\d+)') +HDRRE = re.compile(b"[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]") +CONTINUATION = (b' ', b'\t') + + +RequestLine = collections.namedtuple( + 'RequestLine', ['method', 'uri', 'version']) + + +ResponseStatus = collections.namedtuple( + 'ResponseStatus', ['version', 'code', 'reason']) + + +HttpMessage = collections.namedtuple( + 'HttpMessage', ['headers', 'payload', 'should_close', 'compression']) + + +class StreamEofException(http.client.HTTPException): + """eof received""" + + +def wrap_payload_reader(f): + """wrap_payload_reader wraps payload readers and redirect stream. + payload readers are generator functions, read_chunked_payload, + read_length_payload, read_eof_payload. + payload reader allows to modify data stream and feed data into stream. + + StreamReader instance should be send to generator as first parameter. + This steam is used as destination stream for processed data. + To send data to reader use generator's send() method. + + To indicate eof stream, throw StreamEofException exception into the reader. + In case of errors in incoming stream reader sets exception to + destination stream with StreamReader.set_exception() method. + + Before exit, reader generator returns unprocessed data. + """ + + @functools.wraps(f) + def wrapper(self, *args, **kw): + assert self._reader is None + + rstream = stream = tulip.StreamReader() + + encoding = kw.pop('encoding', None) + if encoding is not None: + if encoding not in ('gzip', 'deflate'): + raise ValueError( + 'Content-Encoding %r is not supported' % encoding) + + stream = DeflateStream(stream, encoding) + + reader = f(self, *args, **kw) + next(reader) + try: + reader.send(stream) + except StopIteration: + pass + else: + # feed buffer + self.line_count = 0 + self.byte_count = 0 + while self.buffer: + try: + reader.send(self.buffer.popleft()) + except StopIteration as exc: + buf = b''.join(self.buffer) + self.buffer.clear() + reader = None + if exc.value: + self.feed_data(exc.value) + + if buf: + self.feed_data(buf) + + break + + if reader is not None: + if self.eof: + try: + reader.throw(StreamEofException()) + except StopIteration as exc: + pass + else: + self._reader = reader + + return rstream + + return wrapper + + +class DeflateStream: + """DeflateStream decomress stream and feed data into specified steram.""" + + def __init__(self, stream, encoding): + self.stream = stream + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + + self.zlib = zlib.decompressobj(wbits=zlib_mode) + + def feed_data(self, chunk): + try: + chunk = self.zlib.decompress(chunk) + except: + self.stream.set_exception(http.client.IncompleteRead(b'')) + + if chunk: + self.stream.feed_data(chunk) + + def feed_eof(self): + self.stream.feed_data(self.zlib.flush()) + if not self.zlib.eof: + self.stream.set_exception(http.client.IncompleteRead(b'')) + + self.stream.feed_eof() + + +class HttpStreamReader(tulip.StreamReader): + + MAX_HEADERS = 32768 + MAX_HEADERFIELD_SIZE = 8190 + + # if _reader is set, feed_data and feed_eof sends data into + # _reader instead of self. is it being used as stream redirection for + # read_chunked_payload, read_length_payload and read_eof_payload + _reader = None + + def feed_data(self, data): + """_reader is a generator, if _reader is set, feed_data sends + incoming data into this generator untile generates stops.""" + if self._reader: + try: + self._reader.send(data) + except StopIteration as exc: + self._reader = None + if exc.value: + self.feed_data(exc.value) + else: + super().feed_data(data) + + def feed_eof(self): + """_reader is a generator, if _reader is set feed_eof throws + StreamEofException into this generator.""" + if self._reader: + try: + self._reader.throw(StreamEofException()) + except StopIteration: + self._reader = None + + super().feed_eof() + + @tulip.coroutine + def read_request_line(self): + """Read request status line. Exception http.client.BadStatusLine + could be raised in case of any errors in status line. + Returns three values (method, uri, version) + + Example: + + GET /path HTTP/1.1 + + >> yield from reader.read_request_line() + ('GET', '/path', (1, 1)) + + """ + bline = yield from self.readline() + try: + line = bline.decode('ascii').rstrip() + except UnicodeDecodeError: + raise http.client.BadStatusLine(bline) from None + + try: + method, uri, version = line.split(None, 2) + except ValueError: + raise http.client.BadStatusLine(line) from None + + # method + method = method.upper() + if not METHRE.match(method): + raise http.client.BadStatusLine(method) + + # version + match = VERSRE.match(version) + if match is None: + raise http.client.BadStatusLine(version) + version = (int(match.group(1)), int(match.group(2))) + + return RequestLine(method, uri, version) + + @tulip.coroutine + def read_response_status(self): + """Read response status line. Exception http.client.BadStatusLine + could be raised in case of any errors in status line. + Returns three values (version, status_code, reason) + + Example: + + HTTP/1.1 200 Ok + + >> yield from reader.read_response_status() + ((1, 1), 200, 'Ok') + + """ + bline = yield from self.readline() + if not bline: + # Presumably, the server closed the connection before + # sending a valid response. + raise http.client.BadStatusLine(bline) + + try: + line = bline.decode('ascii').rstrip() + except UnicodeDecodeError: + raise http.client.BadStatusLine(bline) from None + + try: + version, status = line.split(None, 1) + except ValueError: + raise http.client.BadStatusLine(line) from None + else: + try: + status, reason = status.split(None, 1) + except ValueError: + reason = '' + + # version + match = VERSRE.match(version) + if match is None: + raise http.client.BadStatusLine(line) + version = (int(match.group(1)), int(match.group(2))) + + # The status code is a three-digit number + try: + status = int(status) + except ValueError: + raise http.client.BadStatusLine(line) from None + + if status < 100 or status > 999: + raise http.client.BadStatusLine(line) + + return ResponseStatus(version, status, reason.strip()) + + @tulip.coroutine + def read_headers(self): + """Read and parses RFC2822 headers from a stream. + + Line continuations are supported. Returns list of header name + and value pairs. Header name is in upper case. + """ + size = 0 + headers = [] + + line = yield from self.readline() + + while line not in (b'\r\n', b'\n'): + header_length = len(line) + + # Parse initial header name : value pair. + sep_pos = line.find(b':') + if sep_pos < 0: + raise ValueError('Invalid header %s' % line.strip()) + + name, value = line[:sep_pos], line[sep_pos+1:] + name = name.rstrip(b' \t').upper() + if HDRRE.search(name): + raise ValueError('Invalid header name %s' % name) + + name = name.strip().decode('ascii', 'surrogateescape') + value = [value.lstrip()] + + # next line + line = yield from self.readline() + + # consume continuation lines + continuation = line.startswith(CONTINUATION) + + if continuation: + while continuation: + header_length += len(line) + if header_length > self.MAX_HEADERFIELD_SIZE: + raise http.client.LineTooLong( + 'limit request headers fields size') + value.append(line) + + line = yield from self.readline() + continuation = line.startswith(CONTINUATION) + else: + if header_length > self.MAX_HEADERFIELD_SIZE: + raise http.client.LineTooLong( + 'limit request headers fields size') + + # total headers size + size += header_length + if size >= self.MAX_HEADERS: + raise http.client.LineTooLong('limit request headers fields') + + headers.append( + (name, + b''.join(value).rstrip().decode('ascii', 'surrogateescape'))) + + return headers + + @wrap_payload_reader + def read_chunked_payload(self): + """Read chunked stream.""" + stream = yield + + try: + data = bytearray() + + while True: + # read line + if b'\n' not in data: + data.extend((yield)) + continue + + line, data = data.split(b'\n', 1) + + # Read the next chunk size from the file + i = line.find(b";") + if i >= 0: + line = line[:i] # strip chunk-extensions + try: + size = int(line, 16) + except ValueError: + raise http.client.IncompleteRead(b'') from None + + if size == 0: + break + + # read chunk + while len(data) < size: + data.extend((yield)) + + # feed stream + stream.feed_data(data[:size]) + + data = data[size:] + + # toss the CRLF at the end of the chunk + while len(data) < 2: + data.extend((yield)) + + data = data[2:] + + # read and discard trailer up to the CRLF terminator + while True: + if b'\n' in data: + line, data = data.split(b'\n', 1) + if line in (b'\r', b''): + break + else: + data.extend((yield)) + + # stream eof + stream.feed_eof() + return data + + except StreamEofException: + stream.set_exception(http.client.IncompleteRead(b'')) + except http.client.IncompleteRead as exc: + stream.set_exception(exc) + + @wrap_payload_reader + def read_length_payload(self, length): + """Read specified amount of bytes.""" + stream = yield + + try: + data = bytearray() + while length: + data.extend((yield)) + + data_len = len(data) + if data_len <= length: + stream.feed_data(data) + data = bytearray() + length -= data_len + else: + stream.feed_data(data[:length]) + data = data[length:] + length = 0 + + stream.feed_eof() + return data + except StreamEofException: + stream.set_exception(http.client.IncompleteRead(b'')) + + @wrap_payload_reader + def read_eof_payload(self): + """Read all bytes untile eof.""" + stream = yield + + try: + while True: + stream.feed_data((yield)) + except StreamEofException: + stream.feed_eof() + + @tulip.coroutine + def read_message(self, version=(1, 1), + length=None, compression=True, readall=False): + """Read RFC2822 headers and message payload from a stream. + + read_message() automatically decompress gzip and deflate content + encoding. To prevent decompression pass compression=False. + + Returns tuple of headers, payload stream, should close flag, + compression type. + """ + # load headers + headers = yield from self.read_headers() + + # payload params + chunked = False + encoding = None + close_conn = None + + for name, value in headers: + if name == 'CONTENT-LENGTH': + length = value + elif name == 'TRANSFER-ENCODING': + chunked = value.lower() == 'chunked' + elif name == 'SEC-WEBSOCKET-KEY1': + length = 8 + elif name == "CONNECTION": + v = value.lower() + if v == "close": + close_conn = True + elif v == "keep-alive": + close_conn = False + elif compression and name == 'CONTENT-ENCODING': + enc = value.lower() + if enc in ('gzip', 'deflate'): + encoding = enc + + if close_conn is None: + close_conn = version <= (1, 0) + + # payload stream + if chunked: + payload = self.read_chunked_payload(encoding=encoding) + + elif length is not None: + try: + length = int(length) + except ValueError: + raise http.client.HTTPException('CONTENT-LENGTH') from None + + if length < 0: + raise http.client.HTTPException('CONTENT-LENGTH') + + payload = self.read_length_payload(length, encoding=encoding) + else: + if readall: + payload = self.read_eof_payload(encoding=encoding) + else: + payload = self.read_length_payload(0, encoding=encoding) + + return HttpMessage(headers, payload, close_conn, encoding) + + +class HttpStreamWriter: + + def __init__(self, transport, encoding='utf-8'): + self.transport = transport + self.encoding = encoding + + def encode(self, s): + if isinstance(s, bytes): + return s + return s.encode(self.encoding) + + def decode(self, s): + if isinstance(s, str): + return s + return s.decode(self.encoding) + + def write(self, b): + self.transport.write(b) + + def write_str(self, s): + self.transport.write(self.encode(s)) + + def write_chunked(self, chunk): + if not chunk: + return + data = self.encode(chunk) + self.write_str('{:x}\r\n'.format(len(data))) + self.transport.write(data) + self.transport.write(b'\r\n') + + def write_chunked_eof(self): + self.transport.write(b'0\r\n\r\n') diff --git a/tulip/locks.py b/tulip/locks.py new file mode 100644 index 00000000..e55487a6 --- /dev/null +++ b/tulip/locks.py @@ -0,0 +1,460 @@ +"""Synchronization primitives""" + +__all__ = ['Lock', 'EventWaiter', 'Condition', 'Semaphore'] + +import collections +import time + +from . import events +from . import futures +from . import tasks + + +class Lock: + """The class implementing primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned by + a particular coroutine when locked. A primitive lock is in one of two + states, "locked" or "unlocked". + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() changes + the state to locked and returns immediately. When the state is locked, + acquire() blocks until a call to release() in another coroutine changes + it to unlocked, then the acquire() call resets it to locked and returns. + The release() method should only be called in the locked state; it changes + the state to unlocked and returns immediately. If an attempt is made + to release an unlocked lock, a RuntimeError will be raised. + + When more than one coroutine is blocked in acquire() waiting for the state + to turn to unlocked, only one coroutine proceeds when a release() call + resets the state to unlocked; first coroutine which is blocked in acquire() + is being processed. + + acquire() method is a coroutine and should be called with "yield from" + + Locks also support the context manager protocol. (yield from lock) should + be used as context manager expression. + + Usage: + + lock = Lock() + ... + yield from lock + try: + ... + finally: + lock.release() + + Context manager usage: + + lock = Lock() + ... + with (yield from lock): + ... + + Lock object could be tested for locking state: + + if not lock.locked(): + yield from lock + else: + # lock is acquired + ... + + """ + + def __init__(self): + self._waiters = collections.deque() + self._locked = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<%s [%s]>'%(res[1:-1], 'locked' if self._locked else 'unlocked') + + def locked(self): + """Return true if lock is acquired.""" + return self._locked + + @tasks.coroutine + def acquire(self, timeout=None): + """Acquire a lock. + + Acquire method blocks until the lock is unlocked, then set it to + locked and return True. + + When invoked with the floating-point timeout argument set, blocks for + at most the number of seconds specified by timeout and as long as + the lock cannot be acquired. + + The return value is True if the lock is acquired successfully, + False if not (for example if the timeout expired). + """ + if not self._waiters and not self._locked: + self._locked = True + return True + + fut = futures.Future(event_loop=self._event_loop) + if timeout is not None: + handler = self._event_loop.call_later(timeout, fut.cancel) + else: + handler = None + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + if handler is not None: + handler.cancel() + + self._locked = True + return True + + def release(self): + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + If any other coroutines are blocked waiting for the lock to become + unlocked, allow exactly one of them to proceed. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + if self._waiters: + self._waiters[0].set_result(True) + else: + raise RuntimeError('Lock is not acquired.') + + def __enter__(self): + if not self._locked: + raise RuntimeError( + '"yield from" should be used as context manager expression') + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self + + +class EventWaiter: + """A EventWaiter implementation, our equivalent to threading.Event + + Class implementing event objects. An event manages a flag that can be set + to true with the set() method and reset to false with the clear() method. + The wait() method blocks until the flag is true. The flag is initially + false. + """ + + def __init__(self): + self._waiters = collections.deque() + self._value = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<%s [%s]>'%(res[1:-1], 'set' if self._value else 'unset') + + def is_set(self): + """Return true if and only if the internal flag is true.""" + return self._value + + def set(self): + """Set the internal flag to true. All coroutines waiting for it to + become true are awakened. Coroutine that call wait() once the flag is + true will not block at all. + """ + if not self._value: + self._value = True + + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + + def clear(self): + """Reset the internal flag to false. Subsequently, coroutines calling + wait() will block until set() is called to set the internal flag + to true again.""" + self._value = False + + @tasks.coroutine + def wait(self, timeout=None): + """Block until the internal flag is true. If the internal flag + is true on entry, return immediately. Otherwise, block until another + coroutine calls set() to set the flag to true, or until the optional + timeout occurs. + + When the timeout argument is present and not None, it should be + a floating point number specifying a timeout for the operation in + seconds (or fractions thereof). + + This method returns true if and only if the internal flag has been + set to true, either before the wait call or after the wait starts, + so it will always return True except if a timeout is given and + the operation times out. + + wait() method is a coroutine. + """ + if self._value: + return True + + fut = futures.Future(event_loop=self._event_loop) + if timeout is not None: + handler = self._event_loop.call_later(timeout, fut.cancel) + else: + handler = None + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + if handler is not None: + handler.cancel() + + return True + + +class Condition(Lock): + """A Condition implementation. + + This class implements condition variable objects. A condition variable + allows one or more coroutines to wait until they are notified by another + coroutine. + """ + + def __init__(self): + super().__init__() + + self._condition_waiters = collections.deque() + + @tasks.coroutine + def wait(self, timeout=None): + """Wait until notified or until a timeout occurs. If the calling + coroutine has not acquired the lock when this method is called, + a RuntimeError is raised. + + This method releases the underlying lock, and then blocks until it is + awakened by a notify() or notify_all() call for the same condition + variable in another coroutine, or until the optional timeout occurs. + Once awakened or timed out, it re-acquires the lock and returns. + + When the timeout argument is present and not None, it should be + a floating point number specifying a timeout for the operation + in seconds (or fractions thereof). + + The return value is True unless a given timeout expired, in which + case it is False. + """ + if not self._locked: + raise RuntimeError('cannot wait on un-acquired lock') + + self.release() + + fut = futures.Future(event_loop=self._event_loop) + if timeout is not None: + handler = self._event_loop.call_later(timeout, fut.cancel) + else: + handler = None + + self._condition_waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._condition_waiters.remove(fut) + return False + else: + f = self._condition_waiters.popleft() + assert fut is f + finally: + yield from self.acquire() + + if handler is not None: + handler.cancel() + + return True + + @tasks.coroutine + def wait_for(self, predicate, timeout=None): + """Wait until a condition evaluates to True. predicate should be a + callable which result will be interpreted as a boolean value. A timeout + may be provided giving the maximum time to wait. + """ + endtime = None + waittime = timeout + result = predicate() + + while not result: + if waittime is not None: + if endtime is None: + endtime = time.monotonic() + waittime + else: + waittime = endtime - time.monotonic() + if waittime <= 0: + break + + yield from self.wait(waittime) + result = predicate() + + return result + + def notify(self, n=1): + """By default, wake up one coroutine waiting on this condition, if any. + If the calling coroutine has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up at most n of the coroutines waiting for the + condition variable; it is a no-op if no coroutines are waiting. + + Note: an awakened coroutine does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self._locked: + raise RuntimeError('cannot notify on un-acquired lock') + + idx = 0 + for fut in self._condition_waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self): + """Wake up all threads waiting on this condition. This method acts + like notify(), but wakes up all waiting threads instead of one. If the + calling thread has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._condition_waiters)) + + +class Semaphore: + """A Semaphore implementation. + + A semaphore manages an internal counter which is decremented by each + acquire() call and incremented by each release() call. The counter + can never go below zero; when acquire() finds that it is zero, it blocks, + waiting until some other thread calls release(). + + Semaphores also support the context manager protocol. + + The first optional argument gives the initial value for the internal + counter; it defaults to 1. If the value given is less than 0, + ValueError is raised. + + The second optional argument determins can semophore be released more than + initial internal counter value; it defaults to False. If the value given + is True and number of release() is more than number of successfull + acquire() calls ValueError is raised. + """ + + def __init__(self, value=1, bound=False): + if value < 0: + raise ValueError("Semaphore initial value must be > 0") + self._value = value + self._bound = bound + self._bound_value = value + self._waiters = collections.deque() + self._locked = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<%s [%s]>'%( + res[1:-1], + 'locked' if self._locked else 'unlocked,value:%s'%self._value) + + def locked(self): + """Returns True if semaphore can not be acquired immediately.""" + return self._locked + + @tasks.coroutine + def acquire(self, timeout=None): + """Acquire a semaphore. acquire() method is a coroutine. + + When invoked without arguments: if the internal counter is larger + than zero on entry, decrement it by one and return immediately. + If it is zero on entry, block, waiting until some other coroutine has + called release() to make it larger than zero. + + When invoked with a timeout other than None, it will block for at + most timeout seconds. If acquire does not complete successfully in + that interval, return false. Return true otherwise. + """ + if not self._waiters and self._value > 0: + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + fut = futures.Future(event_loop=self._event_loop) + if timeout is not None: + handler = self._event_loop.call_later(timeout, fut.cancel) + else: + handler = None + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + if handler is not None: + handler.cancel() + + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + When it was zero on entry and another coroutine is waiting for it to + become larger than zero again, wake up that coroutine. + + If Semaphore is create with "bound" paramter equals true, then + release() method checks to make sure its current value doesn't exceed + its initial value. If it does, ValueError is raised. + """ + if self._bound and self._value >= self._bound_value: + raise ValueError('Semaphore released too many times') + + self._value += 1 + self._locked = False + + for waiter in self._waiters: + if not waiter.done(): + waiter.set_result(True) + break + + def __enter__(self): + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py new file mode 100644 index 00000000..45c075e3 --- /dev/null +++ b/tulip/proactor_events.py @@ -0,0 +1,190 @@ +"""Event loop using a proactor and related classes. + +A proactor is a "notify-on-completion" multiplexer. Currently a +proactor is only implemented on Windows with IOCP. +""" + +import logging + +from . import base_events +from . import transports + + +class _ProactorSocketTransport(transports.Transport): + + def __init__(self, event_loop, sock, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._event_loop = event_loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._read_fut = None + self._write_fut = None + self._closing = False # Set when close() called. + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.call_soon(self._loop_reading) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _loop_reading(self, fut=None): + data = None + + try: + if fut is not None: + assert fut is self._read_fut + + data = fut.result() # deliver data later in "finally" clause + if not data: + self._read_fut = None + return + + self._read_fut = self._event_loop._proactor.recv(self._sock, 4096) + except ConnectionAbortedError as exc: + if not self._closing: + self._fatal_error(exc) + except OSError as exc: + self._fatal_error(exc) + else: + self._read_fut.add_done_callback(self._loop_reading) + finally: + if data: + self._protocol.data_received(data) + elif data is not None: + self._protocol.eof_received() + + def write(self, data): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + self._buffer.append(data) + if not self._write_fut: + self._loop_writing() + + def _loop_writing(self, f=None): + try: + assert f is self._write_fut + if f: + f.result() + data = b''.join(self._buffer) + self._buffer = [] + if not data: + self._write_fut = None + return + self._write_fut = self._event_loop._proactor.send(self._sock, data) + except OSError as exc: + self._fatal_error(exc) + else: + self._write_fut.add_done_callback(self._loop_writing) + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + if self._write_fut: + self._write_fut.cancel() + if not self._buffer: + self._event_loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + logging.exception('Fatal error for %s', self) + if self._write_fut: + self._write_fut.cancel() + if self._read_fut: # XXX + self._read_fut.cancel() + self._write_fut = self._read_fut = None + self._buffer = [] + self._event_loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + + +class BaseProactorEventLoop(base_events.BaseEventLoop): + + def __init__(self, proactor): + super().__init__() + logging.debug('Using proactor: %s', proactor.__class__.__name__) + self._proactor = proactor + self._selector = proactor # convenient alias + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + return _ProactorSocketTransport(self, sock, protocol, waiter, extra) + + def close(self): + if self._proactor is not None: + self._close_self_pipe() + self._proactor.close() + self._proactor = None + self._selector = None + + def sock_recv(self, sock, n): + return self._proactor.recv(sock, n) + + def sock_sendall(self, sock, data): + return self._proactor.send(sock, data) + + def sock_connect(self, sock, address): + return self._proactor.connect(sock, address) + + def sock_accept(self, sock): + return self._proactor.accept(sock) + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + + def loop(f=None): + try: + if f: + f.result() # may raise + f = self._proactor.recv(self._ssock, 4096) + except: + self.close() + raise + else: + f.add_done_callback(loop) + self.call_soon(loop) + + def _write_to_self(self): + self._csock.send(b'x') + + def _start_serving(self, protocol_factory, sock): + def loop(f=None): + try: + if f: + conn, addr = f.result() + protocol = protocol_factory() + self._make_socket_transport( + conn, protocol, extra={'addr': addr}) + f = self._proactor.accept(sock) + except OSError: + sock.close() + logging.exception('Accept failed') + else: + f.add_done_callback(loop) + self.call_soon(loop) + + def _process_events(self, event_list): + pass # XXX hard work currently done in poll diff --git a/tulip/protocols.py b/tulip/protocols.py new file mode 100644 index 00000000..f01e2fd2 --- /dev/null +++ b/tulip/protocols.py @@ -0,0 +1,74 @@ +"""Abstract Protocol class.""" + +__all__ = ['Protocol', 'DatagramProtocol'] + + +class Protocol: + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing (they don't raise exceptions). + + When the user wants to requests a transport, they pass a protocol + factory to a utility function (e.g., EventLoop.create_connection()). + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_list() will be called exactly once + with either an exception object or None as an argument. + + State machine of calls: + + start -> CM [-> DR*] [-> ER?] -> CL -> end + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the connection. + To send data, call its write() or writelines() method. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + The default implementation does nothing. + + TODO: By default close the transport. But we don't have the + transport as an instance variable (connection_made() may not + set it). + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +class DatagramProtocol: + """ABC representing a datagram protocol.""" + + def connection_made(self, transport): + """Called when a datagram transport is ready.""" + + def datagram_received(self, data, addr): + """Called when some datagram is received.""" + + def connection_refused(self, exc): + """Connection is refused.""" + + def connection_lost(self, exc): + """Called when the connection is lost or closed.""" diff --git a/tulip/selector_events.py b/tulip/selector_events.py new file mode 100644 index 00000000..5feaafa4 --- /dev/null +++ b/tulip/selector_events.py @@ -0,0 +1,661 @@ +"""Event loop using a selector and related classes. + +A selector is a "notify-when-ready" multiplexer. For a subclass which +also includes support for signal handling, see the unix_events sub-module. +""" + +import collections +import errno +import logging +import socket +try: + import ssl +except ImportError: # pragma: no cover + ssl = None +import sys + +from . import base_events +from . import events +from . import futures +from . import selectors +from . import transports + + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK, errno.EINPROGRESS)) +if sys.platform == 'win32': # pragma: no cover + _TRYAGAIN = frozenset(list(_TRYAGAIN) + [errno.WSAEWOULDBLOCK]) + + +class BaseSelectorEventLoop(base_events.BaseEventLoop): + """Selector event loop. + + See events.EventLoop for API specification. + """ + + def __init__(self, selector=None): + super().__init__() + + if selector is None: + selector = selectors.Selector() + logging.debug('Using selector: %s', selector.__class__.__name__) + self._selector = selector + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + return _SelectorSocketTransport(self, sock, protocol, waiter, extra) + + def _make_ssl_transport(self, rawsock, protocol, + sslcontext, waiter, extra=None): + return _SelectorSslTransport( + self, rawsock, protocol, sslcontext, waiter, extra) + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + return _SelectorDatagramTransport(self, sock, protocol, address, extra) + + def close(self): + if self._selector is not None: + self._close_self_pipe() + self._selector.close() + self._selector = None + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self.remove_reader(self._ssock.fileno()) + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.add_reader(self._ssock.fileno(), self._read_from_self) + + def _read_from_self(self): + try: + self._ssock.recv(1) + except socket.error as exc: + if exc.errno in _TRYAGAIN: + return + raise # Halp! + + def _write_to_self(self): + try: + self._csock.send(b'x') + except socket.error as exc: + if exc.errno in _TRYAGAIN: + return + raise # Halp! + + def _start_serving(self, protocol_factory, sock): + self.add_reader(sock.fileno(), self._accept_connection, + protocol_factory, sock) + + def _accept_connection(self, protocol_factory, sock): + try: + conn, addr = sock.accept() + except socket.error as exc: + if exc.errno in _TRYAGAIN: + return # False alarm. + # Bad error. Stop serving. + self.remove_reader(sock.fileno()) + sock.close() + # There's nowhere to send the error, so just log it. + # TODO: Someone will want an error handler for this. + logging.exception('Accept failed') + return + + self._make_socket_transport( + conn, protocol_factory(), extra={'addr': addr}) + # It's now up to the protocol to handle the connection. + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a Handler instance.""" + handler = events.make_handler(callback, args) + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_READ, + (handler, None)) + else: + self._selector.modify(fd, mask | selectors.EVENT_READ, + (handler, writer)) + if reader is not None: + reader.cancel() + + return handler + + def remove_reader(self, fd): + """Remove a reader callback.""" + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + return False + else: + mask &= ~selectors.EVENT_READ + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (None, writer)) + + if reader is not None: + reader.cancel() + return True + else: + return False + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a Handler instance.""" + handler = events.make_handler(callback, args) + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_WRITE, + (None, handler)) + else: + self._selector.modify(fd, mask | selectors.EVENT_WRITE, + (reader, handler)) + if writer is not None: + writer.cancel() + + return handler + + def remove_writer(self, fd): + """Remove a writer callback.""" + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + return False + else: + # Remove both writer and connector. + mask &= ~selectors.EVENT_WRITE + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None)) + + if writer is not None: + writer.cancel() + return True + else: + return False + + def sock_recv(self, sock, n): + """XXX""" + fut = futures.Future() + self._sock_recv(fut, False, sock, n) + return fut + + def _sock_recv(self, fut, registered, sock, n): + fd = sock.fileno() + if registered: + # Remove the callback early. It should be rare that the + # selector says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_reader(fd) + if fut.cancelled(): + return + try: + data = sock.recv(n) + fut.set_result(data) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + fut.set_exception(exc) + else: + self.add_reader(fd, self._sock_recv, fut, True, sock, n) + + def sock_sendall(self, sock, data): + """XXX""" + fut = futures.Future() + if data: + self._sock_sendall(fut, False, sock, data) + else: + fut.set_result(None) + return fut + + def _sock_sendall(self, fut, registered, sock, data): + fd = sock.fileno() + + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + + try: + n = sock.send(data) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + fut.set_exception(exc) + return + n = 0 + + if n == len(data): + fut.set_result(None) + else: + if n: + data = data[n:] + self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + + def sock_connect(self, sock, address): + """XXX""" + # That address better not require a lookup! We're not calling + # self.getaddrinfo() for you here. But verifying this is + # complicated; the socket module doesn't have a pattern for + # IPv6 addresses (there are too many forms, apparently). + fut = futures.Future() + self._sock_connect(fut, False, sock, address) + return fut + + def _sock_connect(self, fut, registered, sock, address): + fd = sock.fileno() + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + try: + if not registered: + # First time around. + sock.connect(address) + else: + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to the except clause below. + raise socket.error(err, 'Connect call failed') + fut.set_result(None) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + fut.set_exception(exc) + else: + self.add_writer(fd, self._sock_connect, + fut, True, sock, address) + + def sock_accept(self, sock): + """XXX""" + fut = futures.Future() + self._sock_accept(fut, False, sock) + return fut + + def _sock_accept(self, fut, registered, sock): + fd = sock.fileno() + if registered: + self.remove_reader(fd) + if fut.cancelled(): + return + try: + conn, address = sock.accept() + conn.setblocking(False) + fut.set_result((conn, address)) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + fut.set_exception(exc) + else: + self.add_reader(fd, self._sock_accept, fut, True, sock) + + def _process_events(self, event_list): + for fileobj, mask, (reader, writer) in event_list: + if mask & selectors.EVENT_READ and reader is not None: + if reader.cancelled: + self.remove_reader(fileobj) + else: + self._add_callback(reader) + if mask & selectors.EVENT_WRITE and writer is not None: + if writer.cancelled: + self.remove_writer(fileobj) + else: + self._add_callback(writer) + + +class _SelectorSocketTransport(transports.Transport): + + def __init__(self, event_loop, sock, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._event_loop = event_loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._closing = False # Set when close() called. + self._event_loop.add_reader(self._sock.fileno(), self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = self._sock.recv(16*1024) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + else: + if data: + self._event_loop.call_soon(self._protocol.data_received, data) + else: + self._event_loop.remove_reader(self._sock.fileno()) + self._event_loop.call_soon(self._protocol.eof_received) + + def write(self, data): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + if not self._buffer: + # Attempt to send it right away first. + try: + n = self._sock.send(data) + except socket.error as exc: + if exc.errno in _TRYAGAIN: + n = 0 + else: + self._fatal_error(exc) + return + if n == len(data): + return + if n: + data = data[n:] + self._event_loop.add_writer(self._sock.fileno(), self._write_ready) + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + self._buffer = [] + try: + if data: + n = self._sock.send(data) + else: + n = 0 + except socket.error as exc: + if exc.errno in _TRYAGAIN: + n = 0 + else: + self._fatal_error(exc) + return + if n == len(data): + self._event_loop.remove_writer(self._sock.fileno()) + if self._closing: + self._event_loop.call_soon(self._call_connection_lost, None) + return + if n: + data = data[n:] + self._buffer.append(data) # Try again later. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sock.fileno()) + if not self._buffer: + self._event_loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + logging.exception('Fatal error for %s', self) + self._event_loop.remove_writer(self._sock.fileno()) + self._event_loop.remove_reader(self._sock.fileno()) + self._buffer = [] + self._event_loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + + +class _SelectorSslTransport(transports.Transport): + + def __init__(self, event_loop, rawsock, + protocol, sslcontext, waiter, extra=None): + super().__init__(extra) + + self._event_loop = event_loop + self._rawsock = rawsock + self._protocol = protocol + sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslcontext = sslcontext + self._waiter = waiter + sslsock = sslcontext.wrap_socket(rawsock, + do_handshake_on_connect=False) + self._sslsock = sslsock + self._buffer = [] + self._closing = False # Set when close() called. + self._extra['socket'] = sslsock + + self._on_handshake() + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._event_loop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._event_loop.add_writer(fd, self._on_handshake) + return + except Exception as exc: + self._sslsock.close() + self._waiter.set_exception(exc) + return + except BaseException as exc: + self._sslsock.close() + self._waiter.set_exception(exc) + raise + self._event_loop.remove_reader(fd) + self._event_loop.remove_writer(fd) + self._event_loop.add_reader(fd, self._on_ready) + self._event_loop.add_writer(fd, self._on_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.call_soon(self._waiter.set_result, None) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + return + else: + if data: + self._protocol.data_received(data) + else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer + self._event_loop.remove_reader(fd) + self._event_loop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + return + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = b''.join(self._buffer) + self._buffer = [] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + n = 0 + except ssl.SSLWantWriteError: + n = 0 + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + return + else: + n = 0 + + if n < len(data): + self._buffer.append(data[n:]) + elif self._closing: + self._event_loop.remove_writer(self._sslsock.fileno()) + self._sslsock.close() + self._protocol.connection_lost(None) + + def write(self, data): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sslsock.fileno()) + if not self._buffer: + self._event_loop.call_soon(self._protocol.connection_lost, None) + + def _fatal_error(self, exc): + logging.exception('Fatal error for %s', self) + self._event_loop.remove_writer(self._sslsock.fileno()) + self._event_loop.remove_reader(self._sslsock.fileno()) + self._buffer = [] + self._event_loop.call_soon(self._protocol.connection_lost, exc) + + +class _SelectorDatagramTransport(transports.Transport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, event_loop, sock, protocol, address=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._event_loop = event_loop + self._sock = sock + self._fileno = sock.fileno() + self._protocol = protocol + self._address = address + self._buffer = collections.deque() + self._closing = False # Set when close() called. + self._event_loop.add_reader(self._fileno, self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + + def _read_ready(self): + try: + data, addr = self._sock.recvfrom(self.max_size) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + else: + self._event_loop.call_soon( + self._protocol.datagram_received, data, addr) + + def sendto(self, data, addr=None): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + + if self._address: + assert addr in (None, self._address) + + if not self._buffer: + # Attempt to send it right away first. + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + return + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + + self._event_loop.add_writer(self._fileno, self._sendto_ready) + + self._buffer.append((data, addr)) + + def _sendto_ready(self): + while self._buffer: + data, addr = self._buffer.popleft() + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + return + + # Try again later. + self._buffer.appendleft((data, addr)) + break + + if not self._buffer: + self._event_loop.remove_writer(self._fileno) + if self._closing: + self._event_loop.call_soon(self._call_connection_lost, None) + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._fileno) + if not self._buffer: + self._event_loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + logging.exception('Fatal error for %s', self) + + self._buffer.clear() + self._event_loop.remove_writer(self._fileno) + self._event_loop.remove_reader(self._fileno) + if self._address and isinstance(exc, ConnectionRefusedError): + self._event_loop.call_soon(self._protocol.connection_refused, exc) + self._event_loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() diff --git a/tulip/selectors.py b/tulip/selectors.py new file mode 100644 index 00000000..b8b830eb --- /dev/null +++ b/tulip/selectors.py @@ -0,0 +1,419 @@ +"""Select module. + +This module supports asynchronous I/O on multiple file descriptors. +""" + +import logging +import sys + +from select import * + + +# generic events, that must be mapped to implementation-specific ones +# read event +EVENT_READ = (1 << 0) +# write event +EVENT_WRITE = (1 << 1) + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file descriptor, or any object with a `fileno()` method + + Returns: + corresponding file descriptor + """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (ValueError, TypeError): + raise ValueError("Invalid file object: {!r}".format(fileobj)) + return fd + + +class SelectorKey: + """Object used internally to associate a file object to its backing file + descriptor, selected event mask and attached data.""" + + def __init__(self, fileobj, events, data=None): + self.fileobj = fileobj + self.fd = _fileobj_to_fd(fileobj) + self.events = events + self.data = data + + def __repr__(self): + return '{}'.format( + self.__class__.__name__, + self.fileobj, self.fd, self.events, self.data) + + +class _BaseSelector: + """Base selector class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + performant implementation on the current platform. + """ + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # this maps file objects to keys - for fast (un)registering + self._fileobj_to_key = {} + + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + """ + if (not events) or (events & ~(EVENT_READ|EVENT_WRITE)): + raise ValueError("Invalid events: {}".format(events)) + + if fileobj in self._fileobj_to_key: + raise ValueError("{!r} is already registered".format(fileobj)) + + key = SelectorKey(fileobj, events, data) + self._fd_to_key[key.fd] = key + self._fileobj_to_key[fileobj] = key + return key + + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object + + Returns: + SelectorKey instance + """ + try: + key = self._fileobj_to_key[fileobj] + del self._fd_to_key[key.fd] + del self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + return key + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + """ + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + if events != key.events or data != key.data: + self.unregister(fileobj) + return self.register(fileobj, events, data) + else: + return key + + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout == 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (fileobj, events, attached data) for ready file objects + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE + """ + raise NotImplementedError() + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + self._fd_to_key.clear() + self._fileobj_to_key.clear() + + def get_info(self, fileobj): + """Return information about a registered file object. + + Returns: + (events, data) associated to this file object + + Raises KeyError if the file object is not registered. + """ + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise KeyError("{} is not registered".format(fileobj)) + return key.events, key.data + + def registered_count(self): + """Return the number of registered file objects. + + Returns: + number of currently registered file objects + """ + return len(self._fd_to_key) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key + """ + try: + return self._fd_to_key[fd] + except KeyError: + logging.warn('No key found for fd %r', fd) + return None + + +class SelectSelector(_BaseSelector): + """Select-based selector.""" + + def __init__(self): + super().__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & EVENT_WRITE: + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + def select(self, timeout=None): + try: + r, w, _ = self._select(self._readers, self._writers, [], timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + r = set(r) + w = set(w) + ready = [] + for fd in r | w: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + if sys.platform == 'win32': + def _select(self, r, w, _, timeout=None): + r, w, x = select(r, w, w, timeout) + return r, w + x, [] + else: + from select import select as _select + + +if 'poll' in globals(): + + # TODO: Implement poll() for Windows with workaround for + # brokenness in WSAPoll() (Richard Oudkerk, see + # http://bugs.python.org/issue16507). + + class PollSelector(_BaseSelector): + """Poll-based selector.""" + + def __init__(self): + super().__init__() + self._poll = poll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= POLLIN + if events & EVENT_WRITE: + poll_events |= POLLOUT + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = None if timeout is None else int(1000 * timeout) + ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~POLLIN: + events |= EVENT_WRITE + if event & ~POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + +if 'epoll' in globals(): + + class EpollSelector(_BaseSelector): + """Epoll-based selector.""" + + def __init__(self): + super().__init__() + self._epoll = epoll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + epoll_events = 0 + if events & EVENT_READ: + epoll_events |= EPOLLIN + if events & EVENT_WRITE: + epoll_events |= EPOLLOUT + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._epoll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = -1 if timeout is None else timeout + max_ev = self.registered_count() + ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~EPOLLIN: + events |= EVENT_WRITE + if event & ~EPOLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + def close(self): + super().close() + self._epoll.close() + + +if 'kqueue' in globals(): + + class KqueueSelector(_BaseSelector): + """Kqueue-based selector.""" + + def __init__(self): + super().__init__() + self._kqueue = kqueue() + + def unregister(self, fileobj): + key = super().unregister(fileobj) + mask = 0 + if key.events & EVENT_READ: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + if key.events & EVENT_WRITE: + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + return key + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & EVENT_WRITE: + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def select(self, timeout=None): + max_ev = self.registered_count() + ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == KQ_FILTER_READ: + events |= EVENT_READ + if flag == KQ_FILTER_WRITE: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + def close(self): + super().close() + self._kqueue.close() + + +# Choose the best implementation: roughly, epoll|kqueue > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + Selector = KqueueSelector +elif 'EpollSelector' in globals(): + Selector = EpollSelector +elif 'PollSelector' in globals(): + Selector = PollSelector +else: + Selector = SelectSelector diff --git a/tulip/streams.py b/tulip/streams.py new file mode 100644 index 00000000..8d7f6236 --- /dev/null +++ b/tulip/streams.py @@ -0,0 +1,145 @@ +"""Stream-related things.""" + +__all__ = ['StreamReader'] + +import collections + +from . import futures +from . import tasks + + +class StreamReader: + + def __init__(self, limit=2**16): + self.limit = limit # Max line length. (Security feature.) + self.buffer = collections.deque() # Deque of bytes objects. + self.byte_count = 0 # Bytes in buffer. + self.eof = False # Whether we're done. + self.waiter = None # A future. + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + if self.waiter is not None: + self.waiter.set_exception(exc) + + def feed_eof(self): + self.eof = True + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(True) + + def feed_data(self, data): + if not data: + return + + self.buffer.append(data) + self.byte_count += len(data) + + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(False) + + @tasks.coroutine + def readline(self): + if self._exception is not None: + raise self._exception + + parts = [] + parts_size = 0 + not_enough = True + + while not_enough: + while self.buffer and not_enough: + data = self.buffer.popleft() + ichar = data.find(b'\n') + if ichar < 0: + parts.append(data) + parts_size += len(data) + else: + ichar += 1 + head, tail = data[:ichar], data[ichar:] + if tail: + self.buffer.appendleft(tail) + not_enough = False + parts.append(head) + parts_size += len(head) + + if parts_size > self.limit: + self.byte_count -= parts_size + raise ValueError('Line is too long') + + if self.eof: + break + + if not_enough: + assert self.waiter is None + self.waiter = futures.Future() + yield from self.waiter + + line = b''.join(parts) + self.byte_count -= parts_size + + return line + + @tasks.coroutine + def read(self, n=-1): + if self._exception is not None: + raise self._exception + + if not n: + return b'' + + if n < 0: + while not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + else: + if not self.byte_count and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + + if n < 0 or self.byte_count <= n: + data = b''.join(self.buffer) + self.buffer.clear() + self.byte_count = 0 + return data + + parts = [] + parts_bytes = 0 + while self.buffer and parts_bytes < n: + data = self.buffer.popleft() + data_bytes = len(data) + if n < parts_bytes + data_bytes: + data_bytes = n - parts_bytes + data, rest = data[:data_bytes], data[data_bytes:] + self.buffer.appendleft(rest) + + parts.append(data) + parts_bytes += data_bytes + self.byte_count -= data_bytes + + return b''.join(parts) + + @tasks.coroutine + def readexactly(self, n): + if self._exception is not None: + raise self._exception + + if n <= 0: + return b'' + + while self.byte_count < n and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + + return (yield from self.read(n)) diff --git a/tulip/subprocess_transport.py b/tulip/subprocess_transport.py new file mode 100644 index 00000000..721013f8 --- /dev/null +++ b/tulip/subprocess_transport.py @@ -0,0 +1,133 @@ +import fcntl +import os +import traceback + +from . import transports +from . import events + + +class UnixSubprocessTransport(transports.Transport): + """Transport class managing a subprocess. + + TODO: Separate this into something that just handles pipe I/O, + and something else that handles pipe setup, fork, and exec. + """ + + def __init__(self, protocol, args): + self._protocol = protocol # Not a factory! :-) + self._args = args # args[0] must be full path of binary. + self._event_loop = events.get_event_loop() + self._buffer = [] + self._eof = False + rstdin, self._wstdin = os.pipe() + self._rstdout, wstdout = os.pipe() + + # TODO: This is incredibly naive. Should look at + # subprocess.py for all the precautions around fork/exec. + pid = os.fork() + if not pid: + # Child. + try: + os.dup2(rstdin, 0) + os.dup2(wstdout, 1) + # TODO: What to do with stderr? + os.execv(args[0], args) + except: + try: + traceback.print_traceback() + finally: + os._exit(127) + + # Parent. + os.close(rstdin) + os.close(wstdout) + _setnonblocking(self._wstdin) + _setnonblocking(self._rstdout) + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.add_reader(self._rstdout, self._stdout_callback) + + def write(self, data): + assert not self._eof + if not data: + return + if not self._buffer: + # Attempt to write it right away first. + try: + n = os.write(self._wstdin, data) + except BlockingIOError: + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + if n >= len(data): + return + if n > 0: + data = data[n:] + self._event_loop.add_writer(self._wstdin, self._stdin_callback) + self._buffer.append(data) + + def write_eof(self): + assert not self._eof + assert self._wstdin >= 0 + self._eof = True + if not self._buffer: + self._event_loop.remove_writer(self._wstdin) + os.close(self._wstdin) + self._wstdin = -1 + + def close(self): + if not self._eof: + self.write_eof() + # XXX What else? + + def _fatal_error(self, exc): + logging.error('Fatal error: %r', exc) + if self._rstdout >= 0: + os.close(self._rstdout) + self._rstdout = -1 + if self._wstdin >= 0: + os.close(self._wstdin) + self._wstdin = -1 + self._eof = True + self._buffer = None + + def _stdin_callback(self): + data = b''.join(self._buffer) + self._buffer = [] + try: + if data: + n = os.write(self._wstdin, data) + else: + n = 0 + except BlockingIOError: + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + if n >= len(data): + self._event_loop.remove_writer(self._wstdin) + if self._eof: + os.close(self._wstdin) + self._wstdin = -1 + return + if n > 0: + data = data[n:] + self._buffer.append(data) # Try again later. + + def _stdout_callback(self): + try: + data = os.read(self._rstdout, 1024) + except BlockingIOError: + return + if data: + self._event_loop.call_soon(self._protocol.data_received, data) + else: + self._event_loop.remove_reader(self._rstdout) + os.close(self._rstdout) + self._rstdout = -1 + self._event_loop.call_soon(self._protocol.eof_received) + + +def _setnonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) diff --git a/tulip/tasks.py b/tulip/tasks.py new file mode 100644 index 00000000..2e6b73f3 --- /dev/null +++ b/tulip/tasks.py @@ -0,0 +1,304 @@ +"""Support for tasks, coroutines and the scheduler.""" + +__all__ = ['coroutine', 'task', 'Task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'as_completed', 'sleep', + ] + +import concurrent.futures +import inspect +import logging +import time + +from . import events +from . import futures + + +def coroutine(func): + """Decorator to mark coroutines.""" + # TODO: This is a feel-good API only. It is not enforced. + assert inspect.isgeneratorfunction(func) + func._is_coroutine = True # Not sure who can use this. + return func + + +# TODO: Do we need this? +def iscoroutinefunction(func): + """Return True if func is a decorated coroutine function.""" + return (inspect.isgeneratorfunction(func) and + getattr(func, '_is_coroutine', False)) + + +# TODO: Do we need this? +def iscoroutine(obj): + """Return True if obj is a coroutine object.""" + return inspect.isgenerator(obj) # TODO: And what? + + +def task(func): + """Decorator for a coroutine to be wrapped in a Task.""" + def task_wrapper(*args, **kwds): + coro = func(*args, **kwds) + return Task(coro) + return task_wrapper + + +class Task(futures.Future): + """A coroutine wrapped in a Future.""" + + def __init__(self, coro, event_loop=None): + assert inspect.isgenerator(coro) # Must be a coroutine *object*. + super().__init__(event_loop=event_loop) # Sets self._event_loop. + self._coro = coro + self._must_cancel = False + self._event_loop.call_soon(self._step) + + def __repr__(self): + res = super().__repr__() + if (self._must_cancel and + self._state == futures._PENDING and + ')'.format(self._coro.__name__) + res[i:] + return res + + def cancel(self): + if self.done(): + return False + self._must_cancel = True + # _step() will call super().cancel() to call the callbacks. + self._event_loop.call_soon(self._step_maybe) + return True + + def cancelled(self): + return self._must_cancel or super().cancelled() + + def _step_maybe(self): + # Helper for cancel(). + if not self.done(): + return self._step() + + def _step(self, value=None, exc=None): + if self.done(): + logging.warn('_step(): already done: %r, %r, %r', self, value, exc) + return + # We'll call either coro.throw(exc) or coro.send(value). + if self._must_cancel: + exc = futures.CancelledError + coro = self._coro + try: + if exc is not None: + result = coro.throw(exc) + elif value is not None: + result = coro.send(value) + else: + result = next(coro) + except StopIteration as exc: + if self._must_cancel: + super().cancel() + else: + self.set_result(exc.value) + except Exception as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + logging.exception('Exception in task') + except BaseException as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + logging.exception('BaseException in task') + raise + else: + # XXX No check for self._must_cancel here? + if isinstance(result, futures.Future): + if not result._blocking: + self._event_loop.call_soon( + self._step, + None, RuntimeError( + 'yield was used instead of yield from in task %r ' + 'with %r' % (self, result))) + else: + result._blocking = False + result.add_done_callback(self._wakeup) + + elif isinstance(result, concurrent.futures.Future): + # This ought to be more efficient than wrap_future(), + # because we don't create an extra Future. + result.add_done_callback( + lambda future: + self._event_loop.call_soon_threadsafe( + self._wakeup, future)) + else: + if inspect.isgenerator(result): + self._event_loop.call_soon( + self._step, + None, RuntimeError( + 'yield was used instead of yield from for ' + 'generator in task %r with %s' % (self, result))) + else: + if result is not None: + logging.warn('_step(): bad yield: %r', result) + + self._event_loop.call_soon(self._step) + + def _wakeup(self, future): + try: + value = future.result() + except Exception as exc: + self._step(None, exc) + else: + self._step(value, None) + + +# wait() and as_completed() similar to those in PEP 3148. + +FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED +FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION +ALL_COMPLETED = concurrent.futures.ALL_COMPLETED + + +# Even though this *is* a @coroutine, we don't mark it as such! +def wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Wait for the Futures and and coroutines given by fs to complete. + + Coroutines will be wrapped in Tasks. + + Returns two sets of Future: (done, pending). + + Usage: + + done, pending = yield from tulip.wait(fs) + + Note: This does not raise TimeoutError! Futures that aren't done + when the timeout occurs are returned in the second set. + """ + fs = _wrap_coroutines(fs) + return _wait(fs, timeout, return_when) + + +@coroutine +def _wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Internal helper: Like wait() but does not wrap coroutines.""" + done, pending = set(), set() + + errors = 0 + for f in fs: + if f.done(): + done.add(f) + if not f.cancelled() and f.exception() is not None: + errors += 1 + else: + pending.add(f) + + if (not pending or + timeout is not None and timeout <= 0 or + return_when == FIRST_COMPLETED and done or + return_when == FIRST_EXCEPTION and errors): + return done, pending + + bail = futures.Future() # Will always be cancelled eventually. + timeout_handler = None + debugstuff = locals() + + if timeout is not None: + loop = events.get_event_loop() + timeout_handler = loop.call_later(timeout, bail.cancel) + + def _on_completion(f): + pending.remove(f) + done.add(f) + if (not pending or + return_when == FIRST_COMPLETED or + (return_when == FIRST_EXCEPTION and + not f.cancelled() and + f.exception() is not None)): + bail.cancel() + + try: + for f in pending: + f.add_done_callback(_on_completion) + try: + yield from bail + except futures.CancelledError: + pass + finally: + for f in pending: + f.remove_done_callback(_on_completion) + if timeout_handler is not None: + timeout_handler.cancel() + + really_done = set(f for f in pending if f.done()) + if really_done: + done.update(really_done) + pending.difference_update(really_done) + + return done, pending + + +# This is *not* a @coroutine! It is just an iterator (yielding Futures). +def as_completed(fs, timeout=None): + """Return an iterator whose values, when waited for, are Futures. + + This differs from PEP 3148; the proper way to use this is: + + for f in as_completed(fs): + result = yield from f # The 'yield from' may raise. + # Use result. + + Raises TimeoutError if the timeout occurs before all Futures are + done. + + Note: The futures 'f' are not necessarily members of fs. + """ + deadline = None + if timeout is not None: + deadline = time.monotonic() + timeout + done = None # Make nonlocal happy. + fs = _wrap_coroutines(fs) + while fs: + if deadline is not None: + timeout = deadline - time.monotonic() + @coroutine + def _wait_for_some(): + nonlocal done, fs + done, fs = yield from _wait(fs, timeout=timeout, + return_when=FIRST_COMPLETED) + if not done: + fs = set() + raise futures.TimeoutError() + return done.pop().result() # May raise. + yield Task(_wait_for_some()) + for f in done: + yield f + + +def _wrap_coroutines(fs): + """Internal helper to process an iterator of Futures and coroutines. + + Returns a set of Futures. + """ + wrapped = set() + for f in fs: + if not isinstance(f, futures.Future): + assert iscoroutine(f) + f = Task(f) + wrapped.add(f) + return wrapped + + +def sleep(when, result=None): + """Return a Future that completes after a given time (in seconds). + + It's okay to cancel the Future. + + Undocumented feature: sleep(when, x) sets the Future's result to x. + """ + future = futures.Future() + future._event_loop.call_later(when, future.set_result, result) + return future diff --git a/tulip/test_utils.py b/tulip/test_utils.py new file mode 100644 index 00000000..9b87db2f --- /dev/null +++ b/tulip/test_utils.py @@ -0,0 +1,30 @@ +"""Utilities shared by tests.""" + +import logging +import socket +import sys +import unittest + + +if sys.platform == 'win32': # pragma: no cover + from .winsocketpair import socketpair +else: + from socket import socketpair # pragma: no cover + + +class LogTrackingTestCase(unittest.TestCase): + + def setUp(self): + self._logger = logging.getLogger() + self._log_level = self._logger.getEffectiveLevel() + + def tearDown(self): + self._logger.setLevel(self._log_level) + + def suppress_log_errors(self): # pragma: no cover + if self._log_level >= logging.WARNING: + self._logger.setLevel(logging.CRITICAL) + + def suppress_log_warnings(self): # pragma: no cover + if self._log_level >= logging.WARNING: + self._logger.setLevel(logging.ERROR) diff --git a/tulip/transports.py b/tulip/transports.py new file mode 100644 index 00000000..6eb1c554 --- /dev/null +++ b/tulip/transports.py @@ -0,0 +1,99 @@ +"""Abstract Transport class.""" + +__all__ = ['Transport'] + + +class Transport: + """ABC representing a transport. + + There may be several implementations, but typically, the user does + not implement new transports; rather, the platform provides some + useful transports that are implemented using the platform's best + practices. + + The user never instantiates a transport directly; they call a + utility function, passing it a protocol factory and other + information necessary to create the transport and protocol. (E.g. + EventLoop.create_connection() or EventLoop.start_serving().) + + The utility function will asynchronously create a transport and a + protocol and hook them up by calling the protocol's + connection_made() method, passing it the transport. + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + def __init__(self, extra=None): + if extra is None: + extra = {} + self._extra = extra + + def get_extra_info(self, name, default=None): + """Get optional transport information.""" + return self._extra.get(name, default) + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def write_eof(self): + """Closes the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + raise NotImplementedError + + def can_write_eof(self): + """Return True if this protocol supports write_eof(), False if not.""" + raise NotImplementedError + + def pause(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + raise NotImplementedError + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError diff --git a/tulip/unix_events.py b/tulip/unix_events.py new file mode 100644 index 00000000..db2c560d --- /dev/null +++ b/tulip/unix_events.py @@ -0,0 +1,113 @@ +"""Selector eventloop for Unix with signal handling.""" + +import errno +import logging +import socket +import sys + +try: + import signal +except ImportError: # pragma: no cover + signal = None + +from . import events +from . import selector_events + + +__all__ = ['SelectorEventLoop'] + + +if sys.platform == 'win32': # pragma: no cover + raise ImportError('Signals are not really supported on Windows') + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Unix event loop + + Adds signal handling to SelectorEventLoop + """ + def _socketpair(self): + return socket.socketpair() + + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + self._check_signal(sig) + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except ValueError as exc: + raise RuntimeError(str(exc)) + handler = events.make_handler(callback, args) + self._signal_handlers[sig] = handler + try: + signal.signal(sig, self._handle_signal) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as nexc: + logging.info('set_wakeup_fd(-1) failed: %s', nexc) + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + return handler + + def _handle_signal(self, sig, arg): + """Internal helper that is the actual signal handler.""" + handler = self._signal_handlers.get(sig) + if handler is None: + return # Assume it's some race condition. + if handler.cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self.call_soon_threadsafe(handler) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not.""" + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as exc: + logging.info('set_wakeup_fd(-1) failed: %s', exc) + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + if signal is None: + raise RuntimeError('Signals are not supported') + if not (1 <= sig < signal.NSIG): + raise ValueError('sig {} out of range(1, {})'.format(sig, + signal.NSIG)) diff --git a/tulip/windows_events.py b/tulip/windows_events.py new file mode 100644 index 00000000..3a0b8675 --- /dev/null +++ b/tulip/windows_events.py @@ -0,0 +1,157 @@ +"""Selector and proactor eventloops for Windows. +""" + +import errno +import logging +import socket +import weakref +import struct +import _winapi + + +from . import futures +from . import proactor_events +from . import selectors +from . import selector_events +from . import winsocketpair +from . import _overlapped + + +__all__ = ['SelectorEventLoop', 'ProactorEventLoop'] + + +NULL = 0 +INFINITE = 0xffffffff +ERROR_CONNECTION_REFUSED = 1225 +ERROR_CONNECTION_ABORTED = 1236 + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + def _socketpair(self): + return winsocketpair.socketpair() + + +class ProactorEventLoop(proactor_events.BaseProactorEventLoop): + def __init__(self, proactor=None): + if proactor is None: + proactor = IocpProactor() + super().__init__(proactor) + + def _socketpair(self): + return winsocketpair.socketpair() + + +class IocpProactor: + + def __init__(self, concurrency=0xffffffff): + self._results = [] + self._iocp = _overlapped.CreateIoCompletionPort( + _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency) + self._cache = {} + self._registered = weakref.WeakSet() + + def registered_count(self): + return len(self._cache) + + def select(self, timeout=None): + if not self._results: + self._poll(timeout) + tmp = self._results + self._results = [] + return tmp + + def recv(self, conn, nbytes, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + ov.WSARecv(conn.fileno(), nbytes, flags) + return self._register(ov, conn, ov.getresult) + + def send(self, conn, buf, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + ov.WSASend(conn.fileno(), buf, flags) + return self._register(ov, conn, ov.getresult) + + def accept(self, listener): + self._register_with_iocp(listener) + conn = self._get_accept_socket() + ov = _overlapped.Overlapped(NULL) + ov.AcceptEx(listener.fileno(), conn.fileno()) + def finish_accept(): + addr = ov.getresult() + buf = struct.pack('@P', listener.fileno()) + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_ACCEPT_CONTEXT, + buf) + conn.settimeout(listener.gettimeout()) + return conn, conn.getpeername() + return self._register(ov, listener, finish_accept) + + def connect(self, conn, address): + self._register_with_iocp(conn) + _overlapped.BindLocal(conn.fileno(), len(address)) + ov = _overlapped.Overlapped(NULL) + ov.ConnectEx(conn.fileno(), address) + def finish_connect(): + ov.getresult() + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_CONNECT_CONTEXT, + 0) + return conn + return self._register(ov, conn, finish_connect) + + def _register_with_iocp(self, obj): + if obj not in self._registered: + self._registered.add(obj) + _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) + + def _register(self, ov, obj, callback): + f = futures.Future() + self._cache[ov.address] = (f, ov, obj, callback) + return f + + def _get_accept_socket(self): + s = socket.socket() + s.settimeout(0) + return s + + def _poll(self, timeout=None): + if timeout is None: + ms = INFINITE + elif timeout < 0: + raise ValueError("negative timeout") + else: + ms = int(timeout * 1000 + 0.5) + if ms >= INFINITE: + raise ValueError("timeout too big") + while True: + status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms) + if status is None: + return + address = status[3] + f, ov, obj, callback = self._cache.pop(address) + try: + value = callback() + except OSError as e: + f.set_exception(e) + self._results.append(f) + else: + f.set_result(value) + self._results.append(f) + ms = 0 + + def close(self): + for (f, ov, obj, callback) in self._cache.values(): + try: + ov.cancel() + except OSError: + pass + + while self._cache: + if not self._poll(1): + logging.debug('taking long time to close proactor') + + self._results = [] + if self._iocp is not None: + _winapi.CloseHandle(self._iocp) + self._iocp = None diff --git a/tulip/winsocketpair.py b/tulip/winsocketpair.py new file mode 100644 index 00000000..374616f6 --- /dev/null +++ b/tulip/winsocketpair.py @@ -0,0 +1,34 @@ +"""A socket pair usable as a self-pipe, for Windows. + +Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. +""" + +import errno +import socket +import sys + +if sys.platform != 'win32': # pragma: no cover + raise ImportError('winsocketpair is win32 only') + + +def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """Emulate the Unix socketpair() function on Windows.""" + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + lsock.bind(('localhost', 0)) + lsock.listen(1) + addr, port = lsock.getsockname() + csock = socket.socket(family, type, proto) + csock.setblocking(False) + try: + csock.connect((addr, port)) + except socket.error as e: + if e.errno != errno.WSAEWOULDBLOCK: + lsock.close() + csock.close() + raise + ssock, _ = lsock.accept() + csock.setblocking(True) + lsock.close() + return (ssock, csock) From f2b4fe4008aef3762bab4da21bdde9f15e93b2cb Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 19 Mar 2013 10:13:33 -0700 Subject: [PATCH 0354/1502] Handler renamed to Handle --- tests/base_events_test.py | 28 ++++++------- tests/events_test.py | 70 ++++++++++++++++---------------- tests/futures_test.py | 2 + tests/locks_test.py | 17 ++++---- tests/selectors_test.py | 7 ++-- tests/subprocess_test.py | 2 +- tests/tasks_test.py | 50 ++++++++++++++++++----- tests/unix_events_test.py | 12 +++--- tests/winsocketpair_test.py | 1 + tulip/base_events.py | 75 ++++++++++++++++++----------------- tulip/events.py | 20 +++++----- tulip/futures.py | 4 +- tulip/locks.py | 41 +++++++++---------- tulip/selector_events.py | 20 +++++----- tulip/selectors.py | 3 +- tulip/subprocess_transport.py | 1 + tulip/tasks.py | 12 ++++-- tulip/unix_events.py | 33 ++++++++++----- tulip/windows_events.py | 10 ++--- 19 files changed, 232 insertions(+), 176 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index f5bfc363..03c3296b 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -41,8 +41,8 @@ def test_not_implemented(self): self.assertRaises( NotImplementedError, self.event_loop._read_from_self) - def test_add_callback_handler(self): - h = events.Handler(lambda: False, ()) + def test_add_callback_handle(self): + h = events.Handle(lambda: False, ()) self.event_loop._add_callback(h) self.assertFalse(self.event_loop._scheduled) @@ -59,8 +59,8 @@ def test_add_callback_timer(self): self.assertEqual([h1, h2], self.event_loop._scheduled) self.assertFalse(self.event_loop._ready) - def test_add_callback_cancelled_handler(self): - h = events.Handler(lambda: False, ()) + def test_add_callback_cancelled_handle(self): + h = events.Handle(lambda: False, ()) h.cancel() self.event_loop._add_callback(h) @@ -94,7 +94,7 @@ def cb(): h = self.event_loop.call_soon(cb) self.assertEqual(h._callback, cb) - self.assertIsInstance(h, events.Handler) + self.assertIsInstance(h, events.Handle) self.assertIn(h, self.event_loop._ready) def test_call_later(self): @@ -114,13 +114,13 @@ def cb(): self.assertIn(h, self.event_loop._ready) self.assertNotIn(h, self.event_loop._scheduled) - def test_run_once_in_executor_handler(self): + def test_run_once_in_executor_handle(self): def cb(): pass self.assertRaises( AssertionError, self.event_loop.run_in_executor, - None, events.Handler(cb, ()), ('',)) + None, events.Handle(cb, ()), ('',)) self.assertRaises( AssertionError, self.event_loop.run_in_executor, None, events.Timer(10, cb, ())) @@ -128,7 +128,7 @@ def cb(): def test_run_once_in_executor_canceled(self): def cb(): pass - h = events.Handler(cb, ()) + h = events.Handle(cb, ()) h.cancel() f = self.event_loop.run_in_executor(None, h) @@ -138,7 +138,7 @@ def cb(): def test_run_once_in_executor(self): def cb(): pass - h = events.Handler(cb, ()) + h = events.Handle(cb, ()) f = futures.Future() executor = unittest.mock.Mock() executor.submit.return_value = f @@ -222,14 +222,14 @@ def monotonic(): self.event_loop._run_once() self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0]) - def test__run_once_schedule_handler(self): - handler = None + def test__run_once_schedule_handle(self): + handle = None processed = False def cb(event_loop): - nonlocal processed, handler + nonlocal processed, handle processed = True - handler = event_loop.call_soon(lambda: True) + handle = event_loop.call_soon(lambda: True) h = events.Timer(time.monotonic() - 1, cb, (self.event_loop,)) @@ -238,7 +238,7 @@ def cb(event_loop): self.event_loop._run_once() self.assertTrue(processed) - self.assertEqual([handler], list(self.event_loop._ready)) + self.assertEqual([handle], list(self.event_loop._ready)) @unittest.mock.patch('tulip.base_events.socket') def test_create_connection_mutiple_errors(self, m_socket): diff --git a/tests/events_test.py b/tests/events_test.py index be8ccfeb..9fe2d82f 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -171,14 +171,14 @@ def callback(arg1, arg2): self.event_loop.run() self.assertEqual(results, [('hello', 'world')]) - def test_call_soon_with_handler(self): + def test_call_soon_with_handle(self): results = [] def callback(): results.append('yeah') - handler = events.Handler(callback, ()) - self.assertIs(self.event_loop.call_soon(handler), handler) + handle = events.Handle(callback, ()) + self.assertIs(self.event_loop.call_soon(handle), handle) self.event_loop.run() self.assertEqual(results, ['yeah']) @@ -212,17 +212,17 @@ def callback(arg): self.event_loop.run() self.assertEqual(results, ['hello', 'world']) - def test_call_soon_threadsafe_with_handler(self): + def test_call_soon_threadsafe_with_handle(self): results = [] def callback(arg): results.append(arg) - handler = events.Handler(callback, ('hello',)) + handle = events.Handle(callback, ('hello',)) def run(): self.assertIs( - self.event_loop.call_soon_threadsafe(handler), handler) + self.event_loop.call_soon_threadsafe(handle), handle) t = threading.Thread(target=run) self.event_loop.call_later(0.1, callback, 'world') @@ -253,12 +253,12 @@ def run(arg): res = self.event_loop.run_until_complete(f2) self.assertEqual(res, 'yo') - def test_run_in_executor_with_handler(self): + def test_run_in_executor_with_handle(self): def run(arg): time.sleep(0.1) return arg - handler = events.Handler(run, ('yo',)) - f2 = self.event_loop.run_in_executor(None, handler) + handle = events.Handle(run, ('yo',)) + f2 = self.event_loop.run_in_executor(None, handle) res = self.event_loop.run_until_complete(f2) self.assertEqual(res, 'yo') @@ -286,7 +286,7 @@ def reader(): self.event_loop.run() self.assertEqual(b''.join(bytes_read), b'abcdef') - def test_reader_callback_with_handler(self): + def test_reader_callback_with_handle(self): r, w = test_utils.socketpair() bytes_read = [] @@ -303,8 +303,8 @@ def reader(): self.assertTrue(self.event_loop.remove_reader(r.fileno())) r.close() - handler = events.Handler(reader, ()) - self.assertIs(handler, self.event_loop.add_reader(r.fileno(), handler)) + handle = events.Handle(reader, ()) + self.assertIs(handle, self.event_loop.add_reader(r.fileno(), handle)) self.event_loop.call_later(0.05, w.send, b'abc') self.event_loop.call_later(0.1, w.send, b'def') @@ -324,11 +324,11 @@ def reader(): if data: bytes_read.append(data) if sum(len(b) for b in bytes_read) >= 6: - handler.cancel() + handle.cancel() if not data: r.close() - handler = self.event_loop.add_reader(r.fileno(), reader) + handle = self.event_loop.add_reader(r.fileno(), reader) self.event_loop.call_later(0.05, w.send, b'abc') self.event_loop.call_later(0.1, w.send, b'def') self.event_loop.call_later(0.15, w.close) @@ -350,11 +350,11 @@ def remove_writer(): r.close() self.assertTrue(len(data) >= 200) - def test_writer_callback_with_handler(self): + def test_writer_callback_with_handle(self): r, w = test_utils.socketpair() w.setblocking(False) - handler = events.Handler(w.send, (b'x'*(256*1024),)) - self.assertIs(self.event_loop.add_writer(w.fileno(), handler), handler) + handle = events.Handle(w.send, (b'x'*(256*1024),)) + self.assertIs(self.event_loop.add_writer(w.fileno(), handle), handle) def remove_writer(): self.assertTrue(self.event_loop.remove_writer(w.fileno())) @@ -372,9 +372,9 @@ def test_writer_callback_cancel(self): def sender(): w.send(b'x'*256) - handler.cancel() + handle.cancel() - handler = self.event_loop.add_writer(w.fileno(), sender) + handle = self.event_loop.add_writer(w.fileno(), sender) self.event_loop.run() w.close() data = r.recv(1024) @@ -483,8 +483,8 @@ def my_handler(): nonlocal caught caught += 1 - handler = self.event_loop.add_signal_handler(signal.SIGINT, my_handler) - handler.cancel() + handle = self.event_loop.add_signal_handler(signal.SIGINT, my_handler) + handle.cancel() os.kill(os.getpid(), signal.SIGINT) self.event_loop.run_once() self.assertEqual(caught, 0) @@ -947,13 +947,13 @@ def test_reader_callback(self): raise unittest.SkipTest("IocpEventLoop does not have add_reader()") def test_reader_callback_cancel(self): raise unittest.SkipTest("IocpEventLoop does not have add_reader()") - def test_reader_callback_with_handler(self): + def test_reader_callback_with_handle(self): raise unittest.SkipTest("IocpEventLoop does not have add_reader()") def test_writer_callback(self): raise unittest.SkipTest("IocpEventLoop does not have add_writer()") def test_writer_callback_cancel(self): raise unittest.SkipTest("IocpEventLoop does not have add_writer()") - def test_writer_callback_with_handler(self): + def test_writer_callback_with_handle(self): raise unittest.SkipTest("IocpEventLoop does not have add_writer()") def test_accept_connection_retry(self): raise unittest.SkipTest( @@ -1008,22 +1008,22 @@ def create_event_loop(self): return unix_events.SelectorEventLoop(selectors.SelectSelector()) -class HandlerTests(unittest.TestCase): +class HandleTests(unittest.TestCase): - def test_handler(self): + def test_handle(self): def callback(*args): return args args = () - h = events.Handler(callback, args) + h = events.Handle(callback, args) self.assertIs(h.callback, callback) self.assertIs(h.args, args) self.assertFalse(h.cancelled) r = repr(h) self.assertTrue(r.startswith( - 'Handler(' - '.callback')) + 'Handle(' + '.callback')) self.assertTrue(r.endswith('())')) h.cancel() @@ -1031,19 +1031,19 @@ def callback(*args): r = repr(h) self.assertTrue(r.startswith( - 'Handler(' - '.callback')) + 'Handle(' + '.callback')) self.assertTrue(r.endswith('())')) - def test_make_handler(self): + def test_make_handle(self): def callback(*args): return args - h1 = events.Handler(callback, ()) - h2 = events.make_handler(h1, ()) + h1 = events.Handle(callback, ()) + h2 = events.make_handle(h1, ()) self.assertIs(h1, h2) self.assertRaises( - AssertionError, events.make_handler, h1, (1, 2)) + AssertionError, events.make_handle, h1, (1, 2)) class TimerTests(unittest.TestCase): @@ -1105,7 +1105,7 @@ def callback(*args): self.assertFalse(h1 == h2) self.assertTrue(h1 != h2) - h3 = events.Handler(callback, ()) + h3 = events.Handle(callback, ()) self.assertIs(NotImplemented, h1.__eq__(h3)) self.assertIs(NotImplemented, h1.__ne__(h3)) diff --git a/tests/futures_test.py b/tests/futures_test.py index fd49493d..5569cca1 100644 --- a/tests/futures_test.py +++ b/tests/futures_test.py @@ -66,12 +66,14 @@ def test_exception(self): def test_yield_from_twice(self): f = futures.Future() + def fixture(): yield 'A' x = yield from f yield 'B', x y = yield from f yield 'C', y + g = fixture() self.assertEqual(next(g), 'A') # yield 'A'. self.assertEqual(next(g), f) # First yield from f. diff --git a/tests/locks_test.py b/tests/locks_test.py index 4267134c..20dc222b 100644 --- a/tests/locks_test.py +++ b/tests/locks_test.py @@ -245,7 +245,8 @@ def test_wait_timeout(self): ev = locks.EventWaiter() self.event_loop.call_later(0.1, ev.set) - acquired = self.event_loop.run_until_complete(tasks.Task(ev.wait(10.1))) + acquired = self.event_loop.run_until_complete( + tasks.Task(ev.wait(10.1))) self.assertTrue(acquired) def test_wait_timeout_mixed(self): @@ -411,8 +412,8 @@ def test_wait_unacquired(self): def test_wait_for(self): cond = locks.Condition() - presult = False + def predicate(): return presult @@ -487,8 +488,8 @@ def test_wait_for_unacquired(self): # predicate can return true immediately res = self.event_loop.run_until_complete(tasks.Task( - cond.wait_for(lambda: [1,2,3]))) - self.assertEqual([1,2,3], res) + cond.wait_for(lambda: [1, 2, 3]))) + self.assertEqual([1, 2, 3], res) self.assertRaises( RuntimeError, @@ -538,7 +539,7 @@ def c3(result): cond.notify(2048) cond.release() self.event_loop.run_once() - self.assertEqual([1,2,3], result) + self.assertEqual([1, 2, 3], result) def test_notify_all(self): cond = locks.Condition() @@ -569,7 +570,7 @@ def c2(result): cond.notify_all() cond.release() self.event_loop.run_once() - self.assertEqual([1,2], result) + self.assertEqual([1, 2], result) def test_notify_unacquired(self): cond = locks.Condition() @@ -629,12 +630,12 @@ def test_acquire(self): @tasks.coroutine def c1(result): - res = yield from sem.acquire() + yield from sem.acquire() result.append(1) @tasks.coroutine def c2(result): - res = yield from sem.acquire() + yield from sem.acquire() result.append(2) @tasks.coroutine diff --git a/tests/selectors_test.py b/tests/selectors_test.py index 3ebaab8c..996c0130 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -3,7 +3,6 @@ import unittest import unittest.mock -from tulip import events from tulip import selectors @@ -43,7 +42,7 @@ def test_register_already_registered(self): fobj.fileno.return_value = 10 s = selectors._BaseSelector() - key = s.register(fobj, selectors.EVENT_READ) + s.register(fobj, selectors.EVENT_READ) self.assertRaises(ValueError, s.register, fobj, selectors.EVENT_READ) def test_unregister(self): @@ -51,7 +50,7 @@ def test_unregister(self): fobj.fileno.return_value = 10 s = selectors._BaseSelector() - key = s.register(fobj, selectors.EVENT_READ) + s.register(fobj, selectors.EVENT_READ) s.unregister(fobj) self.assertFalse(s._fd_to_key) self.assertFalse(s._fileobj_to_key) @@ -105,7 +104,7 @@ def test_select(self): def test_close(self): s = selectors._BaseSelector() - key = s.register(1, selectors.EVENT_READ) + s.register(1, selectors.EVENT_READ) s.close() self.assertFalse(s._fd_to_key) diff --git a/tests/subprocess_test.py b/tests/subprocess_test.py index 14ce11d7..09aaed52 100644 --- a/tests/subprocess_test.py +++ b/tests/subprocess_test.py @@ -46,7 +46,7 @@ def tearDown(self): def test_unix_subprocess(self): p = MyProto() - t = subprocess_transport.UnixSubprocessTransport(p, ['/bin/ls', '-lR']) + subprocess_transport.UnixSubprocessTransport(p, ['/bin/ls', '-lR']) self.event_loop.run() diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 25ca5a4f..2a11c202 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -12,8 +12,10 @@ class Dummy: + def __repr__(self): return 'Dummy()' + def __call__(self, *args): pass @@ -92,14 +94,17 @@ def outer(): a = yield from inner1() b = yield from inner2() return a+b + @tasks.task def inner1(): yield from [] return 42 + @tasks.task def inner2(): yield from [] return 1000 + t = outer() self.assertEqual(self.event_loop.run_until_complete(t), 1042) @@ -133,6 +138,7 @@ def task(): def test_stop_while_run_in_complete(self): x = 0 + @tasks.coroutine def task(): nonlocal x @@ -184,12 +190,14 @@ def task(): def test_wait(self): a = tasks.sleep(0.1) b = tasks.sleep(0.15) + @tasks.coroutine def foo(): done, pending = yield from tasks.wait([b, a]) self.assertEqual(done, set([a, b])) self.assertEqual(pending, set()) return 42 + t0 = time.monotonic() res = self.event_loop.run_until_complete(tasks.Task(foo())) t1 = time.monotonic() @@ -206,7 +214,7 @@ def test_wait_first_completed(self): a = tasks.sleep(10.0) b = tasks.sleep(0.1) task = tasks.Task(tasks.wait( - [b, a], return_when=tasks.FIRST_COMPLETED)) + [b, a], return_when=tasks.FIRST_COMPLETED)) done, pending = self.event_loop.run_until_complete(task) self.assertEqual({b}, done) @@ -220,13 +228,15 @@ def test_wait_really_done(self): @tasks.coroutine def coro1(): yield from [None] + @tasks.coroutine def coro2(): yield from [None, None] a = tasks.Task(coro1()) b = tasks.Task(coro2()) - task = tasks.Task(tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED)) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED)) done, pending = self.event_loop.run_until_complete(task) self.assertEqual({a, b}, done) @@ -235,6 +245,7 @@ def test_wait_first_exception(self): self.suppress_log_errors() a = tasks.sleep(10.0) + @tasks.coroutine def exc(): yield from [] @@ -242,7 +253,7 @@ def exc(): b = tasks.Task(exc()) task = tasks.Task(tasks.wait( - [b, a], return_when=tasks.FIRST_EXCEPTION)) + [b, a], return_when=tasks.FIRST_EXCEPTION)) done, pending = self.event_loop.run_until_complete(task) self.assertEqual({b}, done) @@ -251,11 +262,14 @@ def exc(): def test_wait_with_exception(self): self.suppress_log_errors() a = tasks.sleep(0.1) + @tasks.coroutine def sleeper(): yield from tasks.sleep(0.15) raise ZeroDivisionError('really') + b = tasks.Task(sleeper()) + @tasks.coroutine def foo(): done, pending = yield from tasks.wait([b, a]) @@ -263,25 +277,28 @@ def foo(): self.assertEqual(pending, set()) errors = set(f for f in done if f.exception() is not None) self.assertEqual(len(errors), 1) + t0 = time.monotonic() - res = self.event_loop.run_until_complete(tasks.Task(foo())) + self.event_loop.run_until_complete(tasks.Task(foo())) t1 = time.monotonic() self.assertTrue(t1-t0 >= 0.14) t0 = time.monotonic() - res = self.event_loop.run_until_complete(tasks.Task(foo())) + self.event_loop.run_until_complete(tasks.Task(foo())) t1 = time.monotonic() self.assertTrue(t1-t0 <= 0.01) def test_wait_with_timeout(self): a = tasks.sleep(0.1) b = tasks.sleep(0.15) + @tasks.coroutine def foo(): done, pending = yield from tasks.wait([b, a], timeout=0.11) self.assertEqual(done, set([a])) self.assertEqual(pending, set([b])) + t0 = time.monotonic() - res = self.event_loop.run_until_complete(tasks.Task(foo())) + self.event_loop.run_until_complete(tasks.Task(foo())) t1 = time.monotonic() self.assertTrue(t1-t0 >= 0.1) self.assertTrue(t1-t0 <= 0.13) @@ -291,15 +308,18 @@ def test_as_completed(self): def sleeper(dt, x): yield from tasks.sleep(dt) return x + a = sleeper(0.1, 'a') b = sleeper(0.1, 'b') c = sleeper(0.15, 'c') + @tasks.coroutine def foo(): values = [] for f in tasks.as_completed([b, c, a]): values.append((yield from f)) return values + t0 = time.monotonic() res = self.event_loop.run_until_complete(tasks.Task(foo())) t1 = time.monotonic() @@ -317,6 +337,7 @@ def test_as_completed_with_timeout(self): self.suppress_log_errors() a = tasks.sleep(0.1, 'a') b = tasks.sleep(0.15, 'b') + @tasks.coroutine def foo(): values = [] @@ -327,6 +348,7 @@ def foo(): except futures.TimeoutError as exc: values.append((2, exc)) return values + t0 = time.monotonic() res = self.event_loop.run_until_complete(tasks.Task(foo())) t1 = time.monotonic() @@ -342,6 +364,7 @@ def sleeper(dt, arg): yield from tasks.sleep(dt/2) res = yield from tasks.sleep(dt/2, arg) return res + t = tasks.Task(sleeper(0.1, 'yeah')) t0 = time.monotonic() self.event_loop.run() @@ -352,27 +375,30 @@ def sleeper(dt, arg): def test_task_cancel_sleeping_task(self): sleepfut = None + @tasks.task def sleep(dt): nonlocal sleepfut sleepfut = tasks.sleep(dt) try: - t0 = time.monotonic() + time.monotonic() yield from sleepfut finally: - t1 = time.monotonic() + time.monotonic() + @tasks.task def doit(): sleeper = sleep(5000) self.event_loop.call_later(0.1, sleeper.cancel) try: - t0 = time.monotonic() + time.monotonic() yield from sleeper except futures.CancelledError: - t1 = time.monotonic() + time.monotonic() return 'cancelled' else: return 'slept in' + t0 = time.monotonic() doer = doit() self.assertEqual(self.event_loop.run_until_complete(doer), 'cancelled') @@ -420,6 +446,7 @@ class Fut(futures.Future): def __init__(self, *args): self.cb_added = False super().__init__(*args) + def add_done_callback(self, fn): self.cb_added = True super().add_done_callback(fn) @@ -432,7 +459,7 @@ def wait_for_future(): nonlocal result result = yield from fut - task = wait_for_future() + wait_for_future() self.event_loop.run_once() self.assertTrue(fut.cb_added) @@ -449,6 +476,7 @@ class Fut(concurrent.futures.Future): def __init__(self): self.cb_added = False super().__init__() + def add_done_callback(self, fn): self.cb_added = True super().add_done_callback(fn) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index 2504648b..24ea4945 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -52,7 +52,7 @@ def test_add_signal_handler(self, m_signal): m_signal.NSIG = signal.NSIG h = self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) - self.assertIsInstance(h, events.Handler) + self.assertIsInstance(h, events.Handle) @unittest.mock.patch('tulip.unix_events.signal') def test_add_signal_handler_install_error(self, m_signal): @@ -108,7 +108,7 @@ class Err(OSError): def test_remove_signal_handler(self, m_signal): m_signal.NSIG = signal.NSIG - h = self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) self.assertTrue( self.event_loop.remove_signal_handler(signal.SIGHUP)) @@ -122,7 +122,7 @@ def test_remove_signal_handler_2(self, m_signal): m_signal.NSIG = signal.NSIG m_signal.SIGINT = signal.SIGINT - h = self.event_loop.add_signal_handler(signal.SIGINT, lambda: True) + self.event_loop.add_signal_handler(signal.SIGINT, lambda: True) self.event_loop._signal_handlers[signal.SIGHUP] = object() m_signal.set_wakeup_fd.reset_mock() @@ -138,7 +138,7 @@ def test_remove_signal_handler_2(self, m_signal): @unittest.mock.patch('tulip.unix_events.logging') def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): m_signal.NSIG = signal.NSIG - h = self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) m_signal.set_wakeup_fd.side_effect = ValueError @@ -148,7 +148,7 @@ def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): @unittest.mock.patch('tulip.unix_events.signal') def test_remove_signal_handler_error(self, m_signal): m_signal.NSIG = signal.NSIG - h = self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) m_signal.signal.side_effect = OSError @@ -158,7 +158,7 @@ def test_remove_signal_handler_error(self, m_signal): @unittest.mock.patch('tulip.unix_events.signal') def test_remove_signal_handler_error2(self, m_signal): m_signal.NSIG = signal.NSIG - h = self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) class Err(OSError): errno = errno.EINVAL diff --git a/tests/winsocketpair_test.py b/tests/winsocketpair_test.py index 0175e9b9..b93bdaf1 100644 --- a/tests/winsocketpair_test.py +++ b/tests/winsocketpair_test.py @@ -22,6 +22,7 @@ def test_winsocketpair(self): @unittest.mock.patch('tulip.winsocketpair.socket') def test_winsocketpair_exc(self, m_socket): m_socket.error = socket.error + class Err(socket.error): errno = errno.WSAEWOULDBLOCK + 1 diff --git a/tulip/base_events.py b/tulip/base_events.py index 9e2e161c..c2cf18a5 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -48,7 +48,6 @@ def __init__(self): self._scheduled = [] self._default_executor = None self._internal_fds = 0 - self._signal_handlers = {} def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): """Create socket transport.""" @@ -99,11 +98,11 @@ def run_forever(self): This only makes sense over run() if you have another thread scheduling callbacks using call_soon_threadsafe(). """ - handler = self.call_repeatedly(24*3600, lambda: None) + handle = self.call_repeatedly(24*3600, lambda: None) try: self.run() finally: - handler.cancel() + handle.cancel() def run_once(self, timeout=None): """Run through all callbacks and all I/O polls once. @@ -121,21 +120,23 @@ def run_until_complete(self, future, timeout=None): Return the Future's result, or raise its exception. If the timeout is reached or stop() is called, raise TimeoutError. """ - handler_called = False + handle_called = False + def stop_loop(): - nonlocal handler_called - handler_called = True + nonlocal handle_called + handle_called = True raise _StopError + future.add_done_callback(_raise_stop_error) if timeout is None: self.run_forever() else: - handler = self.call_later(timeout, stop_loop) + handle = self.call_later(timeout, stop_loop) self.run() - handler.cancel() + handle.cancel() - if handler_called: + if handle_called: raise futures.TimeoutError return future.result() @@ -175,19 +176,21 @@ def call_later(self, delay, callback, *args): """ if delay <= 0: return self.call_soon(callback, *args) - handler = events.Timer(time.monotonic() + delay, callback, args) - heapq.heappush(self._scheduled, handler) - return handler + + handle = events.Timer(time.monotonic() + delay, callback, args) + heapq.heappush(self._scheduled, handle) + return handle def call_repeatedly(self, interval, callback, *args): """Call a callback every 'interval' seconds.""" def wrapper(): callback(*args) # If this fails, the chain is broken. - handler._when = time.monotonic() + interval - heapq.heappush(self._scheduled, handler) - handler = events.Timer(time.monotonic() + interval, wrapper, ()) - heapq.heappush(self._scheduled, handler) - return handler + handle._when = time.monotonic() + interval + heapq.heappush(self._scheduled, handle) + + handle = events.Timer(time.monotonic() + interval, wrapper, ()) + heapq.heappush(self._scheduled, handle) + return handle def call_soon(self, callback, *args): """Arrange for a callback to be called as soon as possible. @@ -199,18 +202,18 @@ def call_soon(self, callback, *args): Any positional arguments after the callback will be passed to the callback when it is called. """ - handler = events.make_handler(callback, args) - self._ready.append(handler) - return handler + handle = events.make_handle(callback, args) + self._ready.append(handle) + return handle def call_soon_threadsafe(self, callback, *args): """XXX""" - handler = self.call_soon(callback, *args) + handle = self.call_soon(callback, *args) self._write_to_self() - return handler + return handle def run_in_executor(self, executor, callback, *args): - if isinstance(callback, events.Handler): + if isinstance(callback, events.Handle): assert not args assert not isinstance(callback, events.Timer) if callback.cancelled: @@ -414,14 +417,14 @@ def start_serving_datagram(self, protocol_factory, host, port, *, return sock - def _add_callback(self, handler): - """Add a Handler to ready or scheduled.""" - if handler.cancelled: + def _add_callback(self, handle): + """Add a Handle to ready or scheduled.""" + if handle.cancelled: return - if isinstance(handler, events.Timer): - heapq.heappush(self._scheduled, handler) + if isinstance(handle, events.Timer): + heapq.heappush(self._scheduled, handle) else: - self._ready.append(handler) + self._ready.append(handle) def wrap_future(self, future): """XXX""" @@ -479,11 +482,11 @@ def _run_once(self, timeout=None): # Handle 'later' callbacks that are ready. now = time.monotonic() while self._scheduled: - handler = self._scheduled[0] - if handler.when > now: + handle = self._scheduled[0] + if handle.when > now: break - handler = heapq.heappop(self._scheduled) - self._ready.append(handler) + handle = heapq.heappop(self._scheduled) + self._ready.append(handle) # This is the only place where callbacks are actually *called*. # All other places just add them to ready. @@ -493,6 +496,6 @@ def _run_once(self, timeout=None): # Use an idiom that is threadsafe without using locks. ntodo = len(self._ready) for i in range(ntodo): - handler = self._ready.popleft() - if not handler.cancelled: - handler.run() + handle = self._ready.popleft() + if not handle.cancelled: + handle.run() diff --git a/tulip/events.py b/tulip/events.py index a5c2dd4c..9bad35fb 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -5,7 +5,7 @@ """ __all__ = ['EventLoopPolicy', 'DefaultEventLoopPolicy', - 'AbstractEventLoop', 'Timer', 'Handler', 'make_handler', + 'AbstractEventLoop', 'Timer', 'Handle', 'make_handle', 'get_event_loop_policy', 'set_event_loop_policy', 'get_event_loop', 'set_event_loop', 'new_event_loop', ] @@ -15,7 +15,7 @@ import threading -class Handler: +class Handle: """Object returned by callback registration methods.""" def __init__(self, callback, args): @@ -24,7 +24,7 @@ def __init__(self, callback, args): self._cancelled = False def __repr__(self): - res = 'Handler({}, {})'.format(self._callback, self._args) + res = 'Handle({}, {})'.format(self._callback, self._args) if self._cancelled: res += '' return res @@ -52,19 +52,20 @@ def run(self): self._callback, self._args) -def make_handler(callback, args): - if isinstance(callback, Handler): +def make_handle(callback, args): + if isinstance(callback, Handle): assert not args return callback - return Handler(callback, args) + return Handle(callback, args) -class Timer(Handler): +class Timer(Handle): """Object returned by timed callback registration methods.""" def __init__(self, when, callback, args): assert when is not None super().__init__(callback, args) + self._when = when def __repr__(self): @@ -73,6 +74,7 @@ def __repr__(self): self._args) if self._cancelled: res += '' + return res @property @@ -144,7 +146,7 @@ def stop(self): # NEW! """ raise NotImplementedError - # Methods returning Handlers for scheduling callbacks. + # Methods returning Handles for scheduling callbacks. def call_later(self, delay, callback, *args): raise NotImplementedError @@ -192,7 +194,7 @@ def start_serving_datagram(self, protocol_factory, host, port, *, raise NotImplementedError # Ready-based callback registration methods. - # The add_*() methods return a Handler. + # The add_*() methods return a Handle. # The remove_*() methods return True if something was removed, # False if there was nothing to delete. diff --git a/tulip/futures.py b/tulip/futures.py index 4bb2f198..68735f35 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -52,7 +52,7 @@ class Future: _result = None _exception = None - _blocking = False # proper use of future (yield vs yield from) + _blocking = False # proper use of future (yield vs yield from) def __init__(self, *, event_loop=None): """Initialize the future. @@ -83,7 +83,7 @@ def __repr__(self): else: res += '<{}, {}>'.format(self._state, self._callbacks) else: - res +='<{}>'.format(self._state) + res += '<{}>'.format(self._state) return res def cancel(self): diff --git a/tulip/locks.py b/tulip/locks.py index e55487a6..c86048f4 100644 --- a/tulip/locks.py +++ b/tulip/locks.py @@ -69,7 +69,8 @@ def __init__(self): def __repr__(self): res = super().__repr__() - return '<%s [%s]>'%(res[1:-1], 'locked' if self._locked else 'unlocked') + return '<%s [%s]>' % ( + res[1:-1], 'locked' if self._locked else 'unlocked') def locked(self): """Return true if lock is acquired.""" @@ -95,9 +96,9 @@ def acquire(self, timeout=None): fut = futures.Future(event_loop=self._event_loop) if timeout is not None: - handler = self._event_loop.call_later(timeout, fut.cancel) + handle = self._event_loop.call_later(timeout, fut.cancel) else: - handler = None + handle = None self._waiters.append(fut) try: @@ -109,8 +110,8 @@ def acquire(self, timeout=None): f = self._waiters.popleft() assert f is fut - if handler is not None: - handler.cancel() + if handle is not None: + handle.cancel() self._locked = True return True @@ -163,7 +164,7 @@ def __init__(self): def __repr__(self): res = super().__repr__() - return '<%s [%s]>'%(res[1:-1], 'set' if self._value else 'unset') + return '<%s [%s]>' % (res[1:-1], 'set' if self._value else 'unset') def is_set(self): """Return true if and only if the internal flag is true.""" @@ -210,9 +211,9 @@ def wait(self, timeout=None): fut = futures.Future(event_loop=self._event_loop) if timeout is not None: - handler = self._event_loop.call_later(timeout, fut.cancel) + handle = self._event_loop.call_later(timeout, fut.cancel) else: - handler = None + handle = None self._waiters.append(fut) try: @@ -224,8 +225,8 @@ def wait(self, timeout=None): f = self._waiters.popleft() assert f is fut - if handler is not None: - handler.cancel() + if handle is not None: + handle.cancel() return True @@ -268,9 +269,9 @@ def wait(self, timeout=None): fut = futures.Future(event_loop=self._event_loop) if timeout is not None: - handler = self._event_loop.call_later(timeout, fut.cancel) + handle = self._event_loop.call_later(timeout, fut.cancel) else: - handler = None + handle = None self._condition_waiters.append(fut) try: @@ -284,8 +285,8 @@ def wait(self, timeout=None): finally: yield from self.acquire() - if handler is not None: - handler.cancel() + if handle is not None: + handle.cancel() return True @@ -378,9 +379,9 @@ def __init__(self, value=1, bound=False): def __repr__(self): res = super().__repr__() - return '<%s [%s]>'%( + return '<%s [%s]>' % ( res[1:-1], - 'locked' if self._locked else 'unlocked,value:%s'%self._value) + 'locked' if self._locked else 'unlocked,value:%s' % self._value) def locked(self): """Returns True if semaphore can not be acquired immediately.""" @@ -407,9 +408,9 @@ def acquire(self, timeout=None): fut = futures.Future(event_loop=self._event_loop) if timeout is not None: - handler = self._event_loop.call_later(timeout, fut.cancel) + handle = self._event_loop.call_later(timeout, fut.cancel) else: - handler = None + handle = None self._waiters.append(fut) try: @@ -421,8 +422,8 @@ def acquire(self, timeout=None): f = self._waiters.popleft() assert f is fut - if handler is not None: - handler.cancel() + if handle is not None: + handle.cancel() self._value -= 1 if self._value == 0: diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 5feaafa4..cc7fe33c 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -127,20 +127,20 @@ def _accept_connection(self, protocol_factory, sock): # It's now up to the protocol to handle the connection. def add_reader(self, fd, callback, *args): - """Add a reader callback. Return a Handler instance.""" - handler = events.make_handler(callback, args) + """Add a reader callback. Return a Handle instance.""" + handle = events.make_handle(callback, args) try: mask, (reader, writer) = self._selector.get_info(fd) except KeyError: self._selector.register(fd, selectors.EVENT_READ, - (handler, None)) + (handle, None)) else: self._selector.modify(fd, mask | selectors.EVENT_READ, - (handler, writer)) + (handle, writer)) if reader is not None: reader.cancel() - return handler + return handle def remove_reader(self, fd): """Remove a reader callback.""" @@ -162,20 +162,20 @@ def remove_reader(self, fd): return False def add_writer(self, fd, callback, *args): - """Add a writer callback. Return a Handler instance.""" - handler = events.make_handler(callback, args) + """Add a writer callback. Return a Handle instance.""" + handle = events.make_handle(callback, args) try: mask, (reader, writer) = self._selector.get_info(fd) except KeyError: self._selector.register(fd, selectors.EVENT_WRITE, - (None, handler)) + (None, handle)) else: self._selector.modify(fd, mask | selectors.EVENT_WRITE, - (reader, handler)) + (reader, handle)) if writer is not None: writer.cancel() - return handler + return handle def remove_writer(self, fd): """Remove a writer callback.""" diff --git a/tulip/selectors.py b/tulip/selectors.py index b8b830eb..e8fd12e9 100644 --- a/tulip/selectors.py +++ b/tulip/selectors.py @@ -11,7 +11,7 @@ # generic events, that must be mapped to implementation-specific ones # read event -EVENT_READ = (1 << 0) +EVENT_READ = (1 << 0) # write event EVENT_WRITE = (1 << 1) @@ -361,7 +361,6 @@ def __init__(self): def unregister(self, fileobj): key = super().unregister(fileobj) - mask = 0 if key.events & EVENT_READ: kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_DELETE) self._kqueue.control([kev], 0, 0) diff --git a/tulip/subprocess_transport.py b/tulip/subprocess_transport.py index 721013f8..45e6f6fe 100644 --- a/tulip/subprocess_transport.py +++ b/tulip/subprocess_transport.py @@ -1,4 +1,5 @@ import fcntl +import logging import os import traceback diff --git a/tulip/tasks.py b/tulip/tasks.py index 2e6b73f3..08bbb31c 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -203,12 +203,12 @@ def _wait(fs, timeout=None, return_when=ALL_COMPLETED): return done, pending bail = futures.Future() # Will always be cancelled eventually. - timeout_handler = None + timeout_handle = None debugstuff = locals() if timeout is not None: loop = events.get_event_loop() - timeout_handler = loop.call_later(timeout, bail.cancel) + timeout_handle = loop.call_later(timeout, bail.cancel) def _on_completion(f): pending.remove(f) @@ -230,8 +230,8 @@ def _on_completion(f): finally: for f in pending: f.remove_done_callback(_on_completion) - if timeout_handler is not None: - timeout_handler.cancel() + if timeout_handle is not None: + timeout_handle.cancel() really_done = set(f for f in pending if f.done()) if really_done: @@ -259,11 +259,14 @@ def as_completed(fs, timeout=None): deadline = None if timeout is not None: deadline = time.monotonic() + timeout + done = None # Make nonlocal happy. fs = _wrap_coroutines(fs) + while fs: if deadline is not None: timeout = deadline - time.monotonic() + @coroutine def _wait_for_some(): nonlocal done, fs @@ -273,6 +276,7 @@ def _wait_for_some(): fs = set() raise futures.TimeoutError() return done.pop().result() # May raise. + yield Task(_wait_for_some()) for f in done: yield f diff --git a/tulip/unix_events.py b/tulip/unix_events.py index db2c560d..41f8e0b4 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -26,6 +26,11 @@ class SelectorEventLoop(selector_events.BaseSelectorEventLoop): Adds signal handling to SelectorEventLoop """ + + def __init__(self, selector=None): + super().__init__(selector) + self._signal_handlers = {} + def _socketpair(self): return socket.socketpair() @@ -44,8 +49,10 @@ def add_signal_handler(self, sig, callback, *args): signal.set_wakeup_fd(self._csock.fileno()) except ValueError as exc: raise RuntimeError(str(exc)) - handler = events.make_handler(callback, args) - self._signal_handlers[sig] = handler + + handle = events.make_handle(callback, args) + self._signal_handlers[sig] = handle + try: signal.signal(sig, self._handle_signal) except OSError as exc: @@ -55,21 +62,23 @@ def add_signal_handler(self, sig, callback, *args): signal.set_wakeup_fd(-1) except ValueError as nexc: logging.info('set_wakeup_fd(-1) failed: %s', nexc) + if exc.errno == errno.EINVAL: raise RuntimeError('sig {} cannot be caught'.format(sig)) else: raise - return handler + + return handle def _handle_signal(self, sig, arg): """Internal helper that is the actual signal handler.""" - handler = self._signal_handlers.get(sig) - if handler is None: + handle = self._signal_handlers.get(sig) + if handle is None: return # Assume it's some race condition. - if handler.cancelled: + if handle.cancelled: self.remove_signal_handler(sig) # Remove it properly. else: - self.call_soon_threadsafe(handler) + self.call_soon_threadsafe(handle) def remove_signal_handler(self, sig): """Remove a handler for a signal. UNIX only. @@ -80,10 +89,12 @@ def remove_signal_handler(self, sig): del self._signal_handlers[sig] except KeyError: return False + if sig == signal.SIGINT: handler = signal.default_int_handler else: handler = signal.SIG_DFL + try: signal.signal(sig, handler) except OSError as exc: @@ -91,11 +102,13 @@ def remove_signal_handler(self, sig): raise RuntimeError('sig {} cannot be caught'.format(sig)) else: raise + if not self._signal_handlers: try: signal.set_wakeup_fd(-1) except ValueError as exc: logging.info('set_wakeup_fd(-1) failed: %s', exc) + return True def _check_signal(self, sig): @@ -106,8 +119,10 @@ def _check_signal(self, sig): """ if not isinstance(sig, int): raise TypeError('sig must be an int, not {!r}'.format(sig)) + if signal is None: raise RuntimeError('Signals are not supported') + if not (1 <= sig < signal.NSIG): - raise ValueError('sig {} out of range(1, {})'.format(sig, - signal.NSIG)) + raise ValueError( + 'sig {} out of range(1, {})'.format(sig, signal.NSIG)) diff --git a/tulip/windows_events.py b/tulip/windows_events.py index 3a0b8675..4297f804 100644 --- a/tulip/windows_events.py +++ b/tulip/windows_events.py @@ -1,17 +1,13 @@ -"""Selector and proactor eventloops for Windows. -""" +"""Selector and proactor eventloops for Windows.""" -import errno import logging import socket import weakref import struct import _winapi - from . import futures from . import proactor_events -from . import selectors from . import selector_events from . import winsocketpair from . import _overlapped @@ -77,6 +73,7 @@ def accept(self, listener): conn = self._get_accept_socket() ov = _overlapped.Overlapped(NULL) ov.AcceptEx(listener.fileno(), conn.fileno()) + def finish_accept(): addr = ov.getresult() buf = struct.pack('@P', listener.fileno()) @@ -85,6 +82,7 @@ def finish_accept(): buf) conn.settimeout(listener.gettimeout()) return conn, conn.getpeername() + return self._register(ov, listener, finish_accept) def connect(self, conn, address): @@ -92,12 +90,14 @@ def connect(self, conn, address): _overlapped.BindLocal(conn.fileno(), len(address)) ov = _overlapped.Overlapped(NULL) ov.ConnectEx(conn.fileno(), address) + def finish_connect(): ov.getresult() conn.setsockopt(socket.SOL_SOCKET, _overlapped.SO_UPDATE_CONNECT_CONTEXT, 0) return conn + return self._register(ov, conn, finish_connect) def _register_with_iocp(self, obj): From cbc3f91f48d0b37ae1d0dd249ce63b01e0e2456b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 19 Mar 2013 11:49:53 -0700 Subject: [PATCH 0355/1502] More notes from the sprint. --- NOTES | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/NOTES b/NOTES index 1121cca9..4b710d15 100644 --- a/NOTES +++ b/NOTES @@ -10,6 +10,37 @@ Notes from PyCon 2013 sprints points to lottery scheduling but also mentions that's just one of the options. +- Feedback from Bram Cohen (Bittorrent creator) about UDP. He doesn't + think connected UDP is worth supporting, it doesn't do anything + except tell the kernel about the default target address for + sendto(). Basically he says all UDP end points are servers. He + sent me his own UDP event loop so I might glean some tricks from it. + He says we should treat EINTR the same as EAGAIN and friends. (We + should use the exceptions dedicated to errno checking, BTW.) HE + said to make sure we use SO_REUSEADDR (I think we already do). He + said to set the max datagram sizes pretty large (anything larger + than the declared limit is dropped on the floor). He reminds us of + the importance of being able to pick a valid, unused port by binding + to port 0 and then using getsockname(). He has an idea where he's + like to be able to kill all registered callbacks (i.e. Handles) + belonging to a certain "context". I think this can be done at the + application level (you'd have to wrap everything that returns a + Handle and collect these handles in some set or other datastructure) + but if someone thinks it's interesting we could imagine having some + kind of notion of context part of the event loop state, + e.g. associated with a Task (see Cancellation point above). He + brought up UTP, a reimplementation of TCP over UDP with more refined + congestion control. + +- Mumblings about UNIX domain sockets and IPv6 addresses being + 4-tuples. The former can be handled by passing in a socket. There + seem to be no real use cases for the latter that can't be dealt with + by passing in suitably esoteric strings for the hostname. + getaddrinfo() will produce the appropriate 4-tuple and connect() + will accept it. + +- Mumblings on the list about add vs. set. + Notes from the second Tulip/Twisted meet-up =========================================== From a80e95f160226fb85fb30022f4f0a0694ee24ed9 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 19 Mar 2013 15:54:50 -0700 Subject: [PATCH 0356/1502] It is uTP, not UTP. --- NOTES | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/NOTES b/NOTES index 4b710d15..c8f274de 100644 --- a/NOTES +++ b/NOTES @@ -29,8 +29,8 @@ Notes from PyCon 2013 sprints but if someone thinks it's interesting we could imagine having some kind of notion of context part of the event loop state, e.g. associated with a Task (see Cancellation point above). He - brought up UTP, a reimplementation of TCP over UDP with more refined - congestion control. + brought up uTP (Micro Transport Protocol), a reimplementation of TCP + over UDP with more refined congestion control. - Mumblings about UNIX domain sockets and IPv6 addresses being 4-tuples. The former can be handled by passing in a socket. There From 2022ede504248d70ecc8088de2bf36d0afcbdff1 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 19 Mar 2013 17:30:07 -0700 Subject: [PATCH 0357/1502] Suppress tracebacks from test server. --- tests/events_test.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 9fe2d82f..20c398de 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -99,7 +99,11 @@ def get_stderr(self): def log_message(self, format, *args): pass - class SSLWSGIServer(WSGIServer): + class SilentWSGIServer(WSGIServer): + def handle_error(self, request, client_address): + pass + + class SSLWSGIServer(SilentWSGIServer): def finish_request(self, request, client_address): here = os.path.dirname(__file__) keyfile = os.path.join(here, 'sample.key') @@ -123,7 +127,7 @@ def app(environ, start_response): # Run the test WSGI server in a separate thread in order not to # interfere with event handling in the main thread - server_class = SSLWSGIServer if use_ssl else WSGIServer + server_class = SSLWSGIServer if use_ssl else SilentWSGIServer httpd = make_server('127.0.0.1', 0, app, server_class, SilentWSGIRequestHandler) server_thread = threading.Thread(target=httpd.serve_forever) From 20b16180ebe83a003ac859173b3c50d4b4c2bc1c Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 19 Mar 2013 20:32:31 -0700 Subject: [PATCH 0358/1502] Add venv and distribute to ignore list --- .hgeol | 2 + .hgignore | 11 + Makefile | 29 + NOTES | 174 ++++ README | 21 + TODO | 163 ++++ check.py | 41 + crawl.py | 143 ++++ curl.py | 35 + examples/udp_echo.py | 71 ++ old/Makefile | 16 + old/echoclt.py | 79 ++ old/echosvr.py | 60 ++ old/http_client.py | 78 ++ old/http_server.py | 68 ++ old/main.py | 134 ++++ old/p3time.py | 47 ++ old/polling.py | 535 +++++++++++++ old/scheduling.py | 354 ++++++++ old/sockets.py | 348 ++++++++ old/transports.py | 496 ++++++++++++ old/xkcd.py | 18 + old/yyftime.py | 75 ++ overlapped.c | 997 +++++++++++++++++++++++ runtests.py | 198 +++++ setup.cfg | 2 + setup.py | 4 + srv.py | 123 +++ sslsrv.py | 56 ++ tests/base_events_test.py | 271 +++++++ tests/events_test.py | 1257 +++++++++++++++++++++++++++++ tests/futures_test.py | 222 +++++ tests/http_protocol_test.py | 535 +++++++++++++ tests/locks_test.py | 753 +++++++++++++++++ tests/sample.crt | 14 + tests/sample.key | 15 + tests/selector_events_test.py | 1421 +++++++++++++++++++++++++++++++++ tests/selectors_test.py | 137 ++++ tests/streams_test.py | 315 ++++++++ tests/subprocess_test.py | 54 ++ tests/tasks_test.py | 584 ++++++++++++++ tests/transports_test.py | 39 + tests/unix_events_test.py | 168 ++++ tests/winsocketpair_test.py | 33 + tulip/TODO | 28 + tulip/__init__.py | 26 + tulip/base_events.py | 501 ++++++++++++ tulip/events.py | 330 ++++++++ tulip/futures.py | 244 ++++++ tulip/http/__init__.py | 8 + tulip/http/client.py | 154 ++++ tulip/http/protocol.py | 508 ++++++++++++ tulip/locks.py | 461 +++++++++++ tulip/proactor_events.py | 190 +++++ tulip/protocols.py | 74 ++ tulip/selector_events.py | 661 +++++++++++++++ tulip/selectors.py | 418 ++++++++++ tulip/streams.py | 145 ++++ tulip/subprocess_transport.py | 134 ++++ tulip/tasks.py | 308 +++++++ tulip/test_utils.py | 30 + tulip/transports.py | 99 +++ tulip/unix_events.py | 128 +++ tulip/windows_events.py | 157 ++++ tulip/winsocketpair.py | 34 + 65 files changed, 14834 insertions(+) create mode 100644 .hgeol create mode 100644 .hgignore create mode 100644 Makefile create mode 100644 NOTES create mode 100644 README create mode 100644 TODO create mode 100644 check.py create mode 100755 crawl.py create mode 100755 curl.py create mode 100644 examples/udp_echo.py create mode 100644 old/Makefile create mode 100644 old/echoclt.py create mode 100644 old/echosvr.py create mode 100644 old/http_client.py create mode 100644 old/http_server.py create mode 100644 old/main.py create mode 100644 old/p3time.py create mode 100644 old/polling.py create mode 100644 old/scheduling.py create mode 100644 old/sockets.py create mode 100644 old/transports.py create mode 100755 old/xkcd.py create mode 100644 old/yyftime.py create mode 100644 overlapped.c create mode 100644 runtests.py create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 srv.py create mode 100644 sslsrv.py create mode 100644 tests/base_events_test.py create mode 100644 tests/events_test.py create mode 100644 tests/futures_test.py create mode 100644 tests/http_protocol_test.py create mode 100644 tests/locks_test.py create mode 100644 tests/sample.crt create mode 100644 tests/sample.key create mode 100644 tests/selector_events_test.py create mode 100644 tests/selectors_test.py create mode 100644 tests/streams_test.py create mode 100644 tests/subprocess_test.py create mode 100644 tests/tasks_test.py create mode 100644 tests/transports_test.py create mode 100644 tests/unix_events_test.py create mode 100644 tests/winsocketpair_test.py create mode 100644 tulip/TODO create mode 100644 tulip/__init__.py create mode 100644 tulip/base_events.py create mode 100644 tulip/events.py create mode 100644 tulip/futures.py create mode 100644 tulip/http/__init__.py create mode 100644 tulip/http/client.py create mode 100644 tulip/http/protocol.py create mode 100644 tulip/locks.py create mode 100644 tulip/proactor_events.py create mode 100644 tulip/protocols.py create mode 100644 tulip/selector_events.py create mode 100644 tulip/selectors.py create mode 100644 tulip/streams.py create mode 100644 tulip/subprocess_transport.py create mode 100644 tulip/tasks.py create mode 100644 tulip/test_utils.py create mode 100644 tulip/transports.py create mode 100644 tulip/unix_events.py create mode 100644 tulip/windows_events.py create mode 100644 tulip/winsocketpair.py diff --git a/.hgeol b/.hgeol new file mode 100644 index 00000000..b6910a22 --- /dev/null +++ b/.hgeol @@ -0,0 +1,2 @@ +[patterns] +** = native diff --git a/.hgignore b/.hgignore new file mode 100644 index 00000000..25902497 --- /dev/null +++ b/.hgignore @@ -0,0 +1,11 @@ +.*\.py[co]$ +.*~$ +.*\.orig$ +.*\#.*$ +.*@.*$ +\.coverage$ +htmlcov$ +\.DS_Store$ +venv$ +distribute_setup.py$ +distribute-\d+.\d+.\d+.tar.gz$ diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..2391d89c --- /dev/null +++ b/Makefile @@ -0,0 +1,29 @@ +# Some simple testing tasks (sorry, UNIX only). + +PYTHON=python3 +FLAGS= + +test: + $(PYTHON) runtests.py -v $(FLAGS) + +testloop: + while sleep 1; do $(PYTHON) runtests.py -v $(FLAGS); done + +# See README for coverage installation instructions. +cov coverage: + $(PYTHON) runtests.py --coverage tulip -v $(FLAGS) + echo "open file://`pwd`/htmlcov/index.html" + +check: + $(PYTHON) check.py + +clean: + rm -rf `find . -name __pycache__` + rm -f `find . -type f -name '*.py[co]' ` + rm -f `find . -type f -name '*~' ` + rm -f `find . -type f -name '.*~' ` + rm -f `find . -type f -name '@*' ` + rm -f `find . -type f -name '#*#' ` + rm -f `find . -type f -name '*.orig' ` + rm -f .coverage + rm -rf htmlcov diff --git a/NOTES b/NOTES new file mode 100644 index 00000000..c8f274de --- /dev/null +++ b/NOTES @@ -0,0 +1,174 @@ +Notes from PyCon 2013 sprints +============================= + +- Cancellation. If a task creates several subtasks, and then the + parent task fails, should the subtasks be cancelled? (How do we + even establish the parent/subtask relationship?) + +- Adam Sah suggests that there might be a need for scheduling + (especially when multiple frameworks share an event loop). He + points to lottery scheduling but also mentions that's just one of + the options. + +- Feedback from Bram Cohen (Bittorrent creator) about UDP. He doesn't + think connected UDP is worth supporting, it doesn't do anything + except tell the kernel about the default target address for + sendto(). Basically he says all UDP end points are servers. He + sent me his own UDP event loop so I might glean some tricks from it. + He says we should treat EINTR the same as EAGAIN and friends. (We + should use the exceptions dedicated to errno checking, BTW.) HE + said to make sure we use SO_REUSEADDR (I think we already do). He + said to set the max datagram sizes pretty large (anything larger + than the declared limit is dropped on the floor). He reminds us of + the importance of being able to pick a valid, unused port by binding + to port 0 and then using getsockname(). He has an idea where he's + like to be able to kill all registered callbacks (i.e. Handles) + belonging to a certain "context". I think this can be done at the + application level (you'd have to wrap everything that returns a + Handle and collect these handles in some set or other datastructure) + but if someone thinks it's interesting we could imagine having some + kind of notion of context part of the event loop state, + e.g. associated with a Task (see Cancellation point above). He + brought up uTP (Micro Transport Protocol), a reimplementation of TCP + over UDP with more refined congestion control. + +- Mumblings about UNIX domain sockets and IPv6 addresses being + 4-tuples. The former can be handled by passing in a socket. There + seem to be no real use cases for the latter that can't be dealt with + by passing in suitably esoteric strings for the hostname. + getaddrinfo() will produce the appropriate 4-tuple and connect() + will accept it. + +- Mumblings on the list about add vs. set. + + +Notes from the second Tulip/Twisted meet-up +=========================================== + +Rackspace, 12/11/2012 +Glyph, Brian Warner, David Reid, Duncan McGreggor, others + +Flow control +------------ + +- Pause/resume on transport manages data_received. + +- There's also an API to tell the transport whom to pause when the + write calls are overwhelming it: IConsumer.registerProducer(). + +- There's also something called pipes but it's built on top of the + old interface. + +- Twisted has variations on the basic flow control that I should + ignore. + +Half_close +---------- + +- This sends an EOF after writing some stuff. + +- Can't write any more. + +- Problem with TLS is known (the RFC sadly specifies this behavior). + +- It must be dynamimcally discoverable whether the transport supports + half_close, since the protocol may have to do something different to + make up for its missing (e.g. use chunked encoding). Twisted uses + an interface check for this and also hasattr(trans, 'halfClose') + but a flag (or flag method) is fine too. + +Constructing transport and protocol +----------------------------------- + +- There are good reasons for passing a function to the transport + construction helper that creates the protocol. (You need these + anyway for server-side protocols.) The sequence of events is + something like + + . open socket + . create transport (pass it a socket?) + . create protocol (pass it nothing) + . proto.make_connection(transport); this does: + . self.transport = transport + . self.connection_made(transport) + + But it seems okay to skip make_connection and setting .transport. + Note that make_connection() is a concrete method on the Protocol + implementation base class, while connection_made() is an abstract + method on IProtocol. + +Event Loop +---------- + +- We discussed the sequence of actions in the event loop. I think in the + end we're fine with what Tulip currently does. There are two choices: + + Tulip: + . run ready callbacks until there aren't any left + . poll, adding more callbacks to the ready list + . add now-ready delayed callbacks to the ready list + . go to top + + Tornado: + . run all currently ready callbacks (but not new ones added during this) + . (the rest is the same) + + The difference is that in the Tulip version, CPU bound callbacks + that keep adding more to the queue will starve I/O (and yielding to + other tasks won't actually cause I/O to happen unless you do + e.g. sleep(0.001)). OTOH this may be good because it means there's + less overhead if you frequently split operations in two. + +- I think Twisted does it Tornado style (in a convoluted way :-), but + it may not matter, and it's important to leave this vague so + implementations can do what's best for their platform. (E.g. if the + event loop is built into the OS there are different trade-offs.) + +System call cost +---------------- + +- System calls on MacOS are expensive, on Linux they are cheap. + +- Optimal buffer size ~16K. + +- Try joining small buffer pieces together, but expect to be tuning + this later. + +Futures +------- + +- Futures are the most robust API for async stuff, you can check + errors etc. So let's do this. + +- Just don't implement wait(). + +- For the basics, however, (recv/send, mostly), don't use Futures but use + basic callbacks, transport/protocol style. + +- make_connection() (by any name) can return a Future, it makes it + easier to check for errors. + +- This means revisiting the Tulip proactor branch (IOCP). + +- The semantics of add_done_callback() are fuzzy about in which thread + the callback will be called. (It may be the current thread or + another one.) We don't like that. But always inserting a + call_soon() indirection may be expensive? Glyph suggested changing + the add_done_callback() method name to something else to indicate + the changed promise. + +- Separately, I've been thinking about having two versions of + call_soon() -- a more heavy-weight one to be called from other + threads that also writes a byte to the self-pipe. + +Signals +------- + +- There was a side conversation about signals. A signal handler is + similar to another thread, so probably should use (the heavy-weight + version of) call_soon() to schedule the real callback and not do + anything else. + +- Glyph vaguely recalled some trickiness with the self-pipe. We + should be able to fix this afterwards if necessary, it shouldn't + affect the API design. diff --git a/README b/README new file mode 100644 index 00000000..85bfe5a7 --- /dev/null +++ b/README @@ -0,0 +1,21 @@ +Tulip is the codename for my reference implementation of PEP 3156. + +PEP 3156: http://www.python.org/dev/peps/pep-3156/ + +*** This requires Python 3.3 or later! *** + +Copyright/license: Open source, Apache 2.0. Enjoy. + +Master Mercurial repo: http://code.google.com/p/tulip/ + +The old code lives in the subdirectory 'old'; the new code (conforming +to PEP 3156, under construction) lives in the 'tulip' subdirectory. + +To run tests: + - make test + +To run coverage (coverage package is required): + - make coverage + + +--Guido van Rossum diff --git a/TODO b/TODO new file mode 100644 index 00000000..c6d4eead --- /dev/null +++ b/TODO @@ -0,0 +1,163 @@ +# -*- Mode: text -*- + +TO DO LARGER TASKS + +- Need more examples. + +- Benchmarkable but more realistic HTTP server? + +- Example of using UDP. + +- Write up a tutorial for the scheduling API. + +- More systematic approach to logging. Logger objects? What about + heavy-duty logging, tracking essentially all task state changes? + +- Restructure directory, move demos and benchmarks to subdirectories. + + +TO DO LATER + +- When multiple tasks are accessing the same socket, they should + either get interleaved I/O or an immediate exception; it should not + compromise the integrity of the scheduler or the app or leave a task + hanging. + +- For epoll you probably want to check/(log?) EPOLLHUP and EPOLLERR errors. + +- Add the simplest API possible to run a generator with a timeout. + +- Ensure multiple tasks can do atomic writes to the same pipe (since + UNIX guarantees that short writes to pipes are atomic). + +- Ensure some easy way of distributing accepted connections across tasks. + +- Be wary of thread-local storage. There should be a standard API to + get the current Context (which holds current task, event loop, and + maybe more) and a standard meta-API to change how that standard API + works (i.e. without monkey-patching). + +- See how much of asyncore I've already replaced. + +- Could BufferedReader reuse the standard io module's readers??? + +- Support ZeroMQ "sockets" which are user objects. Though possibly + this can be supported by getting the underlying fd? See + http://mail.python.org/pipermail/python-ideas/2012-October/017532.html + OTOH see + https://github.com/zeromq/pyzmq/blob/master/zmq/eventloop/ioloop.py + +- Study goroutines (again). + +- Benchmarks: http://nichol.as/benchmark-of-python-web-servers + + +FROM OLDER LIST + +- Multiple readers/writers per socket? (At which level? pollster, + eventloop, or scheduler?) + +- Could poll() usefully be an iterator? + +- Do we need to support more epoll and/or kqueue modes/flags/options/etc.? + +- Optimize register/unregister calls away if they cancel each other out? + +- Add explicit wait queue to wait for Task's completion, instead of + callbacks? + +- Look at pyfdpdlib's ioloop.py: + http://code.google.com/p/pyftpdlib/source/browse/trunk/pyftpdlib/lib/ioloop.py + + +MISTAKES I MADE + +- Forgetting yield from. (E.g.: scheduler.sleep(1); listener.accept().) + +- Forgot to add bare yield at end of internal function, after block(). + +- Forgot to call add_done_callback(). + +- Forgot to pass an undoer to block(), bug only found when cancelled. + +- Subtle accounting mistake in a callback. + +- Used context.eventloop from a different thread, forgetting about TLS. + +- Nasty race: eventloop.ready may contain both an I/O callback and a + cancel callback. How to avoid? Keep the DelayedCall in ready. Is + that enough? + +- If a toplevel task raises an error it just stops and nothing is logged + unless you have debug logging on. This confused me. (Then again, + previously I logged whenever a task raised an error, and that was too + chatty...) + +- Forgot to set the connection socket returned by accept() in + nonblocking mode. + +- Nastiest so far (cost me about a day): A race condition in + call_in_thread() where the Future's done_callback (which was + task.unblock()) would run immediately at the time when + add_done_callback() was called, and this screwed over the task + state. Solution: wrap the callback in eventloop.call_later(). + Ironically, I had a comment stating there might be a race condition. + +- Another bug where I was calling unblock() for the current thread + immediately after calling block(), before yielding. + +- readexactly() wasn't checking for EOF, so could be looping. + (Worse, the first fix I attempted was wrong.) + +- Spent a day trying to understand why a tentative patch trying to + move the recv() implementation into the eventloop (or the pollster) + resulted in problems cancelling a recv() call. Ultimately the + problem is that the cancellation mechanism is part of the coroutine + scheduler, which simply throws an exception into a task when it next + runs, and there isn't anything to be interrupted in the eventloop; + but the eventloop still has a reader registered (which will never + fire because I suspended the server -- that's my test case :-). + Then, the eventloop keeps running until the last file descriptor is + unregistered. What contributed to this disaster? + * I didn't build the whole infrastructure, just played with recv() + * I don't have unittests + * I don't have good logging to see what is going + +- In sockets.py, in some SSL error handling code, used the wrong + variable (sock instead of sslsock). A linter would have found this. + +- In polling.py, in KqueuePollster.register_writer(), a copy/paste + error where I was testing for "if fd not in self.readers" instead of + writers. This only came out when I had both a reader and a writer + for the same fd. + +- Submitted some changes prematurely (forgot to pass the filename on + hg ci). + +- Forgot again that shutdown(SHUT_WR) on an ssl socket does not work + as I expected. I ran into this with the origininal sockets.py and + again in transport.py. + +- Having the same callback for both reading and writing has a problem: + it may be scheduled twice, and if the first call closes the socket, + the second runs into trouble. + + +MISTAKES I MADE IN TULIP V2 + +- Nice one: Future._schedule_callbacks() wasn't scheduling any callbacks. + Spot the bug in these four lines: + + def _schedule_callbacks(self): + callbacks = self._callbacks[:] + self._callbacks[:] = [] + for callback in self._callbacks: + self._event_loop.call_soon(callback, self) + + The good news is that I found it with a unittest (albeit not the + unittest intended to exercise this particular method :-( ). + +- In _make_self_pipe_or_sock(), called _pollster.register_reader() + instead of add_reader(), trying to optimize something but breaking + things instead (since the -- internal -- API of register_reader() + had changed). diff --git a/check.py b/check.py new file mode 100644 index 00000000..64bc2cdd --- /dev/null +++ b/check.py @@ -0,0 +1,41 @@ +"""Search for lines > 80 chars or with trailing whitespace.""" + +import sys, os + +def main(): + args = sys.argv[1:] or os.curdir + for arg in args: + if os.path.isdir(arg): + for dn, dirs, files in os.walk(arg): + for fn in sorted(files): + if fn.endswith('.py'): + process(os.path.join(dn, fn)) + dirs[:] = [d for d in dirs if d[0] != '.'] + dirs.sort() + else: + process(arg) + +def isascii(x): + try: + x.encode('ascii') + return True + except UnicodeError: + return False + +def process(fn): + try: + f = open(fn) + except IOError as err: + print(err) + return + try: + for i, line in enumerate(f): + line = line.rstrip('\n') + sline = line.rstrip() + if len(line) > 80 or line != sline or not isascii(line): + print('%s:%d:%s%s' % ( + fn, i+1, sline, '_' * (len(line) - len(sline)))) + finally: + f.close() + +main() diff --git a/crawl.py b/crawl.py new file mode 100755 index 00000000..4e5bebe2 --- /dev/null +++ b/crawl.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 + +import logging +import re +import signal +import socket +import sys +import urllib.parse + +import tulip +import tulip.http + +END = '\n' +MAXTASKS = 100 + + +class Crawler: + + def __init__(self, rooturl): + self.rooturl = rooturl + self.todo = set() + self.busy = set() + self.done = {} + self.tasks = set() + self.waiter = None + self.addurl(self.rooturl, '') # Set initial work. + self.run() # Kick off work. + + def addurl(self, url, parenturl): + url = urllib.parse.urljoin(parenturl, url) + url, frag = urllib.parse.urldefrag(url) + if not url.startswith(self.rooturl): + return False + if url in self.busy or url in self.done or url in self.todo: + return False + self.todo.add(url) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(None) + return True + + @tulip.task + def run(self): + while self.todo or self.busy or self.tasks: + complete, self.tasks = yield from tulip.wait(self.tasks, timeout=0) + print(len(complete), 'completed tasks,', len(self.tasks), + 'still pending ', end=END) + for task in complete: + try: + yield from task + except Exception as exc: + print('Exception in task:', exc, end=END) + while self.todo and len(self.tasks) < MAXTASKS: + url = self.todo.pop() + self.busy.add(url) + self.tasks.add(self.process(url)) # Async task. + if self.busy: + self.waiter = tulip.Future() + yield from self.waiter + tulip.get_event_loop().stop() + + @tulip.task + def process(self, url): + ok = False + p = None + try: + print('processing', url, end=END) + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not path: + path = '/' + if query: + path = '?'.join([path, query]) + p = tulip.http.HttpClientProtocol( + netloc, path=path, ssl=(scheme=='https')) + delay = 1 + while True: + try: + status, headers, stream = yield from p.connect() + break + except socket.error as exc: + if delay >= 60: + raise + print('...', url, 'has error', repr(str(exc)), + 'retrying after sleep', delay, '...', end=END) + yield from tulip.sleep(delay) + delay *= 2 + + if status[:3] in ('301', '302'): + # Redirect. + u = headers.get('location') or headers.get('uri') + if self.addurl(u, url): + print(' ', url, status[:3], 'redirect to', u, end=END) + elif status.startswith('200'): + ctype = headers.get_content_type() + if ctype == 'text/html': + while True: + line = yield from stream.readline() + if not line: + break + line = line.decode('utf-8', 'replace') + urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', + line) + for u in urls: + if self.addurl(u, url): + print(' ', url, 'href to', u, end=END) + ok = True + finally: + if p is not None: + p.transport.close() + self.done[url] = ok + self.busy.remove(url) + if not ok: + print('failure for', url, sys.exc_info(), end=END) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(None) + + +def main(): + rooturl = sys.argv[1] + c = Crawler(rooturl) + loop = tulip.get_event_loop() + try: + loop.add_signal_handler(signal.SIGINT, loop.stop) + except RuntimeError: + pass + loop.run_forever() + print('todo:', len(c.todo)) + print('busy:', len(c.busy)) + print('done:', len(c.done), '; ok:', sum(c.done.values())) + print('tasks:', len(c.tasks)) + + +if __name__ == '__main__': + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + logging.info('using iocp') + el = windows_events.ProactorEventLoop() + events.set_event_loop(el) + main() diff --git a/curl.py b/curl.py new file mode 100755 index 00000000..0624df86 --- /dev/null +++ b/curl.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 + +import sys +import urllib.parse + +import tulip +import tulip.http + + +def main(): + url = sys.argv[1] + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not path: + path = '/' + if query: + path = '?'.join([path, query]) + print(netloc, path, scheme) + p = tulip.http.HttpClientProtocol(netloc, path=path, ssl=(scheme=='https')) + f = p.connect() + sts, headers, stream = p.event_loop.run_until_complete(tulip.Task(f)) + print(sts) + for k, v in headers.items(): + print('{}: {}'.format(k, v)) + print() + data = p.event_loop.run_until_complete(tulip.Task(stream)) + print(data.decode('utf-8', 'replace')) + + +if __name__ == '__main__': + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + el = windows_events.ProactorEventLoop() + events.set_event_loop(el) + main() diff --git a/examples/udp_echo.py b/examples/udp_echo.py new file mode 100644 index 00000000..c92cb06d --- /dev/null +++ b/examples/udp_echo.py @@ -0,0 +1,71 @@ +"""UDP echo example. + +Start server: + + >> python ./udp_echo.py --server + +""" + +import sys +import tulip + +ADDRESS = ('127.0.0.1', 10000) + + +class MyUdpEchoProtocol: + + def connection_made(self, transport): + print('start', transport) + self.transport = transport + + def datagram_received(self, data, addr): + print('Data received:', data, addr) + self.transport.sendto(data, addr) + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('stop', exc) + + +class MyClientUdpEchoProtocol: + + message = 'This is the message. It will be repeated.' + + def connection_made(self, transport): + self.transport = transport + print('sending "%s"' % self.message) + self.transport.sendto(self.message.encode()) + print('waiting to receive') + + def datagram_received(self, data, addr): + print('received "%s"' % data.decode()) + self.transport.close() + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('closing transport', exc) + loop = tulip.get_event_loop() + loop.stop() + + +def start_server(): + loop = tulip.get_event_loop() + loop.start_serving_datagram(MyUdpEchoProtocol, *ADDRESS) + loop.run_forever() + + +def start_client(): + loop = tulip.get_event_loop() + loop.create_datagram_connection(MyClientUdpEchoProtocol, *ADDRESS) + loop.run_forever() + + +if __name__ == '__main__': + if '--server' in sys.argv: + start_server() + else: + start_client() diff --git a/old/Makefile b/old/Makefile new file mode 100644 index 00000000..d352cd70 --- /dev/null +++ b/old/Makefile @@ -0,0 +1,16 @@ +PYTHON=python3 + +main: + $(PYTHON) main.py -v + +echo: + $(PYTHON) echosvr.py -v + +profile: + $(PYTHON) -m profile -s time main.py + +time: + $(PYTHON) p3time.py + +ytime: + $(PYTHON) yyftime.py diff --git a/old/echoclt.py b/old/echoclt.py new file mode 100644 index 00000000..c24c573e --- /dev/null +++ b/old/echoclt.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3.3 +"""Example echo client.""" + +# Stdlib imports. +import logging +import socket +import sys +import time + +# Local imports. +import scheduling +import sockets + + +def echoclient(host, port): + """COROUTINE""" + testdata = b'hi hi hi ha ha ha\n' + try: + trans = yield from sockets.create_transport(host, port, + af=socket.AF_INET) + except OSError: + return False + try: + ok = yield from trans.send(testdata) + if ok: + response = yield from trans.recv(100) + ok = response == testdata.upper() + return ok + finally: + trans.close() + + +def doit(n): + """COROUTINE""" + t0 = time.time() + tasks = set() + for i in range(n): + t = scheduling.Task(echoclient('127.0.0.1', 1111), 'client-%d' % i) + tasks.add(t) + ok = 0 + bad = 0 + for t in tasks: + try: + yield from t + except Exception: + bad += 1 + else: + ok += 1 + t1 = time.time() + print('ok: ', ok) + print('bad:', bad) + print('dt: ', round(t1-t0, 6)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Get integer from command line. + n = 1 + for arg in sys.argv[1:]: + if not arg.startswith('-'): + n = int(arg) + break + + # Run scheduler, starting it off with doit(). + scheduling.run(doit(n)) + + +if __name__ == '__main__': + main() diff --git a/old/echosvr.py b/old/echosvr.py new file mode 100644 index 00000000..4085f4c6 --- /dev/null +++ b/old/echosvr.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3.3 +"""Example echo server.""" + +# Stdlib imports. +import logging +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + while True: + line = yield from rdr.readline() + logging.debug('Received: %r from %r', line, addr) + if not line: + break + yield from trans.send(line.upper()) + logging.debug('Closing %r', addr) + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 1111, + af=socket.AF_INET, + backlog=100) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() diff --git a/old/http_client.py b/old/http_client.py new file mode 100644 index 00000000..8937ba20 --- /dev/null +++ b/old/http_client.py @@ -0,0 +1,78 @@ +"""Crummy HTTP client. + +This is not meant as an example of how to write a good client. +""" + +# Stdlib. +import re +import time + +# Local. +import sockets + + +def urlfetch(host, port=None, path='/', method='GET', + body=None, hdrs=None, encoding='utf-8', ssl=None, af=0): + """COROUTINE: Make an HTTP 1.0 request.""" + t0 = time.time() + if port is None: + if ssl: + port = 443 + else: + port = 80 + trans = yield from sockets.create_transport(host, port, ssl=ssl, af=af) + yield from trans.send(method.encode(encoding) + b' ' + + path.encode(encoding) + b' HTTP/1.0\r\n') + if hdrs: + kwds = dict(hdrs) + else: + kwds = {} + if 'host' not in kwds: + kwds['host'] = host + if body is not None: + kwds['content_length'] = len(body) + for header, value in kwds.items(): + yield from trans.send(header.replace('_', '-').encode(encoding) + + b': ' + value.encode(encoding) + b'\r\n') + + yield from trans.send(b'\r\n') + if body is not None: + yield from trans.send(body) + + # Read HTTP response line. + rdr = sockets.BufferedReader(trans) + resp = yield from rdr.readline() + m = re.match(br'(?ix) http/(\d\.\d) \s+ (\d\d\d) \s+ ([^\r]*)\r?\n\Z', + resp) + if not m: + trans.close() + raise IOError('No valid HTTP response: %r' % resp) + http_version, status, message = m.groups() + + # Read HTTP headers. + headers = [] + hdict = {} + while True: + line = yield from rdr.readline() + if not line.strip(): + break + m = re.match(br'([^\s:]+):\s*([^\r]*)\r?\n\Z', line) + if not m: + raise IOError('Invalid header: %r' % line) + header, value = m.groups() + headers.append((header, value)) + hdict[header.decode(encoding).lower()] = value.decode(encoding) + + # Read response body. + content_length = hdict.get('content-length') + if content_length is not None: + size = int(content_length) # TODO: Catch errors. + assert size >= 0, size + else: + size = 2**20 # Protective limit (1 MB). + data = yield from rdr.readexactly(size) + trans.close() # Can this block? + t1 = time.time() + result = (host, port, path, int(status), len(data), round(t1-t0, 3)) +## print(result) + return result diff --git a/old/http_server.py b/old/http_server.py new file mode 100644 index 00000000..2b1e3dd6 --- /dev/null +++ b/old/http_server.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3.3 +"""Simple HTTP server. + +This currenty exists just so we can benchmark this thing! +""" + +# Stdlib imports. +import logging +import re +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + ##logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + + # Read but ignore request line. + request_line = yield from rdr.readline() + + # Consume headers but don't interpret them. + while True: + header_line = yield from rdr.readline() + if not header_line.strip(): + break + + # Always send an empty 200 response and close. + yield from trans.send(b'HTTP/1.0 200 Ok\r\n\r\n') + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 8080, + af=socket.AF_INET) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() diff --git a/old/main.py b/old/main.py new file mode 100644 index 00000000..c1f9d0a8 --- /dev/null +++ b/old/main.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3.3 +"""Example HTTP client using yield-from coroutines (PEP 380). + +Requires Python 3.3. + +There are many micro-optimizations possible here, but that's not the point. + +Some incomplete laundry lists: + +TODO: +- Take test urls from command line. +- Move urlfetch to a separate module. +- Profiling. +- Docstrings. +- Unittests. + +FUNCTIONALITY: +- Connection pool (keep connection open). +- Chunked encoding (request and response). +- Pipelining, e.g. zlib (request and response). +- Automatic encoding/decoding. +""" + +__author__ = 'Guido van Rossum ' + +# Standard library imports (keep in alphabetic order). +import logging +import os +import time +import socket +import sys + +# Local imports (keep in alphabetic order). +import scheduling +import http_client + + + +def doit2(): + argses = [ + ('localhost', 8080, '/'), + ('127.0.0.1', 8080, '/home'), + ('python.org', 80, '/'), + ('xkcd.com', 443, '/'), + ] + results = yield from scheduling.map_over( + lambda args: http_client.urlfetch(*args), argses, timeout=2) + for res in results: + print('-->', res) + return [] + + +def doit(): + TIMEOUT = 2 + tasks = set() + + # This references NDB's default test service. + # (Sadly the service is single-threaded.) + task1 = scheduling.Task(http_client.urlfetch('localhost', 8080, path='/'), + 'root', timeout=TIMEOUT) + tasks.add(task1) + task2 = scheduling.Task(http_client.urlfetch('127.0.0.1', 8080, + path='/home'), + 'home', timeout=TIMEOUT) + tasks.add(task2) + + # Fetch python.org home page. + task3 = scheduling.Task(http_client.urlfetch('python.org', 80, path='/'), + 'python', timeout=TIMEOUT) + tasks.add(task3) + + # Fetch XKCD home page using SSL. (Doesn't like IPv6.) + task4 = scheduling.Task(http_client.urlfetch('xkcd.com', ssl=True, path='/', + af=socket.AF_INET), + 'xkcd', timeout=TIMEOUT) + tasks.add(task4) + +## # Fetch many links from python.org (/x.y.z). +## for x in '123': +## for y in '0123456789': +## path = '/{}.{}'.format(x, y) +## g = http_client.urlfetch('82.94.164.162', 80, +## path=path, hdrs={'host': 'python.org'}) +## t = scheduling.Task(g, path, timeout=2) +## tasks.add(t) + +## print(tasks) + yield from scheduling.Task(scheduling.sleep(1), timeout=0.2).wait() + winners = yield from scheduling.wait_any(tasks) + print('And the winners are:', [w.name for w in winners]) + tasks = yield from scheduling.wait_all(tasks) + print('And the players were:', [t.name for t in tasks]) + return tasks + + +def logtimes(real): + utime, stime, cutime, cstime, unused = os.times() + logging.info('real %10.3f', real) + logging.info('user %10.3f', utime + cutime) + logging.info('sys %10.3f', stime + cstime) + + +def main(): + t0 = time.time() + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + task = scheduling.run(doit()) + if task.exception: + print('Exception:', repr(task.exception)) + if isinstance(task.exception, AssertionError): + raise task.exception + else: + for t in task.result: + print(t.name + ':', + repr(t.exception) if t.exception else t.result) + + # Report real, user, sys times. + t1 = time.time() + logtimes(t1-t0) + + +if __name__ == '__main__': + main() diff --git a/old/p3time.py b/old/p3time.py new file mode 100644 index 00000000..35e14c96 --- /dev/null +++ b/old/p3time.py @@ -0,0 +1,47 @@ +"""Compare timing of plain vs. yield-from calls.""" + +import gc +import time + +def plain(n): + if n <= 0: + return 1 + l = plain(n-1) + r = plain(n-1) + return l + 1 + r + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def submain(depth): + t0 = time.time() + k = plain(depth) + t1 = time.time() + fmt = ' {} {} {:-9,.5f}' + delta0 = t1-t0 + print(('plain' + fmt).format(depth, k, delta0)) + + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + delta1 = t1-t0 + print(('coro.' + fmt).format(depth, k, delta1)) + if delta0: + print(('relat' + fmt).format(depth, k, delta1/delta0)) + +def main(reasonable=16): + gc.disable() + for depth in range(reasonable): + submain(depth) + +if __name__ == '__main__': + main() diff --git a/old/polling.py b/old/polling.py new file mode 100644 index 00000000..6586efcc --- /dev/null +++ b/old/polling.py @@ -0,0 +1,535 @@ +"""Event loop and related classes. + +The event loop can be broken up into a pollster (the part responsible +for telling us when file descriptors are ready) and the event loop +proper, which wraps a pollster with functionality for scheduling +callbacks, immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. + +There are several implementations of the pollster part, several using +esoteric system calls that exist only on some platforms. These are: + +- kqueue (most BSD systems) +- epoll (newer Linux systems) +- poll (most UNIX systems) +- select (all UNIX systems, and Windows) +- TODO: Support IOCP on Windows and some UNIX platforms. + +NOTE: We don't use select on systems where any of the others is +available, because select performs poorly as the number of file +descriptors goes up. The ranking is roughly: + + 1. kqueue, epoll, IOCP + 2. poll + 3. select + +TODO: +- Optimize the various pollsters. +- Unittests. +""" + +import collections +import concurrent.futures +import heapq +import logging +import os +import select +import threading +import time + + +class PollsterBase: + """Base class for all polling implementations. + + This defines an interface to register and unregister readers and + writers for specific file descriptors, and an interface to get a + list of events. There's also an interface to check whether any + readers or writers are currently registered. + """ + + def __init__(self): + super().__init__() + self.readers = {} # {fd: token, ...}. + self.writers = {} # {fd: token, ...}. + + def pollable(self): + """Return True if any readers or writers are currently registered.""" + return bool(self.readers or self.writers) + + # Subclasses are expected to extend the add/remove methods. + + def register_reader(self, fd, token): + """Add or update a reader for a file descriptor.""" + self.readers[fd] = token + + def register_writer(self, fd, token): + """Add or update a writer for a file descriptor.""" + self.writers[fd] = token + + def unregister_reader(self, fd): + """Remove the reader for a file descriptor.""" + del self.readers[fd] + + def unregister_writer(self, fd): + """Remove the writer for a file descriptor.""" + del self.writers[fd] + + def poll(self, timeout=None): + """Poll for events. A subclass must implement this. + + If timeout is omitted or None, this blocks until at least one + event is ready. Otherwise, timeout gives a maximum time to + wait (in seconds as an int or float) -- the method returns as + soon as at least one event is ready or when the timeout is + expired. For a non-blocking poll, pass 0. + + The return value is a list of events; it is empty when the + timeout expired before any events were ready. Each event + is a token previously passed to register_reader/writer(). + """ + raise NotImplementedError + + +class SelectPollster(PollsterBase): + """Pollster implementation using select.""" + + def poll(self, timeout=None): + readable, writable, _ = select.select(self.readers, self.writers, + [], timeout) + events = [] + events += (self.readers[fd] for fd in readable) + events += (self.writers[fd] for fd in writable) + return events + + +class PollPollster(PollsterBase): + """Pollster implementation using poll.""" + + def __init__(self): + super().__init__() + self._poll = select.poll() + + def _update(self, fd): + assert isinstance(fd, int), fd + flags = 0 + if fd in self.readers: + flags |= select.POLLIN + if fd in self.writers: + flags |= select.POLLOUT + if flags: + self._poll.register(fd, flags) + else: + self._poll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + # Timeout is in seconds, but poll() takes milliseconds. + msecs = None if timeout is None else int(round(1000 * timeout)) + events = [] + for fd, flags in self._poll.poll(msecs): + if flags & (select.POLLIN | select.POLLHUP): + if fd in self.readers: + events.append(self.readers[fd]) + if flags & (select.POLLOUT | select.POLLHUP): + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class EPollPollster(PollsterBase): + """Pollster implementation using epoll.""" + + def __init__(self): + super().__init__() + self._epoll = select.epoll() + + def _update(self, fd): + assert isinstance(fd, int), fd + eventmask = 0 + if fd in self.readers: + eventmask |= select.EPOLLIN + if fd in self.writers: + eventmask |= select.EPOLLOUT + if eventmask: + try: + self._epoll.register(fd, eventmask) + except IOError: + self._epoll.modify(fd, eventmask) + else: + self._epoll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + if timeout is None: + timeout = -1 # epoll.poll() uses -1 to mean "wait forever". + events = [] + for fd, eventmask in self._epoll.poll(timeout): + if eventmask & select.EPOLLIN: + if fd in self.readers: + events.append(self.readers[fd]) + if eventmask & select.EPOLLOUT: + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class KqueuePollster(PollsterBase): + """Pollster implementation using kqueue.""" + + def __init__(self): + super().__init__() + self._kqueue = select.kqueue() + + def register_reader(self, fd, callback, *args): + if fd not in self.readers: + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_reader(fd, callback, *args) + + def register_writer(self, fd, callback, *args): + if fd not in self.writers: + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_writer(fd, callback, *args) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def poll(self, timeout=None): + events = [] + max_ev = len(self.readers) + len(self.writers) + for kev in self._kqueue.control(None, max_ev, timeout): + fd = kev.ident + flag = kev.filter + if flag == select.KQ_FILTER_READ and fd in self.readers: + events.append(self.readers[fd]) + elif flag == select.KQ_FILTER_WRITE and fd in self.writers: + events.append(self.writers[fd]) + return events + + +# Pick the best pollster class for the platform. +if hasattr(select, 'kqueue'): + best_pollster = KqueuePollster +elif hasattr(select, 'epoll'): + best_pollster = EPollPollster +elif hasattr(select, 'poll'): + best_pollster = PollPollster +else: + best_pollster = SelectPollster + + +class DelayedCall: + """Object returned by callback registration methods.""" + + def __init__(self, when, callback, args, kwds=None): + self.when = when + self.callback = callback + self.args = args + self.kwds = kwds + self.cancelled = False + + def cancel(self): + self.cancelled = True + + def __lt__(self, other): + return self.when < other.when + + def __le__(self, other): + return self.when <= other.when + + def __eq__(self, other): + return self.when == other.when + + +class EventLoop: + """Event loop functionality. + + This defines public APIs call_soon(), call_later(), run_once() and + run(). It also wraps Pollster APIs register_reader(), + register_writer(), remove_reader(), remove_writer() with + add_reader() etc. + + This class's instance variables are not part of its API. + """ + + def __init__(self, pollster=None): + super().__init__() + if pollster is None: + logging.info('Using pollster: %s', best_pollster.__name__) + pollster = best_pollster() + self.pollster = pollster + self.ready = collections.deque() # [(callback, args), ...] + self.scheduled = [] # [(when, callback, args), ...] + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_reader(fd, dcall) + return dcall + + def remove_reader(self, fd): + """Remove a reader callback.""" + self.pollster.unregister_reader(fd) + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_writer(fd, dcall) + return dcall + + def remove_writer(self, fd): + """Remove a writer callback.""" + self.pollster.unregister_writer(fd) + + def add_callback(self, dcall): + """Add a DelayedCall to ready or scheduled.""" + if dcall.cancelled: + return + if dcall.when is None: + self.ready.append(dcall) + else: + heapq.heappush(self.scheduled, dcall) + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + dcall = DelayedCall(None, callback, args) + self.ready.append(dcall) + return dcall + + def call_later(self, when, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The time can be an int or float, expressed in seconds. + + If when is small enough (~11 days), it's assumed to be a + relative time, meaning the call will be scheduled that many + seconds in the future; otherwise it's assumed to be a posix + timestamp as returned by time.time(). + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + if when < 10000000: + when += time.time() + dcall = DelayedCall(when, callback, args) + heapq.heappush(self.scheduled, dcall) + return dcall + + def run_once(self): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Pass in a timeout or deadline or something. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: As step 4, run everything scheduled by steps 1-3. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # TODO: Ensure this loop always finishes, even if some + # callbacks keeps registering more callbacks. + while self.ready: + dcall = self.ready.popleft() + if not dcall.cancelled: + try: + if dcall.kwds: + dcall.callback(*dcall.args, **dcall.kwds) + else: + dcall.callback(*dcall.args) + except Exception: + logging.exception('Exception in callback %s %r', + dcall.callback, dcall.args) + + # Remove delayed calls that were cancelled from head of queue. + while self.scheduled and self.scheduled[0].cancelled: + heapq.heappop(self.scheduled) + + # Inspect the poll queue. + if self.pollster.pollable(): + if self.scheduled: + when = self.scheduled[0].when + timeout = max(0, when - time.time()) + else: + timeout = None + t0 = time.time() + events = self.pollster.poll(timeout) + t1 = time.time() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + for dcall in events: + self.add_callback(dcall) + + # Handle 'later' callbacks that are ready. + now = time.time() + while self.scheduled: + dcall = self.scheduled[0] + if dcall.when > now: + break + dcall = heapq.heappop(self.scheduled) + self.call_soon(dcall.callback, *dcall.args) + + def run(self): + """Run the event loop until there is no work left to do. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + """ + while self.ready or self.scheduled or self.pollster.pollable(): + self.run_once() + + +MAX_WORKERS = 5 # Default max workers when creating an executor. + + +class ThreadRunner: + """Helper to submit work to a thread pool and wait for it. + + This is the glue between the single-threaded callback-based async + world and the threaded world. Use it to call functions that must + block and don't have an async alternative (e.g. getaddrinfo()). + + The only public API is submit(). + """ + + def __init__(self, eventloop, executor=None): + self.eventloop = eventloop + self.executor = executor # Will be constructed lazily. + self.pipe_read_fd, self.pipe_write_fd = os.pipe() + self.active_count = 0 + + def read_callback(self): + """Semi-permanent callback while at least one future is active.""" + assert self.active_count > 0, self.active_count + data = os.read(self.pipe_read_fd, 8192) # Traditional buffer size. + self.active_count -= len(data) + if self.active_count == 0: + self.eventloop.remove_reader(self.pipe_read_fd) + assert self.active_count >= 0, self.active_count + + def submit(self, func, *args, executor=None, callback=None): + """Submit a function to the thread pool. + + This returns a concurrent.futures.Future instance. The caller + should not wait for that, but rather use the callback argument.. + """ + if executor is None: + executor = self.executor + if executor is None: + # Lazily construct a default executor. + # TODO: Should this be shared between threads? + executor = concurrent.futures.ThreadPoolExecutor(MAX_WORKERS) + self.executor = executor + assert self.active_count >= 0, self.active_count + future = executor.submit(func, *args) + if self.active_count == 0: + self.eventloop.add_reader(self.pipe_read_fd, self.read_callback) + self.active_count += 1 + def done_callback(fut): + if callback is not None: + self.eventloop.call_soon(callback, fut) + # TODO: Wake up the pipe in call_soon()? + os.write(self.pipe_write_fd, b'x') + future.add_done_callback(done_callback) + return future + + +class Context(threading.local): + """Thread-local context. + + We use this to avoid having to explicitly pass around an event loop + or something to hold the current task. + + TODO: Add an API so frameworks can substitute a different notion + of context more easily. + """ + + def __init__(self, eventloop=None, threadrunner=None): + # Default event loop and thread runner are lazily constructed + # when first accessed. + self._eventloop = eventloop + self._threadrunner = threadrunner + self.current_task = None # For the benefit of scheduling.py. + + @property + def eventloop(self): + if self._eventloop is None: + self._eventloop = EventLoop() + return self._eventloop + + @property + def threadrunner(self): + if self._threadrunner is None: + self._threadrunner = ThreadRunner(self.eventloop) + return self._threadrunner + + +context = Context() # Thread-local! diff --git a/old/scheduling.py b/old/scheduling.py new file mode 100644 index 00000000..3864571d --- /dev/null +++ b/old/scheduling.py @@ -0,0 +1,354 @@ +#!/usr/bin/env python3.3 +"""Example coroutine scheduler, PEP-380-style ('yield from '). + +Requires Python 3.3. + +There are likely micro-optimizations possible here, but that's not the point. + +TODO: +- Docstrings. +- Unittests. + +PATTERNS TO TRY: +- Various synchronization primitives (Lock, RLock, Event, Condition, + Semaphore, BoundedSemaphore, Barrier). +""" + +__author__ = 'Guido van Rossum ' + +# Standard library imports (keep in alphabetic order). +from concurrent.futures import CancelledError, TimeoutError +import logging +import time +import types + +# Local imports (keep in alphabetic order). +import polling + + +context = polling.context + + +class Task: + """Wrapper around a stack of generators. + + This is a bit like a Future, but with a different interface. + + TODO: + - wait for result. + """ + + def __init__(self, gen, name=None, *, timeout=None): + assert isinstance(gen, types.GeneratorType), repr(gen) + self.gen = gen + self.name = name or gen.__name__ + self.timeout = timeout + self.eventloop = context.eventloop + self.canceleer = None + if timeout is not None: + self.canceleer = self.eventloop.call_later(timeout, self.cancel) + self.blocked = False + self.unblocker = None + self.cancelled = False + self.must_cancel = False + self.alive = True + self.result = None + self.exception = None + self.done_callbacks = [] + # Start the task immediately. + self.eventloop.call_soon(self.step) + + def add_done_callback(self, done_callback): + # For better or for worse, the callback will always be called + # with the task as an argument, like concurrent.futures.Future. + # TODO: Call it right away if task is no longer alive. + dcall = polling.DelayedCall(None, done_callback, (self,)) + self.done_callbacks.append(dcall) + self.done_callbacks = [dc for dc in self.done_callbacks + if not dc.cancelled] + return dcall + + def __repr__(self): + parts = [self.name] + is_current = (self is context.current_task) + if self.blocked: + parts.append('blocking' if is_current else 'blocked') + elif self.alive: + parts.append('running' if is_current else 'runnable') + if self.must_cancel: + parts.append('must_cancel') + if self.cancelled: + parts.append('cancelled') + if self.exception is not None: + parts.append('exception=%r' % self.exception) + elif not self.alive: + parts.append('result=%r' % (self.result,)) + if self.timeout is not None: + parts.append('timeout=%.3f' % self.timeout) + return 'Task<' + ', '.join(parts) + '>' + + def cancel(self): + if self.alive: + if not self.must_cancel and not self.cancelled: + self.must_cancel = True + if self.blocked: + self.unblock() + + def step(self): + assert self.alive, self + try: + context.current_task = self + if self.must_cancel: + self.must_cancel = False + self.cancelled = True + self.gen.throw(CancelledError()) + else: + next(self.gen) + except StopIteration as exc: + self.alive = False + self.result = exc.value + except Exception as exc: + self.alive = False + self.exception = exc + logging.debug('Uncaught exception in %s', self, + exc_info=True, stack_info=True) + except BaseException as exc: + self.alive = False + self.exception = exc + raise + else: + if not self.blocked: + self.eventloop.call_soon(self.step) + finally: + context.current_task = None + if not self.alive: + # Cancel timeout callback if set. + if self.canceleer is not None: + self.canceleer.cancel() + # Schedule done_callbacks. + for dcall in self.done_callbacks: + self.eventloop.add_callback(dcall) + + def block(self, unblock_callback=None, *unblock_args): + assert self is context.current_task, self + assert self.alive, self + assert not self.blocked, self + self.blocked = True + self.unblocker = (unblock_callback, unblock_args) + + def unblock_if_alive(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. + if self.alive: + self.unblock() + + def unblock(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. + assert self.alive, self + assert self.blocked, self + self.blocked = False + unblock_callback, unblock_args = self.unblocker + if unblock_callback is not None: + try: + unblock_callback(*unblock_args) + except Exception: + logging.error('Exception in unblocker in task %r', self.name) + raise + finally: + self.unblocker = None + self.eventloop.call_soon(self.step) + + def block_io(self, fd, flag): + assert isinstance(fd, int), repr(fd) + assert flag in ('r', 'w'), repr(flag) + if flag == 'r': + self.block(self.eventloop.remove_reader, fd) + self.eventloop.add_reader(fd, self.unblock) + else: + self.block(self.eventloop.remove_writer, fd) + self.eventloop.add_writer(fd, self.unblock) + + def wait(self): + """COROUTINE: Wait until this task is finished.""" + current_task = context.current_task + assert self is not current_task, (self, current_task) # How confusing! + if not self.alive: + return + current_task.block() + self.add_done_callback(current_task.unblock) + yield + + def __iter__(self): + """COROUTINE: Wait, then return result or raise exception. + + This adds a little magic so you can say + + x = yield from Task(gen()) + + and it is equivalent to + + x = yield from gen() + + but with the option to add a timeout (and only a tad slower). + """ + if self.alive: + yield from self.wait() + assert not self.alive + if self.exception is not None: + raise self.exception + return self.result + + +def run(arg=None): + """Run the event loop until it's out of work. + + If you pass a generator, it will be spawned for you. + You can also pass a task (already started). + Returns the task. + """ + t = None + if arg is not None: + if isinstance(arg, Task): + t = arg + else: + t = Task(arg) + context.eventloop.run() + if t is not None and t.exception is not None: + logging.error('Uncaught exception in startup task: %r', + t.exception) + return t + + +def sleep(secs): + """COROUTINE: Sleep for some time (a float in seconds).""" + current_task = context.current_task + unblocker = context.eventloop.call_later(secs, current_task.unblock) + current_task.block(unblocker.cancel) + yield + + +def block_r(fd): + """COROUTINE: Block until a file descriptor is ready for reading.""" + context.current_task.block_io(fd, 'r') + yield + + +def block_w(fd): + """COROUTINE: Block until a file descriptor is ready for writing.""" + context.current_task.block_io(fd, 'w') + yield + + +def call_in_thread(func, *args, executor=None): + """COROUTINE: Run a function in a thread.""" + task = context.current_task + eventloop = context.eventloop + future = context.threadrunner.submit(func, *args, + executor=executor, + callback=task.unblock_if_alive) + task.block(future.cancel) + yield + assert future.done() + return future.result() + + +def wait_for(count, tasks): + """COROUTINE: Wait for the first N of a set of tasks to complete. + + May return more than N if more than N are immediately ready. + + NOTE: Tasks that were cancelled or raised are also considered ready. + """ + assert tasks + assert all(isinstance(task, Task) for task in tasks) + tasks = set(tasks) + assert 1 <= count <= len(tasks) + current_task = context.current_task + assert all(task is not current_task for task in tasks) + todo = set() + done = set() + dcalls = [] + def wait_for_callback(task): + nonlocal todo, done, current_task, count, dcalls + todo.remove(task) + if len(done) < count: + done.add(task) + if len(done) == count: + for dcall in dcalls: + dcall.cancel() + current_task.unblock() + for task in tasks: + if task.alive: + todo.add(task) + else: + done.add(task) + if len(done) < count: + for task in todo: + dcall = task.add_done_callback(wait_for_callback) + dcalls.append(dcall) + current_task.block() + yield + return done + + +def wait_any(tasks): + """COROUTINE: Wait for the first of a set of tasks to complete.""" + return wait_for(1, tasks) + + +def wait_all(tasks): + """COROUTINE: Wait for all of a set of tasks to complete.""" + return wait_for(len(tasks), tasks) + + +def map_over(gen, *args, timeout=None): + """COROUTINE: map a generator over one or more iterables. + + E.g. map_over(foo, xs, ys) runs + + Task(foo(x, y)) for x, y in zip(xs, ys) + + and returns a list of all results (in that order). However if any + task raises an exception, the remaining tasks are cancelled and + the exception is propagated. + """ + # gen is a generator function. + tasks = [Task(gobj, timeout=timeout) for gobj in map(gen, *args)] + return (yield from par_tasks(tasks)) + + +def par(*args): + """COROUTINE: Wait for generators, return a list of results. + + Raises as soon as one of the tasks raises an exception (and then + remaining tasks are cancelled). + + This differs from par_tasks() in two ways: + - takes *args instead of list of args + - each arg may be a generator or a task + """ + tasks = [] + for arg in args: + if not isinstance(arg, Task): + # TODO: assert arg is a generator or an iterator? + arg = Task(arg) + tasks.append(arg) + return (yield from par_tasks(tasks)) + + +def par_tasks(tasks): + """COROUTINE: Wait for a list of tasks, return a list of results. + + Raises as soon as one of the tasks raises an exception (and then + remaining tasks are cancelled). + """ + todo = set(tasks) + while todo: + ts = yield from wait_any(todo) + for t in ts: + assert not t.alive, t + todo.remove(t) + if t.exception is not None: + for other in todo: + other.cancel() + raise t.exception + return [t.result for t in tasks] diff --git a/old/sockets.py b/old/sockets.py new file mode 100644 index 00000000..a5005dc3 --- /dev/null +++ b/old/sockets.py @@ -0,0 +1,348 @@ +"""Socket wrappers to go with scheduling.py. + +Classes: + +- SocketTransport: a transport implementation wrapping a socket. +- SslTransport: a transport implementation wrapping SSL around a socket. +- BufferedReader: a buffer wrapping the read end of a transport. + +Functions (all coroutines): + +- connect(): connect a socket. +- getaddrinfo(): look up an address. +- create_connection(): look up address and return a connected socket for it. +- create_transport(): look up address and return a connected transport. + +TODO: +- Improve transport abstraction. +- Make a nice protocol abstraction. +- Unittests. +- A write() call that isn't a generator (needed so you can substitute it + for sys.stderr, pass it to logging.StreamHandler, etc.). +""" + +__author__ = 'Guido van Rossum ' + +# Stdlib imports. +import errno +import socket +import ssl + +# Local imports. +import scheduling + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + + +class SocketTransport: + """Transport wrapping a socket. + + The socket must already be connected in non-blocking mode. + """ + + def __init__(self, sock): + self.sock = sock + + def recv(self, n): + """COROUTINE: Read up to n bytes, blocking as needed. + + Always returns at least one byte, except if the socket was + closed or disconnected and there's no more data; then it + returns b''. + """ + assert n >= 0, n + while True: + try: + return self.sock.recv(n) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return b'' + else: + raise # Unexpected, propagate. + yield from scheduling.block_r(self.sock.fileno()) + + def send(self, data): + """COROUTINE; Send data to the socket, blocking until all written. + + Return True if all went well, False if socket was disconnected. + """ + while data: + try: + n = self.sock.send(data) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + else: + assert 0 <= n <= len(data), (n, len(data)) + if n == len(data): + break + data = data[n:] + continue + yield from scheduling.block_w(self.sock.fileno()) + + return True + + def close(self): + """Close the socket. (Not a coroutine.)""" + self.sock.close() + + +class SslTransport: + """Transport wrapping a socket in SSL. + + The socket must already be connected at the TCP level in + non-blocking mode. + """ + + def __init__(self, rawsock, sslcontext=None): + self.rawsock = rawsock + self.sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self.sslsock = self.sslcontext.wrap_socket( + self.rawsock, do_handshake_on_connect=False) + + def do_handshake(self): + """COROUTINE: Finish the SSL handshake.""" + while True: + try: + self.sslsock.do_handshake() + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + else: + break + + def recv(self, n): + """COROUTINE: Read up to n bytes. + + This blocks until at least one byte is read, or until EOF. + """ + while True: + try: + return self.sslsock.recv(n) + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_r(self.sslsock.fileno()) + elif err.errno in _DISCONNECTED: + # Can this happen? + return b'' + else: + raise # Unexpected, propagate. + + def send(self, data): + """COROUTINE: Send data to the socket, blocking as needed.""" + while data: + try: + n = self.sslsock.send(data) + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_w(self.sslsock.fileno()) + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + if n == len(data): + break + data = data[n:] + + return True + + def close(self): + """Close the socket. (Not a coroutine.) + + This also closes the raw socket. + """ + self.sslsock.close() + + # TODO: More SSL-specific methods, e.g. certificate stuff, unwrap(), ... + + +class BufferedReader: + """A buffered reader wrapping a transport.""" + + def __init__(self, trans, limit=8192): + self.trans = trans + self.limit = limit + self.buffer = b'' + self.eof = False + + def read(self, n): + """COROUTINE: Read up to n bytes, blocking at most once.""" + assert n >= 0, n + if not self.buffer and not self.eof: + yield from self._fillbuffer(max(n, self.limit)) + return self._getfrombuffer(n) + + def readexactly(self, n): + """COUROUTINE: Read exactly n bytes, or until EOF.""" + blocks = [] + count = 0 + while count < n: + block = yield from self.read(n - count) + if not block: + break + blocks.append(block) + count += len(block) + return b''.join(blocks) + + def readline(self): + """COROUTINE: Read up to newline or limit, whichever comes first.""" + end = self.buffer.find(b'\n') + 1 # Point past newline, or 0. + while not end and not self.eof and len(self.buffer) < self.limit: + anchor = len(self.buffer) + yield from self._fillbuffer(self.limit) + end = self.buffer.find(b'\n', anchor) + 1 + if not end: + end = len(self.buffer) + if end > self.limit: + end = self.limit + return self._getfrombuffer(end) + + def _getfrombuffer(self, n): + """Read up to n bytes without blocking (not a coroutine).""" + if n >= len(self.buffer): + result, self.buffer = self.buffer, b'' + else: + result, self.buffer = self.buffer[:n], self.buffer[n:] + return result + + def _fillbuffer(self, n): + """COROUTINE: Fill buffer with one (up to) n bytes from transport.""" + assert not self.eof, '_fillbuffer called at eof' + data = yield from self.trans.recv(n) + if data: + self.buffer += data + else: + self.eof = True + + +def connect(sock, address): + """COROUTINE: Connect a socket to an address.""" + try: + sock.connect(address) + except socket.error as err: + if err.errno != errno.EINPROGRESS: + raise + yield from scheduling.block_w(sock.fileno()) + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + raise IOError(err, 'Connection refused') + + +def getaddrinfo(host, port, af=0, socktype=0, proto=0): + """COROUTINE: Look up an address and return a list of infos for it. + + Each info is a tuple (af, socktype, protocol, canonname, address). + """ + infos = yield from scheduling.call_in_thread(socket.getaddrinfo, + host, port, af, + socktype, proto) + return infos + + +def create_connection(host, port, af=0, socktype=socket.SOCK_STREAM, proto=0): + """COROUTINE: Look up address and create a socket connected to it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + yield from connect(sock, address) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return sock + + +def create_transport(host, port, af=0, ssl=None): + """COROUTINE: Look up address and create a transport connected to it.""" + if ssl is None: + ssl = (port == 443) + sock = yield from create_connection(host, port, af) + if ssl: + trans = SslTransport(sock) + yield from trans.do_handshake() + else: + trans = SocketTransport(sock) + return trans + + +class Listener: + """Wrapper for a listening socket.""" + + def __init__(self, sock): + self.sock = sock + + def accept(self): + """COROUTINE: Accept a connection.""" + while True: + try: + conn, addr = self.sock.accept() + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_r(self.sock.fileno()) + else: + raise # Unexpected, propagate. + else: + conn.setblocking(False) + return conn, addr + + +def create_listener(host, port, af=0, socktype=0, proto=0, + backlog=5, reuse_addr=True): + """COROUTINE: Look up address and create a listener for it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + if reuse_addr: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + sock.listen(backlog) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return Listener(sock) diff --git a/old/transports.py b/old/transports.py new file mode 100644 index 00000000..19095bf4 --- /dev/null +++ b/old/transports.py @@ -0,0 +1,496 @@ +"""Transports and Protocols, actually. + +Inspired by Twisted, PEP 3153 and github.com/lvh/async-pep. + +THIS IS NOT REAL CODE! IT IS JUST AN EXPERIMENT. +""" + +# Stdlib imports. +import collections +import errno +import logging +import socket +import ssl +import sys +import time + +# Local imports. +import polling +import scheduling +import sockets + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + + +class Transport: + """ABC representing a transport. + + There may be many implementations. The user never instantiates + this directly; they call some utility function, passing it a + protocol, and the utility function will call the protocol's + connection_made() method with a transport (or it will call + connection_lost() with an exception if it fails to create the + desired transport). + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + def write(self, data): + """Write some data (bytes) to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data (bytes) to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data will + be received. When all buffered data is flushed, the protocol's + connection_lost() method is called with None as its argument. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method is called with None as + its argument. + """ + raise NotImplementedError + + def half_close(self): + """Closes the write end after flushing buffered data. + + Data may still be received. + + TODO: What's the use case for this? How to implement it? + Should it call shutdown(SHUT_WR) after all the data is flushed? + Is there no use case for closing the other half first? + """ + raise NotImplementedError + + def pause(self): + """Pause the receiving end. + + No data will be received until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Cancels a pause() call, resumes receiving data. + """ + raise NotImplementedError + + +class Protocol: + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing. + + When the user wants to requests a transport, they pass a protocol + instance to a utility function. + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_list() will be called exactly once + with either an exception object or None as an argument. + + If the utility function does not succeed in creating a transport, + it will call connection_lost() with an exception object. + + State machine of calls: + + start -> [CM -> DR*] -> CL -> end + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the connection. + To send data, call its write() or writelines() method. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + + TODO: Should we allow it to be a bytesarray or some other + memory buffer? + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + Also called when we fail to make a connection at all (in that + case connection_made() will not be called). + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +# TODO: The rest is platform specific and should move elsewhere. + +class UnixSocketTransport(Transport): + + def __init__(self, eventloop, protocol, sock): + self._eventloop = eventloop + self._protocol = protocol + self._sock = sock + self._buffer = collections.deque() # For write(). + self._write_closed = False + + def _on_readable(self): + try: + data = self._sock.recv(8192) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + else: + if not data: + self._eventloop.remove_reader(self._sock.fileno()) + self._sock.close() + self._protocol.connection_lost(None) + else: + self._protocol.data_received(data) # XXX call_soon()? + + def write(self, data): + assert isinstance(data, bytes) + assert not self._write_closed + if not data: + # Silly, but it happens. + return + if self._buffer: + # We've already registered a callback, just buffer the data. + self._buffer.append(data) + # Consider pausing if the total length of the buffer is + # truly huge. + return + + # TODO: Refactor so there's more sharing between this and + # _on_writable(). + + # There's no callback registered yet. It's quite possible + # that the kernel has buffer space for our data, so try to + # write now. Since the socket is non-blocking it will + # give us an error in _TRYAGAIN if it doesn't have enough + # space for even one more byte; it will return the number + # of bytes written if it can write at least one byte. + try: + n = self._sock.send(data) + except socket.error as exc: + # An error. + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + # The kernel doesn't have room for more data right now. + n = 0 + else: + # Wrote at least one byte. + if n == len(data): + # Wrote it all. Done! + if self._write_closed: + self._sock.shutdown(socket.SHUT_WR) + return + # Throw away the data that was already written. + # TODO: Do this without copying the data? + data = data[n:] + self._buffer.append(data) + self._eventloop.add_writer(self._sock.fileno(), self._on_writable) + + def _on_writable(self): + while self._buffer: + data = self._buffer[0] + # TODO: Join small amounts of data? + try: + n = self._sock.send(data) + except socket.error as exc: + # Error handling is the same as in write(). + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + if n < len(data): + self._buffer[0] = data[n:] + return + self._buffer.popleft() + self._eventloop.remove_writer(self._sock.fileno()) + if self._write_closed: + self._sock.shutdown(socket.SHUT_WR) + + def abort(self): + self._bad_error(None) + + def _bad_error(self, exc): + # A serious error. Close the socket etc. + fd = self._sock.fileno() + # TODO: Record whether we have a writer and/or reader registered. + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + self._sock.close() + self._protocol.connection_lost(exc) # XXX call_soon()? + + def half_close(self): + self._write_closed = True + + +class UnixSslTransport(Transport): + + # TODO: Refactor Socket and Ssl transport to share some code. + # (E.g. buffering.) + + # TODO: Consider using coroutines instead of callbacks, it seems + # much easier that way. + + def __init__(self, eventloop, protocol, rawsock, sslcontext=None): + self._eventloop = eventloop + self._protocol = protocol + self._rawsock = rawsock + self._sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslsock = self._sslcontext.wrap_socket( + self._rawsock, do_handshake_on_connect=False) + + self._buffer = collections.deque() # For write(). + self._write_closed = False + + # Try the handshake now. Likely it will raise EAGAIN, then it + # will take care of registering the appropriate callback. + self._on_handshake() + + def _bad_error(self, exc): + # A serious error. Close the socket etc. + fd = self._sslsock.fileno() + # TODO: Record whether we have a writer and/or reader registered. + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + self._sslsock.close() + self._protocol.connection_lost(exc) # XXX call_soon()? + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._eventloop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._eventloop.add_writable(fd, self._on_handshake) + return + # TODO: What if it raises another error? + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + self._protocol.connection_made(self) + self._eventloop.add_reader(fd, self._on_ready) + self._eventloop.add_writer(fd, self._on_ready) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + else: + if data: + self._protocol.data_received(data) + else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer + self._eventloop.remove_reader(fd) + self._eventloop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + return + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = self._buffer[0] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + else: + if n == len(data): + self._buffer.popleft() + # Could try again, but let's just have the next callback do it. + else: + self._buffer[0] = data[n:] + + def write(self, data): + assert isinstance(data, bytes) + assert not self._write_closed + if not data: + return + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + def half_close(self): + self._write_closed = True + # Just set the flag. Calling shutdown() on the ssl socket + # breaks something, causing recv() to return binary data. + + +def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, + use_ssl=None): + # TODO: Pass in a protocol factory, not a protocol. + # What should be the exact sequence of events? + # - socket + # - transport + # - protocol + # - tell transport about protocol + # - tell protocol about transport + # Or should the latter two be reversed? Does it matter? + if port is None: + port = 443 if use_ssl else 80 + if use_ssl is None: + use_ssl = (port == 443) + if not socktype: + socktype = socket.SOCK_STREAM + eventloop = polling.context.eventloop + + def on_socket_connected(task): + assert not task.alive + if task.exception is not None: + # TODO: Call some callback. + raise task.exception + sock = task.result + assert sock is not None + logging.debug('on_socket_connected') + if use_ssl: + # You can pass an ssl.SSLContext object as use_ssl, + # or a bool. + if isinstance(use_ssl, bool): + sslcontext = None + else: + sslcontext = use_ssl + transport = UnixSslTransport(eventloop, protocol, sock, sslcontext) + else: + transport = UnixSocketTransport(eventloop, protocol, sock) + # TODO: Should the ransport make the following calls? + protocol.connection_made(transport) # XXX call_soon()? + # Don't do this before connection_made() is called. + eventloop.add_reader(sock.fileno(), transport._on_readable) + + coro = sockets.create_connection(host, port, af, socktype, proto) + task = scheduling.Task(coro) + task.add_done_callback(on_socket_connected) + + +def main(): # Testing... + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + host = 'xkcd.com' + if sys.argv[1:] and '.' in sys.argv[-1]: + host = sys.argv[-1] + + t0 = time.time() + + class TestProtocol(Protocol): + def connection_made(self, transport): + logging.info('Connection made at %.3f secs', time.time() - t0) + self.transport = transport + self.transport.write(b'GET / HTTP/1.0\r\nHost: ' + + host.encode('ascii') + + b'\r\n\r\n') + self.transport.half_close() + def data_received(self, data): + logging.info('Received %d bytes at t=%.3f', + len(data), time.time() - t0) + logging.debug('Received %r', data) + def connection_lost(self, exc): + logging.debug('Connection lost: %r', exc) + self.t1 = time.time() + logging.info('Total time %.3f secs', self.t1 - t0) + + tp = TestProtocol() + logging.debug('tp = %r', tp) + make_connection(tp, host, use_ssl=('-S' in sys.argv)) + logging.info('Running...') + polling.context.eventloop.run() + logging.info('Done.') + + +if __name__ == '__main__': + main() diff --git a/old/xkcd.py b/old/xkcd.py new file mode 100755 index 00000000..474009d0 --- /dev/null +++ b/old/xkcd.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3.3 +"""Minimal synchronous SSL demo, connecting to xkcd.com.""" + +import socket, ssl + +s = socket.socket() +s.connect(('xkcd.com', 443)) +ss = ssl.wrap_socket(s) + +ss.send(b'GET / HTTP/1.0\r\n\r\n') + +while True: + data = ss.recv(1000000) + print(data) + if not data: + break + +ss.close() diff --git a/old/yyftime.py b/old/yyftime.py new file mode 100644 index 00000000..f55234b9 --- /dev/null +++ b/old/yyftime.py @@ -0,0 +1,75 @@ +"""Compare timing of yield-from vs. yield calls.""" + +import gc +import time + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def run_coro(depth): + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + print('coro', depth, k, round(t1-t0, 6)) + return t1-t0 + +class Future: + + def __init__(self, g): + self.g = g + + def wait(self): + value = None + try: + while True: + f = self.g.send(value) + f.wait() + value = f.value + except StopIteration as err: + self.value = err.value + + + +def task(func): # Decorator + def wrapper(*args): + g = func(*args) + f = Future(g) + return f + return wrapper + +@task +def oldstyle(n): + if n <= 0: + return 1 + l = yield oldstyle(n-1) + r = yield oldstyle(n-1) + return l + 1 + r + +def run_olds(depth): + t0 = time.time() + f = oldstyle(depth) + f.wait() + k = f.value + t1 = time.time() + print('olds', depth, k, round(t1-t0, 6)) + return t1-t0 + +def main(): + gc.disable() + for depth in range(16): + tc = run_coro(depth) + to = run_olds(depth) + if tc: + print('ratio', round(to/tc, 2)) + +if __name__ == '__main__': + main() diff --git a/overlapped.c b/overlapped.c new file mode 100644 index 00000000..c9f6ec9f --- /dev/null +++ b/overlapped.c @@ -0,0 +1,997 @@ +/* + * Support for overlapped IO + * + * Some code borrowed from Modules/_winapi.c of CPython + */ + +/* XXX check overflow and DWORD <-> Py_ssize_t conversions + Check itemsize */ + +#include "Python.h" +#include "structmember.h" + +#define WINDOWS_LEAN_AND_MEAN +#include +#include +#include + +#if defined(MS_WIN32) && !defined(MS_WIN64) +# define F_POINTER "k" +# define T_POINTER T_ULONG +#else +# define F_POINTER "K" +# define T_POINTER T_ULONGLONG +#endif + +#define F_HANDLE F_POINTER +#define F_ULONG_PTR F_POINTER +#define F_DWORD "k" +#define F_BOOL "i" +#define F_UINT "I" + +#define T_HANDLE T_POINTER + +enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_WRITE, TYPE_ACCEPT, + TYPE_CONNECT, TYPE_DISCONNECT}; + +/* + * Map Windows error codes to subclasses of OSError + */ + +static PyObject * +SetFromWindowsErr(DWORD err) +{ + PyObject *exception_type; + + if (err == 0) + err = GetLastError(); + switch (err) { + case ERROR_CONNECTION_REFUSED: + exception_type = PyExc_ConnectionRefusedError; + break; + case ERROR_CONNECTION_ABORTED: + exception_type = PyExc_ConnectionAbortedError; + break; + default: + exception_type = PyExc_OSError; + } + return PyErr_SetExcFromWindowsErr(exception_type, err); +} + +/* + * Some functions should be loaded at runtime + */ + +static LPFN_ACCEPTEX Py_AcceptEx = NULL; +static LPFN_CONNECTEX Py_ConnectEx = NULL; +static LPFN_DISCONNECTEX Py_DisconnectEx = NULL; +static BOOL (CALLBACK *Py_CancelIoEx)(HANDLE, LPOVERLAPPED) = NULL; + +#define GET_WSA_POINTER(s, x) \ + (SOCKET_ERROR != WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, \ + &Guid##x, sizeof(Guid##x), &Py_##x, \ + sizeof(Py_##x), &dwBytes, NULL, NULL)) + +static int +initialize_function_pointers(void) +{ + GUID GuidAcceptEx = WSAID_ACCEPTEX; + GUID GuidConnectEx = WSAID_CONNECTEX; + GUID GuidDisconnectEx = WSAID_DISCONNECTEX; + HINSTANCE hKernel32; + SOCKET s; + DWORD dwBytes; + + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (s == INVALID_SOCKET) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + if (!GET_WSA_POINTER(s, AcceptEx) || + !GET_WSA_POINTER(s, ConnectEx) || + !GET_WSA_POINTER(s, DisconnectEx)) + { + closesocket(s); + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + closesocket(s); + + /* On WinXP we will have Py_CancelIoEx == NULL */ + hKernel32 = GetModuleHandle("KERNEL32"); + *(FARPROC *)&Py_CancelIoEx = GetProcAddress(hKernel32, "CancelIoEx"); + return 0; +} + +/* + * Completion port stuff + */ + +PyDoc_STRVAR( + CreateIoCompletionPort_doc, + "CreateIoCompletionPort(handle, port, key, concurrency) -> port\n\n" + "Create a completion port or register a handle with a port."); + +static PyObject * +overlapped_CreateIoCompletionPort(PyObject *self, PyObject *args) +{ + HANDLE FileHandle; + HANDLE ExistingCompletionPort; + ULONG_PTR CompletionKey; + DWORD NumberOfConcurrentThreads; + HANDLE ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE F_ULONG_PTR F_DWORD, + &FileHandle, &ExistingCompletionPort, &CompletionKey, + &NumberOfConcurrentThreads)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = CreateIoCompletionPort(FileHandle, ExistingCompletionPort, + CompletionKey, NumberOfConcurrentThreads); + Py_END_ALLOW_THREADS + + if (ret == NULL) + return SetFromWindowsErr(0); + return Py_BuildValue(F_HANDLE, ret); +} + +PyDoc_STRVAR( + GetQueuedCompletionStatus_doc, + "GetQueuedCompletionStatus(port, msecs) -> (err, bytes, key, address)\n\n" + "Get a message from completion port. Wait for up to msecs milliseconds."); + +static PyObject * +overlapped_GetQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort = NULL; + DWORD NumberOfBytes = 0; + ULONG_PTR CompletionKey = 0; + OVERLAPPED *Overlapped = NULL; + DWORD Milliseconds; + DWORD err; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, + &CompletionPort, &Milliseconds)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = GetQueuedCompletionStatus(CompletionPort, &NumberOfBytes, + &CompletionKey, &Overlapped, Milliseconds); + Py_END_ALLOW_THREADS + + err = ret ? ERROR_SUCCESS : GetLastError(); + if (Overlapped == NULL) { + if (err == WAIT_TIMEOUT) + Py_RETURN_NONE; + else + return SetFromWindowsErr(err); + } + return Py_BuildValue(F_DWORD F_DWORD F_ULONG_PTR F_POINTER, + err, NumberOfBytes, CompletionKey, Overlapped); +} + +PyDoc_STRVAR( + PostQueuedCompletionStatus_doc, + "PostQueuedCompletionStatus(port, bytes, key, address) -> None\n\n" + "Post a message to completion port."); + +static PyObject * +overlapped_PostQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort; + DWORD NumberOfBytes; + ULONG_PTR CompletionKey; + OVERLAPPED *Overlapped; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD F_ULONG_PTR F_POINTER, + &CompletionPort, &NumberOfBytes, &CompletionKey, + &Overlapped)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = PostQueuedCompletionStatus(CompletionPort, NumberOfBytes, + CompletionKey, Overlapped); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +/* + * Bind socket handle to local port without doing slow getaddrinfo() + */ + +PyDoc_STRVAR( + BindLocal_doc, + "BindLocal(handle, length_of_address_tuple) -> None\n\n" + "Bind a socket handle to an arbitrary local port.\n" + "If length_of_address_tuple is 2 then an AF_INET address is used.\n" + "If length_of_address_tuple is 4 then an AF_INET6 address is used."); + +static PyObject * +overlapped_BindLocal(PyObject *self, PyObject *args) +{ + SOCKET Socket; + int TupleLength; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE "i", &Socket, &TupleLength)) + return NULL; + + if (TupleLength == 2) { + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = 0; + addr.sin_addr.S_un.S_addr = INADDR_ANY; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else if (TupleLength == 4) { + struct sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_port = 0; + addr.sin6_addr = in6addr_any; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else { + PyErr_SetString(PyExc_ValueError, "expected tuple of length 2 or 4"); + return NULL; + } + + if (!ret) + return SetFromWindowsErr(WSAGetLastError()); + Py_RETURN_NONE; +} + +/* + * A Python object wrapping an OVERLAPPED structure and other useful data + * for overlapped I/O + */ + +PyDoc_STRVAR( + Overlapped_doc, + "Overlapped object"); + +typedef struct { + PyObject_HEAD + OVERLAPPED overlapped; + /* For convenience, we store the file handle too */ + HANDLE handle; + /* Error returned by last method call */ + DWORD error; + /* Type of operation */ + DWORD type; + /* Buffer used for reading (optional) */ + PyObject *read_buffer; + /* Buffer used for writing (optional) */ + Py_buffer write_buffer; +} OverlappedObject; + + +static PyObject * +Overlapped_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + OverlappedObject *self; + HANDLE event = INVALID_HANDLE_VALUE; + static char *kwlist[] = {"event", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|" F_HANDLE, kwlist, &event)) + return NULL; + + if (event == INVALID_HANDLE_VALUE) { + event = CreateEvent(NULL, TRUE, FALSE, NULL); + if (event == NULL) + return SetFromWindowsErr(0); + } + + self = PyObject_New(OverlappedObject, type); + if (self == NULL) { + if (event != NULL) + CloseHandle(event); + return NULL; + } + + self->handle = NULL; + self->error = 0; + self->type = TYPE_NONE; + self->read_buffer = NULL; + memset(&self->overlapped, 0, sizeof(OVERLAPPED)); + memset(&self->write_buffer, 0, sizeof(Py_buffer)); + if (event) + self->overlapped.hEvent = event; + return (PyObject *)self; +} + +static void +Overlapped_dealloc(OverlappedObject *self) +{ + DWORD bytes; + DWORD olderr = GetLastError(); + BOOL wait = FALSE; + BOOL ret; + + if (!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED) + { + if (Py_CancelIoEx && Py_CancelIoEx(self->handle, &self->overlapped)) + wait = TRUE; + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, + &bytes, wait); + Py_END_ALLOW_THREADS + + switch (ret ? ERROR_SUCCESS : GetLastError()) { + case ERROR_SUCCESS: + case ERROR_NOT_FOUND: + case ERROR_OPERATION_ABORTED: + break; + default: + PyErr_Format( + PyExc_RuntimeError, + "%R still has pending operation at " + "deallocation, the process may crash", self); + PyErr_WriteUnraisable(NULL); + } + } + + if (self->overlapped.hEvent != NULL) + CloseHandle(self->overlapped.hEvent); + + if (self->write_buffer.obj) + PyBuffer_Release(&self->write_buffer); + + Py_CLEAR(self->read_buffer); + PyObject_Del(self); + SetLastError(olderr); +} + +PyDoc_STRVAR( + Overlapped_cancel_doc, + "cancel() -> None\n\n" + "Cancel overlapped operation"); + +static PyObject * +Overlapped_cancel(OverlappedObject *self) +{ + BOOL ret = TRUE; + + if (self->type == TYPE_NOT_STARTED) + Py_RETURN_NONE; + + if (!HasOverlappedIoCompleted(&self->overlapped)) { + Py_BEGIN_ALLOW_THREADS + if (Py_CancelIoEx) + ret = Py_CancelIoEx(self->handle, &self->overlapped); + else + ret = CancelIo(self->handle); + Py_END_ALLOW_THREADS + } + + /* CancelIoEx returns ERROR_NOT_FOUND if the I/O completed in-between */ + if (!ret && GetLastError() != ERROR_NOT_FOUND) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +PyDoc_STRVAR( + Overlapped_getresult_doc, + "getresult(wait=False) -> result\n\n" + "Retrieve result of operation. If wait is true then it blocks\n" + "until the operation is finished. If wait is false and the\n" + "operation is still pending then an error is raised."); + +static PyObject * +Overlapped_getresult(OverlappedObject *self, PyObject *args) +{ + BOOL wait = FALSE; + DWORD transferred = 0; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, "|" F_BOOL, &wait)) + return NULL; + + if (self->type == TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation not yet attempted"); + return NULL; + } + + if (self->type == TYPE_NOT_STARTED) { + PyErr_SetString(PyExc_ValueError, "operation failed to start"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, &transferred, + wait); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + break; + case ERROR_BROKEN_PIPE: + if (self->read_buffer != NULL) + break; + /* fall through */ + default: + return SetFromWindowsErr(err); + } + + switch (self->type) { + case TYPE_READ: + assert(PyBytes_CheckExact(self->read_buffer)); + if (transferred != PyBytes_GET_SIZE(self->read_buffer) && + _PyBytes_Resize(&self->read_buffer, transferred)) + return NULL; + Py_INCREF(self->read_buffer); + return self->read_buffer; + case TYPE_ACCEPT: + case TYPE_CONNECT: + case TYPE_DISCONNECT: + Py_RETURN_NONE; + default: + return PyLong_FromUnsignedLong((unsigned long) transferred); + } +} + +PyDoc_STRVAR( + Overlapped_ReadFile_doc, + "ReadFile(handle, size) -> Overlapped[message]\n\n" + "Start overlapped read"); + +static PyObject * +Overlapped_ReadFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD nread; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &handle, &size)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = ReadFile(handle, PyBytes_AS_STRING(buf), size, &nread, + &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_BROKEN_PIPE: + self->type = TYPE_NOT_STARTED; + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSARecv_doc, + "RecvFile(handle, size, flags) -> Overlapped[message]\n\n" + "Start overlapped receive"); + +static PyObject * +Overlapped_WSARecv(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD flags = 0; + DWORD nread; + PyObject *buf; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD "|" F_DWORD, + &handle, &size, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + wsabuf.len = size; + wsabuf.buf = PyBytes_AS_STRING(buf); + + Py_BEGIN_ALLOW_THREADS + ret = WSARecv((SOCKET)handle, &wsabuf, 1, &nread, &flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_BROKEN_PIPE: + self->type = TYPE_NOT_STARTED; + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WriteFile_doc, + "WriteFile(handle, buf) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped write"); + +static PyObject * +Overlapped_WriteFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD written; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &handle, &bufobj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + + Py_BEGIN_ALLOW_THREADS + ret = WriteFile(handle, self->write_buffer.buf, + (DWORD)self->write_buffer.len, + &written, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSASend_doc, + "WSASend(handle, buf, flags) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped send"); + +static PyObject * +Overlapped_WSASend(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD flags; + DWORD written; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O" F_DWORD, + &handle, &bufobj, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + wsabuf.len = (DWORD)self->write_buffer.len; + wsabuf.buf = self->write_buffer.buf; + + Py_BEGIN_ALLOW_THREADS + ret = WSASend((SOCKET)handle, &wsabuf, 1, &written, flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_AcceptEx_doc, + "AcceptEx(listen_handle, accept_handle) -> Overlapped[address_as_bytes]\n\n" + "Start overlapped wait for client to connect"); + +static PyObject * +Overlapped_AcceptEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ListenSocket; + SOCKET AcceptSocket; + DWORD BytesReceived; + DWORD size; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE, + &ListenSocket, &AcceptSocket)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + size = sizeof(struct sockaddr_in6) + 16; + buf = PyBytes_FromStringAndSize(NULL, size*2); + if (!buf) + return NULL; + + self->type = TYPE_ACCEPT; + self->handle = (HANDLE)ListenSocket; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = Py_AcceptEx(ListenSocket, AcceptSocket, PyBytes_AS_STRING(buf), + 0, size, size, &BytesReceived, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + + +static int +parse_address(PyObject *obj, SOCKADDR *Address, int Length) +{ + char *Host; + unsigned short Port; + unsigned long FlowInfo; + unsigned long ScopeId; + + memset(Address, 0, Length); + + if (PyArg_ParseTuple(obj, "sH", &Host, &Port)) + { + Address->sa_family = AF_INET; + if (WSAStringToAddressA(Host, AF_INET, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN*)Address)->sin_port = htons(Port); + return Length; + } + else if (PyArg_ParseTuple(obj, "sHkk", &Host, &Port, &FlowInfo, &ScopeId)) + { + PyErr_Clear(); + Address->sa_family = AF_INET6; + if (WSAStringToAddressA(Host, AF_INET6, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN6*)Address)->sin6_port = htons(Port); + ((SOCKADDR_IN6*)Address)->sin6_flowinfo = FlowInfo; + ((SOCKADDR_IN6*)Address)->sin6_scope_id = ScopeId; + return Length; + } + + return -1; +} + + +PyDoc_STRVAR( + Overlapped_ConnectEx_doc, + "ConnectEx(client_handle, address_as_bytes) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_ConnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ConnectSocket; + PyObject *AddressObj; + char AddressBuf[sizeof(struct sockaddr_in6)]; + SOCKADDR *Address = (SOCKADDR*)AddressBuf; + int Length; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &ConnectSocket, &AddressObj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + Length = sizeof(AddressBuf); + Length = parse_address(AddressObj, Address, Length); + if (Length < 0) + return NULL; + + self->type = TYPE_CONNECT; + self->handle = (HANDLE)ConnectSocket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_ConnectEx(ConnectSocket, Address, Length, + NULL, 0, NULL, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_DisconnectEx_doc, + "DisconnectEx(handle, flags) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_DisconnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET Socket; + DWORD flags; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &Socket, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + self->type = TYPE_DISCONNECT; + self->handle = (HANDLE)Socket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_DisconnectEx(Socket, &self->overlapped, flags, 0); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +static PyObject* +Overlapped_getaddress(OverlappedObject *self) +{ + return PyLong_FromVoidPtr(&self->overlapped); +} + +static PyObject* +Overlapped_getpending(OverlappedObject *self) +{ + return PyBool_FromLong(!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED); +} + +static PyMethodDef Overlapped_methods[] = { + {"getresult", (PyCFunction) Overlapped_getresult, + METH_VARARGS, Overlapped_getresult_doc}, + {"cancel", (PyCFunction) Overlapped_cancel, + METH_NOARGS, Overlapped_cancel_doc}, + {"ReadFile", (PyCFunction) Overlapped_ReadFile, + METH_VARARGS, Overlapped_ReadFile_doc}, + {"WSARecv", (PyCFunction) Overlapped_WSARecv, + METH_VARARGS, Overlapped_WSARecv_doc}, + {"WriteFile", (PyCFunction) Overlapped_WriteFile, + METH_VARARGS, Overlapped_WriteFile_doc}, + {"WSASend", (PyCFunction) Overlapped_WSASend, + METH_VARARGS, Overlapped_WSASend_doc}, + {"AcceptEx", (PyCFunction) Overlapped_AcceptEx, + METH_VARARGS, Overlapped_AcceptEx_doc}, + {"ConnectEx", (PyCFunction) Overlapped_ConnectEx, + METH_VARARGS, Overlapped_ConnectEx_doc}, + {"DisconnectEx", (PyCFunction) Overlapped_DisconnectEx, + METH_VARARGS, Overlapped_DisconnectEx_doc}, + {NULL} +}; + +static PyMemberDef Overlapped_members[] = { + {"error", T_ULONG, + offsetof(OverlappedObject, error), + READONLY, "Error from last operation"}, + {"event", T_HANDLE, + offsetof(OverlappedObject, overlapped) + offsetof(OVERLAPPED, hEvent), + READONLY, "Overlapped event handle"}, + {NULL} +}; + +static PyGetSetDef Overlapped_getsets[] = { + {"address", (getter)Overlapped_getaddress, NULL, + "Address of overlapped structure"}, + {"pending", (getter)Overlapped_getpending, NULL, + "Whether the operation is pending"}, + {NULL}, +}; + +PyTypeObject OverlappedType = { + PyVarObject_HEAD_INIT(NULL, 0) + /* tp_name */ "_overlapped.Overlapped", + /* tp_basicsize */ sizeof(OverlappedObject), + /* tp_itemsize */ 0, + /* tp_dealloc */ (destructor) Overlapped_dealloc, + /* tp_print */ 0, + /* tp_getattr */ 0, + /* tp_setattr */ 0, + /* tp_reserved */ 0, + /* tp_repr */ 0, + /* tp_as_number */ 0, + /* tp_as_sequence */ 0, + /* tp_as_mapping */ 0, + /* tp_hash */ 0, + /* tp_call */ 0, + /* tp_str */ 0, + /* tp_getattro */ 0, + /* tp_setattro */ 0, + /* tp_as_buffer */ 0, + /* tp_flags */ Py_TPFLAGS_DEFAULT, + /* tp_doc */ "OVERLAPPED structure wrapper", + /* tp_traverse */ 0, + /* tp_clear */ 0, + /* tp_richcompare */ 0, + /* tp_weaklistoffset */ 0, + /* tp_iter */ 0, + /* tp_iternext */ 0, + /* tp_methods */ Overlapped_methods, + /* tp_members */ Overlapped_members, + /* tp_getset */ Overlapped_getsets, + /* tp_base */ 0, + /* tp_dict */ 0, + /* tp_descr_get */ 0, + /* tp_descr_set */ 0, + /* tp_dictoffset */ 0, + /* tp_init */ 0, + /* tp_alloc */ 0, + /* tp_new */ Overlapped_new, +}; + +static PyMethodDef overlapped_functions[] = { + {"CreateIoCompletionPort", overlapped_CreateIoCompletionPort, + METH_VARARGS, CreateIoCompletionPort_doc}, + {"GetQueuedCompletionStatus", overlapped_GetQueuedCompletionStatus, + METH_VARARGS, GetQueuedCompletionStatus_doc}, + {"PostQueuedCompletionStatus", overlapped_PostQueuedCompletionStatus, + METH_VARARGS, PostQueuedCompletionStatus_doc}, + {"BindLocal", overlapped_BindLocal, + METH_VARARGS, BindLocal_doc}, + {NULL} +}; + +static struct PyModuleDef overlapped_module = { + PyModuleDef_HEAD_INIT, + "_overlapped", + NULL, + -1, + overlapped_functions, + NULL, + NULL, + NULL, + NULL +}; + +#define WINAPI_CONSTANT(fmt, con) \ + PyDict_SetItemString(d, #con, Py_BuildValue(fmt, con)) + +PyMODINIT_FUNC +PyInit__overlapped(void) +{ + PyObject *m, *d; + + /* Ensure WSAStartup() called before initializing function pointers */ + m = PyImport_ImportModule("_socket"); + if (!m) + return NULL; + Py_DECREF(m); + + if (initialize_function_pointers() < 0) + return NULL; + + if (PyType_Ready(&OverlappedType) < 0) + return NULL; + + m = PyModule_Create(&overlapped_module); + if (PyModule_AddObject(m, "Overlapped", (PyObject *)&OverlappedType) < 0) + return NULL; + + d = PyModule_GetDict(m); + + WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING); + WINAPI_CONSTANT(F_DWORD, INFINITE); + WINAPI_CONSTANT(F_HANDLE, INVALID_HANDLE_VALUE); + WINAPI_CONSTANT(F_HANDLE, NULL); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_ACCEPT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_CONNECT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, TF_REUSE_SOCKET); + + return m; +} diff --git a/runtests.py b/runtests.py new file mode 100644 index 00000000..e4678d70 --- /dev/null +++ b/runtests.py @@ -0,0 +1,198 @@ +"""Run all unittests. + +Usage: + python3 runtests.py [-v] [-q] [pattern] ... + +Where: + -v: verbose + -q: quiet + pattern: optional regex patterns to match test ids (default all tests) + +Note that the test id is the fully qualified name of the test, +including package, module, class and method, +e.g. 'tests.events_test.PolicyTests.testPolicy'. + +runtests.py with --coverage argument is equivalent of: + + $(COVERAGE) run --branch runtests.py -v + $(COVERAGE) html $(list of files) + $(COVERAGE) report -m $(list of files) + +""" + +# Originally written by Beech Horn (for NDB). + +import argparse +import logging +import os +import re +import sys +import subprocess +import unittest +import importlib.machinery + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +ARGS = argparse.ArgumentParser(description="Run all unittests.") +ARGS.add_argument( + '-v', action="store", dest='verbose', + nargs='?', const=1, type=int, default=0, help='verbose') +ARGS.add_argument( + '-x', action="store_true", dest='exclude', help='exclude tests') +ARGS.add_argument( + '-q', action="store_true", dest='quiet', help='quiet') +ARGS.add_argument( + '--tests', action="store", dest='testsdir', default='tests', + help='tests directory') +ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') +ARGS.add_argument( + 'pattern', action="store", nargs="*", + help='optional regex patterns to match test ids (default all tests)') + +COV_ARGS = argparse.ArgumentParser(description="Run all unittests.") +COV_ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') + + +def load_modules(basedir, suffix='.py'): + def list_dir(prefix, dir): + files = [] + + modpath = os.path.join(dir, '__init__.py') + if os.path.isfile(modpath): + mod = os.path.split(dir)[-1] + files.append(('%s%s' % (prefix, mod), modpath)) + + prefix = '%s%s.' % (prefix, mod) + + for name in os.listdir(dir): + path = os.path.join(dir, name) + + if os.path.isdir(path): + files.extend(list_dir('%s%s.' % (prefix, name), path)) + else: + if (name != '__init__.py' and + name.endswith(suffix) and + not name.startswith(('.', '_'))): + files.append(('%s%s' % (prefix, name[:-3]), path)) + + return files + + mods = [] + for modname, sourcefile in list_dir('', basedir): + if modname == 'runtests': + continue + try: + loader = importlib.machinery.SourceFileLoader(modname, sourcefile) + mods.append((loader.load_module(), sourcefile)) + except Exception as err: + print("Skipping '%s': %s" % (modname, err)) + + return mods + + +def load_tests(testsdir, includes=(), excludes=()): + mods = [mod for mod, _ in load_modules(testsdir)] + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + for mod in mods: + for name in set(dir(mod)): + if name.endswith('Tests'): + test_module = getattr(mod, name) + tests = loader.loadTestsFromTestCase(test_module) + if includes: + tests = [test + for test in tests + if any(re.search(pat, test.id()) + for pat in includes)] + if excludes: + tests = [test + for test in tests + if not any(re.search(pat, test.id()) + for pat in excludes)] + suite.addTests(tests) + + return suite + + +def runtests(): + args = ARGS.parse_args() + + testsdir = os.path.abspath(args.testsdir) + if not os.path.isdir(testsdir): + print("Tests directory is not found: %s\n" % testsdir) + ARGS.print_help() + return + + excludes = includes = [] + if args.exclude: + excludes = args.pattern + else: + includes = args.pattern + + v = 0 if args.quiet else args.verbose + 1 + + tests = load_tests(args.testsdir, includes, excludes) + logger = logging.getLogger() + if v == 0: + logger.setLevel(logging.CRITICAL) + elif v == 1: + logger.setLevel(logging.ERROR) + elif v == 2: + logger.setLevel(logging.WARNING) + elif v == 3: + logger.setLevel(logging.INFO) + elif v >= 4: + logger.setLevel(logging.DEBUG) + result = unittest.TextTestRunner(verbosity=v).run(tests) + sys.exit(not result.wasSuccessful()) + + +def runcoverage(sdir, args): + """ + To install coverage3 for Python 3, you need: + - Distribute (http://packages.python.org/distribute/) + + What worked for me: + - download http://python-distribute.org/distribute_setup.py + * curl -O http://python-distribute.org/distribute_setup.py + - python3 distribute_setup.py + - python3 -m easy_install coverage + """ + try: + import coverage + except ImportError: + print("Coverage package is not found.") + print(runcoverage.__doc__) + return + + sdir = os.path.abspath(sdir) + if not os.path.isdir(sdir): + print("Python files directory is not found: %s\n" % sdir) + ARGS.print_help() + return + + mods = [source for _, source in load_modules(sdir)] + coverage = [sys.executable, '-m', 'coverage'] + + try: + subprocess.check_call( + coverage + ['run', '--branch', 'runtests.py'] + args) + except: + pass + else: + subprocess.check_call(coverage + ['html'] + mods) + subprocess.check_call(coverage + ['report'] + mods) + + +if __name__ == '__main__': + if '--coverage' in sys.argv: + cov_args, args = COV_ARGS.parse_known_args() + runcoverage(cov_args.coverage, args) + else: + runtests() diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..0260f9d5 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[build_ext] +build_lib=tulip diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..67b037cc --- /dev/null +++ b/setup.py @@ -0,0 +1,4 @@ +from distutils.core import setup, Extension + +ext = Extension('_overlapped', ['overlapped.c'], libraries=['ws2_32']) +setup(name='_overlapped', ext_modules=[ext]) diff --git a/srv.py b/srv.py new file mode 100644 index 00000000..540e63b9 --- /dev/null +++ b/srv.py @@ -0,0 +1,123 @@ +"""Simple server written using an event loop.""" + +import http.client +import email.message +import email.parser +import os + +import tulip +import tulip.http + + +class HttpServer(tulip.Protocol): + + def __init__(self): + super().__init__() + self.transport = None + self.reader = None + self.handler = None + + @tulip.task + def handle_request(self): + try: + method, path, version = yield from self.reader.read_request_line() + except http.client.BadStatusLine: + self.transport.close() + return + + print('method = {!r}; path = {!r}; version = {!r}'.format( + method, path, version)) + + if (not (path.isprintable() and path.startswith('/')) or '/.' in path): + print('bad path', repr(path)) + path = None + else: + path = '.' + path + if not os.path.exists(path): + print('no file', repr(path)) + path = None + else: + isdir = os.path.isdir(path) + + if not path: + self.transport.write(b'HTTP/1.0 404 Not found\r\n\r\n') + self.transport.close() + return + + message = yield from self.reader.read_message() + + headers = email.message.Message() + for hdr, val in message.headers: + print(hdr, val) + headers.add_header(hdr, val) + + write = self.transport.write + if isdir and not path.endswith('/'): + bpath = path.encode('ascii') + write(b'HTTP/1.0 302 Redirected\r\n' + b'URI: ' + bpath + b'/\r\n' + b'Location: ' + bpath + b'/\r\n' + b'\r\n') + return + write(b'HTTP/1.0 200 Ok\r\n') + if isdir: + write(b'Content-type: text/html\r\n') + else: + write(b'Content-type: text/plain\r\n') + write(b'\r\n') + if isdir: + write(b'
    \r\n') + for name in sorted(os.listdir(path)): + if name.isprintable() and not name.startswith('.'): + try: + bname = name.encode('ascii') + except UnicodeError: + pass + else: + if os.path.isdir(os.path.join(path, name)): + write(b'
  • ' + bname + b'/
  • \r\n') + else: + write(b'
  • ' + bname + b'
  • \r\n') + write(b'
') + else: + try: + with open(path, 'rb') as f: + write(f.read()) + except OSError: + write(b'Cannot open\r\n') + self.transport.close() + + def connection_made(self, transport): + self.transport = transport + print('connection made', transport, transport.get_extra_info('socket')) + self.reader = tulip.http.HttpStreamReader() + self.handler = self.handle_request() + + def data_received(self, data): + print('data received', data) + self.reader.feed_data(data) + + def eof_received(self): + print('eof received') + self.reader.feed_eof() + + def connection_lost(self, exc): + print('connection lost', exc) + if (self.handler.done() and + not self.handler.cancelled() and + self.handler.exception() is not None): + print('handler exception:', self.handler.exception()) + + +def main(): + loop = tulip.get_event_loop() + f = loop.start_serving(HttpServer, '127.0.0.1', 8080) + x = loop.run_until_complete(f) + print('serving on', x.getsockname()) + loop.run_forever() + + +if __name__ == '__main__': + main() diff --git a/sslsrv.py b/sslsrv.py new file mode 100644 index 00000000..a1bc04f9 --- /dev/null +++ b/sslsrv.py @@ -0,0 +1,56 @@ +"""Serve up an SSL connection, after Python ssl module docs.""" + +import socket +import ssl +import os + + +def main(): + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + certfile = getcertfile() + context.load_cert_chain(certfile=certfile, keyfile=certfile) + bindsocket = socket.socket() + bindsocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + bindsocket.bind(('', 4443)) + bindsocket.listen(5) + + while True: + newsocket, fromaddr = bindsocket.accept() + try: + connstream = context.wrap_socket(newsocket, server_side=True) + try: + deal_with_client(connstream) + finally: + connstream.shutdown(socket.SHUT_RDWR) + connstream.close() + except Exception as exc: + print(exc.__class__.__name__ + ':', exc) + + +def getcertfile(): + import test # Test package + testdir = os.path.dirname(test.__file__) + certfile = os.path.join(testdir, 'keycert.pem') + print('certfile =', certfile) + return certfile + + +def deal_with_client(connstream): + data = connstream.recv(1024) + # empty data means the client is finished with us + while data: + if not do_something(connstream, data): + # we'll assume do_something returns False + # when we're finished with client + break + data = connstream.recv(1024) + # finished with client + + +def do_something(connstream, data): + # just echo back + connstream.sendall(data) + + +if __name__ == '__main__': + main() diff --git a/tests/base_events_test.py b/tests/base_events_test.py new file mode 100644 index 00000000..03c3296b --- /dev/null +++ b/tests/base_events_test.py @@ -0,0 +1,271 @@ +"""Tests for base_events.py""" + +import concurrent.futures +import logging +import socket +import time +import unittest +import unittest.mock + +from tulip import base_events +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import test_utils + + +class BaseEventLoopTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + + self.event_loop = base_events.BaseEventLoop() + self.event_loop._selector = unittest.mock.Mock() + self.event_loop._selector.registered_count.return_value = 1 + + def test_not_implemented(self): + m = unittest.mock.Mock() + self.assertRaises( + NotImplementedError, + self.event_loop._make_socket_transport, m, m) + self.assertRaises( + NotImplementedError, + self.event_loop._make_ssl_transport, m, m, m, m) + self.assertRaises( + NotImplementedError, + self.event_loop._make_datagram_transport, m, m) + self.assertRaises( + NotImplementedError, self.event_loop._process_events, []) + self.assertRaises( + NotImplementedError, self.event_loop._write_to_self) + self.assertRaises( + NotImplementedError, self.event_loop._read_from_self) + + def test_add_callback_handle(self): + h = events.Handle(lambda: False, ()) + + self.event_loop._add_callback(h) + self.assertFalse(self.event_loop._scheduled) + self.assertIn(h, self.event_loop._ready) + + def test_add_callback_timer(self): + when = time.monotonic() + + h1 = events.Timer(when, lambda: False, ()) + h2 = events.Timer(when+10.0, lambda: False, ()) + + self.event_loop._add_callback(h2) + self.event_loop._add_callback(h1) + self.assertEqual([h1, h2], self.event_loop._scheduled) + self.assertFalse(self.event_loop._ready) + + def test_add_callback_cancelled_handle(self): + h = events.Handle(lambda: False, ()) + h.cancel() + + self.event_loop._add_callback(h) + self.assertFalse(self.event_loop._scheduled) + self.assertFalse(self.event_loop._ready) + + def test_wrap_future(self): + f = futures.Future() + self.assertIs(self.event_loop.wrap_future(f), f) + + def test_wrap_future_concurrent(self): + f = concurrent.futures.Future() + self.assertIsInstance(self.event_loop.wrap_future(f), futures.Future) + + def test_set_default_executor(self): + executor = unittest.mock.Mock() + self.event_loop.set_default_executor(executor) + self.assertIs(executor, self.event_loop._default_executor) + + def test_getnameinfo(self): + sockaddr = unittest.mock.Mock() + self.event_loop.run_in_executor = unittest.mock.Mock() + self.event_loop.getnameinfo(sockaddr) + self.assertEqual( + (None, socket.getnameinfo, sockaddr, 0), + self.event_loop.run_in_executor.call_args[0]) + + def test_call_soon(self): + def cb(): + pass + + h = self.event_loop.call_soon(cb) + self.assertEqual(h._callback, cb) + self.assertIsInstance(h, events.Handle) + self.assertIn(h, self.event_loop._ready) + + def test_call_later(self): + def cb(): + pass + + h = self.event_loop.call_later(10.0, cb) + self.assertIsInstance(h, events.Timer) + self.assertIn(h, self.event_loop._scheduled) + self.assertNotIn(h, self.event_loop._ready) + + def test_call_later_no_delay(self): + def cb(): + pass + + h = self.event_loop.call_later(0, cb) + self.assertIn(h, self.event_loop._ready) + self.assertNotIn(h, self.event_loop._scheduled) + + def test_run_once_in_executor_handle(self): + def cb(): + pass + + self.assertRaises( + AssertionError, self.event_loop.run_in_executor, + None, events.Handle(cb, ()), ('',)) + self.assertRaises( + AssertionError, self.event_loop.run_in_executor, + None, events.Timer(10, cb, ())) + + def test_run_once_in_executor_canceled(self): + def cb(): + pass + h = events.Handle(cb, ()) + h.cancel() + + f = self.event_loop.run_in_executor(None, h) + self.assertIsInstance(f, futures.Future) + self.assertTrue(f.done()) + + def test_run_once_in_executor(self): + def cb(): + pass + h = events.Handle(cb, ()) + f = futures.Future() + executor = unittest.mock.Mock() + executor.submit.return_value = f + + self.event_loop.set_default_executor(executor) + + res = self.event_loop.run_in_executor(None, h) + self.assertIs(f, res) + + executor = unittest.mock.Mock() + executor.submit.return_value = f + res = self.event_loop.run_in_executor(executor, h) + self.assertIs(f, res) + self.assertTrue(executor.submit.called) + + def test_run_once(self): + self.event_loop._run_once = unittest.mock.Mock() + self.event_loop._run_once.side_effect = base_events._StopError + self.event_loop.run_once() + self.assertTrue(self.event_loop._run_once.called) + + def test__run_once(self): + h1 = events.Timer(time.monotonic() + 0.1, lambda: True, ()) + h2 = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + + h1.cancel() + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h1) + self.event_loop._scheduled.append(h2) + self.event_loop._run_once() + + t = self.event_loop._selector.select.call_args[0][0] + self.assertTrue(9.99 < t < 10.1) + self.assertEqual([h2], self.event_loop._scheduled) + self.assertTrue(self.event_loop._process_events.called) + + def test__run_once_timeout(self): + h = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._run_once(1.0) + self.assertEqual((1.0,), self.event_loop._selector.select.call_args[0]) + + def test__run_once_timeout_with_ready(self): + """If event loop has ready callbacks, select timeout is always 0.""" + h = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._ready.append(h) + self.event_loop._run_once(1.0) + + self.assertEqual((0,), self.event_loop._selector.select.call_args[0]) + + @unittest.mock.patch('tulip.base_events.time') + @unittest.mock.patch('tulip.base_events.logging') + def test__run_once_logging(self, m_logging, m_time): + """Log to INFO level if timeout > 1.0 sec.""" + idx = -1 + data = [10.0, 10.0, 12.0, 13.0] + + def monotonic(): + nonlocal data, idx + idx += 1 + return data[idx] + + m_time.monotonic = monotonic + m_logging.INFO = logging.INFO + m_logging.DEBUG = logging.DEBUG + + self.event_loop._scheduled.append(events.Timer(11.0, lambda: True, ())) + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._run_once() + self.assertEqual(logging.INFO, m_logging.log.call_args[0][0]) + + idx = -1 + data = [10.0, 10.0, 10.3, 13.0] + self.event_loop._scheduled = [events.Timer(11.0, lambda:True, ())] + self.event_loop._run_once() + self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0]) + + def test__run_once_schedule_handle(self): + handle = None + processed = False + + def cb(event_loop): + nonlocal processed, handle + processed = True + handle = event_loop.call_soon(lambda: True) + + h = events.Timer(time.monotonic() - 1, cb, (self.event_loop,)) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._run_once() + + self.assertTrue(processed) + self.assertEqual([handle], list(self.event_loop._ready)) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_connection_mutiple_errors(self, m_socket): + self.suppress_log_errors() + + class MyProto(protocols.Protocol): + pass + + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + + idx = -1 + errors = ['err1', 'err2'] + + def _socket(*args, **kw): + nonlocal idx, errors + idx += 1 + raise socket.error(errors[idx]) + + m_socket.socket = _socket + m_socket.error = socket.error + + self.event_loop.getaddrinfo = getaddrinfo + + task = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + task._step() + exc = task.exception() + self.assertEqual("Multiple exceptions: err1, err2", str(exc)) diff --git a/tests/events_test.py b/tests/events_test.py new file mode 100644 index 00000000..20c398de --- /dev/null +++ b/tests/events_test.py @@ -0,0 +1,1257 @@ +"""Tests for events.py.""" + +import concurrent.futures +import contextlib +import errno +import gc +import io +import os +import re +import signal +import socket +try: + import ssl +except ImportError: + ssl = None +import sys +import threading +import time +import unittest +import unittest.mock + +from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer + +from tulip import events +from tulip import transports +from tulip import protocols +from tulip import selector_events +from tulip import test_utils + + +class MyProto(protocols.Protocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: xkcd.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + + +class MyDatagramProto(protocols.DatagramProtocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + + +class EventLoopTestsMixin: + + def setUp(self): + super().setUp() + self.event_loop = self.create_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + gc.collect() + super().tearDown() + + @contextlib.contextmanager + def run_test_server(self, *, use_ssl=False): + + class SilentWSGIRequestHandler(WSGIRequestHandler): + def get_stderr(self): + return io.StringIO() + + def log_message(self, format, *args): + pass + + class SilentWSGIServer(WSGIServer): + def handle_error(self, request, client_address): + pass + + class SSLWSGIServer(SilentWSGIServer): + def finish_request(self, request, client_address): + here = os.path.dirname(__file__) + keyfile = os.path.join(here, 'sample.key') + certfile = os.path.join(here, 'sample.crt') + ssock = ssl.wrap_socket(request, + keyfile=keyfile, + certfile=certfile, + server_side=True) + try: + self.RequestHandlerClass(ssock, client_address, self) + ssock.close() + except OSError: + # maybe socket has been closed by peer + pass + + def app(environ, start_response): + status = '302 Found' + headers = [('Content-type', 'text/plain')] + start_response(status, headers) + return [b'Test message'] + + # Run the test WSGI server in a separate thread in order not to + # interfere with event handling in the main thread + server_class = SSLWSGIServer if use_ssl else SilentWSGIServer + httpd = make_server('127.0.0.1', 0, app, + server_class, SilentWSGIRequestHandler) + server_thread = threading.Thread(target=httpd.serve_forever) + server_thread.start() + try: + yield httpd + finally: + httpd.shutdown() + server_thread.join() + + def test_run(self): + self.event_loop.run() # Returns immediately. + + def test_call_later(self): + results = [] + + def callback(arg): + results.append(arg) + + self.event_loop.call_later(0.1, callback, 'hello world') + t0 = time.monotonic() + self.event_loop.run() + t1 = time.monotonic() + self.assertEqual(results, ['hello world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_call_repeatedly(self): + results = [] + + def callback(arg): + results.append(arg) + + self.event_loop.call_repeatedly(0.03, callback, 'ho') + self.event_loop.call_later(0.1, self.event_loop.stop) + self.event_loop.run() + self.assertEqual(results, ['ho', 'ho', 'ho']) + + def test_call_soon(self): + results = [] + + def callback(arg1, arg2): + results.append((arg1, arg2)) + + self.event_loop.call_soon(callback, 'hello', 'world') + self.event_loop.run() + self.assertEqual(results, [('hello', 'world')]) + + def test_call_soon_with_handle(self): + results = [] + + def callback(): + results.append('yeah') + + handle = events.Handle(callback, ()) + self.assertIs(self.event_loop.call_soon(handle), handle) + self.event_loop.run() + self.assertEqual(results, ['yeah']) + + def test_call_soon_threadsafe(self): + results = [] + + def callback(arg): + results.append(arg) + + def run(): + self.event_loop.call_soon_threadsafe(callback, 'hello') + + t = threading.Thread(target=run) + self.event_loop.call_later(0.1, callback, 'world') + t0 = time.monotonic() + t.start() + self.event_loop.run() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_call_soon_threadsafe_same_thread(self): + results = [] + + def callback(arg): + results.append(arg) + + self.event_loop.call_later(0.1, callback, 'world') + self.event_loop.call_soon_threadsafe(callback, 'hello') + self.event_loop.run() + self.assertEqual(results, ['hello', 'world']) + + def test_call_soon_threadsafe_with_handle(self): + results = [] + + def callback(arg): + results.append(arg) + + handle = events.Handle(callback, ('hello',)) + + def run(): + self.assertIs( + self.event_loop.call_soon_threadsafe(handle), handle) + + t = threading.Thread(target=run) + self.event_loop.call_later(0.1, callback, 'world') + + t0 = time.monotonic() + t.start() + self.event_loop.run() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_wrap_future(self): + def run(arg): + time.sleep(0.1) + return arg + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = self.event_loop.wrap_future(f1) + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'oi') + + def test_run_in_executor(self): + def run(arg): + time.sleep(0.1) + return arg + f2 = self.event_loop.run_in_executor(None, run, 'yo') + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def test_run_in_executor_with_handle(self): + def run(arg): + time.sleep(0.1) + return arg + handle = events.Handle(run, ('yo',)) + f2 = self.event_loop.run_in_executor(None, handle) + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def test_reader_callback(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.event_loop.remove_reader(r.fileno())) + r.close() + + self.event_loop.add_reader(r.fileno(), reader) + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_reader_callback_with_handle(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.event_loop.remove_reader(r.fileno())) + r.close() + + handle = events.Handle(reader, ()) + self.assertIs(handle, self.event_loop.add_reader(r.fileno(), handle)) + + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_reader_callback_cancel(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + return + if data: + bytes_read.append(data) + if sum(len(b) for b in bytes_read) >= 6: + handle.cancel() + if not data: + r.close() + + handle = self.event_loop.add_reader(r.fileno(), reader) + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_writer_callback(self): + r, w = test_utils.socketpair() + w.setblocking(False) + self.event_loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + + def remove_writer(): + self.assertTrue(self.event_loop.remove_writer(w.fileno())) + + self.event_loop.call_later(0.1, remove_writer) + self.event_loop.run() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def test_writer_callback_with_handle(self): + r, w = test_utils.socketpair() + w.setblocking(False) + handle = events.Handle(w.send, (b'x'*(256*1024),)) + self.assertIs(self.event_loop.add_writer(w.fileno(), handle), handle) + + def remove_writer(): + self.assertTrue(self.event_loop.remove_writer(w.fileno())) + + self.event_loop.call_later(0.1, remove_writer) + self.event_loop.run() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def test_writer_callback_cancel(self): + r, w = test_utils.socketpair() + w.setblocking(False) + + def sender(): + w.send(b'x'*256) + handle.cancel() + + handle = self.event_loop.add_writer(w.fileno(), sender) + self.event_loop.run() + w.close() + data = r.recv(1024) + r.close() + self.assertTrue(data == b'x'*256) + + def test_sock_client_ops(self): + with self.run_test_server() as httpd: + sock = socket.socket() + sock.setblocking(False) + address = httpd.socket.getsockname() + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, address)) + self.event_loop.run_until_complete( + self.event_loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.event_loop.run_until_complete( + self.event_loop.sock_recv(sock, 1024)) + sock.close() + + self.assertTrue(re.match(rb'HTTP/1.0 302 Found', data), data) + + def test_sock_client_fail(self): + # Make sure that we will get an unused port + address = None + try: + s = socket.socket() + s.bind(('127.0.0.1', 0)) + address = s.getsockname() + finally: + s.close() + + sock = socket.socket() + sock.setblocking(False) + with self.assertRaises(ConnectionRefusedError): + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, address)) + sock.close() + + def test_sock_accept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + + f = self.event_loop.sock_accept(listener) + conn, addr = self.event_loop.run_until_complete(f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + def test_add_signal_handler(self): + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + # Check error behavior first. + self.assertRaises( + TypeError, self.event_loop.add_signal_handler, 'boom', my_handler) + self.assertRaises( + TypeError, self.event_loop.remove_signal_handler, 'boom') + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, signal.NSIG+1) + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, 0, my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, 0) + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, -1, my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, -1) + self.assertRaises( + RuntimeError, self.event_loop.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + self.event_loop.add_signal_handler(signal.SIGINT, my_handler) + self.event_loop.run_once() + os.kill(os.getpid(), signal.SIGINT) + self.event_loop.run_once() + self.assertEqual(caught, 1) + # Removing it should restore the default handler. + self.assertTrue(self.event_loop.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGINT)) + + @unittest.skipIf(sys.platform == 'win32', 'Unix only') + def test_cancel_signal_handler(self): + # Cancelling the handler should remove it (eventually). + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + handle = self.event_loop.add_signal_handler(signal.SIGINT, my_handler) + handle.cancel() + os.kill(os.getpid(), signal.SIGINT) + self.event_loop.run_once() + self.assertEqual(caught, 0) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_while_selecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + self.event_loop.add_signal_handler( + signal.SIGALRM, my_handler) + + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + self.event_loop.call_later(0.15, self.event_loop.stop) + self.event_loop.run_forever() + self.assertEqual(caught, 1) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_args(self): + some_args = (42,) + caught = 0 + + def my_handler(*args): + nonlocal caught + caught += 1 + self.assertEqual(args, some_args) + + self.event_loop.add_signal_handler( + signal.SIGALRM, my_handler, *some_args) + + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + self.event_loop.call_later(0.15, self.event_loop.stop) + self.event_loop.run_forever() + self.assertEqual(caught, 1) + + def test_create_connection(self): + with self.run_test_server() as httpd: + host, port = httpd.socket.getsockname() + f = self.event_loop.create_connection(MyProto, host, port) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) + + def test_create_connection_sock(self): + with self.run_test_server() as httpd: + host, port = httpd.socket.getsockname() + f = self.event_loop.create_connection(MyProto, host, port) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_ssl_connection(self): + with self.run_test_server(use_ssl=True) as httpsd: + host, port = httpsd.socket.getsockname() + f = self.event_loop.create_connection( + MyProto, host, port, ssl=True) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + self.assertTrue(hasattr(tr.get_extra_info('socket'), 'getsockname')) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) + + def test_create_connection_host_port_sock(self): + self.suppress_log_errors() + fut = self.event_loop.create_connection( + MyProto, 'xkcd.com', 80, sock=object()) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_create_connection_no_host_port_sock(self): + self.suppress_log_errors() + fut = self.event_loop.create_connection(MyProto) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_create_connection_no_getaddrinfo(self): + self.suppress_log_errors() + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + fut = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, fut) + + def test_create_connection_connect_err(self): + self.suppress_log_errors() + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + fut = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, fut) + + def test_create_connection_mutiple_errors(self): + self.suppress_log_errors() + + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + self.event_loop.getaddrinfo = getaddrinfo + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + fut = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, fut) + + def test_start_serving(self): + proto = None + + def factory(): + nonlocal proto + proto = MyProto() + return proto + + f = self.event_loop.start_serving(factory, '0.0.0.0', 0) + sock = self.event_loop.run_until_complete(f) + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + self.event_loop.run_once() + self.assertIsInstance(proto, MyProto) + self.assertEqual('INITIAL', proto.state) + self.event_loop.run_once() + self.assertEqual('CONNECTED', proto.state) + self.assertEqual(0, proto.nbytes) + self.event_loop.run_once() + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('socket')) + conn = proto.transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + self.assertEqual( + '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + + # close connection + proto.transport.close() + + self.event_loop.run_once() + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + def test_start_serving_sock(self): + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.event_loop.start_serving(MyProto, sock=sock_ob) + sock = self.event_loop.run_until_complete(f) + self.assertIs(sock, sock_ob) + + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + self.event_loop.run_once() # This is quite mysterious, but necessary. + self.event_loop.run_once() + self.event_loop.run_once() + sock.close() + client.close() + + def test_start_serving_host_port_sock(self): + self.suppress_log_errors() + fut = self.event_loop.start_serving( + MyProto, '0.0.0.0', 0, sock=object()) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_start_serving_no_host_port_sock(self): + self.suppress_log_errors() + fut = self.event_loop.start_serving(MyProto) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_start_serving_no_getaddrinfo(self): + self.suppress_log_errors() + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + f = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, f) + + @unittest.mock.patch('tulip.base_events.socket') + def test_start_serving_cant_bind(self, m_socket): + self.suppress_log_errors() + + class Err(socket.error): + pass + + m_socket.error = socket.error + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.setsockopt.side_effect = Err + + fut = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises(Err, self.event_loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_create_datagram_connection(self): + server = None + + def factory(): + nonlocal server + server = TestMyDatagramProto() + return server + + class TestMyDatagramProto(MyDatagramProto): + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + f = self.event_loop.start_serving_datagram(factory, '127.0.0.1', 0) + sock = self.event_loop.run_until_complete(f) + host, port = sock.getsockname() + + f = self.event_loop.create_datagram_connection( + MyDatagramProto, host, port) + transport, protocol = self.event_loop.run_until_complete(f) + + self.assertEqual('INITIALIZED', protocol.state) + transport.sendto(b'xxx') + self.event_loop.run_once() + self.assertEqual(0, server.nbytes) + self.event_loop.run_once() + self.assertEqual(3, server.nbytes) + self.event_loop.run_once() + + # received + self.event_loop.run_once() + self.assertEqual(8, protocol.nbytes) + + # extra info is available + self.assertIsNotNone(transport.get_extra_info('socket')) + conn = transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + + # close connection + transport.close() + + self.event_loop.run_once() + self.assertEqual('CLOSED', protocol.state) + + server.transport.close() + + def test_create_datagram_connection_no_connection(self): + server = None + + def factory(): + nonlocal server + server = TestMyDatagramProto() + return server + + class TestMyDatagramProto(MyDatagramProto): + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + f = self.event_loop.start_serving_datagram(factory, '127.0.0.1', 0) + sock = self.event_loop.run_until_complete(f) + host, port = sock.getsockname() + + f = self.event_loop.create_datagram_connection(MyDatagramProto) + transport, protocol = self.event_loop.run_until_complete(f) + + self.assertEqual('INITIALIZED', protocol.state) + transport.sendto(b'xxx', (host, port)) + self.event_loop.run_once() + self.assertEqual(0, server.nbytes) + self.event_loop.run_once() + self.assertEqual(3, server.nbytes) + self.event_loop.run_once() + + # received + self.event_loop.run_once() + self.assertEqual(8, protocol.nbytes) + + # extra info is available + self.assertIsNotNone(transport.get_extra_info('socket')) + conn = transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + + # close connection + transport.close() + + self.event_loop.run_once() + self.assertEqual('CLOSED', protocol.state) + + server.transport.close() + + def test_create_datagram_connection_no_getaddrinfo(self): + self.suppress_log_errors() + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + fut = self.event_loop.create_datagram_connection( + protocols.DatagramProtocol, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, fut) + + def test_create_datagram_connection_connect_err(self): + self.suppress_log_errors() + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + fut = self.event_loop.create_datagram_connection( + protocols.DatagramProtocol, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, fut) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_connection_sockopt_err(self, m_socket): + self.suppress_log_errors() + + m_socket.error = socket.error + m_socket.socket.return_value.setsockopt.side_effect = socket.error + + fut = self.event_loop.create_datagram_connection( + protocols.DatagramProtocol) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, fut) + self.assertTrue( + m_socket.socket.return_value.close.called) + + def test_start_serving_datagram(self): + class TestMyDatagramProto(MyDatagramProto): + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + proto = None + + def factory(): + nonlocal proto + proto = TestMyDatagramProto() + return proto + + f = self.event_loop.start_serving_datagram(factory, '127.0.0.1', 0) + sock = self.event_loop.run_until_complete(f) + self.assertEqual('INITIALIZED', proto.state) + + host, port = sock.getsockname() + self.assertEqual(host, '127.0.0.1') + client = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + client.sendto(b'xxx', ('127.0.0.1', port)) + self.event_loop.run_once() + self.assertEqual(0, proto.nbytes) + self.event_loop.run_once() + self.assertEqual(3, proto.nbytes) + + data, server = client.recvfrom(4096) + self.assertEqual(b'resp:xxx', data) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('socket')) + conn = proto.transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + self.assertEqual( + '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + + # close connection + proto.transport.close() + + self.event_loop.run_once() + self.assertEqual('CLOSED', proto.state) + + client.close() + + def test_start_serving_datagram_no_getaddrinfoc(self): + self.suppress_log_errors() + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + f = self.event_loop.start_serving_datagram( + MyDatagramProto, '0.0.0.0', 0) + + self.assertRaises( + socket.error, self.event_loop.run_until_complete, f) + + @unittest.mock.patch('tulip.base_events.socket') + def test_start_serving_datagram_cant_bind(self, m_socket): + self.suppress_log_errors() + + class Err(socket.error): + pass + + m_socket.error = socket.error + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.setsockopt.side_effect = Err + + fut = self.event_loop.start_serving_datagram( + MyDatagramProto, '0.0.0.0', 0) + self.assertRaises(Err, self.event_loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_accept_connection_retry(self): + class Err(socket.error): + errno = errno.EAGAIN + + sock = unittest.mock.Mock() + sock.accept.side_effect = Err + + self.event_loop._accept_connection(MyProto, sock) + self.assertFalse(sock.close.called) + + def test_accept_connection_exception(self): + self.suppress_log_errors() + + sock = unittest.mock.Mock() + sock.accept.side_effect = socket.error + + self.event_loop._accept_connection(MyProto, sock) + self.assertTrue(sock.close.called) + + def test_internal_fds(self): + event_loop = self.create_event_loop() + if not isinstance(event_loop, selector_events.BaseSelectorEventLoop): + return + + self.assertEqual(1, event_loop._internal_fds) + event_loop.close() + self.assertEqual(0, event_loop._internal_fds) + self.assertIsNone(event_loop._csock) + self.assertIsNone(event_loop._ssock) + + +if sys.platform == 'win32': + from tulip import windows_events + + class SelectEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return windows_events.SelectorEventLoop() + + class ProactorEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return windows_events.ProactorEventLoop() + def test_create_ssl_connection(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + def test_reader_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_reader_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_reader_callback_with_handle(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_writer_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_writer_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_writer_callback_with_handle(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_accept_connection_retry(self): + raise unittest.SkipTest( + "IocpEventLoop does not have _accept_connection()") + def test_accept_connection_exception(self): + raise unittest.SkipTest( + "IocpEventLoop does not have _accept_connection()") + def test_create_datagram_connection(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_connection()") + def test_create_datagram_connection_no_connection(self): + raise unittest.SkipTest( + "IocpEventLoop does not have " + "create_datagram_connection_no_connection()") + def test_start_serving_datagram(self): + raise unittest.SkipTest( + "IocpEventLoop does not have start_serving_datagram()") + def test_start_serving_datagram_no_getaddrinfoc(self): + raise unittest.SkipTest( + "IocpEventLoop does not have start_serving_datagram()") + def test_start_serving_datagram_cant_bind(self): + raise unittest.SkipTest( + "IocpEventLoop does not have start_serving_udp()") + +else: + from tulip import selectors + from tulip import unix_events + + if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop( + selectors.KqueueSelector()) + + if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.EpollSelector()) + + if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.PollSelector()) + + # Should always exist. + class SelectEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.SelectSelector()) + + +class HandleTests(unittest.TestCase): + + def test_handle(self): + def callback(*args): + return args + + args = () + h = events.Handle(callback, args) + self.assertIs(h.callback, callback) + self.assertIs(h.args, args) + self.assertFalse(h.cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '.callback')) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h.cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '.callback')) + self.assertTrue(r.endswith('())')) + + def test_make_handle(self): + def callback(*args): + return args + h1 = events.Handle(callback, ()) + h2 = events.make_handle(h1, ()) + self.assertIs(h1, h2) + + self.assertRaises( + AssertionError, events.make_handle, h1, (1, 2)) + + +class TimerTests(unittest.TestCase): + + def test_timer(self): + def callback(*args): + return args + + args = () + when = time.monotonic() + h = events.Timer(when, callback, args) + self.assertIs(h.callback, callback) + self.assertIs(h.args, args) + self.assertFalse(h.cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h.cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + self.assertRaises(AssertionError, events.Timer, None, callback, args) + + def test_timer_comparison(self): + def callback(*args): + return args + + when = time.monotonic() + + h1 = events.Timer(when, callback, ()) + h2 = events.Timer(when, callback, ()) + self.assertFalse(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertTrue(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertFalse(h2 > h1) + self.assertTrue(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertTrue(h1 == h2) + self.assertFalse(h1 != h2) + + h2.cancel() + self.assertFalse(h1 == h2) + + h1 = events.Timer(when, callback, ()) + h2 = events.Timer(when + 10.0, callback, ()) + self.assertTrue(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertFalse(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertTrue(h2 > h1) + self.assertFalse(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h3 = events.Handle(callback, ()) + self.assertIs(NotImplemented, h1.__eq__(h3)) + self.assertIs(NotImplemented, h1.__ne__(h3)) + + +class AbstractEventLoopTests(unittest.TestCase): + + def test_not_imlemented(self): + f = unittest.mock.Mock() + ev_loop = events.AbstractEventLoop() + self.assertRaises( + NotImplementedError, ev_loop.run) + self.assertRaises( + NotImplementedError, ev_loop.run_forever) + self.assertRaises( + NotImplementedError, ev_loop.run_once) + self.assertRaises( + NotImplementedError, ev_loop.run_until_complete, None) + self.assertRaises( + NotImplementedError, ev_loop.stop) + self.assertRaises( + NotImplementedError, ev_loop.call_later, None, None) + self.assertRaises( + NotImplementedError, ev_loop.call_repeatedly, None, None) + self.assertRaises( + NotImplementedError, ev_loop.call_soon, None) + self.assertRaises( + NotImplementedError, ev_loop.call_soon_threadsafe, None) + self.assertRaises( + NotImplementedError, ev_loop.wrap_future, f) + self.assertRaises( + NotImplementedError, ev_loop.run_in_executor, f, f) + self.assertRaises( + NotImplementedError, ev_loop.getaddrinfo, 'localhost', 8080) + self.assertRaises( + NotImplementedError, ev_loop.getnameinfo, ('localhost', 8080)) + self.assertRaises( + NotImplementedError, ev_loop.create_connection, f) + self.assertRaises( + NotImplementedError, ev_loop.start_serving, f) + self.assertRaises( + NotImplementedError, ev_loop.create_datagram_connection, f) + self.assertRaises( + NotImplementedError, ev_loop.start_serving_datagram, + f, 'localhost', 8080) + self.assertRaises( + NotImplementedError, ev_loop.add_reader, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_reader, 1) + self.assertRaises( + NotImplementedError, ev_loop.add_writer, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_writer, 1) + self.assertRaises( + NotImplementedError, ev_loop.sock_recv, f, 10) + self.assertRaises( + NotImplementedError, ev_loop.sock_sendall, f, 10) + self.assertRaises( + NotImplementedError, ev_loop.sock_connect, f, f) + self.assertRaises( + NotImplementedError, ev_loop.sock_accept, f) + self.assertRaises( + NotImplementedError, ev_loop.add_signal_handler, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_signal_handler, 1) + + +class ProtocolsAbsTests(unittest.TestCase): + + def test_empty(self): + f = unittest.mock.Mock() + p = protocols.Protocol() + self.assertIsNone(p.connection_made(f)) + self.assertIsNone(p.connection_lost(f)) + self.assertIsNone(p.data_received(f)) + self.assertIsNone(p.eof_received()) + + dp = protocols.DatagramProtocol() + self.assertIsNone(dp.connection_made(f)) + self.assertIsNone(dp.connection_lost(f)) + self.assertIsNone(dp.connection_refused(f)) + self.assertIsNone(dp.datagram_received(f, f)) + + +class PolicyTests(unittest.TestCase): + + def test_event_loop_policy(self): + policy = events.EventLoopPolicy() + self.assertRaises(NotImplementedError, policy.get_event_loop) + self.assertRaises(NotImplementedError, policy.set_event_loop, object()) + self.assertRaises(NotImplementedError, policy.new_event_loop) + + def test_get_event_loop(self): + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy._event_loop) + + event_loop = policy.get_event_loop() + self.assertIsInstance(event_loop, events.AbstractEventLoop) + + self.assertIs(policy._event_loop, event_loop) + self.assertIs(event_loop, policy.get_event_loop()) + + @unittest.mock.patch('tulip.events.threading') + def test_get_event_loop_thread(self, m_threading): + m_t = m_threading.current_thread.return_value = unittest.mock.Mock() + m_t.name = 'Thread 1' + + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy.get_event_loop()) + + def test_new_event_loop(self): + policy = events.DefaultEventLoopPolicy() + + event_loop = policy.new_event_loop() + self.assertIsInstance(event_loop, events.AbstractEventLoop) + + def test_set_event_loop(self): + policy = events.DefaultEventLoopPolicy() + old_event_loop = policy.get_event_loop() + + self.assertRaises(AssertionError, policy.set_event_loop, object()) + + event_loop = policy.new_event_loop() + policy.set_event_loop(event_loop) + self.assertIs(event_loop, policy.get_event_loop()) + self.assertIsNot(old_event_loop, policy.get_event_loop()) + + def test_get_event_loop_policy(self): + policy = events.get_event_loop_policy() + self.assertIsInstance(policy, events.EventLoopPolicy) + self.assertIs(policy, events.get_event_loop_policy()) + + def test_set_event_loop_policy(self): + self.assertRaises( + AssertionError, events.set_event_loop_policy, object()) + + old_policy = events.get_event_loop_policy() + + policy = events.DefaultEventLoopPolicy() + events.set_event_loop_policy(policy) + self.assertIs(policy, events.get_event_loop_policy()) + self.assertIsNot(policy, old_policy) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/futures_test.py b/tests/futures_test.py new file mode 100644 index 00000000..5569cca1 --- /dev/null +++ b/tests/futures_test.py @@ -0,0 +1,222 @@ +"""Tests for futures.py.""" + +import unittest + +from tulip import futures + + +def _fakefunc(f): + return f + + +class FutureTests(unittest.TestCase): + + def test_initial_state(self): + f = futures.Future() + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertFalse(f.done()) + + def test_init_event_loop_positional(self): + # Make sure Future does't accept a positional argument + self.assertRaises(TypeError, futures.Future, 42) + + def test_cancel(self): + f = futures.Future() + self.assertTrue(f.cancel()) + self.assertTrue(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(futures.CancelledError, f.result) + self.assertRaises(futures.CancelledError, f.exception) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_result(self): + f = futures.Future() + self.assertRaises(futures.InvalidStateError, f.result) + self.assertRaises(futures.InvalidTimeoutError, f.result, 10) + + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_exception(self): + exc = RuntimeError() + f = futures.Future() + self.assertRaises(futures.InvalidStateError, f.exception) + self.assertRaises(futures.InvalidTimeoutError, f.exception, 10) + + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_yield_from_twice(self): + f = futures.Future() + + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + + def test_repr(self): + f_pending = futures.Future() + self.assertEqual(repr(f_pending), 'Future') + + f_cancelled = futures.Future() + f_cancelled.cancel() + self.assertEqual(repr(f_cancelled), 'Future') + + f_result = futures.Future() + f_result.set_result(4) + self.assertEqual(repr(f_result), 'Future') + + f_exception = futures.Future() + f_exception.set_exception(RuntimeError()) + self.assertEqual(repr(f_exception), 'Future') + + f_few_callbacks = futures.Future() + f_few_callbacks.add_done_callback(_fakefunc) + self.assertIn('Future', r) + + def test_copy_state(self): + # Test the internal _copy_state method since it's being directly + # invoked in other modules. + f = futures.Future() + f.set_result(10) + + newf = futures.Future() + newf._copy_state(f) + self.assertTrue(newf.done()) + self.assertEqual(newf.result(), 10) + + f_exception = futures.Future() + f_exception.set_exception(RuntimeError()) + + newf_exception = futures.Future() + newf_exception._copy_state(f_exception) + self.assertTrue(newf_exception.done()) + self.assertRaises(RuntimeError, newf_exception.result) + + f_cancelled = futures.Future() + f_cancelled.cancel() + + newf_cancelled = futures.Future() + newf_cancelled._copy_state(f_cancelled) + self.assertTrue(newf_cancelled.cancelled()) + + def test_iter(self): + def coro(): + fut = futures.Future() + yield from fut + + def test(): + arg1, arg2 = coro() + + self.assertRaises(AssertionError, test) + + +# A fake event loop for tests. All it does is implement a call_soon method +# that immediately invokes the given function. +class _FakeEventLoop: + def call_soon(self, fn, future): + fn(future) + + +class FutureDoneCallbackTests(unittest.TestCase): + + def _make_callback(self, bag, thing): + # Create a callback function that appends thing to bag. + def bag_appender(future): + bag.append(thing) + return bag_appender + + def _new_future(self): + return futures.Future(event_loop=_FakeEventLoop()) + + def test_callbacks_invoked_on_set_result(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 42)) + f.add_done_callback(self._make_callback(bag, 17)) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def test_callbacks_invoked_on_set_exception(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 100)) + + self.assertEqual(bag, []) + exc = RuntimeError() + f.set_exception(exc) + self.assertEqual(bag, [100]) + self.assertEqual(f.exception(), exc) + + def test_remove_done_callback(self): + bag = [] + f = self._new_future() + cb1 = self._make_callback(bag, 1) + cb2 = self._make_callback(bag, 2) + cb3 = self._make_callback(bag, 3) + + # Add one cb1 and one cb2. + f.add_done_callback(cb1) + f.add_done_callback(cb2) + + # One instance of cb2 removed. Now there's only one cb1. + self.assertEqual(f.remove_done_callback(cb2), 1) + + # Never had any cb3 in there. + self.assertEqual(f.remove_done_callback(cb3), 0) + + # After this there will be 6 instances of cb1 and one of cb2. + f.add_done_callback(cb2) + for i in range(5): + f.add_done_callback(cb1) + + # Remove all instances of cb1. One cb2 remains. + self.assertEqual(f.remove_done_callback(cb1), 6) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [2]) + self.assertEqual(f.result(), 'foo') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py new file mode 100644 index 00000000..c0952287 --- /dev/null +++ b/tests/http_protocol_test.py @@ -0,0 +1,535 @@ +"""Tests for http/protocol.py""" + +import http.client +import unittest +import unittest.mock +import zlib + +import tulip +from tulip.http import protocol +from tulip.test_utils import LogTrackingTestCase + + +class HttpStreamReaderTests(LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.suppress_log_errors() + + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + self.transport = unittest.mock.Mock() + self.stream = protocol.HttpStreamReader() + + def tearDown(self): + self.loop.close() + super().tearDown() + + def test_request_line(self): + self.stream.feed_data(b'get /path HTTP/1.1\r\n') + self.assertEqual( + ('GET', '/path', (1, 1)), + self.loop.run_until_complete( + tulip.Task(self.stream.read_request_line()))) + + def test_request_line_two_slashes(self): + self.stream.feed_data(b'get //path HTTP/1.1\r\n') + self.assertEqual( + ('GET', '//path', (1, 1)), + self.loop.run_until_complete( + tulip.Task(self.stream.read_request_line()))) + + def test_request_line_non_ascii(self): + self.stream.feed_data(b'get /path\xd0\xb0 HTTP/1.1\r\n') + + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_request_line())) + + self.assertEqual( + b'get /path\xd0\xb0 HTTP/1.1\r\n', cm.exception.args[0]) + + def test_request_line_bad_status_line(self): + self.stream.feed_data(b'\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + tulip.Task(self.stream.read_request_line())) + + def test_request_line_bad_method(self): + self.stream.feed_data(b'!12%()+=~$ /get HTTP/1.1\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + tulip.Task(self.stream.read_request_line())) + + def test_request_line_bad_version(self): + self.stream.feed_data(b'GET //get HT/11\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + tulip.Task(self.stream.read_request_line())) + + def test_response_status_bad_status_line(self): + self.stream.feed_data(b'\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + tulip.Task(self.stream.read_response_status())) + + def test_response_status_bad_status_line_eof(self): + self.stream.feed_eof() + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + tulip.Task(self.stream.read_response_status())) + + def test_response_status_bad_status_non_ascii(self): + self.stream.feed_data(b'HTTP/1.1 200 \xd0\xb0\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_response_status())) + + self.assertEqual(b'HTTP/1.1 200 \xd0\xb0\r\n', cm.exception.args[0]) + + def test_response_status_bad_version(self): + self.stream.feed_data(b'HT/11 200 Ok\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_response_status())) + + self.assertEqual('HT/11 200 Ok', cm.exception.args[0]) + + def test_response_status_no_reason(self): + self.stream.feed_data(b'HTTP/1.1 200\r\n') + + v, s, r = self.loop.run_until_complete( + tulip.Task(self.stream.read_response_status())) + self.assertEqual(v, (1, 1)) + self.assertEqual(s, 200) + self.assertEqual(r, '') + + def test_response_status_bad(self): + self.stream.feed_data(b'HTT/1\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_response_status())) + + self.assertIn('HTT/1', str(cm.exception)) + + def test_response_status_bad_code_under_100(self): + self.stream.feed_data(b'HTTP/1.1 99 test\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_response_status())) + + self.assertIn('HTTP/1.1 99 test', str(cm.exception)) + + def test_response_status_bad_code_above_999(self): + self.stream.feed_data(b'HTTP/1.1 9999 test\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_response_status())) + + self.assertIn('HTTP/1.1 9999 test', str(cm.exception)) + + def test_response_status_bad_code_not_int(self): + self.stream.feed_data(b'HTTP/1.1 ttt test\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_response_status())) + + self.assertIn('HTTP/1.1 ttt test', str(cm.exception)) + + def test_read_headers(self): + self.stream.feed_data(b'test: line\r\n' + b' continue\r\n' + b'test2: data\r\n' + b'\r\n') + + headers = self.loop.run_until_complete( + tulip.Task(self.stream.read_headers())) + self.assertEqual(headers, + [('TEST', 'line\r\n continue'), ('TEST2', 'data')]) + + def test_read_headers_size(self): + self.stream.feed_data(b'test: line\r\n') + self.stream.feed_data(b' continue\r\n') + self.stream.feed_data(b'test2: data\r\n') + self.stream.feed_data(b'\r\n') + + self.stream.MAX_HEADERS = 5 + self.assertRaises( + http.client.LineTooLong, + self.loop.run_until_complete, + tulip.Task(self.stream.read_headers())) + + def test_read_headers_invalid_header(self): + self.stream.feed_data(b'test line\r\n') + + with self.assertRaises(ValueError) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_headers())) + + self.assertIn("Invalid header b'test line'", str(cm.exception)) + + def test_read_headers_invalid_name(self): + self.stream.feed_data(b'test[]: line\r\n') + + with self.assertRaises(ValueError) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_headers())) + + self.assertIn("Invalid header name b'TEST[]'", str(cm.exception)) + + def test_read_headers_headers_size(self): + self.stream.MAX_HEADERFIELD_SIZE = 5 + self.stream.feed_data(b'test: line data data\r\ndata\r\n') + + with self.assertRaises(http.client.LineTooLong) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_headers())) + + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_read_headers_continuation_headers_size(self): + self.stream.MAX_HEADERFIELD_SIZE = 5 + self.stream.feed_data(b'test: line\r\n test\r\n') + + with self.assertRaises(http.client.LineTooLong) as cm: + self.loop.run_until_complete( + tulip.Task(self.stream.read_headers())) + + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_read_payload_unknown_encoding(self): + self.assertRaises( + ValueError, self.stream.read_length_payload, encoding='unknown') + + def test_read_payload(self): + self.stream.feed_data(b'da') + self.stream.feed_data(b't') + self.stream.feed_data(b'ali') + self.stream.feed_data(b'ne') + + stream = self.stream.read_length_payload(4) + self.assertIsInstance(stream, tulip.StreamReader) + + data = self.loop.run_until_complete(tulip.Task(stream.read())) + self.assertEqual(b'data', data) + self.assertEqual(b'line', b''.join(self.stream.buffer)) + + def test_read_payload_eof(self): + self.stream.feed_data(b'da') + self.stream.feed_eof() + stream = self.stream.read_length_payload(4) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, tulip.Task(stream.read())) + + def test_read_payload_eof_exc(self): + self.stream.feed_data(b'da') + stream = self.stream.read_length_payload(4) + + def eof(): + yield from [] + self.stream.feed_eof() + + t1 = tulip.Task(stream.read()) + t2 = tulip.Task(eof()) + + self.loop.run_until_complete(tulip.Task(tulip.wait([t1, t2]))) + self.assertRaises(http.client.IncompleteRead, t1.result) + self.assertIsNone(self.stream._reader) + + def test_read_payload_deflate(self): + comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + + data = b''.join([comp.compress(b'data'), comp.flush()]) + stream = self.stream.read_length_payload(len(data), encoding='deflate') + + self.stream.feed_data(data) + + data = self.loop.run_until_complete(tulip.Task(stream.read())) + self.assertEqual(b'data', data) + + def _test_read_payload_compress_error(self): + data = b'123123123datadatadata' + reader = protocol.length_reader(4) + self.stream.feed_data(data) + stream = self.stream.read_payload(reader, 'deflate') + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, tulip.Task(stream.read())) + + def test_read_chunked_payload(self): + stream = self.stream.read_chunked_payload() + self.stream.feed_data(b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + + data = self.loop.run_until_complete(tulip.Task(stream.read())) + self.assertEqual(b'dataline', data) + + def test_read_chunked_payload_chunks(self): + stream = self.stream.read_chunked_payload() + + self.stream.feed_data(b'4\r\ndata\r') + self.stream.feed_data(b'\n4') + self.stream.feed_data(b'\r') + self.stream.feed_data(b'\n') + self.stream.feed_data(b'line\r\n0\r\n') + self.stream.feed_data(b'test\r\n\r\n') + + data = self.loop.run_until_complete(tulip.Task(stream.read())) + self.assertEqual(b'dataline', data) + + def test_read_chunked_payload_incomplete(self): + stream = self.stream.read_chunked_payload() + + self.stream.feed_data(b'4\r\ndata\r\n') + self.stream.feed_eof() + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, tulip.Task(stream.read())) + + def test_read_chunked_payload_extension(self): + stream = self.stream.read_chunked_payload() + + self.stream.feed_data( + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + + data = self.loop.run_until_complete(tulip.Task(stream.read())) + self.assertEqual(b'dataline', data) + + def test_read_chunked_payload_size_error(self): + stream = self.stream.read_chunked_payload() + + self.stream.feed_data(b'blah\r\n') + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, tulip.Task(stream.read())) + + def test_read_length_payload(self): + stream = self.stream.read_length_payload(8) + + self.stream.feed_data(b'data') + self.stream.feed_data(b'data') + + data = self.loop.run_until_complete(tulip.Task(stream.read())) + self.assertEqual(b'datadata', data) + + def test_read_length_payload_zero(self): + stream = self.stream.read_length_payload(0) + + self.stream.feed_data(b'data') + + data = self.loop.run_until_complete(tulip.Task(stream.read())) + self.assertEqual(b'', data) + + def test_read_length_payload_incomplete(self): + stream = self.stream.read_length_payload(8) + + self.stream.feed_data(b'data') + self.stream.feed_eof() + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, tulip.Task(stream.read())) + + def test_read_eof_payload(self): + stream = self.stream.read_eof_payload() + + self.stream.feed_data(b'data') + self.stream.feed_eof() + + data = self.loop.run_until_complete(tulip.Task(stream.read())) + self.assertEqual(b'data', data) + + def test_read_message_should_close(self): + self.stream.feed_data( + b'Host: example.com\r\nConnection: close\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + self.assertTrue(msg.should_close) + + def test_read_message_should_close_http11(self): + self.stream.feed_data( + b'Host: example.com\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(version=(1, 1)))) + self.assertFalse(msg.should_close) + + def test_read_message_should_close_http10(self): + self.stream.feed_data( + b'Host: example.com\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(version=(1, 0)))) + self.assertTrue(msg.should_close) + + def test_read_message_should_close_keep_alive(self): + self.stream.feed_data( + b'Host: example.com\r\nConnection: keep-alive\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + self.assertFalse(msg.should_close) + + def test_read_message_content_length_broken(self): + self.stream.feed_data( + b'Host: example.com\r\nContent-Length: qwe\r\n\r\n') + + self.assertRaises( + http.client.HTTPException, + self.loop.run_until_complete, + tulip.Task(self.stream.read_message())) + + def test_read_message_content_length_wrong(self): + self.stream.feed_data( + b'Host: example.com\r\nContent-Length: -1\r\n\r\n') + + self.assertRaises( + http.client.HTTPException, + self.loop.run_until_complete, + tulip.Task(self.stream.read_message())) + + def test_read_message_content_length(self): + self.stream.feed_data( + b'Host: example.com\r\nContent-Length: 2\r\n\r\n12') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(b'12', payload) + + def test_read_message_content_length_no_val(self): + self.stream.feed_data(b'Host: example.com\r\n\r\n12') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(readall=False))) + + payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(b'', payload) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_read_message_deflate(self): + self.stream.feed_data( + ('Host: example.com\r\nContent-Length: %s\r\n' + 'Content-Encoding: deflate\r\n\r\n' % + len(self._COMPRESSED)).encode()) + self.stream.feed_data(self._COMPRESSED) + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(b'data', payload) + + def test_read_message_deflate_disabled(self): + self.stream.feed_data( + ('Host: example.com\r\nContent-Encoding: deflate\r\n' + 'Content-Length: %s\r\n\r\n' % + len(self._COMPRESSED)).encode()) + self.stream.feed_data(self._COMPRESSED) + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(compression=False))) + payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(self._COMPRESSED, payload) + + def test_read_message_deflate_unknown(self): + self.stream.feed_data( + ('Host: example.com\r\nContent-Encoding: compress\r\n' + 'Content-Length: %s\r\n\r\n' % len(self._COMPRESSED)).encode()) + self.stream.feed_data(self._COMPRESSED) + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(compression=False))) + payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(self._COMPRESSED, payload) + + def test_read_message_websocket(self): + self.stream.feed_data( + b'Host: example.com\r\nSec-Websocket-Key1: 13\r\n\r\n1234567890') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(b'12345678', payload) + + def test_read_message_chunked(self): + self.stream.feed_data( + b'Host: example.com\r\nTransfer-Encoding: chunked\r\n\r\n') + self.stream.feed_data( + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(b'dataline', payload) + + def test_read_message_readall(self): + self.stream.feed_data( + b'Host: example.com\r\n\r\n') + self.stream.feed_data(b'data') + self.stream.feed_data(b'line') + self.stream.feed_eof() + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(readall=True))) + + payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(b'dataline', payload) + + +class HttpStreamWriterTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + self.writer = protocol.HttpStreamWriter(self.transport) + + def test_ctor(self): + transport = unittest.mock.Mock() + writer = protocol.HttpStreamWriter(transport, 'latin-1') + self.assertIs(writer.transport, transport) + self.assertEqual(writer.encoding, 'latin-1') + + def test_encode(self): + self.assertEqual(b'test', self.writer.encode('test')) + self.assertEqual(b'test', self.writer.encode(b'test')) + + def test_decode(self): + self.assertEqual('test', self.writer.decode('test')) + self.assertEqual('test', self.writer.decode(b'test')) + + def test_write(self): + self.writer.write(b'test') + self.assertTrue(self.transport.write.called) + self.assertEqual((b'test',), self.transport.write.call_args[0]) + + def test_write_str(self): + self.writer.write_str('test') + self.assertTrue(self.transport.write.called) + self.assertEqual((b'test',), self.transport.write.call_args[0]) + + def test_write_cunked(self): + self.writer.write_chunked('') + self.assertFalse(self.transport.write.called) + + self.writer.write_chunked('data') + self.assertEqual( + [(b'4\r\n',), (b'data',), (b'\r\n',)], + [c[0] for c in self.transport.write.call_args_list]) + + def test_write_eof(self): + self.writer.write_chunked_eof() + self.assertEqual((b'0\r\n\r\n',), self.transport.write.call_args[0]) diff --git a/tests/locks_test.py b/tests/locks_test.py new file mode 100644 index 00000000..20dc222b --- /dev/null +++ b/tests/locks_test.py @@ -0,0 +1,753 @@ +"""Tests for lock.py""" + +import time +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import locks +from tulip import tasks +from tulip import test_utils + + +class LockTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + lock = locks.Lock() + self.assertTrue(repr(lock).endswith('[unlocked]>')) + + def acquire_lock(): + yield from lock + + self.event_loop.run_until_complete(tasks.Task(acquire_lock())) + self.assertTrue(repr(lock).endswith('[locked]>')) + + def test_lock(self): + lock = locks.Lock() + + def acquire_lock(): + return (yield from lock) + + res = self.event_loop.run_until_complete(tasks.Task(acquire_lock())) + + self.assertTrue(res) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_acquire(self): + lock = locks.Lock() + result = [] + + self.assertTrue( + self.event_loop.run_until_complete( + tasks.Task(lock.acquire()) + )) + + @tasks.coroutine + def c1(result): + if (yield from lock.acquire()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from lock.acquire()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from lock.acquire()): + result.append(3) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + self.event_loop.run_once() + self.assertEqual([1], result) + + tasks.Task(c3(result)) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + + def test_acquire_timeout(self): + lock = locks.Lock() + self.assertTrue( + self.event_loop.run_until_complete(tasks.Task(lock.acquire()))) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete( + tasks.Task(lock.acquire(timeout=0.1))) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + lock = locks.Lock() + self.event_loop.run_until_complete(tasks.Task(lock.acquire())) + + self.event_loop.call_later(0.1, lock.release) + acquired = self.event_loop.run_until_complete( + tasks.Task(lock.acquire(10.1))) + self.assertTrue(acquired) + + def test_acquire_timeout_mixed(self): + lock = locks.Lock() + self.event_loop.run_until_complete(tasks.Task(lock.acquire())) + tasks.Task(lock.acquire()) + tasks.Task(lock.acquire()) + acquire_task = tasks.Task(lock.acquire(0.1)) + tasks.Task(lock.acquire()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(lock._waiters)) + + def test_acquire_cancel(self): + lock = locks.Lock() + self.assertTrue( + self.event_loop.run_until_complete(tasks.Task(lock.acquire()))) + + task = tasks.Task(lock.acquire()) + self.event_loop.call_soon(task.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, task) + self.assertFalse(lock._waiters) + + def test_release_not_acquired(self): + lock = locks.Lock() + + self.assertRaises(RuntimeError, lock.release) + + def test_release_no_waiters(self): + lock = locks.Lock() + self.event_loop.run_until_complete(tasks.Task(lock.acquire())) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_context_manager(self): + lock = locks.Lock() + + @tasks.task + def acquire_lock(): + return (yield from lock) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + def test_context_manager_no_yield(self): + lock = locks.Lock() + + try: + with lock: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + +class EventWaiterTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + ev = locks.EventWaiter() + self.assertTrue(repr(ev).endswith('[unset]>')) + + ev.set() + self.assertTrue(repr(ev).endswith('[set]>')) + + def test_wait(self): + ev = locks.EventWaiter() + self.assertFalse(ev.is_set()) + + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from ev.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from ev.wait()): + result.append(3) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + tasks.Task(c3(result)) + + ev.set() + self.event_loop.run_once() + self.assertEqual([3, 1, 2], result) + + def test_wait_on_set(self): + ev = locks.EventWaiter() + ev.set() + + res = self.event_loop.run_until_complete(tasks.Task(ev.wait())) + self.assertTrue(res) + + def test_wait_timeout(self): + ev = locks.EventWaiter() + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(ev.wait(0.1))) + self.assertFalse(res) + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + ev = locks.EventWaiter() + self.event_loop.call_later(0.1, ev.set) + acquired = self.event_loop.run_until_complete( + tasks.Task(ev.wait(10.1))) + self.assertTrue(acquired) + + def test_wait_timeout_mixed(self): + ev = locks.EventWaiter() + tasks.Task(ev.wait()) + tasks.Task(ev.wait()) + acquire_task = tasks.Task(ev.wait(0.1)) + tasks.Task(ev.wait()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(ev._waiters)) + + def test_wait_cancel(self): + ev = locks.EventWaiter() + + wait = tasks.Task(ev.wait()) + self.event_loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, wait) + self.assertFalse(ev._waiters) + + def test_clear(self): + ev = locks.EventWaiter() + self.assertFalse(ev.is_set()) + + ev.set() + self.assertTrue(ev.is_set()) + + ev.clear() + self.assertFalse(ev.is_set()) + + def test_clear_with_waiters(self): + ev = locks.EventWaiter() + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + tasks.Task(c1(result)) + self.event_loop.run_once() + self.assertEqual([], result) + + ev.set() + ev.clear() + self.assertFalse(ev.is_set()) + + ev.set() + ev.set() + self.assertEqual(1, len(ev._waiters)) + + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertEqual(0, len(ev._waiters)) + + +class ConditionTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_wait(self): + cond = locks.Condition() + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue( + self.event_loop.run_until_complete(tasks.Task(cond.acquire()))) + cond.notify() + self.event_loop.run_once() + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + def test_wait_timeout(self): + cond = locks.Condition() + self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + + t0 = time.monotonic() + wait = self.event_loop.run_until_complete(tasks.Task(cond.wait(0.1))) + self.assertFalse(wait) + self.assertTrue(cond.locked()) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + def test_wait_cancel(self): + cond = locks.Condition() + self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + + wait = tasks.Task(cond.wait()) + self.event_loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, wait) + self.assertFalse(cond._condition_waiters) + self.assertTrue(cond.locked()) + + def test_wait_unacquired(self): + self.suppress_log_errors() + + cond = locks.Condition() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, tasks.Task(cond.wait())) + + def test_wait_for(self): + cond = locks.Condition() + presult = False + + def predicate(): + return presult + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate)): + result.append(1) + cond.release() + + tasks.Task(c1(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([], result) + + presult = True + self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + def test_wait_for_timeout(self): + cond = locks.Condition() + + result = [] + + predicate = unittest.mock.Mock() + predicate.return_value = False + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate, 0.2)): + result.append(1) + else: + result.append(2) + cond.release() + + wait_for = tasks.Task(c1(result)) + + t0 = time.monotonic() + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(wait_for) + self.assertEqual([2], result) + self.assertEqual(3, predicate.call_count) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.18 < total_time < 0.22) + + def test_wait_for_unacquired(self): + self.suppress_log_errors() + + cond = locks.Condition() + + # predicate can return true immediately + res = self.event_loop.run_until_complete(tasks.Task( + cond.wait_for(lambda: [1, 2, 3]))) + self.assertEqual([1, 2, 3], res) + + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, + tasks.Task(cond.wait_for(lambda: False))) + + def test_notify(self): + cond = locks.Condition() + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + cond.release() + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + cond.notify(1) + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + cond.notify(1) + cond.notify(2048) + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + + def test_notify_all(self): + cond = locks.Condition() + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + cond.notify_all() + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + + def test_notify_unacquired(self): + cond = locks.Condition() + self.assertRaises(RuntimeError, cond.notify) + + def test_notify_all_unacquired(self): + cond = locks.Condition() + self.assertRaises(RuntimeError, cond.notify_all) + + +class SemaphoreTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + sem = locks.Semaphore() + self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) + + self.event_loop.run_until_complete(tasks.Task(sem.acquire())) + self.assertTrue(repr(sem).endswith('[locked]>')) + + def test_semaphore(self): + sem = locks.Semaphore() + self.assertEqual(1, sem._value) + + @tasks.task + def acquire_lock(): + return (yield from sem) + + res = self.event_loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(sem.locked()) + self.assertEqual(0, sem._value) + + sem.release() + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + def test_semaphore_value(self): + self.assertRaises(ValueError, locks.Semaphore, -1) + + def test_acquire(self): + sem = locks.Semaphore(3) + result = [] + + self.assertTrue( + self.event_loop.run_until_complete(tasks.Task(sem.acquire()))) + self.assertTrue( + self.event_loop.run_until_complete(tasks.Task(sem.acquire()))) + self.assertFalse(sem.locked()) + + @tasks.coroutine + def c1(result): + yield from sem.acquire() + result.append(1) + + @tasks.coroutine + def c2(result): + yield from sem.acquire() + result.append(2) + + @tasks.coroutine + def c3(result): + yield from sem.acquire() + result.append(3) + + @tasks.coroutine + def c4(result): + yield from sem.acquire() + result.append(4) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(sem.locked()) + self.assertEqual(2, len(sem._waiters)) + self.assertEqual(0, sem._value) + + tasks.Task(c4(result)) + + sem.release() + sem.release() + self.assertEqual(2, sem._value) + + self.event_loop.run_once() + self.assertEqual(0, sem._value) + self.assertEqual([1, 2, 3], result) + self.assertTrue(sem.locked()) + self.assertEqual(1, len(sem._waiters)) + self.assertEqual(0, sem._value) + + def test_acquire_timeout(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(tasks.Task(sem.acquire())) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete( + tasks.Task(sem.acquire(0.1))) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + sem = locks.Semaphore() + self.event_loop.run_until_complete(tasks.Task(sem.acquire())) + + self.event_loop.call_later(0.1, sem.release) + acquired = self.event_loop.run_until_complete( + tasks.Task(sem.acquire(10.1))) + self.assertTrue(acquired) + + def test_acquire_timeout_mixed(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(tasks.Task(sem.acquire())) + tasks.Task(sem.acquire()) + tasks.Task(sem.acquire()) + acquire_task = tasks.Task(sem.acquire(0.1)) + tasks.Task(sem.acquire()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(sem._waiters)) + + def test_acquire_cancel(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(tasks.Task(sem.acquire())) + + acquire = tasks.Task(sem.acquire()) + self.event_loop.call_soon(acquire.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, acquire) + self.assertFalse(sem._waiters) + + def test_release_not_acquired(self): + sem = locks.Semaphore(bound=True) + + self.assertRaises(ValueError, sem.release) + + def test_release_no_waiters(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(tasks.Task(sem.acquire())) + self.assertTrue(sem.locked()) + + sem.release() + self.assertFalse(sem.locked()) + + def test_context_manager(self): + sem = locks.Semaphore(2) + + @tasks.task + def acquire_lock(): + return (yield from sem) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertTrue(sem.locked()) + + self.assertEqual(2, sem._value) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/sample.crt b/tests/sample.crt new file mode 100644 index 00000000..6a1e3f3c --- /dev/null +++ b/tests/sample.crt @@ -0,0 +1,14 @@ +-----BEGIN CERTIFICATE----- +MIICMzCCAZwCCQDFl4ys0fU7iTANBgkqhkiG9w0BAQUFADBeMQswCQYDVQQGEwJV +UzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwNU2FuLUZyYW5jaXNjbzEi +MCAGA1UECgwZUHl0aG9uIFNvZnR3YXJlIEZvbmRhdGlvbjAeFw0xMzAzMTgyMDA3 +MjhaFw0yMzAzMTYyMDA3MjhaMF4xCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxp +Zm9ybmlhMRYwFAYDVQQHDA1TYW4tRnJhbmNpc2NvMSIwIAYDVQQKDBlQeXRob24g +U29mdHdhcmUgRm9uZGF0aW9uMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCn +t3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx/47Vc5TZSaO11uO7 +gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIiqusnLfpqR8cIAavg +Z06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQABMA0GCSqGSIb3DQEB +BQUAA4GBAE9PknG6pv72+5z/gsDGYy8sK5UNkbWSNr4i4e5lxVsF03+/M71H+3AB +MxVX4+A+Vlk2fmU+BrdHIIUE0r1dDcO3josQ9hc9OJpp5VLSQFP8VeuJCmzYPp9I +I8WbW93cnXnChTrYQVdgVoFdv7GE9YgU7NYkrGIM0nZl1/f/bHPB +-----END CERTIFICATE----- diff --git a/tests/sample.key b/tests/sample.key new file mode 100644 index 00000000..edfea8dc --- /dev/null +++ b/tests/sample.key @@ -0,0 +1,15 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXQIBAAKBgQCnt3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx +/47Vc5TZSaO11uO7gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIi +qusnLfpqR8cIAavgZ06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQAB +AoGABfm8k19Yue3W68BecKEGS0VBV57GRTPT+MiBGvVGNIQ15gk6w3sGfMZsdD1y +bsUkQgcDb2d/4i5poBTpl/+Cd41V+c20IC/sSl5X1IEreHMKSLhy/uyjyiyfXlP1 +iXhToFCgLWwENWc8LzfUV8vuAV5WG6oL9bnudWzZxeqx8V0CQQDR7xwVj6LN70Eb +DUhSKLkusmFw5Gk9NJ/7wZ4eHg4B8c9KNVvSlLCLhcsVTQXuqYeFpOqytI45SneP +lr0vrvsDAkEAzITYiXu6ox5huDCG7imX2W9CAYuX638urLxBqBXMS7GqBzojD6RL +21Q8oPwJWJquERa3HDScq1deiQbM9uKIkwJBAIa1PLslGN216Xv3UPHPScyKD/aF +ynXIv+OnANPoiyp6RH4ksQ/18zcEGiVH8EeNpvV9tlAHhb+DZibQHgNr74sCQQC0 +zhToplu/bVKSlUQUNO0rqrI9z30FErDewKeCw5KSsIRSU1E/uM3fHr9iyq4wiL6u +GNjUtKZ0y46lsT9uW6LFAkB5eqeEQnshAdr3X5GykWHJ8DDGBXPPn6Rce1NX4RSq +V9khG2z1bFyfo+hMqpYnF2k32hVq3E54RS8YYnwBsVof +-----END RSA PRIVATE KEY----- diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py new file mode 100644 index 00000000..b8fb10b0 --- /dev/null +++ b/tests/selector_events_test.py @@ -0,0 +1,1421 @@ +"""Tests for selector_events.py""" + +import errno +import socket +import unittest +import unittest.mock +try: + import ssl +except ImportError: + ssl = None + +from tulip import futures +from tulip import selectors +from tulip.selector_events import BaseSelectorEventLoop +from tulip.selector_events import _SelectorSslTransport +from tulip.selector_events import _SelectorSocketTransport +from tulip.selector_events import _SelectorDatagramTransport + + +class TestBaseSelectorEventLoop(BaseSelectorEventLoop): + + def _make_self_pipe(self): + self._ssock = unittest.mock.Mock() + self._csock = unittest.mock.Mock() + self._internal_fds += 1 + + +class BaseSelectorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.event_loop = TestBaseSelectorEventLoop(unittest.mock.Mock()) + + def test_make_socket_transport(self): + m = unittest.mock.Mock() + self.event_loop.add_reader = unittest.mock.Mock() + self.assertIsInstance( + self.event_loop._make_socket_transport(m, m), + _SelectorSocketTransport) + + def test_make_ssl_transport(self): + m = unittest.mock.Mock() + self.event_loop.add_reader = unittest.mock.Mock() + self.event_loop.add_writer = unittest.mock.Mock() + self.event_loop.remove_reader = unittest.mock.Mock() + self.event_loop.remove_writer = unittest.mock.Mock() + self.assertIsInstance( + self.event_loop._make_ssl_transport(m, m, m, m), + _SelectorSslTransport) + + def test_close(self): + ssock = self.event_loop._ssock + csock = self.event_loop._csock + remove_reader = self.event_loop.remove_reader = unittest.mock.Mock() + + self.event_loop._selector.close() + self.event_loop._selector = selector = unittest.mock.Mock() + self.event_loop.close() + self.assertIsNone(self.event_loop._selector) + self.assertIsNone(self.event_loop._csock) + self.assertIsNone(self.event_loop._ssock) + self.assertTrue(selector.close.called) + self.assertTrue(ssock.close.called) + self.assertTrue(csock.close.called) + self.assertTrue(remove_reader.called) + + self.event_loop.close() + self.event_loop.close() + + def test_close_no_selector(self): + ssock = self.event_loop._ssock + csock = self.event_loop._csock + remove_reader = self.event_loop.remove_reader = unittest.mock.Mock() + + self.event_loop._selector.close() + self.event_loop._selector = None + self.event_loop.close() + self.assertIsNone(self.event_loop._selector) + self.assertFalse(ssock.close.called) + self.assertFalse(csock.close.called) + self.assertFalse(remove_reader.called) + + def test_socketpair(self): + self.assertRaises(NotImplementedError, self.event_loop._socketpair) + + def test_read_from_self_tryagain(self): + class Err(socket.error): + errno = errno.EAGAIN + + self.event_loop._ssock.recv.side_effect = Err + self.assertIsNone(self.event_loop._read_from_self()) + + def test_read_from_self_exception(self): + class Err(socket.error): + pass + + self.event_loop._ssock.recv.side_effect = Err + self.assertRaises(Err, self.event_loop._read_from_self) + + def test_write_to_self_tryagain(self): + class Err(socket.error): + errno = errno.EAGAIN + + self.event_loop._csock.send.side_effect = Err + self.assertIsNone(self.event_loop._write_to_self()) + + def test_write_to_self_exception(self): + class Err(socket.error): + pass + + self.event_loop._csock.send.side_effect = Err + self.assertRaises(Err, self.event_loop._write_to_self) + + def test_sock_recv(self): + sock = unittest.mock.Mock() + self.event_loop._sock_recv = unittest.mock.Mock() + + f = self.event_loop.sock_recv(sock, 1024) + self.assertIsInstance(f, futures.Future) + self.assertEqual( + (f, False, sock, 1024), + self.event_loop._sock_recv.call_args[0]) + + def test__sock_recv_canceled_fut(self): + sock = unittest.mock.Mock() + + f = futures.Future() + f.cancel() + + self.event_loop._sock_recv(f, False, sock, 1024) + self.assertFalse(sock.recv.called) + + def test__sock_recv_unregister(self): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + f = futures.Future() + f.cancel() + + self.event_loop.remove_reader = unittest.mock.Mock() + self.event_loop._sock_recv(f, True, sock, 1024) + self.assertEqual((10,), self.event_loop.remove_reader.call_args[0]) + + def test__sock_recv_tryagain(self): + class Err(socket.error): + errno = errno.EAGAIN + + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.recv.side_effect = Err + + self.event_loop.add_reader = unittest.mock.Mock() + self.event_loop._sock_recv(f, False, sock, 1024) + self.assertEqual((10, self.event_loop._sock_recv, f, True, sock, 1024), + self.event_loop.add_reader.call_args[0]) + + def test__sock_recv_exception(self): + class Err(socket.error): + pass + + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.recv.side_effect = Err + + self.event_loop._sock_recv(f, False, sock, 1024) + self.assertIsInstance(f.exception(), Err) + + def test_sock_sendall(self): + sock = unittest.mock.Mock() + self.event_loop._sock_sendall = unittest.mock.Mock() + + f = self.event_loop.sock_sendall(sock, b'data') + self.assertIsInstance(f, futures.Future) + self.assertEqual( + (f, False, sock, b'data'), + self.event_loop._sock_sendall.call_args[0]) + + def test_sock_sendall_nodata(self): + sock = unittest.mock.Mock() + self.event_loop._sock_sendall = unittest.mock.Mock() + + f = self.event_loop.sock_sendall(sock, b'') + self.assertIsInstance(f, futures.Future) + self.assertTrue(f.done()) + self.assertFalse(self.event_loop._sock_sendall.called) + + def test__sock_sendall_canceled_fut(self): + sock = unittest.mock.Mock() + + f = futures.Future() + f.cancel() + + self.event_loop._sock_sendall(f, False, sock, b'data') + self.assertFalse(sock.send.called) + + def test__sock_sendall_unregister(self): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + f = futures.Future() + f.cancel() + + self.event_loop.remove_writer = unittest.mock.Mock() + self.event_loop._sock_sendall(f, True, sock, b'data') + self.assertEqual((10,), self.event_loop.remove_writer.call_args[0]) + + def test__sock_sendall_tryagain(self): + class Err(socket.error): + errno = errno.EAGAIN + + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.send.side_effect = Err + + self.event_loop.add_writer = unittest.mock.Mock() + self.event_loop._sock_sendall(f, False, sock, b'data') + self.assertEqual( + (10, self.event_loop._sock_sendall, f, True, sock, b'data'), + self.event_loop.add_writer.call_args[0]) + + def test__sock_sendall_exception(self): + class Err(socket.error): + pass + + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.send.side_effect = Err + + self.event_loop._sock_sendall(f, False, sock, b'data') + self.assertIsInstance(f.exception(), Err) + + def test__sock_sendall(self): + sock = unittest.mock.Mock() + + f = futures.Future() + sock.fileno.return_value = 10 + sock.send.return_value = 4 + + self.event_loop._sock_sendall(f, False, sock, b'data') + self.assertTrue(f.done()) + + def test__sock_sendall_partial(self): + sock = unittest.mock.Mock() + + f = futures.Future() + sock.fileno.return_value = 10 + sock.send.return_value = 2 + + self.event_loop.add_writer = unittest.mock.Mock() + self.event_loop._sock_sendall(f, False, sock, b'data') + self.assertFalse(f.done()) + self.assertEqual( + (10, self.event_loop._sock_sendall, f, True, sock, b'ta'), + self.event_loop.add_writer.call_args[0]) + + def test__sock_sendall_none(self): + sock = unittest.mock.Mock() + + f = futures.Future() + sock.fileno.return_value = 10 + sock.send.return_value = 0 + + self.event_loop.add_writer = unittest.mock.Mock() + self.event_loop._sock_sendall(f, False, sock, b'data') + self.assertFalse(f.done()) + self.assertEqual( + (10, self.event_loop._sock_sendall, f, True, sock, b'data'), + self.event_loop.add_writer.call_args[0]) + + def test_sock_connect(self): + sock = unittest.mock.Mock() + self.event_loop._sock_connect = unittest.mock.Mock() + + f = self.event_loop.sock_connect(sock, ('127.0.0.1', 8080)) + self.assertIsInstance(f, futures.Future) + self.assertEqual( + (f, False, sock, ('127.0.0.1', 8080)), + self.event_loop._sock_connect.call_args[0]) + + def test__sock_connect(self): + f = futures.Future() + + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + self.event_loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) + self.assertTrue(f.done()) + self.assertTrue(sock.connect.called) + + def test__sock_connect_canceled_fut(self): + sock = unittest.mock.Mock() + + f = futures.Future() + f.cancel() + + self.event_loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) + self.assertFalse(sock.connect.called) + + def test__sock_connect_unregister(self): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + f = futures.Future() + f.cancel() + + self.event_loop.remove_writer = unittest.mock.Mock() + self.event_loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.assertEqual((10,), self.event_loop.remove_writer.call_args[0]) + + def test__sock_connect_tryagain(self): + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.getsockopt.return_value = errno.EAGAIN + + self.event_loop.add_writer = unittest.mock.Mock() + self.event_loop.remove_writer = unittest.mock.Mock() + + self.event_loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.assertEqual( + (10, self.event_loop._sock_connect, f, + True, sock, ('127.0.0.1', 8080)), + self.event_loop.add_writer.call_args[0]) + + def test__sock_connect_exception(self): + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.getsockopt.return_value = errno.ENOTCONN + + self.event_loop.remove_writer = unittest.mock.Mock() + self.event_loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.assertIsInstance(f.exception(), socket.error) + + def test_sock_accept(self): + sock = unittest.mock.Mock() + self.event_loop._sock_accept = unittest.mock.Mock() + + f = self.event_loop.sock_accept(sock) + self.assertIsInstance(f, futures.Future) + self.assertEqual( + (f, False, sock), self.event_loop._sock_accept.call_args[0]) + + def test__sock_accept(self): + f = futures.Future() + + conn = unittest.mock.Mock() + + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.accept.return_value = conn, ('127.0.0.1', 1000) + + self.event_loop._sock_accept(f, False, sock) + self.assertTrue(f.done()) + self.assertEqual((conn, ('127.0.0.1', 1000)), f.result()) + self.assertEqual((False,), conn.setblocking.call_args[0]) + + def test__sock_accept_canceled_fut(self): + sock = unittest.mock.Mock() + + f = futures.Future() + f.cancel() + + self.event_loop._sock_accept(f, False, sock) + self.assertFalse(sock.accept.called) + + def test__sock_accept_unregister(self): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + + f = futures.Future() + f.cancel() + + self.event_loop.remove_reader = unittest.mock.Mock() + self.event_loop._sock_accept(f, True, sock) + self.assertEqual((10,), self.event_loop.remove_reader.call_args[0]) + + def test__sock_accept_tryagain(self): + class Err(socket.error): + errno = errno.EAGAIN + + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.accept.side_effect = Err + + self.event_loop.add_reader = unittest.mock.Mock() + self.event_loop._sock_accept(f, False, sock) + self.assertEqual( + (10, self.event_loop._sock_accept, f, True, sock), + self.event_loop.add_reader.call_args[0]) + + def test__sock_accept_exception(self): + class Err(socket.error): + pass + + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.accept.side_effect = Err + + self.event_loop._sock_accept(f, False, sock) + self.assertIsInstance(f.exception(), Err) + + def test_add_reader(self): + self.event_loop._selector.get_info.side_effect = KeyError + h = self.event_loop.add_reader(1, lambda: True) + + self.assertTrue(self.event_loop._selector.register.called) + self.assertEqual( + (1, selectors.EVENT_READ, (h, None)), + self.event_loop._selector.register.call_args[0]) + + def test_add_reader_existing(self): + reader = unittest.mock.Mock() + writer = unittest.mock.Mock() + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_WRITE, (reader, writer)) + h = self.event_loop.add_reader(1, lambda: True) + + self.assertTrue(reader.cancel.called) + self.assertFalse(self.event_loop._selector.register.called) + self.assertTrue(self.event_loop._selector.modify.called) + self.assertEqual( + (1, selectors.EVENT_WRITE | selectors.EVENT_READ, (h, writer)), + self.event_loop._selector.modify.call_args[0]) + + def test_add_reader_existing_writer(self): + writer = unittest.mock.Mock() + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_WRITE, (None, writer)) + h = self.event_loop.add_reader(1, lambda: True) + + self.assertFalse(self.event_loop._selector.register.called) + self.assertTrue(self.event_loop._selector.modify.called) + self.assertEqual( + (1, selectors.EVENT_WRITE | selectors.EVENT_READ, (h, writer)), + self.event_loop._selector.modify.call_args[0]) + + def test_remove_reader(self): + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_READ, (None, None)) + self.assertFalse(self.event_loop.remove_reader(1)) + + self.assertTrue(self.event_loop._selector.unregister.called) + + def test_remove_reader_read_write(self): + reader = unittest.mock.Mock() + writer = unittest.mock.Mock() + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_READ | selectors.EVENT_WRITE, (reader, writer)) + self.assertTrue( + self.event_loop.remove_reader(1)) + + self.assertFalse(self.event_loop._selector.unregister.called) + self.assertEqual( + (1, selectors.EVENT_WRITE, (None, writer)), + self.event_loop._selector.modify.call_args[0]) + + def test_remove_reader_unknown(self): + self.event_loop._selector.get_info.side_effect = KeyError + self.assertFalse( + self.event_loop.remove_reader(1)) + + def test_add_writer(self): + self.event_loop._selector.get_info.side_effect = KeyError + h = self.event_loop.add_writer(1, lambda: True) + + self.assertTrue(self.event_loop._selector.register.called) + self.assertEqual( + (1, selectors.EVENT_WRITE, (None, h)), + self.event_loop._selector.register.call_args[0]) + + def test_add_writer_existing(self): + reader = unittest.mock.Mock() + writer = unittest.mock.Mock() + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_READ, (reader, writer)) + h = self.event_loop.add_writer(1, lambda: True) + + self.assertTrue(writer.cancel.called) + self.assertFalse(self.event_loop._selector.register.called) + self.assertTrue(self.event_loop._selector.modify.called) + self.assertEqual( + (1, selectors.EVENT_WRITE | selectors.EVENT_READ, (reader, h)), + self.event_loop._selector.modify.call_args[0]) + + def test_remove_writer(self): + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_WRITE, (None, None)) + self.assertFalse(self.event_loop.remove_writer(1)) + + self.assertTrue(self.event_loop._selector.unregister.called) + + def test_remove_writer_read_write(self): + reader = unittest.mock.Mock() + writer = unittest.mock.Mock() + self.event_loop._selector.get_info.return_value = ( + selectors.EVENT_READ | selectors.EVENT_WRITE, (reader, writer)) + self.assertTrue( + self.event_loop.remove_writer(1)) + + self.assertFalse(self.event_loop._selector.unregister.called) + self.assertEqual( + (1, selectors.EVENT_READ, (reader, None)), + self.event_loop._selector.modify.call_args[0]) + + def test_remove_writer_unknown(self): + self.event_loop._selector.get_info.side_effect = KeyError + self.assertFalse( + self.event_loop.remove_writer(1)) + + def test_process_events_read(self): + reader = unittest.mock.Mock() + reader.cancelled = False + + self.event_loop._add_callback = unittest.mock.Mock() + self.event_loop._process_events( + ((1, selectors.EVENT_READ, (reader, None)),)) + self.assertTrue(self.event_loop._add_callback.called) + self.assertEqual((reader,), self.event_loop._add_callback.call_args[0]) + + def test_process_events_read_cancelled(self): + reader = unittest.mock.Mock() + reader.cancelled = True + + self.event_loop.remove_reader = unittest.mock.Mock() + self.event_loop._process_events( + ((1, selectors.EVENT_READ, (reader, None)),)) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertEqual((1,), self.event_loop.remove_reader.call_args[0]) + + def test_process_events_write(self): + writer = unittest.mock.Mock() + writer.cancelled = False + + self.event_loop._add_callback = unittest.mock.Mock() + self.event_loop._process_events( + ((1, selectors.EVENT_WRITE, (None, writer)),)) + self.assertTrue(self.event_loop._add_callback.called) + self.assertEqual((writer,), self.event_loop._add_callback.call_args[0]) + + def test_process_events_write_cancelled(self): + writer = unittest.mock.Mock() + writer.cancelled = True + + self.event_loop.remove_writer = unittest.mock.Mock() + self.event_loop._process_events( + ((1, selectors.EVENT_WRITE, (None, writer)),)) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertEqual((1,), self.event_loop.remove_writer.call_args[0]) + + +class SelectorSocketTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock() + self.sock = unittest.mock.Mock() + self.protocol = unittest.mock.Mock() + + def test_ctor(self): + _SelectorSocketTransport(self.event_loop, self.sock, self.protocol) + self.assertTrue(self.event_loop.add_reader.called) + self.assertTrue(self.event_loop.call_soon.called) + + def test_ctor_with_waiter(self): + fut = futures.Future() + + _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol, fut) + self.assertEqual(2, self.event_loop.call_soon.call_count) + self.assertEqual(fut.set_result, + self.event_loop.call_soon.call_args[0][0]) + + def test_read_ready(self): + data_received = unittest.mock.Mock() + self.protocol.data_received = data_received + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + + self.sock.recv.return_value = b'data' + transport._read_ready() + + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (data_received, b'data'), + self.event_loop.call_soon.call_args[0]) + + def test_read_ready_eof(self): + eof_received = unittest.mock.Mock() + self.protocol.eof_received = eof_received + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + + self.sock.recv.return_value = b'' + transport._read_ready() + + self.assertTrue(self.event_loop.remove_reader.called) + self.assertEqual( + (eof_received,), self.event_loop.call_soon.call_args[0]) + + def test_read_ready_tryagain(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + + class Err(socket.error): + errno = errno.EAGAIN + + self.sock.recv.side_effect = Err + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + + def test_read_ready_err(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + + class Err(socket.error): + pass + + self.sock.recv.side_effect = Err + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + self.assertTrue(transport._fatal_error.called) + self.assertIsInstance(transport._fatal_error.call_args[0][0], Err) + + def test_abort(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + + transport.abort() + self.assertTrue(transport._fatal_error.called) + self.assertIsNone(transport._fatal_error.call_args[0][0]) + + def test_write(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport.write(data) + self.assertTrue(self.sock.send.called) + self.assertEqual(self.sock.send.call_args[0], (data,)) + + def test_write_no_data(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append(b'data') + transport.write(b'') + self.assertFalse(self.sock.send.called) + self.assertEqual([b'data'], transport._buffer) + + def test_write_buffer(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append(b'data1') + transport.write(b'data2') + self.assertFalse(self.sock.send.called) + self.assertEqual([b'data1', b'data2'], transport._buffer) + + def test_write_partial(self): + data = b'data' + self.sock.send.return_value = 2 + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport.write(data) + + self.assertTrue(self.event_loop.add_writer.called) + self.assertEqual( + transport._write_ready, self.event_loop.add_writer.call_args[0][1]) + + self.assertEqual([b'ta'], transport._buffer) + + def test_write_partial_none(self): + data = b'data' + self.sock.send.return_value = 0 + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport.write(data) + + self.assertTrue(self.event_loop.add_writer.called) + self.assertEqual( + transport._write_ready, self.event_loop.add_writer.call_args[0][1]) + + self.assertEqual([b'data'], transport._buffer) + + def test_write_tryagain(self): + data = b'data' + + class Err(socket.error): + errno = errno.EAGAIN + + self.sock.send.side_effect = Err + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport.write(data) + + self.assertTrue(self.event_loop.add_writer.called) + self.assertEqual( + transport._write_ready, self.event_loop.add_writer.call_args[0][1]) + + self.assertEqual([b'data'], transport._buffer) + + def test_write_exception(self): + data = b'data' + + class Err(socket.error): + pass + + self.sock.send.side_effect = Err + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport.write(data) + + self.assertTrue(transport._fatal_error.called) + self.assertIsInstance(transport._fatal_error.call_args[0][0], Err) + + def test_write_str(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + self.assertRaises(AssertionError, transport.write, 'str') + + def test_write_closing(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport.close() + self.assertRaises(AssertionError, transport.write, b'data') + + def test_write_ready(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append(data) + transport._write_ready() + self.assertTrue(self.sock.send.called) + self.assertEqual(self.sock.send.call_args[0], (data,)) + self.assertTrue(self.event_loop.remove_writer.called) + + def test_write_ready_closing(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._closing = True + transport._buffer.append(data) + transport._write_ready() + self.assertTrue(self.sock.send.called) + self.assertEqual(self.sock.send.call_args[0], (data,)) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (transport._call_connection_lost, None), + self.event_loop.call_soon.call_args[0]) + + def test_write_ready_no_data(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._write_ready() + self.assertFalse(self.sock.send.called) + self.assertTrue(self.event_loop.remove_writer.called) + + def test_write_ready_partial(self): + data = b'data' + self.sock.send.return_value = 2 + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append(data) + transport._write_ready() + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'ta'], transport._buffer) + + def test_write_ready_partial_none(self): + data = b'data' + self.sock.send.return_value = 0 + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append(data) + transport._write_ready() + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'data'], transport._buffer) + + def test_write_ready_tryagain(self): + class Err(socket.error): + errno = errno.EAGAIN + + self.sock.send.side_effect = Err + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer = [b'data1', b'data2'] + transport._write_ready() + + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'data1data2'], transport._buffer) + + def test_write_ready_exception(self): + class Err(socket.error): + pass + + self.sock.send.side_effect = Err + + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._buffer.append(b'data') + transport._write_ready() + + self.assertTrue(transport._fatal_error.called) + self.assertIsInstance(transport._fatal_error.call_args[0][0], Err) + + def test_close(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + transport.close() + + self.assertTrue(transport._closing) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (transport._call_connection_lost, None), + self.event_loop.call_soon.call_args[0]) + + def test_close_write_buffer(self): + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + self.event_loop.reset_mock() + transport._buffer.append(b'data') + transport.close() + + self.assertTrue(self.event_loop.remove_reader.called) + self.assertFalse(self.event_loop.call_soon.called) + + def test_fatal_error(self): + exc = object() + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + self.event_loop.reset_mock() + transport._buffer.append(b'data') + transport._fatal_error(exc) + + self.assertEqual([], transport._buffer) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (transport._call_connection_lost, exc), + self.event_loop.call_soon.call_args[0]) + + def test_connection_lost(self): + exc = object() + transport = _SelectorSocketTransport( + self.event_loop, self.sock, self.protocol) + self.sock.reset_mock() + self.protocol.reset_mock() + transport._call_connection_lost(exc) + + self.assertTrue(self.protocol.connection_lost.called) + self.assertEqual( + (exc,), self.protocol.connection_lost.call_args[0]) + self.assertTrue(self.sock.close.called) + + +@unittest.skipIf(ssl is None, 'No ssl module') +class SelectorSslTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock() + self.sock = unittest.mock.Mock() + self.protocol = unittest.mock.Mock() + self.sslsock = unittest.mock.Mock() + self.sslsock.fileno.return_value = 1 + self.sslcontext = unittest.mock.Mock() + self.sslcontext.wrap_socket.return_value = self.sslsock + self.waiter = futures.Future() + + self.transport = _SelectorSslTransport( + self.event_loop, self.sock, + self.protocol, self.sslcontext, self.waiter) + self.event_loop.reset_mock() + self.sock.reset_mock() + self.protocol.reset_mock() + self.sslcontext.reset_mock() + + def test_on_handshake(self): + self.transport._on_handshake() + self.assertTrue(self.sslsock.do_handshake.called) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertEqual( + (1, self.transport._on_ready,), + self.event_loop.add_reader.call_args[0]) + self.assertEqual( + (1, self.transport._on_ready,), + self.event_loop.add_writer.call_args[0]) + + def test_on_handshake_reader_retry(self): + self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError + self.transport._on_handshake() + self.assertEqual( + (1, self.transport._on_handshake,), + self.event_loop.add_reader.call_args[0]) + + def test_on_handshake_writer_retry(self): + self.sslsock.do_handshake.side_effect = ssl.SSLWantWriteError + self.transport._on_handshake() + self.assertEqual( + (1, self.transport._on_handshake,), + self.event_loop.add_writer.call_args[0]) + + def test_on_handshake_exc(self): + self.sslsock.do_handshake.side_effect = ValueError + self.transport._on_handshake() + self.assertTrue(self.sslsock.close.called) + + def test_on_handshake_base_exc(self): + self.sslsock.do_handshake.side_effect = BaseException + self.assertRaises(BaseException, self.transport._on_handshake) + self.assertTrue(self.sslsock.close.called) + + def test_write_no_data(self): + self.transport._buffer.append(b'data') + self.transport.write(b'') + self.assertEqual([b'data'], self.transport._buffer) + + def test_write_str(self): + self.assertRaises(AssertionError, self.transport.write, 'str') + + def test_write_closing(self): + self.transport.close() + self.assertRaises(AssertionError, self.transport.write, b'data') + + def test_abort(self): + self.transport._fatal_error = unittest.mock.Mock() + + self.transport.abort() + self.assertTrue(self.transport._fatal_error.called) + self.assertEqual((None,), self.transport._fatal_error.call_args[0]) + + def test_fatal_error(self): + exc = object() + self.transport._buffer.append(b'data') + self.transport._fatal_error(exc) + + self.assertEqual([], self.transport._buffer) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (self.protocol.connection_lost, exc), + self.event_loop.call_soon.call_args[0]) + + def test_close(self): + self.transport.close() + + self.assertTrue(self.transport._closing) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (self.protocol.connection_lost, None), + self.event_loop.call_soon.call_args[0]) + + def test_close_write_buffer(self): + self.transport._buffer.append(b'data') + self.transport.close() + + self.assertTrue(self.event_loop.remove_reader.called) + self.assertFalse(self.event_loop.call_soon.called) + + def test_on_ready_closed(self): + self.sslsock.fileno.return_value = -1 + self.transport._on_ready() + self.assertFalse(self.sslsock.recv.called) + + def test_on_ready_recv(self): + self.sslsock.recv.return_value = b'data' + self.transport._on_ready() + self.assertTrue(self.sslsock.recv.called) + self.assertEqual((b'data',), self.protocol.data_received.call_args[0]) + + def test_on_ready_recv_eof(self): + self.sslsock.recv.return_value = b'' + self.transport._on_ready() + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.sslsock.close.called) + self.assertTrue(self.protocol.connection_lost.called) + + def test_on_ready_recv_retry(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.transport._on_ready() + self.assertTrue(self.sslsock.recv.called) + self.assertFalse(self.protocol.data_received.called) + + self.sslsock.recv.side_effect = ssl.SSLWantWriteError + self.transport._on_ready() + self.assertFalse(self.protocol.data_received.called) + + class Err(socket.error): + errno = errno.EAGAIN + + self.sslsock.recv.side_effect = Err + self.transport._on_ready() + self.assertFalse(self.protocol.data_received.called) + + def test_on_ready_recv_exc(self): + class Err(socket.error): + pass + + self.sslsock.recv.side_effect = Err + self.transport._fatal_error = unittest.mock.Mock() + self.transport._on_ready() + self.assertTrue(self.transport._fatal_error.called) + + def test_on_ready_send(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 4 + self.transport._buffer = [b'data'] + self.transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertEqual([], self.transport._buffer) + + def test_on_ready_send_none(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 0 + self.transport._buffer = [b'data1', b'data2'] + self.transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertEqual([b'data1data2'], self.transport._buffer) + + def test_on_ready_send_partial(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 2 + self.transport._buffer = [b'data1', b'data2'] + self.transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertEqual([b'ta1data2'], self.transport._buffer) + + def test_on_ready_send_closing_partial(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 2 + self.transport._buffer = [b'data1', b'data2'] + self.transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertFalse(self.sslsock.close.called) + + def test_on_ready_send_closing(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 4 + self.transport.close() + self.transport._buffer = [b'data'] + self.transport._on_ready() + self.assertTrue(self.sslsock.close.called) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.protocol.connection_lost.called) + + def test_on_ready_send_retry(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + + self.transport._buffer = [b'data'] + + self.sslsock.send.side_effect = ssl.SSLWantReadError + self.transport._on_ready() + self.assertTrue(self.sslsock.send.called) + self.assertEqual([b'data'], self.transport._buffer) + + self.sslsock.send.side_effect = ssl.SSLWantWriteError + self.transport._on_ready() + self.assertEqual([b'data'], self.transport._buffer) + + class Err(socket.error): + errno = errno.EAGAIN + + self.sslsock.send.side_effect = Err + self.transport._on_ready() + self.assertEqual([b'data'], self.transport._buffer) + + def test_on_ready_send_exc(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + + class Err(socket.error): + pass + + self.sslsock.send.side_effect = Err + self.transport._buffer = [b'data'] + self.transport._fatal_error = unittest.mock.Mock() + self.transport._on_ready() + self.assertTrue(self.transport._fatal_error.called) + self.assertEqual([], self.transport._buffer) + + +class SelectorDatagramTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock() + self.sock = unittest.mock.Mock() + self.protocol = unittest.mock.Mock() + + def test_read_ready(self): + datagram_received = unittest.mock.Mock() + self.protocol.datagram_received = datagram_received + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + + self.sock.recvfrom.return_value = (b'data', ('0.0.0.0', 1234)) + transport._read_ready() + + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (datagram_received, b'data', ('0.0.0.0', 1234)), + self.event_loop.call_soon.call_args[0]) + + def test_read_ready_tryagain(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + + class Err(socket.error): + errno = errno.EAGAIN + + self.sock.recvfrom.side_effect = Err + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + + def test_read_ready_err(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + + class Err(socket.error): + pass + + self.sock.recvfrom.side_effect = Err + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + self.assertTrue(transport._fatal_error.called) + self.assertIsInstance(transport._fatal_error.call_args[0][0], Err) + + def test_abort(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + + transport.abort() + self.assertTrue(transport._fatal_error.called) + self.assertIsNone(transport._fatal_error.call_args[0][0]) + + def test_sendto(self): + data = b'data' + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport.sendto(data, ('0.0.0.0', 1234)) + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234))) + + def test_sendto_no_data(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append((b'data', ('0.0.0.0', 12345))) + transport.sendto(b'', ()) + self.assertFalse(self.sock.sendto.called) + self.assertEqual( + [(b'data', ('0.0.0.0', 12345))], list(transport._buffer)) + + def test_sendto_buffer(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append((b'data1', ('0.0.0.0', 12345))) + transport.sendto(b'data2', ('0.0.0.0', 12345)) + self.assertFalse(self.sock.sendto.called) + self.assertEqual( + [(b'data1', ('0.0.0.0', 12345)), + (b'data2', ('0.0.0.0', 12345))], + list(transport._buffer)) + + def test_sendto_tryagain(self): + data = b'data' + + class Err(socket.error): + errno = errno.EAGAIN + + self.sock.sendto.side_effect = Err + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport.sendto(data, ('0.0.0.0', 12345)) + + self.assertTrue(self.event_loop.add_writer.called) + self.assertEqual( + transport._sendto_ready, + self.event_loop.add_writer.call_args[0][1]) + + self.assertEqual( + [(b'data', ('0.0.0.0', 12345))], list(transport._buffer)) + + def test_sendto_exception(self): + data = b'data' + + class Err(socket.error): + pass + + self.sock.sendto.side_effect = Err + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport.sendto(data, ()) + + self.assertTrue(transport._fatal_error.called) + self.assertIsInstance(transport._fatal_error.call_args[0][0], Err) + + def test_sendto_connection_refused(self): + data = b'data' + + self.sock.sendto.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport.sendto(data, ()) + + self.assertFalse(transport._fatal_error.called) + + def test_sendto_connection_refused_connected(self): + data = b'data' + + self.sock.send.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1)) + transport._fatal_error = unittest.mock.Mock() + transport.sendto(data) + + self.assertTrue(transport._fatal_error.called) + + def test_sendto_str(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + self.assertRaises(AssertionError, transport.sendto, 'str', ()) + + def test_sendto_connected_addr(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1)) + self.assertRaises( + AssertionError, transport.sendto, b'str', ('0.0.0.0', 2)) + + def test_sendto_closing(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport.close() + self.assertRaises(AssertionError, transport.sendto, b'data', ()) + + def test_sendto_ready(self): + data = b'data' + self.sock.sendto.return_value = len(data) + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.append((data, ('0.0.0.0', 12345))) + transport._sendto_ready() + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (data, ('0.0.0.0', 12345))) + self.assertTrue(self.event_loop.remove_writer.called) + + def test_sendto_ready_closing(self): + data = b'data' + self.sock.send.return_value = len(data) + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._closing = True + transport._buffer.append((data, ())) + transport._sendto_ready() + self.assertTrue(self.sock.sendto.called) + self.assertEqual(self.sock.sendto.call_args[0], (data, ())) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (transport._call_connection_lost, None), + self.event_loop.call_soon.call_args[0]) + + def test_sendto_ready_no_data(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._sendto_ready() + self.assertFalse(self.sock.sendto.called) + self.assertTrue(self.event_loop.remove_writer.called) + + def test_sendto_ready_tryagain(self): + class Err(socket.error): + errno = errno.EAGAIN + + self.sock.sendto.side_effect = Err + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._buffer.extend([(b'data1', ()), (b'data2', ())]) + transport._sendto_ready() + + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual( + [(b'data1', ()), (b'data2', ())], + list(transport._buffer)) + + def test_sendto_ready_exception(self): + class Err(socket.error): + pass + + self.sock.sendto.side_effect = Err + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._buffer.append((b'data', ())) + transport._sendto_ready() + + self.assertTrue(transport._fatal_error.called) + self.assertIsInstance(transport._fatal_error.call_args[0][0], Err) + + def test_sendto_ready_connection_refused(self): + self.sock.sendto.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._buffer.append((b'data', ())) + transport._sendto_ready() + + self.assertFalse(transport._fatal_error.called) + + def test_sendto_ready_connection_refused_connection(self): + self.sock.send.side_effect = ConnectionRefusedError + + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1)) + transport._fatal_error = unittest.mock.Mock() + transport._buffer.append((b'data', ())) + transport._sendto_ready() + + self.assertTrue(transport._fatal_error.called) + + def test_close(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + transport.close() + + self.assertTrue(transport._closing) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (transport._call_connection_lost, None), + self.event_loop.call_soon.call_args[0]) + + def test_close_write_buffer(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + self.event_loop.reset_mock() + transport._buffer.append((b'data', ())) + transport.close() + + self.assertTrue(self.event_loop.remove_reader.called) + self.assertFalse(self.event_loop.call_soon.called) + + def test_fatal_error(self): + exc = object() + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + self.event_loop.reset_mock() + transport._buffer.append((b'data', ())) + transport._fatal_error(exc) + + self.assertEqual([], list(transport._buffer)) + self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.event_loop.call_soon.called) + self.assertEqual( + (transport._call_connection_lost, exc), + self.event_loop.call_soon.call_args[0]) + + def test_fatal_error_connected(self): + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1)) + self.event_loop.reset_mock() + transport._fatal_error(ConnectionRefusedError()) + + self.assertEqual( + 2, self.event_loop.call_soon.call_count) + + def test_transport_closing(self): + exc = object() + transport = _SelectorDatagramTransport( + self.event_loop, self.sock, self.protocol) + self.sock.reset_mock() + self.protocol.reset_mock() + transport._call_connection_lost(exc) + + self.assertTrue(self.protocol.connection_lost.called) + self.assertEqual( + (exc,), self.protocol.connection_lost.call_args[0]) + self.assertTrue(self.sock.close.called) diff --git a/tests/selectors_test.py b/tests/selectors_test.py new file mode 100644 index 00000000..996c0130 --- /dev/null +++ b/tests/selectors_test.py @@ -0,0 +1,137 @@ +"""Tests for selectors.py.""" + +import unittest +import unittest.mock + +from tulip import selectors + + +class BaseSelectorTests(unittest.TestCase): + + def test_fileobj_to_fd(self): + self.assertEqual(10, selectors._fileobj_to_fd(10)) + + f = unittest.mock.Mock() + f.fileno.return_value = 10 + self.assertEqual(10, selectors._fileobj_to_fd(f)) + + f.fileno.side_effect = TypeError + self.assertRaises(ValueError, selectors._fileobj_to_fd, f) + + def test_selector_key_repr(self): + key = selectors.SelectorKey(10, selectors.EVENT_READ) + self.assertEqual( + "SelectorKey", repr(key)) + + def test_register(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ) + self.assertIsInstance(key, selectors.SelectorKey) + self.assertEqual(key.fd, 10) + self.assertIs(key, s._fd_to_key[10]) + + def test_register_unknown_event(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.register, unittest.mock.Mock(), 999999) + + def test_register_already_registered(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + s.register(fobj, selectors.EVENT_READ) + self.assertRaises(ValueError, s.register, fobj, selectors.EVENT_READ) + + def test_unregister(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + s.register(fobj, selectors.EVENT_READ) + s.unregister(fobj) + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_unregister_unknown(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.unregister, unittest.mock.Mock()) + + def test_modify_unknown(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.modify, unittest.mock.Mock(), 1) + + def test_modify(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ) + key2 = s.modify(fobj, selectors.EVENT_WRITE) + self.assertNotEqual(key.events, key2.events) + self.assertEqual((selectors.EVENT_WRITE, None), s.get_info(fobj)) + + def test_modify_data(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + d1 = object() + d2 = object() + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ, d1) + key2 = s.modify(fobj, selectors.EVENT_READ, d2) + self.assertEqual(key.events, key2.events) + self.assertNotEqual(key.data, key2.data) + self.assertEqual((selectors.EVENT_READ, d2), s.get_info(fobj)) + + def test_modify_same(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + data = object() + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ, data) + key2 = s.modify(fobj, selectors.EVENT_READ, data) + self.assertIs(key, key2) + + def test_select(self): + s = selectors._BaseSelector() + self.assertRaises(NotImplementedError, s.select) + + def test_close(self): + s = selectors._BaseSelector() + s.register(1, selectors.EVENT_READ) + + s.close() + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_registered_count(self): + s = selectors._BaseSelector() + self.assertEqual(0, s.registered_count()) + + s.register(1, selectors.EVENT_READ) + self.assertEqual(1, s.registered_count()) + + s.unregister(1) + self.assertEqual(0, s.registered_count()) + + def test_context_manager(self): + s = selectors._BaseSelector() + + with s as sel: + sel.register(1, selectors.EVENT_READ) + + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_key_from_fd(self): + s = selectors._BaseSelector() + key = s.register(1, selectors.EVENT_READ) + + self.assertIs(key, s._key_from_fd(1)) + self.assertIsNone(s._key_from_fd(10)) diff --git a/tests/streams_test.py b/tests/streams_test.py new file mode 100644 index 00000000..832ce371 --- /dev/null +++ b/tests/streams_test.py @@ -0,0 +1,315 @@ +"""Tests for streams.py.""" + +import unittest + +from tulip import events +from tulip import streams +from tulip import tasks +from tulip import test_utils + + +class StreamReaderTests(test_utils.LogTrackingTestCase): + + DATA = b'line1\nline2\nline3\n' + + def setUp(self): + super().setUp() + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + super().tearDown() + self.event_loop.close() + + def test_feed_empty_data(self): + stream = streams.StreamReader() + + stream.feed_data(b'') + self.assertEqual(0, stream.byte_count) + + def test_feed_data_byte_count(self): + stream = streams.StreamReader() + + stream.feed_data(self.DATA) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read_zero(self): + """Read zero bytes.""" + stream = streams.StreamReader() + stream.feed_data(self.DATA) + + read_task = tasks.Task(stream.read(0)) + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read(self): + """Read bytes.""" + stream = streams.StreamReader() + read_task = tasks.Task(stream.read(30)) + + def cb(): + stream.feed_data(self.DATA) + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + + def test_read_line_breaks(self): + """Read bytes without line breaks.""" + stream = streams.StreamReader() + stream.feed_data(b'line1') + stream.feed_data(b'line2') + + read_task = tasks.Task(stream.read(5)) + data = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'line1', data) + self.assertEqual(5, stream.byte_count) + + def test_read_eof(self): + """Read bytes, stop at eof.""" + stream = streams.StreamReader() + read_task = tasks.Task(stream.read(1024)) + + def cb(): + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'', data) + self.assertFalse(stream.byte_count) + + def test_read_until_eof(self): + """Read all bytes until eof.""" + stream = streams.StreamReader() + read_task = tasks.Task(stream.read(-1)) + + def cb(): + stream.feed_data(b'chunk1\n') + stream.feed_data(b'chunk2') + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'chunk1\nchunk2', data) + self.assertFalse(stream.byte_count) + + def test_read_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete(tasks.Task(stream.read(2))) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, tasks.Task(stream.read(2))) + + def test_readline(self): + """Read one line.""" + stream = streams.StreamReader() + stream.feed_data(b'chunk1 ') + read_task = tasks.Task(stream.readline()) + + def cb(): + stream.feed_data(b'chunk2 ') + stream.feed_data(b'chunk3 ') + stream.feed_data(b'\n chunk4') + self.event_loop.call_soon(cb) + + line = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) + self.assertEqual(len(b'\n chunk4')-1, stream.byte_count) + + def test_readline_limit_with_existing_data(self): + self.suppress_log_errors() + + stream = streams.StreamReader(3) + stream.feed_data(b'li') + stream.feed_data(b'ne1\nline2\n') + + read_task = tasks.Task(stream.readline()) + self.assertRaises( + ValueError, self.event_loop.run_until_complete, read_task) + self.assertEqual([b'line2\n'], list(stream.buffer)) + + stream = streams.StreamReader(3) + stream.feed_data(b'li') + stream.feed_data(b'ne1') + stream.feed_data(b'li') + + read_task = tasks.Task(stream.readline()) + self.assertRaises( + ValueError, self.event_loop.run_until_complete, read_task) + self.assertEqual([b'li'], list(stream.buffer)) + self.assertEqual(2, stream.byte_count) + + def test_readline_limit(self): + self.suppress_log_errors() + + stream = streams.StreamReader(7) + + def cb(): + stream.feed_data(b'chunk1') + stream.feed_data(b'chunk2') + stream.feed_data(b'chunk3\n') + stream.feed_eof() + self.event_loop.call_soon(cb) + + read_task = tasks.Task(stream.readline()) + self.assertRaises( + ValueError, self.event_loop.run_until_complete, read_task) + self.assertEqual([b'chunk3\n'], list(stream.buffer)) + self.assertEqual(7, stream.byte_count) + + def test_readline_line_byte_count(self): + stream = streams.StreamReader() + stream.feed_data(self.DATA[:6]) + stream.feed_data(self.DATA[6:]) + + read_task = tasks.Task(stream.readline()) + line = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'line1\n', line) + self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) + + def test_readline_eof(self): + stream = streams.StreamReader() + stream.feed_data(b'some data') + stream.feed_eof() + + read_task = tasks.Task(stream.readline()) + line = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'some data', line) + + def test_readline_empty_eof(self): + stream = streams.StreamReader() + stream.feed_eof() + + read_task = tasks.Task(stream.readline()) + line = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'', line) + + def test_readline_read_byte_count(self): + stream = streams.StreamReader() + stream.feed_data(self.DATA) + + read_task = tasks.Task(stream.readline()) + self.event_loop.run_until_complete(read_task) + + read_task = tasks.Task(stream.read(7)) + data = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'line2\nl', data) + self.assertEqual( + len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), + stream.byte_count) + + def test_readline_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete( + tasks.Task(stream.readline())) + self.assertEqual(b'line\n', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, tasks.Task(stream.readline())) + + def test_readexactly_zero_or_less(self): + """Read exact number of bytes (zero or less).""" + stream = streams.StreamReader() + stream.feed_data(self.DATA) + + read_task = tasks.Task(stream.readexactly(0)) + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + read_task = tasks.Task(stream.readexactly(-1)) + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readexactly(self): + """Read exact number of bytes.""" + stream = streams.StreamReader() + + n = 2 * len(self.DATA) + read_task = tasks.Task(stream.readexactly(n)) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA + self.DATA, data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readexactly_eof(self): + """Read exact number of bytes (eof).""" + stream = streams.StreamReader() + n = 2 * len(self.DATA) + read_task = tasks.Task(stream.readexactly(n)) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + + def test_readexactly_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete( + tasks.Task(stream.readexactly(2))) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, + tasks.Task(stream.readexactly(2))) + + def test_exception(self): + stream = streams.StreamReader() + self.assertIsNone(stream.exception()) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(stream.exception(), exc) + + def test_exception_waiter(self): + stream = streams.StreamReader() + + def set_err(): + yield from [] + stream.set_exception(ValueError()) + + def readline(): + yield from stream.readline() + + t1 = tasks.Task(stream.readline()) + t2 = tasks.Task(set_err()) + + self.event_loop.run_until_complete(tasks.Task(tasks.wait([t1, t2]))) + + self.assertRaises(ValueError, t1.result) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/subprocess_test.py b/tests/subprocess_test.py new file mode 100644 index 00000000..09aaed52 --- /dev/null +++ b/tests/subprocess_test.py @@ -0,0 +1,54 @@ +"""Tests for subprocess_transport.py.""" + +import logging +import unittest + +from tulip import events +from tulip import protocols +from tulip import subprocess_transport + + +class MyProto(protocols.Protocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write_eof() + + def data_received(self, data): + logging.info('received: %r', data) + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + + +class FutureTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_unix_subprocess(self): + p = MyProto() + subprocess_transport.UnixSubprocessTransport(p, ['/bin/ls', '-lR']) + self.event_loop.run() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/tasks_test.py b/tests/tasks_test.py new file mode 100644 index 00000000..2a11c202 --- /dev/null +++ b/tests/tasks_test.py @@ -0,0 +1,584 @@ +"""Tests for tasks.py.""" + +import concurrent.futures +import time +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import tasks +from tulip import test_utils + + +class Dummy: + + def __repr__(self): + return 'Dummy()' + + def __call__(self, *args): + pass + + +class TaskTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + super().tearDown() + + def test_task_class(self): + @tasks.coroutine + def notmuch(): + yield from [] + return 'ok' + t = tasks.Task(notmuch()) + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t._event_loop, self.event_loop) + + event_loop = events.new_event_loop() + t = tasks.Task(notmuch(), event_loop=event_loop) + self.assertIs(t._event_loop, event_loop) + + def test_task_decorator(self): + @tasks.task + def notmuch(): + yield from [] + return 'ko' + t = notmuch() + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + + def test_task_repr(self): + @tasks.task + def notmuch(): + yield from [] + return 'abc' + t = notmuch() + t.add_done_callback(Dummy()) + self.assertEqual(repr(t), 'Task()') + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), 'Task()') + self.assertRaises(futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertEqual(repr(t), 'Task()') + t = notmuch() + self.event_loop.run_until_complete(t) + self.assertEqual(repr(t), "Task()") + + def test_task_repr_custom(self): + def coro(): + yield from [] + + class T(futures.Future): + def __repr__(self): + return 'T[]' + + class MyTask(tasks.Task, T): + def __repr__(self): + return super().__repr__() + + t = MyTask(coro()) + self.assertEqual(repr(t), 'T[]()') + + def test_task_basics(self): + @tasks.task + def outer(): + a = yield from inner1() + b = yield from inner2() + return a+b + + @tasks.task + def inner1(): + yield from [] + return 42 + + @tasks.task + def inner2(): + yield from [] + return 1000 + + t = outer() + self.assertEqual(self.event_loop.run_until_complete(t), 1042) + + def test_cancel(self): + @tasks.task + def task(): + yield from tasks.sleep(10.0) + return 12 + + t = task() + self.event_loop.call_soon(t.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_cancel_in_coro(self): + @tasks.coroutine + def task(): + t.cancel() + yield from [] + return 12 + + t = tasks.Task(task()) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_stop_while_run_in_complete(self): + x = 0 + + @tasks.coroutine + def task(): + nonlocal x + while x < 10: + yield from tasks.sleep(0.1) + x += 1 + if x == 2: + self.event_loop.stop() + + t = tasks.Task(task()) + t0 = time.monotonic() + self.assertRaises( + futures.InvalidStateError, + self.event_loop.run_until_complete, t) + t1 = time.monotonic() + self.assertFalse(t.done()) + self.assertTrue(0.18 <= t1-t0 <= 0.22) + self.assertEqual(x, 2) + + def test_timeout(self): + @tasks.task + def task(): + yield from tasks.sleep(10.0) + return 42 + + t = task() + t0 = time.monotonic() + self.assertRaises( + futures.TimeoutError, + self.event_loop.run_until_complete, t, 0.1) + t1 = time.monotonic() + self.assertFalse(t.done()) + self.assertTrue(0.08 <= t1-t0 <= 0.12) + + def test_timeout_not(self): + @tasks.task + def task(): + yield from tasks.sleep(0.1) + return 42 + + t = task() + t0 = time.monotonic() + r = self.event_loop.run_until_complete(t, 10.0) + t1 = time.monotonic() + self.assertTrue(t.done()) + self.assertEqual(r, 42) + self.assertTrue(0.08 <= t1-t0 <= 0.12) + + def test_wait(self): + a = tasks.sleep(0.1) + b = tasks.sleep(0.15) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertEqual(res, 42) + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + # TODO: Test different return_when values. + + def test_wait_first_completed(self): + a = tasks.sleep(10.0) + b = tasks.sleep(0.1) + task = tasks.Task(tasks.wait( + [b, a], return_when=tasks.FIRST_COMPLETED)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + + def test_wait_really_done(self): + self.suppress_log_errors() + # there is possibility that some tasks in the pending list + # became done but their callbacks haven't all been called yet + + @tasks.coroutine + def coro1(): + yield from [None] + + @tasks.coroutine + def coro2(): + yield from [None, None] + + a = tasks.Task(coro1()) + b = tasks.Task(coro2()) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({a, b}, done) + + def test_wait_first_exception(self): + self.suppress_log_errors() + + a = tasks.sleep(10.0) + + @tasks.coroutine + def exc(): + yield from [] + raise ZeroDivisionError('err') + + b = tasks.Task(exc()) + task = tasks.Task(tasks.wait( + [b, a], return_when=tasks.FIRST_EXCEPTION)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + + def test_wait_with_exception(self): + self.suppress_log_errors() + a = tasks.sleep(0.1) + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(0.15) + raise ZeroDivisionError('really') + + b = tasks.Task(sleeper()) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + + t0 = time.monotonic() + self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + t0 = time.monotonic() + self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + + def test_wait_with_timeout(self): + a = tasks.sleep(0.1) + b = tasks.sleep(0.15) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], timeout=0.11) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + + t0 = time.monotonic() + self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.1) + self.assertTrue(t1-t0 <= 0.13) + + def test_as_completed(self): + @tasks.coroutine + def sleeper(dt, x): + yield from tasks.sleep(dt) + return x + + a = sleeper(0.1, 'a') + b = sleeper(0.1, 'b') + c = sleeper(0.15, 'c') + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([b, c, a]): + values.append((yield from f)) + return values + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + + def test_as_completed_with_timeout(self): + self.suppress_log_errors() + a = tasks.sleep(0.1, 'a') + b = tasks.sleep(0.15, 'b') + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([a, b], timeout=0.12): + try: + v = yield from f + values.append((1, v)) + except futures.TimeoutError as exc: + values.append((2, exc)) + return values + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.11) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) + + def test_sleep(self): + @tasks.coroutine + def sleeper(dt, arg): + yield from tasks.sleep(dt/2) + res = yield from tasks.sleep(dt/2, arg) + return res + + t = tasks.Task(sleeper(0.1, 'yeah')) + t0 = time.monotonic() + self.event_loop.run() + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.09) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + + def test_task_cancel_sleeping_task(self): + sleepfut = None + + @tasks.task + def sleep(dt): + nonlocal sleepfut + sleepfut = tasks.sleep(dt) + try: + time.monotonic() + yield from sleepfut + finally: + time.monotonic() + + @tasks.task + def doit(): + sleeper = sleep(5000) + self.event_loop.call_later(0.1, sleeper.cancel) + try: + time.monotonic() + yield from sleeper + except futures.CancelledError: + time.monotonic() + return 'cancelled' + else: + return 'slept in' + + t0 = time.monotonic() + doer = doit() + self.assertEqual(self.event_loop.run_until_complete(doer), 'cancelled') + t1 = time.monotonic() + self.assertTrue(0.09 <= t1-t0 <= 0.13, (t1-t0, sleepfut, doer)) + + @unittest.mock.patch('tulip.tasks.logging') + def test_step_in_completed_task(self, m_logging): + @tasks.coroutine + def notmuch(): + yield from [] + return 'ko' + + task = tasks.Task(notmuch()) + task.set_result('ok') + + task._step() + self.assertTrue(m_logging.warn.called) + self.assertTrue(m_logging.warn.call_args[0][0].startswith( + '_step(): already done: ')) + + @unittest.mock.patch('tulip.tasks.logging') + def test_step_result(self, m_logging): + @tasks.coroutine + def notmuch(): + yield from [None, 1] + return 'ko' + + task = tasks.Task(notmuch()) + task._step() + self.assertFalse(m_logging.warn.called) + + task._step() + self.assertTrue(m_logging.warn.called) + self.assertEqual( + '_step(): bad yield: %r', + m_logging.warn.call_args[0][0]) + self.assertEqual(1, m_logging.warn.call_args[0][1]) + + def test_step_result_future(self): + """If coroutine returns future, task waits on this future.""" + self.suppress_log_warnings() + + class Fut(futures.Future): + def __init__(self, *args): + self.cb_added = False + super().__init__(*args) + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + fut = Fut() + result = None + + @tasks.task + def wait_for_future(): + nonlocal result + result = yield from fut + + wait_for_future() + self.event_loop.run_once() + self.assertTrue(fut.cb_added) + + res = object() + fut.set_result(res) + self.event_loop.run_once() + self.assertIs(res, result) + + def test_step_result_concurrent_future(self): + # Coroutine returns concurrent.futures.Future + self.suppress_log_warnings() + + class Fut(concurrent.futures.Future): + def __init__(self): + self.cb_added = False + super().__init__() + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + c_fut = Fut() + + @tasks.coroutine + def notmuch(): + yield from [c_fut] + return (yield) + + task = tasks.Task(notmuch()) + task._step() + self.assertTrue(c_fut.cb_added) + + res = object() + c_fut.set_result(res) + self.event_loop.run() + self.assertIs(res, task.result()) + + def test_step_with_baseexception(self): + self.suppress_log_errors() + + @tasks.coroutine + def notmutch(): + yield from [] + raise BaseException() + + task = tasks.Task(notmutch()) + self.assertRaises(BaseException, task._step) + + self.assertTrue(task.done()) + self.assertIsInstance(task.exception(), BaseException) + + def test_baseexception_during_cancel(self): + self.suppress_log_errors() + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(10) + + @tasks.coroutine + def notmutch(): + try: + yield from sleeper() + except futures.CancelledError: + raise BaseException() + + task = tasks.Task(notmutch()) + self.event_loop.run_once() + + task.cancel() + self.assertFalse(task.done()) + + self.assertRaises(BaseException, self.event_loop.run_once) + + self.assertTrue(task.done()) + self.assertTrue(task.cancelled()) + + def test_iscoroutinefunction(self): + def fn(): + pass + + self.assertFalse(tasks.iscoroutinefunction(fn)) + + def fn1(): + yield + self.assertFalse(tasks.iscoroutinefunction(fn1)) + + @tasks.coroutine + def fn2(): + yield + self.assertTrue(tasks.iscoroutinefunction(fn2)) + + def test_yield_vs_yield_from(self): + fut = futures.Future() + + @tasks.task + def wait_for_future(): + yield fut + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, task) + + def test_yield_vs_yield_from_generator(self): + fut = futures.Future() + + @tasks.coroutine + def coro(): + yield from fut + + @tasks.task + def wait_for_future(): + yield coro() + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, task) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/transports_test.py b/tests/transports_test.py new file mode 100644 index 00000000..eb61d914 --- /dev/null +++ b/tests/transports_test.py @@ -0,0 +1,39 @@ +"""Tests for transports.py.""" + +import unittest +import unittest.mock + +from tulip import transports + + +class TransportTests(unittest.TestCase): + + def test_ctor_extra_is_none(self): + transport = transports.Transport() + self.assertEqual(transport._extra, {}) + + def test_get_extra_info(self): + transport = transports.Transport({'extra': 'info'}) + self.assertEqual('info', transport.get_extra_info('extra')) + self.assertIsNone(transport.get_extra_info('unknown')) + + default = object() + self.assertIs(default, transport.get_extra_info('unknown', default)) + + def test_writelines(self): + transport = transports.Transport() + transport.write = unittest.mock.Mock() + + transport.writelines(['line1', 'line2', 'line3']) + self.assertEqual(3, transport.write.call_count) + + def test_not_implemented(self): + transport = transports.Transport() + + self.assertRaises(NotImplementedError, transport.write, 'data') + self.assertRaises(NotImplementedError, transport.write_eof) + self.assertRaises(NotImplementedError, transport.can_write_eof) + self.assertRaises(NotImplementedError, transport.pause) + self.assertRaises(NotImplementedError, transport.resume) + self.assertRaises(NotImplementedError, transport.close) + self.assertRaises(NotImplementedError, transport.abort) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py new file mode 100644 index 00000000..24ea4945 --- /dev/null +++ b/tests/unix_events_test.py @@ -0,0 +1,168 @@ +"""Tests for unix_events.py.""" + +import errno +import unittest +import unittest.mock + +try: + import signal +except ImportError: + signal = None + +from tulip import events +from tulip import unix_events + + +@unittest.skipUnless(signal, 'Signals are not supported') +class SelectorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unix_events.SelectorEventLoop() + + def test_check_signal(self): + self.assertRaises( + TypeError, self.event_loop._check_signal, '1') + self.assertRaises( + ValueError, self.event_loop._check_signal, signal.NSIG + 1) + + unix_events.signal = None + + def restore_signal(): + unix_events.signal = signal + self.addCleanup(restore_signal) + + self.assertRaises( + RuntimeError, self.event_loop._check_signal, signal.SIGINT) + + def test_handle_signal_no_handler(self): + self.event_loop._handle_signal(signal.NSIG + 1, ()) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_setup_error(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.set_wakeup_fd.side_effect = ValueError + + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + h = self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + self.assertIsInstance(h, events.Handle) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_install_error(self, m_signal): + m_signal.NSIG = signal.NSIG + + def set_wakeup_fd(fd): + if fd == -1: + raise ValueError() + m_signal.set_wakeup_fd = set_wakeup_fd + + class Err(OSError): + errno = errno.EFAULT + m_signal.signal.side_effect = Err + + self.assertRaises( + Err, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.logging') + def test_add_signal_handler_install_error2(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.event_loop._signal_handlers[signal.SIGHUP] = lambda: True + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(1, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.logging') + def test_add_signal_handler_install_error3(self, m_logging, m_signal): + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + m_signal.NSIG = signal.NSIG + + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(2, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + self.assertTrue( + self.event_loop.remove_signal_handler(signal.SIGHUP)) + self.assertTrue(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_2(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.SIGINT = signal.SIGINT + + self.event_loop.add_signal_handler(signal.SIGINT, lambda: True) + self.event_loop._signal_handlers[signal.SIGHUP] = object() + m_signal.set_wakeup_fd.reset_mock() + + self.assertTrue( + self.event_loop.remove_signal_handler(signal.SIGINT)) + self.assertFalse(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGINT, m_signal.default_int_handler), + m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.logging') + def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.set_wakeup_fd.side_effect = ValueError + + self.event_loop.remove_signal_handler(signal.SIGHUP) + self.assertTrue(m_logging.info) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error(self, m_signal): + m_signal.NSIG = signal.NSIG + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.signal.side_effect = OSError + + self.assertRaises( + OSError, self.event_loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error2(self, m_signal): + m_signal.NSIG = signal.NSIG + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.assertRaises( + RuntimeError, self.event_loop.remove_signal_handler, signal.SIGHUP) diff --git a/tests/winsocketpair_test.py b/tests/winsocketpair_test.py new file mode 100644 index 00000000..b93bdaf1 --- /dev/null +++ b/tests/winsocketpair_test.py @@ -0,0 +1,33 @@ +"""Tests for winsocketpair.py""" + +import errno +import socket +import unittest +import unittest.mock + +from tulip import winsocketpair + + +class WinsocketpairTests(unittest.TestCase): + + def test_winsocketpair(self): + ssock, csock = winsocketpair.socketpair() + + csock.send(b'xxx') + self.assertEqual(b'xxx', ssock.recv(1024)) + + csock.close() + ssock.close() + + @unittest.mock.patch('tulip.winsocketpair.socket') + def test_winsocketpair_exc(self, m_socket): + m_socket.error = socket.error + + class Err(socket.error): + errno = errno.WSAEWOULDBLOCK + 1 + + m_socket.socket.return_value.getsockname.return_value = ('', 12345) + m_socket.socket.return_value.accept.return_value = object(), object() + m_socket.socket.return_value.connect.side_effect = Err + + self.assertRaises(Err, winsocketpair.socketpair) diff --git a/tulip/TODO b/tulip/TODO new file mode 100644 index 00000000..acec5c24 --- /dev/null +++ b/tulip/TODO @@ -0,0 +1,28 @@ +TODO in tulip v2 (tulip/ package directory) +------------------------------------------- + +- See also TBD and Open Issues in PEP 3156 + +- Refactor unix_events.py (it's getting too long) + +- Docstrings + +- Unittests + +- better run_once() behavior? (Run ready list last.) + +- start_serving() + +- Make Handler() callable? Log the exception in there? + +- Add the repeat interval to the Handler class? + +- Recognize Handler passed to add_reader(), call_soon(), etc.? + +- SSL support + +- buffered stream implementation + +- Primitives like par() and wait_one() + +- Remove test dependency on xkcd.com, write our own test server diff --git a/tulip/__init__.py b/tulip/__init__.py new file mode 100644 index 00000000..faf307fb --- /dev/null +++ b/tulip/__init__.py @@ -0,0 +1,26 @@ +"""Tulip 2.0, tracking PEP 3156.""" + +import sys + +# This relies on each of the submodules having an __all__ variable. +from .futures import * +from .events import * +from .locks import * +from .transports import * +from .protocols import * +from .streams import * +from .tasks import * + +if sys.platform == 'win32': # pragma: no cover + from .windows_events import * +else: + from .unix_events import * # pragma: no cover + + +__all__ = (futures.__all__ + + events.__all__ + + locks.__all__ + + transports.__all__ + + protocols.__all__ + + streams.__all__ + + tasks.__all__) diff --git a/tulip/base_events.py b/tulip/base_events.py new file mode 100644 index 00000000..c2cf18a5 --- /dev/null +++ b/tulip/base_events.py @@ -0,0 +1,501 @@ +"""Base implementation of event loop. + +The event loop can be broken up into a multiplexer (the part +responsible for notifying us of IO events) and the event loop proper, +which wraps a multiplexer with functionality for scheduling callbacks, +immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. +""" + + +import collections +import concurrent.futures +import heapq +import logging +import socket +import time + +from . import events +from . import futures +from . import tasks + + +__all__ = ['BaseEventLoop'] + + +# Argument for default thread pool executor creation. +_MAX_WORKERS = 5 + + +class _StopError(BaseException): + """Raised to stop the event loop.""" + + +def _raise_stop_error(*args): + raise _StopError + + +class BaseEventLoop(events.AbstractEventLoop): + + def __init__(self): + self._ready = collections.deque() + self._scheduled = [] + self._default_executor = None + self._internal_fds = 0 + + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + """Create socket transport.""" + raise NotImplementedError + + def _make_ssl_transport(self, rawsock, protocol, + sslcontext, waiter, extra=None): + """Create SSL transport.""" + raise NotImplementedError + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + """Create datagram transport.""" + raise NotImplementedError + + def _read_from_self(self): + """XXX""" + raise NotImplementedError + + def _write_to_self(self): + """XXX""" + raise NotImplementedError + + def _process_events(self, event_list): + """Process selector events.""" + raise NotImplementedError + + def run(self): + """Run the event loop until nothing left to do or stop() called. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + + TODO: Give this a timeout too? + """ + while (self._ready or + self._scheduled or + self._selector.registered_count() > 1): + try: + self._run_once() + except _StopError: + break + + def run_forever(self): + """Run until stop() is called. + + This only makes sense over run() if you have another thread + scheduling callbacks using call_soon_threadsafe(). + """ + handle = self.call_repeatedly(24*3600, lambda: None) + try: + self.run() + finally: + handle.cancel() + + def run_once(self, timeout=None): + """Run through all callbacks and all I/O polls once. + + Calling stop() will break out of this too. + """ + try: + self._run_once(timeout) + except _StopError: + pass + + def run_until_complete(self, future, timeout=None): + """Run until the Future is done, or until a timeout. + + Return the Future's result, or raise its exception. If the + timeout is reached or stop() is called, raise TimeoutError. + """ + handle_called = False + + def stop_loop(): + nonlocal handle_called + handle_called = True + raise _StopError + + future.add_done_callback(_raise_stop_error) + + if timeout is None: + self.run_forever() + else: + handle = self.call_later(timeout, stop_loop) + self.run() + handle.cancel() + + if handle_called: + raise futures.TimeoutError + + return future.result() + + def stop(self): + """Stop running the event loop. + + Every callback scheduled before stop() is called will run. + Callback scheduled after stop() is called won't. However, + those callbacks will run if run() is called again later. + """ + self.call_soon(_raise_stop_error) + + def call_later(self, delay, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The delay can be an int or float, expressed in seconds. It is + always a relative time. + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Callbacks scheduled in the past are passed on to call_soon(), + so these will be called in the order in which they were + registered rather than by time due. This is so you can't + cheat and insert yourself at the front of the ready queue by + using a negative time. + + Any positional arguments after the callback will be passed to + the callback when it is called. + + # TODO: Should delay is None be interpreted as Infinity? + """ + if delay <= 0: + return self.call_soon(callback, *args) + + handle = events.Timer(time.monotonic() + delay, callback, args) + heapq.heappush(self._scheduled, handle) + return handle + + def call_repeatedly(self, interval, callback, *args): + """Call a callback every 'interval' seconds.""" + def wrapper(): + callback(*args) # If this fails, the chain is broken. + handle._when = time.monotonic() + interval + heapq.heappush(self._scheduled, handle) + + handle = events.Timer(time.monotonic() + interval, wrapper, ()) + heapq.heappush(self._scheduled, handle) + return handle + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + handle = events.make_handle(callback, args) + self._ready.append(handle) + return handle + + def call_soon_threadsafe(self, callback, *args): + """XXX""" + handle = self.call_soon(callback, *args) + self._write_to_self() + return handle + + def run_in_executor(self, executor, callback, *args): + if isinstance(callback, events.Handle): + assert not args + assert not isinstance(callback, events.Timer) + if callback.cancelled: + f = futures.Future() + f.set_result(None) + return f + callback, args = callback.callback, callback.args + if executor is None: + executor = self._default_executor + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + self._default_executor = executor + return self.wrap_future(executor.submit(callback, *args)) + + def set_default_executor(self, executor): + self._default_executor = executor + + def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) + + def getnameinfo(self, sockaddr, flags=0): + return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + + @tasks.task + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=False, family=0, proto=0, flags=0, sock=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + "host, port and sock can not be specified at the same time") + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + yield from self.sock_connect(sock, address) + except socket.error as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + if len(exceptions) == 1: + raise exceptions[0] + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise socket.error('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + + elif sock is None: + raise ValueError( + "host and port was not specified and no sock specified") + + sock.setblocking(False) + + protocol = protocol_factory() + waiter = futures.Future() + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + transport = self._make_ssl_transport( + sock, protocol, sslcontext, waiter) + else: + transport = self._make_socket_transport(sock, protocol, waiter) + + yield from waiter + return transport, protocol + + @tasks.task + def create_datagram_connection(self, protocol_factory, + host=None, port=None, *, + family=socket.AF_INET, proto=0, flags=0): + """Create datagram connection.""" + + addr = None + if host is not None or port is not None: + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_DGRAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + sock = socket.socket(family=family, type=type, proto=proto) + + try: + yield from self.sock_connect(sock, address) + addr = address + except socket.error as exc: + sock.close() + exceptions.append(exc) + else: + break + else: + if exceptions: + raise exceptions[0] + else: + sock = socket.socket( + family=family, type=socket.SOCK_DGRAM, proto=proto) + + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(False) + except socket.error: + sock.close() + raise + + protocol = protocol_factory() + transport = self._make_datagram_transport(sock, protocol, addr) + + return transport, protocol + + # TODO: Or create_server()? + @tasks.task + def start_serving(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, backlog=100, sock=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + "host, port and sock can not be specified at the same time") + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + # TODO: Maybe we want to bind every address in the list + # instead of the first one that works? + exceptions = [] + for family, type, proto, cname, address in infos: + sock = socket.socket(family=family, type=type, proto=proto) + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + except socket.error as exc: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + + elif sock is None: + raise ValueError( + "host and port was not specified and no sock specified") + + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock) + return sock + + @tasks.task + def start_serving_datagram(self, protocol_factory, host, port, *, + family=0, proto=0, flags=0): + """XXX""" + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_DGRAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + sock = socket.socket(family=family, type=type, proto=proto) + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + except socket.error as exc: + sock.close() + exceptions.append(exc) + else: + sock.setblocking(False) + break + else: + raise exceptions[0] + + self._make_datagram_transport( + sock, protocol_factory(), extra={'addr': sock.getsockname()}) + + return sock + + def _add_callback(self, handle): + """Add a Handle to ready or scheduled.""" + if handle.cancelled: + return + if isinstance(handle, events.Timer): + heapq.heappush(self._scheduled, handle) + else: + self._ready.append(handle) + + def wrap_future(self, future): + """XXX""" + if isinstance(future, futures.Future): + return future # Don't wrap our own type of Future. + new_future = futures.Future() + future.add_done_callback( + lambda future: + self.call_soon_threadsafe(new_future._copy_state, future)) + return new_future + + def _run_once(self, timeout=None): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0].cancelled: + heapq.heappop(self._scheduled) + + # Inspect the poll queue. If there's exactly one selectable + # file descriptor, it's the self-pipe, and if there's nothing + # scheduled, we should ignore it. + if (self._scheduled or + self._selector.registered_count() > self._internal_fds): + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0].when + deadline = max(0, when - time.monotonic()) + if timeout is None: + timeout = deadline + else: + timeout = min(timeout, deadline) + + t0 = time.monotonic() + event_list = self._selector.select(timeout) + t1 = time.monotonic() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + self._process_events(event_list) + + # Handle 'later' callbacks that are ready. + now = time.monotonic() + while self._scheduled: + handle = self._scheduled[0] + if handle.when > now: + break + handle = heapq.heappop(self._scheduled) + self._ready.append(handle) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is threadsafe without using locks. + ntodo = len(self._ready) + for i in range(ntodo): + handle = self._ready.popleft() + if not handle.cancelled: + handle.run() diff --git a/tulip/events.py b/tulip/events.py new file mode 100644 index 00000000..9bad35fb --- /dev/null +++ b/tulip/events.py @@ -0,0 +1,330 @@ +"""Event loop and event loop policy. + +Beyond the PEP: +- Only the main thread has a default event loop. +""" + +__all__ = ['EventLoopPolicy', 'DefaultEventLoopPolicy', + 'AbstractEventLoop', 'Timer', 'Handle', 'make_handle', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', 'new_event_loop', + ] + +import logging +import sys +import threading + + +class Handle: + """Object returned by callback registration methods.""" + + def __init__(self, callback, args): + self._callback = callback + self._args = args + self._cancelled = False + + def __repr__(self): + res = 'Handle({}, {})'.format(self._callback, self._args) + if self._cancelled: + res += '' + return res + + @property + def callback(self): + return self._callback + + @property + def args(self): + return self._args + + @property + def cancelled(self): + return self._cancelled + + def cancel(self): + self._cancelled = True + + def run(self): + try: + self._callback(*self._args) + except Exception: + logging.exception('Exception in callback %s %r', + self._callback, self._args) + + +def make_handle(callback, args): + if isinstance(callback, Handle): + assert not args + return callback + return Handle(callback, args) + + +class Timer(Handle): + """Object returned by timed callback registration methods.""" + + def __init__(self, when, callback, args): + assert when is not None + super().__init__(callback, args) + + self._when = when + + def __repr__(self): + res = 'Timer({}, {}, {})'.format(self._when, + self._callback, + self._args) + if self._cancelled: + res += '' + + return res + + @property + def when(self): + return self._when + + def __lt__(self, other): + return self._when < other._when + + def __le__(self, other): + if self._when < other._when: + return True + return self.__eq__(other) + + def __gt__(self, other): + return self._when > other._when + + def __ge__(self, other): + if self._when > other._when: + return True + return self.__eq__(other) + + def __eq__(self, other): + if isinstance(other, Timer): + return (self._when == other._when and + self._callback == other._callback and + self._args == other._args and + self._cancelled == other._cancelled) + return NotImplemented + + def __ne__(self, other): + equal = self.__eq__(other) + return NotImplemented if equal is NotImplemented else not equal + + +class AbstractEventLoop: + """Abstract event loop.""" + + # TODO: Rename run() -> run_until_idle(), run_forever() -> run(). + + def run(self): + """Run the event loop. Block until there is nothing left to do.""" + raise NotImplementedError + + def run_forever(self): + """Run the event loop. Block until stop() is called.""" + raise NotImplementedError + + def run_once(self, timeout=None): # NEW! + """Run one complete cycle of the event loop.""" + raise NotImplementedError + + def run_until_complete(self, future, timeout=None): # NEW! + """Run the event loop until a Future is done. + + Return the Future's result, or raise its exception. + + If timeout is not None, run it for at most that long; + if the Future is still not done, raise TimeoutError + (but don't cancel the Future). + """ + raise NotImplementedError + + def stop(self): # NEW! + """Stop the event loop as soon as reasonable. + + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + raise NotImplementedError + + # Methods returning Handles for scheduling callbacks. + + def call_later(self, delay, callback, *args): + raise NotImplementedError + + def call_repeatedly(self, interval, callback, *args): # NEW! + raise NotImplementedError + + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) + + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError + + # Methods returning Futures for interacting with threads. + + def wrap_future(self, future): + raise NotImplementedError + + def run_in_executor(self, executor, callback, *args): + raise NotImplementedError + + # Network I/O methods returning Futures. + + def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError + + def create_connection(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, sock=None): + raise NotImplementedError + + def start_serving(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, sock=None): + raise NotImplementedError + + def create_datagram_connection(self, protocol_factory, + host=None, port=None, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + def start_serving_datagram(self, protocol_factory, host, port, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + # Ready-based callback registration methods. + # The add_*() methods return a Handle. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. + + def add_reader(self, fd, callback, *args): + raise NotImplementedError + + def remove_reader(self, fd): + raise NotImplementedError + + def add_writer(self, fd, callback, *args): + raise NotImplementedError + + def remove_writer(self, fd): + raise NotImplementedError + + # Completion based I/O methods returning Futures. + + def sock_recv(self, sock, nbytes): + raise NotImplementedError + + def sock_sendall(self, sock, data): + raise NotImplementedError + + def sock_connect(self, sock, address): + raise NotImplementedError + + def sock_accept(self, sock): + raise NotImplementedError + + # Signal handling. + + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError + + def remove_signal_handler(self, sig): + raise NotImplementedError + + +class EventLoopPolicy: + """Abstract policy for accessing the event loop.""" + + def get_event_loop(self): + """XXX""" + raise NotImplementedError + + def set_event_loop(self, event_loop): + """XXX""" + raise NotImplementedError + + def new_event_loop(self): + """XXX""" + raise NotImplementedError + + +class DefaultEventLoopPolicy(threading.local, EventLoopPolicy): + """Default policy implementation for accessing the event loop. + + In this policy, each thread has its own event loop. However, we + only automatically create an event loop by default for the main + thread; other threads by default have no event loop. + + Other policies may have different rules (e.g. a single global + event loop, or automatically creating an event loop per thread, or + using some other notion of context to which an event loop is + associated). + """ + + _event_loop = None + + def get_event_loop(self): + """Get the event loop. + + This may be None or an instance of EventLoop. + """ + if (self._event_loop is None and + threading.current_thread().name == 'MainThread'): + self._event_loop = self.new_event_loop() + return self._event_loop + + def set_event_loop(self, event_loop): + """Set the event loop.""" + assert event_loop is None or isinstance(event_loop, AbstractEventLoop) + self._event_loop = event_loop + + def new_event_loop(self): + """Create a new event loop. + + You must call set_event_loop() to make this the current event + loop. + """ + if sys.platform == 'win32': # pragma: no cover + from . import windows_events + return windows_events.SelectorEventLoop() + else: # pragma: no cover + from . import unix_events + return unix_events.SelectorEventLoop() + + +# Event loop policy. The policy itself is always global, even if the +# policy's rules say that there is an event loop per thread (or other +# notion of context). The default policy is installed by the first +# call to get_event_loop_policy(). +_event_loop_policy = None + + +def get_event_loop_policy(): + """XXX""" + global _event_loop_policy + if _event_loop_policy is None: + _event_loop_policy = DefaultEventLoopPolicy() + return _event_loop_policy + + +def set_event_loop_policy(policy): + """XXX""" + global _event_loop_policy + assert policy is None or isinstance(policy, EventLoopPolicy) + _event_loop_policy = policy + + +def get_event_loop(): + """XXX""" + return get_event_loop_policy().get_event_loop() + + +def set_event_loop(event_loop): + """XXX""" + get_event_loop_policy().set_event_loop(event_loop) + + +def new_event_loop(): + """XXX""" + return get_event_loop_policy().new_event_loop() diff --git a/tulip/futures.py b/tulip/futures.py new file mode 100644 index 00000000..68735f35 --- /dev/null +++ b/tulip/futures.py @@ -0,0 +1,244 @@ +"""A Future class similar to the one in PEP 3148.""" + +__all__ = ['CancelledError', 'TimeoutError', + 'InvalidStateError', 'InvalidTimeoutError', + 'Future', + ] + +import concurrent.futures._base + +from . import events + +# States for Future. +_PENDING = 'PENDING' +_CANCELLED = 'CANCELLED' +_FINISHED = 'FINISHED' + +# TODO: Do we really want to depend on concurrent.futures internals? +Error = concurrent.futures._base.Error +CancelledError = concurrent.futures.CancelledError +TimeoutError = concurrent.futures.TimeoutError + + +class InvalidStateError(Error): + """The operation is not allowed in this state.""" + # TODO: Show the future, its state, the method, and the required state. + + +class InvalidTimeoutError(Error): + """Called result() or exception() with timeout != 0.""" + # TODO: Print a nice error message. + + +class Future: + """This class is *almost* compatible with concurrent.futures.Future. + + Differences: + + - result() and exception() do not take a timeout argument and + raise an exception when the future isn't done yet. + + - Callbacks registered with add_done_callback() are always called + via the event loop's call_soon_threadsafe(). + + - This class is not compatible with the wait() and as_completed() + methods in the concurrent.futures package. + + (In Python 3.4 or later we may be able to unify the implementations.) + """ + + # Class variables serving as defaults for instance variables. + _state = _PENDING + _result = None + _exception = None + + _blocking = False # proper use of future (yield vs yield from) + + def __init__(self, *, event_loop=None): + """Initialize the future. + + The optional event_loop argument allows to explicitly set the event + loop object used by the future. If it's not provided, the future uses + the default event loop. + """ + if event_loop is None: + self._event_loop = events.get_event_loop() + else: + self._event_loop = event_loop + self._callbacks = [] + + def __repr__(self): + res = self.__class__.__name__ + if self._state == _FINISHED: + if self._exception is not None: + res += ''.format(self._exception) + else: + res += ''.format(self._result) + elif self._callbacks: + size = len(self._callbacks) + if size > 2: + res += '<{}, [{}, <{} more>, {}]>'.format( + self._state, self._callbacks[0], + size-2, self._callbacks[-1]) + else: + res += '<{}, {}>'.format(self._state, self._callbacks) + else: + res += '<{}>'.format(self._state) + return res + + def cancel(self): + """Cancel the future and schedule callbacks. + + If the future is already done or cancelled, return False. Otherwise, + change the future's state to cancelled, schedule the callbacks and + return True. + """ + if self._state != _PENDING: + return False + self._state = _CANCELLED + self._schedule_callbacks() + return True + + def _schedule_callbacks(self): + """Internal: Ask the event loop to call all callbacks. + + The callbacks are scheduled to be called as soon as possible. Also + clears the callback list. + """ + callbacks = self._callbacks[:] + if not callbacks: + return + self._callbacks[:] = [] + for callback in callbacks: + self._event_loop.call_soon(callback, self) + + def cancelled(self): + """Return True if the future was cancelled.""" + return self._state == _CANCELLED + + def running(self): + """Always return False. + + This method is for compatibility with concurrent.futures; we don't + have a running state. + """ + return False # We don't have a running state. + + def done(self): + """Return True if the future is done. + + Done means either that a result / exception are available, or that the + future was cancelled. + """ + return self._state != _PENDING + + def result(self, timeout=0): + """Return the result this future represents. + + If the future has been cancelled, raises CancelledError. If the + future's result isn't yet available, raises InvalidStateError. If + the future is done and has an exception set, this exception is raised. + Timeout values other than 0 are not supported. + """ + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + if self._exception is not None: + raise self._exception + return self._result + + def exception(self, timeout=0): + """Return the exception that was set on this future. + + The exception (or None if no exception was set) is returned only if + the future is done. If the future has been cancelled, raises + CancelledError. If the future isn't done yet, raises + InvalidStateError. Timeout values other than 0 are not supported. + """ + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + return self._exception + + def add_done_callback(self, fn): + """Add a callback to be run when the future becomes done. + + The callback is called with a single argument - the future object. If + the future is already done when this is called, the callback is + scheduled with call_soon. + """ + if self._state != _PENDING: + self._event_loop.call_soon(fn, self) + else: + self._callbacks.append(fn) + + # New method not in PEP 3148. + + def remove_done_callback(self, fn): + """Remove all instances of a callback from the "call when done" list. + + Returns the number of callbacks removed. + """ + filtered_callbacks = [f for f in self._callbacks if f != fn] + removed_count = len(self._callbacks) - len(filtered_callbacks) + if removed_count: + self._callbacks[:] = filtered_callbacks + return removed_count + + # So-called internal methods (note: no set_running_or_notify_cancel()). + + def set_result(self, result): + """Mark the future done and set its result. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError + self._result = result + self._state = _FINISHED + self._schedule_callbacks() + + def set_exception(self, exception): + """ Mark the future done and set an exception. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError + self._exception = exception + self._state = _FINISHED + self._schedule_callbacks() + + # Truly internal methods. + + def _copy_state(self, other): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert other.done() + assert not self.done() + if other.cancelled(): + self.cancel() + else: + exception = other.exception() + if exception is not None: + self.set_exception(exception) + else: + result = other.result() + self.set_result(result) + + def __iter__(self): + if not self.done(): + self._blocking = True + yield self # This tells Task to wait for completion. + assert self.done(), "yield from wasn't used with future" + return self.result() # May raise too. diff --git a/tulip/http/__init__.py b/tulip/http/__init__.py new file mode 100644 index 00000000..d436383f --- /dev/null +++ b/tulip/http/__init__.py @@ -0,0 +1,8 @@ +# This relies on each of the submodules having an __all__ variable. + +from .client import * +from .protocol import * + + +__all__ = (client.__all__ + + protocol.__all__) diff --git a/tulip/http/client.py b/tulip/http/client.py new file mode 100644 index 00000000..b4db5ccb --- /dev/null +++ b/tulip/http/client.py @@ -0,0 +1,154 @@ +"""HTTP Client for Tulip. + +Most basic usage: + + sts, headers, response = yield from http_client.fetch(url, + method='GET', headers={}, request=b'') + assert isinstance(sts, int) + assert isinstance(headers, dict) + # sort of; case insensitive (what about multiple values for same header?) + headers['status'] == '200 Ok' # or some such + assert isinstance(response, bytes) + +However you can also open a stream: + + f, wstream = http_client.open_stream(url, method, headers) + wstream.write(b'abc') + wstream.writelines([b'def', b'ghi']) + wstream.write_eof() + sts, headers, rstream = yield from f + response = yield from rstream.read() + +TODO: Reuse email.Message class (or its subclass, http.client.HTTPMessage). +TODO: How do we do connection keep alive? Pooling? +""" + +__all__ = ['HttpClientProtocol'] + + +import email.message +import email.parser + +import tulip + +from . import protocol + + +class HttpClientProtocol: + """This Protocol class is also used to initiate the connection. + + Usage: + p = HttpClientProtocol(url, ...) + sts, headers, stream = yield from p.connect() + + """ + + def __init__(self, host, port=None, *, + path='/', method='GET', headers=None, ssl=None, + make_body=None, encoding='utf-8', version='1.1', + chunked=False): + host = self.validate(host, 'host') + if ':' in host: + assert port is None + host, port_s = host.split(':', 1) + port = int(port_s) + self.host = host + if port is None: + if ssl: + port = 443 + else: + port = 80 + assert isinstance(port, int) + self.port = port + self.path = self.validate(path, 'path') + self.method = self.validate(method, 'method') + self.headers = email.message.Message() + self.headers['Accept-Encoding'] = 'gzip, deflate' + if headers: + for key, value in headers.items(): + self.validate(key, 'header key') + self.validate(value, 'header value', True) + self.headers[key] = value + self.encoding = self.validate(encoding, 'encoding') + self.version = self.validate(version, 'version') + self.make_body = make_body + self.chunked = chunked + self.ssl = ssl + if 'content-length' not in self.headers: + if self.make_body is None: + self.headers['Content-Length'] = '0' + else: + self.chunked = True + if self.chunked: + if 'Transfer-Encoding' not in self.headers: + self.headers['Transfer-Encoding'] = 'chunked' + else: + assert self.headers['Transfer-Encoding'].lower() == 'chunked' + if 'host' not in self.headers: + self.headers['Host'] = self.host + self.event_loop = tulip.get_event_loop() + self.transport = None + + def validate(self, value, name, embedded_spaces_okay=False): + # Must be a string. If embedded_spaces_okay is False, no + # whitespace is allowed; otherwise, internal single spaces are + # allowed (but no other whitespace). + assert isinstance(value, str), \ + '{} should be str, not {}'.format(name, type(value)) + parts = value.split() + assert parts, '{} should not be empty'.format(name) + if embedded_spaces_okay: + assert ' '.join(parts) == value, \ + '{} can only contain embedded single spaces ({!r})'.format( + name, value) + else: + assert parts == [value], \ + '{} cannot contain whitespace ({!r})'.format(name, value) + return value + + @tulip.coroutine + def connect(self): + yield from self.event_loop.create_connection( + lambda: self, self.host, self.port, ssl=self.ssl) + + # read response status + version, status, reason = yield from self.stream.read_response_status() + + message = yield from self.stream.read_message(version) + + # headers + headers = email.message.Message() + for hdr, val in message.headers: + headers.add_header(hdr, val) + + sts = '{} {}'.format(status, reason) + return (sts, headers, message.payload) + + def connection_made(self, transport): + self.transport = transport + self.stream = protocol.HttpStreamReader() + self.wstream = protocol.HttpStreamWriter(transport) + + line = '{} {} HTTP/{}\r\n'.format(self.method, + self.path, + self.version) + self.wstream.write_str(line) + for key, value in self.headers.items(): + self.wstream.write_str('{}: {}\r\n'.format(key, value)) + self.wstream.write(b'\r\n') + if self.make_body is not None: + if self.chunked: + self.make_body( + self.wstream.write_chunked, self.wstream.write_chunked_eof) + else: + self.make_body( + self.wstream.write_str, self.wstream.write_eof) + + def data_received(self, data): + self.stream.feed_data(data) + + def eof_received(self): + self.stream.feed_eof() + + def connection_lost(self, exc): + pass diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py new file mode 100644 index 00000000..2ff7876f --- /dev/null +++ b/tulip/http/protocol.py @@ -0,0 +1,508 @@ +"""Http related helper utils.""" + +__all__ = ['HttpStreamReader', 'HttpStreamWriter', + 'HttpMessage', 'RequestLine', 'ResponseStatus'] + +import collections +import functools +import http.client +import re +import zlib + +import tulip + +METHRE = re.compile('[A-Z0-9$-_.]+') +VERSRE = re.compile('HTTP/(\d+).(\d+)') +HDRRE = re.compile(b"[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]") +CONTINUATION = (b' ', b'\t') + + +RequestLine = collections.namedtuple( + 'RequestLine', ['method', 'uri', 'version']) + + +ResponseStatus = collections.namedtuple( + 'ResponseStatus', ['version', 'code', 'reason']) + + +HttpMessage = collections.namedtuple( + 'HttpMessage', ['headers', 'payload', 'should_close', 'compression']) + + +class StreamEofException(http.client.HTTPException): + """eof received""" + + +def wrap_payload_reader(f): + """wrap_payload_reader wraps payload readers and redirect stream. + payload readers are generator functions, read_chunked_payload, + read_length_payload, read_eof_payload. + payload reader allows to modify data stream and feed data into stream. + + StreamReader instance should be send to generator as first parameter. + This steam is used as destination stream for processed data. + To send data to reader use generator's send() method. + + To indicate eof stream, throw StreamEofException exception into the reader. + In case of errors in incoming stream reader sets exception to + destination stream with StreamReader.set_exception() method. + + Before exit, reader generator returns unprocessed data. + """ + + @functools.wraps(f) + def wrapper(self, *args, **kw): + assert self._reader is None + + rstream = stream = tulip.StreamReader() + + encoding = kw.pop('encoding', None) + if encoding is not None: + if encoding not in ('gzip', 'deflate'): + raise ValueError( + 'Content-Encoding %r is not supported' % encoding) + + stream = DeflateStream(stream, encoding) + + reader = f(self, *args, **kw) + next(reader) + try: + reader.send(stream) + except StopIteration: + pass + else: + # feed buffer + self.line_count = 0 + self.byte_count = 0 + while self.buffer: + try: + reader.send(self.buffer.popleft()) + except StopIteration as exc: + buf = b''.join(self.buffer) + self.buffer.clear() + reader = None + if exc.value: + self.feed_data(exc.value) + + if buf: + self.feed_data(buf) + + break + + if reader is not None: + if self.eof: + try: + reader.throw(StreamEofException()) + except StopIteration as exc: + pass + else: + self._reader = reader + + return rstream + + return wrapper + + +class DeflateStream: + """DeflateStream decomress stream and feed data into specified steram.""" + + def __init__(self, stream, encoding): + self.stream = stream + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + + self.zlib = zlib.decompressobj(wbits=zlib_mode) + + def feed_data(self, chunk): + try: + chunk = self.zlib.decompress(chunk) + except: + self.stream.set_exception(http.client.IncompleteRead(b'')) + + if chunk: + self.stream.feed_data(chunk) + + def feed_eof(self): + self.stream.feed_data(self.zlib.flush()) + if not self.zlib.eof: + self.stream.set_exception(http.client.IncompleteRead(b'')) + + self.stream.feed_eof() + + +class HttpStreamReader(tulip.StreamReader): + + MAX_HEADERS = 32768 + MAX_HEADERFIELD_SIZE = 8190 + + # if _reader is set, feed_data and feed_eof sends data into + # _reader instead of self. is it being used as stream redirection for + # read_chunked_payload, read_length_payload and read_eof_payload + _reader = None + + def feed_data(self, data): + """_reader is a generator, if _reader is set, feed_data sends + incoming data into this generator untile generates stops.""" + if self._reader: + try: + self._reader.send(data) + except StopIteration as exc: + self._reader = None + if exc.value: + self.feed_data(exc.value) + else: + super().feed_data(data) + + def feed_eof(self): + """_reader is a generator, if _reader is set feed_eof throws + StreamEofException into this generator.""" + if self._reader: + try: + self._reader.throw(StreamEofException()) + except StopIteration: + self._reader = None + + super().feed_eof() + + @tulip.coroutine + def read_request_line(self): + """Read request status line. Exception http.client.BadStatusLine + could be raised in case of any errors in status line. + Returns three values (method, uri, version) + + Example: + + GET /path HTTP/1.1 + + >> yield from reader.read_request_line() + ('GET', '/path', (1, 1)) + + """ + bline = yield from self.readline() + try: + line = bline.decode('ascii').rstrip() + except UnicodeDecodeError: + raise http.client.BadStatusLine(bline) from None + + try: + method, uri, version = line.split(None, 2) + except ValueError: + raise http.client.BadStatusLine(line) from None + + # method + method = method.upper() + if not METHRE.match(method): + raise http.client.BadStatusLine(method) + + # version + match = VERSRE.match(version) + if match is None: + raise http.client.BadStatusLine(version) + version = (int(match.group(1)), int(match.group(2))) + + return RequestLine(method, uri, version) + + @tulip.coroutine + def read_response_status(self): + """Read response status line. Exception http.client.BadStatusLine + could be raised in case of any errors in status line. + Returns three values (version, status_code, reason) + + Example: + + HTTP/1.1 200 Ok + + >> yield from reader.read_response_status() + ((1, 1), 200, 'Ok') + + """ + bline = yield from self.readline() + if not bline: + # Presumably, the server closed the connection before + # sending a valid response. + raise http.client.BadStatusLine(bline) + + try: + line = bline.decode('ascii').rstrip() + except UnicodeDecodeError: + raise http.client.BadStatusLine(bline) from None + + try: + version, status = line.split(None, 1) + except ValueError: + raise http.client.BadStatusLine(line) from None + else: + try: + status, reason = status.split(None, 1) + except ValueError: + reason = '' + + # version + match = VERSRE.match(version) + if match is None: + raise http.client.BadStatusLine(line) + version = (int(match.group(1)), int(match.group(2))) + + # The status code is a three-digit number + try: + status = int(status) + except ValueError: + raise http.client.BadStatusLine(line) from None + + if status < 100 or status > 999: + raise http.client.BadStatusLine(line) + + return ResponseStatus(version, status, reason.strip()) + + @tulip.coroutine + def read_headers(self): + """Read and parses RFC2822 headers from a stream. + + Line continuations are supported. Returns list of header name + and value pairs. Header name is in upper case. + """ + size = 0 + headers = [] + + line = yield from self.readline() + + while line not in (b'\r\n', b'\n'): + header_length = len(line) + + # Parse initial header name : value pair. + sep_pos = line.find(b':') + if sep_pos < 0: + raise ValueError('Invalid header %s' % line.strip()) + + name, value = line[:sep_pos], line[sep_pos+1:] + name = name.rstrip(b' \t').upper() + if HDRRE.search(name): + raise ValueError('Invalid header name %s' % name) + + name = name.strip().decode('ascii', 'surrogateescape') + value = [value.lstrip()] + + # next line + line = yield from self.readline() + + # consume continuation lines + continuation = line.startswith(CONTINUATION) + + if continuation: + while continuation: + header_length += len(line) + if header_length > self.MAX_HEADERFIELD_SIZE: + raise http.client.LineTooLong( + 'limit request headers fields size') + value.append(line) + + line = yield from self.readline() + continuation = line.startswith(CONTINUATION) + else: + if header_length > self.MAX_HEADERFIELD_SIZE: + raise http.client.LineTooLong( + 'limit request headers fields size') + + # total headers size + size += header_length + if size >= self.MAX_HEADERS: + raise http.client.LineTooLong('limit request headers fields') + + headers.append( + (name, + b''.join(value).rstrip().decode('ascii', 'surrogateescape'))) + + return headers + + @wrap_payload_reader + def read_chunked_payload(self): + """Read chunked stream.""" + stream = yield + + try: + data = bytearray() + + while True: + # read line + if b'\n' not in data: + data.extend((yield)) + continue + + line, data = data.split(b'\n', 1) + + # Read the next chunk size from the file + i = line.find(b";") + if i >= 0: + line = line[:i] # strip chunk-extensions + try: + size = int(line, 16) + except ValueError: + raise http.client.IncompleteRead(b'') from None + + if size == 0: + break + + # read chunk + while len(data) < size: + data.extend((yield)) + + # feed stream + stream.feed_data(data[:size]) + + data = data[size:] + + # toss the CRLF at the end of the chunk + while len(data) < 2: + data.extend((yield)) + + data = data[2:] + + # read and discard trailer up to the CRLF terminator + while True: + if b'\n' in data: + line, data = data.split(b'\n', 1) + if line in (b'\r', b''): + break + else: + data.extend((yield)) + + # stream eof + stream.feed_eof() + return data + + except StreamEofException: + stream.set_exception(http.client.IncompleteRead(b'')) + except http.client.IncompleteRead as exc: + stream.set_exception(exc) + + @wrap_payload_reader + def read_length_payload(self, length): + """Read specified amount of bytes.""" + stream = yield + + try: + data = bytearray() + while length: + data.extend((yield)) + + data_len = len(data) + if data_len <= length: + stream.feed_data(data) + data = bytearray() + length -= data_len + else: + stream.feed_data(data[:length]) + data = data[length:] + length = 0 + + stream.feed_eof() + return data + except StreamEofException: + stream.set_exception(http.client.IncompleteRead(b'')) + + @wrap_payload_reader + def read_eof_payload(self): + """Read all bytes untile eof.""" + stream = yield + + try: + while True: + stream.feed_data((yield)) + except StreamEofException: + stream.feed_eof() + + @tulip.coroutine + def read_message(self, version=(1, 1), + length=None, compression=True, readall=False): + """Read RFC2822 headers and message payload from a stream. + + read_message() automatically decompress gzip and deflate content + encoding. To prevent decompression pass compression=False. + + Returns tuple of headers, payload stream, should close flag, + compression type. + """ + # load headers + headers = yield from self.read_headers() + + # payload params + chunked = False + encoding = None + close_conn = None + + for name, value in headers: + if name == 'CONTENT-LENGTH': + length = value + elif name == 'TRANSFER-ENCODING': + chunked = value.lower() == 'chunked' + elif name == 'SEC-WEBSOCKET-KEY1': + length = 8 + elif name == "CONNECTION": + v = value.lower() + if v == "close": + close_conn = True + elif v == "keep-alive": + close_conn = False + elif compression and name == 'CONTENT-ENCODING': + enc = value.lower() + if enc in ('gzip', 'deflate'): + encoding = enc + + if close_conn is None: + close_conn = version <= (1, 0) + + # payload stream + if chunked: + payload = self.read_chunked_payload(encoding=encoding) + + elif length is not None: + try: + length = int(length) + except ValueError: + raise http.client.HTTPException('CONTENT-LENGTH') from None + + if length < 0: + raise http.client.HTTPException('CONTENT-LENGTH') + + payload = self.read_length_payload(length, encoding=encoding) + else: + if readall: + payload = self.read_eof_payload(encoding=encoding) + else: + payload = self.read_length_payload(0, encoding=encoding) + + return HttpMessage(headers, payload, close_conn, encoding) + + +class HttpStreamWriter: + + def __init__(self, transport, encoding='utf-8'): + self.transport = transport + self.encoding = encoding + + def encode(self, s): + if isinstance(s, bytes): + return s + return s.encode(self.encoding) + + def decode(self, s): + if isinstance(s, str): + return s + return s.decode(self.encoding) + + def write(self, b): + self.transport.write(b) + + def write_str(self, s): + self.transport.write(self.encode(s)) + + def write_chunked(self, chunk): + if not chunk: + return + data = self.encode(chunk) + self.write_str('{:x}\r\n'.format(len(data))) + self.transport.write(data) + self.transport.write(b'\r\n') + + def write_chunked_eof(self): + self.transport.write(b'0\r\n\r\n') diff --git a/tulip/locks.py b/tulip/locks.py new file mode 100644 index 00000000..c86048f4 --- /dev/null +++ b/tulip/locks.py @@ -0,0 +1,461 @@ +"""Synchronization primitives""" + +__all__ = ['Lock', 'EventWaiter', 'Condition', 'Semaphore'] + +import collections +import time + +from . import events +from . import futures +from . import tasks + + +class Lock: + """The class implementing primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned by + a particular coroutine when locked. A primitive lock is in one of two + states, "locked" or "unlocked". + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() changes + the state to locked and returns immediately. When the state is locked, + acquire() blocks until a call to release() in another coroutine changes + it to unlocked, then the acquire() call resets it to locked and returns. + The release() method should only be called in the locked state; it changes + the state to unlocked and returns immediately. If an attempt is made + to release an unlocked lock, a RuntimeError will be raised. + + When more than one coroutine is blocked in acquire() waiting for the state + to turn to unlocked, only one coroutine proceeds when a release() call + resets the state to unlocked; first coroutine which is blocked in acquire() + is being processed. + + acquire() method is a coroutine and should be called with "yield from" + + Locks also support the context manager protocol. (yield from lock) should + be used as context manager expression. + + Usage: + + lock = Lock() + ... + yield from lock + try: + ... + finally: + lock.release() + + Context manager usage: + + lock = Lock() + ... + with (yield from lock): + ... + + Lock object could be tested for locking state: + + if not lock.locked(): + yield from lock + else: + # lock is acquired + ... + + """ + + def __init__(self): + self._waiters = collections.deque() + self._locked = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<%s [%s]>' % ( + res[1:-1], 'locked' if self._locked else 'unlocked') + + def locked(self): + """Return true if lock is acquired.""" + return self._locked + + @tasks.coroutine + def acquire(self, timeout=None): + """Acquire a lock. + + Acquire method blocks until the lock is unlocked, then set it to + locked and return True. + + When invoked with the floating-point timeout argument set, blocks for + at most the number of seconds specified by timeout and as long as + the lock cannot be acquired. + + The return value is True if the lock is acquired successfully, + False if not (for example if the timeout expired). + """ + if not self._waiters and not self._locked: + self._locked = True + return True + + fut = futures.Future(event_loop=self._event_loop) + if timeout is not None: + handle = self._event_loop.call_later(timeout, fut.cancel) + else: + handle = None + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + if handle is not None: + handle.cancel() + + self._locked = True + return True + + def release(self): + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + If any other coroutines are blocked waiting for the lock to become + unlocked, allow exactly one of them to proceed. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + if self._waiters: + self._waiters[0].set_result(True) + else: + raise RuntimeError('Lock is not acquired.') + + def __enter__(self): + if not self._locked: + raise RuntimeError( + '"yield from" should be used as context manager expression') + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self + + +class EventWaiter: + """A EventWaiter implementation, our equivalent to threading.Event + + Class implementing event objects. An event manages a flag that can be set + to true with the set() method and reset to false with the clear() method. + The wait() method blocks until the flag is true. The flag is initially + false. + """ + + def __init__(self): + self._waiters = collections.deque() + self._value = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<%s [%s]>' % (res[1:-1], 'set' if self._value else 'unset') + + def is_set(self): + """Return true if and only if the internal flag is true.""" + return self._value + + def set(self): + """Set the internal flag to true. All coroutines waiting for it to + become true are awakened. Coroutine that call wait() once the flag is + true will not block at all. + """ + if not self._value: + self._value = True + + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + + def clear(self): + """Reset the internal flag to false. Subsequently, coroutines calling + wait() will block until set() is called to set the internal flag + to true again.""" + self._value = False + + @tasks.coroutine + def wait(self, timeout=None): + """Block until the internal flag is true. If the internal flag + is true on entry, return immediately. Otherwise, block until another + coroutine calls set() to set the flag to true, or until the optional + timeout occurs. + + When the timeout argument is present and not None, it should be + a floating point number specifying a timeout for the operation in + seconds (or fractions thereof). + + This method returns true if and only if the internal flag has been + set to true, either before the wait call or after the wait starts, + so it will always return True except if a timeout is given and + the operation times out. + + wait() method is a coroutine. + """ + if self._value: + return True + + fut = futures.Future(event_loop=self._event_loop) + if timeout is not None: + handle = self._event_loop.call_later(timeout, fut.cancel) + else: + handle = None + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + if handle is not None: + handle.cancel() + + return True + + +class Condition(Lock): + """A Condition implementation. + + This class implements condition variable objects. A condition variable + allows one or more coroutines to wait until they are notified by another + coroutine. + """ + + def __init__(self): + super().__init__() + + self._condition_waiters = collections.deque() + + @tasks.coroutine + def wait(self, timeout=None): + """Wait until notified or until a timeout occurs. If the calling + coroutine has not acquired the lock when this method is called, + a RuntimeError is raised. + + This method releases the underlying lock, and then blocks until it is + awakened by a notify() or notify_all() call for the same condition + variable in another coroutine, or until the optional timeout occurs. + Once awakened or timed out, it re-acquires the lock and returns. + + When the timeout argument is present and not None, it should be + a floating point number specifying a timeout for the operation + in seconds (or fractions thereof). + + The return value is True unless a given timeout expired, in which + case it is False. + """ + if not self._locked: + raise RuntimeError('cannot wait on un-acquired lock') + + self.release() + + fut = futures.Future(event_loop=self._event_loop) + if timeout is not None: + handle = self._event_loop.call_later(timeout, fut.cancel) + else: + handle = None + + self._condition_waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._condition_waiters.remove(fut) + return False + else: + f = self._condition_waiters.popleft() + assert fut is f + finally: + yield from self.acquire() + + if handle is not None: + handle.cancel() + + return True + + @tasks.coroutine + def wait_for(self, predicate, timeout=None): + """Wait until a condition evaluates to True. predicate should be a + callable which result will be interpreted as a boolean value. A timeout + may be provided giving the maximum time to wait. + """ + endtime = None + waittime = timeout + result = predicate() + + while not result: + if waittime is not None: + if endtime is None: + endtime = time.monotonic() + waittime + else: + waittime = endtime - time.monotonic() + if waittime <= 0: + break + + yield from self.wait(waittime) + result = predicate() + + return result + + def notify(self, n=1): + """By default, wake up one coroutine waiting on this condition, if any. + If the calling coroutine has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up at most n of the coroutines waiting for the + condition variable; it is a no-op if no coroutines are waiting. + + Note: an awakened coroutine does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self._locked: + raise RuntimeError('cannot notify on un-acquired lock') + + idx = 0 + for fut in self._condition_waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self): + """Wake up all threads waiting on this condition. This method acts + like notify(), but wakes up all waiting threads instead of one. If the + calling thread has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._condition_waiters)) + + +class Semaphore: + """A Semaphore implementation. + + A semaphore manages an internal counter which is decremented by each + acquire() call and incremented by each release() call. The counter + can never go below zero; when acquire() finds that it is zero, it blocks, + waiting until some other thread calls release(). + + Semaphores also support the context manager protocol. + + The first optional argument gives the initial value for the internal + counter; it defaults to 1. If the value given is less than 0, + ValueError is raised. + + The second optional argument determins can semophore be released more than + initial internal counter value; it defaults to False. If the value given + is True and number of release() is more than number of successfull + acquire() calls ValueError is raised. + """ + + def __init__(self, value=1, bound=False): + if value < 0: + raise ValueError("Semaphore initial value must be > 0") + self._value = value + self._bound = bound + self._bound_value = value + self._waiters = collections.deque() + self._locked = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<%s [%s]>' % ( + res[1:-1], + 'locked' if self._locked else 'unlocked,value:%s' % self._value) + + def locked(self): + """Returns True if semaphore can not be acquired immediately.""" + return self._locked + + @tasks.coroutine + def acquire(self, timeout=None): + """Acquire a semaphore. acquire() method is a coroutine. + + When invoked without arguments: if the internal counter is larger + than zero on entry, decrement it by one and return immediately. + If it is zero on entry, block, waiting until some other coroutine has + called release() to make it larger than zero. + + When invoked with a timeout other than None, it will block for at + most timeout seconds. If acquire does not complete successfully in + that interval, return false. Return true otherwise. + """ + if not self._waiters and self._value > 0: + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + fut = futures.Future(event_loop=self._event_loop) + if timeout is not None: + handle = self._event_loop.call_later(timeout, fut.cancel) + else: + handle = None + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + if handle is not None: + handle.cancel() + + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + When it was zero on entry and another coroutine is waiting for it to + become larger than zero again, wake up that coroutine. + + If Semaphore is create with "bound" paramter equals true, then + release() method checks to make sure its current value doesn't exceed + its initial value. If it does, ValueError is raised. + """ + if self._bound and self._value >= self._bound_value: + raise ValueError('Semaphore released too many times') + + self._value += 1 + self._locked = False + + for waiter in self._waiters: + if not waiter.done(): + waiter.set_result(True) + break + + def __enter__(self): + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py new file mode 100644 index 00000000..45c075e3 --- /dev/null +++ b/tulip/proactor_events.py @@ -0,0 +1,190 @@ +"""Event loop using a proactor and related classes. + +A proactor is a "notify-on-completion" multiplexer. Currently a +proactor is only implemented on Windows with IOCP. +""" + +import logging + +from . import base_events +from . import transports + + +class _ProactorSocketTransport(transports.Transport): + + def __init__(self, event_loop, sock, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._event_loop = event_loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._read_fut = None + self._write_fut = None + self._closing = False # Set when close() called. + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.call_soon(self._loop_reading) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _loop_reading(self, fut=None): + data = None + + try: + if fut is not None: + assert fut is self._read_fut + + data = fut.result() # deliver data later in "finally" clause + if not data: + self._read_fut = None + return + + self._read_fut = self._event_loop._proactor.recv(self._sock, 4096) + except ConnectionAbortedError as exc: + if not self._closing: + self._fatal_error(exc) + except OSError as exc: + self._fatal_error(exc) + else: + self._read_fut.add_done_callback(self._loop_reading) + finally: + if data: + self._protocol.data_received(data) + elif data is not None: + self._protocol.eof_received() + + def write(self, data): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + self._buffer.append(data) + if not self._write_fut: + self._loop_writing() + + def _loop_writing(self, f=None): + try: + assert f is self._write_fut + if f: + f.result() + data = b''.join(self._buffer) + self._buffer = [] + if not data: + self._write_fut = None + return + self._write_fut = self._event_loop._proactor.send(self._sock, data) + except OSError as exc: + self._fatal_error(exc) + else: + self._write_fut.add_done_callback(self._loop_writing) + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + if self._write_fut: + self._write_fut.cancel() + if not self._buffer: + self._event_loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + logging.exception('Fatal error for %s', self) + if self._write_fut: + self._write_fut.cancel() + if self._read_fut: # XXX + self._read_fut.cancel() + self._write_fut = self._read_fut = None + self._buffer = [] + self._event_loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + + +class BaseProactorEventLoop(base_events.BaseEventLoop): + + def __init__(self, proactor): + super().__init__() + logging.debug('Using proactor: %s', proactor.__class__.__name__) + self._proactor = proactor + self._selector = proactor # convenient alias + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + return _ProactorSocketTransport(self, sock, protocol, waiter, extra) + + def close(self): + if self._proactor is not None: + self._close_self_pipe() + self._proactor.close() + self._proactor = None + self._selector = None + + def sock_recv(self, sock, n): + return self._proactor.recv(sock, n) + + def sock_sendall(self, sock, data): + return self._proactor.send(sock, data) + + def sock_connect(self, sock, address): + return self._proactor.connect(sock, address) + + def sock_accept(self, sock): + return self._proactor.accept(sock) + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + + def loop(f=None): + try: + if f: + f.result() # may raise + f = self._proactor.recv(self._ssock, 4096) + except: + self.close() + raise + else: + f.add_done_callback(loop) + self.call_soon(loop) + + def _write_to_self(self): + self._csock.send(b'x') + + def _start_serving(self, protocol_factory, sock): + def loop(f=None): + try: + if f: + conn, addr = f.result() + protocol = protocol_factory() + self._make_socket_transport( + conn, protocol, extra={'addr': addr}) + f = self._proactor.accept(sock) + except OSError: + sock.close() + logging.exception('Accept failed') + else: + f.add_done_callback(loop) + self.call_soon(loop) + + def _process_events(self, event_list): + pass # XXX hard work currently done in poll diff --git a/tulip/protocols.py b/tulip/protocols.py new file mode 100644 index 00000000..f01e2fd2 --- /dev/null +++ b/tulip/protocols.py @@ -0,0 +1,74 @@ +"""Abstract Protocol class.""" + +__all__ = ['Protocol', 'DatagramProtocol'] + + +class Protocol: + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing (they don't raise exceptions). + + When the user wants to requests a transport, they pass a protocol + factory to a utility function (e.g., EventLoop.create_connection()). + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_list() will be called exactly once + with either an exception object or None as an argument. + + State machine of calls: + + start -> CM [-> DR*] [-> ER?] -> CL -> end + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the connection. + To send data, call its write() or writelines() method. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + The default implementation does nothing. + + TODO: By default close the transport. But we don't have the + transport as an instance variable (connection_made() may not + set it). + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +class DatagramProtocol: + """ABC representing a datagram protocol.""" + + def connection_made(self, transport): + """Called when a datagram transport is ready.""" + + def datagram_received(self, data, addr): + """Called when some datagram is received.""" + + def connection_refused(self, exc): + """Connection is refused.""" + + def connection_lost(self, exc): + """Called when the connection is lost or closed.""" diff --git a/tulip/selector_events.py b/tulip/selector_events.py new file mode 100644 index 00000000..cc7fe33c --- /dev/null +++ b/tulip/selector_events.py @@ -0,0 +1,661 @@ +"""Event loop using a selector and related classes. + +A selector is a "notify-when-ready" multiplexer. For a subclass which +also includes support for signal handling, see the unix_events sub-module. +""" + +import collections +import errno +import logging +import socket +try: + import ssl +except ImportError: # pragma: no cover + ssl = None +import sys + +from . import base_events +from . import events +from . import futures +from . import selectors +from . import transports + + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK, errno.EINPROGRESS)) +if sys.platform == 'win32': # pragma: no cover + _TRYAGAIN = frozenset(list(_TRYAGAIN) + [errno.WSAEWOULDBLOCK]) + + +class BaseSelectorEventLoop(base_events.BaseEventLoop): + """Selector event loop. + + See events.EventLoop for API specification. + """ + + def __init__(self, selector=None): + super().__init__() + + if selector is None: + selector = selectors.Selector() + logging.debug('Using selector: %s', selector.__class__.__name__) + self._selector = selector + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + return _SelectorSocketTransport(self, sock, protocol, waiter, extra) + + def _make_ssl_transport(self, rawsock, protocol, + sslcontext, waiter, extra=None): + return _SelectorSslTransport( + self, rawsock, protocol, sslcontext, waiter, extra) + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + return _SelectorDatagramTransport(self, sock, protocol, address, extra) + + def close(self): + if self._selector is not None: + self._close_self_pipe() + self._selector.close() + self._selector = None + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self.remove_reader(self._ssock.fileno()) + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.add_reader(self._ssock.fileno(), self._read_from_self) + + def _read_from_self(self): + try: + self._ssock.recv(1) + except socket.error as exc: + if exc.errno in _TRYAGAIN: + return + raise # Halp! + + def _write_to_self(self): + try: + self._csock.send(b'x') + except socket.error as exc: + if exc.errno in _TRYAGAIN: + return + raise # Halp! + + def _start_serving(self, protocol_factory, sock): + self.add_reader(sock.fileno(), self._accept_connection, + protocol_factory, sock) + + def _accept_connection(self, protocol_factory, sock): + try: + conn, addr = sock.accept() + except socket.error as exc: + if exc.errno in _TRYAGAIN: + return # False alarm. + # Bad error. Stop serving. + self.remove_reader(sock.fileno()) + sock.close() + # There's nowhere to send the error, so just log it. + # TODO: Someone will want an error handler for this. + logging.exception('Accept failed') + return + + self._make_socket_transport( + conn, protocol_factory(), extra={'addr': addr}) + # It's now up to the protocol to handle the connection. + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a Handle instance.""" + handle = events.make_handle(callback, args) + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_READ, + (handle, None)) + else: + self._selector.modify(fd, mask | selectors.EVENT_READ, + (handle, writer)) + if reader is not None: + reader.cancel() + + return handle + + def remove_reader(self, fd): + """Remove a reader callback.""" + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + return False + else: + mask &= ~selectors.EVENT_READ + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (None, writer)) + + if reader is not None: + reader.cancel() + return True + else: + return False + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a Handle instance.""" + handle = events.make_handle(callback, args) + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_WRITE, + (None, handle)) + else: + self._selector.modify(fd, mask | selectors.EVENT_WRITE, + (reader, handle)) + if writer is not None: + writer.cancel() + + return handle + + def remove_writer(self, fd): + """Remove a writer callback.""" + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + return False + else: + # Remove both writer and connector. + mask &= ~selectors.EVENT_WRITE + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None)) + + if writer is not None: + writer.cancel() + return True + else: + return False + + def sock_recv(self, sock, n): + """XXX""" + fut = futures.Future() + self._sock_recv(fut, False, sock, n) + return fut + + def _sock_recv(self, fut, registered, sock, n): + fd = sock.fileno() + if registered: + # Remove the callback early. It should be rare that the + # selector says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_reader(fd) + if fut.cancelled(): + return + try: + data = sock.recv(n) + fut.set_result(data) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + fut.set_exception(exc) + else: + self.add_reader(fd, self._sock_recv, fut, True, sock, n) + + def sock_sendall(self, sock, data): + """XXX""" + fut = futures.Future() + if data: + self._sock_sendall(fut, False, sock, data) + else: + fut.set_result(None) + return fut + + def _sock_sendall(self, fut, registered, sock, data): + fd = sock.fileno() + + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + + try: + n = sock.send(data) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + fut.set_exception(exc) + return + n = 0 + + if n == len(data): + fut.set_result(None) + else: + if n: + data = data[n:] + self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + + def sock_connect(self, sock, address): + """XXX""" + # That address better not require a lookup! We're not calling + # self.getaddrinfo() for you here. But verifying this is + # complicated; the socket module doesn't have a pattern for + # IPv6 addresses (there are too many forms, apparently). + fut = futures.Future() + self._sock_connect(fut, False, sock, address) + return fut + + def _sock_connect(self, fut, registered, sock, address): + fd = sock.fileno() + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + try: + if not registered: + # First time around. + sock.connect(address) + else: + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to the except clause below. + raise socket.error(err, 'Connect call failed') + fut.set_result(None) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + fut.set_exception(exc) + else: + self.add_writer(fd, self._sock_connect, + fut, True, sock, address) + + def sock_accept(self, sock): + """XXX""" + fut = futures.Future() + self._sock_accept(fut, False, sock) + return fut + + def _sock_accept(self, fut, registered, sock): + fd = sock.fileno() + if registered: + self.remove_reader(fd) + if fut.cancelled(): + return + try: + conn, address = sock.accept() + conn.setblocking(False) + fut.set_result((conn, address)) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + fut.set_exception(exc) + else: + self.add_reader(fd, self._sock_accept, fut, True, sock) + + def _process_events(self, event_list): + for fileobj, mask, (reader, writer) in event_list: + if mask & selectors.EVENT_READ and reader is not None: + if reader.cancelled: + self.remove_reader(fileobj) + else: + self._add_callback(reader) + if mask & selectors.EVENT_WRITE and writer is not None: + if writer.cancelled: + self.remove_writer(fileobj) + else: + self._add_callback(writer) + + +class _SelectorSocketTransport(transports.Transport): + + def __init__(self, event_loop, sock, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._event_loop = event_loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._closing = False # Set when close() called. + self._event_loop.add_reader(self._sock.fileno(), self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = self._sock.recv(16*1024) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + else: + if data: + self._event_loop.call_soon(self._protocol.data_received, data) + else: + self._event_loop.remove_reader(self._sock.fileno()) + self._event_loop.call_soon(self._protocol.eof_received) + + def write(self, data): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + if not self._buffer: + # Attempt to send it right away first. + try: + n = self._sock.send(data) + except socket.error as exc: + if exc.errno in _TRYAGAIN: + n = 0 + else: + self._fatal_error(exc) + return + if n == len(data): + return + if n: + data = data[n:] + self._event_loop.add_writer(self._sock.fileno(), self._write_ready) + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + self._buffer = [] + try: + if data: + n = self._sock.send(data) + else: + n = 0 + except socket.error as exc: + if exc.errno in _TRYAGAIN: + n = 0 + else: + self._fatal_error(exc) + return + if n == len(data): + self._event_loop.remove_writer(self._sock.fileno()) + if self._closing: + self._event_loop.call_soon(self._call_connection_lost, None) + return + if n: + data = data[n:] + self._buffer.append(data) # Try again later. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sock.fileno()) + if not self._buffer: + self._event_loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + logging.exception('Fatal error for %s', self) + self._event_loop.remove_writer(self._sock.fileno()) + self._event_loop.remove_reader(self._sock.fileno()) + self._buffer = [] + self._event_loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + + +class _SelectorSslTransport(transports.Transport): + + def __init__(self, event_loop, rawsock, + protocol, sslcontext, waiter, extra=None): + super().__init__(extra) + + self._event_loop = event_loop + self._rawsock = rawsock + self._protocol = protocol + sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslcontext = sslcontext + self._waiter = waiter + sslsock = sslcontext.wrap_socket(rawsock, + do_handshake_on_connect=False) + self._sslsock = sslsock + self._buffer = [] + self._closing = False # Set when close() called. + self._extra['socket'] = sslsock + + self._on_handshake() + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._event_loop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._event_loop.add_writer(fd, self._on_handshake) + return + except Exception as exc: + self._sslsock.close() + self._waiter.set_exception(exc) + return + except BaseException as exc: + self._sslsock.close() + self._waiter.set_exception(exc) + raise + self._event_loop.remove_reader(fd) + self._event_loop.remove_writer(fd) + self._event_loop.add_reader(fd, self._on_ready) + self._event_loop.add_writer(fd, self._on_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.call_soon(self._waiter.set_result, None) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + return + else: + if data: + self._protocol.data_received(data) + else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer + self._event_loop.remove_reader(fd) + self._event_loop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + return + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = b''.join(self._buffer) + self._buffer = [] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + n = 0 + except ssl.SSLWantWriteError: + n = 0 + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + return + else: + n = 0 + + if n < len(data): + self._buffer.append(data[n:]) + elif self._closing: + self._event_loop.remove_writer(self._sslsock.fileno()) + self._sslsock.close() + self._protocol.connection_lost(None) + + def write(self, data): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sslsock.fileno()) + if not self._buffer: + self._event_loop.call_soon(self._protocol.connection_lost, None) + + def _fatal_error(self, exc): + logging.exception('Fatal error for %s', self) + self._event_loop.remove_writer(self._sslsock.fileno()) + self._event_loop.remove_reader(self._sslsock.fileno()) + self._buffer = [] + self._event_loop.call_soon(self._protocol.connection_lost, exc) + + +class _SelectorDatagramTransport(transports.Transport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, event_loop, sock, protocol, address=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._event_loop = event_loop + self._sock = sock + self._fileno = sock.fileno() + self._protocol = protocol + self._address = address + self._buffer = collections.deque() + self._closing = False # Set when close() called. + self._event_loop.add_reader(self._fileno, self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + + def _read_ready(self): + try: + data, addr = self._sock.recvfrom(self.max_size) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + else: + self._event_loop.call_soon( + self._protocol.datagram_received, data, addr) + + def sendto(self, data, addr=None): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + + if self._address: + assert addr in (None, self._address) + + if not self._buffer: + # Attempt to send it right away first. + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + return + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + + self._event_loop.add_writer(self._fileno, self._sendto_ready) + + self._buffer.append((data, addr)) + + def _sendto_ready(self): + while self._buffer: + data, addr = self._buffer.popleft() + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._fatal_error(exc) + return + + # Try again later. + self._buffer.appendleft((data, addr)) + break + + if not self._buffer: + self._event_loop.remove_writer(self._fileno) + if self._closing: + self._event_loop.call_soon(self._call_connection_lost, None) + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._fileno) + if not self._buffer: + self._event_loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + logging.exception('Fatal error for %s', self) + + self._buffer.clear() + self._event_loop.remove_writer(self._fileno) + self._event_loop.remove_reader(self._fileno) + if self._address and isinstance(exc, ConnectionRefusedError): + self._event_loop.call_soon(self._protocol.connection_refused, exc) + self._event_loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() diff --git a/tulip/selectors.py b/tulip/selectors.py new file mode 100644 index 00000000..e8fd12e9 --- /dev/null +++ b/tulip/selectors.py @@ -0,0 +1,418 @@ +"""Select module. + +This module supports asynchronous I/O on multiple file descriptors. +""" + +import logging +import sys + +from select import * + + +# generic events, that must be mapped to implementation-specific ones +# read event +EVENT_READ = (1 << 0) +# write event +EVENT_WRITE = (1 << 1) + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file descriptor, or any object with a `fileno()` method + + Returns: + corresponding file descriptor + """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (ValueError, TypeError): + raise ValueError("Invalid file object: {!r}".format(fileobj)) + return fd + + +class SelectorKey: + """Object used internally to associate a file object to its backing file + descriptor, selected event mask and attached data.""" + + def __init__(self, fileobj, events, data=None): + self.fileobj = fileobj + self.fd = _fileobj_to_fd(fileobj) + self.events = events + self.data = data + + def __repr__(self): + return '{}'.format( + self.__class__.__name__, + self.fileobj, self.fd, self.events, self.data) + + +class _BaseSelector: + """Base selector class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + performant implementation on the current platform. + """ + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # this maps file objects to keys - for fast (un)registering + self._fileobj_to_key = {} + + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + """ + if (not events) or (events & ~(EVENT_READ|EVENT_WRITE)): + raise ValueError("Invalid events: {}".format(events)) + + if fileobj in self._fileobj_to_key: + raise ValueError("{!r} is already registered".format(fileobj)) + + key = SelectorKey(fileobj, events, data) + self._fd_to_key[key.fd] = key + self._fileobj_to_key[fileobj] = key + return key + + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object + + Returns: + SelectorKey instance + """ + try: + key = self._fileobj_to_key[fileobj] + del self._fd_to_key[key.fd] + del self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + return key + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + """ + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + if events != key.events or data != key.data: + self.unregister(fileobj) + return self.register(fileobj, events, data) + else: + return key + + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout == 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (fileobj, events, attached data) for ready file objects + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE + """ + raise NotImplementedError() + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + self._fd_to_key.clear() + self._fileobj_to_key.clear() + + def get_info(self, fileobj): + """Return information about a registered file object. + + Returns: + (events, data) associated to this file object + + Raises KeyError if the file object is not registered. + """ + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise KeyError("{} is not registered".format(fileobj)) + return key.events, key.data + + def registered_count(self): + """Return the number of registered file objects. + + Returns: + number of currently registered file objects + """ + return len(self._fd_to_key) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key + """ + try: + return self._fd_to_key[fd] + except KeyError: + logging.warn('No key found for fd %r', fd) + return None + + +class SelectSelector(_BaseSelector): + """Select-based selector.""" + + def __init__(self): + super().__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & EVENT_WRITE: + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + def select(self, timeout=None): + try: + r, w, _ = self._select(self._readers, self._writers, [], timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + r = set(r) + w = set(w) + ready = [] + for fd in r | w: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + if sys.platform == 'win32': + def _select(self, r, w, _, timeout=None): + r, w, x = select(r, w, w, timeout) + return r, w + x, [] + else: + from select import select as _select + + +if 'poll' in globals(): + + # TODO: Implement poll() for Windows with workaround for + # brokenness in WSAPoll() (Richard Oudkerk, see + # http://bugs.python.org/issue16507). + + class PollSelector(_BaseSelector): + """Poll-based selector.""" + + def __init__(self): + super().__init__() + self._poll = poll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= POLLIN + if events & EVENT_WRITE: + poll_events |= POLLOUT + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = None if timeout is None else int(1000 * timeout) + ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~POLLIN: + events |= EVENT_WRITE + if event & ~POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + +if 'epoll' in globals(): + + class EpollSelector(_BaseSelector): + """Epoll-based selector.""" + + def __init__(self): + super().__init__() + self._epoll = epoll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + epoll_events = 0 + if events & EVENT_READ: + epoll_events |= EPOLLIN + if events & EVENT_WRITE: + epoll_events |= EPOLLOUT + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._epoll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = -1 if timeout is None else timeout + max_ev = self.registered_count() + ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~EPOLLIN: + events |= EVENT_WRITE + if event & ~EPOLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + def close(self): + super().close() + self._epoll.close() + + +if 'kqueue' in globals(): + + class KqueueSelector(_BaseSelector): + """Kqueue-based selector.""" + + def __init__(self): + super().__init__() + self._kqueue = kqueue() + + def unregister(self, fileobj): + key = super().unregister(fileobj) + if key.events & EVENT_READ: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + if key.events & EVENT_WRITE: + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + return key + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & EVENT_WRITE: + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def select(self, timeout=None): + max_ev = self.registered_count() + ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == KQ_FILTER_READ: + events |= EVENT_READ + if flag == KQ_FILTER_WRITE: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + def close(self): + super().close() + self._kqueue.close() + + +# Choose the best implementation: roughly, epoll|kqueue > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + Selector = KqueueSelector +elif 'EpollSelector' in globals(): + Selector = EpollSelector +elif 'PollSelector' in globals(): + Selector = PollSelector +else: + Selector = SelectSelector diff --git a/tulip/streams.py b/tulip/streams.py new file mode 100644 index 00000000..8d7f6236 --- /dev/null +++ b/tulip/streams.py @@ -0,0 +1,145 @@ +"""Stream-related things.""" + +__all__ = ['StreamReader'] + +import collections + +from . import futures +from . import tasks + + +class StreamReader: + + def __init__(self, limit=2**16): + self.limit = limit # Max line length. (Security feature.) + self.buffer = collections.deque() # Deque of bytes objects. + self.byte_count = 0 # Bytes in buffer. + self.eof = False # Whether we're done. + self.waiter = None # A future. + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + if self.waiter is not None: + self.waiter.set_exception(exc) + + def feed_eof(self): + self.eof = True + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(True) + + def feed_data(self, data): + if not data: + return + + self.buffer.append(data) + self.byte_count += len(data) + + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(False) + + @tasks.coroutine + def readline(self): + if self._exception is not None: + raise self._exception + + parts = [] + parts_size = 0 + not_enough = True + + while not_enough: + while self.buffer and not_enough: + data = self.buffer.popleft() + ichar = data.find(b'\n') + if ichar < 0: + parts.append(data) + parts_size += len(data) + else: + ichar += 1 + head, tail = data[:ichar], data[ichar:] + if tail: + self.buffer.appendleft(tail) + not_enough = False + parts.append(head) + parts_size += len(head) + + if parts_size > self.limit: + self.byte_count -= parts_size + raise ValueError('Line is too long') + + if self.eof: + break + + if not_enough: + assert self.waiter is None + self.waiter = futures.Future() + yield from self.waiter + + line = b''.join(parts) + self.byte_count -= parts_size + + return line + + @tasks.coroutine + def read(self, n=-1): + if self._exception is not None: + raise self._exception + + if not n: + return b'' + + if n < 0: + while not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + else: + if not self.byte_count and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + + if n < 0 or self.byte_count <= n: + data = b''.join(self.buffer) + self.buffer.clear() + self.byte_count = 0 + return data + + parts = [] + parts_bytes = 0 + while self.buffer and parts_bytes < n: + data = self.buffer.popleft() + data_bytes = len(data) + if n < parts_bytes + data_bytes: + data_bytes = n - parts_bytes + data, rest = data[:data_bytes], data[data_bytes:] + self.buffer.appendleft(rest) + + parts.append(data) + parts_bytes += data_bytes + self.byte_count -= data_bytes + + return b''.join(parts) + + @tasks.coroutine + def readexactly(self, n): + if self._exception is not None: + raise self._exception + + if n <= 0: + return b'' + + while self.byte_count < n and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + + return (yield from self.read(n)) diff --git a/tulip/subprocess_transport.py b/tulip/subprocess_transport.py new file mode 100644 index 00000000..45e6f6fe --- /dev/null +++ b/tulip/subprocess_transport.py @@ -0,0 +1,134 @@ +import fcntl +import logging +import os +import traceback + +from . import transports +from . import events + + +class UnixSubprocessTransport(transports.Transport): + """Transport class managing a subprocess. + + TODO: Separate this into something that just handles pipe I/O, + and something else that handles pipe setup, fork, and exec. + """ + + def __init__(self, protocol, args): + self._protocol = protocol # Not a factory! :-) + self._args = args # args[0] must be full path of binary. + self._event_loop = events.get_event_loop() + self._buffer = [] + self._eof = False + rstdin, self._wstdin = os.pipe() + self._rstdout, wstdout = os.pipe() + + # TODO: This is incredibly naive. Should look at + # subprocess.py for all the precautions around fork/exec. + pid = os.fork() + if not pid: + # Child. + try: + os.dup2(rstdin, 0) + os.dup2(wstdout, 1) + # TODO: What to do with stderr? + os.execv(args[0], args) + except: + try: + traceback.print_traceback() + finally: + os._exit(127) + + # Parent. + os.close(rstdin) + os.close(wstdout) + _setnonblocking(self._wstdin) + _setnonblocking(self._rstdout) + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.add_reader(self._rstdout, self._stdout_callback) + + def write(self, data): + assert not self._eof + if not data: + return + if not self._buffer: + # Attempt to write it right away first. + try: + n = os.write(self._wstdin, data) + except BlockingIOError: + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + if n >= len(data): + return + if n > 0: + data = data[n:] + self._event_loop.add_writer(self._wstdin, self._stdin_callback) + self._buffer.append(data) + + def write_eof(self): + assert not self._eof + assert self._wstdin >= 0 + self._eof = True + if not self._buffer: + self._event_loop.remove_writer(self._wstdin) + os.close(self._wstdin) + self._wstdin = -1 + + def close(self): + if not self._eof: + self.write_eof() + # XXX What else? + + def _fatal_error(self, exc): + logging.error('Fatal error: %r', exc) + if self._rstdout >= 0: + os.close(self._rstdout) + self._rstdout = -1 + if self._wstdin >= 0: + os.close(self._wstdin) + self._wstdin = -1 + self._eof = True + self._buffer = None + + def _stdin_callback(self): + data = b''.join(self._buffer) + self._buffer = [] + try: + if data: + n = os.write(self._wstdin, data) + else: + n = 0 + except BlockingIOError: + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + if n >= len(data): + self._event_loop.remove_writer(self._wstdin) + if self._eof: + os.close(self._wstdin) + self._wstdin = -1 + return + if n > 0: + data = data[n:] + self._buffer.append(data) # Try again later. + + def _stdout_callback(self): + try: + data = os.read(self._rstdout, 1024) + except BlockingIOError: + return + if data: + self._event_loop.call_soon(self._protocol.data_received, data) + else: + self._event_loop.remove_reader(self._rstdout) + os.close(self._rstdout) + self._rstdout = -1 + self._event_loop.call_soon(self._protocol.eof_received) + + +def _setnonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) diff --git a/tulip/tasks.py b/tulip/tasks.py new file mode 100644 index 00000000..08bbb31c --- /dev/null +++ b/tulip/tasks.py @@ -0,0 +1,308 @@ +"""Support for tasks, coroutines and the scheduler.""" + +__all__ = ['coroutine', 'task', 'Task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'as_completed', 'sleep', + ] + +import concurrent.futures +import inspect +import logging +import time + +from . import events +from . import futures + + +def coroutine(func): + """Decorator to mark coroutines.""" + # TODO: This is a feel-good API only. It is not enforced. + assert inspect.isgeneratorfunction(func) + func._is_coroutine = True # Not sure who can use this. + return func + + +# TODO: Do we need this? +def iscoroutinefunction(func): + """Return True if func is a decorated coroutine function.""" + return (inspect.isgeneratorfunction(func) and + getattr(func, '_is_coroutine', False)) + + +# TODO: Do we need this? +def iscoroutine(obj): + """Return True if obj is a coroutine object.""" + return inspect.isgenerator(obj) # TODO: And what? + + +def task(func): + """Decorator for a coroutine to be wrapped in a Task.""" + def task_wrapper(*args, **kwds): + coro = func(*args, **kwds) + return Task(coro) + return task_wrapper + + +class Task(futures.Future): + """A coroutine wrapped in a Future.""" + + def __init__(self, coro, event_loop=None): + assert inspect.isgenerator(coro) # Must be a coroutine *object*. + super().__init__(event_loop=event_loop) # Sets self._event_loop. + self._coro = coro + self._must_cancel = False + self._event_loop.call_soon(self._step) + + def __repr__(self): + res = super().__repr__() + if (self._must_cancel and + self._state == futures._PENDING and + ')'.format(self._coro.__name__) + res[i:] + return res + + def cancel(self): + if self.done(): + return False + self._must_cancel = True + # _step() will call super().cancel() to call the callbacks. + self._event_loop.call_soon(self._step_maybe) + return True + + def cancelled(self): + return self._must_cancel or super().cancelled() + + def _step_maybe(self): + # Helper for cancel(). + if not self.done(): + return self._step() + + def _step(self, value=None, exc=None): + if self.done(): + logging.warn('_step(): already done: %r, %r, %r', self, value, exc) + return + # We'll call either coro.throw(exc) or coro.send(value). + if self._must_cancel: + exc = futures.CancelledError + coro = self._coro + try: + if exc is not None: + result = coro.throw(exc) + elif value is not None: + result = coro.send(value) + else: + result = next(coro) + except StopIteration as exc: + if self._must_cancel: + super().cancel() + else: + self.set_result(exc.value) + except Exception as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + logging.exception('Exception in task') + except BaseException as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + logging.exception('BaseException in task') + raise + else: + # XXX No check for self._must_cancel here? + if isinstance(result, futures.Future): + if not result._blocking: + self._event_loop.call_soon( + self._step, + None, RuntimeError( + 'yield was used instead of yield from in task %r ' + 'with %r' % (self, result))) + else: + result._blocking = False + result.add_done_callback(self._wakeup) + + elif isinstance(result, concurrent.futures.Future): + # This ought to be more efficient than wrap_future(), + # because we don't create an extra Future. + result.add_done_callback( + lambda future: + self._event_loop.call_soon_threadsafe( + self._wakeup, future)) + else: + if inspect.isgenerator(result): + self._event_loop.call_soon( + self._step, + None, RuntimeError( + 'yield was used instead of yield from for ' + 'generator in task %r with %s' % (self, result))) + else: + if result is not None: + logging.warn('_step(): bad yield: %r', result) + + self._event_loop.call_soon(self._step) + + def _wakeup(self, future): + try: + value = future.result() + except Exception as exc: + self._step(None, exc) + else: + self._step(value, None) + + +# wait() and as_completed() similar to those in PEP 3148. + +FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED +FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION +ALL_COMPLETED = concurrent.futures.ALL_COMPLETED + + +# Even though this *is* a @coroutine, we don't mark it as such! +def wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Wait for the Futures and and coroutines given by fs to complete. + + Coroutines will be wrapped in Tasks. + + Returns two sets of Future: (done, pending). + + Usage: + + done, pending = yield from tulip.wait(fs) + + Note: This does not raise TimeoutError! Futures that aren't done + when the timeout occurs are returned in the second set. + """ + fs = _wrap_coroutines(fs) + return _wait(fs, timeout, return_when) + + +@coroutine +def _wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Internal helper: Like wait() but does not wrap coroutines.""" + done, pending = set(), set() + + errors = 0 + for f in fs: + if f.done(): + done.add(f) + if not f.cancelled() and f.exception() is not None: + errors += 1 + else: + pending.add(f) + + if (not pending or + timeout is not None and timeout <= 0 or + return_when == FIRST_COMPLETED and done or + return_when == FIRST_EXCEPTION and errors): + return done, pending + + bail = futures.Future() # Will always be cancelled eventually. + timeout_handle = None + debugstuff = locals() + + if timeout is not None: + loop = events.get_event_loop() + timeout_handle = loop.call_later(timeout, bail.cancel) + + def _on_completion(f): + pending.remove(f) + done.add(f) + if (not pending or + return_when == FIRST_COMPLETED or + (return_when == FIRST_EXCEPTION and + not f.cancelled() and + f.exception() is not None)): + bail.cancel() + + try: + for f in pending: + f.add_done_callback(_on_completion) + try: + yield from bail + except futures.CancelledError: + pass + finally: + for f in pending: + f.remove_done_callback(_on_completion) + if timeout_handle is not None: + timeout_handle.cancel() + + really_done = set(f for f in pending if f.done()) + if really_done: + done.update(really_done) + pending.difference_update(really_done) + + return done, pending + + +# This is *not* a @coroutine! It is just an iterator (yielding Futures). +def as_completed(fs, timeout=None): + """Return an iterator whose values, when waited for, are Futures. + + This differs from PEP 3148; the proper way to use this is: + + for f in as_completed(fs): + result = yield from f # The 'yield from' may raise. + # Use result. + + Raises TimeoutError if the timeout occurs before all Futures are + done. + + Note: The futures 'f' are not necessarily members of fs. + """ + deadline = None + if timeout is not None: + deadline = time.monotonic() + timeout + + done = None # Make nonlocal happy. + fs = _wrap_coroutines(fs) + + while fs: + if deadline is not None: + timeout = deadline - time.monotonic() + + @coroutine + def _wait_for_some(): + nonlocal done, fs + done, fs = yield from _wait(fs, timeout=timeout, + return_when=FIRST_COMPLETED) + if not done: + fs = set() + raise futures.TimeoutError() + return done.pop().result() # May raise. + + yield Task(_wait_for_some()) + for f in done: + yield f + + +def _wrap_coroutines(fs): + """Internal helper to process an iterator of Futures and coroutines. + + Returns a set of Futures. + """ + wrapped = set() + for f in fs: + if not isinstance(f, futures.Future): + assert iscoroutine(f) + f = Task(f) + wrapped.add(f) + return wrapped + + +def sleep(when, result=None): + """Return a Future that completes after a given time (in seconds). + + It's okay to cancel the Future. + + Undocumented feature: sleep(when, x) sets the Future's result to x. + """ + future = futures.Future() + future._event_loop.call_later(when, future.set_result, result) + return future diff --git a/tulip/test_utils.py b/tulip/test_utils.py new file mode 100644 index 00000000..9b87db2f --- /dev/null +++ b/tulip/test_utils.py @@ -0,0 +1,30 @@ +"""Utilities shared by tests.""" + +import logging +import socket +import sys +import unittest + + +if sys.platform == 'win32': # pragma: no cover + from .winsocketpair import socketpair +else: + from socket import socketpair # pragma: no cover + + +class LogTrackingTestCase(unittest.TestCase): + + def setUp(self): + self._logger = logging.getLogger() + self._log_level = self._logger.getEffectiveLevel() + + def tearDown(self): + self._logger.setLevel(self._log_level) + + def suppress_log_errors(self): # pragma: no cover + if self._log_level >= logging.WARNING: + self._logger.setLevel(logging.CRITICAL) + + def suppress_log_warnings(self): # pragma: no cover + if self._log_level >= logging.WARNING: + self._logger.setLevel(logging.ERROR) diff --git a/tulip/transports.py b/tulip/transports.py new file mode 100644 index 00000000..6eb1c554 --- /dev/null +++ b/tulip/transports.py @@ -0,0 +1,99 @@ +"""Abstract Transport class.""" + +__all__ = ['Transport'] + + +class Transport: + """ABC representing a transport. + + There may be several implementations, but typically, the user does + not implement new transports; rather, the platform provides some + useful transports that are implemented using the platform's best + practices. + + The user never instantiates a transport directly; they call a + utility function, passing it a protocol factory and other + information necessary to create the transport and protocol. (E.g. + EventLoop.create_connection() or EventLoop.start_serving().) + + The utility function will asynchronously create a transport and a + protocol and hook them up by calling the protocol's + connection_made() method, passing it the transport. + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + def __init__(self, extra=None): + if extra is None: + extra = {} + self._extra = extra + + def get_extra_info(self, name, default=None): + """Get optional transport information.""" + return self._extra.get(name, default) + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def write_eof(self): + """Closes the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + raise NotImplementedError + + def can_write_eof(self): + """Return True if this protocol supports write_eof(), False if not.""" + raise NotImplementedError + + def pause(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + raise NotImplementedError + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError diff --git a/tulip/unix_events.py b/tulip/unix_events.py new file mode 100644 index 00000000..41f8e0b4 --- /dev/null +++ b/tulip/unix_events.py @@ -0,0 +1,128 @@ +"""Selector eventloop for Unix with signal handling.""" + +import errno +import logging +import socket +import sys + +try: + import signal +except ImportError: # pragma: no cover + signal = None + +from . import events +from . import selector_events + + +__all__ = ['SelectorEventLoop'] + + +if sys.platform == 'win32': # pragma: no cover + raise ImportError('Signals are not really supported on Windows') + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Unix event loop + + Adds signal handling to SelectorEventLoop + """ + + def __init__(self, selector=None): + super().__init__(selector) + self._signal_handlers = {} + + def _socketpair(self): + return socket.socketpair() + + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + self._check_signal(sig) + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except ValueError as exc: + raise RuntimeError(str(exc)) + + handle = events.make_handle(callback, args) + self._signal_handlers[sig] = handle + + try: + signal.signal(sig, self._handle_signal) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as nexc: + logging.info('set_wakeup_fd(-1) failed: %s', nexc) + + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + return handle + + def _handle_signal(self, sig, arg): + """Internal helper that is the actual signal handler.""" + handle = self._signal_handlers.get(sig) + if handle is None: + return # Assume it's some race condition. + if handle.cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self.call_soon_threadsafe(handle) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not.""" + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as exc: + logging.info('set_wakeup_fd(-1) failed: %s', exc) + + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + + if signal is None: + raise RuntimeError('Signals are not supported') + + if not (1 <= sig < signal.NSIG): + raise ValueError( + 'sig {} out of range(1, {})'.format(sig, signal.NSIG)) diff --git a/tulip/windows_events.py b/tulip/windows_events.py new file mode 100644 index 00000000..4297f804 --- /dev/null +++ b/tulip/windows_events.py @@ -0,0 +1,157 @@ +"""Selector and proactor eventloops for Windows.""" + +import logging +import socket +import weakref +import struct +import _winapi + +from . import futures +from . import proactor_events +from . import selector_events +from . import winsocketpair +from . import _overlapped + + +__all__ = ['SelectorEventLoop', 'ProactorEventLoop'] + + +NULL = 0 +INFINITE = 0xffffffff +ERROR_CONNECTION_REFUSED = 1225 +ERROR_CONNECTION_ABORTED = 1236 + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + def _socketpair(self): + return winsocketpair.socketpair() + + +class ProactorEventLoop(proactor_events.BaseProactorEventLoop): + def __init__(self, proactor=None): + if proactor is None: + proactor = IocpProactor() + super().__init__(proactor) + + def _socketpair(self): + return winsocketpair.socketpair() + + +class IocpProactor: + + def __init__(self, concurrency=0xffffffff): + self._results = [] + self._iocp = _overlapped.CreateIoCompletionPort( + _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency) + self._cache = {} + self._registered = weakref.WeakSet() + + def registered_count(self): + return len(self._cache) + + def select(self, timeout=None): + if not self._results: + self._poll(timeout) + tmp = self._results + self._results = [] + return tmp + + def recv(self, conn, nbytes, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + ov.WSARecv(conn.fileno(), nbytes, flags) + return self._register(ov, conn, ov.getresult) + + def send(self, conn, buf, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + ov.WSASend(conn.fileno(), buf, flags) + return self._register(ov, conn, ov.getresult) + + def accept(self, listener): + self._register_with_iocp(listener) + conn = self._get_accept_socket() + ov = _overlapped.Overlapped(NULL) + ov.AcceptEx(listener.fileno(), conn.fileno()) + + def finish_accept(): + addr = ov.getresult() + buf = struct.pack('@P', listener.fileno()) + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_ACCEPT_CONTEXT, + buf) + conn.settimeout(listener.gettimeout()) + return conn, conn.getpeername() + + return self._register(ov, listener, finish_accept) + + def connect(self, conn, address): + self._register_with_iocp(conn) + _overlapped.BindLocal(conn.fileno(), len(address)) + ov = _overlapped.Overlapped(NULL) + ov.ConnectEx(conn.fileno(), address) + + def finish_connect(): + ov.getresult() + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_CONNECT_CONTEXT, + 0) + return conn + + return self._register(ov, conn, finish_connect) + + def _register_with_iocp(self, obj): + if obj not in self._registered: + self._registered.add(obj) + _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) + + def _register(self, ov, obj, callback): + f = futures.Future() + self._cache[ov.address] = (f, ov, obj, callback) + return f + + def _get_accept_socket(self): + s = socket.socket() + s.settimeout(0) + return s + + def _poll(self, timeout=None): + if timeout is None: + ms = INFINITE + elif timeout < 0: + raise ValueError("negative timeout") + else: + ms = int(timeout * 1000 + 0.5) + if ms >= INFINITE: + raise ValueError("timeout too big") + while True: + status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms) + if status is None: + return + address = status[3] + f, ov, obj, callback = self._cache.pop(address) + try: + value = callback() + except OSError as e: + f.set_exception(e) + self._results.append(f) + else: + f.set_result(value) + self._results.append(f) + ms = 0 + + def close(self): + for (f, ov, obj, callback) in self._cache.values(): + try: + ov.cancel() + except OSError: + pass + + while self._cache: + if not self._poll(1): + logging.debug('taking long time to close proactor') + + self._results = [] + if self._iocp is not None: + _winapi.CloseHandle(self._iocp) + self._iocp = None diff --git a/tulip/winsocketpair.py b/tulip/winsocketpair.py new file mode 100644 index 00000000..374616f6 --- /dev/null +++ b/tulip/winsocketpair.py @@ -0,0 +1,34 @@ +"""A socket pair usable as a self-pipe, for Windows. + +Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. +""" + +import errno +import socket +import sys + +if sys.platform != 'win32': # pragma: no cover + raise ImportError('winsocketpair is win32 only') + + +def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """Emulate the Unix socketpair() function on Windows.""" + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + lsock.bind(('localhost', 0)) + lsock.listen(1) + addr, port = lsock.getsockname() + csock = socket.socket(family, type, proto) + csock.setblocking(False) + try: + csock.connect((addr, port)) + except socket.error as e: + if e.errno != errno.WSAEWOULDBLOCK: + lsock.close() + csock.close() + raise + ssock, _ = lsock.accept() + csock.setblocking(True) + lsock.close() + return (ssock, csock) From 0ba989c59fcabb110c60b81133fe54980d045b7c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 20 Mar 2013 08:07:05 -0700 Subject: [PATCH 0359/1502] Fix curl.py for changes in http interface. --- curl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/curl.py b/curl.py index 0624df86..6986fc20 100755 --- a/curl.py +++ b/curl.py @@ -22,7 +22,7 @@ def main(): for k, v in headers.items(): print('{}: {}'.format(k, v)) print() - data = p.event_loop.run_until_complete(tulip.Task(stream)) + data = p.event_loop.run_until_complete(tulip.Task(stream.read())) print(data.decode('utf-8', 'replace')) From 8a1378d1da92fcfce02b0127af8ac224aa10abea Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 20 Mar 2013 08:37:03 -0700 Subject: [PATCH 0360/1502] Update note about scheduling. --- NOTES | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/NOTES b/NOTES index c8f274de..3b94ba96 100644 --- a/NOTES +++ b/NOTES @@ -8,7 +8,9 @@ Notes from PyCon 2013 sprints - Adam Sah suggests that there might be a need for scheduling (especially when multiple frameworks share an event loop). He points to lottery scheduling but also mentions that's just one of - the options. + the options. However, after posting on python-tulip, it appears + none of the other frameworks have scheduling, and nobody seems to + miss it. - Feedback from Bram Cohen (Bittorrent creator) about UDP. He doesn't think connected UDP is worth supporting, it doesn't do anything From b780325d12b2b0faca58d87792aa26f1067963cc Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 20 Mar 2013 10:11:28 -0700 Subject: [PATCH 0361/1502] http Response and Request helpers --- srv.py | 65 ++- tests/http_protocol_test.py | 814 ++++++++++++++++++++++++++++-------- tulip/http/client.py | 35 +- tulip/http/protocol.py | 701 +++++++++++++++++++++++-------- tulip/selector_events.py | 2 +- 5 files changed, 1231 insertions(+), 386 deletions(-) diff --git a/srv.py b/srv.py index 540e63b9..710e0e7e 100644 --- a/srv.py +++ b/srv.py @@ -51,22 +51,36 @@ def handle_request(self): print(hdr, val) headers.add_header(hdr, val) - write = self.transport.write if isdir and not path.endswith('/'): - bpath = path.encode('ascii') - write(b'HTTP/1.0 302 Redirected\r\n' - b'URI: ' + bpath + b'/\r\n' - b'Location: ' + bpath + b'/\r\n' - b'\r\n') + path = path + '/' + response = tulip.http.Response(self.transport, 302) + response.add_headers( + ('URI', path), + ('Location', path)) + response.send_headers() + response.write_eof() + self.transport.close() return - write(b'HTTP/1.0 200 Ok\r\n') - if isdir: - write(b'Content-type: text/html\r\n') - else: - write(b'Content-type: text/plain\r\n') - write(b'\r\n') + + response = tulip.http.Response(self.transport, 200) + response.add_header('Transfer-Encoding', 'chunked') + + # content encoding + accept_encoding = headers.get('accept-encoding', '').lower() + if 'deflate' in accept_encoding: + response.add_header('Content-Encoding', 'deflate') + response.add_compression_filter('deflate') + elif 'gzip' in accept_encoding: + response.add_header('Content-Encoding', 'gzip') + response.add_compression_filter('gzip') + + response.add_chunking_filter(1025) + if isdir: - write(b'
    \r\n') + response.add_header('Content-type', 'text/html') + response.send_headers() + + response.write(b'
      \r\n') for name in sorted(os.listdir(path)): if name.isprintable() and not name.startswith('.'): try: @@ -75,18 +89,27 @@ def handle_request(self): pass else: if os.path.isdir(os.path.join(path, name)): - write(b'
    • ' + bname + b'/
    • \r\n') + response.write(b'
    • ' + bname + b'/
    • \r\n') else: - write(b'
    • ' + bname + b'
    • \r\n') - write(b'
    ') + response.write(b'
  • ' + bname + b'
  • \r\n') + response.write(b'
') else: + response.add_header('Content-type', 'text/plain') + response.send_headers() + try: - with open(path, 'rb') as f: - write(f.read()) + with open(path, 'rb') as fp: + chunk = fp.read(8196) + while chunk: + if not response.write(chunk): + break + chunk = fp.read(8196) except OSError: - write(b'Cannot open\r\n') + response.write(b'Cannot open') + + response.write_eof() self.transport.close() def connection_made(self, transport): diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py index c0952287..408bfc7d 100644 --- a/tests/http_protocol_test.py +++ b/tests/http_protocol_test.py @@ -203,151 +203,6 @@ def test_read_headers_continuation_headers_size(self): self.assertIn("limit request headers fields size", str(cm.exception)) - def test_read_payload_unknown_encoding(self): - self.assertRaises( - ValueError, self.stream.read_length_payload, encoding='unknown') - - def test_read_payload(self): - self.stream.feed_data(b'da') - self.stream.feed_data(b't') - self.stream.feed_data(b'ali') - self.stream.feed_data(b'ne') - - stream = self.stream.read_length_payload(4) - self.assertIsInstance(stream, tulip.StreamReader) - - data = self.loop.run_until_complete(tulip.Task(stream.read())) - self.assertEqual(b'data', data) - self.assertEqual(b'line', b''.join(self.stream.buffer)) - - def test_read_payload_eof(self): - self.stream.feed_data(b'da') - self.stream.feed_eof() - stream = self.stream.read_length_payload(4) - - self.assertRaises( - http.client.IncompleteRead, - self.loop.run_until_complete, tulip.Task(stream.read())) - - def test_read_payload_eof_exc(self): - self.stream.feed_data(b'da') - stream = self.stream.read_length_payload(4) - - def eof(): - yield from [] - self.stream.feed_eof() - - t1 = tulip.Task(stream.read()) - t2 = tulip.Task(eof()) - - self.loop.run_until_complete(tulip.Task(tulip.wait([t1, t2]))) - self.assertRaises(http.client.IncompleteRead, t1.result) - self.assertIsNone(self.stream._reader) - - def test_read_payload_deflate(self): - comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) - - data = b''.join([comp.compress(b'data'), comp.flush()]) - stream = self.stream.read_length_payload(len(data), encoding='deflate') - - self.stream.feed_data(data) - - data = self.loop.run_until_complete(tulip.Task(stream.read())) - self.assertEqual(b'data', data) - - def _test_read_payload_compress_error(self): - data = b'123123123datadatadata' - reader = protocol.length_reader(4) - self.stream.feed_data(data) - stream = self.stream.read_payload(reader, 'deflate') - - self.assertRaises( - http.client.IncompleteRead, - self.loop.run_until_complete, tulip.Task(stream.read())) - - def test_read_chunked_payload(self): - stream = self.stream.read_chunked_payload() - self.stream.feed_data(b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') - - data = self.loop.run_until_complete(tulip.Task(stream.read())) - self.assertEqual(b'dataline', data) - - def test_read_chunked_payload_chunks(self): - stream = self.stream.read_chunked_payload() - - self.stream.feed_data(b'4\r\ndata\r') - self.stream.feed_data(b'\n4') - self.stream.feed_data(b'\r') - self.stream.feed_data(b'\n') - self.stream.feed_data(b'line\r\n0\r\n') - self.stream.feed_data(b'test\r\n\r\n') - - data = self.loop.run_until_complete(tulip.Task(stream.read())) - self.assertEqual(b'dataline', data) - - def test_read_chunked_payload_incomplete(self): - stream = self.stream.read_chunked_payload() - - self.stream.feed_data(b'4\r\ndata\r\n') - self.stream.feed_eof() - - self.assertRaises( - http.client.IncompleteRead, - self.loop.run_until_complete, tulip.Task(stream.read())) - - def test_read_chunked_payload_extension(self): - stream = self.stream.read_chunked_payload() - - self.stream.feed_data( - b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') - - data = self.loop.run_until_complete(tulip.Task(stream.read())) - self.assertEqual(b'dataline', data) - - def test_read_chunked_payload_size_error(self): - stream = self.stream.read_chunked_payload() - - self.stream.feed_data(b'blah\r\n') - self.assertRaises( - http.client.IncompleteRead, - self.loop.run_until_complete, tulip.Task(stream.read())) - - def test_read_length_payload(self): - stream = self.stream.read_length_payload(8) - - self.stream.feed_data(b'data') - self.stream.feed_data(b'data') - - data = self.loop.run_until_complete(tulip.Task(stream.read())) - self.assertEqual(b'datadata', data) - - def test_read_length_payload_zero(self): - stream = self.stream.read_length_payload(0) - - self.stream.feed_data(b'data') - - data = self.loop.run_until_complete(tulip.Task(stream.read())) - self.assertEqual(b'', data) - - def test_read_length_payload_incomplete(self): - stream = self.stream.read_length_payload(8) - - self.stream.feed_data(b'data') - self.stream.feed_eof() - - self.assertRaises( - http.client.IncompleteRead, - self.loop.run_until_complete, tulip.Task(stream.read())) - - def test_read_eof_payload(self): - stream = self.stream.read_eof_payload() - - self.stream.feed_data(b'data') - self.stream.feed_eof() - - data = self.loop.run_until_complete(tulip.Task(stream.read())) - self.assertEqual(b'data', data) - def test_read_message_should_close(self): self.stream.feed_data( b'Host: example.com\r\nConnection: close\r\n\r\n') @@ -477,7 +332,7 @@ def test_read_message_chunked(self): payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) self.assertEqual(b'dataline', payload) - def test_read_message_readall(self): + def test_read_message_readall_eof(self): self.stream.feed_data( b'Host: example.com\r\n\r\n') self.stream.feed_data(b'data') @@ -490,46 +345,653 @@ def test_read_message_readall(self): payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) self.assertEqual(b'dataline', payload) + def test_read_message_payload(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 8\r\n\r\n') + self.stream.feed_data(b'data') + self.stream.feed_data(b'data') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(readall=True))) + + data = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(b'datadata', data) + + def test_read_message_payload_eof(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 4\r\n\r\n') + self.stream.feed_data(b'da') + self.stream.feed_eof() + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(readall=True))) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, tulip.Task(msg.payload.read())) + + def test_read_message_length_payload_zero(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 0\r\n\r\n') + self.stream.feed_data(b'data') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + data = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(b'', data) + + def test_read_message_length_payload_incomplete(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 8\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + def coro(): + self.stream.feed_data(b'data') + self.stream.feed_eof() + return (yield from msg.payload.read()) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, tulip.Task(coro())) -class HttpStreamWriterTests(unittest.TestCase): + def test_read_message_eof_payload(self): + self.stream.feed_data(b'Host: example.com\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(readall=True))) + + def coro(): + self.stream.feed_data(b'data') + self.stream.feed_eof() + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(tulip.Task(coro())) + self.assertEqual(b'data', data) + + def test_read_message_length_payload(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 4\r\n\r\n') + self.stream.feed_data(b'da') + self.stream.feed_data(b't') + self.stream.feed_data(b'ali') + self.stream.feed_data(b'ne') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(readall=True))) + + self.assertIsInstance(msg.payload, tulip.StreamReader) + + data = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.assertEqual(b'data', data) + self.assertEqual(b'line', b''.join(self.stream.buffer)) + + def test_read_message_length_payload_extra(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 4\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + def coro(): + self.stream.feed_data(b'da') + self.stream.feed_data(b't') + self.stream.feed_data(b'ali') + self.stream.feed_data(b'ne') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(tulip.Task(coro())) + self.assertEqual(b'data', data) + self.assertEqual(b'line', b''.join(self.stream.buffer)) + + def test_parse_length_payload_eof_exc(self): + parser = self.stream._parse_length_payload(4) + next(parser) + + stream = tulip.StreamReader() + parser.send(stream) + self.stream._parser = parser + self.stream.feed_data(b'da') + + def eof(): + yield from [] + self.stream.feed_eof() + + t1 = tulip.Task(stream.read()) + t2 = tulip.Task(eof()) + + self.loop.run_until_complete(tulip.Task(tulip.wait([t1, t2]))) + self.assertRaises(http.client.IncompleteRead, t1.result) + self.assertIsNone(self.stream._parser) + + def test_read_message_deflate_payload(self): + comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + + data = b''.join([comp.compress(b'data'), comp.flush()]) + + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Encoding: deflate\r\n' + + ('Content-Length: %s\r\n\r\n' % len(data)).encode()) + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message(readall=True))) + + def coro(): + self.stream.feed_data(data) + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(tulip.Task(coro())) + self.assertEqual(b'data', data) + + def test_read_message_chunked_payload(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + def coro(): + self.stream.feed_data( + b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(tulip.Task(coro())) + self.assertEqual(b'dataline', data) + + def test_read_message_chunked_payload_chunks(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + def coro(): + self.stream.feed_data(b'4\r\ndata\r') + self.stream.feed_data(b'\n4') + self.stream.feed_data(b'\r') + self.stream.feed_data(b'\n') + self.stream.feed_data(b'line\r\n0\r\n') + self.stream.feed_data(b'test\r\n\r\n') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(tulip.Task(coro())) + self.assertEqual(b'dataline', data) + + def test_read_message_chunked_payload_incomplete(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + def coro(): + self.stream.feed_data(b'4\r\ndata\r\n') + self.stream.feed_eof() + return (yield from msg.payload.read()) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, tulip.Task(coro())) + + def test_read_message_chunked_payload_extension(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + def coro(): + self.stream.feed_data( + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(tulip.Task(coro())) + self.assertEqual(b'dataline', data) + + def test_read_message_chunked_payload_size_error(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete( + tulip.Task(self.stream.read_message())) + + def coro(): + self.stream.feed_data(b'blah\r\n') + return (yield from msg.payload.read()) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, tulip.Task(coro())) + + def test_deflate_stream_set_exception(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + exc = ValueError() + dstream.set_exception(exc) + self.assertIs(exc, stream.exception()) + + def test_deflate_stream_feed_data(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + dstream.zlib = unittest.mock.Mock() + dstream.zlib.decompress.return_value = b'line' + + dstream.feed_data(b'data') + self.assertEqual([b'line'], list(stream.buffer)) + + def test_deflate_stream_feed_data_err(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + exc = ValueError() + dstream.zlib = unittest.mock.Mock() + dstream.zlib.decompress.side_effect = exc + + dstream.feed_data(b'data') + self.assertIsInstance(stream.exception(), http.client.IncompleteRead) + + def test_deflate_stream_feed_eof(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + dstream.zlib = unittest.mock.Mock() + dstream.zlib.flush.return_value = b'line' + + dstream.feed_eof() + self.assertEqual([b'line'], list(stream.buffer)) + self.assertTrue(stream.eof) + + def test_deflate_stream_feed_eof_err(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + dstream.zlib = unittest.mock.Mock() + dstream.zlib.flush.return_value = b'line' + dstream.zlib.eof = False + + dstream.feed_eof() + self.assertIsInstance(stream.exception(), http.client.IncompleteRead) + + +class HttpMessageTests(unittest.TestCase): def setUp(self): self.transport = unittest.mock.Mock() - self.writer = protocol.HttpStreamWriter(self.transport) - def test_ctor(self): - transport = unittest.mock.Mock() - writer = protocol.HttpStreamWriter(transport, 'latin-1') - self.assertIs(writer.transport, transport) - self.assertEqual(writer.encoding, 'latin-1') + def test_start_request(self): + msg = protocol.Request( + self.transport, 'GET', '/index.html', close=True) + + self.assertIs(msg.transport, self.transport) + self.assertIsNone(msg.status) + self.assertTrue(msg.closing) + self.assertEqual(msg.status_line, 'GET /index.html HTTP/1.1\r\n') + + def test_start_response(self): + msg = protocol.Response(self.transport, 200, close=True) + + self.assertIs(msg.transport, self.transport) + self.assertEqual(msg.status, 200) + self.assertTrue(msg.closing) + self.assertEqual(msg.status_line, 'HTTP/1.1 200 OK\r\n') + + def test_force_close(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.closing) + msg.force_close() + self.assertTrue(msg.closing) + + def test_force_chunked(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.chunked) + msg.force_chunked() + self.assertTrue(msg.chunked) + + def test_keep_alive(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.keep_alive()) + msg.keepalive = True + self.assertTrue(msg.keep_alive()) + + msg.force_close() + self.assertFalse(msg.keep_alive()) + + def test_add_header(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], msg.headers) - def test_encode(self): - self.assertEqual(b'test', self.writer.encode('test')) - self.assertEqual(b'test', self.writer.encode(b'test')) + msg.add_header('content-type', 'plain/html') + self.assertEqual([('CONTENT-TYPE', 'plain/html')], msg.headers) - def test_decode(self): - self.assertEqual('test', self.writer.decode('test')) - self.assertEqual('test', self.writer.decode(b'test')) + def test_add_headers(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], msg.headers) - def test_write(self): - self.writer.write(b'test') - self.assertTrue(self.transport.write.called) - self.assertEqual((b'test',), self.transport.write.call_args[0]) + msg.add_headers(('content-type', 'plain/html')) + self.assertEqual([('CONTENT-TYPE', 'plain/html')], msg.headers) - def test_write_str(self): - self.writer.write_str('test') - self.assertTrue(self.transport.write.called) - self.assertEqual((b'test',), self.transport.write.call_args[0]) + def test_add_headers_length(self): + msg = protocol.Response(self.transport, 200) + self.assertIsNone(msg.length) - def test_write_cunked(self): - self.writer.write_chunked('') - self.assertFalse(self.transport.write.called) + msg.add_headers(('content-length', '200')) + self.assertEqual(200, msg.length) - self.writer.write_chunked('data') + def test_add_headers_upgrade(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.upgrade) + + msg.add_headers(('connection', 'upgrade')) + self.assertTrue(msg.upgrade) + + def test_add_headers_upgrade_websocket(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('upgrade', 'test')) + self.assertEqual([], msg.headers) + + msg.add_headers(('upgrade', 'websocket')) + self.assertEqual([('UPGRADE', 'websocket')], msg.headers) + + def test_add_headers_connection_keepalive(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('connection', 'keep-alive')) + self.assertEqual([], msg.headers) + self.assertTrue(msg.keepalive) + + msg.add_headers(('connection', 'close')) + self.assertFalse(msg.keepalive) + + def test_add_headers_hop_headers(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('connection', 'test'), ('transfer-encoding', 't')) + self.assertEqual([], msg.headers) + + def test_default_headers(self): + msg = protocol.Response(self.transport, 200) + + headers = [r for r, _ in msg._default_headers()] + self.assertIn('DATE', headers) + self.assertIn('CONNECTION', headers) + + def test_default_headers_server(self): + msg = protocol.Response(self.transport, 200) + + headers = [r for r, _ in msg._default_headers()] + self.assertIn('SERVER', headers) + + def test_default_headers_useragent(self): + msg = protocol.Request(self.transport, 'GET', '/') + + headers = [r for r, _ in msg._default_headers()] + self.assertNotIn('SERVER', headers) + self.assertIn('USER-AGENT', headers) + + def test_default_headers_chunked(self): + msg = protocol.Response(self.transport, 200) + + headers = [r for r, _ in msg._default_headers()] + self.assertNotIn('TRANSFER-ENCODING', headers) + + msg.force_chunked() + + headers = [r for r, _ in msg._default_headers()] + self.assertIn('TRANSFER-ENCODING', headers) + + def test_default_headers_connection_upgrade(self): + msg = protocol.Response(self.transport, 200) + msg.upgrade = True + + headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'upgrade')], headers) + + def test_default_headers_connection_close(self): + msg = protocol.Response(self.transport, 200) + msg.force_close() + + headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'close')], headers) + + def test_default_headers_connection_keep_alive(self): + msg = protocol.Response(self.transport, 200) + msg.keepalive = True + + headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'keep-alive')], headers) + + def test_send_headers(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-type', 'plain/html')) + self.assertFalse(msg.is_headers_sent()) + + msg.send_headers() + + content = b''.join([arg[1][0] for arg in list(write.mock_calls)]) + + self.assertTrue(content.startswith(b'HTTP/1.1 200 OK\r\n')) + self.assertIn(b'CONTENT-TYPE: plain/html', content) + self.assertTrue(msg.headers_sent) + self.assertTrue(msg.is_headers_sent()) + + def test_send_headers_nomore_add(self): + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-type', 'plain/html')) + msg.send_headers() + + self.assertRaises(AssertionError, + msg.add_header, 'content-type', 'plain/html') + + def test_prepare_length(self): + msg = protocol.Response(self.transport, 200) + length = msg._write_length_payload = unittest.mock.Mock() + length.return_value = iter([1, 2, 3]) + + msg.add_headers(('content-length', '200')) + msg.send_headers() + + self.assertTrue(length.called) + self.assertTrue((200,), length.call_args[0]) + + def test_prepare_chunked_force(self): + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + + chunked = msg._write_chunked_payload = unittest.mock.Mock() + chunked.return_value = iter([1, 2, 3]) + + msg.add_headers(('content-length', '200')) + msg.send_headers() + self.assertTrue(chunked.called) + + def test_prepare_chunked_no_length(self): + msg = protocol.Response(self.transport, 200) + + chunked = msg._write_chunked_payload = unittest.mock.Mock() + chunked.return_value = iter([1, 2, 3]) + + msg.send_headers() + self.assertTrue(chunked.called) + + def test_prepare_eof(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + + eof = msg._write_eof_payload = unittest.mock.Mock() + eof.return_value = iter([1, 2, 3]) + + msg.send_headers() + self.assertTrue(eof.called) + + def test_write_auto_send_headers(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg._send_headers = True + + msg.write(b'data1') + self.assertTrue(msg.headers_sent) + + def test_write_payload_eof(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg.send_headers() + + msg.write(b'data1') + self.assertTrue(msg.headers_sent) + + msg.write(b'data2') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'data1data2', content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg.send_headers() + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'4\r\ndata\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_multiple(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg.send_headers() + + msg.write(b'data1') + msg.write(b'data2') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_length(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '2')) + msg.send_headers() + + msg.write(b'd') + msg.write(b'ata') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'da', content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_filter(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(2) + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith(b'2\r\nda\r\n2\r\nta\r\n0\r\n\r\n')) + + def test_write_payload_chunked_filter_mutiple_chunks(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(2) + msg.write(b'data1') + msg.write(b'data2') + msg.write_eof() + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith( + b'2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n' + b'2\r\na2\r\n0\r\n\r\n')) + + def test_write_payload_chunked_large_chunk(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(1024) + msg.write(b'data') + msg.write_eof() + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith(b'4\r\ndata\r\n0\r\n\r\n')) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_write_payload_deflate_filter(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '%s' % len(self._COMPRESSED))) + msg.send_headers() + + msg.add_compression_filter('deflate') + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_deflate_and_chunked(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_compression_filter('deflate') + msg.add_chunking_filter(2) + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) self.assertEqual( - [(b'4\r\n',), (b'data',), (b'\r\n',)], - [c[0] for c in self.transport.write.call_args_list]) + b'2\r\nKI\r\n2\r\n,I\r\n2\r\n\x04\x00\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_and_deflate(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '%s' % len(self._COMPRESSED))) + + msg.add_chunking_filter(2) + msg.add_compression_filter('deflate') + msg.send_headers() - def test_write_eof(self): - self.writer.write_chunked_eof() - self.assertEqual((b'0\r\n\r\n',), self.transport.write.call_args[0]) + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) diff --git a/tulip/http/client.py b/tulip/http/client.py index b4db5ccb..b65b90a8 100644 --- a/tulip/http/client.py +++ b/tulip/http/client.py @@ -10,15 +10,6 @@ headers['status'] == '200 Ok' # or some such assert isinstance(response, bytes) -However you can also open a stream: - - f, wstream = http_client.open_stream(url, method, headers) - wstream.write(b'abc') - wstream.writelines([b'def', b'ghi']) - wstream.write_eof() - sts, headers, rstream = yield from f - response = yield from rstream.read() - TODO: Reuse email.Message class (or its subclass, http.client.HTTPMessage). TODO: How do we do connection keep alive? Pooling? """ @@ -45,7 +36,7 @@ class HttpClientProtocol: def __init__(self, host, port=None, *, path='/', method='GET', headers=None, ssl=None, - make_body=None, encoding='utf-8', version='1.1', + make_body=None, encoding='utf-8', version=(1, 1), chunked=False): host = self.validate(host, 'host') if ':' in host: @@ -70,7 +61,7 @@ def __init__(self, host, port=None, *, self.validate(value, 'header value', True) self.headers[key] = value self.encoding = self.validate(encoding, 'encoding') - self.version = self.validate(version, 'version') + self.version = version self.make_body = make_body self.chunked = chunked self.ssl = ssl @@ -127,22 +118,22 @@ def connect(self): def connection_made(self, transport): self.transport = transport self.stream = protocol.HttpStreamReader() - self.wstream = protocol.HttpStreamWriter(transport) - - line = '{} {} HTTP/{}\r\n'.format(self.method, - self.path, - self.version) - self.wstream.write_str(line) - for key, value in self.headers.items(): - self.wstream.write_str('{}: {}\r\n'.format(key, value)) - self.wstream.write(b'\r\n') + + self.request = protocol.Request( + transport, self.method, self.path, self.version) + + self.request.add_headers(*self.headers.items()) + self.request.send_headers() + if self.make_body is not None: if self.chunked: self.make_body( - self.wstream.write_chunked, self.wstream.write_chunked_eof) + self.request.write, self.request.eof) else: self.make_body( - self.wstream.write_str, self.wstream.write_eof) + self.request.write, self.request.eof) + else: + self.request.write_eof() def data_received(self, data): self.stream.feed_data(data) diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py index 2ff7876f..1773cabd 100644 --- a/tulip/http/protocol.py +++ b/tulip/http/protocol.py @@ -1,12 +1,17 @@ """Http related helper utils.""" -__all__ = ['HttpStreamReader', 'HttpStreamWriter', - 'HttpMessage', 'RequestLine', 'ResponseStatus'] +__all__ = ['HttpStreamReader', + 'HttpMessage', 'Request', 'Response', + 'RawHttpMessage', 'RequestLine', 'ResponseStatus'] import collections +import email.utils import functools import http.client +import http.server +import itertools import re +import sys import zlib import tulip @@ -15,7 +20,7 @@ VERSRE = re.compile('HTTP/(\d+).(\d+)') HDRRE = re.compile(b"[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]") CONTINUATION = (b' ', b'\t') - +RESPONSES = http.server.BaseHTTPRequestHandler.responses RequestLine = collections.namedtuple( 'RequestLine', ['method', 'uri', 'version']) @@ -25,109 +30,8 @@ 'ResponseStatus', ['version', 'code', 'reason']) -HttpMessage = collections.namedtuple( - 'HttpMessage', ['headers', 'payload', 'should_close', 'compression']) - - -class StreamEofException(http.client.HTTPException): - """eof received""" - - -def wrap_payload_reader(f): - """wrap_payload_reader wraps payload readers and redirect stream. - payload readers are generator functions, read_chunked_payload, - read_length_payload, read_eof_payload. - payload reader allows to modify data stream and feed data into stream. - - StreamReader instance should be send to generator as first parameter. - This steam is used as destination stream for processed data. - To send data to reader use generator's send() method. - - To indicate eof stream, throw StreamEofException exception into the reader. - In case of errors in incoming stream reader sets exception to - destination stream with StreamReader.set_exception() method. - - Before exit, reader generator returns unprocessed data. - """ - - @functools.wraps(f) - def wrapper(self, *args, **kw): - assert self._reader is None - - rstream = stream = tulip.StreamReader() - - encoding = kw.pop('encoding', None) - if encoding is not None: - if encoding not in ('gzip', 'deflate'): - raise ValueError( - 'Content-Encoding %r is not supported' % encoding) - - stream = DeflateStream(stream, encoding) - - reader = f(self, *args, **kw) - next(reader) - try: - reader.send(stream) - except StopIteration: - pass - else: - # feed buffer - self.line_count = 0 - self.byte_count = 0 - while self.buffer: - try: - reader.send(self.buffer.popleft()) - except StopIteration as exc: - buf = b''.join(self.buffer) - self.buffer.clear() - reader = None - if exc.value: - self.feed_data(exc.value) - - if buf: - self.feed_data(buf) - - break - - if reader is not None: - if self.eof: - try: - reader.throw(StreamEofException()) - except StopIteration as exc: - pass - else: - self._reader = reader - - return rstream - - return wrapper - - -class DeflateStream: - """DeflateStream decomress stream and feed data into specified steram.""" - - def __init__(self, stream, encoding): - self.stream = stream - zlib_mode = (16 + zlib.MAX_WBITS - if encoding == 'gzip' else -zlib.MAX_WBITS) - - self.zlib = zlib.decompressobj(wbits=zlib_mode) - - def feed_data(self, chunk): - try: - chunk = self.zlib.decompress(chunk) - except: - self.stream.set_exception(http.client.IncompleteRead(b'')) - - if chunk: - self.stream.feed_data(chunk) - - def feed_eof(self): - self.stream.feed_data(self.zlib.flush()) - if not self.zlib.eof: - self.stream.set_exception(http.client.IncompleteRead(b'')) - - self.stream.feed_eof() +RawHttpMessage = collections.namedtuple( + 'RawHttpMessage', ['headers', 'payload', 'should_close', 'compression']) class HttpStreamReader(tulip.StreamReader): @@ -135,32 +39,32 @@ class HttpStreamReader(tulip.StreamReader): MAX_HEADERS = 32768 MAX_HEADERFIELD_SIZE = 8190 - # if _reader is set, feed_data and feed_eof sends data into - # _reader instead of self. is it being used as stream redirection for - # read_chunked_payload, read_length_payload and read_eof_payload - _reader = None + # if _parser is set, feed_data and feed_eof sends data into + # _parser instead of self. is it being used as stream redirection for + # _parse_chunked_payload, _parse_length_payload and _parse_eof_payload + _parser = None def feed_data(self, data): - """_reader is a generator, if _reader is set, feed_data sends - incoming data into this generator untile generates stops.""" - if self._reader: + """_parser is a generator, if _parser is set, feed_data sends + incoming data into the generator untile generator stops.""" + if self._parser: try: - self._reader.send(data) + self._parser.send(data) except StopIteration as exc: - self._reader = None + self._parser = None if exc.value: self.feed_data(exc.value) else: super().feed_data(data) def feed_eof(self): - """_reader is a generator, if _reader is set feed_eof throws + """_parser is a generator, if _parser is set feed_eof throws StreamEofException into this generator.""" - if self._reader: + if self._parser: try: - self._reader.throw(StreamEofException()) + self._parser.throw(StreamEofException()) except StopIteration: - self._reader = None + self._parser = None super().feed_eof() @@ -314,9 +218,8 @@ def read_headers(self): return headers - @wrap_payload_reader - def read_chunked_payload(self): - """Read chunked stream.""" + def _parse_chunked_payload(self): + """Chunked transfer encoding parser.""" stream = yield try: @@ -331,7 +234,7 @@ def read_chunked_payload(self): line, data = data.split(b'\n', 1) # Read the next chunk size from the file - i = line.find(b";") + i = line.find(b';') if i >= 0: line = line[:i] # strip chunk-extensions try: @@ -375,8 +278,7 @@ def read_chunked_payload(self): except http.client.IncompleteRead as exc: stream.set_exception(exc) - @wrap_payload_reader - def read_length_payload(self, length): + def _parse_length_payload(self, length): """Read specified amount of bytes.""" stream = yield @@ -400,8 +302,7 @@ def read_length_payload(self, length): except StreamEofException: stream.set_exception(http.client.IncompleteRead(b'')) - @wrap_payload_reader - def read_eof_payload(self): + def _parse_eof_payload(self): """Read all bytes untile eof.""" stream = yield @@ -437,11 +338,11 @@ def read_message(self, version=(1, 1), chunked = value.lower() == 'chunked' elif name == 'SEC-WEBSOCKET-KEY1': length = 8 - elif name == "CONNECTION": + elif name == 'CONNECTION': v = value.lower() - if v == "close": + if v == 'close': close_conn = True - elif v == "keep-alive": + elif v == 'keep-alive': close_conn = False elif compression and name == 'CONTENT-ENCODING': enc = value.lower() @@ -451,9 +352,9 @@ def read_message(self, version=(1, 1), if close_conn is None: close_conn = version <= (1, 0) - # payload stream + # payload parser if chunked: - payload = self.read_chunked_payload(encoding=encoding) + parser = self._parse_chunked_payload() elif length is not None: try: @@ -464,45 +365,513 @@ def read_message(self, version=(1, 1), if length < 0: raise http.client.HTTPException('CONTENT-LENGTH') - payload = self.read_length_payload(length, encoding=encoding) + parser = self._parse_length_payload(length) else: if readall: - payload = self.read_eof_payload(encoding=encoding) + parser = self._parse_eof_payload() + else: + parser = self._parse_length_payload(0) + + next(parser) + + payload = stream = tulip.StreamReader() + + # payload decompression wrapper + if encoding is not None: + stream = DeflateStream(stream, encoding) + + try: + # initialize payload parser with stream, stream is being + # used by parser as destination stream + parser.send(stream) + except StopIteration: + pass + else: + # feed existing buffer to payload parser + self.byte_count = 0 + while self.buffer: + try: + parser.send(self.buffer.popleft()) + except StopIteration as exc: + parser = None + + # parser is done + buf = b''.join(self.buffer) + self.buffer.clear() + + # re-add remaining data back to buffer + if exc.value: + self.feed_data(exc.value) + + if buf: + self.feed_data(buf) + + break + + # parser still require more data + if parser is not None: + if self.eof: + try: + parser.throw(StreamEofException()) + except StopIteration as exc: + pass + else: + self._parser = parser + + return RawHttpMessage(headers, payload, close_conn, encoding) + + +class StreamEofException(http.client.HTTPException): + """eof received""" + + +class DeflateStream: + """DeflateStream decomress stream and feed data into specified stream.""" + + def __init__(self, stream, encoding): + self.stream = stream + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + + self.zlib = zlib.decompressobj(wbits=zlib_mode) + + def set_exception(self, exc): + self.stream.set_exception(exc) + + def feed_data(self, chunk): + try: + chunk = self.zlib.decompress(chunk) + except: + self.stream.set_exception(http.client.IncompleteRead(b'')) + + if chunk: + self.stream.feed_data(chunk) + + def feed_eof(self): + self.stream.feed_data(self.zlib.flush()) + if not self.zlib.eof: + self.stream.set_exception(http.client.IncompleteRead(b'')) + + self.stream.feed_eof() + + +EOF_MARKER = object() +EOL_MARKER = object() + + +def wrap_payload_filter(func): + """Wraps payload filter and piped filters. + + Filter is a generatator that accepts arbitrary chunks of data, + modify data and emit new stream of data. + + For example we have stream of chunks: ['1', '2', '3', '4', '5'], + we can apply chunking filter to this stream: + + ['1', '2', '3', '4', '5'] + | + response.add_chunking_filter(2) + | + ['12', '34', '5'] + + It is possible to use different filters at the same time. + + For a example to compress incoming stream with 'deflate' encoding + and then split data and emit chunks of 8196 bytes size chunks: + + >> response.add_compression_filter('deflate') + >> response.add_chunking_filter(8196) + + Filters do not alter transfer encoding. + + Filter can receive types types of data, bytes object or EOF_MARKER. + + 1. If filter receives bytes object, it should process data + and yield processed data then yield EOL_MARKER object. + 2. If Filter recevied EOF_MARKER, it should yield remaining + data (buffered) and then yield EOF_MARKER. + """ + @functools.wraps(func) + def wrapper(self, *args, **kw): + new_filter = func(self, *args, **kw) + + filter = self.filter + if filter is not None: + next(new_filter) + self.filter = filter_pipe(filter, new_filter) + else: + self.filter = new_filter + + next(self.filter) + + return wrapper + + +def filter_pipe(filter, filter2): + """Creates pipe between two filters. + + filter_pipe() feeds first filter with incoming data and then + send yielded from first filter data into filter2, results of + filter2 are being emitted. + + 1. If filter_pipe receives bytes object, it sends it to the first filter. + 2. Reads yielded values from the first filter until it receives + EOF_MARKER or EOL_MARKER. + 3. Each of this values is being send to second filter. + 4. Reads yielded values from second filter until it recives EOF_MARKER or + EOL_MARKER. Each of this values yields to writer. + """ + chunk = yield + + while True: + eof = chunk is EOF_MARKER + chunk = filter.send(chunk) + + while chunk is not EOL_MARKER: + chunk = filter2.send(chunk) + + while chunk not in (EOF_MARKER, EOL_MARKER): + yield chunk + chunk = next(filter2) + + if chunk is not EOF_MARKER: + if eof: + chunk = EOF_MARKER + else: + chunk = next(filter) else: - payload = self.read_length_payload(0, encoding=encoding) + break + + chunk = yield EOL_MARKER + + +class HttpMessage: + """HttpMessage allows to write headers and payload to a stream. + + For example, lets say we want to read file then compress it with deflate + compression and then send it with chunked transfer encoding, code may look + like this: + + >> response = tulip.http.Response(transport, 200) + + We have to use deflate compression first: + + >> response.add_compression_filter('deflate') + + Then we want to split output stream into chunks of 1024 bytes size: + + >> response.add_chunking_filter(1024) + + We can add headers to response with add_headers() method. add_headers() + does not send data to transport, send_headers() sends request/response + line and then sends headers: + + >> response.add_headers( + .. ('Content-Disposition', 'attachment; filename="..."')) + >> response.send_headers() + + Now we can use chunked writer to write stream to a network stream. + First call to write() method sends response status line and headers, + add_header() and add_headers() method unavailble at this stage: - return HttpMessage(headers, payload, close_conn, encoding) + >> with open('...', 'rb') as f: + .. chunk = fp.read(8196) + .. while chunk: + .. response.write(chunk) + .. chunk = fp.read(8196) + >> response.write_eof() + """ + + writer = None + + # 'filter' is being used for altering write() bahaviour, + # add_chunking_filter adds deflate/gzip compression and + # add_compression_filter splits incoming data into a chunks. + filter = None + + HOP_HEADERS = None # Must be set by subclass. -class HttpStreamWriter: + SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} tulip/0.0'.format(sys.version_info) - def __init__(self, transport, encoding='utf-8'): + status = None + status_line = b'' + + # subclass can enable auto sending headers with write() call, + # this is useful for wsgi's start_response implementation. + _send_headers = False + + def __init__(self, transport, version, close): self.transport = transport - self.encoding = encoding - - def encode(self, s): - if isinstance(s, bytes): - return s - return s.encode(self.encoding) - - def decode(self, s): - if isinstance(s, str): - return s - return s.decode(self.encoding) - - def write(self, b): - self.transport.write(b) - - def write_str(self, s): - self.transport.write(self.encode(s)) - - def write_chunked(self, chunk): - if not chunk: - return - data = self.encode(chunk) - self.write_str('{:x}\r\n'.format(len(data))) - self.transport.write(data) - self.transport.write(b'\r\n') - - def write_chunked_eof(self): - self.transport.write(b'0\r\n\r\n') + self.version = version + self.closing = close + self.keepalive = False + + self.chunked = False + self.length = None + self.upgrade = False + self.headers = [] + self.headers_sent = False + + def force_close(self): + self.closing = True + + def force_chunked(self): + self.chunked = True + + def keep_alive(self): + return self.keepalive and not self.closing + + def is_headers_sent(self): + return self.headers_sent + + def add_header(self, name, value): + """Analyze headers. Calculate content length, + removes hop headers, etc.""" + assert not self.headers_sent, 'headers have been sent already' + assert isinstance(name, str), '%r is not a string' % name + + name = name.strip().upper() + + if name == 'CONTENT-LENGTH': + self.length = int(value) + + if name == 'CONNECTION': + val = value.lower().strip() + # handle websocket + if val == 'upgrade': + self.upgrade = True + # connection keep-alive + elif val == 'close': + self.keepalive = False + elif val == 'keep-alive': + self.keepalive = True + + elif name == 'UPGRADE': + if 'websocket' in value.lower(): + self.headers.append((name, value)) + + elif name == 'TRANSFER-ENCODING' and not self.chunked: + self.chunked = value.lower().strip() == 'chunked' + + elif name not in self.HOP_HEADERS: + # ignore hopbyhop headers + self.headers.append((name, value)) + + def add_headers(self, *headers): + """Adds headers to a http message.""" + for name, value in headers: + self.add_header(name, value) + + def send_headers(self): + """Writes headers to a stream. Constructs payload writer.""" + # Chunked response is only for HTTP/1.1 clients or newer + # and there is no Content-Length header is set. + # Do not use chunked responses when the response is guaranteed to + # not have a response body (304, 204). + assert not self.headers_sent, 'headers have been sent already' + self.headers_sent = True + + if (self.chunked is True) or ( + self.length is None and + self.version >= (1, 1) and + self.status not in (304, 204)): + self.chunked = True + self.writer = self._write_chunked_payload() + + elif self.length is not None: + self.writer = self._write_length_payload(self.length) + + else: + self.writer = self._write_eof_payload() + + next(self.writer) + + # status line + self.transport.write(self.status_line.encode('ascii')) + + # send headers + self.transport.write( + ('%s\r\n\r\n' % '\r\n'.join( + ('%s: %s' % (k, v) for k, v in + itertools.chain(self._default_headers(), self.headers))) + ).encode('ascii')) + + def _default_headers(self): + # set the connection header + if self.upgrade: + connection = 'upgrade' + elif self.keep_alive(): + connection = 'keep-alive' + else: + connection = 'close' + + headers = [('CONNECTION', connection)] + + if self.chunked: + headers.append(('TRANSFER-ENCODING', 'chunked')) + + return headers + + def write(self, chunk): + """write() writes chunk of data to a steram by using different writers. + writer uses filter to modify chunk of data. write_eof() indicates + end of stream. writer can't be used after write_eof() method + being called.""" + assert (isinstance(chunk, (bytes, bytearray)) or + chunk is EOF_MARKER), chunk + + if self._send_headers and not self.headers_sent: + self.send_headers() + + assert self.writer is not None, 'send_headers() is not called.' + + if self.filter: + chunk = self.filter.send(chunk) + while chunk not in (EOF_MARKER, EOL_MARKER): + self.writer.send(chunk) + chunk = next(self.filter) + else: + if chunk is not EOF_MARKER: + self.writer.send(chunk) + + def write_eof(self): + self.write(EOF_MARKER) + try: + self.writer.throw(StreamEofException()) + except StopIteration: + pass + + def _write_chunked_payload(self): + """Write data in chunked transfer encoding.""" + while True: + try: + chunk = yield + except StreamEofException: + self.transport.write(b'0\r\n\r\n') + break + + self.transport.write('{:x}\r\n'.format(len(chunk)).encode('ascii')) + self.transport.write(chunk) + self.transport.write(b'\r\n') + + def _write_length_payload(self, length): + """Write specified number of bytes to a stream.""" + while True: + try: + chunk = yield + except StreamEofException: + break + + if length: + l = len(chunk) + if length >= l: + self.transport.write(chunk) + else: + self.transport.write(chunk[:length]) + + length = max(0, length-l) + + def _write_eof_payload(self): + while True: + try: + chunk = yield + except StreamEofException: + break + + self.transport.write(chunk) + + @wrap_payload_filter + def add_chunking_filter(self, chunk_size=16*1024): + """Split incoming stream into chunks.""" + buf = bytearray() + chunk = yield + + while True: + if chunk is EOF_MARKER: + if buf: + yield buf + + yield EOF_MARKER + + else: + buf.extend(chunk) + + while len(buf) >= chunk_size: + chunk, buf = buf[:chunk_size], buf[chunk_size:] + yield chunk + + chunk = yield EOL_MARKER + + @wrap_payload_filter + def add_compression_filter(self, encoding='deflate'): + """Compress incoming stream with deflate or gzip encoding.""" + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + zcomp = zlib.compressobj(wbits=zlib_mode) + + chunk = yield + while True: + if chunk is EOF_MARKER: + yield zcomp.flush() + chunk = yield EOF_MARKER + + else: + yield zcomp.compress(chunk) + chunk = yield EOL_MARKER + + +class Response(HttpMessage): + """Create http response message. + + Transport is a socket stream transport. status is a response status code, + status has to be integer value. http_version is a tuple that represents + http version, (1, 0) stands for HTTP/1.0 and (1, 1) is for HTTP/1.1 + """ + + HOP_HEADERS = { + 'CONNECTION', + 'KEEP-ALIVE', + 'PROXY-AUTHENTICATE', + 'PROXY-AUTHORIZATION', + 'TE', + 'TRAILERS', + 'TRANSFER-ENCODING', + 'UPGRADE', + 'SERVER', + 'DATE', + } + + def __init__(self, transport, status, http_version=(1, 1), close=False): + super().__init__(transport, http_version, close) + + self.status = status + self.status_line = 'HTTP/{0[0]}.{0[1]} {1} {2}\r\n'.format( + http_version, status, RESPONSES[status][0]) + + def _default_headers(self): + headers = super()._default_headers() + headers.extend((('DATE', email.utils.formatdate()), + ('SERVER', self.SERVER_SOFTWARE))) + + return headers + + +class Request(HttpMessage): + + HOP_HEADERS = () + + def __init__(self, transport, method, uri, + http_version=(1, 1), close=False): + super().__init__(transport, http_version, close) + + self.method = method + self.uri = uri + self.status_line = '{0} {1} HTTP/{2[0]}.{2[1]}\r\n'.format( + method, uri, http_version) + + def _default_headers(self): + headers = super()._default_headers() + headers.append(('USER-AGENT', self.SERVER_SOFTWARE)) + + return headers diff --git a/tulip/selector_events.py b/tulip/selector_events.py index cc7fe33c..9bc9c23f 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -352,7 +352,7 @@ def _read_ready(self): self._event_loop.call_soon(self._protocol.eof_received) def write(self, data): - assert isinstance(data, bytes) + assert isinstance(data, (bytes, bytearray)), repr(data) assert not self._closing if not data: return From 35981988fbf103d88d6415be790dfd2aa140260d Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 20 Mar 2013 11:25:20 -0700 Subject: [PATCH 0362/1502] timeout support for Future --- tests/base_events_test.py | 4 +++- tests/events_test.py | 48 +++++++++++++++++++++++---------------- tests/tasks_test.py | 33 +++++++++++++++++++++++++++ tulip/base_events.py | 4 ++-- tulip/futures.py | 13 ++++++++++- tulip/locks.py | 36 ++++------------------------- tulip/tasks.py | 16 ++++--------- 7 files changed, 86 insertions(+), 68 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index 03c3296b..1839837b 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -11,6 +11,7 @@ from tulip import events from tulip import futures from tulip import protocols +from tulip import tasks from tulip import test_utils @@ -265,7 +266,8 @@ def _socket(*args, **kw): self.event_loop.getaddrinfo = getaddrinfo - task = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + task = tasks.Task( + self.event_loop.create_connection(MyProto, 'xkcd.com', 80)) task._step() exc = task.exception() self.assertEqual("Multiple exceptions: err1, err2", str(exc)) diff --git a/tests/events_test.py b/tests/events_test.py index 20c398de..1c0883b6 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -25,6 +25,7 @@ from tulip import transports from tulip import protocols from tulip import selector_events +from tulip import tasks from tulip import test_utils @@ -531,7 +532,8 @@ def my_handler(*args): def test_create_connection(self): with self.run_test_server() as httpd: host, port = httpd.socket.getsockname() - f = self.event_loop.create_connection(MyProto, host, port) + f = tasks.Task( + self.event_loop.create_connection(MyProto, host, port)) tr, pr = self.event_loop.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) @@ -541,7 +543,8 @@ def test_create_connection(self): def test_create_connection_sock(self): with self.run_test_server() as httpd: host, port = httpd.socket.getsockname() - f = self.event_loop.create_connection(MyProto, host, port) + f = tasks.Task( + self.event_loop.create_connection(MyProto, host, port)) tr, pr = self.event_loop.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) @@ -552,25 +555,26 @@ def test_create_connection_sock(self): def test_create_ssl_connection(self): with self.run_test_server(use_ssl=True) as httpsd: host, port = httpsd.socket.getsockname() - f = self.event_loop.create_connection( - MyProto, host, port, ssl=True) + f = tasks.Task(self.event_loop.create_connection( + MyProto, host, port, ssl=True)) tr, pr = self.event_loop.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) self.assertTrue('ssl' in tr.__class__.__name__.lower()) - self.assertTrue(hasattr(tr.get_extra_info('socket'), 'getsockname')) + self.assertTrue( + hasattr(tr.get_extra_info('socket'), 'getsockname')) self.event_loop.run() self.assertTrue(pr.nbytes > 0) def test_create_connection_host_port_sock(self): self.suppress_log_errors() - fut = self.event_loop.create_connection( - MyProto, 'xkcd.com', 80, sock=object()) + fut = tasks.Task(self.event_loop.create_connection( + MyProto, 'xkcd.com', 80, sock=object())) self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) def test_create_connection_no_host_port_sock(self): self.suppress_log_errors() - fut = self.event_loop.create_connection(MyProto) + fut = tasks.Task(self.event_loop.create_connection(MyProto)) self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) def test_create_connection_no_getaddrinfo(self): @@ -578,7 +582,8 @@ def test_create_connection_no_getaddrinfo(self): getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() getaddrinfo.return_value = [] - fut = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + fut = tasks.Task( + self.event_loop.create_connection(MyProto, 'xkcd.com', 80)) self.assertRaises( socket.error, self.event_loop.run_until_complete, fut) @@ -587,7 +592,8 @@ def test_create_connection_connect_err(self): self.event_loop.sock_connect = unittest.mock.Mock() self.event_loop.sock_connect.side_effect = socket.error - fut = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + fut = tasks.Task( + self.event_loop.create_connection(MyProto, 'xkcd.com', 80)) self.assertRaises( socket.error, self.event_loop.run_until_complete, fut) @@ -602,7 +608,8 @@ def getaddrinfo(*args, **kw): self.event_loop.sock_connect = unittest.mock.Mock() self.event_loop.sock_connect.side_effect = socket.error - fut = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + fut = tasks.Task( + self.event_loop.create_connection(MyProto, 'xkcd.com', 80)) self.assertRaises( socket.error, self.event_loop.run_until_complete, fut) @@ -721,8 +728,8 @@ def datagram_received(self, data, addr): sock = self.event_loop.run_until_complete(f) host, port = sock.getsockname() - f = self.event_loop.create_datagram_connection( - MyDatagramProto, host, port) + f = tasks.Task(self.event_loop.create_datagram_connection( + MyDatagramProto, host, port)) transport, protocol = self.event_loop.run_until_complete(f) self.assertEqual('INITIALIZED', protocol.state) @@ -767,7 +774,8 @@ def datagram_received(self, data, addr): sock = self.event_loop.run_until_complete(f) host, port = sock.getsockname() - f = self.event_loop.create_datagram_connection(MyDatagramProto) + f = tasks.Task( + self.event_loop.create_datagram_connection(MyDatagramProto)) transport, protocol = self.event_loop.run_until_complete(f) self.assertEqual('INITIALIZED', protocol.state) @@ -800,8 +808,8 @@ def test_create_datagram_connection_no_getaddrinfo(self): getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() getaddrinfo.return_value = [] - fut = self.event_loop.create_datagram_connection( - protocols.DatagramProtocol, 'xkcd.com', 80) + fut = tasks.Task(self.event_loop.create_datagram_connection( + protocols.DatagramProtocol, 'xkcd.com', 80)) self.assertRaises( socket.error, self.event_loop.run_until_complete, fut) @@ -810,8 +818,8 @@ def test_create_datagram_connection_connect_err(self): self.event_loop.sock_connect = unittest.mock.Mock() self.event_loop.sock_connect.side_effect = socket.error - fut = self.event_loop.create_datagram_connection( - protocols.DatagramProtocol, 'xkcd.com', 80) + fut = tasks.Task(self.event_loop.create_datagram_connection( + protocols.DatagramProtocol, 'xkcd.com', 80)) self.assertRaises( socket.error, self.event_loop.run_until_complete, fut) @@ -822,8 +830,8 @@ def test_create_datagram_connection_sockopt_err(self, m_socket): m_socket.error = socket.error m_socket.socket.return_value.setsockopt.side_effect = socket.error - fut = self.event_loop.create_datagram_connection( - protocols.DatagramProtocol) + fut = tasks.Task(self.event_loop.create_datagram_connection( + protocols.DatagramProtocol)) self.assertRaises( socket.error, self.event_loop.run_until_complete, fut) self.assertTrue( diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 2a11c202..03eb5261 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -122,6 +122,39 @@ def task(): self.assertTrue(t.done()) self.assertFalse(t.cancel()) + def test_future_timeout(self): + @tasks.coroutine + def coro(): + yield from tasks.sleep(10.0) + return 12 + + t = tasks.Task(coro(), timeout=0.1) + + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_future_timeout_catch(self): + @tasks.coroutine + def coro(): + yield from tasks.sleep(10.0) + return 12 + + err = None + + @tasks.coroutine + def coro2(): + nonlocal err + try: + yield from tasks.Task(coro(), timeout=0.1) + except futures.CancelledError as exc: + err = exc + + self.event_loop.run_until_complete(tasks.Task(coro2())) + self.assertIsInstance(err, futures.CancelledError) + def test_cancel_in_coro(self): @tasks.coroutine def task(): diff --git a/tulip/base_events.py b/tulip/base_events.py index c2cf18a5..b71be09a 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -239,7 +239,7 @@ def getaddrinfo(self, host, port, *, def getnameinfo(self, sockaddr, flags=0): return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) - @tasks.task + @tasks.coroutine def create_connection(self, protocol_factory, host=None, port=None, *, ssl=False, family=0, proto=0, flags=0, sock=None): """XXX""" @@ -298,7 +298,7 @@ def create_connection(self, protocol_factory, host=None, port=None, *, yield from waiter return transport, protocol - @tasks.task + @tasks.coroutine def create_datagram_connection(self, protocol_factory, host=None, port=None, *, family=socket.AF_INET, proto=0, flags=0): diff --git a/tulip/futures.py b/tulip/futures.py index 68735f35..39137aa6 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -51,10 +51,11 @@ class Future: _state = _PENDING _result = None _exception = None + _timeout_handle = None _blocking = False # proper use of future (yield vs yield from) - def __init__(self, *, event_loop=None): + def __init__(self, *, event_loop=None, timeout=None): """Initialize the future. The optional event_loop argument allows to explicitly set the event @@ -67,6 +68,10 @@ def __init__(self, *, event_loop=None): self._event_loop = event_loop self._callbacks = [] + if timeout is not None: + self._timeout_handle = self._event_loop.call_later( + timeout, self.cancel) + def __repr__(self): res = self.__class__.__name__ if self._state == _FINISHED: @@ -105,9 +110,15 @@ def _schedule_callbacks(self): The callbacks are scheduled to be called as soon as possible. Also clears the callback list. """ + # Cancel timeout handle + if self._timeout_handle is not None: + self._timeout_handle.cancel() + self._timeout_handle = None + callbacks = self._callbacks[:] if not callbacks: return + self._callbacks[:] = [] for callback in callbacks: self._event_loop.call_soon(callback, self) diff --git a/tulip/locks.py b/tulip/locks.py index c86048f4..40247962 100644 --- a/tulip/locks.py +++ b/tulip/locks.py @@ -94,11 +94,7 @@ def acquire(self, timeout=None): self._locked = True return True - fut = futures.Future(event_loop=self._event_loop) - if timeout is not None: - handle = self._event_loop.call_later(timeout, fut.cancel) - else: - handle = None + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) self._waiters.append(fut) try: @@ -110,9 +106,6 @@ def acquire(self, timeout=None): f = self._waiters.popleft() assert f is fut - if handle is not None: - handle.cancel() - self._locked = True return True @@ -209,11 +202,7 @@ def wait(self, timeout=None): if self._value: return True - fut = futures.Future(event_loop=self._event_loop) - if timeout is not None: - handle = self._event_loop.call_later(timeout, fut.cancel) - else: - handle = None + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) self._waiters.append(fut) try: @@ -225,9 +214,6 @@ def wait(self, timeout=None): f = self._waiters.popleft() assert f is fut - if handle is not None: - handle.cancel() - return True @@ -267,11 +253,7 @@ def wait(self, timeout=None): self.release() - fut = futures.Future(event_loop=self._event_loop) - if timeout is not None: - handle = self._event_loop.call_later(timeout, fut.cancel) - else: - handle = None + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) self._condition_waiters.append(fut) try: @@ -285,9 +267,6 @@ def wait(self, timeout=None): finally: yield from self.acquire() - if handle is not None: - handle.cancel() - return True @tasks.coroutine @@ -406,11 +385,7 @@ def acquire(self, timeout=None): self._locked = True return True - fut = futures.Future(event_loop=self._event_loop) - if timeout is not None: - handle = self._event_loop.call_later(timeout, fut.cancel) - else: - handle = None + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) self._waiters.append(fut) try: @@ -422,9 +397,6 @@ def acquire(self, timeout=None): f = self._waiters.popleft() assert f is fut - if handle is not None: - handle.cancel() - self._value -= 1 if self._value == 0: self._locked = True diff --git a/tulip/tasks.py b/tulip/tasks.py index 08bbb31c..45aba017 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -10,7 +10,6 @@ import logging import time -from . import events from . import futures @@ -46,9 +45,9 @@ def task_wrapper(*args, **kwds): class Task(futures.Future): """A coroutine wrapped in a Future.""" - def __init__(self, coro, event_loop=None): + def __init__(self, coro, event_loop=None, timeout=None): assert inspect.isgenerator(coro) # Must be a coroutine *object*. - super().__init__(event_loop=event_loop) # Sets self._event_loop. + super().__init__(event_loop=event_loop, timeout=timeout) self._coro = coro self._must_cancel = False self._event_loop.call_soon(self._step) @@ -202,13 +201,8 @@ def _wait(fs, timeout=None, return_when=ALL_COMPLETED): return_when == FIRST_EXCEPTION and errors): return done, pending - bail = futures.Future() # Will always be cancelled eventually. - timeout_handle = None - debugstuff = locals() - - if timeout is not None: - loop = events.get_event_loop() - timeout_handle = loop.call_later(timeout, bail.cancel) + # Will always be cancelled eventually. + bail = futures.Future(timeout=timeout) def _on_completion(f): pending.remove(f) @@ -230,8 +224,6 @@ def _on_completion(f): finally: for f in pending: f.remove_done_callback(_on_completion) - if timeout_handle is not None: - timeout_handle.cancel() really_done = set(f for f in pending if f.done()) if really_done: From 29fe05e8753927d27fdc0e64f6a6f8311f34ac8f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 20 Mar 2013 12:45:05 -0700 Subject: [PATCH 0363/1502] Tests should not have docstrings, they pollute the output. --- runtests.py | 2 +- tests/base_events_test.py | 4 ++-- tests/streams_test.py | 18 +++++++++--------- tests/tasks_test.py | 2 +- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/runtests.py b/runtests.py index e4678d70..096a2561 100644 --- a/runtests.py +++ b/runtests.py @@ -89,7 +89,7 @@ def list_dir(prefix, dir): loader = importlib.machinery.SourceFileLoader(modname, sourcefile) mods.append((loader.load_module(), sourcefile)) except Exception as err: - print("Skipping '%s': %s" % (modname, err)) + print("Skipping '%s': %s" % (modname, err), file=sys.stderr) return mods diff --git a/tests/base_events_test.py b/tests/base_events_test.py index 1839837b..5a71bc5d 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -186,7 +186,7 @@ def test__run_once_timeout(self): self.assertEqual((1.0,), self.event_loop._selector.select.call_args[0]) def test__run_once_timeout_with_ready(self): - """If event loop has ready callbacks, select timeout is always 0.""" + # If event loop has ready callbacks, select timeout is always 0. h = events.Timer(time.monotonic() + 10.0, lambda: True, ()) self.event_loop._process_events = unittest.mock.Mock() @@ -199,7 +199,7 @@ def test__run_once_timeout_with_ready(self): @unittest.mock.patch('tulip.base_events.time') @unittest.mock.patch('tulip.base_events.logging') def test__run_once_logging(self, m_logging, m_time): - """Log to INFO level if timeout > 1.0 sec.""" + # Log to INFO level if timeout > 1.0 sec. idx = -1 data = [10.0, 10.0, 12.0, 13.0] diff --git a/tests/streams_test.py b/tests/streams_test.py index 832ce371..15fece12 100644 --- a/tests/streams_test.py +++ b/tests/streams_test.py @@ -34,7 +34,7 @@ def test_feed_data_byte_count(self): self.assertEqual(len(self.DATA), stream.byte_count) def test_read_zero(self): - """Read zero bytes.""" + # Read zero bytes. stream = streams.StreamReader() stream.feed_data(self.DATA) @@ -44,7 +44,7 @@ def test_read_zero(self): self.assertEqual(len(self.DATA), stream.byte_count) def test_read(self): - """Read bytes.""" + # Read bytes. stream = streams.StreamReader() read_task = tasks.Task(stream.read(30)) @@ -57,7 +57,7 @@ def cb(): self.assertFalse(stream.byte_count) def test_read_line_breaks(self): - """Read bytes without line breaks.""" + # Read bytes without line breaks. stream = streams.StreamReader() stream.feed_data(b'line1') stream.feed_data(b'line2') @@ -69,7 +69,7 @@ def test_read_line_breaks(self): self.assertEqual(5, stream.byte_count) def test_read_eof(self): - """Read bytes, stop at eof.""" + # Read bytes, stop at eof. stream = streams.StreamReader() read_task = tasks.Task(stream.read(1024)) @@ -82,7 +82,7 @@ def cb(): self.assertFalse(stream.byte_count) def test_read_until_eof(self): - """Read all bytes until eof.""" + # Read all bytes until eof. stream = streams.StreamReader() read_task = tasks.Task(stream.read(-1)) @@ -110,7 +110,7 @@ def test_read_exception(self): self.event_loop.run_until_complete, tasks.Task(stream.read(2))) def test_readline(self): - """Read one line.""" + # Read one line. stream = streams.StreamReader() stream.feed_data(b'chunk1 ') read_task = tasks.Task(stream.readline()) @@ -225,7 +225,7 @@ def test_readline_exception(self): self.event_loop.run_until_complete, tasks.Task(stream.readline())) def test_readexactly_zero_or_less(self): - """Read exact number of bytes (zero or less).""" + # Read exact number of bytes (zero or less). stream = streams.StreamReader() stream.feed_data(self.DATA) @@ -240,7 +240,7 @@ def test_readexactly_zero_or_less(self): self.assertEqual(len(self.DATA), stream.byte_count) def test_readexactly(self): - """Read exact number of bytes.""" + # Read exact number of bytes. stream = streams.StreamReader() n = 2 * len(self.DATA) @@ -257,7 +257,7 @@ def cb(): self.assertEqual(len(self.DATA), stream.byte_count) def test_readexactly_eof(self): - """Read exact number of bytes (eof).""" + # Read exact number of bytes (eof). stream = streams.StreamReader() n = 2 * len(self.DATA) read_task = tasks.Task(stream.readexactly(n)) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 03eb5261..2bcb7745 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -472,7 +472,7 @@ def notmuch(): self.assertEqual(1, m_logging.warn.call_args[0][1]) def test_step_result_future(self): - """If coroutine returns future, task waits on this future.""" + # If coroutine returns future, task waits on this future. self.suppress_log_warnings() class Fut(futures.Future): From c20f8647cd1e5eaeffc03085bc2d0ae0ba397e32 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 20 Mar 2013 13:18:40 -0700 Subject: [PATCH 0364/1502] run_until_complete accepts coroutine --- curl.py | 4 +- tests/events_test.py | 58 ++++++------- tests/http_protocol_test.py | 165 +++++++++++++++--------------------- tests/locks_test.py | 80 ++++++++--------- tests/streams_test.py | 54 +++++------- tulip/base_events.py | 6 ++ 6 files changed, 161 insertions(+), 206 deletions(-) diff --git a/curl.py b/curl.py index 6986fc20..37fce75c 100755 --- a/curl.py +++ b/curl.py @@ -17,12 +17,12 @@ def main(): print(netloc, path, scheme) p = tulip.http.HttpClientProtocol(netloc, path=path, ssl=(scheme=='https')) f = p.connect() - sts, headers, stream = p.event_loop.run_until_complete(tulip.Task(f)) + sts, headers, stream = p.event_loop.run_until_complete(f) print(sts) for k, v in headers.items(): print('{}: {}'.format(k, v)) print() - data = p.event_loop.run_until_complete(tulip.Task(stream.read())) + data = p.event_loop.run_until_complete(stream.read()) print(data.decode('utf-8', 'replace')) diff --git a/tests/events_test.py b/tests/events_test.py index 1c0883b6..794de307 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -555,8 +555,8 @@ def test_create_connection_sock(self): def test_create_ssl_connection(self): with self.run_test_server(use_ssl=True) as httpsd: host, port = httpsd.socket.getsockname() - f = tasks.Task(self.event_loop.create_connection( - MyProto, host, port, ssl=True)) + f = self.event_loop.create_connection( + MyProto, host, port, ssl=True) tr, pr = self.event_loop.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) @@ -568,34 +568,32 @@ def test_create_ssl_connection(self): def test_create_connection_host_port_sock(self): self.suppress_log_errors() - fut = tasks.Task(self.event_loop.create_connection( - MyProto, 'xkcd.com', 80, sock=object())) - self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + coro = self.event_loop.create_connection( + MyProto, 'xkcd.com', 80, sock=object()) + self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) def test_create_connection_no_host_port_sock(self): self.suppress_log_errors() - fut = tasks.Task(self.event_loop.create_connection(MyProto)) - self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + coro = self.event_loop.create_connection(MyProto) + self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) def test_create_connection_no_getaddrinfo(self): self.suppress_log_errors() getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() getaddrinfo.return_value = [] - fut = tasks.Task( - self.event_loop.create_connection(MyProto, 'xkcd.com', 80)) + coro = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) self.assertRaises( - socket.error, self.event_loop.run_until_complete, fut) + socket.error, self.event_loop.run_until_complete, coro) def test_create_connection_connect_err(self): self.suppress_log_errors() self.event_loop.sock_connect = unittest.mock.Mock() self.event_loop.sock_connect.side_effect = socket.error - fut = tasks.Task( - self.event_loop.create_connection(MyProto, 'xkcd.com', 80)) + coro = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) self.assertRaises( - socket.error, self.event_loop.run_until_complete, fut) + socket.error, self.event_loop.run_until_complete, coro) def test_create_connection_mutiple_errors(self): self.suppress_log_errors() @@ -608,10 +606,9 @@ def getaddrinfo(*args, **kw): self.event_loop.sock_connect = unittest.mock.Mock() self.event_loop.sock_connect.side_effect = socket.error - fut = tasks.Task( - self.event_loop.create_connection(MyProto, 'xkcd.com', 80)) + coro = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) self.assertRaises( - socket.error, self.event_loop.run_until_complete, fut) + socket.error, self.event_loop.run_until_complete, coro) def test_start_serving(self): proto = None @@ -728,9 +725,9 @@ def datagram_received(self, data, addr): sock = self.event_loop.run_until_complete(f) host, port = sock.getsockname() - f = tasks.Task(self.event_loop.create_datagram_connection( - MyDatagramProto, host, port)) - transport, protocol = self.event_loop.run_until_complete(f) + coro = self.event_loop.create_datagram_connection( + MyDatagramProto, host, port) + transport, protocol = self.event_loop.run_until_complete(coro) self.assertEqual('INITIALIZED', protocol.state) transport.sendto(b'xxx') @@ -774,9 +771,8 @@ def datagram_received(self, data, addr): sock = self.event_loop.run_until_complete(f) host, port = sock.getsockname() - f = tasks.Task( - self.event_loop.create_datagram_connection(MyDatagramProto)) - transport, protocol = self.event_loop.run_until_complete(f) + coro = self.event_loop.create_datagram_connection(MyDatagramProto) + transport, protocol = self.event_loop.run_until_complete(coro) self.assertEqual('INITIALIZED', protocol.state) transport.sendto(b'xxx', (host, port)) @@ -808,20 +804,20 @@ def test_create_datagram_connection_no_getaddrinfo(self): getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() getaddrinfo.return_value = [] - fut = tasks.Task(self.event_loop.create_datagram_connection( - protocols.DatagramProtocol, 'xkcd.com', 80)) + coro = self.event_loop.create_datagram_connection( + protocols.DatagramProtocol, 'xkcd.com', 80) self.assertRaises( - socket.error, self.event_loop.run_until_complete, fut) + socket.error, self.event_loop.run_until_complete, coro) def test_create_datagram_connection_connect_err(self): self.suppress_log_errors() self.event_loop.sock_connect = unittest.mock.Mock() self.event_loop.sock_connect.side_effect = socket.error - fut = tasks.Task(self.event_loop.create_datagram_connection( - protocols.DatagramProtocol, 'xkcd.com', 80)) + coro = self.event_loop.create_datagram_connection( + protocols.DatagramProtocol, 'xkcd.com', 80) self.assertRaises( - socket.error, self.event_loop.run_until_complete, fut) + socket.error, self.event_loop.run_until_complete, coro) @unittest.mock.patch('tulip.base_events.socket') def test_create_datagram_connection_sockopt_err(self, m_socket): @@ -830,10 +826,10 @@ def test_create_datagram_connection_sockopt_err(self, m_socket): m_socket.error = socket.error m_socket.socket.return_value.setsockopt.side_effect = socket.error - fut = tasks.Task(self.event_loop.create_datagram_connection( - protocols.DatagramProtocol)) + coro = self.event_loop.create_datagram_connection( + protocols.DatagramProtocol) self.assertRaises( - socket.error, self.event_loop.run_until_complete, fut) + socket.error, self.event_loop.run_until_complete, coro) self.assertTrue( m_socket.socket.return_value.close.called) diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py index 408bfc7d..74aef7c8 100644 --- a/tests/http_protocol_test.py +++ b/tests/http_protocol_test.py @@ -30,22 +30,19 @@ def test_request_line(self): self.stream.feed_data(b'get /path HTTP/1.1\r\n') self.assertEqual( ('GET', '/path', (1, 1)), - self.loop.run_until_complete( - tulip.Task(self.stream.read_request_line()))) + self.loop.run_until_complete(self.stream.read_request_line())) def test_request_line_two_slashes(self): self.stream.feed_data(b'get //path HTTP/1.1\r\n') self.assertEqual( ('GET', '//path', (1, 1)), - self.loop.run_until_complete( - tulip.Task(self.stream.read_request_line()))) + self.loop.run_until_complete(self.stream.read_request_line())) def test_request_line_non_ascii(self): self.stream.feed_data(b'get /path\xd0\xb0 HTTP/1.1\r\n') with self.assertRaises(http.client.BadStatusLine) as cm: - self.loop.run_until_complete( - tulip.Task(self.stream.read_request_line())) + self.loop.run_until_complete(self.stream.read_request_line()) self.assertEqual( b'get /path\xd0\xb0 HTTP/1.1\r\n', cm.exception.args[0]) @@ -55,49 +52,47 @@ def test_request_line_bad_status_line(self): self.assertRaises( http.client.BadStatusLine, self.loop.run_until_complete, - tulip.Task(self.stream.read_request_line())) + self.stream.read_request_line()) def test_request_line_bad_method(self): self.stream.feed_data(b'!12%()+=~$ /get HTTP/1.1\r\n') self.assertRaises( http.client.BadStatusLine, self.loop.run_until_complete, - tulip.Task(self.stream.read_request_line())) + self.stream.read_request_line()) def test_request_line_bad_version(self): self.stream.feed_data(b'GET //get HT/11\r\n') self.assertRaises( http.client.BadStatusLine, self.loop.run_until_complete, - tulip.Task(self.stream.read_request_line())) + self.stream.read_request_line()) def test_response_status_bad_status_line(self): self.stream.feed_data(b'\r\n') self.assertRaises( http.client.BadStatusLine, self.loop.run_until_complete, - tulip.Task(self.stream.read_response_status())) + self.stream.read_response_status()) def test_response_status_bad_status_line_eof(self): self.stream.feed_eof() self.assertRaises( http.client.BadStatusLine, self.loop.run_until_complete, - tulip.Task(self.stream.read_response_status())) + self.stream.read_response_status()) def test_response_status_bad_status_non_ascii(self): self.stream.feed_data(b'HTTP/1.1 200 \xd0\xb0\r\n') with self.assertRaises(http.client.BadStatusLine) as cm: - self.loop.run_until_complete( - tulip.Task(self.stream.read_response_status())) + self.loop.run_until_complete(self.stream.read_response_status()) self.assertEqual(b'HTTP/1.1 200 \xd0\xb0\r\n', cm.exception.args[0]) def test_response_status_bad_version(self): self.stream.feed_data(b'HT/11 200 Ok\r\n') with self.assertRaises(http.client.BadStatusLine) as cm: - self.loop.run_until_complete( - tulip.Task(self.stream.read_response_status())) + self.loop.run_until_complete(self.stream.read_response_status()) self.assertEqual('HT/11 200 Ok', cm.exception.args[0]) @@ -105,7 +100,7 @@ def test_response_status_no_reason(self): self.stream.feed_data(b'HTTP/1.1 200\r\n') v, s, r = self.loop.run_until_complete( - tulip.Task(self.stream.read_response_status())) + self.stream.read_response_status()) self.assertEqual(v, (1, 1)) self.assertEqual(s, 200) self.assertEqual(r, '') @@ -114,7 +109,7 @@ def test_response_status_bad(self): self.stream.feed_data(b'HTT/1\r\n') with self.assertRaises(http.client.BadStatusLine) as cm: self.loop.run_until_complete( - tulip.Task(self.stream.read_response_status())) + self.stream.read_response_status()) self.assertIn('HTT/1', str(cm.exception)) @@ -122,7 +117,7 @@ def test_response_status_bad_code_under_100(self): self.stream.feed_data(b'HTTP/1.1 99 test\r\n') with self.assertRaises(http.client.BadStatusLine) as cm: self.loop.run_until_complete( - tulip.Task(self.stream.read_response_status())) + self.stream.read_response_status()) self.assertIn('HTTP/1.1 99 test', str(cm.exception)) @@ -130,7 +125,7 @@ def test_response_status_bad_code_above_999(self): self.stream.feed_data(b'HTTP/1.1 9999 test\r\n') with self.assertRaises(http.client.BadStatusLine) as cm: self.loop.run_until_complete( - tulip.Task(self.stream.read_response_status())) + self.stream.read_response_status()) self.assertIn('HTTP/1.1 9999 test', str(cm.exception)) @@ -138,7 +133,7 @@ def test_response_status_bad_code_not_int(self): self.stream.feed_data(b'HTTP/1.1 ttt test\r\n') with self.assertRaises(http.client.BadStatusLine) as cm: self.loop.run_until_complete( - tulip.Task(self.stream.read_response_status())) + self.stream.read_response_status()) self.assertIn('HTTP/1.1 ttt test', str(cm.exception)) @@ -148,8 +143,7 @@ def test_read_headers(self): b'test2: data\r\n' b'\r\n') - headers = self.loop.run_until_complete( - tulip.Task(self.stream.read_headers())) + headers = self.loop.run_until_complete(self.stream.read_headers()) self.assertEqual(headers, [('TEST', 'line\r\n continue'), ('TEST2', 'data')]) @@ -163,14 +157,13 @@ def test_read_headers_size(self): self.assertRaises( http.client.LineTooLong, self.loop.run_until_complete, - tulip.Task(self.stream.read_headers())) + self.stream.read_headers()) def test_read_headers_invalid_header(self): self.stream.feed_data(b'test line\r\n') with self.assertRaises(ValueError) as cm: - self.loop.run_until_complete( - tulip.Task(self.stream.read_headers())) + self.loop.run_until_complete(self.stream.read_headers()) self.assertIn("Invalid header b'test line'", str(cm.exception)) @@ -178,8 +171,7 @@ def test_read_headers_invalid_name(self): self.stream.feed_data(b'test[]: line\r\n') with self.assertRaises(ValueError) as cm: - self.loop.run_until_complete( - tulip.Task(self.stream.read_headers())) + self.loop.run_until_complete(self.stream.read_headers()) self.assertIn("Invalid header name b'TEST[]'", str(cm.exception)) @@ -188,8 +180,7 @@ def test_read_headers_headers_size(self): self.stream.feed_data(b'test: line data data\r\ndata\r\n') with self.assertRaises(http.client.LineTooLong) as cm: - self.loop.run_until_complete( - tulip.Task(self.stream.read_headers())) + self.loop.run_until_complete(self.stream.read_headers()) self.assertIn("limit request headers fields size", str(cm.exception)) @@ -198,8 +189,7 @@ def test_read_headers_continuation_headers_size(self): self.stream.feed_data(b'test: line\r\n test\r\n') with self.assertRaises(http.client.LineTooLong) as cm: - self.loop.run_until_complete( - tulip.Task(self.stream.read_headers())) + self.loop.run_until_complete(self.stream.read_headers()) self.assertIn("limit request headers fields size", str(cm.exception)) @@ -207,8 +197,7 @@ def test_read_message_should_close(self): self.stream.feed_data( b'Host: example.com\r\nConnection: close\r\n\r\n') - msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message())) + msg = self.loop.run_until_complete(self.stream.read_message()) self.assertTrue(msg.should_close) def test_read_message_should_close_http11(self): @@ -216,7 +205,7 @@ def test_read_message_should_close_http11(self): b'Host: example.com\r\n\r\n') msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message(version=(1, 1)))) + self.stream.read_message(version=(1, 1))) self.assertFalse(msg.should_close) def test_read_message_should_close_http10(self): @@ -224,15 +213,14 @@ def test_read_message_should_close_http10(self): b'Host: example.com\r\n\r\n') msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message(version=(1, 0)))) + self.stream.read_message(version=(1, 0))) self.assertTrue(msg.should_close) def test_read_message_should_close_keep_alive(self): self.stream.feed_data( b'Host: example.com\r\nConnection: keep-alive\r\n\r\n') - msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message())) + msg = self.loop.run_until_complete(self.stream.read_message()) self.assertFalse(msg.should_close) def test_read_message_content_length_broken(self): @@ -242,7 +230,7 @@ def test_read_message_content_length_broken(self): self.assertRaises( http.client.HTTPException, self.loop.run_until_complete, - tulip.Task(self.stream.read_message())) + self.stream.read_message()) def test_read_message_content_length_wrong(self): self.stream.feed_data( @@ -251,25 +239,24 @@ def test_read_message_content_length_wrong(self): self.assertRaises( http.client.HTTPException, self.loop.run_until_complete, - tulip.Task(self.stream.read_message())) + self.stream.read_message()) def test_read_message_content_length(self): self.stream.feed_data( b'Host: example.com\r\nContent-Length: 2\r\n\r\n12') - msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message())) + msg = self.loop.run_until_complete(self.stream.read_message()) - payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + payload = self.loop.run_until_complete(msg.payload.read()) self.assertEqual(b'12', payload) def test_read_message_content_length_no_val(self): self.stream.feed_data(b'Host: example.com\r\n\r\n12') msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message(readall=False))) + self.stream.read_message(readall=False)) - payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + payload = self.loop.run_until_complete(msg.payload.read()) self.assertEqual(b'', payload) _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) @@ -282,9 +269,8 @@ def test_read_message_deflate(self): len(self._COMPRESSED)).encode()) self.stream.feed_data(self._COMPRESSED) - msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message())) - payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + msg = self.loop.run_until_complete(self.stream.read_message()) + payload = self.loop.run_until_complete(msg.payload.read()) self.assertEqual(b'data', payload) def test_read_message_deflate_disabled(self): @@ -295,8 +281,8 @@ def test_read_message_deflate_disabled(self): self.stream.feed_data(self._COMPRESSED) msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message(compression=False))) - payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.stream.read_message(compression=False)) + payload = self.loop.run_until_complete(msg.payload.read()) self.assertEqual(self._COMPRESSED, payload) def test_read_message_deflate_unknown(self): @@ -306,18 +292,17 @@ def test_read_message_deflate_unknown(self): self.stream.feed_data(self._COMPRESSED) msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message(compression=False))) - payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + self.stream.read_message(compression=False)) + payload = self.loop.run_until_complete(msg.payload.read()) self.assertEqual(self._COMPRESSED, payload) def test_read_message_websocket(self): self.stream.feed_data( b'Host: example.com\r\nSec-Websocket-Key1: 13\r\n\r\n1234567890') - msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message())) + msg = self.loop.run_until_complete(self.stream.read_message()) - payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + payload = self.loop.run_until_complete(msg.payload.read()) self.assertEqual(b'12345678', payload) def test_read_message_chunked(self): @@ -326,10 +311,9 @@ def test_read_message_chunked(self): self.stream.feed_data( b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') - msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message())) + msg = self.loop.run_until_complete(self.stream.read_message()) - payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + payload = self.loop.run_until_complete(msg.payload.read()) self.assertEqual(b'dataline', payload) def test_read_message_readall_eof(self): @@ -340,9 +324,9 @@ def test_read_message_readall_eof(self): self.stream.feed_eof() msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message(readall=True))) + self.stream.read_message(readall=True)) - payload = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + payload = self.loop.run_until_complete(msg.payload.read()) self.assertEqual(b'dataline', payload) def test_read_message_payload(self): @@ -353,9 +337,9 @@ def test_read_message_payload(self): self.stream.feed_data(b'data') msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message(readall=True))) + self.stream.read_message(readall=True)) - data = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + data = self.loop.run_until_complete(msg.payload.read()) self.assertEqual(b'datadata', data) def test_read_message_payload_eof(self): @@ -366,11 +350,11 @@ def test_read_message_payload_eof(self): self.stream.feed_eof() msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message(readall=True))) + self.stream.read_message(readall=True)) self.assertRaises( http.client.IncompleteRead, - self.loop.run_until_complete, tulip.Task(msg.payload.read())) + self.loop.run_until_complete, msg.payload.read()) def test_read_message_length_payload_zero(self): self.stream.feed_data( @@ -378,10 +362,8 @@ def test_read_message_length_payload_zero(self): b'Content-Length: 0\r\n\r\n') self.stream.feed_data(b'data') - msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message())) - - data = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + msg = self.loop.run_until_complete(self.stream.read_message()) + data = self.loop.run_until_complete(msg.payload.read()) self.assertEqual(b'', data) def test_read_message_length_payload_incomplete(self): @@ -389,8 +371,7 @@ def test_read_message_length_payload_incomplete(self): b'Host: example.com\r\n' b'Content-Length: 8\r\n\r\n') - msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message())) + msg = self.loop.run_until_complete(self.stream.read_message()) def coro(): self.stream.feed_data(b'data') @@ -399,20 +380,20 @@ def coro(): self.assertRaises( http.client.IncompleteRead, - self.loop.run_until_complete, tulip.Task(coro())) + self.loop.run_until_complete, coro()) def test_read_message_eof_payload(self): self.stream.feed_data(b'Host: example.com\r\n\r\n') msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message(readall=True))) + self.stream.read_message(readall=True)) def coro(): self.stream.feed_data(b'data') self.stream.feed_eof() return (yield from msg.payload.read()) - data = self.loop.run_until_complete(tulip.Task(coro())) + data = self.loop.run_until_complete(coro()) self.assertEqual(b'data', data) def test_read_message_length_payload(self): @@ -425,11 +406,11 @@ def test_read_message_length_payload(self): self.stream.feed_data(b'ne') msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message(readall=True))) + self.stream.read_message(readall=True)) self.assertIsInstance(msg.payload, tulip.StreamReader) - data = self.loop.run_until_complete(tulip.Task(msg.payload.read())) + data = self.loop.run_until_complete(msg.payload.read()) self.assertEqual(b'data', data) self.assertEqual(b'line', b''.join(self.stream.buffer)) @@ -438,8 +419,7 @@ def test_read_message_length_payload_extra(self): b'Host: example.com\r\n' b'Content-Length: 4\r\n\r\n') - msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message())) + msg = self.loop.run_until_complete(self.stream.read_message()) def coro(): self.stream.feed_data(b'da') @@ -448,7 +428,7 @@ def coro(): self.stream.feed_data(b'ne') return (yield from msg.payload.read()) - data = self.loop.run_until_complete(tulip.Task(coro())) + data = self.loop.run_until_complete(coro()) self.assertEqual(b'data', data) self.assertEqual(b'line', b''.join(self.stream.buffer)) @@ -468,7 +448,7 @@ def eof(): t1 = tulip.Task(stream.read()) t2 = tulip.Task(eof()) - self.loop.run_until_complete(tulip.Task(tulip.wait([t1, t2]))) + self.loop.run_until_complete(tulip.wait([t1, t2])) self.assertRaises(http.client.IncompleteRead, t1.result) self.assertIsNone(self.stream._parser) @@ -483,13 +463,13 @@ def test_read_message_deflate_payload(self): ('Content-Length: %s\r\n\r\n' % len(data)).encode()) msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message(readall=True))) + self.stream.read_message(readall=True)) def coro(): self.stream.feed_data(data) return (yield from msg.payload.read()) - data = self.loop.run_until_complete(tulip.Task(coro())) + data = self.loop.run_until_complete(coro()) self.assertEqual(b'data', data) def test_read_message_chunked_payload(self): @@ -497,15 +477,14 @@ def test_read_message_chunked_payload(self): b'Host: example.com\r\n' b'Transfer-Encoding: chunked\r\n\r\n') - msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message())) + msg = self.loop.run_until_complete(self.stream.read_message()) def coro(): self.stream.feed_data( b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') return (yield from msg.payload.read()) - data = self.loop.run_until_complete(tulip.Task(coro())) + data = self.loop.run_until_complete(coro()) self.assertEqual(b'dataline', data) def test_read_message_chunked_payload_chunks(self): @@ -513,8 +492,7 @@ def test_read_message_chunked_payload_chunks(self): b'Host: example.com\r\n' b'Transfer-Encoding: chunked\r\n\r\n') - msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message())) + msg = self.loop.run_until_complete(self.stream.read_message()) def coro(): self.stream.feed_data(b'4\r\ndata\r') @@ -525,7 +503,7 @@ def coro(): self.stream.feed_data(b'test\r\n\r\n') return (yield from msg.payload.read()) - data = self.loop.run_until_complete(tulip.Task(coro())) + data = self.loop.run_until_complete(coro()) self.assertEqual(b'dataline', data) def test_read_message_chunked_payload_incomplete(self): @@ -533,8 +511,7 @@ def test_read_message_chunked_payload_incomplete(self): b'Host: example.com\r\n' b'Transfer-Encoding: chunked\r\n\r\n') - msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message())) + msg = self.loop.run_until_complete(self.stream.read_message()) def coro(): self.stream.feed_data(b'4\r\ndata\r\n') @@ -543,22 +520,21 @@ def coro(): self.assertRaises( http.client.IncompleteRead, - self.loop.run_until_complete, tulip.Task(coro())) + self.loop.run_until_complete, coro()) def test_read_message_chunked_payload_extension(self): self.stream.feed_data( b'Host: example.com\r\n' b'Transfer-Encoding: chunked\r\n\r\n') - msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message())) + msg = self.loop.run_until_complete(self.stream.read_message()) def coro(): self.stream.feed_data( b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') return (yield from msg.payload.read()) - data = self.loop.run_until_complete(tulip.Task(coro())) + data = self.loop.run_until_complete(coro()) self.assertEqual(b'dataline', data) def test_read_message_chunked_payload_size_error(self): @@ -566,8 +542,7 @@ def test_read_message_chunked_payload_size_error(self): b'Host: example.com\r\n' b'Transfer-Encoding: chunked\r\n\r\n') - msg = self.loop.run_until_complete( - tulip.Task(self.stream.read_message())) + msg = self.loop.run_until_complete(self.stream.read_message()) def coro(): self.stream.feed_data(b'blah\r\n') @@ -575,7 +550,7 @@ def coro(): self.assertRaises( http.client.IncompleteRead, - self.loop.run_until_complete, tulip.Task(coro())) + self.loop.run_until_complete, coro()) def test_deflate_stream_set_exception(self): stream = tulip.StreamReader() diff --git a/tests/locks_test.py b/tests/locks_test.py index 20dc222b..7d2111d9 100644 --- a/tests/locks_test.py +++ b/tests/locks_test.py @@ -27,7 +27,7 @@ def test_repr(self): def acquire_lock(): yield from lock - self.event_loop.run_until_complete(tasks.Task(acquire_lock())) + self.event_loop.run_until_complete(acquire_lock()) self.assertTrue(repr(lock).endswith('[locked]>')) def test_lock(self): @@ -36,7 +36,7 @@ def test_lock(self): def acquire_lock(): return (yield from lock) - res = self.event_loop.run_until_complete(tasks.Task(acquire_lock())) + res = self.event_loop.run_until_complete(acquire_lock()) self.assertTrue(res) self.assertTrue(lock.locked()) @@ -49,9 +49,7 @@ def test_acquire(self): result = [] self.assertTrue( - self.event_loop.run_until_complete( - tasks.Task(lock.acquire()) - )) + self.event_loop.run_until_complete(lock.acquire())) @tasks.coroutine def c1(result): @@ -94,27 +92,26 @@ def c3(result): def test_acquire_timeout(self): lock = locks.Lock() self.assertTrue( - self.event_loop.run_until_complete(tasks.Task(lock.acquire()))) + self.event_loop.run_until_complete(lock.acquire())) t0 = time.monotonic() acquired = self.event_loop.run_until_complete( - tasks.Task(lock.acquire(timeout=0.1))) + lock.acquire(timeout=0.1)) self.assertFalse(acquired) total_time = (time.monotonic() - t0) self.assertTrue(0.08 < total_time < 0.12) lock = locks.Lock() - self.event_loop.run_until_complete(tasks.Task(lock.acquire())) + self.event_loop.run_until_complete(lock.acquire()) self.event_loop.call_later(0.1, lock.release) - acquired = self.event_loop.run_until_complete( - tasks.Task(lock.acquire(10.1))) + acquired = self.event_loop.run_until_complete(lock.acquire(10.1)) self.assertTrue(acquired) def test_acquire_timeout_mixed(self): lock = locks.Lock() - self.event_loop.run_until_complete(tasks.Task(lock.acquire())) + self.event_loop.run_until_complete(lock.acquire()) tasks.Task(lock.acquire()) tasks.Task(lock.acquire()) acquire_task = tasks.Task(lock.acquire(0.1)) @@ -132,7 +129,7 @@ def test_acquire_timeout_mixed(self): def test_acquire_cancel(self): lock = locks.Lock() self.assertTrue( - self.event_loop.run_until_complete(tasks.Task(lock.acquire()))) + self.event_loop.run_until_complete(lock.acquire())) task = tasks.Task(lock.acquire()) self.event_loop.call_soon(task.cancel) @@ -148,7 +145,7 @@ def test_release_not_acquired(self): def test_release_no_waiters(self): lock = locks.Lock() - self.event_loop.run_until_complete(tasks.Task(lock.acquire())) + self.event_loop.run_until_complete(lock.acquire()) self.assertTrue(lock.locked()) lock.release() @@ -231,22 +228,21 @@ def test_wait_on_set(self): ev = locks.EventWaiter() ev.set() - res = self.event_loop.run_until_complete(tasks.Task(ev.wait())) + res = self.event_loop.run_until_complete(ev.wait()) self.assertTrue(res) def test_wait_timeout(self): ev = locks.EventWaiter() t0 = time.monotonic() - res = self.event_loop.run_until_complete(tasks.Task(ev.wait(0.1))) + res = self.event_loop.run_until_complete(ev.wait(0.1)) self.assertFalse(res) total_time = (time.monotonic() - t0) self.assertTrue(0.08 < total_time < 0.12) ev = locks.EventWaiter() self.event_loop.call_later(0.1, ev.set) - acquired = self.event_loop.run_until_complete( - tasks.Task(ev.wait(10.1))) + acquired = self.event_loop.run_until_complete(ev.wait(10.1)) self.assertTrue(acquired) def test_wait_timeout_mixed(self): @@ -352,7 +348,7 @@ def c3(result): self.assertFalse(cond.locked()) self.assertTrue( - self.event_loop.run_until_complete(tasks.Task(cond.acquire()))) + self.event_loop.run_until_complete(cond.acquire())) cond.notify() self.event_loop.run_once() self.assertEqual([], result) @@ -380,10 +376,10 @@ def c3(result): def test_wait_timeout(self): cond = locks.Condition() - self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + self.event_loop.run_until_complete(cond.acquire()) t0 = time.monotonic() - wait = self.event_loop.run_until_complete(tasks.Task(cond.wait(0.1))) + wait = self.event_loop.run_until_complete(cond.wait(0.1)) self.assertFalse(wait) self.assertTrue(cond.locked()) @@ -392,7 +388,7 @@ def test_wait_timeout(self): def test_wait_cancel(self): cond = locks.Condition() - self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + self.event_loop.run_until_complete(cond.acquire()) wait = tasks.Task(cond.wait()) self.event_loop.call_soon(wait.cancel) @@ -408,7 +404,7 @@ def test_wait_unacquired(self): cond = locks.Condition() self.assertRaises( RuntimeError, - self.event_loop.run_until_complete, tasks.Task(cond.wait())) + self.event_loop.run_until_complete, cond.wait()) def test_wait_for(self): cond = locks.Condition() @@ -431,14 +427,14 @@ def c1(result): self.event_loop.run_once() self.assertEqual([], result) - self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + self.event_loop.run_until_complete(cond.acquire()) cond.notify() cond.release() self.event_loop.run_once() self.assertEqual([], result) presult = True - self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + self.event_loop.run_until_complete(cond.acquire()) cond.notify() cond.release() self.event_loop.run_once() @@ -468,7 +464,7 @@ def c1(result): self.event_loop.run_once() self.assertEqual([], result) - self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + self.event_loop.run_until_complete(cond.acquire()) cond.notify() cond.release() self.event_loop.run_once() @@ -487,14 +483,14 @@ def test_wait_for_unacquired(self): cond = locks.Condition() # predicate can return true immediately - res = self.event_loop.run_until_complete(tasks.Task( - cond.wait_for(lambda: [1, 2, 3]))) + res = self.event_loop.run_until_complete( + cond.wait_for(lambda: [1, 2, 3])) self.assertEqual([1, 2, 3], res) self.assertRaises( RuntimeError, self.event_loop.run_until_complete, - tasks.Task(cond.wait_for(lambda: False))) + cond.wait_for(lambda: False)) def test_notify(self): cond = locks.Condition() @@ -528,13 +524,13 @@ def c3(result): self.event_loop.run_once() self.assertEqual([], result) - self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + self.event_loop.run_until_complete(cond.acquire()) cond.notify(1) cond.release() self.event_loop.run_once() self.assertEqual([1], result) - self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + self.event_loop.run_until_complete(cond.acquire()) cond.notify(1) cond.notify(2048) cond.release() @@ -566,7 +562,7 @@ def c2(result): self.event_loop.run_once() self.assertEqual([], result) - self.event_loop.run_until_complete(tasks.Task(cond.acquire())) + self.event_loop.run_until_complete(cond.acquire()) cond.notify_all() cond.release() self.event_loop.run_once() @@ -594,7 +590,7 @@ def test_repr(self): sem = locks.Semaphore() self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) - self.event_loop.run_until_complete(tasks.Task(sem.acquire())) + self.event_loop.run_until_complete(sem.acquire()) self.assertTrue(repr(sem).endswith('[locked]>')) def test_semaphore(self): @@ -623,9 +619,9 @@ def test_acquire(self): result = [] self.assertTrue( - self.event_loop.run_until_complete(tasks.Task(sem.acquire()))) + self.event_loop.run_until_complete(sem.acquire())) self.assertTrue( - self.event_loop.run_until_complete(tasks.Task(sem.acquire()))) + self.event_loop.run_until_complete(sem.acquire())) self.assertFalse(sem.locked()) @tasks.coroutine @@ -673,27 +669,25 @@ def c4(result): def test_acquire_timeout(self): sem = locks.Semaphore() - self.event_loop.run_until_complete(tasks.Task(sem.acquire())) + self.event_loop.run_until_complete(sem.acquire()) t0 = time.monotonic() - acquired = self.event_loop.run_until_complete( - tasks.Task(sem.acquire(0.1))) + acquired = self.event_loop.run_until_complete(sem.acquire(0.1)) self.assertFalse(acquired) total_time = (time.monotonic() - t0) self.assertTrue(0.08 < total_time < 0.12) sem = locks.Semaphore() - self.event_loop.run_until_complete(tasks.Task(sem.acquire())) + self.event_loop.run_until_complete(sem.acquire()) self.event_loop.call_later(0.1, sem.release) - acquired = self.event_loop.run_until_complete( - tasks.Task(sem.acquire(10.1))) + acquired = self.event_loop.run_until_complete(sem.acquire(10.1)) self.assertTrue(acquired) def test_acquire_timeout_mixed(self): sem = locks.Semaphore() - self.event_loop.run_until_complete(tasks.Task(sem.acquire())) + self.event_loop.run_until_complete(sem.acquire()) tasks.Task(sem.acquire()) tasks.Task(sem.acquire()) acquire_task = tasks.Task(sem.acquire(0.1)) @@ -710,7 +704,7 @@ def test_acquire_timeout_mixed(self): def test_acquire_cancel(self): sem = locks.Semaphore() - self.event_loop.run_until_complete(tasks.Task(sem.acquire())) + self.event_loop.run_until_complete(sem.acquire()) acquire = tasks.Task(sem.acquire()) self.event_loop.call_soon(acquire.cancel) @@ -726,7 +720,7 @@ def test_release_not_acquired(self): def test_release_no_waiters(self): sem = locks.Semaphore() - self.event_loop.run_until_complete(tasks.Task(sem.acquire())) + self.event_loop.run_until_complete(sem.acquire()) self.assertTrue(sem.locked()) sem.release() diff --git a/tests/streams_test.py b/tests/streams_test.py index 15fece12..dc6eeaf4 100644 --- a/tests/streams_test.py +++ b/tests/streams_test.py @@ -38,8 +38,7 @@ def test_read_zero(self): stream = streams.StreamReader() stream.feed_data(self.DATA) - read_task = tasks.Task(stream.read(0)) - data = self.event_loop.run_until_complete(read_task) + data = self.event_loop.run_until_complete(stream.read(0)) self.assertEqual(b'', data) self.assertEqual(len(self.DATA), stream.byte_count) @@ -62,8 +61,7 @@ def test_read_line_breaks(self): stream.feed_data(b'line1') stream.feed_data(b'line2') - read_task = tasks.Task(stream.read(5)) - data = self.event_loop.run_until_complete(read_task) + data = self.event_loop.run_until_complete(stream.read(5)) self.assertEqual(b'line1', data) self.assertEqual(5, stream.byte_count) @@ -101,13 +99,13 @@ def test_read_exception(self): stream = streams.StreamReader() stream.feed_data(b'line\n') - data = self.event_loop.run_until_complete(tasks.Task(stream.read(2))) + data = self.event_loop.run_until_complete(stream.read(2)) self.assertEqual(b'li', data) stream.set_exception(ValueError()) self.assertRaises( ValueError, - self.event_loop.run_until_complete, tasks.Task(stream.read(2))) + self.event_loop.run_until_complete, stream.read(2)) def test_readline(self): # Read one line. @@ -132,9 +130,8 @@ def test_readline_limit_with_existing_data(self): stream.feed_data(b'li') stream.feed_data(b'ne1\nline2\n') - read_task = tasks.Task(stream.readline()) self.assertRaises( - ValueError, self.event_loop.run_until_complete, read_task) + ValueError, self.event_loop.run_until_complete, stream.readline()) self.assertEqual([b'line2\n'], list(stream.buffer)) stream = streams.StreamReader(3) @@ -142,9 +139,8 @@ def test_readline_limit_with_existing_data(self): stream.feed_data(b'ne1') stream.feed_data(b'li') - read_task = tasks.Task(stream.readline()) self.assertRaises( - ValueError, self.event_loop.run_until_complete, read_task) + ValueError, self.event_loop.run_until_complete, stream.readline()) self.assertEqual([b'li'], list(stream.buffer)) self.assertEqual(2, stream.byte_count) @@ -160,9 +156,8 @@ def cb(): stream.feed_eof() self.event_loop.call_soon(cb) - read_task = tasks.Task(stream.readline()) self.assertRaises( - ValueError, self.event_loop.run_until_complete, read_task) + ValueError, self.event_loop.run_until_complete, stream.readline()) self.assertEqual([b'chunk3\n'], list(stream.buffer)) self.assertEqual(7, stream.byte_count) @@ -171,8 +166,7 @@ def test_readline_line_byte_count(self): stream.feed_data(self.DATA[:6]) stream.feed_data(self.DATA[6:]) - read_task = tasks.Task(stream.readline()) - line = self.event_loop.run_until_complete(read_task) + line = self.event_loop.run_until_complete(stream.readline()) self.assertEqual(b'line1\n', line) self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) @@ -182,29 +176,23 @@ def test_readline_eof(self): stream.feed_data(b'some data') stream.feed_eof() - read_task = tasks.Task(stream.readline()) - line = self.event_loop.run_until_complete(read_task) - + line = self.event_loop.run_until_complete(stream.readline()) self.assertEqual(b'some data', line) def test_readline_empty_eof(self): stream = streams.StreamReader() stream.feed_eof() - read_task = tasks.Task(stream.readline()) - line = self.event_loop.run_until_complete(read_task) - + line = self.event_loop.run_until_complete(stream.readline()) self.assertEqual(b'', line) def test_readline_read_byte_count(self): stream = streams.StreamReader() stream.feed_data(self.DATA) - read_task = tasks.Task(stream.readline()) - self.event_loop.run_until_complete(read_task) + self.event_loop.run_until_complete(stream.readline()) - read_task = tasks.Task(stream.read(7)) - data = self.event_loop.run_until_complete(read_task) + data = self.event_loop.run_until_complete(stream.read(7)) self.assertEqual(b'line2\nl', data) self.assertEqual( @@ -215,27 +203,24 @@ def test_readline_exception(self): stream = streams.StreamReader() stream.feed_data(b'line\n') - data = self.event_loop.run_until_complete( - tasks.Task(stream.readline())) + data = self.event_loop.run_until_complete(stream.readline()) self.assertEqual(b'line\n', data) stream.set_exception(ValueError()) self.assertRaises( ValueError, - self.event_loop.run_until_complete, tasks.Task(stream.readline())) + self.event_loop.run_until_complete, stream.readline()) def test_readexactly_zero_or_less(self): # Read exact number of bytes (zero or less). stream = streams.StreamReader() stream.feed_data(self.DATA) - read_task = tasks.Task(stream.readexactly(0)) - data = self.event_loop.run_until_complete(read_task) + data = self.event_loop.run_until_complete(stream.readexactly(0)) self.assertEqual(b'', data) self.assertEqual(len(self.DATA), stream.byte_count) - read_task = tasks.Task(stream.readexactly(-1)) - data = self.event_loop.run_until_complete(read_task) + data = self.event_loop.run_until_complete(stream.readexactly(-1)) self.assertEqual(b'', data) self.assertEqual(len(self.DATA), stream.byte_count) @@ -275,15 +260,14 @@ def test_readexactly_exception(self): stream = streams.StreamReader() stream.feed_data(b'line\n') - data = self.event_loop.run_until_complete( - tasks.Task(stream.readexactly(2))) + data = self.event_loop.run_until_complete(stream.readexactly(2)) self.assertEqual(b'li', data) stream.set_exception(ValueError()) self.assertRaises( ValueError, self.event_loop.run_until_complete, - tasks.Task(stream.readexactly(2))) + stream.readexactly(2)) def test_exception(self): stream = streams.StreamReader() @@ -306,7 +290,7 @@ def readline(): t1 = tasks.Task(stream.readline()) t2 = tasks.Task(set_err()) - self.event_loop.run_until_complete(tasks.Task(tasks.wait([t1, t2]))) + self.event_loop.run_until_complete(tasks.wait([t1, t2])) self.assertRaises(ValueError, t1.result) diff --git a/tulip/base_events.py b/tulip/base_events.py index b71be09a..11940de6 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -120,6 +120,12 @@ def run_until_complete(self, future, timeout=None): Return the Future's result, or raise its exception. If the timeout is reached or stop() is called, raise TimeoutError. """ + if (not isinstance(future, futures.Future) and + tasks.iscoroutine(future)): + future = tasks.Task(future) + + assert isinstance(future, futures.Future), 'Future is required' + handle_called = False def stop_loop(): From d3bad6863a6215852b8e93cfa4cfa06a242d8eff Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 20 Mar 2013 14:33:54 -0700 Subject: [PATCH 0365/1502] disallow all forms of nested running --- tests/events_test.py | 32 ++++++++++++++++++++++++++++++++ tulip/base_events.py | 32 +++++++++++++++++++++++++------- 2 files changed, 57 insertions(+), 7 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 794de307..79648340 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -142,6 +142,38 @@ def app(environ, start_response): def test_run(self): self.event_loop.run() # Returns immediately. + def test_run_nesting(self): + err = None + + @tasks.coroutine + def coro(): + nonlocal err + yield from [] + try: + self.event_loop.run_until_complete( + tasks.sleep(0.1)) + except Exception as exc: + err = exc + + self.event_loop.run_until_complete(tasks.Task(coro())) + self.assertIsInstance(err, RuntimeError) + + def test_run_once_nesting(self): + err = None + + @tasks.coroutine + def coro(): + nonlocal err + yield from [] + tasks.sleep(0.1) + try: + self.event_loop.run_once() + except Exception as exc: + err = exc + + self.event_loop.run_until_complete(tasks.Task(coro())) + self.assertIsInstance(err, RuntimeError) + def test_call_later(self): results = [] diff --git a/tulip/base_events.py b/tulip/base_events.py index 11940de6..75cb9ddd 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -48,6 +48,7 @@ def __init__(self): self._scheduled = [] self._default_executor = None self._internal_fds = 0 + self._running = False def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): """Create socket transport.""" @@ -75,6 +76,10 @@ def _process_events(self, event_list): """Process selector events.""" raise NotImplementedError + def is_running(self): + """Returns running status of event loop.""" + return self._running + def run(self): """Run the event loop until nothing left to do or stop() called. @@ -84,13 +89,20 @@ def run(self): TODO: Give this a timeout too? """ - while (self._ready or - self._scheduled or - self._selector.registered_count() > 1): - try: - self._run_once() - except _StopError: - break + if self._running: + raise RuntimeError('Event loop is running.') + + self._running = True + try: + while (self._ready or + self._scheduled or + self._selector.registered_count() > 1): + try: + self._run_once() + except _StopError: + break + finally: + self._running = False def run_forever(self): """Run until stop() is called. @@ -109,10 +121,16 @@ def run_once(self, timeout=None): Calling stop() will break out of this too. """ + if self._running: + raise RuntimeError('Event loop is running.') + + self._running = True try: self._run_once(timeout) except _StopError: pass + finally: + self._running = False def run_until_complete(self, future, timeout=None): """Run until the Future is done, or until a timeout. From 08c727556f870ed8d6fd077f529fc9c3a2525961 Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Wed, 20 Mar 2013 20:59:38 -0400 Subject: [PATCH 0366/1502] Initial queues implementation (Fixes issue 7) --- tests/queues_test.py | 370 +++++++++++++++++++++++++++++++++++++++++++ tulip/queues.py | 291 ++++++++++++++++++++++++++++++++++ 2 files changed, 661 insertions(+) create mode 100644 tests/queues_test.py create mode 100644 tulip/queues.py diff --git a/tests/queues_test.py b/tests/queues_test.py new file mode 100644 index 00000000..c722c1ac --- /dev/null +++ b/tests/queues_test.py @@ -0,0 +1,370 @@ +"""Tests for queues.py""" + +import unittest +import queue + +from tulip import events +from tulip import locks +from tulip import queues +from tulip import tasks + + +class _QueueTestBase(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + +class QueueBasicTests(_QueueTestBase): + + def _test_repr_or_str(self, fn, expect_id): + """Test Queue's repr or str. + + fn is repr or str. expect_id is True if we expect the Queue's id to + appear in fn(Queue()). + """ + q = queues.Queue() + self.assertTrue(fn(q).startswith('' % ( + type(self).__name__, hex(id(self)), self._format()) + + def __str__(self): + return '<%s %s>' % (type(self).__name__, self._format()) + + def _format(self): + result = 'maxsize=%r' % (self._maxsize, ) + if getattr(self, '_queue', None): + result += ' _queue=%r' % list(self._queue) + if self._getters: + result += ' _getters[%s]' % len(self._getters) + if self._putters: + result += ' _putters[%s]' % len(self._putters) + return result + + def _consume_done_getters(self, waiters): + # Delete waiters at the head of the get() queue who've timed out. + while waiters and waiters[0].done(): + waiters.popleft() + + def _consume_done_putters(self): + # Delete waiters at the head of the put() queue who've timed out. + while self._putters and self._putters[0][1].done(): + self._putters.popleft() + + def qsize(self): + """Number of items in the queue.""" + return len(self._queue) + + @property + def maxsize(self): + """Number of items allowed in the queue.""" + return self._maxsize + + def empty(self): + """Return True if the queue is empty, False otherwise.""" + return not self._queue + + def full(self): + """Return True if there are maxsize items in the queue. + + Note: if the Queue was initialized with maxsize=0 (the default), + then full() is never True. + """ + if self._maxsize <= 0: + return False + else: + return self.qsize() == self._maxsize + + @coroutine + def put(self, item, timeout=None): + """Put an item into the queue. + + If you yield from put() and timeout is None (the default), wait until a + free slot is available before adding item. + + If a timeout is provided, raise queue.Full if no free slot becomes + available before the timeout. + """ + self._consume_done_getters(self._getters) + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + waiter = futures.Future( + event_loop=self._event_loop, timeout=timeout) + + self._putters.append((item, waiter)) + try: + yield from waiter + except concurrent.futures.CancelledError: + raise queue.Full + + else: + self._put(item) + + def put_nowait(self, item): + """Put an item into the queue without blocking. + + If no free slot is immediately available, raise queue.Full. + """ + self._consume_done_getters(self._getters) + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + raise queue.Full + else: + self._put(item) + + @coroutine + def get(self, timeout=None): + """Remove and return an item from the queue. + + If you yield from get() and timeout is None (the default), wait until a + item is available. + + If a timeout is provided, raise queue.Empty if no item is available + before the timeout. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + + # When a getter runs and frees up a slot so this putter can + # run, we need to defer the put for a tick to ensure that + # getters and putters alternate perfectly. See + # ChannelTest.test_wait. + self._event_loop.call_soon(putter.set_result, None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + waiter = futures.Future( + event_loop=self._event_loop, timeout=timeout) + + self._getters.append(waiter) + try: + return (yield from waiter) + except concurrent.futures.CancelledError: + raise queue.Empty + + def get_nowait(self): + """Remove and return an item from the queue. + + Return an item if one is immediately available, else raise queue.Full. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + # Wake putter on next tick. + putter.set_result(None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + raise queue.Empty + + +class PriorityQueue(Queue): + """A subclass of Queue; retrieves entries in priority order (lowest first). + + Entries are typically tuples of the form: (priority number, data). + """ + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item, heappush=heapq.heappush): + heappush(self._queue, item) + + def _get(self, heappop=heapq.heappop): + return heappop(self._queue) + + +class LifoQueue(Queue): + """A subclass of Queue that retrieves most recently added entries first.""" + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item): + self._queue.append(item) + + def _get(self): + return self._queue.pop() + + +class JoinableQueue(Queue): + """A subclass of Queue with task_done() and join() methods.""" + + def __init__(self, maxsize=0): + self._unfinished_tasks = 0 + self._finished = locks.EventWaiter() + self._finished.set() + super(JoinableQueue, self).__init__(maxsize=maxsize) + + def _format(self): + result = Queue._format(self) + if self._unfinished_tasks: + result += ' tasks=%s' % self._unfinished_tasks + return result + + def _put(self, item): + super(JoinableQueue, self)._put(item) + self._unfinished_tasks += 1 + self._finished.clear() + + def task_done(self): + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each get() used to fetch a task, + a subsequent call to task_done() tells the queue that the processing + on the task is complete. + + If a join() is currently blocking, it will resume when all items have + been processed (meaning that a task_done() call was received for every + item that had been put() into the queue). + + Raises ValueError if called more times than there were items placed in + the queue. + """ + if self._unfinished_tasks <= 0: + raise ValueError('task_done() called too many times') + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + @coroutine + def join(self, timeout=None): + """Block until all items in the queue have been gotten and processed. + + The count of unfinished tasks goes up whenever an item is added to the + queue. The count goes down whenever a consumer thread calls task_done() + to indicate that the item was retrieved and all work on it is complete. + When the count of unfinished tasks drops to zero, join() unblocks. + """ + if self._unfinished_tasks > 0: + yield from self._finished.wait(timeout=timeout) From 114810cc17efd55896ee6e667be5e943a671e16e Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 20 Mar 2013 18:37:08 -0700 Subject: [PATCH 0367/1502] Implement transport/protocol pair for unix pipes --- tests/events_test.py | 137 +++++++++++++ tests/unix_events_test.py | 406 ++++++++++++++++++++++++++++++++++++++ tulip/base_events.py | 28 +++ tulip/events.py | 29 +++ tulip/protocols.py | 56 +++--- tulip/selector_events.py | 2 + tulip/transports.py | 104 +++++----- tulip/unix_events.py | 171 ++++++++++++++++ 8 files changed, 861 insertions(+), 72 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 79648340..a71f10c3 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -3,6 +3,7 @@ import concurrent.futures import contextlib import errno +import fcntl import gc import io import os @@ -78,6 +79,48 @@ def connection_lost(self, exc): self.state = 'CLOSED' +class MyReadPipeProto(protocols.Protocol): + + def __init__(self): + self.state = ['INITIAL'] + self.nbytes = 0 + self.transport = None + + def connection_made(self, transport): + self.transport = transport + assert self.state == ['INITIAL'], self.state + self.state.append('CONNECTED') + + def data_received(self, data): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.state.append('EOF') + self.transport.close() + + def connection_lost(self, exc): + assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state + self.state.append('CLOSED') + + +class MyWritePipeProto(protocols.Protocol): + + def __init__(self): + self.state = 'INITIAL' + self.transport = None + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + + class EventLoopTestsMixin: def setUp(self): @@ -968,6 +1011,92 @@ def test_internal_fds(self): self.assertIsNone(event_loop._csock) self.assertIsNone(event_loop._ssock) + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_read_pipe(self): + proto = None + + def factory(): + nonlocal proto + proto = MyReadPipeProto() + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(rpipe, 'rb', 1024) + + @tasks.task + def connect(): + t, p = yield from self.event_loop.connect_read_pipe(factory, + pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(0, proto.nbytes) + + self.event_loop.run_until_complete(connect()) + + os.write(wpipe, b'1') + self.event_loop.run_once() + self.assertEqual(1, proto.nbytes) + + os.write(wpipe, b'2345') + self.event_loop.run_once() + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(5, proto.nbytes) + + os.close(wpipe) + self.event_loop.run_once() + self.assertEqual(['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto() + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.task + def connect(): + nonlocal transport + t, p = yield from self.event_loop.connect_write_pipe(factory, + pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.event_loop.run_until_complete(connect()) + + transport.write(b'1') + self.event_loop.run_once() + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + transport.write(b'2345') + self.event_loop.run_once() + data = os.read(rpipe, 1024) + self.assertEqual(b'2345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(rpipe) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + # close connection + proto.transport.close() + self.event_loop.run_once() + self.assertEqual('CLOSED', proto.state) + if sys.platform == 'win32': from tulip import windows_events @@ -1210,6 +1339,14 @@ def test_not_imlemented(self): NotImplementedError, ev_loop.add_signal_handler, 1, f) self.assertRaises( NotImplementedError, ev_loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, ev_loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, ev_loop.connect_read_pipe, f, + unittest.mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, ev_loop.connect_write_pipe, f, + unittest.mock.sentinel.pipe) class ProtocolsAbsTests(unittest.TestCase): diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index 24ea4945..a30be2c7 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -1,6 +1,7 @@ """Tests for unix_events.py.""" import errno +import io import unittest import unittest.mock @@ -10,6 +11,8 @@ signal = None from tulip import events +from tulip import futures +from tulip import protocols from tulip import unix_events @@ -166,3 +169,406 @@ class Err(OSError): self.assertRaises( RuntimeError, self.event_loop.remove_signal_handler, signal.SIGHUP) + + +class UnixReadPipeTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor(self, m_fcntl): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + self.event_loop.add_reader.assert_called_with(5, tr._read_ready) + self.event_loop.call_soon.assert_called_with( + self.protocol.connection_made, tr) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor_with_waiter(self, m_fcntl): + fut = futures.Future() + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol, fut) + self.event_loop.call_soon.assert_called_with(fut.set_result, None) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + m_read.return_value = b'data' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.data_received.assert_called_with(b'data') + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready_eof(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + m_read.return_value = b'' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.eof_received.assert_called_with() + self.event_loop.remove_reader.assert_called_with(5) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready_blocked(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + self.event_loop.reset_mock() + m_read.side_effect = BlockingIOError + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.assertFalse(self.protocol.data_received.called) + + @unittest.mock.patch('logging.exception') + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready_error(self, m_fcntl, m_read, m_logexc): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + err = OSError() + m_read.side_effect = err + tr._close = unittest.mock.Mock() + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + tr._close.assert_called_with(err) + m_logexc.assert_called_with('Fatal error for %s', tr) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_pause(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.pause() + self.event_loop.remove_reader.assert_called_with(5) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_resume(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.resume() + self.event_loop.add_reader.assert_called_with(5, tr._read_ready) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_close(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._close = unittest.mock.Mock() + tr.close() + tr._close.assert_called_with(None) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_close_already_closing(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._closing = True + tr._close = unittest.mock.Mock() + tr.close() + self.assertFalse(tr._close.called) + + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__close(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = object() + tr._close(err) + self.assertTrue(tr._closing) + self.event_loop.remove_reader.assert_called_with(5) + self.protocol.connection_lost.assert_called_with(err) + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost(self, m_fcntl): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost_with_err(self, m_fcntl): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + +class UnixWritePipeTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + self.event_loop.call_soon.assert_called_with( + self.protocol.connection_made, tr) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor_with_waiter(self, m_fcntl): + fut = futures.Future() + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol, fut) + self.event_loop.call_soon.assert_called_with(fut.set_result, None) + + @unittest.mock.patch('fcntl.fcntl') + def test_can_write_eof(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + self.assertTrue(tr.can_write_eof()) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + m_write.return_value = 4 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.add_writer.called) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_no_data(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write(b'') + self.assertFalse(m_write.called) + self.assertFalse(self.event_loop.add_writer.called) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_partial(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + m_write.return_value = 2 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.event_loop.add_writer.assert_called_with(5, tr._write_ready) + self.assertEqual([b'ta'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_buffer(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'previous'] + tr.write(b'data') + self.assertFalse(m_write.called) + self.assertFalse(self.event_loop.add_writer.called) + self.assertEqual([b'previous', b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_again(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + m_write.side_effect = BlockingIOError() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.event_loop.add_writer.assert_called_with(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_err(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = OSError() + m_write.side_effect = err + tr._fatal_error = unittest.mock.Mock() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.called) + self.assertEqual([], tr._buffer) + tr._fatal_error.assert_called_with(err) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_partial(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.return_value = 3 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'a'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_again(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.side_effect = BlockingIOError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_empty(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.return_value = 0 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('logging.exception') + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_err(self, m_fcntl, m_write, m_logexc): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.side_effect = err = OSError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + self.protocol.connection_lost.assert_called_with(err) + m_logexc.assert_called_with('Fatal error for %s', tr) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_closing(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._closing = True + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_abort(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + tr.abort() + self.assertFalse(m_write.called) + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost_with_err(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test_close(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr.close() + tr.write_eof.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test_close_closing(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr._closing = True + tr.close() + self.assertFalse(tr.write_eof.called) + + @unittest.mock.patch('fcntl.fcntl') + def test_write_eof(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write_eof() + self.assertTrue(tr._closing) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('fcntl.fcntl') + def test_write_eof_pending(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + tr._buffer = [b'data'] + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.protocol.connection_lost.called) diff --git a/tulip/base_events.py b/tulip/base_events.py index 75cb9ddd..1dc0b52b 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -64,6 +64,16 @@ def _make_datagram_transport(self, sock, protocol, """Create datagram transport.""" raise NotImplementedError + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create read pipe transport.""" + raise NotImplementedError + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create write pipe transport.""" + raise NotImplementedError + def _read_from_self(self): """XXX""" raise NotImplementedError @@ -441,6 +451,24 @@ def start_serving_datagram(self, protocol_factory, host, port, *, return sock + @tasks.coroutine + def connect_read_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future() + transport = self._make_read_pipe_transport(pipe, protocol, waiter, + extra={}) + yield from waiter + return transport, protocol + + @tasks.coroutine + def connect_write_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future() + transport = self._make_write_pipe_transport(pipe, protocol, waiter, + extra={}) + yield from waiter + return transport, protocol + def _add_callback(self, handle): """Add a Handle to ready or scheduled.""" if handle.cancelled: diff --git a/tulip/events.py b/tulip/events.py index 9bad35fb..da892995 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -193,6 +193,35 @@ def start_serving_datagram(self, protocol_factory, host, port, *, family=0, proto=0, flags=0): raise NotImplementedError + def connect_read_pipe(self, protocol_factory, pipe): + """Register read pipe in eventloop. + + protocol_factory should instantiate object with Protocol interface. + pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + ReadTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def connect_write_pipe(self, protocol_factory, pipe): + """Register write pipe in eventloop. + + protocol_factory should instantiate object with BaseProtocol interface. + Pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + WriteTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + #def spawn_subprocess(self, protocol_factory, pipe): + # raise NotImplementedError + # Ready-based callback registration methods. # The add_*() methods return a Handle. # The remove_*() methods return True if something was removed, diff --git a/tulip/protocols.py b/tulip/protocols.py index f01e2fd2..593ee745 100644 --- a/tulip/protocols.py +++ b/tulip/protocols.py @@ -3,7 +3,34 @@ __all__ = ['Protocol', 'DatagramProtocol'] -class Protocol: +class BaseProtocol: + """ABC for base protocol class. + + Usually user implements protocols that derived from BaseProtocol + like Protocol or ProcessProtocol. + + The only case when BaseProtocol should be implemented directly is + write-only transport like write pipe + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the pipe connection. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +class Protocol(BaseProtocol): """ABC representing a protocol. The user should implement this interface. They can inherit from @@ -16,7 +43,7 @@ class Protocol: When the connection is made successfully, connection_made() is called with a suitable transport object. Then data_received() will be called 0 or more times with data (bytes) received from the - transport; finally, connection_list() will be called exactly once + transport; finally, connection_lost() will be called exactly once with either an exception object or None as an argument. State machine of calls: @@ -24,15 +51,6 @@ class Protocol: start -> CM [-> DR*] [-> ER?] -> CL -> end """ - def connection_made(self, transport): - """Called when a connection is made. - - The argument is the transport representing the connection. - To send data, call its write() or writelines() method. - To receive data, wait for data_received() calls. - When the connection is closed, connection_lost() is called. - """ - def data_received(self, data): """Called when some data is received. @@ -49,26 +67,12 @@ def eof_received(self): set it). """ - def connection_lost(self, exc): - """Called when the connection is lost or closed. - The argument is an exception object or None (the latter - meaning a regular EOF is received or the connection was - aborted or closed). - """ - - -class DatagramProtocol: +class DatagramProtocol(BaseProtocol): """ABC representing a datagram protocol.""" - def connection_made(self, transport): - """Called when a datagram transport is ready.""" - def datagram_received(self, data, addr): """Called when some datagram is received.""" def connection_refused(self, exc): """Connection is refused.""" - - def connection_lost(self, exc): - """Called when the connection is lost or closed.""" diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 9bc9c23f..59998461 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -659,3 +659,5 @@ def _call_connection_lost(self, exc): self._protocol.connection_lost(exc) finally: self._sock.close() + + diff --git a/tulip/transports.py b/tulip/transports.py index 6eb1c554..984e7e80 100644 --- a/tulip/transports.py +++ b/tulip/transports.py @@ -1,28 +1,10 @@ """Abstract Transport class.""" -__all__ = ['Transport'] +__all__ = ['ReadTransport', 'WriteTransport', 'Transport'] -class Transport: - """ABC representing a transport. - - There may be several implementations, but typically, the user does - not implement new transports; rather, the platform provides some - useful transports that are implemented using the platform's best - practices. - - The user never instantiates a transport directly; they call a - utility function, passing it a protocol factory and other - information necessary to create the transport and protocol. (E.g. - EventLoop.create_connection() or EventLoop.start_serving().) - - The utility function will asynchronously create a transport and a - protocol and hook them up by calling the protocol's - connection_made() method, passing it the transport. - - The implementation here raises NotImplemented for every method - except writelines(), which calls write() in a loop. - """ +class BaseTransport: + """Base ABC for transports.""" def __init__(self, extra=None): if extra is None: @@ -33,6 +15,40 @@ def get_extra_info(self, name, default=None): """Get optional transport information.""" return self._extra.get(name, default) + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + raise NotImplementedError + + +class ReadTransport(BaseTransport): + """ABC for read-only transports.""" + + def pause(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + raise NotImplementedError + + +class WriteTransport(BaseTransport): + """ABC for write-only transports.""" + def write(self, data): """Write some data bytes to the transport. @@ -63,37 +79,33 @@ def can_write_eof(self): """Return True if this protocol supports write_eof(), False if not.""" raise NotImplementedError - def pause(self): - """Pause the receiving end. + def abort(self): + """Closes the transport immediately. - No data will be passed to the protocol's data_received() - method until resume() is called. + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. """ raise NotImplementedError - def resume(self): - """Resume the receiving end. - Data received will once again be passed to the protocol's - data_received() method. - """ - raise NotImplementedError +class Transport(ReadTransport, WriteTransport): + """ABC representing a bidirectional transport. - def close(self): - """Closes the transport. + There may be several implementations, but typically, the user does + not implement new transports; rather, the platform provides some + useful transports that are implemented using the platform's best + practices. - Buffered data will be flushed asynchronously. No more data - will be received. After all buffered data is flushed, the - protocol's connection_lost() method will (eventually) called - with None as its argument. - """ - raise NotImplementedError + The user never instantiates a transport directly; they call a + utility function, passing it a protocol factory and other + information necessary to create the transport and protocol. (E.g. + EventLoop.create_connection() or EventLoop.start_serving().) - def abort(self): - """Closes the transport immediately. + The utility function will asynchronously create a transport and a + protocol and hook them up by calling the protocol's + connection_made() method, passing it the transport. - Buffered data will be lost. No more data will be received. - The protocol's connection_lost() method will (eventually) be - called with None as its argument. - """ - raise NotImplementedError + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 41f8e0b4..833c6612 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -1,7 +1,9 @@ """Selector eventloop for Unix with signal handling.""" import errno +import fcntl import logging +import os import socket import sys @@ -12,6 +14,7 @@ from . import events from . import selector_events +from . import transports __all__ = ['SelectorEventLoop'] @@ -126,3 +129,171 @@ def _check_signal(self, sig): if not (1 <= sig < signal.NSIG): raise ValueError( 'sig {} out of range(1, {})'.format(sig, signal.NSIG)) + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixReadPipeTransport(self, pipe, protocol, waiter, extra) + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixWritePipeTransport(self, pipe, protocol, waiter, extra) + + +def _set_nonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + + +class _UnixReadPipeTransport(transports.ReadTransport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, event_loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._event_loop = event_loop + self._pipe = pipe + self._fileno = pipe.fileno() + _set_nonblocking(self._fileno) + self._protocol = protocol + self._closing = False + self._event_loop.add_reader(self._fileno, self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = os.read(self._fileno, self.max_size) + except BlockingIOError: + pass + except OSError as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + self._event_loop.remove_reader(self._fileno) + self._protocol.eof_received() + + def pause(self): + self._event_loop.remove_reader(self._fileno) + + def resume(self): + self._event_loop.add_reader(self._fileno, self._read_ready) + + def close(self): + if not self._closing: + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + logging.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._closing = True + self._event_loop.remove_reader(self._fileno) + self._call_connection_lost(exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + + +class _UnixWritePipeTransport(transports.WriteTransport): + + def __init__(self, event_loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._event_loop = event_loop + self._pipe = pipe + self._fileno = pipe.fileno() + _set_nonblocking(self._fileno) + self._protocol = protocol + self._buffer = [] + self._closing = False # Set when close() or write_eof() called. + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def write(self, data): + assert isinstance(data, (bytes, bytearray)), repr(data) + assert not self._closing + if not data: + return + if not self._buffer: + # Attempt to send it right away first. + try: + n = os.write(self._fileno, data) + except BlockingIOError: + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + if n == len(data): + return + elif n > 0: + data = data[n:] + self._event_loop.add_writer(self._fileno, self._write_ready) + assert data, "Data shold not be empty" + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + assert data, "Data shold not be empty" + try: + n = os.write(self._fileno, data) + except BlockingIOError: + self._buffer = [data] + return + except Exception as exc: + self._fatal_error(exc) + return + if n == len(data): + self._buffer = [] + self._event_loop.remove_writer(self._fileno) + if self._closing: + self._call_connection_lost(None) + return + elif n > 0: + data = data[n:] + self._buffer = [data] # Try again later. + + def can_write_eof(self): + return True + + def write_eof(self): + assert not self._closing + assert self._pipe + self._closing = True + if not self._buffer: + self._call_connection_lost(None) + + def close(self): + if not self._closing: + # write_eof is all what we needed to close the write pipe + self.write_eof() + + def abort(self): + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + logging.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc=None): + self._closing = True + self._buffer = [] + self._event_loop.remove_writer(self._fileno) + self._call_connection_lost(exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() From 2a895ddf4d612a2dc380d842c56b860c94043a01 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 20 Mar 2013 18:37:52 -0700 Subject: [PATCH 0368/1502] Implement transport/protocol pair for unix pipes --- tulip/selector_events.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 59998461..9bc9c23f 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -659,5 +659,3 @@ def _call_connection_lost(self, exc): self._protocol.connection_lost(exc) finally: self._sock.close() - - From dd9d56eb76175dbedfa70c964e00e031d6cbe157 Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Thu, 21 Mar 2013 06:17:01 -0400 Subject: [PATCH 0369/1502] Use new-style super() in queues --- tulip/queues.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tulip/queues.py b/tulip/queues.py index e6d52ab3..e77665b9 100644 --- a/tulip/queues.py +++ b/tulip/queues.py @@ -245,7 +245,7 @@ def __init__(self, maxsize=0): self._unfinished_tasks = 0 self._finished = locks.EventWaiter() self._finished.set() - super(JoinableQueue, self).__init__(maxsize=maxsize) + super().__init__(maxsize=maxsize) def _format(self): result = Queue._format(self) @@ -254,7 +254,7 @@ def _format(self): return result def _put(self, item): - super(JoinableQueue, self)._put(item) + super()._put(item) self._unfinished_tasks += 1 self._finished.clear() From 11a96090e7b09af90c3fa80d774b8f9eb48d3af7 Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Thu, 21 Mar 2013 06:19:58 -0400 Subject: [PATCH 0370/1502] Remove unneeded code in queues --- tulip/queues.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tulip/queues.py b/tulip/queues.py index e77665b9..ee349e13 100644 --- a/tulip/queues.py +++ b/tulip/queues.py @@ -30,9 +30,9 @@ def __init__(self, maxsize=0): self._maxsize = maxsize # Futures. - self._getters = collections.deque([]) + self._getters = collections.deque() # Pairs of (item, Future). - self._putters = collections.deque([]) + self._putters = collections.deque() self._init(maxsize) def _init(self, maxsize): From f715824f3beeecfc4ac717f4f3c21ce91cc3b5a6 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 21 Mar 2013 08:11:36 -0700 Subject: [PATCH 0371/1502] Remove .rej files. --- Makefile | 1 + 1 file changed, 1 insertion(+) diff --git a/Makefile b/Makefile index 2391d89c..724c57c9 100644 --- a/Makefile +++ b/Makefile @@ -25,5 +25,6 @@ clean: rm -f `find . -type f -name '@*' ` rm -f `find . -type f -name '#*#' ` rm -f `find . -type f -name '*.orig' ` + rm -f `find . -type f -name '*.rej' ` rm -f .coverage rm -rf htmlcov From e4dd61775faa10364671e562f9d63be75df90489 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Thu, 21 Mar 2013 10:00:51 -0700 Subject: [PATCH 0372/1502] Introduce DatagramTransport ABC --- tulip/selector_events.py | 2 +- tulip/transports.py | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 9bc9c23f..2992e9e6 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -554,7 +554,7 @@ def _fatal_error(self, exc): self._event_loop.call_soon(self._protocol.connection_lost, exc) -class _SelectorDatagramTransport(transports.Transport): +class _SelectorDatagramTransport(transports.DatagramTransport): max_size = 256 * 1024 # max bytes we read in one eventloop iteration diff --git a/tulip/transports.py b/tulip/transports.py index 984e7e80..35d5bb17 100644 --- a/tulip/transports.py +++ b/tulip/transports.py @@ -109,3 +109,25 @@ class Transport(ReadTransport, WriteTransport): The implementation here raises NotImplemented for every method except writelines(), which calls write() in a loop. """ + + +class DatagramTransport(BaseTransport): + """ABC for datagram (UDP) transports.""" + + def sendto(self, data, addr=None): + """Send data to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + addr is target socket address. + If addr is None use target address pointed on transport creation. + """ + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError From 410c2dc46a8dd406c5b9b92d2491f44fb7551085 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 21 Mar 2013 10:42:20 -0700 Subject: [PATCH 0373/1502] http server --- srv.py | 64 ++-------- tests/http_server_test.py | 242 ++++++++++++++++++++++++++++++++++++++ tulip/http/__init__.py | 6 +- tulip/http/errors.py | 44 +++++++ tulip/http/protocol.py | 52 ++++---- tulip/http/server.py | 176 +++++++++++++++++++++++++++ 6 files changed, 503 insertions(+), 81 deletions(-) create mode 100644 tests/http_server_test.py create mode 100644 tulip/http/errors.py create mode 100644 tulip/http/server.py diff --git a/srv.py b/srv.py index 710e0e7e..6f9b225b 100644 --- a/srv.py +++ b/srv.py @@ -1,32 +1,19 @@ """Simple server written using an event loop.""" -import http.client import email.message -import email.parser import os import tulip import tulip.http -class HttpServer(tulip.Protocol): - - def __init__(self): - super().__init__() - self.transport = None - self.reader = None - self.handler = None - - @tulip.task - def handle_request(self): - try: - method, path, version = yield from self.reader.read_request_line() - except http.client.BadStatusLine: - self.transport.close() - return +class HttpServer(tulip.http.ServerHttpProtocol): + def handle_request(self, request_info, message): print('method = {!r}; path = {!r}; version = {!r}'.format( - method, path, version)) + request_info.method, request_info.uri, request_info.version)) + + path = request_info.uri if (not (path.isprintable() and path.startswith('/')) or '/.' in path): print('bad path', repr(path)) @@ -40,11 +27,7 @@ def handle_request(self): isdir = os.path.isdir(path) if not path: - self.transport.write(b'HTTP/1.0 404 Not found\r\n\r\n') - self.transport.close() - return - - message = yield from self.reader.read_message() + raise tulip.http.HttpStatusException(404) headers = email.message.Message() for hdr, val in message.headers: @@ -53,14 +36,8 @@ def handle_request(self): if isdir and not path.endswith('/'): path = path + '/' - response = tulip.http.Response(self.transport, 302) - response.add_headers( - ('URI', path), - ('Location', path)) - response.send_headers() - response.write_eof() - self.transport.close() - return + raise tulip.http.HttpStatusException( + 302, headers=(('URI', path), ('Location', path))) response = tulip.http.Response(self.transport, 200) response.add_header('Transfer-Encoding', 'chunked') @@ -110,33 +87,12 @@ def handle_request(self): response.write(b'Cannot open') response.write_eof() - self.transport.close() - - def connection_made(self, transport): - self.transport = transport - print('connection made', transport, transport.get_extra_info('socket')) - self.reader = tulip.http.HttpStreamReader() - self.handler = self.handle_request() - - def data_received(self, data): - print('data received', data) - self.reader.feed_data(data) - - def eof_received(self): - print('eof received') - self.reader.feed_eof() - - def connection_lost(self, exc): - print('connection lost', exc) - if (self.handler.done() and - not self.handler.cancelled() and - self.handler.exception() is not None): - print('handler exception:', self.handler.exception()) + self.close() def main(): loop = tulip.get_event_loop() - f = loop.start_serving(HttpServer, '127.0.0.1', 8080) + f = loop.start_serving(lambda: HttpServer(debug=True), '127.0.0.1', 8080) x = loop.run_until_complete(f) print('serving on', x.getsockname()) loop.run_forever() diff --git a/tests/http_server_test.py b/tests/http_server_test.py new file mode 100644 index 00000000..dc55eff9 --- /dev/null +++ b/tests/http_server_test.py @@ -0,0 +1,242 @@ +"""Tests for http/server.py""" + +import unittest +import unittest.mock + +import tulip +from tulip.http import server +from tulip.http import errors +from tulip.test_utils import LogTrackingTestCase + + +class HttpServerProtocolTests(LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.suppress_log_errors() + + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + super().tearDown() + + def test_http_status_exception(self): + exc = errors.HttpStatusException(500, message='Internal error') + self.assertEqual(exc.code, 500) + self.assertEqual(exc.message, 'Internal error') + + def test_handle_request(self): + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + rline = unittest.mock.Mock() + rline.version = (1, 1) + message = unittest.mock.Mock() + srv.handle_request(rline, message) + + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertTrue(content.startswith(b'HTTP/1.1 404 Not Found\r\n')) + + def test_connection_made(self): + srv = server.ServerHttpProtocol() + self.assertIsNone(srv._request_handle) + + srv.connection_made(unittest.mock.Mock()) + self.assertIsNotNone(srv._request_handle) + + def test_data_received(self): + srv = server.ServerHttpProtocol() + srv.connection_made(unittest.mock.Mock()) + + srv.data_received(b'123') + self.assertEqual(b'123', b''.join(srv.stream.buffer)) + + srv.data_received(b'456') + self.assertEqual(b'123456', b''.join(srv.stream.buffer)) + + def test_eof_received(self): + srv = server.ServerHttpProtocol() + srv.connection_made(unittest.mock.Mock()) + srv.eof_received() + self.assertTrue(srv.stream.eof) + + def test_connection_lost(self): + srv = server.ServerHttpProtocol() + srv.connection_made(unittest.mock.Mock()) + srv.data_received(b'123') + + handle = srv._request_handle + srv.connection_lost(None) + + self.assertIsNone(srv._request_handle) + self.assertTrue(handle.cancelled()) + + srv.connection_lost(None) + self.assertIsNone(srv._request_handle) + + def test_close(self): + srv = server.ServerHttpProtocol() + self.assertFalse(srv.closing) + + srv.close() + self.assertTrue(srv.closing) + + def test_handle_error(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + srv.handle_error(404) + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertIn(b'HTTP/1.1 404 Not Found', content) + + @unittest.mock.patch('tulip.http.server.traceback') + def test_handle_error_traceback_exc(self, m_trace): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(debug=True) + srv.connection_made(transport) + + m_trace.format_exc.side_effect = ValueError + + srv.handle_error(500, exc=object()) + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertTrue( + content.startswith(b'HTTP/1.1 500 Internal Server Error')) + + def test_handle_error_debug(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.debug = True + srv.connection_made(transport) + + try: + raise ValueError() + except Exception as exc: + srv.handle_error(999, exc=exc) + + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + + self.assertIn(b'HTTP/1.1 500 Internal', content) + self.assertIn(b'Traceback (most recent call last):', content) + + def test_handle_error_500(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log) + srv.connection_made(transport) + + srv.handle_error(500) + self.assertTrue(log.exception.called) + + def test_handle(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + srv.close() + + self.loop.run_until_complete(srv._request_handle) + self.assertTrue(handle.called) + self.assertIsNone(srv._request_handle) + + def test_handle_coro(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + called = False + + @tulip.coroutine + def coro(rline, message): + nonlocal called + called = True + yield from [] + srv.eof_received() + + srv.handle_request = coro + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + self.loop.run_until_complete(srv._request_handle) + self.assertTrue(called) + + def test_handle_close(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + srv.close() + self.loop.run_until_complete(srv._request_handle) + + self.assertTrue(handle.called) + self.assertTrue(transport.close.called) + + def test_handle_cancel(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log, debug=True) + srv.connection_made(transport) + + srv.handle_request = unittest.mock.Mock() + + @tulip.task + def cancel(): + yield from [] + srv._request_handle.cancel() + + srv.close() + self.loop.run_until_complete( + tulip.wait([srv._request_handle, cancel()])) + self.assertTrue(log.debug.called) + + def test_handle_400(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + srv.handle_error = unittest.mock.Mock() + + def side_effect(*args): + srv.close() + srv.handle_error.side_effect = side_effect + + srv.stream.feed_data(b'GET / HT/asd\r\n') + + self.loop.run_until_complete(srv._request_handle) + self.assertTrue(srv.handle_error.called) + self.assertTrue(400, srv.handle_error.call_args[0][0]) + self.assertTrue(transport.close.called) + + def test_handle_500(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + handle.side_effect = ValueError + srv.handle_error = unittest.mock.Mock() + srv.close() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + self.loop.run_until_complete(srv._request_handle) + + self.assertTrue(srv.handle_error.called) + self.assertTrue(500, srv.handle_error.call_args[0][0]) diff --git a/tulip/http/__init__.py b/tulip/http/__init__.py index d436383f..582f0809 100644 --- a/tulip/http/__init__.py +++ b/tulip/http/__init__.py @@ -1,8 +1,12 @@ # This relies on each of the submodules having an __all__ variable. from .client import * +from .errors import * from .protocol import * +from .server import * __all__ = (client.__all__ + - protocol.__all__) + errors.__all__ + + protocol.__all__ + + server.__all__) diff --git a/tulip/http/errors.py b/tulip/http/errors.py new file mode 100644 index 00000000..41344de1 --- /dev/null +++ b/tulip/http/errors.py @@ -0,0 +1,44 @@ +"""http related errors.""" + +__all__ = ['HttpException', 'HttpStatusException', + 'IncompleteRead', 'BadStatusLine', 'LineTooLong', 'InvalidHeader'] + +import http.client + + +class HttpException(http.client.HTTPException): + + code = None + headers = () + + +class HttpStatusException(HttpException): + + def __init__(self, code, headers=None, message=''): + self.code = code + self.headers = headers + self.message = message + + +class BadRequestException(HttpException): + + code = 400 + + +class IncompleteRead(BadRequestException, http.client.IncompleteRead): + pass + + +class BadStatusLine(BadRequestException, http.client.BadStatusLine): + pass + + +class LineTooLong(BadRequestException, http.client.LineTooLong): + pass + + +class InvalidHeader(BadRequestException): + + def __init__(self, hdr): + super().__init__('Invalid HTTP Header: %s' % hdr) + self.hdr = hdr diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py index 1773cabd..6a0e1279 100644 --- a/tulip/http/protocol.py +++ b/tulip/http/protocol.py @@ -7,7 +7,6 @@ import collections import email.utils import functools -import http.client import http.server import itertools import re @@ -15,6 +14,7 @@ import zlib import tulip +from . import errors METHRE = re.compile('[A-Z0-9$-_.]+') VERSRE = re.compile('HTTP/(\d+).(\d+)') @@ -70,7 +70,7 @@ def feed_eof(self): @tulip.coroutine def read_request_line(self): - """Read request status line. Exception http.client.BadStatusLine + """Read request status line. Exception errors.BadStatusLine could be raised in case of any errors in status line. Returns three values (method, uri, version) @@ -86,29 +86,29 @@ def read_request_line(self): try: line = bline.decode('ascii').rstrip() except UnicodeDecodeError: - raise http.client.BadStatusLine(bline) from None + raise errors.BadStatusLine(bline) from None try: method, uri, version = line.split(None, 2) except ValueError: - raise http.client.BadStatusLine(line) from None + raise errors.BadStatusLine(line) from None # method method = method.upper() if not METHRE.match(method): - raise http.client.BadStatusLine(method) + raise errors.BadStatusLine(method) # version match = VERSRE.match(version) if match is None: - raise http.client.BadStatusLine(version) + raise errors.BadStatusLine(version) version = (int(match.group(1)), int(match.group(2))) return RequestLine(method, uri, version) @tulip.coroutine def read_response_status(self): - """Read response status line. Exception http.client.BadStatusLine + """Read response status line. Exception errors.BadStatusLine could be raised in case of any errors in status line. Returns three values (version, status_code, reason) @@ -124,17 +124,17 @@ def read_response_status(self): if not bline: # Presumably, the server closed the connection before # sending a valid response. - raise http.client.BadStatusLine(bline) + raise errors.BadStatusLine(bline) try: line = bline.decode('ascii').rstrip() except UnicodeDecodeError: - raise http.client.BadStatusLine(bline) from None + raise errors.BadStatusLine(bline) from None try: version, status = line.split(None, 1) except ValueError: - raise http.client.BadStatusLine(line) from None + raise errors.BadStatusLine(line) from None else: try: status, reason = status.split(None, 1) @@ -144,17 +144,17 @@ def read_response_status(self): # version match = VERSRE.match(version) if match is None: - raise http.client.BadStatusLine(line) + raise errors.BadStatusLine(line) version = (int(match.group(1)), int(match.group(2))) # The status code is a three-digit number try: status = int(status) except ValueError: - raise http.client.BadStatusLine(line) from None + raise errors.BadStatusLine(line) from None if status < 100 or status > 999: - raise http.client.BadStatusLine(line) + raise errors.BadStatusLine(line) return ResponseStatus(version, status, reason.strip()) @@ -196,7 +196,7 @@ def read_headers(self): while continuation: header_length += len(line) if header_length > self.MAX_HEADERFIELD_SIZE: - raise http.client.LineTooLong( + raise errors.LineTooLong( 'limit request headers fields size') value.append(line) @@ -204,13 +204,13 @@ def read_headers(self): continuation = line.startswith(CONTINUATION) else: if header_length > self.MAX_HEADERFIELD_SIZE: - raise http.client.LineTooLong( + raise errors.LineTooLong( 'limit request headers fields size') # total headers size size += header_length if size >= self.MAX_HEADERS: - raise http.client.LineTooLong('limit request headers fields') + raise errors.LineTooLong('limit request headers fields') headers.append( (name, @@ -240,7 +240,7 @@ def _parse_chunked_payload(self): try: size = int(line, 16) except ValueError: - raise http.client.IncompleteRead(b'') from None + raise errors.IncompleteRead(b'') from None if size == 0: break @@ -274,8 +274,8 @@ def _parse_chunked_payload(self): return data except StreamEofException: - stream.set_exception(http.client.IncompleteRead(b'')) - except http.client.IncompleteRead as exc: + stream.set_exception(errors.IncompleteRead(b'')) + except errors.IncompleteRead as exc: stream.set_exception(exc) def _parse_length_payload(self, length): @@ -300,7 +300,7 @@ def _parse_length_payload(self, length): stream.feed_eof() return data except StreamEofException: - stream.set_exception(http.client.IncompleteRead(b'')) + stream.set_exception(errors.IncompleteRead(b'')) def _parse_eof_payload(self): """Read all bytes untile eof.""" @@ -360,10 +360,10 @@ def read_message(self, version=(1, 1), try: length = int(length) except ValueError: - raise http.client.HTTPException('CONTENT-LENGTH') from None + raise errors.InvalidHeader('CONTENT-LENGTH') from None if length < 0: - raise http.client.HTTPException('CONTENT-LENGTH') + raise errors.InvalidHeader('CONTENT-LENGTH') parser = self._parse_length_payload(length) else: @@ -421,8 +421,8 @@ def read_message(self, version=(1, 1), return RawHttpMessage(headers, payload, close_conn, encoding) -class StreamEofException(http.client.HTTPException): - """eof received""" +class StreamEofException(Exception): + """Internal exception: eof received.""" class DeflateStream: @@ -442,7 +442,7 @@ def feed_data(self, chunk): try: chunk = self.zlib.decompress(chunk) except: - self.stream.set_exception(http.client.IncompleteRead(b'')) + self.stream.set_exception(errors.IncompleteRead(b'')) if chunk: self.stream.feed_data(chunk) @@ -450,7 +450,7 @@ def feed_data(self, chunk): def feed_eof(self): self.stream.feed_data(self.zlib.flush()) if not self.zlib.eof: - self.stream.set_exception(http.client.IncompleteRead(b'')) + self.stream.set_exception(errors.IncompleteRead(b'')) self.stream.feed_eof() diff --git a/tulip/http/server.py b/tulip/http/server.py new file mode 100644 index 00000000..7590e47b --- /dev/null +++ b/tulip/http/server.py @@ -0,0 +1,176 @@ +"""simple http server.""" + +__all__ = ['ServerHttpProtocol'] + +import http.server +import inspect +import logging +import traceback + +import tulip +import tulip.http + +from . import errors + +RESPONSES = http.server.BaseHTTPRequestHandler.responses +DEFAULT_ERROR_MESSAGE = """ + + + %(status)s %(reason)s + + +

%(status)s %(reason)s

+ %(message)s + +""" + + +class ServerHttpProtocol(tulip.Protocol): + """Simple http protocol implementation. + + ServerHttpProtocol handles incoming http request. It reads request line, + request headers and request payload and calls handler_request() method. + By default it always returns with 404 respose. + + ServerHttpProtocol handles errors in incoming request, like bad + status line, bad headers or incomplete payload. If any error occurs, + connection gets closed. + """ + closing = False + request_count = 0 + _request_handle = None + + def __init__(self, log=logging, debug=False): + self.log = log + self.debug = debug + + def connection_made(self, transport): + self.transport = transport + self.stream = tulip.http.HttpStreamReader() + self._request_handle = self.start() + + def data_received(self, data): + self.stream.feed_data(data) + + def connection_lost(self, exc): + if self._request_handle is not None: + self._request_handle.cancel() + self._request_handle = None + + def eof_received(self): + self.stream.feed_eof() + + def close(self): + self.closing = True + + def log_access(self, status, info, message, *args, **kw): + pass + + def log_debug(self, *args, **kw): + if self.debug: + self.log.debug(*args, **kw) + + def log_exception(self, *args, **kw): + self.log.exception(*args, **kw) + + @tulip.task + def start(self): + """Start processing of incoming requests. + It reads request line, request headers and request payload, then + calls handle_request() method. Subclass has to override + handle_request(). start() handles various excetions in request + or response handling. In case of any error connection is being closed. + """ + + while True: + info = None + message = None + self.request_count += 1 + + try: + info = yield from self.stream.read_request_line() + message = yield from self.stream.read_message(info.version) + + handler = self.handle_request(info, message) + if (inspect.isgenerator(handler) or + isinstance(handler, tulip.Future)): + yield from handler + + except tulip.CancelledError: + self.log_debug('Ignored premature client disconnection.') + break + except errors.HttpException as exc: + self.handle_error(exc.code, info, message, exc, exc.headers) + except Exception as exc: + self.handle_error(500, info, message, exc) + finally: + if self.closing: + self.transport.close() + break + + self._request_handle = None + + def handle_error(self, status=500, info=None, + message=None, exc=None, headers=None): + """Handle errors. + + Returns http response with specific status code. Logs additional + information. It always closes current connection.""" + + if status == 500: + self.log_exception("Error handling request") + + try: + reason, msg = RESPONSES[status] + except KeyError: + status = 500 + reason, msg = '???', '' + + if self.debug and exc is not None: + try: + tb = traceback.format_exc() + msg += '

Traceback:

\n
%s
' % tb + except: + pass + + self.log_access(status, info, message) + + html = DEFAULT_ERROR_MESSAGE % { + 'status': status, 'reason': reason, 'message': msg} + + response = tulip.http.Response(self.transport, status, close=True) + response.add_headers( + ('Content-Type', 'text/html'), + ('Content-Length', str(len(html)))) + if headers is not None: + response.add_headers(*headers) + response.send_headers() + + response.write(html.encode('ascii')) + response.write_eof() + + self.close() + + def handle_request(self, info, message): + """Handle a single http request. + + Subclass should override this method. By default it always + returns 404 response. + + info: tulip.http.RequestLine instance + message: tulip.http.RawHttpMessage instance + """ + response = tulip.http.Response( + self.transport, 404, http_version=info.version, close=True) + + body = b'Page Not Found!' + + response.add_headers( + ('Content-Type', 'text/plain'), + ('Content-Length', str(len(body)))) + response.send_headers() + response.write(body) + response.write_eof() + + self.close() + self.log_access(404, info, message) From 23ebc684be72210135d28de5aba47cd2ee695a1c Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 22 Mar 2013 01:35:41 -0700 Subject: [PATCH 0374/1502] Use tail recursion idiom for protocol callbacks, cleanup tests --- tests/events_test.py | 18 --- tests/selector_events_test.py | 244 +++++++++++++--------------------- tulip/selector_events.py | 40 +++--- 3 files changed, 120 insertions(+), 182 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index a71f10c3..6a7dcb75 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -705,8 +705,6 @@ def factory(): self.assertEqual('INITIAL', proto.state) self.event_loop.run_once() self.assertEqual('CONNECTED', proto.state) - self.assertEqual(0, proto.nbytes) - self.event_loop.run_once() self.assertEqual(3, proto.nbytes) # extra info is available @@ -719,7 +717,6 @@ def factory(): # close connection proto.transport.close() - self.event_loop.run_once() self.assertEqual('CLOSED', proto.state) # the client socket must be closed after to avoid ECONNRESET upon @@ -742,7 +739,6 @@ def test_start_serving_sock(self): client.send(b'xxx') self.event_loop.run_once() # This is quite mysterious, but necessary. self.event_loop.run_once() - self.event_loop.run_once() sock.close() client.close() @@ -807,13 +803,10 @@ def datagram_received(self, data, addr): self.assertEqual('INITIALIZED', protocol.state) transport.sendto(b'xxx') self.event_loop.run_once() - self.assertEqual(0, server.nbytes) - self.event_loop.run_once() self.assertEqual(3, server.nbytes) self.event_loop.run_once() # received - self.event_loop.run_once() self.assertEqual(8, protocol.nbytes) # extra info is available @@ -823,10 +816,7 @@ def datagram_received(self, data, addr): # close connection transport.close() - - self.event_loop.run_once() self.assertEqual('CLOSED', protocol.state) - server.transport.close() def test_create_datagram_connection_no_connection(self): @@ -852,13 +842,10 @@ def datagram_received(self, data, addr): self.assertEqual('INITIALIZED', protocol.state) transport.sendto(b'xxx', (host, port)) self.event_loop.run_once() - self.assertEqual(0, server.nbytes) - self.event_loop.run_once() self.assertEqual(3, server.nbytes) self.event_loop.run_once() # received - self.event_loop.run_once() self.assertEqual(8, protocol.nbytes) # extra info is available @@ -868,10 +855,7 @@ def datagram_received(self, data, addr): # close connection transport.close() - - self.event_loop.run_once() self.assertEqual('CLOSED', protocol.state) - server.transport.close() def test_create_datagram_connection_no_getaddrinfo(self): @@ -930,8 +914,6 @@ def factory(): client = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) client.sendto(b'xxx', ('127.0.0.1', port)) self.event_loop.run_once() - self.assertEqual(0, proto.nbytes) - self.event_loop.run_once() self.assertEqual(3, proto.nbytes) data, server = client.recvfrom(4096) diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index b8fb10b0..6ea34318 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -11,6 +11,8 @@ from tulip import futures from tulip import selectors +from tulip.events import AbstractEventLoop +from tulip.protocols import DatagramProtocol, Protocol from tulip.selector_events import BaseSelectorEventLoop from tulip.selector_events import _SelectorSslTransport from tulip.selector_events import _SelectorSocketTransport @@ -49,7 +51,9 @@ def test_make_ssl_transport(self): def test_close(self): ssock = self.event_loop._ssock + ssock.fileno.return_value = 7 csock = self.event_loop._csock + csock.fileno.return_value = 1 remove_reader = self.event_loop.remove_reader = unittest.mock.Mock() self.event_loop._selector.close() @@ -58,10 +62,10 @@ def test_close(self): self.assertIsNone(self.event_loop._selector) self.assertIsNone(self.event_loop._csock) self.assertIsNone(self.event_loop._ssock) - self.assertTrue(selector.close.called) - self.assertTrue(ssock.close.called) - self.assertTrue(csock.close.called) - self.assertTrue(remove_reader.called) + selector.close.assert_called_with() + ssock.close.assert_called_with() + csock.close.assert_called_with() + remove_reader.assert_called_with(7) self.event_loop.close() self.event_loop.close() @@ -116,9 +120,8 @@ def test_sock_recv(self): f = self.event_loop.sock_recv(sock, 1024) self.assertIsInstance(f, futures.Future) - self.assertEqual( - (f, False, sock, 1024), - self.event_loop._sock_recv.call_args[0]) + self.event_loop._sock_recv.assert_called_with( + f, False, sock, 1024) def test__sock_recv_canceled_fut(self): sock = unittest.mock.Mock() @@ -521,7 +524,7 @@ def test_process_events_read(self): self.event_loop._process_events( ((1, selectors.EVENT_READ, (reader, None)),)) self.assertTrue(self.event_loop._add_callback.called) - self.assertEqual((reader,), self.event_loop._add_callback.call_args[0]) + self.event_loop._add_callback.assert_called_with(reader) def test_process_events_read_cancelled(self): reader = unittest.mock.Mock() @@ -530,8 +533,7 @@ def test_process_events_read_cancelled(self): self.event_loop.remove_reader = unittest.mock.Mock() self.event_loop._process_events( ((1, selectors.EVENT_READ, (reader, None)),)) - self.assertTrue(self.event_loop.remove_reader.called) - self.assertEqual((1,), self.event_loop.remove_reader.call_args[0]) + self.event_loop.remove_reader.assert_called_with(1) def test_process_events_write(self): writer = unittest.mock.Mock() @@ -540,31 +542,31 @@ def test_process_events_write(self): self.event_loop._add_callback = unittest.mock.Mock() self.event_loop._process_events( ((1, selectors.EVENT_WRITE, (None, writer)),)) - self.assertTrue(self.event_loop._add_callback.called) - self.assertEqual((writer,), self.event_loop._add_callback.call_args[0]) + self.event_loop._add_callback.assert_called_with(writer) def test_process_events_write_cancelled(self): writer = unittest.mock.Mock() writer.cancelled = True - self.event_loop.remove_writer = unittest.mock.Mock() + self.event_loop._process_events( ((1, selectors.EVENT_WRITE, (None, writer)),)) - self.assertTrue(self.event_loop.remove_writer.called) - self.assertEqual((1,), self.event_loop.remove_writer.call_args[0]) + self.event_loop.remove_writer.assert_called_with(1) class SelectorSocketTransportTests(unittest.TestCase): def setUp(self): - self.event_loop = unittest.mock.Mock() - self.sock = unittest.mock.Mock() - self.protocol = unittest.mock.Mock() + self.event_loop = unittest.mock.Mock(spec_set=AbstractEventLoop) + self.sock = unittest.mock.Mock(socket.socket) + self.sock.fileno.return_value = 7 + self.protocol = unittest.mock.Mock(Protocol) def test_ctor(self): - _SelectorSocketTransport(self.event_loop, self.sock, self.protocol) - self.assertTrue(self.event_loop.add_reader.called) - self.assertTrue(self.event_loop.call_soon.called) + tr = _SelectorSocketTransport(self.event_loop, self.sock, self.protocol) + self.event_loop.add_reader.assert_called_with(7, tr._read_ready) + self.event_loop.call_soon.assert_called_with( + self.protocol.connection_made, tr) def test_ctor_with_waiter(self): fut = futures.Future() @@ -576,24 +578,15 @@ def test_ctor_with_waiter(self): self.event_loop.call_soon.call_args[0][0]) def test_read_ready(self): - data_received = unittest.mock.Mock() - self.protocol.data_received = data_received - transport = _SelectorSocketTransport( self.event_loop, self.sock, self.protocol) self.sock.recv.return_value = b'data' transport._read_ready() - self.assertTrue(self.event_loop.call_soon.called) - self.assertEqual( - (data_received, b'data'), - self.event_loop.call_soon.call_args[0]) + self.protocol.data_received.assert_called_with(b'data') def test_read_ready_eof(self): - eof_received = unittest.mock.Mock() - self.protocol.eof_received = eof_received - transport = _SelectorSocketTransport( self.event_loop, self.sock, self.protocol) @@ -601,44 +594,43 @@ def test_read_ready_eof(self): transport._read_ready() self.assertTrue(self.event_loop.remove_reader.called) - self.assertEqual( - (eof_received,), self.event_loop.call_soon.call_args[0]) + self.protocol.eof_received.assert_called_with() - def test_read_ready_tryagain(self): + @unittest.mock.patch('logging.exception') + def test_read_ready_tryagain(self, m_exc): transport = _SelectorSocketTransport( self.event_loop, self.sock, self.protocol) class Err(socket.error): errno = errno.EAGAIN - self.sock.recv.side_effect = Err + err = self.sock.recv.side_effect = Err() transport._fatal_error = unittest.mock.Mock() transport._read_ready() self.assertFalse(transport._fatal_error.called) - def test_read_ready_err(self): + @unittest.mock.patch('logging.exception') + def test_read_ready_err(self, m_exc): transport = _SelectorSocketTransport( self.event_loop, self.sock, self.protocol) class Err(socket.error): pass - self.sock.recv.side_effect = Err + err = self.sock.recv.side_effect = Err() transport._fatal_error = unittest.mock.Mock() transport._read_ready() - self.assertTrue(transport._fatal_error.called) - self.assertIsInstance(transport._fatal_error.call_args[0][0], Err) + transport._fatal_error.assert_called_with(err) def test_abort(self): transport = _SelectorSocketTransport( self.event_loop, self.sock, self.protocol) - transport._fatal_error = unittest.mock.Mock() + transport._close = unittest.mock.Mock() transport.abort() - self.assertTrue(transport._fatal_error.called) - self.assertIsNone(transport._fatal_error.call_args[0][0]) + transport._close.assert_called_with(None) def test_write(self): data = b'data' @@ -647,8 +639,7 @@ def test_write(self): transport = _SelectorSocketTransport( self.event_loop, self.sock, self.protocol) transport.write(data) - self.assertTrue(self.sock.send.called) - self.assertEqual(self.sock.send.call_args[0], (data,)) + self.sock.send.assert_called_with(data) def test_write_no_data(self): transport = _SelectorSocketTransport( @@ -688,10 +679,7 @@ def test_write_partial_none(self): self.event_loop, self.sock, self.protocol) transport.write(data) - self.assertTrue(self.event_loop.add_writer.called) - self.assertEqual( - transport._write_ready, self.event_loop.add_writer.call_args[0][1]) - + self.event_loop.add_writer.assert_called_with(7, transport._write_ready) self.assertEqual([b'data'], transport._buffer) def test_write_tryagain(self): @@ -760,13 +748,9 @@ def test_write_ready_closing(self): transport._closing = True transport._buffer.append(data) transport._write_ready() - self.assertTrue(self.sock.send.called) - self.assertEqual(self.sock.send.call_args[0], (data,)) - self.assertTrue(self.event_loop.remove_writer.called) - self.assertTrue(self.event_loop.call_soon.called) - self.assertEqual( - (transport._call_connection_lost, None), - self.event_loop.call_soon.call_args[0]) + self.sock.send.assert_called_with(data) + self.event_loop.remove_writer.assert_called_with(7) + self.protocol.connection_lost.assert_called_with(None) def test_write_ready_no_data(self): transport = _SelectorSocketTransport( @@ -815,16 +799,14 @@ def test_write_ready_exception(self): class Err(socket.error): pass - self.sock.send.side_effect = Err + err = self.sock.send.side_effect = Err() transport = _SelectorSocketTransport( self.event_loop, self.sock, self.protocol) transport._fatal_error = unittest.mock.Mock() transport._buffer.append(b'data') transport._write_ready() - - self.assertTrue(transport._fatal_error.called) - self.assertIsInstance(transport._fatal_error.call_args[0][0], Err) + transport._fatal_error.assert_called_with(err) def test_close(self): transport = _SelectorSocketTransport( @@ -832,11 +814,8 @@ def test_close(self): transport.close() self.assertTrue(transport._closing) - self.assertTrue(self.event_loop.remove_reader.called) - self.assertTrue(self.event_loop.call_soon.called) - self.assertEqual( - (transport._call_connection_lost, None), - self.event_loop.call_soon.call_args[0]) + self.event_loop.remove_reader.assert_called_with(7) + self.protocol.connection_lost(None) def test_close_write_buffer(self): transport = _SelectorSocketTransport( @@ -848,43 +827,38 @@ def test_close_write_buffer(self): self.assertTrue(self.event_loop.remove_reader.called) self.assertFalse(self.event_loop.call_soon.called) - def test_fatal_error(self): - exc = object() + @unittest.mock.patch('logging.exception') + def test_fatal_error(self, m_exc): + exc = OSError() transport = _SelectorSocketTransport( self.event_loop, self.sock, self.protocol) - self.event_loop.reset_mock() transport._buffer.append(b'data') transport._fatal_error(exc) self.assertEqual([], transport._buffer) - self.assertTrue(self.event_loop.remove_writer.called) - self.assertTrue(self.event_loop.remove_reader.called) - self.assertTrue(self.event_loop.call_soon.called) - self.assertEqual( - (transport._call_connection_lost, exc), - self.event_loop.call_soon.call_args[0]) + self.event_loop.remove_reader.assert_called_with(7) + self.event_loop.remove_writer.assert_called_with(7) + self.protocol.connection_lost.assert_called_with(exc) + m_exc.assert_called_with('Fatal error for %s', transport) def test_connection_lost(self): exc = object() transport = _SelectorSocketTransport( self.event_loop, self.sock, self.protocol) - self.sock.reset_mock() - self.protocol.reset_mock() transport._call_connection_lost(exc) - self.assertTrue(self.protocol.connection_lost.called) - self.assertEqual( - (exc,), self.protocol.connection_lost.call_args[0]) - self.assertTrue(self.sock.close.called) + self.protocol.connection_lost.assert_called_with(exc) + self.sock.close.assert_called_with() @unittest.skipIf(ssl is None, 'No ssl module') class SelectorSslTransportTests(unittest.TestCase): def setUp(self): - self.event_loop = unittest.mock.Mock() - self.sock = unittest.mock.Mock() - self.protocol = unittest.mock.Mock() + self.event_loop = unittest.mock.Mock(spec_set=AbstractEventLoop) + self.sock = unittest.mock.Mock(socket.socket) + self.sock.fileno.return_value = 7 + self.protocol = unittest.mock.Mock(spec_set=Protocol) self.sslsock = unittest.mock.Mock() self.sslsock.fileno.return_value = 1 self.sslcontext = unittest.mock.Mock() @@ -948,34 +922,27 @@ def test_write_closing(self): self.assertRaises(AssertionError, self.transport.write, b'data') def test_abort(self): - self.transport._fatal_error = unittest.mock.Mock() - + self.transport._close = unittest.mock.Mock() self.transport.abort() - self.assertTrue(self.transport._fatal_error.called) - self.assertEqual((None,), self.transport._fatal_error.call_args[0]) + self.transport._close.assert_called_with(None) - def test_fatal_error(self): - exc = object() + @unittest.mock.patch('logging.exception') + def test_fatal_error(self, m_exc): + exc = OSError() self.transport._buffer.append(b'data') self.transport._fatal_error(exc) self.assertEqual([], self.transport._buffer) self.assertTrue(self.event_loop.remove_writer.called) self.assertTrue(self.event_loop.remove_reader.called) - self.assertTrue(self.event_loop.call_soon.called) - self.assertEqual( - (self.protocol.connection_lost, exc), - self.event_loop.call_soon.call_args[0]) + self.protocol.connection_lost.assert_called_with(exc) + m_exc.assert_called_with('Fatal error for %s', self.transport) def test_close(self): self.transport.close() - self.assertTrue(self.transport._closing) self.assertTrue(self.event_loop.remove_reader.called) - self.assertTrue(self.event_loop.call_soon.called) - self.assertEqual( - (self.protocol.connection_lost, None), - self.event_loop.call_soon.call_args[0]) + self.protocol.connection_lost.assert_called_with(None) def test_close_write_buffer(self): self.transport._buffer.append(b'data') @@ -1109,24 +1076,20 @@ class Err(socket.error): class SelectorDatagramTransportTests(unittest.TestCase): def setUp(self): - self.event_loop = unittest.mock.Mock() - self.sock = unittest.mock.Mock() - self.protocol = unittest.mock.Mock() + self.event_loop = unittest.mock.Mock(spec_set=AbstractEventLoop) + self.sock = unittest.mock.Mock(spec_set=socket.socket) + self.sock.fileno.return_value = 7 + self.protocol = unittest.mock.Mock(spec_set=DatagramProtocol) def test_read_ready(self): - datagram_received = unittest.mock.Mock() - self.protocol.datagram_received = datagram_received - transport = _SelectorDatagramTransport( self.event_loop, self.sock, self.protocol) self.sock.recvfrom.return_value = (b'data', ('0.0.0.0', 1234)) transport._read_ready() - self.assertTrue(self.event_loop.call_soon.called) - self.assertEqual( - (datagram_received, b'data', ('0.0.0.0', 1234)), - self.event_loop.call_soon.call_args[0]) + self.protocol.datagram_received.assert_called_with( + b'data', ('0.0.0.0', 1234)) def test_read_ready_tryagain(self): transport = _SelectorDatagramTransport( @@ -1148,21 +1111,19 @@ def test_read_ready_err(self): class Err(socket.error): pass - self.sock.recvfrom.side_effect = Err + err = self.sock.recvfrom.side_effect = Err() transport._fatal_error = unittest.mock.Mock() transport._read_ready() - self.assertTrue(transport._fatal_error.called) - self.assertIsInstance(transport._fatal_error.call_args[0][0], Err) + transport._fatal_error.assert_called_with(err) def test_abort(self): transport = _SelectorDatagramTransport( self.event_loop, self.sock, self.protocol) - transport._fatal_error = unittest.mock.Mock() + transport._close = unittest.mock.Mock() transport.abort() - self.assertTrue(transport._fatal_error.called) - self.assertIsNone(transport._fatal_error.call_args[0][0]) + transport._close.assert_called_with(None) def test_sendto(self): data = b'data' @@ -1292,13 +1253,9 @@ def test_sendto_ready_closing(self): transport._closing = True transport._buffer.append((data, ())) transport._sendto_ready() - self.assertTrue(self.sock.sendto.called) - self.assertEqual(self.sock.sendto.call_args[0], (data, ())) - self.assertTrue(self.event_loop.remove_writer.called) - self.assertTrue(self.event_loop.call_soon.called) - self.assertEqual( - (transport._call_connection_lost, None), - self.event_loop.call_soon.call_args[0]) + self.sock.sendto.assert_called_with(data, ()) + self.event_loop.remove_writer.assert_called_with(7) + self.protocol.connection_lost.assert_called_with(None) def test_sendto_ready_no_data(self): transport = _SelectorDatagramTransport( @@ -1366,24 +1323,21 @@ def test_close(self): transport.close() self.assertTrue(transport._closing) - self.assertTrue(self.event_loop.remove_reader.called) - self.assertTrue(self.event_loop.call_soon.called) - self.assertEqual( - (transport._call_connection_lost, None), - self.event_loop.call_soon.call_args[0]) + self.event_loop.remove_reader.assert_called_with(7) + self.protocol.connection_lost.assert_called_with(None) def test_close_write_buffer(self): transport = _SelectorDatagramTransport( self.event_loop, self.sock, self.protocol) - self.event_loop.reset_mock() transport._buffer.append((b'data', ())) transport.close() - self.assertTrue(self.event_loop.remove_reader.called) - self.assertFalse(self.event_loop.call_soon.called) + self.event_loop.remove_reader.assert_called_with(7) + self.assertFalse(self.protocol.connection_lost.called) - def test_fatal_error(self): - exc = object() + @unittest.mock.patch('logging.exception') + def test_fatal_error(self, m_exc): + exc = OSError() transport = _SelectorDatagramTransport( self.event_loop, self.sock, self.protocol) self.event_loop.reset_mock() @@ -1391,31 +1345,25 @@ def test_fatal_error(self): transport._fatal_error(exc) self.assertEqual([], list(transport._buffer)) - self.assertTrue(self.event_loop.remove_writer.called) - self.assertTrue(self.event_loop.remove_reader.called) - self.assertTrue(self.event_loop.call_soon.called) - self.assertEqual( - (transport._call_connection_lost, exc), - self.event_loop.call_soon.call_args[0]) + self.event_loop.remove_writer.assert_called_with(7) + self.event_loop.remove_reader.assert_called_with(7) + self.protocol.connection_lost.assert_called_with(exc) + m_exc.assert_called_with('Fatal error for %s', transport) - def test_fatal_error_connected(self): + @unittest.mock.patch('logging.exception') + def test_fatal_error_connected(self, m_exc): transport = _SelectorDatagramTransport( self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1)) - self.event_loop.reset_mock() - transport._fatal_error(ConnectionRefusedError()) - - self.assertEqual( - 2, self.event_loop.call_soon.call_count) + err = ConnectionRefusedError() + transport._fatal_error(err) + self.protocol.connection_refused.assert_called_with(err) + m_exc.assert_called_with('Fatal error for %s', transport) def test_transport_closing(self): exc = object() transport = _SelectorDatagramTransport( self.event_loop, self.sock, self.protocol) - self.sock.reset_mock() - self.protocol.reset_mock() transport._call_connection_lost(exc) - self.assertTrue(self.protocol.connection_lost.called) - self.assertEqual( - (exc,), self.protocol.connection_lost.call_args[0]) - self.assertTrue(self.sock.close.called) + self.protocol.connection_lost.assert_called_with(exc) + self.sock.close.assert_called_with() diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 2992e9e6..464c5808 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -346,10 +346,10 @@ def _read_ready(self): self._fatal_error(exc) else: if data: - self._event_loop.call_soon(self._protocol.data_received, data) + self._protocol.data_received(data) else: self._event_loop.remove_reader(self._sock.fileno()) - self._event_loop.call_soon(self._protocol.eof_received) + self._protocol.eof_received() def write(self, data): assert isinstance(data, (bytes, bytearray)), repr(data) @@ -390,7 +390,7 @@ def _write_ready(self): if n == len(data): self._event_loop.remove_writer(self._sock.fileno()) if self._closing: - self._event_loop.call_soon(self._call_connection_lost, None) + self._call_connection_lost(None) return if n: data = data[n:] @@ -399,20 +399,24 @@ def _write_ready(self): # TODO: write_eof(), can_write_eof(). def abort(self): - self._fatal_error(None) + self._close(None) def close(self): self._closing = True self._event_loop.remove_reader(self._sock.fileno()) if not self._buffer: - self._event_loop.call_soon(self._call_connection_lost, None) + self._call_connection_lost(None) def _fatal_error(self, exc): + # should be called from exception handler only logging.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): self._event_loop.remove_writer(self._sock.fileno()) self._event_loop.remove_reader(self._sock.fileno()) self._buffer = [] - self._event_loop.call_soon(self._call_connection_lost, exc) + self._call_connection_lost(exc) def _call_connection_lost(self, exc): try: @@ -538,20 +542,23 @@ def write(self, data): # TODO: write_eof(), can_write_eof(). def abort(self): - self._fatal_error(None) + self._close(None) def close(self): self._closing = True self._event_loop.remove_reader(self._sslsock.fileno()) if not self._buffer: - self._event_loop.call_soon(self._protocol.connection_lost, None) + self._protocol.connection_lost(None) def _fatal_error(self, exc): logging.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): self._event_loop.remove_writer(self._sslsock.fileno()) self._event_loop.remove_reader(self._sslsock.fileno()) self._buffer = [] - self._event_loop.call_soon(self._protocol.connection_lost, exc) + self._protocol.connection_lost(exc) class _SelectorDatagramTransport(transports.DatagramTransport): @@ -578,8 +585,7 @@ def _read_ready(self): if exc.errno not in _TRYAGAIN: self._fatal_error(exc) else: - self._event_loop.call_soon( - self._protocol.datagram_received, data, addr) + self._protocol.datagram_received(data, addr) def sendto(self, data, addr=None): assert isinstance(data, bytes) @@ -633,26 +639,28 @@ def _sendto_ready(self): if not self._buffer: self._event_loop.remove_writer(self._fileno) if self._closing: - self._event_loop.call_soon(self._call_connection_lost, None) + self._call_connection_lost(None) def abort(self): - self._fatal_error(None) + self._close(None) def close(self): self._closing = True self._event_loop.remove_reader(self._fileno) if not self._buffer: - self._event_loop.call_soon(self._call_connection_lost, None) + self._call_connection_lost(None) def _fatal_error(self, exc): logging.exception('Fatal error for %s', self) + self._close(exc) + def _close(self, exc): self._buffer.clear() self._event_loop.remove_writer(self._fileno) self._event_loop.remove_reader(self._fileno) if self._address and isinstance(exc, ConnectionRefusedError): - self._event_loop.call_soon(self._protocol.connection_refused, exc) - self._event_loop.call_soon(self._call_connection_lost, exc) + self._protocol.connection_refused(exc) + self._call_connection_lost(exc) def _call_connection_lost(self, exc): try: From 325d5d9dd7850e4fb28c6bfb3b12bd9d13e5406e Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 22 Mar 2013 10:48:46 -0700 Subject: [PATCH 0375/1502] Clarify logic in run_until_complete(). No changes in behavior. --- tulip/base_events.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tulip/base_events.py b/tulip/base_events.py index 1dc0b52b..3e87548a 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -145,14 +145,20 @@ def run_once(self, timeout=None): def run_until_complete(self, future, timeout=None): """Run until the Future is done, or until a timeout. + If the argument is a coroutine, it is wrapped in a Task. + + XXX TBD: It would be disastrous to call run_until_complete() + with the same coroutine twice -- it would wrap it in two + different Tasks and that can't be good. + Return the Future's result, or raise its exception. If the timeout is reached or stop() is called, raise TimeoutError. """ - if (not isinstance(future, futures.Future) and - tasks.iscoroutine(future)): - future = tasks.Task(future) - - assert isinstance(future, futures.Future), 'Future is required' + if not isinstance(future, futures.Future): + if tasks.iscoroutine(future): + future = tasks.Task(future) + else: + assert False, 'A Future or coroutine is required' handle_called = False @@ -167,7 +173,7 @@ def stop_loop(): self.run_forever() else: handle = self.call_later(timeout, stop_loop) - self.run() + self.run_forever() handle.cancel() if handle_called: From 7ed385941f72be69eabc7ce36f759e2b308e5825 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 22 Mar 2013 15:17:14 -0700 Subject: [PATCH 0376/1502] udp api refactoring --- examples/udp_echo.py | 8 +- tests/events_test.py | 201 ++++++++++++++++--------------------------- tulip/base_events.py | 123 +++++++++++++------------- tulip/events.py | 8 +- 4 files changed, 137 insertions(+), 203 deletions(-) diff --git a/examples/udp_echo.py b/examples/udp_echo.py index c92cb06d..1597812a 100644 --- a/examples/udp_echo.py +++ b/examples/udp_echo.py @@ -12,7 +12,7 @@ ADDRESS = ('127.0.0.1', 10000) -class MyUdpEchoProtocol: +class MyServerUdpEchoProtocol: def connection_made(self, transport): print('start', transport) @@ -54,13 +54,15 @@ def connection_lost(self, exc): def start_server(): loop = tulip.get_event_loop() - loop.start_serving_datagram(MyUdpEchoProtocol, *ADDRESS) + tulip.Task(loop.create_datagram_connection( + MyServerUdpEchoProtocol, local_addr=ADDRESS)) loop.run_forever() def start_client(): loop = tulip.get_event_loop() - loop.create_datagram_connection(MyClientUdpEchoProtocol, *ADDRESS) + tulip.Task(loop.create_datagram_connection( + MyClientUdpEchoProtocol, remote_addr=ADDRESS)) loop.run_forever() diff --git a/tests/events_test.py b/tests/events_test.py index 6a7dcb75..01507243 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -3,7 +3,6 @@ import concurrent.futures import contextlib import errno -import fcntl import gc import io import os @@ -779,74 +778,53 @@ class Err(socket.error): self.assertRaises(Err, self.event_loop.run_until_complete, fut) self.assertTrue(m_sock.close.called) - def test_create_datagram_connection(self): - server = None - - def factory(): - nonlocal server - server = TestMyDatagramProto() - return server - - class TestMyDatagramProto(MyDatagramProto): - def datagram_received(self, data, addr): - super().datagram_received(data, addr) - self.transport.sendto(b'resp:'+data, addr) + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_connection_no_addrinfo(self, m_socket): + self.suppress_log_errors() - f = self.event_loop.start_serving_datagram(factory, '127.0.0.1', 0) - sock = self.event_loop.run_until_complete(f) - host, port = sock.getsockname() + m_socket.error = socket.error + m_socket.getaddrinfo.return_value = [] coro = self.event_loop.create_datagram_connection( - MyDatagramProto, host, port) - transport, protocol = self.event_loop.run_until_complete(coro) - - self.assertEqual('INITIALIZED', protocol.state) - transport.sendto(b'xxx') - self.event_loop.run_once() - self.assertEqual(3, server.nbytes) - self.event_loop.run_once() - - # received - self.assertEqual(8, protocol.nbytes) - - # extra info is available - self.assertIsNotNone(transport.get_extra_info('socket')) - conn = transport.get_extra_info('socket') - self.assertTrue(hasattr(conn, 'getsockname')) - - # close connection - transport.close() - self.assertEqual('CLOSED', protocol.state) - server.transport.close() + MyDatagramProto, local_addr=('localhost', 0)) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) - def test_create_datagram_connection_no_connection(self): - server = None + def test_create_datagram_connection_addr_error(self): + self.suppress_log_errors() - def factory(): - nonlocal server - server = TestMyDatagramProto() - return server + coro = self.event_loop.create_datagram_connection( + MyDatagramProto, local_addr='localhost') + self.assertRaises( + AssertionError, self.event_loop.run_until_complete, coro) + coro = self.event_loop.create_datagram_connection( + MyDatagramProto, local_addr=('localhost', 1, 2, 3)) + self.assertRaises( + AssertionError, self.event_loop.run_until_complete, coro) + def test_create_datagram_connection(self): class TestMyDatagramProto(MyDatagramProto): def datagram_received(self, data, addr): super().datagram_received(data, addr) self.transport.sendto(b'resp:'+data, addr) - f = self.event_loop.start_serving_datagram(factory, '127.0.0.1', 0) - sock = self.event_loop.run_until_complete(f) - host, port = sock.getsockname() + coro = self.event_loop.create_datagram_connection( + TestMyDatagramProto, local_addr=('127.0.0.1', 0)) + s_transport, server = self.event_loop.run_until_complete(coro) + host, port = s_transport.get_extra_info('addr') - coro = self.event_loop.create_datagram_connection(MyDatagramProto) - transport, protocol = self.event_loop.run_until_complete(coro) + coro = self.event_loop.create_datagram_connection( + MyDatagramProto, remote_addr=(host, port)) + transport, client = self.event_loop.run_until_complete(coro) - self.assertEqual('INITIALIZED', protocol.state) - transport.sendto(b'xxx', (host, port)) + self.assertEqual('INITIALIZED', client.state) + transport.sendto(b'xxx') self.event_loop.run_once() self.assertEqual(3, server.nbytes) self.event_loop.run_once() # received - self.assertEqual(8, protocol.nbytes) + self.assertEqual(8, client.nbytes) # extra info is available self.assertIsNotNone(transport.get_extra_info('socket')) @@ -855,18 +833,9 @@ def datagram_received(self, data, addr): # close connection transport.close() - self.assertEqual('CLOSED', protocol.state) - server.transport.close() - def test_create_datagram_connection_no_getaddrinfo(self): - self.suppress_log_errors() - getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() - getaddrinfo.return_value = [] - - coro = self.event_loop.create_datagram_connection( - protocols.DatagramProtocol, 'xkcd.com', 80) - self.assertRaises( - socket.error, self.event_loop.run_until_complete, coro) + self.assertEqual('CLOSED', client.state) + server.transport.close() def test_create_datagram_connection_connect_err(self): self.suppress_log_errors() @@ -874,92 +843,74 @@ def test_create_datagram_connection_connect_err(self): self.event_loop.sock_connect.side_effect = socket.error coro = self.event_loop.create_datagram_connection( - protocols.DatagramProtocol, 'xkcd.com', 80) + protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0)) self.assertRaises( socket.error, self.event_loop.run_until_complete, coro) @unittest.mock.patch('tulip.base_events.socket') - def test_create_datagram_connection_sockopt_err(self, m_socket): + def test_create_datagram_connection_socket_err(self, m_socket): self.suppress_log_errors() m_socket.error = socket.error - m_socket.socket.return_value.setsockopt.side_effect = socket.error + m_socket.getaddrinfo = socket.getaddrinfo + m_socket.socket.side_effect = socket.error coro = self.event_loop.create_datagram_connection( - protocols.DatagramProtocol) + protocols.DatagramProtocol, family=socket.AF_INET) self.assertRaises( socket.error, self.event_loop.run_until_complete, coro) - self.assertTrue( - m_socket.socket.return_value.close.called) - def test_start_serving_datagram(self): - class TestMyDatagramProto(MyDatagramProto): - def datagram_received(self, data, addr): - super().datagram_received(data, addr) - self.transport.sendto(b'resp:'+data, addr) - - proto = None - - def factory(): - nonlocal proto - proto = TestMyDatagramProto() - return proto - - f = self.event_loop.start_serving_datagram(factory, '127.0.0.1', 0) - sock = self.event_loop.run_until_complete(f) - self.assertEqual('INITIALIZED', proto.state) - - host, port = sock.getsockname() - self.assertEqual(host, '127.0.0.1') - client = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - client.sendto(b'xxx', ('127.0.0.1', port)) - self.event_loop.run_once() - self.assertEqual(3, proto.nbytes) + coro = self.event_loop.create_datagram_connection( + protocols.DatagramProtocol, local_addr=('127.0.0.1', 0)) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) - data, server = client.recvfrom(4096) - self.assertEqual(b'resp:xxx', data) + def test_create_datagram_connection_no_matching_family(self): + self.suppress_log_errors() - # extra info is available - self.assertIsNotNone(proto.transport.get_extra_info('socket')) - conn = proto.transport.get_extra_info('socket') - self.assertTrue(hasattr(conn, 'getsockname')) - self.assertEqual( - '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + coro = self.event_loop.create_datagram_connection( + protocols.DatagramProtocol, + remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) + self.assertRaises( + ValueError, self.event_loop.run_until_complete, coro) - # close connection - proto.transport.close() + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_connection_setblk_err(self, m_socket): + self.suppress_log_errors() - self.event_loop.run_once() - self.assertEqual('CLOSED', proto.state) + m_socket.error = socket.error + m_socket.socket.return_value.setblocking.side_effect = socket.error - client.close() + coro = self.event_loop.create_datagram_connection( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + self.assertTrue( + m_socket.socket.return_value.close.called) - def test_start_serving_datagram_no_getaddrinfoc(self): + def test_create_datagram_connection_noaddr_nofamily(self): self.suppress_log_errors() - getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() - getaddrinfo.return_value = [] - f = self.event_loop.start_serving_datagram( - MyDatagramProto, '0.0.0.0', 0) - - self.assertRaises( - socket.error, self.event_loop.run_until_complete, f) + coro = self.event_loop.create_datagram_connection( + protocols.DatagramProtocol) + self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) @unittest.mock.patch('tulip.base_events.socket') - def test_start_serving_datagram_cant_bind(self, m_socket): + def test_create_datagram_connection_cant_bind(self, m_socket): self.suppress_log_errors() class Err(socket.error): pass m_socket.error = socket.error - m_socket.getaddrinfo.return_value = [ - (2, 1, 6, '', ('127.0.0.1', 10100))] + m_socket.AF_INET6 = socket.AF_INET6 + m_socket.getaddrinfo = socket.getaddrinfo m_sock = m_socket.socket.return_value = unittest.mock.Mock() - m_sock.setsockopt.side_effect = Err + m_sock.bind.side_effect = Err - fut = self.event_loop.start_serving_datagram( - MyDatagramProto, '0.0.0.0', 0) + fut = self.event_loop.create_datagram_connection( + MyDatagramProto, + local_addr=('127.0.0.1', 0), family=socket.AF_INET) self.assertRaises(Err, self.event_loop.run_until_complete, fut) self.assertTrue(m_sock.close.called) @@ -1028,7 +979,8 @@ def connect(): os.close(wpipe) self.event_loop.run_once() - self.assertEqual(['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) # extra info is available self.assertIsNotNone(proto.transport.get_extra_info('pipe')) @@ -1119,13 +1071,7 @@ def test_create_datagram_connection_no_connection(self): raise unittest.SkipTest( "IocpEventLoop does not have " "create_datagram_connection_no_connection()") - def test_start_serving_datagram(self): - raise unittest.SkipTest( - "IocpEventLoop does not have start_serving_datagram()") - def test_start_serving_datagram_no_getaddrinfoc(self): - raise unittest.SkipTest( - "IocpEventLoop does not have start_serving_datagram()") - def test_start_serving_datagram_cant_bind(self): + def test_create_datagram_connection_cant_bind(self): raise unittest.SkipTest( "IocpEventLoop does not have start_serving_udp()") @@ -1298,9 +1244,6 @@ def test_not_imlemented(self): NotImplementedError, ev_loop.start_serving, f) self.assertRaises( NotImplementedError, ev_loop.create_datagram_connection, f) - self.assertRaises( - NotImplementedError, ev_loop.start_serving_datagram, - f, 'localhost', 8080) self.assertRaises( NotImplementedError, ev_loop.add_reader, 1, f) self.assertRaises( diff --git a/tulip/base_events.py b/tulip/base_events.py index 3e87548a..8c05c345 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -70,7 +70,7 @@ def _make_read_pipe_transport(self, pipe, protocol, waiter=None, raise NotImplementedError def _make_write_pipe_transport(self, pipe, protocol, waiter=None, - extra=None): + extra=None): """Create write pipe transport.""" raise NotImplementedError @@ -340,48 +340,72 @@ def create_connection(self, protocol_factory, host=None, port=None, *, @tasks.coroutine def create_datagram_connection(self, protocol_factory, - host=None, port=None, *, - family=socket.AF_INET, proto=0, flags=0): + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): """Create datagram connection.""" + if not (local_addr or remote_addr): + if family == 0: + raise ValueError('unexpected address family') + addr_pairs_info = (((family, proto), (None, None)),) + else: + # join addresss by (family, protocol) + addr_infos = collections.OrderedDict() + for idx, addr in ((0, local_addr), (1, remote_addr)): + if addr is not None: + assert isinstance(addr, tuple) and len(addr) == 2, ( + '2-tuple is expected') + + infos = yield from self.getaddrinfo( + *addr, family=family, type=socket.SOCK_DGRAM, + proto=proto, flags=flags) + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + for fam, _, pro, _, address in infos: + key = (fam, pro) + if key not in addr_infos: + addr_infos[key] = [None, None] + addr_infos[key][idx] = address + + # each addr has to have info for each (family, proto) pair + addr_pairs_info = [ + (key, addr_pair) for key, addr_pair in addr_infos.items() + if not ((local_addr and addr_pair[0] is None) or + (remote_addr and addr_pair[1] is None))] + + if not addr_pairs_info: + raise ValueError('can not get address information') - addr = None - if host is not None or port is not None: - infos = yield from self.getaddrinfo( - host, port, family=family, - type=socket.SOCK_DGRAM, proto=proto, flags=flags) - - if not infos: - raise socket.error('getaddrinfo() returned empty list') + exceptions = [] - exceptions = [] - for family, type, proto, cname, address in infos: - sock = socket.socket(family=family, type=type, proto=proto) + for (family, proto), (local_address, remote_address) in addr_pairs_info: + sock = None + l_addr = None + r_addr = None + try: + sock = socket.socket( + family=family, type=socket.SOCK_DGRAM, proto=proto) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(False) - try: - yield from self.sock_connect(sock, address) - addr = address - except socket.error as exc: + if local_addr: + sock.bind(local_address) + l_addr = sock.getsockname() + if remote_addr: + yield from self.sock_connect(sock, remote_address) + r_addr = remote_address + except socket.error as exc: + if sock is not None: sock.close() - exceptions.append(exc) - else: - break + exceptions.append(exc) else: - if exceptions: - raise exceptions[0] + break else: - sock = socket.socket( - family=family, type=socket.SOCK_DGRAM, proto=proto) - - try: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.setblocking(False) - except socket.error: - sock.close() - raise + raise exceptions[0] protocol = protocol_factory() - transport = self._make_datagram_transport(sock, protocol, addr) - + transport = self._make_datagram_transport( + sock, protocol, r_addr, extra={'addr': l_addr}) return transport, protocol # TODO: Or create_server()? @@ -426,37 +450,6 @@ def start_serving(self, protocol_factory, host=None, port=None, *, self._start_serving(protocol_factory, sock) return sock - @tasks.task - def start_serving_datagram(self, protocol_factory, host, port, *, - family=0, proto=0, flags=0): - """XXX""" - infos = yield from self.getaddrinfo( - host, port, family=family, - type=socket.SOCK_DGRAM, proto=proto, flags=flags) - - if not infos: - raise socket.error('getaddrinfo() returned empty list') - - exceptions = [] - for family, type, proto, cname, address in infos: - sock = socket.socket(family=family, type=type, proto=proto) - try: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(address) - except socket.error as exc: - sock.close() - exceptions.append(exc) - else: - sock.setblocking(False) - break - else: - raise exceptions[0] - - self._make_datagram_transport( - sock, protocol_factory(), extra={'addr': sock.getsockname()}) - - return sock - @tasks.coroutine def connect_read_pipe(self, protocol_factory, pipe): protocol = protocol_factory() diff --git a/tulip/events.py b/tulip/events.py index da892995..f6a8352f 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -184,15 +184,11 @@ def start_serving(self, protocol_factory, host=None, port=None, *, family=0, proto=0, flags=0, sock=None): raise NotImplementedError - def create_datagram_connection(self, protocol_factory, - host=None, port=None, *, + def create_datagram_connection(self, protocol, + local_addr=None, remote_addr=None, *, family=0, proto=0, flags=0): raise NotImplementedError - def start_serving_datagram(self, protocol_factory, host, port, *, - family=0, proto=0, flags=0): - raise NotImplementedError - def connect_read_pipe(self, protocol_factory, pipe): """Register read pipe in eventloop. From 115139d8149ff371bcbcfb3152d5bb9956db6506 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 22 Mar 2013 18:52:57 -0700 Subject: [PATCH 0377/1502] better name --- examples/udp_echo.py | 4 +-- tests/events_test.py | 67 ++++++++++++++++++++++++-------------------- tulip/base_events.py | 6 ++-- tulip/events.py | 6 ++-- 4 files changed, 45 insertions(+), 38 deletions(-) diff --git a/examples/udp_echo.py b/examples/udp_echo.py index 1597812a..9e995d14 100644 --- a/examples/udp_echo.py +++ b/examples/udp_echo.py @@ -54,14 +54,14 @@ def connection_lost(self, exc): def start_server(): loop = tulip.get_event_loop() - tulip.Task(loop.create_datagram_connection( + tulip.Task(loop.create_datagram_endpoint( MyServerUdpEchoProtocol, local_addr=ADDRESS)) loop.run_forever() def start_client(): loop = tulip.get_event_loop() - tulip.Task(loop.create_datagram_connection( + tulip.Task(loop.create_datagram_endpoint( MyClientUdpEchoProtocol, remote_addr=ADDRESS)) loop.run_forever() diff --git a/tests/events_test.py b/tests/events_test.py index 01507243..dcfd36c0 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -779,41 +779,41 @@ class Err(socket.error): self.assertTrue(m_sock.close.called) @unittest.mock.patch('tulip.base_events.socket') - def test_create_datagram_connection_no_addrinfo(self, m_socket): + def test_create_datagram_endpoint_no_addrinfo(self, m_socket): self.suppress_log_errors() m_socket.error = socket.error m_socket.getaddrinfo.return_value = [] - coro = self.event_loop.create_datagram_connection( + coro = self.event_loop.create_datagram_endpoint( MyDatagramProto, local_addr=('localhost', 0)) self.assertRaises( socket.error, self.event_loop.run_until_complete, coro) - def test_create_datagram_connection_addr_error(self): + def test_create_datagram_endpoint_addr_error(self): self.suppress_log_errors() - coro = self.event_loop.create_datagram_connection( + coro = self.event_loop.create_datagram_endpoint( MyDatagramProto, local_addr='localhost') self.assertRaises( AssertionError, self.event_loop.run_until_complete, coro) - coro = self.event_loop.create_datagram_connection( + coro = self.event_loop.create_datagram_endpoint( MyDatagramProto, local_addr=('localhost', 1, 2, 3)) self.assertRaises( AssertionError, self.event_loop.run_until_complete, coro) - def test_create_datagram_connection(self): + def test_create_datagram_endpoint(self): class TestMyDatagramProto(MyDatagramProto): def datagram_received(self, data, addr): super().datagram_received(data, addr) self.transport.sendto(b'resp:'+data, addr) - coro = self.event_loop.create_datagram_connection( + coro = self.event_loop.create_datagram_endpoint( TestMyDatagramProto, local_addr=('127.0.0.1', 0)) s_transport, server = self.event_loop.run_until_complete(coro) host, port = s_transport.get_extra_info('addr') - coro = self.event_loop.create_datagram_connection( + coro = self.event_loop.create_datagram_endpoint( MyDatagramProto, remote_addr=(host, port)) transport, client = self.event_loop.run_until_complete(coro) @@ -837,66 +837,66 @@ def datagram_received(self, data, addr): self.assertEqual('CLOSED', client.state) server.transport.close() - def test_create_datagram_connection_connect_err(self): + def test_create_datagram_endpoint_connect_err(self): self.suppress_log_errors() self.event_loop.sock_connect = unittest.mock.Mock() self.event_loop.sock_connect.side_effect = socket.error - coro = self.event_loop.create_datagram_connection( + coro = self.event_loop.create_datagram_endpoint( protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0)) self.assertRaises( socket.error, self.event_loop.run_until_complete, coro) @unittest.mock.patch('tulip.base_events.socket') - def test_create_datagram_connection_socket_err(self, m_socket): + def test_create_datagram_endpoint_socket_err(self, m_socket): self.suppress_log_errors() m_socket.error = socket.error m_socket.getaddrinfo = socket.getaddrinfo m_socket.socket.side_effect = socket.error - coro = self.event_loop.create_datagram_connection( + coro = self.event_loop.create_datagram_endpoint( protocols.DatagramProtocol, family=socket.AF_INET) self.assertRaises( socket.error, self.event_loop.run_until_complete, coro) - coro = self.event_loop.create_datagram_connection( + coro = self.event_loop.create_datagram_endpoint( protocols.DatagramProtocol, local_addr=('127.0.0.1', 0)) self.assertRaises( socket.error, self.event_loop.run_until_complete, coro) - def test_create_datagram_connection_no_matching_family(self): + def test_create_datagram_endpoint_no_matching_family(self): self.suppress_log_errors() - coro = self.event_loop.create_datagram_connection( + coro = self.event_loop.create_datagram_endpoint( protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) self.assertRaises( ValueError, self.event_loop.run_until_complete, coro) @unittest.mock.patch('tulip.base_events.socket') - def test_create_datagram_connection_setblk_err(self, m_socket): + def test_create_datagram_endpoint_setblk_err(self, m_socket): self.suppress_log_errors() m_socket.error = socket.error m_socket.socket.return_value.setblocking.side_effect = socket.error - coro = self.event_loop.create_datagram_connection( + coro = self.event_loop.create_datagram_endpoint( protocols.DatagramProtocol, family=socket.AF_INET) self.assertRaises( socket.error, self.event_loop.run_until_complete, coro) self.assertTrue( m_socket.socket.return_value.close.called) - def test_create_datagram_connection_noaddr_nofamily(self): + def test_create_datagram_endpoint_noaddr_nofamily(self): self.suppress_log_errors() - coro = self.event_loop.create_datagram_connection( + coro = self.event_loop.create_datagram_endpoint( protocols.DatagramProtocol) self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) @unittest.mock.patch('tulip.base_events.socket') - def test_create_datagram_connection_cant_bind(self, m_socket): + def test_create_datagram_endpoint_cant_bind(self, m_socket): self.suppress_log_errors() class Err(socket.error): @@ -908,7 +908,7 @@ class Err(socket.error): m_sock = m_socket.socket.return_value = unittest.mock.Mock() m_sock.bind.side_effect = Err - fut = self.event_loop.create_datagram_connection( + fut = self.event_loop.create_datagram_endpoint( MyDatagramProto, local_addr=('127.0.0.1', 0), family=socket.AF_INET) self.assertRaises(Err, self.event_loop.run_until_complete, fut) @@ -1064,17 +1064,24 @@ def test_accept_connection_retry(self): def test_accept_connection_exception(self): raise unittest.SkipTest( "IocpEventLoop does not have _accept_connection()") - def test_create_datagram_connection(self): + def test_create_datagram_endpoint(self): raise unittest.SkipTest( - "IocpEventLoop does not have create_datagram_connection()") - def test_create_datagram_connection_no_connection(self): + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_no_connection(self): raise unittest.SkipTest( - "IocpEventLoop does not have " - "create_datagram_connection_no_connection()") - def test_create_datagram_connection_cant_bind(self): + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_cant_bind(self): raise unittest.SkipTest( - "IocpEventLoop does not have start_serving_udp()") - + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_noaddr_nofamily(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_socket_err(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_connect_err(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") else: from tulip import selectors from tulip import unix_events @@ -1243,7 +1250,7 @@ def test_not_imlemented(self): self.assertRaises( NotImplementedError, ev_loop.start_serving, f) self.assertRaises( - NotImplementedError, ev_loop.create_datagram_connection, f) + NotImplementedError, ev_loop.create_datagram_endpoint, f) self.assertRaises( NotImplementedError, ev_loop.add_reader, 1, f) self.assertRaises( diff --git a/tulip/base_events.py b/tulip/base_events.py index 8c05c345..46cab27a 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -339,9 +339,9 @@ def create_connection(self, protocol_factory, host=None, port=None, *, return transport, protocol @tasks.coroutine - def create_datagram_connection(self, protocol_factory, - local_addr=None, remote_addr=None, *, - family=0, proto=0, flags=0): + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): """Create datagram connection.""" if not (local_addr or remote_addr): if family == 0: diff --git a/tulip/events.py b/tulip/events.py index f6a8352f..ba9a50f9 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -184,9 +184,9 @@ def start_serving(self, protocol_factory, host=None, port=None, *, family=0, proto=0, flags=0, sock=None): raise NotImplementedError - def create_datagram_connection(self, protocol, - local_addr=None, remote_addr=None, *, - family=0, proto=0, flags=0): + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): raise NotImplementedError def connect_read_pipe(self, protocol_factory, pipe): From a2442fe9dae56aadeb22ec282ea3373757e1a431 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 22 Mar 2013 19:32:13 -0700 Subject: [PATCH 0378/1502] code cleanup --- tests/base_events_test.py | 10 ++++++ tests/events_test.py | 1 + tests/queues_test.py | 6 ++-- tests/selector_events_test.py | 13 ++++---- tests/transports_test.py | 6 ++++ tests/unix_events_test.py | 9 +++--- tulip/proactor_events.py | 2 +- tulip/selector_events.py | 37 ++++++++++++---------- tulip/subprocess_transport.py | 59 +++++++++++++++++++---------------- tulip/transports.py | 1 + tulip/unix_events.py | 34 ++++++++++---------- 11 files changed, 103 insertions(+), 75 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index 5a71bc5d..85b013fa 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -41,6 +41,12 @@ def test_not_implemented(self): NotImplementedError, self.event_loop._write_to_self) self.assertRaises( NotImplementedError, self.event_loop._read_from_self) + self.assertRaises( + NotImplementedError, + self.event_loop._make_read_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + self.event_loop._make_write_pipe_transport, m, m) def test_add_callback_handle(self): h = events.Handle(lambda: False, ()) @@ -241,6 +247,10 @@ def cb(event_loop): self.assertTrue(processed) self.assertEqual([handle], list(self.event_loop._ready)) + def test_run_until_complete_assertion(self): + self.assertRaises( + AssertionError, self.event_loop.run_until_complete, 'blah') + @unittest.mock.patch('tulip.base_events.socket') def test_create_connection_mutiple_errors(self, m_socket): self.suppress_log_errors() diff --git a/tests/events_test.py b/tests/events_test.py index dcfd36c0..3b54ab47 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -191,6 +191,7 @@ def test_run_nesting(self): def coro(): nonlocal err yield from [] + self.assertTrue(self.event_loop.is_running()) try: self.event_loop.run_until_complete( tasks.sleep(0.1)) diff --git a/tests/queues_test.py b/tests/queues_test.py index c722c1ac..d86abd7f 100644 --- a/tests/queues_test.py +++ b/tests/queues_test.py @@ -23,9 +23,9 @@ class QueueBasicTests(_QueueTestBase): def _test_repr_or_str(self, fn, expect_id): """Test Queue's repr or str. - - fn is repr or str. expect_id is True if we expect the Queue's id to - appear in fn(Queue()). + + fn is repr or str. expect_id is True if we expect the Queue's id to + appear in fn(Queue()). """ q = queues.Queue() self.assertTrue(fn(q).startswith('= len(data): - return - if n > 0: - data = data[n:] + else: + if n == len(data): + return + elif n: + data = data[n:] self._event_loop.add_writer(self._wstdin, self._stdin_callback) self._buffer.append(data) @@ -94,39 +97,41 @@ def _fatal_error(self, exc): def _stdin_callback(self): data = b''.join(self._buffer) + assert data, "Data shold not be empty" + self._buffer = [] try: - if data: - n = os.write(self._wstdin, data) - else: - n = 0 + n = os.write(self._wstdin, data) except BlockingIOError: - n = 0 + self._buffer.append(data) except Exception as exc: self._fatal_error(exc) - return - if n >= len(data): - self._event_loop.remove_writer(self._wstdin) - if self._eof: - os.close(self._wstdin) - self._wstdin = -1 - return - if n > 0: - data = data[n:] - self._buffer.append(data) # Try again later. + else: + if n >= len(data): + self._event_loop.remove_writer(self._wstdin) + if self._eof: + os.close(self._wstdin) + self._wstdin = -1 + return + + elif n > 0: + data = data[n:] + + self._buffer.append(data) # Try again later. def _stdout_callback(self): try: data = os.read(self._rstdout, 1024) except BlockingIOError: - return - if data: - self._event_loop.call_soon(self._protocol.data_received, data) + pass else: - self._event_loop.remove_reader(self._rstdout) - os.close(self._rstdout) - self._rstdout = -1 - self._event_loop.call_soon(self._protocol.eof_received) + if data: + self._event_loop.call_soon(self._protocol.data_received, data) + else: + self._event_loop.remove_reader(self._rstdout) + os.close(self._rstdout) + self._rstdout = -1 + self._event_loop.call_soon(self._protocol.eof_received) def _setnonblocking(fd): diff --git a/tulip/transports.py b/tulip/transports.py index 35d5bb17..a9ec07a0 100644 --- a/tulip/transports.py +++ b/tulip/transports.py @@ -122,6 +122,7 @@ def sendto(self, data, addr=None): addr is target socket address. If addr is None use target address pointed on transport creation. """ + raise NotImplementedError def abort(self): """Closes the transport immediately. diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 833c6612..9e75e3e9 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -221,10 +221,11 @@ def __init__(self, event_loop, pipe, protocol, waiter=None, extra=None): self._event_loop.call_soon(waiter.set_result, None) def write(self, data): - assert isinstance(data, (bytes, bytearray)), repr(data) + assert isinstance(data, bytes), repr(data) assert not self._closing if not data: return + if not self._buffer: # Attempt to send it right away first. try: @@ -239,29 +240,30 @@ def write(self, data): elif n > 0: data = data[n:] self._event_loop.add_writer(self._fileno, self._write_ready) - assert data, "Data shold not be empty" + self._buffer.append(data) def _write_ready(self): data = b''.join(self._buffer) - assert data, "Data shold not be empty" + assert data, "Data should not be empty" + + self._buffer.clear() try: n = os.write(self._fileno, data) except BlockingIOError: - self._buffer = [data] - return + self._buffer.append(data) except Exception as exc: self._fatal_error(exc) - return - if n == len(data): - self._buffer = [] - self._event_loop.remove_writer(self._fileno) - if self._closing: - self._call_connection_lost(None) - return - elif n > 0: - data = data[n:] - self._buffer = [data] # Try again later. + else: + if n == len(data): + self._event_loop.remove_writer(self._fileno) + if self._closing: + self._call_connection_lost(None) + return + elif n > 0: + data = data[n:] + + self._buffer.append(data) # Try again later. def can_write_eof(self): return True @@ -288,7 +290,7 @@ def _fatal_error(self, exc): def _close(self, exc=None): self._closing = True - self._buffer = [] + self._buffer.clear() self._event_loop.remove_writer(self._fileno) self._call_connection_lost(exc) From ba1e6285fea62ea9265229c96f2e1ef604aa7cba Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Sat, 23 Mar 2013 11:07:58 -0700 Subject: [PATCH 0379/1502] Switch to BlockingIOError instead of errno checking --- tests/events_test.py | 8 +- tests/selector_events_test.py | 164 +++++++++------------------------- tests/winsocketpair_test.py | 11 +-- tulip/selector_events.py | 141 +++++++++++++---------------- tulip/winsocketpair.py | 12 +-- 5 files changed, 112 insertions(+), 224 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 3b54ab47..7a64012b 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -2,7 +2,6 @@ import concurrent.futures import contextlib -import errno import gc import io import os @@ -916,11 +915,8 @@ class Err(socket.error): self.assertTrue(m_sock.close.called) def test_accept_connection_retry(self): - class Err(socket.error): - errno = errno.EAGAIN - sock = unittest.mock.Mock() - sock.accept.side_effect = Err + sock.accept.side_effect = BlockingIOError() self.event_loop._accept_connection(MyProto, sock) self.assertFalse(sock.close.called) @@ -929,7 +925,7 @@ def test_accept_connection_exception(self): self.suppress_log_errors() sock = unittest.mock.Mock() - sock.accept.side_effect = socket.error + sock.accept.side_effect = OSError() self.event_loop._accept_connection(MyProto, sock) self.assertTrue(sock.close.called) diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index b187c6bf..4836008b 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -87,32 +87,20 @@ def test_socketpair(self): self.assertRaises(NotImplementedError, self.event_loop._socketpair) def test_read_from_self_tryagain(self): - class Err(socket.error): - errno = errno.EAGAIN - - self.event_loop._ssock.recv.side_effect = Err + self.event_loop._ssock.recv.side_effect = BlockingIOError self.assertIsNone(self.event_loop._read_from_self()) def test_read_from_self_exception(self): - class Err(socket.error): - pass - - self.event_loop._ssock.recv.side_effect = Err - self.assertRaises(Err, self.event_loop._read_from_self) + self.event_loop._ssock.recv.side_effect = OSError + self.assertRaises(OSError, self.event_loop._read_from_self) def test_write_to_self_tryagain(self): - class Err(socket.error): - errno = errno.EAGAIN - - self.event_loop._csock.send.side_effect = Err + self.event_loop._csock.send.side_effect = BlockingIOError self.assertIsNone(self.event_loop._write_to_self()) def test_write_to_self_exception(self): - class Err(socket.error): - pass - - self.event_loop._csock.send.side_effect = Err - self.assertRaises(Err, self.event_loop._write_to_self) + self.event_loop._csock.send.side_effect = OSError() + self.assertRaises(OSError, self.event_loop._write_to_self) def test_sock_recv(self): sock = unittest.mock.Mock() @@ -144,13 +132,10 @@ def test__sock_recv_unregister(self): self.assertEqual((10,), self.event_loop.remove_reader.call_args[0]) def test__sock_recv_tryagain(self): - class Err(socket.error): - errno = errno.EAGAIN - f = futures.Future() sock = unittest.mock.Mock() sock.fileno.return_value = 10 - sock.recv.side_effect = Err + sock.recv.side_effect = BlockingIOError self.event_loop.add_reader = unittest.mock.Mock() self.event_loop._sock_recv(f, False, sock, 1024) @@ -158,16 +143,13 @@ class Err(socket.error): self.event_loop.add_reader.call_args[0]) def test__sock_recv_exception(self): - class Err(socket.error): - pass - f = futures.Future() sock = unittest.mock.Mock() sock.fileno.return_value = 10 - sock.recv.side_effect = Err + err = sock.recv.side_effect = OSError() self.event_loop._sock_recv(f, False, sock, 1024) - self.assertIsInstance(f.exception(), Err) + self.assertIs(err, f.exception()) def test_sock_sendall(self): sock = unittest.mock.Mock() @@ -209,13 +191,10 @@ def test__sock_sendall_unregister(self): self.assertEqual((10,), self.event_loop.remove_writer.call_args[0]) def test__sock_sendall_tryagain(self): - class Err(socket.error): - errno = errno.EAGAIN - f = futures.Future() sock = unittest.mock.Mock() sock.fileno.return_value = 10 - sock.send.side_effect = Err + sock.send.side_effect = BlockingIOError self.event_loop.add_writer = unittest.mock.Mock() self.event_loop._sock_sendall(f, False, sock, b'data') @@ -224,16 +203,13 @@ class Err(socket.error): self.event_loop.add_writer.call_args[0]) def test__sock_sendall_exception(self): - class Err(socket.error): - pass - f = futures.Future() sock = unittest.mock.Mock() sock.fileno.return_value = 10 - sock.send.side_effect = Err + err = sock.send.side_effect = OSError() self.event_loop._sock_sendall(f, False, sock, b'data') - self.assertIsInstance(f.exception(), Err) + self.assertIs(f.exception(), err) def test__sock_sendall(self): sock = unittest.mock.Mock() @@ -382,13 +358,10 @@ def test__sock_accept_unregister(self): self.assertEqual((10,), self.event_loop.remove_reader.call_args[0]) def test__sock_accept_tryagain(self): - class Err(socket.error): - errno = errno.EAGAIN - f = futures.Future() sock = unittest.mock.Mock() sock.fileno.return_value = 10 - sock.accept.side_effect = Err + sock.accept.side_effect = BlockingIOError self.event_loop.add_reader = unittest.mock.Mock() self.event_loop._sock_accept(f, False, sock) @@ -397,16 +370,13 @@ class Err(socket.error): self.event_loop.add_reader.call_args[0]) def test__sock_accept_exception(self): - class Err(socket.error): - pass - f = futures.Future() sock = unittest.mock.Mock() sock.fileno.return_value = 10 - sock.accept.side_effect = Err + err = sock.accept.side_effect = OSError() self.event_loop._sock_accept(f, False, sock) - self.assertIsInstance(f.exception(), Err) + self.assertIs(err, f.exception()) def test_add_reader(self): self.event_loop._selector.get_info.side_effect = KeyError @@ -599,13 +569,10 @@ def test_read_ready_eof(self): @unittest.mock.patch('logging.exception') def test_read_ready_tryagain(self, m_exc): + self.sock.recv.side_effect = BlockingIOError + transport = _SelectorSocketTransport( self.event_loop, self.sock, self.protocol) - - class Err(socket.error): - errno = errno.EAGAIN - - self.sock.recv.side_effect = Err() transport._fatal_error = unittest.mock.Mock() transport._read_ready() @@ -613,13 +580,10 @@ class Err(socket.error): @unittest.mock.patch('logging.exception') def test_read_ready_err(self, m_exc): + err = self.sock.recv.side_effect = OSError() + transport = _SelectorSocketTransport( self.event_loop, self.sock, self.protocol) - - class Err(socket.error): - pass - - err = self.sock.recv.side_effect = Err() transport._fatal_error = unittest.mock.Mock() transport._read_ready() @@ -686,13 +650,9 @@ def test_write_partial_none(self): self.assertEqual([b'data'], transport._buffer) def test_write_tryagain(self): - data = b'data' - - class Err(socket.error): - errno = errno.EAGAIN - - self.sock.send.side_effect = Err + self.sock.send.side_effect = BlockingIOError + data = b'data' transport = _SelectorSocketTransport( self.event_loop, self.sock, self.protocol) transport.write(data) @@ -704,20 +664,14 @@ class Err(socket.error): self.assertEqual([b'data'], transport._buffer) def test_write_exception(self): - data = b'data' - - class Err(socket.error): - pass - - self.sock.send.side_effect = Err + err = self.sock.send.side_effect = OSError() + data = b'data' transport = _SelectorSocketTransport( self.event_loop, self.sock, self.protocol) transport._fatal_error = unittest.mock.Mock() transport.write(data) - - self.assertTrue(transport._fatal_error.called) - self.assertIsInstance(transport._fatal_error.call_args[0][0], Err) + transport._fatal_error.assert_called_with(err) def test_write_str(self): transport = _SelectorSocketTransport( @@ -783,10 +737,7 @@ def test_write_ready_partial_none(self): self.assertEqual([b'data'], transport._buffer) def test_write_ready_tryagain(self): - class Err(socket.error): - errno = errno.EAGAIN - - self.sock.send.side_effect = Err + self.sock.send.side_effect = BlockingIOError transport = _SelectorSocketTransport( self.event_loop, self.sock, self.protocol) @@ -797,10 +748,7 @@ class Err(socket.error): self.assertEqual([b'data1data2'], transport._buffer) def test_write_ready_exception(self): - class Err(socket.error): - pass - - err = self.sock.send.side_effect = Err() + err = self.sock.send.side_effect = OSError() transport = _SelectorSocketTransport( self.event_loop, self.sock, self.protocol) @@ -981,21 +929,15 @@ def test_on_ready_recv_retry(self): self.transport._on_ready() self.assertFalse(self.protocol.data_received.called) - class Err(socket.error): - errno = errno.EAGAIN - - self.sslsock.recv.side_effect = Err + self.sslsock.recv.side_effect = BlockingIOError self.transport._on_ready() self.assertFalse(self.protocol.data_received.called) def test_on_ready_recv_exc(self): - class Err(socket.error): - pass - - self.sslsock.recv.side_effect = Err + err = self.sslsock.recv.side_effect = OSError() self.transport._fatal_error = unittest.mock.Mock() self.transport._on_ready() - self.assertTrue(self.transport._fatal_error.called) + self.transport._fatal_error.assert_called_with(err) def test_on_ready_send(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError @@ -1053,24 +995,18 @@ def test_on_ready_send_retry(self): self.transport._on_ready() self.assertEqual([b'data'], self.transport._buffer) - class Err(socket.error): - errno = errno.EAGAIN - - self.sslsock.send.side_effect = Err + self.sslsock.send.side_effect = BlockingIOError() self.transport._on_ready() self.assertEqual([b'data'], self.transport._buffer) def test_on_ready_send_exc(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError + err = self.sslsock.send.side_effect = OSError() - class Err(socket.error): - pass - - self.sslsock.send.side_effect = Err self.transport._buffer = [b'data'] self.transport._fatal_error = unittest.mock.Mock() self.transport._on_ready() - self.assertTrue(self.transport._fatal_error.called) + self.transport._fatal_error.assert_called_with(err) self.assertEqual([], self.transport._buffer) @@ -1096,10 +1032,7 @@ def test_read_ready_tryagain(self): transport = _SelectorDatagramTransport( self.event_loop, self.sock, self.protocol) - class Err(socket.error): - errno = errno.EAGAIN - - self.sock.recvfrom.side_effect = Err + self.sock.recvfrom.side_effect = BlockingIOError transport._fatal_error = unittest.mock.Mock() transport._read_ready() @@ -1109,10 +1042,7 @@ def test_read_ready_err(self): transport = _SelectorDatagramTransport( self.event_loop, self.sock, self.protocol) - class Err(socket.error): - pass - - err = self.sock.recvfrom.side_effect = Err() + err = self.sock.recvfrom.side_effect = OSError() transport._fatal_error = unittest.mock.Mock() transport._read_ready() @@ -1158,10 +1088,7 @@ def test_sendto_buffer(self): def test_sendto_tryagain(self): data = b'data' - class Err(socket.error): - errno = errno.EAGAIN - - self.sock.sendto.side_effect = Err + self.sock.sendto.side_effect = BlockingIOError transport = _SelectorDatagramTransport( self.event_loop, self.sock, self.protocol) @@ -1177,11 +1104,7 @@ class Err(socket.error): def test_sendto_exception(self): data = b'data' - - class Err(socket.error): - pass - - self.sock.sendto.side_effect = Err + err = self.sock.sendto.side_effect = OSError() transport = _SelectorDatagramTransport( self.event_loop, self.sock, self.protocol) @@ -1189,7 +1112,7 @@ class Err(socket.error): transport.sendto(data, ()) self.assertTrue(transport._fatal_error.called) - self.assertIsInstance(transport._fatal_error.call_args[0][0], Err) + transport._fatal_error.assert_called_with(err) def test_sendto_connection_refused(self): data = b'data' @@ -1266,10 +1189,7 @@ def test_sendto_ready_no_data(self): self.assertTrue(self.event_loop.remove_writer.called) def test_sendto_ready_tryagain(self): - class Err(socket.error): - errno = errno.EAGAIN - - self.sock.sendto.side_effect = Err + self.sock.sendto.side_effect = BlockingIOError transport = _SelectorDatagramTransport( self.event_loop, self.sock, self.protocol) @@ -1282,10 +1202,7 @@ class Err(socket.error): list(transport._buffer)) def test_sendto_ready_exception(self): - class Err(socket.error): - pass - - self.sock.sendto.side_effect = Err + err = self.sock.sendto.side_effect = OSError() transport = _SelectorDatagramTransport( self.event_loop, self.sock, self.protocol) @@ -1293,8 +1210,7 @@ class Err(socket.error): transport._buffer.append((b'data', ())) transport._sendto_ready() - self.assertTrue(transport._fatal_error.called) - self.assertIsInstance(transport._fatal_error.call_args[0][0], Err) + transport._fatal_error.assert_called_with(err) def test_sendto_ready_connection_refused(self): self.sock.sendto.side_effect = ConnectionRefusedError diff --git a/tests/winsocketpair_test.py b/tests/winsocketpair_test.py index b93bdaf1..381fb227 100644 --- a/tests/winsocketpair_test.py +++ b/tests/winsocketpair_test.py @@ -1,7 +1,5 @@ """Tests for winsocketpair.py""" -import errno -import socket import unittest import unittest.mock @@ -21,13 +19,8 @@ def test_winsocketpair(self): @unittest.mock.patch('tulip.winsocketpair.socket') def test_winsocketpair_exc(self, m_socket): - m_socket.error = socket.error - - class Err(socket.error): - errno = errno.WSAEWOULDBLOCK + 1 - m_socket.socket.return_value.getsockname.return_value = ('', 12345) m_socket.socket.return_value.accept.return_value = object(), object() - m_socket.socket.return_value.connect.side_effect = Err + m_socket.socket.return_value.connect.side_effect = OSError() - self.assertRaises(Err, winsocketpair.socketpair) + self.assertRaises(OSError, winsocketpair.socketpair) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index d488990b..46e04dae 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -12,7 +12,6 @@ import ssl except ImportError: # pragma: no cover ssl = None -import sys from . import base_events from . import events @@ -30,11 +29,6 @@ errno.EBADF, )) -# Errno values indicating the socket isn't ready for I/O just yet. -_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK, errno.EINPROGRESS)) -if sys.platform == 'win32': # pragma: no cover - _TRYAGAIN = frozenset(list(_TRYAGAIN) + [errno.WSAEWOULDBLOCK]) - class BaseSelectorEventLoop(base_events.BaseEventLoop): """Selector event loop. @@ -91,18 +85,14 @@ def _make_self_pipe(self): def _read_from_self(self): try: self._ssock.recv(1) - except socket.error as exc: - if exc.errno in _TRYAGAIN: - return - raise # Halp! + except (BlockingIOError, InterruptedError): + pass def _write_to_self(self): try: self._csock.send(b'x') - except socket.error as exc: - if exc.errno in _TRYAGAIN: - return - raise # Halp! + except (BlockingIOError, InterruptedError): + pass def _start_serving(self, protocol_factory, sock): self.add_reader(sock.fileno(), self._accept_connection, @@ -111,19 +101,18 @@ def _start_serving(self, protocol_factory, sock): def _accept_connection(self, protocol_factory, sock): try: conn, addr = sock.accept() - except socket.error as exc: - if exc.errno in _TRYAGAIN: - return # False alarm. - # Bad error. Stop serving. + except (BlockingIOError, InterruptedError): + pass # False alarm. + except: + # Bad error. Stop serving. self.remove_reader(sock.fileno()) sock.close() # There's nowhere to send the error, so just log it. # TODO: Someone will want an error handler for this. logging.exception('Accept failed') - return - - self._make_socket_transport( - conn, protocol_factory(), extra={'addr': addr}) + else: + self._make_socket_transport( + conn, protocol_factory(), extra={'addr': addr}) # It's now up to the protocol to handle the connection. def add_reader(self, fd, callback, *args): @@ -216,11 +205,10 @@ def _sock_recv(self, fut, registered, sock, n): try: data = sock.recv(n) fut.set_result(data) - except socket.error as exc: - if exc.errno not in _TRYAGAIN: - fut.set_exception(exc) - else: - self.add_reader(fd, self._sock_recv, fut, True, sock, n) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_recv, fut, True, sock, n) + except Exception as exc: + fut.set_exception(exc) def sock_sendall(self, sock, data): """XXX""" @@ -241,11 +229,11 @@ def _sock_sendall(self, fut, registered, sock, data): try: n = sock.send(data) - except socket.error as exc: - if exc.errno not in _TRYAGAIN: - fut.set_exception(exc) - return + except (BlockingIOError, InterruptedError): n = 0 + except Exception as exc: + fut.set_exception(exc) + return if n == len(data): fut.set_result(None) @@ -280,12 +268,10 @@ def _sock_connect(self, fut, registered, sock, address): # Jump to the except clause below. raise socket.error(err, 'Connect call failed') fut.set_result(None) - except socket.error as exc: - if exc.errno not in _TRYAGAIN: - fut.set_exception(exc) - else: - self.add_writer(fd, self._sock_connect, - fut, True, sock, address) + except (BlockingIOError, InterruptedError): + self.add_writer(fd, self._sock_connect, fut, True, sock, address) + except Exception as exc: + fut.set_exception(exc) def sock_accept(self, sock): """XXX""" @@ -303,11 +289,10 @@ def _sock_accept(self, fut, registered, sock): conn, address = sock.accept() conn.setblocking(False) fut.set_result((conn, address)) - except socket.error as exc: - if exc.errno not in _TRYAGAIN: - fut.set_exception(exc) - else: - self.add_reader(fd, self._sock_accept, fut, True, sock) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_accept, fut, True, sock) + except Exception as exc: + fut.set_exception(exc) def _process_events(self, event_list): for fileobj, mask, (reader, writer) in event_list: @@ -341,9 +326,10 @@ def __init__(self, event_loop, sock, protocol, waiter=None, extra=None): def _read_ready(self): try: data = self._sock.recv(16*1024) - except socket.error as exc: - if exc.errno not in _TRYAGAIN: - self._fatal_error(exc) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) else: if data: self._protocol.data_received(data) @@ -361,12 +347,11 @@ def write(self, data): # Attempt to send it right away first. try: n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 except socket.error as exc: - if exc.errno in _TRYAGAIN: - n = 0 - else: - self._fatal_error(exc) - return + self._fatal_error(exc) + return if n == len(data): return @@ -383,11 +368,10 @@ def _write_ready(self): self._buffer.clear() try: n = self._sock.send(data) - except socket.error as exc: - if exc.errno in _TRYAGAIN: - self._buffer.append(data) - else: - self._fatal_error(exc) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except Exception as exc: + self._fatal_error(exc) else: if n == len(data): self._event_loop.remove_writer(self._sock.fileno()) @@ -492,10 +476,10 @@ def _on_ready(self): pass except ssl.SSLWantWriteError: pass - except socket.error as exc: - if exc.errno not in _TRYAGAIN: - self._fatal_error(exc) - return + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) else: if data: self._protocol.data_received(data) @@ -520,12 +504,11 @@ def _on_ready(self): n = 0 except ssl.SSLWantWriteError: n = 0 - except socket.error as exc: - if exc.errno not in _TRYAGAIN: - self._fatal_error(exc) - return - else: - n = 0 + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + self._fatal_error(exc) + return if n < len(data): self._buffer.append(data[n:]) @@ -584,9 +567,10 @@ def __init__(self, event_loop, sock, protocol, address=None, extra=None): def _read_ready(self): try: data, addr = self._sock.recvfrom(self.max_size) - except socket.error as exc: - if exc.errno not in _TRYAGAIN: - self._fatal_error(exc) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) else: self._protocol.datagram_received(data, addr) @@ -610,11 +594,12 @@ def sendto(self, data, addr=None): except ConnectionRefusedError as exc: if self._address: self._fatal_error(exc) - except socket.error as exc: - if exc.errno not in _TRYAGAIN: - self._fatal_error(exc) - + return + except (BlockingIOError, InterruptedError): self._event_loop.add_writer(self._fileno, self._sendto_ready) + except Exception as exc: + self._fatal_error(exc) + return self._buffer.append((data, addr)) @@ -630,14 +615,12 @@ def _sendto_ready(self): if self._address: self._fatal_error(exc) return - except socket.error as exc: - if exc.errno not in _TRYAGAIN: - self._fatal_error(exc) - return - - # Try again later. - self._buffer.appendleft((data, addr)) + except (BlockingIOError, InterruptedError): + self._buffer.appendleft((data, addr)) # Try again later. break + except Exception as exc: + self._fatal_error(exc) + return if not self._buffer: self._event_loop.remove_writer(self._fileno) diff --git a/tulip/winsocketpair.py b/tulip/winsocketpair.py index 374616f6..bd1e0928 100644 --- a/tulip/winsocketpair.py +++ b/tulip/winsocketpair.py @@ -3,7 +3,6 @@ Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. """ -import errno import socket import sys @@ -23,11 +22,12 @@ def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): csock.setblocking(False) try: csock.connect((addr, port)) - except socket.error as e: - if e.errno != errno.WSAEWOULDBLOCK: - lsock.close() - csock.close() - raise + except (BlockingIOError, InterruptedError): + pass + except: + lsock.close() + csock.close() + raise ssock, _ = lsock.accept() csock.setblocking(True) lsock.close() From e523ec807b808d7b3101cbc75f8509f1b9da0c56 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 23 Mar 2013 13:22:31 -0700 Subject: [PATCH 0380/1502] Tiny cleanup. --- tulip/base_events.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tulip/base_events.py b/tulip/base_events.py index 46cab27a..2aebcb89 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -211,8 +211,6 @@ def call_later(self, delay, callback, *args): Any positional arguments after the callback will be passed to the callback when it is called. - - # TODO: Should delay is None be interpreted as Infinity? """ if delay <= 0: return self.call_soon(callback, *args) @@ -223,6 +221,7 @@ def call_later(self, delay, callback, *args): def call_repeatedly(self, interval, callback, *args): """Call a callback every 'interval' seconds.""" + assert interval > 0, 'Interval must be > 0: %r' % (interval,) def wrapper(): callback(*args) # If this fails, the chain is broken. handle._when = time.monotonic() + interval From 76b0b24d72996d0071ab2d966d71abeeb0ca6d4f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 23 Mar 2013 13:25:23 -0700 Subject: [PATCH 0381/1502] Fix -v flag in Makefile. Add TODO to call_repeatedly(). --- Makefile | 7 ++++--- tulip/base_events.py | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 724c57c9..274da4c8 100644 --- a/Makefile +++ b/Makefile @@ -1,17 +1,18 @@ # Some simple testing tasks (sorry, UNIX only). PYTHON=python3 +VERBOSE=1 FLAGS= test: - $(PYTHON) runtests.py -v $(FLAGS) + $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS) testloop: - while sleep 1; do $(PYTHON) runtests.py -v $(FLAGS); done + while sleep 1; do $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS); done # See README for coverage installation instructions. cov coverage: - $(PYTHON) runtests.py --coverage tulip -v $(FLAGS) + $(PYTHON) runtests.py --coverage tulip -v $(VERBOSE) $(FLAGS) echo "open file://`pwd`/htmlcov/index.html" check: diff --git a/tulip/base_events.py b/tulip/base_events.py index 2aebcb89..0b55ee4c 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -222,6 +222,7 @@ def call_later(self, delay, callback, *args): def call_repeatedly(self, interval, callback, *args): """Call a callback every 'interval' seconds.""" assert interval > 0, 'Interval must be > 0: %r' % (interval,) + # TODO: What if callback is already a Handle? def wrapper(): callback(*args) # If this fails, the chain is broken. handle._when = time.monotonic() + interval From e856f88a8c4a21f94daca17d47e5e298eeaa550d Mon Sep 17 00:00:00 2001 From: Sa?l Ibarra Corretg? Date: Sun, 24 Mar 2013 21:33:50 +0100 Subject: [PATCH 0382/1502] Don't skip polling even if the self-pipe is th eonly registered fd --- tests/events_test.py | 24 ++++++++++++++++++++-- tulip/base_events.py | 47 ++++++++++++++++++++------------------------ 2 files changed, 43 insertions(+), 28 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 7a64012b..40859211 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -216,6 +216,26 @@ def coro(): self.event_loop.run_until_complete(tasks.Task(coro())) self.assertIsInstance(err, RuntimeError) + def test_run_once_block(self): + called = False + + def callback(): + nonlocal called + called = True + + def run(): + time.sleep(0.1) + self.event_loop.call_soon_threadsafe(callback) + + t = threading.Thread(target=run) + t0 = time.monotonic() + t.start() + self.event_loop.run_once(None) + t1 = time.monotonic() + t.join() + self.assertTrue(called) + self.assertTrue(0.09 < t1-t0 <= 0.12) + def test_call_later(self): results = [] @@ -819,9 +839,9 @@ def datagram_received(self, data, addr): self.assertEqual('INITIALIZED', client.state) transport.sendto(b'xxx') - self.event_loop.run_once() + self.event_loop.run_once(None) self.assertEqual(3, server.nbytes) - self.event_loop.run_once() + self.event_loop.run_once(None) # received self.assertEqual(8, client.nbytes) diff --git a/tulip/base_events.py b/tulip/base_events.py index 0b55ee4c..573807d8 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -126,7 +126,7 @@ def run_forever(self): finally: handle.cancel() - def run_once(self, timeout=None): + def run_once(self, timeout=0): """Run through all callbacks and all I/O polls once. Calling stop() will break out of this too. @@ -503,32 +503,27 @@ def _run_once(self, timeout=None): while self._scheduled and self._scheduled[0].cancelled: heapq.heappop(self._scheduled) - # Inspect the poll queue. If there's exactly one selectable - # file descriptor, it's the self-pipe, and if there's nothing - # scheduled, we should ignore it. - if (self._scheduled or - self._selector.registered_count() > self._internal_fds): - if self._ready: - timeout = 0 - elif self._scheduled: - # Compute the desired timeout. - when = self._scheduled[0].when - deadline = max(0, when - time.monotonic()) - if timeout is None: - timeout = deadline - else: - timeout = min(timeout, deadline) - - t0 = time.monotonic() - event_list = self._selector.select(timeout) - t1 = time.monotonic() - argstr = '' if timeout is None else ' %.3f' % timeout - if t1-t0 >= 1: - level = logging.INFO + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0].when + deadline = max(0, when - time.monotonic()) + if timeout is None: + timeout = deadline else: - level = logging.DEBUG - logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) - self._process_events(event_list) + timeout = min(timeout, deadline) + + t0 = time.monotonic() + event_list = self._selector.select(timeout) + t1 = time.monotonic() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + self._process_events(event_list) # Handle 'later' callbacks that are ready. now = time.monotonic() From e997776c77c58e5047829423a51c618658b44544 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 24 Mar 2013 13:41:52 -0700 Subject: [PATCH 0383/1502] Take host:port from command line. --- srv.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) mode change 100644 => 100755 srv.py diff --git a/srv.py b/srv.py old mode 100644 new mode 100755 index 6f9b225b..b28abbda --- a/srv.py +++ b/srv.py @@ -1,7 +1,11 @@ +#!/usr/bin/env python3 """Simple server written using an event loop.""" import email.message import os +import sys + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' import tulip import tulip.http @@ -91,8 +95,17 @@ def handle_request(self, request_info, message): def main(): + host = '127.0.0.1' + port = 8080 + if sys.argv[1:]: + host = sys.argv[1] + if sys.argv[2:]: + port = int(sys.argv[2]) + elif ':' in host: + host, port = host.split(':', 1) + port = int(port) loop = tulip.get_event_loop() - f = loop.start_serving(lambda: HttpServer(debug=True), '127.0.0.1', 8080) + f = loop.start_serving(lambda: HttpServer(debug=True), host, port) x = loop.run_until_complete(f) print('serving on', x.getsockname()) loop.run_forever() From b8977733c6651f6cb1b5e98948a32fa9c15d55f2 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 25 Mar 2013 09:36:25 -0700 Subject: [PATCH 0384/1502] coroutine wrapper for non generator functions --- tests/tasks_test.py | 30 ++++++++++++++++++++++++++++++ tulip/tasks.py | 30 +++++++++++++++++++++++++----- 2 files changed, 55 insertions(+), 5 deletions(-) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 2bcb7745..88c23b59 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -612,6 +612,36 @@ def wait_for_future(): RuntimeError, self.event_loop.run_until_complete, task) + def test_coroutine_non_gen_function(self): + + @tasks.coroutine + def func(): + return 'test' + + self.assertTrue(tasks.iscoroutinefunction(func)) + + coro = func() + self.assertTrue(tasks.iscoroutine(coro)) + + res = self.event_loop.run_until_complete(coro) + self.assertEqual(res, 'test') + + def test_coroutine_non_gen_function_return_future(self): + fut = futures.Future() + + @tasks.coroutine + def func(): + return fut + + @tasks.coroutine + def coro(): + fut.set_result('test') + + t1 = tasks.Task(func()) + tasks.Task(coro()) + res = self.event_loop.run_until_complete(t1) + self.assertEqual(res, 'test') + if __name__ == '__main__': unittest.main() diff --git a/tulip/tasks.py b/tulip/tasks.py index 45aba017..41093564 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -6,6 +6,7 @@ ] import concurrent.futures +import functools import inspect import logging import time @@ -14,11 +15,30 @@ def coroutine(func): - """Decorator to mark coroutines.""" - # TODO: This is a feel-good API only. It is not enforced. - assert inspect.isgeneratorfunction(func) - func._is_coroutine = True # Not sure who can use this. - return func + """Decorator to mark coroutines. + + Decorator wraps non generator functions and returns generator wrapper. + If non generator function returns generator of Future it yield-from it. + + TODO: This is a feel-good API only. It is not enforced. + """ + if inspect.isgeneratorfunction(func): + coro = func + else: + logging.warning( + 'Coroutine function %s is not a generator.', func.__name__) + + @functools.wraps(func) + def coro(*args, **kw): + res = func(*args, **kw) + + if isinstance(res, futures.Future) or inspect.isgenerator(res): + res = yield from res + + return res + + coro._is_coroutine = True # Not sure who can use this. + return coro # TODO: Do we need this? From 463ed7c479a4c543f378595b0ad3b3f57695f97d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 26 Mar 2013 09:09:15 -0700 Subject: [PATCH 0385/1502] Improved setup.py from issue 23. --- setup.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 67b037cc..dcaee96f 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,14 @@ +import os from distutils.core import setup, Extension -ext = Extension('_overlapped', ['overlapped.c'], libraries=['ws2_32']) -setup(name='_overlapped', ext_modules=[ext]) +extensions = [] +if os.name == 'nt': + ext = Extension('_overlapped', ['overlapped.c'], libraries=['ws2_32']) + extensions.append(ext) + +setup(name='tulip', + description="reference implementation of PEP 3156", + url='http://www.python.org/dev/peps/pep-3156/', + packages=['tulip'], + ext_modules=extensions + ) From 6bb226c830cb3ea010bcd5355c397dd988c04459 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 27 Mar 2013 15:17:43 -0700 Subject: [PATCH 0386/1502] proactor tests --- .hgeol | 2 + .hgignore | 11 + Makefile | 31 + NOTES | 176 +++++ README | 21 + TODO | 163 ++++ check.py | 41 + crawl.py | 143 ++++ curl.py | 35 + examples/udp_echo.py | 73 ++ old/Makefile | 16 + old/echoclt.py | 79 ++ old/echosvr.py | 60 ++ old/http_client.py | 78 ++ old/http_server.py | 68 ++ old/main.py | 134 ++++ old/p3time.py | 47 ++ old/polling.py | 535 +++++++++++++ old/scheduling.py | 354 +++++++++ old/sockets.py | 348 +++++++++ old/transports.py | 496 ++++++++++++ old/xkcd.py | 18 + old/yyftime.py | 75 ++ overlapped.c | 997 ++++++++++++++++++++++++ runtests.py | 198 +++++ setup.cfg | 2 + setup.py | 14 + srv.py | 115 +++ sslsrv.py | 56 ++ tests/base_events_test.py | 283 +++++++ tests/events_test.py | 1379 +++++++++++++++++++++++++++++++++ tests/futures_test.py | 222 ++++++ tests/http_protocol_test.py | 972 +++++++++++++++++++++++ tests/http_server_test.py | 242 ++++++ tests/locks_test.py | 747 ++++++++++++++++++ tests/proactor_events_test.py | 327 ++++++++ tests/queues_test.py | 370 +++++++++ tests/sample.crt | 14 + tests/sample.key | 15 + tests/selector_events_test.py | 1286 ++++++++++++++++++++++++++++++ tests/selectors_test.py | 137 ++++ tests/streams_test.py | 299 +++++++ tests/subprocess_test.py | 54 ++ tests/tasks_test.py | 647 ++++++++++++++++ tests/transports_test.py | 45 ++ tests/unix_events_test.py | 573 ++++++++++++++ tests/winsocketpair_test.py | 26 + tulip/TODO | 28 + tulip/__init__.py | 26 + tulip/base_events.py | 547 +++++++++++++ tulip/events.py | 355 +++++++++ tulip/futures.py | 255 ++++++ tulip/http/__init__.py | 12 + tulip/http/client.py | 145 ++++ tulip/http/errors.py | 44 ++ tulip/http/protocol.py | 877 +++++++++++++++++++++ tulip/http/server.py | 176 +++++ tulip/locks.py | 433 +++++++++++ tulip/proactor_events.py | 190 +++++ tulip/protocols.py | 78 ++ tulip/queues.py | 291 +++++++ tulip/selector_events.py | 655 ++++++++++++++++ tulip/selectors.py | 418 ++++++++++ tulip/streams.py | 145 ++++ tulip/subprocess_transport.py | 139 ++++ tulip/tasks.py | 320 ++++++++ tulip/test_utils.py | 30 + tulip/transports.py | 134 ++++ tulip/unix_events.py | 301 +++++++ tulip/windows_events.py | 157 ++++ tulip/winsocketpair.py | 34 + 71 files changed, 17814 insertions(+) create mode 100644 .hgeol create mode 100644 .hgignore create mode 100644 Makefile create mode 100644 NOTES create mode 100644 README create mode 100644 TODO create mode 100644 check.py create mode 100755 crawl.py create mode 100755 curl.py create mode 100644 examples/udp_echo.py create mode 100644 old/Makefile create mode 100644 old/echoclt.py create mode 100644 old/echosvr.py create mode 100644 old/http_client.py create mode 100644 old/http_server.py create mode 100644 old/main.py create mode 100644 old/p3time.py create mode 100644 old/polling.py create mode 100644 old/scheduling.py create mode 100644 old/sockets.py create mode 100644 old/transports.py create mode 100755 old/xkcd.py create mode 100644 old/yyftime.py create mode 100644 overlapped.c create mode 100644 runtests.py create mode 100644 setup.cfg create mode 100644 setup.py create mode 100755 srv.py create mode 100644 sslsrv.py create mode 100644 tests/base_events_test.py create mode 100644 tests/events_test.py create mode 100644 tests/futures_test.py create mode 100644 tests/http_protocol_test.py create mode 100644 tests/http_server_test.py create mode 100644 tests/locks_test.py create mode 100644 tests/proactor_events_test.py create mode 100644 tests/queues_test.py create mode 100644 tests/sample.crt create mode 100644 tests/sample.key create mode 100644 tests/selector_events_test.py create mode 100644 tests/selectors_test.py create mode 100644 tests/streams_test.py create mode 100644 tests/subprocess_test.py create mode 100644 tests/tasks_test.py create mode 100644 tests/transports_test.py create mode 100644 tests/unix_events_test.py create mode 100644 tests/winsocketpair_test.py create mode 100644 tulip/TODO create mode 100644 tulip/__init__.py create mode 100644 tulip/base_events.py create mode 100644 tulip/events.py create mode 100644 tulip/futures.py create mode 100644 tulip/http/__init__.py create mode 100644 tulip/http/client.py create mode 100644 tulip/http/errors.py create mode 100644 tulip/http/protocol.py create mode 100644 tulip/http/server.py create mode 100644 tulip/locks.py create mode 100644 tulip/proactor_events.py create mode 100644 tulip/protocols.py create mode 100644 tulip/queues.py create mode 100644 tulip/selector_events.py create mode 100644 tulip/selectors.py create mode 100644 tulip/streams.py create mode 100644 tulip/subprocess_transport.py create mode 100644 tulip/tasks.py create mode 100644 tulip/test_utils.py create mode 100644 tulip/transports.py create mode 100644 tulip/unix_events.py create mode 100644 tulip/windows_events.py create mode 100644 tulip/winsocketpair.py diff --git a/.hgeol b/.hgeol new file mode 100644 index 00000000..b6910a22 --- /dev/null +++ b/.hgeol @@ -0,0 +1,2 @@ +[patterns] +** = native diff --git a/.hgignore b/.hgignore new file mode 100644 index 00000000..25902497 --- /dev/null +++ b/.hgignore @@ -0,0 +1,11 @@ +.*\.py[co]$ +.*~$ +.*\.orig$ +.*\#.*$ +.*@.*$ +\.coverage$ +htmlcov$ +\.DS_Store$ +venv$ +distribute_setup.py$ +distribute-\d+.\d+.\d+.tar.gz$ diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..274da4c8 --- /dev/null +++ b/Makefile @@ -0,0 +1,31 @@ +# Some simple testing tasks (sorry, UNIX only). + +PYTHON=python3 +VERBOSE=1 +FLAGS= + +test: + $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS) + +testloop: + while sleep 1; do $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS); done + +# See README for coverage installation instructions. +cov coverage: + $(PYTHON) runtests.py --coverage tulip -v $(VERBOSE) $(FLAGS) + echo "open file://`pwd`/htmlcov/index.html" + +check: + $(PYTHON) check.py + +clean: + rm -rf `find . -name __pycache__` + rm -f `find . -type f -name '*.py[co]' ` + rm -f `find . -type f -name '*~' ` + rm -f `find . -type f -name '.*~' ` + rm -f `find . -type f -name '@*' ` + rm -f `find . -type f -name '#*#' ` + rm -f `find . -type f -name '*.orig' ` + rm -f `find . -type f -name '*.rej' ` + rm -f .coverage + rm -rf htmlcov diff --git a/NOTES b/NOTES new file mode 100644 index 00000000..3b94ba96 --- /dev/null +++ b/NOTES @@ -0,0 +1,176 @@ +Notes from PyCon 2013 sprints +============================= + +- Cancellation. If a task creates several subtasks, and then the + parent task fails, should the subtasks be cancelled? (How do we + even establish the parent/subtask relationship?) + +- Adam Sah suggests that there might be a need for scheduling + (especially when multiple frameworks share an event loop). He + points to lottery scheduling but also mentions that's just one of + the options. However, after posting on python-tulip, it appears + none of the other frameworks have scheduling, and nobody seems to + miss it. + +- Feedback from Bram Cohen (Bittorrent creator) about UDP. He doesn't + think connected UDP is worth supporting, it doesn't do anything + except tell the kernel about the default target address for + sendto(). Basically he says all UDP end points are servers. He + sent me his own UDP event loop so I might glean some tricks from it. + He says we should treat EINTR the same as EAGAIN and friends. (We + should use the exceptions dedicated to errno checking, BTW.) HE + said to make sure we use SO_REUSEADDR (I think we already do). He + said to set the max datagram sizes pretty large (anything larger + than the declared limit is dropped on the floor). He reminds us of + the importance of being able to pick a valid, unused port by binding + to port 0 and then using getsockname(). He has an idea where he's + like to be able to kill all registered callbacks (i.e. Handles) + belonging to a certain "context". I think this can be done at the + application level (you'd have to wrap everything that returns a + Handle and collect these handles in some set or other datastructure) + but if someone thinks it's interesting we could imagine having some + kind of notion of context part of the event loop state, + e.g. associated with a Task (see Cancellation point above). He + brought up uTP (Micro Transport Protocol), a reimplementation of TCP + over UDP with more refined congestion control. + +- Mumblings about UNIX domain sockets and IPv6 addresses being + 4-tuples. The former can be handled by passing in a socket. There + seem to be no real use cases for the latter that can't be dealt with + by passing in suitably esoteric strings for the hostname. + getaddrinfo() will produce the appropriate 4-tuple and connect() + will accept it. + +- Mumblings on the list about add vs. set. + + +Notes from the second Tulip/Twisted meet-up +=========================================== + +Rackspace, 12/11/2012 +Glyph, Brian Warner, David Reid, Duncan McGreggor, others + +Flow control +------------ + +- Pause/resume on transport manages data_received. + +- There's also an API to tell the transport whom to pause when the + write calls are overwhelming it: IConsumer.registerProducer(). + +- There's also something called pipes but it's built on top of the + old interface. + +- Twisted has variations on the basic flow control that I should + ignore. + +Half_close +---------- + +- This sends an EOF after writing some stuff. + +- Can't write any more. + +- Problem with TLS is known (the RFC sadly specifies this behavior). + +- It must be dynamimcally discoverable whether the transport supports + half_close, since the protocol may have to do something different to + make up for its missing (e.g. use chunked encoding). Twisted uses + an interface check for this and also hasattr(trans, 'halfClose') + but a flag (or flag method) is fine too. + +Constructing transport and protocol +----------------------------------- + +- There are good reasons for passing a function to the transport + construction helper that creates the protocol. (You need these + anyway for server-side protocols.) The sequence of events is + something like + + . open socket + . create transport (pass it a socket?) + . create protocol (pass it nothing) + . proto.make_connection(transport); this does: + . self.transport = transport + . self.connection_made(transport) + + But it seems okay to skip make_connection and setting .transport. + Note that make_connection() is a concrete method on the Protocol + implementation base class, while connection_made() is an abstract + method on IProtocol. + +Event Loop +---------- + +- We discussed the sequence of actions in the event loop. I think in the + end we're fine with what Tulip currently does. There are two choices: + + Tulip: + . run ready callbacks until there aren't any left + . poll, adding more callbacks to the ready list + . add now-ready delayed callbacks to the ready list + . go to top + + Tornado: + . run all currently ready callbacks (but not new ones added during this) + . (the rest is the same) + + The difference is that in the Tulip version, CPU bound callbacks + that keep adding more to the queue will starve I/O (and yielding to + other tasks won't actually cause I/O to happen unless you do + e.g. sleep(0.001)). OTOH this may be good because it means there's + less overhead if you frequently split operations in two. + +- I think Twisted does it Tornado style (in a convoluted way :-), but + it may not matter, and it's important to leave this vague so + implementations can do what's best for their platform. (E.g. if the + event loop is built into the OS there are different trade-offs.) + +System call cost +---------------- + +- System calls on MacOS are expensive, on Linux they are cheap. + +- Optimal buffer size ~16K. + +- Try joining small buffer pieces together, but expect to be tuning + this later. + +Futures +------- + +- Futures are the most robust API for async stuff, you can check + errors etc. So let's do this. + +- Just don't implement wait(). + +- For the basics, however, (recv/send, mostly), don't use Futures but use + basic callbacks, transport/protocol style. + +- make_connection() (by any name) can return a Future, it makes it + easier to check for errors. + +- This means revisiting the Tulip proactor branch (IOCP). + +- The semantics of add_done_callback() are fuzzy about in which thread + the callback will be called. (It may be the current thread or + another one.) We don't like that. But always inserting a + call_soon() indirection may be expensive? Glyph suggested changing + the add_done_callback() method name to something else to indicate + the changed promise. + +- Separately, I've been thinking about having two versions of + call_soon() -- a more heavy-weight one to be called from other + threads that also writes a byte to the self-pipe. + +Signals +------- + +- There was a side conversation about signals. A signal handler is + similar to another thread, so probably should use (the heavy-weight + version of) call_soon() to schedule the real callback and not do + anything else. + +- Glyph vaguely recalled some trickiness with the self-pipe. We + should be able to fix this afterwards if necessary, it shouldn't + affect the API design. diff --git a/README b/README new file mode 100644 index 00000000..85bfe5a7 --- /dev/null +++ b/README @@ -0,0 +1,21 @@ +Tulip is the codename for my reference implementation of PEP 3156. + +PEP 3156: http://www.python.org/dev/peps/pep-3156/ + +*** This requires Python 3.3 or later! *** + +Copyright/license: Open source, Apache 2.0. Enjoy. + +Master Mercurial repo: http://code.google.com/p/tulip/ + +The old code lives in the subdirectory 'old'; the new code (conforming +to PEP 3156, under construction) lives in the 'tulip' subdirectory. + +To run tests: + - make test + +To run coverage (coverage package is required): + - make coverage + + +--Guido van Rossum diff --git a/TODO b/TODO new file mode 100644 index 00000000..c6d4eead --- /dev/null +++ b/TODO @@ -0,0 +1,163 @@ +# -*- Mode: text -*- + +TO DO LARGER TASKS + +- Need more examples. + +- Benchmarkable but more realistic HTTP server? + +- Example of using UDP. + +- Write up a tutorial for the scheduling API. + +- More systematic approach to logging. Logger objects? What about + heavy-duty logging, tracking essentially all task state changes? + +- Restructure directory, move demos and benchmarks to subdirectories. + + +TO DO LATER + +- When multiple tasks are accessing the same socket, they should + either get interleaved I/O or an immediate exception; it should not + compromise the integrity of the scheduler or the app or leave a task + hanging. + +- For epoll you probably want to check/(log?) EPOLLHUP and EPOLLERR errors. + +- Add the simplest API possible to run a generator with a timeout. + +- Ensure multiple tasks can do atomic writes to the same pipe (since + UNIX guarantees that short writes to pipes are atomic). + +- Ensure some easy way of distributing accepted connections across tasks. + +- Be wary of thread-local storage. There should be a standard API to + get the current Context (which holds current task, event loop, and + maybe more) and a standard meta-API to change how that standard API + works (i.e. without monkey-patching). + +- See how much of asyncore I've already replaced. + +- Could BufferedReader reuse the standard io module's readers??? + +- Support ZeroMQ "sockets" which are user objects. Though possibly + this can be supported by getting the underlying fd? See + http://mail.python.org/pipermail/python-ideas/2012-October/017532.html + OTOH see + https://github.com/zeromq/pyzmq/blob/master/zmq/eventloop/ioloop.py + +- Study goroutines (again). + +- Benchmarks: http://nichol.as/benchmark-of-python-web-servers + + +FROM OLDER LIST + +- Multiple readers/writers per socket? (At which level? pollster, + eventloop, or scheduler?) + +- Could poll() usefully be an iterator? + +- Do we need to support more epoll and/or kqueue modes/flags/options/etc.? + +- Optimize register/unregister calls away if they cancel each other out? + +- Add explicit wait queue to wait for Task's completion, instead of + callbacks? + +- Look at pyfdpdlib's ioloop.py: + http://code.google.com/p/pyftpdlib/source/browse/trunk/pyftpdlib/lib/ioloop.py + + +MISTAKES I MADE + +- Forgetting yield from. (E.g.: scheduler.sleep(1); listener.accept().) + +- Forgot to add bare yield at end of internal function, after block(). + +- Forgot to call add_done_callback(). + +- Forgot to pass an undoer to block(), bug only found when cancelled. + +- Subtle accounting mistake in a callback. + +- Used context.eventloop from a different thread, forgetting about TLS. + +- Nasty race: eventloop.ready may contain both an I/O callback and a + cancel callback. How to avoid? Keep the DelayedCall in ready. Is + that enough? + +- If a toplevel task raises an error it just stops and nothing is logged + unless you have debug logging on. This confused me. (Then again, + previously I logged whenever a task raised an error, and that was too + chatty...) + +- Forgot to set the connection socket returned by accept() in + nonblocking mode. + +- Nastiest so far (cost me about a day): A race condition in + call_in_thread() where the Future's done_callback (which was + task.unblock()) would run immediately at the time when + add_done_callback() was called, and this screwed over the task + state. Solution: wrap the callback in eventloop.call_later(). + Ironically, I had a comment stating there might be a race condition. + +- Another bug where I was calling unblock() for the current thread + immediately after calling block(), before yielding. + +- readexactly() wasn't checking for EOF, so could be looping. + (Worse, the first fix I attempted was wrong.) + +- Spent a day trying to understand why a tentative patch trying to + move the recv() implementation into the eventloop (or the pollster) + resulted in problems cancelling a recv() call. Ultimately the + problem is that the cancellation mechanism is part of the coroutine + scheduler, which simply throws an exception into a task when it next + runs, and there isn't anything to be interrupted in the eventloop; + but the eventloop still has a reader registered (which will never + fire because I suspended the server -- that's my test case :-). + Then, the eventloop keeps running until the last file descriptor is + unregistered. What contributed to this disaster? + * I didn't build the whole infrastructure, just played with recv() + * I don't have unittests + * I don't have good logging to see what is going + +- In sockets.py, in some SSL error handling code, used the wrong + variable (sock instead of sslsock). A linter would have found this. + +- In polling.py, in KqueuePollster.register_writer(), a copy/paste + error where I was testing for "if fd not in self.readers" instead of + writers. This only came out when I had both a reader and a writer + for the same fd. + +- Submitted some changes prematurely (forgot to pass the filename on + hg ci). + +- Forgot again that shutdown(SHUT_WR) on an ssl socket does not work + as I expected. I ran into this with the origininal sockets.py and + again in transport.py. + +- Having the same callback for both reading and writing has a problem: + it may be scheduled twice, and if the first call closes the socket, + the second runs into trouble. + + +MISTAKES I MADE IN TULIP V2 + +- Nice one: Future._schedule_callbacks() wasn't scheduling any callbacks. + Spot the bug in these four lines: + + def _schedule_callbacks(self): + callbacks = self._callbacks[:] + self._callbacks[:] = [] + for callback in self._callbacks: + self._event_loop.call_soon(callback, self) + + The good news is that I found it with a unittest (albeit not the + unittest intended to exercise this particular method :-( ). + +- In _make_self_pipe_or_sock(), called _pollster.register_reader() + instead of add_reader(), trying to optimize something but breaking + things instead (since the -- internal -- API of register_reader() + had changed). diff --git a/check.py b/check.py new file mode 100644 index 00000000..64bc2cdd --- /dev/null +++ b/check.py @@ -0,0 +1,41 @@ +"""Search for lines > 80 chars or with trailing whitespace.""" + +import sys, os + +def main(): + args = sys.argv[1:] or os.curdir + for arg in args: + if os.path.isdir(arg): + for dn, dirs, files in os.walk(arg): + for fn in sorted(files): + if fn.endswith('.py'): + process(os.path.join(dn, fn)) + dirs[:] = [d for d in dirs if d[0] != '.'] + dirs.sort() + else: + process(arg) + +def isascii(x): + try: + x.encode('ascii') + return True + except UnicodeError: + return False + +def process(fn): + try: + f = open(fn) + except IOError as err: + print(err) + return + try: + for i, line in enumerate(f): + line = line.rstrip('\n') + sline = line.rstrip() + if len(line) > 80 or line != sline or not isascii(line): + print('%s:%d:%s%s' % ( + fn, i+1, sline, '_' * (len(line) - len(sline)))) + finally: + f.close() + +main() diff --git a/crawl.py b/crawl.py new file mode 100755 index 00000000..4e5bebe2 --- /dev/null +++ b/crawl.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 + +import logging +import re +import signal +import socket +import sys +import urllib.parse + +import tulip +import tulip.http + +END = '\n' +MAXTASKS = 100 + + +class Crawler: + + def __init__(self, rooturl): + self.rooturl = rooturl + self.todo = set() + self.busy = set() + self.done = {} + self.tasks = set() + self.waiter = None + self.addurl(self.rooturl, '') # Set initial work. + self.run() # Kick off work. + + def addurl(self, url, parenturl): + url = urllib.parse.urljoin(parenturl, url) + url, frag = urllib.parse.urldefrag(url) + if not url.startswith(self.rooturl): + return False + if url in self.busy or url in self.done or url in self.todo: + return False + self.todo.add(url) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(None) + return True + + @tulip.task + def run(self): + while self.todo or self.busy or self.tasks: + complete, self.tasks = yield from tulip.wait(self.tasks, timeout=0) + print(len(complete), 'completed tasks,', len(self.tasks), + 'still pending ', end=END) + for task in complete: + try: + yield from task + except Exception as exc: + print('Exception in task:', exc, end=END) + while self.todo and len(self.tasks) < MAXTASKS: + url = self.todo.pop() + self.busy.add(url) + self.tasks.add(self.process(url)) # Async task. + if self.busy: + self.waiter = tulip.Future() + yield from self.waiter + tulip.get_event_loop().stop() + + @tulip.task + def process(self, url): + ok = False + p = None + try: + print('processing', url, end=END) + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not path: + path = '/' + if query: + path = '?'.join([path, query]) + p = tulip.http.HttpClientProtocol( + netloc, path=path, ssl=(scheme=='https')) + delay = 1 + while True: + try: + status, headers, stream = yield from p.connect() + break + except socket.error as exc: + if delay >= 60: + raise + print('...', url, 'has error', repr(str(exc)), + 'retrying after sleep', delay, '...', end=END) + yield from tulip.sleep(delay) + delay *= 2 + + if status[:3] in ('301', '302'): + # Redirect. + u = headers.get('location') or headers.get('uri') + if self.addurl(u, url): + print(' ', url, status[:3], 'redirect to', u, end=END) + elif status.startswith('200'): + ctype = headers.get_content_type() + if ctype == 'text/html': + while True: + line = yield from stream.readline() + if not line: + break + line = line.decode('utf-8', 'replace') + urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', + line) + for u in urls: + if self.addurl(u, url): + print(' ', url, 'href to', u, end=END) + ok = True + finally: + if p is not None: + p.transport.close() + self.done[url] = ok + self.busy.remove(url) + if not ok: + print('failure for', url, sys.exc_info(), end=END) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(None) + + +def main(): + rooturl = sys.argv[1] + c = Crawler(rooturl) + loop = tulip.get_event_loop() + try: + loop.add_signal_handler(signal.SIGINT, loop.stop) + except RuntimeError: + pass + loop.run_forever() + print('todo:', len(c.todo)) + print('busy:', len(c.busy)) + print('done:', len(c.done), '; ok:', sum(c.done.values())) + print('tasks:', len(c.tasks)) + + +if __name__ == '__main__': + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + logging.info('using iocp') + el = windows_events.ProactorEventLoop() + events.set_event_loop(el) + main() diff --git a/curl.py b/curl.py new file mode 100755 index 00000000..37fce75c --- /dev/null +++ b/curl.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 + +import sys +import urllib.parse + +import tulip +import tulip.http + + +def main(): + url = sys.argv[1] + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not path: + path = '/' + if query: + path = '?'.join([path, query]) + print(netloc, path, scheme) + p = tulip.http.HttpClientProtocol(netloc, path=path, ssl=(scheme=='https')) + f = p.connect() + sts, headers, stream = p.event_loop.run_until_complete(f) + print(sts) + for k, v in headers.items(): + print('{}: {}'.format(k, v)) + print() + data = p.event_loop.run_until_complete(stream.read()) + print(data.decode('utf-8', 'replace')) + + +if __name__ == '__main__': + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + el = windows_events.ProactorEventLoop() + events.set_event_loop(el) + main() diff --git a/examples/udp_echo.py b/examples/udp_echo.py new file mode 100644 index 00000000..9e995d14 --- /dev/null +++ b/examples/udp_echo.py @@ -0,0 +1,73 @@ +"""UDP echo example. + +Start server: + + >> python ./udp_echo.py --server + +""" + +import sys +import tulip + +ADDRESS = ('127.0.0.1', 10000) + + +class MyServerUdpEchoProtocol: + + def connection_made(self, transport): + print('start', transport) + self.transport = transport + + def datagram_received(self, data, addr): + print('Data received:', data, addr) + self.transport.sendto(data, addr) + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('stop', exc) + + +class MyClientUdpEchoProtocol: + + message = 'This is the message. It will be repeated.' + + def connection_made(self, transport): + self.transport = transport + print('sending "%s"' % self.message) + self.transport.sendto(self.message.encode()) + print('waiting to receive') + + def datagram_received(self, data, addr): + print('received "%s"' % data.decode()) + self.transport.close() + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('closing transport', exc) + loop = tulip.get_event_loop() + loop.stop() + + +def start_server(): + loop = tulip.get_event_loop() + tulip.Task(loop.create_datagram_endpoint( + MyServerUdpEchoProtocol, local_addr=ADDRESS)) + loop.run_forever() + + +def start_client(): + loop = tulip.get_event_loop() + tulip.Task(loop.create_datagram_endpoint( + MyClientUdpEchoProtocol, remote_addr=ADDRESS)) + loop.run_forever() + + +if __name__ == '__main__': + if '--server' in sys.argv: + start_server() + else: + start_client() diff --git a/old/Makefile b/old/Makefile new file mode 100644 index 00000000..d352cd70 --- /dev/null +++ b/old/Makefile @@ -0,0 +1,16 @@ +PYTHON=python3 + +main: + $(PYTHON) main.py -v + +echo: + $(PYTHON) echosvr.py -v + +profile: + $(PYTHON) -m profile -s time main.py + +time: + $(PYTHON) p3time.py + +ytime: + $(PYTHON) yyftime.py diff --git a/old/echoclt.py b/old/echoclt.py new file mode 100644 index 00000000..c24c573e --- /dev/null +++ b/old/echoclt.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3.3 +"""Example echo client.""" + +# Stdlib imports. +import logging +import socket +import sys +import time + +# Local imports. +import scheduling +import sockets + + +def echoclient(host, port): + """COROUTINE""" + testdata = b'hi hi hi ha ha ha\n' + try: + trans = yield from sockets.create_transport(host, port, + af=socket.AF_INET) + except OSError: + return False + try: + ok = yield from trans.send(testdata) + if ok: + response = yield from trans.recv(100) + ok = response == testdata.upper() + return ok + finally: + trans.close() + + +def doit(n): + """COROUTINE""" + t0 = time.time() + tasks = set() + for i in range(n): + t = scheduling.Task(echoclient('127.0.0.1', 1111), 'client-%d' % i) + tasks.add(t) + ok = 0 + bad = 0 + for t in tasks: + try: + yield from t + except Exception: + bad += 1 + else: + ok += 1 + t1 = time.time() + print('ok: ', ok) + print('bad:', bad) + print('dt: ', round(t1-t0, 6)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Get integer from command line. + n = 1 + for arg in sys.argv[1:]: + if not arg.startswith('-'): + n = int(arg) + break + + # Run scheduler, starting it off with doit(). + scheduling.run(doit(n)) + + +if __name__ == '__main__': + main() diff --git a/old/echosvr.py b/old/echosvr.py new file mode 100644 index 00000000..4085f4c6 --- /dev/null +++ b/old/echosvr.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3.3 +"""Example echo server.""" + +# Stdlib imports. +import logging +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + while True: + line = yield from rdr.readline() + logging.debug('Received: %r from %r', line, addr) + if not line: + break + yield from trans.send(line.upper()) + logging.debug('Closing %r', addr) + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 1111, + af=socket.AF_INET, + backlog=100) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() diff --git a/old/http_client.py b/old/http_client.py new file mode 100644 index 00000000..8937ba20 --- /dev/null +++ b/old/http_client.py @@ -0,0 +1,78 @@ +"""Crummy HTTP client. + +This is not meant as an example of how to write a good client. +""" + +# Stdlib. +import re +import time + +# Local. +import sockets + + +def urlfetch(host, port=None, path='/', method='GET', + body=None, hdrs=None, encoding='utf-8', ssl=None, af=0): + """COROUTINE: Make an HTTP 1.0 request.""" + t0 = time.time() + if port is None: + if ssl: + port = 443 + else: + port = 80 + trans = yield from sockets.create_transport(host, port, ssl=ssl, af=af) + yield from trans.send(method.encode(encoding) + b' ' + + path.encode(encoding) + b' HTTP/1.0\r\n') + if hdrs: + kwds = dict(hdrs) + else: + kwds = {} + if 'host' not in kwds: + kwds['host'] = host + if body is not None: + kwds['content_length'] = len(body) + for header, value in kwds.items(): + yield from trans.send(header.replace('_', '-').encode(encoding) + + b': ' + value.encode(encoding) + b'\r\n') + + yield from trans.send(b'\r\n') + if body is not None: + yield from trans.send(body) + + # Read HTTP response line. + rdr = sockets.BufferedReader(trans) + resp = yield from rdr.readline() + m = re.match(br'(?ix) http/(\d\.\d) \s+ (\d\d\d) \s+ ([^\r]*)\r?\n\Z', + resp) + if not m: + trans.close() + raise IOError('No valid HTTP response: %r' % resp) + http_version, status, message = m.groups() + + # Read HTTP headers. + headers = [] + hdict = {} + while True: + line = yield from rdr.readline() + if not line.strip(): + break + m = re.match(br'([^\s:]+):\s*([^\r]*)\r?\n\Z', line) + if not m: + raise IOError('Invalid header: %r' % line) + header, value = m.groups() + headers.append((header, value)) + hdict[header.decode(encoding).lower()] = value.decode(encoding) + + # Read response body. + content_length = hdict.get('content-length') + if content_length is not None: + size = int(content_length) # TODO: Catch errors. + assert size >= 0, size + else: + size = 2**20 # Protective limit (1 MB). + data = yield from rdr.readexactly(size) + trans.close() # Can this block? + t1 = time.time() + result = (host, port, path, int(status), len(data), round(t1-t0, 3)) +## print(result) + return result diff --git a/old/http_server.py b/old/http_server.py new file mode 100644 index 00000000..2b1e3dd6 --- /dev/null +++ b/old/http_server.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3.3 +"""Simple HTTP server. + +This currenty exists just so we can benchmark this thing! +""" + +# Stdlib imports. +import logging +import re +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + ##logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + + # Read but ignore request line. + request_line = yield from rdr.readline() + + # Consume headers but don't interpret them. + while True: + header_line = yield from rdr.readline() + if not header_line.strip(): + break + + # Always send an empty 200 response and close. + yield from trans.send(b'HTTP/1.0 200 Ok\r\n\r\n') + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 8080, + af=socket.AF_INET) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() diff --git a/old/main.py b/old/main.py new file mode 100644 index 00000000..c1f9d0a8 --- /dev/null +++ b/old/main.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3.3 +"""Example HTTP client using yield-from coroutines (PEP 380). + +Requires Python 3.3. + +There are many micro-optimizations possible here, but that's not the point. + +Some incomplete laundry lists: + +TODO: +- Take test urls from command line. +- Move urlfetch to a separate module. +- Profiling. +- Docstrings. +- Unittests. + +FUNCTIONALITY: +- Connection pool (keep connection open). +- Chunked encoding (request and response). +- Pipelining, e.g. zlib (request and response). +- Automatic encoding/decoding. +""" + +__author__ = 'Guido van Rossum ' + +# Standard library imports (keep in alphabetic order). +import logging +import os +import time +import socket +import sys + +# Local imports (keep in alphabetic order). +import scheduling +import http_client + + + +def doit2(): + argses = [ + ('localhost', 8080, '/'), + ('127.0.0.1', 8080, '/home'), + ('python.org', 80, '/'), + ('xkcd.com', 443, '/'), + ] + results = yield from scheduling.map_over( + lambda args: http_client.urlfetch(*args), argses, timeout=2) + for res in results: + print('-->', res) + return [] + + +def doit(): + TIMEOUT = 2 + tasks = set() + + # This references NDB's default test service. + # (Sadly the service is single-threaded.) + task1 = scheduling.Task(http_client.urlfetch('localhost', 8080, path='/'), + 'root', timeout=TIMEOUT) + tasks.add(task1) + task2 = scheduling.Task(http_client.urlfetch('127.0.0.1', 8080, + path='/home'), + 'home', timeout=TIMEOUT) + tasks.add(task2) + + # Fetch python.org home page. + task3 = scheduling.Task(http_client.urlfetch('python.org', 80, path='/'), + 'python', timeout=TIMEOUT) + tasks.add(task3) + + # Fetch XKCD home page using SSL. (Doesn't like IPv6.) + task4 = scheduling.Task(http_client.urlfetch('xkcd.com', ssl=True, path='/', + af=socket.AF_INET), + 'xkcd', timeout=TIMEOUT) + tasks.add(task4) + +## # Fetch many links from python.org (/x.y.z). +## for x in '123': +## for y in '0123456789': +## path = '/{}.{}'.format(x, y) +## g = http_client.urlfetch('82.94.164.162', 80, +## path=path, hdrs={'host': 'python.org'}) +## t = scheduling.Task(g, path, timeout=2) +## tasks.add(t) + +## print(tasks) + yield from scheduling.Task(scheduling.sleep(1), timeout=0.2).wait() + winners = yield from scheduling.wait_any(tasks) + print('And the winners are:', [w.name for w in winners]) + tasks = yield from scheduling.wait_all(tasks) + print('And the players were:', [t.name for t in tasks]) + return tasks + + +def logtimes(real): + utime, stime, cutime, cstime, unused = os.times() + logging.info('real %10.3f', real) + logging.info('user %10.3f', utime + cutime) + logging.info('sys %10.3f', stime + cstime) + + +def main(): + t0 = time.time() + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + task = scheduling.run(doit()) + if task.exception: + print('Exception:', repr(task.exception)) + if isinstance(task.exception, AssertionError): + raise task.exception + else: + for t in task.result: + print(t.name + ':', + repr(t.exception) if t.exception else t.result) + + # Report real, user, sys times. + t1 = time.time() + logtimes(t1-t0) + + +if __name__ == '__main__': + main() diff --git a/old/p3time.py b/old/p3time.py new file mode 100644 index 00000000..35e14c96 --- /dev/null +++ b/old/p3time.py @@ -0,0 +1,47 @@ +"""Compare timing of plain vs. yield-from calls.""" + +import gc +import time + +def plain(n): + if n <= 0: + return 1 + l = plain(n-1) + r = plain(n-1) + return l + 1 + r + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def submain(depth): + t0 = time.time() + k = plain(depth) + t1 = time.time() + fmt = ' {} {} {:-9,.5f}' + delta0 = t1-t0 + print(('plain' + fmt).format(depth, k, delta0)) + + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + delta1 = t1-t0 + print(('coro.' + fmt).format(depth, k, delta1)) + if delta0: + print(('relat' + fmt).format(depth, k, delta1/delta0)) + +def main(reasonable=16): + gc.disable() + for depth in range(reasonable): + submain(depth) + +if __name__ == '__main__': + main() diff --git a/old/polling.py b/old/polling.py new file mode 100644 index 00000000..6586efcc --- /dev/null +++ b/old/polling.py @@ -0,0 +1,535 @@ +"""Event loop and related classes. + +The event loop can be broken up into a pollster (the part responsible +for telling us when file descriptors are ready) and the event loop +proper, which wraps a pollster with functionality for scheduling +callbacks, immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. + +There are several implementations of the pollster part, several using +esoteric system calls that exist only on some platforms. These are: + +- kqueue (most BSD systems) +- epoll (newer Linux systems) +- poll (most UNIX systems) +- select (all UNIX systems, and Windows) +- TODO: Support IOCP on Windows and some UNIX platforms. + +NOTE: We don't use select on systems where any of the others is +available, because select performs poorly as the number of file +descriptors goes up. The ranking is roughly: + + 1. kqueue, epoll, IOCP + 2. poll + 3. select + +TODO: +- Optimize the various pollsters. +- Unittests. +""" + +import collections +import concurrent.futures +import heapq +import logging +import os +import select +import threading +import time + + +class PollsterBase: + """Base class for all polling implementations. + + This defines an interface to register and unregister readers and + writers for specific file descriptors, and an interface to get a + list of events. There's also an interface to check whether any + readers or writers are currently registered. + """ + + def __init__(self): + super().__init__() + self.readers = {} # {fd: token, ...}. + self.writers = {} # {fd: token, ...}. + + def pollable(self): + """Return True if any readers or writers are currently registered.""" + return bool(self.readers or self.writers) + + # Subclasses are expected to extend the add/remove methods. + + def register_reader(self, fd, token): + """Add or update a reader for a file descriptor.""" + self.readers[fd] = token + + def register_writer(self, fd, token): + """Add or update a writer for a file descriptor.""" + self.writers[fd] = token + + def unregister_reader(self, fd): + """Remove the reader for a file descriptor.""" + del self.readers[fd] + + def unregister_writer(self, fd): + """Remove the writer for a file descriptor.""" + del self.writers[fd] + + def poll(self, timeout=None): + """Poll for events. A subclass must implement this. + + If timeout is omitted or None, this blocks until at least one + event is ready. Otherwise, timeout gives a maximum time to + wait (in seconds as an int or float) -- the method returns as + soon as at least one event is ready or when the timeout is + expired. For a non-blocking poll, pass 0. + + The return value is a list of events; it is empty when the + timeout expired before any events were ready. Each event + is a token previously passed to register_reader/writer(). + """ + raise NotImplementedError + + +class SelectPollster(PollsterBase): + """Pollster implementation using select.""" + + def poll(self, timeout=None): + readable, writable, _ = select.select(self.readers, self.writers, + [], timeout) + events = [] + events += (self.readers[fd] for fd in readable) + events += (self.writers[fd] for fd in writable) + return events + + +class PollPollster(PollsterBase): + """Pollster implementation using poll.""" + + def __init__(self): + super().__init__() + self._poll = select.poll() + + def _update(self, fd): + assert isinstance(fd, int), fd + flags = 0 + if fd in self.readers: + flags |= select.POLLIN + if fd in self.writers: + flags |= select.POLLOUT + if flags: + self._poll.register(fd, flags) + else: + self._poll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + # Timeout is in seconds, but poll() takes milliseconds. + msecs = None if timeout is None else int(round(1000 * timeout)) + events = [] + for fd, flags in self._poll.poll(msecs): + if flags & (select.POLLIN | select.POLLHUP): + if fd in self.readers: + events.append(self.readers[fd]) + if flags & (select.POLLOUT | select.POLLHUP): + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class EPollPollster(PollsterBase): + """Pollster implementation using epoll.""" + + def __init__(self): + super().__init__() + self._epoll = select.epoll() + + def _update(self, fd): + assert isinstance(fd, int), fd + eventmask = 0 + if fd in self.readers: + eventmask |= select.EPOLLIN + if fd in self.writers: + eventmask |= select.EPOLLOUT + if eventmask: + try: + self._epoll.register(fd, eventmask) + except IOError: + self._epoll.modify(fd, eventmask) + else: + self._epoll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + if timeout is None: + timeout = -1 # epoll.poll() uses -1 to mean "wait forever". + events = [] + for fd, eventmask in self._epoll.poll(timeout): + if eventmask & select.EPOLLIN: + if fd in self.readers: + events.append(self.readers[fd]) + if eventmask & select.EPOLLOUT: + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class KqueuePollster(PollsterBase): + """Pollster implementation using kqueue.""" + + def __init__(self): + super().__init__() + self._kqueue = select.kqueue() + + def register_reader(self, fd, callback, *args): + if fd not in self.readers: + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_reader(fd, callback, *args) + + def register_writer(self, fd, callback, *args): + if fd not in self.writers: + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_writer(fd, callback, *args) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def poll(self, timeout=None): + events = [] + max_ev = len(self.readers) + len(self.writers) + for kev in self._kqueue.control(None, max_ev, timeout): + fd = kev.ident + flag = kev.filter + if flag == select.KQ_FILTER_READ and fd in self.readers: + events.append(self.readers[fd]) + elif flag == select.KQ_FILTER_WRITE and fd in self.writers: + events.append(self.writers[fd]) + return events + + +# Pick the best pollster class for the platform. +if hasattr(select, 'kqueue'): + best_pollster = KqueuePollster +elif hasattr(select, 'epoll'): + best_pollster = EPollPollster +elif hasattr(select, 'poll'): + best_pollster = PollPollster +else: + best_pollster = SelectPollster + + +class DelayedCall: + """Object returned by callback registration methods.""" + + def __init__(self, when, callback, args, kwds=None): + self.when = when + self.callback = callback + self.args = args + self.kwds = kwds + self.cancelled = False + + def cancel(self): + self.cancelled = True + + def __lt__(self, other): + return self.when < other.when + + def __le__(self, other): + return self.when <= other.when + + def __eq__(self, other): + return self.when == other.when + + +class EventLoop: + """Event loop functionality. + + This defines public APIs call_soon(), call_later(), run_once() and + run(). It also wraps Pollster APIs register_reader(), + register_writer(), remove_reader(), remove_writer() with + add_reader() etc. + + This class's instance variables are not part of its API. + """ + + def __init__(self, pollster=None): + super().__init__() + if pollster is None: + logging.info('Using pollster: %s', best_pollster.__name__) + pollster = best_pollster() + self.pollster = pollster + self.ready = collections.deque() # [(callback, args), ...] + self.scheduled = [] # [(when, callback, args), ...] + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_reader(fd, dcall) + return dcall + + def remove_reader(self, fd): + """Remove a reader callback.""" + self.pollster.unregister_reader(fd) + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_writer(fd, dcall) + return dcall + + def remove_writer(self, fd): + """Remove a writer callback.""" + self.pollster.unregister_writer(fd) + + def add_callback(self, dcall): + """Add a DelayedCall to ready or scheduled.""" + if dcall.cancelled: + return + if dcall.when is None: + self.ready.append(dcall) + else: + heapq.heappush(self.scheduled, dcall) + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + dcall = DelayedCall(None, callback, args) + self.ready.append(dcall) + return dcall + + def call_later(self, when, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The time can be an int or float, expressed in seconds. + + If when is small enough (~11 days), it's assumed to be a + relative time, meaning the call will be scheduled that many + seconds in the future; otherwise it's assumed to be a posix + timestamp as returned by time.time(). + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + if when < 10000000: + when += time.time() + dcall = DelayedCall(when, callback, args) + heapq.heappush(self.scheduled, dcall) + return dcall + + def run_once(self): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Pass in a timeout or deadline or something. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: As step 4, run everything scheduled by steps 1-3. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # TODO: Ensure this loop always finishes, even if some + # callbacks keeps registering more callbacks. + while self.ready: + dcall = self.ready.popleft() + if not dcall.cancelled: + try: + if dcall.kwds: + dcall.callback(*dcall.args, **dcall.kwds) + else: + dcall.callback(*dcall.args) + except Exception: + logging.exception('Exception in callback %s %r', + dcall.callback, dcall.args) + + # Remove delayed calls that were cancelled from head of queue. + while self.scheduled and self.scheduled[0].cancelled: + heapq.heappop(self.scheduled) + + # Inspect the poll queue. + if self.pollster.pollable(): + if self.scheduled: + when = self.scheduled[0].when + timeout = max(0, when - time.time()) + else: + timeout = None + t0 = time.time() + events = self.pollster.poll(timeout) + t1 = time.time() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + for dcall in events: + self.add_callback(dcall) + + # Handle 'later' callbacks that are ready. + now = time.time() + while self.scheduled: + dcall = self.scheduled[0] + if dcall.when > now: + break + dcall = heapq.heappop(self.scheduled) + self.call_soon(dcall.callback, *dcall.args) + + def run(self): + """Run the event loop until there is no work left to do. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + """ + while self.ready or self.scheduled or self.pollster.pollable(): + self.run_once() + + +MAX_WORKERS = 5 # Default max workers when creating an executor. + + +class ThreadRunner: + """Helper to submit work to a thread pool and wait for it. + + This is the glue between the single-threaded callback-based async + world and the threaded world. Use it to call functions that must + block and don't have an async alternative (e.g. getaddrinfo()). + + The only public API is submit(). + """ + + def __init__(self, eventloop, executor=None): + self.eventloop = eventloop + self.executor = executor # Will be constructed lazily. + self.pipe_read_fd, self.pipe_write_fd = os.pipe() + self.active_count = 0 + + def read_callback(self): + """Semi-permanent callback while at least one future is active.""" + assert self.active_count > 0, self.active_count + data = os.read(self.pipe_read_fd, 8192) # Traditional buffer size. + self.active_count -= len(data) + if self.active_count == 0: + self.eventloop.remove_reader(self.pipe_read_fd) + assert self.active_count >= 0, self.active_count + + def submit(self, func, *args, executor=None, callback=None): + """Submit a function to the thread pool. + + This returns a concurrent.futures.Future instance. The caller + should not wait for that, but rather use the callback argument.. + """ + if executor is None: + executor = self.executor + if executor is None: + # Lazily construct a default executor. + # TODO: Should this be shared between threads? + executor = concurrent.futures.ThreadPoolExecutor(MAX_WORKERS) + self.executor = executor + assert self.active_count >= 0, self.active_count + future = executor.submit(func, *args) + if self.active_count == 0: + self.eventloop.add_reader(self.pipe_read_fd, self.read_callback) + self.active_count += 1 + def done_callback(fut): + if callback is not None: + self.eventloop.call_soon(callback, fut) + # TODO: Wake up the pipe in call_soon()? + os.write(self.pipe_write_fd, b'x') + future.add_done_callback(done_callback) + return future + + +class Context(threading.local): + """Thread-local context. + + We use this to avoid having to explicitly pass around an event loop + or something to hold the current task. + + TODO: Add an API so frameworks can substitute a different notion + of context more easily. + """ + + def __init__(self, eventloop=None, threadrunner=None): + # Default event loop and thread runner are lazily constructed + # when first accessed. + self._eventloop = eventloop + self._threadrunner = threadrunner + self.current_task = None # For the benefit of scheduling.py. + + @property + def eventloop(self): + if self._eventloop is None: + self._eventloop = EventLoop() + return self._eventloop + + @property + def threadrunner(self): + if self._threadrunner is None: + self._threadrunner = ThreadRunner(self.eventloop) + return self._threadrunner + + +context = Context() # Thread-local! diff --git a/old/scheduling.py b/old/scheduling.py new file mode 100644 index 00000000..3864571d --- /dev/null +++ b/old/scheduling.py @@ -0,0 +1,354 @@ +#!/usr/bin/env python3.3 +"""Example coroutine scheduler, PEP-380-style ('yield from '). + +Requires Python 3.3. + +There are likely micro-optimizations possible here, but that's not the point. + +TODO: +- Docstrings. +- Unittests. + +PATTERNS TO TRY: +- Various synchronization primitives (Lock, RLock, Event, Condition, + Semaphore, BoundedSemaphore, Barrier). +""" + +__author__ = 'Guido van Rossum ' + +# Standard library imports (keep in alphabetic order). +from concurrent.futures import CancelledError, TimeoutError +import logging +import time +import types + +# Local imports (keep in alphabetic order). +import polling + + +context = polling.context + + +class Task: + """Wrapper around a stack of generators. + + This is a bit like a Future, but with a different interface. + + TODO: + - wait for result. + """ + + def __init__(self, gen, name=None, *, timeout=None): + assert isinstance(gen, types.GeneratorType), repr(gen) + self.gen = gen + self.name = name or gen.__name__ + self.timeout = timeout + self.eventloop = context.eventloop + self.canceleer = None + if timeout is not None: + self.canceleer = self.eventloop.call_later(timeout, self.cancel) + self.blocked = False + self.unblocker = None + self.cancelled = False + self.must_cancel = False + self.alive = True + self.result = None + self.exception = None + self.done_callbacks = [] + # Start the task immediately. + self.eventloop.call_soon(self.step) + + def add_done_callback(self, done_callback): + # For better or for worse, the callback will always be called + # with the task as an argument, like concurrent.futures.Future. + # TODO: Call it right away if task is no longer alive. + dcall = polling.DelayedCall(None, done_callback, (self,)) + self.done_callbacks.append(dcall) + self.done_callbacks = [dc for dc in self.done_callbacks + if not dc.cancelled] + return dcall + + def __repr__(self): + parts = [self.name] + is_current = (self is context.current_task) + if self.blocked: + parts.append('blocking' if is_current else 'blocked') + elif self.alive: + parts.append('running' if is_current else 'runnable') + if self.must_cancel: + parts.append('must_cancel') + if self.cancelled: + parts.append('cancelled') + if self.exception is not None: + parts.append('exception=%r' % self.exception) + elif not self.alive: + parts.append('result=%r' % (self.result,)) + if self.timeout is not None: + parts.append('timeout=%.3f' % self.timeout) + return 'Task<' + ', '.join(parts) + '>' + + def cancel(self): + if self.alive: + if not self.must_cancel and not self.cancelled: + self.must_cancel = True + if self.blocked: + self.unblock() + + def step(self): + assert self.alive, self + try: + context.current_task = self + if self.must_cancel: + self.must_cancel = False + self.cancelled = True + self.gen.throw(CancelledError()) + else: + next(self.gen) + except StopIteration as exc: + self.alive = False + self.result = exc.value + except Exception as exc: + self.alive = False + self.exception = exc + logging.debug('Uncaught exception in %s', self, + exc_info=True, stack_info=True) + except BaseException as exc: + self.alive = False + self.exception = exc + raise + else: + if not self.blocked: + self.eventloop.call_soon(self.step) + finally: + context.current_task = None + if not self.alive: + # Cancel timeout callback if set. + if self.canceleer is not None: + self.canceleer.cancel() + # Schedule done_callbacks. + for dcall in self.done_callbacks: + self.eventloop.add_callback(dcall) + + def block(self, unblock_callback=None, *unblock_args): + assert self is context.current_task, self + assert self.alive, self + assert not self.blocked, self + self.blocked = True + self.unblocker = (unblock_callback, unblock_args) + + def unblock_if_alive(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. + if self.alive: + self.unblock() + + def unblock(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. + assert self.alive, self + assert self.blocked, self + self.blocked = False + unblock_callback, unblock_args = self.unblocker + if unblock_callback is not None: + try: + unblock_callback(*unblock_args) + except Exception: + logging.error('Exception in unblocker in task %r', self.name) + raise + finally: + self.unblocker = None + self.eventloop.call_soon(self.step) + + def block_io(self, fd, flag): + assert isinstance(fd, int), repr(fd) + assert flag in ('r', 'w'), repr(flag) + if flag == 'r': + self.block(self.eventloop.remove_reader, fd) + self.eventloop.add_reader(fd, self.unblock) + else: + self.block(self.eventloop.remove_writer, fd) + self.eventloop.add_writer(fd, self.unblock) + + def wait(self): + """COROUTINE: Wait until this task is finished.""" + current_task = context.current_task + assert self is not current_task, (self, current_task) # How confusing! + if not self.alive: + return + current_task.block() + self.add_done_callback(current_task.unblock) + yield + + def __iter__(self): + """COROUTINE: Wait, then return result or raise exception. + + This adds a little magic so you can say + + x = yield from Task(gen()) + + and it is equivalent to + + x = yield from gen() + + but with the option to add a timeout (and only a tad slower). + """ + if self.alive: + yield from self.wait() + assert not self.alive + if self.exception is not None: + raise self.exception + return self.result + + +def run(arg=None): + """Run the event loop until it's out of work. + + If you pass a generator, it will be spawned for you. + You can also pass a task (already started). + Returns the task. + """ + t = None + if arg is not None: + if isinstance(arg, Task): + t = arg + else: + t = Task(arg) + context.eventloop.run() + if t is not None and t.exception is not None: + logging.error('Uncaught exception in startup task: %r', + t.exception) + return t + + +def sleep(secs): + """COROUTINE: Sleep for some time (a float in seconds).""" + current_task = context.current_task + unblocker = context.eventloop.call_later(secs, current_task.unblock) + current_task.block(unblocker.cancel) + yield + + +def block_r(fd): + """COROUTINE: Block until a file descriptor is ready for reading.""" + context.current_task.block_io(fd, 'r') + yield + + +def block_w(fd): + """COROUTINE: Block until a file descriptor is ready for writing.""" + context.current_task.block_io(fd, 'w') + yield + + +def call_in_thread(func, *args, executor=None): + """COROUTINE: Run a function in a thread.""" + task = context.current_task + eventloop = context.eventloop + future = context.threadrunner.submit(func, *args, + executor=executor, + callback=task.unblock_if_alive) + task.block(future.cancel) + yield + assert future.done() + return future.result() + + +def wait_for(count, tasks): + """COROUTINE: Wait for the first N of a set of tasks to complete. + + May return more than N if more than N are immediately ready. + + NOTE: Tasks that were cancelled or raised are also considered ready. + """ + assert tasks + assert all(isinstance(task, Task) for task in tasks) + tasks = set(tasks) + assert 1 <= count <= len(tasks) + current_task = context.current_task + assert all(task is not current_task for task in tasks) + todo = set() + done = set() + dcalls = [] + def wait_for_callback(task): + nonlocal todo, done, current_task, count, dcalls + todo.remove(task) + if len(done) < count: + done.add(task) + if len(done) == count: + for dcall in dcalls: + dcall.cancel() + current_task.unblock() + for task in tasks: + if task.alive: + todo.add(task) + else: + done.add(task) + if len(done) < count: + for task in todo: + dcall = task.add_done_callback(wait_for_callback) + dcalls.append(dcall) + current_task.block() + yield + return done + + +def wait_any(tasks): + """COROUTINE: Wait for the first of a set of tasks to complete.""" + return wait_for(1, tasks) + + +def wait_all(tasks): + """COROUTINE: Wait for all of a set of tasks to complete.""" + return wait_for(len(tasks), tasks) + + +def map_over(gen, *args, timeout=None): + """COROUTINE: map a generator over one or more iterables. + + E.g. map_over(foo, xs, ys) runs + + Task(foo(x, y)) for x, y in zip(xs, ys) + + and returns a list of all results (in that order). However if any + task raises an exception, the remaining tasks are cancelled and + the exception is propagated. + """ + # gen is a generator function. + tasks = [Task(gobj, timeout=timeout) for gobj in map(gen, *args)] + return (yield from par_tasks(tasks)) + + +def par(*args): + """COROUTINE: Wait for generators, return a list of results. + + Raises as soon as one of the tasks raises an exception (and then + remaining tasks are cancelled). + + This differs from par_tasks() in two ways: + - takes *args instead of list of args + - each arg may be a generator or a task + """ + tasks = [] + for arg in args: + if not isinstance(arg, Task): + # TODO: assert arg is a generator or an iterator? + arg = Task(arg) + tasks.append(arg) + return (yield from par_tasks(tasks)) + + +def par_tasks(tasks): + """COROUTINE: Wait for a list of tasks, return a list of results. + + Raises as soon as one of the tasks raises an exception (and then + remaining tasks are cancelled). + """ + todo = set(tasks) + while todo: + ts = yield from wait_any(todo) + for t in ts: + assert not t.alive, t + todo.remove(t) + if t.exception is not None: + for other in todo: + other.cancel() + raise t.exception + return [t.result for t in tasks] diff --git a/old/sockets.py b/old/sockets.py new file mode 100644 index 00000000..a5005dc3 --- /dev/null +++ b/old/sockets.py @@ -0,0 +1,348 @@ +"""Socket wrappers to go with scheduling.py. + +Classes: + +- SocketTransport: a transport implementation wrapping a socket. +- SslTransport: a transport implementation wrapping SSL around a socket. +- BufferedReader: a buffer wrapping the read end of a transport. + +Functions (all coroutines): + +- connect(): connect a socket. +- getaddrinfo(): look up an address. +- create_connection(): look up address and return a connected socket for it. +- create_transport(): look up address and return a connected transport. + +TODO: +- Improve transport abstraction. +- Make a nice protocol abstraction. +- Unittests. +- A write() call that isn't a generator (needed so you can substitute it + for sys.stderr, pass it to logging.StreamHandler, etc.). +""" + +__author__ = 'Guido van Rossum ' + +# Stdlib imports. +import errno +import socket +import ssl + +# Local imports. +import scheduling + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + + +class SocketTransport: + """Transport wrapping a socket. + + The socket must already be connected in non-blocking mode. + """ + + def __init__(self, sock): + self.sock = sock + + def recv(self, n): + """COROUTINE: Read up to n bytes, blocking as needed. + + Always returns at least one byte, except if the socket was + closed or disconnected and there's no more data; then it + returns b''. + """ + assert n >= 0, n + while True: + try: + return self.sock.recv(n) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return b'' + else: + raise # Unexpected, propagate. + yield from scheduling.block_r(self.sock.fileno()) + + def send(self, data): + """COROUTINE; Send data to the socket, blocking until all written. + + Return True if all went well, False if socket was disconnected. + """ + while data: + try: + n = self.sock.send(data) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + else: + assert 0 <= n <= len(data), (n, len(data)) + if n == len(data): + break + data = data[n:] + continue + yield from scheduling.block_w(self.sock.fileno()) + + return True + + def close(self): + """Close the socket. (Not a coroutine.)""" + self.sock.close() + + +class SslTransport: + """Transport wrapping a socket in SSL. + + The socket must already be connected at the TCP level in + non-blocking mode. + """ + + def __init__(self, rawsock, sslcontext=None): + self.rawsock = rawsock + self.sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self.sslsock = self.sslcontext.wrap_socket( + self.rawsock, do_handshake_on_connect=False) + + def do_handshake(self): + """COROUTINE: Finish the SSL handshake.""" + while True: + try: + self.sslsock.do_handshake() + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + else: + break + + def recv(self, n): + """COROUTINE: Read up to n bytes. + + This blocks until at least one byte is read, or until EOF. + """ + while True: + try: + return self.sslsock.recv(n) + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_r(self.sslsock.fileno()) + elif err.errno in _DISCONNECTED: + # Can this happen? + return b'' + else: + raise # Unexpected, propagate. + + def send(self, data): + """COROUTINE: Send data to the socket, blocking as needed.""" + while data: + try: + n = self.sslsock.send(data) + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_w(self.sslsock.fileno()) + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + if n == len(data): + break + data = data[n:] + + return True + + def close(self): + """Close the socket. (Not a coroutine.) + + This also closes the raw socket. + """ + self.sslsock.close() + + # TODO: More SSL-specific methods, e.g. certificate stuff, unwrap(), ... + + +class BufferedReader: + """A buffered reader wrapping a transport.""" + + def __init__(self, trans, limit=8192): + self.trans = trans + self.limit = limit + self.buffer = b'' + self.eof = False + + def read(self, n): + """COROUTINE: Read up to n bytes, blocking at most once.""" + assert n >= 0, n + if not self.buffer and not self.eof: + yield from self._fillbuffer(max(n, self.limit)) + return self._getfrombuffer(n) + + def readexactly(self, n): + """COUROUTINE: Read exactly n bytes, or until EOF.""" + blocks = [] + count = 0 + while count < n: + block = yield from self.read(n - count) + if not block: + break + blocks.append(block) + count += len(block) + return b''.join(blocks) + + def readline(self): + """COROUTINE: Read up to newline or limit, whichever comes first.""" + end = self.buffer.find(b'\n') + 1 # Point past newline, or 0. + while not end and not self.eof and len(self.buffer) < self.limit: + anchor = len(self.buffer) + yield from self._fillbuffer(self.limit) + end = self.buffer.find(b'\n', anchor) + 1 + if not end: + end = len(self.buffer) + if end > self.limit: + end = self.limit + return self._getfrombuffer(end) + + def _getfrombuffer(self, n): + """Read up to n bytes without blocking (not a coroutine).""" + if n >= len(self.buffer): + result, self.buffer = self.buffer, b'' + else: + result, self.buffer = self.buffer[:n], self.buffer[n:] + return result + + def _fillbuffer(self, n): + """COROUTINE: Fill buffer with one (up to) n bytes from transport.""" + assert not self.eof, '_fillbuffer called at eof' + data = yield from self.trans.recv(n) + if data: + self.buffer += data + else: + self.eof = True + + +def connect(sock, address): + """COROUTINE: Connect a socket to an address.""" + try: + sock.connect(address) + except socket.error as err: + if err.errno != errno.EINPROGRESS: + raise + yield from scheduling.block_w(sock.fileno()) + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + raise IOError(err, 'Connection refused') + + +def getaddrinfo(host, port, af=0, socktype=0, proto=0): + """COROUTINE: Look up an address and return a list of infos for it. + + Each info is a tuple (af, socktype, protocol, canonname, address). + """ + infos = yield from scheduling.call_in_thread(socket.getaddrinfo, + host, port, af, + socktype, proto) + return infos + + +def create_connection(host, port, af=0, socktype=socket.SOCK_STREAM, proto=0): + """COROUTINE: Look up address and create a socket connected to it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + yield from connect(sock, address) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return sock + + +def create_transport(host, port, af=0, ssl=None): + """COROUTINE: Look up address and create a transport connected to it.""" + if ssl is None: + ssl = (port == 443) + sock = yield from create_connection(host, port, af) + if ssl: + trans = SslTransport(sock) + yield from trans.do_handshake() + else: + trans = SocketTransport(sock) + return trans + + +class Listener: + """Wrapper for a listening socket.""" + + def __init__(self, sock): + self.sock = sock + + def accept(self): + """COROUTINE: Accept a connection.""" + while True: + try: + conn, addr = self.sock.accept() + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_r(self.sock.fileno()) + else: + raise # Unexpected, propagate. + else: + conn.setblocking(False) + return conn, addr + + +def create_listener(host, port, af=0, socktype=0, proto=0, + backlog=5, reuse_addr=True): + """COROUTINE: Look up address and create a listener for it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + if reuse_addr: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + sock.listen(backlog) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return Listener(sock) diff --git a/old/transports.py b/old/transports.py new file mode 100644 index 00000000..19095bf4 --- /dev/null +++ b/old/transports.py @@ -0,0 +1,496 @@ +"""Transports and Protocols, actually. + +Inspired by Twisted, PEP 3153 and github.com/lvh/async-pep. + +THIS IS NOT REAL CODE! IT IS JUST AN EXPERIMENT. +""" + +# Stdlib imports. +import collections +import errno +import logging +import socket +import ssl +import sys +import time + +# Local imports. +import polling +import scheduling +import sockets + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + + +class Transport: + """ABC representing a transport. + + There may be many implementations. The user never instantiates + this directly; they call some utility function, passing it a + protocol, and the utility function will call the protocol's + connection_made() method with a transport (or it will call + connection_lost() with an exception if it fails to create the + desired transport). + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + def write(self, data): + """Write some data (bytes) to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data (bytes) to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data will + be received. When all buffered data is flushed, the protocol's + connection_lost() method is called with None as its argument. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method is called with None as + its argument. + """ + raise NotImplementedError + + def half_close(self): + """Closes the write end after flushing buffered data. + + Data may still be received. + + TODO: What's the use case for this? How to implement it? + Should it call shutdown(SHUT_WR) after all the data is flushed? + Is there no use case for closing the other half first? + """ + raise NotImplementedError + + def pause(self): + """Pause the receiving end. + + No data will be received until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Cancels a pause() call, resumes receiving data. + """ + raise NotImplementedError + + +class Protocol: + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing. + + When the user wants to requests a transport, they pass a protocol + instance to a utility function. + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_list() will be called exactly once + with either an exception object or None as an argument. + + If the utility function does not succeed in creating a transport, + it will call connection_lost() with an exception object. + + State machine of calls: + + start -> [CM -> DR*] -> CL -> end + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the connection. + To send data, call its write() or writelines() method. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + + TODO: Should we allow it to be a bytesarray or some other + memory buffer? + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + Also called when we fail to make a connection at all (in that + case connection_made() will not be called). + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +# TODO: The rest is platform specific and should move elsewhere. + +class UnixSocketTransport(Transport): + + def __init__(self, eventloop, protocol, sock): + self._eventloop = eventloop + self._protocol = protocol + self._sock = sock + self._buffer = collections.deque() # For write(). + self._write_closed = False + + def _on_readable(self): + try: + data = self._sock.recv(8192) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + else: + if not data: + self._eventloop.remove_reader(self._sock.fileno()) + self._sock.close() + self._protocol.connection_lost(None) + else: + self._protocol.data_received(data) # XXX call_soon()? + + def write(self, data): + assert isinstance(data, bytes) + assert not self._write_closed + if not data: + # Silly, but it happens. + return + if self._buffer: + # We've already registered a callback, just buffer the data. + self._buffer.append(data) + # Consider pausing if the total length of the buffer is + # truly huge. + return + + # TODO: Refactor so there's more sharing between this and + # _on_writable(). + + # There's no callback registered yet. It's quite possible + # that the kernel has buffer space for our data, so try to + # write now. Since the socket is non-blocking it will + # give us an error in _TRYAGAIN if it doesn't have enough + # space for even one more byte; it will return the number + # of bytes written if it can write at least one byte. + try: + n = self._sock.send(data) + except socket.error as exc: + # An error. + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + # The kernel doesn't have room for more data right now. + n = 0 + else: + # Wrote at least one byte. + if n == len(data): + # Wrote it all. Done! + if self._write_closed: + self._sock.shutdown(socket.SHUT_WR) + return + # Throw away the data that was already written. + # TODO: Do this without copying the data? + data = data[n:] + self._buffer.append(data) + self._eventloop.add_writer(self._sock.fileno(), self._on_writable) + + def _on_writable(self): + while self._buffer: + data = self._buffer[0] + # TODO: Join small amounts of data? + try: + n = self._sock.send(data) + except socket.error as exc: + # Error handling is the same as in write(). + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + if n < len(data): + self._buffer[0] = data[n:] + return + self._buffer.popleft() + self._eventloop.remove_writer(self._sock.fileno()) + if self._write_closed: + self._sock.shutdown(socket.SHUT_WR) + + def abort(self): + self._bad_error(None) + + def _bad_error(self, exc): + # A serious error. Close the socket etc. + fd = self._sock.fileno() + # TODO: Record whether we have a writer and/or reader registered. + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + self._sock.close() + self._protocol.connection_lost(exc) # XXX call_soon()? + + def half_close(self): + self._write_closed = True + + +class UnixSslTransport(Transport): + + # TODO: Refactor Socket and Ssl transport to share some code. + # (E.g. buffering.) + + # TODO: Consider using coroutines instead of callbacks, it seems + # much easier that way. + + def __init__(self, eventloop, protocol, rawsock, sslcontext=None): + self._eventloop = eventloop + self._protocol = protocol + self._rawsock = rawsock + self._sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslsock = self._sslcontext.wrap_socket( + self._rawsock, do_handshake_on_connect=False) + + self._buffer = collections.deque() # For write(). + self._write_closed = False + + # Try the handshake now. Likely it will raise EAGAIN, then it + # will take care of registering the appropriate callback. + self._on_handshake() + + def _bad_error(self, exc): + # A serious error. Close the socket etc. + fd = self._sslsock.fileno() + # TODO: Record whether we have a writer and/or reader registered. + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + self._sslsock.close() + self._protocol.connection_lost(exc) # XXX call_soon()? + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._eventloop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._eventloop.add_writable(fd, self._on_handshake) + return + # TODO: What if it raises another error? + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + self._protocol.connection_made(self) + self._eventloop.add_reader(fd, self._on_ready) + self._eventloop.add_writer(fd, self._on_ready) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + else: + if data: + self._protocol.data_received(data) + else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer + self._eventloop.remove_reader(fd) + self._eventloop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + return + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = self._buffer[0] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + else: + if n == len(data): + self._buffer.popleft() + # Could try again, but let's just have the next callback do it. + else: + self._buffer[0] = data[n:] + + def write(self, data): + assert isinstance(data, bytes) + assert not self._write_closed + if not data: + return + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + def half_close(self): + self._write_closed = True + # Just set the flag. Calling shutdown() on the ssl socket + # breaks something, causing recv() to return binary data. + + +def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, + use_ssl=None): + # TODO: Pass in a protocol factory, not a protocol. + # What should be the exact sequence of events? + # - socket + # - transport + # - protocol + # - tell transport about protocol + # - tell protocol about transport + # Or should the latter two be reversed? Does it matter? + if port is None: + port = 443 if use_ssl else 80 + if use_ssl is None: + use_ssl = (port == 443) + if not socktype: + socktype = socket.SOCK_STREAM + eventloop = polling.context.eventloop + + def on_socket_connected(task): + assert not task.alive + if task.exception is not None: + # TODO: Call some callback. + raise task.exception + sock = task.result + assert sock is not None + logging.debug('on_socket_connected') + if use_ssl: + # You can pass an ssl.SSLContext object as use_ssl, + # or a bool. + if isinstance(use_ssl, bool): + sslcontext = None + else: + sslcontext = use_ssl + transport = UnixSslTransport(eventloop, protocol, sock, sslcontext) + else: + transport = UnixSocketTransport(eventloop, protocol, sock) + # TODO: Should the ransport make the following calls? + protocol.connection_made(transport) # XXX call_soon()? + # Don't do this before connection_made() is called. + eventloop.add_reader(sock.fileno(), transport._on_readable) + + coro = sockets.create_connection(host, port, af, socktype, proto) + task = scheduling.Task(coro) + task.add_done_callback(on_socket_connected) + + +def main(): # Testing... + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + host = 'xkcd.com' + if sys.argv[1:] and '.' in sys.argv[-1]: + host = sys.argv[-1] + + t0 = time.time() + + class TestProtocol(Protocol): + def connection_made(self, transport): + logging.info('Connection made at %.3f secs', time.time() - t0) + self.transport = transport + self.transport.write(b'GET / HTTP/1.0\r\nHost: ' + + host.encode('ascii') + + b'\r\n\r\n') + self.transport.half_close() + def data_received(self, data): + logging.info('Received %d bytes at t=%.3f', + len(data), time.time() - t0) + logging.debug('Received %r', data) + def connection_lost(self, exc): + logging.debug('Connection lost: %r', exc) + self.t1 = time.time() + logging.info('Total time %.3f secs', self.t1 - t0) + + tp = TestProtocol() + logging.debug('tp = %r', tp) + make_connection(tp, host, use_ssl=('-S' in sys.argv)) + logging.info('Running...') + polling.context.eventloop.run() + logging.info('Done.') + + +if __name__ == '__main__': + main() diff --git a/old/xkcd.py b/old/xkcd.py new file mode 100755 index 00000000..474009d0 --- /dev/null +++ b/old/xkcd.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3.3 +"""Minimal synchronous SSL demo, connecting to xkcd.com.""" + +import socket, ssl + +s = socket.socket() +s.connect(('xkcd.com', 443)) +ss = ssl.wrap_socket(s) + +ss.send(b'GET / HTTP/1.0\r\n\r\n') + +while True: + data = ss.recv(1000000) + print(data) + if not data: + break + +ss.close() diff --git a/old/yyftime.py b/old/yyftime.py new file mode 100644 index 00000000..f55234b9 --- /dev/null +++ b/old/yyftime.py @@ -0,0 +1,75 @@ +"""Compare timing of yield-from vs. yield calls.""" + +import gc +import time + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def run_coro(depth): + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + print('coro', depth, k, round(t1-t0, 6)) + return t1-t0 + +class Future: + + def __init__(self, g): + self.g = g + + def wait(self): + value = None + try: + while True: + f = self.g.send(value) + f.wait() + value = f.value + except StopIteration as err: + self.value = err.value + + + +def task(func): # Decorator + def wrapper(*args): + g = func(*args) + f = Future(g) + return f + return wrapper + +@task +def oldstyle(n): + if n <= 0: + return 1 + l = yield oldstyle(n-1) + r = yield oldstyle(n-1) + return l + 1 + r + +def run_olds(depth): + t0 = time.time() + f = oldstyle(depth) + f.wait() + k = f.value + t1 = time.time() + print('olds', depth, k, round(t1-t0, 6)) + return t1-t0 + +def main(): + gc.disable() + for depth in range(16): + tc = run_coro(depth) + to = run_olds(depth) + if tc: + print('ratio', round(to/tc, 2)) + +if __name__ == '__main__': + main() diff --git a/overlapped.c b/overlapped.c new file mode 100644 index 00000000..c9f6ec9f --- /dev/null +++ b/overlapped.c @@ -0,0 +1,997 @@ +/* + * Support for overlapped IO + * + * Some code borrowed from Modules/_winapi.c of CPython + */ + +/* XXX check overflow and DWORD <-> Py_ssize_t conversions + Check itemsize */ + +#include "Python.h" +#include "structmember.h" + +#define WINDOWS_LEAN_AND_MEAN +#include +#include +#include + +#if defined(MS_WIN32) && !defined(MS_WIN64) +# define F_POINTER "k" +# define T_POINTER T_ULONG +#else +# define F_POINTER "K" +# define T_POINTER T_ULONGLONG +#endif + +#define F_HANDLE F_POINTER +#define F_ULONG_PTR F_POINTER +#define F_DWORD "k" +#define F_BOOL "i" +#define F_UINT "I" + +#define T_HANDLE T_POINTER + +enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_WRITE, TYPE_ACCEPT, + TYPE_CONNECT, TYPE_DISCONNECT}; + +/* + * Map Windows error codes to subclasses of OSError + */ + +static PyObject * +SetFromWindowsErr(DWORD err) +{ + PyObject *exception_type; + + if (err == 0) + err = GetLastError(); + switch (err) { + case ERROR_CONNECTION_REFUSED: + exception_type = PyExc_ConnectionRefusedError; + break; + case ERROR_CONNECTION_ABORTED: + exception_type = PyExc_ConnectionAbortedError; + break; + default: + exception_type = PyExc_OSError; + } + return PyErr_SetExcFromWindowsErr(exception_type, err); +} + +/* + * Some functions should be loaded at runtime + */ + +static LPFN_ACCEPTEX Py_AcceptEx = NULL; +static LPFN_CONNECTEX Py_ConnectEx = NULL; +static LPFN_DISCONNECTEX Py_DisconnectEx = NULL; +static BOOL (CALLBACK *Py_CancelIoEx)(HANDLE, LPOVERLAPPED) = NULL; + +#define GET_WSA_POINTER(s, x) \ + (SOCKET_ERROR != WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, \ + &Guid##x, sizeof(Guid##x), &Py_##x, \ + sizeof(Py_##x), &dwBytes, NULL, NULL)) + +static int +initialize_function_pointers(void) +{ + GUID GuidAcceptEx = WSAID_ACCEPTEX; + GUID GuidConnectEx = WSAID_CONNECTEX; + GUID GuidDisconnectEx = WSAID_DISCONNECTEX; + HINSTANCE hKernel32; + SOCKET s; + DWORD dwBytes; + + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (s == INVALID_SOCKET) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + if (!GET_WSA_POINTER(s, AcceptEx) || + !GET_WSA_POINTER(s, ConnectEx) || + !GET_WSA_POINTER(s, DisconnectEx)) + { + closesocket(s); + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + closesocket(s); + + /* On WinXP we will have Py_CancelIoEx == NULL */ + hKernel32 = GetModuleHandle("KERNEL32"); + *(FARPROC *)&Py_CancelIoEx = GetProcAddress(hKernel32, "CancelIoEx"); + return 0; +} + +/* + * Completion port stuff + */ + +PyDoc_STRVAR( + CreateIoCompletionPort_doc, + "CreateIoCompletionPort(handle, port, key, concurrency) -> port\n\n" + "Create a completion port or register a handle with a port."); + +static PyObject * +overlapped_CreateIoCompletionPort(PyObject *self, PyObject *args) +{ + HANDLE FileHandle; + HANDLE ExistingCompletionPort; + ULONG_PTR CompletionKey; + DWORD NumberOfConcurrentThreads; + HANDLE ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE F_ULONG_PTR F_DWORD, + &FileHandle, &ExistingCompletionPort, &CompletionKey, + &NumberOfConcurrentThreads)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = CreateIoCompletionPort(FileHandle, ExistingCompletionPort, + CompletionKey, NumberOfConcurrentThreads); + Py_END_ALLOW_THREADS + + if (ret == NULL) + return SetFromWindowsErr(0); + return Py_BuildValue(F_HANDLE, ret); +} + +PyDoc_STRVAR( + GetQueuedCompletionStatus_doc, + "GetQueuedCompletionStatus(port, msecs) -> (err, bytes, key, address)\n\n" + "Get a message from completion port. Wait for up to msecs milliseconds."); + +static PyObject * +overlapped_GetQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort = NULL; + DWORD NumberOfBytes = 0; + ULONG_PTR CompletionKey = 0; + OVERLAPPED *Overlapped = NULL; + DWORD Milliseconds; + DWORD err; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, + &CompletionPort, &Milliseconds)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = GetQueuedCompletionStatus(CompletionPort, &NumberOfBytes, + &CompletionKey, &Overlapped, Milliseconds); + Py_END_ALLOW_THREADS + + err = ret ? ERROR_SUCCESS : GetLastError(); + if (Overlapped == NULL) { + if (err == WAIT_TIMEOUT) + Py_RETURN_NONE; + else + return SetFromWindowsErr(err); + } + return Py_BuildValue(F_DWORD F_DWORD F_ULONG_PTR F_POINTER, + err, NumberOfBytes, CompletionKey, Overlapped); +} + +PyDoc_STRVAR( + PostQueuedCompletionStatus_doc, + "PostQueuedCompletionStatus(port, bytes, key, address) -> None\n\n" + "Post a message to completion port."); + +static PyObject * +overlapped_PostQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort; + DWORD NumberOfBytes; + ULONG_PTR CompletionKey; + OVERLAPPED *Overlapped; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD F_ULONG_PTR F_POINTER, + &CompletionPort, &NumberOfBytes, &CompletionKey, + &Overlapped)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = PostQueuedCompletionStatus(CompletionPort, NumberOfBytes, + CompletionKey, Overlapped); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +/* + * Bind socket handle to local port without doing slow getaddrinfo() + */ + +PyDoc_STRVAR( + BindLocal_doc, + "BindLocal(handle, length_of_address_tuple) -> None\n\n" + "Bind a socket handle to an arbitrary local port.\n" + "If length_of_address_tuple is 2 then an AF_INET address is used.\n" + "If length_of_address_tuple is 4 then an AF_INET6 address is used."); + +static PyObject * +overlapped_BindLocal(PyObject *self, PyObject *args) +{ + SOCKET Socket; + int TupleLength; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE "i", &Socket, &TupleLength)) + return NULL; + + if (TupleLength == 2) { + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = 0; + addr.sin_addr.S_un.S_addr = INADDR_ANY; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else if (TupleLength == 4) { + struct sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_port = 0; + addr.sin6_addr = in6addr_any; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else { + PyErr_SetString(PyExc_ValueError, "expected tuple of length 2 or 4"); + return NULL; + } + + if (!ret) + return SetFromWindowsErr(WSAGetLastError()); + Py_RETURN_NONE; +} + +/* + * A Python object wrapping an OVERLAPPED structure and other useful data + * for overlapped I/O + */ + +PyDoc_STRVAR( + Overlapped_doc, + "Overlapped object"); + +typedef struct { + PyObject_HEAD + OVERLAPPED overlapped; + /* For convenience, we store the file handle too */ + HANDLE handle; + /* Error returned by last method call */ + DWORD error; + /* Type of operation */ + DWORD type; + /* Buffer used for reading (optional) */ + PyObject *read_buffer; + /* Buffer used for writing (optional) */ + Py_buffer write_buffer; +} OverlappedObject; + + +static PyObject * +Overlapped_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + OverlappedObject *self; + HANDLE event = INVALID_HANDLE_VALUE; + static char *kwlist[] = {"event", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|" F_HANDLE, kwlist, &event)) + return NULL; + + if (event == INVALID_HANDLE_VALUE) { + event = CreateEvent(NULL, TRUE, FALSE, NULL); + if (event == NULL) + return SetFromWindowsErr(0); + } + + self = PyObject_New(OverlappedObject, type); + if (self == NULL) { + if (event != NULL) + CloseHandle(event); + return NULL; + } + + self->handle = NULL; + self->error = 0; + self->type = TYPE_NONE; + self->read_buffer = NULL; + memset(&self->overlapped, 0, sizeof(OVERLAPPED)); + memset(&self->write_buffer, 0, sizeof(Py_buffer)); + if (event) + self->overlapped.hEvent = event; + return (PyObject *)self; +} + +static void +Overlapped_dealloc(OverlappedObject *self) +{ + DWORD bytes; + DWORD olderr = GetLastError(); + BOOL wait = FALSE; + BOOL ret; + + if (!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED) + { + if (Py_CancelIoEx && Py_CancelIoEx(self->handle, &self->overlapped)) + wait = TRUE; + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, + &bytes, wait); + Py_END_ALLOW_THREADS + + switch (ret ? ERROR_SUCCESS : GetLastError()) { + case ERROR_SUCCESS: + case ERROR_NOT_FOUND: + case ERROR_OPERATION_ABORTED: + break; + default: + PyErr_Format( + PyExc_RuntimeError, + "%R still has pending operation at " + "deallocation, the process may crash", self); + PyErr_WriteUnraisable(NULL); + } + } + + if (self->overlapped.hEvent != NULL) + CloseHandle(self->overlapped.hEvent); + + if (self->write_buffer.obj) + PyBuffer_Release(&self->write_buffer); + + Py_CLEAR(self->read_buffer); + PyObject_Del(self); + SetLastError(olderr); +} + +PyDoc_STRVAR( + Overlapped_cancel_doc, + "cancel() -> None\n\n" + "Cancel overlapped operation"); + +static PyObject * +Overlapped_cancel(OverlappedObject *self) +{ + BOOL ret = TRUE; + + if (self->type == TYPE_NOT_STARTED) + Py_RETURN_NONE; + + if (!HasOverlappedIoCompleted(&self->overlapped)) { + Py_BEGIN_ALLOW_THREADS + if (Py_CancelIoEx) + ret = Py_CancelIoEx(self->handle, &self->overlapped); + else + ret = CancelIo(self->handle); + Py_END_ALLOW_THREADS + } + + /* CancelIoEx returns ERROR_NOT_FOUND if the I/O completed in-between */ + if (!ret && GetLastError() != ERROR_NOT_FOUND) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +PyDoc_STRVAR( + Overlapped_getresult_doc, + "getresult(wait=False) -> result\n\n" + "Retrieve result of operation. If wait is true then it blocks\n" + "until the operation is finished. If wait is false and the\n" + "operation is still pending then an error is raised."); + +static PyObject * +Overlapped_getresult(OverlappedObject *self, PyObject *args) +{ + BOOL wait = FALSE; + DWORD transferred = 0; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, "|" F_BOOL, &wait)) + return NULL; + + if (self->type == TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation not yet attempted"); + return NULL; + } + + if (self->type == TYPE_NOT_STARTED) { + PyErr_SetString(PyExc_ValueError, "operation failed to start"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, &transferred, + wait); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + break; + case ERROR_BROKEN_PIPE: + if (self->read_buffer != NULL) + break; + /* fall through */ + default: + return SetFromWindowsErr(err); + } + + switch (self->type) { + case TYPE_READ: + assert(PyBytes_CheckExact(self->read_buffer)); + if (transferred != PyBytes_GET_SIZE(self->read_buffer) && + _PyBytes_Resize(&self->read_buffer, transferred)) + return NULL; + Py_INCREF(self->read_buffer); + return self->read_buffer; + case TYPE_ACCEPT: + case TYPE_CONNECT: + case TYPE_DISCONNECT: + Py_RETURN_NONE; + default: + return PyLong_FromUnsignedLong((unsigned long) transferred); + } +} + +PyDoc_STRVAR( + Overlapped_ReadFile_doc, + "ReadFile(handle, size) -> Overlapped[message]\n\n" + "Start overlapped read"); + +static PyObject * +Overlapped_ReadFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD nread; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &handle, &size)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = ReadFile(handle, PyBytes_AS_STRING(buf), size, &nread, + &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_BROKEN_PIPE: + self->type = TYPE_NOT_STARTED; + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSARecv_doc, + "RecvFile(handle, size, flags) -> Overlapped[message]\n\n" + "Start overlapped receive"); + +static PyObject * +Overlapped_WSARecv(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD flags = 0; + DWORD nread; + PyObject *buf; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD "|" F_DWORD, + &handle, &size, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + wsabuf.len = size; + wsabuf.buf = PyBytes_AS_STRING(buf); + + Py_BEGIN_ALLOW_THREADS + ret = WSARecv((SOCKET)handle, &wsabuf, 1, &nread, &flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_BROKEN_PIPE: + self->type = TYPE_NOT_STARTED; + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WriteFile_doc, + "WriteFile(handle, buf) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped write"); + +static PyObject * +Overlapped_WriteFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD written; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &handle, &bufobj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + + Py_BEGIN_ALLOW_THREADS + ret = WriteFile(handle, self->write_buffer.buf, + (DWORD)self->write_buffer.len, + &written, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSASend_doc, + "WSASend(handle, buf, flags) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped send"); + +static PyObject * +Overlapped_WSASend(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD flags; + DWORD written; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O" F_DWORD, + &handle, &bufobj, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + wsabuf.len = (DWORD)self->write_buffer.len; + wsabuf.buf = self->write_buffer.buf; + + Py_BEGIN_ALLOW_THREADS + ret = WSASend((SOCKET)handle, &wsabuf, 1, &written, flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_AcceptEx_doc, + "AcceptEx(listen_handle, accept_handle) -> Overlapped[address_as_bytes]\n\n" + "Start overlapped wait for client to connect"); + +static PyObject * +Overlapped_AcceptEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ListenSocket; + SOCKET AcceptSocket; + DWORD BytesReceived; + DWORD size; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE, + &ListenSocket, &AcceptSocket)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + size = sizeof(struct sockaddr_in6) + 16; + buf = PyBytes_FromStringAndSize(NULL, size*2); + if (!buf) + return NULL; + + self->type = TYPE_ACCEPT; + self->handle = (HANDLE)ListenSocket; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = Py_AcceptEx(ListenSocket, AcceptSocket, PyBytes_AS_STRING(buf), + 0, size, size, &BytesReceived, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + + +static int +parse_address(PyObject *obj, SOCKADDR *Address, int Length) +{ + char *Host; + unsigned short Port; + unsigned long FlowInfo; + unsigned long ScopeId; + + memset(Address, 0, Length); + + if (PyArg_ParseTuple(obj, "sH", &Host, &Port)) + { + Address->sa_family = AF_INET; + if (WSAStringToAddressA(Host, AF_INET, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN*)Address)->sin_port = htons(Port); + return Length; + } + else if (PyArg_ParseTuple(obj, "sHkk", &Host, &Port, &FlowInfo, &ScopeId)) + { + PyErr_Clear(); + Address->sa_family = AF_INET6; + if (WSAStringToAddressA(Host, AF_INET6, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN6*)Address)->sin6_port = htons(Port); + ((SOCKADDR_IN6*)Address)->sin6_flowinfo = FlowInfo; + ((SOCKADDR_IN6*)Address)->sin6_scope_id = ScopeId; + return Length; + } + + return -1; +} + + +PyDoc_STRVAR( + Overlapped_ConnectEx_doc, + "ConnectEx(client_handle, address_as_bytes) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_ConnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ConnectSocket; + PyObject *AddressObj; + char AddressBuf[sizeof(struct sockaddr_in6)]; + SOCKADDR *Address = (SOCKADDR*)AddressBuf; + int Length; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &ConnectSocket, &AddressObj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + Length = sizeof(AddressBuf); + Length = parse_address(AddressObj, Address, Length); + if (Length < 0) + return NULL; + + self->type = TYPE_CONNECT; + self->handle = (HANDLE)ConnectSocket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_ConnectEx(ConnectSocket, Address, Length, + NULL, 0, NULL, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_DisconnectEx_doc, + "DisconnectEx(handle, flags) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_DisconnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET Socket; + DWORD flags; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &Socket, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + self->type = TYPE_DISCONNECT; + self->handle = (HANDLE)Socket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_DisconnectEx(Socket, &self->overlapped, flags, 0); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +static PyObject* +Overlapped_getaddress(OverlappedObject *self) +{ + return PyLong_FromVoidPtr(&self->overlapped); +} + +static PyObject* +Overlapped_getpending(OverlappedObject *self) +{ + return PyBool_FromLong(!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED); +} + +static PyMethodDef Overlapped_methods[] = { + {"getresult", (PyCFunction) Overlapped_getresult, + METH_VARARGS, Overlapped_getresult_doc}, + {"cancel", (PyCFunction) Overlapped_cancel, + METH_NOARGS, Overlapped_cancel_doc}, + {"ReadFile", (PyCFunction) Overlapped_ReadFile, + METH_VARARGS, Overlapped_ReadFile_doc}, + {"WSARecv", (PyCFunction) Overlapped_WSARecv, + METH_VARARGS, Overlapped_WSARecv_doc}, + {"WriteFile", (PyCFunction) Overlapped_WriteFile, + METH_VARARGS, Overlapped_WriteFile_doc}, + {"WSASend", (PyCFunction) Overlapped_WSASend, + METH_VARARGS, Overlapped_WSASend_doc}, + {"AcceptEx", (PyCFunction) Overlapped_AcceptEx, + METH_VARARGS, Overlapped_AcceptEx_doc}, + {"ConnectEx", (PyCFunction) Overlapped_ConnectEx, + METH_VARARGS, Overlapped_ConnectEx_doc}, + {"DisconnectEx", (PyCFunction) Overlapped_DisconnectEx, + METH_VARARGS, Overlapped_DisconnectEx_doc}, + {NULL} +}; + +static PyMemberDef Overlapped_members[] = { + {"error", T_ULONG, + offsetof(OverlappedObject, error), + READONLY, "Error from last operation"}, + {"event", T_HANDLE, + offsetof(OverlappedObject, overlapped) + offsetof(OVERLAPPED, hEvent), + READONLY, "Overlapped event handle"}, + {NULL} +}; + +static PyGetSetDef Overlapped_getsets[] = { + {"address", (getter)Overlapped_getaddress, NULL, + "Address of overlapped structure"}, + {"pending", (getter)Overlapped_getpending, NULL, + "Whether the operation is pending"}, + {NULL}, +}; + +PyTypeObject OverlappedType = { + PyVarObject_HEAD_INIT(NULL, 0) + /* tp_name */ "_overlapped.Overlapped", + /* tp_basicsize */ sizeof(OverlappedObject), + /* tp_itemsize */ 0, + /* tp_dealloc */ (destructor) Overlapped_dealloc, + /* tp_print */ 0, + /* tp_getattr */ 0, + /* tp_setattr */ 0, + /* tp_reserved */ 0, + /* tp_repr */ 0, + /* tp_as_number */ 0, + /* tp_as_sequence */ 0, + /* tp_as_mapping */ 0, + /* tp_hash */ 0, + /* tp_call */ 0, + /* tp_str */ 0, + /* tp_getattro */ 0, + /* tp_setattro */ 0, + /* tp_as_buffer */ 0, + /* tp_flags */ Py_TPFLAGS_DEFAULT, + /* tp_doc */ "OVERLAPPED structure wrapper", + /* tp_traverse */ 0, + /* tp_clear */ 0, + /* tp_richcompare */ 0, + /* tp_weaklistoffset */ 0, + /* tp_iter */ 0, + /* tp_iternext */ 0, + /* tp_methods */ Overlapped_methods, + /* tp_members */ Overlapped_members, + /* tp_getset */ Overlapped_getsets, + /* tp_base */ 0, + /* tp_dict */ 0, + /* tp_descr_get */ 0, + /* tp_descr_set */ 0, + /* tp_dictoffset */ 0, + /* tp_init */ 0, + /* tp_alloc */ 0, + /* tp_new */ Overlapped_new, +}; + +static PyMethodDef overlapped_functions[] = { + {"CreateIoCompletionPort", overlapped_CreateIoCompletionPort, + METH_VARARGS, CreateIoCompletionPort_doc}, + {"GetQueuedCompletionStatus", overlapped_GetQueuedCompletionStatus, + METH_VARARGS, GetQueuedCompletionStatus_doc}, + {"PostQueuedCompletionStatus", overlapped_PostQueuedCompletionStatus, + METH_VARARGS, PostQueuedCompletionStatus_doc}, + {"BindLocal", overlapped_BindLocal, + METH_VARARGS, BindLocal_doc}, + {NULL} +}; + +static struct PyModuleDef overlapped_module = { + PyModuleDef_HEAD_INIT, + "_overlapped", + NULL, + -1, + overlapped_functions, + NULL, + NULL, + NULL, + NULL +}; + +#define WINAPI_CONSTANT(fmt, con) \ + PyDict_SetItemString(d, #con, Py_BuildValue(fmt, con)) + +PyMODINIT_FUNC +PyInit__overlapped(void) +{ + PyObject *m, *d; + + /* Ensure WSAStartup() called before initializing function pointers */ + m = PyImport_ImportModule("_socket"); + if (!m) + return NULL; + Py_DECREF(m); + + if (initialize_function_pointers() < 0) + return NULL; + + if (PyType_Ready(&OverlappedType) < 0) + return NULL; + + m = PyModule_Create(&overlapped_module); + if (PyModule_AddObject(m, "Overlapped", (PyObject *)&OverlappedType) < 0) + return NULL; + + d = PyModule_GetDict(m); + + WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING); + WINAPI_CONSTANT(F_DWORD, INFINITE); + WINAPI_CONSTANT(F_HANDLE, INVALID_HANDLE_VALUE); + WINAPI_CONSTANT(F_HANDLE, NULL); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_ACCEPT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_CONNECT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, TF_REUSE_SOCKET); + + return m; +} diff --git a/runtests.py b/runtests.py new file mode 100644 index 00000000..096a2561 --- /dev/null +++ b/runtests.py @@ -0,0 +1,198 @@ +"""Run all unittests. + +Usage: + python3 runtests.py [-v] [-q] [pattern] ... + +Where: + -v: verbose + -q: quiet + pattern: optional regex patterns to match test ids (default all tests) + +Note that the test id is the fully qualified name of the test, +including package, module, class and method, +e.g. 'tests.events_test.PolicyTests.testPolicy'. + +runtests.py with --coverage argument is equivalent of: + + $(COVERAGE) run --branch runtests.py -v + $(COVERAGE) html $(list of files) + $(COVERAGE) report -m $(list of files) + +""" + +# Originally written by Beech Horn (for NDB). + +import argparse +import logging +import os +import re +import sys +import subprocess +import unittest +import importlib.machinery + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +ARGS = argparse.ArgumentParser(description="Run all unittests.") +ARGS.add_argument( + '-v', action="store", dest='verbose', + nargs='?', const=1, type=int, default=0, help='verbose') +ARGS.add_argument( + '-x', action="store_true", dest='exclude', help='exclude tests') +ARGS.add_argument( + '-q', action="store_true", dest='quiet', help='quiet') +ARGS.add_argument( + '--tests', action="store", dest='testsdir', default='tests', + help='tests directory') +ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') +ARGS.add_argument( + 'pattern', action="store", nargs="*", + help='optional regex patterns to match test ids (default all tests)') + +COV_ARGS = argparse.ArgumentParser(description="Run all unittests.") +COV_ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') + + +def load_modules(basedir, suffix='.py'): + def list_dir(prefix, dir): + files = [] + + modpath = os.path.join(dir, '__init__.py') + if os.path.isfile(modpath): + mod = os.path.split(dir)[-1] + files.append(('%s%s' % (prefix, mod), modpath)) + + prefix = '%s%s.' % (prefix, mod) + + for name in os.listdir(dir): + path = os.path.join(dir, name) + + if os.path.isdir(path): + files.extend(list_dir('%s%s.' % (prefix, name), path)) + else: + if (name != '__init__.py' and + name.endswith(suffix) and + not name.startswith(('.', '_'))): + files.append(('%s%s' % (prefix, name[:-3]), path)) + + return files + + mods = [] + for modname, sourcefile in list_dir('', basedir): + if modname == 'runtests': + continue + try: + loader = importlib.machinery.SourceFileLoader(modname, sourcefile) + mods.append((loader.load_module(), sourcefile)) + except Exception as err: + print("Skipping '%s': %s" % (modname, err), file=sys.stderr) + + return mods + + +def load_tests(testsdir, includes=(), excludes=()): + mods = [mod for mod, _ in load_modules(testsdir)] + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + for mod in mods: + for name in set(dir(mod)): + if name.endswith('Tests'): + test_module = getattr(mod, name) + tests = loader.loadTestsFromTestCase(test_module) + if includes: + tests = [test + for test in tests + if any(re.search(pat, test.id()) + for pat in includes)] + if excludes: + tests = [test + for test in tests + if not any(re.search(pat, test.id()) + for pat in excludes)] + suite.addTests(tests) + + return suite + + +def runtests(): + args = ARGS.parse_args() + + testsdir = os.path.abspath(args.testsdir) + if not os.path.isdir(testsdir): + print("Tests directory is not found: %s\n" % testsdir) + ARGS.print_help() + return + + excludes = includes = [] + if args.exclude: + excludes = args.pattern + else: + includes = args.pattern + + v = 0 if args.quiet else args.verbose + 1 + + tests = load_tests(args.testsdir, includes, excludes) + logger = logging.getLogger() + if v == 0: + logger.setLevel(logging.CRITICAL) + elif v == 1: + logger.setLevel(logging.ERROR) + elif v == 2: + logger.setLevel(logging.WARNING) + elif v == 3: + logger.setLevel(logging.INFO) + elif v >= 4: + logger.setLevel(logging.DEBUG) + result = unittest.TextTestRunner(verbosity=v).run(tests) + sys.exit(not result.wasSuccessful()) + + +def runcoverage(sdir, args): + """ + To install coverage3 for Python 3, you need: + - Distribute (http://packages.python.org/distribute/) + + What worked for me: + - download http://python-distribute.org/distribute_setup.py + * curl -O http://python-distribute.org/distribute_setup.py + - python3 distribute_setup.py + - python3 -m easy_install coverage + """ + try: + import coverage + except ImportError: + print("Coverage package is not found.") + print(runcoverage.__doc__) + return + + sdir = os.path.abspath(sdir) + if not os.path.isdir(sdir): + print("Python files directory is not found: %s\n" % sdir) + ARGS.print_help() + return + + mods = [source for _, source in load_modules(sdir)] + coverage = [sys.executable, '-m', 'coverage'] + + try: + subprocess.check_call( + coverage + ['run', '--branch', 'runtests.py'] + args) + except: + pass + else: + subprocess.check_call(coverage + ['html'] + mods) + subprocess.check_call(coverage + ['report'] + mods) + + +if __name__ == '__main__': + if '--coverage' in sys.argv: + cov_args, args = COV_ARGS.parse_known_args() + runcoverage(cov_args.coverage, args) + else: + runtests() diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..0260f9d5 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[build_ext] +build_lib=tulip diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..dcaee96f --- /dev/null +++ b/setup.py @@ -0,0 +1,14 @@ +import os +from distutils.core import setup, Extension + +extensions = [] +if os.name == 'nt': + ext = Extension('_overlapped', ['overlapped.c'], libraries=['ws2_32']) + extensions.append(ext) + +setup(name='tulip', + description="reference implementation of PEP 3156", + url='http://www.python.org/dev/peps/pep-3156/', + packages=['tulip'], + ext_modules=extensions + ) diff --git a/srv.py b/srv.py new file mode 100755 index 00000000..b28abbda --- /dev/null +++ b/srv.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +"""Simple server written using an event loop.""" + +import email.message +import os +import sys + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +import tulip +import tulip.http + + +class HttpServer(tulip.http.ServerHttpProtocol): + + def handle_request(self, request_info, message): + print('method = {!r}; path = {!r}; version = {!r}'.format( + request_info.method, request_info.uri, request_info.version)) + + path = request_info.uri + + if (not (path.isprintable() and path.startswith('/')) or '/.' in path): + print('bad path', repr(path)) + path = None + else: + path = '.' + path + if not os.path.exists(path): + print('no file', repr(path)) + path = None + else: + isdir = os.path.isdir(path) + + if not path: + raise tulip.http.HttpStatusException(404) + + headers = email.message.Message() + for hdr, val in message.headers: + print(hdr, val) + headers.add_header(hdr, val) + + if isdir and not path.endswith('/'): + path = path + '/' + raise tulip.http.HttpStatusException( + 302, headers=(('URI', path), ('Location', path))) + + response = tulip.http.Response(self.transport, 200) + response.add_header('Transfer-Encoding', 'chunked') + + # content encoding + accept_encoding = headers.get('accept-encoding', '').lower() + if 'deflate' in accept_encoding: + response.add_header('Content-Encoding', 'deflate') + response.add_compression_filter('deflate') + elif 'gzip' in accept_encoding: + response.add_header('Content-Encoding', 'gzip') + response.add_compression_filter('gzip') + + response.add_chunking_filter(1025) + + if isdir: + response.add_header('Content-type', 'text/html') + response.send_headers() + + response.write(b'
    \r\n') + for name in sorted(os.listdir(path)): + if name.isprintable() and not name.startswith('.'): + try: + bname = name.encode('ascii') + except UnicodeError: + pass + else: + if os.path.isdir(os.path.join(path, name)): + response.write(b'
  • ' + bname + b'/
  • \r\n') + else: + response.write(b'
  • ' + bname + b'
  • \r\n') + response.write(b'
') + else: + response.add_header('Content-type', 'text/plain') + response.send_headers() + + try: + with open(path, 'rb') as fp: + chunk = fp.read(8196) + while chunk: + if not response.write(chunk): + break + chunk = fp.read(8196) + except OSError: + response.write(b'Cannot open') + + response.write_eof() + self.close() + + +def main(): + host = '127.0.0.1' + port = 8080 + if sys.argv[1:]: + host = sys.argv[1] + if sys.argv[2:]: + port = int(sys.argv[2]) + elif ':' in host: + host, port = host.split(':', 1) + port = int(port) + loop = tulip.get_event_loop() + f = loop.start_serving(lambda: HttpServer(debug=True), host, port) + x = loop.run_until_complete(f) + print('serving on', x.getsockname()) + loop.run_forever() + + +if __name__ == '__main__': + main() diff --git a/sslsrv.py b/sslsrv.py new file mode 100644 index 00000000..a1bc04f9 --- /dev/null +++ b/sslsrv.py @@ -0,0 +1,56 @@ +"""Serve up an SSL connection, after Python ssl module docs.""" + +import socket +import ssl +import os + + +def main(): + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + certfile = getcertfile() + context.load_cert_chain(certfile=certfile, keyfile=certfile) + bindsocket = socket.socket() + bindsocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + bindsocket.bind(('', 4443)) + bindsocket.listen(5) + + while True: + newsocket, fromaddr = bindsocket.accept() + try: + connstream = context.wrap_socket(newsocket, server_side=True) + try: + deal_with_client(connstream) + finally: + connstream.shutdown(socket.SHUT_RDWR) + connstream.close() + except Exception as exc: + print(exc.__class__.__name__ + ':', exc) + + +def getcertfile(): + import test # Test package + testdir = os.path.dirname(test.__file__) + certfile = os.path.join(testdir, 'keycert.pem') + print('certfile =', certfile) + return certfile + + +def deal_with_client(connstream): + data = connstream.recv(1024) + # empty data means the client is finished with us + while data: + if not do_something(connstream, data): + # we'll assume do_something returns False + # when we're finished with client + break + data = connstream.recv(1024) + # finished with client + + +def do_something(connstream, data): + # just echo back + connstream.sendall(data) + + +if __name__ == '__main__': + main() diff --git a/tests/base_events_test.py b/tests/base_events_test.py new file mode 100644 index 00000000..85b013fa --- /dev/null +++ b/tests/base_events_test.py @@ -0,0 +1,283 @@ +"""Tests for base_events.py""" + +import concurrent.futures +import logging +import socket +import time +import unittest +import unittest.mock + +from tulip import base_events +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import tasks +from tulip import test_utils + + +class BaseEventLoopTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + + self.event_loop = base_events.BaseEventLoop() + self.event_loop._selector = unittest.mock.Mock() + self.event_loop._selector.registered_count.return_value = 1 + + def test_not_implemented(self): + m = unittest.mock.Mock() + self.assertRaises( + NotImplementedError, + self.event_loop._make_socket_transport, m, m) + self.assertRaises( + NotImplementedError, + self.event_loop._make_ssl_transport, m, m, m, m) + self.assertRaises( + NotImplementedError, + self.event_loop._make_datagram_transport, m, m) + self.assertRaises( + NotImplementedError, self.event_loop._process_events, []) + self.assertRaises( + NotImplementedError, self.event_loop._write_to_self) + self.assertRaises( + NotImplementedError, self.event_loop._read_from_self) + self.assertRaises( + NotImplementedError, + self.event_loop._make_read_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + self.event_loop._make_write_pipe_transport, m, m) + + def test_add_callback_handle(self): + h = events.Handle(lambda: False, ()) + + self.event_loop._add_callback(h) + self.assertFalse(self.event_loop._scheduled) + self.assertIn(h, self.event_loop._ready) + + def test_add_callback_timer(self): + when = time.monotonic() + + h1 = events.Timer(when, lambda: False, ()) + h2 = events.Timer(when+10.0, lambda: False, ()) + + self.event_loop._add_callback(h2) + self.event_loop._add_callback(h1) + self.assertEqual([h1, h2], self.event_loop._scheduled) + self.assertFalse(self.event_loop._ready) + + def test_add_callback_cancelled_handle(self): + h = events.Handle(lambda: False, ()) + h.cancel() + + self.event_loop._add_callback(h) + self.assertFalse(self.event_loop._scheduled) + self.assertFalse(self.event_loop._ready) + + def test_wrap_future(self): + f = futures.Future() + self.assertIs(self.event_loop.wrap_future(f), f) + + def test_wrap_future_concurrent(self): + f = concurrent.futures.Future() + self.assertIsInstance(self.event_loop.wrap_future(f), futures.Future) + + def test_set_default_executor(self): + executor = unittest.mock.Mock() + self.event_loop.set_default_executor(executor) + self.assertIs(executor, self.event_loop._default_executor) + + def test_getnameinfo(self): + sockaddr = unittest.mock.Mock() + self.event_loop.run_in_executor = unittest.mock.Mock() + self.event_loop.getnameinfo(sockaddr) + self.assertEqual( + (None, socket.getnameinfo, sockaddr, 0), + self.event_loop.run_in_executor.call_args[0]) + + def test_call_soon(self): + def cb(): + pass + + h = self.event_loop.call_soon(cb) + self.assertEqual(h._callback, cb) + self.assertIsInstance(h, events.Handle) + self.assertIn(h, self.event_loop._ready) + + def test_call_later(self): + def cb(): + pass + + h = self.event_loop.call_later(10.0, cb) + self.assertIsInstance(h, events.Timer) + self.assertIn(h, self.event_loop._scheduled) + self.assertNotIn(h, self.event_loop._ready) + + def test_call_later_no_delay(self): + def cb(): + pass + + h = self.event_loop.call_later(0, cb) + self.assertIn(h, self.event_loop._ready) + self.assertNotIn(h, self.event_loop._scheduled) + + def test_run_once_in_executor_handle(self): + def cb(): + pass + + self.assertRaises( + AssertionError, self.event_loop.run_in_executor, + None, events.Handle(cb, ()), ('',)) + self.assertRaises( + AssertionError, self.event_loop.run_in_executor, + None, events.Timer(10, cb, ())) + + def test_run_once_in_executor_canceled(self): + def cb(): + pass + h = events.Handle(cb, ()) + h.cancel() + + f = self.event_loop.run_in_executor(None, h) + self.assertIsInstance(f, futures.Future) + self.assertTrue(f.done()) + + def test_run_once_in_executor(self): + def cb(): + pass + h = events.Handle(cb, ()) + f = futures.Future() + executor = unittest.mock.Mock() + executor.submit.return_value = f + + self.event_loop.set_default_executor(executor) + + res = self.event_loop.run_in_executor(None, h) + self.assertIs(f, res) + + executor = unittest.mock.Mock() + executor.submit.return_value = f + res = self.event_loop.run_in_executor(executor, h) + self.assertIs(f, res) + self.assertTrue(executor.submit.called) + + def test_run_once(self): + self.event_loop._run_once = unittest.mock.Mock() + self.event_loop._run_once.side_effect = base_events._StopError + self.event_loop.run_once() + self.assertTrue(self.event_loop._run_once.called) + + def test__run_once(self): + h1 = events.Timer(time.monotonic() + 0.1, lambda: True, ()) + h2 = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + + h1.cancel() + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h1) + self.event_loop._scheduled.append(h2) + self.event_loop._run_once() + + t = self.event_loop._selector.select.call_args[0][0] + self.assertTrue(9.99 < t < 10.1) + self.assertEqual([h2], self.event_loop._scheduled) + self.assertTrue(self.event_loop._process_events.called) + + def test__run_once_timeout(self): + h = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._run_once(1.0) + self.assertEqual((1.0,), self.event_loop._selector.select.call_args[0]) + + def test__run_once_timeout_with_ready(self): + # If event loop has ready callbacks, select timeout is always 0. + h = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._ready.append(h) + self.event_loop._run_once(1.0) + + self.assertEqual((0,), self.event_loop._selector.select.call_args[0]) + + @unittest.mock.patch('tulip.base_events.time') + @unittest.mock.patch('tulip.base_events.logging') + def test__run_once_logging(self, m_logging, m_time): + # Log to INFO level if timeout > 1.0 sec. + idx = -1 + data = [10.0, 10.0, 12.0, 13.0] + + def monotonic(): + nonlocal data, idx + idx += 1 + return data[idx] + + m_time.monotonic = monotonic + m_logging.INFO = logging.INFO + m_logging.DEBUG = logging.DEBUG + + self.event_loop._scheduled.append(events.Timer(11.0, lambda: True, ())) + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._run_once() + self.assertEqual(logging.INFO, m_logging.log.call_args[0][0]) + + idx = -1 + data = [10.0, 10.0, 10.3, 13.0] + self.event_loop._scheduled = [events.Timer(11.0, lambda:True, ())] + self.event_loop._run_once() + self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0]) + + def test__run_once_schedule_handle(self): + handle = None + processed = False + + def cb(event_loop): + nonlocal processed, handle + processed = True + handle = event_loop.call_soon(lambda: True) + + h = events.Timer(time.monotonic() - 1, cb, (self.event_loop,)) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._run_once() + + self.assertTrue(processed) + self.assertEqual([handle], list(self.event_loop._ready)) + + def test_run_until_complete_assertion(self): + self.assertRaises( + AssertionError, self.event_loop.run_until_complete, 'blah') + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_connection_mutiple_errors(self, m_socket): + self.suppress_log_errors() + + class MyProto(protocols.Protocol): + pass + + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + + idx = -1 + errors = ['err1', 'err2'] + + def _socket(*args, **kw): + nonlocal idx, errors + idx += 1 + raise socket.error(errors[idx]) + + m_socket.socket = _socket + m_socket.error = socket.error + + self.event_loop.getaddrinfo = getaddrinfo + + task = tasks.Task( + self.event_loop.create_connection(MyProto, 'xkcd.com', 80)) + task._step() + exc = task.exception() + self.assertEqual("Multiple exceptions: err1, err2", str(exc)) diff --git a/tests/events_test.py b/tests/events_test.py new file mode 100644 index 00000000..40859211 --- /dev/null +++ b/tests/events_test.py @@ -0,0 +1,1379 @@ +"""Tests for events.py.""" + +import concurrent.futures +import contextlib +import gc +import io +import os +import re +import signal +import socket +try: + import ssl +except ImportError: + ssl = None +import sys +import threading +import time +import unittest +import unittest.mock + +from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer + +from tulip import events +from tulip import transports +from tulip import protocols +from tulip import selector_events +from tulip import tasks +from tulip import test_utils + + +class MyProto(protocols.Protocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: xkcd.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + + +class MyDatagramProto(protocols.DatagramProtocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + + +class MyReadPipeProto(protocols.Protocol): + + def __init__(self): + self.state = ['INITIAL'] + self.nbytes = 0 + self.transport = None + + def connection_made(self, transport): + self.transport = transport + assert self.state == ['INITIAL'], self.state + self.state.append('CONNECTED') + + def data_received(self, data): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.state.append('EOF') + self.transport.close() + + def connection_lost(self, exc): + assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state + self.state.append('CLOSED') + + +class MyWritePipeProto(protocols.Protocol): + + def __init__(self): + self.state = 'INITIAL' + self.transport = None + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + + +class EventLoopTestsMixin: + + def setUp(self): + super().setUp() + self.event_loop = self.create_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + gc.collect() + super().tearDown() + + @contextlib.contextmanager + def run_test_server(self, *, use_ssl=False): + + class SilentWSGIRequestHandler(WSGIRequestHandler): + def get_stderr(self): + return io.StringIO() + + def log_message(self, format, *args): + pass + + class SilentWSGIServer(WSGIServer): + def handle_error(self, request, client_address): + pass + + class SSLWSGIServer(SilentWSGIServer): + def finish_request(self, request, client_address): + here = os.path.dirname(__file__) + keyfile = os.path.join(here, 'sample.key') + certfile = os.path.join(here, 'sample.crt') + ssock = ssl.wrap_socket(request, + keyfile=keyfile, + certfile=certfile, + server_side=True) + try: + self.RequestHandlerClass(ssock, client_address, self) + ssock.close() + except OSError: + # maybe socket has been closed by peer + pass + + def app(environ, start_response): + status = '302 Found' + headers = [('Content-type', 'text/plain')] + start_response(status, headers) + return [b'Test message'] + + # Run the test WSGI server in a separate thread in order not to + # interfere with event handling in the main thread + server_class = SSLWSGIServer if use_ssl else SilentWSGIServer + httpd = make_server('127.0.0.1', 0, app, + server_class, SilentWSGIRequestHandler) + server_thread = threading.Thread(target=httpd.serve_forever) + server_thread.start() + try: + yield httpd + finally: + httpd.shutdown() + server_thread.join() + + def test_run(self): + self.event_loop.run() # Returns immediately. + + def test_run_nesting(self): + err = None + + @tasks.coroutine + def coro(): + nonlocal err + yield from [] + self.assertTrue(self.event_loop.is_running()) + try: + self.event_loop.run_until_complete( + tasks.sleep(0.1)) + except Exception as exc: + err = exc + + self.event_loop.run_until_complete(tasks.Task(coro())) + self.assertIsInstance(err, RuntimeError) + + def test_run_once_nesting(self): + err = None + + @tasks.coroutine + def coro(): + nonlocal err + yield from [] + tasks.sleep(0.1) + try: + self.event_loop.run_once() + except Exception as exc: + err = exc + + self.event_loop.run_until_complete(tasks.Task(coro())) + self.assertIsInstance(err, RuntimeError) + + def test_run_once_block(self): + called = False + + def callback(): + nonlocal called + called = True + + def run(): + time.sleep(0.1) + self.event_loop.call_soon_threadsafe(callback) + + t = threading.Thread(target=run) + t0 = time.monotonic() + t.start() + self.event_loop.run_once(None) + t1 = time.monotonic() + t.join() + self.assertTrue(called) + self.assertTrue(0.09 < t1-t0 <= 0.12) + + def test_call_later(self): + results = [] + + def callback(arg): + results.append(arg) + + self.event_loop.call_later(0.1, callback, 'hello world') + t0 = time.monotonic() + self.event_loop.run() + t1 = time.monotonic() + self.assertEqual(results, ['hello world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_call_repeatedly(self): + results = [] + + def callback(arg): + results.append(arg) + + self.event_loop.call_repeatedly(0.03, callback, 'ho') + self.event_loop.call_later(0.1, self.event_loop.stop) + self.event_loop.run() + self.assertEqual(results, ['ho', 'ho', 'ho']) + + def test_call_soon(self): + results = [] + + def callback(arg1, arg2): + results.append((arg1, arg2)) + + self.event_loop.call_soon(callback, 'hello', 'world') + self.event_loop.run() + self.assertEqual(results, [('hello', 'world')]) + + def test_call_soon_with_handle(self): + results = [] + + def callback(): + results.append('yeah') + + handle = events.Handle(callback, ()) + self.assertIs(self.event_loop.call_soon(handle), handle) + self.event_loop.run() + self.assertEqual(results, ['yeah']) + + def test_call_soon_threadsafe(self): + results = [] + + def callback(arg): + results.append(arg) + + def run(): + self.event_loop.call_soon_threadsafe(callback, 'hello') + + t = threading.Thread(target=run) + self.event_loop.call_later(0.1, callback, 'world') + t0 = time.monotonic() + t.start() + self.event_loop.run() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_call_soon_threadsafe_same_thread(self): + results = [] + + def callback(arg): + results.append(arg) + + self.event_loop.call_later(0.1, callback, 'world') + self.event_loop.call_soon_threadsafe(callback, 'hello') + self.event_loop.run() + self.assertEqual(results, ['hello', 'world']) + + def test_call_soon_threadsafe_with_handle(self): + results = [] + + def callback(arg): + results.append(arg) + + handle = events.Handle(callback, ('hello',)) + + def run(): + self.assertIs( + self.event_loop.call_soon_threadsafe(handle), handle) + + t = threading.Thread(target=run) + self.event_loop.call_later(0.1, callback, 'world') + + t0 = time.monotonic() + t.start() + self.event_loop.run() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_wrap_future(self): + def run(arg): + time.sleep(0.1) + return arg + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = self.event_loop.wrap_future(f1) + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'oi') + + def test_run_in_executor(self): + def run(arg): + time.sleep(0.1) + return arg + f2 = self.event_loop.run_in_executor(None, run, 'yo') + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def test_run_in_executor_with_handle(self): + def run(arg): + time.sleep(0.1) + return arg + handle = events.Handle(run, ('yo',)) + f2 = self.event_loop.run_in_executor(None, handle) + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def test_reader_callback(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.event_loop.remove_reader(r.fileno())) + r.close() + + self.event_loop.add_reader(r.fileno(), reader) + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_reader_callback_with_handle(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.event_loop.remove_reader(r.fileno())) + r.close() + + handle = events.Handle(reader, ()) + self.assertIs(handle, self.event_loop.add_reader(r.fileno(), handle)) + + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_reader_callback_cancel(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + return + if data: + bytes_read.append(data) + if sum(len(b) for b in bytes_read) >= 6: + handle.cancel() + if not data: + r.close() + + handle = self.event_loop.add_reader(r.fileno(), reader) + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_writer_callback(self): + r, w = test_utils.socketpair() + w.setblocking(False) + self.event_loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + + def remove_writer(): + self.assertTrue(self.event_loop.remove_writer(w.fileno())) + + self.event_loop.call_later(0.1, remove_writer) + self.event_loop.run() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def test_writer_callback_with_handle(self): + r, w = test_utils.socketpair() + w.setblocking(False) + handle = events.Handle(w.send, (b'x'*(256*1024),)) + self.assertIs(self.event_loop.add_writer(w.fileno(), handle), handle) + + def remove_writer(): + self.assertTrue(self.event_loop.remove_writer(w.fileno())) + + self.event_loop.call_later(0.1, remove_writer) + self.event_loop.run() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def test_writer_callback_cancel(self): + r, w = test_utils.socketpair() + w.setblocking(False) + + def sender(): + w.send(b'x'*256) + handle.cancel() + + handle = self.event_loop.add_writer(w.fileno(), sender) + self.event_loop.run() + w.close() + data = r.recv(1024) + r.close() + self.assertTrue(data == b'x'*256) + + def test_sock_client_ops(self): + with self.run_test_server() as httpd: + sock = socket.socket() + sock.setblocking(False) + address = httpd.socket.getsockname() + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, address)) + self.event_loop.run_until_complete( + self.event_loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.event_loop.run_until_complete( + self.event_loop.sock_recv(sock, 1024)) + sock.close() + + self.assertTrue(re.match(rb'HTTP/1.0 302 Found', data), data) + + def test_sock_client_fail(self): + # Make sure that we will get an unused port + address = None + try: + s = socket.socket() + s.bind(('127.0.0.1', 0)) + address = s.getsockname() + finally: + s.close() + + sock = socket.socket() + sock.setblocking(False) + with self.assertRaises(ConnectionRefusedError): + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, address)) + sock.close() + + def test_sock_accept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + + f = self.event_loop.sock_accept(listener) + conn, addr = self.event_loop.run_until_complete(f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + def test_add_signal_handler(self): + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + # Check error behavior first. + self.assertRaises( + TypeError, self.event_loop.add_signal_handler, 'boom', my_handler) + self.assertRaises( + TypeError, self.event_loop.remove_signal_handler, 'boom') + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, signal.NSIG+1) + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, 0, my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, 0) + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, -1, my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, -1) + self.assertRaises( + RuntimeError, self.event_loop.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + self.event_loop.add_signal_handler(signal.SIGINT, my_handler) + self.event_loop.run_once() + os.kill(os.getpid(), signal.SIGINT) + self.event_loop.run_once() + self.assertEqual(caught, 1) + # Removing it should restore the default handler. + self.assertTrue(self.event_loop.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGINT)) + + @unittest.skipIf(sys.platform == 'win32', 'Unix only') + def test_cancel_signal_handler(self): + # Cancelling the handler should remove it (eventually). + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + handle = self.event_loop.add_signal_handler(signal.SIGINT, my_handler) + handle.cancel() + os.kill(os.getpid(), signal.SIGINT) + self.event_loop.run_once() + self.assertEqual(caught, 0) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_while_selecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + self.event_loop.add_signal_handler( + signal.SIGALRM, my_handler) + + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + self.event_loop.call_later(0.15, self.event_loop.stop) + self.event_loop.run_forever() + self.assertEqual(caught, 1) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_args(self): + some_args = (42,) + caught = 0 + + def my_handler(*args): + nonlocal caught + caught += 1 + self.assertEqual(args, some_args) + + self.event_loop.add_signal_handler( + signal.SIGALRM, my_handler, *some_args) + + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + self.event_loop.call_later(0.15, self.event_loop.stop) + self.event_loop.run_forever() + self.assertEqual(caught, 1) + + def test_create_connection(self): + with self.run_test_server() as httpd: + host, port = httpd.socket.getsockname() + f = tasks.Task( + self.event_loop.create_connection(MyProto, host, port)) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) + + def test_create_connection_sock(self): + with self.run_test_server() as httpd: + host, port = httpd.socket.getsockname() + f = tasks.Task( + self.event_loop.create_connection(MyProto, host, port)) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_ssl_connection(self): + with self.run_test_server(use_ssl=True) as httpsd: + host, port = httpsd.socket.getsockname() + f = self.event_loop.create_connection( + MyProto, host, port, ssl=True) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + self.assertTrue( + hasattr(tr.get_extra_info('socket'), 'getsockname')) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) + + def test_create_connection_host_port_sock(self): + self.suppress_log_errors() + coro = self.event_loop.create_connection( + MyProto, 'xkcd.com', 80, sock=object()) + self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) + + def test_create_connection_no_host_port_sock(self): + self.suppress_log_errors() + coro = self.event_loop.create_connection(MyProto) + self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) + + def test_create_connection_no_getaddrinfo(self): + self.suppress_log_errors() + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + coro = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_create_connection_connect_err(self): + self.suppress_log_errors() + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + coro = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_create_connection_mutiple_errors(self): + self.suppress_log_errors() + + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + self.event_loop.getaddrinfo = getaddrinfo + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + coro = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_start_serving(self): + proto = None + + def factory(): + nonlocal proto + proto = MyProto() + return proto + + f = self.event_loop.start_serving(factory, '0.0.0.0', 0) + sock = self.event_loop.run_until_complete(f) + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + self.event_loop.run_once() + self.assertIsInstance(proto, MyProto) + self.assertEqual('INITIAL', proto.state) + self.event_loop.run_once() + self.assertEqual('CONNECTED', proto.state) + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('socket')) + conn = proto.transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + self.assertEqual( + '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + + # close connection + proto.transport.close() + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + def test_start_serving_sock(self): + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.event_loop.start_serving(MyProto, sock=sock_ob) + sock = self.event_loop.run_until_complete(f) + self.assertIs(sock, sock_ob) + + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + self.event_loop.run_once() # This is quite mysterious, but necessary. + self.event_loop.run_once() + sock.close() + client.close() + + def test_start_serving_host_port_sock(self): + self.suppress_log_errors() + fut = self.event_loop.start_serving( + MyProto, '0.0.0.0', 0, sock=object()) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_start_serving_no_host_port_sock(self): + self.suppress_log_errors() + fut = self.event_loop.start_serving(MyProto) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_start_serving_no_getaddrinfo(self): + self.suppress_log_errors() + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + f = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, f) + + @unittest.mock.patch('tulip.base_events.socket') + def test_start_serving_cant_bind(self, m_socket): + self.suppress_log_errors() + + class Err(socket.error): + pass + + m_socket.error = socket.error + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.setsockopt.side_effect = Err + + fut = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises(Err, self.event_loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_no_addrinfo(self, m_socket): + self.suppress_log_errors() + + m_socket.error = socket.error + m_socket.getaddrinfo.return_value = [] + + coro = self.event_loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 0)) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_create_datagram_endpoint_addr_error(self): + self.suppress_log_errors() + + coro = self.event_loop.create_datagram_endpoint( + MyDatagramProto, local_addr='localhost') + self.assertRaises( + AssertionError, self.event_loop.run_until_complete, coro) + coro = self.event_loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 1, 2, 3)) + self.assertRaises( + AssertionError, self.event_loop.run_until_complete, coro) + + def test_create_datagram_endpoint(self): + class TestMyDatagramProto(MyDatagramProto): + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + coro = self.event_loop.create_datagram_endpoint( + TestMyDatagramProto, local_addr=('127.0.0.1', 0)) + s_transport, server = self.event_loop.run_until_complete(coro) + host, port = s_transport.get_extra_info('addr') + + coro = self.event_loop.create_datagram_endpoint( + MyDatagramProto, remote_addr=(host, port)) + transport, client = self.event_loop.run_until_complete(coro) + + self.assertEqual('INITIALIZED', client.state) + transport.sendto(b'xxx') + self.event_loop.run_once(None) + self.assertEqual(3, server.nbytes) + self.event_loop.run_once(None) + + # received + self.assertEqual(8, client.nbytes) + + # extra info is available + self.assertIsNotNone(transport.get_extra_info('socket')) + conn = transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + + # close connection + transport.close() + + self.assertEqual('CLOSED', client.state) + server.transport.close() + + def test_create_datagram_endpoint_connect_err(self): + self.suppress_log_errors() + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0)) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_socket_err(self, m_socket): + self.suppress_log_errors() + + m_socket.error = socket.error + m_socket.getaddrinfo = socket.getaddrinfo + m_socket.socket.side_effect = socket.error + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, local_addr=('127.0.0.1', 0)) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_create_datagram_endpoint_no_matching_family(self): + self.suppress_log_errors() + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, + remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) + self.assertRaises( + ValueError, self.event_loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_setblk_err(self, m_socket): + self.suppress_log_errors() + + m_socket.error = socket.error + m_socket.socket.return_value.setblocking.side_effect = socket.error + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + self.assertTrue( + m_socket.socket.return_value.close.called) + + def test_create_datagram_endpoint_noaddr_nofamily(self): + self.suppress_log_errors() + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol) + self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_cant_bind(self, m_socket): + self.suppress_log_errors() + + class Err(socket.error): + pass + + m_socket.error = socket.error + m_socket.AF_INET6 = socket.AF_INET6 + m_socket.getaddrinfo = socket.getaddrinfo + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.event_loop.create_datagram_endpoint( + MyDatagramProto, + local_addr=('127.0.0.1', 0), family=socket.AF_INET) + self.assertRaises(Err, self.event_loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_accept_connection_retry(self): + sock = unittest.mock.Mock() + sock.accept.side_effect = BlockingIOError() + + self.event_loop._accept_connection(MyProto, sock) + self.assertFalse(sock.close.called) + + def test_accept_connection_exception(self): + self.suppress_log_errors() + + sock = unittest.mock.Mock() + sock.accept.side_effect = OSError() + + self.event_loop._accept_connection(MyProto, sock) + self.assertTrue(sock.close.called) + + def test_internal_fds(self): + event_loop = self.create_event_loop() + if not isinstance(event_loop, selector_events.BaseSelectorEventLoop): + return + + self.assertEqual(1, event_loop._internal_fds) + event_loop.close() + self.assertEqual(0, event_loop._internal_fds) + self.assertIsNone(event_loop._csock) + self.assertIsNone(event_loop._ssock) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_read_pipe(self): + proto = None + + def factory(): + nonlocal proto + proto = MyReadPipeProto() + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(rpipe, 'rb', 1024) + + @tasks.task + def connect(): + t, p = yield from self.event_loop.connect_read_pipe(factory, + pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(0, proto.nbytes) + + self.event_loop.run_until_complete(connect()) + + os.write(wpipe, b'1') + self.event_loop.run_once() + self.assertEqual(1, proto.nbytes) + + os.write(wpipe, b'2345') + self.event_loop.run_once() + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(5, proto.nbytes) + + os.close(wpipe) + self.event_loop.run_once() + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto() + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.task + def connect(): + nonlocal transport + t, p = yield from self.event_loop.connect_write_pipe(factory, + pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.event_loop.run_until_complete(connect()) + + transport.write(b'1') + self.event_loop.run_once() + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + transport.write(b'2345') + self.event_loop.run_once() + data = os.read(rpipe, 1024) + self.assertEqual(b'2345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(rpipe) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + # close connection + proto.transport.close() + self.event_loop.run_once() + self.assertEqual('CLOSED', proto.state) + + +if sys.platform == 'win32': + from tulip import windows_events + + class SelectEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return windows_events.SelectorEventLoop() + + class ProactorEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return windows_events.ProactorEventLoop() + def test_create_ssl_connection(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + def test_reader_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_reader_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_reader_callback_with_handle(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_writer_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_writer_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_writer_callback_with_handle(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_accept_connection_retry(self): + raise unittest.SkipTest( + "IocpEventLoop does not have _accept_connection()") + def test_accept_connection_exception(self): + raise unittest.SkipTest( + "IocpEventLoop does not have _accept_connection()") + def test_create_datagram_endpoint(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_no_connection(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_cant_bind(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_noaddr_nofamily(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_socket_err(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_connect_err(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") +else: + from tulip import selectors + from tulip import unix_events + + if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop( + selectors.KqueueSelector()) + + if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.EpollSelector()) + + if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.PollSelector()) + + # Should always exist. + class SelectEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.SelectSelector()) + + +class HandleTests(unittest.TestCase): + + def test_handle(self): + def callback(*args): + return args + + args = () + h = events.Handle(callback, args) + self.assertIs(h.callback, callback) + self.assertIs(h.args, args) + self.assertFalse(h.cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '.callback')) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h.cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '.callback')) + self.assertTrue(r.endswith('())')) + + def test_make_handle(self): + def callback(*args): + return args + h1 = events.Handle(callback, ()) + h2 = events.make_handle(h1, ()) + self.assertIs(h1, h2) + + self.assertRaises( + AssertionError, events.make_handle, h1, (1, 2)) + + +class TimerTests(unittest.TestCase): + + def test_timer(self): + def callback(*args): + return args + + args = () + when = time.monotonic() + h = events.Timer(when, callback, args) + self.assertIs(h.callback, callback) + self.assertIs(h.args, args) + self.assertFalse(h.cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h.cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + self.assertRaises(AssertionError, events.Timer, None, callback, args) + + def test_timer_comparison(self): + def callback(*args): + return args + + when = time.monotonic() + + h1 = events.Timer(when, callback, ()) + h2 = events.Timer(when, callback, ()) + self.assertFalse(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertTrue(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertFalse(h2 > h1) + self.assertTrue(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertTrue(h1 == h2) + self.assertFalse(h1 != h2) + + h2.cancel() + self.assertFalse(h1 == h2) + + h1 = events.Timer(when, callback, ()) + h2 = events.Timer(when + 10.0, callback, ()) + self.assertTrue(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertFalse(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertTrue(h2 > h1) + self.assertFalse(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h3 = events.Handle(callback, ()) + self.assertIs(NotImplemented, h1.__eq__(h3)) + self.assertIs(NotImplemented, h1.__ne__(h3)) + + +class AbstractEventLoopTests(unittest.TestCase): + + def test_not_imlemented(self): + f = unittest.mock.Mock() + ev_loop = events.AbstractEventLoop() + self.assertRaises( + NotImplementedError, ev_loop.run) + self.assertRaises( + NotImplementedError, ev_loop.run_forever) + self.assertRaises( + NotImplementedError, ev_loop.run_once) + self.assertRaises( + NotImplementedError, ev_loop.run_until_complete, None) + self.assertRaises( + NotImplementedError, ev_loop.stop) + self.assertRaises( + NotImplementedError, ev_loop.call_later, None, None) + self.assertRaises( + NotImplementedError, ev_loop.call_repeatedly, None, None) + self.assertRaises( + NotImplementedError, ev_loop.call_soon, None) + self.assertRaises( + NotImplementedError, ev_loop.call_soon_threadsafe, None) + self.assertRaises( + NotImplementedError, ev_loop.wrap_future, f) + self.assertRaises( + NotImplementedError, ev_loop.run_in_executor, f, f) + self.assertRaises( + NotImplementedError, ev_loop.getaddrinfo, 'localhost', 8080) + self.assertRaises( + NotImplementedError, ev_loop.getnameinfo, ('localhost', 8080)) + self.assertRaises( + NotImplementedError, ev_loop.create_connection, f) + self.assertRaises( + NotImplementedError, ev_loop.start_serving, f) + self.assertRaises( + NotImplementedError, ev_loop.create_datagram_endpoint, f) + self.assertRaises( + NotImplementedError, ev_loop.add_reader, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_reader, 1) + self.assertRaises( + NotImplementedError, ev_loop.add_writer, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_writer, 1) + self.assertRaises( + NotImplementedError, ev_loop.sock_recv, f, 10) + self.assertRaises( + NotImplementedError, ev_loop.sock_sendall, f, 10) + self.assertRaises( + NotImplementedError, ev_loop.sock_connect, f, f) + self.assertRaises( + NotImplementedError, ev_loop.sock_accept, f) + self.assertRaises( + NotImplementedError, ev_loop.add_signal_handler, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, ev_loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, ev_loop.connect_read_pipe, f, + unittest.mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, ev_loop.connect_write_pipe, f, + unittest.mock.sentinel.pipe) + + +class ProtocolsAbsTests(unittest.TestCase): + + def test_empty(self): + f = unittest.mock.Mock() + p = protocols.Protocol() + self.assertIsNone(p.connection_made(f)) + self.assertIsNone(p.connection_lost(f)) + self.assertIsNone(p.data_received(f)) + self.assertIsNone(p.eof_received()) + + dp = protocols.DatagramProtocol() + self.assertIsNone(dp.connection_made(f)) + self.assertIsNone(dp.connection_lost(f)) + self.assertIsNone(dp.connection_refused(f)) + self.assertIsNone(dp.datagram_received(f, f)) + + +class PolicyTests(unittest.TestCase): + + def test_event_loop_policy(self): + policy = events.EventLoopPolicy() + self.assertRaises(NotImplementedError, policy.get_event_loop) + self.assertRaises(NotImplementedError, policy.set_event_loop, object()) + self.assertRaises(NotImplementedError, policy.new_event_loop) + + def test_get_event_loop(self): + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy._event_loop) + + event_loop = policy.get_event_loop() + self.assertIsInstance(event_loop, events.AbstractEventLoop) + + self.assertIs(policy._event_loop, event_loop) + self.assertIs(event_loop, policy.get_event_loop()) + + @unittest.mock.patch('tulip.events.threading') + def test_get_event_loop_thread(self, m_threading): + m_t = m_threading.current_thread.return_value = unittest.mock.Mock() + m_t.name = 'Thread 1' + + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy.get_event_loop()) + + def test_new_event_loop(self): + policy = events.DefaultEventLoopPolicy() + + event_loop = policy.new_event_loop() + self.assertIsInstance(event_loop, events.AbstractEventLoop) + + def test_set_event_loop(self): + policy = events.DefaultEventLoopPolicy() + old_event_loop = policy.get_event_loop() + + self.assertRaises(AssertionError, policy.set_event_loop, object()) + + event_loop = policy.new_event_loop() + policy.set_event_loop(event_loop) + self.assertIs(event_loop, policy.get_event_loop()) + self.assertIsNot(old_event_loop, policy.get_event_loop()) + + def test_get_event_loop_policy(self): + policy = events.get_event_loop_policy() + self.assertIsInstance(policy, events.EventLoopPolicy) + self.assertIs(policy, events.get_event_loop_policy()) + + def test_set_event_loop_policy(self): + self.assertRaises( + AssertionError, events.set_event_loop_policy, object()) + + old_policy = events.get_event_loop_policy() + + policy = events.DefaultEventLoopPolicy() + events.set_event_loop_policy(policy) + self.assertIs(policy, events.get_event_loop_policy()) + self.assertIsNot(policy, old_policy) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/futures_test.py b/tests/futures_test.py new file mode 100644 index 00000000..5569cca1 --- /dev/null +++ b/tests/futures_test.py @@ -0,0 +1,222 @@ +"""Tests for futures.py.""" + +import unittest + +from tulip import futures + + +def _fakefunc(f): + return f + + +class FutureTests(unittest.TestCase): + + def test_initial_state(self): + f = futures.Future() + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertFalse(f.done()) + + def test_init_event_loop_positional(self): + # Make sure Future does't accept a positional argument + self.assertRaises(TypeError, futures.Future, 42) + + def test_cancel(self): + f = futures.Future() + self.assertTrue(f.cancel()) + self.assertTrue(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(futures.CancelledError, f.result) + self.assertRaises(futures.CancelledError, f.exception) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_result(self): + f = futures.Future() + self.assertRaises(futures.InvalidStateError, f.result) + self.assertRaises(futures.InvalidTimeoutError, f.result, 10) + + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_exception(self): + exc = RuntimeError() + f = futures.Future() + self.assertRaises(futures.InvalidStateError, f.exception) + self.assertRaises(futures.InvalidTimeoutError, f.exception, 10) + + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_yield_from_twice(self): + f = futures.Future() + + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + + def test_repr(self): + f_pending = futures.Future() + self.assertEqual(repr(f_pending), 'Future') + + f_cancelled = futures.Future() + f_cancelled.cancel() + self.assertEqual(repr(f_cancelled), 'Future') + + f_result = futures.Future() + f_result.set_result(4) + self.assertEqual(repr(f_result), 'Future') + + f_exception = futures.Future() + f_exception.set_exception(RuntimeError()) + self.assertEqual(repr(f_exception), 'Future') + + f_few_callbacks = futures.Future() + f_few_callbacks.add_done_callback(_fakefunc) + self.assertIn('Future', r) + + def test_copy_state(self): + # Test the internal _copy_state method since it's being directly + # invoked in other modules. + f = futures.Future() + f.set_result(10) + + newf = futures.Future() + newf._copy_state(f) + self.assertTrue(newf.done()) + self.assertEqual(newf.result(), 10) + + f_exception = futures.Future() + f_exception.set_exception(RuntimeError()) + + newf_exception = futures.Future() + newf_exception._copy_state(f_exception) + self.assertTrue(newf_exception.done()) + self.assertRaises(RuntimeError, newf_exception.result) + + f_cancelled = futures.Future() + f_cancelled.cancel() + + newf_cancelled = futures.Future() + newf_cancelled._copy_state(f_cancelled) + self.assertTrue(newf_cancelled.cancelled()) + + def test_iter(self): + def coro(): + fut = futures.Future() + yield from fut + + def test(): + arg1, arg2 = coro() + + self.assertRaises(AssertionError, test) + + +# A fake event loop for tests. All it does is implement a call_soon method +# that immediately invokes the given function. +class _FakeEventLoop: + def call_soon(self, fn, future): + fn(future) + + +class FutureDoneCallbackTests(unittest.TestCase): + + def _make_callback(self, bag, thing): + # Create a callback function that appends thing to bag. + def bag_appender(future): + bag.append(thing) + return bag_appender + + def _new_future(self): + return futures.Future(event_loop=_FakeEventLoop()) + + def test_callbacks_invoked_on_set_result(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 42)) + f.add_done_callback(self._make_callback(bag, 17)) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def test_callbacks_invoked_on_set_exception(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 100)) + + self.assertEqual(bag, []) + exc = RuntimeError() + f.set_exception(exc) + self.assertEqual(bag, [100]) + self.assertEqual(f.exception(), exc) + + def test_remove_done_callback(self): + bag = [] + f = self._new_future() + cb1 = self._make_callback(bag, 1) + cb2 = self._make_callback(bag, 2) + cb3 = self._make_callback(bag, 3) + + # Add one cb1 and one cb2. + f.add_done_callback(cb1) + f.add_done_callback(cb2) + + # One instance of cb2 removed. Now there's only one cb1. + self.assertEqual(f.remove_done_callback(cb2), 1) + + # Never had any cb3 in there. + self.assertEqual(f.remove_done_callback(cb3), 0) + + # After this there will be 6 instances of cb1 and one of cb2. + f.add_done_callback(cb2) + for i in range(5): + f.add_done_callback(cb1) + + # Remove all instances of cb1. One cb2 remains. + self.assertEqual(f.remove_done_callback(cb1), 6) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [2]) + self.assertEqual(f.result(), 'foo') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py new file mode 100644 index 00000000..74aef7c8 --- /dev/null +++ b/tests/http_protocol_test.py @@ -0,0 +1,972 @@ +"""Tests for http/protocol.py""" + +import http.client +import unittest +import unittest.mock +import zlib + +import tulip +from tulip.http import protocol +from tulip.test_utils import LogTrackingTestCase + + +class HttpStreamReaderTests(LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.suppress_log_errors() + + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + self.transport = unittest.mock.Mock() + self.stream = protocol.HttpStreamReader() + + def tearDown(self): + self.loop.close() + super().tearDown() + + def test_request_line(self): + self.stream.feed_data(b'get /path HTTP/1.1\r\n') + self.assertEqual( + ('GET', '/path', (1, 1)), + self.loop.run_until_complete(self.stream.read_request_line())) + + def test_request_line_two_slashes(self): + self.stream.feed_data(b'get //path HTTP/1.1\r\n') + self.assertEqual( + ('GET', '//path', (1, 1)), + self.loop.run_until_complete(self.stream.read_request_line())) + + def test_request_line_non_ascii(self): + self.stream.feed_data(b'get /path\xd0\xb0 HTTP/1.1\r\n') + + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete(self.stream.read_request_line()) + + self.assertEqual( + b'get /path\xd0\xb0 HTTP/1.1\r\n', cm.exception.args[0]) + + def test_request_line_bad_status_line(self): + self.stream.feed_data(b'\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_request_line()) + + def test_request_line_bad_method(self): + self.stream.feed_data(b'!12%()+=~$ /get HTTP/1.1\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_request_line()) + + def test_request_line_bad_version(self): + self.stream.feed_data(b'GET //get HT/11\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_request_line()) + + def test_response_status_bad_status_line(self): + self.stream.feed_data(b'\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_response_status()) + + def test_response_status_bad_status_line_eof(self): + self.stream.feed_eof() + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_response_status()) + + def test_response_status_bad_status_non_ascii(self): + self.stream.feed_data(b'HTTP/1.1 200 \xd0\xb0\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete(self.stream.read_response_status()) + + self.assertEqual(b'HTTP/1.1 200 \xd0\xb0\r\n', cm.exception.args[0]) + + def test_response_status_bad_version(self): + self.stream.feed_data(b'HT/11 200 Ok\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete(self.stream.read_response_status()) + + self.assertEqual('HT/11 200 Ok', cm.exception.args[0]) + + def test_response_status_no_reason(self): + self.stream.feed_data(b'HTTP/1.1 200\r\n') + + v, s, r = self.loop.run_until_complete( + self.stream.read_response_status()) + self.assertEqual(v, (1, 1)) + self.assertEqual(s, 200) + self.assertEqual(r, '') + + def test_response_status_bad(self): + self.stream.feed_data(b'HTT/1\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + self.stream.read_response_status()) + + self.assertIn('HTT/1', str(cm.exception)) + + def test_response_status_bad_code_under_100(self): + self.stream.feed_data(b'HTTP/1.1 99 test\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + self.stream.read_response_status()) + + self.assertIn('HTTP/1.1 99 test', str(cm.exception)) + + def test_response_status_bad_code_above_999(self): + self.stream.feed_data(b'HTTP/1.1 9999 test\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + self.stream.read_response_status()) + + self.assertIn('HTTP/1.1 9999 test', str(cm.exception)) + + def test_response_status_bad_code_not_int(self): + self.stream.feed_data(b'HTTP/1.1 ttt test\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + self.stream.read_response_status()) + + self.assertIn('HTTP/1.1 ttt test', str(cm.exception)) + + def test_read_headers(self): + self.stream.feed_data(b'test: line\r\n' + b' continue\r\n' + b'test2: data\r\n' + b'\r\n') + + headers = self.loop.run_until_complete(self.stream.read_headers()) + self.assertEqual(headers, + [('TEST', 'line\r\n continue'), ('TEST2', 'data')]) + + def test_read_headers_size(self): + self.stream.feed_data(b'test: line\r\n') + self.stream.feed_data(b' continue\r\n') + self.stream.feed_data(b'test2: data\r\n') + self.stream.feed_data(b'\r\n') + + self.stream.MAX_HEADERS = 5 + self.assertRaises( + http.client.LineTooLong, + self.loop.run_until_complete, + self.stream.read_headers()) + + def test_read_headers_invalid_header(self): + self.stream.feed_data(b'test line\r\n') + + with self.assertRaises(ValueError) as cm: + self.loop.run_until_complete(self.stream.read_headers()) + + self.assertIn("Invalid header b'test line'", str(cm.exception)) + + def test_read_headers_invalid_name(self): + self.stream.feed_data(b'test[]: line\r\n') + + with self.assertRaises(ValueError) as cm: + self.loop.run_until_complete(self.stream.read_headers()) + + self.assertIn("Invalid header name b'TEST[]'", str(cm.exception)) + + def test_read_headers_headers_size(self): + self.stream.MAX_HEADERFIELD_SIZE = 5 + self.stream.feed_data(b'test: line data data\r\ndata\r\n') + + with self.assertRaises(http.client.LineTooLong) as cm: + self.loop.run_until_complete(self.stream.read_headers()) + + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_read_headers_continuation_headers_size(self): + self.stream.MAX_HEADERFIELD_SIZE = 5 + self.stream.feed_data(b'test: line\r\n test\r\n') + + with self.assertRaises(http.client.LineTooLong) as cm: + self.loop.run_until_complete(self.stream.read_headers()) + + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_read_message_should_close(self): + self.stream.feed_data( + b'Host: example.com\r\nConnection: close\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + self.assertTrue(msg.should_close) + + def test_read_message_should_close_http11(self): + self.stream.feed_data( + b'Host: example.com\r\n\r\n') + + msg = self.loop.run_until_complete( + self.stream.read_message(version=(1, 1))) + self.assertFalse(msg.should_close) + + def test_read_message_should_close_http10(self): + self.stream.feed_data( + b'Host: example.com\r\n\r\n') + + msg = self.loop.run_until_complete( + self.stream.read_message(version=(1, 0))) + self.assertTrue(msg.should_close) + + def test_read_message_should_close_keep_alive(self): + self.stream.feed_data( + b'Host: example.com\r\nConnection: keep-alive\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + self.assertFalse(msg.should_close) + + def test_read_message_content_length_broken(self): + self.stream.feed_data( + b'Host: example.com\r\nContent-Length: qwe\r\n\r\n') + + self.assertRaises( + http.client.HTTPException, + self.loop.run_until_complete, + self.stream.read_message()) + + def test_read_message_content_length_wrong(self): + self.stream.feed_data( + b'Host: example.com\r\nContent-Length: -1\r\n\r\n') + + self.assertRaises( + http.client.HTTPException, + self.loop.run_until_complete, + self.stream.read_message()) + + def test_read_message_content_length(self): + self.stream.feed_data( + b'Host: example.com\r\nContent-Length: 2\r\n\r\n12') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'12', payload) + + def test_read_message_content_length_no_val(self): + self.stream.feed_data(b'Host: example.com\r\n\r\n12') + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=False)) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'', payload) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_read_message_deflate(self): + self.stream.feed_data( + ('Host: example.com\r\nContent-Length: %s\r\n' + 'Content-Encoding: deflate\r\n\r\n' % + len(self._COMPRESSED)).encode()) + self.stream.feed_data(self._COMPRESSED) + + msg = self.loop.run_until_complete(self.stream.read_message()) + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'data', payload) + + def test_read_message_deflate_disabled(self): + self.stream.feed_data( + ('Host: example.com\r\nContent-Encoding: deflate\r\n' + 'Content-Length: %s\r\n\r\n' % + len(self._COMPRESSED)).encode()) + self.stream.feed_data(self._COMPRESSED) + + msg = self.loop.run_until_complete( + self.stream.read_message(compression=False)) + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(self._COMPRESSED, payload) + + def test_read_message_deflate_unknown(self): + self.stream.feed_data( + ('Host: example.com\r\nContent-Encoding: compress\r\n' + 'Content-Length: %s\r\n\r\n' % len(self._COMPRESSED)).encode()) + self.stream.feed_data(self._COMPRESSED) + + msg = self.loop.run_until_complete( + self.stream.read_message(compression=False)) + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(self._COMPRESSED, payload) + + def test_read_message_websocket(self): + self.stream.feed_data( + b'Host: example.com\r\nSec-Websocket-Key1: 13\r\n\r\n1234567890') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'12345678', payload) + + def test_read_message_chunked(self): + self.stream.feed_data( + b'Host: example.com\r\nTransfer-Encoding: chunked\r\n\r\n') + self.stream.feed_data( + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'dataline', payload) + + def test_read_message_readall_eof(self): + self.stream.feed_data( + b'Host: example.com\r\n\r\n') + self.stream.feed_data(b'data') + self.stream.feed_data(b'line') + self.stream.feed_eof() + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'dataline', payload) + + def test_read_message_payload(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 8\r\n\r\n') + self.stream.feed_data(b'data') + self.stream.feed_data(b'data') + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + data = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'datadata', data) + + def test_read_message_payload_eof(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 4\r\n\r\n') + self.stream.feed_data(b'da') + self.stream.feed_eof() + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, msg.payload.read()) + + def test_read_message_length_payload_zero(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 0\r\n\r\n') + self.stream.feed_data(b'data') + + msg = self.loop.run_until_complete(self.stream.read_message()) + data = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'', data) + + def test_read_message_length_payload_incomplete(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 8\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data(b'data') + self.stream.feed_eof() + return (yield from msg.payload.read()) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, coro()) + + def test_read_message_eof_payload(self): + self.stream.feed_data(b'Host: example.com\r\n\r\n') + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + def coro(): + self.stream.feed_data(b'data') + self.stream.feed_eof() + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'data', data) + + def test_read_message_length_payload(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 4\r\n\r\n') + self.stream.feed_data(b'da') + self.stream.feed_data(b't') + self.stream.feed_data(b'ali') + self.stream.feed_data(b'ne') + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + self.assertIsInstance(msg.payload, tulip.StreamReader) + + data = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'data', data) + self.assertEqual(b'line', b''.join(self.stream.buffer)) + + def test_read_message_length_payload_extra(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 4\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data(b'da') + self.stream.feed_data(b't') + self.stream.feed_data(b'ali') + self.stream.feed_data(b'ne') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'data', data) + self.assertEqual(b'line', b''.join(self.stream.buffer)) + + def test_parse_length_payload_eof_exc(self): + parser = self.stream._parse_length_payload(4) + next(parser) + + stream = tulip.StreamReader() + parser.send(stream) + self.stream._parser = parser + self.stream.feed_data(b'da') + + def eof(): + yield from [] + self.stream.feed_eof() + + t1 = tulip.Task(stream.read()) + t2 = tulip.Task(eof()) + + self.loop.run_until_complete(tulip.wait([t1, t2])) + self.assertRaises(http.client.IncompleteRead, t1.result) + self.assertIsNone(self.stream._parser) + + def test_read_message_deflate_payload(self): + comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + + data = b''.join([comp.compress(b'data'), comp.flush()]) + + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Encoding: deflate\r\n' + + ('Content-Length: %s\r\n\r\n' % len(data)).encode()) + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + def coro(): + self.stream.feed_data(data) + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'data', data) + + def test_read_message_chunked_payload(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data( + b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'dataline', data) + + def test_read_message_chunked_payload_chunks(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data(b'4\r\ndata\r') + self.stream.feed_data(b'\n4') + self.stream.feed_data(b'\r') + self.stream.feed_data(b'\n') + self.stream.feed_data(b'line\r\n0\r\n') + self.stream.feed_data(b'test\r\n\r\n') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'dataline', data) + + def test_read_message_chunked_payload_incomplete(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data(b'4\r\ndata\r\n') + self.stream.feed_eof() + return (yield from msg.payload.read()) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, coro()) + + def test_read_message_chunked_payload_extension(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data( + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'dataline', data) + + def test_read_message_chunked_payload_size_error(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data(b'blah\r\n') + return (yield from msg.payload.read()) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, coro()) + + def test_deflate_stream_set_exception(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + exc = ValueError() + dstream.set_exception(exc) + self.assertIs(exc, stream.exception()) + + def test_deflate_stream_feed_data(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + dstream.zlib = unittest.mock.Mock() + dstream.zlib.decompress.return_value = b'line' + + dstream.feed_data(b'data') + self.assertEqual([b'line'], list(stream.buffer)) + + def test_deflate_stream_feed_data_err(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + exc = ValueError() + dstream.zlib = unittest.mock.Mock() + dstream.zlib.decompress.side_effect = exc + + dstream.feed_data(b'data') + self.assertIsInstance(stream.exception(), http.client.IncompleteRead) + + def test_deflate_stream_feed_eof(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + dstream.zlib = unittest.mock.Mock() + dstream.zlib.flush.return_value = b'line' + + dstream.feed_eof() + self.assertEqual([b'line'], list(stream.buffer)) + self.assertTrue(stream.eof) + + def test_deflate_stream_feed_eof_err(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + dstream.zlib = unittest.mock.Mock() + dstream.zlib.flush.return_value = b'line' + dstream.zlib.eof = False + + dstream.feed_eof() + self.assertIsInstance(stream.exception(), http.client.IncompleteRead) + + +class HttpMessageTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + + def test_start_request(self): + msg = protocol.Request( + self.transport, 'GET', '/index.html', close=True) + + self.assertIs(msg.transport, self.transport) + self.assertIsNone(msg.status) + self.assertTrue(msg.closing) + self.assertEqual(msg.status_line, 'GET /index.html HTTP/1.1\r\n') + + def test_start_response(self): + msg = protocol.Response(self.transport, 200, close=True) + + self.assertIs(msg.transport, self.transport) + self.assertEqual(msg.status, 200) + self.assertTrue(msg.closing) + self.assertEqual(msg.status_line, 'HTTP/1.1 200 OK\r\n') + + def test_force_close(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.closing) + msg.force_close() + self.assertTrue(msg.closing) + + def test_force_chunked(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.chunked) + msg.force_chunked() + self.assertTrue(msg.chunked) + + def test_keep_alive(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.keep_alive()) + msg.keepalive = True + self.assertTrue(msg.keep_alive()) + + msg.force_close() + self.assertFalse(msg.keep_alive()) + + def test_add_header(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], msg.headers) + + msg.add_header('content-type', 'plain/html') + self.assertEqual([('CONTENT-TYPE', 'plain/html')], msg.headers) + + def test_add_headers(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], msg.headers) + + msg.add_headers(('content-type', 'plain/html')) + self.assertEqual([('CONTENT-TYPE', 'plain/html')], msg.headers) + + def test_add_headers_length(self): + msg = protocol.Response(self.transport, 200) + self.assertIsNone(msg.length) + + msg.add_headers(('content-length', '200')) + self.assertEqual(200, msg.length) + + def test_add_headers_upgrade(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.upgrade) + + msg.add_headers(('connection', 'upgrade')) + self.assertTrue(msg.upgrade) + + def test_add_headers_upgrade_websocket(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('upgrade', 'test')) + self.assertEqual([], msg.headers) + + msg.add_headers(('upgrade', 'websocket')) + self.assertEqual([('UPGRADE', 'websocket')], msg.headers) + + def test_add_headers_connection_keepalive(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('connection', 'keep-alive')) + self.assertEqual([], msg.headers) + self.assertTrue(msg.keepalive) + + msg.add_headers(('connection', 'close')) + self.assertFalse(msg.keepalive) + + def test_add_headers_hop_headers(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('connection', 'test'), ('transfer-encoding', 't')) + self.assertEqual([], msg.headers) + + def test_default_headers(self): + msg = protocol.Response(self.transport, 200) + + headers = [r for r, _ in msg._default_headers()] + self.assertIn('DATE', headers) + self.assertIn('CONNECTION', headers) + + def test_default_headers_server(self): + msg = protocol.Response(self.transport, 200) + + headers = [r for r, _ in msg._default_headers()] + self.assertIn('SERVER', headers) + + def test_default_headers_useragent(self): + msg = protocol.Request(self.transport, 'GET', '/') + + headers = [r for r, _ in msg._default_headers()] + self.assertNotIn('SERVER', headers) + self.assertIn('USER-AGENT', headers) + + def test_default_headers_chunked(self): + msg = protocol.Response(self.transport, 200) + + headers = [r for r, _ in msg._default_headers()] + self.assertNotIn('TRANSFER-ENCODING', headers) + + msg.force_chunked() + + headers = [r for r, _ in msg._default_headers()] + self.assertIn('TRANSFER-ENCODING', headers) + + def test_default_headers_connection_upgrade(self): + msg = protocol.Response(self.transport, 200) + msg.upgrade = True + + headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'upgrade')], headers) + + def test_default_headers_connection_close(self): + msg = protocol.Response(self.transport, 200) + msg.force_close() + + headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'close')], headers) + + def test_default_headers_connection_keep_alive(self): + msg = protocol.Response(self.transport, 200) + msg.keepalive = True + + headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'keep-alive')], headers) + + def test_send_headers(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-type', 'plain/html')) + self.assertFalse(msg.is_headers_sent()) + + msg.send_headers() + + content = b''.join([arg[1][0] for arg in list(write.mock_calls)]) + + self.assertTrue(content.startswith(b'HTTP/1.1 200 OK\r\n')) + self.assertIn(b'CONTENT-TYPE: plain/html', content) + self.assertTrue(msg.headers_sent) + self.assertTrue(msg.is_headers_sent()) + + def test_send_headers_nomore_add(self): + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-type', 'plain/html')) + msg.send_headers() + + self.assertRaises(AssertionError, + msg.add_header, 'content-type', 'plain/html') + + def test_prepare_length(self): + msg = protocol.Response(self.transport, 200) + length = msg._write_length_payload = unittest.mock.Mock() + length.return_value = iter([1, 2, 3]) + + msg.add_headers(('content-length', '200')) + msg.send_headers() + + self.assertTrue(length.called) + self.assertTrue((200,), length.call_args[0]) + + def test_prepare_chunked_force(self): + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + + chunked = msg._write_chunked_payload = unittest.mock.Mock() + chunked.return_value = iter([1, 2, 3]) + + msg.add_headers(('content-length', '200')) + msg.send_headers() + self.assertTrue(chunked.called) + + def test_prepare_chunked_no_length(self): + msg = protocol.Response(self.transport, 200) + + chunked = msg._write_chunked_payload = unittest.mock.Mock() + chunked.return_value = iter([1, 2, 3]) + + msg.send_headers() + self.assertTrue(chunked.called) + + def test_prepare_eof(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + + eof = msg._write_eof_payload = unittest.mock.Mock() + eof.return_value = iter([1, 2, 3]) + + msg.send_headers() + self.assertTrue(eof.called) + + def test_write_auto_send_headers(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg._send_headers = True + + msg.write(b'data1') + self.assertTrue(msg.headers_sent) + + def test_write_payload_eof(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg.send_headers() + + msg.write(b'data1') + self.assertTrue(msg.headers_sent) + + msg.write(b'data2') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'data1data2', content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg.send_headers() + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'4\r\ndata\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_multiple(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg.send_headers() + + msg.write(b'data1') + msg.write(b'data2') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_length(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '2')) + msg.send_headers() + + msg.write(b'd') + msg.write(b'ata') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'da', content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_filter(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(2) + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith(b'2\r\nda\r\n2\r\nta\r\n0\r\n\r\n')) + + def test_write_payload_chunked_filter_mutiple_chunks(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(2) + msg.write(b'data1') + msg.write(b'data2') + msg.write_eof() + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith( + b'2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n' + b'2\r\na2\r\n0\r\n\r\n')) + + def test_write_payload_chunked_large_chunk(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(1024) + msg.write(b'data') + msg.write_eof() + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith(b'4\r\ndata\r\n0\r\n\r\n')) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_write_payload_deflate_filter(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '%s' % len(self._COMPRESSED))) + msg.send_headers() + + msg.add_compression_filter('deflate') + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_deflate_and_chunked(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_compression_filter('deflate') + msg.add_chunking_filter(2) + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'2\r\nKI\r\n2\r\n,I\r\n2\r\n\x04\x00\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_and_deflate(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '%s' % len(self._COMPRESSED))) + + msg.add_chunking_filter(2) + msg.add_compression_filter('deflate') + msg.send_headers() + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) diff --git a/tests/http_server_test.py b/tests/http_server_test.py new file mode 100644 index 00000000..dc55eff9 --- /dev/null +++ b/tests/http_server_test.py @@ -0,0 +1,242 @@ +"""Tests for http/server.py""" + +import unittest +import unittest.mock + +import tulip +from tulip.http import server +from tulip.http import errors +from tulip.test_utils import LogTrackingTestCase + + +class HttpServerProtocolTests(LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.suppress_log_errors() + + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + super().tearDown() + + def test_http_status_exception(self): + exc = errors.HttpStatusException(500, message='Internal error') + self.assertEqual(exc.code, 500) + self.assertEqual(exc.message, 'Internal error') + + def test_handle_request(self): + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + rline = unittest.mock.Mock() + rline.version = (1, 1) + message = unittest.mock.Mock() + srv.handle_request(rline, message) + + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertTrue(content.startswith(b'HTTP/1.1 404 Not Found\r\n')) + + def test_connection_made(self): + srv = server.ServerHttpProtocol() + self.assertIsNone(srv._request_handle) + + srv.connection_made(unittest.mock.Mock()) + self.assertIsNotNone(srv._request_handle) + + def test_data_received(self): + srv = server.ServerHttpProtocol() + srv.connection_made(unittest.mock.Mock()) + + srv.data_received(b'123') + self.assertEqual(b'123', b''.join(srv.stream.buffer)) + + srv.data_received(b'456') + self.assertEqual(b'123456', b''.join(srv.stream.buffer)) + + def test_eof_received(self): + srv = server.ServerHttpProtocol() + srv.connection_made(unittest.mock.Mock()) + srv.eof_received() + self.assertTrue(srv.stream.eof) + + def test_connection_lost(self): + srv = server.ServerHttpProtocol() + srv.connection_made(unittest.mock.Mock()) + srv.data_received(b'123') + + handle = srv._request_handle + srv.connection_lost(None) + + self.assertIsNone(srv._request_handle) + self.assertTrue(handle.cancelled()) + + srv.connection_lost(None) + self.assertIsNone(srv._request_handle) + + def test_close(self): + srv = server.ServerHttpProtocol() + self.assertFalse(srv.closing) + + srv.close() + self.assertTrue(srv.closing) + + def test_handle_error(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + srv.handle_error(404) + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertIn(b'HTTP/1.1 404 Not Found', content) + + @unittest.mock.patch('tulip.http.server.traceback') + def test_handle_error_traceback_exc(self, m_trace): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(debug=True) + srv.connection_made(transport) + + m_trace.format_exc.side_effect = ValueError + + srv.handle_error(500, exc=object()) + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertTrue( + content.startswith(b'HTTP/1.1 500 Internal Server Error')) + + def test_handle_error_debug(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.debug = True + srv.connection_made(transport) + + try: + raise ValueError() + except Exception as exc: + srv.handle_error(999, exc=exc) + + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + + self.assertIn(b'HTTP/1.1 500 Internal', content) + self.assertIn(b'Traceback (most recent call last):', content) + + def test_handle_error_500(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log) + srv.connection_made(transport) + + srv.handle_error(500) + self.assertTrue(log.exception.called) + + def test_handle(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + srv.close() + + self.loop.run_until_complete(srv._request_handle) + self.assertTrue(handle.called) + self.assertIsNone(srv._request_handle) + + def test_handle_coro(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + called = False + + @tulip.coroutine + def coro(rline, message): + nonlocal called + called = True + yield from [] + srv.eof_received() + + srv.handle_request = coro + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + self.loop.run_until_complete(srv._request_handle) + self.assertTrue(called) + + def test_handle_close(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + srv.close() + self.loop.run_until_complete(srv._request_handle) + + self.assertTrue(handle.called) + self.assertTrue(transport.close.called) + + def test_handle_cancel(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log, debug=True) + srv.connection_made(transport) + + srv.handle_request = unittest.mock.Mock() + + @tulip.task + def cancel(): + yield from [] + srv._request_handle.cancel() + + srv.close() + self.loop.run_until_complete( + tulip.wait([srv._request_handle, cancel()])) + self.assertTrue(log.debug.called) + + def test_handle_400(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + srv.handle_error = unittest.mock.Mock() + + def side_effect(*args): + srv.close() + srv.handle_error.side_effect = side_effect + + srv.stream.feed_data(b'GET / HT/asd\r\n') + + self.loop.run_until_complete(srv._request_handle) + self.assertTrue(srv.handle_error.called) + self.assertTrue(400, srv.handle_error.call_args[0][0]) + self.assertTrue(transport.close.called) + + def test_handle_500(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + handle.side_effect = ValueError + srv.handle_error = unittest.mock.Mock() + srv.close() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + self.loop.run_until_complete(srv._request_handle) + + self.assertTrue(srv.handle_error.called) + self.assertTrue(500, srv.handle_error.call_args[0][0]) diff --git a/tests/locks_test.py b/tests/locks_test.py new file mode 100644 index 00000000..7d2111d9 --- /dev/null +++ b/tests/locks_test.py @@ -0,0 +1,747 @@ +"""Tests for lock.py""" + +import time +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import locks +from tulip import tasks +from tulip import test_utils + + +class LockTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + lock = locks.Lock() + self.assertTrue(repr(lock).endswith('[unlocked]>')) + + def acquire_lock(): + yield from lock + + self.event_loop.run_until_complete(acquire_lock()) + self.assertTrue(repr(lock).endswith('[locked]>')) + + def test_lock(self): + lock = locks.Lock() + + def acquire_lock(): + return (yield from lock) + + res = self.event_loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_acquire(self): + lock = locks.Lock() + result = [] + + self.assertTrue( + self.event_loop.run_until_complete(lock.acquire())) + + @tasks.coroutine + def c1(result): + if (yield from lock.acquire()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from lock.acquire()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from lock.acquire()): + result.append(3) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + self.event_loop.run_once() + self.assertEqual([1], result) + + tasks.Task(c3(result)) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + + def test_acquire_timeout(self): + lock = locks.Lock() + self.assertTrue( + self.event_loop.run_until_complete(lock.acquire())) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete( + lock.acquire(timeout=0.1)) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + lock = locks.Lock() + self.event_loop.run_until_complete(lock.acquire()) + + self.event_loop.call_later(0.1, lock.release) + acquired = self.event_loop.run_until_complete(lock.acquire(10.1)) + self.assertTrue(acquired) + + def test_acquire_timeout_mixed(self): + lock = locks.Lock() + self.event_loop.run_until_complete(lock.acquire()) + tasks.Task(lock.acquire()) + tasks.Task(lock.acquire()) + acquire_task = tasks.Task(lock.acquire(0.1)) + tasks.Task(lock.acquire()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(lock._waiters)) + + def test_acquire_cancel(self): + lock = locks.Lock() + self.assertTrue( + self.event_loop.run_until_complete(lock.acquire())) + + task = tasks.Task(lock.acquire()) + self.event_loop.call_soon(task.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, task) + self.assertFalse(lock._waiters) + + def test_release_not_acquired(self): + lock = locks.Lock() + + self.assertRaises(RuntimeError, lock.release) + + def test_release_no_waiters(self): + lock = locks.Lock() + self.event_loop.run_until_complete(lock.acquire()) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_context_manager(self): + lock = locks.Lock() + + @tasks.task + def acquire_lock(): + return (yield from lock) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + def test_context_manager_no_yield(self): + lock = locks.Lock() + + try: + with lock: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + +class EventWaiterTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + ev = locks.EventWaiter() + self.assertTrue(repr(ev).endswith('[unset]>')) + + ev.set() + self.assertTrue(repr(ev).endswith('[set]>')) + + def test_wait(self): + ev = locks.EventWaiter() + self.assertFalse(ev.is_set()) + + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from ev.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from ev.wait()): + result.append(3) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + tasks.Task(c3(result)) + + ev.set() + self.event_loop.run_once() + self.assertEqual([3, 1, 2], result) + + def test_wait_on_set(self): + ev = locks.EventWaiter() + ev.set() + + res = self.event_loop.run_until_complete(ev.wait()) + self.assertTrue(res) + + def test_wait_timeout(self): + ev = locks.EventWaiter() + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(ev.wait(0.1)) + self.assertFalse(res) + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + ev = locks.EventWaiter() + self.event_loop.call_later(0.1, ev.set) + acquired = self.event_loop.run_until_complete(ev.wait(10.1)) + self.assertTrue(acquired) + + def test_wait_timeout_mixed(self): + ev = locks.EventWaiter() + tasks.Task(ev.wait()) + tasks.Task(ev.wait()) + acquire_task = tasks.Task(ev.wait(0.1)) + tasks.Task(ev.wait()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(ev._waiters)) + + def test_wait_cancel(self): + ev = locks.EventWaiter() + + wait = tasks.Task(ev.wait()) + self.event_loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, wait) + self.assertFalse(ev._waiters) + + def test_clear(self): + ev = locks.EventWaiter() + self.assertFalse(ev.is_set()) + + ev.set() + self.assertTrue(ev.is_set()) + + ev.clear() + self.assertFalse(ev.is_set()) + + def test_clear_with_waiters(self): + ev = locks.EventWaiter() + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + tasks.Task(c1(result)) + self.event_loop.run_once() + self.assertEqual([], result) + + ev.set() + ev.clear() + self.assertFalse(ev.is_set()) + + ev.set() + ev.set() + self.assertEqual(1, len(ev._waiters)) + + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertEqual(0, len(ev._waiters)) + + +class ConditionTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_wait(self): + cond = locks.Condition() + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue( + self.event_loop.run_until_complete(cond.acquire())) + cond.notify() + self.event_loop.run_once() + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + def test_wait_timeout(self): + cond = locks.Condition() + self.event_loop.run_until_complete(cond.acquire()) + + t0 = time.monotonic() + wait = self.event_loop.run_until_complete(cond.wait(0.1)) + self.assertFalse(wait) + self.assertTrue(cond.locked()) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + def test_wait_cancel(self): + cond = locks.Condition() + self.event_loop.run_until_complete(cond.acquire()) + + wait = tasks.Task(cond.wait()) + self.event_loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, wait) + self.assertFalse(cond._condition_waiters) + self.assertTrue(cond.locked()) + + def test_wait_unacquired(self): + self.suppress_log_errors() + + cond = locks.Condition() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, cond.wait()) + + def test_wait_for(self): + cond = locks.Condition() + presult = False + + def predicate(): + return presult + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate)): + result.append(1) + cond.release() + + tasks.Task(c1(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([], result) + + presult = True + self.event_loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + def test_wait_for_timeout(self): + cond = locks.Condition() + + result = [] + + predicate = unittest.mock.Mock() + predicate.return_value = False + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate, 0.2)): + result.append(1) + else: + result.append(2) + cond.release() + + wait_for = tasks.Task(c1(result)) + + t0 = time.monotonic() + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(wait_for) + self.assertEqual([2], result) + self.assertEqual(3, predicate.call_count) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.18 < total_time < 0.22) + + def test_wait_for_unacquired(self): + self.suppress_log_errors() + + cond = locks.Condition() + + # predicate can return true immediately + res = self.event_loop.run_until_complete( + cond.wait_for(lambda: [1, 2, 3])) + self.assertEqual([1, 2, 3], res) + + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, + cond.wait_for(lambda: False)) + + def test_notify(self): + cond = locks.Condition() + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + cond.release() + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.notify(2048) + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + + def test_notify_all(self): + cond = locks.Condition() + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify_all() + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + + def test_notify_unacquired(self): + cond = locks.Condition() + self.assertRaises(RuntimeError, cond.notify) + + def test_notify_all_unacquired(self): + cond = locks.Condition() + self.assertRaises(RuntimeError, cond.notify_all) + + +class SemaphoreTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + sem = locks.Semaphore() + self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) + + self.event_loop.run_until_complete(sem.acquire()) + self.assertTrue(repr(sem).endswith('[locked]>')) + + def test_semaphore(self): + sem = locks.Semaphore() + self.assertEqual(1, sem._value) + + @tasks.task + def acquire_lock(): + return (yield from sem) + + res = self.event_loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(sem.locked()) + self.assertEqual(0, sem._value) + + sem.release() + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + def test_semaphore_value(self): + self.assertRaises(ValueError, locks.Semaphore, -1) + + def test_acquire(self): + sem = locks.Semaphore(3) + result = [] + + self.assertTrue( + self.event_loop.run_until_complete(sem.acquire())) + self.assertTrue( + self.event_loop.run_until_complete(sem.acquire())) + self.assertFalse(sem.locked()) + + @tasks.coroutine + def c1(result): + yield from sem.acquire() + result.append(1) + + @tasks.coroutine + def c2(result): + yield from sem.acquire() + result.append(2) + + @tasks.coroutine + def c3(result): + yield from sem.acquire() + result.append(3) + + @tasks.coroutine + def c4(result): + yield from sem.acquire() + result.append(4) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(sem.locked()) + self.assertEqual(2, len(sem._waiters)) + self.assertEqual(0, sem._value) + + tasks.Task(c4(result)) + + sem.release() + sem.release() + self.assertEqual(2, sem._value) + + self.event_loop.run_once() + self.assertEqual(0, sem._value) + self.assertEqual([1, 2, 3], result) + self.assertTrue(sem.locked()) + self.assertEqual(1, len(sem._waiters)) + self.assertEqual(0, sem._value) + + def test_acquire_timeout(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(sem.acquire(0.1)) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + + self.event_loop.call_later(0.1, sem.release) + acquired = self.event_loop.run_until_complete(sem.acquire(10.1)) + self.assertTrue(acquired) + + def test_acquire_timeout_mixed(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + tasks.Task(sem.acquire()) + tasks.Task(sem.acquire()) + acquire_task = tasks.Task(sem.acquire(0.1)) + tasks.Task(sem.acquire()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(sem._waiters)) + + def test_acquire_cancel(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + + acquire = tasks.Task(sem.acquire()) + self.event_loop.call_soon(acquire.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, acquire) + self.assertFalse(sem._waiters) + + def test_release_not_acquired(self): + sem = locks.Semaphore(bound=True) + + self.assertRaises(ValueError, sem.release) + + def test_release_no_waiters(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + self.assertTrue(sem.locked()) + + sem.release() + self.assertFalse(sem.locked()) + + def test_context_manager(self): + sem = locks.Semaphore(2) + + @tasks.task + def acquire_lock(): + return (yield from sem) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertTrue(sem.locked()) + + self.assertEqual(2, sem._value) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py new file mode 100644 index 00000000..6b801ffa --- /dev/null +++ b/tests/proactor_events_test.py @@ -0,0 +1,327 @@ +"""Tests for proactor_events.py""" + +import socket +import unittest +import unittest.mock + +import tulip +from tulip.proactor_events import BaseProactorEventLoop +from tulip.proactor_events import _ProactorSocketTransport + + +class ProactorSocketTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock() + self.sock = unittest.mock.Mock(socket.socket) + self.protocol = unittest.mock.Mock(tulip.Protocol) + + def test_ctor(self): + fut = tulip.Future() + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol, fut) + self.event_loop.call_soon.mock_calls[0].assert_called_with( + tr._loop_reading) + self.event_loop.call_soon.mock_calls[1].assert_called_with( + self.protocol.connection_made, tr) + self.event_loop.call_soon.mock_calls[2].assert_called_with( + fut.set_result, None) + + def test_loop_reading(self): + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._loop_reading() + self.event_loop._proactor.recv.assert_called_with(self.sock, 4096) + self.assertFalse(self.protocol.data_received.called) + self.assertFalse(self.protocol.eof_received.called) + + def test_loop_reading_data(self): + res = tulip.Future() + res.set_result(b'data') + + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + + tr._read_fut = res + tr._loop_reading(res) + self.event_loop._proactor.recv.assert_called_with(self.sock, 4096) + self.protocol.data_received.assert_called_with(b'data') + + def test_loop_reading_no_data(self): + res = tulip.Future() + res.set_result(b'') + + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + + self.assertRaises(AssertionError, tr._loop_reading, res) + + tr._read_fut = res + tr._loop_reading(res) + self.assertFalse(self.event_loop._proactor.recv.called) + self.assertTrue(self.protocol.eof_received.called) + + def test_loop_reading_aborted(self): + err = self.event_loop._proactor.recv.side_effect = ( + ConnectionAbortedError()) + + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with(err) + + def test_loop_reading_aborted_closing(self): + self.event_loop._proactor.recv.side_effect = ( + ConnectionAbortedError()) + + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._closing = True + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + + def test_loop_reading_exception(self): + err = self.event_loop._proactor.recv.side_effect = (OSError()) + + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with(err) + + def test_write(self): + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._loop_writing = unittest.mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, [b'data']) + self.assertTrue(tr._loop_writing.called) + + def test_write_no_data(self): + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr.write(b'') + self.assertFalse(tr._buffer) + + def test_write_more(self): + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._write_fut = unittest.mock.Mock() + tr._loop_writing = unittest.mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, [b'data']) + self.assertFalse(tr._loop_writing.called) + + def test_loop_writing(self): + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + self.event_loop._proactor.send.assert_called_with(self.sock, b'data') + self.event_loop._proactor.send.return_value.add_done_callback.\ + assert_called_with(tr._loop_writing) + + def test_loop_writing_err(self): + err = self.event_loop._proactor.send.side_effect = OSError() + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + tr._fatal_error.assert_called_with(err) + + def test_loop_writing_stop(self): + fut = tulip.Future() + fut.set_result(b'data') + + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._write_fut = fut + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + + def test_abort(self): + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr.abort() + tr._fatal_error.assert_called_with(None) + + def test_close(self): + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._write_fut = unittest.mock.Mock() + tr.close() + + tr._write_fut.cancel.assert_called_with() + self.event_loop.call_soon.assert_called_with( + tr._call_connection_lost, None) + self.assertTrue(tr._closing) + + def test_close_2(self): + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._buffer = [b'data'] + self.event_loop.reset_mock() + tr.close() + + self.assertFalse(self.event_loop.call_soon.called) + + @unittest.mock.patch('tulip.proactor_events.logging') + def test_fatal_error(self, m_logging): + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._buffer = [b'data'] + read_fut = tr._read_fut = unittest.mock.Mock() + write_fut = tr._write_fut = unittest.mock.Mock() + tr._fatal_error(None) + + read_fut.cancel.assert_called_with() + write_fut.cancel.assert_called_with() + self.event_loop.call_soon.assert_called_with( + tr._call_connection_lost, None) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('tulip.proactor_events.logging') + def test_fatal_error_2(self, m_logging): + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._buffer = [b'data'] + tr._fatal_error(None) + + self.event_loop.call_soon.assert_called_with( + tr._call_connection_lost, None) + self.assertEqual([], tr._buffer) + + def test_call_connection_lost(self): + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._call_connection_lost(None) + self.assertTrue(self.protocol.connection_lost.called) + self.assertTrue(self.sock.close.called) + + +class BaseProactorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.sock = unittest.mock.Mock(socket.socket) + self.proactor = unittest.mock.Mock() + + self.ssock, self.csock = unittest.mock.Mock(), unittest.mock.Mock() + + class EventLoop(BaseProactorEventLoop): + def _socketpair(s): + return (self.ssock, self.csock) + + self.event_loop = EventLoop(self.proactor) + + @unittest.mock.patch.object(BaseProactorEventLoop, 'call_soon') + @unittest.mock.patch.object(BaseProactorEventLoop, '_socketpair') + def test_ctor(self, socketpair, call_soon): + ssock, csock = socketpair.return_value = ( + unittest.mock.Mock(), unittest.mock.Mock()) + event_loop = BaseProactorEventLoop(self.proactor) + self.assertIs(event_loop._ssock, ssock) + self.assertIs(event_loop._csock, csock) + self.assertEqual(event_loop._internal_fds, 1) + call_soon.assert_called_with(event_loop._loop_self_reading) + + def test_close_self_pipe(self): + self.event_loop._close_self_pipe() + self.assertEqual(self.event_loop._internal_fds, 0) + self.assertTrue(self.ssock.close.called) + self.assertTrue(self.csock.close.called) + self.assertIsNone(self.event_loop._ssock) + self.assertIsNone(self.event_loop._csock) + + def test_close(self): + self.event_loop._close_self_pipe = unittest.mock.Mock() + self.event_loop.close() + self.assertTrue(self.event_loop._close_self_pipe.called) + self.assertTrue(self.proactor.close.called) + self.assertIsNone(self.event_loop._proactor) + + self.event_loop._close_self_pipe.reset_mock() + self.event_loop.close() + self.assertFalse(self.event_loop._close_self_pipe.called) + + def test_sock_recv(self): + self.event_loop.sock_recv(self.sock, 1024) + self.proactor.recv.assert_called_with(self.sock, 1024) + + def test_sock_sendall(self): + self.event_loop.sock_sendall(self.sock, b'data') + self.proactor.send.assert_called_with(self.sock, b'data') + + def test_sock_connect(self): + self.event_loop.sock_connect(self.sock, 123) + self.proactor.connect.assert_called_with(self.sock, 123) + + def test_sock_accept(self): + self.event_loop.sock_accept(self.sock) + self.proactor.accept.assert_called_with(self.sock) + + def test_socketpair(self): + self.assertRaises( + NotImplementedError, BaseProactorEventLoop, self.proactor) + + def test_make_socket_transport(self): + tr = self.event_loop._make_socket_transport( + self.sock, unittest.mock.Mock()) + self.assertIsInstance(tr, _ProactorSocketTransport) + + def test_loop_self_reading(self): + self.event_loop._loop_self_reading() + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.event_loop._loop_self_reading) + + def test_loop_self_reading_fut(self): + fut = unittest.mock.Mock() + self.event_loop._loop_self_reading(fut) + self.assertTrue(fut.result.called) + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.event_loop._loop_self_reading) + + def test_loop_self_reading_exception(self): + self.event_loop.close = unittest.mock.Mock() + self.proactor.recv.side_effect = OSError() + self.assertRaises(OSError, self.event_loop._loop_self_reading) + self.assertTrue(self.event_loop.close.called) + + def test_write_to_self(self): + self.event_loop._write_to_self() + self.csock.send.assert_called_with(b'x') + + def test_process_events(self): + self.event_loop._process_events([]) + + def test_start_serving(self): + pf = unittest.mock.Mock() + call_soon = self.event_loop.call_soon = unittest.mock.Mock() + + self.event_loop._start_serving(pf, self.sock) + self.assertTrue(call_soon.called) + + # callback + loop = call_soon.call_args[0][0] + loop() + self.proactor.accept.assert_called_with(self.sock) + + # conn + fut = unittest.mock.Mock() + fut.result.return_value = (unittest.mock.Mock(), unittest.mock.Mock()) + + make_transport = self.event_loop._make_socket_transport = ( + unittest.mock.Mock()) + loop(fut) + self.assertTrue(fut.result.called) + self.assertTrue(make_transport.called) + + # exception + fut.result.side_effect = OSError() + loop(fut) + self.assertTrue(self.sock.close.called) diff --git a/tests/queues_test.py b/tests/queues_test.py new file mode 100644 index 00000000..d86abd7f --- /dev/null +++ b/tests/queues_test.py @@ -0,0 +1,370 @@ +"""Tests for queues.py""" + +import unittest +import queue + +from tulip import events +from tulip import locks +from tulip import queues +from tulip import tasks + + +class _QueueTestBase(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + +class QueueBasicTests(_QueueTestBase): + + def _test_repr_or_str(self, fn, expect_id): + """Test Queue's repr or str. + + fn is repr or str. expect_id is True if we expect the Queue's id to + appear in fn(Queue()). + """ + q = queues.Queue() + self.assertTrue(fn(q).startswith('", repr(key)) + + def test_register(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ) + self.assertIsInstance(key, selectors.SelectorKey) + self.assertEqual(key.fd, 10) + self.assertIs(key, s._fd_to_key[10]) + + def test_register_unknown_event(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.register, unittest.mock.Mock(), 999999) + + def test_register_already_registered(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + s.register(fobj, selectors.EVENT_READ) + self.assertRaises(ValueError, s.register, fobj, selectors.EVENT_READ) + + def test_unregister(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + s.register(fobj, selectors.EVENT_READ) + s.unregister(fobj) + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_unregister_unknown(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.unregister, unittest.mock.Mock()) + + def test_modify_unknown(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.modify, unittest.mock.Mock(), 1) + + def test_modify(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ) + key2 = s.modify(fobj, selectors.EVENT_WRITE) + self.assertNotEqual(key.events, key2.events) + self.assertEqual((selectors.EVENT_WRITE, None), s.get_info(fobj)) + + def test_modify_data(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + d1 = object() + d2 = object() + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ, d1) + key2 = s.modify(fobj, selectors.EVENT_READ, d2) + self.assertEqual(key.events, key2.events) + self.assertNotEqual(key.data, key2.data) + self.assertEqual((selectors.EVENT_READ, d2), s.get_info(fobj)) + + def test_modify_same(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + data = object() + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ, data) + key2 = s.modify(fobj, selectors.EVENT_READ, data) + self.assertIs(key, key2) + + def test_select(self): + s = selectors._BaseSelector() + self.assertRaises(NotImplementedError, s.select) + + def test_close(self): + s = selectors._BaseSelector() + s.register(1, selectors.EVENT_READ) + + s.close() + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_registered_count(self): + s = selectors._BaseSelector() + self.assertEqual(0, s.registered_count()) + + s.register(1, selectors.EVENT_READ) + self.assertEqual(1, s.registered_count()) + + s.unregister(1) + self.assertEqual(0, s.registered_count()) + + def test_context_manager(self): + s = selectors._BaseSelector() + + with s as sel: + sel.register(1, selectors.EVENT_READ) + + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_key_from_fd(self): + s = selectors._BaseSelector() + key = s.register(1, selectors.EVENT_READ) + + self.assertIs(key, s._key_from_fd(1)) + self.assertIsNone(s._key_from_fd(10)) diff --git a/tests/streams_test.py b/tests/streams_test.py new file mode 100644 index 00000000..dc6eeaf4 --- /dev/null +++ b/tests/streams_test.py @@ -0,0 +1,299 @@ +"""Tests for streams.py.""" + +import unittest + +from tulip import events +from tulip import streams +from tulip import tasks +from tulip import test_utils + + +class StreamReaderTests(test_utils.LogTrackingTestCase): + + DATA = b'line1\nline2\nline3\n' + + def setUp(self): + super().setUp() + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + super().tearDown() + self.event_loop.close() + + def test_feed_empty_data(self): + stream = streams.StreamReader() + + stream.feed_data(b'') + self.assertEqual(0, stream.byte_count) + + def test_feed_data_byte_count(self): + stream = streams.StreamReader() + + stream.feed_data(self.DATA) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read_zero(self): + # Read zero bytes. + stream = streams.StreamReader() + stream.feed_data(self.DATA) + + data = self.event_loop.run_until_complete(stream.read(0)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read(self): + # Read bytes. + stream = streams.StreamReader() + read_task = tasks.Task(stream.read(30)) + + def cb(): + stream.feed_data(self.DATA) + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + + def test_read_line_breaks(self): + # Read bytes without line breaks. + stream = streams.StreamReader() + stream.feed_data(b'line1') + stream.feed_data(b'line2') + + data = self.event_loop.run_until_complete(stream.read(5)) + + self.assertEqual(b'line1', data) + self.assertEqual(5, stream.byte_count) + + def test_read_eof(self): + # Read bytes, stop at eof. + stream = streams.StreamReader() + read_task = tasks.Task(stream.read(1024)) + + def cb(): + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'', data) + self.assertFalse(stream.byte_count) + + def test_read_until_eof(self): + # Read all bytes until eof. + stream = streams.StreamReader() + read_task = tasks.Task(stream.read(-1)) + + def cb(): + stream.feed_data(b'chunk1\n') + stream.feed_data(b'chunk2') + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'chunk1\nchunk2', data) + self.assertFalse(stream.byte_count) + + def test_read_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete(stream.read(2)) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, stream.read(2)) + + def test_readline(self): + # Read one line. + stream = streams.StreamReader() + stream.feed_data(b'chunk1 ') + read_task = tasks.Task(stream.readline()) + + def cb(): + stream.feed_data(b'chunk2 ') + stream.feed_data(b'chunk3 ') + stream.feed_data(b'\n chunk4') + self.event_loop.call_soon(cb) + + line = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) + self.assertEqual(len(b'\n chunk4')-1, stream.byte_count) + + def test_readline_limit_with_existing_data(self): + self.suppress_log_errors() + + stream = streams.StreamReader(3) + stream.feed_data(b'li') + stream.feed_data(b'ne1\nline2\n') + + self.assertRaises( + ValueError, self.event_loop.run_until_complete, stream.readline()) + self.assertEqual([b'line2\n'], list(stream.buffer)) + + stream = streams.StreamReader(3) + stream.feed_data(b'li') + stream.feed_data(b'ne1') + stream.feed_data(b'li') + + self.assertRaises( + ValueError, self.event_loop.run_until_complete, stream.readline()) + self.assertEqual([b'li'], list(stream.buffer)) + self.assertEqual(2, stream.byte_count) + + def test_readline_limit(self): + self.suppress_log_errors() + + stream = streams.StreamReader(7) + + def cb(): + stream.feed_data(b'chunk1') + stream.feed_data(b'chunk2') + stream.feed_data(b'chunk3\n') + stream.feed_eof() + self.event_loop.call_soon(cb) + + self.assertRaises( + ValueError, self.event_loop.run_until_complete, stream.readline()) + self.assertEqual([b'chunk3\n'], list(stream.buffer)) + self.assertEqual(7, stream.byte_count) + + def test_readline_line_byte_count(self): + stream = streams.StreamReader() + stream.feed_data(self.DATA[:6]) + stream.feed_data(self.DATA[6:]) + + line = self.event_loop.run_until_complete(stream.readline()) + + self.assertEqual(b'line1\n', line) + self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) + + def test_readline_eof(self): + stream = streams.StreamReader() + stream.feed_data(b'some data') + stream.feed_eof() + + line = self.event_loop.run_until_complete(stream.readline()) + self.assertEqual(b'some data', line) + + def test_readline_empty_eof(self): + stream = streams.StreamReader() + stream.feed_eof() + + line = self.event_loop.run_until_complete(stream.readline()) + self.assertEqual(b'', line) + + def test_readline_read_byte_count(self): + stream = streams.StreamReader() + stream.feed_data(self.DATA) + + self.event_loop.run_until_complete(stream.readline()) + + data = self.event_loop.run_until_complete(stream.read(7)) + + self.assertEqual(b'line2\nl', data) + self.assertEqual( + len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), + stream.byte_count) + + def test_readline_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete(stream.readline()) + self.assertEqual(b'line\n', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, stream.readline()) + + def test_readexactly_zero_or_less(self): + # Read exact number of bytes (zero or less). + stream = streams.StreamReader() + stream.feed_data(self.DATA) + + data = self.event_loop.run_until_complete(stream.readexactly(0)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + data = self.event_loop.run_until_complete(stream.readexactly(-1)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readexactly(self): + # Read exact number of bytes. + stream = streams.StreamReader() + + n = 2 * len(self.DATA) + read_task = tasks.Task(stream.readexactly(n)) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA + self.DATA, data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readexactly_eof(self): + # Read exact number of bytes (eof). + stream = streams.StreamReader() + n = 2 * len(self.DATA) + read_task = tasks.Task(stream.readexactly(n)) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + + def test_readexactly_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete(stream.readexactly(2)) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, + stream.readexactly(2)) + + def test_exception(self): + stream = streams.StreamReader() + self.assertIsNone(stream.exception()) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(stream.exception(), exc) + + def test_exception_waiter(self): + stream = streams.StreamReader() + + def set_err(): + yield from [] + stream.set_exception(ValueError()) + + def readline(): + yield from stream.readline() + + t1 = tasks.Task(stream.readline()) + t2 = tasks.Task(set_err()) + + self.event_loop.run_until_complete(tasks.wait([t1, t2])) + + self.assertRaises(ValueError, t1.result) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/subprocess_test.py b/tests/subprocess_test.py new file mode 100644 index 00000000..09aaed52 --- /dev/null +++ b/tests/subprocess_test.py @@ -0,0 +1,54 @@ +"""Tests for subprocess_transport.py.""" + +import logging +import unittest + +from tulip import events +from tulip import protocols +from tulip import subprocess_transport + + +class MyProto(protocols.Protocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write_eof() + + def data_received(self, data): + logging.info('received: %r', data) + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + + +class FutureTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_unix_subprocess(self): + p = MyProto() + subprocess_transport.UnixSubprocessTransport(p, ['/bin/ls', '-lR']) + self.event_loop.run() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/tasks_test.py b/tests/tasks_test.py new file mode 100644 index 00000000..88c23b59 --- /dev/null +++ b/tests/tasks_test.py @@ -0,0 +1,647 @@ +"""Tests for tasks.py.""" + +import concurrent.futures +import time +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import tasks +from tulip import test_utils + + +class Dummy: + + def __repr__(self): + return 'Dummy()' + + def __call__(self, *args): + pass + + +class TaskTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + super().tearDown() + + def test_task_class(self): + @tasks.coroutine + def notmuch(): + yield from [] + return 'ok' + t = tasks.Task(notmuch()) + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t._event_loop, self.event_loop) + + event_loop = events.new_event_loop() + t = tasks.Task(notmuch(), event_loop=event_loop) + self.assertIs(t._event_loop, event_loop) + + def test_task_decorator(self): + @tasks.task + def notmuch(): + yield from [] + return 'ko' + t = notmuch() + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + + def test_task_repr(self): + @tasks.task + def notmuch(): + yield from [] + return 'abc' + t = notmuch() + t.add_done_callback(Dummy()) + self.assertEqual(repr(t), 'Task()') + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), 'Task()') + self.assertRaises(futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertEqual(repr(t), 'Task()') + t = notmuch() + self.event_loop.run_until_complete(t) + self.assertEqual(repr(t), "Task()") + + def test_task_repr_custom(self): + def coro(): + yield from [] + + class T(futures.Future): + def __repr__(self): + return 'T[]' + + class MyTask(tasks.Task, T): + def __repr__(self): + return super().__repr__() + + t = MyTask(coro()) + self.assertEqual(repr(t), 'T[]()') + + def test_task_basics(self): + @tasks.task + def outer(): + a = yield from inner1() + b = yield from inner2() + return a+b + + @tasks.task + def inner1(): + yield from [] + return 42 + + @tasks.task + def inner2(): + yield from [] + return 1000 + + t = outer() + self.assertEqual(self.event_loop.run_until_complete(t), 1042) + + def test_cancel(self): + @tasks.task + def task(): + yield from tasks.sleep(10.0) + return 12 + + t = task() + self.event_loop.call_soon(t.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_future_timeout(self): + @tasks.coroutine + def coro(): + yield from tasks.sleep(10.0) + return 12 + + t = tasks.Task(coro(), timeout=0.1) + + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_future_timeout_catch(self): + @tasks.coroutine + def coro(): + yield from tasks.sleep(10.0) + return 12 + + err = None + + @tasks.coroutine + def coro2(): + nonlocal err + try: + yield from tasks.Task(coro(), timeout=0.1) + except futures.CancelledError as exc: + err = exc + + self.event_loop.run_until_complete(tasks.Task(coro2())) + self.assertIsInstance(err, futures.CancelledError) + + def test_cancel_in_coro(self): + @tasks.coroutine + def task(): + t.cancel() + yield from [] + return 12 + + t = tasks.Task(task()) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_stop_while_run_in_complete(self): + x = 0 + + @tasks.coroutine + def task(): + nonlocal x + while x < 10: + yield from tasks.sleep(0.1) + x += 1 + if x == 2: + self.event_loop.stop() + + t = tasks.Task(task()) + t0 = time.monotonic() + self.assertRaises( + futures.InvalidStateError, + self.event_loop.run_until_complete, t) + t1 = time.monotonic() + self.assertFalse(t.done()) + self.assertTrue(0.18 <= t1-t0 <= 0.22) + self.assertEqual(x, 2) + + def test_timeout(self): + @tasks.task + def task(): + yield from tasks.sleep(10.0) + return 42 + + t = task() + t0 = time.monotonic() + self.assertRaises( + futures.TimeoutError, + self.event_loop.run_until_complete, t, 0.1) + t1 = time.monotonic() + self.assertFalse(t.done()) + self.assertTrue(0.08 <= t1-t0 <= 0.12) + + def test_timeout_not(self): + @tasks.task + def task(): + yield from tasks.sleep(0.1) + return 42 + + t = task() + t0 = time.monotonic() + r = self.event_loop.run_until_complete(t, 10.0) + t1 = time.monotonic() + self.assertTrue(t.done()) + self.assertEqual(r, 42) + self.assertTrue(0.08 <= t1-t0 <= 0.12) + + def test_wait(self): + a = tasks.sleep(0.1) + b = tasks.sleep(0.15) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertEqual(res, 42) + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + # TODO: Test different return_when values. + + def test_wait_first_completed(self): + a = tasks.sleep(10.0) + b = tasks.sleep(0.1) + task = tasks.Task(tasks.wait( + [b, a], return_when=tasks.FIRST_COMPLETED)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + + def test_wait_really_done(self): + self.suppress_log_errors() + # there is possibility that some tasks in the pending list + # became done but their callbacks haven't all been called yet + + @tasks.coroutine + def coro1(): + yield from [None] + + @tasks.coroutine + def coro2(): + yield from [None, None] + + a = tasks.Task(coro1()) + b = tasks.Task(coro2()) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({a, b}, done) + + def test_wait_first_exception(self): + self.suppress_log_errors() + + a = tasks.sleep(10.0) + + @tasks.coroutine + def exc(): + yield from [] + raise ZeroDivisionError('err') + + b = tasks.Task(exc()) + task = tasks.Task(tasks.wait( + [b, a], return_when=tasks.FIRST_EXCEPTION)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + + def test_wait_with_exception(self): + self.suppress_log_errors() + a = tasks.sleep(0.1) + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(0.15) + raise ZeroDivisionError('really') + + b = tasks.Task(sleeper()) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + + t0 = time.monotonic() + self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + t0 = time.monotonic() + self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + + def test_wait_with_timeout(self): + a = tasks.sleep(0.1) + b = tasks.sleep(0.15) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], timeout=0.11) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + + t0 = time.monotonic() + self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.1) + self.assertTrue(t1-t0 <= 0.13) + + def test_as_completed(self): + @tasks.coroutine + def sleeper(dt, x): + yield from tasks.sleep(dt) + return x + + a = sleeper(0.1, 'a') + b = sleeper(0.1, 'b') + c = sleeper(0.15, 'c') + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([b, c, a]): + values.append((yield from f)) + return values + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + + def test_as_completed_with_timeout(self): + self.suppress_log_errors() + a = tasks.sleep(0.1, 'a') + b = tasks.sleep(0.15, 'b') + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([a, b], timeout=0.12): + try: + v = yield from f + values.append((1, v)) + except futures.TimeoutError as exc: + values.append((2, exc)) + return values + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.11) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) + + def test_sleep(self): + @tasks.coroutine + def sleeper(dt, arg): + yield from tasks.sleep(dt/2) + res = yield from tasks.sleep(dt/2, arg) + return res + + t = tasks.Task(sleeper(0.1, 'yeah')) + t0 = time.monotonic() + self.event_loop.run() + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.09) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + + def test_task_cancel_sleeping_task(self): + sleepfut = None + + @tasks.task + def sleep(dt): + nonlocal sleepfut + sleepfut = tasks.sleep(dt) + try: + time.monotonic() + yield from sleepfut + finally: + time.monotonic() + + @tasks.task + def doit(): + sleeper = sleep(5000) + self.event_loop.call_later(0.1, sleeper.cancel) + try: + time.monotonic() + yield from sleeper + except futures.CancelledError: + time.monotonic() + return 'cancelled' + else: + return 'slept in' + + t0 = time.monotonic() + doer = doit() + self.assertEqual(self.event_loop.run_until_complete(doer), 'cancelled') + t1 = time.monotonic() + self.assertTrue(0.09 <= t1-t0 <= 0.13, (t1-t0, sleepfut, doer)) + + @unittest.mock.patch('tulip.tasks.logging') + def test_step_in_completed_task(self, m_logging): + @tasks.coroutine + def notmuch(): + yield from [] + return 'ko' + + task = tasks.Task(notmuch()) + task.set_result('ok') + + task._step() + self.assertTrue(m_logging.warn.called) + self.assertTrue(m_logging.warn.call_args[0][0].startswith( + '_step(): already done: ')) + + @unittest.mock.patch('tulip.tasks.logging') + def test_step_result(self, m_logging): + @tasks.coroutine + def notmuch(): + yield from [None, 1] + return 'ko' + + task = tasks.Task(notmuch()) + task._step() + self.assertFalse(m_logging.warn.called) + + task._step() + self.assertTrue(m_logging.warn.called) + self.assertEqual( + '_step(): bad yield: %r', + m_logging.warn.call_args[0][0]) + self.assertEqual(1, m_logging.warn.call_args[0][1]) + + def test_step_result_future(self): + # If coroutine returns future, task waits on this future. + self.suppress_log_warnings() + + class Fut(futures.Future): + def __init__(self, *args): + self.cb_added = False + super().__init__(*args) + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + fut = Fut() + result = None + + @tasks.task + def wait_for_future(): + nonlocal result + result = yield from fut + + wait_for_future() + self.event_loop.run_once() + self.assertTrue(fut.cb_added) + + res = object() + fut.set_result(res) + self.event_loop.run_once() + self.assertIs(res, result) + + def test_step_result_concurrent_future(self): + # Coroutine returns concurrent.futures.Future + self.suppress_log_warnings() + + class Fut(concurrent.futures.Future): + def __init__(self): + self.cb_added = False + super().__init__() + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + c_fut = Fut() + + @tasks.coroutine + def notmuch(): + yield from [c_fut] + return (yield) + + task = tasks.Task(notmuch()) + task._step() + self.assertTrue(c_fut.cb_added) + + res = object() + c_fut.set_result(res) + self.event_loop.run() + self.assertIs(res, task.result()) + + def test_step_with_baseexception(self): + self.suppress_log_errors() + + @tasks.coroutine + def notmutch(): + yield from [] + raise BaseException() + + task = tasks.Task(notmutch()) + self.assertRaises(BaseException, task._step) + + self.assertTrue(task.done()) + self.assertIsInstance(task.exception(), BaseException) + + def test_baseexception_during_cancel(self): + self.suppress_log_errors() + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(10) + + @tasks.coroutine + def notmutch(): + try: + yield from sleeper() + except futures.CancelledError: + raise BaseException() + + task = tasks.Task(notmutch()) + self.event_loop.run_once() + + task.cancel() + self.assertFalse(task.done()) + + self.assertRaises(BaseException, self.event_loop.run_once) + + self.assertTrue(task.done()) + self.assertTrue(task.cancelled()) + + def test_iscoroutinefunction(self): + def fn(): + pass + + self.assertFalse(tasks.iscoroutinefunction(fn)) + + def fn1(): + yield + self.assertFalse(tasks.iscoroutinefunction(fn1)) + + @tasks.coroutine + def fn2(): + yield + self.assertTrue(tasks.iscoroutinefunction(fn2)) + + def test_yield_vs_yield_from(self): + fut = futures.Future() + + @tasks.task + def wait_for_future(): + yield fut + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, task) + + def test_yield_vs_yield_from_generator(self): + fut = futures.Future() + + @tasks.coroutine + def coro(): + yield from fut + + @tasks.task + def wait_for_future(): + yield coro() + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, task) + + def test_coroutine_non_gen_function(self): + + @tasks.coroutine + def func(): + return 'test' + + self.assertTrue(tasks.iscoroutinefunction(func)) + + coro = func() + self.assertTrue(tasks.iscoroutine(coro)) + + res = self.event_loop.run_until_complete(coro) + self.assertEqual(res, 'test') + + def test_coroutine_non_gen_function_return_future(self): + fut = futures.Future() + + @tasks.coroutine + def func(): + return fut + + @tasks.coroutine + def coro(): + fut.set_result('test') + + t1 = tasks.Task(func()) + tasks.Task(coro()) + res = self.event_loop.run_until_complete(t1) + self.assertEqual(res, 'test') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/transports_test.py b/tests/transports_test.py new file mode 100644 index 00000000..4b24b50b --- /dev/null +++ b/tests/transports_test.py @@ -0,0 +1,45 @@ +"""Tests for transports.py.""" + +import unittest +import unittest.mock + +from tulip import transports + + +class TransportTests(unittest.TestCase): + + def test_ctor_extra_is_none(self): + transport = transports.Transport() + self.assertEqual(transport._extra, {}) + + def test_get_extra_info(self): + transport = transports.Transport({'extra': 'info'}) + self.assertEqual('info', transport.get_extra_info('extra')) + self.assertIsNone(transport.get_extra_info('unknown')) + + default = object() + self.assertIs(default, transport.get_extra_info('unknown', default)) + + def test_writelines(self): + transport = transports.Transport() + transport.write = unittest.mock.Mock() + + transport.writelines(['line1', 'line2', 'line3']) + self.assertEqual(3, transport.write.call_count) + + def test_not_implemented(self): + transport = transports.Transport() + + self.assertRaises(NotImplementedError, transport.write, 'data') + self.assertRaises(NotImplementedError, transport.write_eof) + self.assertRaises(NotImplementedError, transport.can_write_eof) + self.assertRaises(NotImplementedError, transport.pause) + self.assertRaises(NotImplementedError, transport.resume) + self.assertRaises(NotImplementedError, transport.close) + self.assertRaises(NotImplementedError, transport.abort) + + def test_dgram_not_implemented(self): + transport = transports.DatagramTransport() + + self.assertRaises(NotImplementedError, transport.sendto, 'data') + self.assertRaises(NotImplementedError, transport.abort) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py new file mode 100644 index 00000000..ce250fd0 --- /dev/null +++ b/tests/unix_events_test.py @@ -0,0 +1,573 @@ +"""Tests for unix_events.py.""" + +import errno +import io +import unittest +import unittest.mock + +try: + import signal +except ImportError: + signal = None + +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import unix_events + + +@unittest.skipUnless(signal, 'Signals are not supported') +class SelectorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unix_events.SelectorEventLoop() + + def test_check_signal(self): + self.assertRaises( + TypeError, self.event_loop._check_signal, '1') + self.assertRaises( + ValueError, self.event_loop._check_signal, signal.NSIG + 1) + + unix_events.signal = None + + def restore_signal(): + unix_events.signal = signal + self.addCleanup(restore_signal) + + self.assertRaises( + RuntimeError, self.event_loop._check_signal, signal.SIGINT) + + def test_handle_signal_no_handler(self): + self.event_loop._handle_signal(signal.NSIG + 1, ()) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_setup_error(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.set_wakeup_fd.side_effect = ValueError + + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + h = self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + self.assertIsInstance(h, events.Handle) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_install_error(self, m_signal): + m_signal.NSIG = signal.NSIG + + def set_wakeup_fd(fd): + if fd == -1: + raise ValueError() + m_signal.set_wakeup_fd = set_wakeup_fd + + class Err(OSError): + errno = errno.EFAULT + m_signal.signal.side_effect = Err + + self.assertRaises( + Err, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.logging') + def test_add_signal_handler_install_error2(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.event_loop._signal_handlers[signal.SIGHUP] = lambda: True + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(1, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.logging') + def test_add_signal_handler_install_error3(self, m_logging, m_signal): + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + m_signal.NSIG = signal.NSIG + + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(2, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + self.assertTrue( + self.event_loop.remove_signal_handler(signal.SIGHUP)) + self.assertTrue(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_2(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.SIGINT = signal.SIGINT + + self.event_loop.add_signal_handler(signal.SIGINT, lambda: True) + self.event_loop._signal_handlers[signal.SIGHUP] = object() + m_signal.set_wakeup_fd.reset_mock() + + self.assertTrue( + self.event_loop.remove_signal_handler(signal.SIGINT)) + self.assertFalse(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGINT, m_signal.default_int_handler), + m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.logging') + def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.set_wakeup_fd.side_effect = ValueError + + self.event_loop.remove_signal_handler(signal.SIGHUP) + self.assertTrue(m_logging.info) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error(self, m_signal): + m_signal.NSIG = signal.NSIG + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.signal.side_effect = OSError + + self.assertRaises( + OSError, self.event_loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error2(self, m_signal): + m_signal.NSIG = signal.NSIG + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.assertRaises( + RuntimeError, self.event_loop.remove_signal_handler, signal.SIGHUP) + + +class UnixReadPipeTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor(self, m_fcntl): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + self.event_loop.add_reader.assert_called_with(5, tr._read_ready) + self.event_loop.call_soon.assert_called_with( + self.protocol.connection_made, tr) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor_with_waiter(self, m_fcntl): + fut = futures.Future() + unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol, fut) + self.event_loop.call_soon.assert_called_with(fut.set_result, None) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + m_read.return_value = b'data' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.data_received.assert_called_with(b'data') + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready_eof(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + m_read.return_value = b'' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.eof_received.assert_called_with() + self.event_loop.remove_reader.assert_called_with(5) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready_blocked(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + self.event_loop.reset_mock() + m_read.side_effect = BlockingIOError + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.assertFalse(self.protocol.data_received.called) + + @unittest.mock.patch('logging.exception') + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready_error(self, m_fcntl, m_read, m_logexc): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + err = OSError() + m_read.side_effect = err + tr._close = unittest.mock.Mock() + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + tr._close.assert_called_with(err) + m_logexc.assert_called_with('Fatal error for %s', tr) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_pause(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.pause() + self.event_loop.remove_reader.assert_called_with(5) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_resume(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.resume() + self.event_loop.add_reader.assert_called_with(5, tr._read_ready) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_close(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._close = unittest.mock.Mock() + tr.close() + tr._close.assert_called_with(None) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_close_already_closing(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._closing = True + tr._close = unittest.mock.Mock() + tr.close() + self.assertFalse(tr._close.called) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__close(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = object() + tr._close(err) + self.assertTrue(tr._closing) + self.event_loop.remove_reader.assert_called_with(5) + self.protocol.connection_lost.assert_called_with(err) + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost(self, m_fcntl): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost_with_err(self, m_fcntl): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + +class UnixWritePipeTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + self.event_loop.call_soon.assert_called_with( + self.protocol.connection_made, tr) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor_with_waiter(self, m_fcntl): + fut = futures.Future() + unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol, fut) + self.event_loop.call_soon.assert_called_with(fut.set_result, None) + + @unittest.mock.patch('fcntl.fcntl') + def test_can_write_eof(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + self.assertTrue(tr.can_write_eof()) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + m_write.return_value = 4 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.add_writer.called) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_no_data(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write(b'') + self.assertFalse(m_write.called) + self.assertFalse(self.event_loop.add_writer.called) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_partial(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + m_write.return_value = 2 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.event_loop.add_writer.assert_called_with(5, tr._write_ready) + self.assertEqual([b'ta'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_buffer(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'previous'] + tr.write(b'data') + self.assertFalse(m_write.called) + self.assertFalse(self.event_loop.add_writer.called) + self.assertEqual([b'previous', b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_again(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + m_write.side_effect = BlockingIOError() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.event_loop.add_writer.assert_called_with(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_err(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = OSError() + m_write.side_effect = err + tr._fatal_error = unittest.mock.Mock() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.called) + self.assertEqual([], tr._buffer) + tr._fatal_error.assert_called_with(err) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_partial(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.return_value = 3 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'a'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_again(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.side_effect = BlockingIOError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_empty(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.return_value = 0 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('logging.exception') + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_err(self, m_fcntl, m_write, m_logexc): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.side_effect = err = OSError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + self.protocol.connection_lost.assert_called_with(err) + m_logexc.assert_called_with('Fatal error for %s', tr) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_closing(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._closing = True + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_abort(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + tr.abort() + self.assertFalse(m_write.called) + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost_with_err(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test_close(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr.close() + tr.write_eof.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test_close_closing(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr._closing = True + tr.close() + self.assertFalse(tr.write_eof.called) + + @unittest.mock.patch('fcntl.fcntl') + def test_write_eof(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write_eof() + self.assertTrue(tr._closing) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('fcntl.fcntl') + def test_write_eof_pending(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + tr._buffer = [b'data'] + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.protocol.connection_lost.called) diff --git a/tests/winsocketpair_test.py b/tests/winsocketpair_test.py new file mode 100644 index 00000000..381fb227 --- /dev/null +++ b/tests/winsocketpair_test.py @@ -0,0 +1,26 @@ +"""Tests for winsocketpair.py""" + +import unittest +import unittest.mock + +from tulip import winsocketpair + + +class WinsocketpairTests(unittest.TestCase): + + def test_winsocketpair(self): + ssock, csock = winsocketpair.socketpair() + + csock.send(b'xxx') + self.assertEqual(b'xxx', ssock.recv(1024)) + + csock.close() + ssock.close() + + @unittest.mock.patch('tulip.winsocketpair.socket') + def test_winsocketpair_exc(self, m_socket): + m_socket.socket.return_value.getsockname.return_value = ('', 12345) + m_socket.socket.return_value.accept.return_value = object(), object() + m_socket.socket.return_value.connect.side_effect = OSError() + + self.assertRaises(OSError, winsocketpair.socketpair) diff --git a/tulip/TODO b/tulip/TODO new file mode 100644 index 00000000..acec5c24 --- /dev/null +++ b/tulip/TODO @@ -0,0 +1,28 @@ +TODO in tulip v2 (tulip/ package directory) +------------------------------------------- + +- See also TBD and Open Issues in PEP 3156 + +- Refactor unix_events.py (it's getting too long) + +- Docstrings + +- Unittests + +- better run_once() behavior? (Run ready list last.) + +- start_serving() + +- Make Handler() callable? Log the exception in there? + +- Add the repeat interval to the Handler class? + +- Recognize Handler passed to add_reader(), call_soon(), etc.? + +- SSL support + +- buffered stream implementation + +- Primitives like par() and wait_one() + +- Remove test dependency on xkcd.com, write our own test server diff --git a/tulip/__init__.py b/tulip/__init__.py new file mode 100644 index 00000000..faf307fb --- /dev/null +++ b/tulip/__init__.py @@ -0,0 +1,26 @@ +"""Tulip 2.0, tracking PEP 3156.""" + +import sys + +# This relies on each of the submodules having an __all__ variable. +from .futures import * +from .events import * +from .locks import * +from .transports import * +from .protocols import * +from .streams import * +from .tasks import * + +if sys.platform == 'win32': # pragma: no cover + from .windows_events import * +else: + from .unix_events import * # pragma: no cover + + +__all__ = (futures.__all__ + + events.__all__ + + locks.__all__ + + transports.__all__ + + protocols.__all__ + + streams.__all__ + + tasks.__all__) diff --git a/tulip/base_events.py b/tulip/base_events.py new file mode 100644 index 00000000..573807d8 --- /dev/null +++ b/tulip/base_events.py @@ -0,0 +1,547 @@ +"""Base implementation of event loop. + +The event loop can be broken up into a multiplexer (the part +responsible for notifying us of IO events) and the event loop proper, +which wraps a multiplexer with functionality for scheduling callbacks, +immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. +""" + + +import collections +import concurrent.futures +import heapq +import logging +import socket +import time + +from . import events +from . import futures +from . import tasks + + +__all__ = ['BaseEventLoop'] + + +# Argument for default thread pool executor creation. +_MAX_WORKERS = 5 + + +class _StopError(BaseException): + """Raised to stop the event loop.""" + + +def _raise_stop_error(*args): + raise _StopError + + +class BaseEventLoop(events.AbstractEventLoop): + + def __init__(self): + self._ready = collections.deque() + self._scheduled = [] + self._default_executor = None + self._internal_fds = 0 + self._running = False + + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + """Create socket transport.""" + raise NotImplementedError + + def _make_ssl_transport(self, rawsock, protocol, + sslcontext, waiter, extra=None): + """Create SSL transport.""" + raise NotImplementedError + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + """Create datagram transport.""" + raise NotImplementedError + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create read pipe transport.""" + raise NotImplementedError + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create write pipe transport.""" + raise NotImplementedError + + def _read_from_self(self): + """XXX""" + raise NotImplementedError + + def _write_to_self(self): + """XXX""" + raise NotImplementedError + + def _process_events(self, event_list): + """Process selector events.""" + raise NotImplementedError + + def is_running(self): + """Returns running status of event loop.""" + return self._running + + def run(self): + """Run the event loop until nothing left to do or stop() called. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + + TODO: Give this a timeout too? + """ + if self._running: + raise RuntimeError('Event loop is running.') + + self._running = True + try: + while (self._ready or + self._scheduled or + self._selector.registered_count() > 1): + try: + self._run_once() + except _StopError: + break + finally: + self._running = False + + def run_forever(self): + """Run until stop() is called. + + This only makes sense over run() if you have another thread + scheduling callbacks using call_soon_threadsafe(). + """ + handle = self.call_repeatedly(24*3600, lambda: None) + try: + self.run() + finally: + handle.cancel() + + def run_once(self, timeout=0): + """Run through all callbacks and all I/O polls once. + + Calling stop() will break out of this too. + """ + if self._running: + raise RuntimeError('Event loop is running.') + + self._running = True + try: + self._run_once(timeout) + except _StopError: + pass + finally: + self._running = False + + def run_until_complete(self, future, timeout=None): + """Run until the Future is done, or until a timeout. + + If the argument is a coroutine, it is wrapped in a Task. + + XXX TBD: It would be disastrous to call run_until_complete() + with the same coroutine twice -- it would wrap it in two + different Tasks and that can't be good. + + Return the Future's result, or raise its exception. If the + timeout is reached or stop() is called, raise TimeoutError. + """ + if not isinstance(future, futures.Future): + if tasks.iscoroutine(future): + future = tasks.Task(future) + else: + assert False, 'A Future or coroutine is required' + + handle_called = False + + def stop_loop(): + nonlocal handle_called + handle_called = True + raise _StopError + + future.add_done_callback(_raise_stop_error) + + if timeout is None: + self.run_forever() + else: + handle = self.call_later(timeout, stop_loop) + self.run_forever() + handle.cancel() + + if handle_called: + raise futures.TimeoutError + + return future.result() + + def stop(self): + """Stop running the event loop. + + Every callback scheduled before stop() is called will run. + Callback scheduled after stop() is called won't. However, + those callbacks will run if run() is called again later. + """ + self.call_soon(_raise_stop_error) + + def call_later(self, delay, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The delay can be an int or float, expressed in seconds. It is + always a relative time. + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Callbacks scheduled in the past are passed on to call_soon(), + so these will be called in the order in which they were + registered rather than by time due. This is so you can't + cheat and insert yourself at the front of the ready queue by + using a negative time. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + if delay <= 0: + return self.call_soon(callback, *args) + + handle = events.Timer(time.monotonic() + delay, callback, args) + heapq.heappush(self._scheduled, handle) + return handle + + def call_repeatedly(self, interval, callback, *args): + """Call a callback every 'interval' seconds.""" + assert interval > 0, 'Interval must be > 0: %r' % (interval,) + # TODO: What if callback is already a Handle? + def wrapper(): + callback(*args) # If this fails, the chain is broken. + handle._when = time.monotonic() + interval + heapq.heappush(self._scheduled, handle) + + handle = events.Timer(time.monotonic() + interval, wrapper, ()) + heapq.heappush(self._scheduled, handle) + return handle + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + handle = events.make_handle(callback, args) + self._ready.append(handle) + return handle + + def call_soon_threadsafe(self, callback, *args): + """XXX""" + handle = self.call_soon(callback, *args) + self._write_to_self() + return handle + + def run_in_executor(self, executor, callback, *args): + if isinstance(callback, events.Handle): + assert not args + assert not isinstance(callback, events.Timer) + if callback.cancelled: + f = futures.Future() + f.set_result(None) + return f + callback, args = callback.callback, callback.args + if executor is None: + executor = self._default_executor + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + self._default_executor = executor + return self.wrap_future(executor.submit(callback, *args)) + + def set_default_executor(self, executor): + self._default_executor = executor + + def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) + + def getnameinfo(self, sockaddr, flags=0): + return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + + @tasks.coroutine + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=False, family=0, proto=0, flags=0, sock=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + "host, port and sock can not be specified at the same time") + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + yield from self.sock_connect(sock, address) + except socket.error as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + if len(exceptions) == 1: + raise exceptions[0] + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise socket.error('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + + elif sock is None: + raise ValueError( + "host and port was not specified and no sock specified") + + sock.setblocking(False) + + protocol = protocol_factory() + waiter = futures.Future() + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + transport = self._make_ssl_transport( + sock, protocol, sslcontext, waiter) + else: + transport = self._make_socket_transport(sock, protocol, waiter) + + yield from waiter + return transport, protocol + + @tasks.coroutine + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + """Create datagram connection.""" + if not (local_addr or remote_addr): + if family == 0: + raise ValueError('unexpected address family') + addr_pairs_info = (((family, proto), (None, None)),) + else: + # join addresss by (family, protocol) + addr_infos = collections.OrderedDict() + for idx, addr in ((0, local_addr), (1, remote_addr)): + if addr is not None: + assert isinstance(addr, tuple) and len(addr) == 2, ( + '2-tuple is expected') + + infos = yield from self.getaddrinfo( + *addr, family=family, type=socket.SOCK_DGRAM, + proto=proto, flags=flags) + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + for fam, _, pro, _, address in infos: + key = (fam, pro) + if key not in addr_infos: + addr_infos[key] = [None, None] + addr_infos[key][idx] = address + + # each addr has to have info for each (family, proto) pair + addr_pairs_info = [ + (key, addr_pair) for key, addr_pair in addr_infos.items() + if not ((local_addr and addr_pair[0] is None) or + (remote_addr and addr_pair[1] is None))] + + if not addr_pairs_info: + raise ValueError('can not get address information') + + exceptions = [] + + for (family, proto), (local_address, remote_address) in addr_pairs_info: + sock = None + l_addr = None + r_addr = None + try: + sock = socket.socket( + family=family, type=socket.SOCK_DGRAM, proto=proto) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(False) + + if local_addr: + sock.bind(local_address) + l_addr = sock.getsockname() + if remote_addr: + yield from self.sock_connect(sock, remote_address) + r_addr = remote_address + except socket.error as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + + protocol = protocol_factory() + transport = self._make_datagram_transport( + sock, protocol, r_addr, extra={'addr': l_addr}) + return transport, protocol + + # TODO: Or create_server()? + @tasks.task + def start_serving(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, backlog=100, sock=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + "host, port and sock can not be specified at the same time") + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + # TODO: Maybe we want to bind every address in the list + # instead of the first one that works? + exceptions = [] + for family, type, proto, cname, address in infos: + sock = socket.socket(family=family, type=type, proto=proto) + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + except socket.error as exc: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + + elif sock is None: + raise ValueError( + "host and port was not specified and no sock specified") + + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock) + return sock + + @tasks.coroutine + def connect_read_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future() + transport = self._make_read_pipe_transport(pipe, protocol, waiter, + extra={}) + yield from waiter + return transport, protocol + + @tasks.coroutine + def connect_write_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future() + transport = self._make_write_pipe_transport(pipe, protocol, waiter, + extra={}) + yield from waiter + return transport, protocol + + def _add_callback(self, handle): + """Add a Handle to ready or scheduled.""" + if handle.cancelled: + return + if isinstance(handle, events.Timer): + heapq.heappush(self._scheduled, handle) + else: + self._ready.append(handle) + + def wrap_future(self, future): + """XXX""" + if isinstance(future, futures.Future): + return future # Don't wrap our own type of Future. + new_future = futures.Future() + future.add_done_callback( + lambda future: + self.call_soon_threadsafe(new_future._copy_state, future)) + return new_future + + def _run_once(self, timeout=None): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0].cancelled: + heapq.heappop(self._scheduled) + + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0].when + deadline = max(0, when - time.monotonic()) + if timeout is None: + timeout = deadline + else: + timeout = min(timeout, deadline) + + t0 = time.monotonic() + event_list = self._selector.select(timeout) + t1 = time.monotonic() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + self._process_events(event_list) + + # Handle 'later' callbacks that are ready. + now = time.monotonic() + while self._scheduled: + handle = self._scheduled[0] + if handle.when > now: + break + handle = heapq.heappop(self._scheduled) + self._ready.append(handle) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is threadsafe without using locks. + ntodo = len(self._ready) + for i in range(ntodo): + handle = self._ready.popleft() + if not handle.cancelled: + handle.run() diff --git a/tulip/events.py b/tulip/events.py new file mode 100644 index 00000000..ba9a50f9 --- /dev/null +++ b/tulip/events.py @@ -0,0 +1,355 @@ +"""Event loop and event loop policy. + +Beyond the PEP: +- Only the main thread has a default event loop. +""" + +__all__ = ['EventLoopPolicy', 'DefaultEventLoopPolicy', + 'AbstractEventLoop', 'Timer', 'Handle', 'make_handle', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', 'new_event_loop', + ] + +import logging +import sys +import threading + + +class Handle: + """Object returned by callback registration methods.""" + + def __init__(self, callback, args): + self._callback = callback + self._args = args + self._cancelled = False + + def __repr__(self): + res = 'Handle({}, {})'.format(self._callback, self._args) + if self._cancelled: + res += '' + return res + + @property + def callback(self): + return self._callback + + @property + def args(self): + return self._args + + @property + def cancelled(self): + return self._cancelled + + def cancel(self): + self._cancelled = True + + def run(self): + try: + self._callback(*self._args) + except Exception: + logging.exception('Exception in callback %s %r', + self._callback, self._args) + + +def make_handle(callback, args): + if isinstance(callback, Handle): + assert not args + return callback + return Handle(callback, args) + + +class Timer(Handle): + """Object returned by timed callback registration methods.""" + + def __init__(self, when, callback, args): + assert when is not None + super().__init__(callback, args) + + self._when = when + + def __repr__(self): + res = 'Timer({}, {}, {})'.format(self._when, + self._callback, + self._args) + if self._cancelled: + res += '' + + return res + + @property + def when(self): + return self._when + + def __lt__(self, other): + return self._when < other._when + + def __le__(self, other): + if self._when < other._when: + return True + return self.__eq__(other) + + def __gt__(self, other): + return self._when > other._when + + def __ge__(self, other): + if self._when > other._when: + return True + return self.__eq__(other) + + def __eq__(self, other): + if isinstance(other, Timer): + return (self._when == other._when and + self._callback == other._callback and + self._args == other._args and + self._cancelled == other._cancelled) + return NotImplemented + + def __ne__(self, other): + equal = self.__eq__(other) + return NotImplemented if equal is NotImplemented else not equal + + +class AbstractEventLoop: + """Abstract event loop.""" + + # TODO: Rename run() -> run_until_idle(), run_forever() -> run(). + + def run(self): + """Run the event loop. Block until there is nothing left to do.""" + raise NotImplementedError + + def run_forever(self): + """Run the event loop. Block until stop() is called.""" + raise NotImplementedError + + def run_once(self, timeout=None): # NEW! + """Run one complete cycle of the event loop.""" + raise NotImplementedError + + def run_until_complete(self, future, timeout=None): # NEW! + """Run the event loop until a Future is done. + + Return the Future's result, or raise its exception. + + If timeout is not None, run it for at most that long; + if the Future is still not done, raise TimeoutError + (but don't cancel the Future). + """ + raise NotImplementedError + + def stop(self): # NEW! + """Stop the event loop as soon as reasonable. + + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + raise NotImplementedError + + # Methods returning Handles for scheduling callbacks. + + def call_later(self, delay, callback, *args): + raise NotImplementedError + + def call_repeatedly(self, interval, callback, *args): # NEW! + raise NotImplementedError + + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) + + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError + + # Methods returning Futures for interacting with threads. + + def wrap_future(self, future): + raise NotImplementedError + + def run_in_executor(self, executor, callback, *args): + raise NotImplementedError + + # Network I/O methods returning Futures. + + def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError + + def create_connection(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, sock=None): + raise NotImplementedError + + def start_serving(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, sock=None): + raise NotImplementedError + + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + def connect_read_pipe(self, protocol_factory, pipe): + """Register read pipe in eventloop. + + protocol_factory should instantiate object with Protocol interface. + pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + ReadTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def connect_write_pipe(self, protocol_factory, pipe): + """Register write pipe in eventloop. + + protocol_factory should instantiate object with BaseProtocol interface. + Pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + WriteTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + #def spawn_subprocess(self, protocol_factory, pipe): + # raise NotImplementedError + + # Ready-based callback registration methods. + # The add_*() methods return a Handle. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. + + def add_reader(self, fd, callback, *args): + raise NotImplementedError + + def remove_reader(self, fd): + raise NotImplementedError + + def add_writer(self, fd, callback, *args): + raise NotImplementedError + + def remove_writer(self, fd): + raise NotImplementedError + + # Completion based I/O methods returning Futures. + + def sock_recv(self, sock, nbytes): + raise NotImplementedError + + def sock_sendall(self, sock, data): + raise NotImplementedError + + def sock_connect(self, sock, address): + raise NotImplementedError + + def sock_accept(self, sock): + raise NotImplementedError + + # Signal handling. + + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError + + def remove_signal_handler(self, sig): + raise NotImplementedError + + +class EventLoopPolicy: + """Abstract policy for accessing the event loop.""" + + def get_event_loop(self): + """XXX""" + raise NotImplementedError + + def set_event_loop(self, event_loop): + """XXX""" + raise NotImplementedError + + def new_event_loop(self): + """XXX""" + raise NotImplementedError + + +class DefaultEventLoopPolicy(threading.local, EventLoopPolicy): + """Default policy implementation for accessing the event loop. + + In this policy, each thread has its own event loop. However, we + only automatically create an event loop by default for the main + thread; other threads by default have no event loop. + + Other policies may have different rules (e.g. a single global + event loop, or automatically creating an event loop per thread, or + using some other notion of context to which an event loop is + associated). + """ + + _event_loop = None + + def get_event_loop(self): + """Get the event loop. + + This may be None or an instance of EventLoop. + """ + if (self._event_loop is None and + threading.current_thread().name == 'MainThread'): + self._event_loop = self.new_event_loop() + return self._event_loop + + def set_event_loop(self, event_loop): + """Set the event loop.""" + assert event_loop is None or isinstance(event_loop, AbstractEventLoop) + self._event_loop = event_loop + + def new_event_loop(self): + """Create a new event loop. + + You must call set_event_loop() to make this the current event + loop. + """ + if sys.platform == 'win32': # pragma: no cover + from . import windows_events + return windows_events.SelectorEventLoop() + else: # pragma: no cover + from . import unix_events + return unix_events.SelectorEventLoop() + + +# Event loop policy. The policy itself is always global, even if the +# policy's rules say that there is an event loop per thread (or other +# notion of context). The default policy is installed by the first +# call to get_event_loop_policy(). +_event_loop_policy = None + + +def get_event_loop_policy(): + """XXX""" + global _event_loop_policy + if _event_loop_policy is None: + _event_loop_policy = DefaultEventLoopPolicy() + return _event_loop_policy + + +def set_event_loop_policy(policy): + """XXX""" + global _event_loop_policy + assert policy is None or isinstance(policy, EventLoopPolicy) + _event_loop_policy = policy + + +def get_event_loop(): + """XXX""" + return get_event_loop_policy().get_event_loop() + + +def set_event_loop(event_loop): + """XXX""" + get_event_loop_policy().set_event_loop(event_loop) + + +def new_event_loop(): + """XXX""" + return get_event_loop_policy().new_event_loop() diff --git a/tulip/futures.py b/tulip/futures.py new file mode 100644 index 00000000..39137aa6 --- /dev/null +++ b/tulip/futures.py @@ -0,0 +1,255 @@ +"""A Future class similar to the one in PEP 3148.""" + +__all__ = ['CancelledError', 'TimeoutError', + 'InvalidStateError', 'InvalidTimeoutError', + 'Future', + ] + +import concurrent.futures._base + +from . import events + +# States for Future. +_PENDING = 'PENDING' +_CANCELLED = 'CANCELLED' +_FINISHED = 'FINISHED' + +# TODO: Do we really want to depend on concurrent.futures internals? +Error = concurrent.futures._base.Error +CancelledError = concurrent.futures.CancelledError +TimeoutError = concurrent.futures.TimeoutError + + +class InvalidStateError(Error): + """The operation is not allowed in this state.""" + # TODO: Show the future, its state, the method, and the required state. + + +class InvalidTimeoutError(Error): + """Called result() or exception() with timeout != 0.""" + # TODO: Print a nice error message. + + +class Future: + """This class is *almost* compatible with concurrent.futures.Future. + + Differences: + + - result() and exception() do not take a timeout argument and + raise an exception when the future isn't done yet. + + - Callbacks registered with add_done_callback() are always called + via the event loop's call_soon_threadsafe(). + + - This class is not compatible with the wait() and as_completed() + methods in the concurrent.futures package. + + (In Python 3.4 or later we may be able to unify the implementations.) + """ + + # Class variables serving as defaults for instance variables. + _state = _PENDING + _result = None + _exception = None + _timeout_handle = None + + _blocking = False # proper use of future (yield vs yield from) + + def __init__(self, *, event_loop=None, timeout=None): + """Initialize the future. + + The optional event_loop argument allows to explicitly set the event + loop object used by the future. If it's not provided, the future uses + the default event loop. + """ + if event_loop is None: + self._event_loop = events.get_event_loop() + else: + self._event_loop = event_loop + self._callbacks = [] + + if timeout is not None: + self._timeout_handle = self._event_loop.call_later( + timeout, self.cancel) + + def __repr__(self): + res = self.__class__.__name__ + if self._state == _FINISHED: + if self._exception is not None: + res += ''.format(self._exception) + else: + res += ''.format(self._result) + elif self._callbacks: + size = len(self._callbacks) + if size > 2: + res += '<{}, [{}, <{} more>, {}]>'.format( + self._state, self._callbacks[0], + size-2, self._callbacks[-1]) + else: + res += '<{}, {}>'.format(self._state, self._callbacks) + else: + res += '<{}>'.format(self._state) + return res + + def cancel(self): + """Cancel the future and schedule callbacks. + + If the future is already done or cancelled, return False. Otherwise, + change the future's state to cancelled, schedule the callbacks and + return True. + """ + if self._state != _PENDING: + return False + self._state = _CANCELLED + self._schedule_callbacks() + return True + + def _schedule_callbacks(self): + """Internal: Ask the event loop to call all callbacks. + + The callbacks are scheduled to be called as soon as possible. Also + clears the callback list. + """ + # Cancel timeout handle + if self._timeout_handle is not None: + self._timeout_handle.cancel() + self._timeout_handle = None + + callbacks = self._callbacks[:] + if not callbacks: + return + + self._callbacks[:] = [] + for callback in callbacks: + self._event_loop.call_soon(callback, self) + + def cancelled(self): + """Return True if the future was cancelled.""" + return self._state == _CANCELLED + + def running(self): + """Always return False. + + This method is for compatibility with concurrent.futures; we don't + have a running state. + """ + return False # We don't have a running state. + + def done(self): + """Return True if the future is done. + + Done means either that a result / exception are available, or that the + future was cancelled. + """ + return self._state != _PENDING + + def result(self, timeout=0): + """Return the result this future represents. + + If the future has been cancelled, raises CancelledError. If the + future's result isn't yet available, raises InvalidStateError. If + the future is done and has an exception set, this exception is raised. + Timeout values other than 0 are not supported. + """ + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + if self._exception is not None: + raise self._exception + return self._result + + def exception(self, timeout=0): + """Return the exception that was set on this future. + + The exception (or None if no exception was set) is returned only if + the future is done. If the future has been cancelled, raises + CancelledError. If the future isn't done yet, raises + InvalidStateError. Timeout values other than 0 are not supported. + """ + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + return self._exception + + def add_done_callback(self, fn): + """Add a callback to be run when the future becomes done. + + The callback is called with a single argument - the future object. If + the future is already done when this is called, the callback is + scheduled with call_soon. + """ + if self._state != _PENDING: + self._event_loop.call_soon(fn, self) + else: + self._callbacks.append(fn) + + # New method not in PEP 3148. + + def remove_done_callback(self, fn): + """Remove all instances of a callback from the "call when done" list. + + Returns the number of callbacks removed. + """ + filtered_callbacks = [f for f in self._callbacks if f != fn] + removed_count = len(self._callbacks) - len(filtered_callbacks) + if removed_count: + self._callbacks[:] = filtered_callbacks + return removed_count + + # So-called internal methods (note: no set_running_or_notify_cancel()). + + def set_result(self, result): + """Mark the future done and set its result. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError + self._result = result + self._state = _FINISHED + self._schedule_callbacks() + + def set_exception(self, exception): + """ Mark the future done and set an exception. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError + self._exception = exception + self._state = _FINISHED + self._schedule_callbacks() + + # Truly internal methods. + + def _copy_state(self, other): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert other.done() + assert not self.done() + if other.cancelled(): + self.cancel() + else: + exception = other.exception() + if exception is not None: + self.set_exception(exception) + else: + result = other.result() + self.set_result(result) + + def __iter__(self): + if not self.done(): + self._blocking = True + yield self # This tells Task to wait for completion. + assert self.done(), "yield from wasn't used with future" + return self.result() # May raise too. diff --git a/tulip/http/__init__.py b/tulip/http/__init__.py new file mode 100644 index 00000000..582f0809 --- /dev/null +++ b/tulip/http/__init__.py @@ -0,0 +1,12 @@ +# This relies on each of the submodules having an __all__ variable. + +from .client import * +from .errors import * +from .protocol import * +from .server import * + + +__all__ = (client.__all__ + + errors.__all__ + + protocol.__all__ + + server.__all__) diff --git a/tulip/http/client.py b/tulip/http/client.py new file mode 100644 index 00000000..b65b90a8 --- /dev/null +++ b/tulip/http/client.py @@ -0,0 +1,145 @@ +"""HTTP Client for Tulip. + +Most basic usage: + + sts, headers, response = yield from http_client.fetch(url, + method='GET', headers={}, request=b'') + assert isinstance(sts, int) + assert isinstance(headers, dict) + # sort of; case insensitive (what about multiple values for same header?) + headers['status'] == '200 Ok' # or some such + assert isinstance(response, bytes) + +TODO: Reuse email.Message class (or its subclass, http.client.HTTPMessage). +TODO: How do we do connection keep alive? Pooling? +""" + +__all__ = ['HttpClientProtocol'] + + +import email.message +import email.parser + +import tulip + +from . import protocol + + +class HttpClientProtocol: + """This Protocol class is also used to initiate the connection. + + Usage: + p = HttpClientProtocol(url, ...) + sts, headers, stream = yield from p.connect() + + """ + + def __init__(self, host, port=None, *, + path='/', method='GET', headers=None, ssl=None, + make_body=None, encoding='utf-8', version=(1, 1), + chunked=False): + host = self.validate(host, 'host') + if ':' in host: + assert port is None + host, port_s = host.split(':', 1) + port = int(port_s) + self.host = host + if port is None: + if ssl: + port = 443 + else: + port = 80 + assert isinstance(port, int) + self.port = port + self.path = self.validate(path, 'path') + self.method = self.validate(method, 'method') + self.headers = email.message.Message() + self.headers['Accept-Encoding'] = 'gzip, deflate' + if headers: + for key, value in headers.items(): + self.validate(key, 'header key') + self.validate(value, 'header value', True) + self.headers[key] = value + self.encoding = self.validate(encoding, 'encoding') + self.version = version + self.make_body = make_body + self.chunked = chunked + self.ssl = ssl + if 'content-length' not in self.headers: + if self.make_body is None: + self.headers['Content-Length'] = '0' + else: + self.chunked = True + if self.chunked: + if 'Transfer-Encoding' not in self.headers: + self.headers['Transfer-Encoding'] = 'chunked' + else: + assert self.headers['Transfer-Encoding'].lower() == 'chunked' + if 'host' not in self.headers: + self.headers['Host'] = self.host + self.event_loop = tulip.get_event_loop() + self.transport = None + + def validate(self, value, name, embedded_spaces_okay=False): + # Must be a string. If embedded_spaces_okay is False, no + # whitespace is allowed; otherwise, internal single spaces are + # allowed (but no other whitespace). + assert isinstance(value, str), \ + '{} should be str, not {}'.format(name, type(value)) + parts = value.split() + assert parts, '{} should not be empty'.format(name) + if embedded_spaces_okay: + assert ' '.join(parts) == value, \ + '{} can only contain embedded single spaces ({!r})'.format( + name, value) + else: + assert parts == [value], \ + '{} cannot contain whitespace ({!r})'.format(name, value) + return value + + @tulip.coroutine + def connect(self): + yield from self.event_loop.create_connection( + lambda: self, self.host, self.port, ssl=self.ssl) + + # read response status + version, status, reason = yield from self.stream.read_response_status() + + message = yield from self.stream.read_message(version) + + # headers + headers = email.message.Message() + for hdr, val in message.headers: + headers.add_header(hdr, val) + + sts = '{} {}'.format(status, reason) + return (sts, headers, message.payload) + + def connection_made(self, transport): + self.transport = transport + self.stream = protocol.HttpStreamReader() + + self.request = protocol.Request( + transport, self.method, self.path, self.version) + + self.request.add_headers(*self.headers.items()) + self.request.send_headers() + + if self.make_body is not None: + if self.chunked: + self.make_body( + self.request.write, self.request.eof) + else: + self.make_body( + self.request.write, self.request.eof) + else: + self.request.write_eof() + + def data_received(self, data): + self.stream.feed_data(data) + + def eof_received(self): + self.stream.feed_eof() + + def connection_lost(self, exc): + pass diff --git a/tulip/http/errors.py b/tulip/http/errors.py new file mode 100644 index 00000000..41344de1 --- /dev/null +++ b/tulip/http/errors.py @@ -0,0 +1,44 @@ +"""http related errors.""" + +__all__ = ['HttpException', 'HttpStatusException', + 'IncompleteRead', 'BadStatusLine', 'LineTooLong', 'InvalidHeader'] + +import http.client + + +class HttpException(http.client.HTTPException): + + code = None + headers = () + + +class HttpStatusException(HttpException): + + def __init__(self, code, headers=None, message=''): + self.code = code + self.headers = headers + self.message = message + + +class BadRequestException(HttpException): + + code = 400 + + +class IncompleteRead(BadRequestException, http.client.IncompleteRead): + pass + + +class BadStatusLine(BadRequestException, http.client.BadStatusLine): + pass + + +class LineTooLong(BadRequestException, http.client.LineTooLong): + pass + + +class InvalidHeader(BadRequestException): + + def __init__(self, hdr): + super().__init__('Invalid HTTP Header: %s' % hdr) + self.hdr = hdr diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py new file mode 100644 index 00000000..6a0e1279 --- /dev/null +++ b/tulip/http/protocol.py @@ -0,0 +1,877 @@ +"""Http related helper utils.""" + +__all__ = ['HttpStreamReader', + 'HttpMessage', 'Request', 'Response', + 'RawHttpMessage', 'RequestLine', 'ResponseStatus'] + +import collections +import email.utils +import functools +import http.server +import itertools +import re +import sys +import zlib + +import tulip +from . import errors + +METHRE = re.compile('[A-Z0-9$-_.]+') +VERSRE = re.compile('HTTP/(\d+).(\d+)') +HDRRE = re.compile(b"[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]") +CONTINUATION = (b' ', b'\t') +RESPONSES = http.server.BaseHTTPRequestHandler.responses + +RequestLine = collections.namedtuple( + 'RequestLine', ['method', 'uri', 'version']) + + +ResponseStatus = collections.namedtuple( + 'ResponseStatus', ['version', 'code', 'reason']) + + +RawHttpMessage = collections.namedtuple( + 'RawHttpMessage', ['headers', 'payload', 'should_close', 'compression']) + + +class HttpStreamReader(tulip.StreamReader): + + MAX_HEADERS = 32768 + MAX_HEADERFIELD_SIZE = 8190 + + # if _parser is set, feed_data and feed_eof sends data into + # _parser instead of self. is it being used as stream redirection for + # _parse_chunked_payload, _parse_length_payload and _parse_eof_payload + _parser = None + + def feed_data(self, data): + """_parser is a generator, if _parser is set, feed_data sends + incoming data into the generator untile generator stops.""" + if self._parser: + try: + self._parser.send(data) + except StopIteration as exc: + self._parser = None + if exc.value: + self.feed_data(exc.value) + else: + super().feed_data(data) + + def feed_eof(self): + """_parser is a generator, if _parser is set feed_eof throws + StreamEofException into this generator.""" + if self._parser: + try: + self._parser.throw(StreamEofException()) + except StopIteration: + self._parser = None + + super().feed_eof() + + @tulip.coroutine + def read_request_line(self): + """Read request status line. Exception errors.BadStatusLine + could be raised in case of any errors in status line. + Returns three values (method, uri, version) + + Example: + + GET /path HTTP/1.1 + + >> yield from reader.read_request_line() + ('GET', '/path', (1, 1)) + + """ + bline = yield from self.readline() + try: + line = bline.decode('ascii').rstrip() + except UnicodeDecodeError: + raise errors.BadStatusLine(bline) from None + + try: + method, uri, version = line.split(None, 2) + except ValueError: + raise errors.BadStatusLine(line) from None + + # method + method = method.upper() + if not METHRE.match(method): + raise errors.BadStatusLine(method) + + # version + match = VERSRE.match(version) + if match is None: + raise errors.BadStatusLine(version) + version = (int(match.group(1)), int(match.group(2))) + + return RequestLine(method, uri, version) + + @tulip.coroutine + def read_response_status(self): + """Read response status line. Exception errors.BadStatusLine + could be raised in case of any errors in status line. + Returns three values (version, status_code, reason) + + Example: + + HTTP/1.1 200 Ok + + >> yield from reader.read_response_status() + ((1, 1), 200, 'Ok') + + """ + bline = yield from self.readline() + if not bline: + # Presumably, the server closed the connection before + # sending a valid response. + raise errors.BadStatusLine(bline) + + try: + line = bline.decode('ascii').rstrip() + except UnicodeDecodeError: + raise errors.BadStatusLine(bline) from None + + try: + version, status = line.split(None, 1) + except ValueError: + raise errors.BadStatusLine(line) from None + else: + try: + status, reason = status.split(None, 1) + except ValueError: + reason = '' + + # version + match = VERSRE.match(version) + if match is None: + raise errors.BadStatusLine(line) + version = (int(match.group(1)), int(match.group(2))) + + # The status code is a three-digit number + try: + status = int(status) + except ValueError: + raise errors.BadStatusLine(line) from None + + if status < 100 or status > 999: + raise errors.BadStatusLine(line) + + return ResponseStatus(version, status, reason.strip()) + + @tulip.coroutine + def read_headers(self): + """Read and parses RFC2822 headers from a stream. + + Line continuations are supported. Returns list of header name + and value pairs. Header name is in upper case. + """ + size = 0 + headers = [] + + line = yield from self.readline() + + while line not in (b'\r\n', b'\n'): + header_length = len(line) + + # Parse initial header name : value pair. + sep_pos = line.find(b':') + if sep_pos < 0: + raise ValueError('Invalid header %s' % line.strip()) + + name, value = line[:sep_pos], line[sep_pos+1:] + name = name.rstrip(b' \t').upper() + if HDRRE.search(name): + raise ValueError('Invalid header name %s' % name) + + name = name.strip().decode('ascii', 'surrogateescape') + value = [value.lstrip()] + + # next line + line = yield from self.readline() + + # consume continuation lines + continuation = line.startswith(CONTINUATION) + + if continuation: + while continuation: + header_length += len(line) + if header_length > self.MAX_HEADERFIELD_SIZE: + raise errors.LineTooLong( + 'limit request headers fields size') + value.append(line) + + line = yield from self.readline() + continuation = line.startswith(CONTINUATION) + else: + if header_length > self.MAX_HEADERFIELD_SIZE: + raise errors.LineTooLong( + 'limit request headers fields size') + + # total headers size + size += header_length + if size >= self.MAX_HEADERS: + raise errors.LineTooLong('limit request headers fields') + + headers.append( + (name, + b''.join(value).rstrip().decode('ascii', 'surrogateescape'))) + + return headers + + def _parse_chunked_payload(self): + """Chunked transfer encoding parser.""" + stream = yield + + try: + data = bytearray() + + while True: + # read line + if b'\n' not in data: + data.extend((yield)) + continue + + line, data = data.split(b'\n', 1) + + # Read the next chunk size from the file + i = line.find(b';') + if i >= 0: + line = line[:i] # strip chunk-extensions + try: + size = int(line, 16) + except ValueError: + raise errors.IncompleteRead(b'') from None + + if size == 0: + break + + # read chunk + while len(data) < size: + data.extend((yield)) + + # feed stream + stream.feed_data(data[:size]) + + data = data[size:] + + # toss the CRLF at the end of the chunk + while len(data) < 2: + data.extend((yield)) + + data = data[2:] + + # read and discard trailer up to the CRLF terminator + while True: + if b'\n' in data: + line, data = data.split(b'\n', 1) + if line in (b'\r', b''): + break + else: + data.extend((yield)) + + # stream eof + stream.feed_eof() + return data + + except StreamEofException: + stream.set_exception(errors.IncompleteRead(b'')) + except errors.IncompleteRead as exc: + stream.set_exception(exc) + + def _parse_length_payload(self, length): + """Read specified amount of bytes.""" + stream = yield + + try: + data = bytearray() + while length: + data.extend((yield)) + + data_len = len(data) + if data_len <= length: + stream.feed_data(data) + data = bytearray() + length -= data_len + else: + stream.feed_data(data[:length]) + data = data[length:] + length = 0 + + stream.feed_eof() + return data + except StreamEofException: + stream.set_exception(errors.IncompleteRead(b'')) + + def _parse_eof_payload(self): + """Read all bytes untile eof.""" + stream = yield + + try: + while True: + stream.feed_data((yield)) + except StreamEofException: + stream.feed_eof() + + @tulip.coroutine + def read_message(self, version=(1, 1), + length=None, compression=True, readall=False): + """Read RFC2822 headers and message payload from a stream. + + read_message() automatically decompress gzip and deflate content + encoding. To prevent decompression pass compression=False. + + Returns tuple of headers, payload stream, should close flag, + compression type. + """ + # load headers + headers = yield from self.read_headers() + + # payload params + chunked = False + encoding = None + close_conn = None + + for name, value in headers: + if name == 'CONTENT-LENGTH': + length = value + elif name == 'TRANSFER-ENCODING': + chunked = value.lower() == 'chunked' + elif name == 'SEC-WEBSOCKET-KEY1': + length = 8 + elif name == 'CONNECTION': + v = value.lower() + if v == 'close': + close_conn = True + elif v == 'keep-alive': + close_conn = False + elif compression and name == 'CONTENT-ENCODING': + enc = value.lower() + if enc in ('gzip', 'deflate'): + encoding = enc + + if close_conn is None: + close_conn = version <= (1, 0) + + # payload parser + if chunked: + parser = self._parse_chunked_payload() + + elif length is not None: + try: + length = int(length) + except ValueError: + raise errors.InvalidHeader('CONTENT-LENGTH') from None + + if length < 0: + raise errors.InvalidHeader('CONTENT-LENGTH') + + parser = self._parse_length_payload(length) + else: + if readall: + parser = self._parse_eof_payload() + else: + parser = self._parse_length_payload(0) + + next(parser) + + payload = stream = tulip.StreamReader() + + # payload decompression wrapper + if encoding is not None: + stream = DeflateStream(stream, encoding) + + try: + # initialize payload parser with stream, stream is being + # used by parser as destination stream + parser.send(stream) + except StopIteration: + pass + else: + # feed existing buffer to payload parser + self.byte_count = 0 + while self.buffer: + try: + parser.send(self.buffer.popleft()) + except StopIteration as exc: + parser = None + + # parser is done + buf = b''.join(self.buffer) + self.buffer.clear() + + # re-add remaining data back to buffer + if exc.value: + self.feed_data(exc.value) + + if buf: + self.feed_data(buf) + + break + + # parser still require more data + if parser is not None: + if self.eof: + try: + parser.throw(StreamEofException()) + except StopIteration as exc: + pass + else: + self._parser = parser + + return RawHttpMessage(headers, payload, close_conn, encoding) + + +class StreamEofException(Exception): + """Internal exception: eof received.""" + + +class DeflateStream: + """DeflateStream decomress stream and feed data into specified stream.""" + + def __init__(self, stream, encoding): + self.stream = stream + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + + self.zlib = zlib.decompressobj(wbits=zlib_mode) + + def set_exception(self, exc): + self.stream.set_exception(exc) + + def feed_data(self, chunk): + try: + chunk = self.zlib.decompress(chunk) + except: + self.stream.set_exception(errors.IncompleteRead(b'')) + + if chunk: + self.stream.feed_data(chunk) + + def feed_eof(self): + self.stream.feed_data(self.zlib.flush()) + if not self.zlib.eof: + self.stream.set_exception(errors.IncompleteRead(b'')) + + self.stream.feed_eof() + + +EOF_MARKER = object() +EOL_MARKER = object() + + +def wrap_payload_filter(func): + """Wraps payload filter and piped filters. + + Filter is a generatator that accepts arbitrary chunks of data, + modify data and emit new stream of data. + + For example we have stream of chunks: ['1', '2', '3', '4', '5'], + we can apply chunking filter to this stream: + + ['1', '2', '3', '4', '5'] + | + response.add_chunking_filter(2) + | + ['12', '34', '5'] + + It is possible to use different filters at the same time. + + For a example to compress incoming stream with 'deflate' encoding + and then split data and emit chunks of 8196 bytes size chunks: + + >> response.add_compression_filter('deflate') + >> response.add_chunking_filter(8196) + + Filters do not alter transfer encoding. + + Filter can receive types types of data, bytes object or EOF_MARKER. + + 1. If filter receives bytes object, it should process data + and yield processed data then yield EOL_MARKER object. + 2. If Filter recevied EOF_MARKER, it should yield remaining + data (buffered) and then yield EOF_MARKER. + """ + @functools.wraps(func) + def wrapper(self, *args, **kw): + new_filter = func(self, *args, **kw) + + filter = self.filter + if filter is not None: + next(new_filter) + self.filter = filter_pipe(filter, new_filter) + else: + self.filter = new_filter + + next(self.filter) + + return wrapper + + +def filter_pipe(filter, filter2): + """Creates pipe between two filters. + + filter_pipe() feeds first filter with incoming data and then + send yielded from first filter data into filter2, results of + filter2 are being emitted. + + 1. If filter_pipe receives bytes object, it sends it to the first filter. + 2. Reads yielded values from the first filter until it receives + EOF_MARKER or EOL_MARKER. + 3. Each of this values is being send to second filter. + 4. Reads yielded values from second filter until it recives EOF_MARKER or + EOL_MARKER. Each of this values yields to writer. + """ + chunk = yield + + while True: + eof = chunk is EOF_MARKER + chunk = filter.send(chunk) + + while chunk is not EOL_MARKER: + chunk = filter2.send(chunk) + + while chunk not in (EOF_MARKER, EOL_MARKER): + yield chunk + chunk = next(filter2) + + if chunk is not EOF_MARKER: + if eof: + chunk = EOF_MARKER + else: + chunk = next(filter) + else: + break + + chunk = yield EOL_MARKER + + +class HttpMessage: + """HttpMessage allows to write headers and payload to a stream. + + For example, lets say we want to read file then compress it with deflate + compression and then send it with chunked transfer encoding, code may look + like this: + + >> response = tulip.http.Response(transport, 200) + + We have to use deflate compression first: + + >> response.add_compression_filter('deflate') + + Then we want to split output stream into chunks of 1024 bytes size: + + >> response.add_chunking_filter(1024) + + We can add headers to response with add_headers() method. add_headers() + does not send data to transport, send_headers() sends request/response + line and then sends headers: + + >> response.add_headers( + .. ('Content-Disposition', 'attachment; filename="..."')) + >> response.send_headers() + + Now we can use chunked writer to write stream to a network stream. + First call to write() method sends response status line and headers, + add_header() and add_headers() method unavailble at this stage: + + >> with open('...', 'rb') as f: + .. chunk = fp.read(8196) + .. while chunk: + .. response.write(chunk) + .. chunk = fp.read(8196) + + >> response.write_eof() + """ + + writer = None + + # 'filter' is being used for altering write() bahaviour, + # add_chunking_filter adds deflate/gzip compression and + # add_compression_filter splits incoming data into a chunks. + filter = None + + HOP_HEADERS = None # Must be set by subclass. + + SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} tulip/0.0'.format(sys.version_info) + + status = None + status_line = b'' + + # subclass can enable auto sending headers with write() call, + # this is useful for wsgi's start_response implementation. + _send_headers = False + + def __init__(self, transport, version, close): + self.transport = transport + self.version = version + self.closing = close + self.keepalive = False + + self.chunked = False + self.length = None + self.upgrade = False + self.headers = [] + self.headers_sent = False + + def force_close(self): + self.closing = True + + def force_chunked(self): + self.chunked = True + + def keep_alive(self): + return self.keepalive and not self.closing + + def is_headers_sent(self): + return self.headers_sent + + def add_header(self, name, value): + """Analyze headers. Calculate content length, + removes hop headers, etc.""" + assert not self.headers_sent, 'headers have been sent already' + assert isinstance(name, str), '%r is not a string' % name + + name = name.strip().upper() + + if name == 'CONTENT-LENGTH': + self.length = int(value) + + if name == 'CONNECTION': + val = value.lower().strip() + # handle websocket + if val == 'upgrade': + self.upgrade = True + # connection keep-alive + elif val == 'close': + self.keepalive = False + elif val == 'keep-alive': + self.keepalive = True + + elif name == 'UPGRADE': + if 'websocket' in value.lower(): + self.headers.append((name, value)) + + elif name == 'TRANSFER-ENCODING' and not self.chunked: + self.chunked = value.lower().strip() == 'chunked' + + elif name not in self.HOP_HEADERS: + # ignore hopbyhop headers + self.headers.append((name, value)) + + def add_headers(self, *headers): + """Adds headers to a http message.""" + for name, value in headers: + self.add_header(name, value) + + def send_headers(self): + """Writes headers to a stream. Constructs payload writer.""" + # Chunked response is only for HTTP/1.1 clients or newer + # and there is no Content-Length header is set. + # Do not use chunked responses when the response is guaranteed to + # not have a response body (304, 204). + assert not self.headers_sent, 'headers have been sent already' + self.headers_sent = True + + if (self.chunked is True) or ( + self.length is None and + self.version >= (1, 1) and + self.status not in (304, 204)): + self.chunked = True + self.writer = self._write_chunked_payload() + + elif self.length is not None: + self.writer = self._write_length_payload(self.length) + + else: + self.writer = self._write_eof_payload() + + next(self.writer) + + # status line + self.transport.write(self.status_line.encode('ascii')) + + # send headers + self.transport.write( + ('%s\r\n\r\n' % '\r\n'.join( + ('%s: %s' % (k, v) for k, v in + itertools.chain(self._default_headers(), self.headers))) + ).encode('ascii')) + + def _default_headers(self): + # set the connection header + if self.upgrade: + connection = 'upgrade' + elif self.keep_alive(): + connection = 'keep-alive' + else: + connection = 'close' + + headers = [('CONNECTION', connection)] + + if self.chunked: + headers.append(('TRANSFER-ENCODING', 'chunked')) + + return headers + + def write(self, chunk): + """write() writes chunk of data to a steram by using different writers. + writer uses filter to modify chunk of data. write_eof() indicates + end of stream. writer can't be used after write_eof() method + being called.""" + assert (isinstance(chunk, (bytes, bytearray)) or + chunk is EOF_MARKER), chunk + + if self._send_headers and not self.headers_sent: + self.send_headers() + + assert self.writer is not None, 'send_headers() is not called.' + + if self.filter: + chunk = self.filter.send(chunk) + while chunk not in (EOF_MARKER, EOL_MARKER): + self.writer.send(chunk) + chunk = next(self.filter) + else: + if chunk is not EOF_MARKER: + self.writer.send(chunk) + + def write_eof(self): + self.write(EOF_MARKER) + try: + self.writer.throw(StreamEofException()) + except StopIteration: + pass + + def _write_chunked_payload(self): + """Write data in chunked transfer encoding.""" + while True: + try: + chunk = yield + except StreamEofException: + self.transport.write(b'0\r\n\r\n') + break + + self.transport.write('{:x}\r\n'.format(len(chunk)).encode('ascii')) + self.transport.write(chunk) + self.transport.write(b'\r\n') + + def _write_length_payload(self, length): + """Write specified number of bytes to a stream.""" + while True: + try: + chunk = yield + except StreamEofException: + break + + if length: + l = len(chunk) + if length >= l: + self.transport.write(chunk) + else: + self.transport.write(chunk[:length]) + + length = max(0, length-l) + + def _write_eof_payload(self): + while True: + try: + chunk = yield + except StreamEofException: + break + + self.transport.write(chunk) + + @wrap_payload_filter + def add_chunking_filter(self, chunk_size=16*1024): + """Split incoming stream into chunks.""" + buf = bytearray() + chunk = yield + + while True: + if chunk is EOF_MARKER: + if buf: + yield buf + + yield EOF_MARKER + + else: + buf.extend(chunk) + + while len(buf) >= chunk_size: + chunk, buf = buf[:chunk_size], buf[chunk_size:] + yield chunk + + chunk = yield EOL_MARKER + + @wrap_payload_filter + def add_compression_filter(self, encoding='deflate'): + """Compress incoming stream with deflate or gzip encoding.""" + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + zcomp = zlib.compressobj(wbits=zlib_mode) + + chunk = yield + while True: + if chunk is EOF_MARKER: + yield zcomp.flush() + chunk = yield EOF_MARKER + + else: + yield zcomp.compress(chunk) + chunk = yield EOL_MARKER + + +class Response(HttpMessage): + """Create http response message. + + Transport is a socket stream transport. status is a response status code, + status has to be integer value. http_version is a tuple that represents + http version, (1, 0) stands for HTTP/1.0 and (1, 1) is for HTTP/1.1 + """ + + HOP_HEADERS = { + 'CONNECTION', + 'KEEP-ALIVE', + 'PROXY-AUTHENTICATE', + 'PROXY-AUTHORIZATION', + 'TE', + 'TRAILERS', + 'TRANSFER-ENCODING', + 'UPGRADE', + 'SERVER', + 'DATE', + } + + def __init__(self, transport, status, http_version=(1, 1), close=False): + super().__init__(transport, http_version, close) + + self.status = status + self.status_line = 'HTTP/{0[0]}.{0[1]} {1} {2}\r\n'.format( + http_version, status, RESPONSES[status][0]) + + def _default_headers(self): + headers = super()._default_headers() + headers.extend((('DATE', email.utils.formatdate()), + ('SERVER', self.SERVER_SOFTWARE))) + + return headers + + +class Request(HttpMessage): + + HOP_HEADERS = () + + def __init__(self, transport, method, uri, + http_version=(1, 1), close=False): + super().__init__(transport, http_version, close) + + self.method = method + self.uri = uri + self.status_line = '{0} {1} HTTP/{2[0]}.{2[1]}\r\n'.format( + method, uri, http_version) + + def _default_headers(self): + headers = super()._default_headers() + headers.append(('USER-AGENT', self.SERVER_SOFTWARE)) + + return headers diff --git a/tulip/http/server.py b/tulip/http/server.py new file mode 100644 index 00000000..7590e47b --- /dev/null +++ b/tulip/http/server.py @@ -0,0 +1,176 @@ +"""simple http server.""" + +__all__ = ['ServerHttpProtocol'] + +import http.server +import inspect +import logging +import traceback + +import tulip +import tulip.http + +from . import errors + +RESPONSES = http.server.BaseHTTPRequestHandler.responses +DEFAULT_ERROR_MESSAGE = """ + + + %(status)s %(reason)s + + +

%(status)s %(reason)s

+ %(message)s + +""" + + +class ServerHttpProtocol(tulip.Protocol): + """Simple http protocol implementation. + + ServerHttpProtocol handles incoming http request. It reads request line, + request headers and request payload and calls handler_request() method. + By default it always returns with 404 respose. + + ServerHttpProtocol handles errors in incoming request, like bad + status line, bad headers or incomplete payload. If any error occurs, + connection gets closed. + """ + closing = False + request_count = 0 + _request_handle = None + + def __init__(self, log=logging, debug=False): + self.log = log + self.debug = debug + + def connection_made(self, transport): + self.transport = transport + self.stream = tulip.http.HttpStreamReader() + self._request_handle = self.start() + + def data_received(self, data): + self.stream.feed_data(data) + + def connection_lost(self, exc): + if self._request_handle is not None: + self._request_handle.cancel() + self._request_handle = None + + def eof_received(self): + self.stream.feed_eof() + + def close(self): + self.closing = True + + def log_access(self, status, info, message, *args, **kw): + pass + + def log_debug(self, *args, **kw): + if self.debug: + self.log.debug(*args, **kw) + + def log_exception(self, *args, **kw): + self.log.exception(*args, **kw) + + @tulip.task + def start(self): + """Start processing of incoming requests. + It reads request line, request headers and request payload, then + calls handle_request() method. Subclass has to override + handle_request(). start() handles various excetions in request + or response handling. In case of any error connection is being closed. + """ + + while True: + info = None + message = None + self.request_count += 1 + + try: + info = yield from self.stream.read_request_line() + message = yield from self.stream.read_message(info.version) + + handler = self.handle_request(info, message) + if (inspect.isgenerator(handler) or + isinstance(handler, tulip.Future)): + yield from handler + + except tulip.CancelledError: + self.log_debug('Ignored premature client disconnection.') + break + except errors.HttpException as exc: + self.handle_error(exc.code, info, message, exc, exc.headers) + except Exception as exc: + self.handle_error(500, info, message, exc) + finally: + if self.closing: + self.transport.close() + break + + self._request_handle = None + + def handle_error(self, status=500, info=None, + message=None, exc=None, headers=None): + """Handle errors. + + Returns http response with specific status code. Logs additional + information. It always closes current connection.""" + + if status == 500: + self.log_exception("Error handling request") + + try: + reason, msg = RESPONSES[status] + except KeyError: + status = 500 + reason, msg = '???', '' + + if self.debug and exc is not None: + try: + tb = traceback.format_exc() + msg += '

Traceback:

\n
%s
' % tb + except: + pass + + self.log_access(status, info, message) + + html = DEFAULT_ERROR_MESSAGE % { + 'status': status, 'reason': reason, 'message': msg} + + response = tulip.http.Response(self.transport, status, close=True) + response.add_headers( + ('Content-Type', 'text/html'), + ('Content-Length', str(len(html)))) + if headers is not None: + response.add_headers(*headers) + response.send_headers() + + response.write(html.encode('ascii')) + response.write_eof() + + self.close() + + def handle_request(self, info, message): + """Handle a single http request. + + Subclass should override this method. By default it always + returns 404 response. + + info: tulip.http.RequestLine instance + message: tulip.http.RawHttpMessage instance + """ + response = tulip.http.Response( + self.transport, 404, http_version=info.version, close=True) + + body = b'Page Not Found!' + + response.add_headers( + ('Content-Type', 'text/plain'), + ('Content-Length', str(len(body)))) + response.send_headers() + response.write(body) + response.write_eof() + + self.close() + self.log_access(404, info, message) diff --git a/tulip/locks.py b/tulip/locks.py new file mode 100644 index 00000000..40247962 --- /dev/null +++ b/tulip/locks.py @@ -0,0 +1,433 @@ +"""Synchronization primitives""" + +__all__ = ['Lock', 'EventWaiter', 'Condition', 'Semaphore'] + +import collections +import time + +from . import events +from . import futures +from . import tasks + + +class Lock: + """The class implementing primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned by + a particular coroutine when locked. A primitive lock is in one of two + states, "locked" or "unlocked". + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() changes + the state to locked and returns immediately. When the state is locked, + acquire() blocks until a call to release() in another coroutine changes + it to unlocked, then the acquire() call resets it to locked and returns. + The release() method should only be called in the locked state; it changes + the state to unlocked and returns immediately. If an attempt is made + to release an unlocked lock, a RuntimeError will be raised. + + When more than one coroutine is blocked in acquire() waiting for the state + to turn to unlocked, only one coroutine proceeds when a release() call + resets the state to unlocked; first coroutine which is blocked in acquire() + is being processed. + + acquire() method is a coroutine and should be called with "yield from" + + Locks also support the context manager protocol. (yield from lock) should + be used as context manager expression. + + Usage: + + lock = Lock() + ... + yield from lock + try: + ... + finally: + lock.release() + + Context manager usage: + + lock = Lock() + ... + with (yield from lock): + ... + + Lock object could be tested for locking state: + + if not lock.locked(): + yield from lock + else: + # lock is acquired + ... + + """ + + def __init__(self): + self._waiters = collections.deque() + self._locked = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<%s [%s]>' % ( + res[1:-1], 'locked' if self._locked else 'unlocked') + + def locked(self): + """Return true if lock is acquired.""" + return self._locked + + @tasks.coroutine + def acquire(self, timeout=None): + """Acquire a lock. + + Acquire method blocks until the lock is unlocked, then set it to + locked and return True. + + When invoked with the floating-point timeout argument set, blocks for + at most the number of seconds specified by timeout and as long as + the lock cannot be acquired. + + The return value is True if the lock is acquired successfully, + False if not (for example if the timeout expired). + """ + if not self._waiters and not self._locked: + self._locked = True + return True + + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + self._locked = True + return True + + def release(self): + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + If any other coroutines are blocked waiting for the lock to become + unlocked, allow exactly one of them to proceed. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + if self._waiters: + self._waiters[0].set_result(True) + else: + raise RuntimeError('Lock is not acquired.') + + def __enter__(self): + if not self._locked: + raise RuntimeError( + '"yield from" should be used as context manager expression') + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self + + +class EventWaiter: + """A EventWaiter implementation, our equivalent to threading.Event + + Class implementing event objects. An event manages a flag that can be set + to true with the set() method and reset to false with the clear() method. + The wait() method blocks until the flag is true. The flag is initially + false. + """ + + def __init__(self): + self._waiters = collections.deque() + self._value = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<%s [%s]>' % (res[1:-1], 'set' if self._value else 'unset') + + def is_set(self): + """Return true if and only if the internal flag is true.""" + return self._value + + def set(self): + """Set the internal flag to true. All coroutines waiting for it to + become true are awakened. Coroutine that call wait() once the flag is + true will not block at all. + """ + if not self._value: + self._value = True + + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + + def clear(self): + """Reset the internal flag to false. Subsequently, coroutines calling + wait() will block until set() is called to set the internal flag + to true again.""" + self._value = False + + @tasks.coroutine + def wait(self, timeout=None): + """Block until the internal flag is true. If the internal flag + is true on entry, return immediately. Otherwise, block until another + coroutine calls set() to set the flag to true, or until the optional + timeout occurs. + + When the timeout argument is present and not None, it should be + a floating point number specifying a timeout for the operation in + seconds (or fractions thereof). + + This method returns true if and only if the internal flag has been + set to true, either before the wait call or after the wait starts, + so it will always return True except if a timeout is given and + the operation times out. + + wait() method is a coroutine. + """ + if self._value: + return True + + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + return True + + +class Condition(Lock): + """A Condition implementation. + + This class implements condition variable objects. A condition variable + allows one or more coroutines to wait until they are notified by another + coroutine. + """ + + def __init__(self): + super().__init__() + + self._condition_waiters = collections.deque() + + @tasks.coroutine + def wait(self, timeout=None): + """Wait until notified or until a timeout occurs. If the calling + coroutine has not acquired the lock when this method is called, + a RuntimeError is raised. + + This method releases the underlying lock, and then blocks until it is + awakened by a notify() or notify_all() call for the same condition + variable in another coroutine, or until the optional timeout occurs. + Once awakened or timed out, it re-acquires the lock and returns. + + When the timeout argument is present and not None, it should be + a floating point number specifying a timeout for the operation + in seconds (or fractions thereof). + + The return value is True unless a given timeout expired, in which + case it is False. + """ + if not self._locked: + raise RuntimeError('cannot wait on un-acquired lock') + + self.release() + + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + + self._condition_waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._condition_waiters.remove(fut) + return False + else: + f = self._condition_waiters.popleft() + assert fut is f + finally: + yield from self.acquire() + + return True + + @tasks.coroutine + def wait_for(self, predicate, timeout=None): + """Wait until a condition evaluates to True. predicate should be a + callable which result will be interpreted as a boolean value. A timeout + may be provided giving the maximum time to wait. + """ + endtime = None + waittime = timeout + result = predicate() + + while not result: + if waittime is not None: + if endtime is None: + endtime = time.monotonic() + waittime + else: + waittime = endtime - time.monotonic() + if waittime <= 0: + break + + yield from self.wait(waittime) + result = predicate() + + return result + + def notify(self, n=1): + """By default, wake up one coroutine waiting on this condition, if any. + If the calling coroutine has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up at most n of the coroutines waiting for the + condition variable; it is a no-op if no coroutines are waiting. + + Note: an awakened coroutine does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self._locked: + raise RuntimeError('cannot notify on un-acquired lock') + + idx = 0 + for fut in self._condition_waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self): + """Wake up all threads waiting on this condition. This method acts + like notify(), but wakes up all waiting threads instead of one. If the + calling thread has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._condition_waiters)) + + +class Semaphore: + """A Semaphore implementation. + + A semaphore manages an internal counter which is decremented by each + acquire() call and incremented by each release() call. The counter + can never go below zero; when acquire() finds that it is zero, it blocks, + waiting until some other thread calls release(). + + Semaphores also support the context manager protocol. + + The first optional argument gives the initial value for the internal + counter; it defaults to 1. If the value given is less than 0, + ValueError is raised. + + The second optional argument determins can semophore be released more than + initial internal counter value; it defaults to False. If the value given + is True and number of release() is more than number of successfull + acquire() calls ValueError is raised. + """ + + def __init__(self, value=1, bound=False): + if value < 0: + raise ValueError("Semaphore initial value must be > 0") + self._value = value + self._bound = bound + self._bound_value = value + self._waiters = collections.deque() + self._locked = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<%s [%s]>' % ( + res[1:-1], + 'locked' if self._locked else 'unlocked,value:%s' % self._value) + + def locked(self): + """Returns True if semaphore can not be acquired immediately.""" + return self._locked + + @tasks.coroutine + def acquire(self, timeout=None): + """Acquire a semaphore. acquire() method is a coroutine. + + When invoked without arguments: if the internal counter is larger + than zero on entry, decrement it by one and return immediately. + If it is zero on entry, block, waiting until some other coroutine has + called release() to make it larger than zero. + + When invoked with a timeout other than None, it will block for at + most timeout seconds. If acquire does not complete successfully in + that interval, return false. Return true otherwise. + """ + if not self._waiters and self._value > 0: + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + When it was zero on entry and another coroutine is waiting for it to + become larger than zero again, wake up that coroutine. + + If Semaphore is create with "bound" paramter equals true, then + release() method checks to make sure its current value doesn't exceed + its initial value. If it does, ValueError is raised. + """ + if self._bound and self._value >= self._bound_value: + raise ValueError('Semaphore released too many times') + + self._value += 1 + self._locked = False + + for waiter in self._waiters: + if not waiter.done(): + waiter.set_result(True) + break + + def __enter__(self): + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py new file mode 100644 index 00000000..d5865dbd --- /dev/null +++ b/tulip/proactor_events.py @@ -0,0 +1,190 @@ +"""Event loop using a proactor and related classes. + +A proactor is a "notify-on-completion" multiplexer. Currently a +proactor is only implemented on Windows with IOCP. +""" + +import logging + +from . import base_events +from . import transports + + +class _ProactorSocketTransport(transports.Transport): + + def __init__(self, event_loop, sock, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._event_loop = event_loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._read_fut = None + self._write_fut = None + self._closing = False # Set when close() called. + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.call_soon(self._loop_reading) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _loop_reading(self, fut=None): + data = None + + try: + if fut is not None: + assert fut is self._read_fut + + data = fut.result() # deliver data later in "finally" clause + if not data: + self._read_fut = None + return + + self._read_fut = self._event_loop._proactor.recv(self._sock, 4096) + except ConnectionAbortedError as exc: + if not self._closing: + self._fatal_error(exc) + except OSError as exc: + self._fatal_error(exc) + else: + self._read_fut.add_done_callback(self._loop_reading) + finally: + if data: + self._protocol.data_received(data) + elif data is not None: + self._protocol.eof_received() + + def write(self, data): + assert isinstance(data, bytes), repr(data) + assert not self._closing + if not data: + return + self._buffer.append(data) + if not self._write_fut: + self._loop_writing() + + def _loop_writing(self, f=None): + try: + assert f is self._write_fut + if f: + f.result() + data = b''.join(self._buffer) + self._buffer = [] + if not data: + self._write_fut = None + return + self._write_fut = self._event_loop._proactor.send(self._sock, data) + except OSError as exc: + self._fatal_error(exc) + else: + self._write_fut.add_done_callback(self._loop_writing) + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + if self._write_fut: + self._write_fut.cancel() + if not self._buffer: + self._event_loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + logging.exception('Fatal error for %s', self) + if self._write_fut: + self._write_fut.cancel() + if self._read_fut: # XXX + self._read_fut.cancel() + self._write_fut = self._read_fut = None + self._buffer = [] + self._event_loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + + +class BaseProactorEventLoop(base_events.BaseEventLoop): + + def __init__(self, proactor): + super().__init__() + logging.debug('Using proactor: %s', proactor.__class__.__name__) + self._proactor = proactor + self._selector = proactor # convenient alias + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + return _ProactorSocketTransport(self, sock, protocol, waiter, extra) + + def close(self): + if self._proactor is not None: + self._close_self_pipe() + self._proactor.close() + self._proactor = None + self._selector = None + + def sock_recv(self, sock, n): + return self._proactor.recv(sock, n) + + def sock_sendall(self, sock, data): + return self._proactor.send(sock, data) + + def sock_connect(self, sock, address): + return self._proactor.connect(sock, address) + + def sock_accept(self, sock): + return self._proactor.accept(sock) + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.call_soon(self._loop_self_reading) + + def _loop_self_reading(self, f=None): + try: + if f is not None: + f.result() # may raise + f = self._proactor.recv(self._ssock, 4096) + except: + self.close() + raise + else: + f.add_done_callback(self._loop_self_reading) + + def _write_to_self(self): + self._csock.send(b'x') + + def _start_serving(self, protocol_factory, sock): + def loop(f=None): + try: + if f: + conn, addr = f.result() + protocol = protocol_factory() + self._make_socket_transport( + conn, protocol, extra={'addr': addr}) + f = self._proactor.accept(sock) + except OSError: + sock.close() + logging.exception('Accept failed') + else: + f.add_done_callback(loop) + self.call_soon(loop) + + def _process_events(self, event_list): + pass # XXX hard work currently done in poll diff --git a/tulip/protocols.py b/tulip/protocols.py new file mode 100644 index 00000000..593ee745 --- /dev/null +++ b/tulip/protocols.py @@ -0,0 +1,78 @@ +"""Abstract Protocol class.""" + +__all__ = ['Protocol', 'DatagramProtocol'] + + +class BaseProtocol: + """ABC for base protocol class. + + Usually user implements protocols that derived from BaseProtocol + like Protocol or ProcessProtocol. + + The only case when BaseProtocol should be implemented directly is + write-only transport like write pipe + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the pipe connection. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +class Protocol(BaseProtocol): + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing (they don't raise exceptions). + + When the user wants to requests a transport, they pass a protocol + factory to a utility function (e.g., EventLoop.create_connection()). + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_lost() will be called exactly once + with either an exception object or None as an argument. + + State machine of calls: + + start -> CM [-> DR*] [-> ER?] -> CL -> end + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + The default implementation does nothing. + + TODO: By default close the transport. But we don't have the + transport as an instance variable (connection_made() may not + set it). + """ + + +class DatagramProtocol(BaseProtocol): + """ABC representing a datagram protocol.""" + + def datagram_received(self, data, addr): + """Called when some datagram is received.""" + + def connection_refused(self, exc): + """Connection is refused.""" diff --git a/tulip/queues.py b/tulip/queues.py new file mode 100644 index 00000000..ee349e13 --- /dev/null +++ b/tulip/queues.py @@ -0,0 +1,291 @@ +"""Queues""" + +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue'] + +import collections +import concurrent.futures +import heapq +import queue + +from . import events +from . import futures +from . import locks +from .tasks import coroutine + + +class Queue: + """A queue, useful for coordinating producer and consumer coroutines. + + If maxsize is less than or equal to zero, the queue size is infinite. If it + is an integer greater than 0, then "yield from put()" will block when the + queue reaches maxsize, until an item is removed by get(). + + Unlike the standard library Queue, you can reliably know this Queue's size + with qsize(), since your single-threaded Tulip application won't be + interrupted between calling qsize() and doing an operation on the Queue. + """ + + def __init__(self, maxsize=0): + self._event_loop = events.get_event_loop() + self._maxsize = maxsize + + # Futures. + self._getters = collections.deque() + # Pairs of (item, Future). + self._putters = collections.deque() + self._init(maxsize) + + def _init(self, maxsize): + self._queue = collections.deque() + + def _get(self): + return self._queue.popleft() + + def _put(self, item): + self._queue.append(item) + + def __repr__(self): + return '<%s at %s %s>' % ( + type(self).__name__, hex(id(self)), self._format()) + + def __str__(self): + return '<%s %s>' % (type(self).__name__, self._format()) + + def _format(self): + result = 'maxsize=%r' % (self._maxsize, ) + if getattr(self, '_queue', None): + result += ' _queue=%r' % list(self._queue) + if self._getters: + result += ' _getters[%s]' % len(self._getters) + if self._putters: + result += ' _putters[%s]' % len(self._putters) + return result + + def _consume_done_getters(self, waiters): + # Delete waiters at the head of the get() queue who've timed out. + while waiters and waiters[0].done(): + waiters.popleft() + + def _consume_done_putters(self): + # Delete waiters at the head of the put() queue who've timed out. + while self._putters and self._putters[0][1].done(): + self._putters.popleft() + + def qsize(self): + """Number of items in the queue.""" + return len(self._queue) + + @property + def maxsize(self): + """Number of items allowed in the queue.""" + return self._maxsize + + def empty(self): + """Return True if the queue is empty, False otherwise.""" + return not self._queue + + def full(self): + """Return True if there are maxsize items in the queue. + + Note: if the Queue was initialized with maxsize=0 (the default), + then full() is never True. + """ + if self._maxsize <= 0: + return False + else: + return self.qsize() == self._maxsize + + @coroutine + def put(self, item, timeout=None): + """Put an item into the queue. + + If you yield from put() and timeout is None (the default), wait until a + free slot is available before adding item. + + If a timeout is provided, raise queue.Full if no free slot becomes + available before the timeout. + """ + self._consume_done_getters(self._getters) + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + waiter = futures.Future( + event_loop=self._event_loop, timeout=timeout) + + self._putters.append((item, waiter)) + try: + yield from waiter + except concurrent.futures.CancelledError: + raise queue.Full + + else: + self._put(item) + + def put_nowait(self, item): + """Put an item into the queue without blocking. + + If no free slot is immediately available, raise queue.Full. + """ + self._consume_done_getters(self._getters) + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + raise queue.Full + else: + self._put(item) + + @coroutine + def get(self, timeout=None): + """Remove and return an item from the queue. + + If you yield from get() and timeout is None (the default), wait until a + item is available. + + If a timeout is provided, raise queue.Empty if no item is available + before the timeout. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + + # When a getter runs and frees up a slot so this putter can + # run, we need to defer the put for a tick to ensure that + # getters and putters alternate perfectly. See + # ChannelTest.test_wait. + self._event_loop.call_soon(putter.set_result, None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + waiter = futures.Future( + event_loop=self._event_loop, timeout=timeout) + + self._getters.append(waiter) + try: + return (yield from waiter) + except concurrent.futures.CancelledError: + raise queue.Empty + + def get_nowait(self): + """Remove and return an item from the queue. + + Return an item if one is immediately available, else raise queue.Full. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + # Wake putter on next tick. + putter.set_result(None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + raise queue.Empty + + +class PriorityQueue(Queue): + """A subclass of Queue; retrieves entries in priority order (lowest first). + + Entries are typically tuples of the form: (priority number, data). + """ + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item, heappush=heapq.heappush): + heappush(self._queue, item) + + def _get(self, heappop=heapq.heappop): + return heappop(self._queue) + + +class LifoQueue(Queue): + """A subclass of Queue that retrieves most recently added entries first.""" + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item): + self._queue.append(item) + + def _get(self): + return self._queue.pop() + + +class JoinableQueue(Queue): + """A subclass of Queue with task_done() and join() methods.""" + + def __init__(self, maxsize=0): + self._unfinished_tasks = 0 + self._finished = locks.EventWaiter() + self._finished.set() + super().__init__(maxsize=maxsize) + + def _format(self): + result = Queue._format(self) + if self._unfinished_tasks: + result += ' tasks=%s' % self._unfinished_tasks + return result + + def _put(self, item): + super()._put(item) + self._unfinished_tasks += 1 + self._finished.clear() + + def task_done(self): + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each get() used to fetch a task, + a subsequent call to task_done() tells the queue that the processing + on the task is complete. + + If a join() is currently blocking, it will resume when all items have + been processed (meaning that a task_done() call was received for every + item that had been put() into the queue). + + Raises ValueError if called more times than there were items placed in + the queue. + """ + if self._unfinished_tasks <= 0: + raise ValueError('task_done() called too many times') + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + @coroutine + def join(self, timeout=None): + """Block until all items in the queue have been gotten and processed. + + The count of unfinished tasks goes up whenever an item is added to the + queue. The count goes down whenever a consumer thread calls task_done() + to indicate that the item was retrieved and all work on it is complete. + When the count of unfinished tasks drops to zero, join() unblocks. + """ + if self._unfinished_tasks > 0: + yield from self._finished.wait(timeout=timeout) diff --git a/tulip/selector_events.py b/tulip/selector_events.py new file mode 100644 index 00000000..46e04dae --- /dev/null +++ b/tulip/selector_events.py @@ -0,0 +1,655 @@ +"""Event loop using a selector and related classes. + +A selector is a "notify-when-ready" multiplexer. For a subclass which +also includes support for signal handling, see the unix_events sub-module. +""" + +import collections +import errno +import logging +import socket +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import base_events +from . import events +from . import futures +from . import selectors +from . import transports + + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + + +class BaseSelectorEventLoop(base_events.BaseEventLoop): + """Selector event loop. + + See events.EventLoop for API specification. + """ + + def __init__(self, selector=None): + super().__init__() + + if selector is None: + selector = selectors.Selector() + logging.debug('Using selector: %s', selector.__class__.__name__) + self._selector = selector + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + return _SelectorSocketTransport(self, sock, protocol, waiter, extra) + + def _make_ssl_transport(self, rawsock, protocol, + sslcontext, waiter, extra=None): + return _SelectorSslTransport( + self, rawsock, protocol, sslcontext, waiter, extra) + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + return _SelectorDatagramTransport(self, sock, protocol, address, extra) + + def close(self): + if self._selector is not None: + self._close_self_pipe() + self._selector.close() + self._selector = None + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self.remove_reader(self._ssock.fileno()) + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.add_reader(self._ssock.fileno(), self._read_from_self) + + def _read_from_self(self): + try: + self._ssock.recv(1) + except (BlockingIOError, InterruptedError): + pass + + def _write_to_self(self): + try: + self._csock.send(b'x') + except (BlockingIOError, InterruptedError): + pass + + def _start_serving(self, protocol_factory, sock): + self.add_reader(sock.fileno(), self._accept_connection, + protocol_factory, sock) + + def _accept_connection(self, protocol_factory, sock): + try: + conn, addr = sock.accept() + except (BlockingIOError, InterruptedError): + pass # False alarm. + except: + # Bad error. Stop serving. + self.remove_reader(sock.fileno()) + sock.close() + # There's nowhere to send the error, so just log it. + # TODO: Someone will want an error handler for this. + logging.exception('Accept failed') + else: + self._make_socket_transport( + conn, protocol_factory(), extra={'addr': addr}) + # It's now up to the protocol to handle the connection. + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a Handle instance.""" + handle = events.make_handle(callback, args) + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_READ, + (handle, None)) + else: + self._selector.modify(fd, mask | selectors.EVENT_READ, + (handle, writer)) + if reader is not None: + reader.cancel() + + return handle + + def remove_reader(self, fd): + """Remove a reader callback.""" + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + return False + else: + mask &= ~selectors.EVENT_READ + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (None, writer)) + + if reader is not None: + reader.cancel() + return True + else: + return False + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a Handle instance.""" + handle = events.make_handle(callback, args) + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_WRITE, + (None, handle)) + else: + self._selector.modify(fd, mask | selectors.EVENT_WRITE, + (reader, handle)) + if writer is not None: + writer.cancel() + + return handle + + def remove_writer(self, fd): + """Remove a writer callback.""" + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + return False + else: + # Remove both writer and connector. + mask &= ~selectors.EVENT_WRITE + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None)) + + if writer is not None: + writer.cancel() + return True + else: + return False + + def sock_recv(self, sock, n): + """XXX""" + fut = futures.Future() + self._sock_recv(fut, False, sock, n) + return fut + + def _sock_recv(self, fut, registered, sock, n): + fd = sock.fileno() + if registered: + # Remove the callback early. It should be rare that the + # selector says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_reader(fd) + if fut.cancelled(): + return + try: + data = sock.recv(n) + fut.set_result(data) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_recv, fut, True, sock, n) + except Exception as exc: + fut.set_exception(exc) + + def sock_sendall(self, sock, data): + """XXX""" + fut = futures.Future() + if data: + self._sock_sendall(fut, False, sock, data) + else: + fut.set_result(None) + return fut + + def _sock_sendall(self, fut, registered, sock, data): + fd = sock.fileno() + + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + + try: + n = sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + fut.set_exception(exc) + return + + if n == len(data): + fut.set_result(None) + else: + if n: + data = data[n:] + self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + + def sock_connect(self, sock, address): + """XXX""" + # That address better not require a lookup! We're not calling + # self.getaddrinfo() for you here. But verifying this is + # complicated; the socket module doesn't have a pattern for + # IPv6 addresses (there are too many forms, apparently). + fut = futures.Future() + self._sock_connect(fut, False, sock, address) + return fut + + def _sock_connect(self, fut, registered, sock, address): + fd = sock.fileno() + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + try: + if not registered: + # First time around. + sock.connect(address) + else: + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to the except clause below. + raise socket.error(err, 'Connect call failed') + fut.set_result(None) + except (BlockingIOError, InterruptedError): + self.add_writer(fd, self._sock_connect, fut, True, sock, address) + except Exception as exc: + fut.set_exception(exc) + + def sock_accept(self, sock): + """XXX""" + fut = futures.Future() + self._sock_accept(fut, False, sock) + return fut + + def _sock_accept(self, fut, registered, sock): + fd = sock.fileno() + if registered: + self.remove_reader(fd) + if fut.cancelled(): + return + try: + conn, address = sock.accept() + conn.setblocking(False) + fut.set_result((conn, address)) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_accept, fut, True, sock) + except Exception as exc: + fut.set_exception(exc) + + def _process_events(self, event_list): + for fileobj, mask, (reader, writer) in event_list: + if mask & selectors.EVENT_READ and reader is not None: + if reader.cancelled: + self.remove_reader(fileobj) + else: + self._add_callback(reader) + if mask & selectors.EVENT_WRITE and writer is not None: + if writer.cancelled: + self.remove_writer(fileobj) + else: + self._add_callback(writer) + + +class _SelectorSocketTransport(transports.Transport): + + def __init__(self, event_loop, sock, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._event_loop = event_loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._closing = False # Set when close() called. + self._event_loop.add_reader(self._sock.fileno(), self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = self._sock.recv(16*1024) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + self._event_loop.remove_reader(self._sock.fileno()) + self._protocol.eof_received() + + def write(self, data): + assert isinstance(data, (bytes, bytearray)), repr(data) + assert not self._closing + if not data: + return + + if not self._buffer: + # Attempt to send it right away first. + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except socket.error as exc: + self._fatal_error(exc) + return + + if n == len(data): + return + elif n: + data = data[n:] + self._event_loop.add_writer(self._sock.fileno(), self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + assert data, "Data should not be empty" + + self._buffer.clear() + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except Exception as exc: + self._fatal_error(exc) + else: + if n == len(data): + self._event_loop.remove_writer(self._sock.fileno()) + if self._closing: + self._call_connection_lost(None) + return + elif n: + data = data[n:] + + self._buffer.append(data) # Try again later. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._close(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sock.fileno()) + if not self._buffer: + self._call_connection_lost(None) + + def _fatal_error(self, exc): + # should be called from exception handler only + logging.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._event_loop.remove_writer(self._sock.fileno()) + self._event_loop.remove_reader(self._sock.fileno()) + self._buffer.clear() + self._call_connection_lost(exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + + +class _SelectorSslTransport(transports.Transport): + + def __init__(self, event_loop, rawsock, + protocol, sslcontext, waiter, extra=None): + super().__init__(extra) + + self._event_loop = event_loop + self._rawsock = rawsock + self._protocol = protocol + sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslcontext = sslcontext + self._waiter = waiter + sslsock = sslcontext.wrap_socket(rawsock, + do_handshake_on_connect=False) + self._sslsock = sslsock + self._buffer = [] + self._closing = False # Set when close() called. + self._extra['socket'] = sslsock + + self._on_handshake() + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._event_loop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._event_loop.add_writer(fd, self._on_handshake) + return + except Exception as exc: + self._sslsock.close() + self._waiter.set_exception(exc) + return + except BaseException as exc: + self._sslsock.close() + self._waiter.set_exception(exc) + raise + self._event_loop.remove_reader(fd) + self._event_loop.remove_writer(fd) + self._event_loop.add_reader(fd, self._on_ready) + self._event_loop.add_writer(fd, self._on_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.call_soon(self._waiter.set_result, None) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer + self._event_loop.remove_reader(fd) + self._event_loop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + return + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = b''.join(self._buffer) + self._buffer = [] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + n = 0 + except ssl.SSLWantWriteError: + n = 0 + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + + if n < len(data): + self._buffer.append(data[n:]) + elif self._closing: + self._event_loop.remove_writer(self._sslsock.fileno()) + self._sslsock.close() + self._protocol.connection_lost(None) + + def write(self, data): + assert isinstance(data, bytes), repr(data) + assert not self._closing + if not data: + return + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._close(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sslsock.fileno()) + if not self._buffer: + self._protocol.connection_lost(None) + + def _fatal_error(self, exc): + logging.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._event_loop.remove_writer(self._sslsock.fileno()) + self._event_loop.remove_reader(self._sslsock.fileno()) + self._buffer = [] + self._protocol.connection_lost(exc) + + +class _SelectorDatagramTransport(transports.DatagramTransport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, event_loop, sock, protocol, address=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._event_loop = event_loop + self._sock = sock + self._fileno = sock.fileno() + self._protocol = protocol + self._address = address + self._buffer = collections.deque() + self._closing = False # Set when close() called. + self._event_loop.add_reader(self._fileno, self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + + def _read_ready(self): + try: + data, addr = self._sock.recvfrom(self.max_size) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + self._protocol.datagram_received(data, addr) + + def sendto(self, data, addr=None): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + + if self._address: + assert addr in (None, self._address) + + if not self._buffer: + # Attempt to send it right away first. + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + return + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except (BlockingIOError, InterruptedError): + self._event_loop.add_writer(self._fileno, self._sendto_ready) + except Exception as exc: + self._fatal_error(exc) + return + + self._buffer.append((data, addr)) + + def _sendto_ready(self): + while self._buffer: + data, addr = self._buffer.popleft() + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except (BlockingIOError, InterruptedError): + self._buffer.appendleft((data, addr)) # Try again later. + break + except Exception as exc: + self._fatal_error(exc) + return + + if not self._buffer: + self._event_loop.remove_writer(self._fileno) + if self._closing: + self._call_connection_lost(None) + + def abort(self): + self._close(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._fileno) + if not self._buffer: + self._call_connection_lost(None) + + def _fatal_error(self, exc): + logging.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._buffer.clear() + self._event_loop.remove_writer(self._fileno) + self._event_loop.remove_reader(self._fileno) + if self._address and isinstance(exc, ConnectionRefusedError): + self._protocol.connection_refused(exc) + self._call_connection_lost(exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() diff --git a/tulip/selectors.py b/tulip/selectors.py new file mode 100644 index 00000000..e8fd12e9 --- /dev/null +++ b/tulip/selectors.py @@ -0,0 +1,418 @@ +"""Select module. + +This module supports asynchronous I/O on multiple file descriptors. +""" + +import logging +import sys + +from select import * + + +# generic events, that must be mapped to implementation-specific ones +# read event +EVENT_READ = (1 << 0) +# write event +EVENT_WRITE = (1 << 1) + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file descriptor, or any object with a `fileno()` method + + Returns: + corresponding file descriptor + """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (ValueError, TypeError): + raise ValueError("Invalid file object: {!r}".format(fileobj)) + return fd + + +class SelectorKey: + """Object used internally to associate a file object to its backing file + descriptor, selected event mask and attached data.""" + + def __init__(self, fileobj, events, data=None): + self.fileobj = fileobj + self.fd = _fileobj_to_fd(fileobj) + self.events = events + self.data = data + + def __repr__(self): + return '{}'.format( + self.__class__.__name__, + self.fileobj, self.fd, self.events, self.data) + + +class _BaseSelector: + """Base selector class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + performant implementation on the current platform. + """ + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # this maps file objects to keys - for fast (un)registering + self._fileobj_to_key = {} + + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + """ + if (not events) or (events & ~(EVENT_READ|EVENT_WRITE)): + raise ValueError("Invalid events: {}".format(events)) + + if fileobj in self._fileobj_to_key: + raise ValueError("{!r} is already registered".format(fileobj)) + + key = SelectorKey(fileobj, events, data) + self._fd_to_key[key.fd] = key + self._fileobj_to_key[fileobj] = key + return key + + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object + + Returns: + SelectorKey instance + """ + try: + key = self._fileobj_to_key[fileobj] + del self._fd_to_key[key.fd] + del self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + return key + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + """ + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + if events != key.events or data != key.data: + self.unregister(fileobj) + return self.register(fileobj, events, data) + else: + return key + + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout == 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (fileobj, events, attached data) for ready file objects + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE + """ + raise NotImplementedError() + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + self._fd_to_key.clear() + self._fileobj_to_key.clear() + + def get_info(self, fileobj): + """Return information about a registered file object. + + Returns: + (events, data) associated to this file object + + Raises KeyError if the file object is not registered. + """ + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise KeyError("{} is not registered".format(fileobj)) + return key.events, key.data + + def registered_count(self): + """Return the number of registered file objects. + + Returns: + number of currently registered file objects + """ + return len(self._fd_to_key) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key + """ + try: + return self._fd_to_key[fd] + except KeyError: + logging.warn('No key found for fd %r', fd) + return None + + +class SelectSelector(_BaseSelector): + """Select-based selector.""" + + def __init__(self): + super().__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & EVENT_WRITE: + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + def select(self, timeout=None): + try: + r, w, _ = self._select(self._readers, self._writers, [], timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + r = set(r) + w = set(w) + ready = [] + for fd in r | w: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + if sys.platform == 'win32': + def _select(self, r, w, _, timeout=None): + r, w, x = select(r, w, w, timeout) + return r, w + x, [] + else: + from select import select as _select + + +if 'poll' in globals(): + + # TODO: Implement poll() for Windows with workaround for + # brokenness in WSAPoll() (Richard Oudkerk, see + # http://bugs.python.org/issue16507). + + class PollSelector(_BaseSelector): + """Poll-based selector.""" + + def __init__(self): + super().__init__() + self._poll = poll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= POLLIN + if events & EVENT_WRITE: + poll_events |= POLLOUT + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = None if timeout is None else int(1000 * timeout) + ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~POLLIN: + events |= EVENT_WRITE + if event & ~POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + +if 'epoll' in globals(): + + class EpollSelector(_BaseSelector): + """Epoll-based selector.""" + + def __init__(self): + super().__init__() + self._epoll = epoll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + epoll_events = 0 + if events & EVENT_READ: + epoll_events |= EPOLLIN + if events & EVENT_WRITE: + epoll_events |= EPOLLOUT + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._epoll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = -1 if timeout is None else timeout + max_ev = self.registered_count() + ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~EPOLLIN: + events |= EVENT_WRITE + if event & ~EPOLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + def close(self): + super().close() + self._epoll.close() + + +if 'kqueue' in globals(): + + class KqueueSelector(_BaseSelector): + """Kqueue-based selector.""" + + def __init__(self): + super().__init__() + self._kqueue = kqueue() + + def unregister(self, fileobj): + key = super().unregister(fileobj) + if key.events & EVENT_READ: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + if key.events & EVENT_WRITE: + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + return key + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & EVENT_WRITE: + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def select(self, timeout=None): + max_ev = self.registered_count() + ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == KQ_FILTER_READ: + events |= EVENT_READ + if flag == KQ_FILTER_WRITE: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + def close(self): + super().close() + self._kqueue.close() + + +# Choose the best implementation: roughly, epoll|kqueue > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + Selector = KqueueSelector +elif 'EpollSelector' in globals(): + Selector = EpollSelector +elif 'PollSelector' in globals(): + Selector = PollSelector +else: + Selector = SelectSelector diff --git a/tulip/streams.py b/tulip/streams.py new file mode 100644 index 00000000..8d7f6236 --- /dev/null +++ b/tulip/streams.py @@ -0,0 +1,145 @@ +"""Stream-related things.""" + +__all__ = ['StreamReader'] + +import collections + +from . import futures +from . import tasks + + +class StreamReader: + + def __init__(self, limit=2**16): + self.limit = limit # Max line length. (Security feature.) + self.buffer = collections.deque() # Deque of bytes objects. + self.byte_count = 0 # Bytes in buffer. + self.eof = False # Whether we're done. + self.waiter = None # A future. + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + if self.waiter is not None: + self.waiter.set_exception(exc) + + def feed_eof(self): + self.eof = True + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(True) + + def feed_data(self, data): + if not data: + return + + self.buffer.append(data) + self.byte_count += len(data) + + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(False) + + @tasks.coroutine + def readline(self): + if self._exception is not None: + raise self._exception + + parts = [] + parts_size = 0 + not_enough = True + + while not_enough: + while self.buffer and not_enough: + data = self.buffer.popleft() + ichar = data.find(b'\n') + if ichar < 0: + parts.append(data) + parts_size += len(data) + else: + ichar += 1 + head, tail = data[:ichar], data[ichar:] + if tail: + self.buffer.appendleft(tail) + not_enough = False + parts.append(head) + parts_size += len(head) + + if parts_size > self.limit: + self.byte_count -= parts_size + raise ValueError('Line is too long') + + if self.eof: + break + + if not_enough: + assert self.waiter is None + self.waiter = futures.Future() + yield from self.waiter + + line = b''.join(parts) + self.byte_count -= parts_size + + return line + + @tasks.coroutine + def read(self, n=-1): + if self._exception is not None: + raise self._exception + + if not n: + return b'' + + if n < 0: + while not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + else: + if not self.byte_count and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + + if n < 0 or self.byte_count <= n: + data = b''.join(self.buffer) + self.buffer.clear() + self.byte_count = 0 + return data + + parts = [] + parts_bytes = 0 + while self.buffer and parts_bytes < n: + data = self.buffer.popleft() + data_bytes = len(data) + if n < parts_bytes + data_bytes: + data_bytes = n - parts_bytes + data, rest = data[:data_bytes], data[data_bytes:] + self.buffer.appendleft(rest) + + parts.append(data) + parts_bytes += data_bytes + self.byte_count -= data_bytes + + return b''.join(parts) + + @tasks.coroutine + def readexactly(self, n): + if self._exception is not None: + raise self._exception + + if n <= 0: + return b'' + + while self.byte_count < n and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + + return (yield from self.read(n)) diff --git a/tulip/subprocess_transport.py b/tulip/subprocess_transport.py new file mode 100644 index 00000000..8d760f5b --- /dev/null +++ b/tulip/subprocess_transport.py @@ -0,0 +1,139 @@ +import fcntl +import logging +import os +import traceback + +from . import transports +from . import events + + +class UnixSubprocessTransport(transports.Transport): + """Transport class managing a subprocess. + + TODO: Separate this into something that just handles pipe I/O, + and something else that handles pipe setup, fork, and exec. + """ + + def __init__(self, protocol, args): + self._protocol = protocol # Not a factory! :-) + self._args = args # args[0] must be full path of binary. + self._event_loop = events.get_event_loop() + self._buffer = [] + self._eof = False + rstdin, self._wstdin = os.pipe() + self._rstdout, wstdout = os.pipe() + + # TODO: This is incredibly naive. Should look at + # subprocess.py for all the precautions around fork/exec. + pid = os.fork() + if not pid: + # Child. + try: + os.dup2(rstdin, 0) + os.dup2(wstdout, 1) + # TODO: What to do with stderr? + os.execv(args[0], args) + except: + try: + traceback.print_traceback() + finally: + os._exit(127) + + # Parent. + os.close(rstdin) + os.close(wstdout) + _setnonblocking(self._wstdin) + _setnonblocking(self._rstdout) + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.add_reader(self._rstdout, self._stdout_callback) + + def write(self, data): + assert not self._eof + assert isinstance(data, bytes), repr(data) + if not data: + return + + if not self._buffer: + # Attempt to write it right away first. + try: + n = os.write(self._wstdin, data) + except BlockingIOError: + pass + except Exception as exc: + self._fatal_error(exc) + return + else: + if n == len(data): + return + elif n: + data = data[n:] + self._event_loop.add_writer(self._wstdin, self._stdin_callback) + self._buffer.append(data) + + def write_eof(self): + assert not self._eof + assert self._wstdin >= 0 + self._eof = True + if not self._buffer: + self._event_loop.remove_writer(self._wstdin) + os.close(self._wstdin) + self._wstdin = -1 + + def close(self): + if not self._eof: + self.write_eof() + # XXX What else? + + def _fatal_error(self, exc): + logging.error('Fatal error: %r', exc) + if self._rstdout >= 0: + os.close(self._rstdout) + self._rstdout = -1 + if self._wstdin >= 0: + os.close(self._wstdin) + self._wstdin = -1 + self._eof = True + self._buffer = None + + def _stdin_callback(self): + data = b''.join(self._buffer) + assert data, "Data shold not be empty" + + self._buffer = [] + try: + n = os.write(self._wstdin, data) + except BlockingIOError: + self._buffer.append(data) + except Exception as exc: + self._fatal_error(exc) + else: + if n >= len(data): + self._event_loop.remove_writer(self._wstdin) + if self._eof: + os.close(self._wstdin) + self._wstdin = -1 + return + + elif n > 0: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def _stdout_callback(self): + try: + data = os.read(self._rstdout, 1024) + except BlockingIOError: + pass + else: + if data: + self._event_loop.call_soon(self._protocol.data_received, data) + else: + self._event_loop.remove_reader(self._rstdout) + os.close(self._rstdout) + self._rstdout = -1 + self._event_loop.call_soon(self._protocol.eof_received) + + +def _setnonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) diff --git a/tulip/tasks.py b/tulip/tasks.py new file mode 100644 index 00000000..41093564 --- /dev/null +++ b/tulip/tasks.py @@ -0,0 +1,320 @@ +"""Support for tasks, coroutines and the scheduler.""" + +__all__ = ['coroutine', 'task', 'Task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'as_completed', 'sleep', + ] + +import concurrent.futures +import functools +import inspect +import logging +import time + +from . import futures + + +def coroutine(func): + """Decorator to mark coroutines. + + Decorator wraps non generator functions and returns generator wrapper. + If non generator function returns generator of Future it yield-from it. + + TODO: This is a feel-good API only. It is not enforced. + """ + if inspect.isgeneratorfunction(func): + coro = func + else: + logging.warning( + 'Coroutine function %s is not a generator.', func.__name__) + + @functools.wraps(func) + def coro(*args, **kw): + res = func(*args, **kw) + + if isinstance(res, futures.Future) or inspect.isgenerator(res): + res = yield from res + + return res + + coro._is_coroutine = True # Not sure who can use this. + return coro + + +# TODO: Do we need this? +def iscoroutinefunction(func): + """Return True if func is a decorated coroutine function.""" + return (inspect.isgeneratorfunction(func) and + getattr(func, '_is_coroutine', False)) + + +# TODO: Do we need this? +def iscoroutine(obj): + """Return True if obj is a coroutine object.""" + return inspect.isgenerator(obj) # TODO: And what? + + +def task(func): + """Decorator for a coroutine to be wrapped in a Task.""" + def task_wrapper(*args, **kwds): + coro = func(*args, **kwds) + return Task(coro) + return task_wrapper + + +class Task(futures.Future): + """A coroutine wrapped in a Future.""" + + def __init__(self, coro, event_loop=None, timeout=None): + assert inspect.isgenerator(coro) # Must be a coroutine *object*. + super().__init__(event_loop=event_loop, timeout=timeout) + self._coro = coro + self._must_cancel = False + self._event_loop.call_soon(self._step) + + def __repr__(self): + res = super().__repr__() + if (self._must_cancel and + self._state == futures._PENDING and + ')'.format(self._coro.__name__) + res[i:] + return res + + def cancel(self): + if self.done(): + return False + self._must_cancel = True + # _step() will call super().cancel() to call the callbacks. + self._event_loop.call_soon(self._step_maybe) + return True + + def cancelled(self): + return self._must_cancel or super().cancelled() + + def _step_maybe(self): + # Helper for cancel(). + if not self.done(): + return self._step() + + def _step(self, value=None, exc=None): + if self.done(): + logging.warn('_step(): already done: %r, %r, %r', self, value, exc) + return + # We'll call either coro.throw(exc) or coro.send(value). + if self._must_cancel: + exc = futures.CancelledError + coro = self._coro + try: + if exc is not None: + result = coro.throw(exc) + elif value is not None: + result = coro.send(value) + else: + result = next(coro) + except StopIteration as exc: + if self._must_cancel: + super().cancel() + else: + self.set_result(exc.value) + except Exception as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + logging.exception('Exception in task') + except BaseException as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + logging.exception('BaseException in task') + raise + else: + # XXX No check for self._must_cancel here? + if isinstance(result, futures.Future): + if not result._blocking: + self._event_loop.call_soon( + self._step, + None, RuntimeError( + 'yield was used instead of yield from in task %r ' + 'with %r' % (self, result))) + else: + result._blocking = False + result.add_done_callback(self._wakeup) + + elif isinstance(result, concurrent.futures.Future): + # This ought to be more efficient than wrap_future(), + # because we don't create an extra Future. + result.add_done_callback( + lambda future: + self._event_loop.call_soon_threadsafe( + self._wakeup, future)) + else: + if inspect.isgenerator(result): + self._event_loop.call_soon( + self._step, + None, RuntimeError( + 'yield was used instead of yield from for ' + 'generator in task %r with %s' % (self, result))) + else: + if result is not None: + logging.warn('_step(): bad yield: %r', result) + + self._event_loop.call_soon(self._step) + + def _wakeup(self, future): + try: + value = future.result() + except Exception as exc: + self._step(None, exc) + else: + self._step(value, None) + + +# wait() and as_completed() similar to those in PEP 3148. + +FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED +FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION +ALL_COMPLETED = concurrent.futures.ALL_COMPLETED + + +# Even though this *is* a @coroutine, we don't mark it as such! +def wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Wait for the Futures and and coroutines given by fs to complete. + + Coroutines will be wrapped in Tasks. + + Returns two sets of Future: (done, pending). + + Usage: + + done, pending = yield from tulip.wait(fs) + + Note: This does not raise TimeoutError! Futures that aren't done + when the timeout occurs are returned in the second set. + """ + fs = _wrap_coroutines(fs) + return _wait(fs, timeout, return_when) + + +@coroutine +def _wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Internal helper: Like wait() but does not wrap coroutines.""" + done, pending = set(), set() + + errors = 0 + for f in fs: + if f.done(): + done.add(f) + if not f.cancelled() and f.exception() is not None: + errors += 1 + else: + pending.add(f) + + if (not pending or + timeout is not None and timeout <= 0 or + return_when == FIRST_COMPLETED and done or + return_when == FIRST_EXCEPTION and errors): + return done, pending + + # Will always be cancelled eventually. + bail = futures.Future(timeout=timeout) + + def _on_completion(f): + pending.remove(f) + done.add(f) + if (not pending or + return_when == FIRST_COMPLETED or + (return_when == FIRST_EXCEPTION and + not f.cancelled() and + f.exception() is not None)): + bail.cancel() + + try: + for f in pending: + f.add_done_callback(_on_completion) + try: + yield from bail + except futures.CancelledError: + pass + finally: + for f in pending: + f.remove_done_callback(_on_completion) + + really_done = set(f for f in pending if f.done()) + if really_done: + done.update(really_done) + pending.difference_update(really_done) + + return done, pending + + +# This is *not* a @coroutine! It is just an iterator (yielding Futures). +def as_completed(fs, timeout=None): + """Return an iterator whose values, when waited for, are Futures. + + This differs from PEP 3148; the proper way to use this is: + + for f in as_completed(fs): + result = yield from f # The 'yield from' may raise. + # Use result. + + Raises TimeoutError if the timeout occurs before all Futures are + done. + + Note: The futures 'f' are not necessarily members of fs. + """ + deadline = None + if timeout is not None: + deadline = time.monotonic() + timeout + + done = None # Make nonlocal happy. + fs = _wrap_coroutines(fs) + + while fs: + if deadline is not None: + timeout = deadline - time.monotonic() + + @coroutine + def _wait_for_some(): + nonlocal done, fs + done, fs = yield from _wait(fs, timeout=timeout, + return_when=FIRST_COMPLETED) + if not done: + fs = set() + raise futures.TimeoutError() + return done.pop().result() # May raise. + + yield Task(_wait_for_some()) + for f in done: + yield f + + +def _wrap_coroutines(fs): + """Internal helper to process an iterator of Futures and coroutines. + + Returns a set of Futures. + """ + wrapped = set() + for f in fs: + if not isinstance(f, futures.Future): + assert iscoroutine(f) + f = Task(f) + wrapped.add(f) + return wrapped + + +def sleep(when, result=None): + """Return a Future that completes after a given time (in seconds). + + It's okay to cancel the Future. + + Undocumented feature: sleep(when, x) sets the Future's result to x. + """ + future = futures.Future() + future._event_loop.call_later(when, future.set_result, result) + return future diff --git a/tulip/test_utils.py b/tulip/test_utils.py new file mode 100644 index 00000000..9b87db2f --- /dev/null +++ b/tulip/test_utils.py @@ -0,0 +1,30 @@ +"""Utilities shared by tests.""" + +import logging +import socket +import sys +import unittest + + +if sys.platform == 'win32': # pragma: no cover + from .winsocketpair import socketpair +else: + from socket import socketpair # pragma: no cover + + +class LogTrackingTestCase(unittest.TestCase): + + def setUp(self): + self._logger = logging.getLogger() + self._log_level = self._logger.getEffectiveLevel() + + def tearDown(self): + self._logger.setLevel(self._log_level) + + def suppress_log_errors(self): # pragma: no cover + if self._log_level >= logging.WARNING: + self._logger.setLevel(logging.CRITICAL) + + def suppress_log_warnings(self): # pragma: no cover + if self._log_level >= logging.WARNING: + self._logger.setLevel(logging.ERROR) diff --git a/tulip/transports.py b/tulip/transports.py new file mode 100644 index 00000000..a9ec07a0 --- /dev/null +++ b/tulip/transports.py @@ -0,0 +1,134 @@ +"""Abstract Transport class.""" + +__all__ = ['ReadTransport', 'WriteTransport', 'Transport'] + + +class BaseTransport: + """Base ABC for transports.""" + + def __init__(self, extra=None): + if extra is None: + extra = {} + self._extra = extra + + def get_extra_info(self, name, default=None): + """Get optional transport information.""" + return self._extra.get(name, default) + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + raise NotImplementedError + + +class ReadTransport(BaseTransport): + """ABC for read-only transports.""" + + def pause(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + raise NotImplementedError + + +class WriteTransport(BaseTransport): + """ABC for write-only transports.""" + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def write_eof(self): + """Closes the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + raise NotImplementedError + + def can_write_eof(self): + """Return True if this protocol supports write_eof(), False if not.""" + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class Transport(ReadTransport, WriteTransport): + """ABC representing a bidirectional transport. + + There may be several implementations, but typically, the user does + not implement new transports; rather, the platform provides some + useful transports that are implemented using the platform's best + practices. + + The user never instantiates a transport directly; they call a + utility function, passing it a protocol factory and other + information necessary to create the transport and protocol. (E.g. + EventLoop.create_connection() or EventLoop.start_serving().) + + The utility function will asynchronously create a transport and a + protocol and hook them up by calling the protocol's + connection_made() method, passing it the transport. + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + +class DatagramTransport(BaseTransport): + """ABC for datagram (UDP) transports.""" + + def sendto(self, data, addr=None): + """Send data to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + addr is target socket address. + If addr is None use target address pointed on transport creation. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError diff --git a/tulip/unix_events.py b/tulip/unix_events.py new file mode 100644 index 00000000..9e75e3e9 --- /dev/null +++ b/tulip/unix_events.py @@ -0,0 +1,301 @@ +"""Selector eventloop for Unix with signal handling.""" + +import errno +import fcntl +import logging +import os +import socket +import sys + +try: + import signal +except ImportError: # pragma: no cover + signal = None + +from . import events +from . import selector_events +from . import transports + + +__all__ = ['SelectorEventLoop'] + + +if sys.platform == 'win32': # pragma: no cover + raise ImportError('Signals are not really supported on Windows') + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Unix event loop + + Adds signal handling to SelectorEventLoop + """ + + def __init__(self, selector=None): + super().__init__(selector) + self._signal_handlers = {} + + def _socketpair(self): + return socket.socketpair() + + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + self._check_signal(sig) + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except ValueError as exc: + raise RuntimeError(str(exc)) + + handle = events.make_handle(callback, args) + self._signal_handlers[sig] = handle + + try: + signal.signal(sig, self._handle_signal) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as nexc: + logging.info('set_wakeup_fd(-1) failed: %s', nexc) + + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + return handle + + def _handle_signal(self, sig, arg): + """Internal helper that is the actual signal handler.""" + handle = self._signal_handlers.get(sig) + if handle is None: + return # Assume it's some race condition. + if handle.cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self.call_soon_threadsafe(handle) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not.""" + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as exc: + logging.info('set_wakeup_fd(-1) failed: %s', exc) + + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + + if signal is None: + raise RuntimeError('Signals are not supported') + + if not (1 <= sig < signal.NSIG): + raise ValueError( + 'sig {} out of range(1, {})'.format(sig, signal.NSIG)) + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixReadPipeTransport(self, pipe, protocol, waiter, extra) + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixWritePipeTransport(self, pipe, protocol, waiter, extra) + + +def _set_nonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + + +class _UnixReadPipeTransport(transports.ReadTransport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, event_loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._event_loop = event_loop + self._pipe = pipe + self._fileno = pipe.fileno() + _set_nonblocking(self._fileno) + self._protocol = protocol + self._closing = False + self._event_loop.add_reader(self._fileno, self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = os.read(self._fileno, self.max_size) + except BlockingIOError: + pass + except OSError as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + self._event_loop.remove_reader(self._fileno) + self._protocol.eof_received() + + def pause(self): + self._event_loop.remove_reader(self._fileno) + + def resume(self): + self._event_loop.add_reader(self._fileno, self._read_ready) + + def close(self): + if not self._closing: + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + logging.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._closing = True + self._event_loop.remove_reader(self._fileno) + self._call_connection_lost(exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + + +class _UnixWritePipeTransport(transports.WriteTransport): + + def __init__(self, event_loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._event_loop = event_loop + self._pipe = pipe + self._fileno = pipe.fileno() + _set_nonblocking(self._fileno) + self._protocol = protocol + self._buffer = [] + self._closing = False # Set when close() or write_eof() called. + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def write(self, data): + assert isinstance(data, bytes), repr(data) + assert not self._closing + if not data: + return + + if not self._buffer: + # Attempt to send it right away first. + try: + n = os.write(self._fileno, data) + except BlockingIOError: + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + if n == len(data): + return + elif n > 0: + data = data[n:] + self._event_loop.add_writer(self._fileno, self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + assert data, "Data should not be empty" + + self._buffer.clear() + try: + n = os.write(self._fileno, data) + except BlockingIOError: + self._buffer.append(data) + except Exception as exc: + self._fatal_error(exc) + else: + if n == len(data): + self._event_loop.remove_writer(self._fileno) + if self._closing: + self._call_connection_lost(None) + return + elif n > 0: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def can_write_eof(self): + return True + + def write_eof(self): + assert not self._closing + assert self._pipe + self._closing = True + if not self._buffer: + self._call_connection_lost(None) + + def close(self): + if not self._closing: + # write_eof is all what we needed to close the write pipe + self.write_eof() + + def abort(self): + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + logging.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc=None): + self._closing = True + self._buffer.clear() + self._event_loop.remove_writer(self._fileno) + self._call_connection_lost(exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() diff --git a/tulip/windows_events.py b/tulip/windows_events.py new file mode 100644 index 00000000..4297f804 --- /dev/null +++ b/tulip/windows_events.py @@ -0,0 +1,157 @@ +"""Selector and proactor eventloops for Windows.""" + +import logging +import socket +import weakref +import struct +import _winapi + +from . import futures +from . import proactor_events +from . import selector_events +from . import winsocketpair +from . import _overlapped + + +__all__ = ['SelectorEventLoop', 'ProactorEventLoop'] + + +NULL = 0 +INFINITE = 0xffffffff +ERROR_CONNECTION_REFUSED = 1225 +ERROR_CONNECTION_ABORTED = 1236 + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + def _socketpair(self): + return winsocketpair.socketpair() + + +class ProactorEventLoop(proactor_events.BaseProactorEventLoop): + def __init__(self, proactor=None): + if proactor is None: + proactor = IocpProactor() + super().__init__(proactor) + + def _socketpair(self): + return winsocketpair.socketpair() + + +class IocpProactor: + + def __init__(self, concurrency=0xffffffff): + self._results = [] + self._iocp = _overlapped.CreateIoCompletionPort( + _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency) + self._cache = {} + self._registered = weakref.WeakSet() + + def registered_count(self): + return len(self._cache) + + def select(self, timeout=None): + if not self._results: + self._poll(timeout) + tmp = self._results + self._results = [] + return tmp + + def recv(self, conn, nbytes, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + ov.WSARecv(conn.fileno(), nbytes, flags) + return self._register(ov, conn, ov.getresult) + + def send(self, conn, buf, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + ov.WSASend(conn.fileno(), buf, flags) + return self._register(ov, conn, ov.getresult) + + def accept(self, listener): + self._register_with_iocp(listener) + conn = self._get_accept_socket() + ov = _overlapped.Overlapped(NULL) + ov.AcceptEx(listener.fileno(), conn.fileno()) + + def finish_accept(): + addr = ov.getresult() + buf = struct.pack('@P', listener.fileno()) + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_ACCEPT_CONTEXT, + buf) + conn.settimeout(listener.gettimeout()) + return conn, conn.getpeername() + + return self._register(ov, listener, finish_accept) + + def connect(self, conn, address): + self._register_with_iocp(conn) + _overlapped.BindLocal(conn.fileno(), len(address)) + ov = _overlapped.Overlapped(NULL) + ov.ConnectEx(conn.fileno(), address) + + def finish_connect(): + ov.getresult() + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_CONNECT_CONTEXT, + 0) + return conn + + return self._register(ov, conn, finish_connect) + + def _register_with_iocp(self, obj): + if obj not in self._registered: + self._registered.add(obj) + _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) + + def _register(self, ov, obj, callback): + f = futures.Future() + self._cache[ov.address] = (f, ov, obj, callback) + return f + + def _get_accept_socket(self): + s = socket.socket() + s.settimeout(0) + return s + + def _poll(self, timeout=None): + if timeout is None: + ms = INFINITE + elif timeout < 0: + raise ValueError("negative timeout") + else: + ms = int(timeout * 1000 + 0.5) + if ms >= INFINITE: + raise ValueError("timeout too big") + while True: + status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms) + if status is None: + return + address = status[3] + f, ov, obj, callback = self._cache.pop(address) + try: + value = callback() + except OSError as e: + f.set_exception(e) + self._results.append(f) + else: + f.set_result(value) + self._results.append(f) + ms = 0 + + def close(self): + for (f, ov, obj, callback) in self._cache.values(): + try: + ov.cancel() + except OSError: + pass + + while self._cache: + if not self._poll(1): + logging.debug('taking long time to close proactor') + + self._results = [] + if self._iocp is not None: + _winapi.CloseHandle(self._iocp) + self._iocp = None diff --git a/tulip/winsocketpair.py b/tulip/winsocketpair.py new file mode 100644 index 00000000..bd1e0928 --- /dev/null +++ b/tulip/winsocketpair.py @@ -0,0 +1,34 @@ +"""A socket pair usable as a self-pipe, for Windows. + +Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. +""" + +import socket +import sys + +if sys.platform != 'win32': # pragma: no cover + raise ImportError('winsocketpair is win32 only') + + +def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """Emulate the Unix socketpair() function on Windows.""" + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + lsock.bind(('localhost', 0)) + lsock.listen(1) + addr, port = lsock.getsockname() + csock = socket.socket(family, type, proto) + csock.setblocking(False) + try: + csock.connect((addr, port)) + except (BlockingIOError, InterruptedError): + pass + except: + lsock.close() + csock.close() + raise + ssock, _ = lsock.accept() + csock.setblocking(True) + lsock.close() + return (ssock, csock) From 3ba566906add5cee12c2d930326e326a776282e5 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 27 Mar 2013 15:24:38 -0700 Subject: [PATCH 0387/1502] Ignore build subdirectory. --- .hgignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.hgignore b/.hgignore index 25902497..99870025 100644 --- a/.hgignore +++ b/.hgignore @@ -9,3 +9,4 @@ htmlcov$ venv$ distribute_setup.py$ distribute-\d+.\d+.\d+.tar.gz$ +build$ From a3bb5299cbebbb5e850d0b602acdc05ea4a14cdc Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Thu, 28 Mar 2013 00:54:02 +0200 Subject: [PATCH 0388/1502] Replace xkcd.com to example.com for tests, finally get rid of xkcd.com mentioning. --- tests/base_events_test.py | 2 +- tests/events_test.py | 10 +++++----- tulip/TODO | 2 -- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index 85b013fa..361b1791 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -277,7 +277,7 @@ def _socket(*args, **kw): self.event_loop.getaddrinfo = getaddrinfo task = tasks.Task( - self.event_loop.create_connection(MyProto, 'xkcd.com', 80)) + self.event_loop.create_connection(MyProto, 'example.com', 80)) task._step() exc = task.exception() self.assertEqual("Multiple exceptions: err1, err2", str(exc)) diff --git a/tests/events_test.py b/tests/events_test.py index 40859211..f4db2c34 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -38,7 +38,7 @@ def connection_made(self, transport): self.transport = transport assert self.state == 'INITIAL', self.state self.state = 'CONNECTED' - transport.write(b'GET / HTTP/1.0\r\nHost: xkcd.com\r\n\r\n') + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') def data_received(self, data): assert self.state == 'CONNECTED', self.state @@ -663,7 +663,7 @@ def test_create_ssl_connection(self): def test_create_connection_host_port_sock(self): self.suppress_log_errors() coro = self.event_loop.create_connection( - MyProto, 'xkcd.com', 80, sock=object()) + MyProto, 'example.com', 80, sock=object()) self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) def test_create_connection_no_host_port_sock(self): @@ -676,7 +676,7 @@ def test_create_connection_no_getaddrinfo(self): getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() getaddrinfo.return_value = [] - coro = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + coro = self.event_loop.create_connection(MyProto, 'example.com', 80) self.assertRaises( socket.error, self.event_loop.run_until_complete, coro) @@ -685,7 +685,7 @@ def test_create_connection_connect_err(self): self.event_loop.sock_connect = unittest.mock.Mock() self.event_loop.sock_connect.side_effect = socket.error - coro = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + coro = self.event_loop.create_connection(MyProto, 'example.com', 80) self.assertRaises( socket.error, self.event_loop.run_until_complete, coro) @@ -700,7 +700,7 @@ def getaddrinfo(*args, **kw): self.event_loop.sock_connect = unittest.mock.Mock() self.event_loop.sock_connect.side_effect = socket.error - coro = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + coro = self.event_loop.create_connection(MyProto, 'example.com', 80) self.assertRaises( socket.error, self.event_loop.run_until_complete, coro) diff --git a/tulip/TODO b/tulip/TODO index acec5c24..b3a9302e 100644 --- a/tulip/TODO +++ b/tulip/TODO @@ -24,5 +24,3 @@ TODO in tulip v2 (tulip/ package directory) - buffered stream implementation - Primitives like par() and wait_one() - -- Remove test dependency on xkcd.com, write our own test server From 4337dc0ebc7006d6b78350fe1fa4850a2d1e10db Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 27 Mar 2013 16:24:59 -0700 Subject: [PATCH 0389/1502] windows tests fixes --- tests/events_test.py | 29 +++++++++++------------------ tests/locks_test.py | 16 ++++++---------- tests/queues_test.py | 20 ++++++++++---------- 3 files changed, 27 insertions(+), 38 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index f4db2c34..281130bb 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -184,37 +184,30 @@ def test_run(self): self.event_loop.run() # Returns immediately. def test_run_nesting(self): - err = None + self.suppress_log_errors() @tasks.coroutine def coro(): - nonlocal err yield from [] self.assertTrue(self.event_loop.is_running()) - try: - self.event_loop.run_until_complete( - tasks.sleep(0.1)) - except Exception as exc: - err = exc + self.event_loop.run_until_complete(tasks.sleep(0.1)) - self.event_loop.run_until_complete(tasks.Task(coro())) - self.assertIsInstance(err, RuntimeError) + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, coro()) def test_run_once_nesting(self): - err = None + self.suppress_log_errors() @tasks.coroutine def coro(): - nonlocal err yield from [] tasks.sleep(0.1) - try: - self.event_loop.run_once() - except Exception as exc: - err = exc + self.event_loop.run_once() - self.event_loop.run_until_complete(tasks.Task(coro())) - self.assertIsInstance(err, RuntimeError) + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, coro()) def test_run_once_block(self): called = False @@ -354,7 +347,7 @@ def run(arg): def test_run_in_executor_with_handle(self): def run(arg): - time.sleep(0.1) + time.sleep(0.01) return arg handle = events.Handle(run, ('yo',)) f2 = self.event_loop.run_in_executor(None, handle) diff --git a/tests/locks_test.py b/tests/locks_test.py index 7d2111d9..e761c677 100644 --- a/tests/locks_test.py +++ b/tests/locks_test.py @@ -105,7 +105,7 @@ def test_acquire_timeout(self): lock = locks.Lock() self.event_loop.run_until_complete(lock.acquire()) - self.event_loop.call_later(0.1, lock.release) + self.event_loop.call_later(0.01, lock.release) acquired = self.event_loop.run_until_complete(lock.acquire(10.1)) self.assertTrue(acquired) @@ -114,16 +114,12 @@ def test_acquire_timeout_mixed(self): self.event_loop.run_until_complete(lock.acquire()) tasks.Task(lock.acquire()) tasks.Task(lock.acquire()) - acquire_task = tasks.Task(lock.acquire(0.1)) + acquire_task = tasks.Task(lock.acquire(0.01)) tasks.Task(lock.acquire()) - t0 = time.monotonic() acquired = self.event_loop.run_until_complete(acquire_task) self.assertFalse(acquired) - total_time = (time.monotonic() - t0) - self.assertTrue(0.08 < total_time < 0.12) - self.assertEqual(3, len(lock._waiters)) def test_acquire_cancel(self): @@ -241,7 +237,7 @@ def test_wait_timeout(self): self.assertTrue(0.08 < total_time < 0.12) ev = locks.EventWaiter() - self.event_loop.call_later(0.1, ev.set) + self.event_loop.call_later(0.01, ev.set) acquired = self.event_loop.run_until_complete(ev.wait(10.1)) self.assertTrue(acquired) @@ -451,7 +447,7 @@ def test_wait_for_timeout(self): @tasks.coroutine def c1(result): yield from cond.acquire() - if (yield from cond.wait_for(predicate, 0.2)): + if (yield from cond.wait_for(predicate, 0.1)): result.append(1) else: result.append(2) @@ -475,7 +471,7 @@ def c1(result): self.assertEqual(3, predicate.call_count) total_time = (time.monotonic() - t0) - self.assertTrue(0.18 < total_time < 0.22) + self.assertTrue(0.08 < total_time < 0.12) def test_wait_for_unacquired(self): self.suppress_log_errors() @@ -681,7 +677,7 @@ def test_acquire_timeout(self): sem = locks.Semaphore() self.event_loop.run_until_complete(sem.acquire()) - self.event_loop.call_later(0.1, sem.release) + self.event_loop.call_later(0.01, sem.release) acquired = self.event_loop.run_until_complete(sem.acquire(10.1)) self.assertTrue(acquired) diff --git a/tests/queues_test.py b/tests/queues_test.py index d86abd7f..8c1c0afb 100644 --- a/tests/queues_test.py +++ b/tests/queues_test.py @@ -103,14 +103,14 @@ def putter(): @tasks.coroutine def test(): tasks.Task(putter()) - yield from tasks.sleep(0.1) + yield from tasks.sleep(0.01) # The putter is blocked after putting two items. self.assertEqual([0, 1], have_been_put) self.assertEqual(0, q.get_nowait()) # Let the putter resume and put last item. - yield from tasks.sleep(0.1) + yield from tasks.sleep(0.01) self.assertEqual([0, 1, 2], have_been_put) self.assertEqual(1, q.get_nowait()) self.assertEqual(2, q.get_nowait()) @@ -146,7 +146,7 @@ def queue_get(): @tasks.coroutine def queue_put(): - self.event_loop.call_later(0.1, q.put_nowait, 1) + self.event_loop.call_later(0.01, q.put_nowait, 1) queue_get_task = tasks.Task(queue_get()) yield from started.wait() self.assertFalse(finished) @@ -172,7 +172,7 @@ def test_get_timeout(self): @tasks.coroutine def queue_get(): with self.assertRaises(queue.Empty): - return (yield from q.get(timeout=0.1)) + return (yield from q.get(timeout=0.01)) # Get works after timeout, with blocking and non-blocking put. q.put_nowait(1) @@ -188,12 +188,12 @@ def test_get_timeout_cancelled(self): @tasks.coroutine def queue_get(): - return (yield from q.get(timeout=0.2)) + return (yield from q.get(timeout=0.05)) @tasks.coroutine def test(): get_task = tasks.Task(queue_get()) - yield from tasks.sleep(0.1) # let the task start + yield from tasks.sleep(0.01) # let the task start q.put_nowait(1) return (yield from get_task) @@ -227,7 +227,7 @@ def queue_put(): @tasks.coroutine def queue_get(): - self.event_loop.call_later(0.1, q.get_nowait) + self.event_loop.call_later(0.01, q.get_nowait) queue_put_task = tasks.Task(queue_put()) yield from started.wait() self.assertFalse(finished) @@ -253,14 +253,14 @@ def test_put_timeout(self): @tasks.coroutine def queue_put(): with self.assertRaises(queue.Full): - return (yield from q.put(1, timeout=0.1)) + return (yield from q.put(1, timeout=0.01)) self.assertEqual(0, q.get_nowait()) # Put works after timeout, with blocking and non-blocking get. get_task = tasks.Task(q.get()) # Let the get start waiting. - yield from tasks.sleep(0.1) + yield from tasks.sleep(0.01) q.put_nowait(2) self.assertEqual(2, (yield from get_task)) @@ -274,7 +274,7 @@ def test_put_timeout_cancelled(self): @tasks.coroutine def queue_put(): - yield from q.put(1, timeout=0.1) + yield from q.put(1, timeout=0.01) @tasks.coroutine def test(): From dd25f6a21c198bf3c277fa55c77e70477b4533ff Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 27 Mar 2013 16:51:00 -0700 Subject: [PATCH 0390/1502] server ssl socket transport support --- srv.py | 63 +++++++++++++++++++++++++++++++++------- sslsrv.py | 56 ----------------------------------- tests/events_test.py | 63 +++++++++++++++++++++++++++++++++++++++- tulip/base_events.py | 20 +++++++++---- tulip/http/protocol.py | 5 ++-- tulip/selector_events.py | 32 ++++++++++++-------- 6 files changed, 152 insertions(+), 87 deletions(-) delete mode 100644 sslsrv.py diff --git a/srv.py b/srv.py index b28abbda..47a1ed45 100755 --- a/srv.py +++ b/srv.py @@ -1,9 +1,15 @@ #!/usr/bin/env python3 """Simple server written using an event loop.""" +import argparse import email.message +import logging import os import sys +try: + import ssl +except ImportError: # pragma: no cover + ssl = None assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' @@ -94,18 +100,55 @@ def handle_request(self, request_info, message): self.close() +ARGS = argparse.ArgumentParser(description="Run simple http server.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') +ARGS.add_argument( + '--iocp', action="store_true", dest='iocp', help='Windows IOCP event loop') +ARGS.add_argument( + '--ssl', action="store_true", dest='ssl', help='Run ssl mode.') +ARGS.add_argument( + '--sslcert', action="store", dest='certfile', help='SSL cert file.') +ARGS.add_argument( + '--sslkey', action="store", dest='keyfile', help='SSL key file.') + + def main(): - host = '127.0.0.1' - port = 8080 - if sys.argv[1:]: - host = sys.argv[1] - if sys.argv[2:]: - port = int(sys.argv[2]) - elif ':' in host: - host, port = host.split(':', 1) - port = int(port) + args = ARGS.parse_args() + + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if args.iocp: + from tulip import windows_events + sys.argv.remove('--iocp') + logging.info('using iocp') + el = windows_events.ProactorEventLoop() + tulip.set_event_loop(el) + + if args.ssl: + here = os.path.join(os.path.dirname(__file__), 'tests') + + if args.certfile: + certfile = args.certfile or os.path.join(here, 'sample.crt') + keyfile = args.keyfile or os.path.join(here, 'sample.key') + else: + certfile = os.path.join(here, 'sample.crt') + keyfile = os.path.join(here, 'sample.key') + + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain(certfile, keyfile) + else: + sslcontext = False + loop = tulip.get_event_loop() - f = loop.start_serving(lambda: HttpServer(debug=True), host, port) + f = loop.start_serving( + lambda: HttpServer(debug=True), args.host, args.port, ssl=sslcontext) x = loop.run_until_complete(f) print('serving on', x.getsockname()) loop.run_forever() diff --git a/sslsrv.py b/sslsrv.py deleted file mode 100644 index a1bc04f9..00000000 --- a/sslsrv.py +++ /dev/null @@ -1,56 +0,0 @@ -"""Serve up an SSL connection, after Python ssl module docs.""" - -import socket -import ssl -import os - - -def main(): - context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) - certfile = getcertfile() - context.load_cert_chain(certfile=certfile, keyfile=certfile) - bindsocket = socket.socket() - bindsocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - bindsocket.bind(('', 4443)) - bindsocket.listen(5) - - while True: - newsocket, fromaddr = bindsocket.accept() - try: - connstream = context.wrap_socket(newsocket, server_side=True) - try: - deal_with_client(connstream) - finally: - connstream.shutdown(socket.SHUT_RDWR) - connstream.close() - except Exception as exc: - print(exc.__class__.__name__ + ':', exc) - - -def getcertfile(): - import test # Test package - testdir = os.path.dirname(test.__file__) - certfile = os.path.join(testdir, 'keycert.pem') - print('certfile =', certfile) - return certfile - - -def deal_with_client(connstream): - data = connstream.recv(1024) - # empty data means the client is finished with us - while data: - if not do_something(connstream, data): - # we'll assume do_something returns False - # when we're finished with client - break - data = connstream.recv(1024) - # finished with client - - -def do_something(connstream, data): - # just echo back - connstream.sendall(data) - - -if __name__ == '__main__': - main() diff --git a/tests/events_test.py b/tests/events_test.py index 281130bb..6df6d672 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -675,6 +675,11 @@ def test_create_connection_no_getaddrinfo(self): def test_create_connection_connect_err(self): self.suppress_log_errors() + + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80))] + self.event_loop.getaddrinfo = getaddrinfo self.event_loop.sock_connect = unittest.mock.Mock() self.event_loop.sock_connect.side_effect = socket.error @@ -693,7 +698,8 @@ def getaddrinfo(*args, **kw): self.event_loop.sock_connect = unittest.mock.Mock() self.event_loop.sock_connect.side_effect = socket.error - coro = self.event_loop.create_connection(MyProto, 'example.com', 80) + coro = self.event_loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET) self.assertRaises( socket.error, self.event_loop.run_until_complete, coro) @@ -735,6 +741,61 @@ def factory(): # recv()/send() on the serving socket client.close() + @unittest.skipIf(ssl is None, 'No ssl module') + def test_start_serving_ssl(self): + proto = None + + class ClientMyProto(MyProto): + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def factory(): + nonlocal proto + proto = MyProto() + return proto + + here = os.path.dirname(__file__) + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain( + certfile=os.path.join(here, 'sample.crt'), + keyfile=os.path.join(here, 'sample.key')) + + f = self.event_loop.start_serving( + factory, '0.0.0.0', 0, ssl=sslcontext) + + sock = self.event_loop.run_until_complete(f) + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + + f_c = self.event_loop.create_connection( + ClientMyProto, host, port, ssl=True) + client, pr = self.event_loop.run_until_complete(f_c) + + client.write(b'xxx') + self.event_loop.run_once() + self.assertIsInstance(proto, MyProto) + self.event_loop.run_once() + self.assertEqual('CONNECTED', proto.state) + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('socket')) + conn = proto.transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + self.assertEqual( + '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + + # close connection + proto.transport.close() + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + def test_start_serving_sock(self): sock_ob = socket.socket(type=socket.SOCK_STREAM) sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) diff --git a/tulip/base_events.py b/tulip/base_events.py index 573807d8..42f2217a 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -50,12 +50,13 @@ def __init__(self): self._internal_fds = 0 self._running = False - def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None): """Create socket transport.""" raise NotImplementedError - def _make_ssl_transport(self, rawsock, protocol, - sslcontext, waiter, extra=None): + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + server_side=False, extra=None): """Create SSL transport.""" raise NotImplementedError @@ -326,12 +327,18 @@ def create_connection(self, protocol_factory, host=None, port=None, *, sock.setblocking(False) + if ssl: + import ssl as sslmod + sslcontext = sslmod.SSLContext(sslmod.PROTOCOL_SSLv23) + sock = sslcontext.wrap_socket(sock, server_side=False, + do_handshake_on_connect=False) + protocol = protocol_factory() waiter = futures.Future() if ssl: sslcontext = None if isinstance(ssl, bool) else ssl transport = self._make_ssl_transport( - sock, protocol, sslcontext, waiter) + sock, protocol, sslcontext, waiter, server_side=False) else: transport = self._make_socket_transport(sock, protocol, waiter) @@ -411,7 +418,8 @@ def create_datagram_endpoint(self, protocol_factory, # TODO: Or create_server()? @tasks.task def start_serving(self, protocol_factory, host=None, port=None, *, - family=0, proto=0, flags=0, backlog=100, sock=None): + family=0, proto=0, flags=0, backlog=100, sock=None, + ssl=False): """XXX""" if host is not None or port is not None: if sock is not None: @@ -447,7 +455,7 @@ def start_serving(self, protocol_factory, host=None, port=None, *, sock.listen(backlog) sock.setblocking(False) - self._start_serving(protocol_factory, sock) + self._start_serving(protocol_factory, sock, ssl) return sock @tasks.coroutine diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py index 6a0e1279..c5c4e499 100644 --- a/tulip/http/protocol.py +++ b/tulip/http/protocol.py @@ -752,7 +752,7 @@ def _write_chunked_payload(self): break self.transport.write('{:x}\r\n'.format(len(chunk)).encode('ascii')) - self.transport.write(chunk) + self.transport.write(bytes(chunk)) self.transport.write(b'\r\n') def _write_length_payload(self, length): @@ -798,7 +798,8 @@ def add_chunking_filter(self, chunk_size=16*1024): buf.extend(chunk) while len(buf) >= chunk_size: - chunk, buf = buf[:chunk_size], buf[chunk_size:] + chunk = bytes(buf[:chunk_size]) + del buf[:chunk_size] yield chunk chunk = yield EOL_MARKER diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 46e04dae..68ec6a92 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -45,13 +45,14 @@ def __init__(self, selector=None): self._selector = selector self._make_self_pipe() - def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None): return _SelectorSocketTransport(self, sock, protocol, waiter, extra) - def _make_ssl_transport(self, rawsock, protocol, - sslcontext, waiter, extra=None): + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + server_side=False, extra=None): return _SelectorSslTransport( - self, rawsock, protocol, sslcontext, waiter, extra) + self, rawsock, protocol, sslcontext, waiter, server_side, extra) def _make_datagram_transport(self, sock, protocol, address=None, extra=None): @@ -94,13 +95,14 @@ def _write_to_self(self): except (BlockingIOError, InterruptedError): pass - def _start_serving(self, protocol_factory, sock): + def _start_serving(self, protocol_factory, sock, ssl=False): self.add_reader(sock.fileno(), self._accept_connection, - protocol_factory, sock) + protocol_factory, sock, ssl) - def _accept_connection(self, protocol_factory, sock): + def _accept_connection(self, protocol_factory, sock, ssl=False): try: conn, addr = sock.accept() + conn.setblocking(False) except (BlockingIOError, InterruptedError): pass # False alarm. except: @@ -111,8 +113,14 @@ def _accept_connection(self, protocol_factory, sock): # TODO: Someone will want an error handler for this. logging.exception('Accept failed') else: - self._make_socket_transport( - conn, protocol_factory(), extra={'addr': addr}) + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + self._make_ssl_transport( + conn, protocol_factory(), sslcontext, futures.Future(), + server_side=True, extra={'addr': addr}) + else: + self._make_socket_transport( + conn, protocol_factory(), extra={'addr': addr}) # It's now up to the protocol to handle the connection. def add_reader(self, fd, callback, *args): @@ -414,8 +422,8 @@ def _call_connection_lost(self, exc): class _SelectorSslTransport(transports.Transport): - def __init__(self, event_loop, rawsock, - protocol, sslcontext, waiter, extra=None): + def __init__(self, event_loop, rawsock, protocol, sslcontext, waiter, + server_side=False, extra=None): super().__init__(extra) self._event_loop = event_loop @@ -424,7 +432,7 @@ def __init__(self, event_loop, rawsock, sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) self._sslcontext = sslcontext self._waiter = waiter - sslsock = sslcontext.wrap_socket(rawsock, + sslsock = sslcontext.wrap_socket(rawsock, server_side=server_side, do_handshake_on_connect=False) self._sslsock = sslsock self._buffer = [] From 312f408391d68b51681f3327649242f6f04c7279 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 27 Mar 2013 17:03:38 -0700 Subject: [PATCH 0391/1502] replace wsgiref tests server with tulip.http.server --- tests/events_test.py | 93 ++++++++++++-------------------------------- tulip/test_utils.py | 60 ++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 68 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 6df6d672..1bb92b4d 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -1,7 +1,6 @@ """Tests for events.py.""" import concurrent.futures -import contextlib import gc import io import os @@ -18,8 +17,6 @@ import unittest import unittest.mock -from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer - from tulip import events from tulip import transports from tulip import protocols @@ -131,55 +128,6 @@ def tearDown(self): gc.collect() super().tearDown() - @contextlib.contextmanager - def run_test_server(self, *, use_ssl=False): - - class SilentWSGIRequestHandler(WSGIRequestHandler): - def get_stderr(self): - return io.StringIO() - - def log_message(self, format, *args): - pass - - class SilentWSGIServer(WSGIServer): - def handle_error(self, request, client_address): - pass - - class SSLWSGIServer(SilentWSGIServer): - def finish_request(self, request, client_address): - here = os.path.dirname(__file__) - keyfile = os.path.join(here, 'sample.key') - certfile = os.path.join(here, 'sample.crt') - ssock = ssl.wrap_socket(request, - keyfile=keyfile, - certfile=certfile, - server_side=True) - try: - self.RequestHandlerClass(ssock, client_address, self) - ssock.close() - except OSError: - # maybe socket has been closed by peer - pass - - def app(environ, start_response): - status = '302 Found' - headers = [('Content-type', 'text/plain')] - start_response(status, headers) - return [b'Test message'] - - # Run the test WSGI server in a separate thread in order not to - # interfere with event handling in the main thread - server_class = SSLWSGIServer if use_ssl else SilentWSGIServer - httpd = make_server('127.0.0.1', 0, app, - server_class, SilentWSGIRequestHandler) - server_thread = threading.Thread(target=httpd.serve_forever) - server_thread.start() - try: - yield httpd - finally: - httpd.shutdown() - server_thread.join() - def test_run(self): self.event_loop.run() # Returns immediately. @@ -474,19 +422,18 @@ def sender(): self.assertTrue(data == b'x'*256) def test_sock_client_ops(self): - with self.run_test_server() as httpd: + with test_utils.run_test_server(self.event_loop) as addr: sock = socket.socket() sock.setblocking(False) - address = httpd.socket.getsockname() self.event_loop.run_until_complete( - self.event_loop.sock_connect(sock, address)) + self.event_loop.sock_connect(sock, addr)) self.event_loop.run_until_complete( self.event_loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) data = self.event_loop.run_until_complete( self.event_loop.sock_recv(sock, 1024)) sock.close() - self.assertTrue(re.match(rb'HTTP/1.0 302 Found', data), data) + self.assertTrue(re.match(rb'HTTP/1.0 200 OK', data), data) def test_sock_client_fail(self): # Make sure that we will get an unused port @@ -617,10 +564,8 @@ def my_handler(*args): self.assertEqual(caught, 1) def test_create_connection(self): - with self.run_test_server() as httpd: - host, port = httpd.socket.getsockname() - f = tasks.Task( - self.event_loop.create_connection(MyProto, host, port)) + with test_utils.run_test_server(self.event_loop) as addr: + f = self.event_loop.create_connection(MyProto, *addr) tr, pr = self.event_loop.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) @@ -628,10 +573,24 @@ def test_create_connection(self): self.assertTrue(pr.nbytes > 0) def test_create_connection_sock(self): - with self.run_test_server() as httpd: - host, port = httpd.socket.getsockname() - f = tasks.Task( - self.event_loop.create_connection(MyProto, host, port)) + with test_utils.run_test_server(self.event_loop) as addr: + sock = None + infos = self.event_loop.run_until_complete( + self.event_loop.getaddrinfo(*addr, type=socket.SOCK_STREAM)) + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, address)) + except: + pass + else: + break + else: + assert False, 'Can not create socket.' + + f = self.event_loop.create_connection(MyProto, sock=sock) tr, pr = self.event_loop.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) @@ -640,10 +599,8 @@ def test_create_connection_sock(self): @unittest.skipIf(ssl is None, 'No ssl module') def test_create_ssl_connection(self): - with self.run_test_server(use_ssl=True) as httpsd: - host, port = httpsd.socket.getsockname() - f = self.event_loop.create_connection( - MyProto, host, port, ssl=True) + with test_utils.run_test_server(self.event_loop, use_ssl=True) as addr: + f = self.event_loop.create_connection(MyProto, *addr, ssl=True) tr, pr = self.event_loop.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) diff --git a/tulip/test_utils.py b/tulip/test_utils.py index 9b87db2f..ac0b5f59 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -1,9 +1,19 @@ """Utilities shared by tests.""" +import contextlib import logging +import os import socket import sys +import threading import unittest +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +import tulip +import tulip.http if sys.platform == 'win32': # pragma: no cover @@ -28,3 +38,53 @@ def suppress_log_errors(self): # pragma: no cover def suppress_log_warnings(self): # pragma: no cover if self._log_level >= logging.WARNING: self._logger.setLevel(logging.ERROR) + + +@contextlib.contextmanager +def run_test_server(loop, *, host='127.0.0.1', port=0, use_ssl=False): + class HttpServer(tulip.http.ServerHttpProtocol): + def handle_request(self, info, message): + response = tulip.http.Response( + self.transport, 200, info.version) + + text = b'Test message' + response.add_header('Content-type', 'text/plain') + response.add_header('Content-length', str(len(text))) + response.send_headers() + response.write(text) + response.write_eof() + self.transport.close() + + if use_ssl: + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + keyfile = os.path.join(here, 'sample.key') + certfile = os.path.join(here, 'sample.crt') + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain(certfile, keyfile) + else: + sslcontext = False + + def run(loop, fut): + thread_loop = tulip.new_event_loop() + tulip.set_event_loop(thread_loop) + + sock = thread_loop.run_until_complete( + thread_loop.start_serving(HttpServer, host, port, ssl=sslcontext)) + + waiter = tulip.Future() + loop.call_soon_threadsafe( + fut.set_result, (thread_loop, waiter, sock.getsockname())) + + thread_loop.run_until_complete(waiter) + thread_loop.stop() + + fut = tulip.Future() + server_thread = threading.Thread(target=run, args=(loop, fut)) + server_thread.start() + + thread_loop, waiter, addr = loop.run_until_complete(fut) + try: + yield addr + finally: + thread_loop.call_soon_threadsafe(waiter.set_result, None) + server_thread.join() From 9fae46b20d6d6612a62a0f2c53a4134eeac9fddc Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 27 Mar 2013 20:45:02 -0700 Subject: [PATCH 0392/1502] windows tests fixes --- srv.py | 2 +- tests/events_test.py | 12 ++++++++++-- tulip/proactor_events.py | 4 +++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/srv.py b/srv.py index 47a1ed45..8fd6ccf6 100755 --- a/srv.py +++ b/srv.py @@ -133,7 +133,7 @@ def main(): if args.ssl: here = os.path.join(os.path.dirname(__file__), 'tests') - + if args.certfile: certfile = args.certfile or os.path.join(here, 'sample.crt') keyfile = args.keyfile or os.path.join(here, 'sample.key') diff --git a/tests/events_test.py b/tests/events_test.py index 1bb92b4d..956e861d 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -168,6 +168,8 @@ def run(): time.sleep(0.1) self.event_loop.call_soon_threadsafe(callback) + self.event_loop.run_once(0.001) # windows iocp + t = threading.Thread(target=run) t0 = time.monotonic() t.start() @@ -422,6 +424,8 @@ def sender(): self.assertTrue(data == b'x'*256) def test_sock_client_ops(self): + self.suppress_log_errors() + with test_utils.run_test_server(self.event_loop) as addr: sock = socket.socket() sock.setblocking(False) @@ -680,6 +684,7 @@ def factory(): self.assertEqual('INITIAL', proto.state) self.event_loop.run_once() self.assertEqual('CONNECTED', proto.state) + self.event_loop.run_once(0.001) # windows iocp self.assertEqual(3, proto.nbytes) # extra info is available @@ -691,6 +696,7 @@ def factory(): # close connection proto.transport.close() + self.event_loop.run_once(0.001) # windows iocp self.assertEqual('CLOSED', proto.state) @@ -720,11 +726,11 @@ def factory(): keyfile=os.path.join(here, 'sample.key')) f = self.event_loop.start_serving( - factory, '0.0.0.0', 0, ssl=sslcontext) + factory, '127.0.0.1', 0, ssl=sslcontext) sock = self.event_loop.run_until_complete(f) host, port = sock.getsockname() - self.assertEqual(host, '0.0.0.0') + self.assertEqual(host, '127.0.0.1') f_c = self.event_loop.create_connection( ClientMyProto, host, port, ssl=True) @@ -1074,6 +1080,8 @@ def create_event_loop(self): return windows_events.ProactorEventLoop() def test_create_ssl_connection(self): raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + def test_start_serving_ssl(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") def test_reader_callback(self): raise unittest.SkipTest("IocpEventLoop does not have add_reader()") def test_reader_callback_cancel(self): diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index d5865dbd..4b4b0acc 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -170,7 +170,9 @@ def _loop_self_reading(self, f=None): def _write_to_self(self): self._csock.send(b'x') - def _start_serving(self, protocol_factory, sock): + def _start_serving(self, protocol_factory, sock, ssl=False): + assert not ssl, 'IocpEventLoop imcompatible with SSL.' + def loop(f=None): try: if f: From 3793ba41877fb64eacb3317a77a22cc5e84fb0de Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 27 Mar 2013 21:43:39 -0700 Subject: [PATCH 0393/1502] wsgi server --- tests/http_server_test.py | 4 +- tests/http_wsgi_test.py | 242 ++++++++++++++++++++++++++++++++++++++ tulip/http/__init__.py | 4 +- tulip/http/protocol.py | 5 +- tulip/http/server.py | 10 +- tulip/http/wsgi.py | 219 ++++++++++++++++++++++++++++++++++ 6 files changed, 474 insertions(+), 10 deletions(-) create mode 100644 tests/http_wsgi_test.py create mode 100644 tulip/http/wsgi.py diff --git a/tests/http_server_test.py b/tests/http_server_test.py index dc55eff9..04fc60b5 100644 --- a/tests/http_server_test.py +++ b/tests/http_server_test.py @@ -80,10 +80,10 @@ def test_connection_lost(self): def test_close(self): srv = server.ServerHttpProtocol() - self.assertFalse(srv.closing) + self.assertFalse(srv._closing) srv.close() - self.assertTrue(srv.closing) + self.assertTrue(srv._closing) def test_handle_error(self): transport = unittest.mock.Mock() diff --git a/tests/http_wsgi_test.py b/tests/http_wsgi_test.py new file mode 100644 index 00000000..2fc0fee8 --- /dev/null +++ b/tests/http_wsgi_test.py @@ -0,0 +1,242 @@ +"""Tests for http/wsgi.py""" + +import io +import unittest +import unittest.mock + +import tulip +from tulip.http import wsgi +from tulip.http import protocol +from tulip.test_utils import LogTrackingTestCase + + +class HttpWsgiServerProtocolTests(LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.suppress_log_errors() + + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + self.wsgi = unittest.mock.Mock() + self.stream = unittest.mock.Mock() + self.transport = unittest.mock.Mock() + self.transport.get_extra_info.return_value = '127.0.0.1' + + self.payload = b'data' + self.info = protocol.RequestLine('GET', '/path', (1, 0)) + self.headers = [] + self.message = protocol.RawHttpMessage( + self.headers, b'data', True, 'deflate') + + def tearDown(self): + self.loop.close() + super().tearDown() + + def test_ctor(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi) + self.assertIs(srv.wsgi, self.wsgi) + self.assertFalse(srv.readpayload) + + def _make_one(self, **kw): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, **kw) + srv.stream = self.stream + srv.transport = self.transport + return srv.create_wsgi_environ(self.info, self.message, self.payload) + + def test_environ(self): + environ = self._make_one() + self.assertEqual(environ['RAW_URI'], '/path') + self.assertEqual(environ['wsgi.async'], True) + + def test_environ_except_header(self): + self.headers.append(('EXPECT', '101-continue')) + self._make_one() + self.assertFalse(self.transport.write.called) + + self.headers[0] = ('EXPECT', '100-continue') + self._make_one() + self.transport.write.assert_called_with( + b'HTTP/1.1 100 Continue\r\n\r\n') + + def test_environ_headers(self): + self.headers.extend( + (('HOST', 'python.org'), + ('SCRIPT_NAME', 'script'), + ('CONTENT-TYPE', 'text/plain'), + ('CONTENT-LENGTH', '209'), + ('X_TEST', '123'), + ('X_TEST', '456'))) + environ = self._make_one(is_ssl=True) + self.assertEqual(environ['CONTENT_TYPE'], 'text/plain') + self.assertEqual(environ['CONTENT_LENGTH'], '209') + self.assertEqual(environ['HTTP_X_TEST'], '123,456') + self.assertEqual(environ['SCRIPT_NAME'], 'script') + self.assertEqual(environ['SERVER_NAME'], 'python.org') + self.assertEqual(environ['SERVER_PORT'], '443') + + def test_environ_host_header(self): + self.headers.append(('HOST', 'python.org')) + environ = self._make_one() + + self.assertEqual(environ['HTTP_HOST'], 'python.org') + self.assertEqual(environ['SERVER_NAME'], 'python.org') + self.assertEqual(environ['SERVER_PORT'], '80') + self.assertEqual(environ['SERVER_PROTOCOL'], 'HTTP/1.0') + + def test_environ_host_port_header(self): + self.info = protocol.RequestLine('GET', '/path', (1, 1)) + self.headers.append(('HOST', 'python.org:443')) + environ = self._make_one() + + self.assertEqual(environ['HTTP_HOST'], 'python.org:443') + self.assertEqual(environ['SERVER_NAME'], 'python.org') + self.assertEqual(environ['SERVER_PORT'], '443') + self.assertEqual(environ['SERVER_PROTOCOL'], 'HTTP/1.1') + + def test_environ_forward(self): + self.transport.get_extra_info.return_value = 'localhost,127.0.0.1' + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') + self.assertEqual(environ['REMOTE_PORT'], '80') + + self.transport.get_extra_info.return_value = 'localhost,127.0.0.1:443' + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') + self.assertEqual(environ['REMOTE_PORT'], '443') + + self.transport.get_extra_info.return_value = ('127.0.0.1', 443) + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') + self.assertEqual(environ['REMOTE_PORT'], '443') + + self.transport.get_extra_info.return_value = '[::1]' + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '::1') + self.assertEqual(environ['REMOTE_PORT'], '80') + + def test_wsgi_response(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.info, self.message) + self.assertIsInstance(resp, wsgi.WsgiResponse) + + def test_wsgi_response_start_response(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.info, self.message) + resp.start_response( + '200 OK', [('CONTENT-TYPE', 'text/plain')]) + self.assertEqual(resp.status, '200 OK') + self.assertIsInstance(resp.response, protocol.Response) + + def test_wsgi_response_start_response_exc(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.info, self.message) + resp.start_response( + '200 OK', [('CONTENT-TYPE', 'text/plain')], ['', ValueError()]) + self.assertEqual(resp.status, '200 OK') + self.assertIsInstance(resp.response, protocol.Response) + + def test_wsgi_response_start_response_exc_status(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.info, self.message) + resp.start_response('200 OK', [('CONTENT-TYPE', 'text/plain')]) + + self.assertRaises( + ValueError, + resp.start_response, + '500 Err', [('CONTENT-TYPE', 'text/plain')], ['', ValueError()]) + + def test_file_wrapper(self): + fobj = io.BytesIO(b'data') + wrapper = wsgi.FileWrapper(fobj, 2) + self.assertIs(wrapper, iter(wrapper)) + self.assertTrue(hasattr(wrapper, 'close')) + + self.assertEqual(next(wrapper), b'da') + self.assertEqual(next(wrapper), b'ta') + self.assertRaises(StopIteration, next, wrapper) + + wrapper = wsgi.FileWrapper(b'data', 2) + self.assertFalse(hasattr(wrapper, 'close')) + + def test_handle_request_futures(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + f1 = tulip.Future() + f1.set_result(b'data') + fut = tulip.Future() + fut.set_result([f1]) + return fut + + srv = wsgi.WSGIServerHttpProtocol(wsgi_app) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.info, self.message)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) + self.assertTrue(content.endswith(b'data')) + + def test_handle_request_simple(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + return [b'data'] + + stream = tulip.StreamReader() + stream.feed_data(b'data') + stream.feed_eof() + self.message = protocol.RawHttpMessage( + self.headers, stream, True, 'deflate') + self.info = protocol.RequestLine('GET', '/path', (1, 1)) + + srv = wsgi.WSGIServerHttpProtocol(wsgi_app, True) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.info, self.message)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.1 200 OK')) + self.assertTrue(content.endswith(b'data\r\n0\r\n\r\n')) + + def test_handle_request_io(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + return io.BytesIO(b'data') + + srv = wsgi.WSGIServerHttpProtocol(wsgi_app) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.info, self.message)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) + self.assertTrue(content.endswith(b'data')) diff --git a/tulip/http/__init__.py b/tulip/http/__init__.py index 582f0809..b2a0a26d 100644 --- a/tulip/http/__init__.py +++ b/tulip/http/__init__.py @@ -4,9 +4,11 @@ from .errors import * from .protocol import * from .server import * +from .wsgi import * __all__ = (client.__all__ + errors.__all__ + protocol.__all__ + - server.__all__) + server.__all__ + + wsgi.__all__) diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py index c5c4e499..478226ba 100644 --- a/tulip/http/protocol.py +++ b/tulip/http/protocol.py @@ -5,13 +5,14 @@ 'RawHttpMessage', 'RequestLine', 'ResponseStatus'] import collections -import email.utils import functools import http.server import itertools import re import sys +import time import zlib +from wsgiref.handlers import format_date_time import tulip from . import errors @@ -852,7 +853,7 @@ def __init__(self, transport, status, http_version=(1, 1), close=False): def _default_headers(self): headers = super()._default_headers() - headers.extend((('DATE', email.utils.formatdate()), + headers.extend((('DATE', format_date_time(time.time())), ('SERVER', self.SERVER_SOFTWARE))) return headers diff --git a/tulip/http/server.py b/tulip/http/server.py index 7590e47b..d22e7925 100644 --- a/tulip/http/server.py +++ b/tulip/http/server.py @@ -36,8 +36,8 @@ class ServerHttpProtocol(tulip.Protocol): status line, bad headers or incomplete payload. If any error occurs, connection gets closed. """ - closing = False - request_count = 0 + _closing = False + _request_count = 0 _request_handle = None def __init__(self, log=logging, debug=False): @@ -61,7 +61,7 @@ def eof_received(self): self.stream.feed_eof() def close(self): - self.closing = True + self._closing = True def log_access(self, status, info, message, *args, **kw): pass @@ -85,7 +85,7 @@ def start(self): while True: info = None message = None - self.request_count += 1 + self._request_count += 1 try: info = yield from self.stream.read_request_line() @@ -104,7 +104,7 @@ def start(self): except Exception as exc: self.handle_error(500, info, message, exc) finally: - if self.closing: + if self._closing: self.transport.close() break diff --git a/tulip/http/wsgi.py b/tulip/http/wsgi.py new file mode 100644 index 00000000..5957f9c8 --- /dev/null +++ b/tulip/http/wsgi.py @@ -0,0 +1,219 @@ +"""wsgi server. + +TODO: + * proxy protocol + * x-forward security + * wsgi file support (os.sendfile) +""" + +__all__ = ['WSGIServerHttpProtocol'] + +import inspect +import io +import os +import sys +from urllib.parse import unquote, urlsplit + +import tulip +import tulip.http +from tulip.http import server + + +class WSGIServerHttpProtocol(server.ServerHttpProtocol): + """HTTP Server that implements the Python WSGI protocol. + + It uses 'wsgi.async' of 'True'. 'wsgi.input' can behave differently + depends on 'readpayload' constructor parameter. If readpayload is set to + True, wsgi server reads all incoming data into BytesIO object and + sends it as 'wsgi.input' environ var. If readpayload is set to false + 'wsgi.input' is a StreamReader and application should read incoming + data with "yield from environ['wsgi.input'].read()". It defaults to False. + """ + + SCRIPT_NAME = os.environ.get('SCRIPT_NAME', '') + + def __init__(self, app, readpayload=False, is_ssl=False, *args, **kw): + super().__init__(*args, **kw) + + self.wsgi = app + self.is_ssl = is_ssl + self.readpayload = readpayload + + def create_wsgi_response(self, info, message): + return WsgiResponse(self.transport, info, message) + + def create_wsgi_environ(self, info, message, payload): + uri_parts = urlsplit(info.uri) + url_scheme = 'https' if self.is_ssl else 'http' + + environ = { + 'wsgi.input': payload, + 'wsgi.errors': sys.stderr, + 'wsgi.version': (1, 0), + 'wsgi.async': True, + 'wsgi.multithread': False, + 'wsgi.multiprocess': False, + 'wsgi.run_once': False, + 'wsgi.file_wrapper': FileWrapper, + 'wsgi.url_scheme': url_scheme, + 'SERVER_SOFTWARE': tulip.http.HttpMessage.SERVER_SOFTWARE, + 'REQUEST_METHOD': info.method, + 'QUERY_STRING': uri_parts.query or '', + 'RAW_URI': info.uri, + 'SERVER_PROTOCOL': 'HTTP/%s.%s' % info.version + } + + # authors should be aware that REMOTE_HOST and REMOTE_ADDR + # may not qualify the remote addr: + # http://www.ietf.org/rfc/rfc3875 + forward = self.transport.get_extra_info('addr', '127.0.0.1') + script_name = self.SCRIPT_NAME + server = forward + + for hdr_name, hdr_value in message.headers: + if hdr_name == 'EXPECT': + # handle expect + if hdr_value.lower() == '100-continue': + self.transport.write(b'HTTP/1.1 100 Continue\r\n\r\n') + elif hdr_name == 'HOST': + server = hdr_value + elif hdr_name == 'SCRIPT_NAME': + script_name = hdr_value + elif hdr_name == 'CONTENT-TYPE': + environ['CONTENT_TYPE'] = hdr_value + continue + elif hdr_name == 'CONTENT-LENGTH': + environ['CONTENT_LENGTH'] = hdr_value + continue + + key = 'HTTP_' + hdr_name.replace('-', '_') + if key in environ: + hdr_value = '%s,%s' % (environ[key], hdr_value) + + environ[key] = hdr_value + + if isinstance(forward, str): + # we only took the last one + # http://en.wikipedia.org/wiki/X-Forwarded-For + if ',' in forward: + forward = forward.rsplit(',', 1)[-1].strip() + + # find host and port on ipv6 address + if '[' in forward and ']' in forward: + host = forward.split(']')[0][1:].lower() + elif ':' in forward and forward.count(':') == 1: + host = forward.split(':')[0].lower() + else: + host = forward + + forward = forward.split(']')[-1] + if ':' in forward and forward.count(':') == 1: + port = forward.split(':', 1)[1] + else: + port = 80 + + remote = (host, port) + else: + remote = forward + + environ['REMOTE_ADDR'] = remote[0] + environ['REMOTE_PORT'] = str(remote[1]) + + if isinstance(server, str): + server = server.split(':') + if len(server) == 1: + server.append('80' if url_scheme == 'http' else '443') + + environ['SERVER_NAME'] = server[0] + environ['SERVER_PORT'] = str(server[1]) + + path_info = uri_parts.path + if script_name: + path_info = path_info.split(script_name, 1)[-1] + + environ['PATH_INFO'] = unquote(path_info) + environ['SCRIPT_NAME'] = script_name + + environ['tulip.reader'] = self.stream + environ['tulip.writer'] = self.transport + + return environ + + @tulip.coroutine + def handle_request(self, info, message): + """Handle a single HTTP request""" + + if self.readpayload: + payload = io.BytesIO((yield from message.payload.read())) + else: + payload = message.payload + + environ = self.create_wsgi_environ(info, message, payload) + response = self.create_wsgi_response(info, message) + + riter = self.wsgi(environ, response.start_response) + if isinstance(riter, tulip.Future) or inspect.isgenerator(riter): + riter = yield from riter + + resp = response.response + try: + for item in riter: + if isinstance(item, tulip.Future): + item = yield from item + resp.write(item) + + resp.write_eof() + finally: + if hasattr(riter, 'close'): + riter.close() + + if not resp.keep_alive(): + self.close() + + +class FileWrapper: + """Custom file wrapper.""" + + def __init__(self, fobj, chunk_size=8192): + self.fobj = fobj + self.chunk_size = chunk_size + if hasattr(fobj, 'close'): + self.close = fobj.close + + def __iter__(self): + return self + + def __next__(self): + data = self.fobj.read(self.chunk_size) + if data: + return data + raise StopIteration + + +class WsgiResponse: + """Implementation of start_response() callable as specified by PEP 3333""" + + status = None + + def __init__(self, transport, info, message): + self.transport = transport + self.info = info + self.message = message + + def start_response(self, status, headers, exc_info=None): + if exc_info: + try: + if self.status: + raise exc_info[1] + finally: + exc_info = None + + status_code = int(status.split(' ', 1)[0]) + + self.status = status + self.response = tulip.http.Response( + self.transport, status_code, + self.info.version, self.message.should_close) + self.response.add_headers(*headers) + self.response._send_headers = True + return self.response.write From 80d1312a3e9c869f26fa4790a8978fd7f8486fb1 Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Thu, 28 Mar 2013 15:39:55 -0400 Subject: [PATCH 0394/1502] Use logger named 'tulip' for library events, Issue 26 --- .hgeol | 2 + .hgignore | 11 + Makefile | 31 + NOTES | 176 +++++ README | 21 + TODO | 163 ++++ check.py | 41 + crawl.py | 143 ++++ curl.py | 35 + examples/udp_echo.py | 73 ++ old/Makefile | 16 + old/echoclt.py | 79 ++ old/echosvr.py | 60 ++ old/http_client.py | 78 ++ old/http_server.py | 68 ++ old/main.py | 134 ++++ old/p3time.py | 47 ++ old/polling.py | 535 +++++++++++++ old/scheduling.py | 354 +++++++++ old/sockets.py | 348 +++++++++ old/transports.py | 496 ++++++++++++ old/xkcd.py | 18 + old/yyftime.py | 75 ++ overlapped.c | 997 ++++++++++++++++++++++++ runtests.py | 198 +++++ setup.cfg | 2 + setup.py | 14 + srv.py | 115 +++ sslsrv.py | 56 ++ tests/base_events_test.py | 283 +++++++ tests/events_test.py | 1379 +++++++++++++++++++++++++++++++++ tests/futures_test.py | 222 ++++++ tests/http_protocol_test.py | 972 +++++++++++++++++++++++ tests/http_server_test.py | 242 ++++++ tests/locks_test.py | 747 ++++++++++++++++++ tests/queues_test.py | 370 +++++++++ tests/sample.crt | 14 + tests/sample.key | 15 + tests/selector_events_test.py | 1286 ++++++++++++++++++++++++++++++ tests/selectors_test.py | 137 ++++ tests/streams_test.py | 299 +++++++ tests/subprocess_test.py | 54 ++ tests/tasks_test.py | 647 ++++++++++++++++ tests/transports_test.py | 45 ++ tests/unix_events_test.py | 573 ++++++++++++++ tests/winsocketpair_test.py | 26 + tulip/TODO | 28 + tulip/__init__.py | 26 + tulip/base_events.py | 548 +++++++++++++ tulip/events.py | 356 +++++++++ tulip/futures.py | 255 ++++++ tulip/http/__init__.py | 12 + tulip/http/client.py | 145 ++++ tulip/http/errors.py | 44 ++ tulip/http/protocol.py | 877 +++++++++++++++++++++ tulip/http/server.py | 176 +++++ tulip/locks.py | 433 +++++++++++ tulip/log.py | 6 + tulip/proactor_events.py | 189 +++++ tulip/protocols.py | 78 ++ tulip/queues.py | 291 +++++++ tulip/selector_events.py | 655 ++++++++++++++++ tulip/selectors.py | 418 ++++++++++ tulip/streams.py | 145 ++++ tulip/subprocess_transport.py | 139 ++++ tulip/tasks.py | 320 ++++++++ tulip/test_utils.py | 30 + tulip/transports.py | 134 ++++ tulip/unix_events.py | 301 +++++++ tulip/windows_events.py | 157 ++++ tulip/winsocketpair.py | 34 + 71 files changed, 17494 insertions(+) create mode 100644 .hgeol create mode 100644 .hgignore create mode 100644 Makefile create mode 100644 NOTES create mode 100644 README create mode 100644 TODO create mode 100644 check.py create mode 100755 crawl.py create mode 100755 curl.py create mode 100644 examples/udp_echo.py create mode 100644 old/Makefile create mode 100644 old/echoclt.py create mode 100644 old/echosvr.py create mode 100644 old/http_client.py create mode 100644 old/http_server.py create mode 100644 old/main.py create mode 100644 old/p3time.py create mode 100644 old/polling.py create mode 100644 old/scheduling.py create mode 100644 old/sockets.py create mode 100644 old/transports.py create mode 100755 old/xkcd.py create mode 100644 old/yyftime.py create mode 100644 overlapped.c create mode 100644 runtests.py create mode 100644 setup.cfg create mode 100644 setup.py create mode 100755 srv.py create mode 100644 sslsrv.py create mode 100644 tests/base_events_test.py create mode 100644 tests/events_test.py create mode 100644 tests/futures_test.py create mode 100644 tests/http_protocol_test.py create mode 100644 tests/http_server_test.py create mode 100644 tests/locks_test.py create mode 100644 tests/queues_test.py create mode 100644 tests/sample.crt create mode 100644 tests/sample.key create mode 100644 tests/selector_events_test.py create mode 100644 tests/selectors_test.py create mode 100644 tests/streams_test.py create mode 100644 tests/subprocess_test.py create mode 100644 tests/tasks_test.py create mode 100644 tests/transports_test.py create mode 100644 tests/unix_events_test.py create mode 100644 tests/winsocketpair_test.py create mode 100644 tulip/TODO create mode 100644 tulip/__init__.py create mode 100644 tulip/base_events.py create mode 100644 tulip/events.py create mode 100644 tulip/futures.py create mode 100644 tulip/http/__init__.py create mode 100644 tulip/http/client.py create mode 100644 tulip/http/errors.py create mode 100644 tulip/http/protocol.py create mode 100644 tulip/http/server.py create mode 100644 tulip/locks.py create mode 100644 tulip/log.py create mode 100644 tulip/proactor_events.py create mode 100644 tulip/protocols.py create mode 100644 tulip/queues.py create mode 100644 tulip/selector_events.py create mode 100644 tulip/selectors.py create mode 100644 tulip/streams.py create mode 100644 tulip/subprocess_transport.py create mode 100644 tulip/tasks.py create mode 100644 tulip/test_utils.py create mode 100644 tulip/transports.py create mode 100644 tulip/unix_events.py create mode 100644 tulip/windows_events.py create mode 100644 tulip/winsocketpair.py diff --git a/.hgeol b/.hgeol new file mode 100644 index 00000000..b6910a22 --- /dev/null +++ b/.hgeol @@ -0,0 +1,2 @@ +[patterns] +** = native diff --git a/.hgignore b/.hgignore new file mode 100644 index 00000000..25902497 --- /dev/null +++ b/.hgignore @@ -0,0 +1,11 @@ +.*\.py[co]$ +.*~$ +.*\.orig$ +.*\#.*$ +.*@.*$ +\.coverage$ +htmlcov$ +\.DS_Store$ +venv$ +distribute_setup.py$ +distribute-\d+.\d+.\d+.tar.gz$ diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..274da4c8 --- /dev/null +++ b/Makefile @@ -0,0 +1,31 @@ +# Some simple testing tasks (sorry, UNIX only). + +PYTHON=python3 +VERBOSE=1 +FLAGS= + +test: + $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS) + +testloop: + while sleep 1; do $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS); done + +# See README for coverage installation instructions. +cov coverage: + $(PYTHON) runtests.py --coverage tulip -v $(VERBOSE) $(FLAGS) + echo "open file://`pwd`/htmlcov/index.html" + +check: + $(PYTHON) check.py + +clean: + rm -rf `find . -name __pycache__` + rm -f `find . -type f -name '*.py[co]' ` + rm -f `find . -type f -name '*~' ` + rm -f `find . -type f -name '.*~' ` + rm -f `find . -type f -name '@*' ` + rm -f `find . -type f -name '#*#' ` + rm -f `find . -type f -name '*.orig' ` + rm -f `find . -type f -name '*.rej' ` + rm -f .coverage + rm -rf htmlcov diff --git a/NOTES b/NOTES new file mode 100644 index 00000000..3b94ba96 --- /dev/null +++ b/NOTES @@ -0,0 +1,176 @@ +Notes from PyCon 2013 sprints +============================= + +- Cancellation. If a task creates several subtasks, and then the + parent task fails, should the subtasks be cancelled? (How do we + even establish the parent/subtask relationship?) + +- Adam Sah suggests that there might be a need for scheduling + (especially when multiple frameworks share an event loop). He + points to lottery scheduling but also mentions that's just one of + the options. However, after posting on python-tulip, it appears + none of the other frameworks have scheduling, and nobody seems to + miss it. + +- Feedback from Bram Cohen (Bittorrent creator) about UDP. He doesn't + think connected UDP is worth supporting, it doesn't do anything + except tell the kernel about the default target address for + sendto(). Basically he says all UDP end points are servers. He + sent me his own UDP event loop so I might glean some tricks from it. + He says we should treat EINTR the same as EAGAIN and friends. (We + should use the exceptions dedicated to errno checking, BTW.) HE + said to make sure we use SO_REUSEADDR (I think we already do). He + said to set the max datagram sizes pretty large (anything larger + than the declared limit is dropped on the floor). He reminds us of + the importance of being able to pick a valid, unused port by binding + to port 0 and then using getsockname(). He has an idea where he's + like to be able to kill all registered callbacks (i.e. Handles) + belonging to a certain "context". I think this can be done at the + application level (you'd have to wrap everything that returns a + Handle and collect these handles in some set or other datastructure) + but if someone thinks it's interesting we could imagine having some + kind of notion of context part of the event loop state, + e.g. associated with a Task (see Cancellation point above). He + brought up uTP (Micro Transport Protocol), a reimplementation of TCP + over UDP with more refined congestion control. + +- Mumblings about UNIX domain sockets and IPv6 addresses being + 4-tuples. The former can be handled by passing in a socket. There + seem to be no real use cases for the latter that can't be dealt with + by passing in suitably esoteric strings for the hostname. + getaddrinfo() will produce the appropriate 4-tuple and connect() + will accept it. + +- Mumblings on the list about add vs. set. + + +Notes from the second Tulip/Twisted meet-up +=========================================== + +Rackspace, 12/11/2012 +Glyph, Brian Warner, David Reid, Duncan McGreggor, others + +Flow control +------------ + +- Pause/resume on transport manages data_received. + +- There's also an API to tell the transport whom to pause when the + write calls are overwhelming it: IConsumer.registerProducer(). + +- There's also something called pipes but it's built on top of the + old interface. + +- Twisted has variations on the basic flow control that I should + ignore. + +Half_close +---------- + +- This sends an EOF after writing some stuff. + +- Can't write any more. + +- Problem with TLS is known (the RFC sadly specifies this behavior). + +- It must be dynamimcally discoverable whether the transport supports + half_close, since the protocol may have to do something different to + make up for its missing (e.g. use chunked encoding). Twisted uses + an interface check for this and also hasattr(trans, 'halfClose') + but a flag (or flag method) is fine too. + +Constructing transport and protocol +----------------------------------- + +- There are good reasons for passing a function to the transport + construction helper that creates the protocol. (You need these + anyway for server-side protocols.) The sequence of events is + something like + + . open socket + . create transport (pass it a socket?) + . create protocol (pass it nothing) + . proto.make_connection(transport); this does: + . self.transport = transport + . self.connection_made(transport) + + But it seems okay to skip make_connection and setting .transport. + Note that make_connection() is a concrete method on the Protocol + implementation base class, while connection_made() is an abstract + method on IProtocol. + +Event Loop +---------- + +- We discussed the sequence of actions in the event loop. I think in the + end we're fine with what Tulip currently does. There are two choices: + + Tulip: + . run ready callbacks until there aren't any left + . poll, adding more callbacks to the ready list + . add now-ready delayed callbacks to the ready list + . go to top + + Tornado: + . run all currently ready callbacks (but not new ones added during this) + . (the rest is the same) + + The difference is that in the Tulip version, CPU bound callbacks + that keep adding more to the queue will starve I/O (and yielding to + other tasks won't actually cause I/O to happen unless you do + e.g. sleep(0.001)). OTOH this may be good because it means there's + less overhead if you frequently split operations in two. + +- I think Twisted does it Tornado style (in a convoluted way :-), but + it may not matter, and it's important to leave this vague so + implementations can do what's best for their platform. (E.g. if the + event loop is built into the OS there are different trade-offs.) + +System call cost +---------------- + +- System calls on MacOS are expensive, on Linux they are cheap. + +- Optimal buffer size ~16K. + +- Try joining small buffer pieces together, but expect to be tuning + this later. + +Futures +------- + +- Futures are the most robust API for async stuff, you can check + errors etc. So let's do this. + +- Just don't implement wait(). + +- For the basics, however, (recv/send, mostly), don't use Futures but use + basic callbacks, transport/protocol style. + +- make_connection() (by any name) can return a Future, it makes it + easier to check for errors. + +- This means revisiting the Tulip proactor branch (IOCP). + +- The semantics of add_done_callback() are fuzzy about in which thread + the callback will be called. (It may be the current thread or + another one.) We don't like that. But always inserting a + call_soon() indirection may be expensive? Glyph suggested changing + the add_done_callback() method name to something else to indicate + the changed promise. + +- Separately, I've been thinking about having two versions of + call_soon() -- a more heavy-weight one to be called from other + threads that also writes a byte to the self-pipe. + +Signals +------- + +- There was a side conversation about signals. A signal handler is + similar to another thread, so probably should use (the heavy-weight + version of) call_soon() to schedule the real callback and not do + anything else. + +- Glyph vaguely recalled some trickiness with the self-pipe. We + should be able to fix this afterwards if necessary, it shouldn't + affect the API design. diff --git a/README b/README new file mode 100644 index 00000000..85bfe5a7 --- /dev/null +++ b/README @@ -0,0 +1,21 @@ +Tulip is the codename for my reference implementation of PEP 3156. + +PEP 3156: http://www.python.org/dev/peps/pep-3156/ + +*** This requires Python 3.3 or later! *** + +Copyright/license: Open source, Apache 2.0. Enjoy. + +Master Mercurial repo: http://code.google.com/p/tulip/ + +The old code lives in the subdirectory 'old'; the new code (conforming +to PEP 3156, under construction) lives in the 'tulip' subdirectory. + +To run tests: + - make test + +To run coverage (coverage package is required): + - make coverage + + +--Guido van Rossum diff --git a/TODO b/TODO new file mode 100644 index 00000000..c6d4eead --- /dev/null +++ b/TODO @@ -0,0 +1,163 @@ +# -*- Mode: text -*- + +TO DO LARGER TASKS + +- Need more examples. + +- Benchmarkable but more realistic HTTP server? + +- Example of using UDP. + +- Write up a tutorial for the scheduling API. + +- More systematic approach to logging. Logger objects? What about + heavy-duty logging, tracking essentially all task state changes? + +- Restructure directory, move demos and benchmarks to subdirectories. + + +TO DO LATER + +- When multiple tasks are accessing the same socket, they should + either get interleaved I/O or an immediate exception; it should not + compromise the integrity of the scheduler or the app or leave a task + hanging. + +- For epoll you probably want to check/(log?) EPOLLHUP and EPOLLERR errors. + +- Add the simplest API possible to run a generator with a timeout. + +- Ensure multiple tasks can do atomic writes to the same pipe (since + UNIX guarantees that short writes to pipes are atomic). + +- Ensure some easy way of distributing accepted connections across tasks. + +- Be wary of thread-local storage. There should be a standard API to + get the current Context (which holds current task, event loop, and + maybe more) and a standard meta-API to change how that standard API + works (i.e. without monkey-patching). + +- See how much of asyncore I've already replaced. + +- Could BufferedReader reuse the standard io module's readers??? + +- Support ZeroMQ "sockets" which are user objects. Though possibly + this can be supported by getting the underlying fd? See + http://mail.python.org/pipermail/python-ideas/2012-October/017532.html + OTOH see + https://github.com/zeromq/pyzmq/blob/master/zmq/eventloop/ioloop.py + +- Study goroutines (again). + +- Benchmarks: http://nichol.as/benchmark-of-python-web-servers + + +FROM OLDER LIST + +- Multiple readers/writers per socket? (At which level? pollster, + eventloop, or scheduler?) + +- Could poll() usefully be an iterator? + +- Do we need to support more epoll and/or kqueue modes/flags/options/etc.? + +- Optimize register/unregister calls away if they cancel each other out? + +- Add explicit wait queue to wait for Task's completion, instead of + callbacks? + +- Look at pyfdpdlib's ioloop.py: + http://code.google.com/p/pyftpdlib/source/browse/trunk/pyftpdlib/lib/ioloop.py + + +MISTAKES I MADE + +- Forgetting yield from. (E.g.: scheduler.sleep(1); listener.accept().) + +- Forgot to add bare yield at end of internal function, after block(). + +- Forgot to call add_done_callback(). + +- Forgot to pass an undoer to block(), bug only found when cancelled. + +- Subtle accounting mistake in a callback. + +- Used context.eventloop from a different thread, forgetting about TLS. + +- Nasty race: eventloop.ready may contain both an I/O callback and a + cancel callback. How to avoid? Keep the DelayedCall in ready. Is + that enough? + +- If a toplevel task raises an error it just stops and nothing is logged + unless you have debug logging on. This confused me. (Then again, + previously I logged whenever a task raised an error, and that was too + chatty...) + +- Forgot to set the connection socket returned by accept() in + nonblocking mode. + +- Nastiest so far (cost me about a day): A race condition in + call_in_thread() where the Future's done_callback (which was + task.unblock()) would run immediately at the time when + add_done_callback() was called, and this screwed over the task + state. Solution: wrap the callback in eventloop.call_later(). + Ironically, I had a comment stating there might be a race condition. + +- Another bug where I was calling unblock() for the current thread + immediately after calling block(), before yielding. + +- readexactly() wasn't checking for EOF, so could be looping. + (Worse, the first fix I attempted was wrong.) + +- Spent a day trying to understand why a tentative patch trying to + move the recv() implementation into the eventloop (or the pollster) + resulted in problems cancelling a recv() call. Ultimately the + problem is that the cancellation mechanism is part of the coroutine + scheduler, which simply throws an exception into a task when it next + runs, and there isn't anything to be interrupted in the eventloop; + but the eventloop still has a reader registered (which will never + fire because I suspended the server -- that's my test case :-). + Then, the eventloop keeps running until the last file descriptor is + unregistered. What contributed to this disaster? + * I didn't build the whole infrastructure, just played with recv() + * I don't have unittests + * I don't have good logging to see what is going + +- In sockets.py, in some SSL error handling code, used the wrong + variable (sock instead of sslsock). A linter would have found this. + +- In polling.py, in KqueuePollster.register_writer(), a copy/paste + error where I was testing for "if fd not in self.readers" instead of + writers. This only came out when I had both a reader and a writer + for the same fd. + +- Submitted some changes prematurely (forgot to pass the filename on + hg ci). + +- Forgot again that shutdown(SHUT_WR) on an ssl socket does not work + as I expected. I ran into this with the origininal sockets.py and + again in transport.py. + +- Having the same callback for both reading and writing has a problem: + it may be scheduled twice, and if the first call closes the socket, + the second runs into trouble. + + +MISTAKES I MADE IN TULIP V2 + +- Nice one: Future._schedule_callbacks() wasn't scheduling any callbacks. + Spot the bug in these four lines: + + def _schedule_callbacks(self): + callbacks = self._callbacks[:] + self._callbacks[:] = [] + for callback in self._callbacks: + self._event_loop.call_soon(callback, self) + + The good news is that I found it with a unittest (albeit not the + unittest intended to exercise this particular method :-( ). + +- In _make_self_pipe_or_sock(), called _pollster.register_reader() + instead of add_reader(), trying to optimize something but breaking + things instead (since the -- internal -- API of register_reader() + had changed). diff --git a/check.py b/check.py new file mode 100644 index 00000000..64bc2cdd --- /dev/null +++ b/check.py @@ -0,0 +1,41 @@ +"""Search for lines > 80 chars or with trailing whitespace.""" + +import sys, os + +def main(): + args = sys.argv[1:] or os.curdir + for arg in args: + if os.path.isdir(arg): + for dn, dirs, files in os.walk(arg): + for fn in sorted(files): + if fn.endswith('.py'): + process(os.path.join(dn, fn)) + dirs[:] = [d for d in dirs if d[0] != '.'] + dirs.sort() + else: + process(arg) + +def isascii(x): + try: + x.encode('ascii') + return True + except UnicodeError: + return False + +def process(fn): + try: + f = open(fn) + except IOError as err: + print(err) + return + try: + for i, line in enumerate(f): + line = line.rstrip('\n') + sline = line.rstrip() + if len(line) > 80 or line != sline or not isascii(line): + print('%s:%d:%s%s' % ( + fn, i+1, sline, '_' * (len(line) - len(sline)))) + finally: + f.close() + +main() diff --git a/crawl.py b/crawl.py new file mode 100755 index 00000000..4e5bebe2 --- /dev/null +++ b/crawl.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 + +import logging +import re +import signal +import socket +import sys +import urllib.parse + +import tulip +import tulip.http + +END = '\n' +MAXTASKS = 100 + + +class Crawler: + + def __init__(self, rooturl): + self.rooturl = rooturl + self.todo = set() + self.busy = set() + self.done = {} + self.tasks = set() + self.waiter = None + self.addurl(self.rooturl, '') # Set initial work. + self.run() # Kick off work. + + def addurl(self, url, parenturl): + url = urllib.parse.urljoin(parenturl, url) + url, frag = urllib.parse.urldefrag(url) + if not url.startswith(self.rooturl): + return False + if url in self.busy or url in self.done or url in self.todo: + return False + self.todo.add(url) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(None) + return True + + @tulip.task + def run(self): + while self.todo or self.busy or self.tasks: + complete, self.tasks = yield from tulip.wait(self.tasks, timeout=0) + print(len(complete), 'completed tasks,', len(self.tasks), + 'still pending ', end=END) + for task in complete: + try: + yield from task + except Exception as exc: + print('Exception in task:', exc, end=END) + while self.todo and len(self.tasks) < MAXTASKS: + url = self.todo.pop() + self.busy.add(url) + self.tasks.add(self.process(url)) # Async task. + if self.busy: + self.waiter = tulip.Future() + yield from self.waiter + tulip.get_event_loop().stop() + + @tulip.task + def process(self, url): + ok = False + p = None + try: + print('processing', url, end=END) + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not path: + path = '/' + if query: + path = '?'.join([path, query]) + p = tulip.http.HttpClientProtocol( + netloc, path=path, ssl=(scheme=='https')) + delay = 1 + while True: + try: + status, headers, stream = yield from p.connect() + break + except socket.error as exc: + if delay >= 60: + raise + print('...', url, 'has error', repr(str(exc)), + 'retrying after sleep', delay, '...', end=END) + yield from tulip.sleep(delay) + delay *= 2 + + if status[:3] in ('301', '302'): + # Redirect. + u = headers.get('location') or headers.get('uri') + if self.addurl(u, url): + print(' ', url, status[:3], 'redirect to', u, end=END) + elif status.startswith('200'): + ctype = headers.get_content_type() + if ctype == 'text/html': + while True: + line = yield from stream.readline() + if not line: + break + line = line.decode('utf-8', 'replace') + urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', + line) + for u in urls: + if self.addurl(u, url): + print(' ', url, 'href to', u, end=END) + ok = True + finally: + if p is not None: + p.transport.close() + self.done[url] = ok + self.busy.remove(url) + if not ok: + print('failure for', url, sys.exc_info(), end=END) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(None) + + +def main(): + rooturl = sys.argv[1] + c = Crawler(rooturl) + loop = tulip.get_event_loop() + try: + loop.add_signal_handler(signal.SIGINT, loop.stop) + except RuntimeError: + pass + loop.run_forever() + print('todo:', len(c.todo)) + print('busy:', len(c.busy)) + print('done:', len(c.done), '; ok:', sum(c.done.values())) + print('tasks:', len(c.tasks)) + + +if __name__ == '__main__': + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + logging.info('using iocp') + el = windows_events.ProactorEventLoop() + events.set_event_loop(el) + main() diff --git a/curl.py b/curl.py new file mode 100755 index 00000000..37fce75c --- /dev/null +++ b/curl.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 + +import sys +import urllib.parse + +import tulip +import tulip.http + + +def main(): + url = sys.argv[1] + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not path: + path = '/' + if query: + path = '?'.join([path, query]) + print(netloc, path, scheme) + p = tulip.http.HttpClientProtocol(netloc, path=path, ssl=(scheme=='https')) + f = p.connect() + sts, headers, stream = p.event_loop.run_until_complete(f) + print(sts) + for k, v in headers.items(): + print('{}: {}'.format(k, v)) + print() + data = p.event_loop.run_until_complete(stream.read()) + print(data.decode('utf-8', 'replace')) + + +if __name__ == '__main__': + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + el = windows_events.ProactorEventLoop() + events.set_event_loop(el) + main() diff --git a/examples/udp_echo.py b/examples/udp_echo.py new file mode 100644 index 00000000..9e995d14 --- /dev/null +++ b/examples/udp_echo.py @@ -0,0 +1,73 @@ +"""UDP echo example. + +Start server: + + >> python ./udp_echo.py --server + +""" + +import sys +import tulip + +ADDRESS = ('127.0.0.1', 10000) + + +class MyServerUdpEchoProtocol: + + def connection_made(self, transport): + print('start', transport) + self.transport = transport + + def datagram_received(self, data, addr): + print('Data received:', data, addr) + self.transport.sendto(data, addr) + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('stop', exc) + + +class MyClientUdpEchoProtocol: + + message = 'This is the message. It will be repeated.' + + def connection_made(self, transport): + self.transport = transport + print('sending "%s"' % self.message) + self.transport.sendto(self.message.encode()) + print('waiting to receive') + + def datagram_received(self, data, addr): + print('received "%s"' % data.decode()) + self.transport.close() + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('closing transport', exc) + loop = tulip.get_event_loop() + loop.stop() + + +def start_server(): + loop = tulip.get_event_loop() + tulip.Task(loop.create_datagram_endpoint( + MyServerUdpEchoProtocol, local_addr=ADDRESS)) + loop.run_forever() + + +def start_client(): + loop = tulip.get_event_loop() + tulip.Task(loop.create_datagram_endpoint( + MyClientUdpEchoProtocol, remote_addr=ADDRESS)) + loop.run_forever() + + +if __name__ == '__main__': + if '--server' in sys.argv: + start_server() + else: + start_client() diff --git a/old/Makefile b/old/Makefile new file mode 100644 index 00000000..d352cd70 --- /dev/null +++ b/old/Makefile @@ -0,0 +1,16 @@ +PYTHON=python3 + +main: + $(PYTHON) main.py -v + +echo: + $(PYTHON) echosvr.py -v + +profile: + $(PYTHON) -m profile -s time main.py + +time: + $(PYTHON) p3time.py + +ytime: + $(PYTHON) yyftime.py diff --git a/old/echoclt.py b/old/echoclt.py new file mode 100644 index 00000000..c24c573e --- /dev/null +++ b/old/echoclt.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3.3 +"""Example echo client.""" + +# Stdlib imports. +import logging +import socket +import sys +import time + +# Local imports. +import scheduling +import sockets + + +def echoclient(host, port): + """COROUTINE""" + testdata = b'hi hi hi ha ha ha\n' + try: + trans = yield from sockets.create_transport(host, port, + af=socket.AF_INET) + except OSError: + return False + try: + ok = yield from trans.send(testdata) + if ok: + response = yield from trans.recv(100) + ok = response == testdata.upper() + return ok + finally: + trans.close() + + +def doit(n): + """COROUTINE""" + t0 = time.time() + tasks = set() + for i in range(n): + t = scheduling.Task(echoclient('127.0.0.1', 1111), 'client-%d' % i) + tasks.add(t) + ok = 0 + bad = 0 + for t in tasks: + try: + yield from t + except Exception: + bad += 1 + else: + ok += 1 + t1 = time.time() + print('ok: ', ok) + print('bad:', bad) + print('dt: ', round(t1-t0, 6)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Get integer from command line. + n = 1 + for arg in sys.argv[1:]: + if not arg.startswith('-'): + n = int(arg) + break + + # Run scheduler, starting it off with doit(). + scheduling.run(doit(n)) + + +if __name__ == '__main__': + main() diff --git a/old/echosvr.py b/old/echosvr.py new file mode 100644 index 00000000..4085f4c6 --- /dev/null +++ b/old/echosvr.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3.3 +"""Example echo server.""" + +# Stdlib imports. +import logging +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + while True: + line = yield from rdr.readline() + logging.debug('Received: %r from %r', line, addr) + if not line: + break + yield from trans.send(line.upper()) + logging.debug('Closing %r', addr) + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 1111, + af=socket.AF_INET, + backlog=100) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() diff --git a/old/http_client.py b/old/http_client.py new file mode 100644 index 00000000..8937ba20 --- /dev/null +++ b/old/http_client.py @@ -0,0 +1,78 @@ +"""Crummy HTTP client. + +This is not meant as an example of how to write a good client. +""" + +# Stdlib. +import re +import time + +# Local. +import sockets + + +def urlfetch(host, port=None, path='/', method='GET', + body=None, hdrs=None, encoding='utf-8', ssl=None, af=0): + """COROUTINE: Make an HTTP 1.0 request.""" + t0 = time.time() + if port is None: + if ssl: + port = 443 + else: + port = 80 + trans = yield from sockets.create_transport(host, port, ssl=ssl, af=af) + yield from trans.send(method.encode(encoding) + b' ' + + path.encode(encoding) + b' HTTP/1.0\r\n') + if hdrs: + kwds = dict(hdrs) + else: + kwds = {} + if 'host' not in kwds: + kwds['host'] = host + if body is not None: + kwds['content_length'] = len(body) + for header, value in kwds.items(): + yield from trans.send(header.replace('_', '-').encode(encoding) + + b': ' + value.encode(encoding) + b'\r\n') + + yield from trans.send(b'\r\n') + if body is not None: + yield from trans.send(body) + + # Read HTTP response line. + rdr = sockets.BufferedReader(trans) + resp = yield from rdr.readline() + m = re.match(br'(?ix) http/(\d\.\d) \s+ (\d\d\d) \s+ ([^\r]*)\r?\n\Z', + resp) + if not m: + trans.close() + raise IOError('No valid HTTP response: %r' % resp) + http_version, status, message = m.groups() + + # Read HTTP headers. + headers = [] + hdict = {} + while True: + line = yield from rdr.readline() + if not line.strip(): + break + m = re.match(br'([^\s:]+):\s*([^\r]*)\r?\n\Z', line) + if not m: + raise IOError('Invalid header: %r' % line) + header, value = m.groups() + headers.append((header, value)) + hdict[header.decode(encoding).lower()] = value.decode(encoding) + + # Read response body. + content_length = hdict.get('content-length') + if content_length is not None: + size = int(content_length) # TODO: Catch errors. + assert size >= 0, size + else: + size = 2**20 # Protective limit (1 MB). + data = yield from rdr.readexactly(size) + trans.close() # Can this block? + t1 = time.time() + result = (host, port, path, int(status), len(data), round(t1-t0, 3)) +## print(result) + return result diff --git a/old/http_server.py b/old/http_server.py new file mode 100644 index 00000000..2b1e3dd6 --- /dev/null +++ b/old/http_server.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3.3 +"""Simple HTTP server. + +This currenty exists just so we can benchmark this thing! +""" + +# Stdlib imports. +import logging +import re +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + ##logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + + # Read but ignore request line. + request_line = yield from rdr.readline() + + # Consume headers but don't interpret them. + while True: + header_line = yield from rdr.readline() + if not header_line.strip(): + break + + # Always send an empty 200 response and close. + yield from trans.send(b'HTTP/1.0 200 Ok\r\n\r\n') + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 8080, + af=socket.AF_INET) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() diff --git a/old/main.py b/old/main.py new file mode 100644 index 00000000..c1f9d0a8 --- /dev/null +++ b/old/main.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3.3 +"""Example HTTP client using yield-from coroutines (PEP 380). + +Requires Python 3.3. + +There are many micro-optimizations possible here, but that's not the point. + +Some incomplete laundry lists: + +TODO: +- Take test urls from command line. +- Move urlfetch to a separate module. +- Profiling. +- Docstrings. +- Unittests. + +FUNCTIONALITY: +- Connection pool (keep connection open). +- Chunked encoding (request and response). +- Pipelining, e.g. zlib (request and response). +- Automatic encoding/decoding. +""" + +__author__ = 'Guido van Rossum ' + +# Standard library imports (keep in alphabetic order). +import logging +import os +import time +import socket +import sys + +# Local imports (keep in alphabetic order). +import scheduling +import http_client + + + +def doit2(): + argses = [ + ('localhost', 8080, '/'), + ('127.0.0.1', 8080, '/home'), + ('python.org', 80, '/'), + ('xkcd.com', 443, '/'), + ] + results = yield from scheduling.map_over( + lambda args: http_client.urlfetch(*args), argses, timeout=2) + for res in results: + print('-->', res) + return [] + + +def doit(): + TIMEOUT = 2 + tasks = set() + + # This references NDB's default test service. + # (Sadly the service is single-threaded.) + task1 = scheduling.Task(http_client.urlfetch('localhost', 8080, path='/'), + 'root', timeout=TIMEOUT) + tasks.add(task1) + task2 = scheduling.Task(http_client.urlfetch('127.0.0.1', 8080, + path='/home'), + 'home', timeout=TIMEOUT) + tasks.add(task2) + + # Fetch python.org home page. + task3 = scheduling.Task(http_client.urlfetch('python.org', 80, path='/'), + 'python', timeout=TIMEOUT) + tasks.add(task3) + + # Fetch XKCD home page using SSL. (Doesn't like IPv6.) + task4 = scheduling.Task(http_client.urlfetch('xkcd.com', ssl=True, path='/', + af=socket.AF_INET), + 'xkcd', timeout=TIMEOUT) + tasks.add(task4) + +## # Fetch many links from python.org (/x.y.z). +## for x in '123': +## for y in '0123456789': +## path = '/{}.{}'.format(x, y) +## g = http_client.urlfetch('82.94.164.162', 80, +## path=path, hdrs={'host': 'python.org'}) +## t = scheduling.Task(g, path, timeout=2) +## tasks.add(t) + +## print(tasks) + yield from scheduling.Task(scheduling.sleep(1), timeout=0.2).wait() + winners = yield from scheduling.wait_any(tasks) + print('And the winners are:', [w.name for w in winners]) + tasks = yield from scheduling.wait_all(tasks) + print('And the players were:', [t.name for t in tasks]) + return tasks + + +def logtimes(real): + utime, stime, cutime, cstime, unused = os.times() + logging.info('real %10.3f', real) + logging.info('user %10.3f', utime + cutime) + logging.info('sys %10.3f', stime + cstime) + + +def main(): + t0 = time.time() + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + task = scheduling.run(doit()) + if task.exception: + print('Exception:', repr(task.exception)) + if isinstance(task.exception, AssertionError): + raise task.exception + else: + for t in task.result: + print(t.name + ':', + repr(t.exception) if t.exception else t.result) + + # Report real, user, sys times. + t1 = time.time() + logtimes(t1-t0) + + +if __name__ == '__main__': + main() diff --git a/old/p3time.py b/old/p3time.py new file mode 100644 index 00000000..35e14c96 --- /dev/null +++ b/old/p3time.py @@ -0,0 +1,47 @@ +"""Compare timing of plain vs. yield-from calls.""" + +import gc +import time + +def plain(n): + if n <= 0: + return 1 + l = plain(n-1) + r = plain(n-1) + return l + 1 + r + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def submain(depth): + t0 = time.time() + k = plain(depth) + t1 = time.time() + fmt = ' {} {} {:-9,.5f}' + delta0 = t1-t0 + print(('plain' + fmt).format(depth, k, delta0)) + + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + delta1 = t1-t0 + print(('coro.' + fmt).format(depth, k, delta1)) + if delta0: + print(('relat' + fmt).format(depth, k, delta1/delta0)) + +def main(reasonable=16): + gc.disable() + for depth in range(reasonable): + submain(depth) + +if __name__ == '__main__': + main() diff --git a/old/polling.py b/old/polling.py new file mode 100644 index 00000000..6586efcc --- /dev/null +++ b/old/polling.py @@ -0,0 +1,535 @@ +"""Event loop and related classes. + +The event loop can be broken up into a pollster (the part responsible +for telling us when file descriptors are ready) and the event loop +proper, which wraps a pollster with functionality for scheduling +callbacks, immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. + +There are several implementations of the pollster part, several using +esoteric system calls that exist only on some platforms. These are: + +- kqueue (most BSD systems) +- epoll (newer Linux systems) +- poll (most UNIX systems) +- select (all UNIX systems, and Windows) +- TODO: Support IOCP on Windows and some UNIX platforms. + +NOTE: We don't use select on systems where any of the others is +available, because select performs poorly as the number of file +descriptors goes up. The ranking is roughly: + + 1. kqueue, epoll, IOCP + 2. poll + 3. select + +TODO: +- Optimize the various pollsters. +- Unittests. +""" + +import collections +import concurrent.futures +import heapq +import logging +import os +import select +import threading +import time + + +class PollsterBase: + """Base class for all polling implementations. + + This defines an interface to register and unregister readers and + writers for specific file descriptors, and an interface to get a + list of events. There's also an interface to check whether any + readers or writers are currently registered. + """ + + def __init__(self): + super().__init__() + self.readers = {} # {fd: token, ...}. + self.writers = {} # {fd: token, ...}. + + def pollable(self): + """Return True if any readers or writers are currently registered.""" + return bool(self.readers or self.writers) + + # Subclasses are expected to extend the add/remove methods. + + def register_reader(self, fd, token): + """Add or update a reader for a file descriptor.""" + self.readers[fd] = token + + def register_writer(self, fd, token): + """Add or update a writer for a file descriptor.""" + self.writers[fd] = token + + def unregister_reader(self, fd): + """Remove the reader for a file descriptor.""" + del self.readers[fd] + + def unregister_writer(self, fd): + """Remove the writer for a file descriptor.""" + del self.writers[fd] + + def poll(self, timeout=None): + """Poll for events. A subclass must implement this. + + If timeout is omitted or None, this blocks until at least one + event is ready. Otherwise, timeout gives a maximum time to + wait (in seconds as an int or float) -- the method returns as + soon as at least one event is ready or when the timeout is + expired. For a non-blocking poll, pass 0. + + The return value is a list of events; it is empty when the + timeout expired before any events were ready. Each event + is a token previously passed to register_reader/writer(). + """ + raise NotImplementedError + + +class SelectPollster(PollsterBase): + """Pollster implementation using select.""" + + def poll(self, timeout=None): + readable, writable, _ = select.select(self.readers, self.writers, + [], timeout) + events = [] + events += (self.readers[fd] for fd in readable) + events += (self.writers[fd] for fd in writable) + return events + + +class PollPollster(PollsterBase): + """Pollster implementation using poll.""" + + def __init__(self): + super().__init__() + self._poll = select.poll() + + def _update(self, fd): + assert isinstance(fd, int), fd + flags = 0 + if fd in self.readers: + flags |= select.POLLIN + if fd in self.writers: + flags |= select.POLLOUT + if flags: + self._poll.register(fd, flags) + else: + self._poll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + # Timeout is in seconds, but poll() takes milliseconds. + msecs = None if timeout is None else int(round(1000 * timeout)) + events = [] + for fd, flags in self._poll.poll(msecs): + if flags & (select.POLLIN | select.POLLHUP): + if fd in self.readers: + events.append(self.readers[fd]) + if flags & (select.POLLOUT | select.POLLHUP): + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class EPollPollster(PollsterBase): + """Pollster implementation using epoll.""" + + def __init__(self): + super().__init__() + self._epoll = select.epoll() + + def _update(self, fd): + assert isinstance(fd, int), fd + eventmask = 0 + if fd in self.readers: + eventmask |= select.EPOLLIN + if fd in self.writers: + eventmask |= select.EPOLLOUT + if eventmask: + try: + self._epoll.register(fd, eventmask) + except IOError: + self._epoll.modify(fd, eventmask) + else: + self._epoll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + if timeout is None: + timeout = -1 # epoll.poll() uses -1 to mean "wait forever". + events = [] + for fd, eventmask in self._epoll.poll(timeout): + if eventmask & select.EPOLLIN: + if fd in self.readers: + events.append(self.readers[fd]) + if eventmask & select.EPOLLOUT: + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class KqueuePollster(PollsterBase): + """Pollster implementation using kqueue.""" + + def __init__(self): + super().__init__() + self._kqueue = select.kqueue() + + def register_reader(self, fd, callback, *args): + if fd not in self.readers: + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_reader(fd, callback, *args) + + def register_writer(self, fd, callback, *args): + if fd not in self.writers: + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_writer(fd, callback, *args) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def poll(self, timeout=None): + events = [] + max_ev = len(self.readers) + len(self.writers) + for kev in self._kqueue.control(None, max_ev, timeout): + fd = kev.ident + flag = kev.filter + if flag == select.KQ_FILTER_READ and fd in self.readers: + events.append(self.readers[fd]) + elif flag == select.KQ_FILTER_WRITE and fd in self.writers: + events.append(self.writers[fd]) + return events + + +# Pick the best pollster class for the platform. +if hasattr(select, 'kqueue'): + best_pollster = KqueuePollster +elif hasattr(select, 'epoll'): + best_pollster = EPollPollster +elif hasattr(select, 'poll'): + best_pollster = PollPollster +else: + best_pollster = SelectPollster + + +class DelayedCall: + """Object returned by callback registration methods.""" + + def __init__(self, when, callback, args, kwds=None): + self.when = when + self.callback = callback + self.args = args + self.kwds = kwds + self.cancelled = False + + def cancel(self): + self.cancelled = True + + def __lt__(self, other): + return self.when < other.when + + def __le__(self, other): + return self.when <= other.when + + def __eq__(self, other): + return self.when == other.when + + +class EventLoop: + """Event loop functionality. + + This defines public APIs call_soon(), call_later(), run_once() and + run(). It also wraps Pollster APIs register_reader(), + register_writer(), remove_reader(), remove_writer() with + add_reader() etc. + + This class's instance variables are not part of its API. + """ + + def __init__(self, pollster=None): + super().__init__() + if pollster is None: + logging.info('Using pollster: %s', best_pollster.__name__) + pollster = best_pollster() + self.pollster = pollster + self.ready = collections.deque() # [(callback, args), ...] + self.scheduled = [] # [(when, callback, args), ...] + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_reader(fd, dcall) + return dcall + + def remove_reader(self, fd): + """Remove a reader callback.""" + self.pollster.unregister_reader(fd) + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_writer(fd, dcall) + return dcall + + def remove_writer(self, fd): + """Remove a writer callback.""" + self.pollster.unregister_writer(fd) + + def add_callback(self, dcall): + """Add a DelayedCall to ready or scheduled.""" + if dcall.cancelled: + return + if dcall.when is None: + self.ready.append(dcall) + else: + heapq.heappush(self.scheduled, dcall) + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + dcall = DelayedCall(None, callback, args) + self.ready.append(dcall) + return dcall + + def call_later(self, when, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The time can be an int or float, expressed in seconds. + + If when is small enough (~11 days), it's assumed to be a + relative time, meaning the call will be scheduled that many + seconds in the future; otherwise it's assumed to be a posix + timestamp as returned by time.time(). + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + if when < 10000000: + when += time.time() + dcall = DelayedCall(when, callback, args) + heapq.heappush(self.scheduled, dcall) + return dcall + + def run_once(self): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Pass in a timeout or deadline or something. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: As step 4, run everything scheduled by steps 1-3. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # TODO: Ensure this loop always finishes, even if some + # callbacks keeps registering more callbacks. + while self.ready: + dcall = self.ready.popleft() + if not dcall.cancelled: + try: + if dcall.kwds: + dcall.callback(*dcall.args, **dcall.kwds) + else: + dcall.callback(*dcall.args) + except Exception: + logging.exception('Exception in callback %s %r', + dcall.callback, dcall.args) + + # Remove delayed calls that were cancelled from head of queue. + while self.scheduled and self.scheduled[0].cancelled: + heapq.heappop(self.scheduled) + + # Inspect the poll queue. + if self.pollster.pollable(): + if self.scheduled: + when = self.scheduled[0].when + timeout = max(0, when - time.time()) + else: + timeout = None + t0 = time.time() + events = self.pollster.poll(timeout) + t1 = time.time() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + for dcall in events: + self.add_callback(dcall) + + # Handle 'later' callbacks that are ready. + now = time.time() + while self.scheduled: + dcall = self.scheduled[0] + if dcall.when > now: + break + dcall = heapq.heappop(self.scheduled) + self.call_soon(dcall.callback, *dcall.args) + + def run(self): + """Run the event loop until there is no work left to do. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + """ + while self.ready or self.scheduled or self.pollster.pollable(): + self.run_once() + + +MAX_WORKERS = 5 # Default max workers when creating an executor. + + +class ThreadRunner: + """Helper to submit work to a thread pool and wait for it. + + This is the glue between the single-threaded callback-based async + world and the threaded world. Use it to call functions that must + block and don't have an async alternative (e.g. getaddrinfo()). + + The only public API is submit(). + """ + + def __init__(self, eventloop, executor=None): + self.eventloop = eventloop + self.executor = executor # Will be constructed lazily. + self.pipe_read_fd, self.pipe_write_fd = os.pipe() + self.active_count = 0 + + def read_callback(self): + """Semi-permanent callback while at least one future is active.""" + assert self.active_count > 0, self.active_count + data = os.read(self.pipe_read_fd, 8192) # Traditional buffer size. + self.active_count -= len(data) + if self.active_count == 0: + self.eventloop.remove_reader(self.pipe_read_fd) + assert self.active_count >= 0, self.active_count + + def submit(self, func, *args, executor=None, callback=None): + """Submit a function to the thread pool. + + This returns a concurrent.futures.Future instance. The caller + should not wait for that, but rather use the callback argument.. + """ + if executor is None: + executor = self.executor + if executor is None: + # Lazily construct a default executor. + # TODO: Should this be shared between threads? + executor = concurrent.futures.ThreadPoolExecutor(MAX_WORKERS) + self.executor = executor + assert self.active_count >= 0, self.active_count + future = executor.submit(func, *args) + if self.active_count == 0: + self.eventloop.add_reader(self.pipe_read_fd, self.read_callback) + self.active_count += 1 + def done_callback(fut): + if callback is not None: + self.eventloop.call_soon(callback, fut) + # TODO: Wake up the pipe in call_soon()? + os.write(self.pipe_write_fd, b'x') + future.add_done_callback(done_callback) + return future + + +class Context(threading.local): + """Thread-local context. + + We use this to avoid having to explicitly pass around an event loop + or something to hold the current task. + + TODO: Add an API so frameworks can substitute a different notion + of context more easily. + """ + + def __init__(self, eventloop=None, threadrunner=None): + # Default event loop and thread runner are lazily constructed + # when first accessed. + self._eventloop = eventloop + self._threadrunner = threadrunner + self.current_task = None # For the benefit of scheduling.py. + + @property + def eventloop(self): + if self._eventloop is None: + self._eventloop = EventLoop() + return self._eventloop + + @property + def threadrunner(self): + if self._threadrunner is None: + self._threadrunner = ThreadRunner(self.eventloop) + return self._threadrunner + + +context = Context() # Thread-local! diff --git a/old/scheduling.py b/old/scheduling.py new file mode 100644 index 00000000..3864571d --- /dev/null +++ b/old/scheduling.py @@ -0,0 +1,354 @@ +#!/usr/bin/env python3.3 +"""Example coroutine scheduler, PEP-380-style ('yield from '). + +Requires Python 3.3. + +There are likely micro-optimizations possible here, but that's not the point. + +TODO: +- Docstrings. +- Unittests. + +PATTERNS TO TRY: +- Various synchronization primitives (Lock, RLock, Event, Condition, + Semaphore, BoundedSemaphore, Barrier). +""" + +__author__ = 'Guido van Rossum ' + +# Standard library imports (keep in alphabetic order). +from concurrent.futures import CancelledError, TimeoutError +import logging +import time +import types + +# Local imports (keep in alphabetic order). +import polling + + +context = polling.context + + +class Task: + """Wrapper around a stack of generators. + + This is a bit like a Future, but with a different interface. + + TODO: + - wait for result. + """ + + def __init__(self, gen, name=None, *, timeout=None): + assert isinstance(gen, types.GeneratorType), repr(gen) + self.gen = gen + self.name = name or gen.__name__ + self.timeout = timeout + self.eventloop = context.eventloop + self.canceleer = None + if timeout is not None: + self.canceleer = self.eventloop.call_later(timeout, self.cancel) + self.blocked = False + self.unblocker = None + self.cancelled = False + self.must_cancel = False + self.alive = True + self.result = None + self.exception = None + self.done_callbacks = [] + # Start the task immediately. + self.eventloop.call_soon(self.step) + + def add_done_callback(self, done_callback): + # For better or for worse, the callback will always be called + # with the task as an argument, like concurrent.futures.Future. + # TODO: Call it right away if task is no longer alive. + dcall = polling.DelayedCall(None, done_callback, (self,)) + self.done_callbacks.append(dcall) + self.done_callbacks = [dc for dc in self.done_callbacks + if not dc.cancelled] + return dcall + + def __repr__(self): + parts = [self.name] + is_current = (self is context.current_task) + if self.blocked: + parts.append('blocking' if is_current else 'blocked') + elif self.alive: + parts.append('running' if is_current else 'runnable') + if self.must_cancel: + parts.append('must_cancel') + if self.cancelled: + parts.append('cancelled') + if self.exception is not None: + parts.append('exception=%r' % self.exception) + elif not self.alive: + parts.append('result=%r' % (self.result,)) + if self.timeout is not None: + parts.append('timeout=%.3f' % self.timeout) + return 'Task<' + ', '.join(parts) + '>' + + def cancel(self): + if self.alive: + if not self.must_cancel and not self.cancelled: + self.must_cancel = True + if self.blocked: + self.unblock() + + def step(self): + assert self.alive, self + try: + context.current_task = self + if self.must_cancel: + self.must_cancel = False + self.cancelled = True + self.gen.throw(CancelledError()) + else: + next(self.gen) + except StopIteration as exc: + self.alive = False + self.result = exc.value + except Exception as exc: + self.alive = False + self.exception = exc + logging.debug('Uncaught exception in %s', self, + exc_info=True, stack_info=True) + except BaseException as exc: + self.alive = False + self.exception = exc + raise + else: + if not self.blocked: + self.eventloop.call_soon(self.step) + finally: + context.current_task = None + if not self.alive: + # Cancel timeout callback if set. + if self.canceleer is not None: + self.canceleer.cancel() + # Schedule done_callbacks. + for dcall in self.done_callbacks: + self.eventloop.add_callback(dcall) + + def block(self, unblock_callback=None, *unblock_args): + assert self is context.current_task, self + assert self.alive, self + assert not self.blocked, self + self.blocked = True + self.unblocker = (unblock_callback, unblock_args) + + def unblock_if_alive(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. + if self.alive: + self.unblock() + + def unblock(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. + assert self.alive, self + assert self.blocked, self + self.blocked = False + unblock_callback, unblock_args = self.unblocker + if unblock_callback is not None: + try: + unblock_callback(*unblock_args) + except Exception: + logging.error('Exception in unblocker in task %r', self.name) + raise + finally: + self.unblocker = None + self.eventloop.call_soon(self.step) + + def block_io(self, fd, flag): + assert isinstance(fd, int), repr(fd) + assert flag in ('r', 'w'), repr(flag) + if flag == 'r': + self.block(self.eventloop.remove_reader, fd) + self.eventloop.add_reader(fd, self.unblock) + else: + self.block(self.eventloop.remove_writer, fd) + self.eventloop.add_writer(fd, self.unblock) + + def wait(self): + """COROUTINE: Wait until this task is finished.""" + current_task = context.current_task + assert self is not current_task, (self, current_task) # How confusing! + if not self.alive: + return + current_task.block() + self.add_done_callback(current_task.unblock) + yield + + def __iter__(self): + """COROUTINE: Wait, then return result or raise exception. + + This adds a little magic so you can say + + x = yield from Task(gen()) + + and it is equivalent to + + x = yield from gen() + + but with the option to add a timeout (and only a tad slower). + """ + if self.alive: + yield from self.wait() + assert not self.alive + if self.exception is not None: + raise self.exception + return self.result + + +def run(arg=None): + """Run the event loop until it's out of work. + + If you pass a generator, it will be spawned for you. + You can also pass a task (already started). + Returns the task. + """ + t = None + if arg is not None: + if isinstance(arg, Task): + t = arg + else: + t = Task(arg) + context.eventloop.run() + if t is not None and t.exception is not None: + logging.error('Uncaught exception in startup task: %r', + t.exception) + return t + + +def sleep(secs): + """COROUTINE: Sleep for some time (a float in seconds).""" + current_task = context.current_task + unblocker = context.eventloop.call_later(secs, current_task.unblock) + current_task.block(unblocker.cancel) + yield + + +def block_r(fd): + """COROUTINE: Block until a file descriptor is ready for reading.""" + context.current_task.block_io(fd, 'r') + yield + + +def block_w(fd): + """COROUTINE: Block until a file descriptor is ready for writing.""" + context.current_task.block_io(fd, 'w') + yield + + +def call_in_thread(func, *args, executor=None): + """COROUTINE: Run a function in a thread.""" + task = context.current_task + eventloop = context.eventloop + future = context.threadrunner.submit(func, *args, + executor=executor, + callback=task.unblock_if_alive) + task.block(future.cancel) + yield + assert future.done() + return future.result() + + +def wait_for(count, tasks): + """COROUTINE: Wait for the first N of a set of tasks to complete. + + May return more than N if more than N are immediately ready. + + NOTE: Tasks that were cancelled or raised are also considered ready. + """ + assert tasks + assert all(isinstance(task, Task) for task in tasks) + tasks = set(tasks) + assert 1 <= count <= len(tasks) + current_task = context.current_task + assert all(task is not current_task for task in tasks) + todo = set() + done = set() + dcalls = [] + def wait_for_callback(task): + nonlocal todo, done, current_task, count, dcalls + todo.remove(task) + if len(done) < count: + done.add(task) + if len(done) == count: + for dcall in dcalls: + dcall.cancel() + current_task.unblock() + for task in tasks: + if task.alive: + todo.add(task) + else: + done.add(task) + if len(done) < count: + for task in todo: + dcall = task.add_done_callback(wait_for_callback) + dcalls.append(dcall) + current_task.block() + yield + return done + + +def wait_any(tasks): + """COROUTINE: Wait for the first of a set of tasks to complete.""" + return wait_for(1, tasks) + + +def wait_all(tasks): + """COROUTINE: Wait for all of a set of tasks to complete.""" + return wait_for(len(tasks), tasks) + + +def map_over(gen, *args, timeout=None): + """COROUTINE: map a generator over one or more iterables. + + E.g. map_over(foo, xs, ys) runs + + Task(foo(x, y)) for x, y in zip(xs, ys) + + and returns a list of all results (in that order). However if any + task raises an exception, the remaining tasks are cancelled and + the exception is propagated. + """ + # gen is a generator function. + tasks = [Task(gobj, timeout=timeout) for gobj in map(gen, *args)] + return (yield from par_tasks(tasks)) + + +def par(*args): + """COROUTINE: Wait for generators, return a list of results. + + Raises as soon as one of the tasks raises an exception (and then + remaining tasks are cancelled). + + This differs from par_tasks() in two ways: + - takes *args instead of list of args + - each arg may be a generator or a task + """ + tasks = [] + for arg in args: + if not isinstance(arg, Task): + # TODO: assert arg is a generator or an iterator? + arg = Task(arg) + tasks.append(arg) + return (yield from par_tasks(tasks)) + + +def par_tasks(tasks): + """COROUTINE: Wait for a list of tasks, return a list of results. + + Raises as soon as one of the tasks raises an exception (and then + remaining tasks are cancelled). + """ + todo = set(tasks) + while todo: + ts = yield from wait_any(todo) + for t in ts: + assert not t.alive, t + todo.remove(t) + if t.exception is not None: + for other in todo: + other.cancel() + raise t.exception + return [t.result for t in tasks] diff --git a/old/sockets.py b/old/sockets.py new file mode 100644 index 00000000..a5005dc3 --- /dev/null +++ b/old/sockets.py @@ -0,0 +1,348 @@ +"""Socket wrappers to go with scheduling.py. + +Classes: + +- SocketTransport: a transport implementation wrapping a socket. +- SslTransport: a transport implementation wrapping SSL around a socket. +- BufferedReader: a buffer wrapping the read end of a transport. + +Functions (all coroutines): + +- connect(): connect a socket. +- getaddrinfo(): look up an address. +- create_connection(): look up address and return a connected socket for it. +- create_transport(): look up address and return a connected transport. + +TODO: +- Improve transport abstraction. +- Make a nice protocol abstraction. +- Unittests. +- A write() call that isn't a generator (needed so you can substitute it + for sys.stderr, pass it to logging.StreamHandler, etc.). +""" + +__author__ = 'Guido van Rossum ' + +# Stdlib imports. +import errno +import socket +import ssl + +# Local imports. +import scheduling + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + + +class SocketTransport: + """Transport wrapping a socket. + + The socket must already be connected in non-blocking mode. + """ + + def __init__(self, sock): + self.sock = sock + + def recv(self, n): + """COROUTINE: Read up to n bytes, blocking as needed. + + Always returns at least one byte, except if the socket was + closed or disconnected and there's no more data; then it + returns b''. + """ + assert n >= 0, n + while True: + try: + return self.sock.recv(n) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return b'' + else: + raise # Unexpected, propagate. + yield from scheduling.block_r(self.sock.fileno()) + + def send(self, data): + """COROUTINE; Send data to the socket, blocking until all written. + + Return True if all went well, False if socket was disconnected. + """ + while data: + try: + n = self.sock.send(data) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + else: + assert 0 <= n <= len(data), (n, len(data)) + if n == len(data): + break + data = data[n:] + continue + yield from scheduling.block_w(self.sock.fileno()) + + return True + + def close(self): + """Close the socket. (Not a coroutine.)""" + self.sock.close() + + +class SslTransport: + """Transport wrapping a socket in SSL. + + The socket must already be connected at the TCP level in + non-blocking mode. + """ + + def __init__(self, rawsock, sslcontext=None): + self.rawsock = rawsock + self.sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self.sslsock = self.sslcontext.wrap_socket( + self.rawsock, do_handshake_on_connect=False) + + def do_handshake(self): + """COROUTINE: Finish the SSL handshake.""" + while True: + try: + self.sslsock.do_handshake() + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + else: + break + + def recv(self, n): + """COROUTINE: Read up to n bytes. + + This blocks until at least one byte is read, or until EOF. + """ + while True: + try: + return self.sslsock.recv(n) + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_r(self.sslsock.fileno()) + elif err.errno in _DISCONNECTED: + # Can this happen? + return b'' + else: + raise # Unexpected, propagate. + + def send(self, data): + """COROUTINE: Send data to the socket, blocking as needed.""" + while data: + try: + n = self.sslsock.send(data) + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_w(self.sslsock.fileno()) + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + if n == len(data): + break + data = data[n:] + + return True + + def close(self): + """Close the socket. (Not a coroutine.) + + This also closes the raw socket. + """ + self.sslsock.close() + + # TODO: More SSL-specific methods, e.g. certificate stuff, unwrap(), ... + + +class BufferedReader: + """A buffered reader wrapping a transport.""" + + def __init__(self, trans, limit=8192): + self.trans = trans + self.limit = limit + self.buffer = b'' + self.eof = False + + def read(self, n): + """COROUTINE: Read up to n bytes, blocking at most once.""" + assert n >= 0, n + if not self.buffer and not self.eof: + yield from self._fillbuffer(max(n, self.limit)) + return self._getfrombuffer(n) + + def readexactly(self, n): + """COUROUTINE: Read exactly n bytes, or until EOF.""" + blocks = [] + count = 0 + while count < n: + block = yield from self.read(n - count) + if not block: + break + blocks.append(block) + count += len(block) + return b''.join(blocks) + + def readline(self): + """COROUTINE: Read up to newline or limit, whichever comes first.""" + end = self.buffer.find(b'\n') + 1 # Point past newline, or 0. + while not end and not self.eof and len(self.buffer) < self.limit: + anchor = len(self.buffer) + yield from self._fillbuffer(self.limit) + end = self.buffer.find(b'\n', anchor) + 1 + if not end: + end = len(self.buffer) + if end > self.limit: + end = self.limit + return self._getfrombuffer(end) + + def _getfrombuffer(self, n): + """Read up to n bytes without blocking (not a coroutine).""" + if n >= len(self.buffer): + result, self.buffer = self.buffer, b'' + else: + result, self.buffer = self.buffer[:n], self.buffer[n:] + return result + + def _fillbuffer(self, n): + """COROUTINE: Fill buffer with one (up to) n bytes from transport.""" + assert not self.eof, '_fillbuffer called at eof' + data = yield from self.trans.recv(n) + if data: + self.buffer += data + else: + self.eof = True + + +def connect(sock, address): + """COROUTINE: Connect a socket to an address.""" + try: + sock.connect(address) + except socket.error as err: + if err.errno != errno.EINPROGRESS: + raise + yield from scheduling.block_w(sock.fileno()) + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + raise IOError(err, 'Connection refused') + + +def getaddrinfo(host, port, af=0, socktype=0, proto=0): + """COROUTINE: Look up an address and return a list of infos for it. + + Each info is a tuple (af, socktype, protocol, canonname, address). + """ + infos = yield from scheduling.call_in_thread(socket.getaddrinfo, + host, port, af, + socktype, proto) + return infos + + +def create_connection(host, port, af=0, socktype=socket.SOCK_STREAM, proto=0): + """COROUTINE: Look up address and create a socket connected to it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + yield from connect(sock, address) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return sock + + +def create_transport(host, port, af=0, ssl=None): + """COROUTINE: Look up address and create a transport connected to it.""" + if ssl is None: + ssl = (port == 443) + sock = yield from create_connection(host, port, af) + if ssl: + trans = SslTransport(sock) + yield from trans.do_handshake() + else: + trans = SocketTransport(sock) + return trans + + +class Listener: + """Wrapper for a listening socket.""" + + def __init__(self, sock): + self.sock = sock + + def accept(self): + """COROUTINE: Accept a connection.""" + while True: + try: + conn, addr = self.sock.accept() + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_r(self.sock.fileno()) + else: + raise # Unexpected, propagate. + else: + conn.setblocking(False) + return conn, addr + + +def create_listener(host, port, af=0, socktype=0, proto=0, + backlog=5, reuse_addr=True): + """COROUTINE: Look up address and create a listener for it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + if reuse_addr: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + sock.listen(backlog) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return Listener(sock) diff --git a/old/transports.py b/old/transports.py new file mode 100644 index 00000000..19095bf4 --- /dev/null +++ b/old/transports.py @@ -0,0 +1,496 @@ +"""Transports and Protocols, actually. + +Inspired by Twisted, PEP 3153 and github.com/lvh/async-pep. + +THIS IS NOT REAL CODE! IT IS JUST AN EXPERIMENT. +""" + +# Stdlib imports. +import collections +import errno +import logging +import socket +import ssl +import sys +import time + +# Local imports. +import polling +import scheduling +import sockets + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + + +class Transport: + """ABC representing a transport. + + There may be many implementations. The user never instantiates + this directly; they call some utility function, passing it a + protocol, and the utility function will call the protocol's + connection_made() method with a transport (or it will call + connection_lost() with an exception if it fails to create the + desired transport). + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + def write(self, data): + """Write some data (bytes) to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data (bytes) to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data will + be received. When all buffered data is flushed, the protocol's + connection_lost() method is called with None as its argument. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method is called with None as + its argument. + """ + raise NotImplementedError + + def half_close(self): + """Closes the write end after flushing buffered data. + + Data may still be received. + + TODO: What's the use case for this? How to implement it? + Should it call shutdown(SHUT_WR) after all the data is flushed? + Is there no use case for closing the other half first? + """ + raise NotImplementedError + + def pause(self): + """Pause the receiving end. + + No data will be received until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Cancels a pause() call, resumes receiving data. + """ + raise NotImplementedError + + +class Protocol: + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing. + + When the user wants to requests a transport, they pass a protocol + instance to a utility function. + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_list() will be called exactly once + with either an exception object or None as an argument. + + If the utility function does not succeed in creating a transport, + it will call connection_lost() with an exception object. + + State machine of calls: + + start -> [CM -> DR*] -> CL -> end + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the connection. + To send data, call its write() or writelines() method. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + + TODO: Should we allow it to be a bytesarray or some other + memory buffer? + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + Also called when we fail to make a connection at all (in that + case connection_made() will not be called). + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +# TODO: The rest is platform specific and should move elsewhere. + +class UnixSocketTransport(Transport): + + def __init__(self, eventloop, protocol, sock): + self._eventloop = eventloop + self._protocol = protocol + self._sock = sock + self._buffer = collections.deque() # For write(). + self._write_closed = False + + def _on_readable(self): + try: + data = self._sock.recv(8192) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + else: + if not data: + self._eventloop.remove_reader(self._sock.fileno()) + self._sock.close() + self._protocol.connection_lost(None) + else: + self._protocol.data_received(data) # XXX call_soon()? + + def write(self, data): + assert isinstance(data, bytes) + assert not self._write_closed + if not data: + # Silly, but it happens. + return + if self._buffer: + # We've already registered a callback, just buffer the data. + self._buffer.append(data) + # Consider pausing if the total length of the buffer is + # truly huge. + return + + # TODO: Refactor so there's more sharing between this and + # _on_writable(). + + # There's no callback registered yet. It's quite possible + # that the kernel has buffer space for our data, so try to + # write now. Since the socket is non-blocking it will + # give us an error in _TRYAGAIN if it doesn't have enough + # space for even one more byte; it will return the number + # of bytes written if it can write at least one byte. + try: + n = self._sock.send(data) + except socket.error as exc: + # An error. + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + # The kernel doesn't have room for more data right now. + n = 0 + else: + # Wrote at least one byte. + if n == len(data): + # Wrote it all. Done! + if self._write_closed: + self._sock.shutdown(socket.SHUT_WR) + return + # Throw away the data that was already written. + # TODO: Do this without copying the data? + data = data[n:] + self._buffer.append(data) + self._eventloop.add_writer(self._sock.fileno(), self._on_writable) + + def _on_writable(self): + while self._buffer: + data = self._buffer[0] + # TODO: Join small amounts of data? + try: + n = self._sock.send(data) + except socket.error as exc: + # Error handling is the same as in write(). + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + if n < len(data): + self._buffer[0] = data[n:] + return + self._buffer.popleft() + self._eventloop.remove_writer(self._sock.fileno()) + if self._write_closed: + self._sock.shutdown(socket.SHUT_WR) + + def abort(self): + self._bad_error(None) + + def _bad_error(self, exc): + # A serious error. Close the socket etc. + fd = self._sock.fileno() + # TODO: Record whether we have a writer and/or reader registered. + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + self._sock.close() + self._protocol.connection_lost(exc) # XXX call_soon()? + + def half_close(self): + self._write_closed = True + + +class UnixSslTransport(Transport): + + # TODO: Refactor Socket and Ssl transport to share some code. + # (E.g. buffering.) + + # TODO: Consider using coroutines instead of callbacks, it seems + # much easier that way. + + def __init__(self, eventloop, protocol, rawsock, sslcontext=None): + self._eventloop = eventloop + self._protocol = protocol + self._rawsock = rawsock + self._sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslsock = self._sslcontext.wrap_socket( + self._rawsock, do_handshake_on_connect=False) + + self._buffer = collections.deque() # For write(). + self._write_closed = False + + # Try the handshake now. Likely it will raise EAGAIN, then it + # will take care of registering the appropriate callback. + self._on_handshake() + + def _bad_error(self, exc): + # A serious error. Close the socket etc. + fd = self._sslsock.fileno() + # TODO: Record whether we have a writer and/or reader registered. + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + self._sslsock.close() + self._protocol.connection_lost(exc) # XXX call_soon()? + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._eventloop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._eventloop.add_writable(fd, self._on_handshake) + return + # TODO: What if it raises another error? + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + self._protocol.connection_made(self) + self._eventloop.add_reader(fd, self._on_ready) + self._eventloop.add_writer(fd, self._on_ready) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + else: + if data: + self._protocol.data_received(data) + else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer + self._eventloop.remove_reader(fd) + self._eventloop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + return + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = self._buffer[0] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + else: + if n == len(data): + self._buffer.popleft() + # Could try again, but let's just have the next callback do it. + else: + self._buffer[0] = data[n:] + + def write(self, data): + assert isinstance(data, bytes) + assert not self._write_closed + if not data: + return + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + def half_close(self): + self._write_closed = True + # Just set the flag. Calling shutdown() on the ssl socket + # breaks something, causing recv() to return binary data. + + +def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, + use_ssl=None): + # TODO: Pass in a protocol factory, not a protocol. + # What should be the exact sequence of events? + # - socket + # - transport + # - protocol + # - tell transport about protocol + # - tell protocol about transport + # Or should the latter two be reversed? Does it matter? + if port is None: + port = 443 if use_ssl else 80 + if use_ssl is None: + use_ssl = (port == 443) + if not socktype: + socktype = socket.SOCK_STREAM + eventloop = polling.context.eventloop + + def on_socket_connected(task): + assert not task.alive + if task.exception is not None: + # TODO: Call some callback. + raise task.exception + sock = task.result + assert sock is not None + logging.debug('on_socket_connected') + if use_ssl: + # You can pass an ssl.SSLContext object as use_ssl, + # or a bool. + if isinstance(use_ssl, bool): + sslcontext = None + else: + sslcontext = use_ssl + transport = UnixSslTransport(eventloop, protocol, sock, sslcontext) + else: + transport = UnixSocketTransport(eventloop, protocol, sock) + # TODO: Should the ransport make the following calls? + protocol.connection_made(transport) # XXX call_soon()? + # Don't do this before connection_made() is called. + eventloop.add_reader(sock.fileno(), transport._on_readable) + + coro = sockets.create_connection(host, port, af, socktype, proto) + task = scheduling.Task(coro) + task.add_done_callback(on_socket_connected) + + +def main(): # Testing... + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + host = 'xkcd.com' + if sys.argv[1:] and '.' in sys.argv[-1]: + host = sys.argv[-1] + + t0 = time.time() + + class TestProtocol(Protocol): + def connection_made(self, transport): + logging.info('Connection made at %.3f secs', time.time() - t0) + self.transport = transport + self.transport.write(b'GET / HTTP/1.0\r\nHost: ' + + host.encode('ascii') + + b'\r\n\r\n') + self.transport.half_close() + def data_received(self, data): + logging.info('Received %d bytes at t=%.3f', + len(data), time.time() - t0) + logging.debug('Received %r', data) + def connection_lost(self, exc): + logging.debug('Connection lost: %r', exc) + self.t1 = time.time() + logging.info('Total time %.3f secs', self.t1 - t0) + + tp = TestProtocol() + logging.debug('tp = %r', tp) + make_connection(tp, host, use_ssl=('-S' in sys.argv)) + logging.info('Running...') + polling.context.eventloop.run() + logging.info('Done.') + + +if __name__ == '__main__': + main() diff --git a/old/xkcd.py b/old/xkcd.py new file mode 100755 index 00000000..474009d0 --- /dev/null +++ b/old/xkcd.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3.3 +"""Minimal synchronous SSL demo, connecting to xkcd.com.""" + +import socket, ssl + +s = socket.socket() +s.connect(('xkcd.com', 443)) +ss = ssl.wrap_socket(s) + +ss.send(b'GET / HTTP/1.0\r\n\r\n') + +while True: + data = ss.recv(1000000) + print(data) + if not data: + break + +ss.close() diff --git a/old/yyftime.py b/old/yyftime.py new file mode 100644 index 00000000..f55234b9 --- /dev/null +++ b/old/yyftime.py @@ -0,0 +1,75 @@ +"""Compare timing of yield-from vs. yield calls.""" + +import gc +import time + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def run_coro(depth): + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + print('coro', depth, k, round(t1-t0, 6)) + return t1-t0 + +class Future: + + def __init__(self, g): + self.g = g + + def wait(self): + value = None + try: + while True: + f = self.g.send(value) + f.wait() + value = f.value + except StopIteration as err: + self.value = err.value + + + +def task(func): # Decorator + def wrapper(*args): + g = func(*args) + f = Future(g) + return f + return wrapper + +@task +def oldstyle(n): + if n <= 0: + return 1 + l = yield oldstyle(n-1) + r = yield oldstyle(n-1) + return l + 1 + r + +def run_olds(depth): + t0 = time.time() + f = oldstyle(depth) + f.wait() + k = f.value + t1 = time.time() + print('olds', depth, k, round(t1-t0, 6)) + return t1-t0 + +def main(): + gc.disable() + for depth in range(16): + tc = run_coro(depth) + to = run_olds(depth) + if tc: + print('ratio', round(to/tc, 2)) + +if __name__ == '__main__': + main() diff --git a/overlapped.c b/overlapped.c new file mode 100644 index 00000000..c9f6ec9f --- /dev/null +++ b/overlapped.c @@ -0,0 +1,997 @@ +/* + * Support for overlapped IO + * + * Some code borrowed from Modules/_winapi.c of CPython + */ + +/* XXX check overflow and DWORD <-> Py_ssize_t conversions + Check itemsize */ + +#include "Python.h" +#include "structmember.h" + +#define WINDOWS_LEAN_AND_MEAN +#include +#include +#include + +#if defined(MS_WIN32) && !defined(MS_WIN64) +# define F_POINTER "k" +# define T_POINTER T_ULONG +#else +# define F_POINTER "K" +# define T_POINTER T_ULONGLONG +#endif + +#define F_HANDLE F_POINTER +#define F_ULONG_PTR F_POINTER +#define F_DWORD "k" +#define F_BOOL "i" +#define F_UINT "I" + +#define T_HANDLE T_POINTER + +enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_WRITE, TYPE_ACCEPT, + TYPE_CONNECT, TYPE_DISCONNECT}; + +/* + * Map Windows error codes to subclasses of OSError + */ + +static PyObject * +SetFromWindowsErr(DWORD err) +{ + PyObject *exception_type; + + if (err == 0) + err = GetLastError(); + switch (err) { + case ERROR_CONNECTION_REFUSED: + exception_type = PyExc_ConnectionRefusedError; + break; + case ERROR_CONNECTION_ABORTED: + exception_type = PyExc_ConnectionAbortedError; + break; + default: + exception_type = PyExc_OSError; + } + return PyErr_SetExcFromWindowsErr(exception_type, err); +} + +/* + * Some functions should be loaded at runtime + */ + +static LPFN_ACCEPTEX Py_AcceptEx = NULL; +static LPFN_CONNECTEX Py_ConnectEx = NULL; +static LPFN_DISCONNECTEX Py_DisconnectEx = NULL; +static BOOL (CALLBACK *Py_CancelIoEx)(HANDLE, LPOVERLAPPED) = NULL; + +#define GET_WSA_POINTER(s, x) \ + (SOCKET_ERROR != WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, \ + &Guid##x, sizeof(Guid##x), &Py_##x, \ + sizeof(Py_##x), &dwBytes, NULL, NULL)) + +static int +initialize_function_pointers(void) +{ + GUID GuidAcceptEx = WSAID_ACCEPTEX; + GUID GuidConnectEx = WSAID_CONNECTEX; + GUID GuidDisconnectEx = WSAID_DISCONNECTEX; + HINSTANCE hKernel32; + SOCKET s; + DWORD dwBytes; + + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (s == INVALID_SOCKET) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + if (!GET_WSA_POINTER(s, AcceptEx) || + !GET_WSA_POINTER(s, ConnectEx) || + !GET_WSA_POINTER(s, DisconnectEx)) + { + closesocket(s); + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + closesocket(s); + + /* On WinXP we will have Py_CancelIoEx == NULL */ + hKernel32 = GetModuleHandle("KERNEL32"); + *(FARPROC *)&Py_CancelIoEx = GetProcAddress(hKernel32, "CancelIoEx"); + return 0; +} + +/* + * Completion port stuff + */ + +PyDoc_STRVAR( + CreateIoCompletionPort_doc, + "CreateIoCompletionPort(handle, port, key, concurrency) -> port\n\n" + "Create a completion port or register a handle with a port."); + +static PyObject * +overlapped_CreateIoCompletionPort(PyObject *self, PyObject *args) +{ + HANDLE FileHandle; + HANDLE ExistingCompletionPort; + ULONG_PTR CompletionKey; + DWORD NumberOfConcurrentThreads; + HANDLE ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE F_ULONG_PTR F_DWORD, + &FileHandle, &ExistingCompletionPort, &CompletionKey, + &NumberOfConcurrentThreads)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = CreateIoCompletionPort(FileHandle, ExistingCompletionPort, + CompletionKey, NumberOfConcurrentThreads); + Py_END_ALLOW_THREADS + + if (ret == NULL) + return SetFromWindowsErr(0); + return Py_BuildValue(F_HANDLE, ret); +} + +PyDoc_STRVAR( + GetQueuedCompletionStatus_doc, + "GetQueuedCompletionStatus(port, msecs) -> (err, bytes, key, address)\n\n" + "Get a message from completion port. Wait for up to msecs milliseconds."); + +static PyObject * +overlapped_GetQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort = NULL; + DWORD NumberOfBytes = 0; + ULONG_PTR CompletionKey = 0; + OVERLAPPED *Overlapped = NULL; + DWORD Milliseconds; + DWORD err; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, + &CompletionPort, &Milliseconds)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = GetQueuedCompletionStatus(CompletionPort, &NumberOfBytes, + &CompletionKey, &Overlapped, Milliseconds); + Py_END_ALLOW_THREADS + + err = ret ? ERROR_SUCCESS : GetLastError(); + if (Overlapped == NULL) { + if (err == WAIT_TIMEOUT) + Py_RETURN_NONE; + else + return SetFromWindowsErr(err); + } + return Py_BuildValue(F_DWORD F_DWORD F_ULONG_PTR F_POINTER, + err, NumberOfBytes, CompletionKey, Overlapped); +} + +PyDoc_STRVAR( + PostQueuedCompletionStatus_doc, + "PostQueuedCompletionStatus(port, bytes, key, address) -> None\n\n" + "Post a message to completion port."); + +static PyObject * +overlapped_PostQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort; + DWORD NumberOfBytes; + ULONG_PTR CompletionKey; + OVERLAPPED *Overlapped; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD F_ULONG_PTR F_POINTER, + &CompletionPort, &NumberOfBytes, &CompletionKey, + &Overlapped)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = PostQueuedCompletionStatus(CompletionPort, NumberOfBytes, + CompletionKey, Overlapped); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +/* + * Bind socket handle to local port without doing slow getaddrinfo() + */ + +PyDoc_STRVAR( + BindLocal_doc, + "BindLocal(handle, length_of_address_tuple) -> None\n\n" + "Bind a socket handle to an arbitrary local port.\n" + "If length_of_address_tuple is 2 then an AF_INET address is used.\n" + "If length_of_address_tuple is 4 then an AF_INET6 address is used."); + +static PyObject * +overlapped_BindLocal(PyObject *self, PyObject *args) +{ + SOCKET Socket; + int TupleLength; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE "i", &Socket, &TupleLength)) + return NULL; + + if (TupleLength == 2) { + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = 0; + addr.sin_addr.S_un.S_addr = INADDR_ANY; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else if (TupleLength == 4) { + struct sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_port = 0; + addr.sin6_addr = in6addr_any; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else { + PyErr_SetString(PyExc_ValueError, "expected tuple of length 2 or 4"); + return NULL; + } + + if (!ret) + return SetFromWindowsErr(WSAGetLastError()); + Py_RETURN_NONE; +} + +/* + * A Python object wrapping an OVERLAPPED structure and other useful data + * for overlapped I/O + */ + +PyDoc_STRVAR( + Overlapped_doc, + "Overlapped object"); + +typedef struct { + PyObject_HEAD + OVERLAPPED overlapped; + /* For convenience, we store the file handle too */ + HANDLE handle; + /* Error returned by last method call */ + DWORD error; + /* Type of operation */ + DWORD type; + /* Buffer used for reading (optional) */ + PyObject *read_buffer; + /* Buffer used for writing (optional) */ + Py_buffer write_buffer; +} OverlappedObject; + + +static PyObject * +Overlapped_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + OverlappedObject *self; + HANDLE event = INVALID_HANDLE_VALUE; + static char *kwlist[] = {"event", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|" F_HANDLE, kwlist, &event)) + return NULL; + + if (event == INVALID_HANDLE_VALUE) { + event = CreateEvent(NULL, TRUE, FALSE, NULL); + if (event == NULL) + return SetFromWindowsErr(0); + } + + self = PyObject_New(OverlappedObject, type); + if (self == NULL) { + if (event != NULL) + CloseHandle(event); + return NULL; + } + + self->handle = NULL; + self->error = 0; + self->type = TYPE_NONE; + self->read_buffer = NULL; + memset(&self->overlapped, 0, sizeof(OVERLAPPED)); + memset(&self->write_buffer, 0, sizeof(Py_buffer)); + if (event) + self->overlapped.hEvent = event; + return (PyObject *)self; +} + +static void +Overlapped_dealloc(OverlappedObject *self) +{ + DWORD bytes; + DWORD olderr = GetLastError(); + BOOL wait = FALSE; + BOOL ret; + + if (!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED) + { + if (Py_CancelIoEx && Py_CancelIoEx(self->handle, &self->overlapped)) + wait = TRUE; + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, + &bytes, wait); + Py_END_ALLOW_THREADS + + switch (ret ? ERROR_SUCCESS : GetLastError()) { + case ERROR_SUCCESS: + case ERROR_NOT_FOUND: + case ERROR_OPERATION_ABORTED: + break; + default: + PyErr_Format( + PyExc_RuntimeError, + "%R still has pending operation at " + "deallocation, the process may crash", self); + PyErr_WriteUnraisable(NULL); + } + } + + if (self->overlapped.hEvent != NULL) + CloseHandle(self->overlapped.hEvent); + + if (self->write_buffer.obj) + PyBuffer_Release(&self->write_buffer); + + Py_CLEAR(self->read_buffer); + PyObject_Del(self); + SetLastError(olderr); +} + +PyDoc_STRVAR( + Overlapped_cancel_doc, + "cancel() -> None\n\n" + "Cancel overlapped operation"); + +static PyObject * +Overlapped_cancel(OverlappedObject *self) +{ + BOOL ret = TRUE; + + if (self->type == TYPE_NOT_STARTED) + Py_RETURN_NONE; + + if (!HasOverlappedIoCompleted(&self->overlapped)) { + Py_BEGIN_ALLOW_THREADS + if (Py_CancelIoEx) + ret = Py_CancelIoEx(self->handle, &self->overlapped); + else + ret = CancelIo(self->handle); + Py_END_ALLOW_THREADS + } + + /* CancelIoEx returns ERROR_NOT_FOUND if the I/O completed in-between */ + if (!ret && GetLastError() != ERROR_NOT_FOUND) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +PyDoc_STRVAR( + Overlapped_getresult_doc, + "getresult(wait=False) -> result\n\n" + "Retrieve result of operation. If wait is true then it blocks\n" + "until the operation is finished. If wait is false and the\n" + "operation is still pending then an error is raised."); + +static PyObject * +Overlapped_getresult(OverlappedObject *self, PyObject *args) +{ + BOOL wait = FALSE; + DWORD transferred = 0; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, "|" F_BOOL, &wait)) + return NULL; + + if (self->type == TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation not yet attempted"); + return NULL; + } + + if (self->type == TYPE_NOT_STARTED) { + PyErr_SetString(PyExc_ValueError, "operation failed to start"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, &transferred, + wait); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + break; + case ERROR_BROKEN_PIPE: + if (self->read_buffer != NULL) + break; + /* fall through */ + default: + return SetFromWindowsErr(err); + } + + switch (self->type) { + case TYPE_READ: + assert(PyBytes_CheckExact(self->read_buffer)); + if (transferred != PyBytes_GET_SIZE(self->read_buffer) && + _PyBytes_Resize(&self->read_buffer, transferred)) + return NULL; + Py_INCREF(self->read_buffer); + return self->read_buffer; + case TYPE_ACCEPT: + case TYPE_CONNECT: + case TYPE_DISCONNECT: + Py_RETURN_NONE; + default: + return PyLong_FromUnsignedLong((unsigned long) transferred); + } +} + +PyDoc_STRVAR( + Overlapped_ReadFile_doc, + "ReadFile(handle, size) -> Overlapped[message]\n\n" + "Start overlapped read"); + +static PyObject * +Overlapped_ReadFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD nread; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &handle, &size)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = ReadFile(handle, PyBytes_AS_STRING(buf), size, &nread, + &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_BROKEN_PIPE: + self->type = TYPE_NOT_STARTED; + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSARecv_doc, + "RecvFile(handle, size, flags) -> Overlapped[message]\n\n" + "Start overlapped receive"); + +static PyObject * +Overlapped_WSARecv(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD flags = 0; + DWORD nread; + PyObject *buf; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD "|" F_DWORD, + &handle, &size, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + wsabuf.len = size; + wsabuf.buf = PyBytes_AS_STRING(buf); + + Py_BEGIN_ALLOW_THREADS + ret = WSARecv((SOCKET)handle, &wsabuf, 1, &nread, &flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_BROKEN_PIPE: + self->type = TYPE_NOT_STARTED; + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WriteFile_doc, + "WriteFile(handle, buf) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped write"); + +static PyObject * +Overlapped_WriteFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD written; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &handle, &bufobj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + + Py_BEGIN_ALLOW_THREADS + ret = WriteFile(handle, self->write_buffer.buf, + (DWORD)self->write_buffer.len, + &written, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSASend_doc, + "WSASend(handle, buf, flags) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped send"); + +static PyObject * +Overlapped_WSASend(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD flags; + DWORD written; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O" F_DWORD, + &handle, &bufobj, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + wsabuf.len = (DWORD)self->write_buffer.len; + wsabuf.buf = self->write_buffer.buf; + + Py_BEGIN_ALLOW_THREADS + ret = WSASend((SOCKET)handle, &wsabuf, 1, &written, flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_AcceptEx_doc, + "AcceptEx(listen_handle, accept_handle) -> Overlapped[address_as_bytes]\n\n" + "Start overlapped wait for client to connect"); + +static PyObject * +Overlapped_AcceptEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ListenSocket; + SOCKET AcceptSocket; + DWORD BytesReceived; + DWORD size; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE, + &ListenSocket, &AcceptSocket)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + size = sizeof(struct sockaddr_in6) + 16; + buf = PyBytes_FromStringAndSize(NULL, size*2); + if (!buf) + return NULL; + + self->type = TYPE_ACCEPT; + self->handle = (HANDLE)ListenSocket; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = Py_AcceptEx(ListenSocket, AcceptSocket, PyBytes_AS_STRING(buf), + 0, size, size, &BytesReceived, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + + +static int +parse_address(PyObject *obj, SOCKADDR *Address, int Length) +{ + char *Host; + unsigned short Port; + unsigned long FlowInfo; + unsigned long ScopeId; + + memset(Address, 0, Length); + + if (PyArg_ParseTuple(obj, "sH", &Host, &Port)) + { + Address->sa_family = AF_INET; + if (WSAStringToAddressA(Host, AF_INET, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN*)Address)->sin_port = htons(Port); + return Length; + } + else if (PyArg_ParseTuple(obj, "sHkk", &Host, &Port, &FlowInfo, &ScopeId)) + { + PyErr_Clear(); + Address->sa_family = AF_INET6; + if (WSAStringToAddressA(Host, AF_INET6, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN6*)Address)->sin6_port = htons(Port); + ((SOCKADDR_IN6*)Address)->sin6_flowinfo = FlowInfo; + ((SOCKADDR_IN6*)Address)->sin6_scope_id = ScopeId; + return Length; + } + + return -1; +} + + +PyDoc_STRVAR( + Overlapped_ConnectEx_doc, + "ConnectEx(client_handle, address_as_bytes) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_ConnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ConnectSocket; + PyObject *AddressObj; + char AddressBuf[sizeof(struct sockaddr_in6)]; + SOCKADDR *Address = (SOCKADDR*)AddressBuf; + int Length; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &ConnectSocket, &AddressObj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + Length = sizeof(AddressBuf); + Length = parse_address(AddressObj, Address, Length); + if (Length < 0) + return NULL; + + self->type = TYPE_CONNECT; + self->handle = (HANDLE)ConnectSocket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_ConnectEx(ConnectSocket, Address, Length, + NULL, 0, NULL, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_DisconnectEx_doc, + "DisconnectEx(handle, flags) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_DisconnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET Socket; + DWORD flags; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &Socket, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + self->type = TYPE_DISCONNECT; + self->handle = (HANDLE)Socket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_DisconnectEx(Socket, &self->overlapped, flags, 0); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +static PyObject* +Overlapped_getaddress(OverlappedObject *self) +{ + return PyLong_FromVoidPtr(&self->overlapped); +} + +static PyObject* +Overlapped_getpending(OverlappedObject *self) +{ + return PyBool_FromLong(!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED); +} + +static PyMethodDef Overlapped_methods[] = { + {"getresult", (PyCFunction) Overlapped_getresult, + METH_VARARGS, Overlapped_getresult_doc}, + {"cancel", (PyCFunction) Overlapped_cancel, + METH_NOARGS, Overlapped_cancel_doc}, + {"ReadFile", (PyCFunction) Overlapped_ReadFile, + METH_VARARGS, Overlapped_ReadFile_doc}, + {"WSARecv", (PyCFunction) Overlapped_WSARecv, + METH_VARARGS, Overlapped_WSARecv_doc}, + {"WriteFile", (PyCFunction) Overlapped_WriteFile, + METH_VARARGS, Overlapped_WriteFile_doc}, + {"WSASend", (PyCFunction) Overlapped_WSASend, + METH_VARARGS, Overlapped_WSASend_doc}, + {"AcceptEx", (PyCFunction) Overlapped_AcceptEx, + METH_VARARGS, Overlapped_AcceptEx_doc}, + {"ConnectEx", (PyCFunction) Overlapped_ConnectEx, + METH_VARARGS, Overlapped_ConnectEx_doc}, + {"DisconnectEx", (PyCFunction) Overlapped_DisconnectEx, + METH_VARARGS, Overlapped_DisconnectEx_doc}, + {NULL} +}; + +static PyMemberDef Overlapped_members[] = { + {"error", T_ULONG, + offsetof(OverlappedObject, error), + READONLY, "Error from last operation"}, + {"event", T_HANDLE, + offsetof(OverlappedObject, overlapped) + offsetof(OVERLAPPED, hEvent), + READONLY, "Overlapped event handle"}, + {NULL} +}; + +static PyGetSetDef Overlapped_getsets[] = { + {"address", (getter)Overlapped_getaddress, NULL, + "Address of overlapped structure"}, + {"pending", (getter)Overlapped_getpending, NULL, + "Whether the operation is pending"}, + {NULL}, +}; + +PyTypeObject OverlappedType = { + PyVarObject_HEAD_INIT(NULL, 0) + /* tp_name */ "_overlapped.Overlapped", + /* tp_basicsize */ sizeof(OverlappedObject), + /* tp_itemsize */ 0, + /* tp_dealloc */ (destructor) Overlapped_dealloc, + /* tp_print */ 0, + /* tp_getattr */ 0, + /* tp_setattr */ 0, + /* tp_reserved */ 0, + /* tp_repr */ 0, + /* tp_as_number */ 0, + /* tp_as_sequence */ 0, + /* tp_as_mapping */ 0, + /* tp_hash */ 0, + /* tp_call */ 0, + /* tp_str */ 0, + /* tp_getattro */ 0, + /* tp_setattro */ 0, + /* tp_as_buffer */ 0, + /* tp_flags */ Py_TPFLAGS_DEFAULT, + /* tp_doc */ "OVERLAPPED structure wrapper", + /* tp_traverse */ 0, + /* tp_clear */ 0, + /* tp_richcompare */ 0, + /* tp_weaklistoffset */ 0, + /* tp_iter */ 0, + /* tp_iternext */ 0, + /* tp_methods */ Overlapped_methods, + /* tp_members */ Overlapped_members, + /* tp_getset */ Overlapped_getsets, + /* tp_base */ 0, + /* tp_dict */ 0, + /* tp_descr_get */ 0, + /* tp_descr_set */ 0, + /* tp_dictoffset */ 0, + /* tp_init */ 0, + /* tp_alloc */ 0, + /* tp_new */ Overlapped_new, +}; + +static PyMethodDef overlapped_functions[] = { + {"CreateIoCompletionPort", overlapped_CreateIoCompletionPort, + METH_VARARGS, CreateIoCompletionPort_doc}, + {"GetQueuedCompletionStatus", overlapped_GetQueuedCompletionStatus, + METH_VARARGS, GetQueuedCompletionStatus_doc}, + {"PostQueuedCompletionStatus", overlapped_PostQueuedCompletionStatus, + METH_VARARGS, PostQueuedCompletionStatus_doc}, + {"BindLocal", overlapped_BindLocal, + METH_VARARGS, BindLocal_doc}, + {NULL} +}; + +static struct PyModuleDef overlapped_module = { + PyModuleDef_HEAD_INIT, + "_overlapped", + NULL, + -1, + overlapped_functions, + NULL, + NULL, + NULL, + NULL +}; + +#define WINAPI_CONSTANT(fmt, con) \ + PyDict_SetItemString(d, #con, Py_BuildValue(fmt, con)) + +PyMODINIT_FUNC +PyInit__overlapped(void) +{ + PyObject *m, *d; + + /* Ensure WSAStartup() called before initializing function pointers */ + m = PyImport_ImportModule("_socket"); + if (!m) + return NULL; + Py_DECREF(m); + + if (initialize_function_pointers() < 0) + return NULL; + + if (PyType_Ready(&OverlappedType) < 0) + return NULL; + + m = PyModule_Create(&overlapped_module); + if (PyModule_AddObject(m, "Overlapped", (PyObject *)&OverlappedType) < 0) + return NULL; + + d = PyModule_GetDict(m); + + WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING); + WINAPI_CONSTANT(F_DWORD, INFINITE); + WINAPI_CONSTANT(F_HANDLE, INVALID_HANDLE_VALUE); + WINAPI_CONSTANT(F_HANDLE, NULL); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_ACCEPT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_CONNECT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, TF_REUSE_SOCKET); + + return m; +} diff --git a/runtests.py b/runtests.py new file mode 100644 index 00000000..096a2561 --- /dev/null +++ b/runtests.py @@ -0,0 +1,198 @@ +"""Run all unittests. + +Usage: + python3 runtests.py [-v] [-q] [pattern] ... + +Where: + -v: verbose + -q: quiet + pattern: optional regex patterns to match test ids (default all tests) + +Note that the test id is the fully qualified name of the test, +including package, module, class and method, +e.g. 'tests.events_test.PolicyTests.testPolicy'. + +runtests.py with --coverage argument is equivalent of: + + $(COVERAGE) run --branch runtests.py -v + $(COVERAGE) html $(list of files) + $(COVERAGE) report -m $(list of files) + +""" + +# Originally written by Beech Horn (for NDB). + +import argparse +import logging +import os +import re +import sys +import subprocess +import unittest +import importlib.machinery + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +ARGS = argparse.ArgumentParser(description="Run all unittests.") +ARGS.add_argument( + '-v', action="store", dest='verbose', + nargs='?', const=1, type=int, default=0, help='verbose') +ARGS.add_argument( + '-x', action="store_true", dest='exclude', help='exclude tests') +ARGS.add_argument( + '-q', action="store_true", dest='quiet', help='quiet') +ARGS.add_argument( + '--tests', action="store", dest='testsdir', default='tests', + help='tests directory') +ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') +ARGS.add_argument( + 'pattern', action="store", nargs="*", + help='optional regex patterns to match test ids (default all tests)') + +COV_ARGS = argparse.ArgumentParser(description="Run all unittests.") +COV_ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') + + +def load_modules(basedir, suffix='.py'): + def list_dir(prefix, dir): + files = [] + + modpath = os.path.join(dir, '__init__.py') + if os.path.isfile(modpath): + mod = os.path.split(dir)[-1] + files.append(('%s%s' % (prefix, mod), modpath)) + + prefix = '%s%s.' % (prefix, mod) + + for name in os.listdir(dir): + path = os.path.join(dir, name) + + if os.path.isdir(path): + files.extend(list_dir('%s%s.' % (prefix, name), path)) + else: + if (name != '__init__.py' and + name.endswith(suffix) and + not name.startswith(('.', '_'))): + files.append(('%s%s' % (prefix, name[:-3]), path)) + + return files + + mods = [] + for modname, sourcefile in list_dir('', basedir): + if modname == 'runtests': + continue + try: + loader = importlib.machinery.SourceFileLoader(modname, sourcefile) + mods.append((loader.load_module(), sourcefile)) + except Exception as err: + print("Skipping '%s': %s" % (modname, err), file=sys.stderr) + + return mods + + +def load_tests(testsdir, includes=(), excludes=()): + mods = [mod for mod, _ in load_modules(testsdir)] + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + for mod in mods: + for name in set(dir(mod)): + if name.endswith('Tests'): + test_module = getattr(mod, name) + tests = loader.loadTestsFromTestCase(test_module) + if includes: + tests = [test + for test in tests + if any(re.search(pat, test.id()) + for pat in includes)] + if excludes: + tests = [test + for test in tests + if not any(re.search(pat, test.id()) + for pat in excludes)] + suite.addTests(tests) + + return suite + + +def runtests(): + args = ARGS.parse_args() + + testsdir = os.path.abspath(args.testsdir) + if not os.path.isdir(testsdir): + print("Tests directory is not found: %s\n" % testsdir) + ARGS.print_help() + return + + excludes = includes = [] + if args.exclude: + excludes = args.pattern + else: + includes = args.pattern + + v = 0 if args.quiet else args.verbose + 1 + + tests = load_tests(args.testsdir, includes, excludes) + logger = logging.getLogger() + if v == 0: + logger.setLevel(logging.CRITICAL) + elif v == 1: + logger.setLevel(logging.ERROR) + elif v == 2: + logger.setLevel(logging.WARNING) + elif v == 3: + logger.setLevel(logging.INFO) + elif v >= 4: + logger.setLevel(logging.DEBUG) + result = unittest.TextTestRunner(verbosity=v).run(tests) + sys.exit(not result.wasSuccessful()) + + +def runcoverage(sdir, args): + """ + To install coverage3 for Python 3, you need: + - Distribute (http://packages.python.org/distribute/) + + What worked for me: + - download http://python-distribute.org/distribute_setup.py + * curl -O http://python-distribute.org/distribute_setup.py + - python3 distribute_setup.py + - python3 -m easy_install coverage + """ + try: + import coverage + except ImportError: + print("Coverage package is not found.") + print(runcoverage.__doc__) + return + + sdir = os.path.abspath(sdir) + if not os.path.isdir(sdir): + print("Python files directory is not found: %s\n" % sdir) + ARGS.print_help() + return + + mods = [source for _, source in load_modules(sdir)] + coverage = [sys.executable, '-m', 'coverage'] + + try: + subprocess.check_call( + coverage + ['run', '--branch', 'runtests.py'] + args) + except: + pass + else: + subprocess.check_call(coverage + ['html'] + mods) + subprocess.check_call(coverage + ['report'] + mods) + + +if __name__ == '__main__': + if '--coverage' in sys.argv: + cov_args, args = COV_ARGS.parse_known_args() + runcoverage(cov_args.coverage, args) + else: + runtests() diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..0260f9d5 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[build_ext] +build_lib=tulip diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..dcaee96f --- /dev/null +++ b/setup.py @@ -0,0 +1,14 @@ +import os +from distutils.core import setup, Extension + +extensions = [] +if os.name == 'nt': + ext = Extension('_overlapped', ['overlapped.c'], libraries=['ws2_32']) + extensions.append(ext) + +setup(name='tulip', + description="reference implementation of PEP 3156", + url='http://www.python.org/dev/peps/pep-3156/', + packages=['tulip'], + ext_modules=extensions + ) diff --git a/srv.py b/srv.py new file mode 100755 index 00000000..b28abbda --- /dev/null +++ b/srv.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +"""Simple server written using an event loop.""" + +import email.message +import os +import sys + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +import tulip +import tulip.http + + +class HttpServer(tulip.http.ServerHttpProtocol): + + def handle_request(self, request_info, message): + print('method = {!r}; path = {!r}; version = {!r}'.format( + request_info.method, request_info.uri, request_info.version)) + + path = request_info.uri + + if (not (path.isprintable() and path.startswith('/')) or '/.' in path): + print('bad path', repr(path)) + path = None + else: + path = '.' + path + if not os.path.exists(path): + print('no file', repr(path)) + path = None + else: + isdir = os.path.isdir(path) + + if not path: + raise tulip.http.HttpStatusException(404) + + headers = email.message.Message() + for hdr, val in message.headers: + print(hdr, val) + headers.add_header(hdr, val) + + if isdir and not path.endswith('/'): + path = path + '/' + raise tulip.http.HttpStatusException( + 302, headers=(('URI', path), ('Location', path))) + + response = tulip.http.Response(self.transport, 200) + response.add_header('Transfer-Encoding', 'chunked') + + # content encoding + accept_encoding = headers.get('accept-encoding', '').lower() + if 'deflate' in accept_encoding: + response.add_header('Content-Encoding', 'deflate') + response.add_compression_filter('deflate') + elif 'gzip' in accept_encoding: + response.add_header('Content-Encoding', 'gzip') + response.add_compression_filter('gzip') + + response.add_chunking_filter(1025) + + if isdir: + response.add_header('Content-type', 'text/html') + response.send_headers() + + response.write(b'
    \r\n') + for name in sorted(os.listdir(path)): + if name.isprintable() and not name.startswith('.'): + try: + bname = name.encode('ascii') + except UnicodeError: + pass + else: + if os.path.isdir(os.path.join(path, name)): + response.write(b'
  • ' + bname + b'/
  • \r\n') + else: + response.write(b'
  • ' + bname + b'
  • \r\n') + response.write(b'
') + else: + response.add_header('Content-type', 'text/plain') + response.send_headers() + + try: + with open(path, 'rb') as fp: + chunk = fp.read(8196) + while chunk: + if not response.write(chunk): + break + chunk = fp.read(8196) + except OSError: + response.write(b'Cannot open') + + response.write_eof() + self.close() + + +def main(): + host = '127.0.0.1' + port = 8080 + if sys.argv[1:]: + host = sys.argv[1] + if sys.argv[2:]: + port = int(sys.argv[2]) + elif ':' in host: + host, port = host.split(':', 1) + port = int(port) + loop = tulip.get_event_loop() + f = loop.start_serving(lambda: HttpServer(debug=True), host, port) + x = loop.run_until_complete(f) + print('serving on', x.getsockname()) + loop.run_forever() + + +if __name__ == '__main__': + main() diff --git a/sslsrv.py b/sslsrv.py new file mode 100644 index 00000000..a1bc04f9 --- /dev/null +++ b/sslsrv.py @@ -0,0 +1,56 @@ +"""Serve up an SSL connection, after Python ssl module docs.""" + +import socket +import ssl +import os + + +def main(): + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + certfile = getcertfile() + context.load_cert_chain(certfile=certfile, keyfile=certfile) + bindsocket = socket.socket() + bindsocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + bindsocket.bind(('', 4443)) + bindsocket.listen(5) + + while True: + newsocket, fromaddr = bindsocket.accept() + try: + connstream = context.wrap_socket(newsocket, server_side=True) + try: + deal_with_client(connstream) + finally: + connstream.shutdown(socket.SHUT_RDWR) + connstream.close() + except Exception as exc: + print(exc.__class__.__name__ + ':', exc) + + +def getcertfile(): + import test # Test package + testdir = os.path.dirname(test.__file__) + certfile = os.path.join(testdir, 'keycert.pem') + print('certfile =', certfile) + return certfile + + +def deal_with_client(connstream): + data = connstream.recv(1024) + # empty data means the client is finished with us + while data: + if not do_something(connstream, data): + # we'll assume do_something returns False + # when we're finished with client + break + data = connstream.recv(1024) + # finished with client + + +def do_something(connstream, data): + # just echo back + connstream.sendall(data) + + +if __name__ == '__main__': + main() diff --git a/tests/base_events_test.py b/tests/base_events_test.py new file mode 100644 index 00000000..88f3faf4 --- /dev/null +++ b/tests/base_events_test.py @@ -0,0 +1,283 @@ +"""Tests for base_events.py""" + +import concurrent.futures +import logging +import socket +import time +import unittest +import unittest.mock + +from tulip import base_events +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import tasks +from tulip import test_utils + + +class BaseEventLoopTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + + self.event_loop = base_events.BaseEventLoop() + self.event_loop._selector = unittest.mock.Mock() + self.event_loop._selector.registered_count.return_value = 1 + + def test_not_implemented(self): + m = unittest.mock.Mock() + self.assertRaises( + NotImplementedError, + self.event_loop._make_socket_transport, m, m) + self.assertRaises( + NotImplementedError, + self.event_loop._make_ssl_transport, m, m, m, m) + self.assertRaises( + NotImplementedError, + self.event_loop._make_datagram_transport, m, m) + self.assertRaises( + NotImplementedError, self.event_loop._process_events, []) + self.assertRaises( + NotImplementedError, self.event_loop._write_to_self) + self.assertRaises( + NotImplementedError, self.event_loop._read_from_self) + self.assertRaises( + NotImplementedError, + self.event_loop._make_read_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + self.event_loop._make_write_pipe_transport, m, m) + + def test_add_callback_handle(self): + h = events.Handle(lambda: False, ()) + + self.event_loop._add_callback(h) + self.assertFalse(self.event_loop._scheduled) + self.assertIn(h, self.event_loop._ready) + + def test_add_callback_timer(self): + when = time.monotonic() + + h1 = events.Timer(when, lambda: False, ()) + h2 = events.Timer(when+10.0, lambda: False, ()) + + self.event_loop._add_callback(h2) + self.event_loop._add_callback(h1) + self.assertEqual([h1, h2], self.event_loop._scheduled) + self.assertFalse(self.event_loop._ready) + + def test_add_callback_cancelled_handle(self): + h = events.Handle(lambda: False, ()) + h.cancel() + + self.event_loop._add_callback(h) + self.assertFalse(self.event_loop._scheduled) + self.assertFalse(self.event_loop._ready) + + def test_wrap_future(self): + f = futures.Future() + self.assertIs(self.event_loop.wrap_future(f), f) + + def test_wrap_future_concurrent(self): + f = concurrent.futures.Future() + self.assertIsInstance(self.event_loop.wrap_future(f), futures.Future) + + def test_set_default_executor(self): + executor = unittest.mock.Mock() + self.event_loop.set_default_executor(executor) + self.assertIs(executor, self.event_loop._default_executor) + + def test_getnameinfo(self): + sockaddr = unittest.mock.Mock() + self.event_loop.run_in_executor = unittest.mock.Mock() + self.event_loop.getnameinfo(sockaddr) + self.assertEqual( + (None, socket.getnameinfo, sockaddr, 0), + self.event_loop.run_in_executor.call_args[0]) + + def test_call_soon(self): + def cb(): + pass + + h = self.event_loop.call_soon(cb) + self.assertEqual(h._callback, cb) + self.assertIsInstance(h, events.Handle) + self.assertIn(h, self.event_loop._ready) + + def test_call_later(self): + def cb(): + pass + + h = self.event_loop.call_later(10.0, cb) + self.assertIsInstance(h, events.Timer) + self.assertIn(h, self.event_loop._scheduled) + self.assertNotIn(h, self.event_loop._ready) + + def test_call_later_no_delay(self): + def cb(): + pass + + h = self.event_loop.call_later(0, cb) + self.assertIn(h, self.event_loop._ready) + self.assertNotIn(h, self.event_loop._scheduled) + + def test_run_once_in_executor_handle(self): + def cb(): + pass + + self.assertRaises( + AssertionError, self.event_loop.run_in_executor, + None, events.Handle(cb, ()), ('',)) + self.assertRaises( + AssertionError, self.event_loop.run_in_executor, + None, events.Timer(10, cb, ())) + + def test_run_once_in_executor_canceled(self): + def cb(): + pass + h = events.Handle(cb, ()) + h.cancel() + + f = self.event_loop.run_in_executor(None, h) + self.assertIsInstance(f, futures.Future) + self.assertTrue(f.done()) + + def test_run_once_in_executor(self): + def cb(): + pass + h = events.Handle(cb, ()) + f = futures.Future() + executor = unittest.mock.Mock() + executor.submit.return_value = f + + self.event_loop.set_default_executor(executor) + + res = self.event_loop.run_in_executor(None, h) + self.assertIs(f, res) + + executor = unittest.mock.Mock() + executor.submit.return_value = f + res = self.event_loop.run_in_executor(executor, h) + self.assertIs(f, res) + self.assertTrue(executor.submit.called) + + def test_run_once(self): + self.event_loop._run_once = unittest.mock.Mock() + self.event_loop._run_once.side_effect = base_events._StopError + self.event_loop.run_once() + self.assertTrue(self.event_loop._run_once.called) + + def test__run_once(self): + h1 = events.Timer(time.monotonic() + 0.1, lambda: True, ()) + h2 = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + + h1.cancel() + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h1) + self.event_loop._scheduled.append(h2) + self.event_loop._run_once() + + t = self.event_loop._selector.select.call_args[0][0] + self.assertTrue(9.99 < t < 10.1) + self.assertEqual([h2], self.event_loop._scheduled) + self.assertTrue(self.event_loop._process_events.called) + + def test__run_once_timeout(self): + h = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._run_once(1.0) + self.assertEqual((1.0,), self.event_loop._selector.select.call_args[0]) + + def test__run_once_timeout_with_ready(self): + # If event loop has ready callbacks, select timeout is always 0. + h = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._ready.append(h) + self.event_loop._run_once(1.0) + + self.assertEqual((0,), self.event_loop._selector.select.call_args[0]) + + @unittest.mock.patch('tulip.base_events.time') + @unittest.mock.patch('tulip.base_events.tulip_log') + def test__run_once_logging(self, m_logging, m_time): + # Log to INFO level if timeout > 1.0 sec. + idx = -1 + data = [10.0, 10.0, 12.0, 13.0] + + def monotonic(): + nonlocal data, idx + idx += 1 + return data[idx] + + m_time.monotonic = monotonic + m_logging.INFO = logging.INFO + m_logging.DEBUG = logging.DEBUG + + self.event_loop._scheduled.append(events.Timer(11.0, lambda: True, ())) + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._run_once() + self.assertEqual(logging.INFO, m_logging.log.call_args[0][0]) + + idx = -1 + data = [10.0, 10.0, 10.3, 13.0] + self.event_loop._scheduled = [events.Timer(11.0, lambda:True, ())] + self.event_loop._run_once() + self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0]) + + def test__run_once_schedule_handle(self): + handle = None + processed = False + + def cb(event_loop): + nonlocal processed, handle + processed = True + handle = event_loop.call_soon(lambda: True) + + h = events.Timer(time.monotonic() - 1, cb, (self.event_loop,)) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._run_once() + + self.assertTrue(processed) + self.assertEqual([handle], list(self.event_loop._ready)) + + def test_run_until_complete_assertion(self): + self.assertRaises( + AssertionError, self.event_loop.run_until_complete, 'blah') + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_connection_mutiple_errors(self, m_socket): + self.suppress_log_errors() + + class MyProto(protocols.Protocol): + pass + + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + + idx = -1 + errors = ['err1', 'err2'] + + def _socket(*args, **kw): + nonlocal idx, errors + idx += 1 + raise socket.error(errors[idx]) + + m_socket.socket = _socket + m_socket.error = socket.error + + self.event_loop.getaddrinfo = getaddrinfo + + task = tasks.Task( + self.event_loop.create_connection(MyProto, 'xkcd.com', 80)) + task._step() + exc = task.exception() + self.assertEqual("Multiple exceptions: err1, err2", str(exc)) diff --git a/tests/events_test.py b/tests/events_test.py new file mode 100644 index 00000000..40859211 --- /dev/null +++ b/tests/events_test.py @@ -0,0 +1,1379 @@ +"""Tests for events.py.""" + +import concurrent.futures +import contextlib +import gc +import io +import os +import re +import signal +import socket +try: + import ssl +except ImportError: + ssl = None +import sys +import threading +import time +import unittest +import unittest.mock + +from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer + +from tulip import events +from tulip import transports +from tulip import protocols +from tulip import selector_events +from tulip import tasks +from tulip import test_utils + + +class MyProto(protocols.Protocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: xkcd.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + + +class MyDatagramProto(protocols.DatagramProtocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + + +class MyReadPipeProto(protocols.Protocol): + + def __init__(self): + self.state = ['INITIAL'] + self.nbytes = 0 + self.transport = None + + def connection_made(self, transport): + self.transport = transport + assert self.state == ['INITIAL'], self.state + self.state.append('CONNECTED') + + def data_received(self, data): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.state.append('EOF') + self.transport.close() + + def connection_lost(self, exc): + assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state + self.state.append('CLOSED') + + +class MyWritePipeProto(protocols.Protocol): + + def __init__(self): + self.state = 'INITIAL' + self.transport = None + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + + +class EventLoopTestsMixin: + + def setUp(self): + super().setUp() + self.event_loop = self.create_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + gc.collect() + super().tearDown() + + @contextlib.contextmanager + def run_test_server(self, *, use_ssl=False): + + class SilentWSGIRequestHandler(WSGIRequestHandler): + def get_stderr(self): + return io.StringIO() + + def log_message(self, format, *args): + pass + + class SilentWSGIServer(WSGIServer): + def handle_error(self, request, client_address): + pass + + class SSLWSGIServer(SilentWSGIServer): + def finish_request(self, request, client_address): + here = os.path.dirname(__file__) + keyfile = os.path.join(here, 'sample.key') + certfile = os.path.join(here, 'sample.crt') + ssock = ssl.wrap_socket(request, + keyfile=keyfile, + certfile=certfile, + server_side=True) + try: + self.RequestHandlerClass(ssock, client_address, self) + ssock.close() + except OSError: + # maybe socket has been closed by peer + pass + + def app(environ, start_response): + status = '302 Found' + headers = [('Content-type', 'text/plain')] + start_response(status, headers) + return [b'Test message'] + + # Run the test WSGI server in a separate thread in order not to + # interfere with event handling in the main thread + server_class = SSLWSGIServer if use_ssl else SilentWSGIServer + httpd = make_server('127.0.0.1', 0, app, + server_class, SilentWSGIRequestHandler) + server_thread = threading.Thread(target=httpd.serve_forever) + server_thread.start() + try: + yield httpd + finally: + httpd.shutdown() + server_thread.join() + + def test_run(self): + self.event_loop.run() # Returns immediately. + + def test_run_nesting(self): + err = None + + @tasks.coroutine + def coro(): + nonlocal err + yield from [] + self.assertTrue(self.event_loop.is_running()) + try: + self.event_loop.run_until_complete( + tasks.sleep(0.1)) + except Exception as exc: + err = exc + + self.event_loop.run_until_complete(tasks.Task(coro())) + self.assertIsInstance(err, RuntimeError) + + def test_run_once_nesting(self): + err = None + + @tasks.coroutine + def coro(): + nonlocal err + yield from [] + tasks.sleep(0.1) + try: + self.event_loop.run_once() + except Exception as exc: + err = exc + + self.event_loop.run_until_complete(tasks.Task(coro())) + self.assertIsInstance(err, RuntimeError) + + def test_run_once_block(self): + called = False + + def callback(): + nonlocal called + called = True + + def run(): + time.sleep(0.1) + self.event_loop.call_soon_threadsafe(callback) + + t = threading.Thread(target=run) + t0 = time.monotonic() + t.start() + self.event_loop.run_once(None) + t1 = time.monotonic() + t.join() + self.assertTrue(called) + self.assertTrue(0.09 < t1-t0 <= 0.12) + + def test_call_later(self): + results = [] + + def callback(arg): + results.append(arg) + + self.event_loop.call_later(0.1, callback, 'hello world') + t0 = time.monotonic() + self.event_loop.run() + t1 = time.monotonic() + self.assertEqual(results, ['hello world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_call_repeatedly(self): + results = [] + + def callback(arg): + results.append(arg) + + self.event_loop.call_repeatedly(0.03, callback, 'ho') + self.event_loop.call_later(0.1, self.event_loop.stop) + self.event_loop.run() + self.assertEqual(results, ['ho', 'ho', 'ho']) + + def test_call_soon(self): + results = [] + + def callback(arg1, arg2): + results.append((arg1, arg2)) + + self.event_loop.call_soon(callback, 'hello', 'world') + self.event_loop.run() + self.assertEqual(results, [('hello', 'world')]) + + def test_call_soon_with_handle(self): + results = [] + + def callback(): + results.append('yeah') + + handle = events.Handle(callback, ()) + self.assertIs(self.event_loop.call_soon(handle), handle) + self.event_loop.run() + self.assertEqual(results, ['yeah']) + + def test_call_soon_threadsafe(self): + results = [] + + def callback(arg): + results.append(arg) + + def run(): + self.event_loop.call_soon_threadsafe(callback, 'hello') + + t = threading.Thread(target=run) + self.event_loop.call_later(0.1, callback, 'world') + t0 = time.monotonic() + t.start() + self.event_loop.run() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_call_soon_threadsafe_same_thread(self): + results = [] + + def callback(arg): + results.append(arg) + + self.event_loop.call_later(0.1, callback, 'world') + self.event_loop.call_soon_threadsafe(callback, 'hello') + self.event_loop.run() + self.assertEqual(results, ['hello', 'world']) + + def test_call_soon_threadsafe_with_handle(self): + results = [] + + def callback(arg): + results.append(arg) + + handle = events.Handle(callback, ('hello',)) + + def run(): + self.assertIs( + self.event_loop.call_soon_threadsafe(handle), handle) + + t = threading.Thread(target=run) + self.event_loop.call_later(0.1, callback, 'world') + + t0 = time.monotonic() + t.start() + self.event_loop.run() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_wrap_future(self): + def run(arg): + time.sleep(0.1) + return arg + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = self.event_loop.wrap_future(f1) + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'oi') + + def test_run_in_executor(self): + def run(arg): + time.sleep(0.1) + return arg + f2 = self.event_loop.run_in_executor(None, run, 'yo') + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def test_run_in_executor_with_handle(self): + def run(arg): + time.sleep(0.1) + return arg + handle = events.Handle(run, ('yo',)) + f2 = self.event_loop.run_in_executor(None, handle) + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def test_reader_callback(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.event_loop.remove_reader(r.fileno())) + r.close() + + self.event_loop.add_reader(r.fileno(), reader) + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_reader_callback_with_handle(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.event_loop.remove_reader(r.fileno())) + r.close() + + handle = events.Handle(reader, ()) + self.assertIs(handle, self.event_loop.add_reader(r.fileno(), handle)) + + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_reader_callback_cancel(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + return + if data: + bytes_read.append(data) + if sum(len(b) for b in bytes_read) >= 6: + handle.cancel() + if not data: + r.close() + + handle = self.event_loop.add_reader(r.fileno(), reader) + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_writer_callback(self): + r, w = test_utils.socketpair() + w.setblocking(False) + self.event_loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + + def remove_writer(): + self.assertTrue(self.event_loop.remove_writer(w.fileno())) + + self.event_loop.call_later(0.1, remove_writer) + self.event_loop.run() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def test_writer_callback_with_handle(self): + r, w = test_utils.socketpair() + w.setblocking(False) + handle = events.Handle(w.send, (b'x'*(256*1024),)) + self.assertIs(self.event_loop.add_writer(w.fileno(), handle), handle) + + def remove_writer(): + self.assertTrue(self.event_loop.remove_writer(w.fileno())) + + self.event_loop.call_later(0.1, remove_writer) + self.event_loop.run() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def test_writer_callback_cancel(self): + r, w = test_utils.socketpair() + w.setblocking(False) + + def sender(): + w.send(b'x'*256) + handle.cancel() + + handle = self.event_loop.add_writer(w.fileno(), sender) + self.event_loop.run() + w.close() + data = r.recv(1024) + r.close() + self.assertTrue(data == b'x'*256) + + def test_sock_client_ops(self): + with self.run_test_server() as httpd: + sock = socket.socket() + sock.setblocking(False) + address = httpd.socket.getsockname() + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, address)) + self.event_loop.run_until_complete( + self.event_loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.event_loop.run_until_complete( + self.event_loop.sock_recv(sock, 1024)) + sock.close() + + self.assertTrue(re.match(rb'HTTP/1.0 302 Found', data), data) + + def test_sock_client_fail(self): + # Make sure that we will get an unused port + address = None + try: + s = socket.socket() + s.bind(('127.0.0.1', 0)) + address = s.getsockname() + finally: + s.close() + + sock = socket.socket() + sock.setblocking(False) + with self.assertRaises(ConnectionRefusedError): + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, address)) + sock.close() + + def test_sock_accept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + + f = self.event_loop.sock_accept(listener) + conn, addr = self.event_loop.run_until_complete(f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + def test_add_signal_handler(self): + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + # Check error behavior first. + self.assertRaises( + TypeError, self.event_loop.add_signal_handler, 'boom', my_handler) + self.assertRaises( + TypeError, self.event_loop.remove_signal_handler, 'boom') + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, signal.NSIG+1) + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, 0, my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, 0) + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, -1, my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, -1) + self.assertRaises( + RuntimeError, self.event_loop.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + self.event_loop.add_signal_handler(signal.SIGINT, my_handler) + self.event_loop.run_once() + os.kill(os.getpid(), signal.SIGINT) + self.event_loop.run_once() + self.assertEqual(caught, 1) + # Removing it should restore the default handler. + self.assertTrue(self.event_loop.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGINT)) + + @unittest.skipIf(sys.platform == 'win32', 'Unix only') + def test_cancel_signal_handler(self): + # Cancelling the handler should remove it (eventually). + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + handle = self.event_loop.add_signal_handler(signal.SIGINT, my_handler) + handle.cancel() + os.kill(os.getpid(), signal.SIGINT) + self.event_loop.run_once() + self.assertEqual(caught, 0) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_while_selecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + self.event_loop.add_signal_handler( + signal.SIGALRM, my_handler) + + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + self.event_loop.call_later(0.15, self.event_loop.stop) + self.event_loop.run_forever() + self.assertEqual(caught, 1) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_args(self): + some_args = (42,) + caught = 0 + + def my_handler(*args): + nonlocal caught + caught += 1 + self.assertEqual(args, some_args) + + self.event_loop.add_signal_handler( + signal.SIGALRM, my_handler, *some_args) + + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + self.event_loop.call_later(0.15, self.event_loop.stop) + self.event_loop.run_forever() + self.assertEqual(caught, 1) + + def test_create_connection(self): + with self.run_test_server() as httpd: + host, port = httpd.socket.getsockname() + f = tasks.Task( + self.event_loop.create_connection(MyProto, host, port)) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) + + def test_create_connection_sock(self): + with self.run_test_server() as httpd: + host, port = httpd.socket.getsockname() + f = tasks.Task( + self.event_loop.create_connection(MyProto, host, port)) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_ssl_connection(self): + with self.run_test_server(use_ssl=True) as httpsd: + host, port = httpsd.socket.getsockname() + f = self.event_loop.create_connection( + MyProto, host, port, ssl=True) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + self.assertTrue( + hasattr(tr.get_extra_info('socket'), 'getsockname')) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) + + def test_create_connection_host_port_sock(self): + self.suppress_log_errors() + coro = self.event_loop.create_connection( + MyProto, 'xkcd.com', 80, sock=object()) + self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) + + def test_create_connection_no_host_port_sock(self): + self.suppress_log_errors() + coro = self.event_loop.create_connection(MyProto) + self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) + + def test_create_connection_no_getaddrinfo(self): + self.suppress_log_errors() + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + coro = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_create_connection_connect_err(self): + self.suppress_log_errors() + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + coro = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_create_connection_mutiple_errors(self): + self.suppress_log_errors() + + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + self.event_loop.getaddrinfo = getaddrinfo + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + coro = self.event_loop.create_connection(MyProto, 'xkcd.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_start_serving(self): + proto = None + + def factory(): + nonlocal proto + proto = MyProto() + return proto + + f = self.event_loop.start_serving(factory, '0.0.0.0', 0) + sock = self.event_loop.run_until_complete(f) + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + self.event_loop.run_once() + self.assertIsInstance(proto, MyProto) + self.assertEqual('INITIAL', proto.state) + self.event_loop.run_once() + self.assertEqual('CONNECTED', proto.state) + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('socket')) + conn = proto.transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + self.assertEqual( + '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + + # close connection + proto.transport.close() + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + def test_start_serving_sock(self): + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.event_loop.start_serving(MyProto, sock=sock_ob) + sock = self.event_loop.run_until_complete(f) + self.assertIs(sock, sock_ob) + + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + self.event_loop.run_once() # This is quite mysterious, but necessary. + self.event_loop.run_once() + sock.close() + client.close() + + def test_start_serving_host_port_sock(self): + self.suppress_log_errors() + fut = self.event_loop.start_serving( + MyProto, '0.0.0.0', 0, sock=object()) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_start_serving_no_host_port_sock(self): + self.suppress_log_errors() + fut = self.event_loop.start_serving(MyProto) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_start_serving_no_getaddrinfo(self): + self.suppress_log_errors() + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + f = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, f) + + @unittest.mock.patch('tulip.base_events.socket') + def test_start_serving_cant_bind(self, m_socket): + self.suppress_log_errors() + + class Err(socket.error): + pass + + m_socket.error = socket.error + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.setsockopt.side_effect = Err + + fut = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises(Err, self.event_loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_no_addrinfo(self, m_socket): + self.suppress_log_errors() + + m_socket.error = socket.error + m_socket.getaddrinfo.return_value = [] + + coro = self.event_loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 0)) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_create_datagram_endpoint_addr_error(self): + self.suppress_log_errors() + + coro = self.event_loop.create_datagram_endpoint( + MyDatagramProto, local_addr='localhost') + self.assertRaises( + AssertionError, self.event_loop.run_until_complete, coro) + coro = self.event_loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 1, 2, 3)) + self.assertRaises( + AssertionError, self.event_loop.run_until_complete, coro) + + def test_create_datagram_endpoint(self): + class TestMyDatagramProto(MyDatagramProto): + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + coro = self.event_loop.create_datagram_endpoint( + TestMyDatagramProto, local_addr=('127.0.0.1', 0)) + s_transport, server = self.event_loop.run_until_complete(coro) + host, port = s_transport.get_extra_info('addr') + + coro = self.event_loop.create_datagram_endpoint( + MyDatagramProto, remote_addr=(host, port)) + transport, client = self.event_loop.run_until_complete(coro) + + self.assertEqual('INITIALIZED', client.state) + transport.sendto(b'xxx') + self.event_loop.run_once(None) + self.assertEqual(3, server.nbytes) + self.event_loop.run_once(None) + + # received + self.assertEqual(8, client.nbytes) + + # extra info is available + self.assertIsNotNone(transport.get_extra_info('socket')) + conn = transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + + # close connection + transport.close() + + self.assertEqual('CLOSED', client.state) + server.transport.close() + + def test_create_datagram_endpoint_connect_err(self): + self.suppress_log_errors() + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0)) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_socket_err(self, m_socket): + self.suppress_log_errors() + + m_socket.error = socket.error + m_socket.getaddrinfo = socket.getaddrinfo + m_socket.socket.side_effect = socket.error + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, local_addr=('127.0.0.1', 0)) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_create_datagram_endpoint_no_matching_family(self): + self.suppress_log_errors() + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, + remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) + self.assertRaises( + ValueError, self.event_loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_setblk_err(self, m_socket): + self.suppress_log_errors() + + m_socket.error = socket.error + m_socket.socket.return_value.setblocking.side_effect = socket.error + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + self.assertTrue( + m_socket.socket.return_value.close.called) + + def test_create_datagram_endpoint_noaddr_nofamily(self): + self.suppress_log_errors() + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol) + self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_cant_bind(self, m_socket): + self.suppress_log_errors() + + class Err(socket.error): + pass + + m_socket.error = socket.error + m_socket.AF_INET6 = socket.AF_INET6 + m_socket.getaddrinfo = socket.getaddrinfo + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.event_loop.create_datagram_endpoint( + MyDatagramProto, + local_addr=('127.0.0.1', 0), family=socket.AF_INET) + self.assertRaises(Err, self.event_loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_accept_connection_retry(self): + sock = unittest.mock.Mock() + sock.accept.side_effect = BlockingIOError() + + self.event_loop._accept_connection(MyProto, sock) + self.assertFalse(sock.close.called) + + def test_accept_connection_exception(self): + self.suppress_log_errors() + + sock = unittest.mock.Mock() + sock.accept.side_effect = OSError() + + self.event_loop._accept_connection(MyProto, sock) + self.assertTrue(sock.close.called) + + def test_internal_fds(self): + event_loop = self.create_event_loop() + if not isinstance(event_loop, selector_events.BaseSelectorEventLoop): + return + + self.assertEqual(1, event_loop._internal_fds) + event_loop.close() + self.assertEqual(0, event_loop._internal_fds) + self.assertIsNone(event_loop._csock) + self.assertIsNone(event_loop._ssock) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_read_pipe(self): + proto = None + + def factory(): + nonlocal proto + proto = MyReadPipeProto() + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(rpipe, 'rb', 1024) + + @tasks.task + def connect(): + t, p = yield from self.event_loop.connect_read_pipe(factory, + pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(0, proto.nbytes) + + self.event_loop.run_until_complete(connect()) + + os.write(wpipe, b'1') + self.event_loop.run_once() + self.assertEqual(1, proto.nbytes) + + os.write(wpipe, b'2345') + self.event_loop.run_once() + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(5, proto.nbytes) + + os.close(wpipe) + self.event_loop.run_once() + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto() + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.task + def connect(): + nonlocal transport + t, p = yield from self.event_loop.connect_write_pipe(factory, + pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.event_loop.run_until_complete(connect()) + + transport.write(b'1') + self.event_loop.run_once() + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + transport.write(b'2345') + self.event_loop.run_once() + data = os.read(rpipe, 1024) + self.assertEqual(b'2345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(rpipe) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + # close connection + proto.transport.close() + self.event_loop.run_once() + self.assertEqual('CLOSED', proto.state) + + +if sys.platform == 'win32': + from tulip import windows_events + + class SelectEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return windows_events.SelectorEventLoop() + + class ProactorEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return windows_events.ProactorEventLoop() + def test_create_ssl_connection(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + def test_reader_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_reader_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_reader_callback_with_handle(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_writer_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_writer_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_writer_callback_with_handle(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_accept_connection_retry(self): + raise unittest.SkipTest( + "IocpEventLoop does not have _accept_connection()") + def test_accept_connection_exception(self): + raise unittest.SkipTest( + "IocpEventLoop does not have _accept_connection()") + def test_create_datagram_endpoint(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_no_connection(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_cant_bind(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_noaddr_nofamily(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_socket_err(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_connect_err(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") +else: + from tulip import selectors + from tulip import unix_events + + if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop( + selectors.KqueueSelector()) + + if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.EpollSelector()) + + if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.PollSelector()) + + # Should always exist. + class SelectEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.SelectSelector()) + + +class HandleTests(unittest.TestCase): + + def test_handle(self): + def callback(*args): + return args + + args = () + h = events.Handle(callback, args) + self.assertIs(h.callback, callback) + self.assertIs(h.args, args) + self.assertFalse(h.cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '.callback')) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h.cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '.callback')) + self.assertTrue(r.endswith('())')) + + def test_make_handle(self): + def callback(*args): + return args + h1 = events.Handle(callback, ()) + h2 = events.make_handle(h1, ()) + self.assertIs(h1, h2) + + self.assertRaises( + AssertionError, events.make_handle, h1, (1, 2)) + + +class TimerTests(unittest.TestCase): + + def test_timer(self): + def callback(*args): + return args + + args = () + when = time.monotonic() + h = events.Timer(when, callback, args) + self.assertIs(h.callback, callback) + self.assertIs(h.args, args) + self.assertFalse(h.cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h.cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + self.assertRaises(AssertionError, events.Timer, None, callback, args) + + def test_timer_comparison(self): + def callback(*args): + return args + + when = time.monotonic() + + h1 = events.Timer(when, callback, ()) + h2 = events.Timer(when, callback, ()) + self.assertFalse(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertTrue(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertFalse(h2 > h1) + self.assertTrue(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertTrue(h1 == h2) + self.assertFalse(h1 != h2) + + h2.cancel() + self.assertFalse(h1 == h2) + + h1 = events.Timer(when, callback, ()) + h2 = events.Timer(when + 10.0, callback, ()) + self.assertTrue(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertFalse(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertTrue(h2 > h1) + self.assertFalse(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h3 = events.Handle(callback, ()) + self.assertIs(NotImplemented, h1.__eq__(h3)) + self.assertIs(NotImplemented, h1.__ne__(h3)) + + +class AbstractEventLoopTests(unittest.TestCase): + + def test_not_imlemented(self): + f = unittest.mock.Mock() + ev_loop = events.AbstractEventLoop() + self.assertRaises( + NotImplementedError, ev_loop.run) + self.assertRaises( + NotImplementedError, ev_loop.run_forever) + self.assertRaises( + NotImplementedError, ev_loop.run_once) + self.assertRaises( + NotImplementedError, ev_loop.run_until_complete, None) + self.assertRaises( + NotImplementedError, ev_loop.stop) + self.assertRaises( + NotImplementedError, ev_loop.call_later, None, None) + self.assertRaises( + NotImplementedError, ev_loop.call_repeatedly, None, None) + self.assertRaises( + NotImplementedError, ev_loop.call_soon, None) + self.assertRaises( + NotImplementedError, ev_loop.call_soon_threadsafe, None) + self.assertRaises( + NotImplementedError, ev_loop.wrap_future, f) + self.assertRaises( + NotImplementedError, ev_loop.run_in_executor, f, f) + self.assertRaises( + NotImplementedError, ev_loop.getaddrinfo, 'localhost', 8080) + self.assertRaises( + NotImplementedError, ev_loop.getnameinfo, ('localhost', 8080)) + self.assertRaises( + NotImplementedError, ev_loop.create_connection, f) + self.assertRaises( + NotImplementedError, ev_loop.start_serving, f) + self.assertRaises( + NotImplementedError, ev_loop.create_datagram_endpoint, f) + self.assertRaises( + NotImplementedError, ev_loop.add_reader, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_reader, 1) + self.assertRaises( + NotImplementedError, ev_loop.add_writer, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_writer, 1) + self.assertRaises( + NotImplementedError, ev_loop.sock_recv, f, 10) + self.assertRaises( + NotImplementedError, ev_loop.sock_sendall, f, 10) + self.assertRaises( + NotImplementedError, ev_loop.sock_connect, f, f) + self.assertRaises( + NotImplementedError, ev_loop.sock_accept, f) + self.assertRaises( + NotImplementedError, ev_loop.add_signal_handler, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, ev_loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, ev_loop.connect_read_pipe, f, + unittest.mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, ev_loop.connect_write_pipe, f, + unittest.mock.sentinel.pipe) + + +class ProtocolsAbsTests(unittest.TestCase): + + def test_empty(self): + f = unittest.mock.Mock() + p = protocols.Protocol() + self.assertIsNone(p.connection_made(f)) + self.assertIsNone(p.connection_lost(f)) + self.assertIsNone(p.data_received(f)) + self.assertIsNone(p.eof_received()) + + dp = protocols.DatagramProtocol() + self.assertIsNone(dp.connection_made(f)) + self.assertIsNone(dp.connection_lost(f)) + self.assertIsNone(dp.connection_refused(f)) + self.assertIsNone(dp.datagram_received(f, f)) + + +class PolicyTests(unittest.TestCase): + + def test_event_loop_policy(self): + policy = events.EventLoopPolicy() + self.assertRaises(NotImplementedError, policy.get_event_loop) + self.assertRaises(NotImplementedError, policy.set_event_loop, object()) + self.assertRaises(NotImplementedError, policy.new_event_loop) + + def test_get_event_loop(self): + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy._event_loop) + + event_loop = policy.get_event_loop() + self.assertIsInstance(event_loop, events.AbstractEventLoop) + + self.assertIs(policy._event_loop, event_loop) + self.assertIs(event_loop, policy.get_event_loop()) + + @unittest.mock.patch('tulip.events.threading') + def test_get_event_loop_thread(self, m_threading): + m_t = m_threading.current_thread.return_value = unittest.mock.Mock() + m_t.name = 'Thread 1' + + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy.get_event_loop()) + + def test_new_event_loop(self): + policy = events.DefaultEventLoopPolicy() + + event_loop = policy.new_event_loop() + self.assertIsInstance(event_loop, events.AbstractEventLoop) + + def test_set_event_loop(self): + policy = events.DefaultEventLoopPolicy() + old_event_loop = policy.get_event_loop() + + self.assertRaises(AssertionError, policy.set_event_loop, object()) + + event_loop = policy.new_event_loop() + policy.set_event_loop(event_loop) + self.assertIs(event_loop, policy.get_event_loop()) + self.assertIsNot(old_event_loop, policy.get_event_loop()) + + def test_get_event_loop_policy(self): + policy = events.get_event_loop_policy() + self.assertIsInstance(policy, events.EventLoopPolicy) + self.assertIs(policy, events.get_event_loop_policy()) + + def test_set_event_loop_policy(self): + self.assertRaises( + AssertionError, events.set_event_loop_policy, object()) + + old_policy = events.get_event_loop_policy() + + policy = events.DefaultEventLoopPolicy() + events.set_event_loop_policy(policy) + self.assertIs(policy, events.get_event_loop_policy()) + self.assertIsNot(policy, old_policy) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/futures_test.py b/tests/futures_test.py new file mode 100644 index 00000000..5569cca1 --- /dev/null +++ b/tests/futures_test.py @@ -0,0 +1,222 @@ +"""Tests for futures.py.""" + +import unittest + +from tulip import futures + + +def _fakefunc(f): + return f + + +class FutureTests(unittest.TestCase): + + def test_initial_state(self): + f = futures.Future() + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertFalse(f.done()) + + def test_init_event_loop_positional(self): + # Make sure Future does't accept a positional argument + self.assertRaises(TypeError, futures.Future, 42) + + def test_cancel(self): + f = futures.Future() + self.assertTrue(f.cancel()) + self.assertTrue(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(futures.CancelledError, f.result) + self.assertRaises(futures.CancelledError, f.exception) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_result(self): + f = futures.Future() + self.assertRaises(futures.InvalidStateError, f.result) + self.assertRaises(futures.InvalidTimeoutError, f.result, 10) + + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_exception(self): + exc = RuntimeError() + f = futures.Future() + self.assertRaises(futures.InvalidStateError, f.exception) + self.assertRaises(futures.InvalidTimeoutError, f.exception, 10) + + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_yield_from_twice(self): + f = futures.Future() + + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + + def test_repr(self): + f_pending = futures.Future() + self.assertEqual(repr(f_pending), 'Future') + + f_cancelled = futures.Future() + f_cancelled.cancel() + self.assertEqual(repr(f_cancelled), 'Future') + + f_result = futures.Future() + f_result.set_result(4) + self.assertEqual(repr(f_result), 'Future') + + f_exception = futures.Future() + f_exception.set_exception(RuntimeError()) + self.assertEqual(repr(f_exception), 'Future') + + f_few_callbacks = futures.Future() + f_few_callbacks.add_done_callback(_fakefunc) + self.assertIn('Future', r) + + def test_copy_state(self): + # Test the internal _copy_state method since it's being directly + # invoked in other modules. + f = futures.Future() + f.set_result(10) + + newf = futures.Future() + newf._copy_state(f) + self.assertTrue(newf.done()) + self.assertEqual(newf.result(), 10) + + f_exception = futures.Future() + f_exception.set_exception(RuntimeError()) + + newf_exception = futures.Future() + newf_exception._copy_state(f_exception) + self.assertTrue(newf_exception.done()) + self.assertRaises(RuntimeError, newf_exception.result) + + f_cancelled = futures.Future() + f_cancelled.cancel() + + newf_cancelled = futures.Future() + newf_cancelled._copy_state(f_cancelled) + self.assertTrue(newf_cancelled.cancelled()) + + def test_iter(self): + def coro(): + fut = futures.Future() + yield from fut + + def test(): + arg1, arg2 = coro() + + self.assertRaises(AssertionError, test) + + +# A fake event loop for tests. All it does is implement a call_soon method +# that immediately invokes the given function. +class _FakeEventLoop: + def call_soon(self, fn, future): + fn(future) + + +class FutureDoneCallbackTests(unittest.TestCase): + + def _make_callback(self, bag, thing): + # Create a callback function that appends thing to bag. + def bag_appender(future): + bag.append(thing) + return bag_appender + + def _new_future(self): + return futures.Future(event_loop=_FakeEventLoop()) + + def test_callbacks_invoked_on_set_result(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 42)) + f.add_done_callback(self._make_callback(bag, 17)) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def test_callbacks_invoked_on_set_exception(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 100)) + + self.assertEqual(bag, []) + exc = RuntimeError() + f.set_exception(exc) + self.assertEqual(bag, [100]) + self.assertEqual(f.exception(), exc) + + def test_remove_done_callback(self): + bag = [] + f = self._new_future() + cb1 = self._make_callback(bag, 1) + cb2 = self._make_callback(bag, 2) + cb3 = self._make_callback(bag, 3) + + # Add one cb1 and one cb2. + f.add_done_callback(cb1) + f.add_done_callback(cb2) + + # One instance of cb2 removed. Now there's only one cb1. + self.assertEqual(f.remove_done_callback(cb2), 1) + + # Never had any cb3 in there. + self.assertEqual(f.remove_done_callback(cb3), 0) + + # After this there will be 6 instances of cb1 and one of cb2. + f.add_done_callback(cb2) + for i in range(5): + f.add_done_callback(cb1) + + # Remove all instances of cb1. One cb2 remains. + self.assertEqual(f.remove_done_callback(cb1), 6) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [2]) + self.assertEqual(f.result(), 'foo') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py new file mode 100644 index 00000000..74aef7c8 --- /dev/null +++ b/tests/http_protocol_test.py @@ -0,0 +1,972 @@ +"""Tests for http/protocol.py""" + +import http.client +import unittest +import unittest.mock +import zlib + +import tulip +from tulip.http import protocol +from tulip.test_utils import LogTrackingTestCase + + +class HttpStreamReaderTests(LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.suppress_log_errors() + + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + self.transport = unittest.mock.Mock() + self.stream = protocol.HttpStreamReader() + + def tearDown(self): + self.loop.close() + super().tearDown() + + def test_request_line(self): + self.stream.feed_data(b'get /path HTTP/1.1\r\n') + self.assertEqual( + ('GET', '/path', (1, 1)), + self.loop.run_until_complete(self.stream.read_request_line())) + + def test_request_line_two_slashes(self): + self.stream.feed_data(b'get //path HTTP/1.1\r\n') + self.assertEqual( + ('GET', '//path', (1, 1)), + self.loop.run_until_complete(self.stream.read_request_line())) + + def test_request_line_non_ascii(self): + self.stream.feed_data(b'get /path\xd0\xb0 HTTP/1.1\r\n') + + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete(self.stream.read_request_line()) + + self.assertEqual( + b'get /path\xd0\xb0 HTTP/1.1\r\n', cm.exception.args[0]) + + def test_request_line_bad_status_line(self): + self.stream.feed_data(b'\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_request_line()) + + def test_request_line_bad_method(self): + self.stream.feed_data(b'!12%()+=~$ /get HTTP/1.1\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_request_line()) + + def test_request_line_bad_version(self): + self.stream.feed_data(b'GET //get HT/11\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_request_line()) + + def test_response_status_bad_status_line(self): + self.stream.feed_data(b'\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_response_status()) + + def test_response_status_bad_status_line_eof(self): + self.stream.feed_eof() + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_response_status()) + + def test_response_status_bad_status_non_ascii(self): + self.stream.feed_data(b'HTTP/1.1 200 \xd0\xb0\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete(self.stream.read_response_status()) + + self.assertEqual(b'HTTP/1.1 200 \xd0\xb0\r\n', cm.exception.args[0]) + + def test_response_status_bad_version(self): + self.stream.feed_data(b'HT/11 200 Ok\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete(self.stream.read_response_status()) + + self.assertEqual('HT/11 200 Ok', cm.exception.args[0]) + + def test_response_status_no_reason(self): + self.stream.feed_data(b'HTTP/1.1 200\r\n') + + v, s, r = self.loop.run_until_complete( + self.stream.read_response_status()) + self.assertEqual(v, (1, 1)) + self.assertEqual(s, 200) + self.assertEqual(r, '') + + def test_response_status_bad(self): + self.stream.feed_data(b'HTT/1\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + self.stream.read_response_status()) + + self.assertIn('HTT/1', str(cm.exception)) + + def test_response_status_bad_code_under_100(self): + self.stream.feed_data(b'HTTP/1.1 99 test\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + self.stream.read_response_status()) + + self.assertIn('HTTP/1.1 99 test', str(cm.exception)) + + def test_response_status_bad_code_above_999(self): + self.stream.feed_data(b'HTTP/1.1 9999 test\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + self.stream.read_response_status()) + + self.assertIn('HTTP/1.1 9999 test', str(cm.exception)) + + def test_response_status_bad_code_not_int(self): + self.stream.feed_data(b'HTTP/1.1 ttt test\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + self.stream.read_response_status()) + + self.assertIn('HTTP/1.1 ttt test', str(cm.exception)) + + def test_read_headers(self): + self.stream.feed_data(b'test: line\r\n' + b' continue\r\n' + b'test2: data\r\n' + b'\r\n') + + headers = self.loop.run_until_complete(self.stream.read_headers()) + self.assertEqual(headers, + [('TEST', 'line\r\n continue'), ('TEST2', 'data')]) + + def test_read_headers_size(self): + self.stream.feed_data(b'test: line\r\n') + self.stream.feed_data(b' continue\r\n') + self.stream.feed_data(b'test2: data\r\n') + self.stream.feed_data(b'\r\n') + + self.stream.MAX_HEADERS = 5 + self.assertRaises( + http.client.LineTooLong, + self.loop.run_until_complete, + self.stream.read_headers()) + + def test_read_headers_invalid_header(self): + self.stream.feed_data(b'test line\r\n') + + with self.assertRaises(ValueError) as cm: + self.loop.run_until_complete(self.stream.read_headers()) + + self.assertIn("Invalid header b'test line'", str(cm.exception)) + + def test_read_headers_invalid_name(self): + self.stream.feed_data(b'test[]: line\r\n') + + with self.assertRaises(ValueError) as cm: + self.loop.run_until_complete(self.stream.read_headers()) + + self.assertIn("Invalid header name b'TEST[]'", str(cm.exception)) + + def test_read_headers_headers_size(self): + self.stream.MAX_HEADERFIELD_SIZE = 5 + self.stream.feed_data(b'test: line data data\r\ndata\r\n') + + with self.assertRaises(http.client.LineTooLong) as cm: + self.loop.run_until_complete(self.stream.read_headers()) + + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_read_headers_continuation_headers_size(self): + self.stream.MAX_HEADERFIELD_SIZE = 5 + self.stream.feed_data(b'test: line\r\n test\r\n') + + with self.assertRaises(http.client.LineTooLong) as cm: + self.loop.run_until_complete(self.stream.read_headers()) + + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_read_message_should_close(self): + self.stream.feed_data( + b'Host: example.com\r\nConnection: close\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + self.assertTrue(msg.should_close) + + def test_read_message_should_close_http11(self): + self.stream.feed_data( + b'Host: example.com\r\n\r\n') + + msg = self.loop.run_until_complete( + self.stream.read_message(version=(1, 1))) + self.assertFalse(msg.should_close) + + def test_read_message_should_close_http10(self): + self.stream.feed_data( + b'Host: example.com\r\n\r\n') + + msg = self.loop.run_until_complete( + self.stream.read_message(version=(1, 0))) + self.assertTrue(msg.should_close) + + def test_read_message_should_close_keep_alive(self): + self.stream.feed_data( + b'Host: example.com\r\nConnection: keep-alive\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + self.assertFalse(msg.should_close) + + def test_read_message_content_length_broken(self): + self.stream.feed_data( + b'Host: example.com\r\nContent-Length: qwe\r\n\r\n') + + self.assertRaises( + http.client.HTTPException, + self.loop.run_until_complete, + self.stream.read_message()) + + def test_read_message_content_length_wrong(self): + self.stream.feed_data( + b'Host: example.com\r\nContent-Length: -1\r\n\r\n') + + self.assertRaises( + http.client.HTTPException, + self.loop.run_until_complete, + self.stream.read_message()) + + def test_read_message_content_length(self): + self.stream.feed_data( + b'Host: example.com\r\nContent-Length: 2\r\n\r\n12') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'12', payload) + + def test_read_message_content_length_no_val(self): + self.stream.feed_data(b'Host: example.com\r\n\r\n12') + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=False)) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'', payload) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_read_message_deflate(self): + self.stream.feed_data( + ('Host: example.com\r\nContent-Length: %s\r\n' + 'Content-Encoding: deflate\r\n\r\n' % + len(self._COMPRESSED)).encode()) + self.stream.feed_data(self._COMPRESSED) + + msg = self.loop.run_until_complete(self.stream.read_message()) + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'data', payload) + + def test_read_message_deflate_disabled(self): + self.stream.feed_data( + ('Host: example.com\r\nContent-Encoding: deflate\r\n' + 'Content-Length: %s\r\n\r\n' % + len(self._COMPRESSED)).encode()) + self.stream.feed_data(self._COMPRESSED) + + msg = self.loop.run_until_complete( + self.stream.read_message(compression=False)) + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(self._COMPRESSED, payload) + + def test_read_message_deflate_unknown(self): + self.stream.feed_data( + ('Host: example.com\r\nContent-Encoding: compress\r\n' + 'Content-Length: %s\r\n\r\n' % len(self._COMPRESSED)).encode()) + self.stream.feed_data(self._COMPRESSED) + + msg = self.loop.run_until_complete( + self.stream.read_message(compression=False)) + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(self._COMPRESSED, payload) + + def test_read_message_websocket(self): + self.stream.feed_data( + b'Host: example.com\r\nSec-Websocket-Key1: 13\r\n\r\n1234567890') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'12345678', payload) + + def test_read_message_chunked(self): + self.stream.feed_data( + b'Host: example.com\r\nTransfer-Encoding: chunked\r\n\r\n') + self.stream.feed_data( + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'dataline', payload) + + def test_read_message_readall_eof(self): + self.stream.feed_data( + b'Host: example.com\r\n\r\n') + self.stream.feed_data(b'data') + self.stream.feed_data(b'line') + self.stream.feed_eof() + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'dataline', payload) + + def test_read_message_payload(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 8\r\n\r\n') + self.stream.feed_data(b'data') + self.stream.feed_data(b'data') + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + data = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'datadata', data) + + def test_read_message_payload_eof(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 4\r\n\r\n') + self.stream.feed_data(b'da') + self.stream.feed_eof() + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, msg.payload.read()) + + def test_read_message_length_payload_zero(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 0\r\n\r\n') + self.stream.feed_data(b'data') + + msg = self.loop.run_until_complete(self.stream.read_message()) + data = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'', data) + + def test_read_message_length_payload_incomplete(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 8\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data(b'data') + self.stream.feed_eof() + return (yield from msg.payload.read()) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, coro()) + + def test_read_message_eof_payload(self): + self.stream.feed_data(b'Host: example.com\r\n\r\n') + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + def coro(): + self.stream.feed_data(b'data') + self.stream.feed_eof() + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'data', data) + + def test_read_message_length_payload(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 4\r\n\r\n') + self.stream.feed_data(b'da') + self.stream.feed_data(b't') + self.stream.feed_data(b'ali') + self.stream.feed_data(b'ne') + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + self.assertIsInstance(msg.payload, tulip.StreamReader) + + data = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'data', data) + self.assertEqual(b'line', b''.join(self.stream.buffer)) + + def test_read_message_length_payload_extra(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 4\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data(b'da') + self.stream.feed_data(b't') + self.stream.feed_data(b'ali') + self.stream.feed_data(b'ne') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'data', data) + self.assertEqual(b'line', b''.join(self.stream.buffer)) + + def test_parse_length_payload_eof_exc(self): + parser = self.stream._parse_length_payload(4) + next(parser) + + stream = tulip.StreamReader() + parser.send(stream) + self.stream._parser = parser + self.stream.feed_data(b'da') + + def eof(): + yield from [] + self.stream.feed_eof() + + t1 = tulip.Task(stream.read()) + t2 = tulip.Task(eof()) + + self.loop.run_until_complete(tulip.wait([t1, t2])) + self.assertRaises(http.client.IncompleteRead, t1.result) + self.assertIsNone(self.stream._parser) + + def test_read_message_deflate_payload(self): + comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + + data = b''.join([comp.compress(b'data'), comp.flush()]) + + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Encoding: deflate\r\n' + + ('Content-Length: %s\r\n\r\n' % len(data)).encode()) + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + def coro(): + self.stream.feed_data(data) + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'data', data) + + def test_read_message_chunked_payload(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data( + b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'dataline', data) + + def test_read_message_chunked_payload_chunks(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data(b'4\r\ndata\r') + self.stream.feed_data(b'\n4') + self.stream.feed_data(b'\r') + self.stream.feed_data(b'\n') + self.stream.feed_data(b'line\r\n0\r\n') + self.stream.feed_data(b'test\r\n\r\n') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'dataline', data) + + def test_read_message_chunked_payload_incomplete(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data(b'4\r\ndata\r\n') + self.stream.feed_eof() + return (yield from msg.payload.read()) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, coro()) + + def test_read_message_chunked_payload_extension(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data( + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'dataline', data) + + def test_read_message_chunked_payload_size_error(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + def coro(): + self.stream.feed_data(b'blah\r\n') + return (yield from msg.payload.read()) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, coro()) + + def test_deflate_stream_set_exception(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + exc = ValueError() + dstream.set_exception(exc) + self.assertIs(exc, stream.exception()) + + def test_deflate_stream_feed_data(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + dstream.zlib = unittest.mock.Mock() + dstream.zlib.decompress.return_value = b'line' + + dstream.feed_data(b'data') + self.assertEqual([b'line'], list(stream.buffer)) + + def test_deflate_stream_feed_data_err(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + exc = ValueError() + dstream.zlib = unittest.mock.Mock() + dstream.zlib.decompress.side_effect = exc + + dstream.feed_data(b'data') + self.assertIsInstance(stream.exception(), http.client.IncompleteRead) + + def test_deflate_stream_feed_eof(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + dstream.zlib = unittest.mock.Mock() + dstream.zlib.flush.return_value = b'line' + + dstream.feed_eof() + self.assertEqual([b'line'], list(stream.buffer)) + self.assertTrue(stream.eof) + + def test_deflate_stream_feed_eof_err(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + dstream.zlib = unittest.mock.Mock() + dstream.zlib.flush.return_value = b'line' + dstream.zlib.eof = False + + dstream.feed_eof() + self.assertIsInstance(stream.exception(), http.client.IncompleteRead) + + +class HttpMessageTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + + def test_start_request(self): + msg = protocol.Request( + self.transport, 'GET', '/index.html', close=True) + + self.assertIs(msg.transport, self.transport) + self.assertIsNone(msg.status) + self.assertTrue(msg.closing) + self.assertEqual(msg.status_line, 'GET /index.html HTTP/1.1\r\n') + + def test_start_response(self): + msg = protocol.Response(self.transport, 200, close=True) + + self.assertIs(msg.transport, self.transport) + self.assertEqual(msg.status, 200) + self.assertTrue(msg.closing) + self.assertEqual(msg.status_line, 'HTTP/1.1 200 OK\r\n') + + def test_force_close(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.closing) + msg.force_close() + self.assertTrue(msg.closing) + + def test_force_chunked(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.chunked) + msg.force_chunked() + self.assertTrue(msg.chunked) + + def test_keep_alive(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.keep_alive()) + msg.keepalive = True + self.assertTrue(msg.keep_alive()) + + msg.force_close() + self.assertFalse(msg.keep_alive()) + + def test_add_header(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], msg.headers) + + msg.add_header('content-type', 'plain/html') + self.assertEqual([('CONTENT-TYPE', 'plain/html')], msg.headers) + + def test_add_headers(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], msg.headers) + + msg.add_headers(('content-type', 'plain/html')) + self.assertEqual([('CONTENT-TYPE', 'plain/html')], msg.headers) + + def test_add_headers_length(self): + msg = protocol.Response(self.transport, 200) + self.assertIsNone(msg.length) + + msg.add_headers(('content-length', '200')) + self.assertEqual(200, msg.length) + + def test_add_headers_upgrade(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.upgrade) + + msg.add_headers(('connection', 'upgrade')) + self.assertTrue(msg.upgrade) + + def test_add_headers_upgrade_websocket(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('upgrade', 'test')) + self.assertEqual([], msg.headers) + + msg.add_headers(('upgrade', 'websocket')) + self.assertEqual([('UPGRADE', 'websocket')], msg.headers) + + def test_add_headers_connection_keepalive(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('connection', 'keep-alive')) + self.assertEqual([], msg.headers) + self.assertTrue(msg.keepalive) + + msg.add_headers(('connection', 'close')) + self.assertFalse(msg.keepalive) + + def test_add_headers_hop_headers(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('connection', 'test'), ('transfer-encoding', 't')) + self.assertEqual([], msg.headers) + + def test_default_headers(self): + msg = protocol.Response(self.transport, 200) + + headers = [r for r, _ in msg._default_headers()] + self.assertIn('DATE', headers) + self.assertIn('CONNECTION', headers) + + def test_default_headers_server(self): + msg = protocol.Response(self.transport, 200) + + headers = [r for r, _ in msg._default_headers()] + self.assertIn('SERVER', headers) + + def test_default_headers_useragent(self): + msg = protocol.Request(self.transport, 'GET', '/') + + headers = [r for r, _ in msg._default_headers()] + self.assertNotIn('SERVER', headers) + self.assertIn('USER-AGENT', headers) + + def test_default_headers_chunked(self): + msg = protocol.Response(self.transport, 200) + + headers = [r for r, _ in msg._default_headers()] + self.assertNotIn('TRANSFER-ENCODING', headers) + + msg.force_chunked() + + headers = [r for r, _ in msg._default_headers()] + self.assertIn('TRANSFER-ENCODING', headers) + + def test_default_headers_connection_upgrade(self): + msg = protocol.Response(self.transport, 200) + msg.upgrade = True + + headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'upgrade')], headers) + + def test_default_headers_connection_close(self): + msg = protocol.Response(self.transport, 200) + msg.force_close() + + headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'close')], headers) + + def test_default_headers_connection_keep_alive(self): + msg = protocol.Response(self.transport, 200) + msg.keepalive = True + + headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'keep-alive')], headers) + + def test_send_headers(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-type', 'plain/html')) + self.assertFalse(msg.is_headers_sent()) + + msg.send_headers() + + content = b''.join([arg[1][0] for arg in list(write.mock_calls)]) + + self.assertTrue(content.startswith(b'HTTP/1.1 200 OK\r\n')) + self.assertIn(b'CONTENT-TYPE: plain/html', content) + self.assertTrue(msg.headers_sent) + self.assertTrue(msg.is_headers_sent()) + + def test_send_headers_nomore_add(self): + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-type', 'plain/html')) + msg.send_headers() + + self.assertRaises(AssertionError, + msg.add_header, 'content-type', 'plain/html') + + def test_prepare_length(self): + msg = protocol.Response(self.transport, 200) + length = msg._write_length_payload = unittest.mock.Mock() + length.return_value = iter([1, 2, 3]) + + msg.add_headers(('content-length', '200')) + msg.send_headers() + + self.assertTrue(length.called) + self.assertTrue((200,), length.call_args[0]) + + def test_prepare_chunked_force(self): + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + + chunked = msg._write_chunked_payload = unittest.mock.Mock() + chunked.return_value = iter([1, 2, 3]) + + msg.add_headers(('content-length', '200')) + msg.send_headers() + self.assertTrue(chunked.called) + + def test_prepare_chunked_no_length(self): + msg = protocol.Response(self.transport, 200) + + chunked = msg._write_chunked_payload = unittest.mock.Mock() + chunked.return_value = iter([1, 2, 3]) + + msg.send_headers() + self.assertTrue(chunked.called) + + def test_prepare_eof(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + + eof = msg._write_eof_payload = unittest.mock.Mock() + eof.return_value = iter([1, 2, 3]) + + msg.send_headers() + self.assertTrue(eof.called) + + def test_write_auto_send_headers(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg._send_headers = True + + msg.write(b'data1') + self.assertTrue(msg.headers_sent) + + def test_write_payload_eof(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg.send_headers() + + msg.write(b'data1') + self.assertTrue(msg.headers_sent) + + msg.write(b'data2') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'data1data2', content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg.send_headers() + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'4\r\ndata\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_multiple(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg.send_headers() + + msg.write(b'data1') + msg.write(b'data2') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_length(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '2')) + msg.send_headers() + + msg.write(b'd') + msg.write(b'ata') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'da', content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_filter(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(2) + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith(b'2\r\nda\r\n2\r\nta\r\n0\r\n\r\n')) + + def test_write_payload_chunked_filter_mutiple_chunks(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(2) + msg.write(b'data1') + msg.write(b'data2') + msg.write_eof() + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith( + b'2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n' + b'2\r\na2\r\n0\r\n\r\n')) + + def test_write_payload_chunked_large_chunk(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(1024) + msg.write(b'data') + msg.write_eof() + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith(b'4\r\ndata\r\n0\r\n\r\n')) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_write_payload_deflate_filter(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '%s' % len(self._COMPRESSED))) + msg.send_headers() + + msg.add_compression_filter('deflate') + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_deflate_and_chunked(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_compression_filter('deflate') + msg.add_chunking_filter(2) + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'2\r\nKI\r\n2\r\n,I\r\n2\r\n\x04\x00\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_and_deflate(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '%s' % len(self._COMPRESSED))) + + msg.add_chunking_filter(2) + msg.add_compression_filter('deflate') + msg.send_headers() + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) diff --git a/tests/http_server_test.py b/tests/http_server_test.py new file mode 100644 index 00000000..dc55eff9 --- /dev/null +++ b/tests/http_server_test.py @@ -0,0 +1,242 @@ +"""Tests for http/server.py""" + +import unittest +import unittest.mock + +import tulip +from tulip.http import server +from tulip.http import errors +from tulip.test_utils import LogTrackingTestCase + + +class HttpServerProtocolTests(LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.suppress_log_errors() + + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + super().tearDown() + + def test_http_status_exception(self): + exc = errors.HttpStatusException(500, message='Internal error') + self.assertEqual(exc.code, 500) + self.assertEqual(exc.message, 'Internal error') + + def test_handle_request(self): + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + rline = unittest.mock.Mock() + rline.version = (1, 1) + message = unittest.mock.Mock() + srv.handle_request(rline, message) + + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertTrue(content.startswith(b'HTTP/1.1 404 Not Found\r\n')) + + def test_connection_made(self): + srv = server.ServerHttpProtocol() + self.assertIsNone(srv._request_handle) + + srv.connection_made(unittest.mock.Mock()) + self.assertIsNotNone(srv._request_handle) + + def test_data_received(self): + srv = server.ServerHttpProtocol() + srv.connection_made(unittest.mock.Mock()) + + srv.data_received(b'123') + self.assertEqual(b'123', b''.join(srv.stream.buffer)) + + srv.data_received(b'456') + self.assertEqual(b'123456', b''.join(srv.stream.buffer)) + + def test_eof_received(self): + srv = server.ServerHttpProtocol() + srv.connection_made(unittest.mock.Mock()) + srv.eof_received() + self.assertTrue(srv.stream.eof) + + def test_connection_lost(self): + srv = server.ServerHttpProtocol() + srv.connection_made(unittest.mock.Mock()) + srv.data_received(b'123') + + handle = srv._request_handle + srv.connection_lost(None) + + self.assertIsNone(srv._request_handle) + self.assertTrue(handle.cancelled()) + + srv.connection_lost(None) + self.assertIsNone(srv._request_handle) + + def test_close(self): + srv = server.ServerHttpProtocol() + self.assertFalse(srv.closing) + + srv.close() + self.assertTrue(srv.closing) + + def test_handle_error(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + srv.handle_error(404) + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertIn(b'HTTP/1.1 404 Not Found', content) + + @unittest.mock.patch('tulip.http.server.traceback') + def test_handle_error_traceback_exc(self, m_trace): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(debug=True) + srv.connection_made(transport) + + m_trace.format_exc.side_effect = ValueError + + srv.handle_error(500, exc=object()) + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertTrue( + content.startswith(b'HTTP/1.1 500 Internal Server Error')) + + def test_handle_error_debug(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.debug = True + srv.connection_made(transport) + + try: + raise ValueError() + except Exception as exc: + srv.handle_error(999, exc=exc) + + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + + self.assertIn(b'HTTP/1.1 500 Internal', content) + self.assertIn(b'Traceback (most recent call last):', content) + + def test_handle_error_500(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log) + srv.connection_made(transport) + + srv.handle_error(500) + self.assertTrue(log.exception.called) + + def test_handle(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + srv.close() + + self.loop.run_until_complete(srv._request_handle) + self.assertTrue(handle.called) + self.assertIsNone(srv._request_handle) + + def test_handle_coro(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + called = False + + @tulip.coroutine + def coro(rline, message): + nonlocal called + called = True + yield from [] + srv.eof_received() + + srv.handle_request = coro + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + self.loop.run_until_complete(srv._request_handle) + self.assertTrue(called) + + def test_handle_close(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + srv.close() + self.loop.run_until_complete(srv._request_handle) + + self.assertTrue(handle.called) + self.assertTrue(transport.close.called) + + def test_handle_cancel(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log, debug=True) + srv.connection_made(transport) + + srv.handle_request = unittest.mock.Mock() + + @tulip.task + def cancel(): + yield from [] + srv._request_handle.cancel() + + srv.close() + self.loop.run_until_complete( + tulip.wait([srv._request_handle, cancel()])) + self.assertTrue(log.debug.called) + + def test_handle_400(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + srv.handle_error = unittest.mock.Mock() + + def side_effect(*args): + srv.close() + srv.handle_error.side_effect = side_effect + + srv.stream.feed_data(b'GET / HT/asd\r\n') + + self.loop.run_until_complete(srv._request_handle) + self.assertTrue(srv.handle_error.called) + self.assertTrue(400, srv.handle_error.call_args[0][0]) + self.assertTrue(transport.close.called) + + def test_handle_500(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + handle.side_effect = ValueError + srv.handle_error = unittest.mock.Mock() + srv.close() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + self.loop.run_until_complete(srv._request_handle) + + self.assertTrue(srv.handle_error.called) + self.assertTrue(500, srv.handle_error.call_args[0][0]) diff --git a/tests/locks_test.py b/tests/locks_test.py new file mode 100644 index 00000000..7d2111d9 --- /dev/null +++ b/tests/locks_test.py @@ -0,0 +1,747 @@ +"""Tests for lock.py""" + +import time +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import locks +from tulip import tasks +from tulip import test_utils + + +class LockTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + lock = locks.Lock() + self.assertTrue(repr(lock).endswith('[unlocked]>')) + + def acquire_lock(): + yield from lock + + self.event_loop.run_until_complete(acquire_lock()) + self.assertTrue(repr(lock).endswith('[locked]>')) + + def test_lock(self): + lock = locks.Lock() + + def acquire_lock(): + return (yield from lock) + + res = self.event_loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_acquire(self): + lock = locks.Lock() + result = [] + + self.assertTrue( + self.event_loop.run_until_complete(lock.acquire())) + + @tasks.coroutine + def c1(result): + if (yield from lock.acquire()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from lock.acquire()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from lock.acquire()): + result.append(3) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + self.event_loop.run_once() + self.assertEqual([1], result) + + tasks.Task(c3(result)) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + + def test_acquire_timeout(self): + lock = locks.Lock() + self.assertTrue( + self.event_loop.run_until_complete(lock.acquire())) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete( + lock.acquire(timeout=0.1)) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + lock = locks.Lock() + self.event_loop.run_until_complete(lock.acquire()) + + self.event_loop.call_later(0.1, lock.release) + acquired = self.event_loop.run_until_complete(lock.acquire(10.1)) + self.assertTrue(acquired) + + def test_acquire_timeout_mixed(self): + lock = locks.Lock() + self.event_loop.run_until_complete(lock.acquire()) + tasks.Task(lock.acquire()) + tasks.Task(lock.acquire()) + acquire_task = tasks.Task(lock.acquire(0.1)) + tasks.Task(lock.acquire()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(lock._waiters)) + + def test_acquire_cancel(self): + lock = locks.Lock() + self.assertTrue( + self.event_loop.run_until_complete(lock.acquire())) + + task = tasks.Task(lock.acquire()) + self.event_loop.call_soon(task.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, task) + self.assertFalse(lock._waiters) + + def test_release_not_acquired(self): + lock = locks.Lock() + + self.assertRaises(RuntimeError, lock.release) + + def test_release_no_waiters(self): + lock = locks.Lock() + self.event_loop.run_until_complete(lock.acquire()) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_context_manager(self): + lock = locks.Lock() + + @tasks.task + def acquire_lock(): + return (yield from lock) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + def test_context_manager_no_yield(self): + lock = locks.Lock() + + try: + with lock: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + +class EventWaiterTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + ev = locks.EventWaiter() + self.assertTrue(repr(ev).endswith('[unset]>')) + + ev.set() + self.assertTrue(repr(ev).endswith('[set]>')) + + def test_wait(self): + ev = locks.EventWaiter() + self.assertFalse(ev.is_set()) + + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from ev.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from ev.wait()): + result.append(3) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + tasks.Task(c3(result)) + + ev.set() + self.event_loop.run_once() + self.assertEqual([3, 1, 2], result) + + def test_wait_on_set(self): + ev = locks.EventWaiter() + ev.set() + + res = self.event_loop.run_until_complete(ev.wait()) + self.assertTrue(res) + + def test_wait_timeout(self): + ev = locks.EventWaiter() + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(ev.wait(0.1)) + self.assertFalse(res) + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + ev = locks.EventWaiter() + self.event_loop.call_later(0.1, ev.set) + acquired = self.event_loop.run_until_complete(ev.wait(10.1)) + self.assertTrue(acquired) + + def test_wait_timeout_mixed(self): + ev = locks.EventWaiter() + tasks.Task(ev.wait()) + tasks.Task(ev.wait()) + acquire_task = tasks.Task(ev.wait(0.1)) + tasks.Task(ev.wait()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(ev._waiters)) + + def test_wait_cancel(self): + ev = locks.EventWaiter() + + wait = tasks.Task(ev.wait()) + self.event_loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, wait) + self.assertFalse(ev._waiters) + + def test_clear(self): + ev = locks.EventWaiter() + self.assertFalse(ev.is_set()) + + ev.set() + self.assertTrue(ev.is_set()) + + ev.clear() + self.assertFalse(ev.is_set()) + + def test_clear_with_waiters(self): + ev = locks.EventWaiter() + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + tasks.Task(c1(result)) + self.event_loop.run_once() + self.assertEqual([], result) + + ev.set() + ev.clear() + self.assertFalse(ev.is_set()) + + ev.set() + ev.set() + self.assertEqual(1, len(ev._waiters)) + + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertEqual(0, len(ev._waiters)) + + +class ConditionTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_wait(self): + cond = locks.Condition() + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue( + self.event_loop.run_until_complete(cond.acquire())) + cond.notify() + self.event_loop.run_once() + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + def test_wait_timeout(self): + cond = locks.Condition() + self.event_loop.run_until_complete(cond.acquire()) + + t0 = time.monotonic() + wait = self.event_loop.run_until_complete(cond.wait(0.1)) + self.assertFalse(wait) + self.assertTrue(cond.locked()) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + def test_wait_cancel(self): + cond = locks.Condition() + self.event_loop.run_until_complete(cond.acquire()) + + wait = tasks.Task(cond.wait()) + self.event_loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, wait) + self.assertFalse(cond._condition_waiters) + self.assertTrue(cond.locked()) + + def test_wait_unacquired(self): + self.suppress_log_errors() + + cond = locks.Condition() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, cond.wait()) + + def test_wait_for(self): + cond = locks.Condition() + presult = False + + def predicate(): + return presult + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate)): + result.append(1) + cond.release() + + tasks.Task(c1(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([], result) + + presult = True + self.event_loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + def test_wait_for_timeout(self): + cond = locks.Condition() + + result = [] + + predicate = unittest.mock.Mock() + predicate.return_value = False + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate, 0.2)): + result.append(1) + else: + result.append(2) + cond.release() + + wait_for = tasks.Task(c1(result)) + + t0 = time.monotonic() + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(wait_for) + self.assertEqual([2], result) + self.assertEqual(3, predicate.call_count) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.18 < total_time < 0.22) + + def test_wait_for_unacquired(self): + self.suppress_log_errors() + + cond = locks.Condition() + + # predicate can return true immediately + res = self.event_loop.run_until_complete( + cond.wait_for(lambda: [1, 2, 3])) + self.assertEqual([1, 2, 3], res) + + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, + cond.wait_for(lambda: False)) + + def test_notify(self): + cond = locks.Condition() + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + cond.release() + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.notify(2048) + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + + def test_notify_all(self): + cond = locks.Condition() + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify_all() + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + + def test_notify_unacquired(self): + cond = locks.Condition() + self.assertRaises(RuntimeError, cond.notify) + + def test_notify_all_unacquired(self): + cond = locks.Condition() + self.assertRaises(RuntimeError, cond.notify_all) + + +class SemaphoreTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + sem = locks.Semaphore() + self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) + + self.event_loop.run_until_complete(sem.acquire()) + self.assertTrue(repr(sem).endswith('[locked]>')) + + def test_semaphore(self): + sem = locks.Semaphore() + self.assertEqual(1, sem._value) + + @tasks.task + def acquire_lock(): + return (yield from sem) + + res = self.event_loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(sem.locked()) + self.assertEqual(0, sem._value) + + sem.release() + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + def test_semaphore_value(self): + self.assertRaises(ValueError, locks.Semaphore, -1) + + def test_acquire(self): + sem = locks.Semaphore(3) + result = [] + + self.assertTrue( + self.event_loop.run_until_complete(sem.acquire())) + self.assertTrue( + self.event_loop.run_until_complete(sem.acquire())) + self.assertFalse(sem.locked()) + + @tasks.coroutine + def c1(result): + yield from sem.acquire() + result.append(1) + + @tasks.coroutine + def c2(result): + yield from sem.acquire() + result.append(2) + + @tasks.coroutine + def c3(result): + yield from sem.acquire() + result.append(3) + + @tasks.coroutine + def c4(result): + yield from sem.acquire() + result.append(4) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(sem.locked()) + self.assertEqual(2, len(sem._waiters)) + self.assertEqual(0, sem._value) + + tasks.Task(c4(result)) + + sem.release() + sem.release() + self.assertEqual(2, sem._value) + + self.event_loop.run_once() + self.assertEqual(0, sem._value) + self.assertEqual([1, 2, 3], result) + self.assertTrue(sem.locked()) + self.assertEqual(1, len(sem._waiters)) + self.assertEqual(0, sem._value) + + def test_acquire_timeout(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(sem.acquire(0.1)) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + + self.event_loop.call_later(0.1, sem.release) + acquired = self.event_loop.run_until_complete(sem.acquire(10.1)) + self.assertTrue(acquired) + + def test_acquire_timeout_mixed(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + tasks.Task(sem.acquire()) + tasks.Task(sem.acquire()) + acquire_task = tasks.Task(sem.acquire(0.1)) + tasks.Task(sem.acquire()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(sem._waiters)) + + def test_acquire_cancel(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + + acquire = tasks.Task(sem.acquire()) + self.event_loop.call_soon(acquire.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, acquire) + self.assertFalse(sem._waiters) + + def test_release_not_acquired(self): + sem = locks.Semaphore(bound=True) + + self.assertRaises(ValueError, sem.release) + + def test_release_no_waiters(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + self.assertTrue(sem.locked()) + + sem.release() + self.assertFalse(sem.locked()) + + def test_context_manager(self): + sem = locks.Semaphore(2) + + @tasks.task + def acquire_lock(): + return (yield from sem) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertTrue(sem.locked()) + + self.assertEqual(2, sem._value) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/queues_test.py b/tests/queues_test.py new file mode 100644 index 00000000..d86abd7f --- /dev/null +++ b/tests/queues_test.py @@ -0,0 +1,370 @@ +"""Tests for queues.py""" + +import unittest +import queue + +from tulip import events +from tulip import locks +from tulip import queues +from tulip import tasks + + +class _QueueTestBase(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + +class QueueBasicTests(_QueueTestBase): + + def _test_repr_or_str(self, fn, expect_id): + """Test Queue's repr or str. + + fn is repr or str. expect_id is True if we expect the Queue's id to + appear in fn(Queue()). + """ + q = queues.Queue() + self.assertTrue(fn(q).startswith('", repr(key)) + + def test_register(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ) + self.assertIsInstance(key, selectors.SelectorKey) + self.assertEqual(key.fd, 10) + self.assertIs(key, s._fd_to_key[10]) + + def test_register_unknown_event(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.register, unittest.mock.Mock(), 999999) + + def test_register_already_registered(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + s.register(fobj, selectors.EVENT_READ) + self.assertRaises(ValueError, s.register, fobj, selectors.EVENT_READ) + + def test_unregister(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + s.register(fobj, selectors.EVENT_READ) + s.unregister(fobj) + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_unregister_unknown(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.unregister, unittest.mock.Mock()) + + def test_modify_unknown(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.modify, unittest.mock.Mock(), 1) + + def test_modify(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ) + key2 = s.modify(fobj, selectors.EVENT_WRITE) + self.assertNotEqual(key.events, key2.events) + self.assertEqual((selectors.EVENT_WRITE, None), s.get_info(fobj)) + + def test_modify_data(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + d1 = object() + d2 = object() + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ, d1) + key2 = s.modify(fobj, selectors.EVENT_READ, d2) + self.assertEqual(key.events, key2.events) + self.assertNotEqual(key.data, key2.data) + self.assertEqual((selectors.EVENT_READ, d2), s.get_info(fobj)) + + def test_modify_same(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + data = object() + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ, data) + key2 = s.modify(fobj, selectors.EVENT_READ, data) + self.assertIs(key, key2) + + def test_select(self): + s = selectors._BaseSelector() + self.assertRaises(NotImplementedError, s.select) + + def test_close(self): + s = selectors._BaseSelector() + s.register(1, selectors.EVENT_READ) + + s.close() + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_registered_count(self): + s = selectors._BaseSelector() + self.assertEqual(0, s.registered_count()) + + s.register(1, selectors.EVENT_READ) + self.assertEqual(1, s.registered_count()) + + s.unregister(1) + self.assertEqual(0, s.registered_count()) + + def test_context_manager(self): + s = selectors._BaseSelector() + + with s as sel: + sel.register(1, selectors.EVENT_READ) + + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_key_from_fd(self): + s = selectors._BaseSelector() + key = s.register(1, selectors.EVENT_READ) + + self.assertIs(key, s._key_from_fd(1)) + self.assertIsNone(s._key_from_fd(10)) diff --git a/tests/streams_test.py b/tests/streams_test.py new file mode 100644 index 00000000..dc6eeaf4 --- /dev/null +++ b/tests/streams_test.py @@ -0,0 +1,299 @@ +"""Tests for streams.py.""" + +import unittest + +from tulip import events +from tulip import streams +from tulip import tasks +from tulip import test_utils + + +class StreamReaderTests(test_utils.LogTrackingTestCase): + + DATA = b'line1\nline2\nline3\n' + + def setUp(self): + super().setUp() + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + super().tearDown() + self.event_loop.close() + + def test_feed_empty_data(self): + stream = streams.StreamReader() + + stream.feed_data(b'') + self.assertEqual(0, stream.byte_count) + + def test_feed_data_byte_count(self): + stream = streams.StreamReader() + + stream.feed_data(self.DATA) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read_zero(self): + # Read zero bytes. + stream = streams.StreamReader() + stream.feed_data(self.DATA) + + data = self.event_loop.run_until_complete(stream.read(0)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read(self): + # Read bytes. + stream = streams.StreamReader() + read_task = tasks.Task(stream.read(30)) + + def cb(): + stream.feed_data(self.DATA) + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + + def test_read_line_breaks(self): + # Read bytes without line breaks. + stream = streams.StreamReader() + stream.feed_data(b'line1') + stream.feed_data(b'line2') + + data = self.event_loop.run_until_complete(stream.read(5)) + + self.assertEqual(b'line1', data) + self.assertEqual(5, stream.byte_count) + + def test_read_eof(self): + # Read bytes, stop at eof. + stream = streams.StreamReader() + read_task = tasks.Task(stream.read(1024)) + + def cb(): + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'', data) + self.assertFalse(stream.byte_count) + + def test_read_until_eof(self): + # Read all bytes until eof. + stream = streams.StreamReader() + read_task = tasks.Task(stream.read(-1)) + + def cb(): + stream.feed_data(b'chunk1\n') + stream.feed_data(b'chunk2') + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'chunk1\nchunk2', data) + self.assertFalse(stream.byte_count) + + def test_read_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete(stream.read(2)) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, stream.read(2)) + + def test_readline(self): + # Read one line. + stream = streams.StreamReader() + stream.feed_data(b'chunk1 ') + read_task = tasks.Task(stream.readline()) + + def cb(): + stream.feed_data(b'chunk2 ') + stream.feed_data(b'chunk3 ') + stream.feed_data(b'\n chunk4') + self.event_loop.call_soon(cb) + + line = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) + self.assertEqual(len(b'\n chunk4')-1, stream.byte_count) + + def test_readline_limit_with_existing_data(self): + self.suppress_log_errors() + + stream = streams.StreamReader(3) + stream.feed_data(b'li') + stream.feed_data(b'ne1\nline2\n') + + self.assertRaises( + ValueError, self.event_loop.run_until_complete, stream.readline()) + self.assertEqual([b'line2\n'], list(stream.buffer)) + + stream = streams.StreamReader(3) + stream.feed_data(b'li') + stream.feed_data(b'ne1') + stream.feed_data(b'li') + + self.assertRaises( + ValueError, self.event_loop.run_until_complete, stream.readline()) + self.assertEqual([b'li'], list(stream.buffer)) + self.assertEqual(2, stream.byte_count) + + def test_readline_limit(self): + self.suppress_log_errors() + + stream = streams.StreamReader(7) + + def cb(): + stream.feed_data(b'chunk1') + stream.feed_data(b'chunk2') + stream.feed_data(b'chunk3\n') + stream.feed_eof() + self.event_loop.call_soon(cb) + + self.assertRaises( + ValueError, self.event_loop.run_until_complete, stream.readline()) + self.assertEqual([b'chunk3\n'], list(stream.buffer)) + self.assertEqual(7, stream.byte_count) + + def test_readline_line_byte_count(self): + stream = streams.StreamReader() + stream.feed_data(self.DATA[:6]) + stream.feed_data(self.DATA[6:]) + + line = self.event_loop.run_until_complete(stream.readline()) + + self.assertEqual(b'line1\n', line) + self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) + + def test_readline_eof(self): + stream = streams.StreamReader() + stream.feed_data(b'some data') + stream.feed_eof() + + line = self.event_loop.run_until_complete(stream.readline()) + self.assertEqual(b'some data', line) + + def test_readline_empty_eof(self): + stream = streams.StreamReader() + stream.feed_eof() + + line = self.event_loop.run_until_complete(stream.readline()) + self.assertEqual(b'', line) + + def test_readline_read_byte_count(self): + stream = streams.StreamReader() + stream.feed_data(self.DATA) + + self.event_loop.run_until_complete(stream.readline()) + + data = self.event_loop.run_until_complete(stream.read(7)) + + self.assertEqual(b'line2\nl', data) + self.assertEqual( + len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), + stream.byte_count) + + def test_readline_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete(stream.readline()) + self.assertEqual(b'line\n', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, stream.readline()) + + def test_readexactly_zero_or_less(self): + # Read exact number of bytes (zero or less). + stream = streams.StreamReader() + stream.feed_data(self.DATA) + + data = self.event_loop.run_until_complete(stream.readexactly(0)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + data = self.event_loop.run_until_complete(stream.readexactly(-1)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readexactly(self): + # Read exact number of bytes. + stream = streams.StreamReader() + + n = 2 * len(self.DATA) + read_task = tasks.Task(stream.readexactly(n)) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA + self.DATA, data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readexactly_eof(self): + # Read exact number of bytes (eof). + stream = streams.StreamReader() + n = 2 * len(self.DATA) + read_task = tasks.Task(stream.readexactly(n)) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + + def test_readexactly_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete(stream.readexactly(2)) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, + stream.readexactly(2)) + + def test_exception(self): + stream = streams.StreamReader() + self.assertIsNone(stream.exception()) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(stream.exception(), exc) + + def test_exception_waiter(self): + stream = streams.StreamReader() + + def set_err(): + yield from [] + stream.set_exception(ValueError()) + + def readline(): + yield from stream.readline() + + t1 = tasks.Task(stream.readline()) + t2 = tasks.Task(set_err()) + + self.event_loop.run_until_complete(tasks.wait([t1, t2])) + + self.assertRaises(ValueError, t1.result) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/subprocess_test.py b/tests/subprocess_test.py new file mode 100644 index 00000000..09aaed52 --- /dev/null +++ b/tests/subprocess_test.py @@ -0,0 +1,54 @@ +"""Tests for subprocess_transport.py.""" + +import logging +import unittest + +from tulip import events +from tulip import protocols +from tulip import subprocess_transport + + +class MyProto(protocols.Protocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write_eof() + + def data_received(self, data): + logging.info('received: %r', data) + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + + +class FutureTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_unix_subprocess(self): + p = MyProto() + subprocess_transport.UnixSubprocessTransport(p, ['/bin/ls', '-lR']) + self.event_loop.run() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/tasks_test.py b/tests/tasks_test.py new file mode 100644 index 00000000..9ac15bb9 --- /dev/null +++ b/tests/tasks_test.py @@ -0,0 +1,647 @@ +"""Tests for tasks.py.""" + +import concurrent.futures +import time +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import tasks +from tulip import test_utils + + +class Dummy: + + def __repr__(self): + return 'Dummy()' + + def __call__(self, *args): + pass + + +class TaskTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + super().tearDown() + + def test_task_class(self): + @tasks.coroutine + def notmuch(): + yield from [] + return 'ok' + t = tasks.Task(notmuch()) + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t._event_loop, self.event_loop) + + event_loop = events.new_event_loop() + t = tasks.Task(notmuch(), event_loop=event_loop) + self.assertIs(t._event_loop, event_loop) + + def test_task_decorator(self): + @tasks.task + def notmuch(): + yield from [] + return 'ko' + t = notmuch() + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + + def test_task_repr(self): + @tasks.task + def notmuch(): + yield from [] + return 'abc' + t = notmuch() + t.add_done_callback(Dummy()) + self.assertEqual(repr(t), 'Task()') + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), 'Task()') + self.assertRaises(futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertEqual(repr(t), 'Task()') + t = notmuch() + self.event_loop.run_until_complete(t) + self.assertEqual(repr(t), "Task()") + + def test_task_repr_custom(self): + def coro(): + yield from [] + + class T(futures.Future): + def __repr__(self): + return 'T[]' + + class MyTask(tasks.Task, T): + def __repr__(self): + return super().__repr__() + + t = MyTask(coro()) + self.assertEqual(repr(t), 'T[]()') + + def test_task_basics(self): + @tasks.task + def outer(): + a = yield from inner1() + b = yield from inner2() + return a+b + + @tasks.task + def inner1(): + yield from [] + return 42 + + @tasks.task + def inner2(): + yield from [] + return 1000 + + t = outer() + self.assertEqual(self.event_loop.run_until_complete(t), 1042) + + def test_cancel(self): + @tasks.task + def task(): + yield from tasks.sleep(10.0) + return 12 + + t = task() + self.event_loop.call_soon(t.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_future_timeout(self): + @tasks.coroutine + def coro(): + yield from tasks.sleep(10.0) + return 12 + + t = tasks.Task(coro(), timeout=0.1) + + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_future_timeout_catch(self): + @tasks.coroutine + def coro(): + yield from tasks.sleep(10.0) + return 12 + + err = None + + @tasks.coroutine + def coro2(): + nonlocal err + try: + yield from tasks.Task(coro(), timeout=0.1) + except futures.CancelledError as exc: + err = exc + + self.event_loop.run_until_complete(tasks.Task(coro2())) + self.assertIsInstance(err, futures.CancelledError) + + def test_cancel_in_coro(self): + @tasks.coroutine + def task(): + t.cancel() + yield from [] + return 12 + + t = tasks.Task(task()) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_stop_while_run_in_complete(self): + x = 0 + + @tasks.coroutine + def task(): + nonlocal x + while x < 10: + yield from tasks.sleep(0.1) + x += 1 + if x == 2: + self.event_loop.stop() + + t = tasks.Task(task()) + t0 = time.monotonic() + self.assertRaises( + futures.InvalidStateError, + self.event_loop.run_until_complete, t) + t1 = time.monotonic() + self.assertFalse(t.done()) + self.assertTrue(0.18 <= t1-t0 <= 0.22) + self.assertEqual(x, 2) + + def test_timeout(self): + @tasks.task + def task(): + yield from tasks.sleep(10.0) + return 42 + + t = task() + t0 = time.monotonic() + self.assertRaises( + futures.TimeoutError, + self.event_loop.run_until_complete, t, 0.1) + t1 = time.monotonic() + self.assertFalse(t.done()) + self.assertTrue(0.08 <= t1-t0 <= 0.12) + + def test_timeout_not(self): + @tasks.task + def task(): + yield from tasks.sleep(0.1) + return 42 + + t = task() + t0 = time.monotonic() + r = self.event_loop.run_until_complete(t, 10.0) + t1 = time.monotonic() + self.assertTrue(t.done()) + self.assertEqual(r, 42) + self.assertTrue(0.08 <= t1-t0 <= 0.12) + + def test_wait(self): + a = tasks.sleep(0.1) + b = tasks.sleep(0.15) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertEqual(res, 42) + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + # TODO: Test different return_when values. + + def test_wait_first_completed(self): + a = tasks.sleep(10.0) + b = tasks.sleep(0.1) + task = tasks.Task(tasks.wait( + [b, a], return_when=tasks.FIRST_COMPLETED)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + + def test_wait_really_done(self): + self.suppress_log_errors() + # there is possibility that some tasks in the pending list + # became done but their callbacks haven't all been called yet + + @tasks.coroutine + def coro1(): + yield from [None] + + @tasks.coroutine + def coro2(): + yield from [None, None] + + a = tasks.Task(coro1()) + b = tasks.Task(coro2()) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({a, b}, done) + + def test_wait_first_exception(self): + self.suppress_log_errors() + + a = tasks.sleep(10.0) + + @tasks.coroutine + def exc(): + yield from [] + raise ZeroDivisionError('err') + + b = tasks.Task(exc()) + task = tasks.Task(tasks.wait( + [b, a], return_when=tasks.FIRST_EXCEPTION)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + + def test_wait_with_exception(self): + self.suppress_log_errors() + a = tasks.sleep(0.1) + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(0.15) + raise ZeroDivisionError('really') + + b = tasks.Task(sleeper()) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + + t0 = time.monotonic() + self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + t0 = time.monotonic() + self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + + def test_wait_with_timeout(self): + a = tasks.sleep(0.1) + b = tasks.sleep(0.15) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], timeout=0.11) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + + t0 = time.monotonic() + self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.1) + self.assertTrue(t1-t0 <= 0.13) + + def test_as_completed(self): + @tasks.coroutine + def sleeper(dt, x): + yield from tasks.sleep(dt) + return x + + a = sleeper(0.1, 'a') + b = sleeper(0.1, 'b') + c = sleeper(0.15, 'c') + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([b, c, a]): + values.append((yield from f)) + return values + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + + def test_as_completed_with_timeout(self): + self.suppress_log_errors() + a = tasks.sleep(0.1, 'a') + b = tasks.sleep(0.15, 'b') + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([a, b], timeout=0.12): + try: + v = yield from f + values.append((1, v)) + except futures.TimeoutError as exc: + values.append((2, exc)) + return values + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.11) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) + + def test_sleep(self): + @tasks.coroutine + def sleeper(dt, arg): + yield from tasks.sleep(dt/2) + res = yield from tasks.sleep(dt/2, arg) + return res + + t = tasks.Task(sleeper(0.1, 'yeah')) + t0 = time.monotonic() + self.event_loop.run() + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.09) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + + def test_task_cancel_sleeping_task(self): + sleepfut = None + + @tasks.task + def sleep(dt): + nonlocal sleepfut + sleepfut = tasks.sleep(dt) + try: + time.monotonic() + yield from sleepfut + finally: + time.monotonic() + + @tasks.task + def doit(): + sleeper = sleep(5000) + self.event_loop.call_later(0.1, sleeper.cancel) + try: + time.monotonic() + yield from sleeper + except futures.CancelledError: + time.monotonic() + return 'cancelled' + else: + return 'slept in' + + t0 = time.monotonic() + doer = doit() + self.assertEqual(self.event_loop.run_until_complete(doer), 'cancelled') + t1 = time.monotonic() + self.assertTrue(0.09 <= t1-t0 <= 0.13, (t1-t0, sleepfut, doer)) + + @unittest.mock.patch('tulip.tasks.tulip_log') + def test_step_in_completed_task(self, m_logging): + @tasks.coroutine + def notmuch(): + yield from [] + return 'ko' + + task = tasks.Task(notmuch()) + task.set_result('ok') + + task._step() + self.assertTrue(m_logging.warn.called) + self.assertTrue(m_logging.warn.call_args[0][0].startswith( + '_step(): already done: ')) + + @unittest.mock.patch('tulip.tasks.tulip_log') + def test_step_result(self, m_logging): + @tasks.coroutine + def notmuch(): + yield from [None, 1] + return 'ko' + + task = tasks.Task(notmuch()) + task._step() + self.assertFalse(m_logging.warn.called) + + task._step() + self.assertTrue(m_logging.warn.called) + self.assertEqual( + '_step(): bad yield: %r', + m_logging.warn.call_args[0][0]) + self.assertEqual(1, m_logging.warn.call_args[0][1]) + + def test_step_result_future(self): + # If coroutine returns future, task waits on this future. + self.suppress_log_warnings() + + class Fut(futures.Future): + def __init__(self, *args): + self.cb_added = False + super().__init__(*args) + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + fut = Fut() + result = None + + @tasks.task + def wait_for_future(): + nonlocal result + result = yield from fut + + wait_for_future() + self.event_loop.run_once() + self.assertTrue(fut.cb_added) + + res = object() + fut.set_result(res) + self.event_loop.run_once() + self.assertIs(res, result) + + def test_step_result_concurrent_future(self): + # Coroutine returns concurrent.futures.Future + self.suppress_log_warnings() + + class Fut(concurrent.futures.Future): + def __init__(self): + self.cb_added = False + super().__init__() + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + c_fut = Fut() + + @tasks.coroutine + def notmuch(): + yield from [c_fut] + return (yield) + + task = tasks.Task(notmuch()) + task._step() + self.assertTrue(c_fut.cb_added) + + res = object() + c_fut.set_result(res) + self.event_loop.run() + self.assertIs(res, task.result()) + + def test_step_with_baseexception(self): + self.suppress_log_errors() + + @tasks.coroutine + def notmutch(): + yield from [] + raise BaseException() + + task = tasks.Task(notmutch()) + self.assertRaises(BaseException, task._step) + + self.assertTrue(task.done()) + self.assertIsInstance(task.exception(), BaseException) + + def test_baseexception_during_cancel(self): + self.suppress_log_errors() + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(10) + + @tasks.coroutine + def notmutch(): + try: + yield from sleeper() + except futures.CancelledError: + raise BaseException() + + task = tasks.Task(notmutch()) + self.event_loop.run_once() + + task.cancel() + self.assertFalse(task.done()) + + self.assertRaises(BaseException, self.event_loop.run_once) + + self.assertTrue(task.done()) + self.assertTrue(task.cancelled()) + + def test_iscoroutinefunction(self): + def fn(): + pass + + self.assertFalse(tasks.iscoroutinefunction(fn)) + + def fn1(): + yield + self.assertFalse(tasks.iscoroutinefunction(fn1)) + + @tasks.coroutine + def fn2(): + yield + self.assertTrue(tasks.iscoroutinefunction(fn2)) + + def test_yield_vs_yield_from(self): + fut = futures.Future() + + @tasks.task + def wait_for_future(): + yield fut + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, task) + + def test_yield_vs_yield_from_generator(self): + fut = futures.Future() + + @tasks.coroutine + def coro(): + yield from fut + + @tasks.task + def wait_for_future(): + yield coro() + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, task) + + def test_coroutine_non_gen_function(self): + + @tasks.coroutine + def func(): + return 'test' + + self.assertTrue(tasks.iscoroutinefunction(func)) + + coro = func() + self.assertTrue(tasks.iscoroutine(coro)) + + res = self.event_loop.run_until_complete(coro) + self.assertEqual(res, 'test') + + def test_coroutine_non_gen_function_return_future(self): + fut = futures.Future() + + @tasks.coroutine + def func(): + return fut + + @tasks.coroutine + def coro(): + fut.set_result('test') + + t1 = tasks.Task(func()) + tasks.Task(coro()) + res = self.event_loop.run_until_complete(t1) + self.assertEqual(res, 'test') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/transports_test.py b/tests/transports_test.py new file mode 100644 index 00000000..4b24b50b --- /dev/null +++ b/tests/transports_test.py @@ -0,0 +1,45 @@ +"""Tests for transports.py.""" + +import unittest +import unittest.mock + +from tulip import transports + + +class TransportTests(unittest.TestCase): + + def test_ctor_extra_is_none(self): + transport = transports.Transport() + self.assertEqual(transport._extra, {}) + + def test_get_extra_info(self): + transport = transports.Transport({'extra': 'info'}) + self.assertEqual('info', transport.get_extra_info('extra')) + self.assertIsNone(transport.get_extra_info('unknown')) + + default = object() + self.assertIs(default, transport.get_extra_info('unknown', default)) + + def test_writelines(self): + transport = transports.Transport() + transport.write = unittest.mock.Mock() + + transport.writelines(['line1', 'line2', 'line3']) + self.assertEqual(3, transport.write.call_count) + + def test_not_implemented(self): + transport = transports.Transport() + + self.assertRaises(NotImplementedError, transport.write, 'data') + self.assertRaises(NotImplementedError, transport.write_eof) + self.assertRaises(NotImplementedError, transport.can_write_eof) + self.assertRaises(NotImplementedError, transport.pause) + self.assertRaises(NotImplementedError, transport.resume) + self.assertRaises(NotImplementedError, transport.close) + self.assertRaises(NotImplementedError, transport.abort) + + def test_dgram_not_implemented(self): + transport = transports.DatagramTransport() + + self.assertRaises(NotImplementedError, transport.sendto, 'data') + self.assertRaises(NotImplementedError, transport.abort) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py new file mode 100644 index 00000000..d7af7ecc --- /dev/null +++ b/tests/unix_events_test.py @@ -0,0 +1,573 @@ +"""Tests for unix_events.py.""" + +import errno +import io +import unittest +import unittest.mock + +try: + import signal +except ImportError: + signal = None + +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import unix_events + + +@unittest.skipUnless(signal, 'Signals are not supported') +class SelectorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unix_events.SelectorEventLoop() + + def test_check_signal(self): + self.assertRaises( + TypeError, self.event_loop._check_signal, '1') + self.assertRaises( + ValueError, self.event_loop._check_signal, signal.NSIG + 1) + + unix_events.signal = None + + def restore_signal(): + unix_events.signal = signal + self.addCleanup(restore_signal) + + self.assertRaises( + RuntimeError, self.event_loop._check_signal, signal.SIGINT) + + def test_handle_signal_no_handler(self): + self.event_loop._handle_signal(signal.NSIG + 1, ()) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_setup_error(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.set_wakeup_fd.side_effect = ValueError + + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + h = self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + self.assertIsInstance(h, events.Handle) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_install_error(self, m_signal): + m_signal.NSIG = signal.NSIG + + def set_wakeup_fd(fd): + if fd == -1: + raise ValueError() + m_signal.set_wakeup_fd = set_wakeup_fd + + class Err(OSError): + errno = errno.EFAULT + m_signal.signal.side_effect = Err + + self.assertRaises( + Err, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_add_signal_handler_install_error2(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.event_loop._signal_handlers[signal.SIGHUP] = lambda: True + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(1, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_add_signal_handler_install_error3(self, m_logging, m_signal): + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + m_signal.NSIG = signal.NSIG + + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(2, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + self.assertTrue( + self.event_loop.remove_signal_handler(signal.SIGHUP)) + self.assertTrue(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_2(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.SIGINT = signal.SIGINT + + self.event_loop.add_signal_handler(signal.SIGINT, lambda: True) + self.event_loop._signal_handlers[signal.SIGHUP] = object() + m_signal.set_wakeup_fd.reset_mock() + + self.assertTrue( + self.event_loop.remove_signal_handler(signal.SIGINT)) + self.assertFalse(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGINT, m_signal.default_int_handler), + m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.set_wakeup_fd.side_effect = ValueError + + self.event_loop.remove_signal_handler(signal.SIGHUP) + self.assertTrue(m_logging.info) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error(self, m_signal): + m_signal.NSIG = signal.NSIG + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.signal.side_effect = OSError + + self.assertRaises( + OSError, self.event_loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error2(self, m_signal): + m_signal.NSIG = signal.NSIG + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.assertRaises( + RuntimeError, self.event_loop.remove_signal_handler, signal.SIGHUP) + + +class UnixReadPipeTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor(self, m_fcntl): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + self.event_loop.add_reader.assert_called_with(5, tr._read_ready) + self.event_loop.call_soon.assert_called_with( + self.protocol.connection_made, tr) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor_with_waiter(self, m_fcntl): + fut = futures.Future() + unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol, fut) + self.event_loop.call_soon.assert_called_with(fut.set_result, None) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + m_read.return_value = b'data' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.data_received.assert_called_with(b'data') + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready_eof(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + m_read.return_value = b'' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.eof_received.assert_called_with() + self.event_loop.remove_reader.assert_called_with(5) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready_blocked(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + self.event_loop.reset_mock() + m_read.side_effect = BlockingIOError + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.assertFalse(self.protocol.data_received.called) + + @unittest.mock.patch('tulip.log.tulip_log.exception') + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready_error(self, m_fcntl, m_read, m_logexc): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + err = OSError() + m_read.side_effect = err + tr._close = unittest.mock.Mock() + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + tr._close.assert_called_with(err) + m_logexc.assert_called_with('Fatal error for %s', tr) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_pause(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.pause() + self.event_loop.remove_reader.assert_called_with(5) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_resume(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.resume() + self.event_loop.add_reader.assert_called_with(5, tr._read_ready) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_close(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._close = unittest.mock.Mock() + tr.close() + tr._close.assert_called_with(None) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_close_already_closing(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._closing = True + tr._close = unittest.mock.Mock() + tr.close() + self.assertFalse(tr._close.called) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__close(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = object() + tr._close(err) + self.assertTrue(tr._closing) + self.event_loop.remove_reader.assert_called_with(5) + self.protocol.connection_lost.assert_called_with(err) + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost(self, m_fcntl): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost_with_err(self, m_fcntl): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + +class UnixWritePipeTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + self.event_loop.call_soon.assert_called_with( + self.protocol.connection_made, tr) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor_with_waiter(self, m_fcntl): + fut = futures.Future() + unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol, fut) + self.event_loop.call_soon.assert_called_with(fut.set_result, None) + + @unittest.mock.patch('fcntl.fcntl') + def test_can_write_eof(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + self.assertTrue(tr.can_write_eof()) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + m_write.return_value = 4 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.add_writer.called) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_no_data(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write(b'') + self.assertFalse(m_write.called) + self.assertFalse(self.event_loop.add_writer.called) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_partial(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + m_write.return_value = 2 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.event_loop.add_writer.assert_called_with(5, tr._write_ready) + self.assertEqual([b'ta'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_buffer(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'previous'] + tr.write(b'data') + self.assertFalse(m_write.called) + self.assertFalse(self.event_loop.add_writer.called) + self.assertEqual([b'previous', b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_again(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + m_write.side_effect = BlockingIOError() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.event_loop.add_writer.assert_called_with(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_err(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = OSError() + m_write.side_effect = err + tr._fatal_error = unittest.mock.Mock() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.called) + self.assertEqual([], tr._buffer) + tr._fatal_error.assert_called_with(err) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_partial(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.return_value = 3 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'a'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_again(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.side_effect = BlockingIOError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_empty(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.return_value = 0 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('tulip.log.tulip_log.exception') + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_err(self, m_fcntl, m_write, m_logexc): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.side_effect = err = OSError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + self.protocol.connection_lost.assert_called_with(err) + m_logexc.assert_called_with('Fatal error for %s', tr) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_closing(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._closing = True + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_abort(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + tr.abort() + self.assertFalse(m_write.called) + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost_with_err(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test_close(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr.close() + tr.write_eof.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test_close_closing(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr._closing = True + tr.close() + self.assertFalse(tr.write_eof.called) + + @unittest.mock.patch('fcntl.fcntl') + def test_write_eof(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write_eof() + self.assertTrue(tr._closing) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('fcntl.fcntl') + def test_write_eof_pending(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + tr._buffer = [b'data'] + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.protocol.connection_lost.called) diff --git a/tests/winsocketpair_test.py b/tests/winsocketpair_test.py new file mode 100644 index 00000000..381fb227 --- /dev/null +++ b/tests/winsocketpair_test.py @@ -0,0 +1,26 @@ +"""Tests for winsocketpair.py""" + +import unittest +import unittest.mock + +from tulip import winsocketpair + + +class WinsocketpairTests(unittest.TestCase): + + def test_winsocketpair(self): + ssock, csock = winsocketpair.socketpair() + + csock.send(b'xxx') + self.assertEqual(b'xxx', ssock.recv(1024)) + + csock.close() + ssock.close() + + @unittest.mock.patch('tulip.winsocketpair.socket') + def test_winsocketpair_exc(self, m_socket): + m_socket.socket.return_value.getsockname.return_value = ('', 12345) + m_socket.socket.return_value.accept.return_value = object(), object() + m_socket.socket.return_value.connect.side_effect = OSError() + + self.assertRaises(OSError, winsocketpair.socketpair) diff --git a/tulip/TODO b/tulip/TODO new file mode 100644 index 00000000..acec5c24 --- /dev/null +++ b/tulip/TODO @@ -0,0 +1,28 @@ +TODO in tulip v2 (tulip/ package directory) +------------------------------------------- + +- See also TBD and Open Issues in PEP 3156 + +- Refactor unix_events.py (it's getting too long) + +- Docstrings + +- Unittests + +- better run_once() behavior? (Run ready list last.) + +- start_serving() + +- Make Handler() callable? Log the exception in there? + +- Add the repeat interval to the Handler class? + +- Recognize Handler passed to add_reader(), call_soon(), etc.? + +- SSL support + +- buffered stream implementation + +- Primitives like par() and wait_one() + +- Remove test dependency on xkcd.com, write our own test server diff --git a/tulip/__init__.py b/tulip/__init__.py new file mode 100644 index 00000000..faf307fb --- /dev/null +++ b/tulip/__init__.py @@ -0,0 +1,26 @@ +"""Tulip 2.0, tracking PEP 3156.""" + +import sys + +# This relies on each of the submodules having an __all__ variable. +from .futures import * +from .events import * +from .locks import * +from .transports import * +from .protocols import * +from .streams import * +from .tasks import * + +if sys.platform == 'win32': # pragma: no cover + from .windows_events import * +else: + from .unix_events import * # pragma: no cover + + +__all__ = (futures.__all__ + + events.__all__ + + locks.__all__ + + transports.__all__ + + protocols.__all__ + + streams.__all__ + + tasks.__all__) diff --git a/tulip/base_events.py b/tulip/base_events.py new file mode 100644 index 00000000..5ed257a2 --- /dev/null +++ b/tulip/base_events.py @@ -0,0 +1,548 @@ +"""Base implementation of event loop. + +The event loop can be broken up into a multiplexer (the part +responsible for notifying us of IO events) and the event loop proper, +which wraps a multiplexer with functionality for scheduling callbacks, +immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. +""" + + +import collections +import concurrent.futures +import heapq +import logging +import socket +import time + +from . import events +from . import futures +from . import tasks +from .log import tulip_log + + +__all__ = ['BaseEventLoop'] + + +# Argument for default thread pool executor creation. +_MAX_WORKERS = 5 + + +class _StopError(BaseException): + """Raised to stop the event loop.""" + + +def _raise_stop_error(*args): + raise _StopError + + +class BaseEventLoop(events.AbstractEventLoop): + + def __init__(self): + self._ready = collections.deque() + self._scheduled = [] + self._default_executor = None + self._internal_fds = 0 + self._running = False + + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + """Create socket transport.""" + raise NotImplementedError + + def _make_ssl_transport(self, rawsock, protocol, + sslcontext, waiter, extra=None): + """Create SSL transport.""" + raise NotImplementedError + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + """Create datagram transport.""" + raise NotImplementedError + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create read pipe transport.""" + raise NotImplementedError + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create write pipe transport.""" + raise NotImplementedError + + def _read_from_self(self): + """XXX""" + raise NotImplementedError + + def _write_to_self(self): + """XXX""" + raise NotImplementedError + + def _process_events(self, event_list): + """Process selector events.""" + raise NotImplementedError + + def is_running(self): + """Returns running status of event loop.""" + return self._running + + def run(self): + """Run the event loop until nothing left to do or stop() called. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + + TODO: Give this a timeout too? + """ + if self._running: + raise RuntimeError('Event loop is running.') + + self._running = True + try: + while (self._ready or + self._scheduled or + self._selector.registered_count() > 1): + try: + self._run_once() + except _StopError: + break + finally: + self._running = False + + def run_forever(self): + """Run until stop() is called. + + This only makes sense over run() if you have another thread + scheduling callbacks using call_soon_threadsafe(). + """ + handle = self.call_repeatedly(24*3600, lambda: None) + try: + self.run() + finally: + handle.cancel() + + def run_once(self, timeout=0): + """Run through all callbacks and all I/O polls once. + + Calling stop() will break out of this too. + """ + if self._running: + raise RuntimeError('Event loop is running.') + + self._running = True + try: + self._run_once(timeout) + except _StopError: + pass + finally: + self._running = False + + def run_until_complete(self, future, timeout=None): + """Run until the Future is done, or until a timeout. + + If the argument is a coroutine, it is wrapped in a Task. + + XXX TBD: It would be disastrous to call run_until_complete() + with the same coroutine twice -- it would wrap it in two + different Tasks and that can't be good. + + Return the Future's result, or raise its exception. If the + timeout is reached or stop() is called, raise TimeoutError. + """ + if not isinstance(future, futures.Future): + if tasks.iscoroutine(future): + future = tasks.Task(future) + else: + assert False, 'A Future or coroutine is required' + + handle_called = False + + def stop_loop(): + nonlocal handle_called + handle_called = True + raise _StopError + + future.add_done_callback(_raise_stop_error) + + if timeout is None: + self.run_forever() + else: + handle = self.call_later(timeout, stop_loop) + self.run_forever() + handle.cancel() + + if handle_called: + raise futures.TimeoutError + + return future.result() + + def stop(self): + """Stop running the event loop. + + Every callback scheduled before stop() is called will run. + Callback scheduled after stop() is called won't. However, + those callbacks will run if run() is called again later. + """ + self.call_soon(_raise_stop_error) + + def call_later(self, delay, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The delay can be an int or float, expressed in seconds. It is + always a relative time. + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Callbacks scheduled in the past are passed on to call_soon(), + so these will be called in the order in which they were + registered rather than by time due. This is so you can't + cheat and insert yourself at the front of the ready queue by + using a negative time. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + if delay <= 0: + return self.call_soon(callback, *args) + + handle = events.Timer(time.monotonic() + delay, callback, args) + heapq.heappush(self._scheduled, handle) + return handle + + def call_repeatedly(self, interval, callback, *args): + """Call a callback every 'interval' seconds.""" + assert interval > 0, 'Interval must be > 0: %r' % (interval,) + # TODO: What if callback is already a Handle? + def wrapper(): + callback(*args) # If this fails, the chain is broken. + handle._when = time.monotonic() + interval + heapq.heappush(self._scheduled, handle) + + handle = events.Timer(time.monotonic() + interval, wrapper, ()) + heapq.heappush(self._scheduled, handle) + return handle + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + handle = events.make_handle(callback, args) + self._ready.append(handle) + return handle + + def call_soon_threadsafe(self, callback, *args): + """XXX""" + handle = self.call_soon(callback, *args) + self._write_to_self() + return handle + + def run_in_executor(self, executor, callback, *args): + if isinstance(callback, events.Handle): + assert not args + assert not isinstance(callback, events.Timer) + if callback.cancelled: + f = futures.Future() + f.set_result(None) + return f + callback, args = callback.callback, callback.args + if executor is None: + executor = self._default_executor + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + self._default_executor = executor + return self.wrap_future(executor.submit(callback, *args)) + + def set_default_executor(self, executor): + self._default_executor = executor + + def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) + + def getnameinfo(self, sockaddr, flags=0): + return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + + @tasks.coroutine + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=False, family=0, proto=0, flags=0, sock=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + "host, port and sock can not be specified at the same time") + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + yield from self.sock_connect(sock, address) + except socket.error as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + if len(exceptions) == 1: + raise exceptions[0] + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise socket.error('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + + elif sock is None: + raise ValueError( + "host and port was not specified and no sock specified") + + sock.setblocking(False) + + protocol = protocol_factory() + waiter = futures.Future() + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + transport = self._make_ssl_transport( + sock, protocol, sslcontext, waiter) + else: + transport = self._make_socket_transport(sock, protocol, waiter) + + yield from waiter + return transport, protocol + + @tasks.coroutine + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + """Create datagram connection.""" + if not (local_addr or remote_addr): + if family == 0: + raise ValueError('unexpected address family') + addr_pairs_info = (((family, proto), (None, None)),) + else: + # join addresss by (family, protocol) + addr_infos = collections.OrderedDict() + for idx, addr in ((0, local_addr), (1, remote_addr)): + if addr is not None: + assert isinstance(addr, tuple) and len(addr) == 2, ( + '2-tuple is expected') + + infos = yield from self.getaddrinfo( + *addr, family=family, type=socket.SOCK_DGRAM, + proto=proto, flags=flags) + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + for fam, _, pro, _, address in infos: + key = (fam, pro) + if key not in addr_infos: + addr_infos[key] = [None, None] + addr_infos[key][idx] = address + + # each addr has to have info for each (family, proto) pair + addr_pairs_info = [ + (key, addr_pair) for key, addr_pair in addr_infos.items() + if not ((local_addr and addr_pair[0] is None) or + (remote_addr and addr_pair[1] is None))] + + if not addr_pairs_info: + raise ValueError('can not get address information') + + exceptions = [] + + for (family, proto), (local_address, remote_address) in addr_pairs_info: + sock = None + l_addr = None + r_addr = None + try: + sock = socket.socket( + family=family, type=socket.SOCK_DGRAM, proto=proto) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(False) + + if local_addr: + sock.bind(local_address) + l_addr = sock.getsockname() + if remote_addr: + yield from self.sock_connect(sock, remote_address) + r_addr = remote_address + except socket.error as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + + protocol = protocol_factory() + transport = self._make_datagram_transport( + sock, protocol, r_addr, extra={'addr': l_addr}) + return transport, protocol + + # TODO: Or create_server()? + @tasks.task + def start_serving(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, backlog=100, sock=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + "host, port and sock can not be specified at the same time") + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + # TODO: Maybe we want to bind every address in the list + # instead of the first one that works? + exceptions = [] + for family, type, proto, cname, address in infos: + sock = socket.socket(family=family, type=type, proto=proto) + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + except socket.error as exc: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + + elif sock is None: + raise ValueError( + "host and port was not specified and no sock specified") + + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock) + return sock + + @tasks.coroutine + def connect_read_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future() + transport = self._make_read_pipe_transport(pipe, protocol, waiter, + extra={}) + yield from waiter + return transport, protocol + + @tasks.coroutine + def connect_write_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future() + transport = self._make_write_pipe_transport(pipe, protocol, waiter, + extra={}) + yield from waiter + return transport, protocol + + def _add_callback(self, handle): + """Add a Handle to ready or scheduled.""" + if handle.cancelled: + return + if isinstance(handle, events.Timer): + heapq.heappush(self._scheduled, handle) + else: + self._ready.append(handle) + + def wrap_future(self, future): + """XXX""" + if isinstance(future, futures.Future): + return future # Don't wrap our own type of Future. + new_future = futures.Future() + future.add_done_callback( + lambda future: + self.call_soon_threadsafe(new_future._copy_state, future)) + return new_future + + def _run_once(self, timeout=None): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0].cancelled: + heapq.heappop(self._scheduled) + + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0].when + deadline = max(0, when - time.monotonic()) + if timeout is None: + timeout = deadline + else: + timeout = min(timeout, deadline) + + t0 = time.monotonic() + event_list = self._selector.select(timeout) + t1 = time.monotonic() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + tulip_log.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + self._process_events(event_list) + + # Handle 'later' callbacks that are ready. + now = time.monotonic() + while self._scheduled: + handle = self._scheduled[0] + if handle.when > now: + break + handle = heapq.heappop(self._scheduled) + self._ready.append(handle) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is threadsafe without using locks. + ntodo = len(self._ready) + for i in range(ntodo): + handle = self._ready.popleft() + if not handle.cancelled: + handle.run() diff --git a/tulip/events.py b/tulip/events.py new file mode 100644 index 00000000..3a6ad40c --- /dev/null +++ b/tulip/events.py @@ -0,0 +1,356 @@ +"""Event loop and event loop policy. + +Beyond the PEP: +- Only the main thread has a default event loop. +""" + +__all__ = ['EventLoopPolicy', 'DefaultEventLoopPolicy', + 'AbstractEventLoop', 'Timer', 'Handle', 'make_handle', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', 'new_event_loop', + ] + +import sys +import threading + +from .log import tulip_log + + +class Handle: + """Object returned by callback registration methods.""" + + def __init__(self, callback, args): + self._callback = callback + self._args = args + self._cancelled = False + + def __repr__(self): + res = 'Handle({}, {})'.format(self._callback, self._args) + if self._cancelled: + res += '' + return res + + @property + def callback(self): + return self._callback + + @property + def args(self): + return self._args + + @property + def cancelled(self): + return self._cancelled + + def cancel(self): + self._cancelled = True + + def run(self): + try: + self._callback(*self._args) + except Exception: + tulip_log.exception('Exception in callback %s %r', + self._callback, self._args) + + +def make_handle(callback, args): + if isinstance(callback, Handle): + assert not args + return callback + return Handle(callback, args) + + +class Timer(Handle): + """Object returned by timed callback registration methods.""" + + def __init__(self, when, callback, args): + assert when is not None + super().__init__(callback, args) + + self._when = when + + def __repr__(self): + res = 'Timer({}, {}, {})'.format(self._when, + self._callback, + self._args) + if self._cancelled: + res += '' + + return res + + @property + def when(self): + return self._when + + def __lt__(self, other): + return self._when < other._when + + def __le__(self, other): + if self._when < other._when: + return True + return self.__eq__(other) + + def __gt__(self, other): + return self._when > other._when + + def __ge__(self, other): + if self._when > other._when: + return True + return self.__eq__(other) + + def __eq__(self, other): + if isinstance(other, Timer): + return (self._when == other._when and + self._callback == other._callback and + self._args == other._args and + self._cancelled == other._cancelled) + return NotImplemented + + def __ne__(self, other): + equal = self.__eq__(other) + return NotImplemented if equal is NotImplemented else not equal + + +class AbstractEventLoop: + """Abstract event loop.""" + + # TODO: Rename run() -> run_until_idle(), run_forever() -> run(). + + def run(self): + """Run the event loop. Block until there is nothing left to do.""" + raise NotImplementedError + + def run_forever(self): + """Run the event loop. Block until stop() is called.""" + raise NotImplementedError + + def run_once(self, timeout=None): # NEW! + """Run one complete cycle of the event loop.""" + raise NotImplementedError + + def run_until_complete(self, future, timeout=None): # NEW! + """Run the event loop until a Future is done. + + Return the Future's result, or raise its exception. + + If timeout is not None, run it for at most that long; + if the Future is still not done, raise TimeoutError + (but don't cancel the Future). + """ + raise NotImplementedError + + def stop(self): # NEW! + """Stop the event loop as soon as reasonable. + + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + raise NotImplementedError + + # Methods returning Handles for scheduling callbacks. + + def call_later(self, delay, callback, *args): + raise NotImplementedError + + def call_repeatedly(self, interval, callback, *args): # NEW! + raise NotImplementedError + + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) + + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError + + # Methods returning Futures for interacting with threads. + + def wrap_future(self, future): + raise NotImplementedError + + def run_in_executor(self, executor, callback, *args): + raise NotImplementedError + + # Network I/O methods returning Futures. + + def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError + + def create_connection(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, sock=None): + raise NotImplementedError + + def start_serving(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, sock=None): + raise NotImplementedError + + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + def connect_read_pipe(self, protocol_factory, pipe): + """Register read pipe in eventloop. + + protocol_factory should instantiate object with Protocol interface. + pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + ReadTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def connect_write_pipe(self, protocol_factory, pipe): + """Register write pipe in eventloop. + + protocol_factory should instantiate object with BaseProtocol interface. + Pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + WriteTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + #def spawn_subprocess(self, protocol_factory, pipe): + # raise NotImplementedError + + # Ready-based callback registration methods. + # The add_*() methods return a Handle. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. + + def add_reader(self, fd, callback, *args): + raise NotImplementedError + + def remove_reader(self, fd): + raise NotImplementedError + + def add_writer(self, fd, callback, *args): + raise NotImplementedError + + def remove_writer(self, fd): + raise NotImplementedError + + # Completion based I/O methods returning Futures. + + def sock_recv(self, sock, nbytes): + raise NotImplementedError + + def sock_sendall(self, sock, data): + raise NotImplementedError + + def sock_connect(self, sock, address): + raise NotImplementedError + + def sock_accept(self, sock): + raise NotImplementedError + + # Signal handling. + + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError + + def remove_signal_handler(self, sig): + raise NotImplementedError + + +class EventLoopPolicy: + """Abstract policy for accessing the event loop.""" + + def get_event_loop(self): + """XXX""" + raise NotImplementedError + + def set_event_loop(self, event_loop): + """XXX""" + raise NotImplementedError + + def new_event_loop(self): + """XXX""" + raise NotImplementedError + + +class DefaultEventLoopPolicy(threading.local, EventLoopPolicy): + """Default policy implementation for accessing the event loop. + + In this policy, each thread has its own event loop. However, we + only automatically create an event loop by default for the main + thread; other threads by default have no event loop. + + Other policies may have different rules (e.g. a single global + event loop, or automatically creating an event loop per thread, or + using some other notion of context to which an event loop is + associated). + """ + + _event_loop = None + + def get_event_loop(self): + """Get the event loop. + + This may be None or an instance of EventLoop. + """ + if (self._event_loop is None and + threading.current_thread().name == 'MainThread'): + self._event_loop = self.new_event_loop() + return self._event_loop + + def set_event_loop(self, event_loop): + """Set the event loop.""" + assert event_loop is None or isinstance(event_loop, AbstractEventLoop) + self._event_loop = event_loop + + def new_event_loop(self): + """Create a new event loop. + + You must call set_event_loop() to make this the current event + loop. + """ + if sys.platform == 'win32': # pragma: no cover + from . import windows_events + return windows_events.SelectorEventLoop() + else: # pragma: no cover + from . import unix_events + return unix_events.SelectorEventLoop() + + +# Event loop policy. The policy itself is always global, even if the +# policy's rules say that there is an event loop per thread (or other +# notion of context). The default policy is installed by the first +# call to get_event_loop_policy(). +_event_loop_policy = None + + +def get_event_loop_policy(): + """XXX""" + global _event_loop_policy + if _event_loop_policy is None: + _event_loop_policy = DefaultEventLoopPolicy() + return _event_loop_policy + + +def set_event_loop_policy(policy): + """XXX""" + global _event_loop_policy + assert policy is None or isinstance(policy, EventLoopPolicy) + _event_loop_policy = policy + + +def get_event_loop(): + """XXX""" + return get_event_loop_policy().get_event_loop() + + +def set_event_loop(event_loop): + """XXX""" + get_event_loop_policy().set_event_loop(event_loop) + + +def new_event_loop(): + """XXX""" + return get_event_loop_policy().new_event_loop() diff --git a/tulip/futures.py b/tulip/futures.py new file mode 100644 index 00000000..39137aa6 --- /dev/null +++ b/tulip/futures.py @@ -0,0 +1,255 @@ +"""A Future class similar to the one in PEP 3148.""" + +__all__ = ['CancelledError', 'TimeoutError', + 'InvalidStateError', 'InvalidTimeoutError', + 'Future', + ] + +import concurrent.futures._base + +from . import events + +# States for Future. +_PENDING = 'PENDING' +_CANCELLED = 'CANCELLED' +_FINISHED = 'FINISHED' + +# TODO: Do we really want to depend on concurrent.futures internals? +Error = concurrent.futures._base.Error +CancelledError = concurrent.futures.CancelledError +TimeoutError = concurrent.futures.TimeoutError + + +class InvalidStateError(Error): + """The operation is not allowed in this state.""" + # TODO: Show the future, its state, the method, and the required state. + + +class InvalidTimeoutError(Error): + """Called result() or exception() with timeout != 0.""" + # TODO: Print a nice error message. + + +class Future: + """This class is *almost* compatible with concurrent.futures.Future. + + Differences: + + - result() and exception() do not take a timeout argument and + raise an exception when the future isn't done yet. + + - Callbacks registered with add_done_callback() are always called + via the event loop's call_soon_threadsafe(). + + - This class is not compatible with the wait() and as_completed() + methods in the concurrent.futures package. + + (In Python 3.4 or later we may be able to unify the implementations.) + """ + + # Class variables serving as defaults for instance variables. + _state = _PENDING + _result = None + _exception = None + _timeout_handle = None + + _blocking = False # proper use of future (yield vs yield from) + + def __init__(self, *, event_loop=None, timeout=None): + """Initialize the future. + + The optional event_loop argument allows to explicitly set the event + loop object used by the future. If it's not provided, the future uses + the default event loop. + """ + if event_loop is None: + self._event_loop = events.get_event_loop() + else: + self._event_loop = event_loop + self._callbacks = [] + + if timeout is not None: + self._timeout_handle = self._event_loop.call_later( + timeout, self.cancel) + + def __repr__(self): + res = self.__class__.__name__ + if self._state == _FINISHED: + if self._exception is not None: + res += ''.format(self._exception) + else: + res += ''.format(self._result) + elif self._callbacks: + size = len(self._callbacks) + if size > 2: + res += '<{}, [{}, <{} more>, {}]>'.format( + self._state, self._callbacks[0], + size-2, self._callbacks[-1]) + else: + res += '<{}, {}>'.format(self._state, self._callbacks) + else: + res += '<{}>'.format(self._state) + return res + + def cancel(self): + """Cancel the future and schedule callbacks. + + If the future is already done or cancelled, return False. Otherwise, + change the future's state to cancelled, schedule the callbacks and + return True. + """ + if self._state != _PENDING: + return False + self._state = _CANCELLED + self._schedule_callbacks() + return True + + def _schedule_callbacks(self): + """Internal: Ask the event loop to call all callbacks. + + The callbacks are scheduled to be called as soon as possible. Also + clears the callback list. + """ + # Cancel timeout handle + if self._timeout_handle is not None: + self._timeout_handle.cancel() + self._timeout_handle = None + + callbacks = self._callbacks[:] + if not callbacks: + return + + self._callbacks[:] = [] + for callback in callbacks: + self._event_loop.call_soon(callback, self) + + def cancelled(self): + """Return True if the future was cancelled.""" + return self._state == _CANCELLED + + def running(self): + """Always return False. + + This method is for compatibility with concurrent.futures; we don't + have a running state. + """ + return False # We don't have a running state. + + def done(self): + """Return True if the future is done. + + Done means either that a result / exception are available, or that the + future was cancelled. + """ + return self._state != _PENDING + + def result(self, timeout=0): + """Return the result this future represents. + + If the future has been cancelled, raises CancelledError. If the + future's result isn't yet available, raises InvalidStateError. If + the future is done and has an exception set, this exception is raised. + Timeout values other than 0 are not supported. + """ + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + if self._exception is not None: + raise self._exception + return self._result + + def exception(self, timeout=0): + """Return the exception that was set on this future. + + The exception (or None if no exception was set) is returned only if + the future is done. If the future has been cancelled, raises + CancelledError. If the future isn't done yet, raises + InvalidStateError. Timeout values other than 0 are not supported. + """ + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + return self._exception + + def add_done_callback(self, fn): + """Add a callback to be run when the future becomes done. + + The callback is called with a single argument - the future object. If + the future is already done when this is called, the callback is + scheduled with call_soon. + """ + if self._state != _PENDING: + self._event_loop.call_soon(fn, self) + else: + self._callbacks.append(fn) + + # New method not in PEP 3148. + + def remove_done_callback(self, fn): + """Remove all instances of a callback from the "call when done" list. + + Returns the number of callbacks removed. + """ + filtered_callbacks = [f for f in self._callbacks if f != fn] + removed_count = len(self._callbacks) - len(filtered_callbacks) + if removed_count: + self._callbacks[:] = filtered_callbacks + return removed_count + + # So-called internal methods (note: no set_running_or_notify_cancel()). + + def set_result(self, result): + """Mark the future done and set its result. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError + self._result = result + self._state = _FINISHED + self._schedule_callbacks() + + def set_exception(self, exception): + """ Mark the future done and set an exception. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError + self._exception = exception + self._state = _FINISHED + self._schedule_callbacks() + + # Truly internal methods. + + def _copy_state(self, other): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert other.done() + assert not self.done() + if other.cancelled(): + self.cancel() + else: + exception = other.exception() + if exception is not None: + self.set_exception(exception) + else: + result = other.result() + self.set_result(result) + + def __iter__(self): + if not self.done(): + self._blocking = True + yield self # This tells Task to wait for completion. + assert self.done(), "yield from wasn't used with future" + return self.result() # May raise too. diff --git a/tulip/http/__init__.py b/tulip/http/__init__.py new file mode 100644 index 00000000..582f0809 --- /dev/null +++ b/tulip/http/__init__.py @@ -0,0 +1,12 @@ +# This relies on each of the submodules having an __all__ variable. + +from .client import * +from .errors import * +from .protocol import * +from .server import * + + +__all__ = (client.__all__ + + errors.__all__ + + protocol.__all__ + + server.__all__) diff --git a/tulip/http/client.py b/tulip/http/client.py new file mode 100644 index 00000000..b65b90a8 --- /dev/null +++ b/tulip/http/client.py @@ -0,0 +1,145 @@ +"""HTTP Client for Tulip. + +Most basic usage: + + sts, headers, response = yield from http_client.fetch(url, + method='GET', headers={}, request=b'') + assert isinstance(sts, int) + assert isinstance(headers, dict) + # sort of; case insensitive (what about multiple values for same header?) + headers['status'] == '200 Ok' # or some such + assert isinstance(response, bytes) + +TODO: Reuse email.Message class (or its subclass, http.client.HTTPMessage). +TODO: How do we do connection keep alive? Pooling? +""" + +__all__ = ['HttpClientProtocol'] + + +import email.message +import email.parser + +import tulip + +from . import protocol + + +class HttpClientProtocol: + """This Protocol class is also used to initiate the connection. + + Usage: + p = HttpClientProtocol(url, ...) + sts, headers, stream = yield from p.connect() + + """ + + def __init__(self, host, port=None, *, + path='/', method='GET', headers=None, ssl=None, + make_body=None, encoding='utf-8', version=(1, 1), + chunked=False): + host = self.validate(host, 'host') + if ':' in host: + assert port is None + host, port_s = host.split(':', 1) + port = int(port_s) + self.host = host + if port is None: + if ssl: + port = 443 + else: + port = 80 + assert isinstance(port, int) + self.port = port + self.path = self.validate(path, 'path') + self.method = self.validate(method, 'method') + self.headers = email.message.Message() + self.headers['Accept-Encoding'] = 'gzip, deflate' + if headers: + for key, value in headers.items(): + self.validate(key, 'header key') + self.validate(value, 'header value', True) + self.headers[key] = value + self.encoding = self.validate(encoding, 'encoding') + self.version = version + self.make_body = make_body + self.chunked = chunked + self.ssl = ssl + if 'content-length' not in self.headers: + if self.make_body is None: + self.headers['Content-Length'] = '0' + else: + self.chunked = True + if self.chunked: + if 'Transfer-Encoding' not in self.headers: + self.headers['Transfer-Encoding'] = 'chunked' + else: + assert self.headers['Transfer-Encoding'].lower() == 'chunked' + if 'host' not in self.headers: + self.headers['Host'] = self.host + self.event_loop = tulip.get_event_loop() + self.transport = None + + def validate(self, value, name, embedded_spaces_okay=False): + # Must be a string. If embedded_spaces_okay is False, no + # whitespace is allowed; otherwise, internal single spaces are + # allowed (but no other whitespace). + assert isinstance(value, str), \ + '{} should be str, not {}'.format(name, type(value)) + parts = value.split() + assert parts, '{} should not be empty'.format(name) + if embedded_spaces_okay: + assert ' '.join(parts) == value, \ + '{} can only contain embedded single spaces ({!r})'.format( + name, value) + else: + assert parts == [value], \ + '{} cannot contain whitespace ({!r})'.format(name, value) + return value + + @tulip.coroutine + def connect(self): + yield from self.event_loop.create_connection( + lambda: self, self.host, self.port, ssl=self.ssl) + + # read response status + version, status, reason = yield from self.stream.read_response_status() + + message = yield from self.stream.read_message(version) + + # headers + headers = email.message.Message() + for hdr, val in message.headers: + headers.add_header(hdr, val) + + sts = '{} {}'.format(status, reason) + return (sts, headers, message.payload) + + def connection_made(self, transport): + self.transport = transport + self.stream = protocol.HttpStreamReader() + + self.request = protocol.Request( + transport, self.method, self.path, self.version) + + self.request.add_headers(*self.headers.items()) + self.request.send_headers() + + if self.make_body is not None: + if self.chunked: + self.make_body( + self.request.write, self.request.eof) + else: + self.make_body( + self.request.write, self.request.eof) + else: + self.request.write_eof() + + def data_received(self, data): + self.stream.feed_data(data) + + def eof_received(self): + self.stream.feed_eof() + + def connection_lost(self, exc): + pass diff --git a/tulip/http/errors.py b/tulip/http/errors.py new file mode 100644 index 00000000..41344de1 --- /dev/null +++ b/tulip/http/errors.py @@ -0,0 +1,44 @@ +"""http related errors.""" + +__all__ = ['HttpException', 'HttpStatusException', + 'IncompleteRead', 'BadStatusLine', 'LineTooLong', 'InvalidHeader'] + +import http.client + + +class HttpException(http.client.HTTPException): + + code = None + headers = () + + +class HttpStatusException(HttpException): + + def __init__(self, code, headers=None, message=''): + self.code = code + self.headers = headers + self.message = message + + +class BadRequestException(HttpException): + + code = 400 + + +class IncompleteRead(BadRequestException, http.client.IncompleteRead): + pass + + +class BadStatusLine(BadRequestException, http.client.BadStatusLine): + pass + + +class LineTooLong(BadRequestException, http.client.LineTooLong): + pass + + +class InvalidHeader(BadRequestException): + + def __init__(self, hdr): + super().__init__('Invalid HTTP Header: %s' % hdr) + self.hdr = hdr diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py new file mode 100644 index 00000000..6a0e1279 --- /dev/null +++ b/tulip/http/protocol.py @@ -0,0 +1,877 @@ +"""Http related helper utils.""" + +__all__ = ['HttpStreamReader', + 'HttpMessage', 'Request', 'Response', + 'RawHttpMessage', 'RequestLine', 'ResponseStatus'] + +import collections +import email.utils +import functools +import http.server +import itertools +import re +import sys +import zlib + +import tulip +from . import errors + +METHRE = re.compile('[A-Z0-9$-_.]+') +VERSRE = re.compile('HTTP/(\d+).(\d+)') +HDRRE = re.compile(b"[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]") +CONTINUATION = (b' ', b'\t') +RESPONSES = http.server.BaseHTTPRequestHandler.responses + +RequestLine = collections.namedtuple( + 'RequestLine', ['method', 'uri', 'version']) + + +ResponseStatus = collections.namedtuple( + 'ResponseStatus', ['version', 'code', 'reason']) + + +RawHttpMessage = collections.namedtuple( + 'RawHttpMessage', ['headers', 'payload', 'should_close', 'compression']) + + +class HttpStreamReader(tulip.StreamReader): + + MAX_HEADERS = 32768 + MAX_HEADERFIELD_SIZE = 8190 + + # if _parser is set, feed_data and feed_eof sends data into + # _parser instead of self. is it being used as stream redirection for + # _parse_chunked_payload, _parse_length_payload and _parse_eof_payload + _parser = None + + def feed_data(self, data): + """_parser is a generator, if _parser is set, feed_data sends + incoming data into the generator untile generator stops.""" + if self._parser: + try: + self._parser.send(data) + except StopIteration as exc: + self._parser = None + if exc.value: + self.feed_data(exc.value) + else: + super().feed_data(data) + + def feed_eof(self): + """_parser is a generator, if _parser is set feed_eof throws + StreamEofException into this generator.""" + if self._parser: + try: + self._parser.throw(StreamEofException()) + except StopIteration: + self._parser = None + + super().feed_eof() + + @tulip.coroutine + def read_request_line(self): + """Read request status line. Exception errors.BadStatusLine + could be raised in case of any errors in status line. + Returns three values (method, uri, version) + + Example: + + GET /path HTTP/1.1 + + >> yield from reader.read_request_line() + ('GET', '/path', (1, 1)) + + """ + bline = yield from self.readline() + try: + line = bline.decode('ascii').rstrip() + except UnicodeDecodeError: + raise errors.BadStatusLine(bline) from None + + try: + method, uri, version = line.split(None, 2) + except ValueError: + raise errors.BadStatusLine(line) from None + + # method + method = method.upper() + if not METHRE.match(method): + raise errors.BadStatusLine(method) + + # version + match = VERSRE.match(version) + if match is None: + raise errors.BadStatusLine(version) + version = (int(match.group(1)), int(match.group(2))) + + return RequestLine(method, uri, version) + + @tulip.coroutine + def read_response_status(self): + """Read response status line. Exception errors.BadStatusLine + could be raised in case of any errors in status line. + Returns three values (version, status_code, reason) + + Example: + + HTTP/1.1 200 Ok + + >> yield from reader.read_response_status() + ((1, 1), 200, 'Ok') + + """ + bline = yield from self.readline() + if not bline: + # Presumably, the server closed the connection before + # sending a valid response. + raise errors.BadStatusLine(bline) + + try: + line = bline.decode('ascii').rstrip() + except UnicodeDecodeError: + raise errors.BadStatusLine(bline) from None + + try: + version, status = line.split(None, 1) + except ValueError: + raise errors.BadStatusLine(line) from None + else: + try: + status, reason = status.split(None, 1) + except ValueError: + reason = '' + + # version + match = VERSRE.match(version) + if match is None: + raise errors.BadStatusLine(line) + version = (int(match.group(1)), int(match.group(2))) + + # The status code is a three-digit number + try: + status = int(status) + except ValueError: + raise errors.BadStatusLine(line) from None + + if status < 100 or status > 999: + raise errors.BadStatusLine(line) + + return ResponseStatus(version, status, reason.strip()) + + @tulip.coroutine + def read_headers(self): + """Read and parses RFC2822 headers from a stream. + + Line continuations are supported. Returns list of header name + and value pairs. Header name is in upper case. + """ + size = 0 + headers = [] + + line = yield from self.readline() + + while line not in (b'\r\n', b'\n'): + header_length = len(line) + + # Parse initial header name : value pair. + sep_pos = line.find(b':') + if sep_pos < 0: + raise ValueError('Invalid header %s' % line.strip()) + + name, value = line[:sep_pos], line[sep_pos+1:] + name = name.rstrip(b' \t').upper() + if HDRRE.search(name): + raise ValueError('Invalid header name %s' % name) + + name = name.strip().decode('ascii', 'surrogateescape') + value = [value.lstrip()] + + # next line + line = yield from self.readline() + + # consume continuation lines + continuation = line.startswith(CONTINUATION) + + if continuation: + while continuation: + header_length += len(line) + if header_length > self.MAX_HEADERFIELD_SIZE: + raise errors.LineTooLong( + 'limit request headers fields size') + value.append(line) + + line = yield from self.readline() + continuation = line.startswith(CONTINUATION) + else: + if header_length > self.MAX_HEADERFIELD_SIZE: + raise errors.LineTooLong( + 'limit request headers fields size') + + # total headers size + size += header_length + if size >= self.MAX_HEADERS: + raise errors.LineTooLong('limit request headers fields') + + headers.append( + (name, + b''.join(value).rstrip().decode('ascii', 'surrogateescape'))) + + return headers + + def _parse_chunked_payload(self): + """Chunked transfer encoding parser.""" + stream = yield + + try: + data = bytearray() + + while True: + # read line + if b'\n' not in data: + data.extend((yield)) + continue + + line, data = data.split(b'\n', 1) + + # Read the next chunk size from the file + i = line.find(b';') + if i >= 0: + line = line[:i] # strip chunk-extensions + try: + size = int(line, 16) + except ValueError: + raise errors.IncompleteRead(b'') from None + + if size == 0: + break + + # read chunk + while len(data) < size: + data.extend((yield)) + + # feed stream + stream.feed_data(data[:size]) + + data = data[size:] + + # toss the CRLF at the end of the chunk + while len(data) < 2: + data.extend((yield)) + + data = data[2:] + + # read and discard trailer up to the CRLF terminator + while True: + if b'\n' in data: + line, data = data.split(b'\n', 1) + if line in (b'\r', b''): + break + else: + data.extend((yield)) + + # stream eof + stream.feed_eof() + return data + + except StreamEofException: + stream.set_exception(errors.IncompleteRead(b'')) + except errors.IncompleteRead as exc: + stream.set_exception(exc) + + def _parse_length_payload(self, length): + """Read specified amount of bytes.""" + stream = yield + + try: + data = bytearray() + while length: + data.extend((yield)) + + data_len = len(data) + if data_len <= length: + stream.feed_data(data) + data = bytearray() + length -= data_len + else: + stream.feed_data(data[:length]) + data = data[length:] + length = 0 + + stream.feed_eof() + return data + except StreamEofException: + stream.set_exception(errors.IncompleteRead(b'')) + + def _parse_eof_payload(self): + """Read all bytes untile eof.""" + stream = yield + + try: + while True: + stream.feed_data((yield)) + except StreamEofException: + stream.feed_eof() + + @tulip.coroutine + def read_message(self, version=(1, 1), + length=None, compression=True, readall=False): + """Read RFC2822 headers and message payload from a stream. + + read_message() automatically decompress gzip and deflate content + encoding. To prevent decompression pass compression=False. + + Returns tuple of headers, payload stream, should close flag, + compression type. + """ + # load headers + headers = yield from self.read_headers() + + # payload params + chunked = False + encoding = None + close_conn = None + + for name, value in headers: + if name == 'CONTENT-LENGTH': + length = value + elif name == 'TRANSFER-ENCODING': + chunked = value.lower() == 'chunked' + elif name == 'SEC-WEBSOCKET-KEY1': + length = 8 + elif name == 'CONNECTION': + v = value.lower() + if v == 'close': + close_conn = True + elif v == 'keep-alive': + close_conn = False + elif compression and name == 'CONTENT-ENCODING': + enc = value.lower() + if enc in ('gzip', 'deflate'): + encoding = enc + + if close_conn is None: + close_conn = version <= (1, 0) + + # payload parser + if chunked: + parser = self._parse_chunked_payload() + + elif length is not None: + try: + length = int(length) + except ValueError: + raise errors.InvalidHeader('CONTENT-LENGTH') from None + + if length < 0: + raise errors.InvalidHeader('CONTENT-LENGTH') + + parser = self._parse_length_payload(length) + else: + if readall: + parser = self._parse_eof_payload() + else: + parser = self._parse_length_payload(0) + + next(parser) + + payload = stream = tulip.StreamReader() + + # payload decompression wrapper + if encoding is not None: + stream = DeflateStream(stream, encoding) + + try: + # initialize payload parser with stream, stream is being + # used by parser as destination stream + parser.send(stream) + except StopIteration: + pass + else: + # feed existing buffer to payload parser + self.byte_count = 0 + while self.buffer: + try: + parser.send(self.buffer.popleft()) + except StopIteration as exc: + parser = None + + # parser is done + buf = b''.join(self.buffer) + self.buffer.clear() + + # re-add remaining data back to buffer + if exc.value: + self.feed_data(exc.value) + + if buf: + self.feed_data(buf) + + break + + # parser still require more data + if parser is not None: + if self.eof: + try: + parser.throw(StreamEofException()) + except StopIteration as exc: + pass + else: + self._parser = parser + + return RawHttpMessage(headers, payload, close_conn, encoding) + + +class StreamEofException(Exception): + """Internal exception: eof received.""" + + +class DeflateStream: + """DeflateStream decomress stream and feed data into specified stream.""" + + def __init__(self, stream, encoding): + self.stream = stream + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + + self.zlib = zlib.decompressobj(wbits=zlib_mode) + + def set_exception(self, exc): + self.stream.set_exception(exc) + + def feed_data(self, chunk): + try: + chunk = self.zlib.decompress(chunk) + except: + self.stream.set_exception(errors.IncompleteRead(b'')) + + if chunk: + self.stream.feed_data(chunk) + + def feed_eof(self): + self.stream.feed_data(self.zlib.flush()) + if not self.zlib.eof: + self.stream.set_exception(errors.IncompleteRead(b'')) + + self.stream.feed_eof() + + +EOF_MARKER = object() +EOL_MARKER = object() + + +def wrap_payload_filter(func): + """Wraps payload filter and piped filters. + + Filter is a generatator that accepts arbitrary chunks of data, + modify data and emit new stream of data. + + For example we have stream of chunks: ['1', '2', '3', '4', '5'], + we can apply chunking filter to this stream: + + ['1', '2', '3', '4', '5'] + | + response.add_chunking_filter(2) + | + ['12', '34', '5'] + + It is possible to use different filters at the same time. + + For a example to compress incoming stream with 'deflate' encoding + and then split data and emit chunks of 8196 bytes size chunks: + + >> response.add_compression_filter('deflate') + >> response.add_chunking_filter(8196) + + Filters do not alter transfer encoding. + + Filter can receive types types of data, bytes object or EOF_MARKER. + + 1. If filter receives bytes object, it should process data + and yield processed data then yield EOL_MARKER object. + 2. If Filter recevied EOF_MARKER, it should yield remaining + data (buffered) and then yield EOF_MARKER. + """ + @functools.wraps(func) + def wrapper(self, *args, **kw): + new_filter = func(self, *args, **kw) + + filter = self.filter + if filter is not None: + next(new_filter) + self.filter = filter_pipe(filter, new_filter) + else: + self.filter = new_filter + + next(self.filter) + + return wrapper + + +def filter_pipe(filter, filter2): + """Creates pipe between two filters. + + filter_pipe() feeds first filter with incoming data and then + send yielded from first filter data into filter2, results of + filter2 are being emitted. + + 1. If filter_pipe receives bytes object, it sends it to the first filter. + 2. Reads yielded values from the first filter until it receives + EOF_MARKER or EOL_MARKER. + 3. Each of this values is being send to second filter. + 4. Reads yielded values from second filter until it recives EOF_MARKER or + EOL_MARKER. Each of this values yields to writer. + """ + chunk = yield + + while True: + eof = chunk is EOF_MARKER + chunk = filter.send(chunk) + + while chunk is not EOL_MARKER: + chunk = filter2.send(chunk) + + while chunk not in (EOF_MARKER, EOL_MARKER): + yield chunk + chunk = next(filter2) + + if chunk is not EOF_MARKER: + if eof: + chunk = EOF_MARKER + else: + chunk = next(filter) + else: + break + + chunk = yield EOL_MARKER + + +class HttpMessage: + """HttpMessage allows to write headers and payload to a stream. + + For example, lets say we want to read file then compress it with deflate + compression and then send it with chunked transfer encoding, code may look + like this: + + >> response = tulip.http.Response(transport, 200) + + We have to use deflate compression first: + + >> response.add_compression_filter('deflate') + + Then we want to split output stream into chunks of 1024 bytes size: + + >> response.add_chunking_filter(1024) + + We can add headers to response with add_headers() method. add_headers() + does not send data to transport, send_headers() sends request/response + line and then sends headers: + + >> response.add_headers( + .. ('Content-Disposition', 'attachment; filename="..."')) + >> response.send_headers() + + Now we can use chunked writer to write stream to a network stream. + First call to write() method sends response status line and headers, + add_header() and add_headers() method unavailble at this stage: + + >> with open('...', 'rb') as f: + .. chunk = fp.read(8196) + .. while chunk: + .. response.write(chunk) + .. chunk = fp.read(8196) + + >> response.write_eof() + """ + + writer = None + + # 'filter' is being used for altering write() bahaviour, + # add_chunking_filter adds deflate/gzip compression and + # add_compression_filter splits incoming data into a chunks. + filter = None + + HOP_HEADERS = None # Must be set by subclass. + + SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} tulip/0.0'.format(sys.version_info) + + status = None + status_line = b'' + + # subclass can enable auto sending headers with write() call, + # this is useful for wsgi's start_response implementation. + _send_headers = False + + def __init__(self, transport, version, close): + self.transport = transport + self.version = version + self.closing = close + self.keepalive = False + + self.chunked = False + self.length = None + self.upgrade = False + self.headers = [] + self.headers_sent = False + + def force_close(self): + self.closing = True + + def force_chunked(self): + self.chunked = True + + def keep_alive(self): + return self.keepalive and not self.closing + + def is_headers_sent(self): + return self.headers_sent + + def add_header(self, name, value): + """Analyze headers. Calculate content length, + removes hop headers, etc.""" + assert not self.headers_sent, 'headers have been sent already' + assert isinstance(name, str), '%r is not a string' % name + + name = name.strip().upper() + + if name == 'CONTENT-LENGTH': + self.length = int(value) + + if name == 'CONNECTION': + val = value.lower().strip() + # handle websocket + if val == 'upgrade': + self.upgrade = True + # connection keep-alive + elif val == 'close': + self.keepalive = False + elif val == 'keep-alive': + self.keepalive = True + + elif name == 'UPGRADE': + if 'websocket' in value.lower(): + self.headers.append((name, value)) + + elif name == 'TRANSFER-ENCODING' and not self.chunked: + self.chunked = value.lower().strip() == 'chunked' + + elif name not in self.HOP_HEADERS: + # ignore hopbyhop headers + self.headers.append((name, value)) + + def add_headers(self, *headers): + """Adds headers to a http message.""" + for name, value in headers: + self.add_header(name, value) + + def send_headers(self): + """Writes headers to a stream. Constructs payload writer.""" + # Chunked response is only for HTTP/1.1 clients or newer + # and there is no Content-Length header is set. + # Do not use chunked responses when the response is guaranteed to + # not have a response body (304, 204). + assert not self.headers_sent, 'headers have been sent already' + self.headers_sent = True + + if (self.chunked is True) or ( + self.length is None and + self.version >= (1, 1) and + self.status not in (304, 204)): + self.chunked = True + self.writer = self._write_chunked_payload() + + elif self.length is not None: + self.writer = self._write_length_payload(self.length) + + else: + self.writer = self._write_eof_payload() + + next(self.writer) + + # status line + self.transport.write(self.status_line.encode('ascii')) + + # send headers + self.transport.write( + ('%s\r\n\r\n' % '\r\n'.join( + ('%s: %s' % (k, v) for k, v in + itertools.chain(self._default_headers(), self.headers))) + ).encode('ascii')) + + def _default_headers(self): + # set the connection header + if self.upgrade: + connection = 'upgrade' + elif self.keep_alive(): + connection = 'keep-alive' + else: + connection = 'close' + + headers = [('CONNECTION', connection)] + + if self.chunked: + headers.append(('TRANSFER-ENCODING', 'chunked')) + + return headers + + def write(self, chunk): + """write() writes chunk of data to a steram by using different writers. + writer uses filter to modify chunk of data. write_eof() indicates + end of stream. writer can't be used after write_eof() method + being called.""" + assert (isinstance(chunk, (bytes, bytearray)) or + chunk is EOF_MARKER), chunk + + if self._send_headers and not self.headers_sent: + self.send_headers() + + assert self.writer is not None, 'send_headers() is not called.' + + if self.filter: + chunk = self.filter.send(chunk) + while chunk not in (EOF_MARKER, EOL_MARKER): + self.writer.send(chunk) + chunk = next(self.filter) + else: + if chunk is not EOF_MARKER: + self.writer.send(chunk) + + def write_eof(self): + self.write(EOF_MARKER) + try: + self.writer.throw(StreamEofException()) + except StopIteration: + pass + + def _write_chunked_payload(self): + """Write data in chunked transfer encoding.""" + while True: + try: + chunk = yield + except StreamEofException: + self.transport.write(b'0\r\n\r\n') + break + + self.transport.write('{:x}\r\n'.format(len(chunk)).encode('ascii')) + self.transport.write(chunk) + self.transport.write(b'\r\n') + + def _write_length_payload(self, length): + """Write specified number of bytes to a stream.""" + while True: + try: + chunk = yield + except StreamEofException: + break + + if length: + l = len(chunk) + if length >= l: + self.transport.write(chunk) + else: + self.transport.write(chunk[:length]) + + length = max(0, length-l) + + def _write_eof_payload(self): + while True: + try: + chunk = yield + except StreamEofException: + break + + self.transport.write(chunk) + + @wrap_payload_filter + def add_chunking_filter(self, chunk_size=16*1024): + """Split incoming stream into chunks.""" + buf = bytearray() + chunk = yield + + while True: + if chunk is EOF_MARKER: + if buf: + yield buf + + yield EOF_MARKER + + else: + buf.extend(chunk) + + while len(buf) >= chunk_size: + chunk, buf = buf[:chunk_size], buf[chunk_size:] + yield chunk + + chunk = yield EOL_MARKER + + @wrap_payload_filter + def add_compression_filter(self, encoding='deflate'): + """Compress incoming stream with deflate or gzip encoding.""" + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + zcomp = zlib.compressobj(wbits=zlib_mode) + + chunk = yield + while True: + if chunk is EOF_MARKER: + yield zcomp.flush() + chunk = yield EOF_MARKER + + else: + yield zcomp.compress(chunk) + chunk = yield EOL_MARKER + + +class Response(HttpMessage): + """Create http response message. + + Transport is a socket stream transport. status is a response status code, + status has to be integer value. http_version is a tuple that represents + http version, (1, 0) stands for HTTP/1.0 and (1, 1) is for HTTP/1.1 + """ + + HOP_HEADERS = { + 'CONNECTION', + 'KEEP-ALIVE', + 'PROXY-AUTHENTICATE', + 'PROXY-AUTHORIZATION', + 'TE', + 'TRAILERS', + 'TRANSFER-ENCODING', + 'UPGRADE', + 'SERVER', + 'DATE', + } + + def __init__(self, transport, status, http_version=(1, 1), close=False): + super().__init__(transport, http_version, close) + + self.status = status + self.status_line = 'HTTP/{0[0]}.{0[1]} {1} {2}\r\n'.format( + http_version, status, RESPONSES[status][0]) + + def _default_headers(self): + headers = super()._default_headers() + headers.extend((('DATE', email.utils.formatdate()), + ('SERVER', self.SERVER_SOFTWARE))) + + return headers + + +class Request(HttpMessage): + + HOP_HEADERS = () + + def __init__(self, transport, method, uri, + http_version=(1, 1), close=False): + super().__init__(transport, http_version, close) + + self.method = method + self.uri = uri + self.status_line = '{0} {1} HTTP/{2[0]}.{2[1]}\r\n'.format( + method, uri, http_version) + + def _default_headers(self): + headers = super()._default_headers() + headers.append(('USER-AGENT', self.SERVER_SOFTWARE)) + + return headers diff --git a/tulip/http/server.py b/tulip/http/server.py new file mode 100644 index 00000000..7590e47b --- /dev/null +++ b/tulip/http/server.py @@ -0,0 +1,176 @@ +"""simple http server.""" + +__all__ = ['ServerHttpProtocol'] + +import http.server +import inspect +import logging +import traceback + +import tulip +import tulip.http + +from . import errors + +RESPONSES = http.server.BaseHTTPRequestHandler.responses +DEFAULT_ERROR_MESSAGE = """ + + + %(status)s %(reason)s + + +

%(status)s %(reason)s

+ %(message)s + +""" + + +class ServerHttpProtocol(tulip.Protocol): + """Simple http protocol implementation. + + ServerHttpProtocol handles incoming http request. It reads request line, + request headers and request payload and calls handler_request() method. + By default it always returns with 404 respose. + + ServerHttpProtocol handles errors in incoming request, like bad + status line, bad headers or incomplete payload. If any error occurs, + connection gets closed. + """ + closing = False + request_count = 0 + _request_handle = None + + def __init__(self, log=logging, debug=False): + self.log = log + self.debug = debug + + def connection_made(self, transport): + self.transport = transport + self.stream = tulip.http.HttpStreamReader() + self._request_handle = self.start() + + def data_received(self, data): + self.stream.feed_data(data) + + def connection_lost(self, exc): + if self._request_handle is not None: + self._request_handle.cancel() + self._request_handle = None + + def eof_received(self): + self.stream.feed_eof() + + def close(self): + self.closing = True + + def log_access(self, status, info, message, *args, **kw): + pass + + def log_debug(self, *args, **kw): + if self.debug: + self.log.debug(*args, **kw) + + def log_exception(self, *args, **kw): + self.log.exception(*args, **kw) + + @tulip.task + def start(self): + """Start processing of incoming requests. + It reads request line, request headers and request payload, then + calls handle_request() method. Subclass has to override + handle_request(). start() handles various excetions in request + or response handling. In case of any error connection is being closed. + """ + + while True: + info = None + message = None + self.request_count += 1 + + try: + info = yield from self.stream.read_request_line() + message = yield from self.stream.read_message(info.version) + + handler = self.handle_request(info, message) + if (inspect.isgenerator(handler) or + isinstance(handler, tulip.Future)): + yield from handler + + except tulip.CancelledError: + self.log_debug('Ignored premature client disconnection.') + break + except errors.HttpException as exc: + self.handle_error(exc.code, info, message, exc, exc.headers) + except Exception as exc: + self.handle_error(500, info, message, exc) + finally: + if self.closing: + self.transport.close() + break + + self._request_handle = None + + def handle_error(self, status=500, info=None, + message=None, exc=None, headers=None): + """Handle errors. + + Returns http response with specific status code. Logs additional + information. It always closes current connection.""" + + if status == 500: + self.log_exception("Error handling request") + + try: + reason, msg = RESPONSES[status] + except KeyError: + status = 500 + reason, msg = '???', '' + + if self.debug and exc is not None: + try: + tb = traceback.format_exc() + msg += '

Traceback:

\n
%s
' % tb + except: + pass + + self.log_access(status, info, message) + + html = DEFAULT_ERROR_MESSAGE % { + 'status': status, 'reason': reason, 'message': msg} + + response = tulip.http.Response(self.transport, status, close=True) + response.add_headers( + ('Content-Type', 'text/html'), + ('Content-Length', str(len(html)))) + if headers is not None: + response.add_headers(*headers) + response.send_headers() + + response.write(html.encode('ascii')) + response.write_eof() + + self.close() + + def handle_request(self, info, message): + """Handle a single http request. + + Subclass should override this method. By default it always + returns 404 response. + + info: tulip.http.RequestLine instance + message: tulip.http.RawHttpMessage instance + """ + response = tulip.http.Response( + self.transport, 404, http_version=info.version, close=True) + + body = b'Page Not Found!' + + response.add_headers( + ('Content-Type', 'text/plain'), + ('Content-Length', str(len(body)))) + response.send_headers() + response.write(body) + response.write_eof() + + self.close() + self.log_access(404, info, message) diff --git a/tulip/locks.py b/tulip/locks.py new file mode 100644 index 00000000..40247962 --- /dev/null +++ b/tulip/locks.py @@ -0,0 +1,433 @@ +"""Synchronization primitives""" + +__all__ = ['Lock', 'EventWaiter', 'Condition', 'Semaphore'] + +import collections +import time + +from . import events +from . import futures +from . import tasks + + +class Lock: + """The class implementing primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned by + a particular coroutine when locked. A primitive lock is in one of two + states, "locked" or "unlocked". + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() changes + the state to locked and returns immediately. When the state is locked, + acquire() blocks until a call to release() in another coroutine changes + it to unlocked, then the acquire() call resets it to locked and returns. + The release() method should only be called in the locked state; it changes + the state to unlocked and returns immediately. If an attempt is made + to release an unlocked lock, a RuntimeError will be raised. + + When more than one coroutine is blocked in acquire() waiting for the state + to turn to unlocked, only one coroutine proceeds when a release() call + resets the state to unlocked; first coroutine which is blocked in acquire() + is being processed. + + acquire() method is a coroutine and should be called with "yield from" + + Locks also support the context manager protocol. (yield from lock) should + be used as context manager expression. + + Usage: + + lock = Lock() + ... + yield from lock + try: + ... + finally: + lock.release() + + Context manager usage: + + lock = Lock() + ... + with (yield from lock): + ... + + Lock object could be tested for locking state: + + if not lock.locked(): + yield from lock + else: + # lock is acquired + ... + + """ + + def __init__(self): + self._waiters = collections.deque() + self._locked = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<%s [%s]>' % ( + res[1:-1], 'locked' if self._locked else 'unlocked') + + def locked(self): + """Return true if lock is acquired.""" + return self._locked + + @tasks.coroutine + def acquire(self, timeout=None): + """Acquire a lock. + + Acquire method blocks until the lock is unlocked, then set it to + locked and return True. + + When invoked with the floating-point timeout argument set, blocks for + at most the number of seconds specified by timeout and as long as + the lock cannot be acquired. + + The return value is True if the lock is acquired successfully, + False if not (for example if the timeout expired). + """ + if not self._waiters and not self._locked: + self._locked = True + return True + + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + self._locked = True + return True + + def release(self): + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + If any other coroutines are blocked waiting for the lock to become + unlocked, allow exactly one of them to proceed. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + if self._waiters: + self._waiters[0].set_result(True) + else: + raise RuntimeError('Lock is not acquired.') + + def __enter__(self): + if not self._locked: + raise RuntimeError( + '"yield from" should be used as context manager expression') + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self + + +class EventWaiter: + """A EventWaiter implementation, our equivalent to threading.Event + + Class implementing event objects. An event manages a flag that can be set + to true with the set() method and reset to false with the clear() method. + The wait() method blocks until the flag is true. The flag is initially + false. + """ + + def __init__(self): + self._waiters = collections.deque() + self._value = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<%s [%s]>' % (res[1:-1], 'set' if self._value else 'unset') + + def is_set(self): + """Return true if and only if the internal flag is true.""" + return self._value + + def set(self): + """Set the internal flag to true. All coroutines waiting for it to + become true are awakened. Coroutine that call wait() once the flag is + true will not block at all. + """ + if not self._value: + self._value = True + + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + + def clear(self): + """Reset the internal flag to false. Subsequently, coroutines calling + wait() will block until set() is called to set the internal flag + to true again.""" + self._value = False + + @tasks.coroutine + def wait(self, timeout=None): + """Block until the internal flag is true. If the internal flag + is true on entry, return immediately. Otherwise, block until another + coroutine calls set() to set the flag to true, or until the optional + timeout occurs. + + When the timeout argument is present and not None, it should be + a floating point number specifying a timeout for the operation in + seconds (or fractions thereof). + + This method returns true if and only if the internal flag has been + set to true, either before the wait call or after the wait starts, + so it will always return True except if a timeout is given and + the operation times out. + + wait() method is a coroutine. + """ + if self._value: + return True + + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + return True + + +class Condition(Lock): + """A Condition implementation. + + This class implements condition variable objects. A condition variable + allows one or more coroutines to wait until they are notified by another + coroutine. + """ + + def __init__(self): + super().__init__() + + self._condition_waiters = collections.deque() + + @tasks.coroutine + def wait(self, timeout=None): + """Wait until notified or until a timeout occurs. If the calling + coroutine has not acquired the lock when this method is called, + a RuntimeError is raised. + + This method releases the underlying lock, and then blocks until it is + awakened by a notify() or notify_all() call for the same condition + variable in another coroutine, or until the optional timeout occurs. + Once awakened or timed out, it re-acquires the lock and returns. + + When the timeout argument is present and not None, it should be + a floating point number specifying a timeout for the operation + in seconds (or fractions thereof). + + The return value is True unless a given timeout expired, in which + case it is False. + """ + if not self._locked: + raise RuntimeError('cannot wait on un-acquired lock') + + self.release() + + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + + self._condition_waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._condition_waiters.remove(fut) + return False + else: + f = self._condition_waiters.popleft() + assert fut is f + finally: + yield from self.acquire() + + return True + + @tasks.coroutine + def wait_for(self, predicate, timeout=None): + """Wait until a condition evaluates to True. predicate should be a + callable which result will be interpreted as a boolean value. A timeout + may be provided giving the maximum time to wait. + """ + endtime = None + waittime = timeout + result = predicate() + + while not result: + if waittime is not None: + if endtime is None: + endtime = time.monotonic() + waittime + else: + waittime = endtime - time.monotonic() + if waittime <= 0: + break + + yield from self.wait(waittime) + result = predicate() + + return result + + def notify(self, n=1): + """By default, wake up one coroutine waiting on this condition, if any. + If the calling coroutine has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up at most n of the coroutines waiting for the + condition variable; it is a no-op if no coroutines are waiting. + + Note: an awakened coroutine does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self._locked: + raise RuntimeError('cannot notify on un-acquired lock') + + idx = 0 + for fut in self._condition_waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self): + """Wake up all threads waiting on this condition. This method acts + like notify(), but wakes up all waiting threads instead of one. If the + calling thread has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._condition_waiters)) + + +class Semaphore: + """A Semaphore implementation. + + A semaphore manages an internal counter which is decremented by each + acquire() call and incremented by each release() call. The counter + can never go below zero; when acquire() finds that it is zero, it blocks, + waiting until some other thread calls release(). + + Semaphores also support the context manager protocol. + + The first optional argument gives the initial value for the internal + counter; it defaults to 1. If the value given is less than 0, + ValueError is raised. + + The second optional argument determins can semophore be released more than + initial internal counter value; it defaults to False. If the value given + is True and number of release() is more than number of successfull + acquire() calls ValueError is raised. + """ + + def __init__(self, value=1, bound=False): + if value < 0: + raise ValueError("Semaphore initial value must be > 0") + self._value = value + self._bound = bound + self._bound_value = value + self._waiters = collections.deque() + self._locked = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<%s [%s]>' % ( + res[1:-1], + 'locked' if self._locked else 'unlocked,value:%s' % self._value) + + def locked(self): + """Returns True if semaphore can not be acquired immediately.""" + return self._locked + + @tasks.coroutine + def acquire(self, timeout=None): + """Acquire a semaphore. acquire() method is a coroutine. + + When invoked without arguments: if the internal counter is larger + than zero on entry, decrement it by one and return immediately. + If it is zero on entry, block, waiting until some other coroutine has + called release() to make it larger than zero. + + When invoked with a timeout other than None, it will block for at + most timeout seconds. If acquire does not complete successfully in + that interval, return false. Return true otherwise. + """ + if not self._waiters and self._value > 0: + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + When it was zero on entry and another coroutine is waiting for it to + become larger than zero again, wake up that coroutine. + + If Semaphore is create with "bound" paramter equals true, then + release() method checks to make sure its current value doesn't exceed + its initial value. If it does, ValueError is raised. + """ + if self._bound and self._value >= self._bound_value: + raise ValueError('Semaphore released too many times') + + self._value += 1 + self._locked = False + + for waiter in self._waiters: + if not waiter.done(): + waiter.set_result(True) + break + + def __enter__(self): + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self diff --git a/tulip/log.py b/tulip/log.py new file mode 100644 index 00000000..b918fe54 --- /dev/null +++ b/tulip/log.py @@ -0,0 +1,6 @@ +"""Tulip logging configuration""" + +import logging + + +tulip_log = logging.getLogger("tulip") diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py new file mode 100644 index 00000000..46ffa58b --- /dev/null +++ b/tulip/proactor_events.py @@ -0,0 +1,189 @@ +"""Event loop using a proactor and related classes. + +A proactor is a "notify-on-completion" multiplexer. Currently a +proactor is only implemented on Windows with IOCP. +""" + +from . import base_events +from . import transports +from .log import tulip_log + + +class _ProactorSocketTransport(transports.Transport): + + def __init__(self, event_loop, sock, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._event_loop = event_loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._read_fut = None + self._write_fut = None + self._closing = False # Set when close() called. + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.call_soon(self._loop_reading) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _loop_reading(self, fut=None): + data = None + + try: + if fut is not None: + assert fut is self._read_fut + + data = fut.result() # deliver data later in "finally" clause + if not data: + self._read_fut = None + return + + self._read_fut = self._event_loop._proactor.recv(self._sock, 4096) + except ConnectionAbortedError as exc: + if not self._closing: + self._fatal_error(exc) + except OSError as exc: + self._fatal_error(exc) + else: + self._read_fut.add_done_callback(self._loop_reading) + finally: + if data: + self._protocol.data_received(data) + elif data is not None: + self._protocol.eof_received() + + def write(self, data): + assert isinstance(data, bytes), repr(data) + assert not self._closing + if not data: + return + self._buffer.append(data) + if not self._write_fut: + self._loop_writing() + + def _loop_writing(self, f=None): + try: + assert f is self._write_fut + if f: + f.result() + data = b''.join(self._buffer) + self._buffer = [] + if not data: + self._write_fut = None + return + self._write_fut = self._event_loop._proactor.send(self._sock, data) + except OSError as exc: + self._fatal_error(exc) + else: + self._write_fut.add_done_callback(self._loop_writing) + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + if self._write_fut: + self._write_fut.cancel() + if not self._buffer: + self._event_loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + tulip_log.exception('Fatal error for %s', self) + if self._write_fut: + self._write_fut.cancel() + if self._read_fut: # XXX + self._read_fut.cancel() + self._write_fut = self._read_fut = None + self._buffer = [] + self._event_loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + + +class BaseProactorEventLoop(base_events.BaseEventLoop): + + def __init__(self, proactor): + super().__init__() + tulip_log.debug('Using proactor: %s', proactor.__class__.__name__) + self._proactor = proactor + self._selector = proactor # convenient alias + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + return _ProactorSocketTransport(self, sock, protocol, waiter, extra) + + def close(self): + if self._proactor is not None: + self._close_self_pipe() + self._proactor.close() + self._proactor = None + self._selector = None + + def sock_recv(self, sock, n): + return self._proactor.recv(sock, n) + + def sock_sendall(self, sock, data): + return self._proactor.send(sock, data) + + def sock_connect(self, sock, address): + return self._proactor.connect(sock, address) + + def sock_accept(self, sock): + return self._proactor.accept(sock) + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + + def loop(f=None): + try: + if f: + f.result() # may raise + f = self._proactor.recv(self._ssock, 4096) + except: + self.close() + raise + else: + f.add_done_callback(loop) + self.call_soon(loop) + + def _write_to_self(self): + self._csock.send(b'x') + + def _start_serving(self, protocol_factory, sock): + def loop(f=None): + try: + if f: + conn, addr = f.result() + protocol = protocol_factory() + self._make_socket_transport( + conn, protocol, extra={'addr': addr}) + f = self._proactor.accept(sock) + except OSError: + sock.close() + tulip_log.exception('Accept failed') + else: + f.add_done_callback(loop) + self.call_soon(loop) + + def _process_events(self, event_list): + pass # XXX hard work currently done in poll diff --git a/tulip/protocols.py b/tulip/protocols.py new file mode 100644 index 00000000..593ee745 --- /dev/null +++ b/tulip/protocols.py @@ -0,0 +1,78 @@ +"""Abstract Protocol class.""" + +__all__ = ['Protocol', 'DatagramProtocol'] + + +class BaseProtocol: + """ABC for base protocol class. + + Usually user implements protocols that derived from BaseProtocol + like Protocol or ProcessProtocol. + + The only case when BaseProtocol should be implemented directly is + write-only transport like write pipe + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the pipe connection. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +class Protocol(BaseProtocol): + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing (they don't raise exceptions). + + When the user wants to requests a transport, they pass a protocol + factory to a utility function (e.g., EventLoop.create_connection()). + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_lost() will be called exactly once + with either an exception object or None as an argument. + + State machine of calls: + + start -> CM [-> DR*] [-> ER?] -> CL -> end + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + The default implementation does nothing. + + TODO: By default close the transport. But we don't have the + transport as an instance variable (connection_made() may not + set it). + """ + + +class DatagramProtocol(BaseProtocol): + """ABC representing a datagram protocol.""" + + def datagram_received(self, data, addr): + """Called when some datagram is received.""" + + def connection_refused(self, exc): + """Connection is refused.""" diff --git a/tulip/queues.py b/tulip/queues.py new file mode 100644 index 00000000..ee349e13 --- /dev/null +++ b/tulip/queues.py @@ -0,0 +1,291 @@ +"""Queues""" + +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue'] + +import collections +import concurrent.futures +import heapq +import queue + +from . import events +from . import futures +from . import locks +from .tasks import coroutine + + +class Queue: + """A queue, useful for coordinating producer and consumer coroutines. + + If maxsize is less than or equal to zero, the queue size is infinite. If it + is an integer greater than 0, then "yield from put()" will block when the + queue reaches maxsize, until an item is removed by get(). + + Unlike the standard library Queue, you can reliably know this Queue's size + with qsize(), since your single-threaded Tulip application won't be + interrupted between calling qsize() and doing an operation on the Queue. + """ + + def __init__(self, maxsize=0): + self._event_loop = events.get_event_loop() + self._maxsize = maxsize + + # Futures. + self._getters = collections.deque() + # Pairs of (item, Future). + self._putters = collections.deque() + self._init(maxsize) + + def _init(self, maxsize): + self._queue = collections.deque() + + def _get(self): + return self._queue.popleft() + + def _put(self, item): + self._queue.append(item) + + def __repr__(self): + return '<%s at %s %s>' % ( + type(self).__name__, hex(id(self)), self._format()) + + def __str__(self): + return '<%s %s>' % (type(self).__name__, self._format()) + + def _format(self): + result = 'maxsize=%r' % (self._maxsize, ) + if getattr(self, '_queue', None): + result += ' _queue=%r' % list(self._queue) + if self._getters: + result += ' _getters[%s]' % len(self._getters) + if self._putters: + result += ' _putters[%s]' % len(self._putters) + return result + + def _consume_done_getters(self, waiters): + # Delete waiters at the head of the get() queue who've timed out. + while waiters and waiters[0].done(): + waiters.popleft() + + def _consume_done_putters(self): + # Delete waiters at the head of the put() queue who've timed out. + while self._putters and self._putters[0][1].done(): + self._putters.popleft() + + def qsize(self): + """Number of items in the queue.""" + return len(self._queue) + + @property + def maxsize(self): + """Number of items allowed in the queue.""" + return self._maxsize + + def empty(self): + """Return True if the queue is empty, False otherwise.""" + return not self._queue + + def full(self): + """Return True if there are maxsize items in the queue. + + Note: if the Queue was initialized with maxsize=0 (the default), + then full() is never True. + """ + if self._maxsize <= 0: + return False + else: + return self.qsize() == self._maxsize + + @coroutine + def put(self, item, timeout=None): + """Put an item into the queue. + + If you yield from put() and timeout is None (the default), wait until a + free slot is available before adding item. + + If a timeout is provided, raise queue.Full if no free slot becomes + available before the timeout. + """ + self._consume_done_getters(self._getters) + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + waiter = futures.Future( + event_loop=self._event_loop, timeout=timeout) + + self._putters.append((item, waiter)) + try: + yield from waiter + except concurrent.futures.CancelledError: + raise queue.Full + + else: + self._put(item) + + def put_nowait(self, item): + """Put an item into the queue without blocking. + + If no free slot is immediately available, raise queue.Full. + """ + self._consume_done_getters(self._getters) + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + raise queue.Full + else: + self._put(item) + + @coroutine + def get(self, timeout=None): + """Remove and return an item from the queue. + + If you yield from get() and timeout is None (the default), wait until a + item is available. + + If a timeout is provided, raise queue.Empty if no item is available + before the timeout. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + + # When a getter runs and frees up a slot so this putter can + # run, we need to defer the put for a tick to ensure that + # getters and putters alternate perfectly. See + # ChannelTest.test_wait. + self._event_loop.call_soon(putter.set_result, None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + waiter = futures.Future( + event_loop=self._event_loop, timeout=timeout) + + self._getters.append(waiter) + try: + return (yield from waiter) + except concurrent.futures.CancelledError: + raise queue.Empty + + def get_nowait(self): + """Remove and return an item from the queue. + + Return an item if one is immediately available, else raise queue.Full. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + # Wake putter on next tick. + putter.set_result(None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + raise queue.Empty + + +class PriorityQueue(Queue): + """A subclass of Queue; retrieves entries in priority order (lowest first). + + Entries are typically tuples of the form: (priority number, data). + """ + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item, heappush=heapq.heappush): + heappush(self._queue, item) + + def _get(self, heappop=heapq.heappop): + return heappop(self._queue) + + +class LifoQueue(Queue): + """A subclass of Queue that retrieves most recently added entries first.""" + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item): + self._queue.append(item) + + def _get(self): + return self._queue.pop() + + +class JoinableQueue(Queue): + """A subclass of Queue with task_done() and join() methods.""" + + def __init__(self, maxsize=0): + self._unfinished_tasks = 0 + self._finished = locks.EventWaiter() + self._finished.set() + super().__init__(maxsize=maxsize) + + def _format(self): + result = Queue._format(self) + if self._unfinished_tasks: + result += ' tasks=%s' % self._unfinished_tasks + return result + + def _put(self, item): + super()._put(item) + self._unfinished_tasks += 1 + self._finished.clear() + + def task_done(self): + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each get() used to fetch a task, + a subsequent call to task_done() tells the queue that the processing + on the task is complete. + + If a join() is currently blocking, it will resume when all items have + been processed (meaning that a task_done() call was received for every + item that had been put() into the queue). + + Raises ValueError if called more times than there were items placed in + the queue. + """ + if self._unfinished_tasks <= 0: + raise ValueError('task_done() called too many times') + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + @coroutine + def join(self, timeout=None): + """Block until all items in the queue have been gotten and processed. + + The count of unfinished tasks goes up whenever an item is added to the + queue. The count goes down whenever a consumer thread calls task_done() + to indicate that the item was retrieved and all work on it is complete. + When the count of unfinished tasks drops to zero, join() unblocks. + """ + if self._unfinished_tasks > 0: + yield from self._finished.wait(timeout=timeout) diff --git a/tulip/selector_events.py b/tulip/selector_events.py new file mode 100644 index 00000000..20e5db00 --- /dev/null +++ b/tulip/selector_events.py @@ -0,0 +1,655 @@ +"""Event loop using a selector and related classes. + +A selector is a "notify-when-ready" multiplexer. For a subclass which +also includes support for signal handling, see the unix_events sub-module. +""" + +import collections +import errno +import socket +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import base_events +from . import events +from . import futures +from . import selectors +from . import transports +from .log import tulip_log + + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + + +class BaseSelectorEventLoop(base_events.BaseEventLoop): + """Selector event loop. + + See events.EventLoop for API specification. + """ + + def __init__(self, selector=None): + super().__init__() + + if selector is None: + selector = selectors.Selector() + tulip_log.debug('Using selector: %s', selector.__class__.__name__) + self._selector = selector + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + return _SelectorSocketTransport(self, sock, protocol, waiter, extra) + + def _make_ssl_transport(self, rawsock, protocol, + sslcontext, waiter, extra=None): + return _SelectorSslTransport( + self, rawsock, protocol, sslcontext, waiter, extra) + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + return _SelectorDatagramTransport(self, sock, protocol, address, extra) + + def close(self): + if self._selector is not None: + self._close_self_pipe() + self._selector.close() + self._selector = None + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self.remove_reader(self._ssock.fileno()) + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.add_reader(self._ssock.fileno(), self._read_from_self) + + def _read_from_self(self): + try: + self._ssock.recv(1) + except (BlockingIOError, InterruptedError): + pass + + def _write_to_self(self): + try: + self._csock.send(b'x') + except (BlockingIOError, InterruptedError): + pass + + def _start_serving(self, protocol_factory, sock): + self.add_reader(sock.fileno(), self._accept_connection, + protocol_factory, sock) + + def _accept_connection(self, protocol_factory, sock): + try: + conn, addr = sock.accept() + except (BlockingIOError, InterruptedError): + pass # False alarm. + except: + # Bad error. Stop serving. + self.remove_reader(sock.fileno()) + sock.close() + # There's nowhere to send the error, so just log it. + # TODO: Someone will want an error handler for this. + tulip_log.exception('Accept failed') + else: + self._make_socket_transport( + conn, protocol_factory(), extra={'addr': addr}) + # It's now up to the protocol to handle the connection. + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a Handle instance.""" + handle = events.make_handle(callback, args) + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_READ, + (handle, None)) + else: + self._selector.modify(fd, mask | selectors.EVENT_READ, + (handle, writer)) + if reader is not None: + reader.cancel() + + return handle + + def remove_reader(self, fd): + """Remove a reader callback.""" + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + return False + else: + mask &= ~selectors.EVENT_READ + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (None, writer)) + + if reader is not None: + reader.cancel() + return True + else: + return False + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a Handle instance.""" + handle = events.make_handle(callback, args) + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_WRITE, + (None, handle)) + else: + self._selector.modify(fd, mask | selectors.EVENT_WRITE, + (reader, handle)) + if writer is not None: + writer.cancel() + + return handle + + def remove_writer(self, fd): + """Remove a writer callback.""" + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + return False + else: + # Remove both writer and connector. + mask &= ~selectors.EVENT_WRITE + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None)) + + if writer is not None: + writer.cancel() + return True + else: + return False + + def sock_recv(self, sock, n): + """XXX""" + fut = futures.Future() + self._sock_recv(fut, False, sock, n) + return fut + + def _sock_recv(self, fut, registered, sock, n): + fd = sock.fileno() + if registered: + # Remove the callback early. It should be rare that the + # selector says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_reader(fd) + if fut.cancelled(): + return + try: + data = sock.recv(n) + fut.set_result(data) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_recv, fut, True, sock, n) + except Exception as exc: + fut.set_exception(exc) + + def sock_sendall(self, sock, data): + """XXX""" + fut = futures.Future() + if data: + self._sock_sendall(fut, False, sock, data) + else: + fut.set_result(None) + return fut + + def _sock_sendall(self, fut, registered, sock, data): + fd = sock.fileno() + + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + + try: + n = sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + fut.set_exception(exc) + return + + if n == len(data): + fut.set_result(None) + else: + if n: + data = data[n:] + self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + + def sock_connect(self, sock, address): + """XXX""" + # That address better not require a lookup! We're not calling + # self.getaddrinfo() for you here. But verifying this is + # complicated; the socket module doesn't have a pattern for + # IPv6 addresses (there are too many forms, apparently). + fut = futures.Future() + self._sock_connect(fut, False, sock, address) + return fut + + def _sock_connect(self, fut, registered, sock, address): + fd = sock.fileno() + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + try: + if not registered: + # First time around. + sock.connect(address) + else: + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to the except clause below. + raise socket.error(err, 'Connect call failed') + fut.set_result(None) + except (BlockingIOError, InterruptedError): + self.add_writer(fd, self._sock_connect, fut, True, sock, address) + except Exception as exc: + fut.set_exception(exc) + + def sock_accept(self, sock): + """XXX""" + fut = futures.Future() + self._sock_accept(fut, False, sock) + return fut + + def _sock_accept(self, fut, registered, sock): + fd = sock.fileno() + if registered: + self.remove_reader(fd) + if fut.cancelled(): + return + try: + conn, address = sock.accept() + conn.setblocking(False) + fut.set_result((conn, address)) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_accept, fut, True, sock) + except Exception as exc: + fut.set_exception(exc) + + def _process_events(self, event_list): + for fileobj, mask, (reader, writer) in event_list: + if mask & selectors.EVENT_READ and reader is not None: + if reader.cancelled: + self.remove_reader(fileobj) + else: + self._add_callback(reader) + if mask & selectors.EVENT_WRITE and writer is not None: + if writer.cancelled: + self.remove_writer(fileobj) + else: + self._add_callback(writer) + + +class _SelectorSocketTransport(transports.Transport): + + def __init__(self, event_loop, sock, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._event_loop = event_loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._closing = False # Set when close() called. + self._event_loop.add_reader(self._sock.fileno(), self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = self._sock.recv(16*1024) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + self._event_loop.remove_reader(self._sock.fileno()) + self._protocol.eof_received() + + def write(self, data): + assert isinstance(data, (bytes, bytearray)), repr(data) + assert not self._closing + if not data: + return + + if not self._buffer: + # Attempt to send it right away first. + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except socket.error as exc: + self._fatal_error(exc) + return + + if n == len(data): + return + elif n: + data = data[n:] + self._event_loop.add_writer(self._sock.fileno(), self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + assert data, "Data should not be empty" + + self._buffer.clear() + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except Exception as exc: + self._fatal_error(exc) + else: + if n == len(data): + self._event_loop.remove_writer(self._sock.fileno()) + if self._closing: + self._call_connection_lost(None) + return + elif n: + data = data[n:] + + self._buffer.append(data) # Try again later. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._close(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sock.fileno()) + if not self._buffer: + self._call_connection_lost(None) + + def _fatal_error(self, exc): + # should be called from exception handler only + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._event_loop.remove_writer(self._sock.fileno()) + self._event_loop.remove_reader(self._sock.fileno()) + self._buffer.clear() + self._call_connection_lost(exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + + +class _SelectorSslTransport(transports.Transport): + + def __init__(self, event_loop, rawsock, + protocol, sslcontext, waiter, extra=None): + super().__init__(extra) + + self._event_loop = event_loop + self._rawsock = rawsock + self._protocol = protocol + sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslcontext = sslcontext + self._waiter = waiter + sslsock = sslcontext.wrap_socket(rawsock, + do_handshake_on_connect=False) + self._sslsock = sslsock + self._buffer = [] + self._closing = False # Set when close() called. + self._extra['socket'] = sslsock + + self._on_handshake() + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._event_loop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._event_loop.add_writer(fd, self._on_handshake) + return + except Exception as exc: + self._sslsock.close() + self._waiter.set_exception(exc) + return + except BaseException as exc: + self._sslsock.close() + self._waiter.set_exception(exc) + raise + self._event_loop.remove_reader(fd) + self._event_loop.remove_writer(fd) + self._event_loop.add_reader(fd, self._on_ready) + self._event_loop.add_writer(fd, self._on_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.call_soon(self._waiter.set_result, None) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer + self._event_loop.remove_reader(fd) + self._event_loop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + return + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = b''.join(self._buffer) + self._buffer = [] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + n = 0 + except ssl.SSLWantWriteError: + n = 0 + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + + if n < len(data): + self._buffer.append(data[n:]) + elif self._closing: + self._event_loop.remove_writer(self._sslsock.fileno()) + self._sslsock.close() + self._protocol.connection_lost(None) + + def write(self, data): + assert isinstance(data, bytes), repr(data) + assert not self._closing + if not data: + return + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._close(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sslsock.fileno()) + if not self._buffer: + self._protocol.connection_lost(None) + + def _fatal_error(self, exc): + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._event_loop.remove_writer(self._sslsock.fileno()) + self._event_loop.remove_reader(self._sslsock.fileno()) + self._buffer = [] + self._protocol.connection_lost(exc) + + +class _SelectorDatagramTransport(transports.DatagramTransport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, event_loop, sock, protocol, address=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._event_loop = event_loop + self._sock = sock + self._fileno = sock.fileno() + self._protocol = protocol + self._address = address + self._buffer = collections.deque() + self._closing = False # Set when close() called. + self._event_loop.add_reader(self._fileno, self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + + def _read_ready(self): + try: + data, addr = self._sock.recvfrom(self.max_size) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + self._protocol.datagram_received(data, addr) + + def sendto(self, data, addr=None): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + + if self._address: + assert addr in (None, self._address) + + if not self._buffer: + # Attempt to send it right away first. + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + return + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except (BlockingIOError, InterruptedError): + self._event_loop.add_writer(self._fileno, self._sendto_ready) + except Exception as exc: + self._fatal_error(exc) + return + + self._buffer.append((data, addr)) + + def _sendto_ready(self): + while self._buffer: + data, addr = self._buffer.popleft() + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except (BlockingIOError, InterruptedError): + self._buffer.appendleft((data, addr)) # Try again later. + break + except Exception as exc: + self._fatal_error(exc) + return + + if not self._buffer: + self._event_loop.remove_writer(self._fileno) + if self._closing: + self._call_connection_lost(None) + + def abort(self): + self._close(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._fileno) + if not self._buffer: + self._call_connection_lost(None) + + def _fatal_error(self, exc): + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._buffer.clear() + self._event_loop.remove_writer(self._fileno) + self._event_loop.remove_reader(self._fileno) + if self._address and isinstance(exc, ConnectionRefusedError): + self._protocol.connection_refused(exc) + self._call_connection_lost(exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() diff --git a/tulip/selectors.py b/tulip/selectors.py new file mode 100644 index 00000000..57be7abe --- /dev/null +++ b/tulip/selectors.py @@ -0,0 +1,418 @@ +"""Select module. + +This module supports asynchronous I/O on multiple file descriptors. +""" + +import sys +from select import * + +from .log import tulip_log + + +# generic events, that must be mapped to implementation-specific ones +# read event +EVENT_READ = (1 << 0) +# write event +EVENT_WRITE = (1 << 1) + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file descriptor, or any object with a `fileno()` method + + Returns: + corresponding file descriptor + """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (ValueError, TypeError): + raise ValueError("Invalid file object: {!r}".format(fileobj)) + return fd + + +class SelectorKey: + """Object used internally to associate a file object to its backing file + descriptor, selected event mask and attached data.""" + + def __init__(self, fileobj, events, data=None): + self.fileobj = fileobj + self.fd = _fileobj_to_fd(fileobj) + self.events = events + self.data = data + + def __repr__(self): + return '{}'.format( + self.__class__.__name__, + self.fileobj, self.fd, self.events, self.data) + + +class _BaseSelector: + """Base selector class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + performant implementation on the current platform. + """ + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # this maps file objects to keys - for fast (un)registering + self._fileobj_to_key = {} + + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + """ + if (not events) or (events & ~(EVENT_READ|EVENT_WRITE)): + raise ValueError("Invalid events: {}".format(events)) + + if fileobj in self._fileobj_to_key: + raise ValueError("{!r} is already registered".format(fileobj)) + + key = SelectorKey(fileobj, events, data) + self._fd_to_key[key.fd] = key + self._fileobj_to_key[fileobj] = key + return key + + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object + + Returns: + SelectorKey instance + """ + try: + key = self._fileobj_to_key[fileobj] + del self._fd_to_key[key.fd] + del self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + return key + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + """ + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + if events != key.events or data != key.data: + self.unregister(fileobj) + return self.register(fileobj, events, data) + else: + return key + + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout == 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (fileobj, events, attached data) for ready file objects + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE + """ + raise NotImplementedError() + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + self._fd_to_key.clear() + self._fileobj_to_key.clear() + + def get_info(self, fileobj): + """Return information about a registered file object. + + Returns: + (events, data) associated to this file object + + Raises KeyError if the file object is not registered. + """ + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise KeyError("{} is not registered".format(fileobj)) + return key.events, key.data + + def registered_count(self): + """Return the number of registered file objects. + + Returns: + number of currently registered file objects + """ + return len(self._fd_to_key) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key + """ + try: + return self._fd_to_key[fd] + except KeyError: + tulip_log.warn('No key found for fd %r', fd) + return None + + +class SelectSelector(_BaseSelector): + """Select-based selector.""" + + def __init__(self): + super().__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & EVENT_WRITE: + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + def select(self, timeout=None): + try: + r, w, _ = self._select(self._readers, self._writers, [], timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + r = set(r) + w = set(w) + ready = [] + for fd in r | w: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + if sys.platform == 'win32': + def _select(self, r, w, _, timeout=None): + r, w, x = select(r, w, w, timeout) + return r, w + x, [] + else: + from select import select as _select + + +if 'poll' in globals(): + + # TODO: Implement poll() for Windows with workaround for + # brokenness in WSAPoll() (Richard Oudkerk, see + # http://bugs.python.org/issue16507). + + class PollSelector(_BaseSelector): + """Poll-based selector.""" + + def __init__(self): + super().__init__() + self._poll = poll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= POLLIN + if events & EVENT_WRITE: + poll_events |= POLLOUT + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = None if timeout is None else int(1000 * timeout) + ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~POLLIN: + events |= EVENT_WRITE + if event & ~POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + +if 'epoll' in globals(): + + class EpollSelector(_BaseSelector): + """Epoll-based selector.""" + + def __init__(self): + super().__init__() + self._epoll = epoll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + epoll_events = 0 + if events & EVENT_READ: + epoll_events |= EPOLLIN + if events & EVENT_WRITE: + epoll_events |= EPOLLOUT + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._epoll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = -1 if timeout is None else timeout + max_ev = self.registered_count() + ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~EPOLLIN: + events |= EVENT_WRITE + if event & ~EPOLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + def close(self): + super().close() + self._epoll.close() + + +if 'kqueue' in globals(): + + class KqueueSelector(_BaseSelector): + """Kqueue-based selector.""" + + def __init__(self): + super().__init__() + self._kqueue = kqueue() + + def unregister(self, fileobj): + key = super().unregister(fileobj) + if key.events & EVENT_READ: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + if key.events & EVENT_WRITE: + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + return key + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & EVENT_WRITE: + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def select(self, timeout=None): + max_ev = self.registered_count() + ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == KQ_FILTER_READ: + events |= EVENT_READ + if flag == KQ_FILTER_WRITE: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + def close(self): + super().close() + self._kqueue.close() + + +# Choose the best implementation: roughly, epoll|kqueue > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + Selector = KqueueSelector +elif 'EpollSelector' in globals(): + Selector = EpollSelector +elif 'PollSelector' in globals(): + Selector = PollSelector +else: + Selector = SelectSelector diff --git a/tulip/streams.py b/tulip/streams.py new file mode 100644 index 00000000..8d7f6236 --- /dev/null +++ b/tulip/streams.py @@ -0,0 +1,145 @@ +"""Stream-related things.""" + +__all__ = ['StreamReader'] + +import collections + +from . import futures +from . import tasks + + +class StreamReader: + + def __init__(self, limit=2**16): + self.limit = limit # Max line length. (Security feature.) + self.buffer = collections.deque() # Deque of bytes objects. + self.byte_count = 0 # Bytes in buffer. + self.eof = False # Whether we're done. + self.waiter = None # A future. + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + if self.waiter is not None: + self.waiter.set_exception(exc) + + def feed_eof(self): + self.eof = True + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(True) + + def feed_data(self, data): + if not data: + return + + self.buffer.append(data) + self.byte_count += len(data) + + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(False) + + @tasks.coroutine + def readline(self): + if self._exception is not None: + raise self._exception + + parts = [] + parts_size = 0 + not_enough = True + + while not_enough: + while self.buffer and not_enough: + data = self.buffer.popleft() + ichar = data.find(b'\n') + if ichar < 0: + parts.append(data) + parts_size += len(data) + else: + ichar += 1 + head, tail = data[:ichar], data[ichar:] + if tail: + self.buffer.appendleft(tail) + not_enough = False + parts.append(head) + parts_size += len(head) + + if parts_size > self.limit: + self.byte_count -= parts_size + raise ValueError('Line is too long') + + if self.eof: + break + + if not_enough: + assert self.waiter is None + self.waiter = futures.Future() + yield from self.waiter + + line = b''.join(parts) + self.byte_count -= parts_size + + return line + + @tasks.coroutine + def read(self, n=-1): + if self._exception is not None: + raise self._exception + + if not n: + return b'' + + if n < 0: + while not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + else: + if not self.byte_count and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + + if n < 0 or self.byte_count <= n: + data = b''.join(self.buffer) + self.buffer.clear() + self.byte_count = 0 + return data + + parts = [] + parts_bytes = 0 + while self.buffer and parts_bytes < n: + data = self.buffer.popleft() + data_bytes = len(data) + if n < parts_bytes + data_bytes: + data_bytes = n - parts_bytes + data, rest = data[:data_bytes], data[data_bytes:] + self.buffer.appendleft(rest) + + parts.append(data) + parts_bytes += data_bytes + self.byte_count -= data_bytes + + return b''.join(parts) + + @tasks.coroutine + def readexactly(self, n): + if self._exception is not None: + raise self._exception + + if n <= 0: + return b'' + + while self.byte_count < n and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + + return (yield from self.read(n)) diff --git a/tulip/subprocess_transport.py b/tulip/subprocess_transport.py new file mode 100644 index 00000000..734a5fa7 --- /dev/null +++ b/tulip/subprocess_transport.py @@ -0,0 +1,139 @@ +import fcntl +import os +import traceback + +from . import transports +from . import events +from .log import tulip_log + + +class UnixSubprocessTransport(transports.Transport): + """Transport class managing a subprocess. + + TODO: Separate this into something that just handles pipe I/O, + and something else that handles pipe setup, fork, and exec. + """ + + def __init__(self, protocol, args): + self._protocol = protocol # Not a factory! :-) + self._args = args # args[0] must be full path of binary. + self._event_loop = events.get_event_loop() + self._buffer = [] + self._eof = False + rstdin, self._wstdin = os.pipe() + self._rstdout, wstdout = os.pipe() + + # TODO: This is incredibly naive. Should look at + # subprocess.py for all the precautions around fork/exec. + pid = os.fork() + if not pid: + # Child. + try: + os.dup2(rstdin, 0) + os.dup2(wstdout, 1) + # TODO: What to do with stderr? + os.execv(args[0], args) + except: + try: + traceback.print_traceback() + finally: + os._exit(127) + + # Parent. + os.close(rstdin) + os.close(wstdout) + _setnonblocking(self._wstdin) + _setnonblocking(self._rstdout) + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.add_reader(self._rstdout, self._stdout_callback) + + def write(self, data): + assert not self._eof + assert isinstance(data, bytes), repr(data) + if not data: + return + + if not self._buffer: + # Attempt to write it right away first. + try: + n = os.write(self._wstdin, data) + except BlockingIOError: + pass + except Exception as exc: + self._fatal_error(exc) + return + else: + if n == len(data): + return + elif n: + data = data[n:] + self._event_loop.add_writer(self._wstdin, self._stdin_callback) + self._buffer.append(data) + + def write_eof(self): + assert not self._eof + assert self._wstdin >= 0 + self._eof = True + if not self._buffer: + self._event_loop.remove_writer(self._wstdin) + os.close(self._wstdin) + self._wstdin = -1 + + def close(self): + if not self._eof: + self.write_eof() + # XXX What else? + + def _fatal_error(self, exc): + tulip_log.error('Fatal error: %r', exc) + if self._rstdout >= 0: + os.close(self._rstdout) + self._rstdout = -1 + if self._wstdin >= 0: + os.close(self._wstdin) + self._wstdin = -1 + self._eof = True + self._buffer = None + + def _stdin_callback(self): + data = b''.join(self._buffer) + assert data, "Data shold not be empty" + + self._buffer = [] + try: + n = os.write(self._wstdin, data) + except BlockingIOError: + self._buffer.append(data) + except Exception as exc: + self._fatal_error(exc) + else: + if n >= len(data): + self._event_loop.remove_writer(self._wstdin) + if self._eof: + os.close(self._wstdin) + self._wstdin = -1 + return + + elif n > 0: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def _stdout_callback(self): + try: + data = os.read(self._rstdout, 1024) + except BlockingIOError: + pass + else: + if data: + self._event_loop.call_soon(self._protocol.data_received, data) + else: + self._event_loop.remove_reader(self._rstdout) + os.close(self._rstdout) + self._rstdout = -1 + self._event_loop.call_soon(self._protocol.eof_received) + + +def _setnonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) diff --git a/tulip/tasks.py b/tulip/tasks.py new file mode 100644 index 00000000..81359a27 --- /dev/null +++ b/tulip/tasks.py @@ -0,0 +1,320 @@ +"""Support for tasks, coroutines and the scheduler.""" + +__all__ = ['coroutine', 'task', 'Task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'as_completed', 'sleep', + ] + +import concurrent.futures +import functools +import inspect +import time + +from . import futures +from .log import tulip_log + + +def coroutine(func): + """Decorator to mark coroutines. + + Decorator wraps non generator functions and returns generator wrapper. + If non generator function returns generator of Future it yield-from it. + + TODO: This is a feel-good API only. It is not enforced. + """ + if inspect.isgeneratorfunction(func): + coro = func + else: + tulip_log.warning( + 'Coroutine function %s is not a generator.', func.__name__) + + @functools.wraps(func) + def coro(*args, **kw): + res = func(*args, **kw) + + if isinstance(res, futures.Future) or inspect.isgenerator(res): + res = yield from res + + return res + + coro._is_coroutine = True # Not sure who can use this. + return coro + + +# TODO: Do we need this? +def iscoroutinefunction(func): + """Return True if func is a decorated coroutine function.""" + return (inspect.isgeneratorfunction(func) and + getattr(func, '_is_coroutine', False)) + + +# TODO: Do we need this? +def iscoroutine(obj): + """Return True if obj is a coroutine object.""" + return inspect.isgenerator(obj) # TODO: And what? + + +def task(func): + """Decorator for a coroutine to be wrapped in a Task.""" + def task_wrapper(*args, **kwds): + coro = func(*args, **kwds) + return Task(coro) + return task_wrapper + + +class Task(futures.Future): + """A coroutine wrapped in a Future.""" + + def __init__(self, coro, event_loop=None, timeout=None): + assert inspect.isgenerator(coro) # Must be a coroutine *object*. + super().__init__(event_loop=event_loop, timeout=timeout) + self._coro = coro + self._must_cancel = False + self._event_loop.call_soon(self._step) + + def __repr__(self): + res = super().__repr__() + if (self._must_cancel and + self._state == futures._PENDING and + ')'.format(self._coro.__name__) + res[i:] + return res + + def cancel(self): + if self.done(): + return False + self._must_cancel = True + # _step() will call super().cancel() to call the callbacks. + self._event_loop.call_soon(self._step_maybe) + return True + + def cancelled(self): + return self._must_cancel or super().cancelled() + + def _step_maybe(self): + # Helper for cancel(). + if not self.done(): + return self._step() + + def _step(self, value=None, exc=None): + if self.done(): + tulip_log.warn('_step(): already done: %r, %r, %r', self, value, exc) + return + # We'll call either coro.throw(exc) or coro.send(value). + if self._must_cancel: + exc = futures.CancelledError + coro = self._coro + try: + if exc is not None: + result = coro.throw(exc) + elif value is not None: + result = coro.send(value) + else: + result = next(coro) + except StopIteration as exc: + if self._must_cancel: + super().cancel() + else: + self.set_result(exc.value) + except Exception as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + tulip_log.exception('Exception in task') + except BaseException as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + tulip_log.exception('BaseException in task') + raise + else: + # XXX No check for self._must_cancel here? + if isinstance(result, futures.Future): + if not result._blocking: + self._event_loop.call_soon( + self._step, + None, RuntimeError( + 'yield was used instead of yield from in task %r ' + 'with %r' % (self, result))) + else: + result._blocking = False + result.add_done_callback(self._wakeup) + + elif isinstance(result, concurrent.futures.Future): + # This ought to be more efficient than wrap_future(), + # because we don't create an extra Future. + result.add_done_callback( + lambda future: + self._event_loop.call_soon_threadsafe( + self._wakeup, future)) + else: + if inspect.isgenerator(result): + self._event_loop.call_soon( + self._step, + None, RuntimeError( + 'yield was used instead of yield from for ' + 'generator in task %r with %s' % (self, result))) + else: + if result is not None: + tulip_log.warn('_step(): bad yield: %r', result) + + self._event_loop.call_soon(self._step) + + def _wakeup(self, future): + try: + value = future.result() + except Exception as exc: + self._step(None, exc) + else: + self._step(value, None) + + +# wait() and as_completed() similar to those in PEP 3148. + +FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED +FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION +ALL_COMPLETED = concurrent.futures.ALL_COMPLETED + + +# Even though this *is* a @coroutine, we don't mark it as such! +def wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Wait for the Futures and and coroutines given by fs to complete. + + Coroutines will be wrapped in Tasks. + + Returns two sets of Future: (done, pending). + + Usage: + + done, pending = yield from tulip.wait(fs) + + Note: This does not raise TimeoutError! Futures that aren't done + when the timeout occurs are returned in the second set. + """ + fs = _wrap_coroutines(fs) + return _wait(fs, timeout, return_when) + + +@coroutine +def _wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Internal helper: Like wait() but does not wrap coroutines.""" + done, pending = set(), set() + + errors = 0 + for f in fs: + if f.done(): + done.add(f) + if not f.cancelled() and f.exception() is not None: + errors += 1 + else: + pending.add(f) + + if (not pending or + timeout is not None and timeout <= 0 or + return_when == FIRST_COMPLETED and done or + return_when == FIRST_EXCEPTION and errors): + return done, pending + + # Will always be cancelled eventually. + bail = futures.Future(timeout=timeout) + + def _on_completion(f): + pending.remove(f) + done.add(f) + if (not pending or + return_when == FIRST_COMPLETED or + (return_when == FIRST_EXCEPTION and + not f.cancelled() and + f.exception() is not None)): + bail.cancel() + + try: + for f in pending: + f.add_done_callback(_on_completion) + try: + yield from bail + except futures.CancelledError: + pass + finally: + for f in pending: + f.remove_done_callback(_on_completion) + + really_done = set(f for f in pending if f.done()) + if really_done: + done.update(really_done) + pending.difference_update(really_done) + + return done, pending + + +# This is *not* a @coroutine! It is just an iterator (yielding Futures). +def as_completed(fs, timeout=None): + """Return an iterator whose values, when waited for, are Futures. + + This differs from PEP 3148; the proper way to use this is: + + for f in as_completed(fs): + result = yield from f # The 'yield from' may raise. + # Use result. + + Raises TimeoutError if the timeout occurs before all Futures are + done. + + Note: The futures 'f' are not necessarily members of fs. + """ + deadline = None + if timeout is not None: + deadline = time.monotonic() + timeout + + done = None # Make nonlocal happy. + fs = _wrap_coroutines(fs) + + while fs: + if deadline is not None: + timeout = deadline - time.monotonic() + + @coroutine + def _wait_for_some(): + nonlocal done, fs + done, fs = yield from _wait(fs, timeout=timeout, + return_when=FIRST_COMPLETED) + if not done: + fs = set() + raise futures.TimeoutError() + return done.pop().result() # May raise. + + yield Task(_wait_for_some()) + for f in done: + yield f + + +def _wrap_coroutines(fs): + """Internal helper to process an iterator of Futures and coroutines. + + Returns a set of Futures. + """ + wrapped = set() + for f in fs: + if not isinstance(f, futures.Future): + assert iscoroutine(f) + f = Task(f) + wrapped.add(f) + return wrapped + + +def sleep(when, result=None): + """Return a Future that completes after a given time (in seconds). + + It's okay to cancel the Future. + + Undocumented feature: sleep(when, x) sets the Future's result to x. + """ + future = futures.Future() + future._event_loop.call_later(when, future.set_result, result) + return future diff --git a/tulip/test_utils.py b/tulip/test_utils.py new file mode 100644 index 00000000..9b87db2f --- /dev/null +++ b/tulip/test_utils.py @@ -0,0 +1,30 @@ +"""Utilities shared by tests.""" + +import logging +import socket +import sys +import unittest + + +if sys.platform == 'win32': # pragma: no cover + from .winsocketpair import socketpair +else: + from socket import socketpair # pragma: no cover + + +class LogTrackingTestCase(unittest.TestCase): + + def setUp(self): + self._logger = logging.getLogger() + self._log_level = self._logger.getEffectiveLevel() + + def tearDown(self): + self._logger.setLevel(self._log_level) + + def suppress_log_errors(self): # pragma: no cover + if self._log_level >= logging.WARNING: + self._logger.setLevel(logging.CRITICAL) + + def suppress_log_warnings(self): # pragma: no cover + if self._log_level >= logging.WARNING: + self._logger.setLevel(logging.ERROR) diff --git a/tulip/transports.py b/tulip/transports.py new file mode 100644 index 00000000..a9ec07a0 --- /dev/null +++ b/tulip/transports.py @@ -0,0 +1,134 @@ +"""Abstract Transport class.""" + +__all__ = ['ReadTransport', 'WriteTransport', 'Transport'] + + +class BaseTransport: + """Base ABC for transports.""" + + def __init__(self, extra=None): + if extra is None: + extra = {} + self._extra = extra + + def get_extra_info(self, name, default=None): + """Get optional transport information.""" + return self._extra.get(name, default) + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + raise NotImplementedError + + +class ReadTransport(BaseTransport): + """ABC for read-only transports.""" + + def pause(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + raise NotImplementedError + + +class WriteTransport(BaseTransport): + """ABC for write-only transports.""" + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def write_eof(self): + """Closes the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + raise NotImplementedError + + def can_write_eof(self): + """Return True if this protocol supports write_eof(), False if not.""" + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class Transport(ReadTransport, WriteTransport): + """ABC representing a bidirectional transport. + + There may be several implementations, but typically, the user does + not implement new transports; rather, the platform provides some + useful transports that are implemented using the platform's best + practices. + + The user never instantiates a transport directly; they call a + utility function, passing it a protocol factory and other + information necessary to create the transport and protocol. (E.g. + EventLoop.create_connection() or EventLoop.start_serving().) + + The utility function will asynchronously create a transport and a + protocol and hook them up by calling the protocol's + connection_made() method, passing it the transport. + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + +class DatagramTransport(BaseTransport): + """ABC for datagram (UDP) transports.""" + + def sendto(self, data, addr=None): + """Send data to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + addr is target socket address. + If addr is None use target address pointed on transport creation. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError diff --git a/tulip/unix_events.py b/tulip/unix_events.py new file mode 100644 index 00000000..3073ab64 --- /dev/null +++ b/tulip/unix_events.py @@ -0,0 +1,301 @@ +"""Selector eventloop for Unix with signal handling.""" + +import errno +import fcntl +import os +import socket +import sys + +try: + import signal +except ImportError: # pragma: no cover + signal = None + +from . import events +from . import selector_events +from . import transports +from .log import tulip_log + + +__all__ = ['SelectorEventLoop'] + + +if sys.platform == 'win32': # pragma: no cover + raise ImportError('Signals are not really supported on Windows') + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Unix event loop + + Adds signal handling to SelectorEventLoop + """ + + def __init__(self, selector=None): + super().__init__(selector) + self._signal_handlers = {} + + def _socketpair(self): + return socket.socketpair() + + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + self._check_signal(sig) + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except ValueError as exc: + raise RuntimeError(str(exc)) + + handle = events.make_handle(callback, args) + self._signal_handlers[sig] = handle + + try: + signal.signal(sig, self._handle_signal) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as nexc: + tulip_log.info('set_wakeup_fd(-1) failed: %s', nexc) + + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + return handle + + def _handle_signal(self, sig, arg): + """Internal helper that is the actual signal handler.""" + handle = self._signal_handlers.get(sig) + if handle is None: + return # Assume it's some race condition. + if handle.cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self.call_soon_threadsafe(handle) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not.""" + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as exc: + tulip_log.info('set_wakeup_fd(-1) failed: %s', exc) + + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + + if signal is None: + raise RuntimeError('Signals are not supported') + + if not (1 <= sig < signal.NSIG): + raise ValueError( + 'sig {} out of range(1, {})'.format(sig, signal.NSIG)) + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixReadPipeTransport(self, pipe, protocol, waiter, extra) + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixWritePipeTransport(self, pipe, protocol, waiter, extra) + + +def _set_nonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + + +class _UnixReadPipeTransport(transports.ReadTransport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, event_loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._event_loop = event_loop + self._pipe = pipe + self._fileno = pipe.fileno() + _set_nonblocking(self._fileno) + self._protocol = protocol + self._closing = False + self._event_loop.add_reader(self._fileno, self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = os.read(self._fileno, self.max_size) + except BlockingIOError: + pass + except OSError as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + self._event_loop.remove_reader(self._fileno) + self._protocol.eof_received() + + def pause(self): + self._event_loop.remove_reader(self._fileno) + + def resume(self): + self._event_loop.add_reader(self._fileno, self._read_ready) + + def close(self): + if not self._closing: + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._closing = True + self._event_loop.remove_reader(self._fileno) + self._call_connection_lost(exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + + +class _UnixWritePipeTransport(transports.WriteTransport): + + def __init__(self, event_loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._event_loop = event_loop + self._pipe = pipe + self._fileno = pipe.fileno() + _set_nonblocking(self._fileno) + self._protocol = protocol + self._buffer = [] + self._closing = False # Set when close() or write_eof() called. + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def write(self, data): + assert isinstance(data, bytes), repr(data) + assert not self._closing + if not data: + return + + if not self._buffer: + # Attempt to send it right away first. + try: + n = os.write(self._fileno, data) + except BlockingIOError: + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + if n == len(data): + return + elif n > 0: + data = data[n:] + self._event_loop.add_writer(self._fileno, self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + assert data, "Data should not be empty" + + self._buffer.clear() + try: + n = os.write(self._fileno, data) + except BlockingIOError: + self._buffer.append(data) + except Exception as exc: + self._fatal_error(exc) + else: + if n == len(data): + self._event_loop.remove_writer(self._fileno) + if self._closing: + self._call_connection_lost(None) + return + elif n > 0: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def can_write_eof(self): + return True + + def write_eof(self): + assert not self._closing + assert self._pipe + self._closing = True + if not self._buffer: + self._call_connection_lost(None) + + def close(self): + if not self._closing: + # write_eof is all what we needed to close the write pipe + self.write_eof() + + def abort(self): + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc=None): + self._closing = True + self._buffer.clear() + self._event_loop.remove_writer(self._fileno) + self._call_connection_lost(exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() diff --git a/tulip/windows_events.py b/tulip/windows_events.py new file mode 100644 index 00000000..2ec8561c --- /dev/null +++ b/tulip/windows_events.py @@ -0,0 +1,157 @@ +"""Selector and proactor eventloops for Windows.""" + +import socket +import weakref +import struct +import _winapi + +from . import futures +from . import proactor_events +from . import selector_events +from . import winsocketpair +from . import _overlapped +from .log import tulip_log + + +__all__ = ['SelectorEventLoop', 'ProactorEventLoop'] + + +NULL = 0 +INFINITE = 0xffffffff +ERROR_CONNECTION_REFUSED = 1225 +ERROR_CONNECTION_ABORTED = 1236 + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + def _socketpair(self): + return winsocketpair.socketpair() + + +class ProactorEventLoop(proactor_events.BaseProactorEventLoop): + def __init__(self, proactor=None): + if proactor is None: + proactor = IocpProactor() + super().__init__(proactor) + + def _socketpair(self): + return winsocketpair.socketpair() + + +class IocpProactor: + + def __init__(self, concurrency=0xffffffff): + self._results = [] + self._iocp = _overlapped.CreateIoCompletionPort( + _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency) + self._cache = {} + self._registered = weakref.WeakSet() + + def registered_count(self): + return len(self._cache) + + def select(self, timeout=None): + if not self._results: + self._poll(timeout) + tmp = self._results + self._results = [] + return tmp + + def recv(self, conn, nbytes, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + ov.WSARecv(conn.fileno(), nbytes, flags) + return self._register(ov, conn, ov.getresult) + + def send(self, conn, buf, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + ov.WSASend(conn.fileno(), buf, flags) + return self._register(ov, conn, ov.getresult) + + def accept(self, listener): + self._register_with_iocp(listener) + conn = self._get_accept_socket() + ov = _overlapped.Overlapped(NULL) + ov.AcceptEx(listener.fileno(), conn.fileno()) + + def finish_accept(): + addr = ov.getresult() + buf = struct.pack('@P', listener.fileno()) + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_ACCEPT_CONTEXT, + buf) + conn.settimeout(listener.gettimeout()) + return conn, conn.getpeername() + + return self._register(ov, listener, finish_accept) + + def connect(self, conn, address): + self._register_with_iocp(conn) + _overlapped.BindLocal(conn.fileno(), len(address)) + ov = _overlapped.Overlapped(NULL) + ov.ConnectEx(conn.fileno(), address) + + def finish_connect(): + ov.getresult() + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_CONNECT_CONTEXT, + 0) + return conn + + return self._register(ov, conn, finish_connect) + + def _register_with_iocp(self, obj): + if obj not in self._registered: + self._registered.add(obj) + _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) + + def _register(self, ov, obj, callback): + f = futures.Future() + self._cache[ov.address] = (f, ov, obj, callback) + return f + + def _get_accept_socket(self): + s = socket.socket() + s.settimeout(0) + return s + + def _poll(self, timeout=None): + if timeout is None: + ms = INFINITE + elif timeout < 0: + raise ValueError("negative timeout") + else: + ms = int(timeout * 1000 + 0.5) + if ms >= INFINITE: + raise ValueError("timeout too big") + while True: + status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms) + if status is None: + return + address = status[3] + f, ov, obj, callback = self._cache.pop(address) + try: + value = callback() + except OSError as e: + f.set_exception(e) + self._results.append(f) + else: + f.set_result(value) + self._results.append(f) + ms = 0 + + def close(self): + for (f, ov, obj, callback) in self._cache.values(): + try: + ov.cancel() + except OSError: + pass + + while self._cache: + if not self._poll(1): + tulip_log.debug('taking long time to close proactor') + + self._results = [] + if self._iocp is not None: + _winapi.CloseHandle(self._iocp) + self._iocp = None diff --git a/tulip/winsocketpair.py b/tulip/winsocketpair.py new file mode 100644 index 00000000..bd1e0928 --- /dev/null +++ b/tulip/winsocketpair.py @@ -0,0 +1,34 @@ +"""A socket pair usable as a self-pipe, for Windows. + +Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. +""" + +import socket +import sys + +if sys.platform != 'win32': # pragma: no cover + raise ImportError('winsocketpair is win32 only') + + +def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """Emulate the Unix socketpair() function on Windows.""" + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + lsock.bind(('localhost', 0)) + lsock.listen(1) + addr, port = lsock.getsockname() + csock = socket.socket(family, type, proto) + csock.setblocking(False) + try: + csock.connect((addr, port)) + except (BlockingIOError, InterruptedError): + pass + except: + lsock.close() + csock.close() + raise + ssock, _ = lsock.accept() + csock.setblocking(True) + lsock.close() + return (ssock, csock) From e589623c9fbd7ef5245c63949c518fb5b0e6a327 Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Thu, 28 Mar 2013 16:36:50 -0400 Subject: [PATCH 0395/1502] Replace deprecated logger.warn with warning --- tests/tasks_test.py | 12 ++++++------ tulip/selectors.py | 2 +- tulip/tasks.py | 5 +++-- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 9ac15bb9..ab713869 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -449,8 +449,8 @@ def notmuch(): task.set_result('ok') task._step() - self.assertTrue(m_logging.warn.called) - self.assertTrue(m_logging.warn.call_args[0][0].startswith( + self.assertTrue(m_logging.warning.called) + self.assertTrue(m_logging.warning.call_args[0][0].startswith( '_step(): already done: ')) @unittest.mock.patch('tulip.tasks.tulip_log') @@ -462,14 +462,14 @@ def notmuch(): task = tasks.Task(notmuch()) task._step() - self.assertFalse(m_logging.warn.called) + self.assertFalse(m_logging.warning.called) task._step() - self.assertTrue(m_logging.warn.called) + self.assertTrue(m_logging.warning.called) self.assertEqual( '_step(): bad yield: %r', - m_logging.warn.call_args[0][0]) - self.assertEqual(1, m_logging.warn.call_args[0][1]) + m_logging.warning.call_args[0][0]) + self.assertEqual(1, m_logging.warning.call_args[0][1]) def test_step_result_future(self): # If coroutine returns future, task waits on this future. diff --git a/tulip/selectors.py b/tulip/selectors.py index 57be7abe..bd81e554 100644 --- a/tulip/selectors.py +++ b/tulip/selectors.py @@ -196,7 +196,7 @@ def _key_from_fd(self, fd): try: return self._fd_to_key[fd] except KeyError: - tulip_log.warn('No key found for fd %r', fd) + tulip_log.warning('No key found for fd %r', fd) return None diff --git a/tulip/tasks.py b/tulip/tasks.py index 81359a27..f385f49b 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -102,7 +102,8 @@ def _step_maybe(self): def _step(self, value=None, exc=None): if self.done(): - tulip_log.warn('_step(): already done: %r, %r, %r', self, value, exc) + tulip_log.warning( + '_step(): already done: %r, %r, %r', self, value, exc) return # We'll call either coro.throw(exc) or coro.send(value). if self._must_cancel: @@ -162,7 +163,7 @@ def _step(self, value=None, exc=None): 'generator in task %r with %s' % (self, result))) else: if result is not None: - tulip_log.warn('_step(): bad yield: %r', result) + tulip_log.warning('_step(): bad yield: %r', result) self._event_loop.call_soon(self._step) From f5bc0dd2dc0abad22cea786f754a887e12d0bf0a Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 28 Mar 2013 17:54:31 -0700 Subject: [PATCH 0396/1502] fix proactor_tests --- tests/proactor_events_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py index 6b801ffa..039aa886 100644 --- a/tests/proactor_events_test.py +++ b/tests/proactor_events_test.py @@ -169,7 +169,7 @@ def test_close_2(self): self.assertFalse(self.event_loop.call_soon.called) - @unittest.mock.patch('tulip.proactor_events.logging') + @unittest.mock.patch('tulip.proactor_events.tulip_log') def test_fatal_error(self, m_logging): tr = _ProactorSocketTransport( self.event_loop, self.sock, self.protocol) @@ -184,7 +184,7 @@ def test_fatal_error(self, m_logging): tr._call_connection_lost, None) self.assertEqual([], tr._buffer) - @unittest.mock.patch('tulip.proactor_events.logging') + @unittest.mock.patch('tulip.proactor_events.tulip_log') def test_fatal_error_2(self, m_logging): tr = _ProactorSocketTransport( self.event_loop, self.sock, self.protocol) From c8a09d8dded231afb00fd64da6014fdeffab0f50 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 28 Mar 2013 18:38:25 -0700 Subject: [PATCH 0397/1502] http client refactoring --- crawl.py | 143 +++---- curl.py | 27 +- tests/events_test.py | 19 +- tests/http_client_functional_test.py | 395 ++++++++++++++++++ tests/http_client_test.py | 308 ++++++++++++++ tulip/http/client.py | 596 ++++++++++++++++++++++----- tulip/test_utils.py | 205 ++++++++- 7 files changed, 1455 insertions(+), 238 deletions(-) create mode 100644 tests/http_client_functional_test.py create mode 100644 tests/http_client_test.py diff --git a/crawl.py b/crawl.py index 4e5bebe2..723d5305 100755 --- a/crawl.py +++ b/crawl.py @@ -3,125 +3,79 @@ import logging import re import signal -import socket import sys import urllib.parse import tulip import tulip.http -END = '\n' -MAXTASKS = 100 - class Crawler: - def __init__(self, rooturl): + def __init__(self, rooturl, loop, maxtasks=100): self.rooturl = rooturl + self.loop = loop self.todo = set() self.busy = set() self.done = {} self.tasks = set() - self.waiter = None - self.addurl(self.rooturl, '') # Set initial work. - self.run() # Kick off work. - - def addurl(self, url, parenturl): - url = urllib.parse.urljoin(parenturl, url) - url, frag = urllib.parse.urldefrag(url) - if not url.startswith(self.rooturl): - return False - if url in self.busy or url in self.done or url in self.todo: - return False - self.todo.add(url) - waiter = self.waiter - if waiter is not None: - self.waiter = None - waiter.set_result(None) - return True + self.sem = tulip.Semaphore(maxtasks) @tulip.task def run(self): - while self.todo or self.busy or self.tasks: - complete, self.tasks = yield from tulip.wait(self.tasks, timeout=0) - print(len(complete), 'completed tasks,', len(self.tasks), - 'still pending ', end=END) - for task in complete: - try: - yield from task - except Exception as exc: - print('Exception in task:', exc, end=END) - while self.todo and len(self.tasks) < MAXTASKS: - url = self.todo.pop() - self.busy.add(url) - self.tasks.add(self.process(url)) # Async task. - if self.busy: - self.waiter = tulip.Future() - yield from self.waiter - tulip.get_event_loop().stop() + self.addurls([(self.rooturl, '')]) # Set initial work. + yield from tulip.sleep(1) + while self.busy: + yield from tulip.sleep(1) + self.loop.stop() + + @tulip.task + def addurls(self, urls): + for url, parenturl in urls: + url = urllib.parse.urljoin(parenturl, url) + url, frag = urllib.parse.urldefrag(url) + if (url.startswith(self.rooturl) and + url not in self.busy and + url not in self.done and + url not in self.todo): + self.todo.add(url) + yield from self.sem.acquire() + task = self.process(url) + task.add_done_callback(lambda t: self.sem.release()) + task.add_done_callback(self.tasks.remove) + self.tasks.add(task) @tulip.task def process(self, url): - ok = False - p = None + print('processing:', url) + + self.todo.remove(url) + self.busy.add(url) try: - print('processing', url, end=END) - scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) - if not path: - path = '/' - if query: - path = '?'.join([path, query]) - p = tulip.http.HttpClientProtocol( - netloc, path=path, ssl=(scheme=='https')) - delay = 1 - while True: - try: - status, headers, stream = yield from p.connect() - break - except socket.error as exc: - if delay >= 60: - raise - print('...', url, 'has error', repr(str(exc)), - 'retrying after sleep', delay, '...', end=END) - yield from tulip.sleep(delay) - delay *= 2 - - if status[:3] in ('301', '302'): - # Redirect. - u = headers.get('location') or headers.get('uri') - if self.addurl(u, url): - print(' ', url, status[:3], 'redirect to', u, end=END) - elif status.startswith('200'): - ctype = headers.get_content_type() - if ctype == 'text/html': - while True: - line = yield from stream.readline() - if not line: - break - line = line.decode('utf-8', 'replace') - urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', - line) - for u in urls: - if self.addurl(u, url): - print(' ', url, 'href to', u, end=END) - ok = True - finally: - if p is not None: - p.transport.close() - self.done[url] = ok - self.busy.remove(url) - if not ok: - print('failure for', url, sys.exc_info(), end=END) - waiter = self.waiter - if waiter is not None: - self.waiter = None - waiter.set_result(None) + resp = yield from tulip.http.request('get', url) + except Exception as exc: + print('...', url, 'has error', repr(str(exc))) + self.done[url] = False + else: + if resp.status == 200 and resp.get_content_type() == 'text/html': + data = (yield from resp.read()).decode('utf-8', 'replace') + urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', data) + self.addurls([(u, url) for u in urls]) + + resp.close() + self.done[url] = True + + self.busy.remove(url) + print(len(self.done), 'completed tasks,', len(self.tasks), + 'still pending, todo', len(self.todo)) def main(): - rooturl = sys.argv[1] - c = Crawler(rooturl) loop = tulip.get_event_loop() + + c = Crawler(sys.argv[1], loop) + c.run() + try: loop.add_signal_handler(signal.SIGINT, loop.stop) except RuntimeError: @@ -140,4 +94,5 @@ def main(): logging.info('using iocp') el = windows_events.ProactorEventLoop() events.set_event_loop(el) + main() diff --git a/curl.py b/curl.py index 37fce75c..7063adcd 100755 --- a/curl.py +++ b/curl.py @@ -1,28 +1,15 @@ #!/usr/bin/env python3 import sys -import urllib.parse - import tulip import tulip.http -def main(): - url = sys.argv[1] - scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) - if not path: - path = '/' - if query: - path = '?'.join([path, query]) - print(netloc, path, scheme) - p = tulip.http.HttpClientProtocol(netloc, path=path, ssl=(scheme=='https')) - f = p.connect() - sts, headers, stream = p.event_loop.run_until_complete(f) - print(sts) - for k, v in headers.items(): - print('{}: {}'.format(k, v)) - print() - data = p.event_loop.run_until_complete(stream.read()) +def curl(url): + response = yield from tulip.http.request('GET', url) + print(repr(response)) + + data = yield from response.read() print(data.decode('utf-8', 'replace')) @@ -32,4 +19,6 @@ def main(): sys.argv.remove('--iocp') el = windows_events.ProactorEventLoop() events.set_event_loop(el) - main() + + loop = tulip.get_event_loop() + loop.run_until_complete(curl(sys.argv[1])) diff --git a/tests/events_test.py b/tests/events_test.py index 956e861d..58a9c3d9 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -426,11 +426,11 @@ def sender(): def test_sock_client_ops(self): self.suppress_log_errors() - with test_utils.run_test_server(self.event_loop) as addr: + with test_utils.run_test_server(self.event_loop) as httpd: sock = socket.socket() sock.setblocking(False) self.event_loop.run_until_complete( - self.event_loop.sock_connect(sock, addr)) + self.event_loop.sock_connect(sock, httpd.address)) self.event_loop.run_until_complete( self.event_loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) data = self.event_loop.run_until_complete( @@ -568,8 +568,8 @@ def my_handler(*args): self.assertEqual(caught, 1) def test_create_connection(self): - with test_utils.run_test_server(self.event_loop) as addr: - f = self.event_loop.create_connection(MyProto, *addr) + with test_utils.run_test_server(self.event_loop) as httpd: + f = self.event_loop.create_connection(MyProto, *httpd.address) tr, pr = self.event_loop.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) @@ -577,10 +577,11 @@ def test_create_connection(self): self.assertTrue(pr.nbytes > 0) def test_create_connection_sock(self): - with test_utils.run_test_server(self.event_loop) as addr: + with test_utils.run_test_server(self.event_loop) as httpd: sock = None infos = self.event_loop.run_until_complete( - self.event_loop.getaddrinfo(*addr, type=socket.SOCK_STREAM)) + self.event_loop.getaddrinfo( + *httpd.address, type=socket.SOCK_STREAM)) for family, type, proto, cname, address in infos: try: sock = socket.socket(family=family, type=type, proto=proto) @@ -603,8 +604,10 @@ def test_create_connection_sock(self): @unittest.skipIf(ssl is None, 'No ssl module') def test_create_ssl_connection(self): - with test_utils.run_test_server(self.event_loop, use_ssl=True) as addr: - f = self.event_loop.create_connection(MyProto, *addr, ssl=True) + with test_utils.run_test_server( + self.event_loop, use_ssl=True) as httpd: + f = self.event_loop.create_connection( + MyProto, *httpd.address, ssl=True) tr, pr = self.event_loop.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) diff --git a/tests/http_client_functional_test.py b/tests/http_client_functional_test.py new file mode 100644 index 00000000..99c758ab --- /dev/null +++ b/tests/http_client_functional_test.py @@ -0,0 +1,395 @@ +"""Http client functional tests.""" + +import io +import os.path +import http.cookies + +import tulip +import tulip.http +from tulip import test_utils +from tulip.http import client + + +class HttpClientFunctionalTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.suppress_log_errors() + + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + super().tearDown() + + def test_HTTP_200_OK_METHOD(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + for meth in ('get', 'post', 'put', 'delete', 'head'): + r = self.loop.run_until_complete( + client.request(meth, httpd.url('method', meth))) + content1 = self.loop.run_until_complete(r.read()) + content2 = self.loop.run_until_complete(r.read()) + content = content1.decode() + + self.assertEqual(r.status, 200) + self.assertIn('"method": "%s"' % meth.upper(), content) + self.assertEqual(content1, content2) + + def test_HTTP_302_REDIRECT_GET(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('redirect', 2))) + + self.assertEqual(r.status, 200) + self.assertEqual(2, httpd['redirects']) + + def test_HTTP_302_REDIRECT_NON_HTTP(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + self.assertRaises( + ValueError, + self.loop.run_until_complete, + client.request('get', httpd.url('redirect_err'))) + + def test_HTTP_302_REDIRECT_POST(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('post', httpd.url('redirect', 2), + data={'some': 'data'})) + content = self.loop.run_until_complete(r.content.read()) + content = content.decode() + + self.assertEqual(r.status, 200) + self.assertIn('"method": "POST"', content) + self.assertEqual(2, httpd['redirects']) + + def test_HTTP_302_max_redirects(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('redirect', 5), + max_redirects=2)) + + self.assertEqual(r.status, 302) + self.assertEqual(2, httpd['redirects']) + + def test_HTTP_200_GET_WITH_PARAMS(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('method', 'get'), + params={'q': 'test'})) + content = self.loop.run_until_complete(r.content.read()) + content = content.decode() + + self.assertIn('"query": "q=test"', content) + self.assertEqual(r.status, 200) + + def test_HTTP_200_GET_WITH_MIXED_PARAMS(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request( + 'get', httpd.url('method', 'get') + '?test=true', + params={'q': 'test'})) + content = self.loop.run_until_complete(r.content.read()) + content = content.decode() + + self.assertIn('"query": "test=true&q=test"', content) + self.assertEqual(r.status, 200) + + def test_POST_DATA(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + r = self.loop.run_until_complete( + client.request('post', url, data={'some': 'data'})) + self.assertEqual(r.status, 200) + + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual({'some': ['data']}, content['form']) + self.assertEqual(r.status, 200) + + def test_POST_DATA_DEFLATE(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + r = self.loop.run_until_complete( + client.request('post', url, + data={'some': 'data'}, compress=True)) + self.assertEqual(r.status, 200) + + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual('deflate', content['compression']) + self.assertEqual({'some': ['data']}, content['form']) + self.assertEqual(r.status, 200) + + def test_POST_FILES(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request( + 'post', url, files={'some': f}, chunked=1024, + headers={'Transfer-Encoding': 'chunked'})) + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + + def test_POST_FILES_DEFLATE(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files={'some': f}, + chunked=1024, compress='deflate')) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual('deflate', content['compression']) + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + + def test_POST_FILES_STR(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files=[('some', f.read())])) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + 'some', content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + + def test_POST_FILES_LIST(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files=[('some', f)])) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + + def test_POST_FILES_LIST_CT(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, + files=[('some', f, 'text/plain')])) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual( + 'text/plain', content['multipart-data'][0]['content-type']) + self.assertEqual(r.status, 200) + + def test_POST_FILES_SINGLE(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files=[f])) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + filename, content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + + def test_POST_FILES_IO(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + data = io.BytesIO(b'data') + + r = self.loop.run_until_complete( + client.request('post', url, files=[data])) + + content = self.loop.run_until_complete(r.read(True)) + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + {'content-type': 'application/octet-stream', + 'data': 'data', + 'filename': 'unknown', + 'name': 'unknown'}, content['multipart-data'][0]) + self.assertEqual(r.status, 200) + + def test_POST_FILES_WITH_DATA(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, + data={'test': 'true'}, files={'some': f})) + + content = self.loop.run_until_complete(r.read(True)) + + self.assertEqual(2, len(content['multipart-data'])) + self.assertEqual( + 'test', content['multipart-data'][0]['name']) + self.assertEqual( + 'true', content['multipart-data'][0]['data']) + + f.seek(0) + filename = os.path.split(f.name)[-1] + self.assertEqual( + 'some', content['multipart-data'][1]['name']) + self.assertEqual( + filename, content['multipart-data'][1]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][1]['data']) + self.assertEqual(r.status, 200) + + def test_encoding(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('encoding', 'deflate'))) + self.assertEqual(r.status, 200) + + r = self.loop.run_until_complete( + client.request('get', httpd.url('encoding', 'gzip'))) + self.assertEqual(r.status, 200) + + def test_cookies(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + c = http.cookies.Morsel() + c.set('test3', '456', '456') + + r = self.loop.run_until_complete( + client.request( + 'get', httpd.url('method', 'get'), + cookies={'test1': '123', 'test2': c})) + self.assertEqual(r.status, 200) + + content = self.loop.run_until_complete(r.content.read()) + self.assertIn(b'"Cookie": "test1=123; test3=456"', content) + + def test_chunked(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('chunked'))) + self.assertEqual(r.status, 200) + self.assertEqual(r['Transfer-Encoding'], 'chunked') + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['path'], '/chunked') + + def test_timeout(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + httpd['noresponse'] = True + self.assertRaises( + tulip.TimeoutError, + self.loop.run_until_complete, + client.request('get', httpd.url('method', 'get'), timeout=0.1)) + + def test_request_conn_error(self): + self.assertRaises( + ConnectionRefusedError, + self.loop.run_until_complete, + client.request('get', 'http://0.0.0.0:1', timeout=0.1)) + + +class Functional(test_utils.Router): + + @test_utils.Router.define('/method/([A-Za-z]+)$') + def method(self, match): + meth = match.group(1).upper() + if meth == self._method: + self._response(self._start_response(200)) + else: + self._response(self._start_response(400)) + + @test_utils.Router.define('/redirect_err$') + def redirect_err(self, match): + self._response( + self._start_response(302), + headers={'Location': 'ftp://127.0.0.1/test/'}) + + @test_utils.Router.define('/redirect/([0-9]+)$') + def redirect(self, match): + no = int(match.group(1).upper()) + rno = self._props['redirects'] = self._props.get('redirects', 0) + 1 + + if rno >= no: + self._response( + self._start_response(302), + headers={'Location': '/method/%s' % self._method.lower()}) + else: + self._response( + self._start_response(302), + headers={'Location': self._path}) + + @test_utils.Router.define('/encoding/(gzip|deflate)$') + def encoding(self, match): + mode = match.group(1) + + resp = self._start_response(200) + resp.add_compression_filter(mode) + resp.add_chunking_filter(100) + self._response(resp, headers={'Content-encoding': mode}, chunked=True) + + @test_utils.Router.define('/chunked$') + def chunked(self, match): + resp = self._start_response(200) + resp.add_chunking_filter(100) + self._response(resp, chunked=True) diff --git a/tests/http_client_test.py b/tests/http_client_test.py new file mode 100644 index 00000000..0aa9d0bd --- /dev/null +++ b/tests/http_client_test.py @@ -0,0 +1,308 @@ +# -*- coding: utf-8 -*- +"""Tests for tulip/http/client.py""" + +import unittest +import unittest.mock +import urllib.parse + +import tulip +import tulip.http + +from tulip.http.client import HttpProtocol, HttpRequest, HttpResponse + + +class HttpProtocolTests(unittest.TestCase): + + def test_protocol(self): + transport = unittest.mock.Mock() + + p = HttpProtocol() + p.connection_made(transport) + self.assertIs(p.transport, transport) + self.assertIsInstance(p.stream, tulip.http.HttpStreamReader) + + p.data_received(b'data') + self.assertEqual(4, p.stream.byte_count) + + p.eof_received() + self.assertTrue(p.stream.eof) + + p.connection_lost(None) + + +class HttpResponseTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + self.transport = unittest.mock.Mock() + self.stream = tulip.http.HttpStreamReader(self.transport) + self.response = HttpResponse('get', 'http://python.org') + + def tearDown(self): + self.loop.close() + + def test_close(self): + self.response._transport = self.transport + self.response.close() + self.assertIsNone(self.response._transport) + self.assertTrue(self.transport.close.called) + self.response.close() + self.response.close() + + def test_repr(self): + self.response.status = 200 + self.response.reason = 'Ok' + self.assertIn( + '', repr(self.response)) + + +class HttpRequestTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + self.transport = unittest.mock.Mock() + self.stream = tulip.http.HttpStreamReader(self.transport) + + def tearDown(self): + self.loop.close() + + def test_method(self): + req = HttpRequest('get', 'http://python.org/') + self.assertEqual(req.method, 'GET') + + req = HttpRequest('head', 'http://python.org/') + self.assertEqual(req.method, 'HEAD') + + req = HttpRequest('HEAD', 'http://python.org/') + self.assertEqual(req.method, 'HEAD') + + def test_version(self): + req = HttpRequest('get', 'http://python.org/', version='1.0') + self.assertEqual(req.version, (1, 0)) + + def test_version_err(self): + self.assertRaises( + ValueError, + HttpRequest, 'get', 'http://python.org/', version='1.c') + + def test_host_port(self): + req = HttpRequest('get', 'http://python.org/') + self.assertEqual(req.host, 'python.org') + self.assertEqual(req.port, 80) + self.assertFalse(req.ssl) + + req = HttpRequest('get', 'https://python.org/') + self.assertEqual(req.host, 'python.org') + self.assertEqual(req.port, 443) + self.assertTrue(req.ssl) + + req = HttpRequest('get', 'https://python.org:960/') + self.assertEqual(req.host, 'python.org') + self.assertEqual(req.port, 960) + self.assertTrue(req.ssl) + + def test_host_port_err(self): + self.assertRaises( + ValueError, HttpRequest, 'get', 'http://python.org:123e/') + + def test_host_header(self): + req = HttpRequest('get', 'http://python.org/') + self.assertEqual(req.headers['host'], 'python.org') + + req = HttpRequest('get', 'http://python.org/', + headers={'host': 'example.com'}) + self.assertEqual(req.headers['host'], 'example.com') + + def test_headers(self): + req = HttpRequest('get', 'http://python.org/', + headers={'Content-Type': 'text/plain'}) + self.assertIn('Content-Type', req.headers) + self.assertEqual(req.headers['Content-Type'], 'text/plain') + self.assertEqual(req.headers['Accept-Encoding'], 'gzip, deflate') + + def test_headers_list(self): + req = HttpRequest('get', 'http://python.org/', + headers=[('Content-Type', 'text/plain')]) + self.assertIn('Content-Type', req.headers) + self.assertEqual(req.headers['Content-Type'], 'text/plain') + + def test_headers_default(self): + req = HttpRequest('get', 'http://python.org/', + headers={'Accept-Encoding': 'deflate'}) + self.assertEqual(req.headers['Accept-Encoding'], 'deflate') + + def test_invalid_url(self): + self.assertRaises(ValueError, HttpRequest, 'get', 'hiwpefhipowhefopw') + + def test_invalid_idna(self): + self.assertRaises( + ValueError, HttpRequest, 'get', 'http://\u2061owhefopw.com') + + def test_no_path(self): + req = HttpRequest('get', 'http://python.org') + self.assertEqual('/', req.path) + + def test_basic_auth(self): + req = HttpRequest('get', 'http://python.org', auth=('nkim', '1234')) + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbToxMjM0', req.headers['Authorization']) + + def test_basic_auth_from_url(self): + req = HttpRequest('get', 'http://nkim:1234@python.org') + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbToxMjM0', req.headers['Authorization']) + + req = HttpRequest('get', 'http://nkim@python.org') + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbTo=', req.headers['Authorization']) + + req = HttpRequest( + 'get', 'http://nkim@python.org', auth=('nkim', '1234')) + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbToxMjM0', req.headers['Authorization']) + + def test_basic_auth_err(self): + self.assertRaises( + ValueError, HttpRequest, + 'get', 'http://python.org', auth=(1, 2, 3)) + + def test_no_content_length(self): + req = HttpRequest('get', 'http://python.org') + req.send(self.transport) + self.assertEqual(0, req.headers.get('Content-Length')) + + req = HttpRequest('head', 'http://python.org') + req.send(self.transport) + self.assertEqual(0, req.headers.get('Content-Length')) + + def test_path_is_not_double_encoded(self): + req = HttpRequest('get', "http://0.0.0.0/get/test case") + self.assertEqual(req.path, "/get/test%20case") + + req = HttpRequest('get', "http://0.0.0.0/get/test%20case") + self.assertEqual(req.path, "/get/test%20case") + + def test_params_are_added_before_fragment(self): + req = HttpRequest( + 'GET', "http://example.com/path#fragment", params={"a": "b"}) + self.assertEqual( + req.path, "/path?a=b#fragment") + + req = HttpRequest( + 'GET', + "http://example.com/path?key=value#fragment", params={"a": "b"}) + self.assertEqual( + req.path, "/path?key=value&a=b#fragment") + + def test_cookies(self): + req = HttpRequest( + 'get', 'http://test.com/path', cookies={'cookie1': 'val1'}) + self.assertIn('Cookie', req.headers) + self.assertEqual('cookie1=val1', req.headers['cookie']) + + req = HttpRequest( + 'get', 'http://test.com/path', + headers={'cookie': 'cookie1=val1'}, + cookies={'cookie2': 'val2'}) + self.assertEqual('cookie1=val1; cookie2=val2', req.headers['cookie']) + + def test_unicode_get(self): + def join(*suffix): + return urllib.parse.urljoin('http://python.org/', '/'.join(suffix)) + + url = 'http://python.org' + req = HttpRequest('get', url, params={'foo': 'føø'}) + self.assertEqual('/?foo=f%C3%B8%C3%B8', req.path) + req = HttpRequest('', url, params={'føø': 'føø'}) + self.assertEqual('/?f%C3%B8%C3%B8=f%C3%B8%C3%B8', req.path) + req = HttpRequest('', url, params={'foo': 'foo'}) + self.assertEqual('/?foo=foo', req.path) + req = HttpRequest('', join('ø'), params={'foo': 'foo'}) + self.assertEqual('/%C3%B8?foo=foo', req.path) + + def test_query_multivalued_param(self): + for meth in HttpRequest.ALL_METHODS: + req = HttpRequest( + meth, 'http://python.org', + params=(('test', 'foo'), ('test', 'baz'))) + self.assertEqual(req.path, '/?test=foo&test=baz') + + def test_post_data(self): + for meth in HttpRequest.POST_METHODS: + req = HttpRequest(meth, 'http://python.org/', data={'life': '42'}) + req.send(self.transport) + self.assertEqual('/', req.path) + self.assertEqual(b'life=42', req.body[0]) + self.assertEqual('application/x-www-form-urlencoded', + req.headers['content-type']) + + def test_get_with_data(self): + for meth in HttpRequest.GET_METHODS: + req = HttpRequest(meth, 'http://python.org/', data={'life': '42'}) + self.assertEqual('/?life=42', req.path) + + @unittest.mock.patch('tulip.http.client.tulip') + def test_content_encoding(self, m_tulip): + req = HttpRequest('get', 'http://python.org/', compress='deflate') + req.send(self.transport) + self.assertEqual(req.headers['Transfer-encoding'], 'chunked') + self.assertEqual(req.headers['Content-encoding'], 'deflate') + m_tulip.http.Request.return_value\ + .add_compression_filter.assert_called_with('deflate') + + @unittest.mock.patch('tulip.http.client.tulip') + def test_content_encoding_header(self, m_tulip): + req = HttpRequest('get', 'http://python.org/', + headers={'Content-Encoding': 'deflate'}) + req.send(self.transport) + self.assertEqual(req.headers['Transfer-encoding'], 'chunked') + self.assertEqual(req.headers['Content-encoding'], 'deflate') + + m_tulip.http.Request.return_value\ + .add_compression_filter.assert_called_with('deflate') + m_tulip.http.Request.return_value\ + .add_chunking_filter.assert_called_with(8196) + + def test_chunked(self): + req = HttpRequest( + 'get', 'http://python.org/', + headers={'Transfer-encoding': 'gzip'}) + req.send(self.transport) + self.assertEqual('gzip', req.headers['Transfer-encoding']) + + req = HttpRequest( + 'get', 'http://python.org/', + headers={'Transfer-encoding': 'chunked'}) + req.send(self.transport) + self.assertEqual('chunked', req.headers['Transfer-encoding']) + + @unittest.mock.patch('tulip.http.client.tulip') + def test_chunked_explicit(self, m_tulip): + req = HttpRequest( + 'get', 'http://python.org/', chunked=True) + req.send(self.transport) + + self.assertEqual('chunked', req.headers['Transfer-encoding']) + m_tulip.http.Request.return_value\ + .add_chunking_filter.assert_called_with(8196) + + @unittest.mock.patch('tulip.http.client.tulip') + def test_chunked_explicit_size(self, m_tulip): + req = HttpRequest( + 'get', 'http://python.org/', chunked=1024) + req.send(self.transport) + self.assertEqual('chunked', req.headers['Transfer-encoding']) + m_tulip.http.Request.return_value\ + .add_chunking_filter.assert_called_with(1024) + + def test_chunked_length(self): + req = HttpRequest( + 'get', 'http://python.org/', + headers={'Content-Length': '1000'}, chunked=1024) + req.send(self.transport) + self.assertEqual(req.headers['Transfer-Encoding'], 'chunked') + self.assertNotIn('Content-Length', req.headers) diff --git a/tulip/http/client.py b/tulip/http/client.py index b65b90a8..0fff6e86 100644 --- a/tulip/http/client.py +++ b/tulip/http/client.py @@ -2,144 +2,532 @@ Most basic usage: - sts, headers, response = yield from http_client.fetch(url, - method='GET', headers={}, request=b'') - assert isinstance(sts, int) - assert isinstance(headers, dict) - # sort of; case insensitive (what about multiple values for same header?) - headers['status'] == '200 Ok' # or some such - assert isinstance(response, bytes) - -TODO: Reuse email.Message class (or its subclass, http.client.HTTPMessage). -TODO: How do we do connection keep alive? Pooling? -""" + response = yield from tulip.http.request('GET', url) + response['Content-Type'] == 'application/json' + response.status == 200 -__all__ = ['HttpClientProtocol'] + content = yield from response.content.read() +""" +__all__ = ['request'] +import base64 import email.message -import email.parser +import http.client +import http.cookies +import json +import io +import itertools +import mimetypes +import os +import uuid +import urllib.parse import tulip +from tulip.http import protocol -from . import protocol +@tulip.coroutine +def request(method, url, *, + params=None, + data=None, + headers=None, + cookies=None, + files=None, + auth=None, + allow_redirects=True, + max_redirects=10, + encoding='utf-8', + version=(1, 1), + timeout=None, + compress=None, + chunked=None): + """Constructs and sends a request. Returns response object. -class HttpClientProtocol: - """This Protocol class is also used to initiate the connection. + method: http method + url: request url + params: (optional) Dictionary or bytes to be sent in the query string + of the new request + data: (optional) Dictionary, bytes, or file-like object to + send in the body of the request + headers: (optional) Dictionary of HTTP Headers to send with the request + cookies: (optional) Dict object to send with the request + files: (optional) Dictionary of 'name': file-like-objects + for multipart encoding upload + auth: (optional) Auth tuple to enable Basic HTTP Auth + timeout: (optional) Float describing the timeout of the request + allow_redirects: (optional) Boolean. Set to True if POST/PUT/DELETE + redirect following is allowed. + compress: Boolean. Set to True if request has to be compressed + with deflate encoding. + chunked: Boolean or Integer. Set to chunk size for chunked + transfer encoding. Usage: - p = HttpClientProtocol(url, ...) - sts, headers, stream = yield from p.connect() + + import tulip.http + >> resp = yield from tulip.http.request('GET', 'http://python.org/') + >> resp + + + >> data = yield from resp.content.read() """ + redirects = 0 + loop = tulip.get_event_loop() - def __init__(self, host, port=None, *, - path='/', method='GET', headers=None, ssl=None, - make_body=None, encoding='utf-8', version=(1, 1), - chunked=False): - host = self.validate(host, 'host') - if ':' in host: - assert port is None - host, port_s = host.split(':', 1) - port = int(port_s) - self.host = host - if port is None: + while True: + req = HttpRequest( + method, url, params=params, headers=headers, data=data, + cookies=cookies, files=files, auth=auth, encoding=encoding, + version=version, compress=compress, chunked=chunked) + + # connection timeout + try: + resp = yield from tulip.Task(start(req, loop), timeout=timeout) + except tulip.CancelledError: + raise tulip.TimeoutError from None + + # redirects + if resp.status in (301, 302) and allow_redirects: + redirects += 1 + if max_redirects and redirects >= max_redirects: + resp.close() + break + + r_url = resp.get('location') or resp.get('uri') + + scheme = urllib.parse.urlsplit(r_url)[0] + if scheme not in ('http', 'https', ''): + raise ValueError('Can redirect only to http or https') + elif not scheme: + r_url = urllib.parse.urljoin(url, r_url) + + url = urllib.parse.urldefrag(r_url)[0] + if url: + resp.close() + continue + + break + + return resp + + +@tulip.coroutine +def start(req, loop): + transport, p = yield from loop.create_connection( + HttpProtocol, req.host, req.port, ssl=req.ssl) + try: + resp = req.send(transport) + yield from resp.start(p.stream, transport) + except: + transport.close() + raise + + return resp + + +class HttpProtocol(tulip.Protocol): + + stream = None + transport = None + + def connection_made(self, transport): + self.transport = transport + self.stream = protocol.HttpStreamReader() + + def data_received(self, data): + self.stream.feed_data(data) + + def eof_received(self): + self.stream.feed_eof() + + def connection_lost(self, exc): + pass + + +class HttpRequest: + + GET_METHODS = {'DELETE', 'GET', 'HEAD', 'OPTIONS'} + POST_METHODS = {'PATCH', 'POST', 'PUT', 'TRACE'} + ALL_METHODS = GET_METHODS.union(POST_METHODS) + + DEFAULT_HEADERS = { + 'Accept': '*/*', + 'Accept-Encoding': 'gzip, deflate', + } + + body = b'' + + def __init__(self, method, url, *, + params=None, + headers=None, + data=None, + cookies=None, + files=None, + auth=None, + encoding='utf-8', + version=(1, 1), + compress=None, + chunked=None): + self.method = method.upper() + self.encoding = encoding + + # parser http version '1.1' => (1, 1) + if isinstance(version, str): + v = [l.strip() for l in version.split('.', 1)] + try: + version = int(v[0]), int(v[1]) + except: + raise ValueError( + 'Can not parse http version number: {}' + .format(version)) from None + self.version = version + + # path + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not netloc: + raise ValueError('Host could not be detected.') + + if not path: + path = '/' + else: + path = urllib.parse.unquote(path) + + # check domain idna encoding + try: + netloc = netloc.encode('idna').decode('utf-8') + except UnicodeError: + raise ValueError('URL has an invalid label.') + + # basic auth info + if '@' in netloc: + authinfo, netloc = netloc.split('@', 1) + if not auth: + auth = authinfo.split(':', 1) + if len(auth) == 1: + auth.append('') + + # extract host and port + ssl = scheme == 'https' + + if ':' in netloc: + netloc, port_s = netloc.split(':', 1) + try: + port = int(port_s) + except: + raise ValueError( + 'Port number could not be converted.') from None + else: if ssl: - port = 443 + port = http.client.HTTPS_PORT else: - port = 80 - assert isinstance(port, int) + port = http.client.HTTP_PORT + + self.host = netloc self.port = port - self.path = self.validate(path, 'path') - self.method = self.validate(method, 'method') - self.headers = email.message.Message() - self.headers['Accept-Encoding'] = 'gzip, deflate' - if headers: - for key, value in headers.items(): - self.validate(key, 'header key') - self.validate(value, 'header value', True) - self.headers[key] = value - self.encoding = self.validate(encoding, 'encoding') - self.version = version - self.make_body = make_body - self.chunked = chunked self.ssl = ssl - if 'content-length' not in self.headers: - if self.make_body is None: - self.headers['Content-Length'] = '0' - else: - self.chunked = True - if self.chunked: - if 'Transfer-Encoding' not in self.headers: - self.headers['Transfer-Encoding'] = 'chunked' + + # build url query + if isinstance(params, dict): + params = list(params.items()) + + if data and self.method in self.GET_METHODS: + # include data to query + if isinstance(data, dict): + data = data.items() + params = list(itertools.chain(params or (), data)) + data = None + + if params: + params = urllib.parse.urlencode(params) + if query: + query = '%s&%s' % (query, params) else: - assert self.headers['Transfer-Encoding'].lower() == 'chunked' + query = params + + # build path + path = urllib.parse.quote(path) + self.path = urllib.parse.urlunsplit(('', '', path, query, fragment)) + + # headers + self.headers = email.message.Message() + if headers: + if isinstance(headers, dict): + headers = list(headers.items()) + + for key, value in headers: + self.headers.add_header(key, value) + + for hdr, val in self.DEFAULT_HEADERS.items(): + if hdr not in self.headers: + self.headers[hdr] = val + + # host if 'host' not in self.headers: self.headers['Host'] = self.host - self.event_loop = tulip.get_event_loop() - self.transport = None - - def validate(self, value, name, embedded_spaces_okay=False): - # Must be a string. If embedded_spaces_okay is False, no - # whitespace is allowed; otherwise, internal single spaces are - # allowed (but no other whitespace). - assert isinstance(value, str), \ - '{} should be str, not {}'.format(name, type(value)) - parts = value.split() - assert parts, '{} should not be empty'.format(name) - if embedded_spaces_okay: - assert ' '.join(parts) == value, \ - '{} can only contain embedded single spaces ({!r})'.format( - name, value) + + # cookies + if cookies: + c = http.cookies.SimpleCookie() + if 'cookie' in self.headers: + c.load(self.headers.get('cookie', '')) + del self.headers['cookie'] + + for name, value in cookies.items(): + if isinstance(value, http.cookies.Morsel): + dict.__setitem__(c, name, value) + else: + c[name] = value + + self.headers['cookie'] = c.output(header='', sep=';').strip() + + # auth + if auth: + if isinstance(auth, (tuple, list)) and len(auth) == 2: + # basic auth + self.headers['Authorization'] = 'Basic %s' % ( + base64.b64encode( + ('%s:%s' % (auth[0], auth[1])).encode('latin1')) + .strip().decode('latin1')) + else: + raise ValueError("Only basic auth is supported") + + self._params = (chunked, compress, files, data, encoding) + + def send(self, transport): + chunked, compress, files, data, encoding = self._params + + request = tulip.http.Request( + transport, self.method, self.path, self.version) + + # Content-encoding + enc = self.headers.get('Content-Encoding', '').lower() + if enc: + chunked = True # enable chunked, no need to deal with length + request.add_compression_filter(enc) + elif compress: + chunked = True # enable chunked, no need to deal with length + compress = compress if isinstance(compress, str) else 'deflate' + self.headers['Content-Encoding'] = compress + request.add_compression_filter(compress) + + # form data (x-www-form-urlencoded) + if isinstance(data, dict): + data = list(data.items()) + + if data and not files: + if not isinstance(data, str): + data = urllib.parse.urlencode(data, doseq=True) + + self.body = data.encode(encoding) + if 'content-type' not in self.headers: + self.headers['content-type'] = ( + 'application/x-www-form-urlencoded') + if 'content-length' not in self.headers and not chunked: + self.headers['content-length'] = len(self.body) + + # files (multipart/form-data) + elif files: + fields = [] + + if data: + for field, val in data: + fields.append((field, str_to_bytes(val))) + + if isinstance(files, dict): + files = list(files.items()) + + for rec in files: + if not isinstance(rec, (tuple, list)): + rec = (rec,) + + ft = None + if len(rec) == 1: + k = guess_filename(rec[0], 'unknown') + fields.append((k, k, rec[0])) + + elif len(rec) == 2: + k, fp = rec + fn = guess_filename(fp, k) + fields.append((k, fn, fp)) + + else: + k, fp, ft = rec + fn = guess_filename(fp, k) + fields.append((k, fn, fp, ft)) + + chunked = chunked or 8192 + boundary = uuid.uuid4().hex + + self.body = encode_multipart_data( + fields, bytes(boundary, 'latin1')) + + self.headers['content-type'] = ( + 'multipart/form-data; boundary=%s' % boundary) + + # chunked + te = self.headers.get('transfer-encoding', '').lower() + + if chunked: + if 'content-length' in self.headers: + del self.headers['content-length'] + if 'chunked' not in te: + self.headers['Transfer-encoding'] = 'chunked' + + chunk_size = chunked if type(chunked) is int else 8196 + request.add_chunking_filter(chunk_size) else: - assert parts == [value], \ - '{} cannot contain whitespace ({!r})'.format(name, value) - return value + if 'chunked' in te: + request.add_chunking_filter(8196) + else: + chunked = False + self.headers['content-length'] = len(self.body) - @tulip.coroutine - def connect(self): - yield from self.event_loop.create_connection( - lambda: self, self.host, self.port, ssl=self.ssl) + request.add_headers(*self.headers.items()) + request.send_headers() + + if isinstance(self.body, bytes): + self.body = (self.body,) + + for chunk in self.body: + request.write(chunk) + + request.write_eof() + + return HttpResponse(self.method, self.path, self.host) + + +class HttpResponse(http.client.HTTPMessage): + + # from the Status-Line of the response + version = None # HTTP-Version + status = None # Status-Code + reason = None # Reason-Phrase + + content = None # payload stream + + _content = None + _transport = None + + def __init__(self, method, url, host=''): + super().__init__() + + self.method = method + self.url = url + self.host = host + + def __repr__(self): + out = io.StringIO() + print(''.format( + self.host, self.url, self.status, self.reason), file=out) + print(super().__str__(), file=out) + return out.getvalue() + + def start(self, stream, transport): + """Start response processing.""" + self._transport = transport - # read response status - version, status, reason = yield from self.stream.read_response_status() + # read status + self.version, self.status, self.reason = ( + yield from stream.read_response_status()) - message = yield from self.stream.read_message(version) + # does the body have a fixed length? (of zero) + length = None + if (self.status == http.client.NO_CONTENT or + self.status == http.client.NOT_MODIFIED or + 100 <= self.status < 200 or self.method == "HEAD"): + length = 0 + + # http message + message = yield from stream.read_message(length=length) # headers - headers = email.message.Message() for hdr, val in message.headers: - headers.add_header(hdr, val) + self.add_header(hdr, val) - sts = '{} {}'.format(status, reason) - return (sts, headers, message.payload) + # payload + self.content = message.payload - def connection_made(self, transport): - self.transport = transport - self.stream = protocol.HttpStreamReader() + return self - self.request = protocol.Request( - transport, self.method, self.path, self.version) + def close(self): + if self._transport is not None: + self._transport.close() + self._transport = None - self.request.add_headers(*self.headers.items()) - self.request.send_headers() + @tulip.coroutine + def read(self, decode=False): + """Read response payload. Decode known types of content.""" + if self._content is None: + self._content = yield from self.content.read() + + data = self._content + + if decode: + ct = self.get('content-type', '').lower() + if ct == 'application/json': + data = json.loads(data.decode('utf-8')) + + return data + + +def str_to_bytes(s, encoding='utf-8'): + if isinstance(s, str): + return s.encode(encoding) + return s + + +def guess_filename(obj, default=None): + name = getattr(obj, 'name', None) + if name and name[0] != '<' and name[-1] != '>': + return os.path.split(name)[-1] + return default + + +def encode_multipart_data(fields, boundary, encoding='utf-8', chunk_size=8196): + """ + Encode a list of fields using the multipart/form-data MIME format. + + fields: + List of (name, value) or (name, filename, io) or + (name, filename, io, MIME type) field tuples. + """ + for rec in fields: + yield b'--' + boundary + b'\r\n' + + field, *rec = rec + + if len(rec) == 1: + data = rec[0] + yield (('Content-Disposition: form-data; name="%s"\r\n\r\n' % + (field,)).encode(encoding)) + yield data + b'\r\n' - if self.make_body is not None: - if self.chunked: - self.make_body( - self.request.write, self.request.eof) - else: - self.make_body( - self.request.write, self.request.eof) else: - self.request.write_eof() + if len(rec) == 3: + fn, fp, ct = rec + else: + fn, fp = rec + ct = (mimetypes.guess_type(fn)[0] or + 'application/octet-stream') - def data_received(self, data): - self.stream.feed_data(data) + yield ('Content-Disposition: form-data; name="%s"; ' + 'filename="%s"\r\n' % (field, fn)).encode(encoding) + yield ('Content-Type: %s\r\n\r\n' % (ct,)).encode(encoding) - def eof_received(self): - self.stream.feed_eof() + if isinstance(fp, str): + fp = fp.encode(encoding) - def connection_lost(self, exc): - pass + if isinstance(fp, bytes): + fp = io.BytesIO(fp) + + while True: + chunk = fp.read(chunk_size) + if not chunk: + break + yield str_to_bytes(chunk) + + yield b'\r\n' + + yield b'--' + boundary + b'--\r\n' diff --git a/tulip/test_utils.py b/tulip/test_utils.py index ac0b5f59..d6219143 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -1,11 +1,19 @@ """Utilities shared by tests.""" +import cgi import contextlib +import email.parser +import http.server +import json import logging +import io import os +import re import socket import sys import threading +import traceback +import urllib.parse import unittest try: import ssl @@ -14,6 +22,7 @@ import tulip import tulip.http +from tulip.http import client if sys.platform == 'win32': # pragma: no cover @@ -41,18 +50,53 @@ def suppress_log_warnings(self): # pragma: no cover @contextlib.contextmanager -def run_test_server(loop, *, host='127.0.0.1', port=0, use_ssl=False): - class HttpServer(tulip.http.ServerHttpProtocol): +def run_test_server(loop, *, host='127.0.0.1', port=0, + use_ssl=False, router=None): + properties = {} + + class HttpServer: + + def __init__(self, host, port): + self.host = host + self.port = port + self.address = (host, port) + self._url = '{}://{}:{}'.format( + 'https' if use_ssl else 'http', host, port) + + def __getitem__(self, key): + return properties[key] + + def __setitem__(self, key, value): + properties[key] = value + + def url(self, *suffix): + return urllib.parse.urljoin( + self._url, '/'.join(str(s) for s in suffix)) + + class TestHttpServer(tulip.http.ServerHttpProtocol): + def handle_request(self, info, message): - response = tulip.http.Response( - self.transport, 200, info.version) - - text = b'Test message' - response.add_header('Content-type', 'text/plain') - response.add_header('Content-length', str(len(text))) - response.send_headers() - response.write(text) - response.write_eof() + if properties.get('noresponse', False): + return + + if router is not None: + payload = io.BytesIO((yield from message.payload.read())) + rob = router( + properties, self.transport, + info, message.headers, payload, message.compression) + rob.dispatch() + + else: + response = tulip.http.Response( + self.transport, 200, info.version) + + text = b'Test message' + response.add_header('Content-type', 'text/plain') + response.add_header('Content-length', str(len(text))) + response.send_headers() + response.write(text) + response.write_eof() + self.transport.close() if use_ssl: @@ -69,7 +113,8 @@ def run(loop, fut): tulip.set_event_loop(thread_loop) sock = thread_loop.run_until_complete( - thread_loop.start_serving(HttpServer, host, port, ssl=sslcontext)) + thread_loop.start_serving( + TestHttpServer, host, port, ssl=sslcontext)) waiter = tulip.Future() loop.call_soon_threadsafe( @@ -84,7 +129,141 @@ def run(loop, fut): thread_loop, waiter, addr = loop.run_until_complete(fut) try: - yield addr + yield HttpServer(*addr) finally: thread_loop.call_soon_threadsafe(waiter.set_result, None) server_thread.join() + + +class Router: + + _response_version = "1.1" + _responses = http.server.BaseHTTPRequestHandler.responses + + def __init__(self, props, transport, rline, headers, body, cmode): + # headers + self._headers = http.client.HTTPMessage() + for hdr, val in headers: + self._headers.add_header(hdr, val) + + self._props = props + self._transport = transport + self._method = rline.method + self._uri = rline.uri + self._version = rline.version + self._compression = cmode + self._body = body.read() + + url = urllib.parse.urlsplit(self._uri) + self._path = url.path + self._query = url.query + + @staticmethod + def define(rmatch): + def wrapper(fn): + f_locals = sys._getframe(1).f_locals + mapping = f_locals.setdefault('_mapping', []) + mapping.append((re.compile(rmatch), fn.__name__)) + return fn + + return wrapper + + def dispatch(self): # pragma: no cover + for route, fn in self._mapping: + match = route.match(self._path) + if match is not None: + try: + return getattr(self, fn)(match) + except: + out = io.StringIO() + traceback.print_exc(file=out) + self._response(500, out.getvalue()) + + return + + return self._response(self._start_response(404)) + + def _start_response(self, code): + return tulip.http.Response(self._transport, code) + + def _response(self, response, body=None, headers=None, chunked=False): + r_headers = {} + for key, val in self._headers.items(): + key = '-'.join(p.capitalize() for p in key.split('-')) + r_headers[key] = val + + encoding = self._headers.get('content-encoding', '').lower() + if 'gzip' in encoding: # pragma: no cover + cmod = 'gzip' + elif 'deflate' in encoding: + cmod = 'deflate' + else: + cmod = '' + + resp = { + 'method': self._method, + 'version': '%s.%s' % self._version, + 'path': self._uri, + 'headers': r_headers, + 'origin': self._transport.get_extra_info('addr', ' ')[0], + 'query': self._query, + 'form': {}, + 'compression': cmod, + 'multipart-data': [] + } + if body: # pragma: no cover + resp['content'] = body + + ct = self._headers.get('content-type', '').lower() + + # application/x-www-form-urlencoded + if ct == 'application/x-www-form-urlencoded': + resp['form'] = urllib.parse.parse_qs(self._body.decode('latin1')) + + # multipart/form-data + elif ct.startswith('multipart/form-data'): # pragma: no cover + out = io.BytesIO() + for key, val in self._headers.items(): + out.write(bytes('{}: {}\r\n'.format(key, val), 'latin1')) + + out.write(b'\r\n') + out.write(self._body) + out.write(b'\r\n') + out.seek(0) + + message = email.parser.BytesParser().parse(out) + if message.is_multipart(): + for msg in message.get_payload(): + if msg.is_multipart(): + logging.warn('multipart msg is not expected') + else: + key, params = cgi.parse_header( + msg.get('content-disposition', '')) + params['data'] = msg.get_payload() + params['content-type'] = msg.get_content_type() + resp['multipart-data'].append(params) + + body = json.dumps(resp, indent=4, sort_keys=True) + + # default headers + hdrs = [('Connection', 'close'), + ('Content-Type', 'application/json')] + if chunked: + hdrs.append(('Transfer-Encoding', 'chunked')) + else: + hdrs.append(('Content-Length', str(len(body)))) + + # extra headers + if headers: + hdrs.extend(headers.items()) + + if chunked: + response.force_chunked() + + # headers + response.add_headers(*hdrs) + response.send_headers() + + # write payload + response.write(client.str_to_bytes(body)) + response.write_eof() From 224be9a6e6947d6f36b0ff36623bf3ddecbd41da Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Fri, 29 Mar 2013 13:26:31 -0400 Subject: [PATCH 0398/1502] Replace % with format() everywhere. --- check.py | 2 +- examples/udp_echo.py | 4 ++-- runtests.py | 14 +++++++------- tests/http_protocol_test.py | 19 ++++++++++--------- tulip/base_events.py | 4 ++-- tulip/http/errors.py | 2 +- tulip/http/protocol.py | 12 ++++++------ tulip/http/server.py | 12 ++++++------ tulip/locks.py | 9 +++++---- tulip/queues.py | 16 ++++++++-------- tulip/tasks.py | 7 ++++--- 11 files changed, 52 insertions(+), 49 deletions(-) diff --git a/check.py b/check.py index 64bc2cdd..d28b31f7 100644 --- a/check.py +++ b/check.py @@ -33,7 +33,7 @@ def process(fn): line = line.rstrip('\n') sline = line.rstrip() if len(line) > 80 or line != sline or not isascii(line): - print('%s:%d:%s%s' % ( + print('{}:{:d}:{}{}'.format( fn, i+1, sline, '_' * (len(line) - len(sline)))) finally: f.close() diff --git a/examples/udp_echo.py b/examples/udp_echo.py index 9e995d14..5d1e02ec 100644 --- a/examples/udp_echo.py +++ b/examples/udp_echo.py @@ -35,12 +35,12 @@ class MyClientUdpEchoProtocol: def connection_made(self, transport): self.transport = transport - print('sending "%s"' % self.message) + print('sending "{}"'.format(self.message)) self.transport.sendto(self.message.encode()) print('waiting to receive') def datagram_received(self, data, addr): - print('received "%s"' % data.decode()) + print('received "{}"'.format(data.decode())) self.transport.close() def connection_refused(self, exc): diff --git a/runtests.py b/runtests.py index 096a2561..0ec5ba31 100644 --- a/runtests.py +++ b/runtests.py @@ -64,20 +64,20 @@ def list_dir(prefix, dir): modpath = os.path.join(dir, '__init__.py') if os.path.isfile(modpath): mod = os.path.split(dir)[-1] - files.append(('%s%s' % (prefix, mod), modpath)) + files.append(('{}{}'.format(prefix, mod), modpath)) - prefix = '%s%s.' % (prefix, mod) + prefix = '{}{}.'.format(prefix, mod) for name in os.listdir(dir): path = os.path.join(dir, name) if os.path.isdir(path): - files.extend(list_dir('%s%s.' % (prefix, name), path)) + files.extend(list_dir('{}{}.'.format(prefix, name), path)) else: if (name != '__init__.py' and name.endswith(suffix) and not name.startswith(('.', '_'))): - files.append(('%s%s' % (prefix, name[:-3]), path)) + files.append(('{}{}'.format(prefix, name[:-3]), path)) return files @@ -89,7 +89,7 @@ def list_dir(prefix, dir): loader = importlib.machinery.SourceFileLoader(modname, sourcefile) mods.append((loader.load_module(), sourcefile)) except Exception as err: - print("Skipping '%s': %s" % (modname, err), file=sys.stderr) + print("Skipping '{}': {}".format(modname, err), file=sys.stderr) return mods @@ -125,7 +125,7 @@ def runtests(): testsdir = os.path.abspath(args.testsdir) if not os.path.isdir(testsdir): - print("Tests directory is not found: %s\n" % testsdir) + print("Tests directory is not found: {}\n".format(testsdir)) ARGS.print_help() return @@ -173,7 +173,7 @@ def runcoverage(sdir, args): sdir = os.path.abspath(sdir) if not os.path.isdir(sdir): - print("Python files directory is not found: %s\n" % sdir) + print("Python files directory is not found: {}\n".format(sdir)) ARGS.print_help() return diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py index 74aef7c8..6fd0b64c 100644 --- a/tests/http_protocol_test.py +++ b/tests/http_protocol_test.py @@ -264,9 +264,9 @@ def test_read_message_content_length_no_val(self): def test_read_message_deflate(self): self.stream.feed_data( - ('Host: example.com\r\nContent-Length: %s\r\n' - 'Content-Encoding: deflate\r\n\r\n' % - len(self._COMPRESSED)).encode()) + ('Host: example.com\r\nContent-Length: {}\r\n' + 'Content-Encoding: deflate\r\n\r\n'.format( + len(self._COMPRESSED)).encode())) self.stream.feed_data(self._COMPRESSED) msg = self.loop.run_until_complete(self.stream.read_message()) @@ -276,8 +276,8 @@ def test_read_message_deflate(self): def test_read_message_deflate_disabled(self): self.stream.feed_data( ('Host: example.com\r\nContent-Encoding: deflate\r\n' - 'Content-Length: %s\r\n\r\n' % - len(self._COMPRESSED)).encode()) + 'Content-Length: {}\r\n\r\n'.format( + len(self._COMPRESSED)).encode())) self.stream.feed_data(self._COMPRESSED) msg = self.loop.run_until_complete( @@ -288,7 +288,8 @@ def test_read_message_deflate_disabled(self): def test_read_message_deflate_unknown(self): self.stream.feed_data( ('Host: example.com\r\nContent-Encoding: compress\r\n' - 'Content-Length: %s\r\n\r\n' % len(self._COMPRESSED)).encode()) + 'Content-Length: {}\r\n\r\n'.format( + len(self._COMPRESSED)).encode())) self.stream.feed_data(self._COMPRESSED) msg = self.loop.run_until_complete( @@ -460,7 +461,7 @@ def test_read_message_deflate_payload(self): self.stream.feed_data( b'Host: example.com\r\n' b'Content-Encoding: deflate\r\n' + - ('Content-Length: %s\r\n\r\n' % len(data)).encode()) + ('Content-Length: {}\r\n\r\n'.format(len(data)).encode())) msg = self.loop.run_until_complete( self.stream.read_message(readall=True)) @@ -928,7 +929,7 @@ def test_write_payload_chunked_large_chunk(self): def test_write_payload_deflate_filter(self): write = self.transport.write = unittest.mock.Mock() msg = protocol.Response(self.transport, 200) - msg.add_headers(('content-length', '%s' % len(self._COMPRESSED))) + msg.add_headers(('content-length', '{}'.format(len(self._COMPRESSED)))) msg.send_headers() msg.add_compression_filter('deflate') @@ -958,7 +959,7 @@ def test_write_payload_deflate_and_chunked(self): def test_write_payload_chunked_and_deflate(self): write = self.transport.write = unittest.mock.Mock() msg = protocol.Response(self.transport, 200) - msg.add_headers(('content-length', '%s' % len(self._COMPRESSED))) + msg.add_headers(('content-length', '{}'.format(len(self._COMPRESSED)))) msg.add_chunking_filter(2) msg.add_compression_filter('deflate') diff --git a/tulip/base_events.py b/tulip/base_events.py index 573807d8..127530f9 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -221,7 +221,7 @@ def call_later(self, delay, callback, *args): def call_repeatedly(self, interval, callback, *args): """Call a callback every 'interval' seconds.""" - assert interval > 0, 'Interval must be > 0: %r' % (interval,) + assert interval > 0, 'Interval must be > 0: {!r}'.format(interval) # TODO: What if callback is already a Handle? def wrapper(): callback(*args) # If this fails, the chain is broken. @@ -517,7 +517,7 @@ def _run_once(self, timeout=None): t0 = time.monotonic() event_list = self._selector.select(timeout) t1 = time.monotonic() - argstr = '' if timeout is None else ' %.3f' % timeout + argstr = '' if timeout is None else '{:.3f}'.format(timeout) if t1-t0 >= 1: level = logging.INFO else: diff --git a/tulip/http/errors.py b/tulip/http/errors.py index 41344de1..24032337 100644 --- a/tulip/http/errors.py +++ b/tulip/http/errors.py @@ -40,5 +40,5 @@ class LineTooLong(BadRequestException, http.client.LineTooLong): class InvalidHeader(BadRequestException): def __init__(self, hdr): - super().__init__('Invalid HTTP Header: %s' % hdr) + super().__init__('Invalid HTTP Header: {}'.format(hdr)) self.hdr = hdr diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py index 6a0e1279..339b68e9 100644 --- a/tulip/http/protocol.py +++ b/tulip/http/protocol.py @@ -176,12 +176,12 @@ def read_headers(self): # Parse initial header name : value pair. sep_pos = line.find(b':') if sep_pos < 0: - raise ValueError('Invalid header %s' % line.strip()) + raise ValueError('Invalid header {}'.format(line.strip())) name, value = line[:sep_pos], line[sep_pos+1:] name = name.rstrip(b' \t').upper() if HDRRE.search(name): - raise ValueError('Invalid header name %s' % name) + raise ValueError('Invalid header name {}'.format(name)) name = name.strip().decode('ascii', 'surrogateescape') value = [value.lstrip()] @@ -629,7 +629,7 @@ def add_header(self, name, value): """Analyze headers. Calculate content length, removes hop headers, etc.""" assert not self.headers_sent, 'headers have been sent already' - assert isinstance(name, str), '%r is not a string' % name + assert isinstance(name, str), '{!r} is not a string'.format(name) name = name.strip().upper() @@ -692,9 +692,9 @@ def send_headers(self): # send headers self.transport.write( - ('%s\r\n\r\n' % '\r\n'.join( - ('%s: %s' % (k, v) for k, v in - itertools.chain(self._default_headers(), self.headers))) + ('{}\r\n\r\n'.format('\r\n'.join( + ('{}: {}'.format(k, v) for k, v in + itertools.chain(self._default_headers(), self.headers)))) ).encode('ascii')) def _default_headers(self): diff --git a/tulip/http/server.py b/tulip/http/server.py index 7590e47b..5d68cfb5 100644 --- a/tulip/http/server.py +++ b/tulip/http/server.py @@ -16,11 +16,11 @@ DEFAULT_ERROR_MESSAGE = """ - %(status)s %(reason)s + {status} {reason} -

%(status)s %(reason)s

- %(message)s +

{status} {reason}

+ {message} """ @@ -129,14 +129,14 @@ def handle_error(self, status=500, info=None, if self.debug and exc is not None: try: tb = traceback.format_exc() - msg += '

Traceback:

\n
%s
' % tb + msg += '

Traceback:

\n
{}
'.format(tb) except: pass self.log_access(status, info, message) - html = DEFAULT_ERROR_MESSAGE % { - 'status': status, 'reason': reason, 'message': msg} + html = DEFAULT_ERROR_MESSAGE.format( + status=status, reason=reason, message=msg) response = tulip.http.Response(self.transport, status, close=True) response.add_headers( diff --git a/tulip/locks.py b/tulip/locks.py index 40247962..ff841442 100644 --- a/tulip/locks.py +++ b/tulip/locks.py @@ -69,7 +69,7 @@ def __init__(self): def __repr__(self): res = super().__repr__() - return '<%s [%s]>' % ( + return '<{} [{}]>'.format( res[1:-1], 'locked' if self._locked else 'unlocked') def locked(self): @@ -157,7 +157,7 @@ def __init__(self): def __repr__(self): res = super().__repr__() - return '<%s [%s]>' % (res[1:-1], 'set' if self._value else 'unset') + return '<{} [{}]>'.format(res[1:-1], 'set' if self._value else 'unset') def is_set(self): """Return true if and only if the internal flag is true.""" @@ -358,9 +358,10 @@ def __init__(self, value=1, bound=False): def __repr__(self): res = super().__repr__() - return '<%s [%s]>' % ( + return '<{} [{}]>'.format( res[1:-1], - 'locked' if self._locked else 'unlocked,value:%s' % self._value) + 'locked' if self._locked else 'unlocked,value:{}'.format( + self._value)) def locked(self): """Returns True if semaphore can not be acquired immediately.""" diff --git a/tulip/queues.py b/tulip/queues.py index ee349e13..a87a8557 100644 --- a/tulip/queues.py +++ b/tulip/queues.py @@ -45,20 +45,20 @@ def _put(self, item): self._queue.append(item) def __repr__(self): - return '<%s at %s %s>' % ( - type(self).__name__, hex(id(self)), self._format()) + return '<{} at {:#x} {}>'.format( + type(self).__name__, id(self), self._format()) def __str__(self): - return '<%s %s>' % (type(self).__name__, self._format()) + return '<{} {}>'.format(type(self).__name__, self._format()) def _format(self): - result = 'maxsize=%r' % (self._maxsize, ) + result = 'maxsize={!r}'.format(self._maxsize) if getattr(self, '_queue', None): - result += ' _queue=%r' % list(self._queue) + result += ' _queue={!r}'.format(list(self._queue)) if self._getters: - result += ' _getters[%s]' % len(self._getters) + result += ' _getters[{}]'.format(len(self._getters)) if self._putters: - result += ' _putters[%s]' % len(self._putters) + result += ' _putters[{}]'.format(len(self._putters)) return result def _consume_done_getters(self, waiters): @@ -250,7 +250,7 @@ def __init__(self, maxsize=0): def _format(self): result = Queue._format(self) if self._unfinished_tasks: - result += ' tasks=%s' % self._unfinished_tasks + result += ' tasks={}'.format(self._unfinished_tasks) return result def _put(self, item): diff --git a/tulip/tasks.py b/tulip/tasks.py index 41093564..5d3d289a 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -140,8 +140,8 @@ def _step(self, value=None, exc=None): self._event_loop.call_soon( self._step, None, RuntimeError( - 'yield was used instead of yield from in task %r ' - 'with %r' % (self, result))) + 'yield was used instead of yield from in task {!r} ' + 'with {!r}'.format(self, result))) else: result._blocking = False result.add_done_callback(self._wakeup) @@ -159,7 +159,8 @@ def _step(self, value=None, exc=None): self._step, None, RuntimeError( 'yield was used instead of yield from for ' - 'generator in task %r with %s' % (self, result))) + 'generator in task {!r} with {}'.format( + self, result))) else: if result is not None: logging.warn('_step(): bad yield: %r', result) From eb771b087710f9e8f68baae77098947384d4d9d1 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 3 Apr 2013 22:23:33 -0700 Subject: [PATCH 0399/1502] Sorry, I do not want non-ASCII characters in source code. --- tests/http_client_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/http_client_test.py b/tests/http_client_test.py index 0aa9d0bd..973a6cb0 100644 --- a/tests/http_client_test.py +++ b/tests/http_client_test.py @@ -215,13 +215,13 @@ def join(*suffix): return urllib.parse.urljoin('http://python.org/', '/'.join(suffix)) url = 'http://python.org' - req = HttpRequest('get', url, params={'foo': 'føø'}) + req = HttpRequest('get', url, params={'foo': 'f\xf8\xf8'}) self.assertEqual('/?foo=f%C3%B8%C3%B8', req.path) - req = HttpRequest('', url, params={'føø': 'føø'}) + req = HttpRequest('', url, params={'f\xf8\xf8': 'f\xf8\xf8'}) self.assertEqual('/?f%C3%B8%C3%B8=f%C3%B8%C3%B8', req.path) req = HttpRequest('', url, params={'foo': 'foo'}) self.assertEqual('/?foo=foo', req.path) - req = HttpRequest('', join('ø'), params={'foo': 'foo'}) + req = HttpRequest('', join('\xf8'), params={'foo': 'foo'}) self.assertEqual('/%C3%B8?foo=foo', req.path) def test_query_multivalued_param(self): From a27d678bb508d7e4e6eaac75b56cc51a4409b4e9 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 4 Apr 2013 04:16:00 -0700 Subject: [PATCH 0400/1502] race condition in tasks.wait() --- tests/tasks_test.py | 1 - tulip/tasks.py | 17 ++++++++--------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index ab713869..3b515179 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -254,7 +254,6 @@ def test_wait_first_completed(self): self.assertEqual({a}, pending) def test_wait_really_done(self): - self.suppress_log_errors() # there is possibility that some tasks in the pending list # became done but their callbacks haven't all been called yet diff --git a/tulip/tasks.py b/tulip/tasks.py index 16b50000..fb8a0306 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -236,16 +236,15 @@ def _on_completion(f): f.exception() is not None)): bail.cancel() + for f in pending: + f.remove_done_callback(_on_completion) + + for f in pending: + f.add_done_callback(_on_completion) try: - for f in pending: - f.add_done_callback(_on_completion) - try: - yield from bail - except futures.CancelledError: - pass - finally: - for f in pending: - f.remove_done_callback(_on_completion) + yield from bail + except futures.CancelledError: + pass really_done = set(f for f in pending if f.done()) if really_done: From b9d849371d7051fc609320d51bafaa0679ffdfd6 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 4 Apr 2013 17:21:51 -0700 Subject: [PATCH 0401/1502] Task.cancel() race condition --- tests/tasks_test.py | 51 ++++++++++++++++++++++++++------------------- tulip/tasks.py | 40 ++++++++++++++++++++++------------- 2 files changed, 55 insertions(+), 36 deletions(-) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 3b515179..097319d1 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -437,38 +437,47 @@ def doit(): t1 = time.monotonic() self.assertTrue(0.09 <= t1-t0 <= 0.13, (t1-t0, sleepfut, doer)) + def test_task_cancel_waiter_future(self): + fut = futures.Future() + + @tasks.task + def coro(): + try: + yield from fut + except futures.CancelledError: + pass + + task = coro() + self.event_loop.run_once() + self.assertIs(task._fut_waiter, fut) + + task.cancel() + self.assertRaises( + futures.CancelledError, self.event_loop.run_until_complete, task) + self.assertIsNone(task._fut_waiter) + self.assertTrue(fut.cancelled()) + @unittest.mock.patch('tulip.tasks.tulip_log') def test_step_in_completed_task(self, m_logging): @tasks.coroutine def notmuch(): - yield from [] return 'ko' task = tasks.Task(notmuch()) task.set_result('ok') - task._step() - self.assertTrue(m_logging.warning.called) - self.assertTrue(m_logging.warning.call_args[0][0].startswith( - '_step(): already done: ')) + self.assertRaises(AssertionError, task._step) @unittest.mock.patch('tulip.tasks.tulip_log') def test_step_result(self, m_logging): @tasks.coroutine def notmuch(): - yield from [None, 1] + yield None + yield 1 return 'ko' - task = tasks.Task(notmuch()) - task._step() - self.assertFalse(m_logging.warning.called) - - task._step() - self.assertTrue(m_logging.warning.called) - self.assertEqual( - '_step(): bad yield: %r', - m_logging.warning.call_args[0][0]) - self.assertEqual(1, m_logging.warning.call_args[0][1]) + self.assertRaises( + RuntimeError, self.event_loop.run_until_complete, notmuch()) def test_step_result_future(self): # If coroutine returns future, task waits on this future. @@ -502,7 +511,6 @@ def wait_for_future(): def test_step_result_concurrent_future(self): # Coroutine returns concurrent.futures.Future - self.suppress_log_warnings() class Fut(concurrent.futures.Future): def __init__(self): @@ -517,16 +525,15 @@ def add_done_callback(self, fn): @tasks.coroutine def notmuch(): - yield from [c_fut] - return (yield) + return (yield c_fut) task = tasks.Task(notmuch()) - task._step() + self.event_loop.run_once() self.assertTrue(c_fut.cb_added) res = object() c_fut.set_result(res) - self.event_loop.run() + self.event_loop.run_once() self.assertIs(res, task.result()) def test_step_with_baseexception(self): @@ -584,6 +591,7 @@ def fn2(): self.assertTrue(tasks.iscoroutinefunction(fn2)) def test_yield_vs_yield_from(self): + self.suppress_log_errors() fut = futures.Future() @tasks.task @@ -596,6 +604,7 @@ def wait_for_future(): self.event_loop.run_until_complete, task) def test_yield_vs_yield_from_generator(self): + self.suppress_log_errors() fut = futures.Future() @tasks.coroutine diff --git a/tulip/tasks.py b/tulip/tasks.py index fb8a0306..52988134 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -25,7 +25,7 @@ def coroutine(func): if inspect.isgeneratorfunction(func): coro = func else: - tulip_log.warning( + tulip_log.debug( 'Coroutine function %s is not a generator.', func.__name__) @functools.wraps(func) @@ -69,6 +69,7 @@ def __init__(self, coro, event_loop=None, timeout=None): assert inspect.isgenerator(coro) # Must be a coroutine *object*. super().__init__(event_loop=event_loop, timeout=timeout) self._coro = coro + self._fut_waiter = None self._must_cancel = False self._event_loop.call_soon(self._step) @@ -89,7 +90,11 @@ def cancel(self): return False self._must_cancel = True # _step() will call super().cancel() to call the callbacks. - self._event_loop.call_soon(self._step_maybe) + if self._fut_waiter is not None: + assert not self._fut_waiter.done(), 'Assume it is a race condition.' + self._fut_waiter.cancel() + else: + self._event_loop.call_soon(self._step_maybe) return True def cancelled(self): @@ -101,10 +106,11 @@ def _step_maybe(self): return self._step() def _step(self, value=None, exc=None): - if self.done(): - tulip_log.warning( - '_step(): already done: %r, %r, %r', self, value, exc) - return + assert not self.done(), \ + '_step(): already done: {!r}, {!r}, {!r}'.format(self, value, exc) + + self._fut_waiter = None + # We'll call either coro.throw(exc) or coro.send(value). if self._must_cancel: exc = futures.CancelledError @@ -139,13 +145,14 @@ def _step(self, value=None, exc=None): if isinstance(result, futures.Future): if not result._blocking: self._event_loop.call_soon( - self._step, - None, RuntimeError( - 'yield was used instead of yield from in task {!r} ' - 'with {!r}'.format(self, result))) + self._step, None, + RuntimeError( + 'yield was used instead of yield from ' + 'in task {!r} with {!r}'.format(self, result))) else: result._blocking = False result.add_done_callback(self._wakeup) + self._fut_waiter = result elif isinstance(result, concurrent.futures.Future): # This ought to be more efficient than wrap_future(), @@ -157,16 +164,19 @@ def _step(self, value=None, exc=None): else: if inspect.isgenerator(result): self._event_loop.call_soon( - self._step, - None, RuntimeError( + self._step, None, + RuntimeError( 'yield was used instead of yield from for ' 'generator in task {!r} with {}'.format( self, result))) else: if result is not None: - tulip_log.warning('_step(): bad yield: %r', result) - - self._event_loop.call_soon(self._step) + self._event_loop.call_soon( + self._step, None, + RuntimeError( + 'Task received bad yield: {!r}'.format(result))) + else: + self._event_loop.call_soon(self._step) def _wakeup(self, future): try: From 9d049e8bf6622d89c133014fe6a2d560e6dd807f Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 5 Apr 2013 20:19:48 -0700 Subject: [PATCH 0402/1502] allow to use @task decorator for non generator functions --- tests/base_events_test.py | 1 + tests/events_test.py | 13 ++++++++-- tests/http_protocol_test.py | 11 ++++++++- tests/http_server_test.py | 2 -- tests/locks_test.py | 2 ++ tests/streams_test.py | 3 ++- tests/tasks_test.py | 49 +++++++++++++++++++++++++------------ tulip/tasks.py | 18 ++++++++------ 8 files changed, 71 insertions(+), 28 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index 21b7a4d2..10f7c480 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -258,6 +258,7 @@ def test_create_connection_mutiple_errors(self, m_socket): class MyProto(protocols.Protocol): pass + @tasks.coroutine def getaddrinfo(*args, **kw): yield from [] return [(2, 1, 6, '', ('107.6.106.82', 80)), diff --git a/tests/events_test.py b/tests/events_test.py index 58a9c3d9..80f4ac35 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -136,7 +136,6 @@ def test_run_nesting(self): @tasks.coroutine def coro(): - yield from [] self.assertTrue(self.event_loop.is_running()) self.event_loop.run_until_complete(tasks.sleep(0.1)) @@ -149,7 +148,6 @@ def test_run_once_nesting(self): @tasks.coroutine def coro(): - yield from [] tasks.sleep(0.1) self.event_loop.run_once() @@ -640,6 +638,7 @@ def test_create_connection_no_getaddrinfo(self): def test_create_connection_connect_err(self): self.suppress_log_errors() + @tasks.coroutine def getaddrinfo(*args, **kw): yield from [] return [(2, 1, 6, '', ('107.6.106.82', 80))] @@ -654,6 +653,7 @@ def getaddrinfo(*args, **kw): def test_create_connection_mutiple_errors(self): self.suppress_log_errors() + @tasks.coroutine def getaddrinfo(*args, **kw): yield from [] return [(2, 1, 6, '', ('107.6.106.82', 80)), @@ -1188,6 +1188,15 @@ def callback(*args): self.assertRaises( AssertionError, events.make_handle, h1, (1, 2)) + @unittest.mock.patch('tulip.events.tulip_log') + def test_callback_with_exception(self, log): + def callback(): + raise ValueError() + + h = events.Handle(callback, ()) + h.run() + self.assertTrue(log.exception.called) + class TimerTests(unittest.TestCase): diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py index 6fd0b64c..c806337c 100644 --- a/tests/http_protocol_test.py +++ b/tests/http_protocol_test.py @@ -374,6 +374,7 @@ def test_read_message_length_payload_incomplete(self): msg = self.loop.run_until_complete(self.stream.read_message()) + @tulip.coroutine def coro(): self.stream.feed_data(b'data') self.stream.feed_eof() @@ -389,6 +390,7 @@ def test_read_message_eof_payload(self): msg = self.loop.run_until_complete( self.stream.read_message(readall=True)) + @tulip.coroutine def coro(): self.stream.feed_data(b'data') self.stream.feed_eof() @@ -422,6 +424,7 @@ def test_read_message_length_payload_extra(self): msg = self.loop.run_until_complete(self.stream.read_message()) + @tulip.coroutine def coro(): self.stream.feed_data(b'da') self.stream.feed_data(b't') @@ -442,8 +445,8 @@ def test_parse_length_payload_eof_exc(self): self.stream._parser = parser self.stream.feed_data(b'da') + @tulip.coroutine def eof(): - yield from [] self.stream.feed_eof() t1 = tulip.Task(stream.read()) @@ -466,6 +469,7 @@ def test_read_message_deflate_payload(self): msg = self.loop.run_until_complete( self.stream.read_message(readall=True)) + @tulip.coroutine def coro(): self.stream.feed_data(data) return (yield from msg.payload.read()) @@ -480,6 +484,7 @@ def test_read_message_chunked_payload(self): msg = self.loop.run_until_complete(self.stream.read_message()) + @tulip.coroutine def coro(): self.stream.feed_data( b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') @@ -495,6 +500,7 @@ def test_read_message_chunked_payload_chunks(self): msg = self.loop.run_until_complete(self.stream.read_message()) + @tulip.coroutine def coro(): self.stream.feed_data(b'4\r\ndata\r') self.stream.feed_data(b'\n4') @@ -514,6 +520,7 @@ def test_read_message_chunked_payload_incomplete(self): msg = self.loop.run_until_complete(self.stream.read_message()) + @tulip.coroutine def coro(): self.stream.feed_data(b'4\r\ndata\r\n') self.stream.feed_eof() @@ -530,6 +537,7 @@ def test_read_message_chunked_payload_extension(self): msg = self.loop.run_until_complete(self.stream.read_message()) + @tulip.coroutine def coro(): self.stream.feed_data( b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') @@ -545,6 +553,7 @@ def test_read_message_chunked_payload_size_error(self): msg = self.loop.run_until_complete(self.stream.read_message()) + @tulip.coroutine def coro(): self.stream.feed_data(b'blah\r\n') return (yield from msg.payload.read()) diff --git a/tests/http_server_test.py b/tests/http_server_test.py index 04fc60b5..2ab41840 100644 --- a/tests/http_server_test.py +++ b/tests/http_server_test.py @@ -160,7 +160,6 @@ def test_handle_coro(self): def coro(rline, message): nonlocal called called = True - yield from [] srv.eof_received() srv.handle_request = coro @@ -198,7 +197,6 @@ def test_handle_cancel(self): @tulip.task def cancel(): - yield from [] srv._request_handle.cancel() srv.close() diff --git a/tests/locks_test.py b/tests/locks_test.py index e761c677..5f1c180a 100644 --- a/tests/locks_test.py +++ b/tests/locks_test.py @@ -24,6 +24,7 @@ def test_repr(self): lock = locks.Lock() self.assertTrue(repr(lock).endswith('[unlocked]>')) + @tasks.coroutine def acquire_lock(): yield from lock @@ -33,6 +34,7 @@ def acquire_lock(): def test_lock(self): lock = locks.Lock() + @tasks.coroutine def acquire_lock(): return (yield from lock) diff --git a/tests/streams_test.py b/tests/streams_test.py index dc6eeaf4..f7e2992b 100644 --- a/tests/streams_test.py +++ b/tests/streams_test.py @@ -280,10 +280,11 @@ def test_exception(self): def test_exception_waiter(self): stream = streams.StreamReader() + @tasks.coroutine def set_err(): - yield from [] stream.set_exception(ValueError()) + @tasks.coroutine def readline(): yield from stream.readline() diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 097319d1..fc08c3da 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -34,7 +34,6 @@ def tearDown(self): def test_task_class(self): @tasks.coroutine def notmuch(): - yield from [] return 'ok' t = tasks.Task(notmuch()) self.event_loop.run() @@ -56,6 +55,27 @@ def notmuch(): self.assertTrue(t.done()) self.assertEqual(t.result(), 'ko') + def test_task_decorator_func(self): + @tasks.task + def notmuch(): + return 'ko' + t = notmuch() + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + + def test_task_decorator_fut(self): + fut = futures.Future() + fut.set_result('ko') + + @tasks.task + def notmuch(): + return fut + t = notmuch() + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + def test_task_repr(self): @tasks.task def notmuch(): @@ -74,8 +94,9 @@ def notmuch(): self.assertEqual(repr(t), "Task()") def test_task_repr_custom(self): + @tasks.coroutine def coro(): - yield from [] + pass class T(futures.Future): def __repr__(self): @@ -97,12 +118,10 @@ def outer(): @tasks.task def inner1(): - yield from [] return 42 @tasks.task def inner2(): - yield from [] return 1000 t = outer() @@ -137,29 +156,30 @@ def coro(): self.assertFalse(t.cancel()) def test_future_timeout_catch(self): + self.suppress_log_errors() + @tasks.coroutine def coro(): yield from tasks.sleep(10.0) return 12 - err = None + class Cancelled(Exception): + pass @tasks.coroutine def coro2(): - nonlocal err try: yield from tasks.Task(coro(), timeout=0.1) - except futures.CancelledError as exc: - err = exc + except futures.CancelledError: + raise Cancelled() - self.event_loop.run_until_complete(tasks.Task(coro2())) - self.assertIsInstance(err, futures.CancelledError) + self.assertRaises( + Cancelled, self.event_loop.run_until_complete, coro2()) def test_cancel_in_coro(self): @tasks.coroutine def task(): t.cancel() - yield from [] return 12 t = tasks.Task(task()) @@ -259,11 +279,12 @@ def test_wait_really_done(self): @tasks.coroutine def coro1(): - yield from [None] + yield @tasks.coroutine def coro2(): - yield from [None, None] + yield + yield a = tasks.Task(coro1()) b = tasks.Task(coro2()) @@ -280,7 +301,6 @@ def test_wait_first_exception(self): @tasks.coroutine def exc(): - yield from [] raise ZeroDivisionError('err') b = tasks.Task(exc()) @@ -541,7 +561,6 @@ def test_step_with_baseexception(self): @tasks.coroutine def notmutch(): - yield from [] raise BaseException() task = tasks.Task(notmutch()) diff --git a/tulip/tasks.py b/tulip/tasks.py index 52988134..c54df5bd 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -25,16 +25,11 @@ def coroutine(func): if inspect.isgeneratorfunction(func): coro = func else: - tulip_log.debug( - 'Coroutine function %s is not a generator.', func.__name__) - @functools.wraps(func) def coro(*args, **kw): res = func(*args, **kw) - if isinstance(res, futures.Future) or inspect.isgenerator(res): res = yield from res - return res coro._is_coroutine = True # Not sure who can use this. @@ -56,9 +51,18 @@ def iscoroutine(obj): def task(func): """Decorator for a coroutine to be wrapped in a Task.""" + if inspect.isgeneratorfunction(func): + coro = func + else: + def coro(*args, **kw): + res = func(*args, **kw) + if isinstance(res, futures.Future) or inspect.isgenerator(res): + res = yield from res + return res + def task_wrapper(*args, **kwds): - coro = func(*args, **kwds) - return Task(coro) + return Task(coro(*args, **kwds)) + return task_wrapper From 66288cce2230747f93632b44db1ff32ccd2b269e Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 5 Apr 2013 20:23:30 -0700 Subject: [PATCH 0403/1502] set future exception instead of throw into coroutine --- tests/tasks_test.py | 8 +++++--- tulip/tasks.py | 11 +++++------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index fc08c3da..fa22d62c 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -618,9 +618,11 @@ def wait_for_future(): yield fut task = wait_for_future() - self.assertRaises( - RuntimeError, - self.event_loop.run_until_complete, task) + with self.assertRaises(RuntimeError) as cm: + self.event_loop.run_until_complete(task) + + self.assertTrue(fut.done()) + self.assertIs(fut.exception(), cm.exception) def test_yield_vs_yield_from_generator(self): self.suppress_log_errors() diff --git a/tulip/tasks.py b/tulip/tasks.py index c54df5bd..3d7acc79 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -148,15 +148,14 @@ def _step(self, value=None, exc=None): # XXX No check for self._must_cancel here? if isinstance(result, futures.Future): if not result._blocking: - self._event_loop.call_soon( - self._step, None, + result.set_exception( RuntimeError( 'yield was used instead of yield from ' 'in task {!r} with {!r}'.format(self, result))) - else: - result._blocking = False - result.add_done_callback(self._wakeup) - self._fut_waiter = result + + result._blocking = False + result.add_done_callback(self._wakeup) + self._fut_waiter = result elif isinstance(result, concurrent.futures.Future): # This ought to be more efficient than wrap_future(), From d2b63fab205462bdefcf556720cc9bc3269003a2 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 8 Apr 2013 10:48:44 -0700 Subject: [PATCH 0404/1502] EventLoop.stop_serving() - stop listening for incoming connections --- tests/events_test.py | 18 ++++++++++++++++++ tulip/events.py | 4 ++++ tulip/selector_events.py | 4 ++++ 3 files changed, 26 insertions(+) diff --git a/tests/events_test.py b/tests/events_test.py index 80f4ac35..4df5e04d 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -781,6 +781,22 @@ def test_start_serving_sock(self): sock.close() client.close() + def test_stop_serving(self): + f = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + sock = self.event_loop.run_until_complete(f) + host, port = sock.getsockname() + + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + + self.event_loop.stop_serving(sock) + + client = socket.socket() + self.assertRaises( + ConnectionRefusedError, client.connect, ('127.0.0.1', port)) + def test_start_serving_host_port_sock(self): self.suppress_log_errors() fut = self.event_loop.start_serving( @@ -1297,6 +1313,8 @@ def test_not_imlemented(self): NotImplementedError, ev_loop.create_connection, f) self.assertRaises( NotImplementedError, ev_loop.start_serving, f) + self.assertRaises( + NotImplementedError, ev_loop.stop_serving, f) self.assertRaises( NotImplementedError, ev_loop.create_datagram_endpoint, f) self.assertRaises( diff --git a/tulip/events.py b/tulip/events.py index 3a6ad40c..68cd7211 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -185,6 +185,10 @@ def start_serving(self, protocol_factory, host=None, port=None, *, family=0, proto=0, flags=0, sock=None): raise NotImplementedError + def stop_serving(self, sock): + """Stop listening for incoming connections. Close socket.""" + raise NotImplementedError + def create_datagram_endpoint(self, protocol_factory, local_addr=None, remote_addr=None, *, family=0, proto=0, flags=0): diff --git a/tulip/selector_events.py b/tulip/selector_events.py index b8ce3bcf..e072298c 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -315,6 +315,10 @@ def _process_events(self, event_list): else: self._add_callback(writer) + def stop_serving(self, sock): + self.remove_reader(sock.fileno()) + sock.close() + class _SelectorSocketTransport(transports.Transport): From aa83cbd0e5024cdc404c68c8706b7c85a6c1cfab Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 8 Apr 2013 11:13:34 -0700 Subject: [PATCH 0405/1502] iocp event loop does not support stop_serving() --- tests/events_test.py | 5 ++++- tests/http_client_functional_test.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 4df5e04d..e8855548 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -166,7 +166,7 @@ def run(): time.sleep(0.1) self.event_loop.call_soon_threadsafe(callback) - self.event_loop.run_once(0.001) # windows iocp + self.event_loop.run_once(0) # windows iocp t = threading.Thread(target=run) t0 = time.monotonic() @@ -1137,6 +1137,9 @@ def test_create_datagram_endpoint_socket_err(self): def test_create_datagram_endpoint_connect_err(self): raise unittest.SkipTest( "IocpEventLoop does not have create_datagram_endpoint()") + def test_stop_serving(self): + raise unittest.SkipTest( + "IocpEventLoop does not support stop_serving()") else: from tulip import selectors from tulip import unix_events diff --git a/tests/http_client_functional_test.py b/tests/http_client_functional_test.py index 99c758ab..120f78b8 100644 --- a/tests/http_client_functional_test.py +++ b/tests/http_client_functional_test.py @@ -344,7 +344,7 @@ def test_timeout(self): def test_request_conn_error(self): self.assertRaises( - ConnectionRefusedError, + OSError, self.loop.run_until_complete, client.request('get', 'http://0.0.0.0:1', timeout=0.1)) From 18de1219c0033893c0b677df3c20b2d6ae63623b Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 8 Apr 2013 11:36:47 -0700 Subject: [PATCH 0406/1502] prevent writing to transport after socket exception --- tests/proactor_events_test.py | 12 +++++++++- tests/selector_events_test.py | 41 +++++++++++++++++++++++++++++++++-- tulip/proactor_events.py | 7 ++++++ tulip/selector_events.py | 31 +++++++++++++++++++++++++- 4 files changed, 87 insertions(+), 4 deletions(-) diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py index 039aa886..7a92ad08 100644 --- a/tests/proactor_events_test.py +++ b/tests/proactor_events_test.py @@ -123,7 +123,8 @@ def test_loop_writing(self): self.event_loop._proactor.send.return_value.add_done_callback.\ assert_called_with(tr._loop_writing) - def test_loop_writing_err(self): + @unittest.mock.patch('tulip.proactor_events.tulip_log') + def test_loop_writing_err(self, m_log): err = self.event_loop._proactor.send.side_effect = OSError() tr = _ProactorSocketTransport( self.event_loop, self.sock, self.protocol) @@ -131,6 +132,15 @@ def test_loop_writing_err(self): tr._buffer = [b'da', b'ta'] tr._loop_writing() tr._fatal_error.assert_called_with(err) + self.assertEqual(tr._conn_lost, 1) + + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + self.assertEqual(tr._buffer, []) + m_log.warning.assert_called_with('socket.send() raised exception.') def test_loop_writing_stop(self): fut = tulip.Future() diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index c63db15b..031afeb9 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -663,7 +663,8 @@ def test_write_tryagain(self): self.assertEqual([b'data'], transport._buffer) - def test_write_exception(self): + @unittest.mock.patch('tulip.selector_events.tulip_log') + def test_write_exception(self, m_log): err = self.sock.send.side_effect = OSError() data = b'data' @@ -672,6 +673,17 @@ def test_write_exception(self): transport._fatal_error = unittest.mock.Mock() transport.write(data) transport._fatal_error.assert_called_with(err) + self.assertEqual(transport._conn_lost, 1) + + self.sock.reset_mock() + transport.write(data) + self.assertFalse(self.sock.send.called) + self.assertEqual(transport._conn_lost, 2) + transport.write(data) + transport.write(data) + transport.write(data) + transport.write(data) + m_log.warning.assert_called_with('socket.send() raised exception.') def test_write_str(self): transport = _SelectorSocketTransport( @@ -756,6 +768,7 @@ def test_write_ready_exception(self): transport._buffer.append(b'data') transport._write_ready() transport._fatal_error.assert_called_with(err) + self.assertEqual(transport._conn_lost, 1) def test_close(self): transport = _SelectorSocketTransport( @@ -870,6 +883,17 @@ def test_write_closing(self): self.transport.close() self.assertRaises(AssertionError, self.transport.write, b'data') + @unittest.mock.patch('tulip.selector_events.tulip_log') + def test_write_exception(self, m_log): + self.transport._conn_lost = 1 + self.transport.write(b'data') + self.assertEqual(self.transport._buffer, []) + self.transport.write(b'data') + self.transport.write(b'data') + self.transport.write(b'data') + self.transport.write(b'data') + m_log.warning.assert_called_with('socket.send() raised exception.') + def test_abort(self): self.transport._close = unittest.mock.Mock() self.transport.abort() @@ -1008,6 +1032,7 @@ def test_on_ready_send_exc(self): self.transport._on_ready() self.transport._fatal_error.assert_called_with(err) self.assertEqual([], self.transport._buffer) + self.assertEqual(self.transport._conn_lost, 1) class SelectorDatagramTransportTests(unittest.TestCase): @@ -1102,7 +1127,8 @@ def test_sendto_tryagain(self): self.assertEqual( [(b'data', ('0.0.0.0', 12345))], list(transport._buffer)) - def test_sendto_exception(self): + @unittest.mock.patch('tulip.selector_events.tulip_log') + def test_sendto_exception(self, m_log): data = b'data' err = self.sock.sendto.side_effect = OSError() @@ -1111,9 +1137,18 @@ def test_sendto_exception(self): transport._fatal_error = unittest.mock.Mock() transport.sendto(data, ()) + self.assertEqual(transport._conn_lost, 1) self.assertTrue(transport._fatal_error.called) transport._fatal_error.assert_called_with(err) + transport._address = ('123',) + transport.sendto(data) + transport.sendto(data) + transport.sendto(data) + transport.sendto(data) + transport.sendto(data) + m_log.warning.assert_called_with('socket.send() raised exception.') + def test_sendto_connection_refused(self): data = b'data' @@ -1124,6 +1159,7 @@ def test_sendto_connection_refused(self): transport._fatal_error = unittest.mock.Mock() transport.sendto(data, ()) + self.assertEqual(transport._conn_lost, 0) self.assertFalse(transport._fatal_error.called) def test_sendto_connection_refused_connected(self): @@ -1136,6 +1172,7 @@ def test_sendto_connection_refused_connected(self): transport._fatal_error = unittest.mock.Mock() transport.sendto(data) + self.assertEqual(transport._conn_lost, 1) self.assertTrue(transport._fatal_error.called) def test_sendto_str(self): diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index e812c640..6f38db7d 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -20,6 +20,7 @@ def __init__(self, event_loop, sock, protocol, waiter=None, extra=None): self._buffer = [] self._read_fut = None self._write_fut = None + self._conn_lost = 0 self._closing = False # Set when close() called. self._event_loop.call_soon(self._protocol.connection_made, self) self._event_loop.call_soon(self._loop_reading) @@ -57,6 +58,11 @@ def write(self, data): assert not self._closing if not data: return + if self._conn_lost: + if self._conn_lost >= 5: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return self._buffer.append(data) if not self._write_fut: self._loop_writing() @@ -73,6 +79,7 @@ def _loop_writing(self, f=None): return self._write_fut = self._event_loop._proactor.send(self._sock, data) except OSError as exc: + self._conn_lost += 1 self._fatal_error(exc) else: self._write_fut.add_done_callback(self._loop_writing) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index e072298c..2e93132e 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -329,6 +329,7 @@ def __init__(self, event_loop, sock, protocol, waiter=None, extra=None): self._sock = sock self._protocol = protocol self._buffer = [] + self._conn_lost = 0 self._closing = False # Set when close() called. self._event_loop.add_reader(self._sock.fileno(), self._read_ready) self._event_loop.call_soon(self._protocol.connection_made, self) @@ -355,6 +356,12 @@ def write(self, data): if not data: return + if self._conn_lost: + if self._conn_lost >= 5: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + if not self._buffer: # Attempt to send it right away first. try: @@ -362,6 +369,7 @@ def write(self, data): except (BlockingIOError, InterruptedError): n = 0 except socket.error as exc: + self._conn_lost += 1 self._fatal_error(exc) return @@ -383,6 +391,7 @@ def _write_ready(self): except (BlockingIOError, InterruptedError): self._buffer.append(data) except Exception as exc: + self._conn_lost += 1 self._fatal_error(exc) else: if n == len(data): @@ -440,6 +449,7 @@ def __init__(self, event_loop, rawsock, protocol, sslcontext, waiter, do_handshake_on_connect=False) self._sslsock = sslsock self._buffer = [] + self._conn_lost = 0 self._closing = False # Set when close() called. self._extra['socket'] = sslsock @@ -519,6 +529,7 @@ def _on_ready(self): except (BlockingIOError, InterruptedError): n = 0 except Exception as exc: + self._conn_lost += 1 self._fatal_error(exc) return @@ -534,6 +545,13 @@ def write(self, data): assert not self._closing if not data: return + + if self._conn_lost: + if self._conn_lost >= 5: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + self._buffer.append(data) # We could optimize, but the callback can do this for now. @@ -572,6 +590,7 @@ def __init__(self, event_loop, sock, protocol, address=None, extra=None): self._protocol = protocol self._address = address self._buffer = collections.deque() + self._conn_lost = 0 self._closing = False # Set when close() called. self._event_loop.add_reader(self._fileno, self._read_ready) self._event_loop.call_soon(self._protocol.connection_made, self) @@ -595,6 +614,12 @@ def sendto(self, data, addr=None): if self._address: assert addr in (None, self._address) + if self._conn_lost and self._address: + if self._conn_lost >= 5: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + if not self._buffer: # Attempt to send it right away first. try: @@ -605,11 +630,13 @@ def sendto(self, data, addr=None): return except ConnectionRefusedError as exc: if self._address: + self._conn_lost += 1 self._fatal_error(exc) - return + return except (BlockingIOError, InterruptedError): self._event_loop.add_writer(self._fileno, self._sendto_ready) except Exception as exc: + self._conn_lost += 1 self._fatal_error(exc) return @@ -625,12 +652,14 @@ def _sendto_ready(self): self._sock.sendto(data, addr) except ConnectionRefusedError as exc: if self._address: + self._conn_lost += 1 self._fatal_error(exc) return except (BlockingIOError, InterruptedError): self._buffer.appendleft((data, addr)) # Try again later. break except Exception as exc: + self._conn_lost += 1 self._fatal_error(exc) return From 2ef0332dc5c621d1eb44851d2903aa577759fa61 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 8 Apr 2013 11:49:53 -0700 Subject: [PATCH 0407/1502] Attempt to fix some line ending issue on Windows. --- .hgeol | 4 + .hgignore | 12 + Makefile | 31 + NOTES | 176 ++++ README | 21 + TODO | 163 +++ check.py | 41 + crawl.py | 98 ++ curl.py | 24 + examples/udp_echo.py | 73 ++ old/Makefile | 16 + old/echoclt.py | 79 ++ old/echosvr.py | 60 ++ old/http_client.py | 78 ++ old/http_server.py | 68 ++ old/main.py | 134 +++ old/p3time.py | 47 + old/polling.py | 535 ++++++++++ old/scheduling.py | 354 +++++++ old/sockets.py | 348 +++++++ old/transports.py | 496 +++++++++ old/xkcd.py | 18 + old/yyftime.py | 75 ++ overlapped.c | 997 ++++++++++++++++++ runtests.py | 198 ++++ setup.cfg | 2 + setup.py | 14 + srv.py | 158 +++ tests/base_events_test.py | 284 +++++ tests/events_test.py | 1431 ++++++++++++++++++++++++++ tests/futures_test.py | 222 ++++ tests/http_client_functional_test.py | 395 +++++++ tests/http_client_test.py | 308 ++++++ tests/http_protocol_test.py | 982 ++++++++++++++++++ tests/http_server_test.py | 240 +++++ tests/http_wsgi_test.py | 242 +++++ tests/locks_test.py | 745 ++++++++++++++ tests/proactor_events_test.py | 337 ++++++ tests/queues_test.py | 370 +++++++ tests/sample.crt | 14 + tests/sample.key | 15 + tests/selector_events_test.py | 1323 ++++++++++++++++++++++++ tests/selectors_test.py | 137 +++ tests/streams_test.py | 300 ++++++ tests/subprocess_test.py | 54 + tests/tasks_test.py | 676 ++++++++++++ tests/transports_test.py | 45 + tests/unix_events_test.py | 573 +++++++++++ tests/winsocketpair_test.py | 26 + tulip/TODO | 26 + tulip/__init__.py | 26 + tulip/base_events.py | 556 ++++++++++ tulip/events.py | 360 +++++++ tulip/futures.py | 255 +++++ tulip/http/__init__.py | 14 + tulip/http/client.py | 533 ++++++++++ tulip/http/errors.py | 44 + tulip/http/protocol.py | 879 ++++++++++++++++ tulip/http/server.py | 176 ++++ tulip/http/wsgi.py | 219 ++++ tulip/locks.py | 434 ++++++++ tulip/log.py | 6 + tulip/proactor_events.py | 198 ++++ tulip/protocols.py | 78 ++ tulip/queues.py | 291 ++++++ tulip/selector_events.py | 696 +++++++++++++ tulip/selectors.py | 418 ++++++++ tulip/streams.py | 145 +++ tulip/subprocess_transport.py | 139 +++ tulip/tasks.py | 334 ++++++ tulip/test_utils.py | 269 +++++ tulip/transports.py | 134 +++ tulip/unix_events.py | 301 ++++++ tulip/windows_events.py | 157 +++ tulip/winsocketpair.py | 34 + 75 files changed, 19761 insertions(+) create mode 100644 .hgeol create mode 100644 .hgignore create mode 100644 Makefile create mode 100644 NOTES create mode 100644 README create mode 100644 TODO create mode 100644 check.py create mode 100755 crawl.py create mode 100755 curl.py create mode 100644 examples/udp_echo.py create mode 100644 old/Makefile create mode 100644 old/echoclt.py create mode 100644 old/echosvr.py create mode 100644 old/http_client.py create mode 100644 old/http_server.py create mode 100644 old/main.py create mode 100644 old/p3time.py create mode 100644 old/polling.py create mode 100644 old/scheduling.py create mode 100644 old/sockets.py create mode 100644 old/transports.py create mode 100755 old/xkcd.py create mode 100644 old/yyftime.py create mode 100644 overlapped.c create mode 100644 runtests.py create mode 100644 setup.cfg create mode 100644 setup.py create mode 100755 srv.py create mode 100644 tests/base_events_test.py create mode 100644 tests/events_test.py create mode 100644 tests/futures_test.py create mode 100644 tests/http_client_functional_test.py create mode 100644 tests/http_client_test.py create mode 100644 tests/http_protocol_test.py create mode 100644 tests/http_server_test.py create mode 100644 tests/http_wsgi_test.py create mode 100644 tests/locks_test.py create mode 100644 tests/proactor_events_test.py create mode 100644 tests/queues_test.py create mode 100644 tests/sample.crt create mode 100644 tests/sample.key create mode 100644 tests/selector_events_test.py create mode 100644 tests/selectors_test.py create mode 100644 tests/streams_test.py create mode 100644 tests/subprocess_test.py create mode 100644 tests/tasks_test.py create mode 100644 tests/transports_test.py create mode 100644 tests/unix_events_test.py create mode 100644 tests/winsocketpair_test.py create mode 100644 tulip/TODO create mode 100644 tulip/__init__.py create mode 100644 tulip/base_events.py create mode 100644 tulip/events.py create mode 100644 tulip/futures.py create mode 100644 tulip/http/__init__.py create mode 100644 tulip/http/client.py create mode 100644 tulip/http/errors.py create mode 100644 tulip/http/protocol.py create mode 100644 tulip/http/server.py create mode 100644 tulip/http/wsgi.py create mode 100644 tulip/locks.py create mode 100644 tulip/log.py create mode 100644 tulip/proactor_events.py create mode 100644 tulip/protocols.py create mode 100644 tulip/queues.py create mode 100644 tulip/selector_events.py create mode 100644 tulip/selectors.py create mode 100644 tulip/streams.py create mode 100644 tulip/subprocess_transport.py create mode 100644 tulip/tasks.py create mode 100644 tulip/test_utils.py create mode 100644 tulip/transports.py create mode 100644 tulip/unix_events.py create mode 100644 tulip/windows_events.py create mode 100644 tulip/winsocketpair.py diff --git a/.hgeol b/.hgeol new file mode 100644 index 00000000..8233b6dd --- /dev/null +++ b/.hgeol @@ -0,0 +1,4 @@ +[patterns] +** = native +.hgignore = native +.hgeol = native diff --git a/.hgignore b/.hgignore new file mode 100644 index 00000000..99870025 --- /dev/null +++ b/.hgignore @@ -0,0 +1,12 @@ +.*\.py[co]$ +.*~$ +.*\.orig$ +.*\#.*$ +.*@.*$ +\.coverage$ +htmlcov$ +\.DS_Store$ +venv$ +distribute_setup.py$ +distribute-\d+.\d+.\d+.tar.gz$ +build$ diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..274da4c8 --- /dev/null +++ b/Makefile @@ -0,0 +1,31 @@ +# Some simple testing tasks (sorry, UNIX only). + +PYTHON=python3 +VERBOSE=1 +FLAGS= + +test: + $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS) + +testloop: + while sleep 1; do $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS); done + +# See README for coverage installation instructions. +cov coverage: + $(PYTHON) runtests.py --coverage tulip -v $(VERBOSE) $(FLAGS) + echo "open file://`pwd`/htmlcov/index.html" + +check: + $(PYTHON) check.py + +clean: + rm -rf `find . -name __pycache__` + rm -f `find . -type f -name '*.py[co]' ` + rm -f `find . -type f -name '*~' ` + rm -f `find . -type f -name '.*~' ` + rm -f `find . -type f -name '@*' ` + rm -f `find . -type f -name '#*#' ` + rm -f `find . -type f -name '*.orig' ` + rm -f `find . -type f -name '*.rej' ` + rm -f .coverage + rm -rf htmlcov diff --git a/NOTES b/NOTES new file mode 100644 index 00000000..3b94ba96 --- /dev/null +++ b/NOTES @@ -0,0 +1,176 @@ +Notes from PyCon 2013 sprints +============================= + +- Cancellation. If a task creates several subtasks, and then the + parent task fails, should the subtasks be cancelled? (How do we + even establish the parent/subtask relationship?) + +- Adam Sah suggests that there might be a need for scheduling + (especially when multiple frameworks share an event loop). He + points to lottery scheduling but also mentions that's just one of + the options. However, after posting on python-tulip, it appears + none of the other frameworks have scheduling, and nobody seems to + miss it. + +- Feedback from Bram Cohen (Bittorrent creator) about UDP. He doesn't + think connected UDP is worth supporting, it doesn't do anything + except tell the kernel about the default target address for + sendto(). Basically he says all UDP end points are servers. He + sent me his own UDP event loop so I might glean some tricks from it. + He says we should treat EINTR the same as EAGAIN and friends. (We + should use the exceptions dedicated to errno checking, BTW.) HE + said to make sure we use SO_REUSEADDR (I think we already do). He + said to set the max datagram sizes pretty large (anything larger + than the declared limit is dropped on the floor). He reminds us of + the importance of being able to pick a valid, unused port by binding + to port 0 and then using getsockname(). He has an idea where he's + like to be able to kill all registered callbacks (i.e. Handles) + belonging to a certain "context". I think this can be done at the + application level (you'd have to wrap everything that returns a + Handle and collect these handles in some set or other datastructure) + but if someone thinks it's interesting we could imagine having some + kind of notion of context part of the event loop state, + e.g. associated with a Task (see Cancellation point above). He + brought up uTP (Micro Transport Protocol), a reimplementation of TCP + over UDP with more refined congestion control. + +- Mumblings about UNIX domain sockets and IPv6 addresses being + 4-tuples. The former can be handled by passing in a socket. There + seem to be no real use cases for the latter that can't be dealt with + by passing in suitably esoteric strings for the hostname. + getaddrinfo() will produce the appropriate 4-tuple and connect() + will accept it. + +- Mumblings on the list about add vs. set. + + +Notes from the second Tulip/Twisted meet-up +=========================================== + +Rackspace, 12/11/2012 +Glyph, Brian Warner, David Reid, Duncan McGreggor, others + +Flow control +------------ + +- Pause/resume on transport manages data_received. + +- There's also an API to tell the transport whom to pause when the + write calls are overwhelming it: IConsumer.registerProducer(). + +- There's also something called pipes but it's built on top of the + old interface. + +- Twisted has variations on the basic flow control that I should + ignore. + +Half_close +---------- + +- This sends an EOF after writing some stuff. + +- Can't write any more. + +- Problem with TLS is known (the RFC sadly specifies this behavior). + +- It must be dynamimcally discoverable whether the transport supports + half_close, since the protocol may have to do something different to + make up for its missing (e.g. use chunked encoding). Twisted uses + an interface check for this and also hasattr(trans, 'halfClose') + but a flag (or flag method) is fine too. + +Constructing transport and protocol +----------------------------------- + +- There are good reasons for passing a function to the transport + construction helper that creates the protocol. (You need these + anyway for server-side protocols.) The sequence of events is + something like + + . open socket + . create transport (pass it a socket?) + . create protocol (pass it nothing) + . proto.make_connection(transport); this does: + . self.transport = transport + . self.connection_made(transport) + + But it seems okay to skip make_connection and setting .transport. + Note that make_connection() is a concrete method on the Protocol + implementation base class, while connection_made() is an abstract + method on IProtocol. + +Event Loop +---------- + +- We discussed the sequence of actions in the event loop. I think in the + end we're fine with what Tulip currently does. There are two choices: + + Tulip: + . run ready callbacks until there aren't any left + . poll, adding more callbacks to the ready list + . add now-ready delayed callbacks to the ready list + . go to top + + Tornado: + . run all currently ready callbacks (but not new ones added during this) + . (the rest is the same) + + The difference is that in the Tulip version, CPU bound callbacks + that keep adding more to the queue will starve I/O (and yielding to + other tasks won't actually cause I/O to happen unless you do + e.g. sleep(0.001)). OTOH this may be good because it means there's + less overhead if you frequently split operations in two. + +- I think Twisted does it Tornado style (in a convoluted way :-), but + it may not matter, and it's important to leave this vague so + implementations can do what's best for their platform. (E.g. if the + event loop is built into the OS there are different trade-offs.) + +System call cost +---------------- + +- System calls on MacOS are expensive, on Linux they are cheap. + +- Optimal buffer size ~16K. + +- Try joining small buffer pieces together, but expect to be tuning + this later. + +Futures +------- + +- Futures are the most robust API for async stuff, you can check + errors etc. So let's do this. + +- Just don't implement wait(). + +- For the basics, however, (recv/send, mostly), don't use Futures but use + basic callbacks, transport/protocol style. + +- make_connection() (by any name) can return a Future, it makes it + easier to check for errors. + +- This means revisiting the Tulip proactor branch (IOCP). + +- The semantics of add_done_callback() are fuzzy about in which thread + the callback will be called. (It may be the current thread or + another one.) We don't like that. But always inserting a + call_soon() indirection may be expensive? Glyph suggested changing + the add_done_callback() method name to something else to indicate + the changed promise. + +- Separately, I've been thinking about having two versions of + call_soon() -- a more heavy-weight one to be called from other + threads that also writes a byte to the self-pipe. + +Signals +------- + +- There was a side conversation about signals. A signal handler is + similar to another thread, so probably should use (the heavy-weight + version of) call_soon() to schedule the real callback and not do + anything else. + +- Glyph vaguely recalled some trickiness with the self-pipe. We + should be able to fix this afterwards if necessary, it shouldn't + affect the API design. diff --git a/README b/README new file mode 100644 index 00000000..85bfe5a7 --- /dev/null +++ b/README @@ -0,0 +1,21 @@ +Tulip is the codename for my reference implementation of PEP 3156. + +PEP 3156: http://www.python.org/dev/peps/pep-3156/ + +*** This requires Python 3.3 or later! *** + +Copyright/license: Open source, Apache 2.0. Enjoy. + +Master Mercurial repo: http://code.google.com/p/tulip/ + +The old code lives in the subdirectory 'old'; the new code (conforming +to PEP 3156, under construction) lives in the 'tulip' subdirectory. + +To run tests: + - make test + +To run coverage (coverage package is required): + - make coverage + + +--Guido van Rossum diff --git a/TODO b/TODO new file mode 100644 index 00000000..c6d4eead --- /dev/null +++ b/TODO @@ -0,0 +1,163 @@ +# -*- Mode: text -*- + +TO DO LARGER TASKS + +- Need more examples. + +- Benchmarkable but more realistic HTTP server? + +- Example of using UDP. + +- Write up a tutorial for the scheduling API. + +- More systematic approach to logging. Logger objects? What about + heavy-duty logging, tracking essentially all task state changes? + +- Restructure directory, move demos and benchmarks to subdirectories. + + +TO DO LATER + +- When multiple tasks are accessing the same socket, they should + either get interleaved I/O or an immediate exception; it should not + compromise the integrity of the scheduler or the app or leave a task + hanging. + +- For epoll you probably want to check/(log?) EPOLLHUP and EPOLLERR errors. + +- Add the simplest API possible to run a generator with a timeout. + +- Ensure multiple tasks can do atomic writes to the same pipe (since + UNIX guarantees that short writes to pipes are atomic). + +- Ensure some easy way of distributing accepted connections across tasks. + +- Be wary of thread-local storage. There should be a standard API to + get the current Context (which holds current task, event loop, and + maybe more) and a standard meta-API to change how that standard API + works (i.e. without monkey-patching). + +- See how much of asyncore I've already replaced. + +- Could BufferedReader reuse the standard io module's readers??? + +- Support ZeroMQ "sockets" which are user objects. Though possibly + this can be supported by getting the underlying fd? See + http://mail.python.org/pipermail/python-ideas/2012-October/017532.html + OTOH see + https://github.com/zeromq/pyzmq/blob/master/zmq/eventloop/ioloop.py + +- Study goroutines (again). + +- Benchmarks: http://nichol.as/benchmark-of-python-web-servers + + +FROM OLDER LIST + +- Multiple readers/writers per socket? (At which level? pollster, + eventloop, or scheduler?) + +- Could poll() usefully be an iterator? + +- Do we need to support more epoll and/or kqueue modes/flags/options/etc.? + +- Optimize register/unregister calls away if they cancel each other out? + +- Add explicit wait queue to wait for Task's completion, instead of + callbacks? + +- Look at pyfdpdlib's ioloop.py: + http://code.google.com/p/pyftpdlib/source/browse/trunk/pyftpdlib/lib/ioloop.py + + +MISTAKES I MADE + +- Forgetting yield from. (E.g.: scheduler.sleep(1); listener.accept().) + +- Forgot to add bare yield at end of internal function, after block(). + +- Forgot to call add_done_callback(). + +- Forgot to pass an undoer to block(), bug only found when cancelled. + +- Subtle accounting mistake in a callback. + +- Used context.eventloop from a different thread, forgetting about TLS. + +- Nasty race: eventloop.ready may contain both an I/O callback and a + cancel callback. How to avoid? Keep the DelayedCall in ready. Is + that enough? + +- If a toplevel task raises an error it just stops and nothing is logged + unless you have debug logging on. This confused me. (Then again, + previously I logged whenever a task raised an error, and that was too + chatty...) + +- Forgot to set the connection socket returned by accept() in + nonblocking mode. + +- Nastiest so far (cost me about a day): A race condition in + call_in_thread() where the Future's done_callback (which was + task.unblock()) would run immediately at the time when + add_done_callback() was called, and this screwed over the task + state. Solution: wrap the callback in eventloop.call_later(). + Ironically, I had a comment stating there might be a race condition. + +- Another bug where I was calling unblock() for the current thread + immediately after calling block(), before yielding. + +- readexactly() wasn't checking for EOF, so could be looping. + (Worse, the first fix I attempted was wrong.) + +- Spent a day trying to understand why a tentative patch trying to + move the recv() implementation into the eventloop (or the pollster) + resulted in problems cancelling a recv() call. Ultimately the + problem is that the cancellation mechanism is part of the coroutine + scheduler, which simply throws an exception into a task when it next + runs, and there isn't anything to be interrupted in the eventloop; + but the eventloop still has a reader registered (which will never + fire because I suspended the server -- that's my test case :-). + Then, the eventloop keeps running until the last file descriptor is + unregistered. What contributed to this disaster? + * I didn't build the whole infrastructure, just played with recv() + * I don't have unittests + * I don't have good logging to see what is going + +- In sockets.py, in some SSL error handling code, used the wrong + variable (sock instead of sslsock). A linter would have found this. + +- In polling.py, in KqueuePollster.register_writer(), a copy/paste + error where I was testing for "if fd not in self.readers" instead of + writers. This only came out when I had both a reader and a writer + for the same fd. + +- Submitted some changes prematurely (forgot to pass the filename on + hg ci). + +- Forgot again that shutdown(SHUT_WR) on an ssl socket does not work + as I expected. I ran into this with the origininal sockets.py and + again in transport.py. + +- Having the same callback for both reading and writing has a problem: + it may be scheduled twice, and if the first call closes the socket, + the second runs into trouble. + + +MISTAKES I MADE IN TULIP V2 + +- Nice one: Future._schedule_callbacks() wasn't scheduling any callbacks. + Spot the bug in these four lines: + + def _schedule_callbacks(self): + callbacks = self._callbacks[:] + self._callbacks[:] = [] + for callback in self._callbacks: + self._event_loop.call_soon(callback, self) + + The good news is that I found it with a unittest (albeit not the + unittest intended to exercise this particular method :-( ). + +- In _make_self_pipe_or_sock(), called _pollster.register_reader() + instead of add_reader(), trying to optimize something but breaking + things instead (since the -- internal -- API of register_reader() + had changed). diff --git a/check.py b/check.py new file mode 100644 index 00000000..d28b31f7 --- /dev/null +++ b/check.py @@ -0,0 +1,41 @@ +"""Search for lines > 80 chars or with trailing whitespace.""" + +import sys, os + +def main(): + args = sys.argv[1:] or os.curdir + for arg in args: + if os.path.isdir(arg): + for dn, dirs, files in os.walk(arg): + for fn in sorted(files): + if fn.endswith('.py'): + process(os.path.join(dn, fn)) + dirs[:] = [d for d in dirs if d[0] != '.'] + dirs.sort() + else: + process(arg) + +def isascii(x): + try: + x.encode('ascii') + return True + except UnicodeError: + return False + +def process(fn): + try: + f = open(fn) + except IOError as err: + print(err) + return + try: + for i, line in enumerate(f): + line = line.rstrip('\n') + sline = line.rstrip() + if len(line) > 80 or line != sline or not isascii(line): + print('{}:{:d}:{}{}'.format( + fn, i+1, sline, '_' * (len(line) - len(sline)))) + finally: + f.close() + +main() diff --git a/crawl.py b/crawl.py new file mode 100755 index 00000000..723d5305 --- /dev/null +++ b/crawl.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 + +import logging +import re +import signal +import sys +import urllib.parse + +import tulip +import tulip.http + + +class Crawler: + + def __init__(self, rooturl, loop, maxtasks=100): + self.rooturl = rooturl + self.loop = loop + self.todo = set() + self.busy = set() + self.done = {} + self.tasks = set() + self.sem = tulip.Semaphore(maxtasks) + + @tulip.task + def run(self): + self.addurls([(self.rooturl, '')]) # Set initial work. + yield from tulip.sleep(1) + while self.busy: + yield from tulip.sleep(1) + self.loop.stop() + + @tulip.task + def addurls(self, urls): + for url, parenturl in urls: + url = urllib.parse.urljoin(parenturl, url) + url, frag = urllib.parse.urldefrag(url) + if (url.startswith(self.rooturl) and + url not in self.busy and + url not in self.done and + url not in self.todo): + self.todo.add(url) + yield from self.sem.acquire() + task = self.process(url) + task.add_done_callback(lambda t: self.sem.release()) + task.add_done_callback(self.tasks.remove) + self.tasks.add(task) + + @tulip.task + def process(self, url): + print('processing:', url) + + self.todo.remove(url) + self.busy.add(url) + try: + resp = yield from tulip.http.request('get', url) + except Exception as exc: + print('...', url, 'has error', repr(str(exc))) + self.done[url] = False + else: + if resp.status == 200 and resp.get_content_type() == 'text/html': + data = (yield from resp.read()).decode('utf-8', 'replace') + urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', data) + self.addurls([(u, url) for u in urls]) + + resp.close() + self.done[url] = True + + self.busy.remove(url) + print(len(self.done), 'completed tasks,', len(self.tasks), + 'still pending, todo', len(self.todo)) + + +def main(): + loop = tulip.get_event_loop() + + c = Crawler(sys.argv[1], loop) + c.run() + + try: + loop.add_signal_handler(signal.SIGINT, loop.stop) + except RuntimeError: + pass + loop.run_forever() + print('todo:', len(c.todo)) + print('busy:', len(c.busy)) + print('done:', len(c.done), '; ok:', sum(c.done.values())) + print('tasks:', len(c.tasks)) + + +if __name__ == '__main__': + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + logging.info('using iocp') + el = windows_events.ProactorEventLoop() + events.set_event_loop(el) + + main() diff --git a/curl.py b/curl.py new file mode 100755 index 00000000..7063adcd --- /dev/null +++ b/curl.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 + +import sys +import tulip +import tulip.http + + +def curl(url): + response = yield from tulip.http.request('GET', url) + print(repr(response)) + + data = yield from response.read() + print(data.decode('utf-8', 'replace')) + + +if __name__ == '__main__': + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + el = windows_events.ProactorEventLoop() + events.set_event_loop(el) + + loop = tulip.get_event_loop() + loop.run_until_complete(curl(sys.argv[1])) diff --git a/examples/udp_echo.py b/examples/udp_echo.py new file mode 100644 index 00000000..5d1e02ec --- /dev/null +++ b/examples/udp_echo.py @@ -0,0 +1,73 @@ +"""UDP echo example. + +Start server: + + >> python ./udp_echo.py --server + +""" + +import sys +import tulip + +ADDRESS = ('127.0.0.1', 10000) + + +class MyServerUdpEchoProtocol: + + def connection_made(self, transport): + print('start', transport) + self.transport = transport + + def datagram_received(self, data, addr): + print('Data received:', data, addr) + self.transport.sendto(data, addr) + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('stop', exc) + + +class MyClientUdpEchoProtocol: + + message = 'This is the message. It will be repeated.' + + def connection_made(self, transport): + self.transport = transport + print('sending "{}"'.format(self.message)) + self.transport.sendto(self.message.encode()) + print('waiting to receive') + + def datagram_received(self, data, addr): + print('received "{}"'.format(data.decode())) + self.transport.close() + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('closing transport', exc) + loop = tulip.get_event_loop() + loop.stop() + + +def start_server(): + loop = tulip.get_event_loop() + tulip.Task(loop.create_datagram_endpoint( + MyServerUdpEchoProtocol, local_addr=ADDRESS)) + loop.run_forever() + + +def start_client(): + loop = tulip.get_event_loop() + tulip.Task(loop.create_datagram_endpoint( + MyClientUdpEchoProtocol, remote_addr=ADDRESS)) + loop.run_forever() + + +if __name__ == '__main__': + if '--server' in sys.argv: + start_server() + else: + start_client() diff --git a/old/Makefile b/old/Makefile new file mode 100644 index 00000000..d352cd70 --- /dev/null +++ b/old/Makefile @@ -0,0 +1,16 @@ +PYTHON=python3 + +main: + $(PYTHON) main.py -v + +echo: + $(PYTHON) echosvr.py -v + +profile: + $(PYTHON) -m profile -s time main.py + +time: + $(PYTHON) p3time.py + +ytime: + $(PYTHON) yyftime.py diff --git a/old/echoclt.py b/old/echoclt.py new file mode 100644 index 00000000..c24c573e --- /dev/null +++ b/old/echoclt.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3.3 +"""Example echo client.""" + +# Stdlib imports. +import logging +import socket +import sys +import time + +# Local imports. +import scheduling +import sockets + + +def echoclient(host, port): + """COROUTINE""" + testdata = b'hi hi hi ha ha ha\n' + try: + trans = yield from sockets.create_transport(host, port, + af=socket.AF_INET) + except OSError: + return False + try: + ok = yield from trans.send(testdata) + if ok: + response = yield from trans.recv(100) + ok = response == testdata.upper() + return ok + finally: + trans.close() + + +def doit(n): + """COROUTINE""" + t0 = time.time() + tasks = set() + for i in range(n): + t = scheduling.Task(echoclient('127.0.0.1', 1111), 'client-%d' % i) + tasks.add(t) + ok = 0 + bad = 0 + for t in tasks: + try: + yield from t + except Exception: + bad += 1 + else: + ok += 1 + t1 = time.time() + print('ok: ', ok) + print('bad:', bad) + print('dt: ', round(t1-t0, 6)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Get integer from command line. + n = 1 + for arg in sys.argv[1:]: + if not arg.startswith('-'): + n = int(arg) + break + + # Run scheduler, starting it off with doit(). + scheduling.run(doit(n)) + + +if __name__ == '__main__': + main() diff --git a/old/echosvr.py b/old/echosvr.py new file mode 100644 index 00000000..4085f4c6 --- /dev/null +++ b/old/echosvr.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3.3 +"""Example echo server.""" + +# Stdlib imports. +import logging +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + while True: + line = yield from rdr.readline() + logging.debug('Received: %r from %r', line, addr) + if not line: + break + yield from trans.send(line.upper()) + logging.debug('Closing %r', addr) + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 1111, + af=socket.AF_INET, + backlog=100) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() diff --git a/old/http_client.py b/old/http_client.py new file mode 100644 index 00000000..8937ba20 --- /dev/null +++ b/old/http_client.py @@ -0,0 +1,78 @@ +"""Crummy HTTP client. + +This is not meant as an example of how to write a good client. +""" + +# Stdlib. +import re +import time + +# Local. +import sockets + + +def urlfetch(host, port=None, path='/', method='GET', + body=None, hdrs=None, encoding='utf-8', ssl=None, af=0): + """COROUTINE: Make an HTTP 1.0 request.""" + t0 = time.time() + if port is None: + if ssl: + port = 443 + else: + port = 80 + trans = yield from sockets.create_transport(host, port, ssl=ssl, af=af) + yield from trans.send(method.encode(encoding) + b' ' + + path.encode(encoding) + b' HTTP/1.0\r\n') + if hdrs: + kwds = dict(hdrs) + else: + kwds = {} + if 'host' not in kwds: + kwds['host'] = host + if body is not None: + kwds['content_length'] = len(body) + for header, value in kwds.items(): + yield from trans.send(header.replace('_', '-').encode(encoding) + + b': ' + value.encode(encoding) + b'\r\n') + + yield from trans.send(b'\r\n') + if body is not None: + yield from trans.send(body) + + # Read HTTP response line. + rdr = sockets.BufferedReader(trans) + resp = yield from rdr.readline() + m = re.match(br'(?ix) http/(\d\.\d) \s+ (\d\d\d) \s+ ([^\r]*)\r?\n\Z', + resp) + if not m: + trans.close() + raise IOError('No valid HTTP response: %r' % resp) + http_version, status, message = m.groups() + + # Read HTTP headers. + headers = [] + hdict = {} + while True: + line = yield from rdr.readline() + if not line.strip(): + break + m = re.match(br'([^\s:]+):\s*([^\r]*)\r?\n\Z', line) + if not m: + raise IOError('Invalid header: %r' % line) + header, value = m.groups() + headers.append((header, value)) + hdict[header.decode(encoding).lower()] = value.decode(encoding) + + # Read response body. + content_length = hdict.get('content-length') + if content_length is not None: + size = int(content_length) # TODO: Catch errors. + assert size >= 0, size + else: + size = 2**20 # Protective limit (1 MB). + data = yield from rdr.readexactly(size) + trans.close() # Can this block? + t1 = time.time() + result = (host, port, path, int(status), len(data), round(t1-t0, 3)) +## print(result) + return result diff --git a/old/http_server.py b/old/http_server.py new file mode 100644 index 00000000..2b1e3dd6 --- /dev/null +++ b/old/http_server.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python3.3 +"""Simple HTTP server. + +This currenty exists just so we can benchmark this thing! +""" + +# Stdlib imports. +import logging +import re +import socket +import sys + +# Local imports. +import scheduling +import sockets + + +def handler(conn, addr): + """COROUTINE: Handle one connection.""" + ##logging.info('Accepting connection from %r', addr) + trans = sockets.SocketTransport(conn) + rdr = sockets.BufferedReader(trans) + + # Read but ignore request line. + request_line = yield from rdr.readline() + + # Consume headers but don't interpret them. + while True: + header_line = yield from rdr.readline() + if not header_line.strip(): + break + + # Always send an empty 200 response and close. + yield from trans.send(b'HTTP/1.0 200 Ok\r\n\r\n') + trans.close() + + +def doit(): + """COROUTINE: Set the wheels in motion.""" + # Set up listener. + listener = yield from sockets.create_listener('localhost', 8080, + af=socket.AF_INET) + logging.info('Listening on %r', listener.sock.getsockname()) + + # Loop accepting connections. + while True: + conn, addr = yield from listener.accept() + t = scheduling.Task(handler(conn, addr)) + + +def main(): + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + scheduling.run(doit()) + + +if __name__ == '__main__': + main() diff --git a/old/main.py b/old/main.py new file mode 100644 index 00000000..c1f9d0a8 --- /dev/null +++ b/old/main.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3.3 +"""Example HTTP client using yield-from coroutines (PEP 380). + +Requires Python 3.3. + +There are many micro-optimizations possible here, but that's not the point. + +Some incomplete laundry lists: + +TODO: +- Take test urls from command line. +- Move urlfetch to a separate module. +- Profiling. +- Docstrings. +- Unittests. + +FUNCTIONALITY: +- Connection pool (keep connection open). +- Chunked encoding (request and response). +- Pipelining, e.g. zlib (request and response). +- Automatic encoding/decoding. +""" + +__author__ = 'Guido van Rossum ' + +# Standard library imports (keep in alphabetic order). +import logging +import os +import time +import socket +import sys + +# Local imports (keep in alphabetic order). +import scheduling +import http_client + + + +def doit2(): + argses = [ + ('localhost', 8080, '/'), + ('127.0.0.1', 8080, '/home'), + ('python.org', 80, '/'), + ('xkcd.com', 443, '/'), + ] + results = yield from scheduling.map_over( + lambda args: http_client.urlfetch(*args), argses, timeout=2) + for res in results: + print('-->', res) + return [] + + +def doit(): + TIMEOUT = 2 + tasks = set() + + # This references NDB's default test service. + # (Sadly the service is single-threaded.) + task1 = scheduling.Task(http_client.urlfetch('localhost', 8080, path='/'), + 'root', timeout=TIMEOUT) + tasks.add(task1) + task2 = scheduling.Task(http_client.urlfetch('127.0.0.1', 8080, + path='/home'), + 'home', timeout=TIMEOUT) + tasks.add(task2) + + # Fetch python.org home page. + task3 = scheduling.Task(http_client.urlfetch('python.org', 80, path='/'), + 'python', timeout=TIMEOUT) + tasks.add(task3) + + # Fetch XKCD home page using SSL. (Doesn't like IPv6.) + task4 = scheduling.Task(http_client.urlfetch('xkcd.com', ssl=True, path='/', + af=socket.AF_INET), + 'xkcd', timeout=TIMEOUT) + tasks.add(task4) + +## # Fetch many links from python.org (/x.y.z). +## for x in '123': +## for y in '0123456789': +## path = '/{}.{}'.format(x, y) +## g = http_client.urlfetch('82.94.164.162', 80, +## path=path, hdrs={'host': 'python.org'}) +## t = scheduling.Task(g, path, timeout=2) +## tasks.add(t) + +## print(tasks) + yield from scheduling.Task(scheduling.sleep(1), timeout=0.2).wait() + winners = yield from scheduling.wait_any(tasks) + print('And the winners are:', [w.name for w in winners]) + tasks = yield from scheduling.wait_all(tasks) + print('And the players were:', [t.name for t in tasks]) + return tasks + + +def logtimes(real): + utime, stime, cutime, cstime, unused = os.times() + logging.info('real %10.3f', real) + logging.info('user %10.3f', utime + cutime) + logging.info('sys %10.3f', stime + cstime) + + +def main(): + t0 = time.time() + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + # Run scheduler, starting it off with doit(). + task = scheduling.run(doit()) + if task.exception: + print('Exception:', repr(task.exception)) + if isinstance(task.exception, AssertionError): + raise task.exception + else: + for t in task.result: + print(t.name + ':', + repr(t.exception) if t.exception else t.result) + + # Report real, user, sys times. + t1 = time.time() + logtimes(t1-t0) + + +if __name__ == '__main__': + main() diff --git a/old/p3time.py b/old/p3time.py new file mode 100644 index 00000000..35e14c96 --- /dev/null +++ b/old/p3time.py @@ -0,0 +1,47 @@ +"""Compare timing of plain vs. yield-from calls.""" + +import gc +import time + +def plain(n): + if n <= 0: + return 1 + l = plain(n-1) + r = plain(n-1) + return l + 1 + r + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def submain(depth): + t0 = time.time() + k = plain(depth) + t1 = time.time() + fmt = ' {} {} {:-9,.5f}' + delta0 = t1-t0 + print(('plain' + fmt).format(depth, k, delta0)) + + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + delta1 = t1-t0 + print(('coro.' + fmt).format(depth, k, delta1)) + if delta0: + print(('relat' + fmt).format(depth, k, delta1/delta0)) + +def main(reasonable=16): + gc.disable() + for depth in range(reasonable): + submain(depth) + +if __name__ == '__main__': + main() diff --git a/old/polling.py b/old/polling.py new file mode 100644 index 00000000..6586efcc --- /dev/null +++ b/old/polling.py @@ -0,0 +1,535 @@ +"""Event loop and related classes. + +The event loop can be broken up into a pollster (the part responsible +for telling us when file descriptors are ready) and the event loop +proper, which wraps a pollster with functionality for scheduling +callbacks, immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. + +There are several implementations of the pollster part, several using +esoteric system calls that exist only on some platforms. These are: + +- kqueue (most BSD systems) +- epoll (newer Linux systems) +- poll (most UNIX systems) +- select (all UNIX systems, and Windows) +- TODO: Support IOCP on Windows and some UNIX platforms. + +NOTE: We don't use select on systems where any of the others is +available, because select performs poorly as the number of file +descriptors goes up. The ranking is roughly: + + 1. kqueue, epoll, IOCP + 2. poll + 3. select + +TODO: +- Optimize the various pollsters. +- Unittests. +""" + +import collections +import concurrent.futures +import heapq +import logging +import os +import select +import threading +import time + + +class PollsterBase: + """Base class for all polling implementations. + + This defines an interface to register and unregister readers and + writers for specific file descriptors, and an interface to get a + list of events. There's also an interface to check whether any + readers or writers are currently registered. + """ + + def __init__(self): + super().__init__() + self.readers = {} # {fd: token, ...}. + self.writers = {} # {fd: token, ...}. + + def pollable(self): + """Return True if any readers or writers are currently registered.""" + return bool(self.readers or self.writers) + + # Subclasses are expected to extend the add/remove methods. + + def register_reader(self, fd, token): + """Add or update a reader for a file descriptor.""" + self.readers[fd] = token + + def register_writer(self, fd, token): + """Add or update a writer for a file descriptor.""" + self.writers[fd] = token + + def unregister_reader(self, fd): + """Remove the reader for a file descriptor.""" + del self.readers[fd] + + def unregister_writer(self, fd): + """Remove the writer for a file descriptor.""" + del self.writers[fd] + + def poll(self, timeout=None): + """Poll for events. A subclass must implement this. + + If timeout is omitted or None, this blocks until at least one + event is ready. Otherwise, timeout gives a maximum time to + wait (in seconds as an int or float) -- the method returns as + soon as at least one event is ready or when the timeout is + expired. For a non-blocking poll, pass 0. + + The return value is a list of events; it is empty when the + timeout expired before any events were ready. Each event + is a token previously passed to register_reader/writer(). + """ + raise NotImplementedError + + +class SelectPollster(PollsterBase): + """Pollster implementation using select.""" + + def poll(self, timeout=None): + readable, writable, _ = select.select(self.readers, self.writers, + [], timeout) + events = [] + events += (self.readers[fd] for fd in readable) + events += (self.writers[fd] for fd in writable) + return events + + +class PollPollster(PollsterBase): + """Pollster implementation using poll.""" + + def __init__(self): + super().__init__() + self._poll = select.poll() + + def _update(self, fd): + assert isinstance(fd, int), fd + flags = 0 + if fd in self.readers: + flags |= select.POLLIN + if fd in self.writers: + flags |= select.POLLOUT + if flags: + self._poll.register(fd, flags) + else: + self._poll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + # Timeout is in seconds, but poll() takes milliseconds. + msecs = None if timeout is None else int(round(1000 * timeout)) + events = [] + for fd, flags in self._poll.poll(msecs): + if flags & (select.POLLIN | select.POLLHUP): + if fd in self.readers: + events.append(self.readers[fd]) + if flags & (select.POLLOUT | select.POLLHUP): + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class EPollPollster(PollsterBase): + """Pollster implementation using epoll.""" + + def __init__(self): + super().__init__() + self._epoll = select.epoll() + + def _update(self, fd): + assert isinstance(fd, int), fd + eventmask = 0 + if fd in self.readers: + eventmask |= select.EPOLLIN + if fd in self.writers: + eventmask |= select.EPOLLOUT + if eventmask: + try: + self._epoll.register(fd, eventmask) + except IOError: + self._epoll.modify(fd, eventmask) + else: + self._epoll.unregister(fd) + + def register_reader(self, fd, callback, *args): + super().register_reader(fd, callback, *args) + self._update(fd) + + def register_writer(self, fd, callback, *args): + super().register_writer(fd, callback, *args) + self._update(fd) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + self._update(fd) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + self._update(fd) + + def poll(self, timeout=None): + if timeout is None: + timeout = -1 # epoll.poll() uses -1 to mean "wait forever". + events = [] + for fd, eventmask in self._epoll.poll(timeout): + if eventmask & select.EPOLLIN: + if fd in self.readers: + events.append(self.readers[fd]) + if eventmask & select.EPOLLOUT: + if fd in self.writers: + events.append(self.writers[fd]) + return events + + +class KqueuePollster(PollsterBase): + """Pollster implementation using kqueue.""" + + def __init__(self): + super().__init__() + self._kqueue = select.kqueue() + + def register_reader(self, fd, callback, *args): + if fd not in self.readers: + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_reader(fd, callback, *args) + + def register_writer(self, fd, callback, *args): + if fd not in self.writers: + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return super().register_writer(fd, callback, *args) + + def unregister_reader(self, fd): + super().unregister_reader(fd) + kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def unregister_writer(self, fd): + super().unregister_writer(fd) + kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + + def poll(self, timeout=None): + events = [] + max_ev = len(self.readers) + len(self.writers) + for kev in self._kqueue.control(None, max_ev, timeout): + fd = kev.ident + flag = kev.filter + if flag == select.KQ_FILTER_READ and fd in self.readers: + events.append(self.readers[fd]) + elif flag == select.KQ_FILTER_WRITE and fd in self.writers: + events.append(self.writers[fd]) + return events + + +# Pick the best pollster class for the platform. +if hasattr(select, 'kqueue'): + best_pollster = KqueuePollster +elif hasattr(select, 'epoll'): + best_pollster = EPollPollster +elif hasattr(select, 'poll'): + best_pollster = PollPollster +else: + best_pollster = SelectPollster + + +class DelayedCall: + """Object returned by callback registration methods.""" + + def __init__(self, when, callback, args, kwds=None): + self.when = when + self.callback = callback + self.args = args + self.kwds = kwds + self.cancelled = False + + def cancel(self): + self.cancelled = True + + def __lt__(self, other): + return self.when < other.when + + def __le__(self, other): + return self.when <= other.when + + def __eq__(self, other): + return self.when == other.when + + +class EventLoop: + """Event loop functionality. + + This defines public APIs call_soon(), call_later(), run_once() and + run(). It also wraps Pollster APIs register_reader(), + register_writer(), remove_reader(), remove_writer() with + add_reader() etc. + + This class's instance variables are not part of its API. + """ + + def __init__(self, pollster=None): + super().__init__() + if pollster is None: + logging.info('Using pollster: %s', best_pollster.__name__) + pollster = best_pollster() + self.pollster = pollster + self.ready = collections.deque() # [(callback, args), ...] + self.scheduled = [] # [(when, callback, args), ...] + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_reader(fd, dcall) + return dcall + + def remove_reader(self, fd): + """Remove a reader callback.""" + self.pollster.unregister_reader(fd) + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a DelayedCall instance.""" + dcall = DelayedCall(None, callback, args) + self.pollster.register_writer(fd, dcall) + return dcall + + def remove_writer(self, fd): + """Remove a writer callback.""" + self.pollster.unregister_writer(fd) + + def add_callback(self, dcall): + """Add a DelayedCall to ready or scheduled.""" + if dcall.cancelled: + return + if dcall.when is None: + self.ready.append(dcall) + else: + heapq.heappush(self.scheduled, dcall) + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + dcall = DelayedCall(None, callback, args) + self.ready.append(dcall) + return dcall + + def call_later(self, when, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The time can be an int or float, expressed in seconds. + + If when is small enough (~11 days), it's assumed to be a + relative time, meaning the call will be scheduled that many + seconds in the future; otherwise it's assumed to be a posix + timestamp as returned by time.time(). + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + if when < 10000000: + when += time.time() + dcall = DelayedCall(when, callback, args) + heapq.heappush(self.scheduled, dcall) + return dcall + + def run_once(self): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Pass in a timeout or deadline or something. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: As step 4, run everything scheduled by steps 1-3. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # TODO: Ensure this loop always finishes, even if some + # callbacks keeps registering more callbacks. + while self.ready: + dcall = self.ready.popleft() + if not dcall.cancelled: + try: + if dcall.kwds: + dcall.callback(*dcall.args, **dcall.kwds) + else: + dcall.callback(*dcall.args) + except Exception: + logging.exception('Exception in callback %s %r', + dcall.callback, dcall.args) + + # Remove delayed calls that were cancelled from head of queue. + while self.scheduled and self.scheduled[0].cancelled: + heapq.heappop(self.scheduled) + + # Inspect the poll queue. + if self.pollster.pollable(): + if self.scheduled: + when = self.scheduled[0].when + timeout = max(0, when - time.time()) + else: + timeout = None + t0 = time.time() + events = self.pollster.poll(timeout) + t1 = time.time() + argstr = '' if timeout is None else ' %.3f' % timeout + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + for dcall in events: + self.add_callback(dcall) + + # Handle 'later' callbacks that are ready. + now = time.time() + while self.scheduled: + dcall = self.scheduled[0] + if dcall.when > now: + break + dcall = heapq.heappop(self.scheduled) + self.call_soon(dcall.callback, *dcall.args) + + def run(self): + """Run the event loop until there is no work left to do. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + """ + while self.ready or self.scheduled or self.pollster.pollable(): + self.run_once() + + +MAX_WORKERS = 5 # Default max workers when creating an executor. + + +class ThreadRunner: + """Helper to submit work to a thread pool and wait for it. + + This is the glue between the single-threaded callback-based async + world and the threaded world. Use it to call functions that must + block and don't have an async alternative (e.g. getaddrinfo()). + + The only public API is submit(). + """ + + def __init__(self, eventloop, executor=None): + self.eventloop = eventloop + self.executor = executor # Will be constructed lazily. + self.pipe_read_fd, self.pipe_write_fd = os.pipe() + self.active_count = 0 + + def read_callback(self): + """Semi-permanent callback while at least one future is active.""" + assert self.active_count > 0, self.active_count + data = os.read(self.pipe_read_fd, 8192) # Traditional buffer size. + self.active_count -= len(data) + if self.active_count == 0: + self.eventloop.remove_reader(self.pipe_read_fd) + assert self.active_count >= 0, self.active_count + + def submit(self, func, *args, executor=None, callback=None): + """Submit a function to the thread pool. + + This returns a concurrent.futures.Future instance. The caller + should not wait for that, but rather use the callback argument.. + """ + if executor is None: + executor = self.executor + if executor is None: + # Lazily construct a default executor. + # TODO: Should this be shared between threads? + executor = concurrent.futures.ThreadPoolExecutor(MAX_WORKERS) + self.executor = executor + assert self.active_count >= 0, self.active_count + future = executor.submit(func, *args) + if self.active_count == 0: + self.eventloop.add_reader(self.pipe_read_fd, self.read_callback) + self.active_count += 1 + def done_callback(fut): + if callback is not None: + self.eventloop.call_soon(callback, fut) + # TODO: Wake up the pipe in call_soon()? + os.write(self.pipe_write_fd, b'x') + future.add_done_callback(done_callback) + return future + + +class Context(threading.local): + """Thread-local context. + + We use this to avoid having to explicitly pass around an event loop + or something to hold the current task. + + TODO: Add an API so frameworks can substitute a different notion + of context more easily. + """ + + def __init__(self, eventloop=None, threadrunner=None): + # Default event loop and thread runner are lazily constructed + # when first accessed. + self._eventloop = eventloop + self._threadrunner = threadrunner + self.current_task = None # For the benefit of scheduling.py. + + @property + def eventloop(self): + if self._eventloop is None: + self._eventloop = EventLoop() + return self._eventloop + + @property + def threadrunner(self): + if self._threadrunner is None: + self._threadrunner = ThreadRunner(self.eventloop) + return self._threadrunner + + +context = Context() # Thread-local! diff --git a/old/scheduling.py b/old/scheduling.py new file mode 100644 index 00000000..3864571d --- /dev/null +++ b/old/scheduling.py @@ -0,0 +1,354 @@ +#!/usr/bin/env python3.3 +"""Example coroutine scheduler, PEP-380-style ('yield from '). + +Requires Python 3.3. + +There are likely micro-optimizations possible here, but that's not the point. + +TODO: +- Docstrings. +- Unittests. + +PATTERNS TO TRY: +- Various synchronization primitives (Lock, RLock, Event, Condition, + Semaphore, BoundedSemaphore, Barrier). +""" + +__author__ = 'Guido van Rossum ' + +# Standard library imports (keep in alphabetic order). +from concurrent.futures import CancelledError, TimeoutError +import logging +import time +import types + +# Local imports (keep in alphabetic order). +import polling + + +context = polling.context + + +class Task: + """Wrapper around a stack of generators. + + This is a bit like a Future, but with a different interface. + + TODO: + - wait for result. + """ + + def __init__(self, gen, name=None, *, timeout=None): + assert isinstance(gen, types.GeneratorType), repr(gen) + self.gen = gen + self.name = name or gen.__name__ + self.timeout = timeout + self.eventloop = context.eventloop + self.canceleer = None + if timeout is not None: + self.canceleer = self.eventloop.call_later(timeout, self.cancel) + self.blocked = False + self.unblocker = None + self.cancelled = False + self.must_cancel = False + self.alive = True + self.result = None + self.exception = None + self.done_callbacks = [] + # Start the task immediately. + self.eventloop.call_soon(self.step) + + def add_done_callback(self, done_callback): + # For better or for worse, the callback will always be called + # with the task as an argument, like concurrent.futures.Future. + # TODO: Call it right away if task is no longer alive. + dcall = polling.DelayedCall(None, done_callback, (self,)) + self.done_callbacks.append(dcall) + self.done_callbacks = [dc for dc in self.done_callbacks + if not dc.cancelled] + return dcall + + def __repr__(self): + parts = [self.name] + is_current = (self is context.current_task) + if self.blocked: + parts.append('blocking' if is_current else 'blocked') + elif self.alive: + parts.append('running' if is_current else 'runnable') + if self.must_cancel: + parts.append('must_cancel') + if self.cancelled: + parts.append('cancelled') + if self.exception is not None: + parts.append('exception=%r' % self.exception) + elif not self.alive: + parts.append('result=%r' % (self.result,)) + if self.timeout is not None: + parts.append('timeout=%.3f' % self.timeout) + return 'Task<' + ', '.join(parts) + '>' + + def cancel(self): + if self.alive: + if not self.must_cancel and not self.cancelled: + self.must_cancel = True + if self.blocked: + self.unblock() + + def step(self): + assert self.alive, self + try: + context.current_task = self + if self.must_cancel: + self.must_cancel = False + self.cancelled = True + self.gen.throw(CancelledError()) + else: + next(self.gen) + except StopIteration as exc: + self.alive = False + self.result = exc.value + except Exception as exc: + self.alive = False + self.exception = exc + logging.debug('Uncaught exception in %s', self, + exc_info=True, stack_info=True) + except BaseException as exc: + self.alive = False + self.exception = exc + raise + else: + if not self.blocked: + self.eventloop.call_soon(self.step) + finally: + context.current_task = None + if not self.alive: + # Cancel timeout callback if set. + if self.canceleer is not None: + self.canceleer.cancel() + # Schedule done_callbacks. + for dcall in self.done_callbacks: + self.eventloop.add_callback(dcall) + + def block(self, unblock_callback=None, *unblock_args): + assert self is context.current_task, self + assert self.alive, self + assert not self.blocked, self + self.blocked = True + self.unblocker = (unblock_callback, unblock_args) + + def unblock_if_alive(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. + if self.alive: + self.unblock() + + def unblock(self, unused=None): + # Ignore optional argument so we can be a Future's done_callback. + assert self.alive, self + assert self.blocked, self + self.blocked = False + unblock_callback, unblock_args = self.unblocker + if unblock_callback is not None: + try: + unblock_callback(*unblock_args) + except Exception: + logging.error('Exception in unblocker in task %r', self.name) + raise + finally: + self.unblocker = None + self.eventloop.call_soon(self.step) + + def block_io(self, fd, flag): + assert isinstance(fd, int), repr(fd) + assert flag in ('r', 'w'), repr(flag) + if flag == 'r': + self.block(self.eventloop.remove_reader, fd) + self.eventloop.add_reader(fd, self.unblock) + else: + self.block(self.eventloop.remove_writer, fd) + self.eventloop.add_writer(fd, self.unblock) + + def wait(self): + """COROUTINE: Wait until this task is finished.""" + current_task = context.current_task + assert self is not current_task, (self, current_task) # How confusing! + if not self.alive: + return + current_task.block() + self.add_done_callback(current_task.unblock) + yield + + def __iter__(self): + """COROUTINE: Wait, then return result or raise exception. + + This adds a little magic so you can say + + x = yield from Task(gen()) + + and it is equivalent to + + x = yield from gen() + + but with the option to add a timeout (and only a tad slower). + """ + if self.alive: + yield from self.wait() + assert not self.alive + if self.exception is not None: + raise self.exception + return self.result + + +def run(arg=None): + """Run the event loop until it's out of work. + + If you pass a generator, it will be spawned for you. + You can also pass a task (already started). + Returns the task. + """ + t = None + if arg is not None: + if isinstance(arg, Task): + t = arg + else: + t = Task(arg) + context.eventloop.run() + if t is not None and t.exception is not None: + logging.error('Uncaught exception in startup task: %r', + t.exception) + return t + + +def sleep(secs): + """COROUTINE: Sleep for some time (a float in seconds).""" + current_task = context.current_task + unblocker = context.eventloop.call_later(secs, current_task.unblock) + current_task.block(unblocker.cancel) + yield + + +def block_r(fd): + """COROUTINE: Block until a file descriptor is ready for reading.""" + context.current_task.block_io(fd, 'r') + yield + + +def block_w(fd): + """COROUTINE: Block until a file descriptor is ready for writing.""" + context.current_task.block_io(fd, 'w') + yield + + +def call_in_thread(func, *args, executor=None): + """COROUTINE: Run a function in a thread.""" + task = context.current_task + eventloop = context.eventloop + future = context.threadrunner.submit(func, *args, + executor=executor, + callback=task.unblock_if_alive) + task.block(future.cancel) + yield + assert future.done() + return future.result() + + +def wait_for(count, tasks): + """COROUTINE: Wait for the first N of a set of tasks to complete. + + May return more than N if more than N are immediately ready. + + NOTE: Tasks that were cancelled or raised are also considered ready. + """ + assert tasks + assert all(isinstance(task, Task) for task in tasks) + tasks = set(tasks) + assert 1 <= count <= len(tasks) + current_task = context.current_task + assert all(task is not current_task for task in tasks) + todo = set() + done = set() + dcalls = [] + def wait_for_callback(task): + nonlocal todo, done, current_task, count, dcalls + todo.remove(task) + if len(done) < count: + done.add(task) + if len(done) == count: + for dcall in dcalls: + dcall.cancel() + current_task.unblock() + for task in tasks: + if task.alive: + todo.add(task) + else: + done.add(task) + if len(done) < count: + for task in todo: + dcall = task.add_done_callback(wait_for_callback) + dcalls.append(dcall) + current_task.block() + yield + return done + + +def wait_any(tasks): + """COROUTINE: Wait for the first of a set of tasks to complete.""" + return wait_for(1, tasks) + + +def wait_all(tasks): + """COROUTINE: Wait for all of a set of tasks to complete.""" + return wait_for(len(tasks), tasks) + + +def map_over(gen, *args, timeout=None): + """COROUTINE: map a generator over one or more iterables. + + E.g. map_over(foo, xs, ys) runs + + Task(foo(x, y)) for x, y in zip(xs, ys) + + and returns a list of all results (in that order). However if any + task raises an exception, the remaining tasks are cancelled and + the exception is propagated. + """ + # gen is a generator function. + tasks = [Task(gobj, timeout=timeout) for gobj in map(gen, *args)] + return (yield from par_tasks(tasks)) + + +def par(*args): + """COROUTINE: Wait for generators, return a list of results. + + Raises as soon as one of the tasks raises an exception (and then + remaining tasks are cancelled). + + This differs from par_tasks() in two ways: + - takes *args instead of list of args + - each arg may be a generator or a task + """ + tasks = [] + for arg in args: + if not isinstance(arg, Task): + # TODO: assert arg is a generator or an iterator? + arg = Task(arg) + tasks.append(arg) + return (yield from par_tasks(tasks)) + + +def par_tasks(tasks): + """COROUTINE: Wait for a list of tasks, return a list of results. + + Raises as soon as one of the tasks raises an exception (and then + remaining tasks are cancelled). + """ + todo = set(tasks) + while todo: + ts = yield from wait_any(todo) + for t in ts: + assert not t.alive, t + todo.remove(t) + if t.exception is not None: + for other in todo: + other.cancel() + raise t.exception + return [t.result for t in tasks] diff --git a/old/sockets.py b/old/sockets.py new file mode 100644 index 00000000..a5005dc3 --- /dev/null +++ b/old/sockets.py @@ -0,0 +1,348 @@ +"""Socket wrappers to go with scheduling.py. + +Classes: + +- SocketTransport: a transport implementation wrapping a socket. +- SslTransport: a transport implementation wrapping SSL around a socket. +- BufferedReader: a buffer wrapping the read end of a transport. + +Functions (all coroutines): + +- connect(): connect a socket. +- getaddrinfo(): look up an address. +- create_connection(): look up address and return a connected socket for it. +- create_transport(): look up address and return a connected transport. + +TODO: +- Improve transport abstraction. +- Make a nice protocol abstraction. +- Unittests. +- A write() call that isn't a generator (needed so you can substitute it + for sys.stderr, pass it to logging.StreamHandler, etc.). +""" + +__author__ = 'Guido van Rossum ' + +# Stdlib imports. +import errno +import socket +import ssl + +# Local imports. +import scheduling + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + + +class SocketTransport: + """Transport wrapping a socket. + + The socket must already be connected in non-blocking mode. + """ + + def __init__(self, sock): + self.sock = sock + + def recv(self, n): + """COROUTINE: Read up to n bytes, blocking as needed. + + Always returns at least one byte, except if the socket was + closed or disconnected and there's no more data; then it + returns b''. + """ + assert n >= 0, n + while True: + try: + return self.sock.recv(n) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return b'' + else: + raise # Unexpected, propagate. + yield from scheduling.block_r(self.sock.fileno()) + + def send(self, data): + """COROUTINE; Send data to the socket, blocking until all written. + + Return True if all went well, False if socket was disconnected. + """ + while data: + try: + n = self.sock.send(data) + except socket.error as err: + if err.errno in _TRYAGAIN: + pass + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + else: + assert 0 <= n <= len(data), (n, len(data)) + if n == len(data): + break + data = data[n:] + continue + yield from scheduling.block_w(self.sock.fileno()) + + return True + + def close(self): + """Close the socket. (Not a coroutine.)""" + self.sock.close() + + +class SslTransport: + """Transport wrapping a socket in SSL. + + The socket must already be connected at the TCP level in + non-blocking mode. + """ + + def __init__(self, rawsock, sslcontext=None): + self.rawsock = rawsock + self.sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self.sslsock = self.sslcontext.wrap_socket( + self.rawsock, do_handshake_on_connect=False) + + def do_handshake(self): + """COROUTINE: Finish the SSL handshake.""" + while True: + try: + self.sslsock.do_handshake() + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + else: + break + + def recv(self, n): + """COROUTINE: Read up to n bytes. + + This blocks until at least one byte is read, or until EOF. + """ + while True: + try: + return self.sslsock.recv(n) + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_r(self.sslsock.fileno()) + elif err.errno in _DISCONNECTED: + # Can this happen? + return b'' + else: + raise # Unexpected, propagate. + + def send(self, data): + """COROUTINE: Send data to the socket, blocking as needed.""" + while data: + try: + n = self.sslsock.send(data) + except ssl.SSLWantReadError: + yield from scheduling.block_r(self.sslsock.fileno()) + except ssl.SSLWantWriteError: + yield from scheduling.block_w(self.sslsock.fileno()) + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_w(self.sslsock.fileno()) + elif err.errno in _DISCONNECTED: + return False + else: + raise # Unexpected, propagate. + if n == len(data): + break + data = data[n:] + + return True + + def close(self): + """Close the socket. (Not a coroutine.) + + This also closes the raw socket. + """ + self.sslsock.close() + + # TODO: More SSL-specific methods, e.g. certificate stuff, unwrap(), ... + + +class BufferedReader: + """A buffered reader wrapping a transport.""" + + def __init__(self, trans, limit=8192): + self.trans = trans + self.limit = limit + self.buffer = b'' + self.eof = False + + def read(self, n): + """COROUTINE: Read up to n bytes, blocking at most once.""" + assert n >= 0, n + if not self.buffer and not self.eof: + yield from self._fillbuffer(max(n, self.limit)) + return self._getfrombuffer(n) + + def readexactly(self, n): + """COUROUTINE: Read exactly n bytes, or until EOF.""" + blocks = [] + count = 0 + while count < n: + block = yield from self.read(n - count) + if not block: + break + blocks.append(block) + count += len(block) + return b''.join(blocks) + + def readline(self): + """COROUTINE: Read up to newline or limit, whichever comes first.""" + end = self.buffer.find(b'\n') + 1 # Point past newline, or 0. + while not end and not self.eof and len(self.buffer) < self.limit: + anchor = len(self.buffer) + yield from self._fillbuffer(self.limit) + end = self.buffer.find(b'\n', anchor) + 1 + if not end: + end = len(self.buffer) + if end > self.limit: + end = self.limit + return self._getfrombuffer(end) + + def _getfrombuffer(self, n): + """Read up to n bytes without blocking (not a coroutine).""" + if n >= len(self.buffer): + result, self.buffer = self.buffer, b'' + else: + result, self.buffer = self.buffer[:n], self.buffer[n:] + return result + + def _fillbuffer(self, n): + """COROUTINE: Fill buffer with one (up to) n bytes from transport.""" + assert not self.eof, '_fillbuffer called at eof' + data = yield from self.trans.recv(n) + if data: + self.buffer += data + else: + self.eof = True + + +def connect(sock, address): + """COROUTINE: Connect a socket to an address.""" + try: + sock.connect(address) + except socket.error as err: + if err.errno != errno.EINPROGRESS: + raise + yield from scheduling.block_w(sock.fileno()) + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + raise IOError(err, 'Connection refused') + + +def getaddrinfo(host, port, af=0, socktype=0, proto=0): + """COROUTINE: Look up an address and return a list of infos for it. + + Each info is a tuple (af, socktype, protocol, canonname, address). + """ + infos = yield from scheduling.call_in_thread(socket.getaddrinfo, + host, port, af, + socktype, proto) + return infos + + +def create_connection(host, port, af=0, socktype=socket.SOCK_STREAM, proto=0): + """COROUTINE: Look up address and create a socket connected to it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + yield from connect(sock, address) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return sock + + +def create_transport(host, port, af=0, ssl=None): + """COROUTINE: Look up address and create a transport connected to it.""" + if ssl is None: + ssl = (port == 443) + sock = yield from create_connection(host, port, af) + if ssl: + trans = SslTransport(sock) + yield from trans.do_handshake() + else: + trans = SocketTransport(sock) + return trans + + +class Listener: + """Wrapper for a listening socket.""" + + def __init__(self, sock): + self.sock = sock + + def accept(self): + """COROUTINE: Accept a connection.""" + while True: + try: + conn, addr = self.sock.accept() + except socket.error as err: + if err.errno in _TRYAGAIN: + yield from scheduling.block_r(self.sock.fileno()) + else: + raise # Unexpected, propagate. + else: + conn.setblocking(False) + return conn, addr + + +def create_listener(host, port, af=0, socktype=0, proto=0, + backlog=5, reuse_addr=True): + """COROUTINE: Look up address and create a listener for it.""" + infos = yield from getaddrinfo(host, port, af, socktype, proto) + if not infos: + raise IOError('getaddrinfo() returned an empty list') + exc = None + for af, socktype, proto, cname, address in infos: + sock = None + try: + sock = socket.socket(af, socktype, proto) + sock.setblocking(False) + if reuse_addr: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + sock.listen(backlog) + break + except socket.error as err: + if sock is not None: + sock.close() + if exc is None: + exc = err + else: + raise exc + return Listener(sock) diff --git a/old/transports.py b/old/transports.py new file mode 100644 index 00000000..19095bf4 --- /dev/null +++ b/old/transports.py @@ -0,0 +1,496 @@ +"""Transports and Protocols, actually. + +Inspired by Twisted, PEP 3153 and github.com/lvh/async-pep. + +THIS IS NOT REAL CODE! IT IS JUST AN EXPERIMENT. +""" + +# Stdlib imports. +import collections +import errno +import logging +import socket +import ssl +import sys +import time + +# Local imports. +import polling +import scheduling +import sockets + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + +# Errno values indicating the socket isn't ready for I/O just yet. +_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) + + +class Transport: + """ABC representing a transport. + + There may be many implementations. The user never instantiates + this directly; they call some utility function, passing it a + protocol, and the utility function will call the protocol's + connection_made() method with a transport (or it will call + connection_lost() with an exception if it fails to create the + desired transport). + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + def write(self, data): + """Write some data (bytes) to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data (bytes) to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data will + be received. When all buffered data is flushed, the protocol's + connection_lost() method is called with None as its argument. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method is called with None as + its argument. + """ + raise NotImplementedError + + def half_close(self): + """Closes the write end after flushing buffered data. + + Data may still be received. + + TODO: What's the use case for this? How to implement it? + Should it call shutdown(SHUT_WR) after all the data is flushed? + Is there no use case for closing the other half first? + """ + raise NotImplementedError + + def pause(self): + """Pause the receiving end. + + No data will be received until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Cancels a pause() call, resumes receiving data. + """ + raise NotImplementedError + + +class Protocol: + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing. + + When the user wants to requests a transport, they pass a protocol + instance to a utility function. + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_list() will be called exactly once + with either an exception object or None as an argument. + + If the utility function does not succeed in creating a transport, + it will call connection_lost() with an exception object. + + State machine of calls: + + start -> [CM -> DR*] -> CL -> end + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the connection. + To send data, call its write() or writelines() method. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + + TODO: Should we allow it to be a bytesarray or some other + memory buffer? + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + Also called when we fail to make a connection at all (in that + case connection_made() will not be called). + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +# TODO: The rest is platform specific and should move elsewhere. + +class UnixSocketTransport(Transport): + + def __init__(self, eventloop, protocol, sock): + self._eventloop = eventloop + self._protocol = protocol + self._sock = sock + self._buffer = collections.deque() # For write(). + self._write_closed = False + + def _on_readable(self): + try: + data = self._sock.recv(8192) + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + else: + if not data: + self._eventloop.remove_reader(self._sock.fileno()) + self._sock.close() + self._protocol.connection_lost(None) + else: + self._protocol.data_received(data) # XXX call_soon()? + + def write(self, data): + assert isinstance(data, bytes) + assert not self._write_closed + if not data: + # Silly, but it happens. + return + if self._buffer: + # We've already registered a callback, just buffer the data. + self._buffer.append(data) + # Consider pausing if the total length of the buffer is + # truly huge. + return + + # TODO: Refactor so there's more sharing between this and + # _on_writable(). + + # There's no callback registered yet. It's quite possible + # that the kernel has buffer space for our data, so try to + # write now. Since the socket is non-blocking it will + # give us an error in _TRYAGAIN if it doesn't have enough + # space for even one more byte; it will return the number + # of bytes written if it can write at least one byte. + try: + n = self._sock.send(data) + except socket.error as exc: + # An error. + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + # The kernel doesn't have room for more data right now. + n = 0 + else: + # Wrote at least one byte. + if n == len(data): + # Wrote it all. Done! + if self._write_closed: + self._sock.shutdown(socket.SHUT_WR) + return + # Throw away the data that was already written. + # TODO: Do this without copying the data? + data = data[n:] + self._buffer.append(data) + self._eventloop.add_writer(self._sock.fileno(), self._on_writable) + + def _on_writable(self): + while self._buffer: + data = self._buffer[0] + # TODO: Join small amounts of data? + try: + n = self._sock.send(data) + except socket.error as exc: + # Error handling is the same as in write(). + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + if n < len(data): + self._buffer[0] = data[n:] + return + self._buffer.popleft() + self._eventloop.remove_writer(self._sock.fileno()) + if self._write_closed: + self._sock.shutdown(socket.SHUT_WR) + + def abort(self): + self._bad_error(None) + + def _bad_error(self, exc): + # A serious error. Close the socket etc. + fd = self._sock.fileno() + # TODO: Record whether we have a writer and/or reader registered. + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + self._sock.close() + self._protocol.connection_lost(exc) # XXX call_soon()? + + def half_close(self): + self._write_closed = True + + +class UnixSslTransport(Transport): + + # TODO: Refactor Socket and Ssl transport to share some code. + # (E.g. buffering.) + + # TODO: Consider using coroutines instead of callbacks, it seems + # much easier that way. + + def __init__(self, eventloop, protocol, rawsock, sslcontext=None): + self._eventloop = eventloop + self._protocol = protocol + self._rawsock = rawsock + self._sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslsock = self._sslcontext.wrap_socket( + self._rawsock, do_handshake_on_connect=False) + + self._buffer = collections.deque() # For write(). + self._write_closed = False + + # Try the handshake now. Likely it will raise EAGAIN, then it + # will take care of registering the appropriate callback. + self._on_handshake() + + def _bad_error(self, exc): + # A serious error. Close the socket etc. + fd = self._sslsock.fileno() + # TODO: Record whether we have a writer and/or reader registered. + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + self._sslsock.close() + self._protocol.connection_lost(exc) # XXX call_soon()? + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._eventloop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._eventloop.add_writable(fd, self._on_handshake) + return + # TODO: What if it raises another error? + try: + self._eventloop.remove_reader(fd) + except Exception: + pass + try: + self._eventloop.remove_writer(fd) + except Exception: + pass + self._protocol.connection_made(self) + self._eventloop.add_reader(fd, self._on_ready) + self._eventloop.add_writer(fd, self._on_ready) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + else: + if data: + self._protocol.data_received(data) + else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer + self._eventloop.remove_reader(fd) + self._eventloop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + return + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = self._buffer[0] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except socket.error as exc: + if exc.errno not in _TRYAGAIN: + self._bad_error(exc) + return + else: + if n == len(data): + self._buffer.popleft() + # Could try again, but let's just have the next callback do it. + else: + self._buffer[0] = data[n:] + + def write(self, data): + assert isinstance(data, bytes) + assert not self._write_closed + if not data: + return + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + def half_close(self): + self._write_closed = True + # Just set the flag. Calling shutdown() on the ssl socket + # breaks something, causing recv() to return binary data. + + +def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, + use_ssl=None): + # TODO: Pass in a protocol factory, not a protocol. + # What should be the exact sequence of events? + # - socket + # - transport + # - protocol + # - tell transport about protocol + # - tell protocol about transport + # Or should the latter two be reversed? Does it matter? + if port is None: + port = 443 if use_ssl else 80 + if use_ssl is None: + use_ssl = (port == 443) + if not socktype: + socktype = socket.SOCK_STREAM + eventloop = polling.context.eventloop + + def on_socket_connected(task): + assert not task.alive + if task.exception is not None: + # TODO: Call some callback. + raise task.exception + sock = task.result + assert sock is not None + logging.debug('on_socket_connected') + if use_ssl: + # You can pass an ssl.SSLContext object as use_ssl, + # or a bool. + if isinstance(use_ssl, bool): + sslcontext = None + else: + sslcontext = use_ssl + transport = UnixSslTransport(eventloop, protocol, sock, sslcontext) + else: + transport = UnixSocketTransport(eventloop, protocol, sock) + # TODO: Should the ransport make the following calls? + protocol.connection_made(transport) # XXX call_soon()? + # Don't do this before connection_made() is called. + eventloop.add_reader(sock.fileno(), transport._on_readable) + + coro = sockets.create_connection(host, port, af, socktype, proto) + task = scheduling.Task(coro) + task.add_done_callback(on_socket_connected) + + +def main(): # Testing... + + # Initialize logging. + if '-d' in sys.argv: + level = logging.DEBUG + elif '-v' in sys.argv: + level = logging.INFO + elif '-q' in sys.argv: + level = logging.ERROR + else: + level = logging.WARN + logging.basicConfig(level=level) + + host = 'xkcd.com' + if sys.argv[1:] and '.' in sys.argv[-1]: + host = sys.argv[-1] + + t0 = time.time() + + class TestProtocol(Protocol): + def connection_made(self, transport): + logging.info('Connection made at %.3f secs', time.time() - t0) + self.transport = transport + self.transport.write(b'GET / HTTP/1.0\r\nHost: ' + + host.encode('ascii') + + b'\r\n\r\n') + self.transport.half_close() + def data_received(self, data): + logging.info('Received %d bytes at t=%.3f', + len(data), time.time() - t0) + logging.debug('Received %r', data) + def connection_lost(self, exc): + logging.debug('Connection lost: %r', exc) + self.t1 = time.time() + logging.info('Total time %.3f secs', self.t1 - t0) + + tp = TestProtocol() + logging.debug('tp = %r', tp) + make_connection(tp, host, use_ssl=('-S' in sys.argv)) + logging.info('Running...') + polling.context.eventloop.run() + logging.info('Done.') + + +if __name__ == '__main__': + main() diff --git a/old/xkcd.py b/old/xkcd.py new file mode 100755 index 00000000..474009d0 --- /dev/null +++ b/old/xkcd.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3.3 +"""Minimal synchronous SSL demo, connecting to xkcd.com.""" + +import socket, ssl + +s = socket.socket() +s.connect(('xkcd.com', 443)) +ss = ssl.wrap_socket(s) + +ss.send(b'GET / HTTP/1.0\r\n\r\n') + +while True: + data = ss.recv(1000000) + print(data) + if not data: + break + +ss.close() diff --git a/old/yyftime.py b/old/yyftime.py new file mode 100644 index 00000000..f55234b9 --- /dev/null +++ b/old/yyftime.py @@ -0,0 +1,75 @@ +"""Compare timing of yield-from vs. yield calls.""" + +import gc +import time + +def coroutine(n): + if n <= 0: + return 1 + l = yield from coroutine(n-1) + r = yield from coroutine(n-1) + return l + 1 + r + +def run_coro(depth): + t0 = time.time() + try: + g = coroutine(depth) + while True: + next(g) + except StopIteration as err: + k = err.value + t1 = time.time() + print('coro', depth, k, round(t1-t0, 6)) + return t1-t0 + +class Future: + + def __init__(self, g): + self.g = g + + def wait(self): + value = None + try: + while True: + f = self.g.send(value) + f.wait() + value = f.value + except StopIteration as err: + self.value = err.value + + + +def task(func): # Decorator + def wrapper(*args): + g = func(*args) + f = Future(g) + return f + return wrapper + +@task +def oldstyle(n): + if n <= 0: + return 1 + l = yield oldstyle(n-1) + r = yield oldstyle(n-1) + return l + 1 + r + +def run_olds(depth): + t0 = time.time() + f = oldstyle(depth) + f.wait() + k = f.value + t1 = time.time() + print('olds', depth, k, round(t1-t0, 6)) + return t1-t0 + +def main(): + gc.disable() + for depth in range(16): + tc = run_coro(depth) + to = run_olds(depth) + if tc: + print('ratio', round(to/tc, 2)) + +if __name__ == '__main__': + main() diff --git a/overlapped.c b/overlapped.c new file mode 100644 index 00000000..c9f6ec9f --- /dev/null +++ b/overlapped.c @@ -0,0 +1,997 @@ +/* + * Support for overlapped IO + * + * Some code borrowed from Modules/_winapi.c of CPython + */ + +/* XXX check overflow and DWORD <-> Py_ssize_t conversions + Check itemsize */ + +#include "Python.h" +#include "structmember.h" + +#define WINDOWS_LEAN_AND_MEAN +#include +#include +#include + +#if defined(MS_WIN32) && !defined(MS_WIN64) +# define F_POINTER "k" +# define T_POINTER T_ULONG +#else +# define F_POINTER "K" +# define T_POINTER T_ULONGLONG +#endif + +#define F_HANDLE F_POINTER +#define F_ULONG_PTR F_POINTER +#define F_DWORD "k" +#define F_BOOL "i" +#define F_UINT "I" + +#define T_HANDLE T_POINTER + +enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_WRITE, TYPE_ACCEPT, + TYPE_CONNECT, TYPE_DISCONNECT}; + +/* + * Map Windows error codes to subclasses of OSError + */ + +static PyObject * +SetFromWindowsErr(DWORD err) +{ + PyObject *exception_type; + + if (err == 0) + err = GetLastError(); + switch (err) { + case ERROR_CONNECTION_REFUSED: + exception_type = PyExc_ConnectionRefusedError; + break; + case ERROR_CONNECTION_ABORTED: + exception_type = PyExc_ConnectionAbortedError; + break; + default: + exception_type = PyExc_OSError; + } + return PyErr_SetExcFromWindowsErr(exception_type, err); +} + +/* + * Some functions should be loaded at runtime + */ + +static LPFN_ACCEPTEX Py_AcceptEx = NULL; +static LPFN_CONNECTEX Py_ConnectEx = NULL; +static LPFN_DISCONNECTEX Py_DisconnectEx = NULL; +static BOOL (CALLBACK *Py_CancelIoEx)(HANDLE, LPOVERLAPPED) = NULL; + +#define GET_WSA_POINTER(s, x) \ + (SOCKET_ERROR != WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, \ + &Guid##x, sizeof(Guid##x), &Py_##x, \ + sizeof(Py_##x), &dwBytes, NULL, NULL)) + +static int +initialize_function_pointers(void) +{ + GUID GuidAcceptEx = WSAID_ACCEPTEX; + GUID GuidConnectEx = WSAID_CONNECTEX; + GUID GuidDisconnectEx = WSAID_DISCONNECTEX; + HINSTANCE hKernel32; + SOCKET s; + DWORD dwBytes; + + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (s == INVALID_SOCKET) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + if (!GET_WSA_POINTER(s, AcceptEx) || + !GET_WSA_POINTER(s, ConnectEx) || + !GET_WSA_POINTER(s, DisconnectEx)) + { + closesocket(s); + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + closesocket(s); + + /* On WinXP we will have Py_CancelIoEx == NULL */ + hKernel32 = GetModuleHandle("KERNEL32"); + *(FARPROC *)&Py_CancelIoEx = GetProcAddress(hKernel32, "CancelIoEx"); + return 0; +} + +/* + * Completion port stuff + */ + +PyDoc_STRVAR( + CreateIoCompletionPort_doc, + "CreateIoCompletionPort(handle, port, key, concurrency) -> port\n\n" + "Create a completion port or register a handle with a port."); + +static PyObject * +overlapped_CreateIoCompletionPort(PyObject *self, PyObject *args) +{ + HANDLE FileHandle; + HANDLE ExistingCompletionPort; + ULONG_PTR CompletionKey; + DWORD NumberOfConcurrentThreads; + HANDLE ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE F_ULONG_PTR F_DWORD, + &FileHandle, &ExistingCompletionPort, &CompletionKey, + &NumberOfConcurrentThreads)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = CreateIoCompletionPort(FileHandle, ExistingCompletionPort, + CompletionKey, NumberOfConcurrentThreads); + Py_END_ALLOW_THREADS + + if (ret == NULL) + return SetFromWindowsErr(0); + return Py_BuildValue(F_HANDLE, ret); +} + +PyDoc_STRVAR( + GetQueuedCompletionStatus_doc, + "GetQueuedCompletionStatus(port, msecs) -> (err, bytes, key, address)\n\n" + "Get a message from completion port. Wait for up to msecs milliseconds."); + +static PyObject * +overlapped_GetQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort = NULL; + DWORD NumberOfBytes = 0; + ULONG_PTR CompletionKey = 0; + OVERLAPPED *Overlapped = NULL; + DWORD Milliseconds; + DWORD err; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, + &CompletionPort, &Milliseconds)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = GetQueuedCompletionStatus(CompletionPort, &NumberOfBytes, + &CompletionKey, &Overlapped, Milliseconds); + Py_END_ALLOW_THREADS + + err = ret ? ERROR_SUCCESS : GetLastError(); + if (Overlapped == NULL) { + if (err == WAIT_TIMEOUT) + Py_RETURN_NONE; + else + return SetFromWindowsErr(err); + } + return Py_BuildValue(F_DWORD F_DWORD F_ULONG_PTR F_POINTER, + err, NumberOfBytes, CompletionKey, Overlapped); +} + +PyDoc_STRVAR( + PostQueuedCompletionStatus_doc, + "PostQueuedCompletionStatus(port, bytes, key, address) -> None\n\n" + "Post a message to completion port."); + +static PyObject * +overlapped_PostQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort; + DWORD NumberOfBytes; + ULONG_PTR CompletionKey; + OVERLAPPED *Overlapped; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD F_ULONG_PTR F_POINTER, + &CompletionPort, &NumberOfBytes, &CompletionKey, + &Overlapped)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = PostQueuedCompletionStatus(CompletionPort, NumberOfBytes, + CompletionKey, Overlapped); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +/* + * Bind socket handle to local port without doing slow getaddrinfo() + */ + +PyDoc_STRVAR( + BindLocal_doc, + "BindLocal(handle, length_of_address_tuple) -> None\n\n" + "Bind a socket handle to an arbitrary local port.\n" + "If length_of_address_tuple is 2 then an AF_INET address is used.\n" + "If length_of_address_tuple is 4 then an AF_INET6 address is used."); + +static PyObject * +overlapped_BindLocal(PyObject *self, PyObject *args) +{ + SOCKET Socket; + int TupleLength; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE "i", &Socket, &TupleLength)) + return NULL; + + if (TupleLength == 2) { + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = 0; + addr.sin_addr.S_un.S_addr = INADDR_ANY; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else if (TupleLength == 4) { + struct sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_port = 0; + addr.sin6_addr = in6addr_any; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else { + PyErr_SetString(PyExc_ValueError, "expected tuple of length 2 or 4"); + return NULL; + } + + if (!ret) + return SetFromWindowsErr(WSAGetLastError()); + Py_RETURN_NONE; +} + +/* + * A Python object wrapping an OVERLAPPED structure and other useful data + * for overlapped I/O + */ + +PyDoc_STRVAR( + Overlapped_doc, + "Overlapped object"); + +typedef struct { + PyObject_HEAD + OVERLAPPED overlapped; + /* For convenience, we store the file handle too */ + HANDLE handle; + /* Error returned by last method call */ + DWORD error; + /* Type of operation */ + DWORD type; + /* Buffer used for reading (optional) */ + PyObject *read_buffer; + /* Buffer used for writing (optional) */ + Py_buffer write_buffer; +} OverlappedObject; + + +static PyObject * +Overlapped_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + OverlappedObject *self; + HANDLE event = INVALID_HANDLE_VALUE; + static char *kwlist[] = {"event", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|" F_HANDLE, kwlist, &event)) + return NULL; + + if (event == INVALID_HANDLE_VALUE) { + event = CreateEvent(NULL, TRUE, FALSE, NULL); + if (event == NULL) + return SetFromWindowsErr(0); + } + + self = PyObject_New(OverlappedObject, type); + if (self == NULL) { + if (event != NULL) + CloseHandle(event); + return NULL; + } + + self->handle = NULL; + self->error = 0; + self->type = TYPE_NONE; + self->read_buffer = NULL; + memset(&self->overlapped, 0, sizeof(OVERLAPPED)); + memset(&self->write_buffer, 0, sizeof(Py_buffer)); + if (event) + self->overlapped.hEvent = event; + return (PyObject *)self; +} + +static void +Overlapped_dealloc(OverlappedObject *self) +{ + DWORD bytes; + DWORD olderr = GetLastError(); + BOOL wait = FALSE; + BOOL ret; + + if (!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED) + { + if (Py_CancelIoEx && Py_CancelIoEx(self->handle, &self->overlapped)) + wait = TRUE; + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, + &bytes, wait); + Py_END_ALLOW_THREADS + + switch (ret ? ERROR_SUCCESS : GetLastError()) { + case ERROR_SUCCESS: + case ERROR_NOT_FOUND: + case ERROR_OPERATION_ABORTED: + break; + default: + PyErr_Format( + PyExc_RuntimeError, + "%R still has pending operation at " + "deallocation, the process may crash", self); + PyErr_WriteUnraisable(NULL); + } + } + + if (self->overlapped.hEvent != NULL) + CloseHandle(self->overlapped.hEvent); + + if (self->write_buffer.obj) + PyBuffer_Release(&self->write_buffer); + + Py_CLEAR(self->read_buffer); + PyObject_Del(self); + SetLastError(olderr); +} + +PyDoc_STRVAR( + Overlapped_cancel_doc, + "cancel() -> None\n\n" + "Cancel overlapped operation"); + +static PyObject * +Overlapped_cancel(OverlappedObject *self) +{ + BOOL ret = TRUE; + + if (self->type == TYPE_NOT_STARTED) + Py_RETURN_NONE; + + if (!HasOverlappedIoCompleted(&self->overlapped)) { + Py_BEGIN_ALLOW_THREADS + if (Py_CancelIoEx) + ret = Py_CancelIoEx(self->handle, &self->overlapped); + else + ret = CancelIo(self->handle); + Py_END_ALLOW_THREADS + } + + /* CancelIoEx returns ERROR_NOT_FOUND if the I/O completed in-between */ + if (!ret && GetLastError() != ERROR_NOT_FOUND) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +PyDoc_STRVAR( + Overlapped_getresult_doc, + "getresult(wait=False) -> result\n\n" + "Retrieve result of operation. If wait is true then it blocks\n" + "until the operation is finished. If wait is false and the\n" + "operation is still pending then an error is raised."); + +static PyObject * +Overlapped_getresult(OverlappedObject *self, PyObject *args) +{ + BOOL wait = FALSE; + DWORD transferred = 0; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, "|" F_BOOL, &wait)) + return NULL; + + if (self->type == TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation not yet attempted"); + return NULL; + } + + if (self->type == TYPE_NOT_STARTED) { + PyErr_SetString(PyExc_ValueError, "operation failed to start"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, &transferred, + wait); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + break; + case ERROR_BROKEN_PIPE: + if (self->read_buffer != NULL) + break; + /* fall through */ + default: + return SetFromWindowsErr(err); + } + + switch (self->type) { + case TYPE_READ: + assert(PyBytes_CheckExact(self->read_buffer)); + if (transferred != PyBytes_GET_SIZE(self->read_buffer) && + _PyBytes_Resize(&self->read_buffer, transferred)) + return NULL; + Py_INCREF(self->read_buffer); + return self->read_buffer; + case TYPE_ACCEPT: + case TYPE_CONNECT: + case TYPE_DISCONNECT: + Py_RETURN_NONE; + default: + return PyLong_FromUnsignedLong((unsigned long) transferred); + } +} + +PyDoc_STRVAR( + Overlapped_ReadFile_doc, + "ReadFile(handle, size) -> Overlapped[message]\n\n" + "Start overlapped read"); + +static PyObject * +Overlapped_ReadFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD nread; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &handle, &size)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = ReadFile(handle, PyBytes_AS_STRING(buf), size, &nread, + &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_BROKEN_PIPE: + self->type = TYPE_NOT_STARTED; + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSARecv_doc, + "RecvFile(handle, size, flags) -> Overlapped[message]\n\n" + "Start overlapped receive"); + +static PyObject * +Overlapped_WSARecv(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD flags = 0; + DWORD nread; + PyObject *buf; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD "|" F_DWORD, + &handle, &size, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + wsabuf.len = size; + wsabuf.buf = PyBytes_AS_STRING(buf); + + Py_BEGIN_ALLOW_THREADS + ret = WSARecv((SOCKET)handle, &wsabuf, 1, &nread, &flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_BROKEN_PIPE: + self->type = TYPE_NOT_STARTED; + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WriteFile_doc, + "WriteFile(handle, buf) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped write"); + +static PyObject * +Overlapped_WriteFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD written; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &handle, &bufobj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + + Py_BEGIN_ALLOW_THREADS + ret = WriteFile(handle, self->write_buffer.buf, + (DWORD)self->write_buffer.len, + &written, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSASend_doc, + "WSASend(handle, buf, flags) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped send"); + +static PyObject * +Overlapped_WSASend(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD flags; + DWORD written; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O" F_DWORD, + &handle, &bufobj, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + wsabuf.len = (DWORD)self->write_buffer.len; + wsabuf.buf = self->write_buffer.buf; + + Py_BEGIN_ALLOW_THREADS + ret = WSASend((SOCKET)handle, &wsabuf, 1, &written, flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_AcceptEx_doc, + "AcceptEx(listen_handle, accept_handle) -> Overlapped[address_as_bytes]\n\n" + "Start overlapped wait for client to connect"); + +static PyObject * +Overlapped_AcceptEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ListenSocket; + SOCKET AcceptSocket; + DWORD BytesReceived; + DWORD size; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE, + &ListenSocket, &AcceptSocket)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + size = sizeof(struct sockaddr_in6) + 16; + buf = PyBytes_FromStringAndSize(NULL, size*2); + if (!buf) + return NULL; + + self->type = TYPE_ACCEPT; + self->handle = (HANDLE)ListenSocket; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = Py_AcceptEx(ListenSocket, AcceptSocket, PyBytes_AS_STRING(buf), + 0, size, size, &BytesReceived, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + + +static int +parse_address(PyObject *obj, SOCKADDR *Address, int Length) +{ + char *Host; + unsigned short Port; + unsigned long FlowInfo; + unsigned long ScopeId; + + memset(Address, 0, Length); + + if (PyArg_ParseTuple(obj, "sH", &Host, &Port)) + { + Address->sa_family = AF_INET; + if (WSAStringToAddressA(Host, AF_INET, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN*)Address)->sin_port = htons(Port); + return Length; + } + else if (PyArg_ParseTuple(obj, "sHkk", &Host, &Port, &FlowInfo, &ScopeId)) + { + PyErr_Clear(); + Address->sa_family = AF_INET6; + if (WSAStringToAddressA(Host, AF_INET6, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN6*)Address)->sin6_port = htons(Port); + ((SOCKADDR_IN6*)Address)->sin6_flowinfo = FlowInfo; + ((SOCKADDR_IN6*)Address)->sin6_scope_id = ScopeId; + return Length; + } + + return -1; +} + + +PyDoc_STRVAR( + Overlapped_ConnectEx_doc, + "ConnectEx(client_handle, address_as_bytes) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_ConnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ConnectSocket; + PyObject *AddressObj; + char AddressBuf[sizeof(struct sockaddr_in6)]; + SOCKADDR *Address = (SOCKADDR*)AddressBuf; + int Length; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &ConnectSocket, &AddressObj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + Length = sizeof(AddressBuf); + Length = parse_address(AddressObj, Address, Length); + if (Length < 0) + return NULL; + + self->type = TYPE_CONNECT; + self->handle = (HANDLE)ConnectSocket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_ConnectEx(ConnectSocket, Address, Length, + NULL, 0, NULL, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_DisconnectEx_doc, + "DisconnectEx(handle, flags) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_DisconnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET Socket; + DWORD flags; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &Socket, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + self->type = TYPE_DISCONNECT; + self->handle = (HANDLE)Socket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_DisconnectEx(Socket, &self->overlapped, flags, 0); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +static PyObject* +Overlapped_getaddress(OverlappedObject *self) +{ + return PyLong_FromVoidPtr(&self->overlapped); +} + +static PyObject* +Overlapped_getpending(OverlappedObject *self) +{ + return PyBool_FromLong(!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED); +} + +static PyMethodDef Overlapped_methods[] = { + {"getresult", (PyCFunction) Overlapped_getresult, + METH_VARARGS, Overlapped_getresult_doc}, + {"cancel", (PyCFunction) Overlapped_cancel, + METH_NOARGS, Overlapped_cancel_doc}, + {"ReadFile", (PyCFunction) Overlapped_ReadFile, + METH_VARARGS, Overlapped_ReadFile_doc}, + {"WSARecv", (PyCFunction) Overlapped_WSARecv, + METH_VARARGS, Overlapped_WSARecv_doc}, + {"WriteFile", (PyCFunction) Overlapped_WriteFile, + METH_VARARGS, Overlapped_WriteFile_doc}, + {"WSASend", (PyCFunction) Overlapped_WSASend, + METH_VARARGS, Overlapped_WSASend_doc}, + {"AcceptEx", (PyCFunction) Overlapped_AcceptEx, + METH_VARARGS, Overlapped_AcceptEx_doc}, + {"ConnectEx", (PyCFunction) Overlapped_ConnectEx, + METH_VARARGS, Overlapped_ConnectEx_doc}, + {"DisconnectEx", (PyCFunction) Overlapped_DisconnectEx, + METH_VARARGS, Overlapped_DisconnectEx_doc}, + {NULL} +}; + +static PyMemberDef Overlapped_members[] = { + {"error", T_ULONG, + offsetof(OverlappedObject, error), + READONLY, "Error from last operation"}, + {"event", T_HANDLE, + offsetof(OverlappedObject, overlapped) + offsetof(OVERLAPPED, hEvent), + READONLY, "Overlapped event handle"}, + {NULL} +}; + +static PyGetSetDef Overlapped_getsets[] = { + {"address", (getter)Overlapped_getaddress, NULL, + "Address of overlapped structure"}, + {"pending", (getter)Overlapped_getpending, NULL, + "Whether the operation is pending"}, + {NULL}, +}; + +PyTypeObject OverlappedType = { + PyVarObject_HEAD_INIT(NULL, 0) + /* tp_name */ "_overlapped.Overlapped", + /* tp_basicsize */ sizeof(OverlappedObject), + /* tp_itemsize */ 0, + /* tp_dealloc */ (destructor) Overlapped_dealloc, + /* tp_print */ 0, + /* tp_getattr */ 0, + /* tp_setattr */ 0, + /* tp_reserved */ 0, + /* tp_repr */ 0, + /* tp_as_number */ 0, + /* tp_as_sequence */ 0, + /* tp_as_mapping */ 0, + /* tp_hash */ 0, + /* tp_call */ 0, + /* tp_str */ 0, + /* tp_getattro */ 0, + /* tp_setattro */ 0, + /* tp_as_buffer */ 0, + /* tp_flags */ Py_TPFLAGS_DEFAULT, + /* tp_doc */ "OVERLAPPED structure wrapper", + /* tp_traverse */ 0, + /* tp_clear */ 0, + /* tp_richcompare */ 0, + /* tp_weaklistoffset */ 0, + /* tp_iter */ 0, + /* tp_iternext */ 0, + /* tp_methods */ Overlapped_methods, + /* tp_members */ Overlapped_members, + /* tp_getset */ Overlapped_getsets, + /* tp_base */ 0, + /* tp_dict */ 0, + /* tp_descr_get */ 0, + /* tp_descr_set */ 0, + /* tp_dictoffset */ 0, + /* tp_init */ 0, + /* tp_alloc */ 0, + /* tp_new */ Overlapped_new, +}; + +static PyMethodDef overlapped_functions[] = { + {"CreateIoCompletionPort", overlapped_CreateIoCompletionPort, + METH_VARARGS, CreateIoCompletionPort_doc}, + {"GetQueuedCompletionStatus", overlapped_GetQueuedCompletionStatus, + METH_VARARGS, GetQueuedCompletionStatus_doc}, + {"PostQueuedCompletionStatus", overlapped_PostQueuedCompletionStatus, + METH_VARARGS, PostQueuedCompletionStatus_doc}, + {"BindLocal", overlapped_BindLocal, + METH_VARARGS, BindLocal_doc}, + {NULL} +}; + +static struct PyModuleDef overlapped_module = { + PyModuleDef_HEAD_INIT, + "_overlapped", + NULL, + -1, + overlapped_functions, + NULL, + NULL, + NULL, + NULL +}; + +#define WINAPI_CONSTANT(fmt, con) \ + PyDict_SetItemString(d, #con, Py_BuildValue(fmt, con)) + +PyMODINIT_FUNC +PyInit__overlapped(void) +{ + PyObject *m, *d; + + /* Ensure WSAStartup() called before initializing function pointers */ + m = PyImport_ImportModule("_socket"); + if (!m) + return NULL; + Py_DECREF(m); + + if (initialize_function_pointers() < 0) + return NULL; + + if (PyType_Ready(&OverlappedType) < 0) + return NULL; + + m = PyModule_Create(&overlapped_module); + if (PyModule_AddObject(m, "Overlapped", (PyObject *)&OverlappedType) < 0) + return NULL; + + d = PyModule_GetDict(m); + + WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING); + WINAPI_CONSTANT(F_DWORD, INFINITE); + WINAPI_CONSTANT(F_HANDLE, INVALID_HANDLE_VALUE); + WINAPI_CONSTANT(F_HANDLE, NULL); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_ACCEPT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_CONNECT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, TF_REUSE_SOCKET); + + return m; +} diff --git a/runtests.py b/runtests.py new file mode 100644 index 00000000..0ec5ba31 --- /dev/null +++ b/runtests.py @@ -0,0 +1,198 @@ +"""Run all unittests. + +Usage: + python3 runtests.py [-v] [-q] [pattern] ... + +Where: + -v: verbose + -q: quiet + pattern: optional regex patterns to match test ids (default all tests) + +Note that the test id is the fully qualified name of the test, +including package, module, class and method, +e.g. 'tests.events_test.PolicyTests.testPolicy'. + +runtests.py with --coverage argument is equivalent of: + + $(COVERAGE) run --branch runtests.py -v + $(COVERAGE) html $(list of files) + $(COVERAGE) report -m $(list of files) + +""" + +# Originally written by Beech Horn (for NDB). + +import argparse +import logging +import os +import re +import sys +import subprocess +import unittest +import importlib.machinery + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +ARGS = argparse.ArgumentParser(description="Run all unittests.") +ARGS.add_argument( + '-v', action="store", dest='verbose', + nargs='?', const=1, type=int, default=0, help='verbose') +ARGS.add_argument( + '-x', action="store_true", dest='exclude', help='exclude tests') +ARGS.add_argument( + '-q', action="store_true", dest='quiet', help='quiet') +ARGS.add_argument( + '--tests', action="store", dest='testsdir', default='tests', + help='tests directory') +ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') +ARGS.add_argument( + 'pattern', action="store", nargs="*", + help='optional regex patterns to match test ids (default all tests)') + +COV_ARGS = argparse.ArgumentParser(description="Run all unittests.") +COV_ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') + + +def load_modules(basedir, suffix='.py'): + def list_dir(prefix, dir): + files = [] + + modpath = os.path.join(dir, '__init__.py') + if os.path.isfile(modpath): + mod = os.path.split(dir)[-1] + files.append(('{}{}'.format(prefix, mod), modpath)) + + prefix = '{}{}.'.format(prefix, mod) + + for name in os.listdir(dir): + path = os.path.join(dir, name) + + if os.path.isdir(path): + files.extend(list_dir('{}{}.'.format(prefix, name), path)) + else: + if (name != '__init__.py' and + name.endswith(suffix) and + not name.startswith(('.', '_'))): + files.append(('{}{}'.format(prefix, name[:-3]), path)) + + return files + + mods = [] + for modname, sourcefile in list_dir('', basedir): + if modname == 'runtests': + continue + try: + loader = importlib.machinery.SourceFileLoader(modname, sourcefile) + mods.append((loader.load_module(), sourcefile)) + except Exception as err: + print("Skipping '{}': {}".format(modname, err), file=sys.stderr) + + return mods + + +def load_tests(testsdir, includes=(), excludes=()): + mods = [mod for mod, _ in load_modules(testsdir)] + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + for mod in mods: + for name in set(dir(mod)): + if name.endswith('Tests'): + test_module = getattr(mod, name) + tests = loader.loadTestsFromTestCase(test_module) + if includes: + tests = [test + for test in tests + if any(re.search(pat, test.id()) + for pat in includes)] + if excludes: + tests = [test + for test in tests + if not any(re.search(pat, test.id()) + for pat in excludes)] + suite.addTests(tests) + + return suite + + +def runtests(): + args = ARGS.parse_args() + + testsdir = os.path.abspath(args.testsdir) + if not os.path.isdir(testsdir): + print("Tests directory is not found: {}\n".format(testsdir)) + ARGS.print_help() + return + + excludes = includes = [] + if args.exclude: + excludes = args.pattern + else: + includes = args.pattern + + v = 0 if args.quiet else args.verbose + 1 + + tests = load_tests(args.testsdir, includes, excludes) + logger = logging.getLogger() + if v == 0: + logger.setLevel(logging.CRITICAL) + elif v == 1: + logger.setLevel(logging.ERROR) + elif v == 2: + logger.setLevel(logging.WARNING) + elif v == 3: + logger.setLevel(logging.INFO) + elif v >= 4: + logger.setLevel(logging.DEBUG) + result = unittest.TextTestRunner(verbosity=v).run(tests) + sys.exit(not result.wasSuccessful()) + + +def runcoverage(sdir, args): + """ + To install coverage3 for Python 3, you need: + - Distribute (http://packages.python.org/distribute/) + + What worked for me: + - download http://python-distribute.org/distribute_setup.py + * curl -O http://python-distribute.org/distribute_setup.py + - python3 distribute_setup.py + - python3 -m easy_install coverage + """ + try: + import coverage + except ImportError: + print("Coverage package is not found.") + print(runcoverage.__doc__) + return + + sdir = os.path.abspath(sdir) + if not os.path.isdir(sdir): + print("Python files directory is not found: {}\n".format(sdir)) + ARGS.print_help() + return + + mods = [source for _, source in load_modules(sdir)] + coverage = [sys.executable, '-m', 'coverage'] + + try: + subprocess.check_call( + coverage + ['run', '--branch', 'runtests.py'] + args) + except: + pass + else: + subprocess.check_call(coverage + ['html'] + mods) + subprocess.check_call(coverage + ['report'] + mods) + + +if __name__ == '__main__': + if '--coverage' in sys.argv: + cov_args, args = COV_ARGS.parse_known_args() + runcoverage(cov_args.coverage, args) + else: + runtests() diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..0260f9d5 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[build_ext] +build_lib=tulip diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..dcaee96f --- /dev/null +++ b/setup.py @@ -0,0 +1,14 @@ +import os +from distutils.core import setup, Extension + +extensions = [] +if os.name == 'nt': + ext = Extension('_overlapped', ['overlapped.c'], libraries=['ws2_32']) + extensions.append(ext) + +setup(name='tulip', + description="reference implementation of PEP 3156", + url='http://www.python.org/dev/peps/pep-3156/', + packages=['tulip'], + ext_modules=extensions + ) diff --git a/srv.py b/srv.py new file mode 100755 index 00000000..8fd6ccf6 --- /dev/null +++ b/srv.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +"""Simple server written using an event loop.""" + +import argparse +import email.message +import logging +import os +import sys +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +import tulip +import tulip.http + + +class HttpServer(tulip.http.ServerHttpProtocol): + + def handle_request(self, request_info, message): + print('method = {!r}; path = {!r}; version = {!r}'.format( + request_info.method, request_info.uri, request_info.version)) + + path = request_info.uri + + if (not (path.isprintable() and path.startswith('/')) or '/.' in path): + print('bad path', repr(path)) + path = None + else: + path = '.' + path + if not os.path.exists(path): + print('no file', repr(path)) + path = None + else: + isdir = os.path.isdir(path) + + if not path: + raise tulip.http.HttpStatusException(404) + + headers = email.message.Message() + for hdr, val in message.headers: + print(hdr, val) + headers.add_header(hdr, val) + + if isdir and not path.endswith('/'): + path = path + '/' + raise tulip.http.HttpStatusException( + 302, headers=(('URI', path), ('Location', path))) + + response = tulip.http.Response(self.transport, 200) + response.add_header('Transfer-Encoding', 'chunked') + + # content encoding + accept_encoding = headers.get('accept-encoding', '').lower() + if 'deflate' in accept_encoding: + response.add_header('Content-Encoding', 'deflate') + response.add_compression_filter('deflate') + elif 'gzip' in accept_encoding: + response.add_header('Content-Encoding', 'gzip') + response.add_compression_filter('gzip') + + response.add_chunking_filter(1025) + + if isdir: + response.add_header('Content-type', 'text/html') + response.send_headers() + + response.write(b'
    \r\n') + for name in sorted(os.listdir(path)): + if name.isprintable() and not name.startswith('.'): + try: + bname = name.encode('ascii') + except UnicodeError: + pass + else: + if os.path.isdir(os.path.join(path, name)): + response.write(b'
  • ' + bname + b'/
  • \r\n') + else: + response.write(b'
  • ' + bname + b'
  • \r\n') + response.write(b'
') + else: + response.add_header('Content-type', 'text/plain') + response.send_headers() + + try: + with open(path, 'rb') as fp: + chunk = fp.read(8196) + while chunk: + if not response.write(chunk): + break + chunk = fp.read(8196) + except OSError: + response.write(b'Cannot open') + + response.write_eof() + self.close() + + +ARGS = argparse.ArgumentParser(description="Run simple http server.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') +ARGS.add_argument( + '--iocp', action="store_true", dest='iocp', help='Windows IOCP event loop') +ARGS.add_argument( + '--ssl', action="store_true", dest='ssl', help='Run ssl mode.') +ARGS.add_argument( + '--sslcert', action="store", dest='certfile', help='SSL cert file.') +ARGS.add_argument( + '--sslkey', action="store", dest='keyfile', help='SSL key file.') + + +def main(): + args = ARGS.parse_args() + + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if args.iocp: + from tulip import windows_events + sys.argv.remove('--iocp') + logging.info('using iocp') + el = windows_events.ProactorEventLoop() + tulip.set_event_loop(el) + + if args.ssl: + here = os.path.join(os.path.dirname(__file__), 'tests') + + if args.certfile: + certfile = args.certfile or os.path.join(here, 'sample.crt') + keyfile = args.keyfile or os.path.join(here, 'sample.key') + else: + certfile = os.path.join(here, 'sample.crt') + keyfile = os.path.join(here, 'sample.key') + + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain(certfile, keyfile) + else: + sslcontext = False + + loop = tulip.get_event_loop() + f = loop.start_serving( + lambda: HttpServer(debug=True), args.host, args.port, ssl=sslcontext) + x = loop.run_until_complete(f) + print('serving on', x.getsockname()) + loop.run_forever() + + +if __name__ == '__main__': + main() diff --git a/tests/base_events_test.py b/tests/base_events_test.py new file mode 100644 index 00000000..10f7c480 --- /dev/null +++ b/tests/base_events_test.py @@ -0,0 +1,284 @@ +"""Tests for base_events.py""" + +import concurrent.futures +import logging +import socket +import time +import unittest +import unittest.mock + +from tulip import base_events +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import tasks +from tulip import test_utils + + +class BaseEventLoopTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + + self.event_loop = base_events.BaseEventLoop() + self.event_loop._selector = unittest.mock.Mock() + self.event_loop._selector.registered_count.return_value = 1 + + def test_not_implemented(self): + m = unittest.mock.Mock() + self.assertRaises( + NotImplementedError, + self.event_loop._make_socket_transport, m, m) + self.assertRaises( + NotImplementedError, + self.event_loop._make_ssl_transport, m, m, m, m) + self.assertRaises( + NotImplementedError, + self.event_loop._make_datagram_transport, m, m) + self.assertRaises( + NotImplementedError, self.event_loop._process_events, []) + self.assertRaises( + NotImplementedError, self.event_loop._write_to_self) + self.assertRaises( + NotImplementedError, self.event_loop._read_from_self) + self.assertRaises( + NotImplementedError, + self.event_loop._make_read_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + self.event_loop._make_write_pipe_transport, m, m) + + def test_add_callback_handle(self): + h = events.Handle(lambda: False, ()) + + self.event_loop._add_callback(h) + self.assertFalse(self.event_loop._scheduled) + self.assertIn(h, self.event_loop._ready) + + def test_add_callback_timer(self): + when = time.monotonic() + + h1 = events.Timer(when, lambda: False, ()) + h2 = events.Timer(when+10.0, lambda: False, ()) + + self.event_loop._add_callback(h2) + self.event_loop._add_callback(h1) + self.assertEqual([h1, h2], self.event_loop._scheduled) + self.assertFalse(self.event_loop._ready) + + def test_add_callback_cancelled_handle(self): + h = events.Handle(lambda: False, ()) + h.cancel() + + self.event_loop._add_callback(h) + self.assertFalse(self.event_loop._scheduled) + self.assertFalse(self.event_loop._ready) + + def test_wrap_future(self): + f = futures.Future() + self.assertIs(self.event_loop.wrap_future(f), f) + + def test_wrap_future_concurrent(self): + f = concurrent.futures.Future() + self.assertIsInstance(self.event_loop.wrap_future(f), futures.Future) + + def test_set_default_executor(self): + executor = unittest.mock.Mock() + self.event_loop.set_default_executor(executor) + self.assertIs(executor, self.event_loop._default_executor) + + def test_getnameinfo(self): + sockaddr = unittest.mock.Mock() + self.event_loop.run_in_executor = unittest.mock.Mock() + self.event_loop.getnameinfo(sockaddr) + self.assertEqual( + (None, socket.getnameinfo, sockaddr, 0), + self.event_loop.run_in_executor.call_args[0]) + + def test_call_soon(self): + def cb(): + pass + + h = self.event_loop.call_soon(cb) + self.assertEqual(h._callback, cb) + self.assertIsInstance(h, events.Handle) + self.assertIn(h, self.event_loop._ready) + + def test_call_later(self): + def cb(): + pass + + h = self.event_loop.call_later(10.0, cb) + self.assertIsInstance(h, events.Timer) + self.assertIn(h, self.event_loop._scheduled) + self.assertNotIn(h, self.event_loop._ready) + + def test_call_later_no_delay(self): + def cb(): + pass + + h = self.event_loop.call_later(0, cb) + self.assertIn(h, self.event_loop._ready) + self.assertNotIn(h, self.event_loop._scheduled) + + def test_run_once_in_executor_handle(self): + def cb(): + pass + + self.assertRaises( + AssertionError, self.event_loop.run_in_executor, + None, events.Handle(cb, ()), ('',)) + self.assertRaises( + AssertionError, self.event_loop.run_in_executor, + None, events.Timer(10, cb, ())) + + def test_run_once_in_executor_canceled(self): + def cb(): + pass + h = events.Handle(cb, ()) + h.cancel() + + f = self.event_loop.run_in_executor(None, h) + self.assertIsInstance(f, futures.Future) + self.assertTrue(f.done()) + + def test_run_once_in_executor(self): + def cb(): + pass + h = events.Handle(cb, ()) + f = futures.Future() + executor = unittest.mock.Mock() + executor.submit.return_value = f + + self.event_loop.set_default_executor(executor) + + res = self.event_loop.run_in_executor(None, h) + self.assertIs(f, res) + + executor = unittest.mock.Mock() + executor.submit.return_value = f + res = self.event_loop.run_in_executor(executor, h) + self.assertIs(f, res) + self.assertTrue(executor.submit.called) + + def test_run_once(self): + self.event_loop._run_once = unittest.mock.Mock() + self.event_loop._run_once.side_effect = base_events._StopError + self.event_loop.run_once() + self.assertTrue(self.event_loop._run_once.called) + + def test__run_once(self): + h1 = events.Timer(time.monotonic() + 0.1, lambda: True, ()) + h2 = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + + h1.cancel() + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h1) + self.event_loop._scheduled.append(h2) + self.event_loop._run_once() + + t = self.event_loop._selector.select.call_args[0][0] + self.assertTrue(9.99 < t < 10.1) + self.assertEqual([h2], self.event_loop._scheduled) + self.assertTrue(self.event_loop._process_events.called) + + def test__run_once_timeout(self): + h = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._run_once(1.0) + self.assertEqual((1.0,), self.event_loop._selector.select.call_args[0]) + + def test__run_once_timeout_with_ready(self): + # If event loop has ready callbacks, select timeout is always 0. + h = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._ready.append(h) + self.event_loop._run_once(1.0) + + self.assertEqual((0,), self.event_loop._selector.select.call_args[0]) + + @unittest.mock.patch('tulip.base_events.time') + @unittest.mock.patch('tulip.base_events.tulip_log') + def test__run_once_logging(self, m_logging, m_time): + # Log to INFO level if timeout > 1.0 sec. + idx = -1 + data = [10.0, 10.0, 12.0, 13.0] + + def monotonic(): + nonlocal data, idx + idx += 1 + return data[idx] + + m_time.monotonic = monotonic + m_logging.INFO = logging.INFO + m_logging.DEBUG = logging.DEBUG + + self.event_loop._scheduled.append(events.Timer(11.0, lambda: True, ())) + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._run_once() + self.assertEqual(logging.INFO, m_logging.log.call_args[0][0]) + + idx = -1 + data = [10.0, 10.0, 10.3, 13.0] + self.event_loop._scheduled = [events.Timer(11.0, lambda:True, ())] + self.event_loop._run_once() + self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0]) + + def test__run_once_schedule_handle(self): + handle = None + processed = False + + def cb(event_loop): + nonlocal processed, handle + processed = True + handle = event_loop.call_soon(lambda: True) + + h = events.Timer(time.monotonic() - 1, cb, (self.event_loop,)) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._run_once() + + self.assertTrue(processed) + self.assertEqual([handle], list(self.event_loop._ready)) + + def test_run_until_complete_assertion(self): + self.assertRaises( + AssertionError, self.event_loop.run_until_complete, 'blah') + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_connection_mutiple_errors(self, m_socket): + self.suppress_log_errors() + + class MyProto(protocols.Protocol): + pass + + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + + idx = -1 + errors = ['err1', 'err2'] + + def _socket(*args, **kw): + nonlocal idx, errors + idx += 1 + raise socket.error(errors[idx]) + + m_socket.socket = _socket + m_socket.error = socket.error + + self.event_loop.getaddrinfo = getaddrinfo + + task = tasks.Task( + self.event_loop.create_connection(MyProto, 'example.com', 80)) + task._step() + exc = task.exception() + self.assertEqual("Multiple exceptions: err1, err2", str(exc)) diff --git a/tests/events_test.py b/tests/events_test.py new file mode 100644 index 00000000..e8855548 --- /dev/null +++ b/tests/events_test.py @@ -0,0 +1,1431 @@ +"""Tests for events.py.""" + +import concurrent.futures +import gc +import io +import os +import re +import signal +import socket +try: + import ssl +except ImportError: + ssl = None +import sys +import threading +import time +import unittest +import unittest.mock + +from tulip import events +from tulip import transports +from tulip import protocols +from tulip import selector_events +from tulip import tasks +from tulip import test_utils + + +class MyProto(protocols.Protocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + + +class MyDatagramProto(protocols.DatagramProtocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + + +class MyReadPipeProto(protocols.Protocol): + + def __init__(self): + self.state = ['INITIAL'] + self.nbytes = 0 + self.transport = None + + def connection_made(self, transport): + self.transport = transport + assert self.state == ['INITIAL'], self.state + self.state.append('CONNECTED') + + def data_received(self, data): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.state.append('EOF') + self.transport.close() + + def connection_lost(self, exc): + assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state + self.state.append('CLOSED') + + +class MyWritePipeProto(protocols.Protocol): + + def __init__(self): + self.state = 'INITIAL' + self.transport = None + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + + +class EventLoopTestsMixin: + + def setUp(self): + super().setUp() + self.event_loop = self.create_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + gc.collect() + super().tearDown() + + def test_run(self): + self.event_loop.run() # Returns immediately. + + def test_run_nesting(self): + self.suppress_log_errors() + + @tasks.coroutine + def coro(): + self.assertTrue(self.event_loop.is_running()) + self.event_loop.run_until_complete(tasks.sleep(0.1)) + + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, coro()) + + def test_run_once_nesting(self): + self.suppress_log_errors() + + @tasks.coroutine + def coro(): + tasks.sleep(0.1) + self.event_loop.run_once() + + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, coro()) + + def test_run_once_block(self): + called = False + + def callback(): + nonlocal called + called = True + + def run(): + time.sleep(0.1) + self.event_loop.call_soon_threadsafe(callback) + + self.event_loop.run_once(0) # windows iocp + + t = threading.Thread(target=run) + t0 = time.monotonic() + t.start() + self.event_loop.run_once(None) + t1 = time.monotonic() + t.join() + self.assertTrue(called) + self.assertTrue(0.09 < t1-t0 <= 0.12) + + def test_call_later(self): + results = [] + + def callback(arg): + results.append(arg) + + self.event_loop.call_later(0.1, callback, 'hello world') + t0 = time.monotonic() + self.event_loop.run() + t1 = time.monotonic() + self.assertEqual(results, ['hello world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_call_repeatedly(self): + results = [] + + def callback(arg): + results.append(arg) + + self.event_loop.call_repeatedly(0.03, callback, 'ho') + self.event_loop.call_later(0.1, self.event_loop.stop) + self.event_loop.run() + self.assertEqual(results, ['ho', 'ho', 'ho']) + + def test_call_soon(self): + results = [] + + def callback(arg1, arg2): + results.append((arg1, arg2)) + + self.event_loop.call_soon(callback, 'hello', 'world') + self.event_loop.run() + self.assertEqual(results, [('hello', 'world')]) + + def test_call_soon_with_handle(self): + results = [] + + def callback(): + results.append('yeah') + + handle = events.Handle(callback, ()) + self.assertIs(self.event_loop.call_soon(handle), handle) + self.event_loop.run() + self.assertEqual(results, ['yeah']) + + def test_call_soon_threadsafe(self): + results = [] + + def callback(arg): + results.append(arg) + + def run(): + self.event_loop.call_soon_threadsafe(callback, 'hello') + + t = threading.Thread(target=run) + self.event_loop.call_later(0.1, callback, 'world') + t0 = time.monotonic() + t.start() + self.event_loop.run() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_call_soon_threadsafe_same_thread(self): + results = [] + + def callback(arg): + results.append(arg) + + self.event_loop.call_later(0.1, callback, 'world') + self.event_loop.call_soon_threadsafe(callback, 'hello') + self.event_loop.run() + self.assertEqual(results, ['hello', 'world']) + + def test_call_soon_threadsafe_with_handle(self): + results = [] + + def callback(arg): + results.append(arg) + + handle = events.Handle(callback, ('hello',)) + + def run(): + self.assertIs( + self.event_loop.call_soon_threadsafe(handle), handle) + + t = threading.Thread(target=run) + self.event_loop.call_later(0.1, callback, 'world') + + t0 = time.monotonic() + t.start() + self.event_loop.run() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_wrap_future(self): + def run(arg): + time.sleep(0.1) + return arg + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = self.event_loop.wrap_future(f1) + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'oi') + + def test_run_in_executor(self): + def run(arg): + time.sleep(0.1) + return arg + f2 = self.event_loop.run_in_executor(None, run, 'yo') + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def test_run_in_executor_with_handle(self): + def run(arg): + time.sleep(0.01) + return arg + handle = events.Handle(run, ('yo',)) + f2 = self.event_loop.run_in_executor(None, handle) + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def test_reader_callback(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.event_loop.remove_reader(r.fileno())) + r.close() + + self.event_loop.add_reader(r.fileno(), reader) + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_reader_callback_with_handle(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.event_loop.remove_reader(r.fileno())) + r.close() + + handle = events.Handle(reader, ()) + self.assertIs(handle, self.event_loop.add_reader(r.fileno(), handle)) + + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_reader_callback_cancel(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + return + if data: + bytes_read.append(data) + if sum(len(b) for b in bytes_read) >= 6: + handle.cancel() + if not data: + r.close() + + handle = self.event_loop.add_reader(r.fileno(), reader) + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.run() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_writer_callback(self): + r, w = test_utils.socketpair() + w.setblocking(False) + self.event_loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + + def remove_writer(): + self.assertTrue(self.event_loop.remove_writer(w.fileno())) + + self.event_loop.call_later(0.1, remove_writer) + self.event_loop.run() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def test_writer_callback_with_handle(self): + r, w = test_utils.socketpair() + w.setblocking(False) + handle = events.Handle(w.send, (b'x'*(256*1024),)) + self.assertIs(self.event_loop.add_writer(w.fileno(), handle), handle) + + def remove_writer(): + self.assertTrue(self.event_loop.remove_writer(w.fileno())) + + self.event_loop.call_later(0.1, remove_writer) + self.event_loop.run() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def test_writer_callback_cancel(self): + r, w = test_utils.socketpair() + w.setblocking(False) + + def sender(): + w.send(b'x'*256) + handle.cancel() + + handle = self.event_loop.add_writer(w.fileno(), sender) + self.event_loop.run() + w.close() + data = r.recv(1024) + r.close() + self.assertTrue(data == b'x'*256) + + def test_sock_client_ops(self): + self.suppress_log_errors() + + with test_utils.run_test_server(self.event_loop) as httpd: + sock = socket.socket() + sock.setblocking(False) + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, httpd.address)) + self.event_loop.run_until_complete( + self.event_loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.event_loop.run_until_complete( + self.event_loop.sock_recv(sock, 1024)) + sock.close() + + self.assertTrue(re.match(rb'HTTP/1.0 200 OK', data), data) + + def test_sock_client_fail(self): + # Make sure that we will get an unused port + address = None + try: + s = socket.socket() + s.bind(('127.0.0.1', 0)) + address = s.getsockname() + finally: + s.close() + + sock = socket.socket() + sock.setblocking(False) + with self.assertRaises(ConnectionRefusedError): + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, address)) + sock.close() + + def test_sock_accept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + + f = self.event_loop.sock_accept(listener) + conn, addr = self.event_loop.run_until_complete(f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + def test_add_signal_handler(self): + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + # Check error behavior first. + self.assertRaises( + TypeError, self.event_loop.add_signal_handler, 'boom', my_handler) + self.assertRaises( + TypeError, self.event_loop.remove_signal_handler, 'boom') + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, signal.NSIG+1) + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, 0, my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, 0) + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, -1, my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, -1) + self.assertRaises( + RuntimeError, self.event_loop.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + self.event_loop.add_signal_handler(signal.SIGINT, my_handler) + self.event_loop.run_once() + os.kill(os.getpid(), signal.SIGINT) + self.event_loop.run_once() + self.assertEqual(caught, 1) + # Removing it should restore the default handler. + self.assertTrue(self.event_loop.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGINT)) + + @unittest.skipIf(sys.platform == 'win32', 'Unix only') + def test_cancel_signal_handler(self): + # Cancelling the handler should remove it (eventually). + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + handle = self.event_loop.add_signal_handler(signal.SIGINT, my_handler) + handle.cancel() + os.kill(os.getpid(), signal.SIGINT) + self.event_loop.run_once() + self.assertEqual(caught, 0) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_while_selecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + self.event_loop.add_signal_handler( + signal.SIGALRM, my_handler) + + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + self.event_loop.call_later(0.15, self.event_loop.stop) + self.event_loop.run_forever() + self.assertEqual(caught, 1) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_args(self): + some_args = (42,) + caught = 0 + + def my_handler(*args): + nonlocal caught + caught += 1 + self.assertEqual(args, some_args) + + self.event_loop.add_signal_handler( + signal.SIGALRM, my_handler, *some_args) + + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + self.event_loop.call_later(0.15, self.event_loop.stop) + self.event_loop.run_forever() + self.assertEqual(caught, 1) + + def test_create_connection(self): + with test_utils.run_test_server(self.event_loop) as httpd: + f = self.event_loop.create_connection(MyProto, *httpd.address) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) + + def test_create_connection_sock(self): + with test_utils.run_test_server(self.event_loop) as httpd: + sock = None + infos = self.event_loop.run_until_complete( + self.event_loop.getaddrinfo( + *httpd.address, type=socket.SOCK_STREAM)) + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, address)) + except: + pass + else: + break + else: + assert False, 'Can not create socket.' + + f = self.event_loop.create_connection(MyProto, sock=sock) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_ssl_connection(self): + with test_utils.run_test_server( + self.event_loop, use_ssl=True) as httpd: + f = self.event_loop.create_connection( + MyProto, *httpd.address, ssl=True) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + self.assertTrue( + hasattr(tr.get_extra_info('socket'), 'getsockname')) + self.event_loop.run() + self.assertTrue(pr.nbytes > 0) + + def test_create_connection_host_port_sock(self): + self.suppress_log_errors() + coro = self.event_loop.create_connection( + MyProto, 'example.com', 80, sock=object()) + self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) + + def test_create_connection_no_host_port_sock(self): + self.suppress_log_errors() + coro = self.event_loop.create_connection(MyProto) + self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) + + def test_create_connection_no_getaddrinfo(self): + self.suppress_log_errors() + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + coro = self.event_loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_create_connection_connect_err(self): + self.suppress_log_errors() + + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80))] + self.event_loop.getaddrinfo = getaddrinfo + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + coro = self.event_loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_create_connection_mutiple_errors(self): + self.suppress_log_errors() + + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + self.event_loop.getaddrinfo = getaddrinfo + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + coro = self.event_loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_start_serving(self): + proto = None + + def factory(): + nonlocal proto + proto = MyProto() + return proto + + f = self.event_loop.start_serving(factory, '0.0.0.0', 0) + sock = self.event_loop.run_until_complete(f) + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + self.event_loop.run_once() + self.assertIsInstance(proto, MyProto) + self.assertEqual('INITIAL', proto.state) + self.event_loop.run_once() + self.assertEqual('CONNECTED', proto.state) + self.event_loop.run_once(0.001) # windows iocp + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('socket')) + conn = proto.transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + self.assertEqual( + '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + + # close connection + proto.transport.close() + self.event_loop.run_once(0.001) # windows iocp + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_start_serving_ssl(self): + proto = None + + class ClientMyProto(MyProto): + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def factory(): + nonlocal proto + proto = MyProto() + return proto + + here = os.path.dirname(__file__) + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain( + certfile=os.path.join(here, 'sample.crt'), + keyfile=os.path.join(here, 'sample.key')) + + f = self.event_loop.start_serving( + factory, '127.0.0.1', 0, ssl=sslcontext) + + sock = self.event_loop.run_until_complete(f) + host, port = sock.getsockname() + self.assertEqual(host, '127.0.0.1') + + f_c = self.event_loop.create_connection( + ClientMyProto, host, port, ssl=True) + client, pr = self.event_loop.run_until_complete(f_c) + + client.write(b'xxx') + self.event_loop.run_once() + self.assertIsInstance(proto, MyProto) + self.event_loop.run_once() + self.assertEqual('CONNECTED', proto.state) + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('socket')) + conn = proto.transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + self.assertEqual( + '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + + # close connection + proto.transport.close() + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + def test_start_serving_sock(self): + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.event_loop.start_serving(MyProto, sock=sock_ob) + sock = self.event_loop.run_until_complete(f) + self.assertIs(sock, sock_ob) + + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + self.event_loop.run_once() # This is quite mysterious, but necessary. + self.event_loop.run_once() + sock.close() + client.close() + + def test_stop_serving(self): + f = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + sock = self.event_loop.run_until_complete(f) + host, port = sock.getsockname() + + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + + self.event_loop.stop_serving(sock) + + client = socket.socket() + self.assertRaises( + ConnectionRefusedError, client.connect, ('127.0.0.1', port)) + + def test_start_serving_host_port_sock(self): + self.suppress_log_errors() + fut = self.event_loop.start_serving( + MyProto, '0.0.0.0', 0, sock=object()) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_start_serving_no_host_port_sock(self): + self.suppress_log_errors() + fut = self.event_loop.start_serving(MyProto) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_start_serving_no_getaddrinfo(self): + self.suppress_log_errors() + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + f = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, f) + + @unittest.mock.patch('tulip.base_events.socket') + def test_start_serving_cant_bind(self, m_socket): + self.suppress_log_errors() + + class Err(socket.error): + pass + + m_socket.error = socket.error + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.setsockopt.side_effect = Err + + fut = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises(Err, self.event_loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_no_addrinfo(self, m_socket): + self.suppress_log_errors() + + m_socket.error = socket.error + m_socket.getaddrinfo.return_value = [] + + coro = self.event_loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 0)) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_create_datagram_endpoint_addr_error(self): + self.suppress_log_errors() + + coro = self.event_loop.create_datagram_endpoint( + MyDatagramProto, local_addr='localhost') + self.assertRaises( + AssertionError, self.event_loop.run_until_complete, coro) + coro = self.event_loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 1, 2, 3)) + self.assertRaises( + AssertionError, self.event_loop.run_until_complete, coro) + + def test_create_datagram_endpoint(self): + class TestMyDatagramProto(MyDatagramProto): + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + coro = self.event_loop.create_datagram_endpoint( + TestMyDatagramProto, local_addr=('127.0.0.1', 0)) + s_transport, server = self.event_loop.run_until_complete(coro) + host, port = s_transport.get_extra_info('addr') + + coro = self.event_loop.create_datagram_endpoint( + MyDatagramProto, remote_addr=(host, port)) + transport, client = self.event_loop.run_until_complete(coro) + + self.assertEqual('INITIALIZED', client.state) + transport.sendto(b'xxx') + self.event_loop.run_once(None) + self.assertEqual(3, server.nbytes) + self.event_loop.run_once(None) + + # received + self.assertEqual(8, client.nbytes) + + # extra info is available + self.assertIsNotNone(transport.get_extra_info('socket')) + conn = transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + + # close connection + transport.close() + + self.assertEqual('CLOSED', client.state) + server.transport.close() + + def test_create_datagram_endpoint_connect_err(self): + self.suppress_log_errors() + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0)) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_socket_err(self, m_socket): + self.suppress_log_errors() + + m_socket.error = socket.error + m_socket.getaddrinfo = socket.getaddrinfo + m_socket.socket.side_effect = socket.error + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, local_addr=('127.0.0.1', 0)) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_create_datagram_endpoint_no_matching_family(self): + self.suppress_log_errors() + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, + remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) + self.assertRaises( + ValueError, self.event_loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_setblk_err(self, m_socket): + self.suppress_log_errors() + + m_socket.error = socket.error + m_socket.socket.return_value.setblocking.side_effect = socket.error + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + self.assertTrue( + m_socket.socket.return_value.close.called) + + def test_create_datagram_endpoint_noaddr_nofamily(self): + self.suppress_log_errors() + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol) + self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_cant_bind(self, m_socket): + self.suppress_log_errors() + + class Err(socket.error): + pass + + m_socket.error = socket.error + m_socket.AF_INET6 = socket.AF_INET6 + m_socket.getaddrinfo = socket.getaddrinfo + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.event_loop.create_datagram_endpoint( + MyDatagramProto, + local_addr=('127.0.0.1', 0), family=socket.AF_INET) + self.assertRaises(Err, self.event_loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_accept_connection_retry(self): + sock = unittest.mock.Mock() + sock.accept.side_effect = BlockingIOError() + + self.event_loop._accept_connection(MyProto, sock) + self.assertFalse(sock.close.called) + + def test_accept_connection_exception(self): + self.suppress_log_errors() + + sock = unittest.mock.Mock() + sock.accept.side_effect = OSError() + + self.event_loop._accept_connection(MyProto, sock) + self.assertTrue(sock.close.called) + + def test_internal_fds(self): + event_loop = self.create_event_loop() + if not isinstance(event_loop, selector_events.BaseSelectorEventLoop): + return + + self.assertEqual(1, event_loop._internal_fds) + event_loop.close() + self.assertEqual(0, event_loop._internal_fds) + self.assertIsNone(event_loop._csock) + self.assertIsNone(event_loop._ssock) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_read_pipe(self): + proto = None + + def factory(): + nonlocal proto + proto = MyReadPipeProto() + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(rpipe, 'rb', 1024) + + @tasks.task + def connect(): + t, p = yield from self.event_loop.connect_read_pipe(factory, + pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(0, proto.nbytes) + + self.event_loop.run_until_complete(connect()) + + os.write(wpipe, b'1') + self.event_loop.run_once() + self.assertEqual(1, proto.nbytes) + + os.write(wpipe, b'2345') + self.event_loop.run_once() + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(5, proto.nbytes) + + os.close(wpipe) + self.event_loop.run_once() + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto() + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.task + def connect(): + nonlocal transport + t, p = yield from self.event_loop.connect_write_pipe(factory, + pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.event_loop.run_until_complete(connect()) + + transport.write(b'1') + self.event_loop.run_once() + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + transport.write(b'2345') + self.event_loop.run_once() + data = os.read(rpipe, 1024) + self.assertEqual(b'2345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(rpipe) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + # close connection + proto.transport.close() + self.event_loop.run_once() + self.assertEqual('CLOSED', proto.state) + + +if sys.platform == 'win32': + from tulip import windows_events + + class SelectEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return windows_events.SelectorEventLoop() + + class ProactorEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return windows_events.ProactorEventLoop() + def test_create_ssl_connection(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + def test_start_serving_ssl(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + def test_reader_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_reader_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_reader_callback_with_handle(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_writer_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_writer_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_writer_callback_with_handle(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_accept_connection_retry(self): + raise unittest.SkipTest( + "IocpEventLoop does not have _accept_connection()") + def test_accept_connection_exception(self): + raise unittest.SkipTest( + "IocpEventLoop does not have _accept_connection()") + def test_create_datagram_endpoint(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_no_connection(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_cant_bind(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_noaddr_nofamily(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_socket_err(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_connect_err(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_stop_serving(self): + raise unittest.SkipTest( + "IocpEventLoop does not support stop_serving()") +else: + from tulip import selectors + from tulip import unix_events + + if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop( + selectors.KqueueSelector()) + + if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.EpollSelector()) + + if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.PollSelector()) + + # Should always exist. + class SelectEventLoopTests(EventLoopTestsMixin, + test_utils.LogTrackingTestCase): + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.SelectSelector()) + + +class HandleTests(unittest.TestCase): + + def test_handle(self): + def callback(*args): + return args + + args = () + h = events.Handle(callback, args) + self.assertIs(h.callback, callback) + self.assertIs(h.args, args) + self.assertFalse(h.cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '.callback')) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h.cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '.callback')) + self.assertTrue(r.endswith('())')) + + def test_make_handle(self): + def callback(*args): + return args + h1 = events.Handle(callback, ()) + h2 = events.make_handle(h1, ()) + self.assertIs(h1, h2) + + self.assertRaises( + AssertionError, events.make_handle, h1, (1, 2)) + + @unittest.mock.patch('tulip.events.tulip_log') + def test_callback_with_exception(self, log): + def callback(): + raise ValueError() + + h = events.Handle(callback, ()) + h.run() + self.assertTrue(log.exception.called) + + +class TimerTests(unittest.TestCase): + + def test_timer(self): + def callback(*args): + return args + + args = () + when = time.monotonic() + h = events.Timer(when, callback, args) + self.assertIs(h.callback, callback) + self.assertIs(h.args, args) + self.assertFalse(h.cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h.cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + self.assertRaises(AssertionError, events.Timer, None, callback, args) + + def test_timer_comparison(self): + def callback(*args): + return args + + when = time.monotonic() + + h1 = events.Timer(when, callback, ()) + h2 = events.Timer(when, callback, ()) + self.assertFalse(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertTrue(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertFalse(h2 > h1) + self.assertTrue(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertTrue(h1 == h2) + self.assertFalse(h1 != h2) + + h2.cancel() + self.assertFalse(h1 == h2) + + h1 = events.Timer(when, callback, ()) + h2 = events.Timer(when + 10.0, callback, ()) + self.assertTrue(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertFalse(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertTrue(h2 > h1) + self.assertFalse(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h3 = events.Handle(callback, ()) + self.assertIs(NotImplemented, h1.__eq__(h3)) + self.assertIs(NotImplemented, h1.__ne__(h3)) + + +class AbstractEventLoopTests(unittest.TestCase): + + def test_not_imlemented(self): + f = unittest.mock.Mock() + ev_loop = events.AbstractEventLoop() + self.assertRaises( + NotImplementedError, ev_loop.run) + self.assertRaises( + NotImplementedError, ev_loop.run_forever) + self.assertRaises( + NotImplementedError, ev_loop.run_once) + self.assertRaises( + NotImplementedError, ev_loop.run_until_complete, None) + self.assertRaises( + NotImplementedError, ev_loop.stop) + self.assertRaises( + NotImplementedError, ev_loop.call_later, None, None) + self.assertRaises( + NotImplementedError, ev_loop.call_repeatedly, None, None) + self.assertRaises( + NotImplementedError, ev_loop.call_soon, None) + self.assertRaises( + NotImplementedError, ev_loop.call_soon_threadsafe, None) + self.assertRaises( + NotImplementedError, ev_loop.wrap_future, f) + self.assertRaises( + NotImplementedError, ev_loop.run_in_executor, f, f) + self.assertRaises( + NotImplementedError, ev_loop.getaddrinfo, 'localhost', 8080) + self.assertRaises( + NotImplementedError, ev_loop.getnameinfo, ('localhost', 8080)) + self.assertRaises( + NotImplementedError, ev_loop.create_connection, f) + self.assertRaises( + NotImplementedError, ev_loop.start_serving, f) + self.assertRaises( + NotImplementedError, ev_loop.stop_serving, f) + self.assertRaises( + NotImplementedError, ev_loop.create_datagram_endpoint, f) + self.assertRaises( + NotImplementedError, ev_loop.add_reader, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_reader, 1) + self.assertRaises( + NotImplementedError, ev_loop.add_writer, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_writer, 1) + self.assertRaises( + NotImplementedError, ev_loop.sock_recv, f, 10) + self.assertRaises( + NotImplementedError, ev_loop.sock_sendall, f, 10) + self.assertRaises( + NotImplementedError, ev_loop.sock_connect, f, f) + self.assertRaises( + NotImplementedError, ev_loop.sock_accept, f) + self.assertRaises( + NotImplementedError, ev_loop.add_signal_handler, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, ev_loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, ev_loop.connect_read_pipe, f, + unittest.mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, ev_loop.connect_write_pipe, f, + unittest.mock.sentinel.pipe) + + +class ProtocolsAbsTests(unittest.TestCase): + + def test_empty(self): + f = unittest.mock.Mock() + p = protocols.Protocol() + self.assertIsNone(p.connection_made(f)) + self.assertIsNone(p.connection_lost(f)) + self.assertIsNone(p.data_received(f)) + self.assertIsNone(p.eof_received()) + + dp = protocols.DatagramProtocol() + self.assertIsNone(dp.connection_made(f)) + self.assertIsNone(dp.connection_lost(f)) + self.assertIsNone(dp.connection_refused(f)) + self.assertIsNone(dp.datagram_received(f, f)) + + +class PolicyTests(unittest.TestCase): + + def test_event_loop_policy(self): + policy = events.EventLoopPolicy() + self.assertRaises(NotImplementedError, policy.get_event_loop) + self.assertRaises(NotImplementedError, policy.set_event_loop, object()) + self.assertRaises(NotImplementedError, policy.new_event_loop) + + def test_get_event_loop(self): + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy._event_loop) + + event_loop = policy.get_event_loop() + self.assertIsInstance(event_loop, events.AbstractEventLoop) + + self.assertIs(policy._event_loop, event_loop) + self.assertIs(event_loop, policy.get_event_loop()) + + @unittest.mock.patch('tulip.events.threading') + def test_get_event_loop_thread(self, m_threading): + m_t = m_threading.current_thread.return_value = unittest.mock.Mock() + m_t.name = 'Thread 1' + + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy.get_event_loop()) + + def test_new_event_loop(self): + policy = events.DefaultEventLoopPolicy() + + event_loop = policy.new_event_loop() + self.assertIsInstance(event_loop, events.AbstractEventLoop) + + def test_set_event_loop(self): + policy = events.DefaultEventLoopPolicy() + old_event_loop = policy.get_event_loop() + + self.assertRaises(AssertionError, policy.set_event_loop, object()) + + event_loop = policy.new_event_loop() + policy.set_event_loop(event_loop) + self.assertIs(event_loop, policy.get_event_loop()) + self.assertIsNot(old_event_loop, policy.get_event_loop()) + + def test_get_event_loop_policy(self): + policy = events.get_event_loop_policy() + self.assertIsInstance(policy, events.EventLoopPolicy) + self.assertIs(policy, events.get_event_loop_policy()) + + def test_set_event_loop_policy(self): + self.assertRaises( + AssertionError, events.set_event_loop_policy, object()) + + old_policy = events.get_event_loop_policy() + + policy = events.DefaultEventLoopPolicy() + events.set_event_loop_policy(policy) + self.assertIs(policy, events.get_event_loop_policy()) + self.assertIsNot(policy, old_policy) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/futures_test.py b/tests/futures_test.py new file mode 100644 index 00000000..5569cca1 --- /dev/null +++ b/tests/futures_test.py @@ -0,0 +1,222 @@ +"""Tests for futures.py.""" + +import unittest + +from tulip import futures + + +def _fakefunc(f): + return f + + +class FutureTests(unittest.TestCase): + + def test_initial_state(self): + f = futures.Future() + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertFalse(f.done()) + + def test_init_event_loop_positional(self): + # Make sure Future does't accept a positional argument + self.assertRaises(TypeError, futures.Future, 42) + + def test_cancel(self): + f = futures.Future() + self.assertTrue(f.cancel()) + self.assertTrue(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(futures.CancelledError, f.result) + self.assertRaises(futures.CancelledError, f.exception) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_result(self): + f = futures.Future() + self.assertRaises(futures.InvalidStateError, f.result) + self.assertRaises(futures.InvalidTimeoutError, f.result, 10) + + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_exception(self): + exc = RuntimeError() + f = futures.Future() + self.assertRaises(futures.InvalidStateError, f.exception) + self.assertRaises(futures.InvalidTimeoutError, f.exception, 10) + + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_yield_from_twice(self): + f = futures.Future() + + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + + def test_repr(self): + f_pending = futures.Future() + self.assertEqual(repr(f_pending), 'Future') + + f_cancelled = futures.Future() + f_cancelled.cancel() + self.assertEqual(repr(f_cancelled), 'Future') + + f_result = futures.Future() + f_result.set_result(4) + self.assertEqual(repr(f_result), 'Future') + + f_exception = futures.Future() + f_exception.set_exception(RuntimeError()) + self.assertEqual(repr(f_exception), 'Future') + + f_few_callbacks = futures.Future() + f_few_callbacks.add_done_callback(_fakefunc) + self.assertIn('Future', r) + + def test_copy_state(self): + # Test the internal _copy_state method since it's being directly + # invoked in other modules. + f = futures.Future() + f.set_result(10) + + newf = futures.Future() + newf._copy_state(f) + self.assertTrue(newf.done()) + self.assertEqual(newf.result(), 10) + + f_exception = futures.Future() + f_exception.set_exception(RuntimeError()) + + newf_exception = futures.Future() + newf_exception._copy_state(f_exception) + self.assertTrue(newf_exception.done()) + self.assertRaises(RuntimeError, newf_exception.result) + + f_cancelled = futures.Future() + f_cancelled.cancel() + + newf_cancelled = futures.Future() + newf_cancelled._copy_state(f_cancelled) + self.assertTrue(newf_cancelled.cancelled()) + + def test_iter(self): + def coro(): + fut = futures.Future() + yield from fut + + def test(): + arg1, arg2 = coro() + + self.assertRaises(AssertionError, test) + + +# A fake event loop for tests. All it does is implement a call_soon method +# that immediately invokes the given function. +class _FakeEventLoop: + def call_soon(self, fn, future): + fn(future) + + +class FutureDoneCallbackTests(unittest.TestCase): + + def _make_callback(self, bag, thing): + # Create a callback function that appends thing to bag. + def bag_appender(future): + bag.append(thing) + return bag_appender + + def _new_future(self): + return futures.Future(event_loop=_FakeEventLoop()) + + def test_callbacks_invoked_on_set_result(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 42)) + f.add_done_callback(self._make_callback(bag, 17)) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def test_callbacks_invoked_on_set_exception(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 100)) + + self.assertEqual(bag, []) + exc = RuntimeError() + f.set_exception(exc) + self.assertEqual(bag, [100]) + self.assertEqual(f.exception(), exc) + + def test_remove_done_callback(self): + bag = [] + f = self._new_future() + cb1 = self._make_callback(bag, 1) + cb2 = self._make_callback(bag, 2) + cb3 = self._make_callback(bag, 3) + + # Add one cb1 and one cb2. + f.add_done_callback(cb1) + f.add_done_callback(cb2) + + # One instance of cb2 removed. Now there's only one cb1. + self.assertEqual(f.remove_done_callback(cb2), 1) + + # Never had any cb3 in there. + self.assertEqual(f.remove_done_callback(cb3), 0) + + # After this there will be 6 instances of cb1 and one of cb2. + f.add_done_callback(cb2) + for i in range(5): + f.add_done_callback(cb1) + + # Remove all instances of cb1. One cb2 remains. + self.assertEqual(f.remove_done_callback(cb1), 6) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [2]) + self.assertEqual(f.result(), 'foo') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/http_client_functional_test.py b/tests/http_client_functional_test.py new file mode 100644 index 00000000..120f78b8 --- /dev/null +++ b/tests/http_client_functional_test.py @@ -0,0 +1,395 @@ +"""Http client functional tests.""" + +import io +import os.path +import http.cookies + +import tulip +import tulip.http +from tulip import test_utils +from tulip.http import client + + +class HttpClientFunctionalTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.suppress_log_errors() + + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + super().tearDown() + + def test_HTTP_200_OK_METHOD(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + for meth in ('get', 'post', 'put', 'delete', 'head'): + r = self.loop.run_until_complete( + client.request(meth, httpd.url('method', meth))) + content1 = self.loop.run_until_complete(r.read()) + content2 = self.loop.run_until_complete(r.read()) + content = content1.decode() + + self.assertEqual(r.status, 200) + self.assertIn('"method": "%s"' % meth.upper(), content) + self.assertEqual(content1, content2) + + def test_HTTP_302_REDIRECT_GET(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('redirect', 2))) + + self.assertEqual(r.status, 200) + self.assertEqual(2, httpd['redirects']) + + def test_HTTP_302_REDIRECT_NON_HTTP(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + self.assertRaises( + ValueError, + self.loop.run_until_complete, + client.request('get', httpd.url('redirect_err'))) + + def test_HTTP_302_REDIRECT_POST(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('post', httpd.url('redirect', 2), + data={'some': 'data'})) + content = self.loop.run_until_complete(r.content.read()) + content = content.decode() + + self.assertEqual(r.status, 200) + self.assertIn('"method": "POST"', content) + self.assertEqual(2, httpd['redirects']) + + def test_HTTP_302_max_redirects(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('redirect', 5), + max_redirects=2)) + + self.assertEqual(r.status, 302) + self.assertEqual(2, httpd['redirects']) + + def test_HTTP_200_GET_WITH_PARAMS(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('method', 'get'), + params={'q': 'test'})) + content = self.loop.run_until_complete(r.content.read()) + content = content.decode() + + self.assertIn('"query": "q=test"', content) + self.assertEqual(r.status, 200) + + def test_HTTP_200_GET_WITH_MIXED_PARAMS(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request( + 'get', httpd.url('method', 'get') + '?test=true', + params={'q': 'test'})) + content = self.loop.run_until_complete(r.content.read()) + content = content.decode() + + self.assertIn('"query": "test=true&q=test"', content) + self.assertEqual(r.status, 200) + + def test_POST_DATA(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + r = self.loop.run_until_complete( + client.request('post', url, data={'some': 'data'})) + self.assertEqual(r.status, 200) + + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual({'some': ['data']}, content['form']) + self.assertEqual(r.status, 200) + + def test_POST_DATA_DEFLATE(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + r = self.loop.run_until_complete( + client.request('post', url, + data={'some': 'data'}, compress=True)) + self.assertEqual(r.status, 200) + + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual('deflate', content['compression']) + self.assertEqual({'some': ['data']}, content['form']) + self.assertEqual(r.status, 200) + + def test_POST_FILES(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request( + 'post', url, files={'some': f}, chunked=1024, + headers={'Transfer-Encoding': 'chunked'})) + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + + def test_POST_FILES_DEFLATE(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files={'some': f}, + chunked=1024, compress='deflate')) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual('deflate', content['compression']) + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + + def test_POST_FILES_STR(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files=[('some', f.read())])) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + 'some', content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + + def test_POST_FILES_LIST(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files=[('some', f)])) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + + def test_POST_FILES_LIST_CT(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, + files=[('some', f, 'text/plain')])) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual( + 'text/plain', content['multipart-data'][0]['content-type']) + self.assertEqual(r.status, 200) + + def test_POST_FILES_SINGLE(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files=[f])) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + filename, content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + + def test_POST_FILES_IO(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + data = io.BytesIO(b'data') + + r = self.loop.run_until_complete( + client.request('post', url, files=[data])) + + content = self.loop.run_until_complete(r.read(True)) + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + {'content-type': 'application/octet-stream', + 'data': 'data', + 'filename': 'unknown', + 'name': 'unknown'}, content['multipart-data'][0]) + self.assertEqual(r.status, 200) + + def test_POST_FILES_WITH_DATA(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, + data={'test': 'true'}, files={'some': f})) + + content = self.loop.run_until_complete(r.read(True)) + + self.assertEqual(2, len(content['multipart-data'])) + self.assertEqual( + 'test', content['multipart-data'][0]['name']) + self.assertEqual( + 'true', content['multipart-data'][0]['data']) + + f.seek(0) + filename = os.path.split(f.name)[-1] + self.assertEqual( + 'some', content['multipart-data'][1]['name']) + self.assertEqual( + filename, content['multipart-data'][1]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][1]['data']) + self.assertEqual(r.status, 200) + + def test_encoding(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('encoding', 'deflate'))) + self.assertEqual(r.status, 200) + + r = self.loop.run_until_complete( + client.request('get', httpd.url('encoding', 'gzip'))) + self.assertEqual(r.status, 200) + + def test_cookies(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + c = http.cookies.Morsel() + c.set('test3', '456', '456') + + r = self.loop.run_until_complete( + client.request( + 'get', httpd.url('method', 'get'), + cookies={'test1': '123', 'test2': c})) + self.assertEqual(r.status, 200) + + content = self.loop.run_until_complete(r.content.read()) + self.assertIn(b'"Cookie": "test1=123; test3=456"', content) + + def test_chunked(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('chunked'))) + self.assertEqual(r.status, 200) + self.assertEqual(r['Transfer-Encoding'], 'chunked') + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['path'], '/chunked') + + def test_timeout(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + httpd['noresponse'] = True + self.assertRaises( + tulip.TimeoutError, + self.loop.run_until_complete, + client.request('get', httpd.url('method', 'get'), timeout=0.1)) + + def test_request_conn_error(self): + self.assertRaises( + OSError, + self.loop.run_until_complete, + client.request('get', 'http://0.0.0.0:1', timeout=0.1)) + + +class Functional(test_utils.Router): + + @test_utils.Router.define('/method/([A-Za-z]+)$') + def method(self, match): + meth = match.group(1).upper() + if meth == self._method: + self._response(self._start_response(200)) + else: + self._response(self._start_response(400)) + + @test_utils.Router.define('/redirect_err$') + def redirect_err(self, match): + self._response( + self._start_response(302), + headers={'Location': 'ftp://127.0.0.1/test/'}) + + @test_utils.Router.define('/redirect/([0-9]+)$') + def redirect(self, match): + no = int(match.group(1).upper()) + rno = self._props['redirects'] = self._props.get('redirects', 0) + 1 + + if rno >= no: + self._response( + self._start_response(302), + headers={'Location': '/method/%s' % self._method.lower()}) + else: + self._response( + self._start_response(302), + headers={'Location': self._path}) + + @test_utils.Router.define('/encoding/(gzip|deflate)$') + def encoding(self, match): + mode = match.group(1) + + resp = self._start_response(200) + resp.add_compression_filter(mode) + resp.add_chunking_filter(100) + self._response(resp, headers={'Content-encoding': mode}, chunked=True) + + @test_utils.Router.define('/chunked$') + def chunked(self, match): + resp = self._start_response(200) + resp.add_chunking_filter(100) + self._response(resp, chunked=True) diff --git a/tests/http_client_test.py b/tests/http_client_test.py new file mode 100644 index 00000000..973a6cb0 --- /dev/null +++ b/tests/http_client_test.py @@ -0,0 +1,308 @@ +# -*- coding: utf-8 -*- +"""Tests for tulip/http/client.py""" + +import unittest +import unittest.mock +import urllib.parse + +import tulip +import tulip.http + +from tulip.http.client import HttpProtocol, HttpRequest, HttpResponse + + +class HttpProtocolTests(unittest.TestCase): + + def test_protocol(self): + transport = unittest.mock.Mock() + + p = HttpProtocol() + p.connection_made(transport) + self.assertIs(p.transport, transport) + self.assertIsInstance(p.stream, tulip.http.HttpStreamReader) + + p.data_received(b'data') + self.assertEqual(4, p.stream.byte_count) + + p.eof_received() + self.assertTrue(p.stream.eof) + + p.connection_lost(None) + + +class HttpResponseTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + self.transport = unittest.mock.Mock() + self.stream = tulip.http.HttpStreamReader(self.transport) + self.response = HttpResponse('get', 'http://python.org') + + def tearDown(self): + self.loop.close() + + def test_close(self): + self.response._transport = self.transport + self.response.close() + self.assertIsNone(self.response._transport) + self.assertTrue(self.transport.close.called) + self.response.close() + self.response.close() + + def test_repr(self): + self.response.status = 200 + self.response.reason = 'Ok' + self.assertIn( + '', repr(self.response)) + + +class HttpRequestTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + self.transport = unittest.mock.Mock() + self.stream = tulip.http.HttpStreamReader(self.transport) + + def tearDown(self): + self.loop.close() + + def test_method(self): + req = HttpRequest('get', 'http://python.org/') + self.assertEqual(req.method, 'GET') + + req = HttpRequest('head', 'http://python.org/') + self.assertEqual(req.method, 'HEAD') + + req = HttpRequest('HEAD', 'http://python.org/') + self.assertEqual(req.method, 'HEAD') + + def test_version(self): + req = HttpRequest('get', 'http://python.org/', version='1.0') + self.assertEqual(req.version, (1, 0)) + + def test_version_err(self): + self.assertRaises( + ValueError, + HttpRequest, 'get', 'http://python.org/', version='1.c') + + def test_host_port(self): + req = HttpRequest('get', 'http://python.org/') + self.assertEqual(req.host, 'python.org') + self.assertEqual(req.port, 80) + self.assertFalse(req.ssl) + + req = HttpRequest('get', 'https://python.org/') + self.assertEqual(req.host, 'python.org') + self.assertEqual(req.port, 443) + self.assertTrue(req.ssl) + + req = HttpRequest('get', 'https://python.org:960/') + self.assertEqual(req.host, 'python.org') + self.assertEqual(req.port, 960) + self.assertTrue(req.ssl) + + def test_host_port_err(self): + self.assertRaises( + ValueError, HttpRequest, 'get', 'http://python.org:123e/') + + def test_host_header(self): + req = HttpRequest('get', 'http://python.org/') + self.assertEqual(req.headers['host'], 'python.org') + + req = HttpRequest('get', 'http://python.org/', + headers={'host': 'example.com'}) + self.assertEqual(req.headers['host'], 'example.com') + + def test_headers(self): + req = HttpRequest('get', 'http://python.org/', + headers={'Content-Type': 'text/plain'}) + self.assertIn('Content-Type', req.headers) + self.assertEqual(req.headers['Content-Type'], 'text/plain') + self.assertEqual(req.headers['Accept-Encoding'], 'gzip, deflate') + + def test_headers_list(self): + req = HttpRequest('get', 'http://python.org/', + headers=[('Content-Type', 'text/plain')]) + self.assertIn('Content-Type', req.headers) + self.assertEqual(req.headers['Content-Type'], 'text/plain') + + def test_headers_default(self): + req = HttpRequest('get', 'http://python.org/', + headers={'Accept-Encoding': 'deflate'}) + self.assertEqual(req.headers['Accept-Encoding'], 'deflate') + + def test_invalid_url(self): + self.assertRaises(ValueError, HttpRequest, 'get', 'hiwpefhipowhefopw') + + def test_invalid_idna(self): + self.assertRaises( + ValueError, HttpRequest, 'get', 'http://\u2061owhefopw.com') + + def test_no_path(self): + req = HttpRequest('get', 'http://python.org') + self.assertEqual('/', req.path) + + def test_basic_auth(self): + req = HttpRequest('get', 'http://python.org', auth=('nkim', '1234')) + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbToxMjM0', req.headers['Authorization']) + + def test_basic_auth_from_url(self): + req = HttpRequest('get', 'http://nkim:1234@python.org') + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbToxMjM0', req.headers['Authorization']) + + req = HttpRequest('get', 'http://nkim@python.org') + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbTo=', req.headers['Authorization']) + + req = HttpRequest( + 'get', 'http://nkim@python.org', auth=('nkim', '1234')) + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbToxMjM0', req.headers['Authorization']) + + def test_basic_auth_err(self): + self.assertRaises( + ValueError, HttpRequest, + 'get', 'http://python.org', auth=(1, 2, 3)) + + def test_no_content_length(self): + req = HttpRequest('get', 'http://python.org') + req.send(self.transport) + self.assertEqual(0, req.headers.get('Content-Length')) + + req = HttpRequest('head', 'http://python.org') + req.send(self.transport) + self.assertEqual(0, req.headers.get('Content-Length')) + + def test_path_is_not_double_encoded(self): + req = HttpRequest('get', "http://0.0.0.0/get/test case") + self.assertEqual(req.path, "/get/test%20case") + + req = HttpRequest('get', "http://0.0.0.0/get/test%20case") + self.assertEqual(req.path, "/get/test%20case") + + def test_params_are_added_before_fragment(self): + req = HttpRequest( + 'GET', "http://example.com/path#fragment", params={"a": "b"}) + self.assertEqual( + req.path, "/path?a=b#fragment") + + req = HttpRequest( + 'GET', + "http://example.com/path?key=value#fragment", params={"a": "b"}) + self.assertEqual( + req.path, "/path?key=value&a=b#fragment") + + def test_cookies(self): + req = HttpRequest( + 'get', 'http://test.com/path', cookies={'cookie1': 'val1'}) + self.assertIn('Cookie', req.headers) + self.assertEqual('cookie1=val1', req.headers['cookie']) + + req = HttpRequest( + 'get', 'http://test.com/path', + headers={'cookie': 'cookie1=val1'}, + cookies={'cookie2': 'val2'}) + self.assertEqual('cookie1=val1; cookie2=val2', req.headers['cookie']) + + def test_unicode_get(self): + def join(*suffix): + return urllib.parse.urljoin('http://python.org/', '/'.join(suffix)) + + url = 'http://python.org' + req = HttpRequest('get', url, params={'foo': 'f\xf8\xf8'}) + self.assertEqual('/?foo=f%C3%B8%C3%B8', req.path) + req = HttpRequest('', url, params={'f\xf8\xf8': 'f\xf8\xf8'}) + self.assertEqual('/?f%C3%B8%C3%B8=f%C3%B8%C3%B8', req.path) + req = HttpRequest('', url, params={'foo': 'foo'}) + self.assertEqual('/?foo=foo', req.path) + req = HttpRequest('', join('\xf8'), params={'foo': 'foo'}) + self.assertEqual('/%C3%B8?foo=foo', req.path) + + def test_query_multivalued_param(self): + for meth in HttpRequest.ALL_METHODS: + req = HttpRequest( + meth, 'http://python.org', + params=(('test', 'foo'), ('test', 'baz'))) + self.assertEqual(req.path, '/?test=foo&test=baz') + + def test_post_data(self): + for meth in HttpRequest.POST_METHODS: + req = HttpRequest(meth, 'http://python.org/', data={'life': '42'}) + req.send(self.transport) + self.assertEqual('/', req.path) + self.assertEqual(b'life=42', req.body[0]) + self.assertEqual('application/x-www-form-urlencoded', + req.headers['content-type']) + + def test_get_with_data(self): + for meth in HttpRequest.GET_METHODS: + req = HttpRequest(meth, 'http://python.org/', data={'life': '42'}) + self.assertEqual('/?life=42', req.path) + + @unittest.mock.patch('tulip.http.client.tulip') + def test_content_encoding(self, m_tulip): + req = HttpRequest('get', 'http://python.org/', compress='deflate') + req.send(self.transport) + self.assertEqual(req.headers['Transfer-encoding'], 'chunked') + self.assertEqual(req.headers['Content-encoding'], 'deflate') + m_tulip.http.Request.return_value\ + .add_compression_filter.assert_called_with('deflate') + + @unittest.mock.patch('tulip.http.client.tulip') + def test_content_encoding_header(self, m_tulip): + req = HttpRequest('get', 'http://python.org/', + headers={'Content-Encoding': 'deflate'}) + req.send(self.transport) + self.assertEqual(req.headers['Transfer-encoding'], 'chunked') + self.assertEqual(req.headers['Content-encoding'], 'deflate') + + m_tulip.http.Request.return_value\ + .add_compression_filter.assert_called_with('deflate') + m_tulip.http.Request.return_value\ + .add_chunking_filter.assert_called_with(8196) + + def test_chunked(self): + req = HttpRequest( + 'get', 'http://python.org/', + headers={'Transfer-encoding': 'gzip'}) + req.send(self.transport) + self.assertEqual('gzip', req.headers['Transfer-encoding']) + + req = HttpRequest( + 'get', 'http://python.org/', + headers={'Transfer-encoding': 'chunked'}) + req.send(self.transport) + self.assertEqual('chunked', req.headers['Transfer-encoding']) + + @unittest.mock.patch('tulip.http.client.tulip') + def test_chunked_explicit(self, m_tulip): + req = HttpRequest( + 'get', 'http://python.org/', chunked=True) + req.send(self.transport) + + self.assertEqual('chunked', req.headers['Transfer-encoding']) + m_tulip.http.Request.return_value\ + .add_chunking_filter.assert_called_with(8196) + + @unittest.mock.patch('tulip.http.client.tulip') + def test_chunked_explicit_size(self, m_tulip): + req = HttpRequest( + 'get', 'http://python.org/', chunked=1024) + req.send(self.transport) + self.assertEqual('chunked', req.headers['Transfer-encoding']) + m_tulip.http.Request.return_value\ + .add_chunking_filter.assert_called_with(1024) + + def test_chunked_length(self): + req = HttpRequest( + 'get', 'http://python.org/', + headers={'Content-Length': '1000'}, chunked=1024) + req.send(self.transport) + self.assertEqual(req.headers['Transfer-Encoding'], 'chunked') + self.assertNotIn('Content-Length', req.headers) diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py new file mode 100644 index 00000000..c806337c --- /dev/null +++ b/tests/http_protocol_test.py @@ -0,0 +1,982 @@ +"""Tests for http/protocol.py""" + +import http.client +import unittest +import unittest.mock +import zlib + +import tulip +from tulip.http import protocol +from tulip.test_utils import LogTrackingTestCase + + +class HttpStreamReaderTests(LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.suppress_log_errors() + + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + self.transport = unittest.mock.Mock() + self.stream = protocol.HttpStreamReader() + + def tearDown(self): + self.loop.close() + super().tearDown() + + def test_request_line(self): + self.stream.feed_data(b'get /path HTTP/1.1\r\n') + self.assertEqual( + ('GET', '/path', (1, 1)), + self.loop.run_until_complete(self.stream.read_request_line())) + + def test_request_line_two_slashes(self): + self.stream.feed_data(b'get //path HTTP/1.1\r\n') + self.assertEqual( + ('GET', '//path', (1, 1)), + self.loop.run_until_complete(self.stream.read_request_line())) + + def test_request_line_non_ascii(self): + self.stream.feed_data(b'get /path\xd0\xb0 HTTP/1.1\r\n') + + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete(self.stream.read_request_line()) + + self.assertEqual( + b'get /path\xd0\xb0 HTTP/1.1\r\n', cm.exception.args[0]) + + def test_request_line_bad_status_line(self): + self.stream.feed_data(b'\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_request_line()) + + def test_request_line_bad_method(self): + self.stream.feed_data(b'!12%()+=~$ /get HTTP/1.1\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_request_line()) + + def test_request_line_bad_version(self): + self.stream.feed_data(b'GET //get HT/11\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_request_line()) + + def test_response_status_bad_status_line(self): + self.stream.feed_data(b'\r\n') + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_response_status()) + + def test_response_status_bad_status_line_eof(self): + self.stream.feed_eof() + self.assertRaises( + http.client.BadStatusLine, + self.loop.run_until_complete, + self.stream.read_response_status()) + + def test_response_status_bad_status_non_ascii(self): + self.stream.feed_data(b'HTTP/1.1 200 \xd0\xb0\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete(self.stream.read_response_status()) + + self.assertEqual(b'HTTP/1.1 200 \xd0\xb0\r\n', cm.exception.args[0]) + + def test_response_status_bad_version(self): + self.stream.feed_data(b'HT/11 200 Ok\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete(self.stream.read_response_status()) + + self.assertEqual('HT/11 200 Ok', cm.exception.args[0]) + + def test_response_status_no_reason(self): + self.stream.feed_data(b'HTTP/1.1 200\r\n') + + v, s, r = self.loop.run_until_complete( + self.stream.read_response_status()) + self.assertEqual(v, (1, 1)) + self.assertEqual(s, 200) + self.assertEqual(r, '') + + def test_response_status_bad(self): + self.stream.feed_data(b'HTT/1\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + self.stream.read_response_status()) + + self.assertIn('HTT/1', str(cm.exception)) + + def test_response_status_bad_code_under_100(self): + self.stream.feed_data(b'HTTP/1.1 99 test\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + self.stream.read_response_status()) + + self.assertIn('HTTP/1.1 99 test', str(cm.exception)) + + def test_response_status_bad_code_above_999(self): + self.stream.feed_data(b'HTTP/1.1 9999 test\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + self.stream.read_response_status()) + + self.assertIn('HTTP/1.1 9999 test', str(cm.exception)) + + def test_response_status_bad_code_not_int(self): + self.stream.feed_data(b'HTTP/1.1 ttt test\r\n') + with self.assertRaises(http.client.BadStatusLine) as cm: + self.loop.run_until_complete( + self.stream.read_response_status()) + + self.assertIn('HTTP/1.1 ttt test', str(cm.exception)) + + def test_read_headers(self): + self.stream.feed_data(b'test: line\r\n' + b' continue\r\n' + b'test2: data\r\n' + b'\r\n') + + headers = self.loop.run_until_complete(self.stream.read_headers()) + self.assertEqual(headers, + [('TEST', 'line\r\n continue'), ('TEST2', 'data')]) + + def test_read_headers_size(self): + self.stream.feed_data(b'test: line\r\n') + self.stream.feed_data(b' continue\r\n') + self.stream.feed_data(b'test2: data\r\n') + self.stream.feed_data(b'\r\n') + + self.stream.MAX_HEADERS = 5 + self.assertRaises( + http.client.LineTooLong, + self.loop.run_until_complete, + self.stream.read_headers()) + + def test_read_headers_invalid_header(self): + self.stream.feed_data(b'test line\r\n') + + with self.assertRaises(ValueError) as cm: + self.loop.run_until_complete(self.stream.read_headers()) + + self.assertIn("Invalid header b'test line'", str(cm.exception)) + + def test_read_headers_invalid_name(self): + self.stream.feed_data(b'test[]: line\r\n') + + with self.assertRaises(ValueError) as cm: + self.loop.run_until_complete(self.stream.read_headers()) + + self.assertIn("Invalid header name b'TEST[]'", str(cm.exception)) + + def test_read_headers_headers_size(self): + self.stream.MAX_HEADERFIELD_SIZE = 5 + self.stream.feed_data(b'test: line data data\r\ndata\r\n') + + with self.assertRaises(http.client.LineTooLong) as cm: + self.loop.run_until_complete(self.stream.read_headers()) + + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_read_headers_continuation_headers_size(self): + self.stream.MAX_HEADERFIELD_SIZE = 5 + self.stream.feed_data(b'test: line\r\n test\r\n') + + with self.assertRaises(http.client.LineTooLong) as cm: + self.loop.run_until_complete(self.stream.read_headers()) + + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_read_message_should_close(self): + self.stream.feed_data( + b'Host: example.com\r\nConnection: close\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + self.assertTrue(msg.should_close) + + def test_read_message_should_close_http11(self): + self.stream.feed_data( + b'Host: example.com\r\n\r\n') + + msg = self.loop.run_until_complete( + self.stream.read_message(version=(1, 1))) + self.assertFalse(msg.should_close) + + def test_read_message_should_close_http10(self): + self.stream.feed_data( + b'Host: example.com\r\n\r\n') + + msg = self.loop.run_until_complete( + self.stream.read_message(version=(1, 0))) + self.assertTrue(msg.should_close) + + def test_read_message_should_close_keep_alive(self): + self.stream.feed_data( + b'Host: example.com\r\nConnection: keep-alive\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + self.assertFalse(msg.should_close) + + def test_read_message_content_length_broken(self): + self.stream.feed_data( + b'Host: example.com\r\nContent-Length: qwe\r\n\r\n') + + self.assertRaises( + http.client.HTTPException, + self.loop.run_until_complete, + self.stream.read_message()) + + def test_read_message_content_length_wrong(self): + self.stream.feed_data( + b'Host: example.com\r\nContent-Length: -1\r\n\r\n') + + self.assertRaises( + http.client.HTTPException, + self.loop.run_until_complete, + self.stream.read_message()) + + def test_read_message_content_length(self): + self.stream.feed_data( + b'Host: example.com\r\nContent-Length: 2\r\n\r\n12') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'12', payload) + + def test_read_message_content_length_no_val(self): + self.stream.feed_data(b'Host: example.com\r\n\r\n12') + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=False)) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'', payload) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_read_message_deflate(self): + self.stream.feed_data( + ('Host: example.com\r\nContent-Length: {}\r\n' + 'Content-Encoding: deflate\r\n\r\n'.format( + len(self._COMPRESSED)).encode())) + self.stream.feed_data(self._COMPRESSED) + + msg = self.loop.run_until_complete(self.stream.read_message()) + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'data', payload) + + def test_read_message_deflate_disabled(self): + self.stream.feed_data( + ('Host: example.com\r\nContent-Encoding: deflate\r\n' + 'Content-Length: {}\r\n\r\n'.format( + len(self._COMPRESSED)).encode())) + self.stream.feed_data(self._COMPRESSED) + + msg = self.loop.run_until_complete( + self.stream.read_message(compression=False)) + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(self._COMPRESSED, payload) + + def test_read_message_deflate_unknown(self): + self.stream.feed_data( + ('Host: example.com\r\nContent-Encoding: compress\r\n' + 'Content-Length: {}\r\n\r\n'.format( + len(self._COMPRESSED)).encode())) + self.stream.feed_data(self._COMPRESSED) + + msg = self.loop.run_until_complete( + self.stream.read_message(compression=False)) + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(self._COMPRESSED, payload) + + def test_read_message_websocket(self): + self.stream.feed_data( + b'Host: example.com\r\nSec-Websocket-Key1: 13\r\n\r\n1234567890') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'12345678', payload) + + def test_read_message_chunked(self): + self.stream.feed_data( + b'Host: example.com\r\nTransfer-Encoding: chunked\r\n\r\n') + self.stream.feed_data( + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'dataline', payload) + + def test_read_message_readall_eof(self): + self.stream.feed_data( + b'Host: example.com\r\n\r\n') + self.stream.feed_data(b'data') + self.stream.feed_data(b'line') + self.stream.feed_eof() + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + payload = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'dataline', payload) + + def test_read_message_payload(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 8\r\n\r\n') + self.stream.feed_data(b'data') + self.stream.feed_data(b'data') + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + data = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'datadata', data) + + def test_read_message_payload_eof(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 4\r\n\r\n') + self.stream.feed_data(b'da') + self.stream.feed_eof() + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, msg.payload.read()) + + def test_read_message_length_payload_zero(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 0\r\n\r\n') + self.stream.feed_data(b'data') + + msg = self.loop.run_until_complete(self.stream.read_message()) + data = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'', data) + + def test_read_message_length_payload_incomplete(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 8\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + @tulip.coroutine + def coro(): + self.stream.feed_data(b'data') + self.stream.feed_eof() + return (yield from msg.payload.read()) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, coro()) + + def test_read_message_eof_payload(self): + self.stream.feed_data(b'Host: example.com\r\n\r\n') + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + @tulip.coroutine + def coro(): + self.stream.feed_data(b'data') + self.stream.feed_eof() + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'data', data) + + def test_read_message_length_payload(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 4\r\n\r\n') + self.stream.feed_data(b'da') + self.stream.feed_data(b't') + self.stream.feed_data(b'ali') + self.stream.feed_data(b'ne') + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + self.assertIsInstance(msg.payload, tulip.StreamReader) + + data = self.loop.run_until_complete(msg.payload.read()) + self.assertEqual(b'data', data) + self.assertEqual(b'line', b''.join(self.stream.buffer)) + + def test_read_message_length_payload_extra(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Length: 4\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + @tulip.coroutine + def coro(): + self.stream.feed_data(b'da') + self.stream.feed_data(b't') + self.stream.feed_data(b'ali') + self.stream.feed_data(b'ne') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'data', data) + self.assertEqual(b'line', b''.join(self.stream.buffer)) + + def test_parse_length_payload_eof_exc(self): + parser = self.stream._parse_length_payload(4) + next(parser) + + stream = tulip.StreamReader() + parser.send(stream) + self.stream._parser = parser + self.stream.feed_data(b'da') + + @tulip.coroutine + def eof(): + self.stream.feed_eof() + + t1 = tulip.Task(stream.read()) + t2 = tulip.Task(eof()) + + self.loop.run_until_complete(tulip.wait([t1, t2])) + self.assertRaises(http.client.IncompleteRead, t1.result) + self.assertIsNone(self.stream._parser) + + def test_read_message_deflate_payload(self): + comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + + data = b''.join([comp.compress(b'data'), comp.flush()]) + + self.stream.feed_data( + b'Host: example.com\r\n' + b'Content-Encoding: deflate\r\n' + + ('Content-Length: {}\r\n\r\n'.format(len(data)).encode())) + + msg = self.loop.run_until_complete( + self.stream.read_message(readall=True)) + + @tulip.coroutine + def coro(): + self.stream.feed_data(data) + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'data', data) + + def test_read_message_chunked_payload(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + @tulip.coroutine + def coro(): + self.stream.feed_data( + b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'dataline', data) + + def test_read_message_chunked_payload_chunks(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + @tulip.coroutine + def coro(): + self.stream.feed_data(b'4\r\ndata\r') + self.stream.feed_data(b'\n4') + self.stream.feed_data(b'\r') + self.stream.feed_data(b'\n') + self.stream.feed_data(b'line\r\n0\r\n') + self.stream.feed_data(b'test\r\n\r\n') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'dataline', data) + + def test_read_message_chunked_payload_incomplete(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + @tulip.coroutine + def coro(): + self.stream.feed_data(b'4\r\ndata\r\n') + self.stream.feed_eof() + return (yield from msg.payload.read()) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, coro()) + + def test_read_message_chunked_payload_extension(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + @tulip.coroutine + def coro(): + self.stream.feed_data( + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') + return (yield from msg.payload.read()) + + data = self.loop.run_until_complete(coro()) + self.assertEqual(b'dataline', data) + + def test_read_message_chunked_payload_size_error(self): + self.stream.feed_data( + b'Host: example.com\r\n' + b'Transfer-Encoding: chunked\r\n\r\n') + + msg = self.loop.run_until_complete(self.stream.read_message()) + + @tulip.coroutine + def coro(): + self.stream.feed_data(b'blah\r\n') + return (yield from msg.payload.read()) + + self.assertRaises( + http.client.IncompleteRead, + self.loop.run_until_complete, coro()) + + def test_deflate_stream_set_exception(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + exc = ValueError() + dstream.set_exception(exc) + self.assertIs(exc, stream.exception()) + + def test_deflate_stream_feed_data(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + dstream.zlib = unittest.mock.Mock() + dstream.zlib.decompress.return_value = b'line' + + dstream.feed_data(b'data') + self.assertEqual([b'line'], list(stream.buffer)) + + def test_deflate_stream_feed_data_err(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + exc = ValueError() + dstream.zlib = unittest.mock.Mock() + dstream.zlib.decompress.side_effect = exc + + dstream.feed_data(b'data') + self.assertIsInstance(stream.exception(), http.client.IncompleteRead) + + def test_deflate_stream_feed_eof(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + dstream.zlib = unittest.mock.Mock() + dstream.zlib.flush.return_value = b'line' + + dstream.feed_eof() + self.assertEqual([b'line'], list(stream.buffer)) + self.assertTrue(stream.eof) + + def test_deflate_stream_feed_eof_err(self): + stream = tulip.StreamReader() + dstream = protocol.DeflateStream(stream, 'deflate') + + dstream.zlib = unittest.mock.Mock() + dstream.zlib.flush.return_value = b'line' + dstream.zlib.eof = False + + dstream.feed_eof() + self.assertIsInstance(stream.exception(), http.client.IncompleteRead) + + +class HttpMessageTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + + def test_start_request(self): + msg = protocol.Request( + self.transport, 'GET', '/index.html', close=True) + + self.assertIs(msg.transport, self.transport) + self.assertIsNone(msg.status) + self.assertTrue(msg.closing) + self.assertEqual(msg.status_line, 'GET /index.html HTTP/1.1\r\n') + + def test_start_response(self): + msg = protocol.Response(self.transport, 200, close=True) + + self.assertIs(msg.transport, self.transport) + self.assertEqual(msg.status, 200) + self.assertTrue(msg.closing) + self.assertEqual(msg.status_line, 'HTTP/1.1 200 OK\r\n') + + def test_force_close(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.closing) + msg.force_close() + self.assertTrue(msg.closing) + + def test_force_chunked(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.chunked) + msg.force_chunked() + self.assertTrue(msg.chunked) + + def test_keep_alive(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.keep_alive()) + msg.keepalive = True + self.assertTrue(msg.keep_alive()) + + msg.force_close() + self.assertFalse(msg.keep_alive()) + + def test_add_header(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], msg.headers) + + msg.add_header('content-type', 'plain/html') + self.assertEqual([('CONTENT-TYPE', 'plain/html')], msg.headers) + + def test_add_headers(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], msg.headers) + + msg.add_headers(('content-type', 'plain/html')) + self.assertEqual([('CONTENT-TYPE', 'plain/html')], msg.headers) + + def test_add_headers_length(self): + msg = protocol.Response(self.transport, 200) + self.assertIsNone(msg.length) + + msg.add_headers(('content-length', '200')) + self.assertEqual(200, msg.length) + + def test_add_headers_upgrade(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.upgrade) + + msg.add_headers(('connection', 'upgrade')) + self.assertTrue(msg.upgrade) + + def test_add_headers_upgrade_websocket(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('upgrade', 'test')) + self.assertEqual([], msg.headers) + + msg.add_headers(('upgrade', 'websocket')) + self.assertEqual([('UPGRADE', 'websocket')], msg.headers) + + def test_add_headers_connection_keepalive(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('connection', 'keep-alive')) + self.assertEqual([], msg.headers) + self.assertTrue(msg.keepalive) + + msg.add_headers(('connection', 'close')) + self.assertFalse(msg.keepalive) + + def test_add_headers_hop_headers(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('connection', 'test'), ('transfer-encoding', 't')) + self.assertEqual([], msg.headers) + + def test_default_headers(self): + msg = protocol.Response(self.transport, 200) + + headers = [r for r, _ in msg._default_headers()] + self.assertIn('DATE', headers) + self.assertIn('CONNECTION', headers) + + def test_default_headers_server(self): + msg = protocol.Response(self.transport, 200) + + headers = [r for r, _ in msg._default_headers()] + self.assertIn('SERVER', headers) + + def test_default_headers_useragent(self): + msg = protocol.Request(self.transport, 'GET', '/') + + headers = [r for r, _ in msg._default_headers()] + self.assertNotIn('SERVER', headers) + self.assertIn('USER-AGENT', headers) + + def test_default_headers_chunked(self): + msg = protocol.Response(self.transport, 200) + + headers = [r for r, _ in msg._default_headers()] + self.assertNotIn('TRANSFER-ENCODING', headers) + + msg.force_chunked() + + headers = [r for r, _ in msg._default_headers()] + self.assertIn('TRANSFER-ENCODING', headers) + + def test_default_headers_connection_upgrade(self): + msg = protocol.Response(self.transport, 200) + msg.upgrade = True + + headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'upgrade')], headers) + + def test_default_headers_connection_close(self): + msg = protocol.Response(self.transport, 200) + msg.force_close() + + headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'close')], headers) + + def test_default_headers_connection_keep_alive(self): + msg = protocol.Response(self.transport, 200) + msg.keepalive = True + + headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'keep-alive')], headers) + + def test_send_headers(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-type', 'plain/html')) + self.assertFalse(msg.is_headers_sent()) + + msg.send_headers() + + content = b''.join([arg[1][0] for arg in list(write.mock_calls)]) + + self.assertTrue(content.startswith(b'HTTP/1.1 200 OK\r\n')) + self.assertIn(b'CONTENT-TYPE: plain/html', content) + self.assertTrue(msg.headers_sent) + self.assertTrue(msg.is_headers_sent()) + + def test_send_headers_nomore_add(self): + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-type', 'plain/html')) + msg.send_headers() + + self.assertRaises(AssertionError, + msg.add_header, 'content-type', 'plain/html') + + def test_prepare_length(self): + msg = protocol.Response(self.transport, 200) + length = msg._write_length_payload = unittest.mock.Mock() + length.return_value = iter([1, 2, 3]) + + msg.add_headers(('content-length', '200')) + msg.send_headers() + + self.assertTrue(length.called) + self.assertTrue((200,), length.call_args[0]) + + def test_prepare_chunked_force(self): + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + + chunked = msg._write_chunked_payload = unittest.mock.Mock() + chunked.return_value = iter([1, 2, 3]) + + msg.add_headers(('content-length', '200')) + msg.send_headers() + self.assertTrue(chunked.called) + + def test_prepare_chunked_no_length(self): + msg = protocol.Response(self.transport, 200) + + chunked = msg._write_chunked_payload = unittest.mock.Mock() + chunked.return_value = iter([1, 2, 3]) + + msg.send_headers() + self.assertTrue(chunked.called) + + def test_prepare_eof(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + + eof = msg._write_eof_payload = unittest.mock.Mock() + eof.return_value = iter([1, 2, 3]) + + msg.send_headers() + self.assertTrue(eof.called) + + def test_write_auto_send_headers(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg._send_headers = True + + msg.write(b'data1') + self.assertTrue(msg.headers_sent) + + def test_write_payload_eof(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg.send_headers() + + msg.write(b'data1') + self.assertTrue(msg.headers_sent) + + msg.write(b'data2') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'data1data2', content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg.send_headers() + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'4\r\ndata\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_multiple(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg.send_headers() + + msg.write(b'data1') + msg.write(b'data2') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_length(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '2')) + msg.send_headers() + + msg.write(b'd') + msg.write(b'ata') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'da', content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_filter(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(2) + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith(b'2\r\nda\r\n2\r\nta\r\n0\r\n\r\n')) + + def test_write_payload_chunked_filter_mutiple_chunks(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(2) + msg.write(b'data1') + msg.write(b'data2') + msg.write_eof() + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith( + b'2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n' + b'2\r\na2\r\n0\r\n\r\n')) + + def test_write_payload_chunked_large_chunk(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(1024) + msg.write(b'data') + msg.write_eof() + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith(b'4\r\ndata\r\n0\r\n\r\n')) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_write_payload_deflate_filter(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '{}'.format(len(self._COMPRESSED)))) + msg.send_headers() + + msg.add_compression_filter('deflate') + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_deflate_and_chunked(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_compression_filter('deflate') + msg.add_chunking_filter(2) + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'2\r\nKI\r\n2\r\n,I\r\n2\r\n\x04\x00\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_and_deflate(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '{}'.format(len(self._COMPRESSED)))) + + msg.add_chunking_filter(2) + msg.add_compression_filter('deflate') + msg.send_headers() + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) diff --git a/tests/http_server_test.py b/tests/http_server_test.py new file mode 100644 index 00000000..2ab41840 --- /dev/null +++ b/tests/http_server_test.py @@ -0,0 +1,240 @@ +"""Tests for http/server.py""" + +import unittest +import unittest.mock + +import tulip +from tulip.http import server +from tulip.http import errors +from tulip.test_utils import LogTrackingTestCase + + +class HttpServerProtocolTests(LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.suppress_log_errors() + + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + super().tearDown() + + def test_http_status_exception(self): + exc = errors.HttpStatusException(500, message='Internal error') + self.assertEqual(exc.code, 500) + self.assertEqual(exc.message, 'Internal error') + + def test_handle_request(self): + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + rline = unittest.mock.Mock() + rline.version = (1, 1) + message = unittest.mock.Mock() + srv.handle_request(rline, message) + + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertTrue(content.startswith(b'HTTP/1.1 404 Not Found\r\n')) + + def test_connection_made(self): + srv = server.ServerHttpProtocol() + self.assertIsNone(srv._request_handle) + + srv.connection_made(unittest.mock.Mock()) + self.assertIsNotNone(srv._request_handle) + + def test_data_received(self): + srv = server.ServerHttpProtocol() + srv.connection_made(unittest.mock.Mock()) + + srv.data_received(b'123') + self.assertEqual(b'123', b''.join(srv.stream.buffer)) + + srv.data_received(b'456') + self.assertEqual(b'123456', b''.join(srv.stream.buffer)) + + def test_eof_received(self): + srv = server.ServerHttpProtocol() + srv.connection_made(unittest.mock.Mock()) + srv.eof_received() + self.assertTrue(srv.stream.eof) + + def test_connection_lost(self): + srv = server.ServerHttpProtocol() + srv.connection_made(unittest.mock.Mock()) + srv.data_received(b'123') + + handle = srv._request_handle + srv.connection_lost(None) + + self.assertIsNone(srv._request_handle) + self.assertTrue(handle.cancelled()) + + srv.connection_lost(None) + self.assertIsNone(srv._request_handle) + + def test_close(self): + srv = server.ServerHttpProtocol() + self.assertFalse(srv._closing) + + srv.close() + self.assertTrue(srv._closing) + + def test_handle_error(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + srv.handle_error(404) + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertIn(b'HTTP/1.1 404 Not Found', content) + + @unittest.mock.patch('tulip.http.server.traceback') + def test_handle_error_traceback_exc(self, m_trace): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(debug=True) + srv.connection_made(transport) + + m_trace.format_exc.side_effect = ValueError + + srv.handle_error(500, exc=object()) + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertTrue( + content.startswith(b'HTTP/1.1 500 Internal Server Error')) + + def test_handle_error_debug(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.debug = True + srv.connection_made(transport) + + try: + raise ValueError() + except Exception as exc: + srv.handle_error(999, exc=exc) + + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + + self.assertIn(b'HTTP/1.1 500 Internal', content) + self.assertIn(b'Traceback (most recent call last):', content) + + def test_handle_error_500(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log) + srv.connection_made(transport) + + srv.handle_error(500) + self.assertTrue(log.exception.called) + + def test_handle(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + srv.close() + + self.loop.run_until_complete(srv._request_handle) + self.assertTrue(handle.called) + self.assertIsNone(srv._request_handle) + + def test_handle_coro(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + called = False + + @tulip.coroutine + def coro(rline, message): + nonlocal called + called = True + srv.eof_received() + + srv.handle_request = coro + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + self.loop.run_until_complete(srv._request_handle) + self.assertTrue(called) + + def test_handle_close(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + srv.close() + self.loop.run_until_complete(srv._request_handle) + + self.assertTrue(handle.called) + self.assertTrue(transport.close.called) + + def test_handle_cancel(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log, debug=True) + srv.connection_made(transport) + + srv.handle_request = unittest.mock.Mock() + + @tulip.task + def cancel(): + srv._request_handle.cancel() + + srv.close() + self.loop.run_until_complete( + tulip.wait([srv._request_handle, cancel()])) + self.assertTrue(log.debug.called) + + def test_handle_400(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + srv.handle_error = unittest.mock.Mock() + + def side_effect(*args): + srv.close() + srv.handle_error.side_effect = side_effect + + srv.stream.feed_data(b'GET / HT/asd\r\n') + + self.loop.run_until_complete(srv._request_handle) + self.assertTrue(srv.handle_error.called) + self.assertTrue(400, srv.handle_error.call_args[0][0]) + self.assertTrue(transport.close.called) + + def test_handle_500(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + handle.side_effect = ValueError + srv.handle_error = unittest.mock.Mock() + srv.close() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + self.loop.run_until_complete(srv._request_handle) + + self.assertTrue(srv.handle_error.called) + self.assertTrue(500, srv.handle_error.call_args[0][0]) diff --git a/tests/http_wsgi_test.py b/tests/http_wsgi_test.py new file mode 100644 index 00000000..2fc0fee8 --- /dev/null +++ b/tests/http_wsgi_test.py @@ -0,0 +1,242 @@ +"""Tests for http/wsgi.py""" + +import io +import unittest +import unittest.mock + +import tulip +from tulip.http import wsgi +from tulip.http import protocol +from tulip.test_utils import LogTrackingTestCase + + +class HttpWsgiServerProtocolTests(LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.suppress_log_errors() + + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + self.wsgi = unittest.mock.Mock() + self.stream = unittest.mock.Mock() + self.transport = unittest.mock.Mock() + self.transport.get_extra_info.return_value = '127.0.0.1' + + self.payload = b'data' + self.info = protocol.RequestLine('GET', '/path', (1, 0)) + self.headers = [] + self.message = protocol.RawHttpMessage( + self.headers, b'data', True, 'deflate') + + def tearDown(self): + self.loop.close() + super().tearDown() + + def test_ctor(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi) + self.assertIs(srv.wsgi, self.wsgi) + self.assertFalse(srv.readpayload) + + def _make_one(self, **kw): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, **kw) + srv.stream = self.stream + srv.transport = self.transport + return srv.create_wsgi_environ(self.info, self.message, self.payload) + + def test_environ(self): + environ = self._make_one() + self.assertEqual(environ['RAW_URI'], '/path') + self.assertEqual(environ['wsgi.async'], True) + + def test_environ_except_header(self): + self.headers.append(('EXPECT', '101-continue')) + self._make_one() + self.assertFalse(self.transport.write.called) + + self.headers[0] = ('EXPECT', '100-continue') + self._make_one() + self.transport.write.assert_called_with( + b'HTTP/1.1 100 Continue\r\n\r\n') + + def test_environ_headers(self): + self.headers.extend( + (('HOST', 'python.org'), + ('SCRIPT_NAME', 'script'), + ('CONTENT-TYPE', 'text/plain'), + ('CONTENT-LENGTH', '209'), + ('X_TEST', '123'), + ('X_TEST', '456'))) + environ = self._make_one(is_ssl=True) + self.assertEqual(environ['CONTENT_TYPE'], 'text/plain') + self.assertEqual(environ['CONTENT_LENGTH'], '209') + self.assertEqual(environ['HTTP_X_TEST'], '123,456') + self.assertEqual(environ['SCRIPT_NAME'], 'script') + self.assertEqual(environ['SERVER_NAME'], 'python.org') + self.assertEqual(environ['SERVER_PORT'], '443') + + def test_environ_host_header(self): + self.headers.append(('HOST', 'python.org')) + environ = self._make_one() + + self.assertEqual(environ['HTTP_HOST'], 'python.org') + self.assertEqual(environ['SERVER_NAME'], 'python.org') + self.assertEqual(environ['SERVER_PORT'], '80') + self.assertEqual(environ['SERVER_PROTOCOL'], 'HTTP/1.0') + + def test_environ_host_port_header(self): + self.info = protocol.RequestLine('GET', '/path', (1, 1)) + self.headers.append(('HOST', 'python.org:443')) + environ = self._make_one() + + self.assertEqual(environ['HTTP_HOST'], 'python.org:443') + self.assertEqual(environ['SERVER_NAME'], 'python.org') + self.assertEqual(environ['SERVER_PORT'], '443') + self.assertEqual(environ['SERVER_PROTOCOL'], 'HTTP/1.1') + + def test_environ_forward(self): + self.transport.get_extra_info.return_value = 'localhost,127.0.0.1' + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') + self.assertEqual(environ['REMOTE_PORT'], '80') + + self.transport.get_extra_info.return_value = 'localhost,127.0.0.1:443' + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') + self.assertEqual(environ['REMOTE_PORT'], '443') + + self.transport.get_extra_info.return_value = ('127.0.0.1', 443) + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') + self.assertEqual(environ['REMOTE_PORT'], '443') + + self.transport.get_extra_info.return_value = '[::1]' + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '::1') + self.assertEqual(environ['REMOTE_PORT'], '80') + + def test_wsgi_response(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.info, self.message) + self.assertIsInstance(resp, wsgi.WsgiResponse) + + def test_wsgi_response_start_response(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.info, self.message) + resp.start_response( + '200 OK', [('CONTENT-TYPE', 'text/plain')]) + self.assertEqual(resp.status, '200 OK') + self.assertIsInstance(resp.response, protocol.Response) + + def test_wsgi_response_start_response_exc(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.info, self.message) + resp.start_response( + '200 OK', [('CONTENT-TYPE', 'text/plain')], ['', ValueError()]) + self.assertEqual(resp.status, '200 OK') + self.assertIsInstance(resp.response, protocol.Response) + + def test_wsgi_response_start_response_exc_status(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.info, self.message) + resp.start_response('200 OK', [('CONTENT-TYPE', 'text/plain')]) + + self.assertRaises( + ValueError, + resp.start_response, + '500 Err', [('CONTENT-TYPE', 'text/plain')], ['', ValueError()]) + + def test_file_wrapper(self): + fobj = io.BytesIO(b'data') + wrapper = wsgi.FileWrapper(fobj, 2) + self.assertIs(wrapper, iter(wrapper)) + self.assertTrue(hasattr(wrapper, 'close')) + + self.assertEqual(next(wrapper), b'da') + self.assertEqual(next(wrapper), b'ta') + self.assertRaises(StopIteration, next, wrapper) + + wrapper = wsgi.FileWrapper(b'data', 2) + self.assertFalse(hasattr(wrapper, 'close')) + + def test_handle_request_futures(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + f1 = tulip.Future() + f1.set_result(b'data') + fut = tulip.Future() + fut.set_result([f1]) + return fut + + srv = wsgi.WSGIServerHttpProtocol(wsgi_app) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.info, self.message)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) + self.assertTrue(content.endswith(b'data')) + + def test_handle_request_simple(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + return [b'data'] + + stream = tulip.StreamReader() + stream.feed_data(b'data') + stream.feed_eof() + self.message = protocol.RawHttpMessage( + self.headers, stream, True, 'deflate') + self.info = protocol.RequestLine('GET', '/path', (1, 1)) + + srv = wsgi.WSGIServerHttpProtocol(wsgi_app, True) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.info, self.message)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.1 200 OK')) + self.assertTrue(content.endswith(b'data\r\n0\r\n\r\n')) + + def test_handle_request_io(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + return io.BytesIO(b'data') + + srv = wsgi.WSGIServerHttpProtocol(wsgi_app) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.info, self.message)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) + self.assertTrue(content.endswith(b'data')) diff --git a/tests/locks_test.py b/tests/locks_test.py new file mode 100644 index 00000000..5f1c180a --- /dev/null +++ b/tests/locks_test.py @@ -0,0 +1,745 @@ +"""Tests for lock.py""" + +import time +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import locks +from tulip import tasks +from tulip import test_utils + + +class LockTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + lock = locks.Lock() + self.assertTrue(repr(lock).endswith('[unlocked]>')) + + @tasks.coroutine + def acquire_lock(): + yield from lock + + self.event_loop.run_until_complete(acquire_lock()) + self.assertTrue(repr(lock).endswith('[locked]>')) + + def test_lock(self): + lock = locks.Lock() + + @tasks.coroutine + def acquire_lock(): + return (yield from lock) + + res = self.event_loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_acquire(self): + lock = locks.Lock() + result = [] + + self.assertTrue( + self.event_loop.run_until_complete(lock.acquire())) + + @tasks.coroutine + def c1(result): + if (yield from lock.acquire()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from lock.acquire()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from lock.acquire()): + result.append(3) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + self.event_loop.run_once() + self.assertEqual([1], result) + + tasks.Task(c3(result)) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + + def test_acquire_timeout(self): + lock = locks.Lock() + self.assertTrue( + self.event_loop.run_until_complete(lock.acquire())) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete( + lock.acquire(timeout=0.1)) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + lock = locks.Lock() + self.event_loop.run_until_complete(lock.acquire()) + + self.event_loop.call_later(0.01, lock.release) + acquired = self.event_loop.run_until_complete(lock.acquire(10.1)) + self.assertTrue(acquired) + + def test_acquire_timeout_mixed(self): + lock = locks.Lock() + self.event_loop.run_until_complete(lock.acquire()) + tasks.Task(lock.acquire()) + tasks.Task(lock.acquire()) + acquire_task = tasks.Task(lock.acquire(0.01)) + tasks.Task(lock.acquire()) + + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + self.assertEqual(3, len(lock._waiters)) + + def test_acquire_cancel(self): + lock = locks.Lock() + self.assertTrue( + self.event_loop.run_until_complete(lock.acquire())) + + task = tasks.Task(lock.acquire()) + self.event_loop.call_soon(task.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, task) + self.assertFalse(lock._waiters) + + def test_release_not_acquired(self): + lock = locks.Lock() + + self.assertRaises(RuntimeError, lock.release) + + def test_release_no_waiters(self): + lock = locks.Lock() + self.event_loop.run_until_complete(lock.acquire()) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_context_manager(self): + lock = locks.Lock() + + @tasks.task + def acquire_lock(): + return (yield from lock) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + def test_context_manager_no_yield(self): + lock = locks.Lock() + + try: + with lock: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + +class EventWaiterTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + ev = locks.EventWaiter() + self.assertTrue(repr(ev).endswith('[unset]>')) + + ev.set() + self.assertTrue(repr(ev).endswith('[set]>')) + + def test_wait(self): + ev = locks.EventWaiter() + self.assertFalse(ev.is_set()) + + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from ev.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from ev.wait()): + result.append(3) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + tasks.Task(c3(result)) + + ev.set() + self.event_loop.run_once() + self.assertEqual([3, 1, 2], result) + + def test_wait_on_set(self): + ev = locks.EventWaiter() + ev.set() + + res = self.event_loop.run_until_complete(ev.wait()) + self.assertTrue(res) + + def test_wait_timeout(self): + ev = locks.EventWaiter() + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(ev.wait(0.1)) + self.assertFalse(res) + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + ev = locks.EventWaiter() + self.event_loop.call_later(0.01, ev.set) + acquired = self.event_loop.run_until_complete(ev.wait(10.1)) + self.assertTrue(acquired) + + def test_wait_timeout_mixed(self): + ev = locks.EventWaiter() + tasks.Task(ev.wait()) + tasks.Task(ev.wait()) + acquire_task = tasks.Task(ev.wait(0.1)) + tasks.Task(ev.wait()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(ev._waiters)) + + def test_wait_cancel(self): + ev = locks.EventWaiter() + + wait = tasks.Task(ev.wait()) + self.event_loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, wait) + self.assertFalse(ev._waiters) + + def test_clear(self): + ev = locks.EventWaiter() + self.assertFalse(ev.is_set()) + + ev.set() + self.assertTrue(ev.is_set()) + + ev.clear() + self.assertFalse(ev.is_set()) + + def test_clear_with_waiters(self): + ev = locks.EventWaiter() + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + tasks.Task(c1(result)) + self.event_loop.run_once() + self.assertEqual([], result) + + ev.set() + ev.clear() + self.assertFalse(ev.is_set()) + + ev.set() + ev.set() + self.assertEqual(1, len(ev._waiters)) + + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertEqual(0, len(ev._waiters)) + + +class ConditionTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_wait(self): + cond = locks.Condition() + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue( + self.event_loop.run_until_complete(cond.acquire())) + cond.notify() + self.event_loop.run_once() + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + def test_wait_timeout(self): + cond = locks.Condition() + self.event_loop.run_until_complete(cond.acquire()) + + t0 = time.monotonic() + wait = self.event_loop.run_until_complete(cond.wait(0.1)) + self.assertFalse(wait) + self.assertTrue(cond.locked()) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + def test_wait_cancel(self): + cond = locks.Condition() + self.event_loop.run_until_complete(cond.acquire()) + + wait = tasks.Task(cond.wait()) + self.event_loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, wait) + self.assertFalse(cond._condition_waiters) + self.assertTrue(cond.locked()) + + def test_wait_unacquired(self): + self.suppress_log_errors() + + cond = locks.Condition() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, cond.wait()) + + def test_wait_for(self): + cond = locks.Condition() + presult = False + + def predicate(): + return presult + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate)): + result.append(1) + cond.release() + + tasks.Task(c1(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([], result) + + presult = True + self.event_loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + def test_wait_for_timeout(self): + cond = locks.Condition() + + result = [] + + predicate = unittest.mock.Mock() + predicate.return_value = False + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate, 0.1)): + result.append(1) + else: + result.append(2) + cond.release() + + wait_for = tasks.Task(c1(result)) + + t0 = time.monotonic() + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(wait_for) + self.assertEqual([2], result) + self.assertEqual(3, predicate.call_count) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + def test_wait_for_unacquired(self): + self.suppress_log_errors() + + cond = locks.Condition() + + # predicate can return true immediately + res = self.event_loop.run_until_complete( + cond.wait_for(lambda: [1, 2, 3])) + self.assertEqual([1, 2, 3], res) + + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, + cond.wait_for(lambda: False)) + + def test_notify(self): + cond = locks.Condition() + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + cond.release() + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.notify(2048) + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + + def test_notify_all(self): + cond = locks.Condition() + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify_all() + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + + def test_notify_unacquired(self): + cond = locks.Condition() + self.assertRaises(RuntimeError, cond.notify) + + def test_notify_all_unacquired(self): + cond = locks.Condition() + self.assertRaises(RuntimeError, cond.notify_all) + + +class SemaphoreTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + sem = locks.Semaphore() + self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) + + self.event_loop.run_until_complete(sem.acquire()) + self.assertTrue(repr(sem).endswith('[locked]>')) + + def test_semaphore(self): + sem = locks.Semaphore() + self.assertEqual(1, sem._value) + + @tasks.task + def acquire_lock(): + return (yield from sem) + + res = self.event_loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(sem.locked()) + self.assertEqual(0, sem._value) + + sem.release() + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + def test_semaphore_value(self): + self.assertRaises(ValueError, locks.Semaphore, -1) + + def test_acquire(self): + sem = locks.Semaphore(3) + result = [] + + self.assertTrue( + self.event_loop.run_until_complete(sem.acquire())) + self.assertTrue( + self.event_loop.run_until_complete(sem.acquire())) + self.assertFalse(sem.locked()) + + @tasks.coroutine + def c1(result): + yield from sem.acquire() + result.append(1) + + @tasks.coroutine + def c2(result): + yield from sem.acquire() + result.append(2) + + @tasks.coroutine + def c3(result): + yield from sem.acquire() + result.append(3) + + @tasks.coroutine + def c4(result): + yield from sem.acquire() + result.append(4) + + tasks.Task(c1(result)) + tasks.Task(c2(result)) + tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(sem.locked()) + self.assertEqual(2, len(sem._waiters)) + self.assertEqual(0, sem._value) + + tasks.Task(c4(result)) + + sem.release() + sem.release() + self.assertEqual(2, sem._value) + + self.event_loop.run_once() + self.assertEqual(0, sem._value) + self.assertEqual([1, 2, 3], result) + self.assertTrue(sem.locked()) + self.assertEqual(1, len(sem._waiters)) + self.assertEqual(0, sem._value) + + def test_acquire_timeout(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(sem.acquire(0.1)) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + + self.event_loop.call_later(0.01, sem.release) + acquired = self.event_loop.run_until_complete(sem.acquire(10.1)) + self.assertTrue(acquired) + + def test_acquire_timeout_mixed(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + tasks.Task(sem.acquire()) + tasks.Task(sem.acquire()) + acquire_task = tasks.Task(sem.acquire(0.1)) + tasks.Task(sem.acquire()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(sem._waiters)) + + def test_acquire_cancel(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + + acquire = tasks.Task(sem.acquire()) + self.event_loop.call_soon(acquire.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, acquire) + self.assertFalse(sem._waiters) + + def test_release_not_acquired(self): + sem = locks.Semaphore(bound=True) + + self.assertRaises(ValueError, sem.release) + + def test_release_no_waiters(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + self.assertTrue(sem.locked()) + + sem.release() + self.assertFalse(sem.locked()) + + def test_context_manager(self): + sem = locks.Semaphore(2) + + @tasks.task + def acquire_lock(): + return (yield from sem) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertTrue(sem.locked()) + + self.assertEqual(2, sem._value) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py new file mode 100644 index 00000000..7a92ad08 --- /dev/null +++ b/tests/proactor_events_test.py @@ -0,0 +1,337 @@ +"""Tests for proactor_events.py""" + +import socket +import unittest +import unittest.mock + +import tulip +from tulip.proactor_events import BaseProactorEventLoop +from tulip.proactor_events import _ProactorSocketTransport + + +class ProactorSocketTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock() + self.sock = unittest.mock.Mock(socket.socket) + self.protocol = unittest.mock.Mock(tulip.Protocol) + + def test_ctor(self): + fut = tulip.Future() + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol, fut) + self.event_loop.call_soon.mock_calls[0].assert_called_with( + tr._loop_reading) + self.event_loop.call_soon.mock_calls[1].assert_called_with( + self.protocol.connection_made, tr) + self.event_loop.call_soon.mock_calls[2].assert_called_with( + fut.set_result, None) + + def test_loop_reading(self): + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._loop_reading() + self.event_loop._proactor.recv.assert_called_with(self.sock, 4096) + self.assertFalse(self.protocol.data_received.called) + self.assertFalse(self.protocol.eof_received.called) + + def test_loop_reading_data(self): + res = tulip.Future() + res.set_result(b'data') + + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + + tr._read_fut = res + tr._loop_reading(res) + self.event_loop._proactor.recv.assert_called_with(self.sock, 4096) + self.protocol.data_received.assert_called_with(b'data') + + def test_loop_reading_no_data(self): + res = tulip.Future() + res.set_result(b'') + + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + + self.assertRaises(AssertionError, tr._loop_reading, res) + + tr._read_fut = res + tr._loop_reading(res) + self.assertFalse(self.event_loop._proactor.recv.called) + self.assertTrue(self.protocol.eof_received.called) + + def test_loop_reading_aborted(self): + err = self.event_loop._proactor.recv.side_effect = ( + ConnectionAbortedError()) + + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with(err) + + def test_loop_reading_aborted_closing(self): + self.event_loop._proactor.recv.side_effect = ( + ConnectionAbortedError()) + + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._closing = True + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + + def test_loop_reading_exception(self): + err = self.event_loop._proactor.recv.side_effect = (OSError()) + + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with(err) + + def test_write(self): + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._loop_writing = unittest.mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, [b'data']) + self.assertTrue(tr._loop_writing.called) + + def test_write_no_data(self): + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr.write(b'') + self.assertFalse(tr._buffer) + + def test_write_more(self): + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._write_fut = unittest.mock.Mock() + tr._loop_writing = unittest.mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, [b'data']) + self.assertFalse(tr._loop_writing.called) + + def test_loop_writing(self): + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + self.event_loop._proactor.send.assert_called_with(self.sock, b'data') + self.event_loop._proactor.send.return_value.add_done_callback.\ + assert_called_with(tr._loop_writing) + + @unittest.mock.patch('tulip.proactor_events.tulip_log') + def test_loop_writing_err(self, m_log): + err = self.event_loop._proactor.send.side_effect = OSError() + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + tr._fatal_error.assert_called_with(err) + self.assertEqual(tr._conn_lost, 1) + + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + self.assertEqual(tr._buffer, []) + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_loop_writing_stop(self): + fut = tulip.Future() + fut.set_result(b'data') + + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._write_fut = fut + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + + def test_abort(self): + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr.abort() + tr._fatal_error.assert_called_with(None) + + def test_close(self): + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._write_fut = unittest.mock.Mock() + tr.close() + + tr._write_fut.cancel.assert_called_with() + self.event_loop.call_soon.assert_called_with( + tr._call_connection_lost, None) + self.assertTrue(tr._closing) + + def test_close_2(self): + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._buffer = [b'data'] + self.event_loop.reset_mock() + tr.close() + + self.assertFalse(self.event_loop.call_soon.called) + + @unittest.mock.patch('tulip.proactor_events.tulip_log') + def test_fatal_error(self, m_logging): + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._buffer = [b'data'] + read_fut = tr._read_fut = unittest.mock.Mock() + write_fut = tr._write_fut = unittest.mock.Mock() + tr._fatal_error(None) + + read_fut.cancel.assert_called_with() + write_fut.cancel.assert_called_with() + self.event_loop.call_soon.assert_called_with( + tr._call_connection_lost, None) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('tulip.proactor_events.tulip_log') + def test_fatal_error_2(self, m_logging): + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._buffer = [b'data'] + tr._fatal_error(None) + + self.event_loop.call_soon.assert_called_with( + tr._call_connection_lost, None) + self.assertEqual([], tr._buffer) + + def test_call_connection_lost(self): + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._call_connection_lost(None) + self.assertTrue(self.protocol.connection_lost.called) + self.assertTrue(self.sock.close.called) + + +class BaseProactorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.sock = unittest.mock.Mock(socket.socket) + self.proactor = unittest.mock.Mock() + + self.ssock, self.csock = unittest.mock.Mock(), unittest.mock.Mock() + + class EventLoop(BaseProactorEventLoop): + def _socketpair(s): + return (self.ssock, self.csock) + + self.event_loop = EventLoop(self.proactor) + + @unittest.mock.patch.object(BaseProactorEventLoop, 'call_soon') + @unittest.mock.patch.object(BaseProactorEventLoop, '_socketpair') + def test_ctor(self, socketpair, call_soon): + ssock, csock = socketpair.return_value = ( + unittest.mock.Mock(), unittest.mock.Mock()) + event_loop = BaseProactorEventLoop(self.proactor) + self.assertIs(event_loop._ssock, ssock) + self.assertIs(event_loop._csock, csock) + self.assertEqual(event_loop._internal_fds, 1) + call_soon.assert_called_with(event_loop._loop_self_reading) + + def test_close_self_pipe(self): + self.event_loop._close_self_pipe() + self.assertEqual(self.event_loop._internal_fds, 0) + self.assertTrue(self.ssock.close.called) + self.assertTrue(self.csock.close.called) + self.assertIsNone(self.event_loop._ssock) + self.assertIsNone(self.event_loop._csock) + + def test_close(self): + self.event_loop._close_self_pipe = unittest.mock.Mock() + self.event_loop.close() + self.assertTrue(self.event_loop._close_self_pipe.called) + self.assertTrue(self.proactor.close.called) + self.assertIsNone(self.event_loop._proactor) + + self.event_loop._close_self_pipe.reset_mock() + self.event_loop.close() + self.assertFalse(self.event_loop._close_self_pipe.called) + + def test_sock_recv(self): + self.event_loop.sock_recv(self.sock, 1024) + self.proactor.recv.assert_called_with(self.sock, 1024) + + def test_sock_sendall(self): + self.event_loop.sock_sendall(self.sock, b'data') + self.proactor.send.assert_called_with(self.sock, b'data') + + def test_sock_connect(self): + self.event_loop.sock_connect(self.sock, 123) + self.proactor.connect.assert_called_with(self.sock, 123) + + def test_sock_accept(self): + self.event_loop.sock_accept(self.sock) + self.proactor.accept.assert_called_with(self.sock) + + def test_socketpair(self): + self.assertRaises( + NotImplementedError, BaseProactorEventLoop, self.proactor) + + def test_make_socket_transport(self): + tr = self.event_loop._make_socket_transport( + self.sock, unittest.mock.Mock()) + self.assertIsInstance(tr, _ProactorSocketTransport) + + def test_loop_self_reading(self): + self.event_loop._loop_self_reading() + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.event_loop._loop_self_reading) + + def test_loop_self_reading_fut(self): + fut = unittest.mock.Mock() + self.event_loop._loop_self_reading(fut) + self.assertTrue(fut.result.called) + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.event_loop._loop_self_reading) + + def test_loop_self_reading_exception(self): + self.event_loop.close = unittest.mock.Mock() + self.proactor.recv.side_effect = OSError() + self.assertRaises(OSError, self.event_loop._loop_self_reading) + self.assertTrue(self.event_loop.close.called) + + def test_write_to_self(self): + self.event_loop._write_to_self() + self.csock.send.assert_called_with(b'x') + + def test_process_events(self): + self.event_loop._process_events([]) + + def test_start_serving(self): + pf = unittest.mock.Mock() + call_soon = self.event_loop.call_soon = unittest.mock.Mock() + + self.event_loop._start_serving(pf, self.sock) + self.assertTrue(call_soon.called) + + # callback + loop = call_soon.call_args[0][0] + loop() + self.proactor.accept.assert_called_with(self.sock) + + # conn + fut = unittest.mock.Mock() + fut.result.return_value = (unittest.mock.Mock(), unittest.mock.Mock()) + + make_transport = self.event_loop._make_socket_transport = ( + unittest.mock.Mock()) + loop(fut) + self.assertTrue(fut.result.called) + self.assertTrue(make_transport.called) + + # exception + fut.result.side_effect = OSError() + loop(fut) + self.assertTrue(self.sock.close.called) diff --git a/tests/queues_test.py b/tests/queues_test.py new file mode 100644 index 00000000..8c1c0afb --- /dev/null +++ b/tests/queues_test.py @@ -0,0 +1,370 @@ +"""Tests for queues.py""" + +import unittest +import queue + +from tulip import events +from tulip import locks +from tulip import queues +from tulip import tasks + + +class _QueueTestBase(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + +class QueueBasicTests(_QueueTestBase): + + def _test_repr_or_str(self, fn, expect_id): + """Test Queue's repr or str. + + fn is repr or str. expect_id is True if we expect the Queue's id to + appear in fn(Queue()). + """ + q = queues.Queue() + self.assertTrue(fn(q).startswith('", repr(key)) + + def test_register(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ) + self.assertIsInstance(key, selectors.SelectorKey) + self.assertEqual(key.fd, 10) + self.assertIs(key, s._fd_to_key[10]) + + def test_register_unknown_event(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.register, unittest.mock.Mock(), 999999) + + def test_register_already_registered(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + s.register(fobj, selectors.EVENT_READ) + self.assertRaises(ValueError, s.register, fobj, selectors.EVENT_READ) + + def test_unregister(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + s.register(fobj, selectors.EVENT_READ) + s.unregister(fobj) + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_unregister_unknown(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.unregister, unittest.mock.Mock()) + + def test_modify_unknown(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.modify, unittest.mock.Mock(), 1) + + def test_modify(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ) + key2 = s.modify(fobj, selectors.EVENT_WRITE) + self.assertNotEqual(key.events, key2.events) + self.assertEqual((selectors.EVENT_WRITE, None), s.get_info(fobj)) + + def test_modify_data(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + d1 = object() + d2 = object() + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ, d1) + key2 = s.modify(fobj, selectors.EVENT_READ, d2) + self.assertEqual(key.events, key2.events) + self.assertNotEqual(key.data, key2.data) + self.assertEqual((selectors.EVENT_READ, d2), s.get_info(fobj)) + + def test_modify_same(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + data = object() + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ, data) + key2 = s.modify(fobj, selectors.EVENT_READ, data) + self.assertIs(key, key2) + + def test_select(self): + s = selectors._BaseSelector() + self.assertRaises(NotImplementedError, s.select) + + def test_close(self): + s = selectors._BaseSelector() + s.register(1, selectors.EVENT_READ) + + s.close() + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_registered_count(self): + s = selectors._BaseSelector() + self.assertEqual(0, s.registered_count()) + + s.register(1, selectors.EVENT_READ) + self.assertEqual(1, s.registered_count()) + + s.unregister(1) + self.assertEqual(0, s.registered_count()) + + def test_context_manager(self): + s = selectors._BaseSelector() + + with s as sel: + sel.register(1, selectors.EVENT_READ) + + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_key_from_fd(self): + s = selectors._BaseSelector() + key = s.register(1, selectors.EVENT_READ) + + self.assertIs(key, s._key_from_fd(1)) + self.assertIsNone(s._key_from_fd(10)) diff --git a/tests/streams_test.py b/tests/streams_test.py new file mode 100644 index 00000000..f7e2992b --- /dev/null +++ b/tests/streams_test.py @@ -0,0 +1,300 @@ +"""Tests for streams.py.""" + +import unittest + +from tulip import events +from tulip import streams +from tulip import tasks +from tulip import test_utils + + +class StreamReaderTests(test_utils.LogTrackingTestCase): + + DATA = b'line1\nline2\nline3\n' + + def setUp(self): + super().setUp() + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + super().tearDown() + self.event_loop.close() + + def test_feed_empty_data(self): + stream = streams.StreamReader() + + stream.feed_data(b'') + self.assertEqual(0, stream.byte_count) + + def test_feed_data_byte_count(self): + stream = streams.StreamReader() + + stream.feed_data(self.DATA) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read_zero(self): + # Read zero bytes. + stream = streams.StreamReader() + stream.feed_data(self.DATA) + + data = self.event_loop.run_until_complete(stream.read(0)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read(self): + # Read bytes. + stream = streams.StreamReader() + read_task = tasks.Task(stream.read(30)) + + def cb(): + stream.feed_data(self.DATA) + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + + def test_read_line_breaks(self): + # Read bytes without line breaks. + stream = streams.StreamReader() + stream.feed_data(b'line1') + stream.feed_data(b'line2') + + data = self.event_loop.run_until_complete(stream.read(5)) + + self.assertEqual(b'line1', data) + self.assertEqual(5, stream.byte_count) + + def test_read_eof(self): + # Read bytes, stop at eof. + stream = streams.StreamReader() + read_task = tasks.Task(stream.read(1024)) + + def cb(): + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'', data) + self.assertFalse(stream.byte_count) + + def test_read_until_eof(self): + # Read all bytes until eof. + stream = streams.StreamReader() + read_task = tasks.Task(stream.read(-1)) + + def cb(): + stream.feed_data(b'chunk1\n') + stream.feed_data(b'chunk2') + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'chunk1\nchunk2', data) + self.assertFalse(stream.byte_count) + + def test_read_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete(stream.read(2)) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, stream.read(2)) + + def test_readline(self): + # Read one line. + stream = streams.StreamReader() + stream.feed_data(b'chunk1 ') + read_task = tasks.Task(stream.readline()) + + def cb(): + stream.feed_data(b'chunk2 ') + stream.feed_data(b'chunk3 ') + stream.feed_data(b'\n chunk4') + self.event_loop.call_soon(cb) + + line = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) + self.assertEqual(len(b'\n chunk4')-1, stream.byte_count) + + def test_readline_limit_with_existing_data(self): + self.suppress_log_errors() + + stream = streams.StreamReader(3) + stream.feed_data(b'li') + stream.feed_data(b'ne1\nline2\n') + + self.assertRaises( + ValueError, self.event_loop.run_until_complete, stream.readline()) + self.assertEqual([b'line2\n'], list(stream.buffer)) + + stream = streams.StreamReader(3) + stream.feed_data(b'li') + stream.feed_data(b'ne1') + stream.feed_data(b'li') + + self.assertRaises( + ValueError, self.event_loop.run_until_complete, stream.readline()) + self.assertEqual([b'li'], list(stream.buffer)) + self.assertEqual(2, stream.byte_count) + + def test_readline_limit(self): + self.suppress_log_errors() + + stream = streams.StreamReader(7) + + def cb(): + stream.feed_data(b'chunk1') + stream.feed_data(b'chunk2') + stream.feed_data(b'chunk3\n') + stream.feed_eof() + self.event_loop.call_soon(cb) + + self.assertRaises( + ValueError, self.event_loop.run_until_complete, stream.readline()) + self.assertEqual([b'chunk3\n'], list(stream.buffer)) + self.assertEqual(7, stream.byte_count) + + def test_readline_line_byte_count(self): + stream = streams.StreamReader() + stream.feed_data(self.DATA[:6]) + stream.feed_data(self.DATA[6:]) + + line = self.event_loop.run_until_complete(stream.readline()) + + self.assertEqual(b'line1\n', line) + self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) + + def test_readline_eof(self): + stream = streams.StreamReader() + stream.feed_data(b'some data') + stream.feed_eof() + + line = self.event_loop.run_until_complete(stream.readline()) + self.assertEqual(b'some data', line) + + def test_readline_empty_eof(self): + stream = streams.StreamReader() + stream.feed_eof() + + line = self.event_loop.run_until_complete(stream.readline()) + self.assertEqual(b'', line) + + def test_readline_read_byte_count(self): + stream = streams.StreamReader() + stream.feed_data(self.DATA) + + self.event_loop.run_until_complete(stream.readline()) + + data = self.event_loop.run_until_complete(stream.read(7)) + + self.assertEqual(b'line2\nl', data) + self.assertEqual( + len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), + stream.byte_count) + + def test_readline_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete(stream.readline()) + self.assertEqual(b'line\n', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, stream.readline()) + + def test_readexactly_zero_or_less(self): + # Read exact number of bytes (zero or less). + stream = streams.StreamReader() + stream.feed_data(self.DATA) + + data = self.event_loop.run_until_complete(stream.readexactly(0)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + data = self.event_loop.run_until_complete(stream.readexactly(-1)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readexactly(self): + # Read exact number of bytes. + stream = streams.StreamReader() + + n = 2 * len(self.DATA) + read_task = tasks.Task(stream.readexactly(n)) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA + self.DATA, data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readexactly_eof(self): + # Read exact number of bytes (eof). + stream = streams.StreamReader() + n = 2 * len(self.DATA) + read_task = tasks.Task(stream.readexactly(n)) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + + def test_readexactly_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete(stream.readexactly(2)) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, + stream.readexactly(2)) + + def test_exception(self): + stream = streams.StreamReader() + self.assertIsNone(stream.exception()) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(stream.exception(), exc) + + def test_exception_waiter(self): + stream = streams.StreamReader() + + @tasks.coroutine + def set_err(): + stream.set_exception(ValueError()) + + @tasks.coroutine + def readline(): + yield from stream.readline() + + t1 = tasks.Task(stream.readline()) + t2 = tasks.Task(set_err()) + + self.event_loop.run_until_complete(tasks.wait([t1, t2])) + + self.assertRaises(ValueError, t1.result) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/subprocess_test.py b/tests/subprocess_test.py new file mode 100644 index 00000000..09aaed52 --- /dev/null +++ b/tests/subprocess_test.py @@ -0,0 +1,54 @@ +"""Tests for subprocess_transport.py.""" + +import logging +import unittest + +from tulip import events +from tulip import protocols +from tulip import subprocess_transport + + +class MyProto(protocols.Protocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write_eof() + + def data_received(self, data): + logging.info('received: %r', data) + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + + +class FutureTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_unix_subprocess(self): + p = MyProto() + subprocess_transport.UnixSubprocessTransport(p, ['/bin/ls', '-lR']) + self.event_loop.run() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/tasks_test.py b/tests/tasks_test.py new file mode 100644 index 00000000..fa22d62c --- /dev/null +++ b/tests/tasks_test.py @@ -0,0 +1,676 @@ +"""Tests for tasks.py.""" + +import concurrent.futures +import time +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import tasks +from tulip import test_utils + + +class Dummy: + + def __repr__(self): + return 'Dummy()' + + def __call__(self, *args): + pass + + +class TaskTests(test_utils.LogTrackingTestCase): + + def setUp(self): + super().setUp() + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + super().tearDown() + + def test_task_class(self): + @tasks.coroutine + def notmuch(): + return 'ok' + t = tasks.Task(notmuch()) + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t._event_loop, self.event_loop) + + event_loop = events.new_event_loop() + t = tasks.Task(notmuch(), event_loop=event_loop) + self.assertIs(t._event_loop, event_loop) + + def test_task_decorator(self): + @tasks.task + def notmuch(): + yield from [] + return 'ko' + t = notmuch() + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + + def test_task_decorator_func(self): + @tasks.task + def notmuch(): + return 'ko' + t = notmuch() + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + + def test_task_decorator_fut(self): + fut = futures.Future() + fut.set_result('ko') + + @tasks.task + def notmuch(): + return fut + t = notmuch() + self.event_loop.run() + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + + def test_task_repr(self): + @tasks.task + def notmuch(): + yield from [] + return 'abc' + t = notmuch() + t.add_done_callback(Dummy()) + self.assertEqual(repr(t), 'Task()') + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), 'Task()') + self.assertRaises(futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertEqual(repr(t), 'Task()') + t = notmuch() + self.event_loop.run_until_complete(t) + self.assertEqual(repr(t), "Task()") + + def test_task_repr_custom(self): + @tasks.coroutine + def coro(): + pass + + class T(futures.Future): + def __repr__(self): + return 'T[]' + + class MyTask(tasks.Task, T): + def __repr__(self): + return super().__repr__() + + t = MyTask(coro()) + self.assertEqual(repr(t), 'T[]()') + + def test_task_basics(self): + @tasks.task + def outer(): + a = yield from inner1() + b = yield from inner2() + return a+b + + @tasks.task + def inner1(): + return 42 + + @tasks.task + def inner2(): + return 1000 + + t = outer() + self.assertEqual(self.event_loop.run_until_complete(t), 1042) + + def test_cancel(self): + @tasks.task + def task(): + yield from tasks.sleep(10.0) + return 12 + + t = task() + self.event_loop.call_soon(t.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_future_timeout(self): + @tasks.coroutine + def coro(): + yield from tasks.sleep(10.0) + return 12 + + t = tasks.Task(coro(), timeout=0.1) + + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_future_timeout_catch(self): + self.suppress_log_errors() + + @tasks.coroutine + def coro(): + yield from tasks.sleep(10.0) + return 12 + + class Cancelled(Exception): + pass + + @tasks.coroutine + def coro2(): + try: + yield from tasks.Task(coro(), timeout=0.1) + except futures.CancelledError: + raise Cancelled() + + self.assertRaises( + Cancelled, self.event_loop.run_until_complete, coro2()) + + def test_cancel_in_coro(self): + @tasks.coroutine + def task(): + t.cancel() + return 12 + + t = tasks.Task(task()) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_stop_while_run_in_complete(self): + x = 0 + + @tasks.coroutine + def task(): + nonlocal x + while x < 10: + yield from tasks.sleep(0.1) + x += 1 + if x == 2: + self.event_loop.stop() + + t = tasks.Task(task()) + t0 = time.monotonic() + self.assertRaises( + futures.InvalidStateError, + self.event_loop.run_until_complete, t) + t1 = time.monotonic() + self.assertFalse(t.done()) + self.assertTrue(0.18 <= t1-t0 <= 0.22) + self.assertEqual(x, 2) + + def test_timeout(self): + @tasks.task + def task(): + yield from tasks.sleep(10.0) + return 42 + + t = task() + t0 = time.monotonic() + self.assertRaises( + futures.TimeoutError, + self.event_loop.run_until_complete, t, 0.1) + t1 = time.monotonic() + self.assertFalse(t.done()) + self.assertTrue(0.08 <= t1-t0 <= 0.12) + + def test_timeout_not(self): + @tasks.task + def task(): + yield from tasks.sleep(0.1) + return 42 + + t = task() + t0 = time.monotonic() + r = self.event_loop.run_until_complete(t, 10.0) + t1 = time.monotonic() + self.assertTrue(t.done()) + self.assertEqual(r, 42) + self.assertTrue(0.08 <= t1-t0 <= 0.12) + + def test_wait(self): + a = tasks.sleep(0.1) + b = tasks.sleep(0.15) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertEqual(res, 42) + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + # TODO: Test different return_when values. + + def test_wait_first_completed(self): + a = tasks.sleep(10.0) + b = tasks.sleep(0.1) + task = tasks.Task(tasks.wait( + [b, a], return_when=tasks.FIRST_COMPLETED)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + + def test_wait_really_done(self): + # there is possibility that some tasks in the pending list + # became done but their callbacks haven't all been called yet + + @tasks.coroutine + def coro1(): + yield + + @tasks.coroutine + def coro2(): + yield + yield + + a = tasks.Task(coro1()) + b = tasks.Task(coro2()) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({a, b}, done) + + def test_wait_first_exception(self): + self.suppress_log_errors() + + a = tasks.sleep(10.0) + + @tasks.coroutine + def exc(): + raise ZeroDivisionError('err') + + b = tasks.Task(exc()) + task = tasks.Task(tasks.wait( + [b, a], return_when=tasks.FIRST_EXCEPTION)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + + def test_wait_with_exception(self): + self.suppress_log_errors() + a = tasks.sleep(0.1) + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(0.15) + raise ZeroDivisionError('really') + + b = tasks.Task(sleeper()) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + + t0 = time.monotonic() + self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + t0 = time.monotonic() + self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + + def test_wait_with_timeout(self): + a = tasks.sleep(0.1) + b = tasks.sleep(0.15) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], timeout=0.11) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + + t0 = time.monotonic() + self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.1) + self.assertTrue(t1-t0 <= 0.13) + + def test_as_completed(self): + @tasks.coroutine + def sleeper(dt, x): + yield from tasks.sleep(dt) + return x + + a = sleeper(0.1, 'a') + b = sleeper(0.1, 'b') + c = sleeper(0.15, 'c') + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([b, c, a]): + values.append((yield from f)) + return values + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + + def test_as_completed_with_timeout(self): + self.suppress_log_errors() + a = tasks.sleep(0.1, 'a') + b = tasks.sleep(0.15, 'b') + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([a, b], timeout=0.12): + try: + v = yield from f + values.append((1, v)) + except futures.TimeoutError as exc: + values.append((2, exc)) + return values + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.11) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) + + def test_sleep(self): + @tasks.coroutine + def sleeper(dt, arg): + yield from tasks.sleep(dt/2) + res = yield from tasks.sleep(dt/2, arg) + return res + + t = tasks.Task(sleeper(0.1, 'yeah')) + t0 = time.monotonic() + self.event_loop.run() + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.09) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + + def test_task_cancel_sleeping_task(self): + sleepfut = None + + @tasks.task + def sleep(dt): + nonlocal sleepfut + sleepfut = tasks.sleep(dt) + try: + time.monotonic() + yield from sleepfut + finally: + time.monotonic() + + @tasks.task + def doit(): + sleeper = sleep(5000) + self.event_loop.call_later(0.1, sleeper.cancel) + try: + time.monotonic() + yield from sleeper + except futures.CancelledError: + time.monotonic() + return 'cancelled' + else: + return 'slept in' + + t0 = time.monotonic() + doer = doit() + self.assertEqual(self.event_loop.run_until_complete(doer), 'cancelled') + t1 = time.monotonic() + self.assertTrue(0.09 <= t1-t0 <= 0.13, (t1-t0, sleepfut, doer)) + + def test_task_cancel_waiter_future(self): + fut = futures.Future() + + @tasks.task + def coro(): + try: + yield from fut + except futures.CancelledError: + pass + + task = coro() + self.event_loop.run_once() + self.assertIs(task._fut_waiter, fut) + + task.cancel() + self.assertRaises( + futures.CancelledError, self.event_loop.run_until_complete, task) + self.assertIsNone(task._fut_waiter) + self.assertTrue(fut.cancelled()) + + @unittest.mock.patch('tulip.tasks.tulip_log') + def test_step_in_completed_task(self, m_logging): + @tasks.coroutine + def notmuch(): + return 'ko' + + task = tasks.Task(notmuch()) + task.set_result('ok') + + self.assertRaises(AssertionError, task._step) + + @unittest.mock.patch('tulip.tasks.tulip_log') + def test_step_result(self, m_logging): + @tasks.coroutine + def notmuch(): + yield None + yield 1 + return 'ko' + + self.assertRaises( + RuntimeError, self.event_loop.run_until_complete, notmuch()) + + def test_step_result_future(self): + # If coroutine returns future, task waits on this future. + self.suppress_log_warnings() + + class Fut(futures.Future): + def __init__(self, *args): + self.cb_added = False + super().__init__(*args) + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + fut = Fut() + result = None + + @tasks.task + def wait_for_future(): + nonlocal result + result = yield from fut + + wait_for_future() + self.event_loop.run_once() + self.assertTrue(fut.cb_added) + + res = object() + fut.set_result(res) + self.event_loop.run_once() + self.assertIs(res, result) + + def test_step_result_concurrent_future(self): + # Coroutine returns concurrent.futures.Future + + class Fut(concurrent.futures.Future): + def __init__(self): + self.cb_added = False + super().__init__() + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + c_fut = Fut() + + @tasks.coroutine + def notmuch(): + return (yield c_fut) + + task = tasks.Task(notmuch()) + self.event_loop.run_once() + self.assertTrue(c_fut.cb_added) + + res = object() + c_fut.set_result(res) + self.event_loop.run_once() + self.assertIs(res, task.result()) + + def test_step_with_baseexception(self): + self.suppress_log_errors() + + @tasks.coroutine + def notmutch(): + raise BaseException() + + task = tasks.Task(notmutch()) + self.assertRaises(BaseException, task._step) + + self.assertTrue(task.done()) + self.assertIsInstance(task.exception(), BaseException) + + def test_baseexception_during_cancel(self): + self.suppress_log_errors() + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(10) + + @tasks.coroutine + def notmutch(): + try: + yield from sleeper() + except futures.CancelledError: + raise BaseException() + + task = tasks.Task(notmutch()) + self.event_loop.run_once() + + task.cancel() + self.assertFalse(task.done()) + + self.assertRaises(BaseException, self.event_loop.run_once) + + self.assertTrue(task.done()) + self.assertTrue(task.cancelled()) + + def test_iscoroutinefunction(self): + def fn(): + pass + + self.assertFalse(tasks.iscoroutinefunction(fn)) + + def fn1(): + yield + self.assertFalse(tasks.iscoroutinefunction(fn1)) + + @tasks.coroutine + def fn2(): + yield + self.assertTrue(tasks.iscoroutinefunction(fn2)) + + def test_yield_vs_yield_from(self): + self.suppress_log_errors() + fut = futures.Future() + + @tasks.task + def wait_for_future(): + yield fut + + task = wait_for_future() + with self.assertRaises(RuntimeError) as cm: + self.event_loop.run_until_complete(task) + + self.assertTrue(fut.done()) + self.assertIs(fut.exception(), cm.exception) + + def test_yield_vs_yield_from_generator(self): + self.suppress_log_errors() + fut = futures.Future() + + @tasks.coroutine + def coro(): + yield from fut + + @tasks.task + def wait_for_future(): + yield coro() + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, task) + + def test_coroutine_non_gen_function(self): + + @tasks.coroutine + def func(): + return 'test' + + self.assertTrue(tasks.iscoroutinefunction(func)) + + coro = func() + self.assertTrue(tasks.iscoroutine(coro)) + + res = self.event_loop.run_until_complete(coro) + self.assertEqual(res, 'test') + + def test_coroutine_non_gen_function_return_future(self): + fut = futures.Future() + + @tasks.coroutine + def func(): + return fut + + @tasks.coroutine + def coro(): + fut.set_result('test') + + t1 = tasks.Task(func()) + tasks.Task(coro()) + res = self.event_loop.run_until_complete(t1) + self.assertEqual(res, 'test') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/transports_test.py b/tests/transports_test.py new file mode 100644 index 00000000..4b24b50b --- /dev/null +++ b/tests/transports_test.py @@ -0,0 +1,45 @@ +"""Tests for transports.py.""" + +import unittest +import unittest.mock + +from tulip import transports + + +class TransportTests(unittest.TestCase): + + def test_ctor_extra_is_none(self): + transport = transports.Transport() + self.assertEqual(transport._extra, {}) + + def test_get_extra_info(self): + transport = transports.Transport({'extra': 'info'}) + self.assertEqual('info', transport.get_extra_info('extra')) + self.assertIsNone(transport.get_extra_info('unknown')) + + default = object() + self.assertIs(default, transport.get_extra_info('unknown', default)) + + def test_writelines(self): + transport = transports.Transport() + transport.write = unittest.mock.Mock() + + transport.writelines(['line1', 'line2', 'line3']) + self.assertEqual(3, transport.write.call_count) + + def test_not_implemented(self): + transport = transports.Transport() + + self.assertRaises(NotImplementedError, transport.write, 'data') + self.assertRaises(NotImplementedError, transport.write_eof) + self.assertRaises(NotImplementedError, transport.can_write_eof) + self.assertRaises(NotImplementedError, transport.pause) + self.assertRaises(NotImplementedError, transport.resume) + self.assertRaises(NotImplementedError, transport.close) + self.assertRaises(NotImplementedError, transport.abort) + + def test_dgram_not_implemented(self): + transport = transports.DatagramTransport() + + self.assertRaises(NotImplementedError, transport.sendto, 'data') + self.assertRaises(NotImplementedError, transport.abort) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py new file mode 100644 index 00000000..d7af7ecc --- /dev/null +++ b/tests/unix_events_test.py @@ -0,0 +1,573 @@ +"""Tests for unix_events.py.""" + +import errno +import io +import unittest +import unittest.mock + +try: + import signal +except ImportError: + signal = None + +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import unix_events + + +@unittest.skipUnless(signal, 'Signals are not supported') +class SelectorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unix_events.SelectorEventLoop() + + def test_check_signal(self): + self.assertRaises( + TypeError, self.event_loop._check_signal, '1') + self.assertRaises( + ValueError, self.event_loop._check_signal, signal.NSIG + 1) + + unix_events.signal = None + + def restore_signal(): + unix_events.signal = signal + self.addCleanup(restore_signal) + + self.assertRaises( + RuntimeError, self.event_loop._check_signal, signal.SIGINT) + + def test_handle_signal_no_handler(self): + self.event_loop._handle_signal(signal.NSIG + 1, ()) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_setup_error(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.set_wakeup_fd.side_effect = ValueError + + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + h = self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + self.assertIsInstance(h, events.Handle) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_install_error(self, m_signal): + m_signal.NSIG = signal.NSIG + + def set_wakeup_fd(fd): + if fd == -1: + raise ValueError() + m_signal.set_wakeup_fd = set_wakeup_fd + + class Err(OSError): + errno = errno.EFAULT + m_signal.signal.side_effect = Err + + self.assertRaises( + Err, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_add_signal_handler_install_error2(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.event_loop._signal_handlers[signal.SIGHUP] = lambda: True + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(1, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_add_signal_handler_install_error3(self, m_logging, m_signal): + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + m_signal.NSIG = signal.NSIG + + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(2, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + self.assertTrue( + self.event_loop.remove_signal_handler(signal.SIGHUP)) + self.assertTrue(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_2(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.SIGINT = signal.SIGINT + + self.event_loop.add_signal_handler(signal.SIGINT, lambda: True) + self.event_loop._signal_handlers[signal.SIGHUP] = object() + m_signal.set_wakeup_fd.reset_mock() + + self.assertTrue( + self.event_loop.remove_signal_handler(signal.SIGINT)) + self.assertFalse(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGINT, m_signal.default_int_handler), + m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.set_wakeup_fd.side_effect = ValueError + + self.event_loop.remove_signal_handler(signal.SIGHUP) + self.assertTrue(m_logging.info) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error(self, m_signal): + m_signal.NSIG = signal.NSIG + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.signal.side_effect = OSError + + self.assertRaises( + OSError, self.event_loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error2(self, m_signal): + m_signal.NSIG = signal.NSIG + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.assertRaises( + RuntimeError, self.event_loop.remove_signal_handler, signal.SIGHUP) + + +class UnixReadPipeTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor(self, m_fcntl): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + self.event_loop.add_reader.assert_called_with(5, tr._read_ready) + self.event_loop.call_soon.assert_called_with( + self.protocol.connection_made, tr) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor_with_waiter(self, m_fcntl): + fut = futures.Future() + unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol, fut) + self.event_loop.call_soon.assert_called_with(fut.set_result, None) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + m_read.return_value = b'data' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.data_received.assert_called_with(b'data') + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready_eof(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + m_read.return_value = b'' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.eof_received.assert_called_with() + self.event_loop.remove_reader.assert_called_with(5) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready_blocked(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + self.event_loop.reset_mock() + m_read.side_effect = BlockingIOError + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.assertFalse(self.protocol.data_received.called) + + @unittest.mock.patch('tulip.log.tulip_log.exception') + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready_error(self, m_fcntl, m_read, m_logexc): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + err = OSError() + m_read.side_effect = err + tr._close = unittest.mock.Mock() + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + tr._close.assert_called_with(err) + m_logexc.assert_called_with('Fatal error for %s', tr) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_pause(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.pause() + self.event_loop.remove_reader.assert_called_with(5) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_resume(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.resume() + self.event_loop.add_reader.assert_called_with(5, tr._read_ready) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_close(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._close = unittest.mock.Mock() + tr.close() + tr._close.assert_called_with(None) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_close_already_closing(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._closing = True + tr._close = unittest.mock.Mock() + tr.close() + self.assertFalse(tr._close.called) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__close(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = object() + tr._close(err) + self.assertTrue(tr._closing) + self.event_loop.remove_reader.assert_called_with(5) + self.protocol.connection_lost.assert_called_with(err) + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost(self, m_fcntl): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost_with_err(self, m_fcntl): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + +class UnixWritePipeTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + self.event_loop.call_soon.assert_called_with( + self.protocol.connection_made, tr) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor_with_waiter(self, m_fcntl): + fut = futures.Future() + unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol, fut) + self.event_loop.call_soon.assert_called_with(fut.set_result, None) + + @unittest.mock.patch('fcntl.fcntl') + def test_can_write_eof(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + self.assertTrue(tr.can_write_eof()) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + m_write.return_value = 4 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.add_writer.called) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_no_data(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write(b'') + self.assertFalse(m_write.called) + self.assertFalse(self.event_loop.add_writer.called) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_partial(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + m_write.return_value = 2 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.event_loop.add_writer.assert_called_with(5, tr._write_ready) + self.assertEqual([b'ta'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_buffer(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'previous'] + tr.write(b'data') + self.assertFalse(m_write.called) + self.assertFalse(self.event_loop.add_writer.called) + self.assertEqual([b'previous', b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_again(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + m_write.side_effect = BlockingIOError() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.event_loop.add_writer.assert_called_with(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_err(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = OSError() + m_write.side_effect = err + tr._fatal_error = unittest.mock.Mock() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.called) + self.assertEqual([], tr._buffer) + tr._fatal_error.assert_called_with(err) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_partial(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.return_value = 3 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'a'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_again(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.side_effect = BlockingIOError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_empty(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.return_value = 0 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('tulip.log.tulip_log.exception') + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_err(self, m_fcntl, m_write, m_logexc): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.side_effect = err = OSError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + self.protocol.connection_lost.assert_called_with(err) + m_logexc.assert_called_with('Fatal error for %s', tr) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_closing(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._closing = True + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_abort(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + tr.abort() + self.assertFalse(m_write.called) + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost_with_err(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test_close(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr.close() + tr.write_eof.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test_close_closing(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr._closing = True + tr.close() + self.assertFalse(tr.write_eof.called) + + @unittest.mock.patch('fcntl.fcntl') + def test_write_eof(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write_eof() + self.assertTrue(tr._closing) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('fcntl.fcntl') + def test_write_eof_pending(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + tr._buffer = [b'data'] + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.protocol.connection_lost.called) diff --git a/tests/winsocketpair_test.py b/tests/winsocketpair_test.py new file mode 100644 index 00000000..381fb227 --- /dev/null +++ b/tests/winsocketpair_test.py @@ -0,0 +1,26 @@ +"""Tests for winsocketpair.py""" + +import unittest +import unittest.mock + +from tulip import winsocketpair + + +class WinsocketpairTests(unittest.TestCase): + + def test_winsocketpair(self): + ssock, csock = winsocketpair.socketpair() + + csock.send(b'xxx') + self.assertEqual(b'xxx', ssock.recv(1024)) + + csock.close() + ssock.close() + + @unittest.mock.patch('tulip.winsocketpair.socket') + def test_winsocketpair_exc(self, m_socket): + m_socket.socket.return_value.getsockname.return_value = ('', 12345) + m_socket.socket.return_value.accept.return_value = object(), object() + m_socket.socket.return_value.connect.side_effect = OSError() + + self.assertRaises(OSError, winsocketpair.socketpair) diff --git a/tulip/TODO b/tulip/TODO new file mode 100644 index 00000000..b3a9302e --- /dev/null +++ b/tulip/TODO @@ -0,0 +1,26 @@ +TODO in tulip v2 (tulip/ package directory) +------------------------------------------- + +- See also TBD and Open Issues in PEP 3156 + +- Refactor unix_events.py (it's getting too long) + +- Docstrings + +- Unittests + +- better run_once() behavior? (Run ready list last.) + +- start_serving() + +- Make Handler() callable? Log the exception in there? + +- Add the repeat interval to the Handler class? + +- Recognize Handler passed to add_reader(), call_soon(), etc.? + +- SSL support + +- buffered stream implementation + +- Primitives like par() and wait_one() diff --git a/tulip/__init__.py b/tulip/__init__.py new file mode 100644 index 00000000..faf307fb --- /dev/null +++ b/tulip/__init__.py @@ -0,0 +1,26 @@ +"""Tulip 2.0, tracking PEP 3156.""" + +import sys + +# This relies on each of the submodules having an __all__ variable. +from .futures import * +from .events import * +from .locks import * +from .transports import * +from .protocols import * +from .streams import * +from .tasks import * + +if sys.platform == 'win32': # pragma: no cover + from .windows_events import * +else: + from .unix_events import * # pragma: no cover + + +__all__ = (futures.__all__ + + events.__all__ + + locks.__all__ + + transports.__all__ + + protocols.__all__ + + streams.__all__ + + tasks.__all__) diff --git a/tulip/base_events.py b/tulip/base_events.py new file mode 100644 index 00000000..d9e2316b --- /dev/null +++ b/tulip/base_events.py @@ -0,0 +1,556 @@ +"""Base implementation of event loop. + +The event loop can be broken up into a multiplexer (the part +responsible for notifying us of IO events) and the event loop proper, +which wraps a multiplexer with functionality for scheduling callbacks, +immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. +""" + + +import collections +import concurrent.futures +import heapq +import logging +import socket +import time + +from . import events +from . import futures +from . import tasks +from .log import tulip_log + + +__all__ = ['BaseEventLoop'] + + +# Argument for default thread pool executor creation. +_MAX_WORKERS = 5 + + +class _StopError(BaseException): + """Raised to stop the event loop.""" + + +def _raise_stop_error(*args): + raise _StopError + + +class BaseEventLoop(events.AbstractEventLoop): + + def __init__(self): + self._ready = collections.deque() + self._scheduled = [] + self._default_executor = None + self._internal_fds = 0 + self._running = False + + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None): + """Create socket transport.""" + raise NotImplementedError + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + server_side=False, extra=None): + """Create SSL transport.""" + raise NotImplementedError + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + """Create datagram transport.""" + raise NotImplementedError + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create read pipe transport.""" + raise NotImplementedError + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create write pipe transport.""" + raise NotImplementedError + + def _read_from_self(self): + """XXX""" + raise NotImplementedError + + def _write_to_self(self): + """XXX""" + raise NotImplementedError + + def _process_events(self, event_list): + """Process selector events.""" + raise NotImplementedError + + def is_running(self): + """Returns running status of event loop.""" + return self._running + + def run(self): + """Run the event loop until nothing left to do or stop() called. + + This keeps going as long as there are either readable and + writable file descriptors, or scheduled callbacks (of either + variety). + + TODO: Give this a timeout too? + """ + if self._running: + raise RuntimeError('Event loop is running.') + + self._running = True + try: + while (self._ready or + self._scheduled or + self._selector.registered_count() > 1): + try: + self._run_once() + except _StopError: + break + finally: + self._running = False + + def run_forever(self): + """Run until stop() is called. + + This only makes sense over run() if you have another thread + scheduling callbacks using call_soon_threadsafe(). + """ + handle = self.call_repeatedly(24*3600, lambda: None) + try: + self.run() + finally: + handle.cancel() + + def run_once(self, timeout=0): + """Run through all callbacks and all I/O polls once. + + Calling stop() will break out of this too. + """ + if self._running: + raise RuntimeError('Event loop is running.') + + self._running = True + try: + self._run_once(timeout) + except _StopError: + pass + finally: + self._running = False + + def run_until_complete(self, future, timeout=None): + """Run until the Future is done, or until a timeout. + + If the argument is a coroutine, it is wrapped in a Task. + + XXX TBD: It would be disastrous to call run_until_complete() + with the same coroutine twice -- it would wrap it in two + different Tasks and that can't be good. + + Return the Future's result, or raise its exception. If the + timeout is reached or stop() is called, raise TimeoutError. + """ + if not isinstance(future, futures.Future): + if tasks.iscoroutine(future): + future = tasks.Task(future) + else: + assert False, 'A Future or coroutine is required' + + handle_called = False + + def stop_loop(): + nonlocal handle_called + handle_called = True + raise _StopError + + future.add_done_callback(_raise_stop_error) + + if timeout is None: + self.run_forever() + else: + handle = self.call_later(timeout, stop_loop) + self.run_forever() + handle.cancel() + + if handle_called: + raise futures.TimeoutError + + return future.result() + + def stop(self): + """Stop running the event loop. + + Every callback scheduled before stop() is called will run. + Callback scheduled after stop() is called won't. However, + those callbacks will run if run() is called again later. + """ + self.call_soon(_raise_stop_error) + + def call_later(self, delay, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The delay can be an int or float, expressed in seconds. It is + always a relative time. + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Callbacks scheduled in the past are passed on to call_soon(), + so these will be called in the order in which they were + registered rather than by time due. This is so you can't + cheat and insert yourself at the front of the ready queue by + using a negative time. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + if delay <= 0: + return self.call_soon(callback, *args) + + handle = events.Timer(time.monotonic() + delay, callback, args) + heapq.heappush(self._scheduled, handle) + return handle + + def call_repeatedly(self, interval, callback, *args): + """Call a callback every 'interval' seconds.""" + assert interval > 0, 'Interval must be > 0: {!r}'.format(interval) + # TODO: What if callback is already a Handle? + def wrapper(): + callback(*args) # If this fails, the chain is broken. + handle._when = time.monotonic() + interval + heapq.heappush(self._scheduled, handle) + + handle = events.Timer(time.monotonic() + interval, wrapper, ()) + heapq.heappush(self._scheduled, handle) + return handle + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + handle = events.make_handle(callback, args) + self._ready.append(handle) + return handle + + def call_soon_threadsafe(self, callback, *args): + """XXX""" + handle = self.call_soon(callback, *args) + self._write_to_self() + return handle + + def run_in_executor(self, executor, callback, *args): + if isinstance(callback, events.Handle): + assert not args + assert not isinstance(callback, events.Timer) + if callback.cancelled: + f = futures.Future() + f.set_result(None) + return f + callback, args = callback.callback, callback.args + if executor is None: + executor = self._default_executor + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + self._default_executor = executor + return self.wrap_future(executor.submit(callback, *args)) + + def set_default_executor(self, executor): + self._default_executor = executor + + def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) + + def getnameinfo(self, sockaddr, flags=0): + return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + + @tasks.coroutine + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=False, family=0, proto=0, flags=0, sock=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + "host, port and sock can not be specified at the same time") + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + yield from self.sock_connect(sock, address) + except socket.error as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + if len(exceptions) == 1: + raise exceptions[0] + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise socket.error('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + + elif sock is None: + raise ValueError( + "host and port was not specified and no sock specified") + + sock.setblocking(False) + + if ssl: + import ssl as sslmod + sslcontext = sslmod.SSLContext(sslmod.PROTOCOL_SSLv23) + sock = sslcontext.wrap_socket(sock, server_side=False, + do_handshake_on_connect=False) + + protocol = protocol_factory() + waiter = futures.Future() + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + transport = self._make_ssl_transport( + sock, protocol, sslcontext, waiter, server_side=False) + else: + transport = self._make_socket_transport(sock, protocol, waiter) + + yield from waiter + return transport, protocol + + @tasks.coroutine + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + """Create datagram connection.""" + if not (local_addr or remote_addr): + if family == 0: + raise ValueError('unexpected address family') + addr_pairs_info = (((family, proto), (None, None)),) + else: + # join addresss by (family, protocol) + addr_infos = collections.OrderedDict() + for idx, addr in ((0, local_addr), (1, remote_addr)): + if addr is not None: + assert isinstance(addr, tuple) and len(addr) == 2, ( + '2-tuple is expected') + + infos = yield from self.getaddrinfo( + *addr, family=family, type=socket.SOCK_DGRAM, + proto=proto, flags=flags) + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + for fam, _, pro, _, address in infos: + key = (fam, pro) + if key not in addr_infos: + addr_infos[key] = [None, None] + addr_infos[key][idx] = address + + # each addr has to have info for each (family, proto) pair + addr_pairs_info = [ + (key, addr_pair) for key, addr_pair in addr_infos.items() + if not ((local_addr and addr_pair[0] is None) or + (remote_addr and addr_pair[1] is None))] + + if not addr_pairs_info: + raise ValueError('can not get address information') + + exceptions = [] + + for (family, proto), (local_address, remote_address) in addr_pairs_info: + sock = None + l_addr = None + r_addr = None + try: + sock = socket.socket( + family=family, type=socket.SOCK_DGRAM, proto=proto) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(False) + + if local_addr: + sock.bind(local_address) + l_addr = sock.getsockname() + if remote_addr: + yield from self.sock_connect(sock, remote_address) + r_addr = remote_address + except socket.error as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + + protocol = protocol_factory() + transport = self._make_datagram_transport( + sock, protocol, r_addr, extra={'addr': l_addr}) + return transport, protocol + + # TODO: Or create_server()? + @tasks.task + def start_serving(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, backlog=100, sock=None, + ssl=False): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + "host, port and sock can not be specified at the same time") + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + # TODO: Maybe we want to bind every address in the list + # instead of the first one that works? + exceptions = [] + for family, type, proto, cname, address in infos: + sock = socket.socket(family=family, type=type, proto=proto) + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + except socket.error as exc: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + + elif sock is None: + raise ValueError( + "host and port was not specified and no sock specified") + + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock, ssl) + return sock + + @tasks.coroutine + def connect_read_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future() + transport = self._make_read_pipe_transport(pipe, protocol, waiter, + extra={}) + yield from waiter + return transport, protocol + + @tasks.coroutine + def connect_write_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future() + transport = self._make_write_pipe_transport(pipe, protocol, waiter, + extra={}) + yield from waiter + return transport, protocol + + def _add_callback(self, handle): + """Add a Handle to ready or scheduled.""" + if handle.cancelled: + return + if isinstance(handle, events.Timer): + heapq.heappush(self._scheduled, handle) + else: + self._ready.append(handle) + + def wrap_future(self, future): + """XXX""" + if isinstance(future, futures.Future): + return future # Don't wrap our own type of Future. + new_future = futures.Future() + future.add_done_callback( + lambda future: + self.call_soon_threadsafe(new_future._copy_state, future)) + return new_future + + def _run_once(self, timeout=None): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0].cancelled: + heapq.heappop(self._scheduled) + + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0].when + deadline = max(0, when - time.monotonic()) + if timeout is None: + timeout = deadline + else: + timeout = min(timeout, deadline) + + t0 = time.monotonic() + event_list = self._selector.select(timeout) + t1 = time.monotonic() + argstr = '' if timeout is None else '{:.3f}'.format(timeout) + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + tulip_log.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + self._process_events(event_list) + + # Handle 'later' callbacks that are ready. + now = time.monotonic() + while self._scheduled: + handle = self._scheduled[0] + if handle.when > now: + break + handle = heapq.heappop(self._scheduled) + self._ready.append(handle) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is threadsafe without using locks. + ntodo = len(self._ready) + for i in range(ntodo): + handle = self._ready.popleft() + if not handle.cancelled: + handle.run() diff --git a/tulip/events.py b/tulip/events.py new file mode 100644 index 00000000..68cd7211 --- /dev/null +++ b/tulip/events.py @@ -0,0 +1,360 @@ +"""Event loop and event loop policy. + +Beyond the PEP: +- Only the main thread has a default event loop. +""" + +__all__ = ['EventLoopPolicy', 'DefaultEventLoopPolicy', + 'AbstractEventLoop', 'Timer', 'Handle', 'make_handle', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', 'new_event_loop', + ] + +import sys +import threading + +from .log import tulip_log + + +class Handle: + """Object returned by callback registration methods.""" + + def __init__(self, callback, args): + self._callback = callback + self._args = args + self._cancelled = False + + def __repr__(self): + res = 'Handle({}, {})'.format(self._callback, self._args) + if self._cancelled: + res += '' + return res + + @property + def callback(self): + return self._callback + + @property + def args(self): + return self._args + + @property + def cancelled(self): + return self._cancelled + + def cancel(self): + self._cancelled = True + + def run(self): + try: + self._callback(*self._args) + except Exception: + tulip_log.exception('Exception in callback %s %r', + self._callback, self._args) + + +def make_handle(callback, args): + if isinstance(callback, Handle): + assert not args + return callback + return Handle(callback, args) + + +class Timer(Handle): + """Object returned by timed callback registration methods.""" + + def __init__(self, when, callback, args): + assert when is not None + super().__init__(callback, args) + + self._when = when + + def __repr__(self): + res = 'Timer({}, {}, {})'.format(self._when, + self._callback, + self._args) + if self._cancelled: + res += '' + + return res + + @property + def when(self): + return self._when + + def __lt__(self, other): + return self._when < other._when + + def __le__(self, other): + if self._when < other._when: + return True + return self.__eq__(other) + + def __gt__(self, other): + return self._when > other._when + + def __ge__(self, other): + if self._when > other._when: + return True + return self.__eq__(other) + + def __eq__(self, other): + if isinstance(other, Timer): + return (self._when == other._when and + self._callback == other._callback and + self._args == other._args and + self._cancelled == other._cancelled) + return NotImplemented + + def __ne__(self, other): + equal = self.__eq__(other) + return NotImplemented if equal is NotImplemented else not equal + + +class AbstractEventLoop: + """Abstract event loop.""" + + # TODO: Rename run() -> run_until_idle(), run_forever() -> run(). + + def run(self): + """Run the event loop. Block until there is nothing left to do.""" + raise NotImplementedError + + def run_forever(self): + """Run the event loop. Block until stop() is called.""" + raise NotImplementedError + + def run_once(self, timeout=None): # NEW! + """Run one complete cycle of the event loop.""" + raise NotImplementedError + + def run_until_complete(self, future, timeout=None): # NEW! + """Run the event loop until a Future is done. + + Return the Future's result, or raise its exception. + + If timeout is not None, run it for at most that long; + if the Future is still not done, raise TimeoutError + (but don't cancel the Future). + """ + raise NotImplementedError + + def stop(self): # NEW! + """Stop the event loop as soon as reasonable. + + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + raise NotImplementedError + + # Methods returning Handles for scheduling callbacks. + + def call_later(self, delay, callback, *args): + raise NotImplementedError + + def call_repeatedly(self, interval, callback, *args): # NEW! + raise NotImplementedError + + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) + + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError + + # Methods returning Futures for interacting with threads. + + def wrap_future(self, future): + raise NotImplementedError + + def run_in_executor(self, executor, callback, *args): + raise NotImplementedError + + # Network I/O methods returning Futures. + + def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError + + def create_connection(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, sock=None): + raise NotImplementedError + + def start_serving(self, protocol_factory, host=None, port=None, *, + family=0, proto=0, flags=0, sock=None): + raise NotImplementedError + + def stop_serving(self, sock): + """Stop listening for incoming connections. Close socket.""" + raise NotImplementedError + + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + def connect_read_pipe(self, protocol_factory, pipe): + """Register read pipe in eventloop. + + protocol_factory should instantiate object with Protocol interface. + pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + ReadTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def connect_write_pipe(self, protocol_factory, pipe): + """Register write pipe in eventloop. + + protocol_factory should instantiate object with BaseProtocol interface. + Pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + WriteTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + #def spawn_subprocess(self, protocol_factory, pipe): + # raise NotImplementedError + + # Ready-based callback registration methods. + # The add_*() methods return a Handle. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. + + def add_reader(self, fd, callback, *args): + raise NotImplementedError + + def remove_reader(self, fd): + raise NotImplementedError + + def add_writer(self, fd, callback, *args): + raise NotImplementedError + + def remove_writer(self, fd): + raise NotImplementedError + + # Completion based I/O methods returning Futures. + + def sock_recv(self, sock, nbytes): + raise NotImplementedError + + def sock_sendall(self, sock, data): + raise NotImplementedError + + def sock_connect(self, sock, address): + raise NotImplementedError + + def sock_accept(self, sock): + raise NotImplementedError + + # Signal handling. + + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError + + def remove_signal_handler(self, sig): + raise NotImplementedError + + +class EventLoopPolicy: + """Abstract policy for accessing the event loop.""" + + def get_event_loop(self): + """XXX""" + raise NotImplementedError + + def set_event_loop(self, event_loop): + """XXX""" + raise NotImplementedError + + def new_event_loop(self): + """XXX""" + raise NotImplementedError + + +class DefaultEventLoopPolicy(threading.local, EventLoopPolicy): + """Default policy implementation for accessing the event loop. + + In this policy, each thread has its own event loop. However, we + only automatically create an event loop by default for the main + thread; other threads by default have no event loop. + + Other policies may have different rules (e.g. a single global + event loop, or automatically creating an event loop per thread, or + using some other notion of context to which an event loop is + associated). + """ + + _event_loop = None + + def get_event_loop(self): + """Get the event loop. + + This may be None or an instance of EventLoop. + """ + if (self._event_loop is None and + threading.current_thread().name == 'MainThread'): + self._event_loop = self.new_event_loop() + return self._event_loop + + def set_event_loop(self, event_loop): + """Set the event loop.""" + assert event_loop is None or isinstance(event_loop, AbstractEventLoop) + self._event_loop = event_loop + + def new_event_loop(self): + """Create a new event loop. + + You must call set_event_loop() to make this the current event + loop. + """ + if sys.platform == 'win32': # pragma: no cover + from . import windows_events + return windows_events.SelectorEventLoop() + else: # pragma: no cover + from . import unix_events + return unix_events.SelectorEventLoop() + + +# Event loop policy. The policy itself is always global, even if the +# policy's rules say that there is an event loop per thread (or other +# notion of context). The default policy is installed by the first +# call to get_event_loop_policy(). +_event_loop_policy = None + + +def get_event_loop_policy(): + """XXX""" + global _event_loop_policy + if _event_loop_policy is None: + _event_loop_policy = DefaultEventLoopPolicy() + return _event_loop_policy + + +def set_event_loop_policy(policy): + """XXX""" + global _event_loop_policy + assert policy is None or isinstance(policy, EventLoopPolicy) + _event_loop_policy = policy + + +def get_event_loop(): + """XXX""" + return get_event_loop_policy().get_event_loop() + + +def set_event_loop(event_loop): + """XXX""" + get_event_loop_policy().set_event_loop(event_loop) + + +def new_event_loop(): + """XXX""" + return get_event_loop_policy().new_event_loop() diff --git a/tulip/futures.py b/tulip/futures.py new file mode 100644 index 00000000..39137aa6 --- /dev/null +++ b/tulip/futures.py @@ -0,0 +1,255 @@ +"""A Future class similar to the one in PEP 3148.""" + +__all__ = ['CancelledError', 'TimeoutError', + 'InvalidStateError', 'InvalidTimeoutError', + 'Future', + ] + +import concurrent.futures._base + +from . import events + +# States for Future. +_PENDING = 'PENDING' +_CANCELLED = 'CANCELLED' +_FINISHED = 'FINISHED' + +# TODO: Do we really want to depend on concurrent.futures internals? +Error = concurrent.futures._base.Error +CancelledError = concurrent.futures.CancelledError +TimeoutError = concurrent.futures.TimeoutError + + +class InvalidStateError(Error): + """The operation is not allowed in this state.""" + # TODO: Show the future, its state, the method, and the required state. + + +class InvalidTimeoutError(Error): + """Called result() or exception() with timeout != 0.""" + # TODO: Print a nice error message. + + +class Future: + """This class is *almost* compatible with concurrent.futures.Future. + + Differences: + + - result() and exception() do not take a timeout argument and + raise an exception when the future isn't done yet. + + - Callbacks registered with add_done_callback() are always called + via the event loop's call_soon_threadsafe(). + + - This class is not compatible with the wait() and as_completed() + methods in the concurrent.futures package. + + (In Python 3.4 or later we may be able to unify the implementations.) + """ + + # Class variables serving as defaults for instance variables. + _state = _PENDING + _result = None + _exception = None + _timeout_handle = None + + _blocking = False # proper use of future (yield vs yield from) + + def __init__(self, *, event_loop=None, timeout=None): + """Initialize the future. + + The optional event_loop argument allows to explicitly set the event + loop object used by the future. If it's not provided, the future uses + the default event loop. + """ + if event_loop is None: + self._event_loop = events.get_event_loop() + else: + self._event_loop = event_loop + self._callbacks = [] + + if timeout is not None: + self._timeout_handle = self._event_loop.call_later( + timeout, self.cancel) + + def __repr__(self): + res = self.__class__.__name__ + if self._state == _FINISHED: + if self._exception is not None: + res += ''.format(self._exception) + else: + res += ''.format(self._result) + elif self._callbacks: + size = len(self._callbacks) + if size > 2: + res += '<{}, [{}, <{} more>, {}]>'.format( + self._state, self._callbacks[0], + size-2, self._callbacks[-1]) + else: + res += '<{}, {}>'.format(self._state, self._callbacks) + else: + res += '<{}>'.format(self._state) + return res + + def cancel(self): + """Cancel the future and schedule callbacks. + + If the future is already done or cancelled, return False. Otherwise, + change the future's state to cancelled, schedule the callbacks and + return True. + """ + if self._state != _PENDING: + return False + self._state = _CANCELLED + self._schedule_callbacks() + return True + + def _schedule_callbacks(self): + """Internal: Ask the event loop to call all callbacks. + + The callbacks are scheduled to be called as soon as possible. Also + clears the callback list. + """ + # Cancel timeout handle + if self._timeout_handle is not None: + self._timeout_handle.cancel() + self._timeout_handle = None + + callbacks = self._callbacks[:] + if not callbacks: + return + + self._callbacks[:] = [] + for callback in callbacks: + self._event_loop.call_soon(callback, self) + + def cancelled(self): + """Return True if the future was cancelled.""" + return self._state == _CANCELLED + + def running(self): + """Always return False. + + This method is for compatibility with concurrent.futures; we don't + have a running state. + """ + return False # We don't have a running state. + + def done(self): + """Return True if the future is done. + + Done means either that a result / exception are available, or that the + future was cancelled. + """ + return self._state != _PENDING + + def result(self, timeout=0): + """Return the result this future represents. + + If the future has been cancelled, raises CancelledError. If the + future's result isn't yet available, raises InvalidStateError. If + the future is done and has an exception set, this exception is raised. + Timeout values other than 0 are not supported. + """ + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + if self._exception is not None: + raise self._exception + return self._result + + def exception(self, timeout=0): + """Return the exception that was set on this future. + + The exception (or None if no exception was set) is returned only if + the future is done. If the future has been cancelled, raises + CancelledError. If the future isn't done yet, raises + InvalidStateError. Timeout values other than 0 are not supported. + """ + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + return self._exception + + def add_done_callback(self, fn): + """Add a callback to be run when the future becomes done. + + The callback is called with a single argument - the future object. If + the future is already done when this is called, the callback is + scheduled with call_soon. + """ + if self._state != _PENDING: + self._event_loop.call_soon(fn, self) + else: + self._callbacks.append(fn) + + # New method not in PEP 3148. + + def remove_done_callback(self, fn): + """Remove all instances of a callback from the "call when done" list. + + Returns the number of callbacks removed. + """ + filtered_callbacks = [f for f in self._callbacks if f != fn] + removed_count = len(self._callbacks) - len(filtered_callbacks) + if removed_count: + self._callbacks[:] = filtered_callbacks + return removed_count + + # So-called internal methods (note: no set_running_or_notify_cancel()). + + def set_result(self, result): + """Mark the future done and set its result. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError + self._result = result + self._state = _FINISHED + self._schedule_callbacks() + + def set_exception(self, exception): + """ Mark the future done and set an exception. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError + self._exception = exception + self._state = _FINISHED + self._schedule_callbacks() + + # Truly internal methods. + + def _copy_state(self, other): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert other.done() + assert not self.done() + if other.cancelled(): + self.cancel() + else: + exception = other.exception() + if exception is not None: + self.set_exception(exception) + else: + result = other.result() + self.set_result(result) + + def __iter__(self): + if not self.done(): + self._blocking = True + yield self # This tells Task to wait for completion. + assert self.done(), "yield from wasn't used with future" + return self.result() # May raise too. diff --git a/tulip/http/__init__.py b/tulip/http/__init__.py new file mode 100644 index 00000000..b2a0a26d --- /dev/null +++ b/tulip/http/__init__.py @@ -0,0 +1,14 @@ +# This relies on each of the submodules having an __all__ variable. + +from .client import * +from .errors import * +from .protocol import * +from .server import * +from .wsgi import * + + +__all__ = (client.__all__ + + errors.__all__ + + protocol.__all__ + + server.__all__ + + wsgi.__all__) diff --git a/tulip/http/client.py b/tulip/http/client.py new file mode 100644 index 00000000..0fff6e86 --- /dev/null +++ b/tulip/http/client.py @@ -0,0 +1,533 @@ +"""HTTP Client for Tulip. + +Most basic usage: + + response = yield from tulip.http.request('GET', url) + response['Content-Type'] == 'application/json' + response.status == 200 + + content = yield from response.content.read() +""" + +__all__ = ['request'] + +import base64 +import email.message +import http.client +import http.cookies +import json +import io +import itertools +import mimetypes +import os +import uuid +import urllib.parse + +import tulip +from tulip.http import protocol + + +@tulip.coroutine +def request(method, url, *, + params=None, + data=None, + headers=None, + cookies=None, + files=None, + auth=None, + allow_redirects=True, + max_redirects=10, + encoding='utf-8', + version=(1, 1), + timeout=None, + compress=None, + chunked=None): + """Constructs and sends a request. Returns response object. + + method: http method + url: request url + params: (optional) Dictionary or bytes to be sent in the query string + of the new request + data: (optional) Dictionary, bytes, or file-like object to + send in the body of the request + headers: (optional) Dictionary of HTTP Headers to send with the request + cookies: (optional) Dict object to send with the request + files: (optional) Dictionary of 'name': file-like-objects + for multipart encoding upload + auth: (optional) Auth tuple to enable Basic HTTP Auth + timeout: (optional) Float describing the timeout of the request + allow_redirects: (optional) Boolean. Set to True if POST/PUT/DELETE + redirect following is allowed. + compress: Boolean. Set to True if request has to be compressed + with deflate encoding. + chunked: Boolean or Integer. Set to chunk size for chunked + transfer encoding. + + Usage: + + import tulip.http + >> resp = yield from tulip.http.request('GET', 'http://python.org/') + >> resp + + + >> data = yield from resp.content.read() + + """ + redirects = 0 + loop = tulip.get_event_loop() + + while True: + req = HttpRequest( + method, url, params=params, headers=headers, data=data, + cookies=cookies, files=files, auth=auth, encoding=encoding, + version=version, compress=compress, chunked=chunked) + + # connection timeout + try: + resp = yield from tulip.Task(start(req, loop), timeout=timeout) + except tulip.CancelledError: + raise tulip.TimeoutError from None + + # redirects + if resp.status in (301, 302) and allow_redirects: + redirects += 1 + if max_redirects and redirects >= max_redirects: + resp.close() + break + + r_url = resp.get('location') or resp.get('uri') + + scheme = urllib.parse.urlsplit(r_url)[0] + if scheme not in ('http', 'https', ''): + raise ValueError('Can redirect only to http or https') + elif not scheme: + r_url = urllib.parse.urljoin(url, r_url) + + url = urllib.parse.urldefrag(r_url)[0] + if url: + resp.close() + continue + + break + + return resp + + +@tulip.coroutine +def start(req, loop): + transport, p = yield from loop.create_connection( + HttpProtocol, req.host, req.port, ssl=req.ssl) + try: + resp = req.send(transport) + yield from resp.start(p.stream, transport) + except: + transport.close() + raise + + return resp + + +class HttpProtocol(tulip.Protocol): + + stream = None + transport = None + + def connection_made(self, transport): + self.transport = transport + self.stream = protocol.HttpStreamReader() + + def data_received(self, data): + self.stream.feed_data(data) + + def eof_received(self): + self.stream.feed_eof() + + def connection_lost(self, exc): + pass + + +class HttpRequest: + + GET_METHODS = {'DELETE', 'GET', 'HEAD', 'OPTIONS'} + POST_METHODS = {'PATCH', 'POST', 'PUT', 'TRACE'} + ALL_METHODS = GET_METHODS.union(POST_METHODS) + + DEFAULT_HEADERS = { + 'Accept': '*/*', + 'Accept-Encoding': 'gzip, deflate', + } + + body = b'' + + def __init__(self, method, url, *, + params=None, + headers=None, + data=None, + cookies=None, + files=None, + auth=None, + encoding='utf-8', + version=(1, 1), + compress=None, + chunked=None): + self.method = method.upper() + self.encoding = encoding + + # parser http version '1.1' => (1, 1) + if isinstance(version, str): + v = [l.strip() for l in version.split('.', 1)] + try: + version = int(v[0]), int(v[1]) + except: + raise ValueError( + 'Can not parse http version number: {}' + .format(version)) from None + self.version = version + + # path + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not netloc: + raise ValueError('Host could not be detected.') + + if not path: + path = '/' + else: + path = urllib.parse.unquote(path) + + # check domain idna encoding + try: + netloc = netloc.encode('idna').decode('utf-8') + except UnicodeError: + raise ValueError('URL has an invalid label.') + + # basic auth info + if '@' in netloc: + authinfo, netloc = netloc.split('@', 1) + if not auth: + auth = authinfo.split(':', 1) + if len(auth) == 1: + auth.append('') + + # extract host and port + ssl = scheme == 'https' + + if ':' in netloc: + netloc, port_s = netloc.split(':', 1) + try: + port = int(port_s) + except: + raise ValueError( + 'Port number could not be converted.') from None + else: + if ssl: + port = http.client.HTTPS_PORT + else: + port = http.client.HTTP_PORT + + self.host = netloc + self.port = port + self.ssl = ssl + + # build url query + if isinstance(params, dict): + params = list(params.items()) + + if data and self.method in self.GET_METHODS: + # include data to query + if isinstance(data, dict): + data = data.items() + params = list(itertools.chain(params or (), data)) + data = None + + if params: + params = urllib.parse.urlencode(params) + if query: + query = '%s&%s' % (query, params) + else: + query = params + + # build path + path = urllib.parse.quote(path) + self.path = urllib.parse.urlunsplit(('', '', path, query, fragment)) + + # headers + self.headers = email.message.Message() + if headers: + if isinstance(headers, dict): + headers = list(headers.items()) + + for key, value in headers: + self.headers.add_header(key, value) + + for hdr, val in self.DEFAULT_HEADERS.items(): + if hdr not in self.headers: + self.headers[hdr] = val + + # host + if 'host' not in self.headers: + self.headers['Host'] = self.host + + # cookies + if cookies: + c = http.cookies.SimpleCookie() + if 'cookie' in self.headers: + c.load(self.headers.get('cookie', '')) + del self.headers['cookie'] + + for name, value in cookies.items(): + if isinstance(value, http.cookies.Morsel): + dict.__setitem__(c, name, value) + else: + c[name] = value + + self.headers['cookie'] = c.output(header='', sep=';').strip() + + # auth + if auth: + if isinstance(auth, (tuple, list)) and len(auth) == 2: + # basic auth + self.headers['Authorization'] = 'Basic %s' % ( + base64.b64encode( + ('%s:%s' % (auth[0], auth[1])).encode('latin1')) + .strip().decode('latin1')) + else: + raise ValueError("Only basic auth is supported") + + self._params = (chunked, compress, files, data, encoding) + + def send(self, transport): + chunked, compress, files, data, encoding = self._params + + request = tulip.http.Request( + transport, self.method, self.path, self.version) + + # Content-encoding + enc = self.headers.get('Content-Encoding', '').lower() + if enc: + chunked = True # enable chunked, no need to deal with length + request.add_compression_filter(enc) + elif compress: + chunked = True # enable chunked, no need to deal with length + compress = compress if isinstance(compress, str) else 'deflate' + self.headers['Content-Encoding'] = compress + request.add_compression_filter(compress) + + # form data (x-www-form-urlencoded) + if isinstance(data, dict): + data = list(data.items()) + + if data and not files: + if not isinstance(data, str): + data = urllib.parse.urlencode(data, doseq=True) + + self.body = data.encode(encoding) + if 'content-type' not in self.headers: + self.headers['content-type'] = ( + 'application/x-www-form-urlencoded') + if 'content-length' not in self.headers and not chunked: + self.headers['content-length'] = len(self.body) + + # files (multipart/form-data) + elif files: + fields = [] + + if data: + for field, val in data: + fields.append((field, str_to_bytes(val))) + + if isinstance(files, dict): + files = list(files.items()) + + for rec in files: + if not isinstance(rec, (tuple, list)): + rec = (rec,) + + ft = None + if len(rec) == 1: + k = guess_filename(rec[0], 'unknown') + fields.append((k, k, rec[0])) + + elif len(rec) == 2: + k, fp = rec + fn = guess_filename(fp, k) + fields.append((k, fn, fp)) + + else: + k, fp, ft = rec + fn = guess_filename(fp, k) + fields.append((k, fn, fp, ft)) + + chunked = chunked or 8192 + boundary = uuid.uuid4().hex + + self.body = encode_multipart_data( + fields, bytes(boundary, 'latin1')) + + self.headers['content-type'] = ( + 'multipart/form-data; boundary=%s' % boundary) + + # chunked + te = self.headers.get('transfer-encoding', '').lower() + + if chunked: + if 'content-length' in self.headers: + del self.headers['content-length'] + if 'chunked' not in te: + self.headers['Transfer-encoding'] = 'chunked' + + chunk_size = chunked if type(chunked) is int else 8196 + request.add_chunking_filter(chunk_size) + else: + if 'chunked' in te: + request.add_chunking_filter(8196) + else: + chunked = False + self.headers['content-length'] = len(self.body) + + request.add_headers(*self.headers.items()) + request.send_headers() + + if isinstance(self.body, bytes): + self.body = (self.body,) + + for chunk in self.body: + request.write(chunk) + + request.write_eof() + + return HttpResponse(self.method, self.path, self.host) + + +class HttpResponse(http.client.HTTPMessage): + + # from the Status-Line of the response + version = None # HTTP-Version + status = None # Status-Code + reason = None # Reason-Phrase + + content = None # payload stream + + _content = None + _transport = None + + def __init__(self, method, url, host=''): + super().__init__() + + self.method = method + self.url = url + self.host = host + + def __repr__(self): + out = io.StringIO() + print(''.format( + self.host, self.url, self.status, self.reason), file=out) + print(super().__str__(), file=out) + return out.getvalue() + + def start(self, stream, transport): + """Start response processing.""" + self._transport = transport + + # read status + self.version, self.status, self.reason = ( + yield from stream.read_response_status()) + + # does the body have a fixed length? (of zero) + length = None + if (self.status == http.client.NO_CONTENT or + self.status == http.client.NOT_MODIFIED or + 100 <= self.status < 200 or self.method == "HEAD"): + length = 0 + + # http message + message = yield from stream.read_message(length=length) + + # headers + for hdr, val in message.headers: + self.add_header(hdr, val) + + # payload + self.content = message.payload + + return self + + def close(self): + if self._transport is not None: + self._transport.close() + self._transport = None + + @tulip.coroutine + def read(self, decode=False): + """Read response payload. Decode known types of content.""" + if self._content is None: + self._content = yield from self.content.read() + + data = self._content + + if decode: + ct = self.get('content-type', '').lower() + if ct == 'application/json': + data = json.loads(data.decode('utf-8')) + + return data + + +def str_to_bytes(s, encoding='utf-8'): + if isinstance(s, str): + return s.encode(encoding) + return s + + +def guess_filename(obj, default=None): + name = getattr(obj, 'name', None) + if name and name[0] != '<' and name[-1] != '>': + return os.path.split(name)[-1] + return default + + +def encode_multipart_data(fields, boundary, encoding='utf-8', chunk_size=8196): + """ + Encode a list of fields using the multipart/form-data MIME format. + + fields: + List of (name, value) or (name, filename, io) or + (name, filename, io, MIME type) field tuples. + """ + for rec in fields: + yield b'--' + boundary + b'\r\n' + + field, *rec = rec + + if len(rec) == 1: + data = rec[0] + yield (('Content-Disposition: form-data; name="%s"\r\n\r\n' % + (field,)).encode(encoding)) + yield data + b'\r\n' + + else: + if len(rec) == 3: + fn, fp, ct = rec + else: + fn, fp = rec + ct = (mimetypes.guess_type(fn)[0] or + 'application/octet-stream') + + yield ('Content-Disposition: form-data; name="%s"; ' + 'filename="%s"\r\n' % (field, fn)).encode(encoding) + yield ('Content-Type: %s\r\n\r\n' % (ct,)).encode(encoding) + + if isinstance(fp, str): + fp = fp.encode(encoding) + + if isinstance(fp, bytes): + fp = io.BytesIO(fp) + + while True: + chunk = fp.read(chunk_size) + if not chunk: + break + yield str_to_bytes(chunk) + + yield b'\r\n' + + yield b'--' + boundary + b'--\r\n' diff --git a/tulip/http/errors.py b/tulip/http/errors.py new file mode 100644 index 00000000..24032337 --- /dev/null +++ b/tulip/http/errors.py @@ -0,0 +1,44 @@ +"""http related errors.""" + +__all__ = ['HttpException', 'HttpStatusException', + 'IncompleteRead', 'BadStatusLine', 'LineTooLong', 'InvalidHeader'] + +import http.client + + +class HttpException(http.client.HTTPException): + + code = None + headers = () + + +class HttpStatusException(HttpException): + + def __init__(self, code, headers=None, message=''): + self.code = code + self.headers = headers + self.message = message + + +class BadRequestException(HttpException): + + code = 400 + + +class IncompleteRead(BadRequestException, http.client.IncompleteRead): + pass + + +class BadStatusLine(BadRequestException, http.client.BadStatusLine): + pass + + +class LineTooLong(BadRequestException, http.client.LineTooLong): + pass + + +class InvalidHeader(BadRequestException): + + def __init__(self, hdr): + super().__init__('Invalid HTTP Header: {}'.format(hdr)) + self.hdr = hdr diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py new file mode 100644 index 00000000..378feff1 --- /dev/null +++ b/tulip/http/protocol.py @@ -0,0 +1,879 @@ +"""Http related helper utils.""" + +__all__ = ['HttpStreamReader', + 'HttpMessage', 'Request', 'Response', + 'RawHttpMessage', 'RequestLine', 'ResponseStatus'] + +import collections +import functools +import http.server +import itertools +import re +import sys +import time +import zlib +from wsgiref.handlers import format_date_time + +import tulip +from . import errors + +METHRE = re.compile('[A-Z0-9$-_.]+') +VERSRE = re.compile('HTTP/(\d+).(\d+)') +HDRRE = re.compile(b"[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]") +CONTINUATION = (b' ', b'\t') +RESPONSES = http.server.BaseHTTPRequestHandler.responses + +RequestLine = collections.namedtuple( + 'RequestLine', ['method', 'uri', 'version']) + + +ResponseStatus = collections.namedtuple( + 'ResponseStatus', ['version', 'code', 'reason']) + + +RawHttpMessage = collections.namedtuple( + 'RawHttpMessage', ['headers', 'payload', 'should_close', 'compression']) + + +class HttpStreamReader(tulip.StreamReader): + + MAX_HEADERS = 32768 + MAX_HEADERFIELD_SIZE = 8190 + + # if _parser is set, feed_data and feed_eof sends data into + # _parser instead of self. is it being used as stream redirection for + # _parse_chunked_payload, _parse_length_payload and _parse_eof_payload + _parser = None + + def feed_data(self, data): + """_parser is a generator, if _parser is set, feed_data sends + incoming data into the generator untile generator stops.""" + if self._parser: + try: + self._parser.send(data) + except StopIteration as exc: + self._parser = None + if exc.value: + self.feed_data(exc.value) + else: + super().feed_data(data) + + def feed_eof(self): + """_parser is a generator, if _parser is set feed_eof throws + StreamEofException into this generator.""" + if self._parser: + try: + self._parser.throw(StreamEofException()) + except StopIteration: + self._parser = None + + super().feed_eof() + + @tulip.coroutine + def read_request_line(self): + """Read request status line. Exception errors.BadStatusLine + could be raised in case of any errors in status line. + Returns three values (method, uri, version) + + Example: + + GET /path HTTP/1.1 + + >> yield from reader.read_request_line() + ('GET', '/path', (1, 1)) + + """ + bline = yield from self.readline() + try: + line = bline.decode('ascii').rstrip() + except UnicodeDecodeError: + raise errors.BadStatusLine(bline) from None + + try: + method, uri, version = line.split(None, 2) + except ValueError: + raise errors.BadStatusLine(line) from None + + # method + method = method.upper() + if not METHRE.match(method): + raise errors.BadStatusLine(method) + + # version + match = VERSRE.match(version) + if match is None: + raise errors.BadStatusLine(version) + version = (int(match.group(1)), int(match.group(2))) + + return RequestLine(method, uri, version) + + @tulip.coroutine + def read_response_status(self): + """Read response status line. Exception errors.BadStatusLine + could be raised in case of any errors in status line. + Returns three values (version, status_code, reason) + + Example: + + HTTP/1.1 200 Ok + + >> yield from reader.read_response_status() + ((1, 1), 200, 'Ok') + + """ + bline = yield from self.readline() + if not bline: + # Presumably, the server closed the connection before + # sending a valid response. + raise errors.BadStatusLine(bline) + + try: + line = bline.decode('ascii').rstrip() + except UnicodeDecodeError: + raise errors.BadStatusLine(bline) from None + + try: + version, status = line.split(None, 1) + except ValueError: + raise errors.BadStatusLine(line) from None + else: + try: + status, reason = status.split(None, 1) + except ValueError: + reason = '' + + # version + match = VERSRE.match(version) + if match is None: + raise errors.BadStatusLine(line) + version = (int(match.group(1)), int(match.group(2))) + + # The status code is a three-digit number + try: + status = int(status) + except ValueError: + raise errors.BadStatusLine(line) from None + + if status < 100 or status > 999: + raise errors.BadStatusLine(line) + + return ResponseStatus(version, status, reason.strip()) + + @tulip.coroutine + def read_headers(self): + """Read and parses RFC2822 headers from a stream. + + Line continuations are supported. Returns list of header name + and value pairs. Header name is in upper case. + """ + size = 0 + headers = [] + + line = yield from self.readline() + + while line not in (b'\r\n', b'\n'): + header_length = len(line) + + # Parse initial header name : value pair. + sep_pos = line.find(b':') + if sep_pos < 0: + raise ValueError('Invalid header {}'.format(line.strip())) + + name, value = line[:sep_pos], line[sep_pos+1:] + name = name.rstrip(b' \t').upper() + if HDRRE.search(name): + raise ValueError('Invalid header name {}'.format(name)) + + name = name.strip().decode('ascii', 'surrogateescape') + value = [value.lstrip()] + + # next line + line = yield from self.readline() + + # consume continuation lines + continuation = line.startswith(CONTINUATION) + + if continuation: + while continuation: + header_length += len(line) + if header_length > self.MAX_HEADERFIELD_SIZE: + raise errors.LineTooLong( + 'limit request headers fields size') + value.append(line) + + line = yield from self.readline() + continuation = line.startswith(CONTINUATION) + else: + if header_length > self.MAX_HEADERFIELD_SIZE: + raise errors.LineTooLong( + 'limit request headers fields size') + + # total headers size + size += header_length + if size >= self.MAX_HEADERS: + raise errors.LineTooLong('limit request headers fields') + + headers.append( + (name, + b''.join(value).rstrip().decode('ascii', 'surrogateescape'))) + + return headers + + def _parse_chunked_payload(self): + """Chunked transfer encoding parser.""" + stream = yield + + try: + data = bytearray() + + while True: + # read line + if b'\n' not in data: + data.extend((yield)) + continue + + line, data = data.split(b'\n', 1) + + # Read the next chunk size from the file + i = line.find(b';') + if i >= 0: + line = line[:i] # strip chunk-extensions + try: + size = int(line, 16) + except ValueError: + raise errors.IncompleteRead(b'') from None + + if size == 0: + break + + # read chunk + while len(data) < size: + data.extend((yield)) + + # feed stream + stream.feed_data(data[:size]) + + data = data[size:] + + # toss the CRLF at the end of the chunk + while len(data) < 2: + data.extend((yield)) + + data = data[2:] + + # read and discard trailer up to the CRLF terminator + while True: + if b'\n' in data: + line, data = data.split(b'\n', 1) + if line in (b'\r', b''): + break + else: + data.extend((yield)) + + # stream eof + stream.feed_eof() + return data + + except StreamEofException: + stream.set_exception(errors.IncompleteRead(b'')) + except errors.IncompleteRead as exc: + stream.set_exception(exc) + + def _parse_length_payload(self, length): + """Read specified amount of bytes.""" + stream = yield + + try: + data = bytearray() + while length: + data.extend((yield)) + + data_len = len(data) + if data_len <= length: + stream.feed_data(data) + data = bytearray() + length -= data_len + else: + stream.feed_data(data[:length]) + data = data[length:] + length = 0 + + stream.feed_eof() + return data + except StreamEofException: + stream.set_exception(errors.IncompleteRead(b'')) + + def _parse_eof_payload(self): + """Read all bytes untile eof.""" + stream = yield + + try: + while True: + stream.feed_data((yield)) + except StreamEofException: + stream.feed_eof() + + @tulip.coroutine + def read_message(self, version=(1, 1), + length=None, compression=True, readall=False): + """Read RFC2822 headers and message payload from a stream. + + read_message() automatically decompress gzip and deflate content + encoding. To prevent decompression pass compression=False. + + Returns tuple of headers, payload stream, should close flag, + compression type. + """ + # load headers + headers = yield from self.read_headers() + + # payload params + chunked = False + encoding = None + close_conn = None + + for name, value in headers: + if name == 'CONTENT-LENGTH': + length = value + elif name == 'TRANSFER-ENCODING': + chunked = value.lower() == 'chunked' + elif name == 'SEC-WEBSOCKET-KEY1': + length = 8 + elif name == 'CONNECTION': + v = value.lower() + if v == 'close': + close_conn = True + elif v == 'keep-alive': + close_conn = False + elif compression and name == 'CONTENT-ENCODING': + enc = value.lower() + if enc in ('gzip', 'deflate'): + encoding = enc + + if close_conn is None: + close_conn = version <= (1, 0) + + # payload parser + if chunked: + parser = self._parse_chunked_payload() + + elif length is not None: + try: + length = int(length) + except ValueError: + raise errors.InvalidHeader('CONTENT-LENGTH') from None + + if length < 0: + raise errors.InvalidHeader('CONTENT-LENGTH') + + parser = self._parse_length_payload(length) + else: + if readall: + parser = self._parse_eof_payload() + else: + parser = self._parse_length_payload(0) + + next(parser) + + payload = stream = tulip.StreamReader() + + # payload decompression wrapper + if encoding is not None: + stream = DeflateStream(stream, encoding) + + try: + # initialize payload parser with stream, stream is being + # used by parser as destination stream + parser.send(stream) + except StopIteration: + pass + else: + # feed existing buffer to payload parser + self.byte_count = 0 + while self.buffer: + try: + parser.send(self.buffer.popleft()) + except StopIteration as exc: + parser = None + + # parser is done + buf = b''.join(self.buffer) + self.buffer.clear() + + # re-add remaining data back to buffer + if exc.value: + self.feed_data(exc.value) + + if buf: + self.feed_data(buf) + + break + + # parser still require more data + if parser is not None: + if self.eof: + try: + parser.throw(StreamEofException()) + except StopIteration as exc: + pass + else: + self._parser = parser + + return RawHttpMessage(headers, payload, close_conn, encoding) + + +class StreamEofException(Exception): + """Internal exception: eof received.""" + + +class DeflateStream: + """DeflateStream decomress stream and feed data into specified stream.""" + + def __init__(self, stream, encoding): + self.stream = stream + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + + self.zlib = zlib.decompressobj(wbits=zlib_mode) + + def set_exception(self, exc): + self.stream.set_exception(exc) + + def feed_data(self, chunk): + try: + chunk = self.zlib.decompress(chunk) + except: + self.stream.set_exception(errors.IncompleteRead(b'')) + + if chunk: + self.stream.feed_data(chunk) + + def feed_eof(self): + self.stream.feed_data(self.zlib.flush()) + if not self.zlib.eof: + self.stream.set_exception(errors.IncompleteRead(b'')) + + self.stream.feed_eof() + + +EOF_MARKER = object() +EOL_MARKER = object() + + +def wrap_payload_filter(func): + """Wraps payload filter and piped filters. + + Filter is a generatator that accepts arbitrary chunks of data, + modify data and emit new stream of data. + + For example we have stream of chunks: ['1', '2', '3', '4', '5'], + we can apply chunking filter to this stream: + + ['1', '2', '3', '4', '5'] + | + response.add_chunking_filter(2) + | + ['12', '34', '5'] + + It is possible to use different filters at the same time. + + For a example to compress incoming stream with 'deflate' encoding + and then split data and emit chunks of 8196 bytes size chunks: + + >> response.add_compression_filter('deflate') + >> response.add_chunking_filter(8196) + + Filters do not alter transfer encoding. + + Filter can receive types types of data, bytes object or EOF_MARKER. + + 1. If filter receives bytes object, it should process data + and yield processed data then yield EOL_MARKER object. + 2. If Filter recevied EOF_MARKER, it should yield remaining + data (buffered) and then yield EOF_MARKER. + """ + @functools.wraps(func) + def wrapper(self, *args, **kw): + new_filter = func(self, *args, **kw) + + filter = self.filter + if filter is not None: + next(new_filter) + self.filter = filter_pipe(filter, new_filter) + else: + self.filter = new_filter + + next(self.filter) + + return wrapper + + +def filter_pipe(filter, filter2): + """Creates pipe between two filters. + + filter_pipe() feeds first filter with incoming data and then + send yielded from first filter data into filter2, results of + filter2 are being emitted. + + 1. If filter_pipe receives bytes object, it sends it to the first filter. + 2. Reads yielded values from the first filter until it receives + EOF_MARKER or EOL_MARKER. + 3. Each of this values is being send to second filter. + 4. Reads yielded values from second filter until it recives EOF_MARKER or + EOL_MARKER. Each of this values yields to writer. + """ + chunk = yield + + while True: + eof = chunk is EOF_MARKER + chunk = filter.send(chunk) + + while chunk is not EOL_MARKER: + chunk = filter2.send(chunk) + + while chunk not in (EOF_MARKER, EOL_MARKER): + yield chunk + chunk = next(filter2) + + if chunk is not EOF_MARKER: + if eof: + chunk = EOF_MARKER + else: + chunk = next(filter) + else: + break + + chunk = yield EOL_MARKER + + +class HttpMessage: + """HttpMessage allows to write headers and payload to a stream. + + For example, lets say we want to read file then compress it with deflate + compression and then send it with chunked transfer encoding, code may look + like this: + + >> response = tulip.http.Response(transport, 200) + + We have to use deflate compression first: + + >> response.add_compression_filter('deflate') + + Then we want to split output stream into chunks of 1024 bytes size: + + >> response.add_chunking_filter(1024) + + We can add headers to response with add_headers() method. add_headers() + does not send data to transport, send_headers() sends request/response + line and then sends headers: + + >> response.add_headers( + .. ('Content-Disposition', 'attachment; filename="..."')) + >> response.send_headers() + + Now we can use chunked writer to write stream to a network stream. + First call to write() method sends response status line and headers, + add_header() and add_headers() method unavailble at this stage: + + >> with open('...', 'rb') as f: + .. chunk = fp.read(8196) + .. while chunk: + .. response.write(chunk) + .. chunk = fp.read(8196) + + >> response.write_eof() + """ + + writer = None + + # 'filter' is being used for altering write() bahaviour, + # add_chunking_filter adds deflate/gzip compression and + # add_compression_filter splits incoming data into a chunks. + filter = None + + HOP_HEADERS = None # Must be set by subclass. + + SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} tulip/0.0'.format(sys.version_info) + + status = None + status_line = b'' + + # subclass can enable auto sending headers with write() call, + # this is useful for wsgi's start_response implementation. + _send_headers = False + + def __init__(self, transport, version, close): + self.transport = transport + self.version = version + self.closing = close + self.keepalive = False + + self.chunked = False + self.length = None + self.upgrade = False + self.headers = [] + self.headers_sent = False + + def force_close(self): + self.closing = True + + def force_chunked(self): + self.chunked = True + + def keep_alive(self): + return self.keepalive and not self.closing + + def is_headers_sent(self): + return self.headers_sent + + def add_header(self, name, value): + """Analyze headers. Calculate content length, + removes hop headers, etc.""" + assert not self.headers_sent, 'headers have been sent already' + assert isinstance(name, str), '{!r} is not a string'.format(name) + + name = name.strip().upper() + + if name == 'CONTENT-LENGTH': + self.length = int(value) + + if name == 'CONNECTION': + val = value.lower().strip() + # handle websocket + if val == 'upgrade': + self.upgrade = True + # connection keep-alive + elif val == 'close': + self.keepalive = False + elif val == 'keep-alive': + self.keepalive = True + + elif name == 'UPGRADE': + if 'websocket' in value.lower(): + self.headers.append((name, value)) + + elif name == 'TRANSFER-ENCODING' and not self.chunked: + self.chunked = value.lower().strip() == 'chunked' + + elif name not in self.HOP_HEADERS: + # ignore hopbyhop headers + self.headers.append((name, value)) + + def add_headers(self, *headers): + """Adds headers to a http message.""" + for name, value in headers: + self.add_header(name, value) + + def send_headers(self): + """Writes headers to a stream. Constructs payload writer.""" + # Chunked response is only for HTTP/1.1 clients or newer + # and there is no Content-Length header is set. + # Do not use chunked responses when the response is guaranteed to + # not have a response body (304, 204). + assert not self.headers_sent, 'headers have been sent already' + self.headers_sent = True + + if (self.chunked is True) or ( + self.length is None and + self.version >= (1, 1) and + self.status not in (304, 204)): + self.chunked = True + self.writer = self._write_chunked_payload() + + elif self.length is not None: + self.writer = self._write_length_payload(self.length) + + else: + self.writer = self._write_eof_payload() + + next(self.writer) + + # status line + self.transport.write(self.status_line.encode('ascii')) + + # send headers + self.transport.write( + ('{}\r\n\r\n'.format('\r\n'.join( + ('{}: {}'.format(k, v) for k, v in + itertools.chain(self._default_headers(), self.headers)))) + ).encode('ascii')) + + def _default_headers(self): + # set the connection header + if self.upgrade: + connection = 'upgrade' + elif self.keep_alive(): + connection = 'keep-alive' + else: + connection = 'close' + + headers = [('CONNECTION', connection)] + + if self.chunked: + headers.append(('TRANSFER-ENCODING', 'chunked')) + + return headers + + def write(self, chunk): + """write() writes chunk of data to a steram by using different writers. + writer uses filter to modify chunk of data. write_eof() indicates + end of stream. writer can't be used after write_eof() method + being called.""" + assert (isinstance(chunk, (bytes, bytearray)) or + chunk is EOF_MARKER), chunk + + if self._send_headers and not self.headers_sent: + self.send_headers() + + assert self.writer is not None, 'send_headers() is not called.' + + if self.filter: + chunk = self.filter.send(chunk) + while chunk not in (EOF_MARKER, EOL_MARKER): + self.writer.send(chunk) + chunk = next(self.filter) + else: + if chunk is not EOF_MARKER: + self.writer.send(chunk) + + def write_eof(self): + self.write(EOF_MARKER) + try: + self.writer.throw(StreamEofException()) + except StopIteration: + pass + + def _write_chunked_payload(self): + """Write data in chunked transfer encoding.""" + while True: + try: + chunk = yield + except StreamEofException: + self.transport.write(b'0\r\n\r\n') + break + + self.transport.write('{:x}\r\n'.format(len(chunk)).encode('ascii')) + self.transport.write(bytes(chunk)) + self.transport.write(b'\r\n') + + def _write_length_payload(self, length): + """Write specified number of bytes to a stream.""" + while True: + try: + chunk = yield + except StreamEofException: + break + + if length: + l = len(chunk) + if length >= l: + self.transport.write(chunk) + else: + self.transport.write(chunk[:length]) + + length = max(0, length-l) + + def _write_eof_payload(self): + while True: + try: + chunk = yield + except StreamEofException: + break + + self.transport.write(chunk) + + @wrap_payload_filter + def add_chunking_filter(self, chunk_size=16*1024): + """Split incoming stream into chunks.""" + buf = bytearray() + chunk = yield + + while True: + if chunk is EOF_MARKER: + if buf: + yield buf + + yield EOF_MARKER + + else: + buf.extend(chunk) + + while len(buf) >= chunk_size: + chunk = bytes(buf[:chunk_size]) + del buf[:chunk_size] + yield chunk + + chunk = yield EOL_MARKER + + @wrap_payload_filter + def add_compression_filter(self, encoding='deflate'): + """Compress incoming stream with deflate or gzip encoding.""" + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + zcomp = zlib.compressobj(wbits=zlib_mode) + + chunk = yield + while True: + if chunk is EOF_MARKER: + yield zcomp.flush() + chunk = yield EOF_MARKER + + else: + yield zcomp.compress(chunk) + chunk = yield EOL_MARKER + + +class Response(HttpMessage): + """Create http response message. + + Transport is a socket stream transport. status is a response status code, + status has to be integer value. http_version is a tuple that represents + http version, (1, 0) stands for HTTP/1.0 and (1, 1) is for HTTP/1.1 + """ + + HOP_HEADERS = { + 'CONNECTION', + 'KEEP-ALIVE', + 'PROXY-AUTHENTICATE', + 'PROXY-AUTHORIZATION', + 'TE', + 'TRAILERS', + 'TRANSFER-ENCODING', + 'UPGRADE', + 'SERVER', + 'DATE', + } + + def __init__(self, transport, status, http_version=(1, 1), close=False): + super().__init__(transport, http_version, close) + + self.status = status + self.status_line = 'HTTP/{0[0]}.{0[1]} {1} {2}\r\n'.format( + http_version, status, RESPONSES[status][0]) + + def _default_headers(self): + headers = super()._default_headers() + headers.extend((('DATE', format_date_time(time.time())), + ('SERVER', self.SERVER_SOFTWARE))) + + return headers + + +class Request(HttpMessage): + + HOP_HEADERS = () + + def __init__(self, transport, method, uri, + http_version=(1, 1), close=False): + super().__init__(transport, http_version, close) + + self.method = method + self.uri = uri + self.status_line = '{0} {1} HTTP/{2[0]}.{2[1]}\r\n'.format( + method, uri, http_version) + + def _default_headers(self): + headers = super()._default_headers() + headers.append(('USER-AGENT', self.SERVER_SOFTWARE)) + + return headers diff --git a/tulip/http/server.py b/tulip/http/server.py new file mode 100644 index 00000000..2eb6f98b --- /dev/null +++ b/tulip/http/server.py @@ -0,0 +1,176 @@ +"""simple http server.""" + +__all__ = ['ServerHttpProtocol'] + +import http.server +import inspect +import logging +import traceback + +import tulip +import tulip.http + +from . import errors + +RESPONSES = http.server.BaseHTTPRequestHandler.responses +DEFAULT_ERROR_MESSAGE = """ + + + {status} {reason} + + +

{status} {reason}

+ {message} + +""" + + +class ServerHttpProtocol(tulip.Protocol): + """Simple http protocol implementation. + + ServerHttpProtocol handles incoming http request. It reads request line, + request headers and request payload and calls handler_request() method. + By default it always returns with 404 respose. + + ServerHttpProtocol handles errors in incoming request, like bad + status line, bad headers or incomplete payload. If any error occurs, + connection gets closed. + """ + _closing = False + _request_count = 0 + _request_handle = None + + def __init__(self, log=logging, debug=False): + self.log = log + self.debug = debug + + def connection_made(self, transport): + self.transport = transport + self.stream = tulip.http.HttpStreamReader() + self._request_handle = self.start() + + def data_received(self, data): + self.stream.feed_data(data) + + def connection_lost(self, exc): + if self._request_handle is not None: + self._request_handle.cancel() + self._request_handle = None + + def eof_received(self): + self.stream.feed_eof() + + def close(self): + self._closing = True + + def log_access(self, status, info, message, *args, **kw): + pass + + def log_debug(self, *args, **kw): + if self.debug: + self.log.debug(*args, **kw) + + def log_exception(self, *args, **kw): + self.log.exception(*args, **kw) + + @tulip.task + def start(self): + """Start processing of incoming requests. + It reads request line, request headers and request payload, then + calls handle_request() method. Subclass has to override + handle_request(). start() handles various excetions in request + or response handling. In case of any error connection is being closed. + """ + + while True: + info = None + message = None + self._request_count += 1 + + try: + info = yield from self.stream.read_request_line() + message = yield from self.stream.read_message(info.version) + + handler = self.handle_request(info, message) + if (inspect.isgenerator(handler) or + isinstance(handler, tulip.Future)): + yield from handler + + except tulip.CancelledError: + self.log_debug('Ignored premature client disconnection.') + break + except errors.HttpException as exc: + self.handle_error(exc.code, info, message, exc, exc.headers) + except Exception as exc: + self.handle_error(500, info, message, exc) + finally: + if self._closing: + self.transport.close() + break + + self._request_handle = None + + def handle_error(self, status=500, info=None, + message=None, exc=None, headers=None): + """Handle errors. + + Returns http response with specific status code. Logs additional + information. It always closes current connection.""" + + if status == 500: + self.log_exception("Error handling request") + + try: + reason, msg = RESPONSES[status] + except KeyError: + status = 500 + reason, msg = '???', '' + + if self.debug and exc is not None: + try: + tb = traceback.format_exc() + msg += '

Traceback:

\n
{}
'.format(tb) + except: + pass + + self.log_access(status, info, message) + + html = DEFAULT_ERROR_MESSAGE.format( + status=status, reason=reason, message=msg) + + response = tulip.http.Response(self.transport, status, close=True) + response.add_headers( + ('Content-Type', 'text/html'), + ('Content-Length', str(len(html)))) + if headers is not None: + response.add_headers(*headers) + response.send_headers() + + response.write(html.encode('ascii')) + response.write_eof() + + self.close() + + def handle_request(self, info, message): + """Handle a single http request. + + Subclass should override this method. By default it always + returns 404 response. + + info: tulip.http.RequestLine instance + message: tulip.http.RawHttpMessage instance + """ + response = tulip.http.Response( + self.transport, 404, http_version=info.version, close=True) + + body = b'Page Not Found!' + + response.add_headers( + ('Content-Type', 'text/plain'), + ('Content-Length', str(len(body)))) + response.send_headers() + response.write(body) + response.write_eof() + + self.close() + self.log_access(404, info, message) diff --git a/tulip/http/wsgi.py b/tulip/http/wsgi.py new file mode 100644 index 00000000..5957f9c8 --- /dev/null +++ b/tulip/http/wsgi.py @@ -0,0 +1,219 @@ +"""wsgi server. + +TODO: + * proxy protocol + * x-forward security + * wsgi file support (os.sendfile) +""" + +__all__ = ['WSGIServerHttpProtocol'] + +import inspect +import io +import os +import sys +from urllib.parse import unquote, urlsplit + +import tulip +import tulip.http +from tulip.http import server + + +class WSGIServerHttpProtocol(server.ServerHttpProtocol): + """HTTP Server that implements the Python WSGI protocol. + + It uses 'wsgi.async' of 'True'. 'wsgi.input' can behave differently + depends on 'readpayload' constructor parameter. If readpayload is set to + True, wsgi server reads all incoming data into BytesIO object and + sends it as 'wsgi.input' environ var. If readpayload is set to false + 'wsgi.input' is a StreamReader and application should read incoming + data with "yield from environ['wsgi.input'].read()". It defaults to False. + """ + + SCRIPT_NAME = os.environ.get('SCRIPT_NAME', '') + + def __init__(self, app, readpayload=False, is_ssl=False, *args, **kw): + super().__init__(*args, **kw) + + self.wsgi = app + self.is_ssl = is_ssl + self.readpayload = readpayload + + def create_wsgi_response(self, info, message): + return WsgiResponse(self.transport, info, message) + + def create_wsgi_environ(self, info, message, payload): + uri_parts = urlsplit(info.uri) + url_scheme = 'https' if self.is_ssl else 'http' + + environ = { + 'wsgi.input': payload, + 'wsgi.errors': sys.stderr, + 'wsgi.version': (1, 0), + 'wsgi.async': True, + 'wsgi.multithread': False, + 'wsgi.multiprocess': False, + 'wsgi.run_once': False, + 'wsgi.file_wrapper': FileWrapper, + 'wsgi.url_scheme': url_scheme, + 'SERVER_SOFTWARE': tulip.http.HttpMessage.SERVER_SOFTWARE, + 'REQUEST_METHOD': info.method, + 'QUERY_STRING': uri_parts.query or '', + 'RAW_URI': info.uri, + 'SERVER_PROTOCOL': 'HTTP/%s.%s' % info.version + } + + # authors should be aware that REMOTE_HOST and REMOTE_ADDR + # may not qualify the remote addr: + # http://www.ietf.org/rfc/rfc3875 + forward = self.transport.get_extra_info('addr', '127.0.0.1') + script_name = self.SCRIPT_NAME + server = forward + + for hdr_name, hdr_value in message.headers: + if hdr_name == 'EXPECT': + # handle expect + if hdr_value.lower() == '100-continue': + self.transport.write(b'HTTP/1.1 100 Continue\r\n\r\n') + elif hdr_name == 'HOST': + server = hdr_value + elif hdr_name == 'SCRIPT_NAME': + script_name = hdr_value + elif hdr_name == 'CONTENT-TYPE': + environ['CONTENT_TYPE'] = hdr_value + continue + elif hdr_name == 'CONTENT-LENGTH': + environ['CONTENT_LENGTH'] = hdr_value + continue + + key = 'HTTP_' + hdr_name.replace('-', '_') + if key in environ: + hdr_value = '%s,%s' % (environ[key], hdr_value) + + environ[key] = hdr_value + + if isinstance(forward, str): + # we only took the last one + # http://en.wikipedia.org/wiki/X-Forwarded-For + if ',' in forward: + forward = forward.rsplit(',', 1)[-1].strip() + + # find host and port on ipv6 address + if '[' in forward and ']' in forward: + host = forward.split(']')[0][1:].lower() + elif ':' in forward and forward.count(':') == 1: + host = forward.split(':')[0].lower() + else: + host = forward + + forward = forward.split(']')[-1] + if ':' in forward and forward.count(':') == 1: + port = forward.split(':', 1)[1] + else: + port = 80 + + remote = (host, port) + else: + remote = forward + + environ['REMOTE_ADDR'] = remote[0] + environ['REMOTE_PORT'] = str(remote[1]) + + if isinstance(server, str): + server = server.split(':') + if len(server) == 1: + server.append('80' if url_scheme == 'http' else '443') + + environ['SERVER_NAME'] = server[0] + environ['SERVER_PORT'] = str(server[1]) + + path_info = uri_parts.path + if script_name: + path_info = path_info.split(script_name, 1)[-1] + + environ['PATH_INFO'] = unquote(path_info) + environ['SCRIPT_NAME'] = script_name + + environ['tulip.reader'] = self.stream + environ['tulip.writer'] = self.transport + + return environ + + @tulip.coroutine + def handle_request(self, info, message): + """Handle a single HTTP request""" + + if self.readpayload: + payload = io.BytesIO((yield from message.payload.read())) + else: + payload = message.payload + + environ = self.create_wsgi_environ(info, message, payload) + response = self.create_wsgi_response(info, message) + + riter = self.wsgi(environ, response.start_response) + if isinstance(riter, tulip.Future) or inspect.isgenerator(riter): + riter = yield from riter + + resp = response.response + try: + for item in riter: + if isinstance(item, tulip.Future): + item = yield from item + resp.write(item) + + resp.write_eof() + finally: + if hasattr(riter, 'close'): + riter.close() + + if not resp.keep_alive(): + self.close() + + +class FileWrapper: + """Custom file wrapper.""" + + def __init__(self, fobj, chunk_size=8192): + self.fobj = fobj + self.chunk_size = chunk_size + if hasattr(fobj, 'close'): + self.close = fobj.close + + def __iter__(self): + return self + + def __next__(self): + data = self.fobj.read(self.chunk_size) + if data: + return data + raise StopIteration + + +class WsgiResponse: + """Implementation of start_response() callable as specified by PEP 3333""" + + status = None + + def __init__(self, transport, info, message): + self.transport = transport + self.info = info + self.message = message + + def start_response(self, status, headers, exc_info=None): + if exc_info: + try: + if self.status: + raise exc_info[1] + finally: + exc_info = None + + status_code = int(status.split(' ', 1)[0]) + + self.status = status + self.response = tulip.http.Response( + self.transport, status_code, + self.info.version, self.message.should_close) + self.response.add_headers(*headers) + self.response._send_headers = True + return self.response.write diff --git a/tulip/locks.py b/tulip/locks.py new file mode 100644 index 00000000..ff841442 --- /dev/null +++ b/tulip/locks.py @@ -0,0 +1,434 @@ +"""Synchronization primitives""" + +__all__ = ['Lock', 'EventWaiter', 'Condition', 'Semaphore'] + +import collections +import time + +from . import events +from . import futures +from . import tasks + + +class Lock: + """The class implementing primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned by + a particular coroutine when locked. A primitive lock is in one of two + states, "locked" or "unlocked". + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() changes + the state to locked and returns immediately. When the state is locked, + acquire() blocks until a call to release() in another coroutine changes + it to unlocked, then the acquire() call resets it to locked and returns. + The release() method should only be called in the locked state; it changes + the state to unlocked and returns immediately. If an attempt is made + to release an unlocked lock, a RuntimeError will be raised. + + When more than one coroutine is blocked in acquire() waiting for the state + to turn to unlocked, only one coroutine proceeds when a release() call + resets the state to unlocked; first coroutine which is blocked in acquire() + is being processed. + + acquire() method is a coroutine and should be called with "yield from" + + Locks also support the context manager protocol. (yield from lock) should + be used as context manager expression. + + Usage: + + lock = Lock() + ... + yield from lock + try: + ... + finally: + lock.release() + + Context manager usage: + + lock = Lock() + ... + with (yield from lock): + ... + + Lock object could be tested for locking state: + + if not lock.locked(): + yield from lock + else: + # lock is acquired + ... + + """ + + def __init__(self): + self._waiters = collections.deque() + self._locked = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<{} [{}]>'.format( + res[1:-1], 'locked' if self._locked else 'unlocked') + + def locked(self): + """Return true if lock is acquired.""" + return self._locked + + @tasks.coroutine + def acquire(self, timeout=None): + """Acquire a lock. + + Acquire method blocks until the lock is unlocked, then set it to + locked and return True. + + When invoked with the floating-point timeout argument set, blocks for + at most the number of seconds specified by timeout and as long as + the lock cannot be acquired. + + The return value is True if the lock is acquired successfully, + False if not (for example if the timeout expired). + """ + if not self._waiters and not self._locked: + self._locked = True + return True + + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + self._locked = True + return True + + def release(self): + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + If any other coroutines are blocked waiting for the lock to become + unlocked, allow exactly one of them to proceed. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + if self._waiters: + self._waiters[0].set_result(True) + else: + raise RuntimeError('Lock is not acquired.') + + def __enter__(self): + if not self._locked: + raise RuntimeError( + '"yield from" should be used as context manager expression') + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self + + +class EventWaiter: + """A EventWaiter implementation, our equivalent to threading.Event + + Class implementing event objects. An event manages a flag that can be set + to true with the set() method and reset to false with the clear() method. + The wait() method blocks until the flag is true. The flag is initially + false. + """ + + def __init__(self): + self._waiters = collections.deque() + self._value = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<{} [{}]>'.format(res[1:-1], 'set' if self._value else 'unset') + + def is_set(self): + """Return true if and only if the internal flag is true.""" + return self._value + + def set(self): + """Set the internal flag to true. All coroutines waiting for it to + become true are awakened. Coroutine that call wait() once the flag is + true will not block at all. + """ + if not self._value: + self._value = True + + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + + def clear(self): + """Reset the internal flag to false. Subsequently, coroutines calling + wait() will block until set() is called to set the internal flag + to true again.""" + self._value = False + + @tasks.coroutine + def wait(self, timeout=None): + """Block until the internal flag is true. If the internal flag + is true on entry, return immediately. Otherwise, block until another + coroutine calls set() to set the flag to true, or until the optional + timeout occurs. + + When the timeout argument is present and not None, it should be + a floating point number specifying a timeout for the operation in + seconds (or fractions thereof). + + This method returns true if and only if the internal flag has been + set to true, either before the wait call or after the wait starts, + so it will always return True except if a timeout is given and + the operation times out. + + wait() method is a coroutine. + """ + if self._value: + return True + + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + return True + + +class Condition(Lock): + """A Condition implementation. + + This class implements condition variable objects. A condition variable + allows one or more coroutines to wait until they are notified by another + coroutine. + """ + + def __init__(self): + super().__init__() + + self._condition_waiters = collections.deque() + + @tasks.coroutine + def wait(self, timeout=None): + """Wait until notified or until a timeout occurs. If the calling + coroutine has not acquired the lock when this method is called, + a RuntimeError is raised. + + This method releases the underlying lock, and then blocks until it is + awakened by a notify() or notify_all() call for the same condition + variable in another coroutine, or until the optional timeout occurs. + Once awakened or timed out, it re-acquires the lock and returns. + + When the timeout argument is present and not None, it should be + a floating point number specifying a timeout for the operation + in seconds (or fractions thereof). + + The return value is True unless a given timeout expired, in which + case it is False. + """ + if not self._locked: + raise RuntimeError('cannot wait on un-acquired lock') + + self.release() + + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + + self._condition_waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._condition_waiters.remove(fut) + return False + else: + f = self._condition_waiters.popleft() + assert fut is f + finally: + yield from self.acquire() + + return True + + @tasks.coroutine + def wait_for(self, predicate, timeout=None): + """Wait until a condition evaluates to True. predicate should be a + callable which result will be interpreted as a boolean value. A timeout + may be provided giving the maximum time to wait. + """ + endtime = None + waittime = timeout + result = predicate() + + while not result: + if waittime is not None: + if endtime is None: + endtime = time.monotonic() + waittime + else: + waittime = endtime - time.monotonic() + if waittime <= 0: + break + + yield from self.wait(waittime) + result = predicate() + + return result + + def notify(self, n=1): + """By default, wake up one coroutine waiting on this condition, if any. + If the calling coroutine has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up at most n of the coroutines waiting for the + condition variable; it is a no-op if no coroutines are waiting. + + Note: an awakened coroutine does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self._locked: + raise RuntimeError('cannot notify on un-acquired lock') + + idx = 0 + for fut in self._condition_waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self): + """Wake up all threads waiting on this condition. This method acts + like notify(), but wakes up all waiting threads instead of one. If the + calling thread has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._condition_waiters)) + + +class Semaphore: + """A Semaphore implementation. + + A semaphore manages an internal counter which is decremented by each + acquire() call and incremented by each release() call. The counter + can never go below zero; when acquire() finds that it is zero, it blocks, + waiting until some other thread calls release(). + + Semaphores also support the context manager protocol. + + The first optional argument gives the initial value for the internal + counter; it defaults to 1. If the value given is less than 0, + ValueError is raised. + + The second optional argument determins can semophore be released more than + initial internal counter value; it defaults to False. If the value given + is True and number of release() is more than number of successfull + acquire() calls ValueError is raised. + """ + + def __init__(self, value=1, bound=False): + if value < 0: + raise ValueError("Semaphore initial value must be > 0") + self._value = value + self._bound = bound + self._bound_value = value + self._waiters = collections.deque() + self._locked = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<{} [{}]>'.format( + res[1:-1], + 'locked' if self._locked else 'unlocked,value:{}'.format( + self._value)) + + def locked(self): + """Returns True if semaphore can not be acquired immediately.""" + return self._locked + + @tasks.coroutine + def acquire(self, timeout=None): + """Acquire a semaphore. acquire() method is a coroutine. + + When invoked without arguments: if the internal counter is larger + than zero on entry, decrement it by one and return immediately. + If it is zero on entry, block, waiting until some other coroutine has + called release() to make it larger than zero. + + When invoked with a timeout other than None, it will block for at + most timeout seconds. If acquire does not complete successfully in + that interval, return false. Return true otherwise. + """ + if not self._waiters and self._value > 0: + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + When it was zero on entry and another coroutine is waiting for it to + become larger than zero again, wake up that coroutine. + + If Semaphore is create with "bound" paramter equals true, then + release() method checks to make sure its current value doesn't exceed + its initial value. If it does, ValueError is raised. + """ + if self._bound and self._value >= self._bound_value: + raise ValueError('Semaphore released too many times') + + self._value += 1 + self._locked = False + + for waiter in self._waiters: + if not waiter.done(): + waiter.set_result(True) + break + + def __enter__(self): + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self diff --git a/tulip/log.py b/tulip/log.py new file mode 100644 index 00000000..b918fe54 --- /dev/null +++ b/tulip/log.py @@ -0,0 +1,6 @@ +"""Tulip logging configuration""" + +import logging + + +tulip_log = logging.getLogger("tulip") diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py new file mode 100644 index 00000000..6f38db7d --- /dev/null +++ b/tulip/proactor_events.py @@ -0,0 +1,198 @@ +"""Event loop using a proactor and related classes. + +A proactor is a "notify-on-completion" multiplexer. Currently a +proactor is only implemented on Windows with IOCP. +""" + +from . import base_events +from . import transports +from .log import tulip_log + + +class _ProactorSocketTransport(transports.Transport): + + def __init__(self, event_loop, sock, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._event_loop = event_loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._read_fut = None + self._write_fut = None + self._conn_lost = 0 + self._closing = False # Set when close() called. + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.call_soon(self._loop_reading) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _loop_reading(self, fut=None): + data = None + + try: + if fut is not None: + assert fut is self._read_fut + + data = fut.result() # deliver data later in "finally" clause + if not data: + self._read_fut = None + return + + self._read_fut = self._event_loop._proactor.recv(self._sock, 4096) + except ConnectionAbortedError as exc: + if not self._closing: + self._fatal_error(exc) + except OSError as exc: + self._fatal_error(exc) + else: + self._read_fut.add_done_callback(self._loop_reading) + finally: + if data: + self._protocol.data_received(data) + elif data is not None: + self._protocol.eof_received() + + def write(self, data): + assert isinstance(data, bytes), repr(data) + assert not self._closing + if not data: + return + if self._conn_lost: + if self._conn_lost >= 5: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + self._buffer.append(data) + if not self._write_fut: + self._loop_writing() + + def _loop_writing(self, f=None): + try: + assert f is self._write_fut + if f: + f.result() + data = b''.join(self._buffer) + self._buffer = [] + if not data: + self._write_fut = None + return + self._write_fut = self._event_loop._proactor.send(self._sock, data) + except OSError as exc: + self._conn_lost += 1 + self._fatal_error(exc) + else: + self._write_fut.add_done_callback(self._loop_writing) + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + if self._write_fut: + self._write_fut.cancel() + if not self._buffer: + self._event_loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + tulip_log.exception('Fatal error for %s', self) + if self._write_fut: + self._write_fut.cancel() + if self._read_fut: # XXX + self._read_fut.cancel() + self._write_fut = self._read_fut = None + self._buffer = [] + self._event_loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + + +class BaseProactorEventLoop(base_events.BaseEventLoop): + + def __init__(self, proactor): + super().__init__() + tulip_log.debug('Using proactor: %s', proactor.__class__.__name__) + self._proactor = proactor + self._selector = proactor # convenient alias + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + return _ProactorSocketTransport(self, sock, protocol, waiter, extra) + + def close(self): + if self._proactor is not None: + self._close_self_pipe() + self._proactor.close() + self._proactor = None + self._selector = None + + def sock_recv(self, sock, n): + return self._proactor.recv(sock, n) + + def sock_sendall(self, sock, data): + return self._proactor.send(sock, data) + + def sock_connect(self, sock, address): + return self._proactor.connect(sock, address) + + def sock_accept(self, sock): + return self._proactor.accept(sock) + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.call_soon(self._loop_self_reading) + + def _loop_self_reading(self, f=None): + try: + if f is not None: + f.result() # may raise + f = self._proactor.recv(self._ssock, 4096) + except: + self.close() + raise + else: + f.add_done_callback(self._loop_self_reading) + + def _write_to_self(self): + self._csock.send(b'x') + + def _start_serving(self, protocol_factory, sock, ssl=False): + assert not ssl, 'IocpEventLoop imcompatible with SSL.' + + def loop(f=None): + try: + if f: + conn, addr = f.result() + protocol = protocol_factory() + self._make_socket_transport( + conn, protocol, extra={'addr': addr}) + f = self._proactor.accept(sock) + except OSError: + sock.close() + tulip_log.exception('Accept failed') + else: + f.add_done_callback(loop) + self.call_soon(loop) + + def _process_events(self, event_list): + pass # XXX hard work currently done in poll diff --git a/tulip/protocols.py b/tulip/protocols.py new file mode 100644 index 00000000..593ee745 --- /dev/null +++ b/tulip/protocols.py @@ -0,0 +1,78 @@ +"""Abstract Protocol class.""" + +__all__ = ['Protocol', 'DatagramProtocol'] + + +class BaseProtocol: + """ABC for base protocol class. + + Usually user implements protocols that derived from BaseProtocol + like Protocol or ProcessProtocol. + + The only case when BaseProtocol should be implemented directly is + write-only transport like write pipe + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the pipe connection. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +class Protocol(BaseProtocol): + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing (they don't raise exceptions). + + When the user wants to requests a transport, they pass a protocol + factory to a utility function (e.g., EventLoop.create_connection()). + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_lost() will be called exactly once + with either an exception object or None as an argument. + + State machine of calls: + + start -> CM [-> DR*] [-> ER?] -> CL -> end + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + The default implementation does nothing. + + TODO: By default close the transport. But we don't have the + transport as an instance variable (connection_made() may not + set it). + """ + + +class DatagramProtocol(BaseProtocol): + """ABC representing a datagram protocol.""" + + def datagram_received(self, data, addr): + """Called when some datagram is received.""" + + def connection_refused(self, exc): + """Connection is refused.""" diff --git a/tulip/queues.py b/tulip/queues.py new file mode 100644 index 00000000..a87a8557 --- /dev/null +++ b/tulip/queues.py @@ -0,0 +1,291 @@ +"""Queues""" + +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue'] + +import collections +import concurrent.futures +import heapq +import queue + +from . import events +from . import futures +from . import locks +from .tasks import coroutine + + +class Queue: + """A queue, useful for coordinating producer and consumer coroutines. + + If maxsize is less than or equal to zero, the queue size is infinite. If it + is an integer greater than 0, then "yield from put()" will block when the + queue reaches maxsize, until an item is removed by get(). + + Unlike the standard library Queue, you can reliably know this Queue's size + with qsize(), since your single-threaded Tulip application won't be + interrupted between calling qsize() and doing an operation on the Queue. + """ + + def __init__(self, maxsize=0): + self._event_loop = events.get_event_loop() + self._maxsize = maxsize + + # Futures. + self._getters = collections.deque() + # Pairs of (item, Future). + self._putters = collections.deque() + self._init(maxsize) + + def _init(self, maxsize): + self._queue = collections.deque() + + def _get(self): + return self._queue.popleft() + + def _put(self, item): + self._queue.append(item) + + def __repr__(self): + return '<{} at {:#x} {}>'.format( + type(self).__name__, id(self), self._format()) + + def __str__(self): + return '<{} {}>'.format(type(self).__name__, self._format()) + + def _format(self): + result = 'maxsize={!r}'.format(self._maxsize) + if getattr(self, '_queue', None): + result += ' _queue={!r}'.format(list(self._queue)) + if self._getters: + result += ' _getters[{}]'.format(len(self._getters)) + if self._putters: + result += ' _putters[{}]'.format(len(self._putters)) + return result + + def _consume_done_getters(self, waiters): + # Delete waiters at the head of the get() queue who've timed out. + while waiters and waiters[0].done(): + waiters.popleft() + + def _consume_done_putters(self): + # Delete waiters at the head of the put() queue who've timed out. + while self._putters and self._putters[0][1].done(): + self._putters.popleft() + + def qsize(self): + """Number of items in the queue.""" + return len(self._queue) + + @property + def maxsize(self): + """Number of items allowed in the queue.""" + return self._maxsize + + def empty(self): + """Return True if the queue is empty, False otherwise.""" + return not self._queue + + def full(self): + """Return True if there are maxsize items in the queue. + + Note: if the Queue was initialized with maxsize=0 (the default), + then full() is never True. + """ + if self._maxsize <= 0: + return False + else: + return self.qsize() == self._maxsize + + @coroutine + def put(self, item, timeout=None): + """Put an item into the queue. + + If you yield from put() and timeout is None (the default), wait until a + free slot is available before adding item. + + If a timeout is provided, raise queue.Full if no free slot becomes + available before the timeout. + """ + self._consume_done_getters(self._getters) + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + waiter = futures.Future( + event_loop=self._event_loop, timeout=timeout) + + self._putters.append((item, waiter)) + try: + yield from waiter + except concurrent.futures.CancelledError: + raise queue.Full + + else: + self._put(item) + + def put_nowait(self, item): + """Put an item into the queue without blocking. + + If no free slot is immediately available, raise queue.Full. + """ + self._consume_done_getters(self._getters) + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + raise queue.Full + else: + self._put(item) + + @coroutine + def get(self, timeout=None): + """Remove and return an item from the queue. + + If you yield from get() and timeout is None (the default), wait until a + item is available. + + If a timeout is provided, raise queue.Empty if no item is available + before the timeout. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + + # When a getter runs and frees up a slot so this putter can + # run, we need to defer the put for a tick to ensure that + # getters and putters alternate perfectly. See + # ChannelTest.test_wait. + self._event_loop.call_soon(putter.set_result, None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + waiter = futures.Future( + event_loop=self._event_loop, timeout=timeout) + + self._getters.append(waiter) + try: + return (yield from waiter) + except concurrent.futures.CancelledError: + raise queue.Empty + + def get_nowait(self): + """Remove and return an item from the queue. + + Return an item if one is immediately available, else raise queue.Full. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + # Wake putter on next tick. + putter.set_result(None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + raise queue.Empty + + +class PriorityQueue(Queue): + """A subclass of Queue; retrieves entries in priority order (lowest first). + + Entries are typically tuples of the form: (priority number, data). + """ + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item, heappush=heapq.heappush): + heappush(self._queue, item) + + def _get(self, heappop=heapq.heappop): + return heappop(self._queue) + + +class LifoQueue(Queue): + """A subclass of Queue that retrieves most recently added entries first.""" + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item): + self._queue.append(item) + + def _get(self): + return self._queue.pop() + + +class JoinableQueue(Queue): + """A subclass of Queue with task_done() and join() methods.""" + + def __init__(self, maxsize=0): + self._unfinished_tasks = 0 + self._finished = locks.EventWaiter() + self._finished.set() + super().__init__(maxsize=maxsize) + + def _format(self): + result = Queue._format(self) + if self._unfinished_tasks: + result += ' tasks={}'.format(self._unfinished_tasks) + return result + + def _put(self, item): + super()._put(item) + self._unfinished_tasks += 1 + self._finished.clear() + + def task_done(self): + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each get() used to fetch a task, + a subsequent call to task_done() tells the queue that the processing + on the task is complete. + + If a join() is currently blocking, it will resume when all items have + been processed (meaning that a task_done() call was received for every + item that had been put() into the queue). + + Raises ValueError if called more times than there were items placed in + the queue. + """ + if self._unfinished_tasks <= 0: + raise ValueError('task_done() called too many times') + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + @coroutine + def join(self, timeout=None): + """Block until all items in the queue have been gotten and processed. + + The count of unfinished tasks goes up whenever an item is added to the + queue. The count goes down whenever a consumer thread calls task_done() + to indicate that the item was retrieved and all work on it is complete. + When the count of unfinished tasks drops to zero, join() unblocks. + """ + if self._unfinished_tasks > 0: + yield from self._finished.wait(timeout=timeout) diff --git a/tulip/selector_events.py b/tulip/selector_events.py new file mode 100644 index 00000000..2e93132e --- /dev/null +++ b/tulip/selector_events.py @@ -0,0 +1,696 @@ +"""Event loop using a selector and related classes. + +A selector is a "notify-when-ready" multiplexer. For a subclass which +also includes support for signal handling, see the unix_events sub-module. +""" + +import collections +import errno +import socket +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import base_events +from . import events +from . import futures +from . import selectors +from . import transports +from .log import tulip_log + + +# Errno values indicating the connection was disconnected. +_DISCONNECTED = frozenset((errno.ECONNRESET, + errno.ENOTCONN, + errno.ESHUTDOWN, + errno.ECONNABORTED, + errno.EPIPE, + errno.EBADF, + )) + + +class BaseSelectorEventLoop(base_events.BaseEventLoop): + """Selector event loop. + + See events.EventLoop for API specification. + """ + + def __init__(self, selector=None): + super().__init__() + + if selector is None: + selector = selectors.Selector() + tulip_log.debug('Using selector: %s', selector.__class__.__name__) + self._selector = selector + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None): + return _SelectorSocketTransport(self, sock, protocol, waiter, extra) + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + server_side=False, extra=None): + return _SelectorSslTransport( + self, rawsock, protocol, sslcontext, waiter, server_side, extra) + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + return _SelectorDatagramTransport(self, sock, protocol, address, extra) + + def close(self): + if self._selector is not None: + self._close_self_pipe() + self._selector.close() + self._selector = None + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self.remove_reader(self._ssock.fileno()) + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.add_reader(self._ssock.fileno(), self._read_from_self) + + def _read_from_self(self): + try: + self._ssock.recv(1) + except (BlockingIOError, InterruptedError): + pass + + def _write_to_self(self): + try: + self._csock.send(b'x') + except (BlockingIOError, InterruptedError): + pass + + def _start_serving(self, protocol_factory, sock, ssl=False): + self.add_reader(sock.fileno(), self._accept_connection, + protocol_factory, sock, ssl) + + def _accept_connection(self, protocol_factory, sock, ssl=False): + try: + conn, addr = sock.accept() + conn.setblocking(False) + except (BlockingIOError, InterruptedError): + pass # False alarm. + except: + # Bad error. Stop serving. + self.remove_reader(sock.fileno()) + sock.close() + # There's nowhere to send the error, so just log it. + # TODO: Someone will want an error handler for this. + tulip_log.exception('Accept failed') + else: + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + self._make_ssl_transport( + conn, protocol_factory(), sslcontext, futures.Future(), + server_side=True, extra={'addr': addr}) + else: + self._make_socket_transport( + conn, protocol_factory(), extra={'addr': addr}) + # It's now up to the protocol to handle the connection. + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a Handle instance.""" + handle = events.make_handle(callback, args) + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_READ, + (handle, None)) + else: + self._selector.modify(fd, mask | selectors.EVENT_READ, + (handle, writer)) + if reader is not None: + reader.cancel() + + return handle + + def remove_reader(self, fd): + """Remove a reader callback.""" + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + return False + else: + mask &= ~selectors.EVENT_READ + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (None, writer)) + + if reader is not None: + reader.cancel() + return True + else: + return False + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a Handle instance.""" + handle = events.make_handle(callback, args) + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_WRITE, + (None, handle)) + else: + self._selector.modify(fd, mask | selectors.EVENT_WRITE, + (reader, handle)) + if writer is not None: + writer.cancel() + + return handle + + def remove_writer(self, fd): + """Remove a writer callback.""" + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + return False + else: + # Remove both writer and connector. + mask &= ~selectors.EVENT_WRITE + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None)) + + if writer is not None: + writer.cancel() + return True + else: + return False + + def sock_recv(self, sock, n): + """XXX""" + fut = futures.Future() + self._sock_recv(fut, False, sock, n) + return fut + + def _sock_recv(self, fut, registered, sock, n): + fd = sock.fileno() + if registered: + # Remove the callback early. It should be rare that the + # selector says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_reader(fd) + if fut.cancelled(): + return + try: + data = sock.recv(n) + fut.set_result(data) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_recv, fut, True, sock, n) + except Exception as exc: + fut.set_exception(exc) + + def sock_sendall(self, sock, data): + """XXX""" + fut = futures.Future() + if data: + self._sock_sendall(fut, False, sock, data) + else: + fut.set_result(None) + return fut + + def _sock_sendall(self, fut, registered, sock, data): + fd = sock.fileno() + + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + + try: + n = sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + fut.set_exception(exc) + return + + if n == len(data): + fut.set_result(None) + else: + if n: + data = data[n:] + self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + + def sock_connect(self, sock, address): + """XXX""" + # That address better not require a lookup! We're not calling + # self.getaddrinfo() for you here. But verifying this is + # complicated; the socket module doesn't have a pattern for + # IPv6 addresses (there are too many forms, apparently). + fut = futures.Future() + self._sock_connect(fut, False, sock, address) + return fut + + def _sock_connect(self, fut, registered, sock, address): + fd = sock.fileno() + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + try: + if not registered: + # First time around. + sock.connect(address) + else: + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to the except clause below. + raise socket.error(err, 'Connect call failed') + fut.set_result(None) + except (BlockingIOError, InterruptedError): + self.add_writer(fd, self._sock_connect, fut, True, sock, address) + except Exception as exc: + fut.set_exception(exc) + + def sock_accept(self, sock): + """XXX""" + fut = futures.Future() + self._sock_accept(fut, False, sock) + return fut + + def _sock_accept(self, fut, registered, sock): + fd = sock.fileno() + if registered: + self.remove_reader(fd) + if fut.cancelled(): + return + try: + conn, address = sock.accept() + conn.setblocking(False) + fut.set_result((conn, address)) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_accept, fut, True, sock) + except Exception as exc: + fut.set_exception(exc) + + def _process_events(self, event_list): + for fileobj, mask, (reader, writer) in event_list: + if mask & selectors.EVENT_READ and reader is not None: + if reader.cancelled: + self.remove_reader(fileobj) + else: + self._add_callback(reader) + if mask & selectors.EVENT_WRITE and writer is not None: + if writer.cancelled: + self.remove_writer(fileobj) + else: + self._add_callback(writer) + + def stop_serving(self, sock): + self.remove_reader(sock.fileno()) + sock.close() + + +class _SelectorSocketTransport(transports.Transport): + + def __init__(self, event_loop, sock, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._event_loop = event_loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._conn_lost = 0 + self._closing = False # Set when close() called. + self._event_loop.add_reader(self._sock.fileno(), self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = self._sock.recv(16*1024) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + self._event_loop.remove_reader(self._sock.fileno()) + self._protocol.eof_received() + + def write(self, data): + assert isinstance(data, (bytes, bytearray)), repr(data) + assert not self._closing + if not data: + return + + if self._conn_lost: + if self._conn_lost >= 5: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Attempt to send it right away first. + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except socket.error as exc: + self._conn_lost += 1 + self._fatal_error(exc) + return + + if n == len(data): + return + elif n: + data = data[n:] + self._event_loop.add_writer(self._sock.fileno(), self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + assert data, "Data should not be empty" + + self._buffer.clear() + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except Exception as exc: + self._conn_lost += 1 + self._fatal_error(exc) + else: + if n == len(data): + self._event_loop.remove_writer(self._sock.fileno()) + if self._closing: + self._call_connection_lost(None) + return + elif n: + data = data[n:] + + self._buffer.append(data) # Try again later. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._close(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sock.fileno()) + if not self._buffer: + self._call_connection_lost(None) + + def _fatal_error(self, exc): + # should be called from exception handler only + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._event_loop.remove_writer(self._sock.fileno()) + self._event_loop.remove_reader(self._sock.fileno()) + self._buffer.clear() + self._call_connection_lost(exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + + +class _SelectorSslTransport(transports.Transport): + + def __init__(self, event_loop, rawsock, protocol, sslcontext, waiter, + server_side=False, extra=None): + super().__init__(extra) + + self._event_loop = event_loop + self._rawsock = rawsock + self._protocol = protocol + sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslcontext = sslcontext + self._waiter = waiter + sslsock = sslcontext.wrap_socket(rawsock, server_side=server_side, + do_handshake_on_connect=False) + self._sslsock = sslsock + self._buffer = [] + self._conn_lost = 0 + self._closing = False # Set when close() called. + self._extra['socket'] = sslsock + + self._on_handshake() + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._event_loop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._event_loop.add_writer(fd, self._on_handshake) + return + except Exception as exc: + self._sslsock.close() + self._waiter.set_exception(exc) + return + except BaseException as exc: + self._sslsock.close() + self._waiter.set_exception(exc) + raise + self._event_loop.remove_reader(fd) + self._event_loop.remove_writer(fd) + self._event_loop.add_reader(fd, self._on_ready) + self._event_loop.add_writer(fd, self._on_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.call_soon(self._waiter.set_result, None) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer + self._event_loop.remove_reader(fd) + self._event_loop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + return + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = b''.join(self._buffer) + self._buffer = [] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + n = 0 + except ssl.SSLWantWriteError: + n = 0 + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + self._conn_lost += 1 + self._fatal_error(exc) + return + + if n < len(data): + self._buffer.append(data[n:]) + elif self._closing: + self._event_loop.remove_writer(self._sslsock.fileno()) + self._sslsock.close() + self._protocol.connection_lost(None) + + def write(self, data): + assert isinstance(data, bytes), repr(data) + assert not self._closing + if not data: + return + + if self._conn_lost: + if self._conn_lost >= 5: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._close(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sslsock.fileno()) + if not self._buffer: + self._protocol.connection_lost(None) + + def _fatal_error(self, exc): + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._event_loop.remove_writer(self._sslsock.fileno()) + self._event_loop.remove_reader(self._sslsock.fileno()) + self._buffer = [] + self._protocol.connection_lost(exc) + + +class _SelectorDatagramTransport(transports.DatagramTransport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, event_loop, sock, protocol, address=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._event_loop = event_loop + self._sock = sock + self._fileno = sock.fileno() + self._protocol = protocol + self._address = address + self._buffer = collections.deque() + self._conn_lost = 0 + self._closing = False # Set when close() called. + self._event_loop.add_reader(self._fileno, self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + + def _read_ready(self): + try: + data, addr = self._sock.recvfrom(self.max_size) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + self._protocol.datagram_received(data, addr) + + def sendto(self, data, addr=None): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + + if self._address: + assert addr in (None, self._address) + + if self._conn_lost and self._address: + if self._conn_lost >= 5: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Attempt to send it right away first. + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + return + except ConnectionRefusedError as exc: + if self._address: + self._conn_lost += 1 + self._fatal_error(exc) + return + except (BlockingIOError, InterruptedError): + self._event_loop.add_writer(self._fileno, self._sendto_ready) + except Exception as exc: + self._conn_lost += 1 + self._fatal_error(exc) + return + + self._buffer.append((data, addr)) + + def _sendto_ready(self): + while self._buffer: + data, addr = self._buffer.popleft() + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + except ConnectionRefusedError as exc: + if self._address: + self._conn_lost += 1 + self._fatal_error(exc) + return + except (BlockingIOError, InterruptedError): + self._buffer.appendleft((data, addr)) # Try again later. + break + except Exception as exc: + self._conn_lost += 1 + self._fatal_error(exc) + return + + if not self._buffer: + self._event_loop.remove_writer(self._fileno) + if self._closing: + self._call_connection_lost(None) + + def abort(self): + self._close(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._fileno) + if not self._buffer: + self._call_connection_lost(None) + + def _fatal_error(self, exc): + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._buffer.clear() + self._event_loop.remove_writer(self._fileno) + self._event_loop.remove_reader(self._fileno) + if self._address and isinstance(exc, ConnectionRefusedError): + self._protocol.connection_refused(exc) + self._call_connection_lost(exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() diff --git a/tulip/selectors.py b/tulip/selectors.py new file mode 100644 index 00000000..bd81e554 --- /dev/null +++ b/tulip/selectors.py @@ -0,0 +1,418 @@ +"""Select module. + +This module supports asynchronous I/O on multiple file descriptors. +""" + +import sys +from select import * + +from .log import tulip_log + + +# generic events, that must be mapped to implementation-specific ones +# read event +EVENT_READ = (1 << 0) +# write event +EVENT_WRITE = (1 << 1) + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file descriptor, or any object with a `fileno()` method + + Returns: + corresponding file descriptor + """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (ValueError, TypeError): + raise ValueError("Invalid file object: {!r}".format(fileobj)) + return fd + + +class SelectorKey: + """Object used internally to associate a file object to its backing file + descriptor, selected event mask and attached data.""" + + def __init__(self, fileobj, events, data=None): + self.fileobj = fileobj + self.fd = _fileobj_to_fd(fileobj) + self.events = events + self.data = data + + def __repr__(self): + return '{}'.format( + self.__class__.__name__, + self.fileobj, self.fd, self.events, self.data) + + +class _BaseSelector: + """Base selector class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + performant implementation on the current platform. + """ + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # this maps file objects to keys - for fast (un)registering + self._fileobj_to_key = {} + + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + """ + if (not events) or (events & ~(EVENT_READ|EVENT_WRITE)): + raise ValueError("Invalid events: {}".format(events)) + + if fileobj in self._fileobj_to_key: + raise ValueError("{!r} is already registered".format(fileobj)) + + key = SelectorKey(fileobj, events, data) + self._fd_to_key[key.fd] = key + self._fileobj_to_key[fileobj] = key + return key + + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object + + Returns: + SelectorKey instance + """ + try: + key = self._fileobj_to_key[fileobj] + del self._fd_to_key[key.fd] + del self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + return key + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + """ + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + if events != key.events or data != key.data: + self.unregister(fileobj) + return self.register(fileobj, events, data) + else: + return key + + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout == 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (fileobj, events, attached data) for ready file objects + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE + """ + raise NotImplementedError() + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + self._fd_to_key.clear() + self._fileobj_to_key.clear() + + def get_info(self, fileobj): + """Return information about a registered file object. + + Returns: + (events, data) associated to this file object + + Raises KeyError if the file object is not registered. + """ + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise KeyError("{} is not registered".format(fileobj)) + return key.events, key.data + + def registered_count(self): + """Return the number of registered file objects. + + Returns: + number of currently registered file objects + """ + return len(self._fd_to_key) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key + """ + try: + return self._fd_to_key[fd] + except KeyError: + tulip_log.warning('No key found for fd %r', fd) + return None + + +class SelectSelector(_BaseSelector): + """Select-based selector.""" + + def __init__(self): + super().__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & EVENT_WRITE: + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + def select(self, timeout=None): + try: + r, w, _ = self._select(self._readers, self._writers, [], timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + r = set(r) + w = set(w) + ready = [] + for fd in r | w: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + if sys.platform == 'win32': + def _select(self, r, w, _, timeout=None): + r, w, x = select(r, w, w, timeout) + return r, w + x, [] + else: + from select import select as _select + + +if 'poll' in globals(): + + # TODO: Implement poll() for Windows with workaround for + # brokenness in WSAPoll() (Richard Oudkerk, see + # http://bugs.python.org/issue16507). + + class PollSelector(_BaseSelector): + """Poll-based selector.""" + + def __init__(self): + super().__init__() + self._poll = poll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= POLLIN + if events & EVENT_WRITE: + poll_events |= POLLOUT + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = None if timeout is None else int(1000 * timeout) + ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~POLLIN: + events |= EVENT_WRITE + if event & ~POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + +if 'epoll' in globals(): + + class EpollSelector(_BaseSelector): + """Epoll-based selector.""" + + def __init__(self): + super().__init__() + self._epoll = epoll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + epoll_events = 0 + if events & EVENT_READ: + epoll_events |= EPOLLIN + if events & EVENT_WRITE: + epoll_events |= EPOLLOUT + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._epoll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = -1 if timeout is None else timeout + max_ev = self.registered_count() + ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~EPOLLIN: + events |= EVENT_WRITE + if event & ~EPOLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + def close(self): + super().close() + self._epoll.close() + + +if 'kqueue' in globals(): + + class KqueueSelector(_BaseSelector): + """Kqueue-based selector.""" + + def __init__(self): + super().__init__() + self._kqueue = kqueue() + + def unregister(self, fileobj): + key = super().unregister(fileobj) + if key.events & EVENT_READ: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + if key.events & EVENT_WRITE: + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + return key + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & EVENT_WRITE: + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def select(self, timeout=None): + max_ev = self.registered_count() + ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == KQ_FILTER_READ: + events |= EVENT_READ + if flag == KQ_FILTER_WRITE: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + def close(self): + super().close() + self._kqueue.close() + + +# Choose the best implementation: roughly, epoll|kqueue > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + Selector = KqueueSelector +elif 'EpollSelector' in globals(): + Selector = EpollSelector +elif 'PollSelector' in globals(): + Selector = PollSelector +else: + Selector = SelectSelector diff --git a/tulip/streams.py b/tulip/streams.py new file mode 100644 index 00000000..8d7f6236 --- /dev/null +++ b/tulip/streams.py @@ -0,0 +1,145 @@ +"""Stream-related things.""" + +__all__ = ['StreamReader'] + +import collections + +from . import futures +from . import tasks + + +class StreamReader: + + def __init__(self, limit=2**16): + self.limit = limit # Max line length. (Security feature.) + self.buffer = collections.deque() # Deque of bytes objects. + self.byte_count = 0 # Bytes in buffer. + self.eof = False # Whether we're done. + self.waiter = None # A future. + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + if self.waiter is not None: + self.waiter.set_exception(exc) + + def feed_eof(self): + self.eof = True + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(True) + + def feed_data(self, data): + if not data: + return + + self.buffer.append(data) + self.byte_count += len(data) + + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(False) + + @tasks.coroutine + def readline(self): + if self._exception is not None: + raise self._exception + + parts = [] + parts_size = 0 + not_enough = True + + while not_enough: + while self.buffer and not_enough: + data = self.buffer.popleft() + ichar = data.find(b'\n') + if ichar < 0: + parts.append(data) + parts_size += len(data) + else: + ichar += 1 + head, tail = data[:ichar], data[ichar:] + if tail: + self.buffer.appendleft(tail) + not_enough = False + parts.append(head) + parts_size += len(head) + + if parts_size > self.limit: + self.byte_count -= parts_size + raise ValueError('Line is too long') + + if self.eof: + break + + if not_enough: + assert self.waiter is None + self.waiter = futures.Future() + yield from self.waiter + + line = b''.join(parts) + self.byte_count -= parts_size + + return line + + @tasks.coroutine + def read(self, n=-1): + if self._exception is not None: + raise self._exception + + if not n: + return b'' + + if n < 0: + while not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + else: + if not self.byte_count and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + + if n < 0 or self.byte_count <= n: + data = b''.join(self.buffer) + self.buffer.clear() + self.byte_count = 0 + return data + + parts = [] + parts_bytes = 0 + while self.buffer and parts_bytes < n: + data = self.buffer.popleft() + data_bytes = len(data) + if n < parts_bytes + data_bytes: + data_bytes = n - parts_bytes + data, rest = data[:data_bytes], data[data_bytes:] + self.buffer.appendleft(rest) + + parts.append(data) + parts_bytes += data_bytes + self.byte_count -= data_bytes + + return b''.join(parts) + + @tasks.coroutine + def readexactly(self, n): + if self._exception is not None: + raise self._exception + + if n <= 0: + return b'' + + while self.byte_count < n and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + + return (yield from self.read(n)) diff --git a/tulip/subprocess_transport.py b/tulip/subprocess_transport.py new file mode 100644 index 00000000..734a5fa7 --- /dev/null +++ b/tulip/subprocess_transport.py @@ -0,0 +1,139 @@ +import fcntl +import os +import traceback + +from . import transports +from . import events +from .log import tulip_log + + +class UnixSubprocessTransport(transports.Transport): + """Transport class managing a subprocess. + + TODO: Separate this into something that just handles pipe I/O, + and something else that handles pipe setup, fork, and exec. + """ + + def __init__(self, protocol, args): + self._protocol = protocol # Not a factory! :-) + self._args = args # args[0] must be full path of binary. + self._event_loop = events.get_event_loop() + self._buffer = [] + self._eof = False + rstdin, self._wstdin = os.pipe() + self._rstdout, wstdout = os.pipe() + + # TODO: This is incredibly naive. Should look at + # subprocess.py for all the precautions around fork/exec. + pid = os.fork() + if not pid: + # Child. + try: + os.dup2(rstdin, 0) + os.dup2(wstdout, 1) + # TODO: What to do with stderr? + os.execv(args[0], args) + except: + try: + traceback.print_traceback() + finally: + os._exit(127) + + # Parent. + os.close(rstdin) + os.close(wstdout) + _setnonblocking(self._wstdin) + _setnonblocking(self._rstdout) + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.add_reader(self._rstdout, self._stdout_callback) + + def write(self, data): + assert not self._eof + assert isinstance(data, bytes), repr(data) + if not data: + return + + if not self._buffer: + # Attempt to write it right away first. + try: + n = os.write(self._wstdin, data) + except BlockingIOError: + pass + except Exception as exc: + self._fatal_error(exc) + return + else: + if n == len(data): + return + elif n: + data = data[n:] + self._event_loop.add_writer(self._wstdin, self._stdin_callback) + self._buffer.append(data) + + def write_eof(self): + assert not self._eof + assert self._wstdin >= 0 + self._eof = True + if not self._buffer: + self._event_loop.remove_writer(self._wstdin) + os.close(self._wstdin) + self._wstdin = -1 + + def close(self): + if not self._eof: + self.write_eof() + # XXX What else? + + def _fatal_error(self, exc): + tulip_log.error('Fatal error: %r', exc) + if self._rstdout >= 0: + os.close(self._rstdout) + self._rstdout = -1 + if self._wstdin >= 0: + os.close(self._wstdin) + self._wstdin = -1 + self._eof = True + self._buffer = None + + def _stdin_callback(self): + data = b''.join(self._buffer) + assert data, "Data shold not be empty" + + self._buffer = [] + try: + n = os.write(self._wstdin, data) + except BlockingIOError: + self._buffer.append(data) + except Exception as exc: + self._fatal_error(exc) + else: + if n >= len(data): + self._event_loop.remove_writer(self._wstdin) + if self._eof: + os.close(self._wstdin) + self._wstdin = -1 + return + + elif n > 0: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def _stdout_callback(self): + try: + data = os.read(self._rstdout, 1024) + except BlockingIOError: + pass + else: + if data: + self._event_loop.call_soon(self._protocol.data_received, data) + else: + self._event_loop.remove_reader(self._rstdout) + os.close(self._rstdout) + self._rstdout = -1 + self._event_loop.call_soon(self._protocol.eof_received) + + +def _setnonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) diff --git a/tulip/tasks.py b/tulip/tasks.py new file mode 100644 index 00000000..3d7acc79 --- /dev/null +++ b/tulip/tasks.py @@ -0,0 +1,334 @@ +"""Support for tasks, coroutines and the scheduler.""" + +__all__ = ['coroutine', 'task', 'Task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'as_completed', 'sleep', + ] + +import concurrent.futures +import functools +import inspect +import time + +from . import futures +from .log import tulip_log + + +def coroutine(func): + """Decorator to mark coroutines. + + Decorator wraps non generator functions and returns generator wrapper. + If non generator function returns generator of Future it yield-from it. + + TODO: This is a feel-good API only. It is not enforced. + """ + if inspect.isgeneratorfunction(func): + coro = func + else: + @functools.wraps(func) + def coro(*args, **kw): + res = func(*args, **kw) + if isinstance(res, futures.Future) or inspect.isgenerator(res): + res = yield from res + return res + + coro._is_coroutine = True # Not sure who can use this. + return coro + + +# TODO: Do we need this? +def iscoroutinefunction(func): + """Return True if func is a decorated coroutine function.""" + return (inspect.isgeneratorfunction(func) and + getattr(func, '_is_coroutine', False)) + + +# TODO: Do we need this? +def iscoroutine(obj): + """Return True if obj is a coroutine object.""" + return inspect.isgenerator(obj) # TODO: And what? + + +def task(func): + """Decorator for a coroutine to be wrapped in a Task.""" + if inspect.isgeneratorfunction(func): + coro = func + else: + def coro(*args, **kw): + res = func(*args, **kw) + if isinstance(res, futures.Future) or inspect.isgenerator(res): + res = yield from res + return res + + def task_wrapper(*args, **kwds): + return Task(coro(*args, **kwds)) + + return task_wrapper + + +class Task(futures.Future): + """A coroutine wrapped in a Future.""" + + def __init__(self, coro, event_loop=None, timeout=None): + assert inspect.isgenerator(coro) # Must be a coroutine *object*. + super().__init__(event_loop=event_loop, timeout=timeout) + self._coro = coro + self._fut_waiter = None + self._must_cancel = False + self._event_loop.call_soon(self._step) + + def __repr__(self): + res = super().__repr__() + if (self._must_cancel and + self._state == futures._PENDING and + ')'.format(self._coro.__name__) + res[i:] + return res + + def cancel(self): + if self.done(): + return False + self._must_cancel = True + # _step() will call super().cancel() to call the callbacks. + if self._fut_waiter is not None: + assert not self._fut_waiter.done(), 'Assume it is a race condition.' + self._fut_waiter.cancel() + else: + self._event_loop.call_soon(self._step_maybe) + return True + + def cancelled(self): + return self._must_cancel or super().cancelled() + + def _step_maybe(self): + # Helper for cancel(). + if not self.done(): + return self._step() + + def _step(self, value=None, exc=None): + assert not self.done(), \ + '_step(): already done: {!r}, {!r}, {!r}'.format(self, value, exc) + + self._fut_waiter = None + + # We'll call either coro.throw(exc) or coro.send(value). + if self._must_cancel: + exc = futures.CancelledError + coro = self._coro + try: + if exc is not None: + result = coro.throw(exc) + elif value is not None: + result = coro.send(value) + else: + result = next(coro) + except StopIteration as exc: + if self._must_cancel: + super().cancel() + else: + self.set_result(exc.value) + except Exception as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + tulip_log.exception('Exception in task') + except BaseException as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + tulip_log.exception('BaseException in task') + raise + else: + # XXX No check for self._must_cancel here? + if isinstance(result, futures.Future): + if not result._blocking: + result.set_exception( + RuntimeError( + 'yield was used instead of yield from ' + 'in task {!r} with {!r}'.format(self, result))) + + result._blocking = False + result.add_done_callback(self._wakeup) + self._fut_waiter = result + + elif isinstance(result, concurrent.futures.Future): + # This ought to be more efficient than wrap_future(), + # because we don't create an extra Future. + result.add_done_callback( + lambda future: + self._event_loop.call_soon_threadsafe( + self._wakeup, future)) + else: + if inspect.isgenerator(result): + self._event_loop.call_soon( + self._step, None, + RuntimeError( + 'yield was used instead of yield from for ' + 'generator in task {!r} with {}'.format( + self, result))) + else: + if result is not None: + self._event_loop.call_soon( + self._step, None, + RuntimeError( + 'Task received bad yield: {!r}'.format(result))) + else: + self._event_loop.call_soon(self._step) + + def _wakeup(self, future): + try: + value = future.result() + except Exception as exc: + self._step(None, exc) + else: + self._step(value, None) + + +# wait() and as_completed() similar to those in PEP 3148. + +FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED +FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION +ALL_COMPLETED = concurrent.futures.ALL_COMPLETED + + +# Even though this *is* a @coroutine, we don't mark it as such! +def wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Wait for the Futures and and coroutines given by fs to complete. + + Coroutines will be wrapped in Tasks. + + Returns two sets of Future: (done, pending). + + Usage: + + done, pending = yield from tulip.wait(fs) + + Note: This does not raise TimeoutError! Futures that aren't done + when the timeout occurs are returned in the second set. + """ + fs = _wrap_coroutines(fs) + return _wait(fs, timeout, return_when) + + +@coroutine +def _wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Internal helper: Like wait() but does not wrap coroutines.""" + done, pending = set(), set() + + errors = 0 + for f in fs: + if f.done(): + done.add(f) + if not f.cancelled() and f.exception() is not None: + errors += 1 + else: + pending.add(f) + + if (not pending or + timeout is not None and timeout <= 0 or + return_when == FIRST_COMPLETED and done or + return_when == FIRST_EXCEPTION and errors): + return done, pending + + # Will always be cancelled eventually. + bail = futures.Future(timeout=timeout) + + def _on_completion(f): + pending.remove(f) + done.add(f) + if (not pending or + return_when == FIRST_COMPLETED or + (return_when == FIRST_EXCEPTION and + not f.cancelled() and + f.exception() is not None)): + bail.cancel() + + for f in pending: + f.remove_done_callback(_on_completion) + + for f in pending: + f.add_done_callback(_on_completion) + try: + yield from bail + except futures.CancelledError: + pass + + really_done = set(f for f in pending if f.done()) + if really_done: + done.update(really_done) + pending.difference_update(really_done) + + return done, pending + + +# This is *not* a @coroutine! It is just an iterator (yielding Futures). +def as_completed(fs, timeout=None): + """Return an iterator whose values, when waited for, are Futures. + + This differs from PEP 3148; the proper way to use this is: + + for f in as_completed(fs): + result = yield from f # The 'yield from' may raise. + # Use result. + + Raises TimeoutError if the timeout occurs before all Futures are + done. + + Note: The futures 'f' are not necessarily members of fs. + """ + deadline = None + if timeout is not None: + deadline = time.monotonic() + timeout + + done = None # Make nonlocal happy. + fs = _wrap_coroutines(fs) + + while fs: + if deadline is not None: + timeout = deadline - time.monotonic() + + @coroutine + def _wait_for_some(): + nonlocal done, fs + done, fs = yield from _wait(fs, timeout=timeout, + return_when=FIRST_COMPLETED) + if not done: + fs = set() + raise futures.TimeoutError() + return done.pop().result() # May raise. + + yield Task(_wait_for_some()) + for f in done: + yield f + + +def _wrap_coroutines(fs): + """Internal helper to process an iterator of Futures and coroutines. + + Returns a set of Futures. + """ + wrapped = set() + for f in fs: + if not isinstance(f, futures.Future): + assert iscoroutine(f) + f = Task(f) + wrapped.add(f) + return wrapped + + +def sleep(when, result=None): + """Return a Future that completes after a given time (in seconds). + + It's okay to cancel the Future. + + Undocumented feature: sleep(when, x) sets the Future's result to x. + """ + future = futures.Future() + future._event_loop.call_later(when, future.set_result, result) + return future diff --git a/tulip/test_utils.py b/tulip/test_utils.py new file mode 100644 index 00000000..d6219143 --- /dev/null +++ b/tulip/test_utils.py @@ -0,0 +1,269 @@ +"""Utilities shared by tests.""" + +import cgi +import contextlib +import email.parser +import http.server +import json +import logging +import io +import os +import re +import socket +import sys +import threading +import traceback +import urllib.parse +import unittest +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +import tulip +import tulip.http +from tulip.http import client + + +if sys.platform == 'win32': # pragma: no cover + from .winsocketpair import socketpair +else: + from socket import socketpair # pragma: no cover + + +class LogTrackingTestCase(unittest.TestCase): + + def setUp(self): + self._logger = logging.getLogger() + self._log_level = self._logger.getEffectiveLevel() + + def tearDown(self): + self._logger.setLevel(self._log_level) + + def suppress_log_errors(self): # pragma: no cover + if self._log_level >= logging.WARNING: + self._logger.setLevel(logging.CRITICAL) + + def suppress_log_warnings(self): # pragma: no cover + if self._log_level >= logging.WARNING: + self._logger.setLevel(logging.ERROR) + + +@contextlib.contextmanager +def run_test_server(loop, *, host='127.0.0.1', port=0, + use_ssl=False, router=None): + properties = {} + + class HttpServer: + + def __init__(self, host, port): + self.host = host + self.port = port + self.address = (host, port) + self._url = '{}://{}:{}'.format( + 'https' if use_ssl else 'http', host, port) + + def __getitem__(self, key): + return properties[key] + + def __setitem__(self, key, value): + properties[key] = value + + def url(self, *suffix): + return urllib.parse.urljoin( + self._url, '/'.join(str(s) for s in suffix)) + + class TestHttpServer(tulip.http.ServerHttpProtocol): + + def handle_request(self, info, message): + if properties.get('noresponse', False): + return + + if router is not None: + payload = io.BytesIO((yield from message.payload.read())) + rob = router( + properties, self.transport, + info, message.headers, payload, message.compression) + rob.dispatch() + + else: + response = tulip.http.Response( + self.transport, 200, info.version) + + text = b'Test message' + response.add_header('Content-type', 'text/plain') + response.add_header('Content-length', str(len(text))) + response.send_headers() + response.write(text) + response.write_eof() + + self.transport.close() + + if use_ssl: + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + keyfile = os.path.join(here, 'sample.key') + certfile = os.path.join(here, 'sample.crt') + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain(certfile, keyfile) + else: + sslcontext = False + + def run(loop, fut): + thread_loop = tulip.new_event_loop() + tulip.set_event_loop(thread_loop) + + sock = thread_loop.run_until_complete( + thread_loop.start_serving( + TestHttpServer, host, port, ssl=sslcontext)) + + waiter = tulip.Future() + loop.call_soon_threadsafe( + fut.set_result, (thread_loop, waiter, sock.getsockname())) + + thread_loop.run_until_complete(waiter) + thread_loop.stop() + + fut = tulip.Future() + server_thread = threading.Thread(target=run, args=(loop, fut)) + server_thread.start() + + thread_loop, waiter, addr = loop.run_until_complete(fut) + try: + yield HttpServer(*addr) + finally: + thread_loop.call_soon_threadsafe(waiter.set_result, None) + server_thread.join() + + +class Router: + + _response_version = "1.1" + _responses = http.server.BaseHTTPRequestHandler.responses + + def __init__(self, props, transport, rline, headers, body, cmode): + # headers + self._headers = http.client.HTTPMessage() + for hdr, val in headers: + self._headers.add_header(hdr, val) + + self._props = props + self._transport = transport + self._method = rline.method + self._uri = rline.uri + self._version = rline.version + self._compression = cmode + self._body = body.read() + + url = urllib.parse.urlsplit(self._uri) + self._path = url.path + self._query = url.query + + @staticmethod + def define(rmatch): + def wrapper(fn): + f_locals = sys._getframe(1).f_locals + mapping = f_locals.setdefault('_mapping', []) + mapping.append((re.compile(rmatch), fn.__name__)) + return fn + + return wrapper + + def dispatch(self): # pragma: no cover + for route, fn in self._mapping: + match = route.match(self._path) + if match is not None: + try: + return getattr(self, fn)(match) + except: + out = io.StringIO() + traceback.print_exc(file=out) + self._response(500, out.getvalue()) + + return + + return self._response(self._start_response(404)) + + def _start_response(self, code): + return tulip.http.Response(self._transport, code) + + def _response(self, response, body=None, headers=None, chunked=False): + r_headers = {} + for key, val in self._headers.items(): + key = '-'.join(p.capitalize() for p in key.split('-')) + r_headers[key] = val + + encoding = self._headers.get('content-encoding', '').lower() + if 'gzip' in encoding: # pragma: no cover + cmod = 'gzip' + elif 'deflate' in encoding: + cmod = 'deflate' + else: + cmod = '' + + resp = { + 'method': self._method, + 'version': '%s.%s' % self._version, + 'path': self._uri, + 'headers': r_headers, + 'origin': self._transport.get_extra_info('addr', ' ')[0], + 'query': self._query, + 'form': {}, + 'compression': cmod, + 'multipart-data': [] + } + if body: # pragma: no cover + resp['content'] = body + + ct = self._headers.get('content-type', '').lower() + + # application/x-www-form-urlencoded + if ct == 'application/x-www-form-urlencoded': + resp['form'] = urllib.parse.parse_qs(self._body.decode('latin1')) + + # multipart/form-data + elif ct.startswith('multipart/form-data'): # pragma: no cover + out = io.BytesIO() + for key, val in self._headers.items(): + out.write(bytes('{}: {}\r\n'.format(key, val), 'latin1')) + + out.write(b'\r\n') + out.write(self._body) + out.write(b'\r\n') + out.seek(0) + + message = email.parser.BytesParser().parse(out) + if message.is_multipart(): + for msg in message.get_payload(): + if msg.is_multipart(): + logging.warn('multipart msg is not expected') + else: + key, params = cgi.parse_header( + msg.get('content-disposition', '')) + params['data'] = msg.get_payload() + params['content-type'] = msg.get_content_type() + resp['multipart-data'].append(params) + + body = json.dumps(resp, indent=4, sort_keys=True) + + # default headers + hdrs = [('Connection', 'close'), + ('Content-Type', 'application/json')] + if chunked: + hdrs.append(('Transfer-Encoding', 'chunked')) + else: + hdrs.append(('Content-Length', str(len(body)))) + + # extra headers + if headers: + hdrs.extend(headers.items()) + + if chunked: + response.force_chunked() + + # headers + response.add_headers(*hdrs) + response.send_headers() + + # write payload + response.write(client.str_to_bytes(body)) + response.write_eof() diff --git a/tulip/transports.py b/tulip/transports.py new file mode 100644 index 00000000..a9ec07a0 --- /dev/null +++ b/tulip/transports.py @@ -0,0 +1,134 @@ +"""Abstract Transport class.""" + +__all__ = ['ReadTransport', 'WriteTransport', 'Transport'] + + +class BaseTransport: + """Base ABC for transports.""" + + def __init__(self, extra=None): + if extra is None: + extra = {} + self._extra = extra + + def get_extra_info(self, name, default=None): + """Get optional transport information.""" + return self._extra.get(name, default) + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + raise NotImplementedError + + +class ReadTransport(BaseTransport): + """ABC for read-only transports.""" + + def pause(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + raise NotImplementedError + + +class WriteTransport(BaseTransport): + """ABC for write-only transports.""" + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def write_eof(self): + """Closes the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + raise NotImplementedError + + def can_write_eof(self): + """Return True if this protocol supports write_eof(), False if not.""" + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class Transport(ReadTransport, WriteTransport): + """ABC representing a bidirectional transport. + + There may be several implementations, but typically, the user does + not implement new transports; rather, the platform provides some + useful transports that are implemented using the platform's best + practices. + + The user never instantiates a transport directly; they call a + utility function, passing it a protocol factory and other + information necessary to create the transport and protocol. (E.g. + EventLoop.create_connection() or EventLoop.start_serving().) + + The utility function will asynchronously create a transport and a + protocol and hook them up by calling the protocol's + connection_made() method, passing it the transport. + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + +class DatagramTransport(BaseTransport): + """ABC for datagram (UDP) transports.""" + + def sendto(self, data, addr=None): + """Send data to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + addr is target socket address. + If addr is None use target address pointed on transport creation. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError diff --git a/tulip/unix_events.py b/tulip/unix_events.py new file mode 100644 index 00000000..3073ab64 --- /dev/null +++ b/tulip/unix_events.py @@ -0,0 +1,301 @@ +"""Selector eventloop for Unix with signal handling.""" + +import errno +import fcntl +import os +import socket +import sys + +try: + import signal +except ImportError: # pragma: no cover + signal = None + +from . import events +from . import selector_events +from . import transports +from .log import tulip_log + + +__all__ = ['SelectorEventLoop'] + + +if sys.platform == 'win32': # pragma: no cover + raise ImportError('Signals are not really supported on Windows') + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Unix event loop + + Adds signal handling to SelectorEventLoop + """ + + def __init__(self, selector=None): + super().__init__(selector) + self._signal_handlers = {} + + def _socketpair(self): + return socket.socketpair() + + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + self._check_signal(sig) + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except ValueError as exc: + raise RuntimeError(str(exc)) + + handle = events.make_handle(callback, args) + self._signal_handlers[sig] = handle + + try: + signal.signal(sig, self._handle_signal) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as nexc: + tulip_log.info('set_wakeup_fd(-1) failed: %s', nexc) + + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + return handle + + def _handle_signal(self, sig, arg): + """Internal helper that is the actual signal handler.""" + handle = self._signal_handlers.get(sig) + if handle is None: + return # Assume it's some race condition. + if handle.cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self.call_soon_threadsafe(handle) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not.""" + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as exc: + tulip_log.info('set_wakeup_fd(-1) failed: %s', exc) + + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + + if signal is None: + raise RuntimeError('Signals are not supported') + + if not (1 <= sig < signal.NSIG): + raise ValueError( + 'sig {} out of range(1, {})'.format(sig, signal.NSIG)) + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixReadPipeTransport(self, pipe, protocol, waiter, extra) + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixWritePipeTransport(self, pipe, protocol, waiter, extra) + + +def _set_nonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + + +class _UnixReadPipeTransport(transports.ReadTransport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, event_loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._event_loop = event_loop + self._pipe = pipe + self._fileno = pipe.fileno() + _set_nonblocking(self._fileno) + self._protocol = protocol + self._closing = False + self._event_loop.add_reader(self._fileno, self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = os.read(self._fileno, self.max_size) + except BlockingIOError: + pass + except OSError as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + self._event_loop.remove_reader(self._fileno) + self._protocol.eof_received() + + def pause(self): + self._event_loop.remove_reader(self._fileno) + + def resume(self): + self._event_loop.add_reader(self._fileno, self._read_ready) + + def close(self): + if not self._closing: + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._closing = True + self._event_loop.remove_reader(self._fileno) + self._call_connection_lost(exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + + +class _UnixWritePipeTransport(transports.WriteTransport): + + def __init__(self, event_loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._event_loop = event_loop + self._pipe = pipe + self._fileno = pipe.fileno() + _set_nonblocking(self._fileno) + self._protocol = protocol + self._buffer = [] + self._closing = False # Set when close() or write_eof() called. + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def write(self, data): + assert isinstance(data, bytes), repr(data) + assert not self._closing + if not data: + return + + if not self._buffer: + # Attempt to send it right away first. + try: + n = os.write(self._fileno, data) + except BlockingIOError: + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + if n == len(data): + return + elif n > 0: + data = data[n:] + self._event_loop.add_writer(self._fileno, self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + assert data, "Data should not be empty" + + self._buffer.clear() + try: + n = os.write(self._fileno, data) + except BlockingIOError: + self._buffer.append(data) + except Exception as exc: + self._fatal_error(exc) + else: + if n == len(data): + self._event_loop.remove_writer(self._fileno) + if self._closing: + self._call_connection_lost(None) + return + elif n > 0: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def can_write_eof(self): + return True + + def write_eof(self): + assert not self._closing + assert self._pipe + self._closing = True + if not self._buffer: + self._call_connection_lost(None) + + def close(self): + if not self._closing: + # write_eof is all what we needed to close the write pipe + self.write_eof() + + def abort(self): + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc=None): + self._closing = True + self._buffer.clear() + self._event_loop.remove_writer(self._fileno) + self._call_connection_lost(exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() diff --git a/tulip/windows_events.py b/tulip/windows_events.py new file mode 100644 index 00000000..2ec8561c --- /dev/null +++ b/tulip/windows_events.py @@ -0,0 +1,157 @@ +"""Selector and proactor eventloops for Windows.""" + +import socket +import weakref +import struct +import _winapi + +from . import futures +from . import proactor_events +from . import selector_events +from . import winsocketpair +from . import _overlapped +from .log import tulip_log + + +__all__ = ['SelectorEventLoop', 'ProactorEventLoop'] + + +NULL = 0 +INFINITE = 0xffffffff +ERROR_CONNECTION_REFUSED = 1225 +ERROR_CONNECTION_ABORTED = 1236 + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + def _socketpair(self): + return winsocketpair.socketpair() + + +class ProactorEventLoop(proactor_events.BaseProactorEventLoop): + def __init__(self, proactor=None): + if proactor is None: + proactor = IocpProactor() + super().__init__(proactor) + + def _socketpair(self): + return winsocketpair.socketpair() + + +class IocpProactor: + + def __init__(self, concurrency=0xffffffff): + self._results = [] + self._iocp = _overlapped.CreateIoCompletionPort( + _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency) + self._cache = {} + self._registered = weakref.WeakSet() + + def registered_count(self): + return len(self._cache) + + def select(self, timeout=None): + if not self._results: + self._poll(timeout) + tmp = self._results + self._results = [] + return tmp + + def recv(self, conn, nbytes, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + ov.WSARecv(conn.fileno(), nbytes, flags) + return self._register(ov, conn, ov.getresult) + + def send(self, conn, buf, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + ov.WSASend(conn.fileno(), buf, flags) + return self._register(ov, conn, ov.getresult) + + def accept(self, listener): + self._register_with_iocp(listener) + conn = self._get_accept_socket() + ov = _overlapped.Overlapped(NULL) + ov.AcceptEx(listener.fileno(), conn.fileno()) + + def finish_accept(): + addr = ov.getresult() + buf = struct.pack('@P', listener.fileno()) + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_ACCEPT_CONTEXT, + buf) + conn.settimeout(listener.gettimeout()) + return conn, conn.getpeername() + + return self._register(ov, listener, finish_accept) + + def connect(self, conn, address): + self._register_with_iocp(conn) + _overlapped.BindLocal(conn.fileno(), len(address)) + ov = _overlapped.Overlapped(NULL) + ov.ConnectEx(conn.fileno(), address) + + def finish_connect(): + ov.getresult() + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_CONNECT_CONTEXT, + 0) + return conn + + return self._register(ov, conn, finish_connect) + + def _register_with_iocp(self, obj): + if obj not in self._registered: + self._registered.add(obj) + _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) + + def _register(self, ov, obj, callback): + f = futures.Future() + self._cache[ov.address] = (f, ov, obj, callback) + return f + + def _get_accept_socket(self): + s = socket.socket() + s.settimeout(0) + return s + + def _poll(self, timeout=None): + if timeout is None: + ms = INFINITE + elif timeout < 0: + raise ValueError("negative timeout") + else: + ms = int(timeout * 1000 + 0.5) + if ms >= INFINITE: + raise ValueError("timeout too big") + while True: + status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms) + if status is None: + return + address = status[3] + f, ov, obj, callback = self._cache.pop(address) + try: + value = callback() + except OSError as e: + f.set_exception(e) + self._results.append(f) + else: + f.set_result(value) + self._results.append(f) + ms = 0 + + def close(self): + for (f, ov, obj, callback) in self._cache.values(): + try: + ov.cancel() + except OSError: + pass + + while self._cache: + if not self._poll(1): + tulip_log.debug('taking long time to close proactor') + + self._results = [] + if self._iocp is not None: + _winapi.CloseHandle(self._iocp) + self._iocp = None diff --git a/tulip/winsocketpair.py b/tulip/winsocketpair.py new file mode 100644 index 00000000..bd1e0928 --- /dev/null +++ b/tulip/winsocketpair.py @@ -0,0 +1,34 @@ +"""A socket pair usable as a self-pipe, for Windows. + +Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. +""" + +import socket +import sys + +if sys.platform != 'win32': # pragma: no cover + raise ImportError('winsocketpair is win32 only') + + +def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """Emulate the Unix socketpair() function on Windows.""" + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + lsock.bind(('localhost', 0)) + lsock.listen(1) + addr, port = lsock.getsockname() + csock = socket.socket(family, type, proto) + csock.setblocking(False) + try: + csock.connect((addr, port)) + except (BlockingIOError, InterruptedError): + pass + except: + lsock.close() + csock.close() + raise + ssock, _ = lsock.accept() + csock.setblocking(True) + lsock.close() + return (ssock, csock) From c21e4eedef4f0d3d5939c9832cb37fbe69c324b3 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 9 Apr 2013 12:09:08 -0700 Subject: [PATCH 0408/1502] warn about omitting future result or exception --- tests/base_events_test.py | 14 ++-- tests/events_test.py | 65 +++++----------- tests/futures_test.py | 71 ++++++++++++++++- tests/http_client_functional_test.py | 7 +- tests/http_protocol_test.py | 7 +- tests/http_server_test.py | 11 +-- tests/http_wsgi_test.py | 7 +- tests/locks_test.py | 112 ++++++++++++++++++++------- tests/proactor_events_test.py | 4 +- tests/queues_test.py | 16 +++- tests/selector_events_test.py | 3 + tests/selectors_test.py | 4 +- tests/streams_test.py | 9 +-- tests/tasks_test.py | 60 +++++++------- tests/unix_events_test.py | 2 + tulip/base_events.py | 17 ++-- tulip/futures.py | 54 +++++++++++++ tulip/http/server.py | 2 +- tulip/selector_events.py | 13 ++-- tulip/tasks.py | 18 ++--- tulip/test_utils.py | 22 +----- 21 files changed, 328 insertions(+), 190 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index 10f7c480..e9383264 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -12,14 +12,11 @@ from tulip import futures from tulip import protocols from tulip import tasks -from tulip import test_utils -class BaseEventLoopTests(test_utils.LogTrackingTestCase): +class BaseEventLoopTests(unittest.TestCase): def setUp(self): - super().setUp() - self.event_loop = base_events.BaseEventLoop() self.event_loop._selector = unittest.mock.Mock() self.event_loop._selector.registered_count.return_value = 1 @@ -75,12 +72,15 @@ def test_add_callback_cancelled_handle(self): self.assertFalse(self.event_loop._ready) def test_wrap_future(self): - f = futures.Future() + f = futures.Future(event_loop=self.event_loop) self.assertIs(self.event_loop.wrap_future(f), f) + f.cancel() def test_wrap_future_concurrent(self): f = concurrent.futures.Future() - self.assertIsInstance(self.event_loop.wrap_future(f), futures.Future) + fut = self.event_loop.wrap_future(f) + self.assertIsInstance(fut, futures.Future) + fut.cancel() def test_set_default_executor(self): executor = unittest.mock.Mock() @@ -141,6 +141,7 @@ def cb(): f = self.event_loop.run_in_executor(None, h) self.assertIsInstance(f, futures.Future) self.assertTrue(f.done()) + self.assertIsNone(f.result()) def test_run_once_in_executor(self): def cb(): @@ -253,7 +254,6 @@ def test_run_until_complete_assertion(self): @unittest.mock.patch('tulip.base_events.socket') def test_create_connection_mutiple_errors(self, m_socket): - self.suppress_log_errors() class MyProto(protocols.Protocol): pass diff --git a/tests/events_test.py b/tests/events_test.py index e8855548..c28d88eb 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -132,8 +132,6 @@ def test_run(self): self.event_loop.run() # Returns immediately. def test_run_nesting(self): - self.suppress_log_errors() - @tasks.coroutine def coro(): self.assertTrue(self.event_loop.is_running()) @@ -144,8 +142,6 @@ def coro(): self.event_loop.run_until_complete, coro()) def test_run_once_nesting(self): - self.suppress_log_errors() - @tasks.coroutine def coro(): tasks.sleep(0.1) @@ -422,8 +418,6 @@ def sender(): self.assertTrue(data == b'x'*256) def test_sock_client_ops(self): - self.suppress_log_errors() - with test_utils.run_test_server(self.event_loop) as httpd: sock = socket.socket() sock.setblocking(False) @@ -433,6 +427,9 @@ def test_sock_client_ops(self): self.event_loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) data = self.event_loop.run_until_complete( self.event_loop.sock_recv(sock, 1024)) + # consume data + self.event_loop.run_until_complete( + self.event_loop.sock_recv(sock, 1024)) sock.close() self.assertTrue(re.match(rb'HTTP/1.0 200 OK', data), data) @@ -616,18 +613,15 @@ def test_create_ssl_connection(self): self.assertTrue(pr.nbytes > 0) def test_create_connection_host_port_sock(self): - self.suppress_log_errors() coro = self.event_loop.create_connection( MyProto, 'example.com', 80, sock=object()) self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) def test_create_connection_no_host_port_sock(self): - self.suppress_log_errors() coro = self.event_loop.create_connection(MyProto) self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) def test_create_connection_no_getaddrinfo(self): - self.suppress_log_errors() getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() getaddrinfo.return_value = [] @@ -636,8 +630,6 @@ def test_create_connection_no_getaddrinfo(self): socket.error, self.event_loop.run_until_complete, coro) def test_create_connection_connect_err(self): - self.suppress_log_errors() - @tasks.coroutine def getaddrinfo(*args, **kw): yield from [] @@ -651,8 +643,6 @@ def getaddrinfo(*args, **kw): socket.error, self.event_loop.run_until_complete, coro) def test_create_connection_mutiple_errors(self): - self.suppress_log_errors() - @tasks.coroutine def getaddrinfo(*args, **kw): yield from [] @@ -798,18 +788,15 @@ def test_stop_serving(self): ConnectionRefusedError, client.connect, ('127.0.0.1', port)) def test_start_serving_host_port_sock(self): - self.suppress_log_errors() fut = self.event_loop.start_serving( MyProto, '0.0.0.0', 0, sock=object()) self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) def test_start_serving_no_host_port_sock(self): - self.suppress_log_errors() fut = self.event_loop.start_serving(MyProto) self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) def test_start_serving_no_getaddrinfo(self): - self.suppress_log_errors() getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() getaddrinfo.return_value = [] @@ -819,7 +806,6 @@ def test_start_serving_no_getaddrinfo(self): @unittest.mock.patch('tulip.base_events.socket') def test_start_serving_cant_bind(self, m_socket): - self.suppress_log_errors() class Err(socket.error): pass @@ -836,8 +822,6 @@ class Err(socket.error): @unittest.mock.patch('tulip.base_events.socket') def test_create_datagram_endpoint_no_addrinfo(self, m_socket): - self.suppress_log_errors() - m_socket.error = socket.error m_socket.getaddrinfo.return_value = [] @@ -847,8 +831,6 @@ def test_create_datagram_endpoint_no_addrinfo(self, m_socket): socket.error, self.event_loop.run_until_complete, coro) def test_create_datagram_endpoint_addr_error(self): - self.suppress_log_errors() - coro = self.event_loop.create_datagram_endpoint( MyDatagramProto, local_addr='localhost') self.assertRaises( @@ -894,7 +876,6 @@ def datagram_received(self, data, addr): server.transport.close() def test_create_datagram_endpoint_connect_err(self): - self.suppress_log_errors() self.event_loop.sock_connect = unittest.mock.Mock() self.event_loop.sock_connect.side_effect = socket.error @@ -905,8 +886,6 @@ def test_create_datagram_endpoint_connect_err(self): @unittest.mock.patch('tulip.base_events.socket') def test_create_datagram_endpoint_socket_err(self, m_socket): - self.suppress_log_errors() - m_socket.error = socket.error m_socket.getaddrinfo = socket.getaddrinfo m_socket.socket.side_effect = socket.error @@ -922,8 +901,6 @@ def test_create_datagram_endpoint_socket_err(self, m_socket): socket.error, self.event_loop.run_until_complete, coro) def test_create_datagram_endpoint_no_matching_family(self): - self.suppress_log_errors() - coro = self.event_loop.create_datagram_endpoint( protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) @@ -932,8 +909,6 @@ def test_create_datagram_endpoint_no_matching_family(self): @unittest.mock.patch('tulip.base_events.socket') def test_create_datagram_endpoint_setblk_err(self, m_socket): - self.suppress_log_errors() - m_socket.error = socket.error m_socket.socket.return_value.setblocking.side_effect = socket.error @@ -945,16 +920,12 @@ def test_create_datagram_endpoint_setblk_err(self, m_socket): m_socket.socket.return_value.close.called) def test_create_datagram_endpoint_noaddr_nofamily(self): - self.suppress_log_errors() - coro = self.event_loop.create_datagram_endpoint( protocols.DatagramProtocol) self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) @unittest.mock.patch('tulip.base_events.socket') def test_create_datagram_endpoint_cant_bind(self, m_socket): - self.suppress_log_errors() - class Err(socket.error): pass @@ -977,14 +948,14 @@ def test_accept_connection_retry(self): self.event_loop._accept_connection(MyProto, sock) self.assertFalse(sock.close.called) - def test_accept_connection_exception(self): - self.suppress_log_errors() - + @unittest.mock.patch('tulip.selector_events.tulip_log') + def test_accept_connection_exception(self, m_log): sock = unittest.mock.Mock() sock.accept.side_effect = OSError() self.event_loop._accept_connection(MyProto, sock) self.assertTrue(sock.close.called) + self.assertTrue(m_log.exception.called) def test_internal_fds(self): event_loop = self.create_event_loop() @@ -1088,13 +1059,13 @@ def connect(): if sys.platform == 'win32': from tulip import windows_events - class SelectEventLoopTests(EventLoopTestsMixin, - test_utils.LogTrackingTestCase): + class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + def create_event_loop(self): return windows_events.SelectorEventLoop() - class ProactorEventLoopTests(EventLoopTestsMixin, - test_utils.LogTrackingTestCase): + class ProactorEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + def create_event_loop(self): return windows_events.ProactorEventLoop() def test_create_ssl_connection(self): @@ -1145,27 +1116,27 @@ def test_stop_serving(self): from tulip import unix_events if hasattr(selectors, 'KqueueSelector'): - class KqueueEventLoopTests(EventLoopTestsMixin, - test_utils.LogTrackingTestCase): + class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + def create_event_loop(self): return unix_events.SelectorEventLoop( selectors.KqueueSelector()) if hasattr(selectors, 'EpollSelector'): - class EPollEventLoopTests(EventLoopTestsMixin, - test_utils.LogTrackingTestCase): + class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + def create_event_loop(self): return unix_events.SelectorEventLoop(selectors.EpollSelector()) if hasattr(selectors, 'PollSelector'): - class PollEventLoopTests(EventLoopTestsMixin, - test_utils.LogTrackingTestCase): + class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + def create_event_loop(self): return unix_events.SelectorEventLoop(selectors.PollSelector()) # Should always exist. - class SelectEventLoopTests(EventLoopTestsMixin, - test_utils.LogTrackingTestCase): + class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + def create_event_loop(self): return unix_events.SelectorEventLoop(selectors.SelectSelector()) diff --git a/tests/futures_test.py b/tests/futures_test.py index 5569cca1..9e2c4dea 100644 --- a/tests/futures_test.py +++ b/tests/futures_test.py @@ -1,7 +1,10 @@ """Tests for futures.py.""" +import logging import unittest +import unittest.mock +from tulip import events from tulip import futures @@ -11,11 +14,16 @@ def _fakefunc(f): class FutureTests(unittest.TestCase): + def setUp(self): + self.loop = events.get_event_loop() + def test_initial_state(self): f = futures.Future() self.assertFalse(f.cancelled()) self.assertFalse(f.running()) self.assertFalse(f.done()) + f.cancel() + self.assertTrue(f.cancelled()) def test_init_event_loop_positional(self): # Make sure Future does't accept a positional argument @@ -85,6 +93,7 @@ def fixture(): def test_repr(self): f_pending = futures.Future() self.assertEqual(repr(f_pending), 'Future') + f_pending.cancel() f_cancelled = futures.Future() f_cancelled.cancel() @@ -93,15 +102,19 @@ def test_repr(self): f_result = futures.Future() f_result.set_result(4) self.assertEqual(repr(f_result), 'Future') + self.assertEqual(f_result.result(), 4) + exc = RuntimeError() f_exception = futures.Future() - f_exception.set_exception(RuntimeError()) + f_exception.set_exception(exc) self.assertEqual(repr(f_exception), 'Future') + self.assertIs(f_exception.exception(), exc) f_few_callbacks = futures.Future() f_few_callbacks.add_done_callback(_fakefunc) self.assertIn('Future', r) + f_many_callbacks.cancel() def test_copy_state(self): # Test the internal _copy_state method since it's being directly @@ -137,14 +151,61 @@ def test_copy_state(self): self.assertTrue(newf_cancelled.cancelled()) def test_iter(self): + fut = futures.Future() + def coro(): - fut = futures.Future() yield from fut def test(): arg1, arg2 = coro() self.assertRaises(AssertionError, test) + fut.cancel() + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_del_normal(self, log): + self.loop.set_log_level(futures.STACK_DEBUG) + + fut = futures.Future() + fut.set_result(True) + fut.result() + del fut + self.assertFalse(log.error.called) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_del_not_done(self, log): + self.loop.set_log_level(futures.STACK_DEBUG) + + fut = futures.Future() + r_fut = repr(fut) + del fut + log.error.mock_calls[-1].assert_called_with( + 'Future abandoned before completion: %r', r_fut) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_del_done(self, log): + self.loop.set_log_level(futures.STACK_DEBUG) + + fut = futures.Future() + next(iter(fut)) + fut.set_result(1) + r_fut = repr(fut) + del fut + log.error.mock_calls[-1].assert_called_with( + 'Future result has not been requested: %r', r_fut) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_del_exc(self, log): + self.loop.set_log_level(futures.STACK_DEBUG) + + exc = ValueError() + fut = futures.Future() + fut.set_exception(exc) + r_fut = repr(fut) + del fut + log.exception.mock_calls[-1].assert_called_with( + 'Future raised an exception and nobody caught it: %r', r_fut, + exc_info=(ValueError, exc, None)) # A fake event loop for tests. All it does is implement a call_soon method @@ -153,6 +214,12 @@ class _FakeEventLoop: def call_soon(self, fn, future): fn(future) + def set_log_level(self, val): + pass + + def get_log_level(self): + return logging.CRITICAL + class FutureDoneCallbackTests(unittest.TestCase): diff --git a/tests/http_client_functional_test.py b/tests/http_client_functional_test.py index 120f78b8..1ee3ab8c 100644 --- a/tests/http_client_functional_test.py +++ b/tests/http_client_functional_test.py @@ -3,6 +3,7 @@ import io import os.path import http.cookies +import unittest import tulip import tulip.http @@ -10,18 +11,14 @@ from tulip.http import client -class HttpClientFunctionalTests(test_utils.LogTrackingTestCase): +class HttpClientFunctionalTests(unittest.TestCase): def setUp(self): - super().setUp() - self.suppress_log_errors() - self.loop = tulip.new_event_loop() tulip.set_event_loop(self.loop) def tearDown(self): self.loop.close() - super().tearDown() def test_HTTP_200_OK_METHOD(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py index c806337c..272c4c3e 100644 --- a/tests/http_protocol_test.py +++ b/tests/http_protocol_test.py @@ -7,15 +7,11 @@ import tulip from tulip.http import protocol -from tulip.test_utils import LogTrackingTestCase -class HttpStreamReaderTests(LogTrackingTestCase): +class HttpStreamReaderTests(unittest.TestCase): def setUp(self): - super().setUp() - self.suppress_log_errors() - self.loop = tulip.new_event_loop() tulip.set_event_loop(self.loop) @@ -24,7 +20,6 @@ def setUp(self): def tearDown(self): self.loop.close() - super().tearDown() def test_request_line(self): self.stream.feed_data(b'get /path HTTP/1.1\r\n') diff --git a/tests/http_server_test.py b/tests/http_server_test.py index 2ab41840..299e950d 100644 --- a/tests/http_server_test.py +++ b/tests/http_server_test.py @@ -6,21 +6,16 @@ import tulip from tulip.http import server from tulip.http import errors -from tulip.test_utils import LogTrackingTestCase -class HttpServerProtocolTests(LogTrackingTestCase): +class HttpServerProtocolTests(unittest.TestCase): def setUp(self): - super().setUp() - self.suppress_log_errors() - self.loop = tulip.new_event_loop() tulip.set_event_loop(self.loop) def tearDown(self): self.loop.close() - super().tearDown() def test_http_status_exception(self): exc = errors.HttpStatusException(500, message='Internal error') @@ -97,7 +92,8 @@ def test_handle_error(self): @unittest.mock.patch('tulip.http.server.traceback') def test_handle_error_traceback_exc(self, m_trace): transport = unittest.mock.Mock() - srv = server.ServerHttpProtocol(debug=True) + log = unittest.mock.Mock() + srv = server.ServerHttpProtocol(debug=True, log=log) srv.connection_made(transport) m_trace.format_exc.side_effect = ValueError @@ -106,6 +102,7 @@ def test_handle_error_traceback_exc(self, m_trace): content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) self.assertTrue( content.startswith(b'HTTP/1.1 500 Internal Server Error')) + self.assertTrue(log.exception.called) def test_handle_error_debug(self): transport = unittest.mock.Mock() diff --git a/tests/http_wsgi_test.py b/tests/http_wsgi_test.py index 2fc0fee8..e2757c91 100644 --- a/tests/http_wsgi_test.py +++ b/tests/http_wsgi_test.py @@ -7,15 +7,11 @@ import tulip from tulip.http import wsgi from tulip.http import protocol -from tulip.test_utils import LogTrackingTestCase -class HttpWsgiServerProtocolTests(LogTrackingTestCase): +class HttpWsgiServerProtocolTests(unittest.TestCase): def setUp(self): - super().setUp() - self.suppress_log_errors() - self.loop = tulip.new_event_loop() tulip.set_event_loop(self.loop) @@ -32,7 +28,6 @@ def setUp(self): def tearDown(self): self.loop.close() - super().tearDown() def test_ctor(self): srv = wsgi.WSGIServerHttpProtocol(self.wsgi) diff --git a/tests/locks_test.py b/tests/locks_test.py index 5f1c180a..a2e03381 100644 --- a/tests/locks_test.py +++ b/tests/locks_test.py @@ -8,7 +8,6 @@ from tulip import futures from tulip import locks from tulip import tasks -from tulip import test_utils class LockTests(unittest.TestCase): @@ -57,19 +56,22 @@ def test_acquire(self): def c1(result): if (yield from lock.acquire()): result.append(1) + return True @tasks.coroutine def c2(result): if (yield from lock.acquire()): result.append(2) + return True @tasks.coroutine def c3(result): if (yield from lock.acquire()): result.append(3) + return True - tasks.Task(c1(result)) - tasks.Task(c2(result)) + t1 = tasks.Task(c1(result)) + t2 = tasks.Task(c2(result)) self.event_loop.run_once() self.assertEqual([], result) @@ -81,7 +83,7 @@ def c3(result): self.event_loop.run_once() self.assertEqual([1], result) - tasks.Task(c3(result)) + t3 = tasks.Task(c3(result)) lock.release() self.event_loop.run_once() @@ -91,6 +93,13 @@ def c3(result): self.event_loop.run_once() self.assertEqual([1, 2, 3], result) + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + def test_acquire_timeout(self): lock = locks.Lock() self.assertTrue( @@ -210,18 +219,25 @@ def c3(result): if (yield from ev.wait()): result.append(3) - tasks.Task(c1(result)) - tasks.Task(c2(result)) + t1 = tasks.Task(c1(result)) + t2 = tasks.Task(c2(result)) self.event_loop.run_once() self.assertEqual([], result) - tasks.Task(c3(result)) + t3 = tasks.Task(c3(result)) ev.set() self.event_loop.run_once() self.assertEqual([3, 1, 2], result) + self.assertTrue(t1.done()) + self.assertIsNone(t1.result()) + self.assertTrue(t2.done()) + self.assertIsNone(t2.result()) + self.assertTrue(t3.done()) + self.assertIsNone(t3.result()) + def test_wait_on_set(self): ev = locks.EventWaiter() ev.set() @@ -287,8 +303,9 @@ def test_clear_with_waiters(self): def c1(result): if (yield from ev.wait()): result.append(1) + return True - tasks.Task(c1(result)) + t = tasks.Task(c1(result)) self.event_loop.run_once() self.assertEqual([], result) @@ -304,11 +321,13 @@ def c1(result): self.assertEqual([1], result) self.assertEqual(0, len(ev._waiters)) + self.assertTrue(t.done()) + self.assertTrue(t.result()) + -class ConditionTests(test_utils.LogTrackingTestCase): +class ConditionTests(unittest.TestCase): def setUp(self): - super().setUp() self.event_loop = events.new_event_loop() events.set_event_loop(self.event_loop) @@ -324,22 +343,25 @@ def c1(result): yield from cond.acquire() if (yield from cond.wait()): result.append(1) + return True @tasks.coroutine def c2(result): yield from cond.acquire() if (yield from cond.wait()): result.append(2) + return True @tasks.coroutine def c3(result): yield from cond.acquire() if (yield from cond.wait()): result.append(3) + return True - tasks.Task(c1(result)) - tasks.Task(c2(result)) - tasks.Task(c3(result)) + t1 = tasks.Task(c1(result)) + t2 = tasks.Task(c2(result)) + t3 = tasks.Task(c3(result)) self.event_loop.run_once() self.assertEqual([], result) @@ -372,6 +394,13 @@ def c3(result): self.assertEqual([1, 2, 3], result) self.assertTrue(cond.locked()) + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + def test_wait_timeout(self): cond = locks.Condition() self.event_loop.run_until_complete(cond.acquire()) @@ -397,8 +426,6 @@ def test_wait_cancel(self): self.assertTrue(cond.locked()) def test_wait_unacquired(self): - self.suppress_log_errors() - cond = locks.Condition() self.assertRaises( RuntimeError, @@ -419,8 +446,9 @@ def c1(result): if (yield from cond.wait_for(predicate)): result.append(1) cond.release() + return True - tasks.Task(c1(result)) + t = tasks.Task(c1(result)) self.event_loop.run_once() self.assertEqual([], result) @@ -438,6 +466,9 @@ def c1(result): self.event_loop.run_once() self.assertEqual([1], result) + self.assertTrue(t.done()) + self.assertTrue(t.result()) + def test_wait_for_timeout(self): cond = locks.Condition() @@ -476,8 +507,6 @@ def c1(result): self.assertTrue(0.08 < total_time < 0.12) def test_wait_for_unacquired(self): - self.suppress_log_errors() - cond = locks.Condition() # predicate can return true immediately @@ -500,6 +529,7 @@ def c1(result): if (yield from cond.wait()): result.append(1) cond.release() + return True @tasks.coroutine def c2(result): @@ -507,6 +537,7 @@ def c2(result): if (yield from cond.wait()): result.append(2) cond.release() + return True @tasks.coroutine def c3(result): @@ -514,10 +545,11 @@ def c3(result): if (yield from cond.wait()): result.append(3) cond.release() + return True - tasks.Task(c1(result)) - tasks.Task(c2(result)) - tasks.Task(c3(result)) + t1 = tasks.Task(c1(result)) + t2 = tasks.Task(c2(result)) + t3 = tasks.Task(c3(result)) self.event_loop.run_once() self.assertEqual([], result) @@ -535,6 +567,13 @@ def c3(result): self.event_loop.run_once() self.assertEqual([1, 2, 3], result) + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + def test_notify_all(self): cond = locks.Condition() @@ -546,6 +585,7 @@ def c1(result): if (yield from cond.wait()): result.append(1) cond.release() + return True @tasks.coroutine def c2(result): @@ -553,9 +593,10 @@ def c2(result): if (yield from cond.wait()): result.append(2) cond.release() + return True - tasks.Task(c1(result)) - tasks.Task(c2(result)) + t1 = tasks.Task(c1(result)) + t2 = tasks.Task(c2(result)) self.event_loop.run_once() self.assertEqual([], result) @@ -566,6 +607,11 @@ def c2(result): self.event_loop.run_once() self.assertEqual([1, 2], result) + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + def test_notify_unacquired(self): cond = locks.Condition() self.assertRaises(RuntimeError, cond.notify) @@ -626,25 +672,29 @@ def test_acquire(self): def c1(result): yield from sem.acquire() result.append(1) + return True @tasks.coroutine def c2(result): yield from sem.acquire() result.append(2) + return True @tasks.coroutine def c3(result): yield from sem.acquire() result.append(3) + return True @tasks.coroutine def c4(result): yield from sem.acquire() result.append(4) + return True - tasks.Task(c1(result)) - tasks.Task(c2(result)) - tasks.Task(c3(result)) + t1 = tasks.Task(c1(result)) + t2 = tasks.Task(c2(result)) + t3 = tasks.Task(c3(result)) self.event_loop.run_once() self.assertEqual([1], result) @@ -652,7 +702,7 @@ def c4(result): self.assertEqual(2, len(sem._waiters)) self.assertEqual(0, sem._value) - tasks.Task(c4(result)) + t4 = tasks.Task(c4(result)) sem.release() sem.release() @@ -665,6 +715,14 @@ def c4(result): self.assertEqual(1, len(sem._waiters)) self.assertEqual(0, sem._value) + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + self.assertFalse(t4.done()) + def test_acquire_timeout(self): sem = locks.Semaphore() self.event_loop.run_until_complete(sem.acquire()) diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py index 7a92ad08..7b4d5aa3 100644 --- a/tests/proactor_events_test.py +++ b/tests/proactor_events_test.py @@ -309,7 +309,8 @@ def test_write_to_self(self): def test_process_events(self): self.event_loop._process_events([]) - def test_start_serving(self): + @unittest.mock.patch('tulip.proactor_events.tulip_log') + def test_start_serving(self, m_log): pf = unittest.mock.Mock() call_soon = self.event_loop.call_soon = unittest.mock.Mock() @@ -335,3 +336,4 @@ def test_start_serving(self): fut.result.side_effect = OSError() loop(fut) self.assertTrue(self.sock.close.called) + self.assertTrue(m_log.exception.called) diff --git a/tests/queues_test.py b/tests/queues_test.py index 8c1c0afb..714465d9 100644 --- a/tests/queues_test.py +++ b/tests/queues_test.py @@ -99,10 +99,11 @@ def putter(): for i in range(3): yield from q.put(i) have_been_put.append(i) + return True @tasks.coroutine def test(): - tasks.Task(putter()) + t = tasks.Task(putter()) yield from tasks.sleep(0.01) # The putter is blocked after putting two items. @@ -115,6 +116,9 @@ def test(): self.assertEqual(1, q.get_nowait()) self.assertEqual(2, q.get_nowait()) + self.assertTrue(t.done()) + self.assertTrue(t.result()) + self.event_loop.run_until_complete(test()) @@ -178,9 +182,12 @@ def queue_get(): q.put_nowait(1) self.assertEqual(1, (yield from q.get())) - tasks.Task(q.put(2)) + t = tasks.Task(q.put(2)) self.assertEqual(2, (yield from q.get())) + self.assertTrue(t.done()) + self.assertIsNone(t.result()) + self.event_loop.run_until_complete(queue_get()) def test_get_timeout_cancelled(self): @@ -275,13 +282,16 @@ def test_put_timeout_cancelled(self): @tasks.coroutine def queue_put(): yield from q.put(1, timeout=0.01) + return True @tasks.coroutine def test(): - tasks.Task(queue_put()) return (yield from q.get()) + t = tasks.Task(queue_put()) self.assertEqual(1, self.event_loop.run_until_complete(test())) + self.assertTrue(t.done()) + self.assertTrue(t.result()) class LifoQueueTests(_QueueTestBase): diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index 031afeb9..e03ef1f3 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -168,6 +168,7 @@ def test_sock_sendall_nodata(self): f = self.event_loop.sock_sendall(sock, b'') self.assertIsInstance(f, futures.Future) self.assertTrue(f.done()) + self.assertIsNone(f.result()) self.assertFalse(self.event_loop._sock_sendall.called) def test__sock_sendall_canceled_fut(self): @@ -220,6 +221,7 @@ def test__sock_sendall(self): self.event_loop._sock_sendall(f, False, sock, b'data') self.assertTrue(f.done()) + self.assertIsNone(f.result()) def test__sock_sendall_partial(self): sock = unittest.mock.Mock() @@ -267,6 +269,7 @@ def test__sock_connect(self): self.event_loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) self.assertTrue(f.done()) + self.assertIsNone(f.result()) self.assertTrue(sock.connect.called) def test__sock_connect_canceled_fut(self): diff --git a/tests/selectors_test.py b/tests/selectors_test.py index 996c0130..1457e4ed 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -129,9 +129,11 @@ def test_context_manager(self): self.assertFalse(s._fd_to_key) self.assertFalse(s._fileobj_to_key) - def test_key_from_fd(self): + @unittest.mock.patch('tulip.selectors.tulip_log') + def test_key_from_fd(self, m_log): s = selectors._BaseSelector() key = s.register(1, selectors.EVENT_READ) self.assertIs(key, s._key_from_fd(1)) self.assertIsNone(s._key_from_fd(10)) + m_log.warning.assert_called_with('No key found for fd %r', 10) diff --git a/tests/streams_test.py b/tests/streams_test.py index f7e2992b..832c3119 100644 --- a/tests/streams_test.py +++ b/tests/streams_test.py @@ -5,20 +5,17 @@ from tulip import events from tulip import streams from tulip import tasks -from tulip import test_utils -class StreamReaderTests(test_utils.LogTrackingTestCase): +class StreamReaderTests(unittest.TestCase): DATA = b'line1\nline2\nline3\n' def setUp(self): - super().setUp() self.event_loop = events.new_event_loop() events.set_event_loop(self.event_loop) def tearDown(self): - super().tearDown() self.event_loop.close() def test_feed_empty_data(self): @@ -124,8 +121,6 @@ def cb(): self.assertEqual(len(b'\n chunk4')-1, stream.byte_count) def test_readline_limit_with_existing_data(self): - self.suppress_log_errors() - stream = streams.StreamReader(3) stream.feed_data(b'li') stream.feed_data(b'ne1\nline2\n') @@ -145,8 +140,6 @@ def test_readline_limit_with_existing_data(self): self.assertEqual(2, stream.byte_count) def test_readline_limit(self): - self.suppress_log_errors() - stream = streams.StreamReader(7) def cb(): diff --git a/tests/tasks_test.py b/tests/tasks_test.py index fa22d62c..ae9264cc 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -8,7 +8,6 @@ from tulip import events from tulip import futures from tulip import tasks -from tulip import test_utils class Dummy: @@ -20,16 +19,14 @@ def __call__(self, *args): pass -class TaskTests(test_utils.LogTrackingTestCase): +class TaskTests(unittest.TestCase): def setUp(self): - super().setUp() self.event_loop = events.new_event_loop() events.set_event_loop(self.event_loop) def tearDown(self): self.event_loop.close() - super().tearDown() def test_task_class(self): @tasks.coroutine @@ -156,8 +153,6 @@ def coro(): self.assertFalse(t.cancel()) def test_future_timeout_catch(self): - self.suppress_log_errors() - @tasks.coroutine def coro(): yield from tasks.sleep(10.0) @@ -272,6 +267,9 @@ def test_wait_first_completed(self): done, pending = self.event_loop.run_until_complete(task) self.assertEqual({b}, done) self.assertEqual({a}, pending) + self.assertFalse(a.done()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) def test_wait_really_done(self): # there is possibility that some tasks in the pending list @@ -293,10 +291,13 @@ def coro2(): done, pending = self.event_loop.run_until_complete(task) self.assertEqual({a, b}, done) + self.assertTrue(a.done()) + self.assertIsNone(a.result()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) def test_wait_first_exception(self): - self.suppress_log_errors() - + # first_exception, task already has exception a = tasks.sleep(10.0) @tasks.coroutine @@ -311,8 +312,23 @@ def exc(): self.assertEqual({b}, done) self.assertEqual({a}, pending) + def test_wait_first_exception_in_wait(self): + # first_exception, exception during waiting + a = tasks.sleep(10.0) + + @tasks.coroutine + def exc(): + yield from tasks.sleep(0.01) + raise ZeroDivisionError('err') + + b = tasks.Task(exc()) + task = tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + def test_wait_with_exception(self): - self.suppress_log_errors() a = tasks.sleep(0.1) @tasks.coroutine @@ -386,7 +402,6 @@ def foo(): self.assertTrue(t1-t0 <= 0.01) def test_as_completed_with_timeout(self): - self.suppress_log_errors() a = tasks.sleep(0.1, 'a') b = tasks.sleep(0.15, 'b') @@ -477,8 +492,7 @@ def coro(): self.assertIsNone(task._fut_waiter) self.assertTrue(fut.cancelled()) - @unittest.mock.patch('tulip.tasks.tulip_log') - def test_step_in_completed_task(self, m_logging): + def test_step_in_completed_task(self): @tasks.coroutine def notmuch(): return 'ko' @@ -488,8 +502,7 @@ def notmuch(): self.assertRaises(AssertionError, task._step) - @unittest.mock.patch('tulip.tasks.tulip_log') - def test_step_result(self, m_logging): + def test_step_result(self): @tasks.coroutine def notmuch(): yield None @@ -501,7 +514,6 @@ def notmuch(): def test_step_result_future(self): # If coroutine returns future, task waits on this future. - self.suppress_log_warnings() class Fut(futures.Future): def __init__(self, *args): @@ -520,7 +532,7 @@ def wait_for_future(): nonlocal result result = yield from fut - wait_for_future() + t = wait_for_future() self.event_loop.run_once() self.assertTrue(fut.cb_added) @@ -528,6 +540,8 @@ def wait_for_future(): fut.set_result(res) self.event_loop.run_once() self.assertIs(res, result) + self.assertTrue(t.done()) + self.assertIsNone(t.result()) def test_step_result_concurrent_future(self): # Coroutine returns concurrent.futures.Future @@ -557,8 +571,6 @@ def notmuch(): self.assertIs(res, task.result()) def test_step_with_baseexception(self): - self.suppress_log_errors() - @tasks.coroutine def notmutch(): raise BaseException() @@ -570,8 +582,6 @@ def notmutch(): self.assertIsInstance(task.exception(), BaseException) def test_baseexception_during_cancel(self): - self.suppress_log_errors() - @tasks.coroutine def sleeper(): yield from tasks.sleep(10) @@ -610,7 +620,6 @@ def fn2(): self.assertTrue(tasks.iscoroutinefunction(fn2)) def test_yield_vs_yield_from(self): - self.suppress_log_errors() fut = futures.Future() @tasks.task @@ -625,12 +634,9 @@ def wait_for_future(): self.assertIs(fut.exception(), cm.exception) def test_yield_vs_yield_from_generator(self): - self.suppress_log_errors() - fut = futures.Future() - @tasks.coroutine def coro(): - yield from fut + yield @tasks.task def wait_for_future(): @@ -642,7 +648,6 @@ def wait_for_future(): self.event_loop.run_until_complete, task) def test_coroutine_non_gen_function(self): - @tasks.coroutine def func(): return 'test' @@ -667,9 +672,10 @@ def coro(): fut.set_result('test') t1 = tasks.Task(func()) - tasks.Task(coro()) + t2 = tasks.Task(coro()) res = self.event_loop.run_until_complete(t1) self.assertEqual(res, 'test') + self.assertIsNone(t2.result()) if __name__ == '__main__': diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index d7af7ecc..686f2e41 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -193,6 +193,7 @@ def test_ctor_with_waiter(self, m_fcntl): unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol, fut) self.event_loop.call_soon.assert_called_with(fut.set_result, None) + fut.cancel() @unittest.mock.patch('os.read') @unittest.mock.patch('fcntl.fcntl') @@ -337,6 +338,7 @@ def test_ctor_with_waiter(self, m_fcntl): unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol, fut) self.event_loop.call_soon.assert_called_with(fut.set_result, None) + fut.cancel() @unittest.mock.patch('fcntl.fcntl') def test_can_write_eof(self, m_fcntl): diff --git a/tulip/base_events.py b/tulip/base_events.py index d9e2316b..66397421 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -328,12 +328,6 @@ def create_connection(self, protocol_factory, host=None, port=None, *, sock.setblocking(False) - if ssl: - import ssl as sslmod - sslcontext = sslmod.SSLContext(sslmod.PROTOCOL_SSLv23) - sock = sslcontext.wrap_socket(sock, server_side=False, - do_handshake_on_connect=False) - protocol = protocol_factory() waiter = futures.Future() if ssl: @@ -490,7 +484,7 @@ def wrap_future(self, future): """XXX""" if isinstance(future, futures.Future): return future # Don't wrap our own type of Future. - new_future = futures.Future() + new_future = futures.Future(event_loop=self) future.add_done_callback( lambda future: self.call_soon_threadsafe(new_future._copy_state, future)) @@ -554,3 +548,12 @@ def _run_once(self, timeout=None): handle = self._ready.popleft() if not handle.cancelled: handle.run() + + # Future.__del__ uses log level + _log_level = logging.WARNING + + def set_log_level(self, val): + self._log_level = val + + def get_log_level(self): + return self._log_level diff --git a/tulip/futures.py b/tulip/futures.py index 39137aa6..a778a51c 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -6,8 +6,12 @@ ] import concurrent.futures._base +import io +import logging +import traceback from . import events +from .log import tulip_log # States for Future. _PENDING = 'PENDING' @@ -19,6 +23,8 @@ CancelledError = concurrent.futures.CancelledError TimeoutError = concurrent.futures.TimeoutError +STACK_DEBUG = logging.DEBUG - 1 # heavy-duty debugging + class InvalidStateError(Error): """The operation is not allowed in this state.""" @@ -52,9 +58,14 @@ class Future: _result = None _exception = None _timeout_handle = None + _event_loop = None _blocking = False # proper use of future (yield vs yield from) + # result of the future has to be requested + _debug_stack = None + _debug_result_requested = False + def __init__(self, *, event_loop=None, timeout=None): """Initialize the future. @@ -72,6 +83,12 @@ def __init__(self, *, event_loop=None, timeout=None): self._timeout_handle = self._event_loop.call_later( timeout, self.cancel) + if __debug__: + if self._event_loop.get_log_level() <= STACK_DEBUG: + out = io.StringIO() + traceback.print_stack(file=out) + self._debug_stack = out.getvalue() + def __repr__(self): res = self.__class__.__name__ if self._state == _FINISHED: @@ -151,6 +168,8 @@ def result(self, timeout=0): the future is done and has an exception set, this exception is raised. Timeout values other than 0 are not supported. """ + if __debug__: + self._debug_result_requested = True if timeout != 0: raise InvalidTimeoutError if self._state == _CANCELLED: @@ -169,6 +188,8 @@ def exception(self, timeout=0): CancelledError. If the future isn't done yet, raises InvalidStateError. Timeout values other than 0 are not supported. """ + if __debug__: + self._debug_result_requested = True if timeout != 0: raise InvalidTimeoutError if self._state == _CANCELLED: @@ -253,3 +274,36 @@ def __iter__(self): yield self # This tells Task to wait for completion. assert self.done(), "yield from wasn't used with future" return self.result() # May raise too. + + if __debug__: + def __del__(self): + if (not self._debug_result_requested and + self._state != _CANCELLED and + self._event_loop is not None): + + level = self._event_loop.get_log_level() + if level > logging.WARNING: + return + + r_self = repr(self) + + if self._state == _PENDING: + tulip_log.error( + 'Future abandoned before completion: %s', r_self) + if (self._debug_stack and level <= STACK_DEBUG): + tulip_log.error(self._debug_stack) + + else: + exc = self._exception + if exc is not None: + tulip_log.exception( + 'Future raised an exception and ' + 'nobody caught it: %s', r_self, + exc_info=(exc.__class__, exc, exc.__traceback__)) + if (self._debug_stack and level <= STACK_DEBUG): + tulip_log.error(self._debug_stack) + else: + tulip_log.error( + 'Future result has not been requested: %s', r_self) + if (self._debug_stack and level <= STACK_DEBUG): + tulip_log.error(self._debug_stack) diff --git a/tulip/http/server.py b/tulip/http/server.py index 2eb6f98b..d722863d 100644 --- a/tulip/http/server.py +++ b/tulip/http/server.py @@ -40,7 +40,7 @@ class ServerHttpProtocol(tulip.Protocol): _request_count = 0 _request_handle = None - def __init__(self, log=logging, debug=False): + def __init__(self, *, log=logging, debug=False): self.log = log self.debug = debug diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 2e93132e..a649d8c8 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -116,7 +116,7 @@ def _accept_connection(self, protocol_factory, sock, ssl=False): if ssl: sslcontext = None if isinstance(ssl, bool) else ssl self._make_ssl_transport( - conn, protocol_factory(), sslcontext, futures.Future(), + conn, protocol_factory(), sslcontext, None, server_side=True, extra={'addr': addr}) else: self._make_socket_transport( @@ -435,7 +435,7 @@ def _call_connection_lost(self, exc): class _SelectorSslTransport(transports.Transport): - def __init__(self, event_loop, rawsock, protocol, sslcontext, waiter, + def __init__(self, event_loop, rawsock, protocol, sslcontext, waiter=None, server_side=False, extra=None): super().__init__(extra) @@ -467,18 +467,21 @@ def _on_handshake(self): return except Exception as exc: self._sslsock.close() - self._waiter.set_exception(exc) + if self._waiter is not None: + self._waiter.set_exception(exc) return except BaseException as exc: self._sslsock.close() - self._waiter.set_exception(exc) + if self._waiter is not None: + self._waiter.set_exception(exc) raise self._event_loop.remove_reader(fd) self._event_loop.remove_writer(fd) self._event_loop.add_reader(fd, self._on_ready) self._event_loop.add_writer(fd, self._on_ready) self._event_loop.call_soon(self._protocol.connection_made, self) - self._event_loop.call_soon(self._waiter.set_result, None) + if self._waiter is not None: + self._event_loop.call_soon(self._waiter.set_result, None) def _on_ready(self): # Because of renegotiations (?), there's no difference between diff --git a/tulip/tasks.py b/tulip/tasks.py index 3d7acc79..67c7b60d 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -11,7 +11,6 @@ import time from . import futures -from .log import tulip_log def coroutine(func): @@ -95,11 +94,10 @@ def cancel(self): self._must_cancel = True # _step() will call super().cancel() to call the callbacks. if self._fut_waiter is not None: - assert not self._fut_waiter.done(), 'Assume it is a race condition.' - self._fut_waiter.cancel() + return self._fut_waiter.cancel() else: self._event_loop.call_soon(self._step_maybe) - return True + return True def cancelled(self): return self._must_cancel or super().cancelled() @@ -136,13 +134,11 @@ def _step(self, value=None, exc=None): super().cancel() else: self.set_exception(exc) - tulip_log.exception('Exception in task') except BaseException as exc: if self._must_cancel: super().cancel() else: self.set_exception(exc) - tulip_log.exception('BaseException in task') raise else: # XXX No check for self._must_cancel here? @@ -239,14 +235,14 @@ def _wait(fs, timeout=None, return_when=ALL_COMPLETED): # Will always be cancelled eventually. bail = futures.Future(timeout=timeout) - def _on_completion(f): - pending.remove(f) - done.add(f) + def _on_completion(fut): + pending.remove(fut) + done.add(fut) if (not pending or return_when == FIRST_COMPLETED or (return_when == FIRST_EXCEPTION and - not f.cancelled() and - f.exception() is not None)): + not fut.cancelled() and + fut.exception() is not None)): bail.cancel() for f in pending: diff --git a/tulip/test_utils.py b/tulip/test_utils.py index d6219143..3c91e99f 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -2,6 +2,7 @@ import cgi import contextlib +import gc import email.parser import http.server import json @@ -14,7 +15,6 @@ import threading import traceback import urllib.parse -import unittest try: import ssl except ImportError: # pragma: no cover @@ -31,24 +31,6 @@ from socket import socketpair # pragma: no cover -class LogTrackingTestCase(unittest.TestCase): - - def setUp(self): - self._logger = logging.getLogger() - self._log_level = self._logger.getEffectiveLevel() - - def tearDown(self): - self._logger.setLevel(self._log_level) - - def suppress_log_errors(self): # pragma: no cover - if self._log_level >= logging.WARNING: - self._logger.setLevel(logging.CRITICAL) - - def suppress_log_warnings(self): # pragma: no cover - if self._log_level >= logging.WARNING: - self._logger.setLevel(logging.ERROR) - - @contextlib.contextmanager def run_test_server(loop, *, host='127.0.0.1', port=0, use_ssl=False, router=None): @@ -110,6 +92,7 @@ def handle_request(self, info, message): def run(loop, fut): thread_loop = tulip.new_event_loop() + thread_loop.set_log_level(logging.CRITICAL) tulip.set_event_loop(thread_loop) sock = thread_loop.run_until_complete( @@ -122,6 +105,7 @@ def run(loop, fut): thread_loop.run_until_complete(waiter) thread_loop.stop() + gc.collect() fut = tulip.Future() server_thread = threading.Thread(target=run, args=(loop, fut)) From 1cfe3deac0b542c8d3df065a85ac802ea964e77c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 9 Apr 2013 12:21:17 -0700 Subject: [PATCH 0409/1502] Prevent spurious warning about abandoned Future. --- tests/base_events_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index e9383264..e991f835 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -132,7 +132,7 @@ def cb(): AssertionError, self.event_loop.run_in_executor, None, events.Timer(10, cb, ())) - def test_run_once_in_executor_canceled(self): + def test_run_once_in_executor_cancelled(self): def cb(): pass h = events.Handle(cb, ()) @@ -143,7 +143,7 @@ def cb(): self.assertTrue(f.done()) self.assertIsNone(f.result()) - def test_run_once_in_executor(self): + def test_run_once_in_executor_plain(self): def cb(): pass h = events.Handle(cb, ()) @@ -162,6 +162,8 @@ def cb(): self.assertIs(f, res) self.assertTrue(executor.submit.called) + f.cancel() # Don't complain about abandoned Future. + def test_run_once(self): self.event_loop._run_once = unittest.mock.Mock() self.event_loop._run_once.side_effect = base_events._StopError From 21ebc8946f6c0d2144201b9435184cd2e601bb50 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 13 Apr 2013 20:05:41 +0300 Subject: [PATCH 0410/1502] Fix issue #15: use less aggressive tail recusion. --- tests/events_test.py | 62 ++++++++++++++++++++++++++--------- tests/selector_events_test.py | 22 +++++++++---- tests/unix_events_test.py | 15 ++++++--- tulip/selector_events.py | 12 +++---- tulip/unix_events.py | 6 ++-- 5 files changed, 81 insertions(+), 36 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index c28d88eb..ab8c968d 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -17,6 +17,8 @@ import unittest import unittest.mock + +from tulip import futures from tulip import events from tulip import transports from tulip import protocols @@ -26,10 +28,13 @@ class MyProto(protocols.Protocol): + done = None - def __init__(self): + def __init__(self, create_future=False): self.state = 'INITIAL' self.nbytes = 0 + if create_future: + self.done = futures.Future() def connection_made(self, transport): self.transport = transport @@ -49,13 +54,18 @@ def eof_received(self): def connection_lost(self, exc): assert self.state in ('CONNECTED', 'EOF'), self.state self.state = 'CLOSED' + if self.done: + self.done.set_result(None) class MyDatagramProto(protocols.DatagramProtocol): + done = None - def __init__(self): + def __init__(self, create_future=False): self.state = 'INITIAL' self.nbytes = 0 + if create_future: + self.done = futures.Future() def connection_made(self, transport): self.transport = transport @@ -72,14 +82,19 @@ def connection_refused(self, exc): def connection_lost(self, exc): assert self.state == 'INITIALIZED', self.state self.state = 'CLOSED' + if self.done: + self.done.set_result(None) class MyReadPipeProto(protocols.Protocol): + done = None - def __init__(self): + def __init__(self, create_future=False): self.state = ['INITIAL'] self.nbytes = 0 self.transport = None + if create_future: + self.done = futures.Future() def connection_made(self, transport): self.transport = transport @@ -98,13 +113,18 @@ def eof_received(self): def connection_lost(self, exc): assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state self.state.append('CLOSED') + if self.done: + self.done.set_result(None) class MyWritePipeProto(protocols.Protocol): + done = None - def __init__(self): + def __init__(self, create_future=False): self.state = 'INITIAL' self.transport = None + if create_future: + self.done = futures.Future() def connection_made(self, transport): self.transport = transport @@ -114,6 +134,8 @@ def connection_made(self, transport): def connection_lost(self, exc): assert self.state == 'CONNECTED', self.state self.state = 'CLOSED' + if self.done: + self.done.set_result(None) class EventLoopTestsMixin: @@ -709,7 +731,7 @@ def connection_made(self, transport): def factory(): nonlocal proto - proto = MyProto() + proto = MyProto(create_future=True) return proto here = os.path.dirname(__file__) @@ -745,7 +767,7 @@ def factory(): # close connection proto.transport.close() - + self.event_loop.run_until_complete(proto.done) self.assertEqual('CLOSED', proto.state) # the client socket must be closed after to avoid ECONNRESET upon @@ -753,11 +775,18 @@ def factory(): client.close() def test_start_serving_sock(self): + proto = futures.Future() + + class TestMyProto(MyProto): + def __init__(self): + super().__init__() + proto.set_result(self) + sock_ob = socket.socket(type=socket.SOCK_STREAM) sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock_ob.bind(('0.0.0.0', 0)) - f = self.event_loop.start_serving(MyProto, sock=sock_ob) + f = self.event_loop.start_serving(TestMyProto, sock=sock_ob) sock = self.event_loop.run_until_complete(f) self.assertIs(sock, sock_ob) @@ -766,8 +795,7 @@ def test_start_serving_sock(self): client = socket.socket() client.connect(('127.0.0.1', port)) client.send(b'xxx') - self.event_loop.run_once() # This is quite mysterious, but necessary. - self.event_loop.run_once() + self.event_loop.run_until_complete(proto) sock.close() client.close() @@ -842,6 +870,9 @@ def test_create_datagram_endpoint_addr_error(self): def test_create_datagram_endpoint(self): class TestMyDatagramProto(MyDatagramProto): + def __init__(self): + super().__init__(create_future=True) + def datagram_received(self, data, addr): super().datagram_received(data, addr) self.transport.sendto(b'resp:'+data, addr) @@ -852,7 +883,8 @@ def datagram_received(self, data, addr): host, port = s_transport.get_extra_info('addr') coro = self.event_loop.create_datagram_endpoint( - MyDatagramProto, remote_addr=(host, port)) + lambda: MyDatagramProto(create_future=True), + remote_addr=(host, port)) transport, client = self.event_loop.run_until_complete(coro) self.assertEqual('INITIALIZED', client.state) @@ -871,7 +903,7 @@ def datagram_received(self, data, addr): # close connection transport.close() - + self.event_loop.run_until_complete(client.done) self.assertEqual('CLOSED', client.state) server.transport.close() @@ -975,7 +1007,7 @@ def test_read_pipe(self): def factory(): nonlocal proto - proto = MyReadPipeProto() + proto = MyReadPipeProto(create_future=True) return proto rpipe, wpipe = os.pipe() @@ -1002,7 +1034,7 @@ def connect(): self.assertEqual(5, proto.nbytes) os.close(wpipe) - self.event_loop.run_once() + self.event_loop.run_until_complete(proto.done) self.assertEqual( ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) # extra info is available @@ -1016,7 +1048,7 @@ def test_write_pipe(self): def factory(): nonlocal proto - proto = MyWritePipeProto() + proto = MyWritePipeProto(create_future=True) return proto rpipe, wpipe = os.pipe() @@ -1052,7 +1084,7 @@ def connect(): # close connection proto.transport.close() - self.event_loop.run_once() + self.event_loop.run_until_complete(proto.done) self.assertEqual('CLOSED', proto.state) diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index e03ef1f3..e0227e6a 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -722,6 +722,7 @@ def test_write_ready_closing(self): transport._write_ready() self.sock.send.assert_called_with(data) self.event_loop.remove_writer.assert_called_with(7) + self.sock.close.assert_called_with() self.protocol.connection_lost.assert_called_with(None) def test_write_ready_no_data(self): @@ -803,7 +804,8 @@ def test_fatal_error(self, m_exc): self.assertEqual([], transport._buffer) self.event_loop.remove_reader.assert_called_with(7) self.event_loop.remove_writer.assert_called_with(7) - self.protocol.connection_lost.assert_called_with(exc) + self.event_loop.call_soon.assert_called_with( + transport._call_connection_lost, exc) m_exc.assert_called_with('Fatal error for %s', transport) def test_connection_lost(self): @@ -911,14 +913,16 @@ def test_fatal_error(self, m_exc): self.assertEqual([], self.transport._buffer) self.assertTrue(self.event_loop.remove_writer.called) self.assertTrue(self.event_loop.remove_reader.called) - self.protocol.connection_lost.assert_called_with(exc) + self.event_loop.call_soon.assert_called_with( + self.protocol.connection_lost, exc) m_exc.assert_called_with('Fatal error for %s', self.transport) def test_close(self): self.transport.close() self.assertTrue(self.transport._closing) self.assertTrue(self.event_loop.remove_reader.called) - self.protocol.connection_lost.assert_called_with(None) + self.event_loop.call_soon.assert_called_with( + self.protocol.connection_lost, None) def test_close_write_buffer(self): self.transport._buffer.append(b'data') @@ -944,7 +948,7 @@ def test_on_ready_recv_eof(self): self.assertTrue(self.event_loop.remove_reader.called) self.assertTrue(self.event_loop.remove_writer.called) self.assertTrue(self.sslsock.close.called) - self.assertTrue(self.protocol.connection_lost.called) + self.protocol.connection_lost.assert_called_with(None) def test_on_ready_recv_retry(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError @@ -1006,7 +1010,8 @@ def test_on_ready_send_closing(self): self.transport._on_ready() self.assertTrue(self.sslsock.close.called) self.assertTrue(self.event_loop.remove_writer.called) - self.assertTrue(self.protocol.connection_lost.called) + self.event_loop.call_soon.assert_called_with( + self.protocol.connection_lost, None) def test_on_ready_send_retry(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError @@ -1219,6 +1224,7 @@ def test_sendto_ready_closing(self): transport._sendto_ready() self.sock.sendto.assert_called_with(data, ()) self.event_loop.remove_writer.assert_called_with(7) + self.sock.close.assert_called_with() self.protocol.connection_lost.assert_called_with(None) def test_sendto_ready_no_data(self): @@ -1281,7 +1287,8 @@ def test_close(self): self.assertTrue(transport._closing) self.event_loop.remove_reader.assert_called_with(7) - self.protocol.connection_lost.assert_called_with(None) + self.event_loop.call_soon.assert_called_with( + transport._call_connection_lost, None) def test_close_write_buffer(self): transport = _SelectorDatagramTransport( @@ -1304,7 +1311,8 @@ def test_fatal_error(self, m_exc): self.assertEqual([], list(transport._buffer)) self.event_loop.remove_writer.assert_called_with(7) self.event_loop.remove_reader.assert_called_with(7) - self.protocol.connection_lost.assert_called_with(exc) + self.event_loop.call_soon.assert_called_with( + transport._call_connection_lost, exc) m_exc.assert_called_with('Fatal error for %s', transport) @unittest.mock.patch('tulip.log.tulip_log.exception') diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index 686f2e41..b7d03c5c 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -215,8 +215,8 @@ def test__read_ready_eof(self, m_fcntl, m_read): tr._read_ready() m_read.assert_called_with(5, tr.max_size) - self.protocol.eof_received.assert_called_with() self.event_loop.remove_reader.assert_called_with(5) + self.protocol.eof_received.assert_called_with() @unittest.mock.patch('os.read') @unittest.mock.patch('fcntl.fcntl') @@ -294,7 +294,8 @@ def test__close(self, m_fcntl, m_read): tr._close(err) self.assertTrue(tr._closing) self.event_loop.remove_reader.assert_called_with(5) - self.protocol.connection_lost.assert_called_with(err) + self.event_loop.call_soon.assert_called_with( + tr._call_connection_lost, err) @unittest.mock.patch('fcntl.fcntl') def test__call_connection_lost(self, m_fcntl): @@ -485,7 +486,8 @@ def test__write_ready_err(self, m_fcntl, m_write, m_logexc): self.event_loop.remove_writer.assert_called_with(5) self.assertEqual([], tr._buffer) self.assertTrue(tr._closing) - self.protocol.connection_lost.assert_called_with(err) + self.event_loop.call_soon.assert_called_with( + tr._call_connection_lost, err) m_logexc.assert_called_with('Fatal error for %s', tr) @unittest.mock.patch('os.write') @@ -502,6 +504,7 @@ def test__write_ready_closing(self, m_fcntl, m_write): self.event_loop.remove_writer.assert_called_with(5) self.assertEqual([], tr._buffer) self.protocol.connection_lost.assert_called_with(None) + self.pipe.close.assert_called_with() @unittest.mock.patch('os.write') @unittest.mock.patch('fcntl.fcntl') @@ -515,7 +518,8 @@ def test_abort(self, m_fcntl, m_write): self.event_loop.remove_writer.assert_called_with(5) self.assertEqual([], tr._buffer) self.assertTrue(tr._closing) - self.protocol.connection_lost.assert_called_with(None) + self.event_loop.call_soon.assert_called_with( + tr._call_connection_lost, None) @unittest.mock.patch('fcntl.fcntl') def test__call_connection_lost(self, m_fcntl): @@ -563,7 +567,8 @@ def test_write_eof(self, m_fcntl): tr.write_eof() self.assertTrue(tr._closing) - self.protocol.connection_lost.assert_called_with(None) + self.event_loop.call_soon.assert_called_with( + tr._call_connection_lost, None) @unittest.mock.patch('fcntl.fcntl') def test_write_eof_pending(self, m_fcntl): diff --git a/tulip/selector_events.py b/tulip/selector_events.py index a649d8c8..f5b7ae37 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -413,7 +413,7 @@ def close(self): self._closing = True self._event_loop.remove_reader(self._sock.fileno()) if not self._buffer: - self._call_connection_lost(None) + self._event_loop.call_soon(self._call_connection_lost, None) def _fatal_error(self, exc): # should be called from exception handler only @@ -424,7 +424,7 @@ def _close(self, exc): self._event_loop.remove_writer(self._sock.fileno()) self._event_loop.remove_reader(self._sock.fileno()) self._buffer.clear() - self._call_connection_lost(exc) + self._event_loop.call_soon(self._call_connection_lost, exc) def _call_connection_lost(self, exc): try: @@ -567,7 +567,7 @@ def close(self): self._closing = True self._event_loop.remove_reader(self._sslsock.fileno()) if not self._buffer: - self._protocol.connection_lost(None) + self._event_loop.call_soon(self._protocol.connection_lost, None) def _fatal_error(self, exc): tulip_log.exception('Fatal error for %s', self) @@ -577,7 +577,7 @@ def _close(self, exc): self._event_loop.remove_writer(self._sslsock.fileno()) self._event_loop.remove_reader(self._sslsock.fileno()) self._buffer = [] - self._protocol.connection_lost(exc) + self._event_loop.call_soon(self._protocol.connection_lost, exc) class _SelectorDatagramTransport(transports.DatagramTransport): @@ -678,7 +678,7 @@ def close(self): self._closing = True self._event_loop.remove_reader(self._fileno) if not self._buffer: - self._call_connection_lost(None) + self._event_loop.call_soon(self._call_connection_lost, None) def _fatal_error(self, exc): tulip_log.exception('Fatal error for %s', self) @@ -690,7 +690,7 @@ def _close(self, exc): self._event_loop.remove_reader(self._fileno) if self._address and isinstance(exc, ConnectionRefusedError): self._protocol.connection_refused(exc) - self._call_connection_lost(exc) + self._event_loop.call_soon(self._call_connection_lost, exc) def _call_connection_lost(self, exc): try: diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 3073ab64..4022b9b4 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -195,7 +195,7 @@ def _fatal_error(self, exc): def _close(self, exc): self._closing = True self._event_loop.remove_reader(self._fileno) - self._call_connection_lost(exc) + self._event_loop.call_soon(self._call_connection_lost, exc) def _call_connection_lost(self, exc): try: @@ -273,7 +273,7 @@ def write_eof(self): assert self._pipe self._closing = True if not self._buffer: - self._call_connection_lost(None) + self._event_loop.call_soon(self._call_connection_lost, None) def close(self): if not self._closing: @@ -292,7 +292,7 @@ def _close(self, exc=None): self._closing = True self._buffer.clear() self._event_loop.remove_writer(self._fileno) - self._call_connection_lost(exc) + self._event_loop.call_soon(self._call_connection_lost, exc) def _call_connection_lost(self, exc): try: From c5d0c83981e593b8b76b719e1859d5c5ffe7ede1 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Sat, 13 Apr 2013 21:02:37 -0700 Subject: [PATCH 0411/1502] cancellation support in task.sleep() --- tests/tasks_test.py | 38 +++++++++++++++++++++++++++++--------- tulip/tasks.py | 15 +++++++-------- 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index ae9264cc..98b08f98 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -236,8 +236,8 @@ def task(): self.assertTrue(0.08 <= t1-t0 <= 0.12) def test_wait(self): - a = tasks.sleep(0.1) - b = tasks.sleep(0.15) + a = tasks.Task(tasks.sleep(0.1)) + b = tasks.Task(tasks.sleep(0.15)) @tasks.coroutine def foo(): @@ -259,8 +259,8 @@ def foo(): # TODO: Test different return_when values. def test_wait_first_completed(self): - a = tasks.sleep(10.0) - b = tasks.sleep(0.1) + a = tasks.Task(tasks.sleep(10.0)) + b = tasks.Task(tasks.sleep(0.1)) task = tasks.Task(tasks.wait( [b, a], return_when=tasks.FIRST_COMPLETED)) @@ -298,7 +298,7 @@ def coro2(): def test_wait_first_exception(self): # first_exception, task already has exception - a = tasks.sleep(10.0) + a = tasks.Task(tasks.sleep(10.0)) @tasks.coroutine def exc(): @@ -314,7 +314,7 @@ def exc(): def test_wait_first_exception_in_wait(self): # first_exception, exception during waiting - a = tasks.sleep(10.0) + a = tasks.Task(tasks.sleep(10.0)) @tasks.coroutine def exc(): @@ -329,7 +329,7 @@ def exc(): self.assertEqual({a}, pending) def test_wait_with_exception(self): - a = tasks.sleep(0.1) + a = tasks.Task(tasks.sleep(0.1)) @tasks.coroutine def sleeper(): @@ -356,8 +356,8 @@ def foo(): self.assertTrue(t1-t0 <= 0.01) def test_wait_with_timeout(self): - a = tasks.sleep(0.1) - b = tasks.sleep(0.15) + a = tasks.Task(tasks.sleep(0.1)) + b = tasks.Task(tasks.sleep(0.15)) @tasks.coroutine def foo(): @@ -440,6 +440,26 @@ def sleeper(dt, arg): self.assertTrue(t.done()) self.assertEqual(t.result(), 'yeah') + def test_sleep_cancel(self): + t = tasks.Task(tasks.sleep(10.0, 'yeah')) + + handle = None + orig_call_later = self.event_loop.call_later + + def call_later(self, delay, callback, *args): + nonlocal handle + handle = orig_call_later(self, delay, callback, *args) + return handle + + self.event_loop.call_later = call_later + self.event_loop.run_once() + + self.assertFalse(handle.cancelled) + + t.cancel() + self.event_loop.run_once() + self.assertTrue(handle.cancelled) + def test_task_cancel_sleeping_task(self): sleepfut = None diff --git a/tulip/tasks.py b/tulip/tasks.py index 67c7b60d..70762c57 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -318,13 +318,12 @@ def _wrap_coroutines(fs): return wrapped +@coroutine def sleep(when, result=None): - """Return a Future that completes after a given time (in seconds). - - It's okay to cancel the Future. - - Undocumented feature: sleep(when, x) sets the Future's result to x. - """ + """Coroutine that completes after a given time (in seconds).""" future = futures.Future() - future._event_loop.call_later(when, future.set_result, result) - return future + h = future._event_loop.call_later(when, future.set_result, result) + try: + return (yield from future) + finally: + h.cancel() From 718f424e2e4f7cc234673f52321352f65f2b04c8 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 14 Apr 2013 15:52:22 -0700 Subject: [PATCH 0412/1502] Add note to subprocess module indicating it is a hack. --- tests/subprocess_test.py | 4 ++++ tulip/subprocess_transport.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/tests/subprocess_test.py b/tests/subprocess_test.py index 09aaed52..9d1ff93a 100644 --- a/tests/subprocess_test.py +++ b/tests/subprocess_test.py @@ -1,3 +1,7 @@ +# NOTE: This is a hack. Andrew Svetlov is working in a proper +# subprocess management transport for use with +# connect_{read,write}_pipe(). + """Tests for subprocess_transport.py.""" import logging diff --git a/tulip/subprocess_transport.py b/tulip/subprocess_transport.py index 734a5fa7..5e4d6550 100644 --- a/tulip/subprocess_transport.py +++ b/tulip/subprocess_transport.py @@ -1,3 +1,7 @@ +# NOTE: This is a hack. Andrew Svetlov is working in a proper +# subprocess management transport for use with +# connect_{read,write}_pipe(). + import fcntl import os import traceback From 1696ae9be246e8cbc585c6f5340642754c8e7694 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 15 Apr 2013 03:35:31 +0300 Subject: [PATCH 0413/1502] Comment out unused _DISCONNECTED set --- tulip/selector_events.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index f5b7ae37..95f8f7d4 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -21,13 +21,16 @@ # Errno values indicating the connection was disconnected. -_DISCONNECTED = frozenset((errno.ECONNRESET, - errno.ENOTCONN, - errno.ESHUTDOWN, - errno.ECONNABORTED, - errno.EPIPE, - errno.EBADF, - )) +# Comment out _DISCONNECTED as never used +# TODO: make sure that errors has processed properly +# for now we have no exception clsses for ENOTCONN and EBADF +# _DISCONNECTED = frozenset((errno.ECONNRESET, +# errno.ENOTCONN, +# errno.ESHUTDOWN, +# errno.ECONNABORTED, +# errno.EPIPE, +# errno.EBADF, +# )) class BaseSelectorEventLoop(base_events.BaseEventLoop): From 2d1b8eecfae716e7ba14a5a6458821459000fb75 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 15 Apr 2013 20:38:29 +0300 Subject: [PATCH 0414/1502] Add _conn_lost logic to unit pipe, extract magic 5 into constants module --- tests/unix_events_test.py | 14 +++++++++++++- tulip/constants.py | 4 ++++ tulip/proactor_events.py | 3 ++- tulip/selector_events.py | 7 ++++--- tulip/unix_events.py | 10 ++++++++++ 5 files changed, 33 insertions(+), 5 deletions(-) create mode 100644 tulip/constants.py diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index b7d03c5c..b96719a8 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -406,9 +406,10 @@ def test_write_again(self, m_fcntl, m_write): self.event_loop.add_writer.assert_called_with(5, tr._write_ready) self.assertEqual([b'data'], tr._buffer) + @unittest.mock.patch('tulip.unix_events.tulip_log') @unittest.mock.patch('os.write') @unittest.mock.patch('fcntl.fcntl') - def test_write_err(self, m_fcntl, m_write): + def test_write_err(self, m_fcntl, m_write, m_log): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) @@ -420,6 +421,16 @@ def test_write_err(self, m_fcntl, m_write): self.assertFalse(self.event_loop.called) self.assertEqual([], tr._buffer) tr._fatal_error.assert_called_with(err) + self.assertEqual(1, tr._conn_lost) + + tr.write(b'data') + self.assertEqual(2, tr._conn_lost) + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + m_log.warning.assert_called_with( + 'os.write(pipe, data) raised exception.') @unittest.mock.patch('os.write') @unittest.mock.patch('fcntl.fcntl') @@ -489,6 +500,7 @@ def test__write_ready_err(self, m_fcntl, m_write, m_logexc): self.event_loop.call_soon.assert_called_with( tr._call_connection_lost, err) m_logexc.assert_called_with('Fatal error for %s', tr) + self.assertEqual(1, tr._conn_lost) @unittest.mock.patch('os.write') @unittest.mock.patch('fcntl.fcntl') diff --git a/tulip/constants.py b/tulip/constants.py new file mode 100644 index 00000000..79c3b931 --- /dev/null +++ b/tulip/constants.py @@ -0,0 +1,4 @@ +"""Constants.""" + + +LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5 diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index 6f38db7d..ba1389d2 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -5,6 +5,7 @@ """ from . import base_events +from . import constants from . import transports from .log import tulip_log @@ -59,7 +60,7 @@ def write(self, data): if not data: return if self._conn_lost: - if self._conn_lost >= 5: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: tulip_log.warning('socket.send() raised exception.') self._conn_lost += 1 return diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 95f8f7d4..450f668a 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -13,6 +13,7 @@ ssl = None from . import base_events +from . import constants from . import events from . import futures from . import selectors @@ -360,7 +361,7 @@ def write(self, data): return if self._conn_lost: - if self._conn_lost >= 5: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: tulip_log.warning('socket.send() raised exception.') self._conn_lost += 1 return @@ -553,7 +554,7 @@ def write(self, data): return if self._conn_lost: - if self._conn_lost >= 5: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: tulip_log.warning('socket.send() raised exception.') self._conn_lost += 1 return @@ -621,7 +622,7 @@ def sendto(self, data, addr=None): assert addr in (None, self._address) if self._conn_lost and self._address: - if self._conn_lost >= 5: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: tulip_log.warning('socket.send() raised exception.') self._conn_lost += 1 return diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 4022b9b4..e7f5af27 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -11,6 +11,7 @@ except ImportError: # pragma: no cover signal = None +from . import constants from . import events from . import selector_events from . import transports @@ -215,6 +216,7 @@ def __init__(self, event_loop, pipe, protocol, waiter=None, extra=None): _set_nonblocking(self._fileno) self._protocol = protocol self._buffer = [] + self._conn_lost = 0 self._closing = False # Set when close() or write_eof() called. self._event_loop.call_soon(self._protocol.connection_made, self) if waiter is not None: @@ -226,6 +228,12 @@ def write(self, data): if not data: return + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('os.write(pipe, data) raised exception.') + self._conn_lost += 1 + return + if not self._buffer: # Attempt to send it right away first. try: @@ -233,6 +241,7 @@ def write(self, data): except BlockingIOError: n = 0 except Exception as exc: + self._conn_lost += 1 self._fatal_error(exc) return if n == len(data): @@ -253,6 +262,7 @@ def _write_ready(self): except BlockingIOError: self._buffer.append(data) except Exception as exc: + self._conn_lost += 1 self._fatal_error(exc) else: if n == len(data): From 15061ca56558e6d2eff27b0d4fce79f46e8e65d7 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 16 Apr 2013 14:56:10 +0300 Subject: [PATCH 0415/1502] Process InterruptedError for unix pipes --- tulip/unix_events.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index e7f5af27..73ada428 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -167,7 +167,7 @@ def __init__(self, event_loop, pipe, protocol, waiter=None, extra=None): def _read_ready(self): try: data = os.read(self._fileno, self.max_size) - except BlockingIOError: + except (BlockingIOError, InterruptedError): pass except OSError as exc: self._fatal_error(exc) @@ -238,7 +238,7 @@ def write(self, data): # Attempt to send it right away first. try: n = os.write(self._fileno, data) - except BlockingIOError: + except (BlockingIOError, InterruptedError): n = 0 except Exception as exc: self._conn_lost += 1 @@ -259,7 +259,7 @@ def _write_ready(self): self._buffer.clear() try: n = os.write(self._fileno, data) - except BlockingIOError: + except (BlockingIOError, InterruptedError): self._buffer.append(data) except Exception as exc: self._conn_lost += 1 From 00d69d3f9993b81302eb1de09b01c657550e0f71 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 17 Apr 2013 15:28:47 -0700 Subject: [PATCH 0416/1502] websocket protocol parser + multiprocess srv.py example --- examples/mpsrv.py | 291 ++++++++++++++++ srv.py | 3 +- tests/http_websocket_test.py | 347 ++++++++++++++++++ tests/parsers_test.py | 657 +++++++++++++++++++++++++++++++++++ tulip/__init__.py | 2 + tulip/http/websocket.py | 175 ++++++++++ tulip/parsers.py | 418 ++++++++++++++++++++++ 7 files changed, 1891 insertions(+), 2 deletions(-) create mode 100644 examples/mpsrv.py create mode 100644 tests/http_websocket_test.py create mode 100644 tests/parsers_test.py create mode 100644 tulip/http/websocket.py create mode 100644 tulip/parsers.py diff --git a/examples/mpsrv.py b/examples/mpsrv.py new file mode 100644 index 00000000..f1a53d41 --- /dev/null +++ b/examples/mpsrv.py @@ -0,0 +1,291 @@ +#!/usr/bin/env python3 +"""Simple multiprocess http server written using an event loop.""" + +import argparse +import email.message +import os +import logging +import socket +import signal +import time +import tulip +import tulip.http +from tulip.http import websocket + +ARGS = argparse.ArgumentParser(description="Run simple http server.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') +ARGS.add_argument( + '--workers', action="store", dest='workers', + default=2, type=int, help='Number of workers.') + + +class HttpServer(tulip.http.ServerHttpProtocol): + + @tulip.coroutine + def handle_request(self, request_info, message): + print('{}: method = {!r}; path = {!r}; version = {!r}'.format( + os.getpid(), request_info.method, + request_info.uri, request_info.version)) + + path = request_info.uri + + if (not (path.isprintable() and path.startswith('/')) or '/.' in path): + path = None + else: + path = '.' + path + if not os.path.exists(path): + path = None + else: + isdir = os.path.isdir(path) + + if not path: + raise tulip.http.HttpStatusException(404) + + headers = email.message.Message() + for hdr, val in message.headers: + headers.add_header(hdr, val) + + if isdir and not path.endswith('/'): + path = path + '/' + raise tulip.http.HttpStatusException( + 302, headers=(('URI', path), ('Location', path))) + + response = tulip.http.Response(self.transport, 200) + response.add_header('Transfer-Encoding', 'chunked') + + # content encoding + accept_encoding = headers.get('accept-encoding', '').lower() + if 'deflate' in accept_encoding: + response.add_header('Content-Encoding', 'deflate') + response.add_compression_filter('deflate') + elif 'gzip' in accept_encoding: + response.add_header('Content-Encoding', 'gzip') + response.add_compression_filter('gzip') + + response.add_chunking_filter(1025) + + if isdir: + response.add_header('Content-type', 'text/html') + response.send_headers() + + response.write(b'
    \r\n') + for name in sorted(os.listdir(path)): + if name.isprintable() and not name.startswith('.'): + try: + bname = name.encode('ascii') + except UnicodeError: + pass + else: + if os.path.isdir(os.path.join(path, name)): + response.write(b'
  • ' + bname + b'/
  • \r\n') + else: + response.write(b'
  • ' + bname + b'
  • \r\n') + response.write(b'
') + else: + response.add_header('Content-type', 'text/plain') + response.send_headers() + + try: + with open(path, 'rb') as fp: + chunk = fp.read(8196) + while chunk: + response.write(chunk) + chunk = fp.read(8196) + except OSError: + response.write(b'Cannot open') + + response.write_eof() + self.close() + + +class ChildProcess: + + def __init__(self, up_read, down_write, args, sock): + self.up_read = up_read + self.down_write = down_write + self.args = args + self.sock = sock + + def start(self): + # start server + self.loop = loop = tulip.new_event_loop() + tulip.set_event_loop(loop) + loop.set_log_level(logging.CRITICAL) + + def stop(): + self.loop.stop() + os._exit(0) + loop.add_signal_handler(signal.SIGINT, stop) + + f = loop.start_serving(lambda: HttpServer(debug=True), sock=self.sock) + x = loop.run_until_complete(f) + print('Starting srv worker process {} on {}'.format( + os.getpid(), x.getsockname())) + + # heartbeat + self.heartbeat() + + tulip.get_event_loop().run_forever() + os._exit(0) + + @tulip.task + def heartbeat(self): + # setup pipes + read_transport, read_proto = yield from self.loop.connect_read_pipe( + tulip.StreamProtocol, os.fdopen(self.up_read, 'rb')) + write_transport, _ = yield from self.loop.connect_write_pipe( + tulip.StreamProtocol, os.fdopen(self.down_write, 'wb')) + + reader = read_proto.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(write_transport) + + while True: + msg = yield from reader.read() + if msg is None: + print('Superviser is dead, {} stopping...'.format(os.getpid())) + self.loop.stop() + break + elif msg.tp == websocket.MSG_PING: + writer.pong() + elif msg.tp == websocket.MSG_CLOSE: + break + + read_transport.close() + write_transport.close() + + +class Worker: + + _started = False + + def __init__(self, loop, args, sock): + self.loop = loop + self.args = args + self.sock = sock + self.start() + + def start(self): + assert not self._started + self._started = True + + up_read, up_write = os.pipe() + down_read, down_write = os.pipe() + args, sock = self.args, self.sock + + pid = os.fork() + if pid: + # parent + os.close(up_read) + os.close(down_write) + self.connect(pid, up_write, down_read) + else: + # child + os.close(up_write) + os.close(down_read) + + # cleanup after fork + tulip.set_event_loop(None) + + # setup process + process = ChildProcess(up_read, down_write, args, sock) + process.start() + + @tulip.task + def heartbeat(self, writer): + while True: + yield from tulip.sleep(15) + + if (time.monotonic() - self.ping) < 30: + writer.ping() + else: + print('Restart unresponsive worker process: {}'.format( + self.pid)) + self.kill() + self.start() + return + + @tulip.task + def chat(self, reader): + while True: + msg = yield from reader.read() + if msg is None: + print('Restart unresponsive worker process: {}'.format( + self.pid)) + self.kill() + self.start() + return + elif msg.tp == websocket.MSG_PONG: + self.ping = time.monotonic() + + @tulip.task + def connect(self, pid, up_write, down_read): + # setup pipes + read_transport, proto = yield from self.loop.connect_read_pipe( + tulip.StreamProtocol, os.fdopen(down_read, 'rb')) + write_transport, _ = yield from self.loop.connect_write_pipe( + tulip.StreamProtocol, os.fdopen(up_write, 'wb')) + + # websocket protocol + reader = proto.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(write_transport) + + # store info + self.pid = pid + self.ping = time.monotonic() + self.rtransport = read_transport + self.wtransport = write_transport + self.chat_task = self.chat(reader) + self.heartbeat_task = self.heartbeat(writer) + + def kill(self): + self._started = False + self.chat_task.cancel() + self.heartbeat_task.cancel() + self.rtransport.close() + self.wtransport.close() + os.kill(self.pid, signal.SIGTERM) + + +class Superviser: + + def __init__(self, args): + self.loop = tulip.get_event_loop() + self.loop.set_log_level(logging.CRITICAL) + self.args = args + self.workers = [] + + def start(self): + # bind socket + sock = self.sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((self.args.host, self.args.port)) + sock.listen(1024) + sock.setblocking(False) + + # start processes + for idx in range(self.args.workers): + self.workers.append(Worker(self.loop, self.args, sock)) + + self.loop.add_signal_handler(signal.SIGINT, lambda: self.loop.stop()) + self.loop.run_forever() + + +def main(): + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + superviser = Superviser(args) + superviser.start() + + +if __name__ == '__main__': + main() diff --git a/srv.py b/srv.py index 8fd6ccf6..a018daa9 100755 --- a/srv.py +++ b/srv.py @@ -90,8 +90,7 @@ def handle_request(self, request_info, message): with open(path, 'rb') as fp: chunk = fp.read(8196) while chunk: - if not response.write(chunk): - break + response.write(chunk): chunk = fp.read(8196) except OSError: response.write(b'Cannot open') diff --git a/tests/http_websocket_test.py b/tests/http_websocket_test.py new file mode 100644 index 00000000..73c77d09 --- /dev/null +++ b/tests/http_websocket_test.py @@ -0,0 +1,347 @@ +"""Tests for http/websocket.py""" + +import struct +import unittest +import unittest.mock + +import tulip +from tulip.http import websocket + + +class WebsocketParserTests(unittest.TestCase): + + def test_parse_frame(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + p.send(struct.pack('!BB', 0b00000001, 0b00000001)) + try: + p.send(b'1') + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b'1'), (fin, opcode, payload)) + + def test_parse_frame_length0(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + try: + p.send(struct.pack('!BB', 0b00000001, 0b00000000)) + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b''), (fin, opcode, payload)) + + def test_parse_frame_length2(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + p.send(struct.pack('!BB', 0b00000001, 126)) + p.send(struct.pack('!H', 4)) + try: + p.send(b'1234') + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b'1234'), (fin, opcode, payload)) + + def test_parse_frame_length4(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + p.send(struct.pack('!BB', 0b00000001, 127)) + p.send(struct.pack('!Q', 4)) + try: + p.send(b'1234') + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b'1234'), (fin, opcode, payload)) + + def test_parse_frame_mask(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + p.send(struct.pack('!BB', 0b00000001, 0b10000001)) + p.send(b'0001') + try: + p.send(b'1') + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b'\x01'), (fin, opcode, payload)) + + def test_parse_frame_header_reversed_bits(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b01100000, 0b00000000)) + + def test_parse_frame_header_control_frame(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b00001000, 0b00000000)) + + def test_parse_frame_header_continuation(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b00000000, 0b00000000)) + + def test_parse_frame_header_new_data_err(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b000000000, 0b00000000)) + + def test_parse_frame_header_payload_size(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b10001000, 0b01111110)) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_ping_frame(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_PING, b'') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_PING, '', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_pong_frame(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_PONG, b'') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_PONG, '', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_close_frame(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_CLOSE, b'') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_CLOSE, '', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_close_frame_info(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_CLOSE, b'0112345') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_CLOSE, 12337, b'12345')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_close_frame_invalid(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_CLOSE, b'1') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + self.assertRaises(websocket.WebSocketError, p.send, b'') + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_unknown_frame(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_CONTINUATION, b'') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + self.assertRaises(websocket.WebSocketError, p.send, b'') + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_simple_text(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_TEXT, b'text') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_TEXT, 'text', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_simple_binary(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_BINARY, b'binary') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_BINARY, b'binary', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_continuation(self, m_parse_frame): + cur = 0 + + def parse_frame(buf): + nonlocal cur + yield + if cur == 0: + cur = 1 + return (0, websocket.OPCODE_TEXT, b'line1') + else: + return (1, websocket.OPCODE_CONTINUATION, b'line2') + + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + p.send(b'') + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_TEXT, 'line1line2', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_continuation_err(self, m_parse_frame): + cur = 0 + + def parse_frame(buf): + nonlocal cur + yield + if cur == 0: + cur = 1 + return (0, websocket.OPCODE_TEXT, b'line1') + else: + return (1, websocket.OPCODE_TEXT, b'line2') + + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + p.send(b'') + self.assertRaises(websocket.WebSocketError, p.send, b'') + + @unittest.mock.patch('tulip.http.websocket.parse_message') + def test_parser(self, m_parse_message): + cur = 0 + + def parse_message(buf): + nonlocal cur + yield + if cur == 0: + cur = 1 + return websocket.Message(websocket.OPCODE_TEXT, b'line1', b'') + else: + return websocket.Message(websocket.OPCODE_CLOSE, b'', b'') + + m_parse_message.side_effect = parse_message + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = websocket.WebSocketParser() + next(p) + p.send((out, buf)) + p.send(b'') + self.assertRaises(StopIteration, p.send, b'') + + self.assertEqual( + (websocket.OPCODE_TEXT, b'line1', b''), out._buffer[0]) + self.assertEqual( + (websocket.OPCODE_CLOSE, b'', b''), out._buffer[1]) + self.assertTrue(out._eof) + + def test_parser_eof(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = websocket.WebSocketParser() + next(p) + p.send((out, buf)) + self.assertRaises(tulip.EofStream, p.throw, tulip.EofStream) + self.assertEqual([], list(out._buffer)) + + +class WebsocketWriterTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + self.writer = websocket.WebSocketWriter(self.transport) + + def test_pong(self): + self.writer.pong() + self.transport.write.assert_called_with(b'\x8a\x00') + + def test_ping(self): + self.writer.ping() + self.transport.write.assert_called_with(b'\x89\x00') + + def test_send_text(self): + self.writer.send(b'text') + self.transport.write.assert_called_with(b'\x81\x04text') + + def test_send_binary(self): + self.writer.send('binary', True) + self.transport.write.assert_called_with(b'\x82\x06binary') + + def test_send_binary_long(self): + self.writer.send(b'b'*127, True) + self.assertTrue( + self.transport.write.call_args[0][0].startswith(b'\x82~\x00\x7fb')) + + def test_send_binary_very_long(self): + self.writer.send(b'b'*65537, True) + self.assertTrue( + self.transport.write.call_args[0][0].startswith( + b'\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x01b')) + + def test_close(self): + self.writer.close(1001, 'msg') + self.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') + + self.writer.close(1001, b'msg') + self.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') diff --git a/tests/parsers_test.py b/tests/parsers_test.py new file mode 100644 index 00000000..cd4e1ca2 --- /dev/null +++ b/tests/parsers_test.py @@ -0,0 +1,657 @@ +"""Tests for parser.py""" + +import unittest +import unittest.mock + +from tulip import events +from tulip import parsers +from tulip import tasks + + +class ParserBufferTests(unittest.TestCase): + + def test_feed_data(self): + buf = parsers.ParserBuffer() + buf.feed_data(b'') + self.assertEqual(len(buf), 0) + + buf.feed_data(b'data') + self.assertEqual(buf.size, 4) + self.assertEqual(len(buf), 4) + self.assertEqual(buf, b'data') + + def test_shrink(self): + buf = parsers.ParserBuffer() + buf.feed_data(b'data') + + buf.shrink() + self.assertEqual(bytes(buf), b'data') + + buf.offset = 2 + buf.shrink() + self.assertEqual(bytes(buf), b'ta') + self.assertEqual(2, len(buf)) + self.assertEqual(2, buf.size) + self.assertEqual(0, buf.offset) + + def test_shrink_feed_data(self): + stream = parsers.StreamBuffer(2) + stream.feed_data(b'data') + self.assertEqual(bytes(stream._buffer), b'data') + + stream._buffer.offset = 2 + stream.feed_data(b'1') + self.assertEqual(bytes(stream._buffer), b'ta1') + self.assertEqual(3, len(stream._buffer)) + self.assertEqual(3, stream._buffer.size) + self.assertEqual(0, stream._buffer.offset) + + def test_read(self): + buf = parsers.ParserBuffer() + p = buf.read(3) + next(p) + p.send(b'1') + try: + p.send(b'234') + except StopIteration as exc: + res = exc.value + + self.assertEqual(res, b'123') + self.assertEqual(b'4', bytes(buf)) + + def test_readsome(self): + buf = parsers.ParserBuffer() + p = buf.readsome(3) + next(p) + try: + p.send(b'1') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, b'1') + + p = buf.readsome(2) + next(p) + try: + p.send(b'234') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, b'23') + self.assertEqual(b'4', bytes(buf)) + + def test_skip(self): + buf = parsers.ParserBuffer() + p = buf.skip(3) + next(p) + p.send(b'1') + try: + p.send(b'234') + except StopIteration as exc: + res = exc.value + + self.assertIsNone(res) + self.assertEqual(b'4', bytes(buf)) + + def test_readuntil_params(self): + buf = parsers.ParserBuffer() + p = buf.readuntil(b'') + self.assertRaises(AssertionError, next, p) + + p = buf.readuntil('\n') + self.assertRaises(AssertionError, next, p) + + def test_readuntil_limit(self): + buf = parsers.ParserBuffer() + p = buf.readuntil(b'\n', limit=4) + next(p) + p.send(b'1') + p.send(b'234') + self.assertRaises(ValueError, p.send, b'5') + + buf = parsers.ParserBuffer() + p = buf.readuntil(b'\n', limit=4) + next(p) + self.assertRaises(ValueError, p.send, b'12345\n6') + + buf = parsers.ParserBuffer() + p = buf.readuntil(b'\n', limit=4) + next(p) + self.assertRaises(ValueError, p.send, b'12345\n6') + + class CustomExc(Exception): + pass + + buf = parsers.ParserBuffer() + p = buf.readuntil(b'\n', limit=4, exc=CustomExc) + next(p) + self.assertRaises(CustomExc, p.send, b'12345\n6') + + def test_readuntil(self): + buf = parsers.ParserBuffer() + p = buf.readuntil(b'\n', limit=4) + next(p) + p.send(b'123') + try: + p.send(b'\n456') + except StopIteration as exc: + res = exc.value + + self.assertEqual(res, b'123\n') + self.assertEqual(b'456', bytes(buf)) + + def test_skipuntil_params(self): + buf = parsers.ParserBuffer() + p = buf.skipuntil(b'') + self.assertRaises(AssertionError, next, p) + + p = buf.skipuntil('\n') + self.assertRaises(AssertionError, next, p) + + def test_skipuntil(self): + buf = parsers.ParserBuffer() + p = buf.skipuntil(b'\n') + next(p) + p.send(b'123') + try: + p.send(b'\n456\n') + except StopIteration: + pass + self.assertEqual(b'456\n', bytes(buf)) + + p = buf.readline() + try: + next(p) + except StopIteration as exc: + res = exc.value + self.assertEqual(b'', bytes(buf)) + self.assertEqual(b'456\n', res) + + def test_readline_limit(self): + buf = parsers.ParserBuffer() + p = buf.readline(limit=4) + next(p) + p.send(b'1') + p.send(b'234') + self.assertRaises(ValueError, p.send, b'5') + + buf = parsers.ParserBuffer() + p = buf.readline(limit=4) + next(p) + self.assertRaises(ValueError, p.send, b'12345\n6') + + buf = parsers.ParserBuffer() + p = buf.readline(limit=4) + next(p) + self.assertRaises(ValueError, p.send, b'12345\n6') + + def test_readline(self): + buf = parsers.ParserBuffer() + p = buf.readline(limit=4) + next(p) + p.send(b'123') + try: + p.send(b'\n456') + except StopIteration as exc: + res = exc.value + + self.assertEqual(res, b'123\n') + self.assertEqual(b'456', bytes(buf)) + + def test_lines_parser(self): + out = parsers.DataBuffer() + buf = parsers.ParserBuffer() + p = parsers.lines_parser() + next(p) + p.send((out, buf)) + + for d in (b'line1', b'\r\n', b'lin', b'e2\r', b'\ndata'): + p.send(d) + + self.assertEqual( + [bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(out._buffer)) + try: + p.throw(parsers.EofStream()) + except parsers.EofStream: + pass + + self.assertEqual(bytes(buf), b'data') + + def test_chunks_parser(self): + out = parsers.DataBuffer() + buf = parsers.ParserBuffer() + p = parsers.chunks_parser(5) + next(p) + p.send((out, buf)) + + for d in (b'line1', b'lin', b'e2d', b'ata'): + p.send(d) + + self.assertEqual( + [bytearray(b'line1'), bytearray(b'line2')], list(out._buffer)) + try: + p.throw(parsers.EofStream()) + except parsers.EofStream: + pass + + self.assertEqual(bytes(buf), b'data') + + +class StreamBufferTests(unittest.TestCase): + + DATA = b'line1\nline2\nline3\n' + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + + def test_exception(self): + stream = parsers.StreamBuffer() + self.assertIsNone(stream.exception()) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(stream.exception(), exc) + + def test_exception_waiter(self): + stream = parsers.StreamBuffer() + + stream._parser = parsers.lines_parser() + buf = stream._parser_buffer = parsers.DataBuffer() + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(buf.exception(), exc) + + def test_feed_data(self): + stream = parsers.StreamBuffer() + + stream.feed_data(self.DATA) + self.assertEqual(self.DATA, bytes(stream._buffer)) + + def test_feed_empty_data(self): + stream = parsers.StreamBuffer() + + stream.feed_data(b'') + self.assertEqual(b'', bytes(stream._buffer)) + + def test_set_parser_unset_prev(self): + stream = parsers.StreamBuffer() + stream.set_parser(parsers.lines_parser()) + + unset = stream.unset_parser = unittest.mock.Mock() + stream.set_parser(parsers.lines_parser()) + + self.assertTrue(unset.called) + + def test_set_parser_exception(self): + stream = parsers.StreamBuffer() + + exc = ValueError() + stream.set_exception(exc) + s = stream.set_parser(parsers.lines_parser()) + self.assertIs(s.exception(), exc) + + def test_set_parser_feed_existing(self): + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_data(b'\r\nline2\r\ndata') + s = stream.set_parser(parsers.lines_parser()) + + self.assertEqual([bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertIsNotNone(stream._parser) + + stream.unset_parser() + self.assertIsNone(stream._parser) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertTrue(s._eof) + + def test_set_parser_feed_existing_exc(self): + + def p(): + yield # stream + raise ValueError() + + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + s = stream.set_parser(p()) + self.assertIsInstance(s.exception(), ValueError) + + def test_set_parser_feed_existing_eof(self): + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_data(b'\r\nline2\r\ndata') + stream.feed_eof() + s = stream.set_parser(parsers.lines_parser()) + + self.assertEqual([bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertIsNone(stream._parser) + + def test_set_parser_feed_existing_eof_exc(self): + + def p(): + yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + raise ValueError() + + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_eof() + s = stream.set_parser(p()) + self.assertIsInstance(s.exception(), ValueError) + + def test_set_parser_feed_existing_eof_unhandled_eof(self): + + def p(): + yield # stream + while True: + yield # read chunk + + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_eof() + s = stream.set_parser(p()) + self.assertIsNone(s.exception()) + self.assertTrue(s._eof) + + def test_set_parser_unset(self): + stream = parsers.StreamBuffer() + s = stream.set_parser(parsers.lines_parser()) + + stream.feed_data(b'line1\r\nline2\r\n') + self.assertEqual( + [bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'', bytes(stream._buffer)) + stream.unset_parser() + self.assertTrue(s._eof) + self.assertEqual(b'', bytes(stream._buffer)) + + def test_set_parser_feed_existing_stop(self): + def lines_parser(): + out, buf = yield + try: + out.feed_data((yield from buf.readline())) + out.feed_data((yield from buf.readline())) + finally: + out.feed_eof() + + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_data(b'\r\nline2\r\ndata') + s = stream.set_parser(lines_parser()) + + self.assertEqual([bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertIsNone(stream._parser) + self.assertTrue(s._eof) + + def test_feed_parser(self): + stream = parsers.StreamBuffer() + s = stream.set_parser(parsers.lines_parser()) + + stream.feed_data(b'line1') + stream.feed_data(b'\r\nline2\r\ndata') + self.assertEqual(b'data', bytes(stream._buffer)) + + stream.feed_eof() + self.assertEqual([bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertTrue(s._eof) + + def test_feed_parser_exc(self): + def p(): + yield # stream + yield # read chunk + raise ValueError() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + self.assertIsInstance(s.exception(), ValueError) + self.assertEqual(b'', bytes(stream._buffer)) + + def test_feed_parser_stop(self): + def p(): + yield # stream + yield # chunk + + stream = parsers.StreamBuffer() + stream.set_parser(p()) + + stream.feed_data(b'line1') + self.assertIsNone(stream._parser) + self.assertEqual(b'', bytes(stream._buffer)) + + def test_feed_eof_exc(self): + def p(): + yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + raise ValueError() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + self.assertIsNone(s.exception()) + + stream.feed_eof() + self.assertIsInstance(s.exception(), ValueError) + + def test_feed_eof_stop(self): + def p(): + out, buf = yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + out.feed_eof() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.feed_eof() + self.assertTrue(s._eof) + + def test_feed_eof_unhandled_eof(self): + def p(): + yield # stream + while True: + yield # read chunk + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.feed_eof() + self.assertIsNone(s.exception()) + self.assertTrue(s._eof) + + def test_feed_parser2(self): + stream = parsers.StreamBuffer() + s = stream.set_parser(parsers.lines_parser()) + + stream.feed_data(b'line1\r\nline2\r\n') + stream.feed_eof() + self.assertEqual( + [bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'', bytes(stream._buffer)) + self.assertTrue(s._eof) + + def test_unset_parser_eof_exc(self): + def p(): + yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + raise ValueError() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.unset_parser() + self.assertIsInstance(s.exception(), ValueError) + self.assertIsNone(stream._parser) + + def test_unset_parser_eof_unhandled_eof(self): + def p(): + yield # stream + while True: + yield # read chunk + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.unset_parser() + self.assertIsNone(s.exception(), ValueError) + self.assertTrue(s._eof) + + def test_unset_parser_stop(self): + def p(): + out, buf = yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + out.feed_eof() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.unset_parser() + self.assertTrue(s._eof) + + +class DataBufferTests(unittest.TestCase): + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + + def test_feed_data(self): + buffer = parsers.DataBuffer() + + item = object() + buffer.feed_data(item) + self.assertEqual([item], list(buffer._buffer)) + + def test_feed_eof(self): + buffer = parsers.DataBuffer() + buffer.feed_eof() + self.assertTrue(buffer._eof) + + def test_read(self): + item = object() + buffer = parsers.DataBuffer() + read_task = tasks.Task(buffer.read()) + + def cb(): + buffer.feed_data(item) + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertIs(item, data) + + def test_read_eof(self): + buffer = parsers.DataBuffer() + read_task = tasks.Task(buffer.read()) + + def cb(): + buffer.feed_eof() + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertIsNone(data) + + def test_read_until_eof(self): + item = object() + buffer = parsers.DataBuffer() + buffer.feed_data(item) + buffer.feed_eof() + + data = self.loop.run_until_complete(buffer.read()) + self.assertIs(data, item) + + data = self.loop.run_until_complete(buffer.read()) + self.assertIsNone(data) + + def test_read_exception(self): + buffer = parsers.DataBuffer() + buffer.feed_data(object()) + buffer.set_exception(ValueError()) + + self.assertRaises( + ValueError, self.loop.run_until_complete, buffer.read()) + + def test_exception(self): + buffer = parsers.DataBuffer() + self.assertIsNone(buffer.exception()) + + exc = ValueError() + buffer.set_exception(exc) + self.assertIs(buffer.exception(), exc) + + def test_exception_waiter(self): + buffer = parsers.DataBuffer() + + @tasks.coroutine + def set_err(): + buffer.set_exception(ValueError()) + + t1 = tasks.Task(buffer.read()) + t2 = tasks.Task(set_err()) + + self.loop.run_until_complete(tasks.wait([t1, t2])) + + self.assertRaises(ValueError, t1.result) + + +class StreamProtocolTests(unittest.TestCase): + + def test_connection_made(self): + tr = unittest.mock.Mock() + + proto = parsers.StreamProtocol() + self.assertIsNone(proto.transport) + + proto.connection_made(tr) + self.assertIs(proto.transport, tr) + + def test_connection_lost(self): + proto = parsers.StreamProtocol() + proto.connection_made(unittest.mock.Mock()) + proto.connection_lost(None) + self.assertIsNone(proto.transport) + self.assertTrue(proto._eof) + + def test_connection_lost_exc(self): + proto = parsers.StreamProtocol() + proto.connection_made(unittest.mock.Mock()) + + exc = ValueError() + proto.connection_lost(exc) + self.assertIs(proto.exception(), exc) diff --git a/tulip/__init__.py b/tulip/__init__.py index faf307fb..9de84cb0 100644 --- a/tulip/__init__.py +++ b/tulip/__init__.py @@ -7,6 +7,7 @@ from .events import * from .locks import * from .transports import * +from .parsers import * from .protocols import * from .streams import * from .tasks import * @@ -21,6 +22,7 @@ events.__all__ + locks.__all__ + transports.__all__ + + parsers.__all__ + protocols.__all__ + streams.__all__ + tasks.__all__) diff --git a/tulip/http/websocket.py b/tulip/http/websocket.py new file mode 100644 index 00000000..a80745ed --- /dev/null +++ b/tulip/http/websocket.py @@ -0,0 +1,175 @@ +"""WebSocket protocol versions 13 and 8.""" + +__all__ = ['WebSocketParser', 'WebSocketWriter', 'Message', 'WebSocketError', + 'MSG_TEXT', 'MSG_BINARY', 'MSG_CLOSE', 'MSG_PING', 'MSG_PONG'] + +import collections +import struct + +# Frame opcodes defined in the spec. +OPCODE_CONTINUATION = 0x0 +MSG_TEXT = OPCODE_TEXT = 0x1 +MSG_BINARY = OPCODE_BINARY = 0x2 +MSG_CLOSE = OPCODE_CLOSE = 0x8 +MSG_PING = OPCODE_PING = 0x9 +MSG_PONG = OPCODE_PONG = 0xa + +Message = collections.namedtuple('Message', ['tp', 'data', 'extra']) + + +class WebSocketError(Exception): + """WebSocket protocol parser error.""" + + +def WebSocketParser(): + out, buf = yield + + while True: + message = yield from parse_message(buf) + out.feed_data(message) + + if message.tp == MSG_CLOSE: + out.feed_eof() + break + + +def parse_frame(buf): + """Return the next frame from the socket.""" + # read header + data = yield from buf.read(2) + first_byte, second_byte = struct.unpack('!BB', data) + + fin = (first_byte >> 7) & 1 + rsv1 = (first_byte >> 6) & 1 + rsv2 = (first_byte >> 5) & 1 + rsv3 = (first_byte >> 4) & 1 + opcode = first_byte & 0xf + + # frame-fin = %x0 ; more frames of this message follow + # / %x1 ; final frame of this message + # frame-rsv1 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise + # frame-rsv2 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise + # frame-rsv3 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise + if rsv1 or rsv2 or rsv3: + raise WebSocketError('Received frame with non-zero reserved bits') + + if opcode > 0x7 and fin == 0: + raise WebSocketError('Received fragmented control frame') + + if fin == 0 and opcode == OPCODE_CONTINUATION: + raise WebSocketError( + 'Received new fragment frame with non-zero opcode') + + has_mask = (second_byte >> 7) & 1 + length = (second_byte) & 0x7f + + # Control frames MUST have a payload length of 125 bytes or less + if opcode > 0x7 and length > 125: + raise WebSocketError( + "Control frame payload cannot be larger than 125 bytes") + + # read payload + if length == 126: + data = yield from buf.read(2) + length = struct.unpack_from('!H', data)[0] + elif length > 126: + data = yield from buf.read(8) + length = struct.unpack_from('!Q', data)[0] + + if has_mask: + mask = yield from buf.read(4) + + if length: + payload = yield from buf.read(length) + else: + payload = b'' + + if has_mask: + payload = bytes(b ^ mask[i % 4] for i, b in enumerate(payload)) + + return fin, opcode, payload + + +def parse_message(buf): + fin, opcode, payload = yield from parse_frame(buf) + + if opcode == OPCODE_CLOSE: + if len(payload) >= 2: + close_code = struct.unpack('!H', payload[:2])[0] + close_message = payload[2:] + return Message(OPCODE_CLOSE, close_code, close_message) + elif payload: + raise WebSocketError( + 'Invalid close frame: {} {} {!r}'.format(fin, opcode, payload)) + return Message(OPCODE_CLOSE, '', '') + + elif opcode == OPCODE_PING: + return Message(OPCODE_PING, '', '') + + elif opcode == OPCODE_PONG: + return Message(OPCODE_PONG, '', '') + + elif opcode not in (OPCODE_TEXT, OPCODE_BINARY): + raise WebSocketError("Unexpected opcode={!r}".format(opcode)) + + # load text/binary + data = [payload] + + while not fin: + fin, _opcode, payload = yield from parse_frame(buf) + if _opcode != OPCODE_CONTINUATION: + raise WebSocketError( + 'The opcode in non-fin frame is expected ' + 'to be zero, got {!r}'.format(opcode)) + else: + data.append(payload) + + if opcode == OPCODE_TEXT: + return Message(OPCODE_TEXT, b''.join(data).decode('utf-8'), '') + else: + return Message(OPCODE_BINARY, b''.join(data), '') + + +class WebSocketWriter: + + def __init__(self, transport): + self.transport = transport + + def _send_frame(self, message, opcode): + """Send a frame over the websocket with message as its payload.""" + header = bytes([0x80 | opcode]) + msg_length = len(message) + + if msg_length < 126: + header += bytes([msg_length]) + elif msg_length < (1 << 16): + header += bytes([126]) + struct.pack('!H', msg_length) + else: + header += bytes([127]) + struct.pack('!Q', msg_length) + + self.transport.write(header + message) + + def pong(self): + """Send pong message.""" + self._send_frame(b'', OPCODE_PONG) + + def ping(self): + """Send pong message.""" + self._send_frame(b'', OPCODE_PING) + + def send(self, message, binary=False): + """Send a frame over the websocket with message as its payload.""" + if isinstance(message, str): + message = message.encode('utf-8') + if binary: + self._send_frame(message, OPCODE_BINARY) + else: + self._send_frame(message, OPCODE_TEXT) + + def close(self, code=1000, message=b''): + """Close the websocket, sending the specified code and message.""" + if isinstance(message, str): + message = message.encode('utf-8') + self._send_frame( + struct.pack('!H%ds' % len(message), code, message), + opcode=OPCODE_CLOSE) diff --git a/tulip/parsers.py b/tulip/parsers.py new file mode 100644 index 00000000..9d8151de --- /dev/null +++ b/tulip/parsers.py @@ -0,0 +1,418 @@ +"""Parser is a generator function. + +Parser receives data with generator's send() method and sends data to +destination DataBuffer. Parser receives ParserBuffer and DataBuffer objects +as a parameters of the first send() call, all subsequent send() calls should +send bytes objects. Parser sends parsed 'term' to desitnation buffer with +DataBuffer.feed_data() method. DataBuffer object should implement two methods. +feed_data() - parser uses this method to send parsed protocol data. +feed_eof() - parser uses this method for indication of end of parsing stream. +To indicate end of incoming data stream EofStream exception should be sent +into parser. Parser could throw exceptions. + +There are three stages: + + * Data flow chain: + + 1. Application creates StreamBuffer object for storing incoming data. + 2. StreamBuffer creates ParserBuffer as internal data buffer. + 3. Application create parser and set it into stream buffer: + + parser = http_request_parser() + data_buffer = stream.set_parser(parser) + + 3. At this stage StreamBuffer creates DataBuffer object and passes it + and internal buffer into parser with first send() call. + + def set_parser(self, parser): + next(parser) + data_buffer = DataBuffer() + parser.send((data_buffer, self._buffer)) + return data_buffer + + 4. Application waits data on data_buffer.read() + + while True: + msg = yield form data_buffer.read() + ... + + * Data flow: + + 1. Tulip's transport reads data from socket and sends data to protocol + with data_received() call. + 2. Protocol sends data to StreamBuffer with feed_data() call. + 3. StreamBuffer sends data into parser with generator's send() method. + 4. Parser processes incoming data and sends parsed data + to DataBuffer with feed_data() + 4. Application received parsed data from DataBuffer.read() + + * Eof: + + 1. StreamBuffer recevies eof with feed_eof() call. + 2. StreamBuffer throws EofStream exception into parser. + 3. Then it unsets parser. + +_SocketSocketTransport -> + -> "protocol" -> StreamBuffer -> "parser" -> DataBuffer <- "application" + +""" +__all__ = ['EofStream', 'StreamBuffer', 'StreamProtocol', + 'ParserBuffer', 'DataBuffer', 'lines_parser', 'chunks_parser'] + +import collections + +from . import tasks +from . import futures +from . import protocols + + +class EofStream(Exception): + """eof stream indication.""" + + +class StreamBuffer: + """StreamBuffer manages incoming bytes stream and protocol parsers. + + StreamBuffer uses ParserBuffer as internal buffer. + + set_parser() sets current parser, it creates DataBuffer object + and sends ParserBuffer and DataBuffer into parser generator. + + unset_parser() sends EofStream into parser and then removes it. + """ + + def __init__(self, buffer_size=5120): + self._buffer = ParserBuffer() + self._buffer_size = buffer_size + self._eof = False + self._parser = None + self._parser_buffer = None + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + if self._parser_buffer is not None: + self._buffer.shrink() + self._parser_buffer.set_exception(exc) + + self._parser = None + self._parser_buffer = None + + def feed_data(self, data): + """send data to current parser or store in buffer.""" + if not data: + return + + if self._parser: + try: + self._parser.send(data) + except StopIteration: + self._parser = None + self._parser_buffer = None + except Exception as exc: + self._parser_buffer.set_exception(exc) + self._parser = None + self._parser_buffer = None + else: + self._buffer.feed_data(data) + + # shrink buffer + if (self._buffer.offset and len(self._buffer) > self._buffer_size): + self._buffer.shrink() + + def feed_eof(self): + """send eof to all parsers, recursively.""" + if self._parser: + try: + self._parser.throw(EofStream()) + except StopIteration: + pass + except EofStream: + self._parser_buffer.feed_eof() + except Exception as exc: + self._parser_buffer.set_exception(exc) + + self._parser = None + self._parser_buffer = None + + self._eof = True + self._buffer.shrink() + + def set_parser(self, p): + """set parser to stream. return parser's DataStream.""" + if self._parser: + self.unset_parser() + + out = DataBuffer() + if self._exception: + out.set_exception(self._exception) + return out + + # init generator + next(p) + try: + # initialize parser with data and parser buffers + p.send((out, self._buffer)) + except StopIteration: + pass + except Exception as exc: + out.set_exception(exc) + else: + # parser still require more data + self._parser = p + self._parser_buffer = out + + if self._eof: + self.unset_parser() + + return out + + def unset_parser(self): + """unset parser, send eof to the parser and then remove it.""" + assert self._parser is not None, 'Paser is not set.' + + try: + self._parser.throw(EofStream()) + except StopIteration: + pass + except EofStream: + self._parser_buffer.feed_eof() + except Exception as exc: + self._parser_buffer.set_exception(exc) + finally: + self._parser = None + self._parser_buffer = None + + +class StreamProtocol(StreamBuffer, protocols.Protocol): + """Tulip's stream protocol based on StreamBuffer""" + + transport = None + + data_received = StreamBuffer.feed_data + + eof_received = StreamBuffer.feed_eof + + def connection_made(self, transport): + self.transport = transport + + def connection_lost(self, exc): + self.transport = None + + if exc is not None: + self.set_exception(exc) + else: + self.feed_eof() + + +class DataBuffer: + """DataBuffer is a destination for parsed data.""" + + def __init__(self): + self._buffer = collections.deque() + self._eof = False + self._waiter = None + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + if self._waiter is not None: + self._waiter.set_exception(exc) + + def feed_data(self, data): + self._buffer.append(data) + + waiter = self._waiter + if waiter is not None: + self._waiter = None + waiter.set_result(True) + + def feed_eof(self): + self._eof = True + + waiter = self._waiter + if waiter is not None: + self._waiter = None + waiter.set_result(True) + + @tasks.coroutine + def read(self): + if self._exception is not None: + raise self._exception + + if not self._buffer and not self._eof: + assert not self._waiter + self._waiter = futures.Future() + yield from self._waiter + + if self._buffer: + return self._buffer.popleft() + else: + return None + + +class ParserBuffer(bytearray): + """ParserBuffer is a bytearray extension. + + ParserBuffer provides helper methods for parsers. + """ + + def __init__(self, *args): + super().__init__(*args) + + self.offset = 0 + self.size = 0 + self._writer = self._feed_data() + next(self._writer) + + def shrink(self): + if self.offset: + del self[:self.offset] + self.offset = 0 + self.size = len(self) + + def _feed_data(self): + while True: + chunk = yield + if chunk: + chunk_len = len(chunk) + self.size += chunk_len + self.extend(chunk) + + def feed_data(self, data): + self._writer.send(data) + + def read(self, size): + """read() reads specified amount of bytes.""" + + while True: + if self.size >= size: + start, end = self.offset, self.offset + size + self.offset = end + self.size = self.size - size + return self[start:end] + + self._writer.send((yield)) + + def readsome(self, size=None): + """reads size of less amount of bytes.""" + + while True: + if self.size > 0: + if size is None or self.size < size: + size = self.size + + start, end = self.offset, self.offset + size + self.offset = end + self.size = self.size - size + + return self[start:end] + + self._writer.send((yield)) + + def readline(self, limit=2**16, exc=ValueError): + """readline() reads until \n string.""" + + while True: + new_line = self.find(b'\n', self.offset) + if new_line >= 0: + end = new_line + 1 + size = end - self.offset + if size > limit: + raise exc('Line is too long.') + + start, self.offset = self.offset, end + self.size = self.size - size + + return self[start:end] + else: + if self.size > limit: + raise exc('Line is too long.') + + self._writer.send((yield)) + + def readuntil(self, stop, limit=None, exc=ValueError): + assert isinstance(stop, bytes) and stop, \ + 'bytes is required: {!r}'.format(stop) + + stop_len = len(stop) + + while True: + new_line = self.find(stop, self.offset) + if new_line >= 0: + end = new_line + stop_len + size = end - self.offset + if size > limit: + raise exc('Line is too long.') + + start, self.offset = self.offset, end + self.size = self.size - size + + return self[start:end] + else: + if limit is not None and self.size > limit: + raise exc('Line is too long.') + + self._writer.send((yield)) + + def skip(self, size): + """skip() skips specified amount of bytes.""" + + while self.size < size: + self._writer.send((yield)) + + self.size -= size + self.offset += size + + def skipuntil(self, stop): + """skipuntil() reads until `stop` bytes sequence.""" + assert isinstance(stop, bytes) and stop, \ + 'bytes is required: {!r}'.format(stop) + + stop_len = len(stop) + + while True: + stop_line = self.find(stop, self.offset) + if stop_line >= 0: + end = stop_line + stop_len + self.size = self.size - end - self.offset + self.offset = end + return + else: + self.size = 0 + self.offset = len(self) - 1 + + self._writer.send((yield)) + + def __bytes__(self): + return bytes(self[self.offset:]) + + +def lines_parser(limit=2**16, exc=ValueError): + """Lines parser. + + lines parser splits a bytes stream into a chunks of data, each chunk ends + with \n symbol.""" + out, buf = yield + + while True: + out.feed_data((yield from buf.readline(limit, exc))) + + +def chunks_parser(size=8196): + """Chunks parser. + + chunks parser splits a bytes stream into a specified + size chunks of data.""" + out, buf = yield + + while True: + out.feed_data((yield from buf.read(size))) From f045f7fe3e12384abcf1eb2568259ab27143afc7 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 22 Apr 2013 20:49:18 +0300 Subject: [PATCH 0417/1502] Add support for disconnecting notifications for write pipe. --- tests/events_test.py | 45 +++++++- tests/unix_events_test.py | 220 ++++++++++++++++++++++---------------- tulip/unix_events.py | 25 +++++ 3 files changed, 195 insertions(+), 95 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index ab8c968d..eafcfd67 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -1052,6 +1052,7 @@ def factory(): return proto rpipe, wpipe = os.pipe() + self.addCleanup(os.close, rpipe) pipeobj = io.open(wpipe, 'wb', 1024) @tasks.task @@ -1077,8 +1078,6 @@ def connect(): self.assertEqual(b'2345', data) self.assertEqual('CONNECTED', proto.state) - os.close(rpipe) - # extra info is available self.assertIsNotNone(proto.transport.get_extra_info('pipe')) @@ -1087,6 +1086,48 @@ def connect(): self.event_loop.run_until_complete(proto.done) self.assertEqual('CLOSED', proto.state) + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe_disconnect_on_close(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto(create_future=True) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.task + def connect(): + nonlocal transport + t, p = yield from self.event_loop.connect_write_pipe(factory, + pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.event_loop.run_until_complete(connect()) + + transport.write(b'1') + self.event_loop.run_once() + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + transport.write(b'2345') + self.event_loop.run_once() + data = os.read(rpipe, 1024) + self.assertEqual(b'2345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(rpipe) + + self.event_loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + if sys.platform == 'win32': from tulip import windows_events diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index b96719a8..ea832029 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -2,6 +2,9 @@ import errno import io +import os +import stat +import tempfile import unittest import unittest.mock @@ -179,16 +182,18 @@ def setUp(self): self.pipe.fileno.return_value = 5 self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) - @unittest.mock.patch('fcntl.fcntl') - def test_ctor(self, m_fcntl): + fcntl_patcher = unittest.mock.patch('fcntl.fcntl') + fcntl_patcher.start() + self.addCleanup(fcntl_patcher.stop) + + def test_ctor(self): tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol) self.event_loop.add_reader.assert_called_with(5, tr._read_ready) self.event_loop.call_soon.assert_called_with( self.protocol.connection_made, tr) - @unittest.mock.patch('fcntl.fcntl') - def test_ctor_with_waiter(self, m_fcntl): + def test_ctor_with_waiter(self): fut = futures.Future() unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol, fut) @@ -196,8 +201,7 @@ def test_ctor_with_waiter(self, m_fcntl): fut.cancel() @unittest.mock.patch('os.read') - @unittest.mock.patch('fcntl.fcntl') - def test__read_ready(self, m_fcntl, m_read): + def test__read_ready(self, m_read): tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol) m_read.return_value = b'data' @@ -207,8 +211,7 @@ def test__read_ready(self, m_fcntl, m_read): self.protocol.data_received.assert_called_with(b'data') @unittest.mock.patch('os.read') - @unittest.mock.patch('fcntl.fcntl') - def test__read_ready_eof(self, m_fcntl, m_read): + def test__read_ready_eof(self, m_read): tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol) m_read.return_value = b'' @@ -219,8 +222,7 @@ def test__read_ready_eof(self, m_fcntl, m_read): self.protocol.eof_received.assert_called_with() @unittest.mock.patch('os.read') - @unittest.mock.patch('fcntl.fcntl') - def test__read_ready_blocked(self, m_fcntl, m_read): + def test__read_ready_blocked(self, m_read): tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol) self.event_loop.reset_mock() @@ -232,8 +234,7 @@ def test__read_ready_blocked(self, m_fcntl, m_read): @unittest.mock.patch('tulip.log.tulip_log.exception') @unittest.mock.patch('os.read') - @unittest.mock.patch('fcntl.fcntl') - def test__read_ready_error(self, m_fcntl, m_read, m_logexc): + def test__read_ready_error(self, m_read, m_logexc): tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol) err = OSError() @@ -246,50 +247,40 @@ def test__read_ready_error(self, m_fcntl, m_read, m_logexc): m_logexc.assert_called_with('Fatal error for %s', tr) @unittest.mock.patch('os.read') - @unittest.mock.patch('fcntl.fcntl') - def test_pause(self, m_fcntl, m_read): + def test_pause(self, m_read): tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol) - tr.pause() self.event_loop.remove_reader.assert_called_with(5) @unittest.mock.patch('os.read') - @unittest.mock.patch('fcntl.fcntl') - def test_resume(self, m_fcntl, m_read): + def test_resume(self, m_read): tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol) - tr.resume() self.event_loop.add_reader.assert_called_with(5, tr._read_ready) @unittest.mock.patch('os.read') - @unittest.mock.patch('fcntl.fcntl') - def test_close(self, m_fcntl, m_read): + def test_close(self, m_read): tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol) - tr._close = unittest.mock.Mock() tr.close() tr._close.assert_called_with(None) @unittest.mock.patch('os.read') - @unittest.mock.patch('fcntl.fcntl') - def test_close_already_closing(self, m_fcntl, m_read): + def test_close_already_closing(self, m_read): tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol) - tr._closing = True tr._close = unittest.mock.Mock() tr.close() self.assertFalse(tr._close.called) @unittest.mock.patch('os.read') - @unittest.mock.patch('fcntl.fcntl') - def test__close(self, m_fcntl, m_read): + def test__close(self, m_read): tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol) - err = object() tr._close(err) self.assertTrue(tr._closing) @@ -297,21 +288,17 @@ def test__close(self, m_fcntl, m_read): self.event_loop.call_soon.assert_called_with( tr._call_connection_lost, err) - @unittest.mock.patch('fcntl.fcntl') - def test__call_connection_lost(self, m_fcntl): + def test__call_connection_lost(self): tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol) - err = None tr._call_connection_lost(err) self.protocol.connection_lost.assert_called_with(err) self.pipe.close.assert_called_with() - @unittest.mock.patch('fcntl.fcntl') - def test__call_connection_lost_with_err(self, m_fcntl): + def test__call_connection_lost_with_err(self): tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol) - err = OSError() tr._call_connection_lost(err) self.protocol.connection_lost.assert_called_with(err) @@ -326,33 +313,43 @@ def setUp(self): self.pipe.fileno.return_value = 5 self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) - @unittest.mock.patch('fcntl.fcntl') - def test_ctor(self, m_fcntl): + fcntl_patcher = unittest.mock.patch('fcntl.fcntl') + fcntl_patcher.start() + self.addCleanup(fcntl_patcher.stop) + + self.fstat_patcher = unittest.mock.patch('os.fstat') + m_fstat = self.fstat_patcher.start() + st = unittest.mock.Mock() + st.st_mode = stat.S_IFIFO + m_fstat.return_value = st + self.addCleanup(self.fstat_patcher.stop) + + def test_ctor(self): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) + self.event_loop.add_reader.assert_called_with(5, tr._read_ready) self.event_loop.call_soon.assert_called_with( self.protocol.connection_made, tr) + self.assertTrue(tr._enable_read_hack) - @unittest.mock.patch('fcntl.fcntl') - def test_ctor_with_waiter(self, m_fcntl): + def test_ctor_with_waiter(self): fut = futures.Future() - unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol, fut) + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol, fut) + self.event_loop.add_reader.assert_called_with(5, tr._read_ready) self.event_loop.call_soon.assert_called_with(fut.set_result, None) fut.cancel() + self.assertTrue(tr._enable_read_hack) - @unittest.mock.patch('fcntl.fcntl') - def test_can_write_eof(self, m_fcntl): + def test_can_write_eof(self): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) self.assertTrue(tr.can_write_eof()) @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test_write(self, m_fcntl, m_write): + def test_write(self, m_write): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) - m_write.return_value = 4 tr.write(b'data') m_write.assert_called_with(5, b'data') @@ -360,22 +357,18 @@ def test_write(self, m_fcntl, m_write): self.assertEqual([], tr._buffer) @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test_write_no_data(self, m_fcntl, m_write): + def test_write_no_data(self, m_write): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) - tr.write(b'') self.assertFalse(m_write.called) self.assertFalse(self.event_loop.add_writer.called) self.assertEqual([], tr._buffer) @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test_write_partial(self, m_fcntl, m_write): + def test_write_partial(self, m_write): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) - m_write.return_value = 2 tr.write(b'data') m_write.assert_called_with(5, b'data') @@ -383,11 +376,9 @@ def test_write_partial(self, m_fcntl, m_write): self.assertEqual([b'ta'], tr._buffer) @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test_write_buffer(self, m_fcntl, m_write): + def test_write_buffer(self, m_write): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) - tr._buffer = [b'previous'] tr.write(b'data') self.assertFalse(m_write.called) @@ -395,11 +386,9 @@ def test_write_buffer(self, m_fcntl, m_write): self.assertEqual([b'previous', b'data'], tr._buffer) @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test_write_again(self, m_fcntl, m_write): + def test_write_again(self, m_write): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) - m_write.side_effect = BlockingIOError() tr.write(b'data') m_write.assert_called_with(5, b'data') @@ -408,11 +397,9 @@ def test_write_again(self, m_fcntl, m_write): @unittest.mock.patch('tulip.unix_events.tulip_log') @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test_write_err(self, m_fcntl, m_write, m_log): + def test_write_err(self, m_write, m_log): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) - err = OSError() m_write.side_effect = err tr._fatal_error = unittest.mock.Mock() @@ -432,9 +419,18 @@ def test_write_err(self, m_fcntl, m_write, m_log): m_log.warning.assert_called_with( 'os.write(pipe, data) raised exception.') + def test__read_ready(self): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + tr._read_ready() + self.event_loop.remove_writer.assert_called_with(5) + self.event_loop.remove_reader.assert_called_with(5) + self.assertTrue(tr._closing) + self.event_loop.call_soon.assert_called_with(tr._call_connection_lost, + None) + @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test__write_ready(self, m_fcntl, m_write): + def test__write_ready(self, m_write): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) tr._buffer = [b'da', b'ta'] @@ -445,11 +441,9 @@ def test__write_ready(self, m_fcntl, m_write): self.assertEqual([], tr._buffer) @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test__write_ready_partial(self, m_fcntl, m_write): + def test__write_ready_partial(self, m_write): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) - tr._buffer = [b'da', b'ta'] m_write.return_value = 3 tr._write_ready() @@ -458,11 +452,9 @@ def test__write_ready_partial(self, m_fcntl, m_write): self.assertEqual([b'a'], tr._buffer) @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test__write_ready_again(self, m_fcntl, m_write): + def test__write_ready_again(self, m_write): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) - tr._buffer = [b'da', b'ta'] m_write.side_effect = BlockingIOError() tr._write_ready() @@ -471,11 +463,9 @@ def test__write_ready_again(self, m_fcntl, m_write): self.assertEqual([b'data'], tr._buffer) @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test__write_ready_empty(self, m_fcntl, m_write): + def test__write_ready_empty(self, m_write): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) - tr._buffer = [b'da', b'ta'] m_write.return_value = 0 tr._write_ready() @@ -485,11 +475,9 @@ def test__write_ready_empty(self, m_fcntl, m_write): @unittest.mock.patch('tulip.log.tulip_log.exception') @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test__write_ready_err(self, m_fcntl, m_write, m_logexc): + def test__write_ready_err(self, m_write, m_logexc): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) - tr._buffer = [b'da', b'ta'] m_write.side_effect = err = OSError() tr._write_ready() @@ -501,13 +489,12 @@ def test__write_ready_err(self, m_fcntl, m_write, m_logexc): tr._call_connection_lost, err) m_logexc.assert_called_with('Fatal error for %s', tr) self.assertEqual(1, tr._conn_lost) + self.event_loop.remove_reader.assert_called_with(5) @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test__write_ready_closing(self, m_fcntl, m_write): + def test__write_ready_closing(self, m_write): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) - tr._closing = True tr._buffer = [b'da', b'ta'] m_write.return_value = 4 @@ -517,13 +504,28 @@ def test__write_ready_closing(self, m_fcntl, m_write): self.assertEqual([], tr._buffer) self.protocol.connection_lost.assert_called_with(None) self.pipe.close.assert_called_with() + self.event_loop.remove_reader.assert_called_with(5) @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test_abort(self, m_fcntl, m_write): + def test__write_ready_closing_regular_file(self, m_write): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) + tr._enable_read_hack = False + tr._closing = True + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + self.protocol.connection_lost.assert_called_with(None) + self.pipe.close.assert_called_with() + self.assertFalse(self.event_loop.remove_reader.called) + @unittest.mock.patch('os.write') + def test_abort(self, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) tr._buffer = [b'da', b'ta'] tr.abort() self.assertFalse(m_write.called) @@ -532,61 +534,93 @@ def test_abort(self, m_fcntl, m_write): self.assertTrue(tr._closing) self.event_loop.call_soon.assert_called_with( tr._call_connection_lost, None) + self.event_loop.remove_reader.assert_called_with(5) - @unittest.mock.patch('fcntl.fcntl') - def test__call_connection_lost(self, m_fcntl): + @unittest.mock.patch('os.write') + def test_abort_closing_regular_file(self, m_write): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) + tr._enable_read_hack = False + tr._buffer = [b'da', b'ta'] + tr.abort() + self.assertFalse(m_write.called) + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + self.event_loop.call_soon.assert_called_with( + tr._call_connection_lost, None) + self.assertFalse(self.event_loop.remove_reader.called) + def test__call_connection_lost(self): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) err = None tr._call_connection_lost(err) self.protocol.connection_lost.assert_called_with(err) self.pipe.close.assert_called_with() - @unittest.mock.patch('fcntl.fcntl') - def test__call_connection_lost_with_err(self, m_fcntl): + def test__call_connection_lost_with_err(self): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) - err = OSError() tr._call_connection_lost(err) self.protocol.connection_lost.assert_called_with(err) self.pipe.close.assert_called_with() - @unittest.mock.patch('fcntl.fcntl') - def test_close(self, m_fcntl): + def test_close(self): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) - tr.write_eof = unittest.mock.Mock() tr.close() tr.write_eof.assert_called_with() - @unittest.mock.patch('fcntl.fcntl') - def test_close_closing(self, m_fcntl): + def test_close_closing(self): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) - tr.write_eof = unittest.mock.Mock() tr._closing = True tr.close() self.assertFalse(tr.write_eof.called) - @unittest.mock.patch('fcntl.fcntl') - def test_write_eof(self, m_fcntl): + def test_write_eof(self): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) + tr.write_eof() + self.assertTrue(tr._closing) + self.event_loop.call_soon.assert_called_with( + tr._call_connection_lost, None) + self.event_loop.remove_reader.assert_called_with(5) + def test_write_eof_dont_remove_reader_for_regular_file(self): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + tr._enable_read_hack = False tr.write_eof() self.assertTrue(tr._closing) self.event_loop.call_soon.assert_called_with( tr._call_connection_lost, None) + self.assertFalse(self.event_loop.remove_reader.called) - @unittest.mock.patch('fcntl.fcntl') - def test_write_eof_pending(self, m_fcntl): + def test_write_eof_pending(self): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) tr._buffer = [b'data'] tr.write_eof() self.assertTrue(tr._closing) self.assertFalse(self.protocol.connection_lost.called) + + +class UntxWritePipeRegularFileTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop) + self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) + + def test_ctor_with_regular_file(self): + with tempfile.TemporaryFile() as f: + tr = unix_events._UnixWritePipeTransport(self.event_loop, f, + self.protocol) + self.assertFalse(self.event_loop.add_reader.called) + self.event_loop.call_soon.assert_called_with( + self.protocol.connection_made, tr) + self.assertFalse(tr._enable_read_hack) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 73ada428..3e1bc098 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -4,6 +4,7 @@ import fcntl import os import socket +import stat import sys try: @@ -218,10 +219,28 @@ def __init__(self, event_loop, pipe, protocol, waiter=None, extra=None): self._buffer = [] self._conn_lost = 0 self._closing = False # Set when close() or write_eof() called. + # Do nothing if it is a regular file. + # Enable hack only if pipe is FIFO object. + # Look on twisted.internet.process:ProcessWriter.__init__ + if stat.S_ISFIFO(os.fstat(self._fileno).st_mode): + self._enable_read_hack = True + else: + # If the pipe is not a unix pipe, then the read hack is never + # applicable. This case arises when _UnixWritePipeTransport + # is used by subprocess and stdout/stderr + # are redirected to a normal file. + self._enable_read_hack = False + + if self._enable_read_hack: + self._event_loop.add_reader(self._fileno, self._read_ready) self._event_loop.call_soon(self._protocol.connection_made, self) if waiter is not None: self._event_loop.call_soon(waiter.set_result, None) + def _read_ready(self): + # pipe was closed by peer + self._close() + def write(self, data): assert isinstance(data, bytes), repr(data) assert not self._closing @@ -267,6 +286,8 @@ def _write_ready(self): else: if n == len(data): self._event_loop.remove_writer(self._fileno) + if self._enable_read_hack: + self._event_loop.remove_reader(self._fileno) if self._closing: self._call_connection_lost(None) return @@ -283,6 +304,8 @@ def write_eof(self): assert self._pipe self._closing = True if not self._buffer: + if self._enable_read_hack: + self._event_loop.remove_reader(self._fileno) self._event_loop.call_soon(self._call_connection_lost, None) def close(self): @@ -302,6 +325,8 @@ def _close(self, exc=None): self._closing = True self._buffer.clear() self._event_loop.remove_writer(self._fileno) + if self._enable_read_hack: + self._event_loop.remove_reader(self._fileno) self._event_loop.call_soon(self._call_connection_lost, exc) def _call_connection_lost(self, exc): From e223f4ab588632861c6cca5ec7fa664ef4c15046 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 22 Apr 2013 18:50:49 -0700 Subject: [PATCH 0418/1502] http protocol parser --- examples/mpsrv.py | 7 +- examples/websocket.html | 90 ++++ examples/wsclient.py | 72 +++ examples/wssrv.py | 311 +++++++++++++ srv.py | 9 +- tests/events_test.py | 48 +- tests/http_client_functional_test.py | 4 +- tests/http_client_test.py | 33 +- tests/http_parser_test.py | 496 ++++++++++++++++++++ tests/http_protocol_test.py | 645 ++------------------------- tests/http_server_test.py | 27 +- tests/http_websocket_test.py | 81 +++- tests/http_wsgi_test.py | 34 +- tests/parsers_test.py | 406 ++++++++--------- tests/unix_events_test.py | 220 ++++----- tulip/http/client.py | 82 ++-- tulip/http/protocol.py | 622 ++++++++++---------------- tulip/http/server.py | 105 +++-- tulip/http/websocket.py | 54 ++- tulip/http/wsgi.py | 36 +- tulip/parsers.py | 56 +-- tulip/selector_events.py | 1 - tulip/streams.py | 6 +- tulip/test_utils.py | 29 +- tulip/unix_events.py | 25 -- 25 files changed, 1850 insertions(+), 1649 deletions(-) create mode 100644 examples/websocket.html create mode 100644 examples/wsclient.py create mode 100644 examples/wssrv.py create mode 100644 tests/http_parser_test.py diff --git a/examples/mpsrv.py b/examples/mpsrv.py index f1a53d41..8664cd46 100644 --- a/examples/mpsrv.py +++ b/examples/mpsrv.py @@ -27,12 +27,11 @@ class HttpServer(tulip.http.ServerHttpProtocol): @tulip.coroutine - def handle_request(self, request_info, message): + def handle_request(self, message, payload): print('{}: method = {!r}; path = {!r}; version = {!r}'.format( - os.getpid(), request_info.method, - request_info.uri, request_info.version)) + os.getpid(), message.method, message.path, message.version)) - path = request_info.uri + path = message.path if (not (path.isprintable() and path.startswith('/')) or '/.' in path): path = None diff --git a/examples/websocket.html b/examples/websocket.html new file mode 100644 index 00000000..6bad7f74 --- /dev/null +++ b/examples/websocket.html @@ -0,0 +1,90 @@ + + + + + + + + +

Chat!

+
+  | Status: + disconnected +
+
+
+
+ + +
+ + diff --git a/examples/wsclient.py b/examples/wsclient.py new file mode 100644 index 00000000..64b48514 --- /dev/null +++ b/examples/wsclient.py @@ -0,0 +1,72 @@ +"""websocket cmd client for wssrv.py example.""" +import base64 +import hashlib +import os +import signal +import sys + +import tulip +import tulip.http +from tulip.http import websocket + +WS_KEY = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + +def start_client(loop): + name = input('Please enter your name: ').encode() + + url = 'http://localhost:8080/' + sec_key = base64.b64encode(os.urandom(16)) + + # send request + response = yield from tulip.http.request( + 'get', url, + headers={ + 'UPGRADE': 'WebSocket', + 'CONNECTION': 'Upgrade', + 'SEC-WEBSOCKET-VERSION': '13', + 'SEC-WEBSOCKET-KEY': sec_key.decode(), + }, timeout=1.0) + + # websocket handshake + if response.status != 101: + raise ValueError("Handshake error: Invalid response status") + if response.get('upgrade', '').lower() != 'websocket': + raise ValueError("Handshake error - Invalid upgrade header") + if response.get('connection', '').lower() != 'upgrade': + raise ValueError("Handshake error - Invalid connection header") + + key = response.get('sec-websocket-accept', '').encode() + match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()) + if key != match: + raise ValueError("Handshake error - Invalid challenge response") + + # switch to websocket protocol + stream = response.stream.set_parser(websocket.websocket_parser()) + writer = websocket.websocket_writer(response.transport) + + # input reader + loop.add_reader( + sys.stdin.fileno(), + lambda: writer.send(name + b': ' + sys.stdin.readline().encode())) + + @tulip.coroutine + def dispatch(): + while True: + msg = yield from stream.read() + if msg is None: + break + elif msg.opcode == websocket.OPCODE_PING: + writer.pong() + elif msg.opcode == websocket.OPCODE_TEXT: + print(msg.data.strip()) + elif msg.opcode == websocket.OPCODE_CLOSE: + break + + yield from dispatch() + + +if __name__ == '__main__': + loop = tulip.get_event_loop() + loop.add_signal_handler(signal.SIGINT, loop.stop) + loop.run_until_complete(start_client(loop)) diff --git a/examples/wssrv.py b/examples/wssrv.py new file mode 100644 index 00000000..2d94fd36 --- /dev/null +++ b/examples/wssrv.py @@ -0,0 +1,311 @@ +#!/usr/bin/env python3 +"""Multiprocess WebSocket http chat example.""" + +import argparse +import os +import logging +import socket +import signal +import time +import tulip +import tulip.http +from tulip.http import websocket + +ARGS = argparse.ArgumentParser(description="Run simple http server.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') +ARGS.add_argument( + '--workers', action="store", dest='workers', + default=2, type=int, help='Number of workers.') + +WS_FILE = os.path.join(os.path.dirname(__file__), 'websocket.html') + + +class HttpServer(tulip.http.ServerHttpProtocol): + + clients = None # list of all active connections + parent = None # process supervisor + # we use it as broadcaster to all workers + + @tulip.coroutine + def handle_request(self, message, payload): + upgrade = False + for hdr, val in message.headers: + if hdr == 'UPGRADE': + upgrade = 'websocket' in val.lower() + break + + if upgrade: + # websocket handshake + status, headers, parser, writer = websocket.do_handshake( + message, self.transport) + + resp = tulip.http.Response(self.transport, status) + resp.add_headers(*headers) + resp.send_headers() + + # install websocket parser + databuffer = self.stream.set_parser(parser) + + # notify everybody + print('{}: Someone joined.'.format(os.getpid())) + for wsc in self.clients: + wsc.send(b'Someone joined.') + self.clients.append(writer) + self.parent.send(b'Someone joined.') + + # chat dispatcher + while True: + msg = yield from databuffer.read() + if msg is None: # client droped connection + break + + if msg.tp == websocket.MSG_PING: + writer.pong() + + elif msg.tp == websocket.MSG_TEXT: + data = msg.data.strip() + print('{}: {}'.format(os.getpid(), data)) + for wsc in self.clients: + if wsc is not writer: + wsc.send(data.encode()) + self.parent.send(data) + + elif msg.tp == websocket.MSG_CLOSE: + break + + # notify everybody + print('{}: Someone disconnected.'.format(os.getpid())) + self.parent.send(b'Someone disconnected.') + self.clients.remove(writer) + for wsc in self.clients: + wsc.send(b'Someone disconnected.') + + self.close() + else: + # send html page with js chat + response = tulip.http.Response(self.transport, 200) + response.add_header('Transfer-Encoding', 'chunked') + response.add_header('Content-type', 'text/html') + response.send_headers() + + try: + with open(WS_FILE, 'rb') as fp: + chunk = fp.read(8196) + while chunk: + if not response.write(chunk): + break + chunk = fp.read(8196) + except OSError: + response.write(b'Cannot open') + + response.write_eof() + self.close() + + +class ChildProcess: + + def __init__(self, up_read, down_write, args, sock): + self.up_read = up_read + self.down_write = down_write + self.args = args + self.sock = sock + self.clients = [] + + def start(self): + # start server + self.loop = loop = tulip.new_event_loop() + tulip.set_event_loop(loop) + loop.set_log_level(logging.CRITICAL) + + def stop(): + self.loop.stop() + os._exit(0) + loop.add_signal_handler(signal.SIGINT, stop) + + # heartbeat + self.heartbeat() + + tulip.get_event_loop().run_forever() + os._exit(0) + + @tulip.task + def start_server(self, writer): + sock = yield from self.loop.start_serving( + lambda: HttpServer( + debug=True, parent=writer, clients=self.clients), + sock=self.sock) + print('Starting srv worker process {} on {}'.format( + os.getpid(), sock.getsockname())) + + @tulip.task + def heartbeat(self): + # setup pipes + read_transport, read_proto = yield from self.loop.connect_read_pipe( + tulip.StreamProtocol, os.fdopen(self.up_read, 'rb')) + write_transport, _ = yield from self.loop.connect_write_pipe( + tulip.StreamProtocol, os.fdopen(self.down_write, 'wb')) + + reader = read_proto.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(write_transport) + + self.start_server(writer) + + while True: + msg = yield from reader.read() + if msg is None: + print('Superviser is dead, {} stopping...'.format(os.getpid())) + self.loop.stop() + break + elif msg.tp == websocket.MSG_PING: + writer.pong() + elif msg.tp == websocket.MSG_CLOSE: + break + elif msg.tp == websocket.MSG_TEXT: # broadcast message + for wsc in self.clients: + wsc.send(msg.data.strip().encode()) + + read_transport.close() + write_transport.close() + + +class Worker: + + _started = False + + def __init__(self, sv, loop, args, sock): + self.sv = sv + self.loop = loop + self.args = args + self.sock = sock + self.start() + + def start(self): + assert not self._started + self._started = True + + up_read, up_write = os.pipe() + down_read, down_write = os.pipe() + args, sock = self.args, self.sock + + pid = os.fork() + if pid: + # parent + os.close(up_read) + os.close(down_write) + self.connect(pid, up_write, down_read) + else: + # child + os.close(up_write) + os.close(down_read) + + # cleanup after fork + tulip.set_event_loop(None) + + # setup process + process = ChildProcess(up_read, down_write, args, sock) + process.start() + + @tulip.task + def heartbeat(self, writer): + while True: + yield from tulip.sleep(15) + + if (time.monotonic() - self.ping) < 30: + writer.ping() + else: + print('Restart unresponsive worker process: {}'.format( + self.pid)) + self.kill() + self.start() + return + + @tulip.task + def chat(self, reader): + while True: + msg = yield from reader.read() + if msg is None: + print('Restart unresponsive worker process: {}'.format( + self.pid)) + self.kill() + self.start() + return + + elif msg.tp == websocket.MSG_PONG: + self.ping = time.monotonic() + + elif msg.tp == websocket.MSG_TEXT: # broadcast to all workers + for worker in self.sv.workers: + if self.pid != worker.pid: + worker.writer.send(msg.data) + + @tulip.task + def connect(self, pid, up_write, down_read): + # setup pipes + read_transport, proto = yield from self.loop.connect_read_pipe( + tulip.StreamProtocol, os.fdopen(down_read, 'rb')) + write_transport, _ = yield from self.loop.connect_write_pipe( + tulip.StreamProtocol, os.fdopen(up_write, 'wb')) + + # websocket protocol + reader = proto.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(write_transport) + + # store info + self.pid = pid + self.ping = time.monotonic() + self.writer = writer + self.rtransport = read_transport + self.wtransport = write_transport + self.chat_task = self.chat(reader) + self.heartbeat_task = self.heartbeat(writer) + + def kill(self): + self._started = False + self.chat_task.cancel() + self.heartbeat_task.cancel() + self.rtransport.close() + self.wtransport.close() + os.kill(self.pid, signal.SIGTERM) + + +class Superviser: + + def __init__(self, args): + self.loop = tulip.get_event_loop() + self.loop.set_log_level(logging.CRITICAL) + self.args = args + self.workers = [] + + def start(self): + # bind socket + sock = self.sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((self.args.host, self.args.port)) + sock.listen(1024) + sock.setblocking(False) + + # start processes + for idx in range(self.args.workers): + self.workers.append(Worker(self, self.loop, self.args, sock)) + + self.loop.add_signal_handler(signal.SIGINT, lambda: self.loop.stop()) + self.loop.run_forever() + + +def main(): + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + superviser = Superviser(args) + superviser.start() + + +if __name__ == '__main__': + main() diff --git a/srv.py b/srv.py index a018daa9..56392dd5 100755 --- a/srv.py +++ b/srv.py @@ -19,11 +19,12 @@ class HttpServer(tulip.http.ServerHttpProtocol): - def handle_request(self, request_info, message): + @tulip.coroutine + def handle_request(self, message, payload): print('method = {!r}; path = {!r}; version = {!r}'.format( - request_info.method, request_info.uri, request_info.version)) + message.method, message.path, message.version)) - path = request_info.uri + path = message.path if (not (path.isprintable() and path.startswith('/')) or '/.' in path): print('bad path', repr(path)) @@ -90,7 +91,7 @@ def handle_request(self, request_info, message): with open(path, 'rb') as fp: chunk = fp.read(8196) while chunk: - response.write(chunk): + response.write(chunk) chunk = fp.read(8196) except OSError: response.write(b'Cannot open') diff --git a/tests/events_test.py b/tests/events_test.py index eafcfd67..e928cdf0 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -446,7 +446,8 @@ def test_sock_client_ops(self): self.event_loop.run_until_complete( self.event_loop.sock_connect(sock, httpd.address)) self.event_loop.run_until_complete( - self.event_loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + self.event_loop.sock_sendall( + sock, b'GET / HTTP/1.0\r\n\r\n')) data = self.event_loop.run_until_complete( self.event_loop.sock_recv(sock, 1024)) # consume data @@ -1052,7 +1053,6 @@ def factory(): return proto rpipe, wpipe = os.pipe() - self.addCleanup(os.close, rpipe) pipeobj = io.open(wpipe, 'wb', 1024) @tasks.task @@ -1078,6 +1078,8 @@ def connect(): self.assertEqual(b'2345', data) self.assertEqual('CONNECTED', proto.state) + os.close(rpipe) + # extra info is available self.assertIsNotNone(proto.transport.get_extra_info('pipe')) @@ -1086,48 +1088,6 @@ def connect(): self.event_loop.run_until_complete(proto.done) self.assertEqual('CLOSED', proto.state) - @unittest.skipUnless(sys.platform != 'win32', - "Don't support pipes for Windows") - def test_write_pipe_disconnect_on_close(self): - proto = None - transport = None - - def factory(): - nonlocal proto - proto = MyWritePipeProto(create_future=True) - return proto - - rpipe, wpipe = os.pipe() - pipeobj = io.open(wpipe, 'wb', 1024) - - @tasks.task - def connect(): - nonlocal transport - t, p = yield from self.event_loop.connect_write_pipe(factory, - pipeobj) - self.assertIs(p, proto) - self.assertIs(t, proto.transport) - self.assertEqual('CONNECTED', proto.state) - transport = t - - self.event_loop.run_until_complete(connect()) - - transport.write(b'1') - self.event_loop.run_once() - data = os.read(rpipe, 1024) - self.assertEqual(b'1', data) - - transport.write(b'2345') - self.event_loop.run_once() - data = os.read(rpipe, 1024) - self.assertEqual(b'2345', data) - self.assertEqual('CONNECTED', proto.state) - - os.close(rpipe) - - self.event_loop.run_until_complete(proto.done) - self.assertEqual('CLOSED', proto.state) - if sys.platform == 'win32': from tulip import windows_events diff --git a/tests/http_client_functional_test.py b/tests/http_client_functional_test.py index 1ee3ab8c..5eb65702 100644 --- a/tests/http_client_functional_test.py +++ b/tests/http_client_functional_test.py @@ -1,5 +1,6 @@ """Http client functional tests.""" +import gc import io import os.path import http.cookies @@ -19,6 +20,7 @@ def setUp(self): def tearDown(self): self.loop.close() + gc.collect() def test_HTTP_200_OK_METHOD(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: @@ -320,7 +322,7 @@ def test_cookies(self): self.assertEqual(r.status, 200) content = self.loop.run_until_complete(r.content.read()) - self.assertIn(b'"Cookie": "test1=123; test3=456"', content) + self.assertIn(b'"Cookie": "test1=123; test3=456"', bytes(content)) def test_chunked(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: diff --git a/tests/http_client_test.py b/tests/http_client_test.py index 973a6cb0..77a2a7ef 100644 --- a/tests/http_client_test.py +++ b/tests/http_client_test.py @@ -8,26 +8,7 @@ import tulip import tulip.http -from tulip.http.client import HttpProtocol, HttpRequest, HttpResponse - - -class HttpProtocolTests(unittest.TestCase): - - def test_protocol(self): - transport = unittest.mock.Mock() - - p = HttpProtocol() - p.connection_made(transport) - self.assertIs(p.transport, transport) - self.assertIsInstance(p.stream, tulip.http.HttpStreamReader) - - p.data_received(b'data') - self.assertEqual(4, p.stream.byte_count) - - p.eof_received() - self.assertTrue(p.stream.eof) - - p.connection_lost(None) +from tulip.http.client import HttpRequest, HttpResponse class HttpResponseTests(unittest.TestCase): @@ -37,16 +18,16 @@ def setUp(self): tulip.set_event_loop(self.loop) self.transport = unittest.mock.Mock() - self.stream = tulip.http.HttpStreamReader(self.transport) + self.stream = tulip.StreamBuffer() self.response = HttpResponse('get', 'http://python.org') def tearDown(self): self.loop.close() def test_close(self): - self.response._transport = self.transport + self.response.transport = self.transport self.response.close() - self.assertIsNone(self.response._transport) + self.assertIsNone(self.response.transport) self.assertTrue(self.transport.close.called) self.response.close() self.response.close() @@ -65,7 +46,7 @@ def setUp(self): tulip.set_event_loop(self.loop) self.transport = unittest.mock.Mock() - self.stream = tulip.http.HttpStreamReader(self.transport) + self.stream = tulip.StreamBuffer() def tearDown(self): self.loop.close() @@ -173,11 +154,11 @@ def test_basic_auth_err(self): def test_no_content_length(self): req = HttpRequest('get', 'http://python.org') req.send(self.transport) - self.assertEqual(0, req.headers.get('Content-Length')) + self.assertEqual('0', req.headers.get('Content-Length')) req = HttpRequest('head', 'http://python.org') req.send(self.transport) - self.assertEqual(0, req.headers.get('Content-Length')) + self.assertEqual('0', req.headers.get('Content-Length')) def test_path_is_not_double_encoded(self): req = HttpRequest('get', "http://0.0.0.0/get/test case") diff --git a/tests/http_parser_test.py b/tests/http_parser_test.py new file mode 100644 index 00000000..71951be7 --- /dev/null +++ b/tests/http_parser_test.py @@ -0,0 +1,496 @@ +"""Tests for http/parser.py""" + +from collections import deque +import zlib +import unittest +import unittest.mock + +import tulip +from tulip.http import errors +from tulip.http import protocol + + +class ParseHeadersTests(unittest.TestCase): + + def test_parse_headers(self): + hdrs = ('', 'test: line\r\n', ' continue\r\n', + 'test2: data\r\n', '\r\n') + + headers, close, compression = protocol.parse_headers( + hdrs, 8190, 32768, 8190) + + self.assertEqual(list(headers), + [('TEST', 'line\r\n continue'), ('TEST2', 'data')]) + self.assertIsNone(close) + self.assertIsNone(compression) + + def test_conn_close(self): + headers, close, compression = protocol.parse_headers( + ['', 'connection: close\r\n', '\r\n'], 8190, 32768, 8190) + self.assertTrue(close) + + def test_conn_keep_alive(self): + headers, close, compression = protocol.parse_headers( + ['', 'connection: keep-alive\r\n', '\r\n'], 8190, 32768, 8190) + self.assertFalse(close) + + def test_conn_other(self): + headers, close, compression = protocol.parse_headers( + ['', 'connection: test\r\n', '\r\n'], 8190, 32768, 8190) + self.assertIsNone(close) + + def test_compression_gzip(self): + headers, close, compression = protocol.parse_headers( + ['', 'content-encoding: gzip\r\n', '\r\n'], 8190, 32768, 8190) + self.assertEqual('gzip', compression) + + def test_compression_deflate(self): + headers, close, compression = protocol.parse_headers( + ['', 'content-encoding: deflate\r\n', '\r\n'], 8190, 32768, 8190) + self.assertEqual('deflate', compression) + + def test_compression_unknown(self): + headers, close, compression = protocol.parse_headers( + ['', 'content-encoding: compress\r\n', '\r\n'], 8190, 32768, 8190) + self.assertIsNone(compression) + + def test_max_field_size(self): + with self.assertRaises(errors.LineTooLong) as cm: + protocol.parse_headers( + ['', 'test: line data data\r\n', 'data\r\n', '\r\n'], + 8190, 32768, 5) + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_max_continuation_headers_size(self): + with self.assertRaises(errors.LineTooLong) as cm: + protocol.parse_headers( + ['', 'test: line\r\n', ' test\r\n', '\r\n'], 8190, 32768, 5) + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_invalid_header(self): + with self.assertRaises(ValueError) as cm: + protocol.parse_headers( + ['', 'test line\r\n', '\r\n'], 8190, 32768, 8190) + self.assertIn("Invalid header: test line", str(cm.exception)) + + def test_invalid_name(self): + with self.assertRaises(ValueError) as cm: + protocol.parse_headers( + ['', 'test[]: line\r\n', '\r\n'], 8190, 32768, 8190) + self.assertIn("Invalid header name: TEST[]", str(cm.exception)) + + +class DeflateBufferTests(unittest.TestCase): + + def test_feed_data(self): + buf = tulip.DataBuffer() + dbuf = protocol.DeflateBuffer(buf, 'deflate') + + dbuf.zlib = unittest.mock.Mock() + dbuf.zlib.decompress.return_value = b'line' + + dbuf.feed_data(b'data') + self.assertEqual([b'line'], list(buf._buffer)) + + def test_feed_data_err(self): + buf = tulip.DataBuffer() + dbuf = protocol.DeflateBuffer(buf, 'deflate') + + exc = ValueError() + dbuf.zlib = unittest.mock.Mock() + dbuf.zlib.decompress.side_effect = exc + + self.assertRaises(errors.IncompleteRead, dbuf.feed_data, b'data') + + def test_feed_eof(self): + buf = tulip.DataBuffer() + dbuf = protocol.DeflateBuffer(buf, 'deflate') + + dbuf.zlib = unittest.mock.Mock() + dbuf.zlib.flush.return_value = b'line' + + dbuf.feed_eof() + self.assertEqual([b'line'], list(buf._buffer)) + self.assertTrue(buf._eof) + + def test_feed_eof_err(self): + buf = tulip.DataBuffer() + dbuf = protocol.DeflateBuffer(buf, 'deflate') + + dbuf.zlib = unittest.mock.Mock() + dbuf.zlib.flush.return_value = b'line' + dbuf.zlib.eof = False + + self.assertRaises(errors.IncompleteRead, dbuf.feed_eof) + + +class ParsePayloadTests(unittest.TestCase): + + def test_parse_eof_payload(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_eof_payload(out, buf) + next(p) + p.send(b'data') + try: + p.throw(tulip.EofStream()) + except tulip.EofStream: + pass + + self.assertEqual([b'data'], list(out._buffer)) + + def test_parse_length_payload(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_length_payload(out, buf, 4) + next(p) + p.send(b'da') + p.send(b't') + try: + p.send(b'aline') + except StopIteration: + pass + + self.assertEqual(3, len(out._buffer)) + self.assertEqual(b'data', b''.join(out._buffer)) + self.assertEqual(b'line', bytes(buf)) + + def test_parse_length_payload_eof(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_length_payload(out, buf, 4) + next(p) + p.send(b'da') + self.assertRaises( + errors.IncompleteRead, p.throw, tulip.EofStream) + + def test_parse_chunked_payload(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + try: + p.send(b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') + except StopIteration: + pass + self.assertEqual(b'dataline', b''.join(out._buffer)) + self.assertEqual(b'', bytes(buf)) + + def test_parse_chunked_payload_chunks(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + p.send(b'4\r\ndata\r') + p.send(b'\n4') + p.send(b'\r') + p.send(b'\n') + p.send(b'line\r\n0\r\n') + self.assertRaises(StopIteration, p.send, b'test\r\n') + self.assertEqual(b'dataline', b''.join(out._buffer)) + + def test_parse_chunked_payload_incomplete(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + p.send(b'4\r\ndata\r\n') + self.assertRaises(errors.IncompleteRead, p.throw, tulip.EofStream) + + def test_parse_chunked_payload_extension(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + try: + p.send(b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') + except StopIteration: + pass + self.assertEqual(b'dataline', b''.join(out._buffer)) + + def test_parse_chunked_payload_size_error(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + self.assertRaises(errors.IncompleteRead, p.send, b'blah\r\n') + + def test_http_payload_parser_length_broken(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', 'qwe')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + self.assertRaises(errors.InvalidHeader, p.send, (out, buf)) + + def test_http_payload_parser_length_wrong(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', '-1')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + self.assertRaises(errors.InvalidHeader, p.send, (out, buf)) + + def test_http_payload_parser_length(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', '2')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + try: + p.send(b'1245') + except StopIteration: + pass + + self.assertEqual(b'12', b''.join(out._buffer)) + self.assertEqual(b'45', bytes(buf)) + + def test_http_payload_parser_no_length(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [], None, None) + p = protocol.http_payload_parser(msg, readall=False) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + self.assertRaises(StopIteration, p.send, (out, buf)) + self.assertEqual(b'', b''.join(out._buffer)) + self.assertTrue(out._eof) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_http_payload_parser_deflate(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', len(self._COMPRESSED))], + None, 'deflate') + p = protocol.http_payload_parser(msg) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(StopIteration, p.send, self._COMPRESSED) + self.assertEqual(b'data', b''.join(out._buffer)) + + def test_http_payload_parser_deflate_disabled(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', len(self._COMPRESSED))], + None, 'deflate') + p = protocol.http_payload_parser(msg, compression=False) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(StopIteration, p.send, self._COMPRESSED) + self.assertEqual(self._COMPRESSED, b''.join(out._buffer)) + + def test_http_payload_parser_websocket(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('SEC-WEBSOCKET-KEY1', '13')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(StopIteration, p.send, b'1234567890') + self.assertEqual(b'12345678', b''.join(out._buffer)) + + def test_http_payload_parser_chunked(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('TRANSFER-ENCODING', 'chunked')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(StopIteration, p.send, + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') + self.assertEqual(b'dataline', b''.join(out._buffer)) + + def test_http_payload_parser_eof(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [], None, None) + p = protocol.http_payload_parser(msg, readall=True) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + p.send(b'data') + p.send(b'line') + self.assertRaises(tulip.EofStream, p.throw, tulip.EofStream()) + self.assertEqual(b'dataline', b''.join(out._buffer)) + + def test_http_payload_parser_length_zero(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', '0')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + self.assertRaises(StopIteration, p.send, (out, buf)) + self.assertEqual(b'', b''.join(out._buffer)) + + +class ParseRequestTests(unittest.TestCase): + + def test_http_request_parser_max_headers(self): + p = protocol.http_request_parser(8190, 20, 8190) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + + self.assertRaises( + errors.LineTooLong, + p.send, + b'get /path HTTP/1.1\r\ntest: line\r\ntest2: data\r\n\r\n') + + def test_http_request_parser(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + try: + p.send(b'get /path HTTP/1.1\r\n\r\n') + except StopIteration: + pass + result = out._buffer[0] + self.assertEqual( + ('GET', '/path', (1, 1), deque(), False, None), result) + + def test_http_request_parser_two_slashes(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + try: + p.send(b'get //path HTTP/1.1\r\n\r\n') + except StopIteration: + pass + self.assertEqual( + ('GET', '//path', (1, 1), deque(), False, None), out._buffer[0]) + + def test_http_request_parser_bad_status_line(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises( + errors.BadStatusLine, p.send, b'\r\n\r\n') + + def test_http_request_parser_bad_method(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises( + errors.BadStatusLine, + p.send, b'!12%()+=~$ /get HTTP/1.1\r\n\r\n') + + def test_http_request_parser_bad_version(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises( + errors.BadStatusLine, + p.send, b'GET //get HT/11\r\n\r\n') + + +class ParseResponseTests(unittest.TestCase): + + def test_http_response_parser_bad_status_line(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(errors.BadStatusLine, p.send, b'\r\n\r\n') + + def test_http_response_parser_bad_status_line_eof(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises( + errors.BadStatusLine, p.throw, tulip.EofStream()) + + def test_http_response_parser_bad_version(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HT/11 200 Ok\r\n\r\n') + self.assertEqual('HT/11 200 Ok\r\n', cm.exception.args[0]) + + def test_http_response_parser_no_reason(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + try: + p.send(b'HTTP/1.1 200\r\n\r\n') + except StopIteration: + pass + v, s, r = out._buffer[0][:3] + self.assertEqual(v, (1, 1)) + self.assertEqual(s, 200) + self.assertEqual(r, '') + + def test_http_response_parser_bad(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HTT/1\r\n\r\n') + self.assertIn('HTT/1', str(cm.exception)) + + def test_http_response_parser_code_under_100(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HTTP/1.1 99 test\r\n\r\n') + self.assertIn('HTTP/1.1 99 test', str(cm.exception)) + + def test_http_response_parser_code_above_999(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HTTP/1.1 9999 test\r\n\r\n') + self.assertIn('HTTP/1.1 9999 test', str(cm.exception)) + + def test_http_response_parser_code_not_int(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HTTP/1.1 ttt test\r\n\r\n') + self.assertIn('HTTP/1.1 ttt test', str(cm.exception)) diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py index 272c4c3e..9455426a 100644 --- a/tests/http_protocol_test.py +++ b/tests/http_protocol_test.py @@ -1,614 +1,12 @@ """Tests for http/protocol.py""" -import http.client import unittest import unittest.mock import zlib -import tulip from tulip.http import protocol -class HttpStreamReaderTests(unittest.TestCase): - - def setUp(self): - self.loop = tulip.new_event_loop() - tulip.set_event_loop(self.loop) - - self.transport = unittest.mock.Mock() - self.stream = protocol.HttpStreamReader() - - def tearDown(self): - self.loop.close() - - def test_request_line(self): - self.stream.feed_data(b'get /path HTTP/1.1\r\n') - self.assertEqual( - ('GET', '/path', (1, 1)), - self.loop.run_until_complete(self.stream.read_request_line())) - - def test_request_line_two_slashes(self): - self.stream.feed_data(b'get //path HTTP/1.1\r\n') - self.assertEqual( - ('GET', '//path', (1, 1)), - self.loop.run_until_complete(self.stream.read_request_line())) - - def test_request_line_non_ascii(self): - self.stream.feed_data(b'get /path\xd0\xb0 HTTP/1.1\r\n') - - with self.assertRaises(http.client.BadStatusLine) as cm: - self.loop.run_until_complete(self.stream.read_request_line()) - - self.assertEqual( - b'get /path\xd0\xb0 HTTP/1.1\r\n', cm.exception.args[0]) - - def test_request_line_bad_status_line(self): - self.stream.feed_data(b'\r\n') - self.assertRaises( - http.client.BadStatusLine, - self.loop.run_until_complete, - self.stream.read_request_line()) - - def test_request_line_bad_method(self): - self.stream.feed_data(b'!12%()+=~$ /get HTTP/1.1\r\n') - self.assertRaises( - http.client.BadStatusLine, - self.loop.run_until_complete, - self.stream.read_request_line()) - - def test_request_line_bad_version(self): - self.stream.feed_data(b'GET //get HT/11\r\n') - self.assertRaises( - http.client.BadStatusLine, - self.loop.run_until_complete, - self.stream.read_request_line()) - - def test_response_status_bad_status_line(self): - self.stream.feed_data(b'\r\n') - self.assertRaises( - http.client.BadStatusLine, - self.loop.run_until_complete, - self.stream.read_response_status()) - - def test_response_status_bad_status_line_eof(self): - self.stream.feed_eof() - self.assertRaises( - http.client.BadStatusLine, - self.loop.run_until_complete, - self.stream.read_response_status()) - - def test_response_status_bad_status_non_ascii(self): - self.stream.feed_data(b'HTTP/1.1 200 \xd0\xb0\r\n') - with self.assertRaises(http.client.BadStatusLine) as cm: - self.loop.run_until_complete(self.stream.read_response_status()) - - self.assertEqual(b'HTTP/1.1 200 \xd0\xb0\r\n', cm.exception.args[0]) - - def test_response_status_bad_version(self): - self.stream.feed_data(b'HT/11 200 Ok\r\n') - with self.assertRaises(http.client.BadStatusLine) as cm: - self.loop.run_until_complete(self.stream.read_response_status()) - - self.assertEqual('HT/11 200 Ok', cm.exception.args[0]) - - def test_response_status_no_reason(self): - self.stream.feed_data(b'HTTP/1.1 200\r\n') - - v, s, r = self.loop.run_until_complete( - self.stream.read_response_status()) - self.assertEqual(v, (1, 1)) - self.assertEqual(s, 200) - self.assertEqual(r, '') - - def test_response_status_bad(self): - self.stream.feed_data(b'HTT/1\r\n') - with self.assertRaises(http.client.BadStatusLine) as cm: - self.loop.run_until_complete( - self.stream.read_response_status()) - - self.assertIn('HTT/1', str(cm.exception)) - - def test_response_status_bad_code_under_100(self): - self.stream.feed_data(b'HTTP/1.1 99 test\r\n') - with self.assertRaises(http.client.BadStatusLine) as cm: - self.loop.run_until_complete( - self.stream.read_response_status()) - - self.assertIn('HTTP/1.1 99 test', str(cm.exception)) - - def test_response_status_bad_code_above_999(self): - self.stream.feed_data(b'HTTP/1.1 9999 test\r\n') - with self.assertRaises(http.client.BadStatusLine) as cm: - self.loop.run_until_complete( - self.stream.read_response_status()) - - self.assertIn('HTTP/1.1 9999 test', str(cm.exception)) - - def test_response_status_bad_code_not_int(self): - self.stream.feed_data(b'HTTP/1.1 ttt test\r\n') - with self.assertRaises(http.client.BadStatusLine) as cm: - self.loop.run_until_complete( - self.stream.read_response_status()) - - self.assertIn('HTTP/1.1 ttt test', str(cm.exception)) - - def test_read_headers(self): - self.stream.feed_data(b'test: line\r\n' - b' continue\r\n' - b'test2: data\r\n' - b'\r\n') - - headers = self.loop.run_until_complete(self.stream.read_headers()) - self.assertEqual(headers, - [('TEST', 'line\r\n continue'), ('TEST2', 'data')]) - - def test_read_headers_size(self): - self.stream.feed_data(b'test: line\r\n') - self.stream.feed_data(b' continue\r\n') - self.stream.feed_data(b'test2: data\r\n') - self.stream.feed_data(b'\r\n') - - self.stream.MAX_HEADERS = 5 - self.assertRaises( - http.client.LineTooLong, - self.loop.run_until_complete, - self.stream.read_headers()) - - def test_read_headers_invalid_header(self): - self.stream.feed_data(b'test line\r\n') - - with self.assertRaises(ValueError) as cm: - self.loop.run_until_complete(self.stream.read_headers()) - - self.assertIn("Invalid header b'test line'", str(cm.exception)) - - def test_read_headers_invalid_name(self): - self.stream.feed_data(b'test[]: line\r\n') - - with self.assertRaises(ValueError) as cm: - self.loop.run_until_complete(self.stream.read_headers()) - - self.assertIn("Invalid header name b'TEST[]'", str(cm.exception)) - - def test_read_headers_headers_size(self): - self.stream.MAX_HEADERFIELD_SIZE = 5 - self.stream.feed_data(b'test: line data data\r\ndata\r\n') - - with self.assertRaises(http.client.LineTooLong) as cm: - self.loop.run_until_complete(self.stream.read_headers()) - - self.assertIn("limit request headers fields size", str(cm.exception)) - - def test_read_headers_continuation_headers_size(self): - self.stream.MAX_HEADERFIELD_SIZE = 5 - self.stream.feed_data(b'test: line\r\n test\r\n') - - with self.assertRaises(http.client.LineTooLong) as cm: - self.loop.run_until_complete(self.stream.read_headers()) - - self.assertIn("limit request headers fields size", str(cm.exception)) - - def test_read_message_should_close(self): - self.stream.feed_data( - b'Host: example.com\r\nConnection: close\r\n\r\n') - - msg = self.loop.run_until_complete(self.stream.read_message()) - self.assertTrue(msg.should_close) - - def test_read_message_should_close_http11(self): - self.stream.feed_data( - b'Host: example.com\r\n\r\n') - - msg = self.loop.run_until_complete( - self.stream.read_message(version=(1, 1))) - self.assertFalse(msg.should_close) - - def test_read_message_should_close_http10(self): - self.stream.feed_data( - b'Host: example.com\r\n\r\n') - - msg = self.loop.run_until_complete( - self.stream.read_message(version=(1, 0))) - self.assertTrue(msg.should_close) - - def test_read_message_should_close_keep_alive(self): - self.stream.feed_data( - b'Host: example.com\r\nConnection: keep-alive\r\n\r\n') - - msg = self.loop.run_until_complete(self.stream.read_message()) - self.assertFalse(msg.should_close) - - def test_read_message_content_length_broken(self): - self.stream.feed_data( - b'Host: example.com\r\nContent-Length: qwe\r\n\r\n') - - self.assertRaises( - http.client.HTTPException, - self.loop.run_until_complete, - self.stream.read_message()) - - def test_read_message_content_length_wrong(self): - self.stream.feed_data( - b'Host: example.com\r\nContent-Length: -1\r\n\r\n') - - self.assertRaises( - http.client.HTTPException, - self.loop.run_until_complete, - self.stream.read_message()) - - def test_read_message_content_length(self): - self.stream.feed_data( - b'Host: example.com\r\nContent-Length: 2\r\n\r\n12') - - msg = self.loop.run_until_complete(self.stream.read_message()) - - payload = self.loop.run_until_complete(msg.payload.read()) - self.assertEqual(b'12', payload) - - def test_read_message_content_length_no_val(self): - self.stream.feed_data(b'Host: example.com\r\n\r\n12') - - msg = self.loop.run_until_complete( - self.stream.read_message(readall=False)) - - payload = self.loop.run_until_complete(msg.payload.read()) - self.assertEqual(b'', payload) - - _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) - _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) - - def test_read_message_deflate(self): - self.stream.feed_data( - ('Host: example.com\r\nContent-Length: {}\r\n' - 'Content-Encoding: deflate\r\n\r\n'.format( - len(self._COMPRESSED)).encode())) - self.stream.feed_data(self._COMPRESSED) - - msg = self.loop.run_until_complete(self.stream.read_message()) - payload = self.loop.run_until_complete(msg.payload.read()) - self.assertEqual(b'data', payload) - - def test_read_message_deflate_disabled(self): - self.stream.feed_data( - ('Host: example.com\r\nContent-Encoding: deflate\r\n' - 'Content-Length: {}\r\n\r\n'.format( - len(self._COMPRESSED)).encode())) - self.stream.feed_data(self._COMPRESSED) - - msg = self.loop.run_until_complete( - self.stream.read_message(compression=False)) - payload = self.loop.run_until_complete(msg.payload.read()) - self.assertEqual(self._COMPRESSED, payload) - - def test_read_message_deflate_unknown(self): - self.stream.feed_data( - ('Host: example.com\r\nContent-Encoding: compress\r\n' - 'Content-Length: {}\r\n\r\n'.format( - len(self._COMPRESSED)).encode())) - self.stream.feed_data(self._COMPRESSED) - - msg = self.loop.run_until_complete( - self.stream.read_message(compression=False)) - payload = self.loop.run_until_complete(msg.payload.read()) - self.assertEqual(self._COMPRESSED, payload) - - def test_read_message_websocket(self): - self.stream.feed_data( - b'Host: example.com\r\nSec-Websocket-Key1: 13\r\n\r\n1234567890') - - msg = self.loop.run_until_complete(self.stream.read_message()) - - payload = self.loop.run_until_complete(msg.payload.read()) - self.assertEqual(b'12345678', payload) - - def test_read_message_chunked(self): - self.stream.feed_data( - b'Host: example.com\r\nTransfer-Encoding: chunked\r\n\r\n') - self.stream.feed_data( - b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') - - msg = self.loop.run_until_complete(self.stream.read_message()) - - payload = self.loop.run_until_complete(msg.payload.read()) - self.assertEqual(b'dataline', payload) - - def test_read_message_readall_eof(self): - self.stream.feed_data( - b'Host: example.com\r\n\r\n') - self.stream.feed_data(b'data') - self.stream.feed_data(b'line') - self.stream.feed_eof() - - msg = self.loop.run_until_complete( - self.stream.read_message(readall=True)) - - payload = self.loop.run_until_complete(msg.payload.read()) - self.assertEqual(b'dataline', payload) - - def test_read_message_payload(self): - self.stream.feed_data( - b'Host: example.com\r\n' - b'Content-Length: 8\r\n\r\n') - self.stream.feed_data(b'data') - self.stream.feed_data(b'data') - - msg = self.loop.run_until_complete( - self.stream.read_message(readall=True)) - - data = self.loop.run_until_complete(msg.payload.read()) - self.assertEqual(b'datadata', data) - - def test_read_message_payload_eof(self): - self.stream.feed_data( - b'Host: example.com\r\n' - b'Content-Length: 4\r\n\r\n') - self.stream.feed_data(b'da') - self.stream.feed_eof() - - msg = self.loop.run_until_complete( - self.stream.read_message(readall=True)) - - self.assertRaises( - http.client.IncompleteRead, - self.loop.run_until_complete, msg.payload.read()) - - def test_read_message_length_payload_zero(self): - self.stream.feed_data( - b'Host: example.com\r\n' - b'Content-Length: 0\r\n\r\n') - self.stream.feed_data(b'data') - - msg = self.loop.run_until_complete(self.stream.read_message()) - data = self.loop.run_until_complete(msg.payload.read()) - self.assertEqual(b'', data) - - def test_read_message_length_payload_incomplete(self): - self.stream.feed_data( - b'Host: example.com\r\n' - b'Content-Length: 8\r\n\r\n') - - msg = self.loop.run_until_complete(self.stream.read_message()) - - @tulip.coroutine - def coro(): - self.stream.feed_data(b'data') - self.stream.feed_eof() - return (yield from msg.payload.read()) - - self.assertRaises( - http.client.IncompleteRead, - self.loop.run_until_complete, coro()) - - def test_read_message_eof_payload(self): - self.stream.feed_data(b'Host: example.com\r\n\r\n') - - msg = self.loop.run_until_complete( - self.stream.read_message(readall=True)) - - @tulip.coroutine - def coro(): - self.stream.feed_data(b'data') - self.stream.feed_eof() - return (yield from msg.payload.read()) - - data = self.loop.run_until_complete(coro()) - self.assertEqual(b'data', data) - - def test_read_message_length_payload(self): - self.stream.feed_data( - b'Host: example.com\r\n' - b'Content-Length: 4\r\n\r\n') - self.stream.feed_data(b'da') - self.stream.feed_data(b't') - self.stream.feed_data(b'ali') - self.stream.feed_data(b'ne') - - msg = self.loop.run_until_complete( - self.stream.read_message(readall=True)) - - self.assertIsInstance(msg.payload, tulip.StreamReader) - - data = self.loop.run_until_complete(msg.payload.read()) - self.assertEqual(b'data', data) - self.assertEqual(b'line', b''.join(self.stream.buffer)) - - def test_read_message_length_payload_extra(self): - self.stream.feed_data( - b'Host: example.com\r\n' - b'Content-Length: 4\r\n\r\n') - - msg = self.loop.run_until_complete(self.stream.read_message()) - - @tulip.coroutine - def coro(): - self.stream.feed_data(b'da') - self.stream.feed_data(b't') - self.stream.feed_data(b'ali') - self.stream.feed_data(b'ne') - return (yield from msg.payload.read()) - - data = self.loop.run_until_complete(coro()) - self.assertEqual(b'data', data) - self.assertEqual(b'line', b''.join(self.stream.buffer)) - - def test_parse_length_payload_eof_exc(self): - parser = self.stream._parse_length_payload(4) - next(parser) - - stream = tulip.StreamReader() - parser.send(stream) - self.stream._parser = parser - self.stream.feed_data(b'da') - - @tulip.coroutine - def eof(): - self.stream.feed_eof() - - t1 = tulip.Task(stream.read()) - t2 = tulip.Task(eof()) - - self.loop.run_until_complete(tulip.wait([t1, t2])) - self.assertRaises(http.client.IncompleteRead, t1.result) - self.assertIsNone(self.stream._parser) - - def test_read_message_deflate_payload(self): - comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) - - data = b''.join([comp.compress(b'data'), comp.flush()]) - - self.stream.feed_data( - b'Host: example.com\r\n' - b'Content-Encoding: deflate\r\n' + - ('Content-Length: {}\r\n\r\n'.format(len(data)).encode())) - - msg = self.loop.run_until_complete( - self.stream.read_message(readall=True)) - - @tulip.coroutine - def coro(): - self.stream.feed_data(data) - return (yield from msg.payload.read()) - - data = self.loop.run_until_complete(coro()) - self.assertEqual(b'data', data) - - def test_read_message_chunked_payload(self): - self.stream.feed_data( - b'Host: example.com\r\n' - b'Transfer-Encoding: chunked\r\n\r\n') - - msg = self.loop.run_until_complete(self.stream.read_message()) - - @tulip.coroutine - def coro(): - self.stream.feed_data( - b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') - return (yield from msg.payload.read()) - - data = self.loop.run_until_complete(coro()) - self.assertEqual(b'dataline', data) - - def test_read_message_chunked_payload_chunks(self): - self.stream.feed_data( - b'Host: example.com\r\n' - b'Transfer-Encoding: chunked\r\n\r\n') - - msg = self.loop.run_until_complete(self.stream.read_message()) - - @tulip.coroutine - def coro(): - self.stream.feed_data(b'4\r\ndata\r') - self.stream.feed_data(b'\n4') - self.stream.feed_data(b'\r') - self.stream.feed_data(b'\n') - self.stream.feed_data(b'line\r\n0\r\n') - self.stream.feed_data(b'test\r\n\r\n') - return (yield from msg.payload.read()) - - data = self.loop.run_until_complete(coro()) - self.assertEqual(b'dataline', data) - - def test_read_message_chunked_payload_incomplete(self): - self.stream.feed_data( - b'Host: example.com\r\n' - b'Transfer-Encoding: chunked\r\n\r\n') - - msg = self.loop.run_until_complete(self.stream.read_message()) - - @tulip.coroutine - def coro(): - self.stream.feed_data(b'4\r\ndata\r\n') - self.stream.feed_eof() - return (yield from msg.payload.read()) - - self.assertRaises( - http.client.IncompleteRead, - self.loop.run_until_complete, coro()) - - def test_read_message_chunked_payload_extension(self): - self.stream.feed_data( - b'Host: example.com\r\n' - b'Transfer-Encoding: chunked\r\n\r\n') - - msg = self.loop.run_until_complete(self.stream.read_message()) - - @tulip.coroutine - def coro(): - self.stream.feed_data( - b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n\r\n') - return (yield from msg.payload.read()) - - data = self.loop.run_until_complete(coro()) - self.assertEqual(b'dataline', data) - - def test_read_message_chunked_payload_size_error(self): - self.stream.feed_data( - b'Host: example.com\r\n' - b'Transfer-Encoding: chunked\r\n\r\n') - - msg = self.loop.run_until_complete(self.stream.read_message()) - - @tulip.coroutine - def coro(): - self.stream.feed_data(b'blah\r\n') - return (yield from msg.payload.read()) - - self.assertRaises( - http.client.IncompleteRead, - self.loop.run_until_complete, coro()) - - def test_deflate_stream_set_exception(self): - stream = tulip.StreamReader() - dstream = protocol.DeflateStream(stream, 'deflate') - - exc = ValueError() - dstream.set_exception(exc) - self.assertIs(exc, stream.exception()) - - def test_deflate_stream_feed_data(self): - stream = tulip.StreamReader() - dstream = protocol.DeflateStream(stream, 'deflate') - - dstream.zlib = unittest.mock.Mock() - dstream.zlib.decompress.return_value = b'line' - - dstream.feed_data(b'data') - self.assertEqual([b'line'], list(stream.buffer)) - - def test_deflate_stream_feed_data_err(self): - stream = tulip.StreamReader() - dstream = protocol.DeflateStream(stream, 'deflate') - - exc = ValueError() - dstream.zlib = unittest.mock.Mock() - dstream.zlib.decompress.side_effect = exc - - dstream.feed_data(b'data') - self.assertIsInstance(stream.exception(), http.client.IncompleteRead) - - def test_deflate_stream_feed_eof(self): - stream = tulip.StreamReader() - dstream = protocol.DeflateStream(stream, 'deflate') - - dstream.zlib = unittest.mock.Mock() - dstream.zlib.flush.return_value = b'line' - - dstream.feed_eof() - self.assertEqual([b'line'], list(stream.buffer)) - self.assertTrue(stream.eof) - - def test_deflate_stream_feed_eof_err(self): - stream = tulip.StreamReader() - dstream = protocol.DeflateStream(stream, 'deflate') - - dstream.zlib = unittest.mock.Mock() - dstream.zlib.flush.return_value = b'line' - dstream.zlib.eof = False - - dstream.feed_eof() - self.assertIsInstance(stream.exception(), http.client.IncompleteRead) - - class HttpMessageTests(unittest.TestCase): def setUp(self): @@ -644,7 +42,7 @@ def test_force_chunked(self): self.assertTrue(msg.chunked) def test_keep_alive(self): - msg = protocol.Response(self.transport, 200) + msg = protocol.Response(self.transport, 200, close=True) self.assertFalse(msg.keep_alive()) msg.keepalive = True self.assertTrue(msg.keep_alive()) @@ -654,17 +52,17 @@ def test_keep_alive(self): def test_add_header(self): msg = protocol.Response(self.transport, 200) - self.assertEqual([], msg.headers) + self.assertEqual([], list(msg.headers)) msg.add_header('content-type', 'plain/html') - self.assertEqual([('CONTENT-TYPE', 'plain/html')], msg.headers) + self.assertEqual([('CONTENT-TYPE', 'plain/html')], list(msg.headers)) def test_add_headers(self): msg = protocol.Response(self.transport, 200) - self.assertEqual([], msg.headers) + self.assertEqual([], list(msg.headers)) msg.add_headers(('content-type', 'plain/html')) - self.assertEqual([('CONTENT-TYPE', 'plain/html')], msg.headers) + self.assertEqual([('CONTENT-TYPE', 'plain/html')], list(msg.headers)) def test_add_headers_length(self): msg = protocol.Response(self.transport, 200) @@ -684,16 +82,16 @@ def test_add_headers_upgrade_websocket(self): msg = protocol.Response(self.transport, 200) msg.add_headers(('upgrade', 'test')) - self.assertEqual([], msg.headers) + self.assertEqual([], list(msg.headers)) msg.add_headers(('upgrade', 'websocket')) - self.assertEqual([('UPGRADE', 'websocket')], msg.headers) + self.assertEqual([('UPGRADE', 'websocket')], list(msg.headers)) def test_add_headers_connection_keepalive(self): msg = protocol.Response(self.transport, 200) msg.add_headers(('connection', 'keep-alive')) - self.assertEqual([], msg.headers) + self.assertEqual([], list(msg.headers)) self.assertTrue(msg.keepalive) msg.add_headers(('connection', 'close')) @@ -703,58 +101,67 @@ def test_add_headers_hop_headers(self): msg = protocol.Response(self.transport, 200) msg.add_headers(('connection', 'test'), ('transfer-encoding', 't')) - self.assertEqual([], msg.headers) + self.assertEqual([], list(msg.headers)) def test_default_headers(self): msg = protocol.Response(self.transport, 200) + msg._add_default_headers() - headers = [r for r, _ in msg._default_headers()] + headers = [r for r, _ in msg.headers] self.assertIn('DATE', headers) self.assertIn('CONNECTION', headers) def test_default_headers_server(self): msg = protocol.Response(self.transport, 200) + msg._add_default_headers() - headers = [r for r, _ in msg._default_headers()] + headers = [r for r, _ in msg.headers] self.assertIn('SERVER', headers) def test_default_headers_useragent(self): msg = protocol.Request(self.transport, 'GET', '/') + msg._add_default_headers() - headers = [r for r, _ in msg._default_headers()] + headers = [r for r, _ in msg.headers] self.assertNotIn('SERVER', headers) self.assertIn('USER-AGENT', headers) def test_default_headers_chunked(self): msg = protocol.Response(self.transport, 200) + msg._add_default_headers() - headers = [r for r, _ in msg._default_headers()] + headers = [r for r, _ in msg.headers] self.assertNotIn('TRANSFER-ENCODING', headers) + msg = protocol.Response(self.transport, 200) msg.force_chunked() + msg._add_default_headers() - headers = [r for r, _ in msg._default_headers()] + headers = [r for r, _ in msg.headers] self.assertIn('TRANSFER-ENCODING', headers) def test_default_headers_connection_upgrade(self): msg = protocol.Response(self.transport, 200) msg.upgrade = True + msg._add_default_headers() - headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION'] + headers = [r for r in msg.headers if r[0] == 'CONNECTION'] self.assertEqual([('CONNECTION', 'upgrade')], headers) def test_default_headers_connection_close(self): msg = protocol.Response(self.transport, 200) msg.force_close() + msg._add_default_headers() - headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION'] + headers = [r for r in msg.headers if r[0] == 'CONNECTION'] self.assertEqual([('CONNECTION', 'close')], headers) def test_default_headers_connection_keep_alive(self): msg = protocol.Response(self.transport, 200) msg.keepalive = True + msg._add_default_headers() - headers = [r for r in msg._default_headers() if r[0] == 'CONNECTION'] + headers = [r for r in msg.headers if r[0] == 'CONNECTION'] self.assertEqual([('CONNECTION', 'keep-alive')], headers) def test_send_headers(self): diff --git a/tests/http_server_test.py b/tests/http_server_test.py index 299e950d..c0f09603 100644 --- a/tests/http_server_test.py +++ b/tests/http_server_test.py @@ -48,16 +48,16 @@ def test_data_received(self): srv.connection_made(unittest.mock.Mock()) srv.data_received(b'123') - self.assertEqual(b'123', b''.join(srv.stream.buffer)) + self.assertEqual(b'123', bytes(srv.stream._buffer)) srv.data_received(b'456') - self.assertEqual(b'123456', b''.join(srv.stream.buffer)) + self.assertEqual(b'123456', bytes(srv.stream._buffer)) def test_eof_received(self): srv = server.ServerHttpProtocol() srv.connection_made(unittest.mock.Mock()) srv.eof_received() - self.assertTrue(srv.stream.eof) + self.assertTrue(srv.stream._eof) def test_connection_lost(self): srv = server.ServerHttpProtocol() @@ -85,9 +85,10 @@ def test_handle_error(self): srv = server.ServerHttpProtocol() srv.connection_made(transport) - srv.handle_error(404) + srv.handle_error(404, headers=(('X-Server', 'Tulip'),)) content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) self.assertIn(b'HTTP/1.1 404 Not Found', content) + self.assertIn(b'X-SERVER: Tulip', content) @unittest.mock.patch('tulip.http.server.traceback') def test_handle_error_traceback_exc(self, m_trace): @@ -144,22 +145,22 @@ def test_handle(self): self.loop.run_until_complete(srv._request_handle) self.assertTrue(handle.called) - self.assertIsNone(srv._request_handle) def test_handle_coro(self): transport = unittest.mock.Mock() srv = server.ServerHttpProtocol() - srv.connection_made(transport) called = False @tulip.coroutine - def coro(rline, message): + def coro(message, payload): nonlocal called called = True srv.eof_received() + srv.close() srv.handle_request = coro + srv.connection_made(transport) srv.stream.feed_data( b'GET / HTTP/1.0\r\n' @@ -211,7 +212,7 @@ def side_effect(*args): srv.close() srv.handle_error.side_effect = side_effect - srv.stream.feed_data(b'GET / HT/asd\r\n') + srv.stream.feed_data(b'GET / HT/asd\r\n\r\n') self.loop.run_until_complete(srv._request_handle) self.assertTrue(srv.handle_error.called) @@ -235,3 +236,13 @@ def test_handle_500(self): self.assertTrue(srv.handle_error.called) self.assertTrue(500, srv.handle_error.call_args[0][0]) + + def test_handle_error_no_handle_task(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + srv.connection_lost(None) + close = srv.close = unittest.mock.Mock() + + srv.handle_error(300) + self.assertTrue(close.called) diff --git a/tests/http_websocket_test.py b/tests/http_websocket_test.py index 73c77d09..bd89b75b 100644 --- a/tests/http_websocket_test.py +++ b/tests/http_websocket_test.py @@ -1,11 +1,14 @@ """Tests for http/websocket.py""" +import base64 +import hashlib +import os import struct import unittest import unittest.mock import tulip -from tulip.http import websocket +from tulip.http import websocket, protocol, errors class WebsocketParserTests(unittest.TestCase): @@ -345,3 +348,79 @@ def test_close(self): self.writer.close(1001, b'msg') self.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') + + +class WebSocketHandshakeTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + self.headers = [] + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 0), self.headers, True, None) + + def test_no_upgrade(self): + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, self.message, self.transport) + + def test_no_connection(self): + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'keep-alive')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, self.message, self.transport) + + def test_protocol_version(self): + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, self.message, self.transport) + + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '1')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, self.message, self.transport) + + def test_protocol_key(self): + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, self.message, self.transport) + + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13'), + ('SEC-WEBSOCKET-KEY', '123')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, self.message, self.transport) + + sec_key = base64.b64encode(os.urandom(2)) + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13'), + ('SEC-WEBSOCKET-KEY', sec_key.decode())]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, self.message, self.transport) + + def test_handshake(self): + sec_key = base64.b64encode(os.urandom(16)).decode() + + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13'), + ('SEC-WEBSOCKET-KEY', sec_key)]) + status, headers, parser, writer = websocket.do_handshake( + self.message, self.transport) + self.assertEqual(status, 101) + + key = base64.b64encode( + hashlib.sha1(sec_key.encode() + websocket.WS_KEY).digest()) + headers = dict(headers) + self.assertEqual(headers['SEC-WEBSOCKET-ACCEPT'], key.decode()) diff --git a/tests/http_wsgi_test.py b/tests/http_wsgi_test.py index e2757c91..1145f273 100644 --- a/tests/http_wsgi_test.py +++ b/tests/http_wsgi_test.py @@ -21,10 +21,13 @@ def setUp(self): self.transport.get_extra_info.return_value = '127.0.0.1' self.payload = b'data' - self.info = protocol.RequestLine('GET', '/path', (1, 0)) self.headers = [] - self.message = protocol.RawHttpMessage( - self.headers, b'data', True, 'deflate') + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 0), self.headers, True, 'deflate') + self.payload = tulip.DataBuffer() + self.payload.feed_data(b'data') + self.payload.feed_data(b'data') + self.payload.feed_eof() def tearDown(self): self.loop.close() @@ -38,7 +41,7 @@ def _make_one(self, **kw): srv = wsgi.WSGIServerHttpProtocol(self.wsgi, **kw) srv.stream = self.stream srv.transport = self.transport - return srv.create_wsgi_environ(self.info, self.message, self.payload) + return srv.create_wsgi_environ(self.message, self.payload) def test_environ(self): environ = self._make_one() @@ -81,7 +84,8 @@ def test_environ_host_header(self): self.assertEqual(environ['SERVER_PROTOCOL'], 'HTTP/1.0') def test_environ_host_port_header(self): - self.info = protocol.RequestLine('GET', '/path', (1, 1)) + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 1), self.headers, True, 'deflate') self.headers.append(('HOST', 'python.org:443')) environ = self._make_one() @@ -120,7 +124,7 @@ def test_wsgi_response(self): srv.stream = self.stream srv.transport = self.transport - resp = srv.create_wsgi_response(self.info, self.message) + resp = srv.create_wsgi_response(self.message) self.assertIsInstance(resp, wsgi.WsgiResponse) def test_wsgi_response_start_response(self): @@ -128,7 +132,7 @@ def test_wsgi_response_start_response(self): srv.stream = self.stream srv.transport = self.transport - resp = srv.create_wsgi_response(self.info, self.message) + resp = srv.create_wsgi_response(self.message) resp.start_response( '200 OK', [('CONTENT-TYPE', 'text/plain')]) self.assertEqual(resp.status, '200 OK') @@ -139,7 +143,7 @@ def test_wsgi_response_start_response_exc(self): srv.stream = self.stream srv.transport = self.transport - resp = srv.create_wsgi_response(self.info, self.message) + resp = srv.create_wsgi_response(self.message) resp.start_response( '200 OK', [('CONTENT-TYPE', 'text/plain')], ['', ValueError()]) self.assertEqual(resp.status, '200 OK') @@ -150,7 +154,7 @@ def test_wsgi_response_start_response_exc_status(self): srv.stream = self.stream srv.transport = self.transport - resp = srv.create_wsgi_response(self.info, self.message) + resp = srv.create_wsgi_response(self.message) resp.start_response('200 OK', [('CONTENT-TYPE', 'text/plain')]) self.assertRaises( @@ -186,7 +190,7 @@ def wsgi_app(env, start): srv.transport = self.transport self.loop.run_until_complete( - srv.handle_request(self.info, self.message)) + srv.handle_request(self.message, self.payload)) content = b''.join( [c[1][0] for c in self.transport.write.mock_calls]) @@ -202,16 +206,16 @@ def wsgi_app(env, start): stream = tulip.StreamReader() stream.feed_data(b'data') stream.feed_eof() - self.message = protocol.RawHttpMessage( - self.headers, stream, True, 'deflate') - self.info = protocol.RequestLine('GET', '/path', (1, 1)) + + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 1), self.headers, True, 'deflate') srv = wsgi.WSGIServerHttpProtocol(wsgi_app, True) srv.stream = self.stream srv.transport = self.transport self.loop.run_until_complete( - srv.handle_request(self.info, self.message)) + srv.handle_request(self.message, self.payload)) content = b''.join( [c[1][0] for c in self.transport.write.mock_calls]) @@ -229,7 +233,7 @@ def wsgi_app(env, start): srv.transport = self.transport self.loop.run_until_complete( - srv.handle_request(self.info, self.message)) + srv.handle_request(self.message, self.payload)) content = b''.join( [c[1][0] for c in self.transport.write.mock_calls]) diff --git a/tests/parsers_test.py b/tests/parsers_test.py index cd4e1ca2..8a6f2927 100644 --- a/tests/parsers_test.py +++ b/tests/parsers_test.py @@ -8,234 +8,6 @@ from tulip import tasks -class ParserBufferTests(unittest.TestCase): - - def test_feed_data(self): - buf = parsers.ParserBuffer() - buf.feed_data(b'') - self.assertEqual(len(buf), 0) - - buf.feed_data(b'data') - self.assertEqual(buf.size, 4) - self.assertEqual(len(buf), 4) - self.assertEqual(buf, b'data') - - def test_shrink(self): - buf = parsers.ParserBuffer() - buf.feed_data(b'data') - - buf.shrink() - self.assertEqual(bytes(buf), b'data') - - buf.offset = 2 - buf.shrink() - self.assertEqual(bytes(buf), b'ta') - self.assertEqual(2, len(buf)) - self.assertEqual(2, buf.size) - self.assertEqual(0, buf.offset) - - def test_shrink_feed_data(self): - stream = parsers.StreamBuffer(2) - stream.feed_data(b'data') - self.assertEqual(bytes(stream._buffer), b'data') - - stream._buffer.offset = 2 - stream.feed_data(b'1') - self.assertEqual(bytes(stream._buffer), b'ta1') - self.assertEqual(3, len(stream._buffer)) - self.assertEqual(3, stream._buffer.size) - self.assertEqual(0, stream._buffer.offset) - - def test_read(self): - buf = parsers.ParserBuffer() - p = buf.read(3) - next(p) - p.send(b'1') - try: - p.send(b'234') - except StopIteration as exc: - res = exc.value - - self.assertEqual(res, b'123') - self.assertEqual(b'4', bytes(buf)) - - def test_readsome(self): - buf = parsers.ParserBuffer() - p = buf.readsome(3) - next(p) - try: - p.send(b'1') - except StopIteration as exc: - res = exc.value - self.assertEqual(res, b'1') - - p = buf.readsome(2) - next(p) - try: - p.send(b'234') - except StopIteration as exc: - res = exc.value - self.assertEqual(res, b'23') - self.assertEqual(b'4', bytes(buf)) - - def test_skip(self): - buf = parsers.ParserBuffer() - p = buf.skip(3) - next(p) - p.send(b'1') - try: - p.send(b'234') - except StopIteration as exc: - res = exc.value - - self.assertIsNone(res) - self.assertEqual(b'4', bytes(buf)) - - def test_readuntil_params(self): - buf = parsers.ParserBuffer() - p = buf.readuntil(b'') - self.assertRaises(AssertionError, next, p) - - p = buf.readuntil('\n') - self.assertRaises(AssertionError, next, p) - - def test_readuntil_limit(self): - buf = parsers.ParserBuffer() - p = buf.readuntil(b'\n', limit=4) - next(p) - p.send(b'1') - p.send(b'234') - self.assertRaises(ValueError, p.send, b'5') - - buf = parsers.ParserBuffer() - p = buf.readuntil(b'\n', limit=4) - next(p) - self.assertRaises(ValueError, p.send, b'12345\n6') - - buf = parsers.ParserBuffer() - p = buf.readuntil(b'\n', limit=4) - next(p) - self.assertRaises(ValueError, p.send, b'12345\n6') - - class CustomExc(Exception): - pass - - buf = parsers.ParserBuffer() - p = buf.readuntil(b'\n', limit=4, exc=CustomExc) - next(p) - self.assertRaises(CustomExc, p.send, b'12345\n6') - - def test_readuntil(self): - buf = parsers.ParserBuffer() - p = buf.readuntil(b'\n', limit=4) - next(p) - p.send(b'123') - try: - p.send(b'\n456') - except StopIteration as exc: - res = exc.value - - self.assertEqual(res, b'123\n') - self.assertEqual(b'456', bytes(buf)) - - def test_skipuntil_params(self): - buf = parsers.ParserBuffer() - p = buf.skipuntil(b'') - self.assertRaises(AssertionError, next, p) - - p = buf.skipuntil('\n') - self.assertRaises(AssertionError, next, p) - - def test_skipuntil(self): - buf = parsers.ParserBuffer() - p = buf.skipuntil(b'\n') - next(p) - p.send(b'123') - try: - p.send(b'\n456\n') - except StopIteration: - pass - self.assertEqual(b'456\n', bytes(buf)) - - p = buf.readline() - try: - next(p) - except StopIteration as exc: - res = exc.value - self.assertEqual(b'', bytes(buf)) - self.assertEqual(b'456\n', res) - - def test_readline_limit(self): - buf = parsers.ParserBuffer() - p = buf.readline(limit=4) - next(p) - p.send(b'1') - p.send(b'234') - self.assertRaises(ValueError, p.send, b'5') - - buf = parsers.ParserBuffer() - p = buf.readline(limit=4) - next(p) - self.assertRaises(ValueError, p.send, b'12345\n6') - - buf = parsers.ParserBuffer() - p = buf.readline(limit=4) - next(p) - self.assertRaises(ValueError, p.send, b'12345\n6') - - def test_readline(self): - buf = parsers.ParserBuffer() - p = buf.readline(limit=4) - next(p) - p.send(b'123') - try: - p.send(b'\n456') - except StopIteration as exc: - res = exc.value - - self.assertEqual(res, b'123\n') - self.assertEqual(b'456', bytes(buf)) - - def test_lines_parser(self): - out = parsers.DataBuffer() - buf = parsers.ParserBuffer() - p = parsers.lines_parser() - next(p) - p.send((out, buf)) - - for d in (b'line1', b'\r\n', b'lin', b'e2\r', b'\ndata'): - p.send(d) - - self.assertEqual( - [bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], - list(out._buffer)) - try: - p.throw(parsers.EofStream()) - except parsers.EofStream: - pass - - self.assertEqual(bytes(buf), b'data') - - def test_chunks_parser(self): - out = parsers.DataBuffer() - buf = parsers.ParserBuffer() - p = parsers.chunks_parser(5) - next(p) - p.send((out, buf)) - - for d in (b'line1', b'lin', b'e2d', b'ata'): - p.send(d) - - self.assertEqual( - [bytearray(b'line1'), bytearray(b'line2')], list(out._buffer)) - try: - p.throw(parsers.EofStream()) - except parsers.EofStream: - pass - - self.assertEqual(bytes(buf), b'data') - - class StreamBufferTests(unittest.TestCase): DATA = b'line1\nline2\nline3\n' @@ -380,8 +152,8 @@ def test_set_parser_feed_existing_stop(self): def lines_parser(): out, buf = yield try: - out.feed_data((yield from buf.readline())) - out.feed_data((yield from buf.readline())) + out.feed_data((yield from buf.readuntil(b'\n'))) + out.feed_data((yield from buf.readuntil(b'\n'))) finally: out.feed_eof() @@ -390,8 +162,7 @@ def lines_parser(): stream.feed_data(b'\r\nline2\r\ndata') s = stream.set_parser(lines_parser()) - self.assertEqual([bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], - list(s._buffer)) + self.assertEqual(b'line1\r\nline2\r\n', b''.join(s._buffer)) self.assertEqual(b'data', bytes(stream._buffer)) self.assertIsNone(stream._parser) self.assertTrue(s._eof) @@ -655,3 +426,174 @@ def test_connection_lost_exc(self): exc = ValueError() proto.connection_lost(exc) self.assertIs(proto.exception(), exc) + + +class ParserBuffer(unittest.TestCase): + + def _make_one(self): + return parsers.ParserBuffer() + + def test_shrink(self): + buf = parsers.ParserBuffer() + buf.feed_data(b'data') + + buf._shrink() + self.assertEqual(bytes(buf), b'data') + + buf.offset = 2 + buf._shrink() + self.assertEqual(bytes(buf), b'ta') + self.assertEqual(2, len(buf)) + self.assertEqual(2, buf.size) + self.assertEqual(0, buf.offset) + + def test_feed_data(self): + buf = self._make_one() + buf.feed_data(b'') + self.assertEqual(len(buf), 0) + + buf.feed_data(b'data') + self.assertEqual(len(buf), 4) + self.assertEqual(bytes(buf), b'data') + + def test_read(self): + buf = self._make_one() + p = buf.read(3) + next(p) + p.send(b'1') + try: + p.send(b'234') + except StopIteration as exc: + res = exc.value + + self.assertEqual(res, b'123') + self.assertEqual(b'4', bytes(buf)) + + def test_readsome(self): + buf = self._make_one() + p = buf.readsome(3) + next(p) + try: + p.send(b'1') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, b'1') + + p = buf.readsome(2) + next(p) + try: + p.send(b'234') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, b'23') + self.assertEqual(b'4', bytes(buf)) + + def test_skip(self): + buf = self._make_one() + p = buf.skip(3) + next(p) + p.send(b'1') + try: + p.send(b'234') + except StopIteration as exc: + res = exc.value + + self.assertIsNone(res) + self.assertEqual(b'4', bytes(buf)) + + def test_readuntil_limit(self): + buf = self._make_one() + p = buf.readuntil(b'\n', 4) + next(p) + p.send(b'1') + p.send(b'234') + self.assertRaises(ValueError, p.send, b'5') + + buf = parsers.ParserBuffer() + p = buf.readuntil(b'\n', 4) + next(p) + self.assertRaises(ValueError, p.send, b'12345\n6') + + buf = parsers.ParserBuffer() + p = buf.readuntil(b'\n', 4) + next(p) + self.assertRaises(ValueError, p.send, b'12345\n6') + + class CustomExc(Exception): + pass + + buf = parsers.ParserBuffer() + p = buf.readuntil(b'\n', 4, CustomExc) + next(p) + self.assertRaises(CustomExc, p.send, b'12345\n6') + + def test_readuntil(self): + buf = self._make_one() + p = buf.readuntil(b'\n', 4) + next(p) + p.send(b'123') + try: + p.send(b'\n456') + except StopIteration as exc: + res = exc.value + + self.assertEqual(res, b'123\n') + self.assertEqual(b'456', bytes(buf)) + + def test_skipuntil(self): + buf = self._make_one() + p = buf.skipuntil(b'\n') + next(p) + p.send(b'123') + try: + p.send(b'\n456\n') + except StopIteration: + pass + self.assertEqual(b'456\n', bytes(buf)) + + p = buf.readuntil(b'\n') + try: + next(p) + except StopIteration as exc: + res = exc.value + self.assertEqual(b'', bytes(buf)) + self.assertEqual(b'456\n', res) + + def test_lines_parser(self): + out = parsers.DataBuffer() + buf = self._make_one() + p = parsers.lines_parser() + next(p) + p.send((out, buf)) + + for d in (b'line1', b'\r\n', b'lin', b'e2\r', b'\ndata'): + p.send(d) + + self.assertEqual( + [bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(out._buffer)) + try: + p.throw(parsers.EofStream()) + except parsers.EofStream: + pass + + self.assertEqual(bytes(buf), b'data') + + def test_chunks_parser(self): + out = parsers.DataBuffer() + buf = self._make_one() + p = parsers.chunks_parser(5) + next(p) + p.send((out, buf)) + + for d in (b'line1', b'lin', b'e2d', b'ata'): + p.send(d) + + self.assertEqual( + [bytearray(b'line1'), bytearray(b'line2')], list(out._buffer)) + try: + p.throw(parsers.EofStream()) + except parsers.EofStream: + pass + + self.assertEqual(bytes(buf), b'data') diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index ea832029..b96719a8 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -2,9 +2,6 @@ import errno import io -import os -import stat -import tempfile import unittest import unittest.mock @@ -182,18 +179,16 @@ def setUp(self): self.pipe.fileno.return_value = 5 self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) - fcntl_patcher = unittest.mock.patch('fcntl.fcntl') - fcntl_patcher.start() - self.addCleanup(fcntl_patcher.stop) - - def test_ctor(self): + @unittest.mock.patch('fcntl.fcntl') + def test_ctor(self, m_fcntl): tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol) self.event_loop.add_reader.assert_called_with(5, tr._read_ready) self.event_loop.call_soon.assert_called_with( self.protocol.connection_made, tr) - def test_ctor_with_waiter(self): + @unittest.mock.patch('fcntl.fcntl') + def test_ctor_with_waiter(self, m_fcntl): fut = futures.Future() unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol, fut) @@ -201,7 +196,8 @@ def test_ctor_with_waiter(self): fut.cancel() @unittest.mock.patch('os.read') - def test__read_ready(self, m_read): + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready(self, m_fcntl, m_read): tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol) m_read.return_value = b'data' @@ -211,7 +207,8 @@ def test__read_ready(self, m_read): self.protocol.data_received.assert_called_with(b'data') @unittest.mock.patch('os.read') - def test__read_ready_eof(self, m_read): + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready_eof(self, m_fcntl, m_read): tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol) m_read.return_value = b'' @@ -222,7 +219,8 @@ def test__read_ready_eof(self, m_read): self.protocol.eof_received.assert_called_with() @unittest.mock.patch('os.read') - def test__read_ready_blocked(self, m_read): + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready_blocked(self, m_fcntl, m_read): tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol) self.event_loop.reset_mock() @@ -234,7 +232,8 @@ def test__read_ready_blocked(self, m_read): @unittest.mock.patch('tulip.log.tulip_log.exception') @unittest.mock.patch('os.read') - def test__read_ready_error(self, m_read, m_logexc): + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready_error(self, m_fcntl, m_read, m_logexc): tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol) err = OSError() @@ -247,40 +246,50 @@ def test__read_ready_error(self, m_read, m_logexc): m_logexc.assert_called_with('Fatal error for %s', tr) @unittest.mock.patch('os.read') - def test_pause(self, m_read): + @unittest.mock.patch('fcntl.fcntl') + def test_pause(self, m_fcntl, m_read): tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol) + tr.pause() self.event_loop.remove_reader.assert_called_with(5) @unittest.mock.patch('os.read') - def test_resume(self, m_read): + @unittest.mock.patch('fcntl.fcntl') + def test_resume(self, m_fcntl, m_read): tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol) + tr.resume() self.event_loop.add_reader.assert_called_with(5, tr._read_ready) @unittest.mock.patch('os.read') - def test_close(self, m_read): + @unittest.mock.patch('fcntl.fcntl') + def test_close(self, m_fcntl, m_read): tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol) + tr._close = unittest.mock.Mock() tr.close() tr._close.assert_called_with(None) @unittest.mock.patch('os.read') - def test_close_already_closing(self, m_read): + @unittest.mock.patch('fcntl.fcntl') + def test_close_already_closing(self, m_fcntl, m_read): tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol) + tr._closing = True tr._close = unittest.mock.Mock() tr.close() self.assertFalse(tr._close.called) @unittest.mock.patch('os.read') - def test__close(self, m_read): + @unittest.mock.patch('fcntl.fcntl') + def test__close(self, m_fcntl, m_read): tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol) + err = object() tr._close(err) self.assertTrue(tr._closing) @@ -288,17 +297,21 @@ def test__close(self, m_read): self.event_loop.call_soon.assert_called_with( tr._call_connection_lost, err) - def test__call_connection_lost(self): + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost(self, m_fcntl): tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol) + err = None tr._call_connection_lost(err) self.protocol.connection_lost.assert_called_with(err) self.pipe.close.assert_called_with() - def test__call_connection_lost_with_err(self): + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost_with_err(self, m_fcntl): tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, self.protocol) + err = OSError() tr._call_connection_lost(err) self.protocol.connection_lost.assert_called_with(err) @@ -313,43 +326,33 @@ def setUp(self): self.pipe.fileno.return_value = 5 self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) - fcntl_patcher = unittest.mock.patch('fcntl.fcntl') - fcntl_patcher.start() - self.addCleanup(fcntl_patcher.stop) - - self.fstat_patcher = unittest.mock.patch('os.fstat') - m_fstat = self.fstat_patcher.start() - st = unittest.mock.Mock() - st.st_mode = stat.S_IFIFO - m_fstat.return_value = st - self.addCleanup(self.fstat_patcher.stop) - - def test_ctor(self): + @unittest.mock.patch('fcntl.fcntl') + def test_ctor(self, m_fcntl): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) - self.event_loop.add_reader.assert_called_with(5, tr._read_ready) self.event_loop.call_soon.assert_called_with( self.protocol.connection_made, tr) - self.assertTrue(tr._enable_read_hack) - def test_ctor_with_waiter(self): + @unittest.mock.patch('fcntl.fcntl') + def test_ctor_with_waiter(self, m_fcntl): fut = futures.Future() - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol, fut) - self.event_loop.add_reader.assert_called_with(5, tr._read_ready) + unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol, fut) self.event_loop.call_soon.assert_called_with(fut.set_result, None) fut.cancel() - self.assertTrue(tr._enable_read_hack) - def test_can_write_eof(self): + @unittest.mock.patch('fcntl.fcntl') + def test_can_write_eof(self, m_fcntl): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) self.assertTrue(tr.can_write_eof()) @unittest.mock.patch('os.write') - def test_write(self, m_write): + @unittest.mock.patch('fcntl.fcntl') + def test_write(self, m_fcntl, m_write): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) + m_write.return_value = 4 tr.write(b'data') m_write.assert_called_with(5, b'data') @@ -357,18 +360,22 @@ def test_write(self, m_write): self.assertEqual([], tr._buffer) @unittest.mock.patch('os.write') - def test_write_no_data(self, m_write): + @unittest.mock.patch('fcntl.fcntl') + def test_write_no_data(self, m_fcntl, m_write): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) + tr.write(b'') self.assertFalse(m_write.called) self.assertFalse(self.event_loop.add_writer.called) self.assertEqual([], tr._buffer) @unittest.mock.patch('os.write') - def test_write_partial(self, m_write): + @unittest.mock.patch('fcntl.fcntl') + def test_write_partial(self, m_fcntl, m_write): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) + m_write.return_value = 2 tr.write(b'data') m_write.assert_called_with(5, b'data') @@ -376,9 +383,11 @@ def test_write_partial(self, m_write): self.assertEqual([b'ta'], tr._buffer) @unittest.mock.patch('os.write') - def test_write_buffer(self, m_write): + @unittest.mock.patch('fcntl.fcntl') + def test_write_buffer(self, m_fcntl, m_write): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) + tr._buffer = [b'previous'] tr.write(b'data') self.assertFalse(m_write.called) @@ -386,9 +395,11 @@ def test_write_buffer(self, m_write): self.assertEqual([b'previous', b'data'], tr._buffer) @unittest.mock.patch('os.write') - def test_write_again(self, m_write): + @unittest.mock.patch('fcntl.fcntl') + def test_write_again(self, m_fcntl, m_write): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) + m_write.side_effect = BlockingIOError() tr.write(b'data') m_write.assert_called_with(5, b'data') @@ -397,9 +408,11 @@ def test_write_again(self, m_write): @unittest.mock.patch('tulip.unix_events.tulip_log') @unittest.mock.patch('os.write') - def test_write_err(self, m_write, m_log): + @unittest.mock.patch('fcntl.fcntl') + def test_write_err(self, m_fcntl, m_write, m_log): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) + err = OSError() m_write.side_effect = err tr._fatal_error = unittest.mock.Mock() @@ -419,18 +432,9 @@ def test_write_err(self, m_write, m_log): m_log.warning.assert_called_with( 'os.write(pipe, data) raised exception.') - def test__read_ready(self): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) - tr._read_ready() - self.event_loop.remove_writer.assert_called_with(5) - self.event_loop.remove_reader.assert_called_with(5) - self.assertTrue(tr._closing) - self.event_loop.call_soon.assert_called_with(tr._call_connection_lost, - None) - @unittest.mock.patch('os.write') - def test__write_ready(self, m_write): + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready(self, m_fcntl, m_write): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) tr._buffer = [b'da', b'ta'] @@ -441,9 +445,11 @@ def test__write_ready(self, m_write): self.assertEqual([], tr._buffer) @unittest.mock.patch('os.write') - def test__write_ready_partial(self, m_write): + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_partial(self, m_fcntl, m_write): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) + tr._buffer = [b'da', b'ta'] m_write.return_value = 3 tr._write_ready() @@ -452,9 +458,11 @@ def test__write_ready_partial(self, m_write): self.assertEqual([b'a'], tr._buffer) @unittest.mock.patch('os.write') - def test__write_ready_again(self, m_write): + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_again(self, m_fcntl, m_write): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) + tr._buffer = [b'da', b'ta'] m_write.side_effect = BlockingIOError() tr._write_ready() @@ -463,9 +471,11 @@ def test__write_ready_again(self, m_write): self.assertEqual([b'data'], tr._buffer) @unittest.mock.patch('os.write') - def test__write_ready_empty(self, m_write): + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_empty(self, m_fcntl, m_write): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) + tr._buffer = [b'da', b'ta'] m_write.return_value = 0 tr._write_ready() @@ -475,9 +485,11 @@ def test__write_ready_empty(self, m_write): @unittest.mock.patch('tulip.log.tulip_log.exception') @unittest.mock.patch('os.write') - def test__write_ready_err(self, m_write, m_logexc): + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_err(self, m_fcntl, m_write, m_logexc): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) + tr._buffer = [b'da', b'ta'] m_write.side_effect = err = OSError() tr._write_ready() @@ -489,28 +501,13 @@ def test__write_ready_err(self, m_write, m_logexc): tr._call_connection_lost, err) m_logexc.assert_called_with('Fatal error for %s', tr) self.assertEqual(1, tr._conn_lost) - self.event_loop.remove_reader.assert_called_with(5) @unittest.mock.patch('os.write') - def test__write_ready_closing(self, m_write): + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_closing(self, m_fcntl, m_write): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) - tr._closing = True - tr._buffer = [b'da', b'ta'] - m_write.return_value = 4 - tr._write_ready() - m_write.assert_called_with(5, b'data') - self.event_loop.remove_writer.assert_called_with(5) - self.assertEqual([], tr._buffer) - self.protocol.connection_lost.assert_called_with(None) - self.pipe.close.assert_called_with() - self.event_loop.remove_reader.assert_called_with(5) - @unittest.mock.patch('os.write') - def test__write_ready_closing_regular_file(self, m_write): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) - tr._enable_read_hack = False tr._closing = True tr._buffer = [b'da', b'ta'] m_write.return_value = 4 @@ -520,27 +517,13 @@ def test__write_ready_closing_regular_file(self, m_write): self.assertEqual([], tr._buffer) self.protocol.connection_lost.assert_called_with(None) self.pipe.close.assert_called_with() - self.assertFalse(self.event_loop.remove_reader.called) @unittest.mock.patch('os.write') - def test_abort(self, m_write): + @unittest.mock.patch('fcntl.fcntl') + def test_abort(self, m_fcntl, m_write): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) - tr._buffer = [b'da', b'ta'] - tr.abort() - self.assertFalse(m_write.called) - self.event_loop.remove_writer.assert_called_with(5) - self.assertEqual([], tr._buffer) - self.assertTrue(tr._closing) - self.event_loop.call_soon.assert_called_with( - tr._call_connection_lost, None) - self.event_loop.remove_reader.assert_called_with(5) - @unittest.mock.patch('os.write') - def test_abort_closing_regular_file(self, m_write): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) - tr._enable_read_hack = False tr._buffer = [b'da', b'ta'] tr.abort() self.assertFalse(m_write.called) @@ -549,78 +532,61 @@ def test_abort_closing_regular_file(self, m_write): self.assertTrue(tr._closing) self.event_loop.call_soon.assert_called_with( tr._call_connection_lost, None) - self.assertFalse(self.event_loop.remove_reader.called) - def test__call_connection_lost(self): + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost(self, m_fcntl): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) + err = None tr._call_connection_lost(err) self.protocol.connection_lost.assert_called_with(err) self.pipe.close.assert_called_with() - def test__call_connection_lost_with_err(self): + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost_with_err(self, m_fcntl): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) + err = OSError() tr._call_connection_lost(err) self.protocol.connection_lost.assert_called_with(err) self.pipe.close.assert_called_with() - def test_close(self): + @unittest.mock.patch('fcntl.fcntl') + def test_close(self, m_fcntl): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) + tr.write_eof = unittest.mock.Mock() tr.close() tr.write_eof.assert_called_with() - def test_close_closing(self): + @unittest.mock.patch('fcntl.fcntl') + def test_close_closing(self, m_fcntl): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) + tr.write_eof = unittest.mock.Mock() tr._closing = True tr.close() self.assertFalse(tr.write_eof.called) - def test_write_eof(self): + @unittest.mock.patch('fcntl.fcntl') + def test_write_eof(self, m_fcntl): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) - tr.write_eof() - self.assertTrue(tr._closing) - self.event_loop.call_soon.assert_called_with( - tr._call_connection_lost, None) - self.event_loop.remove_reader.assert_called_with(5) - def test_write_eof_dont_remove_reader_for_regular_file(self): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) - tr._enable_read_hack = False tr.write_eof() self.assertTrue(tr._closing) self.event_loop.call_soon.assert_called_with( tr._call_connection_lost, None) - self.assertFalse(self.event_loop.remove_reader.called) - def test_write_eof_pending(self): + @unittest.mock.patch('fcntl.fcntl') + def test_write_eof_pending(self, m_fcntl): tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, self.protocol) tr._buffer = [b'data'] tr.write_eof() self.assertTrue(tr._closing) self.assertFalse(self.protocol.connection_lost.called) - - -class UntxWritePipeRegularFileTests(unittest.TestCase): - - def setUp(self): - self.event_loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop) - self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) - - def test_ctor_with_regular_file(self): - with tempfile.TemporaryFile() as f: - tr = unix_events._UnixWritePipeTransport(self.event_loop, f, - self.protocol) - self.assertFalse(self.event_loop.add_reader.called) - self.event_loop.call_soon.assert_called_with( - self.protocol.connection_made, tr) - self.assertFalse(tr._enable_read_hack) diff --git a/tulip/http/client.py b/tulip/http/client.py index 0fff6e86..fe58e658 100644 --- a/tulip/http/client.py +++ b/tulip/http/client.py @@ -24,7 +24,7 @@ import urllib.parse import tulip -from tulip.http import protocol +import tulip.http @tulip.coroutine @@ -116,10 +116,10 @@ def request(method, url, *, @tulip.coroutine def start(req, loop): transport, p = yield from loop.create_connection( - HttpProtocol, req.host, req.port, ssl=req.ssl) + tulip.StreamProtocol, req.host, req.port, ssl=req.ssl) try: resp = req.send(transport) - yield from resp.start(p.stream, transport) + yield from resp.start(p, transport) except: transport.close() raise @@ -127,25 +127,6 @@ def start(req, loop): return resp -class HttpProtocol(tulip.Protocol): - - stream = None - transport = None - - def connection_made(self, transport): - self.transport = transport - self.stream = protocol.HttpStreamReader() - - def data_received(self, data): - self.stream.feed_data(data) - - def eof_received(self): - self.stream.feed_eof() - - def connection_lost(self, exc): - pass - - class HttpRequest: GET_METHODS = {'DELETE', 'GET', 'HEAD', 'OPTIONS'} @@ -325,7 +306,7 @@ def send(self, transport): self.headers['content-type'] = ( 'application/x-www-form-urlencoded') if 'content-length' not in self.headers and not chunked: - self.headers['content-length'] = len(self.body) + self.headers['content-length'] = str(len(self.body)) # files (multipart/form-data) elif files: @@ -382,7 +363,7 @@ def send(self, transport): request.add_chunking_filter(8196) else: chunked = False - self.headers['content-length'] = len(self.body) + self.headers['content-length'] = str(len(self.body)) request.add_headers(*self.headers.items()) request.send_headers() @@ -406,9 +387,8 @@ class HttpResponse(http.client.HTTPMessage): reason = None # Reason-Phrase content = None # payload stream - - _content = None - _transport = None + stream = None # input stream + transport = None # current transport def __init__(self, method, url, host=''): super().__init__() @@ -416,6 +396,7 @@ def __init__(self, method, url, host=''): self.method = method self.url = url self.host = host + self._content = None def __repr__(self): out = io.StringIO() @@ -426,41 +407,54 @@ def __repr__(self): def start(self, stream, transport): """Start response processing.""" - self._transport = transport + self.stream = stream + self.transport = transport - # read status - self.version, self.status, self.reason = ( - yield from stream.read_response_status()) + httpstream = stream.set_parser(tulip.http.http_response_parser()) - # does the body have a fixed length? (of zero) - length = None - if (self.status == http.client.NO_CONTENT or - self.status == http.client.NOT_MODIFIED or - 100 <= self.status < 200 or self.method == "HEAD"): - length = 0 + # read response + message = yield from httpstream.read() - # http message - message = yield from stream.read_message(length=length) + # response status + self.version = message.version + self.status = message.code + self.reason = message.reason # headers for hdr, val in message.headers: self.add_header(hdr, val) # payload - self.content = message.payload + self.content = stream.set_parser( + tulip.http.http_payload_parser(message)) return self def close(self): - if self._transport is not None: - self._transport.close() - self._transport = None + if self.transport is not None: + self.transport.close() + self.transport = None @tulip.coroutine def read(self, decode=False): """Read response payload. Decode known types of content.""" if self._content is None: - self._content = yield from self.content.read() + buf = [] + total = 0 + chunk = yield from self.content.read() + while chunk: + size = len(chunk) + buf.append((chunk, size)) + total += size + chunk = yield from self.content.read() + + self._content = bytearray(total) + + idx = 0 + content = memoryview(self._content) + for chunk, size in buf: + content[idx:idx+size] = chunk + idx += size data = self._content diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py index 378feff1..a7fc567f 100644 --- a/tulip/http/protocol.py +++ b/tulip/http/protocol.py @@ -1,8 +1,9 @@ """Http related helper utils.""" -__all__ = ['HttpStreamReader', - 'HttpMessage', 'Request', 'Response', - 'RawHttpMessage', 'RequestLine', 'ResponseStatus'] +__all__ = ['HttpMessage', 'Request', 'Response', + 'RawRequestMessage', 'RawResponseMessage', + 'http_request_parser', 'http_response_parser', + 'http_payload_parser'] import collections import functools @@ -10,87 +11,50 @@ import itertools import re import sys -import time import zlib from wsgiref.handlers import format_date_time import tulip -from . import errors +from tulip.http import errors METHRE = re.compile('[A-Z0-9$-_.]+') VERSRE = re.compile('HTTP/(\d+).(\d+)') -HDRRE = re.compile(b"[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]") -CONTINUATION = (b' ', b'\t') -RESPONSES = http.server.BaseHTTPRequestHandler.responses - -RequestLine = collections.namedtuple( - 'RequestLine', ['method', 'uri', 'version']) - - -ResponseStatus = collections.namedtuple( - 'ResponseStatus', ['version', 'code', 'reason']) - - -RawHttpMessage = collections.namedtuple( - 'RawHttpMessage', ['headers', 'payload', 'should_close', 'compression']) - - -class HttpStreamReader(tulip.StreamReader): - - MAX_HEADERS = 32768 - MAX_HEADERFIELD_SIZE = 8190 - - # if _parser is set, feed_data and feed_eof sends data into - # _parser instead of self. is it being used as stream redirection for - # _parse_chunked_payload, _parse_length_payload and _parse_eof_payload - _parser = None +HDRRE = re.compile('[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]') +CONTINUATION = (' ', '\t') +EOF_MARKER = object() +EOL_MARKER = object() - def feed_data(self, data): - """_parser is a generator, if _parser is set, feed_data sends - incoming data into the generator untile generator stops.""" - if self._parser: - try: - self._parser.send(data) - except StopIteration as exc: - self._parser = None - if exc.value: - self.feed_data(exc.value) - else: - super().feed_data(data) +RESPONSES = http.server.BaseHTTPRequestHandler.responses - def feed_eof(self): - """_parser is a generator, if _parser is set feed_eof throws - StreamEofException into this generator.""" - if self._parser: - try: - self._parser.throw(StreamEofException()) - except StopIteration: - self._parser = None - super().feed_eof() +RawRequestMessage = collections.namedtuple( + 'RawRequestLine', + ['method', 'path', 'version', 'headers', 'should_close', 'compression']) - @tulip.coroutine - def read_request_line(self): - """Read request status line. Exception errors.BadStatusLine - could be raised in case of any errors in status line. - Returns three values (method, uri, version) - Example: +RawResponseMessage = collections.namedtuple( + 'RawResponseStatus', + ['version', 'code', 'reason', 'headers', 'should_close', 'compression']) - GET /path HTTP/1.1 - >> yield from reader.read_request_line() - ('GET', '/path', (1, 1)) +def http_request_parser(max_line_size=8190, + max_headers=32768, max_field_size=8190): + """Read request status line. Exception errors.BadStatusLine + could be raised in case of any errors in status line. + Returns RawRequestMessage. + """ + out, buf = yield - """ - bline = yield from self.readline() - try: - line = bline.decode('ascii').rstrip() - except UnicodeDecodeError: - raise errors.BadStatusLine(bline) from None + try: + # read http message (request line + headers) + raw_data = yield from buf.readuntil( + b'\r\n\r\n', max_headers, errors.LineTooLong) + lines = raw_data.decode('ascii', 'surrogateescape').splitlines(True) + # request line + line = lines[0] try: - method, uri, version = line.split(None, 2) + method, path, version = line.split(None, 2) except ValueError: raise errors.BadStatusLine(line) from None @@ -105,33 +69,37 @@ def read_request_line(self): raise errors.BadStatusLine(version) version = (int(match.group(1)), int(match.group(2))) - return RequestLine(method, uri, version) - - @tulip.coroutine - def read_response_status(self): - """Read response status line. Exception errors.BadStatusLine - could be raised in case of any errors in status line. - Returns three values (version, status_code, reason) - - Example: - - HTTP/1.1 200 Ok - - >> yield from reader.read_response_status() - ((1, 1), 200, 'Ok') - - """ - bline = yield from self.readline() - if not bline: - # Presumably, the server closed the connection before - # sending a valid response. - raise errors.BadStatusLine(bline) - - try: - line = bline.decode('ascii').rstrip() - except UnicodeDecodeError: - raise errors.BadStatusLine(bline) from None - + # read headers + headers, close, compression = parse_headers( + lines, max_line_size, max_headers, max_field_size) + if close is None: + close = version <= (1, 0) + + out.feed_data( + RawRequestMessage( + method, path, version, headers, close, compression)) + out.feed_eof() + except tulip.EofStream: + # Presumably, the server closed the connection before + # sending a valid response. + pass + + +def http_response_parser(max_line_size=8190, + max_headers=32768, max_field_size=8190): + """Read response status line and headers. + + BadStatusLine could be raised in case of any errors in status line. + Returns RawResponseMessage""" + out, buf = yield + + try: + # read http message (response line + headers) + raw_data = yield from buf.readuntil( + b'\r\n\r\n', max_line_size+max_headers, errors.LineTooLong) + lines = raw_data.decode('ascii', 'surrogateescape').splitlines(True) + + line = lines[0] try: version, status = line.split(None, 1) except ValueError: @@ -157,307 +125,210 @@ def read_response_status(self): if status < 100 or status > 999: raise errors.BadStatusLine(line) - return ResponseStatus(version, status, reason.strip()) + # read headers + headers, close, compression = parse_headers( + lines, max_line_size, max_headers, max_field_size) + if close is None: + close = version <= (1, 0) - @tulip.coroutine - def read_headers(self): - """Read and parses RFC2822 headers from a stream. + out.feed_data( + RawResponseMessage( + version, status, reason.strip(), headers, close, compression)) + out.feed_eof() + except tulip.EofStream: + # Presumably, the server closed the connection before + # sending a valid response. + raise errors.BadStatusLine(b'') from None - Line continuations are supported. Returns list of header name - and value pairs. Header name is in upper case. - """ - size = 0 - headers = [] - line = yield from self.readline() +def parse_headers(lines, max_line_size, max_headers, max_field_size): + """Parses RFC2822 headers from a stream. - while line not in (b'\r\n', b'\n'): - header_length = len(line) + Line continuations are supported. Returns list of header name + and value pairs. Header name is in upper case. + """ + close_conn = None + encoding = None + headers = collections.deque() - # Parse initial header name : value pair. - sep_pos = line.find(b':') - if sep_pos < 0: - raise ValueError('Invalid header {}'.format(line.strip())) + lines_idx = 1 + line = lines[1] - name, value = line[:sep_pos], line[sep_pos+1:] - name = name.rstrip(b' \t').upper() - if HDRRE.search(name): - raise ValueError('Invalid header name {}'.format(name)) + while line not in ('\r\n', '\n'): + header_length = len(line) - name = name.strip().decode('ascii', 'surrogateescape') - value = [value.lstrip()] + # Parse initial header name : value pair. + try: + name, value = line.split(':', 1) + except ValueError: + raise ValueError('Invalid header: {}'.format(line)) from None - # next line - line = yield from self.readline() + name = name.strip(' \t').upper() + if HDRRE.search(name): + raise ValueError('Invalid header name: {}'.format(name)) - # consume continuation lines - continuation = line.startswith(CONTINUATION) + # next line + lines_idx += 1 + line = lines[lines_idx] - if continuation: - while continuation: - header_length += len(line) - if header_length > self.MAX_HEADERFIELD_SIZE: - raise errors.LineTooLong( - 'limit request headers fields size') - value.append(line) + # consume continuation lines + continuation = line[0] in CONTINUATION - line = yield from self.readline() - continuation = line.startswith(CONTINUATION) - else: - if header_length > self.MAX_HEADERFIELD_SIZE: + if continuation: + value = [value] + while continuation: + header_length += len(line) + if header_length > max_field_size: raise errors.LineTooLong( 'limit request headers fields size') + value.append(line) - # total headers size - size += header_length - if size >= self.MAX_HEADERS: - raise errors.LineTooLong('limit request headers fields') - - headers.append( - (name, - b''.join(value).rstrip().decode('ascii', 'surrogateescape'))) - - return headers - - def _parse_chunked_payload(self): - """Chunked transfer encoding parser.""" - stream = yield - - try: - data = bytearray() - - while True: - # read line - if b'\n' not in data: - data.extend((yield)) - continue - - line, data = data.split(b'\n', 1) - - # Read the next chunk size from the file - i = line.find(b';') - if i >= 0: - line = line[:i] # strip chunk-extensions - try: - size = int(line, 16) - except ValueError: - raise errors.IncompleteRead(b'') from None + # next line + lines_idx += 1 + line = lines[lines_idx] + continuation = line[0] in CONTINUATION + value = ''.join(value) + else: + if header_length > max_field_size: + raise errors.LineTooLong('limit request headers fields size') - if size == 0: - break + value = value.strip() - # read chunk - while len(data) < size: - data.extend((yield)) + # keep-alive and encoding + if name == 'CONNECTION': + v = value.lower() + if v == 'close': + close_conn = True + elif v == 'keep-alive': + close_conn = False + elif name == 'CONTENT-ENCODING': + enc = value.lower() + if enc in ('gzip', 'deflate'): + encoding = enc - # feed stream - stream.feed_data(data[:size]) + headers.append((name, value)) - data = data[size:] + return headers, close_conn, encoding - # toss the CRLF at the end of the chunk - while len(data) < 2: - data.extend((yield)) - data = data[2:] +def http_payload_parser(message, length=None, compression=True, readall=False): + out, buf = yield - # read and discard trailer up to the CRLF terminator - while True: - if b'\n' in data: - line, data = data.split(b'\n', 1) - if line in (b'\r', b''): - break - else: - data.extend((yield)) - - # stream eof - stream.feed_eof() - return data + # payload params + chunked = False + for name, value in message.headers: + if name == 'CONTENT-LENGTH': + length = value + elif name == 'TRANSFER-ENCODING': + chunked = value.lower() == 'chunked' + elif name == 'SEC-WEBSOCKET-KEY1': + length = 8 - except StreamEofException: - stream.set_exception(errors.IncompleteRead(b'')) - except errors.IncompleteRead as exc: - stream.set_exception(exc) + # payload decompression wrapper + if compression and message.compression: + out = DeflateBuffer(out, message.compression) - def _parse_length_payload(self, length): - """Read specified amount of bytes.""" - stream = yield + # payload parser + if chunked: + yield from parse_chunked_payload(out, buf) + elif length is not None: try: - data = bytearray() - while length: - data.extend((yield)) - - data_len = len(data) - if data_len <= length: - stream.feed_data(data) - data = bytearray() - length -= data_len - else: - stream.feed_data(data[:length]) - data = data[length:] - length = 0 + length = int(length) + except ValueError: + raise errors.InvalidHeader('CONTENT-LENGTH') from None - stream.feed_eof() - return data - except StreamEofException: - stream.set_exception(errors.IncompleteRead(b'')) + if length < 0: + raise errors.InvalidHeader('CONTENT-LENGTH') + elif length > 0: + yield from parse_length_payload(out, buf, length) + else: + if readall: + yield from parse_eof_payload(out, buf) - def _parse_eof_payload(self): - """Read all bytes untile eof.""" - stream = yield + out.feed_eof() - try: - while True: - stream.feed_data((yield)) - except StreamEofException: - stream.feed_eof() - - @tulip.coroutine - def read_message(self, version=(1, 1), - length=None, compression=True, readall=False): - """Read RFC2822 headers and message payload from a stream. - - read_message() automatically decompress gzip and deflate content - encoding. To prevent decompression pass compression=False. - - Returns tuple of headers, payload stream, should close flag, - compression type. - """ - # load headers - headers = yield from self.read_headers() - - # payload params - chunked = False - encoding = None - close_conn = None - for name, value in headers: - if name == 'CONTENT-LENGTH': - length = value - elif name == 'TRANSFER-ENCODING': - chunked = value.lower() == 'chunked' - elif name == 'SEC-WEBSOCKET-KEY1': - length = 8 - elif name == 'CONNECTION': - v = value.lower() - if v == 'close': - close_conn = True - elif v == 'keep-alive': - close_conn = False - elif compression and name == 'CONTENT-ENCODING': - enc = value.lower() - if enc in ('gzip', 'deflate'): - encoding = enc - - if close_conn is None: - close_conn = version <= (1, 0) - - # payload parser - if chunked: - parser = self._parse_chunked_payload() - - elif length is not None: +def parse_chunked_payload(out, buf): + """Chunked transfer encoding parser.""" + try: + while True: + # read next chunk size + #line = yield from buf.readline(8196) + line = yield from buf.readuntil(b'\r\n', 8196) + + i = line.find(b';') + if i >= 0: + line = line[:i] # strip chunk-extensions + else: + line = line.strip() try: - length = int(length) + size = int(line, 16) except ValueError: - raise errors.InvalidHeader('CONTENT-LENGTH') from None + raise errors.IncompleteRead(b'') from None - if length < 0: - raise errors.InvalidHeader('CONTENT-LENGTH') + if size == 0: # eof marker + break - parser = self._parse_length_payload(length) - else: - if readall: - parser = self._parse_eof_payload() - else: - parser = self._parse_length_payload(0) + # read chunk and feed buffer + while size: + chunk = yield from buf.readsome(size) + out.feed_data(chunk) + size = size - len(chunk) - next(parser) + # toss the CRLF at the end of the chunk + yield from buf.skip(2) - payload = stream = tulip.StreamReader() + # read and discard trailer up to the CRLF terminator + yield from buf.skipuntil(b'\r\n') - # payload decompression wrapper - if encoding is not None: - stream = DeflateStream(stream, encoding) + except tulip.EofStream: + raise errors.IncompleteRead(b'') from None - try: - # initialize payload parser with stream, stream is being - # used by parser as destination stream - parser.send(stream) - except StopIteration: - pass - else: - # feed existing buffer to payload parser - self.byte_count = 0 - while self.buffer: - try: - parser.send(self.buffer.popleft()) - except StopIteration as exc: - parser = None - - # parser is done - buf = b''.join(self.buffer) - self.buffer.clear() - - # re-add remaining data back to buffer - if exc.value: - self.feed_data(exc.value) - - if buf: - self.feed_data(buf) - - break - - # parser still require more data - if parser is not None: - if self.eof: - try: - parser.throw(StreamEofException()) - except StopIteration as exc: - pass - else: - self._parser = parser - return RawHttpMessage(headers, payload, close_conn, encoding) +def parse_length_payload(out, buf, length): + """Read specified amount of bytes.""" + try: + while length: + chunk = yield from buf.readsome(length) + out.feed_data(chunk) + length -= len(chunk) + except tulip.EofStream: + raise errors.IncompleteRead(b'') from None -class StreamEofException(Exception): - """Internal exception: eof received.""" +def parse_eof_payload(out, buf): + """Read all bytes untile eof.""" + while True: + out.feed_data((yield from buf.readsome())) -class DeflateStream: + +class DeflateBuffer: """DeflateStream decomress stream and feed data into specified stream.""" - def __init__(self, stream, encoding): - self.stream = stream + def __init__(self, out, encoding): + self.out = out zlib_mode = (16 + zlib.MAX_WBITS if encoding == 'gzip' else -zlib.MAX_WBITS) self.zlib = zlib.decompressobj(wbits=zlib_mode) - def set_exception(self, exc): - self.stream.set_exception(exc) - def feed_data(self, chunk): try: chunk = self.zlib.decompress(chunk) except: - self.stream.set_exception(errors.IncompleteRead(b'')) + raise errors.IncompleteRead(b'') from None if chunk: - self.stream.feed_data(chunk) + self.out.feed_data(chunk) def feed_eof(self): - self.stream.feed_data(self.zlib.flush()) + self.out.feed_data(self.zlib.flush()) if not self.zlib.eof: - self.stream.set_exception(errors.IncompleteRead(b'')) - - self.stream.feed_eof() - + raise errors.IncompleteRead(b'') -EOF_MARKER = object() -EOL_MARKER = object() + self.out.feed_eof() def wrap_payload_filter(func): @@ -606,22 +477,26 @@ def __init__(self, transport, version, close): self.transport = transport self.version = version self.closing = close - self.keepalive = False + self.keepalive = None self.chunked = False self.length = None self.upgrade = False - self.headers = [] + self.headers = collections.deque() self.headers_sent = False def force_close(self): self.closing = True + self.keepalive = False def force_chunked(self): self.chunked = True def keep_alive(self): - return self.keepalive and not self.closing + if self.keepalive is None: + return not self.closing + else: + return self.keepalive def is_headers_sent(self): return self.headers_sent @@ -638,14 +513,14 @@ def add_header(self, name, value): self.length = int(value) if name == 'CONNECTION': - val = value.lower().strip() + val = value.lower() # handle websocket - if val == 'upgrade': + if 'upgrade' in val: self.upgrade = True # connection keep-alive - elif val == 'close': + elif 'close' in val: self.keepalive = False - elif val == 'keep-alive': + elif 'keep-alive' in val: self.keepalive = True elif name == 'UPGRADE': @@ -688,31 +563,28 @@ def send_headers(self): next(self.writer) - # status line - self.transport.write(self.status_line.encode('ascii')) + self._add_default_headers() - # send headers - self.transport.write( - ('{}\r\n\r\n'.format('\r\n'.join( - ('{}: {}'.format(k, v) for k, v in - itertools.chain(self._default_headers(), self.headers)))) - ).encode('ascii')) + # status + headers + hdrs = ''.join(itertools.chain( + (self.status_line,), + *((k, ': ', v, '\r\n') for k, v in self.headers))) - def _default_headers(self): + self.transport.write(hdrs.encode('ascii') + b'\r\n') + + def _add_default_headers(self): # set the connection header if self.upgrade: connection = 'upgrade' - elif self.keep_alive(): + elif not self.closing if self.keepalive is None else self.keepalive: connection = 'keep-alive' else: connection = 'close' - headers = [('CONNECTION', connection)] - if self.chunked: - headers.append(('TRANSFER-ENCODING', 'chunked')) + self.headers.appendleft(('TRANSFER-ENCODING', 'chunked')) - return headers + self.headers.appendleft(('CONNECTION', connection)) def write(self, chunk): """write() writes chunk of data to a steram by using different writers. @@ -739,7 +611,7 @@ def write(self, chunk): def write_eof(self): self.write(EOF_MARKER) try: - self.writer.throw(StreamEofException()) + self.writer.throw(tulip.EofStream()) except StopIteration: pass @@ -748,7 +620,7 @@ def _write_chunked_payload(self): while True: try: chunk = yield - except StreamEofException: + except tulip.EofStream: self.transport.write(b'0\r\n\r\n') break @@ -761,7 +633,7 @@ def _write_length_payload(self, length): while True: try: chunk = yield - except StreamEofException: + except tulip.EofStream: break if length: @@ -777,7 +649,7 @@ def _write_eof_payload(self): while True: try: chunk = yield - except StreamEofException: + except tulip.EofStream: break self.transport.write(chunk) @@ -848,32 +720,28 @@ def __init__(self, transport, status, http_version=(1, 1), close=False): super().__init__(transport, http_version, close) self.status = status - self.status_line = 'HTTP/{0[0]}.{0[1]} {1} {2}\r\n'.format( - http_version, status, RESPONSES[status][0]) + self.status_line = 'HTTP/{}.{} {} {}\r\n'.format( + http_version[0], http_version[1], status, RESPONSES[status][0]) - def _default_headers(self): - headers = super()._default_headers() - headers.extend((('DATE', format_date_time(time.time())), - ('SERVER', self.SERVER_SOFTWARE))) - - return headers + def _add_default_headers(self): + super()._add_default_headers() + self.headers.extend((('DATE', format_date_time(None)), + ('SERVER', self.SERVER_SOFTWARE),)) class Request(HttpMessage): HOP_HEADERS = () - def __init__(self, transport, method, uri, + def __init__(self, transport, method, path, http_version=(1, 1), close=False): super().__init__(transport, http_version, close) self.method = method - self.uri = uri + self.path = path self.status_line = '{0} {1} HTTP/{2[0]}.{2[1]}\r\n'.format( - method, uri, http_version) - - def _default_headers(self): - headers = super()._default_headers() - headers.append(('USER-AGENT', self.SERVER_SOFTWARE)) + method, path, http_version) - return headers + def _add_default_headers(self): + super()._add_default_headers() + self.headers.append(('USER-AGENT', self.SERVER_SOFTWARE)) diff --git a/tulip/http/server.py b/tulip/http/server.py index d722863d..8c816e94 100644 --- a/tulip/http/server.py +++ b/tulip/http/server.py @@ -8,9 +8,8 @@ import traceback import tulip -import tulip.http +from tulip.http import errors -from . import errors RESPONSES = http.server.BaseHTTPRequestHandler.responses DEFAULT_ERROR_MESSAGE = """ @@ -40,13 +39,14 @@ class ServerHttpProtocol(tulip.Protocol): _request_count = 0 _request_handle = None - def __init__(self, *, log=logging, debug=False): + def __init__(self, *, log=logging, debug=False, **kwargs): + self.__dict__.update(kwargs) self.log = log self.debug = debug def connection_made(self, transport): self.transport = transport - self.stream = tulip.http.HttpStreamReader() + self.stream = tulip.StreamBuffer() self._request_handle = self.start() def data_received(self, data): @@ -63,7 +63,7 @@ def eof_received(self): def close(self): self._closing = True - def log_access(self, status, info, message, *args, **kw): + def log_access(self, status, message, *args, **kw): pass def log_debug(self, *args, **kw): @@ -82,18 +82,23 @@ def start(self): or response handling. In case of any error connection is being closed. """ - while True: + while self._request_handle is not None: info = None message = None self._request_count += 1 try: - info = yield from self.stream.read_request_line() - message = yield from self.stream.read_message(info.version) + httpstream = self.stream.set_parser( + tulip.http.http_request_parser()) - handler = self.handle_request(info, message) + message = yield from httpstream.read() + + payload = self.stream.set_parser( + tulip.http.http_payload_parser(message)) + + handler = self.handle_request(message, payload) if (inspect.isgenerator(handler) or - isinstance(handler, tulip.Future)): + isinstance(handler, tulip.Future)): yield from handler except tulip.CancelledError: @@ -108,50 +113,52 @@ def start(self): self.transport.close() break - self._request_handle = None - - def handle_error(self, status=500, info=None, - message=None, exc=None, headers=None): + def handle_error(self, status=500, + message=None, payload=None, exc=None, headers=None): """Handle errors. Returns http response with specific status code. Logs additional information. It always closes current connection.""" - - if status == 500: - self.log_exception("Error handling request") - try: - reason, msg = RESPONSES[status] - except KeyError: - status = 500 - reason, msg = '???', '' - - if self.debug and exc is not None: - try: - tb = traceback.format_exc() - msg += '

Traceback:

\n
{}
'.format(tb) - except: - pass - - self.log_access(status, info, message) + if self._request_handle is None: + # client has been disconnected during writing. + return - html = DEFAULT_ERROR_MESSAGE.format( - status=status, reason=reason, message=msg) + if status == 500: + self.log_exception("Error handling request") - response = tulip.http.Response(self.transport, status, close=True) - response.add_headers( - ('Content-Type', 'text/html'), - ('Content-Length', str(len(html)))) - if headers is not None: - response.add_headers(*headers) - response.send_headers() - - response.write(html.encode('ascii')) - response.write_eof() - - self.close() - - def handle_request(self, info, message): + try: + reason, msg = RESPONSES[status] + except KeyError: + status = 500 + reason, msg = '???', '' + + if self.debug and exc is not None: + try: + tb = traceback.format_exc() + msg += '

Traceback:

\n
{}
'.format(tb) + except: + pass + + self.log_access(status, message) + + html = DEFAULT_ERROR_MESSAGE.format( + status=status, reason=reason, message=msg) + + response = tulip.http.Response(self.transport, status, close=True) + response.add_headers( + ('Content-Type', 'text/html'), + ('Content-Length', str(len(html)))) + if headers is not None: + response.add_headers(*headers) + response.send_headers() + + response.write(html.encode('ascii')) + response.write_eof() + finally: + self.close() + + def handle_request(self, message, payload): """Handle a single http request. Subclass should override this method. By default it always @@ -161,7 +168,7 @@ def handle_request(self, info, message): message: tulip.http.RawHttpMessage instance """ response = tulip.http.Response( - self.transport, 404, http_version=info.version, close=True) + self.transport, 404, http_version=message.version, close=True) body = b'Page Not Found!' @@ -173,4 +180,4 @@ def handle_request(self, info, message): response.write_eof() self.close() - self.log_access(404, info, message) + self.log_access(404, message) diff --git a/tulip/http/websocket.py b/tulip/http/websocket.py index a80745ed..71784bcc 100644 --- a/tulip/http/websocket.py +++ b/tulip/http/websocket.py @@ -1,10 +1,15 @@ """WebSocket protocol versions 13 and 8.""" -__all__ = ['WebSocketParser', 'WebSocketWriter', 'Message', 'WebSocketError', +__all__ = ['WebSocketParser', 'WebSocketWriter', 'do_handshake', + 'Message', 'WebSocketError', 'MSG_TEXT', 'MSG_BINARY', 'MSG_CLOSE', 'MSG_PING', 'MSG_PONG'] +import base64 +import binascii import collections +import hashlib import struct +from tulip.http import errors # Frame opcodes defined in the spec. OPCODE_CONTINUATION = 0x0 @@ -14,6 +19,10 @@ MSG_PING = OPCODE_PING = 0x9 MSG_PONG = OPCODE_PONG = 0xa +WS_KEY = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +WS_HDRS = ('UPGRADE', 'CONNECTION', + 'SEC-WEBSOCKET-VERSION', 'SEC-WEBSOCKET-KEY') + Message = collections.namedtuple('Message', ['tp', 'data', 'extra']) @@ -173,3 +182,46 @@ def close(self, code=1000, message=b''): self._send_frame( struct.pack('!H%ds' % len(message), code, message), opcode=OPCODE_CLOSE) + + +def do_handshake(message, transport): + """Prepare WebSocket handshake. It return http response code, + response headers, websocket parser, websocket writer. It does not + do any IO.""" + headers = dict(((hdr, val) + for hdr, val in message.headers if hdr in WS_HDRS)) + + if 'websocket' != headers.get('UPGRADE', '').lower().strip(): + raise errors.BadRequestException('No WebSocket UPGRADE hdr: {}'.format( + headers.get('UPGRADE'))) + + if 'upgrade' not in headers.get('CONNECTION', '').lower(): + raise errors.BadRequestException( + 'No CONNECTION upgrade hdr: {}'.format( + headers.get('CONNECTION'))) + + # check supported version + version = headers.get('SEC-WEBSOCKET-VERSION') + if version not in ('13', '8'): + raise errors.BadRequestException( + 'Unsupported version: {}'.format(version)) + + # check client handshake for validity + key = headers.get('SEC-WEBSOCKET-KEY') + try: + if not key or len(base64.b64decode(key)) != 16: + raise errors.BadRequestException( + 'Handshake error: {!r}'.format(key)) + except binascii.Error: + raise errors.BadRequestException( + 'Handshake error: {!r}'.format(key)) from None + + # response code, headers, parser, writer + return (101, + (('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('TRANSFER-ENCODING', 'chunked'), + ('SEC-WEBSOCKET-ACCEPT', base64.b64encode( + hashlib.sha1(key.encode() + WS_KEY).digest()).decode())), + WebSocketParser(), + WebSocketWriter(transport)) diff --git a/tulip/http/wsgi.py b/tulip/http/wsgi.py index 5957f9c8..a612a0aa 100644 --- a/tulip/http/wsgi.py +++ b/tulip/http/wsgi.py @@ -39,11 +39,11 @@ def __init__(self, app, readpayload=False, is_ssl=False, *args, **kw): self.is_ssl = is_ssl self.readpayload = readpayload - def create_wsgi_response(self, info, message): - return WsgiResponse(self.transport, info, message) + def create_wsgi_response(self, message): + return WsgiResponse(self.transport, message) - def create_wsgi_environ(self, info, message, payload): - uri_parts = urlsplit(info.uri) + def create_wsgi_environ(self, message, payload): + uri_parts = urlsplit(message.path) url_scheme = 'https' if self.is_ssl else 'http' environ = { @@ -57,10 +57,10 @@ def create_wsgi_environ(self, info, message, payload): 'wsgi.file_wrapper': FileWrapper, 'wsgi.url_scheme': url_scheme, 'SERVER_SOFTWARE': tulip.http.HttpMessage.SERVER_SOFTWARE, - 'REQUEST_METHOD': info.method, + 'REQUEST_METHOD': message.method, 'QUERY_STRING': uri_parts.query or '', - 'RAW_URI': info.uri, - 'SERVER_PROTOCOL': 'HTTP/%s.%s' % info.version + 'RAW_URI': message.path, + 'SERVER_PROTOCOL': 'HTTP/%s.%s' % message.version } # authors should be aware that REMOTE_HOST and REMOTE_ADDR @@ -86,7 +86,7 @@ def create_wsgi_environ(self, info, message, payload): environ['CONTENT_LENGTH'] = hdr_value continue - key = 'HTTP_' + hdr_name.replace('-', '_') + key = 'HTTP_%s' % hdr_name.replace('-', '_') if key in environ: hdr_value = '%s,%s' % (environ[key], hdr_value) @@ -140,16 +140,19 @@ def create_wsgi_environ(self, info, message, payload): return environ @tulip.coroutine - def handle_request(self, info, message): + def handle_request(self, message, payload): """Handle a single HTTP request""" if self.readpayload: - payload = io.BytesIO((yield from message.payload.read())) - else: - payload = message.payload + wsgiinput = io.BytesIO() + chunk = yield from payload.read() + while chunk: + wsgiinput.write(chunk) + chunk = yield from payload.read() + payload = wsgiinput - environ = self.create_wsgi_environ(info, message, payload) - response = self.create_wsgi_response(info, message) + environ = self.create_wsgi_environ(message, payload) + response = self.create_wsgi_response(message) riter = self.wsgi(environ, response.start_response) if isinstance(riter, tulip.Future) or inspect.isgenerator(riter): @@ -195,9 +198,8 @@ class WsgiResponse: status = None - def __init__(self, transport, info, message): + def __init__(self, transport, message): self.transport = transport - self.info = info self.message = message def start_response(self, status, headers, exc_info=None): @@ -213,7 +215,7 @@ def start_response(self, status, headers, exc_info=None): self.status = status self.response = tulip.http.Response( self.transport, status_code, - self.info.version, self.message.should_close) + self.message.version, self.message.should_close) self.response.add_headers(*headers) self.response._send_headers = True return self.response.write diff --git a/tulip/parsers.py b/tulip/parsers.py index 9d8151de..0b599635 100644 --- a/tulip/parsers.py +++ b/tulip/parsers.py @@ -81,9 +81,8 @@ class StreamBuffer: unset_parser() sends EofStream into parser and then removes it. """ - def __init__(self, buffer_size=5120): + def __init__(self): self._buffer = ParserBuffer() - self._buffer_size = buffer_size self._eof = False self._parser = None self._parser_buffer = None @@ -96,9 +95,7 @@ def set_exception(self, exc): self._exception = exc if self._parser_buffer is not None: - self._buffer.shrink() self._parser_buffer.set_exception(exc) - self._parser = None self._parser_buffer = None @@ -120,10 +117,6 @@ def feed_data(self, data): else: self._buffer.feed_data(data) - # shrink buffer - if (self._buffer.offset and len(self._buffer) > self._buffer_size): - self._buffer.shrink() - def feed_eof(self): """send eof to all parsers, recursively.""" if self._parser: @@ -140,7 +133,6 @@ def feed_eof(self): self._parser_buffer = None self._eof = True - self._buffer.shrink() def set_parser(self, p): """set parser to stream. return parser's DataStream.""" @@ -224,8 +216,11 @@ def exception(self): def set_exception(self, exc): self._exception = exc - if self._waiter is not None: - self._waiter.set_exception(exc) + waiter = self._waiter + if waiter is not None: + self._waiter = None + if not waiter.cancelled(): + waiter.set_exception(exc) def feed_data(self, data): self._buffer.append(data) @@ -241,7 +236,7 @@ def feed_eof(self): waiter = self._waiter if waiter is not None: self._waiter = None - waiter.set_result(True) + waiter.set_result(False) @tasks.coroutine def read(self): @@ -273,7 +268,7 @@ def __init__(self, *args): self._writer = self._feed_data() next(self._writer) - def shrink(self): + def _shrink(self): if self.offset: del self[:self.offset] self.offset = 0 @@ -287,6 +282,10 @@ def _feed_data(self): self.size += chunk_len self.extend(chunk) + # shrink buffer + if (self.offset and len(self) > 5120): + self._shrink() + def feed_data(self, data): self._writer.send(data) @@ -318,27 +317,6 @@ def readsome(self, size=None): self._writer.send((yield)) - def readline(self, limit=2**16, exc=ValueError): - """readline() reads until \n string.""" - - while True: - new_line = self.find(b'\n', self.offset) - if new_line >= 0: - end = new_line + 1 - size = end - self.offset - if size > limit: - raise exc('Line is too long.') - - start, self.offset = self.offset, end - self.size = self.size - size - - return self[start:end] - else: - if self.size > limit: - raise exc('Line is too long.') - - self._writer.send((yield)) - def readuntil(self, stop, limit=None, exc=ValueError): assert isinstance(stop, bytes) and stop, \ 'bytes is required: {!r}'.format(stop) @@ -346,11 +324,11 @@ def readuntil(self, stop, limit=None, exc=ValueError): stop_len = len(stop) while True: - new_line = self.find(stop, self.offset) - if new_line >= 0: - end = new_line + stop_len + pos = self.find(stop, self.offset) + if pos >= 0: + end = pos + stop_len size = end - self.offset - if size > limit: + if limit is not None and size > limit: raise exc('Line is too long.') start, self.offset = self.offset, end @@ -404,7 +382,7 @@ def lines_parser(limit=2**16, exc=ValueError): out, buf = yield while True: - out.feed_data((yield from buf.readline(limit, exc))) + out.feed_data((yield from buf.readuntil(b'\n', limit, exc))) def chunks_parser(size=8196): diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 450f668a..801f805e 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -5,7 +5,6 @@ """ import collections -import errno import socket try: import ssl diff --git a/tulip/streams.py b/tulip/streams.py index 8d7f6236..51028ca7 100644 --- a/tulip/streams.py +++ b/tulip/streams.py @@ -24,8 +24,10 @@ def exception(self): def set_exception(self, exc): self._exception = exc - if self.waiter is not None: - self.waiter.set_exception(exc) + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_exception(exc) def feed_eof(self): self.eof = True diff --git a/tulip/test_utils.py b/tulip/test_utils.py index 3c91e99f..c3870c0f 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -57,20 +57,23 @@ def url(self, *suffix): class TestHttpServer(tulip.http.ServerHttpProtocol): - def handle_request(self, info, message): + def handle_request(self, message, payload): if properties.get('noresponse', False): return if router is not None: - payload = io.BytesIO((yield from message.payload.read())) - rob = router( - properties, self.transport, - info, message.headers, payload, message.compression) + body = bytearray() + chunk = yield from payload.read() + while chunk: + body.extend(chunk) + chunk = yield from payload.read() + + rob = router(properties, self.transport, message, bytes(body)) rob.dispatch() else: response = tulip.http.Response( - self.transport, 200, info.version) + self.transport, 200, message.version) text = b'Test message' response.add_header('Content-type', 'text/plain') @@ -124,19 +127,19 @@ class Router: _response_version = "1.1" _responses = http.server.BaseHTTPRequestHandler.responses - def __init__(self, props, transport, rline, headers, body, cmode): + def __init__(self, props, transport, message, payload): # headers self._headers = http.client.HTTPMessage() - for hdr, val in headers: + for hdr, val in message.headers: self._headers.add_header(hdr, val) self._props = props self._transport = transport - self._method = rline.method - self._uri = rline.uri - self._version = rline.version - self._compression = cmode - self._body = body.read() + self._method = message.method + self._uri = message.path + self._version = message.version + self._compression = message.compression + self._body = payload url = urllib.parse.urlsplit(self._uri) self._path = url.path diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 3e1bc098..73ada428 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -4,7 +4,6 @@ import fcntl import os import socket -import stat import sys try: @@ -219,28 +218,10 @@ def __init__(self, event_loop, pipe, protocol, waiter=None, extra=None): self._buffer = [] self._conn_lost = 0 self._closing = False # Set when close() or write_eof() called. - # Do nothing if it is a regular file. - # Enable hack only if pipe is FIFO object. - # Look on twisted.internet.process:ProcessWriter.__init__ - if stat.S_ISFIFO(os.fstat(self._fileno).st_mode): - self._enable_read_hack = True - else: - # If the pipe is not a unix pipe, then the read hack is never - # applicable. This case arises when _UnixWritePipeTransport - # is used by subprocess and stdout/stderr - # are redirected to a normal file. - self._enable_read_hack = False - - if self._enable_read_hack: - self._event_loop.add_reader(self._fileno, self._read_ready) self._event_loop.call_soon(self._protocol.connection_made, self) if waiter is not None: self._event_loop.call_soon(waiter.set_result, None) - def _read_ready(self): - # pipe was closed by peer - self._close() - def write(self, data): assert isinstance(data, bytes), repr(data) assert not self._closing @@ -286,8 +267,6 @@ def _write_ready(self): else: if n == len(data): self._event_loop.remove_writer(self._fileno) - if self._enable_read_hack: - self._event_loop.remove_reader(self._fileno) if self._closing: self._call_connection_lost(None) return @@ -304,8 +283,6 @@ def write_eof(self): assert self._pipe self._closing = True if not self._buffer: - if self._enable_read_hack: - self._event_loop.remove_reader(self._fileno) self._event_loop.call_soon(self._call_connection_lost, None) def close(self): @@ -325,8 +302,6 @@ def _close(self, exc=None): self._closing = True self._buffer.clear() self._event_loop.remove_writer(self._fileno) - if self._enable_read_hack: - self._event_loop.remove_reader(self._fileno) self._event_loop.call_soon(self._call_connection_lost, exc) def _call_connection_lost(self, exc): From ea96f38e766be15476223e78d33458cffa048f87 Mon Sep 17 00:00:00 2001 From: Giampaolo Rodola' Date: Tue, 23 Apr 2013 04:25:06 +0200 Subject: [PATCH 0419/1502] get rid of bare except clauses --- tulip/http/client.py | 6 +++--- tulip/http/protocol.py | 2 +- tulip/proactor_events.py | 2 +- tulip/selector_events.py | 2 +- tulip/subprocess_transport.py | 2 +- tulip/test_utils.py | 2 +- tulip/winsocketpair.py | 2 +- 7 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tulip/http/client.py b/tulip/http/client.py index fe58e658..ddaf25e2 100644 --- a/tulip/http/client.py +++ b/tulip/http/client.py @@ -120,7 +120,7 @@ def start(req, loop): try: resp = req.send(transport) yield from resp.start(p, transport) - except: + except Exception: transport.close() raise @@ -159,7 +159,7 @@ def __init__(self, method, url, *, v = [l.strip() for l in version.split('.', 1)] try: version = int(v[0]), int(v[1]) - except: + except ValueError: raise ValueError( 'Can not parse http version number: {}' .format(version)) from None @@ -196,7 +196,7 @@ def __init__(self, method, url, *, netloc, port_s = netloc.split(':', 1) try: port = int(port_s) - except: + except ValueError: raise ValueError( 'Port number could not be converted.') from None else: diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py index a7fc567f..0a2959cf 100644 --- a/tulip/http/protocol.py +++ b/tulip/http/protocol.py @@ -317,7 +317,7 @@ def __init__(self, out, encoding): def feed_data(self, chunk): try: chunk = self.zlib.decompress(chunk) - except: + except Exception: raise errors.IncompleteRead(b'') from None if chunk: diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index ba1389d2..22dfa303 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -168,7 +168,7 @@ def _loop_self_reading(self, f=None): if f is not None: f.result() # may raise f = self._proactor.recv(self._ssock, 4096) - except: + except Exception: self.close() raise else: diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 801f805e..2dd4b748 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -108,7 +108,7 @@ def _accept_connection(self, protocol_factory, sock, ssl=False): conn.setblocking(False) except (BlockingIOError, InterruptedError): pass # False alarm. - except: + except Exception: # Bad error. Stop serving. self.remove_reader(sock.fileno()) sock.close() diff --git a/tulip/subprocess_transport.py b/tulip/subprocess_transport.py index 5e4d6550..01bf7339 100644 --- a/tulip/subprocess_transport.py +++ b/tulip/subprocess_transport.py @@ -37,7 +37,7 @@ def __init__(self, protocol, args): os.dup2(wstdout, 1) # TODO: What to do with stderr? os.execv(args[0], args) - except: + except Exception: try: traceback.print_traceback() finally: diff --git a/tulip/test_utils.py b/tulip/test_utils.py index c3870c0f..557e0b0e 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -161,7 +161,7 @@ def dispatch(self): # pragma: no cover if match is not None: try: return getattr(self, fn)(match) - except: + except Exception: out = io.StringIO() traceback.print_exc(file=out) self._response(500, out.getvalue()) diff --git a/tulip/winsocketpair.py b/tulip/winsocketpair.py index bd1e0928..59c8aecc 100644 --- a/tulip/winsocketpair.py +++ b/tulip/winsocketpair.py @@ -24,7 +24,7 @@ def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): csock.connect((addr, port)) except (BlockingIOError, InterruptedError): pass - except: + except Exception: lsock.close() csock.close() raise From 9d13b3f9d0fda80b77780291e51bece21ff952a3 Mon Sep 17 00:00:00 2001 From: Giampaolo Rodola' Date: Tue, 23 Apr 2013 11:32:33 +0200 Subject: [PATCH 0420/1502] use a bare except clause around fork as per guido post-commit review: https://code.google.com/p/tulip/source/detail?r=502e7fdb40f4c2124b6c0fe3d56a3304b5e26c9d --- tulip/subprocess_transport.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tulip/subprocess_transport.py b/tulip/subprocess_transport.py index 01bf7339..5e4d6550 100644 --- a/tulip/subprocess_transport.py +++ b/tulip/subprocess_transport.py @@ -37,7 +37,7 @@ def __init__(self, protocol, args): os.dup2(wstdout, 1) # TODO: What to do with stderr? os.execv(args[0], args) - except Exception: + except: try: traceback.print_traceback() finally: From 64319a97034eecb04a79087ba9798a555bb22dc3 Mon Sep 17 00:00:00 2001 From: Giampaolo Rodola' Date: Tue, 23 Apr 2013 12:24:47 +0200 Subject: [PATCH 0421/1502] provide a fileno() method for selectors for which this makes sense (epoll and kqueue) --- tests/selectors_test.py | 4 ++++ tulip/selectors.py | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/tests/selectors_test.py b/tests/selectors_test.py index 1457e4ed..19f422f6 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -137,3 +137,7 @@ def test_key_from_fd(self, m_log): self.assertIs(key, s._key_from_fd(1)) self.assertIsNone(s._key_from_fd(10)) m_log.warning.assert_called_with('No key found for fd %r', 10) + + if hasattr(selectors.Selector, 'fileno'): + def test_fileno(self): + self.assertIsInstance(selectors.Selector().fileno(), int) diff --git a/tulip/selectors.py b/tulip/selectors.py index bd81e554..8e6add5e 100644 --- a/tulip/selectors.py +++ b/tulip/selectors.py @@ -309,6 +309,9 @@ def __init__(self): super().__init__() self._epoll = epoll() + def fileno(self): + return self._epoll.fileno() + def register(self, fileobj, events, data=None): key = super().register(fileobj, events, data) epoll_events = 0 @@ -359,6 +362,9 @@ def __init__(self): super().__init__() self._kqueue = kqueue() + def fileno(self): + return self._kqueue.fileno() + def unregister(self, fileobj): key = super().unregister(fileobj) if key.events & EVENT_READ: From 440716dce029418f2ddf9b6ac14f47f16e6ef12f Mon Sep 17 00:00:00 2001 From: Giampaolo Rodola' Date: Tue, 23 Apr 2013 17:44:13 +0200 Subject: [PATCH 0422/1502] get back to using a bare exception clause in those points where the exception gets re-raised --- tulip/http/client.py | 2 +- tulip/proactor_events.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tulip/http/client.py b/tulip/http/client.py index ddaf25e2..94455e2f 100644 --- a/tulip/http/client.py +++ b/tulip/http/client.py @@ -120,7 +120,7 @@ def start(req, loop): try: resp = req.send(transport) yield from resp.start(p, transport) - except Exception: + except: transport.close() raise diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index 22dfa303..ba1389d2 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -168,7 +168,7 @@ def _loop_self_reading(self, f=None): if f is not None: f.result() # may raise f = self._proactor.recv(self._ssock, 4096) - except Exception: + except: self.close() raise else: From 116374302029129e899e3c79d89aa10625c0be43 Mon Sep 17 00:00:00 2001 From: Giampaolo Rodola' Date: Wed, 24 Apr 2013 14:33:37 +0200 Subject: [PATCH 0423/1502] make start_serving() listen on multiple interfaces (discussion at https://groups.google.com/forum/?fromgroups=#!topic/python-tulip/uccY8LEskwA) --- examples/mpsrv.py | 2 +- examples/wssrv.py | 4 +-- srv.py | 4 +-- tests/events_test.py | 36 ++++++++++++++++++--- tulip/base_events.py | 77 ++++++++++++++++++++++++++++---------------- tulip/events.py | 32 +++++++++++++++++- tulip/test_utils.py | 4 +-- 7 files changed, 118 insertions(+), 41 deletions(-) diff --git a/examples/mpsrv.py b/examples/mpsrv.py index 8664cd46..daf55b1a 100644 --- a/examples/mpsrv.py +++ b/examples/mpsrv.py @@ -124,7 +124,7 @@ def stop(): loop.add_signal_handler(signal.SIGINT, stop) f = loop.start_serving(lambda: HttpServer(debug=True), sock=self.sock) - x = loop.run_until_complete(f) + x = loop.run_until_complete(f[0]) print('Starting srv worker process {} on {}'.format( os.getpid(), x.getsockname())) diff --git a/examples/wssrv.py b/examples/wssrv.py index 2d94fd36..8459f458 100644 --- a/examples/wssrv.py +++ b/examples/wssrv.py @@ -135,12 +135,12 @@ def stop(): @tulip.task def start_server(self, writer): - sock = yield from self.loop.start_serving( + socks = yield from self.loop.start_serving( lambda: HttpServer( debug=True, parent=writer, clients=self.clients), sock=self.sock) print('Starting srv worker process {} on {}'.format( - os.getpid(), sock.getsockname())) + os.getpid(), socks[0].getsockname())) @tulip.task def heartbeat(self): diff --git a/srv.py b/srv.py index 56392dd5..f93e1d7e 100755 --- a/srv.py +++ b/srv.py @@ -149,8 +149,8 @@ def main(): loop = tulip.get_event_loop() f = loop.start_serving( lambda: HttpServer(debug=True), args.host, args.port, ssl=sslcontext) - x = loop.run_until_complete(f) - print('serving on', x.getsockname()) + socks = loop.run_until_complete(f) + print('serving on', socks[0].getsockname()) loop.run_forever() diff --git a/tests/events_test.py b/tests/events_test.py index e928cdf0..ff6b0e4e 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -14,8 +14,10 @@ import sys import threading import time +import errno import unittest import unittest.mock +from test.support import find_unused_port from tulip import futures @@ -689,7 +691,9 @@ def factory(): return proto f = self.event_loop.start_serving(factory, '0.0.0.0', 0) - sock = self.event_loop.run_until_complete(f) + socks = self.event_loop.run_until_complete(f) + self.assertEqual(len(socks), 1) + sock = socks[0] host, port = sock.getsockname() self.assertEqual(host, '0.0.0.0') client = socket.socket() @@ -744,7 +748,7 @@ def factory(): f = self.event_loop.start_serving( factory, '127.0.0.1', 0, ssl=sslcontext) - sock = self.event_loop.run_until_complete(f) + sock = self.event_loop.run_until_complete(f)[0] host, port = sock.getsockname() self.assertEqual(host, '127.0.0.1') @@ -788,7 +792,7 @@ def __init__(self): sock_ob.bind(('0.0.0.0', 0)) f = self.event_loop.start_serving(TestMyProto, sock=sock_ob) - sock = self.event_loop.run_until_complete(f) + sock = self.event_loop.run_until_complete(f)[0] self.assertIs(sock, sock_ob) host, port = sock.getsockname() @@ -797,12 +801,34 @@ def __init__(self): client.connect(('127.0.0.1', port)) client.send(b'xxx') self.event_loop.run_until_complete(proto) - sock.close() client.close() + f = self.event_loop.start_serving(MyProto, host=host, port=port) + with self.assertRaises(socket.error) as cm: + self.event_loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + + @unittest.skipUnless(socket.has_ipv6, 'IPv6 not supported') + def test_start_serving_dual_stack(self): + port = find_unused_port() + f = self.event_loop.start_serving(MyProto, host=None, port=port) + socks = self.event_loop.run_until_complete(f) + with socket.socket() as client: + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + self.event_loop.run_once() + self.event_loop.run_once() + with socket.socket(socket.AF_INET6) as client: + client.connect(('::1', port)) + client.send(b'xxx') + self.event_loop.run_once() + self.event_loop.run_once() + for s in socks: + s.close() + def test_stop_serving(self): f = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) - sock = self.event_loop.run_until_complete(f) + sock = self.event_loop.run_until_complete(f)[0] host, port = sock.getsockname() client = socket.socket() diff --git a/tulip/base_events.py b/tulip/base_events.py index 66397421..ca83a56c 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -20,6 +20,8 @@ import logging import socket import time +import os +import sys from . import events from . import futures @@ -413,45 +415,64 @@ def create_datagram_endpoint(self, protocol_factory, # TODO: Or create_server()? @tasks.task def start_serving(self, protocol_factory, host=None, port=None, *, - family=0, proto=0, flags=0, backlog=100, sock=None, - ssl=False): + family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, + sock=None, backlog=100, ssl=False, reuse_address=None): """XXX""" if host is not None or port is not None: if sock is not None: raise ValueError( "host, port and sock can not be specified at the same time") + AF_INET6 = getattr(socket, 'AF_INET6', 0) + if reuse_address is None: + reuse_address = os.name == 'posix' and sys.platform != 'cygwin' + sockets = [] + if host == "": + host = None + infos = yield from self.getaddrinfo( host, port, family=family, - type=socket.SOCK_STREAM, proto=proto, flags=flags) - + type=socket.SOCK_STREAM, proto=0, flags=flags) if not infos: raise socket.error('getaddrinfo() returned empty list') - # TODO: Maybe we want to bind every address in the list - # instead of the first one that works? - exceptions = [] - for family, type, proto, cname, address in infos: - sock = socket.socket(family=family, type=type, proto=proto) - try: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(address) - except socket.error as exc: - sock.close() - exceptions.append(exc) - else: - break - else: - raise exceptions[0] - - elif sock is None: - raise ValueError( - "host and port was not specified and no sock specified") - - sock.listen(backlog) - sock.setblocking(False) - self._start_serving(protocol_factory, sock, ssl) - return sock + completed = False + try: + for res in infos: + af, socktype, proto, canonname, sa = res + sock = socket.socket(af, socktype, proto) + sockets.append(sock) + if reuse_address: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, + True) + # Disable IPv4/IPv6 dual stack support (enabled by + # default on Linux) which makes a single socket + # listen on both address families. + if af == AF_INET6 and hasattr(socket, "IPPROTO_IPV6"): + sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, + True) + try: + sock.bind(sa) + except socket.error as err: + raise socket.error(err.errno, "error while attempting " + "to bind on address %r: %s" \ + % (sa, err.strerror.lower())) + completed = True + finally: + if not completed: + for sock in sockets: + sock.close() + else: + if sock is None: + raise ValueError( + "host and port was not specified and no sock specified") + sockets = [sock] + + for sock in sockets: + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock, ssl) + return sockets @tasks.coroutine def connect_read_pipe(self, protocol_factory, pipe): diff --git a/tulip/events.py b/tulip/events.py index 68cd7211..63e4b790 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -12,6 +12,7 @@ import sys import threading +import socket from .log import tulip_log @@ -182,7 +183,36 @@ def create_connection(self, protocol_factory, host=None, port=None, *, raise NotImplementedError def start_serving(self, protocol_factory, host=None, port=None, *, - family=0, proto=0, flags=0, sock=None): + family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, + sock=None, backlog=100, ssl=False, reuse_address=None): + """Creates a TCP server bound to host and port and return + a list of socket objects which will later be handled by + protocol_factory. + + If host is an empty string or None all interfaces are assumed + and a list of multiple sockets will be returned (most likely + one for IPv4 and another one for IPv6). + + family can be set to either AF_INET or AF_INET6 to force the + socket to use IPv4 or IPv6. If not set it will be determined + from host (defaults to AF_UNSPEC). + + flags is a bitmask for getaddrinfo(). + + sock can optionally be specified in order to use a preexisting + socket object. + + backlog is the maximum number of queued connections passed to + listen() (defaults to 100). + + ssl can be set to True to enable SSL over the accepted + connections. + + reuse_address tells the kernel to reuse a local socket in + TIME_WAIT state, without waiting for its natural timeout to + expire. If not specified will automatically be set to True on + UNIX. + """ raise NotImplementedError def stop_serving(self, sock): diff --git a/tulip/test_utils.py b/tulip/test_utils.py index 557e0b0e..e6d1069b 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -98,13 +98,13 @@ def run(loop, fut): thread_loop.set_log_level(logging.CRITICAL) tulip.set_event_loop(thread_loop) - sock = thread_loop.run_until_complete( + socks = thread_loop.run_until_complete( thread_loop.start_serving( TestHttpServer, host, port, ssl=sslcontext)) waiter = tulip.Future() loop.call_soon_threadsafe( - fut.set_result, (thread_loop, waiter, sock.getsockname())) + fut.set_result, (thread_loop, waiter, socks[0].getsockname())) thread_loop.run_until_complete(waiter) thread_loop.stop() From f3221bc134681830c9f925fba8544041d358d6d6 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 24 Apr 2013 09:33:55 -0700 Subject: [PATCH 0424/1502] delay task cancelation until wait future is not done --- tests/tasks_test.py | 47 +++++++++++++++++++++++++++++++++++++++++++++ tulip/tasks.py | 22 ++++++++++++++------- 2 files changed, 62 insertions(+), 7 deletions(-) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 98b08f98..57f5bede 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -78,6 +78,7 @@ def test_task_repr(self): def notmuch(): yield from [] return 'abc' + t = notmuch() t.add_done_callback(Dummy()) self.assertEqual(repr(t), 'Task()') @@ -138,6 +139,52 @@ def task(): self.assertTrue(t.done()) self.assertFalse(t.cancel()) + def test_cancel_yield(self): + @tasks.task + def task(): + yield + yield + return 12 + + t = task() + self.event_loop.run_once() # start coro + t.cancel() + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_cancel_done_future(self): + fut1 = futures.Future() + fut2 = futures.Future() + fut3 = futures.Future() + + @tasks.task + def task(): + yield from fut1 + try: + yield from fut2 + except futures.CancelledError: + pass + yield from fut3 + + t = task() + self.event_loop.run_once() + fut1.set_result(None) + t.cancel() + self.event_loop.run_once() # process fut1 result, delay cancel + self.assertFalse(t.done()) + self.event_loop.run_once() # cancel fut2, but coro still alive + self.assertFalse(t.done()) + self.event_loop.run_once() # cancel fut3 + self.assertTrue(t.done()) + + self.assertEqual(fut1.result(), None) + self.assertTrue(fut2.cancelled()) + self.assertTrue(fut3.cancelled()) + self.assertTrue(t.cancelled()) + def test_future_timeout(self): @tasks.coroutine def coro(): diff --git a/tulip/tasks.py b/tulip/tasks.py index 70762c57..119b5ee6 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -65,6 +65,9 @@ def task_wrapper(*args, **kwds): return task_wrapper +_marker = object() + + class Task(futures.Future): """A coroutine wrapped in a Future.""" @@ -89,7 +92,7 @@ def __repr__(self): return res def cancel(self): - if self.done(): + if self.done() or self._must_cancel: return False self._must_cancel = True # _step() will call super().cancel() to call the callbacks. @@ -107,16 +110,18 @@ def _step_maybe(self): if not self.done(): return self._step() - def _step(self, value=None, exc=None): + def _step(self, value=_marker, exc=None): assert not self.done(), \ '_step(): already done: {!r}, {!r}, {!r}'.format(self, value, exc) - self._fut_waiter = None - # We'll call either coro.throw(exc) or coro.send(value). - if self._must_cancel: + # Task cancel has to be delayed if current waiter future is done. + if self._must_cancel and exc is None and value is _marker: exc = futures.CancelledError + coro = self._coro + value = None if value is _marker else value + self._fut_waiter = None try: if exc is not None: result = coro.throw(exc) @@ -141,7 +146,6 @@ def _step(self, value=None, exc=None): self.set_exception(exc) raise else: - # XXX No check for self._must_cancel here? if isinstance(result, futures.Future): if not result._blocking: result.set_exception( @@ -153,6 +157,10 @@ def _step(self, value=None, exc=None): result.add_done_callback(self._wakeup) self._fut_waiter = result + # task cancellation has been delayed. + if self._must_cancel: + self._fut_waiter.cancel() + elif isinstance(result, concurrent.futures.Future): # This ought to be more efficient than wrap_future(), # because we don't create an extra Future. @@ -175,7 +183,7 @@ def _step(self, value=None, exc=None): RuntimeError( 'Task received bad yield: {!r}'.format(result))) else: - self._event_loop.call_soon(self._step) + self._event_loop.call_soon(self._step_maybe) def _wakeup(self, future): try: From a2449dfc9f89699131730193efa785bdd52a1b78 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 24 Apr 2013 09:44:03 -0700 Subject: [PATCH 0425/1502] Rename Selector class to DefaultSelector. --- tests/selectors_test.py | 4 ++-- tulip/selector_events.py | 2 +- tulip/selectors.py | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/selectors_test.py b/tests/selectors_test.py index 19f422f6..f933f35e 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -138,6 +138,6 @@ def test_key_from_fd(self, m_log): self.assertIsNone(s._key_from_fd(10)) m_log.warning.assert_called_with('No key found for fd %r', 10) - if hasattr(selectors.Selector, 'fileno'): + if hasattr(selectors.DefaultSelector, 'fileno'): def test_fileno(self): - self.assertIsInstance(selectors.Selector().fileno(), int) + self.assertIsInstance(selectors.DefaultSelector().fileno(), int) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 2dd4b748..e9388f91 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -43,7 +43,7 @@ def __init__(self, selector=None): super().__init__() if selector is None: - selector = selectors.Selector() + selector = selectors.DefaultSelector() tulip_log.debug('Using selector: %s', selector.__class__.__name__) self._selector = selector self._make_self_pipe() diff --git a/tulip/selectors.py b/tulip/selectors.py index 8e6add5e..4e671444 100644 --- a/tulip/selectors.py +++ b/tulip/selectors.py @@ -415,10 +415,10 @@ def close(self): # Choose the best implementation: roughly, epoll|kqueue > poll > select. # select() also can't accept a FD > FD_SETSIZE (usually around 1024) if 'KqueueSelector' in globals(): - Selector = KqueueSelector + DefaultSelector = KqueueSelector elif 'EpollSelector' in globals(): - Selector = EpollSelector + DefaultSelector = EpollSelector elif 'PollSelector' in globals(): - Selector = PollSelector + DefaultSelector = PollSelector else: - Selector = SelectSelector + DefaultSelector = SelectSelector From 516f1b1b7452743fd194c4e06b4215754c6dd0d1 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 24 Apr 2013 10:55:24 -0700 Subject: [PATCH 0426/1502] separate test for EADDRINUSE error --- tests/events_test.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index ff6b0e4e..0566af68 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -783,8 +783,8 @@ def test_start_serving_sock(self): proto = futures.Future() class TestMyProto(MyProto): - def __init__(self): - super().__init__() + def connection_made(self, transport): + super().connection_made(transport) proto.set_result(self) sock_ob = socket.socket(type=socket.SOCK_STREAM) @@ -801,8 +801,18 @@ def __init__(self): client.connect(('127.0.0.1', port)) client.send(b'xxx') self.event_loop.run_until_complete(proto) + sock.close() client.close() + def test_start_serving_addrinuse(self): + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.event_loop.start_serving(MyProto, sock=sock_ob) + sock = self.event_loop.run_until_complete(f)[0] + host, port = sock.getsockname() + f = self.event_loop.start_serving(MyProto, host=host, port=port) with self.assertRaises(socket.error) as cm: self.event_loop.run_until_complete(f) From 6d91d497636a395899603be064692f0b013e2093 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 24 Apr 2013 12:43:41 -0700 Subject: [PATCH 0427/1502] windows tests fixes --- tests/events_test.py | 42 ++++++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 0566af68..9afdff60 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -820,19 +820,33 @@ def test_start_serving_addrinuse(self): @unittest.skipUnless(socket.has_ipv6, 'IPv6 not supported') def test_start_serving_dual_stack(self): + f_proto = futures.Future() + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + f_proto.set_result(self) + port = find_unused_port() - f = self.event_loop.start_serving(MyProto, host=None, port=port) + f = self.event_loop.start_serving(TestMyProto, host=None, port=port) socks = self.event_loop.run_until_complete(f) - with socket.socket() as client: - client.connect(('127.0.0.1', port)) - client.send(b'xxx') - self.event_loop.run_once() - self.event_loop.run_once() - with socket.socket(socket.AF_INET6) as client: - client.connect(('::1', port)) - client.send(b'xxx') - self.event_loop.run_once() - self.event_loop.run_once() + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + proto = self.event_loop.run_until_complete(f_proto) + proto.transport.close() + self.event_loop.run_once() # windows, issue #35 + client.close() + + f_proto = futures.Future() + client = socket.socket(socket.AF_INET6) + client.connect(('::1', port)) + client.send(b'xxx') + proto = self.event_loop.run_until_complete(f_proto) + proto.transport.close() + self.event_loop.run_once() # windows, issue #35 + client.close() + for s in socks: s.close() @@ -873,16 +887,16 @@ def test_start_serving_no_getaddrinfo(self): def test_start_serving_cant_bind(self, m_socket): class Err(socket.error): - pass + strerror = 'error' m_socket.error = socket.error m_socket.getaddrinfo.return_value = [ (2, 1, 6, '', ('127.0.0.1', 10100))] m_sock = m_socket.socket.return_value = unittest.mock.Mock() - m_sock.setsockopt.side_effect = Err + m_sock.bind.side_effect = Err fut = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) - self.assertRaises(Err, self.event_loop.run_until_complete, fut) + self.assertRaises(OSError, self.event_loop.run_until_complete, fut) self.assertTrue(m_sock.close.called) @unittest.mock.patch('tulip.base_events.socket') From 3d91fefe57a03361d6339d344532f0e4b3209413 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 24 Apr 2013 13:13:22 -0700 Subject: [PATCH 0428/1502] wsclient example fixes --- examples/wsclient.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/wsclient.py b/examples/wsclient.py index 64b48514..3ef310c0 100644 --- a/examples/wsclient.py +++ b/examples/wsclient.py @@ -42,8 +42,8 @@ def start_client(loop): raise ValueError("Handshake error - Invalid challenge response") # switch to websocket protocol - stream = response.stream.set_parser(websocket.websocket_parser()) - writer = websocket.websocket_writer(response.transport) + stream = response.stream.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(response.transport) # input reader loop.add_reader( @@ -56,11 +56,11 @@ def dispatch(): msg = yield from stream.read() if msg is None: break - elif msg.opcode == websocket.OPCODE_PING: + elif msg.tp == websocket.MSG_PING: writer.pong() - elif msg.opcode == websocket.OPCODE_TEXT: + elif msg.tp == websocket.MSG_TEXT: print(msg.data.strip()) - elif msg.opcode == websocket.OPCODE_CLOSE: + elif msg.tp == websocket.MSG_CLOSE: break yield from dispatch() From f57efa4285bdff25c25546e846f329345819e1fc Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 24 Apr 2013 13:25:51 -0700 Subject: [PATCH 0429/1502] Move example programs into examples directory. --- crawl.py => examples/crawl.py | 0 curl.py => examples/curl.py | 0 srv.py => examples/srv.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename crawl.py => examples/crawl.py (100%) rename curl.py => examples/curl.py (100%) rename srv.py => examples/srv.py (100%) diff --git a/crawl.py b/examples/crawl.py similarity index 100% rename from crawl.py rename to examples/crawl.py diff --git a/curl.py b/examples/curl.py similarity index 100% rename from curl.py rename to examples/curl.py diff --git a/srv.py b/examples/srv.py similarity index 100% rename from srv.py rename to examples/srv.py From d412b1d462c531185dc8ae6a8eddc6b730bd8487 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 24 Apr 2013 13:46:15 -0700 Subject: [PATCH 0430/1502] added host and port arguments to wsclient.py example --- examples/wsclient.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) mode change 100644 => 100755 examples/wsclient.py diff --git a/examples/wsclient.py b/examples/wsclient.py old mode 100644 new mode 100755 index 3ef310c0..8458d78e --- a/examples/wsclient.py +++ b/examples/wsclient.py @@ -1,4 +1,6 @@ +#!/usr/bin/env python3 """websocket cmd client for wssrv.py example.""" +import argparse import base64 import hashlib import os @@ -12,10 +14,9 @@ WS_KEY = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" -def start_client(loop): +def start_client(loop, url): name = input('Please enter your name: ').encode() - url = 'http://localhost:8080/' sec_key = base64.b64encode(os.urandom(16)) # send request @@ -66,7 +67,25 @@ def dispatch(): yield from dispatch() +ARGS = argparse.ArgumentParser( + description="websocket console client for wssrv.py example.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') + if __name__ == '__main__': + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + url = 'http://{}:{}'.format(args.host, args.port) + loop = tulip.get_event_loop() + loop.set_log_level(50) loop.add_signal_handler(signal.SIGINT, loop.stop) - loop.run_until_complete(start_client(loop)) + tulip.Task(start_client(loop, url)) + loop.run_forever() From 68cb30f6f28d656de52613e970266c1c32241a19 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 24 Apr 2013 15:20:58 -0700 Subject: [PATCH 0431/1502] Make wsclient work on OSX in an Emacs shell window. Also make it stop when receiving EOF. --- examples/wsclient.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/examples/wsclient.py b/examples/wsclient.py index 8458d78e..5598c38c 100755 --- a/examples/wsclient.py +++ b/examples/wsclient.py @@ -10,6 +10,7 @@ import tulip import tulip.http from tulip.http import websocket +import tulip.selectors WS_KEY = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" @@ -47,9 +48,13 @@ def start_client(loop, url): writer = websocket.WebSocketWriter(response.transport) # input reader - loop.add_reader( - sys.stdin.fileno(), - lambda: writer.send(name + b': ' + sys.stdin.readline().encode())) + def stdin_callback(): + line = sys.stdin.buffer.readline() + if not line: + loop.stop() + else: + writer.send(name + b': ' + line) + loop.add_reader(sys.stdin.fileno(), stdin_callback) @tulip.coroutine def dispatch(): @@ -84,7 +89,9 @@ def dispatch(): url = 'http://{}:{}'.format(args.host, args.port) - loop = tulip.get_event_loop() + loop = tulip.SelectorEventLoop(tulip.selectors.SelectSelector()) + tulip.set_event_loop(loop) + loop.set_log_level(50) loop.add_signal_handler(signal.SIGINT, loop.stop) tulip.Task(start_client(loop, url)) From 48097dc869e47d03fe342a1f21f101dca7999b0d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 24 Apr 2013 15:50:37 -0700 Subject: [PATCH 0432/1502] Delete old stuff. --- old/Makefile | 16 -- old/echoclt.py | 79 ------- old/echosvr.py | 60 ----- old/http_client.py | 78 ------- old/http_server.py | 68 ------ old/main.py | 134 ------------ old/p3time.py | 47 ---- old/polling.py | 535 --------------------------------------------- old/scheduling.py | 354 ------------------------------ old/sockets.py | 348 ----------------------------- old/transports.py | 496 ----------------------------------------- old/xkcd.py | 18 -- old/yyftime.py | 75 ------- 13 files changed, 2308 deletions(-) delete mode 100644 old/Makefile delete mode 100644 old/echoclt.py delete mode 100644 old/echosvr.py delete mode 100644 old/http_client.py delete mode 100644 old/http_server.py delete mode 100644 old/main.py delete mode 100644 old/p3time.py delete mode 100644 old/polling.py delete mode 100644 old/scheduling.py delete mode 100644 old/sockets.py delete mode 100644 old/transports.py delete mode 100755 old/xkcd.py delete mode 100644 old/yyftime.py diff --git a/old/Makefile b/old/Makefile deleted file mode 100644 index d352cd70..00000000 --- a/old/Makefile +++ /dev/null @@ -1,16 +0,0 @@ -PYTHON=python3 - -main: - $(PYTHON) main.py -v - -echo: - $(PYTHON) echosvr.py -v - -profile: - $(PYTHON) -m profile -s time main.py - -time: - $(PYTHON) p3time.py - -ytime: - $(PYTHON) yyftime.py diff --git a/old/echoclt.py b/old/echoclt.py deleted file mode 100644 index c24c573e..00000000 --- a/old/echoclt.py +++ /dev/null @@ -1,79 +0,0 @@ -#!/usr/bin/env python3.3 -"""Example echo client.""" - -# Stdlib imports. -import logging -import socket -import sys -import time - -# Local imports. -import scheduling -import sockets - - -def echoclient(host, port): - """COROUTINE""" - testdata = b'hi hi hi ha ha ha\n' - try: - trans = yield from sockets.create_transport(host, port, - af=socket.AF_INET) - except OSError: - return False - try: - ok = yield from trans.send(testdata) - if ok: - response = yield from trans.recv(100) - ok = response == testdata.upper() - return ok - finally: - trans.close() - - -def doit(n): - """COROUTINE""" - t0 = time.time() - tasks = set() - for i in range(n): - t = scheduling.Task(echoclient('127.0.0.1', 1111), 'client-%d' % i) - tasks.add(t) - ok = 0 - bad = 0 - for t in tasks: - try: - yield from t - except Exception: - bad += 1 - else: - ok += 1 - t1 = time.time() - print('ok: ', ok) - print('bad:', bad) - print('dt: ', round(t1-t0, 6)) - - -def main(): - # Initialize logging. - if '-d' in sys.argv: - level = logging.DEBUG - elif '-v' in sys.argv: - level = logging.INFO - elif '-q' in sys.argv: - level = logging.ERROR - else: - level = logging.WARN - logging.basicConfig(level=level) - - # Get integer from command line. - n = 1 - for arg in sys.argv[1:]: - if not arg.startswith('-'): - n = int(arg) - break - - # Run scheduler, starting it off with doit(). - scheduling.run(doit(n)) - - -if __name__ == '__main__': - main() diff --git a/old/echosvr.py b/old/echosvr.py deleted file mode 100644 index 4085f4c6..00000000 --- a/old/echosvr.py +++ /dev/null @@ -1,60 +0,0 @@ -#!/usr/bin/env python3.3 -"""Example echo server.""" - -# Stdlib imports. -import logging -import socket -import sys - -# Local imports. -import scheduling -import sockets - - -def handler(conn, addr): - """COROUTINE: Handle one connection.""" - logging.info('Accepting connection from %r', addr) - trans = sockets.SocketTransport(conn) - rdr = sockets.BufferedReader(trans) - while True: - line = yield from rdr.readline() - logging.debug('Received: %r from %r', line, addr) - if not line: - break - yield from trans.send(line.upper()) - logging.debug('Closing %r', addr) - trans.close() - - -def doit(): - """COROUTINE: Set the wheels in motion.""" - # Set up listener. - listener = yield from sockets.create_listener('localhost', 1111, - af=socket.AF_INET, - backlog=100) - logging.info('Listening on %r', listener.sock.getsockname()) - - # Loop accepting connections. - while True: - conn, addr = yield from listener.accept() - t = scheduling.Task(handler(conn, addr)) - - -def main(): - # Initialize logging. - if '-d' in sys.argv: - level = logging.DEBUG - elif '-v' in sys.argv: - level = logging.INFO - elif '-q' in sys.argv: - level = logging.ERROR - else: - level = logging.WARN - logging.basicConfig(level=level) - - # Run scheduler, starting it off with doit(). - scheduling.run(doit()) - - -if __name__ == '__main__': - main() diff --git a/old/http_client.py b/old/http_client.py deleted file mode 100644 index 8937ba20..00000000 --- a/old/http_client.py +++ /dev/null @@ -1,78 +0,0 @@ -"""Crummy HTTP client. - -This is not meant as an example of how to write a good client. -""" - -# Stdlib. -import re -import time - -# Local. -import sockets - - -def urlfetch(host, port=None, path='/', method='GET', - body=None, hdrs=None, encoding='utf-8', ssl=None, af=0): - """COROUTINE: Make an HTTP 1.0 request.""" - t0 = time.time() - if port is None: - if ssl: - port = 443 - else: - port = 80 - trans = yield from sockets.create_transport(host, port, ssl=ssl, af=af) - yield from trans.send(method.encode(encoding) + b' ' + - path.encode(encoding) + b' HTTP/1.0\r\n') - if hdrs: - kwds = dict(hdrs) - else: - kwds = {} - if 'host' not in kwds: - kwds['host'] = host - if body is not None: - kwds['content_length'] = len(body) - for header, value in kwds.items(): - yield from trans.send(header.replace('_', '-').encode(encoding) + - b': ' + value.encode(encoding) + b'\r\n') - - yield from trans.send(b'\r\n') - if body is not None: - yield from trans.send(body) - - # Read HTTP response line. - rdr = sockets.BufferedReader(trans) - resp = yield from rdr.readline() - m = re.match(br'(?ix) http/(\d\.\d) \s+ (\d\d\d) \s+ ([^\r]*)\r?\n\Z', - resp) - if not m: - trans.close() - raise IOError('No valid HTTP response: %r' % resp) - http_version, status, message = m.groups() - - # Read HTTP headers. - headers = [] - hdict = {} - while True: - line = yield from rdr.readline() - if not line.strip(): - break - m = re.match(br'([^\s:]+):\s*([^\r]*)\r?\n\Z', line) - if not m: - raise IOError('Invalid header: %r' % line) - header, value = m.groups() - headers.append((header, value)) - hdict[header.decode(encoding).lower()] = value.decode(encoding) - - # Read response body. - content_length = hdict.get('content-length') - if content_length is not None: - size = int(content_length) # TODO: Catch errors. - assert size >= 0, size - else: - size = 2**20 # Protective limit (1 MB). - data = yield from rdr.readexactly(size) - trans.close() # Can this block? - t1 = time.time() - result = (host, port, path, int(status), len(data), round(t1-t0, 3)) -## print(result) - return result diff --git a/old/http_server.py b/old/http_server.py deleted file mode 100644 index 2b1e3dd6..00000000 --- a/old/http_server.py +++ /dev/null @@ -1,68 +0,0 @@ -#!/usr/bin/env python3.3 -"""Simple HTTP server. - -This currenty exists just so we can benchmark this thing! -""" - -# Stdlib imports. -import logging -import re -import socket -import sys - -# Local imports. -import scheduling -import sockets - - -def handler(conn, addr): - """COROUTINE: Handle one connection.""" - ##logging.info('Accepting connection from %r', addr) - trans = sockets.SocketTransport(conn) - rdr = sockets.BufferedReader(trans) - - # Read but ignore request line. - request_line = yield from rdr.readline() - - # Consume headers but don't interpret them. - while True: - header_line = yield from rdr.readline() - if not header_line.strip(): - break - - # Always send an empty 200 response and close. - yield from trans.send(b'HTTP/1.0 200 Ok\r\n\r\n') - trans.close() - - -def doit(): - """COROUTINE: Set the wheels in motion.""" - # Set up listener. - listener = yield from sockets.create_listener('localhost', 8080, - af=socket.AF_INET) - logging.info('Listening on %r', listener.sock.getsockname()) - - # Loop accepting connections. - while True: - conn, addr = yield from listener.accept() - t = scheduling.Task(handler(conn, addr)) - - -def main(): - # Initialize logging. - if '-d' in sys.argv: - level = logging.DEBUG - elif '-v' in sys.argv: - level = logging.INFO - elif '-q' in sys.argv: - level = logging.ERROR - else: - level = logging.WARN - logging.basicConfig(level=level) - - # Run scheduler, starting it off with doit(). - scheduling.run(doit()) - - -if __name__ == '__main__': - main() diff --git a/old/main.py b/old/main.py deleted file mode 100644 index c1f9d0a8..00000000 --- a/old/main.py +++ /dev/null @@ -1,134 +0,0 @@ -#!/usr/bin/env python3.3 -"""Example HTTP client using yield-from coroutines (PEP 380). - -Requires Python 3.3. - -There are many micro-optimizations possible here, but that's not the point. - -Some incomplete laundry lists: - -TODO: -- Take test urls from command line. -- Move urlfetch to a separate module. -- Profiling. -- Docstrings. -- Unittests. - -FUNCTIONALITY: -- Connection pool (keep connection open). -- Chunked encoding (request and response). -- Pipelining, e.g. zlib (request and response). -- Automatic encoding/decoding. -""" - -__author__ = 'Guido van Rossum ' - -# Standard library imports (keep in alphabetic order). -import logging -import os -import time -import socket -import sys - -# Local imports (keep in alphabetic order). -import scheduling -import http_client - - - -def doit2(): - argses = [ - ('localhost', 8080, '/'), - ('127.0.0.1', 8080, '/home'), - ('python.org', 80, '/'), - ('xkcd.com', 443, '/'), - ] - results = yield from scheduling.map_over( - lambda args: http_client.urlfetch(*args), argses, timeout=2) - for res in results: - print('-->', res) - return [] - - -def doit(): - TIMEOUT = 2 - tasks = set() - - # This references NDB's default test service. - # (Sadly the service is single-threaded.) - task1 = scheduling.Task(http_client.urlfetch('localhost', 8080, path='/'), - 'root', timeout=TIMEOUT) - tasks.add(task1) - task2 = scheduling.Task(http_client.urlfetch('127.0.0.1', 8080, - path='/home'), - 'home', timeout=TIMEOUT) - tasks.add(task2) - - # Fetch python.org home page. - task3 = scheduling.Task(http_client.urlfetch('python.org', 80, path='/'), - 'python', timeout=TIMEOUT) - tasks.add(task3) - - # Fetch XKCD home page using SSL. (Doesn't like IPv6.) - task4 = scheduling.Task(http_client.urlfetch('xkcd.com', ssl=True, path='/', - af=socket.AF_INET), - 'xkcd', timeout=TIMEOUT) - tasks.add(task4) - -## # Fetch many links from python.org (/x.y.z). -## for x in '123': -## for y in '0123456789': -## path = '/{}.{}'.format(x, y) -## g = http_client.urlfetch('82.94.164.162', 80, -## path=path, hdrs={'host': 'python.org'}) -## t = scheduling.Task(g, path, timeout=2) -## tasks.add(t) - -## print(tasks) - yield from scheduling.Task(scheduling.sleep(1), timeout=0.2).wait() - winners = yield from scheduling.wait_any(tasks) - print('And the winners are:', [w.name for w in winners]) - tasks = yield from scheduling.wait_all(tasks) - print('And the players were:', [t.name for t in tasks]) - return tasks - - -def logtimes(real): - utime, stime, cutime, cstime, unused = os.times() - logging.info('real %10.3f', real) - logging.info('user %10.3f', utime + cutime) - logging.info('sys %10.3f', stime + cstime) - - -def main(): - t0 = time.time() - - # Initialize logging. - if '-d' in sys.argv: - level = logging.DEBUG - elif '-v' in sys.argv: - level = logging.INFO - elif '-q' in sys.argv: - level = logging.ERROR - else: - level = logging.WARN - logging.basicConfig(level=level) - - # Run scheduler, starting it off with doit(). - task = scheduling.run(doit()) - if task.exception: - print('Exception:', repr(task.exception)) - if isinstance(task.exception, AssertionError): - raise task.exception - else: - for t in task.result: - print(t.name + ':', - repr(t.exception) if t.exception else t.result) - - # Report real, user, sys times. - t1 = time.time() - logtimes(t1-t0) - - -if __name__ == '__main__': - main() diff --git a/old/p3time.py b/old/p3time.py deleted file mode 100644 index 35e14c96..00000000 --- a/old/p3time.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Compare timing of plain vs. yield-from calls.""" - -import gc -import time - -def plain(n): - if n <= 0: - return 1 - l = plain(n-1) - r = plain(n-1) - return l + 1 + r - -def coroutine(n): - if n <= 0: - return 1 - l = yield from coroutine(n-1) - r = yield from coroutine(n-1) - return l + 1 + r - -def submain(depth): - t0 = time.time() - k = plain(depth) - t1 = time.time() - fmt = ' {} {} {:-9,.5f}' - delta0 = t1-t0 - print(('plain' + fmt).format(depth, k, delta0)) - - t0 = time.time() - try: - g = coroutine(depth) - while True: - next(g) - except StopIteration as err: - k = err.value - t1 = time.time() - delta1 = t1-t0 - print(('coro.' + fmt).format(depth, k, delta1)) - if delta0: - print(('relat' + fmt).format(depth, k, delta1/delta0)) - -def main(reasonable=16): - gc.disable() - for depth in range(reasonable): - submain(depth) - -if __name__ == '__main__': - main() diff --git a/old/polling.py b/old/polling.py deleted file mode 100644 index 6586efcc..00000000 --- a/old/polling.py +++ /dev/null @@ -1,535 +0,0 @@ -"""Event loop and related classes. - -The event loop can be broken up into a pollster (the part responsible -for telling us when file descriptors are ready) and the event loop -proper, which wraps a pollster with functionality for scheduling -callbacks, immediately or at a given time in the future. - -Whenever a public API takes a callback, subsequent positional -arguments will be passed to the callback if/when it is called. This -avoids the proliferation of trivial lambdas implementing closures. -Keyword arguments for the callback are not supported; this is a -conscious design decision, leaving the door open for keyword arguments -to modify the meaning of the API call itself. - -There are several implementations of the pollster part, several using -esoteric system calls that exist only on some platforms. These are: - -- kqueue (most BSD systems) -- epoll (newer Linux systems) -- poll (most UNIX systems) -- select (all UNIX systems, and Windows) -- TODO: Support IOCP on Windows and some UNIX platforms. - -NOTE: We don't use select on systems where any of the others is -available, because select performs poorly as the number of file -descriptors goes up. The ranking is roughly: - - 1. kqueue, epoll, IOCP - 2. poll - 3. select - -TODO: -- Optimize the various pollsters. -- Unittests. -""" - -import collections -import concurrent.futures -import heapq -import logging -import os -import select -import threading -import time - - -class PollsterBase: - """Base class for all polling implementations. - - This defines an interface to register and unregister readers and - writers for specific file descriptors, and an interface to get a - list of events. There's also an interface to check whether any - readers or writers are currently registered. - """ - - def __init__(self): - super().__init__() - self.readers = {} # {fd: token, ...}. - self.writers = {} # {fd: token, ...}. - - def pollable(self): - """Return True if any readers or writers are currently registered.""" - return bool(self.readers or self.writers) - - # Subclasses are expected to extend the add/remove methods. - - def register_reader(self, fd, token): - """Add or update a reader for a file descriptor.""" - self.readers[fd] = token - - def register_writer(self, fd, token): - """Add or update a writer for a file descriptor.""" - self.writers[fd] = token - - def unregister_reader(self, fd): - """Remove the reader for a file descriptor.""" - del self.readers[fd] - - def unregister_writer(self, fd): - """Remove the writer for a file descriptor.""" - del self.writers[fd] - - def poll(self, timeout=None): - """Poll for events. A subclass must implement this. - - If timeout is omitted or None, this blocks until at least one - event is ready. Otherwise, timeout gives a maximum time to - wait (in seconds as an int or float) -- the method returns as - soon as at least one event is ready or when the timeout is - expired. For a non-blocking poll, pass 0. - - The return value is a list of events; it is empty when the - timeout expired before any events were ready. Each event - is a token previously passed to register_reader/writer(). - """ - raise NotImplementedError - - -class SelectPollster(PollsterBase): - """Pollster implementation using select.""" - - def poll(self, timeout=None): - readable, writable, _ = select.select(self.readers, self.writers, - [], timeout) - events = [] - events += (self.readers[fd] for fd in readable) - events += (self.writers[fd] for fd in writable) - return events - - -class PollPollster(PollsterBase): - """Pollster implementation using poll.""" - - def __init__(self): - super().__init__() - self._poll = select.poll() - - def _update(self, fd): - assert isinstance(fd, int), fd - flags = 0 - if fd in self.readers: - flags |= select.POLLIN - if fd in self.writers: - flags |= select.POLLOUT - if flags: - self._poll.register(fd, flags) - else: - self._poll.unregister(fd) - - def register_reader(self, fd, callback, *args): - super().register_reader(fd, callback, *args) - self._update(fd) - - def register_writer(self, fd, callback, *args): - super().register_writer(fd, callback, *args) - self._update(fd) - - def unregister_reader(self, fd): - super().unregister_reader(fd) - self._update(fd) - - def unregister_writer(self, fd): - super().unregister_writer(fd) - self._update(fd) - - def poll(self, timeout=None): - # Timeout is in seconds, but poll() takes milliseconds. - msecs = None if timeout is None else int(round(1000 * timeout)) - events = [] - for fd, flags in self._poll.poll(msecs): - if flags & (select.POLLIN | select.POLLHUP): - if fd in self.readers: - events.append(self.readers[fd]) - if flags & (select.POLLOUT | select.POLLHUP): - if fd in self.writers: - events.append(self.writers[fd]) - return events - - -class EPollPollster(PollsterBase): - """Pollster implementation using epoll.""" - - def __init__(self): - super().__init__() - self._epoll = select.epoll() - - def _update(self, fd): - assert isinstance(fd, int), fd - eventmask = 0 - if fd in self.readers: - eventmask |= select.EPOLLIN - if fd in self.writers: - eventmask |= select.EPOLLOUT - if eventmask: - try: - self._epoll.register(fd, eventmask) - except IOError: - self._epoll.modify(fd, eventmask) - else: - self._epoll.unregister(fd) - - def register_reader(self, fd, callback, *args): - super().register_reader(fd, callback, *args) - self._update(fd) - - def register_writer(self, fd, callback, *args): - super().register_writer(fd, callback, *args) - self._update(fd) - - def unregister_reader(self, fd): - super().unregister_reader(fd) - self._update(fd) - - def unregister_writer(self, fd): - super().unregister_writer(fd) - self._update(fd) - - def poll(self, timeout=None): - if timeout is None: - timeout = -1 # epoll.poll() uses -1 to mean "wait forever". - events = [] - for fd, eventmask in self._epoll.poll(timeout): - if eventmask & select.EPOLLIN: - if fd in self.readers: - events.append(self.readers[fd]) - if eventmask & select.EPOLLOUT: - if fd in self.writers: - events.append(self.writers[fd]) - return events - - -class KqueuePollster(PollsterBase): - """Pollster implementation using kqueue.""" - - def __init__(self): - super().__init__() - self._kqueue = select.kqueue() - - def register_reader(self, fd, callback, *args): - if fd not in self.readers: - kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) - self._kqueue.control([kev], 0, 0) - return super().register_reader(fd, callback, *args) - - def register_writer(self, fd, callback, *args): - if fd not in self.writers: - kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) - self._kqueue.control([kev], 0, 0) - return super().register_writer(fd, callback, *args) - - def unregister_reader(self, fd): - super().unregister_reader(fd) - kev = select.kevent(fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) - self._kqueue.control([kev], 0, 0) - - def unregister_writer(self, fd): - super().unregister_writer(fd) - kev = select.kevent(fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE) - self._kqueue.control([kev], 0, 0) - - def poll(self, timeout=None): - events = [] - max_ev = len(self.readers) + len(self.writers) - for kev in self._kqueue.control(None, max_ev, timeout): - fd = kev.ident - flag = kev.filter - if flag == select.KQ_FILTER_READ and fd in self.readers: - events.append(self.readers[fd]) - elif flag == select.KQ_FILTER_WRITE and fd in self.writers: - events.append(self.writers[fd]) - return events - - -# Pick the best pollster class for the platform. -if hasattr(select, 'kqueue'): - best_pollster = KqueuePollster -elif hasattr(select, 'epoll'): - best_pollster = EPollPollster -elif hasattr(select, 'poll'): - best_pollster = PollPollster -else: - best_pollster = SelectPollster - - -class DelayedCall: - """Object returned by callback registration methods.""" - - def __init__(self, when, callback, args, kwds=None): - self.when = when - self.callback = callback - self.args = args - self.kwds = kwds - self.cancelled = False - - def cancel(self): - self.cancelled = True - - def __lt__(self, other): - return self.when < other.when - - def __le__(self, other): - return self.when <= other.when - - def __eq__(self, other): - return self.when == other.when - - -class EventLoop: - """Event loop functionality. - - This defines public APIs call_soon(), call_later(), run_once() and - run(). It also wraps Pollster APIs register_reader(), - register_writer(), remove_reader(), remove_writer() with - add_reader() etc. - - This class's instance variables are not part of its API. - """ - - def __init__(self, pollster=None): - super().__init__() - if pollster is None: - logging.info('Using pollster: %s', best_pollster.__name__) - pollster = best_pollster() - self.pollster = pollster - self.ready = collections.deque() # [(callback, args), ...] - self.scheduled = [] # [(when, callback, args), ...] - - def add_reader(self, fd, callback, *args): - """Add a reader callback. Return a DelayedCall instance.""" - dcall = DelayedCall(None, callback, args) - self.pollster.register_reader(fd, dcall) - return dcall - - def remove_reader(self, fd): - """Remove a reader callback.""" - self.pollster.unregister_reader(fd) - - def add_writer(self, fd, callback, *args): - """Add a writer callback. Return a DelayedCall instance.""" - dcall = DelayedCall(None, callback, args) - self.pollster.register_writer(fd, dcall) - return dcall - - def remove_writer(self, fd): - """Remove a writer callback.""" - self.pollster.unregister_writer(fd) - - def add_callback(self, dcall): - """Add a DelayedCall to ready or scheduled.""" - if dcall.cancelled: - return - if dcall.when is None: - self.ready.append(dcall) - else: - heapq.heappush(self.scheduled, dcall) - - def call_soon(self, callback, *args): - """Arrange for a callback to be called as soon as possible. - - This operates as a FIFO queue, callbacks are called in the - order in which they are registered. Each callback will be - called exactly once. - - Any positional arguments after the callback will be passed to - the callback when it is called. - """ - dcall = DelayedCall(None, callback, args) - self.ready.append(dcall) - return dcall - - def call_later(self, when, callback, *args): - """Arrange for a callback to be called at a given time. - - Return an object with a cancel() method that can be used to - cancel the call. - - The time can be an int or float, expressed in seconds. - - If when is small enough (~11 days), it's assumed to be a - relative time, meaning the call will be scheduled that many - seconds in the future; otherwise it's assumed to be a posix - timestamp as returned by time.time(). - - Each callback will be called exactly once. If two callbacks - are scheduled for exactly the same time, it undefined which - will be called first. - - Any positional arguments after the callback will be passed to - the callback when it is called. - """ - if when < 10000000: - when += time.time() - dcall = DelayedCall(when, callback, args) - heapq.heappush(self.scheduled, dcall) - return dcall - - def run_once(self): - """Run one full iteration of the event loop. - - This calls all currently ready callbacks, polls for I/O, - schedules the resulting callbacks, and finally schedules - 'call_later' callbacks. - """ - # TODO: Break each of these into smaller pieces. - # TODO: Pass in a timeout or deadline or something. - # TODO: Refactor to separate the callbacks from the readers/writers. - # TODO: As step 4, run everything scheduled by steps 1-3. - # TODO: An alternative API would be to do the *minimal* amount - # of work, e.g. one callback or one I/O poll. - - # This is the only place where callbacks are actually *called*. - # All other places just add them to ready. - # TODO: Ensure this loop always finishes, even if some - # callbacks keeps registering more callbacks. - while self.ready: - dcall = self.ready.popleft() - if not dcall.cancelled: - try: - if dcall.kwds: - dcall.callback(*dcall.args, **dcall.kwds) - else: - dcall.callback(*dcall.args) - except Exception: - logging.exception('Exception in callback %s %r', - dcall.callback, dcall.args) - - # Remove delayed calls that were cancelled from head of queue. - while self.scheduled and self.scheduled[0].cancelled: - heapq.heappop(self.scheduled) - - # Inspect the poll queue. - if self.pollster.pollable(): - if self.scheduled: - when = self.scheduled[0].when - timeout = max(0, when - time.time()) - else: - timeout = None - t0 = time.time() - events = self.pollster.poll(timeout) - t1 = time.time() - argstr = '' if timeout is None else ' %.3f' % timeout - if t1-t0 >= 1: - level = logging.INFO - else: - level = logging.DEBUG - logging.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) - for dcall in events: - self.add_callback(dcall) - - # Handle 'later' callbacks that are ready. - now = time.time() - while self.scheduled: - dcall = self.scheduled[0] - if dcall.when > now: - break - dcall = heapq.heappop(self.scheduled) - self.call_soon(dcall.callback, *dcall.args) - - def run(self): - """Run the event loop until there is no work left to do. - - This keeps going as long as there are either readable and - writable file descriptors, or scheduled callbacks (of either - variety). - """ - while self.ready or self.scheduled or self.pollster.pollable(): - self.run_once() - - -MAX_WORKERS = 5 # Default max workers when creating an executor. - - -class ThreadRunner: - """Helper to submit work to a thread pool and wait for it. - - This is the glue between the single-threaded callback-based async - world and the threaded world. Use it to call functions that must - block and don't have an async alternative (e.g. getaddrinfo()). - - The only public API is submit(). - """ - - def __init__(self, eventloop, executor=None): - self.eventloop = eventloop - self.executor = executor # Will be constructed lazily. - self.pipe_read_fd, self.pipe_write_fd = os.pipe() - self.active_count = 0 - - def read_callback(self): - """Semi-permanent callback while at least one future is active.""" - assert self.active_count > 0, self.active_count - data = os.read(self.pipe_read_fd, 8192) # Traditional buffer size. - self.active_count -= len(data) - if self.active_count == 0: - self.eventloop.remove_reader(self.pipe_read_fd) - assert self.active_count >= 0, self.active_count - - def submit(self, func, *args, executor=None, callback=None): - """Submit a function to the thread pool. - - This returns a concurrent.futures.Future instance. The caller - should not wait for that, but rather use the callback argument.. - """ - if executor is None: - executor = self.executor - if executor is None: - # Lazily construct a default executor. - # TODO: Should this be shared between threads? - executor = concurrent.futures.ThreadPoolExecutor(MAX_WORKERS) - self.executor = executor - assert self.active_count >= 0, self.active_count - future = executor.submit(func, *args) - if self.active_count == 0: - self.eventloop.add_reader(self.pipe_read_fd, self.read_callback) - self.active_count += 1 - def done_callback(fut): - if callback is not None: - self.eventloop.call_soon(callback, fut) - # TODO: Wake up the pipe in call_soon()? - os.write(self.pipe_write_fd, b'x') - future.add_done_callback(done_callback) - return future - - -class Context(threading.local): - """Thread-local context. - - We use this to avoid having to explicitly pass around an event loop - or something to hold the current task. - - TODO: Add an API so frameworks can substitute a different notion - of context more easily. - """ - - def __init__(self, eventloop=None, threadrunner=None): - # Default event loop and thread runner are lazily constructed - # when first accessed. - self._eventloop = eventloop - self._threadrunner = threadrunner - self.current_task = None # For the benefit of scheduling.py. - - @property - def eventloop(self): - if self._eventloop is None: - self._eventloop = EventLoop() - return self._eventloop - - @property - def threadrunner(self): - if self._threadrunner is None: - self._threadrunner = ThreadRunner(self.eventloop) - return self._threadrunner - - -context = Context() # Thread-local! diff --git a/old/scheduling.py b/old/scheduling.py deleted file mode 100644 index 3864571d..00000000 --- a/old/scheduling.py +++ /dev/null @@ -1,354 +0,0 @@ -#!/usr/bin/env python3.3 -"""Example coroutine scheduler, PEP-380-style ('yield from '). - -Requires Python 3.3. - -There are likely micro-optimizations possible here, but that's not the point. - -TODO: -- Docstrings. -- Unittests. - -PATTERNS TO TRY: -- Various synchronization primitives (Lock, RLock, Event, Condition, - Semaphore, BoundedSemaphore, Barrier). -""" - -__author__ = 'Guido van Rossum ' - -# Standard library imports (keep in alphabetic order). -from concurrent.futures import CancelledError, TimeoutError -import logging -import time -import types - -# Local imports (keep in alphabetic order). -import polling - - -context = polling.context - - -class Task: - """Wrapper around a stack of generators. - - This is a bit like a Future, but with a different interface. - - TODO: - - wait for result. - """ - - def __init__(self, gen, name=None, *, timeout=None): - assert isinstance(gen, types.GeneratorType), repr(gen) - self.gen = gen - self.name = name or gen.__name__ - self.timeout = timeout - self.eventloop = context.eventloop - self.canceleer = None - if timeout is not None: - self.canceleer = self.eventloop.call_later(timeout, self.cancel) - self.blocked = False - self.unblocker = None - self.cancelled = False - self.must_cancel = False - self.alive = True - self.result = None - self.exception = None - self.done_callbacks = [] - # Start the task immediately. - self.eventloop.call_soon(self.step) - - def add_done_callback(self, done_callback): - # For better or for worse, the callback will always be called - # with the task as an argument, like concurrent.futures.Future. - # TODO: Call it right away if task is no longer alive. - dcall = polling.DelayedCall(None, done_callback, (self,)) - self.done_callbacks.append(dcall) - self.done_callbacks = [dc for dc in self.done_callbacks - if not dc.cancelled] - return dcall - - def __repr__(self): - parts = [self.name] - is_current = (self is context.current_task) - if self.blocked: - parts.append('blocking' if is_current else 'blocked') - elif self.alive: - parts.append('running' if is_current else 'runnable') - if self.must_cancel: - parts.append('must_cancel') - if self.cancelled: - parts.append('cancelled') - if self.exception is not None: - parts.append('exception=%r' % self.exception) - elif not self.alive: - parts.append('result=%r' % (self.result,)) - if self.timeout is not None: - parts.append('timeout=%.3f' % self.timeout) - return 'Task<' + ', '.join(parts) + '>' - - def cancel(self): - if self.alive: - if not self.must_cancel and not self.cancelled: - self.must_cancel = True - if self.blocked: - self.unblock() - - def step(self): - assert self.alive, self - try: - context.current_task = self - if self.must_cancel: - self.must_cancel = False - self.cancelled = True - self.gen.throw(CancelledError()) - else: - next(self.gen) - except StopIteration as exc: - self.alive = False - self.result = exc.value - except Exception as exc: - self.alive = False - self.exception = exc - logging.debug('Uncaught exception in %s', self, - exc_info=True, stack_info=True) - except BaseException as exc: - self.alive = False - self.exception = exc - raise - else: - if not self.blocked: - self.eventloop.call_soon(self.step) - finally: - context.current_task = None - if not self.alive: - # Cancel timeout callback if set. - if self.canceleer is not None: - self.canceleer.cancel() - # Schedule done_callbacks. - for dcall in self.done_callbacks: - self.eventloop.add_callback(dcall) - - def block(self, unblock_callback=None, *unblock_args): - assert self is context.current_task, self - assert self.alive, self - assert not self.blocked, self - self.blocked = True - self.unblocker = (unblock_callback, unblock_args) - - def unblock_if_alive(self, unused=None): - # Ignore optional argument so we can be a Future's done_callback. - if self.alive: - self.unblock() - - def unblock(self, unused=None): - # Ignore optional argument so we can be a Future's done_callback. - assert self.alive, self - assert self.blocked, self - self.blocked = False - unblock_callback, unblock_args = self.unblocker - if unblock_callback is not None: - try: - unblock_callback(*unblock_args) - except Exception: - logging.error('Exception in unblocker in task %r', self.name) - raise - finally: - self.unblocker = None - self.eventloop.call_soon(self.step) - - def block_io(self, fd, flag): - assert isinstance(fd, int), repr(fd) - assert flag in ('r', 'w'), repr(flag) - if flag == 'r': - self.block(self.eventloop.remove_reader, fd) - self.eventloop.add_reader(fd, self.unblock) - else: - self.block(self.eventloop.remove_writer, fd) - self.eventloop.add_writer(fd, self.unblock) - - def wait(self): - """COROUTINE: Wait until this task is finished.""" - current_task = context.current_task - assert self is not current_task, (self, current_task) # How confusing! - if not self.alive: - return - current_task.block() - self.add_done_callback(current_task.unblock) - yield - - def __iter__(self): - """COROUTINE: Wait, then return result or raise exception. - - This adds a little magic so you can say - - x = yield from Task(gen()) - - and it is equivalent to - - x = yield from gen() - - but with the option to add a timeout (and only a tad slower). - """ - if self.alive: - yield from self.wait() - assert not self.alive - if self.exception is not None: - raise self.exception - return self.result - - -def run(arg=None): - """Run the event loop until it's out of work. - - If you pass a generator, it will be spawned for you. - You can also pass a task (already started). - Returns the task. - """ - t = None - if arg is not None: - if isinstance(arg, Task): - t = arg - else: - t = Task(arg) - context.eventloop.run() - if t is not None and t.exception is not None: - logging.error('Uncaught exception in startup task: %r', - t.exception) - return t - - -def sleep(secs): - """COROUTINE: Sleep for some time (a float in seconds).""" - current_task = context.current_task - unblocker = context.eventloop.call_later(secs, current_task.unblock) - current_task.block(unblocker.cancel) - yield - - -def block_r(fd): - """COROUTINE: Block until a file descriptor is ready for reading.""" - context.current_task.block_io(fd, 'r') - yield - - -def block_w(fd): - """COROUTINE: Block until a file descriptor is ready for writing.""" - context.current_task.block_io(fd, 'w') - yield - - -def call_in_thread(func, *args, executor=None): - """COROUTINE: Run a function in a thread.""" - task = context.current_task - eventloop = context.eventloop - future = context.threadrunner.submit(func, *args, - executor=executor, - callback=task.unblock_if_alive) - task.block(future.cancel) - yield - assert future.done() - return future.result() - - -def wait_for(count, tasks): - """COROUTINE: Wait for the first N of a set of tasks to complete. - - May return more than N if more than N are immediately ready. - - NOTE: Tasks that were cancelled or raised are also considered ready. - """ - assert tasks - assert all(isinstance(task, Task) for task in tasks) - tasks = set(tasks) - assert 1 <= count <= len(tasks) - current_task = context.current_task - assert all(task is not current_task for task in tasks) - todo = set() - done = set() - dcalls = [] - def wait_for_callback(task): - nonlocal todo, done, current_task, count, dcalls - todo.remove(task) - if len(done) < count: - done.add(task) - if len(done) == count: - for dcall in dcalls: - dcall.cancel() - current_task.unblock() - for task in tasks: - if task.alive: - todo.add(task) - else: - done.add(task) - if len(done) < count: - for task in todo: - dcall = task.add_done_callback(wait_for_callback) - dcalls.append(dcall) - current_task.block() - yield - return done - - -def wait_any(tasks): - """COROUTINE: Wait for the first of a set of tasks to complete.""" - return wait_for(1, tasks) - - -def wait_all(tasks): - """COROUTINE: Wait for all of a set of tasks to complete.""" - return wait_for(len(tasks), tasks) - - -def map_over(gen, *args, timeout=None): - """COROUTINE: map a generator over one or more iterables. - - E.g. map_over(foo, xs, ys) runs - - Task(foo(x, y)) for x, y in zip(xs, ys) - - and returns a list of all results (in that order). However if any - task raises an exception, the remaining tasks are cancelled and - the exception is propagated. - """ - # gen is a generator function. - tasks = [Task(gobj, timeout=timeout) for gobj in map(gen, *args)] - return (yield from par_tasks(tasks)) - - -def par(*args): - """COROUTINE: Wait for generators, return a list of results. - - Raises as soon as one of the tasks raises an exception (and then - remaining tasks are cancelled). - - This differs from par_tasks() in two ways: - - takes *args instead of list of args - - each arg may be a generator or a task - """ - tasks = [] - for arg in args: - if not isinstance(arg, Task): - # TODO: assert arg is a generator or an iterator? - arg = Task(arg) - tasks.append(arg) - return (yield from par_tasks(tasks)) - - -def par_tasks(tasks): - """COROUTINE: Wait for a list of tasks, return a list of results. - - Raises as soon as one of the tasks raises an exception (and then - remaining tasks are cancelled). - """ - todo = set(tasks) - while todo: - ts = yield from wait_any(todo) - for t in ts: - assert not t.alive, t - todo.remove(t) - if t.exception is not None: - for other in todo: - other.cancel() - raise t.exception - return [t.result for t in tasks] diff --git a/old/sockets.py b/old/sockets.py deleted file mode 100644 index a5005dc3..00000000 --- a/old/sockets.py +++ /dev/null @@ -1,348 +0,0 @@ -"""Socket wrappers to go with scheduling.py. - -Classes: - -- SocketTransport: a transport implementation wrapping a socket. -- SslTransport: a transport implementation wrapping SSL around a socket. -- BufferedReader: a buffer wrapping the read end of a transport. - -Functions (all coroutines): - -- connect(): connect a socket. -- getaddrinfo(): look up an address. -- create_connection(): look up address and return a connected socket for it. -- create_transport(): look up address and return a connected transport. - -TODO: -- Improve transport abstraction. -- Make a nice protocol abstraction. -- Unittests. -- A write() call that isn't a generator (needed so you can substitute it - for sys.stderr, pass it to logging.StreamHandler, etc.). -""" - -__author__ = 'Guido van Rossum ' - -# Stdlib imports. -import errno -import socket -import ssl - -# Local imports. -import scheduling - -# Errno values indicating the connection was disconnected. -_DISCONNECTED = frozenset((errno.ECONNRESET, - errno.ENOTCONN, - errno.ESHUTDOWN, - errno.ECONNABORTED, - errno.EPIPE, - errno.EBADF, - )) - -# Errno values indicating the socket isn't ready for I/O just yet. -_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) - - -class SocketTransport: - """Transport wrapping a socket. - - The socket must already be connected in non-blocking mode. - """ - - def __init__(self, sock): - self.sock = sock - - def recv(self, n): - """COROUTINE: Read up to n bytes, blocking as needed. - - Always returns at least one byte, except if the socket was - closed or disconnected and there's no more data; then it - returns b''. - """ - assert n >= 0, n - while True: - try: - return self.sock.recv(n) - except socket.error as err: - if err.errno in _TRYAGAIN: - pass - elif err.errno in _DISCONNECTED: - return b'' - else: - raise # Unexpected, propagate. - yield from scheduling.block_r(self.sock.fileno()) - - def send(self, data): - """COROUTINE; Send data to the socket, blocking until all written. - - Return True if all went well, False if socket was disconnected. - """ - while data: - try: - n = self.sock.send(data) - except socket.error as err: - if err.errno in _TRYAGAIN: - pass - elif err.errno in _DISCONNECTED: - return False - else: - raise # Unexpected, propagate. - else: - assert 0 <= n <= len(data), (n, len(data)) - if n == len(data): - break - data = data[n:] - continue - yield from scheduling.block_w(self.sock.fileno()) - - return True - - def close(self): - """Close the socket. (Not a coroutine.)""" - self.sock.close() - - -class SslTransport: - """Transport wrapping a socket in SSL. - - The socket must already be connected at the TCP level in - non-blocking mode. - """ - - def __init__(self, rawsock, sslcontext=None): - self.rawsock = rawsock - self.sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) - self.sslsock = self.sslcontext.wrap_socket( - self.rawsock, do_handshake_on_connect=False) - - def do_handshake(self): - """COROUTINE: Finish the SSL handshake.""" - while True: - try: - self.sslsock.do_handshake() - except ssl.SSLWantReadError: - yield from scheduling.block_r(self.sslsock.fileno()) - except ssl.SSLWantWriteError: - yield from scheduling.block_w(self.sslsock.fileno()) - else: - break - - def recv(self, n): - """COROUTINE: Read up to n bytes. - - This blocks until at least one byte is read, or until EOF. - """ - while True: - try: - return self.sslsock.recv(n) - except ssl.SSLWantReadError: - yield from scheduling.block_r(self.sslsock.fileno()) - except ssl.SSLWantWriteError: - yield from scheduling.block_w(self.sslsock.fileno()) - except socket.error as err: - if err.errno in _TRYAGAIN: - yield from scheduling.block_r(self.sslsock.fileno()) - elif err.errno in _DISCONNECTED: - # Can this happen? - return b'' - else: - raise # Unexpected, propagate. - - def send(self, data): - """COROUTINE: Send data to the socket, blocking as needed.""" - while data: - try: - n = self.sslsock.send(data) - except ssl.SSLWantReadError: - yield from scheduling.block_r(self.sslsock.fileno()) - except ssl.SSLWantWriteError: - yield from scheduling.block_w(self.sslsock.fileno()) - except socket.error as err: - if err.errno in _TRYAGAIN: - yield from scheduling.block_w(self.sslsock.fileno()) - elif err.errno in _DISCONNECTED: - return False - else: - raise # Unexpected, propagate. - if n == len(data): - break - data = data[n:] - - return True - - def close(self): - """Close the socket. (Not a coroutine.) - - This also closes the raw socket. - """ - self.sslsock.close() - - # TODO: More SSL-specific methods, e.g. certificate stuff, unwrap(), ... - - -class BufferedReader: - """A buffered reader wrapping a transport.""" - - def __init__(self, trans, limit=8192): - self.trans = trans - self.limit = limit - self.buffer = b'' - self.eof = False - - def read(self, n): - """COROUTINE: Read up to n bytes, blocking at most once.""" - assert n >= 0, n - if not self.buffer and not self.eof: - yield from self._fillbuffer(max(n, self.limit)) - return self._getfrombuffer(n) - - def readexactly(self, n): - """COUROUTINE: Read exactly n bytes, or until EOF.""" - blocks = [] - count = 0 - while count < n: - block = yield from self.read(n - count) - if not block: - break - blocks.append(block) - count += len(block) - return b''.join(blocks) - - def readline(self): - """COROUTINE: Read up to newline or limit, whichever comes first.""" - end = self.buffer.find(b'\n') + 1 # Point past newline, or 0. - while not end and not self.eof and len(self.buffer) < self.limit: - anchor = len(self.buffer) - yield from self._fillbuffer(self.limit) - end = self.buffer.find(b'\n', anchor) + 1 - if not end: - end = len(self.buffer) - if end > self.limit: - end = self.limit - return self._getfrombuffer(end) - - def _getfrombuffer(self, n): - """Read up to n bytes without blocking (not a coroutine).""" - if n >= len(self.buffer): - result, self.buffer = self.buffer, b'' - else: - result, self.buffer = self.buffer[:n], self.buffer[n:] - return result - - def _fillbuffer(self, n): - """COROUTINE: Fill buffer with one (up to) n bytes from transport.""" - assert not self.eof, '_fillbuffer called at eof' - data = yield from self.trans.recv(n) - if data: - self.buffer += data - else: - self.eof = True - - -def connect(sock, address): - """COROUTINE: Connect a socket to an address.""" - try: - sock.connect(address) - except socket.error as err: - if err.errno != errno.EINPROGRESS: - raise - yield from scheduling.block_w(sock.fileno()) - err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) - if err != 0: - raise IOError(err, 'Connection refused') - - -def getaddrinfo(host, port, af=0, socktype=0, proto=0): - """COROUTINE: Look up an address and return a list of infos for it. - - Each info is a tuple (af, socktype, protocol, canonname, address). - """ - infos = yield from scheduling.call_in_thread(socket.getaddrinfo, - host, port, af, - socktype, proto) - return infos - - -def create_connection(host, port, af=0, socktype=socket.SOCK_STREAM, proto=0): - """COROUTINE: Look up address and create a socket connected to it.""" - infos = yield from getaddrinfo(host, port, af, socktype, proto) - if not infos: - raise IOError('getaddrinfo() returned an empty list') - exc = None - for af, socktype, proto, cname, address in infos: - sock = None - try: - sock = socket.socket(af, socktype, proto) - sock.setblocking(False) - yield from connect(sock, address) - break - except socket.error as err: - if sock is not None: - sock.close() - if exc is None: - exc = err - else: - raise exc - return sock - - -def create_transport(host, port, af=0, ssl=None): - """COROUTINE: Look up address and create a transport connected to it.""" - if ssl is None: - ssl = (port == 443) - sock = yield from create_connection(host, port, af) - if ssl: - trans = SslTransport(sock) - yield from trans.do_handshake() - else: - trans = SocketTransport(sock) - return trans - - -class Listener: - """Wrapper for a listening socket.""" - - def __init__(self, sock): - self.sock = sock - - def accept(self): - """COROUTINE: Accept a connection.""" - while True: - try: - conn, addr = self.sock.accept() - except socket.error as err: - if err.errno in _TRYAGAIN: - yield from scheduling.block_r(self.sock.fileno()) - else: - raise # Unexpected, propagate. - else: - conn.setblocking(False) - return conn, addr - - -def create_listener(host, port, af=0, socktype=0, proto=0, - backlog=5, reuse_addr=True): - """COROUTINE: Look up address and create a listener for it.""" - infos = yield from getaddrinfo(host, port, af, socktype, proto) - if not infos: - raise IOError('getaddrinfo() returned an empty list') - exc = None - for af, socktype, proto, cname, address in infos: - sock = None - try: - sock = socket.socket(af, socktype, proto) - sock.setblocking(False) - if reuse_addr: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind(address) - sock.listen(backlog) - break - except socket.error as err: - if sock is not None: - sock.close() - if exc is None: - exc = err - else: - raise exc - return Listener(sock) diff --git a/old/transports.py b/old/transports.py deleted file mode 100644 index 19095bf4..00000000 --- a/old/transports.py +++ /dev/null @@ -1,496 +0,0 @@ -"""Transports and Protocols, actually. - -Inspired by Twisted, PEP 3153 and github.com/lvh/async-pep. - -THIS IS NOT REAL CODE! IT IS JUST AN EXPERIMENT. -""" - -# Stdlib imports. -import collections -import errno -import logging -import socket -import ssl -import sys -import time - -# Local imports. -import polling -import scheduling -import sockets - -# Errno values indicating the connection was disconnected. -_DISCONNECTED = frozenset((errno.ECONNRESET, - errno.ENOTCONN, - errno.ESHUTDOWN, - errno.ECONNABORTED, - errno.EPIPE, - errno.EBADF, - )) - -# Errno values indicating the socket isn't ready for I/O just yet. -_TRYAGAIN = frozenset((errno.EAGAIN, errno.EWOULDBLOCK)) - - -class Transport: - """ABC representing a transport. - - There may be many implementations. The user never instantiates - this directly; they call some utility function, passing it a - protocol, and the utility function will call the protocol's - connection_made() method with a transport (or it will call - connection_lost() with an exception if it fails to create the - desired transport). - - The implementation here raises NotImplemented for every method - except writelines(), which calls write() in a loop. - """ - - def write(self, data): - """Write some data (bytes) to the transport. - - This does not block; it buffers the data and arranges for it - to be sent out asynchronously. - """ - raise NotImplementedError - - def writelines(self, list_of_data): - """Write a list (or any iterable) of data (bytes) to the transport. - - The default implementation just calls write() for each item in - the list/iterable. - """ - for data in list_of_data: - self.write(data) - - def close(self): - """Closes the transport. - - Buffered data will be flushed asynchronously. No more data will - be received. When all buffered data is flushed, the protocol's - connection_lost() method is called with None as its argument. - """ - raise NotImplementedError - - def abort(self): - """Closes the transport immediately. - - Buffered data will be lost. No more data will be received. - The protocol's connection_lost() method is called with None as - its argument. - """ - raise NotImplementedError - - def half_close(self): - """Closes the write end after flushing buffered data. - - Data may still be received. - - TODO: What's the use case for this? How to implement it? - Should it call shutdown(SHUT_WR) after all the data is flushed? - Is there no use case for closing the other half first? - """ - raise NotImplementedError - - def pause(self): - """Pause the receiving end. - - No data will be received until resume() is called. - """ - raise NotImplementedError - - def resume(self): - """Resume the receiving end. - - Cancels a pause() call, resumes receiving data. - """ - raise NotImplementedError - - -class Protocol: - """ABC representing a protocol. - - The user should implement this interface. They can inherit from - this class but don't need to. The implementations here do - nothing. - - When the user wants to requests a transport, they pass a protocol - instance to a utility function. - - When the connection is made successfully, connection_made() is - called with a suitable transport object. Then data_received() - will be called 0 or more times with data (bytes) received from the - transport; finally, connection_list() will be called exactly once - with either an exception object or None as an argument. - - If the utility function does not succeed in creating a transport, - it will call connection_lost() with an exception object. - - State machine of calls: - - start -> [CM -> DR*] -> CL -> end - """ - - def connection_made(self, transport): - """Called when a connection is made. - - The argument is the transport representing the connection. - To send data, call its write() or writelines() method. - To receive data, wait for data_received() calls. - When the connection is closed, connection_lost() is called. - """ - - def data_received(self, data): - """Called when some data is received. - - The argument is a bytes object. - - TODO: Should we allow it to be a bytesarray or some other - memory buffer? - """ - - def connection_lost(self, exc): - """Called when the connection is lost or closed. - - Also called when we fail to make a connection at all (in that - case connection_made() will not be called). - - The argument is an exception object or None (the latter - meaning a regular EOF is received or the connection was - aborted or closed). - """ - - -# TODO: The rest is platform specific and should move elsewhere. - -class UnixSocketTransport(Transport): - - def __init__(self, eventloop, protocol, sock): - self._eventloop = eventloop - self._protocol = protocol - self._sock = sock - self._buffer = collections.deque() # For write(). - self._write_closed = False - - def _on_readable(self): - try: - data = self._sock.recv(8192) - except socket.error as exc: - if exc.errno not in _TRYAGAIN: - self._bad_error(exc) - else: - if not data: - self._eventloop.remove_reader(self._sock.fileno()) - self._sock.close() - self._protocol.connection_lost(None) - else: - self._protocol.data_received(data) # XXX call_soon()? - - def write(self, data): - assert isinstance(data, bytes) - assert not self._write_closed - if not data: - # Silly, but it happens. - return - if self._buffer: - # We've already registered a callback, just buffer the data. - self._buffer.append(data) - # Consider pausing if the total length of the buffer is - # truly huge. - return - - # TODO: Refactor so there's more sharing between this and - # _on_writable(). - - # There's no callback registered yet. It's quite possible - # that the kernel has buffer space for our data, so try to - # write now. Since the socket is non-blocking it will - # give us an error in _TRYAGAIN if it doesn't have enough - # space for even one more byte; it will return the number - # of bytes written if it can write at least one byte. - try: - n = self._sock.send(data) - except socket.error as exc: - # An error. - if exc.errno not in _TRYAGAIN: - self._bad_error(exc) - return - # The kernel doesn't have room for more data right now. - n = 0 - else: - # Wrote at least one byte. - if n == len(data): - # Wrote it all. Done! - if self._write_closed: - self._sock.shutdown(socket.SHUT_WR) - return - # Throw away the data that was already written. - # TODO: Do this without copying the data? - data = data[n:] - self._buffer.append(data) - self._eventloop.add_writer(self._sock.fileno(), self._on_writable) - - def _on_writable(self): - while self._buffer: - data = self._buffer[0] - # TODO: Join small amounts of data? - try: - n = self._sock.send(data) - except socket.error as exc: - # Error handling is the same as in write(). - if exc.errno not in _TRYAGAIN: - self._bad_error(exc) - return - if n < len(data): - self._buffer[0] = data[n:] - return - self._buffer.popleft() - self._eventloop.remove_writer(self._sock.fileno()) - if self._write_closed: - self._sock.shutdown(socket.SHUT_WR) - - def abort(self): - self._bad_error(None) - - def _bad_error(self, exc): - # A serious error. Close the socket etc. - fd = self._sock.fileno() - # TODO: Record whether we have a writer and/or reader registered. - try: - self._eventloop.remove_writer(fd) - except Exception: - pass - try: - self._eventloop.remove_reader(fd) - except Exception: - pass - self._sock.close() - self._protocol.connection_lost(exc) # XXX call_soon()? - - def half_close(self): - self._write_closed = True - - -class UnixSslTransport(Transport): - - # TODO: Refactor Socket and Ssl transport to share some code. - # (E.g. buffering.) - - # TODO: Consider using coroutines instead of callbacks, it seems - # much easier that way. - - def __init__(self, eventloop, protocol, rawsock, sslcontext=None): - self._eventloop = eventloop - self._protocol = protocol - self._rawsock = rawsock - self._sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) - self._sslsock = self._sslcontext.wrap_socket( - self._rawsock, do_handshake_on_connect=False) - - self._buffer = collections.deque() # For write(). - self._write_closed = False - - # Try the handshake now. Likely it will raise EAGAIN, then it - # will take care of registering the appropriate callback. - self._on_handshake() - - def _bad_error(self, exc): - # A serious error. Close the socket etc. - fd = self._sslsock.fileno() - # TODO: Record whether we have a writer and/or reader registered. - try: - self._eventloop.remove_writer(fd) - except Exception: - pass - try: - self._eventloop.remove_reader(fd) - except Exception: - pass - self._sslsock.close() - self._protocol.connection_lost(exc) # XXX call_soon()? - - def _on_handshake(self): - fd = self._sslsock.fileno() - try: - self._sslsock.do_handshake() - except ssl.SSLWantReadError: - self._eventloop.add_reader(fd, self._on_handshake) - return - except ssl.SSLWantWriteError: - self._eventloop.add_writable(fd, self._on_handshake) - return - # TODO: What if it raises another error? - try: - self._eventloop.remove_reader(fd) - except Exception: - pass - try: - self._eventloop.remove_writer(fd) - except Exception: - pass - self._protocol.connection_made(self) - self._eventloop.add_reader(fd, self._on_ready) - self._eventloop.add_writer(fd, self._on_ready) - - def _on_ready(self): - # Because of renegotiations (?), there's no difference between - # readable and writable. We just try both. XXX This may be - # incorrect; we probably need to keep state about what we - # should do next. - - # Maybe we're already closed... - fd = self._sslsock.fileno() - if fd < 0: - return - - # First try reading. - try: - data = self._sslsock.recv(8192) - except ssl.SSLWantReadError: - pass - except ssl.SSLWantWriteError: - pass - except socket.error as exc: - if exc.errno not in _TRYAGAIN: - self._bad_error(exc) - return - else: - if data: - self._protocol.data_received(data) - else: - # TODO: Don't close when self._buffer is non-empty. - assert not self._buffer - self._eventloop.remove_reader(fd) - self._eventloop.remove_writer(fd) - self._sslsock.close() - self._protocol.connection_lost(None) - return - - # Now try writing, if there's anything to write. - if not self._buffer: - return - - data = self._buffer[0] - try: - n = self._sslsock.send(data) - except ssl.SSLWantReadError: - pass - except ssl.SSLWantWriteError: - pass - except socket.error as exc: - if exc.errno not in _TRYAGAIN: - self._bad_error(exc) - return - else: - if n == len(data): - self._buffer.popleft() - # Could try again, but let's just have the next callback do it. - else: - self._buffer[0] = data[n:] - - def write(self, data): - assert isinstance(data, bytes) - assert not self._write_closed - if not data: - return - self._buffer.append(data) - # We could optimize, but the callback can do this for now. - - def half_close(self): - self._write_closed = True - # Just set the flag. Calling shutdown() on the ssl socket - # breaks something, causing recv() to return binary data. - - -def make_connection(protocol, host, port=None, af=0, socktype=0, proto=0, - use_ssl=None): - # TODO: Pass in a protocol factory, not a protocol. - # What should be the exact sequence of events? - # - socket - # - transport - # - protocol - # - tell transport about protocol - # - tell protocol about transport - # Or should the latter two be reversed? Does it matter? - if port is None: - port = 443 if use_ssl else 80 - if use_ssl is None: - use_ssl = (port == 443) - if not socktype: - socktype = socket.SOCK_STREAM - eventloop = polling.context.eventloop - - def on_socket_connected(task): - assert not task.alive - if task.exception is not None: - # TODO: Call some callback. - raise task.exception - sock = task.result - assert sock is not None - logging.debug('on_socket_connected') - if use_ssl: - # You can pass an ssl.SSLContext object as use_ssl, - # or a bool. - if isinstance(use_ssl, bool): - sslcontext = None - else: - sslcontext = use_ssl - transport = UnixSslTransport(eventloop, protocol, sock, sslcontext) - else: - transport = UnixSocketTransport(eventloop, protocol, sock) - # TODO: Should the ransport make the following calls? - protocol.connection_made(transport) # XXX call_soon()? - # Don't do this before connection_made() is called. - eventloop.add_reader(sock.fileno(), transport._on_readable) - - coro = sockets.create_connection(host, port, af, socktype, proto) - task = scheduling.Task(coro) - task.add_done_callback(on_socket_connected) - - -def main(): # Testing... - - # Initialize logging. - if '-d' in sys.argv: - level = logging.DEBUG - elif '-v' in sys.argv: - level = logging.INFO - elif '-q' in sys.argv: - level = logging.ERROR - else: - level = logging.WARN - logging.basicConfig(level=level) - - host = 'xkcd.com' - if sys.argv[1:] and '.' in sys.argv[-1]: - host = sys.argv[-1] - - t0 = time.time() - - class TestProtocol(Protocol): - def connection_made(self, transport): - logging.info('Connection made at %.3f secs', time.time() - t0) - self.transport = transport - self.transport.write(b'GET / HTTP/1.0\r\nHost: ' + - host.encode('ascii') + - b'\r\n\r\n') - self.transport.half_close() - def data_received(self, data): - logging.info('Received %d bytes at t=%.3f', - len(data), time.time() - t0) - logging.debug('Received %r', data) - def connection_lost(self, exc): - logging.debug('Connection lost: %r', exc) - self.t1 = time.time() - logging.info('Total time %.3f secs', self.t1 - t0) - - tp = TestProtocol() - logging.debug('tp = %r', tp) - make_connection(tp, host, use_ssl=('-S' in sys.argv)) - logging.info('Running...') - polling.context.eventloop.run() - logging.info('Done.') - - -if __name__ == '__main__': - main() diff --git a/old/xkcd.py b/old/xkcd.py deleted file mode 100755 index 474009d0..00000000 --- a/old/xkcd.py +++ /dev/null @@ -1,18 +0,0 @@ -#!/usr/bin/env python3.3 -"""Minimal synchronous SSL demo, connecting to xkcd.com.""" - -import socket, ssl - -s = socket.socket() -s.connect(('xkcd.com', 443)) -ss = ssl.wrap_socket(s) - -ss.send(b'GET / HTTP/1.0\r\n\r\n') - -while True: - data = ss.recv(1000000) - print(data) - if not data: - break - -ss.close() diff --git a/old/yyftime.py b/old/yyftime.py deleted file mode 100644 index f55234b9..00000000 --- a/old/yyftime.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Compare timing of yield-from vs. yield calls.""" - -import gc -import time - -def coroutine(n): - if n <= 0: - return 1 - l = yield from coroutine(n-1) - r = yield from coroutine(n-1) - return l + 1 + r - -def run_coro(depth): - t0 = time.time() - try: - g = coroutine(depth) - while True: - next(g) - except StopIteration as err: - k = err.value - t1 = time.time() - print('coro', depth, k, round(t1-t0, 6)) - return t1-t0 - -class Future: - - def __init__(self, g): - self.g = g - - def wait(self): - value = None - try: - while True: - f = self.g.send(value) - f.wait() - value = f.value - except StopIteration as err: - self.value = err.value - - - -def task(func): # Decorator - def wrapper(*args): - g = func(*args) - f = Future(g) - return f - return wrapper - -@task -def oldstyle(n): - if n <= 0: - return 1 - l = yield oldstyle(n-1) - r = yield oldstyle(n-1) - return l + 1 + r - -def run_olds(depth): - t0 = time.time() - f = oldstyle(depth) - f.wait() - k = f.value - t1 = time.time() - print('olds', depth, k, round(t1-t0, 6)) - return t1-t0 - -def main(): - gc.disable() - for depth in range(16): - tc = run_coro(depth) - to = run_olds(depth) - if tc: - print('ratio', round(to/tc, 2)) - -if __name__ == '__main__': - main() From 53794d39afc5ce651c29d6b12327e0f356d0b6cb Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 24 Apr 2013 16:32:25 -0700 Subject: [PATCH 0433/1502] fix ParserBuffer.skipuntil() size calculation --- tests/parsers_test.py | 7 +++---- tulip/parsers.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/parsers_test.py b/tests/parsers_test.py index 8a6f2927..083e141c 100644 --- a/tests/parsers_test.py +++ b/tests/parsers_test.py @@ -551,13 +551,12 @@ def test_skipuntil(self): pass self.assertEqual(b'456\n', bytes(buf)) - p = buf.readuntil(b'\n') + p = buf.skipuntil(b'\n') try: next(p) - except StopIteration as exc: - res = exc.value + except StopIteration: + pass self.assertEqual(b'', bytes(buf)) - self.assertEqual(b'456\n', res) def test_lines_parser(self): out = parsers.DataBuffer() diff --git a/tulip/parsers.py b/tulip/parsers.py index 0b599635..689fa4c8 100644 --- a/tulip/parsers.py +++ b/tulip/parsers.py @@ -361,7 +361,7 @@ def skipuntil(self, stop): stop_line = self.find(stop, self.offset) if stop_line >= 0: end = stop_line + stop_len - self.size = self.size - end - self.offset + self.size = self.size - (end - self.offset) self.offset = end return else: From 9d94248857ba1b97d10eb28d7e9068961f8c1503 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 24 Apr 2013 20:20:15 -0700 Subject: [PATCH 0434/1502] Get rid of EventLoop.run(). --- tests/events_test.py | 67 +++++++++++++++++++++-------------- tests/subprocess_test.py | 5 ++- tests/tasks_test.py | 10 +++--- tulip/base_events.py | 27 +++----------- tulip/events.py | 22 ++++++------ tulip/subprocess_transport.py | 15 +++++++- 6 files changed, 78 insertions(+), 68 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 9afdff60..f03023ab 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -152,9 +152,6 @@ def tearDown(self): gc.collect() super().tearDown() - def test_run(self): - self.event_loop.run() # Returns immediately. - def test_run_nesting(self): @tasks.coroutine def coro(): @@ -202,13 +199,14 @@ def test_call_later(self): def callback(arg): results.append(arg) + self.event_loop.stop() self.event_loop.call_later(0.1, callback, 'hello world') t0 = time.monotonic() - self.event_loop.run() + self.event_loop.run_forever() t1 = time.monotonic() self.assertEqual(results, ['hello world']) - self.assertTrue(t1-t0 >= 0.09) + self.assertTrue(0.09 <= t1-t0 <= 0.12) def test_call_repeatedly(self): results = [] @@ -218,7 +216,7 @@ def callback(arg): self.event_loop.call_repeatedly(0.03, callback, 'ho') self.event_loop.call_later(0.1, self.event_loop.stop) - self.event_loop.run() + self.event_loop.run_forever() self.assertEqual(results, ['ho', 'ho', 'ho']) def test_call_soon(self): @@ -226,9 +224,10 @@ def test_call_soon(self): def callback(arg1, arg2): results.append((arg1, arg2)) + self.event_loop.stop() self.event_loop.call_soon(callback, 'hello', 'world') - self.event_loop.run() + self.event_loop.run_forever() self.assertEqual(results, [('hello', 'world')]) def test_call_soon_with_handle(self): @@ -236,10 +235,11 @@ def test_call_soon_with_handle(self): def callback(): results.append('yeah') + self.event_loop.stop() handle = events.Handle(callback, ()) self.assertIs(self.event_loop.call_soon(handle), handle) - self.event_loop.run() + self.event_loop.run_forever() self.assertEqual(results, ['yeah']) def test_call_soon_threadsafe(self): @@ -247,15 +247,17 @@ def test_call_soon_threadsafe(self): def callback(arg): results.append(arg) + if len(results) >= 2: + self.event_loop.stop() - def run(): + def run_in_thread(): self.event_loop.call_soon_threadsafe(callback, 'hello') - t = threading.Thread(target=run) + t = threading.Thread(target=run_in_thread) self.event_loop.call_later(0.1, callback, 'world') t0 = time.monotonic() t.start() - self.event_loop.run() + self.event_loop.run_forever() t1 = time.monotonic() t.join() self.assertEqual(results, ['hello', 'world']) @@ -266,10 +268,12 @@ def test_call_soon_threadsafe_same_thread(self): def callback(arg): results.append(arg) + if len(results) >= 2: + self.event_loop.stop() self.event_loop.call_later(0.1, callback, 'world') self.event_loop.call_soon_threadsafe(callback, 'hello') - self.event_loop.run() + self.event_loop.run_forever() self.assertEqual(results, ['hello', 'world']) def test_call_soon_threadsafe_with_handle(self): @@ -277,6 +281,8 @@ def test_call_soon_threadsafe_with_handle(self): def callback(arg): results.append(arg) + if len(results) >= 2: + self.event_loop.stop() handle = events.Handle(callback, ('hello',)) @@ -289,7 +295,7 @@ def run(): t0 = time.monotonic() t.start() - self.event_loop.run() + self.event_loop.run_forever() t1 = time.monotonic() t.join() self.assertEqual(results, ['hello', 'world']) @@ -343,7 +349,8 @@ def reader(): self.event_loop.call_later(0.05, w.send, b'abc') self.event_loop.call_later(0.1, w.send, b'def') self.event_loop.call_later(0.15, w.close) - self.event_loop.run() + self.event_loop.call_later(0.16, self.event_loop.stop) + self.event_loop.run_forever() self.assertEqual(b''.join(bytes_read), b'abcdef') def test_reader_callback_with_handle(self): @@ -369,7 +376,8 @@ def reader(): self.event_loop.call_later(0.05, w.send, b'abc') self.event_loop.call_later(0.1, w.send, b'def') self.event_loop.call_later(0.15, w.close) - self.event_loop.run() + self.event_loop.call_later(0.16, self.event_loop.stop) + self.event_loop.run_forever() self.assertEqual(b''.join(bytes_read), b'abcdef') def test_reader_callback_cancel(self): @@ -392,7 +400,8 @@ def reader(): self.event_loop.call_later(0.05, w.send, b'abc') self.event_loop.call_later(0.1, w.send, b'def') self.event_loop.call_later(0.15, w.close) - self.event_loop.run() + self.event_loop.call_later(0.16, self.event_loop.stop) + self.event_loop.run_forever() self.assertEqual(b''.join(bytes_read), b'abcdef') def test_writer_callback(self): @@ -404,7 +413,8 @@ def remove_writer(): self.assertTrue(self.event_loop.remove_writer(w.fileno())) self.event_loop.call_later(0.1, remove_writer) - self.event_loop.run() + self.event_loop.call_later(0.11, self.event_loop.stop) + self.event_loop.run_forever() w.close() data = r.recv(256*1024) r.close() @@ -420,7 +430,8 @@ def remove_writer(): self.assertTrue(self.event_loop.remove_writer(w.fileno())) self.event_loop.call_later(0.1, remove_writer) - self.event_loop.run() + self.event_loop.call_later(0.11, self.event_loop.stop) + self.event_loop.run_forever() w.close() data = r.recv(256*1024) r.close() @@ -433,9 +444,10 @@ def test_writer_callback_cancel(self): def sender(): w.send(b'x'*256) handle.cancel() + self.event_loop.stop() handle = self.event_loop.add_writer(w.fileno(), sender) - self.event_loop.run() + self.event_loop.run_forever() w.close() data = r.recv(1024) r.close() @@ -589,11 +601,13 @@ def my_handler(*args): def test_create_connection(self): with test_utils.run_test_server(self.event_loop) as httpd: - f = self.event_loop.create_connection(MyProto, *httpd.address) + f = self.event_loop.create_connection( + lambda: MyProto(create_future=True), + *httpd.address) tr, pr = self.event_loop.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) - self.event_loop.run() + self.event_loop.run_until_complete(pr.done) self.assertTrue(pr.nbytes > 0) def test_create_connection_sock(self): @@ -615,11 +629,12 @@ def test_create_connection_sock(self): else: assert False, 'Can not create socket.' - f = self.event_loop.create_connection(MyProto, sock=sock) + f = self.event_loop.create_connection( + lambda: MyProto(create_future=True), sock=sock) tr, pr = self.event_loop.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) - self.event_loop.run() + self.event_loop.run_until_complete(pr.done) self.assertTrue(pr.nbytes > 0) @unittest.skipIf(ssl is None, 'No ssl module') @@ -627,14 +642,14 @@ def test_create_ssl_connection(self): with test_utils.run_test_server( self.event_loop, use_ssl=True) as httpd: f = self.event_loop.create_connection( - MyProto, *httpd.address, ssl=True) + lambda: MyProto(create_future=True), *httpd.address, ssl=True) tr, pr = self.event_loop.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) self.assertTrue('ssl' in tr.__class__.__name__.lower()) self.assertTrue( hasattr(tr.get_extra_info('socket'), 'getsockname')) - self.event_loop.run() + self.event_loop.run_until_complete(pr.done) self.assertTrue(pr.nbytes > 0) def test_create_connection_host_port_sock(self): @@ -1340,8 +1355,6 @@ class AbstractEventLoopTests(unittest.TestCase): def test_not_imlemented(self): f = unittest.mock.Mock() ev_loop = events.AbstractEventLoop() - self.assertRaises( - NotImplementedError, ev_loop.run) self.assertRaises( NotImplementedError, ev_loop.run_forever) self.assertRaises( diff --git a/tests/subprocess_test.py b/tests/subprocess_test.py index 9d1ff93a..28ab6623 100644 --- a/tests/subprocess_test.py +++ b/tests/subprocess_test.py @@ -8,6 +8,7 @@ import unittest from tulip import events +from tulip import futures from tulip import protocols from tulip import subprocess_transport @@ -17,6 +18,7 @@ class MyProto(protocols.Protocol): def __init__(self): self.state = 'INITIAL' self.nbytes = 0 + self.done = futures.Future() def connection_made(self, transport): self.transport = transport @@ -37,6 +39,7 @@ def eof_received(self): def connection_lost(self, exc): assert self.state in ('CONNECTED', 'EOF'), self.state self.state = 'CLOSED' + self.done.set_result(None) class FutureTests(unittest.TestCase): @@ -51,7 +54,7 @@ def tearDown(self): def test_unix_subprocess(self): p = MyProto() subprocess_transport.UnixSubprocessTransport(p, ['/bin/ls', '-lR']) - self.event_loop.run() + self.event_loop.run_until_complete(p.done) if __name__ == '__main__': diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 57f5bede..583674e8 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -33,7 +33,7 @@ def test_task_class(self): def notmuch(): return 'ok' t = tasks.Task(notmuch()) - self.event_loop.run() + self.event_loop.run_until_complete(t) self.assertTrue(t.done()) self.assertEqual(t.result(), 'ok') self.assertIs(t._event_loop, self.event_loop) @@ -48,7 +48,7 @@ def notmuch(): yield from [] return 'ko' t = notmuch() - self.event_loop.run() + self.event_loop.run_until_complete(t) self.assertTrue(t.done()) self.assertEqual(t.result(), 'ko') @@ -57,7 +57,7 @@ def test_task_decorator_func(self): def notmuch(): return 'ko' t = notmuch() - self.event_loop.run() + self.event_loop.run_until_complete(t) self.assertTrue(t.done()) self.assertEqual(t.result(), 'ko') @@ -69,7 +69,7 @@ def test_task_decorator_fut(self): def notmuch(): return fut t = notmuch() - self.event_loop.run() + self.event_loop.run_until_complete(t) self.assertTrue(t.done()) self.assertEqual(t.result(), 'ko') @@ -481,7 +481,7 @@ def sleeper(dt, arg): t = tasks.Task(sleeper(0.1, 'yeah')) t0 = time.monotonic() - self.event_loop.run() + self.event_loop.run_until_complete(t) t1 = time.monotonic() self.assertTrue(t1-t0 >= 0.09) self.assertTrue(t.done()) diff --git a/tulip/base_events.py b/tulip/base_events.py index ca83a56c..81b01253 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -94,23 +94,16 @@ def is_running(self): """Returns running status of event loop.""" return self._running - def run(self): - """Run the event loop until nothing left to do or stop() called. - - This keeps going as long as there are either readable and - writable file descriptors, or scheduled callbacks (of either - variety). + def run_forever(self): + """Run until stop() is called. - TODO: Give this a timeout too? + TODO: Maybe rename to run(). """ if self._running: raise RuntimeError('Event loop is running.') - self._running = True try: - while (self._ready or - self._scheduled or - self._selector.registered_count() > 1): + while True: try: self._run_once() except _StopError: @@ -118,18 +111,6 @@ def run(self): finally: self._running = False - def run_forever(self): - """Run until stop() is called. - - This only makes sense over run() if you have another thread - scheduling callbacks using call_soon_threadsafe(). - """ - handle = self.call_repeatedly(24*3600, lambda: None) - try: - self.run() - finally: - handle.cancel() - def run_once(self, timeout=0): """Run through all callbacks and all I/O polls once. diff --git a/tulip/events.py b/tulip/events.py index 63e4b790..e615d06c 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -115,21 +115,21 @@ def __ne__(self, other): class AbstractEventLoop: """Abstract event loop.""" - # TODO: Rename run() -> run_until_idle(), run_forever() -> run(). + def run_forever(self): + """Run the event loop until stop() is called. - def run(self): - """Run the event loop. Block until there is nothing left to do.""" + TODO: Rename to run(). + """ raise NotImplementedError - def run_forever(self): - """Run the event loop. Block until stop() is called.""" - raise NotImplementedError + def run_once(self, timeout=None): + """Run one complete cycle of the event loop. - def run_once(self, timeout=None): # NEW! - """Run one complete cycle of the event loop.""" + TODO: Deprecate this. + """ raise NotImplementedError - def run_until_complete(self, future, timeout=None): # NEW! + def run_until_complete(self, future, timeout=None): """Run the event loop until a Future is done. Return the Future's result, or raise its exception. @@ -140,7 +140,7 @@ def run_until_complete(self, future, timeout=None): # NEW! """ raise NotImplementedError - def stop(self): # NEW! + def stop(self): """Stop the event loop as soon as reasonable. Exactly how soon that is may depend on the implementation, but @@ -153,7 +153,7 @@ def stop(self): # NEW! def call_later(self, delay, callback, *args): raise NotImplementedError - def call_repeatedly(self, interval, callback, *args): # NEW! + def call_repeatedly(self, interval, callback, *args): raise NotImplementedError def call_soon(self, callback, *args): diff --git a/tulip/subprocess_transport.py b/tulip/subprocess_transport.py index 5e4d6550..e790d285 100644 --- a/tulip/subprocess_transport.py +++ b/tulip/subprocess_transport.py @@ -82,11 +82,12 @@ def write_eof(self): self._event_loop.remove_writer(self._wstdin) os.close(self._wstdin) self._wstdin = -1 + self._maybe_cleanup() def close(self): if not self._eof: self.write_eof() - # XXX What else? + self._maybe_cleanup() def _fatal_error(self, exc): tulip_log.error('Fatal error: %r', exc) @@ -98,6 +99,16 @@ def _fatal_error(self, exc): self._wstdin = -1 self._eof = True self._buffer = None + self._maybe_cleanup(exc) + + _conn_lost_called = False + + def _maybe_cleanup(self, exc=None): + if (self._wstdin < 0 and + self._rstdout < 0 and + not self._conn_lost_called): + self._conn_lost_called = True + self._event_loop.call_soon(self._protocol.connection_lost, exc) def _stdin_callback(self): data = b''.join(self._buffer) @@ -116,6 +127,7 @@ def _stdin_callback(self): if self._eof: os.close(self._wstdin) self._wstdin = -1 + self._maybe_cleanup() return elif n > 0: @@ -136,6 +148,7 @@ def _stdout_callback(self): os.close(self._rstdout) self._rstdout = -1 self._event_loop.call_soon(self._protocol.eof_received) + self._maybe_cleanup() def _setnonblocking(fd): From c5af1bf2045032e203ce462a389a9d18ee44ab76 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 24 Apr 2013 20:41:35 -0700 Subject: [PATCH 0435/1502] tcp echo example; better cmd arguments for udp echo example --- examples/mpsrv.py | 0 examples/tcp_echo.py | 111 +++++++++++++++++++++++++++++++++++++++++++ examples/udp_echo.py | 75 +++++++++++++++++++---------- examples/wssrv.py | 0 4 files changed, 162 insertions(+), 24 deletions(-) mode change 100644 => 100755 examples/mpsrv.py create mode 100755 examples/tcp_echo.py mode change 100644 => 100755 examples/udp_echo.py mode change 100644 => 100755 examples/wssrv.py diff --git a/examples/mpsrv.py b/examples/mpsrv.py old mode 100644 new mode 100755 diff --git a/examples/tcp_echo.py b/examples/tcp_echo.py new file mode 100755 index 00000000..16e3fb65 --- /dev/null +++ b/examples/tcp_echo.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +"""TCP echo server example.""" +import argparse +import tulip +try: + import signal +except ImportError: + signal = None + + +class EchoServer(tulip.Protocol): + + TIMEOUT = 5.0 + + def timeout(self): + print('connection timeout, closing.') + self.transport.close() + + def connection_made(self, transport): + print('connection made') + self.transport = transport + + # start 5 seconds timeout timer + self.h_timeout = tulip.get_event_loop().call_later( + self.TIMEOUT, self.timeout) + + def data_received(self, data): + print('data received: ', data.decode()) + self.transport.write(b'Re: ' + data) + + # restart timeout timer + self.h_timeout.cancel() + self.h_timeout = tulip.get_event_loop().call_later( + self.TIMEOUT, self.timeout) + + def eof_received(self): + pass + + def connection_lost(self, exc): + print('connection lost') + + +class EchoClient(tulip.Protocol): + + message = 'This is the message. It will be echoed.' + + def connection_made(self, transport): + self.transport = transport + self.transport.write(self.message.encode()) + print('data sent:', self.message) + + def data_received(self, data): + print('data received:', data) + + # disconnect after 10 seconds + tulip.get_event_loop().call_later(10.0, self.transport.close) + + def eof_received(self): + pass + + def connection_lost(self, exc): + tulip.get_event_loop().stop() + + +def start_client(loop, host, port): + t = tulip.Task(loop.create_connection(EchoClient, host, port)) + loop.run_until_complete(t) + + +def start_server(loop, host, port): + f = loop.start_serving(EchoServer, host, port) + x = loop.run_until_complete(f)[0] + print('serving on', x.getsockname()) + + +ARGS = argparse.ArgumentParser(description="TCP Echo example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run tcp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run tcp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') + + +if __name__ == '__main__': + args = ARGS.parse_args() + + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + loop = tulip.get_event_loop() + if signal is not None: + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if args.server: + start_server(loop, args.host, args.port) + else: + start_client(loop, args.host, args.port) + + loop.run_forever() diff --git a/examples/udp_echo.py b/examples/udp_echo.py old mode 100644 new mode 100755 index 5d1e02ec..d7bde29a --- a/examples/udp_echo.py +++ b/examples/udp_echo.py @@ -1,15 +1,13 @@ -"""UDP echo example. - -Start server: - - >> python ./udp_echo.py --server - -""" - +#!/usr/bin/env python3 +"""UDP echo example.""" +import argparse import sys import tulip - -ADDRESS = ('127.0.0.1', 10000) +import logging +try: + import signal +except ImportError: + signal = None class MyServerUdpEchoProtocol: @@ -31,7 +29,7 @@ def connection_lost(self, exc): class MyClientUdpEchoProtocol: - message = 'This is the message. It will be repeated.' + message = 'This is the message. It will be echoed.' def connection_made(self, transport): self.transport = transport @@ -52,22 +50,51 @@ def connection_lost(self, exc): loop.stop() -def start_server(): - loop = tulip.get_event_loop() - tulip.Task(loop.create_datagram_endpoint( - MyServerUdpEchoProtocol, local_addr=ADDRESS)) - loop.run_forever() +def start_server(loop, addr): + t = tulip.Task(loop.create_datagram_endpoint( + MyServerUdpEchoProtocol, local_addr=addr)) + loop.run_until_complete(t) -def start_client(): - loop = tulip.get_event_loop() - tulip.Task(loop.create_datagram_endpoint( - MyClientUdpEchoProtocol, remote_addr=ADDRESS)) - loop.run_forever() +def start_client(loop, addr): + t = tulip.Task(loop.create_datagram_endpoint( + MyClientUdpEchoProtocol, remote_addr=addr)) + loop.run_until_complete(t) + + +ARGS = argparse.ArgumentParser(description="UDP Echo example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run udp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run udp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') if __name__ == '__main__': - if '--server' in sys.argv: - start_server() + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() else: - start_client() + loop = tulip.get_event_loop() + loop.set_log_level(logging.CRITICAL) + if signal is not None: + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if '--server' in sys.argv: + start_server(loop, (args.host, args.port)) + else: + start_client(loop, (args.host, args.port)) + + loop.run_forever() diff --git a/examples/wssrv.py b/examples/wssrv.py old mode 100644 new mode 100755 From 3d7a38a77c130ac35a23cf0322233afbf3c7b49e Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 24 Apr 2013 21:26:39 -0700 Subject: [PATCH 0436/1502] Make time range check more lenient. --- tests/events_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/events_test.py b/tests/events_test.py index f03023ab..d68436e5 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -192,7 +192,7 @@ def run(): t1 = time.monotonic() t.join() self.assertTrue(called) - self.assertTrue(0.09 < t1-t0 <= 0.12) + self.assertTrue(0.09 < t1-t0 <= 0.15) def test_call_later(self): results = [] From 34e5f627c539efa5145be48b3ef2af250d9944cb Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 24 Apr 2013 21:37:42 -0700 Subject: [PATCH 0437/1502] protocol parser example --- examples/tcp_protocol_parser.py | 172 ++++++++++++++++++++++++++++++++ 1 file changed, 172 insertions(+) create mode 100755 examples/tcp_protocol_parser.py diff --git a/examples/tcp_protocol_parser.py b/examples/tcp_protocol_parser.py new file mode 100755 index 00000000..f05518b1 --- /dev/null +++ b/examples/tcp_protocol_parser.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 +"""Protocol parser example.""" +import argparse +import collections +import logging +import tulip +try: + import signal +except ImportError: + signal = None + + +MSG_TEXT = b'text:' +MSG_PING = b'ping:' +MSG_PONG = b'pong:' +MSG_STOP = b'stop:' + +Message = collections.namedtuple('Message', ('tp', 'data')) + + +def my_protocol_parser(): + """Parser is used with StreamBuffer for incremental protocol parsing. + Parser is a generator function, but it is not a coroutine. Usually + parsers are implemented as a state machine. + + more details in tulip/parsers.py + existing parsers: + * http protocol parsers tulip/http/protocol.py + * websocket parser tulip/http/websocket.py + """ + out, buf = yield + + while True: + tp = yield from buf.read(5) + if tp in (MSG_PING, MSG_PONG): + # skip line + yield from buf.skipuntil(b'\r\n') + out.feed_data(Message(tp, None)) + elif tp == MSG_STOP: + out.feed_data(Message(tp, None)) + elif tp == MSG_TEXT: + # read text + text = yield from buf.readuntil(b'\r\n') + out.feed_data(Message(tp, text.strip().decode('utf-8'))) + else: + raise ValueError('Unknown protocol prefix.') + + +class MyProtocolWriter: + + def __init__(self, transport): + self.transport = transport + + def ping(self): + self.transport.write(b'ping:\r\n') + + def pong(self): + self.transport.write(b'pong:\r\n') + + def stop(self): + self.transport.write(b'stop:\r\n') + + def send_text(self, text): + self.transport.write( + 'text:{}\r\n'.format(text.strip()).encode('utf-8')) + + +class EchoServer(tulip.Protocol): + + def connection_made(self, transport): + print('Connection made') + self.transport = transport + self.stream = tulip.StreamBuffer() + self.dispatch() + + def data_received(self, data): + self.stream.feed_data(data) + + def eof_received(self): + self.stream.feed_eof() + + def connection_lost(self, exc): + print('Connection lost') + + @tulip.task + def dispatch(self): + reader = self.stream.set_parser(my_protocol_parser()) + writer = MyProtocolWriter(self.transport) + + while True: + msg = yield from reader.read() + if msg is None: + break # client has been disconnected + + print('Message received: {}'.format(msg)) + + if msg.tp == MSG_PING: + writer.pong() + elif msg.tp == MSG_TEXT: + writer.send_text('Re: ' + msg.data) + elif msg.tp == MSG_STOP: + self.transport.close() + break + + +@tulip.task +def start_client(loop, host, port): + transport, stream = yield from loop.create_connection( + tulip.StreamProtocol, host, port) + reader = stream.set_parser(my_protocol_parser()) + writer = MyProtocolWriter(transport) + writer.ping() + + message = 'This is the message. It will be echoed.' + + while True: + msg = yield from reader.read() + + print('Message received: {}'.format(msg)) + if msg.tp == MSG_PONG: + writer.send_text(message) + print('data sent:', message) + elif msg.tp == MSG_TEXT: + writer.stop() + print('stop sent') + break + + transport.close() + + +def start_server(loop, host, port): + f = loop.start_serving(EchoServer, host, port) + x = loop.run_until_complete(f)[0] + print('serving on', x.getsockname()) + loop.run_forever() + + +ARGS = argparse.ArgumentParser(description="Protocol parser example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run tcp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run tcp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') + + +if __name__ == '__main__': + args = ARGS.parse_args() + + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + loop = tulip.get_event_loop() + loop.set_log_level(logging.CRITICAL) + if signal is not None: + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if args.server: + start_server(loop, args.host, args.port) + else: + loop.run_until_complete(start_client(loop, args.host, args.port)) From 2863c7276a1638b00236ce9079c8d6bcaf5a57d8 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 25 Apr 2013 11:25:21 -0700 Subject: [PATCH 0438/1502] Allow V as alias for VERBOSE, e.g. "make V=0". --- Makefile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 274da4c8..bca4ed9b 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,8 @@ # Some simple testing tasks (sorry, UNIX only). PYTHON=python3 -VERBOSE=1 +VERBOSE=$(V) +V= 1 FLAGS= test: From f09e7fb0c5b92f8dcd0bafd79736385210594486 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 25 Apr 2013 12:16:57 -0700 Subject: [PATCH 0439/1502] add closing logic to _ProactorSocketTransport._loop_writing() --- tests/proactor_events_test.py | 28 ++++++++++++++++++++++++---- tulip/proactor_events.py | 5 +++-- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py index 7b4d5aa3..8f26eb4c 100644 --- a/tests/proactor_events_test.py +++ b/tests/proactor_events_test.py @@ -152,6 +152,20 @@ def test_loop_writing_stop(self): tr._loop_writing(fut) self.assertIsNone(tr._write_fut) + def test_loop_writing_closing(self): + fut = tulip.Future() + fut.set_result(1) + + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + self.event_loop.reset_mock() + tr._write_fut = fut + tr.close() + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + self.event_loop.call_soon.assert_called_with( + tr._call_connection_lost, None) + def test_abort(self): tr = _ProactorSocketTransport( self.event_loop, self.sock, self.protocol) @@ -162,15 +176,21 @@ def test_abort(self): def test_close(self): tr = _ProactorSocketTransport( self.event_loop, self.sock, self.protocol) - tr._write_fut = unittest.mock.Mock() + self.event_loop.reset_mock() tr.close() - - tr._write_fut.cancel.assert_called_with() self.event_loop.call_soon.assert_called_with( tr._call_connection_lost, None) self.assertTrue(tr._closing) - def test_close_2(self): + def test_close_write_fut(self): + tr = _ProactorSocketTransport( + self.event_loop, self.sock, self.protocol) + tr._write_fut = unittest.mock.Mock() + self.event_loop.reset_mock() + tr.close() + self.assertFalse(self.event_loop.call_soon.called) + + def test_close_buffer(self): tr = _ProactorSocketTransport( self.event_loop, self.sock, self.protocol) tr._buffer = [b'data'] diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index ba1389d2..751f8ee1 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -77,6 +77,9 @@ def _loop_writing(self, f=None): self._buffer = [] if not data: self._write_fut = None + if self._closing: + self._event_loop.call_soon( + self._call_connection_lost, None) return self._write_fut = self._event_loop._proactor.send(self._sock, data) except OSError as exc: @@ -92,8 +95,6 @@ def abort(self): def close(self): self._closing = True - if self._write_fut: - self._write_fut.cancel() if not self._buffer: self._event_loop.call_soon(self._call_connection_lost, None) From dd4960655efa8b0c658d0f82864efd8a369b81f5 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 25 Apr 2013 16:43:39 -0700 Subject: [PATCH 0440/1502] do not schedule _call_connection_lost if we have active _write_fut --- tulip/proactor_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index 751f8ee1..e7c31728 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -95,7 +95,7 @@ def abort(self): def close(self): self._closing = True - if not self._buffer: + if not self._buffer and self._write_fut is None: self._event_loop.call_soon(self._call_connection_lost, None) def _fatal_error(self, exc): From cd6bd06e4b355d8eee4bf07e4dc1c37957ff65ab Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 25 Apr 2013 16:47:33 -0700 Subject: [PATCH 0441/1502] skip ConnectionResetError during transport closing --- tests/events_test.py | 2 -- tulip/proactor_events.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index d68436e5..47cf4e82 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -850,7 +850,6 @@ def connection_made(self, transport): client.send(b'xxx') proto = self.event_loop.run_until_complete(f_proto) proto.transport.close() - self.event_loop.run_once() # windows, issue #35 client.close() f_proto = futures.Future() @@ -859,7 +858,6 @@ def connection_made(self, transport): client.send(b'xxx') proto = self.event_loop.run_until_complete(f_proto) proto.transport.close() - self.event_loop.run_once() # windows, issue #35 client.close() for s in socks: diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index e7c31728..e45109c6 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -41,7 +41,7 @@ def _loop_reading(self, fut=None): return self._read_fut = self._event_loop._proactor.recv(self._sock, 4096) - except ConnectionAbortedError as exc: + except (ConnectionAbortedError, ConnectionResetError) as exc: if not self._closing: self._fatal_error(exc) except OSError as exc: From 1b0b15c5ced07c3cf33b606be06670e3ea9820de Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 25 Apr 2013 17:55:44 -0700 Subject: [PATCH 0442/1502] replace event_loop with loop in proactor --- tests/proactor_events_test.py | 188 ++++++++++++++-------------------- tulip/proactor_events.py | 21 ++-- tulip/windows_events.py | 2 +- 3 files changed, 90 insertions(+), 121 deletions(-) diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py index 8f26eb4c..959cd2ba 100644 --- a/tests/proactor_events_test.py +++ b/tests/proactor_events_test.py @@ -12,26 +12,24 @@ class ProactorSocketTransportTests(unittest.TestCase): def setUp(self): - self.event_loop = unittest.mock.Mock() + self.loop = unittest.mock.Mock() self.sock = unittest.mock.Mock(socket.socket) self.protocol = unittest.mock.Mock(tulip.Protocol) def test_ctor(self): fut = tulip.Future() tr = _ProactorSocketTransport( - self.event_loop, self.sock, self.protocol, fut) - self.event_loop.call_soon.mock_calls[0].assert_called_with( - tr._loop_reading) - self.event_loop.call_soon.mock_calls[1].assert_called_with( + self.loop, self.sock, self.protocol, fut) + self.loop.call_soon.mock_calls[0].assert_called_with(tr._loop_reading) + self.loop.call_soon.mock_calls[1].assert_called_with( self.protocol.connection_made, tr) - self.event_loop.call_soon.mock_calls[2].assert_called_with( + self.loop.call_soon.mock_calls[2].assert_called_with( fut.set_result, None) def test_loop_reading(self): - tr = _ProactorSocketTransport( - self.event_loop, self.sock, self.protocol) + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._loop_reading() - self.event_loop._proactor.recv.assert_called_with(self.sock, 4096) + self.loop._proactor.recv.assert_called_with(self.sock, 4096) self.assertFalse(self.protocol.data_received.called) self.assertFalse(self.protocol.eof_received.called) @@ -39,75 +37,65 @@ def test_loop_reading_data(self): res = tulip.Future() res.set_result(b'data') - tr = _ProactorSocketTransport( - self.event_loop, self.sock, self.protocol) + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._read_fut = res tr._loop_reading(res) - self.event_loop._proactor.recv.assert_called_with(self.sock, 4096) + self.loop._proactor.recv.assert_called_with(self.sock, 4096) self.protocol.data_received.assert_called_with(b'data') def test_loop_reading_no_data(self): res = tulip.Future() res.set_result(b'') - tr = _ProactorSocketTransport( - self.event_loop, self.sock, self.protocol) + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) self.assertRaises(AssertionError, tr._loop_reading, res) tr._read_fut = res tr._loop_reading(res) - self.assertFalse(self.event_loop._proactor.recv.called) + self.assertFalse(self.loop._proactor.recv.called) self.assertTrue(self.protocol.eof_received.called) def test_loop_reading_aborted(self): - err = self.event_loop._proactor.recv.side_effect = ( - ConnectionAbortedError()) + err = self.loop._proactor.recv.side_effect = ConnectionAbortedError() - tr = _ProactorSocketTransport( - self.event_loop, self.sock, self.protocol) + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._fatal_error = unittest.mock.Mock() tr._loop_reading() tr._fatal_error.assert_called_with(err) def test_loop_reading_aborted_closing(self): - self.event_loop._proactor.recv.side_effect = ( - ConnectionAbortedError()) + self.loop._proactor.recv.side_effect = ConnectionAbortedError() - tr = _ProactorSocketTransport( - self.event_loop, self.sock, self.protocol) + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._closing = True tr._fatal_error = unittest.mock.Mock() tr._loop_reading() self.assertFalse(tr._fatal_error.called) def test_loop_reading_exception(self): - err = self.event_loop._proactor.recv.side_effect = (OSError()) + err = self.loop._proactor.recv.side_effect = (OSError()) - tr = _ProactorSocketTransport( - self.event_loop, self.sock, self.protocol) + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._fatal_error = unittest.mock.Mock() tr._loop_reading() tr._fatal_error.assert_called_with(err) def test_write(self): - tr = _ProactorSocketTransport( - self.event_loop, self.sock, self.protocol) + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._loop_writing = unittest.mock.Mock() tr.write(b'data') self.assertEqual(tr._buffer, [b'data']) self.assertTrue(tr._loop_writing.called) def test_write_no_data(self): - tr = _ProactorSocketTransport( - self.event_loop, self.sock, self.protocol) + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr.write(b'') self.assertFalse(tr._buffer) def test_write_more(self): - tr = _ProactorSocketTransport( - self.event_loop, self.sock, self.protocol) + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._write_fut = unittest.mock.Mock() tr._loop_writing = unittest.mock.Mock() tr.write(b'data') @@ -115,19 +103,17 @@ def test_write_more(self): self.assertFalse(tr._loop_writing.called) def test_loop_writing(self): - tr = _ProactorSocketTransport( - self.event_loop, self.sock, self.protocol) + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._buffer = [b'da', b'ta'] tr._loop_writing() - self.event_loop._proactor.send.assert_called_with(self.sock, b'data') - self.event_loop._proactor.send.return_value.add_done_callback.\ + self.loop._proactor.send.assert_called_with(self.sock, b'data') + self.loop._proactor.send.return_value.add_done_callback.\ assert_called_with(tr._loop_writing) @unittest.mock.patch('tulip.proactor_events.tulip_log') def test_loop_writing_err(self, m_log): - err = self.event_loop._proactor.send.side_effect = OSError() - tr = _ProactorSocketTransport( - self.event_loop, self.sock, self.protocol) + err = self.loop._proactor.send.side_effect = OSError() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._fatal_error = unittest.mock.Mock() tr._buffer = [b'da', b'ta'] tr._loop_writing() @@ -146,8 +132,7 @@ def test_loop_writing_stop(self): fut = tulip.Future() fut.set_result(b'data') - tr = _ProactorSocketTransport( - self.event_loop, self.sock, self.protocol) + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._write_fut = fut tr._loop_writing(fut) self.assertIsNone(tr._write_fut) @@ -156,53 +141,44 @@ def test_loop_writing_closing(self): fut = tulip.Future() fut.set_result(1) - tr = _ProactorSocketTransport( - self.event_loop, self.sock, self.protocol) - self.event_loop.reset_mock() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + self.loop.reset_mock() tr._write_fut = fut tr.close() tr._loop_writing(fut) self.assertIsNone(tr._write_fut) - self.event_loop.call_soon.assert_called_with( - tr._call_connection_lost, None) + self.loop.call_soon.assert_called_with(tr._call_connection_lost, None) def test_abort(self): - tr = _ProactorSocketTransport( - self.event_loop, self.sock, self.protocol) + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._fatal_error = unittest.mock.Mock() tr.abort() tr._fatal_error.assert_called_with(None) def test_close(self): - tr = _ProactorSocketTransport( - self.event_loop, self.sock, self.protocol) - self.event_loop.reset_mock() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + self.loop.reset_mock() tr.close() - self.event_loop.call_soon.assert_called_with( - tr._call_connection_lost, None) + self.loop.call_soon.assert_called_with(tr._call_connection_lost, None) self.assertTrue(tr._closing) def test_close_write_fut(self): - tr = _ProactorSocketTransport( - self.event_loop, self.sock, self.protocol) + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._write_fut = unittest.mock.Mock() - self.event_loop.reset_mock() + self.loop.reset_mock() tr.close() - self.assertFalse(self.event_loop.call_soon.called) + self.assertFalse(self.loop.call_soon.called) def test_close_buffer(self): - tr = _ProactorSocketTransport( - self.event_loop, self.sock, self.protocol) + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._buffer = [b'data'] - self.event_loop.reset_mock() + self.loop.reset_mock() tr.close() - - self.assertFalse(self.event_loop.call_soon.called) + self.assertFalse(self.loop.call_soon.called) @unittest.mock.patch('tulip.proactor_events.tulip_log') def test_fatal_error(self, m_logging): - tr = _ProactorSocketTransport( - self.event_loop, self.sock, self.protocol) + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._buffer = [b'data'] read_fut = tr._read_fut = unittest.mock.Mock() write_fut = tr._write_fut = unittest.mock.Mock() @@ -210,24 +186,20 @@ def test_fatal_error(self, m_logging): read_fut.cancel.assert_called_with() write_fut.cancel.assert_called_with() - self.event_loop.call_soon.assert_called_with( - tr._call_connection_lost, None) + self.loop.call_soon.assert_called_with(tr._call_connection_lost, None) self.assertEqual([], tr._buffer) @unittest.mock.patch('tulip.proactor_events.tulip_log') def test_fatal_error_2(self, m_logging): - tr = _ProactorSocketTransport( - self.event_loop, self.sock, self.protocol) + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._buffer = [b'data'] tr._fatal_error(None) - self.event_loop.call_soon.assert_called_with( - tr._call_connection_lost, None) + self.loop.call_soon.assert_called_with(tr._call_connection_lost, None) self.assertEqual([], tr._buffer) def test_call_connection_lost(self): - tr = _ProactorSocketTransport( - self.event_loop, self.sock, self.protocol) + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._call_connection_lost(None) self.assertTrue(self.protocol.connection_lost.called) self.assertTrue(self.sock.close.called) @@ -245,52 +217,52 @@ class EventLoop(BaseProactorEventLoop): def _socketpair(s): return (self.ssock, self.csock) - self.event_loop = EventLoop(self.proactor) + self.loop = EventLoop(self.proactor) @unittest.mock.patch.object(BaseProactorEventLoop, 'call_soon') @unittest.mock.patch.object(BaseProactorEventLoop, '_socketpair') def test_ctor(self, socketpair, call_soon): ssock, csock = socketpair.return_value = ( unittest.mock.Mock(), unittest.mock.Mock()) - event_loop = BaseProactorEventLoop(self.proactor) - self.assertIs(event_loop._ssock, ssock) - self.assertIs(event_loop._csock, csock) - self.assertEqual(event_loop._internal_fds, 1) - call_soon.assert_called_with(event_loop._loop_self_reading) + loop = BaseProactorEventLoop(self.proactor) + self.assertIs(loop._ssock, ssock) + self.assertIs(loop._csock, csock) + self.assertEqual(loop._internal_fds, 1) + call_soon.assert_called_with(loop._loop_self_reading) def test_close_self_pipe(self): - self.event_loop._close_self_pipe() - self.assertEqual(self.event_loop._internal_fds, 0) + self.loop._close_self_pipe() + self.assertEqual(self.loop._internal_fds, 0) self.assertTrue(self.ssock.close.called) self.assertTrue(self.csock.close.called) - self.assertIsNone(self.event_loop._ssock) - self.assertIsNone(self.event_loop._csock) + self.assertIsNone(self.loop._ssock) + self.assertIsNone(self.loop._csock) def test_close(self): - self.event_loop._close_self_pipe = unittest.mock.Mock() - self.event_loop.close() - self.assertTrue(self.event_loop._close_self_pipe.called) + self.loop._close_self_pipe = unittest.mock.Mock() + self.loop.close() + self.assertTrue(self.loop._close_self_pipe.called) self.assertTrue(self.proactor.close.called) - self.assertIsNone(self.event_loop._proactor) + self.assertIsNone(self.loop._proactor) - self.event_loop._close_self_pipe.reset_mock() - self.event_loop.close() - self.assertFalse(self.event_loop._close_self_pipe.called) + self.loop._close_self_pipe.reset_mock() + self.loop.close() + self.assertFalse(self.loop._close_self_pipe.called) def test_sock_recv(self): - self.event_loop.sock_recv(self.sock, 1024) + self.loop.sock_recv(self.sock, 1024) self.proactor.recv.assert_called_with(self.sock, 1024) def test_sock_sendall(self): - self.event_loop.sock_sendall(self.sock, b'data') + self.loop.sock_sendall(self.sock, b'data') self.proactor.send.assert_called_with(self.sock, b'data') def test_sock_connect(self): - self.event_loop.sock_connect(self.sock, 123) + self.loop.sock_connect(self.sock, 123) self.proactor.connect.assert_called_with(self.sock, 123) def test_sock_accept(self): - self.event_loop.sock_accept(self.sock) + self.loop.sock_accept(self.sock) self.proactor.accept.assert_called_with(self.sock) def test_socketpair(self): @@ -298,43 +270,42 @@ def test_socketpair(self): NotImplementedError, BaseProactorEventLoop, self.proactor) def test_make_socket_transport(self): - tr = self.event_loop._make_socket_transport( - self.sock, unittest.mock.Mock()) + tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock()) self.assertIsInstance(tr, _ProactorSocketTransport) def test_loop_self_reading(self): - self.event_loop._loop_self_reading() + self.loop._loop_self_reading() self.proactor.recv.assert_called_with(self.ssock, 4096) self.proactor.recv.return_value.add_done_callback.assert_called_with( - self.event_loop._loop_self_reading) + self.loop._loop_self_reading) def test_loop_self_reading_fut(self): fut = unittest.mock.Mock() - self.event_loop._loop_self_reading(fut) + self.loop._loop_self_reading(fut) self.assertTrue(fut.result.called) self.proactor.recv.assert_called_with(self.ssock, 4096) self.proactor.recv.return_value.add_done_callback.assert_called_with( - self.event_loop._loop_self_reading) + self.loop._loop_self_reading) def test_loop_self_reading_exception(self): - self.event_loop.close = unittest.mock.Mock() + self.loop.close = unittest.mock.Mock() self.proactor.recv.side_effect = OSError() - self.assertRaises(OSError, self.event_loop._loop_self_reading) - self.assertTrue(self.event_loop.close.called) + self.assertRaises(OSError, self.loop._loop_self_reading) + self.assertTrue(self.loop.close.called) def test_write_to_self(self): - self.event_loop._write_to_self() + self.loop._write_to_self() self.csock.send.assert_called_with(b'x') def test_process_events(self): - self.event_loop._process_events([]) + self.loop._process_events([]) @unittest.mock.patch('tulip.proactor_events.tulip_log') def test_start_serving(self, m_log): pf = unittest.mock.Mock() - call_soon = self.event_loop.call_soon = unittest.mock.Mock() + call_soon = self.loop.call_soon = unittest.mock.Mock() - self.event_loop._start_serving(pf, self.sock) + self.loop._start_serving(pf, self.sock) self.assertTrue(call_soon.called) # callback @@ -346,11 +317,10 @@ def test_start_serving(self, m_log): fut = unittest.mock.Mock() fut.result.return_value = (unittest.mock.Mock(), unittest.mock.Mock()) - make_transport = self.event_loop._make_socket_transport = ( - unittest.mock.Mock()) + make_tr = self.loop._make_socket_transport = unittest.mock.Mock() loop(fut) self.assertTrue(fut.result.called) - self.assertTrue(make_transport.called) + self.assertTrue(make_tr.called) # exception fut.result.side_effect = OSError() diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index e45109c6..cc889b55 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -12,10 +12,10 @@ class _ProactorSocketTransport(transports.Transport): - def __init__(self, event_loop, sock, protocol, waiter=None, extra=None): + def __init__(self, loop, sock, protocol, waiter=None, extra=None): super().__init__(extra) self._extra['socket'] = sock - self._event_loop = event_loop + self._loop = loop self._sock = sock self._protocol = protocol self._buffer = [] @@ -23,10 +23,10 @@ def __init__(self, event_loop, sock, protocol, waiter=None, extra=None): self._write_fut = None self._conn_lost = 0 self._closing = False # Set when close() called. - self._event_loop.call_soon(self._protocol.connection_made, self) - self._event_loop.call_soon(self._loop_reading) + self._loop.call_soon(self._protocol.connection_made, self) + self._loop.call_soon(self._loop_reading) if waiter is not None: - self._event_loop.call_soon(waiter.set_result, None) + self._loop.call_soon(waiter.set_result, None) def _loop_reading(self, fut=None): data = None @@ -40,7 +40,7 @@ def _loop_reading(self, fut=None): self._read_fut = None return - self._read_fut = self._event_loop._proactor.recv(self._sock, 4096) + self._read_fut = self._loop._proactor.recv(self._sock, 4096) except (ConnectionAbortedError, ConnectionResetError) as exc: if not self._closing: self._fatal_error(exc) @@ -78,10 +78,9 @@ def _loop_writing(self, f=None): if not data: self._write_fut = None if self._closing: - self._event_loop.call_soon( - self._call_connection_lost, None) + self._loop.call_soon(self._call_connection_lost, None) return - self._write_fut = self._event_loop._proactor.send(self._sock, data) + self._write_fut = self._loop._proactor.send(self._sock, data) except OSError as exc: self._conn_lost += 1 self._fatal_error(exc) @@ -96,7 +95,7 @@ def abort(self): def close(self): self._closing = True if not self._buffer and self._write_fut is None: - self._event_loop.call_soon(self._call_connection_lost, None) + self._loop.call_soon(self._call_connection_lost, None) def _fatal_error(self, exc): tulip_log.exception('Fatal error for %s', self) @@ -106,7 +105,7 @@ def _fatal_error(self, exc): self._read_fut.cancel() self._write_fut = self._read_fut = None self._buffer = [] - self._event_loop.call_soon(self._call_connection_lost, exc) + self._loop.call_soon(self._call_connection_lost, exc) def _call_connection_lost(self, exc): try: diff --git a/tulip/windows_events.py b/tulip/windows_events.py index 2ec8561c..9feaa543 100644 --- a/tulip/windows_events.py +++ b/tulip/windows_events.py @@ -75,7 +75,7 @@ def accept(self, listener): ov.AcceptEx(listener.fileno(), conn.fileno()) def finish_accept(): - addr = ov.getresult() + ov.getresult() buf = struct.pack('@P', listener.fileno()) conn.setsockopt(socket.SOL_SOCKET, _overlapped.SO_UPDATE_ACCEPT_CONTEXT, From f1bfb26d5383b119c668decce9d14f24cce6b583 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 26 Apr 2013 10:39:08 -0700 Subject: [PATCH 0443/1502] ssl socket transport tests refactoring --- tests/selector_events_test.py | 232 +++++++++++++++++++--------------- tulip/selector_events.py | 36 +++--- 2 files changed, 151 insertions(+), 117 deletions(-) diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index e0227e6a..9cde8b94 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -822,7 +822,7 @@ def test_connection_lost(self): class SelectorSslTransportTests(unittest.TestCase): def setUp(self): - self.event_loop = unittest.mock.Mock(spec_set=AbstractEventLoop) + self.loop = unittest.mock.Mock(spec_set=AbstractEventLoop) self.sock = unittest.mock.Mock(socket.socket) self.sock.fileno.return_value = 7 self.protocol = unittest.mock.Mock(spec_set=Protocol) @@ -830,217 +830,251 @@ def setUp(self): self.sslsock.fileno.return_value = 1 self.sslcontext = unittest.mock.Mock() self.sslcontext.wrap_socket.return_value = self.sslsock - self.waiter = futures.Future() - self.transport = _SelectorSslTransport( - self.event_loop, self.sock, - self.protocol, self.sslcontext, self.waiter) - self.event_loop.reset_mock() + def _make_one(self, create_waiter=None): + transport = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext) + self.loop.reset_mock() self.sock.reset_mock() self.protocol.reset_mock() + self.sslsock.reset_mock() self.sslcontext.reset_mock() + return transport def test_on_handshake(self): - self.transport._on_handshake() + tr = self._make_one() + tr._waiter = futures.Future() + tr._on_handshake() self.assertTrue(self.sslsock.do_handshake.called) - self.assertTrue(self.event_loop.remove_reader.called) - self.assertTrue(self.event_loop.remove_writer.called) - self.assertEqual( - (1, self.transport._on_ready,), - self.event_loop.add_reader.call_args[0]) - self.assertEqual( - (1, self.transport._on_ready,), - self.event_loop.add_writer.call_args[0]) + self.assertTrue(self.loop.remove_reader.called) + self.assertTrue(self.loop.remove_writer.called) + self.assertEqual((1, tr._on_ready,), + self.loop.add_reader.call_args[0]) + self.assertEqual((1, tr._on_ready,), + self.loop.add_writer.call_args[0]) + self.assertEqual((tr._waiter.set_result, None), + self.loop.call_soon.call_args[0]) + tr._waiter.cancel() def test_on_handshake_reader_retry(self): self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError - self.transport._on_handshake() - self.assertEqual( - (1, self.transport._on_handshake,), - self.event_loop.add_reader.call_args[0]) + transport = self._make_one() + transport._on_handshake() + self.assertEqual((1, transport._on_handshake,), + self.loop.add_reader.call_args[0]) def test_on_handshake_writer_retry(self): self.sslsock.do_handshake.side_effect = ssl.SSLWantWriteError - self.transport._on_handshake() - self.assertEqual( - (1, self.transport._on_handshake,), - self.event_loop.add_writer.call_args[0]) + transport = self._make_one() + transport._on_handshake() + self.assertEqual((1, transport._on_handshake,), + self.loop.add_writer.call_args[0]) def test_on_handshake_exc(self): - self.sslsock.do_handshake.side_effect = ValueError - self.transport._on_handshake() + exc = ValueError() + self.sslsock.do_handshake.side_effect = exc + transport = self._make_one() + transport._waiter = futures.Future() + transport._on_handshake() self.assertTrue(self.sslsock.close.called) + self.assertTrue(transport._waiter.done()) + self.assertIs(exc, transport._waiter.exception()) def test_on_handshake_base_exc(self): - self.sslsock.do_handshake.side_effect = BaseException - self.assertRaises(BaseException, self.transport._on_handshake) + transport = self._make_one() + transport._waiter = futures.Future() + exc = BaseException() + self.sslsock.do_handshake.side_effect = exc + self.assertRaises(BaseException, transport._on_handshake) self.assertTrue(self.sslsock.close.called) + self.assertTrue(transport._waiter.done()) + self.assertIs(exc, transport._waiter.exception()) def test_write_no_data(self): - self.transport._buffer.append(b'data') - self.transport.write(b'') - self.assertEqual([b'data'], self.transport._buffer) + transport = self._make_one() + transport._buffer.append(b'data') + transport.write(b'') + self.assertEqual([b'data'], transport._buffer) def test_write_str(self): - self.assertRaises(AssertionError, self.transport.write, 'str') + transport = self._make_one() + self.assertRaises(AssertionError, transport.write, 'str') def test_write_closing(self): - self.transport.close() - self.assertRaises(AssertionError, self.transport.write, b'data') + transport = self._make_one() + transport.close() + self.assertRaises(AssertionError, transport.write, b'data') @unittest.mock.patch('tulip.selector_events.tulip_log') def test_write_exception(self, m_log): - self.transport._conn_lost = 1 - self.transport.write(b'data') - self.assertEqual(self.transport._buffer, []) - self.transport.write(b'data') - self.transport.write(b'data') - self.transport.write(b'data') - self.transport.write(b'data') + transport = self._make_one() + transport._conn_lost = 1 + transport.write(b'data') + self.assertEqual(transport._buffer, []) + transport.write(b'data') + transport.write(b'data') + transport.write(b'data') + transport.write(b'data') m_log.warning.assert_called_with('socket.send() raised exception.') def test_abort(self): - self.transport._close = unittest.mock.Mock() - self.transport.abort() - self.transport._close.assert_called_with(None) + transport = self._make_one() + transport._close = unittest.mock.Mock() + transport.abort() + transport._close.assert_called_with(None) @unittest.mock.patch('tulip.log.tulip_log.exception') def test_fatal_error(self, m_exc): exc = OSError() - self.transport._buffer.append(b'data') - self.transport._fatal_error(exc) + transport = self._make_one() + transport._buffer.append(b'data') + transport._fatal_error(exc) - self.assertEqual([], self.transport._buffer) - self.assertTrue(self.event_loop.remove_writer.called) - self.assertTrue(self.event_loop.remove_reader.called) - self.event_loop.call_soon.assert_called_with( + self.assertEqual([], transport._buffer) + self.assertTrue(self.loop.remove_writer.called) + self.assertTrue(self.loop.remove_reader.called) + self.loop.call_soon.assert_called_with( self.protocol.connection_lost, exc) - m_exc.assert_called_with('Fatal error for %s', self.transport) + m_exc.assert_called_with('Fatal error for %s', transport) def test_close(self): - self.transport.close() - self.assertTrue(self.transport._closing) - self.assertTrue(self.event_loop.remove_reader.called) - self.event_loop.call_soon.assert_called_with( + transport = self._make_one() + transport.close() + self.assertTrue(transport._closing) + self.assertTrue(self.loop.remove_reader.called) + self.loop.call_soon.assert_called_with( self.protocol.connection_lost, None) - def test_close_write_buffer(self): - self.transport._buffer.append(b'data') - self.transport.close() + def test_close_write_buffer1(self): + transport = self._make_one() + transport._buffer.append(b'data') + transport.close() - self.assertTrue(self.event_loop.remove_reader.called) - self.assertFalse(self.event_loop.call_soon.called) + self.assertTrue(self.loop.remove_reader.called) + self.assertFalse(self.loop.call_soon.called) def test_on_ready_closed(self): self.sslsock.fileno.return_value = -1 - self.transport._on_ready() + transport = self._make_one() + transport._on_ready() self.assertFalse(self.sslsock.recv.called) def test_on_ready_recv(self): self.sslsock.recv.return_value = b'data' - self.transport._on_ready() + transport = self._make_one() + transport._on_ready() self.assertTrue(self.sslsock.recv.called) self.assertEqual((b'data',), self.protocol.data_received.call_args[0]) def test_on_ready_recv_eof(self): self.sslsock.recv.return_value = b'' - self.transport._on_ready() - self.assertTrue(self.event_loop.remove_reader.called) - self.assertTrue(self.event_loop.remove_writer.called) + transport = self._make_one() + transport._on_ready() + self.assertTrue(self.loop.remove_reader.called) + self.assertTrue(self.loop.remove_writer.called) self.assertTrue(self.sslsock.close.called) self.protocol.connection_lost.assert_called_with(None) def test_on_ready_recv_retry(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError - self.transport._on_ready() + transport = self._make_one() + transport._on_ready() self.assertTrue(self.sslsock.recv.called) self.assertFalse(self.protocol.data_received.called) self.sslsock.recv.side_effect = ssl.SSLWantWriteError - self.transport._on_ready() + transport._on_ready() self.assertFalse(self.protocol.data_received.called) self.sslsock.recv.side_effect = BlockingIOError - self.transport._on_ready() + transport._on_ready() self.assertFalse(self.protocol.data_received.called) def test_on_ready_recv_exc(self): err = self.sslsock.recv.side_effect = OSError() - self.transport._fatal_error = unittest.mock.Mock() - self.transport._on_ready() - self.transport._fatal_error.assert_called_with(err) + transport = self._make_one() + transport._fatal_error = unittest.mock.Mock() + transport._on_ready() + transport._fatal_error.assert_called_with(err) def test_on_ready_send(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError self.sslsock.send.return_value = 4 - self.transport._buffer = [b'data'] - self.transport._on_ready() + transport = self._make_one() + transport._buffer = [b'data'] + transport._on_ready() + self.assertEqual([], transport._buffer) self.assertTrue(self.sslsock.send.called) - self.assertEqual([], self.transport._buffer) def test_on_ready_send_none(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError self.sslsock.send.return_value = 0 - self.transport._buffer = [b'data1', b'data2'] - self.transport._on_ready() + transport = self._make_one() + transport._buffer = [b'data1', b'data2'] + transport._on_ready() self.assertTrue(self.sslsock.send.called) - self.assertEqual([b'data1data2'], self.transport._buffer) + self.assertEqual([b'data1data2'], transport._buffer) def test_on_ready_send_partial(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError self.sslsock.send.return_value = 2 - self.transport._buffer = [b'data1', b'data2'] - self.transport._on_ready() + transport = self._make_one() + transport._buffer = [b'data1', b'data2'] + transport._on_ready() self.assertTrue(self.sslsock.send.called) - self.assertEqual([b'ta1data2'], self.transport._buffer) + self.assertEqual([b'ta1data2'], transport._buffer) def test_on_ready_send_closing_partial(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError self.sslsock.send.return_value = 2 - self.transport._buffer = [b'data1', b'data2'] - self.transport._on_ready() + transport = self._make_one() + transport._buffer = [b'data1', b'data2'] + transport._on_ready() self.assertTrue(self.sslsock.send.called) self.assertFalse(self.sslsock.close.called) def test_on_ready_send_closing(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError self.sslsock.send.return_value = 4 - self.transport.close() - self.transport._buffer = [b'data'] - self.transport._on_ready() + transport = self._make_one() + transport.close() + transport._buffer = [b'data'] + transport._on_ready() self.assertTrue(self.sslsock.close.called) - self.assertTrue(self.event_loop.remove_writer.called) - self.event_loop.call_soon.assert_called_with( + self.assertTrue(self.loop.remove_writer.called) + self.loop.call_soon.assert_called_with( self.protocol.connection_lost, None) def test_on_ready_send_retry(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError - self.transport._buffer = [b'data'] + transport = self._make_one() + transport._buffer = [b'data'] self.sslsock.send.side_effect = ssl.SSLWantReadError - self.transport._on_ready() + transport._on_ready() self.assertTrue(self.sslsock.send.called) - self.assertEqual([b'data'], self.transport._buffer) + self.assertEqual([b'data'], transport._buffer) self.sslsock.send.side_effect = ssl.SSLWantWriteError - self.transport._on_ready() - self.assertEqual([b'data'], self.transport._buffer) + transport._on_ready() + self.assertEqual([b'data'], transport._buffer) self.sslsock.send.side_effect = BlockingIOError() - self.transport._on_ready() - self.assertEqual([b'data'], self.transport._buffer) + transport._on_ready() + self.assertEqual([b'data'], transport._buffer) def test_on_ready_send_exc(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError err = self.sslsock.send.side_effect = OSError() - self.transport._buffer = [b'data'] - self.transport._fatal_error = unittest.mock.Mock() - self.transport._on_ready() - self.transport._fatal_error.assert_called_with(err) - self.assertEqual([], self.transport._buffer) - self.assertEqual(self.transport._conn_lost, 1) + transport = self._make_one() + transport._buffer = [b'data'] + transport._fatal_error = unittest.mock.Mock() + transport._on_ready() + transport._fatal_error.assert_called_with(err) + self.assertEqual([], transport._buffer) + self.assertEqual(transport._conn_lost, 1) class SelectorDatagramTransportTests(unittest.TestCase): diff --git a/tulip/selector_events.py b/tulip/selector_events.py index e9388f91..6090b2ee 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -438,11 +438,11 @@ def _call_connection_lost(self, exc): class _SelectorSslTransport(transports.Transport): - def __init__(self, event_loop, rawsock, protocol, sslcontext, waiter=None, + def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, server_side=False, extra=None): super().__init__(extra) - self._event_loop = event_loop + self._loop = loop self._rawsock = rawsock self._protocol = protocol sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) @@ -463,10 +463,10 @@ def _on_handshake(self): try: self._sslsock.do_handshake() except ssl.SSLWantReadError: - self._event_loop.add_reader(fd, self._on_handshake) + self._loop.add_reader(fd, self._on_handshake) return except ssl.SSLWantWriteError: - self._event_loop.add_writer(fd, self._on_handshake) + self._loop.add_writer(fd, self._on_handshake) return except Exception as exc: self._sslsock.close() @@ -478,13 +478,13 @@ def _on_handshake(self): if self._waiter is not None: self._waiter.set_exception(exc) raise - self._event_loop.remove_reader(fd) - self._event_loop.remove_writer(fd) - self._event_loop.add_reader(fd, self._on_ready) - self._event_loop.add_writer(fd, self._on_ready) - self._event_loop.call_soon(self._protocol.connection_made, self) + self._loop.remove_reader(fd) + self._loop.remove_writer(fd) + self._loop.add_reader(fd, self._on_ready) + self._loop.add_writer(fd, self._on_ready) + self._loop.call_soon(self._protocol.connection_made, self) if self._waiter is not None: - self._event_loop.call_soon(self._waiter.set_result, None) + self._loop.call_soon(self._waiter.set_result, None) def _on_ready(self): # Because of renegotiations (?), there's no difference between @@ -514,8 +514,8 @@ def _on_ready(self): else: # TODO: Don't close when self._buffer is non-empty. assert not self._buffer - self._event_loop.remove_reader(fd) - self._event_loop.remove_writer(fd) + self._loop.remove_reader(fd) + self._loop.remove_writer(fd) self._sslsock.close() self._protocol.connection_lost(None) return @@ -542,7 +542,7 @@ def _on_ready(self): if n < len(data): self._buffer.append(data[n:]) elif self._closing: - self._event_loop.remove_writer(self._sslsock.fileno()) + self._loop.remove_writer(self._sslsock.fileno()) self._sslsock.close() self._protocol.connection_lost(None) @@ -568,19 +568,19 @@ def abort(self): def close(self): self._closing = True - self._event_loop.remove_reader(self._sslsock.fileno()) + self._loop.remove_reader(self._sslsock.fileno()) if not self._buffer: - self._event_loop.call_soon(self._protocol.connection_lost, None) + self._loop.call_soon(self._protocol.connection_lost, None) def _fatal_error(self, exc): tulip_log.exception('Fatal error for %s', self) self._close(exc) def _close(self, exc): - self._event_loop.remove_writer(self._sslsock.fileno()) - self._event_loop.remove_reader(self._sslsock.fileno()) + self._loop.remove_writer(self._sslsock.fileno()) + self._loop.remove_reader(self._sslsock.fileno()) self._buffer = [] - self._event_loop.call_soon(self._protocol.connection_lost, exc) + self._loop.call_soon(self._protocol.connection_lost, exc) class _SelectorDatagramTransport(transports.DatagramTransport): From d06ba1118e20c4cfc82155ac45c6323d8c192641 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 26 Apr 2013 11:29:38 -0700 Subject: [PATCH 0444/1502] Name abstract classes Abstract*. Move a few things around. --- examples/tcp_protocol_parser.py | 2 +- tests/events_test.py | 4 ++-- tulip/base_events.py | 14 +++++--------- tulip/events.py | 34 ++++++++++++++++++++------------- tulip/unix_events.py | 3 ++- tulip/windows_events.py | 2 +- 6 files changed, 32 insertions(+), 27 deletions(-) diff --git a/examples/tcp_protocol_parser.py b/examples/tcp_protocol_parser.py index f05518b1..340f1766 100755 --- a/examples/tcp_protocol_parser.py +++ b/examples/tcp_protocol_parser.py @@ -24,7 +24,7 @@ def my_protocol_parser(): parsers are implemented as a state machine. more details in tulip/parsers.py - existing parsers: + existing parsers: * http protocol parsers tulip/http/protocol.py * websocket parser tulip/http/websocket.py """ diff --git a/tests/events_test.py b/tests/events_test.py index 47cf4e82..d98324d5 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -1435,7 +1435,7 @@ def test_empty(self): class PolicyTests(unittest.TestCase): def test_event_loop_policy(self): - policy = events.EventLoopPolicy() + policy = events.AbstractEventLoopPolicy() self.assertRaises(NotImplementedError, policy.get_event_loop) self.assertRaises(NotImplementedError, policy.set_event_loop, object()) self.assertRaises(NotImplementedError, policy.new_event_loop) @@ -1477,7 +1477,7 @@ def test_set_event_loop(self): def test_get_event_loop_policy(self): policy = events.get_event_loop_policy() - self.assertIsInstance(policy, events.EventLoopPolicy) + self.assertIsInstance(policy, events.AbstractEventLoopPolicy) self.assertIs(policy, events.get_event_loop_policy()) def test_set_event_loop_policy(self): diff --git a/tulip/base_events.py b/tulip/base_events.py index 81b01253..f3935057 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -90,15 +90,8 @@ def _process_events(self, event_list): """Process selector events.""" raise NotImplementedError - def is_running(self): - """Returns running status of event loop.""" - return self._running - def run_forever(self): - """Run until stop() is called. - - TODO: Maybe rename to run(). - """ + """Run until stop() is called.""" if self._running: raise RuntimeError('Event loop is running.') self._running = True @@ -175,6 +168,10 @@ def stop(self): """ self.call_soon(_raise_stop_error) + def is_running(self): + """Returns running status of event loop.""" + return self._running + def call_later(self, delay, callback, *args): """Arrange for a callback to be called at a given time. @@ -393,7 +390,6 @@ def create_datagram_endpoint(self, protocol_factory, sock, protocol, r_addr, extra={'addr': l_addr}) return transport, protocol - # TODO: Or create_server()? @tasks.task def start_serving(self, protocol_factory, host=None, port=None, *, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, diff --git a/tulip/events.py b/tulip/events.py index e615d06c..48cc2c81 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -4,7 +4,7 @@ - Only the main thread has a default event loop. """ -__all__ = ['EventLoopPolicy', 'DefaultEventLoopPolicy', +__all__ = ['AbstractEventLoopPolicy', 'DefaultEventLoopPolicy', 'AbstractEventLoop', 'Timer', 'Handle', 'make_handle', 'get_event_loop_policy', 'set_event_loop_policy', 'get_event_loop', 'set_event_loop', 'new_event_loop', @@ -115,11 +115,10 @@ def __ne__(self, other): class AbstractEventLoop: """Abstract event loop.""" - def run_forever(self): - """Run the event loop until stop() is called. + # Running and stopping the event loop. - TODO: Rename to run(). - """ + def run_forever(self): + """Run the event loop until stop() is called.""" raise NotImplementedError def run_once(self, timeout=None): @@ -148,7 +147,14 @@ def stop(self): """ raise NotImplementedError - # Methods returning Handles for scheduling callbacks. + def is_running(self): + """Return whether the event loop is currently running.""" + raise NotImplementedError + + # Methods scheduling callbacks. All these return Handles. + + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) def call_later(self, delay, callback, *args): raise NotImplementedError @@ -156,20 +162,20 @@ def call_later(self, delay, callback, *args): def call_repeatedly(self, interval, callback, *args): raise NotImplementedError - def call_soon(self, callback, *args): - return self.call_later(0, callback, *args) + # Methods for interacting with threads. def call_soon_threadsafe(self, callback, *args): raise NotImplementedError - # Methods returning Futures for interacting with threads. - def wrap_future(self, future): raise NotImplementedError def run_in_executor(self, executor, callback, *args): raise NotImplementedError + def set_default_executor(self, executor): + raise NotImplementedError + # Network I/O methods returning Futures. def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): @@ -293,7 +299,7 @@ def remove_signal_handler(self, sig): raise NotImplementedError -class EventLoopPolicy: +class AbstractEventLoopPolicy: """Abstract policy for accessing the event loop.""" def get_event_loop(self): @@ -309,7 +315,7 @@ def new_event_loop(self): raise NotImplementedError -class DefaultEventLoopPolicy(threading.local, EventLoopPolicy): +class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy): """Default policy implementation for accessing the event loop. In this policy, each thread has its own event loop. However, we @@ -336,6 +342,7 @@ def get_event_loop(self): def set_event_loop(self, event_loop): """Set the event loop.""" + # TODO: The isinstance() test violates the PEP. assert event_loop is None or isinstance(event_loop, AbstractEventLoop) self._event_loop = event_loop @@ -371,7 +378,8 @@ def get_event_loop_policy(): def set_event_loop_policy(policy): """XXX""" global _event_loop_policy - assert policy is None or isinstance(policy, EventLoopPolicy) + # TODO: The isinstance() test violates the PEP. + assert policy is None or isinstance(policy, AbstractEventLoopPolicy) _event_loop_policy = policy diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 73ada428..87514ef1 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -87,7 +87,8 @@ def _handle_signal(self, sig, arg): def remove_signal_handler(self, sig): """Remove a handler for a signal. UNIX only. - Return True if a signal handler was removed, False if not.""" + Return True if a signal handler was removed, False if not. + """ self._check_signal(sig) try: del self._signal_handlers[sig] diff --git a/tulip/windows_events.py b/tulip/windows_events.py index 9feaa543..bede2b5e 100644 --- a/tulip/windows_events.py +++ b/tulip/windows_events.py @@ -13,7 +13,7 @@ from .log import tulip_log -__all__ = ['SelectorEventLoop', 'ProactorEventLoop'] +__all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor'] NULL = 0 From 67ad76a6d987c091b01a27f1afd0bc765bed3497 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 26 Apr 2013 11:41:29 -0700 Subject: [PATCH 0445/1502] Add TODO about optimizing modify() if events unchanged. --- tulip/selectors.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tulip/selectors.py b/tulip/selectors.py index 4e671444..388df25f 100644 --- a/tulip/selectors.py +++ b/tulip/selectors.py @@ -125,6 +125,8 @@ def modify(self, fileobj, events, data=None): except KeyError: raise ValueError("{!r} is not registered".format(fileobj)) if events != key.events or data != key.data: + # TODO: If only the data changed, use a shortcut that only + # updates the data. self.unregister(fileobj) return self.register(fileobj, events, data) else: From 04855d6a8efb03510fb69ff4847990c45113d597 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 26 Apr 2013 13:03:29 -0700 Subject: [PATCH 0446/1502] I like 79 as the max line length. --- check.py | 4 ++-- tulip/base_events.py | 10 ++++++---- tulip/tasks.py | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/check.py b/check.py index d28b31f7..9ab6bcc0 100644 --- a/check.py +++ b/check.py @@ -1,4 +1,4 @@ -"""Search for lines > 80 chars or with trailing whitespace.""" +"""Search for lines >= 80 chars or with trailing whitespace.""" import sys, os @@ -32,7 +32,7 @@ def process(fn): for i, line in enumerate(f): line = line.rstrip('\n') sline = line.rstrip() - if len(line) > 80 or line != sline or not isascii(line): + if len(line) >= 80 or line != sline or not isascii(line): print('{}:{:d}:{}{}'.format( fn, i+1, sline, '_' * (len(line) - len(sline)))) finally: diff --git a/tulip/base_events.py b/tulip/base_events.py index f3935057..5544495a 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -268,7 +268,7 @@ def create_connection(self, protocol_factory, host=None, port=None, *, if host is not None or port is not None: if sock is not None: raise ValueError( - "host, port and sock can not be specified at the same time") + "host/port and sock can not be specified at the same time") infos = yield from self.getaddrinfo( host, port, family=family, @@ -360,7 +360,8 @@ def create_datagram_endpoint(self, protocol_factory, exceptions = [] - for (family, proto), (local_address, remote_address) in addr_pairs_info: + for ((family, proto), + (local_address, remote_address)) in addr_pairs_info: sock = None l_addr = None r_addr = None @@ -398,7 +399,7 @@ def start_serving(self, protocol_factory, host=None, port=None, *, if host is not None or port is not None: if sock is not None: raise ValueError( - "host, port and sock can not be specified at the same time") + "host/port and sock can not be specified at the same time") AF_INET6 = getattr(socket, 'AF_INET6', 0) if reuse_address is None: @@ -426,7 +427,8 @@ def start_serving(self, protocol_factory, host=None, port=None, *, # default on Linux) which makes a single socket # listen on both address families. if af == AF_INET6 and hasattr(socket, "IPPROTO_IPV6"): - sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, + sock.setsockopt(socket.IPPROTO_IPV6, + socket.IPV6_V6ONLY, True) try: sock.bind(sa) diff --git a/tulip/tasks.py b/tulip/tasks.py index 119b5ee6..541043be 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -181,7 +181,7 @@ def _step(self, value=_marker, exc=None): self._event_loop.call_soon( self._step, None, RuntimeError( - 'Task received bad yield: {!r}'.format(result))) + 'Task got bad yield: {!r}'.format(result))) else: self._event_loop.call_soon(self._step_maybe) From 5a54b869dbcfbe8206c4e9355f3f2a7d38d0d04d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 29 Apr 2013 10:52:13 -0700 Subject: [PATCH 0447/1502] Rename Timer to TimerHandle. --- tests/base_events_test.py | 24 +++++++++++++----------- tests/events_test.py | 13 +++++++------ tulip/base_events.py | 8 ++++---- tulip/events.py | 12 ++++++------ 4 files changed, 30 insertions(+), 27 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index e991f835..aeb1cd8c 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -55,8 +55,8 @@ def test_add_callback_handle(self): def test_add_callback_timer(self): when = time.monotonic() - h1 = events.Timer(when, lambda: False, ()) - h2 = events.Timer(when+10.0, lambda: False, ()) + h1 = events.TimerHandle(when, lambda: False, ()) + h2 = events.TimerHandle(when+10.0, lambda: False, ()) self.event_loop._add_callback(h2) self.event_loop._add_callback(h1) @@ -109,7 +109,7 @@ def cb(): pass h = self.event_loop.call_later(10.0, cb) - self.assertIsInstance(h, events.Timer) + self.assertIsInstance(h, events.TimerHandle) self.assertIn(h, self.event_loop._scheduled) self.assertNotIn(h, self.event_loop._ready) @@ -130,7 +130,7 @@ def cb(): None, events.Handle(cb, ()), ('',)) self.assertRaises( AssertionError, self.event_loop.run_in_executor, - None, events.Timer(10, cb, ())) + None, events.TimerHandle(10, cb, ())) def test_run_once_in_executor_cancelled(self): def cb(): @@ -171,8 +171,8 @@ def test_run_once(self): self.assertTrue(self.event_loop._run_once.called) def test__run_once(self): - h1 = events.Timer(time.monotonic() + 0.1, lambda: True, ()) - h2 = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + h1 = events.TimerHandle(time.monotonic() + 0.1, lambda: True, ()) + h2 = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ()) h1.cancel() @@ -187,7 +187,7 @@ def test__run_once(self): self.assertTrue(self.event_loop._process_events.called) def test__run_once_timeout(self): - h = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + h = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ()) self.event_loop._process_events = unittest.mock.Mock() self.event_loop._scheduled.append(h) @@ -196,7 +196,7 @@ def test__run_once_timeout(self): def test__run_once_timeout_with_ready(self): # If event loop has ready callbacks, select timeout is always 0. - h = events.Timer(time.monotonic() + 10.0, lambda: True, ()) + h = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ()) self.event_loop._process_events = unittest.mock.Mock() self.event_loop._scheduled.append(h) @@ -221,14 +221,16 @@ def monotonic(): m_logging.INFO = logging.INFO m_logging.DEBUG = logging.DEBUG - self.event_loop._scheduled.append(events.Timer(11.0, lambda: True, ())) + self.event_loop._scheduled.append(events.TimerHandle(11.0, + lambda: True, ())) self.event_loop._process_events = unittest.mock.Mock() self.event_loop._run_once() self.assertEqual(logging.INFO, m_logging.log.call_args[0][0]) idx = -1 data = [10.0, 10.0, 10.3, 13.0] - self.event_loop._scheduled = [events.Timer(11.0, lambda:True, ())] + self.event_loop._scheduled = [events.TimerHandle(11.0, + lambda:True, ())] self.event_loop._run_once() self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0]) @@ -241,7 +243,7 @@ def cb(event_loop): processed = True handle = event_loop.call_soon(lambda: True) - h = events.Timer(time.monotonic() - 1, cb, (self.event_loop,)) + h = events.TimerHandle(time.monotonic() - 1, cb, (self.event_loop,)) self.event_loop._process_events = unittest.mock.Mock() self.event_loop._scheduled.append(h) diff --git a/tests/events_test.py b/tests/events_test.py index d98324d5..aea1606f 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -1292,7 +1292,7 @@ def callback(*args): args = () when = time.monotonic() - h = events.Timer(when, callback, args) + h = events.TimerHandle(when, callback, args) self.assertIs(h.callback, callback) self.assertIs(h.args, args) self.assertFalse(h.cancelled) @@ -1306,7 +1306,8 @@ def callback(*args): r = repr(h) self.assertTrue(r.endswith('())')) - self.assertRaises(AssertionError, events.Timer, None, callback, args) + self.assertRaises(AssertionError, + events.TimerHandle, None, callback, args) def test_timer_comparison(self): def callback(*args): @@ -1314,8 +1315,8 @@ def callback(*args): when = time.monotonic() - h1 = events.Timer(when, callback, ()) - h2 = events.Timer(when, callback, ()) + h1 = events.TimerHandle(when, callback, ()) + h2 = events.TimerHandle(when, callback, ()) self.assertFalse(h1 < h2) self.assertFalse(h2 < h1) self.assertTrue(h1 <= h2) @@ -1330,8 +1331,8 @@ def callback(*args): h2.cancel() self.assertFalse(h1 == h2) - h1 = events.Timer(when, callback, ()) - h2 = events.Timer(when + 10.0, callback, ()) + h1 = events.TimerHandle(when, callback, ()) + h2 = events.TimerHandle(when + 10.0, callback, ()) self.assertTrue(h1 < h2) self.assertFalse(h2 < h1) self.assertTrue(h1 <= h2) diff --git a/tulip/base_events.py b/tulip/base_events.py index 5544495a..58c90b3c 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -197,7 +197,7 @@ def call_later(self, delay, callback, *args): if delay <= 0: return self.call_soon(callback, *args) - handle = events.Timer(time.monotonic() + delay, callback, args) + handle = events.TimerHandle(time.monotonic() + delay, callback, args) heapq.heappush(self._scheduled, handle) return handle @@ -210,7 +210,7 @@ def wrapper(): handle._when = time.monotonic() + interval heapq.heappush(self._scheduled, handle) - handle = events.Timer(time.monotonic() + interval, wrapper, ()) + handle = events.TimerHandle(time.monotonic() + interval, wrapper, ()) heapq.heappush(self._scheduled, handle) return handle @@ -237,7 +237,7 @@ def call_soon_threadsafe(self, callback, *args): def run_in_executor(self, executor, callback, *args): if isinstance(callback, events.Handle): assert not args - assert not isinstance(callback, events.Timer) + assert not isinstance(callback, events.TimerHandle) if callback.cancelled: f = futures.Future() f.set_result(None) @@ -475,7 +475,7 @@ def _add_callback(self, handle): """Add a Handle to ready or scheduled.""" if handle.cancelled: return - if isinstance(handle, events.Timer): + if isinstance(handle, events.TimerHandle): heapq.heappush(self._scheduled, handle) else: self._ready.append(handle) diff --git a/tulip/events.py b/tulip/events.py index 48cc2c81..9fc3d1ed 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -5,7 +5,7 @@ """ __all__ = ['AbstractEventLoopPolicy', 'DefaultEventLoopPolicy', - 'AbstractEventLoop', 'Timer', 'Handle', 'make_handle', + 'AbstractEventLoop', 'TimerHandle', 'Handle', 'make_handle', 'get_event_loop_policy', 'set_event_loop_policy', 'get_event_loop', 'set_event_loop', 'new_event_loop', ] @@ -61,7 +61,7 @@ def make_handle(callback, args): return Handle(callback, args) -class Timer(Handle): +class TimerHandle(Handle): """Object returned by timed callback registration methods.""" def __init__(self, when, callback, args): @@ -71,9 +71,9 @@ def __init__(self, when, callback, args): self._when = when def __repr__(self): - res = 'Timer({}, {}, {})'.format(self._when, - self._callback, - self._args) + res = 'TimerHandle({}, {}, {})'.format(self._when, + self._callback, + self._args) if self._cancelled: res += '' @@ -100,7 +100,7 @@ def __ge__(self, other): return self.__eq__(other) def __eq__(self, other): - if isinstance(other, Timer): + if isinstance(other, TimerHandle): return (self._when == other._when and self._callback == other._callback and self._args == other._args and From d012afaf2cdd8e00341c4e2c82f84ea621257200 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 29 Apr 2013 11:27:20 -0700 Subject: [PATCH 0448/1502] Disallow ssl=True for servers. Prefer ssl=None over ssl=False. --- examples/srv.py | 2 +- tulip/base_events.py | 4 ++-- tulip/events.py | 8 ++++---- tulip/proactor_events.py | 2 +- tulip/selector_events.py | 14 +++++++++----- tulip/test_utils.py | 2 +- 6 files changed, 18 insertions(+), 14 deletions(-) diff --git a/examples/srv.py b/examples/srv.py index f93e1d7e..aad0875d 100755 --- a/examples/srv.py +++ b/examples/srv.py @@ -144,7 +144,7 @@ def main(): sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext.load_cert_chain(certfile, keyfile) else: - sslcontext = False + sslcontext = None loop = tulip.get_event_loop() f = loop.start_serving( diff --git a/tulip/base_events.py b/tulip/base_events.py index 58c90b3c..4d890722 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -263,7 +263,7 @@ def getnameinfo(self, sockaddr, flags=0): @tasks.coroutine def create_connection(self, protocol_factory, host=None, port=None, *, - ssl=False, family=0, proto=0, flags=0, sock=None): + ssl=None, family=0, proto=0, flags=0, sock=None): """XXX""" if host is not None or port is not None: if sock is not None: @@ -394,7 +394,7 @@ def create_datagram_endpoint(self, protocol_factory, @tasks.task def start_serving(self, protocol_factory, host=None, port=None, *, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, - sock=None, backlog=100, ssl=False, reuse_address=None): + sock=None, backlog=100, ssl=None, reuse_address=None): """XXX""" if host is not None or port is not None: if sock is not None: diff --git a/tulip/events.py b/tulip/events.py index 9fc3d1ed..c8f2401c 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -185,12 +185,12 @@ def getnameinfo(self, sockaddr, flags=0): raise NotImplementedError def create_connection(self, protocol_factory, host=None, port=None, *, - family=0, proto=0, flags=0, sock=None): + ssl=None, family=0, proto=0, flags=0, sock=None): raise NotImplementedError def start_serving(self, protocol_factory, host=None, port=None, *, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, - sock=None, backlog=100, ssl=False, reuse_address=None): + sock=None, backlog=100, ssl=None, reuse_address=None): """Creates a TCP server bound to host and port and return a list of socket objects which will later be handled by protocol_factory. @@ -211,8 +211,8 @@ def start_serving(self, protocol_factory, host=None, port=None, *, backlog is the maximum number of queued connections passed to listen() (defaults to 100). - ssl can be set to True to enable SSL over the accepted - connections. + ssl can be set to an SSLContext to enable SSL over the + accepted connections. reuse_address tells the kernel to reuse a local socket in TIME_WAIT state, without waiting for its natural timeout to diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index cc889b55..c142ff45 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -177,7 +177,7 @@ def _loop_self_reading(self, f=None): def _write_to_self(self): self._csock.send(b'x') - def _start_serving(self, protocol_factory, sock, ssl=False): + def _start_serving(self, protocol_factory, sock, ssl=None): assert not ssl, 'IocpEventLoop imcompatible with SSL.' def loop(f=None): diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 6090b2ee..0dd492aa 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -98,11 +98,11 @@ def _write_to_self(self): except (BlockingIOError, InterruptedError): pass - def _start_serving(self, protocol_factory, sock, ssl=False): + def _start_serving(self, protocol_factory, sock, ssl=None): self.add_reader(sock.fileno(), self._accept_connection, protocol_factory, sock, ssl) - def _accept_connection(self, protocol_factory, sock, ssl=False): + def _accept_connection(self, protocol_factory, sock, ssl=None): try: conn, addr = sock.accept() conn.setblocking(False) @@ -117,9 +117,8 @@ def _accept_connection(self, protocol_factory, sock, ssl=False): tulip_log.exception('Accept failed') else: if ssl: - sslcontext = None if isinstance(ssl, bool) else ssl self._make_ssl_transport( - conn, protocol_factory(), sslcontext, None, + conn, protocol_factory(), ssl, None, server_side=True, extra={'addr': addr}) else: self._make_socket_transport( @@ -445,7 +444,12 @@ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, self._loop = loop self._rawsock = rawsock self._protocol = protocol - sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + if server_side: + assert isinstance(sslcontext, + ssl.SSLContext), 'Must pass an SSLContext' + else: + # Client-side may pass ssl=True to use a default context. + sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) self._sslcontext = sslcontext self._waiter = waiter sslsock = sslcontext.wrap_socket(rawsock, server_side=server_side, diff --git a/tulip/test_utils.py b/tulip/test_utils.py index e6d1069b..0bd0dfc7 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -91,7 +91,7 @@ def handle_request(self, message, payload): sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext.load_cert_chain(certfile, keyfile) else: - sslcontext = False + sslcontext = None def run(loop, fut): thread_loop = tulip.new_event_loop() From 7af799dab508c741d38d7b39081bebeb997b0874 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 29 Apr 2013 12:59:30 -0700 Subject: [PATCH 0449/1502] Disable 'Future result has not been requested' warn msg for tasks --- tests/futures_test.py | 11 +++++++++++ tulip/futures.py | 3 ++- tulip/tasks.py | 3 +++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/futures_test.py b/tests/futures_test.py index 9e2c4dea..a425ed16 100644 --- a/tests/futures_test.py +++ b/tests/futures_test.py @@ -194,6 +194,17 @@ def test_del_done(self, log): log.error.mock_calls[-1].assert_called_with( 'Future result has not been requested: %r', r_fut) + @unittest.mock.patch('tulip.futures.tulip_log') + def test_del_done_skip(self, log): + self.loop.set_log_level(futures.STACK_DEBUG) + + fut = futures.Future() + fut._debug_warn_result_requested = False + next(iter(fut)) + fut.set_result(1) + del fut + self.assertFalse(log.error.called) + @unittest.mock.patch('tulip.futures.tulip_log') def test_del_exc(self, log): self.loop.set_log_level(futures.STACK_DEBUG) diff --git a/tulip/futures.py b/tulip/futures.py index a778a51c..e3baf15e 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -65,6 +65,7 @@ class Future: # result of the future has to be requested _debug_stack = None _debug_result_requested = False + _debug_warn_result_requested = True def __init__(self, *, event_loop=None, timeout=None): """Initialize the future. @@ -302,7 +303,7 @@ def __del__(self): exc_info=(exc.__class__, exc, exc.__traceback__)) if (self._debug_stack and level <= STACK_DEBUG): tulip_log.error(self._debug_stack) - else: + elif self._debug_warn_result_requested: tulip_log.error( 'Future result has not been requested: %s', r_self) if (self._debug_stack and level <= STACK_DEBUG): diff --git a/tulip/tasks.py b/tulip/tasks.py index 541043be..0b1ab1f3 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -71,6 +71,9 @@ def task_wrapper(*args, **kwds): class Task(futures.Future): """A coroutine wrapped in a Future.""" + # disable "Future result has not been requested" warning message. + _debug_warn_result_requested = False + def __init__(self, coro, event_loop=None, timeout=None): assert inspect.isgenerator(coro) # Must be a coroutine *object*. super().__init__(event_loop=event_loop, timeout=timeout) From 5bf796cd6398007f7fbc4c580775fab85b487ebe Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 29 Apr 2013 13:01:10 -0700 Subject: [PATCH 0450/1502] enable debug logging in examples; fix mpsrv.py example --- examples/mpsrv.py | 5 +---- examples/srv.py | 5 ++++- examples/tcp_protocol_parser.py | 2 -- examples/udp_echo.py | 2 -- examples/wsclient.py | 1 - examples/wssrv.py | 3 --- 6 files changed, 5 insertions(+), 13 deletions(-) diff --git a/examples/mpsrv.py b/examples/mpsrv.py index daf55b1a..d6f82d3a 100755 --- a/examples/mpsrv.py +++ b/examples/mpsrv.py @@ -4,7 +4,6 @@ import argparse import email.message import os -import logging import socket import signal import time @@ -116,7 +115,6 @@ def start(self): # start server self.loop = loop = tulip.new_event_loop() tulip.set_event_loop(loop) - loop.set_log_level(logging.CRITICAL) def stop(): self.loop.stop() @@ -124,7 +122,7 @@ def stop(): loop.add_signal_handler(signal.SIGINT, stop) f = loop.start_serving(lambda: HttpServer(debug=True), sock=self.sock) - x = loop.run_until_complete(f[0]) + x = loop.run_until_complete(f)[0] print('Starting srv worker process {} on {}'.format( os.getpid(), x.getsockname())) @@ -256,7 +254,6 @@ class Superviser: def __init__(self, args): self.loop = tulip.get_event_loop() - self.loop.set_log_level(logging.CRITICAL) self.args = args self.workers = [] diff --git a/examples/srv.py b/examples/srv.py index aad0875d..2c2ae91e 100755 --- a/examples/srv.py +++ b/examples/srv.py @@ -151,7 +151,10 @@ def main(): lambda: HttpServer(debug=True), args.host, args.port, ssl=sslcontext) socks = loop.run_until_complete(f) print('serving on', socks[0].getsockname()) - loop.run_forever() + try: + loop.run_forever() + except KeyboardInterrupt: + pass if __name__ == '__main__': diff --git a/examples/tcp_protocol_parser.py b/examples/tcp_protocol_parser.py index 340f1766..a0258613 100755 --- a/examples/tcp_protocol_parser.py +++ b/examples/tcp_protocol_parser.py @@ -2,7 +2,6 @@ """Protocol parser example.""" import argparse import collections -import logging import tulip try: import signal @@ -162,7 +161,6 @@ def start_server(loop, host, port): ARGS.print_help() else: loop = tulip.get_event_loop() - loop.set_log_level(logging.CRITICAL) if signal is not None: loop.add_signal_handler(signal.SIGINT, loop.stop) diff --git a/examples/udp_echo.py b/examples/udp_echo.py index d7bde29a..0347bfbd 100755 --- a/examples/udp_echo.py +++ b/examples/udp_echo.py @@ -3,7 +3,6 @@ import argparse import sys import tulip -import logging try: import signal except ImportError: @@ -88,7 +87,6 @@ def start_client(loop, addr): ARGS.print_help() else: loop = tulip.get_event_loop() - loop.set_log_level(logging.CRITICAL) if signal is not None: loop.add_signal_handler(signal.SIGINT, loop.stop) diff --git a/examples/wsclient.py b/examples/wsclient.py index 5598c38c..f5b2ef58 100755 --- a/examples/wsclient.py +++ b/examples/wsclient.py @@ -92,7 +92,6 @@ def dispatch(): loop = tulip.SelectorEventLoop(tulip.selectors.SelectSelector()) tulip.set_event_loop(loop) - loop.set_log_level(50) loop.add_signal_handler(signal.SIGINT, loop.stop) tulip.Task(start_client(loop, url)) loop.run_forever() diff --git a/examples/wssrv.py b/examples/wssrv.py index 8459f458..2befec1d 100755 --- a/examples/wssrv.py +++ b/examples/wssrv.py @@ -3,7 +3,6 @@ import argparse import os -import logging import socket import signal import time @@ -120,7 +119,6 @@ def start(self): # start server self.loop = loop = tulip.new_event_loop() tulip.set_event_loop(loop) - loop.set_log_level(logging.CRITICAL) def stop(): self.loop.stop() @@ -277,7 +275,6 @@ class Superviser: def __init__(self, args): self.loop = tulip.get_event_loop() - self.loop.set_log_level(logging.CRITICAL) self.args = args self.workers = [] From c41f9a9d5995ff00f2bb24b24d286365e5a6c570 Mon Sep 17 00:00:00 2001 From: Giampaolo Rodola' Date: Mon, 29 Apr 2013 22:38:08 +0200 Subject: [PATCH 0451/1502] 'make test' will now crash in case of syntax error in test files --- runtests.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/runtests.py b/runtests.py index 0ec5ba31..cd334423 100644 --- a/runtests.py +++ b/runtests.py @@ -88,6 +88,8 @@ def list_dir(prefix, dir): try: loader = importlib.machinery.SourceFileLoader(modname, sourcefile) mods.append((loader.load_module(), sourcefile)) + except SyntaxError: + raise except Exception as err: print("Skipping '{}': {}".format(modname, err), file=sys.stderr) From d3e5a467b8981bfd0a23e4370b58372cc75f22c0 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 30 Apr 2013 09:40:02 -0700 Subject: [PATCH 0452/1502] Rename sleep() argument from when to delay. --- tulip/tasks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tulip/tasks.py b/tulip/tasks.py index 0b1ab1f3..8dfd73a3 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -330,10 +330,10 @@ def _wrap_coroutines(fs): @coroutine -def sleep(when, result=None): +def sleep(delay, result=None): """Coroutine that completes after a given time (in seconds).""" future = futures.Future() - h = future._event_loop.call_later(when, future.set_result, result) + h = future._event_loop.call_later(delay, future.set_result, result) try: return (yield from future) finally: From 9d115a25c073f07524a5906376f63b490b1a675a Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 2 May 2013 13:46:32 -0700 Subject: [PATCH 0453/1502] simple http client connection pooling --- examples/crawl.py | 8 +- tests/http_client_functional_test.py | 93 +++++++++++++++++++ tests/http_parser_test.py | 14 +++ tests/http_session_test.py | 133 +++++++++++++++++++++++++++ tulip/http/__init__.py | 2 + tulip/http/client.py | 103 ++++++++++++++------- tulip/http/session.py | 101 ++++++++++++++++++++ tulip/parsers.py | 3 +- tulip/test_utils.py | 7 +- 9 files changed, 425 insertions(+), 39 deletions(-) create mode 100644 tests/http_session_test.py create mode 100644 tulip/http/session.py diff --git a/examples/crawl.py b/examples/crawl.py index 723d5305..ac9c25e9 100755 --- a/examples/crawl.py +++ b/examples/crawl.py @@ -21,12 +21,17 @@ def __init__(self, rooturl, loop, maxtasks=100): self.tasks = set() self.sem = tulip.Semaphore(maxtasks) + # session stores cookies between requests and uses connection pool + self.session = tulip.http.Session() + @tulip.task def run(self): self.addurls([(self.rooturl, '')]) # Set initial work. yield from tulip.sleep(1) while self.busy: yield from tulip.sleep(1) + + self.session.close() self.loop.stop() @tulip.task @@ -52,7 +57,8 @@ def process(self, url): self.todo.remove(url) self.busy.add(url) try: - resp = yield from tulip.http.request('get', url) + resp = yield from tulip.http.request( + 'get', url, session=self.session) except Exception as exc: print('...', url, 'has error', repr(str(exc))) self.done[url] = False diff --git a/tests/http_client_functional_test.py b/tests/http_client_functional_test.py index 5eb65702..31d0df86 100644 --- a/tests/http_client_functional_test.py +++ b/tests/http_client_functional_test.py @@ -34,6 +34,7 @@ def test_HTTP_200_OK_METHOD(self): self.assertEqual(r.status, 200) self.assertIn('"method": "%s"' % meth.upper(), content) self.assertEqual(content1, content2) + r.close() def test_HTTP_302_REDIRECT_GET(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: @@ -324,6 +325,15 @@ def test_cookies(self): content = self.loop.run_until_complete(r.content.read()) self.assertIn(b'"Cookie": "test1=123; test3=456"', bytes(content)) + def test_set_cookies(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + resp = self.loop.run_until_complete( + client.request('get', httpd.url('cookies'))) + self.assertEqual(resp.status, 200) + + self.assertEqual(resp.cookies['c1'].value, 'cookie1') + self.assertEqual(resp.cookies['c2'].value, 'cookie2') + def test_chunked(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: r = self.loop.run_until_complete( @@ -347,6 +357,64 @@ def test_request_conn_error(self): self.loop.run_until_complete, client.request('get', 'http://0.0.0.0:1', timeout=0.1)) + def test_keepalive(self): + from tulip.http import session + s = session.Session() + + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('keepalive',), session=s)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['content'], 'requests=1') + r.close() + + r = self.loop.run_until_complete( + client.request('get', httpd.url('keepalive'), session=s)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['content'], 'requests=2') + r.close() + + def test_session_close(self): + from tulip.http import session + s = session.Session() + + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request( + 'get', httpd.url('keepalive') + '?close=1', session=s)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['content'], 'requests=1') + r.close() + + r = self.loop.run_until_complete( + client.request('get', httpd.url('keepalive'), session=s)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['content'], 'requests=1') + r.close() + + def test_session_cookies(self): + from tulip.http import session + s = session.Session() + + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + s.update_cookies({'test': '1'}) + r = self.loop.run_until_complete( + client.request( + 'get', httpd.url('cookies'), session=s)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + + self.assertEqual(content['headers']['Cookie'], 'test=1') + r.close() + + cookies = sorted([(k, v.value) for k, v in s.cookies.items()]) + self.assertEqual( + cookies, [('c1', 'cookie1'), ('c2', 'cookie2'), ('test', '1')]) + class Functional(test_utils.Router): @@ -392,3 +460,28 @@ def chunked(self, match): resp = self._start_response(200) resp.add_chunking_filter(100) self._response(resp, chunked=True) + + @test_utils.Router.define('/keepalive$') + def keepalive(self, match): + self._transport._requests = getattr( + self._transport, '_requests', 0) + 1 + resp = self._start_response(200) + if 'close=' in self._query: + self._response( + resp, 'requests={}'.format(self._transport._requests)) + else: + self._response( + resp, 'requests={}'.format(self._transport._requests), + headers={'CONNECTION': 'keep-alive'}) + + @test_utils.Router.define('/cookies$') + def cookies(self, match): + cookies = http.cookies.SimpleCookie() + cookies['c1'] = 'cookie1' + cookies['c2'] = 'cookie2' + + resp = self._start_response(200) + for cookie in cookies.output(header='').split('\n'): + resp.add_header('Set-Cookie', cookie.strip()) + + self._response(resp) diff --git a/tests/http_parser_test.py b/tests/http_parser_test.py index 71951be7..c8dfb1bc 100644 --- a/tests/http_parser_test.py +++ b/tests/http_parser_test.py @@ -24,6 +24,20 @@ def test_parse_headers(self): self.assertIsNone(close) self.assertIsNone(compression) + def test_parse_headers_multi(self): + hdrs = ('', + 'Set-Cookie: c1=cookie1\r\n', + 'Set-Cookie: c2=cookie2\r\n', '\r\n') + + headers, close, compression = protocol.parse_headers( + hdrs, 8190, 32768, 8190) + + self.assertEqual(list(headers), + [('SET-COOKIE', 'c1=cookie1'), + ('SET-COOKIE', 'c2=cookie2')]) + self.assertIsNone(close) + self.assertIsNone(compression) + def test_conn_close(self): headers, close, compression = protocol.parse_headers( ['', 'connection: close\r\n', '\r\n'], 8190, 32768, 8190) diff --git a/tests/http_session_test.py b/tests/http_session_test.py new file mode 100644 index 00000000..1b86c56d --- /dev/null +++ b/tests/http_session_test.py @@ -0,0 +1,133 @@ +"""Tests for tulip/http/session.py""" + +import http.cookies +import unittest +import unittest.mock + +import tulip +import tulip.http + +from tulip.http.client import HttpResponse +from tulip.http.session import Session + + +class HttpSessionTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + self.transport = unittest.mock.Mock() + self.stream = tulip.StreamBuffer() + self.response = HttpResponse('get', 'http://python.org') + + def test_del(self): + session = Session() + close = session.close = unittest.mock.Mock() + + del session + self.assertTrue(close.called) + + def test_close(self): + tr = unittest.mock.Mock() + + session = Session() + session._conns[1] = [(tr, object())] + session.close() + + self.assertFalse(session._conns) + self.assertTrue(tr.close.called) + + def test_get(self): + session = Session() + self.assertEqual(session._get(1), (None, None)) + + tr, proto = unittest.mock.Mock(), object() + session._conns[1] = [(tr, proto)] + self.assertEqual(session._get(1), (tr, proto)) + + def test_release(self): + session = Session() + resp = unittest.mock.Mock() + resp.message.should_close = False + + cookies = resp.cookies = http.cookies.SimpleCookie() + cookies['c1'] = 'cookie1' + cookies['c2'] = 'cookie2' + + tr, proto = unittest.mock.Mock(), unittest.mock.Mock() + session._release(resp, 1, (tr, proto)) + self.assertEqual(session._conns[1][0], (tr, proto)) + self.assertEqual(session.cookies, dict(cookies.items())) + + def test_release_close(self): + session = Session() + resp = unittest.mock.Mock() + resp.message.should_close = True + + cookies = resp.cookies = http.cookies.SimpleCookie() + cookies['c1'] = 'cookie1' + cookies['c2'] = 'cookie2' + + tr, proto = unittest.mock.Mock(), unittest.mock.Mock() + session._release(resp, 1, (tr, proto)) + self.assertFalse(session._conns) + self.assertTrue(tr.close.called) + + def test_call_new_conn_exc(self): + tr, proto = unittest.mock.Mock(), unittest.mock.Mock() + + class Req: + host = 'host' + port = 80 + ssl = False + + def send(self, *args): + raise ValueError() + + class Loop: + @tulip.coroutine + def create_connection(self, *args, **kw): + return tr, proto + + session = Session() + self.assertRaises( + ValueError, + self.loop.run_until_complete, session.start(Req(), Loop(), True)) + + self.assertTrue(tr.close.called) + + def test_call_existing_conn_exc(self): + existing = unittest.mock.Mock() + tr, proto = unittest.mock.Mock(), unittest.mock.Mock() + + class Req: + host = 'host' + port = 80 + ssl = False + + def send(self, transport): + if transport is existing: + transport.close() + raise ValueError() + else: + return Resp() + + class Resp: + @tulip.coroutine + def start(self, *args, **kw): + pass + + class Loop: + @tulip.coroutine + def create_connection(self, *args, **kw): + return tr, proto + + session = Session() + key = ('host', 80, False) + session._conns[key] = [(existing, object())] + + resp = self.loop.run_until_complete(session.start(Req(), Loop())) + self.assertIsInstance(resp, Resp) + self.assertTrue(existing.close.called) + self.assertFalse(session._conns[key]) diff --git a/tulip/http/__init__.py b/tulip/http/__init__.py index b2a0a26d..a1432dee 100644 --- a/tulip/http/__init__.py +++ b/tulip/http/__init__.py @@ -4,6 +4,7 @@ from .errors import * from .protocol import * from .server import * +from .session import * from .wsgi import * @@ -11,4 +12,5 @@ errors.__all__ + protocol.__all__ + server.__all__ + + session.__all__ + wsgi.__all__) diff --git a/tulip/http/client.py b/tulip/http/client.py index 94455e2f..4c797b8c 100644 --- a/tulip/http/client.py +++ b/tulip/http/client.py @@ -41,7 +41,8 @@ def request(method, url, *, version=(1, 1), timeout=None, compress=None, - chunked=None): + chunked=None, + session=None): """Constructs and sends a request. Returns response object. method: http method @@ -62,6 +63,8 @@ def request(method, url, *, with deflate encoding. chunked: Boolean or Integer. Set to chunk size for chunked transfer encoding. + session: tulip.http.Session instance to support connection pooling and + session cookies. Usage: @@ -82,9 +85,14 @@ def request(method, url, *, cookies=cookies, files=files, auth=auth, encoding=encoding, version=version, compress=compress, chunked=chunked) + if session is None: + conn = start(req, loop) + else: + conn = session.start(req, loop) + # connection timeout try: - resp = yield from tulip.Task(start(req, loop), timeout=timeout) + resp = yield from tulip.Task(conn, timeout=timeout) except tulip.CancelledError: raise tulip.TimeoutError from None @@ -117,6 +125,7 @@ def request(method, url, *, def start(req, loop): transport, p = yield from loop.create_connection( tulip.StreamProtocol, req.host, req.port, ssl=req.ssl) + try: resp = req.send(transport) yield from resp.start(p, transport) @@ -250,18 +259,7 @@ def __init__(self, method, url, *, # cookies if cookies: - c = http.cookies.SimpleCookie() - if 'cookie' in self.headers: - c.load(self.headers.get('cookie', '')) - del self.headers['cookie'] - - for name, value in cookies.items(): - if isinstance(value, http.cookies.Morsel): - dict.__setitem__(c, name, value) - else: - c[name] = value - - self.headers['cookie'] = c.output(header='', sep=';').strip() + self.update_cookies(cookies) # auth if auth: @@ -274,24 +272,15 @@ def __init__(self, method, url, *, else: raise ValueError("Only basic auth is supported") - self._params = (chunked, compress, files, data, encoding) - - def send(self, transport): - chunked, compress, files, data, encoding = self._params - - request = tulip.http.Request( - transport, self.method, self.path, self.version) - # Content-encoding enc = self.headers.get('Content-Encoding', '').lower() if enc: chunked = True # enable chunked, no need to deal with length - request.add_compression_filter(enc) + compress = enc elif compress: chunked = True # enable chunked, no need to deal with length compress = compress if isinstance(compress, str) else 'deflate' self.headers['Content-Encoding'] = compress - request.add_compression_filter(compress) # form data (x-www-form-urlencoded) if isinstance(data, dict): @@ -354,17 +343,48 @@ def send(self, transport): if 'content-length' in self.headers: del self.headers['content-length'] if 'chunked' not in te: - self.headers['Transfer-encoding'] = 'chunked' + self.headers['transfer-encoding'] = 'chunked' - chunk_size = chunked if type(chunked) is int else 8196 - request.add_chunking_filter(chunk_size) + chunked = chunked if type(chunked) is int else 8196 else: if 'chunked' in te: - request.add_chunking_filter(8196) + chunked = 8196 else: - chunked = False + chunked = None self.headers['content-length'] = str(len(self.body)) + self._chunked = chunked + self._compress = compress + + def update_cookies(self, cookies): + """Update request cookies header.""" + c = http.cookies.SimpleCookie() + if 'cookie' in self.headers: + c.load(self.headers.get('cookie', '')) + del self.headers['cookie'] + + if isinstance(cookies, dict): + cookies = cookies.items() + + for name, value in cookies: + if isinstance(value, http.cookies.Morsel): + # use dict method because SimpleCookie class modifies value + dict.__setitem__(c, name, value) + else: + c[name] = value + + self.headers['cookie'] = c.output(header='', sep=';').strip() + + def send(self, transport): + request = tulip.http.Request( + transport, self.method, self.path, self.version) + + if self._compress: + request.add_compression_filter(self._compress) + + if self._chunked is not None: + request.add_chunking_filter(self._chunked) + request.add_headers(*self.headers.items()) request.send_headers() @@ -381,11 +401,15 @@ def send(self, transport): class HttpResponse(http.client.HTTPMessage): + message = None # RawResponseStatus object + # from the Status-Line of the response version = None # HTTP-Version status = None # Status-Code reason = None # Reason-Phrase + cookies = None # Response cookies (Set-Cookie) + content = None # payload stream stream = None # input stream transport = None # current transport @@ -398,6 +422,9 @@ def __init__(self, method, url, host=''): self.host = host self._content = None + def __del__(self): + self.close() + def __repr__(self): out = io.StringIO() print(''.format( @@ -413,20 +440,26 @@ def start(self, stream, transport): httpstream = stream.set_parser(tulip.http.http_response_parser()) # read response - message = yield from httpstream.read() + self.message = yield from httpstream.read() # response status - self.version = message.version - self.status = message.code - self.reason = message.reason + self.version = self.message.version + self.status = self.message.code + self.reason = self.message.reason # headers - for hdr, val in message.headers: + for hdr, val in self.message.headers: self.add_header(hdr, val) # payload self.content = stream.set_parser( - tulip.http.http_payload_parser(message)) + tulip.http.http_payload_parser(self.message)) + + # cookies + self.cookies = http.cookies.SimpleCookie() + if 'Set-Cookie' in self: + for hdr in self.get_all('Set-Cookie'): + self.cookies.load(hdr) return self diff --git a/tulip/http/session.py b/tulip/http/session.py new file mode 100644 index 00000000..baf19dba --- /dev/null +++ b/tulip/http/session.py @@ -0,0 +1,101 @@ +"""client session support.""" + +__all__ = ['Session'] + +import tulip +import http.cookies + + +class Session: + + def __init__(self): + self._conns = {} + self.cookies = http.cookies.SimpleCookie() + + def __del__(self): + self.close() + + def close(self): + """Close all opened transports.""" + for key, data in self._conns.items(): + for transport, proto in data: + transport.close() + + self._conns.clear() + + def update_cookies(self, cookies): + if isinstance(cookies, dict): + cookies = cookies.items() + + for name, value in cookies: + if isinstance(value, http.cookies.Morsel): + # use dict method because SimpleCookie class modifies value + dict.__setitem__(self.cookies, name, value) + else: + self.cookies[name] = value + + @tulip.coroutine + def start(self, req, loop, new_conn=False, set_cookies=True): + key = (req.host, req.port, req.ssl) + + if set_cookies and self.cookies: + req.update_cookies(self.cookies.items()) + + if not new_conn: + transport, proto = self._get(key) + + if new_conn or transport is None: + new = True + transport, proto = yield from loop.create_connection( + tulip.StreamProtocol, req.host, req.port, ssl=req.ssl) + else: + new = False + + try: + resp = req.send(transport) + yield from resp.start( + proto, TransportWrapper( + self._release, key, transport, proto, resp)) + except: + if new: + transport.close() + raise + + return (yield from self.start(req, loop, set_cookies=False)) + + return resp + + def _get(self, key): + conns = self._conns.get(key) + if conns: + return conns.pop() + + return None, None + + def _release(self, resp, key, conn): + msg = resp.message + if msg.should_close: + conn[0].close() + else: + conns = self._conns.get(key) + if conns is None: + conns = self._conns[key] = [] + conns.append(conn) + conn[1].unset_parser() + + if resp.cookies: + self.update_cookies(resp.cookies.items()) + + +class TransportWrapper: + + def __init__(self, release, key, transport, protocol, response): + self.release = release + self.key = key + self.transport = transport + self.protocol = protocol + self.response = response + + def close(self): + self.release(self.response, self.key, + (self.transport, self.protocol)) diff --git a/tulip/parsers.py b/tulip/parsers.py index 689fa4c8..f5a7845a 100644 --- a/tulip/parsers.py +++ b/tulip/parsers.py @@ -165,7 +165,8 @@ def set_parser(self, p): def unset_parser(self): """unset parser, send eof to the parser and then remove it.""" - assert self._parser is not None, 'Paser is not set.' + if self._parser is None: + return try: self._parser.throw(EofStream()) diff --git a/tulip/test_utils.py b/tulip/test_utils.py index 0bd0dfc7..3cd5ba95 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -81,8 +81,7 @@ def handle_request(self, message, payload): response.send_headers() response.write(text) response.write_eof() - - self.transport.close() + self.transport.close() if use_ssl: here = os.path.join(os.path.dirname(__file__), '..', 'tests') @@ -254,3 +253,7 @@ def _response(self, response, body=None, headers=None, chunked=False): # write payload response.write(client.str_to_bytes(body)) response.write_eof() + + # keep-alive + if not response.keep_alive(): + self._transport.close() From 5598720d8107c5f98ff346ba43404c6fa69dacc0 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 2 May 2013 14:16:49 -0700 Subject: [PATCH 0454/1502] Add TODO about using getaddrinfo() to _sock_connect(). --- tulip/selector_events.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 0dd492aa..488dad18 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -263,6 +263,13 @@ def sock_connect(self, sock, address): return fut def _sock_connect(self, fut, registered, sock, address): + # TODO: Use getaddrinfo() to look up the address, to avoid the + # trap of hanging the entire event loop when the address + # requires doing a DNS lookup. (OTOH, the caller should + # already have done this, so it would be nice if we could + # easily tell whether the address needs looking up or not. I + # know how to do this for IPv4, but IPv6 addresses have many + # syntaxes.) fd = sock.fileno() if registered: self.remove_writer(fd) From e623ecad9587104a959f4a31bdbdda3d81fc9221 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 2 May 2013 15:02:42 -0700 Subject: [PATCH 0455/1502] code cleanup and some coverage --- .hgeol | 4 + .hgignore | 12 + Makefile | 32 + NOTES | 176 +++ README | 21 + TODO | 163 +++ check.py | 41 + examples/crawl.py | 104 ++ examples/curl.py | 24 + examples/mpsrv.py | 287 +++++ examples/srv.py | 161 +++ examples/tcp_echo.py | 111 ++ examples/tcp_protocol_parser.py | 170 +++ examples/udp_echo.py | 98 ++ examples/websocket.html | 90 ++ examples/wsclient.py | 97 ++ examples/wssrv.py | 308 ++++++ overlapped.c | 997 +++++++++++++++++ runtests.py | 200 ++++ setup.cfg | 2 + setup.py | 14 + tests/base_events_test.py | 288 +++++ tests/events_test.py | 1501 ++++++++++++++++++++++++++ tests/futures_test.py | 308 ++++++ tests/http_client_functional_test.py | 487 +++++++++ tests/http_client_test.py | 289 +++++ tests/http_parser_test.py | 510 +++++++++ tests/http_protocol_test.py | 384 +++++++ tests/http_server_test.py | 248 +++++ tests/http_session_test.py | 133 +++ tests/http_websocket_test.py | 426 ++++++++ tests/http_wsgi_test.py | 241 +++++ tests/locks_test.py | 803 ++++++++++++++ tests/parsers_test.py | 598 ++++++++++ tests/proactor_events_test.py | 329 ++++++ tests/queues_test.py | 380 +++++++ tests/sample.crt | 14 + tests/sample.key | 15 + tests/selector_events_test.py | 1368 +++++++++++++++++++++++ tests/selectors_test.py | 143 +++ tests/streams_test.py | 293 +++++ tests/subprocess_test.py | 61 ++ tests/tasks_test.py | 749 +++++++++++++ tests/transports_test.py | 45 + tests/unix_events_test.py | 592 ++++++++++ tests/winsocketpair_test.py | 26 + tulip/TODO | 26 + tulip/__init__.py | 28 + tulip/base_events.py | 560 ++++++++++ tulip/constants.py | 4 + tulip/events.py | 398 +++++++ tulip/futures.py | 310 ++++++ tulip/http/__init__.py | 16 + tulip/http/client.py | 560 ++++++++++ tulip/http/errors.py | 44 + tulip/http/protocol.py | 747 +++++++++++++ tulip/http/server.py | 183 ++++ tulip/http/session.py | 101 ++ tulip/http/websocket.py | 227 ++++ tulip/http/wsgi.py | 221 ++++ tulip/locks.py | 434 ++++++++ tulip/log.py | 6 + tulip/parsers.py | 397 +++++++ tulip/proactor_events.py | 199 ++++ tulip/protocols.py | 78 ++ tulip/queues.py | 291 +++++ tulip/selector_events.py | 706 ++++++++++++ tulip/selectors.py | 426 ++++++++ tulip/streams.py | 147 +++ tulip/subprocess_transport.py | 156 +++ tulip/tasks.py | 340 ++++++ tulip/test_utils.py | 259 +++++ tulip/transports.py | 134 +++ tulip/unix_events.py | 312 ++++++ tulip/windows_events.py | 157 +++ tulip/winsocketpair.py | 34 + 76 files changed, 20844 insertions(+) create mode 100644 .hgeol create mode 100644 .hgignore create mode 100644 Makefile create mode 100644 NOTES create mode 100644 README create mode 100644 TODO create mode 100644 check.py create mode 100755 examples/crawl.py create mode 100755 examples/curl.py create mode 100755 examples/mpsrv.py create mode 100755 examples/srv.py create mode 100755 examples/tcp_echo.py create mode 100755 examples/tcp_protocol_parser.py create mode 100755 examples/udp_echo.py create mode 100644 examples/websocket.html create mode 100755 examples/wsclient.py create mode 100755 examples/wssrv.py create mode 100644 overlapped.c create mode 100644 runtests.py create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 tests/base_events_test.py create mode 100644 tests/events_test.py create mode 100644 tests/futures_test.py create mode 100644 tests/http_client_functional_test.py create mode 100644 tests/http_client_test.py create mode 100644 tests/http_parser_test.py create mode 100644 tests/http_protocol_test.py create mode 100644 tests/http_server_test.py create mode 100644 tests/http_session_test.py create mode 100644 tests/http_websocket_test.py create mode 100644 tests/http_wsgi_test.py create mode 100644 tests/locks_test.py create mode 100644 tests/parsers_test.py create mode 100644 tests/proactor_events_test.py create mode 100644 tests/queues_test.py create mode 100644 tests/sample.crt create mode 100644 tests/sample.key create mode 100644 tests/selector_events_test.py create mode 100644 tests/selectors_test.py create mode 100644 tests/streams_test.py create mode 100644 tests/subprocess_test.py create mode 100644 tests/tasks_test.py create mode 100644 tests/transports_test.py create mode 100644 tests/unix_events_test.py create mode 100644 tests/winsocketpair_test.py create mode 100644 tulip/TODO create mode 100644 tulip/__init__.py create mode 100644 tulip/base_events.py create mode 100644 tulip/constants.py create mode 100644 tulip/events.py create mode 100644 tulip/futures.py create mode 100644 tulip/http/__init__.py create mode 100644 tulip/http/client.py create mode 100644 tulip/http/errors.py create mode 100644 tulip/http/protocol.py create mode 100644 tulip/http/server.py create mode 100644 tulip/http/session.py create mode 100644 tulip/http/websocket.py create mode 100644 tulip/http/wsgi.py create mode 100644 tulip/locks.py create mode 100644 tulip/log.py create mode 100644 tulip/parsers.py create mode 100644 tulip/proactor_events.py create mode 100644 tulip/protocols.py create mode 100644 tulip/queues.py create mode 100644 tulip/selector_events.py create mode 100644 tulip/selectors.py create mode 100644 tulip/streams.py create mode 100644 tulip/subprocess_transport.py create mode 100644 tulip/tasks.py create mode 100644 tulip/test_utils.py create mode 100644 tulip/transports.py create mode 100644 tulip/unix_events.py create mode 100644 tulip/windows_events.py create mode 100644 tulip/winsocketpair.py diff --git a/.hgeol b/.hgeol new file mode 100644 index 00000000..8233b6dd --- /dev/null +++ b/.hgeol @@ -0,0 +1,4 @@ +[patterns] +** = native +.hgignore = native +.hgeol = native diff --git a/.hgignore b/.hgignore new file mode 100644 index 00000000..99870025 --- /dev/null +++ b/.hgignore @@ -0,0 +1,12 @@ +.*\.py[co]$ +.*~$ +.*\.orig$ +.*\#.*$ +.*@.*$ +\.coverage$ +htmlcov$ +\.DS_Store$ +venv$ +distribute_setup.py$ +distribute-\d+.\d+.\d+.tar.gz$ +build$ diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..bca4ed9b --- /dev/null +++ b/Makefile @@ -0,0 +1,32 @@ +# Some simple testing tasks (sorry, UNIX only). + +PYTHON=python3 +VERBOSE=$(V) +V= 1 +FLAGS= + +test: + $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS) + +testloop: + while sleep 1; do $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS); done + +# See README for coverage installation instructions. +cov coverage: + $(PYTHON) runtests.py --coverage tulip -v $(VERBOSE) $(FLAGS) + echo "open file://`pwd`/htmlcov/index.html" + +check: + $(PYTHON) check.py + +clean: + rm -rf `find . -name __pycache__` + rm -f `find . -type f -name '*.py[co]' ` + rm -f `find . -type f -name '*~' ` + rm -f `find . -type f -name '.*~' ` + rm -f `find . -type f -name '@*' ` + rm -f `find . -type f -name '#*#' ` + rm -f `find . -type f -name '*.orig' ` + rm -f `find . -type f -name '*.rej' ` + rm -f .coverage + rm -rf htmlcov diff --git a/NOTES b/NOTES new file mode 100644 index 00000000..3b94ba96 --- /dev/null +++ b/NOTES @@ -0,0 +1,176 @@ +Notes from PyCon 2013 sprints +============================= + +- Cancellation. If a task creates several subtasks, and then the + parent task fails, should the subtasks be cancelled? (How do we + even establish the parent/subtask relationship?) + +- Adam Sah suggests that there might be a need for scheduling + (especially when multiple frameworks share an event loop). He + points to lottery scheduling but also mentions that's just one of + the options. However, after posting on python-tulip, it appears + none of the other frameworks have scheduling, and nobody seems to + miss it. + +- Feedback from Bram Cohen (Bittorrent creator) about UDP. He doesn't + think connected UDP is worth supporting, it doesn't do anything + except tell the kernel about the default target address for + sendto(). Basically he says all UDP end points are servers. He + sent me his own UDP event loop so I might glean some tricks from it. + He says we should treat EINTR the same as EAGAIN and friends. (We + should use the exceptions dedicated to errno checking, BTW.) HE + said to make sure we use SO_REUSEADDR (I think we already do). He + said to set the max datagram sizes pretty large (anything larger + than the declared limit is dropped on the floor). He reminds us of + the importance of being able to pick a valid, unused port by binding + to port 0 and then using getsockname(). He has an idea where he's + like to be able to kill all registered callbacks (i.e. Handles) + belonging to a certain "context". I think this can be done at the + application level (you'd have to wrap everything that returns a + Handle and collect these handles in some set or other datastructure) + but if someone thinks it's interesting we could imagine having some + kind of notion of context part of the event loop state, + e.g. associated with a Task (see Cancellation point above). He + brought up uTP (Micro Transport Protocol), a reimplementation of TCP + over UDP with more refined congestion control. + +- Mumblings about UNIX domain sockets and IPv6 addresses being + 4-tuples. The former can be handled by passing in a socket. There + seem to be no real use cases for the latter that can't be dealt with + by passing in suitably esoteric strings for the hostname. + getaddrinfo() will produce the appropriate 4-tuple and connect() + will accept it. + +- Mumblings on the list about add vs. set. + + +Notes from the second Tulip/Twisted meet-up +=========================================== + +Rackspace, 12/11/2012 +Glyph, Brian Warner, David Reid, Duncan McGreggor, others + +Flow control +------------ + +- Pause/resume on transport manages data_received. + +- There's also an API to tell the transport whom to pause when the + write calls are overwhelming it: IConsumer.registerProducer(). + +- There's also something called pipes but it's built on top of the + old interface. + +- Twisted has variations on the basic flow control that I should + ignore. + +Half_close +---------- + +- This sends an EOF after writing some stuff. + +- Can't write any more. + +- Problem with TLS is known (the RFC sadly specifies this behavior). + +- It must be dynamimcally discoverable whether the transport supports + half_close, since the protocol may have to do something different to + make up for its missing (e.g. use chunked encoding). Twisted uses + an interface check for this and also hasattr(trans, 'halfClose') + but a flag (or flag method) is fine too. + +Constructing transport and protocol +----------------------------------- + +- There are good reasons for passing a function to the transport + construction helper that creates the protocol. (You need these + anyway for server-side protocols.) The sequence of events is + something like + + . open socket + . create transport (pass it a socket?) + . create protocol (pass it nothing) + . proto.make_connection(transport); this does: + . self.transport = transport + . self.connection_made(transport) + + But it seems okay to skip make_connection and setting .transport. + Note that make_connection() is a concrete method on the Protocol + implementation base class, while connection_made() is an abstract + method on IProtocol. + +Event Loop +---------- + +- We discussed the sequence of actions in the event loop. I think in the + end we're fine with what Tulip currently does. There are two choices: + + Tulip: + . run ready callbacks until there aren't any left + . poll, adding more callbacks to the ready list + . add now-ready delayed callbacks to the ready list + . go to top + + Tornado: + . run all currently ready callbacks (but not new ones added during this) + . (the rest is the same) + + The difference is that in the Tulip version, CPU bound callbacks + that keep adding more to the queue will starve I/O (and yielding to + other tasks won't actually cause I/O to happen unless you do + e.g. sleep(0.001)). OTOH this may be good because it means there's + less overhead if you frequently split operations in two. + +- I think Twisted does it Tornado style (in a convoluted way :-), but + it may not matter, and it's important to leave this vague so + implementations can do what's best for their platform. (E.g. if the + event loop is built into the OS there are different trade-offs.) + +System call cost +---------------- + +- System calls on MacOS are expensive, on Linux they are cheap. + +- Optimal buffer size ~16K. + +- Try joining small buffer pieces together, but expect to be tuning + this later. + +Futures +------- + +- Futures are the most robust API for async stuff, you can check + errors etc. So let's do this. + +- Just don't implement wait(). + +- For the basics, however, (recv/send, mostly), don't use Futures but use + basic callbacks, transport/protocol style. + +- make_connection() (by any name) can return a Future, it makes it + easier to check for errors. + +- This means revisiting the Tulip proactor branch (IOCP). + +- The semantics of add_done_callback() are fuzzy about in which thread + the callback will be called. (It may be the current thread or + another one.) We don't like that. But always inserting a + call_soon() indirection may be expensive? Glyph suggested changing + the add_done_callback() method name to something else to indicate + the changed promise. + +- Separately, I've been thinking about having two versions of + call_soon() -- a more heavy-weight one to be called from other + threads that also writes a byte to the self-pipe. + +Signals +------- + +- There was a side conversation about signals. A signal handler is + similar to another thread, so probably should use (the heavy-weight + version of) call_soon() to schedule the real callback and not do + anything else. + +- Glyph vaguely recalled some trickiness with the self-pipe. We + should be able to fix this afterwards if necessary, it shouldn't + affect the API design. diff --git a/README b/README new file mode 100644 index 00000000..85bfe5a7 --- /dev/null +++ b/README @@ -0,0 +1,21 @@ +Tulip is the codename for my reference implementation of PEP 3156. + +PEP 3156: http://www.python.org/dev/peps/pep-3156/ + +*** This requires Python 3.3 or later! *** + +Copyright/license: Open source, Apache 2.0. Enjoy. + +Master Mercurial repo: http://code.google.com/p/tulip/ + +The old code lives in the subdirectory 'old'; the new code (conforming +to PEP 3156, under construction) lives in the 'tulip' subdirectory. + +To run tests: + - make test + +To run coverage (coverage package is required): + - make coverage + + +--Guido van Rossum diff --git a/TODO b/TODO new file mode 100644 index 00000000..c6d4eead --- /dev/null +++ b/TODO @@ -0,0 +1,163 @@ +# -*- Mode: text -*- + +TO DO LARGER TASKS + +- Need more examples. + +- Benchmarkable but more realistic HTTP server? + +- Example of using UDP. + +- Write up a tutorial for the scheduling API. + +- More systematic approach to logging. Logger objects? What about + heavy-duty logging, tracking essentially all task state changes? + +- Restructure directory, move demos and benchmarks to subdirectories. + + +TO DO LATER + +- When multiple tasks are accessing the same socket, they should + either get interleaved I/O or an immediate exception; it should not + compromise the integrity of the scheduler or the app or leave a task + hanging. + +- For epoll you probably want to check/(log?) EPOLLHUP and EPOLLERR errors. + +- Add the simplest API possible to run a generator with a timeout. + +- Ensure multiple tasks can do atomic writes to the same pipe (since + UNIX guarantees that short writes to pipes are atomic). + +- Ensure some easy way of distributing accepted connections across tasks. + +- Be wary of thread-local storage. There should be a standard API to + get the current Context (which holds current task, event loop, and + maybe more) and a standard meta-API to change how that standard API + works (i.e. without monkey-patching). + +- See how much of asyncore I've already replaced. + +- Could BufferedReader reuse the standard io module's readers??? + +- Support ZeroMQ "sockets" which are user objects. Though possibly + this can be supported by getting the underlying fd? See + http://mail.python.org/pipermail/python-ideas/2012-October/017532.html + OTOH see + https://github.com/zeromq/pyzmq/blob/master/zmq/eventloop/ioloop.py + +- Study goroutines (again). + +- Benchmarks: http://nichol.as/benchmark-of-python-web-servers + + +FROM OLDER LIST + +- Multiple readers/writers per socket? (At which level? pollster, + eventloop, or scheduler?) + +- Could poll() usefully be an iterator? + +- Do we need to support more epoll and/or kqueue modes/flags/options/etc.? + +- Optimize register/unregister calls away if they cancel each other out? + +- Add explicit wait queue to wait for Task's completion, instead of + callbacks? + +- Look at pyfdpdlib's ioloop.py: + http://code.google.com/p/pyftpdlib/source/browse/trunk/pyftpdlib/lib/ioloop.py + + +MISTAKES I MADE + +- Forgetting yield from. (E.g.: scheduler.sleep(1); listener.accept().) + +- Forgot to add bare yield at end of internal function, after block(). + +- Forgot to call add_done_callback(). + +- Forgot to pass an undoer to block(), bug only found when cancelled. + +- Subtle accounting mistake in a callback. + +- Used context.eventloop from a different thread, forgetting about TLS. + +- Nasty race: eventloop.ready may contain both an I/O callback and a + cancel callback. How to avoid? Keep the DelayedCall in ready. Is + that enough? + +- If a toplevel task raises an error it just stops and nothing is logged + unless you have debug logging on. This confused me. (Then again, + previously I logged whenever a task raised an error, and that was too + chatty...) + +- Forgot to set the connection socket returned by accept() in + nonblocking mode. + +- Nastiest so far (cost me about a day): A race condition in + call_in_thread() where the Future's done_callback (which was + task.unblock()) would run immediately at the time when + add_done_callback() was called, and this screwed over the task + state. Solution: wrap the callback in eventloop.call_later(). + Ironically, I had a comment stating there might be a race condition. + +- Another bug where I was calling unblock() for the current thread + immediately after calling block(), before yielding. + +- readexactly() wasn't checking for EOF, so could be looping. + (Worse, the first fix I attempted was wrong.) + +- Spent a day trying to understand why a tentative patch trying to + move the recv() implementation into the eventloop (or the pollster) + resulted in problems cancelling a recv() call. Ultimately the + problem is that the cancellation mechanism is part of the coroutine + scheduler, which simply throws an exception into a task when it next + runs, and there isn't anything to be interrupted in the eventloop; + but the eventloop still has a reader registered (which will never + fire because I suspended the server -- that's my test case :-). + Then, the eventloop keeps running until the last file descriptor is + unregistered. What contributed to this disaster? + * I didn't build the whole infrastructure, just played with recv() + * I don't have unittests + * I don't have good logging to see what is going + +- In sockets.py, in some SSL error handling code, used the wrong + variable (sock instead of sslsock). A linter would have found this. + +- In polling.py, in KqueuePollster.register_writer(), a copy/paste + error where I was testing for "if fd not in self.readers" instead of + writers. This only came out when I had both a reader and a writer + for the same fd. + +- Submitted some changes prematurely (forgot to pass the filename on + hg ci). + +- Forgot again that shutdown(SHUT_WR) on an ssl socket does not work + as I expected. I ran into this with the origininal sockets.py and + again in transport.py. + +- Having the same callback for both reading and writing has a problem: + it may be scheduled twice, and if the first call closes the socket, + the second runs into trouble. + + +MISTAKES I MADE IN TULIP V2 + +- Nice one: Future._schedule_callbacks() wasn't scheduling any callbacks. + Spot the bug in these four lines: + + def _schedule_callbacks(self): + callbacks = self._callbacks[:] + self._callbacks[:] = [] + for callback in self._callbacks: + self._event_loop.call_soon(callback, self) + + The good news is that I found it with a unittest (albeit not the + unittest intended to exercise this particular method :-( ). + +- In _make_self_pipe_or_sock(), called _pollster.register_reader() + instead of add_reader(), trying to optimize something but breaking + things instead (since the -- internal -- API of register_reader() + had changed). diff --git a/check.py b/check.py new file mode 100644 index 00000000..9ab6bcc0 --- /dev/null +++ b/check.py @@ -0,0 +1,41 @@ +"""Search for lines >= 80 chars or with trailing whitespace.""" + +import sys, os + +def main(): + args = sys.argv[1:] or os.curdir + for arg in args: + if os.path.isdir(arg): + for dn, dirs, files in os.walk(arg): + for fn in sorted(files): + if fn.endswith('.py'): + process(os.path.join(dn, fn)) + dirs[:] = [d for d in dirs if d[0] != '.'] + dirs.sort() + else: + process(arg) + +def isascii(x): + try: + x.encode('ascii') + return True + except UnicodeError: + return False + +def process(fn): + try: + f = open(fn) + except IOError as err: + print(err) + return + try: + for i, line in enumerate(f): + line = line.rstrip('\n') + sline = line.rstrip() + if len(line) >= 80 or line != sline or not isascii(line): + print('{}:{:d}:{}{}'.format( + fn, i+1, sline, '_' * (len(line) - len(sline)))) + finally: + f.close() + +main() diff --git a/examples/crawl.py b/examples/crawl.py new file mode 100755 index 00000000..ac9c25e9 --- /dev/null +++ b/examples/crawl.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 + +import logging +import re +import signal +import sys +import urllib.parse + +import tulip +import tulip.http + + +class Crawler: + + def __init__(self, rooturl, loop, maxtasks=100): + self.rooturl = rooturl + self.loop = loop + self.todo = set() + self.busy = set() + self.done = {} + self.tasks = set() + self.sem = tulip.Semaphore(maxtasks) + + # session stores cookies between requests and uses connection pool + self.session = tulip.http.Session() + + @tulip.task + def run(self): + self.addurls([(self.rooturl, '')]) # Set initial work. + yield from tulip.sleep(1) + while self.busy: + yield from tulip.sleep(1) + + self.session.close() + self.loop.stop() + + @tulip.task + def addurls(self, urls): + for url, parenturl in urls: + url = urllib.parse.urljoin(parenturl, url) + url, frag = urllib.parse.urldefrag(url) + if (url.startswith(self.rooturl) and + url not in self.busy and + url not in self.done and + url not in self.todo): + self.todo.add(url) + yield from self.sem.acquire() + task = self.process(url) + task.add_done_callback(lambda t: self.sem.release()) + task.add_done_callback(self.tasks.remove) + self.tasks.add(task) + + @tulip.task + def process(self, url): + print('processing:', url) + + self.todo.remove(url) + self.busy.add(url) + try: + resp = yield from tulip.http.request( + 'get', url, session=self.session) + except Exception as exc: + print('...', url, 'has error', repr(str(exc))) + self.done[url] = False + else: + if resp.status == 200 and resp.get_content_type() == 'text/html': + data = (yield from resp.read()).decode('utf-8', 'replace') + urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', data) + self.addurls([(u, url) for u in urls]) + + resp.close() + self.done[url] = True + + self.busy.remove(url) + print(len(self.done), 'completed tasks,', len(self.tasks), + 'still pending, todo', len(self.todo)) + + +def main(): + loop = tulip.get_event_loop() + + c = Crawler(sys.argv[1], loop) + c.run() + + try: + loop.add_signal_handler(signal.SIGINT, loop.stop) + except RuntimeError: + pass + loop.run_forever() + print('todo:', len(c.todo)) + print('busy:', len(c.busy)) + print('done:', len(c.done), '; ok:', sum(c.done.values())) + print('tasks:', len(c.tasks)) + + +if __name__ == '__main__': + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + logging.info('using iocp') + el = windows_events.ProactorEventLoop() + events.set_event_loop(el) + + main() diff --git a/examples/curl.py b/examples/curl.py new file mode 100755 index 00000000..7063adcd --- /dev/null +++ b/examples/curl.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 + +import sys +import tulip +import tulip.http + + +def curl(url): + response = yield from tulip.http.request('GET', url) + print(repr(response)) + + data = yield from response.read() + print(data.decode('utf-8', 'replace')) + + +if __name__ == '__main__': + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + el = windows_events.ProactorEventLoop() + events.set_event_loop(el) + + loop = tulip.get_event_loop() + loop.run_until_complete(curl(sys.argv[1])) diff --git a/examples/mpsrv.py b/examples/mpsrv.py new file mode 100755 index 00000000..d6f82d3a --- /dev/null +++ b/examples/mpsrv.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python3 +"""Simple multiprocess http server written using an event loop.""" + +import argparse +import email.message +import os +import socket +import signal +import time +import tulip +import tulip.http +from tulip.http import websocket + +ARGS = argparse.ArgumentParser(description="Run simple http server.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') +ARGS.add_argument( + '--workers', action="store", dest='workers', + default=2, type=int, help='Number of workers.') + + +class HttpServer(tulip.http.ServerHttpProtocol): + + @tulip.coroutine + def handle_request(self, message, payload): + print('{}: method = {!r}; path = {!r}; version = {!r}'.format( + os.getpid(), message.method, message.path, message.version)) + + path = message.path + + if (not (path.isprintable() and path.startswith('/')) or '/.' in path): + path = None + else: + path = '.' + path + if not os.path.exists(path): + path = None + else: + isdir = os.path.isdir(path) + + if not path: + raise tulip.http.HttpStatusException(404) + + headers = email.message.Message() + for hdr, val in message.headers: + headers.add_header(hdr, val) + + if isdir and not path.endswith('/'): + path = path + '/' + raise tulip.http.HttpStatusException( + 302, headers=(('URI', path), ('Location', path))) + + response = tulip.http.Response(self.transport, 200) + response.add_header('Transfer-Encoding', 'chunked') + + # content encoding + accept_encoding = headers.get('accept-encoding', '').lower() + if 'deflate' in accept_encoding: + response.add_header('Content-Encoding', 'deflate') + response.add_compression_filter('deflate') + elif 'gzip' in accept_encoding: + response.add_header('Content-Encoding', 'gzip') + response.add_compression_filter('gzip') + + response.add_chunking_filter(1025) + + if isdir: + response.add_header('Content-type', 'text/html') + response.send_headers() + + response.write(b'
    \r\n') + for name in sorted(os.listdir(path)): + if name.isprintable() and not name.startswith('.'): + try: + bname = name.encode('ascii') + except UnicodeError: + pass + else: + if os.path.isdir(os.path.join(path, name)): + response.write(b'
  • ' + bname + b'/
  • \r\n') + else: + response.write(b'
  • ' + bname + b'
  • \r\n') + response.write(b'
') + else: + response.add_header('Content-type', 'text/plain') + response.send_headers() + + try: + with open(path, 'rb') as fp: + chunk = fp.read(8196) + while chunk: + response.write(chunk) + chunk = fp.read(8196) + except OSError: + response.write(b'Cannot open') + + response.write_eof() + self.close() + + +class ChildProcess: + + def __init__(self, up_read, down_write, args, sock): + self.up_read = up_read + self.down_write = down_write + self.args = args + self.sock = sock + + def start(self): + # start server + self.loop = loop = tulip.new_event_loop() + tulip.set_event_loop(loop) + + def stop(): + self.loop.stop() + os._exit(0) + loop.add_signal_handler(signal.SIGINT, stop) + + f = loop.start_serving(lambda: HttpServer(debug=True), sock=self.sock) + x = loop.run_until_complete(f)[0] + print('Starting srv worker process {} on {}'.format( + os.getpid(), x.getsockname())) + + # heartbeat + self.heartbeat() + + tulip.get_event_loop().run_forever() + os._exit(0) + + @tulip.task + def heartbeat(self): + # setup pipes + read_transport, read_proto = yield from self.loop.connect_read_pipe( + tulip.StreamProtocol, os.fdopen(self.up_read, 'rb')) + write_transport, _ = yield from self.loop.connect_write_pipe( + tulip.StreamProtocol, os.fdopen(self.down_write, 'wb')) + + reader = read_proto.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(write_transport) + + while True: + msg = yield from reader.read() + if msg is None: + print('Superviser is dead, {} stopping...'.format(os.getpid())) + self.loop.stop() + break + elif msg.tp == websocket.MSG_PING: + writer.pong() + elif msg.tp == websocket.MSG_CLOSE: + break + + read_transport.close() + write_transport.close() + + +class Worker: + + _started = False + + def __init__(self, loop, args, sock): + self.loop = loop + self.args = args + self.sock = sock + self.start() + + def start(self): + assert not self._started + self._started = True + + up_read, up_write = os.pipe() + down_read, down_write = os.pipe() + args, sock = self.args, self.sock + + pid = os.fork() + if pid: + # parent + os.close(up_read) + os.close(down_write) + self.connect(pid, up_write, down_read) + else: + # child + os.close(up_write) + os.close(down_read) + + # cleanup after fork + tulip.set_event_loop(None) + + # setup process + process = ChildProcess(up_read, down_write, args, sock) + process.start() + + @tulip.task + def heartbeat(self, writer): + while True: + yield from tulip.sleep(15) + + if (time.monotonic() - self.ping) < 30: + writer.ping() + else: + print('Restart unresponsive worker process: {}'.format( + self.pid)) + self.kill() + self.start() + return + + @tulip.task + def chat(self, reader): + while True: + msg = yield from reader.read() + if msg is None: + print('Restart unresponsive worker process: {}'.format( + self.pid)) + self.kill() + self.start() + return + elif msg.tp == websocket.MSG_PONG: + self.ping = time.monotonic() + + @tulip.task + def connect(self, pid, up_write, down_read): + # setup pipes + read_transport, proto = yield from self.loop.connect_read_pipe( + tulip.StreamProtocol, os.fdopen(down_read, 'rb')) + write_transport, _ = yield from self.loop.connect_write_pipe( + tulip.StreamProtocol, os.fdopen(up_write, 'wb')) + + # websocket protocol + reader = proto.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(write_transport) + + # store info + self.pid = pid + self.ping = time.monotonic() + self.rtransport = read_transport + self.wtransport = write_transport + self.chat_task = self.chat(reader) + self.heartbeat_task = self.heartbeat(writer) + + def kill(self): + self._started = False + self.chat_task.cancel() + self.heartbeat_task.cancel() + self.rtransport.close() + self.wtransport.close() + os.kill(self.pid, signal.SIGTERM) + + +class Superviser: + + def __init__(self, args): + self.loop = tulip.get_event_loop() + self.args = args + self.workers = [] + + def start(self): + # bind socket + sock = self.sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((self.args.host, self.args.port)) + sock.listen(1024) + sock.setblocking(False) + + # start processes + for idx in range(self.args.workers): + self.workers.append(Worker(self.loop, self.args, sock)) + + self.loop.add_signal_handler(signal.SIGINT, lambda: self.loop.stop()) + self.loop.run_forever() + + +def main(): + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + superviser = Superviser(args) + superviser.start() + + +if __name__ == '__main__': + main() diff --git a/examples/srv.py b/examples/srv.py new file mode 100755 index 00000000..2c2ae91e --- /dev/null +++ b/examples/srv.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 +"""Simple server written using an event loop.""" + +import argparse +import email.message +import logging +import os +import sys +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +import tulip +import tulip.http + + +class HttpServer(tulip.http.ServerHttpProtocol): + + @tulip.coroutine + def handle_request(self, message, payload): + print('method = {!r}; path = {!r}; version = {!r}'.format( + message.method, message.path, message.version)) + + path = message.path + + if (not (path.isprintable() and path.startswith('/')) or '/.' in path): + print('bad path', repr(path)) + path = None + else: + path = '.' + path + if not os.path.exists(path): + print('no file', repr(path)) + path = None + else: + isdir = os.path.isdir(path) + + if not path: + raise tulip.http.HttpStatusException(404) + + headers = email.message.Message() + for hdr, val in message.headers: + print(hdr, val) + headers.add_header(hdr, val) + + if isdir and not path.endswith('/'): + path = path + '/' + raise tulip.http.HttpStatusException( + 302, headers=(('URI', path), ('Location', path))) + + response = tulip.http.Response(self.transport, 200) + response.add_header('Transfer-Encoding', 'chunked') + + # content encoding + accept_encoding = headers.get('accept-encoding', '').lower() + if 'deflate' in accept_encoding: + response.add_header('Content-Encoding', 'deflate') + response.add_compression_filter('deflate') + elif 'gzip' in accept_encoding: + response.add_header('Content-Encoding', 'gzip') + response.add_compression_filter('gzip') + + response.add_chunking_filter(1025) + + if isdir: + response.add_header('Content-type', 'text/html') + response.send_headers() + + response.write(b'
    \r\n') + for name in sorted(os.listdir(path)): + if name.isprintable() and not name.startswith('.'): + try: + bname = name.encode('ascii') + except UnicodeError: + pass + else: + if os.path.isdir(os.path.join(path, name)): + response.write(b'
  • ' + bname + b'/
  • \r\n') + else: + response.write(b'
  • ' + bname + b'
  • \r\n') + response.write(b'
') + else: + response.add_header('Content-type', 'text/plain') + response.send_headers() + + try: + with open(path, 'rb') as fp: + chunk = fp.read(8196) + while chunk: + response.write(chunk) + chunk = fp.read(8196) + except OSError: + response.write(b'Cannot open') + + response.write_eof() + self.close() + + +ARGS = argparse.ArgumentParser(description="Run simple http server.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') +ARGS.add_argument( + '--iocp', action="store_true", dest='iocp', help='Windows IOCP event loop') +ARGS.add_argument( + '--ssl', action="store_true", dest='ssl', help='Run ssl mode.') +ARGS.add_argument( + '--sslcert', action="store", dest='certfile', help='SSL cert file.') +ARGS.add_argument( + '--sslkey', action="store", dest='keyfile', help='SSL key file.') + + +def main(): + args = ARGS.parse_args() + + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if args.iocp: + from tulip import windows_events + sys.argv.remove('--iocp') + logging.info('using iocp') + el = windows_events.ProactorEventLoop() + tulip.set_event_loop(el) + + if args.ssl: + here = os.path.join(os.path.dirname(__file__), 'tests') + + if args.certfile: + certfile = args.certfile or os.path.join(here, 'sample.crt') + keyfile = args.keyfile or os.path.join(here, 'sample.key') + else: + certfile = os.path.join(here, 'sample.crt') + keyfile = os.path.join(here, 'sample.key') + + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain(certfile, keyfile) + else: + sslcontext = None + + loop = tulip.get_event_loop() + f = loop.start_serving( + lambda: HttpServer(debug=True), args.host, args.port, ssl=sslcontext) + socks = loop.run_until_complete(f) + print('serving on', socks[0].getsockname()) + try: + loop.run_forever() + except KeyboardInterrupt: + pass + + +if __name__ == '__main__': + main() diff --git a/examples/tcp_echo.py b/examples/tcp_echo.py new file mode 100755 index 00000000..16e3fb65 --- /dev/null +++ b/examples/tcp_echo.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +"""TCP echo server example.""" +import argparse +import tulip +try: + import signal +except ImportError: + signal = None + + +class EchoServer(tulip.Protocol): + + TIMEOUT = 5.0 + + def timeout(self): + print('connection timeout, closing.') + self.transport.close() + + def connection_made(self, transport): + print('connection made') + self.transport = transport + + # start 5 seconds timeout timer + self.h_timeout = tulip.get_event_loop().call_later( + self.TIMEOUT, self.timeout) + + def data_received(self, data): + print('data received: ', data.decode()) + self.transport.write(b'Re: ' + data) + + # restart timeout timer + self.h_timeout.cancel() + self.h_timeout = tulip.get_event_loop().call_later( + self.TIMEOUT, self.timeout) + + def eof_received(self): + pass + + def connection_lost(self, exc): + print('connection lost') + + +class EchoClient(tulip.Protocol): + + message = 'This is the message. It will be echoed.' + + def connection_made(self, transport): + self.transport = transport + self.transport.write(self.message.encode()) + print('data sent:', self.message) + + def data_received(self, data): + print('data received:', data) + + # disconnect after 10 seconds + tulip.get_event_loop().call_later(10.0, self.transport.close) + + def eof_received(self): + pass + + def connection_lost(self, exc): + tulip.get_event_loop().stop() + + +def start_client(loop, host, port): + t = tulip.Task(loop.create_connection(EchoClient, host, port)) + loop.run_until_complete(t) + + +def start_server(loop, host, port): + f = loop.start_serving(EchoServer, host, port) + x = loop.run_until_complete(f)[0] + print('serving on', x.getsockname()) + + +ARGS = argparse.ArgumentParser(description="TCP Echo example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run tcp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run tcp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') + + +if __name__ == '__main__': + args = ARGS.parse_args() + + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + loop = tulip.get_event_loop() + if signal is not None: + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if args.server: + start_server(loop, args.host, args.port) + else: + start_client(loop, args.host, args.port) + + loop.run_forever() diff --git a/examples/tcp_protocol_parser.py b/examples/tcp_protocol_parser.py new file mode 100755 index 00000000..a0258613 --- /dev/null +++ b/examples/tcp_protocol_parser.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +"""Protocol parser example.""" +import argparse +import collections +import tulip +try: + import signal +except ImportError: + signal = None + + +MSG_TEXT = b'text:' +MSG_PING = b'ping:' +MSG_PONG = b'pong:' +MSG_STOP = b'stop:' + +Message = collections.namedtuple('Message', ('tp', 'data')) + + +def my_protocol_parser(): + """Parser is used with StreamBuffer for incremental protocol parsing. + Parser is a generator function, but it is not a coroutine. Usually + parsers are implemented as a state machine. + + more details in tulip/parsers.py + existing parsers: + * http protocol parsers tulip/http/protocol.py + * websocket parser tulip/http/websocket.py + """ + out, buf = yield + + while True: + tp = yield from buf.read(5) + if tp in (MSG_PING, MSG_PONG): + # skip line + yield from buf.skipuntil(b'\r\n') + out.feed_data(Message(tp, None)) + elif tp == MSG_STOP: + out.feed_data(Message(tp, None)) + elif tp == MSG_TEXT: + # read text + text = yield from buf.readuntil(b'\r\n') + out.feed_data(Message(tp, text.strip().decode('utf-8'))) + else: + raise ValueError('Unknown protocol prefix.') + + +class MyProtocolWriter: + + def __init__(self, transport): + self.transport = transport + + def ping(self): + self.transport.write(b'ping:\r\n') + + def pong(self): + self.transport.write(b'pong:\r\n') + + def stop(self): + self.transport.write(b'stop:\r\n') + + def send_text(self, text): + self.transport.write( + 'text:{}\r\n'.format(text.strip()).encode('utf-8')) + + +class EchoServer(tulip.Protocol): + + def connection_made(self, transport): + print('Connection made') + self.transport = transport + self.stream = tulip.StreamBuffer() + self.dispatch() + + def data_received(self, data): + self.stream.feed_data(data) + + def eof_received(self): + self.stream.feed_eof() + + def connection_lost(self, exc): + print('Connection lost') + + @tulip.task + def dispatch(self): + reader = self.stream.set_parser(my_protocol_parser()) + writer = MyProtocolWriter(self.transport) + + while True: + msg = yield from reader.read() + if msg is None: + break # client has been disconnected + + print('Message received: {}'.format(msg)) + + if msg.tp == MSG_PING: + writer.pong() + elif msg.tp == MSG_TEXT: + writer.send_text('Re: ' + msg.data) + elif msg.tp == MSG_STOP: + self.transport.close() + break + + +@tulip.task +def start_client(loop, host, port): + transport, stream = yield from loop.create_connection( + tulip.StreamProtocol, host, port) + reader = stream.set_parser(my_protocol_parser()) + writer = MyProtocolWriter(transport) + writer.ping() + + message = 'This is the message. It will be echoed.' + + while True: + msg = yield from reader.read() + + print('Message received: {}'.format(msg)) + if msg.tp == MSG_PONG: + writer.send_text(message) + print('data sent:', message) + elif msg.tp == MSG_TEXT: + writer.stop() + print('stop sent') + break + + transport.close() + + +def start_server(loop, host, port): + f = loop.start_serving(EchoServer, host, port) + x = loop.run_until_complete(f)[0] + print('serving on', x.getsockname()) + loop.run_forever() + + +ARGS = argparse.ArgumentParser(description="Protocol parser example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run tcp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run tcp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') + + +if __name__ == '__main__': + args = ARGS.parse_args() + + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + loop = tulip.get_event_loop() + if signal is not None: + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if args.server: + start_server(loop, args.host, args.port) + else: + loop.run_until_complete(start_client(loop, args.host, args.port)) diff --git a/examples/udp_echo.py b/examples/udp_echo.py new file mode 100755 index 00000000..0347bfbd --- /dev/null +++ b/examples/udp_echo.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +"""UDP echo example.""" +import argparse +import sys +import tulip +try: + import signal +except ImportError: + signal = None + + +class MyServerUdpEchoProtocol: + + def connection_made(self, transport): + print('start', transport) + self.transport = transport + + def datagram_received(self, data, addr): + print('Data received:', data, addr) + self.transport.sendto(data, addr) + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('stop', exc) + + +class MyClientUdpEchoProtocol: + + message = 'This is the message. It will be echoed.' + + def connection_made(self, transport): + self.transport = transport + print('sending "{}"'.format(self.message)) + self.transport.sendto(self.message.encode()) + print('waiting to receive') + + def datagram_received(self, data, addr): + print('received "{}"'.format(data.decode())) + self.transport.close() + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('closing transport', exc) + loop = tulip.get_event_loop() + loop.stop() + + +def start_server(loop, addr): + t = tulip.Task(loop.create_datagram_endpoint( + MyServerUdpEchoProtocol, local_addr=addr)) + loop.run_until_complete(t) + + +def start_client(loop, addr): + t = tulip.Task(loop.create_datagram_endpoint( + MyClientUdpEchoProtocol, remote_addr=addr)) + loop.run_until_complete(t) + + +ARGS = argparse.ArgumentParser(description="UDP Echo example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run udp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run udp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') + + +if __name__ == '__main__': + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + loop = tulip.get_event_loop() + if signal is not None: + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if '--server' in sys.argv: + start_server(loop, (args.host, args.port)) + else: + start_client(loop, (args.host, args.port)) + + loop.run_forever() diff --git a/examples/websocket.html b/examples/websocket.html new file mode 100644 index 00000000..6bad7f74 --- /dev/null +++ b/examples/websocket.html @@ -0,0 +1,90 @@ + + + + + + + + +

Chat!

+
+  | Status: + disconnected +
+
+
+
+ + +
+ + diff --git a/examples/wsclient.py b/examples/wsclient.py new file mode 100755 index 00000000..f5b2ef58 --- /dev/null +++ b/examples/wsclient.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +"""websocket cmd client for wssrv.py example.""" +import argparse +import base64 +import hashlib +import os +import signal +import sys + +import tulip +import tulip.http +from tulip.http import websocket +import tulip.selectors + +WS_KEY = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + +def start_client(loop, url): + name = input('Please enter your name: ').encode() + + sec_key = base64.b64encode(os.urandom(16)) + + # send request + response = yield from tulip.http.request( + 'get', url, + headers={ + 'UPGRADE': 'WebSocket', + 'CONNECTION': 'Upgrade', + 'SEC-WEBSOCKET-VERSION': '13', + 'SEC-WEBSOCKET-KEY': sec_key.decode(), + }, timeout=1.0) + + # websocket handshake + if response.status != 101: + raise ValueError("Handshake error: Invalid response status") + if response.get('upgrade', '').lower() != 'websocket': + raise ValueError("Handshake error - Invalid upgrade header") + if response.get('connection', '').lower() != 'upgrade': + raise ValueError("Handshake error - Invalid connection header") + + key = response.get('sec-websocket-accept', '').encode() + match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()) + if key != match: + raise ValueError("Handshake error - Invalid challenge response") + + # switch to websocket protocol + stream = response.stream.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(response.transport) + + # input reader + def stdin_callback(): + line = sys.stdin.buffer.readline() + if not line: + loop.stop() + else: + writer.send(name + b': ' + line) + loop.add_reader(sys.stdin.fileno(), stdin_callback) + + @tulip.coroutine + def dispatch(): + while True: + msg = yield from stream.read() + if msg is None: + break + elif msg.tp == websocket.MSG_PING: + writer.pong() + elif msg.tp == websocket.MSG_TEXT: + print(msg.data.strip()) + elif msg.tp == websocket.MSG_CLOSE: + break + + yield from dispatch() + + +ARGS = argparse.ArgumentParser( + description="websocket console client for wssrv.py example.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') + +if __name__ == '__main__': + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + url = 'http://{}:{}'.format(args.host, args.port) + + loop = tulip.SelectorEventLoop(tulip.selectors.SelectSelector()) + tulip.set_event_loop(loop) + + loop.add_signal_handler(signal.SIGINT, loop.stop) + tulip.Task(start_client(loop, url)) + loop.run_forever() diff --git a/examples/wssrv.py b/examples/wssrv.py new file mode 100755 index 00000000..2befec1d --- /dev/null +++ b/examples/wssrv.py @@ -0,0 +1,308 @@ +#!/usr/bin/env python3 +"""Multiprocess WebSocket http chat example.""" + +import argparse +import os +import socket +import signal +import time +import tulip +import tulip.http +from tulip.http import websocket + +ARGS = argparse.ArgumentParser(description="Run simple http server.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') +ARGS.add_argument( + '--workers', action="store", dest='workers', + default=2, type=int, help='Number of workers.') + +WS_FILE = os.path.join(os.path.dirname(__file__), 'websocket.html') + + +class HttpServer(tulip.http.ServerHttpProtocol): + + clients = None # list of all active connections + parent = None # process supervisor + # we use it as broadcaster to all workers + + @tulip.coroutine + def handle_request(self, message, payload): + upgrade = False + for hdr, val in message.headers: + if hdr == 'UPGRADE': + upgrade = 'websocket' in val.lower() + break + + if upgrade: + # websocket handshake + status, headers, parser, writer = websocket.do_handshake( + message, self.transport) + + resp = tulip.http.Response(self.transport, status) + resp.add_headers(*headers) + resp.send_headers() + + # install websocket parser + databuffer = self.stream.set_parser(parser) + + # notify everybody + print('{}: Someone joined.'.format(os.getpid())) + for wsc in self.clients: + wsc.send(b'Someone joined.') + self.clients.append(writer) + self.parent.send(b'Someone joined.') + + # chat dispatcher + while True: + msg = yield from databuffer.read() + if msg is None: # client droped connection + break + + if msg.tp == websocket.MSG_PING: + writer.pong() + + elif msg.tp == websocket.MSG_TEXT: + data = msg.data.strip() + print('{}: {}'.format(os.getpid(), data)) + for wsc in self.clients: + if wsc is not writer: + wsc.send(data.encode()) + self.parent.send(data) + + elif msg.tp == websocket.MSG_CLOSE: + break + + # notify everybody + print('{}: Someone disconnected.'.format(os.getpid())) + self.parent.send(b'Someone disconnected.') + self.clients.remove(writer) + for wsc in self.clients: + wsc.send(b'Someone disconnected.') + + self.close() + else: + # send html page with js chat + response = tulip.http.Response(self.transport, 200) + response.add_header('Transfer-Encoding', 'chunked') + response.add_header('Content-type', 'text/html') + response.send_headers() + + try: + with open(WS_FILE, 'rb') as fp: + chunk = fp.read(8196) + while chunk: + if not response.write(chunk): + break + chunk = fp.read(8196) + except OSError: + response.write(b'Cannot open') + + response.write_eof() + self.close() + + +class ChildProcess: + + def __init__(self, up_read, down_write, args, sock): + self.up_read = up_read + self.down_write = down_write + self.args = args + self.sock = sock + self.clients = [] + + def start(self): + # start server + self.loop = loop = tulip.new_event_loop() + tulip.set_event_loop(loop) + + def stop(): + self.loop.stop() + os._exit(0) + loop.add_signal_handler(signal.SIGINT, stop) + + # heartbeat + self.heartbeat() + + tulip.get_event_loop().run_forever() + os._exit(0) + + @tulip.task + def start_server(self, writer): + socks = yield from self.loop.start_serving( + lambda: HttpServer( + debug=True, parent=writer, clients=self.clients), + sock=self.sock) + print('Starting srv worker process {} on {}'.format( + os.getpid(), socks[0].getsockname())) + + @tulip.task + def heartbeat(self): + # setup pipes + read_transport, read_proto = yield from self.loop.connect_read_pipe( + tulip.StreamProtocol, os.fdopen(self.up_read, 'rb')) + write_transport, _ = yield from self.loop.connect_write_pipe( + tulip.StreamProtocol, os.fdopen(self.down_write, 'wb')) + + reader = read_proto.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(write_transport) + + self.start_server(writer) + + while True: + msg = yield from reader.read() + if msg is None: + print('Superviser is dead, {} stopping...'.format(os.getpid())) + self.loop.stop() + break + elif msg.tp == websocket.MSG_PING: + writer.pong() + elif msg.tp == websocket.MSG_CLOSE: + break + elif msg.tp == websocket.MSG_TEXT: # broadcast message + for wsc in self.clients: + wsc.send(msg.data.strip().encode()) + + read_transport.close() + write_transport.close() + + +class Worker: + + _started = False + + def __init__(self, sv, loop, args, sock): + self.sv = sv + self.loop = loop + self.args = args + self.sock = sock + self.start() + + def start(self): + assert not self._started + self._started = True + + up_read, up_write = os.pipe() + down_read, down_write = os.pipe() + args, sock = self.args, self.sock + + pid = os.fork() + if pid: + # parent + os.close(up_read) + os.close(down_write) + self.connect(pid, up_write, down_read) + else: + # child + os.close(up_write) + os.close(down_read) + + # cleanup after fork + tulip.set_event_loop(None) + + # setup process + process = ChildProcess(up_read, down_write, args, sock) + process.start() + + @tulip.task + def heartbeat(self, writer): + while True: + yield from tulip.sleep(15) + + if (time.monotonic() - self.ping) < 30: + writer.ping() + else: + print('Restart unresponsive worker process: {}'.format( + self.pid)) + self.kill() + self.start() + return + + @tulip.task + def chat(self, reader): + while True: + msg = yield from reader.read() + if msg is None: + print('Restart unresponsive worker process: {}'.format( + self.pid)) + self.kill() + self.start() + return + + elif msg.tp == websocket.MSG_PONG: + self.ping = time.monotonic() + + elif msg.tp == websocket.MSG_TEXT: # broadcast to all workers + for worker in self.sv.workers: + if self.pid != worker.pid: + worker.writer.send(msg.data) + + @tulip.task + def connect(self, pid, up_write, down_read): + # setup pipes + read_transport, proto = yield from self.loop.connect_read_pipe( + tulip.StreamProtocol, os.fdopen(down_read, 'rb')) + write_transport, _ = yield from self.loop.connect_write_pipe( + tulip.StreamProtocol, os.fdopen(up_write, 'wb')) + + # websocket protocol + reader = proto.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(write_transport) + + # store info + self.pid = pid + self.ping = time.monotonic() + self.writer = writer + self.rtransport = read_transport + self.wtransport = write_transport + self.chat_task = self.chat(reader) + self.heartbeat_task = self.heartbeat(writer) + + def kill(self): + self._started = False + self.chat_task.cancel() + self.heartbeat_task.cancel() + self.rtransport.close() + self.wtransport.close() + os.kill(self.pid, signal.SIGTERM) + + +class Superviser: + + def __init__(self, args): + self.loop = tulip.get_event_loop() + self.args = args + self.workers = [] + + def start(self): + # bind socket + sock = self.sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((self.args.host, self.args.port)) + sock.listen(1024) + sock.setblocking(False) + + # start processes + for idx in range(self.args.workers): + self.workers.append(Worker(self, self.loop, self.args, sock)) + + self.loop.add_signal_handler(signal.SIGINT, lambda: self.loop.stop()) + self.loop.run_forever() + + +def main(): + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + superviser = Superviser(args) + superviser.start() + + +if __name__ == '__main__': + main() diff --git a/overlapped.c b/overlapped.c new file mode 100644 index 00000000..c9f6ec9f --- /dev/null +++ b/overlapped.c @@ -0,0 +1,997 @@ +/* + * Support for overlapped IO + * + * Some code borrowed from Modules/_winapi.c of CPython + */ + +/* XXX check overflow and DWORD <-> Py_ssize_t conversions + Check itemsize */ + +#include "Python.h" +#include "structmember.h" + +#define WINDOWS_LEAN_AND_MEAN +#include +#include +#include + +#if defined(MS_WIN32) && !defined(MS_WIN64) +# define F_POINTER "k" +# define T_POINTER T_ULONG +#else +# define F_POINTER "K" +# define T_POINTER T_ULONGLONG +#endif + +#define F_HANDLE F_POINTER +#define F_ULONG_PTR F_POINTER +#define F_DWORD "k" +#define F_BOOL "i" +#define F_UINT "I" + +#define T_HANDLE T_POINTER + +enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_WRITE, TYPE_ACCEPT, + TYPE_CONNECT, TYPE_DISCONNECT}; + +/* + * Map Windows error codes to subclasses of OSError + */ + +static PyObject * +SetFromWindowsErr(DWORD err) +{ + PyObject *exception_type; + + if (err == 0) + err = GetLastError(); + switch (err) { + case ERROR_CONNECTION_REFUSED: + exception_type = PyExc_ConnectionRefusedError; + break; + case ERROR_CONNECTION_ABORTED: + exception_type = PyExc_ConnectionAbortedError; + break; + default: + exception_type = PyExc_OSError; + } + return PyErr_SetExcFromWindowsErr(exception_type, err); +} + +/* + * Some functions should be loaded at runtime + */ + +static LPFN_ACCEPTEX Py_AcceptEx = NULL; +static LPFN_CONNECTEX Py_ConnectEx = NULL; +static LPFN_DISCONNECTEX Py_DisconnectEx = NULL; +static BOOL (CALLBACK *Py_CancelIoEx)(HANDLE, LPOVERLAPPED) = NULL; + +#define GET_WSA_POINTER(s, x) \ + (SOCKET_ERROR != WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, \ + &Guid##x, sizeof(Guid##x), &Py_##x, \ + sizeof(Py_##x), &dwBytes, NULL, NULL)) + +static int +initialize_function_pointers(void) +{ + GUID GuidAcceptEx = WSAID_ACCEPTEX; + GUID GuidConnectEx = WSAID_CONNECTEX; + GUID GuidDisconnectEx = WSAID_DISCONNECTEX; + HINSTANCE hKernel32; + SOCKET s; + DWORD dwBytes; + + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (s == INVALID_SOCKET) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + if (!GET_WSA_POINTER(s, AcceptEx) || + !GET_WSA_POINTER(s, ConnectEx) || + !GET_WSA_POINTER(s, DisconnectEx)) + { + closesocket(s); + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + closesocket(s); + + /* On WinXP we will have Py_CancelIoEx == NULL */ + hKernel32 = GetModuleHandle("KERNEL32"); + *(FARPROC *)&Py_CancelIoEx = GetProcAddress(hKernel32, "CancelIoEx"); + return 0; +} + +/* + * Completion port stuff + */ + +PyDoc_STRVAR( + CreateIoCompletionPort_doc, + "CreateIoCompletionPort(handle, port, key, concurrency) -> port\n\n" + "Create a completion port or register a handle with a port."); + +static PyObject * +overlapped_CreateIoCompletionPort(PyObject *self, PyObject *args) +{ + HANDLE FileHandle; + HANDLE ExistingCompletionPort; + ULONG_PTR CompletionKey; + DWORD NumberOfConcurrentThreads; + HANDLE ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE F_ULONG_PTR F_DWORD, + &FileHandle, &ExistingCompletionPort, &CompletionKey, + &NumberOfConcurrentThreads)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = CreateIoCompletionPort(FileHandle, ExistingCompletionPort, + CompletionKey, NumberOfConcurrentThreads); + Py_END_ALLOW_THREADS + + if (ret == NULL) + return SetFromWindowsErr(0); + return Py_BuildValue(F_HANDLE, ret); +} + +PyDoc_STRVAR( + GetQueuedCompletionStatus_doc, + "GetQueuedCompletionStatus(port, msecs) -> (err, bytes, key, address)\n\n" + "Get a message from completion port. Wait for up to msecs milliseconds."); + +static PyObject * +overlapped_GetQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort = NULL; + DWORD NumberOfBytes = 0; + ULONG_PTR CompletionKey = 0; + OVERLAPPED *Overlapped = NULL; + DWORD Milliseconds; + DWORD err; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, + &CompletionPort, &Milliseconds)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = GetQueuedCompletionStatus(CompletionPort, &NumberOfBytes, + &CompletionKey, &Overlapped, Milliseconds); + Py_END_ALLOW_THREADS + + err = ret ? ERROR_SUCCESS : GetLastError(); + if (Overlapped == NULL) { + if (err == WAIT_TIMEOUT) + Py_RETURN_NONE; + else + return SetFromWindowsErr(err); + } + return Py_BuildValue(F_DWORD F_DWORD F_ULONG_PTR F_POINTER, + err, NumberOfBytes, CompletionKey, Overlapped); +} + +PyDoc_STRVAR( + PostQueuedCompletionStatus_doc, + "PostQueuedCompletionStatus(port, bytes, key, address) -> None\n\n" + "Post a message to completion port."); + +static PyObject * +overlapped_PostQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort; + DWORD NumberOfBytes; + ULONG_PTR CompletionKey; + OVERLAPPED *Overlapped; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD F_ULONG_PTR F_POINTER, + &CompletionPort, &NumberOfBytes, &CompletionKey, + &Overlapped)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = PostQueuedCompletionStatus(CompletionPort, NumberOfBytes, + CompletionKey, Overlapped); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +/* + * Bind socket handle to local port without doing slow getaddrinfo() + */ + +PyDoc_STRVAR( + BindLocal_doc, + "BindLocal(handle, length_of_address_tuple) -> None\n\n" + "Bind a socket handle to an arbitrary local port.\n" + "If length_of_address_tuple is 2 then an AF_INET address is used.\n" + "If length_of_address_tuple is 4 then an AF_INET6 address is used."); + +static PyObject * +overlapped_BindLocal(PyObject *self, PyObject *args) +{ + SOCKET Socket; + int TupleLength; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE "i", &Socket, &TupleLength)) + return NULL; + + if (TupleLength == 2) { + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = 0; + addr.sin_addr.S_un.S_addr = INADDR_ANY; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else if (TupleLength == 4) { + struct sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_port = 0; + addr.sin6_addr = in6addr_any; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else { + PyErr_SetString(PyExc_ValueError, "expected tuple of length 2 or 4"); + return NULL; + } + + if (!ret) + return SetFromWindowsErr(WSAGetLastError()); + Py_RETURN_NONE; +} + +/* + * A Python object wrapping an OVERLAPPED structure and other useful data + * for overlapped I/O + */ + +PyDoc_STRVAR( + Overlapped_doc, + "Overlapped object"); + +typedef struct { + PyObject_HEAD + OVERLAPPED overlapped; + /* For convenience, we store the file handle too */ + HANDLE handle; + /* Error returned by last method call */ + DWORD error; + /* Type of operation */ + DWORD type; + /* Buffer used for reading (optional) */ + PyObject *read_buffer; + /* Buffer used for writing (optional) */ + Py_buffer write_buffer; +} OverlappedObject; + + +static PyObject * +Overlapped_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + OverlappedObject *self; + HANDLE event = INVALID_HANDLE_VALUE; + static char *kwlist[] = {"event", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|" F_HANDLE, kwlist, &event)) + return NULL; + + if (event == INVALID_HANDLE_VALUE) { + event = CreateEvent(NULL, TRUE, FALSE, NULL); + if (event == NULL) + return SetFromWindowsErr(0); + } + + self = PyObject_New(OverlappedObject, type); + if (self == NULL) { + if (event != NULL) + CloseHandle(event); + return NULL; + } + + self->handle = NULL; + self->error = 0; + self->type = TYPE_NONE; + self->read_buffer = NULL; + memset(&self->overlapped, 0, sizeof(OVERLAPPED)); + memset(&self->write_buffer, 0, sizeof(Py_buffer)); + if (event) + self->overlapped.hEvent = event; + return (PyObject *)self; +} + +static void +Overlapped_dealloc(OverlappedObject *self) +{ + DWORD bytes; + DWORD olderr = GetLastError(); + BOOL wait = FALSE; + BOOL ret; + + if (!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED) + { + if (Py_CancelIoEx && Py_CancelIoEx(self->handle, &self->overlapped)) + wait = TRUE; + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, + &bytes, wait); + Py_END_ALLOW_THREADS + + switch (ret ? ERROR_SUCCESS : GetLastError()) { + case ERROR_SUCCESS: + case ERROR_NOT_FOUND: + case ERROR_OPERATION_ABORTED: + break; + default: + PyErr_Format( + PyExc_RuntimeError, + "%R still has pending operation at " + "deallocation, the process may crash", self); + PyErr_WriteUnraisable(NULL); + } + } + + if (self->overlapped.hEvent != NULL) + CloseHandle(self->overlapped.hEvent); + + if (self->write_buffer.obj) + PyBuffer_Release(&self->write_buffer); + + Py_CLEAR(self->read_buffer); + PyObject_Del(self); + SetLastError(olderr); +} + +PyDoc_STRVAR( + Overlapped_cancel_doc, + "cancel() -> None\n\n" + "Cancel overlapped operation"); + +static PyObject * +Overlapped_cancel(OverlappedObject *self) +{ + BOOL ret = TRUE; + + if (self->type == TYPE_NOT_STARTED) + Py_RETURN_NONE; + + if (!HasOverlappedIoCompleted(&self->overlapped)) { + Py_BEGIN_ALLOW_THREADS + if (Py_CancelIoEx) + ret = Py_CancelIoEx(self->handle, &self->overlapped); + else + ret = CancelIo(self->handle); + Py_END_ALLOW_THREADS + } + + /* CancelIoEx returns ERROR_NOT_FOUND if the I/O completed in-between */ + if (!ret && GetLastError() != ERROR_NOT_FOUND) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +PyDoc_STRVAR( + Overlapped_getresult_doc, + "getresult(wait=False) -> result\n\n" + "Retrieve result of operation. If wait is true then it blocks\n" + "until the operation is finished. If wait is false and the\n" + "operation is still pending then an error is raised."); + +static PyObject * +Overlapped_getresult(OverlappedObject *self, PyObject *args) +{ + BOOL wait = FALSE; + DWORD transferred = 0; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, "|" F_BOOL, &wait)) + return NULL; + + if (self->type == TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation not yet attempted"); + return NULL; + } + + if (self->type == TYPE_NOT_STARTED) { + PyErr_SetString(PyExc_ValueError, "operation failed to start"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, &transferred, + wait); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + break; + case ERROR_BROKEN_PIPE: + if (self->read_buffer != NULL) + break; + /* fall through */ + default: + return SetFromWindowsErr(err); + } + + switch (self->type) { + case TYPE_READ: + assert(PyBytes_CheckExact(self->read_buffer)); + if (transferred != PyBytes_GET_SIZE(self->read_buffer) && + _PyBytes_Resize(&self->read_buffer, transferred)) + return NULL; + Py_INCREF(self->read_buffer); + return self->read_buffer; + case TYPE_ACCEPT: + case TYPE_CONNECT: + case TYPE_DISCONNECT: + Py_RETURN_NONE; + default: + return PyLong_FromUnsignedLong((unsigned long) transferred); + } +} + +PyDoc_STRVAR( + Overlapped_ReadFile_doc, + "ReadFile(handle, size) -> Overlapped[message]\n\n" + "Start overlapped read"); + +static PyObject * +Overlapped_ReadFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD nread; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &handle, &size)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = ReadFile(handle, PyBytes_AS_STRING(buf), size, &nread, + &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_BROKEN_PIPE: + self->type = TYPE_NOT_STARTED; + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSARecv_doc, + "RecvFile(handle, size, flags) -> Overlapped[message]\n\n" + "Start overlapped receive"); + +static PyObject * +Overlapped_WSARecv(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD flags = 0; + DWORD nread; + PyObject *buf; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD "|" F_DWORD, + &handle, &size, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + wsabuf.len = size; + wsabuf.buf = PyBytes_AS_STRING(buf); + + Py_BEGIN_ALLOW_THREADS + ret = WSARecv((SOCKET)handle, &wsabuf, 1, &nread, &flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_BROKEN_PIPE: + self->type = TYPE_NOT_STARTED; + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WriteFile_doc, + "WriteFile(handle, buf) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped write"); + +static PyObject * +Overlapped_WriteFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD written; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &handle, &bufobj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + + Py_BEGIN_ALLOW_THREADS + ret = WriteFile(handle, self->write_buffer.buf, + (DWORD)self->write_buffer.len, + &written, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSASend_doc, + "WSASend(handle, buf, flags) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped send"); + +static PyObject * +Overlapped_WSASend(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD flags; + DWORD written; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O" F_DWORD, + &handle, &bufobj, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + wsabuf.len = (DWORD)self->write_buffer.len; + wsabuf.buf = self->write_buffer.buf; + + Py_BEGIN_ALLOW_THREADS + ret = WSASend((SOCKET)handle, &wsabuf, 1, &written, flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_AcceptEx_doc, + "AcceptEx(listen_handle, accept_handle) -> Overlapped[address_as_bytes]\n\n" + "Start overlapped wait for client to connect"); + +static PyObject * +Overlapped_AcceptEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ListenSocket; + SOCKET AcceptSocket; + DWORD BytesReceived; + DWORD size; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE, + &ListenSocket, &AcceptSocket)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + size = sizeof(struct sockaddr_in6) + 16; + buf = PyBytes_FromStringAndSize(NULL, size*2); + if (!buf) + return NULL; + + self->type = TYPE_ACCEPT; + self->handle = (HANDLE)ListenSocket; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = Py_AcceptEx(ListenSocket, AcceptSocket, PyBytes_AS_STRING(buf), + 0, size, size, &BytesReceived, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + + +static int +parse_address(PyObject *obj, SOCKADDR *Address, int Length) +{ + char *Host; + unsigned short Port; + unsigned long FlowInfo; + unsigned long ScopeId; + + memset(Address, 0, Length); + + if (PyArg_ParseTuple(obj, "sH", &Host, &Port)) + { + Address->sa_family = AF_INET; + if (WSAStringToAddressA(Host, AF_INET, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN*)Address)->sin_port = htons(Port); + return Length; + } + else if (PyArg_ParseTuple(obj, "sHkk", &Host, &Port, &FlowInfo, &ScopeId)) + { + PyErr_Clear(); + Address->sa_family = AF_INET6; + if (WSAStringToAddressA(Host, AF_INET6, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN6*)Address)->sin6_port = htons(Port); + ((SOCKADDR_IN6*)Address)->sin6_flowinfo = FlowInfo; + ((SOCKADDR_IN6*)Address)->sin6_scope_id = ScopeId; + return Length; + } + + return -1; +} + + +PyDoc_STRVAR( + Overlapped_ConnectEx_doc, + "ConnectEx(client_handle, address_as_bytes) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_ConnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ConnectSocket; + PyObject *AddressObj; + char AddressBuf[sizeof(struct sockaddr_in6)]; + SOCKADDR *Address = (SOCKADDR*)AddressBuf; + int Length; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &ConnectSocket, &AddressObj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + Length = sizeof(AddressBuf); + Length = parse_address(AddressObj, Address, Length); + if (Length < 0) + return NULL; + + self->type = TYPE_CONNECT; + self->handle = (HANDLE)ConnectSocket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_ConnectEx(ConnectSocket, Address, Length, + NULL, 0, NULL, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_DisconnectEx_doc, + "DisconnectEx(handle, flags) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_DisconnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET Socket; + DWORD flags; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &Socket, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + self->type = TYPE_DISCONNECT; + self->handle = (HANDLE)Socket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_DisconnectEx(Socket, &self->overlapped, flags, 0); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +static PyObject* +Overlapped_getaddress(OverlappedObject *self) +{ + return PyLong_FromVoidPtr(&self->overlapped); +} + +static PyObject* +Overlapped_getpending(OverlappedObject *self) +{ + return PyBool_FromLong(!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED); +} + +static PyMethodDef Overlapped_methods[] = { + {"getresult", (PyCFunction) Overlapped_getresult, + METH_VARARGS, Overlapped_getresult_doc}, + {"cancel", (PyCFunction) Overlapped_cancel, + METH_NOARGS, Overlapped_cancel_doc}, + {"ReadFile", (PyCFunction) Overlapped_ReadFile, + METH_VARARGS, Overlapped_ReadFile_doc}, + {"WSARecv", (PyCFunction) Overlapped_WSARecv, + METH_VARARGS, Overlapped_WSARecv_doc}, + {"WriteFile", (PyCFunction) Overlapped_WriteFile, + METH_VARARGS, Overlapped_WriteFile_doc}, + {"WSASend", (PyCFunction) Overlapped_WSASend, + METH_VARARGS, Overlapped_WSASend_doc}, + {"AcceptEx", (PyCFunction) Overlapped_AcceptEx, + METH_VARARGS, Overlapped_AcceptEx_doc}, + {"ConnectEx", (PyCFunction) Overlapped_ConnectEx, + METH_VARARGS, Overlapped_ConnectEx_doc}, + {"DisconnectEx", (PyCFunction) Overlapped_DisconnectEx, + METH_VARARGS, Overlapped_DisconnectEx_doc}, + {NULL} +}; + +static PyMemberDef Overlapped_members[] = { + {"error", T_ULONG, + offsetof(OverlappedObject, error), + READONLY, "Error from last operation"}, + {"event", T_HANDLE, + offsetof(OverlappedObject, overlapped) + offsetof(OVERLAPPED, hEvent), + READONLY, "Overlapped event handle"}, + {NULL} +}; + +static PyGetSetDef Overlapped_getsets[] = { + {"address", (getter)Overlapped_getaddress, NULL, + "Address of overlapped structure"}, + {"pending", (getter)Overlapped_getpending, NULL, + "Whether the operation is pending"}, + {NULL}, +}; + +PyTypeObject OverlappedType = { + PyVarObject_HEAD_INIT(NULL, 0) + /* tp_name */ "_overlapped.Overlapped", + /* tp_basicsize */ sizeof(OverlappedObject), + /* tp_itemsize */ 0, + /* tp_dealloc */ (destructor) Overlapped_dealloc, + /* tp_print */ 0, + /* tp_getattr */ 0, + /* tp_setattr */ 0, + /* tp_reserved */ 0, + /* tp_repr */ 0, + /* tp_as_number */ 0, + /* tp_as_sequence */ 0, + /* tp_as_mapping */ 0, + /* tp_hash */ 0, + /* tp_call */ 0, + /* tp_str */ 0, + /* tp_getattro */ 0, + /* tp_setattro */ 0, + /* tp_as_buffer */ 0, + /* tp_flags */ Py_TPFLAGS_DEFAULT, + /* tp_doc */ "OVERLAPPED structure wrapper", + /* tp_traverse */ 0, + /* tp_clear */ 0, + /* tp_richcompare */ 0, + /* tp_weaklistoffset */ 0, + /* tp_iter */ 0, + /* tp_iternext */ 0, + /* tp_methods */ Overlapped_methods, + /* tp_members */ Overlapped_members, + /* tp_getset */ Overlapped_getsets, + /* tp_base */ 0, + /* tp_dict */ 0, + /* tp_descr_get */ 0, + /* tp_descr_set */ 0, + /* tp_dictoffset */ 0, + /* tp_init */ 0, + /* tp_alloc */ 0, + /* tp_new */ Overlapped_new, +}; + +static PyMethodDef overlapped_functions[] = { + {"CreateIoCompletionPort", overlapped_CreateIoCompletionPort, + METH_VARARGS, CreateIoCompletionPort_doc}, + {"GetQueuedCompletionStatus", overlapped_GetQueuedCompletionStatus, + METH_VARARGS, GetQueuedCompletionStatus_doc}, + {"PostQueuedCompletionStatus", overlapped_PostQueuedCompletionStatus, + METH_VARARGS, PostQueuedCompletionStatus_doc}, + {"BindLocal", overlapped_BindLocal, + METH_VARARGS, BindLocal_doc}, + {NULL} +}; + +static struct PyModuleDef overlapped_module = { + PyModuleDef_HEAD_INIT, + "_overlapped", + NULL, + -1, + overlapped_functions, + NULL, + NULL, + NULL, + NULL +}; + +#define WINAPI_CONSTANT(fmt, con) \ + PyDict_SetItemString(d, #con, Py_BuildValue(fmt, con)) + +PyMODINIT_FUNC +PyInit__overlapped(void) +{ + PyObject *m, *d; + + /* Ensure WSAStartup() called before initializing function pointers */ + m = PyImport_ImportModule("_socket"); + if (!m) + return NULL; + Py_DECREF(m); + + if (initialize_function_pointers() < 0) + return NULL; + + if (PyType_Ready(&OverlappedType) < 0) + return NULL; + + m = PyModule_Create(&overlapped_module); + if (PyModule_AddObject(m, "Overlapped", (PyObject *)&OverlappedType) < 0) + return NULL; + + d = PyModule_GetDict(m); + + WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING); + WINAPI_CONSTANT(F_DWORD, INFINITE); + WINAPI_CONSTANT(F_HANDLE, INVALID_HANDLE_VALUE); + WINAPI_CONSTANT(F_HANDLE, NULL); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_ACCEPT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_CONNECT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, TF_REUSE_SOCKET); + + return m; +} diff --git a/runtests.py b/runtests.py new file mode 100644 index 00000000..cd334423 --- /dev/null +++ b/runtests.py @@ -0,0 +1,200 @@ +"""Run all unittests. + +Usage: + python3 runtests.py [-v] [-q] [pattern] ... + +Where: + -v: verbose + -q: quiet + pattern: optional regex patterns to match test ids (default all tests) + +Note that the test id is the fully qualified name of the test, +including package, module, class and method, +e.g. 'tests.events_test.PolicyTests.testPolicy'. + +runtests.py with --coverage argument is equivalent of: + + $(COVERAGE) run --branch runtests.py -v + $(COVERAGE) html $(list of files) + $(COVERAGE) report -m $(list of files) + +""" + +# Originally written by Beech Horn (for NDB). + +import argparse +import logging +import os +import re +import sys +import subprocess +import unittest +import importlib.machinery + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +ARGS = argparse.ArgumentParser(description="Run all unittests.") +ARGS.add_argument( + '-v', action="store", dest='verbose', + nargs='?', const=1, type=int, default=0, help='verbose') +ARGS.add_argument( + '-x', action="store_true", dest='exclude', help='exclude tests') +ARGS.add_argument( + '-q', action="store_true", dest='quiet', help='quiet') +ARGS.add_argument( + '--tests', action="store", dest='testsdir', default='tests', + help='tests directory') +ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') +ARGS.add_argument( + 'pattern', action="store", nargs="*", + help='optional regex patterns to match test ids (default all tests)') + +COV_ARGS = argparse.ArgumentParser(description="Run all unittests.") +COV_ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') + + +def load_modules(basedir, suffix='.py'): + def list_dir(prefix, dir): + files = [] + + modpath = os.path.join(dir, '__init__.py') + if os.path.isfile(modpath): + mod = os.path.split(dir)[-1] + files.append(('{}{}'.format(prefix, mod), modpath)) + + prefix = '{}{}.'.format(prefix, mod) + + for name in os.listdir(dir): + path = os.path.join(dir, name) + + if os.path.isdir(path): + files.extend(list_dir('{}{}.'.format(prefix, name), path)) + else: + if (name != '__init__.py' and + name.endswith(suffix) and + not name.startswith(('.', '_'))): + files.append(('{}{}'.format(prefix, name[:-3]), path)) + + return files + + mods = [] + for modname, sourcefile in list_dir('', basedir): + if modname == 'runtests': + continue + try: + loader = importlib.machinery.SourceFileLoader(modname, sourcefile) + mods.append((loader.load_module(), sourcefile)) + except SyntaxError: + raise + except Exception as err: + print("Skipping '{}': {}".format(modname, err), file=sys.stderr) + + return mods + + +def load_tests(testsdir, includes=(), excludes=()): + mods = [mod for mod, _ in load_modules(testsdir)] + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + for mod in mods: + for name in set(dir(mod)): + if name.endswith('Tests'): + test_module = getattr(mod, name) + tests = loader.loadTestsFromTestCase(test_module) + if includes: + tests = [test + for test in tests + if any(re.search(pat, test.id()) + for pat in includes)] + if excludes: + tests = [test + for test in tests + if not any(re.search(pat, test.id()) + for pat in excludes)] + suite.addTests(tests) + + return suite + + +def runtests(): + args = ARGS.parse_args() + + testsdir = os.path.abspath(args.testsdir) + if not os.path.isdir(testsdir): + print("Tests directory is not found: {}\n".format(testsdir)) + ARGS.print_help() + return + + excludes = includes = [] + if args.exclude: + excludes = args.pattern + else: + includes = args.pattern + + v = 0 if args.quiet else args.verbose + 1 + + tests = load_tests(args.testsdir, includes, excludes) + logger = logging.getLogger() + if v == 0: + logger.setLevel(logging.CRITICAL) + elif v == 1: + logger.setLevel(logging.ERROR) + elif v == 2: + logger.setLevel(logging.WARNING) + elif v == 3: + logger.setLevel(logging.INFO) + elif v >= 4: + logger.setLevel(logging.DEBUG) + result = unittest.TextTestRunner(verbosity=v).run(tests) + sys.exit(not result.wasSuccessful()) + + +def runcoverage(sdir, args): + """ + To install coverage3 for Python 3, you need: + - Distribute (http://packages.python.org/distribute/) + + What worked for me: + - download http://python-distribute.org/distribute_setup.py + * curl -O http://python-distribute.org/distribute_setup.py + - python3 distribute_setup.py + - python3 -m easy_install coverage + """ + try: + import coverage + except ImportError: + print("Coverage package is not found.") + print(runcoverage.__doc__) + return + + sdir = os.path.abspath(sdir) + if not os.path.isdir(sdir): + print("Python files directory is not found: {}\n".format(sdir)) + ARGS.print_help() + return + + mods = [source for _, source in load_modules(sdir)] + coverage = [sys.executable, '-m', 'coverage'] + + try: + subprocess.check_call( + coverage + ['run', '--branch', 'runtests.py'] + args) + except: + pass + else: + subprocess.check_call(coverage + ['html'] + mods) + subprocess.check_call(coverage + ['report'] + mods) + + +if __name__ == '__main__': + if '--coverage' in sys.argv: + cov_args, args = COV_ARGS.parse_known_args() + runcoverage(cov_args.coverage, args) + else: + runtests() diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..0260f9d5 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[build_ext] +build_lib=tulip diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..dcaee96f --- /dev/null +++ b/setup.py @@ -0,0 +1,14 @@ +import os +from distutils.core import setup, Extension + +extensions = [] +if os.name == 'nt': + ext = Extension('_overlapped', ['overlapped.c'], libraries=['ws2_32']) + extensions.append(ext) + +setup(name='tulip', + description="reference implementation of PEP 3156", + url='http://www.python.org/dev/peps/pep-3156/', + packages=['tulip'], + ext_modules=extensions + ) diff --git a/tests/base_events_test.py b/tests/base_events_test.py new file mode 100644 index 00000000..aeb1cd8c --- /dev/null +++ b/tests/base_events_test.py @@ -0,0 +1,288 @@ +"""Tests for base_events.py""" + +import concurrent.futures +import logging +import socket +import time +import unittest +import unittest.mock + +from tulip import base_events +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import tasks + + +class BaseEventLoopTests(unittest.TestCase): + + def setUp(self): + self.event_loop = base_events.BaseEventLoop() + self.event_loop._selector = unittest.mock.Mock() + self.event_loop._selector.registered_count.return_value = 1 + + def test_not_implemented(self): + m = unittest.mock.Mock() + self.assertRaises( + NotImplementedError, + self.event_loop._make_socket_transport, m, m) + self.assertRaises( + NotImplementedError, + self.event_loop._make_ssl_transport, m, m, m, m) + self.assertRaises( + NotImplementedError, + self.event_loop._make_datagram_transport, m, m) + self.assertRaises( + NotImplementedError, self.event_loop._process_events, []) + self.assertRaises( + NotImplementedError, self.event_loop._write_to_self) + self.assertRaises( + NotImplementedError, self.event_loop._read_from_self) + self.assertRaises( + NotImplementedError, + self.event_loop._make_read_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + self.event_loop._make_write_pipe_transport, m, m) + + def test_add_callback_handle(self): + h = events.Handle(lambda: False, ()) + + self.event_loop._add_callback(h) + self.assertFalse(self.event_loop._scheduled) + self.assertIn(h, self.event_loop._ready) + + def test_add_callback_timer(self): + when = time.monotonic() + + h1 = events.TimerHandle(when, lambda: False, ()) + h2 = events.TimerHandle(when+10.0, lambda: False, ()) + + self.event_loop._add_callback(h2) + self.event_loop._add_callback(h1) + self.assertEqual([h1, h2], self.event_loop._scheduled) + self.assertFalse(self.event_loop._ready) + + def test_add_callback_cancelled_handle(self): + h = events.Handle(lambda: False, ()) + h.cancel() + + self.event_loop._add_callback(h) + self.assertFalse(self.event_loop._scheduled) + self.assertFalse(self.event_loop._ready) + + def test_wrap_future(self): + f = futures.Future(event_loop=self.event_loop) + self.assertIs(self.event_loop.wrap_future(f), f) + f.cancel() + + def test_wrap_future_concurrent(self): + f = concurrent.futures.Future() + fut = self.event_loop.wrap_future(f) + self.assertIsInstance(fut, futures.Future) + fut.cancel() + + def test_set_default_executor(self): + executor = unittest.mock.Mock() + self.event_loop.set_default_executor(executor) + self.assertIs(executor, self.event_loop._default_executor) + + def test_getnameinfo(self): + sockaddr = unittest.mock.Mock() + self.event_loop.run_in_executor = unittest.mock.Mock() + self.event_loop.getnameinfo(sockaddr) + self.assertEqual( + (None, socket.getnameinfo, sockaddr, 0), + self.event_loop.run_in_executor.call_args[0]) + + def test_call_soon(self): + def cb(): + pass + + h = self.event_loop.call_soon(cb) + self.assertEqual(h._callback, cb) + self.assertIsInstance(h, events.Handle) + self.assertIn(h, self.event_loop._ready) + + def test_call_later(self): + def cb(): + pass + + h = self.event_loop.call_later(10.0, cb) + self.assertIsInstance(h, events.TimerHandle) + self.assertIn(h, self.event_loop._scheduled) + self.assertNotIn(h, self.event_loop._ready) + + def test_call_later_no_delay(self): + def cb(): + pass + + h = self.event_loop.call_later(0, cb) + self.assertIn(h, self.event_loop._ready) + self.assertNotIn(h, self.event_loop._scheduled) + + def test_run_once_in_executor_handle(self): + def cb(): + pass + + self.assertRaises( + AssertionError, self.event_loop.run_in_executor, + None, events.Handle(cb, ()), ('',)) + self.assertRaises( + AssertionError, self.event_loop.run_in_executor, + None, events.TimerHandle(10, cb, ())) + + def test_run_once_in_executor_cancelled(self): + def cb(): + pass + h = events.Handle(cb, ()) + h.cancel() + + f = self.event_loop.run_in_executor(None, h) + self.assertIsInstance(f, futures.Future) + self.assertTrue(f.done()) + self.assertIsNone(f.result()) + + def test_run_once_in_executor_plain(self): + def cb(): + pass + h = events.Handle(cb, ()) + f = futures.Future() + executor = unittest.mock.Mock() + executor.submit.return_value = f + + self.event_loop.set_default_executor(executor) + + res = self.event_loop.run_in_executor(None, h) + self.assertIs(f, res) + + executor = unittest.mock.Mock() + executor.submit.return_value = f + res = self.event_loop.run_in_executor(executor, h) + self.assertIs(f, res) + self.assertTrue(executor.submit.called) + + f.cancel() # Don't complain about abandoned Future. + + def test_run_once(self): + self.event_loop._run_once = unittest.mock.Mock() + self.event_loop._run_once.side_effect = base_events._StopError + self.event_loop.run_once() + self.assertTrue(self.event_loop._run_once.called) + + def test__run_once(self): + h1 = events.TimerHandle(time.monotonic() + 0.1, lambda: True, ()) + h2 = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ()) + + h1.cancel() + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h1) + self.event_loop._scheduled.append(h2) + self.event_loop._run_once() + + t = self.event_loop._selector.select.call_args[0][0] + self.assertTrue(9.99 < t < 10.1) + self.assertEqual([h2], self.event_loop._scheduled) + self.assertTrue(self.event_loop._process_events.called) + + def test__run_once_timeout(self): + h = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ()) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._run_once(1.0) + self.assertEqual((1.0,), self.event_loop._selector.select.call_args[0]) + + def test__run_once_timeout_with_ready(self): + # If event loop has ready callbacks, select timeout is always 0. + h = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ()) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._ready.append(h) + self.event_loop._run_once(1.0) + + self.assertEqual((0,), self.event_loop._selector.select.call_args[0]) + + @unittest.mock.patch('tulip.base_events.time') + @unittest.mock.patch('tulip.base_events.tulip_log') + def test__run_once_logging(self, m_logging, m_time): + # Log to INFO level if timeout > 1.0 sec. + idx = -1 + data = [10.0, 10.0, 12.0, 13.0] + + def monotonic(): + nonlocal data, idx + idx += 1 + return data[idx] + + m_time.monotonic = monotonic + m_logging.INFO = logging.INFO + m_logging.DEBUG = logging.DEBUG + + self.event_loop._scheduled.append(events.TimerHandle(11.0, + lambda: True, ())) + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._run_once() + self.assertEqual(logging.INFO, m_logging.log.call_args[0][0]) + + idx = -1 + data = [10.0, 10.0, 10.3, 13.0] + self.event_loop._scheduled = [events.TimerHandle(11.0, + lambda:True, ())] + self.event_loop._run_once() + self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0]) + + def test__run_once_schedule_handle(self): + handle = None + processed = False + + def cb(event_loop): + nonlocal processed, handle + processed = True + handle = event_loop.call_soon(lambda: True) + + h = events.TimerHandle(time.monotonic() - 1, cb, (self.event_loop,)) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop._scheduled.append(h) + self.event_loop._run_once() + + self.assertTrue(processed) + self.assertEqual([handle], list(self.event_loop._ready)) + + def test_run_until_complete_assertion(self): + self.assertRaises( + AssertionError, self.event_loop.run_until_complete, 'blah') + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_connection_mutiple_errors(self, m_socket): + + class MyProto(protocols.Protocol): + pass + + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + + idx = -1 + errors = ['err1', 'err2'] + + def _socket(*args, **kw): + nonlocal idx, errors + idx += 1 + raise socket.error(errors[idx]) + + m_socket.socket = _socket + m_socket.error = socket.error + + self.event_loop.getaddrinfo = getaddrinfo + + task = tasks.Task( + self.event_loop.create_connection(MyProto, 'example.com', 80)) + task._step() + exc = task.exception() + self.assertEqual("Multiple exceptions: err1, err2", str(exc)) diff --git a/tests/events_test.py b/tests/events_test.py new file mode 100644 index 00000000..94760f53 --- /dev/null +++ b/tests/events_test.py @@ -0,0 +1,1501 @@ +"""Tests for events.py.""" + +import concurrent.futures +import gc +import io +import os +import re +import signal +import socket +try: + import ssl +except ImportError: + ssl = None +import sys +import threading +import time +import errno +import unittest +import unittest.mock +from test.support import find_unused_port + + +from tulip import futures +from tulip import events +from tulip import transports +from tulip import protocols +from tulip import selector_events +from tulip import tasks +from tulip import test_utils + + +class MyProto(protocols.Protocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = futures.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyDatagramProto(protocols.DatagramProtocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = futures.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyReadPipeProto(protocols.Protocol): + done = None + + def __init__(self, create_future=False): + self.state = ['INITIAL'] + self.nbytes = 0 + self.transport = None + if create_future: + self.done = futures.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == ['INITIAL'], self.state + self.state.append('CONNECTED') + + def data_received(self, data): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.state.append('EOF') + self.transport.close() + + def connection_lost(self, exc): + assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state + self.state.append('CLOSED') + if self.done: + self.done.set_result(None) + + +class MyWritePipeProto(protocols.Protocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.transport = None + if create_future: + self.done = futures.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class EventLoopTestsMixin: + + def setUp(self): + super().setUp() + self.event_loop = self.create_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + gc.collect() + super().tearDown() + + def test_run_nesting(self): + @tasks.coroutine + def coro(): + self.assertTrue(self.event_loop.is_running()) + self.event_loop.run_until_complete(tasks.sleep(0.1)) + + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, coro()) + + def test_run_once_nesting(self): + @tasks.coroutine + def coro(): + tasks.sleep(0.1) + self.event_loop.run_once() + + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, coro()) + + def test_run_once_block(self): + called = False + + def callback(): + nonlocal called + called = True + + def run(): + time.sleep(0.1) + self.event_loop.call_soon_threadsafe(callback) + + self.event_loop.run_once(0) # windows iocp + + t = threading.Thread(target=run) + t0 = time.monotonic() + t.start() + self.event_loop.run_once(None) + t1 = time.monotonic() + t.join() + self.assertTrue(called) + self.assertTrue(0.09 < t1-t0 <= 0.15) + + def test_call_later(self): + results = [] + + def callback(arg): + results.append(arg) + self.event_loop.stop() + + self.event_loop.call_later(0.1, callback, 'hello world') + t0 = time.monotonic() + self.event_loop.run_forever() + t1 = time.monotonic() + self.assertEqual(results, ['hello world']) + self.assertTrue(0.09 <= t1-t0 <= 0.12) + + def test_call_repeatedly(self): + results = [] + + def callback(arg): + results.append(arg) + + self.event_loop.call_repeatedly(0.03, callback, 'ho') + self.event_loop.call_later(0.1, self.event_loop.stop) + self.event_loop.run_forever() + self.assertEqual(results, ['ho', 'ho', 'ho']) + + def test_call_soon(self): + results = [] + + def callback(arg1, arg2): + results.append((arg1, arg2)) + self.event_loop.stop() + + self.event_loop.call_soon(callback, 'hello', 'world') + self.event_loop.run_forever() + self.assertEqual(results, [('hello', 'world')]) + + def test_call_soon_with_handle(self): + results = [] + + def callback(): + results.append('yeah') + self.event_loop.stop() + + handle = events.Handle(callback, ()) + self.assertIs(self.event_loop.call_soon(handle), handle) + self.event_loop.run_forever() + self.assertEqual(results, ['yeah']) + + def test_call_soon_threadsafe(self): + results = [] + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.event_loop.stop() + + def run_in_thread(): + self.event_loop.call_soon_threadsafe(callback, 'hello') + + t = threading.Thread(target=run_in_thread) + self.event_loop.call_later(0.1, callback, 'world') + t0 = time.monotonic() + t.start() + self.event_loop.run_forever() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_call_soon_threadsafe_same_thread(self): + results = [] + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.event_loop.stop() + + self.event_loop.call_later(0.1, callback, 'world') + self.event_loop.call_soon_threadsafe(callback, 'hello') + self.event_loop.run_forever() + self.assertEqual(results, ['hello', 'world']) + + def test_call_soon_threadsafe_with_handle(self): + results = [] + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.event_loop.stop() + + handle = events.Handle(callback, ('hello',)) + + def run(): + self.assertIs( + self.event_loop.call_soon_threadsafe(handle), handle) + + t = threading.Thread(target=run) + self.event_loop.call_later(0.1, callback, 'world') + + t0 = time.monotonic() + t.start() + self.event_loop.run_forever() + t1 = time.monotonic() + t.join() + self.assertEqual(results, ['hello', 'world']) + self.assertTrue(t1-t0 >= 0.09) + + def test_wrap_future(self): + def run(arg): + time.sleep(0.1) + return arg + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = self.event_loop.wrap_future(f1) + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'oi') + + def test_run_in_executor(self): + def run(arg): + time.sleep(0.1) + return arg + f2 = self.event_loop.run_in_executor(None, run, 'yo') + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def test_run_in_executor_with_handle(self): + def run(arg): + time.sleep(0.01) + return arg + handle = events.Handle(run, ('yo',)) + f2 = self.event_loop.run_in_executor(None, handle) + res = self.event_loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + + def test_reader_callback(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.event_loop.remove_reader(r.fileno())) + r.close() + + self.event_loop.add_reader(r.fileno(), reader) + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.call_later(0.16, self.event_loop.stop) + self.event_loop.run_forever() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_reader_callback_with_handle(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.event_loop.remove_reader(r.fileno())) + r.close() + + handle = events.Handle(reader, ()) + self.assertIs(handle, self.event_loop.add_reader(r.fileno(), handle)) + + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.call_later(0.16, self.event_loop.stop) + self.event_loop.run_forever() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_reader_callback_cancel(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + return + if data: + bytes_read.append(data) + if sum(len(b) for b in bytes_read) >= 6: + handle.cancel() + if not data: + r.close() + + handle = self.event_loop.add_reader(r.fileno(), reader) + self.event_loop.call_later(0.05, w.send, b'abc') + self.event_loop.call_later(0.1, w.send, b'def') + self.event_loop.call_later(0.15, w.close) + self.event_loop.call_later(0.16, self.event_loop.stop) + self.event_loop.run_forever() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_writer_callback(self): + r, w = test_utils.socketpair() + w.setblocking(False) + self.event_loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + + def remove_writer(): + self.assertTrue(self.event_loop.remove_writer(w.fileno())) + + self.event_loop.call_later(0.1, remove_writer) + self.event_loop.call_later(0.11, self.event_loop.stop) + self.event_loop.run_forever() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def test_writer_callback_with_handle(self): + r, w = test_utils.socketpair() + w.setblocking(False) + handle = events.Handle(w.send, (b'x'*(256*1024),)) + self.assertIs(self.event_loop.add_writer(w.fileno(), handle), handle) + + def remove_writer(): + self.assertTrue(self.event_loop.remove_writer(w.fileno())) + + self.event_loop.call_later(0.1, remove_writer) + self.event_loop.call_later(0.11, self.event_loop.stop) + self.event_loop.run_forever() + w.close() + data = r.recv(256*1024) + r.close() + self.assertTrue(len(data) >= 200) + + def test_writer_callback_cancel(self): + r, w = test_utils.socketpair() + w.setblocking(False) + + def sender(): + w.send(b'x'*256) + handle.cancel() + self.event_loop.stop() + + handle = self.event_loop.add_writer(w.fileno(), sender) + self.event_loop.run_forever() + w.close() + data = r.recv(1024) + r.close() + self.assertTrue(data == b'x'*256) + + def test_sock_client_ops(self): + with test_utils.run_test_server(self.event_loop, host='') as httpd: + sock = socket.socket() + sock.setblocking(False) + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, httpd.address)) + self.event_loop.run_until_complete( + self.event_loop.sock_sendall( + sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.event_loop.run_until_complete( + self.event_loop.sock_recv(sock, 1024)) + # consume data + self.event_loop.run_until_complete( + self.event_loop.sock_recv(sock, 1024)) + sock.close() + + self.assertTrue(re.match(rb'HTTP/1.0 200 OK', data), data) + + def test_sock_client_fail(self): + # Make sure that we will get an unused port + address = None + try: + s = socket.socket() + s.bind(('127.0.0.1', 0)) + address = s.getsockname() + finally: + s.close() + + sock = socket.socket() + sock.setblocking(False) + with self.assertRaises(ConnectionRefusedError): + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, address)) + sock.close() + + def test_sock_accept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + + f = self.event_loop.sock_accept(listener) + conn, addr = self.event_loop.run_until_complete(f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + def test_add_signal_handler(self): + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + # Check error behavior first. + self.assertRaises( + TypeError, self.event_loop.add_signal_handler, 'boom', my_handler) + self.assertRaises( + TypeError, self.event_loop.remove_signal_handler, 'boom') + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, signal.NSIG+1) + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, 0, my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, 0) + self.assertRaises( + ValueError, self.event_loop.add_signal_handler, -1, my_handler) + self.assertRaises( + ValueError, self.event_loop.remove_signal_handler, -1) + self.assertRaises( + RuntimeError, self.event_loop.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + self.event_loop.add_signal_handler(signal.SIGINT, my_handler) + self.event_loop.run_once() + os.kill(os.getpid(), signal.SIGINT) + self.event_loop.run_once() + self.assertEqual(caught, 1) + # Removing it should restore the default handler. + self.assertTrue(self.event_loop.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGINT)) + + @unittest.skipIf(sys.platform == 'win32', 'Unix only') + def test_cancel_signal_handler(self): + # Cancelling the handler should remove it (eventually). + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + handle = self.event_loop.add_signal_handler(signal.SIGINT, my_handler) + handle.cancel() + os.kill(os.getpid(), signal.SIGINT) + self.event_loop.run_once() + self.assertEqual(caught, 0) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_while_selecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + self.event_loop.add_signal_handler( + signal.SIGALRM, my_handler) + + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + self.event_loop.call_later(0.15, self.event_loop.stop) + self.event_loop.run_forever() + self.assertEqual(caught, 1) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_args(self): + some_args = (42,) + caught = 0 + + def my_handler(*args): + nonlocal caught + caught += 1 + self.assertEqual(args, some_args) + + self.event_loop.add_signal_handler( + signal.SIGALRM, my_handler, *some_args) + + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + self.event_loop.call_later(0.15, self.event_loop.stop) + self.event_loop.run_forever() + self.assertEqual(caught, 1) + + def test_create_connection(self): + with test_utils.run_test_server(self.event_loop) as httpd: + f = self.event_loop.create_connection( + lambda: MyProto(create_future=True), + *httpd.address) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.event_loop.run_until_complete(pr.done) + self.assertTrue(pr.nbytes > 0) + + def test_create_connection_sock(self): + with test_utils.run_test_server(self.event_loop) as httpd: + sock = None + infos = self.event_loop.run_until_complete( + self.event_loop.getaddrinfo( + *httpd.address, type=socket.SOCK_STREAM)) + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + self.event_loop.run_until_complete( + self.event_loop.sock_connect(sock, address)) + except: + pass + else: + break + else: + assert False, 'Can not create socket.' + + f = self.event_loop.create_connection( + lambda: MyProto(create_future=True), sock=sock) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.event_loop.run_until_complete(pr.done) + self.assertTrue(pr.nbytes > 0) + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_ssl_connection(self): + with test_utils.run_test_server( + self.event_loop, use_ssl=True) as httpd: + f = self.event_loop.create_connection( + lambda: MyProto(create_future=True), *httpd.address, ssl=True) + tr, pr = self.event_loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + self.assertTrue( + hasattr(tr.get_extra_info('socket'), 'getsockname')) + self.event_loop.run_until_complete(pr.done) + self.assertTrue(pr.nbytes > 0) + + def test_create_connection_host_port_sock(self): + coro = self.event_loop.create_connection( + MyProto, 'example.com', 80, sock=object()) + self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) + + def test_create_connection_no_host_port_sock(self): + coro = self.event_loop.create_connection(MyProto) + self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) + + def test_create_connection_no_getaddrinfo(self): + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + coro = self.event_loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_create_connection_connect_err(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80))] + self.event_loop.getaddrinfo = getaddrinfo + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + coro = self.event_loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_create_connection_mutiple_errors(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + self.event_loop.getaddrinfo = getaddrinfo + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + coro = self.event_loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_start_serving(self): + proto = None + + def factory(): + nonlocal proto + proto = MyProto() + return proto + + f = self.event_loop.start_serving(factory, '0.0.0.0', 0) + socks = self.event_loop.run_until_complete(f) + self.assertEqual(len(socks), 1) + sock = socks[0] + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + self.event_loop.run_once() + self.assertIsInstance(proto, MyProto) + self.assertEqual('INITIAL', proto.state) + self.event_loop.run_once() + self.assertEqual('CONNECTED', proto.state) + self.event_loop.run_once(0.001) # windows iocp + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('socket')) + conn = proto.transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + self.assertEqual( + '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + + # close connection + proto.transport.close() + self.event_loop.run_once(0.001) # windows iocp + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_start_serving_ssl(self): + proto = None + + class ClientMyProto(MyProto): + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def factory(): + nonlocal proto + proto = MyProto(create_future=True) + return proto + + here = os.path.dirname(__file__) + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain( + certfile=os.path.join(here, 'sample.crt'), + keyfile=os.path.join(here, 'sample.key')) + + f = self.event_loop.start_serving( + factory, '127.0.0.1', 0, ssl=sslcontext) + + sock = self.event_loop.run_until_complete(f)[0] + host, port = sock.getsockname() + self.assertEqual(host, '127.0.0.1') + + f_c = self.event_loop.create_connection( + ClientMyProto, host, port, ssl=True) + client, pr = self.event_loop.run_until_complete(f_c) + + client.write(b'xxx') + self.event_loop.run_once() + self.assertIsInstance(proto, MyProto) + self.event_loop.run_once() + self.assertEqual('CONNECTED', proto.state) + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('socket')) + conn = proto.transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + self.assertEqual( + '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + + # close connection + proto.transport.close() + self.event_loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + def test_start_serving_sock(self): + proto = futures.Future() + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + proto.set_result(self) + + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.event_loop.start_serving(TestMyProto, sock=sock_ob) + sock = self.event_loop.run_until_complete(f)[0] + self.assertIs(sock, sock_ob) + + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + self.event_loop.run_until_complete(proto) + sock.close() + client.close() + + def test_start_serving_addrinuse(self): + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.event_loop.start_serving(MyProto, sock=sock_ob) + sock = self.event_loop.run_until_complete(f)[0] + host, port = sock.getsockname() + + f = self.event_loop.start_serving(MyProto, host=host, port=port) + with self.assertRaises(socket.error) as cm: + self.event_loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + + @unittest.skipUnless(socket.has_ipv6, 'IPv6 not supported') + def test_start_serving_dual_stack(self): + f_proto = futures.Future() + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + f_proto.set_result(self) + + port = find_unused_port() + f = self.event_loop.start_serving(TestMyProto, host=None, port=port) + socks = self.event_loop.run_until_complete(f) + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + proto = self.event_loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + f_proto = futures.Future() + client = socket.socket(socket.AF_INET6) + client.connect(('::1', port)) + client.send(b'xxx') + proto = self.event_loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + for s in socks: + s.close() + + def test_stop_serving(self): + f = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + sock = self.event_loop.run_until_complete(f)[0] + host, port = sock.getsockname() + + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + + self.event_loop.stop_serving(sock) + + client = socket.socket() + self.assertRaises( + ConnectionRefusedError, client.connect, ('127.0.0.1', port)) + + def test_start_serving_host_port_sock(self): + fut = self.event_loop.start_serving( + MyProto, '0.0.0.0', 0, sock=object()) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_start_serving_no_host_port_sock(self): + fut = self.event_loop.start_serving(MyProto) + self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + + def test_start_serving_no_getaddrinfo(self): + getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + f = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, f) + + @unittest.mock.patch('tulip.base_events.socket') + def test_start_serving_cant_bind(self, m_socket): + + class Err(socket.error): + strerror = 'error' + + m_socket.error = socket.error + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.event_loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_no_addrinfo(self, m_socket): + m_socket.error = socket.error + m_socket.getaddrinfo.return_value = [] + + coro = self.event_loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 0)) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_create_datagram_endpoint_addr_error(self): + coro = self.event_loop.create_datagram_endpoint( + MyDatagramProto, local_addr='localhost') + self.assertRaises( + AssertionError, self.event_loop.run_until_complete, coro) + coro = self.event_loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 1, 2, 3)) + self.assertRaises( + AssertionError, self.event_loop.run_until_complete, coro) + + def test_create_datagram_endpoint(self): + class TestMyDatagramProto(MyDatagramProto): + def __init__(self): + super().__init__(create_future=True) + + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + coro = self.event_loop.create_datagram_endpoint( + TestMyDatagramProto, local_addr=('127.0.0.1', 0)) + s_transport, server = self.event_loop.run_until_complete(coro) + host, port = s_transport.get_extra_info('addr') + + coro = self.event_loop.create_datagram_endpoint( + lambda: MyDatagramProto(create_future=True), + remote_addr=(host, port)) + transport, client = self.event_loop.run_until_complete(coro) + + self.assertEqual('INITIALIZED', client.state) + transport.sendto(b'xxx') + self.event_loop.run_once(None) + self.assertEqual(3, server.nbytes) + self.event_loop.run_once(None) + + # received + self.assertEqual(8, client.nbytes) + + # extra info is available + self.assertIsNotNone(transport.get_extra_info('socket')) + conn = transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + + # close connection + transport.close() + self.event_loop.run_until_complete(client.done) + self.assertEqual('CLOSED', client.state) + server.transport.close() + + def test_create_datagram_endpoint_connect_err(self): + self.event_loop.sock_connect = unittest.mock.Mock() + self.event_loop.sock_connect.side_effect = socket.error + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0)) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_socket_err(self, m_socket): + m_socket.error = socket.error + m_socket.getaddrinfo = socket.getaddrinfo + m_socket.socket.side_effect = socket.error + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, local_addr=('127.0.0.1', 0)) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + + def test_create_datagram_endpoint_no_matching_family(self): + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, + remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) + self.assertRaises( + ValueError, self.event_loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_setblk_err(self, m_socket): + m_socket.error = socket.error + m_socket.socket.return_value.setblocking.side_effect = socket.error + + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + socket.error, self.event_loop.run_until_complete, coro) + self.assertTrue( + m_socket.socket.return_value.close.called) + + def test_create_datagram_endpoint_noaddr_nofamily(self): + coro = self.event_loop.create_datagram_endpoint( + protocols.DatagramProtocol) + self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_cant_bind(self, m_socket): + class Err(socket.error): + pass + + m_socket.error = socket.error + m_socket.AF_INET6 = socket.AF_INET6 + m_socket.getaddrinfo = socket.getaddrinfo + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.event_loop.create_datagram_endpoint( + MyDatagramProto, + local_addr=('127.0.0.1', 0), family=socket.AF_INET) + self.assertRaises(Err, self.event_loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_accept_connection_retry(self): + sock = unittest.mock.Mock() + sock.accept.side_effect = BlockingIOError() + + self.event_loop._accept_connection(MyProto, sock) + self.assertFalse(sock.close.called) + + @unittest.mock.patch('tulip.selector_events.tulip_log') + def test_accept_connection_exception(self, m_log): + sock = unittest.mock.Mock() + sock.accept.side_effect = OSError() + + self.event_loop._accept_connection(MyProto, sock) + self.assertTrue(sock.close.called) + self.assertTrue(m_log.exception.called) + + def test_internal_fds(self): + event_loop = self.create_event_loop() + if not isinstance(event_loop, selector_events.BaseSelectorEventLoop): + return + + self.assertEqual(1, event_loop._internal_fds) + event_loop.close() + self.assertEqual(0, event_loop._internal_fds) + self.assertIsNone(event_loop._csock) + self.assertIsNone(event_loop._ssock) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_read_pipe(self): + proto = None + + def factory(): + nonlocal proto + proto = MyReadPipeProto(create_future=True) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(rpipe, 'rb', 1024) + + @tasks.task + def connect(): + t, p = yield from self.event_loop.connect_read_pipe(factory, + pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(0, proto.nbytes) + + self.event_loop.run_until_complete(connect()) + + os.write(wpipe, b'1') + self.event_loop.run_once() + self.assertEqual(1, proto.nbytes) + + os.write(wpipe, b'2345') + self.event_loop.run_once() + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(5, proto.nbytes) + + os.close(wpipe) + self.event_loop.run_until_complete(proto.done) + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto(create_future=True) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.task + def connect(): + nonlocal transport + t, p = yield from self.event_loop.connect_write_pipe(factory, + pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.event_loop.run_until_complete(connect()) + + transport.write(b'1') + self.event_loop.run_once() + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + transport.write(b'2345') + self.event_loop.run_once() + data = os.read(rpipe, 1024) + self.assertEqual(b'2345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(rpipe) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + # close connection + proto.transport.close() + self.event_loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + +if sys.platform == 'win32': + from tulip import windows_events + + class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return windows_events.SelectorEventLoop() + + class ProactorEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return windows_events.ProactorEventLoop() + def test_create_ssl_connection(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + def test_start_serving_ssl(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + def test_reader_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_reader_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_reader_callback_with_handle(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_writer_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_writer_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_writer_callback_with_handle(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_accept_connection_retry(self): + raise unittest.SkipTest( + "IocpEventLoop does not have _accept_connection()") + def test_accept_connection_exception(self): + raise unittest.SkipTest( + "IocpEventLoop does not have _accept_connection()") + def test_create_datagram_endpoint(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_no_connection(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_cant_bind(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_noaddr_nofamily(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_socket_err(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_create_datagram_endpoint_connect_err(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + def test_stop_serving(self): + raise unittest.SkipTest( + "IocpEventLoop does not support stop_serving()") +else: + from tulip import selectors + from tulip import unix_events + + if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop( + selectors.KqueueSelector()) + + if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.EpollSelector()) + + if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.PollSelector()) + + # Should always exist. + class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.SelectSelector()) + + +class HandleTests(unittest.TestCase): + + def test_handle(self): + def callback(*args): + return args + + args = () + h = events.Handle(callback, args) + self.assertIs(h.callback, callback) + self.assertIs(h.args, args) + self.assertFalse(h.cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '.callback')) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h.cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '.callback')) + self.assertTrue(r.endswith('())')) + + def test_make_handle(self): + def callback(*args): + return args + h1 = events.Handle(callback, ()) + h2 = events.make_handle(h1, ()) + self.assertIs(h1, h2) + + self.assertRaises( + AssertionError, events.make_handle, h1, (1, 2)) + + @unittest.mock.patch('tulip.events.tulip_log') + def test_callback_with_exception(self, log): + def callback(): + raise ValueError() + + h = events.Handle(callback, ()) + h.run() + self.assertTrue(log.exception.called) + + +class TimerTests(unittest.TestCase): + + def test_timer(self): + def callback(*args): + return args + + args = () + when = time.monotonic() + h = events.TimerHandle(when, callback, args) + self.assertIs(h.callback, callback) + self.assertIs(h.args, args) + self.assertFalse(h.cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h.cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + self.assertRaises(AssertionError, + events.TimerHandle, None, callback, args) + + def test_timer_comparison(self): + def callback(*args): + return args + + when = time.monotonic() + + h1 = events.TimerHandle(when, callback, ()) + h2 = events.TimerHandle(when, callback, ()) + self.assertFalse(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertTrue(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertFalse(h2 > h1) + self.assertTrue(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertTrue(h1 == h2) + self.assertFalse(h1 != h2) + + h2.cancel() + self.assertFalse(h1 == h2) + + h1 = events.TimerHandle(when, callback, ()) + h2 = events.TimerHandle(when + 10.0, callback, ()) + self.assertTrue(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertFalse(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertTrue(h2 > h1) + self.assertFalse(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h3 = events.Handle(callback, ()) + self.assertIs(NotImplemented, h1.__eq__(h3)) + self.assertIs(NotImplemented, h1.__ne__(h3)) + + +class AbstractEventLoopTests(unittest.TestCase): + + def test_not_imlemented(self): + f = unittest.mock.Mock() + ev_loop = events.AbstractEventLoop() + self.assertRaises( + NotImplementedError, ev_loop.run_forever) + self.assertRaises( + NotImplementedError, ev_loop.run_once) + self.assertRaises( + NotImplementedError, ev_loop.run_until_complete, None) + self.assertRaises( + NotImplementedError, ev_loop.stop) + self.assertRaises( + NotImplementedError, ev_loop.is_running) + self.assertRaises( + NotImplementedError, ev_loop.call_later, None, None) + self.assertRaises( + NotImplementedError, ev_loop.call_repeatedly, None, None) + self.assertRaises( + NotImplementedError, ev_loop.call_soon, None) + self.assertRaises( + NotImplementedError, ev_loop.call_soon_threadsafe, None) + self.assertRaises( + NotImplementedError, ev_loop.wrap_future, f) + self.assertRaises( + NotImplementedError, ev_loop.run_in_executor, f, f) + self.assertRaises( + NotImplementedError, ev_loop.set_default_executor, f) + self.assertRaises( + NotImplementedError, ev_loop.getaddrinfo, 'localhost', 8080) + self.assertRaises( + NotImplementedError, ev_loop.getnameinfo, ('localhost', 8080)) + self.assertRaises( + NotImplementedError, ev_loop.create_connection, f) + self.assertRaises( + NotImplementedError, ev_loop.start_serving, f) + self.assertRaises( + NotImplementedError, ev_loop.stop_serving, f) + self.assertRaises( + NotImplementedError, ev_loop.create_datagram_endpoint, f) + self.assertRaises( + NotImplementedError, ev_loop.add_reader, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_reader, 1) + self.assertRaises( + NotImplementedError, ev_loop.add_writer, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_writer, 1) + self.assertRaises( + NotImplementedError, ev_loop.sock_recv, f, 10) + self.assertRaises( + NotImplementedError, ev_loop.sock_sendall, f, 10) + self.assertRaises( + NotImplementedError, ev_loop.sock_connect, f, f) + self.assertRaises( + NotImplementedError, ev_loop.sock_accept, f) + self.assertRaises( + NotImplementedError, ev_loop.add_signal_handler, 1, f) + self.assertRaises( + NotImplementedError, ev_loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, ev_loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, ev_loop.connect_read_pipe, f, + unittest.mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, ev_loop.connect_write_pipe, f, + unittest.mock.sentinel.pipe) + + +class ProtocolsAbsTests(unittest.TestCase): + + def test_empty(self): + f = unittest.mock.Mock() + p = protocols.Protocol() + self.assertIsNone(p.connection_made(f)) + self.assertIsNone(p.connection_lost(f)) + self.assertIsNone(p.data_received(f)) + self.assertIsNone(p.eof_received()) + + dp = protocols.DatagramProtocol() + self.assertIsNone(dp.connection_made(f)) + self.assertIsNone(dp.connection_lost(f)) + self.assertIsNone(dp.connection_refused(f)) + self.assertIsNone(dp.datagram_received(f, f)) + + +class PolicyTests(unittest.TestCase): + + def test_event_loop_policy(self): + policy = events.AbstractEventLoopPolicy() + self.assertRaises(NotImplementedError, policy.get_event_loop) + self.assertRaises(NotImplementedError, policy.set_event_loop, object()) + self.assertRaises(NotImplementedError, policy.new_event_loop) + + def test_get_event_loop(self): + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy._event_loop) + + event_loop = policy.get_event_loop() + self.assertIsInstance(event_loop, events.AbstractEventLoop) + + self.assertIs(policy._event_loop, event_loop) + self.assertIs(event_loop, policy.get_event_loop()) + + @unittest.mock.patch('tulip.events.threading') + def test_get_event_loop_thread(self, m_threading): + m_t = m_threading.current_thread.return_value = unittest.mock.Mock() + m_t.name = 'Thread 1' + + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy.get_event_loop()) + + def test_new_event_loop(self): + policy = events.DefaultEventLoopPolicy() + + event_loop = policy.new_event_loop() + self.assertIsInstance(event_loop, events.AbstractEventLoop) + + def test_set_event_loop(self): + policy = events.DefaultEventLoopPolicy() + old_event_loop = policy.get_event_loop() + + self.assertRaises(AssertionError, policy.set_event_loop, object()) + + event_loop = policy.new_event_loop() + policy.set_event_loop(event_loop) + self.assertIs(event_loop, policy.get_event_loop()) + self.assertIsNot(old_event_loop, policy.get_event_loop()) + + def test_get_event_loop_policy(self): + policy = events.get_event_loop_policy() + self.assertIsInstance(policy, events.AbstractEventLoopPolicy) + self.assertIs(policy, events.get_event_loop_policy()) + + def test_set_event_loop_policy(self): + self.assertRaises( + AssertionError, events.set_event_loop_policy, object()) + + old_policy = events.get_event_loop_policy() + + policy = events.DefaultEventLoopPolicy() + events.set_event_loop_policy(policy) + self.assertIs(policy, events.get_event_loop_policy()) + self.assertIsNot(policy, old_policy) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/futures_test.py b/tests/futures_test.py new file mode 100644 index 00000000..adc283d7 --- /dev/null +++ b/tests/futures_test.py @@ -0,0 +1,308 @@ +"""Tests for futures.py.""" + +import logging +import unittest +import unittest.mock + +from tulip import events +from tulip import futures + + +def _fakefunc(f): + return f + + +class FutureTests(unittest.TestCase): + + def setUp(self): + self.loop = events.get_event_loop() + + def test_initial_state(self): + f = futures.Future() + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertFalse(f.done()) + f.cancel() + self.assertTrue(f.cancelled()) + + def test_init_event_loop_positional(self): + # Make sure Future does't accept a positional argument + self.assertRaises(TypeError, futures.Future, 42) + + def test_cancel(self): + f = futures.Future() + self.assertTrue(f.cancel()) + self.assertTrue(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(futures.CancelledError, f.result) + self.assertRaises(futures.CancelledError, f.exception) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_result(self): + f = futures.Future() + self.assertRaises(futures.InvalidStateError, f.result) + self.assertRaises(futures.InvalidTimeoutError, f.result, 10) + + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_exception(self): + exc = RuntimeError() + f = futures.Future() + self.assertRaises(futures.InvalidStateError, f.exception) + self.assertRaises(futures.InvalidTimeoutError, f.exception, 10) + + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertFalse(f.running()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_yield_from_twice(self): + f = futures.Future() + + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + + def test_repr(self): + f_pending = futures.Future() + self.assertEqual(repr(f_pending), 'Future') + f_pending.cancel() + + f_cancelled = futures.Future() + f_cancelled.cancel() + self.assertEqual(repr(f_cancelled), 'Future') + + f_result = futures.Future() + f_result.set_result(4) + self.assertEqual(repr(f_result), 'Future') + self.assertEqual(f_result.result(), 4) + + exc = RuntimeError() + f_exception = futures.Future() + f_exception.set_exception(exc) + self.assertEqual(repr(f_exception), 'Future') + self.assertIs(f_exception.exception(), exc) + + f_few_callbacks = futures.Future() + f_few_callbacks.add_done_callback(_fakefunc) + self.assertIn('Future', r) + f_many_callbacks.cancel() + + def test_copy_state(self): + # Test the internal _copy_state method since it's being directly + # invoked in other modules. + f = futures.Future() + f.set_result(10) + + newf = futures.Future() + newf._copy_state(f) + self.assertTrue(newf.done()) + self.assertEqual(newf.result(), 10) + + f_exception = futures.Future() + f_exception.set_exception(RuntimeError()) + + newf_exception = futures.Future() + newf_exception._copy_state(f_exception) + self.assertTrue(newf_exception.done()) + self.assertRaises(RuntimeError, newf_exception.result) + + f_cancelled = futures.Future() + f_cancelled.cancel() + + newf_cancelled = futures.Future() + newf_cancelled._copy_state(f_cancelled) + self.assertTrue(newf_cancelled.cancelled()) + + def test_iter(self): + fut = futures.Future() + + def coro(): + yield from fut + + def test(): + arg1, arg2 = coro() + + self.assertRaises(AssertionError, test) + fut.cancel() + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_del_norm_level(self, log): + self.loop.set_log_level(logging.CRITICAL) + + fut = futures.Future() + del fut + self.assertFalse(log.error.called) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_del_normal(self, log): + self.loop.set_log_level(futures.STACK_DEBUG) + + fut = futures.Future() + fut.set_result(True) + fut.result() + del fut + self.assertFalse(log.error.called) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_del_not_done(self, log): + self.loop.set_log_level(futures.STACK_DEBUG) + + fut = futures.Future() + r_fut = repr(fut) + del fut + log.error.mock_calls[-1].assert_called_with( + 'Future abandoned before completion: %r', r_fut) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_del_done(self, log): + self.loop.set_log_level(futures.STACK_DEBUG) + + fut = futures.Future() + next(iter(fut)) + fut.set_result(1) + r_fut = repr(fut) + del fut + log.error.mock_calls[-1].assert_called_with( + 'Future result has not been requested: %r', r_fut) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_del_done_skip(self, log): + self.loop.set_log_level(futures.STACK_DEBUG) + + fut = futures.Future() + fut._debug_warn_result_requested = False + next(iter(fut)) + fut.set_result(1) + del fut + self.assertFalse(log.error.called) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_del_exc(self, log): + self.loop.set_log_level(futures.STACK_DEBUG) + + exc = ValueError() + fut = futures.Future() + fut.set_exception(exc) + r_fut = repr(fut) + del fut + log.exception.mock_calls[-1].assert_called_with( + 'Future raised an exception and nobody caught it: %r', r_fut, + exc_info=(ValueError, exc, None)) + + +# A fake event loop for tests. All it does is implement a call_soon method +# that immediately invokes the given function. +class _FakeEventLoop: + def call_soon(self, fn, future): + fn(future) + + def set_log_level(self, val): + pass + + def get_log_level(self): + return logging.CRITICAL + + +class FutureDoneCallbackTests(unittest.TestCase): + + def _make_callback(self, bag, thing): + # Create a callback function that appends thing to bag. + def bag_appender(future): + bag.append(thing) + return bag_appender + + def _new_future(self): + return futures.Future(event_loop=_FakeEventLoop()) + + def test_callbacks_invoked_on_set_result(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 42)) + f.add_done_callback(self._make_callback(bag, 17)) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def test_callbacks_invoked_on_set_exception(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 100)) + + self.assertEqual(bag, []) + exc = RuntimeError() + f.set_exception(exc) + self.assertEqual(bag, [100]) + self.assertEqual(f.exception(), exc) + + def test_remove_done_callback(self): + bag = [] + f = self._new_future() + cb1 = self._make_callback(bag, 1) + cb2 = self._make_callback(bag, 2) + cb3 = self._make_callback(bag, 3) + + # Add one cb1 and one cb2. + f.add_done_callback(cb1) + f.add_done_callback(cb2) + + # One instance of cb2 removed. Now there's only one cb1. + self.assertEqual(f.remove_done_callback(cb2), 1) + + # Never had any cb3 in there. + self.assertEqual(f.remove_done_callback(cb3), 0) + + # After this there will be 6 instances of cb1 and one of cb2. + f.add_done_callback(cb2) + for i in range(5): + f.add_done_callback(cb1) + + # Remove all instances of cb1. One cb2 remains. + self.assertEqual(f.remove_done_callback(cb1), 6) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [2]) + self.assertEqual(f.result(), 'foo') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/http_client_functional_test.py b/tests/http_client_functional_test.py new file mode 100644 index 00000000..31d0df86 --- /dev/null +++ b/tests/http_client_functional_test.py @@ -0,0 +1,487 @@ +"""Http client functional tests.""" + +import gc +import io +import os.path +import http.cookies +import unittest + +import tulip +import tulip.http +from tulip import test_utils +from tulip.http import client + + +class HttpClientFunctionalTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + gc.collect() + + def test_HTTP_200_OK_METHOD(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + for meth in ('get', 'post', 'put', 'delete', 'head'): + r = self.loop.run_until_complete( + client.request(meth, httpd.url('method', meth))) + content1 = self.loop.run_until_complete(r.read()) + content2 = self.loop.run_until_complete(r.read()) + content = content1.decode() + + self.assertEqual(r.status, 200) + self.assertIn('"method": "%s"' % meth.upper(), content) + self.assertEqual(content1, content2) + r.close() + + def test_HTTP_302_REDIRECT_GET(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('redirect', 2))) + + self.assertEqual(r.status, 200) + self.assertEqual(2, httpd['redirects']) + + def test_HTTP_302_REDIRECT_NON_HTTP(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + self.assertRaises( + ValueError, + self.loop.run_until_complete, + client.request('get', httpd.url('redirect_err'))) + + def test_HTTP_302_REDIRECT_POST(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('post', httpd.url('redirect', 2), + data={'some': 'data'})) + content = self.loop.run_until_complete(r.content.read()) + content = content.decode() + + self.assertEqual(r.status, 200) + self.assertIn('"method": "POST"', content) + self.assertEqual(2, httpd['redirects']) + + def test_HTTP_302_max_redirects(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('redirect', 5), + max_redirects=2)) + + self.assertEqual(r.status, 302) + self.assertEqual(2, httpd['redirects']) + + def test_HTTP_200_GET_WITH_PARAMS(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('method', 'get'), + params={'q': 'test'})) + content = self.loop.run_until_complete(r.content.read()) + content = content.decode() + + self.assertIn('"query": "q=test"', content) + self.assertEqual(r.status, 200) + + def test_HTTP_200_GET_WITH_MIXED_PARAMS(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request( + 'get', httpd.url('method', 'get') + '?test=true', + params={'q': 'test'})) + content = self.loop.run_until_complete(r.content.read()) + content = content.decode() + + self.assertIn('"query": "test=true&q=test"', content) + self.assertEqual(r.status, 200) + + def test_POST_DATA(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + r = self.loop.run_until_complete( + client.request('post', url, data={'some': 'data'})) + self.assertEqual(r.status, 200) + + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual({'some': ['data']}, content['form']) + self.assertEqual(r.status, 200) + + def test_POST_DATA_DEFLATE(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + r = self.loop.run_until_complete( + client.request('post', url, + data={'some': 'data'}, compress=True)) + self.assertEqual(r.status, 200) + + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual('deflate', content['compression']) + self.assertEqual({'some': ['data']}, content['form']) + self.assertEqual(r.status, 200) + + def test_POST_FILES(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request( + 'post', url, files={'some': f}, chunked=1024, + headers={'Transfer-Encoding': 'chunked'})) + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + + def test_POST_FILES_DEFLATE(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files={'some': f}, + chunked=1024, compress='deflate')) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual('deflate', content['compression']) + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + + def test_POST_FILES_STR(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files=[('some', f.read())])) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + 'some', content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + + def test_POST_FILES_LIST(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files=[('some', f)])) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + + def test_POST_FILES_LIST_CT(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, + files=[('some', f, 'text/plain')])) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual( + 'text/plain', content['multipart-data'][0]['content-type']) + self.assertEqual(r.status, 200) + + def test_POST_FILES_SINGLE(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files=[f])) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + filename, content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + + def test_POST_FILES_IO(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + data = io.BytesIO(b'data') + + r = self.loop.run_until_complete( + client.request('post', url, files=[data])) + + content = self.loop.run_until_complete(r.read(True)) + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + {'content-type': 'application/octet-stream', + 'data': 'data', + 'filename': 'unknown', + 'name': 'unknown'}, content['multipart-data'][0]) + self.assertEqual(r.status, 200) + + def test_POST_FILES_WITH_DATA(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, + data={'test': 'true'}, files={'some': f})) + + content = self.loop.run_until_complete(r.read(True)) + + self.assertEqual(2, len(content['multipart-data'])) + self.assertEqual( + 'test', content['multipart-data'][0]['name']) + self.assertEqual( + 'true', content['multipart-data'][0]['data']) + + f.seek(0) + filename = os.path.split(f.name)[-1] + self.assertEqual( + 'some', content['multipart-data'][1]['name']) + self.assertEqual( + filename, content['multipart-data'][1]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][1]['data']) + self.assertEqual(r.status, 200) + + def test_encoding(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('encoding', 'deflate'))) + self.assertEqual(r.status, 200) + + r = self.loop.run_until_complete( + client.request('get', httpd.url('encoding', 'gzip'))) + self.assertEqual(r.status, 200) + + def test_cookies(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + c = http.cookies.Morsel() + c.set('test3', '456', '456') + + r = self.loop.run_until_complete( + client.request( + 'get', httpd.url('method', 'get'), + cookies={'test1': '123', 'test2': c})) + self.assertEqual(r.status, 200) + + content = self.loop.run_until_complete(r.content.read()) + self.assertIn(b'"Cookie": "test1=123; test3=456"', bytes(content)) + + def test_set_cookies(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + resp = self.loop.run_until_complete( + client.request('get', httpd.url('cookies'))) + self.assertEqual(resp.status, 200) + + self.assertEqual(resp.cookies['c1'].value, 'cookie1') + self.assertEqual(resp.cookies['c2'].value, 'cookie2') + + def test_chunked(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('chunked'))) + self.assertEqual(r.status, 200) + self.assertEqual(r['Transfer-Encoding'], 'chunked') + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['path'], '/chunked') + + def test_timeout(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + httpd['noresponse'] = True + self.assertRaises( + tulip.TimeoutError, + self.loop.run_until_complete, + client.request('get', httpd.url('method', 'get'), timeout=0.1)) + + def test_request_conn_error(self): + self.assertRaises( + OSError, + self.loop.run_until_complete, + client.request('get', 'http://0.0.0.0:1', timeout=0.1)) + + def test_keepalive(self): + from tulip.http import session + s = session.Session() + + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('keepalive',), session=s)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['content'], 'requests=1') + r.close() + + r = self.loop.run_until_complete( + client.request('get', httpd.url('keepalive'), session=s)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['content'], 'requests=2') + r.close() + + def test_session_close(self): + from tulip.http import session + s = session.Session() + + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request( + 'get', httpd.url('keepalive') + '?close=1', session=s)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['content'], 'requests=1') + r.close() + + r = self.loop.run_until_complete( + client.request('get', httpd.url('keepalive'), session=s)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['content'], 'requests=1') + r.close() + + def test_session_cookies(self): + from tulip.http import session + s = session.Session() + + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + s.update_cookies({'test': '1'}) + r = self.loop.run_until_complete( + client.request( + 'get', httpd.url('cookies'), session=s)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + + self.assertEqual(content['headers']['Cookie'], 'test=1') + r.close() + + cookies = sorted([(k, v.value) for k, v in s.cookies.items()]) + self.assertEqual( + cookies, [('c1', 'cookie1'), ('c2', 'cookie2'), ('test', '1')]) + + +class Functional(test_utils.Router): + + @test_utils.Router.define('/method/([A-Za-z]+)$') + def method(self, match): + meth = match.group(1).upper() + if meth == self._method: + self._response(self._start_response(200)) + else: + self._response(self._start_response(400)) + + @test_utils.Router.define('/redirect_err$') + def redirect_err(self, match): + self._response( + self._start_response(302), + headers={'Location': 'ftp://127.0.0.1/test/'}) + + @test_utils.Router.define('/redirect/([0-9]+)$') + def redirect(self, match): + no = int(match.group(1).upper()) + rno = self._props['redirects'] = self._props.get('redirects', 0) + 1 + + if rno >= no: + self._response( + self._start_response(302), + headers={'Location': '/method/%s' % self._method.lower()}) + else: + self._response( + self._start_response(302), + headers={'Location': self._path}) + + @test_utils.Router.define('/encoding/(gzip|deflate)$') + def encoding(self, match): + mode = match.group(1) + + resp = self._start_response(200) + resp.add_compression_filter(mode) + resp.add_chunking_filter(100) + self._response(resp, headers={'Content-encoding': mode}, chunked=True) + + @test_utils.Router.define('/chunked$') + def chunked(self, match): + resp = self._start_response(200) + resp.add_chunking_filter(100) + self._response(resp, chunked=True) + + @test_utils.Router.define('/keepalive$') + def keepalive(self, match): + self._transport._requests = getattr( + self._transport, '_requests', 0) + 1 + resp = self._start_response(200) + if 'close=' in self._query: + self._response( + resp, 'requests={}'.format(self._transport._requests)) + else: + self._response( + resp, 'requests={}'.format(self._transport._requests), + headers={'CONNECTION': 'keep-alive'}) + + @test_utils.Router.define('/cookies$') + def cookies(self, match): + cookies = http.cookies.SimpleCookie() + cookies['c1'] = 'cookie1' + cookies['c2'] = 'cookie2' + + resp = self._start_response(200) + for cookie in cookies.output(header='').split('\n'): + resp.add_header('Set-Cookie', cookie.strip()) + + self._response(resp) diff --git a/tests/http_client_test.py b/tests/http_client_test.py new file mode 100644 index 00000000..77a2a7ef --- /dev/null +++ b/tests/http_client_test.py @@ -0,0 +1,289 @@ +# -*- coding: utf-8 -*- +"""Tests for tulip/http/client.py""" + +import unittest +import unittest.mock +import urllib.parse + +import tulip +import tulip.http + +from tulip.http.client import HttpRequest, HttpResponse + + +class HttpResponseTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + self.transport = unittest.mock.Mock() + self.stream = tulip.StreamBuffer() + self.response = HttpResponse('get', 'http://python.org') + + def tearDown(self): + self.loop.close() + + def test_close(self): + self.response.transport = self.transport + self.response.close() + self.assertIsNone(self.response.transport) + self.assertTrue(self.transport.close.called) + self.response.close() + self.response.close() + + def test_repr(self): + self.response.status = 200 + self.response.reason = 'Ok' + self.assertIn( + '', repr(self.response)) + + +class HttpRequestTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + self.transport = unittest.mock.Mock() + self.stream = tulip.StreamBuffer() + + def tearDown(self): + self.loop.close() + + def test_method(self): + req = HttpRequest('get', 'http://python.org/') + self.assertEqual(req.method, 'GET') + + req = HttpRequest('head', 'http://python.org/') + self.assertEqual(req.method, 'HEAD') + + req = HttpRequest('HEAD', 'http://python.org/') + self.assertEqual(req.method, 'HEAD') + + def test_version(self): + req = HttpRequest('get', 'http://python.org/', version='1.0') + self.assertEqual(req.version, (1, 0)) + + def test_version_err(self): + self.assertRaises( + ValueError, + HttpRequest, 'get', 'http://python.org/', version='1.c') + + def test_host_port(self): + req = HttpRequest('get', 'http://python.org/') + self.assertEqual(req.host, 'python.org') + self.assertEqual(req.port, 80) + self.assertFalse(req.ssl) + + req = HttpRequest('get', 'https://python.org/') + self.assertEqual(req.host, 'python.org') + self.assertEqual(req.port, 443) + self.assertTrue(req.ssl) + + req = HttpRequest('get', 'https://python.org:960/') + self.assertEqual(req.host, 'python.org') + self.assertEqual(req.port, 960) + self.assertTrue(req.ssl) + + def test_host_port_err(self): + self.assertRaises( + ValueError, HttpRequest, 'get', 'http://python.org:123e/') + + def test_host_header(self): + req = HttpRequest('get', 'http://python.org/') + self.assertEqual(req.headers['host'], 'python.org') + + req = HttpRequest('get', 'http://python.org/', + headers={'host': 'example.com'}) + self.assertEqual(req.headers['host'], 'example.com') + + def test_headers(self): + req = HttpRequest('get', 'http://python.org/', + headers={'Content-Type': 'text/plain'}) + self.assertIn('Content-Type', req.headers) + self.assertEqual(req.headers['Content-Type'], 'text/plain') + self.assertEqual(req.headers['Accept-Encoding'], 'gzip, deflate') + + def test_headers_list(self): + req = HttpRequest('get', 'http://python.org/', + headers=[('Content-Type', 'text/plain')]) + self.assertIn('Content-Type', req.headers) + self.assertEqual(req.headers['Content-Type'], 'text/plain') + + def test_headers_default(self): + req = HttpRequest('get', 'http://python.org/', + headers={'Accept-Encoding': 'deflate'}) + self.assertEqual(req.headers['Accept-Encoding'], 'deflate') + + def test_invalid_url(self): + self.assertRaises(ValueError, HttpRequest, 'get', 'hiwpefhipowhefopw') + + def test_invalid_idna(self): + self.assertRaises( + ValueError, HttpRequest, 'get', 'http://\u2061owhefopw.com') + + def test_no_path(self): + req = HttpRequest('get', 'http://python.org') + self.assertEqual('/', req.path) + + def test_basic_auth(self): + req = HttpRequest('get', 'http://python.org', auth=('nkim', '1234')) + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbToxMjM0', req.headers['Authorization']) + + def test_basic_auth_from_url(self): + req = HttpRequest('get', 'http://nkim:1234@python.org') + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbToxMjM0', req.headers['Authorization']) + + req = HttpRequest('get', 'http://nkim@python.org') + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbTo=', req.headers['Authorization']) + + req = HttpRequest( + 'get', 'http://nkim@python.org', auth=('nkim', '1234')) + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbToxMjM0', req.headers['Authorization']) + + def test_basic_auth_err(self): + self.assertRaises( + ValueError, HttpRequest, + 'get', 'http://python.org', auth=(1, 2, 3)) + + def test_no_content_length(self): + req = HttpRequest('get', 'http://python.org') + req.send(self.transport) + self.assertEqual('0', req.headers.get('Content-Length')) + + req = HttpRequest('head', 'http://python.org') + req.send(self.transport) + self.assertEqual('0', req.headers.get('Content-Length')) + + def test_path_is_not_double_encoded(self): + req = HttpRequest('get', "http://0.0.0.0/get/test case") + self.assertEqual(req.path, "/get/test%20case") + + req = HttpRequest('get', "http://0.0.0.0/get/test%20case") + self.assertEqual(req.path, "/get/test%20case") + + def test_params_are_added_before_fragment(self): + req = HttpRequest( + 'GET', "http://example.com/path#fragment", params={"a": "b"}) + self.assertEqual( + req.path, "/path?a=b#fragment") + + req = HttpRequest( + 'GET', + "http://example.com/path?key=value#fragment", params={"a": "b"}) + self.assertEqual( + req.path, "/path?key=value&a=b#fragment") + + def test_cookies(self): + req = HttpRequest( + 'get', 'http://test.com/path', cookies={'cookie1': 'val1'}) + self.assertIn('Cookie', req.headers) + self.assertEqual('cookie1=val1', req.headers['cookie']) + + req = HttpRequest( + 'get', 'http://test.com/path', + headers={'cookie': 'cookie1=val1'}, + cookies={'cookie2': 'val2'}) + self.assertEqual('cookie1=val1; cookie2=val2', req.headers['cookie']) + + def test_unicode_get(self): + def join(*suffix): + return urllib.parse.urljoin('http://python.org/', '/'.join(suffix)) + + url = 'http://python.org' + req = HttpRequest('get', url, params={'foo': 'f\xf8\xf8'}) + self.assertEqual('/?foo=f%C3%B8%C3%B8', req.path) + req = HttpRequest('', url, params={'f\xf8\xf8': 'f\xf8\xf8'}) + self.assertEqual('/?f%C3%B8%C3%B8=f%C3%B8%C3%B8', req.path) + req = HttpRequest('', url, params={'foo': 'foo'}) + self.assertEqual('/?foo=foo', req.path) + req = HttpRequest('', join('\xf8'), params={'foo': 'foo'}) + self.assertEqual('/%C3%B8?foo=foo', req.path) + + def test_query_multivalued_param(self): + for meth in HttpRequest.ALL_METHODS: + req = HttpRequest( + meth, 'http://python.org', + params=(('test', 'foo'), ('test', 'baz'))) + self.assertEqual(req.path, '/?test=foo&test=baz') + + def test_post_data(self): + for meth in HttpRequest.POST_METHODS: + req = HttpRequest(meth, 'http://python.org/', data={'life': '42'}) + req.send(self.transport) + self.assertEqual('/', req.path) + self.assertEqual(b'life=42', req.body[0]) + self.assertEqual('application/x-www-form-urlencoded', + req.headers['content-type']) + + def test_get_with_data(self): + for meth in HttpRequest.GET_METHODS: + req = HttpRequest(meth, 'http://python.org/', data={'life': '42'}) + self.assertEqual('/?life=42', req.path) + + @unittest.mock.patch('tulip.http.client.tulip') + def test_content_encoding(self, m_tulip): + req = HttpRequest('get', 'http://python.org/', compress='deflate') + req.send(self.transport) + self.assertEqual(req.headers['Transfer-encoding'], 'chunked') + self.assertEqual(req.headers['Content-encoding'], 'deflate') + m_tulip.http.Request.return_value\ + .add_compression_filter.assert_called_with('deflate') + + @unittest.mock.patch('tulip.http.client.tulip') + def test_content_encoding_header(self, m_tulip): + req = HttpRequest('get', 'http://python.org/', + headers={'Content-Encoding': 'deflate'}) + req.send(self.transport) + self.assertEqual(req.headers['Transfer-encoding'], 'chunked') + self.assertEqual(req.headers['Content-encoding'], 'deflate') + + m_tulip.http.Request.return_value\ + .add_compression_filter.assert_called_with('deflate') + m_tulip.http.Request.return_value\ + .add_chunking_filter.assert_called_with(8196) + + def test_chunked(self): + req = HttpRequest( + 'get', 'http://python.org/', + headers={'Transfer-encoding': 'gzip'}) + req.send(self.transport) + self.assertEqual('gzip', req.headers['Transfer-encoding']) + + req = HttpRequest( + 'get', 'http://python.org/', + headers={'Transfer-encoding': 'chunked'}) + req.send(self.transport) + self.assertEqual('chunked', req.headers['Transfer-encoding']) + + @unittest.mock.patch('tulip.http.client.tulip') + def test_chunked_explicit(self, m_tulip): + req = HttpRequest( + 'get', 'http://python.org/', chunked=True) + req.send(self.transport) + + self.assertEqual('chunked', req.headers['Transfer-encoding']) + m_tulip.http.Request.return_value\ + .add_chunking_filter.assert_called_with(8196) + + @unittest.mock.patch('tulip.http.client.tulip') + def test_chunked_explicit_size(self, m_tulip): + req = HttpRequest( + 'get', 'http://python.org/', chunked=1024) + req.send(self.transport) + self.assertEqual('chunked', req.headers['Transfer-encoding']) + m_tulip.http.Request.return_value\ + .add_chunking_filter.assert_called_with(1024) + + def test_chunked_length(self): + req = HttpRequest( + 'get', 'http://python.org/', + headers={'Content-Length': '1000'}, chunked=1024) + req.send(self.transport) + self.assertEqual(req.headers['Transfer-Encoding'], 'chunked') + self.assertNotIn('Content-Length', req.headers) diff --git a/tests/http_parser_test.py b/tests/http_parser_test.py new file mode 100644 index 00000000..c8dfb1bc --- /dev/null +++ b/tests/http_parser_test.py @@ -0,0 +1,510 @@ +"""Tests for http/parser.py""" + +from collections import deque +import zlib +import unittest +import unittest.mock + +import tulip +from tulip.http import errors +from tulip.http import protocol + + +class ParseHeadersTests(unittest.TestCase): + + def test_parse_headers(self): + hdrs = ('', 'test: line\r\n', ' continue\r\n', + 'test2: data\r\n', '\r\n') + + headers, close, compression = protocol.parse_headers( + hdrs, 8190, 32768, 8190) + + self.assertEqual(list(headers), + [('TEST', 'line\r\n continue'), ('TEST2', 'data')]) + self.assertIsNone(close) + self.assertIsNone(compression) + + def test_parse_headers_multi(self): + hdrs = ('', + 'Set-Cookie: c1=cookie1\r\n', + 'Set-Cookie: c2=cookie2\r\n', '\r\n') + + headers, close, compression = protocol.parse_headers( + hdrs, 8190, 32768, 8190) + + self.assertEqual(list(headers), + [('SET-COOKIE', 'c1=cookie1'), + ('SET-COOKIE', 'c2=cookie2')]) + self.assertIsNone(close) + self.assertIsNone(compression) + + def test_conn_close(self): + headers, close, compression = protocol.parse_headers( + ['', 'connection: close\r\n', '\r\n'], 8190, 32768, 8190) + self.assertTrue(close) + + def test_conn_keep_alive(self): + headers, close, compression = protocol.parse_headers( + ['', 'connection: keep-alive\r\n', '\r\n'], 8190, 32768, 8190) + self.assertFalse(close) + + def test_conn_other(self): + headers, close, compression = protocol.parse_headers( + ['', 'connection: test\r\n', '\r\n'], 8190, 32768, 8190) + self.assertIsNone(close) + + def test_compression_gzip(self): + headers, close, compression = protocol.parse_headers( + ['', 'content-encoding: gzip\r\n', '\r\n'], 8190, 32768, 8190) + self.assertEqual('gzip', compression) + + def test_compression_deflate(self): + headers, close, compression = protocol.parse_headers( + ['', 'content-encoding: deflate\r\n', '\r\n'], 8190, 32768, 8190) + self.assertEqual('deflate', compression) + + def test_compression_unknown(self): + headers, close, compression = protocol.parse_headers( + ['', 'content-encoding: compress\r\n', '\r\n'], 8190, 32768, 8190) + self.assertIsNone(compression) + + def test_max_field_size(self): + with self.assertRaises(errors.LineTooLong) as cm: + protocol.parse_headers( + ['', 'test: line data data\r\n', 'data\r\n', '\r\n'], + 8190, 32768, 5) + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_max_continuation_headers_size(self): + with self.assertRaises(errors.LineTooLong) as cm: + protocol.parse_headers( + ['', 'test: line\r\n', ' test\r\n', '\r\n'], 8190, 32768, 5) + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_invalid_header(self): + with self.assertRaises(ValueError) as cm: + protocol.parse_headers( + ['', 'test line\r\n', '\r\n'], 8190, 32768, 8190) + self.assertIn("Invalid header: test line", str(cm.exception)) + + def test_invalid_name(self): + with self.assertRaises(ValueError) as cm: + protocol.parse_headers( + ['', 'test[]: line\r\n', '\r\n'], 8190, 32768, 8190) + self.assertIn("Invalid header name: TEST[]", str(cm.exception)) + + +class DeflateBufferTests(unittest.TestCase): + + def test_feed_data(self): + buf = tulip.DataBuffer() + dbuf = protocol.DeflateBuffer(buf, 'deflate') + + dbuf.zlib = unittest.mock.Mock() + dbuf.zlib.decompress.return_value = b'line' + + dbuf.feed_data(b'data') + self.assertEqual([b'line'], list(buf._buffer)) + + def test_feed_data_err(self): + buf = tulip.DataBuffer() + dbuf = protocol.DeflateBuffer(buf, 'deflate') + + exc = ValueError() + dbuf.zlib = unittest.mock.Mock() + dbuf.zlib.decompress.side_effect = exc + + self.assertRaises(errors.IncompleteRead, dbuf.feed_data, b'data') + + def test_feed_eof(self): + buf = tulip.DataBuffer() + dbuf = protocol.DeflateBuffer(buf, 'deflate') + + dbuf.zlib = unittest.mock.Mock() + dbuf.zlib.flush.return_value = b'line' + + dbuf.feed_eof() + self.assertEqual([b'line'], list(buf._buffer)) + self.assertTrue(buf._eof) + + def test_feed_eof_err(self): + buf = tulip.DataBuffer() + dbuf = protocol.DeflateBuffer(buf, 'deflate') + + dbuf.zlib = unittest.mock.Mock() + dbuf.zlib.flush.return_value = b'line' + dbuf.zlib.eof = False + + self.assertRaises(errors.IncompleteRead, dbuf.feed_eof) + + +class ParsePayloadTests(unittest.TestCase): + + def test_parse_eof_payload(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_eof_payload(out, buf) + next(p) + p.send(b'data') + try: + p.throw(tulip.EofStream()) + except tulip.EofStream: + pass + + self.assertEqual([b'data'], list(out._buffer)) + + def test_parse_length_payload(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_length_payload(out, buf, 4) + next(p) + p.send(b'da') + p.send(b't') + try: + p.send(b'aline') + except StopIteration: + pass + + self.assertEqual(3, len(out._buffer)) + self.assertEqual(b'data', b''.join(out._buffer)) + self.assertEqual(b'line', bytes(buf)) + + def test_parse_length_payload_eof(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_length_payload(out, buf, 4) + next(p) + p.send(b'da') + self.assertRaises( + errors.IncompleteRead, p.throw, tulip.EofStream) + + def test_parse_chunked_payload(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + try: + p.send(b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') + except StopIteration: + pass + self.assertEqual(b'dataline', b''.join(out._buffer)) + self.assertEqual(b'', bytes(buf)) + + def test_parse_chunked_payload_chunks(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + p.send(b'4\r\ndata\r') + p.send(b'\n4') + p.send(b'\r') + p.send(b'\n') + p.send(b'line\r\n0\r\n') + self.assertRaises(StopIteration, p.send, b'test\r\n') + self.assertEqual(b'dataline', b''.join(out._buffer)) + + def test_parse_chunked_payload_incomplete(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + p.send(b'4\r\ndata\r\n') + self.assertRaises(errors.IncompleteRead, p.throw, tulip.EofStream) + + def test_parse_chunked_payload_extension(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + try: + p.send(b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') + except StopIteration: + pass + self.assertEqual(b'dataline', b''.join(out._buffer)) + + def test_parse_chunked_payload_size_error(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + self.assertRaises(errors.IncompleteRead, p.send, b'blah\r\n') + + def test_http_payload_parser_length_broken(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', 'qwe')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + self.assertRaises(errors.InvalidHeader, p.send, (out, buf)) + + def test_http_payload_parser_length_wrong(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', '-1')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + self.assertRaises(errors.InvalidHeader, p.send, (out, buf)) + + def test_http_payload_parser_length(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', '2')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + try: + p.send(b'1245') + except StopIteration: + pass + + self.assertEqual(b'12', b''.join(out._buffer)) + self.assertEqual(b'45', bytes(buf)) + + def test_http_payload_parser_no_length(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [], None, None) + p = protocol.http_payload_parser(msg, readall=False) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + self.assertRaises(StopIteration, p.send, (out, buf)) + self.assertEqual(b'', b''.join(out._buffer)) + self.assertTrue(out._eof) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_http_payload_parser_deflate(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', len(self._COMPRESSED))], + None, 'deflate') + p = protocol.http_payload_parser(msg) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(StopIteration, p.send, self._COMPRESSED) + self.assertEqual(b'data', b''.join(out._buffer)) + + def test_http_payload_parser_deflate_disabled(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', len(self._COMPRESSED))], + None, 'deflate') + p = protocol.http_payload_parser(msg, compression=False) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(StopIteration, p.send, self._COMPRESSED) + self.assertEqual(self._COMPRESSED, b''.join(out._buffer)) + + def test_http_payload_parser_websocket(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('SEC-WEBSOCKET-KEY1', '13')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(StopIteration, p.send, b'1234567890') + self.assertEqual(b'12345678', b''.join(out._buffer)) + + def test_http_payload_parser_chunked(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('TRANSFER-ENCODING', 'chunked')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(StopIteration, p.send, + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') + self.assertEqual(b'dataline', b''.join(out._buffer)) + + def test_http_payload_parser_eof(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [], None, None) + p = protocol.http_payload_parser(msg, readall=True) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + p.send(b'data') + p.send(b'line') + self.assertRaises(tulip.EofStream, p.throw, tulip.EofStream()) + self.assertEqual(b'dataline', b''.join(out._buffer)) + + def test_http_payload_parser_length_zero(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', '0')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + self.assertRaises(StopIteration, p.send, (out, buf)) + self.assertEqual(b'', b''.join(out._buffer)) + + +class ParseRequestTests(unittest.TestCase): + + def test_http_request_parser_max_headers(self): + p = protocol.http_request_parser(8190, 20, 8190) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + + self.assertRaises( + errors.LineTooLong, + p.send, + b'get /path HTTP/1.1\r\ntest: line\r\ntest2: data\r\n\r\n') + + def test_http_request_parser(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + try: + p.send(b'get /path HTTP/1.1\r\n\r\n') + except StopIteration: + pass + result = out._buffer[0] + self.assertEqual( + ('GET', '/path', (1, 1), deque(), False, None), result) + + def test_http_request_parser_two_slashes(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + try: + p.send(b'get //path HTTP/1.1\r\n\r\n') + except StopIteration: + pass + self.assertEqual( + ('GET', '//path', (1, 1), deque(), False, None), out._buffer[0]) + + def test_http_request_parser_bad_status_line(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises( + errors.BadStatusLine, p.send, b'\r\n\r\n') + + def test_http_request_parser_bad_method(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises( + errors.BadStatusLine, + p.send, b'!12%()+=~$ /get HTTP/1.1\r\n\r\n') + + def test_http_request_parser_bad_version(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises( + errors.BadStatusLine, + p.send, b'GET //get HT/11\r\n\r\n') + + +class ParseResponseTests(unittest.TestCase): + + def test_http_response_parser_bad_status_line(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(errors.BadStatusLine, p.send, b'\r\n\r\n') + + def test_http_response_parser_bad_status_line_eof(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises( + errors.BadStatusLine, p.throw, tulip.EofStream()) + + def test_http_response_parser_bad_version(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HT/11 200 Ok\r\n\r\n') + self.assertEqual('HT/11 200 Ok\r\n', cm.exception.args[0]) + + def test_http_response_parser_no_reason(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + try: + p.send(b'HTTP/1.1 200\r\n\r\n') + except StopIteration: + pass + v, s, r = out._buffer[0][:3] + self.assertEqual(v, (1, 1)) + self.assertEqual(s, 200) + self.assertEqual(r, '') + + def test_http_response_parser_bad(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HTT/1\r\n\r\n') + self.assertIn('HTT/1', str(cm.exception)) + + def test_http_response_parser_code_under_100(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HTTP/1.1 99 test\r\n\r\n') + self.assertIn('HTTP/1.1 99 test', str(cm.exception)) + + def test_http_response_parser_code_above_999(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HTTP/1.1 9999 test\r\n\r\n') + self.assertIn('HTTP/1.1 9999 test', str(cm.exception)) + + def test_http_response_parser_code_not_int(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HTTP/1.1 ttt test\r\n\r\n') + self.assertIn('HTTP/1.1 ttt test', str(cm.exception)) diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py new file mode 100644 index 00000000..9455426a --- /dev/null +++ b/tests/http_protocol_test.py @@ -0,0 +1,384 @@ +"""Tests for http/protocol.py""" + +import unittest +import unittest.mock +import zlib + +from tulip.http import protocol + + +class HttpMessageTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + + def test_start_request(self): + msg = protocol.Request( + self.transport, 'GET', '/index.html', close=True) + + self.assertIs(msg.transport, self.transport) + self.assertIsNone(msg.status) + self.assertTrue(msg.closing) + self.assertEqual(msg.status_line, 'GET /index.html HTTP/1.1\r\n') + + def test_start_response(self): + msg = protocol.Response(self.transport, 200, close=True) + + self.assertIs(msg.transport, self.transport) + self.assertEqual(msg.status, 200) + self.assertTrue(msg.closing) + self.assertEqual(msg.status_line, 'HTTP/1.1 200 OK\r\n') + + def test_force_close(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.closing) + msg.force_close() + self.assertTrue(msg.closing) + + def test_force_chunked(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.chunked) + msg.force_chunked() + self.assertTrue(msg.chunked) + + def test_keep_alive(self): + msg = protocol.Response(self.transport, 200, close=True) + self.assertFalse(msg.keep_alive()) + msg.keepalive = True + self.assertTrue(msg.keep_alive()) + + msg.force_close() + self.assertFalse(msg.keep_alive()) + + def test_add_header(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], list(msg.headers)) + + msg.add_header('content-type', 'plain/html') + self.assertEqual([('CONTENT-TYPE', 'plain/html')], list(msg.headers)) + + def test_add_headers(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], list(msg.headers)) + + msg.add_headers(('content-type', 'plain/html')) + self.assertEqual([('CONTENT-TYPE', 'plain/html')], list(msg.headers)) + + def test_add_headers_length(self): + msg = protocol.Response(self.transport, 200) + self.assertIsNone(msg.length) + + msg.add_headers(('content-length', '200')) + self.assertEqual(200, msg.length) + + def test_add_headers_upgrade(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.upgrade) + + msg.add_headers(('connection', 'upgrade')) + self.assertTrue(msg.upgrade) + + def test_add_headers_upgrade_websocket(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('upgrade', 'test')) + self.assertEqual([], list(msg.headers)) + + msg.add_headers(('upgrade', 'websocket')) + self.assertEqual([('UPGRADE', 'websocket')], list(msg.headers)) + + def test_add_headers_connection_keepalive(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('connection', 'keep-alive')) + self.assertEqual([], list(msg.headers)) + self.assertTrue(msg.keepalive) + + msg.add_headers(('connection', 'close')) + self.assertFalse(msg.keepalive) + + def test_add_headers_hop_headers(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('connection', 'test'), ('transfer-encoding', 't')) + self.assertEqual([], list(msg.headers)) + + def test_default_headers(self): + msg = protocol.Response(self.transport, 200) + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertIn('DATE', headers) + self.assertIn('CONNECTION', headers) + + def test_default_headers_server(self): + msg = protocol.Response(self.transport, 200) + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertIn('SERVER', headers) + + def test_default_headers_useragent(self): + msg = protocol.Request(self.transport, 'GET', '/') + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertNotIn('SERVER', headers) + self.assertIn('USER-AGENT', headers) + + def test_default_headers_chunked(self): + msg = protocol.Response(self.transport, 200) + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertNotIn('TRANSFER-ENCODING', headers) + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertIn('TRANSFER-ENCODING', headers) + + def test_default_headers_connection_upgrade(self): + msg = protocol.Response(self.transport, 200) + msg.upgrade = True + msg._add_default_headers() + + headers = [r for r in msg.headers if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'upgrade')], headers) + + def test_default_headers_connection_close(self): + msg = protocol.Response(self.transport, 200) + msg.force_close() + msg._add_default_headers() + + headers = [r for r in msg.headers if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'close')], headers) + + def test_default_headers_connection_keep_alive(self): + msg = protocol.Response(self.transport, 200) + msg.keepalive = True + msg._add_default_headers() + + headers = [r for r in msg.headers if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'keep-alive')], headers) + + def test_send_headers(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-type', 'plain/html')) + self.assertFalse(msg.is_headers_sent()) + + msg.send_headers() + + content = b''.join([arg[1][0] for arg in list(write.mock_calls)]) + + self.assertTrue(content.startswith(b'HTTP/1.1 200 OK\r\n')) + self.assertIn(b'CONTENT-TYPE: plain/html', content) + self.assertTrue(msg.headers_sent) + self.assertTrue(msg.is_headers_sent()) + + def test_send_headers_nomore_add(self): + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-type', 'plain/html')) + msg.send_headers() + + self.assertRaises(AssertionError, + msg.add_header, 'content-type', 'plain/html') + + def test_prepare_length(self): + msg = protocol.Response(self.transport, 200) + length = msg._write_length_payload = unittest.mock.Mock() + length.return_value = iter([1, 2, 3]) + + msg.add_headers(('content-length', '200')) + msg.send_headers() + + self.assertTrue(length.called) + self.assertTrue((200,), length.call_args[0]) + + def test_prepare_chunked_force(self): + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + + chunked = msg._write_chunked_payload = unittest.mock.Mock() + chunked.return_value = iter([1, 2, 3]) + + msg.add_headers(('content-length', '200')) + msg.send_headers() + self.assertTrue(chunked.called) + + def test_prepare_chunked_no_length(self): + msg = protocol.Response(self.transport, 200) + + chunked = msg._write_chunked_payload = unittest.mock.Mock() + chunked.return_value = iter([1, 2, 3]) + + msg.send_headers() + self.assertTrue(chunked.called) + + def test_prepare_eof(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + + eof = msg._write_eof_payload = unittest.mock.Mock() + eof.return_value = iter([1, 2, 3]) + + msg.send_headers() + self.assertTrue(eof.called) + + def test_write_auto_send_headers(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg._send_headers = True + + msg.write(b'data1') + self.assertTrue(msg.headers_sent) + + def test_write_payload_eof(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg.send_headers() + + msg.write(b'data1') + self.assertTrue(msg.headers_sent) + + msg.write(b'data2') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'data1data2', content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg.send_headers() + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'4\r\ndata\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_multiple(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg.send_headers() + + msg.write(b'data1') + msg.write(b'data2') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_length(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '2')) + msg.send_headers() + + msg.write(b'd') + msg.write(b'ata') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'da', content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_filter(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(2) + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith(b'2\r\nda\r\n2\r\nta\r\n0\r\n\r\n')) + + def test_write_payload_chunked_filter_mutiple_chunks(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(2) + msg.write(b'data1') + msg.write(b'data2') + msg.write_eof() + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith( + b'2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n' + b'2\r\na2\r\n0\r\n\r\n')) + + def test_write_payload_chunked_large_chunk(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(1024) + msg.write(b'data') + msg.write_eof() + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith(b'4\r\ndata\r\n0\r\n\r\n')) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_write_payload_deflate_filter(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '{}'.format(len(self._COMPRESSED)))) + msg.send_headers() + + msg.add_compression_filter('deflate') + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_deflate_and_chunked(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_compression_filter('deflate') + msg.add_chunking_filter(2) + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'2\r\nKI\r\n2\r\n,I\r\n2\r\n\x04\x00\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_and_deflate(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '{}'.format(len(self._COMPRESSED)))) + + msg.add_chunking_filter(2) + msg.add_compression_filter('deflate') + msg.send_headers() + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) diff --git a/tests/http_server_test.py b/tests/http_server_test.py new file mode 100644 index 00000000..c0f09603 --- /dev/null +++ b/tests/http_server_test.py @@ -0,0 +1,248 @@ +"""Tests for http/server.py""" + +import unittest +import unittest.mock + +import tulip +from tulip.http import server +from tulip.http import errors + + +class HttpServerProtocolTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + + def test_http_status_exception(self): + exc = errors.HttpStatusException(500, message='Internal error') + self.assertEqual(exc.code, 500) + self.assertEqual(exc.message, 'Internal error') + + def test_handle_request(self): + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + rline = unittest.mock.Mock() + rline.version = (1, 1) + message = unittest.mock.Mock() + srv.handle_request(rline, message) + + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertTrue(content.startswith(b'HTTP/1.1 404 Not Found\r\n')) + + def test_connection_made(self): + srv = server.ServerHttpProtocol() + self.assertIsNone(srv._request_handle) + + srv.connection_made(unittest.mock.Mock()) + self.assertIsNotNone(srv._request_handle) + + def test_data_received(self): + srv = server.ServerHttpProtocol() + srv.connection_made(unittest.mock.Mock()) + + srv.data_received(b'123') + self.assertEqual(b'123', bytes(srv.stream._buffer)) + + srv.data_received(b'456') + self.assertEqual(b'123456', bytes(srv.stream._buffer)) + + def test_eof_received(self): + srv = server.ServerHttpProtocol() + srv.connection_made(unittest.mock.Mock()) + srv.eof_received() + self.assertTrue(srv.stream._eof) + + def test_connection_lost(self): + srv = server.ServerHttpProtocol() + srv.connection_made(unittest.mock.Mock()) + srv.data_received(b'123') + + handle = srv._request_handle + srv.connection_lost(None) + + self.assertIsNone(srv._request_handle) + self.assertTrue(handle.cancelled()) + + srv.connection_lost(None) + self.assertIsNone(srv._request_handle) + + def test_close(self): + srv = server.ServerHttpProtocol() + self.assertFalse(srv._closing) + + srv.close() + self.assertTrue(srv._closing) + + def test_handle_error(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + srv.handle_error(404, headers=(('X-Server', 'Tulip'),)) + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertIn(b'HTTP/1.1 404 Not Found', content) + self.assertIn(b'X-SERVER: Tulip', content) + + @unittest.mock.patch('tulip.http.server.traceback') + def test_handle_error_traceback_exc(self, m_trace): + transport = unittest.mock.Mock() + log = unittest.mock.Mock() + srv = server.ServerHttpProtocol(debug=True, log=log) + srv.connection_made(transport) + + m_trace.format_exc.side_effect = ValueError + + srv.handle_error(500, exc=object()) + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertTrue( + content.startswith(b'HTTP/1.1 500 Internal Server Error')) + self.assertTrue(log.exception.called) + + def test_handle_error_debug(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.debug = True + srv.connection_made(transport) + + try: + raise ValueError() + except Exception as exc: + srv.handle_error(999, exc=exc) + + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + + self.assertIn(b'HTTP/1.1 500 Internal', content) + self.assertIn(b'Traceback (most recent call last):', content) + + def test_handle_error_500(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log) + srv.connection_made(transport) + + srv.handle_error(500) + self.assertTrue(log.exception.called) + + def test_handle(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + srv.close() + + self.loop.run_until_complete(srv._request_handle) + self.assertTrue(handle.called) + + def test_handle_coro(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + + called = False + + @tulip.coroutine + def coro(message, payload): + nonlocal called + called = True + srv.eof_received() + srv.close() + + srv.handle_request = coro + srv.connection_made(transport) + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + self.loop.run_until_complete(srv._request_handle) + self.assertTrue(called) + + def test_handle_close(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + srv.close() + self.loop.run_until_complete(srv._request_handle) + + self.assertTrue(handle.called) + self.assertTrue(transport.close.called) + + def test_handle_cancel(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log, debug=True) + srv.connection_made(transport) + + srv.handle_request = unittest.mock.Mock() + + @tulip.task + def cancel(): + srv._request_handle.cancel() + + srv.close() + self.loop.run_until_complete( + tulip.wait([srv._request_handle, cancel()])) + self.assertTrue(log.debug.called) + + def test_handle_400(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + srv.handle_error = unittest.mock.Mock() + + def side_effect(*args): + srv.close() + srv.handle_error.side_effect = side_effect + + srv.stream.feed_data(b'GET / HT/asd\r\n\r\n') + + self.loop.run_until_complete(srv._request_handle) + self.assertTrue(srv.handle_error.called) + self.assertTrue(400, srv.handle_error.call_args[0][0]) + self.assertTrue(transport.close.called) + + def test_handle_500(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + handle.side_effect = ValueError + srv.handle_error = unittest.mock.Mock() + srv.close() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + self.loop.run_until_complete(srv._request_handle) + + self.assertTrue(srv.handle_error.called) + self.assertTrue(500, srv.handle_error.call_args[0][0]) + + def test_handle_error_no_handle_task(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol() + srv.connection_made(transport) + srv.connection_lost(None) + close = srv.close = unittest.mock.Mock() + + srv.handle_error(300) + self.assertTrue(close.called) diff --git a/tests/http_session_test.py b/tests/http_session_test.py new file mode 100644 index 00000000..1b86c56d --- /dev/null +++ b/tests/http_session_test.py @@ -0,0 +1,133 @@ +"""Tests for tulip/http/session.py""" + +import http.cookies +import unittest +import unittest.mock + +import tulip +import tulip.http + +from tulip.http.client import HttpResponse +from tulip.http.session import Session + + +class HttpSessionTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + self.transport = unittest.mock.Mock() + self.stream = tulip.StreamBuffer() + self.response = HttpResponse('get', 'http://python.org') + + def test_del(self): + session = Session() + close = session.close = unittest.mock.Mock() + + del session + self.assertTrue(close.called) + + def test_close(self): + tr = unittest.mock.Mock() + + session = Session() + session._conns[1] = [(tr, object())] + session.close() + + self.assertFalse(session._conns) + self.assertTrue(tr.close.called) + + def test_get(self): + session = Session() + self.assertEqual(session._get(1), (None, None)) + + tr, proto = unittest.mock.Mock(), object() + session._conns[1] = [(tr, proto)] + self.assertEqual(session._get(1), (tr, proto)) + + def test_release(self): + session = Session() + resp = unittest.mock.Mock() + resp.message.should_close = False + + cookies = resp.cookies = http.cookies.SimpleCookie() + cookies['c1'] = 'cookie1' + cookies['c2'] = 'cookie2' + + tr, proto = unittest.mock.Mock(), unittest.mock.Mock() + session._release(resp, 1, (tr, proto)) + self.assertEqual(session._conns[1][0], (tr, proto)) + self.assertEqual(session.cookies, dict(cookies.items())) + + def test_release_close(self): + session = Session() + resp = unittest.mock.Mock() + resp.message.should_close = True + + cookies = resp.cookies = http.cookies.SimpleCookie() + cookies['c1'] = 'cookie1' + cookies['c2'] = 'cookie2' + + tr, proto = unittest.mock.Mock(), unittest.mock.Mock() + session._release(resp, 1, (tr, proto)) + self.assertFalse(session._conns) + self.assertTrue(tr.close.called) + + def test_call_new_conn_exc(self): + tr, proto = unittest.mock.Mock(), unittest.mock.Mock() + + class Req: + host = 'host' + port = 80 + ssl = False + + def send(self, *args): + raise ValueError() + + class Loop: + @tulip.coroutine + def create_connection(self, *args, **kw): + return tr, proto + + session = Session() + self.assertRaises( + ValueError, + self.loop.run_until_complete, session.start(Req(), Loop(), True)) + + self.assertTrue(tr.close.called) + + def test_call_existing_conn_exc(self): + existing = unittest.mock.Mock() + tr, proto = unittest.mock.Mock(), unittest.mock.Mock() + + class Req: + host = 'host' + port = 80 + ssl = False + + def send(self, transport): + if transport is existing: + transport.close() + raise ValueError() + else: + return Resp() + + class Resp: + @tulip.coroutine + def start(self, *args, **kw): + pass + + class Loop: + @tulip.coroutine + def create_connection(self, *args, **kw): + return tr, proto + + session = Session() + key = ('host', 80, False) + session._conns[key] = [(existing, object())] + + resp = self.loop.run_until_complete(session.start(Req(), Loop())) + self.assertIsInstance(resp, Resp) + self.assertTrue(existing.close.called) + self.assertFalse(session._conns[key]) diff --git a/tests/http_websocket_test.py b/tests/http_websocket_test.py new file mode 100644 index 00000000..bd89b75b --- /dev/null +++ b/tests/http_websocket_test.py @@ -0,0 +1,426 @@ +"""Tests for http/websocket.py""" + +import base64 +import hashlib +import os +import struct +import unittest +import unittest.mock + +import tulip +from tulip.http import websocket, protocol, errors + + +class WebsocketParserTests(unittest.TestCase): + + def test_parse_frame(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + p.send(struct.pack('!BB', 0b00000001, 0b00000001)) + try: + p.send(b'1') + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b'1'), (fin, opcode, payload)) + + def test_parse_frame_length0(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + try: + p.send(struct.pack('!BB', 0b00000001, 0b00000000)) + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b''), (fin, opcode, payload)) + + def test_parse_frame_length2(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + p.send(struct.pack('!BB', 0b00000001, 126)) + p.send(struct.pack('!H', 4)) + try: + p.send(b'1234') + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b'1234'), (fin, opcode, payload)) + + def test_parse_frame_length4(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + p.send(struct.pack('!BB', 0b00000001, 127)) + p.send(struct.pack('!Q', 4)) + try: + p.send(b'1234') + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b'1234'), (fin, opcode, payload)) + + def test_parse_frame_mask(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + p.send(struct.pack('!BB', 0b00000001, 0b10000001)) + p.send(b'0001') + try: + p.send(b'1') + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b'\x01'), (fin, opcode, payload)) + + def test_parse_frame_header_reversed_bits(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b01100000, 0b00000000)) + + def test_parse_frame_header_control_frame(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b00001000, 0b00000000)) + + def test_parse_frame_header_continuation(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b00000000, 0b00000000)) + + def test_parse_frame_header_new_data_err(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b000000000, 0b00000000)) + + def test_parse_frame_header_payload_size(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b10001000, 0b01111110)) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_ping_frame(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_PING, b'') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_PING, '', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_pong_frame(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_PONG, b'') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_PONG, '', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_close_frame(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_CLOSE, b'') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_CLOSE, '', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_close_frame_info(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_CLOSE, b'0112345') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_CLOSE, 12337, b'12345')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_close_frame_invalid(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_CLOSE, b'1') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + self.assertRaises(websocket.WebSocketError, p.send, b'') + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_unknown_frame(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_CONTINUATION, b'') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + self.assertRaises(websocket.WebSocketError, p.send, b'') + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_simple_text(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_TEXT, b'text') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_TEXT, 'text', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_simple_binary(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_BINARY, b'binary') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_BINARY, b'binary', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_continuation(self, m_parse_frame): + cur = 0 + + def parse_frame(buf): + nonlocal cur + yield + if cur == 0: + cur = 1 + return (0, websocket.OPCODE_TEXT, b'line1') + else: + return (1, websocket.OPCODE_CONTINUATION, b'line2') + + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + p.send(b'') + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_TEXT, 'line1line2', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_continuation_err(self, m_parse_frame): + cur = 0 + + def parse_frame(buf): + nonlocal cur + yield + if cur == 0: + cur = 1 + return (0, websocket.OPCODE_TEXT, b'line1') + else: + return (1, websocket.OPCODE_TEXT, b'line2') + + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + p.send(b'') + self.assertRaises(websocket.WebSocketError, p.send, b'') + + @unittest.mock.patch('tulip.http.websocket.parse_message') + def test_parser(self, m_parse_message): + cur = 0 + + def parse_message(buf): + nonlocal cur + yield + if cur == 0: + cur = 1 + return websocket.Message(websocket.OPCODE_TEXT, b'line1', b'') + else: + return websocket.Message(websocket.OPCODE_CLOSE, b'', b'') + + m_parse_message.side_effect = parse_message + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = websocket.WebSocketParser() + next(p) + p.send((out, buf)) + p.send(b'') + self.assertRaises(StopIteration, p.send, b'') + + self.assertEqual( + (websocket.OPCODE_TEXT, b'line1', b''), out._buffer[0]) + self.assertEqual( + (websocket.OPCODE_CLOSE, b'', b''), out._buffer[1]) + self.assertTrue(out._eof) + + def test_parser_eof(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = websocket.WebSocketParser() + next(p) + p.send((out, buf)) + self.assertRaises(tulip.EofStream, p.throw, tulip.EofStream) + self.assertEqual([], list(out._buffer)) + + +class WebsocketWriterTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + self.writer = websocket.WebSocketWriter(self.transport) + + def test_pong(self): + self.writer.pong() + self.transport.write.assert_called_with(b'\x8a\x00') + + def test_ping(self): + self.writer.ping() + self.transport.write.assert_called_with(b'\x89\x00') + + def test_send_text(self): + self.writer.send(b'text') + self.transport.write.assert_called_with(b'\x81\x04text') + + def test_send_binary(self): + self.writer.send('binary', True) + self.transport.write.assert_called_with(b'\x82\x06binary') + + def test_send_binary_long(self): + self.writer.send(b'b'*127, True) + self.assertTrue( + self.transport.write.call_args[0][0].startswith(b'\x82~\x00\x7fb')) + + def test_send_binary_very_long(self): + self.writer.send(b'b'*65537, True) + self.assertTrue( + self.transport.write.call_args[0][0].startswith( + b'\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x01b')) + + def test_close(self): + self.writer.close(1001, 'msg') + self.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') + + self.writer.close(1001, b'msg') + self.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') + + +class WebSocketHandshakeTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + self.headers = [] + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 0), self.headers, True, None) + + def test_no_upgrade(self): + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, self.message, self.transport) + + def test_no_connection(self): + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'keep-alive')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, self.message, self.transport) + + def test_protocol_version(self): + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, self.message, self.transport) + + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '1')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, self.message, self.transport) + + def test_protocol_key(self): + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, self.message, self.transport) + + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13'), + ('SEC-WEBSOCKET-KEY', '123')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, self.message, self.transport) + + sec_key = base64.b64encode(os.urandom(2)) + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13'), + ('SEC-WEBSOCKET-KEY', sec_key.decode())]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, self.message, self.transport) + + def test_handshake(self): + sec_key = base64.b64encode(os.urandom(16)).decode() + + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13'), + ('SEC-WEBSOCKET-KEY', sec_key)]) + status, headers, parser, writer = websocket.do_handshake( + self.message, self.transport) + self.assertEqual(status, 101) + + key = base64.b64encode( + hashlib.sha1(sec_key.encode() + websocket.WS_KEY).digest()) + headers = dict(headers) + self.assertEqual(headers['SEC-WEBSOCKET-ACCEPT'], key.decode()) diff --git a/tests/http_wsgi_test.py b/tests/http_wsgi_test.py new file mode 100644 index 00000000..1145f273 --- /dev/null +++ b/tests/http_wsgi_test.py @@ -0,0 +1,241 @@ +"""Tests for http/wsgi.py""" + +import io +import unittest +import unittest.mock + +import tulip +from tulip.http import wsgi +from tulip.http import protocol + + +class HttpWsgiServerProtocolTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(self.loop) + + self.wsgi = unittest.mock.Mock() + self.stream = unittest.mock.Mock() + self.transport = unittest.mock.Mock() + self.transport.get_extra_info.return_value = '127.0.0.1' + + self.payload = b'data' + self.headers = [] + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 0), self.headers, True, 'deflate') + self.payload = tulip.DataBuffer() + self.payload.feed_data(b'data') + self.payload.feed_data(b'data') + self.payload.feed_eof() + + def tearDown(self): + self.loop.close() + + def test_ctor(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi) + self.assertIs(srv.wsgi, self.wsgi) + self.assertFalse(srv.readpayload) + + def _make_one(self, **kw): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, **kw) + srv.stream = self.stream + srv.transport = self.transport + return srv.create_wsgi_environ(self.message, self.payload) + + def test_environ(self): + environ = self._make_one() + self.assertEqual(environ['RAW_URI'], '/path') + self.assertEqual(environ['wsgi.async'], True) + + def test_environ_except_header(self): + self.headers.append(('EXPECT', '101-continue')) + self._make_one() + self.assertFalse(self.transport.write.called) + + self.headers[0] = ('EXPECT', '100-continue') + self._make_one() + self.transport.write.assert_called_with( + b'HTTP/1.1 100 Continue\r\n\r\n') + + def test_environ_headers(self): + self.headers.extend( + (('HOST', 'python.org'), + ('SCRIPT_NAME', 'script'), + ('CONTENT-TYPE', 'text/plain'), + ('CONTENT-LENGTH', '209'), + ('X_TEST', '123'), + ('X_TEST', '456'))) + environ = self._make_one(is_ssl=True) + self.assertEqual(environ['CONTENT_TYPE'], 'text/plain') + self.assertEqual(environ['CONTENT_LENGTH'], '209') + self.assertEqual(environ['HTTP_X_TEST'], '123,456') + self.assertEqual(environ['SCRIPT_NAME'], 'script') + self.assertEqual(environ['SERVER_NAME'], 'python.org') + self.assertEqual(environ['SERVER_PORT'], '443') + + def test_environ_host_header(self): + self.headers.append(('HOST', 'python.org')) + environ = self._make_one() + + self.assertEqual(environ['HTTP_HOST'], 'python.org') + self.assertEqual(environ['SERVER_NAME'], 'python.org') + self.assertEqual(environ['SERVER_PORT'], '80') + self.assertEqual(environ['SERVER_PROTOCOL'], 'HTTP/1.0') + + def test_environ_host_port_header(self): + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 1), self.headers, True, 'deflate') + self.headers.append(('HOST', 'python.org:443')) + environ = self._make_one() + + self.assertEqual(environ['HTTP_HOST'], 'python.org:443') + self.assertEqual(environ['SERVER_NAME'], 'python.org') + self.assertEqual(environ['SERVER_PORT'], '443') + self.assertEqual(environ['SERVER_PROTOCOL'], 'HTTP/1.1') + + def test_environ_forward(self): + self.transport.get_extra_info.return_value = 'localhost,127.0.0.1' + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') + self.assertEqual(environ['REMOTE_PORT'], '80') + + self.transport.get_extra_info.return_value = 'localhost,127.0.0.1:443' + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') + self.assertEqual(environ['REMOTE_PORT'], '443') + + self.transport.get_extra_info.return_value = ('127.0.0.1', 443) + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') + self.assertEqual(environ['REMOTE_PORT'], '443') + + self.transport.get_extra_info.return_value = '[::1]' + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '::1') + self.assertEqual(environ['REMOTE_PORT'], '80') + + def test_wsgi_response(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + self.assertIsInstance(resp, wsgi.WsgiResponse) + + def test_wsgi_response_start_response(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + resp.start_response( + '200 OK', [('CONTENT-TYPE', 'text/plain')]) + self.assertEqual(resp.status, '200 OK') + self.assertIsInstance(resp.response, protocol.Response) + + def test_wsgi_response_start_response_exc(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + resp.start_response( + '200 OK', [('CONTENT-TYPE', 'text/plain')], ['', ValueError()]) + self.assertEqual(resp.status, '200 OK') + self.assertIsInstance(resp.response, protocol.Response) + + def test_wsgi_response_start_response_exc_status(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + resp.start_response('200 OK', [('CONTENT-TYPE', 'text/plain')]) + + self.assertRaises( + ValueError, + resp.start_response, + '500 Err', [('CONTENT-TYPE', 'text/plain')], ['', ValueError()]) + + def test_file_wrapper(self): + fobj = io.BytesIO(b'data') + wrapper = wsgi.FileWrapper(fobj, 2) + self.assertIs(wrapper, iter(wrapper)) + self.assertTrue(hasattr(wrapper, 'close')) + + self.assertEqual(next(wrapper), b'da') + self.assertEqual(next(wrapper), b'ta') + self.assertRaises(StopIteration, next, wrapper) + + wrapper = wsgi.FileWrapper(b'data', 2) + self.assertFalse(hasattr(wrapper, 'close')) + + def test_handle_request_futures(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + f1 = tulip.Future() + f1.set_result(b'data') + fut = tulip.Future() + fut.set_result([f1]) + return fut + + srv = wsgi.WSGIServerHttpProtocol(wsgi_app) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) + self.assertTrue(content.endswith(b'data')) + + def test_handle_request_simple(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + return [b'data'] + + stream = tulip.StreamReader() + stream.feed_data(b'data') + stream.feed_eof() + + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 1), self.headers, True, 'deflate') + + srv = wsgi.WSGIServerHttpProtocol(wsgi_app, True) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.1 200 OK')) + self.assertTrue(content.endswith(b'data\r\n0\r\n\r\n')) + + def test_handle_request_io(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + return io.BytesIO(b'data') + + srv = wsgi.WSGIServerHttpProtocol(wsgi_app) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) + self.assertTrue(content.endswith(b'data')) diff --git a/tests/locks_test.py b/tests/locks_test.py new file mode 100644 index 00000000..a2e03381 --- /dev/null +++ b/tests/locks_test.py @@ -0,0 +1,803 @@ +"""Tests for lock.py""" + +import time +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import locks +from tulip import tasks + + +class LockTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + lock = locks.Lock() + self.assertTrue(repr(lock).endswith('[unlocked]>')) + + @tasks.coroutine + def acquire_lock(): + yield from lock + + self.event_loop.run_until_complete(acquire_lock()) + self.assertTrue(repr(lock).endswith('[locked]>')) + + def test_lock(self): + lock = locks.Lock() + + @tasks.coroutine + def acquire_lock(): + return (yield from lock) + + res = self.event_loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_acquire(self): + lock = locks.Lock() + result = [] + + self.assertTrue( + self.event_loop.run_until_complete(lock.acquire())) + + @tasks.coroutine + def c1(result): + if (yield from lock.acquire()): + result.append(1) + return True + + @tasks.coroutine + def c2(result): + if (yield from lock.acquire()): + result.append(2) + return True + + @tasks.coroutine + def c3(result): + if (yield from lock.acquire()): + result.append(3) + return True + + t1 = tasks.Task(c1(result)) + t2 = tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + self.event_loop.run_once() + self.assertEqual([1], result) + + t3 = tasks.Task(c3(result)) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + + lock.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_acquire_timeout(self): + lock = locks.Lock() + self.assertTrue( + self.event_loop.run_until_complete(lock.acquire())) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete( + lock.acquire(timeout=0.1)) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + lock = locks.Lock() + self.event_loop.run_until_complete(lock.acquire()) + + self.event_loop.call_later(0.01, lock.release) + acquired = self.event_loop.run_until_complete(lock.acquire(10.1)) + self.assertTrue(acquired) + + def test_acquire_timeout_mixed(self): + lock = locks.Lock() + self.event_loop.run_until_complete(lock.acquire()) + tasks.Task(lock.acquire()) + tasks.Task(lock.acquire()) + acquire_task = tasks.Task(lock.acquire(0.01)) + tasks.Task(lock.acquire()) + + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + self.assertEqual(3, len(lock._waiters)) + + def test_acquire_cancel(self): + lock = locks.Lock() + self.assertTrue( + self.event_loop.run_until_complete(lock.acquire())) + + task = tasks.Task(lock.acquire()) + self.event_loop.call_soon(task.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, task) + self.assertFalse(lock._waiters) + + def test_release_not_acquired(self): + lock = locks.Lock() + + self.assertRaises(RuntimeError, lock.release) + + def test_release_no_waiters(self): + lock = locks.Lock() + self.event_loop.run_until_complete(lock.acquire()) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_context_manager(self): + lock = locks.Lock() + + @tasks.task + def acquire_lock(): + return (yield from lock) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + def test_context_manager_no_yield(self): + lock = locks.Lock() + + try: + with lock: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + +class EventWaiterTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + ev = locks.EventWaiter() + self.assertTrue(repr(ev).endswith('[unset]>')) + + ev.set() + self.assertTrue(repr(ev).endswith('[set]>')) + + def test_wait(self): + ev = locks.EventWaiter() + self.assertFalse(ev.is_set()) + + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from ev.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from ev.wait()): + result.append(3) + + t1 = tasks.Task(c1(result)) + t2 = tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + t3 = tasks.Task(c3(result)) + + ev.set() + self.event_loop.run_once() + self.assertEqual([3, 1, 2], result) + + self.assertTrue(t1.done()) + self.assertIsNone(t1.result()) + self.assertTrue(t2.done()) + self.assertIsNone(t2.result()) + self.assertTrue(t3.done()) + self.assertIsNone(t3.result()) + + def test_wait_on_set(self): + ev = locks.EventWaiter() + ev.set() + + res = self.event_loop.run_until_complete(ev.wait()) + self.assertTrue(res) + + def test_wait_timeout(self): + ev = locks.EventWaiter() + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(ev.wait(0.1)) + self.assertFalse(res) + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + ev = locks.EventWaiter() + self.event_loop.call_later(0.01, ev.set) + acquired = self.event_loop.run_until_complete(ev.wait(10.1)) + self.assertTrue(acquired) + + def test_wait_timeout_mixed(self): + ev = locks.EventWaiter() + tasks.Task(ev.wait()) + tasks.Task(ev.wait()) + acquire_task = tasks.Task(ev.wait(0.1)) + tasks.Task(ev.wait()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(ev._waiters)) + + def test_wait_cancel(self): + ev = locks.EventWaiter() + + wait = tasks.Task(ev.wait()) + self.event_loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, wait) + self.assertFalse(ev._waiters) + + def test_clear(self): + ev = locks.EventWaiter() + self.assertFalse(ev.is_set()) + + ev.set() + self.assertTrue(ev.is_set()) + + ev.clear() + self.assertFalse(ev.is_set()) + + def test_clear_with_waiters(self): + ev = locks.EventWaiter() + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + return True + + t = tasks.Task(c1(result)) + self.event_loop.run_once() + self.assertEqual([], result) + + ev.set() + ev.clear() + self.assertFalse(ev.is_set()) + + ev.set() + ev.set() + self.assertEqual(1, len(ev._waiters)) + + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertEqual(0, len(ev._waiters)) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + +class ConditionTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_wait(self): + cond = locks.Condition() + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + return True + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + return True + + t1 = tasks.Task(c1(result)) + t2 = tasks.Task(c2(result)) + t3 = tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue( + self.event_loop.run_until_complete(cond.acquire())) + cond.notify() + self.event_loop.run_once() + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_wait_timeout(self): + cond = locks.Condition() + self.event_loop.run_until_complete(cond.acquire()) + + t0 = time.monotonic() + wait = self.event_loop.run_until_complete(cond.wait(0.1)) + self.assertFalse(wait) + self.assertTrue(cond.locked()) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + def test_wait_cancel(self): + cond = locks.Condition() + self.event_loop.run_until_complete(cond.acquire()) + + wait = tasks.Task(cond.wait()) + self.event_loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, wait) + self.assertFalse(cond._condition_waiters) + self.assertTrue(cond.locked()) + + def test_wait_unacquired(self): + cond = locks.Condition() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, cond.wait()) + + def test_wait_for(self): + cond = locks.Condition() + presult = False + + def predicate(): + return presult + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate)): + result.append(1) + cond.release() + return True + + t = tasks.Task(c1(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([], result) + + presult = True + self.event_loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + def test_wait_for_timeout(self): + cond = locks.Condition() + + result = [] + + predicate = unittest.mock.Mock() + predicate.return_value = False + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate, 0.1)): + result.append(1) + else: + result.append(2) + cond.release() + + wait_for = tasks.Task(c1(result)) + + t0 = time.monotonic() + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(wait_for) + self.assertEqual([2], result) + self.assertEqual(3, predicate.call_count) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + def test_wait_for_unacquired(self): + cond = locks.Condition() + + # predicate can return true immediately + res = self.event_loop.run_until_complete( + cond.wait_for(lambda: [1, 2, 3])) + self.assertEqual([1, 2, 3], res) + + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, + cond.wait_for(lambda: False)) + + def test_notify(self): + cond = locks.Condition() + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + return True + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + cond.release() + return True + + t1 = tasks.Task(c1(result)) + t2 = tasks.Task(c2(result)) + t3 = tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.release() + self.event_loop.run_once() + self.assertEqual([1], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.notify(2048) + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_notify_all(self): + cond = locks.Condition() + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + return True + + t1 = tasks.Task(c1(result)) + t2 = tasks.Task(c2(result)) + + self.event_loop.run_once() + self.assertEqual([], result) + + self.event_loop.run_until_complete(cond.acquire()) + cond.notify_all() + cond.release() + self.event_loop.run_once() + self.assertEqual([1, 2], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + + def test_notify_unacquired(self): + cond = locks.Condition() + self.assertRaises(RuntimeError, cond.notify) + + def test_notify_all_unacquired(self): + cond = locks.Condition() + self.assertRaises(RuntimeError, cond.notify_all) + + +class SemaphoreTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_repr(self): + sem = locks.Semaphore() + self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) + + self.event_loop.run_until_complete(sem.acquire()) + self.assertTrue(repr(sem).endswith('[locked]>')) + + def test_semaphore(self): + sem = locks.Semaphore() + self.assertEqual(1, sem._value) + + @tasks.task + def acquire_lock(): + return (yield from sem) + + res = self.event_loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(sem.locked()) + self.assertEqual(0, sem._value) + + sem.release() + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + def test_semaphore_value(self): + self.assertRaises(ValueError, locks.Semaphore, -1) + + def test_acquire(self): + sem = locks.Semaphore(3) + result = [] + + self.assertTrue( + self.event_loop.run_until_complete(sem.acquire())) + self.assertTrue( + self.event_loop.run_until_complete(sem.acquire())) + self.assertFalse(sem.locked()) + + @tasks.coroutine + def c1(result): + yield from sem.acquire() + result.append(1) + return True + + @tasks.coroutine + def c2(result): + yield from sem.acquire() + result.append(2) + return True + + @tasks.coroutine + def c3(result): + yield from sem.acquire() + result.append(3) + return True + + @tasks.coroutine + def c4(result): + yield from sem.acquire() + result.append(4) + return True + + t1 = tasks.Task(c1(result)) + t2 = tasks.Task(c2(result)) + t3 = tasks.Task(c3(result)) + + self.event_loop.run_once() + self.assertEqual([1], result) + self.assertTrue(sem.locked()) + self.assertEqual(2, len(sem._waiters)) + self.assertEqual(0, sem._value) + + t4 = tasks.Task(c4(result)) + + sem.release() + sem.release() + self.assertEqual(2, sem._value) + + self.event_loop.run_once() + self.assertEqual(0, sem._value) + self.assertEqual([1, 2, 3], result) + self.assertTrue(sem.locked()) + self.assertEqual(1, len(sem._waiters)) + self.assertEqual(0, sem._value) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + self.assertFalse(t4.done()) + + def test_acquire_timeout(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(sem.acquire(0.1)) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + + self.event_loop.call_later(0.01, sem.release) + acquired = self.event_loop.run_until_complete(sem.acquire(10.1)) + self.assertTrue(acquired) + + def test_acquire_timeout_mixed(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + tasks.Task(sem.acquire()) + tasks.Task(sem.acquire()) + acquire_task = tasks.Task(sem.acquire(0.1)) + tasks.Task(sem.acquire()) + + t0 = time.monotonic() + acquired = self.event_loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + total_time = (time.monotonic() - t0) + self.assertTrue(0.08 < total_time < 0.12) + + self.assertEqual(3, len(sem._waiters)) + + def test_acquire_cancel(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + + acquire = tasks.Task(sem.acquire()) + self.event_loop.call_soon(acquire.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, acquire) + self.assertFalse(sem._waiters) + + def test_release_not_acquired(self): + sem = locks.Semaphore(bound=True) + + self.assertRaises(ValueError, sem.release) + + def test_release_no_waiters(self): + sem = locks.Semaphore() + self.event_loop.run_until_complete(sem.acquire()) + self.assertTrue(sem.locked()) + + sem.release() + self.assertFalse(sem.locked()) + + def test_context_manager(self): + sem = locks.Semaphore(2) + + @tasks.task + def acquire_lock(): + return (yield from sem) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + with self.event_loop.run_until_complete(acquire_lock()): + self.assertTrue(sem.locked()) + + self.assertEqual(2, sem._value) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/parsers_test.py b/tests/parsers_test.py new file mode 100644 index 00000000..083e141c --- /dev/null +++ b/tests/parsers_test.py @@ -0,0 +1,598 @@ +"""Tests for parser.py""" + +import unittest +import unittest.mock + +from tulip import events +from tulip import parsers +from tulip import tasks + + +class StreamBufferTests(unittest.TestCase): + + DATA = b'line1\nline2\nline3\n' + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + + def test_exception(self): + stream = parsers.StreamBuffer() + self.assertIsNone(stream.exception()) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(stream.exception(), exc) + + def test_exception_waiter(self): + stream = parsers.StreamBuffer() + + stream._parser = parsers.lines_parser() + buf = stream._parser_buffer = parsers.DataBuffer() + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(buf.exception(), exc) + + def test_feed_data(self): + stream = parsers.StreamBuffer() + + stream.feed_data(self.DATA) + self.assertEqual(self.DATA, bytes(stream._buffer)) + + def test_feed_empty_data(self): + stream = parsers.StreamBuffer() + + stream.feed_data(b'') + self.assertEqual(b'', bytes(stream._buffer)) + + def test_set_parser_unset_prev(self): + stream = parsers.StreamBuffer() + stream.set_parser(parsers.lines_parser()) + + unset = stream.unset_parser = unittest.mock.Mock() + stream.set_parser(parsers.lines_parser()) + + self.assertTrue(unset.called) + + def test_set_parser_exception(self): + stream = parsers.StreamBuffer() + + exc = ValueError() + stream.set_exception(exc) + s = stream.set_parser(parsers.lines_parser()) + self.assertIs(s.exception(), exc) + + def test_set_parser_feed_existing(self): + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_data(b'\r\nline2\r\ndata') + s = stream.set_parser(parsers.lines_parser()) + + self.assertEqual([bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertIsNotNone(stream._parser) + + stream.unset_parser() + self.assertIsNone(stream._parser) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertTrue(s._eof) + + def test_set_parser_feed_existing_exc(self): + + def p(): + yield # stream + raise ValueError() + + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + s = stream.set_parser(p()) + self.assertIsInstance(s.exception(), ValueError) + + def test_set_parser_feed_existing_eof(self): + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_data(b'\r\nline2\r\ndata') + stream.feed_eof() + s = stream.set_parser(parsers.lines_parser()) + + self.assertEqual([bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertIsNone(stream._parser) + + def test_set_parser_feed_existing_eof_exc(self): + + def p(): + yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + raise ValueError() + + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_eof() + s = stream.set_parser(p()) + self.assertIsInstance(s.exception(), ValueError) + + def test_set_parser_feed_existing_eof_unhandled_eof(self): + + def p(): + yield # stream + while True: + yield # read chunk + + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_eof() + s = stream.set_parser(p()) + self.assertIsNone(s.exception()) + self.assertTrue(s._eof) + + def test_set_parser_unset(self): + stream = parsers.StreamBuffer() + s = stream.set_parser(parsers.lines_parser()) + + stream.feed_data(b'line1\r\nline2\r\n') + self.assertEqual( + [bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'', bytes(stream._buffer)) + stream.unset_parser() + self.assertTrue(s._eof) + self.assertEqual(b'', bytes(stream._buffer)) + + def test_set_parser_feed_existing_stop(self): + def lines_parser(): + out, buf = yield + try: + out.feed_data((yield from buf.readuntil(b'\n'))) + out.feed_data((yield from buf.readuntil(b'\n'))) + finally: + out.feed_eof() + + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_data(b'\r\nline2\r\ndata') + s = stream.set_parser(lines_parser()) + + self.assertEqual(b'line1\r\nline2\r\n', b''.join(s._buffer)) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertIsNone(stream._parser) + self.assertTrue(s._eof) + + def test_feed_parser(self): + stream = parsers.StreamBuffer() + s = stream.set_parser(parsers.lines_parser()) + + stream.feed_data(b'line1') + stream.feed_data(b'\r\nline2\r\ndata') + self.assertEqual(b'data', bytes(stream._buffer)) + + stream.feed_eof() + self.assertEqual([bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertTrue(s._eof) + + def test_feed_parser_exc(self): + def p(): + yield # stream + yield # read chunk + raise ValueError() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + self.assertIsInstance(s.exception(), ValueError) + self.assertEqual(b'', bytes(stream._buffer)) + + def test_feed_parser_stop(self): + def p(): + yield # stream + yield # chunk + + stream = parsers.StreamBuffer() + stream.set_parser(p()) + + stream.feed_data(b'line1') + self.assertIsNone(stream._parser) + self.assertEqual(b'', bytes(stream._buffer)) + + def test_feed_eof_exc(self): + def p(): + yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + raise ValueError() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + self.assertIsNone(s.exception()) + + stream.feed_eof() + self.assertIsInstance(s.exception(), ValueError) + + def test_feed_eof_stop(self): + def p(): + out, buf = yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + out.feed_eof() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.feed_eof() + self.assertTrue(s._eof) + + def test_feed_eof_unhandled_eof(self): + def p(): + yield # stream + while True: + yield # read chunk + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.feed_eof() + self.assertIsNone(s.exception()) + self.assertTrue(s._eof) + + def test_feed_parser2(self): + stream = parsers.StreamBuffer() + s = stream.set_parser(parsers.lines_parser()) + + stream.feed_data(b'line1\r\nline2\r\n') + stream.feed_eof() + self.assertEqual( + [bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'', bytes(stream._buffer)) + self.assertTrue(s._eof) + + def test_unset_parser_eof_exc(self): + def p(): + yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + raise ValueError() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.unset_parser() + self.assertIsInstance(s.exception(), ValueError) + self.assertIsNone(stream._parser) + + def test_unset_parser_eof_unhandled_eof(self): + def p(): + yield # stream + while True: + yield # read chunk + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.unset_parser() + self.assertIsNone(s.exception(), ValueError) + self.assertTrue(s._eof) + + def test_unset_parser_stop(self): + def p(): + out, buf = yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + out.feed_eof() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.unset_parser() + self.assertTrue(s._eof) + + +class DataBufferTests(unittest.TestCase): + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + + def test_feed_data(self): + buffer = parsers.DataBuffer() + + item = object() + buffer.feed_data(item) + self.assertEqual([item], list(buffer._buffer)) + + def test_feed_eof(self): + buffer = parsers.DataBuffer() + buffer.feed_eof() + self.assertTrue(buffer._eof) + + def test_read(self): + item = object() + buffer = parsers.DataBuffer() + read_task = tasks.Task(buffer.read()) + + def cb(): + buffer.feed_data(item) + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertIs(item, data) + + def test_read_eof(self): + buffer = parsers.DataBuffer() + read_task = tasks.Task(buffer.read()) + + def cb(): + buffer.feed_eof() + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertIsNone(data) + + def test_read_until_eof(self): + item = object() + buffer = parsers.DataBuffer() + buffer.feed_data(item) + buffer.feed_eof() + + data = self.loop.run_until_complete(buffer.read()) + self.assertIs(data, item) + + data = self.loop.run_until_complete(buffer.read()) + self.assertIsNone(data) + + def test_read_exception(self): + buffer = parsers.DataBuffer() + buffer.feed_data(object()) + buffer.set_exception(ValueError()) + + self.assertRaises( + ValueError, self.loop.run_until_complete, buffer.read()) + + def test_exception(self): + buffer = parsers.DataBuffer() + self.assertIsNone(buffer.exception()) + + exc = ValueError() + buffer.set_exception(exc) + self.assertIs(buffer.exception(), exc) + + def test_exception_waiter(self): + buffer = parsers.DataBuffer() + + @tasks.coroutine + def set_err(): + buffer.set_exception(ValueError()) + + t1 = tasks.Task(buffer.read()) + t2 = tasks.Task(set_err()) + + self.loop.run_until_complete(tasks.wait([t1, t2])) + + self.assertRaises(ValueError, t1.result) + + +class StreamProtocolTests(unittest.TestCase): + + def test_connection_made(self): + tr = unittest.mock.Mock() + + proto = parsers.StreamProtocol() + self.assertIsNone(proto.transport) + + proto.connection_made(tr) + self.assertIs(proto.transport, tr) + + def test_connection_lost(self): + proto = parsers.StreamProtocol() + proto.connection_made(unittest.mock.Mock()) + proto.connection_lost(None) + self.assertIsNone(proto.transport) + self.assertTrue(proto._eof) + + def test_connection_lost_exc(self): + proto = parsers.StreamProtocol() + proto.connection_made(unittest.mock.Mock()) + + exc = ValueError() + proto.connection_lost(exc) + self.assertIs(proto.exception(), exc) + + +class ParserBuffer(unittest.TestCase): + + def _make_one(self): + return parsers.ParserBuffer() + + def test_shrink(self): + buf = parsers.ParserBuffer() + buf.feed_data(b'data') + + buf._shrink() + self.assertEqual(bytes(buf), b'data') + + buf.offset = 2 + buf._shrink() + self.assertEqual(bytes(buf), b'ta') + self.assertEqual(2, len(buf)) + self.assertEqual(2, buf.size) + self.assertEqual(0, buf.offset) + + def test_feed_data(self): + buf = self._make_one() + buf.feed_data(b'') + self.assertEqual(len(buf), 0) + + buf.feed_data(b'data') + self.assertEqual(len(buf), 4) + self.assertEqual(bytes(buf), b'data') + + def test_read(self): + buf = self._make_one() + p = buf.read(3) + next(p) + p.send(b'1') + try: + p.send(b'234') + except StopIteration as exc: + res = exc.value + + self.assertEqual(res, b'123') + self.assertEqual(b'4', bytes(buf)) + + def test_readsome(self): + buf = self._make_one() + p = buf.readsome(3) + next(p) + try: + p.send(b'1') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, b'1') + + p = buf.readsome(2) + next(p) + try: + p.send(b'234') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, b'23') + self.assertEqual(b'4', bytes(buf)) + + def test_skip(self): + buf = self._make_one() + p = buf.skip(3) + next(p) + p.send(b'1') + try: + p.send(b'234') + except StopIteration as exc: + res = exc.value + + self.assertIsNone(res) + self.assertEqual(b'4', bytes(buf)) + + def test_readuntil_limit(self): + buf = self._make_one() + p = buf.readuntil(b'\n', 4) + next(p) + p.send(b'1') + p.send(b'234') + self.assertRaises(ValueError, p.send, b'5') + + buf = parsers.ParserBuffer() + p = buf.readuntil(b'\n', 4) + next(p) + self.assertRaises(ValueError, p.send, b'12345\n6') + + buf = parsers.ParserBuffer() + p = buf.readuntil(b'\n', 4) + next(p) + self.assertRaises(ValueError, p.send, b'12345\n6') + + class CustomExc(Exception): + pass + + buf = parsers.ParserBuffer() + p = buf.readuntil(b'\n', 4, CustomExc) + next(p) + self.assertRaises(CustomExc, p.send, b'12345\n6') + + def test_readuntil(self): + buf = self._make_one() + p = buf.readuntil(b'\n', 4) + next(p) + p.send(b'123') + try: + p.send(b'\n456') + except StopIteration as exc: + res = exc.value + + self.assertEqual(res, b'123\n') + self.assertEqual(b'456', bytes(buf)) + + def test_skipuntil(self): + buf = self._make_one() + p = buf.skipuntil(b'\n') + next(p) + p.send(b'123') + try: + p.send(b'\n456\n') + except StopIteration: + pass + self.assertEqual(b'456\n', bytes(buf)) + + p = buf.skipuntil(b'\n') + try: + next(p) + except StopIteration: + pass + self.assertEqual(b'', bytes(buf)) + + def test_lines_parser(self): + out = parsers.DataBuffer() + buf = self._make_one() + p = parsers.lines_parser() + next(p) + p.send((out, buf)) + + for d in (b'line1', b'\r\n', b'lin', b'e2\r', b'\ndata'): + p.send(d) + + self.assertEqual( + [bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(out._buffer)) + try: + p.throw(parsers.EofStream()) + except parsers.EofStream: + pass + + self.assertEqual(bytes(buf), b'data') + + def test_chunks_parser(self): + out = parsers.DataBuffer() + buf = self._make_one() + p = parsers.chunks_parser(5) + next(p) + p.send((out, buf)) + + for d in (b'line1', b'lin', b'e2d', b'ata'): + p.send(d) + + self.assertEqual( + [bytearray(b'line1'), bytearray(b'line2')], list(out._buffer)) + try: + p.throw(parsers.EofStream()) + except parsers.EofStream: + pass + + self.assertEqual(bytes(buf), b'data') diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py new file mode 100644 index 00000000..959cd2ba --- /dev/null +++ b/tests/proactor_events_test.py @@ -0,0 +1,329 @@ +"""Tests for proactor_events.py""" + +import socket +import unittest +import unittest.mock + +import tulip +from tulip.proactor_events import BaseProactorEventLoop +from tulip.proactor_events import _ProactorSocketTransport + + +class ProactorSocketTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = unittest.mock.Mock() + self.sock = unittest.mock.Mock(socket.socket) + self.protocol = unittest.mock.Mock(tulip.Protocol) + + def test_ctor(self): + fut = tulip.Future() + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol, fut) + self.loop.call_soon.mock_calls[0].assert_called_with(tr._loop_reading) + self.loop.call_soon.mock_calls[1].assert_called_with( + self.protocol.connection_made, tr) + self.loop.call_soon.mock_calls[2].assert_called_with( + fut.set_result, None) + + def test_loop_reading(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._loop_reading() + self.loop._proactor.recv.assert_called_with(self.sock, 4096) + self.assertFalse(self.protocol.data_received.called) + self.assertFalse(self.protocol.eof_received.called) + + def test_loop_reading_data(self): + res = tulip.Future() + res.set_result(b'data') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + + tr._read_fut = res + tr._loop_reading(res) + self.loop._proactor.recv.assert_called_with(self.sock, 4096) + self.protocol.data_received.assert_called_with(b'data') + + def test_loop_reading_no_data(self): + res = tulip.Future() + res.set_result(b'') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + + self.assertRaises(AssertionError, tr._loop_reading, res) + + tr._read_fut = res + tr._loop_reading(res) + self.assertFalse(self.loop._proactor.recv.called) + self.assertTrue(self.protocol.eof_received.called) + + def test_loop_reading_aborted(self): + err = self.loop._proactor.recv.side_effect = ConnectionAbortedError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with(err) + + def test_loop_reading_aborted_closing(self): + self.loop._proactor.recv.side_effect = ConnectionAbortedError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = True + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + + def test_loop_reading_exception(self): + err = self.loop._proactor.recv.side_effect = (OSError()) + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with(err) + + def test_write(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._loop_writing = unittest.mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, [b'data']) + self.assertTrue(tr._loop_writing.called) + + def test_write_no_data(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr.write(b'') + self.assertFalse(tr._buffer) + + def test_write_more(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = unittest.mock.Mock() + tr._loop_writing = unittest.mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, [b'data']) + self.assertFalse(tr._loop_writing.called) + + def test_loop_writing(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + self.loop._proactor.send.assert_called_with(self.sock, b'data') + self.loop._proactor.send.return_value.add_done_callback.\ + assert_called_with(tr._loop_writing) + + @unittest.mock.patch('tulip.proactor_events.tulip_log') + def test_loop_writing_err(self, m_log): + err = self.loop._proactor.send.side_effect = OSError() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + tr._fatal_error.assert_called_with(err) + self.assertEqual(tr._conn_lost, 1) + + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + self.assertEqual(tr._buffer, []) + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_loop_writing_stop(self): + fut = tulip.Future() + fut.set_result(b'data') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = fut + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + + def test_loop_writing_closing(self): + fut = tulip.Future() + fut.set_result(1) + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + self.loop.reset_mock() + tr._write_fut = fut + tr.close() + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + self.loop.call_soon.assert_called_with(tr._call_connection_lost, None) + + def test_abort(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr.abort() + tr._fatal_error.assert_called_with(None) + + def test_close(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + self.loop.reset_mock() + tr.close() + self.loop.call_soon.assert_called_with(tr._call_connection_lost, None) + self.assertTrue(tr._closing) + + def test_close_write_fut(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = unittest.mock.Mock() + self.loop.reset_mock() + tr.close() + self.assertFalse(self.loop.call_soon.called) + + def test_close_buffer(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + self.loop.reset_mock() + tr.close() + self.assertFalse(self.loop.call_soon.called) + + @unittest.mock.patch('tulip.proactor_events.tulip_log') + def test_fatal_error(self, m_logging): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + read_fut = tr._read_fut = unittest.mock.Mock() + write_fut = tr._write_fut = unittest.mock.Mock() + tr._fatal_error(None) + + read_fut.cancel.assert_called_with() + write_fut.cancel.assert_called_with() + self.loop.call_soon.assert_called_with(tr._call_connection_lost, None) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('tulip.proactor_events.tulip_log') + def test_fatal_error_2(self, m_logging): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + tr._fatal_error(None) + + self.loop.call_soon.assert_called_with(tr._call_connection_lost, None) + self.assertEqual([], tr._buffer) + + def test_call_connection_lost(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._call_connection_lost(None) + self.assertTrue(self.protocol.connection_lost.called) + self.assertTrue(self.sock.close.called) + + +class BaseProactorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.sock = unittest.mock.Mock(socket.socket) + self.proactor = unittest.mock.Mock() + + self.ssock, self.csock = unittest.mock.Mock(), unittest.mock.Mock() + + class EventLoop(BaseProactorEventLoop): + def _socketpair(s): + return (self.ssock, self.csock) + + self.loop = EventLoop(self.proactor) + + @unittest.mock.patch.object(BaseProactorEventLoop, 'call_soon') + @unittest.mock.patch.object(BaseProactorEventLoop, '_socketpair') + def test_ctor(self, socketpair, call_soon): + ssock, csock = socketpair.return_value = ( + unittest.mock.Mock(), unittest.mock.Mock()) + loop = BaseProactorEventLoop(self.proactor) + self.assertIs(loop._ssock, ssock) + self.assertIs(loop._csock, csock) + self.assertEqual(loop._internal_fds, 1) + call_soon.assert_called_with(loop._loop_self_reading) + + def test_close_self_pipe(self): + self.loop._close_self_pipe() + self.assertEqual(self.loop._internal_fds, 0) + self.assertTrue(self.ssock.close.called) + self.assertTrue(self.csock.close.called) + self.assertIsNone(self.loop._ssock) + self.assertIsNone(self.loop._csock) + + def test_close(self): + self.loop._close_self_pipe = unittest.mock.Mock() + self.loop.close() + self.assertTrue(self.loop._close_self_pipe.called) + self.assertTrue(self.proactor.close.called) + self.assertIsNone(self.loop._proactor) + + self.loop._close_self_pipe.reset_mock() + self.loop.close() + self.assertFalse(self.loop._close_self_pipe.called) + + def test_sock_recv(self): + self.loop.sock_recv(self.sock, 1024) + self.proactor.recv.assert_called_with(self.sock, 1024) + + def test_sock_sendall(self): + self.loop.sock_sendall(self.sock, b'data') + self.proactor.send.assert_called_with(self.sock, b'data') + + def test_sock_connect(self): + self.loop.sock_connect(self.sock, 123) + self.proactor.connect.assert_called_with(self.sock, 123) + + def test_sock_accept(self): + self.loop.sock_accept(self.sock) + self.proactor.accept.assert_called_with(self.sock) + + def test_socketpair(self): + self.assertRaises( + NotImplementedError, BaseProactorEventLoop, self.proactor) + + def test_make_socket_transport(self): + tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock()) + self.assertIsInstance(tr, _ProactorSocketTransport) + + def test_loop_self_reading(self): + self.loop._loop_self_reading() + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_fut(self): + fut = unittest.mock.Mock() + self.loop._loop_self_reading(fut) + self.assertTrue(fut.result.called) + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_exception(self): + self.loop.close = unittest.mock.Mock() + self.proactor.recv.side_effect = OSError() + self.assertRaises(OSError, self.loop._loop_self_reading) + self.assertTrue(self.loop.close.called) + + def test_write_to_self(self): + self.loop._write_to_self() + self.csock.send.assert_called_with(b'x') + + def test_process_events(self): + self.loop._process_events([]) + + @unittest.mock.patch('tulip.proactor_events.tulip_log') + def test_start_serving(self, m_log): + pf = unittest.mock.Mock() + call_soon = self.loop.call_soon = unittest.mock.Mock() + + self.loop._start_serving(pf, self.sock) + self.assertTrue(call_soon.called) + + # callback + loop = call_soon.call_args[0][0] + loop() + self.proactor.accept.assert_called_with(self.sock) + + # conn + fut = unittest.mock.Mock() + fut.result.return_value = (unittest.mock.Mock(), unittest.mock.Mock()) + + make_tr = self.loop._make_socket_transport = unittest.mock.Mock() + loop(fut) + self.assertTrue(fut.result.called) + self.assertTrue(make_tr.called) + + # exception + fut.result.side_effect = OSError() + loop(fut) + self.assertTrue(self.sock.close.called) + self.assertTrue(m_log.exception.called) diff --git a/tests/queues_test.py b/tests/queues_test.py new file mode 100644 index 00000000..714465d9 --- /dev/null +++ b/tests/queues_test.py @@ -0,0 +1,380 @@ +"""Tests for queues.py""" + +import unittest +import queue + +from tulip import events +from tulip import locks +from tulip import queues +from tulip import tasks + + +class _QueueTestBase(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + +class QueueBasicTests(_QueueTestBase): + + def _test_repr_or_str(self, fn, expect_id): + """Test Queue's repr or str. + + fn is repr or str. expect_id is True if we expect the Queue's id to + appear in fn(Queue()). + """ + q = queues.Queue() + self.assertTrue(fn(q).startswith('", repr(key)) + + def test_register(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ) + self.assertIsInstance(key, selectors.SelectorKey) + self.assertEqual(key.fd, 10) + self.assertIs(key, s._fd_to_key[10]) + + def test_register_unknown_event(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.register, unittest.mock.Mock(), 999999) + + def test_register_already_registered(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + s.register(fobj, selectors.EVENT_READ) + self.assertRaises(ValueError, s.register, fobj, selectors.EVENT_READ) + + def test_unregister(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + s.register(fobj, selectors.EVENT_READ) + s.unregister(fobj) + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_unregister_unknown(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.unregister, unittest.mock.Mock()) + + def test_modify_unknown(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.modify, unittest.mock.Mock(), 1) + + def test_modify(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ) + key2 = s.modify(fobj, selectors.EVENT_WRITE) + self.assertNotEqual(key.events, key2.events) + self.assertEqual((selectors.EVENT_WRITE, None), s.get_info(fobj)) + + def test_modify_data(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + d1 = object() + d2 = object() + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ, d1) + key2 = s.modify(fobj, selectors.EVENT_READ, d2) + self.assertEqual(key.events, key2.events) + self.assertNotEqual(key.data, key2.data) + self.assertEqual((selectors.EVENT_READ, d2), s.get_info(fobj)) + + def test_modify_same(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + data = object() + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ, data) + key2 = s.modify(fobj, selectors.EVENT_READ, data) + self.assertIs(key, key2) + + def test_select(self): + s = selectors._BaseSelector() + self.assertRaises(NotImplementedError, s.select) + + def test_close(self): + s = selectors._BaseSelector() + s.register(1, selectors.EVENT_READ) + + s.close() + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_registered_count(self): + s = selectors._BaseSelector() + self.assertEqual(0, s.registered_count()) + + s.register(1, selectors.EVENT_READ) + self.assertEqual(1, s.registered_count()) + + s.unregister(1) + self.assertEqual(0, s.registered_count()) + + def test_context_manager(self): + s = selectors._BaseSelector() + + with s as sel: + sel.register(1, selectors.EVENT_READ) + + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + @unittest.mock.patch('tulip.selectors.tulip_log') + def test_key_from_fd(self, m_log): + s = selectors._BaseSelector() + key = s.register(1, selectors.EVENT_READ) + + self.assertIs(key, s._key_from_fd(1)) + self.assertIsNone(s._key_from_fd(10)) + m_log.warning.assert_called_with('No key found for fd %r', 10) + + if hasattr(selectors.DefaultSelector, 'fileno'): + def test_fileno(self): + self.assertIsInstance(selectors.DefaultSelector().fileno(), int) diff --git a/tests/streams_test.py b/tests/streams_test.py new file mode 100644 index 00000000..832c3119 --- /dev/null +++ b/tests/streams_test.py @@ -0,0 +1,293 @@ +"""Tests for streams.py.""" + +import unittest + +from tulip import events +from tulip import streams +from tulip import tasks + + +class StreamReaderTests(unittest.TestCase): + + DATA = b'line1\nline2\nline3\n' + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_feed_empty_data(self): + stream = streams.StreamReader() + + stream.feed_data(b'') + self.assertEqual(0, stream.byte_count) + + def test_feed_data_byte_count(self): + stream = streams.StreamReader() + + stream.feed_data(self.DATA) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read_zero(self): + # Read zero bytes. + stream = streams.StreamReader() + stream.feed_data(self.DATA) + + data = self.event_loop.run_until_complete(stream.read(0)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read(self): + # Read bytes. + stream = streams.StreamReader() + read_task = tasks.Task(stream.read(30)) + + def cb(): + stream.feed_data(self.DATA) + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + + def test_read_line_breaks(self): + # Read bytes without line breaks. + stream = streams.StreamReader() + stream.feed_data(b'line1') + stream.feed_data(b'line2') + + data = self.event_loop.run_until_complete(stream.read(5)) + + self.assertEqual(b'line1', data) + self.assertEqual(5, stream.byte_count) + + def test_read_eof(self): + # Read bytes, stop at eof. + stream = streams.StreamReader() + read_task = tasks.Task(stream.read(1024)) + + def cb(): + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'', data) + self.assertFalse(stream.byte_count) + + def test_read_until_eof(self): + # Read all bytes until eof. + stream = streams.StreamReader() + read_task = tasks.Task(stream.read(-1)) + + def cb(): + stream.feed_data(b'chunk1\n') + stream.feed_data(b'chunk2') + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + + self.assertEqual(b'chunk1\nchunk2', data) + self.assertFalse(stream.byte_count) + + def test_read_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete(stream.read(2)) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, stream.read(2)) + + def test_readline(self): + # Read one line. + stream = streams.StreamReader() + stream.feed_data(b'chunk1 ') + read_task = tasks.Task(stream.readline()) + + def cb(): + stream.feed_data(b'chunk2 ') + stream.feed_data(b'chunk3 ') + stream.feed_data(b'\n chunk4') + self.event_loop.call_soon(cb) + + line = self.event_loop.run_until_complete(read_task) + self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) + self.assertEqual(len(b'\n chunk4')-1, stream.byte_count) + + def test_readline_limit_with_existing_data(self): + stream = streams.StreamReader(3) + stream.feed_data(b'li') + stream.feed_data(b'ne1\nline2\n') + + self.assertRaises( + ValueError, self.event_loop.run_until_complete, stream.readline()) + self.assertEqual([b'line2\n'], list(stream.buffer)) + + stream = streams.StreamReader(3) + stream.feed_data(b'li') + stream.feed_data(b'ne1') + stream.feed_data(b'li') + + self.assertRaises( + ValueError, self.event_loop.run_until_complete, stream.readline()) + self.assertEqual([b'li'], list(stream.buffer)) + self.assertEqual(2, stream.byte_count) + + def test_readline_limit(self): + stream = streams.StreamReader(7) + + def cb(): + stream.feed_data(b'chunk1') + stream.feed_data(b'chunk2') + stream.feed_data(b'chunk3\n') + stream.feed_eof() + self.event_loop.call_soon(cb) + + self.assertRaises( + ValueError, self.event_loop.run_until_complete, stream.readline()) + self.assertEqual([b'chunk3\n'], list(stream.buffer)) + self.assertEqual(7, stream.byte_count) + + def test_readline_line_byte_count(self): + stream = streams.StreamReader() + stream.feed_data(self.DATA[:6]) + stream.feed_data(self.DATA[6:]) + + line = self.event_loop.run_until_complete(stream.readline()) + + self.assertEqual(b'line1\n', line) + self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) + + def test_readline_eof(self): + stream = streams.StreamReader() + stream.feed_data(b'some data') + stream.feed_eof() + + line = self.event_loop.run_until_complete(stream.readline()) + self.assertEqual(b'some data', line) + + def test_readline_empty_eof(self): + stream = streams.StreamReader() + stream.feed_eof() + + line = self.event_loop.run_until_complete(stream.readline()) + self.assertEqual(b'', line) + + def test_readline_read_byte_count(self): + stream = streams.StreamReader() + stream.feed_data(self.DATA) + + self.event_loop.run_until_complete(stream.readline()) + + data = self.event_loop.run_until_complete(stream.read(7)) + + self.assertEqual(b'line2\nl', data) + self.assertEqual( + len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), + stream.byte_count) + + def test_readline_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete(stream.readline()) + self.assertEqual(b'line\n', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, stream.readline()) + + def test_readexactly_zero_or_less(self): + # Read exact number of bytes (zero or less). + stream = streams.StreamReader() + stream.feed_data(self.DATA) + + data = self.event_loop.run_until_complete(stream.readexactly(0)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + data = self.event_loop.run_until_complete(stream.readexactly(-1)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readexactly(self): + # Read exact number of bytes. + stream = streams.StreamReader() + + n = 2 * len(self.DATA) + read_task = tasks.Task(stream.readexactly(n)) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA + self.DATA, data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readexactly_eof(self): + # Read exact number of bytes (eof). + stream = streams.StreamReader() + n = 2 * len(self.DATA) + read_task = tasks.Task(stream.readexactly(n)) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_eof() + self.event_loop.call_soon(cb) + + data = self.event_loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + + def test_readexactly_exception(self): + stream = streams.StreamReader() + stream.feed_data(b'line\n') + + data = self.event_loop.run_until_complete(stream.readexactly(2)) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, + self.event_loop.run_until_complete, + stream.readexactly(2)) + + def test_exception(self): + stream = streams.StreamReader() + self.assertIsNone(stream.exception()) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(stream.exception(), exc) + + def test_exception_waiter(self): + stream = streams.StreamReader() + + @tasks.coroutine + def set_err(): + stream.set_exception(ValueError()) + + @tasks.coroutine + def readline(): + yield from stream.readline() + + t1 = tasks.Task(stream.readline()) + t2 = tasks.Task(set_err()) + + self.event_loop.run_until_complete(tasks.wait([t1, t2])) + + self.assertRaises(ValueError, t1.result) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/subprocess_test.py b/tests/subprocess_test.py new file mode 100644 index 00000000..28ab6623 --- /dev/null +++ b/tests/subprocess_test.py @@ -0,0 +1,61 @@ +# NOTE: This is a hack. Andrew Svetlov is working in a proper +# subprocess management transport for use with +# connect_{read,write}_pipe(). + +"""Tests for subprocess_transport.py.""" + +import logging +import unittest + +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import subprocess_transport + + +class MyProto(protocols.Protocol): + + def __init__(self): + self.state = 'INITIAL' + self.nbytes = 0 + self.done = futures.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write_eof() + + def data_received(self, data): + logging.info('received: %r', data) + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + self.done.set_result(None) + + +class FutureTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_unix_subprocess(self): + p = MyProto() + subprocess_transport.UnixSubprocessTransport(p, ['/bin/ls', '-lR']) + self.event_loop.run_until_complete(p.done) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/tasks_test.py b/tests/tasks_test.py new file mode 100644 index 00000000..583674e8 --- /dev/null +++ b/tests/tasks_test.py @@ -0,0 +1,749 @@ +"""Tests for tasks.py.""" + +import concurrent.futures +import time +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import tasks + + +class Dummy: + + def __repr__(self): + return 'Dummy()' + + def __call__(self, *args): + pass + + +class TaskTests(unittest.TestCase): + + def setUp(self): + self.event_loop = events.new_event_loop() + events.set_event_loop(self.event_loop) + + def tearDown(self): + self.event_loop.close() + + def test_task_class(self): + @tasks.coroutine + def notmuch(): + return 'ok' + t = tasks.Task(notmuch()) + self.event_loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t._event_loop, self.event_loop) + + event_loop = events.new_event_loop() + t = tasks.Task(notmuch(), event_loop=event_loop) + self.assertIs(t._event_loop, event_loop) + + def test_task_decorator(self): + @tasks.task + def notmuch(): + yield from [] + return 'ko' + t = notmuch() + self.event_loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + + def test_task_decorator_func(self): + @tasks.task + def notmuch(): + return 'ko' + t = notmuch() + self.event_loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + + def test_task_decorator_fut(self): + fut = futures.Future() + fut.set_result('ko') + + @tasks.task + def notmuch(): + return fut + t = notmuch() + self.event_loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + + def test_task_repr(self): + @tasks.task + def notmuch(): + yield from [] + return 'abc' + + t = notmuch() + t.add_done_callback(Dummy()) + self.assertEqual(repr(t), 'Task()') + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), 'Task()') + self.assertRaises(futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertEqual(repr(t), 'Task()') + t = notmuch() + self.event_loop.run_until_complete(t) + self.assertEqual(repr(t), "Task()") + + def test_task_repr_custom(self): + @tasks.coroutine + def coro(): + pass + + class T(futures.Future): + def __repr__(self): + return 'T[]' + + class MyTask(tasks.Task, T): + def __repr__(self): + return super().__repr__() + + t = MyTask(coro()) + self.assertEqual(repr(t), 'T[]()') + + def test_task_basics(self): + @tasks.task + def outer(): + a = yield from inner1() + b = yield from inner2() + return a+b + + @tasks.task + def inner1(): + return 42 + + @tasks.task + def inner2(): + return 1000 + + t = outer() + self.assertEqual(self.event_loop.run_until_complete(t), 1042) + + def test_cancel(self): + @tasks.task + def task(): + yield from tasks.sleep(10.0) + return 12 + + t = task() + self.event_loop.call_soon(t.cancel) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_cancel_yield(self): + @tasks.task + def task(): + yield + yield + return 12 + + t = task() + self.event_loop.run_once() # start coro + t.cancel() + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_cancel_done_future(self): + fut1 = futures.Future() + fut2 = futures.Future() + fut3 = futures.Future() + + @tasks.task + def task(): + yield from fut1 + try: + yield from fut2 + except futures.CancelledError: + pass + yield from fut3 + + t = task() + self.event_loop.run_once() + fut1.set_result(None) + t.cancel() + self.event_loop.run_once() # process fut1 result, delay cancel + self.assertFalse(t.done()) + self.event_loop.run_once() # cancel fut2, but coro still alive + self.assertFalse(t.done()) + self.event_loop.run_once() # cancel fut3 + self.assertTrue(t.done()) + + self.assertEqual(fut1.result(), None) + self.assertTrue(fut2.cancelled()) + self.assertTrue(fut3.cancelled()) + self.assertTrue(t.cancelled()) + + def test_future_timeout(self): + @tasks.coroutine + def coro(): + yield from tasks.sleep(10.0) + return 12 + + t = tasks.Task(coro(), timeout=0.1) + + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_future_timeout_catch(self): + @tasks.coroutine + def coro(): + yield from tasks.sleep(10.0) + return 12 + + class Cancelled(Exception): + pass + + @tasks.coroutine + def coro2(): + try: + yield from tasks.Task(coro(), timeout=0.1) + except futures.CancelledError: + raise Cancelled() + + self.assertRaises( + Cancelled, self.event_loop.run_until_complete, coro2()) + + def test_cancel_in_coro(self): + @tasks.coroutine + def task(): + t.cancel() + return 12 + + t = tasks.Task(task()) + self.assertRaises( + futures.CancelledError, + self.event_loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_stop_while_run_in_complete(self): + x = 0 + + @tasks.coroutine + def task(): + nonlocal x + while x < 10: + yield from tasks.sleep(0.1) + x += 1 + if x == 2: + self.event_loop.stop() + + t = tasks.Task(task()) + t0 = time.monotonic() + self.assertRaises( + futures.InvalidStateError, + self.event_loop.run_until_complete, t) + t1 = time.monotonic() + self.assertFalse(t.done()) + self.assertTrue(0.18 <= t1-t0 <= 0.22) + self.assertEqual(x, 2) + + def test_timeout(self): + @tasks.task + def task(): + yield from tasks.sleep(10.0) + return 42 + + t = task() + t0 = time.monotonic() + self.assertRaises( + futures.TimeoutError, + self.event_loop.run_until_complete, t, 0.1) + t1 = time.monotonic() + self.assertFalse(t.done()) + self.assertTrue(0.08 <= t1-t0 <= 0.12) + + def test_timeout_not(self): + @tasks.task + def task(): + yield from tasks.sleep(0.1) + return 42 + + t = task() + t0 = time.monotonic() + r = self.event_loop.run_until_complete(t, 10.0) + t1 = time.monotonic() + self.assertTrue(t.done()) + self.assertEqual(r, 42) + self.assertTrue(0.08 <= t1-t0 <= 0.12) + + def test_wait(self): + a = tasks.Task(tasks.sleep(0.1)) + b = tasks.Task(tasks.sleep(0.15)) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertEqual(res, 42) + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + # TODO: Test different return_when values. + + def test_wait_first_completed(self): + a = tasks.Task(tasks.sleep(10.0)) + b = tasks.Task(tasks.sleep(0.1)) + task = tasks.Task(tasks.wait( + [b, a], return_when=tasks.FIRST_COMPLETED)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertFalse(a.done()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + + def test_wait_really_done(self): + # there is possibility that some tasks in the pending list + # became done but their callbacks haven't all been called yet + + @tasks.coroutine + def coro1(): + yield + + @tasks.coroutine + def coro2(): + yield + yield + + a = tasks.Task(coro1()) + b = tasks.Task(coro2()) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({a, b}, done) + self.assertTrue(a.done()) + self.assertIsNone(a.result()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + + def test_wait_first_exception(self): + # first_exception, task already has exception + a = tasks.Task(tasks.sleep(10.0)) + + @tasks.coroutine + def exc(): + raise ZeroDivisionError('err') + + b = tasks.Task(exc()) + task = tasks.Task(tasks.wait( + [b, a], return_when=tasks.FIRST_EXCEPTION)) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + + def test_wait_first_exception_in_wait(self): + # first_exception, exception during waiting + a = tasks.Task(tasks.sleep(10.0)) + + @tasks.coroutine + def exc(): + yield from tasks.sleep(0.01) + raise ZeroDivisionError('err') + + b = tasks.Task(exc()) + task = tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION) + + done, pending = self.event_loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + + def test_wait_with_exception(self): + a = tasks.Task(tasks.sleep(0.1)) + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(0.15) + raise ZeroDivisionError('really') + + b = tasks.Task(sleeper()) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + + t0 = time.monotonic() + self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + t0 = time.monotonic() + self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + + def test_wait_with_timeout(self): + a = tasks.Task(tasks.sleep(0.1)) + b = tasks.Task(tasks.sleep(0.15)) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], timeout=0.11) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + + t0 = time.monotonic() + self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.1) + self.assertTrue(t1-t0 <= 0.13) + + def test_as_completed(self): + @tasks.coroutine + def sleeper(dt, x): + yield from tasks.sleep(dt) + return x + + a = sleeper(0.1, 'a') + b = sleeper(0.1, 'b') + c = sleeper(0.15, 'c') + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([b, c, a]): + values.append((yield from f)) + return values + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.14) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + # Doing it again should take no time and exercise a different path. + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 <= 0.01) + + def test_as_completed_with_timeout(self): + a = tasks.sleep(0.1, 'a') + b = tasks.sleep(0.15, 'b') + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([a, b], timeout=0.12): + try: + v = yield from f + values.append((1, v)) + except futures.TimeoutError as exc: + values.append((2, exc)) + return values + + t0 = time.monotonic() + res = self.event_loop.run_until_complete(tasks.Task(foo())) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.11) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) + + def test_sleep(self): + @tasks.coroutine + def sleeper(dt, arg): + yield from tasks.sleep(dt/2) + res = yield from tasks.sleep(dt/2, arg) + return res + + t = tasks.Task(sleeper(0.1, 'yeah')) + t0 = time.monotonic() + self.event_loop.run_until_complete(t) + t1 = time.monotonic() + self.assertTrue(t1-t0 >= 0.09) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + + def test_sleep_cancel(self): + t = tasks.Task(tasks.sleep(10.0, 'yeah')) + + handle = None + orig_call_later = self.event_loop.call_later + + def call_later(self, delay, callback, *args): + nonlocal handle + handle = orig_call_later(self, delay, callback, *args) + return handle + + self.event_loop.call_later = call_later + self.event_loop.run_once() + + self.assertFalse(handle.cancelled) + + t.cancel() + self.event_loop.run_once() + self.assertTrue(handle.cancelled) + + def test_task_cancel_sleeping_task(self): + sleepfut = None + + @tasks.task + def sleep(dt): + nonlocal sleepfut + sleepfut = tasks.sleep(dt) + try: + time.monotonic() + yield from sleepfut + finally: + time.monotonic() + + @tasks.task + def doit(): + sleeper = sleep(5000) + self.event_loop.call_later(0.1, sleeper.cancel) + try: + time.monotonic() + yield from sleeper + except futures.CancelledError: + time.monotonic() + return 'cancelled' + else: + return 'slept in' + + t0 = time.monotonic() + doer = doit() + self.assertEqual(self.event_loop.run_until_complete(doer), 'cancelled') + t1 = time.monotonic() + self.assertTrue(0.09 <= t1-t0 <= 0.13, (t1-t0, sleepfut, doer)) + + def test_task_cancel_waiter_future(self): + fut = futures.Future() + + @tasks.task + def coro(): + try: + yield from fut + except futures.CancelledError: + pass + + task = coro() + self.event_loop.run_once() + self.assertIs(task._fut_waiter, fut) + + task.cancel() + self.assertRaises( + futures.CancelledError, self.event_loop.run_until_complete, task) + self.assertIsNone(task._fut_waiter) + self.assertTrue(fut.cancelled()) + + def test_step_in_completed_task(self): + @tasks.coroutine + def notmuch(): + return 'ko' + + task = tasks.Task(notmuch()) + task.set_result('ok') + + self.assertRaises(AssertionError, task._step) + + def test_step_result(self): + @tasks.coroutine + def notmuch(): + yield None + yield 1 + return 'ko' + + self.assertRaises( + RuntimeError, self.event_loop.run_until_complete, notmuch()) + + def test_step_result_future(self): + # If coroutine returns future, task waits on this future. + + class Fut(futures.Future): + def __init__(self, *args): + self.cb_added = False + super().__init__(*args) + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + fut = Fut() + result = None + + @tasks.task + def wait_for_future(): + nonlocal result + result = yield from fut + + t = wait_for_future() + self.event_loop.run_once() + self.assertTrue(fut.cb_added) + + res = object() + fut.set_result(res) + self.event_loop.run_once() + self.assertIs(res, result) + self.assertTrue(t.done()) + self.assertIsNone(t.result()) + + def test_step_result_concurrent_future(self): + # Coroutine returns concurrent.futures.Future + + class Fut(concurrent.futures.Future): + def __init__(self): + self.cb_added = False + super().__init__() + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + c_fut = Fut() + + @tasks.coroutine + def notmuch(): + return (yield c_fut) + + task = tasks.Task(notmuch()) + self.event_loop.run_once() + self.assertTrue(c_fut.cb_added) + + res = object() + c_fut.set_result(res) + self.event_loop.run_once() + self.assertIs(res, task.result()) + + def test_step_with_baseexception(self): + @tasks.coroutine + def notmutch(): + raise BaseException() + + task = tasks.Task(notmutch()) + self.assertRaises(BaseException, task._step) + + self.assertTrue(task.done()) + self.assertIsInstance(task.exception(), BaseException) + + def test_baseexception_during_cancel(self): + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(10) + + @tasks.coroutine + def notmutch(): + try: + yield from sleeper() + except futures.CancelledError: + raise BaseException() + + task = tasks.Task(notmutch()) + self.event_loop.run_once() + + task.cancel() + self.assertFalse(task.done()) + + self.assertRaises(BaseException, self.event_loop.run_once) + + self.assertTrue(task.done()) + self.assertTrue(task.cancelled()) + + def test_iscoroutinefunction(self): + def fn(): + pass + + self.assertFalse(tasks.iscoroutinefunction(fn)) + + def fn1(): + yield + self.assertFalse(tasks.iscoroutinefunction(fn1)) + + @tasks.coroutine + def fn2(): + yield + self.assertTrue(tasks.iscoroutinefunction(fn2)) + + def test_yield_vs_yield_from(self): + fut = futures.Future() + + @tasks.task + def wait_for_future(): + yield fut + + task = wait_for_future() + with self.assertRaises(RuntimeError) as cm: + self.event_loop.run_until_complete(task) + + self.assertTrue(fut.done()) + self.assertIs(fut.exception(), cm.exception) + + def test_yield_vs_yield_from_generator(self): + @tasks.coroutine + def coro(): + yield + + @tasks.task + def wait_for_future(): + yield coro() + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.event_loop.run_until_complete, task) + + def test_coroutine_non_gen_function(self): + @tasks.coroutine + def func(): + return 'test' + + self.assertTrue(tasks.iscoroutinefunction(func)) + + coro = func() + self.assertTrue(tasks.iscoroutine(coro)) + + res = self.event_loop.run_until_complete(coro) + self.assertEqual(res, 'test') + + def test_coroutine_non_gen_function_return_future(self): + fut = futures.Future() + + @tasks.coroutine + def func(): + return fut + + @tasks.coroutine + def coro(): + fut.set_result('test') + + t1 = tasks.Task(func()) + t2 = tasks.Task(coro()) + res = self.event_loop.run_until_complete(t1) + self.assertEqual(res, 'test') + self.assertIsNone(t2.result()) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/transports_test.py b/tests/transports_test.py new file mode 100644 index 00000000..4b24b50b --- /dev/null +++ b/tests/transports_test.py @@ -0,0 +1,45 @@ +"""Tests for transports.py.""" + +import unittest +import unittest.mock + +from tulip import transports + + +class TransportTests(unittest.TestCase): + + def test_ctor_extra_is_none(self): + transport = transports.Transport() + self.assertEqual(transport._extra, {}) + + def test_get_extra_info(self): + transport = transports.Transport({'extra': 'info'}) + self.assertEqual('info', transport.get_extra_info('extra')) + self.assertIsNone(transport.get_extra_info('unknown')) + + default = object() + self.assertIs(default, transport.get_extra_info('unknown', default)) + + def test_writelines(self): + transport = transports.Transport() + transport.write = unittest.mock.Mock() + + transport.writelines(['line1', 'line2', 'line3']) + self.assertEqual(3, transport.write.call_count) + + def test_not_implemented(self): + transport = transports.Transport() + + self.assertRaises(NotImplementedError, transport.write, 'data') + self.assertRaises(NotImplementedError, transport.write_eof) + self.assertRaises(NotImplementedError, transport.can_write_eof) + self.assertRaises(NotImplementedError, transport.pause) + self.assertRaises(NotImplementedError, transport.resume) + self.assertRaises(NotImplementedError, transport.close) + self.assertRaises(NotImplementedError, transport.abort) + + def test_dgram_not_implemented(self): + transport = transports.DatagramTransport() + + self.assertRaises(NotImplementedError, transport.sendto, 'data') + self.assertRaises(NotImplementedError, transport.abort) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py new file mode 100644 index 00000000..b96719a8 --- /dev/null +++ b/tests/unix_events_test.py @@ -0,0 +1,592 @@ +"""Tests for unix_events.py.""" + +import errno +import io +import unittest +import unittest.mock + +try: + import signal +except ImportError: + signal = None + +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import unix_events + + +@unittest.skipUnless(signal, 'Signals are not supported') +class SelectorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unix_events.SelectorEventLoop() + + def test_check_signal(self): + self.assertRaises( + TypeError, self.event_loop._check_signal, '1') + self.assertRaises( + ValueError, self.event_loop._check_signal, signal.NSIG + 1) + + unix_events.signal = None + + def restore_signal(): + unix_events.signal = signal + self.addCleanup(restore_signal) + + self.assertRaises( + RuntimeError, self.event_loop._check_signal, signal.SIGINT) + + def test_handle_signal_no_handler(self): + self.event_loop._handle_signal(signal.NSIG + 1, ()) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_setup_error(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.set_wakeup_fd.side_effect = ValueError + + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + h = self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + self.assertIsInstance(h, events.Handle) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_install_error(self, m_signal): + m_signal.NSIG = signal.NSIG + + def set_wakeup_fd(fd): + if fd == -1: + raise ValueError() + m_signal.set_wakeup_fd = set_wakeup_fd + + class Err(OSError): + errno = errno.EFAULT + m_signal.signal.side_effect = Err + + self.assertRaises( + Err, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_add_signal_handler_install_error2(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.event_loop._signal_handlers[signal.SIGHUP] = lambda: True + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(1, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_add_signal_handler_install_error3(self, m_logging, m_signal): + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + m_signal.NSIG = signal.NSIG + + self.assertRaises( + RuntimeError, + self.event_loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(2, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + self.assertTrue( + self.event_loop.remove_signal_handler(signal.SIGHUP)) + self.assertTrue(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_2(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.SIGINT = signal.SIGINT + + self.event_loop.add_signal_handler(signal.SIGINT, lambda: True) + self.event_loop._signal_handlers[signal.SIGHUP] = object() + m_signal.set_wakeup_fd.reset_mock() + + self.assertTrue( + self.event_loop.remove_signal_handler(signal.SIGINT)) + self.assertFalse(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGINT, m_signal.default_int_handler), + m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.set_wakeup_fd.side_effect = ValueError + + self.event_loop.remove_signal_handler(signal.SIGHUP) + self.assertTrue(m_logging.info) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error(self, m_signal): + m_signal.NSIG = signal.NSIG + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.signal.side_effect = OSError + + self.assertRaises( + OSError, self.event_loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error2(self, m_signal): + m_signal.NSIG = signal.NSIG + self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.assertRaises( + RuntimeError, self.event_loop.remove_signal_handler, signal.SIGHUP) + + +class UnixReadPipeTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor(self, m_fcntl): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + self.event_loop.add_reader.assert_called_with(5, tr._read_ready) + self.event_loop.call_soon.assert_called_with( + self.protocol.connection_made, tr) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor_with_waiter(self, m_fcntl): + fut = futures.Future() + unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol, fut) + self.event_loop.call_soon.assert_called_with(fut.set_result, None) + fut.cancel() + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + m_read.return_value = b'data' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.data_received.assert_called_with(b'data') + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready_eof(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + m_read.return_value = b'' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.event_loop.remove_reader.assert_called_with(5) + self.protocol.eof_received.assert_called_with() + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready_blocked(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + self.event_loop.reset_mock() + m_read.side_effect = BlockingIOError + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.assertFalse(self.protocol.data_received.called) + + @unittest.mock.patch('tulip.log.tulip_log.exception') + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__read_ready_error(self, m_fcntl, m_read, m_logexc): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + err = OSError() + m_read.side_effect = err + tr._close = unittest.mock.Mock() + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + tr._close.assert_called_with(err) + m_logexc.assert_called_with('Fatal error for %s', tr) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_pause(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.pause() + self.event_loop.remove_reader.assert_called_with(5) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_resume(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.resume() + self.event_loop.add_reader.assert_called_with(5, tr._read_ready) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_close(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._close = unittest.mock.Mock() + tr.close() + tr._close.assert_called_with(None) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test_close_already_closing(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._closing = True + tr._close = unittest.mock.Mock() + tr.close() + self.assertFalse(tr._close.called) + + @unittest.mock.patch('os.read') + @unittest.mock.patch('fcntl.fcntl') + def test__close(self, m_fcntl, m_read): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = object() + tr._close(err) + self.assertTrue(tr._closing) + self.event_loop.remove_reader.assert_called_with(5) + self.event_loop.call_soon.assert_called_with( + tr._call_connection_lost, err) + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost(self, m_fcntl): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost_with_err(self, m_fcntl): + tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + +class UnixWritePipeTransportTests(unittest.TestCase): + + def setUp(self): + self.event_loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + self.event_loop.call_soon.assert_called_with( + self.protocol.connection_made, tr) + + @unittest.mock.patch('fcntl.fcntl') + def test_ctor_with_waiter(self, m_fcntl): + fut = futures.Future() + unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol, fut) + self.event_loop.call_soon.assert_called_with(fut.set_result, None) + fut.cancel() + + @unittest.mock.patch('fcntl.fcntl') + def test_can_write_eof(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + self.assertTrue(tr.can_write_eof()) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + m_write.return_value = 4 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.add_writer.called) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_no_data(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write(b'') + self.assertFalse(m_write.called) + self.assertFalse(self.event_loop.add_writer.called) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_partial(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + m_write.return_value = 2 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.event_loop.add_writer.assert_called_with(5, tr._write_ready) + self.assertEqual([b'ta'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_buffer(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'previous'] + tr.write(b'data') + self.assertFalse(m_write.called) + self.assertFalse(self.event_loop.add_writer.called) + self.assertEqual([b'previous', b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_again(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + m_write.side_effect = BlockingIOError() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.event_loop.add_writer.assert_called_with(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('tulip.unix_events.tulip_log') + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_write_err(self, m_fcntl, m_write, m_log): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = OSError() + m_write.side_effect = err + tr._fatal_error = unittest.mock.Mock() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.called) + self.assertEqual([], tr._buffer) + tr._fatal_error.assert_called_with(err) + self.assertEqual(1, tr._conn_lost) + + tr.write(b'data') + self.assertEqual(2, tr._conn_lost) + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + m_log.warning.assert_called_with( + 'os.write(pipe, data) raised exception.') + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_partial(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.return_value = 3 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'a'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_again(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.side_effect = BlockingIOError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_empty(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.return_value = 0 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.event_loop.remove_writer.called) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('tulip.log.tulip_log.exception') + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_err(self, m_fcntl, m_write, m_logexc): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + m_write.side_effect = err = OSError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + self.event_loop.call_soon.assert_called_with( + tr._call_connection_lost, err) + m_logexc.assert_called_with('Fatal error for %s', tr) + self.assertEqual(1, tr._conn_lost) + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test__write_ready_closing(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._closing = True + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + self.protocol.connection_lost.assert_called_with(None) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('os.write') + @unittest.mock.patch('fcntl.fcntl') + def test_abort(self, m_fcntl, m_write): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr._buffer = [b'da', b'ta'] + tr.abort() + self.assertFalse(m_write.called) + self.event_loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + self.event_loop.call_soon.assert_called_with( + tr._call_connection_lost, None) + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test__call_connection_lost_with_err(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test_close(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr.close() + tr.write_eof.assert_called_with() + + @unittest.mock.patch('fcntl.fcntl') + def test_close_closing(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr._closing = True + tr.close() + self.assertFalse(tr.write_eof.called) + + @unittest.mock.patch('fcntl.fcntl') + def test_write_eof(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + + tr.write_eof() + self.assertTrue(tr._closing) + self.event_loop.call_soon.assert_called_with( + tr._call_connection_lost, None) + + @unittest.mock.patch('fcntl.fcntl') + def test_write_eof_pending(self, m_fcntl): + tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, + self.protocol) + tr._buffer = [b'data'] + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.protocol.connection_lost.called) diff --git a/tests/winsocketpair_test.py b/tests/winsocketpair_test.py new file mode 100644 index 00000000..381fb227 --- /dev/null +++ b/tests/winsocketpair_test.py @@ -0,0 +1,26 @@ +"""Tests for winsocketpair.py""" + +import unittest +import unittest.mock + +from tulip import winsocketpair + + +class WinsocketpairTests(unittest.TestCase): + + def test_winsocketpair(self): + ssock, csock = winsocketpair.socketpair() + + csock.send(b'xxx') + self.assertEqual(b'xxx', ssock.recv(1024)) + + csock.close() + ssock.close() + + @unittest.mock.patch('tulip.winsocketpair.socket') + def test_winsocketpair_exc(self, m_socket): + m_socket.socket.return_value.getsockname.return_value = ('', 12345) + m_socket.socket.return_value.accept.return_value = object(), object() + m_socket.socket.return_value.connect.side_effect = OSError() + + self.assertRaises(OSError, winsocketpair.socketpair) diff --git a/tulip/TODO b/tulip/TODO new file mode 100644 index 00000000..b3a9302e --- /dev/null +++ b/tulip/TODO @@ -0,0 +1,26 @@ +TODO in tulip v2 (tulip/ package directory) +------------------------------------------- + +- See also TBD and Open Issues in PEP 3156 + +- Refactor unix_events.py (it's getting too long) + +- Docstrings + +- Unittests + +- better run_once() behavior? (Run ready list last.) + +- start_serving() + +- Make Handler() callable? Log the exception in there? + +- Add the repeat interval to the Handler class? + +- Recognize Handler passed to add_reader(), call_soon(), etc.? + +- SSL support + +- buffered stream implementation + +- Primitives like par() and wait_one() diff --git a/tulip/__init__.py b/tulip/__init__.py new file mode 100644 index 00000000..9de84cb0 --- /dev/null +++ b/tulip/__init__.py @@ -0,0 +1,28 @@ +"""Tulip 2.0, tracking PEP 3156.""" + +import sys + +# This relies on each of the submodules having an __all__ variable. +from .futures import * +from .events import * +from .locks import * +from .transports import * +from .parsers import * +from .protocols import * +from .streams import * +from .tasks import * + +if sys.platform == 'win32': # pragma: no cover + from .windows_events import * +else: + from .unix_events import * # pragma: no cover + + +__all__ = (futures.__all__ + + events.__all__ + + locks.__all__ + + transports.__all__ + + parsers.__all__ + + protocols.__all__ + + streams.__all__ + + tasks.__all__) diff --git a/tulip/base_events.py b/tulip/base_events.py new file mode 100644 index 00000000..e21ea978 --- /dev/null +++ b/tulip/base_events.py @@ -0,0 +1,560 @@ +"""Base implementation of event loop. + +The event loop can be broken up into a multiplexer (the part +responsible for notifying us of IO events) and the event loop proper, +which wraps a multiplexer with functionality for scheduling callbacks, +immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. +""" + + +import collections +import concurrent.futures +import heapq +import logging +import socket +import time +import os +import sys + +from . import events +from . import futures +from . import tasks +from .log import tulip_log + + +__all__ = ['BaseEventLoop'] + + +# Argument for default thread pool executor creation. +_MAX_WORKERS = 5 + + +class _StopError(BaseException): + """Raised to stop the event loop.""" + + +def _raise_stop_error(*args): + raise _StopError + + +class BaseEventLoop(events.AbstractEventLoop): + + def __init__(self): + self._ready = collections.deque() + self._scheduled = [] + self._default_executor = None + self._internal_fds = 0 + self._running = False + + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None): + """Create socket transport.""" + raise NotImplementedError + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + server_side=False, extra=None): + """Create SSL transport.""" + raise NotImplementedError + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + """Create datagram transport.""" + raise NotImplementedError + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create read pipe transport.""" + raise NotImplementedError + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create write pipe transport.""" + raise NotImplementedError + + def _read_from_self(self): + """XXX""" + raise NotImplementedError + + def _write_to_self(self): + """XXX""" + raise NotImplementedError + + def _process_events(self, event_list): + """Process selector events.""" + raise NotImplementedError + + def run_forever(self): + """Run until stop() is called.""" + if self._running: + raise RuntimeError('Event loop is running.') + self._running = True + try: + while True: + try: + self._run_once() + except _StopError: + break + finally: + self._running = False + + def run_once(self, timeout=0): + """Run through all callbacks and all I/O polls once. + + Calling stop() will break out of this too. + """ + if self._running: + raise RuntimeError('Event loop is running.') + + self._running = True + try: + self._run_once(timeout) + except _StopError: + pass + finally: + self._running = False + + def run_until_complete(self, future, timeout=None): + """Run until the Future is done, or until a timeout. + + If the argument is a coroutine, it is wrapped in a Task. + + XXX TBD: It would be disastrous to call run_until_complete() + with the same coroutine twice -- it would wrap it in two + different Tasks and that can't be good. + + Return the Future's result, or raise its exception. If the + timeout is reached or stop() is called, raise TimeoutError. + """ + if not isinstance(future, futures.Future): + if tasks.iscoroutine(future): + future = tasks.Task(future) + else: + assert False, 'A Future or coroutine is required' + + handle_called = False + + def stop_loop(): + nonlocal handle_called + handle_called = True + raise _StopError + + future.add_done_callback(_raise_stop_error) + + if timeout is None: + self.run_forever() + else: + handle = self.call_later(timeout, stop_loop) + self.run_forever() + handle.cancel() + + if handle_called: + raise futures.TimeoutError + + return future.result() + + def stop(self): + """Stop running the event loop. + + Every callback scheduled before stop() is called will run. + Callback scheduled after stop() is called won't. However, + those callbacks will run if run() is called again later. + """ + self.call_soon(_raise_stop_error) + + def is_running(self): + """Returns running status of event loop.""" + return self._running + + def call_later(self, delay, callback, *args): + """Arrange for a callback to be called at a given time. + + Return an object with a cancel() method that can be used to + cancel the call. + + The delay can be an int or float, expressed in seconds. It is + always a relative time. + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Callbacks scheduled in the past are passed on to call_soon(), + so these will be called in the order in which they were + registered rather than by time due. This is so you can't + cheat and insert yourself at the front of the ready queue by + using a negative time. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + if delay <= 0: + return self.call_soon(callback, *args) + + handle = events.TimerHandle(time.monotonic() + delay, callback, args) + heapq.heappush(self._scheduled, handle) + return handle + + def call_repeatedly(self, interval, callback, *args): + """Call a callback every 'interval' seconds.""" + assert interval > 0, 'Interval must be > 0: {!r}'.format(interval) + + # TODO: What if callback is already a Handle? + def wrapper(): + callback(*args) # If this fails, the chain is broken. + handle._when = time.monotonic() + interval + heapq.heappush(self._scheduled, handle) + + handle = events.TimerHandle(time.monotonic() + interval, wrapper, ()) + heapq.heappush(self._scheduled, handle) + return handle + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + handle = events.make_handle(callback, args) + self._ready.append(handle) + return handle + + def call_soon_threadsafe(self, callback, *args): + """XXX""" + handle = self.call_soon(callback, *args) + self._write_to_self() + return handle + + def run_in_executor(self, executor, callback, *args): + if isinstance(callback, events.Handle): + assert not args + assert not isinstance(callback, events.TimerHandle) + if callback.cancelled: + f = futures.Future() + f.set_result(None) + return f + callback, args = callback.callback, callback.args + if executor is None: + executor = self._default_executor + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + self._default_executor = executor + return self.wrap_future(executor.submit(callback, *args)) + + def set_default_executor(self, executor): + self._default_executor = executor + + def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) + + def getnameinfo(self, sockaddr, flags=0): + return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + + @tasks.coroutine + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=None, family=0, proto=0, flags=0, sock=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + 'host/port and sock can not be specified at the same time') + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + yield from self.sock_connect(sock, address) + except socket.error as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + if len(exceptions) == 1: + raise exceptions[0] + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise socket.error('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + + elif sock is None: + raise ValueError( + 'host and port was not specified and no sock specified') + + sock.setblocking(False) + + protocol = protocol_factory() + waiter = futures.Future() + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + transport = self._make_ssl_transport( + sock, protocol, sslcontext, waiter, server_side=False) + else: + transport = self._make_socket_transport(sock, protocol, waiter) + + yield from waiter + return transport, protocol + + @tasks.coroutine + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + """Create datagram connection.""" + if not (local_addr or remote_addr): + if family == 0: + raise ValueError('unexpected address family') + addr_pairs_info = (((family, proto), (None, None)),) + else: + # join addresss by (family, protocol) + addr_infos = collections.OrderedDict() + for idx, addr in ((0, local_addr), (1, remote_addr)): + if addr is not None: + assert isinstance(addr, tuple) and len(addr) == 2, ( + '2-tuple is expected') + + infos = yield from self.getaddrinfo( + *addr, family=family, type=socket.SOCK_DGRAM, + proto=proto, flags=flags) + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + for fam, _, pro, _, address in infos: + key = (fam, pro) + if key not in addr_infos: + addr_infos[key] = [None, None] + addr_infos[key][idx] = address + + # each addr has to have info for each (family, proto) pair + addr_pairs_info = [ + (key, addr_pair) for key, addr_pair in addr_infos.items() + if not ((local_addr and addr_pair[0] is None) or + (remote_addr and addr_pair[1] is None))] + + if not addr_pairs_info: + raise ValueError('can not get address information') + + exceptions = [] + + for ((family, proto), + (local_address, remote_address)) in addr_pairs_info: + sock = None + l_addr = None + r_addr = None + try: + sock = socket.socket( + family=family, type=socket.SOCK_DGRAM, proto=proto) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(False) + + if local_addr: + sock.bind(local_address) + l_addr = sock.getsockname() + if remote_addr: + yield from self.sock_connect(sock, remote_address) + r_addr = remote_address + except socket.error as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + + protocol = protocol_factory() + transport = self._make_datagram_transport( + sock, protocol, r_addr, extra={'addr': l_addr}) + return transport, protocol + + @tasks.task + def start_serving(self, protocol_factory, host=None, port=None, *, + family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, + sock=None, backlog=100, ssl=None, reuse_address=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + 'host/port and sock can not be specified at the same time') + + AF_INET6 = getattr(socket, 'AF_INET6', 0) + if reuse_address is None: + reuse_address = os.name == 'posix' and sys.platform != 'cygwin' + sockets = [] + if host == '': + host = None + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=0, flags=flags) + if not infos: + raise socket.error('getaddrinfo() returned empty list') + + completed = False + try: + for res in infos: + af, socktype, proto, canonname, sa = res + sock = socket.socket(af, socktype, proto) + sockets.append(sock) + if reuse_address: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, + True) + # Disable IPv4/IPv6 dual stack support (enabled by + # default on Linux) which makes a single socket + # listen on both address families. + if af == AF_INET6 and hasattr(socket, 'IPPROTO_IPV6'): + sock.setsockopt(socket.IPPROTO_IPV6, + socket.IPV6_V6ONLY, + True) + try: + sock.bind(sa) + except socket.error as err: + raise socket.error(err.errno, 'error while attempting ' + 'to bind on address %r: %s' + % (sa, err.strerror.lower())) + completed = True + finally: + if not completed: + for sock in sockets: + sock.close() + else: + if sock is None: + raise ValueError( + 'host and port was not specified and no sock specified') + sockets = [sock] + + for sock in sockets: + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock, ssl) + return sockets + + @tasks.coroutine + def connect_read_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future() + transport = self._make_read_pipe_transport(pipe, protocol, waiter, + extra={}) + yield from waiter + return transport, protocol + + @tasks.coroutine + def connect_write_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future() + transport = self._make_write_pipe_transport(pipe, protocol, waiter, + extra={}) + yield from waiter + return transport, protocol + + def _add_callback(self, handle): + """Add a Handle to ready or scheduled.""" + if handle.cancelled: + return + if isinstance(handle, events.TimerHandle): + heapq.heappush(self._scheduled, handle) + else: + self._ready.append(handle) + + def wrap_future(self, future): + """XXX""" + if isinstance(future, futures.Future): + return future # Don't wrap our own type of Future. + new_future = futures.Future(event_loop=self) + future.add_done_callback( + lambda future: + self.call_soon_threadsafe(new_future._copy_state, future)) + return new_future + + def _run_once(self, timeout=None): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # TODO: Break each of these into smaller pieces. + # TODO: Refactor to separate the callbacks from the readers/writers. + # TODO: An alternative API would be to do the *minimal* amount + # of work, e.g. one callback or one I/O poll. + + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0].cancelled: + heapq.heappop(self._scheduled) + + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0].when + deadline = max(0, when - time.monotonic()) + if timeout is None: + timeout = deadline + else: + timeout = min(timeout, deadline) + + t0 = time.monotonic() + event_list = self._selector.select(timeout) + t1 = time.monotonic() + argstr = '' if timeout is None else '{:.3f}'.format(timeout) + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + tulip_log.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + self._process_events(event_list) + + # Handle 'later' callbacks that are ready. + now = time.monotonic() + while self._scheduled: + handle = self._scheduled[0] + if handle.when > now: + break + handle = heapq.heappop(self._scheduled) + self._ready.append(handle) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is threadsafe without using locks. + ntodo = len(self._ready) + for i in range(ntodo): + handle = self._ready.popleft() + if not handle.cancelled: + handle.run() + + # Future.__del__ uses log level + _log_level = logging.WARNING + + def set_log_level(self, val): + self._log_level = val + + def get_log_level(self): + return self._log_level diff --git a/tulip/constants.py b/tulip/constants.py new file mode 100644 index 00000000..79c3b931 --- /dev/null +++ b/tulip/constants.py @@ -0,0 +1,4 @@ +"""Constants.""" + + +LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5 diff --git a/tulip/events.py b/tulip/events.py new file mode 100644 index 00000000..c8f2401c --- /dev/null +++ b/tulip/events.py @@ -0,0 +1,398 @@ +"""Event loop and event loop policy. + +Beyond the PEP: +- Only the main thread has a default event loop. +""" + +__all__ = ['AbstractEventLoopPolicy', 'DefaultEventLoopPolicy', + 'AbstractEventLoop', 'TimerHandle', 'Handle', 'make_handle', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', 'new_event_loop', + ] + +import sys +import threading +import socket + +from .log import tulip_log + + +class Handle: + """Object returned by callback registration methods.""" + + def __init__(self, callback, args): + self._callback = callback + self._args = args + self._cancelled = False + + def __repr__(self): + res = 'Handle({}, {})'.format(self._callback, self._args) + if self._cancelled: + res += '' + return res + + @property + def callback(self): + return self._callback + + @property + def args(self): + return self._args + + @property + def cancelled(self): + return self._cancelled + + def cancel(self): + self._cancelled = True + + def run(self): + try: + self._callback(*self._args) + except Exception: + tulip_log.exception('Exception in callback %s %r', + self._callback, self._args) + + +def make_handle(callback, args): + if isinstance(callback, Handle): + assert not args + return callback + return Handle(callback, args) + + +class TimerHandle(Handle): + """Object returned by timed callback registration methods.""" + + def __init__(self, when, callback, args): + assert when is not None + super().__init__(callback, args) + + self._when = when + + def __repr__(self): + res = 'TimerHandle({}, {}, {})'.format(self._when, + self._callback, + self._args) + if self._cancelled: + res += '' + + return res + + @property + def when(self): + return self._when + + def __lt__(self, other): + return self._when < other._when + + def __le__(self, other): + if self._when < other._when: + return True + return self.__eq__(other) + + def __gt__(self, other): + return self._when > other._when + + def __ge__(self, other): + if self._when > other._when: + return True + return self.__eq__(other) + + def __eq__(self, other): + if isinstance(other, TimerHandle): + return (self._when == other._when and + self._callback == other._callback and + self._args == other._args and + self._cancelled == other._cancelled) + return NotImplemented + + def __ne__(self, other): + equal = self.__eq__(other) + return NotImplemented if equal is NotImplemented else not equal + + +class AbstractEventLoop: + """Abstract event loop.""" + + # Running and stopping the event loop. + + def run_forever(self): + """Run the event loop until stop() is called.""" + raise NotImplementedError + + def run_once(self, timeout=None): + """Run one complete cycle of the event loop. + + TODO: Deprecate this. + """ + raise NotImplementedError + + def run_until_complete(self, future, timeout=None): + """Run the event loop until a Future is done. + + Return the Future's result, or raise its exception. + + If timeout is not None, run it for at most that long; + if the Future is still not done, raise TimeoutError + (but don't cancel the Future). + """ + raise NotImplementedError + + def stop(self): + """Stop the event loop as soon as reasonable. + + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + raise NotImplementedError + + def is_running(self): + """Return whether the event loop is currently running.""" + raise NotImplementedError + + # Methods scheduling callbacks. All these return Handles. + + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) + + def call_later(self, delay, callback, *args): + raise NotImplementedError + + def call_repeatedly(self, interval, callback, *args): + raise NotImplementedError + + # Methods for interacting with threads. + + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError + + def wrap_future(self, future): + raise NotImplementedError + + def run_in_executor(self, executor, callback, *args): + raise NotImplementedError + + def set_default_executor(self, executor): + raise NotImplementedError + + # Network I/O methods returning Futures. + + def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError + + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=None, family=0, proto=0, flags=0, sock=None): + raise NotImplementedError + + def start_serving(self, protocol_factory, host=None, port=None, *, + family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, + sock=None, backlog=100, ssl=None, reuse_address=None): + """Creates a TCP server bound to host and port and return + a list of socket objects which will later be handled by + protocol_factory. + + If host is an empty string or None all interfaces are assumed + and a list of multiple sockets will be returned (most likely + one for IPv4 and another one for IPv6). + + family can be set to either AF_INET or AF_INET6 to force the + socket to use IPv4 or IPv6. If not set it will be determined + from host (defaults to AF_UNSPEC). + + flags is a bitmask for getaddrinfo(). + + sock can optionally be specified in order to use a preexisting + socket object. + + backlog is the maximum number of queued connections passed to + listen() (defaults to 100). + + ssl can be set to an SSLContext to enable SSL over the + accepted connections. + + reuse_address tells the kernel to reuse a local socket in + TIME_WAIT state, without waiting for its natural timeout to + expire. If not specified will automatically be set to True on + UNIX. + """ + raise NotImplementedError + + def stop_serving(self, sock): + """Stop listening for incoming connections. Close socket.""" + raise NotImplementedError + + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + def connect_read_pipe(self, protocol_factory, pipe): + """Register read pipe in eventloop. + + protocol_factory should instantiate object with Protocol interface. + pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + ReadTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def connect_write_pipe(self, protocol_factory, pipe): + """Register write pipe in eventloop. + + protocol_factory should instantiate object with BaseProtocol interface. + Pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + WriteTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + #def spawn_subprocess(self, protocol_factory, pipe): + # raise NotImplementedError + + # Ready-based callback registration methods. + # The add_*() methods return a Handle. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. + + def add_reader(self, fd, callback, *args): + raise NotImplementedError + + def remove_reader(self, fd): + raise NotImplementedError + + def add_writer(self, fd, callback, *args): + raise NotImplementedError + + def remove_writer(self, fd): + raise NotImplementedError + + # Completion based I/O methods returning Futures. + + def sock_recv(self, sock, nbytes): + raise NotImplementedError + + def sock_sendall(self, sock, data): + raise NotImplementedError + + def sock_connect(self, sock, address): + raise NotImplementedError + + def sock_accept(self, sock): + raise NotImplementedError + + # Signal handling. + + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError + + def remove_signal_handler(self, sig): + raise NotImplementedError + + +class AbstractEventLoopPolicy: + """Abstract policy for accessing the event loop.""" + + def get_event_loop(self): + """XXX""" + raise NotImplementedError + + def set_event_loop(self, event_loop): + """XXX""" + raise NotImplementedError + + def new_event_loop(self): + """XXX""" + raise NotImplementedError + + +class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy): + """Default policy implementation for accessing the event loop. + + In this policy, each thread has its own event loop. However, we + only automatically create an event loop by default for the main + thread; other threads by default have no event loop. + + Other policies may have different rules (e.g. a single global + event loop, or automatically creating an event loop per thread, or + using some other notion of context to which an event loop is + associated). + """ + + _event_loop = None + + def get_event_loop(self): + """Get the event loop. + + This may be None or an instance of EventLoop. + """ + if (self._event_loop is None and + threading.current_thread().name == 'MainThread'): + self._event_loop = self.new_event_loop() + return self._event_loop + + def set_event_loop(self, event_loop): + """Set the event loop.""" + # TODO: The isinstance() test violates the PEP. + assert event_loop is None or isinstance(event_loop, AbstractEventLoop) + self._event_loop = event_loop + + def new_event_loop(self): + """Create a new event loop. + + You must call set_event_loop() to make this the current event + loop. + """ + if sys.platform == 'win32': # pragma: no cover + from . import windows_events + return windows_events.SelectorEventLoop() + else: # pragma: no cover + from . import unix_events + return unix_events.SelectorEventLoop() + + +# Event loop policy. The policy itself is always global, even if the +# policy's rules say that there is an event loop per thread (or other +# notion of context). The default policy is installed by the first +# call to get_event_loop_policy(). +_event_loop_policy = None + + +def get_event_loop_policy(): + """XXX""" + global _event_loop_policy + if _event_loop_policy is None: + _event_loop_policy = DefaultEventLoopPolicy() + return _event_loop_policy + + +def set_event_loop_policy(policy): + """XXX""" + global _event_loop_policy + # TODO: The isinstance() test violates the PEP. + assert policy is None or isinstance(policy, AbstractEventLoopPolicy) + _event_loop_policy = policy + + +def get_event_loop(): + """XXX""" + return get_event_loop_policy().get_event_loop() + + +def set_event_loop(event_loop): + """XXX""" + get_event_loop_policy().set_event_loop(event_loop) + + +def new_event_loop(): + """XXX""" + return get_event_loop_policy().new_event_loop() diff --git a/tulip/futures.py b/tulip/futures.py new file mode 100644 index 00000000..e3baf15e --- /dev/null +++ b/tulip/futures.py @@ -0,0 +1,310 @@ +"""A Future class similar to the one in PEP 3148.""" + +__all__ = ['CancelledError', 'TimeoutError', + 'InvalidStateError', 'InvalidTimeoutError', + 'Future', + ] + +import concurrent.futures._base +import io +import logging +import traceback + +from . import events +from .log import tulip_log + +# States for Future. +_PENDING = 'PENDING' +_CANCELLED = 'CANCELLED' +_FINISHED = 'FINISHED' + +# TODO: Do we really want to depend on concurrent.futures internals? +Error = concurrent.futures._base.Error +CancelledError = concurrent.futures.CancelledError +TimeoutError = concurrent.futures.TimeoutError + +STACK_DEBUG = logging.DEBUG - 1 # heavy-duty debugging + + +class InvalidStateError(Error): + """The operation is not allowed in this state.""" + # TODO: Show the future, its state, the method, and the required state. + + +class InvalidTimeoutError(Error): + """Called result() or exception() with timeout != 0.""" + # TODO: Print a nice error message. + + +class Future: + """This class is *almost* compatible with concurrent.futures.Future. + + Differences: + + - result() and exception() do not take a timeout argument and + raise an exception when the future isn't done yet. + + - Callbacks registered with add_done_callback() are always called + via the event loop's call_soon_threadsafe(). + + - This class is not compatible with the wait() and as_completed() + methods in the concurrent.futures package. + + (In Python 3.4 or later we may be able to unify the implementations.) + """ + + # Class variables serving as defaults for instance variables. + _state = _PENDING + _result = None + _exception = None + _timeout_handle = None + _event_loop = None + + _blocking = False # proper use of future (yield vs yield from) + + # result of the future has to be requested + _debug_stack = None + _debug_result_requested = False + _debug_warn_result_requested = True + + def __init__(self, *, event_loop=None, timeout=None): + """Initialize the future. + + The optional event_loop argument allows to explicitly set the event + loop object used by the future. If it's not provided, the future uses + the default event loop. + """ + if event_loop is None: + self._event_loop = events.get_event_loop() + else: + self._event_loop = event_loop + self._callbacks = [] + + if timeout is not None: + self._timeout_handle = self._event_loop.call_later( + timeout, self.cancel) + + if __debug__: + if self._event_loop.get_log_level() <= STACK_DEBUG: + out = io.StringIO() + traceback.print_stack(file=out) + self._debug_stack = out.getvalue() + + def __repr__(self): + res = self.__class__.__name__ + if self._state == _FINISHED: + if self._exception is not None: + res += ''.format(self._exception) + else: + res += ''.format(self._result) + elif self._callbacks: + size = len(self._callbacks) + if size > 2: + res += '<{}, [{}, <{} more>, {}]>'.format( + self._state, self._callbacks[0], + size-2, self._callbacks[-1]) + else: + res += '<{}, {}>'.format(self._state, self._callbacks) + else: + res += '<{}>'.format(self._state) + return res + + def cancel(self): + """Cancel the future and schedule callbacks. + + If the future is already done or cancelled, return False. Otherwise, + change the future's state to cancelled, schedule the callbacks and + return True. + """ + if self._state != _PENDING: + return False + self._state = _CANCELLED + self._schedule_callbacks() + return True + + def _schedule_callbacks(self): + """Internal: Ask the event loop to call all callbacks. + + The callbacks are scheduled to be called as soon as possible. Also + clears the callback list. + """ + # Cancel timeout handle + if self._timeout_handle is not None: + self._timeout_handle.cancel() + self._timeout_handle = None + + callbacks = self._callbacks[:] + if not callbacks: + return + + self._callbacks[:] = [] + for callback in callbacks: + self._event_loop.call_soon(callback, self) + + def cancelled(self): + """Return True if the future was cancelled.""" + return self._state == _CANCELLED + + def running(self): + """Always return False. + + This method is for compatibility with concurrent.futures; we don't + have a running state. + """ + return False # We don't have a running state. + + def done(self): + """Return True if the future is done. + + Done means either that a result / exception are available, or that the + future was cancelled. + """ + return self._state != _PENDING + + def result(self, timeout=0): + """Return the result this future represents. + + If the future has been cancelled, raises CancelledError. If the + future's result isn't yet available, raises InvalidStateError. If + the future is done and has an exception set, this exception is raised. + Timeout values other than 0 are not supported. + """ + if __debug__: + self._debug_result_requested = True + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + if self._exception is not None: + raise self._exception + return self._result + + def exception(self, timeout=0): + """Return the exception that was set on this future. + + The exception (or None if no exception was set) is returned only if + the future is done. If the future has been cancelled, raises + CancelledError. If the future isn't done yet, raises + InvalidStateError. Timeout values other than 0 are not supported. + """ + if __debug__: + self._debug_result_requested = True + if timeout != 0: + raise InvalidTimeoutError + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError + return self._exception + + def add_done_callback(self, fn): + """Add a callback to be run when the future becomes done. + + The callback is called with a single argument - the future object. If + the future is already done when this is called, the callback is + scheduled with call_soon. + """ + if self._state != _PENDING: + self._event_loop.call_soon(fn, self) + else: + self._callbacks.append(fn) + + # New method not in PEP 3148. + + def remove_done_callback(self, fn): + """Remove all instances of a callback from the "call when done" list. + + Returns the number of callbacks removed. + """ + filtered_callbacks = [f for f in self._callbacks if f != fn] + removed_count = len(self._callbacks) - len(filtered_callbacks) + if removed_count: + self._callbacks[:] = filtered_callbacks + return removed_count + + # So-called internal methods (note: no set_running_or_notify_cancel()). + + def set_result(self, result): + """Mark the future done and set its result. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError + self._result = result + self._state = _FINISHED + self._schedule_callbacks() + + def set_exception(self, exception): + """ Mark the future done and set an exception. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError + self._exception = exception + self._state = _FINISHED + self._schedule_callbacks() + + # Truly internal methods. + + def _copy_state(self, other): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert other.done() + assert not self.done() + if other.cancelled(): + self.cancel() + else: + exception = other.exception() + if exception is not None: + self.set_exception(exception) + else: + result = other.result() + self.set_result(result) + + def __iter__(self): + if not self.done(): + self._blocking = True + yield self # This tells Task to wait for completion. + assert self.done(), "yield from wasn't used with future" + return self.result() # May raise too. + + if __debug__: + def __del__(self): + if (not self._debug_result_requested and + self._state != _CANCELLED and + self._event_loop is not None): + + level = self._event_loop.get_log_level() + if level > logging.WARNING: + return + + r_self = repr(self) + + if self._state == _PENDING: + tulip_log.error( + 'Future abandoned before completion: %s', r_self) + if (self._debug_stack and level <= STACK_DEBUG): + tulip_log.error(self._debug_stack) + + else: + exc = self._exception + if exc is not None: + tulip_log.exception( + 'Future raised an exception and ' + 'nobody caught it: %s', r_self, + exc_info=(exc.__class__, exc, exc.__traceback__)) + if (self._debug_stack and level <= STACK_DEBUG): + tulip_log.error(self._debug_stack) + elif self._debug_warn_result_requested: + tulip_log.error( + 'Future result has not been requested: %s', r_self) + if (self._debug_stack and level <= STACK_DEBUG): + tulip_log.error(self._debug_stack) diff --git a/tulip/http/__init__.py b/tulip/http/__init__.py new file mode 100644 index 00000000..a1432dee --- /dev/null +++ b/tulip/http/__init__.py @@ -0,0 +1,16 @@ +# This relies on each of the submodules having an __all__ variable. + +from .client import * +from .errors import * +from .protocol import * +from .server import * +from .session import * +from .wsgi import * + + +__all__ = (client.__all__ + + errors.__all__ + + protocol.__all__ + + server.__all__ + + session.__all__ + + wsgi.__all__) diff --git a/tulip/http/client.py b/tulip/http/client.py new file mode 100644 index 00000000..4c797b8c --- /dev/null +++ b/tulip/http/client.py @@ -0,0 +1,560 @@ +"""HTTP Client for Tulip. + +Most basic usage: + + response = yield from tulip.http.request('GET', url) + response['Content-Type'] == 'application/json' + response.status == 200 + + content = yield from response.content.read() +""" + +__all__ = ['request'] + +import base64 +import email.message +import http.client +import http.cookies +import json +import io +import itertools +import mimetypes +import os +import uuid +import urllib.parse + +import tulip +import tulip.http + + +@tulip.coroutine +def request(method, url, *, + params=None, + data=None, + headers=None, + cookies=None, + files=None, + auth=None, + allow_redirects=True, + max_redirects=10, + encoding='utf-8', + version=(1, 1), + timeout=None, + compress=None, + chunked=None, + session=None): + """Constructs and sends a request. Returns response object. + + method: http method + url: request url + params: (optional) Dictionary or bytes to be sent in the query string + of the new request + data: (optional) Dictionary, bytes, or file-like object to + send in the body of the request + headers: (optional) Dictionary of HTTP Headers to send with the request + cookies: (optional) Dict object to send with the request + files: (optional) Dictionary of 'name': file-like-objects + for multipart encoding upload + auth: (optional) Auth tuple to enable Basic HTTP Auth + timeout: (optional) Float describing the timeout of the request + allow_redirects: (optional) Boolean. Set to True if POST/PUT/DELETE + redirect following is allowed. + compress: Boolean. Set to True if request has to be compressed + with deflate encoding. + chunked: Boolean or Integer. Set to chunk size for chunked + transfer encoding. + session: tulip.http.Session instance to support connection pooling and + session cookies. + + Usage: + + import tulip.http + >> resp = yield from tulip.http.request('GET', 'http://python.org/') + >> resp + + + >> data = yield from resp.content.read() + + """ + redirects = 0 + loop = tulip.get_event_loop() + + while True: + req = HttpRequest( + method, url, params=params, headers=headers, data=data, + cookies=cookies, files=files, auth=auth, encoding=encoding, + version=version, compress=compress, chunked=chunked) + + if session is None: + conn = start(req, loop) + else: + conn = session.start(req, loop) + + # connection timeout + try: + resp = yield from tulip.Task(conn, timeout=timeout) + except tulip.CancelledError: + raise tulip.TimeoutError from None + + # redirects + if resp.status in (301, 302) and allow_redirects: + redirects += 1 + if max_redirects and redirects >= max_redirects: + resp.close() + break + + r_url = resp.get('location') or resp.get('uri') + + scheme = urllib.parse.urlsplit(r_url)[0] + if scheme not in ('http', 'https', ''): + raise ValueError('Can redirect only to http or https') + elif not scheme: + r_url = urllib.parse.urljoin(url, r_url) + + url = urllib.parse.urldefrag(r_url)[0] + if url: + resp.close() + continue + + break + + return resp + + +@tulip.coroutine +def start(req, loop): + transport, p = yield from loop.create_connection( + tulip.StreamProtocol, req.host, req.port, ssl=req.ssl) + + try: + resp = req.send(transport) + yield from resp.start(p, transport) + except: + transport.close() + raise + + return resp + + +class HttpRequest: + + GET_METHODS = {'DELETE', 'GET', 'HEAD', 'OPTIONS'} + POST_METHODS = {'PATCH', 'POST', 'PUT', 'TRACE'} + ALL_METHODS = GET_METHODS.union(POST_METHODS) + + DEFAULT_HEADERS = { + 'Accept': '*/*', + 'Accept-Encoding': 'gzip, deflate', + } + + body = b'' + + def __init__(self, method, url, *, + params=None, + headers=None, + data=None, + cookies=None, + files=None, + auth=None, + encoding='utf-8', + version=(1, 1), + compress=None, + chunked=None): + self.method = method.upper() + self.encoding = encoding + + # parser http version '1.1' => (1, 1) + if isinstance(version, str): + v = [l.strip() for l in version.split('.', 1)] + try: + version = int(v[0]), int(v[1]) + except ValueError: + raise ValueError( + 'Can not parse http version number: {}' + .format(version)) from None + self.version = version + + # path + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not netloc: + raise ValueError('Host could not be detected.') + + if not path: + path = '/' + else: + path = urllib.parse.unquote(path) + + # check domain idna encoding + try: + netloc = netloc.encode('idna').decode('utf-8') + except UnicodeError: + raise ValueError('URL has an invalid label.') + + # basic auth info + if '@' in netloc: + authinfo, netloc = netloc.split('@', 1) + if not auth: + auth = authinfo.split(':', 1) + if len(auth) == 1: + auth.append('') + + # extract host and port + ssl = scheme == 'https' + + if ':' in netloc: + netloc, port_s = netloc.split(':', 1) + try: + port = int(port_s) + except ValueError: + raise ValueError( + 'Port number could not be converted.') from None + else: + if ssl: + port = http.client.HTTPS_PORT + else: + port = http.client.HTTP_PORT + + self.host = netloc + self.port = port + self.ssl = ssl + + # build url query + if isinstance(params, dict): + params = list(params.items()) + + if data and self.method in self.GET_METHODS: + # include data to query + if isinstance(data, dict): + data = data.items() + params = list(itertools.chain(params or (), data)) + data = None + + if params: + params = urllib.parse.urlencode(params) + if query: + query = '%s&%s' % (query, params) + else: + query = params + + # build path + path = urllib.parse.quote(path) + self.path = urllib.parse.urlunsplit(('', '', path, query, fragment)) + + # headers + self.headers = email.message.Message() + if headers: + if isinstance(headers, dict): + headers = list(headers.items()) + + for key, value in headers: + self.headers.add_header(key, value) + + for hdr, val in self.DEFAULT_HEADERS.items(): + if hdr not in self.headers: + self.headers[hdr] = val + + # host + if 'host' not in self.headers: + self.headers['Host'] = self.host + + # cookies + if cookies: + self.update_cookies(cookies) + + # auth + if auth: + if isinstance(auth, (tuple, list)) and len(auth) == 2: + # basic auth + self.headers['Authorization'] = 'Basic %s' % ( + base64.b64encode( + ('%s:%s' % (auth[0], auth[1])).encode('latin1')) + .strip().decode('latin1')) + else: + raise ValueError("Only basic auth is supported") + + # Content-encoding + enc = self.headers.get('Content-Encoding', '').lower() + if enc: + chunked = True # enable chunked, no need to deal with length + compress = enc + elif compress: + chunked = True # enable chunked, no need to deal with length + compress = compress if isinstance(compress, str) else 'deflate' + self.headers['Content-Encoding'] = compress + + # form data (x-www-form-urlencoded) + if isinstance(data, dict): + data = list(data.items()) + + if data and not files: + if not isinstance(data, str): + data = urllib.parse.urlencode(data, doseq=True) + + self.body = data.encode(encoding) + if 'content-type' not in self.headers: + self.headers['content-type'] = ( + 'application/x-www-form-urlencoded') + if 'content-length' not in self.headers and not chunked: + self.headers['content-length'] = str(len(self.body)) + + # files (multipart/form-data) + elif files: + fields = [] + + if data: + for field, val in data: + fields.append((field, str_to_bytes(val))) + + if isinstance(files, dict): + files = list(files.items()) + + for rec in files: + if not isinstance(rec, (tuple, list)): + rec = (rec,) + + ft = None + if len(rec) == 1: + k = guess_filename(rec[0], 'unknown') + fields.append((k, k, rec[0])) + + elif len(rec) == 2: + k, fp = rec + fn = guess_filename(fp, k) + fields.append((k, fn, fp)) + + else: + k, fp, ft = rec + fn = guess_filename(fp, k) + fields.append((k, fn, fp, ft)) + + chunked = chunked or 8192 + boundary = uuid.uuid4().hex + + self.body = encode_multipart_data( + fields, bytes(boundary, 'latin1')) + + self.headers['content-type'] = ( + 'multipart/form-data; boundary=%s' % boundary) + + # chunked + te = self.headers.get('transfer-encoding', '').lower() + + if chunked: + if 'content-length' in self.headers: + del self.headers['content-length'] + if 'chunked' not in te: + self.headers['transfer-encoding'] = 'chunked' + + chunked = chunked if type(chunked) is int else 8196 + else: + if 'chunked' in te: + chunked = 8196 + else: + chunked = None + self.headers['content-length'] = str(len(self.body)) + + self._chunked = chunked + self._compress = compress + + def update_cookies(self, cookies): + """Update request cookies header.""" + c = http.cookies.SimpleCookie() + if 'cookie' in self.headers: + c.load(self.headers.get('cookie', '')) + del self.headers['cookie'] + + if isinstance(cookies, dict): + cookies = cookies.items() + + for name, value in cookies: + if isinstance(value, http.cookies.Morsel): + # use dict method because SimpleCookie class modifies value + dict.__setitem__(c, name, value) + else: + c[name] = value + + self.headers['cookie'] = c.output(header='', sep=';').strip() + + def send(self, transport): + request = tulip.http.Request( + transport, self.method, self.path, self.version) + + if self._compress: + request.add_compression_filter(self._compress) + + if self._chunked is not None: + request.add_chunking_filter(self._chunked) + + request.add_headers(*self.headers.items()) + request.send_headers() + + if isinstance(self.body, bytes): + self.body = (self.body,) + + for chunk in self.body: + request.write(chunk) + + request.write_eof() + + return HttpResponse(self.method, self.path, self.host) + + +class HttpResponse(http.client.HTTPMessage): + + message = None # RawResponseStatus object + + # from the Status-Line of the response + version = None # HTTP-Version + status = None # Status-Code + reason = None # Reason-Phrase + + cookies = None # Response cookies (Set-Cookie) + + content = None # payload stream + stream = None # input stream + transport = None # current transport + + def __init__(self, method, url, host=''): + super().__init__() + + self.method = method + self.url = url + self.host = host + self._content = None + + def __del__(self): + self.close() + + def __repr__(self): + out = io.StringIO() + print(''.format( + self.host, self.url, self.status, self.reason), file=out) + print(super().__str__(), file=out) + return out.getvalue() + + def start(self, stream, transport): + """Start response processing.""" + self.stream = stream + self.transport = transport + + httpstream = stream.set_parser(tulip.http.http_response_parser()) + + # read response + self.message = yield from httpstream.read() + + # response status + self.version = self.message.version + self.status = self.message.code + self.reason = self.message.reason + + # headers + for hdr, val in self.message.headers: + self.add_header(hdr, val) + + # payload + self.content = stream.set_parser( + tulip.http.http_payload_parser(self.message)) + + # cookies + self.cookies = http.cookies.SimpleCookie() + if 'Set-Cookie' in self: + for hdr in self.get_all('Set-Cookie'): + self.cookies.load(hdr) + + return self + + def close(self): + if self.transport is not None: + self.transport.close() + self.transport = None + + @tulip.coroutine + def read(self, decode=False): + """Read response payload. Decode known types of content.""" + if self._content is None: + buf = [] + total = 0 + chunk = yield from self.content.read() + while chunk: + size = len(chunk) + buf.append((chunk, size)) + total += size + chunk = yield from self.content.read() + + self._content = bytearray(total) + + idx = 0 + content = memoryview(self._content) + for chunk, size in buf: + content[idx:idx+size] = chunk + idx += size + + data = self._content + + if decode: + ct = self.get('content-type', '').lower() + if ct == 'application/json': + data = json.loads(data.decode('utf-8')) + + return data + + +def str_to_bytes(s, encoding='utf-8'): + if isinstance(s, str): + return s.encode(encoding) + return s + + +def guess_filename(obj, default=None): + name = getattr(obj, 'name', None) + if name and name[0] != '<' and name[-1] != '>': + return os.path.split(name)[-1] + return default + + +def encode_multipart_data(fields, boundary, encoding='utf-8', chunk_size=8196): + """ + Encode a list of fields using the multipart/form-data MIME format. + + fields: + List of (name, value) or (name, filename, io) or + (name, filename, io, MIME type) field tuples. + """ + for rec in fields: + yield b'--' + boundary + b'\r\n' + + field, *rec = rec + + if len(rec) == 1: + data = rec[0] + yield (('Content-Disposition: form-data; name="%s"\r\n\r\n' % + (field,)).encode(encoding)) + yield data + b'\r\n' + + else: + if len(rec) == 3: + fn, fp, ct = rec + else: + fn, fp = rec + ct = (mimetypes.guess_type(fn)[0] or + 'application/octet-stream') + + yield ('Content-Disposition: form-data; name="%s"; ' + 'filename="%s"\r\n' % (field, fn)).encode(encoding) + yield ('Content-Type: %s\r\n\r\n' % (ct,)).encode(encoding) + + if isinstance(fp, str): + fp = fp.encode(encoding) + + if isinstance(fp, bytes): + fp = io.BytesIO(fp) + + while True: + chunk = fp.read(chunk_size) + if not chunk: + break + yield str_to_bytes(chunk) + + yield b'\r\n' + + yield b'--' + boundary + b'--\r\n' diff --git a/tulip/http/errors.py b/tulip/http/errors.py new file mode 100644 index 00000000..24032337 --- /dev/null +++ b/tulip/http/errors.py @@ -0,0 +1,44 @@ +"""http related errors.""" + +__all__ = ['HttpException', 'HttpStatusException', + 'IncompleteRead', 'BadStatusLine', 'LineTooLong', 'InvalidHeader'] + +import http.client + + +class HttpException(http.client.HTTPException): + + code = None + headers = () + + +class HttpStatusException(HttpException): + + def __init__(self, code, headers=None, message=''): + self.code = code + self.headers = headers + self.message = message + + +class BadRequestException(HttpException): + + code = 400 + + +class IncompleteRead(BadRequestException, http.client.IncompleteRead): + pass + + +class BadStatusLine(BadRequestException, http.client.BadStatusLine): + pass + + +class LineTooLong(BadRequestException, http.client.LineTooLong): + pass + + +class InvalidHeader(BadRequestException): + + def __init__(self, hdr): + super().__init__('Invalid HTTP Header: {}'.format(hdr)) + self.hdr = hdr diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py new file mode 100644 index 00000000..0a2959cf --- /dev/null +++ b/tulip/http/protocol.py @@ -0,0 +1,747 @@ +"""Http related helper utils.""" + +__all__ = ['HttpMessage', 'Request', 'Response', + 'RawRequestMessage', 'RawResponseMessage', + 'http_request_parser', 'http_response_parser', + 'http_payload_parser'] + +import collections +import functools +import http.server +import itertools +import re +import sys +import zlib +from wsgiref.handlers import format_date_time + +import tulip +from tulip.http import errors + +METHRE = re.compile('[A-Z0-9$-_.]+') +VERSRE = re.compile('HTTP/(\d+).(\d+)') +HDRRE = re.compile('[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]') +CONTINUATION = (' ', '\t') +EOF_MARKER = object() +EOL_MARKER = object() + +RESPONSES = http.server.BaseHTTPRequestHandler.responses + + +RawRequestMessage = collections.namedtuple( + 'RawRequestLine', + ['method', 'path', 'version', 'headers', 'should_close', 'compression']) + + +RawResponseMessage = collections.namedtuple( + 'RawResponseStatus', + ['version', 'code', 'reason', 'headers', 'should_close', 'compression']) + + +def http_request_parser(max_line_size=8190, + max_headers=32768, max_field_size=8190): + """Read request status line. Exception errors.BadStatusLine + could be raised in case of any errors in status line. + Returns RawRequestMessage. + """ + out, buf = yield + + try: + # read http message (request line + headers) + raw_data = yield from buf.readuntil( + b'\r\n\r\n', max_headers, errors.LineTooLong) + lines = raw_data.decode('ascii', 'surrogateescape').splitlines(True) + + # request line + line = lines[0] + try: + method, path, version = line.split(None, 2) + except ValueError: + raise errors.BadStatusLine(line) from None + + # method + method = method.upper() + if not METHRE.match(method): + raise errors.BadStatusLine(method) + + # version + match = VERSRE.match(version) + if match is None: + raise errors.BadStatusLine(version) + version = (int(match.group(1)), int(match.group(2))) + + # read headers + headers, close, compression = parse_headers( + lines, max_line_size, max_headers, max_field_size) + if close is None: + close = version <= (1, 0) + + out.feed_data( + RawRequestMessage( + method, path, version, headers, close, compression)) + out.feed_eof() + except tulip.EofStream: + # Presumably, the server closed the connection before + # sending a valid response. + pass + + +def http_response_parser(max_line_size=8190, + max_headers=32768, max_field_size=8190): + """Read response status line and headers. + + BadStatusLine could be raised in case of any errors in status line. + Returns RawResponseMessage""" + out, buf = yield + + try: + # read http message (response line + headers) + raw_data = yield from buf.readuntil( + b'\r\n\r\n', max_line_size+max_headers, errors.LineTooLong) + lines = raw_data.decode('ascii', 'surrogateescape').splitlines(True) + + line = lines[0] + try: + version, status = line.split(None, 1) + except ValueError: + raise errors.BadStatusLine(line) from None + else: + try: + status, reason = status.split(None, 1) + except ValueError: + reason = '' + + # version + match = VERSRE.match(version) + if match is None: + raise errors.BadStatusLine(line) + version = (int(match.group(1)), int(match.group(2))) + + # The status code is a three-digit number + try: + status = int(status) + except ValueError: + raise errors.BadStatusLine(line) from None + + if status < 100 or status > 999: + raise errors.BadStatusLine(line) + + # read headers + headers, close, compression = parse_headers( + lines, max_line_size, max_headers, max_field_size) + if close is None: + close = version <= (1, 0) + + out.feed_data( + RawResponseMessage( + version, status, reason.strip(), headers, close, compression)) + out.feed_eof() + except tulip.EofStream: + # Presumably, the server closed the connection before + # sending a valid response. + raise errors.BadStatusLine(b'') from None + + +def parse_headers(lines, max_line_size, max_headers, max_field_size): + """Parses RFC2822 headers from a stream. + + Line continuations are supported. Returns list of header name + and value pairs. Header name is in upper case. + """ + close_conn = None + encoding = None + headers = collections.deque() + + lines_idx = 1 + line = lines[1] + + while line not in ('\r\n', '\n'): + header_length = len(line) + + # Parse initial header name : value pair. + try: + name, value = line.split(':', 1) + except ValueError: + raise ValueError('Invalid header: {}'.format(line)) from None + + name = name.strip(' \t').upper() + if HDRRE.search(name): + raise ValueError('Invalid header name: {}'.format(name)) + + # next line + lines_idx += 1 + line = lines[lines_idx] + + # consume continuation lines + continuation = line[0] in CONTINUATION + + if continuation: + value = [value] + while continuation: + header_length += len(line) + if header_length > max_field_size: + raise errors.LineTooLong( + 'limit request headers fields size') + value.append(line) + + # next line + lines_idx += 1 + line = lines[lines_idx] + continuation = line[0] in CONTINUATION + value = ''.join(value) + else: + if header_length > max_field_size: + raise errors.LineTooLong('limit request headers fields size') + + value = value.strip() + + # keep-alive and encoding + if name == 'CONNECTION': + v = value.lower() + if v == 'close': + close_conn = True + elif v == 'keep-alive': + close_conn = False + elif name == 'CONTENT-ENCODING': + enc = value.lower() + if enc in ('gzip', 'deflate'): + encoding = enc + + headers.append((name, value)) + + return headers, close_conn, encoding + + +def http_payload_parser(message, length=None, compression=True, readall=False): + out, buf = yield + + # payload params + chunked = False + for name, value in message.headers: + if name == 'CONTENT-LENGTH': + length = value + elif name == 'TRANSFER-ENCODING': + chunked = value.lower() == 'chunked' + elif name == 'SEC-WEBSOCKET-KEY1': + length = 8 + + # payload decompression wrapper + if compression and message.compression: + out = DeflateBuffer(out, message.compression) + + # payload parser + if chunked: + yield from parse_chunked_payload(out, buf) + + elif length is not None: + try: + length = int(length) + except ValueError: + raise errors.InvalidHeader('CONTENT-LENGTH') from None + + if length < 0: + raise errors.InvalidHeader('CONTENT-LENGTH') + elif length > 0: + yield from parse_length_payload(out, buf, length) + else: + if readall: + yield from parse_eof_payload(out, buf) + + out.feed_eof() + + +def parse_chunked_payload(out, buf): + """Chunked transfer encoding parser.""" + try: + while True: + # read next chunk size + #line = yield from buf.readline(8196) + line = yield from buf.readuntil(b'\r\n', 8196) + + i = line.find(b';') + if i >= 0: + line = line[:i] # strip chunk-extensions + else: + line = line.strip() + try: + size = int(line, 16) + except ValueError: + raise errors.IncompleteRead(b'') from None + + if size == 0: # eof marker + break + + # read chunk and feed buffer + while size: + chunk = yield from buf.readsome(size) + out.feed_data(chunk) + size = size - len(chunk) + + # toss the CRLF at the end of the chunk + yield from buf.skip(2) + + # read and discard trailer up to the CRLF terminator + yield from buf.skipuntil(b'\r\n') + + except tulip.EofStream: + raise errors.IncompleteRead(b'') from None + + +def parse_length_payload(out, buf, length): + """Read specified amount of bytes.""" + try: + while length: + chunk = yield from buf.readsome(length) + out.feed_data(chunk) + length -= len(chunk) + + except tulip.EofStream: + raise errors.IncompleteRead(b'') from None + + +def parse_eof_payload(out, buf): + """Read all bytes untile eof.""" + while True: + out.feed_data((yield from buf.readsome())) + + +class DeflateBuffer: + """DeflateStream decomress stream and feed data into specified stream.""" + + def __init__(self, out, encoding): + self.out = out + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + + self.zlib = zlib.decompressobj(wbits=zlib_mode) + + def feed_data(self, chunk): + try: + chunk = self.zlib.decompress(chunk) + except Exception: + raise errors.IncompleteRead(b'') from None + + if chunk: + self.out.feed_data(chunk) + + def feed_eof(self): + self.out.feed_data(self.zlib.flush()) + if not self.zlib.eof: + raise errors.IncompleteRead(b'') + + self.out.feed_eof() + + +def wrap_payload_filter(func): + """Wraps payload filter and piped filters. + + Filter is a generatator that accepts arbitrary chunks of data, + modify data and emit new stream of data. + + For example we have stream of chunks: ['1', '2', '3', '4', '5'], + we can apply chunking filter to this stream: + + ['1', '2', '3', '4', '5'] + | + response.add_chunking_filter(2) + | + ['12', '34', '5'] + + It is possible to use different filters at the same time. + + For a example to compress incoming stream with 'deflate' encoding + and then split data and emit chunks of 8196 bytes size chunks: + + >> response.add_compression_filter('deflate') + >> response.add_chunking_filter(8196) + + Filters do not alter transfer encoding. + + Filter can receive types types of data, bytes object or EOF_MARKER. + + 1. If filter receives bytes object, it should process data + and yield processed data then yield EOL_MARKER object. + 2. If Filter recevied EOF_MARKER, it should yield remaining + data (buffered) and then yield EOF_MARKER. + """ + @functools.wraps(func) + def wrapper(self, *args, **kw): + new_filter = func(self, *args, **kw) + + filter = self.filter + if filter is not None: + next(new_filter) + self.filter = filter_pipe(filter, new_filter) + else: + self.filter = new_filter + + next(self.filter) + + return wrapper + + +def filter_pipe(filter, filter2): + """Creates pipe between two filters. + + filter_pipe() feeds first filter with incoming data and then + send yielded from first filter data into filter2, results of + filter2 are being emitted. + + 1. If filter_pipe receives bytes object, it sends it to the first filter. + 2. Reads yielded values from the first filter until it receives + EOF_MARKER or EOL_MARKER. + 3. Each of this values is being send to second filter. + 4. Reads yielded values from second filter until it recives EOF_MARKER or + EOL_MARKER. Each of this values yields to writer. + """ + chunk = yield + + while True: + eof = chunk is EOF_MARKER + chunk = filter.send(chunk) + + while chunk is not EOL_MARKER: + chunk = filter2.send(chunk) + + while chunk not in (EOF_MARKER, EOL_MARKER): + yield chunk + chunk = next(filter2) + + if chunk is not EOF_MARKER: + if eof: + chunk = EOF_MARKER + else: + chunk = next(filter) + else: + break + + chunk = yield EOL_MARKER + + +class HttpMessage: + """HttpMessage allows to write headers and payload to a stream. + + For example, lets say we want to read file then compress it with deflate + compression and then send it with chunked transfer encoding, code may look + like this: + + >> response = tulip.http.Response(transport, 200) + + We have to use deflate compression first: + + >> response.add_compression_filter('deflate') + + Then we want to split output stream into chunks of 1024 bytes size: + + >> response.add_chunking_filter(1024) + + We can add headers to response with add_headers() method. add_headers() + does not send data to transport, send_headers() sends request/response + line and then sends headers: + + >> response.add_headers( + .. ('Content-Disposition', 'attachment; filename="..."')) + >> response.send_headers() + + Now we can use chunked writer to write stream to a network stream. + First call to write() method sends response status line and headers, + add_header() and add_headers() method unavailble at this stage: + + >> with open('...', 'rb') as f: + .. chunk = fp.read(8196) + .. while chunk: + .. response.write(chunk) + .. chunk = fp.read(8196) + + >> response.write_eof() + """ + + writer = None + + # 'filter' is being used for altering write() bahaviour, + # add_chunking_filter adds deflate/gzip compression and + # add_compression_filter splits incoming data into a chunks. + filter = None + + HOP_HEADERS = None # Must be set by subclass. + + SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} tulip/0.0'.format(sys.version_info) + + status = None + status_line = b'' + + # subclass can enable auto sending headers with write() call, + # this is useful for wsgi's start_response implementation. + _send_headers = False + + def __init__(self, transport, version, close): + self.transport = transport + self.version = version + self.closing = close + self.keepalive = None + + self.chunked = False + self.length = None + self.upgrade = False + self.headers = collections.deque() + self.headers_sent = False + + def force_close(self): + self.closing = True + self.keepalive = False + + def force_chunked(self): + self.chunked = True + + def keep_alive(self): + if self.keepalive is None: + return not self.closing + else: + return self.keepalive + + def is_headers_sent(self): + return self.headers_sent + + def add_header(self, name, value): + """Analyze headers. Calculate content length, + removes hop headers, etc.""" + assert not self.headers_sent, 'headers have been sent already' + assert isinstance(name, str), '{!r} is not a string'.format(name) + + name = name.strip().upper() + + if name == 'CONTENT-LENGTH': + self.length = int(value) + + if name == 'CONNECTION': + val = value.lower() + # handle websocket + if 'upgrade' in val: + self.upgrade = True + # connection keep-alive + elif 'close' in val: + self.keepalive = False + elif 'keep-alive' in val: + self.keepalive = True + + elif name == 'UPGRADE': + if 'websocket' in value.lower(): + self.headers.append((name, value)) + + elif name == 'TRANSFER-ENCODING' and not self.chunked: + self.chunked = value.lower().strip() == 'chunked' + + elif name not in self.HOP_HEADERS: + # ignore hopbyhop headers + self.headers.append((name, value)) + + def add_headers(self, *headers): + """Adds headers to a http message.""" + for name, value in headers: + self.add_header(name, value) + + def send_headers(self): + """Writes headers to a stream. Constructs payload writer.""" + # Chunked response is only for HTTP/1.1 clients or newer + # and there is no Content-Length header is set. + # Do not use chunked responses when the response is guaranteed to + # not have a response body (304, 204). + assert not self.headers_sent, 'headers have been sent already' + self.headers_sent = True + + if (self.chunked is True) or ( + self.length is None and + self.version >= (1, 1) and + self.status not in (304, 204)): + self.chunked = True + self.writer = self._write_chunked_payload() + + elif self.length is not None: + self.writer = self._write_length_payload(self.length) + + else: + self.writer = self._write_eof_payload() + + next(self.writer) + + self._add_default_headers() + + # status + headers + hdrs = ''.join(itertools.chain( + (self.status_line,), + *((k, ': ', v, '\r\n') for k, v in self.headers))) + + self.transport.write(hdrs.encode('ascii') + b'\r\n') + + def _add_default_headers(self): + # set the connection header + if self.upgrade: + connection = 'upgrade' + elif not self.closing if self.keepalive is None else self.keepalive: + connection = 'keep-alive' + else: + connection = 'close' + + if self.chunked: + self.headers.appendleft(('TRANSFER-ENCODING', 'chunked')) + + self.headers.appendleft(('CONNECTION', connection)) + + def write(self, chunk): + """write() writes chunk of data to a steram by using different writers. + writer uses filter to modify chunk of data. write_eof() indicates + end of stream. writer can't be used after write_eof() method + being called.""" + assert (isinstance(chunk, (bytes, bytearray)) or + chunk is EOF_MARKER), chunk + + if self._send_headers and not self.headers_sent: + self.send_headers() + + assert self.writer is not None, 'send_headers() is not called.' + + if self.filter: + chunk = self.filter.send(chunk) + while chunk not in (EOF_MARKER, EOL_MARKER): + self.writer.send(chunk) + chunk = next(self.filter) + else: + if chunk is not EOF_MARKER: + self.writer.send(chunk) + + def write_eof(self): + self.write(EOF_MARKER) + try: + self.writer.throw(tulip.EofStream()) + except StopIteration: + pass + + def _write_chunked_payload(self): + """Write data in chunked transfer encoding.""" + while True: + try: + chunk = yield + except tulip.EofStream: + self.transport.write(b'0\r\n\r\n') + break + + self.transport.write('{:x}\r\n'.format(len(chunk)).encode('ascii')) + self.transport.write(bytes(chunk)) + self.transport.write(b'\r\n') + + def _write_length_payload(self, length): + """Write specified number of bytes to a stream.""" + while True: + try: + chunk = yield + except tulip.EofStream: + break + + if length: + l = len(chunk) + if length >= l: + self.transport.write(chunk) + else: + self.transport.write(chunk[:length]) + + length = max(0, length-l) + + def _write_eof_payload(self): + while True: + try: + chunk = yield + except tulip.EofStream: + break + + self.transport.write(chunk) + + @wrap_payload_filter + def add_chunking_filter(self, chunk_size=16*1024): + """Split incoming stream into chunks.""" + buf = bytearray() + chunk = yield + + while True: + if chunk is EOF_MARKER: + if buf: + yield buf + + yield EOF_MARKER + + else: + buf.extend(chunk) + + while len(buf) >= chunk_size: + chunk = bytes(buf[:chunk_size]) + del buf[:chunk_size] + yield chunk + + chunk = yield EOL_MARKER + + @wrap_payload_filter + def add_compression_filter(self, encoding='deflate'): + """Compress incoming stream with deflate or gzip encoding.""" + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + zcomp = zlib.compressobj(wbits=zlib_mode) + + chunk = yield + while True: + if chunk is EOF_MARKER: + yield zcomp.flush() + chunk = yield EOF_MARKER + + else: + yield zcomp.compress(chunk) + chunk = yield EOL_MARKER + + +class Response(HttpMessage): + """Create http response message. + + Transport is a socket stream transport. status is a response status code, + status has to be integer value. http_version is a tuple that represents + http version, (1, 0) stands for HTTP/1.0 and (1, 1) is for HTTP/1.1 + """ + + HOP_HEADERS = { + 'CONNECTION', + 'KEEP-ALIVE', + 'PROXY-AUTHENTICATE', + 'PROXY-AUTHORIZATION', + 'TE', + 'TRAILERS', + 'TRANSFER-ENCODING', + 'UPGRADE', + 'SERVER', + 'DATE', + } + + def __init__(self, transport, status, http_version=(1, 1), close=False): + super().__init__(transport, http_version, close) + + self.status = status + self.status_line = 'HTTP/{}.{} {} {}\r\n'.format( + http_version[0], http_version[1], status, RESPONSES[status][0]) + + def _add_default_headers(self): + super()._add_default_headers() + self.headers.extend((('DATE', format_date_time(None)), + ('SERVER', self.SERVER_SOFTWARE),)) + + +class Request(HttpMessage): + + HOP_HEADERS = () + + def __init__(self, transport, method, path, + http_version=(1, 1), close=False): + super().__init__(transport, http_version, close) + + self.method = method + self.path = path + self.status_line = '{0} {1} HTTP/{2[0]}.{2[1]}\r\n'.format( + method, path, http_version) + + def _add_default_headers(self): + super()._add_default_headers() + self.headers.append(('USER-AGENT', self.SERVER_SOFTWARE)) diff --git a/tulip/http/server.py b/tulip/http/server.py new file mode 100644 index 00000000..8c816e94 --- /dev/null +++ b/tulip/http/server.py @@ -0,0 +1,183 @@ +"""simple http server.""" + +__all__ = ['ServerHttpProtocol'] + +import http.server +import inspect +import logging +import traceback + +import tulip +from tulip.http import errors + + +RESPONSES = http.server.BaseHTTPRequestHandler.responses +DEFAULT_ERROR_MESSAGE = """ + + + {status} {reason} + + +

{status} {reason}

+ {message} + +""" + + +class ServerHttpProtocol(tulip.Protocol): + """Simple http protocol implementation. + + ServerHttpProtocol handles incoming http request. It reads request line, + request headers and request payload and calls handler_request() method. + By default it always returns with 404 respose. + + ServerHttpProtocol handles errors in incoming request, like bad + status line, bad headers or incomplete payload. If any error occurs, + connection gets closed. + """ + _closing = False + _request_count = 0 + _request_handle = None + + def __init__(self, *, log=logging, debug=False, **kwargs): + self.__dict__.update(kwargs) + self.log = log + self.debug = debug + + def connection_made(self, transport): + self.transport = transport + self.stream = tulip.StreamBuffer() + self._request_handle = self.start() + + def data_received(self, data): + self.stream.feed_data(data) + + def connection_lost(self, exc): + if self._request_handle is not None: + self._request_handle.cancel() + self._request_handle = None + + def eof_received(self): + self.stream.feed_eof() + + def close(self): + self._closing = True + + def log_access(self, status, message, *args, **kw): + pass + + def log_debug(self, *args, **kw): + if self.debug: + self.log.debug(*args, **kw) + + def log_exception(self, *args, **kw): + self.log.exception(*args, **kw) + + @tulip.task + def start(self): + """Start processing of incoming requests. + It reads request line, request headers and request payload, then + calls handle_request() method. Subclass has to override + handle_request(). start() handles various excetions in request + or response handling. In case of any error connection is being closed. + """ + + while self._request_handle is not None: + info = None + message = None + self._request_count += 1 + + try: + httpstream = self.stream.set_parser( + tulip.http.http_request_parser()) + + message = yield from httpstream.read() + + payload = self.stream.set_parser( + tulip.http.http_payload_parser(message)) + + handler = self.handle_request(message, payload) + if (inspect.isgenerator(handler) or + isinstance(handler, tulip.Future)): + yield from handler + + except tulip.CancelledError: + self.log_debug('Ignored premature client disconnection.') + break + except errors.HttpException as exc: + self.handle_error(exc.code, info, message, exc, exc.headers) + except Exception as exc: + self.handle_error(500, info, message, exc) + finally: + if self._closing: + self.transport.close() + break + + def handle_error(self, status=500, + message=None, payload=None, exc=None, headers=None): + """Handle errors. + + Returns http response with specific status code. Logs additional + information. It always closes current connection.""" + try: + if self._request_handle is None: + # client has been disconnected during writing. + return + + if status == 500: + self.log_exception("Error handling request") + + try: + reason, msg = RESPONSES[status] + except KeyError: + status = 500 + reason, msg = '???', '' + + if self.debug and exc is not None: + try: + tb = traceback.format_exc() + msg += '

Traceback:

\n
{}
'.format(tb) + except: + pass + + self.log_access(status, message) + + html = DEFAULT_ERROR_MESSAGE.format( + status=status, reason=reason, message=msg) + + response = tulip.http.Response(self.transport, status, close=True) + response.add_headers( + ('Content-Type', 'text/html'), + ('Content-Length', str(len(html)))) + if headers is not None: + response.add_headers(*headers) + response.send_headers() + + response.write(html.encode('ascii')) + response.write_eof() + finally: + self.close() + + def handle_request(self, message, payload): + """Handle a single http request. + + Subclass should override this method. By default it always + returns 404 response. + + info: tulip.http.RequestLine instance + message: tulip.http.RawHttpMessage instance + """ + response = tulip.http.Response( + self.transport, 404, http_version=message.version, close=True) + + body = b'Page Not Found!' + + response.add_headers( + ('Content-Type', 'text/plain'), + ('Content-Length', str(len(body)))) + response.send_headers() + response.write(body) + response.write_eof() + + self.close() + self.log_access(404, message) diff --git a/tulip/http/session.py b/tulip/http/session.py new file mode 100644 index 00000000..baf19dba --- /dev/null +++ b/tulip/http/session.py @@ -0,0 +1,101 @@ +"""client session support.""" + +__all__ = ['Session'] + +import tulip +import http.cookies + + +class Session: + + def __init__(self): + self._conns = {} + self.cookies = http.cookies.SimpleCookie() + + def __del__(self): + self.close() + + def close(self): + """Close all opened transports.""" + for key, data in self._conns.items(): + for transport, proto in data: + transport.close() + + self._conns.clear() + + def update_cookies(self, cookies): + if isinstance(cookies, dict): + cookies = cookies.items() + + for name, value in cookies: + if isinstance(value, http.cookies.Morsel): + # use dict method because SimpleCookie class modifies value + dict.__setitem__(self.cookies, name, value) + else: + self.cookies[name] = value + + @tulip.coroutine + def start(self, req, loop, new_conn=False, set_cookies=True): + key = (req.host, req.port, req.ssl) + + if set_cookies and self.cookies: + req.update_cookies(self.cookies.items()) + + if not new_conn: + transport, proto = self._get(key) + + if new_conn or transport is None: + new = True + transport, proto = yield from loop.create_connection( + tulip.StreamProtocol, req.host, req.port, ssl=req.ssl) + else: + new = False + + try: + resp = req.send(transport) + yield from resp.start( + proto, TransportWrapper( + self._release, key, transport, proto, resp)) + except: + if new: + transport.close() + raise + + return (yield from self.start(req, loop, set_cookies=False)) + + return resp + + def _get(self, key): + conns = self._conns.get(key) + if conns: + return conns.pop() + + return None, None + + def _release(self, resp, key, conn): + msg = resp.message + if msg.should_close: + conn[0].close() + else: + conns = self._conns.get(key) + if conns is None: + conns = self._conns[key] = [] + conns.append(conn) + conn[1].unset_parser() + + if resp.cookies: + self.update_cookies(resp.cookies.items()) + + +class TransportWrapper: + + def __init__(self, release, key, transport, protocol, response): + self.release = release + self.key = key + self.transport = transport + self.protocol = protocol + self.response = response + + def close(self): + self.release(self.response, self.key, + (self.transport, self.protocol)) diff --git a/tulip/http/websocket.py b/tulip/http/websocket.py new file mode 100644 index 00000000..71784bcc --- /dev/null +++ b/tulip/http/websocket.py @@ -0,0 +1,227 @@ +"""WebSocket protocol versions 13 and 8.""" + +__all__ = ['WebSocketParser', 'WebSocketWriter', 'do_handshake', + 'Message', 'WebSocketError', + 'MSG_TEXT', 'MSG_BINARY', 'MSG_CLOSE', 'MSG_PING', 'MSG_PONG'] + +import base64 +import binascii +import collections +import hashlib +import struct +from tulip.http import errors + +# Frame opcodes defined in the spec. +OPCODE_CONTINUATION = 0x0 +MSG_TEXT = OPCODE_TEXT = 0x1 +MSG_BINARY = OPCODE_BINARY = 0x2 +MSG_CLOSE = OPCODE_CLOSE = 0x8 +MSG_PING = OPCODE_PING = 0x9 +MSG_PONG = OPCODE_PONG = 0xa + +WS_KEY = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +WS_HDRS = ('UPGRADE', 'CONNECTION', + 'SEC-WEBSOCKET-VERSION', 'SEC-WEBSOCKET-KEY') + +Message = collections.namedtuple('Message', ['tp', 'data', 'extra']) + + +class WebSocketError(Exception): + """WebSocket protocol parser error.""" + + +def WebSocketParser(): + out, buf = yield + + while True: + message = yield from parse_message(buf) + out.feed_data(message) + + if message.tp == MSG_CLOSE: + out.feed_eof() + break + + +def parse_frame(buf): + """Return the next frame from the socket.""" + # read header + data = yield from buf.read(2) + first_byte, second_byte = struct.unpack('!BB', data) + + fin = (first_byte >> 7) & 1 + rsv1 = (first_byte >> 6) & 1 + rsv2 = (first_byte >> 5) & 1 + rsv3 = (first_byte >> 4) & 1 + opcode = first_byte & 0xf + + # frame-fin = %x0 ; more frames of this message follow + # / %x1 ; final frame of this message + # frame-rsv1 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise + # frame-rsv2 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise + # frame-rsv3 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise + if rsv1 or rsv2 or rsv3: + raise WebSocketError('Received frame with non-zero reserved bits') + + if opcode > 0x7 and fin == 0: + raise WebSocketError('Received fragmented control frame') + + if fin == 0 and opcode == OPCODE_CONTINUATION: + raise WebSocketError( + 'Received new fragment frame with non-zero opcode') + + has_mask = (second_byte >> 7) & 1 + length = (second_byte) & 0x7f + + # Control frames MUST have a payload length of 125 bytes or less + if opcode > 0x7 and length > 125: + raise WebSocketError( + "Control frame payload cannot be larger than 125 bytes") + + # read payload + if length == 126: + data = yield from buf.read(2) + length = struct.unpack_from('!H', data)[0] + elif length > 126: + data = yield from buf.read(8) + length = struct.unpack_from('!Q', data)[0] + + if has_mask: + mask = yield from buf.read(4) + + if length: + payload = yield from buf.read(length) + else: + payload = b'' + + if has_mask: + payload = bytes(b ^ mask[i % 4] for i, b in enumerate(payload)) + + return fin, opcode, payload + + +def parse_message(buf): + fin, opcode, payload = yield from parse_frame(buf) + + if opcode == OPCODE_CLOSE: + if len(payload) >= 2: + close_code = struct.unpack('!H', payload[:2])[0] + close_message = payload[2:] + return Message(OPCODE_CLOSE, close_code, close_message) + elif payload: + raise WebSocketError( + 'Invalid close frame: {} {} {!r}'.format(fin, opcode, payload)) + return Message(OPCODE_CLOSE, '', '') + + elif opcode == OPCODE_PING: + return Message(OPCODE_PING, '', '') + + elif opcode == OPCODE_PONG: + return Message(OPCODE_PONG, '', '') + + elif opcode not in (OPCODE_TEXT, OPCODE_BINARY): + raise WebSocketError("Unexpected opcode={!r}".format(opcode)) + + # load text/binary + data = [payload] + + while not fin: + fin, _opcode, payload = yield from parse_frame(buf) + if _opcode != OPCODE_CONTINUATION: + raise WebSocketError( + 'The opcode in non-fin frame is expected ' + 'to be zero, got {!r}'.format(opcode)) + else: + data.append(payload) + + if opcode == OPCODE_TEXT: + return Message(OPCODE_TEXT, b''.join(data).decode('utf-8'), '') + else: + return Message(OPCODE_BINARY, b''.join(data), '') + + +class WebSocketWriter: + + def __init__(self, transport): + self.transport = transport + + def _send_frame(self, message, opcode): + """Send a frame over the websocket with message as its payload.""" + header = bytes([0x80 | opcode]) + msg_length = len(message) + + if msg_length < 126: + header += bytes([msg_length]) + elif msg_length < (1 << 16): + header += bytes([126]) + struct.pack('!H', msg_length) + else: + header += bytes([127]) + struct.pack('!Q', msg_length) + + self.transport.write(header + message) + + def pong(self): + """Send pong message.""" + self._send_frame(b'', OPCODE_PONG) + + def ping(self): + """Send pong message.""" + self._send_frame(b'', OPCODE_PING) + + def send(self, message, binary=False): + """Send a frame over the websocket with message as its payload.""" + if isinstance(message, str): + message = message.encode('utf-8') + if binary: + self._send_frame(message, OPCODE_BINARY) + else: + self._send_frame(message, OPCODE_TEXT) + + def close(self, code=1000, message=b''): + """Close the websocket, sending the specified code and message.""" + if isinstance(message, str): + message = message.encode('utf-8') + self._send_frame( + struct.pack('!H%ds' % len(message), code, message), + opcode=OPCODE_CLOSE) + + +def do_handshake(message, transport): + """Prepare WebSocket handshake. It return http response code, + response headers, websocket parser, websocket writer. It does not + do any IO.""" + headers = dict(((hdr, val) + for hdr, val in message.headers if hdr in WS_HDRS)) + + if 'websocket' != headers.get('UPGRADE', '').lower().strip(): + raise errors.BadRequestException('No WebSocket UPGRADE hdr: {}'.format( + headers.get('UPGRADE'))) + + if 'upgrade' not in headers.get('CONNECTION', '').lower(): + raise errors.BadRequestException( + 'No CONNECTION upgrade hdr: {}'.format( + headers.get('CONNECTION'))) + + # check supported version + version = headers.get('SEC-WEBSOCKET-VERSION') + if version not in ('13', '8'): + raise errors.BadRequestException( + 'Unsupported version: {}'.format(version)) + + # check client handshake for validity + key = headers.get('SEC-WEBSOCKET-KEY') + try: + if not key or len(base64.b64decode(key)) != 16: + raise errors.BadRequestException( + 'Handshake error: {!r}'.format(key)) + except binascii.Error: + raise errors.BadRequestException( + 'Handshake error: {!r}'.format(key)) from None + + # response code, headers, parser, writer + return (101, + (('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('TRANSFER-ENCODING', 'chunked'), + ('SEC-WEBSOCKET-ACCEPT', base64.b64encode( + hashlib.sha1(key.encode() + WS_KEY).digest()).decode())), + WebSocketParser(), + WebSocketWriter(transport)) diff --git a/tulip/http/wsgi.py b/tulip/http/wsgi.py new file mode 100644 index 00000000..a612a0aa --- /dev/null +++ b/tulip/http/wsgi.py @@ -0,0 +1,221 @@ +"""wsgi server. + +TODO: + * proxy protocol + * x-forward security + * wsgi file support (os.sendfile) +""" + +__all__ = ['WSGIServerHttpProtocol'] + +import inspect +import io +import os +import sys +from urllib.parse import unquote, urlsplit + +import tulip +import tulip.http +from tulip.http import server + + +class WSGIServerHttpProtocol(server.ServerHttpProtocol): + """HTTP Server that implements the Python WSGI protocol. + + It uses 'wsgi.async' of 'True'. 'wsgi.input' can behave differently + depends on 'readpayload' constructor parameter. If readpayload is set to + True, wsgi server reads all incoming data into BytesIO object and + sends it as 'wsgi.input' environ var. If readpayload is set to false + 'wsgi.input' is a StreamReader and application should read incoming + data with "yield from environ['wsgi.input'].read()". It defaults to False. + """ + + SCRIPT_NAME = os.environ.get('SCRIPT_NAME', '') + + def __init__(self, app, readpayload=False, is_ssl=False, *args, **kw): + super().__init__(*args, **kw) + + self.wsgi = app + self.is_ssl = is_ssl + self.readpayload = readpayload + + def create_wsgi_response(self, message): + return WsgiResponse(self.transport, message) + + def create_wsgi_environ(self, message, payload): + uri_parts = urlsplit(message.path) + url_scheme = 'https' if self.is_ssl else 'http' + + environ = { + 'wsgi.input': payload, + 'wsgi.errors': sys.stderr, + 'wsgi.version': (1, 0), + 'wsgi.async': True, + 'wsgi.multithread': False, + 'wsgi.multiprocess': False, + 'wsgi.run_once': False, + 'wsgi.file_wrapper': FileWrapper, + 'wsgi.url_scheme': url_scheme, + 'SERVER_SOFTWARE': tulip.http.HttpMessage.SERVER_SOFTWARE, + 'REQUEST_METHOD': message.method, + 'QUERY_STRING': uri_parts.query or '', + 'RAW_URI': message.path, + 'SERVER_PROTOCOL': 'HTTP/%s.%s' % message.version + } + + # authors should be aware that REMOTE_HOST and REMOTE_ADDR + # may not qualify the remote addr: + # http://www.ietf.org/rfc/rfc3875 + forward = self.transport.get_extra_info('addr', '127.0.0.1') + script_name = self.SCRIPT_NAME + server = forward + + for hdr_name, hdr_value in message.headers: + if hdr_name == 'EXPECT': + # handle expect + if hdr_value.lower() == '100-continue': + self.transport.write(b'HTTP/1.1 100 Continue\r\n\r\n') + elif hdr_name == 'HOST': + server = hdr_value + elif hdr_name == 'SCRIPT_NAME': + script_name = hdr_value + elif hdr_name == 'CONTENT-TYPE': + environ['CONTENT_TYPE'] = hdr_value + continue + elif hdr_name == 'CONTENT-LENGTH': + environ['CONTENT_LENGTH'] = hdr_value + continue + + key = 'HTTP_%s' % hdr_name.replace('-', '_') + if key in environ: + hdr_value = '%s,%s' % (environ[key], hdr_value) + + environ[key] = hdr_value + + if isinstance(forward, str): + # we only took the last one + # http://en.wikipedia.org/wiki/X-Forwarded-For + if ',' in forward: + forward = forward.rsplit(',', 1)[-1].strip() + + # find host and port on ipv6 address + if '[' in forward and ']' in forward: + host = forward.split(']')[0][1:].lower() + elif ':' in forward and forward.count(':') == 1: + host = forward.split(':')[0].lower() + else: + host = forward + + forward = forward.split(']')[-1] + if ':' in forward and forward.count(':') == 1: + port = forward.split(':', 1)[1] + else: + port = 80 + + remote = (host, port) + else: + remote = forward + + environ['REMOTE_ADDR'] = remote[0] + environ['REMOTE_PORT'] = str(remote[1]) + + if isinstance(server, str): + server = server.split(':') + if len(server) == 1: + server.append('80' if url_scheme == 'http' else '443') + + environ['SERVER_NAME'] = server[0] + environ['SERVER_PORT'] = str(server[1]) + + path_info = uri_parts.path + if script_name: + path_info = path_info.split(script_name, 1)[-1] + + environ['PATH_INFO'] = unquote(path_info) + environ['SCRIPT_NAME'] = script_name + + environ['tulip.reader'] = self.stream + environ['tulip.writer'] = self.transport + + return environ + + @tulip.coroutine + def handle_request(self, message, payload): + """Handle a single HTTP request""" + + if self.readpayload: + wsgiinput = io.BytesIO() + chunk = yield from payload.read() + while chunk: + wsgiinput.write(chunk) + chunk = yield from payload.read() + payload = wsgiinput + + environ = self.create_wsgi_environ(message, payload) + response = self.create_wsgi_response(message) + + riter = self.wsgi(environ, response.start_response) + if isinstance(riter, tulip.Future) or inspect.isgenerator(riter): + riter = yield from riter + + resp = response.response + try: + for item in riter: + if isinstance(item, tulip.Future): + item = yield from item + resp.write(item) + + resp.write_eof() + finally: + if hasattr(riter, 'close'): + riter.close() + + if not resp.keep_alive(): + self.close() + + +class FileWrapper: + """Custom file wrapper.""" + + def __init__(self, fobj, chunk_size=8192): + self.fobj = fobj + self.chunk_size = chunk_size + if hasattr(fobj, 'close'): + self.close = fobj.close + + def __iter__(self): + return self + + def __next__(self): + data = self.fobj.read(self.chunk_size) + if data: + return data + raise StopIteration + + +class WsgiResponse: + """Implementation of start_response() callable as specified by PEP 3333""" + + status = None + + def __init__(self, transport, message): + self.transport = transport + self.message = message + + def start_response(self, status, headers, exc_info=None): + if exc_info: + try: + if self.status: + raise exc_info[1] + finally: + exc_info = None + + status_code = int(status.split(' ', 1)[0]) + + self.status = status + self.response = tulip.http.Response( + self.transport, status_code, + self.message.version, self.message.should_close) + self.response.add_headers(*headers) + self.response._send_headers = True + return self.response.write diff --git a/tulip/locks.py b/tulip/locks.py new file mode 100644 index 00000000..ff841442 --- /dev/null +++ b/tulip/locks.py @@ -0,0 +1,434 @@ +"""Synchronization primitives""" + +__all__ = ['Lock', 'EventWaiter', 'Condition', 'Semaphore'] + +import collections +import time + +from . import events +from . import futures +from . import tasks + + +class Lock: + """The class implementing primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned by + a particular coroutine when locked. A primitive lock is in one of two + states, "locked" or "unlocked". + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() changes + the state to locked and returns immediately. When the state is locked, + acquire() blocks until a call to release() in another coroutine changes + it to unlocked, then the acquire() call resets it to locked and returns. + The release() method should only be called in the locked state; it changes + the state to unlocked and returns immediately. If an attempt is made + to release an unlocked lock, a RuntimeError will be raised. + + When more than one coroutine is blocked in acquire() waiting for the state + to turn to unlocked, only one coroutine proceeds when a release() call + resets the state to unlocked; first coroutine which is blocked in acquire() + is being processed. + + acquire() method is a coroutine and should be called with "yield from" + + Locks also support the context manager protocol. (yield from lock) should + be used as context manager expression. + + Usage: + + lock = Lock() + ... + yield from lock + try: + ... + finally: + lock.release() + + Context manager usage: + + lock = Lock() + ... + with (yield from lock): + ... + + Lock object could be tested for locking state: + + if not lock.locked(): + yield from lock + else: + # lock is acquired + ... + + """ + + def __init__(self): + self._waiters = collections.deque() + self._locked = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<{} [{}]>'.format( + res[1:-1], 'locked' if self._locked else 'unlocked') + + def locked(self): + """Return true if lock is acquired.""" + return self._locked + + @tasks.coroutine + def acquire(self, timeout=None): + """Acquire a lock. + + Acquire method blocks until the lock is unlocked, then set it to + locked and return True. + + When invoked with the floating-point timeout argument set, blocks for + at most the number of seconds specified by timeout and as long as + the lock cannot be acquired. + + The return value is True if the lock is acquired successfully, + False if not (for example if the timeout expired). + """ + if not self._waiters and not self._locked: + self._locked = True + return True + + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + self._locked = True + return True + + def release(self): + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + If any other coroutines are blocked waiting for the lock to become + unlocked, allow exactly one of them to proceed. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + if self._waiters: + self._waiters[0].set_result(True) + else: + raise RuntimeError('Lock is not acquired.') + + def __enter__(self): + if not self._locked: + raise RuntimeError( + '"yield from" should be used as context manager expression') + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self + + +class EventWaiter: + """A EventWaiter implementation, our equivalent to threading.Event + + Class implementing event objects. An event manages a flag that can be set + to true with the set() method and reset to false with the clear() method. + The wait() method blocks until the flag is true. The flag is initially + false. + """ + + def __init__(self): + self._waiters = collections.deque() + self._value = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<{} [{}]>'.format(res[1:-1], 'set' if self._value else 'unset') + + def is_set(self): + """Return true if and only if the internal flag is true.""" + return self._value + + def set(self): + """Set the internal flag to true. All coroutines waiting for it to + become true are awakened. Coroutine that call wait() once the flag is + true will not block at all. + """ + if not self._value: + self._value = True + + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + + def clear(self): + """Reset the internal flag to false. Subsequently, coroutines calling + wait() will block until set() is called to set the internal flag + to true again.""" + self._value = False + + @tasks.coroutine + def wait(self, timeout=None): + """Block until the internal flag is true. If the internal flag + is true on entry, return immediately. Otherwise, block until another + coroutine calls set() to set the flag to true, or until the optional + timeout occurs. + + When the timeout argument is present and not None, it should be + a floating point number specifying a timeout for the operation in + seconds (or fractions thereof). + + This method returns true if and only if the internal flag has been + set to true, either before the wait call or after the wait starts, + so it will always return True except if a timeout is given and + the operation times out. + + wait() method is a coroutine. + """ + if self._value: + return True + + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + return True + + +class Condition(Lock): + """A Condition implementation. + + This class implements condition variable objects. A condition variable + allows one or more coroutines to wait until they are notified by another + coroutine. + """ + + def __init__(self): + super().__init__() + + self._condition_waiters = collections.deque() + + @tasks.coroutine + def wait(self, timeout=None): + """Wait until notified or until a timeout occurs. If the calling + coroutine has not acquired the lock when this method is called, + a RuntimeError is raised. + + This method releases the underlying lock, and then blocks until it is + awakened by a notify() or notify_all() call for the same condition + variable in another coroutine, or until the optional timeout occurs. + Once awakened or timed out, it re-acquires the lock and returns. + + When the timeout argument is present and not None, it should be + a floating point number specifying a timeout for the operation + in seconds (or fractions thereof). + + The return value is True unless a given timeout expired, in which + case it is False. + """ + if not self._locked: + raise RuntimeError('cannot wait on un-acquired lock') + + self.release() + + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + + self._condition_waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._condition_waiters.remove(fut) + return False + else: + f = self._condition_waiters.popleft() + assert fut is f + finally: + yield from self.acquire() + + return True + + @tasks.coroutine + def wait_for(self, predicate, timeout=None): + """Wait until a condition evaluates to True. predicate should be a + callable which result will be interpreted as a boolean value. A timeout + may be provided giving the maximum time to wait. + """ + endtime = None + waittime = timeout + result = predicate() + + while not result: + if waittime is not None: + if endtime is None: + endtime = time.monotonic() + waittime + else: + waittime = endtime - time.monotonic() + if waittime <= 0: + break + + yield from self.wait(waittime) + result = predicate() + + return result + + def notify(self, n=1): + """By default, wake up one coroutine waiting on this condition, if any. + If the calling coroutine has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up at most n of the coroutines waiting for the + condition variable; it is a no-op if no coroutines are waiting. + + Note: an awakened coroutine does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self._locked: + raise RuntimeError('cannot notify on un-acquired lock') + + idx = 0 + for fut in self._condition_waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self): + """Wake up all threads waiting on this condition. This method acts + like notify(), but wakes up all waiting threads instead of one. If the + calling thread has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._condition_waiters)) + + +class Semaphore: + """A Semaphore implementation. + + A semaphore manages an internal counter which is decremented by each + acquire() call and incremented by each release() call. The counter + can never go below zero; when acquire() finds that it is zero, it blocks, + waiting until some other thread calls release(). + + Semaphores also support the context manager protocol. + + The first optional argument gives the initial value for the internal + counter; it defaults to 1. If the value given is less than 0, + ValueError is raised. + + The second optional argument determins can semophore be released more than + initial internal counter value; it defaults to False. If the value given + is True and number of release() is more than number of successfull + acquire() calls ValueError is raised. + """ + + def __init__(self, value=1, bound=False): + if value < 0: + raise ValueError("Semaphore initial value must be > 0") + self._value = value + self._bound = bound + self._bound_value = value + self._waiters = collections.deque() + self._locked = False + self._event_loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<{} [{}]>'.format( + res[1:-1], + 'locked' if self._locked else 'unlocked,value:{}'.format( + self._value)) + + def locked(self): + """Returns True if semaphore can not be acquired immediately.""" + return self._locked + + @tasks.coroutine + def acquire(self, timeout=None): + """Acquire a semaphore. acquire() method is a coroutine. + + When invoked without arguments: if the internal counter is larger + than zero on entry, decrement it by one and return immediately. + If it is zero on entry, block, waiting until some other coroutine has + called release() to make it larger than zero. + + When invoked with a timeout other than None, it will block for at + most timeout seconds. If acquire does not complete successfully in + that interval, return false. Return true otherwise. + """ + if not self._waiters and self._value > 0: + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + When it was zero on entry and another coroutine is waiting for it to + become larger than zero again, wake up that coroutine. + + If Semaphore is create with "bound" paramter equals true, then + release() method checks to make sure its current value doesn't exceed + its initial value. If it does, ValueError is raised. + """ + if self._bound and self._value >= self._bound_value: + raise ValueError('Semaphore released too many times') + + self._value += 1 + self._locked = False + + for waiter in self._waiters: + if not waiter.done(): + waiter.set_result(True) + break + + def __enter__(self): + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self diff --git a/tulip/log.py b/tulip/log.py new file mode 100644 index 00000000..b918fe54 --- /dev/null +++ b/tulip/log.py @@ -0,0 +1,6 @@ +"""Tulip logging configuration""" + +import logging + + +tulip_log = logging.getLogger("tulip") diff --git a/tulip/parsers.py b/tulip/parsers.py new file mode 100644 index 00000000..f5a7845a --- /dev/null +++ b/tulip/parsers.py @@ -0,0 +1,397 @@ +"""Parser is a generator function. + +Parser receives data with generator's send() method and sends data to +destination DataBuffer. Parser receives ParserBuffer and DataBuffer objects +as a parameters of the first send() call, all subsequent send() calls should +send bytes objects. Parser sends parsed 'term' to desitnation buffer with +DataBuffer.feed_data() method. DataBuffer object should implement two methods. +feed_data() - parser uses this method to send parsed protocol data. +feed_eof() - parser uses this method for indication of end of parsing stream. +To indicate end of incoming data stream EofStream exception should be sent +into parser. Parser could throw exceptions. + +There are three stages: + + * Data flow chain: + + 1. Application creates StreamBuffer object for storing incoming data. + 2. StreamBuffer creates ParserBuffer as internal data buffer. + 3. Application create parser and set it into stream buffer: + + parser = http_request_parser() + data_buffer = stream.set_parser(parser) + + 3. At this stage StreamBuffer creates DataBuffer object and passes it + and internal buffer into parser with first send() call. + + def set_parser(self, parser): + next(parser) + data_buffer = DataBuffer() + parser.send((data_buffer, self._buffer)) + return data_buffer + + 4. Application waits data on data_buffer.read() + + while True: + msg = yield form data_buffer.read() + ... + + * Data flow: + + 1. Tulip's transport reads data from socket and sends data to protocol + with data_received() call. + 2. Protocol sends data to StreamBuffer with feed_data() call. + 3. StreamBuffer sends data into parser with generator's send() method. + 4. Parser processes incoming data and sends parsed data + to DataBuffer with feed_data() + 4. Application received parsed data from DataBuffer.read() + + * Eof: + + 1. StreamBuffer recevies eof with feed_eof() call. + 2. StreamBuffer throws EofStream exception into parser. + 3. Then it unsets parser. + +_SocketSocketTransport -> + -> "protocol" -> StreamBuffer -> "parser" -> DataBuffer <- "application" + +""" +__all__ = ['EofStream', 'StreamBuffer', 'StreamProtocol', + 'ParserBuffer', 'DataBuffer', 'lines_parser', 'chunks_parser'] + +import collections + +from . import tasks +from . import futures +from . import protocols + + +class EofStream(Exception): + """eof stream indication.""" + + +class StreamBuffer: + """StreamBuffer manages incoming bytes stream and protocol parsers. + + StreamBuffer uses ParserBuffer as internal buffer. + + set_parser() sets current parser, it creates DataBuffer object + and sends ParserBuffer and DataBuffer into parser generator. + + unset_parser() sends EofStream into parser and then removes it. + """ + + def __init__(self): + self._buffer = ParserBuffer() + self._eof = False + self._parser = None + self._parser_buffer = None + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + if self._parser_buffer is not None: + self._parser_buffer.set_exception(exc) + self._parser = None + self._parser_buffer = None + + def feed_data(self, data): + """send data to current parser or store in buffer.""" + if not data: + return + + if self._parser: + try: + self._parser.send(data) + except StopIteration: + self._parser = None + self._parser_buffer = None + except Exception as exc: + self._parser_buffer.set_exception(exc) + self._parser = None + self._parser_buffer = None + else: + self._buffer.feed_data(data) + + def feed_eof(self): + """send eof to all parsers, recursively.""" + if self._parser: + try: + self._parser.throw(EofStream()) + except StopIteration: + pass + except EofStream: + self._parser_buffer.feed_eof() + except Exception as exc: + self._parser_buffer.set_exception(exc) + + self._parser = None + self._parser_buffer = None + + self._eof = True + + def set_parser(self, p): + """set parser to stream. return parser's DataStream.""" + if self._parser: + self.unset_parser() + + out = DataBuffer() + if self._exception: + out.set_exception(self._exception) + return out + + # init generator + next(p) + try: + # initialize parser with data and parser buffers + p.send((out, self._buffer)) + except StopIteration: + pass + except Exception as exc: + out.set_exception(exc) + else: + # parser still require more data + self._parser = p + self._parser_buffer = out + + if self._eof: + self.unset_parser() + + return out + + def unset_parser(self): + """unset parser, send eof to the parser and then remove it.""" + if self._parser is None: + return + + try: + self._parser.throw(EofStream()) + except StopIteration: + pass + except EofStream: + self._parser_buffer.feed_eof() + except Exception as exc: + self._parser_buffer.set_exception(exc) + finally: + self._parser = None + self._parser_buffer = None + + +class StreamProtocol(StreamBuffer, protocols.Protocol): + """Tulip's stream protocol based on StreamBuffer""" + + transport = None + + data_received = StreamBuffer.feed_data + + eof_received = StreamBuffer.feed_eof + + def connection_made(self, transport): + self.transport = transport + + def connection_lost(self, exc): + self.transport = None + + if exc is not None: + self.set_exception(exc) + else: + self.feed_eof() + + +class DataBuffer: + """DataBuffer is a destination for parsed data.""" + + def __init__(self): + self._buffer = collections.deque() + self._eof = False + self._waiter = None + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + waiter = self._waiter + if waiter is not None: + self._waiter = None + if not waiter.cancelled(): + waiter.set_exception(exc) + + def feed_data(self, data): + self._buffer.append(data) + + waiter = self._waiter + if waiter is not None: + self._waiter = None + waiter.set_result(True) + + def feed_eof(self): + self._eof = True + + waiter = self._waiter + if waiter is not None: + self._waiter = None + waiter.set_result(False) + + @tasks.coroutine + def read(self): + if self._exception is not None: + raise self._exception + + if not self._buffer and not self._eof: + assert not self._waiter + self._waiter = futures.Future() + yield from self._waiter + + if self._buffer: + return self._buffer.popleft() + else: + return None + + +class ParserBuffer(bytearray): + """ParserBuffer is a bytearray extension. + + ParserBuffer provides helper methods for parsers. + """ + + def __init__(self, *args): + super().__init__(*args) + + self.offset = 0 + self.size = 0 + self._writer = self._feed_data() + next(self._writer) + + def _shrink(self): + if self.offset: + del self[:self.offset] + self.offset = 0 + self.size = len(self) + + def _feed_data(self): + while True: + chunk = yield + if chunk: + chunk_len = len(chunk) + self.size += chunk_len + self.extend(chunk) + + # shrink buffer + if (self.offset and len(self) > 5120): + self._shrink() + + def feed_data(self, data): + self._writer.send(data) + + def read(self, size): + """read() reads specified amount of bytes.""" + + while True: + if self.size >= size: + start, end = self.offset, self.offset + size + self.offset = end + self.size = self.size - size + return self[start:end] + + self._writer.send((yield)) + + def readsome(self, size=None): + """reads size of less amount of bytes.""" + + while True: + if self.size > 0: + if size is None or self.size < size: + size = self.size + + start, end = self.offset, self.offset + size + self.offset = end + self.size = self.size - size + + return self[start:end] + + self._writer.send((yield)) + + def readuntil(self, stop, limit=None, exc=ValueError): + assert isinstance(stop, bytes) and stop, \ + 'bytes is required: {!r}'.format(stop) + + stop_len = len(stop) + + while True: + pos = self.find(stop, self.offset) + if pos >= 0: + end = pos + stop_len + size = end - self.offset + if limit is not None and size > limit: + raise exc('Line is too long.') + + start, self.offset = self.offset, end + self.size = self.size - size + + return self[start:end] + else: + if limit is not None and self.size > limit: + raise exc('Line is too long.') + + self._writer.send((yield)) + + def skip(self, size): + """skip() skips specified amount of bytes.""" + + while self.size < size: + self._writer.send((yield)) + + self.size -= size + self.offset += size + + def skipuntil(self, stop): + """skipuntil() reads until `stop` bytes sequence.""" + assert isinstance(stop, bytes) and stop, \ + 'bytes is required: {!r}'.format(stop) + + stop_len = len(stop) + + while True: + stop_line = self.find(stop, self.offset) + if stop_line >= 0: + end = stop_line + stop_len + self.size = self.size - (end - self.offset) + self.offset = end + return + else: + self.size = 0 + self.offset = len(self) - 1 + + self._writer.send((yield)) + + def __bytes__(self): + return bytes(self[self.offset:]) + + +def lines_parser(limit=2**16, exc=ValueError): + """Lines parser. + + lines parser splits a bytes stream into a chunks of data, each chunk ends + with \n symbol.""" + out, buf = yield + + while True: + out.feed_data((yield from buf.readuntil(b'\n', limit, exc))) + + +def chunks_parser(size=8196): + """Chunks parser. + + chunks parser splits a bytes stream into a specified + size chunks of data.""" + out, buf = yield + + while True: + out.feed_data((yield from buf.read(size))) diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py new file mode 100644 index 00000000..c142ff45 --- /dev/null +++ b/tulip/proactor_events.py @@ -0,0 +1,199 @@ +"""Event loop using a proactor and related classes. + +A proactor is a "notify-on-completion" multiplexer. Currently a +proactor is only implemented on Windows with IOCP. +""" + +from . import base_events +from . import constants +from . import transports +from .log import tulip_log + + +class _ProactorSocketTransport(transports.Transport): + + def __init__(self, loop, sock, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._loop = loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._read_fut = None + self._write_fut = None + self._conn_lost = 0 + self._closing = False # Set when close() called. + self._loop.call_soon(self._protocol.connection_made, self) + self._loop.call_soon(self._loop_reading) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _loop_reading(self, fut=None): + data = None + + try: + if fut is not None: + assert fut is self._read_fut + + data = fut.result() # deliver data later in "finally" clause + if not data: + self._read_fut = None + return + + self._read_fut = self._loop._proactor.recv(self._sock, 4096) + except (ConnectionAbortedError, ConnectionResetError) as exc: + if not self._closing: + self._fatal_error(exc) + except OSError as exc: + self._fatal_error(exc) + else: + self._read_fut.add_done_callback(self._loop_reading) + finally: + if data: + self._protocol.data_received(data) + elif data is not None: + self._protocol.eof_received() + + def write(self, data): + assert isinstance(data, bytes), repr(data) + assert not self._closing + if not data: + return + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + self._buffer.append(data) + if not self._write_fut: + self._loop_writing() + + def _loop_writing(self, f=None): + try: + assert f is self._write_fut + if f: + f.result() + data = b''.join(self._buffer) + self._buffer = [] + if not data: + self._write_fut = None + if self._closing: + self._loop.call_soon(self._call_connection_lost, None) + return + self._write_fut = self._loop._proactor.send(self._sock, data) + except OSError as exc: + self._conn_lost += 1 + self._fatal_error(exc) + else: + self._write_fut.add_done_callback(self._loop_writing) + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._fatal_error(None) + + def close(self): + self._closing = True + if not self._buffer and self._write_fut is None: + self._loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + tulip_log.exception('Fatal error for %s', self) + if self._write_fut: + self._write_fut.cancel() + if self._read_fut: # XXX + self._read_fut.cancel() + self._write_fut = self._read_fut = None + self._buffer = [] + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + + +class BaseProactorEventLoop(base_events.BaseEventLoop): + + def __init__(self, proactor): + super().__init__() + tulip_log.debug('Using proactor: %s', proactor.__class__.__name__) + self._proactor = proactor + self._selector = proactor # convenient alias + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + return _ProactorSocketTransport(self, sock, protocol, waiter, extra) + + def close(self): + if self._proactor is not None: + self._close_self_pipe() + self._proactor.close() + self._proactor = None + self._selector = None + + def sock_recv(self, sock, n): + return self._proactor.recv(sock, n) + + def sock_sendall(self, sock, data): + return self._proactor.send(sock, data) + + def sock_connect(self, sock, address): + return self._proactor.connect(sock, address) + + def sock_accept(self, sock): + return self._proactor.accept(sock) + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.call_soon(self._loop_self_reading) + + def _loop_self_reading(self, f=None): + try: + if f is not None: + f.result() # may raise + f = self._proactor.recv(self._ssock, 4096) + except: + self.close() + raise + else: + f.add_done_callback(self._loop_self_reading) + + def _write_to_self(self): + self._csock.send(b'x') + + def _start_serving(self, protocol_factory, sock, ssl=None): + assert not ssl, 'IocpEventLoop imcompatible with SSL.' + + def loop(f=None): + try: + if f: + conn, addr = f.result() + protocol = protocol_factory() + self._make_socket_transport( + conn, protocol, extra={'addr': addr}) + f = self._proactor.accept(sock) + except OSError: + sock.close() + tulip_log.exception('Accept failed') + else: + f.add_done_callback(loop) + self.call_soon(loop) + + def _process_events(self, event_list): + pass # XXX hard work currently done in poll diff --git a/tulip/protocols.py b/tulip/protocols.py new file mode 100644 index 00000000..593ee745 --- /dev/null +++ b/tulip/protocols.py @@ -0,0 +1,78 @@ +"""Abstract Protocol class.""" + +__all__ = ['Protocol', 'DatagramProtocol'] + + +class BaseProtocol: + """ABC for base protocol class. + + Usually user implements protocols that derived from BaseProtocol + like Protocol or ProcessProtocol. + + The only case when BaseProtocol should be implemented directly is + write-only transport like write pipe + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the pipe connection. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +class Protocol(BaseProtocol): + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing (they don't raise exceptions). + + When the user wants to requests a transport, they pass a protocol + factory to a utility function (e.g., EventLoop.create_connection()). + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_lost() will be called exactly once + with either an exception object or None as an argument. + + State machine of calls: + + start -> CM [-> DR*] [-> ER?] -> CL -> end + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + The default implementation does nothing. + + TODO: By default close the transport. But we don't have the + transport as an instance variable (connection_made() may not + set it). + """ + + +class DatagramProtocol(BaseProtocol): + """ABC representing a datagram protocol.""" + + def datagram_received(self, data, addr): + """Called when some datagram is received.""" + + def connection_refused(self, exc): + """Connection is refused.""" diff --git a/tulip/queues.py b/tulip/queues.py new file mode 100644 index 00000000..a87a8557 --- /dev/null +++ b/tulip/queues.py @@ -0,0 +1,291 @@ +"""Queues""" + +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue'] + +import collections +import concurrent.futures +import heapq +import queue + +from . import events +from . import futures +from . import locks +from .tasks import coroutine + + +class Queue: + """A queue, useful for coordinating producer and consumer coroutines. + + If maxsize is less than or equal to zero, the queue size is infinite. If it + is an integer greater than 0, then "yield from put()" will block when the + queue reaches maxsize, until an item is removed by get(). + + Unlike the standard library Queue, you can reliably know this Queue's size + with qsize(), since your single-threaded Tulip application won't be + interrupted between calling qsize() and doing an operation on the Queue. + """ + + def __init__(self, maxsize=0): + self._event_loop = events.get_event_loop() + self._maxsize = maxsize + + # Futures. + self._getters = collections.deque() + # Pairs of (item, Future). + self._putters = collections.deque() + self._init(maxsize) + + def _init(self, maxsize): + self._queue = collections.deque() + + def _get(self): + return self._queue.popleft() + + def _put(self, item): + self._queue.append(item) + + def __repr__(self): + return '<{} at {:#x} {}>'.format( + type(self).__name__, id(self), self._format()) + + def __str__(self): + return '<{} {}>'.format(type(self).__name__, self._format()) + + def _format(self): + result = 'maxsize={!r}'.format(self._maxsize) + if getattr(self, '_queue', None): + result += ' _queue={!r}'.format(list(self._queue)) + if self._getters: + result += ' _getters[{}]'.format(len(self._getters)) + if self._putters: + result += ' _putters[{}]'.format(len(self._putters)) + return result + + def _consume_done_getters(self, waiters): + # Delete waiters at the head of the get() queue who've timed out. + while waiters and waiters[0].done(): + waiters.popleft() + + def _consume_done_putters(self): + # Delete waiters at the head of the put() queue who've timed out. + while self._putters and self._putters[0][1].done(): + self._putters.popleft() + + def qsize(self): + """Number of items in the queue.""" + return len(self._queue) + + @property + def maxsize(self): + """Number of items allowed in the queue.""" + return self._maxsize + + def empty(self): + """Return True if the queue is empty, False otherwise.""" + return not self._queue + + def full(self): + """Return True if there are maxsize items in the queue. + + Note: if the Queue was initialized with maxsize=0 (the default), + then full() is never True. + """ + if self._maxsize <= 0: + return False + else: + return self.qsize() == self._maxsize + + @coroutine + def put(self, item, timeout=None): + """Put an item into the queue. + + If you yield from put() and timeout is None (the default), wait until a + free slot is available before adding item. + + If a timeout is provided, raise queue.Full if no free slot becomes + available before the timeout. + """ + self._consume_done_getters(self._getters) + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + waiter = futures.Future( + event_loop=self._event_loop, timeout=timeout) + + self._putters.append((item, waiter)) + try: + yield from waiter + except concurrent.futures.CancelledError: + raise queue.Full + + else: + self._put(item) + + def put_nowait(self, item): + """Put an item into the queue without blocking. + + If no free slot is immediately available, raise queue.Full. + """ + self._consume_done_getters(self._getters) + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + raise queue.Full + else: + self._put(item) + + @coroutine + def get(self, timeout=None): + """Remove and return an item from the queue. + + If you yield from get() and timeout is None (the default), wait until a + item is available. + + If a timeout is provided, raise queue.Empty if no item is available + before the timeout. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + + # When a getter runs and frees up a slot so this putter can + # run, we need to defer the put for a tick to ensure that + # getters and putters alternate perfectly. See + # ChannelTest.test_wait. + self._event_loop.call_soon(putter.set_result, None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + waiter = futures.Future( + event_loop=self._event_loop, timeout=timeout) + + self._getters.append(waiter) + try: + return (yield from waiter) + except concurrent.futures.CancelledError: + raise queue.Empty + + def get_nowait(self): + """Remove and return an item from the queue. + + Return an item if one is immediately available, else raise queue.Full. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + # Wake putter on next tick. + putter.set_result(None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + raise queue.Empty + + +class PriorityQueue(Queue): + """A subclass of Queue; retrieves entries in priority order (lowest first). + + Entries are typically tuples of the form: (priority number, data). + """ + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item, heappush=heapq.heappush): + heappush(self._queue, item) + + def _get(self, heappop=heapq.heappop): + return heappop(self._queue) + + +class LifoQueue(Queue): + """A subclass of Queue that retrieves most recently added entries first.""" + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item): + self._queue.append(item) + + def _get(self): + return self._queue.pop() + + +class JoinableQueue(Queue): + """A subclass of Queue with task_done() and join() methods.""" + + def __init__(self, maxsize=0): + self._unfinished_tasks = 0 + self._finished = locks.EventWaiter() + self._finished.set() + super().__init__(maxsize=maxsize) + + def _format(self): + result = Queue._format(self) + if self._unfinished_tasks: + result += ' tasks={}'.format(self._unfinished_tasks) + return result + + def _put(self, item): + super()._put(item) + self._unfinished_tasks += 1 + self._finished.clear() + + def task_done(self): + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each get() used to fetch a task, + a subsequent call to task_done() tells the queue that the processing + on the task is complete. + + If a join() is currently blocking, it will resume when all items have + been processed (meaning that a task_done() call was received for every + item that had been put() into the queue). + + Raises ValueError if called more times than there were items placed in + the queue. + """ + if self._unfinished_tasks <= 0: + raise ValueError('task_done() called too many times') + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + @coroutine + def join(self, timeout=None): + """Block until all items in the queue have been gotten and processed. + + The count of unfinished tasks goes up whenever an item is added to the + queue. The count goes down whenever a consumer thread calls task_done() + to indicate that the item was retrieved and all work on it is complete. + When the count of unfinished tasks drops to zero, join() unblocks. + """ + if self._unfinished_tasks > 0: + yield from self._finished.wait(timeout=timeout) diff --git a/tulip/selector_events.py b/tulip/selector_events.py new file mode 100644 index 00000000..0dd492aa --- /dev/null +++ b/tulip/selector_events.py @@ -0,0 +1,706 @@ +"""Event loop using a selector and related classes. + +A selector is a "notify-when-ready" multiplexer. For a subclass which +also includes support for signal handling, see the unix_events sub-module. +""" + +import collections +import socket +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import base_events +from . import constants +from . import events +from . import futures +from . import selectors +from . import transports +from .log import tulip_log + + +# Errno values indicating the connection was disconnected. +# Comment out _DISCONNECTED as never used +# TODO: make sure that errors has processed properly +# for now we have no exception clsses for ENOTCONN and EBADF +# _DISCONNECTED = frozenset((errno.ECONNRESET, +# errno.ENOTCONN, +# errno.ESHUTDOWN, +# errno.ECONNABORTED, +# errno.EPIPE, +# errno.EBADF, +# )) + + +class BaseSelectorEventLoop(base_events.BaseEventLoop): + """Selector event loop. + + See events.EventLoop for API specification. + """ + + def __init__(self, selector=None): + super().__init__() + + if selector is None: + selector = selectors.DefaultSelector() + tulip_log.debug('Using selector: %s', selector.__class__.__name__) + self._selector = selector + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None): + return _SelectorSocketTransport(self, sock, protocol, waiter, extra) + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + server_side=False, extra=None): + return _SelectorSslTransport( + self, rawsock, protocol, sslcontext, waiter, server_side, extra) + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + return _SelectorDatagramTransport(self, sock, protocol, address, extra) + + def close(self): + if self._selector is not None: + self._close_self_pipe() + self._selector.close() + self._selector = None + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self.remove_reader(self._ssock.fileno()) + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.add_reader(self._ssock.fileno(), self._read_from_self) + + def _read_from_self(self): + try: + self._ssock.recv(1) + except (BlockingIOError, InterruptedError): + pass + + def _write_to_self(self): + try: + self._csock.send(b'x') + except (BlockingIOError, InterruptedError): + pass + + def _start_serving(self, protocol_factory, sock, ssl=None): + self.add_reader(sock.fileno(), self._accept_connection, + protocol_factory, sock, ssl) + + def _accept_connection(self, protocol_factory, sock, ssl=None): + try: + conn, addr = sock.accept() + conn.setblocking(False) + except (BlockingIOError, InterruptedError): + pass # False alarm. + except Exception: + # Bad error. Stop serving. + self.remove_reader(sock.fileno()) + sock.close() + # There's nowhere to send the error, so just log it. + # TODO: Someone will want an error handler for this. + tulip_log.exception('Accept failed') + else: + if ssl: + self._make_ssl_transport( + conn, protocol_factory(), ssl, None, + server_side=True, extra={'addr': addr}) + else: + self._make_socket_transport( + conn, protocol_factory(), extra={'addr': addr}) + # It's now up to the protocol to handle the connection. + + def add_reader(self, fd, callback, *args): + """Add a reader callback. Return a Handle instance.""" + handle = events.make_handle(callback, args) + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_READ, + (handle, None)) + else: + self._selector.modify(fd, mask | selectors.EVENT_READ, + (handle, writer)) + if reader is not None: + reader.cancel() + + return handle + + def remove_reader(self, fd): + """Remove a reader callback.""" + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + return False + else: + mask &= ~selectors.EVENT_READ + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (None, writer)) + + if reader is not None: + reader.cancel() + return True + else: + return False + + def add_writer(self, fd, callback, *args): + """Add a writer callback. Return a Handle instance.""" + handle = events.make_handle(callback, args) + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_WRITE, + (None, handle)) + else: + self._selector.modify(fd, mask | selectors.EVENT_WRITE, + (reader, handle)) + if writer is not None: + writer.cancel() + + return handle + + def remove_writer(self, fd): + """Remove a writer callback.""" + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + return False + else: + # Remove both writer and connector. + mask &= ~selectors.EVENT_WRITE + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None)) + + if writer is not None: + writer.cancel() + return True + else: + return False + + def sock_recv(self, sock, n): + """XXX""" + fut = futures.Future() + self._sock_recv(fut, False, sock, n) + return fut + + def _sock_recv(self, fut, registered, sock, n): + fd = sock.fileno() + if registered: + # Remove the callback early. It should be rare that the + # selector says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_reader(fd) + if fut.cancelled(): + return + try: + data = sock.recv(n) + fut.set_result(data) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_recv, fut, True, sock, n) + except Exception as exc: + fut.set_exception(exc) + + def sock_sendall(self, sock, data): + """XXX""" + fut = futures.Future() + if data: + self._sock_sendall(fut, False, sock, data) + else: + fut.set_result(None) + return fut + + def _sock_sendall(self, fut, registered, sock, data): + fd = sock.fileno() + + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + + try: + n = sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + fut.set_exception(exc) + return + + if n == len(data): + fut.set_result(None) + else: + if n: + data = data[n:] + self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + + def sock_connect(self, sock, address): + """XXX""" + # That address better not require a lookup! We're not calling + # self.getaddrinfo() for you here. But verifying this is + # complicated; the socket module doesn't have a pattern for + # IPv6 addresses (there are too many forms, apparently). + fut = futures.Future() + self._sock_connect(fut, False, sock, address) + return fut + + def _sock_connect(self, fut, registered, sock, address): + fd = sock.fileno() + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + try: + if not registered: + # First time around. + sock.connect(address) + else: + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to the except clause below. + raise socket.error(err, 'Connect call failed') + fut.set_result(None) + except (BlockingIOError, InterruptedError): + self.add_writer(fd, self._sock_connect, fut, True, sock, address) + except Exception as exc: + fut.set_exception(exc) + + def sock_accept(self, sock): + """XXX""" + fut = futures.Future() + self._sock_accept(fut, False, sock) + return fut + + def _sock_accept(self, fut, registered, sock): + fd = sock.fileno() + if registered: + self.remove_reader(fd) + if fut.cancelled(): + return + try: + conn, address = sock.accept() + conn.setblocking(False) + fut.set_result((conn, address)) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_accept, fut, True, sock) + except Exception as exc: + fut.set_exception(exc) + + def _process_events(self, event_list): + for fileobj, mask, (reader, writer) in event_list: + if mask & selectors.EVENT_READ and reader is not None: + if reader.cancelled: + self.remove_reader(fileobj) + else: + self._add_callback(reader) + if mask & selectors.EVENT_WRITE and writer is not None: + if writer.cancelled: + self.remove_writer(fileobj) + else: + self._add_callback(writer) + + def stop_serving(self, sock): + self.remove_reader(sock.fileno()) + sock.close() + + +class _SelectorSocketTransport(transports.Transport): + + def __init__(self, event_loop, sock, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._event_loop = event_loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._conn_lost = 0 + self._closing = False # Set when close() called. + self._event_loop.add_reader(self._sock.fileno(), self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = self._sock.recv(16*1024) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + self._event_loop.remove_reader(self._sock.fileno()) + self._protocol.eof_received() + + def write(self, data): + assert isinstance(data, (bytes, bytearray)), repr(data) + assert not self._closing + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Attempt to send it right away first. + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except socket.error as exc: + self._conn_lost += 1 + self._fatal_error(exc) + return + + if n == len(data): + return + elif n: + data = data[n:] + self._event_loop.add_writer(self._sock.fileno(), self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + assert data, "Data should not be empty" + + self._buffer.clear() + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except Exception as exc: + self._conn_lost += 1 + self._fatal_error(exc) + else: + if n == len(data): + self._event_loop.remove_writer(self._sock.fileno()) + if self._closing: + self._call_connection_lost(None) + return + elif n: + data = data[n:] + + self._buffer.append(data) # Try again later. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._close(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._sock.fileno()) + if not self._buffer: + self._event_loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + # should be called from exception handler only + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._event_loop.remove_writer(self._sock.fileno()) + self._event_loop.remove_reader(self._sock.fileno()) + self._buffer.clear() + self._event_loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + + +class _SelectorSslTransport(transports.Transport): + + def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, + server_side=False, extra=None): + super().__init__(extra) + + self._loop = loop + self._rawsock = rawsock + self._protocol = protocol + if server_side: + assert isinstance(sslcontext, + ssl.SSLContext), 'Must pass an SSLContext' + else: + # Client-side may pass ssl=True to use a default context. + sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + self._sslcontext = sslcontext + self._waiter = waiter + sslsock = sslcontext.wrap_socket(rawsock, server_side=server_side, + do_handshake_on_connect=False) + self._sslsock = sslsock + self._buffer = [] + self._conn_lost = 0 + self._closing = False # Set when close() called. + self._extra['socket'] = sslsock + + self._on_handshake() + + def _on_handshake(self): + fd = self._sslsock.fileno() + try: + self._sslsock.do_handshake() + except ssl.SSLWantReadError: + self._loop.add_reader(fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._loop.add_writer(fd, self._on_handshake) + return + except Exception as exc: + self._sslsock.close() + if self._waiter is not None: + self._waiter.set_exception(exc) + return + except BaseException as exc: + self._sslsock.close() + if self._waiter is not None: + self._waiter.set_exception(exc) + raise + self._loop.remove_reader(fd) + self._loop.remove_writer(fd) + self._loop.add_reader(fd, self._on_ready) + self._loop.add_writer(fd, self._on_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if self._waiter is not None: + self._loop.call_soon(self._waiter.set_result, None) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # Maybe we're already closed... + fd = self._sslsock.fileno() + if fd < 0: + return + + # First try reading. + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + # TODO: Don't close when self._buffer is non-empty. + assert not self._buffer + self._loop.remove_reader(fd) + self._loop.remove_writer(fd) + self._sslsock.close() + self._protocol.connection_lost(None) + return + + # Now try writing, if there's anything to write. + if not self._buffer: + return + + data = b''.join(self._buffer) + self._buffer = [] + try: + n = self._sslsock.send(data) + except ssl.SSLWantReadError: + n = 0 + except ssl.SSLWantWriteError: + n = 0 + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + self._conn_lost += 1 + self._fatal_error(exc) + return + + if n < len(data): + self._buffer.append(data[n:]) + elif self._closing: + self._loop.remove_writer(self._sslsock.fileno()) + self._sslsock.close() + self._protocol.connection_lost(None) + + def write(self, data): + assert isinstance(data, bytes), repr(data) + assert not self._closing + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._close(None) + + def close(self): + self._closing = True + self._loop.remove_reader(self._sslsock.fileno()) + if not self._buffer: + self._loop.call_soon(self._protocol.connection_lost, None) + + def _fatal_error(self, exc): + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._loop.remove_writer(self._sslsock.fileno()) + self._loop.remove_reader(self._sslsock.fileno()) + self._buffer = [] + self._loop.call_soon(self._protocol.connection_lost, exc) + + +class _SelectorDatagramTransport(transports.DatagramTransport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, event_loop, sock, protocol, address=None, extra=None): + super().__init__(extra) + self._extra['socket'] = sock + self._event_loop = event_loop + self._sock = sock + self._fileno = sock.fileno() + self._protocol = protocol + self._address = address + self._buffer = collections.deque() + self._conn_lost = 0 + self._closing = False # Set when close() called. + self._event_loop.add_reader(self._fileno, self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + + def _read_ready(self): + try: + data, addr = self._sock.recvfrom(self.max_size) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + self._protocol.datagram_received(data, addr) + + def sendto(self, data, addr=None): + assert isinstance(data, bytes) + assert not self._closing + if not data: + return + + if self._address: + assert addr in (None, self._address) + + if self._conn_lost and self._address: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Attempt to send it right away first. + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + return + except ConnectionRefusedError as exc: + if self._address: + self._conn_lost += 1 + self._fatal_error(exc) + return + except (BlockingIOError, InterruptedError): + self._event_loop.add_writer(self._fileno, self._sendto_ready) + except Exception as exc: + self._conn_lost += 1 + self._fatal_error(exc) + return + + self._buffer.append((data, addr)) + + def _sendto_ready(self): + while self._buffer: + data, addr = self._buffer.popleft() + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + except ConnectionRefusedError as exc: + if self._address: + self._conn_lost += 1 + self._fatal_error(exc) + return + except (BlockingIOError, InterruptedError): + self._buffer.appendleft((data, addr)) # Try again later. + break + except Exception as exc: + self._conn_lost += 1 + self._fatal_error(exc) + return + + if not self._buffer: + self._event_loop.remove_writer(self._fileno) + if self._closing: + self._call_connection_lost(None) + + def abort(self): + self._close(None) + + def close(self): + self._closing = True + self._event_loop.remove_reader(self._fileno) + if not self._buffer: + self._event_loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._buffer.clear() + self._event_loop.remove_writer(self._fileno) + self._event_loop.remove_reader(self._fileno) + if self._address and isinstance(exc, ConnectionRefusedError): + self._protocol.connection_refused(exc) + self._event_loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() diff --git a/tulip/selectors.py b/tulip/selectors.py new file mode 100644 index 00000000..388df25f --- /dev/null +++ b/tulip/selectors.py @@ -0,0 +1,426 @@ +"""Select module. + +This module supports asynchronous I/O on multiple file descriptors. +""" + +import sys +from select import * + +from .log import tulip_log + + +# generic events, that must be mapped to implementation-specific ones +# read event +EVENT_READ = (1 << 0) +# write event +EVENT_WRITE = (1 << 1) + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file descriptor, or any object with a `fileno()` method + + Returns: + corresponding file descriptor + """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (ValueError, TypeError): + raise ValueError("Invalid file object: {!r}".format(fileobj)) + return fd + + +class SelectorKey: + """Object used internally to associate a file object to its backing file + descriptor, selected event mask and attached data.""" + + def __init__(self, fileobj, events, data=None): + self.fileobj = fileobj + self.fd = _fileobj_to_fd(fileobj) + self.events = events + self.data = data + + def __repr__(self): + return '{}'.format( + self.__class__.__name__, + self.fileobj, self.fd, self.events, self.data) + + +class _BaseSelector: + """Base selector class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + performant implementation on the current platform. + """ + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # this maps file objects to keys - for fast (un)registering + self._fileobj_to_key = {} + + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + """ + if (not events) or (events & ~(EVENT_READ|EVENT_WRITE)): + raise ValueError("Invalid events: {}".format(events)) + + if fileobj in self._fileobj_to_key: + raise ValueError("{!r} is already registered".format(fileobj)) + + key = SelectorKey(fileobj, events, data) + self._fd_to_key[key.fd] = key + self._fileobj_to_key[fileobj] = key + return key + + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object + + Returns: + SelectorKey instance + """ + try: + key = self._fileobj_to_key[fileobj] + del self._fd_to_key[key.fd] + del self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + return key + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + """ + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + if events != key.events or data != key.data: + # TODO: If only the data changed, use a shortcut that only + # updates the data. + self.unregister(fileobj) + return self.register(fileobj, events, data) + else: + return key + + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout == 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (fileobj, events, attached data) for ready file objects + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE + """ + raise NotImplementedError() + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + self._fd_to_key.clear() + self._fileobj_to_key.clear() + + def get_info(self, fileobj): + """Return information about a registered file object. + + Returns: + (events, data) associated to this file object + + Raises KeyError if the file object is not registered. + """ + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise KeyError("{} is not registered".format(fileobj)) + return key.events, key.data + + def registered_count(self): + """Return the number of registered file objects. + + Returns: + number of currently registered file objects + """ + return len(self._fd_to_key) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key + """ + try: + return self._fd_to_key[fd] + except KeyError: + tulip_log.warning('No key found for fd %r', fd) + return None + + +class SelectSelector(_BaseSelector): + """Select-based selector.""" + + def __init__(self): + super().__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & EVENT_WRITE: + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + def select(self, timeout=None): + try: + r, w, _ = self._select(self._readers, self._writers, [], timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + r = set(r) + w = set(w) + ready = [] + for fd in r | w: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + if sys.platform == 'win32': + def _select(self, r, w, _, timeout=None): + r, w, x = select(r, w, w, timeout) + return r, w + x, [] + else: + from select import select as _select + + +if 'poll' in globals(): + + # TODO: Implement poll() for Windows with workaround for + # brokenness in WSAPoll() (Richard Oudkerk, see + # http://bugs.python.org/issue16507). + + class PollSelector(_BaseSelector): + """Poll-based selector.""" + + def __init__(self): + super().__init__() + self._poll = poll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= POLLIN + if events & EVENT_WRITE: + poll_events |= POLLOUT + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = None if timeout is None else int(1000 * timeout) + ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~POLLIN: + events |= EVENT_WRITE + if event & ~POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + +if 'epoll' in globals(): + + class EpollSelector(_BaseSelector): + """Epoll-based selector.""" + + def __init__(self): + super().__init__() + self._epoll = epoll() + + def fileno(self): + return self._epoll.fileno() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + epoll_events = 0 + if events & EVENT_READ: + epoll_events |= EPOLLIN + if events & EVENT_WRITE: + epoll_events |= EPOLLOUT + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._epoll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = -1 if timeout is None else timeout + max_ev = self.registered_count() + ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~EPOLLIN: + events |= EVENT_WRITE + if event & ~EPOLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + def close(self): + super().close() + self._epoll.close() + + +if 'kqueue' in globals(): + + class KqueueSelector(_BaseSelector): + """Kqueue-based selector.""" + + def __init__(self): + super().__init__() + self._kqueue = kqueue() + + def fileno(self): + return self._kqueue.fileno() + + def unregister(self, fileobj): + key = super().unregister(fileobj) + if key.events & EVENT_READ: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + if key.events & EVENT_WRITE: + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + return key + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & EVENT_WRITE: + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def select(self, timeout=None): + max_ev = self.registered_count() + ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == KQ_FILTER_READ: + events |= EVENT_READ + if flag == KQ_FILTER_WRITE: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + def close(self): + super().close() + self._kqueue.close() + + +# Choose the best implementation: roughly, epoll|kqueue > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + DefaultSelector = KqueueSelector +elif 'EpollSelector' in globals(): + DefaultSelector = EpollSelector +elif 'PollSelector' in globals(): + DefaultSelector = PollSelector +else: + DefaultSelector = SelectSelector diff --git a/tulip/streams.py b/tulip/streams.py new file mode 100644 index 00000000..51028ca7 --- /dev/null +++ b/tulip/streams.py @@ -0,0 +1,147 @@ +"""Stream-related things.""" + +__all__ = ['StreamReader'] + +import collections + +from . import futures +from . import tasks + + +class StreamReader: + + def __init__(self, limit=2**16): + self.limit = limit # Max line length. (Security feature.) + self.buffer = collections.deque() # Deque of bytes objects. + self.byte_count = 0 # Bytes in buffer. + self.eof = False # Whether we're done. + self.waiter = None # A future. + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_exception(exc) + + def feed_eof(self): + self.eof = True + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(True) + + def feed_data(self, data): + if not data: + return + + self.buffer.append(data) + self.byte_count += len(data) + + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_result(False) + + @tasks.coroutine + def readline(self): + if self._exception is not None: + raise self._exception + + parts = [] + parts_size = 0 + not_enough = True + + while not_enough: + while self.buffer and not_enough: + data = self.buffer.popleft() + ichar = data.find(b'\n') + if ichar < 0: + parts.append(data) + parts_size += len(data) + else: + ichar += 1 + head, tail = data[:ichar], data[ichar:] + if tail: + self.buffer.appendleft(tail) + not_enough = False + parts.append(head) + parts_size += len(head) + + if parts_size > self.limit: + self.byte_count -= parts_size + raise ValueError('Line is too long') + + if self.eof: + break + + if not_enough: + assert self.waiter is None + self.waiter = futures.Future() + yield from self.waiter + + line = b''.join(parts) + self.byte_count -= parts_size + + return line + + @tasks.coroutine + def read(self, n=-1): + if self._exception is not None: + raise self._exception + + if not n: + return b'' + + if n < 0: + while not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + else: + if not self.byte_count and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + + if n < 0 or self.byte_count <= n: + data = b''.join(self.buffer) + self.buffer.clear() + self.byte_count = 0 + return data + + parts = [] + parts_bytes = 0 + while self.buffer and parts_bytes < n: + data = self.buffer.popleft() + data_bytes = len(data) + if n < parts_bytes + data_bytes: + data_bytes = n - parts_bytes + data, rest = data[:data_bytes], data[data_bytes:] + self.buffer.appendleft(rest) + + parts.append(data) + parts_bytes += data_bytes + self.byte_count -= data_bytes + + return b''.join(parts) + + @tasks.coroutine + def readexactly(self, n): + if self._exception is not None: + raise self._exception + + if n <= 0: + return b'' + + while self.byte_count < n and not self.eof: + assert not self.waiter + self.waiter = futures.Future() + yield from self.waiter + + return (yield from self.read(n)) diff --git a/tulip/subprocess_transport.py b/tulip/subprocess_transport.py new file mode 100644 index 00000000..e790d285 --- /dev/null +++ b/tulip/subprocess_transport.py @@ -0,0 +1,156 @@ +# NOTE: This is a hack. Andrew Svetlov is working in a proper +# subprocess management transport for use with +# connect_{read,write}_pipe(). + +import fcntl +import os +import traceback + +from . import transports +from . import events +from .log import tulip_log + + +class UnixSubprocessTransport(transports.Transport): + """Transport class managing a subprocess. + + TODO: Separate this into something that just handles pipe I/O, + and something else that handles pipe setup, fork, and exec. + """ + + def __init__(self, protocol, args): + self._protocol = protocol # Not a factory! :-) + self._args = args # args[0] must be full path of binary. + self._event_loop = events.get_event_loop() + self._buffer = [] + self._eof = False + rstdin, self._wstdin = os.pipe() + self._rstdout, wstdout = os.pipe() + + # TODO: This is incredibly naive. Should look at + # subprocess.py for all the precautions around fork/exec. + pid = os.fork() + if not pid: + # Child. + try: + os.dup2(rstdin, 0) + os.dup2(wstdout, 1) + # TODO: What to do with stderr? + os.execv(args[0], args) + except: + try: + traceback.print_traceback() + finally: + os._exit(127) + + # Parent. + os.close(rstdin) + os.close(wstdout) + _setnonblocking(self._wstdin) + _setnonblocking(self._rstdout) + self._event_loop.call_soon(self._protocol.connection_made, self) + self._event_loop.add_reader(self._rstdout, self._stdout_callback) + + def write(self, data): + assert not self._eof + assert isinstance(data, bytes), repr(data) + if not data: + return + + if not self._buffer: + # Attempt to write it right away first. + try: + n = os.write(self._wstdin, data) + except BlockingIOError: + pass + except Exception as exc: + self._fatal_error(exc) + return + else: + if n == len(data): + return + elif n: + data = data[n:] + self._event_loop.add_writer(self._wstdin, self._stdin_callback) + self._buffer.append(data) + + def write_eof(self): + assert not self._eof + assert self._wstdin >= 0 + self._eof = True + if not self._buffer: + self._event_loop.remove_writer(self._wstdin) + os.close(self._wstdin) + self._wstdin = -1 + self._maybe_cleanup() + + def close(self): + if not self._eof: + self.write_eof() + self._maybe_cleanup() + + def _fatal_error(self, exc): + tulip_log.error('Fatal error: %r', exc) + if self._rstdout >= 0: + os.close(self._rstdout) + self._rstdout = -1 + if self._wstdin >= 0: + os.close(self._wstdin) + self._wstdin = -1 + self._eof = True + self._buffer = None + self._maybe_cleanup(exc) + + _conn_lost_called = False + + def _maybe_cleanup(self, exc=None): + if (self._wstdin < 0 and + self._rstdout < 0 and + not self._conn_lost_called): + self._conn_lost_called = True + self._event_loop.call_soon(self._protocol.connection_lost, exc) + + def _stdin_callback(self): + data = b''.join(self._buffer) + assert data, "Data shold not be empty" + + self._buffer = [] + try: + n = os.write(self._wstdin, data) + except BlockingIOError: + self._buffer.append(data) + except Exception as exc: + self._fatal_error(exc) + else: + if n >= len(data): + self._event_loop.remove_writer(self._wstdin) + if self._eof: + os.close(self._wstdin) + self._wstdin = -1 + self._maybe_cleanup() + return + + elif n > 0: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def _stdout_callback(self): + try: + data = os.read(self._rstdout, 1024) + except BlockingIOError: + pass + else: + if data: + self._event_loop.call_soon(self._protocol.data_received, data) + else: + self._event_loop.remove_reader(self._rstdout) + os.close(self._rstdout) + self._rstdout = -1 + self._event_loop.call_soon(self._protocol.eof_received) + self._maybe_cleanup() + + +def _setnonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) diff --git a/tulip/tasks.py b/tulip/tasks.py new file mode 100644 index 00000000..8dfd73a3 --- /dev/null +++ b/tulip/tasks.py @@ -0,0 +1,340 @@ +"""Support for tasks, coroutines and the scheduler.""" + +__all__ = ['coroutine', 'task', 'Task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'as_completed', 'sleep', + ] + +import concurrent.futures +import functools +import inspect +import time + +from . import futures + + +def coroutine(func): + """Decorator to mark coroutines. + + Decorator wraps non generator functions and returns generator wrapper. + If non generator function returns generator of Future it yield-from it. + + TODO: This is a feel-good API only. It is not enforced. + """ + if inspect.isgeneratorfunction(func): + coro = func + else: + @functools.wraps(func) + def coro(*args, **kw): + res = func(*args, **kw) + if isinstance(res, futures.Future) or inspect.isgenerator(res): + res = yield from res + return res + + coro._is_coroutine = True # Not sure who can use this. + return coro + + +# TODO: Do we need this? +def iscoroutinefunction(func): + """Return True if func is a decorated coroutine function.""" + return (inspect.isgeneratorfunction(func) and + getattr(func, '_is_coroutine', False)) + + +# TODO: Do we need this? +def iscoroutine(obj): + """Return True if obj is a coroutine object.""" + return inspect.isgenerator(obj) # TODO: And what? + + +def task(func): + """Decorator for a coroutine to be wrapped in a Task.""" + if inspect.isgeneratorfunction(func): + coro = func + else: + def coro(*args, **kw): + res = func(*args, **kw) + if isinstance(res, futures.Future) or inspect.isgenerator(res): + res = yield from res + return res + + def task_wrapper(*args, **kwds): + return Task(coro(*args, **kwds)) + + return task_wrapper + + +_marker = object() + + +class Task(futures.Future): + """A coroutine wrapped in a Future.""" + + # disable "Future result has not been requested" warning message. + _debug_warn_result_requested = False + + def __init__(self, coro, event_loop=None, timeout=None): + assert inspect.isgenerator(coro) # Must be a coroutine *object*. + super().__init__(event_loop=event_loop, timeout=timeout) + self._coro = coro + self._fut_waiter = None + self._must_cancel = False + self._event_loop.call_soon(self._step) + + def __repr__(self): + res = super().__repr__() + if (self._must_cancel and + self._state == futures._PENDING and + ')'.format(self._coro.__name__) + res[i:] + return res + + def cancel(self): + if self.done() or self._must_cancel: + return False + self._must_cancel = True + # _step() will call super().cancel() to call the callbacks. + if self._fut_waiter is not None: + return self._fut_waiter.cancel() + else: + self._event_loop.call_soon(self._step_maybe) + return True + + def cancelled(self): + return self._must_cancel or super().cancelled() + + def _step_maybe(self): + # Helper for cancel(). + if not self.done(): + return self._step() + + def _step(self, value=_marker, exc=None): + assert not self.done(), \ + '_step(): already done: {!r}, {!r}, {!r}'.format(self, value, exc) + + # We'll call either coro.throw(exc) or coro.send(value). + # Task cancel has to be delayed if current waiter future is done. + if self._must_cancel and exc is None and value is _marker: + exc = futures.CancelledError + + coro = self._coro + value = None if value is _marker else value + self._fut_waiter = None + try: + if exc is not None: + result = coro.throw(exc) + elif value is not None: + result = coro.send(value) + else: + result = next(coro) + except StopIteration as exc: + if self._must_cancel: + super().cancel() + else: + self.set_result(exc.value) + except Exception as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + except BaseException as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + raise + else: + if isinstance(result, futures.Future): + if not result._blocking: + result.set_exception( + RuntimeError( + 'yield was used instead of yield from ' + 'in task {!r} with {!r}'.format(self, result))) + + result._blocking = False + result.add_done_callback(self._wakeup) + self._fut_waiter = result + + # task cancellation has been delayed. + if self._must_cancel: + self._fut_waiter.cancel() + + elif isinstance(result, concurrent.futures.Future): + # This ought to be more efficient than wrap_future(), + # because we don't create an extra Future. + result.add_done_callback( + lambda future: + self._event_loop.call_soon_threadsafe( + self._wakeup, future)) + else: + if inspect.isgenerator(result): + self._event_loop.call_soon( + self._step, None, + RuntimeError( + 'yield was used instead of yield from for ' + 'generator in task {!r} with {}'.format( + self, result))) + else: + if result is not None: + self._event_loop.call_soon( + self._step, None, + RuntimeError( + 'Task got bad yield: {!r}'.format(result))) + else: + self._event_loop.call_soon(self._step_maybe) + + def _wakeup(self, future): + try: + value = future.result() + except Exception as exc: + self._step(None, exc) + else: + self._step(value, None) + + +# wait() and as_completed() similar to those in PEP 3148. + +FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED +FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION +ALL_COMPLETED = concurrent.futures.ALL_COMPLETED + + +# Even though this *is* a @coroutine, we don't mark it as such! +def wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Wait for the Futures and and coroutines given by fs to complete. + + Coroutines will be wrapped in Tasks. + + Returns two sets of Future: (done, pending). + + Usage: + + done, pending = yield from tulip.wait(fs) + + Note: This does not raise TimeoutError! Futures that aren't done + when the timeout occurs are returned in the second set. + """ + fs = _wrap_coroutines(fs) + return _wait(fs, timeout, return_when) + + +@coroutine +def _wait(fs, timeout=None, return_when=ALL_COMPLETED): + """Internal helper: Like wait() but does not wrap coroutines.""" + done, pending = set(), set() + + errors = 0 + for f in fs: + if f.done(): + done.add(f) + if not f.cancelled() and f.exception() is not None: + errors += 1 + else: + pending.add(f) + + if (not pending or + timeout is not None and timeout <= 0 or + return_when == FIRST_COMPLETED and done or + return_when == FIRST_EXCEPTION and errors): + return done, pending + + # Will always be cancelled eventually. + bail = futures.Future(timeout=timeout) + + def _on_completion(fut): + pending.remove(fut) + done.add(fut) + if (not pending or + return_when == FIRST_COMPLETED or + (return_when == FIRST_EXCEPTION and + not fut.cancelled() and + fut.exception() is not None)): + bail.cancel() + + for f in pending: + f.remove_done_callback(_on_completion) + + for f in pending: + f.add_done_callback(_on_completion) + try: + yield from bail + except futures.CancelledError: + pass + + really_done = set(f for f in pending if f.done()) + if really_done: + done.update(really_done) + pending.difference_update(really_done) + + return done, pending + + +# This is *not* a @coroutine! It is just an iterator (yielding Futures). +def as_completed(fs, timeout=None): + """Return an iterator whose values, when waited for, are Futures. + + This differs from PEP 3148; the proper way to use this is: + + for f in as_completed(fs): + result = yield from f # The 'yield from' may raise. + # Use result. + + Raises TimeoutError if the timeout occurs before all Futures are + done. + + Note: The futures 'f' are not necessarily members of fs. + """ + deadline = None + if timeout is not None: + deadline = time.monotonic() + timeout + + done = None # Make nonlocal happy. + fs = _wrap_coroutines(fs) + + while fs: + if deadline is not None: + timeout = deadline - time.monotonic() + + @coroutine + def _wait_for_some(): + nonlocal done, fs + done, fs = yield from _wait(fs, timeout=timeout, + return_when=FIRST_COMPLETED) + if not done: + fs = set() + raise futures.TimeoutError() + return done.pop().result() # May raise. + + yield Task(_wait_for_some()) + for f in done: + yield f + + +def _wrap_coroutines(fs): + """Internal helper to process an iterator of Futures and coroutines. + + Returns a set of Futures. + """ + wrapped = set() + for f in fs: + if not isinstance(f, futures.Future): + assert iscoroutine(f) + f = Task(f) + wrapped.add(f) + return wrapped + + +@coroutine +def sleep(delay, result=None): + """Coroutine that completes after a given time (in seconds).""" + future = futures.Future() + h = future._event_loop.call_later(delay, future.set_result, result) + try: + return (yield from future) + finally: + h.cancel() diff --git a/tulip/test_utils.py b/tulip/test_utils.py new file mode 100644 index 00000000..3cd5ba95 --- /dev/null +++ b/tulip/test_utils.py @@ -0,0 +1,259 @@ +"""Utilities shared by tests.""" + +import cgi +import contextlib +import gc +import email.parser +import http.server +import json +import logging +import io +import os +import re +import socket +import sys +import threading +import traceback +import urllib.parse +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +import tulip +import tulip.http +from tulip.http import client + + +if sys.platform == 'win32': # pragma: no cover + from .winsocketpair import socketpair +else: + from socket import socketpair # pragma: no cover + + +@contextlib.contextmanager +def run_test_server(loop, *, host='127.0.0.1', port=0, + use_ssl=False, router=None): + properties = {} + + class HttpServer: + + def __init__(self, host, port): + self.host = host + self.port = port + self.address = (host, port) + self._url = '{}://{}:{}'.format( + 'https' if use_ssl else 'http', host, port) + + def __getitem__(self, key): + return properties[key] + + def __setitem__(self, key, value): + properties[key] = value + + def url(self, *suffix): + return urllib.parse.urljoin( + self._url, '/'.join(str(s) for s in suffix)) + + class TestHttpServer(tulip.http.ServerHttpProtocol): + + def handle_request(self, message, payload): + if properties.get('noresponse', False): + return + + if router is not None: + body = bytearray() + chunk = yield from payload.read() + while chunk: + body.extend(chunk) + chunk = yield from payload.read() + + rob = router(properties, self.transport, message, bytes(body)) + rob.dispatch() + + else: + response = tulip.http.Response( + self.transport, 200, message.version) + + text = b'Test message' + response.add_header('Content-type', 'text/plain') + response.add_header('Content-length', str(len(text))) + response.send_headers() + response.write(text) + response.write_eof() + self.transport.close() + + if use_ssl: + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + keyfile = os.path.join(here, 'sample.key') + certfile = os.path.join(here, 'sample.crt') + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain(certfile, keyfile) + else: + sslcontext = None + + def run(loop, fut): + thread_loop = tulip.new_event_loop() + thread_loop.set_log_level(logging.CRITICAL) + tulip.set_event_loop(thread_loop) + + socks = thread_loop.run_until_complete( + thread_loop.start_serving( + TestHttpServer, host, port, ssl=sslcontext)) + + waiter = tulip.Future() + loop.call_soon_threadsafe( + fut.set_result, (thread_loop, waiter, socks[0].getsockname())) + + thread_loop.run_until_complete(waiter) + thread_loop.stop() + gc.collect() + + fut = tulip.Future() + server_thread = threading.Thread(target=run, args=(loop, fut)) + server_thread.start() + + thread_loop, waiter, addr = loop.run_until_complete(fut) + try: + yield HttpServer(*addr) + finally: + thread_loop.call_soon_threadsafe(waiter.set_result, None) + server_thread.join() + + +class Router: + + _response_version = "1.1" + _responses = http.server.BaseHTTPRequestHandler.responses + + def __init__(self, props, transport, message, payload): + # headers + self._headers = http.client.HTTPMessage() + for hdr, val in message.headers: + self._headers.add_header(hdr, val) + + self._props = props + self._transport = transport + self._method = message.method + self._uri = message.path + self._version = message.version + self._compression = message.compression + self._body = payload + + url = urllib.parse.urlsplit(self._uri) + self._path = url.path + self._query = url.query + + @staticmethod + def define(rmatch): + def wrapper(fn): + f_locals = sys._getframe(1).f_locals + mapping = f_locals.setdefault('_mapping', []) + mapping.append((re.compile(rmatch), fn.__name__)) + return fn + + return wrapper + + def dispatch(self): # pragma: no cover + for route, fn in self._mapping: + match = route.match(self._path) + if match is not None: + try: + return getattr(self, fn)(match) + except Exception: + out = io.StringIO() + traceback.print_exc(file=out) + self._response(500, out.getvalue()) + + return + + return self._response(self._start_response(404)) + + def _start_response(self, code): + return tulip.http.Response(self._transport, code) + + def _response(self, response, body=None, headers=None, chunked=False): + r_headers = {} + for key, val in self._headers.items(): + key = '-'.join(p.capitalize() for p in key.split('-')) + r_headers[key] = val + + encoding = self._headers.get('content-encoding', '').lower() + if 'gzip' in encoding: # pragma: no cover + cmod = 'gzip' + elif 'deflate' in encoding: + cmod = 'deflate' + else: + cmod = '' + + resp = { + 'method': self._method, + 'version': '%s.%s' % self._version, + 'path': self._uri, + 'headers': r_headers, + 'origin': self._transport.get_extra_info('addr', ' ')[0], + 'query': self._query, + 'form': {}, + 'compression': cmod, + 'multipart-data': [] + } + if body: # pragma: no cover + resp['content'] = body + + ct = self._headers.get('content-type', '').lower() + + # application/x-www-form-urlencoded + if ct == 'application/x-www-form-urlencoded': + resp['form'] = urllib.parse.parse_qs(self._body.decode('latin1')) + + # multipart/form-data + elif ct.startswith('multipart/form-data'): # pragma: no cover + out = io.BytesIO() + for key, val in self._headers.items(): + out.write(bytes('{}: {}\r\n'.format(key, val), 'latin1')) + + out.write(b'\r\n') + out.write(self._body) + out.write(b'\r\n') + out.seek(0) + + message = email.parser.BytesParser().parse(out) + if message.is_multipart(): + for msg in message.get_payload(): + if msg.is_multipart(): + logging.warn('multipart msg is not expected') + else: + key, params = cgi.parse_header( + msg.get('content-disposition', '')) + params['data'] = msg.get_payload() + params['content-type'] = msg.get_content_type() + resp['multipart-data'].append(params) + + body = json.dumps(resp, indent=4, sort_keys=True) + + # default headers + hdrs = [('Connection', 'close'), + ('Content-Type', 'application/json')] + if chunked: + hdrs.append(('Transfer-Encoding', 'chunked')) + else: + hdrs.append(('Content-Length', str(len(body)))) + + # extra headers + if headers: + hdrs.extend(headers.items()) + + if chunked: + response.force_chunked() + + # headers + response.add_headers(*hdrs) + response.send_headers() + + # write payload + response.write(client.str_to_bytes(body)) + response.write_eof() + + # keep-alive + if not response.keep_alive(): + self._transport.close() diff --git a/tulip/transports.py b/tulip/transports.py new file mode 100644 index 00000000..a9ec07a0 --- /dev/null +++ b/tulip/transports.py @@ -0,0 +1,134 @@ +"""Abstract Transport class.""" + +__all__ = ['ReadTransport', 'WriteTransport', 'Transport'] + + +class BaseTransport: + """Base ABC for transports.""" + + def __init__(self, extra=None): + if extra is None: + extra = {} + self._extra = extra + + def get_extra_info(self, name, default=None): + """Get optional transport information.""" + return self._extra.get(name, default) + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + raise NotImplementedError + + +class ReadTransport(BaseTransport): + """ABC for read-only transports.""" + + def pause(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + raise NotImplementedError + + +class WriteTransport(BaseTransport): + """ABC for write-only transports.""" + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def write_eof(self): + """Closes the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + raise NotImplementedError + + def can_write_eof(self): + """Return True if this protocol supports write_eof(), False if not.""" + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class Transport(ReadTransport, WriteTransport): + """ABC representing a bidirectional transport. + + There may be several implementations, but typically, the user does + not implement new transports; rather, the platform provides some + useful transports that are implemented using the platform's best + practices. + + The user never instantiates a transport directly; they call a + utility function, passing it a protocol factory and other + information necessary to create the transport and protocol. (E.g. + EventLoop.create_connection() or EventLoop.start_serving().) + + The utility function will asynchronously create a transport and a + protocol and hook them up by calling the protocol's + connection_made() method, passing it the transport. + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + +class DatagramTransport(BaseTransport): + """ABC for datagram (UDP) transports.""" + + def sendto(self, data, addr=None): + """Send data to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + addr is target socket address. + If addr is None use target address pointed on transport creation. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError diff --git a/tulip/unix_events.py b/tulip/unix_events.py new file mode 100644 index 00000000..87514ef1 --- /dev/null +++ b/tulip/unix_events.py @@ -0,0 +1,312 @@ +"""Selector eventloop for Unix with signal handling.""" + +import errno +import fcntl +import os +import socket +import sys + +try: + import signal +except ImportError: # pragma: no cover + signal = None + +from . import constants +from . import events +from . import selector_events +from . import transports +from .log import tulip_log + + +__all__ = ['SelectorEventLoop'] + + +if sys.platform == 'win32': # pragma: no cover + raise ImportError('Signals are not really supported on Windows') + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Unix event loop + + Adds signal handling to SelectorEventLoop + """ + + def __init__(self, selector=None): + super().__init__(selector) + self._signal_handlers = {} + + def _socketpair(self): + return socket.socketpair() + + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + self._check_signal(sig) + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except ValueError as exc: + raise RuntimeError(str(exc)) + + handle = events.make_handle(callback, args) + self._signal_handlers[sig] = handle + + try: + signal.signal(sig, self._handle_signal) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as nexc: + tulip_log.info('set_wakeup_fd(-1) failed: %s', nexc) + + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + return handle + + def _handle_signal(self, sig, arg): + """Internal helper that is the actual signal handler.""" + handle = self._signal_handlers.get(sig) + if handle is None: + return # Assume it's some race condition. + if handle.cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self.call_soon_threadsafe(handle) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not. + """ + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as exc: + tulip_log.info('set_wakeup_fd(-1) failed: %s', exc) + + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + + if signal is None: + raise RuntimeError('Signals are not supported') + + if not (1 <= sig < signal.NSIG): + raise ValueError( + 'sig {} out of range(1, {})'.format(sig, signal.NSIG)) + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixReadPipeTransport(self, pipe, protocol, waiter, extra) + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixWritePipeTransport(self, pipe, protocol, waiter, extra) + + +def _set_nonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + + +class _UnixReadPipeTransport(transports.ReadTransport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, event_loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._event_loop = event_loop + self._pipe = pipe + self._fileno = pipe.fileno() + _set_nonblocking(self._fileno) + self._protocol = protocol + self._closing = False + self._event_loop.add_reader(self._fileno, self._read_ready) + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = os.read(self._fileno, self.max_size) + except (BlockingIOError, InterruptedError): + pass + except OSError as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + self._event_loop.remove_reader(self._fileno) + self._protocol.eof_received() + + def pause(self): + self._event_loop.remove_reader(self._fileno) + + def resume(self): + self._event_loop.add_reader(self._fileno, self._read_ready) + + def close(self): + if not self._closing: + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._closing = True + self._event_loop.remove_reader(self._fileno) + self._event_loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + + +class _UnixWritePipeTransport(transports.WriteTransport): + + def __init__(self, event_loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._event_loop = event_loop + self._pipe = pipe + self._fileno = pipe.fileno() + _set_nonblocking(self._fileno) + self._protocol = protocol + self._buffer = [] + self._conn_lost = 0 + self._closing = False # Set when close() or write_eof() called. + self._event_loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._event_loop.call_soon(waiter.set_result, None) + + def write(self, data): + assert isinstance(data, bytes), repr(data) + assert not self._closing + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('os.write(pipe, data) raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Attempt to send it right away first. + try: + n = os.write(self._fileno, data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + self._conn_lost += 1 + self._fatal_error(exc) + return + if n == len(data): + return + elif n > 0: + data = data[n:] + self._event_loop.add_writer(self._fileno, self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + assert data, "Data should not be empty" + + self._buffer.clear() + try: + n = os.write(self._fileno, data) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except Exception as exc: + self._conn_lost += 1 + self._fatal_error(exc) + else: + if n == len(data): + self._event_loop.remove_writer(self._fileno) + if self._closing: + self._call_connection_lost(None) + return + elif n > 0: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def can_write_eof(self): + return True + + def write_eof(self): + assert not self._closing + assert self._pipe + self._closing = True + if not self._buffer: + self._event_loop.call_soon(self._call_connection_lost, None) + + def close(self): + if not self._closing: + # write_eof is all what we needed to close the write pipe + self.write_eof() + + def abort(self): + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc=None): + self._closing = True + self._buffer.clear() + self._event_loop.remove_writer(self._fileno) + self._event_loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() diff --git a/tulip/windows_events.py b/tulip/windows_events.py new file mode 100644 index 00000000..bede2b5e --- /dev/null +++ b/tulip/windows_events.py @@ -0,0 +1,157 @@ +"""Selector and proactor eventloops for Windows.""" + +import socket +import weakref +import struct +import _winapi + +from . import futures +from . import proactor_events +from . import selector_events +from . import winsocketpair +from . import _overlapped +from .log import tulip_log + + +__all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor'] + + +NULL = 0 +INFINITE = 0xffffffff +ERROR_CONNECTION_REFUSED = 1225 +ERROR_CONNECTION_ABORTED = 1236 + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + def _socketpair(self): + return winsocketpair.socketpair() + + +class ProactorEventLoop(proactor_events.BaseProactorEventLoop): + def __init__(self, proactor=None): + if proactor is None: + proactor = IocpProactor() + super().__init__(proactor) + + def _socketpair(self): + return winsocketpair.socketpair() + + +class IocpProactor: + + def __init__(self, concurrency=0xffffffff): + self._results = [] + self._iocp = _overlapped.CreateIoCompletionPort( + _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency) + self._cache = {} + self._registered = weakref.WeakSet() + + def registered_count(self): + return len(self._cache) + + def select(self, timeout=None): + if not self._results: + self._poll(timeout) + tmp = self._results + self._results = [] + return tmp + + def recv(self, conn, nbytes, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + ov.WSARecv(conn.fileno(), nbytes, flags) + return self._register(ov, conn, ov.getresult) + + def send(self, conn, buf, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + ov.WSASend(conn.fileno(), buf, flags) + return self._register(ov, conn, ov.getresult) + + def accept(self, listener): + self._register_with_iocp(listener) + conn = self._get_accept_socket() + ov = _overlapped.Overlapped(NULL) + ov.AcceptEx(listener.fileno(), conn.fileno()) + + def finish_accept(): + ov.getresult() + buf = struct.pack('@P', listener.fileno()) + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_ACCEPT_CONTEXT, + buf) + conn.settimeout(listener.gettimeout()) + return conn, conn.getpeername() + + return self._register(ov, listener, finish_accept) + + def connect(self, conn, address): + self._register_with_iocp(conn) + _overlapped.BindLocal(conn.fileno(), len(address)) + ov = _overlapped.Overlapped(NULL) + ov.ConnectEx(conn.fileno(), address) + + def finish_connect(): + ov.getresult() + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_CONNECT_CONTEXT, + 0) + return conn + + return self._register(ov, conn, finish_connect) + + def _register_with_iocp(self, obj): + if obj not in self._registered: + self._registered.add(obj) + _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) + + def _register(self, ov, obj, callback): + f = futures.Future() + self._cache[ov.address] = (f, ov, obj, callback) + return f + + def _get_accept_socket(self): + s = socket.socket() + s.settimeout(0) + return s + + def _poll(self, timeout=None): + if timeout is None: + ms = INFINITE + elif timeout < 0: + raise ValueError("negative timeout") + else: + ms = int(timeout * 1000 + 0.5) + if ms >= INFINITE: + raise ValueError("timeout too big") + while True: + status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms) + if status is None: + return + address = status[3] + f, ov, obj, callback = self._cache.pop(address) + try: + value = callback() + except OSError as e: + f.set_exception(e) + self._results.append(f) + else: + f.set_result(value) + self._results.append(f) + ms = 0 + + def close(self): + for (f, ov, obj, callback) in self._cache.values(): + try: + ov.cancel() + except OSError: + pass + + while self._cache: + if not self._poll(1): + tulip_log.debug('taking long time to close proactor') + + self._results = [] + if self._iocp is not None: + _winapi.CloseHandle(self._iocp) + self._iocp = None diff --git a/tulip/winsocketpair.py b/tulip/winsocketpair.py new file mode 100644 index 00000000..59c8aecc --- /dev/null +++ b/tulip/winsocketpair.py @@ -0,0 +1,34 @@ +"""A socket pair usable as a self-pipe, for Windows. + +Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. +""" + +import socket +import sys + +if sys.platform != 'win32': # pragma: no cover + raise ImportError('winsocketpair is win32 only') + + +def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """Emulate the Unix socketpair() function on Windows.""" + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + lsock.bind(('localhost', 0)) + lsock.listen(1) + addr, port = lsock.getsockname() + csock = socket.socket(family, type, proto) + csock.setblocking(False) + try: + csock.connect((addr, port)) + except (BlockingIOError, InterruptedError): + pass + except Exception: + lsock.close() + csock.close() + raise + ssock, _ = lsock.accept() + csock.setblocking(True) + lsock.close() + return (ssock, csock) From 7548bf3d130f8f1bab5ab0af3e67ca6809afb816 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 2 May 2013 15:51:03 -0700 Subject: [PATCH 0456/1502] Remove call_repeatedly(). Add time() and call_at(). add_*_handler() no longer returns Handles, call_soon() no longer accepts them. --- tests/base_events_test.py | 41 +++++---- tests/events_test.py | 169 +--------------------------------- tests/selector_events_test.py | 55 +++++++---- tests/unix_events_test.py | 7 +- tulip/base_events.py | 58 +++++------- tulip/events.py | 15 ++- tulip/selector_events.py | 8 +- tulip/unix_events.py | 4 +- 8 files changed, 102 insertions(+), 255 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index aeb1cd8c..6dfaeecf 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -45,25 +45,14 @@ def test_not_implemented(self): NotImplementedError, self.event_loop._make_write_pipe_transport, m, m) - def test_add_callback_handle(self): + def test__add_callback_handle(self): h = events.Handle(lambda: False, ()) self.event_loop._add_callback(h) self.assertFalse(self.event_loop._scheduled) self.assertIn(h, self.event_loop._ready) - def test_add_callback_timer(self): - when = time.monotonic() - - h1 = events.TimerHandle(when, lambda: False, ()) - h2 = events.TimerHandle(when+10.0, lambda: False, ()) - - self.event_loop._add_callback(h2) - self.event_loop._add_callback(h1) - self.assertEqual([h1, h2], self.event_loop._scheduled) - self.assertFalse(self.event_loop._ready) - - def test_add_callback_cancelled_handle(self): + def test__add_callback_cancelled_handle(self): h = events.Handle(lambda: False, ()) h.cancel() @@ -113,13 +102,29 @@ def cb(): self.assertIn(h, self.event_loop._scheduled) self.assertNotIn(h, self.event_loop._ready) - def test_call_later_no_delay(self): + def test_call_later_negative_delays(self): + calls = [] + + def cb(arg): + calls.append(arg) + + self.event_loop._process_events = unittest.mock.Mock() + self.event_loop.call_later(-1, cb, 'a') + self.event_loop.call_later(-2, cb, 'b') + self.event_loop.run_once() + self.assertEqual(calls, ['b', 'a']) + + def test_time_and_call_at(self): def cb(): - pass + self.event_loop.stop() - h = self.event_loop.call_later(0, cb) - self.assertIn(h, self.event_loop._ready) - self.assertNotIn(h, self.event_loop._scheduled) + self.event_loop._process_events = unittest.mock.Mock() + when = self.event_loop.time() + 0.1 + self.event_loop.call_at(when, cb) + t0 = self.event_loop.time() + self.event_loop.run_forever() + t1 = self.event_loop.time() + self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0) def test_run_once_in_executor_handle(self): def cb(): diff --git a/tests/events_test.py b/tests/events_test.py index aea1606f..cc694d33 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -208,17 +208,6 @@ def callback(arg): self.assertEqual(results, ['hello world']) self.assertTrue(0.09 <= t1-t0 <= 0.12) - def test_call_repeatedly(self): - results = [] - - def callback(arg): - results.append(arg) - - self.event_loop.call_repeatedly(0.03, callback, 'ho') - self.event_loop.call_later(0.1, self.event_loop.stop) - self.event_loop.run_forever() - self.assertEqual(results, ['ho', 'ho', 'ho']) - def test_call_soon(self): results = [] @@ -230,18 +219,6 @@ def callback(arg1, arg2): self.event_loop.run_forever() self.assertEqual(results, [('hello', 'world')]) - def test_call_soon_with_handle(self): - results = [] - - def callback(): - results.append('yeah') - self.event_loop.stop() - - handle = events.Handle(callback, ()) - self.assertIs(self.event_loop.call_soon(handle), handle) - self.event_loop.run_forever() - self.assertEqual(results, ['yeah']) - def test_call_soon_threadsafe(self): results = [] @@ -276,31 +253,6 @@ def callback(arg): self.event_loop.run_forever() self.assertEqual(results, ['hello', 'world']) - def test_call_soon_threadsafe_with_handle(self): - results = [] - - def callback(arg): - results.append(arg) - if len(results) >= 2: - self.event_loop.stop() - - handle = events.Handle(callback, ('hello',)) - - def run(): - self.assertIs( - self.event_loop.call_soon_threadsafe(handle), handle) - - t = threading.Thread(target=run) - self.event_loop.call_later(0.1, callback, 'world') - - t0 = time.monotonic() - t.start() - self.event_loop.run_forever() - t1 = time.monotonic() - t.join() - self.assertEqual(results, ['hello', 'world']) - self.assertTrue(t1-t0 >= 0.09) - def test_wrap_future(self): def run(arg): time.sleep(0.1) @@ -319,15 +271,6 @@ def run(arg): res = self.event_loop.run_until_complete(f2) self.assertEqual(res, 'yo') - def test_run_in_executor_with_handle(self): - def run(arg): - time.sleep(0.01) - return arg - handle = events.Handle(run, ('yo',)) - f2 = self.event_loop.run_in_executor(None, handle) - res = self.event_loop.run_until_complete(f2) - self.assertEqual(res, 'yo') - def test_reader_callback(self): r, w = test_utils.socketpair() bytes_read = [] @@ -353,57 +296,6 @@ def reader(): self.event_loop.run_forever() self.assertEqual(b''.join(bytes_read), b'abcdef') - def test_reader_callback_with_handle(self): - r, w = test_utils.socketpair() - bytes_read = [] - - def reader(): - try: - data = r.recv(1024) - except BlockingIOError: - # Spurious readiness notifications are possible - # at least on Linux -- see man select. - return - if data: - bytes_read.append(data) - else: - self.assertTrue(self.event_loop.remove_reader(r.fileno())) - r.close() - - handle = events.Handle(reader, ()) - self.assertIs(handle, self.event_loop.add_reader(r.fileno(), handle)) - - self.event_loop.call_later(0.05, w.send, b'abc') - self.event_loop.call_later(0.1, w.send, b'def') - self.event_loop.call_later(0.15, w.close) - self.event_loop.call_later(0.16, self.event_loop.stop) - self.event_loop.run_forever() - self.assertEqual(b''.join(bytes_read), b'abcdef') - - def test_reader_callback_cancel(self): - r, w = test_utils.socketpair() - bytes_read = [] - - def reader(): - try: - data = r.recv(1024) - except BlockingIOError: - return - if data: - bytes_read.append(data) - if sum(len(b) for b in bytes_read) >= 6: - handle.cancel() - if not data: - r.close() - - handle = self.event_loop.add_reader(r.fileno(), reader) - self.event_loop.call_later(0.05, w.send, b'abc') - self.event_loop.call_later(0.1, w.send, b'def') - self.event_loop.call_later(0.15, w.close) - self.event_loop.call_later(0.16, self.event_loop.stop) - self.event_loop.run_forever() - self.assertEqual(b''.join(bytes_read), b'abcdef') - def test_writer_callback(self): r, w = test_utils.socketpair() w.setblocking(False) @@ -420,39 +312,6 @@ def remove_writer(): r.close() self.assertTrue(len(data) >= 200) - def test_writer_callback_with_handle(self): - r, w = test_utils.socketpair() - w.setblocking(False) - handle = events.Handle(w.send, (b'x'*(256*1024),)) - self.assertIs(self.event_loop.add_writer(w.fileno(), handle), handle) - - def remove_writer(): - self.assertTrue(self.event_loop.remove_writer(w.fileno())) - - self.event_loop.call_later(0.1, remove_writer) - self.event_loop.call_later(0.11, self.event_loop.stop) - self.event_loop.run_forever() - w.close() - data = r.recv(256*1024) - r.close() - self.assertTrue(len(data) >= 200) - - def test_writer_callback_cancel(self): - r, w = test_utils.socketpair() - w.setblocking(False) - - def sender(): - w.send(b'x'*256) - handle.cancel() - self.event_loop.stop() - - handle = self.event_loop.add_writer(w.fileno(), sender) - self.event_loop.run_forever() - w.close() - data = r.recv(1024) - r.close() - self.assertTrue(data == b'x'*256) - def test_sock_client_ops(self): with test_utils.run_test_server(self.event_loop) as httpd: sock = socket.socket() @@ -549,21 +408,6 @@ def my_handler(): # Removing again returns False. self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGINT)) - @unittest.skipIf(sys.platform == 'win32', 'Unix only') - def test_cancel_signal_handler(self): - # Cancelling the handler should remove it (eventually). - caught = 0 - - def my_handler(): - nonlocal caught - caught += 1 - - handle = self.event_loop.add_signal_handler(signal.SIGINT, my_handler) - handle.cancel() - os.kill(os.getpid(), signal.SIGINT) - self.event_loop.run_once() - self.assertEqual(caught, 0) - @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') def test_signal_handling_while_selecting(self): # Test with a signal actually arriving during a select() call. @@ -1172,14 +1016,10 @@ def test_reader_callback(self): raise unittest.SkipTest("IocpEventLoop does not have add_reader()") def test_reader_callback_cancel(self): raise unittest.SkipTest("IocpEventLoop does not have add_reader()") - def test_reader_callback_with_handle(self): - raise unittest.SkipTest("IocpEventLoop does not have add_reader()") def test_writer_callback(self): raise unittest.SkipTest("IocpEventLoop does not have add_writer()") def test_writer_callback_cancel(self): raise unittest.SkipTest("IocpEventLoop does not have add_writer()") - def test_writer_callback_with_handle(self): - raise unittest.SkipTest("IocpEventLoop does not have add_writer()") def test_accept_connection_retry(self): raise unittest.SkipTest( "IocpEventLoop does not have _accept_connection()") @@ -1268,11 +1108,8 @@ def test_make_handle(self): def callback(*args): return args h1 = events.Handle(callback, ()) - h2 = events.make_handle(h1, ()) - self.assertIs(h1, h2) - self.assertRaises( - AssertionError, events.make_handle, h1, (1, 2)) + AssertionError, events.make_handle, h1, ()) @unittest.mock.patch('tulip.events.tulip_log') def test_callback_with_exception(self, log): @@ -1364,10 +1201,10 @@ def test_not_imlemented(self): NotImplementedError, ev_loop.stop) self.assertRaises( NotImplementedError, ev_loop.call_later, None, None) - self.assertRaises( - NotImplementedError, ev_loop.call_repeatedly, None, None) self.assertRaises( NotImplementedError, ev_loop.call_soon, None) + self.assertRaises( + NotImplementedError, ev_loop.time) self.assertRaises( NotImplementedError, ev_loop.call_soon_threadsafe, None) self.assertRaises( diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index 9cde8b94..b70e6e50 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -383,38 +383,47 @@ def test__sock_accept_exception(self): def test_add_reader(self): self.event_loop._selector.get_info.side_effect = KeyError - h = self.event_loop.add_reader(1, lambda: True) + cb = lambda: True + self.event_loop.add_reader(1, cb) self.assertTrue(self.event_loop._selector.register.called) - self.assertEqual( - (1, selectors.EVENT_READ, (h, None)), - self.event_loop._selector.register.call_args[0]) + fd, mask, (r, w) = self.event_loop._selector.register.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_READ, mask) + self.assertEqual(cb, r.callback) + self.assertEqual(None, w) def test_add_reader_existing(self): reader = unittest.mock.Mock() writer = unittest.mock.Mock() self.event_loop._selector.get_info.return_value = ( selectors.EVENT_WRITE, (reader, writer)) - h = self.event_loop.add_reader(1, lambda: True) + cb = lambda: True + self.event_loop.add_reader(1, cb) self.assertTrue(reader.cancel.called) self.assertFalse(self.event_loop._selector.register.called) self.assertTrue(self.event_loop._selector.modify.called) - self.assertEqual( - (1, selectors.EVENT_WRITE | selectors.EVENT_READ, (h, writer)), - self.event_loop._selector.modify.call_args[0]) + fd, mask, (r, w) = self.event_loop._selector.modify.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask) + self.assertEqual(cb, r.callback) + self.assertEqual(writer, w) def test_add_reader_existing_writer(self): writer = unittest.mock.Mock() self.event_loop._selector.get_info.return_value = ( selectors.EVENT_WRITE, (None, writer)) - h = self.event_loop.add_reader(1, lambda: True) + cb = lambda: True + self.event_loop.add_reader(1, cb) self.assertFalse(self.event_loop._selector.register.called) self.assertTrue(self.event_loop._selector.modify.called) - self.assertEqual( - (1, selectors.EVENT_WRITE | selectors.EVENT_READ, (h, writer)), - self.event_loop._selector.modify.call_args[0]) + fd, mask, (r, w) = self.event_loop._selector.modify.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask) + self.assertEqual(cb, r.callback) + self.assertEqual(writer, w) def test_remove_reader(self): self.event_loop._selector.get_info.return_value = ( @@ -443,26 +452,32 @@ def test_remove_reader_unknown(self): def test_add_writer(self): self.event_loop._selector.get_info.side_effect = KeyError - h = self.event_loop.add_writer(1, lambda: True) + cb = lambda: True + self.event_loop.add_writer(1, cb) self.assertTrue(self.event_loop._selector.register.called) - self.assertEqual( - (1, selectors.EVENT_WRITE, (None, h)), - self.event_loop._selector.register.call_args[0]) + fd, mask, (r, w) = self.event_loop._selector.register.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_WRITE, mask) + self.assertEqual(None, r) + self.assertEqual(cb, w.callback) def test_add_writer_existing(self): reader = unittest.mock.Mock() writer = unittest.mock.Mock() self.event_loop._selector.get_info.return_value = ( selectors.EVENT_READ, (reader, writer)) - h = self.event_loop.add_writer(1, lambda: True) + cb = lambda: True + self.event_loop.add_writer(1, cb) self.assertTrue(writer.cancel.called) self.assertFalse(self.event_loop._selector.register.called) self.assertTrue(self.event_loop._selector.modify.called) - self.assertEqual( - (1, selectors.EVENT_WRITE | selectors.EVENT_READ, (reader, h)), - self.event_loop._selector.modify.call_args[0]) + fd, mask, (r, w) = self.event_loop._selector.modify.call_args[0] + self.assertEqual(1, fd) + self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask) + self.assertEqual(reader, r) + self.assertEqual(cb, w.callback) def test_remove_writer(self): self.event_loop._selector.get_info.return_value = ( diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index b96719a8..536ac0a7 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -54,8 +54,11 @@ def test_add_signal_handler_setup_error(self, m_signal): def test_add_signal_handler(self, m_signal): m_signal.NSIG = signal.NSIG - h = self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) - self.assertIsInstance(h, events.Handle) + cb = lambda: True + self.event_loop.add_signal_handler(signal.SIGHUP, cb) + h = self.event_loop._signal_handlers.get(signal.SIGHUP) + self.assertTrue(isinstance(h, events.Handle)) + self.assertEqual(h.callback, cb) @unittest.mock.patch('tulip.unix_events.signal') def test_add_signal_handler_install_error(self, m_signal): diff --git a/tulip/base_events.py b/tulip/base_events.py index 4d890722..2d7fa22d 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -172,11 +172,15 @@ def is_running(self): """Returns running status of event loop.""" return self._running + def time(self): + """Return the time according to the event loop's clock.""" + return time.monotonic() + def call_later(self, delay, callback, *args): """Arrange for a callback to be called at a given time. - Return an object with a cancel() method that can be used to - cancel the call. + Return a Handle: an opaque object with a cancel() method that + can be used to cancel the call. The delay can be an int or float, expressed in seconds. It is always a relative time. @@ -185,34 +189,16 @@ def call_later(self, delay, callback, *args): are scheduled for exactly the same time, it undefined which will be called first. - Callbacks scheduled in the past are passed on to call_soon(), - so these will be called in the order in which they were - registered rather than by time due. This is so you can't - cheat and insert yourself at the front of the ready queue by - using a negative time. - Any positional arguments after the callback will be passed to the callback when it is called. """ - if delay <= 0: - return self.call_soon(callback, *args) - - handle = events.TimerHandle(time.monotonic() + delay, callback, args) - heapq.heappush(self._scheduled, handle) - return handle - - def call_repeatedly(self, interval, callback, *args): - """Call a callback every 'interval' seconds.""" - assert interval > 0, 'Interval must be > 0: {!r}'.format(interval) - # TODO: What if callback is already a Handle? - def wrapper(): - callback(*args) # If this fails, the chain is broken. - handle._when = time.monotonic() + interval - heapq.heappush(self._scheduled, handle) + return self.call_at(self.time() + delay, callback, *args) - handle = events.TimerHandle(time.monotonic() + interval, wrapper, ()) - heapq.heappush(self._scheduled, handle) - return handle + def call_at(self, when, callback, *args): + """Like call_later(), but uses an absolute time.""" + timer = events.TimerHandle(when, callback, args) + heapq.heappush(self._scheduled, timer) + return timer def call_soon(self, callback, *args): """Arrange for a callback to be called as soon as possible. @@ -473,6 +459,7 @@ def connect_write_pipe(self, protocol_factory, pipe): def _add_callback(self, handle): """Add a Handle to ready or scheduled.""" + assert isinstance(handle, events.Handle), 'A Handle is required here' if handle.cancelled: return if isinstance(handle, events.TimerHandle): @@ -480,6 +467,11 @@ def _add_callback(self, handle): else: self._ready.append(handle) + def _add_callback_signalsafe(self, handle): + """Like _add_callback() but called from a signal handler.""" + self._add_callback(handle) + self._write_to_self() + def wrap_future(self, future): """XXX""" if isinstance(future, futures.Future): @@ -497,11 +489,6 @@ def _run_once(self, timeout=None): schedules the resulting callbacks, and finally schedules 'call_later' callbacks. """ - # TODO: Break each of these into smaller pieces. - # TODO: Refactor to separate the callbacks from the readers/writers. - # TODO: An alternative API would be to do the *minimal* amount - # of work, e.g. one callback or one I/O poll. - # Remove delayed calls that were cancelled from head of queue. while self._scheduled and self._scheduled[0].cancelled: heapq.heappop(self._scheduled) @@ -511,15 +498,16 @@ def _run_once(self, timeout=None): elif self._scheduled: # Compute the desired timeout. when = self._scheduled[0].when - deadline = max(0, when - time.monotonic()) + deadline = max(0, when - self.time()) if timeout is None: timeout = deadline else: timeout = min(timeout, deadline) - t0 = time.monotonic() + # TODO: Instrumentation only in debug mode? + t0 = self.time() event_list = self._selector.select(timeout) - t1 = time.monotonic() + t1 = self.time() argstr = '' if timeout is None else '{:.3f}'.format(timeout) if t1-t0 >= 1: level = logging.INFO @@ -529,7 +517,7 @@ def _run_once(self, timeout=None): self._process_events(event_list) # Handle 'later' callbacks that are ready. - now = time.monotonic() + now = self.time() while self._scheduled: handle = self._scheduled[0] if handle.when > now: diff --git a/tulip/events.py b/tulip/events.py index c8f2401c..25ec15a1 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -55,9 +55,8 @@ def run(self): def make_handle(callback, args): - if isinstance(callback, Handle): - assert not args - return callback + # TODO: Inline this? + assert not isinstance(callback, Handle), 'A Handle is not a callback' return Handle(callback, args) @@ -83,6 +82,9 @@ def __repr__(self): def when(self): return self._when + def __hash__(self): + return hash(self._when) + def __lt__(self, other): return self._when < other._when @@ -159,7 +161,10 @@ def call_soon(self, callback, *args): def call_later(self, delay, callback, *args): raise NotImplementedError - def call_repeatedly(self, interval, callback, *args): + def call_at(self, when, callback, *args): + raise NotImplementedError + + def time(self): raise NotImplementedError # Methods for interacting with threads. @@ -260,7 +265,7 @@ def connect_write_pipe(self, protocol_factory, pipe): # raise NotImplementedError # Ready-based callback registration methods. - # The add_*() methods return a Handle. + # The add_*() methods return None. # The remove_*() methods return True if something was removed, # False if there was nothing to delete. diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 488dad18..5c182fa4 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -126,7 +126,7 @@ def _accept_connection(self, protocol_factory, sock, ssl=None): # It's now up to the protocol to handle the connection. def add_reader(self, fd, callback, *args): - """Add a reader callback. Return a Handle instance.""" + """Add a reader callback.""" handle = events.make_handle(callback, args) try: mask, (reader, writer) = self._selector.get_info(fd) @@ -139,8 +139,6 @@ def add_reader(self, fd, callback, *args): if reader is not None: reader.cancel() - return handle - def remove_reader(self, fd): """Remove a reader callback.""" try: @@ -161,7 +159,7 @@ def remove_reader(self, fd): return False def add_writer(self, fd, callback, *args): - """Add a writer callback. Return a Handle instance.""" + """Add a writer callback..""" handle = events.make_handle(callback, args) try: mask, (reader, writer) = self._selector.get_info(fd) @@ -174,8 +172,6 @@ def add_writer(self, fd, callback, *args): if writer is not None: writer.cancel() - return handle - def remove_writer(self, fd): """Remove a writer callback.""" try: diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 87514ef1..35926406 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -72,8 +72,6 @@ def add_signal_handler(self, sig, callback, *args): else: raise - return handle - def _handle_signal(self, sig, arg): """Internal helper that is the actual signal handler.""" handle = self._signal_handlers.get(sig) @@ -82,7 +80,7 @@ def _handle_signal(self, sig, arg): if handle.cancelled: self.remove_signal_handler(sig) # Remove it properly. else: - self.call_soon_threadsafe(handle) + self._add_callback_signalsafe(handle) def remove_signal_handler(self, sig): """Remove a handler for a signal. UNIX only. From 7ae111d53a41eea470d1c47e496c25ba0502719b Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 2 May 2013 16:32:34 -0700 Subject: [PATCH 0457/1502] queues tests --- tests/queues_test.py | 68 +++++++++++++++++++++++++++++++------------- tulip/queues.py | 15 +++++----- 2 files changed, 56 insertions(+), 27 deletions(-) diff --git a/tests/queues_test.py b/tests/queues_test.py index 714465d9..c47d1f74 100644 --- a/tests/queues_test.py +++ b/tests/queues_test.py @@ -4,6 +4,7 @@ import queue from tulip import events +from tulip import futures from tulip import locks from tulip import queues from tulip import tasks @@ -12,11 +13,11 @@ class _QueueTestBase(unittest.TestCase): def setUp(self): - self.event_loop = events.new_event_loop() - events.set_event_loop(self.event_loop) + self.loop = events.new_event_loop() + events.set_event_loop(self.loop) def tearDown(self): - self.event_loop.close() + self.loop.close() class QueueBasicTests(_QueueTestBase): @@ -41,7 +42,7 @@ def add_getter(): yield from tasks.sleep(0.1) self.assertTrue('_getters[1]' in fn(q)) - self.event_loop.run_until_complete(add_getter()) + self.loop.run_until_complete(add_getter()) @tasks.coroutine def add_putter(): @@ -53,12 +54,20 @@ def add_putter(): yield from tasks.sleep(0.1) self.assertTrue('_putters[1]' in fn(q)) - self.event_loop.run_until_complete(add_putter()) + self.loop.run_until_complete(add_putter()) q = queues.Queue() q.put_nowait(1) self.assertTrue('_queue=[1]' in fn(q)) + def test_ctor_loop(self): + loop = unittest.mock.Mock() + q = queues.Queue(loop=loop) + self.assertIs(q._loop, loop) + + q = queues.Queue() + self.assertIs(q._loop, events.get_event_loop()) + def test_repr(self): self._test_repr_or_str(repr, True) @@ -119,7 +128,7 @@ def test(): self.assertTrue(t.done()) self.assertTrue(t.result()) - self.event_loop.run_until_complete(test()) + self.loop.run_until_complete(test()) class QueueGetTests(_QueueTestBase): @@ -132,8 +141,20 @@ def test_blocking_get(self): def queue_get(): return (yield from q.get()) - res = self.event_loop.run_until_complete(queue_get()) + res = self.loop.run_until_complete(queue_get()) + self.assertEqual(1, res) + + def test_get_with_putters(self): + q = queues.Queue(1) + q.put_nowait(1) + + waiter = futures.Future() + q._putters.append((2, waiter)) + + res = self.loop.run_until_complete(q.get()) self.assertEqual(1, res) + self.assertTrue(waiter.done()) + self.assertIsNone(waiter.result()) def test_blocking_get_wait(self): q = queues.Queue() @@ -150,7 +171,7 @@ def queue_get(): @tasks.coroutine def queue_put(): - self.event_loop.call_later(0.01, q.put_nowait, 1) + self.loop.call_later(0.01, q.put_nowait, 1) queue_get_task = tasks.Task(queue_get()) yield from started.wait() self.assertFalse(finished) @@ -158,7 +179,7 @@ def queue_put(): self.assertTrue(finished) return res - res = self.event_loop.run_until_complete(queue_put()) + res = self.loop.run_until_complete(queue_put()) self.assertEqual(1, res) def test_nonblocking_get(self): @@ -188,7 +209,7 @@ def queue_get(): self.assertTrue(t.done()) self.assertIsNone(t.result()) - self.event_loop.run_until_complete(queue_get()) + self.loop.run_until_complete(queue_get()) def test_get_timeout_cancelled(self): q = queues.Queue() @@ -204,7 +225,7 @@ def test(): q.put_nowait(1) return (yield from get_task) - self.assertEqual(1, self.event_loop.run_until_complete(test())) + self.assertEqual(1, self.loop.run_until_complete(test())) class QueuePutTests(_QueueTestBase): @@ -217,7 +238,7 @@ def queue_put(): # No maxsize, won't block. yield from q.put(1) - self.event_loop.run_until_complete(queue_put()) + self.loop.run_until_complete(queue_put()) def test_blocking_put_wait(self): q = queues.Queue(maxsize=1) @@ -234,14 +255,14 @@ def queue_put(): @tasks.coroutine def queue_get(): - self.event_loop.call_later(0.01, q.get_nowait) + self.loop.call_later(0.01, q.get_nowait) queue_put_task = tasks.Task(queue_put()) yield from started.wait() self.assertFalse(finished) yield from queue_put_task self.assertTrue(finished) - self.event_loop.run_until_complete(queue_get()) + self.loop.run_until_complete(queue_get()) def test_nonblocking_put(self): q = queues.Queue() @@ -274,7 +295,7 @@ def queue_put(): q.put_nowait(3) self.assertEqual(3, q.get_nowait()) - self.event_loop.run_until_complete(queue_put()) + self.loop.run_until_complete(queue_put()) def test_put_timeout_cancelled(self): q = queues.Queue() @@ -289,7 +310,7 @@ def test(): return (yield from q.get()) t = tasks.Task(queue_put()) - self.assertEqual(1, self.event_loop.run_until_complete(test())) + self.assertEqual(1, self.loop.run_until_complete(test())) self.assertTrue(t.done()) self.assertTrue(t.result()) @@ -320,7 +341,7 @@ class JoinableQueueTests(_QueueTestBase): def test_task_done_underflow(self): q = queues.JoinableQueue() - self.assertRaises(q.task_done) + self.assertRaises(ValueError, q.task_done) def test_task_done(self): q = queues.JoinableQueue() @@ -348,7 +369,7 @@ def test(): yield from q.join() - self.event_loop.run_until_complete(test()) + self.loop.run_until_complete(test()) self.assertEqual(sum(range(100)), accumulator) def test_join_empty_queue(self): @@ -362,7 +383,7 @@ def join(): yield from q.join() yield from q.join() - self.event_loop.run_until_complete(join()) + self.loop.run_until_complete(join()) def test_join_timeout(self): q = queues.JoinableQueue() @@ -373,7 +394,14 @@ def join(): yield from q.join(0.1) # Join completes in ~ 0.1 seconds, although no one calls task_done(). - self.event_loop.run_until_complete(join()) + self.loop.run_until_complete(join()) + + def test_format(self): + q = queues.JoinableQueue() + self.assertEqual(q._format(), 'maxsize=0') + + q._unfinished_tasks = 2 + self.assertEqual(q._format(), 'maxsize=0 tasks=2') if __name__ == '__main__': diff --git a/tulip/queues.py b/tulip/queues.py index a87a8557..9bcdea61 100644 --- a/tulip/queues.py +++ b/tulip/queues.py @@ -25,8 +25,11 @@ class Queue: interrupted between calling qsize() and doing an operation on the Queue. """ - def __init__(self, maxsize=0): - self._event_loop = events.get_event_loop() + def __init__(self, maxsize=0, loop=None): + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop self._maxsize = maxsize # Futures. @@ -118,8 +121,7 @@ def put(self, item, timeout=None): getter.set_result(self._get()) elif self._maxsize > 0 and self._maxsize == self.qsize(): - waiter = futures.Future( - event_loop=self._event_loop, timeout=timeout) + waiter = futures.Future(event_loop=self._loop, timeout=timeout) self._putters.append((item, waiter)) try: @@ -172,15 +174,14 @@ def get(self, timeout=None): # run, we need to defer the put for a tick to ensure that # getters and putters alternate perfectly. See # ChannelTest.test_wait. - self._event_loop.call_soon(putter.set_result, None) + self._loop.call_soon(putter.set_result, None) return self._get() elif self.qsize(): return self._get() else: - waiter = futures.Future( - event_loop=self._event_loop, timeout=timeout) + waiter = futures.Future(event_loop=self._loop, timeout=timeout) self._getters.append(waiter) try: From e4835d96ec69667beb24632a118615d5663bb9c7 Mon Sep 17 00:00:00 2001 From: Giampaolo Rodola' Date: Fri, 3 May 2013 02:50:09 +0200 Subject: [PATCH 0458/1502] provide a 'local_addr' parameter for create_connection() (review at https://codereview.appspot.com/8997043/) --- tests/base_events_test.py | 2 +- tests/events_test.py | 31 ++++++++++++++++++++++++++----- tulip/base_events.py | 34 ++++++++++++++++++++++++++++++---- tulip/events.py | 3 ++- 4 files changed, 59 insertions(+), 11 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index 6dfaeecf..7c8771db 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -288,6 +288,6 @@ def _socket(*args, **kw): task = tasks.Task( self.event_loop.create_connection(MyProto, 'example.com', 80)) - task._step() + yield from tasks.wait(task) exc = task.exception() self.assertEqual("Multiple exceptions: err1, err2", str(exc)) diff --git a/tests/events_test.py b/tests/events_test.py index ace36d3e..80240500 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -506,15 +506,16 @@ def test_create_connection_no_host_port_sock(self): self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) def test_create_connection_no_getaddrinfo(self): - getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() - getaddrinfo.return_value = [] - + @tasks.task + def getaddrinfo(*args, **kw): + yield from [] + self.event_loop.getaddrinfo = getaddrinfo coro = self.event_loop.create_connection(MyProto, 'example.com', 80) self.assertRaises( socket.error, self.event_loop.run_until_complete, coro) def test_create_connection_connect_err(self): - @tasks.coroutine + @tasks.task def getaddrinfo(*args, **kw): yield from [] return [(2, 1, 6, '', ('107.6.106.82', 80))] @@ -527,7 +528,7 @@ def getaddrinfo(*args, **kw): socket.error, self.event_loop.run_until_complete, coro) def test_create_connection_mutiple_errors(self): - @tasks.coroutine + @tasks.task def getaddrinfo(*args, **kw): yield from [] return [(2, 1, 6, '', ('107.6.106.82', 80)), @@ -541,6 +542,26 @@ def getaddrinfo(*args, **kw): self.assertRaises( socket.error, self.event_loop.run_until_complete, coro) + def test_create_connection_local_addr(self): + with test_utils.run_test_server(self.event_loop) as httpd: + port = find_unused_port() + f = self.event_loop.create_connection( + lambda: MyProto(create_future=True), + *httpd.address, local_addr=(httpd.address[0], port)) + tr, pr = self.event_loop.run_until_complete(f) + expected = pr.transport.get_extra_info('socket').getsockname()[1] + self.assertEqual(port, expected) + + def test_create_connection_local_addr_in_use(self): + with test_utils.run_test_server(self.event_loop) as httpd: + f = self.event_loop.create_connection( + lambda: MyProto(create_future=True), + *httpd.address, local_addr=httpd.address) + with self.assertRaises(socket.error) as cm: + self.event_loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + self.assertIn(str(httpd.address), cm.exception.strerror) + def test_start_serving(self): proto = None diff --git a/tulip/base_events.py b/tulip/base_events.py index 5d725993..bf00dc23 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -249,25 +249,51 @@ def getnameinfo(self, sockaddr, flags=0): @tasks.coroutine def create_connection(self, protocol_factory, host=None, port=None, *, - ssl=None, family=0, proto=0, flags=0, sock=None): + ssl=None, family=0, proto=0, flags=0, sock=None, + local_addr=None): """XXX""" if host is not None or port is not None: if sock is not None: raise ValueError( 'host/port and sock can not be specified at the same time') - infos = yield from self.getaddrinfo( - host, port, family=family, - type=socket.SOCK_STREAM, proto=proto, flags=flags) + f1 = self.getaddrinfo(host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + fs = [f1] + if local_addr is not None: + f2 = self.getaddrinfo(*local_addr, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + fs.append(f2) + else: + f2 = None + + yield from tasks.wait(fs) + infos = f1.result() if not infos: raise socket.error('getaddrinfo() returned empty list') + if f2 is not None: + laddr_infos = f2.result() + if not laddr_infos: + raise socket.error('getaddrinfo() returned empty list') exceptions = [] for family, type, proto, cname, address in infos: try: sock = socket.socket(family=family, type=type, proto=proto) sock.setblocking(False) + if f2 is not None: + for _, _, _, _, laddr in laddr_infos: + try: + sock.bind(laddr) + break + except socket.error as exc: + exc = socket.error(exc.errno, "error while " \ + "attempting to bind on address " \ + "%r: %s" % (laddr, exc.strerror.lower())) + exceptions.append(exc) + else: + continue yield from self.sock_connect(sock, address) except socket.error as exc: if sock is not None: diff --git a/tulip/events.py b/tulip/events.py index 25ec15a1..46854b5d 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -190,7 +190,8 @@ def getnameinfo(self, sockaddr, flags=0): raise NotImplementedError def create_connection(self, protocol_factory, host=None, port=None, *, - ssl=None, family=0, proto=0, flags=0, sock=None): + ssl=None, family=0, proto=0, flags=0, sock=None, + local_addr): raise NotImplementedError def start_serving(self, protocol_factory, host=None, port=None, *, From 7031493ad50f729853438dd369932126ec79b1d3 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 2 May 2013 20:01:47 -0700 Subject: [PATCH 0459/1502] Fix failing test due to missing default for local_addr. --- tests/events_test.py | 2 +- tulip/events.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 80240500..99d97ed7 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -1209,7 +1209,7 @@ def callback(*args): class AbstractEventLoopTests(unittest.TestCase): - def test_not_imlemented(self): + def test_not_implemented(self): f = unittest.mock.Mock() ev_loop = events.AbstractEventLoop() self.assertRaises( diff --git a/tulip/events.py b/tulip/events.py index 46854b5d..72a660f0 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -191,7 +191,7 @@ def getnameinfo(self, sockaddr, flags=0): def create_connection(self, protocol_factory, host=None, port=None, *, ssl=None, family=0, proto=0, flags=0, sock=None, - local_addr): + local_addr=None): raise NotImplementedError def start_serving(self, protocol_factory, host=None, port=None, *, From 43eb83993ebed20a073dd021035e231f72ba8364 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 3 May 2013 12:16:03 -0700 Subject: [PATCH 0460/1502] add 'loop' parameter to locks ctor; remove run_once() from locks tests --- tests/locks_test.py | 221 ++++++++++++++++++++++++-------------------- tulip/locks.py | 33 ++++--- tulip/queues.py | 2 +- tulip/test_utils.py | 7 ++ 4 files changed, 148 insertions(+), 115 deletions(-) diff --git a/tests/locks_test.py b/tests/locks_test.py index a2e03381..f3bd01e0 100644 --- a/tests/locks_test.py +++ b/tests/locks_test.py @@ -8,16 +8,25 @@ from tulip import futures from tulip import locks from tulip import tasks +from tulip.test_utils import run_once class LockTests(unittest.TestCase): def setUp(self): - self.event_loop = events.new_event_loop() - events.set_event_loop(self.event_loop) + self.loop = events.new_event_loop() + events.set_event_loop(self.loop) def tearDown(self): - self.event_loop.close() + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + lock = locks.Lock(loop=loop) + self.assertIs(lock._loop, loop) + + lock = locks.Lock() + self.assertIs(lock._loop, events.get_event_loop()) def test_repr(self): lock = locks.Lock() @@ -27,7 +36,7 @@ def test_repr(self): def acquire_lock(): yield from lock - self.event_loop.run_until_complete(acquire_lock()) + self.loop.run_until_complete(acquire_lock()) self.assertTrue(repr(lock).endswith('[locked]>')) def test_lock(self): @@ -37,7 +46,7 @@ def test_lock(self): def acquire_lock(): return (yield from lock) - res = self.event_loop.run_until_complete(acquire_lock()) + res = self.loop.run_until_complete(acquire_lock()) self.assertTrue(res) self.assertTrue(lock.locked()) @@ -49,8 +58,7 @@ def test_acquire(self): lock = locks.Lock() result = [] - self.assertTrue( - self.event_loop.run_until_complete(lock.acquire())) + self.assertTrue(self.loop.run_until_complete(lock.acquire())) @tasks.coroutine def c1(result): @@ -73,24 +81,24 @@ def c3(result): t1 = tasks.Task(c1(result)) t2 = tasks.Task(c2(result)) - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([], result) lock.release() - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([1], result) - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([1], result) t3 = tasks.Task(c3(result)) lock.release() - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([1, 2], result) lock.release() - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([1, 2, 3], result) self.assertTrue(t1.done()) @@ -102,47 +110,44 @@ def c3(result): def test_acquire_timeout(self): lock = locks.Lock() - self.assertTrue( - self.event_loop.run_until_complete(lock.acquire())) + self.assertTrue(self.loop.run_until_complete(lock.acquire())) t0 = time.monotonic() - acquired = self.event_loop.run_until_complete( - lock.acquire(timeout=0.1)) + acquired = self.loop.run_until_complete(lock.acquire(timeout=0.1)) self.assertFalse(acquired) total_time = (time.monotonic() - t0) self.assertTrue(0.08 < total_time < 0.12) lock = locks.Lock() - self.event_loop.run_until_complete(lock.acquire()) + self.loop.run_until_complete(lock.acquire()) - self.event_loop.call_later(0.01, lock.release) - acquired = self.event_loop.run_until_complete(lock.acquire(10.1)) + self.loop.call_later(0.01, lock.release) + acquired = self.loop.run_until_complete(lock.acquire(10.1)) self.assertTrue(acquired) def test_acquire_timeout_mixed(self): lock = locks.Lock() - self.event_loop.run_until_complete(lock.acquire()) + self.loop.run_until_complete(lock.acquire()) tasks.Task(lock.acquire()) tasks.Task(lock.acquire()) acquire_task = tasks.Task(lock.acquire(0.01)) tasks.Task(lock.acquire()) - acquired = self.event_loop.run_until_complete(acquire_task) + acquired = self.loop.run_until_complete(acquire_task) self.assertFalse(acquired) self.assertEqual(3, len(lock._waiters)) def test_acquire_cancel(self): lock = locks.Lock() - self.assertTrue( - self.event_loop.run_until_complete(lock.acquire())) + self.assertTrue(self.loop.run_until_complete(lock.acquire())) task = tasks.Task(lock.acquire()) - self.event_loop.call_soon(task.cancel) + self.loop.call_soon(task.cancel) self.assertRaises( futures.CancelledError, - self.event_loop.run_until_complete, task) + self.loop.run_until_complete, task) self.assertFalse(lock._waiters) def test_release_not_acquired(self): @@ -152,7 +157,7 @@ def test_release_not_acquired(self): def test_release_no_waiters(self): lock = locks.Lock() - self.event_loop.run_until_complete(lock.acquire()) + self.loop.run_until_complete(lock.acquire()) self.assertTrue(lock.locked()) lock.release() @@ -165,7 +170,7 @@ def test_context_manager(self): def acquire_lock(): return (yield from lock) - with self.event_loop.run_until_complete(acquire_lock()): + with self.loop.run_until_complete(acquire_lock()): self.assertTrue(lock.locked()) self.assertFalse(lock.locked()) @@ -185,11 +190,19 @@ def test_context_manager_no_yield(self): class EventWaiterTests(unittest.TestCase): def setUp(self): - self.event_loop = events.new_event_loop() - events.set_event_loop(self.event_loop) + self.loop = events.new_event_loop() + events.set_event_loop(self.loop) def tearDown(self): - self.event_loop.close() + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + ev = locks.EventWaiter(loop=loop) + self.assertIs(ev._loop, loop) + + ev = locks.EventWaiter() + self.assertIs(ev._loop, events.get_event_loop()) def test_repr(self): ev = locks.EventWaiter() @@ -222,13 +235,13 @@ def c3(result): t1 = tasks.Task(c1(result)) t2 = tasks.Task(c2(result)) - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([], result) t3 = tasks.Task(c3(result)) ev.set() - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([3, 1, 2], result) self.assertTrue(t1.done()) @@ -242,21 +255,21 @@ def test_wait_on_set(self): ev = locks.EventWaiter() ev.set() - res = self.event_loop.run_until_complete(ev.wait()) + res = self.loop.run_until_complete(ev.wait()) self.assertTrue(res) def test_wait_timeout(self): ev = locks.EventWaiter() t0 = time.monotonic() - res = self.event_loop.run_until_complete(ev.wait(0.1)) + res = self.loop.run_until_complete(ev.wait(0.1)) self.assertFalse(res) total_time = (time.monotonic() - t0) self.assertTrue(0.08 < total_time < 0.12) ev = locks.EventWaiter() - self.event_loop.call_later(0.01, ev.set) - acquired = self.event_loop.run_until_complete(ev.wait(10.1)) + self.loop.call_later(0.01, ev.set) + acquired = self.loop.run_until_complete(ev.wait(10.1)) self.assertTrue(acquired) def test_wait_timeout_mixed(self): @@ -267,7 +280,7 @@ def test_wait_timeout_mixed(self): tasks.Task(ev.wait()) t0 = time.monotonic() - acquired = self.event_loop.run_until_complete(acquire_task) + acquired = self.loop.run_until_complete(acquire_task) self.assertFalse(acquired) total_time = (time.monotonic() - t0) @@ -279,10 +292,10 @@ def test_wait_cancel(self): ev = locks.EventWaiter() wait = tasks.Task(ev.wait()) - self.event_loop.call_soon(wait.cancel) + self.loop.call_soon(wait.cancel) self.assertRaises( futures.CancelledError, - self.event_loop.run_until_complete, wait) + self.loop.run_until_complete, wait) self.assertFalse(ev._waiters) def test_clear(self): @@ -306,7 +319,7 @@ def c1(result): return True t = tasks.Task(c1(result)) - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([], result) ev.set() @@ -317,7 +330,7 @@ def c1(result): ev.set() self.assertEqual(1, len(ev._waiters)) - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([1], result) self.assertEqual(0, len(ev._waiters)) @@ -328,11 +341,11 @@ def c1(result): class ConditionTests(unittest.TestCase): def setUp(self): - self.event_loop = events.new_event_loop() - events.set_event_loop(self.event_loop) + self.loop = events.new_event_loop() + events.set_event_loop(self.loop) def tearDown(self): - self.event_loop.close() + self.loop.close() def test_wait(self): cond = locks.Condition() @@ -363,34 +376,33 @@ def c3(result): t2 = tasks.Task(c2(result)) t3 = tasks.Task(c3(result)) - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([], result) self.assertFalse(cond.locked()) - self.assertTrue( - self.event_loop.run_until_complete(cond.acquire())) + self.assertTrue(self.loop.run_until_complete(cond.acquire())) cond.notify() - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([], result) self.assertTrue(cond.locked()) cond.release() - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([1], result) self.assertTrue(cond.locked()) cond.notify(2) - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([1], result) self.assertTrue(cond.locked()) cond.release() - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([1, 2], result) self.assertTrue(cond.locked()) cond.release() - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([1, 2, 3], result) self.assertTrue(cond.locked()) @@ -403,10 +415,10 @@ def c3(result): def test_wait_timeout(self): cond = locks.Condition() - self.event_loop.run_until_complete(cond.acquire()) + self.loop.run_until_complete(cond.acquire()) t0 = time.monotonic() - wait = self.event_loop.run_until_complete(cond.wait(0.1)) + wait = self.loop.run_until_complete(cond.wait(0.1)) self.assertFalse(wait) self.assertTrue(cond.locked()) @@ -415,13 +427,13 @@ def test_wait_timeout(self): def test_wait_cancel(self): cond = locks.Condition() - self.event_loop.run_until_complete(cond.acquire()) + self.loop.run_until_complete(cond.acquire()) wait = tasks.Task(cond.wait()) - self.event_loop.call_soon(wait.cancel) + self.loop.call_soon(wait.cancel) self.assertRaises( futures.CancelledError, - self.event_loop.run_until_complete, wait) + self.loop.run_until_complete, wait) self.assertFalse(cond._condition_waiters) self.assertTrue(cond.locked()) @@ -429,7 +441,7 @@ def test_wait_unacquired(self): cond = locks.Condition() self.assertRaises( RuntimeError, - self.event_loop.run_until_complete, cond.wait()) + self.loop.run_until_complete, cond.wait()) def test_wait_for(self): cond = locks.Condition() @@ -450,20 +462,20 @@ def c1(result): t = tasks.Task(c1(result)) - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([], result) - self.event_loop.run_until_complete(cond.acquire()) + self.loop.run_until_complete(cond.acquire()) cond.notify() cond.release() - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([], result) presult = True - self.event_loop.run_until_complete(cond.acquire()) + self.loop.run_until_complete(cond.acquire()) cond.notify() cond.release() - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([1], result) self.assertTrue(t.done()) @@ -490,16 +502,16 @@ def c1(result): t0 = time.monotonic() - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([], result) - self.event_loop.run_until_complete(cond.acquire()) + self.loop.run_until_complete(cond.acquire()) cond.notify() cond.release() - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([], result) - self.event_loop.run_until_complete(wait_for) + self.loop.run_until_complete(wait_for) self.assertEqual([2], result) self.assertEqual(3, predicate.call_count) @@ -510,13 +522,12 @@ def test_wait_for_unacquired(self): cond = locks.Condition() # predicate can return true immediately - res = self.event_loop.run_until_complete( - cond.wait_for(lambda: [1, 2, 3])) + res = self.loop.run_until_complete(cond.wait_for(lambda: [1, 2, 3])) self.assertEqual([1, 2, 3], res) self.assertRaises( RuntimeError, - self.event_loop.run_until_complete, + self.loop.run_until_complete, cond.wait_for(lambda: False)) def test_notify(self): @@ -551,20 +562,20 @@ def c3(result): t2 = tasks.Task(c2(result)) t3 = tasks.Task(c3(result)) - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([], result) - self.event_loop.run_until_complete(cond.acquire()) + self.loop.run_until_complete(cond.acquire()) cond.notify(1) cond.release() - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([1], result) - self.event_loop.run_until_complete(cond.acquire()) + self.loop.run_until_complete(cond.acquire()) cond.notify(1) cond.notify(2048) cond.release() - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([1, 2, 3], result) self.assertTrue(t1.done()) @@ -598,13 +609,13 @@ def c2(result): t1 = tasks.Task(c1(result)) t2 = tasks.Task(c2(result)) - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([], result) - self.event_loop.run_until_complete(cond.acquire()) + self.loop.run_until_complete(cond.acquire()) cond.notify_all() cond.release() - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([1, 2], result) self.assertTrue(t1.done()) @@ -624,17 +635,25 @@ def test_notify_all_unacquired(self): class SemaphoreTests(unittest.TestCase): def setUp(self): - self.event_loop = events.new_event_loop() - events.set_event_loop(self.event_loop) + self.loop = events.new_event_loop() + events.set_event_loop(self.loop) def tearDown(self): - self.event_loop.close() + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + sem = locks.Semaphore(loop=loop) + self.assertIs(sem._loop, loop) + + sem = locks.Semaphore() + self.assertIs(sem._loop, events.get_event_loop()) def test_repr(self): sem = locks.Semaphore() self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) - self.event_loop.run_until_complete(sem.acquire()) + self.loop.run_until_complete(sem.acquire()) self.assertTrue(repr(sem).endswith('[locked]>')) def test_semaphore(self): @@ -645,7 +664,7 @@ def test_semaphore(self): def acquire_lock(): return (yield from sem) - res = self.event_loop.run_until_complete(acquire_lock()) + res = self.loop.run_until_complete(acquire_lock()) self.assertTrue(res) self.assertTrue(sem.locked()) @@ -662,10 +681,8 @@ def test_acquire(self): sem = locks.Semaphore(3) result = [] - self.assertTrue( - self.event_loop.run_until_complete(sem.acquire())) - self.assertTrue( - self.event_loop.run_until_complete(sem.acquire())) + self.assertTrue(self.loop.run_until_complete(sem.acquire())) + self.assertTrue(self.loop.run_until_complete(sem.acquire())) self.assertFalse(sem.locked()) @tasks.coroutine @@ -696,7 +713,7 @@ def c4(result): t2 = tasks.Task(c2(result)) t3 = tasks.Task(c3(result)) - self.event_loop.run_once() + run_once(self.loop) self.assertEqual([1], result) self.assertTrue(sem.locked()) self.assertEqual(2, len(sem._waiters)) @@ -708,7 +725,7 @@ def c4(result): sem.release() self.assertEqual(2, sem._value) - self.event_loop.run_once() + run_once(self.loop) self.assertEqual(0, sem._value) self.assertEqual([1, 2, 3], result) self.assertTrue(sem.locked()) @@ -725,32 +742,32 @@ def c4(result): def test_acquire_timeout(self): sem = locks.Semaphore() - self.event_loop.run_until_complete(sem.acquire()) + self.loop.run_until_complete(sem.acquire()) t0 = time.monotonic() - acquired = self.event_loop.run_until_complete(sem.acquire(0.1)) + acquired = self.loop.run_until_complete(sem.acquire(0.1)) self.assertFalse(acquired) total_time = (time.monotonic() - t0) self.assertTrue(0.08 < total_time < 0.12) sem = locks.Semaphore() - self.event_loop.run_until_complete(sem.acquire()) + self.loop.run_until_complete(sem.acquire()) - self.event_loop.call_later(0.01, sem.release) - acquired = self.event_loop.run_until_complete(sem.acquire(10.1)) + self.loop.call_later(0.01, sem.release) + acquired = self.loop.run_until_complete(sem.acquire(10.1)) self.assertTrue(acquired) def test_acquire_timeout_mixed(self): sem = locks.Semaphore() - self.event_loop.run_until_complete(sem.acquire()) + self.loop.run_until_complete(sem.acquire()) tasks.Task(sem.acquire()) tasks.Task(sem.acquire()) acquire_task = tasks.Task(sem.acquire(0.1)) tasks.Task(sem.acquire()) t0 = time.monotonic() - acquired = self.event_loop.run_until_complete(acquire_task) + acquired = self.loop.run_until_complete(acquire_task) self.assertFalse(acquired) total_time = (time.monotonic() - t0) @@ -760,13 +777,13 @@ def test_acquire_timeout_mixed(self): def test_acquire_cancel(self): sem = locks.Semaphore() - self.event_loop.run_until_complete(sem.acquire()) + self.loop.run_until_complete(sem.acquire()) acquire = tasks.Task(sem.acquire()) - self.event_loop.call_soon(acquire.cancel) + self.loop.call_soon(acquire.cancel) self.assertRaises( futures.CancelledError, - self.event_loop.run_until_complete, acquire) + self.loop.run_until_complete, acquire) self.assertFalse(sem._waiters) def test_release_not_acquired(self): @@ -776,7 +793,7 @@ def test_release_not_acquired(self): def test_release_no_waiters(self): sem = locks.Semaphore() - self.event_loop.run_until_complete(sem.acquire()) + self.loop.run_until_complete(sem.acquire()) self.assertTrue(sem.locked()) sem.release() @@ -789,11 +806,11 @@ def test_context_manager(self): def acquire_lock(): return (yield from sem) - with self.event_loop.run_until_complete(acquire_lock()): + with self.loop.run_until_complete(acquire_lock()): self.assertFalse(sem.locked()) self.assertEqual(1, sem._value) - with self.event_loop.run_until_complete(acquire_lock()): + with self.loop.run_until_complete(acquire_lock()): self.assertTrue(sem.locked()) self.assertEqual(2, sem._value) diff --git a/tulip/locks.py b/tulip/locks.py index ff841442..d425d064 100644 --- a/tulip/locks.py +++ b/tulip/locks.py @@ -62,10 +62,13 @@ class Lock: """ - def __init__(self): + def __init__(self, *, loop=None): self._waiters = collections.deque() self._locked = False - self._event_loop = events.get_event_loop() + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() def __repr__(self): res = super().__repr__() @@ -94,7 +97,7 @@ def acquire(self, timeout=None): self._locked = True return True - fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + fut = futures.Future(event_loop=self._loop, timeout=timeout) self._waiters.append(fut) try: @@ -150,10 +153,13 @@ class EventWaiter: false. """ - def __init__(self): + def __init__(self, *, loop=None): self._waiters = collections.deque() self._value = False - self._event_loop = events.get_event_loop() + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() def __repr__(self): res = super().__repr__() @@ -202,7 +208,7 @@ def wait(self, timeout=None): if self._value: return True - fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + fut = futures.Future(event_loop=self._loop, timeout=timeout) self._waiters.append(fut) try: @@ -225,8 +231,8 @@ class Condition(Lock): coroutine. """ - def __init__(self): - super().__init__() + def __init__(self, *, loop=None): + super().__init__(loop=loop) self._condition_waiters = collections.deque() @@ -253,7 +259,7 @@ def wait(self, timeout=None): self.release() - fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + fut = futures.Future(event_loop=self._loop, timeout=timeout) self._condition_waiters.append(fut) try: @@ -346,7 +352,7 @@ class Semaphore: acquire() calls ValueError is raised. """ - def __init__(self, value=1, bound=False): + def __init__(self, value=1, bound=False, *, loop=None): if value < 0: raise ValueError("Semaphore initial value must be > 0") self._value = value @@ -354,7 +360,10 @@ def __init__(self, value=1, bound=False): self._bound_value = value self._waiters = collections.deque() self._locked = False - self._event_loop = events.get_event_loop() + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() def __repr__(self): res = super().__repr__() @@ -386,7 +395,7 @@ def acquire(self, timeout=None): self._locked = True return True - fut = futures.Future(event_loop=self._event_loop, timeout=timeout) + fut = futures.Future(event_loop=self._loop, timeout=timeout) self._waiters.append(fut) try: diff --git a/tulip/queues.py b/tulip/queues.py index 9bcdea61..575049ee 100644 --- a/tulip/queues.py +++ b/tulip/queues.py @@ -244,7 +244,7 @@ class JoinableQueue(Queue): def __init__(self, maxsize=0): self._unfinished_tasks = 0 - self._finished = locks.EventWaiter() + self._finished = locks.EventWaiter(loop=self._loop) self._finished.set() super().__init__(maxsize=maxsize) diff --git a/tulip/test_utils.py b/tulip/test_utils.py index 3cd5ba95..d1f05614 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -31,6 +31,13 @@ from socket import socketpair # pragma: no cover +def run_once(loop): + @tulip.task + def once(): + pass + loop.run_until_complete(once()) + + @contextlib.contextmanager def run_test_server(loop, *, host='127.0.0.1', port=0, use_ssl=False, router=None): From 2553129ec43dd3b5801fcd57f906ea5c88b68951 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 3 May 2013 12:30:21 -0700 Subject: [PATCH 0461/1502] added loop parameter to JoinableQueue --- tulip/queues.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tulip/queues.py b/tulip/queues.py index 575049ee..ba0b626d 100644 --- a/tulip/queues.py +++ b/tulip/queues.py @@ -25,7 +25,7 @@ class Queue: interrupted between calling qsize() and doing an operation on the Queue. """ - def __init__(self, maxsize=0, loop=None): + def __init__(self, maxsize=0, *, loop=None): if loop is None: self._loop = events.get_event_loop() else: @@ -242,11 +242,11 @@ def _get(self): class JoinableQueue(Queue): """A subclass of Queue with task_done() and join() methods.""" - def __init__(self, maxsize=0): + def __init__(self, maxsize=0, *, loop=None): + super().__init__(maxsize=maxsize, loop=loop) self._unfinished_tasks = 0 self._finished = locks.EventWaiter(loop=self._loop) self._finished.set() - super().__init__(maxsize=maxsize) def _format(self): result = Queue._format(self) From 816884d3dfdcccd1f997ba38293fd753765ce7fc Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 3 May 2013 15:15:47 -0700 Subject: [PATCH 0462/1502] Better solution for logging tracebacks from abandoned futures. --- tests/futures_test.py | 60 ++++++++-------------- tulip/futures.py | 115 +++++++++++++++++++++++++----------------- tulip/tasks.py | 3 -- 3 files changed, 88 insertions(+), 90 deletions(-) diff --git a/tests/futures_test.py b/tests/futures_test.py index adc283d7..0450e51a 100644 --- a/tests/futures_test.py +++ b/tests/futures_test.py @@ -163,68 +163,48 @@ def test(): fut.cancel() @unittest.mock.patch('tulip.futures.tulip_log') - def test_del_norm_level(self, log): - self.loop.set_log_level(logging.CRITICAL) - + def test_tb_logger_abandoned(self, m_log): fut = futures.Future() del fut - self.assertFalse(log.error.called) + self.assertFalse(m_log.error.called) @unittest.mock.patch('tulip.futures.tulip_log') - def test_del_normal(self, log): - self.loop.set_log_level(futures.STACK_DEBUG) - + def test_tb_logger_result_unretrieved(self, m_log): fut = futures.Future() - fut.set_result(True) - fut.result() + fut.set_result(42) del fut - self.assertFalse(log.error.called) + self.assertFalse(m_log.error.called) @unittest.mock.patch('tulip.futures.tulip_log') - def test_del_not_done(self, log): - self.loop.set_log_level(futures.STACK_DEBUG) - + def test_tb_logger_result_retrieved(self, m_log): fut = futures.Future() - r_fut = repr(fut) + fut.set_result(42) + fut.result() del fut - log.error.mock_calls[-1].assert_called_with( - 'Future abandoned before completion: %r', r_fut) + self.assertFalse(m_log.error.called) @unittest.mock.patch('tulip.futures.tulip_log') - def test_del_done(self, log): - self.loop.set_log_level(futures.STACK_DEBUG) - + def test_tb_logger_exception_unretrieved(self, m_log): fut = futures.Future() - next(iter(fut)) - fut.set_result(1) - r_fut = repr(fut) + fut.set_exception(RuntimeError('boom')) del fut - log.error.mock_calls[-1].assert_called_with( - 'Future result has not been requested: %r', r_fut) + self.assertTrue(m_log.error.called) @unittest.mock.patch('tulip.futures.tulip_log') - def test_del_done_skip(self, log): - self.loop.set_log_level(futures.STACK_DEBUG) - + def test_tb_logger_exception_retrieved(self, m_log): fut = futures.Future() - fut._debug_warn_result_requested = False - next(iter(fut)) - fut.set_result(1) + fut.set_exception(RuntimeError('boom')) + fut.exception() del fut - self.assertFalse(log.error.called) + self.assertFalse(m_log.error.called) @unittest.mock.patch('tulip.futures.tulip_log') - def test_del_exc(self, log): - self.loop.set_log_level(futures.STACK_DEBUG) - - exc = ValueError() + def test_tb_logger_exception_result_retrieved(self, m_log): fut = futures.Future() - fut.set_exception(exc) - r_fut = repr(fut) + fut.set_exception(RuntimeError('boom')) + self.assertRaises(RuntimeError, fut.result) del fut - log.exception.mock_calls[-1].assert_called_with( - 'Future raised an exception and nobody caught it: %r', r_fut, - exc_info=(ValueError, exc, None)) + self.assertFalse(m_log.error.called) # A fake event loop for tests. All it does is implement a call_soon method diff --git a/tulip/futures.py b/tulip/futures.py index e3baf15e..7f410d31 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -36,6 +36,66 @@ class InvalidTimeoutError(Error): # TODO: Print a nice error message. +class _TracebackLogger: + """Helper to log a traceback upon destruction if not cleared. + + This solves a nasty problem with Futures and Tasks that have an + exception set: if nobody asks for the exception, the exception is + never logged. This violates the Zen of Python: 'Errors should + never pass silently. Unless explicitly silenced.' + + However, we don't want to log the exception as soon as + set_exception() is called: if the calling code is written + properly, it will get the exception and handle it properly. But + we *do* want to log it if result() or exception() was never called + -- otherwise developers waste a lot of time wondering why their + buggy code fails silently. + + An earlier attempt added a __del__() method to the Future class + itself, but this backfired because the presence of __del__() + prevents garbage collection from breaking cycles. A way out of + this catch-22 is to avoid having a __del__() method on the Future + class itself, but instead to have a reference to a helper object + with a __del__() method that logs the traceback, where we ensure + that the helper object doesn't participate in cycles, and only the + Future has a reference to it. + + The helper object is added when set_exception() is called. When + the Future is collected, and the helper is present, the helper + object is also collected, and its __del__() method will log the + traceback. When the Future's result() or exception() method is + called (and a helper object is present), it removes the the helper + object, after calling its clear() method to prevent it from + logging. + + One downside is that we do a fair amount of work to extract the + traceback from the exception, even when it is never logged. It + would seem cheaper to just store the exception object, but that + references the traceback, which references stack frames, which may + reference the Future, which references the _TracebackLogger, and + then the _TracebackLogger would be included in a cycle, which is + what we're trying to avoid! As a compromise, we use + extract_exception() rather than format_exception(). (We may also + have to limit how many entries we extract, but then we'd need a + public API to change the limit; so let's punt on this for now.) + + PS. I don't claim credit for this solution. I first heard of it + in a discussion about closing files when they are collected. + """ + + def __init__(self, exc): + self.tb = traceback.format_exception(exc.__class__, exc, + exc.__traceback__) + + def clear(self): + self.tb = None + + def __del__(self): + if self.tb: + tulip_log.error('Future/Task exception was never retrieved:\n%s', + ''.join(self.tb)) + + class Future: """This class is *almost* compatible with concurrent.futures.Future. @@ -62,10 +122,7 @@ class Future: _blocking = False # proper use of future (yield vs yield from) - # result of the future has to be requested - _debug_stack = None - _debug_result_requested = False - _debug_warn_result_requested = True + _tb_logger = None def __init__(self, *, event_loop=None, timeout=None): """Initialize the future. @@ -84,12 +141,6 @@ def __init__(self, *, event_loop=None, timeout=None): self._timeout_handle = self._event_loop.call_later( timeout, self.cancel) - if __debug__: - if self._event_loop.get_log_level() <= STACK_DEBUG: - out = io.StringIO() - traceback.print_stack(file=out) - self._debug_stack = out.getvalue() - def __repr__(self): res = self.__class__.__name__ if self._state == _FINISHED: @@ -169,14 +220,15 @@ def result(self, timeout=0): the future is done and has an exception set, this exception is raised. Timeout values other than 0 are not supported. """ - if __debug__: - self._debug_result_requested = True if timeout != 0: raise InvalidTimeoutError if self._state == _CANCELLED: raise CancelledError if self._state != _FINISHED: raise InvalidStateError + if self._tb_logger is not None: + self._tb_logger.clear() + self._tb_logger = None if self._exception is not None: raise self._exception return self._result @@ -189,14 +241,15 @@ def exception(self, timeout=0): CancelledError. If the future isn't done yet, raises InvalidStateError. Timeout values other than 0 are not supported. """ - if __debug__: - self._debug_result_requested = True if timeout != 0: raise InvalidTimeoutError if self._state == _CANCELLED: raise CancelledError if self._state != _FINISHED: raise InvalidStateError + if self._tb_logger is not None: + self._tb_logger.clear() + self._tb_logger = None return self._exception def add_done_callback(self, fn): @@ -247,6 +300,7 @@ def set_exception(self, exception): if self._state != _PENDING: raise InvalidStateError self._exception = exception + self._tb_logger = _TracebackLogger(exception) self._state = _FINISHED self._schedule_callbacks() @@ -275,36 +329,3 @@ def __iter__(self): yield self # This tells Task to wait for completion. assert self.done(), "yield from wasn't used with future" return self.result() # May raise too. - - if __debug__: - def __del__(self): - if (not self._debug_result_requested and - self._state != _CANCELLED and - self._event_loop is not None): - - level = self._event_loop.get_log_level() - if level > logging.WARNING: - return - - r_self = repr(self) - - if self._state == _PENDING: - tulip_log.error( - 'Future abandoned before completion: %s', r_self) - if (self._debug_stack and level <= STACK_DEBUG): - tulip_log.error(self._debug_stack) - - else: - exc = self._exception - if exc is not None: - tulip_log.exception( - 'Future raised an exception and ' - 'nobody caught it: %s', r_self, - exc_info=(exc.__class__, exc, exc.__traceback__)) - if (self._debug_stack and level <= STACK_DEBUG): - tulip_log.error(self._debug_stack) - elif self._debug_warn_result_requested: - tulip_log.error( - 'Future result has not been requested: %s', r_self) - if (self._debug_stack and level <= STACK_DEBUG): - tulip_log.error(self._debug_stack) diff --git a/tulip/tasks.py b/tulip/tasks.py index 8dfd73a3..ce539ab8 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -71,9 +71,6 @@ def task_wrapper(*args, **kwds): class Task(futures.Future): """A coroutine wrapped in a Future.""" - # disable "Future result has not been requested" warning message. - _debug_warn_result_requested = False - def __init__(self, coro, event_loop=None, timeout=None): assert inspect.isgenerator(coro) # Must be a coroutine *object*. super().__init__(event_loop=event_loop, timeout=timeout) From 799ddeec05156d5a72254afd20116fe89856a053 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Sun, 5 May 2013 12:18:27 -0700 Subject: [PATCH 0463/1502] rename event_loop to loop; logging cleanup --- tests/base_events_test.py | 169 ++++++------ tests/events_test.py | 481 ++++++++++++++++----------------- tests/futures_test.py | 8 +- tests/selector_events_test.py | 493 +++++++++++++++++----------------- tests/streams_test.py | 74 +++-- tests/subprocess_test.py | 8 +- tests/tasks_test.py | 127 +++++---- tests/unix_events_test.py | 245 +++++++++-------- tulip/base_events.py | 19 +- tulip/events.py | 20 +- tulip/futures.py | 18 +- tulip/locks.py | 8 +- tulip/queues.py | 4 +- tulip/tasks.py | 18 +- tulip/test_utils.py | 1 - 15 files changed, 827 insertions(+), 866 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index 7c8771db..b1790ff2 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -17,90 +17,90 @@ class BaseEventLoopTests(unittest.TestCase): def setUp(self): - self.event_loop = base_events.BaseEventLoop() - self.event_loop._selector = unittest.mock.Mock() - self.event_loop._selector.registered_count.return_value = 1 + self.loop = base_events.BaseEventLoop() + self.loop._selector = unittest.mock.Mock() + self.loop._selector.registered_count.return_value = 1 def test_not_implemented(self): m = unittest.mock.Mock() self.assertRaises( NotImplementedError, - self.event_loop._make_socket_transport, m, m) + self.loop._make_socket_transport, m, m) self.assertRaises( NotImplementedError, - self.event_loop._make_ssl_transport, m, m, m, m) + self.loop._make_ssl_transport, m, m, m, m) self.assertRaises( NotImplementedError, - self.event_loop._make_datagram_transport, m, m) + self.loop._make_datagram_transport, m, m) self.assertRaises( - NotImplementedError, self.event_loop._process_events, []) + NotImplementedError, self.loop._process_events, []) self.assertRaises( - NotImplementedError, self.event_loop._write_to_self) + NotImplementedError, self.loop._write_to_self) self.assertRaises( - NotImplementedError, self.event_loop._read_from_self) + NotImplementedError, self.loop._read_from_self) self.assertRaises( NotImplementedError, - self.event_loop._make_read_pipe_transport, m, m) + self.loop._make_read_pipe_transport, m, m) self.assertRaises( NotImplementedError, - self.event_loop._make_write_pipe_transport, m, m) + self.loop._make_write_pipe_transport, m, m) def test__add_callback_handle(self): h = events.Handle(lambda: False, ()) - self.event_loop._add_callback(h) - self.assertFalse(self.event_loop._scheduled) - self.assertIn(h, self.event_loop._ready) + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertIn(h, self.loop._ready) def test__add_callback_cancelled_handle(self): h = events.Handle(lambda: False, ()) h.cancel() - self.event_loop._add_callback(h) - self.assertFalse(self.event_loop._scheduled) - self.assertFalse(self.event_loop._ready) + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertFalse(self.loop._ready) def test_wrap_future(self): - f = futures.Future(event_loop=self.event_loop) - self.assertIs(self.event_loop.wrap_future(f), f) + f = futures.Future(loop=self.loop) + self.assertIs(self.loop.wrap_future(f), f) f.cancel() def test_wrap_future_concurrent(self): f = concurrent.futures.Future() - fut = self.event_loop.wrap_future(f) + fut = self.loop.wrap_future(f) self.assertIsInstance(fut, futures.Future) fut.cancel() def test_set_default_executor(self): executor = unittest.mock.Mock() - self.event_loop.set_default_executor(executor) - self.assertIs(executor, self.event_loop._default_executor) + self.loop.set_default_executor(executor) + self.assertIs(executor, self.loop._default_executor) def test_getnameinfo(self): sockaddr = unittest.mock.Mock() - self.event_loop.run_in_executor = unittest.mock.Mock() - self.event_loop.getnameinfo(sockaddr) + self.loop.run_in_executor = unittest.mock.Mock() + self.loop.getnameinfo(sockaddr) self.assertEqual( (None, socket.getnameinfo, sockaddr, 0), - self.event_loop.run_in_executor.call_args[0]) + self.loop.run_in_executor.call_args[0]) def test_call_soon(self): def cb(): pass - h = self.event_loop.call_soon(cb) + h = self.loop.call_soon(cb) self.assertEqual(h._callback, cb) self.assertIsInstance(h, events.Handle) - self.assertIn(h, self.event_loop._ready) + self.assertIn(h, self.loop._ready) def test_call_later(self): def cb(): pass - h = self.event_loop.call_later(10.0, cb) + h = self.loop.call_later(10.0, cb) self.assertIsInstance(h, events.TimerHandle) - self.assertIn(h, self.event_loop._scheduled) - self.assertNotIn(h, self.event_loop._ready) + self.assertIn(h, self.loop._scheduled) + self.assertNotIn(h, self.loop._ready) def test_call_later_negative_delays(self): calls = [] @@ -108,22 +108,22 @@ def test_call_later_negative_delays(self): def cb(arg): calls.append(arg) - self.event_loop._process_events = unittest.mock.Mock() - self.event_loop.call_later(-1, cb, 'a') - self.event_loop.call_later(-2, cb, 'b') - self.event_loop.run_once() + self.loop._process_events = unittest.mock.Mock() + self.loop.call_later(-1, cb, 'a') + self.loop.call_later(-2, cb, 'b') + self.loop.run_once() self.assertEqual(calls, ['b', 'a']) def test_time_and_call_at(self): def cb(): - self.event_loop.stop() - - self.event_loop._process_events = unittest.mock.Mock() - when = self.event_loop.time() + 0.1 - self.event_loop.call_at(when, cb) - t0 = self.event_loop.time() - self.event_loop.run_forever() - t1 = self.event_loop.time() + self.loop.stop() + + self.loop._process_events = unittest.mock.Mock() + when = self.loop.time() + 0.1 + self.loop.call_at(when, cb) + t0 = self.loop.time() + self.loop.run_forever() + t1 = self.loop.time() self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0) def test_run_once_in_executor_handle(self): @@ -131,10 +131,10 @@ def cb(): pass self.assertRaises( - AssertionError, self.event_loop.run_in_executor, + AssertionError, self.loop.run_in_executor, None, events.Handle(cb, ()), ('',)) self.assertRaises( - AssertionError, self.event_loop.run_in_executor, + AssertionError, self.loop.run_in_executor, None, events.TimerHandle(10, cb, ())) def test_run_once_in_executor_cancelled(self): @@ -143,7 +143,7 @@ def cb(): h = events.Handle(cb, ()) h.cancel() - f = self.event_loop.run_in_executor(None, h) + f = self.loop.run_in_executor(None, h) self.assertIsInstance(f, futures.Future) self.assertTrue(f.done()) self.assertIsNone(f.result()) @@ -156,24 +156,24 @@ def cb(): executor = unittest.mock.Mock() executor.submit.return_value = f - self.event_loop.set_default_executor(executor) + self.loop.set_default_executor(executor) - res = self.event_loop.run_in_executor(None, h) + res = self.loop.run_in_executor(None, h) self.assertIs(f, res) executor = unittest.mock.Mock() executor.submit.return_value = f - res = self.event_loop.run_in_executor(executor, h) + res = self.loop.run_in_executor(executor, h) self.assertIs(f, res) self.assertTrue(executor.submit.called) f.cancel() # Don't complain about abandoned Future. def test_run_once(self): - self.event_loop._run_once = unittest.mock.Mock() - self.event_loop._run_once.side_effect = base_events._StopError - self.event_loop.run_once() - self.assertTrue(self.event_loop._run_once.called) + self.loop._run_once = unittest.mock.Mock() + self.loop._run_once.side_effect = base_events._StopError + self.loop.run_once() + self.assertTrue(self.loop._run_once.called) def test__run_once(self): h1 = events.TimerHandle(time.monotonic() + 0.1, lambda: True, ()) @@ -181,34 +181,34 @@ def test__run_once(self): h1.cancel() - self.event_loop._process_events = unittest.mock.Mock() - self.event_loop._scheduled.append(h1) - self.event_loop._scheduled.append(h2) - self.event_loop._run_once() + self.loop._process_events = unittest.mock.Mock() + self.loop._scheduled.append(h1) + self.loop._scheduled.append(h2) + self.loop._run_once() - t = self.event_loop._selector.select.call_args[0][0] + t = self.loop._selector.select.call_args[0][0] self.assertTrue(9.99 < t < 10.1) - self.assertEqual([h2], self.event_loop._scheduled) - self.assertTrue(self.event_loop._process_events.called) + self.assertEqual([h2], self.loop._scheduled) + self.assertTrue(self.loop._process_events.called) def test__run_once_timeout(self): h = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ()) - self.event_loop._process_events = unittest.mock.Mock() - self.event_loop._scheduled.append(h) - self.event_loop._run_once(1.0) - self.assertEqual((1.0,), self.event_loop._selector.select.call_args[0]) + self.loop._process_events = unittest.mock.Mock() + self.loop._scheduled.append(h) + self.loop._run_once(1.0) + self.assertEqual((1.0,), self.loop._selector.select.call_args[0]) def test__run_once_timeout_with_ready(self): # If event loop has ready callbacks, select timeout is always 0. h = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ()) - self.event_loop._process_events = unittest.mock.Mock() - self.event_loop._scheduled.append(h) - self.event_loop._ready.append(h) - self.event_loop._run_once(1.0) + self.loop._process_events = unittest.mock.Mock() + self.loop._scheduled.append(h) + self.loop._ready.append(h) + self.loop._run_once(1.0) - self.assertEqual((0,), self.event_loop._selector.select.call_args[0]) + self.assertEqual((0,), self.loop._selector.select.call_args[0]) @unittest.mock.patch('tulip.base_events.time') @unittest.mock.patch('tulip.base_events.tulip_log') @@ -226,40 +226,39 @@ def monotonic(): m_logging.INFO = logging.INFO m_logging.DEBUG = logging.DEBUG - self.event_loop._scheduled.append(events.TimerHandle(11.0, - lambda: True, ())) - self.event_loop._process_events = unittest.mock.Mock() - self.event_loop._run_once() + self.loop._scheduled.append( + events.TimerHandle(11.0, lambda: True, ())) + self.loop._process_events = unittest.mock.Mock() + self.loop._run_once() self.assertEqual(logging.INFO, m_logging.log.call_args[0][0]) idx = -1 data = [10.0, 10.0, 10.3, 13.0] - self.event_loop._scheduled = [events.TimerHandle(11.0, - lambda:True, ())] - self.event_loop._run_once() + self.loop._scheduled = [events.TimerHandle(11.0, lambda:True, ())] + self.loop._run_once() self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0]) def test__run_once_schedule_handle(self): handle = None processed = False - def cb(event_loop): + def cb(loop): nonlocal processed, handle processed = True - handle = event_loop.call_soon(lambda: True) + handle = loop.call_soon(lambda: True) - h = events.TimerHandle(time.monotonic() - 1, cb, (self.event_loop,)) + h = events.TimerHandle(time.monotonic() - 1, cb, (self.loop,)) - self.event_loop._process_events = unittest.mock.Mock() - self.event_loop._scheduled.append(h) - self.event_loop._run_once() + self.loop._process_events = unittest.mock.Mock() + self.loop._scheduled.append(h) + self.loop._run_once() self.assertTrue(processed) - self.assertEqual([handle], list(self.event_loop._ready)) + self.assertEqual([handle], list(self.loop._ready)) def test_run_until_complete_assertion(self): self.assertRaises( - AssertionError, self.event_loop.run_until_complete, 'blah') + AssertionError, self.loop.run_until_complete, 'blah') @unittest.mock.patch('tulip.base_events.socket') def test_create_connection_mutiple_errors(self, m_socket): @@ -284,10 +283,10 @@ def _socket(*args, **kw): m_socket.socket = _socket m_socket.error = socket.error - self.event_loop.getaddrinfo = getaddrinfo + self.loop.getaddrinfo = getaddrinfo task = tasks.Task( - self.event_loop.create_connection(MyProto, 'example.com', 80)) + self.loop.create_connection(MyProto, 'example.com', 80)) yield from tasks.wait(task) exc = task.exception() self.assertEqual("Multiple exceptions: err1, err2", str(exc)) diff --git a/tests/events_test.py b/tests/events_test.py index 99d97ed7..3f606834 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -144,33 +144,32 @@ class EventLoopTestsMixin: def setUp(self): super().setUp() - self.event_loop = self.create_event_loop() - events.set_event_loop(self.event_loop) + self.loop = self.create_event_loop() + events.set_event_loop(self.loop) def tearDown(self): - self.event_loop.close() + self.loop.close() gc.collect() super().tearDown() def test_run_nesting(self): @tasks.coroutine def coro(): - self.assertTrue(self.event_loop.is_running()) - self.event_loop.run_until_complete(tasks.sleep(0.1)) + self.assertTrue(self.loop.is_running()) + self.loop.run_until_complete(tasks.sleep(0.1)) self.assertRaises( - RuntimeError, - self.event_loop.run_until_complete, coro()) + RuntimeError, self.loop.run_until_complete, coro()) def test_run_once_nesting(self): @tasks.coroutine def coro(): tasks.sleep(0.1) - self.event_loop.run_once() + self.loop.run_once() self.assertRaises( RuntimeError, - self.event_loop.run_until_complete, coro()) + self.loop.run_until_complete, coro()) def test_run_once_block(self): called = False @@ -181,14 +180,14 @@ def callback(): def run(): time.sleep(0.1) - self.event_loop.call_soon_threadsafe(callback) + self.loop.call_soon_threadsafe(callback) - self.event_loop.run_once(0) # windows iocp + self.loop.run_once(0) # windows iocp t = threading.Thread(target=run) t0 = time.monotonic() t.start() - self.event_loop.run_once(None) + self.loop.run_once(None) t1 = time.monotonic() t.join() self.assertTrue(called) @@ -199,11 +198,11 @@ def test_call_later(self): def callback(arg): results.append(arg) - self.event_loop.stop() + self.loop.stop() - self.event_loop.call_later(0.1, callback, 'hello world') + self.loop.call_later(0.1, callback, 'hello world') t0 = time.monotonic() - self.event_loop.run_forever() + self.loop.run_forever() t1 = time.monotonic() self.assertEqual(results, ['hello world']) self.assertTrue(0.09 <= t1-t0 <= 0.12) @@ -213,10 +212,10 @@ def test_call_soon(self): def callback(arg1, arg2): results.append((arg1, arg2)) - self.event_loop.stop() + self.loop.stop() - self.event_loop.call_soon(callback, 'hello', 'world') - self.event_loop.run_forever() + self.loop.call_soon(callback, 'hello', 'world') + self.loop.run_forever() self.assertEqual(results, [('hello', 'world')]) def test_call_soon_threadsafe(self): @@ -225,16 +224,16 @@ def test_call_soon_threadsafe(self): def callback(arg): results.append(arg) if len(results) >= 2: - self.event_loop.stop() + self.loop.stop() def run_in_thread(): - self.event_loop.call_soon_threadsafe(callback, 'hello') + self.loop.call_soon_threadsafe(callback, 'hello') t = threading.Thread(target=run_in_thread) - self.event_loop.call_later(0.1, callback, 'world') + self.loop.call_later(0.1, callback, 'world') t0 = time.monotonic() t.start() - self.event_loop.run_forever() + self.loop.run_forever() t1 = time.monotonic() t.join() self.assertEqual(results, ['hello', 'world']) @@ -246,11 +245,11 @@ def test_call_soon_threadsafe_same_thread(self): def callback(arg): results.append(arg) if len(results) >= 2: - self.event_loop.stop() + self.loop.stop() - self.event_loop.call_later(0.1, callback, 'world') - self.event_loop.call_soon_threadsafe(callback, 'hello') - self.event_loop.run_forever() + self.loop.call_later(0.1, callback, 'world') + self.loop.call_soon_threadsafe(callback, 'hello') + self.loop.run_forever() self.assertEqual(results, ['hello', 'world']) def test_wrap_future(self): @@ -259,16 +258,16 @@ def run(arg): return arg ex = concurrent.futures.ThreadPoolExecutor(1) f1 = ex.submit(run, 'oi') - f2 = self.event_loop.wrap_future(f1) - res = self.event_loop.run_until_complete(f2) + f2 = self.loop.wrap_future(f1) + res = self.loop.run_until_complete(f2) self.assertEqual(res, 'oi') def test_run_in_executor(self): def run(arg): time.sleep(0.1) return arg - f2 = self.event_loop.run_in_executor(None, run, 'yo') - res = self.event_loop.run_until_complete(f2) + f2 = self.loop.run_in_executor(None, run, 'yo') + res = self.loop.run_until_complete(f2) self.assertEqual(res, 'yo') def test_reader_callback(self): @@ -285,47 +284,46 @@ def reader(): if data: bytes_read.append(data) else: - self.assertTrue(self.event_loop.remove_reader(r.fileno())) + self.assertTrue(self.loop.remove_reader(r.fileno())) r.close() - self.event_loop.add_reader(r.fileno(), reader) - self.event_loop.call_later(0.05, w.send, b'abc') - self.event_loop.call_later(0.1, w.send, b'def') - self.event_loop.call_later(0.15, w.close) - self.event_loop.call_later(0.16, self.event_loop.stop) - self.event_loop.run_forever() + self.loop.add_reader(r.fileno(), reader) + self.loop.call_later(0.05, w.send, b'abc') + self.loop.call_later(0.1, w.send, b'def') + self.loop.call_later(0.15, w.close) + self.loop.call_later(0.16, self.loop.stop) + self.loop.run_forever() self.assertEqual(b''.join(bytes_read), b'abcdef') def test_writer_callback(self): r, w = test_utils.socketpair() w.setblocking(False) - self.event_loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + self.loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) def remove_writer(): - self.assertTrue(self.event_loop.remove_writer(w.fileno())) + self.assertTrue(self.loop.remove_writer(w.fileno())) - self.event_loop.call_later(0.1, remove_writer) - self.event_loop.call_later(0.11, self.event_loop.stop) - self.event_loop.run_forever() + self.loop.call_later(0.1, remove_writer) + self.loop.call_later(0.11, self.loop.stop) + self.loop.run_forever() w.close() data = r.recv(256*1024) r.close() self.assertTrue(len(data) >= 200) def test_sock_client_ops(self): - with test_utils.run_test_server(self.event_loop) as httpd: + with test_utils.run_test_server(self.loop) as httpd: sock = socket.socket() sock.setblocking(False) - self.event_loop.run_until_complete( - self.event_loop.sock_connect(sock, httpd.address)) - self.event_loop.run_until_complete( - self.event_loop.sock_sendall( - sock, b'GET / HTTP/1.0\r\n\r\n')) - data = self.event_loop.run_until_complete( - self.event_loop.sock_recv(sock, 1024)) + self.loop.run_until_complete( + self.loop.sock_connect(sock, httpd.address)) + self.loop.run_until_complete( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) # consume data - self.event_loop.run_until_complete( - self.event_loop.sock_recv(sock, 1024)) + self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) sock.close() self.assertTrue(re.match(rb'HTTP/1.0 200 OK', data), data) @@ -343,8 +341,8 @@ def test_sock_client_fail(self): sock = socket.socket() sock.setblocking(False) with self.assertRaises(ConnectionRefusedError): - self.event_loop.run_until_complete( - self.event_loop.sock_connect(sock, address)) + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) sock.close() def test_sock_accept(self): @@ -355,8 +353,8 @@ def test_sock_accept(self): client = socket.socket() client.connect(listener.getsockname()) - f = self.event_loop.sock_accept(listener) - conn, addr = self.event_loop.run_until_complete(f) + f = self.loop.sock_accept(listener) + conn, addr = self.loop.run_until_complete(f) self.assertEqual(conn.gettimeout(), 0) self.assertEqual(addr, client.getsockname()) self.assertEqual(client.getpeername(), listener.getsockname()) @@ -374,39 +372,39 @@ def my_handler(): # Check error behavior first. self.assertRaises( - TypeError, self.event_loop.add_signal_handler, 'boom', my_handler) + TypeError, self.loop.add_signal_handler, 'boom', my_handler) self.assertRaises( - TypeError, self.event_loop.remove_signal_handler, 'boom') + TypeError, self.loop.remove_signal_handler, 'boom') self.assertRaises( - ValueError, self.event_loop.add_signal_handler, signal.NSIG+1, + ValueError, self.loop.add_signal_handler, signal.NSIG+1, my_handler) self.assertRaises( - ValueError, self.event_loop.remove_signal_handler, signal.NSIG+1) + ValueError, self.loop.remove_signal_handler, signal.NSIG+1) self.assertRaises( - ValueError, self.event_loop.add_signal_handler, 0, my_handler) + ValueError, self.loop.add_signal_handler, 0, my_handler) self.assertRaises( - ValueError, self.event_loop.remove_signal_handler, 0) + ValueError, self.loop.remove_signal_handler, 0) self.assertRaises( - ValueError, self.event_loop.add_signal_handler, -1, my_handler) + ValueError, self.loop.add_signal_handler, -1, my_handler) self.assertRaises( - ValueError, self.event_loop.remove_signal_handler, -1) + ValueError, self.loop.remove_signal_handler, -1) self.assertRaises( - RuntimeError, self.event_loop.add_signal_handler, signal.SIGKILL, + RuntimeError, self.loop.add_signal_handler, signal.SIGKILL, my_handler) # Removing SIGKILL doesn't raise, since we don't call signal(). - self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGKILL)) + self.assertFalse(self.loop.remove_signal_handler(signal.SIGKILL)) # Now set a handler and handle it. - self.event_loop.add_signal_handler(signal.SIGINT, my_handler) - self.event_loop.run_once() + self.loop.add_signal_handler(signal.SIGINT, my_handler) + self.loop.run_once() os.kill(os.getpid(), signal.SIGINT) - self.event_loop.run_once() + self.loop.run_once() self.assertEqual(caught, 1) # Removing it should restore the default handler. - self.assertTrue(self.event_loop.remove_signal_handler(signal.SIGINT)) + self.assertTrue(self.loop.remove_signal_handler(signal.SIGINT)) self.assertEqual(signal.getsignal(signal.SIGINT), signal.default_int_handler) # Removing again returns False. - self.assertFalse(self.event_loop.remove_signal_handler(signal.SIGINT)) + self.assertFalse(self.loop.remove_signal_handler(signal.SIGINT)) @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') def test_signal_handling_while_selecting(self): @@ -417,12 +415,11 @@ def my_handler(): nonlocal caught caught += 1 - self.event_loop.add_signal_handler( - signal.SIGALRM, my_handler) + self.loop.add_signal_handler(signal.SIGALRM, my_handler) signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. - self.event_loop.call_later(0.15, self.event_loop.stop) - self.event_loop.run_forever() + self.loop.call_later(0.15, self.loop.stop) + self.loop.run_forever() self.assertEqual(caught, 1) @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') @@ -435,37 +432,35 @@ def my_handler(*args): caught += 1 self.assertEqual(args, some_args) - self.event_loop.add_signal_handler( - signal.SIGALRM, my_handler, *some_args) + self.loop.add_signal_handler(signal.SIGALRM, my_handler, *some_args) signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. - self.event_loop.call_later(0.15, self.event_loop.stop) - self.event_loop.run_forever() + self.loop.call_later(0.15, self.loop.stop) + self.loop.run_forever() self.assertEqual(caught, 1) def test_create_connection(self): - with test_utils.run_test_server(self.event_loop) as httpd: - f = self.event_loop.create_connection( - lambda: MyProto(create_future=True), - *httpd.address) - tr, pr = self.event_loop.run_until_complete(f) + with test_utils.run_test_server(self.loop) as httpd: + f = self.loop.create_connection( + lambda: MyProto(create_future=True), *httpd.address) + tr, pr = self.loop.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) - self.event_loop.run_until_complete(pr.done) + self.loop.run_until_complete(pr.done) self.assertTrue(pr.nbytes > 0) def test_create_connection_sock(self): - with test_utils.run_test_server(self.event_loop) as httpd: + with test_utils.run_test_server(self.loop) as httpd: sock = None - infos = self.event_loop.run_until_complete( - self.event_loop.getaddrinfo( + infos = self.loop.run_until_complete( + self.loop.getaddrinfo( *httpd.address, type=socket.SOCK_STREAM)) for family, type, proto, cname, address in infos: try: sock = socket.socket(family=family, type=type, proto=proto) sock.setblocking(False) - self.event_loop.run_until_complete( - self.event_loop.sock_connect(sock, address)) + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) except: pass else: @@ -473,59 +468,59 @@ def test_create_connection_sock(self): else: assert False, 'Can not create socket.' - f = self.event_loop.create_connection( + f = self.loop.create_connection( lambda: MyProto(create_future=True), sock=sock) - tr, pr = self.event_loop.run_until_complete(f) + tr, pr = self.loop.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) - self.event_loop.run_until_complete(pr.done) + self.loop.run_until_complete(pr.done) self.assertTrue(pr.nbytes > 0) @unittest.skipIf(ssl is None, 'No ssl module') def test_create_ssl_connection(self): with test_utils.run_test_server( - self.event_loop, use_ssl=True) as httpd: - f = self.event_loop.create_connection( + self.loop, use_ssl=True) as httpd: + f = self.loop.create_connection( lambda: MyProto(create_future=True), *httpd.address, ssl=True) - tr, pr = self.event_loop.run_until_complete(f) + tr, pr = self.loop.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) self.assertTrue('ssl' in tr.__class__.__name__.lower()) self.assertTrue( hasattr(tr.get_extra_info('socket'), 'getsockname')) - self.event_loop.run_until_complete(pr.done) + self.loop.run_until_complete(pr.done) self.assertTrue(pr.nbytes > 0) def test_create_connection_host_port_sock(self): - coro = self.event_loop.create_connection( + coro = self.loop.create_connection( MyProto, 'example.com', 80, sock=object()) - self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) def test_create_connection_no_host_port_sock(self): - coro = self.event_loop.create_connection(MyProto) - self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) + coro = self.loop.create_connection(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) def test_create_connection_no_getaddrinfo(self): @tasks.task def getaddrinfo(*args, **kw): yield from [] - self.event_loop.getaddrinfo = getaddrinfo - coro = self.event_loop.create_connection(MyProto, 'example.com', 80) + self.loop.getaddrinfo = getaddrinfo + coro = self.loop.create_connection(MyProto, 'example.com', 80) self.assertRaises( - socket.error, self.event_loop.run_until_complete, coro) + socket.error, self.loop.run_until_complete, coro) def test_create_connection_connect_err(self): @tasks.task def getaddrinfo(*args, **kw): yield from [] return [(2, 1, 6, '', ('107.6.106.82', 80))] - self.event_loop.getaddrinfo = getaddrinfo - self.event_loop.sock_connect = unittest.mock.Mock() - self.event_loop.sock_connect.side_effect = socket.error + self.loop.getaddrinfo = getaddrinfo + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = socket.error - coro = self.event_loop.create_connection(MyProto, 'example.com', 80) + coro = self.loop.create_connection(MyProto, 'example.com', 80) self.assertRaises( - socket.error, self.event_loop.run_until_complete, coro) + socket.error, self.loop.run_until_complete, coro) def test_create_connection_mutiple_errors(self): @tasks.task @@ -533,32 +528,32 @@ def getaddrinfo(*args, **kw): yield from [] return [(2, 1, 6, '', ('107.6.106.82', 80)), (2, 1, 6, '', ('107.6.106.82', 80))] - self.event_loop.getaddrinfo = getaddrinfo - self.event_loop.sock_connect = unittest.mock.Mock() - self.event_loop.sock_connect.side_effect = socket.error + self.loop.getaddrinfo = getaddrinfo + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = socket.error - coro = self.event_loop.create_connection( + coro = self.loop.create_connection( MyProto, 'example.com', 80, family=socket.AF_INET) self.assertRaises( - socket.error, self.event_loop.run_until_complete, coro) + socket.error, self.loop.run_until_complete, coro) def test_create_connection_local_addr(self): - with test_utils.run_test_server(self.event_loop) as httpd: + with test_utils.run_test_server(self.loop) as httpd: port = find_unused_port() - f = self.event_loop.create_connection( + f = self.loop.create_connection( lambda: MyProto(create_future=True), *httpd.address, local_addr=(httpd.address[0], port)) - tr, pr = self.event_loop.run_until_complete(f) + tr, pr = self.loop.run_until_complete(f) expected = pr.transport.get_extra_info('socket').getsockname()[1] self.assertEqual(port, expected) def test_create_connection_local_addr_in_use(self): - with test_utils.run_test_server(self.event_loop) as httpd: - f = self.event_loop.create_connection( + with test_utils.run_test_server(self.loop) as httpd: + f = self.loop.create_connection( lambda: MyProto(create_future=True), *httpd.address, local_addr=httpd.address) with self.assertRaises(socket.error) as cm: - self.event_loop.run_until_complete(f) + self.loop.run_until_complete(f) self.assertEqual(cm.exception.errno, errno.EADDRINUSE) self.assertIn(str(httpd.address), cm.exception.strerror) @@ -570,8 +565,8 @@ def factory(): proto = MyProto() return proto - f = self.event_loop.start_serving(factory, '0.0.0.0', 0) - socks = self.event_loop.run_until_complete(f) + f = self.loop.start_serving(factory, '0.0.0.0', 0) + socks = self.loop.run_until_complete(f) self.assertEqual(len(socks), 1) sock = socks[0] host, port = sock.getsockname() @@ -579,12 +574,12 @@ def factory(): client = socket.socket() client.connect(('127.0.0.1', port)) client.send(b'xxx') - self.event_loop.run_once() + self.loop.run_once() self.assertIsInstance(proto, MyProto) self.assertEqual('INITIAL', proto.state) - self.event_loop.run_once() + self.loop.run_once() self.assertEqual('CONNECTED', proto.state) - self.event_loop.run_once(0.001) # windows iocp + self.loop.run_once(0.001) # windows iocp self.assertEqual(3, proto.nbytes) # extra info is available @@ -596,7 +591,7 @@ def factory(): # close connection proto.transport.close() - self.event_loop.run_once(0.001) # windows iocp + self.loop.run_once(0.001) # windows iocp self.assertEqual('CLOSED', proto.state) @@ -625,21 +620,20 @@ def factory(): certfile=os.path.join(here, 'sample.crt'), keyfile=os.path.join(here, 'sample.key')) - f = self.event_loop.start_serving( + f = self.loop.start_serving( factory, '127.0.0.1', 0, ssl=sslcontext) - sock = self.event_loop.run_until_complete(f)[0] + sock = self.loop.run_until_complete(f)[0] host, port = sock.getsockname() self.assertEqual(host, '127.0.0.1') - f_c = self.event_loop.create_connection( - ClientMyProto, host, port, ssl=True) - client, pr = self.event_loop.run_until_complete(f_c) + f_c = self.loop.create_connection(ClientMyProto, host, port, ssl=True) + client, pr = self.loop.run_until_complete(f_c) client.write(b'xxx') - self.event_loop.run_once() + self.loop.run_once() self.assertIsInstance(proto, MyProto) - self.event_loop.run_once() + self.loop.run_once() self.assertEqual('CONNECTED', proto.state) self.assertEqual(3, proto.nbytes) @@ -652,7 +646,7 @@ def factory(): # close connection proto.transport.close() - self.event_loop.run_until_complete(proto.done) + self.loop.run_until_complete(proto.done) self.assertEqual('CLOSED', proto.state) # the client socket must be closed after to avoid ECONNRESET upon @@ -671,8 +665,8 @@ def connection_made(self, transport): sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock_ob.bind(('0.0.0.0', 0)) - f = self.event_loop.start_serving(TestMyProto, sock=sock_ob) - sock = self.event_loop.run_until_complete(f)[0] + f = self.loop.start_serving(TestMyProto, sock=sock_ob) + sock = self.loop.run_until_complete(f)[0] self.assertIs(sock, sock_ob) host, port = sock.getsockname() @@ -680,7 +674,7 @@ def connection_made(self, transport): client = socket.socket() client.connect(('127.0.0.1', port)) client.send(b'xxx') - self.event_loop.run_until_complete(proto) + self.loop.run_until_complete(proto) sock.close() client.close() @@ -689,13 +683,13 @@ def test_start_serving_addrinuse(self): sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock_ob.bind(('0.0.0.0', 0)) - f = self.event_loop.start_serving(MyProto, sock=sock_ob) - sock = self.event_loop.run_until_complete(f)[0] + f = self.loop.start_serving(MyProto, sock=sock_ob) + sock = self.loop.run_until_complete(f)[0] host, port = sock.getsockname() - f = self.event_loop.start_serving(MyProto, host=host, port=port) + f = self.loop.start_serving(MyProto, host=host, port=port) with self.assertRaises(socket.error) as cm: - self.event_loop.run_until_complete(f) + self.loop.run_until_complete(f) self.assertEqual(cm.exception.errno, errno.EADDRINUSE) @unittest.skipUnless(socket.has_ipv6, 'IPv6 not supported') @@ -708,12 +702,12 @@ def connection_made(self, transport): f_proto.set_result(self) port = find_unused_port() - f = self.event_loop.start_serving(TestMyProto, host=None, port=port) - socks = self.event_loop.run_until_complete(f) + f = self.loop.start_serving(TestMyProto, host=None, port=port) + socks = self.loop.run_until_complete(f) client = socket.socket() client.connect(('127.0.0.1', port)) client.send(b'xxx') - proto = self.event_loop.run_until_complete(f_proto) + proto = self.loop.run_until_complete(f_proto) proto.transport.close() client.close() @@ -721,7 +715,7 @@ def connection_made(self, transport): client = socket.socket(socket.AF_INET6) client.connect(('::1', port)) client.send(b'xxx') - proto = self.event_loop.run_until_complete(f_proto) + proto = self.loop.run_until_complete(f_proto) proto.transport.close() client.close() @@ -729,8 +723,8 @@ def connection_made(self, transport): s.close() def test_stop_serving(self): - f = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) - sock = self.event_loop.run_until_complete(f)[0] + f = self.loop.start_serving(MyProto, '0.0.0.0', 0) + sock = self.loop.run_until_complete(f)[0] host, port = sock.getsockname() client = socket.socket() @@ -738,28 +732,27 @@ def test_stop_serving(self): client.send(b'xxx') client.close() - self.event_loop.stop_serving(sock) + self.loop.stop_serving(sock) client = socket.socket() self.assertRaises( ConnectionRefusedError, client.connect, ('127.0.0.1', port)) def test_start_serving_host_port_sock(self): - fut = self.event_loop.start_serving( + fut = self.loop.start_serving( MyProto, '0.0.0.0', 0, sock=object()) - self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) def test_start_serving_no_host_port_sock(self): - fut = self.event_loop.start_serving(MyProto) - self.assertRaises(ValueError, self.event_loop.run_until_complete, fut) + fut = self.loop.start_serving(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) def test_start_serving_no_getaddrinfo(self): - getaddrinfo = self.event_loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo = self.loop.getaddrinfo = unittest.mock.Mock() getaddrinfo.return_value = [] - f = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) - self.assertRaises( - socket.error, self.event_loop.run_until_complete, f) + f = self.loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises(socket.error, self.loop.run_until_complete, f) @unittest.mock.patch('tulip.base_events.socket') def test_start_serving_cant_bind(self, m_socket): @@ -773,8 +766,8 @@ class Err(socket.error): m_sock = m_socket.socket.return_value = unittest.mock.Mock() m_sock.bind.side_effect = Err - fut = self.event_loop.start_serving(MyProto, '0.0.0.0', 0) - self.assertRaises(OSError, self.event_loop.run_until_complete, fut) + fut = self.loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) self.assertTrue(m_sock.close.called) @unittest.mock.patch('tulip.base_events.socket') @@ -782,20 +775,20 @@ def test_create_datagram_endpoint_no_addrinfo(self, m_socket): m_socket.error = socket.error m_socket.getaddrinfo.return_value = [] - coro = self.event_loop.create_datagram_endpoint( + coro = self.loop.create_datagram_endpoint( MyDatagramProto, local_addr=('localhost', 0)) self.assertRaises( - socket.error, self.event_loop.run_until_complete, coro) + socket.error, self.loop.run_until_complete, coro) def test_create_datagram_endpoint_addr_error(self): - coro = self.event_loop.create_datagram_endpoint( + coro = self.loop.create_datagram_endpoint( MyDatagramProto, local_addr='localhost') self.assertRaises( - AssertionError, self.event_loop.run_until_complete, coro) - coro = self.event_loop.create_datagram_endpoint( + AssertionError, self.loop.run_until_complete, coro) + coro = self.loop.create_datagram_endpoint( MyDatagramProto, local_addr=('localhost', 1, 2, 3)) self.assertRaises( - AssertionError, self.event_loop.run_until_complete, coro) + AssertionError, self.loop.run_until_complete, coro) def test_create_datagram_endpoint(self): class TestMyDatagramProto(MyDatagramProto): @@ -806,21 +799,21 @@ def datagram_received(self, data, addr): super().datagram_received(data, addr) self.transport.sendto(b'resp:'+data, addr) - coro = self.event_loop.create_datagram_endpoint( + coro = self.loop.create_datagram_endpoint( TestMyDatagramProto, local_addr=('127.0.0.1', 0)) - s_transport, server = self.event_loop.run_until_complete(coro) + s_transport, server = self.loop.run_until_complete(coro) host, port = s_transport.get_extra_info('addr') - coro = self.event_loop.create_datagram_endpoint( + coro = self.loop.create_datagram_endpoint( lambda: MyDatagramProto(create_future=True), remote_addr=(host, port)) - transport, client = self.event_loop.run_until_complete(coro) + transport, client = self.loop.run_until_complete(coro) self.assertEqual('INITIALIZED', client.state) transport.sendto(b'xxx') - self.event_loop.run_once(None) + self.loop.run_once(None) self.assertEqual(3, server.nbytes) - self.event_loop.run_once(None) + self.loop.run_once(None) # received self.assertEqual(8, client.nbytes) @@ -832,18 +825,18 @@ def datagram_received(self, data, addr): # close connection transport.close() - self.event_loop.run_until_complete(client.done) + self.loop.run_until_complete(client.done) self.assertEqual('CLOSED', client.state) server.transport.close() def test_create_datagram_endpoint_connect_err(self): - self.event_loop.sock_connect = unittest.mock.Mock() - self.event_loop.sock_connect.side_effect = socket.error + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = socket.error - coro = self.event_loop.create_datagram_endpoint( + coro = self.loop.create_datagram_endpoint( protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0)) self.assertRaises( - socket.error, self.event_loop.run_until_complete, coro) + socket.error, self.loop.run_until_complete, coro) @unittest.mock.patch('tulip.base_events.socket') def test_create_datagram_endpoint_socket_err(self, m_socket): @@ -851,39 +844,39 @@ def test_create_datagram_endpoint_socket_err(self, m_socket): m_socket.getaddrinfo = socket.getaddrinfo m_socket.socket.side_effect = socket.error - coro = self.event_loop.create_datagram_endpoint( + coro = self.loop.create_datagram_endpoint( protocols.DatagramProtocol, family=socket.AF_INET) self.assertRaises( - socket.error, self.event_loop.run_until_complete, coro) + socket.error, self.loop.run_until_complete, coro) - coro = self.event_loop.create_datagram_endpoint( + coro = self.loop.create_datagram_endpoint( protocols.DatagramProtocol, local_addr=('127.0.0.1', 0)) self.assertRaises( - socket.error, self.event_loop.run_until_complete, coro) + socket.error, self.loop.run_until_complete, coro) def test_create_datagram_endpoint_no_matching_family(self): - coro = self.event_loop.create_datagram_endpoint( + coro = self.loop.create_datagram_endpoint( protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) self.assertRaises( - ValueError, self.event_loop.run_until_complete, coro) + ValueError, self.loop.run_until_complete, coro) @unittest.mock.patch('tulip.base_events.socket') def test_create_datagram_endpoint_setblk_err(self, m_socket): m_socket.error = socket.error m_socket.socket.return_value.setblocking.side_effect = socket.error - coro = self.event_loop.create_datagram_endpoint( + coro = self.loop.create_datagram_endpoint( protocols.DatagramProtocol, family=socket.AF_INET) self.assertRaises( - socket.error, self.event_loop.run_until_complete, coro) + socket.error, self.loop.run_until_complete, coro) self.assertTrue( m_socket.socket.return_value.close.called) def test_create_datagram_endpoint_noaddr_nofamily(self): - coro = self.event_loop.create_datagram_endpoint( + coro = self.loop.create_datagram_endpoint( protocols.DatagramProtocol) - self.assertRaises(ValueError, self.event_loop.run_until_complete, coro) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) @unittest.mock.patch('tulip.base_events.socket') def test_create_datagram_endpoint_cant_bind(self, m_socket): @@ -896,17 +889,17 @@ class Err(socket.error): m_sock = m_socket.socket.return_value = unittest.mock.Mock() m_sock.bind.side_effect = Err - fut = self.event_loop.create_datagram_endpoint( + fut = self.loop.create_datagram_endpoint( MyDatagramProto, local_addr=('127.0.0.1', 0), family=socket.AF_INET) - self.assertRaises(Err, self.event_loop.run_until_complete, fut) + self.assertRaises(Err, self.loop.run_until_complete, fut) self.assertTrue(m_sock.close.called) def test_accept_connection_retry(self): sock = unittest.mock.Mock() sock.accept.side_effect = BlockingIOError() - self.event_loop._accept_connection(MyProto, sock) + self.loop._accept_connection(MyProto, sock) self.assertFalse(sock.close.called) @unittest.mock.patch('tulip.selector_events.tulip_log') @@ -914,20 +907,20 @@ def test_accept_connection_exception(self, m_log): sock = unittest.mock.Mock() sock.accept.side_effect = OSError() - self.event_loop._accept_connection(MyProto, sock) + self.loop._accept_connection(MyProto, sock) self.assertTrue(sock.close.called) self.assertTrue(m_log.exception.called) def test_internal_fds(self): - event_loop = self.create_event_loop() - if not isinstance(event_loop, selector_events.BaseSelectorEventLoop): + loop = self.create_event_loop() + if not isinstance(loop, selector_events.BaseSelectorEventLoop): return - self.assertEqual(1, event_loop._internal_fds) - event_loop.close() - self.assertEqual(0, event_loop._internal_fds) - self.assertIsNone(event_loop._csock) - self.assertIsNone(event_loop._ssock) + self.assertEqual(1, loop._internal_fds) + loop.close() + self.assertEqual(0, loop._internal_fds) + self.assertIsNone(loop._csock) + self.assertIsNone(loop._ssock) @unittest.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") @@ -944,26 +937,25 @@ def factory(): @tasks.task def connect(): - t, p = yield from self.event_loop.connect_read_pipe(factory, - pipeobj) + t, p = yield from self.loop.connect_read_pipe(factory, pipeobj) self.assertIs(p, proto) self.assertIs(t, proto.transport) self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) self.assertEqual(0, proto.nbytes) - self.event_loop.run_until_complete(connect()) + self.loop.run_until_complete(connect()) os.write(wpipe, b'1') - self.event_loop.run_once() + self.loop.run_once() self.assertEqual(1, proto.nbytes) os.write(wpipe, b'2345') - self.event_loop.run_once() + self.loop.run_once() self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) self.assertEqual(5, proto.nbytes) os.close(wpipe) - self.event_loop.run_until_complete(proto.done) + self.loop.run_until_complete(proto.done) self.assertEqual( ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) # extra info is available @@ -986,22 +978,21 @@ def factory(): @tasks.task def connect(): nonlocal transport - t, p = yield from self.event_loop.connect_write_pipe(factory, - pipeobj) + t, p = yield from self.loop.connect_write_pipe(factory, pipeobj) self.assertIs(p, proto) self.assertIs(t, proto.transport) self.assertEqual('CONNECTED', proto.state) transport = t - self.event_loop.run_until_complete(connect()) + self.loop.run_until_complete(connect()) transport.write(b'1') - self.event_loop.run_once() + self.loop.run_once() data = os.read(rpipe, 1024) self.assertEqual(b'1', data) transport.write(b'2345') - self.event_loop.run_once() + self.loop.run_once() data = os.read(rpipe, 1024) self.assertEqual(b'2345', data) self.assertEqual('CONNECTED', proto.state) @@ -1013,7 +1004,7 @@ def connect(): # close connection proto.transport.close() - self.event_loop.run_until_complete(proto.done) + self.loop.run_until_complete(proto.done) self.assertEqual('CLOSED', proto.state) @@ -1211,70 +1202,70 @@ class AbstractEventLoopTests(unittest.TestCase): def test_not_implemented(self): f = unittest.mock.Mock() - ev_loop = events.AbstractEventLoop() + loop = events.AbstractEventLoop() self.assertRaises( - NotImplementedError, ev_loop.run_forever) + NotImplementedError, loop.run_forever) self.assertRaises( - NotImplementedError, ev_loop.run_once) + NotImplementedError, loop.run_once) self.assertRaises( - NotImplementedError, ev_loop.run_until_complete, None) + NotImplementedError, loop.run_until_complete, None) self.assertRaises( - NotImplementedError, ev_loop.stop) + NotImplementedError, loop.stop) self.assertRaises( - NotImplementedError, ev_loop.is_running) + NotImplementedError, loop.is_running) self.assertRaises( - NotImplementedError, ev_loop.call_later, None, None) + NotImplementedError, loop.call_later, None, None) self.assertRaises( - NotImplementedError, ev_loop.call_soon, None) + NotImplementedError, loop.call_soon, None) self.assertRaises( - NotImplementedError, ev_loop.time) + NotImplementedError, loop.time) self.assertRaises( - NotImplementedError, ev_loop.call_soon_threadsafe, None) + NotImplementedError, loop.call_soon_threadsafe, None) self.assertRaises( - NotImplementedError, ev_loop.wrap_future, f) + NotImplementedError, loop.wrap_future, f) self.assertRaises( - NotImplementedError, ev_loop.run_in_executor, f, f) + NotImplementedError, loop.run_in_executor, f, f) self.assertRaises( - NotImplementedError, ev_loop.set_default_executor, f) + NotImplementedError, loop.set_default_executor, f) self.assertRaises( - NotImplementedError, ev_loop.getaddrinfo, 'localhost', 8080) + NotImplementedError, loop.getaddrinfo, 'localhost', 8080) self.assertRaises( - NotImplementedError, ev_loop.getnameinfo, ('localhost', 8080)) + NotImplementedError, loop.getnameinfo, ('localhost', 8080)) self.assertRaises( - NotImplementedError, ev_loop.create_connection, f) + NotImplementedError, loop.create_connection, f) self.assertRaises( - NotImplementedError, ev_loop.start_serving, f) + NotImplementedError, loop.start_serving, f) self.assertRaises( - NotImplementedError, ev_loop.stop_serving, f) + NotImplementedError, loop.stop_serving, f) self.assertRaises( - NotImplementedError, ev_loop.create_datagram_endpoint, f) + NotImplementedError, loop.create_datagram_endpoint, f) self.assertRaises( - NotImplementedError, ev_loop.add_reader, 1, f) + NotImplementedError, loop.add_reader, 1, f) self.assertRaises( - NotImplementedError, ev_loop.remove_reader, 1) + NotImplementedError, loop.remove_reader, 1) self.assertRaises( - NotImplementedError, ev_loop.add_writer, 1, f) + NotImplementedError, loop.add_writer, 1, f) self.assertRaises( - NotImplementedError, ev_loop.remove_writer, 1) + NotImplementedError, loop.remove_writer, 1) self.assertRaises( - NotImplementedError, ev_loop.sock_recv, f, 10) + NotImplementedError, loop.sock_recv, f, 10) self.assertRaises( - NotImplementedError, ev_loop.sock_sendall, f, 10) + NotImplementedError, loop.sock_sendall, f, 10) self.assertRaises( - NotImplementedError, ev_loop.sock_connect, f, f) + NotImplementedError, loop.sock_connect, f, f) self.assertRaises( - NotImplementedError, ev_loop.sock_accept, f) + NotImplementedError, loop.sock_accept, f) self.assertRaises( - NotImplementedError, ev_loop.add_signal_handler, 1, f) + NotImplementedError, loop.add_signal_handler, 1, f) self.assertRaises( - NotImplementedError, ev_loop.remove_signal_handler, 1) + NotImplementedError, loop.remove_signal_handler, 1) self.assertRaises( - NotImplementedError, ev_loop.remove_signal_handler, 1) + NotImplementedError, loop.remove_signal_handler, 1) self.assertRaises( - NotImplementedError, ev_loop.connect_read_pipe, f, + NotImplementedError, loop.connect_read_pipe, f, unittest.mock.sentinel.pipe) self.assertRaises( - NotImplementedError, ev_loop.connect_write_pipe, f, + NotImplementedError, loop.connect_write_pipe, f, unittest.mock.sentinel.pipe) @@ -1305,13 +1296,13 @@ def test_event_loop_policy(self): def test_get_event_loop(self): policy = events.DefaultEventLoopPolicy() - self.assertIsNone(policy._event_loop) + self.assertIsNone(policy._loop) - event_loop = policy.get_event_loop() - self.assertIsInstance(event_loop, events.AbstractEventLoop) + loop = policy.get_event_loop() + self.assertIsInstance(loop, events.AbstractEventLoop) - self.assertIs(policy._event_loop, event_loop) - self.assertIs(event_loop, policy.get_event_loop()) + self.assertIs(policy._loop, loop) + self.assertIs(loop, policy.get_event_loop()) @unittest.mock.patch('tulip.events.threading') def test_get_event_loop_thread(self, m_threading): diff --git a/tests/futures_test.py b/tests/futures_test.py index 0450e51a..44955182 100644 --- a/tests/futures_test.py +++ b/tests/futures_test.py @@ -213,12 +213,6 @@ class _FakeEventLoop: def call_soon(self, fn, future): fn(future) - def set_log_level(self, val): - pass - - def get_log_level(self): - return logging.CRITICAL - class FutureDoneCallbackTests(unittest.TestCase): @@ -229,7 +223,7 @@ def bag_appender(future): return bag_appender def _new_future(self): - return futures.Future(event_loop=_FakeEventLoop()) + return futures.Future(loop=_FakeEventLoop()) def test_callbacks_invoked_on_set_result(self): bag = [] diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index b70e6e50..3232fa46 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -30,86 +30,83 @@ def _make_self_pipe(self): class BaseSelectorEventLoopTests(unittest.TestCase): def setUp(self): - self.event_loop = TestBaseSelectorEventLoop(unittest.mock.Mock()) + self.loop = TestBaseSelectorEventLoop(unittest.mock.Mock()) def test_make_socket_transport(self): m = unittest.mock.Mock() - self.event_loop.add_reader = unittest.mock.Mock() + self.loop.add_reader = unittest.mock.Mock() self.assertIsInstance( - self.event_loop._make_socket_transport(m, m), - _SelectorSocketTransport) + self.loop._make_socket_transport(m, m), _SelectorSocketTransport) def test_make_ssl_transport(self): m = unittest.mock.Mock() - self.event_loop.add_reader = unittest.mock.Mock() - self.event_loop.add_writer = unittest.mock.Mock() - self.event_loop.remove_reader = unittest.mock.Mock() - self.event_loop.remove_writer = unittest.mock.Mock() + self.loop.add_reader = unittest.mock.Mock() + self.loop.add_writer = unittest.mock.Mock() + self.loop.remove_reader = unittest.mock.Mock() + self.loop.remove_writer = unittest.mock.Mock() self.assertIsInstance( - self.event_loop._make_ssl_transport(m, m, m, m), - _SelectorSslTransport) + self.loop._make_ssl_transport(m, m, m, m), _SelectorSslTransport) def test_close(self): - ssock = self.event_loop._ssock + ssock = self.loop._ssock ssock.fileno.return_value = 7 - csock = self.event_loop._csock + csock = self.loop._csock csock.fileno.return_value = 1 - remove_reader = self.event_loop.remove_reader = unittest.mock.Mock() - - self.event_loop._selector.close() - self.event_loop._selector = selector = unittest.mock.Mock() - self.event_loop.close() - self.assertIsNone(self.event_loop._selector) - self.assertIsNone(self.event_loop._csock) - self.assertIsNone(self.event_loop._ssock) + remove_reader = self.loop.remove_reader = unittest.mock.Mock() + + self.loop._selector.close() + self.loop._selector = selector = unittest.mock.Mock() + self.loop.close() + self.assertIsNone(self.loop._selector) + self.assertIsNone(self.loop._csock) + self.assertIsNone(self.loop._ssock) selector.close.assert_called_with() ssock.close.assert_called_with() csock.close.assert_called_with() remove_reader.assert_called_with(7) - self.event_loop.close() - self.event_loop.close() + self.loop.close() + self.loop.close() def test_close_no_selector(self): - ssock = self.event_loop._ssock - csock = self.event_loop._csock - remove_reader = self.event_loop.remove_reader = unittest.mock.Mock() - - self.event_loop._selector.close() - self.event_loop._selector = None - self.event_loop.close() - self.assertIsNone(self.event_loop._selector) + ssock = self.loop._ssock + csock = self.loop._csock + remove_reader = self.loop.remove_reader = unittest.mock.Mock() + + self.loop._selector.close() + self.loop._selector = None + self.loop.close() + self.assertIsNone(self.loop._selector) self.assertFalse(ssock.close.called) self.assertFalse(csock.close.called) self.assertFalse(remove_reader.called) def test_socketpair(self): - self.assertRaises(NotImplementedError, self.event_loop._socketpair) + self.assertRaises(NotImplementedError, self.loop._socketpair) def test_read_from_self_tryagain(self): - self.event_loop._ssock.recv.side_effect = BlockingIOError - self.assertIsNone(self.event_loop._read_from_self()) + self.loop._ssock.recv.side_effect = BlockingIOError + self.assertIsNone(self.loop._read_from_self()) def test_read_from_self_exception(self): - self.event_loop._ssock.recv.side_effect = OSError - self.assertRaises(OSError, self.event_loop._read_from_self) + self.loop._ssock.recv.side_effect = OSError + self.assertRaises(OSError, self.loop._read_from_self) def test_write_to_self_tryagain(self): - self.event_loop._csock.send.side_effect = BlockingIOError - self.assertIsNone(self.event_loop._write_to_self()) + self.loop._csock.send.side_effect = BlockingIOError + self.assertIsNone(self.loop._write_to_self()) def test_write_to_self_exception(self): - self.event_loop._csock.send.side_effect = OSError() - self.assertRaises(OSError, self.event_loop._write_to_self) + self.loop._csock.send.side_effect = OSError() + self.assertRaises(OSError, self.loop._write_to_self) def test_sock_recv(self): sock = unittest.mock.Mock() - self.event_loop._sock_recv = unittest.mock.Mock() + self.loop._sock_recv = unittest.mock.Mock() - f = self.event_loop.sock_recv(sock, 1024) + f = self.loop.sock_recv(sock, 1024) self.assertIsInstance(f, futures.Future) - self.event_loop._sock_recv.assert_called_with( - f, False, sock, 1024) + self.loop._sock_recv.assert_called_with(f, False, sock, 1024) def test__sock_recv_canceled_fut(self): sock = unittest.mock.Mock() @@ -117,7 +114,7 @@ def test__sock_recv_canceled_fut(self): f = futures.Future() f.cancel() - self.event_loop._sock_recv(f, False, sock, 1024) + self.loop._sock_recv(f, False, sock, 1024) self.assertFalse(sock.recv.called) def test__sock_recv_unregister(self): @@ -127,9 +124,9 @@ def test__sock_recv_unregister(self): f = futures.Future() f.cancel() - self.event_loop.remove_reader = unittest.mock.Mock() - self.event_loop._sock_recv(f, True, sock, 1024) - self.assertEqual((10,), self.event_loop.remove_reader.call_args[0]) + self.loop.remove_reader = unittest.mock.Mock() + self.loop._sock_recv(f, True, sock, 1024) + self.assertEqual((10,), self.loop.remove_reader.call_args[0]) def test__sock_recv_tryagain(self): f = futures.Future() @@ -137,10 +134,10 @@ def test__sock_recv_tryagain(self): sock.fileno.return_value = 10 sock.recv.side_effect = BlockingIOError - self.event_loop.add_reader = unittest.mock.Mock() - self.event_loop._sock_recv(f, False, sock, 1024) - self.assertEqual((10, self.event_loop._sock_recv, f, True, sock, 1024), - self.event_loop.add_reader.call_args[0]) + self.loop.add_reader = unittest.mock.Mock() + self.loop._sock_recv(f, False, sock, 1024) + self.assertEqual((10, self.loop._sock_recv, f, True, sock, 1024), + self.loop.add_reader.call_args[0]) def test__sock_recv_exception(self): f = futures.Future() @@ -148,28 +145,28 @@ def test__sock_recv_exception(self): sock.fileno.return_value = 10 err = sock.recv.side_effect = OSError() - self.event_loop._sock_recv(f, False, sock, 1024) + self.loop._sock_recv(f, False, sock, 1024) self.assertIs(err, f.exception()) def test_sock_sendall(self): sock = unittest.mock.Mock() - self.event_loop._sock_sendall = unittest.mock.Mock() + self.loop._sock_sendall = unittest.mock.Mock() - f = self.event_loop.sock_sendall(sock, b'data') + f = self.loop.sock_sendall(sock, b'data') self.assertIsInstance(f, futures.Future) self.assertEqual( (f, False, sock, b'data'), - self.event_loop._sock_sendall.call_args[0]) + self.loop._sock_sendall.call_args[0]) def test_sock_sendall_nodata(self): sock = unittest.mock.Mock() - self.event_loop._sock_sendall = unittest.mock.Mock() + self.loop._sock_sendall = unittest.mock.Mock() - f = self.event_loop.sock_sendall(sock, b'') + f = self.loop.sock_sendall(sock, b'') self.assertIsInstance(f, futures.Future) self.assertTrue(f.done()) self.assertIsNone(f.result()) - self.assertFalse(self.event_loop._sock_sendall.called) + self.assertFalse(self.loop._sock_sendall.called) def test__sock_sendall_canceled_fut(self): sock = unittest.mock.Mock() @@ -177,7 +174,7 @@ def test__sock_sendall_canceled_fut(self): f = futures.Future() f.cancel() - self.event_loop._sock_sendall(f, False, sock, b'data') + self.loop._sock_sendall(f, False, sock, b'data') self.assertFalse(sock.send.called) def test__sock_sendall_unregister(self): @@ -187,9 +184,9 @@ def test__sock_sendall_unregister(self): f = futures.Future() f.cancel() - self.event_loop.remove_writer = unittest.mock.Mock() - self.event_loop._sock_sendall(f, True, sock, b'data') - self.assertEqual((10,), self.event_loop.remove_writer.call_args[0]) + self.loop.remove_writer = unittest.mock.Mock() + self.loop._sock_sendall(f, True, sock, b'data') + self.assertEqual((10,), self.loop.remove_writer.call_args[0]) def test__sock_sendall_tryagain(self): f = futures.Future() @@ -197,11 +194,11 @@ def test__sock_sendall_tryagain(self): sock.fileno.return_value = 10 sock.send.side_effect = BlockingIOError - self.event_loop.add_writer = unittest.mock.Mock() - self.event_loop._sock_sendall(f, False, sock, b'data') + self.loop.add_writer = unittest.mock.Mock() + self.loop._sock_sendall(f, False, sock, b'data') self.assertEqual( - (10, self.event_loop._sock_sendall, f, True, sock, b'data'), - self.event_loop.add_writer.call_args[0]) + (10, self.loop._sock_sendall, f, True, sock, b'data'), + self.loop.add_writer.call_args[0]) def test__sock_sendall_exception(self): f = futures.Future() @@ -209,7 +206,7 @@ def test__sock_sendall_exception(self): sock.fileno.return_value = 10 err = sock.send.side_effect = OSError() - self.event_loop._sock_sendall(f, False, sock, b'data') + self.loop._sock_sendall(f, False, sock, b'data') self.assertIs(f.exception(), err) def test__sock_sendall(self): @@ -219,7 +216,7 @@ def test__sock_sendall(self): sock.fileno.return_value = 10 sock.send.return_value = 4 - self.event_loop._sock_sendall(f, False, sock, b'data') + self.loop._sock_sendall(f, False, sock, b'data') self.assertTrue(f.done()) self.assertIsNone(f.result()) @@ -230,12 +227,12 @@ def test__sock_sendall_partial(self): sock.fileno.return_value = 10 sock.send.return_value = 2 - self.event_loop.add_writer = unittest.mock.Mock() - self.event_loop._sock_sendall(f, False, sock, b'data') + self.loop.add_writer = unittest.mock.Mock() + self.loop._sock_sendall(f, False, sock, b'data') self.assertFalse(f.done()) self.assertEqual( - (10, self.event_loop._sock_sendall, f, True, sock, b'ta'), - self.event_loop.add_writer.call_args[0]) + (10, self.loop._sock_sendall, f, True, sock, b'ta'), + self.loop.add_writer.call_args[0]) def test__sock_sendall_none(self): sock = unittest.mock.Mock() @@ -244,22 +241,22 @@ def test__sock_sendall_none(self): sock.fileno.return_value = 10 sock.send.return_value = 0 - self.event_loop.add_writer = unittest.mock.Mock() - self.event_loop._sock_sendall(f, False, sock, b'data') + self.loop.add_writer = unittest.mock.Mock() + self.loop._sock_sendall(f, False, sock, b'data') self.assertFalse(f.done()) self.assertEqual( - (10, self.event_loop._sock_sendall, f, True, sock, b'data'), - self.event_loop.add_writer.call_args[0]) + (10, self.loop._sock_sendall, f, True, sock, b'data'), + self.loop.add_writer.call_args[0]) def test_sock_connect(self): sock = unittest.mock.Mock() - self.event_loop._sock_connect = unittest.mock.Mock() + self.loop._sock_connect = unittest.mock.Mock() - f = self.event_loop.sock_connect(sock, ('127.0.0.1', 8080)) + f = self.loop.sock_connect(sock, ('127.0.0.1', 8080)) self.assertIsInstance(f, futures.Future) self.assertEqual( (f, False, sock, ('127.0.0.1', 8080)), - self.event_loop._sock_connect.call_args[0]) + self.loop._sock_connect.call_args[0]) def test__sock_connect(self): f = futures.Future() @@ -267,7 +264,7 @@ def test__sock_connect(self): sock = unittest.mock.Mock() sock.fileno.return_value = 10 - self.event_loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) + self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) self.assertTrue(f.done()) self.assertIsNone(f.result()) self.assertTrue(sock.connect.called) @@ -278,7 +275,7 @@ def test__sock_connect_canceled_fut(self): f = futures.Future() f.cancel() - self.event_loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) + self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) self.assertFalse(sock.connect.called) def test__sock_connect_unregister(self): @@ -288,9 +285,9 @@ def test__sock_connect_unregister(self): f = futures.Future() f.cancel() - self.event_loop.remove_writer = unittest.mock.Mock() - self.event_loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) - self.assertEqual((10,), self.event_loop.remove_writer.call_args[0]) + self.loop.remove_writer = unittest.mock.Mock() + self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.assertEqual((10,), self.loop.remove_writer.call_args[0]) def test__sock_connect_tryagain(self): f = futures.Future() @@ -298,14 +295,14 @@ def test__sock_connect_tryagain(self): sock.fileno.return_value = 10 sock.getsockopt.return_value = errno.EAGAIN - self.event_loop.add_writer = unittest.mock.Mock() - self.event_loop.remove_writer = unittest.mock.Mock() + self.loop.add_writer = unittest.mock.Mock() + self.loop.remove_writer = unittest.mock.Mock() - self.event_loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) self.assertEqual( - (10, self.event_loop._sock_connect, f, + (10, self.loop._sock_connect, f, True, sock, ('127.0.0.1', 8080)), - self.event_loop.add_writer.call_args[0]) + self.loop.add_writer.call_args[0]) def test__sock_connect_exception(self): f = futures.Future() @@ -313,18 +310,18 @@ def test__sock_connect_exception(self): sock.fileno.return_value = 10 sock.getsockopt.return_value = errno.ENOTCONN - self.event_loop.remove_writer = unittest.mock.Mock() - self.event_loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.loop.remove_writer = unittest.mock.Mock() + self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) self.assertIsInstance(f.exception(), socket.error) def test_sock_accept(self): sock = unittest.mock.Mock() - self.event_loop._sock_accept = unittest.mock.Mock() + self.loop._sock_accept = unittest.mock.Mock() - f = self.event_loop.sock_accept(sock) + f = self.loop.sock_accept(sock) self.assertIsInstance(f, futures.Future) self.assertEqual( - (f, False, sock), self.event_loop._sock_accept.call_args[0]) + (f, False, sock), self.loop._sock_accept.call_args[0]) def test__sock_accept(self): f = futures.Future() @@ -335,7 +332,7 @@ def test__sock_accept(self): sock.fileno.return_value = 10 sock.accept.return_value = conn, ('127.0.0.1', 1000) - self.event_loop._sock_accept(f, False, sock) + self.loop._sock_accept(f, False, sock) self.assertTrue(f.done()) self.assertEqual((conn, ('127.0.0.1', 1000)), f.result()) self.assertEqual((False,), conn.setblocking.call_args[0]) @@ -346,7 +343,7 @@ def test__sock_accept_canceled_fut(self): f = futures.Future() f.cancel() - self.event_loop._sock_accept(f, False, sock) + self.loop._sock_accept(f, False, sock) self.assertFalse(sock.accept.called) def test__sock_accept_unregister(self): @@ -356,9 +353,9 @@ def test__sock_accept_unregister(self): f = futures.Future() f.cancel() - self.event_loop.remove_reader = unittest.mock.Mock() - self.event_loop._sock_accept(f, True, sock) - self.assertEqual((10,), self.event_loop.remove_reader.call_args[0]) + self.loop.remove_reader = unittest.mock.Mock() + self.loop._sock_accept(f, True, sock) + self.assertEqual((10,), self.loop.remove_reader.call_args[0]) def test__sock_accept_tryagain(self): f = futures.Future() @@ -366,11 +363,11 @@ def test__sock_accept_tryagain(self): sock.fileno.return_value = 10 sock.accept.side_effect = BlockingIOError - self.event_loop.add_reader = unittest.mock.Mock() - self.event_loop._sock_accept(f, False, sock) + self.loop.add_reader = unittest.mock.Mock() + self.loop._sock_accept(f, False, sock) self.assertEqual( - (10, self.event_loop._sock_accept, f, True, sock), - self.event_loop.add_reader.call_args[0]) + (10, self.loop._sock_accept, f, True, sock), + self.loop.add_reader.call_args[0]) def test__sock_accept_exception(self): f = futures.Future() @@ -378,16 +375,16 @@ def test__sock_accept_exception(self): sock.fileno.return_value = 10 err = sock.accept.side_effect = OSError() - self.event_loop._sock_accept(f, False, sock) + self.loop._sock_accept(f, False, sock) self.assertIs(err, f.exception()) def test_add_reader(self): - self.event_loop._selector.get_info.side_effect = KeyError + self.loop._selector.get_info.side_effect = KeyError cb = lambda: True - self.event_loop.add_reader(1, cb) + self.loop.add_reader(1, cb) - self.assertTrue(self.event_loop._selector.register.called) - fd, mask, (r, w) = self.event_loop._selector.register.call_args[0] + self.assertTrue(self.loop._selector.register.called) + fd, mask, (r, w) = self.loop._selector.register.call_args[0] self.assertEqual(1, fd) self.assertEqual(selectors.EVENT_READ, mask) self.assertEqual(cb, r.callback) @@ -396,15 +393,15 @@ def test_add_reader(self): def test_add_reader_existing(self): reader = unittest.mock.Mock() writer = unittest.mock.Mock() - self.event_loop._selector.get_info.return_value = ( + self.loop._selector.get_info.return_value = ( selectors.EVENT_WRITE, (reader, writer)) cb = lambda: True - self.event_loop.add_reader(1, cb) + self.loop.add_reader(1, cb) self.assertTrue(reader.cancel.called) - self.assertFalse(self.event_loop._selector.register.called) - self.assertTrue(self.event_loop._selector.modify.called) - fd, mask, (r, w) = self.event_loop._selector.modify.call_args[0] + self.assertFalse(self.loop._selector.register.called) + self.assertTrue(self.loop._selector.modify.called) + fd, mask, (r, w) = self.loop._selector.modify.call_args[0] self.assertEqual(1, fd) self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask) self.assertEqual(cb, r.callback) @@ -412,51 +409,51 @@ def test_add_reader_existing(self): def test_add_reader_existing_writer(self): writer = unittest.mock.Mock() - self.event_loop._selector.get_info.return_value = ( + self.loop._selector.get_info.return_value = ( selectors.EVENT_WRITE, (None, writer)) cb = lambda: True - self.event_loop.add_reader(1, cb) + self.loop.add_reader(1, cb) - self.assertFalse(self.event_loop._selector.register.called) - self.assertTrue(self.event_loop._selector.modify.called) - fd, mask, (r, w) = self.event_loop._selector.modify.call_args[0] + self.assertFalse(self.loop._selector.register.called) + self.assertTrue(self.loop._selector.modify.called) + fd, mask, (r, w) = self.loop._selector.modify.call_args[0] self.assertEqual(1, fd) self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask) self.assertEqual(cb, r.callback) self.assertEqual(writer, w) def test_remove_reader(self): - self.event_loop._selector.get_info.return_value = ( + self.loop._selector.get_info.return_value = ( selectors.EVENT_READ, (None, None)) - self.assertFalse(self.event_loop.remove_reader(1)) + self.assertFalse(self.loop.remove_reader(1)) - self.assertTrue(self.event_loop._selector.unregister.called) + self.assertTrue(self.loop._selector.unregister.called) def test_remove_reader_read_write(self): reader = unittest.mock.Mock() writer = unittest.mock.Mock() - self.event_loop._selector.get_info.return_value = ( + self.loop._selector.get_info.return_value = ( selectors.EVENT_READ | selectors.EVENT_WRITE, (reader, writer)) self.assertTrue( - self.event_loop.remove_reader(1)) + self.loop.remove_reader(1)) - self.assertFalse(self.event_loop._selector.unregister.called) + self.assertFalse(self.loop._selector.unregister.called) self.assertEqual( (1, selectors.EVENT_WRITE, (None, writer)), - self.event_loop._selector.modify.call_args[0]) + self.loop._selector.modify.call_args[0]) def test_remove_reader_unknown(self): - self.event_loop._selector.get_info.side_effect = KeyError + self.loop._selector.get_info.side_effect = KeyError self.assertFalse( - self.event_loop.remove_reader(1)) + self.loop.remove_reader(1)) def test_add_writer(self): - self.event_loop._selector.get_info.side_effect = KeyError + self.loop._selector.get_info.side_effect = KeyError cb = lambda: True - self.event_loop.add_writer(1, cb) + self.loop.add_writer(1, cb) - self.assertTrue(self.event_loop._selector.register.called) - fd, mask, (r, w) = self.event_loop._selector.register.call_args[0] + self.assertTrue(self.loop._selector.register.called) + fd, mask, (r, w) = self.loop._selector.register.call_args[0] self.assertEqual(1, fd) self.assertEqual(selectors.EVENT_WRITE, mask) self.assertEqual(None, r) @@ -465,110 +462,110 @@ def test_add_writer(self): def test_add_writer_existing(self): reader = unittest.mock.Mock() writer = unittest.mock.Mock() - self.event_loop._selector.get_info.return_value = ( + self.loop._selector.get_info.return_value = ( selectors.EVENT_READ, (reader, writer)) cb = lambda: True - self.event_loop.add_writer(1, cb) + self.loop.add_writer(1, cb) self.assertTrue(writer.cancel.called) - self.assertFalse(self.event_loop._selector.register.called) - self.assertTrue(self.event_loop._selector.modify.called) - fd, mask, (r, w) = self.event_loop._selector.modify.call_args[0] + self.assertFalse(self.loop._selector.register.called) + self.assertTrue(self.loop._selector.modify.called) + fd, mask, (r, w) = self.loop._selector.modify.call_args[0] self.assertEqual(1, fd) self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask) self.assertEqual(reader, r) self.assertEqual(cb, w.callback) def test_remove_writer(self): - self.event_loop._selector.get_info.return_value = ( + self.loop._selector.get_info.return_value = ( selectors.EVENT_WRITE, (None, None)) - self.assertFalse(self.event_loop.remove_writer(1)) + self.assertFalse(self.loop.remove_writer(1)) - self.assertTrue(self.event_loop._selector.unregister.called) + self.assertTrue(self.loop._selector.unregister.called) def test_remove_writer_read_write(self): reader = unittest.mock.Mock() writer = unittest.mock.Mock() - self.event_loop._selector.get_info.return_value = ( + self.loop._selector.get_info.return_value = ( selectors.EVENT_READ | selectors.EVENT_WRITE, (reader, writer)) self.assertTrue( - self.event_loop.remove_writer(1)) + self.loop.remove_writer(1)) - self.assertFalse(self.event_loop._selector.unregister.called) + self.assertFalse(self.loop._selector.unregister.called) self.assertEqual( (1, selectors.EVENT_READ, (reader, None)), - self.event_loop._selector.modify.call_args[0]) + self.loop._selector.modify.call_args[0]) def test_remove_writer_unknown(self): - self.event_loop._selector.get_info.side_effect = KeyError + self.loop._selector.get_info.side_effect = KeyError self.assertFalse( - self.event_loop.remove_writer(1)) + self.loop.remove_writer(1)) def test_process_events_read(self): reader = unittest.mock.Mock() reader.cancelled = False - self.event_loop._add_callback = unittest.mock.Mock() - self.event_loop._process_events( + self.loop._add_callback = unittest.mock.Mock() + self.loop._process_events( ((1, selectors.EVENT_READ, (reader, None)),)) - self.assertTrue(self.event_loop._add_callback.called) - self.event_loop._add_callback.assert_called_with(reader) + self.assertTrue(self.loop._add_callback.called) + self.loop._add_callback.assert_called_with(reader) def test_process_events_read_cancelled(self): reader = unittest.mock.Mock() reader.cancelled = True - self.event_loop.remove_reader = unittest.mock.Mock() - self.event_loop._process_events( + self.loop.remove_reader = unittest.mock.Mock() + self.loop._process_events( ((1, selectors.EVENT_READ, (reader, None)),)) - self.event_loop.remove_reader.assert_called_with(1) + self.loop.remove_reader.assert_called_with(1) def test_process_events_write(self): writer = unittest.mock.Mock() writer.cancelled = False - self.event_loop._add_callback = unittest.mock.Mock() - self.event_loop._process_events( + self.loop._add_callback = unittest.mock.Mock() + self.loop._process_events( ((1, selectors.EVENT_WRITE, (None, writer)),)) - self.event_loop._add_callback.assert_called_with(writer) + self.loop._add_callback.assert_called_with(writer) def test_process_events_write_cancelled(self): writer = unittest.mock.Mock() writer.cancelled = True - self.event_loop.remove_writer = unittest.mock.Mock() + self.loop.remove_writer = unittest.mock.Mock() - self.event_loop._process_events( + self.loop._process_events( ((1, selectors.EVENT_WRITE, (None, writer)),)) - self.event_loop.remove_writer.assert_called_with(1) + self.loop.remove_writer.assert_called_with(1) class SelectorSocketTransportTests(unittest.TestCase): def setUp(self): - self.event_loop = unittest.mock.Mock(spec_set=AbstractEventLoop) + self.loop = unittest.mock.Mock(spec_set=AbstractEventLoop) self.sock = unittest.mock.Mock(socket.socket) self.sock.fileno.return_value = 7 self.protocol = unittest.mock.Mock(Protocol) def test_ctor(self): tr = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) - self.event_loop.add_reader.assert_called_with(7, tr._read_ready) - self.event_loop.call_soon.assert_called_with( + self.loop, self.sock, self.protocol) + self.loop.add_reader.assert_called_with(7, tr._read_ready) + self.loop.call_soon.assert_called_with( self.protocol.connection_made, tr) def test_ctor_with_waiter(self): fut = futures.Future() _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol, fut) - self.assertEqual(2, self.event_loop.call_soon.call_count) + self.loop, self.sock, self.protocol, fut) + self.assertEqual(2, self.loop.call_soon.call_count) self.assertEqual(fut.set_result, - self.event_loop.call_soon.call_args[0][0]) + self.loop.call_soon.call_args[0][0]) def test_read_ready(self): transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) self.sock.recv.return_value = b'data' transport._read_ready() @@ -577,12 +574,12 @@ def test_read_ready(self): def test_read_ready_eof(self): transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) self.sock.recv.return_value = b'' transport._read_ready() - self.assertTrue(self.event_loop.remove_reader.called) + self.assertTrue(self.loop.remove_reader.called) self.protocol.eof_received.assert_called_with() @unittest.mock.patch('logging.exception') @@ -590,7 +587,7 @@ def test_read_ready_tryagain(self, m_exc): self.sock.recv.side_effect = BlockingIOError transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._fatal_error = unittest.mock.Mock() transport._read_ready() @@ -601,7 +598,7 @@ def test_read_ready_err(self, m_exc): err = self.sock.recv.side_effect = OSError() transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._fatal_error = unittest.mock.Mock() transport._read_ready() @@ -609,7 +606,7 @@ def test_read_ready_err(self, m_exc): def test_abort(self): transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._close = unittest.mock.Mock() transport.abort() @@ -620,13 +617,13 @@ def test_write(self): self.sock.send.return_value = len(data) transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport.write(data) self.sock.send.assert_called_with(data) def test_write_no_data(self): transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._buffer.append(b'data') transport.write(b'') self.assertFalse(self.sock.send.called) @@ -634,7 +631,7 @@ def test_write_no_data(self): def test_write_buffer(self): transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._buffer.append(b'data1') transport.write(b'data2') self.assertFalse(self.sock.send.called) @@ -645,12 +642,12 @@ def test_write_partial(self): self.sock.send.return_value = 2 transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport.write(data) - self.assertTrue(self.event_loop.add_writer.called) + self.assertTrue(self.loop.add_writer.called) self.assertEqual( - transport._write_ready, self.event_loop.add_writer.call_args[0][1]) + transport._write_ready, self.loop.add_writer.call_args[0][1]) self.assertEqual([b'ta'], transport._buffer) @@ -660,10 +657,10 @@ def test_write_partial_none(self): self.sock.fileno.return_value = 7 transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport.write(data) - self.event_loop.add_writer.assert_called_with( + self.loop.add_writer.assert_called_with( 7, transport._write_ready) self.assertEqual([b'data'], transport._buffer) @@ -672,12 +669,12 @@ def test_write_tryagain(self): data = b'data' transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport.write(data) - self.assertTrue(self.event_loop.add_writer.called) + self.assertTrue(self.loop.add_writer.called) self.assertEqual( - transport._write_ready, self.event_loop.add_writer.call_args[0][1]) + transport._write_ready, self.loop.add_writer.call_args[0][1]) self.assertEqual([b'data'], transport._buffer) @@ -687,7 +684,7 @@ def test_write_exception(self, m_log): data = b'data' transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._fatal_error = unittest.mock.Mock() transport.write(data) transport._fatal_error.assert_called_with(err) @@ -705,12 +702,12 @@ def test_write_exception(self, m_log): def test_write_str(self): transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) self.assertRaises(AssertionError, transport.write, 'str') def test_write_closing(self): transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport.close() self.assertRaises(AssertionError, transport.write, b'data') @@ -719,30 +716,30 @@ def test_write_ready(self): self.sock.send.return_value = len(data) transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._buffer.append(data) transport._write_ready() self.assertTrue(self.sock.send.called) self.assertEqual(self.sock.send.call_args[0], (data,)) - self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.loop.remove_writer.called) def test_write_ready_closing(self): data = b'data' self.sock.send.return_value = len(data) transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._closing = True transport._buffer.append(data) transport._write_ready() self.sock.send.assert_called_with(data) - self.event_loop.remove_writer.assert_called_with(7) + self.loop.remove_writer.assert_called_with(7) self.sock.close.assert_called_with() self.protocol.connection_lost.assert_called_with(None) def test_write_ready_no_data(self): transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) self.assertRaises(AssertionError, transport._write_ready) def test_write_ready_partial(self): @@ -750,10 +747,10 @@ def test_write_ready_partial(self): self.sock.send.return_value = 2 transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._buffer.append(data) transport._write_ready() - self.assertFalse(self.event_loop.remove_writer.called) + self.assertFalse(self.loop.remove_writer.called) self.assertEqual([b'ta'], transport._buffer) def test_write_ready_partial_none(self): @@ -761,28 +758,28 @@ def test_write_ready_partial_none(self): self.sock.send.return_value = 0 transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._buffer.append(data) transport._write_ready() - self.assertFalse(self.event_loop.remove_writer.called) + self.assertFalse(self.loop.remove_writer.called) self.assertEqual([b'data'], transport._buffer) def test_write_ready_tryagain(self): self.sock.send.side_effect = BlockingIOError transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._buffer = [b'data1', b'data2'] transport._write_ready() - self.assertFalse(self.event_loop.remove_writer.called) + self.assertFalse(self.loop.remove_writer.called) self.assertEqual([b'data1data2'], transport._buffer) def test_write_ready_exception(self): err = self.sock.send.side_effect = OSError() transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._fatal_error = unittest.mock.Mock() transport._buffer.append(b'data') transport._write_ready() @@ -791,42 +788,42 @@ def test_write_ready_exception(self): def test_close(self): transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport.close() self.assertTrue(transport._closing) - self.event_loop.remove_reader.assert_called_with(7) + self.loop.remove_reader.assert_called_with(7) self.protocol.connection_lost(None) def test_close_write_buffer(self): transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) - self.event_loop.reset_mock() + self.loop, self.sock, self.protocol) + self.loop.reset_mock() transport._buffer.append(b'data') transport.close() - self.assertTrue(self.event_loop.remove_reader.called) - self.assertFalse(self.event_loop.call_soon.called) + self.assertTrue(self.loop.remove_reader.called) + self.assertFalse(self.loop.call_soon.called) @unittest.mock.patch('tulip.log.tulip_log.exception') def test_fatal_error(self, m_exc): exc = OSError() transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._buffer.append(b'data') transport._fatal_error(exc) self.assertEqual([], transport._buffer) - self.event_loop.remove_reader.assert_called_with(7) - self.event_loop.remove_writer.assert_called_with(7) - self.event_loop.call_soon.assert_called_with( + self.loop.remove_reader.assert_called_with(7) + self.loop.remove_writer.assert_called_with(7) + self.loop.call_soon.assert_called_with( transport._call_connection_lost, exc) m_exc.assert_called_with('Fatal error for %s', transport) def test_connection_lost(self): exc = object() transport = _SelectorSocketTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._call_connection_lost(exc) self.protocol.connection_lost.assert_called_with(exc) @@ -1095,14 +1092,14 @@ def test_on_ready_send_exc(self): class SelectorDatagramTransportTests(unittest.TestCase): def setUp(self): - self.event_loop = unittest.mock.Mock(spec_set=AbstractEventLoop) + self.loop = unittest.mock.Mock(spec_set=AbstractEventLoop) self.sock = unittest.mock.Mock(spec_set=socket.socket) self.sock.fileno.return_value = 7 self.protocol = unittest.mock.Mock(spec_set=DatagramProtocol) def test_read_ready(self): transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) self.sock.recvfrom.return_value = (b'data', ('0.0.0.0', 1234)) transport._read_ready() @@ -1112,7 +1109,7 @@ def test_read_ready(self): def test_read_ready_tryagain(self): transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) self.sock.recvfrom.side_effect = BlockingIOError transport._fatal_error = unittest.mock.Mock() @@ -1122,7 +1119,7 @@ def test_read_ready_tryagain(self): def test_read_ready_err(self): transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) err = self.sock.recvfrom.side_effect = OSError() transport._fatal_error = unittest.mock.Mock() @@ -1132,7 +1129,7 @@ def test_read_ready_err(self): def test_abort(self): transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._close = unittest.mock.Mock() transport.abort() @@ -1141,7 +1138,7 @@ def test_abort(self): def test_sendto(self): data = b'data' transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport.sendto(data, ('0.0.0.0', 1234)) self.assertTrue(self.sock.sendto.called) self.assertEqual( @@ -1149,7 +1146,7 @@ def test_sendto(self): def test_sendto_no_data(self): transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._buffer.append((b'data', ('0.0.0.0', 12345))) transport.sendto(b'', ()) self.assertFalse(self.sock.sendto.called) @@ -1158,7 +1155,7 @@ def test_sendto_no_data(self): def test_sendto_buffer(self): transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._buffer.append((b'data1', ('0.0.0.0', 12345))) transport.sendto(b'data2', ('0.0.0.0', 12345)) self.assertFalse(self.sock.sendto.called) @@ -1173,13 +1170,13 @@ def test_sendto_tryagain(self): self.sock.sendto.side_effect = BlockingIOError transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport.sendto(data, ('0.0.0.0', 12345)) - self.assertTrue(self.event_loop.add_writer.called) + self.assertTrue(self.loop.add_writer.called) self.assertEqual( transport._sendto_ready, - self.event_loop.add_writer.call_args[0][1]) + self.loop.add_writer.call_args[0][1]) self.assertEqual( [(b'data', ('0.0.0.0', 12345))], list(transport._buffer)) @@ -1190,7 +1187,7 @@ def test_sendto_exception(self, m_log): err = self.sock.sendto.side_effect = OSError() transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._fatal_error = unittest.mock.Mock() transport.sendto(data, ()) @@ -1212,7 +1209,7 @@ def test_sendto_connection_refused(self): self.sock.sendto.side_effect = ConnectionRefusedError transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._fatal_error = unittest.mock.Mock() transport.sendto(data, ()) @@ -1225,7 +1222,7 @@ def test_sendto_connection_refused_connected(self): self.sock.send.side_effect = ConnectionRefusedError transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1)) + self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) transport._fatal_error = unittest.mock.Mock() transport.sendto(data) @@ -1234,18 +1231,18 @@ def test_sendto_connection_refused_connected(self): def test_sendto_str(self): transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) self.assertRaises(AssertionError, transport.sendto, 'str', ()) def test_sendto_connected_addr(self): transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1)) + self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) self.assertRaises( AssertionError, transport.sendto, b'str', ('0.0.0.0', 2)) def test_sendto_closing(self): transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport.close() self.assertRaises(AssertionError, transport.sendto, b'data', ()) @@ -1254,44 +1251,44 @@ def test_sendto_ready(self): self.sock.sendto.return_value = len(data) transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._buffer.append((data, ('0.0.0.0', 12345))) transport._sendto_ready() self.assertTrue(self.sock.sendto.called) self.assertEqual( self.sock.sendto.call_args[0], (data, ('0.0.0.0', 12345))) - self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.loop.remove_writer.called) def test_sendto_ready_closing(self): data = b'data' self.sock.send.return_value = len(data) transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._closing = True transport._buffer.append((data, ())) transport._sendto_ready() self.sock.sendto.assert_called_with(data, ()) - self.event_loop.remove_writer.assert_called_with(7) + self.loop.remove_writer.assert_called_with(7) self.sock.close.assert_called_with() self.protocol.connection_lost.assert_called_with(None) def test_sendto_ready_no_data(self): transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._sendto_ready() self.assertFalse(self.sock.sendto.called) - self.assertTrue(self.event_loop.remove_writer.called) + self.assertTrue(self.loop.remove_writer.called) def test_sendto_ready_tryagain(self): self.sock.sendto.side_effect = BlockingIOError transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._buffer.extend([(b'data1', ()), (b'data2', ())]) transport._sendto_ready() - self.assertFalse(self.event_loop.remove_writer.called) + self.assertFalse(self.loop.remove_writer.called) self.assertEqual( [(b'data1', ()), (b'data2', ())], list(transport._buffer)) @@ -1300,7 +1297,7 @@ def test_sendto_ready_exception(self): err = self.sock.sendto.side_effect = OSError() transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._fatal_error = unittest.mock.Mock() transport._buffer.append((b'data', ())) transport._sendto_ready() @@ -1311,7 +1308,7 @@ def test_sendto_ready_connection_refused(self): self.sock.sendto.side_effect = ConnectionRefusedError transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._fatal_error = unittest.mock.Mock() transport._buffer.append((b'data', ())) transport._sendto_ready() @@ -1322,7 +1319,7 @@ def test_sendto_ready_connection_refused_connection(self): self.sock.send.side_effect = ConnectionRefusedError transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1)) + self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) transport._fatal_error = unittest.mock.Mock() transport._buffer.append((b'data', ())) transport._sendto_ready() @@ -1331,43 +1328,43 @@ def test_sendto_ready_connection_refused_connection(self): def test_close(self): transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport.close() self.assertTrue(transport._closing) - self.event_loop.remove_reader.assert_called_with(7) - self.event_loop.call_soon.assert_called_with( + self.loop.remove_reader.assert_called_with(7) + self.loop.call_soon.assert_called_with( transport._call_connection_lost, None) def test_close_write_buffer(self): transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._buffer.append((b'data', ())) transport.close() - self.event_loop.remove_reader.assert_called_with(7) + self.loop.remove_reader.assert_called_with(7) self.assertFalse(self.protocol.connection_lost.called) @unittest.mock.patch('tulip.log.tulip_log.exception') def test_fatal_error(self, m_exc): exc = OSError() transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol) - self.event_loop.reset_mock() + self.loop, self.sock, self.protocol) + self.loop.reset_mock() transport._buffer.append((b'data', ())) transport._fatal_error(exc) self.assertEqual([], list(transport._buffer)) - self.event_loop.remove_writer.assert_called_with(7) - self.event_loop.remove_reader.assert_called_with(7) - self.event_loop.call_soon.assert_called_with( + self.loop.remove_writer.assert_called_with(7) + self.loop.remove_reader.assert_called_with(7) + self.loop.call_soon.assert_called_with( transport._call_connection_lost, exc) m_exc.assert_called_with('Fatal error for %s', transport) @unittest.mock.patch('tulip.log.tulip_log.exception') def test_fatal_error_connected(self, m_exc): transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol, ('0.0.0.0', 1)) + self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) err = ConnectionRefusedError() transport._fatal_error(err) self.protocol.connection_refused.assert_called_with(err) @@ -1376,7 +1373,7 @@ def test_fatal_error_connected(self, m_exc): def test_transport_closing(self): exc = object() transport = _SelectorDatagramTransport( - self.event_loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol) transport._call_connection_lost(exc) self.protocol.connection_lost.assert_called_with(exc) diff --git a/tests/streams_test.py b/tests/streams_test.py index 832c3119..ab148bdd 100644 --- a/tests/streams_test.py +++ b/tests/streams_test.py @@ -12,11 +12,11 @@ class StreamReaderTests(unittest.TestCase): DATA = b'line1\nline2\nline3\n' def setUp(self): - self.event_loop = events.new_event_loop() - events.set_event_loop(self.event_loop) + self.loop = events.new_event_loop() + events.set_event_loop(self.loop) def tearDown(self): - self.event_loop.close() + self.loop.close() def test_feed_empty_data(self): stream = streams.StreamReader() @@ -35,7 +35,7 @@ def test_read_zero(self): stream = streams.StreamReader() stream.feed_data(self.DATA) - data = self.event_loop.run_until_complete(stream.read(0)) + data = self.loop.run_until_complete(stream.read(0)) self.assertEqual(b'', data) self.assertEqual(len(self.DATA), stream.byte_count) @@ -46,9 +46,9 @@ def test_read(self): def cb(): stream.feed_data(self.DATA) - self.event_loop.call_soon(cb) + self.loop.call_soon(cb) - data = self.event_loop.run_until_complete(read_task) + data = self.loop.run_until_complete(read_task) self.assertEqual(self.DATA, data) self.assertFalse(stream.byte_count) @@ -58,7 +58,7 @@ def test_read_line_breaks(self): stream.feed_data(b'line1') stream.feed_data(b'line2') - data = self.event_loop.run_until_complete(stream.read(5)) + data = self.loop.run_until_complete(stream.read(5)) self.assertEqual(b'line1', data) self.assertEqual(5, stream.byte_count) @@ -70,9 +70,9 @@ def test_read_eof(self): def cb(): stream.feed_eof() - self.event_loop.call_soon(cb) + self.loop.call_soon(cb) - data = self.event_loop.run_until_complete(read_task) + data = self.loop.run_until_complete(read_task) self.assertEqual(b'', data) self.assertFalse(stream.byte_count) @@ -85,9 +85,9 @@ def cb(): stream.feed_data(b'chunk1\n') stream.feed_data(b'chunk2') stream.feed_eof() - self.event_loop.call_soon(cb) + self.loop.call_soon(cb) - data = self.event_loop.run_until_complete(read_task) + data = self.loop.run_until_complete(read_task) self.assertEqual(b'chunk1\nchunk2', data) self.assertFalse(stream.byte_count) @@ -96,13 +96,12 @@ def test_read_exception(self): stream = streams.StreamReader() stream.feed_data(b'line\n') - data = self.event_loop.run_until_complete(stream.read(2)) + data = self.loop.run_until_complete(stream.read(2)) self.assertEqual(b'li', data) stream.set_exception(ValueError()) self.assertRaises( - ValueError, - self.event_loop.run_until_complete, stream.read(2)) + ValueError, self.loop.run_until_complete, stream.read(2)) def test_readline(self): # Read one line. @@ -114,9 +113,9 @@ def cb(): stream.feed_data(b'chunk2 ') stream.feed_data(b'chunk3 ') stream.feed_data(b'\n chunk4') - self.event_loop.call_soon(cb) + self.loop.call_soon(cb) - line = self.event_loop.run_until_complete(read_task) + line = self.loop.run_until_complete(read_task) self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) self.assertEqual(len(b'\n chunk4')-1, stream.byte_count) @@ -126,7 +125,7 @@ def test_readline_limit_with_existing_data(self): stream.feed_data(b'ne1\nline2\n') self.assertRaises( - ValueError, self.event_loop.run_until_complete, stream.readline()) + ValueError, self.loop.run_until_complete, stream.readline()) self.assertEqual([b'line2\n'], list(stream.buffer)) stream = streams.StreamReader(3) @@ -135,7 +134,7 @@ def test_readline_limit_with_existing_data(self): stream.feed_data(b'li') self.assertRaises( - ValueError, self.event_loop.run_until_complete, stream.readline()) + ValueError, self.loop.run_until_complete, stream.readline()) self.assertEqual([b'li'], list(stream.buffer)) self.assertEqual(2, stream.byte_count) @@ -147,10 +146,10 @@ def cb(): stream.feed_data(b'chunk2') stream.feed_data(b'chunk3\n') stream.feed_eof() - self.event_loop.call_soon(cb) + self.loop.call_soon(cb) self.assertRaises( - ValueError, self.event_loop.run_until_complete, stream.readline()) + ValueError, self.loop.run_until_complete, stream.readline()) self.assertEqual([b'chunk3\n'], list(stream.buffer)) self.assertEqual(7, stream.byte_count) @@ -159,7 +158,7 @@ def test_readline_line_byte_count(self): stream.feed_data(self.DATA[:6]) stream.feed_data(self.DATA[6:]) - line = self.event_loop.run_until_complete(stream.readline()) + line = self.loop.run_until_complete(stream.readline()) self.assertEqual(b'line1\n', line) self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) @@ -169,23 +168,23 @@ def test_readline_eof(self): stream.feed_data(b'some data') stream.feed_eof() - line = self.event_loop.run_until_complete(stream.readline()) + line = self.loop.run_until_complete(stream.readline()) self.assertEqual(b'some data', line) def test_readline_empty_eof(self): stream = streams.StreamReader() stream.feed_eof() - line = self.event_loop.run_until_complete(stream.readline()) + line = self.loop.run_until_complete(stream.readline()) self.assertEqual(b'', line) def test_readline_read_byte_count(self): stream = streams.StreamReader() stream.feed_data(self.DATA) - self.event_loop.run_until_complete(stream.readline()) + self.loop.run_until_complete(stream.readline()) - data = self.event_loop.run_until_complete(stream.read(7)) + data = self.loop.run_until_complete(stream.read(7)) self.assertEqual(b'line2\nl', data) self.assertEqual( @@ -196,24 +195,23 @@ def test_readline_exception(self): stream = streams.StreamReader() stream.feed_data(b'line\n') - data = self.event_loop.run_until_complete(stream.readline()) + data = self.loop.run_until_complete(stream.readline()) self.assertEqual(b'line\n', data) stream.set_exception(ValueError()) self.assertRaises( - ValueError, - self.event_loop.run_until_complete, stream.readline()) + ValueError, self.loop.run_until_complete, stream.readline()) def test_readexactly_zero_or_less(self): # Read exact number of bytes (zero or less). stream = streams.StreamReader() stream.feed_data(self.DATA) - data = self.event_loop.run_until_complete(stream.readexactly(0)) + data = self.loop.run_until_complete(stream.readexactly(0)) self.assertEqual(b'', data) self.assertEqual(len(self.DATA), stream.byte_count) - data = self.event_loop.run_until_complete(stream.readexactly(-1)) + data = self.loop.run_until_complete(stream.readexactly(-1)) self.assertEqual(b'', data) self.assertEqual(len(self.DATA), stream.byte_count) @@ -228,9 +226,9 @@ def cb(): stream.feed_data(self.DATA) stream.feed_data(self.DATA) stream.feed_data(self.DATA) - self.event_loop.call_soon(cb) + self.loop.call_soon(cb) - data = self.event_loop.run_until_complete(read_task) + data = self.loop.run_until_complete(read_task) self.assertEqual(self.DATA + self.DATA, data) self.assertEqual(len(self.DATA), stream.byte_count) @@ -243,9 +241,9 @@ def test_readexactly_eof(self): def cb(): stream.feed_data(self.DATA) stream.feed_eof() - self.event_loop.call_soon(cb) + self.loop.call_soon(cb) - data = self.event_loop.run_until_complete(read_task) + data = self.loop.run_until_complete(read_task) self.assertEqual(self.DATA, data) self.assertFalse(stream.byte_count) @@ -253,14 +251,12 @@ def test_readexactly_exception(self): stream = streams.StreamReader() stream.feed_data(b'line\n') - data = self.event_loop.run_until_complete(stream.readexactly(2)) + data = self.loop.run_until_complete(stream.readexactly(2)) self.assertEqual(b'li', data) stream.set_exception(ValueError()) self.assertRaises( - ValueError, - self.event_loop.run_until_complete, - stream.readexactly(2)) + ValueError, self.loop.run_until_complete, stream.readexactly(2)) def test_exception(self): stream = streams.StreamReader() @@ -284,7 +280,7 @@ def readline(): t1 = tasks.Task(stream.readline()) t2 = tasks.Task(set_err()) - self.event_loop.run_until_complete(tasks.wait([t1, t2])) + self.loop.run_until_complete(tasks.wait([t1, t2])) self.assertRaises(ValueError, t1.result) diff --git a/tests/subprocess_test.py b/tests/subprocess_test.py index 28ab6623..01d72f7c 100644 --- a/tests/subprocess_test.py +++ b/tests/subprocess_test.py @@ -45,16 +45,16 @@ def connection_lost(self, exc): class FutureTests(unittest.TestCase): def setUp(self): - self.event_loop = events.new_event_loop() - events.set_event_loop(self.event_loop) + self.loop = events.new_event_loop() + events.set_event_loop(self.loop) def tearDown(self): - self.event_loop.close() + self.loop.close() def test_unix_subprocess(self): p = MyProto() subprocess_transport.UnixSubprocessTransport(p, ['/bin/ls', '-lR']) - self.event_loop.run_until_complete(p.done) + self.loop.run_until_complete(p.done) if __name__ == '__main__': diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 583674e8..d3def0c9 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -22,25 +22,25 @@ def __call__(self, *args): class TaskTests(unittest.TestCase): def setUp(self): - self.event_loop = events.new_event_loop() - events.set_event_loop(self.event_loop) + self.loop = events.new_event_loop() + events.set_event_loop(self.loop) def tearDown(self): - self.event_loop.close() + self.loop.close() def test_task_class(self): @tasks.coroutine def notmuch(): return 'ok' t = tasks.Task(notmuch()) - self.event_loop.run_until_complete(t) + self.loop.run_until_complete(t) self.assertTrue(t.done()) self.assertEqual(t.result(), 'ok') - self.assertIs(t._event_loop, self.event_loop) + self.assertIs(t._loop, self.loop) - event_loop = events.new_event_loop() - t = tasks.Task(notmuch(), event_loop=event_loop) - self.assertIs(t._event_loop, event_loop) + loop = events.new_event_loop() + t = tasks.Task(notmuch(), loop=loop) + self.assertIs(t._loop, loop) def test_task_decorator(self): @tasks.task @@ -48,7 +48,7 @@ def notmuch(): yield from [] return 'ko' t = notmuch() - self.event_loop.run_until_complete(t) + self.loop.run_until_complete(t) self.assertTrue(t.done()) self.assertEqual(t.result(), 'ko') @@ -57,7 +57,7 @@ def test_task_decorator_func(self): def notmuch(): return 'ko' t = notmuch() - self.event_loop.run_until_complete(t) + self.loop.run_until_complete(t) self.assertTrue(t.done()) self.assertEqual(t.result(), 'ko') @@ -69,7 +69,7 @@ def test_task_decorator_fut(self): def notmuch(): return fut t = notmuch() - self.event_loop.run_until_complete(t) + self.loop.run_until_complete(t) self.assertTrue(t.done()) self.assertEqual(t.result(), 'ko') @@ -85,10 +85,10 @@ def notmuch(): t.cancel() # Does not take immediate effect! self.assertEqual(repr(t), 'Task()') self.assertRaises(futures.CancelledError, - self.event_loop.run_until_complete, t) + self.loop.run_until_complete, t) self.assertEqual(repr(t), 'Task()') t = notmuch() - self.event_loop.run_until_complete(t) + self.loop.run_until_complete(t) self.assertEqual(repr(t), "Task()") def test_task_repr_custom(self): @@ -123,7 +123,7 @@ def inner2(): return 1000 t = outer() - self.assertEqual(self.event_loop.run_until_complete(t), 1042) + self.assertEqual(self.loop.run_until_complete(t), 1042) def test_cancel(self): @tasks.task @@ -132,10 +132,9 @@ def task(): return 12 t = task() - self.event_loop.call_soon(t.cancel) + self.loop.call_soon(t.cancel) self.assertRaises( - futures.CancelledError, - self.event_loop.run_until_complete, t) + futures.CancelledError, self.loop.run_until_complete, t) self.assertTrue(t.done()) self.assertFalse(t.cancel()) @@ -147,11 +146,10 @@ def task(): return 12 t = task() - self.event_loop.run_once() # start coro + self.loop.run_once() # start coro t.cancel() self.assertRaises( - futures.CancelledError, - self.event_loop.run_until_complete, t) + futures.CancelledError, self.loop.run_until_complete, t) self.assertTrue(t.done()) self.assertFalse(t.cancel()) @@ -170,14 +168,14 @@ def task(): yield from fut3 t = task() - self.event_loop.run_once() + self.loop.run_once() fut1.set_result(None) t.cancel() - self.event_loop.run_once() # process fut1 result, delay cancel + self.loop.run_once() # process fut1 result, delay cancel self.assertFalse(t.done()) - self.event_loop.run_once() # cancel fut2, but coro still alive + self.loop.run_once() # cancel fut2, but coro still alive self.assertFalse(t.done()) - self.event_loop.run_once() # cancel fut3 + self.loop.run_once() # cancel fut3 self.assertTrue(t.done()) self.assertEqual(fut1.result(), None) @@ -195,7 +193,7 @@ def coro(): self.assertRaises( futures.CancelledError, - self.event_loop.run_until_complete, t) + self.loop.run_until_complete, t) self.assertTrue(t.done()) self.assertFalse(t.cancel()) @@ -216,7 +214,7 @@ def coro2(): raise Cancelled() self.assertRaises( - Cancelled, self.event_loop.run_until_complete, coro2()) + Cancelled, self.loop.run_until_complete, coro2()) def test_cancel_in_coro(self): @tasks.coroutine @@ -226,8 +224,7 @@ def task(): t = tasks.Task(task()) self.assertRaises( - futures.CancelledError, - self.event_loop.run_until_complete, t) + futures.CancelledError, self.loop.run_until_complete, t) self.assertTrue(t.done()) self.assertFalse(t.cancel()) @@ -241,13 +238,12 @@ def task(): yield from tasks.sleep(0.1) x += 1 if x == 2: - self.event_loop.stop() + self.loop.stop() t = tasks.Task(task()) t0 = time.monotonic() self.assertRaises( - futures.InvalidStateError, - self.event_loop.run_until_complete, t) + futures.InvalidStateError, self.loop.run_until_complete, t) t1 = time.monotonic() self.assertFalse(t.done()) self.assertTrue(0.18 <= t1-t0 <= 0.22) @@ -262,8 +258,7 @@ def task(): t = task() t0 = time.monotonic() self.assertRaises( - futures.TimeoutError, - self.event_loop.run_until_complete, t, 0.1) + futures.TimeoutError, self.loop.run_until_complete, t, 0.1) t1 = time.monotonic() self.assertFalse(t.done()) self.assertTrue(0.08 <= t1-t0 <= 0.12) @@ -276,7 +271,7 @@ def task(): t = task() t0 = time.monotonic() - r = self.event_loop.run_until_complete(t, 10.0) + r = self.loop.run_until_complete(t, 10.0) t1 = time.monotonic() self.assertTrue(t.done()) self.assertEqual(r, 42) @@ -294,13 +289,13 @@ def foo(): return 42 t0 = time.monotonic() - res = self.event_loop.run_until_complete(tasks.Task(foo())) + res = self.loop.run_until_complete(tasks.Task(foo())) t1 = time.monotonic() self.assertTrue(t1-t0 >= 0.14) self.assertEqual(res, 42) # Doing it again should take no time and exercise a different path. t0 = time.monotonic() - res = self.event_loop.run_until_complete(tasks.Task(foo())) + res = self.loop.run_until_complete(tasks.Task(foo())) t1 = time.monotonic() self.assertTrue(t1-t0 <= 0.01) # TODO: Test different return_when values. @@ -311,7 +306,7 @@ def test_wait_first_completed(self): task = tasks.Task(tasks.wait( [b, a], return_when=tasks.FIRST_COMPLETED)) - done, pending = self.event_loop.run_until_complete(task) + done, pending = self.loop.run_until_complete(task) self.assertEqual({b}, done) self.assertEqual({a}, pending) self.assertFalse(a.done()) @@ -336,7 +331,7 @@ def coro2(): task = tasks.Task( tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED)) - done, pending = self.event_loop.run_until_complete(task) + done, pending = self.loop.run_until_complete(task) self.assertEqual({a, b}, done) self.assertTrue(a.done()) self.assertIsNone(a.result()) @@ -355,7 +350,7 @@ def exc(): task = tasks.Task(tasks.wait( [b, a], return_when=tasks.FIRST_EXCEPTION)) - done, pending = self.event_loop.run_until_complete(task) + done, pending = self.loop.run_until_complete(task) self.assertEqual({b}, done) self.assertEqual({a}, pending) @@ -371,7 +366,7 @@ def exc(): b = tasks.Task(exc()) task = tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION) - done, pending = self.event_loop.run_until_complete(task) + done, pending = self.loop.run_until_complete(task) self.assertEqual({b}, done) self.assertEqual({a}, pending) @@ -394,11 +389,11 @@ def foo(): self.assertEqual(len(errors), 1) t0 = time.monotonic() - self.event_loop.run_until_complete(tasks.Task(foo())) + self.loop.run_until_complete(tasks.Task(foo())) t1 = time.monotonic() self.assertTrue(t1-t0 >= 0.14) t0 = time.monotonic() - self.event_loop.run_until_complete(tasks.Task(foo())) + self.loop.run_until_complete(tasks.Task(foo())) t1 = time.monotonic() self.assertTrue(t1-t0 <= 0.01) @@ -413,7 +408,7 @@ def foo(): self.assertEqual(pending, set([b])) t0 = time.monotonic() - self.event_loop.run_until_complete(tasks.Task(foo())) + self.loop.run_until_complete(tasks.Task(foo())) t1 = time.monotonic() self.assertTrue(t1-t0 >= 0.1) self.assertTrue(t1-t0 <= 0.13) @@ -436,7 +431,7 @@ def foo(): return values t0 = time.monotonic() - res = self.event_loop.run_until_complete(tasks.Task(foo())) + res = self.loop.run_until_complete(tasks.Task(foo())) t1 = time.monotonic() self.assertTrue(t1-t0 >= 0.14) self.assertTrue('a' in res[:2]) @@ -444,7 +439,7 @@ def foo(): self.assertEqual(res[2], 'c') # Doing it again should take no time and exercise a different path. t0 = time.monotonic() - res = self.event_loop.run_until_complete(tasks.Task(foo())) + res = self.loop.run_until_complete(tasks.Task(foo())) t1 = time.monotonic() self.assertTrue(t1-t0 <= 0.01) @@ -464,7 +459,7 @@ def foo(): return values t0 = time.monotonic() - res = self.event_loop.run_until_complete(tasks.Task(foo())) + res = self.loop.run_until_complete(tasks.Task(foo())) t1 = time.monotonic() self.assertTrue(t1-t0 >= 0.11) self.assertEqual(len(res), 2, res) @@ -481,7 +476,7 @@ def sleeper(dt, arg): t = tasks.Task(sleeper(0.1, 'yeah')) t0 = time.monotonic() - self.event_loop.run_until_complete(t) + self.loop.run_until_complete(t) t1 = time.monotonic() self.assertTrue(t1-t0 >= 0.09) self.assertTrue(t.done()) @@ -491,20 +486,20 @@ def test_sleep_cancel(self): t = tasks.Task(tasks.sleep(10.0, 'yeah')) handle = None - orig_call_later = self.event_loop.call_later + orig_call_later = self.loop.call_later def call_later(self, delay, callback, *args): nonlocal handle handle = orig_call_later(self, delay, callback, *args) return handle - self.event_loop.call_later = call_later - self.event_loop.run_once() + self.loop.call_later = call_later + self.loop.run_once() self.assertFalse(handle.cancelled) t.cancel() - self.event_loop.run_once() + self.loop.run_once() self.assertTrue(handle.cancelled) def test_task_cancel_sleeping_task(self): @@ -523,7 +518,7 @@ def sleep(dt): @tasks.task def doit(): sleeper = sleep(5000) - self.event_loop.call_later(0.1, sleeper.cancel) + self.loop.call_later(0.1, sleeper.cancel) try: time.monotonic() yield from sleeper @@ -535,7 +530,7 @@ def doit(): t0 = time.monotonic() doer = doit() - self.assertEqual(self.event_loop.run_until_complete(doer), 'cancelled') + self.assertEqual(self.loop.run_until_complete(doer), 'cancelled') t1 = time.monotonic() self.assertTrue(0.09 <= t1-t0 <= 0.13, (t1-t0, sleepfut, doer)) @@ -550,12 +545,12 @@ def coro(): pass task = coro() - self.event_loop.run_once() + self.loop.run_once() self.assertIs(task._fut_waiter, fut) task.cancel() self.assertRaises( - futures.CancelledError, self.event_loop.run_until_complete, task) + futures.CancelledError, self.loop.run_until_complete, task) self.assertIsNone(task._fut_waiter) self.assertTrue(fut.cancelled()) @@ -577,7 +572,7 @@ def notmuch(): return 'ko' self.assertRaises( - RuntimeError, self.event_loop.run_until_complete, notmuch()) + RuntimeError, self.loop.run_until_complete, notmuch()) def test_step_result_future(self): # If coroutine returns future, task waits on this future. @@ -600,12 +595,12 @@ def wait_for_future(): result = yield from fut t = wait_for_future() - self.event_loop.run_once() + self.loop.run_once() self.assertTrue(fut.cb_added) res = object() fut.set_result(res) - self.event_loop.run_once() + self.loop.run_once() self.assertIs(res, result) self.assertTrue(t.done()) self.assertIsNone(t.result()) @@ -629,12 +624,12 @@ def notmuch(): return (yield c_fut) task = tasks.Task(notmuch()) - self.event_loop.run_once() + self.loop.run_once() self.assertTrue(c_fut.cb_added) res = object() c_fut.set_result(res) - self.event_loop.run_once() + self.loop.run_once() self.assertIs(res, task.result()) def test_step_with_baseexception(self): @@ -661,12 +656,12 @@ def notmutch(): raise BaseException() task = tasks.Task(notmutch()) - self.event_loop.run_once() + self.loop.run_once() task.cancel() self.assertFalse(task.done()) - self.assertRaises(BaseException, self.event_loop.run_once) + self.assertRaises(BaseException, self.loop.run_once) self.assertTrue(task.done()) self.assertTrue(task.cancelled()) @@ -695,7 +690,7 @@ def wait_for_future(): task = wait_for_future() with self.assertRaises(RuntimeError) as cm: - self.event_loop.run_until_complete(task) + self.loop.run_until_complete(task) self.assertTrue(fut.done()) self.assertIs(fut.exception(), cm.exception) @@ -712,7 +707,7 @@ def wait_for_future(): task = wait_for_future() self.assertRaises( RuntimeError, - self.event_loop.run_until_complete, task) + self.loop.run_until_complete, task) def test_coroutine_non_gen_function(self): @tasks.coroutine @@ -724,7 +719,7 @@ def func(): coro = func() self.assertTrue(tasks.iscoroutine(coro)) - res = self.event_loop.run_until_complete(coro) + res = self.loop.run_until_complete(coro) self.assertEqual(res, 'test') def test_coroutine_non_gen_function_return_future(self): @@ -740,7 +735,7 @@ def coro(): t1 = tasks.Task(func()) t2 = tasks.Task(coro()) - res = self.event_loop.run_until_complete(t1) + res = self.loop.run_until_complete(t1) self.assertEqual(res, 'test') self.assertIsNone(t2.result()) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index 536ac0a7..fd5115d3 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -20,13 +20,13 @@ class SelectorEventLoopTests(unittest.TestCase): def setUp(self): - self.event_loop = unix_events.SelectorEventLoop() + self.loop = unix_events.SelectorEventLoop() def test_check_signal(self): self.assertRaises( - TypeError, self.event_loop._check_signal, '1') + TypeError, self.loop._check_signal, '1') self.assertRaises( - ValueError, self.event_loop._check_signal, signal.NSIG + 1) + ValueError, self.loop._check_signal, signal.NSIG + 1) unix_events.signal = None @@ -35,10 +35,10 @@ def restore_signal(): self.addCleanup(restore_signal) self.assertRaises( - RuntimeError, self.event_loop._check_signal, signal.SIGINT) + RuntimeError, self.loop._check_signal, signal.SIGINT) def test_handle_signal_no_handler(self): - self.event_loop._handle_signal(signal.NSIG + 1, ()) + self.loop._handle_signal(signal.NSIG + 1, ()) @unittest.mock.patch('tulip.unix_events.signal') def test_add_signal_handler_setup_error(self, m_signal): @@ -47,7 +47,7 @@ def test_add_signal_handler_setup_error(self, m_signal): self.assertRaises( RuntimeError, - self.event_loop.add_signal_handler, + self.loop.add_signal_handler, signal.SIGINT, lambda: True) @unittest.mock.patch('tulip.unix_events.signal') @@ -55,8 +55,8 @@ def test_add_signal_handler(self, m_signal): m_signal.NSIG = signal.NSIG cb = lambda: True - self.event_loop.add_signal_handler(signal.SIGHUP, cb) - h = self.event_loop._signal_handlers.get(signal.SIGHUP) + self.loop.add_signal_handler(signal.SIGHUP, cb) + h = self.loop._signal_handlers.get(signal.SIGHUP) self.assertTrue(isinstance(h, events.Handle)) self.assertEqual(h.callback, cb) @@ -75,7 +75,7 @@ class Err(OSError): self.assertRaises( Err, - self.event_loop.add_signal_handler, + self.loop.add_signal_handler, signal.SIGINT, lambda: True) @unittest.mock.patch('tulip.unix_events.signal') @@ -87,10 +87,10 @@ class Err(OSError): errno = errno.EINVAL m_signal.signal.side_effect = Err - self.event_loop._signal_handlers[signal.SIGHUP] = lambda: True + self.loop._signal_handlers[signal.SIGHUP] = lambda: True self.assertRaises( RuntimeError, - self.event_loop.add_signal_handler, + self.loop.add_signal_handler, signal.SIGINT, lambda: True) self.assertFalse(m_logging.info.called) self.assertEqual(1, m_signal.set_wakeup_fd.call_count) @@ -105,7 +105,7 @@ class Err(OSError): self.assertRaises( RuntimeError, - self.event_loop.add_signal_handler, + self.loop.add_signal_handler, signal.SIGINT, lambda: True) self.assertFalse(m_logging.info.called) self.assertEqual(2, m_signal.set_wakeup_fd.call_count) @@ -114,10 +114,10 @@ class Err(OSError): def test_remove_signal_handler(self, m_signal): m_signal.NSIG = signal.NSIG - self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) self.assertTrue( - self.event_loop.remove_signal_handler(signal.SIGHUP)) + self.loop.remove_signal_handler(signal.SIGHUP)) self.assertTrue(m_signal.set_wakeup_fd.called) self.assertTrue(m_signal.signal.called) self.assertEqual( @@ -128,12 +128,12 @@ def test_remove_signal_handler_2(self, m_signal): m_signal.NSIG = signal.NSIG m_signal.SIGINT = signal.SIGINT - self.event_loop.add_signal_handler(signal.SIGINT, lambda: True) - self.event_loop._signal_handlers[signal.SIGHUP] = object() + self.loop.add_signal_handler(signal.SIGINT, lambda: True) + self.loop._signal_handlers[signal.SIGHUP] = object() m_signal.set_wakeup_fd.reset_mock() self.assertTrue( - self.event_loop.remove_signal_handler(signal.SIGINT)) + self.loop.remove_signal_handler(signal.SIGINT)) self.assertFalse(m_signal.set_wakeup_fd.called) self.assertTrue(m_signal.signal.called) self.assertEqual( @@ -144,65 +144,65 @@ def test_remove_signal_handler_2(self, m_signal): @unittest.mock.patch('tulip.unix_events.tulip_log') def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): m_signal.NSIG = signal.NSIG - self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) m_signal.set_wakeup_fd.side_effect = ValueError - self.event_loop.remove_signal_handler(signal.SIGHUP) + self.loop.remove_signal_handler(signal.SIGHUP) self.assertTrue(m_logging.info) @unittest.mock.patch('tulip.unix_events.signal') def test_remove_signal_handler_error(self, m_signal): m_signal.NSIG = signal.NSIG - self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) m_signal.signal.side_effect = OSError self.assertRaises( - OSError, self.event_loop.remove_signal_handler, signal.SIGHUP) + OSError, self.loop.remove_signal_handler, signal.SIGHUP) @unittest.mock.patch('tulip.unix_events.signal') def test_remove_signal_handler_error2(self, m_signal): m_signal.NSIG = signal.NSIG - self.event_loop.add_signal_handler(signal.SIGHUP, lambda: True) + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) class Err(OSError): errno = errno.EINVAL m_signal.signal.side_effect = Err self.assertRaises( - RuntimeError, self.event_loop.remove_signal_handler, signal.SIGHUP) + RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP) class UnixReadPipeTransportTests(unittest.TestCase): def setUp(self): - self.event_loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop) + self.loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop) self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) self.pipe.fileno.return_value = 5 self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) @unittest.mock.patch('fcntl.fcntl') def test_ctor(self, m_fcntl): - tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, - self.protocol) - self.event_loop.add_reader.assert_called_with(5, tr._read_ready) - self.event_loop.call_soon.assert_called_with( + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.add_reader.assert_called_with(5, tr._read_ready) + self.loop.call_soon.assert_called_with( self.protocol.connection_made, tr) @unittest.mock.patch('fcntl.fcntl') def test_ctor_with_waiter(self, m_fcntl): fut = futures.Future() - unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, - self.protocol, fut) - self.event_loop.call_soon.assert_called_with(fut.set_result, None) + unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol, fut) + self.loop.call_soon.assert_called_with(fut.set_result, None) fut.cancel() @unittest.mock.patch('os.read') @unittest.mock.patch('fcntl.fcntl') def test__read_ready(self, m_fcntl, m_read): - tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) m_read.return_value = b'data' tr._read_ready() @@ -212,21 +212,21 @@ def test__read_ready(self, m_fcntl, m_read): @unittest.mock.patch('os.read') @unittest.mock.patch('fcntl.fcntl') def test__read_ready_eof(self, m_fcntl, m_read): - tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) m_read.return_value = b'' tr._read_ready() m_read.assert_called_with(5, tr.max_size) - self.event_loop.remove_reader.assert_called_with(5) + self.loop.remove_reader.assert_called_with(5) self.protocol.eof_received.assert_called_with() @unittest.mock.patch('os.read') @unittest.mock.patch('fcntl.fcntl') def test__read_ready_blocked(self, m_fcntl, m_read): - tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, - self.protocol) - self.event_loop.reset_mock() + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.reset_mock() m_read.side_effect = BlockingIOError tr._read_ready() @@ -237,8 +237,8 @@ def test__read_ready_blocked(self, m_fcntl, m_read): @unittest.mock.patch('os.read') @unittest.mock.patch('fcntl.fcntl') def test__read_ready_error(self, m_fcntl, m_read, m_logexc): - tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) err = OSError() m_read.side_effect = err tr._close = unittest.mock.Mock() @@ -251,26 +251,26 @@ def test__read_ready_error(self, m_fcntl, m_read, m_logexc): @unittest.mock.patch('os.read') @unittest.mock.patch('fcntl.fcntl') def test_pause(self, m_fcntl, m_read): - tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) tr.pause() - self.event_loop.remove_reader.assert_called_with(5) + self.loop.remove_reader.assert_called_with(5) @unittest.mock.patch('os.read') @unittest.mock.patch('fcntl.fcntl') def test_resume(self, m_fcntl, m_read): - tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) tr.resume() - self.event_loop.add_reader.assert_called_with(5, tr._read_ready) + self.loop.add_reader.assert_called_with(5, tr._read_ready) @unittest.mock.patch('os.read') @unittest.mock.patch('fcntl.fcntl') def test_close(self, m_fcntl, m_read): - tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) tr._close = unittest.mock.Mock() tr.close() @@ -279,8 +279,8 @@ def test_close(self, m_fcntl, m_read): @unittest.mock.patch('os.read') @unittest.mock.patch('fcntl.fcntl') def test_close_already_closing(self, m_fcntl, m_read): - tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) tr._closing = True tr._close = unittest.mock.Mock() @@ -290,20 +290,19 @@ def test_close_already_closing(self, m_fcntl, m_read): @unittest.mock.patch('os.read') @unittest.mock.patch('fcntl.fcntl') def test__close(self, m_fcntl, m_read): - tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) err = object() tr._close(err) self.assertTrue(tr._closing) - self.event_loop.remove_reader.assert_called_with(5) - self.event_loop.call_soon.assert_called_with( - tr._call_connection_lost, err) + self.loop.remove_reader.assert_called_with(5) + self.loop.call_soon.assert_called_with(tr._call_connection_lost, err) @unittest.mock.patch('fcntl.fcntl') def test__call_connection_lost(self, m_fcntl): - tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) err = None tr._call_connection_lost(err) @@ -312,8 +311,8 @@ def test__call_connection_lost(self, m_fcntl): @unittest.mock.patch('fcntl.fcntl') def test__call_connection_lost_with_err(self, m_fcntl): - tr = unix_events._UnixReadPipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) err = OSError() tr._call_connection_lost(err) @@ -324,104 +323,104 @@ def test__call_connection_lost_with_err(self, m_fcntl): class UnixWritePipeTransportTests(unittest.TestCase): def setUp(self): - self.event_loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop) + self.loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop) self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) self.pipe.fileno.return_value = 5 self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) @unittest.mock.patch('fcntl.fcntl') def test_ctor(self, m_fcntl): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) - self.event_loop.call_soon.assert_called_with( + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.call_soon.assert_called_with( self.protocol.connection_made, tr) @unittest.mock.patch('fcntl.fcntl') def test_ctor_with_waiter(self, m_fcntl): fut = futures.Future() - unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol, fut) - self.event_loop.call_soon.assert_called_with(fut.set_result, None) + unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol, fut) + self.loop.call_soon.assert_called_with(fut.set_result, None) fut.cancel() @unittest.mock.patch('fcntl.fcntl') def test_can_write_eof(self, m_fcntl): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) self.assertTrue(tr.can_write_eof()) @unittest.mock.patch('os.write') @unittest.mock.patch('fcntl.fcntl') def test_write(self, m_fcntl, m_write): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) m_write.return_value = 4 tr.write(b'data') m_write.assert_called_with(5, b'data') - self.assertFalse(self.event_loop.add_writer.called) + self.assertFalse(self.loop.add_writer.called) self.assertEqual([], tr._buffer) @unittest.mock.patch('os.write') @unittest.mock.patch('fcntl.fcntl') def test_write_no_data(self, m_fcntl, m_write): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) tr.write(b'') self.assertFalse(m_write.called) - self.assertFalse(self.event_loop.add_writer.called) + self.assertFalse(self.loop.add_writer.called) self.assertEqual([], tr._buffer) @unittest.mock.patch('os.write') @unittest.mock.patch('fcntl.fcntl') def test_write_partial(self, m_fcntl, m_write): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) m_write.return_value = 2 tr.write(b'data') m_write.assert_called_with(5, b'data') - self.event_loop.add_writer.assert_called_with(5, tr._write_ready) + self.loop.add_writer.assert_called_with(5, tr._write_ready) self.assertEqual([b'ta'], tr._buffer) @unittest.mock.patch('os.write') @unittest.mock.patch('fcntl.fcntl') def test_write_buffer(self, m_fcntl, m_write): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) tr._buffer = [b'previous'] tr.write(b'data') self.assertFalse(m_write.called) - self.assertFalse(self.event_loop.add_writer.called) + self.assertFalse(self.loop.add_writer.called) self.assertEqual([b'previous', b'data'], tr._buffer) @unittest.mock.patch('os.write') @unittest.mock.patch('fcntl.fcntl') def test_write_again(self, m_fcntl, m_write): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) m_write.side_effect = BlockingIOError() tr.write(b'data') m_write.assert_called_with(5, b'data') - self.event_loop.add_writer.assert_called_with(5, tr._write_ready) + self.loop.add_writer.assert_called_with(5, tr._write_ready) self.assertEqual([b'data'], tr._buffer) @unittest.mock.patch('tulip.unix_events.tulip_log') @unittest.mock.patch('os.write') @unittest.mock.patch('fcntl.fcntl') def test_write_err(self, m_fcntl, m_write, m_log): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) err = OSError() m_write.side_effect = err tr._fatal_error = unittest.mock.Mock() tr.write(b'data') m_write.assert_called_with(5, b'data') - self.assertFalse(self.event_loop.called) + self.assertFalse(self.loop.called) self.assertEqual([], tr._buffer) tr._fatal_error.assert_called_with(err) self.assertEqual(1, tr._conn_lost) @@ -438,69 +437,69 @@ def test_write_err(self, m_fcntl, m_write, m_log): @unittest.mock.patch('os.write') @unittest.mock.patch('fcntl.fcntl') def test__write_ready(self, m_fcntl, m_write): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) tr._buffer = [b'da', b'ta'] m_write.return_value = 4 tr._write_ready() m_write.assert_called_with(5, b'data') - self.event_loop.remove_writer.assert_called_with(5) + self.loop.remove_writer.assert_called_with(5) self.assertEqual([], tr._buffer) @unittest.mock.patch('os.write') @unittest.mock.patch('fcntl.fcntl') def test__write_ready_partial(self, m_fcntl, m_write): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) tr._buffer = [b'da', b'ta'] m_write.return_value = 3 tr._write_ready() m_write.assert_called_with(5, b'data') - self.assertFalse(self.event_loop.remove_writer.called) + self.assertFalse(self.loop.remove_writer.called) self.assertEqual([b'a'], tr._buffer) @unittest.mock.patch('os.write') @unittest.mock.patch('fcntl.fcntl') def test__write_ready_again(self, m_fcntl, m_write): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) tr._buffer = [b'da', b'ta'] m_write.side_effect = BlockingIOError() tr._write_ready() m_write.assert_called_with(5, b'data') - self.assertFalse(self.event_loop.remove_writer.called) + self.assertFalse(self.loop.remove_writer.called) self.assertEqual([b'data'], tr._buffer) @unittest.mock.patch('os.write') @unittest.mock.patch('fcntl.fcntl') def test__write_ready_empty(self, m_fcntl, m_write): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) tr._buffer = [b'da', b'ta'] m_write.return_value = 0 tr._write_ready() m_write.assert_called_with(5, b'data') - self.assertFalse(self.event_loop.remove_writer.called) + self.assertFalse(self.loop.remove_writer.called) self.assertEqual([b'data'], tr._buffer) @unittest.mock.patch('tulip.log.tulip_log.exception') @unittest.mock.patch('os.write') @unittest.mock.patch('fcntl.fcntl') def test__write_ready_err(self, m_fcntl, m_write, m_logexc): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) tr._buffer = [b'da', b'ta'] m_write.side_effect = err = OSError() tr._write_ready() m_write.assert_called_with(5, b'data') - self.event_loop.remove_writer.assert_called_with(5) + self.loop.remove_writer.assert_called_with(5) self.assertEqual([], tr._buffer) self.assertTrue(tr._closing) - self.event_loop.call_soon.assert_called_with( + self.loop.call_soon.assert_called_with( tr._call_connection_lost, err) m_logexc.assert_called_with('Fatal error for %s', tr) self.assertEqual(1, tr._conn_lost) @@ -508,15 +507,15 @@ def test__write_ready_err(self, m_fcntl, m_write, m_logexc): @unittest.mock.patch('os.write') @unittest.mock.patch('fcntl.fcntl') def test__write_ready_closing(self, m_fcntl, m_write): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) tr._closing = True tr._buffer = [b'da', b'ta'] m_write.return_value = 4 tr._write_ready() m_write.assert_called_with(5, b'data') - self.event_loop.remove_writer.assert_called_with(5) + self.loop.remove_writer.assert_called_with(5) self.assertEqual([], tr._buffer) self.protocol.connection_lost.assert_called_with(None) self.pipe.close.assert_called_with() @@ -524,22 +523,22 @@ def test__write_ready_closing(self, m_fcntl, m_write): @unittest.mock.patch('os.write') @unittest.mock.patch('fcntl.fcntl') def test_abort(self, m_fcntl, m_write): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) tr._buffer = [b'da', b'ta'] tr.abort() self.assertFalse(m_write.called) - self.event_loop.remove_writer.assert_called_with(5) + self.loop.remove_writer.assert_called_with(5) self.assertEqual([], tr._buffer) self.assertTrue(tr._closing) - self.event_loop.call_soon.assert_called_with( + self.loop.call_soon.assert_called_with( tr._call_connection_lost, None) @unittest.mock.patch('fcntl.fcntl') def test__call_connection_lost(self, m_fcntl): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) err = None tr._call_connection_lost(err) @@ -548,8 +547,8 @@ def test__call_connection_lost(self, m_fcntl): @unittest.mock.patch('fcntl.fcntl') def test__call_connection_lost_with_err(self, m_fcntl): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) err = OSError() tr._call_connection_lost(err) @@ -558,8 +557,8 @@ def test__call_connection_lost_with_err(self, m_fcntl): @unittest.mock.patch('fcntl.fcntl') def test_close(self, m_fcntl): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) tr.write_eof = unittest.mock.Mock() tr.close() @@ -567,8 +566,8 @@ def test_close(self, m_fcntl): @unittest.mock.patch('fcntl.fcntl') def test_close_closing(self, m_fcntl): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) tr.write_eof = unittest.mock.Mock() tr._closing = True @@ -577,18 +576,18 @@ def test_close_closing(self, m_fcntl): @unittest.mock.patch('fcntl.fcntl') def test_write_eof(self, m_fcntl): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) tr.write_eof() self.assertTrue(tr._closing) - self.event_loop.call_soon.assert_called_with( + self.loop.call_soon.assert_called_with( tr._call_connection_lost, None) @unittest.mock.patch('fcntl.fcntl') def test_write_eof_pending(self, m_fcntl): - tr = unix_events._UnixWritePipeTransport(self.event_loop, self.pipe, - self.protocol) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) tr._buffer = [b'data'] tr.write_eof() self.assertTrue(tr._closing) diff --git a/tulip/base_events.py b/tulip/base_events.py index bf00dc23..9d71d96f 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -288,9 +288,11 @@ def create_connection(self, protocol_factory, host=None, port=None, *, sock.bind(laddr) break except socket.error as exc: - exc = socket.error(exc.errno, "error while " \ - "attempting to bind on address " \ - "%r: %s" % (laddr, exc.strerror.lower())) + exc = socket.error( + exc.errno, 'error while ' + 'attempting to bind on address ' + '{!r}: {}'.format( + laddr, exc.strerror.lower())) exceptions.append(exc) else: continue @@ -502,7 +504,7 @@ def wrap_future(self, future): """XXX""" if isinstance(future, futures.Future): return future # Don't wrap our own type of Future. - new_future = futures.Future(event_loop=self) + new_future = futures.Future(loop=self) future.add_done_callback( lambda future: self.call_soon_threadsafe(new_future._copy_state, future)) @@ -562,12 +564,3 @@ def _run_once(self, timeout=None): handle = self._ready.popleft() if not handle.cancelled: handle.run() - - # Future.__del__ uses log level - _log_level = logging.WARNING - - def set_log_level(self, val): - self._log_level = val - - def get_log_level(self): - return self._log_level diff --git a/tulip/events.py b/tulip/events.py index 72a660f0..63861262 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -312,7 +312,7 @@ def get_event_loop(self): """XXX""" raise NotImplementedError - def set_event_loop(self, event_loop): + def set_event_loop(self, loop): """XXX""" raise NotImplementedError @@ -334,23 +334,23 @@ class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy): associated). """ - _event_loop = None + _loop = None def get_event_loop(self): """Get the event loop. This may be None or an instance of EventLoop. """ - if (self._event_loop is None and + if (self._loop is None and threading.current_thread().name == 'MainThread'): - self._event_loop = self.new_event_loop() - return self._event_loop + self._loop = self.new_event_loop() + return self._loop - def set_event_loop(self, event_loop): + def set_event_loop(self, loop): """Set the event loop.""" # TODO: The isinstance() test violates the PEP. - assert event_loop is None or isinstance(event_loop, AbstractEventLoop) - self._event_loop = event_loop + assert loop is None or isinstance(loop, AbstractEventLoop) + self._loop = loop def new_event_loop(self): """Create a new event loop. @@ -394,9 +394,9 @@ def get_event_loop(): return get_event_loop_policy().get_event_loop() -def set_event_loop(event_loop): +def set_event_loop(loop): """XXX""" - get_event_loop_policy().set_event_loop(event_loop) + get_event_loop_policy().set_event_loop(loop) def new_event_loop(): diff --git a/tulip/futures.py b/tulip/futures.py index 7f410d31..5161dafe 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -6,7 +6,6 @@ ] import concurrent.futures._base -import io import logging import traceback @@ -118,28 +117,27 @@ class Future: _result = None _exception = None _timeout_handle = None - _event_loop = None + _loop = None _blocking = False # proper use of future (yield vs yield from) _tb_logger = None - def __init__(self, *, event_loop=None, timeout=None): + def __init__(self, *, loop=None, timeout=None): """Initialize the future. The optional event_loop argument allows to explicitly set the event loop object used by the future. If it's not provided, the future uses the default event loop. """ - if event_loop is None: - self._event_loop = events.get_event_loop() + if loop is None: + self._loop = events.get_event_loop() else: - self._event_loop = event_loop + self._loop = loop self._callbacks = [] if timeout is not None: - self._timeout_handle = self._event_loop.call_later( - timeout, self.cancel) + self._timeout_handle = self._loop.call_later(timeout, self.cancel) def __repr__(self): res = self.__class__.__name__ @@ -190,7 +188,7 @@ def _schedule_callbacks(self): self._callbacks[:] = [] for callback in callbacks: - self._event_loop.call_soon(callback, self) + self._loop.call_soon(callback, self) def cancelled(self): """Return True if the future was cancelled.""" @@ -260,7 +258,7 @@ def add_done_callback(self, fn): scheduled with call_soon. """ if self._state != _PENDING: - self._event_loop.call_soon(fn, self) + self._loop.call_soon(fn, self) else: self._callbacks.append(fn) diff --git a/tulip/locks.py b/tulip/locks.py index d425d064..dfe9905d 100644 --- a/tulip/locks.py +++ b/tulip/locks.py @@ -97,7 +97,7 @@ def acquire(self, timeout=None): self._locked = True return True - fut = futures.Future(event_loop=self._loop, timeout=timeout) + fut = futures.Future(loop=self._loop, timeout=timeout) self._waiters.append(fut) try: @@ -208,7 +208,7 @@ def wait(self, timeout=None): if self._value: return True - fut = futures.Future(event_loop=self._loop, timeout=timeout) + fut = futures.Future(loop=self._loop, timeout=timeout) self._waiters.append(fut) try: @@ -259,7 +259,7 @@ def wait(self, timeout=None): self.release() - fut = futures.Future(event_loop=self._loop, timeout=timeout) + fut = futures.Future(loop=self._loop, timeout=timeout) self._condition_waiters.append(fut) try: @@ -395,7 +395,7 @@ def acquire(self, timeout=None): self._locked = True return True - fut = futures.Future(event_loop=self._loop, timeout=timeout) + fut = futures.Future(loop=self._loop, timeout=timeout) self._waiters.append(fut) try: diff --git a/tulip/queues.py b/tulip/queues.py index ba0b626d..8bb35066 100644 --- a/tulip/queues.py +++ b/tulip/queues.py @@ -121,7 +121,7 @@ def put(self, item, timeout=None): getter.set_result(self._get()) elif self._maxsize > 0 and self._maxsize == self.qsize(): - waiter = futures.Future(event_loop=self._loop, timeout=timeout) + waiter = futures.Future(loop=self._loop, timeout=timeout) self._putters.append((item, waiter)) try: @@ -181,7 +181,7 @@ def get(self, timeout=None): elif self.qsize(): return self._get() else: - waiter = futures.Future(event_loop=self._loop, timeout=timeout) + waiter = futures.Future(loop=self._loop, timeout=timeout) self._getters.append(waiter) try: diff --git a/tulip/tasks.py b/tulip/tasks.py index ce539ab8..6f9e6efd 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -71,13 +71,13 @@ def task_wrapper(*args, **kwds): class Task(futures.Future): """A coroutine wrapped in a Future.""" - def __init__(self, coro, event_loop=None, timeout=None): + def __init__(self, coro, *, loop=None, timeout=None): assert inspect.isgenerator(coro) # Must be a coroutine *object*. - super().__init__(event_loop=event_loop, timeout=timeout) + super().__init__(loop=loop, timeout=timeout) self._coro = coro self._fut_waiter = None self._must_cancel = False - self._event_loop.call_soon(self._step) + self._loop.call_soon(self._step) def __repr__(self): res = super().__repr__() @@ -99,7 +99,7 @@ def cancel(self): if self._fut_waiter is not None: return self._fut_waiter.cancel() else: - self._event_loop.call_soon(self._step_maybe) + self._loop.call_soon(self._step_maybe) return True def cancelled(self): @@ -166,11 +166,11 @@ def _step(self, value=_marker, exc=None): # because we don't create an extra Future. result.add_done_callback( lambda future: - self._event_loop.call_soon_threadsafe( + self._loop.call_soon_threadsafe( self._wakeup, future)) else: if inspect.isgenerator(result): - self._event_loop.call_soon( + self._loop.call_soon( self._step, None, RuntimeError( 'yield was used instead of yield from for ' @@ -178,12 +178,12 @@ def _step(self, value=_marker, exc=None): self, result))) else: if result is not None: - self._event_loop.call_soon( + self._loop.call_soon( self._step, None, RuntimeError( 'Task got bad yield: {!r}'.format(result))) else: - self._event_loop.call_soon(self._step_maybe) + self._loop.call_soon(self._step_maybe) def _wakeup(self, future): try: @@ -330,7 +330,7 @@ def _wrap_coroutines(fs): def sleep(delay, result=None): """Coroutine that completes after a given time (in seconds).""" future = futures.Future() - h = future._event_loop.call_later(delay, future.set_result, result) + h = future._loop.call_later(delay, future.set_result, result) try: return (yield from future) finally: diff --git a/tulip/test_utils.py b/tulip/test_utils.py index d1f05614..29983a52 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -101,7 +101,6 @@ def handle_request(self, message, payload): def run(loop, fut): thread_loop = tulip.new_event_loop() - thread_loop.set_log_level(logging.CRITICAL) tulip.set_event_loop(thread_loop) socks = thread_loop.run_until_complete( From 47b5770d1c17ad60d1e0ae6159dad49681bdf100 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 6 May 2013 09:40:41 -0700 Subject: [PATCH 0464/1502] Improvement to _TracebackLogger. Bonus: actually break the cycle! --- tests/futures_test.py | 11 ++++++++--- tulip/base_events.py | 1 + tulip/events.py | 1 + tulip/futures.py | 28 ++++++++++++++++++++++------ tulip/tasks.py | 16 ++++++++++------ 5 files changed, 42 insertions(+), 15 deletions(-) diff --git a/tests/futures_test.py b/tests/futures_test.py index 44955182..0fd482f8 100644 --- a/tests/futures_test.py +++ b/tests/futures_test.py @@ -15,7 +15,11 @@ def _fakefunc(f): class FutureTests(unittest.TestCase): def setUp(self): - self.loop = events.get_event_loop() + self.loop = events.new_event_loop() + events.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() def test_initial_state(self): f = futures.Future() @@ -188,6 +192,7 @@ def test_tb_logger_exception_unretrieved(self, m_log): fut = futures.Future() fut.set_exception(RuntimeError('boom')) del fut + self.loop.run_once() self.assertTrue(m_log.error.called) @unittest.mock.patch('tulip.futures.tulip_log') @@ -210,8 +215,8 @@ def test_tb_logger_exception_result_retrieved(self, m_log): # A fake event loop for tests. All it does is implement a call_soon method # that immediately invokes the given function. class _FakeEventLoop: - def call_soon(self, fn, future): - fn(future) + def call_soon(self, fn, *args): + fn(*args) class FutureDoneCallbackTests(unittest.TestCase): diff --git a/tulip/base_events.py b/tulip/base_events.py index 9d71d96f..7d20f654 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -564,3 +564,4 @@ def _run_once(self, timeout=None): handle = self._ready.popleft() if not handle.cancelled: handle.run() + handle = None # Needed to break cycles when an exception occurs. diff --git a/tulip/events.py b/tulip/events.py index 63861262..6efcd76e 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -52,6 +52,7 @@ def run(self): except Exception: tulip_log.exception('Exception in callback %s %r', self._callback, self._args) + self = None # Needed to break cycles when an exception occurs. def make_handle(callback, args): diff --git a/tulip/futures.py b/tulip/futures.py index 5161dafe..e2c0245e 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -73,20 +73,33 @@ class itself, but instead to have a reference to a helper object references the traceback, which references stack frames, which may reference the Future, which references the _TracebackLogger, and then the _TracebackLogger would be included in a cycle, which is - what we're trying to avoid! As a compromise, we use - extract_exception() rather than format_exception(). (We may also - have to limit how many entries we extract, but then we'd need a - public API to change the limit; so let's punt on this for now.) + what we're trying to avoid! As an optimization, we don't + immediately format the exception; we only do the work when + activate() is called, which call is delayed until after all the + Future's callbacks have run. Since usually a Future has at least + one callback (typically set by 'yield from') and usually that + callback extracts the callback, thereby removing the need to + format the exception. PS. I don't claim credit for this solution. I first heard of it in a discussion about closing files when they are collected. """ + __slots__ = ['exc', 'tb'] + def __init__(self, exc): - self.tb = traceback.format_exception(exc.__class__, exc, - exc.__traceback__) + self.exc = exc + self.tb = None + + def activate(self): + exc = self.exc + if exc is not None: + self.exc = None + self.tb = traceback.format_exception(exc.__class__, exc, + exc.__traceback__) def clear(self): + self.exc = None self.tb = None def __del__(self): @@ -301,6 +314,9 @@ def set_exception(self, exception): self._tb_logger = _TracebackLogger(exception) self._state = _FINISHED self._schedule_callbacks() + # Arrange for the logger to be activated after all callbacks + # have had a chance to call result() or exception(). + self._loop.call_soon(self._tb_logger.activate) # Truly internal methods. diff --git a/tulip/tasks.py b/tulip/tasks.py index 6f9e6efd..946fae56 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -162,12 +162,10 @@ def _step(self, value=_marker, exc=None): self._fut_waiter.cancel() elif isinstance(result, concurrent.futures.Future): - # This ought to be more efficient than wrap_future(), - # because we don't create an extra Future. - result.add_done_callback( - lambda future: - self._loop.call_soon_threadsafe( - self._wakeup, future)) + # Don't use a lambda here; mysteriously it creates an + # unnecessary memory cycle. + result.add_done_callback(self._wakeup_from_thread) + else: if inspect.isgenerator(result): self._loop.call_soon( @@ -184,6 +182,11 @@ def _step(self, value=_marker, exc=None): 'Task got bad yield: {!r}'.format(result))) else: self._loop.call_soon(self._step_maybe) + self = None + + def _wakeup_from_thread(self, future): + # Helper to wake up a task from a thread. + self._loop.call_soon_threadsafe(self._wakeup, future) def _wakeup(self, future): try: @@ -192,6 +195,7 @@ def _wakeup(self, future): self._step(None, exc) else: self._step(value, None) + self = None # Needed to break cycles when an exception occurs. # wait() and as_completed() similar to those in PEP 3148. From f1531b51ddf060e3e21af3d1b9dfbb35ec7de2aa Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 7 May 2013 11:22:54 -0700 Subject: [PATCH 0465/1502] tests cleanup --- tests/base_events_test.py | 313 ++++++++++++++++++++++++++++++++++++++ tests/events_test.py | 208 +------------------------ tests/futures_test.py | 1 - tests/unix_events_test.py | 8 + tulip/base_events.py | 8 +- 5 files changed, 333 insertions(+), 205 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index b1790ff2..b5852d3c 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -52,6 +52,12 @@ def test__add_callback_handle(self): self.assertFalse(self.loop._scheduled) self.assertIn(h, self.loop._ready) + def test__add_callback_timer(self): + h = events.TimerHandle(time.monotonic()+10, lambda: False, ()) + + self.loop._add_callback(h) + self.assertIn(h, self.loop._scheduled) + def test__add_callback_cancelled_handle(self): h = events.Handle(lambda: False, ()) h.cancel() @@ -260,6 +266,75 @@ def test_run_until_complete_assertion(self): self.assertRaises( AssertionError, self.loop.run_until_complete, 'blah') + +class MyProto(protocols.Protocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = futures.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyDatagramProto(protocols.DatagramProtocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = futures.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class BaseEventLoopWithSelectorTests(unittest.TestCase): + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + @unittest.mock.patch('tulip.base_events.socket') def test_create_connection_mutiple_errors(self, m_socket): @@ -290,3 +365,241 @@ def _socket(*args, **kw): yield from tasks.wait(task) exc = task.exception() self.assertEqual("Multiple exceptions: err1, err2", str(exc)) + + def test_create_connection_host_port_sock(self): + coro = self.loop.create_connection( + MyProto, 'example.com', 80, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_host_port_sock(self): + coro = self.loop.create_connection(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_getaddrinfo(self): + @tasks.task + def getaddrinfo(*args, **kw): + yield from [] + self.loop.getaddrinfo = getaddrinfo + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + socket.error, self.loop.run_until_complete, coro) + + def test_create_connection_connect_err(self): + @tasks.task + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80))] + self.loop.getaddrinfo = getaddrinfo + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = socket.error + + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + socket.error, self.loop.run_until_complete, coro) + + def test_create_connection_mutiple(self): + @tasks.task + def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + self.loop.getaddrinfo = getaddrinfo + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = socket.error + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET) + with self.assertRaises(socket.error): + self.loop.run_until_complete(coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_connection_mutiple_errors_local_addr(self, m_socket): + m_socket.error = socket.error + + def bind(addr): + if addr[0] == '0.0.0.1': + err = socket.error('Err') + err.strerror = 'Err' + raise err + + m_socket.socket.return_value.bind = bind + + @tasks.task + def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + self.loop.getaddrinfo = getaddrinfo + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = socket.error('Err2') + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + with self.assertRaises(socket.error) as cm: + self.loop.run_until_complete(coro) + + self.assertTrue(str(cm.exception), 'Multiple exceptions: ') + + def test_create_connection_no_local_addr(self): + @tasks.task + def getaddrinfo(host, *args, **kw): + if host == 'example.com': + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + else: + return [] + self.loop.getaddrinfo = getaddrinfo + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + self.assertRaises( + socket.error, self.loop.run_until_complete, coro) + + def test_start_serving_empty_host(self): + # if host is empty string use None instead + host = object() + + @tasks.task + def getaddrinfo(*args, **kw): + nonlocal host + host = args[0] + yield from [] + + self.loop.getaddrinfo = getaddrinfo + fut = self.loop.start_serving(MyProto, '', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertIsNone(host) + + def test_start_serving_host_port_sock(self): + fut = self.loop.start_serving( + MyProto, '0.0.0.0', 0, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_start_serving_no_host_port_sock(self): + fut = self.loop.start_serving(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_start_serving_no_getaddrinfo(self): + getaddrinfo = self.loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + f = self.loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises(socket.error, self.loop.run_until_complete, f) + + @unittest.mock.patch('tulip.base_events.socket') + def test_start_serving_cant_bind(self, m_socket): + + class Err(socket.error): + strerror = 'error' + + m_socket.error = socket.error + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_no_addrinfo(self, m_socket): + m_socket.error = socket.error + m_socket.getaddrinfo.return_value = [] + + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 0)) + self.assertRaises( + socket.error, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_addr_error(self): + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr='localhost') + self.assertRaises( + AssertionError, self.loop.run_until_complete, coro) + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 1, 2, 3)) + self.assertRaises( + AssertionError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_connect_err(self): + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = socket.error + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0)) + self.assertRaises( + socket.error, self.loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_socket_err(self, m_socket): + m_socket.error = socket.error + m_socket.getaddrinfo = socket.getaddrinfo + m_socket.socket.side_effect = socket.error + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + socket.error, self.loop.run_until_complete, coro) + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, local_addr=('127.0.0.1', 0)) + self.assertRaises( + socket.error, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_no_matching_family(self): + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, + remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) + self.assertRaises( + ValueError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_setblk_err(self, m_socket): + m_socket.error = socket.error + m_socket.socket.return_value.setblocking.side_effect = socket.error + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + socket.error, self.loop.run_until_complete, coro) + self.assertTrue( + m_socket.socket.return_value.close.called) + + def test_create_datagram_endpoint_noaddr_nofamily(self): + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_cant_bind(self, m_socket): + class Err(socket.error): + pass + + m_socket.error = socket.error + m_socket.AF_INET6 = socket.AF_INET6 + m_socket.getaddrinfo = socket.getaddrinfo + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, + local_addr=('127.0.0.1', 0), family=socket.AF_INET) + self.assertRaises(Err, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_accept_connection_retry(self): + sock = unittest.mock.Mock() + sock.accept.side_effect = BlockingIOError() + + self.loop._accept_connection(MyProto, sock) + self.assertFalse(sock.close.called) + + @unittest.mock.patch('tulip.selector_events.tulip_log') + def test_accept_connection_exception(self, m_log): + sock = unittest.mock.Mock() + sock.accept.side_effect = OSError() + + self.loop._accept_connection(MyProto, sock) + self.assertTrue(sock.close.called) + self.assertTrue(m_log.exception.called) diff --git a/tests/events_test.py b/tests/events_test.py index 3f606834..9e9eec74 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -491,52 +491,6 @@ def test_create_ssl_connection(self): self.loop.run_until_complete(pr.done) self.assertTrue(pr.nbytes > 0) - def test_create_connection_host_port_sock(self): - coro = self.loop.create_connection( - MyProto, 'example.com', 80, sock=object()) - self.assertRaises(ValueError, self.loop.run_until_complete, coro) - - def test_create_connection_no_host_port_sock(self): - coro = self.loop.create_connection(MyProto) - self.assertRaises(ValueError, self.loop.run_until_complete, coro) - - def test_create_connection_no_getaddrinfo(self): - @tasks.task - def getaddrinfo(*args, **kw): - yield from [] - self.loop.getaddrinfo = getaddrinfo - coro = self.loop.create_connection(MyProto, 'example.com', 80) - self.assertRaises( - socket.error, self.loop.run_until_complete, coro) - - def test_create_connection_connect_err(self): - @tasks.task - def getaddrinfo(*args, **kw): - yield from [] - return [(2, 1, 6, '', ('107.6.106.82', 80))] - self.loop.getaddrinfo = getaddrinfo - self.loop.sock_connect = unittest.mock.Mock() - self.loop.sock_connect.side_effect = socket.error - - coro = self.loop.create_connection(MyProto, 'example.com', 80) - self.assertRaises( - socket.error, self.loop.run_until_complete, coro) - - def test_create_connection_mutiple_errors(self): - @tasks.task - def getaddrinfo(*args, **kw): - yield from [] - return [(2, 1, 6, '', ('107.6.106.82', 80)), - (2, 1, 6, '', ('107.6.106.82', 80))] - self.loop.getaddrinfo = getaddrinfo - self.loop.sock_connect = unittest.mock.Mock() - self.loop.sock_connect.side_effect = socket.error - - coro = self.loop.create_connection( - MyProto, 'example.com', 80, family=socket.AF_INET) - self.assertRaises( - socket.error, self.loop.run_until_complete, coro) - def test_create_connection_local_addr(self): with test_utils.run_test_server(self.loop) as httpd: port = find_unused_port() @@ -738,58 +692,6 @@ def test_stop_serving(self): self.assertRaises( ConnectionRefusedError, client.connect, ('127.0.0.1', port)) - def test_start_serving_host_port_sock(self): - fut = self.loop.start_serving( - MyProto, '0.0.0.0', 0, sock=object()) - self.assertRaises(ValueError, self.loop.run_until_complete, fut) - - def test_start_serving_no_host_port_sock(self): - fut = self.loop.start_serving(MyProto) - self.assertRaises(ValueError, self.loop.run_until_complete, fut) - - def test_start_serving_no_getaddrinfo(self): - getaddrinfo = self.loop.getaddrinfo = unittest.mock.Mock() - getaddrinfo.return_value = [] - - f = self.loop.start_serving(MyProto, '0.0.0.0', 0) - self.assertRaises(socket.error, self.loop.run_until_complete, f) - - @unittest.mock.patch('tulip.base_events.socket') - def test_start_serving_cant_bind(self, m_socket): - - class Err(socket.error): - strerror = 'error' - - m_socket.error = socket.error - m_socket.getaddrinfo.return_value = [ - (2, 1, 6, '', ('127.0.0.1', 10100))] - m_sock = m_socket.socket.return_value = unittest.mock.Mock() - m_sock.bind.side_effect = Err - - fut = self.loop.start_serving(MyProto, '0.0.0.0', 0) - self.assertRaises(OSError, self.loop.run_until_complete, fut) - self.assertTrue(m_sock.close.called) - - @unittest.mock.patch('tulip.base_events.socket') - def test_create_datagram_endpoint_no_addrinfo(self, m_socket): - m_socket.error = socket.error - m_socket.getaddrinfo.return_value = [] - - coro = self.loop.create_datagram_endpoint( - MyDatagramProto, local_addr=('localhost', 0)) - self.assertRaises( - socket.error, self.loop.run_until_complete, coro) - - def test_create_datagram_endpoint_addr_error(self): - coro = self.loop.create_datagram_endpoint( - MyDatagramProto, local_addr='localhost') - self.assertRaises( - AssertionError, self.loop.run_until_complete, coro) - coro = self.loop.create_datagram_endpoint( - MyDatagramProto, local_addr=('localhost', 1, 2, 3)) - self.assertRaises( - AssertionError, self.loop.run_until_complete, coro) - def test_create_datagram_endpoint(self): class TestMyDatagramProto(MyDatagramProto): def __init__(self): @@ -829,88 +731,6 @@ def datagram_received(self, data, addr): self.assertEqual('CLOSED', client.state) server.transport.close() - def test_create_datagram_endpoint_connect_err(self): - self.loop.sock_connect = unittest.mock.Mock() - self.loop.sock_connect.side_effect = socket.error - - coro = self.loop.create_datagram_endpoint( - protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0)) - self.assertRaises( - socket.error, self.loop.run_until_complete, coro) - - @unittest.mock.patch('tulip.base_events.socket') - def test_create_datagram_endpoint_socket_err(self, m_socket): - m_socket.error = socket.error - m_socket.getaddrinfo = socket.getaddrinfo - m_socket.socket.side_effect = socket.error - - coro = self.loop.create_datagram_endpoint( - protocols.DatagramProtocol, family=socket.AF_INET) - self.assertRaises( - socket.error, self.loop.run_until_complete, coro) - - coro = self.loop.create_datagram_endpoint( - protocols.DatagramProtocol, local_addr=('127.0.0.1', 0)) - self.assertRaises( - socket.error, self.loop.run_until_complete, coro) - - def test_create_datagram_endpoint_no_matching_family(self): - coro = self.loop.create_datagram_endpoint( - protocols.DatagramProtocol, - remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) - self.assertRaises( - ValueError, self.loop.run_until_complete, coro) - - @unittest.mock.patch('tulip.base_events.socket') - def test_create_datagram_endpoint_setblk_err(self, m_socket): - m_socket.error = socket.error - m_socket.socket.return_value.setblocking.side_effect = socket.error - - coro = self.loop.create_datagram_endpoint( - protocols.DatagramProtocol, family=socket.AF_INET) - self.assertRaises( - socket.error, self.loop.run_until_complete, coro) - self.assertTrue( - m_socket.socket.return_value.close.called) - - def test_create_datagram_endpoint_noaddr_nofamily(self): - coro = self.loop.create_datagram_endpoint( - protocols.DatagramProtocol) - self.assertRaises(ValueError, self.loop.run_until_complete, coro) - - @unittest.mock.patch('tulip.base_events.socket') - def test_create_datagram_endpoint_cant_bind(self, m_socket): - class Err(socket.error): - pass - - m_socket.error = socket.error - m_socket.AF_INET6 = socket.AF_INET6 - m_socket.getaddrinfo = socket.getaddrinfo - m_sock = m_socket.socket.return_value = unittest.mock.Mock() - m_sock.bind.side_effect = Err - - fut = self.loop.create_datagram_endpoint( - MyDatagramProto, - local_addr=('127.0.0.1', 0), family=socket.AF_INET) - self.assertRaises(Err, self.loop.run_until_complete, fut) - self.assertTrue(m_sock.close.called) - - def test_accept_connection_retry(self): - sock = unittest.mock.Mock() - sock.accept.side_effect = BlockingIOError() - - self.loop._accept_connection(MyProto, sock) - self.assertFalse(sock.close.called) - - @unittest.mock.patch('tulip.selector_events.tulip_log') - def test_accept_connection_exception(self, m_log): - sock = unittest.mock.Mock() - sock.accept.side_effect = OSError() - - self.loop._accept_connection(MyProto, sock) - self.assertTrue(sock.close.called) - self.assertTrue(m_log.exception.called) - def test_internal_fds(self): loop = self.create_event_loop() if not isinstance(loop, selector_events.BaseSelectorEventLoop): @@ -1032,30 +852,9 @@ def test_writer_callback(self): raise unittest.SkipTest("IocpEventLoop does not have add_writer()") def test_writer_callback_cancel(self): raise unittest.SkipTest("IocpEventLoop does not have add_writer()") - def test_accept_connection_retry(self): - raise unittest.SkipTest( - "IocpEventLoop does not have _accept_connection()") - def test_accept_connection_exception(self): - raise unittest.SkipTest( - "IocpEventLoop does not have _accept_connection()") def test_create_datagram_endpoint(self): raise unittest.SkipTest( "IocpEventLoop does not have create_datagram_endpoint()") - def test_create_datagram_endpoint_no_connection(self): - raise unittest.SkipTest( - "IocpEventLoop does not have create_datagram_endpoint()") - def test_create_datagram_endpoint_cant_bind(self): - raise unittest.SkipTest( - "IocpEventLoop does not have create_datagram_endpoint()") - def test_create_datagram_endpoint_noaddr_nofamily(self): - raise unittest.SkipTest( - "IocpEventLoop does not have create_datagram_endpoint()") - def test_create_datagram_endpoint_socket_err(self): - raise unittest.SkipTest( - "IocpEventLoop does not have create_datagram_endpoint()") - def test_create_datagram_endpoint_connect_err(self): - raise unittest.SkipTest( - "IocpEventLoop does not have create_datagram_endpoint()") def test_stop_serving(self): raise unittest.SkipTest( "IocpEventLoop does not support stop_serving()") @@ -1135,6 +934,11 @@ def callback(): class TimerTests(unittest.TestCase): + def test_hash(self): + when = time.monotonic() + h = events.TimerHandle(when, lambda: False, ()) + self.assertEqual(hash(h), hash(when)) + def test_timer(self): def callback(*args): return args @@ -1215,6 +1019,8 @@ def test_not_implemented(self): NotImplementedError, loop.is_running) self.assertRaises( NotImplementedError, loop.call_later, None, None) + self.assertRaises( + NotImplementedError, loop.call_at, f, f) self.assertRaises( NotImplementedError, loop.call_soon, None) self.assertRaises( diff --git a/tests/futures_test.py b/tests/futures_test.py index 0fd482f8..2c81e4da 100644 --- a/tests/futures_test.py +++ b/tests/futures_test.py @@ -1,6 +1,5 @@ """Tests for futures.py.""" -import logging import unittest import unittest.mock diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index fd5115d3..e1916f30 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -40,6 +40,14 @@ def restore_signal(): def test_handle_signal_no_handler(self): self.loop._handle_signal(signal.NSIG + 1, ()) + def test_handle_signal_cancelled_handler(self): + h = events.Handle(unittest.mock.Mock(), ()) + h.cancel() + self.loop._signal_handlers[signal.NSIG + 1] = h + self.loop.remove_signal_handler = unittest.mock.Mock() + self.loop._handle_signal(signal.NSIG + 1, ()) + self.loop.remove_signal_handler.assert_called_with(signal.NSIG + 1) + @unittest.mock.patch('tulip.unix_events.signal') def test_add_signal_handler_setup_error(self, m_signal): m_signal.NSIG = signal.NSIG diff --git a/tulip/base_events.py b/tulip/base_events.py index 7d20f654..5623f6df 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -257,11 +257,13 @@ def create_connection(self, protocol_factory, host=None, port=None, *, raise ValueError( 'host/port and sock can not be specified at the same time') - f1 = self.getaddrinfo(host, port, family=family, - type=socket.SOCK_STREAM, proto=proto, flags=flags) + f1 = self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) fs = [f1] if local_addr is not None: - f2 = self.getaddrinfo(*local_addr, family=family, + f2 = self.getaddrinfo( + *local_addr, family=family, type=socket.SOCK_STREAM, proto=proto, flags=flags) fs.append(f2) else: From 7ac16478374de7ce852678e36d99adbd0547a453 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Fri, 10 May 2013 11:10:46 +0100 Subject: [PATCH 0466/1502] Make IocpProactor.connect() work with locally-bound socket. --- tulip/windows_events.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tulip/windows_events.py b/tulip/windows_events.py index bede2b5e..112dfc0d 100644 --- a/tulip/windows_events.py +++ b/tulip/windows_events.py @@ -87,7 +87,19 @@ def finish_accept(): def connect(self, conn, address): self._register_with_iocp(conn) - _overlapped.BindLocal(conn.fileno(), len(address)) + # the socket must be locally bound before calling ConnectEx() + try: + _overlapped.BindLocal(conn.fileno(), len(address)) + except OSError as e: + if e.winerror == 10022: # WSAEINVAL + # the socket is probably already locally bound + try: + if conn.getsockname()[1] == 0: + raise e + except: + raise e + else: + raise ov = _overlapped.Overlapped(NULL) ov.ConnectEx(conn.fileno(), address) From 5220b6f66c1ce8211fd3069e61b2c1be072dcdcf Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 10 May 2013 16:00:38 +0300 Subject: [PATCH 0467/1502] Rename event_loop to loop into selector_events transports --- tulip/selector_events.py | 48 ++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 5c182fa4..f402bec4 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -327,19 +327,19 @@ def stop_serving(self, sock): class _SelectorSocketTransport(transports.Transport): - def __init__(self, event_loop, sock, protocol, waiter=None, extra=None): + def __init__(self, loop, sock, protocol, waiter=None, extra=None): super().__init__(extra) self._extra['socket'] = sock - self._event_loop = event_loop + self._loop = loop self._sock = sock self._protocol = protocol self._buffer = [] self._conn_lost = 0 self._closing = False # Set when close() called. - self._event_loop.add_reader(self._sock.fileno(), self._read_ready) - self._event_loop.call_soon(self._protocol.connection_made, self) + self._loop.add_reader(self._sock.fileno(), self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - self._event_loop.call_soon(waiter.set_result, None) + self._loop.call_soon(waiter.set_result, None) def _read_ready(self): try: @@ -352,7 +352,7 @@ def _read_ready(self): if data: self._protocol.data_received(data) else: - self._event_loop.remove_reader(self._sock.fileno()) + self._loop.remove_reader(self._sock.fileno()) self._protocol.eof_received() def write(self, data): @@ -382,7 +382,7 @@ def write(self, data): return elif n: data = data[n:] - self._event_loop.add_writer(self._sock.fileno(), self._write_ready) + self._loop.add_writer(self._sock.fileno(), self._write_ready) self._buffer.append(data) @@ -400,7 +400,7 @@ def _write_ready(self): self._fatal_error(exc) else: if n == len(data): - self._event_loop.remove_writer(self._sock.fileno()) + self._loop.remove_writer(self._sock.fileno()) if self._closing: self._call_connection_lost(None) return @@ -416,9 +416,9 @@ def abort(self): def close(self): self._closing = True - self._event_loop.remove_reader(self._sock.fileno()) + self._loop.remove_reader(self._sock.fileno()) if not self._buffer: - self._event_loop.call_soon(self._call_connection_lost, None) + self._loop.call_soon(self._call_connection_lost, None) def _fatal_error(self, exc): # should be called from exception handler only @@ -426,10 +426,10 @@ def _fatal_error(self, exc): self._close(exc) def _close(self, exc): - self._event_loop.remove_writer(self._sock.fileno()) - self._event_loop.remove_reader(self._sock.fileno()) + self._loop.remove_writer(self._sock.fileno()) + self._loop.remove_reader(self._sock.fileno()) self._buffer.clear() - self._event_loop.call_soon(self._call_connection_lost, exc) + self._loop.call_soon(self._call_connection_lost, exc) def _call_connection_lost(self, exc): try: @@ -594,10 +594,10 @@ class _SelectorDatagramTransport(transports.DatagramTransport): max_size = 256 * 1024 # max bytes we read in one eventloop iteration - def __init__(self, event_loop, sock, protocol, address=None, extra=None): + def __init__(self, loop, sock, protocol, address=None, extra=None): super().__init__(extra) self._extra['socket'] = sock - self._event_loop = event_loop + self._loop = loop self._sock = sock self._fileno = sock.fileno() self._protocol = protocol @@ -605,8 +605,8 @@ def __init__(self, event_loop, sock, protocol, address=None, extra=None): self._buffer = collections.deque() self._conn_lost = 0 self._closing = False # Set when close() called. - self._event_loop.add_reader(self._fileno, self._read_ready) - self._event_loop.call_soon(self._protocol.connection_made, self) + self._loop.add_reader(self._fileno, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) def _read_ready(self): try: @@ -647,7 +647,7 @@ def sendto(self, data, addr=None): self._fatal_error(exc) return except (BlockingIOError, InterruptedError): - self._event_loop.add_writer(self._fileno, self._sendto_ready) + self._loop.add_writer(self._fileno, self._sendto_ready) except Exception as exc: self._conn_lost += 1 self._fatal_error(exc) @@ -677,7 +677,7 @@ def _sendto_ready(self): return if not self._buffer: - self._event_loop.remove_writer(self._fileno) + self._loop.remove_writer(self._fileno) if self._closing: self._call_connection_lost(None) @@ -686,9 +686,9 @@ def abort(self): def close(self): self._closing = True - self._event_loop.remove_reader(self._fileno) + self._loop.remove_reader(self._fileno) if not self._buffer: - self._event_loop.call_soon(self._call_connection_lost, None) + self._loop.call_soon(self._call_connection_lost, None) def _fatal_error(self, exc): tulip_log.exception('Fatal error for %s', self) @@ -696,11 +696,11 @@ def _fatal_error(self, exc): def _close(self, exc): self._buffer.clear() - self._event_loop.remove_writer(self._fileno) - self._event_loop.remove_reader(self._fileno) + self._loop.remove_writer(self._fileno) + self._loop.remove_reader(self._fileno) if self._address and isinstance(exc, ConnectionRefusedError): self._protocol.connection_refused(exc) - self._event_loop.call_soon(self._call_connection_lost, exc) + self._loop.call_soon(self._call_connection_lost, exc) def _call_connection_lost(self, exc): try: From 488470f6650b859bb7c51e0a51c6bbb3de1998bb Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Fri, 10 May 2013 19:22:08 +0100 Subject: [PATCH 0468/1502] Make error handling in sock_connect() a bit clearer and simpler. --- tulip/windows_events.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tulip/windows_events.py b/tulip/windows_events.py index 112dfc0d..cd6c61af 100644 --- a/tulip/windows_events.py +++ b/tulip/windows_events.py @@ -1,5 +1,6 @@ """Selector and proactor eventloops for Windows.""" +import errno import socket import weakref import struct @@ -87,18 +88,14 @@ def finish_accept(): def connect(self, conn, address): self._register_with_iocp(conn) - # the socket must be locally bound before calling ConnectEx() + # the socket needs to be locally bound before we call ConnectEx() try: _overlapped.BindLocal(conn.fileno(), len(address)) except OSError as e: - if e.winerror == 10022: # WSAEINVAL - # the socket is probably already locally bound - try: - if conn.getsockname()[1] == 0: - raise e - except: - raise e - else: + if e.winerror != errno.WSAEINVAL: + raise + # probably already locally bound; check using getsockname() + if conn.getsockname()[1] == 0: raise ov = _overlapped.Overlapped(NULL) ov.ConnectEx(conn.fileno(), address) From 7ce473e04eef806b5b19bf3571920c103cb490dc Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 10 May 2013 16:01:58 -0700 Subject: [PATCH 0469/1502] Handle and TimerHandle should have no public API except for cancel() issue #40 --- tests/events_test.py | 18 +++++++++--------- tests/selector_events_test.py | 14 +++++++------- tests/tasks_test.py | 4 ++-- tests/unix_events_test.py | 2 +- tulip/base_events.py | 16 ++++++++-------- tulip/events.py | 18 +----------------- tulip/selector_events.py | 4 ++-- tulip/unix_events.py | 2 +- 8 files changed, 31 insertions(+), 47 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 9e9eec74..eefc1a66 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -896,9 +896,9 @@ def callback(*args): args = () h = events.Handle(callback, args) - self.assertIs(h.callback, callback) - self.assertIs(h.args, args) - self.assertFalse(h.cancelled) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h._cancelled) r = repr(h) self.assertTrue(r.startswith( @@ -907,7 +907,7 @@ def callback(*args): self.assertTrue(r.endswith('())')) h.cancel() - self.assertTrue(h.cancelled) + self.assertTrue(h._cancelled) r = repr(h) self.assertTrue(r.startswith( @@ -928,7 +928,7 @@ def callback(): raise ValueError() h = events.Handle(callback, ()) - h.run() + h._run() self.assertTrue(log.exception.called) @@ -946,15 +946,15 @@ def callback(*args): args = () when = time.monotonic() h = events.TimerHandle(when, callback, args) - self.assertIs(h.callback, callback) - self.assertIs(h.args, args) - self.assertFalse(h.cancelled) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h._cancelled) r = repr(h) self.assertTrue(r.endswith('())')) h.cancel() - self.assertTrue(h.cancelled) + self.assertTrue(h._cancelled) r = repr(h) self.assertTrue(r.endswith('())')) diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index 3232fa46..8db2e37c 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -387,7 +387,7 @@ def test_add_reader(self): fd, mask, (r, w) = self.loop._selector.register.call_args[0] self.assertEqual(1, fd) self.assertEqual(selectors.EVENT_READ, mask) - self.assertEqual(cb, r.callback) + self.assertEqual(cb, r._callback) self.assertEqual(None, w) def test_add_reader_existing(self): @@ -404,7 +404,7 @@ def test_add_reader_existing(self): fd, mask, (r, w) = self.loop._selector.modify.call_args[0] self.assertEqual(1, fd) self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask) - self.assertEqual(cb, r.callback) + self.assertEqual(cb, r._callback) self.assertEqual(writer, w) def test_add_reader_existing_writer(self): @@ -419,7 +419,7 @@ def test_add_reader_existing_writer(self): fd, mask, (r, w) = self.loop._selector.modify.call_args[0] self.assertEqual(1, fd) self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask) - self.assertEqual(cb, r.callback) + self.assertEqual(cb, r._callback) self.assertEqual(writer, w) def test_remove_reader(self): @@ -457,7 +457,7 @@ def test_add_writer(self): self.assertEqual(1, fd) self.assertEqual(selectors.EVENT_WRITE, mask) self.assertEqual(None, r) - self.assertEqual(cb, w.callback) + self.assertEqual(cb, w._callback) def test_add_writer_existing(self): reader = unittest.mock.Mock() @@ -474,7 +474,7 @@ def test_add_writer_existing(self): self.assertEqual(1, fd) self.assertEqual(selectors.EVENT_WRITE | selectors.EVENT_READ, mask) self.assertEqual(reader, r) - self.assertEqual(cb, w.callback) + self.assertEqual(cb, w._callback) def test_remove_writer(self): self.loop._selector.get_info.return_value = ( @@ -503,7 +503,7 @@ def test_remove_writer_unknown(self): def test_process_events_read(self): reader = unittest.mock.Mock() - reader.cancelled = False + reader._cancelled = False self.loop._add_callback = unittest.mock.Mock() self.loop._process_events( @@ -522,7 +522,7 @@ def test_process_events_read_cancelled(self): def test_process_events_write(self): writer = unittest.mock.Mock() - writer.cancelled = False + writer._cancelled = False self.loop._add_callback = unittest.mock.Mock() self.loop._process_events( diff --git a/tests/tasks_test.py b/tests/tasks_test.py index d3def0c9..16bf8f27 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -496,11 +496,11 @@ def call_later(self, delay, callback, *args): self.loop.call_later = call_later self.loop.run_once() - self.assertFalse(handle.cancelled) + self.assertFalse(handle._cancelled) t.cancel() self.loop.run_once() - self.assertTrue(handle.cancelled) + self.assertTrue(handle._cancelled) def test_task_cancel_sleeping_task(self): sleepfut = None diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index e1916f30..d5f72b9d 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -66,7 +66,7 @@ def test_add_signal_handler(self, m_signal): self.loop.add_signal_handler(signal.SIGHUP, cb) h = self.loop._signal_handlers.get(signal.SIGHUP) self.assertTrue(isinstance(h, events.Handle)) - self.assertEqual(h.callback, cb) + self.assertEqual(h._callback, cb) @unittest.mock.patch('tulip.unix_events.signal') def test_add_signal_handler_install_error(self, m_signal): diff --git a/tulip/base_events.py b/tulip/base_events.py index 5623f6df..a0f38be6 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -224,11 +224,11 @@ def run_in_executor(self, executor, callback, *args): if isinstance(callback, events.Handle): assert not args assert not isinstance(callback, events.TimerHandle) - if callback.cancelled: + if callback._cancelled: f = futures.Future() f.set_result(None) return f - callback, args = callback.callback, callback.args + callback, args = callback._callback, callback._args if executor is None: executor = self._default_executor if executor is None: @@ -490,7 +490,7 @@ def connect_write_pipe(self, protocol_factory, pipe): def _add_callback(self, handle): """Add a Handle to ready or scheduled.""" assert isinstance(handle, events.Handle), 'A Handle is required here' - if handle.cancelled: + if handle._cancelled: return if isinstance(handle, events.TimerHandle): heapq.heappush(self._scheduled, handle) @@ -520,14 +520,14 @@ def _run_once(self, timeout=None): 'call_later' callbacks. """ # Remove delayed calls that were cancelled from head of queue. - while self._scheduled and self._scheduled[0].cancelled: + while self._scheduled and self._scheduled[0]._cancelled: heapq.heappop(self._scheduled) if self._ready: timeout = 0 elif self._scheduled: # Compute the desired timeout. - when = self._scheduled[0].when + when = self._scheduled[0]._when deadline = max(0, when - self.time()) if timeout is None: timeout = deadline @@ -550,7 +550,7 @@ def _run_once(self, timeout=None): now = self.time() while self._scheduled: handle = self._scheduled[0] - if handle.when > now: + if handle._when > now: break handle = heapq.heappop(self._scheduled) self._ready.append(handle) @@ -564,6 +564,6 @@ def _run_once(self, timeout=None): ntodo = len(self._ready) for i in range(ntodo): handle = self._ready.popleft() - if not handle.cancelled: - handle.run() + if not handle._cancelled: + handle._run() handle = None # Needed to break cycles when an exception occurs. diff --git a/tulip/events.py b/tulip/events.py index 6efcd76e..dea5c8ba 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -31,22 +31,10 @@ def __repr__(self): res += '' return res - @property - def callback(self): - return self._callback - - @property - def args(self): - return self._args - - @property - def cancelled(self): - return self._cancelled - def cancel(self): self._cancelled = True - def run(self): + def _run(self): try: self._callback(*self._args) except Exception: @@ -79,10 +67,6 @@ def __repr__(self): return res - @property - def when(self): - return self._when - def __hash__(self): return hash(self._when) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index f402bec4..1ae1202c 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -310,12 +310,12 @@ def _sock_accept(self, fut, registered, sock): def _process_events(self, event_list): for fileobj, mask, (reader, writer) in event_list: if mask & selectors.EVENT_READ and reader is not None: - if reader.cancelled: + if reader._cancelled: self.remove_reader(fileobj) else: self._add_callback(reader) if mask & selectors.EVENT_WRITE and writer is not None: - if writer.cancelled: + if writer._cancelled: self.remove_writer(fileobj) else: self._add_callback(writer) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 35926406..a8e073af 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -77,7 +77,7 @@ def _handle_signal(self, sig, arg): handle = self._signal_handlers.get(sig) if handle is None: return # Assume it's some race condition. - if handle.cancelled: + if handle._cancelled: self.remove_signal_handler(sig) # Remove it properly. else: self._add_callback_signalsafe(handle) From 58bc3d881501c0d61ec82a9a0428c3917598e3d2 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 10 May 2013 16:10:22 -0700 Subject: [PATCH 0470/1502] implement wrap_future() as helper function --- tests/base_events_test.py | 12 ------------ tests/events_test.py | 13 ------------- tests/futures_test.py | 18 ++++++++++++++++++ tests/tasks_test.py | 28 ---------------------------- tulip/base_events.py | 12 +----------- tulip/events.py | 3 --- tulip/futures.py | 20 +++++++++++++++++++- tulip/tasks.py | 9 --------- 8 files changed, 38 insertions(+), 77 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index b5852d3c..c8397811 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -1,6 +1,5 @@ """Tests for base_events.py""" -import concurrent.futures import logging import socket import time @@ -66,17 +65,6 @@ def test__add_callback_cancelled_handle(self): self.assertFalse(self.loop._scheduled) self.assertFalse(self.loop._ready) - def test_wrap_future(self): - f = futures.Future(loop=self.loop) - self.assertIs(self.loop.wrap_future(f), f) - f.cancel() - - def test_wrap_future_concurrent(self): - f = concurrent.futures.Future() - fut = self.loop.wrap_future(f) - self.assertIsInstance(fut, futures.Future) - fut.cancel() - def test_set_default_executor(self): executor = unittest.mock.Mock() self.loop.set_default_executor(executor) diff --git a/tests/events_test.py b/tests/events_test.py index eefc1a66..ea06abb9 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -1,6 +1,5 @@ """Tests for events.py.""" -import concurrent.futures import gc import io import os @@ -252,16 +251,6 @@ def callback(arg): self.loop.run_forever() self.assertEqual(results, ['hello', 'world']) - def test_wrap_future(self): - def run(arg): - time.sleep(0.1) - return arg - ex = concurrent.futures.ThreadPoolExecutor(1) - f1 = ex.submit(run, 'oi') - f2 = self.loop.wrap_future(f1) - res = self.loop.run_until_complete(f2) - self.assertEqual(res, 'oi') - def test_run_in_executor(self): def run(arg): time.sleep(0.1) @@ -1027,8 +1016,6 @@ def test_not_implemented(self): NotImplementedError, loop.time) self.assertRaises( NotImplementedError, loop.call_soon_threadsafe, None) - self.assertRaises( - NotImplementedError, loop.wrap_future, f) self.assertRaises( NotImplementedError, loop.run_in_executor, f, f) self.assertRaises( diff --git a/tests/futures_test.py b/tests/futures_test.py index 2c81e4da..18e70c41 100644 --- a/tests/futures_test.py +++ b/tests/futures_test.py @@ -1,5 +1,7 @@ """Tests for futures.py.""" +import concurrent.futures +import time import unittest import unittest.mock @@ -210,6 +212,22 @@ def test_tb_logger_exception_result_retrieved(self, m_log): del fut self.assertFalse(m_log.error.called) + def test_wrap_future(self): + def run(arg): + time.sleep(0.1) + return arg + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = futures.wrap_future(f1) + res = self.loop.run_until_complete(f2) + self.assertIsInstance(f2, futures.Future) + self.assertEqual(res, 'oi') + + def test_wrap_future_future(self): + f1 = futures.Future() + f2 = futures.wrap_future(f1) + self.assertIs(f1, f2) + # A fake event loop for tests. All it does is implement a call_soon method # that immediately invokes the given function. diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 16bf8f27..3ccaa92b 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -1,6 +1,5 @@ """Tests for tasks.py.""" -import concurrent.futures import time import unittest import unittest.mock @@ -605,33 +604,6 @@ def wait_for_future(): self.assertTrue(t.done()) self.assertIsNone(t.result()) - def test_step_result_concurrent_future(self): - # Coroutine returns concurrent.futures.Future - - class Fut(concurrent.futures.Future): - def __init__(self): - self.cb_added = False - super().__init__() - - def add_done_callback(self, fn): - self.cb_added = True - super().add_done_callback(fn) - - c_fut = Fut() - - @tasks.coroutine - def notmuch(): - return (yield c_fut) - - task = tasks.Task(notmuch()) - self.loop.run_once() - self.assertTrue(c_fut.cb_added) - - res = object() - c_fut.set_result(res) - self.loop.run_once() - self.assertIs(res, task.result()) - def test_step_with_baseexception(self): @tasks.coroutine def notmutch(): diff --git a/tulip/base_events.py b/tulip/base_events.py index a0f38be6..fcc6ce2c 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -234,7 +234,7 @@ def run_in_executor(self, executor, callback, *args): if executor is None: executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) self._default_executor = executor - return self.wrap_future(executor.submit(callback, *args)) + return futures.wrap_future(executor.submit(callback, *args), loop=self) def set_default_executor(self, executor): self._default_executor = executor @@ -502,16 +502,6 @@ def _add_callback_signalsafe(self, handle): self._add_callback(handle) self._write_to_self() - def wrap_future(self, future): - """XXX""" - if isinstance(future, futures.Future): - return future # Don't wrap our own type of Future. - new_future = futures.Future(loop=self) - future.add_done_callback( - lambda future: - self.call_soon_threadsafe(new_future._copy_state, future)) - return new_future - def _run_once(self, timeout=None): """Run one full iteration of the event loop. diff --git a/tulip/events.py b/tulip/events.py index dea5c8ba..b1b5186c 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -157,9 +157,6 @@ def time(self): def call_soon_threadsafe(self, callback, *args): raise NotImplementedError - def wrap_future(self, future): - raise NotImplementedError - def run_in_executor(self, executor, callback, *args): raise NotImplementedError diff --git a/tulip/futures.py b/tulip/futures.py index e2c0245e..14edbc99 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -2,7 +2,7 @@ __all__ = ['CancelledError', 'TimeoutError', 'InvalidStateError', 'InvalidTimeoutError', - 'Future', + 'Future', 'wrap_future', ] import concurrent.futures._base @@ -343,3 +343,21 @@ def __iter__(self): yield self # This tells Task to wait for completion. assert self.done(), "yield from wasn't used with future" return self.result() # May raise too. + + +def wrap_future(fut, *, loop=None): + """Wrap concurrent.futures.Future object.""" + if isinstance(fut, Future): + return fut + + assert isinstance(fut, concurrent.futures.Future), \ + 'concurrent.futures.Future is expected, got {!r}'.format(fut) + + if loop is None: + loop = events.get_event_loop() + + new_future = Future(loop=loop) + fut.add_done_callback( + lambda future: loop.call_soon_threadsafe( + new_future._copy_state, fut)) + return new_future diff --git a/tulip/tasks.py b/tulip/tasks.py index 946fae56..54c594ba 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -161,11 +161,6 @@ def _step(self, value=_marker, exc=None): if self._must_cancel: self._fut_waiter.cancel() - elif isinstance(result, concurrent.futures.Future): - # Don't use a lambda here; mysteriously it creates an - # unnecessary memory cycle. - result.add_done_callback(self._wakeup_from_thread) - else: if inspect.isgenerator(result): self._loop.call_soon( @@ -184,10 +179,6 @@ def _step(self, value=_marker, exc=None): self._loop.call_soon(self._step_maybe) self = None - def _wakeup_from_thread(self, future): - # Helper to wake up a task from a thread. - self._loop.call_soon_threadsafe(self._wakeup, future) - def _wakeup(self, future): try: value = future.result() From 43f3745df2e1594bae9ff72daf430fb7887302db Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 10 May 2013 16:21:31 -0700 Subject: [PATCH 0471/1502] close transport after receiving eof --- examples/tcp_echo.py | 1 + tests/selector_events_test.py | 2 ++ tulip/selector_events.py | 1 + 3 files changed, 4 insertions(+) diff --git a/examples/tcp_echo.py b/examples/tcp_echo.py index 16e3fb65..ff40c4ab 100755 --- a/examples/tcp_echo.py +++ b/examples/tcp_echo.py @@ -38,6 +38,7 @@ def eof_received(self): def connection_lost(self, exc): print('connection lost') + self.h_timeout.cancel() class EchoClient(tulip.Protocol): diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index 8db2e37c..7596223b 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -576,11 +576,13 @@ def test_read_ready_eof(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) + self.loop.reset_mock() self.sock.recv.return_value = b'' transport._read_ready() self.assertTrue(self.loop.remove_reader.called) self.protocol.eof_received.assert_called_with() + self.loop.call_soon.assert_called_with(transport.close) @unittest.mock.patch('logging.exception') def test_read_ready_tryagain(self, m_exc): diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 1ae1202c..e55edd8d 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -353,6 +353,7 @@ def _read_ready(self): self._protocol.data_received(data) else: self._loop.remove_reader(self._sock.fileno()) + self._loop.call_soon(self.close) self._protocol.eof_received() def write(self, data): From 40e28434284cb25b28d7ec6105d0543a412b858f Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 10 May 2013 18:54:00 -0700 Subject: [PATCH 0472/1502] ssl transport eof support --- tests/events_test.py | 1 - tests/selector_events_test.py | 7 +++---- tulip/selector_events.py | 36 ++++++++++++++++------------------- 3 files changed, 19 insertions(+), 25 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index ea06abb9..87be7736 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -50,7 +50,6 @@ def data_received(self, data): def eof_received(self): assert self.state == 'CONNECTED', self.state self.state = 'EOF' - self.transport.close() def connection_lost(self, exc): assert self.state in ('CONNECTED', 'EOF'), self.state diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index 7596223b..ea0e0efd 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -983,11 +983,10 @@ def test_on_ready_recv(self): def test_on_ready_recv_eof(self): self.sslsock.recv.return_value = b'' transport = self._make_one() + transport.close = unittest.mock.Mock() transport._on_ready() - self.assertTrue(self.loop.remove_reader.called) - self.assertTrue(self.loop.remove_writer.called) - self.assertTrue(self.sslsock.close.called) - self.protocol.connection_lost.assert_called_with(None) + transport.close.assert_called_with() + self.protocol.eof_received.assert_called_with() def test_on_ready_recv_retry(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError diff --git a/tulip/selector_events.py b/tulip/selector_events.py index e55edd8d..20473ea1 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -506,27 +506,23 @@ def _on_ready(self): return # First try reading. - try: - data = self._sslsock.recv(8192) - except ssl.SSLWantReadError: - pass - except ssl.SSLWantWriteError: - pass - except (BlockingIOError, InterruptedError): - pass - except Exception as exc: - self._fatal_error(exc) - else: - if data: - self._protocol.data_received(data) + if not self._closing: + try: + data = self._sslsock.recv(8192) + except ssl.SSLWantReadError: + pass + except ssl.SSLWantWriteError: + pass + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) else: - # TODO: Don't close when self._buffer is non-empty. - assert not self._buffer - self._loop.remove_reader(fd) - self._loop.remove_writer(fd) - self._sslsock.close() - self._protocol.connection_lost(None) - return + if data: + self._protocol.data_received(data) + else: + self._protocol.eof_received() + self.close() # Now try writing, if there's anything to write. if not self._buffer: From 27d01d930e78a52780878a89247ef04d1dde651c Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 13 May 2013 14:03:31 -0700 Subject: [PATCH 0473/1502] http server keep-alive support --- tests/http_client_functional_test.py | 8 ++ tests/http_parser_test.py | 14 +++ tests/http_server_test.py | 134 +++++++++++++++++++-------- tests/http_wsgi_test.py | 27 ++++++ tulip/http/server.py | 68 ++++++++++---- tulip/http/wsgi.py | 4 +- tulip/test_utils.py | 20 ++-- 7 files changed, 207 insertions(+), 68 deletions(-) diff --git a/tests/http_client_functional_test.py b/tests/http_client_functional_test.py index 31d0df86..430664da 100644 --- a/tests/http_client_functional_test.py +++ b/tests/http_client_functional_test.py @@ -357,6 +357,14 @@ def test_request_conn_error(self): self.loop.run_until_complete, client.request('get', 'http://0.0.0.0:1', timeout=0.1)) + def test_request_conn_closed(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + httpd['close'] = True + self.assertRaises( + tulip.http.HttpException, + self.loop.run_until_complete, + client.request('get', httpd.url('method', 'get'))) + def test_keepalive(self): from tulip.http import session s = session.Session() diff --git a/tests/http_parser_test.py b/tests/http_parser_test.py index c8dfb1bc..91accfd7 100644 --- a/tests/http_parser_test.py +++ b/tests/http_parser_test.py @@ -382,6 +382,20 @@ def test_http_request_parser(self): self.assertEqual( ('GET', '/path', (1, 1), deque(), False, None), result) + def test_http_request_parser_eof(self): + # http_request_parser does not fail on EofStream() + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + p.send(b'get /path HTTP/1.1\r\n') + try: + p.throw(tulip.EofStream()) + except StopIteration: + pass + self.assertFalse(out._buffer) + def test_http_request_parser_two_slashes(self): p = protocol.http_request_parser() next(p) diff --git a/tests/http_server_test.py b/tests/http_server_test.py index c0f09603..cba5fabc 100644 --- a/tests/http_server_test.py +++ b/tests/http_server_test.py @@ -6,6 +6,7 @@ import tulip from tulip.http import server from tulip.http import errors +from tulip.test_utils import run_once class HttpServerProtocolTests(unittest.TestCase): @@ -38,10 +39,10 @@ def test_handle_request(self): def test_connection_made(self): srv = server.ServerHttpProtocol() - self.assertIsNone(srv._request_handle) + self.assertIsNone(srv._request_handler) srv.connection_made(unittest.mock.Mock()) - self.assertIsNotNone(srv._request_handle) + self.assertIsNotNone(srv._request_handler) def test_data_received(self): srv = server.ServerHttpProtocol() @@ -64,31 +65,42 @@ def test_connection_lost(self): srv.connection_made(unittest.mock.Mock()) srv.data_received(b'123') - handle = srv._request_handle + keep_alive_handle = srv._keep_alive_handle = unittest.mock.Mock() + + handle = srv._request_handler srv.connection_lost(None) - self.assertIsNone(srv._request_handle) + self.assertIsNone(srv._request_handler) self.assertTrue(handle.cancelled()) + self.assertIsNone(srv._keep_alive_handle) + self.assertTrue(keep_alive_handle.cancel.called) + srv.connection_lost(None) - self.assertIsNone(srv._request_handle) + self.assertIsNone(srv._request_handler) + self.assertIsNone(srv._keep_alive_handle) - def test_close(self): + def test_srv_keep_alive(self): srv = server.ServerHttpProtocol() - self.assertFalse(srv._closing) + self.assertFalse(srv._keep_alive) - srv.close() - self.assertTrue(srv._closing) + srv.keep_alive(True) + self.assertTrue(srv._keep_alive) + + srv.keep_alive(False) + self.assertFalse(srv._keep_alive) def test_handle_error(self): transport = unittest.mock.Mock() srv = server.ServerHttpProtocol() srv.connection_made(transport) + srv.keep_alive(True) srv.handle_error(404, headers=(('X-Server', 'Tulip'),)) content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) self.assertIn(b'HTTP/1.1 404 Not Found', content) self.assertIn(b'X-SERVER: Tulip', content) + self.assertFalse(srv._keep_alive) @unittest.mock.patch('tulip.http.server.traceback') def test_handle_error_traceback_exc(self, m_trace): @@ -141,10 +153,10 @@ def test_handle(self): srv.stream.feed_data( b'GET / HTTP/1.0\r\n' b'Host: example.com\r\n\r\n') - srv.close() - self.loop.run_until_complete(srv._request_handle) + self.loop.run_until_complete(srv._request_handler) self.assertTrue(handle.called) + self.assertTrue(transport.close.called) def test_handle_coro(self): transport = unittest.mock.Mock() @@ -157,7 +169,6 @@ def coro(message, payload): nonlocal called called = True srv.eof_received() - srv.close() srv.handle_request = coro srv.connection_made(transport) @@ -165,26 +176,27 @@ def coro(message, payload): srv.stream.feed_data( b'GET / HTTP/1.0\r\n' b'Host: example.com\r\n\r\n') - self.loop.run_until_complete(srv._request_handle) + self.loop.run_until_complete(srv._request_handler) self.assertTrue(called) - def test_handle_close(self): + def test_handle_cancel(self): + log = unittest.mock.Mock() transport = unittest.mock.Mock() - srv = server.ServerHttpProtocol() + + srv = server.ServerHttpProtocol(log=log, debug=True) srv.connection_made(transport) - handle = srv.handle_request = unittest.mock.Mock() + srv.handle_request = unittest.mock.Mock() - srv.stream.feed_data( - b'GET / HTTP/1.0\r\n' - b'Host: example.com\r\n\r\n') - srv.close() - self.loop.run_until_complete(srv._request_handle) + @tulip.task + def cancel(): + srv._request_handler.cancel() - self.assertTrue(handle.called) - self.assertTrue(transport.close.called) + self.loop.run_until_complete( + tulip.wait([srv._request_handler, cancel()])) + self.assertTrue(log.debug.called) - def test_handle_cancel(self): + def test_handle_cancelled(self): log = unittest.mock.Mock() transport = unittest.mock.Mock() @@ -192,29 +204,26 @@ def test_handle_cancel(self): srv.connection_made(transport) srv.handle_request = unittest.mock.Mock() + run_once(self.loop) # start request_handler task - @tulip.task - def cancel(): - srv._request_handle.cancel() + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') - srv.close() - self.loop.run_until_complete( - tulip.wait([srv._request_handle, cancel()])) - self.assertTrue(log.debug.called) + r_handler = srv._request_handler + srv._request_handler = None # emulate srv.connection_lost() + + self.assertIsNone(self.loop.run_until_complete(r_handler)) def test_handle_400(self): transport = unittest.mock.Mock() srv = server.ServerHttpProtocol() srv.connection_made(transport) srv.handle_error = unittest.mock.Mock() - - def side_effect(*args): - srv.close() - srv.handle_error.side_effect = side_effect - + srv.keep_alive(True) srv.stream.feed_data(b'GET / HT/asd\r\n\r\n') - self.loop.run_until_complete(srv._request_handle) + self.loop.run_until_complete(srv._request_handler) self.assertTrue(srv.handle_error.called) self.assertTrue(400, srv.handle_error.call_args[0][0]) self.assertTrue(transport.close.called) @@ -227,12 +236,11 @@ def test_handle_500(self): handle = srv.handle_request = unittest.mock.Mock() handle.side_effect = ValueError srv.handle_error = unittest.mock.Mock() - srv.close() srv.stream.feed_data( b'GET / HTTP/1.0\r\n' b'Host: example.com\r\n\r\n') - self.loop.run_until_complete(srv._request_handle) + self.loop.run_until_complete(srv._request_handler) self.assertTrue(srv.handle_error.called) self.assertTrue(500, srv.handle_error.call_args[0][0]) @@ -240,9 +248,53 @@ def test_handle_500(self): def test_handle_error_no_handle_task(self): transport = unittest.mock.Mock() srv = server.ServerHttpProtocol() + srv.keep_alive(True) srv.connection_made(transport) srv.connection_lost(None) - close = srv.close = unittest.mock.Mock() srv.handle_error(300) - self.assertTrue(close.called) + self.assertFalse(srv._keep_alive) + + def test_keep_alive(self): + srv = server.ServerHttpProtocol(keep_alive=0.1) + transport = unittest.mock.Mock() + closed = False + + def close(): + nonlocal closed + closed = True + srv.connection_lost(None) + self.loop.stop() + + transport.close = close + + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.1\r\n' + b'CONNECTION: keep-alive\r\n' + b'HOST: example.com\r\n\r\n') + + self.loop.run_forever() + self.assertTrue(handle.called) + self.assertTrue(closed) + + def test_keep_alive_close_existing(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(keep_alive=15) + srv.connection_made(transport) + + self.assertIsNone(srv._keep_alive_handle) + keep_alive_handle = srv._keep_alive_handle = unittest.mock.Mock() + srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'HOST: example.com\r\n\r\n') + + self.loop.run_until_complete(srv._request_handler) + self.assertTrue(keep_alive_handle.cancel.called) + self.assertIsNone(srv._keep_alive_handle) + self.assertTrue(transport.close.called) diff --git a/tests/http_wsgi_test.py b/tests/http_wsgi_test.py index 1145f273..bedfca6d 100644 --- a/tests/http_wsgi_test.py +++ b/tests/http_wsgi_test.py @@ -221,6 +221,7 @@ def wsgi_app(env, start): [c[1][0] for c in self.transport.write.mock_calls]) self.assertTrue(content.startswith(b'HTTP/1.1 200 OK')) self.assertTrue(content.endswith(b'data\r\n0\r\n\r\n')) + self.assertFalse(srv._keep_alive) def test_handle_request_io(self): @@ -239,3 +240,29 @@ def wsgi_app(env, start): [c[1][0] for c in self.transport.write.mock_calls]) self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) self.assertTrue(content.endswith(b'data')) + + def test_handle_request_keep_alive(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + return [b'data'] + + stream = tulip.StreamReader() + stream.feed_data(b'data') + stream.feed_eof() + + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 1), self.headers, False, 'deflate') + + srv = wsgi.WSGIServerHttpProtocol(wsgi_app, True) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.1 200 OK')) + self.assertTrue(content.endswith(b'data\r\n0\r\n\r\n')) + self.assertTrue(srv._keep_alive) diff --git a/tulip/http/server.py b/tulip/http/server.py index 8c816e94..72fb15ee 100644 --- a/tulip/http/server.py +++ b/tulip/http/server.py @@ -34,34 +34,52 @@ class ServerHttpProtocol(tulip.Protocol): ServerHttpProtocol handles errors in incoming request, like bad status line, bad headers or incomplete payload. If any error occurs, connection gets closed. + + log: custom logging object + debug: enable debug mode + keep_alive: number of seconds before closing keep alive connection + loop: event loop object """ - _closing = False _request_count = 0 - _request_handle = None + _request_handler = None + _keep_alive = False # keep transport open + _keep_alive_handle = None # keep alive timer handle - def __init__(self, *, log=logging, debug=False, **kwargs): + def __init__(self, *, log=logging, debug=False, + keep_alive=None, loop=None, **kwargs): self.__dict__.update(kwargs) self.log = log self.debug = debug + self._keep_alive_period = keep_alive # number of seconds to keep alive + + if keep_alive and loop is None: + loop = tulip.get_event_loop() + self._loop = loop + def connection_made(self, transport): self.transport = transport self.stream = tulip.StreamBuffer() - self._request_handle = self.start() + self._request_handler = self.start() def data_received(self, data): self.stream.feed_data(data) - def connection_lost(self, exc): - if self._request_handle is not None: - self._request_handle.cancel() - self._request_handle = None - def eof_received(self): self.stream.feed_eof() - def close(self): - self._closing = True + def connection_lost(self, exc): + self.stream.feed_eof() + + if self._request_handler is not None: + self._request_handler.cancel() + self._request_handler = None + if self._keep_alive_handle is not None: + self._keep_alive_handle.cancel() + self._keep_alive_handle = None + + def keep_alive(self, val): + self._keep_alive = val def log_access(self, status, message, *args, **kw): pass @@ -79,13 +97,15 @@ def start(self): It reads request line, request headers and request payload, then calls handle_request() method. Subclass has to override handle_request(). start() handles various excetions in request - or response handling. In case of any error connection is being closed. + or response handling. Connection is being closed always unless + keep_alive(True) specified. """ - while self._request_handle is not None: + while True: info = None message = None self._request_count += 1 + self._keep_alive = False try: httpstream = self.stream.set_parser( @@ -93,6 +113,11 @@ def start(self): message = yield from httpstream.read() + # cancel keep-alive timer + if self._keep_alive_handle is not None: + self._keep_alive_handle.cancel() + self._keep_alive_handle = None + payload = self.stream.set_parser( tulip.http.http_payload_parser(message)) @@ -109,8 +134,15 @@ def start(self): except Exception as exc: self.handle_error(500, info, message, exc) finally: - if self._closing: - self.transport.close() + if self._request_handler: + if self._keep_alive and self._keep_alive_period: + self._keep_alive_handle = self._loop.call_later( + self._keep_alive_period, self.transport.close) + else: + self.transport.close() + self._request_handler = None + break + else: break def handle_error(self, status=500, @@ -120,7 +152,7 @@ def handle_error(self, status=500, Returns http response with specific status code. Logs additional information. It always closes current connection.""" try: - if self._request_handle is None: + if self._request_handler is None: # client has been disconnected during writing. return @@ -156,7 +188,7 @@ def handle_error(self, status=500, response.write(html.encode('ascii')) response.write_eof() finally: - self.close() + self.keep_alive(False) def handle_request(self, message, payload): """Handle a single http request. @@ -179,5 +211,5 @@ def handle_request(self, message, payload): response.write(body) response.write_eof() - self.close() + self.keep_alive(False) self.log_access(404, message) diff --git a/tulip/http/wsgi.py b/tulip/http/wsgi.py index a612a0aa..f36bde63 100644 --- a/tulip/http/wsgi.py +++ b/tulip/http/wsgi.py @@ -170,8 +170,8 @@ def handle_request(self, message, payload): if hasattr(riter, 'close'): riter.close() - if not resp.keep_alive(): - self.close() + if resp.keep_alive(): + self.keep_alive(True) class FileWrapper: diff --git a/tulip/test_utils.py b/tulip/test_utils.py index 29983a52..f5854f4e 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -65,9 +65,12 @@ def url(self, *suffix): class TestHttpServer(tulip.http.ServerHttpProtocol): def handle_request(self, message, payload): - if properties.get('noresponse', False): + if properties.get('close', False): return + if properties.get('noresponse', False): + yield from tulip.sleep(99999) + if router is not None: body = bytearray() chunk = yield from payload.read() @@ -75,7 +78,9 @@ def handle_request(self, message, payload): body.extend(chunk) chunk = yield from payload.read() - rob = router(properties, self.transport, message, bytes(body)) + rob = router( + self, properties, + self.transport, message, bytes(body)) rob.dispatch() else: @@ -88,7 +93,6 @@ def handle_request(self, message, payload): response.send_headers() response.write(text) response.write_eof() - self.transport.close() if use_ssl: here = os.path.join(os.path.dirname(__file__), '..', 'tests') @@ -105,7 +109,8 @@ def run(loop, fut): socks = thread_loop.run_until_complete( thread_loop.start_serving( - TestHttpServer, host, port, ssl=sslcontext)) + lambda: TestHttpServer(keep_alive=0.5), + host, port, ssl=sslcontext)) waiter = tulip.Future() loop.call_soon_threadsafe( @@ -132,12 +137,13 @@ class Router: _response_version = "1.1" _responses = http.server.BaseHTTPRequestHandler.responses - def __init__(self, props, transport, message, payload): + def __init__(self, srv, props, transport, message, payload): # headers self._headers = http.client.HTTPMessage() for hdr, val in message.headers: self._headers.add_header(hdr, val) + self._srv = srv self._props = props self._transport = transport self._method = message.method @@ -261,5 +267,5 @@ def _response(self, response, body=None, headers=None, chunked=False): response.write_eof() # keep-alive - if not response.keep_alive(): - self._transport.close() + if response.keep_alive(): + self._srv.keep_alive(True) From ed7023672088eff7c6a93a1cc89e4c07844ef520 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 14 May 2013 10:57:53 -0700 Subject: [PATCH 0474/1502] close iocp transport after receiving eof --- tests/proactor_events_test.py | 2 ++ tulip/proactor_events.py | 1 + 2 files changed, 3 insertions(+) diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py index 959cd2ba..4a48885d 100644 --- a/tests/proactor_events_test.py +++ b/tests/proactor_events_test.py @@ -52,10 +52,12 @@ def test_loop_reading_no_data(self): self.assertRaises(AssertionError, tr._loop_reading, res) + tr.close = unittest.mock.Mock() tr._read_fut = res tr._loop_reading(res) self.assertFalse(self.loop._proactor.recv.called) self.assertTrue(self.protocol.eof_received.called) + self.assertTrue(tr.close.called) def test_loop_reading_aborted(self): err = self.loop._proactor.recv.side_effect = ConnectionAbortedError() diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index c142ff45..c0a1abb8 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -53,6 +53,7 @@ def _loop_reading(self, fut=None): self._protocol.data_received(data) elif data is not None: self._protocol.eof_received() + self.close() def write(self, data): assert isinstance(data, bytes), repr(data) From 79673cb080ca58aac75ece8b48bb585a4d562213 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 15 May 2013 18:16:57 -0700 Subject: [PATCH 0475/1502] Default "make test" to non-verbose. Add vtest to force verbise. --- Makefile | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index bca4ed9b..11fe52ca 100644 --- a/Makefile +++ b/Makefile @@ -2,12 +2,15 @@ PYTHON=python3 VERBOSE=$(V) -V= 1 +V= 0 FLAGS= test: $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS) +vtest: + $(PYTHON) runtests.py -v 1 $(FLAGS) + testloop: while sleep 1; do $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS); done From a4d9460d3d0bce372ff977a493d38565a2797704 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Fri, 17 May 2013 16:40:19 +0100 Subject: [PATCH 0476/1502] Correct handling of EOF for Overlapped.ReadFile(). --- overlapped.c | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/overlapped.c b/overlapped.c index c9f6ec9f..3a2c1208 100644 --- a/overlapped.c +++ b/overlapped.c @@ -248,6 +248,18 @@ overlapped_BindLocal(PyObject *self, PyObject *args) Py_RETURN_NONE; } +/* + * Mark operation as completed - used when reading produces ERROR_BROKEN_PIPE + */ + +static void +mark_as_completed(OVERLAPPED *ov) +{ + ov->Internal = 0; + if (ov->hEvent != NULL) + SetEvent(ov->hEvent); +} + /* * A Python object wrapping an OVERLAPPED structure and other useful data * for overlapped I/O @@ -484,7 +496,7 @@ Overlapped_ReadFile(OverlappedObject *self, PyObject *args) self->error = err = ret ? ERROR_SUCCESS : GetLastError(); switch (err) { case ERROR_BROKEN_PIPE: - self->type = TYPE_NOT_STARTED; + mark_as_completed(&self->overlapped); Py_RETURN_NONE; case ERROR_SUCCESS: case ERROR_MORE_DATA: @@ -543,7 +555,7 @@ Overlapped_WSARecv(OverlappedObject *self, PyObject *args) self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); switch (err) { case ERROR_BROKEN_PIPE: - self->type = TYPE_NOT_STARTED; + mark_as_completed(&self->overlapped); Py_RETURN_NONE; case ERROR_SUCCESS: case ERROR_MORE_DATA: From 36913258c62499c26830309ebb5271d27e87b11d Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 17 May 2013 16:03:33 -0700 Subject: [PATCH 0477/1502] fix socket leaks in create_connection() --- tests/base_events_test.py | 1 + tests/events_test.py | 4 ++++ tulip/base_events.py | 2 ++ tulip/test_utils.py | 4 ++++ 4 files changed, 11 insertions(+) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index c8397811..9153fe7f 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -426,6 +426,7 @@ def getaddrinfo(*args, **kw): self.loop.run_until_complete(coro) self.assertTrue(str(cm.exception), 'Multiple exceptions: ') + self.assertTrue(m_socket.socket.return_value.close.called) def test_create_connection_no_local_addr(self): @tasks.task diff --git a/tests/events_test.py b/tests/events_test.py index 87be7736..6e3f9b0d 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -436,6 +436,7 @@ def test_create_connection(self): self.assertTrue(isinstance(pr, protocols.Protocol)) self.loop.run_until_complete(pr.done) self.assertTrue(pr.nbytes > 0) + tr.close() def test_create_connection_sock(self): with test_utils.run_test_server(self.loop) as httpd: @@ -463,6 +464,7 @@ def test_create_connection_sock(self): self.assertTrue(isinstance(pr, protocols.Protocol)) self.loop.run_until_complete(pr.done) self.assertTrue(pr.nbytes > 0) + tr.close() @unittest.skipIf(ssl is None, 'No ssl module') def test_create_ssl_connection(self): @@ -478,6 +480,7 @@ def test_create_ssl_connection(self): hasattr(tr.get_extra_info('socket'), 'getsockname')) self.loop.run_until_complete(pr.done) self.assertTrue(pr.nbytes > 0) + tr.close() def test_create_connection_local_addr(self): with test_utils.run_test_server(self.loop) as httpd: @@ -488,6 +491,7 @@ def test_create_connection_local_addr(self): tr, pr = self.loop.run_until_complete(f) expected = pr.transport.get_extra_info('socket').getsockname()[1] self.assertEqual(port, expected) + tr.close() def test_create_connection_local_addr_in_use(self): with test_utils.run_test_server(self.loop) as httpd: diff --git a/tulip/base_events.py b/tulip/base_events.py index fcc6ce2c..e39807c4 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -297,6 +297,8 @@ def create_connection(self, protocol_factory, host=None, port=None, *, laddr, exc.strerror.lower())) exceptions.append(exc) else: + sock.close() + sock = None continue yield from self.sock_connect(sock, address) except socket.error as exc: diff --git a/tulip/test_utils.py b/tulip/test_utils.py index f5854f4e..970a926e 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -117,7 +117,11 @@ def run(loop, fut): fut.set_result, (thread_loop, waiter, socks[0].getsockname())) thread_loop.run_until_complete(waiter) + + for s in socks: + s.close() thread_loop.stop() + thread_loop.close() gc.collect() fut = tulip.Future() From 0256ddcf484969dc989f9f780bd6f3dca34dcba4 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 17 May 2013 20:11:05 -0700 Subject: [PATCH 0478/1502] refactor ssl transport closing --- tests/events_test.py | 2 ++ tests/selector_events_test.py | 28 ++++++++++++++-------------- tulip/selector_events.py | 16 +++++++++++----- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 6e3f9b0d..b4cad7d5 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -599,6 +599,8 @@ def factory(): # recv()/send() on the serving socket client.close() + self.loop.stop_serving(sock) + def test_start_serving_sock(self): proto = futures.Future() diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index ea0e0efd..2ac4a428 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -948,7 +948,7 @@ def test_fatal_error(self, m_exc): self.assertTrue(self.loop.remove_writer.called) self.assertTrue(self.loop.remove_reader.called) self.loop.call_soon.assert_called_with( - self.protocol.connection_lost, exc) + transport._call_connection_lost, exc) m_exc.assert_called_with('Fatal error for %s', transport) def test_close(self): @@ -956,16 +956,6 @@ def test_close(self): transport.close() self.assertTrue(transport._closing) self.assertTrue(self.loop.remove_reader.called) - self.loop.call_soon.assert_called_with( - self.protocol.connection_lost, None) - - def test_close_write_buffer1(self): - transport = self._make_one() - transport._buffer.append(b'data') - transport.close() - - self.assertTrue(self.loop.remove_reader.called) - self.assertFalse(self.loop.call_soon.called) def test_on_ready_closed(self): self.sslsock.fileno.return_value = -1 @@ -1052,11 +1042,21 @@ def test_on_ready_send_closing(self): transport = self._make_one() transport.close() transport._buffer = [b'data'] + transport._call_connection_lost = unittest.mock.Mock() transport._on_ready() - self.assertTrue(self.sslsock.close.called) self.assertTrue(self.loop.remove_writer.called) - self.loop.call_soon.assert_called_with( - self.protocol.connection_lost, None) + self.assertTrue(transport._call_connection_lost.called) + + def test_on_ready_send_closing_empty_buffer(self): + self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.send.return_value = 4 + transport = self._make_one() + transport.close() + transport._buffer = [] + transport._call_connection_lost = unittest.mock.Mock() + transport._on_ready() + self.assertTrue(self.loop.remove_writer.called) + self.assertTrue(transport._call_connection_lost.called) def test_on_ready_send_retry(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 20473ea1..f495b8fc 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -526,6 +526,9 @@ def _on_ready(self): # Now try writing, if there's anything to write. if not self._buffer: + if self._closing: + self._loop.remove_writer(self._sslsock.fileno()) + self._call_connection_lost(None) return data = b''.join(self._buffer) @@ -547,8 +550,7 @@ def _on_ready(self): self._buffer.append(data[n:]) elif self._closing: self._loop.remove_writer(self._sslsock.fileno()) - self._sslsock.close() - self._protocol.connection_lost(None) + self._call_connection_lost(None) def write(self, data): assert isinstance(data, bytes), repr(data) @@ -573,8 +575,6 @@ def abort(self): def close(self): self._closing = True self._loop.remove_reader(self._sslsock.fileno()) - if not self._buffer: - self._loop.call_soon(self._protocol.connection_lost, None) def _fatal_error(self, exc): tulip_log.exception('Fatal error for %s', self) @@ -584,7 +584,13 @@ def _close(self, exc): self._loop.remove_writer(self._sslsock.fileno()) self._loop.remove_reader(self._sslsock.fileno()) self._buffer = [] - self._loop.call_soon(self._protocol.connection_lost, exc) + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sslsock.close() class _SelectorDatagramTransport(transports.DatagramTransport): From 7850afc377fa3cede2148034a704e02635264dac Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 17 May 2013 21:32:56 -0700 Subject: [PATCH 0479/1502] make socket transport's abort() and close() idempotent --- tests/proactor_events_test.py | 9 +++++++ tests/selector_events_test.py | 48 +++++++++++++++++++++++++++++++++++ tulip/proactor_events.py | 5 ++++ tulip/selector_events.py | 15 +++++++++++ 4 files changed, 77 insertions(+) diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py index 4a48885d..e6eacfb6 100644 --- a/tests/proactor_events_test.py +++ b/tests/proactor_events_test.py @@ -157,6 +157,11 @@ def test_abort(self): tr.abort() tr._fatal_error.assert_called_with(None) + tr._fatal_error.reset_mock() + tr._closing = True + tr.abort() + self.assertFalse(tr._fatal_error.called) + def test_close(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) self.loop.reset_mock() @@ -164,6 +169,10 @@ def test_close(self): self.loop.call_soon.assert_called_with(tr._call_connection_lost, None) self.assertTrue(tr._closing) + self.loop.reset_mock() + tr.close() + self.assertFalse(self.loop.call_soon.called) + def test_close_write_fut(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._write_fut = unittest.mock.Mock() diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index 2ac4a428..e8445f02 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -797,6 +797,25 @@ def test_close(self): self.loop.remove_reader.assert_called_with(7) self.protocol.connection_lost(None) + self.loop.reset_mock() + transport.close() + self.assertFalse(self.loop.remove_reader.called) + + def test__close(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._buffer = [b'1'] + transport._close(None) + + self.assertTrue(transport._closing) + self.assertEqual(transport._buffer, []) + self.loop.remove_reader.assert_called_with(7) + self.loop.remove_writer.assert_called_with(7) + + self.loop.reset_mock() + transport._close(None) + self.assertFalse(self.loop.remove_reader.called) + def test_close_write_buffer(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) @@ -957,6 +976,20 @@ def test_close(self): self.assertTrue(transport._closing) self.assertTrue(self.loop.remove_reader.called) + self.loop.reset_mock() + transport.close() + self.assertFalse(self.loop.remove_reader.called) + + def test__close(self): + transport = self._make_one() + transport._close(None) + self.assertTrue(transport._closing) + self.assertTrue(self.loop.remove_reader.called) + + self.loop.reset_mock() + transport._close(None) + self.assertFalse(self.loop.remove_reader.called) + def test_on_ready_closed(self): self.sslsock.fileno.return_value = -1 transport = self._make_one() @@ -1337,6 +1370,10 @@ def test_close(self): self.loop.call_soon.assert_called_with( transport._call_connection_lost, None) + self.loop.reset_mock() + transport.close() + self.assertFalse(self.loop.remove_reader.called) + def test_close_write_buffer(self): transport = _SelectorDatagramTransport( self.loop, self.sock, self.protocol) @@ -1346,6 +1383,17 @@ def test_close_write_buffer(self): self.loop.remove_reader.assert_called_with(7) self.assertFalse(self.protocol.connection_lost.called) + def test__close(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._close(None) + self.assertTrue(transport._closing) + self.assertTrue(self.loop.remove_reader.called) + + self.loop.reset_mock() + transport._close(None) + self.assertFalse(self.loop.remove_reader.called) + @unittest.mock.patch('tulip.log.tulip_log.exception') def test_fatal_error(self, m_exc): exc = OSError() diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index c0a1abb8..cebfbe89 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -91,15 +91,20 @@ def _loop_writing(self, f=None): # TODO: write_eof(), can_write_eof(). def abort(self): + if self._closing: + return self._fatal_error(None) def close(self): + if self._closing: + return self._closing = True if not self._buffer and self._write_fut is None: self._loop.call_soon(self._call_connection_lost, None) def _fatal_error(self, exc): tulip_log.exception('Fatal error for %s', self) + self._closing = True if self._write_fut: self._write_fut.cancel() if self._read_fut: # XXX diff --git a/tulip/selector_events.py b/tulip/selector_events.py index f495b8fc..407690d0 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -416,6 +416,8 @@ def abort(self): self._close(None) def close(self): + if self._closing: + return self._closing = True self._loop.remove_reader(self._sock.fileno()) if not self._buffer: @@ -427,6 +429,9 @@ def _fatal_error(self, exc): self._close(exc) def _close(self, exc): + if self._closing: + return + self._closing = True self._loop.remove_writer(self._sock.fileno()) self._loop.remove_reader(self._sock.fileno()) self._buffer.clear() @@ -573,6 +578,8 @@ def abort(self): self._close(None) def close(self): + if self._closing: + return self._closing = True self._loop.remove_reader(self._sslsock.fileno()) @@ -581,6 +588,9 @@ def _fatal_error(self, exc): self._close(exc) def _close(self, exc): + if self._closing: + return + self._closing = True self._loop.remove_writer(self._sslsock.fileno()) self._loop.remove_reader(self._sslsock.fileno()) self._buffer = [] @@ -688,6 +698,8 @@ def abort(self): self._close(None) def close(self): + if self._closing: + return self._closing = True self._loop.remove_reader(self._fileno) if not self._buffer: @@ -698,6 +710,9 @@ def _fatal_error(self, exc): self._close(exc) def _close(self, exc): + if self._closing: + return + self._closing = True self._buffer.clear() self._loop.remove_writer(self._fileno) self._loop.remove_reader(self._fileno) From d7abbce4dca2f91c1c47c7ac8f1c026c1bfeaf56 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 17 May 2013 21:55:55 -0700 Subject: [PATCH 0480/1502] call transport.close() in finally --- tests/selector_events_test.py | 4 ++-- tulip/proactor_events.py | 6 ++++-- tulip/selector_events.py | 13 ++++++++----- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index e8445f02..8d28c405 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -575,14 +575,14 @@ def test_read_ready(self): def test_read_ready_eof(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) + transport.close = unittest.mock.Mock() self.loop.reset_mock() self.sock.recv.return_value = b'' transport._read_ready() - self.assertTrue(self.loop.remove_reader.called) self.protocol.eof_received.assert_called_with() - self.loop.call_soon.assert_called_with(transport.close) + transport.close.assert_called_with() @unittest.mock.patch('logging.exception') def test_read_ready_tryagain(self, m_exc): diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index cebfbe89..7f23ac3b 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -52,8 +52,10 @@ def _loop_reading(self, fut=None): if data: self._protocol.data_received(data) elif data is not None: - self._protocol.eof_received() - self.close() + try: + self._protocol.eof_received() + finally: + self.close() def write(self, data): assert isinstance(data, bytes), repr(data) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 407690d0..15124c07 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -352,9 +352,10 @@ def _read_ready(self): if data: self._protocol.data_received(data) else: - self._loop.remove_reader(self._sock.fileno()) - self._loop.call_soon(self.close) - self._protocol.eof_received() + try: + self._protocol.eof_received() + finally: + self.close() def write(self, data): assert isinstance(data, (bytes, bytearray)), repr(data) @@ -526,8 +527,10 @@ def _on_ready(self): if data: self._protocol.data_received(data) else: - self._protocol.eof_received() - self.close() + try: + self._protocol.eof_received() + finally: + self.close() # Now try writing, if there's anything to write. if not self._buffer: From 2a80b98070e5fbfc96dc3b6e4f00ae558d18bdd7 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 20 May 2013 13:14:59 -0700 Subject: [PATCH 0481/1502] http server examples fixes #43 --- examples/mpsrv.py | 6 ++++-- examples/srv.py | 6 ++++-- examples/wssrv.py | 7 ++++--- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/examples/mpsrv.py b/examples/mpsrv.py index d6f82d3a..6b1ebb8f 100755 --- a/examples/mpsrv.py +++ b/examples/mpsrv.py @@ -100,7 +100,8 @@ def handle_request(self, message, payload): response.write(b'Cannot open') response.write_eof() - self.close() + if response.keep_alive(): + self.keep_alive(True) class ChildProcess: @@ -121,7 +122,8 @@ def stop(): os._exit(0) loop.add_signal_handler(signal.SIGINT, stop) - f = loop.start_serving(lambda: HttpServer(debug=True), sock=self.sock) + f = loop.start_serving( + lambda: HttpServer(debug=True, keep_alive=75), sock=self.sock) x = loop.run_until_complete(f)[0] print('Starting srv worker process {} on {}'.format( os.getpid(), x.getsockname())) diff --git a/examples/srv.py b/examples/srv.py index 2c2ae91e..e01e407c 100755 --- a/examples/srv.py +++ b/examples/srv.py @@ -97,7 +97,8 @@ def handle_request(self, message, payload): response.write(b'Cannot open') response.write_eof() - self.close() + if response.keep_alive(): + self.keep_alive(True) ARGS = argparse.ArgumentParser(description="Run simple http server.") @@ -148,7 +149,8 @@ def main(): loop = tulip.get_event_loop() f = loop.start_serving( - lambda: HttpServer(debug=True), args.host, args.port, ssl=sslcontext) + lambda: HttpServer(debug=True, keep_alive=75), args.host, args.port, + ssl=sslcontext) socks = loop.run_until_complete(f) print('serving on', socks[0].getsockname()) try: diff --git a/examples/wssrv.py b/examples/wssrv.py index 2befec1d..aecce9f7 100755 --- a/examples/wssrv.py +++ b/examples/wssrv.py @@ -84,7 +84,6 @@ def handle_request(self, message, payload): for wsc in self.clients: wsc.send(b'Someone disconnected.') - self.close() else: # send html page with js chat response = tulip.http.Response(self.transport, 200) @@ -103,7 +102,8 @@ def handle_request(self, message, payload): response.write(b'Cannot open') response.write_eof() - self.close() + if response.keep_alive(): + self.keep_alive(True) class ChildProcess: @@ -135,7 +135,8 @@ def stop(): def start_server(self, writer): socks = yield from self.loop.start_serving( lambda: HttpServer( - debug=True, parent=writer, clients=self.clients), + debug=True, keep_alive=75, + parent=writer, clients=self.clients), sock=self.sock) print('Starting srv worker process {} on {}'.format( os.getpid(), socks[0].getsockname())) From 15f7b0c003321281c73dea4eafe2d2a84444b9d0 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 20 May 2013 21:16:12 -0700 Subject: [PATCH 0482/1502] socket transports close procedure refactoring --- tests/proactor_events_test.py | 34 ++-- tests/selector_events_test.py | 295 ++++++++++++---------------------- tulip/proactor_events.py | 14 +- tulip/selector_events.py | 268 +++++++++++------------------- 4 files changed, 230 insertions(+), 381 deletions(-) diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py index e6eacfb6..24205532 100644 --- a/tests/proactor_events_test.py +++ b/tests/proactor_events_test.py @@ -120,7 +120,7 @@ def test_loop_writing_err(self, m_log): tr._buffer = [b'da', b'ta'] tr._loop_writing() tr._fatal_error.assert_called_with(err) - self.assertEqual(tr._conn_lost, 1) + tr._conn_lost = 1 tr.write(b'data') tr.write(b'data') @@ -153,14 +153,9 @@ def test_loop_writing_closing(self): def test_abort(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) - tr._fatal_error = unittest.mock.Mock() - tr.abort() - tr._fatal_error.assert_called_with(None) - - tr._fatal_error.reset_mock() - tr._closing = True + tr._force_close = unittest.mock.Mock() tr.abort() - self.assertFalse(tr._fatal_error.called) + tr._force_close.assert_called_with(None) def test_close(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) @@ -168,6 +163,7 @@ def test_close(self): tr.close() self.loop.call_soon.assert_called_with(tr._call_connection_lost, None) self.assertTrue(tr._closing) + self.assertEqual(tr._conn_lost, 1) self.loop.reset_mock() tr.close() @@ -189,22 +185,36 @@ def test_close_buffer(self): @unittest.mock.patch('tulip.proactor_events.tulip_log') def test_fatal_error(self, m_logging): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._force_close = unittest.mock.Mock() + tr._fatal_error(None) + self.assertTrue(tr._force_close.called) + self.assertTrue(m_logging.exception.called) + + def test_force_close(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._buffer = [b'data'] read_fut = tr._read_fut = unittest.mock.Mock() write_fut = tr._write_fut = unittest.mock.Mock() - tr._fatal_error(None) + tr._force_close(None) read_fut.cancel.assert_called_with() write_fut.cancel.assert_called_with() self.loop.call_soon.assert_called_with(tr._call_connection_lost, None) self.assertEqual([], tr._buffer) + self.assertEqual(tr._conn_lost, 1) - @unittest.mock.patch('tulip.proactor_events.tulip_log') - def test_fatal_error_2(self, m_logging): + def test_force_close_idempotent(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = True + self.loop.reset_mock() + tr._force_close(None) + self.assertFalse(self.loop.call_soon.called) + + def test_fatal_error_2(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._buffer = [b'data'] - tr._fatal_error(None) + tr._force_close(None) self.loop.call_soon.assert_called_with(tr._call_connection_lost, None) self.assertEqual([], tr._buffer) diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index 8d28c405..f1bae06b 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -14,6 +14,7 @@ from tulip.events import AbstractEventLoop from tulip.protocols import DatagramProtocol, Protocol from tulip.selector_events import BaseSelectorEventLoop +from tulip.selector_events import _SelectorTransport from tulip.selector_events import _SelectorSslTransport from tulip.selector_events import _SelectorSocketTransport from tulip.selector_events import _SelectorDatagramTransport @@ -539,6 +540,83 @@ def test_process_events_write_cancelled(self): self.loop.remove_writer.assert_called_with(1) +class SelectorTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = unittest.mock.Mock(spec_set=AbstractEventLoop) + self.sock = unittest.mock.Mock(socket.socket) + self.sock.fileno.return_value = 7 + self.protocol = unittest.mock.Mock(Protocol) + + def test_ctor(self): + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + self.assertIs(tr._loop, self.loop) + self.assertIs(tr._sock, self.sock) + self.assertIs(tr._sock_fd, 7) + + def test_abort(self): + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr._force_close = unittest.mock.Mock() + + tr.abort() + tr._force_close.assert_called_with(None) + + def test_close(self): + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr.close() + + self.assertTrue(tr._closing) + self.loop.remove_reader.assert_called_with(7) + self.protocol.connection_lost(None) + self.assertEqual(tr._conn_lost, 1) + + self.loop.reset_mock() + tr.close() + self.assertEqual(tr._conn_lost, 1) + self.assertFalse(self.loop.remove_reader.called) + + def test_close_write_buffer(self): + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + self.loop.reset_mock() + tr._buffer.append(b'data') + tr.close() + + self.assertTrue(self.loop.remove_reader.called) + self.assertFalse(self.loop.call_soon.called) + + def test_force_close(self): + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr._buffer = [b'1'] + tr._force_close(None) + + self.assertTrue(tr._closing) + self.assertEqual(tr._buffer, []) + self.loop.remove_reader.assert_called_with(7) + self.loop.remove_writer.assert_called_with(7) + + self.loop.reset_mock() + tr._force_close(None) + self.assertFalse(self.loop.remove_reader.called) + + @unittest.mock.patch('tulip.log.tulip_log.exception') + def test_fatal_error(self, m_exc): + exc = OSError() + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr._force_close = unittest.mock.Mock() + tr._fatal_error(exc) + + m_exc.assert_called_with('Fatal error for %s', tr) + tr._force_close.assert_called_with(exc) + + def test_connection_lost(self): + exc = object() + tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr._call_connection_lost(exc) + + self.protocol.connection_lost.assert_called_with(exc) + self.sock.close.assert_called_with() + + class SelectorSocketTransportTests(unittest.TestCase): def setUp(self): @@ -606,14 +684,6 @@ def test_read_ready_err(self, m_exc): transport._fatal_error.assert_called_with(err) - def test_abort(self): - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) - transport._close = unittest.mock.Mock() - - transport.abort() - transport._close.assert_called_with(None) - def test_write(self): data = b'data' self.sock.send.return_value = len(data) @@ -690,7 +760,7 @@ def test_write_exception(self, m_log): transport._fatal_error = unittest.mock.Mock() transport.write(data) transport._fatal_error.assert_called_with(err) - self.assertEqual(transport._conn_lost, 1) + transport._conn_lost = 1 self.sock.reset_mock() transport.write(data) @@ -711,7 +781,9 @@ def test_write_closing(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) transport.close() - self.assertRaises(AssertionError, transport.write, b'data') + self.assertEqual(transport._conn_lost, 1) + transport.write(b'data') + self.assertEqual(transport._conn_lost, 2) def test_write_ready(self): data = b'data' @@ -786,69 +858,6 @@ def test_write_ready_exception(self): transport._buffer.append(b'data') transport._write_ready() transport._fatal_error.assert_called_with(err) - self.assertEqual(transport._conn_lost, 1) - - def test_close(self): - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) - transport.close() - - self.assertTrue(transport._closing) - self.loop.remove_reader.assert_called_with(7) - self.protocol.connection_lost(None) - - self.loop.reset_mock() - transport.close() - self.assertFalse(self.loop.remove_reader.called) - - def test__close(self): - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) - transport._buffer = [b'1'] - transport._close(None) - - self.assertTrue(transport._closing) - self.assertEqual(transport._buffer, []) - self.loop.remove_reader.assert_called_with(7) - self.loop.remove_writer.assert_called_with(7) - - self.loop.reset_mock() - transport._close(None) - self.assertFalse(self.loop.remove_reader.called) - - def test_close_write_buffer(self): - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) - self.loop.reset_mock() - transport._buffer.append(b'data') - transport.close() - - self.assertTrue(self.loop.remove_reader.called) - self.assertFalse(self.loop.call_soon.called) - - @unittest.mock.patch('tulip.log.tulip_log.exception') - def test_fatal_error(self, m_exc): - exc = OSError() - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) - transport._buffer.append(b'data') - transport._fatal_error(exc) - - self.assertEqual([], transport._buffer) - self.loop.remove_reader.assert_called_with(7) - self.loop.remove_writer.assert_called_with(7) - self.loop.call_soon.assert_called_with( - transport._call_connection_lost, exc) - m_exc.assert_called_with('Fatal error for %s', transport) - - def test_connection_lost(self): - exc = object() - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) - transport._call_connection_lost(exc) - - self.protocol.connection_lost.assert_called_with(exc) - self.sock.close.assert_called_with() @unittest.skipIf(ssl is None, 'No ssl module') @@ -936,7 +945,9 @@ def test_write_str(self): def test_write_closing(self): transport = self._make_one() transport.close() - self.assertRaises(AssertionError, transport.write, b'data') + self.assertEqual(transport._conn_lost, 1) + transport.write(b'data') + self.assertEqual(transport._conn_lost, 2) @unittest.mock.patch('tulip.selector_events.tulip_log') def test_write_exception(self, m_log): @@ -950,52 +961,6 @@ def test_write_exception(self, m_log): transport.write(b'data') m_log.warning.assert_called_with('socket.send() raised exception.') - def test_abort(self): - transport = self._make_one() - transport._close = unittest.mock.Mock() - transport.abort() - transport._close.assert_called_with(None) - - @unittest.mock.patch('tulip.log.tulip_log.exception') - def test_fatal_error(self, m_exc): - exc = OSError() - transport = self._make_one() - transport._buffer.append(b'data') - transport._fatal_error(exc) - - self.assertEqual([], transport._buffer) - self.assertTrue(self.loop.remove_writer.called) - self.assertTrue(self.loop.remove_reader.called) - self.loop.call_soon.assert_called_with( - transport._call_connection_lost, exc) - m_exc.assert_called_with('Fatal error for %s', transport) - - def test_close(self): - transport = self._make_one() - transport.close() - self.assertTrue(transport._closing) - self.assertTrue(self.loop.remove_reader.called) - - self.loop.reset_mock() - transport.close() - self.assertFalse(self.loop.remove_reader.called) - - def test__close(self): - transport = self._make_one() - transport._close(None) - self.assertTrue(transport._closing) - self.assertTrue(self.loop.remove_reader.called) - - self.loop.reset_mock() - transport._close(None) - self.assertFalse(self.loop.remove_reader.called) - - def test_on_ready_closed(self): - self.sslsock.fileno.return_value = -1 - transport = self._make_one() - transport._on_ready() - self.assertFalse(self.sslsock.recv.called) - def test_on_ready_recv(self): self.sslsock.recv.return_value = b'data' transport = self._make_one() @@ -1120,7 +1085,19 @@ def test_on_ready_send_exc(self): transport._on_ready() transport._fatal_error.assert_called_with(err) self.assertEqual([], transport._buffer) - self.assertEqual(transport._conn_lost, 1) + + def test_close(self): + tr = self._make_one() + tr.close() + + self.assertTrue(tr._closing) + self.loop.remove_reader.assert_called_with(1) + self.assertEqual(tr._conn_lost, 1) + + self.loop.reset_mock() + tr.close() + self.assertEqual(tr._conn_lost, 1) + self.assertFalse(self.loop.remove_reader.called) class SelectorDatagramTransportTests(unittest.TestCase): @@ -1161,14 +1138,6 @@ def test_read_ready_err(self): transport._fatal_error.assert_called_with(err) - def test_abort(self): - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) - transport._close = unittest.mock.Mock() - - transport.abort() - transport._close.assert_called_with(None) - def test_sendto(self): data = b'data' transport = _SelectorDatagramTransport( @@ -1225,9 +1194,9 @@ def test_sendto_exception(self, m_log): transport._fatal_error = unittest.mock.Mock() transport.sendto(data, ()) - self.assertEqual(transport._conn_lost, 1) self.assertTrue(transport._fatal_error.called) transport._fatal_error.assert_called_with(err) + transport._conn_lost = 1 transport._address = ('123',) transport.sendto(data) @@ -1260,7 +1229,6 @@ def test_sendto_connection_refused_connected(self): transport._fatal_error = unittest.mock.Mock() transport.sendto(data) - self.assertEqual(transport._conn_lost, 1) self.assertTrue(transport._fatal_error.called) def test_sendto_str(self): @@ -1276,9 +1244,11 @@ def test_sendto_connected_addr(self): def test_sendto_closing(self): transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) + self.loop, self.sock, self.protocol, address=(1,)) transport.close() - self.assertRaises(AssertionError, transport.sendto, b'data', ()) + self.assertEqual(transport._conn_lost, 1) + transport.sendto(b'data', (1,)) + self.assertEqual(transport._conn_lost, 2) def test_sendto_ready(self): data = b'data' @@ -1360,56 +1330,6 @@ def test_sendto_ready_connection_refused_connection(self): self.assertTrue(transport._fatal_error.called) - def test_close(self): - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) - transport.close() - - self.assertTrue(transport._closing) - self.loop.remove_reader.assert_called_with(7) - self.loop.call_soon.assert_called_with( - transport._call_connection_lost, None) - - self.loop.reset_mock() - transport.close() - self.assertFalse(self.loop.remove_reader.called) - - def test_close_write_buffer(self): - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) - transport._buffer.append((b'data', ())) - transport.close() - - self.loop.remove_reader.assert_called_with(7) - self.assertFalse(self.protocol.connection_lost.called) - - def test__close(self): - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) - transport._close(None) - self.assertTrue(transport._closing) - self.assertTrue(self.loop.remove_reader.called) - - self.loop.reset_mock() - transport._close(None) - self.assertFalse(self.loop.remove_reader.called) - - @unittest.mock.patch('tulip.log.tulip_log.exception') - def test_fatal_error(self, m_exc): - exc = OSError() - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) - self.loop.reset_mock() - transport._buffer.append((b'data', ())) - transport._fatal_error(exc) - - self.assertEqual([], list(transport._buffer)) - self.loop.remove_writer.assert_called_with(7) - self.loop.remove_reader.assert_called_with(7) - self.loop.call_soon.assert_called_with( - transport._call_connection_lost, exc) - m_exc.assert_called_with('Fatal error for %s', transport) - @unittest.mock.patch('tulip.log.tulip_log.exception') def test_fatal_error_connected(self, m_exc): transport = _SelectorDatagramTransport( @@ -1418,12 +1338,3 @@ def test_fatal_error_connected(self, m_exc): transport._fatal_error(err) self.protocol.connection_refused.assert_called_with(err) m_exc.assert_called_with('Fatal error for %s', transport) - - def test_transport_closing(self): - exc = object() - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) - transport._call_connection_lost(exc) - - self.protocol.connection_lost.assert_called_with(exc) - self.sock.close.assert_called_with() diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index 7f23ac3b..0791639b 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -59,9 +59,9 @@ def _loop_reading(self, fut=None): def write(self, data): assert isinstance(data, bytes), repr(data) - assert not self._closing if not data: return + if self._conn_lost: if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: tulip_log.warning('socket.send() raised exception.') @@ -85,7 +85,6 @@ def _loop_writing(self, f=None): return self._write_fut = self._loop._proactor.send(self._sock, data) except OSError as exc: - self._conn_lost += 1 self._fatal_error(exc) else: self._write_fut.add_done_callback(self._loop_writing) @@ -93,20 +92,25 @@ def _loop_writing(self, f=None): # TODO: write_eof(), can_write_eof(). def abort(self): - if self._closing: - return - self._fatal_error(None) + self._force_close(None) def close(self): if self._closing: return self._closing = True + self._conn_lost += 1 if not self._buffer and self._write_fut is None: self._loop.call_soon(self._call_connection_lost, None) def _fatal_error(self, exc): tulip_log.exception('Fatal error for %s', self) + self._force_close(exc) + + def _force_close(self, exc): + if self._closing: + return self._closing = True + self._conn_lost += 1 if self._write_fut: self._write_fut.cancel() if self._read_fut: # XXX diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 15124c07..fc60cf2d 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -325,18 +325,60 @@ def stop_serving(self, sock): sock.close() -class _SelectorSocketTransport(transports.Transport): +class _SelectorTransport(transports.Transport): - def __init__(self, loop, sock, protocol, waiter=None, extra=None): + def __init__(self, loop, sock, protocol, extra): super().__init__(extra) self._extra['socket'] = sock self._loop = loop self._sock = sock + self._sock_fd = sock.fileno() self._protocol = protocol self._buffer = [] self._conn_lost = 0 self._closing = False # Set when close() called. - self._loop.add_reader(self._sock.fileno(), self._read_ready) + + def abort(self): + self._force_close(None) + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + self._loop.remove_reader(self._sock_fd) + if not self._buffer: + self._loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + # should be called from exception handler only + tulip_log.exception('Fatal error for %s', self) + self._force_close(exc) + + def _force_close(self, exc): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + self._loop.remove_writer(self._sock_fd) + self._loop.remove_reader(self._sock_fd) + self._buffer.clear() + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + self._sock = None + + +class _SelectorSocketTransport(_SelectorTransport): + + def __init__(self, loop, sock, protocol, waiter=None, extra=None): + super().__init__(loop, sock, protocol, extra) + + self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: self._loop.call_soon(waiter.set_result, None) @@ -358,8 +400,7 @@ def _read_ready(self): self.close() def write(self, data): - assert isinstance(data, (bytes, bytearray)), repr(data) - assert not self._closing + assert isinstance(data, bytes), repr(data) if not data: return @@ -376,7 +417,6 @@ def write(self, data): except (BlockingIOError, InterruptedError): n = 0 except socket.error as exc: - self._conn_lost += 1 self._fatal_error(exc) return @@ -384,13 +424,13 @@ def write(self, data): return elif n: data = data[n:] - self._loop.add_writer(self._sock.fileno(), self._write_ready) + self._loop.add_writer(self._sock_fd, self._write_ready) self._buffer.append(data) def _write_ready(self): data = b''.join(self._buffer) - assert data, "Data should not be empty" + assert data, 'Data should not be empty' self._buffer.clear() try: @@ -398,7 +438,6 @@ def _write_ready(self): except (BlockingIOError, InterruptedError): self._buffer.append(data) except Exception as exc: - self._conn_lost += 1 self._fatal_error(exc) else: if n == len(data): @@ -411,91 +450,51 @@ def _write_ready(self): self._buffer.append(data) # Try again later. - # TODO: write_eof(), can_write_eof(). - def abort(self): - self._close(None) - - def close(self): - if self._closing: - return - self._closing = True - self._loop.remove_reader(self._sock.fileno()) - if not self._buffer: - self._loop.call_soon(self._call_connection_lost, None) - - def _fatal_error(self, exc): - # should be called from exception handler only - tulip_log.exception('Fatal error for %s', self) - self._close(exc) - - def _close(self, exc): - if self._closing: - return - self._closing = True - self._loop.remove_writer(self._sock.fileno()) - self._loop.remove_reader(self._sock.fileno()) - self._buffer.clear() - self._loop.call_soon(self._call_connection_lost, exc) - - def _call_connection_lost(self, exc): - try: - self._protocol.connection_lost(exc) - finally: - self._sock.close() - - -class _SelectorSslTransport(transports.Transport): +class _SelectorSslTransport(_SelectorTransport): def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, server_side=False, extra=None): - super().__init__(extra) - - self._loop = loop - self._rawsock = rawsock - self._protocol = protocol if server_side: - assert isinstance(sslcontext, - ssl.SSLContext), 'Must pass an SSLContext' + assert isinstance( + sslcontext, ssl.SSLContext), 'Must pass an SSLContext' else: # Client-side may pass ssl=True to use a default context. sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) - self._sslcontext = sslcontext - self._waiter = waiter sslsock = sslcontext.wrap_socket(rawsock, server_side=server_side, do_handshake_on_connect=False) - self._sslsock = sslsock - self._buffer = [] - self._conn_lost = 0 - self._closing = False # Set when close() called. - self._extra['socket'] = sslsock + + super().__init__(loop, sslsock, protocol, extra) + + self._waiter = waiter + self._rawsock = rawsock + self._sslcontext = sslcontext self._on_handshake() def _on_handshake(self): - fd = self._sslsock.fileno() try: - self._sslsock.do_handshake() + self._sock.do_handshake() except ssl.SSLWantReadError: - self._loop.add_reader(fd, self._on_handshake) + self._loop.add_reader(self._sock_fd, self._on_handshake) return except ssl.SSLWantWriteError: - self._loop.add_writer(fd, self._on_handshake) + self._loop.add_writer(self._sock_fd, self._on_handshake) return except Exception as exc: - self._sslsock.close() + self._sock.close() if self._waiter is not None: self._waiter.set_exception(exc) return except BaseException as exc: - self._sslsock.close() + self._sock.close() if self._waiter is not None: self._waiter.set_exception(exc) raise - self._loop.remove_reader(fd) - self._loop.remove_writer(fd) - self._loop.add_reader(fd, self._on_ready) - self._loop.add_writer(fd, self._on_ready) + self._loop.remove_reader(self._sock_fd) + self._loop.remove_writer(self._sock_fd) + self._loop.add_reader(self._sock_fd, self._on_ready) + self._loop.add_writer(self._sock_fd, self._on_ready) self._loop.call_soon(self._protocol.connection_made, self) if self._waiter is not None: self._loop.call_soon(self._waiter.set_result, None) @@ -506,20 +505,12 @@ def _on_ready(self): # incorrect; we probably need to keep state about what we # should do next. - # Maybe we're already closed... - fd = self._sslsock.fileno() - if fd < 0: - return - # First try reading. if not self._closing: try: - data = self._sslsock.recv(8192) - except ssl.SSLWantReadError: - pass - except ssl.SSLWantWriteError: - pass - except (BlockingIOError, InterruptedError): + data = self._sock.recv(8192) + except (BlockingIOError, InterruptedError, + ssl.SSLWantReadError, ssl.SSLWantWriteError): pass except Exception as exc: self._fatal_error(exc) @@ -533,36 +524,27 @@ def _on_ready(self): self.close() # Now try writing, if there's anything to write. - if not self._buffer: - if self._closing: - self._loop.remove_writer(self._sslsock.fileno()) - self._call_connection_lost(None) - return + if self._buffer: + data = b''.join(self._buffer) + self._buffer = [] + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError, + ssl.SSLWantReadError, ssl.SSLWantWriteError): + n = 0 + except Exception as exc: + self._fatal_error(exc) + return - data = b''.join(self._buffer) - self._buffer = [] - try: - n = self._sslsock.send(data) - except ssl.SSLWantReadError: - n = 0 - except ssl.SSLWantWriteError: - n = 0 - except (BlockingIOError, InterruptedError): - n = 0 - except Exception as exc: - self._conn_lost += 1 - self._fatal_error(exc) - return + if n < len(data): + self._buffer.append(data[n:]) - if n < len(data): - self._buffer.append(data[n:]) - elif self._closing: - self._loop.remove_writer(self._sslsock.fileno()) + if self._closing and not self._buffer: + self._loop.remove_writer(self._sock_fd) self._call_connection_lost(None) def write(self, data): assert isinstance(data, bytes), repr(data) - assert not self._closing if not data: return @@ -575,53 +557,26 @@ def write(self, data): self._buffer.append(data) # We could optimize, but the callback can do this for now. - # TODO: write_eof(), can_write_eof(). - - def abort(self): - self._close(None) - def close(self): if self._closing: return self._closing = True - self._loop.remove_reader(self._sslsock.fileno()) - - def _fatal_error(self, exc): - tulip_log.exception('Fatal error for %s', self) - self._close(exc) + self._conn_lost += 1 + self._loop.remove_reader(self._sock_fd) - def _close(self, exc): - if self._closing: - return - self._closing = True - self._loop.remove_writer(self._sslsock.fileno()) - self._loop.remove_reader(self._sslsock.fileno()) - self._buffer = [] - self._loop.call_soon(self._call_connection_lost, exc) - - def _call_connection_lost(self, exc): - try: - self._protocol.connection_lost(exc) - finally: - self._sslsock.close() + # TODO: write_eof(), can_write_eof(). -class _SelectorDatagramTransport(transports.DatagramTransport): +class _SelectorDatagramTransport(_SelectorTransport): max_size = 256 * 1024 # max bytes we read in one eventloop iteration def __init__(self, loop, sock, protocol, address=None, extra=None): - super().__init__(extra) - self._extra['socket'] = sock - self._loop = loop - self._sock = sock - self._fileno = sock.fileno() - self._protocol = protocol + super().__init__(loop, sock, protocol, extra) + self._address = address self._buffer = collections.deque() - self._conn_lost = 0 - self._closing = False # Set when close() called. - self._loop.add_reader(self._fileno, self._read_ready) + self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) def _read_ready(self): @@ -635,8 +590,7 @@ def _read_ready(self): self._protocol.datagram_received(data, addr) def sendto(self, data, addr=None): - assert isinstance(data, bytes) - assert not self._closing + assert isinstance(data, bytes), repr(data) if not data: return @@ -659,13 +613,11 @@ def sendto(self, data, addr=None): return except ConnectionRefusedError as exc: if self._address: - self._conn_lost += 1 self._fatal_error(exc) return except (BlockingIOError, InterruptedError): - self._loop.add_writer(self._fileno, self._sendto_ready) + self._loop.add_writer(self._sock_fd, self._sendto_ready) except Exception as exc: - self._conn_lost += 1 self._fatal_error(exc) return @@ -681,50 +633,22 @@ def _sendto_ready(self): self._sock.sendto(data, addr) except ConnectionRefusedError as exc: if self._address: - self._conn_lost += 1 self._fatal_error(exc) return except (BlockingIOError, InterruptedError): self._buffer.appendleft((data, addr)) # Try again later. break except Exception as exc: - self._conn_lost += 1 self._fatal_error(exc) return if not self._buffer: - self._loop.remove_writer(self._fileno) + self._loop.remove_writer(self._sock_fd) if self._closing: self._call_connection_lost(None) - def abort(self): - self._close(None) - - def close(self): - if self._closing: - return - self._closing = True - self._loop.remove_reader(self._fileno) - if not self._buffer: - self._loop.call_soon(self._call_connection_lost, None) - - def _fatal_error(self, exc): - tulip_log.exception('Fatal error for %s', self) - self._close(exc) - - def _close(self, exc): - if self._closing: - return - self._closing = True - self._buffer.clear() - self._loop.remove_writer(self._fileno) - self._loop.remove_reader(self._fileno) + def _force_close(self, exc): if self._address and isinstance(exc, ConnectionRefusedError): self._protocol.connection_refused(exc) - self._loop.call_soon(self._call_connection_lost, exc) - def _call_connection_lost(self, exc): - try: - self._protocol.connection_lost(exc) - finally: - self._sock.close() + super()._force_close(exc) From 5c3a4d596b4077753532a800b9eef1005a9cbda5 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 21 May 2013 10:22:43 -0700 Subject: [PATCH 0483/1502] fix unclosed sockets in tests --- tests/base_events_test.py | 1 + tests/events_test.py | 37 ++++++++++++++++++++++------ tests/http_client_functional_test.py | 3 +++ tests/http_session_test.py | 3 +++ tests/queues_test.py | 1 + tests/tasks_test.py | 1 + tests/unix_events_test.py | 4 +++ tulip/test_utils.py | 12 +++++++++ tulip/unix_events.py | 2 +- 9 files changed, 56 insertions(+), 8 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index 9153fe7f..599b29f0 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -19,6 +19,7 @@ def setUp(self): self.loop = base_events.BaseEventLoop() self.loop._selector = unittest.mock.Mock() self.loop._selector.registered_count.return_value = 1 + events.set_event_loop(self.loop) def test_not_implemented(self): m = unittest.mock.Mock() diff --git a/tests/events_test.py b/tests/events_test.py index b4cad7d5..d23548e1 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -146,18 +146,25 @@ def setUp(self): events.set_event_loop(self.loop) def tearDown(self): + # just in case if we have transport close callbacks + test_utils.run_once(self.loop) + self.loop.close() gc.collect() super().tearDown() def test_run_nesting(self): @tasks.coroutine - def coro(): + def coro1(): + yield + + @tasks.coroutine + def coro2(): self.assertTrue(self.loop.is_running()) - self.loop.run_until_complete(tasks.sleep(0.1)) + self.loop.run_until_complete(coro1()) self.assertRaises( - RuntimeError, self.loop.run_until_complete, coro()) + RuntimeError, self.loop.run_until_complete, coro2()) def test_run_once_nesting(self): @tasks.coroutine @@ -545,6 +552,9 @@ def factory(): # recv()/send() on the serving socket client.close() + # close start_serving socks + self.loop.stop_serving(sock) + @unittest.skipIf(ssl is None, 'No ssl module') def test_start_serving_ssl(self): proto = None @@ -599,6 +609,7 @@ def factory(): # recv()/send() on the serving socket client.close() + # stop serving self.loop.stop_serving(sock) def test_start_serving_sock(self): @@ -623,10 +634,14 @@ def connection_made(self, transport): client.connect(('127.0.0.1', port)) client.send(b'xxx') self.loop.run_until_complete(proto) - sock.close() client.close() - def test_start_serving_addrinuse(self): + # wait until connection get closed + test_utils.run_once(self.loop) + + self.loop.stop_serving(sock) + + def test_start_serving_addr_in_use(self): sock_ob = socket.socket(type=socket.SOCK_STREAM) sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock_ob.bind(('0.0.0.0', 0)) @@ -640,6 +655,8 @@ def test_start_serving_addrinuse(self): self.loop.run_until_complete(f) self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + self.loop.stop_serving(sock) + @unittest.skipUnless(socket.has_ipv6, 'IPv6 not supported') def test_start_serving_dual_stack(self): f_proto = futures.Future() @@ -668,11 +685,12 @@ def connection_made(self, transport): client.close() for s in socks: - s.close() + self.loop.stop_serving(s) def test_stop_serving(self): f = self.loop.start_serving(MyProto, '0.0.0.0', 0) - sock = self.loop.run_until_complete(f)[0] + socks = self.loop.run_until_complete(f) + sock = socks[0] host, port = sock.getsockname() client = socket.socket() @@ -685,6 +703,7 @@ def test_stop_serving(self): client = socket.socket() self.assertRaises( ConnectionRefusedError, client.connect, ('127.0.0.1', port)) + client.close() def test_create_datagram_endpoint(self): class TestMyDatagramProto(MyDatagramProto): @@ -1101,6 +1120,7 @@ def test_get_event_loop(self): self.assertIs(policy._loop, loop) self.assertIs(loop, policy.get_event_loop()) + loop.close() @unittest.mock.patch('tulip.events.threading') def test_get_event_loop_thread(self, m_threading): @@ -1115,6 +1135,7 @@ def test_new_event_loop(self): event_loop = policy.new_event_loop() self.assertIsInstance(event_loop, events.AbstractEventLoop) + event_loop.close() def test_set_event_loop(self): policy = events.DefaultEventLoopPolicy() @@ -1126,6 +1147,8 @@ def test_set_event_loop(self): policy.set_event_loop(event_loop) self.assertIs(event_loop, policy.get_event_loop()) self.assertIsNot(old_event_loop, policy.get_event_loop()) + event_loop.close() + old_event_loop.close() def test_get_event_loop_policy(self): policy = events.get_event_loop_policy() diff --git a/tests/http_client_functional_test.py b/tests/http_client_functional_test.py index 430664da..125927f4 100644 --- a/tests/http_client_functional_test.py +++ b/tests/http_client_functional_test.py @@ -19,6 +19,9 @@ def setUp(self): tulip.set_event_loop(self.loop) def tearDown(self): + # just in case if we have transport close callbacks + test_utils.run_once(self.loop) + self.loop.close() gc.collect() diff --git a/tests/http_session_test.py b/tests/http_session_test.py index 1b86c56d..84f601a6 100644 --- a/tests/http_session_test.py +++ b/tests/http_session_test.py @@ -21,6 +21,9 @@ def setUp(self): self.stream = tulip.StreamBuffer() self.response = HttpResponse('get', 'http://python.org') + def tearDown(self): + self.loop.close() + def test_del(self): session = Session() close = session.close = unittest.mock.Mock() diff --git a/tests/queues_test.py b/tests/queues_test.py index c47d1f74..6227f3d0 100644 --- a/tests/queues_test.py +++ b/tests/queues_test.py @@ -1,6 +1,7 @@ """Tests for queues.py""" import unittest +import unittest.mock import queue from tulip import events diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 3ccaa92b..2c175592 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -40,6 +40,7 @@ def notmuch(): loop = events.new_event_loop() t = tasks.Task(notmuch(), loop=loop) self.assertIs(t._loop, loop) + loop.close() def test_task_decorator(self): @tasks.task diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index d5f72b9d..e5304f71 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -21,6 +21,10 @@ class SelectorEventLoopTests(unittest.TestCase): def setUp(self): self.loop = unix_events.SelectorEventLoop() + events.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() def test_check_signal(self): self.assertRaises( diff --git a/tulip/test_utils.py b/tulip/test_utils.py index 970a926e..3eacbade 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -42,6 +42,7 @@ def once(): def run_test_server(loop, *, host='127.0.0.1', port=0, use_ssl=False, router=None): properties = {} + transports = [] class HttpServer: @@ -64,6 +65,10 @@ def url(self, *suffix): class TestHttpServer(tulip.http.ServerHttpProtocol): + def connection_made(self, transport): + transports.append(transport) + super().connection_made(transport) + def handle_request(self, message, payload): if properties.get('close', False): return @@ -118,8 +123,15 @@ def run(loop, fut): thread_loop.run_until_complete(waiter) + # close opened trnsports + for tr in transports: + tr.close() + for s in socks: s.close() + + run_once(thread_loop) # call close callbacks + thread_loop.stop() thread_loop.close() gc.collect() diff --git a/tulip/unix_events.py b/tulip/unix_events.py index a8e073af..b5825950 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -253,7 +253,7 @@ def write(self, data): def _write_ready(self): data = b''.join(self._buffer) - assert data, "Data should not be empty" + assert data, 'Data should not be empty' self._buffer.clear() try: From aa7dc16fb988c8fa17f3948c28d8b574f8065297 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 21 May 2013 11:04:46 -0700 Subject: [PATCH 0484/1502] Add async() function, by Aaron Griffith. --- tests/base_events_test.py | 4 +-- tests/tasks_test.py | 51 +++++++++++++++++++++++++++++++++++++++ tulip/base_events.py | 7 +----- tulip/futures.py | 2 ++ tulip/tasks.py | 25 +++++++++++++++---- 5 files changed, 76 insertions(+), 13 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index 599b29f0..4e820ad8 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -251,9 +251,9 @@ def cb(loop): self.assertTrue(processed) self.assertEqual([handle], list(self.loop._ready)) - def test_run_until_complete_assertion(self): + def test_run_until_complete_type_error(self): self.assertRaises( - AssertionError, self.loop.run_until_complete, 'blah') + TypeError, self.loop.run_until_complete, 'blah') class MyProto(protocols.Protocol): diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 2c175592..114c52cd 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -72,6 +72,57 @@ def notmuch(): self.loop.run_until_complete(t) self.assertTrue(t.done()) self.assertEqual(t.result(), 'ko') + + def test_async_coroutine(self): + @tasks.coroutine + def notmuch(): + return 'ok' + t = tasks.async(notmuch()) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t._loop, self.loop) + + loop = events.new_event_loop() + t = tasks.async(notmuch(), loop=loop) + self.assertIs(t._loop, loop) + + def test_async_future(self): + f_orig = futures.Future() + f_orig.set_result('ko') + + f = tasks.async(f_orig) + self.loop.run_until_complete(f) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 'ko') + self.assertIs(f, f_orig) + + with self.assertRaises(ValueError): + loop = events.new_event_loop() + f = tasks.async(f_orig, loop=loop) + f = tasks.async(f_orig, loop=self.loop) + self.assertIs(f, f_orig) + + def test_async_task(self): + @tasks.coroutine + def notmuch(): + return 'ok' + t_orig = tasks.Task(notmuch()) + t = tasks.async(t_orig) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t, t_orig) + + with self.assertRaises(ValueError): + loop = events.new_event_loop() + t = tasks.async(t_orig, loop=loop) + t = tasks.async(t_orig, loop=self.loop) + self.assertIs(t, t_orig) + + def test_async_neither(self): + with self.assertRaises(TypeError): + tasks.async('ok') def test_task_repr(self): @tasks.task diff --git a/tulip/base_events.py b/tulip/base_events.py index e39807c4..57c386b1 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -132,12 +132,7 @@ def run_until_complete(self, future, timeout=None): Return the Future's result, or raise its exception. If the timeout is reached or stop() is called, raise TimeoutError. """ - if not isinstance(future, futures.Future): - if tasks.iscoroutine(future): - future = tasks.Task(future) - else: - assert False, 'A Future or coroutine is required' - + future = tasks.async(future) handle_called = False def stop_loop(): diff --git a/tulip/futures.py b/tulip/futures.py index 14edbc99..142004aa 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -129,6 +129,7 @@ class Future: _state = _PENDING _result = None _exception = None + _timeout = None _timeout_handle = None _loop = None @@ -150,6 +151,7 @@ def __init__(self, *, loop=None, timeout=None): self._callbacks = [] if timeout is not None: + self._timeout = timeout self._timeout_handle = self._loop.call_later(timeout, self.cancel) def __repr__(self): diff --git a/tulip/tasks.py b/tulip/tasks.py index 54c594ba..1b9317c9 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -1,6 +1,6 @@ """Support for tasks, coroutines and the scheduler.""" -__all__ = ['coroutine', 'task', 'Task', +__all__ = ['coroutine', 'task', 'async', 'Task', 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', 'wait', 'as_completed', 'sleep', ] @@ -65,6 +65,24 @@ def task_wrapper(*args, **kwds): return task_wrapper +def async(coro_or_future, *, loop=None, timeout=None): + """Wrap a coroutine in a future. + + If the argument is a Future, it is returned directly. + """ + if isinstance(coro_or_future, futures.Future): + if ((loop != None and loop != coro_or_future._loop) or + (timeout != None and timeout != coro_or_future._timeout)): + raise ValueError( + 'loop and timeout arguments must agree with Future') + + return coro_or_future + elif iscoroutine(coro_or_future): + return Task(coro_or_future, loop=loop, timeout=timeout) + else: + raise TypeError('A Future or coroutine is required') + + _marker = object() @@ -314,10 +332,7 @@ def _wrap_coroutines(fs): """ wrapped = set() for f in fs: - if not isinstance(f, futures.Future): - assert iscoroutine(f) - f = Task(f) - wrapped.add(f) + wrapped.add(async(f)) return wrapped From 5778effa924cdd2202d9631a533a40cff55a8356 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Wed, 22 May 2013 11:53:15 +0100 Subject: [PATCH 0485/1502] Add examples/windows_subprocess.py. --- examples/windows_subprocess.py | 249 +++++++++++++++++++++++++++++++++ tulip/proactor_events.py | 12 +- tulip/windows_events.py | 12 +- 3 files changed, 267 insertions(+), 6 deletions(-) create mode 100644 examples/windows_subprocess.py diff --git a/examples/windows_subprocess.py b/examples/windows_subprocess.py new file mode 100644 index 00000000..a839613b --- /dev/null +++ b/examples/windows_subprocess.py @@ -0,0 +1,249 @@ +""" +Example of asynchronous interaction with a subprocess on Windows. + +This requires use of overlapped pipe handles and (a modified) iocp proactor. +""" + +import itertools +import msvcrt +import os +import subprocess +import sys +import tempfile +import _winapi + +try: + import tulip +except ImportError: + # tulip is not installed + sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + import tulip + +from tulip import _overlapped +from tulip import windows_events +from tulip import streams +from tulip import protocols + +# +# Constants/globals +# + +BUFSIZE = 8192 +PIPE = subprocess.PIPE +_mmap_counter=itertools.count() + +# +# Replacement for os.pipe() using handles instead of fds +# + +def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): + address = tempfile.mktemp(prefix=r'\\.\pipe\python-pipe-%d-%d-' % + (os.getpid(), next(_mmap_counter))) + + if duplex: + openmode = _winapi.PIPE_ACCESS_DUPLEX + access = _winapi.GENERIC_READ | _winapi.GENERIC_WRITE + obsize, ibsize = bufsize, bufsize + else: + openmode = _winapi.PIPE_ACCESS_INBOUND + access = _winapi.GENERIC_WRITE + obsize, ibsize = 0, bufsize + + openmode |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE + + if overlapped[0]: + openmode |= _winapi.FILE_FLAG_OVERLAPPED + + if overlapped[1]: + flags_and_attribs = _winapi.FILE_FLAG_OVERLAPPED + else: + flags_and_attribs = 0 + + h1 = h2 = None + try: + h1 = _winapi.CreateNamedPipe( + address, openmode, _winapi.PIPE_WAIT, + 1, obsize, ibsize, _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL) + + h2 = _winapi.CreateFile( + address, access, 0, _winapi.NULL, _winapi.OPEN_EXISTING, + flags_and_attribs, _winapi.NULL) + + ov = _winapi.ConnectNamedPipe(h1, overlapped=True) + ov.GetOverlappedResult(True) + return h1, h2 + except: + if h1 is not None: + _winapi.CloseHandle(h1) + if h2 is not None: + _winapi.CloseHandle(h2) + raise + +# +# Wrapper for a pipe handle +# + +class PipeHandle: + def __init__(self, handle): + self._handle = handle + + @property + def handle(self): + return self._handle + + def fileno(self): + return self._handle + + def close(self, *, CloseHandle=_winapi.CloseHandle): + if self._handle is not None: + CloseHandle(self._handle) + self._handle = None + + __del__ = close + + def __enter__(self): + return self + + def __exit__(self, t, v, tb): + self.close() + +# +# Replacement for subprocess.Popen using overlapped pipe handles +# + +class Popen(subprocess.Popen): + def __init__(self, args, stdin=None, stdout=None, stderr=None, **kwds): + stdin_rfd = stdout_wfd = stderr_wfd = None + stdin_wh = stdout_rh = stderr_rh = None + if stdin == PIPE: + stdin_rh, stdin_wh = pipe(overlapped=(False, True)) + stdin_rfd = msvcrt.open_osfhandle(stdin_rh, os.O_RDONLY) + if stdout == PIPE: + stdout_rh, stdout_wh = pipe(overlapped=(True, False)) + stdout_wfd = msvcrt.open_osfhandle(stdout_wh, 0) + if stderr == PIPE: + stderr_rh, stderr_wh = pipe(overlapped=(True, False)) + stderr_wfd = msvcrt.open_osfhandle(stderr_wh, 0) + try: + super().__init__(args, stdin=stdin_rfd, stdout=stdout_wfd, + stderr=stderr_wfd, **kwds) + except: + for h in (stdin_wh, stdout_rh, stderr_rh): + _winapi.CloseHandle(h) + raise + else: + if stdin_wh is not None: + self.stdin = PipeHandle(stdin_wh) + if stdout_rh is not None: + self.stdout = PipeHandle(stdout_rh) + if stderr_rh is not None: + self.stderr = PipeHandle(stderr_rh) + finally: + if stdin == PIPE: + os.close(stdin_rfd) + if stdout == PIPE: + os.close(stdout_wfd) + if stderr == PIPE: + os.close(stderr_wfd) + +# +# Return a write-only transport wrapping a writable pipe +# + +def connect_write_pipe(file): + loop = tulip.get_event_loop() + protocol = protocols.Protocol() + return loop._make_socket_transport(file, protocol, write_only=True) + +# +# Wrap a readable pipe in a stream +# + +def connect_read_pipe(file): + loop = tulip.get_event_loop() + stream_reader = streams.StreamReader() + protocol = _StreamReaderProtocol(stream_reader) + transport = loop._make_socket_transport(file, protocol) + return stream_reader + +class _StreamReaderProtocol(protocols.Protocol): + def __init__(self, stream_reader): + self.stream_reader = stream_reader + def connection_lost(self, exc): + self.stream_reader.set_exception(exc) + def data_received(self, data): + self.stream_reader.feed_data(data) + def eof_received(self): + self.stream_reader.feed_eof() + +# +# Example +# + +@tulip.task +def main(loop): + # program which prints evaluation of each expression from stdin + code = r'''if 1: + import os + def writeall(fd, buf): + while buf: + n = os.write(fd, buf) + buf = buf[n:] + while True: + s = os.read(0, 1024) + if not s: + break + s = s.decode('ascii') + s = repr(eval(s)) + '\n' + s = s.encode('ascii') + writeall(1, s) + ''' + + # commands to send to input + commands = iter([b"1+1\n", + b"2**16\n", + b"1/3\n", + b"'x'*50", + b"1/0\n"]) + + # start subprocess and wrap stdin, stdout, stderr + p = Popen([sys.executable, '-c', code], + stdin=PIPE, stdout=PIPE, stderr=PIPE) + stdin = connect_write_pipe(p.stdin) + stdout = connect_read_pipe(p.stdout) + stderr = connect_read_pipe(p.stderr) + + # interact with subprocess + name = {stdout:'OUT', stderr:'ERR'} + registered = {tulip.Task(stderr.readline()): stderr, + tulip.Task(stdout.readline()): stdout} + while registered: + # write command + cmd = next(commands, None) + if cmd is None: + stdin.close() + else: + print('>>>', cmd.decode('ascii').rstrip()) + stdin.write(cmd) + + # get and print lines from stdout, stderr + timeout = None + while True: + done, pending = yield from tulip.wait( + registered, timeout, tulip.FIRST_COMPLETED) + if not done: + break + for f in done: + stream = registered.pop(f) + res = f.result() + print(name[stream], res.decode('ascii').rstrip()) + if res != b'': + registered[tulip.Task(stream.readline())] = stream + timeout = 0.0 + + +if __name__ == '__main__': + loop = windows_events.ProactorEventLoop() + tulip.set_event_loop(loop) + loop.run_until_complete(main(loop)) + loop.close() diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index 0791639b..f0eed149 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -12,7 +12,8 @@ class _ProactorSocketTransport(transports.Transport): - def __init__(self, loop, sock, protocol, waiter=None, extra=None): + def __init__(self, loop, sock, protocol, waiter=None, extra=None, + write_only=False): super().__init__(extra) self._extra['socket'] = sock self._loop = loop @@ -24,7 +25,8 @@ def __init__(self, loop, sock, protocol, waiter=None, extra=None): self._conn_lost = 0 self._closing = False # Set when close() called. self._loop.call_soon(self._protocol.connection_made, self) - self._loop.call_soon(self._loop_reading) + if not write_only: + self._loop.call_soon(self._loop_reading) if waiter is not None: self._loop.call_soon(waiter.set_result, None) @@ -135,8 +137,10 @@ def __init__(self, proactor): self._selector = proactor # convenient alias self._make_self_pipe() - def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): - return _ProactorSocketTransport(self, sock, protocol, waiter, extra) + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None, + write_only=False): + return _ProactorSocketTransport(self, sock, protocol, waiter, extra, + write_only) def close(self): if self._proactor is not None: diff --git a/tulip/windows_events.py b/tulip/windows_events.py index cd6c61af..0a2f07e0 100644 --- a/tulip/windows_events.py +++ b/tulip/windows_events.py @@ -60,13 +60,21 @@ def select(self, timeout=None): def recv(self, conn, nbytes, flags=0): self._register_with_iocp(conn) ov = _overlapped.Overlapped(NULL) - ov.WSARecv(conn.fileno(), nbytes, flags) + handle = getattr(conn, 'handle', None) + if handle is None: + ov.WSARecv(conn.fileno(), nbytes, flags) + else: + ov.ReadFile(handle, nbytes) return self._register(ov, conn, ov.getresult) def send(self, conn, buf, flags=0): self._register_with_iocp(conn) ov = _overlapped.Overlapped(NULL) - ov.WSASend(conn.fileno(), buf, flags) + handle = getattr(conn, 'handle', None) + if handle is None: + ov.WSASend(conn.fileno(), buf, flags) + else: + ov.WriteFile(handle, buf) return self._register(ov, conn, ov.getresult) def accept(self, listener): From 1a6f1672a0e2727cc421fd1a32c89c17980c0757 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 22 May 2013 14:57:15 -0700 Subject: [PATCH 0486/1502] Output pause/resume/discard by Jeff Quast (improved by us). --- tests/selector_events_test.py | 71 ++++++++++++++++++++++++++++++++++- tests/transports_test.py | 3 ++ tulip/selector_events.py | 25 +++++++++++- tulip/transports.py | 15 ++++++++ 4 files changed, 111 insertions(+), 3 deletions(-) diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index f1bae06b..656178af 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -622,7 +622,7 @@ class SelectorSocketTransportTests(unittest.TestCase): def setUp(self): self.loop = unittest.mock.Mock(spec_set=AbstractEventLoop) self.sock = unittest.mock.Mock(socket.socket) - self.sock.fileno.return_value = 7 + self.sock_fd = self.sock.fileno.return_value = 7 self.protocol = unittest.mock.Mock(Protocol) def test_ctor(self): @@ -631,6 +631,7 @@ def test_ctor(self): self.loop.add_reader.assert_called_with(7, tr._read_ready) self.loop.call_soon.assert_called_with( self.protocol.connection_made, tr) + self.assertTrue(tr._writing) def test_ctor_with_waiter(self): fut = futures.Future() @@ -709,6 +710,14 @@ def test_write_buffer(self): self.assertFalse(self.sock.send.called) self.assertEqual([b'data1', b'data2'], transport._buffer) + def test_write_paused(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._writing = False + transport.write(b'data') + self.assertFalse(self.sock.send.called) + self.assertEqual(transport._buffer, [b'data']) + def test_write_partial(self): data = b'data' self.sock.send.return_value = 2 @@ -797,6 +806,15 @@ def test_write_ready(self): self.assertEqual(self.sock.send.call_args[0], (data,)) self.assertTrue(self.loop.remove_writer.called) + def test_write_ready_paused(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._writing = False + transport._buffer.append(b'data') + transport._write_ready() + self.assertFalse(self.sock.send.called) + self.assertEqual(transport._buffer, [b'data']) + def test_write_ready_closing(self): data = b'data' self.sock.send.return_value = len(data) @@ -859,6 +877,57 @@ def test_write_ready_exception(self): transport._write_ready() transport._fatal_error.assert_called_with(err) + def test_pause_writing(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append(b'data') + transport.pause_writing() + self.assertFalse(transport._writing) + self.loop.remove_writer.assert_called_with(self.sock_fd) + + self.loop.reset_mock() + transport.pause_writing() + self.assertFalse(self.loop.remove_writer.called) + + def test_pause_writing_no_buffer(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.pause_writing() + self.assertFalse(transport._writing) + self.assertFalse(self.loop.remove_writer.called) + + def test_resume_writing(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append(b'data') + transport.resume_writing() + self.assertFalse(self.loop.add_writer.called) + + transport._writing = False + transport.resume_writing() + self.assertTrue(transport._writing) + self.loop.add_writer.assert_called_with( + self.sock_fd, transport._write_ready) + + def test_resume_writing_no_buffer(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._writing = False + transport.resume_writing() + self.assertTrue(transport._writing) + self.assertFalse(self.loop.add_writer.called) + + def test_discard_output(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.discard_output() + self.assertFalse(self.loop.remove_writer.called) + + transport._buffer.append(b'data') + transport.discard_output() + self.assertEqual(transport._buffer, []) + self.loop.remove_writer.assert_called_with(self.sock_fd) + @unittest.skipIf(ssl is None, 'No ssl module') class SelectorSslTransportTests(unittest.TestCase): diff --git a/tests/transports_test.py b/tests/transports_test.py index 4b24b50b..b1c932f0 100644 --- a/tests/transports_test.py +++ b/tests/transports_test.py @@ -37,6 +37,9 @@ def test_not_implemented(self): self.assertRaises(NotImplementedError, transport.resume) self.assertRaises(NotImplementedError, transport.close) self.assertRaises(NotImplementedError, transport.abort) + self.assertRaises(NotImplementedError, transport.pause_writing) + self.assertRaises(NotImplementedError, transport.resume_writing) + self.assertRaises(NotImplementedError, transport.discard_output) def test_dgram_not_implemented(self): transport = transports.DatagramTransport() diff --git a/tulip/selector_events.py b/tulip/selector_events.py index fc60cf2d..ec2fd1c9 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -336,6 +336,7 @@ def __init__(self, loop, sock, protocol, extra): self._protocol = protocol self._buffer = [] self._conn_lost = 0 + self._writing = True self._closing = False # Set when close() called. def abort(self): @@ -410,7 +411,7 @@ def write(self, data): self._conn_lost += 1 return - if not self._buffer: + if not self._buffer and self._writing: # Attempt to send it right away first. try: n = self._sock.send(data) @@ -429,6 +430,9 @@ def write(self, data): self._buffer.append(data) def _write_ready(self): + if not self._writing: + return # transmission off + data = b''.join(self._buffer) assert data, 'Data should not be empty' @@ -441,7 +445,7 @@ def _write_ready(self): self._fatal_error(exc) else: if n == len(data): - self._loop.remove_writer(self._sock.fileno()) + self._loop.remove_writer(self._sock_fd) if self._closing: self._call_connection_lost(None) return @@ -450,6 +454,23 @@ def _write_ready(self): self._buffer.append(data) # Try again later. + def pause_writing(self): + if self._writing: + if self._buffer: + self._loop.remove_writer(self._sock_fd) + self._writing = False + + def resume_writing(self): + if not self._writing: + if self._buffer: + self._loop.add_writer(self._sock_fd, self._write_ready) + self._writing = True + + def discard_output(self): + if self._buffer: + self._loop.remove_writer(self._sock_fd) + self._buffer.clear() + class _SelectorSslTransport(_SelectorTransport): diff --git a/tulip/transports.py b/tulip/transports.py index a9ec07a0..2b34bc59 100644 --- a/tulip/transports.py +++ b/tulip/transports.py @@ -79,6 +79,21 @@ def can_write_eof(self): """Return True if this protocol supports write_eof(), False if not.""" raise NotImplementedError + def pause_writing(self): + """Pause transmission on the transport. + + Subsequent writes are deferred until resume_writing() is called. + """ + raise NotImplementedError + + def resume_writing(self): + """Resume transmission on the transport. """ + raise NotImplementedError + + def discard_output(self): + """Discard any buffered data awaiting transmission on the transport.""" + raise NotImplementedError + def abort(self): """Closes the transport immediately. From 1379167d6b722d1b9c5f77c648d1893bf5332f6f Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 23 May 2013 12:22:46 -0700 Subject: [PATCH 0487/1502] added stop_serving() to iocp event loop --- tests/events_test.py | 7 ------- tests/proactor_events_test.py | 20 ++++++++++++++++++++ tests/tasks_test.py | 16 ++++++++-------- tulip/proactor_events.py | 13 +++++++++++++ tulip/test_utils.py | 6 +++--- tulip/windows_events.py | 26 ++++++++++++++++++-------- 6 files changed, 62 insertions(+), 26 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index d23548e1..6273a2cf 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -633,12 +633,8 @@ def connection_made(self, transport): client = socket.socket() client.connect(('127.0.0.1', port)) client.send(b'xxx') - self.loop.run_until_complete(proto) client.close() - # wait until connection get closed - test_utils.run_once(self.loop) - self.loop.stop_serving(sock) def test_start_serving_addr_in_use(self): @@ -868,9 +864,6 @@ def test_writer_callback_cancel(self): def test_create_datagram_endpoint(self): raise unittest.SkipTest( "IocpEventLoop does not have create_datagram_endpoint()") - def test_stop_serving(self): - raise unittest.SkipTest( - "IocpEventLoop does not support stop_serving()") else: from tulip import selectors from tulip import unix_events diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py index 24205532..468a7ede 100644 --- a/tests/proactor_events_test.py +++ b/tests/proactor_events_test.py @@ -348,3 +348,23 @@ def test_start_serving(self, m_log): loop(fut) self.assertTrue(self.sock.close.called) self.assertTrue(m_log.exception.called) + + def test_start_serving_cancel(self): + pf = unittest.mock.Mock() + call_soon = self.loop.call_soon = unittest.mock.Mock() + + self.loop._start_serving(pf, self.sock) + loop = call_soon.call_args[0][0] + + # cancelled + self.sock.reset_mock() + fut = tulip.Future() + fut.cancel() + loop(fut) + self.assertTrue(self.sock.close.called) + + def test_stop_serving(self): + sock = unittest.mock.Mock() + self.loop.stop_serving(sock) + self.assertTrue(sock.close.called) + self.proactor.stop_serving.assert_called_with(sock) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 114c52cd..7d76ac58 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -72,7 +72,7 @@ def notmuch(): self.loop.run_until_complete(t) self.assertTrue(t.done()) self.assertEqual(t.result(), 'ko') - + def test_async_coroutine(self): @tasks.coroutine def notmuch(): @@ -82,27 +82,27 @@ def notmuch(): self.assertTrue(t.done()) self.assertEqual(t.result(), 'ok') self.assertIs(t._loop, self.loop) - + loop = events.new_event_loop() t = tasks.async(notmuch(), loop=loop) self.assertIs(t._loop, loop) - + def test_async_future(self): f_orig = futures.Future() f_orig.set_result('ko') - + f = tasks.async(f_orig) self.loop.run_until_complete(f) self.assertTrue(f.done()) self.assertEqual(f.result(), 'ko') self.assertIs(f, f_orig) - + with self.assertRaises(ValueError): loop = events.new_event_loop() f = tasks.async(f_orig, loop=loop) f = tasks.async(f_orig, loop=self.loop) self.assertIs(f, f_orig) - + def test_async_task(self): @tasks.coroutine def notmuch(): @@ -113,13 +113,13 @@ def notmuch(): self.assertTrue(t.done()) self.assertEqual(t.result(), 'ok') self.assertIs(t, t_orig) - + with self.assertRaises(ValueError): loop = events.new_event_loop() t = tasks.async(t_orig, loop=loop) t = tasks.async(t_orig, loop=self.loop) self.assertIs(t, t_orig) - + def test_async_neither(self): with self.assertRaises(TypeError): tasks.async('ok') diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index f0eed149..f33075ee 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -6,6 +6,7 @@ from . import base_events from . import constants +from . import futures from . import transports from .log import tulip_log @@ -42,6 +43,12 @@ def _loop_reading(self, fut=None): self._read_fut = None return + # iocp, it is possible to close transport and + # receive data from socket + if self._closing: + data = None + return + self._read_fut = self._loop._proactor.recv(self._sock, 4096) except (ConnectionAbortedError, ConnectionResetError) as exc: if not self._closing: @@ -207,9 +214,15 @@ def loop(f=None): except OSError: sock.close() tulip_log.exception('Accept failed') + except futures.CancelledError: + sock.close() else: f.add_done_callback(loop) self.call_soon(loop) def _process_events(self, event_list): pass # XXX hard work currently done in poll + + def stop_serving(self, sock): + self._proactor.stop_serving(sock) + sock.close() diff --git a/tulip/test_utils.py b/tulip/test_utils.py index 3eacbade..1c192fc2 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -127,11 +127,11 @@ def run(loop, fut): for tr in transports: tr.close() - for s in socks: - s.close() - run_once(thread_loop) # call close callbacks + for s in socks: + thread_loop.stop_serving(s) + thread_loop.stop() thread_loop.close() gc.collect() diff --git a/tulip/windows_events.py b/tulip/windows_events.py index 0a2f07e0..8a789f1f 100644 --- a/tulip/windows_events.py +++ b/tulip/windows_events.py @@ -147,16 +147,26 @@ def _poll(self, timeout=None): return address = status[3] f, ov, obj, callback = self._cache.pop(address) - try: - value = callback() - except OSError as e: - f.set_exception(e) - self._results.append(f) - else: - f.set_result(value) - self._results.append(f) + if not f.cancelled(): + try: + value = callback() + except OSError as e: + f.set_exception(e) + self._results.append(f) + else: + f.set_result(value) + self._results.append(f) ms = 0 + def stop_serving(self, obj): + for (f, ov, ob, callback) in self._cache.values(): + if ob is obj: + f.cancel() + try: + ov.cancel() + except OSError: + pass + def close(self): for (f, ov, obj, callback) in self._cache.values(): try: From 6d9e2793a09bbeacc52fc8d107ed9269025d5d38 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 23 May 2013 15:58:35 -0700 Subject: [PATCH 0488/1502] Make _wait() simpler and stupider, and more correct. --- tests/tasks_test.py | 19 +++++++++ tulip/tasks.py | 96 +++++++++++++++++---------------------------- 2 files changed, 55 insertions(+), 60 deletions(-) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 7d76ac58..df5ca2ac 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -351,6 +351,15 @@ def foo(): self.assertTrue(t1-t0 <= 0.01) # TODO: Test different return_when values. + def test_wait_errors(self): + self.assertRaises( + ValueError, self.loop.run_until_complete, + tasks.wait(set())) + + self.assertRaises( + ValueError, self.loop.run_until_complete, + tasks.wait([tasks.sleep(10.0)], return_when=-1)) + def test_wait_first_completed(self): a = tasks.Task(tasks.sleep(10.0)) b = tasks.Task(tasks.sleep(0.1)) @@ -464,6 +473,16 @@ def foo(): self.assertTrue(t1-t0 >= 0.1) self.assertTrue(t1-t0 <= 0.13) + def test_wait_concurrent_complete(self): + a = tasks.Task(tasks.sleep(0.1)) + b = tasks.Task(tasks.sleep(0.15)) + + done, pending = self.loop.run_until_complete( + tasks.wait([b, a], timeout=0.1)) + + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + def test_as_completed(self): @tasks.coroutine def sleeper(dt, x): diff --git a/tulip/tasks.py b/tulip/tasks.py index 1b9317c9..1cddde25 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -67,12 +67,12 @@ def task_wrapper(*args, **kwds): def async(coro_or_future, *, loop=None, timeout=None): """Wrap a coroutine in a future. - + If the argument is a Future, it is returned directly. """ if isinstance(coro_or_future, futures.Future): - if ((loop != None and loop != coro_or_future._loop) or - (timeout != None and timeout != coro_or_future._timeout)): + if ((loop is not None and loop is not coro_or_future._loop) or + (timeout is not None and timeout != coro_or_future._timeout)): raise ValueError( 'loop and timeout arguments must agree with Future') @@ -214,9 +214,9 @@ def _wakeup(self, future): ALL_COMPLETED = concurrent.futures.ALL_COMPLETED -# Even though this *is* a @coroutine, we don't mark it as such! +@coroutine def wait(fs, timeout=None, return_when=ALL_COMPLETED): - """Wait for the Futures and and coroutines given by fs to complete. + """Wait for the Futures and coroutines given by fs to complete. Coroutines will be wrapped in Tasks. @@ -229,58 +229,45 @@ def wait(fs, timeout=None, return_when=ALL_COMPLETED): Note: This does not raise TimeoutError! Futures that aren't done when the timeout occurs are returned in the second set. """ - fs = _wrap_coroutines(fs) - return _wait(fs, timeout, return_when) + fs = set(map(async, fs)) + if not fs: + raise ValueError('Set of coroutines/Futures is empty.') + if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED): + raise ValueError('Invalid return_when value: {}'.format(return_when)) + return (yield from _wait(fs, timeout, return_when)) @coroutine -def _wait(fs, timeout=None, return_when=ALL_COMPLETED): - """Internal helper: Like wait() but does not wrap coroutines.""" - done, pending = set(), set() - - errors = 0 - for f in fs: - if f.done(): - done.add(f) - if not f.cancelled() and f.exception() is not None: - errors += 1 - else: - pending.add(f) - - if (not pending or - timeout is not None and timeout <= 0 or - return_when == FIRST_COMPLETED and done or - return_when == FIRST_EXCEPTION and errors): - return done, pending - - # Will always be cancelled eventually. - bail = futures.Future(timeout=timeout) +def _wait(fs, timeout, return_when): + """Internal helper for wait(return_when=FIRST_COMPLETED). - def _on_completion(fut): - pending.remove(fut) - done.add(fut) - if (not pending or + The fs argument must be a set of Futures. + The timeout argument is like for wait(). + """ + assert fs, 'Set of Futures is empty.' + waiter = futures.Future(timeout=timeout) + counter = len(fs) + def _on_completion(f): + nonlocal counter + counter -= 1 + if (counter <= 0 or return_when == FIRST_COMPLETED or - (return_when == FIRST_EXCEPTION and - not fut.cancelled() and - fut.exception() is not None)): - bail.cancel() - - for f in pending: - f.remove_done_callback(_on_completion) - - for f in pending: + return_when == FIRST_EXCEPTION and (not f.cancelled() and + f.exception() is not None)): + waiter.cancel() + for f in fs: f.add_done_callback(_on_completion) try: - yield from bail + yield from waiter except futures.CancelledError: pass - - really_done = set(f for f in pending if f.done()) - if really_done: - done.update(really_done) - pending.difference_update(really_done) - + done, pending = set(), set() + for f in fs: + f.remove_done_callback(_on_completion) + if f.done(): + done.add(f) + else: + pending.add(f) return done, pending @@ -304,7 +291,7 @@ def as_completed(fs, timeout=None): deadline = time.monotonic() + timeout done = None # Make nonlocal happy. - fs = _wrap_coroutines(fs) + fs = set(map(async, fs)) while fs: if deadline is not None: @@ -325,17 +312,6 @@ def _wait_for_some(): yield f -def _wrap_coroutines(fs): - """Internal helper to process an iterator of Futures and coroutines. - - Returns a set of Futures. - """ - wrapped = set() - for f in fs: - wrapped.add(async(f)) - return wrapped - - @coroutine def sleep(delay, result=None): """Coroutine that completes after a given time (in seconds).""" From 41bb209c56396be6f6887219cb04023545cf34a5 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Fri, 24 May 2013 13:30:03 +0100 Subject: [PATCH 0489/1502] Add tiny timeout to run_once() in test to keep Windows happy. --- tests/events_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/events_test.py b/tests/events_test.py index 6273a2cf..7bfa8730 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -527,7 +527,7 @@ def factory(): client = socket.socket() client.connect(('127.0.0.1', port)) client.send(b'xxx') - self.loop.run_once() + self.loop.run_once(0.001) self.assertIsInstance(proto, MyProto) self.assertEqual('INITIAL', proto.state) self.loop.run_once() From f563879154262040520d72d8e73a5157b0debcf5 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Fri, 24 May 2013 16:06:20 +0100 Subject: [PATCH 0490/1502] Issue 45: Ensure eof_received() is not invoked after close() called. --- tulip/proactor_events.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index f33075ee..41f3f82b 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -37,18 +37,19 @@ def _loop_reading(self, fut=None): try: if fut is not None: assert fut is self._read_fut - + self._read_fut = None data = fut.result() # deliver data later in "finally" clause - if not data: - self._read_fut = None - return - # iocp, it is possible to close transport and - # receive data from socket if self._closing: + # since close() has been called we ignore any read data data = None return + if data == b'': + # we got end-of-file so no need to reschedule a new read + return + + # reschedule a new read self._read_fut = self._loop._proactor.recv(self._sock, 4096) except (ConnectionAbortedError, ConnectionResetError) as exc: if not self._closing: From a0b8b28359dbd8847c6b4edb938c924d9ec5ecab Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 24 May 2013 10:29:51 -0700 Subject: [PATCH 0491/1502] Rewrite as_completed() to be more right. --- tests/tasks_test.py | 21 ++++++++++++++++++ tulip/tasks.py | 52 +++++++++++++++++++++++---------------------- 2 files changed, 48 insertions(+), 25 deletions(-) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index df5ca2ac..a65ecc89 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -537,6 +537,27 @@ def foo(): self.assertEqual(res[1][0], 2) self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) + def test_as_completed_reverse_wait(self): + a = tasks.sleep(0.05, 'a') + b = tasks.sleep(0.10, 'b') + fs = {a, b} + futs = list(tasks.as_completed(fs)) + self.assertEqual(len(futs), 2) + x = self.loop.run_until_complete(futs[1]) + self.assertEqual(x, 'a') + y = self.loop.run_until_complete(futs[0]) + self.assertEqual(y, 'b') + + def test_as_completed_concurrent(self): + a = tasks.sleep(0.05, 'a') + b = tasks.sleep(0.05, 'b') + fs = {a, b} + futs = list(tasks.as_completed(fs)) + self.assertEqual(len(futs), 2) + waiter = tasks.wait(futs) + done, pending = self.loop.run_until_complete(waiter) + self.assertEqual(set(f.result() for f in done), {'a', 'b'}) + def test_sleep(self): @tasks.coroutine def sleeper(dt, arg): diff --git a/tulip/tasks.py b/tulip/tasks.py index 1cddde25..2ef53e1d 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -5,11 +5,12 @@ 'wait', 'as_completed', 'sleep', ] +import collections import concurrent.futures import functools import inspect -import time +from . import events from . import futures @@ -286,30 +287,31 @@ def as_completed(fs, timeout=None): Note: The futures 'f' are not necessarily members of fs. """ - deadline = None - if timeout is not None: - deadline = time.monotonic() + timeout - - done = None # Make nonlocal happy. - fs = set(map(async, fs)) - - while fs: - if deadline is not None: - timeout = deadline - time.monotonic() - - @coroutine - def _wait_for_some(): - nonlocal done, fs - done, fs = yield from _wait(fs, timeout=timeout, - return_when=FIRST_COMPLETED) - if not done: - fs = set() - raise futures.TimeoutError() - return done.pop().result() # May raise. - - yield Task(_wait_for_some()) - for f in done: - yield f + loop = events.get_event_loop() + deadline = None if timeout is None else loop.time() + timeout + todo = set(async(f, loop=loop) for f in fs) + completed = collections.deque() + + @coroutine + def _wait_for_one(): + while not completed: + timeout = None + if deadline is not None: + timeout = deadline - loop.time() + if timeout < 0: + raise futures.TimeoutError() + done, pending = yield from _wait(todo, timeout, FIRST_COMPLETED) + # Multiple callers might be waiting for the same events + # and getting the same outcome. Dedupe by updating todo. + for f in done: + if f in todo: + todo.remove(f) + completed.append(f) + f = completed.popleft() + return f.result() # May raise. + + for _ in range(len(todo)): + yield _wait_for_one() @coroutine From cb7a8228ff76c581bf52538140e53125e892b48b Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 24 May 2013 11:20:21 -0700 Subject: [PATCH 0492/1502] accept loop in wait(), sleep(), as_completed() --- tulip/tasks.py | 66 +++++++++++++++++++++++++++----------------------- 1 file changed, 36 insertions(+), 30 deletions(-) diff --git a/tulip/tasks.py b/tulip/tasks.py index 2ef53e1d..09b923dd 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -1,8 +1,8 @@ """Support for tasks, coroutines and the scheduler.""" -__all__ = ['coroutine', 'task', 'async', 'Task', +__all__ = ['coroutine', 'task', 'Task', 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', - 'wait', 'as_completed', 'sleep', + 'wait', 'as_completed', 'sleep', 'async', ] import collections @@ -66,24 +66,6 @@ def task_wrapper(*args, **kwds): return task_wrapper -def async(coro_or_future, *, loop=None, timeout=None): - """Wrap a coroutine in a future. - - If the argument is a Future, it is returned directly. - """ - if isinstance(coro_or_future, futures.Future): - if ((loop is not None and loop is not coro_or_future._loop) or - (timeout is not None and timeout != coro_or_future._timeout)): - raise ValueError( - 'loop and timeout arguments must agree with Future') - - return coro_or_future - elif iscoroutine(coro_or_future): - return Task(coro_or_future, loop=loop, timeout=timeout) - else: - raise TypeError('A Future or coroutine is required') - - _marker = object() @@ -216,7 +198,7 @@ def _wakeup(self, future): @coroutine -def wait(fs, timeout=None, return_when=ALL_COMPLETED): +def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): """Wait for the Futures and coroutines given by fs to complete. Coroutines will be wrapped in Tasks. @@ -230,24 +212,28 @@ def wait(fs, timeout=None, return_when=ALL_COMPLETED): Note: This does not raise TimeoutError! Futures that aren't done when the timeout occurs are returned in the second set. """ - fs = set(map(async, fs)) if not fs: raise ValueError('Set of coroutines/Futures is empty.') + + loop = loop if loop is not None else events.get_event_loop() + fs = set(async(f, loop=loop) for f in fs) + if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED): raise ValueError('Invalid return_when value: {}'.format(return_when)) - return (yield from _wait(fs, timeout, return_when)) + return (yield from _wait(fs, timeout, return_when, loop)) @coroutine -def _wait(fs, timeout, return_when): +def _wait(fs, timeout, return_when, loop): """Internal helper for wait(return_when=FIRST_COMPLETED). The fs argument must be a set of Futures. The timeout argument is like for wait(). """ assert fs, 'Set of Futures is empty.' - waiter = futures.Future(timeout=timeout) + waiter = futures.Future(loop=loop, timeout=timeout) counter = len(fs) + def _on_completion(f): nonlocal counter counter -= 1 @@ -256,6 +242,7 @@ def _on_completion(f): return_when == FIRST_EXCEPTION and (not f.cancelled() and f.exception() is not None)): waiter.cancel() + for f in fs: f.add_done_callback(_on_completion) try: @@ -273,7 +260,7 @@ def _on_completion(f): # This is *not* a @coroutine! It is just an iterator (yielding Futures). -def as_completed(fs, timeout=None): +def as_completed(fs, *, loop=None, timeout=None): """Return an iterator whose values, when waited for, are Futures. This differs from PEP 3148; the proper way to use this is: @@ -287,7 +274,7 @@ def as_completed(fs, timeout=None): Note: The futures 'f' are not necessarily members of fs. """ - loop = events.get_event_loop() + loop = loop if loop is not None else events.get_event_loop() deadline = None if timeout is None else loop.time() + timeout todo = set(async(f, loop=loop) for f in fs) completed = collections.deque() @@ -300,7 +287,8 @@ def _wait_for_one(): timeout = deadline - loop.time() if timeout < 0: raise futures.TimeoutError() - done, pending = yield from _wait(todo, timeout, FIRST_COMPLETED) + done, pending = yield from _wait( + todo, timeout, FIRST_COMPLETED, loop) # Multiple callers might be waiting for the same events # and getting the same outcome. Dedupe by updating todo. for f in done: @@ -315,11 +303,29 @@ def _wait_for_one(): @coroutine -def sleep(delay, result=None): +def sleep(delay, result=None, *, loop=None): """Coroutine that completes after a given time (in seconds).""" - future = futures.Future() + future = futures.Future(loop=loop) h = future._loop.call_later(delay, future.set_result, result) try: return (yield from future) finally: h.cancel() + + +def async(coro_or_future, *, loop=None, timeout=None): + """Wrap a coroutine in a future. + + If the argument is a Future, it is returned directly. + """ + if isinstance(coro_or_future, futures.Future): + if ((loop is not None and loop is not coro_or_future._loop) or + (timeout is not None and timeout != coro_or_future._timeout)): + raise ValueError( + 'loop and timeout arguments must agree with Future') + + return coro_or_future + elif iscoroutine(coro_or_future): + return Task(coro_or_future, loop=loop, timeout=timeout) + else: + raise TypeError('A Future or coroutine is required') From 6a2acd9abe1a24210e29f39b064f298283d49822 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Sat, 25 May 2013 16:09:38 +0100 Subject: [PATCH 0493/1502] Refactor stop_serving() for proactors to be efficient. --- tulip/windows_events.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tulip/windows_events.py b/tulip/windows_events.py index 8a789f1f..b7e30913 100644 --- a/tulip/windows_events.py +++ b/tulip/windows_events.py @@ -46,6 +46,7 @@ def __init__(self, concurrency=0xffffffff): _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency) self._cache = {} self._registered = weakref.WeakSet() + self._stopped_serving = weakref.WeakSet() def registered_count(self): return len(self._cache) @@ -147,7 +148,9 @@ def _poll(self, timeout=None): return address = status[3] f, ov, obj, callback = self._cache.pop(address) - if not f.cancelled(): + if obj in self._stopped_serving: + f.cancel() + elif not f.cancelled(): try: value = callback() except OSError as e: @@ -159,13 +162,10 @@ def _poll(self, timeout=None): ms = 0 def stop_serving(self, obj): - for (f, ov, ob, callback) in self._cache.values(): - if ob is obj: - f.cancel() - try: - ov.cancel() - except OSError: - pass + # obj is a socket or pipe handle. It will be closed in + # BaseProactorEventLoop.stop_serving() which will make any + # pending operations fail quickly. + self._stopped_serving.add(obj) def close(self): for (f, ov, obj, callback) in self._cache.values(): From 1832c949d10bb1a8cc55b222181ac024b2de24df Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Mon, 27 May 2013 13:18:26 +0100 Subject: [PATCH 0494/1502] Cancelling a future wrapping an overlapped op now cancels overlapped op. --- tests/events_test.py | 33 +++++++++++++++++++++++++++++++++ tulip/windows_events.py | 20 +++++++++++++++++++- 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/tests/events_test.py b/tests/events_test.py index 7bfa8730..bd7bb642 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -836,6 +836,39 @@ def connect(): self.loop.run_until_complete(proto.done) self.assertEqual('CLOSED', proto.state) + def test_prompt_cancellation(self): + r, w = test_utils.socketpair() + r.setblocking(False) + f = self.loop.sock_recv(r, 1) + ov = getattr(f, 'ov', None) + self.assertTrue(ov is None or ov.pending) + + def main(): + try: + self.loop.call_soon(f.cancel) + yield from f + except futures.CancelledError: + res = 'cancelled' + else: + res = None + finally: + self.loop.stop() + return res + + start = time.monotonic() + t = tasks.Task(main(), timeout=1) + self.loop.run_forever() + elapsed = time.monotonic() - start + + self.assertTrue(elapsed < 0.1) + self.assertEqual(t.result(), 'cancelled') + self.assertRaises(futures.CancelledError, f.result) + self.assertTrue(ov is None or not ov.pending) + self.loop.stop_serving(r) + + r.close() + w.close() + if sys.platform == 'win32': from tulip import windows_events diff --git a/tulip/windows_events.py b/tulip/windows_events.py index b7e30913..6e61b2c7 100644 --- a/tulip/windows_events.py +++ b/tulip/windows_events.py @@ -23,6 +23,24 @@ ERROR_CONNECTION_ABORTED = 1236 +class _OverlappedFuture(futures.Future): + """Subclass of Future which represents an overlapped operation. + + Cancelling it will immediately cancel the overlapped operation. + """ + + def __init__(self, ov): + super().__init__() + self.ov = ov + + def cancel(self): + try: + self.ov.cancel() + except OSError: + pass + return super().cancel() + + class SelectorEventLoop(selector_events.BaseSelectorEventLoop): def _socketpair(self): return winsocketpair.socketpair() @@ -124,7 +142,7 @@ def _register_with_iocp(self, obj): _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) def _register(self, ov, obj, callback): - f = futures.Future() + f = _OverlappedFuture(ov) self._cache[ov.address] = (f, ov, obj, callback) return f From cc2fc685a382e175609fe565730fef1d87c443b0 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 28 May 2013 10:04:11 -0700 Subject: [PATCH 0495/1502] handle ConnectionResetError by Jeff Quast #31 --- examples/tcp_echo.py | 3 ++- tests/proactor_events_test.py | 19 +++++++++++++++ tests/selector_events_test.py | 44 +++++++++++++++++++++++++++++++++++ tulip/proactor_events.py | 4 +++- tulip/selector_events.py | 4 ++++ 5 files changed, 72 insertions(+), 2 deletions(-) diff --git a/examples/tcp_echo.py b/examples/tcp_echo.py index ff40c4ab..39db5cca 100755 --- a/examples/tcp_echo.py +++ b/examples/tcp_echo.py @@ -37,7 +37,7 @@ def eof_received(self): pass def connection_lost(self, exc): - print('connection lost') + print('connection lost:', exc) self.h_timeout.cancel() @@ -60,6 +60,7 @@ def eof_received(self): pass def connection_lost(self, exc): + print('connection lost:', exc) tulip.get_event_loop().stop() diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py index 468a7ede..14764e49 100644 --- a/tests/proactor_events_test.py +++ b/tests/proactor_events_test.py @@ -76,6 +76,25 @@ def test_loop_reading_aborted_closing(self): tr._loop_reading() self.assertFalse(tr._fatal_error.called) + def test_loop_reading_aborted_is_fatal(self): + self.loop._proactor.recv.side_effect = ConnectionAbortedError() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = False + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + self.assertTrue(tr._fatal_error.called) + + def test_loop_reading_conn_reset_lost(self): + err = self.loop._proactor.recv.side_effect = ConnectionResetError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = False + tr._fatal_error = unittest.mock.Mock() + tr._force_close = unittest.mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + tr._force_close.assert_called_with(err) + def test_loop_reading_exception(self): err = self.loop._proactor.recv.side_effect = (OSError()) diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index 656178af..a33c0ae6 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -201,6 +201,18 @@ def test__sock_sendall_tryagain(self): (10, self.loop._sock_sendall, f, True, sock, b'data'), self.loop.add_writer.call_args[0]) + def test__sock_sendall_interrupted(self): + f = futures.Future() + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.send.side_effect = InterruptedError + + self.loop.add_writer = unittest.mock.Mock() + self.loop._sock_sendall(f, False, sock, b'data') + self.assertEqual( + (10, self.loop._sock_sendall, f, True, sock, b'data'), + self.loop.add_writer.call_args[0]) + def test__sock_sendall_exception(self): f = futures.Future() sock = unittest.mock.Mock() @@ -674,6 +686,27 @@ def test_read_ready_tryagain(self, m_exc): self.assertFalse(transport._fatal_error.called) + @unittest.mock.patch('logging.exception') + def test_read_ready_tryagain_interrupted(self, m_exc): + self.sock.recv.side_effect = InterruptedError + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + + @unittest.mock.patch('logging.exception') + def test_read_ready_conn_reset(self, m_exc): + err = self.sock.recv.side_effect = ConnectionResetError() + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport._force_close = unittest.mock.Mock() + transport._read_ready() + transport._force_close.assert_called_with(err) + @unittest.mock.patch('logging.exception') def test_read_ready_err(self, m_exc): err = self.sock.recv.side_effect = OSError() @@ -1045,6 +1078,13 @@ def test_on_ready_recv_eof(self): transport.close.assert_called_with() self.protocol.eof_received.assert_called_with() + def test_on_ready_recv_conn_reset(self): + err = self.sslsock.recv.side_effect = ConnectionResetError() + transport = self._make_one() + transport._force_close = unittest.mock.Mock() + transport._on_ready() + transport._force_close.assert_called_with(err) + def test_on_ready_recv_retry(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError transport = self._make_one() @@ -1060,6 +1100,10 @@ def test_on_ready_recv_retry(self): transport._on_ready() self.assertFalse(self.protocol.data_received.called) + self.sslsock.recv.side_effect = InterruptedError + transport._on_ready() + self.assertFalse(self.protocol.data_received.called) + def test_on_ready_recv_exc(self): err = self.sslsock.recv.side_effect = OSError() transport = self._make_one() diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index 41f3f82b..41de8d42 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -51,9 +51,11 @@ def _loop_reading(self, fut=None): # reschedule a new read self._read_fut = self._loop._proactor.recv(self._sock, 4096) - except (ConnectionAbortedError, ConnectionResetError) as exc: + except ConnectionAbortedError as exc: if not self._closing: self._fatal_error(exc) + except ConnectionResetError as exc: + self._force_close(exc) except OSError as exc: self._fatal_error(exc) else: diff --git a/tulip/selector_events.py b/tulip/selector_events.py index ec2fd1c9..cdd04623 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -389,6 +389,8 @@ def _read_ready(self): data = self._sock.recv(16*1024) except (BlockingIOError, InterruptedError): pass + except ConnectionResetError as exc: + self._force_close(exc) except Exception as exc: self._fatal_error(exc) else: @@ -533,6 +535,8 @@ def _on_ready(self): except (BlockingIOError, InterruptedError, ssl.SSLWantReadError, ssl.SSLWantWriteError): pass + except ConnectionResetError as exc: + self._force_close(exc) except Exception as exc: self._fatal_error(exc) else: From 73e91d3322bd0700cd88d66bd3de75f7d53029d9 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Thu, 30 May 2013 14:09:00 +0100 Subject: [PATCH 0496/1502] Move various Windows specific things to a windows_utils submodule. Also rename examples/windows_subprocess.py to examples/child_process.py and make it work on Unix. --- examples/child_process.py | 129 ++++++++++++++ tests/windows_utils_test.py | 132 ++++++++++++++ tests/winsocketpair_test.py | 26 --- tulip/proactor_events.py | 13 +- tulip/test_utils.py | 2 +- tulip/windows_events.py | 6 +- .../windows_utils.py | 166 ++++++------------ tulip/winsocketpair.py | 34 ---- 8 files changed, 324 insertions(+), 184 deletions(-) create mode 100644 examples/child_process.py create mode 100644 tests/windows_utils_test.py delete mode 100644 tests/winsocketpair_test.py rename examples/windows_subprocess.py => tulip/windows_utils.py (50%) delete mode 100644 tulip/winsocketpair.py diff --git a/examples/child_process.py b/examples/child_process.py new file mode 100644 index 00000000..e21a925a --- /dev/null +++ b/examples/child_process.py @@ -0,0 +1,129 @@ +""" +Example of asynchronous interaction with a child python process. + +Note that on Windows we must use the IOCP event loop. +""" + +import os +import sys + +try: + import tulip +except ImportError: + # tulip is not installed + sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + import tulip + +from tulip import streams +from tulip import protocols + +if sys.platform == 'win32': + from tulip.windows_utils import Popen, PIPE + from tulip.windows_events import ProactorEventLoop +else: + from subprocess import Popen, PIPE + +# +# Return a write-only transport wrapping a writable pipe +# + +def connect_write_pipe(file): + loop = tulip.get_event_loop() + protocol = protocols.Protocol() + return loop._make_write_pipe_transport(file, protocol) + +# +# Wrap a readable pipe in a stream +# + +def connect_read_pipe(file): + loop = tulip.get_event_loop() + stream_reader = streams.StreamReader() + protocol = _StreamReaderProtocol(stream_reader) + transport = loop._make_read_pipe_transport(file, protocol) + return stream_reader + +class _StreamReaderProtocol(protocols.Protocol): + def __init__(self, stream_reader): + self.stream_reader = stream_reader + def connection_lost(self, exc): + self.stream_reader.set_exception(exc) + def data_received(self, data): + self.stream_reader.feed_data(data) + def eof_received(self): + self.stream_reader.feed_eof() + +# +# Example +# + +@tulip.task +def main(loop): + # program which prints evaluation of each expression from stdin + code = r'''if 1: + import os + def writeall(fd, buf): + while buf: + n = os.write(fd, buf) + buf = buf[n:] + while True: + s = os.read(0, 1024) + if not s: + break + s = s.decode('ascii') + s = repr(eval(s)) + '\n' + s = s.encode('ascii') + writeall(1, s) + ''' + + # commands to send to input + commands = iter([b"1+1\n", + b"2**16\n", + b"1/3\n", + b"'x'*50", + b"1/0\n"]) + + # start subprocess and wrap stdin, stdout, stderr + p = Popen([sys.executable, '-c', code], + stdin=PIPE, stdout=PIPE, stderr=PIPE) + stdin = connect_write_pipe(p.stdin) + stdout = connect_read_pipe(p.stdout) + stderr = connect_read_pipe(p.stderr) + + # interact with subprocess + name = {stdout:'OUT', stderr:'ERR'} + registered = {tulip.Task(stderr.readline()): stderr, + tulip.Task(stdout.readline()): stdout} + while registered: + # write command + cmd = next(commands, None) + if cmd is None: + stdin.close() + else: + print('>>>', cmd.decode('ascii').rstrip()) + stdin.write(cmd) + + # get and print lines from stdout, stderr + timeout = None + while registered: + done, pending = yield from tulip.wait( + registered, timeout=timeout, return_when=tulip.FIRST_COMPLETED) + if not done: + break + for f in done: + stream = registered.pop(f) + res = f.result() + print(name[stream], res.decode('ascii').rstrip()) + if res != b'': + registered[tulip.Task(stream.readline())] = stream + timeout = 0.0 + + +if __name__ == '__main__': + if sys.platform == 'win32': + loop = ProactorEventLoop() + tulip.set_event_loop(loop) + else: + loop = tulip.get_event_loop() + loop.run_until_complete(main(loop)) + loop.close() diff --git a/tests/windows_utils_test.py b/tests/windows_utils_test.py new file mode 100644 index 00000000..d94716a2 --- /dev/null +++ b/tests/windows_utils_test.py @@ -0,0 +1,132 @@ +"""Tests for winsocketpair.py""" + +import sys +import test.support +import unittest +import unittest.mock +import _winapi + +from tulip import windows_utils +from tulip import _overlapped + + +class WinsocketpairTests(unittest.TestCase): + + def test_winsocketpair(self): + ssock, csock = windows_utils.socketpair() + + csock.send(b'xxx') + self.assertEqual(b'xxx', ssock.recv(1024)) + + csock.close() + ssock.close() + + @unittest.mock.patch('tulip.windows_utils.socket') + def test_winsocketpair_exc(self, m_socket): + m_socket.socket.return_value.getsockname.return_value = ('', 12345) + m_socket.socket.return_value.accept.return_value = object(), object() + m_socket.socket.return_value.connect.side_effect = OSError() + + self.assertRaises(OSError, windows_utils.socketpair) + + +class PipeTests(unittest.TestCase): + + def test_pipe_overlapped(self): + h1, h2 = windows_utils.pipe(overlapped=(True, True)) + try: + ov1 = _overlapped.Overlapped() + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, 0) + + ov1.ReadFile(h1, 100) + self.assertTrue(ov1.pending) + self.assertEqual(ov1.error, _winapi.ERROR_IO_PENDING) + ERROR_IO_INCOMPLETE = 996 + try: + ov1.getresult() + except OSError as e: + self.assertEqual(e.winerror, ERROR_IO_INCOMPLETE) + else: + raise RuntimeError('expected ERROR_IO_INCOMPLETE') + + ov2 = _overlapped.Overlapped() + self.assertFalse(ov2.pending) + self.assertEqual(ov2.error, 0) + + ov2.WriteFile(h2, b"hello") + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + + res = _winapi.WaitForMultipleObjects([ov2.event], False, 100) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, ERROR_IO_INCOMPLETE) + self.assertFalse(ov2.pending) + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + self.assertEqual(ov1.getresult(), b"hello") + finally: + _winapi.CloseHandle(h1) + _winapi.CloseHandle(h2) + + def test_pipe_handle(self): + h, _ = windows_utils.pipe(overlapped=(True, True)) + _winapi.CloseHandle(_) + p = windows_utils.PipeHandle(h) + self.assertEqual(p.fileno(), h) + self.assertEqual(p.handle, h) + + # check garbage collection of p closes handle + del p + test.support.gc_collect() + try: + _winapi.CloseHandle(h) + except OSError as e: + self.assertEqual(e.winerror, 6) # ERROR_INVALID_HANDLE + else: + raise RuntimeError('expected ERROR_INVALID_HANDLE') + + +class PopenTests(unittest.TestCase): + + def test_popen(self): + command = r"""if 1: + import sys + s = sys.stdin.readline() + sys.stdout.write(s.upper()) + sys.stderr.write('stderr') + """ + msg = b"blah\n" + + p = windows_utils.Popen([sys.executable, '-c', command], + stdin=windows_utils.PIPE, + stdout=windows_utils.PIPE, + stderr=windows_utils.PIPE) + + for f in [p.stdin, p.stdout, p.stderr]: + self.assertIsInstance(f, windows_utils.PipeHandle) + + ovin = _overlapped.Overlapped() + ovout = _overlapped.Overlapped() + overr = _overlapped.Overlapped() + + ovin.WriteFile(p.stdin.handle, msg) + ovout.ReadFile(p.stdout.handle, 100) + overr.ReadFile(p.stderr.handle, 100) + + events = [ovin.event, ovout.event, overr.event] + res = _winapi.WaitForMultipleObjects(events, True, 2000) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + self.assertFalse(ovout.pending) + self.assertFalse(overr.pending) + self.assertFalse(ovin.pending) + + self.assertEqual(ovin.getresult(), len(msg)) + out = ovout.getresult().rstrip() + err = overr.getresult().rstrip() + + self.assertTrue(len(out) > 0) + self.assertTrue(len(err) > 0) + # allow for partial reads... + self.assertTrue(msg.upper().rstrip().startswith(out)) + self.assertTrue(b"stderr".startswith(err)) diff --git a/tests/winsocketpair_test.py b/tests/winsocketpair_test.py deleted file mode 100644 index 381fb227..00000000 --- a/tests/winsocketpair_test.py +++ /dev/null @@ -1,26 +0,0 @@ -"""Tests for winsocketpair.py""" - -import unittest -import unittest.mock - -from tulip import winsocketpair - - -class WinsocketpairTests(unittest.TestCase): - - def test_winsocketpair(self): - ssock, csock = winsocketpair.socketpair() - - csock.send(b'xxx') - self.assertEqual(b'xxx', ssock.recv(1024)) - - csock.close() - ssock.close() - - @unittest.mock.patch('tulip.winsocketpair.socket') - def test_winsocketpair_exc(self, m_socket): - m_socket.socket.return_value.getsockname.return_value = ('', 12345) - m_socket.socket.return_value.accept.return_value = object(), object() - m_socket.socket.return_value.connect.side_effect = OSError() - - self.assertRaises(OSError, winsocketpair.socketpair) diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index 41de8d42..c04e924f 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -147,10 +147,17 @@ def __init__(self, proactor): self._selector = proactor # convenient alias self._make_self_pipe() - def _make_socket_transport(self, sock, protocol, waiter=None, extra=None, - write_only=False): + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + return _ProactorSocketTransport(self, sock, protocol, waiter, extra) + + def _make_read_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorSocketTransport(self, sock, protocol, waiter, extra) + + def _make_write_pipe_transport(self, sock, protocol, waiter=None, + extra=None): return _ProactorSocketTransport(self, sock, protocol, waiter, extra, - write_only) + write_only=True) def close(self): if self._proactor is not None: diff --git a/tulip/test_utils.py b/tulip/test_utils.py index 1c192fc2..6757d93c 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -26,7 +26,7 @@ if sys.platform == 'win32': # pragma: no cover - from .winsocketpair import socketpair + from .windows_utils import socketpair else: from socket import socketpair # pragma: no cover diff --git a/tulip/windows_events.py b/tulip/windows_events.py index 6e61b2c7..08f221e2 100644 --- a/tulip/windows_events.py +++ b/tulip/windows_events.py @@ -9,7 +9,7 @@ from . import futures from . import proactor_events from . import selector_events -from . import winsocketpair +from . import windows_utils from . import _overlapped from .log import tulip_log @@ -43,7 +43,7 @@ def cancel(self): class SelectorEventLoop(selector_events.BaseSelectorEventLoop): def _socketpair(self): - return winsocketpair.socketpair() + return windows_utils.socketpair() class ProactorEventLoop(proactor_events.BaseProactorEventLoop): @@ -53,7 +53,7 @@ def __init__(self, proactor=None): super().__init__(proactor) def _socketpair(self): - return winsocketpair.socketpair() + return windows_utils.socketpair() class IocpProactor: diff --git a/examples/windows_subprocess.py b/tulip/windows_utils.py similarity index 50% rename from examples/windows_subprocess.py rename to tulip/windows_utils.py index a839613b..bf85f31e 100644 --- a/examples/windows_subprocess.py +++ b/tulip/windows_utils.py @@ -1,28 +1,22 @@ """ -Example of asynchronous interaction with a subprocess on Windows. - -This requires use of overlapped pipe handles and (a modified) iocp proactor. +Various Windows specific bits and pieces """ +import sys + +if sys.platform != 'win32': # pragma: no cover + raise ImportError('win32 only') + +import socket import itertools import msvcrt import os import subprocess -import sys import tempfile import _winapi -try: - import tulip -except ImportError: - # tulip is not installed - sys.path.append(os.path.join(os.path.dirname(__file__), '..')) - import tulip -from tulip import _overlapped -from tulip import windows_events -from tulip import streams -from tulip import protocols +__all__ = ['socketpair', 'pipe', 'Popen', 'PIPE', 'PipeHandle'] # # Constants/globals @@ -32,11 +26,42 @@ PIPE = subprocess.PIPE _mmap_counter=itertools.count() +# +# Replacement for socket.socketpair() +# + +def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """A socket pair usable as a self-pipe, for Windows. + + Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. + """ + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + lsock.bind(('localhost', 0)) + lsock.listen(1) + addr, port = lsock.getsockname() + csock = socket.socket(family, type, proto) + csock.setblocking(False) + try: + csock.connect((addr, port)) + except (BlockingIOError, InterruptedError): + pass + except Exception: + lsock.close() + csock.close() + raise + ssock, _ = lsock.accept() + csock.setblocking(True) + lsock.close() + return (ssock, csock) + # # Replacement for os.pipe() using handles instead of fds # def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): + """Like os.pipe() but with overlapped support and using handles not fds.""" address = tempfile.mktemp(prefix=r'\\.\pipe\python-pipe-%d-%d-' % (os.getpid(), next(_mmap_counter))) @@ -84,6 +109,10 @@ def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): # class PipeHandle: + """Wrapper for an overlapped pipe handle which is vaguely file-object like. + + The IOCP event loop can use these instead of socket objects. + """ def __init__(self, handle): self._handle = handle @@ -112,6 +141,10 @@ def __exit__(self, t, v, tb): # class Popen(subprocess.Popen): + """Replacement for subprocess.Popen using overlapped pipe handles. + + The stdin, stdout, stderr are None or instances of PipeHandle. + """ def __init__(self, args, stdin=None, stdout=None, stderr=None, **kwds): stdin_rfd = stdout_wfd = stderr_wfd = None stdin_wh = stdout_rh = stderr_rh = None @@ -125,7 +158,8 @@ def __init__(self, args, stdin=None, stdout=None, stderr=None, **kwds): stderr_rh, stderr_wh = pipe(overlapped=(True, False)) stderr_wfd = msvcrt.open_osfhandle(stderr_wh, 0) try: - super().__init__(args, stdin=stdin_rfd, stdout=stdout_wfd, + super().__init__(args, bufsize=0, universal_newlines=False, + stdin=stdin_rfd, stdout=stdout_wfd, stderr=stderr_wfd, **kwds) except: for h in (stdin_wh, stdout_rh, stderr_rh): @@ -145,105 +179,3 @@ def __init__(self, args, stdin=None, stdout=None, stderr=None, **kwds): os.close(stdout_wfd) if stderr == PIPE: os.close(stderr_wfd) - -# -# Return a write-only transport wrapping a writable pipe -# - -def connect_write_pipe(file): - loop = tulip.get_event_loop() - protocol = protocols.Protocol() - return loop._make_socket_transport(file, protocol, write_only=True) - -# -# Wrap a readable pipe in a stream -# - -def connect_read_pipe(file): - loop = tulip.get_event_loop() - stream_reader = streams.StreamReader() - protocol = _StreamReaderProtocol(stream_reader) - transport = loop._make_socket_transport(file, protocol) - return stream_reader - -class _StreamReaderProtocol(protocols.Protocol): - def __init__(self, stream_reader): - self.stream_reader = stream_reader - def connection_lost(self, exc): - self.stream_reader.set_exception(exc) - def data_received(self, data): - self.stream_reader.feed_data(data) - def eof_received(self): - self.stream_reader.feed_eof() - -# -# Example -# - -@tulip.task -def main(loop): - # program which prints evaluation of each expression from stdin - code = r'''if 1: - import os - def writeall(fd, buf): - while buf: - n = os.write(fd, buf) - buf = buf[n:] - while True: - s = os.read(0, 1024) - if not s: - break - s = s.decode('ascii') - s = repr(eval(s)) + '\n' - s = s.encode('ascii') - writeall(1, s) - ''' - - # commands to send to input - commands = iter([b"1+1\n", - b"2**16\n", - b"1/3\n", - b"'x'*50", - b"1/0\n"]) - - # start subprocess and wrap stdin, stdout, stderr - p = Popen([sys.executable, '-c', code], - stdin=PIPE, stdout=PIPE, stderr=PIPE) - stdin = connect_write_pipe(p.stdin) - stdout = connect_read_pipe(p.stdout) - stderr = connect_read_pipe(p.stderr) - - # interact with subprocess - name = {stdout:'OUT', stderr:'ERR'} - registered = {tulip.Task(stderr.readline()): stderr, - tulip.Task(stdout.readline()): stdout} - while registered: - # write command - cmd = next(commands, None) - if cmd is None: - stdin.close() - else: - print('>>>', cmd.decode('ascii').rstrip()) - stdin.write(cmd) - - # get and print lines from stdout, stderr - timeout = None - while True: - done, pending = yield from tulip.wait( - registered, timeout, tulip.FIRST_COMPLETED) - if not done: - break - for f in done: - stream = registered.pop(f) - res = f.result() - print(name[stream], res.decode('ascii').rstrip()) - if res != b'': - registered[tulip.Task(stream.readline())] = stream - timeout = 0.0 - - -if __name__ == '__main__': - loop = windows_events.ProactorEventLoop() - tulip.set_event_loop(loop) - loop.run_until_complete(main(loop)) - loop.close() diff --git a/tulip/winsocketpair.py b/tulip/winsocketpair.py deleted file mode 100644 index 59c8aecc..00000000 --- a/tulip/winsocketpair.py +++ /dev/null @@ -1,34 +0,0 @@ -"""A socket pair usable as a self-pipe, for Windows. - -Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. -""" - -import socket -import sys - -if sys.platform != 'win32': # pragma: no cover - raise ImportError('winsocketpair is win32 only') - - -def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): - """Emulate the Unix socketpair() function on Windows.""" - # We create a connected TCP socket. Note the trick with setblocking(0) - # that prevents us from having to create a thread. - lsock = socket.socket(family, type, proto) - lsock.bind(('localhost', 0)) - lsock.listen(1) - addr, port = lsock.getsockname() - csock = socket.socket(family, type, proto) - csock.setblocking(False) - try: - csock.connect((addr, port)) - except (BlockingIOError, InterruptedError): - pass - except Exception: - lsock.close() - csock.close() - raise - ssock, _ = lsock.accept() - csock.setblocking(True) - lsock.close() - return (ssock, csock) From 4b7322d1d1a921cabb4cea68e7b3bb30dec300a4 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Thu, 30 May 2013 14:13:14 +0100 Subject: [PATCH 0497/1502] Correct docstring for test/windows_utils_test.py. --- tests/windows_utils_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/windows_utils_test.py b/tests/windows_utils_test.py index d94716a2..c43fa235 100644 --- a/tests/windows_utils_test.py +++ b/tests/windows_utils_test.py @@ -1,4 +1,4 @@ -"""Tests for winsocketpair.py""" +"""Tests for window_utils""" import sys import test.support From 27220c83b58ab44e85eecdfdb2ec566ab97b9a98 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 30 May 2013 09:45:07 -0700 Subject: [PATCH 0498/1502] rename RawRequestMessage and RawResponseMessage namedtuple for consistency (codereview 9698044) --- tulip/http/client.py | 2 +- tulip/http/protocol.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tulip/http/client.py b/tulip/http/client.py index 4c797b8c..f0984f9d 100644 --- a/tulip/http/client.py +++ b/tulip/http/client.py @@ -401,7 +401,7 @@ def send(self, transport): class HttpResponse(http.client.HTTPMessage): - message = None # RawResponseStatus object + message = None # RawResponseMessage object # from the Status-Line of the response version = None # HTTP-Version diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py index 0a2959cf..997d6ef3 100644 --- a/tulip/http/protocol.py +++ b/tulip/http/protocol.py @@ -28,12 +28,12 @@ RawRequestMessage = collections.namedtuple( - 'RawRequestLine', + 'RawRequestMessage', ['method', 'path', 'version', 'headers', 'should_close', 'compression']) RawResponseMessage = collections.namedtuple( - 'RawResponseStatus', + 'RawResponseMessage', ['version', 'code', 'reason', 'headers', 'should_close', 'compression']) From 52177370a52b84256ec7348db13eb7fea6b0c7f1 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 31 May 2013 15:15:10 -0700 Subject: [PATCH 0499/1502] replace run_once() in tests with test helper --- tests/events_test.py | 37 ++++++++++-------- tests/futures_test.py | 3 +- tests/http_client_functional_test.py | 2 +- tests/http_server_test.py | 4 +- tests/locks_test.py | 56 ++++++++++++++-------------- tests/tasks_test.py | 25 +++++++------ tulip/test_utils.py | 14 ++++++- 7 files changed, 80 insertions(+), 61 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index bd7bb642..0ffabbd7 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -147,7 +147,7 @@ def setUp(self): def tearDown(self): # just in case if we have transport close callbacks - test_utils.run_once(self.loop) + test_utils.run_briefly(self.loop) self.loop.close() gc.collect() @@ -390,9 +390,9 @@ def my_handler(): self.assertFalse(self.loop.remove_signal_handler(signal.SIGKILL)) # Now set a handler and handle it. self.loop.add_signal_handler(signal.SIGINT, my_handler) - self.loop.run_once() + test_utils.run_briefly(self.loop) os.kill(os.getpid(), signal.SIGINT) - self.loop.run_once() + test_utils.run_briefly(self.loop) self.assertEqual(caught, 1) # Removing it should restore the default handler. self.assertTrue(self.loop.remove_signal_handler(signal.SIGINT)) @@ -527,12 +527,12 @@ def factory(): client = socket.socket() client.connect(('127.0.0.1', port)) client.send(b'xxx') - self.loop.run_once(0.001) + test_utils.run_briefly(self.loop) self.assertIsInstance(proto, MyProto) self.assertEqual('INITIAL', proto.state) - self.loop.run_once() + test_utils.run_briefly(self.loop) self.assertEqual('CONNECTED', proto.state) - self.loop.run_once(0.001) # windows iocp + test_utils.run_briefly(self.loop) # windows iocp self.assertEqual(3, proto.nbytes) # extra info is available @@ -544,7 +544,7 @@ def factory(): # close connection proto.transport.close() - self.loop.run_once(0.001) # windows iocp + test_utils.run_briefly(self.loop) # windows iocp self.assertEqual('CLOSED', proto.state) @@ -587,9 +587,9 @@ def factory(): client, pr = self.loop.run_until_complete(f_c) client.write(b'xxx') - self.loop.run_once() + test_utils.run_briefly(self.loop) self.assertIsInstance(proto, MyProto) - self.loop.run_once() + test_utils.run_briefly(self.loop) self.assertEqual('CONNECTED', proto.state) self.assertEqual(3, proto.nbytes) @@ -722,9 +722,9 @@ def datagram_received(self, data, addr): self.assertEqual('INITIALIZED', client.state) transport.sendto(b'xxx') - self.loop.run_once(None) + test_utils.run_briefly(self.loop) self.assertEqual(3, server.nbytes) - self.loop.run_once(None) + test_utils.run_briefly(self.loop) # received self.assertEqual(8, client.nbytes) @@ -775,11 +775,11 @@ def connect(): self.loop.run_until_complete(connect()) os.write(wpipe, b'1') - self.loop.run_once() + test_utils.run_briefly(self.loop) self.assertEqual(1, proto.nbytes) os.write(wpipe, b'2345') - self.loop.run_once() + test_utils.run_briefly(self.loop) self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) self.assertEqual(5, proto.nbytes) @@ -816,12 +816,12 @@ def connect(): self.loop.run_until_complete(connect()) transport.write(b'1') - self.loop.run_once() + test_utils.run_briefly(self.loop) data = os.read(rpipe, 1024) self.assertEqual(b'1', data) transport.write(b'2345') - self.loop.run_once() + test_utils.run_briefly(self.loop) data = os.read(rpipe, 1024) self.assertEqual(b'2345', data) self.assertEqual('CONNECTED', proto.state) @@ -882,18 +882,25 @@ class ProactorEventLoopTests(EventLoopTestsMixin, unittest.TestCase): def create_event_loop(self): return windows_events.ProactorEventLoop() + def test_create_ssl_connection(self): raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + def test_start_serving_ssl(self): raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + def test_reader_callback(self): raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_reader_callback_cancel(self): raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + def test_writer_callback(self): raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_writer_callback_cancel(self): raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + def test_create_datagram_endpoint(self): raise unittest.SkipTest( "IocpEventLoop does not have create_datagram_endpoint()") diff --git a/tests/futures_test.py b/tests/futures_test.py index 18e70c41..87198cf7 100644 --- a/tests/futures_test.py +++ b/tests/futures_test.py @@ -7,6 +7,7 @@ from tulip import events from tulip import futures +from tulip import test_utils def _fakefunc(f): @@ -193,7 +194,7 @@ def test_tb_logger_exception_unretrieved(self, m_log): fut = futures.Future() fut.set_exception(RuntimeError('boom')) del fut - self.loop.run_once() + test_utils.run_briefly(self.loop) self.assertTrue(m_log.error.called) @unittest.mock.patch('tulip.futures.tulip_log') diff --git a/tests/http_client_functional_test.py b/tests/http_client_functional_test.py index 125927f4..3abc0076 100644 --- a/tests/http_client_functional_test.py +++ b/tests/http_client_functional_test.py @@ -20,7 +20,7 @@ def setUp(self): def tearDown(self): # just in case if we have transport close callbacks - test_utils.run_once(self.loop) + test_utils.run_briefly(self.loop) self.loop.close() gc.collect() diff --git a/tests/http_server_test.py b/tests/http_server_test.py index cba5fabc..ac52d1bb 100644 --- a/tests/http_server_test.py +++ b/tests/http_server_test.py @@ -6,7 +6,7 @@ import tulip from tulip.http import server from tulip.http import errors -from tulip.test_utils import run_once +from tulip.test_utils import run_briefly class HttpServerProtocolTests(unittest.TestCase): @@ -204,7 +204,7 @@ def test_handle_cancelled(self): srv.connection_made(transport) srv.handle_request = unittest.mock.Mock() - run_once(self.loop) # start request_handler task + run_briefly(self.loop) # start request_handler task srv.stream.feed_data( b'GET / HTTP/1.0\r\n' diff --git a/tests/locks_test.py b/tests/locks_test.py index f3bd01e0..48242bfc 100644 --- a/tests/locks_test.py +++ b/tests/locks_test.py @@ -8,7 +8,7 @@ from tulip import futures from tulip import locks from tulip import tasks -from tulip.test_utils import run_once +from tulip.test_utils import run_briefly class LockTests(unittest.TestCase): @@ -81,24 +81,24 @@ def c3(result): t1 = tasks.Task(c1(result)) t2 = tasks.Task(c2(result)) - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([], result) lock.release() - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([1], result) - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([1], result) t3 = tasks.Task(c3(result)) lock.release() - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([1, 2], result) lock.release() - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([1, 2, 3], result) self.assertTrue(t1.done()) @@ -235,13 +235,13 @@ def c3(result): t1 = tasks.Task(c1(result)) t2 = tasks.Task(c2(result)) - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([], result) t3 = tasks.Task(c3(result)) ev.set() - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([3, 1, 2], result) self.assertTrue(t1.done()) @@ -319,7 +319,7 @@ def c1(result): return True t = tasks.Task(c1(result)) - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([], result) ev.set() @@ -330,7 +330,7 @@ def c1(result): ev.set() self.assertEqual(1, len(ev._waiters)) - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([1], result) self.assertEqual(0, len(ev._waiters)) @@ -376,33 +376,33 @@ def c3(result): t2 = tasks.Task(c2(result)) t3 = tasks.Task(c3(result)) - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([], result) self.assertFalse(cond.locked()) self.assertTrue(self.loop.run_until_complete(cond.acquire())) cond.notify() - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([], result) self.assertTrue(cond.locked()) cond.release() - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([1], result) self.assertTrue(cond.locked()) cond.notify(2) - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([1], result) self.assertTrue(cond.locked()) cond.release() - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([1, 2], result) self.assertTrue(cond.locked()) cond.release() - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([1, 2, 3], result) self.assertTrue(cond.locked()) @@ -462,20 +462,20 @@ def c1(result): t = tasks.Task(c1(result)) - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([], result) self.loop.run_until_complete(cond.acquire()) cond.notify() cond.release() - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([], result) presult = True self.loop.run_until_complete(cond.acquire()) cond.notify() cond.release() - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([1], result) self.assertTrue(t.done()) @@ -502,13 +502,13 @@ def c1(result): t0 = time.monotonic() - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([], result) self.loop.run_until_complete(cond.acquire()) cond.notify() cond.release() - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([], result) self.loop.run_until_complete(wait_for) @@ -562,20 +562,20 @@ def c3(result): t2 = tasks.Task(c2(result)) t3 = tasks.Task(c3(result)) - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([], result) self.loop.run_until_complete(cond.acquire()) cond.notify(1) cond.release() - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([1], result) self.loop.run_until_complete(cond.acquire()) cond.notify(1) cond.notify(2048) cond.release() - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([1, 2, 3], result) self.assertTrue(t1.done()) @@ -609,13 +609,13 @@ def c2(result): t1 = tasks.Task(c1(result)) t2 = tasks.Task(c2(result)) - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([], result) self.loop.run_until_complete(cond.acquire()) cond.notify_all() cond.release() - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([1, 2], result) self.assertTrue(t1.done()) @@ -713,7 +713,7 @@ def c4(result): t2 = tasks.Task(c2(result)) t3 = tasks.Task(c3(result)) - run_once(self.loop) + run_briefly(self.loop) self.assertEqual([1], result) self.assertTrue(sem.locked()) self.assertEqual(2, len(sem._waiters)) @@ -725,7 +725,7 @@ def c4(result): sem.release() self.assertEqual(2, sem._value) - run_once(self.loop) + run_briefly(self.loop) self.assertEqual(0, sem._value) self.assertEqual([1, 2, 3], result) self.assertTrue(sem.locked()) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index a65ecc89..ba0afef2 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -7,6 +7,7 @@ from tulip import events from tulip import futures from tulip import tasks +from tulip import test_utils class Dummy: @@ -197,7 +198,7 @@ def task(): return 12 t = task() - self.loop.run_once() # start coro + test_utils.run_briefly(self.loop) # start coro t.cancel() self.assertRaises( futures.CancelledError, self.loop.run_until_complete, t) @@ -219,14 +220,14 @@ def task(): yield from fut3 t = task() - self.loop.run_once() + test_utils.run_briefly(self.loop) fut1.set_result(None) t.cancel() - self.loop.run_once() # process fut1 result, delay cancel + test_utils.run_once(self.loop) # process fut1 result, delay cancel self.assertFalse(t.done()) - self.loop.run_once() # cancel fut2, but coro still alive + test_utils.run_once(self.loop) # cancel fut2, but coro still alive self.assertFalse(t.done()) - self.loop.run_once() # cancel fut3 + test_utils.run_briefly(self.loop) # cancel fut3 self.assertTrue(t.done()) self.assertEqual(fut1.result(), None) @@ -585,12 +586,12 @@ def call_later(self, delay, callback, *args): return handle self.loop.call_later = call_later - self.loop.run_once() + test_utils.run_briefly(self.loop) self.assertFalse(handle._cancelled) t.cancel() - self.loop.run_once() + test_utils.run_briefly(self.loop) self.assertTrue(handle._cancelled) def test_task_cancel_sleeping_task(self): @@ -636,7 +637,7 @@ def coro(): pass task = coro() - self.loop.run_once() + test_utils.run_briefly(self.loop) self.assertIs(task._fut_waiter, fut) task.cancel() @@ -686,12 +687,12 @@ def wait_for_future(): result = yield from fut t = wait_for_future() - self.loop.run_once() + test_utils.run_briefly(self.loop) self.assertTrue(fut.cb_added) res = object() fut.set_result(res) - self.loop.run_once() + test_utils.run_briefly(self.loop) self.assertIs(res, result) self.assertTrue(t.done()) self.assertIsNone(t.result()) @@ -720,12 +721,12 @@ def notmutch(): raise BaseException() task = tasks.Task(notmutch()) - self.loop.run_once() + test_utils.run_briefly(self.loop) task.cancel() self.assertFalse(task.done()) - self.assertRaises(BaseException, self.loop.run_once) + self.assertRaises(BaseException, test_utils.run_briefly, self.loop) self.assertTrue(task.done()) self.assertTrue(task.cancelled()) diff --git a/tulip/test_utils.py b/tulip/test_utils.py index 6757d93c..d97dc618 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -31,13 +31,23 @@ from socket import socketpair # pragma: no cover -def run_once(loop): +def run_briefly(loop): @tulip.task def once(): pass loop.run_until_complete(once()) +def run_once(loop): + """loop.stop() schedules _raise_stop_error() + and run_forever() runs until _raise_stop_error() callback. + this wont work if test waits for some IO events, because + _raise_stop_error() runs before any of io events callbacks. + """ + loop.stop() + loop.run_forever() + + @contextlib.contextmanager def run_test_server(loop, *, host='127.0.0.1', port=0, use_ssl=False, router=None): @@ -127,7 +137,7 @@ def run(loop, fut): for tr in transports: tr.close() - run_once(thread_loop) # call close callbacks + run_briefly(thread_loop) # call close callbacks for s in socks: thread_loop.stop_serving(s) From a4d789700b39449cc9199b2110a13b440d9d6e6e Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 5 Jun 2013 09:44:27 -0700 Subject: [PATCH 0500/1502] deprecate run_once() --- tests/base_events_test.py | 9 ++------- tests/events_test.py | 34 ---------------------------------- tulip/base_events.py | 16 ---------------- tulip/events.py | 7 ------- 4 files changed, 2 insertions(+), 64 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index 4e820ad8..45a0c30b 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -11,6 +11,7 @@ from tulip import futures from tulip import protocols from tulip import tasks +from tulip import test_utils class BaseEventLoopTests(unittest.TestCase): @@ -106,7 +107,7 @@ def cb(arg): self.loop._process_events = unittest.mock.Mock() self.loop.call_later(-1, cb, 'a') self.loop.call_later(-2, cb, 'b') - self.loop.run_once() + test_utils.run_briefly(self.loop) self.assertEqual(calls, ['b', 'a']) def test_time_and_call_at(self): @@ -164,12 +165,6 @@ def cb(): f.cancel() # Don't complain about abandoned Future. - def test_run_once(self): - self.loop._run_once = unittest.mock.Mock() - self.loop._run_once.side_effect = base_events._StopError - self.loop.run_once() - self.assertTrue(self.loop._run_once.called) - def test__run_once(self): h1 = events.TimerHandle(time.monotonic() + 0.1, lambda: True, ()) h2 = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ()) diff --git a/tests/events_test.py b/tests/events_test.py index 0ffabbd7..1c589130 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -166,38 +166,6 @@ def coro2(): self.assertRaises( RuntimeError, self.loop.run_until_complete, coro2()) - def test_run_once_nesting(self): - @tasks.coroutine - def coro(): - tasks.sleep(0.1) - self.loop.run_once() - - self.assertRaises( - RuntimeError, - self.loop.run_until_complete, coro()) - - def test_run_once_block(self): - called = False - - def callback(): - nonlocal called - called = True - - def run(): - time.sleep(0.1) - self.loop.call_soon_threadsafe(callback) - - self.loop.run_once(0) # windows iocp - - t = threading.Thread(target=run) - t0 = time.monotonic() - t.start() - self.loop.run_once(None) - t1 = time.monotonic() - t.join() - self.assertTrue(called) - self.assertTrue(0.09 < t1-t0 <= 0.15) - def test_call_later(self): results = [] @@ -1055,8 +1023,6 @@ def test_not_implemented(self): loop = events.AbstractEventLoop() self.assertRaises( NotImplementedError, loop.run_forever) - self.assertRaises( - NotImplementedError, loop.run_once) self.assertRaises( NotImplementedError, loop.run_until_complete, None) self.assertRaises( diff --git a/tulip/base_events.py b/tulip/base_events.py index 57c386b1..e4eae9ba 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -104,22 +104,6 @@ def run_forever(self): finally: self._running = False - def run_once(self, timeout=0): - """Run through all callbacks and all I/O polls once. - - Calling stop() will break out of this too. - """ - if self._running: - raise RuntimeError('Event loop is running.') - - self._running = True - try: - self._run_once(timeout) - except _StopError: - pass - finally: - self._running = False - def run_until_complete(self, future, timeout=None): """Run until the Future is done, or until a timeout. diff --git a/tulip/events.py b/tulip/events.py index b1b5186c..a1a5fd3e 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -108,13 +108,6 @@ def run_forever(self): """Run the event loop until stop() is called.""" raise NotImplementedError - def run_once(self, timeout=None): - """Run one complete cycle of the event loop. - - TODO: Deprecate this. - """ - raise NotImplementedError - def run_until_complete(self, future, timeout=None): """Run the event loop until a Future is done. From f8a29e2d871ef73a7bf528c73dcaa2f90fb6d6e1 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 7 Jun 2013 11:24:44 -0700 Subject: [PATCH 0501/1502] Fix issue #47 in run_until_complete(), adding tests. --- tests/events_test.py | 29 ++++++++++++++++++++++++++++- tests/tasks_test.py | 2 +- tulip/base_events.py | 19 ++++++++++++------- 3 files changed, 41 insertions(+), 9 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 1c589130..69de0085 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -153,7 +153,7 @@ def tearDown(self): gc.collect() super().tearDown() - def test_run_nesting(self): + def test_run_until_complete_nesting(self): @tasks.coroutine def coro1(): yield @@ -166,6 +166,33 @@ def coro2(): self.assertRaises( RuntimeError, self.loop.run_until_complete, coro2()) + def test_run_until_complete(self): + t0 = self.loop.time() + self.loop.run_until_complete(tasks.sleep(0.010)) + t1 = self.loop.time() + self.assertTrue(0.009 <= t1-t0 <= 0.012) + + def test_run_until_complete_stopped(self): + @tasks.task + def cb(): + self.loop.stop() + yield from tasks.sleep(0.010) + task = cb() + self.assertRaises(RuntimeError, + self.loop.run_until_complete, task) + + def test_run_until_complete_timeout(self): + t0 = self.loop.time() + task = tasks.async(tasks.sleep(0.020)) + self.assertRaises(futures.TimeoutError, + self.loop.run_until_complete, + task, timeout=0.010) + t1 = self.loop.time() + self.assertTrue(0.009 <= t1-t0 <= 0.012) + self.loop.run_until_complete(task) + t2 = self.loop.time() + self.assertTrue(0.009 <= t2-t1 <= 0.012) + def test_call_later(self): results = [] diff --git a/tests/tasks_test.py b/tests/tasks_test.py index ba0afef2..7a91d0fc 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -295,7 +295,7 @@ def task(): t = tasks.Task(task()) t0 = time.monotonic() self.assertRaises( - futures.InvalidStateError, self.loop.run_until_complete, t) + RuntimeError, self.loop.run_until_complete, t) t1 = time.monotonic() self.assertFalse(t.done()) self.assertTrue(0.18 <= t1-t0 <= 0.22) diff --git a/tulip/base_events.py b/tulip/base_events.py index e4eae9ba..b4ae4595 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -117,25 +117,30 @@ def run_until_complete(self, future, timeout=None): timeout is reached or stop() is called, raise TimeoutError. """ future = tasks.async(future) - handle_called = False - - def stop_loop(): - nonlocal handle_called - handle_called = True - raise _StopError - future.add_done_callback(_raise_stop_error) + handle_called = False if timeout is None: self.run_forever() else: + + def stop_loop(): + nonlocal handle_called + handle_called = True + raise _StopError + handle = self.call_later(timeout, stop_loop) self.run_forever() handle.cancel() + future.remove_done_callback(_raise_stop_error) + if handle_called: raise futures.TimeoutError + if not future.done(): + raise RuntimeError('Event loop stopped before Future completed.') + return future.result() def stop(self): From 3c86d9a56e8f47c5a514df89a10f46af648fff02 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 7 Jun 2013 11:30:34 -0700 Subject: [PATCH 0502/1502] Adjust some timing constraints. I see 0.016 in practice. --- tests/events_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 69de0085..5df4598d 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -170,7 +170,7 @@ def test_run_until_complete(self): t0 = self.loop.time() self.loop.run_until_complete(tasks.sleep(0.010)) t1 = self.loop.time() - self.assertTrue(0.009 <= t1-t0 <= 0.012) + self.assertTrue(0.009 <= t1-t0 <= 0.018, t1-t0) def test_run_until_complete_stopped(self): @tasks.task @@ -188,10 +188,10 @@ def test_run_until_complete_timeout(self): self.loop.run_until_complete, task, timeout=0.010) t1 = self.loop.time() - self.assertTrue(0.009 <= t1-t0 <= 0.012) + self.assertTrue(0.009 <= t1-t0 <= 0.018, t1-t0) self.loop.run_until_complete(task) t2 = self.loop.time() - self.assertTrue(0.009 <= t2-t1 <= 0.012) + self.assertTrue(0.009 <= t2-t1 <= 0.018, t1-t0) def test_call_later(self): results = [] From d2222046aa3121036d5af31af396b2327c5f9d12 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Tue, 11 Jun 2013 15:10:58 +0100 Subject: [PATCH 0503/1502] Refactor proactor socket transports to support read/write pipe transports. --- tests/windows_events_test.py | 78 +++++++++++++++++++++ tulip/proactor_events.py | 131 ++++++++++++++++++++++++----------- 2 files changed, 168 insertions(+), 41 deletions(-) create mode 100644 tests/windows_events_test.py diff --git a/tests/windows_events_test.py b/tests/windows_events_test.py new file mode 100644 index 00000000..bb3a8a65 --- /dev/null +++ b/tests/windows_events_test.py @@ -0,0 +1,78 @@ +import unittest + +import tulip + +from tulip import windows_events +from tulip import protocols +from tulip import streams +from tulip import test_utils + + +def connect_read_pipe(loop, file): + stream_reader = streams.StreamReader() + protocol = _StreamReaderProtocol(stream_reader) + transport = loop._make_read_pipe_transport(file, protocol) + return stream_reader + + +class _StreamReaderProtocol(protocols.Protocol): + def __init__(self, stream_reader): + self.stream_reader = stream_reader + def connection_lost(self, exc): + self.stream_reader.set_exception(exc) + def data_received(self, data): + self.stream_reader.feed_data(data) + def eof_received(self): + self.stream_reader.feed_eof() + + +class ProactorTests(unittest.TestCase): + + def setUp(self): + self.loop = windows_events.ProactorEventLoop() + tulip.set_event_loop(self.loop) + + def tearDown(self): + self.loop.close() + + def test_pause_resume_discard(self): + a, b = self.loop._socketpair() + trans = self.loop._make_write_pipe_transport(a, protocols.Protocol()) + reader = connect_read_pipe(self.loop, b) + f = tulip.async(reader.readline()) + + trans.write(b'msg1\n') + self.loop.run_until_complete(f, timeout=0.01) + self.assertEqual(f.result(), b'msg1\n') + f = tulip.async(reader.readline()) + + trans.pause_writing() + trans.write(b'msg2\n') + with self.assertRaises(tulip.TimeoutError): + self.loop.run_until_complete(f, timeout=0.01) + self.assertEqual(trans._buffer, [b'msg2\n']) + + trans.resume_writing() + self.loop.run_until_complete(f, timeout=0.1) + self.assertEqual(f.result(), b'msg2\n') + f = tulip.async(reader.readline()) + + trans.pause_writing() + trans.write(b'msg3\n') + self.assertEqual(trans._buffer, [b'msg3\n']) + trans.discard_output() + self.assertEqual(trans._buffer, []) + + trans.write(b'msg4\n') + self.assertEqual(trans._buffer, [b'msg4\n']) + trans.resume_writing() + self.loop.run_until_complete(f, timeout=0.01) + self.assertEqual(f.result(), b'msg4\n') + + def test_close(self): + a, b = self.loop._socketpair() + trans = self.loop._make_socket_transport(a, protocols.Protocol()) + f = tulip.async(self.loop.sock_recv(b, 100)) + trans.close() + self.loop.run_until_complete(f, timeout=1) + self.assertEqual(f.result(), b'') diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index c04e924f..af6e00f8 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -4,6 +4,8 @@ proactor is only implemented on Windows with IOCP. """ +import socket + from . import base_events from . import constants from . import futures @@ -11,26 +13,76 @@ from .log import tulip_log -class _ProactorSocketTransport(transports.Transport): +class _ProactorBasePipeTransport(transports.BaseTransport): + """Base class for pipe and socket transports.""" - def __init__(self, loop, sock, protocol, waiter=None, extra=None, - write_only=False): + def __init__(self, loop, sock, protocol, waiter=None, extra=None): super().__init__(extra) - self._extra['socket'] = sock + self._set_extra(sock) self._loop = loop self._sock = sock self._protocol = protocol self._buffer = [] self._read_fut = None self._write_fut = None + self._writing_disabled = False self._conn_lost = 0 self._closing = False # Set when close() called. self._loop.call_soon(self._protocol.connection_made, self) - if not write_only: - self._loop.call_soon(self._loop_reading) if waiter is not None: self._loop.call_soon(waiter.set_result, None) + def _set_extra(self, sock): + self._extra['pipe'] = sock + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + if not self._buffer and self._write_fut is None: + self._loop.call_soon(self._call_connection_lost, None) + if self._read_fut is not None: + self._read_fut.cancel() + + def _fatal_error(self, exc): + tulip_log.exception('Fatal error for %s', self) + self._force_close(exc) + + def _force_close(self, exc): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + if self._write_fut: + self._write_fut.cancel() + if self._read_fut: + self._read_fut.cancel() + self._write_fut = self._read_fut = None + self._buffer = [] + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + # XXX If there is a pending overlapped read on the other + # end then it may fail with ERROR_NETNAME_DELETED if we + # just close our end. First calling shutdown() seems to + # cure it, but maybe using DisconnectEx() would be better. + if hasattr(self._sock, 'shutdown'): + self._sock.shutdown(socket.SHUT_RDWR) + self._sock.close() + + +class _ProactorReadPipeTransport(_ProactorBasePipeTransport, + transports.ReadTransport): + """Transport for read pipes.""" + + def __init__(self, loop, sock, protocol, waiter=None, extra=None): + super().__init__(loop, sock, protocol, waiter, extra) + self._loop.call_soon(self._loop_reading) + def _loop_reading(self, fut=None): data = None @@ -58,6 +110,9 @@ def _loop_reading(self, fut=None): self._force_close(exc) except OSError as exc: self._fatal_error(exc) + except futures.CancelledError: + if not self._closing: + raise else: self._read_fut.add_done_callback(self._loop_reading) finally: @@ -69,6 +124,11 @@ def _loop_reading(self, fut=None): finally: self.close() + +class _ProactorWritePipeTransport(_ProactorBasePipeTransport, + transports.WriteTransport): + """Transport for write pipes.""" + def write(self, data): assert isinstance(data, bytes), repr(data) if not data: @@ -80,62 +140,52 @@ def write(self, data): self._conn_lost += 1 return self._buffer.append(data) - if not self._write_fut: + if self._write_fut is None and not self._writing_disabled: self._loop_writing() def _loop_writing(self, f=None): try: assert f is self._write_fut + self._write_fut = None if f: f.result() data = b''.join(self._buffer) self._buffer = [] if not data: - self._write_fut = None if self._closing: self._loop.call_soon(self._call_connection_lost, None) return - self._write_fut = self._loop._proactor.send(self._sock, data) + if not self._writing_disabled: + self._write_fut = self._loop._proactor.send(self._sock, data) + self._write_fut.add_done_callback(self._loop_writing) except OSError as exc: self._fatal_error(exc) - else: - self._write_fut.add_done_callback(self._loop_writing) # TODO: write_eof(), can_write_eof(). def abort(self): self._force_close(None) - def close(self): - if self._closing: - return - self._closing = True - self._conn_lost += 1 - if not self._buffer and self._write_fut is None: - self._loop.call_soon(self._call_connection_lost, None) + def pause_writing(self): + self._writing_disabled = True - def _fatal_error(self, exc): - tulip_log.exception('Fatal error for %s', self) - self._force_close(exc) + def resume_writing(self): + self._writing_disabled = False + if self._buffer and self._write_fut is None: + self._loop_writing() - def _force_close(self, exc): - if self._closing: - return - self._closing = True - self._conn_lost += 1 - if self._write_fut: - self._write_fut.cancel() - if self._read_fut: # XXX - self._read_fut.cancel() - self._write_fut = self._read_fut = None - self._buffer = [] - self._loop.call_soon(self._call_connection_lost, exc) + def discard_output(self): + if self._buffer: + self._buffer = [] - def _call_connection_lost(self, exc): - try: - self._protocol.connection_lost(exc) - finally: - self._sock.close() + +class _ProactorSocketTransport(_ProactorReadPipeTransport, + _ProactorWritePipeTransport, + transports.Transport): + """Transport for connected sockets.""" + + def _set_extra(self, sock): + self._extra['socket'] = sock class BaseProactorEventLoop(base_events.BaseEventLoop): @@ -152,12 +202,11 @@ def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): def _make_read_pipe_transport(self, sock, protocol, waiter=None, extra=None): - return _ProactorSocketTransport(self, sock, protocol, waiter, extra) + return _ProactorReadPipeTransport(self, sock, protocol, waiter, extra) def _make_write_pipe_transport(self, sock, protocol, waiter=None, extra=None): - return _ProactorSocketTransport(self, sock, protocol, waiter, extra, - write_only=True) + return _ProactorWritePipeTransport(self, sock, protocol, waiter, extra) def close(self): if self._proactor is not None: From ff61a3a8a2e61d8425a928d9307cabe9fb768bc0 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 15 Jun 2013 17:20:44 -0700 Subject: [PATCH 0504/1502] Fix copy/paste error in assertEqual message. --- tests/events_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/events_test.py b/tests/events_test.py index 5df4598d..e0c82c2d 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -191,7 +191,7 @@ def test_run_until_complete_timeout(self): self.assertTrue(0.009 <= t1-t0 <= 0.018, t1-t0) self.loop.run_until_complete(task) t2 = self.loop.time() - self.assertTrue(0.009 <= t2-t1 <= 0.018, t1-t0) + self.assertTrue(0.009 <= t2-t1 <= 0.018, t2-t1) def test_call_later(self): results = [] From b4e31bf620db1fca7d9ab438ee1967d65cbdc9f1 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 26 Jun 2013 11:55:13 -0700 Subject: [PATCH 0505/1502] readpayload option of WSGIServerHttpProtocol is broken, #48 by Aymeric Augustin --- tests/http_wsgi_test.py | 23 ++++++++++++++++++++--- tulip/http/wsgi.py | 1 + 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/tests/http_wsgi_test.py b/tests/http_wsgi_test.py index bedfca6d..fc25b38c 100644 --- a/tests/http_wsgi_test.py +++ b/tests/http_wsgi_test.py @@ -20,7 +20,6 @@ def setUp(self): self.transport = unittest.mock.Mock() self.transport.get_extra_info.return_value = '127.0.0.1' - self.payload = b'data' self.headers = [] self.message = protocol.RawRequestMessage( 'GET', '/path', (1, 0), self.headers, True, 'deflate') @@ -210,7 +209,7 @@ def wsgi_app(env, start): self.message = protocol.RawRequestMessage( 'GET', '/path', (1, 1), self.headers, True, 'deflate') - srv = wsgi.WSGIServerHttpProtocol(wsgi_app, True) + srv = wsgi.WSGIServerHttpProtocol(wsgi_app, readpayload=True) srv.stream = self.stream srv.transport = self.transport @@ -254,7 +253,7 @@ def wsgi_app(env, start): self.message = protocol.RawRequestMessage( 'GET', '/path', (1, 1), self.headers, False, 'deflate') - srv = wsgi.WSGIServerHttpProtocol(wsgi_app, True) + srv = wsgi.WSGIServerHttpProtocol(wsgi_app, readpayload=True) srv.stream = self.stream srv.transport = self.transport @@ -266,3 +265,21 @@ def wsgi_app(env, start): self.assertTrue(content.startswith(b'HTTP/1.1 200 OK')) self.assertTrue(content.endswith(b'data\r\n0\r\n\r\n')) self.assertTrue(srv._keep_alive) + + def test_handle_request_readpayload(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + return [env['wsgi.input'].read()] + + srv = wsgi.WSGIServerHttpProtocol(wsgi_app, readpayload=True) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) + self.assertTrue(content.endswith(b'data')) diff --git a/tulip/http/wsgi.py b/tulip/http/wsgi.py index f36bde63..55ddc9ff 100644 --- a/tulip/http/wsgi.py +++ b/tulip/http/wsgi.py @@ -149,6 +149,7 @@ def handle_request(self, message, payload): while chunk: wsgiinput.write(chunk) chunk = yield from payload.read() + wsgiinput.seek(0) payload = wsgiinput environ = self.create_wsgi_environ(message, payload) From ceb3aeeb78923d288d8ac6e55dda9a1f3e89a3eb Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 29 Jul 2013 10:36:33 -0700 Subject: [PATCH 0506/1502] refactor websocket handshake process, make it usable with wsgi server --- examples/wssrv.py | 2 +- tests/http_server_test.py | 4 ++-- tests/http_websocket_test.py | 29 +++++++++++++++++++++-------- tests/http_wsgi_test.py | 13 +++++++++++++ tulip/http/errors.py | 8 +++++--- tulip/http/protocol.py | 4 +++- tulip/http/websocket.py | 20 +++++++++++++------- tulip/http/wsgi.py | 11 ++++++++--- 8 files changed, 66 insertions(+), 25 deletions(-) diff --git a/examples/wssrv.py b/examples/wssrv.py index aecce9f7..f96e0855 100755 --- a/examples/wssrv.py +++ b/examples/wssrv.py @@ -41,7 +41,7 @@ def handle_request(self, message, payload): if upgrade: # websocket handshake status, headers, parser, writer = websocket.do_handshake( - message, self.transport) + message.method, message.headers, self.transport) resp = tulip.http.Response(self.transport, status) resp.add_headers(*headers) diff --git a/tests/http_server_test.py b/tests/http_server_test.py index ac52d1bb..0d20ecca 100644 --- a/tests/http_server_test.py +++ b/tests/http_server_test.py @@ -18,8 +18,8 @@ def setUp(self): def tearDown(self): self.loop.close() - def test_http_status_exception(self): - exc = errors.HttpStatusException(500, message='Internal error') + def test_http_error_exception(self): + exc = errors.HttpErrorException(500, message='Internal error') self.assertEqual(exc.code, 500) self.assertEqual(exc.message, 'Internal error') diff --git a/tests/http_websocket_test.py b/tests/http_websocket_test.py index bd89b75b..319538ae 100644 --- a/tests/http_websocket_test.py +++ b/tests/http_websocket_test.py @@ -358,31 +358,41 @@ def setUp(self): self.message = protocol.RawRequestMessage( 'GET', '/path', (1, 0), self.headers, True, None) + def test_not_get(self): + self.assertRaises( + errors.HttpErrorException, + websocket.do_handshake, + 'POST', self.message.headers, self.transport) + def test_no_upgrade(self): self.assertRaises( errors.BadRequestException, - websocket.do_handshake, self.message, self.transport) + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) def test_no_connection(self): self.headers.extend([('UPGRADE', 'websocket'), ('CONNECTION', 'keep-alive')]) self.assertRaises( errors.BadRequestException, - websocket.do_handshake, self.message, self.transport) + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) def test_protocol_version(self): self.headers.extend([('UPGRADE', 'websocket'), ('CONNECTION', 'upgrade')]) self.assertRaises( errors.BadRequestException, - websocket.do_handshake, self.message, self.transport) + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) self.headers.extend([('UPGRADE', 'websocket'), ('CONNECTION', 'upgrade'), ('SEC-WEBSOCKET-VERSION', '1')]) self.assertRaises( errors.BadRequestException, - websocket.do_handshake, self.message, self.transport) + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) def test_protocol_key(self): self.headers.extend([('UPGRADE', 'websocket'), @@ -390,7 +400,8 @@ def test_protocol_key(self): ('SEC-WEBSOCKET-VERSION', '13')]) self.assertRaises( errors.BadRequestException, - websocket.do_handshake, self.message, self.transport) + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) self.headers.extend([('UPGRADE', 'websocket'), ('CONNECTION', 'upgrade'), @@ -398,7 +409,8 @@ def test_protocol_key(self): ('SEC-WEBSOCKET-KEY', '123')]) self.assertRaises( errors.BadRequestException, - websocket.do_handshake, self.message, self.transport) + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) sec_key = base64.b64encode(os.urandom(2)) self.headers.extend([('UPGRADE', 'websocket'), @@ -407,7 +419,8 @@ def test_protocol_key(self): ('SEC-WEBSOCKET-KEY', sec_key.decode())]) self.assertRaises( errors.BadRequestException, - websocket.do_handshake, self.message, self.transport) + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) def test_handshake(self): sec_key = base64.b64encode(os.urandom(16)).decode() @@ -417,7 +430,7 @@ def test_handshake(self): ('SEC-WEBSOCKET-VERSION', '13'), ('SEC-WEBSOCKET-KEY', sec_key)]) status, headers, parser, writer = websocket.do_handshake( - self.message, self.transport) + self.message.method, self.message.headers, self.transport) self.assertEqual(status, 101) key = base64.b64encode( diff --git a/tests/http_wsgi_test.py b/tests/http_wsgi_test.py index fc25b38c..34399665 100644 --- a/tests/http_wsgi_test.py +++ b/tests/http_wsgi_test.py @@ -161,6 +161,19 @@ def test_wsgi_response_start_response_exc_status(self): resp.start_response, '500 Err', [('CONTENT-TYPE', 'text/plain')], ['', ValueError()]) + @unittest.mock.patch('tulip.http.wsgi.tulip') + def test_wsgi_response_101_upgrade_to_websocket(self, m_tulip): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + resp.start_response( + '101 Switching Protocols', (('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'))) + self.assertEqual(resp.status, '101 Switching Protocols') + self.assertTrue(m_tulip.http.Response.return_value.send_headers.called) + def test_file_wrapper(self): fobj = io.BytesIO(b'data') wrapper = wsgi.FileWrapper(fobj, 2) diff --git a/tulip/http/errors.py b/tulip/http/errors.py index 24032337..f8b77e9b 100644 --- a/tulip/http/errors.py +++ b/tulip/http/errors.py @@ -1,6 +1,6 @@ """http related errors.""" -__all__ = ['HttpException', 'HttpStatusException', +__all__ = ['HttpException', 'HttpErrorException', 'BadRequestException', 'IncompleteRead', 'BadStatusLine', 'LineTooLong', 'InvalidHeader'] import http.client @@ -10,11 +10,12 @@ class HttpException(http.client.HTTPException): code = None headers = () + message = '' -class HttpStatusException(HttpException): +class HttpErrorException(HttpException): - def __init__(self, code, headers=None, message=''): + def __init__(self, code, message='', headers=None): self.code = code self.headers = headers self.message = message @@ -23,6 +24,7 @@ def __init__(self, code, headers=None, message=''): class BadRequestException(HttpException): code = 400 + message = 'Bad Request' class IncompleteRead(BadRequestException, http.client.IncompleteRead): diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py index 997d6ef3..30ecb83b 100644 --- a/tulip/http/protocol.py +++ b/tulip/http/protocol.py @@ -468,6 +468,8 @@ class HttpMessage: status = None status_line = b'' + upgrade = False # Connection: UPGRADE + websocket = False # Upgrade: WEBSOCKET # subclass can enable auto sending headers with write() call, # this is useful for wsgi's start_response implementation. @@ -481,7 +483,6 @@ def __init__(self, transport, version, close): self.chunked = False self.length = None - self.upgrade = False self.headers = collections.deque() self.headers_sent = False @@ -525,6 +526,7 @@ def add_header(self, name, value): elif name == 'UPGRADE': if 'websocket' in value.lower(): + self.websocket = True self.headers.append((name, value)) elif name == 'TRANSFER-ENCODING' and not self.chunked: diff --git a/tulip/http/websocket.py b/tulip/http/websocket.py index 71784bcc..c3dd5872 100644 --- a/tulip/http/websocket.py +++ b/tulip/http/websocket.py @@ -184,16 +184,22 @@ def close(self, code=1000, message=b''): opcode=OPCODE_CLOSE) -def do_handshake(message, transport): +def do_handshake(method, headers, transport): """Prepare WebSocket handshake. It return http response code, response headers, websocket parser, websocket writer. It does not - do any IO.""" - headers = dict(((hdr, val) - for hdr, val in message.headers if hdr in WS_HDRS)) + perform any IO.""" + + # WebSocket accepts only GET + if method.upper() != 'GET': + raise errors.HttpErrorException(405, headers=(('Allow', 'GET'),)) + + headers = dict(((hdr, val) for hdr, val in headers if hdr in WS_HDRS)) if 'websocket' != headers.get('UPGRADE', '').lower().strip(): - raise errors.BadRequestException('No WebSocket UPGRADE hdr: {}'.format( - headers.get('UPGRADE'))) + raise errors.BadRequestException( + 'No WebSocket UPGRADE hdr: {}\n' + 'Can "Upgrade" only to "WebSocket".'.format( + headers.get('UPGRADE'))) if 'upgrade' not in headers.get('CONNECTION', '').lower(): raise errors.BadRequestException( @@ -202,7 +208,7 @@ def do_handshake(message, transport): # check supported version version = headers.get('SEC-WEBSOCKET-VERSION') - if version not in ('13', '8'): + if version not in ('13', '8', '7'): raise errors.BadRequestException( 'Unsupported version: {}'.format(version)) diff --git a/tulip/http/wsgi.py b/tulip/http/wsgi.py index 55ddc9ff..738e100f 100644 --- a/tulip/http/wsgi.py +++ b/tulip/http/wsgi.py @@ -214,9 +214,14 @@ def start_response(self, status, headers, exc_info=None): status_code = int(status.split(' ', 1)[0]) self.status = status - self.response = tulip.http.Response( + resp = self.response = tulip.http.Response( self.transport, status_code, self.message.version, self.message.should_close) - self.response.add_headers(*headers) - self.response._send_headers = True + resp.add_headers(*headers) + + # send headers immediately for websocket connection + if status_code == 101 and resp.upgrade and resp.websocket: + resp.send_headers() + else: + resp._send_headers = True return self.response.write From 1d7b4a751d0e48fc8a4235e82c4290f1b8cc2910 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 5 Aug 2013 14:40:04 -0700 Subject: [PATCH 0507/1502] Better way of checking two sets of times. --- tests/events_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/events_test.py b/tests/events_test.py index e0c82c2d..f5a509e7 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -191,7 +191,7 @@ def test_run_until_complete_timeout(self): self.assertTrue(0.009 <= t1-t0 <= 0.018, t1-t0) self.loop.run_until_complete(task) t2 = self.loop.time() - self.assertTrue(0.009 <= t2-t1 <= 0.018, t2-t1) + self.assertTrue(0.018 <= t2-t0 <= 0.028, t2-t0) def test_call_later(self): results = [] From 5ba66d329915d080d0bc0fd02bed11b7db05f9f7 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 5 Aug 2013 14:43:24 -0700 Subject: [PATCH 0508/1502] Re-export Full and Empty exceptions from tulip.queues instead of requiring import queue. --- tests/queues_test.py | 9 ++++----- tulip/queues.py | 23 ++++++++++++++--------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/tests/queues_test.py b/tests/queues_test.py index 6227f3d0..5632bbff 100644 --- a/tests/queues_test.py +++ b/tests/queues_test.py @@ -2,7 +2,6 @@ import unittest import unittest.mock -import queue from tulip import events from tulip import futures @@ -190,14 +189,14 @@ def test_nonblocking_get(self): def test_nonblocking_get_exception(self): q = queues.Queue() - self.assertRaises(queue.Empty, q.get_nowait) + self.assertRaises(queues.Empty, q.get_nowait) def test_get_timeout(self): q = queues.Queue() @tasks.coroutine def queue_get(): - with self.assertRaises(queue.Empty): + with self.assertRaises(queues.Empty): return (yield from q.get(timeout=0.01)) # Get works after timeout, with blocking and non-blocking put. @@ -273,7 +272,7 @@ def test_nonblocking_put(self): def test_nonblocking_put_exception(self): q = queues.Queue(maxsize=1) q.put_nowait(1) - self.assertRaises(queue.Full, q.put_nowait, 2) + self.assertRaises(queues.Full, q.put_nowait, 2) def test_put_timeout(self): q = queues.Queue(1) @@ -281,7 +280,7 @@ def test_put_timeout(self): @tasks.coroutine def queue_put(): - with self.assertRaises(queue.Full): + with self.assertRaises(queues.Full): return (yield from q.put(1, timeout=0.01)) self.assertEqual(0, q.get_nowait()) diff --git a/tulip/queues.py b/tulip/queues.py index 8bb35066..dfe0cae7 100644 --- a/tulip/queues.py +++ b/tulip/queues.py @@ -1,6 +1,6 @@ """Queues""" -__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue'] +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue', 'Full', 'Empty'] import collections import concurrent.futures @@ -13,6 +13,11 @@ from .tasks import coroutine +# Re-export queue.Full and .Empty exceptions. +Full = queue.Full +Empty = queue.Empty + + class Queue: """A queue, useful for coordinating producer and consumer coroutines. @@ -105,7 +110,7 @@ def put(self, item, timeout=None): If you yield from put() and timeout is None (the default), wait until a free slot is available before adding item. - If a timeout is provided, raise queue.Full if no free slot becomes + If a timeout is provided, raise Full if no free slot becomes available before the timeout. """ self._consume_done_getters(self._getters) @@ -127,7 +132,7 @@ def put(self, item, timeout=None): try: yield from waiter except concurrent.futures.CancelledError: - raise queue.Full + raise Full else: self._put(item) @@ -135,7 +140,7 @@ def put(self, item, timeout=None): def put_nowait(self, item): """Put an item into the queue without blocking. - If no free slot is immediately available, raise queue.Full. + If no free slot is immediately available, raise Full. """ self._consume_done_getters(self._getters) if self._getters: @@ -150,7 +155,7 @@ def put_nowait(self, item): getter.set_result(self._get()) elif self._maxsize > 0 and self._maxsize == self.qsize(): - raise queue.Full + raise Full else: self._put(item) @@ -161,7 +166,7 @@ def get(self, timeout=None): If you yield from get() and timeout is None (the default), wait until a item is available. - If a timeout is provided, raise queue.Empty if no item is available + If a timeout is provided, raise Empty if no item is available before the timeout. """ self._consume_done_putters() @@ -187,12 +192,12 @@ def get(self, timeout=None): try: return (yield from waiter) except concurrent.futures.CancelledError: - raise queue.Empty + raise Empty def get_nowait(self): """Remove and return an item from the queue. - Return an item if one is immediately available, else raise queue.Full. + Return an item if one is immediately available, else raise Full. """ self._consume_done_putters() if self._putters: @@ -207,7 +212,7 @@ def get_nowait(self): elif self.qsize(): return self._get() else: - raise queue.Empty + raise Empty class PriorityQueue(Queue): From d16ad79a6f1768ddc722d942a0cee1fb1ddb4e09 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 5 Aug 2013 15:43:47 -0700 Subject: [PATCH 0509/1502] Add open_connection() and class StreamReaderProtocol to streams.py. --- tests/streams_test.py | 32 ++++++++++++++++++++++ tulip/streams.py | 63 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 93 insertions(+), 2 deletions(-) diff --git a/tests/streams_test.py b/tests/streams_test.py index ab148bdd..3989075a 100644 --- a/tests/streams_test.py +++ b/tests/streams_test.py @@ -1,10 +1,12 @@ """Tests for streams.py.""" +from unittest import mock import unittest from tulip import events from tulip import streams from tulip import tasks +from tulip import test_utils class StreamReaderTests(unittest.TestCase): @@ -18,6 +20,36 @@ def setUp(self): def tearDown(self): self.loop.close() + def test_open_connection(self): + with test_utils.run_test_server(self.loop) as httpd: + f = streams.open_connection(*httpd.address, loop=self.loop) + reader, writer = self.loop.run_until_complete(f) + writer.write(b'GET / HTTP/1.0\r\n\r\n') + f = reader.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + f = reader.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + + def test_open_connection_no_loop_ssl(self): + with test_utils.run_test_server(self.loop, use_ssl=True) as httpd: + f = streams.open_connection(*httpd.address, ssl=True) + reader, writer = self.loop.run_until_complete(f) + writer.write(b'GET / HTTP/1.0\r\n\r\n') + f = reader.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + + def test_open_connection_error(self): + with test_utils.run_test_server(self.loop) as httpd: + f = streams.open_connection(*httpd.address) + reader, writer = self.loop.run_until_complete(f) + writer._protocol.connection_lost(ZeroDivisionError()) + f = reader.read() + with self.assertRaises(ZeroDivisionError): + self.loop.run_until_complete(f) + def test_feed_empty_data(self): stream = streams.StreamReader() diff --git a/tulip/streams.py b/tulip/streams.py index 51028ca7..3aec9ea6 100644 --- a/tulip/streams.py +++ b/tulip/streams.py @@ -1,16 +1,75 @@ """Stream-related things.""" -__all__ = ['StreamReader'] +__all__ = ['StreamReader', 'StreamReaderProtocol', 'open_connection'] import collections +from . import events from . import futures +from . import protocols from . import tasks +_DEFAULT_LIMIT = 2**16 + + +@tasks.coroutine +def open_connection(host=None, port=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """A wrapper for create_connection() returning a (reader, writer) pair. + + The reader returned is a StreamReader instance; the writer is a + Transport. + + The arguments are all the usual arguments to create_connection() + except protocol_factory; most common are positional host and port, + with various optional keyword arguments following. + + Additional optional keyword arguments are loop (to set the event loop + instance to use) and limit (to set the buffer limit passed to the + StreamReader). + + (If you want to customize the StreamReader and/or + StreamReaderProtocol classes, just copy the code -- there's + really nothing special here except some convenience.) + """ + if loop is None: + loop = events.get_event_loop() + reader = StreamReader(limit=limit) + protocol = StreamReaderProtocol(reader) + transport, _ = yield from loop.create_connection( + lambda: protocol, host, port, **kwds) + return reader, transport # (reader, writer) + + +class StreamReaderProtocol(protocols.Protocol): + """Trivial helper class to adapt between Protocol and StreamReader. + + (This is a helper class instead of making StreamReader itself a + Protocol subclass, because the StreamReader has other potential + uses, and to prevent the user of the StreamReader to accidentally + call inappropriate methods of the protocol.) + """ + + def __init__(self, stream_reader): + self.stream_reader = stream_reader + + def connection_lost(self, exc): + if exc is None: + self.stream_reader.feed_eof() + else: + self.stream_reader.set_exception(exc) + + def data_received(self, data): + self.stream_reader.feed_data(data) + + def eof_received(self): + self.stream_reader.feed_eof() + + class StreamReader: - def __init__(self, limit=2**16): + def __init__(self, limit=_DEFAULT_LIMIT): self.limit = limit # Max line length. (Security feature.) self.buffer = collections.deque() # Deque of bytes objects. self.byte_count = 0 # Bytes in buffer. From ba94d7c00a5e62593745d8aaa60e085294728003 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 6 Aug 2013 10:37:21 -0700 Subject: [PATCH 0510/1502] enable keep-alive only for http/1.1 clients --- tests/http_protocol_test.py | 8 ++++++++ tests/streams_test.py | 1 - tulip/http/protocol.py | 15 +++++++++++---- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py index 9455426a..fc6d2842 100644 --- a/tests/http_protocol_test.py +++ b/tests/http_protocol_test.py @@ -50,6 +50,14 @@ def test_keep_alive(self): msg.force_close() self.assertFalse(msg.keep_alive()) + def test_keep_alive_http10(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + self.assertFalse(msg.keepalive) + self.assertFalse(msg.keep_alive()) + + msg = protocol.Response(self.transport, 200, http_version=(1, 1)) + self.assertIsNone(msg.keepalive) + def test_add_header(self): msg = protocol.Response(self.transport, 200) self.assertEqual([], list(msg.headers)) diff --git a/tests/streams_test.py b/tests/streams_test.py index 3989075a..669da75d 100644 --- a/tests/streams_test.py +++ b/tests/streams_test.py @@ -1,6 +1,5 @@ """Tests for streams.py.""" -from unittest import mock import unittest from tulip import events diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py index 30ecb83b..7081fd59 100644 --- a/tulip/http/protocol.py +++ b/tulip/http/protocol.py @@ -72,8 +72,10 @@ def http_request_parser(max_line_size=8190, # read headers headers, close, compression = parse_headers( lines, max_line_size, max_headers, max_field_size) - if close is None: - close = version <= (1, 0) + if version <= (1, 0): + close = True + elif close is None: + close = False out.feed_data( RawRequestMessage( @@ -479,7 +481,12 @@ def __init__(self, transport, version, close): self.transport = transport self.version = version self.closing = close - self.keepalive = None + + # disable keep-alive for http/1.0 + if version <= (1, 0): + self.keepalive = False + else: + self.keepalive = None self.chunked = False self.length = None @@ -521,7 +528,7 @@ def add_header(self, name, value): # connection keep-alive elif 'close' in val: self.keepalive = False - elif 'keep-alive' in val: + elif 'keep-alive' in val and self.version >= (1, 1): self.keepalive = True elif name == 'UPGRADE': From 83b1f7e1442a0c32de246fe10a9f49e39c327831 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 6 Aug 2013 12:12:13 -0700 Subject: [PATCH 0511/1502] Remove stray space at start of docstring. --- tulip/futures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tulip/futures.py b/tulip/futures.py index 142004aa..3965e2b5 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -305,7 +305,7 @@ def set_result(self, result): self._schedule_callbacks() def set_exception(self, exception): - """ Mark the future done and set an exception. + """Mark the future done and set an exception. If the future is already done when this method is called, raises InvalidStateError. From d9611a51dea5accb35550bf824ae2155801ab156 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 6 Aug 2013 15:09:55 -0700 Subject: [PATCH 0512/1502] Opportunistically fix some cases where Future() was called without passing a loop. --- tests/base_events_test.py | 2 +- tests/http_client_functional_test.py | 26 +++++++++++++------------- tulip/base_events.py | 10 +++++----- tulip/http/client.py | 9 ++++++--- tulip/selector_events.py | 8 ++++---- tulip/test_utils.py | 4 ++-- 6 files changed, 31 insertions(+), 28 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index 45a0c30b..f80e8c24 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -148,7 +148,7 @@ def test_run_once_in_executor_plain(self): def cb(): pass h = events.Handle(cb, ()) - f = futures.Future() + f = futures.Future(loop=self.loop) executor = unittest.mock.Mock() executor.submit.return_value = f diff --git a/tests/http_client_functional_test.py b/tests/http_client_functional_test.py index 3abc0076..b87366ec 100644 --- a/tests/http_client_functional_test.py +++ b/tests/http_client_functional_test.py @@ -29,7 +29,7 @@ def test_HTTP_200_OK_METHOD(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: for meth in ('get', 'post', 'put', 'delete', 'head'): r = self.loop.run_until_complete( - client.request(meth, httpd.url('method', meth))) + client.request(meth, httpd.url('method', meth), loop=self.loop)) content1 = self.loop.run_until_complete(r.read()) content2 = self.loop.run_until_complete(r.read()) content = content1.decode() @@ -42,7 +42,7 @@ def test_HTTP_200_OK_METHOD(self): def test_HTTP_302_REDIRECT_GET(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: r = self.loop.run_until_complete( - client.request('get', httpd.url('redirect', 2))) + client.request('get', httpd.url('redirect', 2), loop=self.loop)) self.assertEqual(r.status, 200) self.assertEqual(2, httpd['redirects']) @@ -52,13 +52,13 @@ def test_HTTP_302_REDIRECT_NON_HTTP(self): self.assertRaises( ValueError, self.loop.run_until_complete, - client.request('get', httpd.url('redirect_err'))) + client.request('get', httpd.url('redirect_err'), loop=self.loop)) def test_HTTP_302_REDIRECT_POST(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: r = self.loop.run_until_complete( client.request('post', httpd.url('redirect', 2), - data={'some': 'data'})) + data={'some': 'data'}, loop=self.loop)) content = self.loop.run_until_complete(r.content.read()) content = content.decode() @@ -70,7 +70,7 @@ def test_HTTP_302_max_redirects(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: r = self.loop.run_until_complete( client.request('get', httpd.url('redirect', 5), - max_redirects=2)) + max_redirects=2, loop=self.loop)) self.assertEqual(r.status, 302) self.assertEqual(2, httpd['redirects']) @@ -79,7 +79,7 @@ def test_HTTP_200_GET_WITH_PARAMS(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: r = self.loop.run_until_complete( client.request('get', httpd.url('method', 'get'), - params={'q': 'test'})) + params={'q': 'test'}, loop=self.loop)) content = self.loop.run_until_complete(r.content.read()) content = content.decode() @@ -91,7 +91,7 @@ def test_HTTP_200_GET_WITH_MIXED_PARAMS(self): r = self.loop.run_until_complete( client.request( 'get', httpd.url('method', 'get') + '?test=true', - params={'q': 'test'})) + params={'q': 'test'}, loop=self.loop)) content = self.loop.run_until_complete(r.content.read()) content = content.decode() @@ -102,7 +102,7 @@ def test_POST_DATA(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: url = httpd.url('method', 'post') r = self.loop.run_until_complete( - client.request('post', url, data={'some': 'data'})) + client.request('post', url, data={'some': 'data'}, loop=self.loop)) self.assertEqual(r.status, 200) content = self.loop.run_until_complete(r.read(True)) @@ -114,7 +114,7 @@ def test_POST_DATA_DEFLATE(self): url = httpd.url('method', 'post') r = self.loop.run_until_complete( client.request('post', url, - data={'some': 'data'}, compress=True)) + data={'some': 'data'}, compress=True, loop=self.loop)) self.assertEqual(r.status, 200) content = self.loop.run_until_complete(r.read(True)) @@ -130,7 +130,7 @@ def test_POST_FILES(self): r = self.loop.run_until_complete( client.request( 'post', url, files={'some': f}, chunked=1024, - headers={'Transfer-Encoding': 'chunked'})) + headers={'Transfer-Encoding': 'chunked'}, loop=self.loop)) content = self.loop.run_until_complete(r.read(True)) f.seek(0) @@ -152,7 +152,7 @@ def test_POST_FILES_DEFLATE(self): with open(__file__) as f: r = self.loop.run_until_complete( client.request('post', url, files={'some': f}, - chunked=1024, compress='deflate')) + chunked=1024, compress='deflate', loop=self.loop)) content = self.loop.run_until_complete(r.read(True)) @@ -175,7 +175,7 @@ def test_POST_FILES_STR(self): with open(__file__) as f: r = self.loop.run_until_complete( - client.request('post', url, files=[('some', f.read())])) + client.request('post', url, files=[('some', f.read())], loop=self.loop)) content = self.loop.run_until_complete(r.read(True)) @@ -195,7 +195,7 @@ def test_POST_FILES_LIST(self): with open(__file__) as f: r = self.loop.run_until_complete( - client.request('post', url, files=[('some', f)])) + client.request('post', url, files=[('some', f)], loop=self.loop)) content = self.loop.run_until_complete(r.read(True)) diff --git a/tulip/base_events.py b/tulip/base_events.py index b4ae4595..3cfe6625 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -116,7 +116,7 @@ def run_until_complete(self, future, timeout=None): Return the Future's result, or raise its exception. If the timeout is reached or stop() is called, raise TimeoutError. """ - future = tasks.async(future) + future = tasks.async(future, loop=self) future.add_done_callback(_raise_stop_error) handle_called = False @@ -209,7 +209,7 @@ def run_in_executor(self, executor, callback, *args): assert not args assert not isinstance(callback, events.TimerHandle) if callback._cancelled: - f = futures.Future() + f = futures.Future(loop=self) f.set_result(None) return f callback, args = callback._callback, callback._args @@ -311,7 +311,7 @@ def create_connection(self, protocol_factory, host=None, port=None, *, sock.setblocking(False) protocol = protocol_factory() - waiter = futures.Future() + waiter = futures.Future(loop=self) if ssl: sslcontext = None if isinstance(ssl, bool) else ssl transport = self._make_ssl_transport( @@ -458,7 +458,7 @@ def start_serving(self, protocol_factory, host=None, port=None, *, @tasks.coroutine def connect_read_pipe(self, protocol_factory, pipe): protocol = protocol_factory() - waiter = futures.Future() + waiter = futures.Future(loop=self) transport = self._make_read_pipe_transport(pipe, protocol, waiter, extra={}) yield from waiter @@ -467,7 +467,7 @@ def connect_read_pipe(self, protocol_factory, pipe): @tasks.coroutine def connect_write_pipe(self, protocol_factory, pipe): protocol = protocol_factory() - waiter = futures.Future() + waiter = futures.Future(loop=self) transport = self._make_write_pipe_transport(pipe, protocol, waiter, extra={}) yield from waiter diff --git a/tulip/http/client.py b/tulip/http/client.py index f0984f9d..babfd7f6 100644 --- a/tulip/http/client.py +++ b/tulip/http/client.py @@ -42,7 +42,8 @@ def request(method, url, *, timeout=None, compress=None, chunked=None, - session=None): + session=None, + loop=None): """Constructs and sends a request. Returns response object. method: http method @@ -65,6 +66,7 @@ def request(method, url, *, transfer encoding. session: tulip.http.Session instance to support connection pooling and session cookies. + loop: Optional event loop. Usage: @@ -77,7 +79,8 @@ def request(method, url, *, """ redirects = 0 - loop = tulip.get_event_loop() + if loop is None: + loop = tulip.get_event_loop() while True: req = HttpRequest( @@ -92,7 +95,7 @@ def request(method, url, *, # connection timeout try: - resp = yield from tulip.Task(conn, timeout=timeout) + resp = yield from tulip.Task(conn, timeout=timeout, loop=loop) except tulip.CancelledError: raise tulip.TimeoutError from None diff --git a/tulip/selector_events.py b/tulip/selector_events.py index cdd04623..225af6f3 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -194,7 +194,7 @@ def remove_writer(self, fd): def sock_recv(self, sock, n): """XXX""" - fut = futures.Future() + fut = futures.Future(loop=self) self._sock_recv(fut, False, sock, n) return fut @@ -218,7 +218,7 @@ def _sock_recv(self, fut, registered, sock, n): def sock_sendall(self, sock, data): """XXX""" - fut = futures.Future() + fut = futures.Future(loop=self) if data: self._sock_sendall(fut, False, sock, data) else: @@ -254,7 +254,7 @@ def sock_connect(self, sock, address): # self.getaddrinfo() for you here. But verifying this is # complicated; the socket module doesn't have a pattern for # IPv6 addresses (there are too many forms, apparently). - fut = futures.Future() + fut = futures.Future(loop=self) self._sock_connect(fut, False, sock, address) return fut @@ -288,7 +288,7 @@ def _sock_connect(self, fut, registered, sock, address): def sock_accept(self, sock): """XXX""" - fut = futures.Future() + fut = futures.Future(loop=self) self._sock_accept(fut, False, sock) return fut diff --git a/tulip/test_utils.py b/tulip/test_utils.py index d97dc618..ca924491 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -127,7 +127,7 @@ def run(loop, fut): lambda: TestHttpServer(keep_alive=0.5), host, port, ssl=sslcontext)) - waiter = tulip.Future() + waiter = tulip.Future(loop=thread_loop) loop.call_soon_threadsafe( fut.set_result, (thread_loop, waiter, socks[0].getsockname())) @@ -146,7 +146,7 @@ def run(loop, fut): thread_loop.close() gc.collect() - fut = tulip.Future() + fut = tulip.Future(loop=loop) server_thread = threading.Thread(target=run, args=(loop, fut)) server_thread.start() From eb652d88448a291962a1c831db014da72398c4a1 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 6 Aug 2013 16:51:34 -0700 Subject: [PATCH 0513/1502] Break long lines. --- tests/http_client_functional_test.py | 27 ++++++++++++++++++--------- tulip/queues.py | 3 ++- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/tests/http_client_functional_test.py b/tests/http_client_functional_test.py index b87366ec..1be13b0f 100644 --- a/tests/http_client_functional_test.py +++ b/tests/http_client_functional_test.py @@ -29,7 +29,8 @@ def test_HTTP_200_OK_METHOD(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: for meth in ('get', 'post', 'put', 'delete', 'head'): r = self.loop.run_until_complete( - client.request(meth, httpd.url('method', meth), loop=self.loop)) + client.request(meth, httpd.url('method', meth), + loop=self.loop)) content1 = self.loop.run_until_complete(r.read()) content2 = self.loop.run_until_complete(r.read()) content = content1.decode() @@ -42,7 +43,8 @@ def test_HTTP_200_OK_METHOD(self): def test_HTTP_302_REDIRECT_GET(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: r = self.loop.run_until_complete( - client.request('get', httpd.url('redirect', 2), loop=self.loop)) + client.request('get', httpd.url('redirect', 2), + loop=self.loop)) self.assertEqual(r.status, 200) self.assertEqual(2, httpd['redirects']) @@ -52,7 +54,8 @@ def test_HTTP_302_REDIRECT_NON_HTTP(self): self.assertRaises( ValueError, self.loop.run_until_complete, - client.request('get', httpd.url('redirect_err'), loop=self.loop)) + client.request('get', httpd.url('redirect_err'), + loop=self.loop)) def test_HTTP_302_REDIRECT_POST(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: @@ -102,7 +105,8 @@ def test_POST_DATA(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: url = httpd.url('method', 'post') r = self.loop.run_until_complete( - client.request('post', url, data={'some': 'data'}, loop=self.loop)) + client.request('post', url, data={'some': 'data'}, + loop=self.loop)) self.assertEqual(r.status, 200) content = self.loop.run_until_complete(r.read(True)) @@ -114,7 +118,8 @@ def test_POST_DATA_DEFLATE(self): url = httpd.url('method', 'post') r = self.loop.run_until_complete( client.request('post', url, - data={'some': 'data'}, compress=True, loop=self.loop)) + data={'some': 'data'}, compress=True, + loop=self.loop)) self.assertEqual(r.status, 200) content = self.loop.run_until_complete(r.read(True)) @@ -130,7 +135,8 @@ def test_POST_FILES(self): r = self.loop.run_until_complete( client.request( 'post', url, files={'some': f}, chunked=1024, - headers={'Transfer-Encoding': 'chunked'}, loop=self.loop)) + headers={'Transfer-Encoding': 'chunked'}, + loop=self.loop)) content = self.loop.run_until_complete(r.read(True)) f.seek(0) @@ -152,7 +158,8 @@ def test_POST_FILES_DEFLATE(self): with open(__file__) as f: r = self.loop.run_until_complete( client.request('post', url, files={'some': f}, - chunked=1024, compress='deflate', loop=self.loop)) + chunked=1024, compress='deflate', + loop=self.loop)) content = self.loop.run_until_complete(r.read(True)) @@ -175,7 +182,8 @@ def test_POST_FILES_STR(self): with open(__file__) as f: r = self.loop.run_until_complete( - client.request('post', url, files=[('some', f.read())], loop=self.loop)) + client.request('post', url, files=[('some', f.read())], + loop=self.loop)) content = self.loop.run_until_complete(r.read(True)) @@ -195,7 +203,8 @@ def test_POST_FILES_LIST(self): with open(__file__) as f: r = self.loop.run_until_complete( - client.request('post', url, files=[('some', f)], loop=self.loop)) + client.request('post', url, files=[('some', f)], + loop=self.loop)) content = self.loop.run_until_complete(r.read(True)) diff --git a/tulip/queues.py b/tulip/queues.py index dfe0cae7..8214d0ec 100644 --- a/tulip/queues.py +++ b/tulip/queues.py @@ -1,6 +1,7 @@ """Queues""" -__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue', 'Full', 'Empty'] +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue', + 'Full', 'Empty'] import collections import concurrent.futures From 58f09727aca56c7721c32a845cc787839c87298a Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 7 Aug 2013 09:10:53 -0700 Subject: [PATCH 0514/1502] Add optional loop parameter to StreamReader. --- examples/child_process.py | 14 ++------- tests/streams_test.py | 62 +++++++++++++++++++-------------------- tulip/streams.py | 15 ++++++---- 3 files changed, 42 insertions(+), 49 deletions(-) diff --git a/examples/child_process.py b/examples/child_process.py index e21a925a..a799fa27 100644 --- a/examples/child_process.py +++ b/examples/child_process.py @@ -38,21 +38,11 @@ def connect_write_pipe(file): def connect_read_pipe(file): loop = tulip.get_event_loop() - stream_reader = streams.StreamReader() - protocol = _StreamReaderProtocol(stream_reader) + stream_reader = streams.StreamReader(loop=loop) + protocol = streams.StreamReaderProtocol(stream_reader) transport = loop._make_read_pipe_transport(file, protocol) return stream_reader -class _StreamReaderProtocol(protocols.Protocol): - def __init__(self, stream_reader): - self.stream_reader = stream_reader - def connection_lost(self, exc): - self.stream_reader.set_exception(exc) - def data_received(self, data): - self.stream_reader.feed_data(data) - def eof_received(self): - self.stream_reader.feed_eof() - # # Example # diff --git a/tests/streams_test.py b/tests/streams_test.py index 669da75d..6022a0c8 100644 --- a/tests/streams_test.py +++ b/tests/streams_test.py @@ -50,20 +50,20 @@ def test_open_connection_error(self): self.loop.run_until_complete(f) def test_feed_empty_data(self): - stream = streams.StreamReader() + stream = streams.StreamReader(loop=self.loop) stream.feed_data(b'') self.assertEqual(0, stream.byte_count) def test_feed_data_byte_count(self): - stream = streams.StreamReader() + stream = streams.StreamReader(loop=self.loop) stream.feed_data(self.DATA) self.assertEqual(len(self.DATA), stream.byte_count) def test_read_zero(self): # Read zero bytes. - stream = streams.StreamReader() + stream = streams.StreamReader(loop=self.loop) stream.feed_data(self.DATA) data = self.loop.run_until_complete(stream.read(0)) @@ -72,8 +72,8 @@ def test_read_zero(self): def test_read(self): # Read bytes. - stream = streams.StreamReader() - read_task = tasks.Task(stream.read(30)) + stream = streams.StreamReader(loop=self.loop) + read_task = tasks.Task(stream.read(30), loop=self.loop) def cb(): stream.feed_data(self.DATA) @@ -85,7 +85,7 @@ def cb(): def test_read_line_breaks(self): # Read bytes without line breaks. - stream = streams.StreamReader() + stream = streams.StreamReader(loop=self.loop) stream.feed_data(b'line1') stream.feed_data(b'line2') @@ -96,8 +96,8 @@ def test_read_line_breaks(self): def test_read_eof(self): # Read bytes, stop at eof. - stream = streams.StreamReader() - read_task = tasks.Task(stream.read(1024)) + stream = streams.StreamReader(loop=self.loop) + read_task = tasks.Task(stream.read(1024), loop=self.loop) def cb(): stream.feed_eof() @@ -109,8 +109,8 @@ def cb(): def test_read_until_eof(self): # Read all bytes until eof. - stream = streams.StreamReader() - read_task = tasks.Task(stream.read(-1)) + stream = streams.StreamReader(loop=self.loop) + read_task = tasks.Task(stream.read(-1), loop=self.loop) def cb(): stream.feed_data(b'chunk1\n') @@ -124,7 +124,7 @@ def cb(): self.assertFalse(stream.byte_count) def test_read_exception(self): - stream = streams.StreamReader() + stream = streams.StreamReader(loop=self.loop) stream.feed_data(b'line\n') data = self.loop.run_until_complete(stream.read(2)) @@ -136,9 +136,9 @@ def test_read_exception(self): def test_readline(self): # Read one line. - stream = streams.StreamReader() + stream = streams.StreamReader(loop=self.loop) stream.feed_data(b'chunk1 ') - read_task = tasks.Task(stream.readline()) + read_task = tasks.Task(stream.readline(), loop=self.loop) def cb(): stream.feed_data(b'chunk2 ') @@ -151,7 +151,7 @@ def cb(): self.assertEqual(len(b'\n chunk4')-1, stream.byte_count) def test_readline_limit_with_existing_data(self): - stream = streams.StreamReader(3) + stream = streams.StreamReader(3, loop=self.loop) stream.feed_data(b'li') stream.feed_data(b'ne1\nline2\n') @@ -159,7 +159,7 @@ def test_readline_limit_with_existing_data(self): ValueError, self.loop.run_until_complete, stream.readline()) self.assertEqual([b'line2\n'], list(stream.buffer)) - stream = streams.StreamReader(3) + stream = streams.StreamReader(3, loop=self.loop) stream.feed_data(b'li') stream.feed_data(b'ne1') stream.feed_data(b'li') @@ -170,7 +170,7 @@ def test_readline_limit_with_existing_data(self): self.assertEqual(2, stream.byte_count) def test_readline_limit(self): - stream = streams.StreamReader(7) + stream = streams.StreamReader(7, loop=self.loop) def cb(): stream.feed_data(b'chunk1') @@ -185,7 +185,7 @@ def cb(): self.assertEqual(7, stream.byte_count) def test_readline_line_byte_count(self): - stream = streams.StreamReader() + stream = streams.StreamReader(loop=self.loop) stream.feed_data(self.DATA[:6]) stream.feed_data(self.DATA[6:]) @@ -195,7 +195,7 @@ def test_readline_line_byte_count(self): self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) def test_readline_eof(self): - stream = streams.StreamReader() + stream = streams.StreamReader(loop=self.loop) stream.feed_data(b'some data') stream.feed_eof() @@ -203,14 +203,14 @@ def test_readline_eof(self): self.assertEqual(b'some data', line) def test_readline_empty_eof(self): - stream = streams.StreamReader() + stream = streams.StreamReader(loop=self.loop) stream.feed_eof() line = self.loop.run_until_complete(stream.readline()) self.assertEqual(b'', line) def test_readline_read_byte_count(self): - stream = streams.StreamReader() + stream = streams.StreamReader(loop=self.loop) stream.feed_data(self.DATA) self.loop.run_until_complete(stream.readline()) @@ -223,7 +223,7 @@ def test_readline_read_byte_count(self): stream.byte_count) def test_readline_exception(self): - stream = streams.StreamReader() + stream = streams.StreamReader(loop=self.loop) stream.feed_data(b'line\n') data = self.loop.run_until_complete(stream.readline()) @@ -235,7 +235,7 @@ def test_readline_exception(self): def test_readexactly_zero_or_less(self): # Read exact number of bytes (zero or less). - stream = streams.StreamReader() + stream = streams.StreamReader(loop=self.loop) stream.feed_data(self.DATA) data = self.loop.run_until_complete(stream.readexactly(0)) @@ -248,10 +248,10 @@ def test_readexactly_zero_or_less(self): def test_readexactly(self): # Read exact number of bytes. - stream = streams.StreamReader() + stream = streams.StreamReader(loop=self.loop) n = 2 * len(self.DATA) - read_task = tasks.Task(stream.readexactly(n)) + read_task = tasks.Task(stream.readexactly(n), loop=self.loop) def cb(): stream.feed_data(self.DATA) @@ -265,9 +265,9 @@ def cb(): def test_readexactly_eof(self): # Read exact number of bytes (eof). - stream = streams.StreamReader() + stream = streams.StreamReader(loop=self.loop) n = 2 * len(self.DATA) - read_task = tasks.Task(stream.readexactly(n)) + read_task = tasks.Task(stream.readexactly(n), loop=self.loop) def cb(): stream.feed_data(self.DATA) @@ -279,7 +279,7 @@ def cb(): self.assertFalse(stream.byte_count) def test_readexactly_exception(self): - stream = streams.StreamReader() + stream = streams.StreamReader(loop=self.loop) stream.feed_data(b'line\n') data = self.loop.run_until_complete(stream.readexactly(2)) @@ -290,7 +290,7 @@ def test_readexactly_exception(self): ValueError, self.loop.run_until_complete, stream.readexactly(2)) def test_exception(self): - stream = streams.StreamReader() + stream = streams.StreamReader(loop=self.loop) self.assertIsNone(stream.exception()) exc = ValueError() @@ -298,7 +298,7 @@ def test_exception(self): self.assertIs(stream.exception(), exc) def test_exception_waiter(self): - stream = streams.StreamReader() + stream = streams.StreamReader(loop=self.loop) @tasks.coroutine def set_err(): @@ -308,8 +308,8 @@ def set_err(): def readline(): yield from stream.readline() - t1 = tasks.Task(stream.readline()) - t2 = tasks.Task(set_err()) + t1 = tasks.Task(stream.readline(), loop=self.loop) + t2 = tasks.Task(set_err(), loop=self.loop) self.loop.run_until_complete(tasks.wait([t1, t2])) diff --git a/tulip/streams.py b/tulip/streams.py index 3aec9ea6..511ac2e0 100644 --- a/tulip/streams.py +++ b/tulip/streams.py @@ -35,7 +35,7 @@ def open_connection(host=None, port=None, *, """ if loop is None: loop = events.get_event_loop() - reader = StreamReader(limit=limit) + reader = StreamReader(limit=limit, loop=loop) protocol = StreamReaderProtocol(reader) transport, _ = yield from loop.create_connection( lambda: protocol, host, port, **kwds) @@ -69,8 +69,11 @@ def eof_received(self): class StreamReader: - def __init__(self, limit=_DEFAULT_LIMIT): + def __init__(self, limit=_DEFAULT_LIMIT, loop=None): self.limit = limit # Max line length. (Security feature.) + if loop is None: + loop = events.get_event_loop() + self.loop = loop self.buffer = collections.deque() # Deque of bytes objects. self.byte_count = 0 # Bytes in buffer. self.eof = False # Whether we're done. @@ -141,7 +144,7 @@ def readline(self): if not_enough: assert self.waiter is None - self.waiter = futures.Future() + self.waiter = futures.Future(loop=self.loop) yield from self.waiter line = b''.join(parts) @@ -160,12 +163,12 @@ def read(self, n=-1): if n < 0: while not self.eof: assert not self.waiter - self.waiter = futures.Future() + self.waiter = futures.Future(loop=self.loop) yield from self.waiter else: if not self.byte_count and not self.eof: assert not self.waiter - self.waiter = futures.Future() + self.waiter = futures.Future(loop=self.loop) yield from self.waiter if n < 0 or self.byte_count <= n: @@ -200,7 +203,7 @@ def readexactly(self, n): while self.byte_count < n and not self.eof: assert not self.waiter - self.waiter = futures.Future() + self.waiter = futures.Future(loop=self.loop) yield from self.waiter return (yield from self.read(n)) From 4967fcbb25635abb2390c6acd0d01846f6a1475f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 7 Aug 2013 09:15:19 -0700 Subject: [PATCH 0515/1502] Make StreamReader.feed_*() robust if reading task is cancelled. --- tulip/streams.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tulip/streams.py b/tulip/streams.py index 511ac2e0..3203b7d6 100644 --- a/tulip/streams.py +++ b/tulip/streams.py @@ -96,7 +96,8 @@ def feed_eof(self): waiter = self.waiter if waiter is not None: self.waiter = None - waiter.set_result(True) + if not waiter.done(): + waiter.set_result(True) def feed_data(self, data): if not data: @@ -108,7 +109,8 @@ def feed_data(self, data): waiter = self.waiter if waiter is not None: self.waiter = None - waiter.set_result(False) + if not waiter.done(): + waiter.set_result(False) @tasks.coroutine def readline(self): From 7087a3214f114b5310c5679f6b8f5dbad6763559 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 7 Aug 2013 10:37:35 -0700 Subject: [PATCH 0516/1502] Add optional loop parameter to StreamBuffer and DataBuffer. --- tulip/parsers.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tulip/parsers.py b/tulip/parsers.py index f5a7845a..43ddc2e9 100644 --- a/tulip/parsers.py +++ b/tulip/parsers.py @@ -81,7 +81,8 @@ class StreamBuffer: unset_parser() sends EofStream into parser and then removes it. """ - def __init__(self): + def __init__(self, *, loop=None): + self._loop = loop self._buffer = ParserBuffer() self._eof = False self._parser = None @@ -139,7 +140,7 @@ def set_parser(self, p): if self._parser: self.unset_parser() - out = DataBuffer() + out = DataBuffer(loop=self._loop) if self._exception: out.set_exception(self._exception) return out @@ -205,7 +206,8 @@ def connection_lost(self, exc): class DataBuffer: """DataBuffer is a destination for parsed data.""" - def __init__(self): + def __init__(self, *, loop=None): + self._loop = loop self._buffer = collections.deque() self._eof = False self._waiter = None @@ -220,7 +222,7 @@ def set_exception(self, exc): waiter = self._waiter if waiter is not None: self._waiter = None - if not waiter.cancelled(): + if not waiter.done(): waiter.set_exception(exc) def feed_data(self, data): @@ -246,7 +248,7 @@ def read(self): if not self._buffer and not self._eof: assert not self._waiter - self._waiter = futures.Future() + self._waiter = futures.Future(loop=self._loop) yield from self._waiter if self._buffer: From 732d779b2bc803ea4e1d9c6fb146e32fc5483034 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Thu, 8 Aug 2013 14:50:59 +0300 Subject: [PATCH 0517/1502] Add disconnection notification and pause for write pipes. --- tests/events_test.py | 37 ++++++ tests/unix_events_test.py | 268 ++++++++++++++++++++++++++++---------- tulip/unix_events.py | 86 +++++++++--- 3 files changed, 300 insertions(+), 91 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index f5a509e7..3c198ebf 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -831,6 +831,43 @@ def connect(): self.loop.run_until_complete(proto.done) self.assertEqual('CLOSED', proto.state) + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe_disconnect_on_close(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto(create_future=True) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.task + def connect(): + nonlocal transport + t, p = yield from self.loop.connect_write_pipe(factory, + pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.loop.run_until_complete(connect()) + self.assertEqual('CONNECTED', proto.state) + + transport.write(b'1') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + os.close(rpipe) + + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + def test_prompt_cancellation(self): r, w = test_utils.socketpair() r.setblocking(False) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index e5304f71..8ad66308 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -2,6 +2,8 @@ import errno import io +import stat +import tempfile import unittest import unittest.mock @@ -194,16 +196,18 @@ def setUp(self): self.pipe.fileno.return_value = 5 self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) - @unittest.mock.patch('fcntl.fcntl') - def test_ctor(self, m_fcntl): + fcntl_patcher = unittest.mock.patch('fcntl.fcntl') + fcntl_patcher.start() + self.addCleanup(fcntl_patcher.stop) + + def test_ctor(self): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) self.loop.add_reader.assert_called_with(5, tr._read_ready) self.loop.call_soon.assert_called_with( self.protocol.connection_made, tr) - @unittest.mock.patch('fcntl.fcntl') - def test_ctor_with_waiter(self, m_fcntl): + def test_ctor_with_waiter(self): fut = futures.Future() unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol, fut) @@ -211,8 +215,7 @@ def test_ctor_with_waiter(self, m_fcntl): fut.cancel() @unittest.mock.patch('os.read') - @unittest.mock.patch('fcntl.fcntl') - def test__read_ready(self, m_fcntl, m_read): + def test__read_ready(self, m_read): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) m_read.return_value = b'data' @@ -222,8 +225,7 @@ def test__read_ready(self, m_fcntl, m_read): self.protocol.data_received.assert_called_with(b'data') @unittest.mock.patch('os.read') - @unittest.mock.patch('fcntl.fcntl') - def test__read_ready_eof(self, m_fcntl, m_read): + def test__read_ready_eof(self, m_read): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) m_read.return_value = b'' @@ -234,8 +236,7 @@ def test__read_ready_eof(self, m_fcntl, m_read): self.protocol.eof_received.assert_called_with() @unittest.mock.patch('os.read') - @unittest.mock.patch('fcntl.fcntl') - def test__read_ready_blocked(self, m_fcntl, m_read): + def test__read_ready_blocked(self, m_read): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) self.loop.reset_mock() @@ -247,8 +248,7 @@ def test__read_ready_blocked(self, m_fcntl, m_read): @unittest.mock.patch('tulip.log.tulip_log.exception') @unittest.mock.patch('os.read') - @unittest.mock.patch('fcntl.fcntl') - def test__read_ready_error(self, m_fcntl, m_read, m_logexc): + def test__read_ready_error(self, m_read, m_logexc): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) err = OSError() @@ -261,8 +261,7 @@ def test__read_ready_error(self, m_fcntl, m_read, m_logexc): m_logexc.assert_called_with('Fatal error for %s', tr) @unittest.mock.patch('os.read') - @unittest.mock.patch('fcntl.fcntl') - def test_pause(self, m_fcntl, m_read): + def test_pause(self, m_read): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) @@ -270,8 +269,7 @@ def test_pause(self, m_fcntl, m_read): self.loop.remove_reader.assert_called_with(5) @unittest.mock.patch('os.read') - @unittest.mock.patch('fcntl.fcntl') - def test_resume(self, m_fcntl, m_read): + def test_resume(self, m_read): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) @@ -279,8 +277,7 @@ def test_resume(self, m_fcntl, m_read): self.loop.add_reader.assert_called_with(5, tr._read_ready) @unittest.mock.patch('os.read') - @unittest.mock.patch('fcntl.fcntl') - def test_close(self, m_fcntl, m_read): + def test_close(self, m_read): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) @@ -289,8 +286,7 @@ def test_close(self, m_fcntl, m_read): tr._close.assert_called_with(None) @unittest.mock.patch('os.read') - @unittest.mock.patch('fcntl.fcntl') - def test_close_already_closing(self, m_fcntl, m_read): + def test_close_already_closing(self, m_read): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) @@ -300,8 +296,7 @@ def test_close_already_closing(self, m_fcntl, m_read): self.assertFalse(tr._close.called) @unittest.mock.patch('os.read') - @unittest.mock.patch('fcntl.fcntl') - def test__close(self, m_fcntl, m_read): + def test__close(self, m_read): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) @@ -311,8 +306,7 @@ def test__close(self, m_fcntl, m_read): self.loop.remove_reader.assert_called_with(5) self.loop.call_soon.assert_called_with(tr._call_connection_lost, err) - @unittest.mock.patch('fcntl.fcntl') - def test__call_connection_lost(self, m_fcntl): + def test__call_connection_lost(self): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) @@ -321,8 +315,7 @@ def test__call_connection_lost(self, m_fcntl): self.protocol.connection_lost.assert_called_with(err) self.pipe.close.assert_called_with() - @unittest.mock.patch('fcntl.fcntl') - def test__call_connection_lost_with_err(self, m_fcntl): + def test__call_connection_lost_with_err(self): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) @@ -340,30 +333,41 @@ def setUp(self): self.pipe.fileno.return_value = 5 self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) - @unittest.mock.patch('fcntl.fcntl') - def test_ctor(self, m_fcntl): + fcntl_patcher = unittest.mock.patch('fcntl.fcntl') + fcntl_patcher.start() + self.addCleanup(fcntl_patcher.stop) + + fstat_patcher = unittest.mock.patch('os.fstat') + m_fstat = fstat_patcher.start() + st = unittest.mock.Mock() + st.st_mode = stat.S_IFIFO + m_fstat.return_value = st + self.addCleanup(fstat_patcher.stop) + + def test_ctor(self): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) + self.loop.add_reader.assert_called_with(5, tr._read_ready) self.loop.call_soon.assert_called_with( self.protocol.connection_made, tr) + self.assertTrue(tr._enable_read_hack) - @unittest.mock.patch('fcntl.fcntl') - def test_ctor_with_waiter(self, m_fcntl): + def test_ctor_with_waiter(self): fut = futures.Future() - unix_events._UnixWritePipeTransport( + tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol, fut) self.loop.call_soon.assert_called_with(fut.set_result, None) + self.loop.add_reader.assert_called_with(5, tr._read_ready) + self.assertTrue(tr._enable_read_hack) fut.cancel() - @unittest.mock.patch('fcntl.fcntl') - def test_can_write_eof(self, m_fcntl): + def test_can_write_eof(self): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) self.assertTrue(tr.can_write_eof()) @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test_write(self, m_fcntl, m_write): + def test_write(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -374,8 +378,7 @@ def test_write(self, m_fcntl, m_write): self.assertEqual([], tr._buffer) @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test_write_no_data(self, m_fcntl, m_write): + def test_write_no_data(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -385,8 +388,7 @@ def test_write_no_data(self, m_fcntl, m_write): self.assertEqual([], tr._buffer) @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test_write_partial(self, m_fcntl, m_write): + def test_write_partial(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -397,8 +399,7 @@ def test_write_partial(self, m_fcntl, m_write): self.assertEqual([b'ta'], tr._buffer) @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test_write_buffer(self, m_fcntl, m_write): + def test_write_buffer(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -409,8 +410,7 @@ def test_write_buffer(self, m_fcntl, m_write): self.assertEqual([b'previous', b'data'], tr._buffer) @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test_write_again(self, m_fcntl, m_write): + def test_write_again(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -422,8 +422,7 @@ def test_write_again(self, m_fcntl, m_write): @unittest.mock.patch('tulip.unix_events.tulip_log') @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test_write_err(self, m_fcntl, m_write, m_log): + def test_write_err(self, m_write, m_log): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -446,9 +445,18 @@ def test_write_err(self, m_fcntl, m_write, m_log): m_log.warning.assert_called_with( 'os.write(pipe, data) raised exception.') + def test__read_ready(self): + tr = unix_events._UnixWritePipeTransport(self.loop, self.pipe, + self.protocol) + tr._read_ready() + self.loop.remove_writer.assert_called_with(5) + self.loop.remove_reader.assert_called_with(5) + self.assertTrue(tr._closing) + self.loop.call_soon.assert_called_with(tr._call_connection_lost, + None) + @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test__write_ready(self, m_fcntl, m_write): + def test__write_ready(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) tr._buffer = [b'da', b'ta'] @@ -459,8 +467,7 @@ def test__write_ready(self, m_fcntl, m_write): self.assertEqual([], tr._buffer) @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test__write_ready_partial(self, m_fcntl, m_write): + def test__write_ready_partial(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -472,8 +479,7 @@ def test__write_ready_partial(self, m_fcntl, m_write): self.assertEqual([b'a'], tr._buffer) @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test__write_ready_again(self, m_fcntl, m_write): + def test__write_ready_again(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -485,8 +491,7 @@ def test__write_ready_again(self, m_fcntl, m_write): self.assertEqual([b'data'], tr._buffer) @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test__write_ready_empty(self, m_fcntl, m_write): + def test__write_ready_empty(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -499,8 +504,7 @@ def test__write_ready_empty(self, m_fcntl, m_write): @unittest.mock.patch('tulip.log.tulip_log.exception') @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test__write_ready_err(self, m_fcntl, m_write, m_logexc): + def test__write_ready_err(self, m_write, m_logexc): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -509,6 +513,7 @@ def test__write_ready_err(self, m_fcntl, m_write, m_logexc): tr._write_ready() m_write.assert_called_with(5, b'data') self.loop.remove_writer.assert_called_with(5) + self.loop.remove_reader.assert_called_with(5) self.assertEqual([], tr._buffer) self.assertTrue(tr._closing) self.loop.call_soon.assert_called_with( @@ -517,8 +522,7 @@ def test__write_ready_err(self, m_fcntl, m_write, m_logexc): self.assertEqual(1, tr._conn_lost) @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test__write_ready_closing(self, m_fcntl, m_write): + def test__write_ready_closing(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -528,13 +532,13 @@ def test__write_ready_closing(self, m_fcntl, m_write): tr._write_ready() m_write.assert_called_with(5, b'data') self.loop.remove_writer.assert_called_with(5) + self.loop.remove_reader.assert_called_with(5) self.assertEqual([], tr._buffer) self.protocol.connection_lost.assert_called_with(None) self.pipe.close.assert_called_with() @unittest.mock.patch('os.write') - @unittest.mock.patch('fcntl.fcntl') - def test_abort(self, m_fcntl, m_write): + def test_abort(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -542,13 +546,13 @@ def test_abort(self, m_fcntl, m_write): tr.abort() self.assertFalse(m_write.called) self.loop.remove_writer.assert_called_with(5) + self.loop.remove_reader.assert_called_with(5) self.assertEqual([], tr._buffer) self.assertTrue(tr._closing) self.loop.call_soon.assert_called_with( tr._call_connection_lost, None) - @unittest.mock.patch('fcntl.fcntl') - def test__call_connection_lost(self, m_fcntl): + def test__call_connection_lost(self): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -557,8 +561,7 @@ def test__call_connection_lost(self, m_fcntl): self.protocol.connection_lost.assert_called_with(err) self.pipe.close.assert_called_with() - @unittest.mock.patch('fcntl.fcntl') - def test__call_connection_lost_with_err(self, m_fcntl): + def test__call_connection_lost_with_err(self): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -567,8 +570,7 @@ def test__call_connection_lost_with_err(self, m_fcntl): self.protocol.connection_lost.assert_called_with(err) self.pipe.close.assert_called_with() - @unittest.mock.patch('fcntl.fcntl') - def test_close(self, m_fcntl): + def test_close(self): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -576,8 +578,7 @@ def test_close(self, m_fcntl): tr.close() tr.write_eof.assert_called_with() - @unittest.mock.patch('fcntl.fcntl') - def test_close_closing(self, m_fcntl): + def test_close_closing(self): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -586,21 +587,146 @@ def test_close_closing(self, m_fcntl): tr.close() self.assertFalse(tr.write_eof.called) - @unittest.mock.patch('fcntl.fcntl') - def test_write_eof(self, m_fcntl): + def test_write_eof(self): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) tr.write_eof() self.assertTrue(tr._closing) + self.loop.remove_reader.assert_called_with(5) self.loop.call_soon.assert_called_with( tr._call_connection_lost, None) - @unittest.mock.patch('fcntl.fcntl') - def test_write_eof_pending(self, m_fcntl): + def test_write_eof_pending(self): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) tr._buffer = [b'data'] tr.write_eof() self.assertTrue(tr._closing) self.assertFalse(self.protocol.connection_lost.called) + + def test_pause_resume_writing(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr.pause_writing() + self.assertFalse(tr._writing) + tr.resume_writing() + self.assertTrue(tr._writing) + + def test_double_pause_resume_writing(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr.pause_writing() + self.assertFalse(tr._writing) + tr.pause_writing() + self.assertFalse(tr._writing) + tr.resume_writing() + self.assertTrue(tr._writing) + tr.resume_writing() + self.assertTrue(tr._writing) + + def test_pause_resume_writing_with_nonempty_buffer(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr._buffer = [b'da', b'ta'] + tr.pause_writing() + self.assertFalse(tr._writing) + self.loop.remove_writer.assert_called_with(5) + self.assertEqual([b'da', b'ta'], tr._buffer) + + tr.resume_writing() + self.assertTrue(tr._writing) + self.loop.add_writer.assert_called_with(5, tr._write_ready) + self.assertEqual([b'da', b'ta'], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_on_pause(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr._buffer = [b'da', b'ta'] + tr.pause_writing() + self.loop.remove_writer.reset_mock() + tr._write_ready() + self.assertFalse(m_write.called) + self.assertFalse(self.loop.remove_writer.called) + self.assertEqual([b'da', b'ta'], tr._buffer) + self.assertFalse(tr._writing) + + def test_discard_output(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr._buffer = [b'da', b'ta'] + tr.discard_output() + self.assertTrue(tr._writing) + self.loop.remove_writer.assert_called_with(5) + self.assertEqual([], tr._buffer) + + def test_discard_output_without_pending_writes(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr.discard_output() + self.assertTrue(tr._writing) + self.assertFalse(self.loop.remove_writer.called) + self.assertEqual([], tr._buffer) + + +class UnixWritePipeRegularFileTests(unittest.TestCase): + + def setUp(self): + self.loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop) + self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) + + def test_ctor_with_regular_file(self): + with tempfile.TemporaryFile() as f: + tr = unix_events._UnixWritePipeTransport(self.loop, f, + self.protocol) + self.assertFalse(self.loop.add_reader.called) + self.loop.call_soon.assert_called_with( + self.protocol.connection_made, tr) + self.assertFalse(tr._enable_read_hack) + + def test_write_eof(self): + with tempfile.TemporaryFile() as f: + tr = unix_events._UnixWritePipeTransport( + self.loop, f, self.protocol) + + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.loop.remove_reader.called) + self.loop.call_soon.assert_called_with( + tr._call_connection_lost, None) + + @unittest.mock.patch('os.write') + def test__write_ready_closing(self, m_write): + with tempfile.TemporaryFile() as f: + fileno = f.fileno() + tr = unix_events._UnixWritePipeTransport( + self.loop, f, self.protocol) + + tr._closing = True + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(fileno, b'data') + self.loop.remove_writer.assert_called_with(fileno) + self.assertFalse(self.loop.remove_reader.called) + self.assertEqual([], tr._buffer) + self.protocol.connection_lost.assert_called_with(None) + self.assertTrue(f.closed) + + @unittest.mock.patch('os.write') + def test_abort(self, m_write): + with tempfile.TemporaryFile() as f: + fileno = f.fileno() + tr = unix_events._UnixWritePipeTransport( + self.loop, f, self.protocol) + + tr._buffer = [b'da', b'ta'] + tr.abort() + self.assertFalse(m_write.called) + self.loop.remove_writer.assert_called_with(fileno) + self.assertFalse(self.loop.remove_reader.called) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + self.loop.call_soon.assert_called_with( + tr._call_connection_lost, None) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index b5825950..563ff6c3 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -4,6 +4,7 @@ import fcntl import os import socket +import stat import sys try: @@ -149,19 +150,19 @@ class _UnixReadPipeTransport(transports.ReadTransport): max_size = 256 * 1024 # max bytes we read in one eventloop iteration - def __init__(self, event_loop, pipe, protocol, waiter=None, extra=None): + def __init__(self, loop, pipe, protocol, waiter=None, extra=None): super().__init__(extra) self._extra['pipe'] = pipe - self._event_loop = event_loop + self._loop = loop self._pipe = pipe self._fileno = pipe.fileno() _set_nonblocking(self._fileno) self._protocol = protocol self._closing = False - self._event_loop.add_reader(self._fileno, self._read_ready) - self._event_loop.call_soon(self._protocol.connection_made, self) + self._loop.add_reader(self._fileno, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - self._event_loop.call_soon(waiter.set_result, None) + self._loop.call_soon(waiter.set_result, None) def _read_ready(self): try: @@ -174,14 +175,14 @@ def _read_ready(self): if data: self._protocol.data_received(data) else: - self._event_loop.remove_reader(self._fileno) + self._loop.remove_reader(self._fileno) self._protocol.eof_received() def pause(self): - self._event_loop.remove_reader(self._fileno) + self._loop.remove_reader(self._fileno) def resume(self): - self._event_loop.add_reader(self._fileno, self._read_ready) + self._loop.add_reader(self._fileno, self._read_ready) def close(self): if not self._closing: @@ -194,8 +195,8 @@ def _fatal_error(self, exc): def _close(self, exc): self._closing = True - self._event_loop.remove_reader(self._fileno) - self._event_loop.call_soon(self._call_connection_lost, exc) + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, exc) def _call_connection_lost(self, exc): try: @@ -206,10 +207,10 @@ def _call_connection_lost(self, exc): class _UnixWritePipeTransport(transports.WriteTransport): - def __init__(self, event_loop, pipe, protocol, waiter=None, extra=None): + def __init__(self, loop, pipe, protocol, waiter=None, extra=None): super().__init__(extra) self._extra['pipe'] = pipe - self._event_loop = event_loop + self._loop = loop self._pipe = pipe self._fileno = pipe.fileno() _set_nonblocking(self._fileno) @@ -217,9 +218,28 @@ def __init__(self, event_loop, pipe, protocol, waiter=None, extra=None): self._buffer = [] self._conn_lost = 0 self._closing = False # Set when close() or write_eof() called. - self._event_loop.call_soon(self._protocol.connection_made, self) + self._writing = True + # Do nothing if it is a regular file. + # Enable hack only if pipe is FIFO object. + # Look on twisted.internet.process:ProcessWriter.__init__ + if stat.S_ISFIFO(os.fstat(self._fileno).st_mode): + self._enable_read_hack = True + else: + # If the pipe is not a unix pipe, then the read hack is never + # applicable. This case arises when _UnixWritePipeTransport + # is used by subprocess and stdout/stderr + # are redirected to a normal file. + self._enable_read_hack = False + if self._enable_read_hack: + self._loop.add_reader(self._fileno, self._read_ready) + + self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - self._event_loop.call_soon(waiter.set_result, None) + self._loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + # pipe was closed by peer + self._close() def write(self, data): assert isinstance(data, bytes), repr(data) @@ -233,7 +253,7 @@ def write(self, data): self._conn_lost += 1 return - if not self._buffer: + if not self._buffer and self._writing: # Attempt to send it right away first. try: n = os.write(self._fileno, data) @@ -247,11 +267,14 @@ def write(self, data): return elif n > 0: data = data[n:] - self._event_loop.add_writer(self._fileno, self._write_ready) + self._loop.add_writer(self._fileno, self._write_ready) self._buffer.append(data) def _write_ready(self): + if not self._writing: + return + data = b''.join(self._buffer) assert data, 'Data should not be empty' @@ -265,8 +288,10 @@ def _write_ready(self): self._fatal_error(exc) else: if n == len(data): - self._event_loop.remove_writer(self._fileno) + self._loop.remove_writer(self._fileno) if self._closing: + if self._enable_read_hack: + self._loop.remove_reader(self._fileno) self._call_connection_lost(None) return elif n > 0: @@ -282,7 +307,9 @@ def write_eof(self): assert self._pipe self._closing = True if not self._buffer: - self._event_loop.call_soon(self._call_connection_lost, None) + if self._enable_read_hack: + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, None) def close(self): if not self._closing: @@ -300,11 +327,30 @@ def _fatal_error(self, exc): def _close(self, exc=None): self._closing = True self._buffer.clear() - self._event_loop.remove_writer(self._fileno) - self._event_loop.call_soon(self._call_connection_lost, exc) + self._loop.remove_writer(self._fileno) + if self._enable_read_hack: + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, exc) def _call_connection_lost(self, exc): try: self._protocol.connection_lost(exc) finally: self._pipe.close() + + def pause_writing(self): + if self._writing: + if self._buffer: + self._loop.remove_writer(self._fileno) + self._writing = False + + def resume_writing(self): + if not self._writing: + if self._buffer: + self._loop.add_writer(self._fileno, self._write_ready) + self._writing = True + + def discard_output(self): + if self._buffer: + self._loop.remove_writer(self._fileno) + self._buffer.clear() From 5dd85a00cffa5e757175af0e9c9af5e421abf7db Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 8 Aug 2013 08:04:48 -0700 Subject: [PATCH 0518/1502] Add loop=self.loop to all Future() constructors in tasks_test.py. --- tests/tasks_test.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 7a91d0fc..2d166ff2 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -63,7 +63,7 @@ def notmuch(): self.assertEqual(t.result(), 'ko') def test_task_decorator_fut(self): - fut = futures.Future() + fut = futures.Future(loop=self.loop) fut.set_result('ko') @tasks.task @@ -89,7 +89,7 @@ def notmuch(): self.assertIs(t._loop, loop) def test_async_future(self): - f_orig = futures.Future() + f_orig = futures.Future(loop=self.loop) f_orig.set_result('ko') f = tasks.async(f_orig) @@ -206,9 +206,9 @@ def task(): self.assertFalse(t.cancel()) def test_cancel_done_future(self): - fut1 = futures.Future() - fut2 = futures.Future() - fut3 = futures.Future() + fut1 = futures.Future(loop=self.loop) + fut2 = futures.Future(loop=self.loop) + fut3 = futures.Future(loop=self.loop) @tasks.task def task(): @@ -627,7 +627,7 @@ def doit(): self.assertTrue(0.09 <= t1-t0 <= 0.13, (t1-t0, sleepfut, doer)) def test_task_cancel_waiter_future(self): - fut = futures.Future() + fut = futures.Future(loop=self.loop) @tasks.task def coro(): @@ -747,7 +747,7 @@ def fn2(): self.assertTrue(tasks.iscoroutinefunction(fn2)) def test_yield_vs_yield_from(self): - fut = futures.Future() + fut = futures.Future(loop=self.loop) @tasks.task def wait_for_future(): @@ -788,7 +788,7 @@ def func(): self.assertEqual(res, 'test') def test_coroutine_non_gen_function_return_future(self): - fut = futures.Future() + fut = futures.Future(loop=self.loop) @tasks.coroutine def func(): From cdef58736355964efd62c4933d1c0196f0dc6d56 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 8 Aug 2013 10:33:20 -0700 Subject: [PATCH 0519/1502] Fix for examples/child_process.py by Gustavo Carneiro . --- examples/child_process.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/examples/child_process.py b/examples/child_process.py index a799fa27..d4a035bd 100644 --- a/examples/child_process.py +++ b/examples/child_process.py @@ -27,21 +27,26 @@ # Return a write-only transport wrapping a writable pipe # +@tulip.coroutine def connect_write_pipe(file): loop = tulip.get_event_loop() protocol = protocols.Protocol() - return loop._make_write_pipe_transport(file, protocol) + transport, _ = yield from loop.connect_write_pipe(tulip.Protocol, file) + return transport # # Wrap a readable pipe in a stream # +@tulip.coroutine def connect_read_pipe(file): loop = tulip.get_event_loop() stream_reader = streams.StreamReader(loop=loop) - protocol = streams.StreamReaderProtocol(stream_reader) - transport = loop._make_read_pipe_transport(file, protocol) - return stream_reader + def factory(): + return streams.StreamReaderProtocol(stream_reader) + transport, _ = yield from loop.connect_read_pipe(factory, file) + return stream_reader, transport + # # Example @@ -76,9 +81,10 @@ def writeall(fd, buf): # start subprocess and wrap stdin, stdout, stderr p = Popen([sys.executable, '-c', code], stdin=PIPE, stdout=PIPE, stderr=PIPE) - stdin = connect_write_pipe(p.stdin) - stdout = connect_read_pipe(p.stdout) - stderr = connect_read_pipe(p.stderr) + + stdin = yield from connect_write_pipe(p.stdin) + stdout, stdout_transport = yield from connect_read_pipe(p.stdout) + stderr, stderr_transport = yield from connect_read_pipe(p.stderr) # interact with subprocess name = {stdout:'OUT', stderr:'ERR'} @@ -108,6 +114,8 @@ def writeall(fd, buf): registered[tulip.Task(stream.readline())] = stream timeout = 0.0 + stdout_transport.close() + stderr_transport.close() if __name__ == '__main__': if sys.platform == 'win32': From 627ccd83a3b06bae5cbb8541590b3ed8dcf27b9f Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 9 Aug 2013 03:47:58 +0300 Subject: [PATCH 0520/1502] Add --failfast parameter to runtests.py --- runtests.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/runtests.py b/runtests.py index cd334423..e254fdc1 100644 --- a/runtests.py +++ b/runtests.py @@ -39,6 +39,9 @@ nargs='?', const=1, type=int, default=0, help='verbose') ARGS.add_argument( '-x', action="store_true", dest='exclude', help='exclude tests') +ARGS.add_argument( + '-f', '--failfast', action="store_true", default=False, + dest='failfast', help='Stop on first fail or error') ARGS.add_argument( '-q', action="store_true", dest='quiet', help='quiet') ARGS.add_argument( @@ -138,6 +141,7 @@ def runtests(): includes = args.pattern v = 0 if args.quiet else args.verbose + 1 + failfast = args.failfast tests = load_tests(args.testsdir, includes, excludes) logger = logging.getLogger() @@ -151,7 +155,7 @@ def runtests(): logger.setLevel(logging.INFO) elif v >= 4: logger.setLevel(logging.DEBUG) - result = unittest.TextTestRunner(verbosity=v).run(tests) + result = unittest.TextTestRunner(verbosity=v, failfast=failfast).run(tests) sys.exit(not result.wasSuccessful()) From 2adb05703c1e1997d7358bb857d84044033cda42 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 9 Aug 2013 11:33:05 -0700 Subject: [PATCH 0521/1502] Stop relying on get_event_loop() where possible. It now raises when there is no loop. --- tests/base_events_test.py | 48 ++++-- tests/events_test.py | 99 ++++++------ tests/futures_test.py | 64 ++++---- tests/http_client_functional_test.py | 1 + tests/http_client_test.py | 2 + tests/http_server_test.py | 1 + tests/http_session_test.py | 1 + tests/http_wsgi_test.py | 1 + tests/locks_test.py | 222 +++++++++++++++----------- tests/parsers_test.py | 36 ++--- tests/proactor_events_test.py | 12 +- tests/queues_test.py | 110 +++++++------ tests/selector_events_test.py | 52 +++--- tests/streams_test.py | 14 +- tests/subprocess_test.py | 11 +- tests/tasks_test.py | 230 ++++++++++++++------------- tests/unix_events_test.py | 6 +- tests/windows_events_test.py | 3 +- tulip/base_events.py | 40 ++++- tulip/events.py | 12 +- tulip/subprocess_transport.py | 6 +- tulip/test_utils.py | 5 +- 22 files changed, 557 insertions(+), 419 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index f80e8c24..84f26a2d 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -20,7 +20,7 @@ def setUp(self): self.loop = base_events.BaseEventLoop() self.loop._selector = unittest.mock.Mock() self.loop._selector.registered_count.return_value = 1 - events.set_event_loop(self.loop) + events.set_event_loop(None) def test_not_implemented(self): m = unittest.mock.Mock() @@ -314,7 +314,7 @@ class BaseEventLoopWithSelectorTests(unittest.TestCase): def setUp(self): self.loop = events.new_event_loop() - events.set_event_loop(self.loop) + events.set_event_loop(None) def tearDown(self): self.loop.close() @@ -331,6 +331,9 @@ def getaddrinfo(*args, **kw): return [(2, 1, 6, '', ('107.6.106.82', 80)), (2, 1, 6, '', ('107.6.106.82', 80))] + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + idx = -1 errors = ['err1', 'err2'] @@ -342,7 +345,7 @@ def _socket(*args, **kw): m_socket.socket = _socket m_socket.error = socket.error - self.loop.getaddrinfo = getaddrinfo + self.loop.getaddrinfo = getaddrinfo_task task = tasks.Task( self.loop.create_connection(MyProto, 'example.com', 80)) @@ -360,20 +363,24 @@ def test_create_connection_no_host_port_sock(self): self.assertRaises(ValueError, self.loop.run_until_complete, coro) def test_create_connection_no_getaddrinfo(self): - @tasks.task + @tasks.coroutine def getaddrinfo(*args, **kw): yield from [] - self.loop.getaddrinfo = getaddrinfo + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + self.loop.getaddrinfo = getaddrinfo_task coro = self.loop.create_connection(MyProto, 'example.com', 80) self.assertRaises( socket.error, self.loop.run_until_complete, coro) def test_create_connection_connect_err(self): - @tasks.task + @tasks.coroutine def getaddrinfo(*args, **kw): yield from [] return [(2, 1, 6, '', ('107.6.106.82', 80))] - self.loop.getaddrinfo = getaddrinfo + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + self.loop.getaddrinfo = getaddrinfo_task self.loop.sock_connect = unittest.mock.Mock() self.loop.sock_connect.side_effect = socket.error @@ -381,12 +388,14 @@ def getaddrinfo(*args, **kw): self.assertRaises( socket.error, self.loop.run_until_complete, coro) - def test_create_connection_mutiple(self): - @tasks.task + def test_create_connection_multiple(self): + @tasks.coroutine def getaddrinfo(*args, **kw): return [(2, 1, 6, '', ('0.0.0.1', 80)), (2, 1, 6, '', ('0.0.0.2', 80))] - self.loop.getaddrinfo = getaddrinfo + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + self.loop.getaddrinfo = getaddrinfo_task self.loop.sock_connect = unittest.mock.Mock() self.loop.sock_connect.side_effect = socket.error @@ -407,11 +416,13 @@ def bind(addr): m_socket.socket.return_value.bind = bind - @tasks.task + @tasks.coroutine def getaddrinfo(*args, **kw): return [(2, 1, 6, '', ('0.0.0.1', 80)), (2, 1, 6, '', ('0.0.0.2', 80))] - self.loop.getaddrinfo = getaddrinfo + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + self.loop.getaddrinfo = getaddrinfo_task self.loop.sock_connect = unittest.mock.Mock() self.loop.sock_connect.side_effect = socket.error('Err2') @@ -425,14 +436,16 @@ def getaddrinfo(*args, **kw): self.assertTrue(m_socket.socket.return_value.close.called) def test_create_connection_no_local_addr(self): - @tasks.task + @tasks.coroutine def getaddrinfo(host, *args, **kw): if host == 'example.com': return [(2, 1, 6, '', ('107.6.106.82', 80)), (2, 1, 6, '', ('107.6.106.82', 80))] else: return [] - self.loop.getaddrinfo = getaddrinfo + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + self.loop.getaddrinfo = getaddrinfo_task coro = self.loop.create_connection( MyProto, 'example.com', 80, family=socket.AF_INET, @@ -444,13 +457,16 @@ def test_start_serving_empty_host(self): # if host is empty string use None instead host = object() - @tasks.task + @tasks.coroutine def getaddrinfo(*args, **kw): nonlocal host host = args[0] yield from [] - self.loop.getaddrinfo = getaddrinfo + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task fut = self.loop.start_serving(MyProto, '', 0) self.assertRaises(OSError, self.loop.run_until_complete, fut) self.assertIsNone(host) diff --git a/tests/events_test.py b/tests/events_test.py index 3c198ebf..a34fd8a5 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -31,11 +31,11 @@ class MyProto(protocols.Protocol): done = None - def __init__(self, create_future=False): + def __init__(self, loop=None): self.state = 'INITIAL' self.nbytes = 0 - if create_future: - self.done = futures.Future() + if loop is not None: + self.done = futures.Future(loop=loop) def connection_made(self, transport): self.transport = transport @@ -61,11 +61,11 @@ def connection_lost(self, exc): class MyDatagramProto(protocols.DatagramProtocol): done = None - def __init__(self, create_future=False): + def __init__(self, loop=None): self.state = 'INITIAL' self.nbytes = 0 - if create_future: - self.done = futures.Future() + if loop is not None: + self.done = futures.Future(loop=loop) def connection_made(self, transport): self.transport = transport @@ -89,12 +89,12 @@ def connection_lost(self, exc): class MyReadPipeProto(protocols.Protocol): done = None - def __init__(self, create_future=False): + def __init__(self, loop=None): self.state = ['INITIAL'] self.nbytes = 0 self.transport = None - if create_future: - self.done = futures.Future() + if loop is not None: + self.done = futures.Future(loop=loop) def connection_made(self, transport): self.transport = transport @@ -120,11 +120,11 @@ def connection_lost(self, exc): class MyWritePipeProto(protocols.Protocol): done = None - def __init__(self, create_future=False): + def __init__(self, loop=None): self.state = 'INITIAL' self.transport = None - if create_future: - self.done = futures.Future() + if loop is not None: + self.done = futures.Future(loop=loop) def connection_made(self, transport): self.transport = transport @@ -143,7 +143,7 @@ class EventLoopTestsMixin: def setUp(self): super().setUp() self.loop = self.create_event_loop() - events.set_event_loop(self.loop) + events.set_event_loop(None) def tearDown(self): # just in case if we have transport close callbacks @@ -168,22 +168,22 @@ def coro2(): def test_run_until_complete(self): t0 = self.loop.time() - self.loop.run_until_complete(tasks.sleep(0.010)) + self.loop.run_until_complete(tasks.sleep(0.010, loop=self.loop)) t1 = self.loop.time() self.assertTrue(0.009 <= t1-t0 <= 0.018, t1-t0) def test_run_until_complete_stopped(self): - @tasks.task + @tasks.coroutine def cb(): self.loop.stop() - yield from tasks.sleep(0.010) + yield from tasks.sleep(0.010, loop=self.loop) task = cb() self.assertRaises(RuntimeError, self.loop.run_until_complete, task) def test_run_until_complete_timeout(self): t0 = self.loop.time() - task = tasks.async(tasks.sleep(0.020)) + task = tasks.async(tasks.sleep(0.020, loop=self.loop), loop=self.loop) self.assertRaises(futures.TimeoutError, self.loop.run_until_complete, task, timeout=0.010) @@ -432,7 +432,7 @@ def my_handler(*args): def test_create_connection(self): with test_utils.run_test_server(self.loop) as httpd: f = self.loop.create_connection( - lambda: MyProto(create_future=True), *httpd.address) + lambda: MyProto(loop=self.loop), *httpd.address) tr, pr = self.loop.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) @@ -460,7 +460,7 @@ def test_create_connection_sock(self): assert False, 'Can not create socket.' f = self.loop.create_connection( - lambda: MyProto(create_future=True), sock=sock) + lambda: MyProto(loop=self.loop), sock=sock) tr, pr = self.loop.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) @@ -473,7 +473,7 @@ def test_create_ssl_connection(self): with test_utils.run_test_server( self.loop, use_ssl=True) as httpd: f = self.loop.create_connection( - lambda: MyProto(create_future=True), *httpd.address, ssl=True) + lambda: MyProto(loop=self.loop), *httpd.address, ssl=True) tr, pr = self.loop.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) @@ -488,7 +488,7 @@ def test_create_connection_local_addr(self): with test_utils.run_test_server(self.loop) as httpd: port = find_unused_port() f = self.loop.create_connection( - lambda: MyProto(create_future=True), + lambda: MyProto(loop=self.loop), *httpd.address, local_addr=(httpd.address[0], port)) tr, pr = self.loop.run_until_complete(f) expected = pr.transport.get_extra_info('socket').getsockname()[1] @@ -498,7 +498,7 @@ def test_create_connection_local_addr(self): def test_create_connection_local_addr_in_use(self): with test_utils.run_test_server(self.loop) as httpd: f = self.loop.create_connection( - lambda: MyProto(create_future=True), + lambda: MyProto(loop=self.loop), *httpd.address, local_addr=httpd.address) with self.assertRaises(socket.error) as cm: self.loop.run_until_complete(f) @@ -562,7 +562,7 @@ def connection_made(self, transport): def factory(): nonlocal proto - proto = MyProto(create_future=True) + proto = MyProto(loop=self.loop) return proto here = os.path.dirname(__file__) @@ -608,7 +608,7 @@ def factory(): self.loop.stop_serving(sock) def test_start_serving_sock(self): - proto = futures.Future() + proto = futures.Future(loop=self.loop) class TestMyProto(MyProto): def connection_made(self, transport): @@ -650,7 +650,7 @@ def test_start_serving_addr_in_use(self): @unittest.skipUnless(socket.has_ipv6, 'IPv6 not supported') def test_start_serving_dual_stack(self): - f_proto = futures.Future() + f_proto = futures.Future(loop=self.loop) class TestMyProto(MyProto): def connection_made(self, transport): @@ -667,7 +667,7 @@ def connection_made(self, transport): proto.transport.close() client.close() - f_proto = futures.Future() + f_proto = futures.Future(loop=self.loop) client = socket.socket(socket.AF_INET6) client.connect(('::1', port)) client.send(b'xxx') @@ -698,8 +698,8 @@ def test_stop_serving(self): def test_create_datagram_endpoint(self): class TestMyDatagramProto(MyDatagramProto): - def __init__(self): - super().__init__(create_future=True) + def __init__(inner_self): + super().__init__(loop=self.loop) def datagram_received(self, data, addr): super().datagram_received(data, addr) @@ -711,7 +711,7 @@ def datagram_received(self, data, addr): host, port = s_transport.get_extra_info('addr') coro = self.loop.create_datagram_endpoint( - lambda: MyDatagramProto(create_future=True), + lambda: MyDatagramProto(loop=self.loop), remote_addr=(host, port)) transport, client = self.loop.run_until_complete(coro) @@ -753,13 +753,13 @@ def test_read_pipe(self): def factory(): nonlocal proto - proto = MyReadPipeProto(create_future=True) + proto = MyReadPipeProto(loop=self.loop) return proto rpipe, wpipe = os.pipe() pipeobj = io.open(rpipe, 'rb', 1024) - @tasks.task + @tasks.coroutine def connect(): t, p = yield from self.loop.connect_read_pipe(factory, pipeobj) self.assertIs(p, proto) @@ -793,13 +793,13 @@ def test_write_pipe(self): def factory(): nonlocal proto - proto = MyWritePipeProto(create_future=True) + proto = MyWritePipeProto(loop=self.loop) return proto rpipe, wpipe = os.pipe() pipeobj = io.open(wpipe, 'wb', 1024) - @tasks.task + @tasks.coroutine def connect(): nonlocal transport t, p = yield from self.loop.connect_write_pipe(factory, pipeobj) @@ -839,13 +839,13 @@ def test_write_pipe_disconnect_on_close(self): def factory(): nonlocal proto - proto = MyWritePipeProto(create_future=True) + proto = MyWritePipeProto(loop=self.loop) return proto rpipe, wpipe = os.pipe() pipeobj = io.open(wpipe, 'wb', 1024) - @tasks.task + @tasks.coroutine def connect(): nonlocal transport t, p = yield from self.loop.connect_write_pipe(factory, @@ -888,7 +888,7 @@ def main(): return res start = time.monotonic() - t = tasks.Task(main(), timeout=1) + t = tasks.Task(main(), timeout=1, loop=self.loop) self.loop.run_forever() elapsed = time.monotonic() - start @@ -1185,33 +1185,38 @@ def test_get_event_loop(self): self.assertIs(loop, policy.get_event_loop()) loop.close() + def test_get_event_loop_after_set_none(self): + policy = events.DefaultEventLoopPolicy() + policy.set_event_loop(None) + self.assertRaises(AssertionError, policy.get_event_loop) + @unittest.mock.patch('tulip.events.threading') def test_get_event_loop_thread(self, m_threading): m_t = m_threading.current_thread.return_value = unittest.mock.Mock() m_t.name = 'Thread 1' policy = events.DefaultEventLoopPolicy() - self.assertIsNone(policy.get_event_loop()) + self.assertRaises(AssertionError, policy.get_event_loop) def test_new_event_loop(self): policy = events.DefaultEventLoopPolicy() - event_loop = policy.new_event_loop() - self.assertIsInstance(event_loop, events.AbstractEventLoop) - event_loop.close() + loop = policy.new_event_loop() + self.assertIsInstance(loop, events.AbstractEventLoop) + loop.close() def test_set_event_loop(self): policy = events.DefaultEventLoopPolicy() - old_event_loop = policy.get_event_loop() + old_loop = policy.get_event_loop() self.assertRaises(AssertionError, policy.set_event_loop, object()) - event_loop = policy.new_event_loop() - policy.set_event_loop(event_loop) - self.assertIs(event_loop, policy.get_event_loop()) - self.assertIsNot(old_event_loop, policy.get_event_loop()) - event_loop.close() - old_event_loop.close() + loop = policy.new_event_loop() + policy.set_event_loop(loop) + self.assertIs(loop, policy.get_event_loop()) + self.assertIsNot(old_loop, policy.get_event_loop()) + loop.close() + old_loop.close() def test_get_event_loop_policy(self): policy = events.get_event_loop_policy() diff --git a/tests/futures_test.py b/tests/futures_test.py index 87198cf7..e448ff82 100644 --- a/tests/futures_test.py +++ b/tests/futures_test.py @@ -18,25 +18,33 @@ class FutureTests(unittest.TestCase): def setUp(self): self.loop = events.new_event_loop() - events.set_event_loop(self.loop) + events.set_event_loop(None) def tearDown(self): self.loop.close() def test_initial_state(self): - f = futures.Future() + f = futures.Future(loop=self.loop) self.assertFalse(f.cancelled()) self.assertFalse(f.running()) self.assertFalse(f.done()) f.cancel() self.assertTrue(f.cancelled()) - def test_init_event_loop_positional(self): + def test_init_constructor_default_loop(self): + try: + events.set_event_loop(self.loop) + f = futures.Future() + self.assertIs(f._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_constructor_positional(self): # Make sure Future does't accept a positional argument self.assertRaises(TypeError, futures.Future, 42) def test_cancel(self): - f = futures.Future() + f = futures.Future(loop=self.loop) self.assertTrue(f.cancel()) self.assertTrue(f.cancelled()) self.assertFalse(f.running()) @@ -48,7 +56,7 @@ def test_cancel(self): self.assertFalse(f.cancel()) def test_result(self): - f = futures.Future() + f = futures.Future(loop=self.loop) self.assertRaises(futures.InvalidStateError, f.result) self.assertRaises(futures.InvalidTimeoutError, f.result, 10) @@ -64,7 +72,7 @@ def test_result(self): def test_exception(self): exc = RuntimeError() - f = futures.Future() + f = futures.Future(loop=self.loop) self.assertRaises(futures.InvalidStateError, f.exception) self.assertRaises(futures.InvalidTimeoutError, f.exception, 10) @@ -79,7 +87,7 @@ def test_exception(self): self.assertFalse(f.cancel()) def test_yield_from_twice(self): - f = futures.Future() + f = futures.Future(loop=self.loop) def fixture(): yield 'A' @@ -97,32 +105,32 @@ def fixture(): self.assertEqual(next(g), ('C', 42)) # yield 'C', y. def test_repr(self): - f_pending = futures.Future() + f_pending = futures.Future(loop=self.loop) self.assertEqual(repr(f_pending), 'Future') f_pending.cancel() - f_cancelled = futures.Future() + f_cancelled = futures.Future(loop=self.loop) f_cancelled.cancel() self.assertEqual(repr(f_cancelled), 'Future') - f_result = futures.Future() + f_result = futures.Future(loop=self.loop) f_result.set_result(4) self.assertEqual(repr(f_result), 'Future') self.assertEqual(f_result.result(), 4) exc = RuntimeError() - f_exception = futures.Future() + f_exception = futures.Future(loop=self.loop) f_exception.set_exception(exc) self.assertEqual(repr(f_exception), 'Future') self.assertIs(f_exception.exception(), exc) - f_few_callbacks = futures.Future() + f_few_callbacks = futures.Future(loop=self.loop) f_few_callbacks.add_done_callback(_fakefunc) self.assertIn('Future')) @tasks.coroutine @@ -40,7 +48,7 @@ def acquire_lock(): self.assertTrue(repr(lock).endswith('[locked]>')) def test_lock(self): - lock = locks.Lock() + lock = locks.Lock(loop=self.loop) @tasks.coroutine def acquire_lock(): @@ -55,7 +63,7 @@ def acquire_lock(): self.assertFalse(lock.locked()) def test_acquire(self): - lock = locks.Lock() + lock = locks.Lock(loop=self.loop) result = [] self.assertTrue(self.loop.run_until_complete(lock.acquire())) @@ -78,8 +86,8 @@ def c3(result): result.append(3) return True - t1 = tasks.Task(c1(result)) - t2 = tasks.Task(c2(result)) + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) run_briefly(self.loop) self.assertEqual([], result) @@ -91,7 +99,7 @@ def c3(result): run_briefly(self.loop) self.assertEqual([1], result) - t3 = tasks.Task(c3(result)) + t3 = tasks.Task(c3(result), loop=self.loop) lock.release() run_briefly(self.loop) @@ -109,7 +117,7 @@ def c3(result): self.assertTrue(t3.result()) def test_acquire_timeout(self): - lock = locks.Lock() + lock = locks.Lock(loop=self.loop) self.assertTrue(self.loop.run_until_complete(lock.acquire())) t0 = time.monotonic() @@ -119,7 +127,7 @@ def test_acquire_timeout(self): total_time = (time.monotonic() - t0) self.assertTrue(0.08 < total_time < 0.12) - lock = locks.Lock() + lock = locks.Lock(loop=self.loop) self.loop.run_until_complete(lock.acquire()) self.loop.call_later(0.01, lock.release) @@ -127,12 +135,12 @@ def test_acquire_timeout(self): self.assertTrue(acquired) def test_acquire_timeout_mixed(self): - lock = locks.Lock() + lock = locks.Lock(loop=self.loop) self.loop.run_until_complete(lock.acquire()) - tasks.Task(lock.acquire()) - tasks.Task(lock.acquire()) - acquire_task = tasks.Task(lock.acquire(0.01)) - tasks.Task(lock.acquire()) + tasks.Task(lock.acquire(), loop=self.loop) + tasks.Task(lock.acquire(), loop=self.loop) + acquire_task = tasks.Task(lock.acquire(0.01), loop=self.loop) + tasks.Task(lock.acquire(), loop=self.loop) acquired = self.loop.run_until_complete(acquire_task) self.assertFalse(acquired) @@ -140,10 +148,10 @@ def test_acquire_timeout_mixed(self): self.assertEqual(3, len(lock._waiters)) def test_acquire_cancel(self): - lock = locks.Lock() + lock = locks.Lock(loop=self.loop) self.assertTrue(self.loop.run_until_complete(lock.acquire())) - task = tasks.Task(lock.acquire()) + task = tasks.Task(lock.acquire(), loop=self.loop) self.loop.call_soon(task.cancel) self.assertRaises( futures.CancelledError, @@ -151,12 +159,12 @@ def test_acquire_cancel(self): self.assertFalse(lock._waiters) def test_release_not_acquired(self): - lock = locks.Lock() + lock = locks.Lock(loop=self.loop) self.assertRaises(RuntimeError, lock.release) def test_release_no_waiters(self): - lock = locks.Lock() + lock = locks.Lock(loop=self.loop) self.loop.run_until_complete(lock.acquire()) self.assertTrue(lock.locked()) @@ -164,9 +172,9 @@ def test_release_no_waiters(self): self.assertFalse(lock.locked()) def test_context_manager(self): - lock = locks.Lock() + lock = locks.Lock(loop=self.loop) - @tasks.task + @tasks.coroutine def acquire_lock(): return (yield from lock) @@ -176,7 +184,7 @@ def acquire_lock(): self.assertFalse(lock.locked()) def test_context_manager_no_yield(self): - lock = locks.Lock() + lock = locks.Lock(loop=self.loop) try: with lock: @@ -191,7 +199,7 @@ class EventWaiterTests(unittest.TestCase): def setUp(self): self.loop = events.new_event_loop() - events.set_event_loop(self.loop) + events.set_event_loop(None) def tearDown(self): self.loop.close() @@ -201,18 +209,26 @@ def test_ctor_loop(self): ev = locks.EventWaiter(loop=loop) self.assertIs(ev._loop, loop) - ev = locks.EventWaiter() - self.assertIs(ev._loop, events.get_event_loop()) + ev = locks.EventWaiter(loop=self.loop) + self.assertIs(ev._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + ev = locks.EventWaiter() + self.assertIs(ev._loop, self.loop) + finally: + events.set_event_loop(None) def test_repr(self): - ev = locks.EventWaiter() + ev = locks.EventWaiter(loop=self.loop) self.assertTrue(repr(ev).endswith('[unset]>')) ev.set() self.assertTrue(repr(ev).endswith('[set]>')) def test_wait(self): - ev = locks.EventWaiter() + ev = locks.EventWaiter(loop=self.loop) self.assertFalse(ev.is_set()) result = [] @@ -232,13 +248,13 @@ def c3(result): if (yield from ev.wait()): result.append(3) - t1 = tasks.Task(c1(result)) - t2 = tasks.Task(c2(result)) + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) run_briefly(self.loop) self.assertEqual([], result) - t3 = tasks.Task(c3(result)) + t3 = tasks.Task(c3(result), loop=self.loop) ev.set() run_briefly(self.loop) @@ -252,14 +268,14 @@ def c3(result): self.assertIsNone(t3.result()) def test_wait_on_set(self): - ev = locks.EventWaiter() + ev = locks.EventWaiter(loop=self.loop) ev.set() res = self.loop.run_until_complete(ev.wait()) self.assertTrue(res) def test_wait_timeout(self): - ev = locks.EventWaiter() + ev = locks.EventWaiter(loop=self.loop) t0 = time.monotonic() res = self.loop.run_until_complete(ev.wait(0.1)) @@ -267,17 +283,17 @@ def test_wait_timeout(self): total_time = (time.monotonic() - t0) self.assertTrue(0.08 < total_time < 0.12) - ev = locks.EventWaiter() + ev = locks.EventWaiter(loop=self.loop) self.loop.call_later(0.01, ev.set) acquired = self.loop.run_until_complete(ev.wait(10.1)) self.assertTrue(acquired) def test_wait_timeout_mixed(self): - ev = locks.EventWaiter() - tasks.Task(ev.wait()) - tasks.Task(ev.wait()) - acquire_task = tasks.Task(ev.wait(0.1)) - tasks.Task(ev.wait()) + ev = locks.EventWaiter(loop=self.loop) + tasks.Task(ev.wait(), loop=self.loop) + tasks.Task(ev.wait(), loop=self.loop) + acquire_task = tasks.Task(ev.wait(0.1), loop=self.loop) + tasks.Task(ev.wait(), loop=self.loop) t0 = time.monotonic() acquired = self.loop.run_until_complete(acquire_task) @@ -289,9 +305,9 @@ def test_wait_timeout_mixed(self): self.assertEqual(3, len(ev._waiters)) def test_wait_cancel(self): - ev = locks.EventWaiter() + ev = locks.EventWaiter(loop=self.loop) - wait = tasks.Task(ev.wait()) + wait = tasks.Task(ev.wait(), loop=self.loop) self.loop.call_soon(wait.cancel) self.assertRaises( futures.CancelledError, @@ -299,7 +315,7 @@ def test_wait_cancel(self): self.assertFalse(ev._waiters) def test_clear(self): - ev = locks.EventWaiter() + ev = locks.EventWaiter(loop=self.loop) self.assertFalse(ev.is_set()) ev.set() @@ -309,7 +325,7 @@ def test_clear(self): self.assertFalse(ev.is_set()) def test_clear_with_waiters(self): - ev = locks.EventWaiter() + ev = locks.EventWaiter(loop=self.loop) result = [] @tasks.coroutine @@ -318,7 +334,7 @@ def c1(result): result.append(1) return True - t = tasks.Task(c1(result)) + t = tasks.Task(c1(result), loop=self.loop) run_briefly(self.loop) self.assertEqual([], result) @@ -342,13 +358,29 @@ class ConditionTests(unittest.TestCase): def setUp(self): self.loop = events.new_event_loop() - events.set_event_loop(self.loop) + events.set_event_loop(None) def tearDown(self): self.loop.close() + def test_ctor_loop(self): + loop = unittest.mock.Mock() + cond = locks.Condition(loop=loop) + self.assertIs(cond._loop, loop) + + cond = locks.Condition(loop=self.loop) + self.assertIs(cond._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + cond = locks.Condition() + self.assertIs(cond._loop, self.loop) + finally: + events.set_event_loop(None) + def test_wait(self): - cond = locks.Condition() + cond = locks.Condition(loop=self.loop) result = [] @tasks.coroutine @@ -372,9 +404,9 @@ def c3(result): result.append(3) return True - t1 = tasks.Task(c1(result)) - t2 = tasks.Task(c2(result)) - t3 = tasks.Task(c3(result)) + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) run_briefly(self.loop) self.assertEqual([], result) @@ -414,7 +446,7 @@ def c3(result): self.assertTrue(t3.result()) def test_wait_timeout(self): - cond = locks.Condition() + cond = locks.Condition(loop=self.loop) self.loop.run_until_complete(cond.acquire()) t0 = time.monotonic() @@ -426,10 +458,10 @@ def test_wait_timeout(self): self.assertTrue(0.08 < total_time < 0.12) def test_wait_cancel(self): - cond = locks.Condition() + cond = locks.Condition(loop=self.loop) self.loop.run_until_complete(cond.acquire()) - wait = tasks.Task(cond.wait()) + wait = tasks.Task(cond.wait(), loop=self.loop) self.loop.call_soon(wait.cancel) self.assertRaises( futures.CancelledError, @@ -438,13 +470,13 @@ def test_wait_cancel(self): self.assertTrue(cond.locked()) def test_wait_unacquired(self): - cond = locks.Condition() + cond = locks.Condition(loop=self.loop) self.assertRaises( RuntimeError, self.loop.run_until_complete, cond.wait()) def test_wait_for(self): - cond = locks.Condition() + cond = locks.Condition(loop=self.loop) presult = False def predicate(): @@ -460,7 +492,7 @@ def c1(result): cond.release() return True - t = tasks.Task(c1(result)) + t = tasks.Task(c1(result), loop=self.loop) run_briefly(self.loop) self.assertEqual([], result) @@ -482,7 +514,7 @@ def c1(result): self.assertTrue(t.result()) def test_wait_for_timeout(self): - cond = locks.Condition() + cond = locks.Condition(loop=self.loop) result = [] @@ -498,7 +530,7 @@ def c1(result): result.append(2) cond.release() - wait_for = tasks.Task(c1(result)) + wait_for = tasks.Task(c1(result), loop=self.loop) t0 = time.monotonic() @@ -519,7 +551,7 @@ def c1(result): self.assertTrue(0.08 < total_time < 0.12) def test_wait_for_unacquired(self): - cond = locks.Condition() + cond = locks.Condition(loop=self.loop) # predicate can return true immediately res = self.loop.run_until_complete(cond.wait_for(lambda: [1, 2, 3])) @@ -531,7 +563,7 @@ def test_wait_for_unacquired(self): cond.wait_for(lambda: False)) def test_notify(self): - cond = locks.Condition() + cond = locks.Condition(loop=self.loop) result = [] @tasks.coroutine @@ -558,9 +590,9 @@ def c3(result): cond.release() return True - t1 = tasks.Task(c1(result)) - t2 = tasks.Task(c2(result)) - t3 = tasks.Task(c3(result)) + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) run_briefly(self.loop) self.assertEqual([], result) @@ -586,7 +618,7 @@ def c3(result): self.assertTrue(t3.result()) def test_notify_all(self): - cond = locks.Condition() + cond = locks.Condition(loop=self.loop) result = [] @@ -606,8 +638,8 @@ def c2(result): cond.release() return True - t1 = tasks.Task(c1(result)) - t2 = tasks.Task(c2(result)) + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) run_briefly(self.loop) self.assertEqual([], result) @@ -624,11 +656,11 @@ def c2(result): self.assertTrue(t2.result()) def test_notify_unacquired(self): - cond = locks.Condition() + cond = locks.Condition(loop=self.loop) self.assertRaises(RuntimeError, cond.notify) def test_notify_all_unacquired(self): - cond = locks.Condition() + cond = locks.Condition(loop=self.loop) self.assertRaises(RuntimeError, cond.notify_all) @@ -636,7 +668,7 @@ class SemaphoreTests(unittest.TestCase): def setUp(self): self.loop = events.new_event_loop() - events.set_event_loop(self.loop) + events.set_event_loop(None) def tearDown(self): self.loop.close() @@ -646,21 +678,29 @@ def test_ctor_loop(self): sem = locks.Semaphore(loop=loop) self.assertIs(sem._loop, loop) - sem = locks.Semaphore() - self.assertIs(sem._loop, events.get_event_loop()) + sem = locks.Semaphore(loop=self.loop) + self.assertIs(sem._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + sem = locks.Semaphore() + self.assertIs(sem._loop, self.loop) + finally: + events.set_event_loop(None) def test_repr(self): - sem = locks.Semaphore() + sem = locks.Semaphore(loop=self.loop) self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) self.loop.run_until_complete(sem.acquire()) self.assertTrue(repr(sem).endswith('[locked]>')) def test_semaphore(self): - sem = locks.Semaphore() + sem = locks.Semaphore(loop=self.loop) self.assertEqual(1, sem._value) - @tasks.task + @tasks.coroutine def acquire_lock(): return (yield from sem) @@ -678,7 +718,7 @@ def test_semaphore_value(self): self.assertRaises(ValueError, locks.Semaphore, -1) def test_acquire(self): - sem = locks.Semaphore(3) + sem = locks.Semaphore(3, loop=self.loop) result = [] self.assertTrue(self.loop.run_until_complete(sem.acquire())) @@ -709,9 +749,9 @@ def c4(result): result.append(4) return True - t1 = tasks.Task(c1(result)) - t2 = tasks.Task(c2(result)) - t3 = tasks.Task(c3(result)) + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) run_briefly(self.loop) self.assertEqual([1], result) @@ -719,7 +759,7 @@ def c4(result): self.assertEqual(2, len(sem._waiters)) self.assertEqual(0, sem._value) - t4 = tasks.Task(c4(result)) + t4 = tasks.Task(c4(result), loop=self.loop) sem.release() sem.release() @@ -741,7 +781,7 @@ def c4(result): self.assertFalse(t4.done()) def test_acquire_timeout(self): - sem = locks.Semaphore() + sem = locks.Semaphore(loop=self.loop) self.loop.run_until_complete(sem.acquire()) t0 = time.monotonic() @@ -751,7 +791,7 @@ def test_acquire_timeout(self): total_time = (time.monotonic() - t0) self.assertTrue(0.08 < total_time < 0.12) - sem = locks.Semaphore() + sem = locks.Semaphore(loop=self.loop) self.loop.run_until_complete(sem.acquire()) self.loop.call_later(0.01, sem.release) @@ -759,12 +799,12 @@ def test_acquire_timeout(self): self.assertTrue(acquired) def test_acquire_timeout_mixed(self): - sem = locks.Semaphore() + sem = locks.Semaphore(loop=self.loop) self.loop.run_until_complete(sem.acquire()) - tasks.Task(sem.acquire()) - tasks.Task(sem.acquire()) - acquire_task = tasks.Task(sem.acquire(0.1)) - tasks.Task(sem.acquire()) + tasks.Task(sem.acquire(), loop=self.loop) + tasks.Task(sem.acquire(), loop=self.loop) + acquire_task = tasks.Task(sem.acquire(0.1), loop=self.loop) + tasks.Task(sem.acquire(), loop=self.loop) t0 = time.monotonic() acquired = self.loop.run_until_complete(acquire_task) @@ -776,10 +816,10 @@ def test_acquire_timeout_mixed(self): self.assertEqual(3, len(sem._waiters)) def test_acquire_cancel(self): - sem = locks.Semaphore() + sem = locks.Semaphore(loop=self.loop) self.loop.run_until_complete(sem.acquire()) - acquire = tasks.Task(sem.acquire()) + acquire = tasks.Task(sem.acquire(), loop=self.loop) self.loop.call_soon(acquire.cancel) self.assertRaises( futures.CancelledError, @@ -787,12 +827,12 @@ def test_acquire_cancel(self): self.assertFalse(sem._waiters) def test_release_not_acquired(self): - sem = locks.Semaphore(bound=True) + sem = locks.Semaphore(bound=True, loop=self.loop) self.assertRaises(ValueError, sem.release) def test_release_no_waiters(self): - sem = locks.Semaphore() + sem = locks.Semaphore(loop=self.loop) self.loop.run_until_complete(sem.acquire()) self.assertTrue(sem.locked()) @@ -800,9 +840,9 @@ def test_release_no_waiters(self): self.assertFalse(sem.locked()) def test_context_manager(self): - sem = locks.Semaphore(2) + sem = locks.Semaphore(2, loop=self.loop) - @tasks.task + @tasks.coroutine def acquire_lock(): return (yield from sem) diff --git a/tests/parsers_test.py b/tests/parsers_test.py index 083e141c..debc532c 100644 --- a/tests/parsers_test.py +++ b/tests/parsers_test.py @@ -14,7 +14,7 @@ class StreamBufferTests(unittest.TestCase): def setUp(self): self.loop = events.new_event_loop() - events.set_event_loop(self.loop) + events.set_event_loop(None) def tearDown(self): self.loop.close() @@ -31,7 +31,7 @@ def test_exception_waiter(self): stream = parsers.StreamBuffer() stream._parser = parsers.lines_parser() - buf = stream._parser_buffer = parsers.DataBuffer() + buf = stream._parser_buffer = parsers.DataBuffer(loop=self.loop) exc = ValueError() stream.set_exception(exc) @@ -318,27 +318,27 @@ class DataBufferTests(unittest.TestCase): def setUp(self): self.loop = events.new_event_loop() - events.set_event_loop(self.loop) + events.set_event_loop(None) def tearDown(self): self.loop.close() def test_feed_data(self): - buffer = parsers.DataBuffer() + buffer = parsers.DataBuffer(loop=self.loop) item = object() buffer.feed_data(item) self.assertEqual([item], list(buffer._buffer)) def test_feed_eof(self): - buffer = parsers.DataBuffer() + buffer = parsers.DataBuffer(loop=self.loop) buffer.feed_eof() self.assertTrue(buffer._eof) def test_read(self): item = object() - buffer = parsers.DataBuffer() - read_task = tasks.Task(buffer.read()) + buffer = parsers.DataBuffer(loop=self.loop) + read_task = tasks.Task(buffer.read(), loop=self.loop) def cb(): buffer.feed_data(item) @@ -348,8 +348,8 @@ def cb(): self.assertIs(item, data) def test_read_eof(self): - buffer = parsers.DataBuffer() - read_task = tasks.Task(buffer.read()) + buffer = parsers.DataBuffer(loop=self.loop) + read_task = tasks.Task(buffer.read(), loop=self.loop) def cb(): buffer.feed_eof() @@ -360,7 +360,7 @@ def cb(): def test_read_until_eof(self): item = object() - buffer = parsers.DataBuffer() + buffer = parsers.DataBuffer(loop=self.loop) buffer.feed_data(item) buffer.feed_eof() @@ -371,7 +371,7 @@ def test_read_until_eof(self): self.assertIsNone(data) def test_read_exception(self): - buffer = parsers.DataBuffer() + buffer = parsers.DataBuffer(loop=self.loop) buffer.feed_data(object()) buffer.set_exception(ValueError()) @@ -379,7 +379,7 @@ def test_read_exception(self): ValueError, self.loop.run_until_complete, buffer.read()) def test_exception(self): - buffer = parsers.DataBuffer() + buffer = parsers.DataBuffer(loop=self.loop) self.assertIsNone(buffer.exception()) exc = ValueError() @@ -387,16 +387,16 @@ def test_exception(self): self.assertIs(buffer.exception(), exc) def test_exception_waiter(self): - buffer = parsers.DataBuffer() + buffer = parsers.DataBuffer(loop=self.loop) @tasks.coroutine def set_err(): buffer.set_exception(ValueError()) - t1 = tasks.Task(buffer.read()) - t2 = tasks.Task(set_err()) + t1 = tasks.Task(buffer.read(), loop=self.loop) + t2 = tasks.Task(set_err(), loop=self.loop) - self.loop.run_until_complete(tasks.wait([t1, t2])) + self.loop.run_until_complete(tasks.wait([t1, t2], loop=self.loop)) self.assertRaises(ValueError, t1.result) @@ -559,7 +559,7 @@ def test_skipuntil(self): self.assertEqual(b'', bytes(buf)) def test_lines_parser(self): - out = parsers.DataBuffer() + out = parsers.DataBuffer(loop=self.loop) buf = self._make_one() p = parsers.lines_parser() next(p) @@ -579,7 +579,7 @@ def test_lines_parser(self): self.assertEqual(bytes(buf), b'data') def test_chunks_parser(self): - out = parsers.DataBuffer() + out = parsers.DataBuffer(loop=self.loop) buf = self._make_one() p = parsers.chunks_parser(5) next(p) diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py index 14764e49..6b6de32f 100644 --- a/tests/proactor_events_test.py +++ b/tests/proactor_events_test.py @@ -17,7 +17,7 @@ def setUp(self): self.protocol = unittest.mock.Mock(tulip.Protocol) def test_ctor(self): - fut = tulip.Future() + fut = tulip.Future(loop=self.loop) tr = _ProactorSocketTransport( self.loop, self.sock, self.protocol, fut) self.loop.call_soon.mock_calls[0].assert_called_with(tr._loop_reading) @@ -34,7 +34,7 @@ def test_loop_reading(self): self.assertFalse(self.protocol.eof_received.called) def test_loop_reading_data(self): - res = tulip.Future() + res = tulip.Future(loop=self.loop) res.set_result(b'data') tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) @@ -45,7 +45,7 @@ def test_loop_reading_data(self): self.protocol.data_received.assert_called_with(b'data') def test_loop_reading_no_data(self): - res = tulip.Future() + res = tulip.Future(loop=self.loop) res.set_result(b'') tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) @@ -150,7 +150,7 @@ def test_loop_writing_err(self, m_log): m_log.warning.assert_called_with('socket.send() raised exception.') def test_loop_writing_stop(self): - fut = tulip.Future() + fut = tulip.Future(loop=self.loop) fut.set_result(b'data') tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) @@ -159,7 +159,7 @@ def test_loop_writing_stop(self): self.assertIsNone(tr._write_fut) def test_loop_writing_closing(self): - fut = tulip.Future() + fut = tulip.Future(loop=self.loop) fut.set_result(1) tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) @@ -377,7 +377,7 @@ def test_start_serving_cancel(self): # cancelled self.sock.reset_mock() - fut = tulip.Future() + fut = tulip.Future(loop=self.loop) fut.cancel() loop(fut) self.assertTrue(self.sock.close.called) diff --git a/tests/queues_test.py b/tests/queues_test.py index 5632bbff..fb81b3fe 100644 --- a/tests/queues_test.py +++ b/tests/queues_test.py @@ -14,7 +14,7 @@ class _QueueTestBase(unittest.TestCase): def setUp(self): self.loop = events.new_event_loop() - events.set_event_loop(self.loop) + events.set_event_loop(None) def tearDown(self): self.loop.close() @@ -28,35 +28,35 @@ def _test_repr_or_str(self, fn, expect_id): fn is repr or str. expect_id is True if we expect the Queue's id to appear in fn(Queue()). """ - q = queues.Queue() + q = queues.Queue(loop=self.loop) self.assertTrue(fn(q).startswith(')') t.cancel() # Does not take immediate effect! @@ -139,7 +139,7 @@ def notmuch(): self.assertRaises(futures.CancelledError, self.loop.run_until_complete, t) self.assertEqual(repr(t), 'Task()') - t = notmuch() + t = tasks.Task(notmuch(), loop=self.loop) self.loop.run_until_complete(t) self.assertEqual(repr(t), "Task()") @@ -156,21 +156,21 @@ class MyTask(tasks.Task, T): def __repr__(self): return super().__repr__() - t = MyTask(coro()) + t = MyTask(coro(), loop=self.loop) self.assertEqual(repr(t), 'T[]()') def test_task_basics(self): - @tasks.task + @tasks.coroutine def outer(): a = yield from inner1() b = yield from inner2() return a+b - @tasks.task + @tasks.coroutine def inner1(): return 42 - @tasks.task + @tasks.coroutine def inner2(): return 1000 @@ -178,12 +178,12 @@ def inner2(): self.assertEqual(self.loop.run_until_complete(t), 1042) def test_cancel(self): - @tasks.task + @tasks.coroutine def task(): - yield from tasks.sleep(10.0) + yield from tasks.sleep(10.0, loop=self.loop) return 12 - t = task() + t = tasks.Task(task(), loop=self.loop) self.loop.call_soon(t.cancel) self.assertRaises( futures.CancelledError, self.loop.run_until_complete, t) @@ -191,13 +191,13 @@ def task(): self.assertFalse(t.cancel()) def test_cancel_yield(self): - @tasks.task + @tasks.coroutine def task(): yield yield return 12 - t = task() + t = tasks.Task(task(), loop=self.loop) test_utils.run_briefly(self.loop) # start coro t.cancel() self.assertRaises( @@ -210,7 +210,7 @@ def test_cancel_done_future(self): fut2 = futures.Future(loop=self.loop) fut3 = futures.Future(loop=self.loop) - @tasks.task + @tasks.coroutine def task(): yield from fut1 try: @@ -219,7 +219,7 @@ def task(): pass yield from fut3 - t = task() + t = tasks.Task(task(), loop=self.loop) test_utils.run_briefly(self.loop) fut1.set_result(None) t.cancel() @@ -238,10 +238,10 @@ def task(): def test_future_timeout(self): @tasks.coroutine def coro(): - yield from tasks.sleep(10.0) + yield from tasks.sleep(10.0, loop=self.loop) return 12 - t = tasks.Task(coro(), timeout=0.1) + t = tasks.Task(coro(), timeout=0.1, loop=self.loop) self.assertRaises( futures.CancelledError, @@ -252,7 +252,7 @@ def coro(): def test_future_timeout_catch(self): @tasks.coroutine def coro(): - yield from tasks.sleep(10.0) + yield from tasks.sleep(10.0, loop=self.loop) return 12 class Cancelled(Exception): @@ -261,7 +261,7 @@ class Cancelled(Exception): @tasks.coroutine def coro2(): try: - yield from tasks.Task(coro(), timeout=0.1) + yield from tasks.Task(coro(), timeout=0.1, loop=self.loop) except futures.CancelledError: raise Cancelled() @@ -274,7 +274,7 @@ def task(): t.cancel() return 12 - t = tasks.Task(task()) + t = tasks.Task(task(), loop=self.loop) self.assertRaises( futures.CancelledError, self.loop.run_until_complete, t) self.assertTrue(t.done()) @@ -287,12 +287,12 @@ def test_stop_while_run_in_complete(self): def task(): nonlocal x while x < 10: - yield from tasks.sleep(0.1) + yield from tasks.sleep(0.1, loop=self.loop) x += 1 if x == 2: self.loop.stop() - t = tasks.Task(task()) + t = tasks.Task(task(), loop=self.loop) t0 = time.monotonic() self.assertRaises( RuntimeError, self.loop.run_until_complete, t) @@ -302,12 +302,12 @@ def task(): self.assertEqual(x, 2) def test_timeout(self): - @tasks.task + @tasks.coroutine def task(): - yield from tasks.sleep(10.0) + yield from tasks.sleep(10.0, loop=self.loop) return 42 - t = task() + t = tasks.Task(task(), loop=self.loop) t0 = time.monotonic() self.assertRaises( futures.TimeoutError, self.loop.run_until_complete, t, 0.1) @@ -316,12 +316,12 @@ def task(): self.assertTrue(0.08 <= t1-t0 <= 0.12) def test_timeout_not(self): - @tasks.task + @tasks.coroutine def task(): - yield from tasks.sleep(0.1) + yield from tasks.sleep(0.1, loop=self.loop) return 42 - t = task() + t = tasks.Task(task(), loop=self.loop) t0 = time.monotonic() r = self.loop.run_until_complete(t, 10.0) t1 = time.monotonic() @@ -330,24 +330,24 @@ def task(): self.assertTrue(0.08 <= t1-t0 <= 0.12) def test_wait(self): - a = tasks.Task(tasks.sleep(0.1)) - b = tasks.Task(tasks.sleep(0.15)) + a = tasks.Task(tasks.sleep(0.1, loop=self.loop), loop=self.loop) + b = tasks.Task(tasks.sleep(0.15, loop=self.loop), loop=self.loop) @tasks.coroutine def foo(): - done, pending = yield from tasks.wait([b, a]) + done, pending = yield from tasks.wait([b, a], loop=self.loop) self.assertEqual(done, set([a, b])) self.assertEqual(pending, set()) return 42 t0 = time.monotonic() - res = self.loop.run_until_complete(tasks.Task(foo())) + res = self.loop.run_until_complete(tasks.Task(foo(), loop=self.loop)) t1 = time.monotonic() self.assertTrue(t1-t0 >= 0.14) self.assertEqual(res, 42) # Doing it again should take no time and exercise a different path. t0 = time.monotonic() - res = self.loop.run_until_complete(tasks.Task(foo())) + res = self.loop.run_until_complete(tasks.Task(foo(), loop=self.loop)) t1 = time.monotonic() self.assertTrue(t1-t0 <= 0.01) # TODO: Test different return_when values. @@ -355,17 +355,20 @@ def foo(): def test_wait_errors(self): self.assertRaises( ValueError, self.loop.run_until_complete, - tasks.wait(set())) + tasks.wait(set(), loop=self.loop)) self.assertRaises( ValueError, self.loop.run_until_complete, - tasks.wait([tasks.sleep(10.0)], return_when=-1)) + tasks.wait([tasks.sleep(10.0, loop=self.loop)], + return_when=-1, loop=self.loop)) def test_wait_first_completed(self): - a = tasks.Task(tasks.sleep(10.0)) - b = tasks.Task(tasks.sleep(0.1)) - task = tasks.Task(tasks.wait( - [b, a], return_when=tasks.FIRST_COMPLETED)) + a = tasks.Task(tasks.sleep(10.0, loop=self.loop), loop=self.loop) + b = tasks.Task(tasks.sleep(0.1, loop=self.loop), loop=self.loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED, + loop=self.loop), + loop=self.loop) done, pending = self.loop.run_until_complete(task) self.assertEqual({b}, done) @@ -387,10 +390,12 @@ def coro2(): yield yield - a = tasks.Task(coro1()) - b = tasks.Task(coro2()) + a = tasks.Task(coro1(), loop=self.loop) + b = tasks.Task(coro2(), loop=self.loop) task = tasks.Task( - tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED)) + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED, + loop=self.loop), + loop=self.loop) done, pending = self.loop.run_until_complete(task) self.assertEqual({a, b}, done) @@ -401,15 +406,17 @@ def coro2(): def test_wait_first_exception(self): # first_exception, task already has exception - a = tasks.Task(tasks.sleep(10.0)) + a = tasks.Task(tasks.sleep(10.0, loop=self.loop), loop=self.loop) @tasks.coroutine def exc(): raise ZeroDivisionError('err') - b = tasks.Task(exc()) - task = tasks.Task(tasks.wait( - [b, a], return_when=tasks.FIRST_EXCEPTION)) + b = tasks.Task(exc(), loop=self.loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION, + loop=self.loop), + loop=self.loop) done, pending = self.loop.run_until_complete(task) self.assertEqual({b}, done) @@ -417,69 +424,71 @@ def exc(): def test_wait_first_exception_in_wait(self): # first_exception, exception during waiting - a = tasks.Task(tasks.sleep(10.0)) + a = tasks.Task(tasks.sleep(10.0, loop=self.loop), loop=self.loop) @tasks.coroutine def exc(): - yield from tasks.sleep(0.01) + yield from tasks.sleep(0.01, loop=self.loop) raise ZeroDivisionError('err') - b = tasks.Task(exc()) - task = tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION) + b = tasks.Task(exc(), loop=self.loop) + task = tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION, + loop=self.loop) done, pending = self.loop.run_until_complete(task) self.assertEqual({b}, done) self.assertEqual({a}, pending) def test_wait_with_exception(self): - a = tasks.Task(tasks.sleep(0.1)) + a = tasks.Task(tasks.sleep(0.1, loop=self.loop), loop=self.loop) @tasks.coroutine def sleeper(): - yield from tasks.sleep(0.15) + yield from tasks.sleep(0.15, loop=self.loop) raise ZeroDivisionError('really') - b = tasks.Task(sleeper()) + b = tasks.Task(sleeper(), loop=self.loop) @tasks.coroutine def foo(): - done, pending = yield from tasks.wait([b, a]) + done, pending = yield from tasks.wait([b, a], loop=self.loop) self.assertEqual(len(done), 2) self.assertEqual(pending, set()) errors = set(f for f in done if f.exception() is not None) self.assertEqual(len(errors), 1) t0 = time.monotonic() - self.loop.run_until_complete(tasks.Task(foo())) + self.loop.run_until_complete(tasks.Task(foo(), loop=self.loop)) t1 = time.monotonic() self.assertTrue(t1-t0 >= 0.14) t0 = time.monotonic() - self.loop.run_until_complete(tasks.Task(foo())) + self.loop.run_until_complete(tasks.Task(foo(), loop=self.loop)) t1 = time.monotonic() self.assertTrue(t1-t0 <= 0.01) def test_wait_with_timeout(self): - a = tasks.Task(tasks.sleep(0.1)) - b = tasks.Task(tasks.sleep(0.15)) + a = tasks.Task(tasks.sleep(0.1, loop=self.loop), loop=self.loop) + b = tasks.Task(tasks.sleep(0.15, loop=self.loop), loop=self.loop) @tasks.coroutine def foo(): - done, pending = yield from tasks.wait([b, a], timeout=0.11) + done, pending = yield from tasks.wait([b, a], timeout=0.11, + loop=self.loop) self.assertEqual(done, set([a])) self.assertEqual(pending, set([b])) t0 = time.monotonic() - self.loop.run_until_complete(tasks.Task(foo())) + self.loop.run_until_complete(tasks.Task(foo(), loop=self.loop)) t1 = time.monotonic() self.assertTrue(t1-t0 >= 0.1) self.assertTrue(t1-t0 <= 0.13) def test_wait_concurrent_complete(self): - a = tasks.Task(tasks.sleep(0.1)) - b = tasks.Task(tasks.sleep(0.15)) + a = tasks.Task(tasks.sleep(0.1, loop=self.loop), loop=self.loop) + b = tasks.Task(tasks.sleep(0.15, loop=self.loop), loop=self.loop) done, pending = self.loop.run_until_complete( - tasks.wait([b, a], timeout=0.1)) + tasks.wait([b, a], timeout=0.1, loop=self.loop)) self.assertEqual(done, set([a])) self.assertEqual(pending, set([b])) @@ -487,7 +496,7 @@ def test_wait_concurrent_complete(self): def test_as_completed(self): @tasks.coroutine def sleeper(dt, x): - yield from tasks.sleep(dt) + yield from tasks.sleep(dt, loop=self.loop) return x a = sleeper(0.1, 'a') @@ -497,12 +506,12 @@ def sleeper(dt, x): @tasks.coroutine def foo(): values = [] - for f in tasks.as_completed([b, c, a]): + for f in tasks.as_completed([b, c, a], loop=self.loop): values.append((yield from f)) return values t0 = time.monotonic() - res = self.loop.run_until_complete(tasks.Task(foo())) + res = self.loop.run_until_complete(tasks.Task(foo(), loop=self.loop)) t1 = time.monotonic() self.assertTrue(t1-t0 >= 0.14) self.assertTrue('a' in res[:2]) @@ -510,18 +519,18 @@ def foo(): self.assertEqual(res[2], 'c') # Doing it again should take no time and exercise a different path. t0 = time.monotonic() - res = self.loop.run_until_complete(tasks.Task(foo())) + res = self.loop.run_until_complete(tasks.Task(foo(), loop=self.loop)) t1 = time.monotonic() self.assertTrue(t1-t0 <= 0.01) def test_as_completed_with_timeout(self): - a = tasks.sleep(0.1, 'a') - b = tasks.sleep(0.15, 'b') + a = tasks.sleep(0.1, 'a', loop=self.loop) + b = tasks.sleep(0.15, 'b', loop=self.loop) @tasks.coroutine def foo(): values = [] - for f in tasks.as_completed([a, b], timeout=0.12): + for f in tasks.as_completed([a, b], timeout=0.12, loop=self.loop): try: v = yield from f values.append((1, v)) @@ -530,7 +539,7 @@ def foo(): return values t0 = time.monotonic() - res = self.loop.run_until_complete(tasks.Task(foo())) + res = self.loop.run_until_complete(tasks.Task(foo(), loop=self.loop)) t1 = time.monotonic() self.assertTrue(t1-t0 >= 0.11) self.assertEqual(len(res), 2, res) @@ -539,10 +548,10 @@ def foo(): self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) def test_as_completed_reverse_wait(self): - a = tasks.sleep(0.05, 'a') - b = tasks.sleep(0.10, 'b') + a = tasks.sleep(0.05, 'a', loop=self.loop) + b = tasks.sleep(0.10, 'b', loop=self.loop) fs = {a, b} - futs = list(tasks.as_completed(fs)) + futs = list(tasks.as_completed(fs, loop=self.loop)) self.assertEqual(len(futs), 2) x = self.loop.run_until_complete(futs[1]) self.assertEqual(x, 'a') @@ -550,23 +559,23 @@ def test_as_completed_reverse_wait(self): self.assertEqual(y, 'b') def test_as_completed_concurrent(self): - a = tasks.sleep(0.05, 'a') - b = tasks.sleep(0.05, 'b') + a = tasks.sleep(0.05, 'a', loop=self.loop) + b = tasks.sleep(0.05, 'b', loop=self.loop) fs = {a, b} - futs = list(tasks.as_completed(fs)) + futs = list(tasks.as_completed(fs, loop=self.loop)) self.assertEqual(len(futs), 2) - waiter = tasks.wait(futs) + waiter = tasks.wait(futs, loop=self.loop) done, pending = self.loop.run_until_complete(waiter) self.assertEqual(set(f.result() for f in done), {'a', 'b'}) def test_sleep(self): @tasks.coroutine def sleeper(dt, arg): - yield from tasks.sleep(dt/2) - res = yield from tasks.sleep(dt/2, arg) + yield from tasks.sleep(dt/2, loop=self.loop) + res = yield from tasks.sleep(dt/2, arg, loop=self.loop) return res - t = tasks.Task(sleeper(0.1, 'yeah')) + t = tasks.Task(sleeper(0.1, 'yeah'), loop=self.loop) t0 = time.monotonic() self.loop.run_until_complete(t) t1 = time.monotonic() @@ -575,7 +584,8 @@ def sleeper(dt, arg): self.assertEqual(t.result(), 'yeah') def test_sleep_cancel(self): - t = tasks.Task(tasks.sleep(10.0, 'yeah')) + t = tasks.Task(tasks.sleep(10.0, 'yeah', loop=self.loop), + loop=self.loop) handle = None orig_call_later = self.loop.call_later @@ -597,19 +607,19 @@ def call_later(self, delay, callback, *args): def test_task_cancel_sleeping_task(self): sleepfut = None - @tasks.task + @tasks.coroutine def sleep(dt): nonlocal sleepfut - sleepfut = tasks.sleep(dt) + sleepfut = tasks.sleep(dt, loop=self.loop) try: time.monotonic() yield from sleepfut finally: time.monotonic() - @tasks.task + @tasks.coroutine def doit(): - sleeper = sleep(5000) + sleeper = tasks.Task(sleep(5000), loop=self.loop) self.loop.call_later(0.1, sleeper.cancel) try: time.monotonic() @@ -629,14 +639,14 @@ def doit(): def test_task_cancel_waiter_future(self): fut = futures.Future(loop=self.loop) - @tasks.task + @tasks.coroutine def coro(): try: yield from fut except futures.CancelledError: pass - task = coro() + task = tasks.Task(coro(), loop=self.loop) test_utils.run_briefly(self.loop) self.assertIs(task._fut_waiter, fut) @@ -651,7 +661,7 @@ def test_step_in_completed_task(self): def notmuch(): return 'ko' - task = tasks.Task(notmuch()) + task = tasks.Task(notmuch(), loop=self.loop) task.set_result('ok') self.assertRaises(AssertionError, task._step) @@ -670,23 +680,23 @@ def test_step_result_future(self): # If coroutine returns future, task waits on this future. class Fut(futures.Future): - def __init__(self, *args): + def __init__(self, *args, **kwds): self.cb_added = False - super().__init__(*args) + super().__init__(*args, **kwds) def add_done_callback(self, fn): self.cb_added = True super().add_done_callback(fn) - fut = Fut() + fut = Fut(loop=self.loop) result = None - @tasks.task + @tasks.coroutine def wait_for_future(): nonlocal result result = yield from fut - t = wait_for_future() + t = tasks.Task(wait_for_future(), loop=self.loop) test_utils.run_briefly(self.loop) self.assertTrue(fut.cb_added) @@ -702,7 +712,7 @@ def test_step_with_baseexception(self): def notmutch(): raise BaseException() - task = tasks.Task(notmutch()) + task = tasks.Task(notmutch(), loop=self.loop) self.assertRaises(BaseException, task._step) self.assertTrue(task.done()) @@ -711,7 +721,7 @@ def notmutch(): def test_baseexception_during_cancel(self): @tasks.coroutine def sleeper(): - yield from tasks.sleep(10) + yield from tasks.sleep(10, loop=self.loop) @tasks.coroutine def notmutch(): @@ -720,7 +730,7 @@ def notmutch(): except futures.CancelledError: raise BaseException() - task = tasks.Task(notmutch()) + task = tasks.Task(notmutch(), loop=self.loop) test_utils.run_briefly(self.loop) task.cancel() @@ -749,7 +759,7 @@ def fn2(): def test_yield_vs_yield_from(self): fut = futures.Future(loop=self.loop) - @tasks.task + @tasks.coroutine def wait_for_future(): yield fut @@ -765,7 +775,7 @@ def test_yield_vs_yield_from_generator(self): def coro(): yield - @tasks.task + @tasks.coroutine def wait_for_future(): yield coro() @@ -798,8 +808,8 @@ def func(): def coro(): fut.set_result('test') - t1 = tasks.Task(func()) - t2 = tasks.Task(coro()) + t1 = tasks.Task(func(), loop=self.loop) + t2 = tasks.Task(coro(), loop=self.loop) res = self.loop.run_until_complete(t1) self.assertEqual(res, 'test') self.assertIsNone(t2.result()) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index 8ad66308..314ba065 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -23,7 +23,7 @@ class SelectorEventLoopTests(unittest.TestCase): def setUp(self): self.loop = unix_events.SelectorEventLoop() - events.set_event_loop(self.loop) + events.set_event_loop(None) def tearDown(self): self.loop.close() @@ -208,7 +208,7 @@ def test_ctor(self): self.protocol.connection_made, tr) def test_ctor_with_waiter(self): - fut = futures.Future() + fut = futures.Future(loop=self.loop) unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol, fut) self.loop.call_soon.assert_called_with(fut.set_result, None) @@ -353,7 +353,7 @@ def test_ctor(self): self.assertTrue(tr._enable_read_hack) def test_ctor_with_waiter(self): - fut = futures.Future() + fut = futures.Future(loop=self.loop) tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol, fut) self.loop.call_soon.assert_called_with(fut.set_result, None) diff --git a/tests/windows_events_test.py b/tests/windows_events_test.py index bb3a8a65..73304c2b 100644 --- a/tests/windows_events_test.py +++ b/tests/windows_events_test.py @@ -30,9 +30,10 @@ class ProactorTests(unittest.TestCase): def setUp(self): self.loop = windows_events.ProactorEventLoop() - tulip.set_event_loop(self.loop) + tulip.set_event_loop(self.loop) # TODO: Use None, test on Windows. def tearDown(self): + tulip.set_event_loop(None) self.loop.close() def test_pause_resume_discard(self): diff --git a/tulip/base_events.py b/tulip/base_events.py index 3cfe6625..69fee3e5 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -253,7 +253,7 @@ def create_connection(self, protocol_factory, host=None, port=None, *, else: f2 = None - yield from tasks.wait(fs) + yield from tasks.wait(fs, loop=self) infos = f1.result() if not infos: @@ -393,10 +393,40 @@ def create_datagram_endpoint(self, protocol_factory, sock, protocol, r_addr, extra={'addr': l_addr}) return transport, protocol - @tasks.task - def start_serving(self, protocol_factory, host=None, port=None, *, - family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, - sock=None, backlog=100, ssl=None, reuse_address=None): + # This returns a Task made from self._start_serving_internal(). + # We want start_serving() to return a Task so that it will start + # running right away (when the event loop runs) even if the caller + # doesn't wait for it. Note that this is different from + # e.g. create_connection(), or create_datagram_endpoint(), which + # are a "mere" coroutines and require their caller to wait for + # them. The reason for the difference is that only + # start_serving() creates multiple transports and protocols. + def start_serving(self, protocol_factory, host=None, port=None, + *, + family=socket.AF_UNSPEC, + flags=socket.AI_PASSIVE, + sock=None, + backlog=100, + ssl=None, + reuse_address=None): + coro = self._start_serving_internal(protocol_factory, host, port, + family=family, + flags=flags, + sock=sock, + backlog=backlog, + ssl=ssl, + reuse_address=reuse_address) + return tasks.Task(coro, loop=self) + + @tasks.coroutine + def _start_serving_internal(self, protocol_factory, host=None, port=None, + *, + family=socket.AF_UNSPEC, + flags=socket.AI_PASSIVE, + sock=None, + backlog=100, + ssl=None, + reuse_address=None): """XXX""" if host is not None or port is not None: if sock is not None: diff --git a/tulip/events.py b/tulip/events.py index a1a5fd3e..23fad05d 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -172,9 +172,9 @@ def create_connection(self, protocol_factory, host=None, port=None, *, def start_serving(self, protocol_factory, host=None, port=None, *, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, sock=None, backlog=100, ssl=None, reuse_address=None): - """Creates a TCP server bound to host and port and return - a list of socket objects which will later be handled by - protocol_factory. + """Creates a TCP server bound to host and port and return a + Task whose result will be a list of socket objects which will + later be handled by protocol_factory. If host is an empty string or None all interfaces are assumed and a list of multiple sockets will be returned (most likely @@ -310,6 +310,7 @@ class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy): """ _loop = None + _set_called = False def get_event_loop(self): """Get the event loop. @@ -317,13 +318,18 @@ def get_event_loop(self): This may be None or an instance of EventLoop. """ if (self._loop is None and + not self._set_called and threading.current_thread().name == 'MainThread'): self._loop = self.new_event_loop() + assert self._loop is not None, \ + ('There is no current event loop in thread %r.' % + threading.current_thread().name) return self._loop def set_event_loop(self, loop): """Set the event loop.""" # TODO: The isinstance() test violates the PEP. + self._set_called = True assert loop is None or isinstance(loop, AbstractEventLoop) self._loop = loop diff --git a/tulip/subprocess_transport.py b/tulip/subprocess_transport.py index e790d285..cae33918 100644 --- a/tulip/subprocess_transport.py +++ b/tulip/subprocess_transport.py @@ -18,10 +18,12 @@ class UnixSubprocessTransport(transports.Transport): and something else that handles pipe setup, fork, and exec. """ - def __init__(self, protocol, args): + def __init__(self, protocol, args, *, loop=None): self._protocol = protocol # Not a factory! :-) self._args = args # args[0] must be full path of binary. - self._event_loop = events.get_event_loop() + if loop is None: + loop = events.get_event_loop() + self._event_loop = loop self._buffer = [] self._eof = False rstdin, self._wstdin = os.pipe() diff --git a/tulip/test_utils.py b/tulip/test_utils.py index ca924491..98c61a3e 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -32,10 +32,11 @@ def run_briefly(loop): - @tulip.task + @tulip.coroutine def once(): pass - loop.run_until_complete(once()) + t = tulip.Task(once(), loop=loop) + loop.run_until_complete(t) def run_once(loop): From 7774134c4fe9bb8cd671895eecab563c4c0671c2 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 10 Aug 2013 01:49:01 +0300 Subject: [PATCH 0522/1502] Implement subprocess execution for UNIX systems --- tests/events_test.py | 145 ++++++++++++++++++++++++++++- tests/unix_events_test.py | 4 +- tulip/base_events.py | 36 ++++++++ tulip/events.py | 12 ++- tulip/protocols.py | 22 +++++ tulip/transports.py | 52 +++++++++++ tulip/unix_events.py | 186 +++++++++++++++++++++++++++++++++++++- 7 files changed, 452 insertions(+), 5 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index a34fd8a5..0315db8f 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -1,5 +1,6 @@ """Tests for events.py.""" +import functools import gc import io import os @@ -26,6 +27,7 @@ from tulip import selector_events from tulip import tasks from tulip import test_utils +from tulip import locks class MyProto(protocols.Protocol): @@ -117,7 +119,7 @@ def connection_lost(self, exc): self.done.set_result(None) -class MyWritePipeProto(protocols.Protocol): +class MyWritePipeProto(protocols.BaseProtocol): done = None def __init__(self, loop=None): @@ -138,6 +140,44 @@ def connection_lost(self, exc): self.done.set_result(None) +class MySubprocessProtocol(protocols.SubprocessProtocol): + + def __init__(self, loop): + self.state = 'INITIAL' + self.transport = None + self.connected = futures.Future(loop=loop) + self.completed = futures.Future(loop=loop) + self.disconnects = {fd: futures.Future(loop=loop) for fd in range(3)} + self.data = {1: b'', 2: b''} + self.returncode = None + self.got_data = locks.EventWaiter(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + self.connected.set_result(None) + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + self.completed.set_result(None) + + def pipe_data_received(self, fd, data): + assert self.state == 'CONNECTED', self.state + self.data[fd] += data + self.got_data.set() + + def pipe_connection_lost(self, fd, exc): + if exc: + self.disconnects[fd].set_exception(exc) + else: + self.disconnects[fd].set_result(exc) + + def process_exited(self): + self.returncode = self.transport.get_returncode() + + class EventLoopTestsMixin: def setUp(self): @@ -901,6 +941,109 @@ def main(): r.close() w.close() + @unittest.skipUnless(sys.platform != 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_exec(self): + proto = None + transp = None + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + 'tr', '[a-z]', '[A-Z]') + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python ') + stdin.write(b'The Winner') + stdin.close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(0, proto.returncode) + self.assertEqual(b'PYTHON THE WINNER', proto.data[1]) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_interactive(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + try: + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python ') + self.loop.run_until_complete(proto.got_data.wait(10)) + proto.got_data.clear() + self.assertEqual(b'Python ', proto.data[1]) + + stdin.write(b'The Winner') + self.loop.run_until_complete(proto.got_data.wait(10)) + self.assertEqual(b'Python The Winner', proto.data[1]) + finally: + transp.close() + + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_shell(self): + proto = None + transp = None + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'echo "Python"') + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.get_pipe_transport(0).close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(0, proto.returncode) + self.assertTrue(all(f.done() for f in proto.disconnects.values())) + self.assertEqual({1: b'Python\n', 2: b''}, proto.data) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_exitcode(self): + proto = None + transp = None + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7') + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + if sys.platform == 'win32': from tulip import windows_events diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index 314ba065..29a8a949 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -233,7 +233,9 @@ def test__read_ready_eof(self, m_read): m_read.assert_called_with(5, tr.max_size) self.loop.remove_reader.assert_called_with(5) - self.protocol.eof_received.assert_called_with() + self.loop.call_soon.assert_has_calls([ + unittest.mock.call(self.protocol.eof_received), + unittest.mock.call(tr._call_connection_lost, None)]) @unittest.mock.patch('os.read') def test__read_ready_blocked(self, m_read): diff --git a/tulip/base_events.py b/tulip/base_events.py index 69fee3e5..04af67b8 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -19,6 +19,7 @@ import heapq import logging import socket +import subprocess import time import os import sys @@ -78,6 +79,14 @@ def _make_write_pipe_transport(self, pipe, protocol, waiter=None, """Create write pipe transport.""" raise NotImplementedError + @tasks.coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + """Create subprocess transport.""" + raise NotImplementedError + yield + def _read_from_self(self): """XXX""" raise NotImplementedError @@ -503,6 +512,33 @@ def connect_write_pipe(self, protocol_factory, pipe): yield from waiter return transport, protocol + @tasks.coroutine + def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + universal_newlines=False, shell=True, bufsize=0, + **kwargs): + assert not universal_newlines, "universal_newlines must be False" + assert shell, "shell must be True" + assert isinstance(cmd, str), cmd + protocol = protocol_factory() + transport = yield from self._make_subprocess_transport( + protocol, cmd, True, stdin, stdout, stderr, bufsize, + extra={}, **kwargs) + return transport, protocol + + @tasks.coroutine + def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + universal_newlines=False, shell=False, bufsize=0, + **kwargs): + assert not universal_newlines, "universal_newlines must be False" + assert not shell, "shell must be False" + protocol = protocol_factory() + transport = yield from self._make_subprocess_transport( + protocol, args, False, stdin, stdout, stderr, bufsize, + extra={}, **kwargs) + return transport, protocol + def _add_callback(self, handle): """Add a Handle to ready or scheduled.""" assert isinstance(handle, events.Handle), 'A Handle is required here' diff --git a/tulip/events.py b/tulip/events.py index 23fad05d..37b95594 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -10,6 +10,7 @@ 'get_event_loop', 'set_event_loop', 'new_event_loop', ] +import subprocess import sys import threading import socket @@ -237,8 +238,15 @@ def connect_write_pipe(self, protocol_factory, pipe): # close fd in pipe transport then close f and vise versa. raise NotImplementedError - #def spawn_subprocess(self, protocol_factory, pipe): - # raise NotImplementedError + def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + **kwargs): + raise NotImplementedError + + def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + **kwargs): + raise NotImplementedError # Ready-based callback registration methods. # The add_*() methods return None. diff --git a/tulip/protocols.py b/tulip/protocols.py index 593ee745..d76f25a2 100644 --- a/tulip/protocols.py +++ b/tulip/protocols.py @@ -76,3 +76,25 @@ def datagram_received(self, data, addr): def connection_refused(self, exc): """Connection is refused.""" + + +class SubprocessProtocol(BaseProtocol): + """ABC representing a protocol for subprocess calls.""" + + def pipe_data_received(self, fd, data): + """Called when subprocess write a data into stdout/stderr pipes. + + fd is int file dascriptor. + data is bytes object. + """ + + def pipe_connection_lost(self, fd, exc): + """Called when a file descriptor associated with the child process is + closed. + + fd is the int file descriptor that was closed. + """ + + def process_exited(self): + """Called when subprocess has exited. + """ diff --git a/tulip/transports.py b/tulip/transports.py index 2b34bc59..c571fcc8 100644 --- a/tulip/transports.py +++ b/tulip/transports.py @@ -147,3 +147,55 @@ def abort(self): called with None as its argument. """ raise NotImplementedError + + +class SubprocessTransport(BaseTransport): + + def get_pid(self): + """Get subprocess id.""" + raise NotImplementedError + + def get_returncode(self): + """Get subprocess returncode. + + See also + http://docs.python.org/3/library/subprocess#subprocess.Popen.returncode + """ + raise NotImplementedError + + def get_pipe_transport(self, fd): + """Get transport for pipe with number fd.""" + raise NotImplementedError + + def send_signal(self, signal): + """Send signal to subprocess. + + See also: + http://docs.python.org/3/library/subprocess#subprocess.Popen.send_signal + """ + raise NotImplementedError + + def terminate(self): + """Stop the subprocess. + + Alias for close() method. + + On Posix OSs the method sends SIGTERM to the subprocess. + On Windows the Win32 API function TerminateProcess() + is called to stop the subprocess. + + See also: + http://docs.python.org/3/library/subprocess#subprocess.Popen.terminate + """ + raise NotImplementedError + + def kill(self): + """Kill the subprocess. + + On Posix OSs the function sends SIGKILL to the subprocess. + On Windows kill() is an alias for terminate(). + + See also: + http://docs.python.org/3/library/subprocess#subprocess.Popen.kill + """ + raise NotImplementedError diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 563ff6c3..2b0af64f 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -1,11 +1,15 @@ """Selector eventloop for Unix with signal handling.""" +import collections import errno import fcntl +import functools import os import socket import stat +import subprocess import sys +import weakref try: import signal @@ -14,7 +18,9 @@ from . import constants from . import events +from . import protocols from . import selector_events +from . import tasks from . import transports from .log import tulip_log @@ -35,10 +41,18 @@ class SelectorEventLoop(selector_events.BaseSelectorEventLoop): def __init__(self, selector=None): super().__init__(selector) self._signal_handlers = {} + self._subprocesses = weakref.WeakValueDictionary() def _socketpair(self): return socket.socketpair() + def close(self): + if signal is not None: + handler = self._signal_handlers.get(signal.SIGCHLD) + if handler is not None: + self.remove_signal_handler(signal.SIGCHLD) + super().close() + def add_signal_handler(self, sig, callback, *args): """Add a handler for a signal. UNIX only. @@ -139,6 +153,39 @@ def _make_write_pipe_transport(self, pipe, protocol, waiter=None, extra=None): return _UnixWritePipeTransport(self, pipe, protocol, waiter, extra) + @tasks.coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + self._reg_sigchld() + transp = _UnixSubprocessTransport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs) + self._subprocesses[transp.get_pid()] = transp + yield from transp._post_init() + return transp + + def _reg_sigchld(self): + assert signal, "signal support is required" + if signal.SIGCHLD not in self._signal_handlers: + self.add_signal_handler(signal.SIGCHLD, + self._sig_chld) + + def _sig_chld(self): + try: + while True: + grp = os.getpgrp() + ret = os.waitid(os.P_PGID, grp, + os.WNOHANG|os.WNOWAIT|os.WEXITED) + if ret is None: + break + pid = ret.si_pid + transp = self._subprocesses.get(pid) + if transp is not None: + transp._poll() + except ChildProcessError: + pass + def _set_nonblocking(fd): flags = fcntl.fcntl(fd, fcntl.F_GETFL) @@ -175,8 +222,10 @@ def _read_ready(self): if data: self._protocol.data_received(data) else: + self._closing = True self._loop.remove_reader(self._fileno) - self._protocol.eof_received() + self._loop.call_soon(self._protocol.eof_received) + self._loop.call_soon(self._call_connection_lost, None) def pause(self): self._loop.remove_reader(self._fileno) @@ -354,3 +403,138 @@ def discard_output(self): if self._buffer: self._loop.remove_writer(self._fileno) self._buffer.clear() + + +class _UnixWriteSubprocessPipeProto(protocols.BaseProtocol): + pipe = None + + def __init__(self, proc, fd): + self.proc = proc + self.fd = fd + self.connected = False + self.disconnected = False + proc._pipes[fd] = self + + def connection_made(self, transport): + self.connected = True + self.pipe = transport + self.proc._pipe_connection_made(self.fd) + + def connection_lost(self, exc): + self.disconnected = True + self.proc._pipe_connection_lost(self.fd, exc) + + +class _UnixReadSubprocessPipeProto(_UnixWriteSubprocessPipeProto, + protocols.Protocol): + + def data_received(self, data): + self.proc._pipe_data_received(self.fd, data) + + def eof_received(self): + pass + + +class _UnixSubprocessTransport(transports.SubprocessTransport): + + def __init__(self, loop, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + super().__init__(extra) + self._protocol = protocol + self._loop = loop + + self._pipes = {} + if stdin == subprocess.PIPE: + self._pipes[0] = None + if stdout == subprocess.PIPE: + self._pipes[1] = None + if stderr == subprocess.PIPE: + self._pipes[2] = None + self._pending_calls = collections.deque() + self._exited = False + self._done = False + + self._proc = subprocess.Popen( + args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, + universal_newlines=False, bufsize=bufsize, **kwargs) + self._extra['subprocess'] = self._proc + + if not self._pipes: + self._loop.call_soon(self._protocol.connection_made, self) + + def close(self): + for proto in self._pipes.values(): + proto.pipe.close() + self.terminate() + + def get_pid(self): + return self._proc.pid + + def get_returncode(self): + return self._proc.returncode + + def get_pipe_transport(self, fd): + if fd in self._pipes: + return self._pipes[fd].pipe + else: + return None + + def send_signal(self, signal): + self._proc.send_signal(signal) + + def terminate(self): + self._proc.terminate() + + def kill(self): + self._proc.kill() + + @tasks.coroutine + def _post_init(self): + proc = self._proc + loop = self._loop + if proc.stdin is not None: + transp, proto = yield from loop.connect_write_pipe( + functools.partial(_UnixWriteSubprocessPipeProto, self, 0), + proc.stdin) + if proc.stdout is not None: + transp, proto = yield from loop.connect_read_pipe( + functools.partial(_UnixReadSubprocessPipeProto, self, 1), + proc.stdout) + if proc.stderr is not None: + transp, proto = yield from loop.connect_read_pipe( + functools.partial(_UnixReadSubprocessPipeProto, self, 2), + proc.stderr) + self._poll() + + def _call(self, cb, *data): + if self._pending_calls is not None: + self._pending_calls.append((cb, data)) + else: + self._loop.call_soon(cb, *data) + + def _pipe_connection_made(self, fd): + if all(p is not None and p.connected for p in self._pipes.values()): + self._loop.call_soon(self._protocol.connection_made, self) + for callback, data in self._pending_calls: + self._loop.call_soon(callback, *data) + self._pending_calls = None + + def _pipe_connection_lost(self, fd, exc): + self._call(self._protocol.pipe_connection_lost, fd, exc) + self._poll() + + def _pipe_data_received(self, fd, data): + self._call(self._protocol.pipe_data_received, fd, data) + + def _poll(self): + if not self._exited: + returncode = self._proc.poll() + if returncode is not None: + self._exited = True + self._call(self._protocol.process_exited) + if self._exited and not self._done: + if all(p is not None and p.disconnected + for p in self._pipes.values()): + self._call(self._protocol.connection_lost, None) + self._done = True From 184864f27dafd1a4dffd22af540b8faaf9f2a7ae Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 10 Aug 2013 01:53:36 +0300 Subject: [PATCH 0523/1502] Skip subpocess tests only for Windows --- tests/events_test.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 0315db8f..7918e388 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -941,8 +941,8 @@ def main(): r.close() w.close() - @unittest.skipUnless(sys.platform != 'win32', - "Don't support subprocess for Windows yet") + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") def test_subprocess_exec(self): proto = None transp = None @@ -967,8 +967,8 @@ def connect(): self.assertEqual(0, proto.returncode) self.assertEqual(b'PYTHON THE WINNER', proto.data[1]) - @unittest.skipUnless(sys.platform != 'win32', - "Don't support subprocess for Windows yet") + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") def test_subprocess_interactive(self): proto = None transp = None @@ -1003,7 +1003,7 @@ def connect(): self.loop.run_until_complete(proto.completed) self.assertEqual(-signal.SIGTERM, proto.returncode) - @unittest.skipUnless(sys.platform != 'win32', + @unittest.skipIf(sys.platform == 'win32', "Don't support subprocess for Windows yet") def test_subprocess_shell(self): proto = None @@ -1026,8 +1026,8 @@ def connect(): self.assertTrue(all(f.done() for f in proto.disconnects.values())) self.assertEqual({1: b'Python\n', 2: b''}, proto.data) - @unittest.skipUnless(sys.platform != 'win32', - "Don't support subprocess for Windows yet") + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") def test_subprocess_exitcode(self): proto = None transp = None From 4277c82060aa3004a8c28b0f11abad89f80cba53 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 10 Aug 2013 01:54:38 +0300 Subject: [PATCH 0524/1502] Add missing test file --- tests/echo.py | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 tests/echo.py diff --git a/tests/echo.py b/tests/echo.py new file mode 100644 index 00000000..f6ac0a30 --- /dev/null +++ b/tests/echo.py @@ -0,0 +1,6 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + os.write(1, buf) From d6774ff180a1050015b2b1d7e88bd516a2a2824b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 9 Aug 2013 16:52:05 -0700 Subject: [PATCH 0525/1502] Get rid of Future.running(). It is useless. --- tests/futures_test.py | 4 ---- tulip/futures.py | 8 +------- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/tests/futures_test.py b/tests/futures_test.py index e448ff82..7ec40be5 100644 --- a/tests/futures_test.py +++ b/tests/futures_test.py @@ -26,7 +26,6 @@ def tearDown(self): def test_initial_state(self): f = futures.Future(loop=self.loop) self.assertFalse(f.cancelled()) - self.assertFalse(f.running()) self.assertFalse(f.done()) f.cancel() self.assertTrue(f.cancelled()) @@ -47,7 +46,6 @@ def test_cancel(self): f = futures.Future(loop=self.loop) self.assertTrue(f.cancel()) self.assertTrue(f.cancelled()) - self.assertFalse(f.running()) self.assertTrue(f.done()) self.assertRaises(futures.CancelledError, f.result) self.assertRaises(futures.CancelledError, f.exception) @@ -62,7 +60,6 @@ def test_result(self): f.set_result(42) self.assertFalse(f.cancelled()) - self.assertFalse(f.running()) self.assertTrue(f.done()) self.assertEqual(f.result(), 42) self.assertEqual(f.exception(), None) @@ -78,7 +75,6 @@ def test_exception(self): f.set_exception(exc) self.assertFalse(f.cancelled()) - self.assertFalse(f.running()) self.assertTrue(f.done()) self.assertRaises(RuntimeError, f.result) self.assertEqual(f.exception(), exc) diff --git a/tulip/futures.py b/tulip/futures.py index 3965e2b5..676cea16 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -209,13 +209,7 @@ def cancelled(self): """Return True if the future was cancelled.""" return self._state == _CANCELLED - def running(self): - """Always return False. - - This method is for compatibility with concurrent.futures; we don't - have a running state. - """ - return False # We don't have a running state. + # Don't implement running(); see http://bugs.python.org/issue18699 def done(self): """Return True if the future is done. From 6f9a0ed9c17d845d9230d9d525965e845cd22044 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 10 Aug 2013 04:03:50 +0300 Subject: [PATCH 0526/1502] Replace waitid to waitpid to make work on OSX --- tulip/unix_events.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 2b0af64f..f1ad0b97 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -175,14 +175,12 @@ def _sig_chld(self): try: while True: grp = os.getpgrp() - ret = os.waitid(os.P_PGID, grp, - os.WNOHANG|os.WNOWAIT|os.WEXITED) - if ret is None: + pid, status = os.waitpid(0, os.WNOHANG) + if pid == 0: break - pid = ret.si_pid transp = self._subprocesses.get(pid) if transp is not None: - transp._poll() + transp._poll(status) except ChildProcessError: pass @@ -527,9 +525,13 @@ def _pipe_connection_lost(self, fd, exc): def _pipe_data_received(self, fd, data): self._call(self._protocol.pipe_data_received, fd, data) - def _poll(self): + def _poll(self, status=None): if not self._exited: - returncode = self._proc.poll() + if status is None: + returncode = self._proc.poll() + else: + self._proc._handle_exitstatus(status) + returncode = self._proc.returncode if returncode is not None: self._exited = True self._call(self._protocol.process_exited) From b85c0eedac257e93e60c6cb4c42562f811935e3e Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 10 Aug 2013 16:23:50 +0300 Subject: [PATCH 0527/1502] Get rid of using undocumented subprocess.Popen method --- tulip/unix_events.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index f1ad0b97..5df9d444 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -178,9 +178,15 @@ def _sig_chld(self): pid, status = os.waitpid(0, os.WNOHANG) if pid == 0: break + if os.WIFSIGNALED(status): + returncode = -os.WTERMSIG(status) + elif os.WIFEXITED(status): + returncode = os.WEXITSTATUS(status) + else: + break transp = self._subprocesses.get(pid) if transp is not None: - transp._poll(status) + transp._poll(returncode) except ChildProcessError: pass @@ -450,8 +456,8 @@ def __init__(self, loop, protocol, args, shell, if stderr == subprocess.PIPE: self._pipes[2] = None self._pending_calls = collections.deque() - self._exited = False self._done = False + self._returncode = None self._proc = subprocess.Popen( args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, @@ -464,13 +470,14 @@ def __init__(self, loop, protocol, args, shell, def close(self): for proto in self._pipes.values(): proto.pipe.close() - self.terminate() + if self._returncode is None: + self.terminate() def get_pid(self): return self._proc.pid def get_returncode(self): - return self._proc.returncode + return self._returncode def get_pipe_transport(self, fd): if fd in self._pipes: @@ -525,17 +532,14 @@ def _pipe_connection_lost(self, fd, exc): def _pipe_data_received(self, fd, data): self._call(self._protocol.pipe_data_received, fd, data) - def _poll(self, status=None): - if not self._exited: - if status is None: + def _poll(self, returncode=None): + if self._returncode is None: + if returncode is None: returncode = self._proc.poll() - else: - self._proc._handle_exitstatus(status) - returncode = self._proc.returncode if returncode is not None: - self._exited = True + self._returncode = returncode self._call(self._protocol.process_exited) - if self._exited and not self._done: + if self._returncode is not None and not self._done: if all(p is not None and p.disconnected for p in self._pipes.values()): self._call(self._protocol.connection_lost, None) From adac0bfa7ae95b2640d4023f309ab5b37de87d12 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 10 Aug 2013 16:30:31 +0300 Subject: [PATCH 0528/1502] Dont use subtle weakref to subprocess transport, switch to direct notification on child exiting --- tulip/unix_events.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 5df9d444..9c5382b7 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -9,7 +9,6 @@ import stat import subprocess import sys -import weakref try: import signal @@ -41,7 +40,7 @@ class SelectorEventLoop(selector_events.BaseSelectorEventLoop): def __init__(self, selector=None): super().__init__(selector) self._signal_handlers = {} - self._subprocesses = weakref.WeakValueDictionary() + self._subprocesses = {} def _socketpair(self): return socket.socketpair() @@ -190,6 +189,11 @@ def _sig_chld(self): except ChildProcessError: pass + def _subprocess_closed(self, transport): + pid = transport.get_pid() + if self._subprocesses.get(pid): + del self._subprocesses[pid] + def _set_nonblocking(fd): flags = fcntl.fcntl(fd, fcntl.F_GETFL) @@ -538,6 +542,7 @@ def _poll(self, returncode=None): returncode = self._proc.poll() if returncode is not None: self._returncode = returncode + self._loop._subprocess_closed(self) self._call(self._protocol.process_exited) if self._returncode is not None and not self._done: if all(p is not None and p.disconnected From 708019bf5ec836921d61578703b1605fa705858d Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sun, 11 Aug 2013 21:22:04 +0300 Subject: [PATCH 0529/1502] Fix windows tests, use explicit loop --- tests/windows_events_test.py | 12 ++++++------ tulip/proactor_events.py | 1 + tulip/windows_events.py | 10 +++++++--- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/tests/windows_events_test.py b/tests/windows_events_test.py index 73304c2b..09faf2bb 100644 --- a/tests/windows_events_test.py +++ b/tests/windows_events_test.py @@ -9,7 +9,7 @@ def connect_read_pipe(loop, file): - stream_reader = streams.StreamReader() + stream_reader = streams.StreamReader(loop=loop) protocol = _StreamReaderProtocol(stream_reader) transport = loop._make_read_pipe_transport(file, protocol) return stream_reader @@ -30,22 +30,22 @@ class ProactorTests(unittest.TestCase): def setUp(self): self.loop = windows_events.ProactorEventLoop() - tulip.set_event_loop(self.loop) # TODO: Use None, test on Windows. + tulip.set_event_loop(None) def tearDown(self): - tulip.set_event_loop(None) self.loop.close() + self.loop = None def test_pause_resume_discard(self): a, b = self.loop._socketpair() trans = self.loop._make_write_pipe_transport(a, protocols.Protocol()) reader = connect_read_pipe(self.loop, b) - f = tulip.async(reader.readline()) + f = tulip.async(reader.readline(), loop=self.loop) trans.write(b'msg1\n') self.loop.run_until_complete(f, timeout=0.01) self.assertEqual(f.result(), b'msg1\n') - f = tulip.async(reader.readline()) + f = tulip.async(reader.readline(), loop=self.loop) trans.pause_writing() trans.write(b'msg2\n') @@ -56,7 +56,7 @@ def test_pause_resume_discard(self): trans.resume_writing() self.loop.run_until_complete(f, timeout=0.1) self.assertEqual(f.result(), b'msg2\n') - f = tulip.async(reader.readline()) + f = tulip.async(reader.readline(), loop=self.loop) trans.pause_writing() trans.write(b'msg3\n') diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index af6e00f8..cda87918 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -195,6 +195,7 @@ def __init__(self, proactor): tulip_log.debug('Using proactor: %s', proactor.__class__.__name__) self._proactor = proactor self._selector = proactor # convenient alias + proactor.set_loop(self) self._make_self_pipe() def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): diff --git a/tulip/windows_events.py b/tulip/windows_events.py index 08f221e2..527922e1 100644 --- a/tulip/windows_events.py +++ b/tulip/windows_events.py @@ -29,8 +29,8 @@ class _OverlappedFuture(futures.Future): Cancelling it will immediately cancel the overlapped operation. """ - def __init__(self, ov): - super().__init__() + def __init__(self, ov, *, loop=None): + super().__init__(loop=loop) self.ov = ov def cancel(self): @@ -59,6 +59,7 @@ def _socketpair(self): class IocpProactor: def __init__(self, concurrency=0xffffffff): + self._loop = None self._results = [] self._iocp = _overlapped.CreateIoCompletionPort( _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency) @@ -66,6 +67,9 @@ def __init__(self, concurrency=0xffffffff): self._registered = weakref.WeakSet() self._stopped_serving = weakref.WeakSet() + def set_loop(self, loop): + self._loop = loop + def registered_count(self): return len(self._cache) @@ -142,7 +146,7 @@ def _register_with_iocp(self, obj): _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) def _register(self, ov, obj, callback): - f = _OverlappedFuture(ov) + f = _OverlappedFuture(ov, loop=self._loop) self._cache[ov.address] = (f, ov, obj, callback) return f From bc167360a39723c91be33e6ae53b0d0dd859a93a Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 12 Aug 2013 03:48:30 +0300 Subject: [PATCH 0530/1502] Refactor unix subprocess support --- tulip/unix_events.py | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 9c5382b7..959f5f27 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -185,7 +185,7 @@ def _sig_chld(self): break transp = self._subprocesses.get(pid) if transp is not None: - transp._poll(returncode) + transp._process_exited(returncode) except ChildProcessError: pass @@ -460,7 +460,7 @@ def __init__(self, loop, protocol, args, shell, if stderr == subprocess.PIPE: self._pipes[2] = None self._pending_calls = collections.deque() - self._done = False + self._finished = False self._returncode = None self._proc = subprocess.Popen( @@ -514,7 +514,6 @@ def _post_init(self): transp, proto = yield from loop.connect_read_pipe( functools.partial(_UnixReadSubprocessPipeProto, self, 2), proc.stderr) - self._poll() def _call(self, cb, *data): if self._pending_calls is not None: @@ -531,21 +530,25 @@ def _pipe_connection_made(self, fd): def _pipe_connection_lost(self, fd, exc): self._call(self._protocol.pipe_connection_lost, fd, exc) - self._poll() + self._try_finish() def _pipe_data_received(self, fd, data): self._call(self._protocol.pipe_data_received, fd, data) - def _poll(self, returncode=None): + def _process_exited(self, returncode): + assert returncode is not None + if self._returncode is not None: + return + self._returncode = returncode + self._loop._subprocess_closed(self) + self._call(self._protocol.process_exited) + self._try_finish() + + def _try_finish(self): + assert not self._finished if self._returncode is None: - if returncode is None: - returncode = self._proc.poll() - if returncode is not None: - self._returncode = returncode - self._loop._subprocess_closed(self) - self._call(self._protocol.process_exited) - if self._returncode is not None and not self._done: - if all(p is not None and p.disconnected - for p in self._pipes.values()): - self._call(self._protocol.connection_lost, None) - self._done = True + return + if all(p is not None and p.disconnected + for p in self._pipes.values()): + self._finished = True + self._loop.call_soon(self._protocol.connection_lost, None) From 8967ca3e0d2721ed718a478df4bc8e3af8f7ea4d Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 12 Aug 2013 10:52:20 +0300 Subject: [PATCH 0531/1502] Drop never executed code --- tulip/unix_events.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 959f5f27..e5c80737 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -537,8 +537,6 @@ def _pipe_data_received(self, fd, data): def _process_exited(self, returncode): assert returncode is not None - if self._returncode is not None: - return self._returncode = returncode self._loop._subprocess_closed(self) self._call(self._protocol.process_exited) From 5de6ce285af11e28f0756c9dc4442e346759459b Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 12 Aug 2013 09:39:45 -0700 Subject: [PATCH 0532/1502] make InvalidStateError more informative --- tulip/futures.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tulip/futures.py b/tulip/futures.py index 676cea16..cad42b47 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -232,7 +232,7 @@ def result(self, timeout=0): if self._state == _CANCELLED: raise CancelledError if self._state != _FINISHED: - raise InvalidStateError + raise InvalidStateError('Result is not ready.') if self._tb_logger is not None: self._tb_logger.clear() self._tb_logger = None @@ -253,7 +253,7 @@ def exception(self, timeout=0): if self._state == _CANCELLED: raise CancelledError if self._state != _FINISHED: - raise InvalidStateError + raise InvalidStateError('Exception is not set.') if self._tb_logger is not None: self._tb_logger.clear() self._tb_logger = None @@ -293,7 +293,7 @@ def set_result(self, result): InvalidStateError. """ if self._state != _PENDING: - raise InvalidStateError + raise InvalidStateError('{}: {!r}'.format(self._state, self)) self._result = result self._state = _FINISHED self._schedule_callbacks() @@ -305,7 +305,7 @@ def set_exception(self, exception): InvalidStateError. """ if self._state != _PENDING: - raise InvalidStateError + raise InvalidStateError('{}: {!r}'.format(self._state, self)) self._exception = exception self._tb_logger = _TracebackLogger(exception) self._state = _FINISHED From b9012fa98f0e1a34b048bf3c4b0446b6317f5173 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 13 Aug 2013 02:21:16 +0300 Subject: [PATCH 0533/1502] Improve test coverage, fix couple bugs in Unix subprocess --- tests/events_test.py | 91 ++++++++++++++++++++++++++++++++++++--- tests/transports_test.py | 10 +++++ tests/unix_events_test.py | 89 ++++++++++++++++++++++++++++++++++++++ tulip/unix_events.py | 51 +++++++++++----------- 4 files changed, 210 insertions(+), 31 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 7918e388..33f6fbb3 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -169,12 +169,14 @@ def pipe_data_received(self, fd, data): self.got_data.set() def pipe_connection_lost(self, fd, exc): + assert self.state == 'CONNECTED', self.state if exc: self.disconnects[fd].set_exception(exc) else: self.disconnects[fd].set_result(exc) def process_exited(self): + assert self.state == 'CONNECTED', self.state self.returncode = self.transport.get_returncode() @@ -947,12 +949,14 @@ def test_subprocess_exec(self): proto = None transp = None + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + @tasks.coroutine def connect(): nonlocal proto, transp transp, proto = yield from self.loop.subprocess_exec( functools.partial(MySubprocessProtocol, self.loop), - 'tr', '[a-z]', '[A-Z]') + sys.executable, prog) self.assertIsInstance(proto, MySubprocessProtocol) self.loop.run_until_complete(connect()) @@ -960,12 +964,12 @@ def connect(): self.assertEqual('CONNECTED', proto.state) stdin = transp.get_pipe_transport(0) - stdin.write(b'Python ') - stdin.write(b'The Winner') - stdin.close() + stdin.write(b'Python The Winner') + self.loop.run_until_complete(proto.got_data.wait(10)) + transp.close() self.loop.run_until_complete(proto.completed) - self.assertEqual(0, proto.returncode) - self.assertEqual(b'PYTHON THE WINNER', proto.data[1]) + self.assertEqual(-signal.SIGTERM, proto.returncode) + self.assertEqual(b'Python The Winner', proto.data[1]) @unittest.skipIf(sys.platform == 'win32', "Don't support subprocess for Windows yet") @@ -1037,12 +1041,78 @@ def connect(): nonlocal proto, transp transp, proto = yield from self.loop.subprocess_shell( functools.partial(MySubprocessProtocol, self.loop), - 'exit 7') + 'exit 7', stdin=None, stdout=None, stderr=None) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_close_after_finish(self): + proto = None + transp = None + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) self.assertIsInstance(proto, MySubprocessProtocol) self.loop.run_until_complete(connect()) + self.assertIsNone(transp.get_pipe_transport(0)) self.loop.run_until_complete(proto.completed) self.assertEqual(7, proto.returncode) + self.assertIsNone(transp.close()) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_kill(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.kill() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGKILL, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_send_signal(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.send_signal(signal.SIGQUIT) + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGQUIT, proto.returncode) if sys.platform == 'win32': @@ -1308,6 +1378,13 @@ def test_empty(self): self.assertIsNone(dp.connection_refused(f)) self.assertIsNone(dp.datagram_received(f, f)) + sp = protocols.SubprocessProtocol() + self.assertIsNone(sp.connection_made(f)) + self.assertIsNone(sp.connection_lost(f)) + self.assertIsNone(sp.pipe_data_received(1, f)) + self.assertIsNone(sp.pipe_connection_lost(1, f)) + self.assertIsNone(sp.process_exited()) + class PolicyTests(unittest.TestCase): diff --git a/tests/transports_test.py b/tests/transports_test.py index b1c932f0..d2688c3a 100644 --- a/tests/transports_test.py +++ b/tests/transports_test.py @@ -46,3 +46,13 @@ def test_dgram_not_implemented(self): self.assertRaises(NotImplementedError, transport.sendto, 'data') self.assertRaises(NotImplementedError, transport.abort) + + def test_subprocess_transport_not_implemented(self): + transport = transports.SubprocessTransport() + + self.assertRaises(NotImplementedError, transport.get_pid) + self.assertRaises(NotImplementedError, transport.get_returncode) + self.assertRaises(NotImplementedError, transport.get_pipe_transport, 1) + self.assertRaises(NotImplementedError, transport.send_signal, 1) + self.assertRaises(NotImplementedError, transport.terminate) + self.assertRaises(NotImplementedError, transport.kill) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index 29a8a949..52af8055 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -187,6 +187,95 @@ class Err(OSError): self.assertRaises( RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP) + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = True + m_WIFSIGNALED.return_value = False + m_WEXITSTATUS.return_value = 3 + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + transp._process_exited.assert_called_with(3) + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_signal(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = False + m_WIFSIGNALED.return_value = True + m_WTERMSIG.return_value = 1 + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + transp._process_exited.assert_called_with(-1) + self.assertFalse(m_WEXITSTATUS.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_zero_pid(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(0, object()), ChildProcessError] + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m_WEXITSTATUS.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_not_registered_subprocess(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = True + m_WIFSIGNALED.return_value = False + m_WEXITSTATUS.return_value = 3 + + self.loop._sig_chld() + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_unknown_status(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = False + m_WIFSIGNALED.return_value = False + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + class UnixReadPipeTransportTests(unittest.TestCase): diff --git a/tulip/unix_events.py b/tulip/unix_events.py index e5c80737..ca7488c0 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -171,28 +171,30 @@ def _reg_sigchld(self): self._sig_chld) def _sig_chld(self): - try: - while True: - grp = os.getpgrp() + while True: + try: pid, status = os.waitpid(0, os.WNOHANG) - if pid == 0: - break - if os.WIFSIGNALED(status): - returncode = -os.WTERMSIG(status) - elif os.WIFEXITED(status): - returncode = os.WEXITSTATUS(status) - else: - break - transp = self._subprocesses.get(pid) - if transp is not None: - transp._process_exited(returncode) - except ChildProcessError: - pass + except ChildProcessError: + break + if pid == 0: + continue + elif os.WIFSIGNALED(status): + returncode = -os.WTERMSIG(status) + elif os.WIFEXITED(status): + returncode = os.WEXITSTATUS(status) + else: + # covered by + # SelectorEventLoopTests.test__sig_chld_unknown_status + # from tests/unix_events_test.py + # bug in coverage.py version 3.6 ??? + continue # pragma: no cover + transp = self._subprocesses.get(pid) + if transp is not None: + transp._process_exited(returncode) def _subprocess_closed(self, transport): pid = transport.get_pid() - if self._subprocesses.get(pid): - del self._subprocesses[pid] + self._subprocesses.pop(pid, None) def _set_nonblocking(fd): @@ -426,7 +428,7 @@ def __init__(self, proc, fd): def connection_made(self, transport): self.connected = True self.pipe = transport - self.proc._pipe_connection_made(self.fd) + self.proc._try_connected() def connection_lost(self, exc): self.disconnected = True @@ -468,9 +470,6 @@ def __init__(self, loop, protocol, args, shell, universal_newlines=False, bufsize=bufsize, **kwargs) self._extra['subprocess'] = self._proc - if not self._pipes: - self._loop.call_soon(self._protocol.connection_made, self) - def close(self): for proto in self._pipes.values(): proto.pipe.close() @@ -514,6 +513,8 @@ def _post_init(self): transp, proto = yield from loop.connect_read_pipe( functools.partial(_UnixReadSubprocessPipeProto, self, 2), proc.stderr) + if not self._pipes: + self._try_connected() def _call(self, cb, *data): if self._pending_calls is not None: @@ -521,7 +522,8 @@ def _call(self, cb, *data): else: self._loop.call_soon(cb, *data) - def _pipe_connection_made(self, fd): + def _try_connected(self): + assert self._pending_calls is not None if all(p is not None and p.connected for p in self._pipes.values()): self._loop.call_soon(self._protocol.connection_made, self) for callback, data in self._pending_calls: @@ -536,7 +538,8 @@ def _pipe_data_received(self, fd, data): self._call(self._protocol.pipe_data_received, fd, data) def _process_exited(self, returncode): - assert returncode is not None + assert returncode is not None, returncode + assert self._returncode is None, self._returncode self._returncode = returncode self._loop._subprocess_closed(self) self._call(self._protocol.process_exited) From 49c70fa3f98fa32502f138f90faf8f7345a9b34a Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 13 Aug 2013 05:16:07 +0300 Subject: [PATCH 0534/1502] Add test for stderr in subprocess --- tests/echo2.py | 7 +++++++ tests/events_test.py | 43 +++++++++++++++++++++++++++++++++++++------ 2 files changed, 44 insertions(+), 6 deletions(-) create mode 100644 tests/echo2.py diff --git a/tests/echo2.py b/tests/echo2.py new file mode 100644 index 00000000..24503295 --- /dev/null +++ b/tests/echo2.py @@ -0,0 +1,7 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + os.write(1, b'OUT:'+buf) + os.write(2, b'ERR:'+buf) diff --git a/tests/events_test.py b/tests/events_test.py index 33f6fbb3..2da48a65 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -150,7 +150,8 @@ def __init__(self, loop): self.disconnects = {fd: futures.Future(loop=loop) for fd in range(3)} self.data = {1: b'', 2: b''} self.returncode = None - self.got_data = locks.EventWaiter(loop=loop) + self.got_data = {1: locks.EventWaiter(loop=loop), + 2: locks.EventWaiter(loop=loop)} def connection_made(self, transport): self.transport = transport @@ -166,7 +167,7 @@ def connection_lost(self, exc): def pipe_data_received(self, fd, data): assert self.state == 'CONNECTED', self.state self.data[fd] += data - self.got_data.set() + self.got_data[fd].set() def pipe_connection_lost(self, fd, exc): assert self.state == 'CONNECTED', self.state @@ -965,7 +966,7 @@ def connect(): stdin = transp.get_pipe_transport(0) stdin.write(b'Python The Winner') - self.loop.run_until_complete(proto.got_data.wait(10)) + self.loop.run_until_complete(proto.got_data[1].wait(10)) transp.close() self.loop.run_until_complete(proto.completed) self.assertEqual(-signal.SIGTERM, proto.returncode) @@ -994,12 +995,12 @@ def connect(): try: stdin = transp.get_pipe_transport(0) stdin.write(b'Python ') - self.loop.run_until_complete(proto.got_data.wait(10)) - proto.got_data.clear() + self.loop.run_until_complete(proto.got_data[1].wait(10)) + proto.got_data[1].clear() self.assertEqual(b'Python ', proto.data[1]) stdin.write(b'The Winner') - self.loop.run_until_complete(proto.got_data.wait(10)) + self.loop.run_until_complete(proto.got_data[1].wait(10)) self.assertEqual(b'Python The Winner', proto.data[1]) finally: transp.close() @@ -1114,6 +1115,36 @@ def connect(): self.loop.run_until_complete(proto.completed) self.assertEqual(-signal.SIGQUIT, proto.returncode) + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_stderr(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'test') + self.loop.run_until_complete(proto.got_data[1].wait(10)) + self.assertEqual(b'OUT:test', proto.data[1]) + self.loop.run_until_complete(proto.got_data[2].wait(10)) + self.assertEqual(b'ERR:test', proto.data[2]) + + transp.close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + if sys.platform == 'win32': from tulip import windows_events From ae4ec5b09c484fc336f41d53fd182d051c776d95 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 13 Aug 2013 05:30:17 +0300 Subject: [PATCH 0535/1502] Add test for subpocess.STDOUT --- tests/events_test.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/events_test.py b/tests/events_test.py index 2da48a65..2f65c10b 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -11,6 +11,7 @@ import ssl except ImportError: ssl = None +import subprocess import sys import threading import time @@ -1145,6 +1146,35 @@ def connect(): self.loop.run_until_complete(proto.completed) self.assertEqual(-signal.SIGTERM, proto.returncode) + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_stderr_redirect_to_stdout(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog, stderr=subprocess.STDOUT) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'test') + self.loop.run_until_complete(proto.got_data[1].wait(10)) + self.assertEqual(b'OUT:testERR:test', proto.data[1]) + self.assertEqual(b'', proto.data[2]) + + transp.close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + if sys.platform == 'win32': from tulip import windows_events From 2d9b4073943e8e69138043a5016e3b1d4eb0e38e Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 13 Aug 2013 18:16:17 +0300 Subject: [PATCH 0536/1502] Add log message for unexpected exceptions in SIGCHLD handler, don't cancel handler on unknown exception --- tests/unix_events_test.py | 23 +++++++++++++++++++++ tulip/unix_events.py | 43 +++++++++++++++++++++------------------ 2 files changed, 46 insertions(+), 20 deletions(-) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index 52af8055..0ddd11bc 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -276,6 +276,29 @@ def test__sig_chld_unknown_status(self, m_waitpid, self.assertFalse(m_WEXITSTATUS.called) self.assertFalse(m_WTERMSIG.called) + @unittest.mock.patch('tulip.unix_events.tulip_log') + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_unknown_status(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG, + m_log): + m_waitpid.side_effect = Exception + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m_WEXITSTATUS.called) + m_log.exception.assert_called_with( + 'Unknown exception in SIGCHLD handler') + class UnixReadPipeTransportTests(unittest.TestCase): diff --git a/tulip/unix_events.py b/tulip/unix_events.py index ca7488c0..f27f9dc2 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -171,26 +171,29 @@ def _reg_sigchld(self): self._sig_chld) def _sig_chld(self): - while True: - try: - pid, status = os.waitpid(0, os.WNOHANG) - except ChildProcessError: - break - if pid == 0: - continue - elif os.WIFSIGNALED(status): - returncode = -os.WTERMSIG(status) - elif os.WIFEXITED(status): - returncode = os.WEXITSTATUS(status) - else: - # covered by - # SelectorEventLoopTests.test__sig_chld_unknown_status - # from tests/unix_events_test.py - # bug in coverage.py version 3.6 ??? - continue # pragma: no cover - transp = self._subprocesses.get(pid) - if transp is not None: - transp._process_exited(returncode) + try: + while True: + try: + pid, status = os.waitpid(0, os.WNOHANG) + except ChildProcessError: + break + if pid == 0: + continue + elif os.WIFSIGNALED(status): + returncode = -os.WTERMSIG(status) + elif os.WIFEXITED(status): + returncode = os.WEXITSTATUS(status) + else: + # covered by + # SelectorEventLoopTests.test__sig_chld_unknown_status + # from tests/unix_events_test.py + # bug in coverage.py version 3.6 ??? + continue # pragma: no cover + transp = self._subprocesses.get(pid) + if transp is not None: + transp._process_exited(returncode) + except Exception: + tulip_log.exception('Unknown exception in SIGCHLD handler') def _subprocess_closed(self, transport): pid = transport.get_pid() From 635739ebf3421f1e49e3d4cf8e9ed3bf0f0be1d3 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 13 Aug 2013 18:45:23 +0300 Subject: [PATCH 0537/1502] Point to setuptools instead of distribute in runtests.py Please note: distribute has joined back into setuptools --- runtests.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/runtests.py b/runtests.py index e254fdc1..e85d12d8 100644 --- a/runtests.py +++ b/runtests.py @@ -162,12 +162,12 @@ def runtests(): def runcoverage(sdir, args): """ To install coverage3 for Python 3, you need: - - Distribute (http://packages.python.org/distribute/) + - Setiptools (https://pypi.python.org/pypi/setuptools) What worked for me: - - download http://python-distribute.org/distribute_setup.py - * curl -O http://python-distribute.org/distribute_setup.py - - python3 distribute_setup.py + - download https://bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py + * curl -O https://bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py + - python3 ez_setup.py - python3 -m easy_install coverage """ try: From 44648b63cbc5d040ff42ea558cd02df59461e034 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 13 Aug 2013 19:36:44 +0300 Subject: [PATCH 0538/1502] Fix typo --- runtests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtests.py b/runtests.py index e85d12d8..6b784d02 100644 --- a/runtests.py +++ b/runtests.py @@ -162,7 +162,7 @@ def runtests(): def runcoverage(sdir, args): """ To install coverage3 for Python 3, you need: - - Setiptools (https://pypi.python.org/pypi/setuptools) + - Setuptools (https://pypi.python.org/pypi/setuptools) What worked for me: - download https://bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py From 7069a069eb4811ed15e3a31d0033bf52b9926d53 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 13 Aug 2013 19:39:14 +0300 Subject: [PATCH 0539/1502] Improve subprocess tests --- tests/events_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/events_test.py b/tests/events_test.py index 2f65c10b..b22dfe23 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -1066,6 +1066,8 @@ def connect(): self.loop.run_until_complete(connect()) self.assertIsNone(transp.get_pipe_transport(0)) + self.assertIsNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) self.loop.run_until_complete(proto.completed) self.assertEqual(7, proto.returncode) self.assertIsNone(transp.close()) @@ -1170,6 +1172,8 @@ def connect(): self.loop.run_until_complete(proto.got_data[1].wait(10)) self.assertEqual(b'OUT:testERR:test', proto.data[1]) self.assertEqual(b'', proto.data[2]) + self.assertIsNotNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) transp.close() self.loop.run_until_complete(proto.completed) From 5d97e9825cf6d0046807c387163404735099dc64 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 14 Aug 2013 15:27:11 -0700 Subject: [PATCH 0540/1502] Stop relying on get_event_loop() in http tests, test coverage --- tests/base_events_test.py | 12 +++++ tests/events_test.py | 10 +++-- tests/futures_test.py | 10 +++++ tests/http_client_functional_test.py | 65 +++++++++++++++++++--------- tests/http_client_test.py | 10 ++--- tests/http_parser_test.py | 15 +++++++ tests/http_protocol_test.py | 2 + tests/http_server_test.py | 45 ++++++++++--------- tests/http_wsgi_test.py | 38 ++++++++-------- tests/streams_test.py | 5 +++ tests/tasks_test.py | 38 ++++++++++++---- tests/unix_events_test.py | 12 ++--- tests/windows_events_test.py | 6 ++- tulip/base_events.py | 1 - tulip/http/client.py | 4 +- tulip/http/server.py | 6 +-- tulip/http/session.py | 4 +- 17 files changed, 190 insertions(+), 93 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index 84f26a2d..31ee30d9 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -45,6 +45,9 @@ def test_not_implemented(self): self.assertRaises( NotImplementedError, self.loop._make_write_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + next, self.loop._make_subprocess_transport(m, m, m, m, m, m, m)) def test__add_callback_handle(self): h = events.Handle(lambda: False, ()) @@ -366,8 +369,10 @@ def test_create_connection_no_getaddrinfo(self): @tasks.coroutine def getaddrinfo(*args, **kw): yield from [] + def getaddrinfo_task(*args, **kwds): return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + self.loop.getaddrinfo = getaddrinfo_task coro = self.loop.create_connection(MyProto, 'example.com', 80) self.assertRaises( @@ -378,8 +383,10 @@ def test_create_connection_connect_err(self): def getaddrinfo(*args, **kw): yield from [] return [(2, 1, 6, '', ('107.6.106.82', 80))] + def getaddrinfo_task(*args, **kwds): return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + self.loop.getaddrinfo = getaddrinfo_task self.loop.sock_connect = unittest.mock.Mock() self.loop.sock_connect.side_effect = socket.error @@ -393,8 +400,10 @@ def test_create_connection_multiple(self): def getaddrinfo(*args, **kw): return [(2, 1, 6, '', ('0.0.0.1', 80)), (2, 1, 6, '', ('0.0.0.2', 80))] + def getaddrinfo_task(*args, **kwds): return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + self.loop.getaddrinfo = getaddrinfo_task self.loop.sock_connect = unittest.mock.Mock() self.loop.sock_connect.side_effect = socket.error @@ -420,8 +429,10 @@ def bind(addr): def getaddrinfo(*args, **kw): return [(2, 1, 6, '', ('0.0.0.1', 80)), (2, 1, 6, '', ('0.0.0.2', 80))] + def getaddrinfo_task(*args, **kwds): return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + self.loop.getaddrinfo = getaddrinfo_task self.loop.sock_connect = unittest.mock.Mock() self.loop.sock_connect.side_effect = socket.error('Err2') @@ -443,6 +454,7 @@ def getaddrinfo(host, *args, **kw): (2, 1, 6, '', ('107.6.106.82', 80))] else: return [] + def getaddrinfo_task(*args, **kwds): return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) self.loop.getaddrinfo = getaddrinfo_task diff --git a/tests/events_test.py b/tests/events_test.py index b22dfe23..d070b25b 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -1010,7 +1010,7 @@ def connect(): self.assertEqual(-signal.SIGTERM, proto.returncode) @unittest.skipIf(sys.platform == 'win32', - "Don't support subprocess for Windows yet") + "Don't support subprocess for Windows yet") def test_subprocess_shell(self): proto = None transp = None @@ -1036,11 +1036,10 @@ def connect(): "Don't support subprocess for Windows yet") def test_subprocess_exitcode(self): proto = None - transp = None @tasks.coroutine def connect(): - nonlocal proto, transp + nonlocal proto transp, proto = yield from self.loop.subprocess_shell( functools.partial(MySubprocessProtocol, self.loop), 'exit 7', stdin=None, stdout=None, stderr=None) @@ -1425,6 +1424,11 @@ def test_not_implemented(self): self.assertRaises( NotImplementedError, loop.connect_write_pipe, f, unittest.mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, loop.subprocess_shell, f, + unittest.mock.sentinel) + self.assertRaises( + NotImplementedError, loop.subprocess_exec, f) class ProtocolsAbsTests(unittest.TestCase): diff --git a/tests/futures_test.py b/tests/futures_test.py index 7ec40be5..79aeddd2 100644 --- a/tests/futures_test.py +++ b/tests/futures_test.py @@ -233,6 +233,16 @@ def test_wrap_future_future(self): f2 = futures.wrap_future(f1) self.assertIs(f1, f2) + @unittest.mock.patch('tulip.futures.events') + def test_wrap_future_use_global_loop(self, m_events): + def run(arg): + time.sleep(0.1) + return arg + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = futures.wrap_future(f1) + self.assertIs(m_events.get_event_loop.return_value, f2._loop) + # A fake event loop for tests. All it does is implement a call_soon method # that immediately invokes the given function. diff --git a/tests/http_client_functional_test.py b/tests/http_client_functional_test.py index 039ab878..c351607a 100644 --- a/tests/http_client_functional_test.py +++ b/tests/http_client_functional_test.py @@ -16,13 +16,12 @@ class HttpClientFunctionalTests(unittest.TestCase): def setUp(self): self.loop = tulip.new_event_loop() - tulip.set_event_loop(self.loop) + tulip.set_event_loop(None) def tearDown(self): # just in case if we have transport close callbacks test_utils.run_briefly(self.loop) - tulip.set_event_loop(None) self.loop.close() gc.collect() @@ -41,6 +40,23 @@ def test_HTTP_200_OK_METHOD(self): self.assertEqual(content1, content2) r.close() + def test_use_global_loop(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + try: + tulip.set_event_loop(self.loop) + r = self.loop.run_until_complete( + client.request('get', httpd.url('method', 'get'))) + finally: + tulip.set_event_loop(None) + content1 = self.loop.run_until_complete(r.read()) + content2 = self.loop.run_until_complete(r.read()) + content = content1.decode() + + self.assertEqual(r.status, 200) + self.assertIn('"method": "GET"', content) + self.assertEqual(content1, content2) + r.close() + def test_HTTP_302_REDIRECT_GET(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: r = self.loop.run_until_complete( @@ -227,7 +243,7 @@ def test_POST_FILES_LIST_CT(self): with open(__file__) as f: r = self.loop.run_until_complete( - client.request('post', url, + client.request('post', url, loop=self.loop, files=[('some', f, 'text/plain')])) content = self.loop.run_until_complete(r.read(True)) @@ -252,7 +268,7 @@ def test_POST_FILES_SINGLE(self): with open(__file__) as f: r = self.loop.run_until_complete( - client.request('post', url, files=[f])) + client.request('post', url, files=[f], loop=self.loop)) content = self.loop.run_until_complete(r.read(True)) @@ -275,7 +291,7 @@ def test_POST_FILES_IO(self): data = io.BytesIO(b'data') r = self.loop.run_until_complete( - client.request('post', url, files=[data])) + client.request('post', url, files=[data], loop=self.loop)) content = self.loop.run_until_complete(r.read(True)) @@ -293,7 +309,7 @@ def test_POST_FILES_WITH_DATA(self): with open(__file__) as f: r = self.loop.run_until_complete( - client.request('post', url, + client.request('post', url, loop=self.loop, data={'test': 'true'}, files={'some': f})) content = self.loop.run_until_complete(r.read(True)) @@ -317,11 +333,13 @@ def test_POST_FILES_WITH_DATA(self): def test_encoding(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: r = self.loop.run_until_complete( - client.request('get', httpd.url('encoding', 'deflate'))) + client.request('get', httpd.url('encoding', 'deflate'), + loop=self.loop)) self.assertEqual(r.status, 200) r = self.loop.run_until_complete( - client.request('get', httpd.url('encoding', 'gzip'))) + client.request('get', httpd.url('encoding', 'gzip'), + loop=self.loop)) self.assertEqual(r.status, 200) def test_cookies(self): @@ -331,7 +349,7 @@ def test_cookies(self): r = self.loop.run_until_complete( client.request( - 'get', httpd.url('method', 'get'), + 'get', httpd.url('method', 'get'), loop=self.loop, cookies={'test1': '123', 'test2': c})) self.assertEqual(r.status, 200) @@ -341,7 +359,7 @@ def test_cookies(self): def test_set_cookies(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: resp = self.loop.run_until_complete( - client.request('get', httpd.url('cookies'))) + client.request('get', httpd.url('cookies'), loop=self.loop)) self.assertEqual(resp.status, 200) self.assertEqual(resp.cookies['c1'].value, 'cookie1') @@ -350,7 +368,7 @@ def test_set_cookies(self): def test_chunked(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: r = self.loop.run_until_complete( - client.request('get', httpd.url('chunked'))) + client.request('get', httpd.url('chunked'), loop=self.loop)) self.assertEqual(r.status, 200) self.assertEqual(r['Transfer-Encoding'], 'chunked') content = self.loop.run_until_complete(r.read(True)) @@ -362,13 +380,15 @@ def test_timeout(self): self.assertRaises( tulip.TimeoutError, self.loop.run_until_complete, - client.request('get', httpd.url('method', 'get'), timeout=0.1)) + client.request('get', httpd.url('method', 'get'), + timeout=0.1, loop=self.loop)) def test_request_conn_error(self): self.assertRaises( OSError, self.loop.run_until_complete, - client.request('get', 'http://0.0.0.0:1', timeout=0.1)) + client.request('get', 'http://0.0.0.0:1', + timeout=0.1, loop=self.loop)) def test_request_conn_closed(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: @@ -376,7 +396,8 @@ def test_request_conn_closed(self): self.assertRaises( tulip.http.HttpException, self.loop.run_until_complete, - client.request('get', httpd.url('method', 'get'))) + client.request('get', httpd.url('method', 'get'), + loop=self.loop)) def test_keepalive(self): from tulip.http import session @@ -384,14 +405,16 @@ def test_keepalive(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: r = self.loop.run_until_complete( - client.request('get', httpd.url('keepalive',), session=s)) + client.request('get', httpd.url('keepalive',), + session=s, loop=self.loop)) self.assertEqual(r.status, 200) content = self.loop.run_until_complete(r.read(True)) self.assertEqual(content['content'], 'requests=1') r.close() r = self.loop.run_until_complete( - client.request('get', httpd.url('keepalive'), session=s)) + client.request('get', httpd.url('keepalive'), + session=s, loop=self.loop)) self.assertEqual(r.status, 200) content = self.loop.run_until_complete(r.read(True)) self.assertEqual(content['content'], 'requests=2') @@ -404,14 +427,16 @@ def test_session_close(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: r = self.loop.run_until_complete( client.request( - 'get', httpd.url('keepalive') + '?close=1', session=s)) + 'get', httpd.url('keepalive') + '?close=1', + session=s, loop=self.loop)) self.assertEqual(r.status, 200) content = self.loop.run_until_complete(r.read(True)) self.assertEqual(content['content'], 'requests=1') r.close() r = self.loop.run_until_complete( - client.request('get', httpd.url('keepalive'), session=s)) + client.request('get', httpd.url('keepalive'), + session=s, loop=self.loop)) self.assertEqual(r.status, 200) content = self.loop.run_until_complete(r.read(True)) self.assertEqual(content['content'], 'requests=1') @@ -424,8 +449,8 @@ def test_session_cookies(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: s.update_cookies({'test': '1'}) r = self.loop.run_until_complete( - client.request( - 'get', httpd.url('cookies'), session=s)) + client.request('get', httpd.url('cookies'), + session=s, loop=self.loop)) self.assertEqual(r.status, 200) content = self.loop.run_until_complete(r.read(True)) diff --git a/tests/http_client_test.py b/tests/http_client_test.py index c59f7758..1aa27244 100644 --- a/tests/http_client_test.py +++ b/tests/http_client_test.py @@ -15,14 +15,13 @@ class HttpResponseTests(unittest.TestCase): def setUp(self): self.loop = tulip.new_event_loop() - tulip.set_event_loop(self.loop) + tulip.set_event_loop(None) self.transport = unittest.mock.Mock() - self.stream = tulip.StreamBuffer() + self.stream = tulip.StreamBuffer(loop=self.loop) self.response = HttpResponse('get', 'http://python.org') def tearDown(self): - tulip.set_event_loop(None) self.loop.close() def test_close(self): @@ -44,13 +43,12 @@ class HttpRequestTests(unittest.TestCase): def setUp(self): self.loop = tulip.new_event_loop() - tulip.set_event_loop(self.loop) + tulip.set_event_loop(None) self.transport = unittest.mock.Mock() - self.stream = tulip.StreamBuffer() + self.stream = tulip.StreamBuffer(loop=self.loop) def tearDown(self): - tulip.set_event_loop(None) self.loop.close() def test_method(self): diff --git a/tests/http_parser_test.py b/tests/http_parser_test.py index 91accfd7..6240ad49 100644 --- a/tests/http_parser_test.py +++ b/tests/http_parser_test.py @@ -12,6 +12,9 @@ class ParseHeadersTests(unittest.TestCase): + def setUp(self): + tulip.set_event_loop(None) + def test_parse_headers(self): hdrs = ('', 'test: line\r\n', ' continue\r\n', 'test2: data\r\n', '\r\n') @@ -96,6 +99,9 @@ def test_invalid_name(self): class DeflateBufferTests(unittest.TestCase): + def setUp(self): + tulip.set_event_loop(None) + def test_feed_data(self): buf = tulip.DataBuffer() dbuf = protocol.DeflateBuffer(buf, 'deflate') @@ -140,6 +146,9 @@ def test_feed_eof_err(self): class ParsePayloadTests(unittest.TestCase): + def setUp(self): + tulip.set_event_loop(None) + def test_parse_eof_payload(self): out = tulip.DataBuffer() buf = tulip.ParserBuffer() @@ -356,6 +365,9 @@ def test_http_payload_parser_length_zero(self): class ParseRequestTests(unittest.TestCase): + def setUp(self): + tulip.set_event_loop(None) + def test_http_request_parser_max_headers(self): p = protocol.http_request_parser(8190, 20, 8190) next(p) @@ -441,6 +453,9 @@ def test_http_request_parser_bad_version(self): class ParseResponseTests(unittest.TestCase): + def setUp(self): + tulip.set_event_loop(None) + def test_http_response_parser_bad_status_line(self): p = protocol.http_response_parser() next(p) diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py index fc6d2842..e74b8f27 100644 --- a/tests/http_protocol_test.py +++ b/tests/http_protocol_test.py @@ -4,6 +4,7 @@ import unittest.mock import zlib +import tulip from tulip.http import protocol @@ -11,6 +12,7 @@ class HttpMessageTests(unittest.TestCase): def setUp(self): self.transport = unittest.mock.Mock() + tulip.set_event_loop(None) def test_start_request(self): msg = protocol.Request( diff --git a/tests/http_server_test.py b/tests/http_server_test.py index 5162dbb0..49cfc8fa 100644 --- a/tests/http_server_test.py +++ b/tests/http_server_test.py @@ -13,10 +13,9 @@ class HttpServerProtocolTests(unittest.TestCase): def setUp(self): self.loop = tulip.new_event_loop() - tulip.set_event_loop(self.loop) + tulip.set_event_loop(None) def tearDown(self): - tulip.set_event_loop(None) self.loop.close() def test_http_error_exception(self): @@ -27,7 +26,7 @@ def test_http_error_exception(self): def test_handle_request(self): transport = unittest.mock.Mock() - srv = server.ServerHttpProtocol() + srv = server.ServerHttpProtocol(loop=self.loop) srv.connection_made(transport) rline = unittest.mock.Mock() @@ -39,14 +38,14 @@ def test_handle_request(self): self.assertTrue(content.startswith(b'HTTP/1.1 404 Not Found\r\n')) def test_connection_made(self): - srv = server.ServerHttpProtocol() + srv = server.ServerHttpProtocol(loop=self.loop) self.assertIsNone(srv._request_handler) srv.connection_made(unittest.mock.Mock()) self.assertIsNotNone(srv._request_handler) def test_data_received(self): - srv = server.ServerHttpProtocol() + srv = server.ServerHttpProtocol(loop=self.loop) srv.connection_made(unittest.mock.Mock()) srv.data_received(b'123') @@ -56,13 +55,13 @@ def test_data_received(self): self.assertEqual(b'123456', bytes(srv.stream._buffer)) def test_eof_received(self): - srv = server.ServerHttpProtocol() + srv = server.ServerHttpProtocol(loop=self.loop) srv.connection_made(unittest.mock.Mock()) srv.eof_received() self.assertTrue(srv.stream._eof) def test_connection_lost(self): - srv = server.ServerHttpProtocol() + srv = server.ServerHttpProtocol(loop=self.loop) srv.connection_made(unittest.mock.Mock()) srv.data_received(b'123') @@ -82,7 +81,7 @@ def test_connection_lost(self): self.assertIsNone(srv._keep_alive_handle) def test_srv_keep_alive(self): - srv = server.ServerHttpProtocol() + srv = server.ServerHttpProtocol(loop=self.loop) self.assertFalse(srv._keep_alive) srv.keep_alive(True) @@ -93,7 +92,7 @@ def test_srv_keep_alive(self): def test_handle_error(self): transport = unittest.mock.Mock() - srv = server.ServerHttpProtocol() + srv = server.ServerHttpProtocol(loop=self.loop) srv.connection_made(transport) srv.keep_alive(True) @@ -107,7 +106,7 @@ def test_handle_error(self): def test_handle_error_traceback_exc(self, m_trace): transport = unittest.mock.Mock() log = unittest.mock.Mock() - srv = server.ServerHttpProtocol(debug=True, log=log) + srv = server.ServerHttpProtocol(debug=True, log=log, loop=self.loop) srv.connection_made(transport) m_trace.format_exc.side_effect = ValueError @@ -120,7 +119,7 @@ def test_handle_error_traceback_exc(self, m_trace): def test_handle_error_debug(self): transport = unittest.mock.Mock() - srv = server.ServerHttpProtocol() + srv = server.ServerHttpProtocol(loop=self.loop) srv.debug = True srv.connection_made(transport) @@ -138,7 +137,7 @@ def test_handle_error_500(self): log = unittest.mock.Mock() transport = unittest.mock.Mock() - srv = server.ServerHttpProtocol(log=log) + srv = server.ServerHttpProtocol(log=log, loop=self.loop) srv.connection_made(transport) srv.handle_error(500) @@ -146,7 +145,7 @@ def test_handle_error_500(self): def test_handle(self): transport = unittest.mock.Mock() - srv = server.ServerHttpProtocol() + srv = server.ServerHttpProtocol(loop=self.loop) srv.connection_made(transport) handle = srv.handle_request = unittest.mock.Mock() @@ -161,7 +160,7 @@ def test_handle(self): def test_handle_coro(self): transport = unittest.mock.Mock() - srv = server.ServerHttpProtocol() + srv = server.ServerHttpProtocol(loop=self.loop) called = False @@ -184,24 +183,24 @@ def test_handle_cancel(self): log = unittest.mock.Mock() transport = unittest.mock.Mock() - srv = server.ServerHttpProtocol(log=log, debug=True) + srv = server.ServerHttpProtocol(log=log, debug=True, loop=self.loop) srv.connection_made(transport) srv.handle_request = unittest.mock.Mock() - @tulip.task + @tulip.coroutine def cancel(): srv._request_handler.cancel() self.loop.run_until_complete( - tulip.wait([srv._request_handler, cancel()])) + tulip.wait([srv._request_handler, cancel()], loop=self.loop)) self.assertTrue(log.debug.called) def test_handle_cancelled(self): log = unittest.mock.Mock() transport = unittest.mock.Mock() - srv = server.ServerHttpProtocol(log=log, debug=True) + srv = server.ServerHttpProtocol(log=log, debug=True, loop=self.loop) srv.connection_made(transport) srv.handle_request = unittest.mock.Mock() @@ -218,7 +217,7 @@ def test_handle_cancelled(self): def test_handle_400(self): transport = unittest.mock.Mock() - srv = server.ServerHttpProtocol() + srv = server.ServerHttpProtocol(loop=self.loop) srv.connection_made(transport) srv.handle_error = unittest.mock.Mock() srv.keep_alive(True) @@ -231,7 +230,7 @@ def test_handle_400(self): def test_handle_500(self): transport = unittest.mock.Mock() - srv = server.ServerHttpProtocol() + srv = server.ServerHttpProtocol(loop=self.loop) srv.connection_made(transport) handle = srv.handle_request = unittest.mock.Mock() @@ -248,7 +247,7 @@ def test_handle_500(self): def test_handle_error_no_handle_task(self): transport = unittest.mock.Mock() - srv = server.ServerHttpProtocol() + srv = server.ServerHttpProtocol(loop=self.loop) srv.keep_alive(True) srv.connection_made(transport) srv.connection_lost(None) @@ -257,7 +256,7 @@ def test_handle_error_no_handle_task(self): self.assertFalse(srv._keep_alive) def test_keep_alive(self): - srv = server.ServerHttpProtocol(keep_alive=0.1) + srv = server.ServerHttpProtocol(keep_alive=0.1, loop=self.loop) transport = unittest.mock.Mock() closed = False @@ -284,7 +283,7 @@ def close(): def test_keep_alive_close_existing(self): transport = unittest.mock.Mock() - srv = server.ServerHttpProtocol(keep_alive=15) + srv = server.ServerHttpProtocol(keep_alive=15, loop=self.loop) srv.connection_made(transport) self.assertIsNone(srv._keep_alive_handle) diff --git a/tests/http_wsgi_test.py b/tests/http_wsgi_test.py index 4d27f64c..053f5a69 100644 --- a/tests/http_wsgi_test.py +++ b/tests/http_wsgi_test.py @@ -13,7 +13,7 @@ class HttpWsgiServerProtocolTests(unittest.TestCase): def setUp(self): self.loop = tulip.new_event_loop() - tulip.set_event_loop(self.loop) + tulip.set_event_loop(None) self.wsgi = unittest.mock.Mock() self.stream = unittest.mock.Mock() @@ -29,16 +29,15 @@ def setUp(self): self.payload.feed_eof() def tearDown(self): - tulip.set_event_loop(None) self.loop.close() def test_ctor(self): - srv = wsgi.WSGIServerHttpProtocol(self.wsgi) + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) self.assertIs(srv.wsgi, self.wsgi) self.assertFalse(srv.readpayload) def _make_one(self, **kw): - srv = wsgi.WSGIServerHttpProtocol(self.wsgi, **kw) + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop, **kw) srv.stream = self.stream srv.transport = self.transport return srv.create_wsgi_environ(self.message, self.payload) @@ -120,7 +119,7 @@ def test_environ_forward(self): self.assertEqual(environ['REMOTE_PORT'], '80') def test_wsgi_response(self): - srv = wsgi.WSGIServerHttpProtocol(self.wsgi) + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) srv.stream = self.stream srv.transport = self.transport @@ -128,7 +127,7 @@ def test_wsgi_response(self): self.assertIsInstance(resp, wsgi.WsgiResponse) def test_wsgi_response_start_response(self): - srv = wsgi.WSGIServerHttpProtocol(self.wsgi) + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) srv.stream = self.stream srv.transport = self.transport @@ -139,7 +138,7 @@ def test_wsgi_response_start_response(self): self.assertIsInstance(resp.response, protocol.Response) def test_wsgi_response_start_response_exc(self): - srv = wsgi.WSGIServerHttpProtocol(self.wsgi) + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) srv.stream = self.stream srv.transport = self.transport @@ -150,7 +149,7 @@ def test_wsgi_response_start_response_exc(self): self.assertIsInstance(resp.response, protocol.Response) def test_wsgi_response_start_response_exc_status(self): - srv = wsgi.WSGIServerHttpProtocol(self.wsgi) + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) srv.stream = self.stream srv.transport = self.transport @@ -164,7 +163,7 @@ def test_wsgi_response_start_response_exc_status(self): @unittest.mock.patch('tulip.http.wsgi.tulip') def test_wsgi_response_101_upgrade_to_websocket(self, m_tulip): - srv = wsgi.WSGIServerHttpProtocol(self.wsgi) + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) srv.stream = self.stream srv.transport = self.transport @@ -192,13 +191,13 @@ def test_handle_request_futures(self): def wsgi_app(env, start): start('200 OK', [('Content-Type', 'text/plain')]) - f1 = tulip.Future() + f1 = tulip.Future(loop=self.loop) f1.set_result(b'data') - fut = tulip.Future() + fut = tulip.Future(loop=self.loop) fut.set_result([f1]) return fut - srv = wsgi.WSGIServerHttpProtocol(wsgi_app) + srv = wsgi.WSGIServerHttpProtocol(wsgi_app, loop=self.loop) srv.stream = self.stream srv.transport = self.transport @@ -216,14 +215,15 @@ def wsgi_app(env, start): start('200 OK', [('Content-Type', 'text/plain')]) return [b'data'] - stream = tulip.StreamReader() + stream = tulip.StreamReader(loop=self.loop) stream.feed_data(b'data') stream.feed_eof() self.message = protocol.RawRequestMessage( 'GET', '/path', (1, 1), self.headers, True, 'deflate') - srv = wsgi.WSGIServerHttpProtocol(wsgi_app, readpayload=True) + srv = wsgi.WSGIServerHttpProtocol( + wsgi_app, readpayload=True, loop=self.loop) srv.stream = self.stream srv.transport = self.transport @@ -242,7 +242,7 @@ def wsgi_app(env, start): start('200 OK', [('Content-Type', 'text/plain')]) return io.BytesIO(b'data') - srv = wsgi.WSGIServerHttpProtocol(wsgi_app) + srv = wsgi.WSGIServerHttpProtocol(wsgi_app, loop=self.loop) srv.stream = self.stream srv.transport = self.transport @@ -260,14 +260,15 @@ def wsgi_app(env, start): start('200 OK', [('Content-Type', 'text/plain')]) return [b'data'] - stream = tulip.StreamReader() + stream = tulip.StreamReader(loop=self.loop) stream.feed_data(b'data') stream.feed_eof() self.message = protocol.RawRequestMessage( 'GET', '/path', (1, 1), self.headers, False, 'deflate') - srv = wsgi.WSGIServerHttpProtocol(wsgi_app, readpayload=True) + srv = wsgi.WSGIServerHttpProtocol( + wsgi_app, readpayload=True, loop=self.loop) srv.stream = self.stream srv.transport = self.transport @@ -286,7 +287,8 @@ def wsgi_app(env, start): start('200 OK', [('Content-Type', 'text/plain')]) return [env['wsgi.input'].read()] - srv = wsgi.WSGIServerHttpProtocol(wsgi_app, readpayload=True) + srv = wsgi.WSGIServerHttpProtocol( + wsgi_app, readpayload=True, loop=self.loop) srv.stream = self.stream srv.transport = self.transport diff --git a/tests/streams_test.py b/tests/streams_test.py index 87fc9180..81221817 100644 --- a/tests/streams_test.py +++ b/tests/streams_test.py @@ -19,6 +19,11 @@ def setUp(self): def tearDown(self): self.loop.close() + @unittest.mock.patch('tulip.streams.events') + def test_ctor_global_loop(self, m_events): + stream = streams.StreamReader() + self.assertIs(stream.loop, m_events.get_event_loop.return_value) + def test_open_connection(self): with test_utils.run_test_server(self.loop) as httpd: f = streams.open_connection(*httpd.address, loop=self.loop) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index f228b3dc..4dc0a65b 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -44,32 +44,52 @@ def notmuch(): loop.close() def test_task_decorator(self): - @tasks.coroutine + @tasks.task def notmuch(): yield from [] return 'ko' - t = tasks.Task(notmuch(), loop=self.loop) + + try: + events.set_event_loop(self.loop) + t = notmuch() + finally: + events.set_event_loop(None) + + self.assertIsInstance(t, tasks.Task) self.loop.run_until_complete(t) self.assertTrue(t.done()) self.assertEqual(t.result(), 'ko') def test_task_decorator_func(self): - @tasks.coroutine + @tasks.task def notmuch(): return 'ko' - t = tasks.Task(notmuch(), loop=self.loop) + + try: + events.set_event_loop(self.loop) + t = notmuch() + finally: + events.set_event_loop(None) + + self.assertIsInstance(t, tasks.Task) self.loop.run_until_complete(t) self.assertTrue(t.done()) self.assertEqual(t.result(), 'ko') def test_task_decorator_fut(self): - fut = futures.Future(loop=self.loop) - fut.set_result('ko') - - @tasks.coroutine + @tasks.task def notmuch(): + fut = futures.Future(loop=self.loop) + fut.set_result('ko') return fut - t = tasks.Task(notmuch(), loop=self.loop) + + try: + events.set_event_loop(self.loop) + t = notmuch() + finally: + events.set_event_loop(None) + + self.assertIsInstance(t, tasks.Task) self.loop.run_until_complete(t) self.assertTrue(t.done()) self.assertEqual(t.result(), 'ko') diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index 0ddd11bc..c927701c 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -211,7 +211,7 @@ def test__sig_chld(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, @unittest.mock.patch('os.WIFEXITED') @unittest.mock.patch('os.waitpid') def test__sig_chld_signal(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, - m_WEXITSTATUS, m_WTERMSIG): + m_WEXITSTATUS, m_WTERMSIG): m_waitpid.side_effect = [(7, object()), ChildProcessError] m_WIFEXITED.return_value = False m_WIFSIGNALED.return_value = True @@ -229,7 +229,7 @@ def test__sig_chld_signal(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, @unittest.mock.patch('os.WIFEXITED') @unittest.mock.patch('os.waitpid') def test__sig_chld_zero_pid(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, - m_WEXITSTATUS, m_WTERMSIG): + m_WEXITSTATUS, m_WTERMSIG): m_waitpid.side_effect = [(0, object()), ChildProcessError] transp = unittest.mock.Mock() self.loop._subprocesses[7] = transp @@ -282,10 +282,10 @@ def test__sig_chld_unknown_status(self, m_waitpid, @unittest.mock.patch('os.WIFSIGNALED') @unittest.mock.patch('os.WIFEXITED') @unittest.mock.patch('os.waitpid') - def test__sig_chld_unknown_status(self, m_waitpid, - m_WIFEXITED, m_WIFSIGNALED, - m_WEXITSTATUS, m_WTERMSIG, - m_log): + def test__sig_chld_unknown_status_in_handler(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG, + m_log): m_waitpid.side_effect = Exception transp = unittest.mock.Mock() self.loop._subprocesses[7] = transp diff --git a/tests/windows_events_test.py b/tests/windows_events_test.py index 09faf2bb..ce9b74da 100644 --- a/tests/windows_events_test.py +++ b/tests/windows_events_test.py @@ -5,23 +5,25 @@ from tulip import windows_events from tulip import protocols from tulip import streams -from tulip import test_utils def connect_read_pipe(loop, file): stream_reader = streams.StreamReader(loop=loop) protocol = _StreamReaderProtocol(stream_reader) - transport = loop._make_read_pipe_transport(file, protocol) + loop._make_read_pipe_transport(file, protocol) return stream_reader class _StreamReaderProtocol(protocols.Protocol): def __init__(self, stream_reader): self.stream_reader = stream_reader + def connection_lost(self, exc): self.stream_reader.set_exception(exc) + def data_received(self, data): self.stream_reader.feed_data(data) + def eof_received(self): self.stream_reader.feed_eof() diff --git a/tulip/base_events.py b/tulip/base_events.py index 04af67b8..19de896a 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -85,7 +85,6 @@ def _make_subprocess_transport(self, protocol, args, shell, extra=None, **kwargs): """Create subprocess transport.""" raise NotImplementedError - yield def _read_from_self(self): """XXX""" diff --git a/tulip/http/client.py b/tulip/http/client.py index babfd7f6..2aedfdd1 100644 --- a/tulip/http/client.py +++ b/tulip/http/client.py @@ -13,6 +13,7 @@ import base64 import email.message +import functools import http.client import http.cookies import json @@ -127,7 +128,8 @@ def request(method, url, *, @tulip.coroutine def start(req, loop): transport, p = yield from loop.create_connection( - tulip.StreamProtocol, req.host, req.port, ssl=req.ssl) + functools.partial(tulip.StreamProtocol, loop=loop), + req.host, req.port, ssl=req.ssl) try: resp = req.send(transport) diff --git a/tulip/http/server.py b/tulip/http/server.py index 72fb15ee..fc5621c5 100644 --- a/tulip/http/server.py +++ b/tulip/http/server.py @@ -59,8 +59,8 @@ def __init__(self, *, log=logging, debug=False, def connection_made(self, transport): self.transport = transport - self.stream = tulip.StreamBuffer() - self._request_handler = self.start() + self.stream = tulip.StreamBuffer(loop=self._loop) + self._request_handler = tulip.Task(self.start(), loop=self._loop) def data_received(self, data): self.stream.feed_data(data) @@ -91,7 +91,7 @@ def log_debug(self, *args, **kw): def log_exception(self, *args, **kw): self.log.exception(*args, **kw) - @tulip.task + @tulip.coroutine def start(self): """Start processing of incoming requests. It reads request line, request headers and request payload, then diff --git a/tulip/http/session.py b/tulip/http/session.py index baf19dba..9cdd9cea 100644 --- a/tulip/http/session.py +++ b/tulip/http/session.py @@ -2,6 +2,7 @@ __all__ = ['Session'] +import functools import tulip import http.cookies @@ -47,7 +48,8 @@ def start(self, req, loop, new_conn=False, set_cookies=True): if new_conn or transport is None: new = True transport, proto = yield from loop.create_connection( - tulip.StreamProtocol, req.host, req.port, ssl=req.ssl) + functools.partial(tulip.StreamProtocol, loop=loop), + req.host, req.port, ssl=req.ssl) else: new = False From a63a64f2d07dde43b326ea97d834a2cebc2ad4ec Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 17 Aug 2013 01:14:00 +0300 Subject: [PATCH 0541/1502] Test that subprocess gets closing of stdout --- tests/events_test.py | 47 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 41 insertions(+), 6 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index d070b25b..480841db 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -967,7 +967,7 @@ def connect(): stdin = transp.get_pipe_transport(0) stdin.write(b'Python The Winner') - self.loop.run_until_complete(proto.got_data[1].wait(10)) + self.loop.run_until_complete(proto.got_data[1].wait(1)) transp.close() self.loop.run_until_complete(proto.completed) self.assertEqual(-signal.SIGTERM, proto.returncode) @@ -996,12 +996,12 @@ def connect(): try: stdin = transp.get_pipe_transport(0) stdin.write(b'Python ') - self.loop.run_until_complete(proto.got_data[1].wait(10)) + self.loop.run_until_complete(proto.got_data[1].wait(1)) proto.got_data[1].clear() self.assertEqual(b'Python ', proto.data[1]) stdin.write(b'The Winner') - self.loop.run_until_complete(proto.got_data[1].wait(10)) + self.loop.run_until_complete(proto.got_data[1].wait(1)) self.assertEqual(b'Python The Winner', proto.data[1]) finally: transp.close() @@ -1138,9 +1138,9 @@ def connect(): stdin = transp.get_pipe_transport(0) stdin.write(b'test') - self.loop.run_until_complete(proto.got_data[1].wait(10)) + self.loop.run_until_complete(proto.got_data[1].wait(1)) self.assertEqual(b'OUT:test', proto.data[1]) - self.loop.run_until_complete(proto.got_data[2].wait(10)) + self.loop.run_until_complete(proto.got_data[2].wait(1)) self.assertEqual(b'ERR:test', proto.data[2]) transp.close() @@ -1168,7 +1168,7 @@ def connect(): stdin = transp.get_pipe_transport(0) stdin.write(b'test') - self.loop.run_until_complete(proto.got_data[1].wait(10)) + self.loop.run_until_complete(proto.got_data[1].wait(1)) self.assertEqual(b'OUT:testERR:test', proto.data[1]) self.assertEqual(b'', proto.data[2]) self.assertIsNotNone(transp.get_pipe_transport(1)) @@ -1178,6 +1178,41 @@ def connect(): self.loop.run_until_complete(proto.completed) self.assertEqual(-signal.SIGTERM, proto.returncode) + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_close_client_stream(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo3.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdout = transp.get_pipe_transport(1) + stdin.write(b'test') + self.loop.run_until_complete(proto.got_data[1].wait(1)) + self.assertEqual(b'OUT:test', proto.data[1]) + + stdout.close() + self.loop.run_until_complete(proto.disconnects[1]) + stdin.write(b'xxx') + self.loop.run_until_complete(proto.got_data[2].wait(1)) + self.assertEqual(b'ERR:BrokenPipeError', proto.data[2]) + + transp.close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + if sys.platform == 'win32': from tulip import windows_events From 02954b9be355e014e64c7f2e0ddda34bd515d9c9 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 17 Aug 2013 01:20:34 +0300 Subject: [PATCH 0542/1502] Add missing file --- tests/echo3.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 tests/echo3.py diff --git a/tests/echo3.py b/tests/echo3.py new file mode 100644 index 00000000..f1f7ea7c --- /dev/null +++ b/tests/echo3.py @@ -0,0 +1,9 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + try: + os.write(1, b'OUT:'+buf) + except OSError as ex: + os.write(2, b'ERR:' + ex.__class__.__name__.encode('ascii')) From 3e530cde58979932373eb2abd8ec8704966c04e0 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 17 Aug 2013 02:14:13 +0300 Subject: [PATCH 0543/1502] Restore running of tests/streams_test.py --- tests/streams_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/streams_test.py b/tests/streams_test.py index 81221817..123aef9d 100644 --- a/tests/streams_test.py +++ b/tests/streams_test.py @@ -1,6 +1,7 @@ """Tests for streams.py.""" import unittest +import unittest.mock from tulip import events from tulip import streams From 85ade64b442eb9b60981c48a773f6bfa6d531edb Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 17 Aug 2013 03:30:37 +0300 Subject: [PATCH 0544/1502] Remove resource leakage, make some subprocess tests stable --- tests/echo2.py | 7 +++---- tests/events_test.py | 21 ++++++++++----------- tests/streams_test.py | 11 +++++++++++ tests/tasks_test.py | 15 +++++++++++++-- tulip/test_utils.py | 23 ++++++++++++----------- 5 files changed, 49 insertions(+), 28 deletions(-) diff --git a/tests/echo2.py b/tests/echo2.py index 24503295..e83ca09f 100644 --- a/tests/echo2.py +++ b/tests/echo2.py @@ -1,7 +1,6 @@ import os if __name__ == '__main__': - while True: - buf = os.read(0, 1024) - os.write(1, b'OUT:'+buf) - os.write(2, b'ERR:'+buf) + buf = os.read(0, 1024) + os.write(1, b'OUT:'+buf) + os.write(2, b'ERR:'+buf) diff --git a/tests/events_test.py b/tests/events_test.py index 480841db..8842b6e1 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -1138,14 +1138,13 @@ def connect(): stdin = transp.get_pipe_transport(0) stdin.write(b'test') - self.loop.run_until_complete(proto.got_data[1].wait(1)) - self.assertEqual(b'OUT:test', proto.data[1]) - self.loop.run_until_complete(proto.got_data[2].wait(1)) - self.assertEqual(b'ERR:test', proto.data[2]) - transp.close() self.loop.run_until_complete(proto.completed) - self.assertEqual(-signal.SIGTERM, proto.returncode) + + transp.close() + self.assertEqual(b'OUT:test', proto.data[1]) + self.assertEqual(b'ERR:test', proto.data[2]) + self.assertEqual(0, proto.returncode) @unittest.skipIf(sys.platform == 'win32', "Don't support subprocess for Windows yet") @@ -1167,16 +1166,16 @@ def connect(): self.loop.run_until_complete(proto.connected) stdin = transp.get_pipe_transport(0) + self.assertIsNotNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + stdin.write(b'test') - self.loop.run_until_complete(proto.got_data[1].wait(1)) + self.loop.run_until_complete(proto.completed) self.assertEqual(b'OUT:testERR:test', proto.data[1]) self.assertEqual(b'', proto.data[2]) - self.assertIsNotNone(transp.get_pipe_transport(1)) - self.assertIsNone(transp.get_pipe_transport(2)) transp.close() - self.loop.run_until_complete(proto.completed) - self.assertEqual(-signal.SIGTERM, proto.returncode) + self.assertEqual(0, proto.returncode) @unittest.skipIf(sys.platform == 'win32', "Don't support subprocess for Windows yet") diff --git a/tests/streams_test.py b/tests/streams_test.py index 123aef9d..347b1262 100644 --- a/tests/streams_test.py +++ b/tests/streams_test.py @@ -1,5 +1,6 @@ """Tests for streams.py.""" +import gc import unittest import unittest.mock @@ -18,7 +19,11 @@ def setUp(self): events.set_event_loop(None) def tearDown(self): + # just in case if we have transport close callbacks + test_utils.run_briefly(self.loop) + self.loop.close() + gc.collect() @unittest.mock.patch('tulip.streams.events') def test_ctor_global_loop(self, m_events): @@ -37,6 +42,8 @@ def test_open_connection(self): data = self.loop.run_until_complete(f) self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + writer.close() + def test_open_connection_no_loop_ssl(self): with test_utils.run_test_server(self.loop, use_ssl=True) as httpd: try: @@ -50,6 +57,8 @@ def test_open_connection_no_loop_ssl(self): data = self.loop.run_until_complete(f) self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + writer.close() + def test_open_connection_error(self): with test_utils.run_test_server(self.loop) as httpd: f = streams.open_connection(*httpd.address, loop=self.loop) @@ -59,6 +68,8 @@ def test_open_connection_error(self): with self.assertRaises(ZeroDivisionError): self.loop.run_until_complete(f) + writer.close() + def test_feed_empty_data(self): stream = streams.StreamReader(loop=self.loop) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 4dc0a65b..1daadd80 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -1,5 +1,6 @@ """Tests for tasks.py.""" +import gc import time import unittest import unittest.mock @@ -27,6 +28,7 @@ def setUp(self): def tearDown(self): self.loop.close() + gc.collect() def test_task_class(self): @tasks.coroutine @@ -107,6 +109,7 @@ def notmuch(): loop = events.new_event_loop() t = tasks.async(notmuch(), loop=loop) self.assertIs(t._loop, loop) + loop.close() def test_async_future(self): f_orig = futures.Future(loop=self.loop) @@ -118,9 +121,13 @@ def test_async_future(self): self.assertEqual(f.result(), 'ko') self.assertIs(f, f_orig) + loop = events.new_event_loop() + with self.assertRaises(ValueError): - loop = events.new_event_loop() f = tasks.async(f_orig, loop=loop) + + loop.close() + f = tasks.async(f_orig, loop=self.loop) self.assertIs(f, f_orig) @@ -135,9 +142,13 @@ def notmuch(): self.assertEqual(t.result(), 'ok') self.assertIs(t, t_orig) + loop = events.new_event_loop() + with self.assertRaises(ValueError): - loop = events.new_event_loop() t = tasks.async(t_orig, loop=loop) + + loop.close() + t = tasks.async(t_orig, loop=self.loop) self.assertIs(t, t_orig) diff --git a/tulip/test_utils.py b/tulip/test_utils.py index 98c61a3e..8e62b49a 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -132,20 +132,21 @@ def run(loop, fut): loop.call_soon_threadsafe( fut.set_result, (thread_loop, waiter, socks[0].getsockname())) - thread_loop.run_until_complete(waiter) + try: + thread_loop.run_until_complete(waiter) + finally: + # close opened trnsports + for tr in transports: + tr.close() - # close opened trnsports - for tr in transports: - tr.close() + run_briefly(thread_loop) # call close callbacks - run_briefly(thread_loop) # call close callbacks + for s in socks: + thread_loop.stop_serving(s) - for s in socks: - thread_loop.stop_serving(s) - - thread_loop.stop() - thread_loop.close() - gc.collect() + thread_loop.stop() + thread_loop.close() + gc.collect() fut = tulip.Future(loop=loop) server_thread = threading.Thread(target=run, args=(loop, fut)) From df1b721795a22ab0f97ba0fb7cb97f3b69b398e4 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 17 Aug 2013 03:44:02 +0300 Subject: [PATCH 0545/1502] Remove subtle resource leak --- tulip/test_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tulip/test_utils.py b/tulip/test_utils.py index 8e62b49a..17f00475 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -135,6 +135,9 @@ def run(loop, fut): try: thread_loop.run_until_complete(waiter) finally: + # call pending connection_made if present + run_briefly(thread_loop) + # close opened trnsports for tr in transports: tr.close() From 38249877926ac7c7591af4c2c2f08f4b4ee6e724 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 17 Aug 2013 03:55:02 +0300 Subject: [PATCH 0546/1502] Get rid of socket.error alias, use OSError instead --- tests/base_events_test.py | 49 +++++++++++++++-------------------- tests/events_test.py | 4 +-- tests/selector_events_test.py | 2 +- tulip/base_events.py | 22 ++++++++-------- tulip/selector_events.py | 4 +-- 5 files changed, 37 insertions(+), 44 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index 31ee30d9..3fa7a06f 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -343,10 +343,9 @@ def getaddrinfo_task(*args, **kwds): def _socket(*args, **kw): nonlocal idx, errors idx += 1 - raise socket.error(errors[idx]) + raise OSError(errors[idx]) m_socket.socket = _socket - m_socket.error = socket.error self.loop.getaddrinfo = getaddrinfo_task @@ -376,7 +375,7 @@ def getaddrinfo_task(*args, **kwds): self.loop.getaddrinfo = getaddrinfo_task coro = self.loop.create_connection(MyProto, 'example.com', 80) self.assertRaises( - socket.error, self.loop.run_until_complete, coro) + OSError, self.loop.run_until_complete, coro) def test_create_connection_connect_err(self): @tasks.coroutine @@ -389,11 +388,11 @@ def getaddrinfo_task(*args, **kwds): self.loop.getaddrinfo = getaddrinfo_task self.loop.sock_connect = unittest.mock.Mock() - self.loop.sock_connect.side_effect = socket.error + self.loop.sock_connect.side_effect = OSError coro = self.loop.create_connection(MyProto, 'example.com', 80) self.assertRaises( - socket.error, self.loop.run_until_complete, coro) + OSError, self.loop.run_until_complete, coro) def test_create_connection_multiple(self): @tasks.coroutine @@ -406,20 +405,19 @@ def getaddrinfo_task(*args, **kwds): self.loop.getaddrinfo = getaddrinfo_task self.loop.sock_connect = unittest.mock.Mock() - self.loop.sock_connect.side_effect = socket.error + self.loop.sock_connect.side_effect = OSError coro = self.loop.create_connection( MyProto, 'example.com', 80, family=socket.AF_INET) - with self.assertRaises(socket.error): + with self.assertRaises(OSError): self.loop.run_until_complete(coro) @unittest.mock.patch('tulip.base_events.socket') def test_create_connection_mutiple_errors_local_addr(self, m_socket): - m_socket.error = socket.error def bind(addr): if addr[0] == '0.0.0.1': - err = socket.error('Err') + err = OSError('Err') err.strerror = 'Err' raise err @@ -435,12 +433,12 @@ def getaddrinfo_task(*args, **kwds): self.loop.getaddrinfo = getaddrinfo_task self.loop.sock_connect = unittest.mock.Mock() - self.loop.sock_connect.side_effect = socket.error('Err2') + self.loop.sock_connect.side_effect = OSError('Err2') coro = self.loop.create_connection( MyProto, 'example.com', 80, family=socket.AF_INET, local_addr=(None, 8080)) - with self.assertRaises(socket.error) as cm: + with self.assertRaises(OSError) as cm: self.loop.run_until_complete(coro) self.assertTrue(str(cm.exception), 'Multiple exceptions: ') @@ -463,7 +461,7 @@ def getaddrinfo_task(*args, **kwds): MyProto, 'example.com', 80, family=socket.AF_INET, local_addr=(None, 8080)) self.assertRaises( - socket.error, self.loop.run_until_complete, coro) + OSError, self.loop.run_until_complete, coro) def test_start_serving_empty_host(self): # if host is empty string use None instead @@ -497,15 +495,14 @@ def test_start_serving_no_getaddrinfo(self): getaddrinfo.return_value = [] f = self.loop.start_serving(MyProto, '0.0.0.0', 0) - self.assertRaises(socket.error, self.loop.run_until_complete, f) + self.assertRaises(OSError, self.loop.run_until_complete, f) @unittest.mock.patch('tulip.base_events.socket') def test_start_serving_cant_bind(self, m_socket): - class Err(socket.error): + class Err(OSError): strerror = 'error' - m_socket.error = socket.error m_socket.getaddrinfo.return_value = [ (2, 1, 6, '', ('127.0.0.1', 10100))] m_sock = m_socket.socket.return_value = unittest.mock.Mock() @@ -517,13 +514,12 @@ class Err(socket.error): @unittest.mock.patch('tulip.base_events.socket') def test_create_datagram_endpoint_no_addrinfo(self, m_socket): - m_socket.error = socket.error m_socket.getaddrinfo.return_value = [] coro = self.loop.create_datagram_endpoint( MyDatagramProto, local_addr=('localhost', 0)) self.assertRaises( - socket.error, self.loop.run_until_complete, coro) + OSError, self.loop.run_until_complete, coro) def test_create_datagram_endpoint_addr_error(self): coro = self.loop.create_datagram_endpoint( @@ -537,28 +533,27 @@ def test_create_datagram_endpoint_addr_error(self): def test_create_datagram_endpoint_connect_err(self): self.loop.sock_connect = unittest.mock.Mock() - self.loop.sock_connect.side_effect = socket.error + self.loop.sock_connect.side_effect = OSError coro = self.loop.create_datagram_endpoint( protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0)) self.assertRaises( - socket.error, self.loop.run_until_complete, coro) + OSError, self.loop.run_until_complete, coro) @unittest.mock.patch('tulip.base_events.socket') def test_create_datagram_endpoint_socket_err(self, m_socket): - m_socket.error = socket.error m_socket.getaddrinfo = socket.getaddrinfo - m_socket.socket.side_effect = socket.error + m_socket.socket.side_effect = OSError coro = self.loop.create_datagram_endpoint( protocols.DatagramProtocol, family=socket.AF_INET) self.assertRaises( - socket.error, self.loop.run_until_complete, coro) + OSError, self.loop.run_until_complete, coro) coro = self.loop.create_datagram_endpoint( protocols.DatagramProtocol, local_addr=('127.0.0.1', 0)) self.assertRaises( - socket.error, self.loop.run_until_complete, coro) + OSError, self.loop.run_until_complete, coro) def test_create_datagram_endpoint_no_matching_family(self): coro = self.loop.create_datagram_endpoint( @@ -569,13 +564,12 @@ def test_create_datagram_endpoint_no_matching_family(self): @unittest.mock.patch('tulip.base_events.socket') def test_create_datagram_endpoint_setblk_err(self, m_socket): - m_socket.error = socket.error - m_socket.socket.return_value.setblocking.side_effect = socket.error + m_socket.socket.return_value.setblocking.side_effect = OSError coro = self.loop.create_datagram_endpoint( protocols.DatagramProtocol, family=socket.AF_INET) self.assertRaises( - socket.error, self.loop.run_until_complete, coro) + OSError, self.loop.run_until_complete, coro) self.assertTrue( m_socket.socket.return_value.close.called) @@ -586,10 +580,9 @@ def test_create_datagram_endpoint_noaddr_nofamily(self): @unittest.mock.patch('tulip.base_events.socket') def test_create_datagram_endpoint_cant_bind(self, m_socket): - class Err(socket.error): + class Err(OSError): pass - m_socket.error = socket.error m_socket.AF_INET6 = socket.AF_INET6 m_socket.getaddrinfo = socket.getaddrinfo m_sock = m_socket.socket.return_value = unittest.mock.Mock() diff --git a/tests/events_test.py b/tests/events_test.py index 8842b6e1..618f319d 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -544,7 +544,7 @@ def test_create_connection_local_addr_in_use(self): f = self.loop.create_connection( lambda: MyProto(loop=self.loop), *httpd.address, local_addr=httpd.address) - with self.assertRaises(socket.error) as cm: + with self.assertRaises(OSError) as cm: self.loop.run_until_complete(f) self.assertEqual(cm.exception.errno, errno.EADDRINUSE) self.assertIn(str(httpd.address), cm.exception.strerror) @@ -686,7 +686,7 @@ def test_start_serving_addr_in_use(self): host, port = sock.getsockname() f = self.loop.start_serving(MyProto, host=host, port=port) - with self.assertRaises(socket.error) as cm: + with self.assertRaises(OSError) as cm: self.loop.run_until_complete(f) self.assertEqual(cm.exception.errno, errno.EADDRINUSE) diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index 93346600..76d72910 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -325,7 +325,7 @@ def test__sock_connect_exception(self): self.loop.remove_writer = unittest.mock.Mock() self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) - self.assertIsInstance(f.exception(), socket.error) + self.assertIsInstance(f.exception(), OSError) def test_sock_accept(self): sock = unittest.mock.Mock() diff --git a/tulip/base_events.py b/tulip/base_events.py index 19de896a..c3378da6 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -265,11 +265,11 @@ def create_connection(self, protocol_factory, host=None, port=None, *, infos = f1.result() if not infos: - raise socket.error('getaddrinfo() returned empty list') + raise OSError('getaddrinfo() returned empty list') if f2 is not None: laddr_infos = f2.result() if not laddr_infos: - raise socket.error('getaddrinfo() returned empty list') + raise OSError('getaddrinfo() returned empty list') exceptions = [] for family, type, proto, cname, address in infos: @@ -281,8 +281,8 @@ def create_connection(self, protocol_factory, host=None, port=None, *, try: sock.bind(laddr) break - except socket.error as exc: - exc = socket.error( + except OSError as exc: + exc = OSError( exc.errno, 'error while ' 'attempting to bind on address ' '{!r}: {}'.format( @@ -293,7 +293,7 @@ def create_connection(self, protocol_factory, host=None, port=None, *, sock = None continue yield from self.sock_connect(sock, address) - except socket.error as exc: + except OSError as exc: if sock is not None: sock.close() exceptions.append(exc) @@ -309,7 +309,7 @@ def create_connection(self, protocol_factory, host=None, port=None, *, raise exceptions[0] # Raise a combined exception so the user can see all # the various error messages. - raise socket.error('Multiple exceptions: {}'.format( + raise OSError('Multiple exceptions: {}'.format( ', '.join(str(exc) for exc in exceptions))) elif sock is None: @@ -351,7 +351,7 @@ def create_datagram_endpoint(self, protocol_factory, *addr, family=family, type=socket.SOCK_DGRAM, proto=proto, flags=flags) if not infos: - raise socket.error('getaddrinfo() returned empty list') + raise OSError('getaddrinfo() returned empty list') for fam, _, pro, _, address in infos: key = (fam, pro) @@ -387,7 +387,7 @@ def create_datagram_endpoint(self, protocol_factory, if remote_addr: yield from self.sock_connect(sock, remote_address) r_addr = remote_address - except socket.error as exc: + except OSError as exc: if sock is not None: sock.close() exceptions.append(exc) @@ -452,7 +452,7 @@ def _start_serving_internal(self, protocol_factory, host=None, port=None, host, port, family=family, type=socket.SOCK_STREAM, proto=0, flags=flags) if not infos: - raise socket.error('getaddrinfo() returned empty list') + raise OSError('getaddrinfo() returned empty list') completed = False try: @@ -472,8 +472,8 @@ def _start_serving_internal(self, protocol_factory, host=None, port=None, True) try: sock.bind(sa) - except socket.error as err: - raise socket.error(err.errno, 'error while attempting ' + except OSError as err: + raise OSError(err.errno, 'error while attempting ' 'to bind on address %r: %s' % (sa, err.strerror.lower())) completed = True diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 225af6f3..82b86939 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -279,7 +279,7 @@ def _sock_connect(self, fut, registered, sock, address): err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) if err != 0: # Jump to the except clause below. - raise socket.error(err, 'Connect call failed') + raise OSError(err, 'Connect call failed') fut.set_result(None) except (BlockingIOError, InterruptedError): self.add_writer(fd, self._sock_connect, fut, True, sock, address) @@ -419,7 +419,7 @@ def write(self, data): n = self._sock.send(data) except (BlockingIOError, InterruptedError): n = 0 - except socket.error as exc: + except OSError as exc: self._fatal_error(exc) return From 3c77c297c0b9336db757d141eec68b8bf5e09f52 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 17 Aug 2013 21:27:56 +0300 Subject: [PATCH 0547/1502] Add --catch parameter to runtests.py --- runtests.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/runtests.py b/runtests.py index 6b784d02..35e3ea3c 100644 --- a/runtests.py +++ b/runtests.py @@ -31,6 +31,8 @@ import unittest import importlib.machinery +from unittest.signals import installHandler + assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' ARGS = argparse.ArgumentParser(description="Run all unittests.") @@ -42,6 +44,9 @@ ARGS.add_argument( '-f', '--failfast', action="store_true", default=False, dest='failfast', help='Stop on first fail or error') +ARGS.add_argument( + '-c', '--catch', action="store_true", default=False, + dest='catchbreak', help='Catch control-C and display results') ARGS.add_argument( '-q', action="store_true", dest='quiet', help='quiet') ARGS.add_argument( @@ -142,6 +147,7 @@ def runtests(): v = 0 if args.quiet else args.verbose + 1 failfast = args.failfast + catchbreak = args.catchbreak tests = load_tests(args.testsdir, includes, excludes) logger = logging.getLogger() @@ -155,6 +161,8 @@ def runtests(): logger.setLevel(logging.INFO) elif v >= 4: logger.setLevel(logging.DEBUG) + if catchbreak: + installHandler() result = unittest.TextTestRunner(verbosity=v, failfast=failfast).run(tests) sys.exit(not result.wasSuccessful()) From cfc14727f4516286fffaf8ec7b089f61aa7c983e Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sun, 18 Aug 2013 21:40:19 +0300 Subject: [PATCH 0548/1502] Replase SIGQUIT with SIGHUP in tests to don't create core dumps on test run. --- tests/events_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 618f319d..0f8f6f47 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -1113,9 +1113,9 @@ def connect(): self.loop.run_until_complete(connect()) self.loop.run_until_complete(proto.connected) - transp.send_signal(signal.SIGQUIT) + transp.send_signal(signal.SIGHUP) self.loop.run_until_complete(proto.completed) - self.assertEqual(-signal.SIGQUIT, proto.returncode) + self.assertEqual(-signal.SIGHUP, proto.returncode) @unittest.skipIf(sys.platform == 'win32', "Don't support subprocess for Windows yet") From 56ecb123c925c9ec3d4cf8de96b1f7d77e2d3c2b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 18 Aug 2013 12:28:40 -0700 Subject: [PATCH 0549/1502] Shorten URL to fit in 80 chars. --- tulip/transports.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tulip/transports.py b/tulip/transports.py index c571fcc8..56425aa9 100644 --- a/tulip/transports.py +++ b/tulip/transports.py @@ -171,7 +171,7 @@ def send_signal(self, signal): """Send signal to subprocess. See also: - http://docs.python.org/3/library/subprocess#subprocess.Popen.send_signal + docs.python.org/3/library/subprocess#subprocess.Popen.send_signal """ raise NotImplementedError From 3352f59d971c860ee45900f69121ed95ab5e653d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 18 Aug 2013 12:33:33 -0700 Subject: [PATCH 0550/1502] Shorten/fold lines to fit 80 chars. --- runtests.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/runtests.py b/runtests.py index 35e3ea3c..be7bbe4a 100644 --- a/runtests.py +++ b/runtests.py @@ -173,8 +173,9 @@ def runcoverage(sdir, args): - Setuptools (https://pypi.python.org/pypi/setuptools) What worked for me: - - download https://bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py - * curl -O https://bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py + - download bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py + * curl -O \ + https://bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py - python3 ez_setup.py - python3 -m easy_install coverage """ From 3de2c2e3ed75845d159d7842cdbe5d402aaf2d33 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 19 Aug 2013 16:32:25 +0300 Subject: [PATCH 0551/1502] Force pipe transport to work with pipes only --- tests/unix_events_test.py | 64 --------------------------------------- tulip/unix_events.py | 25 ++++----------- 2 files changed, 6 insertions(+), 83 deletions(-) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index c927701c..a384b4a1 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -464,7 +464,6 @@ def test_ctor(self): self.loop.add_reader.assert_called_with(5, tr._read_ready) self.loop.call_soon.assert_called_with( self.protocol.connection_made, tr) - self.assertTrue(tr._enable_read_hack) def test_ctor_with_waiter(self): fut = futures.Future(loop=self.loop) @@ -472,7 +471,6 @@ def test_ctor_with_waiter(self): self.loop, self.pipe, self.protocol, fut) self.loop.call_soon.assert_called_with(fut.set_result, None) self.loop.add_reader.assert_called_with(5, tr._read_ready) - self.assertTrue(tr._enable_read_hack) fut.cancel() def test_can_write_eof(self): @@ -782,65 +780,3 @@ def test_discard_output_without_pending_writes(self): self.assertTrue(tr._writing) self.assertFalse(self.loop.remove_writer.called) self.assertEqual([], tr._buffer) - - -class UnixWritePipeRegularFileTests(unittest.TestCase): - - def setUp(self): - self.loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop) - self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) - - def test_ctor_with_regular_file(self): - with tempfile.TemporaryFile() as f: - tr = unix_events._UnixWritePipeTransport(self.loop, f, - self.protocol) - self.assertFalse(self.loop.add_reader.called) - self.loop.call_soon.assert_called_with( - self.protocol.connection_made, tr) - self.assertFalse(tr._enable_read_hack) - - def test_write_eof(self): - with tempfile.TemporaryFile() as f: - tr = unix_events._UnixWritePipeTransport( - self.loop, f, self.protocol) - - tr.write_eof() - self.assertTrue(tr._closing) - self.assertFalse(self.loop.remove_reader.called) - self.loop.call_soon.assert_called_with( - tr._call_connection_lost, None) - - @unittest.mock.patch('os.write') - def test__write_ready_closing(self, m_write): - with tempfile.TemporaryFile() as f: - fileno = f.fileno() - tr = unix_events._UnixWritePipeTransport( - self.loop, f, self.protocol) - - tr._closing = True - tr._buffer = [b'da', b'ta'] - m_write.return_value = 4 - tr._write_ready() - m_write.assert_called_with(fileno, b'data') - self.loop.remove_writer.assert_called_with(fileno) - self.assertFalse(self.loop.remove_reader.called) - self.assertEqual([], tr._buffer) - self.protocol.connection_lost.assert_called_with(None) - self.assertTrue(f.closed) - - @unittest.mock.patch('os.write') - def test_abort(self, m_write): - with tempfile.TemporaryFile() as f: - fileno = f.fileno() - tr = unix_events._UnixWritePipeTransport( - self.loop, f, self.protocol) - - tr._buffer = [b'da', b'ta'] - tr.abort() - self.assertFalse(m_write.called) - self.loop.remove_writer.assert_called_with(fileno) - self.assertFalse(self.loop.remove_reader.called) - self.assertEqual([], tr._buffer) - self.assertTrue(tr._closing) - self.loop.call_soon.assert_called_with( - tr._call_connection_lost, None) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index f27f9dc2..f7c5e5ac 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -281,19 +281,9 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): self._conn_lost = 0 self._closing = False # Set when close() or write_eof() called. self._writing = True - # Do nothing if it is a regular file. - # Enable hack only if pipe is FIFO object. - # Look on twisted.internet.process:ProcessWriter.__init__ - if stat.S_ISFIFO(os.fstat(self._fileno).st_mode): - self._enable_read_hack = True - else: - # If the pipe is not a unix pipe, then the read hack is never - # applicable. This case arises when _UnixWritePipeTransport - # is used by subprocess and stdout/stderr - # are redirected to a normal file. - self._enable_read_hack = False - if self._enable_read_hack: - self._loop.add_reader(self._fileno, self._read_ready) + if not stat.S_ISFIFO(os.fstat(self._fileno).st_mode): + raise ValueError("Pipe transport is for pipes only.") + self._loop.add_reader(self._fileno, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: @@ -352,8 +342,7 @@ def _write_ready(self): if n == len(data): self._loop.remove_writer(self._fileno) if self._closing: - if self._enable_read_hack: - self._loop.remove_reader(self._fileno) + self._loop.remove_reader(self._fileno) self._call_connection_lost(None) return elif n > 0: @@ -369,8 +358,7 @@ def write_eof(self): assert self._pipe self._closing = True if not self._buffer: - if self._enable_read_hack: - self._loop.remove_reader(self._fileno) + self._loop.remove_reader(self._fileno) self._loop.call_soon(self._call_connection_lost, None) def close(self): @@ -390,8 +378,7 @@ def _close(self, exc=None): self._closing = True self._buffer.clear() self._loop.remove_writer(self._fileno) - if self._enable_read_hack: - self._loop.remove_reader(self._fileno) + self._loop.remove_reader(self._fileno) self._loop.call_soon(self._call_connection_lost, exc) def _call_connection_lost(self, exc): From 2813e4239ae2d11fe45ee9fe2712bcda9d42cef4 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 19 Aug 2013 16:34:16 +0300 Subject: [PATCH 0552/1502] Drop useless comment --- tulip/selector_events.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 82b86939..b86acf10 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -20,19 +20,6 @@ from .log import tulip_log -# Errno values indicating the connection was disconnected. -# Comment out _DISCONNECTED as never used -# TODO: make sure that errors has processed properly -# for now we have no exception clsses for ENOTCONN and EBADF -# _DISCONNECTED = frozenset((errno.ECONNRESET, -# errno.ENOTCONN, -# errno.ESHUTDOWN, -# errno.ECONNABORTED, -# errno.EPIPE, -# errno.EBADF, -# )) - - class BaseSelectorEventLoop(base_events.BaseEventLoop): """Selector event loop. From d5becc4e3430ee6a81abc8da6833adbc342cb5ae Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 19 Aug 2013 18:09:58 +0300 Subject: [PATCH 0553/1502] Move pipe type check before setting it nonblocking --- tulip/unix_events.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index f7c5e5ac..cccc568a 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -275,14 +275,14 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): self._loop = loop self._pipe = pipe self._fileno = pipe.fileno() + if not stat.S_ISFIFO(os.fstat(self._fileno).st_mode): + raise ValueError("Pipe transport is for pipes only.") _set_nonblocking(self._fileno) self._protocol = protocol self._buffer = [] self._conn_lost = 0 self._closing = False # Set when close() or write_eof() called. self._writing = True - if not stat.S_ISFIFO(os.fstat(self._fileno).st_mode): - raise ValueError("Pipe transport is for pipes only.") self._loop.add_reader(self._fileno, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) From 64a251a0b3561ffc197a07be1b62d33a9ac7ec24 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 20 Aug 2013 12:26:31 -0700 Subject: [PATCH 0554/1502] Add tasks.wait_for() helper coroutine, it waits on the single Future or coroutine. --- tests/tasks_test.py | 63 ++++++++++++++++++++++++++++++++++++++++++++- tulip/tasks.py | 34 +++++++++++++++++++++--- 2 files changed, 93 insertions(+), 4 deletions(-) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 1daadd80..1f45b82f 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -360,6 +360,48 @@ def task(): self.assertEqual(r, 42) self.assertTrue(0.08 <= t1-t0 <= 0.12) + def test_wait_for(self): + @tasks.coroutine + def foo(): + yield from tasks.sleep(0.2, loop=self.loop) + return 'done' + + fut = tasks.Task(foo(), loop=self.loop) + + t0 = time.monotonic() + self.assertRaises( + futures.TimeoutError, + self.loop.run_until_complete, + tasks.wait_for(fut, 0.1, loop=self.loop)) + t1 = time.monotonic() + self.assertFalse(fut.done()) + + # wait for result + res = self.loop.run_until_complete( + tasks.wait_for(fut, 0.2, loop=self.loop)) + t2 = time.monotonic() + self.assertEqual(res, 'done') + + self.assertTrue(0.08 <= t1-t0 <= 0.12) + self.assertTrue(0.18 <= t2-t0 <= 0.22) + + def test_wait_for_with_global_loop(self): + @tasks.coroutine + def foo(): + yield from tasks.sleep(0.2, loop=self.loop) + return 'done' + + events.set_event_loop(self.loop) + try: + fut = tasks.Task(foo(), loop=self.loop) + self.assertRaises( + futures.TimeoutError, + self.loop.run_until_complete, tasks.wait_for(fut, 0.01)) + finally: + events.set_event_loop(None) + + self.assertFalse(fut.done()) + def test_wait(self): a = tasks.Task(tasks.sleep(0.1, loop=self.loop), loop=self.loop) b = tasks.Task(tasks.sleep(0.15, loop=self.loop), loop=self.loop) @@ -381,7 +423,26 @@ def foo(): res = self.loop.run_until_complete(tasks.Task(foo(), loop=self.loop)) t1 = time.monotonic() self.assertTrue(t1-t0 <= 0.01) - # TODO: Test different return_when values. + + def test_wait_with_global_loop(self): + a = tasks.Task(tasks.sleep(0.01, loop=self.loop), loop=self.loop) + b = tasks.Task(tasks.sleep(0.015, loop=self.loop), loop=self.loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + events.set_event_loop(self.loop) + try: + res = self.loop.run_until_complete( + tasks.Task(foo(), loop=self.loop)) + finally: + events.set_event_loop(None) + + self.assertEqual(res, 42) def test_wait_errors(self): self.assertRaises( diff --git a/tulip/tasks.py b/tulip/tasks.py index 09b923dd..ca513a10 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -2,7 +2,7 @@ __all__ = ['coroutine', 'task', 'Task', 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', - 'wait', 'as_completed', 'sleep', 'async', + 'wait', 'wait_for', 'as_completed', 'sleep', 'async', ] import collections @@ -209,13 +209,15 @@ def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): done, pending = yield from tulip.wait(fs) - Note: This does not raise TimeoutError! Futures that aren't done + Note: This does not raise TimeoutError! Futures that aren't done when the timeout occurs are returned in the second set. """ if not fs: raise ValueError('Set of coroutines/Futures is empty.') - loop = loop if loop is not None else events.get_event_loop() + if loop is None: + loop = events.get_event_loop() + fs = set(async(f, loop=loop) for f in fs) if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED): @@ -223,6 +225,32 @@ def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): return (yield from _wait(fs, timeout, return_when, loop)) +@coroutine +def wait_for(fut, timeout, *, loop=None): + """Wait for the single Future or coroutine to complete. + + Coroutine will be wrapped in Task. + + Returns result of the Future or coroutine. Raises TimeoutError when + timeout occurs. + + Usage: + + result = yield from tulip.wait_for(fut, 10.0) + + """ + if loop is None: + loop = events.get_event_loop() + + fut = async(fut, loop=loop) + + done, pending = yield from _wait([fut], timeout, FIRST_COMPLETED, loop) + if done: + return done.pop().result() + + raise futures.TimeoutError() + + @coroutine def _wait(fs, timeout, return_when, loop): """Internal helper for wait(return_when=FIRST_COMPLETED). From ac500c92c44e7bfaa191943abfafd1b0fea7a138 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 21 Aug 2013 04:55:09 +0300 Subject: [PATCH 0555/1502] Reduce test execution time --- tests/events_test.py | 42 +-- tests/futures_test.py | 23 +- tests/http_server_test.py | 6 +- tests/http_session_test.py | 4 +- tests/locks_test.py | 270 ++++++++++------- tests/queues_test.py | 164 ++++++++--- tests/tasks_test.py | 582 ++++++++++++++++++++++++++----------- tulip/futures.py | 7 + tulip/locks.py | 5 +- tulip/test_utils.py | 80 +++++ 10 files changed, 839 insertions(+), 344 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 0f8f6f47..3d562caf 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -264,6 +264,7 @@ def callback(arg1, arg2): def test_call_soon_threadsafe(self): results = [] + lock = threading.Lock() def callback(arg): results.append(arg) @@ -272,16 +273,17 @@ def callback(arg): def run_in_thread(): self.loop.call_soon_threadsafe(callback, 'hello') + lock.release() + lock.acquire() t = threading.Thread(target=run_in_thread) - self.loop.call_later(0.1, callback, 'world') - t0 = time.monotonic() t.start() - self.loop.run_forever() - t1 = time.monotonic() + + with lock: + self.loop.call_soon(callback, 'world') + self.loop.run_forever() t.join() self.assertEqual(results, ['hello', 'world']) - self.assertTrue(t1-t0 >= 0.09) def test_call_soon_threadsafe_same_thread(self): results = [] @@ -291,18 +293,18 @@ def callback(arg): if len(results) >= 2: self.loop.stop() - self.loop.call_later(0.1, callback, 'world') self.loop.call_soon_threadsafe(callback, 'hello') + self.loop.call_soon(callback, 'world') self.loop.run_forever() self.assertEqual(results, ['hello', 'world']) def test_run_in_executor(self): def run(arg): - time.sleep(0.1) - return arg + return (arg, threading.get_ident()) f2 = self.loop.run_in_executor(None, run, 'yo') - res = self.loop.run_until_complete(f2) + res, thread_id = self.loop.run_until_complete(f2) self.assertEqual(res, 'yo') + self.assertNotEqual(thread_id, threading.get_ident()) def test_reader_callback(self): r, w = test_utils.socketpair() @@ -322,10 +324,11 @@ def reader(): r.close() self.loop.add_reader(r.fileno(), reader) - self.loop.call_later(0.05, w.send, b'abc') - self.loop.call_later(0.1, w.send, b'def') - self.loop.call_later(0.15, w.close) - self.loop.call_later(0.16, self.loop.stop) + self.loop.call_soon(w.send, b'abc') + test_utils.run_briefly(self.loop) + self.loop.call_soon(w.send, b'def') + self.loop.call_soon(w.close) + self.loop.call_soon(self.loop.stop) self.loop.run_forever() self.assertEqual(b''.join(bytes_read), b'abcdef') @@ -333,12 +336,13 @@ def test_writer_callback(self): r, w = test_utils.socketpair() w.setblocking(False) self.loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + test_utils.run_briefly(self.loop) def remove_writer(): self.assertTrue(self.loop.remove_writer(w.fileno())) - self.loop.call_later(0.1, remove_writer) - self.loop.call_later(0.11, self.loop.stop) + self.loop.call_soon(remove_writer) + self.loop.call_soon(self.loop.stop) self.loop.run_forever() w.close() data = r.recv(256*1024) @@ -448,11 +452,11 @@ def test_signal_handling_while_selecting(self): def my_handler(): nonlocal caught caught += 1 + self.loop.stop() self.loop.add_signal_handler(signal.SIGALRM, my_handler) - signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. - self.loop.call_later(0.15, self.loop.stop) + signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once. self.loop.run_forever() self.assertEqual(caught, 1) @@ -468,8 +472,8 @@ def my_handler(*args): self.loop.add_signal_handler(signal.SIGALRM, my_handler, *some_args) - signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. - self.loop.call_later(0.15, self.loop.stop) + signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once. + self.loop.call_later(0.015, self.loop.stop) self.loop.run_forever() self.assertEqual(caught, 1) diff --git a/tests/futures_test.py b/tests/futures_test.py index 79aeddd2..b88c2a75 100644 --- a/tests/futures_test.py +++ b/tests/futures_test.py @@ -1,7 +1,7 @@ """Tests for futures.py.""" import concurrent.futures -import time +import threading import unittest import unittest.mock @@ -17,7 +17,7 @@ def _fakefunc(f): class FutureTests(unittest.TestCase): def setUp(self): - self.loop = events.new_event_loop() + self.loop = test_utils.TestLoop() events.set_event_loop(None) def tearDown(self): @@ -134,6 +134,15 @@ def test_repr(self): self.assertIn('<18 more>', r) f_many_callbacks.cancel() + f_pending = futures.Future(loop=self.loop, timeout=10) + self.assertEqual('Future{timeout=10, when=10}', + repr(f_pending)) + f_pending.cancel() + + f_pending = futures.Future(loop=self.loop, timeout=10) + f_pending.cancel() + self.assertEqual('Future{timeout=10}', repr(f_pending)) + def test_copy_state(self): # Test the internal _copy_state method since it's being directly # invoked in other modules. @@ -218,15 +227,16 @@ def test_tb_logger_exception_result_retrieved(self, m_log): self.assertFalse(m_log.error.called) def test_wrap_future(self): + def run(arg): - time.sleep(0.1) - return arg + return (arg, threading.get_ident()) ex = concurrent.futures.ThreadPoolExecutor(1) f1 = ex.submit(run, 'oi') f2 = futures.wrap_future(f1, loop=self.loop) - res = self.loop.run_until_complete(f2) + res, ident = self.loop.run_until_complete(f2) self.assertIsInstance(f2, futures.Future) self.assertEqual(res, 'oi') + self.assertNotEqual(ident, threading.get_ident()) def test_wrap_future_future(self): f1 = futures.Future(loop=self.loop) @@ -236,8 +246,7 @@ def test_wrap_future_future(self): @unittest.mock.patch('tulip.futures.events') def test_wrap_future_use_global_loop(self, m_events): def run(arg): - time.sleep(0.1) - return arg + return (arg, threading.get_ident()) ex = concurrent.futures.ThreadPoolExecutor(1) f1 = ex.submit(run, 'oi') f2 = futures.wrap_future(f1) diff --git a/tests/http_server_test.py b/tests/http_server_test.py index 49cfc8fa..862779b9 100644 --- a/tests/http_server_test.py +++ b/tests/http_server_test.py @@ -6,13 +6,13 @@ import tulip from tulip.http import server from tulip.http import errors -from tulip.test_utils import run_briefly +from tulip import test_utils class HttpServerProtocolTests(unittest.TestCase): def setUp(self): - self.loop = tulip.new_event_loop() + self.loop = test_utils.TestLoop() tulip.set_event_loop(None) def tearDown(self): @@ -204,7 +204,7 @@ def test_handle_cancelled(self): srv.connection_made(transport) srv.handle_request = unittest.mock.Mock() - run_briefly(self.loop) # start request_handler task + test_utils.run_briefly(self.loop) # start request_handler task srv.stream.feed_data( b'GET / HTTP/1.0\r\n' diff --git a/tests/http_session_test.py b/tests/http_session_test.py index cd55b7c0..39a80091 100644 --- a/tests/http_session_test.py +++ b/tests/http_session_test.py @@ -10,11 +10,13 @@ from tulip.http.client import HttpResponse from tulip.http.session import Session +from tulip import test_utils + class HttpSessionTests(unittest.TestCase): def setUp(self): - self.loop = tulip.new_event_loop() + self.loop = test_utils.TestLoop() tulip.set_event_loop(self.loop) self.transport = unittest.mock.Mock() diff --git a/tests/locks_test.py b/tests/locks_test.py index 7105571d..83663ec0 100644 --- a/tests/locks_test.py +++ b/tests/locks_test.py @@ -1,6 +1,5 @@ """Tests for lock.py""" -import time import unittest import unittest.mock @@ -8,13 +7,13 @@ from tulip import futures from tulip import locks from tulip import tasks -from tulip.test_utils import run_briefly +from tulip import test_utils class LockTests(unittest.TestCase): def setUp(self): - self.loop = events.new_event_loop() + self.loop = test_utils.TestLoop() events.set_event_loop(None) def tearDown(self): @@ -89,24 +88,24 @@ def c3(result): t1 = tasks.Task(c1(result), loop=self.loop) t2 = tasks.Task(c2(result), loop=self.loop) - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual([], result) lock.release() - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual([1], result) - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual([1], result) t3 = tasks.Task(c3(result), loop=self.loop) lock.release() - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual([1, 2], result) lock.release() - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual([1, 2, 3], result) self.assertTrue(t1.done()) @@ -117,33 +116,50 @@ def c3(result): self.assertTrue(t3.result()) def test_acquire_timeout(self): - lock = locks.Lock(loop=self.loop) - self.assertTrue(self.loop.run_until_complete(lock.acquire())) - t0 = time.monotonic() - acquired = self.loop.run_until_complete(lock.acquire(timeout=0.1)) - self.assertFalse(acquired) + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + yield 0.1 - total_time = (time.monotonic() - t0) - self.assertTrue(0.08 < total_time < 0.12) + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + lock = locks.Lock(loop=loop) - lock = locks.Lock(loop=self.loop) + self.assertTrue(loop.run_until_complete(lock.acquire())) + + acquired = loop.run_until_complete(lock.acquire(timeout=0.1)) + self.assertFalse(acquired) + self.assertAlmostEqual(0.1, loop.time()) + + lock = locks.Lock(loop=loop) self.loop.run_until_complete(lock.acquire()) - self.loop.call_later(0.01, lock.release) - acquired = self.loop.run_until_complete(lock.acquire(10.1)) + loop.call_soon(lock.release) + acquired = loop.run_until_complete(lock.acquire(10.1)) self.assertTrue(acquired) + self.assertAlmostEqual(0.1, loop.time()) def test_acquire_timeout_mixed(self): - lock = locks.Lock(loop=self.loop) - self.loop.run_until_complete(lock.acquire()) - tasks.Task(lock.acquire(), loop=self.loop) - tasks.Task(lock.acquire(), loop=self.loop) - acquire_task = tasks.Task(lock.acquire(0.01), loop=self.loop) - tasks.Task(lock.acquire(), loop=self.loop) - acquired = self.loop.run_until_complete(acquire_task) + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + lock = locks.Lock(loop=loop) + loop.run_until_complete(lock.acquire()) + tasks.Task(lock.acquire(), loop=loop) + tasks.Task(lock.acquire(), loop=loop) + acquire_task = tasks.Task(lock.acquire(0.01), loop=loop) + tasks.Task(lock.acquire(), loop=loop) + + acquired = loop.run_until_complete(acquire_task) self.assertFalse(acquired) + self.assertAlmostEqual(0.1, loop.time()) self.assertEqual(3, len(lock._waiters)) @@ -198,7 +214,7 @@ def test_context_manager_no_yield(self): class EventWaiterTests(unittest.TestCase): def setUp(self): - self.loop = events.new_event_loop() + self.loop = test_utils.TestLoop() events.set_event_loop(None) def tearDown(self): @@ -251,13 +267,13 @@ def c3(result): t1 = tasks.Task(c1(result), loop=self.loop) t2 = tasks.Task(c2(result), loop=self.loop) - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual([], result) t3 = tasks.Task(c3(result), loop=self.loop) ev.set() - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual([3, 1, 2], result) self.assertTrue(t1.done()) @@ -275,33 +291,48 @@ def test_wait_on_set(self): self.assertTrue(res) def test_wait_timeout(self): - ev = locks.EventWaiter(loop=self.loop) + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.11, when) + when = yield 0 + self.assertAlmostEqual(10.2, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) - t0 = time.monotonic() - res = self.loop.run_until_complete(ev.wait(0.1)) + ev = locks.EventWaiter(loop=loop) + + res = loop.run_until_complete(ev.wait(0.1)) self.assertFalse(res) - total_time = (time.monotonic() - t0) - self.assertTrue(0.08 < total_time < 0.12) + self.assertAlmostEqual(0.1, loop.time()) - ev = locks.EventWaiter(loop=self.loop) - self.loop.call_later(0.01, ev.set) - acquired = self.loop.run_until_complete(ev.wait(10.1)) + ev = locks.EventWaiter(loop=loop) + loop.call_later(0.01, ev.set) + acquired = loop.run_until_complete(ev.wait(10.1)) self.assertTrue(acquired) + self.assertAlmostEqual(0.11, loop.time()) def test_wait_timeout_mixed(self): - ev = locks.EventWaiter(loop=self.loop) - tasks.Task(ev.wait(), loop=self.loop) - tasks.Task(ev.wait(), loop=self.loop) - acquire_task = tasks.Task(ev.wait(0.1), loop=self.loop) - tasks.Task(ev.wait(), loop=self.loop) + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + yield 0.1 - t0 = time.monotonic() - acquired = self.loop.run_until_complete(acquire_task) - self.assertFalse(acquired) + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) - total_time = (time.monotonic() - t0) - self.assertTrue(0.08 < total_time < 0.12) + ev = locks.EventWaiter(loop=loop) + tasks.Task(ev.wait(), loop=loop) + tasks.Task(ev.wait(), loop=loop) + acquire_task = tasks.Task(ev.wait(0.1), loop=loop) + tasks.Task(ev.wait(), loop=loop) + acquired = loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + self.assertAlmostEqual(0.1, loop.time()) self.assertEqual(3, len(ev._waiters)) def test_wait_cancel(self): @@ -335,7 +366,7 @@ def c1(result): return True t = tasks.Task(c1(result), loop=self.loop) - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual([], result) ev.set() @@ -346,7 +377,7 @@ def c1(result): ev.set() self.assertEqual(1, len(ev._waiters)) - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual([1], result) self.assertEqual(0, len(ev._waiters)) @@ -357,7 +388,7 @@ def c1(result): class ConditionTests(unittest.TestCase): def setUp(self): - self.loop = events.new_event_loop() + self.loop = test_utils.TestLoop() events.set_event_loop(None) def tearDown(self): @@ -408,33 +439,33 @@ def c3(result): t2 = tasks.Task(c2(result), loop=self.loop) t3 = tasks.Task(c3(result), loop=self.loop) - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual([], result) self.assertFalse(cond.locked()) self.assertTrue(self.loop.run_until_complete(cond.acquire())) cond.notify() - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual([], result) self.assertTrue(cond.locked()) cond.release() - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual([1], result) self.assertTrue(cond.locked()) cond.notify(2) - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual([1], result) self.assertTrue(cond.locked()) cond.release() - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual([1, 2], result) self.assertTrue(cond.locked()) cond.release() - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual([1, 2, 3], result) self.assertTrue(cond.locked()) @@ -446,16 +477,21 @@ def c3(result): self.assertTrue(t3.result()) def test_wait_timeout(self): - cond = locks.Condition(loop=self.loop) - self.loop.run_until_complete(cond.acquire()) + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + cond = locks.Condition(loop=loop) + loop.run_until_complete(cond.acquire()) - t0 = time.monotonic() - wait = self.loop.run_until_complete(cond.wait(0.1)) + wait = loop.run_until_complete(cond.wait(0.1)) self.assertFalse(wait) self.assertTrue(cond.locked()) - - total_time = (time.monotonic() - t0) - self.assertTrue(0.08 < total_time < 0.12) + self.assertAlmostEqual(0.1, loop.time()) def test_wait_cancel(self): cond = locks.Condition(loop=self.loop) @@ -494,32 +530,41 @@ def c1(result): t = tasks.Task(c1(result), loop=self.loop) - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual([], result) self.loop.run_until_complete(cond.acquire()) cond.notify() cond.release() - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual([], result) presult = True self.loop.run_until_complete(cond.acquire()) cond.notify() cond.release() - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual([1], result) self.assertTrue(t.done()) self.assertTrue(t.result()) def test_wait_for_timeout(self): - cond = locks.Condition(loop=self.loop) + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + cond = locks.Condition(loop=loop) result = [] - predicate = unittest.mock.Mock() - predicate.return_value = False + predicate = unittest.mock.Mock(return_value=False) @tasks.coroutine def c1(result): @@ -530,25 +575,22 @@ def c1(result): result.append(2) cond.release() - wait_for = tasks.Task(c1(result), loop=self.loop) + wait_for = tasks.Task(c1(result), loop=loop) - t0 = time.monotonic() - - run_briefly(self.loop) + test_utils.run_briefly(loop) self.assertEqual([], result) - self.loop.run_until_complete(cond.acquire()) + loop.run_until_complete(cond.acquire()) cond.notify() cond.release() - run_briefly(self.loop) + test_utils.run_briefly(loop) self.assertEqual([], result) - self.loop.run_until_complete(wait_for) + loop.run_until_complete(wait_for) self.assertEqual([2], result) self.assertEqual(3, predicate.call_count) - total_time = (time.monotonic() - t0) - self.assertTrue(0.08 < total_time < 0.12) + self.assertAlmostEqual(0.1, loop.time()) def test_wait_for_unacquired(self): cond = locks.Condition(loop=self.loop) @@ -594,20 +636,20 @@ def c3(result): t2 = tasks.Task(c2(result), loop=self.loop) t3 = tasks.Task(c3(result), loop=self.loop) - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual([], result) self.loop.run_until_complete(cond.acquire()) cond.notify(1) cond.release() - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual([1], result) self.loop.run_until_complete(cond.acquire()) cond.notify(1) cond.notify(2048) cond.release() - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual([1, 2, 3], result) self.assertTrue(t1.done()) @@ -641,13 +683,13 @@ def c2(result): t1 = tasks.Task(c1(result), loop=self.loop) t2 = tasks.Task(c2(result), loop=self.loop) - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual([], result) self.loop.run_until_complete(cond.acquire()) cond.notify_all() cond.release() - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual([1, 2], result) self.assertTrue(t1.done()) @@ -667,7 +709,7 @@ def test_notify_all_unacquired(self): class SemaphoreTests(unittest.TestCase): def setUp(self): - self.loop = events.new_event_loop() + self.loop = test_utils.TestLoop() events.set_event_loop(None) def tearDown(self): @@ -753,7 +795,7 @@ def c4(result): t2 = tasks.Task(c2(result), loop=self.loop) t3 = tasks.Task(c3(result), loop=self.loop) - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual([1], result) self.assertTrue(sem.locked()) self.assertEqual(2, len(sem._waiters)) @@ -765,7 +807,7 @@ def c4(result): sem.release() self.assertEqual(2, sem._value) - run_briefly(self.loop) + test_utils.run_briefly(self.loop) self.assertEqual(0, sem._value) self.assertEqual([1, 2, 3], result) self.assertTrue(sem.locked()) @@ -781,37 +823,53 @@ def c4(result): self.assertFalse(t4.done()) def test_acquire_timeout(self): - sem = locks.Semaphore(loop=self.loop) - self.loop.run_until_complete(sem.acquire()) + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.11, when) + when = yield 0 + self.assertAlmostEqual(10.2, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) - t0 = time.monotonic() - acquired = self.loop.run_until_complete(sem.acquire(0.1)) - self.assertFalse(acquired) + sem = locks.Semaphore(loop=loop) + loop.run_until_complete(sem.acquire()) - total_time = (time.monotonic() - t0) - self.assertTrue(0.08 < total_time < 0.12) + acquired = loop.run_until_complete(sem.acquire(0.1)) + self.assertFalse(acquired) + self.assertAlmostEqual(0.1, loop.time()) - sem = locks.Semaphore(loop=self.loop) - self.loop.run_until_complete(sem.acquire()) + sem = locks.Semaphore(loop=loop) + loop.run_until_complete(sem.acquire()) - self.loop.call_later(0.01, sem.release) - acquired = self.loop.run_until_complete(sem.acquire(10.1)) + loop.call_later(0.01, sem.release) + acquired = loop.run_until_complete(sem.acquire(10.1)) self.assertTrue(acquired) + self.assertAlmostEqual(0.11, loop.time()) def test_acquire_timeout_mixed(self): - sem = locks.Semaphore(loop=self.loop) - self.loop.run_until_complete(sem.acquire()) - tasks.Task(sem.acquire(), loop=self.loop) - tasks.Task(sem.acquire(), loop=self.loop) - acquire_task = tasks.Task(sem.acquire(0.1), loop=self.loop) - tasks.Task(sem.acquire(), loop=self.loop) + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + sem = locks.Semaphore(loop=loop) + loop.run_until_complete(sem.acquire()) + tasks.Task(sem.acquire(), loop=loop) + tasks.Task(sem.acquire(), loop=loop) + acquire_task = tasks.Task(sem.acquire(0.1), loop=loop) + tasks.Task(sem.acquire(), loop=loop) - t0 = time.monotonic() - acquired = self.loop.run_until_complete(acquire_task) + acquired = loop.run_until_complete(acquire_task) self.assertFalse(acquired) - total_time = (time.monotonic() - t0) - self.assertTrue(0.08 < total_time < 0.12) + self.assertAlmostEqual(0.1, loop.time()) self.assertEqual(3, len(sem._waiters)) diff --git a/tests/queues_test.py b/tests/queues_test.py index fb81b3fe..07440585 100644 --- a/tests/queues_test.py +++ b/tests/queues_test.py @@ -8,12 +8,13 @@ from tulip import locks from tulip import queues from tulip import tasks +from tulip import test_utils class _QueueTestBase(unittest.TestCase): def setUp(self): - self.loop = events.new_event_loop() + self.loop = test_utils.TestLoop() events.set_event_loop(None) def tearDown(self): @@ -28,35 +29,45 @@ def _test_repr_or_str(self, fn, expect_id): fn is repr or str. expect_id is True if we expect the Queue's id to appear in fn(Queue()). """ - q = queues.Queue(loop=self.loop) + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.2, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = queues.Queue(loop=loop) self.assertTrue(fn(q).startswith('= 0.14) + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) self.assertEqual(res, 42) + self.assertAlmostEqual(0.15, loop.time()) + # Doing it again should take no time and exercise a different path. - t0 = time.monotonic() - res = self.loop.run_until_complete(tasks.Task(foo(), loop=self.loop)) - t1 = time.monotonic() - self.assertTrue(t1-t0 <= 0.01) + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + self.assertEqual(res, 42) def test_wait_with_global_loop(self): - a = tasks.Task(tasks.sleep(0.01, loop=self.loop), loop=self.loop) - b = tasks.Task(tasks.sleep(0.015, loop=self.loop), loop=self.loop) + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + when = yield 0 + self.assertAlmostEqual(0.015, when) + yield 0.015 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.01, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.015, loop=loop), loop=loop) @tasks.coroutine def foo(): @@ -435,10 +536,10 @@ def foo(): self.assertEqual(pending, set()) return 42 - events.set_event_loop(self.loop) + events.set_event_loop(loop) try: - res = self.loop.run_until_complete( - tasks.Task(foo(), loop=self.loop)) + res = loop.run_until_complete( + tasks.Task(foo(), loop=loop)) finally: events.set_event_loop(None) @@ -455,19 +556,31 @@ def test_wait_errors(self): return_when=-1, loop=self.loop)) def test_wait_first_completed(self): - a = tasks.Task(tasks.sleep(10.0, loop=self.loop), loop=self.loop) - b = tasks.Task(tasks.sleep(0.1, loop=self.loop), loop=self.loop) + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) task = tasks.Task( tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED, - loop=self.loop), - loop=self.loop) + loop=loop), + loop=loop) - done, pending = self.loop.run_until_complete(task) + done, pending = loop.run_until_complete(task) self.assertEqual({b}, done) self.assertEqual({a}, pending) self.assertFalse(a.done()) self.assertTrue(b.done()) self.assertIsNone(b.result()) + self.assertAlmostEqual(0.1, loop.time()) def test_wait_really_done(self): # there is possibility that some tasks in the pending list @@ -497,132 +610,215 @@ def coro2(): self.assertIsNone(b.result()) def test_wait_first_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + # first_exception, task already has exception - a = tasks.Task(tasks.sleep(10.0, loop=self.loop), loop=self.loop) + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) @tasks.coroutine def exc(): raise ZeroDivisionError('err') - b = tasks.Task(exc(), loop=self.loop) + b = tasks.Task(exc(), loop=loop) task = tasks.Task( tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION, - loop=self.loop), - loop=self.loop) + loop=loop), + loop=loop) - done, pending = self.loop.run_until_complete(task) + done, pending = loop.run_until_complete(task) self.assertEqual({b}, done) self.assertEqual({a}, pending) + self.assertAlmostEqual(0, loop.time()) def test_wait_first_exception_in_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + # first_exception, exception during waiting - a = tasks.Task(tasks.sleep(10.0, loop=self.loop), loop=self.loop) + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) @tasks.coroutine def exc(): - yield from tasks.sleep(0.01, loop=self.loop) + yield from tasks.sleep(0.01, loop=loop) raise ZeroDivisionError('err') - b = tasks.Task(exc(), loop=self.loop) + b = tasks.Task(exc(), loop=loop) task = tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION, - loop=self.loop) + loop=loop) - done, pending = self.loop.run_until_complete(task) + done, pending = loop.run_until_complete(task) self.assertEqual({b}, done) self.assertEqual({a}, pending) + self.assertAlmostEqual(0.01, loop.time()) def test_wait_with_exception(self): - a = tasks.Task(tasks.sleep(0.1, loop=self.loop), loop=self.loop) + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) @tasks.coroutine def sleeper(): - yield from tasks.sleep(0.15, loop=self.loop) + yield from tasks.sleep(0.15, loop=loop) raise ZeroDivisionError('really') - b = tasks.Task(sleeper(), loop=self.loop) + b = tasks.Task(sleeper(), loop=loop) @tasks.coroutine def foo(): - done, pending = yield from tasks.wait([b, a], loop=self.loop) + done, pending = yield from tasks.wait([b, a], loop=loop) self.assertEqual(len(done), 2) self.assertEqual(pending, set()) errors = set(f for f in done if f.exception() is not None) self.assertEqual(len(errors), 1) - t0 = time.monotonic() - self.loop.run_until_complete(tasks.Task(foo(), loop=self.loop)) - t1 = time.monotonic() - self.assertTrue(t1-t0 >= 0.14) - t0 = time.monotonic() - self.loop.run_until_complete(tasks.Task(foo(), loop=self.loop)) - t1 = time.monotonic() - self.assertTrue(t1-t0 <= 0.01) + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) def test_wait_with_timeout(self): - a = tasks.Task(tasks.sleep(0.1, loop=self.loop), loop=self.loop) - b = tasks.Task(tasks.sleep(0.15, loop=self.loop), loop=self.loop) + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.11, when) + yield 0.11 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) @tasks.coroutine def foo(): done, pending = yield from tasks.wait([b, a], timeout=0.11, - loop=self.loop) + loop=loop) self.assertEqual(done, set([a])) self.assertEqual(pending, set([b])) - t0 = time.monotonic() - self.loop.run_until_complete(tasks.Task(foo(), loop=self.loop)) - t1 = time.monotonic() - self.assertTrue(t1-t0 >= 0.1) - self.assertTrue(t1-t0 <= 0.13) + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.11, loop.time()) def test_wait_concurrent_complete(self): - a = tasks.Task(tasks.sleep(0.1, loop=self.loop), loop=self.loop) - b = tasks.Task(tasks.sleep(0.15, loop=self.loop), loop=self.loop) - done, pending = self.loop.run_until_complete( - tasks.wait([b, a], timeout=0.1, loop=self.loop)) + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + done, pending = loop.run_until_complete( + tasks.wait([b, a], timeout=0.1, loop=loop)) self.assertEqual(done, set([a])) self.assertEqual(pending, set([b])) + self.assertAlmostEqual(0.1, loop.time()) def test_as_completed(self): + + def gen(): + yield 0 + yield 0 + yield 0.01 + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + completed = set() + time_shifted = False + @tasks.coroutine def sleeper(dt, x): - yield from tasks.sleep(dt, loop=self.loop) + nonlocal time_shifted + yield from tasks.sleep(dt, loop=loop) + completed.add(x) + if not time_shifted and 'a' in completed and 'b' in completed: + time_shifted = True + loop.advance_time(0.14) return x - a = sleeper(0.1, 'a') - b = sleeper(0.1, 'b') + a = sleeper(0.01, 'a') + b = sleeper(0.01, 'b') c = sleeper(0.15, 'c') @tasks.coroutine def foo(): values = [] - for f in tasks.as_completed([b, c, a], loop=self.loop): + for f in tasks.as_completed([b, c, a], loop=loop): values.append((yield from f)) return values - t0 = time.monotonic() - res = self.loop.run_until_complete(tasks.Task(foo(), loop=self.loop)) - t1 = time.monotonic() - self.assertTrue(t1-t0 >= 0.14) + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) self.assertTrue('a' in res[:2]) self.assertTrue('b' in res[:2]) self.assertEqual(res[2], 'c') + # Doing it again should take no time and exercise a different path. - t0 = time.monotonic() - res = self.loop.run_until_complete(tasks.Task(foo(), loop=self.loop)) - t1 = time.monotonic() - self.assertTrue(t1-t0 <= 0.01) + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) def test_as_completed_with_timeout(self): - a = tasks.sleep(0.1, 'a', loop=self.loop) - b = tasks.sleep(0.15, 'b', loop=self.loop) + + def gen(): + when = yield + self.assertAlmostEqual(0.12, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0.1 + self.assertAlmostEqual(0.12, when) + yield 0.02 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.sleep(0.1, 'a', loop=loop) + b = tasks.sleep(0.15, 'b', loop=loop) @tasks.coroutine def foo(): values = [] - for f in tasks.as_completed([a, b], timeout=0.12, loop=self.loop): + for f in tasks.as_completed([a, b], timeout=0.12, loop=loop): try: v = yield from f values.append((1, v)) @@ -630,103 +826,148 @@ def foo(): values.append((2, exc)) return values - t0 = time.monotonic() - res = self.loop.run_until_complete(tasks.Task(foo(), loop=self.loop)) - t1 = time.monotonic() - self.assertTrue(t1-t0 >= 0.11) + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) self.assertEqual(len(res), 2, res) self.assertEqual(res[0], (1, 'a')) self.assertEqual(res[1][0], 2) self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) + self.assertAlmostEqual(0.12, loop.time()) def test_as_completed_reverse_wait(self): - a = tasks.sleep(0.05, 'a', loop=self.loop) - b = tasks.sleep(0.10, 'b', loop=self.loop) + + def gen(): + yield 0 + yield 0.05 + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + completed = set() + time_shifted = False + + a = tasks.sleep(0.05, 'a', loop=loop) + b = tasks.sleep(0.10, 'b', loop=loop) fs = {a, b} - futs = list(tasks.as_completed(fs, loop=self.loop)) + futs = list(tasks.as_completed(fs, loop=loop)) self.assertEqual(len(futs), 2) - x = self.loop.run_until_complete(futs[1]) + + x = loop.run_until_complete(futs[1]) self.assertEqual(x, 'a') - y = self.loop.run_until_complete(futs[0]) + self.assertAlmostEqual(0.05, loop.time()) + loop.advance_time(0.05) + y = loop.run_until_complete(futs[0]) self.assertEqual(y, 'b') + self.assertAlmostEqual(0.10, loop.time()) def test_as_completed_concurrent(self): - a = tasks.sleep(0.05, 'a', loop=self.loop) - b = tasks.sleep(0.05, 'b', loop=self.loop) + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0 + self.assertAlmostEqual(0.05, when) + yield 0.05 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.sleep(0.05, 'a', loop=loop) + b = tasks.sleep(0.05, 'b', loop=loop) fs = {a, b} - futs = list(tasks.as_completed(fs, loop=self.loop)) + futs = list(tasks.as_completed(fs, loop=loop)) self.assertEqual(len(futs), 2) - waiter = tasks.wait(futs, loop=self.loop) - done, pending = self.loop.run_until_complete(waiter) + waiter = tasks.wait(futs, loop=loop) + done, pending = loop.run_until_complete(waiter) self.assertEqual(set(f.result() for f in done), {'a', 'b'}) def test_sleep(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0.05 + self.assertAlmostEqual(0.1, when) + yield 0.05 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + @tasks.coroutine def sleeper(dt, arg): - yield from tasks.sleep(dt/2, loop=self.loop) - res = yield from tasks.sleep(dt/2, arg, loop=self.loop) + yield from tasks.sleep(dt/2, loop=loop) + res = yield from tasks.sleep(dt/2, arg, loop=loop) return res - t = tasks.Task(sleeper(0.1, 'yeah'), loop=self.loop) - t0 = time.monotonic() - self.loop.run_until_complete(t) - t1 = time.monotonic() - self.assertTrue(t1-t0 >= 0.09) + t = tasks.Task(sleeper(0.1, 'yeah'), loop=loop) + loop.run_until_complete(t) self.assertTrue(t.done()) self.assertEqual(t.result(), 'yeah') + self.assertAlmostEqual(0.1, loop.time()) def test_sleep_cancel(self): - t = tasks.Task(tasks.sleep(10.0, 'yeah', loop=self.loop), - loop=self.loop) + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + t = tasks.Task(tasks.sleep(10.0, 'yeah', loop=loop), + loop=loop) handle = None - orig_call_later = self.loop.call_later + orig_call_later = loop.call_later def call_later(self, delay, callback, *args): nonlocal handle handle = orig_call_later(self, delay, callback, *args) return handle - self.loop.call_later = call_later - test_utils.run_briefly(self.loop) + loop.call_later = call_later + test_utils.run_briefly(loop) self.assertFalse(handle._cancelled) t.cancel() - test_utils.run_briefly(self.loop) + test_utils.run_briefly(loop) self.assertTrue(handle._cancelled) def test_task_cancel_sleeping_task(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(5000, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + sleepfut = None @tasks.coroutine def sleep(dt): nonlocal sleepfut - sleepfut = tasks.sleep(dt, loop=self.loop) - try: - time.monotonic() - yield from sleepfut - finally: - time.monotonic() + sleepfut = tasks.sleep(dt, loop=loop) + yield from sleepfut @tasks.coroutine def doit(): - sleeper = tasks.Task(sleep(5000), loop=self.loop) - self.loop.call_later(0.1, sleeper.cancel) + sleeper = tasks.Task(sleep(5000), loop=loop) + loop.call_later(0.1, sleeper.cancel) try: - time.monotonic() yield from sleeper except futures.CancelledError: - time.monotonic() return 'cancelled' else: return 'slept in' - t0 = time.monotonic() doer = doit() - self.assertEqual(self.loop.run_until_complete(doer), 'cancelled') - t1 = time.monotonic() - self.assertTrue(0.09 <= t1-t0 <= 0.13, (t1-t0, sleepfut, doer)) + self.assertEqual(loop.run_until_complete(doer), 'cancelled') + self.assertAlmostEqual(0.1, loop.time()) def test_task_cancel_waiter_future(self): fut = futures.Future(loop=self.loop) @@ -811,9 +1052,18 @@ def notmutch(): self.assertIsInstance(task.exception(), BaseException) def test_baseexception_during_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + @tasks.coroutine def sleeper(): - yield from tasks.sleep(10, loop=self.loop) + yield from tasks.sleep(10, loop=loop) @tasks.coroutine def notmutch(): @@ -822,13 +1072,13 @@ def notmutch(): except futures.CancelledError: raise BaseException() - task = tasks.Task(notmutch(), loop=self.loop) - test_utils.run_briefly(self.loop) + task = tasks.Task(notmutch(), loop=loop) + test_utils.run_briefly(loop) task.cancel() self.assertFalse(task.done()) - self.assertRaises(BaseException, test_utils.run_briefly, self.loop) + self.assertRaises(BaseException, test_utils.run_briefly, loop) self.assertTrue(task.done()) self.assertTrue(task.cancelled()) diff --git a/tulip/futures.py b/tulip/futures.py index cad42b47..068f77ee 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -171,6 +171,13 @@ def __repr__(self): res += '<{}, {}>'.format(self._state, self._callbacks) else: res += '<{}>'.format(self._state) + dct = {} + if self._timeout is not None: + dct['timeout'] = self._timeout + if self._timeout_handle is not None: + dct['when'] = self._timeout_handle._when + if dct: + res += '{' + ', '.join(k+'='+str(dct[k]) for k in sorted(dct)) + '}' return res def cancel(self): diff --git a/tulip/locks.py b/tulip/locks.py index dfe9905d..622a499b 100644 --- a/tulip/locks.py +++ b/tulip/locks.py @@ -3,7 +3,6 @@ __all__ = ['Lock', 'EventWaiter', 'Condition', 'Semaphore'] import collections -import time from . import events from . import futures @@ -288,9 +287,9 @@ def wait_for(self, predicate, timeout=None): while not result: if waittime is not None: if endtime is None: - endtime = time.monotonic() + waittime + endtime = self._loop.time() + waittime else: - waittime = endtime - time.monotonic() + waittime = endtime - self._loop.time() if waittime <= 0: break diff --git a/tulip/test_utils.py b/tulip/test_utils.py index 17f00475..5077c4a4 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -8,6 +8,7 @@ import json import logging import io +import unittest.mock import os import re import socket @@ -24,6 +25,9 @@ import tulip.http from tulip.http import client +from tulip import base_events +from tulip import selectors + if sys.platform == 'win32': # pragma: no cover from .windows_utils import socketpair @@ -300,3 +304,79 @@ def _response(self, response, body=None, headers=None, chunked=False): # keep-alive if response.keep_alive(): self._srv.keep_alive(True) + + +class TestSelector(selectors._BaseSelector): + + def select(self, timeout): + return [] + + +class TestLoop(base_events.BaseEventLoop): + """Loop for unittests. + + It manages self time directly. + If something scheduled to be executed later then + on next loop iteration after all ready handlers done + generator passed to __init__ is calling. + + Generator should be like this: + + def gen(): + ... + when = yield ... + ... = yield time_advance + + Value retuned by yield is absolute time of next scheduled handler. + Value passed to yield is time advance to move loop's time forward. + """ + + def __init__(self, gen=None): + super().__init__() + + if gen is None: + self._check_on_close = False + def gen(): + yield + else: + self._check_on_close = True + + self._gen = gen() + next(self._gen) + self._time = 0 + self._timers = [] + self._selector = TestSelector() + + def time(self): + return self._time + + def advance_time(self, advance): + """Move test time forward.""" + if advance: + self._time += advance + + def close(self): + if self._check_on_close: + try: + self._gen.send(0) + except StopIteration: + pass + else: + raise AssertionError("Time generator is not finished") + + def _run_once(self): + super()._run_once() + for when in self._timers: + advance = self._gen.send(when) + self.advance_time(advance) + self._timers = [] + + def call_at(self, when, callback, *args): + self._timers.append(when) + return super().call_at(when, callback, *args) + + def _process_events(self, event_list): + return + + def _write_to_self(self): + pass From 2fbbe7b5850d406c524ecb72e69ecd52fe43e274 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 21 Aug 2013 05:47:13 +0300 Subject: [PATCH 0556/1502] Close response object explicitly in http tests. --- tests/http_client_functional_test.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/http_client_functional_test.py b/tests/http_client_functional_test.py index c351607a..91badfc4 100644 --- a/tests/http_client_functional_test.py +++ b/tests/http_client_functional_test.py @@ -65,6 +65,7 @@ def test_HTTP_302_REDIRECT_GET(self): self.assertEqual(r.status, 200) self.assertEqual(2, httpd['redirects']) + r.close() def test_HTTP_302_REDIRECT_NON_HTTP(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: @@ -85,6 +86,7 @@ def test_HTTP_302_REDIRECT_POST(self): self.assertEqual(r.status, 200) self.assertIn('"method": "POST"', content) self.assertEqual(2, httpd['redirects']) + r.close() def test_HTTP_302_max_redirects(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: @@ -94,6 +96,7 @@ def test_HTTP_302_max_redirects(self): self.assertEqual(r.status, 302) self.assertEqual(2, httpd['redirects']) + r.close() def test_HTTP_200_GET_WITH_PARAMS(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: @@ -105,6 +108,7 @@ def test_HTTP_200_GET_WITH_PARAMS(self): self.assertIn('"query": "q=test"', content) self.assertEqual(r.status, 200) + r.close() def test_HTTP_200_GET_WITH_MIXED_PARAMS(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: @@ -117,6 +121,7 @@ def test_HTTP_200_GET_WITH_MIXED_PARAMS(self): self.assertIn('"query": "test=true&q=test"', content) self.assertEqual(r.status, 200) + r.close() def test_POST_DATA(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: @@ -129,6 +134,7 @@ def test_POST_DATA(self): content = self.loop.run_until_complete(r.read(True)) self.assertEqual({'some': ['data']}, content['form']) self.assertEqual(r.status, 200) + r.close() def test_POST_DATA_DEFLATE(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: @@ -143,6 +149,7 @@ def test_POST_DATA_DEFLATE(self): self.assertEqual('deflate', content['compression']) self.assertEqual({'some': ['data']}, content['form']) self.assertEqual(r.status, 200) + r.close() def test_POST_FILES(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: @@ -167,6 +174,7 @@ def test_POST_FILES(self): self.assertEqual( f.read(), content['multipart-data'][0]['data']) self.assertEqual(r.status, 200) + r.close() def test_POST_FILES_DEFLATE(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: @@ -192,6 +200,7 @@ def test_POST_FILES_DEFLATE(self): self.assertEqual( f.read(), content['multipart-data'][0]['data']) self.assertEqual(r.status, 200) + r.close() def test_POST_FILES_STR(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: @@ -213,6 +222,7 @@ def test_POST_FILES_STR(self): self.assertEqual( f.read(), content['multipart-data'][0]['data']) self.assertEqual(r.status, 200) + r.close() def test_POST_FILES_LIST(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: @@ -236,6 +246,7 @@ def test_POST_FILES_LIST(self): self.assertEqual( f.read(), content['multipart-data'][0]['data']) self.assertEqual(r.status, 200) + r.close() def test_POST_FILES_LIST_CT(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: @@ -261,6 +272,7 @@ def test_POST_FILES_LIST_CT(self): self.assertEqual( 'text/plain', content['multipart-data'][0]['content-type']) self.assertEqual(r.status, 200) + r.close() def test_POST_FILES_SINGLE(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: @@ -283,6 +295,7 @@ def test_POST_FILES_SINGLE(self): self.assertEqual( f.read(), content['multipart-data'][0]['data']) self.assertEqual(r.status, 200) + r.close() def test_POST_FILES_IO(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: @@ -302,6 +315,7 @@ def test_POST_FILES_IO(self): 'filename': 'unknown', 'name': 'unknown'}, content['multipart-data'][0]) self.assertEqual(r.status, 200) + r.close() def test_POST_FILES_WITH_DATA(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: @@ -329,6 +343,7 @@ def test_POST_FILES_WITH_DATA(self): self.assertEqual( f.read(), content['multipart-data'][1]['data']) self.assertEqual(r.status, 200) + r.close() def test_encoding(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: @@ -341,6 +356,7 @@ def test_encoding(self): client.request('get', httpd.url('encoding', 'gzip'), loop=self.loop)) self.assertEqual(r.status, 200) + r.close() def test_cookies(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: @@ -355,6 +371,7 @@ def test_cookies(self): content = self.loop.run_until_complete(r.content.read()) self.assertIn(b'"Cookie": "test1=123; test3=456"', bytes(content)) + r.close() def test_set_cookies(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: @@ -364,6 +381,7 @@ def test_set_cookies(self): self.assertEqual(resp.cookies['c1'].value, 'cookie1') self.assertEqual(resp.cookies['c2'].value, 'cookie2') + resp.close() def test_chunked(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: @@ -373,6 +391,7 @@ def test_chunked(self): self.assertEqual(r['Transfer-Encoding'], 'chunked') content = self.loop.run_until_complete(r.read(True)) self.assertEqual(content['path'], '/chunked') + r.close() def test_timeout(self): with test_utils.run_test_server(self.loop, router=Functional) as httpd: From 9c519c2731b9ac6c309dfc5d632c397803634993 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 21 Aug 2013 08:29:57 +0300 Subject: [PATCH 0557/1502] Unittest refactoring. --- tests/selector_events_test.py | 177 +++++++++++++++++----------------- tests/transports_test.py | 1 + tests/unix_events_test.py | 131 ++++++++++++++----------- tulip/test_utils.py | 61 ++++++++++++ tulip/unix_events.py | 7 +- 5 files changed, 228 insertions(+), 149 deletions(-) diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index 76d72910..58cef434 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -11,6 +11,7 @@ from tulip import futures from tulip import selectors +from tulip import test_utils from tulip.events import AbstractEventLoop from tulip.protocols import DatagramProtocol, Protocol from tulip.selector_events import BaseSelectorEventLoop @@ -401,7 +402,7 @@ def test_add_reader(self): self.assertEqual(1, fd) self.assertEqual(selectors.EVENT_READ, mask) self.assertEqual(cb, r._callback) - self.assertEqual(None, w) + self.assertIsNone(w) def test_add_reader_existing(self): reader = unittest.mock.Mock() @@ -469,7 +470,7 @@ def test_add_writer(self): fd, mask, (r, w) = self.loop._selector.register.call_args[0] self.assertEqual(1, fd) self.assertEqual(selectors.EVENT_WRITE, mask) - self.assertEqual(None, r) + self.assertIsNone(r) self.assertEqual(cb, w._callback) def test_add_writer_existing(self): @@ -555,10 +556,10 @@ def test_process_events_write_cancelled(self): class SelectorTransportTests(unittest.TestCase): def setUp(self): - self.loop = unittest.mock.Mock(spec_set=AbstractEventLoop) + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(Protocol) self.sock = unittest.mock.Mock(socket.socket) self.sock.fileno.return_value = 7 - self.protocol = unittest.mock.Mock(Protocol) def test_ctor(self): tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) @@ -578,37 +579,39 @@ def test_close(self): tr.close() self.assertTrue(tr._closing) - self.loop.remove_reader.assert_called_with(7) + self.assertEqual(1, self.loop.remove_reader_count[7]) self.protocol.connection_lost(None) self.assertEqual(tr._conn_lost, 1) - self.loop.reset_mock() tr.close() self.assertEqual(tr._conn_lost, 1) - self.assertFalse(self.loop.remove_reader.called) + self.assertEqual(1, self.loop.remove_reader_count[7]) def test_close_write_buffer(self): tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) - self.loop.reset_mock() tr._buffer.append(b'data') tr.close() - self.assertTrue(self.loop.remove_reader.called) - self.assertFalse(self.loop.call_soon.called) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) def test_force_close(self): tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) tr._buffer = [b'1'] + self.loop.add_reader(7, unittest.mock.sentinel) + self.loop.add_writer(7, unittest.mock.sentinel) tr._force_close(None) self.assertTrue(tr._closing) self.assertEqual(tr._buffer, []) - self.loop.remove_reader.assert_called_with(7) - self.loop.remove_writer.assert_called_with(7) + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) - self.loop.reset_mock() + # second close should not remove reader tr._force_close(None) - self.assertFalse(self.loop.remove_reader.called) + self.assertFalse(self.loop.readers) + self.assertEqual(1, self.loop.remove_reader_count[7]) @unittest.mock.patch('tulip.log.tulip_log.exception') def test_fatal_error(self, m_exc): @@ -632,17 +635,17 @@ def test_connection_lost(self): class SelectorSocketTransportTests(unittest.TestCase): def setUp(self): - self.loop = unittest.mock.Mock(spec_set=AbstractEventLoop) + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(Protocol) self.sock = unittest.mock.Mock(socket.socket) self.sock_fd = self.sock.fileno.return_value = 7 - self.protocol = unittest.mock.Mock(Protocol) def test_ctor(self): tr = _SelectorSocketTransport( self.loop, self.sock, self.protocol) - self.loop.add_reader.assert_called_with(7, tr._read_ready) - self.loop.call_soon.assert_called_with( - self.protocol.connection_made, tr) + self.loop.assert_reader(7, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) self.assertTrue(tr._writing) def test_ctor_with_waiter(self): @@ -650,9 +653,8 @@ def test_ctor_with_waiter(self): _SelectorSocketTransport( self.loop, self.sock, self.protocol, fut) - self.assertEqual(2, self.loop.call_soon.call_count) - self.assertEqual(fut.set_result, - self.loop.call_soon.call_args[0][0]) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) def test_read_ready(self): transport = _SelectorSocketTransport( @@ -668,7 +670,6 @@ def test_read_ready_eof(self): self.loop, self.sock, self.protocol) transport.close = unittest.mock.Mock() - self.loop.reset_mock() self.sock.recv.return_value = b'' transport._read_ready() @@ -759,10 +760,7 @@ def test_write_partial(self): self.loop, self.sock, self.protocol) transport.write(data) - self.assertTrue(self.loop.add_writer.called) - self.assertEqual( - transport._write_ready, self.loop.add_writer.call_args[0][1]) - + self.loop.assert_writer(7, transport._write_ready) self.assertEqual([b'ta'], transport._buffer) def test_write_partial_none(self): @@ -774,8 +772,7 @@ def test_write_partial_none(self): self.loop, self.sock, self.protocol) transport.write(data) - self.loop.add_writer.assert_called_with( - 7, transport._write_ready) + self.loop.assert_writer(7, transport._write_ready) self.assertEqual([b'data'], transport._buffer) def test_write_tryagain(self): @@ -786,10 +783,7 @@ def test_write_tryagain(self): self.loop, self.sock, self.protocol) transport.write(data) - self.assertTrue(self.loop.add_writer.called) - self.assertEqual( - transport._write_ready, self.loop.add_writer.call_args[0][1]) - + self.loop.assert_writer(7, transport._write_ready) self.assertEqual([b'data'], transport._buffer) @unittest.mock.patch('tulip.selector_events.tulip_log') @@ -834,10 +828,11 @@ def test_write_ready(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) transport._buffer.append(data) + self.loop.add_writer(7, transport._write_ready) transport._write_ready() self.assertTrue(self.sock.send.called) self.assertEqual(self.sock.send.call_args[0], (data,)) - self.assertTrue(self.loop.remove_writer.called) + self.assertFalse(self.loop.writers) def test_write_ready_paused(self): transport = _SelectorSocketTransport( @@ -856,9 +851,10 @@ def test_write_ready_closing(self): self.loop, self.sock, self.protocol) transport._closing = True transport._buffer.append(data) + self.loop.add_writer(7, transport._write_ready) transport._write_ready() self.sock.send.assert_called_with(data) - self.loop.remove_writer.assert_called_with(7) + self.assertFalse(self.loop.writers) self.sock.close.assert_called_with() self.protocol.connection_lost.assert_called_with(None) @@ -874,8 +870,9 @@ def test_write_ready_partial(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) transport._buffer.append(data) + self.loop.add_writer(7, transport._write_ready) transport._write_ready() - self.assertFalse(self.loop.remove_writer.called) + self.loop.assert_writer(7, transport._write_ready) self.assertEqual([b'ta'], transport._buffer) def test_write_ready_partial_none(self): @@ -885,8 +882,9 @@ def test_write_ready_partial_none(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) transport._buffer.append(data) + self.loop.add_writer(7, transport._write_ready) transport._write_ready() - self.assertFalse(self.loop.remove_writer.called) + self.loop.assert_writer(7, transport._write_ready) self.assertEqual([b'data'], transport._buffer) def test_write_ready_tryagain(self): @@ -895,9 +893,10 @@ def test_write_ready_tryagain(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) transport._buffer = [b'data1', b'data2'] + self.loop.add_writer(7, transport._write_ready) transport._write_ready() - self.assertFalse(self.loop.remove_writer.called) + self.loop.assert_writer(7, transport._write_ready) self.assertEqual([b'data1data2'], transport._buffer) def test_write_ready_exception(self): @@ -914,33 +913,33 @@ def test_pause_writing(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) transport._buffer.append(b'data') + self.loop.add_writer(self.sock_fd, transport._write_ready) transport.pause_writing() self.assertFalse(transport._writing) - self.loop.remove_writer.assert_called_with(self.sock_fd) + self.assertFalse(self.loop.writers) + self.assertEqual(1, self.loop.remove_writer_count[self.sock_fd]) - self.loop.reset_mock() transport.pause_writing() - self.assertFalse(self.loop.remove_writer.called) + self.assertEqual(1, self.loop.remove_writer_count[self.sock_fd]) def test_pause_writing_no_buffer(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) transport.pause_writing() self.assertFalse(transport._writing) - self.assertFalse(self.loop.remove_writer.called) + self.assertEqual(0, self.loop.remove_writer_count[7]) def test_resume_writing(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) transport._buffer.append(b'data') transport.resume_writing() - self.assertFalse(self.loop.add_writer.called) + self.assertFalse(self.loop.writers) transport._writing = False transport.resume_writing() self.assertTrue(transport._writing) - self.loop.add_writer.assert_called_with( - self.sock_fd, transport._write_ready) + self.loop.assert_writer(self.sock_fd, transport._write_ready) def test_resume_writing_no_buffer(self): transport = _SelectorSocketTransport( @@ -948,28 +947,30 @@ def test_resume_writing_no_buffer(self): transport._writing = False transport.resume_writing() self.assertTrue(transport._writing) - self.assertFalse(self.loop.add_writer.called) + self.assertFalse(self.loop.writers) def test_discard_output(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) transport.discard_output() - self.assertFalse(self.loop.remove_writer.called) + self.assertEqual(0, self.loop.remove_writer_count[self.sock_fd]) transport._buffer.append(b'data') + self.loop.add_writer(self.sock_fd, transport._write_ready) transport.discard_output() self.assertEqual(transport._buffer, []) - self.loop.remove_writer.assert_called_with(self.sock_fd) + self.assertEqual(1, self.loop.remove_writer_count[self.sock_fd]) + self.assertFalse(self.loop.writers) @unittest.skipIf(ssl is None, 'No ssl module') class SelectorSslTransportTests(unittest.TestCase): def setUp(self): - self.loop = unittest.mock.Mock(spec_set=AbstractEventLoop) + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(Protocol) self.sock = unittest.mock.Mock(socket.socket) self.sock.fileno.return_value = 7 - self.protocol = unittest.mock.Mock(spec_set=Protocol) self.sslsock = unittest.mock.Mock() self.sslsock.fileno.return_value = 1 self.sslcontext = unittest.mock.Mock() @@ -978,46 +979,42 @@ def setUp(self): def _make_one(self, create_waiter=None): transport = _SelectorSslTransport( self.loop, self.sock, self.protocol, self.sslcontext) - self.loop.reset_mock() self.sock.reset_mock() self.protocol.reset_mock() self.sslsock.reset_mock() self.sslcontext.reset_mock() + self.loop.reset_counters() return transport def test_on_handshake(self): - tr = self._make_one() - tr._waiter = futures.Future(loop=self.loop) - tr._on_handshake() + waiter = futures.Future(loop=self.loop) + tr = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext, waiter=waiter) self.assertTrue(self.sslsock.do_handshake.called) - self.assertTrue(self.loop.remove_reader.called) - self.assertTrue(self.loop.remove_writer.called) - self.assertEqual((1, tr._on_ready,), - self.loop.add_reader.call_args[0]) - self.assertEqual((1, tr._on_ready,), - self.loop.add_writer.call_args[0]) - self.assertEqual((tr._waiter.set_result, None), - self.loop.call_soon.call_args[0]) - tr._waiter.cancel() + self.loop.assert_reader(1, tr._on_ready) + self.loop.assert_writer(1, tr._on_ready) + test_utils.run_briefly(self.loop) + self.assertIsNone(waiter.result()) def test_on_handshake_reader_retry(self): self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError - transport = self._make_one() + transport = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext) transport._on_handshake() - self.assertEqual((1, transport._on_handshake,), - self.loop.add_reader.call_args[0]) + self.loop.assert_reader(1, transport._on_handshake) def test_on_handshake_writer_retry(self): self.sslsock.do_handshake.side_effect = ssl.SSLWantWriteError - transport = self._make_one() + transport = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext) transport._on_handshake() - self.assertEqual((1, transport._on_handshake,), - self.loop.add_writer.call_args[0]) + self.loop.assert_writer(1, transport._on_handshake) def test_on_handshake_exc(self): exc = ValueError() self.sslsock.do_handshake.side_effect = exc - transport = self._make_one() + transport = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext) transport._waiter = futures.Future(loop=self.loop) transport._on_handshake() self.assertTrue(self.sslsock.close.called) @@ -1025,7 +1022,8 @@ def test_on_handshake_exc(self): self.assertIs(exc, transport._waiter.exception()) def test_on_handshake_base_exc(self): - transport = self._make_one() + transport = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext) transport._waiter = futures.Future(loop=self.loop) exc = BaseException() self.sslsock.do_handshake.side_effect = exc @@ -1153,10 +1151,9 @@ def test_on_ready_send_closing(self): transport = self._make_one() transport.close() transport._buffer = [b'data'] - transport._call_connection_lost = unittest.mock.Mock() transport._on_ready() - self.assertTrue(self.loop.remove_writer.called) - self.assertTrue(transport._call_connection_lost.called) + self.assertFalse(self.loop.writers) + self.protocol.connection_lost.assert_called_with(None) def test_on_ready_send_closing_empty_buffer(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError @@ -1164,10 +1161,9 @@ def test_on_ready_send_closing_empty_buffer(self): transport = self._make_one() transport.close() transport._buffer = [] - transport._call_connection_lost = unittest.mock.Mock() transport._on_ready() - self.assertTrue(self.loop.remove_writer.called) - self.assertTrue(transport._call_connection_lost.called) + self.assertFalse(self.loop.writers) + self.protocol.connection_lost.assert_called_with(None) def test_on_ready_send_retry(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError @@ -1204,22 +1200,21 @@ def test_close(self): tr.close() self.assertTrue(tr._closing) - self.loop.remove_reader.assert_called_with(1) + self.assertEqual(1, self.loop.remove_reader_count[1]) self.assertEqual(tr._conn_lost, 1) - self.loop.reset_mock() tr.close() self.assertEqual(tr._conn_lost, 1) - self.assertFalse(self.loop.remove_reader.called) + self.assertEqual(1, self.loop.remove_reader_count[1]) class SelectorDatagramTransportTests(unittest.TestCase): def setUp(self): - self.loop = unittest.mock.Mock(spec_set=AbstractEventLoop) + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(DatagramProtocol) self.sock = unittest.mock.Mock(spec_set=socket.socket) self.sock.fileno.return_value = 7 - self.protocol = unittest.mock.Mock(spec_set=DatagramProtocol) def test_read_ready(self): transport = _SelectorDatagramTransport( @@ -1289,11 +1284,7 @@ def test_sendto_tryagain(self): self.loop, self.sock, self.protocol) transport.sendto(data, ('0.0.0.0', 12345)) - self.assertTrue(self.loop.add_writer.called) - self.assertEqual( - transport._sendto_ready, - self.loop.add_writer.call_args[0][1]) - + self.loop.assert_writer(7, transport._sendto_ready) self.assertEqual( [(b'data', ('0.0.0.0', 12345))], list(transport._buffer)) @@ -1370,11 +1361,12 @@ def test_sendto_ready(self): transport = _SelectorDatagramTransport( self.loop, self.sock, self.protocol) transport._buffer.append((data, ('0.0.0.0', 12345))) + self.loop.add_writer(7, transport._sendto_ready) transport._sendto_ready() self.assertTrue(self.sock.sendto.called) self.assertEqual( self.sock.sendto.call_args[0], (data, ('0.0.0.0', 12345))) - self.assertTrue(self.loop.remove_writer.called) + self.assertFalse(self.loop.writers) def test_sendto_ready_closing(self): data = b'data' @@ -1384,18 +1376,20 @@ def test_sendto_ready_closing(self): self.loop, self.sock, self.protocol) transport._closing = True transport._buffer.append((data, ())) + self.loop.add_writer(7, transport._sendto_ready) transport._sendto_ready() self.sock.sendto.assert_called_with(data, ()) - self.loop.remove_writer.assert_called_with(7) + self.assertFalse(self.loop.writers) self.sock.close.assert_called_with() self.protocol.connection_lost.assert_called_with(None) def test_sendto_ready_no_data(self): transport = _SelectorDatagramTransport( self.loop, self.sock, self.protocol) + self.loop.add_writer(7, transport._sendto_ready) transport._sendto_ready() self.assertFalse(self.sock.sendto.called) - self.assertTrue(self.loop.remove_writer.called) + self.assertFalse(self.loop.writers) def test_sendto_ready_tryagain(self): self.sock.sendto.side_effect = BlockingIOError @@ -1403,9 +1397,10 @@ def test_sendto_ready_tryagain(self): transport = _SelectorDatagramTransport( self.loop, self.sock, self.protocol) transport._buffer.extend([(b'data1', ()), (b'data2', ())]) + self.loop.add_writer(7, transport._sendto_ready) transport._sendto_ready() - self.assertFalse(self.loop.remove_writer.called) + self.loop.assert_writer(7, transport._sendto_ready) self.assertEqual( [(b'data1', ()), (b'data2', ())], list(transport._buffer)) diff --git a/tests/transports_test.py b/tests/transports_test.py index d2688c3a..5920cda6 100644 --- a/tests/transports_test.py +++ b/tests/transports_test.py @@ -3,6 +3,7 @@ import unittest import unittest.mock +from tulip import futures from tulip import transports diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index a384b4a1..d12d7dde 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -15,6 +15,7 @@ from tulip import events from tulip import futures from tulip import protocols +from tulip import test_utils from tulip import unix_events @@ -303,10 +304,10 @@ def test__sig_chld_unknown_status_in_handler(self, m_waitpid, class UnixReadPipeTransportTests(unittest.TestCase): def setUp(self): - self.loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop) + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(protocols.Protocol) self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) self.pipe.fileno.return_value = 5 - self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) fcntl_patcher = unittest.mock.patch('fcntl.fcntl') fcntl_patcher.start() @@ -315,16 +316,16 @@ def setUp(self): def test_ctor(self): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) - self.loop.add_reader.assert_called_with(5, tr._read_ready) - self.loop.call_soon.assert_called_with( - self.protocol.connection_made, tr) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) def test_ctor_with_waiter(self): fut = futures.Future(loop=self.loop) unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol, fut) - self.loop.call_soon.assert_called_with(fut.set_result, None) - fut.cancel() + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) @unittest.mock.patch('os.read') def test__read_ready(self, m_read): @@ -344,20 +345,20 @@ def test__read_ready_eof(self, m_read): tr._read_ready() m_read.assert_called_with(5, tr.max_size) - self.loop.remove_reader.assert_called_with(5) - self.loop.call_soon.assert_has_calls([ - unittest.mock.call(self.protocol.eof_received), - unittest.mock.call(tr._call_connection_lost, None)]) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.eof_received.assert_called_with() + self.protocol.connection_lost.assert_called_with(None) @unittest.mock.patch('os.read') def test__read_ready_blocked(self, m_read): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) - self.loop.reset_mock() m_read.side_effect = BlockingIOError tr._read_ready() m_read.assert_called_with(5, tr.max_size) + test_utils.run_briefly(self.loop) self.assertFalse(self.protocol.data_received.called) @unittest.mock.patch('tulip.log.tulip_log.exception') @@ -379,8 +380,10 @@ def test_pause(self, m_read): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) + m = unittest.mock.Mock() + self.loop.add_reader(5, m) tr.pause() - self.loop.remove_reader.assert_called_with(5) + self.assertFalse(self.loop.readers) @unittest.mock.patch('os.read') def test_resume(self, m_read): @@ -388,7 +391,7 @@ def test_resume(self, m_read): self.loop, self.pipe, self.protocol) tr.resume() - self.loop.add_reader.assert_called_with(5, tr._read_ready) + self.loop.assert_reader(5, tr._read_ready) @unittest.mock.patch('os.read') def test_close(self, m_read): @@ -417,8 +420,9 @@ def test__close(self, m_read): err = object() tr._close(err) self.assertTrue(tr._closing) - self.loop.remove_reader.assert_called_with(5) - self.loop.call_soon.assert_called_with(tr._call_connection_lost, err) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) def test__call_connection_lost(self): tr = unix_events._UnixReadPipeTransport( @@ -442,10 +446,10 @@ def test__call_connection_lost_with_err(self): class UnixWritePipeTransportTests(unittest.TestCase): def setUp(self): - self.loop = unittest.mock.Mock(spec_set=events.AbstractEventLoop) + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(protocols.BaseProtocol) self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) self.pipe.fileno.return_value = 5 - self.protocol = unittest.mock.Mock(spec_set=protocols.Protocol) fcntl_patcher = unittest.mock.patch('fcntl.fcntl') fcntl_patcher.start() @@ -461,17 +465,17 @@ def setUp(self): def test_ctor(self): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) - self.loop.add_reader.assert_called_with(5, tr._read_ready) - self.loop.call_soon.assert_called_with( - self.protocol.connection_made, tr) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) def test_ctor_with_waiter(self): fut = futures.Future(loop=self.loop) tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol, fut) - self.loop.call_soon.assert_called_with(fut.set_result, None) - self.loop.add_reader.assert_called_with(5, tr._read_ready) - fut.cancel() + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.assertEqual(None, fut.result()) def test_can_write_eof(self): tr = unix_events._UnixWritePipeTransport( @@ -486,7 +490,7 @@ def test_write(self, m_write): m_write.return_value = 4 tr.write(b'data') m_write.assert_called_with(5, b'data') - self.assertFalse(self.loop.add_writer.called) + self.assertFalse(self.loop.writers) self.assertEqual([], tr._buffer) @unittest.mock.patch('os.write') @@ -496,7 +500,7 @@ def test_write_no_data(self, m_write): tr.write(b'') self.assertFalse(m_write.called) - self.assertFalse(self.loop.add_writer.called) + self.assertFalse(self.loop.writers) self.assertEqual([], tr._buffer) @unittest.mock.patch('os.write') @@ -507,7 +511,7 @@ def test_write_partial(self, m_write): m_write.return_value = 2 tr.write(b'data') m_write.assert_called_with(5, b'data') - self.loop.add_writer.assert_called_with(5, tr._write_ready) + self.loop.assert_writer(5, tr._write_ready) self.assertEqual([b'ta'], tr._buffer) @unittest.mock.patch('os.write') @@ -515,10 +519,11 @@ def test_write_buffer(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) tr._buffer = [b'previous'] tr.write(b'data') self.assertFalse(m_write.called) - self.assertFalse(self.loop.add_writer.called) + self.loop.assert_writer(5, tr._write_ready) self.assertEqual([b'previous', b'data'], tr._buffer) @unittest.mock.patch('os.write') @@ -529,7 +534,7 @@ def test_write_again(self, m_write): m_write.side_effect = BlockingIOError() tr.write(b'data') m_write.assert_called_with(5, b'data') - self.loop.add_writer.assert_called_with(5, tr._write_ready) + self.loop.assert_writer(5, tr._write_ready) self.assertEqual([b'data'], tr._buffer) @unittest.mock.patch('tulip.unix_events.tulip_log') @@ -543,7 +548,7 @@ def test_write_err(self, m_write, m_log): tr._fatal_error = unittest.mock.Mock() tr.write(b'data') m_write.assert_called_with(5, b'data') - self.assertFalse(self.loop.called) + self.assertFalse(self.loop.writers) self.assertEqual([], tr._buffer) tr._fatal_error.assert_called_with(err) self.assertEqual(1, tr._conn_lost) @@ -561,21 +566,22 @@ def test__read_ready(self): tr = unix_events._UnixWritePipeTransport(self.loop, self.pipe, self.protocol) tr._read_ready() - self.loop.remove_writer.assert_called_with(5) - self.loop.remove_reader.assert_called_with(5) + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) self.assertTrue(tr._closing) - self.loop.call_soon.assert_called_with(tr._call_connection_lost, - None) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) @unittest.mock.patch('os.write') def test__write_ready(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) tr._buffer = [b'da', b'ta'] m_write.return_value = 4 tr._write_ready() m_write.assert_called_with(5, b'data') - self.loop.remove_writer.assert_called_with(5) + self.assertFalse(self.loop.writers) self.assertEqual([], tr._buffer) @unittest.mock.patch('os.write') @@ -583,11 +589,12 @@ def test__write_ready_partial(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) tr._buffer = [b'da', b'ta'] m_write.return_value = 3 tr._write_ready() m_write.assert_called_with(5, b'data') - self.assertFalse(self.loop.remove_writer.called) + self.loop.assert_writer(5, tr._write_ready) self.assertEqual([b'a'], tr._buffer) @unittest.mock.patch('os.write') @@ -595,11 +602,12 @@ def test__write_ready_again(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) tr._buffer = [b'da', b'ta'] m_write.side_effect = BlockingIOError() tr._write_ready() m_write.assert_called_with(5, b'data') - self.assertFalse(self.loop.remove_writer.called) + self.loop.assert_writer(5, tr._write_ready) self.assertEqual([b'data'], tr._buffer) @unittest.mock.patch('os.write') @@ -607,11 +615,12 @@ def test__write_ready_empty(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) tr._buffer = [b'da', b'ta'] m_write.return_value = 0 tr._write_ready() m_write.assert_called_with(5, b'data') - self.assertFalse(self.loop.remove_writer.called) + self.loop.assert_writer(5, tr._write_ready) self.assertEqual([b'data'], tr._buffer) @unittest.mock.patch('tulip.log.tulip_log.exception') @@ -620,31 +629,34 @@ def test__write_ready_err(self, m_write, m_logexc): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) tr._buffer = [b'da', b'ta'] m_write.side_effect = err = OSError() tr._write_ready() m_write.assert_called_with(5, b'data') - self.loop.remove_writer.assert_called_with(5) - self.loop.remove_reader.assert_called_with(5) + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) self.assertEqual([], tr._buffer) self.assertTrue(tr._closing) - self.loop.call_soon.assert_called_with( - tr._call_connection_lost, err) m_logexc.assert_called_with('Fatal error for %s', tr) self.assertEqual(1, tr._conn_lost) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + @unittest.mock.patch('os.write') def test__write_ready_closing(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) tr._closing = True tr._buffer = [b'da', b'ta'] m_write.return_value = 4 tr._write_ready() m_write.assert_called_with(5, b'data') - self.loop.remove_writer.assert_called_with(5) - self.loop.remove_reader.assert_called_with(5) + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) self.assertEqual([], tr._buffer) self.protocol.connection_lost.assert_called_with(None) self.pipe.close.assert_called_with() @@ -654,15 +666,17 @@ def test_abort(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) + self.loop.add_reader(5, tr._read_ready) tr._buffer = [b'da', b'ta'] tr.abort() self.assertFalse(m_write.called) - self.loop.remove_writer.assert_called_with(5) - self.loop.remove_reader.assert_called_with(5) + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) self.assertEqual([], tr._buffer) self.assertTrue(tr._closing) - self.loop.call_soon.assert_called_with( - tr._call_connection_lost, None) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) def test__call_connection_lost(self): tr = unix_events._UnixWritePipeTransport( @@ -705,9 +719,9 @@ def test_write_eof(self): tr.write_eof() self.assertTrue(tr._closing) - self.loop.remove_reader.assert_called_with(5) - self.loop.call_soon.assert_called_with( - tr._call_connection_lost, None) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) def test_write_eof_pending(self): tr = unix_events._UnixWritePipeTransport( @@ -740,27 +754,29 @@ def test_double_pause_resume_writing(self): def test_pause_resume_writing_with_nonempty_buffer(self): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) tr._buffer = [b'da', b'ta'] tr.pause_writing() self.assertFalse(tr._writing) - self.loop.remove_writer.assert_called_with(5) + self.assertFalse(self.loop.writers) self.assertEqual([b'da', b'ta'], tr._buffer) tr.resume_writing() self.assertTrue(tr._writing) - self.loop.add_writer.assert_called_with(5, tr._write_ready) + self.loop.assert_writer(5, tr._write_ready) self.assertEqual([b'da', b'ta'], tr._buffer) @unittest.mock.patch('os.write') def test__write_ready_on_pause(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) tr._buffer = [b'da', b'ta'] tr.pause_writing() - self.loop.remove_writer.reset_mock() + tr._write_ready() self.assertFalse(m_write.called) - self.assertFalse(self.loop.remove_writer.called) + self.assertFalse(self.loop.writers) self.assertEqual([b'da', b'ta'], tr._buffer) self.assertFalse(tr._writing) @@ -768,9 +784,10 @@ def test_discard_output(self): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) tr._buffer = [b'da', b'ta'] + self.loop.add_writer(5, tr._write_ready) tr.discard_output() self.assertTrue(tr._writing) - self.loop.remove_writer.assert_called_with(5) + self.assertFalse(self.loop.writers) self.assertEqual([], tr._buffer) def test_discard_output_without_pending_writes(self): @@ -778,5 +795,5 @@ def test_discard_output_without_pending_writes(self): self.loop, self.pipe, self.protocol) tr.discard_output() self.assertTrue(tr._writing) - self.assertFalse(self.loop.remove_writer.called) + self.assertFalse(self.loop.writers) self.assertEqual([], tr._buffer) diff --git a/tulip/test_utils.py b/tulip/test_utils.py index 5077c4a4..c6b8fe1c 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -1,6 +1,7 @@ """Utilities shared by tests.""" import cgi +import collections import contextlib import gc import email.parser @@ -15,6 +16,8 @@ import sys import threading import traceback +import unittest +import unittest.mock import urllib.parse try: import ssl @@ -24,6 +27,8 @@ import tulip import tulip.http from tulip.http import client +from tulip import base_events +from tulip import events from tulip import base_events from tulip import selectors @@ -306,6 +311,16 @@ def _response(self, response, body=None, headers=None, chunked=False): self._srv.keep_alive(True) +def make_test_protocol(base): + proto = unittest.mock.Mock(spec_set=base) + for name in dir(base): + if name.startswith('__') and name.endswith('__'): + # skip magic names + continue + getattr(proto, name).return_value = None + return proto + + class TestSelector(selectors._BaseSelector): def select(self, timeout): @@ -347,6 +362,10 @@ def gen(): self._timers = [] self._selector = TestSelector() + self.readers = {} + self.writers = {} + self.reset_counters() + def time(self): return self._time @@ -364,6 +383,48 @@ def close(self): else: raise AssertionError("Time generator is not finished") + def add_reader(self, fd, callback, *args): + self.readers[fd] = events.make_handle(callback, args) + + def remove_reader(self, fd): + self.remove_reader_count[fd] += 1 + if fd in self.readers: + del self.readers[fd] + return True + else: + return False + + def assert_reader(self, fd, callback, *args): + assert fd in self.readers, 'fd {} is not registered'.format(fd) + handle = self.readers[fd] + assert handle._callback == callback, '{!r} != {!r}'.format( + handle._callback, callback) + assert handle._args == args, '{!r} != {!r}'.format( + handle._args, args) + + def add_writer(self, fd, callback, *args): + self.writers[fd] = events.make_handle(callback, args) + + def remove_writer(self, fd): + self.remove_writer_count[fd] += 1 + if fd in self.writers: + del self.writers[fd] + return True + else: + return False + + def assert_writer(self, fd, callback, *args): + assert fd in self.writers, 'fd {} is not registered'.format(fd) + handle = self.writers[fd] + assert handle._callback == callback, '{!r} != {!r}'.format( + handle._callback, callback) + assert handle._args == args, '{!r} != {!r}'.format( + handle._args, args) + + def reset_counters(self): + self.remove_reader_count = collections.defaultdict(int) + self.remove_writer_count = collections.defaultdict(int) + def _run_once(self): super()._run_once() for when in self._timers: diff --git a/tulip/unix_events.py b/tulip/unix_events.py index cccc568a..3e78da39 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -291,6 +291,7 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): def _read_ready(self): # pipe was closed by peer + self._close() def write(self, data): @@ -337,6 +338,9 @@ def _write_ready(self): self._buffer.append(data) except Exception as exc: self._conn_lost += 1 + # Remove writer here, _fatal_error() doesn't it + # because _buffer is empty. + self._loop.remove_writer(self._fileno) self._fatal_error(exc) else: if n == len(data): @@ -376,8 +380,9 @@ def _fatal_error(self, exc): def _close(self, exc=None): self._closing = True + if self._buffer: + self._loop.remove_writer(self._fileno) self._buffer.clear() - self._loop.remove_writer(self._fileno) self._loop.remove_reader(self._fileno) self._loop.call_soon(self._call_connection_lost, exc) From 1497883f359b94dd747a5e3226cc1ea62fc72e78 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 21 Aug 2013 11:08:56 -0700 Subject: [PATCH 0558/1502] Fold long line. --- tests/selector_events_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index 58cef434..5880c89f 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -989,7 +989,8 @@ def _make_one(self, create_waiter=None): def test_on_handshake(self): waiter = futures.Future(loop=self.loop) tr = _SelectorSslTransport( - self.loop, self.sock, self.protocol, self.sslcontext, waiter=waiter) + self.loop, self.sock, self.protocol, self.sslcontext, + waiter=waiter) self.assertTrue(self.sslsock.do_handshake.called) self.loop.assert_reader(1, tr._on_ready) self.loop.assert_writer(1, tr._on_ready) From 827976ac4bc43a3d827f6f7a4f5d136969a3dedd Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 21 Aug 2013 22:21:22 -0700 Subject: [PATCH 0559/1502] Fix a bunch of assertTrue calls. --- tests/base_events_test.py | 2 +- tests/events_test.py | 17 +++++++++-------- tests/queues_test.py | 2 +- tests/windows_utils_test.py | 4 ++-- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index 3fa7a06f..22679f79 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -180,7 +180,7 @@ def test__run_once(self): self.loop._run_once() t = self.loop._selector.select.call_args[0][0] - self.assertTrue(9.99 < t < 10.1) + self.assertTrue(9.99 < t < 10.1, t) self.assertEqual([h2], self.loop._scheduled) self.assertTrue(self.loop._process_events.called) diff --git a/tests/events_test.py b/tests/events_test.py index 3d562caf..e59a2ebc 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -249,7 +249,7 @@ def callback(arg): self.loop.run_forever() t1 = time.monotonic() self.assertEqual(results, ['hello world']) - self.assertTrue(0.09 <= t1-t0 <= 0.12) + self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0) def test_call_soon(self): results = [] @@ -347,7 +347,7 @@ def remove_writer(): w.close() data = r.recv(256*1024) r.close() - self.assertTrue(len(data) >= 200) + self.assertGreaterEqual(len(data), 200) def test_sock_client_ops(self): with test_utils.run_test_server(self.loop) as httpd: @@ -485,7 +485,7 @@ def test_create_connection(self): self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) self.loop.run_until_complete(pr.done) - self.assertTrue(pr.nbytes > 0) + self.assertGreater(pr.nbytes, 0) tr.close() def test_create_connection_sock(self): @@ -513,7 +513,7 @@ def test_create_connection_sock(self): self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) self.loop.run_until_complete(pr.done) - self.assertTrue(pr.nbytes > 0) + self.assertGreater(pr.nbytes, 0) tr.close() @unittest.skipIf(ssl is None, 'No ssl module') @@ -529,7 +529,7 @@ def test_create_ssl_connection(self): self.assertTrue( hasattr(tr.get_extra_info('socket'), 'getsockname')) self.loop.run_until_complete(pr.done) - self.assertTrue(pr.nbytes > 0) + self.assertGreater(pr.nbytes, 0) tr.close() def test_create_connection_local_addr(self): @@ -940,7 +940,7 @@ def main(): self.loop.run_forever() elapsed = time.monotonic() - start - self.assertTrue(elapsed < 0.1) + self.assertLess(elapsed, 0.1) self.assertEqual(t.result(), 'cancelled') self.assertRaises(futures.CancelledError, f.result) self.assertTrue(ov is None or not ov.pending) @@ -1306,7 +1306,7 @@ def callback(*args): self.assertTrue(r.startswith( 'Handle(' '.callback')) - self.assertTrue(r.endswith('())')) + self.assertTrue(r.endswith('())'), r) def test_make_handle(self): def callback(*args): @@ -1350,7 +1350,7 @@ def callback(*args): self.assertTrue(h._cancelled) r = repr(h) - self.assertTrue(r.endswith('())')) + self.assertTrue(r.endswith('())'), r) self.assertRaises(AssertionError, events.TimerHandle, None, callback, args) @@ -1363,6 +1363,7 @@ def callback(*args): h1 = events.TimerHandle(when, callback, ()) h2 = events.TimerHandle(when, callback, ()) + # TODO: Use assertLess etc. self.assertFalse(h1 < h2) self.assertFalse(h2 < h1) self.assertTrue(h1 <= h2) diff --git a/tests/queues_test.py b/tests/queues_test.py index 07440585..4d4876b9 100644 --- a/tests/queues_test.py +++ b/tests/queues_test.py @@ -40,7 +40,7 @@ def gen(): self.addCleanup(loop.close) q = queues.Queue(loop=loop) - self.assertTrue(fn(q).startswith(' 0) - self.assertTrue(len(err) > 0) + self.assertGreater(len(out), 0) + self.assertGreater(len(err), 0) # allow for partial reads... self.assertTrue(msg.upper().rstrip().startswith(out)) self.assertTrue(b"stderr".startswith(err)) From 3f0b6b7d17f071cc4c485499733f68a05ee54982 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 22 Aug 2013 09:48:16 -0700 Subject: [PATCH 0560/1502] Skip test if no ssl. --- tests/streams_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/streams_test.py b/tests/streams_test.py index 347b1262..2267a0f5 100644 --- a/tests/streams_test.py +++ b/tests/streams_test.py @@ -1,6 +1,7 @@ """Tests for streams.py.""" import gc +import ssl import unittest import unittest.mock @@ -44,6 +45,7 @@ def test_open_connection(self): writer.close() + @unittest.skipIf(ssl is None, 'No ssl module') def test_open_connection_no_loop_ssl(self): with test_utils.run_test_server(self.loop, use_ssl=True) as httpd: try: From 3f8c2aa805398711b590e10ce7c390233d814c97 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 22 Aug 2013 10:50:15 -0700 Subject: [PATCH 0561/1502] Increase sleep times so Windows has a chance. --- tests/events_test.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index e59a2ebc..2431334f 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -210,32 +210,35 @@ def coro2(): self.assertRaises( RuntimeError, self.loop.run_until_complete, coro2()) + # Note: because of the default Windows timing granularity of + # 15.6 msec, we use fairly long sleep times here (~100 msec). + def test_run_until_complete(self): t0 = self.loop.time() - self.loop.run_until_complete(tasks.sleep(0.010, loop=self.loop)) + self.loop.run_until_complete(tasks.sleep(0.1, loop=self.loop)) t1 = self.loop.time() - self.assertTrue(0.009 <= t1-t0 <= 0.018, t1-t0) + self.assertTrue(0.08 <= t1-t0 <= 0.12, t1-t0) def test_run_until_complete_stopped(self): @tasks.coroutine def cb(): self.loop.stop() - yield from tasks.sleep(0.010, loop=self.loop) + yield from tasks.sleep(0.1, loop=self.loop) task = cb() self.assertRaises(RuntimeError, self.loop.run_until_complete, task) def test_run_until_complete_timeout(self): t0 = self.loop.time() - task = tasks.async(tasks.sleep(0.020, loop=self.loop), loop=self.loop) + task = tasks.async(tasks.sleep(0.2, loop=self.loop), loop=self.loop) self.assertRaises(futures.TimeoutError, self.loop.run_until_complete, - task, timeout=0.010) + task, timeout=0.1) t1 = self.loop.time() - self.assertTrue(0.009 <= t1-t0 <= 0.018, t1-t0) + self.assertTrue(0.08 <= t1-t0 <= 0.12, t1-t0) self.loop.run_until_complete(task) t2 = self.loop.time() - self.assertTrue(0.018 <= t2-t0 <= 0.028, t2-t0) + self.assertTrue(0.18 <= t2-t0 <= 0.22, t2-t0) def test_call_later(self): results = [] From 136878b14d55311b74790353eba551e475015ef6 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 22 Aug 2013 11:42:42 -0700 Subject: [PATCH 0562/1502] Move set_result() calls from try clause to else clause. --- tulip/selector_events.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index b86acf10..bfaedf59 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -197,11 +197,12 @@ def _sock_recv(self, fut, registered, sock, n): return try: data = sock.recv(n) - fut.set_result(data) except (BlockingIOError, InterruptedError): self.add_reader(fd, self._sock_recv, fut, True, sock, n) except Exception as exc: fut.set_exception(exc) + else: + fut.set_result(data) def sock_sendall(self, sock, data): """XXX""" @@ -267,11 +268,12 @@ def _sock_connect(self, fut, registered, sock, address): if err != 0: # Jump to the except clause below. raise OSError(err, 'Connect call failed') - fut.set_result(None) except (BlockingIOError, InterruptedError): self.add_writer(fd, self._sock_connect, fut, True, sock, address) except Exception as exc: fut.set_exception(exc) + else: + fut.set_result(None) def sock_accept(self, sock): """XXX""" @@ -288,11 +290,12 @@ def _sock_accept(self, fut, registered, sock): try: conn, address = sock.accept() conn.setblocking(False) - fut.set_result((conn, address)) except (BlockingIOError, InterruptedError): self.add_reader(fd, self._sock_accept, fut, True, sock) except Exception as exc: fut.set_exception(exc) + else: + fut.set_result((conn, address)) def _process_events(self, event_list): for fileobj, mask, (reader, writer) in event_list: From c9e8a2b097f9eba62e8b5860d163f46059277e3d Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 23 Aug 2013 05:21:22 +0300 Subject: [PATCH 0563/1502] Reduce reference cycles, drop refs to loop and protocol on connection_lost --- tests/selector_events_test.py | 14 ++++++++++++-- tests/unix_events_test.py | 31 +++++++++++++++++++++++++++++++ tulip/selector_events.py | 2 ++ tulip/test_utils.py | 6 +++--- tulip/unix_events.py | 16 +++++++++++++++- 5 files changed, 63 insertions(+), 6 deletions(-) diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index 5880c89f..1395bf9b 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -1,7 +1,10 @@ """Tests for selector_events.py""" import errno +import gc +import pprint import socket +import sys import unittest import unittest.mock try: @@ -624,12 +627,20 @@ def test_fatal_error(self, m_exc): tr._force_close.assert_called_with(exc) def test_connection_lost(self): - exc = object() + exc = OSError() tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) tr._call_connection_lost(exc) self.protocol.connection_lost.assert_called_with(exc) self.sock.close.assert_called_with() + self.assertIsNone(tr._sock) + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) class SelectorSocketTransportTests(unittest.TestCase): @@ -980,7 +991,6 @@ def _make_one(self, create_waiter=None): transport = _SelectorSslTransport( self.loop, self.sock, self.protocol, self.sslcontext) self.sock.reset_mock() - self.protocol.reset_mock() self.sslsock.reset_mock() self.sslcontext.reset_mock() self.loop.reset_counters() diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index d12d7dde..3d3e0415 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -1,8 +1,11 @@ """Tests for unix_events.py.""" +import gc import errno import io +import pprint import stat +import sys import tempfile import unittest import unittest.mock @@ -433,6 +436,13 @@ def test__call_connection_lost(self): self.protocol.connection_lost.assert_called_with(err) self.pipe.close.assert_called_with() + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + def test__call_connection_lost_with_err(self): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) @@ -442,6 +452,13 @@ def test__call_connection_lost_with_err(self): self.protocol.connection_lost.assert_called_with(err) self.pipe.close.assert_called_with() + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + class UnixWritePipeTransportTests(unittest.TestCase): @@ -687,6 +704,13 @@ def test__call_connection_lost(self): self.protocol.connection_lost.assert_called_with(err) self.pipe.close.assert_called_with() + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + def test__call_connection_lost_with_err(self): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -696,6 +720,13 @@ def test__call_connection_lost_with_err(self): self.protocol.connection_lost.assert_called_with(err) self.pipe.close.assert_called_with() + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + def test_close(self): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index bfaedf59..99e73ece 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -362,6 +362,8 @@ def _call_connection_lost(self, exc): finally: self._sock.close() self._sock = None + self._protocol = None + self._loop = None class _SelectorSocketTransport(_SelectorTransport): diff --git a/tulip/test_utils.py b/tulip/test_utils.py index c6b8fe1c..05d5e6ab 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -312,13 +312,13 @@ def _response(self, response, body=None, headers=None, chunked=False): def make_test_protocol(base): - proto = unittest.mock.Mock(spec_set=base) + dct = {} for name in dir(base): if name.startswith('__') and name.endswith('__'): # skip magic names continue - getattr(proto, name).return_value = None - return proto + dct[name] = unittest.mock.Mock(return_value=None) + return type('TestProtocol', (base,) + base.__bases__, dct)() class TestSelector(selectors._BaseSelector): diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 3e78da39..09af7da1 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -265,6 +265,9 @@ def _call_connection_lost(self, exc): self._protocol.connection_lost(exc) finally: self._pipe.close() + self._pipe = None + self._protocol = None + self._loop = None class _UnixWritePipeTransport(transports.WriteTransport): @@ -391,6 +394,9 @@ def _call_connection_lost(self, exc): self._protocol.connection_lost(exc) finally: self._pipe.close() + self._pipe = None + self._protocol = None + self._loop = None def pause_writing(self): if self._writing: @@ -547,4 +553,12 @@ def _try_finish(self): if all(p is not None and p.disconnected for p in self._pipes.values()): self._finished = True - self._loop.call_soon(self._protocol.connection_lost, None) + self._loop.call_soon(self._call_connection_lost, None) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._proc = None + self._protocol = None + self._loop = None From fa0e9bdddb2872b1bb78827dc5ce6f1a6071235c Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 24 Aug 2013 14:08:50 +0300 Subject: [PATCH 0564/1502] Remove obsolete subpocess_transport.py --- tests/subprocess_test.py | 62 ------------- tulip/subprocess_transport.py | 158 ---------------------------------- 2 files changed, 220 deletions(-) delete mode 100644 tests/subprocess_test.py delete mode 100644 tulip/subprocess_transport.py diff --git a/tests/subprocess_test.py b/tests/subprocess_test.py deleted file mode 100644 index bdf7ce3c..00000000 --- a/tests/subprocess_test.py +++ /dev/null @@ -1,62 +0,0 @@ -# NOTE: This is a hack. Andrew Svetlov is working in a proper -# subprocess management transport for use with -# connect_{read,write}_pipe(). - -"""Tests for subprocess_transport.py.""" - -import logging -import unittest - -from tulip import events -from tulip import futures -from tulip import protocols -from tulip import subprocess_transport - - -class MyProto(protocols.Protocol): - - def __init__(self, loop): - self.state = 'INITIAL' - self.nbytes = 0 - self.done = futures.Future(loop=loop) - - def connection_made(self, transport): - self.transport = transport - assert self.state == 'INITIAL', self.state - self.state = 'CONNECTED' - transport.write_eof() - - def data_received(self, data): - logging.info('received: %r', data) - assert self.state == 'CONNECTED', self.state - self.nbytes += len(data) - - def eof_received(self): - assert self.state == 'CONNECTED', self.state - self.state = 'EOF' - self.transport.close() - - def connection_lost(self, exc): - assert self.state in ('CONNECTED', 'EOF'), self.state - self.state = 'CLOSED' - self.done.set_result(None) - - -class FutureTests(unittest.TestCase): - - def setUp(self): - self.loop = events.new_event_loop() - events.set_event_loop(None) - - def tearDown(self): - self.loop.close() - - def test_unix_subprocess(self): - p = MyProto(self.loop) - subprocess_transport.UnixSubprocessTransport(p, ['/bin/ls', '-lR'], - loop=self.loop) - self.loop.run_until_complete(p.done) - - -if __name__ == '__main__': - unittest.main() diff --git a/tulip/subprocess_transport.py b/tulip/subprocess_transport.py deleted file mode 100644 index cae33918..00000000 --- a/tulip/subprocess_transport.py +++ /dev/null @@ -1,158 +0,0 @@ -# NOTE: This is a hack. Andrew Svetlov is working in a proper -# subprocess management transport for use with -# connect_{read,write}_pipe(). - -import fcntl -import os -import traceback - -from . import transports -from . import events -from .log import tulip_log - - -class UnixSubprocessTransport(transports.Transport): - """Transport class managing a subprocess. - - TODO: Separate this into something that just handles pipe I/O, - and something else that handles pipe setup, fork, and exec. - """ - - def __init__(self, protocol, args, *, loop=None): - self._protocol = protocol # Not a factory! :-) - self._args = args # args[0] must be full path of binary. - if loop is None: - loop = events.get_event_loop() - self._event_loop = loop - self._buffer = [] - self._eof = False - rstdin, self._wstdin = os.pipe() - self._rstdout, wstdout = os.pipe() - - # TODO: This is incredibly naive. Should look at - # subprocess.py for all the precautions around fork/exec. - pid = os.fork() - if not pid: - # Child. - try: - os.dup2(rstdin, 0) - os.dup2(wstdout, 1) - # TODO: What to do with stderr? - os.execv(args[0], args) - except: - try: - traceback.print_traceback() - finally: - os._exit(127) - - # Parent. - os.close(rstdin) - os.close(wstdout) - _setnonblocking(self._wstdin) - _setnonblocking(self._rstdout) - self._event_loop.call_soon(self._protocol.connection_made, self) - self._event_loop.add_reader(self._rstdout, self._stdout_callback) - - def write(self, data): - assert not self._eof - assert isinstance(data, bytes), repr(data) - if not data: - return - - if not self._buffer: - # Attempt to write it right away first. - try: - n = os.write(self._wstdin, data) - except BlockingIOError: - pass - except Exception as exc: - self._fatal_error(exc) - return - else: - if n == len(data): - return - elif n: - data = data[n:] - self._event_loop.add_writer(self._wstdin, self._stdin_callback) - self._buffer.append(data) - - def write_eof(self): - assert not self._eof - assert self._wstdin >= 0 - self._eof = True - if not self._buffer: - self._event_loop.remove_writer(self._wstdin) - os.close(self._wstdin) - self._wstdin = -1 - self._maybe_cleanup() - - def close(self): - if not self._eof: - self.write_eof() - self._maybe_cleanup() - - def _fatal_error(self, exc): - tulip_log.error('Fatal error: %r', exc) - if self._rstdout >= 0: - os.close(self._rstdout) - self._rstdout = -1 - if self._wstdin >= 0: - os.close(self._wstdin) - self._wstdin = -1 - self._eof = True - self._buffer = None - self._maybe_cleanup(exc) - - _conn_lost_called = False - - def _maybe_cleanup(self, exc=None): - if (self._wstdin < 0 and - self._rstdout < 0 and - not self._conn_lost_called): - self._conn_lost_called = True - self._event_loop.call_soon(self._protocol.connection_lost, exc) - - def _stdin_callback(self): - data = b''.join(self._buffer) - assert data, "Data shold not be empty" - - self._buffer = [] - try: - n = os.write(self._wstdin, data) - except BlockingIOError: - self._buffer.append(data) - except Exception as exc: - self._fatal_error(exc) - else: - if n >= len(data): - self._event_loop.remove_writer(self._wstdin) - if self._eof: - os.close(self._wstdin) - self._wstdin = -1 - self._maybe_cleanup() - return - - elif n > 0: - data = data[n:] - - self._buffer.append(data) # Try again later. - - def _stdout_callback(self): - try: - data = os.read(self._rstdout, 1024) - except BlockingIOError: - pass - else: - if data: - self._event_loop.call_soon(self._protocol.data_received, data) - else: - self._event_loop.remove_reader(self._rstdout) - os.close(self._rstdout) - self._rstdout = -1 - self._event_loop.call_soon(self._protocol.eof_received) - self._maybe_cleanup() - - -def _setnonblocking(fd): - flags = fcntl.fcntl(fd, fcntl.F_GETFL) - fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) From 125a989ba79302bd592e283da5ee179eab1d90dc Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 24 Aug 2013 15:17:43 +0300 Subject: [PATCH 0565/1502] Make signal module mandatory for Unix event loop. --- tests/unix_events_test.py | 14 +------------- tulip/unix_events.py | 19 +++++-------------- 2 files changed, 6 insertions(+), 27 deletions(-) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index 3d3e0415..f0b42a39 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -4,16 +4,13 @@ import errno import io import pprint +import signal import stat import sys import tempfile import unittest import unittest.mock -try: - import signal -except ImportError: - signal = None from tulip import events from tulip import futures @@ -38,15 +35,6 @@ def test_check_signal(self): self.assertRaises( ValueError, self.loop._check_signal, signal.NSIG + 1) - unix_events.signal = None - - def restore_signal(): - unix_events.signal = signal - self.addCleanup(restore_signal) - - self.assertRaises( - RuntimeError, self.loop._check_signal, signal.SIGINT) - def test_handle_signal_no_handler(self): self.loop._handle_signal(signal.NSIG + 1, ()) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 09af7da1..75131851 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -5,15 +5,12 @@ import fcntl import functools import os +import signal import socket import stat import subprocess import sys -try: - import signal -except ImportError: # pragma: no cover - signal = None from . import constants from . import events @@ -46,10 +43,9 @@ def _socketpair(self): return socket.socketpair() def close(self): - if signal is not None: - handler = self._signal_handlers.get(signal.SIGCHLD) - if handler is not None: - self.remove_signal_handler(signal.SIGCHLD) + handler = self._signal_handlers.get(signal.SIGCHLD) + if handler is not None: + self.remove_signal_handler(signal.SIGCHLD) super().close() def add_signal_handler(self, sig, callback, *args): @@ -137,9 +133,6 @@ def _check_signal(self, sig): if not isinstance(sig, int): raise TypeError('sig must be an int, not {!r}'.format(sig)) - if signal is None: - raise RuntimeError('Signals are not supported') - if not (1 <= sig < signal.NSIG): raise ValueError( 'sig {} out of range(1, {})'.format(sig, signal.NSIG)) @@ -165,10 +158,8 @@ def _make_subprocess_transport(self, protocol, args, shell, return transp def _reg_sigchld(self): - assert signal, "signal support is required" if signal.SIGCHLD not in self._signal_handlers: - self.add_signal_handler(signal.SIGCHLD, - self._sig_chld) + self.add_signal_handler(signal.SIGCHLD, self._sig_chld) def _sig_chld(self): try: From 76884f93f9d8d7be658c34508a839fe6171db1ac Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 26 Aug 2013 06:12:33 +0300 Subject: [PATCH 0566/1502] Add --forever option to runtests.py, fix sporadic test error --- runtests.py | 15 +++++++++++++-- tests/events_test.py | 18 +++++++++++++++--- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/runtests.py b/runtests.py index be7bbe4a..7c62bf45 100644 --- a/runtests.py +++ b/runtests.py @@ -47,6 +47,9 @@ ARGS.add_argument( '-c', '--catch', action="store_true", default=False, dest='catchbreak', help='Catch control-C and display results') +ARGS.add_argument( + '--forever', action="store_true", dest='forever', default=False, + help='Run tests forever to catch sporadic errors') ARGS.add_argument( '-q', action="store_true", dest='quiet', help='quiet') ARGS.add_argument( @@ -163,8 +166,16 @@ def runtests(): logger.setLevel(logging.DEBUG) if catchbreak: installHandler() - result = unittest.TextTestRunner(verbosity=v, failfast=failfast).run(tests) - sys.exit(not result.wasSuccessful()) + if args.forever: + while True: + result = unittest.TextTestRunner(verbosity=v, + failfast=failfast).run(tests) + if not result.wasSuccessful(): + sys.exit(1) + else: + result = unittest.TextTestRunner(verbosity=v, + failfast=failfast).run(tests) + sys.exit(not result.wasSuccessful()) def runcoverage(sdir, args): diff --git a/tests/events_test.py b/tests/events_test.py index 2431334f..1273d9dd 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -708,9 +708,21 @@ def connection_made(self, transport): super().connection_made(transport) f_proto.set_result(self) - port = find_unused_port() - f = self.loop.start_serving(TestMyProto, host=None, port=port) - socks = self.loop.run_until_complete(f) + try_count = 0 + while True: + try: + port = find_unused_port() + f = self.loop.start_serving(TestMyProto, host=None, port=port) + socks = self.loop.run_until_complete(f) + except OSError as ex: + if ex.errno == errno.EADDRINUSE: + try_count += 1 + self.assertGreaterEqual(5, try_count) + continue + else: + raise + else: + break client = socket.socket() client.connect(('127.0.0.1', port)) client.send(b'xxx') From b796492cf95a446935a9a653302606a3a37c3285 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 26 Aug 2013 13:46:22 +0300 Subject: [PATCH 0567/1502] Refactor tests for proactor transports. --- tests/proactor_events_test.py | 48 +++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py index 6b6de32f..da4dea35 100644 --- a/tests/proactor_events_test.py +++ b/tests/proactor_events_test.py @@ -7,24 +7,26 @@ import tulip from tulip.proactor_events import BaseProactorEventLoop from tulip.proactor_events import _ProactorSocketTransport +from tulip import test_utils class ProactorSocketTransportTests(unittest.TestCase): def setUp(self): - self.loop = unittest.mock.Mock() + self.loop = test_utils.TestLoop() + self.proactor = unittest.mock.Mock() + self.loop._proactor = self.proactor + self.protocol = test_utils.make_test_protocol(tulip.Protocol) self.sock = unittest.mock.Mock(socket.socket) - self.protocol = unittest.mock.Mock(tulip.Protocol) def test_ctor(self): fut = tulip.Future(loop=self.loop) tr = _ProactorSocketTransport( self.loop, self.sock, self.protocol, fut) - self.loop.call_soon.mock_calls[0].assert_called_with(tr._loop_reading) - self.loop.call_soon.mock_calls[1].assert_called_with( - self.protocol.connection_made, tr) - self.loop.call_soon.mock_calls[2].assert_called_with( - fut.set_result, None) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + self.protocol.connection_made(tr) + self.proactor.recv.assert_called_with(self.sock, 4096) def test_loop_reading(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) @@ -163,12 +165,12 @@ def test_loop_writing_closing(self): fut.set_result(1) tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) - self.loop.reset_mock() tr._write_fut = fut tr.close() tr._loop_writing(fut) self.assertIsNone(tr._write_fut) - self.loop.call_soon.assert_called_with(tr._call_connection_lost, None) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) def test_abort(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) @@ -178,29 +180,30 @@ def test_abort(self): def test_close(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) - self.loop.reset_mock() tr.close() - self.loop.call_soon.assert_called_with(tr._call_connection_lost, None) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) self.assertTrue(tr._closing) self.assertEqual(tr._conn_lost, 1) - self.loop.reset_mock() + self.protocol.connection_lost.reset_mock() tr.close() - self.assertFalse(self.loop.call_soon.called) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) def test_close_write_fut(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._write_fut = unittest.mock.Mock() - self.loop.reset_mock() tr.close() - self.assertFalse(self.loop.call_soon.called) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) def test_close_buffer(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._buffer = [b'data'] - self.loop.reset_mock() tr.close() - self.assertFalse(self.loop.call_soon.called) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) @unittest.mock.patch('tulip.proactor_events.tulip_log') def test_fatal_error(self, m_logging): @@ -219,23 +222,25 @@ def test_force_close(self): read_fut.cancel.assert_called_with() write_fut.cancel.assert_called_with() - self.loop.call_soon.assert_called_with(tr._call_connection_lost, None) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) self.assertEqual([], tr._buffer) self.assertEqual(tr._conn_lost, 1) def test_force_close_idempotent(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._closing = True - self.loop.reset_mock() tr._force_close(None) - self.assertFalse(self.loop.call_soon.called) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) def test_fatal_error_2(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._buffer = [b'data'] tr._force_close(None) - self.loop.call_soon.assert_called_with(tr._call_connection_lost, None) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) self.assertEqual([], tr._buffer) def test_call_connection_lost(self): @@ -376,7 +381,6 @@ def test_start_serving_cancel(self): loop = call_soon.call_args[0][0] # cancelled - self.sock.reset_mock() fut = tulip.Future(loop=self.loop) fut.cancel() loop(fut) From cc32b5f7cae091a1cbe1ad8effdfe8afc9cc48c5 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 26 Aug 2013 09:32:10 -0700 Subject: [PATCH 0568/1502] Kill non-functional timeout=0 arg on result() and exception(). --- tests/futures_test.py | 2 -- tulip/futures.py | 11 +++-------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/futures_test.py b/tests/futures_test.py index b88c2a75..c7228c00 100644 --- a/tests/futures_test.py +++ b/tests/futures_test.py @@ -56,7 +56,6 @@ def test_cancel(self): def test_result(self): f = futures.Future(loop=self.loop) self.assertRaises(futures.InvalidStateError, f.result) - self.assertRaises(futures.InvalidTimeoutError, f.result, 10) f.set_result(42) self.assertFalse(f.cancelled()) @@ -71,7 +70,6 @@ def test_exception(self): exc = RuntimeError() f = futures.Future(loop=self.loop) self.assertRaises(futures.InvalidStateError, f.exception) - self.assertRaises(futures.InvalidTimeoutError, f.exception, 10) f.set_exception(exc) self.assertFalse(f.cancelled()) diff --git a/tulip/futures.py b/tulip/futures.py index 068f77ee..76a2bce3 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -226,16 +226,13 @@ def done(self): """ return self._state != _PENDING - def result(self, timeout=0): + def result(self): """Return the result this future represents. If the future has been cancelled, raises CancelledError. If the future's result isn't yet available, raises InvalidStateError. If the future is done and has an exception set, this exception is raised. - Timeout values other than 0 are not supported. """ - if timeout != 0: - raise InvalidTimeoutError if self._state == _CANCELLED: raise CancelledError if self._state != _FINISHED: @@ -247,16 +244,14 @@ def result(self, timeout=0): raise self._exception return self._result - def exception(self, timeout=0): + def exception(self): """Return the exception that was set on this future. The exception (or None if no exception was set) is returned only if the future is done. If the future has been cancelled, raises CancelledError. If the future isn't done yet, raises - InvalidStateError. Timeout values other than 0 are not supported. + InvalidStateError. """ - if timeout != 0: - raise InvalidTimeoutError if self._state == _CANCELLED: raise CancelledError if self._state != _FINISHED: From 3d19a62c091e803e9d3ae7efd35ef96918981fb4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 26 Aug 2013 10:33:25 -0700 Subject: [PATCH 0569/1502] Kill unused timeout argument on internal _run_once() method. --- tests/base_events_test.py | 19 ------------------- tulip/base_events.py | 3 ++- 2 files changed, 2 insertions(+), 20 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index 22679f79..104dc763 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -184,25 +184,6 @@ def test__run_once(self): self.assertEqual([h2], self.loop._scheduled) self.assertTrue(self.loop._process_events.called) - def test__run_once_timeout(self): - h = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ()) - - self.loop._process_events = unittest.mock.Mock() - self.loop._scheduled.append(h) - self.loop._run_once(1.0) - self.assertEqual((1.0,), self.loop._selector.select.call_args[0]) - - def test__run_once_timeout_with_ready(self): - # If event loop has ready callbacks, select timeout is always 0. - h = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ()) - - self.loop._process_events = unittest.mock.Mock() - self.loop._scheduled.append(h) - self.loop._ready.append(h) - self.loop._run_once(1.0) - - self.assertEqual((0,), self.loop._selector.select.call_args[0]) - @unittest.mock.patch('tulip.base_events.time') @unittest.mock.patch('tulip.base_events.tulip_log') def test__run_once_logging(self, m_logging, m_time): diff --git a/tulip/base_events.py b/tulip/base_events.py index c3378da6..5ff2d3c9 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -553,7 +553,7 @@ def _add_callback_signalsafe(self, handle): self._add_callback(handle) self._write_to_self() - def _run_once(self, timeout=None): + def _run_once(self): """Run one full iteration of the event loop. This calls all currently ready callbacks, polls for I/O, @@ -564,6 +564,7 @@ def _run_once(self, timeout=None): while self._scheduled and self._scheduled[0]._cancelled: heapq.heappop(self._scheduled) + timeout = None if self._ready: timeout = 0 elif self._scheduled: From 189bcc99dd1612f253e38687e84062a05ebc89f6 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 26 Aug 2013 10:39:48 -0700 Subject: [PATCH 0570/1502] Fold long line. --- tulip/futures.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tulip/futures.py b/tulip/futures.py index 76a2bce3..8593e9ae 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -177,7 +177,8 @@ def __repr__(self): if self._timeout_handle is not None: dct['when'] = self._timeout_handle._when if dct: - res += '{' + ', '.join(k+'='+str(dct[k]) for k in sorted(dct)) + '}' + res += '{' + ', '.join('{}={}'.format(k, dct[k]) + for k in sorted(dct)) + '}' return res def cancel(self): From f565fcdcd1336636e65b983547bfa9a582b2f26c Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 26 Aug 2013 23:00:45 +0300 Subject: [PATCH 0571/1502] Fix tests to pass on Python 3.3 debug build. --- .hgeol | 4 + .hgignore | 12 + Makefile | 35 + NOTES | 176 +++ README | 21 + TODO | 163 +++ check.py | 41 + examples/child_process.py | 127 +++ examples/crawl.py | 104 ++ examples/curl.py | 24 + examples/mpsrv.py | 289 +++++ examples/srv.py | 163 +++ examples/tcp_echo.py | 113 ++ examples/tcp_protocol_parser.py | 170 +++ examples/udp_echo.py | 98 ++ examples/websocket.html | 90 ++ examples/wsclient.py | 97 ++ examples/wssrv.py | 309 +++++ overlapped.c | 1009 ++++++++++++++++ runtests.py | 224 ++++ setup.cfg | 2 + setup.py | 14 + tests/base_events_test.py | 592 ++++++++++ tests/echo.py | 6 + tests/echo2.py | 6 + tests/echo3.py | 9 + tests/events_test.py | 1583 ++++++++++++++++++++++++++ tests/futures_test.py | 326 ++++++ tests/http_client_functional_test.py | 552 +++++++++ tests/http_client_test.py | 289 +++++ tests/http_parser_test.py | 539 +++++++++ tests/http_protocol_test.py | 394 +++++++ tests/http_server_test.py | 300 +++++ tests/http_session_test.py | 139 +++ tests/http_websocket_test.py | 439 +++++++ tests/http_wsgi_test.py | 301 +++++ tests/locks_test.py | 918 +++++++++++++++ tests/parsers_test.py | 598 ++++++++++ tests/proactor_events_test.py | 393 +++++++ tests/queues_test.py | 502 ++++++++ tests/sample.crt | 14 + tests/sample.key | 15 + tests/selector_events_test.py | 1459 ++++++++++++++++++++++++ tests/selectors_test.py | 143 +++ tests/streams_test.py | 343 ++++++ tests/tasks_test.py | 1161 +++++++++++++++++++ tests/transports_test.py | 59 + tests/unix_events_test.py | 818 +++++++++++++ tests/windows_events_test.py | 81 ++ tests/windows_utils_test.py | 132 +++ tulip/TODO | 26 + tulip/__init__.py | 28 + tulip/base_events.py | 611 ++++++++++ tulip/constants.py | 4 + tulip/events.py | 393 +++++++ tulip/futures.py | 362 ++++++ tulip/http/__init__.py | 16 + tulip/http/client.py | 565 +++++++++ tulip/http/errors.py | 46 + tulip/http/protocol.py | 756 ++++++++++++ tulip/http/server.py | 215 ++++ tulip/http/session.py | 103 ++ tulip/http/websocket.py | 233 ++++ tulip/http/wsgi.py | 227 ++++ tulip/locks.py | 442 +++++++ tulip/log.py | 6 + tulip/parsers.py | 399 +++++++ tulip/proactor_events.py | 288 +++++ tulip/protocols.py | 100 ++ tulip/queues.py | 298 +++++ tulip/selector_events.py | 671 +++++++++++ tulip/selectors.py | 426 +++++++ tulip/streams.py | 211 ++++ tulip/tasks.py | 359 ++++++ tulip/test_utils.py | 443 +++++++ tulip/transports.py | 201 ++++ tulip/unix_events.py | 555 +++++++++ tulip/windows_events.py | 206 ++++ tulip/windows_utils.py | 181 +++ 79 files changed, 23767 insertions(+) create mode 100644 .hgeol create mode 100644 .hgignore create mode 100644 Makefile create mode 100644 NOTES create mode 100644 README create mode 100644 TODO create mode 100644 check.py create mode 100644 examples/child_process.py create mode 100755 examples/crawl.py create mode 100755 examples/curl.py create mode 100755 examples/mpsrv.py create mode 100755 examples/srv.py create mode 100755 examples/tcp_echo.py create mode 100755 examples/tcp_protocol_parser.py create mode 100755 examples/udp_echo.py create mode 100644 examples/websocket.html create mode 100755 examples/wsclient.py create mode 100755 examples/wssrv.py create mode 100644 overlapped.c create mode 100644 runtests.py create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 tests/base_events_test.py create mode 100644 tests/echo.py create mode 100644 tests/echo2.py create mode 100644 tests/echo3.py create mode 100644 tests/events_test.py create mode 100644 tests/futures_test.py create mode 100644 tests/http_client_functional_test.py create mode 100644 tests/http_client_test.py create mode 100644 tests/http_parser_test.py create mode 100644 tests/http_protocol_test.py create mode 100644 tests/http_server_test.py create mode 100644 tests/http_session_test.py create mode 100644 tests/http_websocket_test.py create mode 100644 tests/http_wsgi_test.py create mode 100644 tests/locks_test.py create mode 100644 tests/parsers_test.py create mode 100644 tests/proactor_events_test.py create mode 100644 tests/queues_test.py create mode 100644 tests/sample.crt create mode 100644 tests/sample.key create mode 100644 tests/selector_events_test.py create mode 100644 tests/selectors_test.py create mode 100644 tests/streams_test.py create mode 100644 tests/tasks_test.py create mode 100644 tests/transports_test.py create mode 100644 tests/unix_events_test.py create mode 100644 tests/windows_events_test.py create mode 100644 tests/windows_utils_test.py create mode 100644 tulip/TODO create mode 100644 tulip/__init__.py create mode 100644 tulip/base_events.py create mode 100644 tulip/constants.py create mode 100644 tulip/events.py create mode 100644 tulip/futures.py create mode 100644 tulip/http/__init__.py create mode 100644 tulip/http/client.py create mode 100644 tulip/http/errors.py create mode 100644 tulip/http/protocol.py create mode 100644 tulip/http/server.py create mode 100644 tulip/http/session.py create mode 100644 tulip/http/websocket.py create mode 100644 tulip/http/wsgi.py create mode 100644 tulip/locks.py create mode 100644 tulip/log.py create mode 100644 tulip/parsers.py create mode 100644 tulip/proactor_events.py create mode 100644 tulip/protocols.py create mode 100644 tulip/queues.py create mode 100644 tulip/selector_events.py create mode 100644 tulip/selectors.py create mode 100644 tulip/streams.py create mode 100644 tulip/tasks.py create mode 100644 tulip/test_utils.py create mode 100644 tulip/transports.py create mode 100644 tulip/unix_events.py create mode 100644 tulip/windows_events.py create mode 100644 tulip/windows_utils.py diff --git a/.hgeol b/.hgeol new file mode 100644 index 00000000..8233b6dd --- /dev/null +++ b/.hgeol @@ -0,0 +1,4 @@ +[patterns] +** = native +.hgignore = native +.hgeol = native diff --git a/.hgignore b/.hgignore new file mode 100644 index 00000000..99870025 --- /dev/null +++ b/.hgignore @@ -0,0 +1,12 @@ +.*\.py[co]$ +.*~$ +.*\.orig$ +.*\#.*$ +.*@.*$ +\.coverage$ +htmlcov$ +\.DS_Store$ +venv$ +distribute_setup.py$ +distribute-\d+.\d+.\d+.tar.gz$ +build$ diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..11fe52ca --- /dev/null +++ b/Makefile @@ -0,0 +1,35 @@ +# Some simple testing tasks (sorry, UNIX only). + +PYTHON=python3 +VERBOSE=$(V) +V= 0 +FLAGS= + +test: + $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS) + +vtest: + $(PYTHON) runtests.py -v 1 $(FLAGS) + +testloop: + while sleep 1; do $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS); done + +# See README for coverage installation instructions. +cov coverage: + $(PYTHON) runtests.py --coverage tulip -v $(VERBOSE) $(FLAGS) + echo "open file://`pwd`/htmlcov/index.html" + +check: + $(PYTHON) check.py + +clean: + rm -rf `find . -name __pycache__` + rm -f `find . -type f -name '*.py[co]' ` + rm -f `find . -type f -name '*~' ` + rm -f `find . -type f -name '.*~' ` + rm -f `find . -type f -name '@*' ` + rm -f `find . -type f -name '#*#' ` + rm -f `find . -type f -name '*.orig' ` + rm -f `find . -type f -name '*.rej' ` + rm -f .coverage + rm -rf htmlcov diff --git a/NOTES b/NOTES new file mode 100644 index 00000000..3b94ba96 --- /dev/null +++ b/NOTES @@ -0,0 +1,176 @@ +Notes from PyCon 2013 sprints +============================= + +- Cancellation. If a task creates several subtasks, and then the + parent task fails, should the subtasks be cancelled? (How do we + even establish the parent/subtask relationship?) + +- Adam Sah suggests that there might be a need for scheduling + (especially when multiple frameworks share an event loop). He + points to lottery scheduling but also mentions that's just one of + the options. However, after posting on python-tulip, it appears + none of the other frameworks have scheduling, and nobody seems to + miss it. + +- Feedback from Bram Cohen (Bittorrent creator) about UDP. He doesn't + think connected UDP is worth supporting, it doesn't do anything + except tell the kernel about the default target address for + sendto(). Basically he says all UDP end points are servers. He + sent me his own UDP event loop so I might glean some tricks from it. + He says we should treat EINTR the same as EAGAIN and friends. (We + should use the exceptions dedicated to errno checking, BTW.) HE + said to make sure we use SO_REUSEADDR (I think we already do). He + said to set the max datagram sizes pretty large (anything larger + than the declared limit is dropped on the floor). He reminds us of + the importance of being able to pick a valid, unused port by binding + to port 0 and then using getsockname(). He has an idea where he's + like to be able to kill all registered callbacks (i.e. Handles) + belonging to a certain "context". I think this can be done at the + application level (you'd have to wrap everything that returns a + Handle and collect these handles in some set or other datastructure) + but if someone thinks it's interesting we could imagine having some + kind of notion of context part of the event loop state, + e.g. associated with a Task (see Cancellation point above). He + brought up uTP (Micro Transport Protocol), a reimplementation of TCP + over UDP with more refined congestion control. + +- Mumblings about UNIX domain sockets and IPv6 addresses being + 4-tuples. The former can be handled by passing in a socket. There + seem to be no real use cases for the latter that can't be dealt with + by passing in suitably esoteric strings for the hostname. + getaddrinfo() will produce the appropriate 4-tuple and connect() + will accept it. + +- Mumblings on the list about add vs. set. + + +Notes from the second Tulip/Twisted meet-up +=========================================== + +Rackspace, 12/11/2012 +Glyph, Brian Warner, David Reid, Duncan McGreggor, others + +Flow control +------------ + +- Pause/resume on transport manages data_received. + +- There's also an API to tell the transport whom to pause when the + write calls are overwhelming it: IConsumer.registerProducer(). + +- There's also something called pipes but it's built on top of the + old interface. + +- Twisted has variations on the basic flow control that I should + ignore. + +Half_close +---------- + +- This sends an EOF after writing some stuff. + +- Can't write any more. + +- Problem with TLS is known (the RFC sadly specifies this behavior). + +- It must be dynamimcally discoverable whether the transport supports + half_close, since the protocol may have to do something different to + make up for its missing (e.g. use chunked encoding). Twisted uses + an interface check for this and also hasattr(trans, 'halfClose') + but a flag (or flag method) is fine too. + +Constructing transport and protocol +----------------------------------- + +- There are good reasons for passing a function to the transport + construction helper that creates the protocol. (You need these + anyway for server-side protocols.) The sequence of events is + something like + + . open socket + . create transport (pass it a socket?) + . create protocol (pass it nothing) + . proto.make_connection(transport); this does: + . self.transport = transport + . self.connection_made(transport) + + But it seems okay to skip make_connection and setting .transport. + Note that make_connection() is a concrete method on the Protocol + implementation base class, while connection_made() is an abstract + method on IProtocol. + +Event Loop +---------- + +- We discussed the sequence of actions in the event loop. I think in the + end we're fine with what Tulip currently does. There are two choices: + + Tulip: + . run ready callbacks until there aren't any left + . poll, adding more callbacks to the ready list + . add now-ready delayed callbacks to the ready list + . go to top + + Tornado: + . run all currently ready callbacks (but not new ones added during this) + . (the rest is the same) + + The difference is that in the Tulip version, CPU bound callbacks + that keep adding more to the queue will starve I/O (and yielding to + other tasks won't actually cause I/O to happen unless you do + e.g. sleep(0.001)). OTOH this may be good because it means there's + less overhead if you frequently split operations in two. + +- I think Twisted does it Tornado style (in a convoluted way :-), but + it may not matter, and it's important to leave this vague so + implementations can do what's best for their platform. (E.g. if the + event loop is built into the OS there are different trade-offs.) + +System call cost +---------------- + +- System calls on MacOS are expensive, on Linux they are cheap. + +- Optimal buffer size ~16K. + +- Try joining small buffer pieces together, but expect to be tuning + this later. + +Futures +------- + +- Futures are the most robust API for async stuff, you can check + errors etc. So let's do this. + +- Just don't implement wait(). + +- For the basics, however, (recv/send, mostly), don't use Futures but use + basic callbacks, transport/protocol style. + +- make_connection() (by any name) can return a Future, it makes it + easier to check for errors. + +- This means revisiting the Tulip proactor branch (IOCP). + +- The semantics of add_done_callback() are fuzzy about in which thread + the callback will be called. (It may be the current thread or + another one.) We don't like that. But always inserting a + call_soon() indirection may be expensive? Glyph suggested changing + the add_done_callback() method name to something else to indicate + the changed promise. + +- Separately, I've been thinking about having two versions of + call_soon() -- a more heavy-weight one to be called from other + threads that also writes a byte to the self-pipe. + +Signals +------- + +- There was a side conversation about signals. A signal handler is + similar to another thread, so probably should use (the heavy-weight + version of) call_soon() to schedule the real callback and not do + anything else. + +- Glyph vaguely recalled some trickiness with the self-pipe. We + should be able to fix this afterwards if necessary, it shouldn't + affect the API design. diff --git a/README b/README new file mode 100644 index 00000000..85bfe5a7 --- /dev/null +++ b/README @@ -0,0 +1,21 @@ +Tulip is the codename for my reference implementation of PEP 3156. + +PEP 3156: http://www.python.org/dev/peps/pep-3156/ + +*** This requires Python 3.3 or later! *** + +Copyright/license: Open source, Apache 2.0. Enjoy. + +Master Mercurial repo: http://code.google.com/p/tulip/ + +The old code lives in the subdirectory 'old'; the new code (conforming +to PEP 3156, under construction) lives in the 'tulip' subdirectory. + +To run tests: + - make test + +To run coverage (coverage package is required): + - make coverage + + +--Guido van Rossum diff --git a/TODO b/TODO new file mode 100644 index 00000000..c6d4eead --- /dev/null +++ b/TODO @@ -0,0 +1,163 @@ +# -*- Mode: text -*- + +TO DO LARGER TASKS + +- Need more examples. + +- Benchmarkable but more realistic HTTP server? + +- Example of using UDP. + +- Write up a tutorial for the scheduling API. + +- More systematic approach to logging. Logger objects? What about + heavy-duty logging, tracking essentially all task state changes? + +- Restructure directory, move demos and benchmarks to subdirectories. + + +TO DO LATER + +- When multiple tasks are accessing the same socket, they should + either get interleaved I/O or an immediate exception; it should not + compromise the integrity of the scheduler or the app or leave a task + hanging. + +- For epoll you probably want to check/(log?) EPOLLHUP and EPOLLERR errors. + +- Add the simplest API possible to run a generator with a timeout. + +- Ensure multiple tasks can do atomic writes to the same pipe (since + UNIX guarantees that short writes to pipes are atomic). + +- Ensure some easy way of distributing accepted connections across tasks. + +- Be wary of thread-local storage. There should be a standard API to + get the current Context (which holds current task, event loop, and + maybe more) and a standard meta-API to change how that standard API + works (i.e. without monkey-patching). + +- See how much of asyncore I've already replaced. + +- Could BufferedReader reuse the standard io module's readers??? + +- Support ZeroMQ "sockets" which are user objects. Though possibly + this can be supported by getting the underlying fd? See + http://mail.python.org/pipermail/python-ideas/2012-October/017532.html + OTOH see + https://github.com/zeromq/pyzmq/blob/master/zmq/eventloop/ioloop.py + +- Study goroutines (again). + +- Benchmarks: http://nichol.as/benchmark-of-python-web-servers + + +FROM OLDER LIST + +- Multiple readers/writers per socket? (At which level? pollster, + eventloop, or scheduler?) + +- Could poll() usefully be an iterator? + +- Do we need to support more epoll and/or kqueue modes/flags/options/etc.? + +- Optimize register/unregister calls away if they cancel each other out? + +- Add explicit wait queue to wait for Task's completion, instead of + callbacks? + +- Look at pyfdpdlib's ioloop.py: + http://code.google.com/p/pyftpdlib/source/browse/trunk/pyftpdlib/lib/ioloop.py + + +MISTAKES I MADE + +- Forgetting yield from. (E.g.: scheduler.sleep(1); listener.accept().) + +- Forgot to add bare yield at end of internal function, after block(). + +- Forgot to call add_done_callback(). + +- Forgot to pass an undoer to block(), bug only found when cancelled. + +- Subtle accounting mistake in a callback. + +- Used context.eventloop from a different thread, forgetting about TLS. + +- Nasty race: eventloop.ready may contain both an I/O callback and a + cancel callback. How to avoid? Keep the DelayedCall in ready. Is + that enough? + +- If a toplevel task raises an error it just stops and nothing is logged + unless you have debug logging on. This confused me. (Then again, + previously I logged whenever a task raised an error, and that was too + chatty...) + +- Forgot to set the connection socket returned by accept() in + nonblocking mode. + +- Nastiest so far (cost me about a day): A race condition in + call_in_thread() where the Future's done_callback (which was + task.unblock()) would run immediately at the time when + add_done_callback() was called, and this screwed over the task + state. Solution: wrap the callback in eventloop.call_later(). + Ironically, I had a comment stating there might be a race condition. + +- Another bug where I was calling unblock() for the current thread + immediately after calling block(), before yielding. + +- readexactly() wasn't checking for EOF, so could be looping. + (Worse, the first fix I attempted was wrong.) + +- Spent a day trying to understand why a tentative patch trying to + move the recv() implementation into the eventloop (or the pollster) + resulted in problems cancelling a recv() call. Ultimately the + problem is that the cancellation mechanism is part of the coroutine + scheduler, which simply throws an exception into a task when it next + runs, and there isn't anything to be interrupted in the eventloop; + but the eventloop still has a reader registered (which will never + fire because I suspended the server -- that's my test case :-). + Then, the eventloop keeps running until the last file descriptor is + unregistered. What contributed to this disaster? + * I didn't build the whole infrastructure, just played with recv() + * I don't have unittests + * I don't have good logging to see what is going + +- In sockets.py, in some SSL error handling code, used the wrong + variable (sock instead of sslsock). A linter would have found this. + +- In polling.py, in KqueuePollster.register_writer(), a copy/paste + error where I was testing for "if fd not in self.readers" instead of + writers. This only came out when I had both a reader and a writer + for the same fd. + +- Submitted some changes prematurely (forgot to pass the filename on + hg ci). + +- Forgot again that shutdown(SHUT_WR) on an ssl socket does not work + as I expected. I ran into this with the origininal sockets.py and + again in transport.py. + +- Having the same callback for both reading and writing has a problem: + it may be scheduled twice, and if the first call closes the socket, + the second runs into trouble. + + +MISTAKES I MADE IN TULIP V2 + +- Nice one: Future._schedule_callbacks() wasn't scheduling any callbacks. + Spot the bug in these four lines: + + def _schedule_callbacks(self): + callbacks = self._callbacks[:] + self._callbacks[:] = [] + for callback in self._callbacks: + self._event_loop.call_soon(callback, self) + + The good news is that I found it with a unittest (albeit not the + unittest intended to exercise this particular method :-( ). + +- In _make_self_pipe_or_sock(), called _pollster.register_reader() + instead of add_reader(), trying to optimize something but breaking + things instead (since the -- internal -- API of register_reader() + had changed). diff --git a/check.py b/check.py new file mode 100644 index 00000000..9ab6bcc0 --- /dev/null +++ b/check.py @@ -0,0 +1,41 @@ +"""Search for lines >= 80 chars or with trailing whitespace.""" + +import sys, os + +def main(): + args = sys.argv[1:] or os.curdir + for arg in args: + if os.path.isdir(arg): + for dn, dirs, files in os.walk(arg): + for fn in sorted(files): + if fn.endswith('.py'): + process(os.path.join(dn, fn)) + dirs[:] = [d for d in dirs if d[0] != '.'] + dirs.sort() + else: + process(arg) + +def isascii(x): + try: + x.encode('ascii') + return True + except UnicodeError: + return False + +def process(fn): + try: + f = open(fn) + except IOError as err: + print(err) + return + try: + for i, line in enumerate(f): + line = line.rstrip('\n') + sline = line.rstrip() + if len(line) >= 80 or line != sline or not isascii(line): + print('{}:{:d}:{}{}'.format( + fn, i+1, sline, '_' * (len(line) - len(sline)))) + finally: + f.close() + +main() diff --git a/examples/child_process.py b/examples/child_process.py new file mode 100644 index 00000000..d4a035bd --- /dev/null +++ b/examples/child_process.py @@ -0,0 +1,127 @@ +""" +Example of asynchronous interaction with a child python process. + +Note that on Windows we must use the IOCP event loop. +""" + +import os +import sys + +try: + import tulip +except ImportError: + # tulip is not installed + sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + import tulip + +from tulip import streams +from tulip import protocols + +if sys.platform == 'win32': + from tulip.windows_utils import Popen, PIPE + from tulip.windows_events import ProactorEventLoop +else: + from subprocess import Popen, PIPE + +# +# Return a write-only transport wrapping a writable pipe +# + +@tulip.coroutine +def connect_write_pipe(file): + loop = tulip.get_event_loop() + protocol = protocols.Protocol() + transport, _ = yield from loop.connect_write_pipe(tulip.Protocol, file) + return transport + +# +# Wrap a readable pipe in a stream +# + +@tulip.coroutine +def connect_read_pipe(file): + loop = tulip.get_event_loop() + stream_reader = streams.StreamReader(loop=loop) + def factory(): + return streams.StreamReaderProtocol(stream_reader) + transport, _ = yield from loop.connect_read_pipe(factory, file) + return stream_reader, transport + + +# +# Example +# + +@tulip.task +def main(loop): + # program which prints evaluation of each expression from stdin + code = r'''if 1: + import os + def writeall(fd, buf): + while buf: + n = os.write(fd, buf) + buf = buf[n:] + while True: + s = os.read(0, 1024) + if not s: + break + s = s.decode('ascii') + s = repr(eval(s)) + '\n' + s = s.encode('ascii') + writeall(1, s) + ''' + + # commands to send to input + commands = iter([b"1+1\n", + b"2**16\n", + b"1/3\n", + b"'x'*50", + b"1/0\n"]) + + # start subprocess and wrap stdin, stdout, stderr + p = Popen([sys.executable, '-c', code], + stdin=PIPE, stdout=PIPE, stderr=PIPE) + + stdin = yield from connect_write_pipe(p.stdin) + stdout, stdout_transport = yield from connect_read_pipe(p.stdout) + stderr, stderr_transport = yield from connect_read_pipe(p.stderr) + + # interact with subprocess + name = {stdout:'OUT', stderr:'ERR'} + registered = {tulip.Task(stderr.readline()): stderr, + tulip.Task(stdout.readline()): stdout} + while registered: + # write command + cmd = next(commands, None) + if cmd is None: + stdin.close() + else: + print('>>>', cmd.decode('ascii').rstrip()) + stdin.write(cmd) + + # get and print lines from stdout, stderr + timeout = None + while registered: + done, pending = yield from tulip.wait( + registered, timeout=timeout, return_when=tulip.FIRST_COMPLETED) + if not done: + break + for f in done: + stream = registered.pop(f) + res = f.result() + print(name[stream], res.decode('ascii').rstrip()) + if res != b'': + registered[tulip.Task(stream.readline())] = stream + timeout = 0.0 + + stdout_transport.close() + stderr_transport.close() + +if __name__ == '__main__': + if sys.platform == 'win32': + loop = ProactorEventLoop() + tulip.set_event_loop(loop) + else: + loop = tulip.get_event_loop() + loop.run_until_complete(main(loop)) + loop.close() diff --git a/examples/crawl.py b/examples/crawl.py new file mode 100755 index 00000000..ac9c25e9 --- /dev/null +++ b/examples/crawl.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 + +import logging +import re +import signal +import sys +import urllib.parse + +import tulip +import tulip.http + + +class Crawler: + + def __init__(self, rooturl, loop, maxtasks=100): + self.rooturl = rooturl + self.loop = loop + self.todo = set() + self.busy = set() + self.done = {} + self.tasks = set() + self.sem = tulip.Semaphore(maxtasks) + + # session stores cookies between requests and uses connection pool + self.session = tulip.http.Session() + + @tulip.task + def run(self): + self.addurls([(self.rooturl, '')]) # Set initial work. + yield from tulip.sleep(1) + while self.busy: + yield from tulip.sleep(1) + + self.session.close() + self.loop.stop() + + @tulip.task + def addurls(self, urls): + for url, parenturl in urls: + url = urllib.parse.urljoin(parenturl, url) + url, frag = urllib.parse.urldefrag(url) + if (url.startswith(self.rooturl) and + url not in self.busy and + url not in self.done and + url not in self.todo): + self.todo.add(url) + yield from self.sem.acquire() + task = self.process(url) + task.add_done_callback(lambda t: self.sem.release()) + task.add_done_callback(self.tasks.remove) + self.tasks.add(task) + + @tulip.task + def process(self, url): + print('processing:', url) + + self.todo.remove(url) + self.busy.add(url) + try: + resp = yield from tulip.http.request( + 'get', url, session=self.session) + except Exception as exc: + print('...', url, 'has error', repr(str(exc))) + self.done[url] = False + else: + if resp.status == 200 and resp.get_content_type() == 'text/html': + data = (yield from resp.read()).decode('utf-8', 'replace') + urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', data) + self.addurls([(u, url) for u in urls]) + + resp.close() + self.done[url] = True + + self.busy.remove(url) + print(len(self.done), 'completed tasks,', len(self.tasks), + 'still pending, todo', len(self.todo)) + + +def main(): + loop = tulip.get_event_loop() + + c = Crawler(sys.argv[1], loop) + c.run() + + try: + loop.add_signal_handler(signal.SIGINT, loop.stop) + except RuntimeError: + pass + loop.run_forever() + print('todo:', len(c.todo)) + print('busy:', len(c.busy)) + print('done:', len(c.done), '; ok:', sum(c.done.values())) + print('tasks:', len(c.tasks)) + + +if __name__ == '__main__': + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + logging.info('using iocp') + el = windows_events.ProactorEventLoop() + events.set_event_loop(el) + + main() diff --git a/examples/curl.py b/examples/curl.py new file mode 100755 index 00000000..7063adcd --- /dev/null +++ b/examples/curl.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 + +import sys +import tulip +import tulip.http + + +def curl(url): + response = yield from tulip.http.request('GET', url) + print(repr(response)) + + data = yield from response.read() + print(data.decode('utf-8', 'replace')) + + +if __name__ == '__main__': + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + el = windows_events.ProactorEventLoop() + events.set_event_loop(el) + + loop = tulip.get_event_loop() + loop.run_until_complete(curl(sys.argv[1])) diff --git a/examples/mpsrv.py b/examples/mpsrv.py new file mode 100755 index 00000000..6b1ebb8f --- /dev/null +++ b/examples/mpsrv.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python3 +"""Simple multiprocess http server written using an event loop.""" + +import argparse +import email.message +import os +import socket +import signal +import time +import tulip +import tulip.http +from tulip.http import websocket + +ARGS = argparse.ArgumentParser(description="Run simple http server.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') +ARGS.add_argument( + '--workers', action="store", dest='workers', + default=2, type=int, help='Number of workers.') + + +class HttpServer(tulip.http.ServerHttpProtocol): + + @tulip.coroutine + def handle_request(self, message, payload): + print('{}: method = {!r}; path = {!r}; version = {!r}'.format( + os.getpid(), message.method, message.path, message.version)) + + path = message.path + + if (not (path.isprintable() and path.startswith('/')) or '/.' in path): + path = None + else: + path = '.' + path + if not os.path.exists(path): + path = None + else: + isdir = os.path.isdir(path) + + if not path: + raise tulip.http.HttpStatusException(404) + + headers = email.message.Message() + for hdr, val in message.headers: + headers.add_header(hdr, val) + + if isdir and not path.endswith('/'): + path = path + '/' + raise tulip.http.HttpStatusException( + 302, headers=(('URI', path), ('Location', path))) + + response = tulip.http.Response(self.transport, 200) + response.add_header('Transfer-Encoding', 'chunked') + + # content encoding + accept_encoding = headers.get('accept-encoding', '').lower() + if 'deflate' in accept_encoding: + response.add_header('Content-Encoding', 'deflate') + response.add_compression_filter('deflate') + elif 'gzip' in accept_encoding: + response.add_header('Content-Encoding', 'gzip') + response.add_compression_filter('gzip') + + response.add_chunking_filter(1025) + + if isdir: + response.add_header('Content-type', 'text/html') + response.send_headers() + + response.write(b'
    \r\n') + for name in sorted(os.listdir(path)): + if name.isprintable() and not name.startswith('.'): + try: + bname = name.encode('ascii') + except UnicodeError: + pass + else: + if os.path.isdir(os.path.join(path, name)): + response.write(b'
  • ' + bname + b'/
  • \r\n') + else: + response.write(b'
  • ' + bname + b'
  • \r\n') + response.write(b'
') + else: + response.add_header('Content-type', 'text/plain') + response.send_headers() + + try: + with open(path, 'rb') as fp: + chunk = fp.read(8196) + while chunk: + response.write(chunk) + chunk = fp.read(8196) + except OSError: + response.write(b'Cannot open') + + response.write_eof() + if response.keep_alive(): + self.keep_alive(True) + + +class ChildProcess: + + def __init__(self, up_read, down_write, args, sock): + self.up_read = up_read + self.down_write = down_write + self.args = args + self.sock = sock + + def start(self): + # start server + self.loop = loop = tulip.new_event_loop() + tulip.set_event_loop(loop) + + def stop(): + self.loop.stop() + os._exit(0) + loop.add_signal_handler(signal.SIGINT, stop) + + f = loop.start_serving( + lambda: HttpServer(debug=True, keep_alive=75), sock=self.sock) + x = loop.run_until_complete(f)[0] + print('Starting srv worker process {} on {}'.format( + os.getpid(), x.getsockname())) + + # heartbeat + self.heartbeat() + + tulip.get_event_loop().run_forever() + os._exit(0) + + @tulip.task + def heartbeat(self): + # setup pipes + read_transport, read_proto = yield from self.loop.connect_read_pipe( + tulip.StreamProtocol, os.fdopen(self.up_read, 'rb')) + write_transport, _ = yield from self.loop.connect_write_pipe( + tulip.StreamProtocol, os.fdopen(self.down_write, 'wb')) + + reader = read_proto.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(write_transport) + + while True: + msg = yield from reader.read() + if msg is None: + print('Superviser is dead, {} stopping...'.format(os.getpid())) + self.loop.stop() + break + elif msg.tp == websocket.MSG_PING: + writer.pong() + elif msg.tp == websocket.MSG_CLOSE: + break + + read_transport.close() + write_transport.close() + + +class Worker: + + _started = False + + def __init__(self, loop, args, sock): + self.loop = loop + self.args = args + self.sock = sock + self.start() + + def start(self): + assert not self._started + self._started = True + + up_read, up_write = os.pipe() + down_read, down_write = os.pipe() + args, sock = self.args, self.sock + + pid = os.fork() + if pid: + # parent + os.close(up_read) + os.close(down_write) + self.connect(pid, up_write, down_read) + else: + # child + os.close(up_write) + os.close(down_read) + + # cleanup after fork + tulip.set_event_loop(None) + + # setup process + process = ChildProcess(up_read, down_write, args, sock) + process.start() + + @tulip.task + def heartbeat(self, writer): + while True: + yield from tulip.sleep(15) + + if (time.monotonic() - self.ping) < 30: + writer.ping() + else: + print('Restart unresponsive worker process: {}'.format( + self.pid)) + self.kill() + self.start() + return + + @tulip.task + def chat(self, reader): + while True: + msg = yield from reader.read() + if msg is None: + print('Restart unresponsive worker process: {}'.format( + self.pid)) + self.kill() + self.start() + return + elif msg.tp == websocket.MSG_PONG: + self.ping = time.monotonic() + + @tulip.task + def connect(self, pid, up_write, down_read): + # setup pipes + read_transport, proto = yield from self.loop.connect_read_pipe( + tulip.StreamProtocol, os.fdopen(down_read, 'rb')) + write_transport, _ = yield from self.loop.connect_write_pipe( + tulip.StreamProtocol, os.fdopen(up_write, 'wb')) + + # websocket protocol + reader = proto.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(write_transport) + + # store info + self.pid = pid + self.ping = time.monotonic() + self.rtransport = read_transport + self.wtransport = write_transport + self.chat_task = self.chat(reader) + self.heartbeat_task = self.heartbeat(writer) + + def kill(self): + self._started = False + self.chat_task.cancel() + self.heartbeat_task.cancel() + self.rtransport.close() + self.wtransport.close() + os.kill(self.pid, signal.SIGTERM) + + +class Superviser: + + def __init__(self, args): + self.loop = tulip.get_event_loop() + self.args = args + self.workers = [] + + def start(self): + # bind socket + sock = self.sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((self.args.host, self.args.port)) + sock.listen(1024) + sock.setblocking(False) + + # start processes + for idx in range(self.args.workers): + self.workers.append(Worker(self.loop, self.args, sock)) + + self.loop.add_signal_handler(signal.SIGINT, lambda: self.loop.stop()) + self.loop.run_forever() + + +def main(): + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + superviser = Superviser(args) + superviser.start() + + +if __name__ == '__main__': + main() diff --git a/examples/srv.py b/examples/srv.py new file mode 100755 index 00000000..e01e407c --- /dev/null +++ b/examples/srv.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +"""Simple server written using an event loop.""" + +import argparse +import email.message +import logging +import os +import sys +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +import tulip +import tulip.http + + +class HttpServer(tulip.http.ServerHttpProtocol): + + @tulip.coroutine + def handle_request(self, message, payload): + print('method = {!r}; path = {!r}; version = {!r}'.format( + message.method, message.path, message.version)) + + path = message.path + + if (not (path.isprintable() and path.startswith('/')) or '/.' in path): + print('bad path', repr(path)) + path = None + else: + path = '.' + path + if not os.path.exists(path): + print('no file', repr(path)) + path = None + else: + isdir = os.path.isdir(path) + + if not path: + raise tulip.http.HttpStatusException(404) + + headers = email.message.Message() + for hdr, val in message.headers: + print(hdr, val) + headers.add_header(hdr, val) + + if isdir and not path.endswith('/'): + path = path + '/' + raise tulip.http.HttpStatusException( + 302, headers=(('URI', path), ('Location', path))) + + response = tulip.http.Response(self.transport, 200) + response.add_header('Transfer-Encoding', 'chunked') + + # content encoding + accept_encoding = headers.get('accept-encoding', '').lower() + if 'deflate' in accept_encoding: + response.add_header('Content-Encoding', 'deflate') + response.add_compression_filter('deflate') + elif 'gzip' in accept_encoding: + response.add_header('Content-Encoding', 'gzip') + response.add_compression_filter('gzip') + + response.add_chunking_filter(1025) + + if isdir: + response.add_header('Content-type', 'text/html') + response.send_headers() + + response.write(b'
    \r\n') + for name in sorted(os.listdir(path)): + if name.isprintable() and not name.startswith('.'): + try: + bname = name.encode('ascii') + except UnicodeError: + pass + else: + if os.path.isdir(os.path.join(path, name)): + response.write(b'
  • ' + bname + b'/
  • \r\n') + else: + response.write(b'
  • ' + bname + b'
  • \r\n') + response.write(b'
') + else: + response.add_header('Content-type', 'text/plain') + response.send_headers() + + try: + with open(path, 'rb') as fp: + chunk = fp.read(8196) + while chunk: + response.write(chunk) + chunk = fp.read(8196) + except OSError: + response.write(b'Cannot open') + + response.write_eof() + if response.keep_alive(): + self.keep_alive(True) + + +ARGS = argparse.ArgumentParser(description="Run simple http server.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') +ARGS.add_argument( + '--iocp', action="store_true", dest='iocp', help='Windows IOCP event loop') +ARGS.add_argument( + '--ssl', action="store_true", dest='ssl', help='Run ssl mode.') +ARGS.add_argument( + '--sslcert', action="store", dest='certfile', help='SSL cert file.') +ARGS.add_argument( + '--sslkey', action="store", dest='keyfile', help='SSL key file.') + + +def main(): + args = ARGS.parse_args() + + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if args.iocp: + from tulip import windows_events + sys.argv.remove('--iocp') + logging.info('using iocp') + el = windows_events.ProactorEventLoop() + tulip.set_event_loop(el) + + if args.ssl: + here = os.path.join(os.path.dirname(__file__), 'tests') + + if args.certfile: + certfile = args.certfile or os.path.join(here, 'sample.crt') + keyfile = args.keyfile or os.path.join(here, 'sample.key') + else: + certfile = os.path.join(here, 'sample.crt') + keyfile = os.path.join(here, 'sample.key') + + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain(certfile, keyfile) + else: + sslcontext = None + + loop = tulip.get_event_loop() + f = loop.start_serving( + lambda: HttpServer(debug=True, keep_alive=75), args.host, args.port, + ssl=sslcontext) + socks = loop.run_until_complete(f) + print('serving on', socks[0].getsockname()) + try: + loop.run_forever() + except KeyboardInterrupt: + pass + + +if __name__ == '__main__': + main() diff --git a/examples/tcp_echo.py b/examples/tcp_echo.py new file mode 100755 index 00000000..39db5cca --- /dev/null +++ b/examples/tcp_echo.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +"""TCP echo server example.""" +import argparse +import tulip +try: + import signal +except ImportError: + signal = None + + +class EchoServer(tulip.Protocol): + + TIMEOUT = 5.0 + + def timeout(self): + print('connection timeout, closing.') + self.transport.close() + + def connection_made(self, transport): + print('connection made') + self.transport = transport + + # start 5 seconds timeout timer + self.h_timeout = tulip.get_event_loop().call_later( + self.TIMEOUT, self.timeout) + + def data_received(self, data): + print('data received: ', data.decode()) + self.transport.write(b'Re: ' + data) + + # restart timeout timer + self.h_timeout.cancel() + self.h_timeout = tulip.get_event_loop().call_later( + self.TIMEOUT, self.timeout) + + def eof_received(self): + pass + + def connection_lost(self, exc): + print('connection lost:', exc) + self.h_timeout.cancel() + + +class EchoClient(tulip.Protocol): + + message = 'This is the message. It will be echoed.' + + def connection_made(self, transport): + self.transport = transport + self.transport.write(self.message.encode()) + print('data sent:', self.message) + + def data_received(self, data): + print('data received:', data) + + # disconnect after 10 seconds + tulip.get_event_loop().call_later(10.0, self.transport.close) + + def eof_received(self): + pass + + def connection_lost(self, exc): + print('connection lost:', exc) + tulip.get_event_loop().stop() + + +def start_client(loop, host, port): + t = tulip.Task(loop.create_connection(EchoClient, host, port)) + loop.run_until_complete(t) + + +def start_server(loop, host, port): + f = loop.start_serving(EchoServer, host, port) + x = loop.run_until_complete(f)[0] + print('serving on', x.getsockname()) + + +ARGS = argparse.ArgumentParser(description="TCP Echo example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run tcp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run tcp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') + + +if __name__ == '__main__': + args = ARGS.parse_args() + + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + loop = tulip.get_event_loop() + if signal is not None: + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if args.server: + start_server(loop, args.host, args.port) + else: + start_client(loop, args.host, args.port) + + loop.run_forever() diff --git a/examples/tcp_protocol_parser.py b/examples/tcp_protocol_parser.py new file mode 100755 index 00000000..a0258613 --- /dev/null +++ b/examples/tcp_protocol_parser.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +"""Protocol parser example.""" +import argparse +import collections +import tulip +try: + import signal +except ImportError: + signal = None + + +MSG_TEXT = b'text:' +MSG_PING = b'ping:' +MSG_PONG = b'pong:' +MSG_STOP = b'stop:' + +Message = collections.namedtuple('Message', ('tp', 'data')) + + +def my_protocol_parser(): + """Parser is used with StreamBuffer for incremental protocol parsing. + Parser is a generator function, but it is not a coroutine. Usually + parsers are implemented as a state machine. + + more details in tulip/parsers.py + existing parsers: + * http protocol parsers tulip/http/protocol.py + * websocket parser tulip/http/websocket.py + """ + out, buf = yield + + while True: + tp = yield from buf.read(5) + if tp in (MSG_PING, MSG_PONG): + # skip line + yield from buf.skipuntil(b'\r\n') + out.feed_data(Message(tp, None)) + elif tp == MSG_STOP: + out.feed_data(Message(tp, None)) + elif tp == MSG_TEXT: + # read text + text = yield from buf.readuntil(b'\r\n') + out.feed_data(Message(tp, text.strip().decode('utf-8'))) + else: + raise ValueError('Unknown protocol prefix.') + + +class MyProtocolWriter: + + def __init__(self, transport): + self.transport = transport + + def ping(self): + self.transport.write(b'ping:\r\n') + + def pong(self): + self.transport.write(b'pong:\r\n') + + def stop(self): + self.transport.write(b'stop:\r\n') + + def send_text(self, text): + self.transport.write( + 'text:{}\r\n'.format(text.strip()).encode('utf-8')) + + +class EchoServer(tulip.Protocol): + + def connection_made(self, transport): + print('Connection made') + self.transport = transport + self.stream = tulip.StreamBuffer() + self.dispatch() + + def data_received(self, data): + self.stream.feed_data(data) + + def eof_received(self): + self.stream.feed_eof() + + def connection_lost(self, exc): + print('Connection lost') + + @tulip.task + def dispatch(self): + reader = self.stream.set_parser(my_protocol_parser()) + writer = MyProtocolWriter(self.transport) + + while True: + msg = yield from reader.read() + if msg is None: + break # client has been disconnected + + print('Message received: {}'.format(msg)) + + if msg.tp == MSG_PING: + writer.pong() + elif msg.tp == MSG_TEXT: + writer.send_text('Re: ' + msg.data) + elif msg.tp == MSG_STOP: + self.transport.close() + break + + +@tulip.task +def start_client(loop, host, port): + transport, stream = yield from loop.create_connection( + tulip.StreamProtocol, host, port) + reader = stream.set_parser(my_protocol_parser()) + writer = MyProtocolWriter(transport) + writer.ping() + + message = 'This is the message. It will be echoed.' + + while True: + msg = yield from reader.read() + + print('Message received: {}'.format(msg)) + if msg.tp == MSG_PONG: + writer.send_text(message) + print('data sent:', message) + elif msg.tp == MSG_TEXT: + writer.stop() + print('stop sent') + break + + transport.close() + + +def start_server(loop, host, port): + f = loop.start_serving(EchoServer, host, port) + x = loop.run_until_complete(f)[0] + print('serving on', x.getsockname()) + loop.run_forever() + + +ARGS = argparse.ArgumentParser(description="Protocol parser example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run tcp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run tcp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') + + +if __name__ == '__main__': + args = ARGS.parse_args() + + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + loop = tulip.get_event_loop() + if signal is not None: + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if args.server: + start_server(loop, args.host, args.port) + else: + loop.run_until_complete(start_client(loop, args.host, args.port)) diff --git a/examples/udp_echo.py b/examples/udp_echo.py new file mode 100755 index 00000000..0347bfbd --- /dev/null +++ b/examples/udp_echo.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +"""UDP echo example.""" +import argparse +import sys +import tulip +try: + import signal +except ImportError: + signal = None + + +class MyServerUdpEchoProtocol: + + def connection_made(self, transport): + print('start', transport) + self.transport = transport + + def datagram_received(self, data, addr): + print('Data received:', data, addr) + self.transport.sendto(data, addr) + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('stop', exc) + + +class MyClientUdpEchoProtocol: + + message = 'This is the message. It will be echoed.' + + def connection_made(self, transport): + self.transport = transport + print('sending "{}"'.format(self.message)) + self.transport.sendto(self.message.encode()) + print('waiting to receive') + + def datagram_received(self, data, addr): + print('received "{}"'.format(data.decode())) + self.transport.close() + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('closing transport', exc) + loop = tulip.get_event_loop() + loop.stop() + + +def start_server(loop, addr): + t = tulip.Task(loop.create_datagram_endpoint( + MyServerUdpEchoProtocol, local_addr=addr)) + loop.run_until_complete(t) + + +def start_client(loop, addr): + t = tulip.Task(loop.create_datagram_endpoint( + MyClientUdpEchoProtocol, remote_addr=addr)) + loop.run_until_complete(t) + + +ARGS = argparse.ArgumentParser(description="UDP Echo example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run udp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run udp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') + + +if __name__ == '__main__': + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + loop = tulip.get_event_loop() + if signal is not None: + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if '--server' in sys.argv: + start_server(loop, (args.host, args.port)) + else: + start_client(loop, (args.host, args.port)) + + loop.run_forever() diff --git a/examples/websocket.html b/examples/websocket.html new file mode 100644 index 00000000..6bad7f74 --- /dev/null +++ b/examples/websocket.html @@ -0,0 +1,90 @@ + + + + + + + + +

Chat!

+
+  | Status: + disconnected +
+
+
+
+ + +
+ + diff --git a/examples/wsclient.py b/examples/wsclient.py new file mode 100755 index 00000000..f5b2ef58 --- /dev/null +++ b/examples/wsclient.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +"""websocket cmd client for wssrv.py example.""" +import argparse +import base64 +import hashlib +import os +import signal +import sys + +import tulip +import tulip.http +from tulip.http import websocket +import tulip.selectors + +WS_KEY = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + +def start_client(loop, url): + name = input('Please enter your name: ').encode() + + sec_key = base64.b64encode(os.urandom(16)) + + # send request + response = yield from tulip.http.request( + 'get', url, + headers={ + 'UPGRADE': 'WebSocket', + 'CONNECTION': 'Upgrade', + 'SEC-WEBSOCKET-VERSION': '13', + 'SEC-WEBSOCKET-KEY': sec_key.decode(), + }, timeout=1.0) + + # websocket handshake + if response.status != 101: + raise ValueError("Handshake error: Invalid response status") + if response.get('upgrade', '').lower() != 'websocket': + raise ValueError("Handshake error - Invalid upgrade header") + if response.get('connection', '').lower() != 'upgrade': + raise ValueError("Handshake error - Invalid connection header") + + key = response.get('sec-websocket-accept', '').encode() + match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()) + if key != match: + raise ValueError("Handshake error - Invalid challenge response") + + # switch to websocket protocol + stream = response.stream.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(response.transport) + + # input reader + def stdin_callback(): + line = sys.stdin.buffer.readline() + if not line: + loop.stop() + else: + writer.send(name + b': ' + line) + loop.add_reader(sys.stdin.fileno(), stdin_callback) + + @tulip.coroutine + def dispatch(): + while True: + msg = yield from stream.read() + if msg is None: + break + elif msg.tp == websocket.MSG_PING: + writer.pong() + elif msg.tp == websocket.MSG_TEXT: + print(msg.data.strip()) + elif msg.tp == websocket.MSG_CLOSE: + break + + yield from dispatch() + + +ARGS = argparse.ArgumentParser( + description="websocket console client for wssrv.py example.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') + +if __name__ == '__main__': + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + url = 'http://{}:{}'.format(args.host, args.port) + + loop = tulip.SelectorEventLoop(tulip.selectors.SelectSelector()) + tulip.set_event_loop(loop) + + loop.add_signal_handler(signal.SIGINT, loop.stop) + tulip.Task(start_client(loop, url)) + loop.run_forever() diff --git a/examples/wssrv.py b/examples/wssrv.py new file mode 100755 index 00000000..f96e0855 --- /dev/null +++ b/examples/wssrv.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python3 +"""Multiprocess WebSocket http chat example.""" + +import argparse +import os +import socket +import signal +import time +import tulip +import tulip.http +from tulip.http import websocket + +ARGS = argparse.ArgumentParser(description="Run simple http server.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') +ARGS.add_argument( + '--workers', action="store", dest='workers', + default=2, type=int, help='Number of workers.') + +WS_FILE = os.path.join(os.path.dirname(__file__), 'websocket.html') + + +class HttpServer(tulip.http.ServerHttpProtocol): + + clients = None # list of all active connections + parent = None # process supervisor + # we use it as broadcaster to all workers + + @tulip.coroutine + def handle_request(self, message, payload): + upgrade = False + for hdr, val in message.headers: + if hdr == 'UPGRADE': + upgrade = 'websocket' in val.lower() + break + + if upgrade: + # websocket handshake + status, headers, parser, writer = websocket.do_handshake( + message.method, message.headers, self.transport) + + resp = tulip.http.Response(self.transport, status) + resp.add_headers(*headers) + resp.send_headers() + + # install websocket parser + databuffer = self.stream.set_parser(parser) + + # notify everybody + print('{}: Someone joined.'.format(os.getpid())) + for wsc in self.clients: + wsc.send(b'Someone joined.') + self.clients.append(writer) + self.parent.send(b'Someone joined.') + + # chat dispatcher + while True: + msg = yield from databuffer.read() + if msg is None: # client droped connection + break + + if msg.tp == websocket.MSG_PING: + writer.pong() + + elif msg.tp == websocket.MSG_TEXT: + data = msg.data.strip() + print('{}: {}'.format(os.getpid(), data)) + for wsc in self.clients: + if wsc is not writer: + wsc.send(data.encode()) + self.parent.send(data) + + elif msg.tp == websocket.MSG_CLOSE: + break + + # notify everybody + print('{}: Someone disconnected.'.format(os.getpid())) + self.parent.send(b'Someone disconnected.') + self.clients.remove(writer) + for wsc in self.clients: + wsc.send(b'Someone disconnected.') + + else: + # send html page with js chat + response = tulip.http.Response(self.transport, 200) + response.add_header('Transfer-Encoding', 'chunked') + response.add_header('Content-type', 'text/html') + response.send_headers() + + try: + with open(WS_FILE, 'rb') as fp: + chunk = fp.read(8196) + while chunk: + if not response.write(chunk): + break + chunk = fp.read(8196) + except OSError: + response.write(b'Cannot open') + + response.write_eof() + if response.keep_alive(): + self.keep_alive(True) + + +class ChildProcess: + + def __init__(self, up_read, down_write, args, sock): + self.up_read = up_read + self.down_write = down_write + self.args = args + self.sock = sock + self.clients = [] + + def start(self): + # start server + self.loop = loop = tulip.new_event_loop() + tulip.set_event_loop(loop) + + def stop(): + self.loop.stop() + os._exit(0) + loop.add_signal_handler(signal.SIGINT, stop) + + # heartbeat + self.heartbeat() + + tulip.get_event_loop().run_forever() + os._exit(0) + + @tulip.task + def start_server(self, writer): + socks = yield from self.loop.start_serving( + lambda: HttpServer( + debug=True, keep_alive=75, + parent=writer, clients=self.clients), + sock=self.sock) + print('Starting srv worker process {} on {}'.format( + os.getpid(), socks[0].getsockname())) + + @tulip.task + def heartbeat(self): + # setup pipes + read_transport, read_proto = yield from self.loop.connect_read_pipe( + tulip.StreamProtocol, os.fdopen(self.up_read, 'rb')) + write_transport, _ = yield from self.loop.connect_write_pipe( + tulip.StreamProtocol, os.fdopen(self.down_write, 'wb')) + + reader = read_proto.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(write_transport) + + self.start_server(writer) + + while True: + msg = yield from reader.read() + if msg is None: + print('Superviser is dead, {} stopping...'.format(os.getpid())) + self.loop.stop() + break + elif msg.tp == websocket.MSG_PING: + writer.pong() + elif msg.tp == websocket.MSG_CLOSE: + break + elif msg.tp == websocket.MSG_TEXT: # broadcast message + for wsc in self.clients: + wsc.send(msg.data.strip().encode()) + + read_transport.close() + write_transport.close() + + +class Worker: + + _started = False + + def __init__(self, sv, loop, args, sock): + self.sv = sv + self.loop = loop + self.args = args + self.sock = sock + self.start() + + def start(self): + assert not self._started + self._started = True + + up_read, up_write = os.pipe() + down_read, down_write = os.pipe() + args, sock = self.args, self.sock + + pid = os.fork() + if pid: + # parent + os.close(up_read) + os.close(down_write) + self.connect(pid, up_write, down_read) + else: + # child + os.close(up_write) + os.close(down_read) + + # cleanup after fork + tulip.set_event_loop(None) + + # setup process + process = ChildProcess(up_read, down_write, args, sock) + process.start() + + @tulip.task + def heartbeat(self, writer): + while True: + yield from tulip.sleep(15) + + if (time.monotonic() - self.ping) < 30: + writer.ping() + else: + print('Restart unresponsive worker process: {}'.format( + self.pid)) + self.kill() + self.start() + return + + @tulip.task + def chat(self, reader): + while True: + msg = yield from reader.read() + if msg is None: + print('Restart unresponsive worker process: {}'.format( + self.pid)) + self.kill() + self.start() + return + + elif msg.tp == websocket.MSG_PONG: + self.ping = time.monotonic() + + elif msg.tp == websocket.MSG_TEXT: # broadcast to all workers + for worker in self.sv.workers: + if self.pid != worker.pid: + worker.writer.send(msg.data) + + @tulip.task + def connect(self, pid, up_write, down_read): + # setup pipes + read_transport, proto = yield from self.loop.connect_read_pipe( + tulip.StreamProtocol, os.fdopen(down_read, 'rb')) + write_transport, _ = yield from self.loop.connect_write_pipe( + tulip.StreamProtocol, os.fdopen(up_write, 'wb')) + + # websocket protocol + reader = proto.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(write_transport) + + # store info + self.pid = pid + self.ping = time.monotonic() + self.writer = writer + self.rtransport = read_transport + self.wtransport = write_transport + self.chat_task = self.chat(reader) + self.heartbeat_task = self.heartbeat(writer) + + def kill(self): + self._started = False + self.chat_task.cancel() + self.heartbeat_task.cancel() + self.rtransport.close() + self.wtransport.close() + os.kill(self.pid, signal.SIGTERM) + + +class Superviser: + + def __init__(self, args): + self.loop = tulip.get_event_loop() + self.args = args + self.workers = [] + + def start(self): + # bind socket + sock = self.sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((self.args.host, self.args.port)) + sock.listen(1024) + sock.setblocking(False) + + # start processes + for idx in range(self.args.workers): + self.workers.append(Worker(self, self.loop, self.args, sock)) + + self.loop.add_signal_handler(signal.SIGINT, lambda: self.loop.stop()) + self.loop.run_forever() + + +def main(): + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + superviser = Superviser(args) + superviser.start() + + +if __name__ == '__main__': + main() diff --git a/overlapped.c b/overlapped.c new file mode 100644 index 00000000..3a2c1208 --- /dev/null +++ b/overlapped.c @@ -0,0 +1,1009 @@ +/* + * Support for overlapped IO + * + * Some code borrowed from Modules/_winapi.c of CPython + */ + +/* XXX check overflow and DWORD <-> Py_ssize_t conversions + Check itemsize */ + +#include "Python.h" +#include "structmember.h" + +#define WINDOWS_LEAN_AND_MEAN +#include +#include +#include + +#if defined(MS_WIN32) && !defined(MS_WIN64) +# define F_POINTER "k" +# define T_POINTER T_ULONG +#else +# define F_POINTER "K" +# define T_POINTER T_ULONGLONG +#endif + +#define F_HANDLE F_POINTER +#define F_ULONG_PTR F_POINTER +#define F_DWORD "k" +#define F_BOOL "i" +#define F_UINT "I" + +#define T_HANDLE T_POINTER + +enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_WRITE, TYPE_ACCEPT, + TYPE_CONNECT, TYPE_DISCONNECT}; + +/* + * Map Windows error codes to subclasses of OSError + */ + +static PyObject * +SetFromWindowsErr(DWORD err) +{ + PyObject *exception_type; + + if (err == 0) + err = GetLastError(); + switch (err) { + case ERROR_CONNECTION_REFUSED: + exception_type = PyExc_ConnectionRefusedError; + break; + case ERROR_CONNECTION_ABORTED: + exception_type = PyExc_ConnectionAbortedError; + break; + default: + exception_type = PyExc_OSError; + } + return PyErr_SetExcFromWindowsErr(exception_type, err); +} + +/* + * Some functions should be loaded at runtime + */ + +static LPFN_ACCEPTEX Py_AcceptEx = NULL; +static LPFN_CONNECTEX Py_ConnectEx = NULL; +static LPFN_DISCONNECTEX Py_DisconnectEx = NULL; +static BOOL (CALLBACK *Py_CancelIoEx)(HANDLE, LPOVERLAPPED) = NULL; + +#define GET_WSA_POINTER(s, x) \ + (SOCKET_ERROR != WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, \ + &Guid##x, sizeof(Guid##x), &Py_##x, \ + sizeof(Py_##x), &dwBytes, NULL, NULL)) + +static int +initialize_function_pointers(void) +{ + GUID GuidAcceptEx = WSAID_ACCEPTEX; + GUID GuidConnectEx = WSAID_CONNECTEX; + GUID GuidDisconnectEx = WSAID_DISCONNECTEX; + HINSTANCE hKernel32; + SOCKET s; + DWORD dwBytes; + + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (s == INVALID_SOCKET) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + if (!GET_WSA_POINTER(s, AcceptEx) || + !GET_WSA_POINTER(s, ConnectEx) || + !GET_WSA_POINTER(s, DisconnectEx)) + { + closesocket(s); + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + closesocket(s); + + /* On WinXP we will have Py_CancelIoEx == NULL */ + hKernel32 = GetModuleHandle("KERNEL32"); + *(FARPROC *)&Py_CancelIoEx = GetProcAddress(hKernel32, "CancelIoEx"); + return 0; +} + +/* + * Completion port stuff + */ + +PyDoc_STRVAR( + CreateIoCompletionPort_doc, + "CreateIoCompletionPort(handle, port, key, concurrency) -> port\n\n" + "Create a completion port or register a handle with a port."); + +static PyObject * +overlapped_CreateIoCompletionPort(PyObject *self, PyObject *args) +{ + HANDLE FileHandle; + HANDLE ExistingCompletionPort; + ULONG_PTR CompletionKey; + DWORD NumberOfConcurrentThreads; + HANDLE ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE F_ULONG_PTR F_DWORD, + &FileHandle, &ExistingCompletionPort, &CompletionKey, + &NumberOfConcurrentThreads)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = CreateIoCompletionPort(FileHandle, ExistingCompletionPort, + CompletionKey, NumberOfConcurrentThreads); + Py_END_ALLOW_THREADS + + if (ret == NULL) + return SetFromWindowsErr(0); + return Py_BuildValue(F_HANDLE, ret); +} + +PyDoc_STRVAR( + GetQueuedCompletionStatus_doc, + "GetQueuedCompletionStatus(port, msecs) -> (err, bytes, key, address)\n\n" + "Get a message from completion port. Wait for up to msecs milliseconds."); + +static PyObject * +overlapped_GetQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort = NULL; + DWORD NumberOfBytes = 0; + ULONG_PTR CompletionKey = 0; + OVERLAPPED *Overlapped = NULL; + DWORD Milliseconds; + DWORD err; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, + &CompletionPort, &Milliseconds)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = GetQueuedCompletionStatus(CompletionPort, &NumberOfBytes, + &CompletionKey, &Overlapped, Milliseconds); + Py_END_ALLOW_THREADS + + err = ret ? ERROR_SUCCESS : GetLastError(); + if (Overlapped == NULL) { + if (err == WAIT_TIMEOUT) + Py_RETURN_NONE; + else + return SetFromWindowsErr(err); + } + return Py_BuildValue(F_DWORD F_DWORD F_ULONG_PTR F_POINTER, + err, NumberOfBytes, CompletionKey, Overlapped); +} + +PyDoc_STRVAR( + PostQueuedCompletionStatus_doc, + "PostQueuedCompletionStatus(port, bytes, key, address) -> None\n\n" + "Post a message to completion port."); + +static PyObject * +overlapped_PostQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort; + DWORD NumberOfBytes; + ULONG_PTR CompletionKey; + OVERLAPPED *Overlapped; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD F_ULONG_PTR F_POINTER, + &CompletionPort, &NumberOfBytes, &CompletionKey, + &Overlapped)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = PostQueuedCompletionStatus(CompletionPort, NumberOfBytes, + CompletionKey, Overlapped); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +/* + * Bind socket handle to local port without doing slow getaddrinfo() + */ + +PyDoc_STRVAR( + BindLocal_doc, + "BindLocal(handle, length_of_address_tuple) -> None\n\n" + "Bind a socket handle to an arbitrary local port.\n" + "If length_of_address_tuple is 2 then an AF_INET address is used.\n" + "If length_of_address_tuple is 4 then an AF_INET6 address is used."); + +static PyObject * +overlapped_BindLocal(PyObject *self, PyObject *args) +{ + SOCKET Socket; + int TupleLength; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE "i", &Socket, &TupleLength)) + return NULL; + + if (TupleLength == 2) { + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = 0; + addr.sin_addr.S_un.S_addr = INADDR_ANY; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else if (TupleLength == 4) { + struct sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_port = 0; + addr.sin6_addr = in6addr_any; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else { + PyErr_SetString(PyExc_ValueError, "expected tuple of length 2 or 4"); + return NULL; + } + + if (!ret) + return SetFromWindowsErr(WSAGetLastError()); + Py_RETURN_NONE; +} + +/* + * Mark operation as completed - used when reading produces ERROR_BROKEN_PIPE + */ + +static void +mark_as_completed(OVERLAPPED *ov) +{ + ov->Internal = 0; + if (ov->hEvent != NULL) + SetEvent(ov->hEvent); +} + +/* + * A Python object wrapping an OVERLAPPED structure and other useful data + * for overlapped I/O + */ + +PyDoc_STRVAR( + Overlapped_doc, + "Overlapped object"); + +typedef struct { + PyObject_HEAD + OVERLAPPED overlapped; + /* For convenience, we store the file handle too */ + HANDLE handle; + /* Error returned by last method call */ + DWORD error; + /* Type of operation */ + DWORD type; + /* Buffer used for reading (optional) */ + PyObject *read_buffer; + /* Buffer used for writing (optional) */ + Py_buffer write_buffer; +} OverlappedObject; + + +static PyObject * +Overlapped_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + OverlappedObject *self; + HANDLE event = INVALID_HANDLE_VALUE; + static char *kwlist[] = {"event", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|" F_HANDLE, kwlist, &event)) + return NULL; + + if (event == INVALID_HANDLE_VALUE) { + event = CreateEvent(NULL, TRUE, FALSE, NULL); + if (event == NULL) + return SetFromWindowsErr(0); + } + + self = PyObject_New(OverlappedObject, type); + if (self == NULL) { + if (event != NULL) + CloseHandle(event); + return NULL; + } + + self->handle = NULL; + self->error = 0; + self->type = TYPE_NONE; + self->read_buffer = NULL; + memset(&self->overlapped, 0, sizeof(OVERLAPPED)); + memset(&self->write_buffer, 0, sizeof(Py_buffer)); + if (event) + self->overlapped.hEvent = event; + return (PyObject *)self; +} + +static void +Overlapped_dealloc(OverlappedObject *self) +{ + DWORD bytes; + DWORD olderr = GetLastError(); + BOOL wait = FALSE; + BOOL ret; + + if (!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED) + { + if (Py_CancelIoEx && Py_CancelIoEx(self->handle, &self->overlapped)) + wait = TRUE; + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, + &bytes, wait); + Py_END_ALLOW_THREADS + + switch (ret ? ERROR_SUCCESS : GetLastError()) { + case ERROR_SUCCESS: + case ERROR_NOT_FOUND: + case ERROR_OPERATION_ABORTED: + break; + default: + PyErr_Format( + PyExc_RuntimeError, + "%R still has pending operation at " + "deallocation, the process may crash", self); + PyErr_WriteUnraisable(NULL); + } + } + + if (self->overlapped.hEvent != NULL) + CloseHandle(self->overlapped.hEvent); + + if (self->write_buffer.obj) + PyBuffer_Release(&self->write_buffer); + + Py_CLEAR(self->read_buffer); + PyObject_Del(self); + SetLastError(olderr); +} + +PyDoc_STRVAR( + Overlapped_cancel_doc, + "cancel() -> None\n\n" + "Cancel overlapped operation"); + +static PyObject * +Overlapped_cancel(OverlappedObject *self) +{ + BOOL ret = TRUE; + + if (self->type == TYPE_NOT_STARTED) + Py_RETURN_NONE; + + if (!HasOverlappedIoCompleted(&self->overlapped)) { + Py_BEGIN_ALLOW_THREADS + if (Py_CancelIoEx) + ret = Py_CancelIoEx(self->handle, &self->overlapped); + else + ret = CancelIo(self->handle); + Py_END_ALLOW_THREADS + } + + /* CancelIoEx returns ERROR_NOT_FOUND if the I/O completed in-between */ + if (!ret && GetLastError() != ERROR_NOT_FOUND) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +PyDoc_STRVAR( + Overlapped_getresult_doc, + "getresult(wait=False) -> result\n\n" + "Retrieve result of operation. If wait is true then it blocks\n" + "until the operation is finished. If wait is false and the\n" + "operation is still pending then an error is raised."); + +static PyObject * +Overlapped_getresult(OverlappedObject *self, PyObject *args) +{ + BOOL wait = FALSE; + DWORD transferred = 0; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, "|" F_BOOL, &wait)) + return NULL; + + if (self->type == TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation not yet attempted"); + return NULL; + } + + if (self->type == TYPE_NOT_STARTED) { + PyErr_SetString(PyExc_ValueError, "operation failed to start"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, &transferred, + wait); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + break; + case ERROR_BROKEN_PIPE: + if (self->read_buffer != NULL) + break; + /* fall through */ + default: + return SetFromWindowsErr(err); + } + + switch (self->type) { + case TYPE_READ: + assert(PyBytes_CheckExact(self->read_buffer)); + if (transferred != PyBytes_GET_SIZE(self->read_buffer) && + _PyBytes_Resize(&self->read_buffer, transferred)) + return NULL; + Py_INCREF(self->read_buffer); + return self->read_buffer; + case TYPE_ACCEPT: + case TYPE_CONNECT: + case TYPE_DISCONNECT: + Py_RETURN_NONE; + default: + return PyLong_FromUnsignedLong((unsigned long) transferred); + } +} + +PyDoc_STRVAR( + Overlapped_ReadFile_doc, + "ReadFile(handle, size) -> Overlapped[message]\n\n" + "Start overlapped read"); + +static PyObject * +Overlapped_ReadFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD nread; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &handle, &size)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = ReadFile(handle, PyBytes_AS_STRING(buf), size, &nread, + &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_BROKEN_PIPE: + mark_as_completed(&self->overlapped); + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSARecv_doc, + "RecvFile(handle, size, flags) -> Overlapped[message]\n\n" + "Start overlapped receive"); + +static PyObject * +Overlapped_WSARecv(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD flags = 0; + DWORD nread; + PyObject *buf; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD "|" F_DWORD, + &handle, &size, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + wsabuf.len = size; + wsabuf.buf = PyBytes_AS_STRING(buf); + + Py_BEGIN_ALLOW_THREADS + ret = WSARecv((SOCKET)handle, &wsabuf, 1, &nread, &flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_BROKEN_PIPE: + mark_as_completed(&self->overlapped); + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WriteFile_doc, + "WriteFile(handle, buf) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped write"); + +static PyObject * +Overlapped_WriteFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD written; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &handle, &bufobj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + + Py_BEGIN_ALLOW_THREADS + ret = WriteFile(handle, self->write_buffer.buf, + (DWORD)self->write_buffer.len, + &written, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSASend_doc, + "WSASend(handle, buf, flags) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped send"); + +static PyObject * +Overlapped_WSASend(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD flags; + DWORD written; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O" F_DWORD, + &handle, &bufobj, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + wsabuf.len = (DWORD)self->write_buffer.len; + wsabuf.buf = self->write_buffer.buf; + + Py_BEGIN_ALLOW_THREADS + ret = WSASend((SOCKET)handle, &wsabuf, 1, &written, flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_AcceptEx_doc, + "AcceptEx(listen_handle, accept_handle) -> Overlapped[address_as_bytes]\n\n" + "Start overlapped wait for client to connect"); + +static PyObject * +Overlapped_AcceptEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ListenSocket; + SOCKET AcceptSocket; + DWORD BytesReceived; + DWORD size; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE, + &ListenSocket, &AcceptSocket)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + size = sizeof(struct sockaddr_in6) + 16; + buf = PyBytes_FromStringAndSize(NULL, size*2); + if (!buf) + return NULL; + + self->type = TYPE_ACCEPT; + self->handle = (HANDLE)ListenSocket; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = Py_AcceptEx(ListenSocket, AcceptSocket, PyBytes_AS_STRING(buf), + 0, size, size, &BytesReceived, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + + +static int +parse_address(PyObject *obj, SOCKADDR *Address, int Length) +{ + char *Host; + unsigned short Port; + unsigned long FlowInfo; + unsigned long ScopeId; + + memset(Address, 0, Length); + + if (PyArg_ParseTuple(obj, "sH", &Host, &Port)) + { + Address->sa_family = AF_INET; + if (WSAStringToAddressA(Host, AF_INET, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN*)Address)->sin_port = htons(Port); + return Length; + } + else if (PyArg_ParseTuple(obj, "sHkk", &Host, &Port, &FlowInfo, &ScopeId)) + { + PyErr_Clear(); + Address->sa_family = AF_INET6; + if (WSAStringToAddressA(Host, AF_INET6, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN6*)Address)->sin6_port = htons(Port); + ((SOCKADDR_IN6*)Address)->sin6_flowinfo = FlowInfo; + ((SOCKADDR_IN6*)Address)->sin6_scope_id = ScopeId; + return Length; + } + + return -1; +} + + +PyDoc_STRVAR( + Overlapped_ConnectEx_doc, + "ConnectEx(client_handle, address_as_bytes) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_ConnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ConnectSocket; + PyObject *AddressObj; + char AddressBuf[sizeof(struct sockaddr_in6)]; + SOCKADDR *Address = (SOCKADDR*)AddressBuf; + int Length; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &ConnectSocket, &AddressObj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + Length = sizeof(AddressBuf); + Length = parse_address(AddressObj, Address, Length); + if (Length < 0) + return NULL; + + self->type = TYPE_CONNECT; + self->handle = (HANDLE)ConnectSocket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_ConnectEx(ConnectSocket, Address, Length, + NULL, 0, NULL, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_DisconnectEx_doc, + "DisconnectEx(handle, flags) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_DisconnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET Socket; + DWORD flags; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &Socket, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + self->type = TYPE_DISCONNECT; + self->handle = (HANDLE)Socket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_DisconnectEx(Socket, &self->overlapped, flags, 0); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +static PyObject* +Overlapped_getaddress(OverlappedObject *self) +{ + return PyLong_FromVoidPtr(&self->overlapped); +} + +static PyObject* +Overlapped_getpending(OverlappedObject *self) +{ + return PyBool_FromLong(!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED); +} + +static PyMethodDef Overlapped_methods[] = { + {"getresult", (PyCFunction) Overlapped_getresult, + METH_VARARGS, Overlapped_getresult_doc}, + {"cancel", (PyCFunction) Overlapped_cancel, + METH_NOARGS, Overlapped_cancel_doc}, + {"ReadFile", (PyCFunction) Overlapped_ReadFile, + METH_VARARGS, Overlapped_ReadFile_doc}, + {"WSARecv", (PyCFunction) Overlapped_WSARecv, + METH_VARARGS, Overlapped_WSARecv_doc}, + {"WriteFile", (PyCFunction) Overlapped_WriteFile, + METH_VARARGS, Overlapped_WriteFile_doc}, + {"WSASend", (PyCFunction) Overlapped_WSASend, + METH_VARARGS, Overlapped_WSASend_doc}, + {"AcceptEx", (PyCFunction) Overlapped_AcceptEx, + METH_VARARGS, Overlapped_AcceptEx_doc}, + {"ConnectEx", (PyCFunction) Overlapped_ConnectEx, + METH_VARARGS, Overlapped_ConnectEx_doc}, + {"DisconnectEx", (PyCFunction) Overlapped_DisconnectEx, + METH_VARARGS, Overlapped_DisconnectEx_doc}, + {NULL} +}; + +static PyMemberDef Overlapped_members[] = { + {"error", T_ULONG, + offsetof(OverlappedObject, error), + READONLY, "Error from last operation"}, + {"event", T_HANDLE, + offsetof(OverlappedObject, overlapped) + offsetof(OVERLAPPED, hEvent), + READONLY, "Overlapped event handle"}, + {NULL} +}; + +static PyGetSetDef Overlapped_getsets[] = { + {"address", (getter)Overlapped_getaddress, NULL, + "Address of overlapped structure"}, + {"pending", (getter)Overlapped_getpending, NULL, + "Whether the operation is pending"}, + {NULL}, +}; + +PyTypeObject OverlappedType = { + PyVarObject_HEAD_INIT(NULL, 0) + /* tp_name */ "_overlapped.Overlapped", + /* tp_basicsize */ sizeof(OverlappedObject), + /* tp_itemsize */ 0, + /* tp_dealloc */ (destructor) Overlapped_dealloc, + /* tp_print */ 0, + /* tp_getattr */ 0, + /* tp_setattr */ 0, + /* tp_reserved */ 0, + /* tp_repr */ 0, + /* tp_as_number */ 0, + /* tp_as_sequence */ 0, + /* tp_as_mapping */ 0, + /* tp_hash */ 0, + /* tp_call */ 0, + /* tp_str */ 0, + /* tp_getattro */ 0, + /* tp_setattro */ 0, + /* tp_as_buffer */ 0, + /* tp_flags */ Py_TPFLAGS_DEFAULT, + /* tp_doc */ "OVERLAPPED structure wrapper", + /* tp_traverse */ 0, + /* tp_clear */ 0, + /* tp_richcompare */ 0, + /* tp_weaklistoffset */ 0, + /* tp_iter */ 0, + /* tp_iternext */ 0, + /* tp_methods */ Overlapped_methods, + /* tp_members */ Overlapped_members, + /* tp_getset */ Overlapped_getsets, + /* tp_base */ 0, + /* tp_dict */ 0, + /* tp_descr_get */ 0, + /* tp_descr_set */ 0, + /* tp_dictoffset */ 0, + /* tp_init */ 0, + /* tp_alloc */ 0, + /* tp_new */ Overlapped_new, +}; + +static PyMethodDef overlapped_functions[] = { + {"CreateIoCompletionPort", overlapped_CreateIoCompletionPort, + METH_VARARGS, CreateIoCompletionPort_doc}, + {"GetQueuedCompletionStatus", overlapped_GetQueuedCompletionStatus, + METH_VARARGS, GetQueuedCompletionStatus_doc}, + {"PostQueuedCompletionStatus", overlapped_PostQueuedCompletionStatus, + METH_VARARGS, PostQueuedCompletionStatus_doc}, + {"BindLocal", overlapped_BindLocal, + METH_VARARGS, BindLocal_doc}, + {NULL} +}; + +static struct PyModuleDef overlapped_module = { + PyModuleDef_HEAD_INIT, + "_overlapped", + NULL, + -1, + overlapped_functions, + NULL, + NULL, + NULL, + NULL +}; + +#define WINAPI_CONSTANT(fmt, con) \ + PyDict_SetItemString(d, #con, Py_BuildValue(fmt, con)) + +PyMODINIT_FUNC +PyInit__overlapped(void) +{ + PyObject *m, *d; + + /* Ensure WSAStartup() called before initializing function pointers */ + m = PyImport_ImportModule("_socket"); + if (!m) + return NULL; + Py_DECREF(m); + + if (initialize_function_pointers() < 0) + return NULL; + + if (PyType_Ready(&OverlappedType) < 0) + return NULL; + + m = PyModule_Create(&overlapped_module); + if (PyModule_AddObject(m, "Overlapped", (PyObject *)&OverlappedType) < 0) + return NULL; + + d = PyModule_GetDict(m); + + WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING); + WINAPI_CONSTANT(F_DWORD, INFINITE); + WINAPI_CONSTANT(F_HANDLE, INVALID_HANDLE_VALUE); + WINAPI_CONSTANT(F_HANDLE, NULL); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_ACCEPT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_CONNECT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, TF_REUSE_SOCKET); + + return m; +} diff --git a/runtests.py b/runtests.py new file mode 100644 index 00000000..7c62bf45 --- /dev/null +++ b/runtests.py @@ -0,0 +1,224 @@ +"""Run all unittests. + +Usage: + python3 runtests.py [-v] [-q] [pattern] ... + +Where: + -v: verbose + -q: quiet + pattern: optional regex patterns to match test ids (default all tests) + +Note that the test id is the fully qualified name of the test, +including package, module, class and method, +e.g. 'tests.events_test.PolicyTests.testPolicy'. + +runtests.py with --coverage argument is equivalent of: + + $(COVERAGE) run --branch runtests.py -v + $(COVERAGE) html $(list of files) + $(COVERAGE) report -m $(list of files) + +""" + +# Originally written by Beech Horn (for NDB). + +import argparse +import logging +import os +import re +import sys +import subprocess +import unittest +import importlib.machinery + +from unittest.signals import installHandler + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +ARGS = argparse.ArgumentParser(description="Run all unittests.") +ARGS.add_argument( + '-v', action="store", dest='verbose', + nargs='?', const=1, type=int, default=0, help='verbose') +ARGS.add_argument( + '-x', action="store_true", dest='exclude', help='exclude tests') +ARGS.add_argument( + '-f', '--failfast', action="store_true", default=False, + dest='failfast', help='Stop on first fail or error') +ARGS.add_argument( + '-c', '--catch', action="store_true", default=False, + dest='catchbreak', help='Catch control-C and display results') +ARGS.add_argument( + '--forever', action="store_true", dest='forever', default=False, + help='Run tests forever to catch sporadic errors') +ARGS.add_argument( + '-q', action="store_true", dest='quiet', help='quiet') +ARGS.add_argument( + '--tests', action="store", dest='testsdir', default='tests', + help='tests directory') +ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') +ARGS.add_argument( + 'pattern', action="store", nargs="*", + help='optional regex patterns to match test ids (default all tests)') + +COV_ARGS = argparse.ArgumentParser(description="Run all unittests.") +COV_ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') + + +def load_modules(basedir, suffix='.py'): + def list_dir(prefix, dir): + files = [] + + modpath = os.path.join(dir, '__init__.py') + if os.path.isfile(modpath): + mod = os.path.split(dir)[-1] + files.append(('{}{}'.format(prefix, mod), modpath)) + + prefix = '{}{}.'.format(prefix, mod) + + for name in os.listdir(dir): + path = os.path.join(dir, name) + + if os.path.isdir(path): + files.extend(list_dir('{}{}.'.format(prefix, name), path)) + else: + if (name != '__init__.py' and + name.endswith(suffix) and + not name.startswith(('.', '_'))): + files.append(('{}{}'.format(prefix, name[:-3]), path)) + + return files + + mods = [] + for modname, sourcefile in list_dir('', basedir): + if modname == 'runtests': + continue + try: + loader = importlib.machinery.SourceFileLoader(modname, sourcefile) + mods.append((loader.load_module(), sourcefile)) + except SyntaxError: + raise + except Exception as err: + print("Skipping '{}': {}".format(modname, err), file=sys.stderr) + + return mods + + +def load_tests(testsdir, includes=(), excludes=()): + mods = [mod for mod, _ in load_modules(testsdir)] + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + for mod in mods: + for name in set(dir(mod)): + if name.endswith('Tests'): + test_module = getattr(mod, name) + tests = loader.loadTestsFromTestCase(test_module) + if includes: + tests = [test + for test in tests + if any(re.search(pat, test.id()) + for pat in includes)] + if excludes: + tests = [test + for test in tests + if not any(re.search(pat, test.id()) + for pat in excludes)] + suite.addTests(tests) + + return suite + + +def runtests(): + args = ARGS.parse_args() + + testsdir = os.path.abspath(args.testsdir) + if not os.path.isdir(testsdir): + print("Tests directory is not found: {}\n".format(testsdir)) + ARGS.print_help() + return + + excludes = includes = [] + if args.exclude: + excludes = args.pattern + else: + includes = args.pattern + + v = 0 if args.quiet else args.verbose + 1 + failfast = args.failfast + catchbreak = args.catchbreak + + tests = load_tests(args.testsdir, includes, excludes) + logger = logging.getLogger() + if v == 0: + logger.setLevel(logging.CRITICAL) + elif v == 1: + logger.setLevel(logging.ERROR) + elif v == 2: + logger.setLevel(logging.WARNING) + elif v == 3: + logger.setLevel(logging.INFO) + elif v >= 4: + logger.setLevel(logging.DEBUG) + if catchbreak: + installHandler() + if args.forever: + while True: + result = unittest.TextTestRunner(verbosity=v, + failfast=failfast).run(tests) + if not result.wasSuccessful(): + sys.exit(1) + else: + result = unittest.TextTestRunner(verbosity=v, + failfast=failfast).run(tests) + sys.exit(not result.wasSuccessful()) + + +def runcoverage(sdir, args): + """ + To install coverage3 for Python 3, you need: + - Setuptools (https://pypi.python.org/pypi/setuptools) + + What worked for me: + - download bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py + * curl -O \ + https://bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py + - python3 ez_setup.py + - python3 -m easy_install coverage + """ + try: + import coverage + except ImportError: + print("Coverage package is not found.") + print(runcoverage.__doc__) + return + + sdir = os.path.abspath(sdir) + if not os.path.isdir(sdir): + print("Python files directory is not found: {}\n".format(sdir)) + ARGS.print_help() + return + + mods = [source for _, source in load_modules(sdir)] + coverage = [sys.executable, '-m', 'coverage'] + + try: + subprocess.check_call( + coverage + ['run', '--branch', 'runtests.py'] + args) + except: + pass + else: + subprocess.check_call(coverage + ['html'] + mods) + subprocess.check_call(coverage + ['report'] + mods) + + +if __name__ == '__main__': + if '--coverage' in sys.argv: + cov_args, args = COV_ARGS.parse_known_args() + runcoverage(cov_args.coverage, args) + else: + runtests() diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..0260f9d5 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[build_ext] +build_lib=tulip diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..dcaee96f --- /dev/null +++ b/setup.py @@ -0,0 +1,14 @@ +import os +from distutils.core import setup, Extension + +extensions = [] +if os.name == 'nt': + ext = Extension('_overlapped', ['overlapped.c'], libraries=['ws2_32']) + extensions.append(ext) + +setup(name='tulip', + description="reference implementation of PEP 3156", + url='http://www.python.org/dev/peps/pep-3156/', + packages=['tulip'], + ext_modules=extensions + ) diff --git a/tests/base_events_test.py b/tests/base_events_test.py new file mode 100644 index 00000000..104dc763 --- /dev/null +++ b/tests/base_events_test.py @@ -0,0 +1,592 @@ +"""Tests for base_events.py""" + +import logging +import socket +import time +import unittest +import unittest.mock + +from tulip import base_events +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import tasks +from tulip import test_utils + + +class BaseEventLoopTests(unittest.TestCase): + + def setUp(self): + self.loop = base_events.BaseEventLoop() + self.loop._selector = unittest.mock.Mock() + self.loop._selector.registered_count.return_value = 1 + events.set_event_loop(None) + + def test_not_implemented(self): + m = unittest.mock.Mock() + self.assertRaises( + NotImplementedError, + self.loop._make_socket_transport, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_ssl_transport, m, m, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_datagram_transport, m, m) + self.assertRaises( + NotImplementedError, self.loop._process_events, []) + self.assertRaises( + NotImplementedError, self.loop._write_to_self) + self.assertRaises( + NotImplementedError, self.loop._read_from_self) + self.assertRaises( + NotImplementedError, + self.loop._make_read_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_write_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + next, self.loop._make_subprocess_transport(m, m, m, m, m, m, m)) + + def test__add_callback_handle(self): + h = events.Handle(lambda: False, ()) + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertIn(h, self.loop._ready) + + def test__add_callback_timer(self): + h = events.TimerHandle(time.monotonic()+10, lambda: False, ()) + + self.loop._add_callback(h) + self.assertIn(h, self.loop._scheduled) + + def test__add_callback_cancelled_handle(self): + h = events.Handle(lambda: False, ()) + h.cancel() + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertFalse(self.loop._ready) + + def test_set_default_executor(self): + executor = unittest.mock.Mock() + self.loop.set_default_executor(executor) + self.assertIs(executor, self.loop._default_executor) + + def test_getnameinfo(self): + sockaddr = unittest.mock.Mock() + self.loop.run_in_executor = unittest.mock.Mock() + self.loop.getnameinfo(sockaddr) + self.assertEqual( + (None, socket.getnameinfo, sockaddr, 0), + self.loop.run_in_executor.call_args[0]) + + def test_call_soon(self): + def cb(): + pass + + h = self.loop.call_soon(cb) + self.assertEqual(h._callback, cb) + self.assertIsInstance(h, events.Handle) + self.assertIn(h, self.loop._ready) + + def test_call_later(self): + def cb(): + pass + + h = self.loop.call_later(10.0, cb) + self.assertIsInstance(h, events.TimerHandle) + self.assertIn(h, self.loop._scheduled) + self.assertNotIn(h, self.loop._ready) + + def test_call_later_negative_delays(self): + calls = [] + + def cb(arg): + calls.append(arg) + + self.loop._process_events = unittest.mock.Mock() + self.loop.call_later(-1, cb, 'a') + self.loop.call_later(-2, cb, 'b') + test_utils.run_briefly(self.loop) + self.assertEqual(calls, ['b', 'a']) + + def test_time_and_call_at(self): + def cb(): + self.loop.stop() + + self.loop._process_events = unittest.mock.Mock() + when = self.loop.time() + 0.1 + self.loop.call_at(when, cb) + t0 = self.loop.time() + self.loop.run_forever() + t1 = self.loop.time() + self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0) + + def test_run_once_in_executor_handle(self): + def cb(): + pass + + self.assertRaises( + AssertionError, self.loop.run_in_executor, + None, events.Handle(cb, ()), ('',)) + self.assertRaises( + AssertionError, self.loop.run_in_executor, + None, events.TimerHandle(10, cb, ())) + + def test_run_once_in_executor_cancelled(self): + def cb(): + pass + h = events.Handle(cb, ()) + h.cancel() + + f = self.loop.run_in_executor(None, h) + self.assertIsInstance(f, futures.Future) + self.assertTrue(f.done()) + self.assertIsNone(f.result()) + + def test_run_once_in_executor_plain(self): + def cb(): + pass + h = events.Handle(cb, ()) + f = futures.Future(loop=self.loop) + executor = unittest.mock.Mock() + executor.submit.return_value = f + + self.loop.set_default_executor(executor) + + res = self.loop.run_in_executor(None, h) + self.assertIs(f, res) + + executor = unittest.mock.Mock() + executor.submit.return_value = f + res = self.loop.run_in_executor(executor, h) + self.assertIs(f, res) + self.assertTrue(executor.submit.called) + + f.cancel() # Don't complain about abandoned Future. + + def test__run_once(self): + h1 = events.TimerHandle(time.monotonic() + 0.1, lambda: True, ()) + h2 = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ()) + + h1.cancel() + + self.loop._process_events = unittest.mock.Mock() + self.loop._scheduled.append(h1) + self.loop._scheduled.append(h2) + self.loop._run_once() + + t = self.loop._selector.select.call_args[0][0] + self.assertTrue(9.99 < t < 10.1, t) + self.assertEqual([h2], self.loop._scheduled) + self.assertTrue(self.loop._process_events.called) + + @unittest.mock.patch('tulip.base_events.time') + @unittest.mock.patch('tulip.base_events.tulip_log') + def test__run_once_logging(self, m_logging, m_time): + # Log to INFO level if timeout > 1.0 sec. + idx = -1 + data = [10.0, 10.0, 12.0, 13.0] + + def monotonic(): + nonlocal data, idx + idx += 1 + return data[idx] + + m_time.monotonic = monotonic + m_logging.INFO = logging.INFO + m_logging.DEBUG = logging.DEBUG + + self.loop._scheduled.append( + events.TimerHandle(11.0, lambda: True, ())) + self.loop._process_events = unittest.mock.Mock() + self.loop._run_once() + self.assertEqual(logging.INFO, m_logging.log.call_args[0][0]) + + idx = -1 + data = [10.0, 10.0, 10.3, 13.0] + self.loop._scheduled = [events.TimerHandle(11.0, lambda:True, ())] + self.loop._run_once() + self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0]) + + def test__run_once_schedule_handle(self): + handle = None + processed = False + + def cb(loop): + nonlocal processed, handle + processed = True + handle = loop.call_soon(lambda: True) + + h = events.TimerHandle(time.monotonic() - 1, cb, (self.loop,)) + + self.loop._process_events = unittest.mock.Mock() + self.loop._scheduled.append(h) + self.loop._run_once() + + self.assertTrue(processed) + self.assertEqual([handle], list(self.loop._ready)) + + def test_run_until_complete_type_error(self): + self.assertRaises( + TypeError, self.loop.run_until_complete, 'blah') + + +class MyProto(protocols.Protocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = futures.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyDatagramProto(protocols.DatagramProtocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = futures.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class BaseEventLoopWithSelectorTests(unittest.TestCase): + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_connection_mutiple_errors(self, m_socket): + + class MyProto(protocols.Protocol): + pass + + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + idx = -1 + errors = ['err1', 'err2'] + + def _socket(*args, **kw): + nonlocal idx, errors + idx += 1 + raise OSError(errors[idx]) + + m_socket.socket = _socket + + self.loop.getaddrinfo = getaddrinfo_task + + task = tasks.Task( + self.loop.create_connection(MyProto, 'example.com', 80)) + yield from tasks.wait(task) + exc = task.exception() + self.assertEqual("Multiple exceptions: err1, err2", str(exc)) + + def test_create_connection_host_port_sock(self): + coro = self.loop.create_connection( + MyProto, 'example.com', 80, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_host_port_sock(self): + coro = self.loop.create_connection(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_getaddrinfo(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_connect_err(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_multiple(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET) + with self.assertRaises(OSError): + self.loop.run_until_complete(coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_connection_mutiple_errors_local_addr(self, m_socket): + + def bind(addr): + if addr[0] == '0.0.0.1': + err = OSError('Err') + err.strerror = 'Err' + raise err + + m_socket.socket.return_value.bind = bind + + @tasks.coroutine + def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError('Err2') + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(coro) + + self.assertTrue(str(cm.exception), 'Multiple exceptions: ') + self.assertTrue(m_socket.socket.return_value.close.called) + + def test_create_connection_no_local_addr(self): + @tasks.coroutine + def getaddrinfo(host, *args, **kw): + if host == 'example.com': + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + else: + return [] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + self.loop.getaddrinfo = getaddrinfo_task + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_start_serving_empty_host(self): + # if host is empty string use None instead + host = object() + + @tasks.coroutine + def getaddrinfo(*args, **kw): + nonlocal host + host = args[0] + yield from [] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + fut = self.loop.start_serving(MyProto, '', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertIsNone(host) + + def test_start_serving_host_port_sock(self): + fut = self.loop.start_serving( + MyProto, '0.0.0.0', 0, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_start_serving_no_host_port_sock(self): + fut = self.loop.start_serving(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_start_serving_no_getaddrinfo(self): + getaddrinfo = self.loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + f = self.loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, f) + + @unittest.mock.patch('tulip.base_events.socket') + def test_start_serving_cant_bind(self, m_socket): + + class Err(OSError): + strerror = 'error' + + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_no_addrinfo(self, m_socket): + m_socket.getaddrinfo.return_value = [] + + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_addr_error(self): + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr='localhost') + self.assertRaises( + AssertionError, self.loop.run_until_complete, coro) + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 1, 2, 3)) + self.assertRaises( + AssertionError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_connect_err(self): + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_socket_err(self, m_socket): + m_socket.getaddrinfo = socket.getaddrinfo + m_socket.socket.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, local_addr=('127.0.0.1', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_no_matching_family(self): + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, + remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) + self.assertRaises( + ValueError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_setblk_err(self, m_socket): + m_socket.socket.return_value.setblocking.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + self.assertTrue( + m_socket.socket.return_value.close.called) + + def test_create_datagram_endpoint_noaddr_nofamily(self): + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_cant_bind(self, m_socket): + class Err(OSError): + pass + + m_socket.AF_INET6 = socket.AF_INET6 + m_socket.getaddrinfo = socket.getaddrinfo + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, + local_addr=('127.0.0.1', 0), family=socket.AF_INET) + self.assertRaises(Err, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_accept_connection_retry(self): + sock = unittest.mock.Mock() + sock.accept.side_effect = BlockingIOError() + + self.loop._accept_connection(MyProto, sock) + self.assertFalse(sock.close.called) + + @unittest.mock.patch('tulip.selector_events.tulip_log') + def test_accept_connection_exception(self, m_log): + sock = unittest.mock.Mock() + sock.accept.side_effect = OSError() + + self.loop._accept_connection(MyProto, sock) + self.assertTrue(sock.close.called) + self.assertTrue(m_log.exception.called) diff --git a/tests/echo.py b/tests/echo.py new file mode 100644 index 00000000..f6ac0a30 --- /dev/null +++ b/tests/echo.py @@ -0,0 +1,6 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + os.write(1, buf) diff --git a/tests/echo2.py b/tests/echo2.py new file mode 100644 index 00000000..e83ca09f --- /dev/null +++ b/tests/echo2.py @@ -0,0 +1,6 @@ +import os + +if __name__ == '__main__': + buf = os.read(0, 1024) + os.write(1, b'OUT:'+buf) + os.write(2, b'ERR:'+buf) diff --git a/tests/echo3.py b/tests/echo3.py new file mode 100644 index 00000000..f1f7ea7c --- /dev/null +++ b/tests/echo3.py @@ -0,0 +1,9 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + try: + os.write(1, b'OUT:'+buf) + except OSError as ex: + os.write(2, b'ERR:' + ex.__class__.__name__.encode('ascii')) diff --git a/tests/events_test.py b/tests/events_test.py new file mode 100644 index 00000000..16058f19 --- /dev/null +++ b/tests/events_test.py @@ -0,0 +1,1583 @@ +"""Tests for events.py.""" + +import functools +import gc +import io +import os +import re +import signal +import socket +try: + import ssl +except ImportError: + ssl = None +import subprocess +import sys +import threading +import time +import errno +import unittest +import unittest.mock +from test.support import find_unused_port + + +from tulip import futures +from tulip import events +from tulip import transports +from tulip import protocols +from tulip import selector_events +from tulip import tasks +from tulip import test_utils +from tulip import locks + + +class MyProto(protocols.Protocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyDatagramProto(protocols.DatagramProtocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyReadPipeProto(protocols.Protocol): + done = None + + def __init__(self, loop=None): + self.state = ['INITIAL'] + self.nbytes = 0 + self.transport = None + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == ['INITIAL'], self.state + self.state.append('CONNECTED') + + def data_received(self, data): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.state.append('EOF') + self.transport.close() + + def connection_lost(self, exc): + assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state + self.state.append('CLOSED') + if self.done: + self.done.set_result(None) + + +class MyWritePipeProto(protocols.BaseProtocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.transport = None + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MySubprocessProtocol(protocols.SubprocessProtocol): + + def __init__(self, loop): + self.state = 'INITIAL' + self.transport = None + self.connected = futures.Future(loop=loop) + self.completed = futures.Future(loop=loop) + self.disconnects = {fd: futures.Future(loop=loop) for fd in range(3)} + self.data = {1: b'', 2: b''} + self.returncode = None + self.got_data = {1: locks.EventWaiter(loop=loop), + 2: locks.EventWaiter(loop=loop)} + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + self.connected.set_result(None) + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + self.completed.set_result(None) + + def pipe_data_received(self, fd, data): + assert self.state == 'CONNECTED', self.state + self.data[fd] += data + self.got_data[fd].set() + + def pipe_connection_lost(self, fd, exc): + assert self.state == 'CONNECTED', self.state + if exc: + self.disconnects[fd].set_exception(exc) + else: + self.disconnects[fd].set_result(exc) + + def process_exited(self): + assert self.state == 'CONNECTED', self.state + self.returncode = self.transport.get_returncode() + + +class EventLoopTestsMixin: + + def setUp(self): + super().setUp() + self.loop = self.create_event_loop() + events.set_event_loop(None) + + def tearDown(self): + # just in case if we have transport close callbacks + test_utils.run_briefly(self.loop) + + self.loop.close() + gc.collect() + super().tearDown() + + def test_run_until_complete_nesting(self): + @tasks.coroutine + def coro1(): + yield + + @tasks.coroutine + def coro2(): + self.assertTrue(self.loop.is_running()) + self.loop.run_until_complete(coro1()) + + self.assertRaises( + RuntimeError, self.loop.run_until_complete, coro2()) + + # Note: because of the default Windows timing granularity of + # 15.6 msec, we use fairly long sleep times here (~100 msec). + + def test_run_until_complete(self): + t0 = self.loop.time() + self.loop.run_until_complete(tasks.sleep(0.1, loop=self.loop)) + t1 = self.loop.time() + self.assertTrue(0.08 <= t1-t0 <= 0.12, t1-t0) + + def test_run_until_complete_stopped(self): + @tasks.coroutine + def cb(): + self.loop.stop() + yield from tasks.sleep(0.1, loop=self.loop) + task = cb() + self.assertRaises(RuntimeError, + self.loop.run_until_complete, task) + + def test_run_until_complete_timeout(self): + t0 = self.loop.time() + task = tasks.async(tasks.sleep(0.2, loop=self.loop), loop=self.loop) + self.assertRaises(futures.TimeoutError, + self.loop.run_until_complete, + task, timeout=0.1) + t1 = self.loop.time() + self.assertTrue(0.08 <= t1-t0 <= 0.12, t1-t0) + self.loop.run_until_complete(task) + t2 = self.loop.time() + self.assertTrue(0.18 <= t2-t0 <= 0.22, t2-t0) + + def test_call_later(self): + results = [] + + def callback(arg): + results.append(arg) + self.loop.stop() + + self.loop.call_later(0.1, callback, 'hello world') + t0 = time.monotonic() + self.loop.run_forever() + t1 = time.monotonic() + self.assertEqual(results, ['hello world']) + self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0) + + def test_call_soon(self): + results = [] + + def callback(arg1, arg2): + results.append((arg1, arg2)) + self.loop.stop() + + self.loop.call_soon(callback, 'hello', 'world') + self.loop.run_forever() + self.assertEqual(results, [('hello', 'world')]) + + def test_call_soon_threadsafe(self): + results = [] + lock = threading.Lock() + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.loop.stop() + + def run_in_thread(): + self.loop.call_soon_threadsafe(callback, 'hello') + lock.release() + + lock.acquire() + t = threading.Thread(target=run_in_thread) + t.start() + + with lock: + self.loop.call_soon(callback, 'world') + self.loop.run_forever() + t.join() + self.assertEqual(results, ['hello', 'world']) + + def test_call_soon_threadsafe_same_thread(self): + results = [] + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.loop.stop() + + self.loop.call_soon_threadsafe(callback, 'hello') + self.loop.call_soon(callback, 'world') + self.loop.run_forever() + self.assertEqual(results, ['hello', 'world']) + + def test_run_in_executor(self): + def run(arg): + return (arg, threading.get_ident()) + f2 = self.loop.run_in_executor(None, run, 'yo') + res, thread_id = self.loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + self.assertNotEqual(thread_id, threading.get_ident()) + + def test_reader_callback(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.loop.remove_reader(r.fileno())) + r.close() + + self.loop.add_reader(r.fileno(), reader) + self.loop.call_soon(w.send, b'abc') + test_utils.run_briefly(self.loop) + self.loop.call_soon(w.send, b'def') + self.loop.call_soon(w.close) + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_writer_callback(self): + r, w = test_utils.socketpair() + w.setblocking(False) + self.loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + test_utils.run_briefly(self.loop) + + def remove_writer(): + self.assertTrue(self.loop.remove_writer(w.fileno())) + + self.loop.call_soon(remove_writer) + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + w.close() + data = r.recv(256*1024) + r.close() + self.assertGreaterEqual(len(data), 200) + + def test_sock_client_ops(self): + with test_utils.run_test_server(self.loop) as httpd: + sock = socket.socket() + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, httpd.address)) + self.loop.run_until_complete( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + # consume data + self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + sock.close() + + self.assertTrue(re.match(rb'HTTP/1.0 200 OK', data), data) + + def test_sock_client_fail(self): + # Make sure that we will get an unused port + address = None + try: + s = socket.socket() + s.bind(('127.0.0.1', 0)) + address = s.getsockname() + finally: + s.close() + + sock = socket.socket() + sock.setblocking(False) + with self.assertRaises(ConnectionRefusedError): + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) + sock.close() + + def test_sock_accept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + + f = self.loop.sock_accept(listener) + conn, addr = self.loop.run_until_complete(f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + def test_add_signal_handler(self): + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + # Check error behavior first. + self.assertRaises( + TypeError, self.loop.add_signal_handler, 'boom', my_handler) + self.assertRaises( + TypeError, self.loop.remove_signal_handler, 'boom') + self.assertRaises( + ValueError, self.loop.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, signal.NSIG+1) + self.assertRaises( + ValueError, self.loop.add_signal_handler, 0, my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, 0) + self.assertRaises( + ValueError, self.loop.add_signal_handler, -1, my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, -1) + self.assertRaises( + RuntimeError, self.loop.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(self.loop.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + self.loop.add_signal_handler(signal.SIGINT, my_handler) + test_utils.run_briefly(self.loop) + os.kill(os.getpid(), signal.SIGINT) + test_utils.run_briefly(self.loop) + self.assertEqual(caught, 1) + # Removing it should restore the default handler. + self.assertTrue(self.loop.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(self.loop.remove_signal_handler(signal.SIGINT)) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_while_selecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + self.loop.stop() + + self.loop.add_signal_handler(signal.SIGALRM, my_handler) + + signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once. + self.loop.run_forever() + self.assertEqual(caught, 1) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_args(self): + some_args = (42,) + caught = 0 + + def my_handler(*args): + nonlocal caught + caught += 1 + self.assertEqual(args, some_args) + + self.loop.add_signal_handler(signal.SIGALRM, my_handler, *some_args) + + signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once. + self.loop.call_later(0.015, self.loop.stop) + self.loop.run_forever() + self.assertEqual(caught, 1) + + def test_create_connection(self): + with test_utils.run_test_server(self.loop) as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), *httpd.address) + tr, pr = self.loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def test_create_connection_sock(self): + with test_utils.run_test_server(self.loop) as httpd: + sock = None + infos = self.loop.run_until_complete( + self.loop.getaddrinfo( + *httpd.address, type=socket.SOCK_STREAM)) + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) + except: + pass + else: + break + else: + assert False, 'Can not create socket.' + + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), sock=sock) + tr, pr = self.loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_ssl_connection(self): + with test_utils.run_test_server( + self.loop, use_ssl=True) as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), *httpd.address, ssl=True) + tr, pr = self.loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + self.assertTrue( + hasattr(tr.get_extra_info('socket'), 'getsockname')) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def test_create_connection_local_addr(self): + with test_utils.run_test_server(self.loop) as httpd: + port = find_unused_port() + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, local_addr=(httpd.address[0], port)) + tr, pr = self.loop.run_until_complete(f) + expected = pr.transport.get_extra_info('socket').getsockname()[1] + self.assertEqual(port, expected) + tr.close() + + def test_create_connection_local_addr_in_use(self): + with test_utils.run_test_server(self.loop) as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, local_addr=httpd.address) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + self.assertIn(str(httpd.address), cm.exception.strerror) + + def test_start_serving(self): + proto = None + + def factory(): + nonlocal proto + proto = MyProto() + return proto + + f = self.loop.start_serving(factory, '0.0.0.0', 0) + socks = self.loop.run_until_complete(f) + self.assertEqual(len(socks), 1) + sock = socks[0] + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + test_utils.run_briefly(self.loop) + self.assertIsInstance(proto, MyProto) + self.assertEqual('INITIAL', proto.state) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + test_utils.run_briefly(self.loop) # windows iocp + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('socket')) + conn = proto.transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + self.assertEqual( + '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + + # close connection + proto.transport.close() + test_utils.run_briefly(self.loop) # windows iocp + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # close start_serving socks + self.loop.stop_serving(sock) + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_start_serving_ssl(self): + proto = None + + class ClientMyProto(MyProto): + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def factory(): + nonlocal proto + proto = MyProto(loop=self.loop) + return proto + + here = os.path.dirname(__file__) + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain( + certfile=os.path.join(here, 'sample.crt'), + keyfile=os.path.join(here, 'sample.key')) + + f = self.loop.start_serving( + factory, '127.0.0.1', 0, ssl=sslcontext) + + sock = self.loop.run_until_complete(f)[0] + host, port = sock.getsockname() + self.assertEqual(host, '127.0.0.1') + + f_c = self.loop.create_connection(ClientMyProto, host, port, ssl=True) + client, pr = self.loop.run_until_complete(f_c) + + client.write(b'xxx') + test_utils.run_briefly(self.loop) + self.assertIsInstance(proto, MyProto) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('socket')) + conn = proto.transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + self.assertEqual( + '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # stop serving + self.loop.stop_serving(sock) + + def test_start_serving_sock(self): + proto = futures.Future(loop=self.loop) + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + proto.set_result(self) + + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.loop.start_serving(TestMyProto, sock=sock_ob) + sock = self.loop.run_until_complete(f)[0] + self.assertIs(sock, sock_ob) + + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + + self.loop.stop_serving(sock) + + def test_start_serving_addr_in_use(self): + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.loop.start_serving(MyProto, sock=sock_ob) + sock = self.loop.run_until_complete(f)[0] + host, port = sock.getsockname() + + f = self.loop.start_serving(MyProto, host=host, port=port) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + + self.loop.stop_serving(sock) + + @unittest.skipUnless(socket.has_ipv6, 'IPv6 not supported') + def test_start_serving_dual_stack(self): + f_proto = futures.Future(loop=self.loop) + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + f_proto.set_result(self) + + try_count = 0 + while True: + try: + port = find_unused_port() + f = self.loop.start_serving(TestMyProto, host=None, port=port) + socks = self.loop.run_until_complete(f) + except OSError as ex: + if ex.errno == errno.EADDRINUSE: + try_count += 1 + self.assertGreaterEqual(5, try_count) + continue + else: + raise + else: + break + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + proto = self.loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + f_proto = futures.Future(loop=self.loop) + client = socket.socket(socket.AF_INET6) + client.connect(('::1', port)) + client.send(b'xxx') + proto = self.loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + for s in socks: + self.loop.stop_serving(s) + + def test_stop_serving(self): + f = self.loop.start_serving(MyProto, '0.0.0.0', 0) + socks = self.loop.run_until_complete(f) + sock = socks[0] + host, port = sock.getsockname() + + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + + self.loop.stop_serving(sock) + + client = socket.socket() + self.assertRaises( + ConnectionRefusedError, client.connect, ('127.0.0.1', port)) + client.close() + + def test_create_datagram_endpoint(self): + class TestMyDatagramProto(MyDatagramProto): + def __init__(inner_self): + super().__init__(loop=self.loop) + + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + coro = self.loop.create_datagram_endpoint( + TestMyDatagramProto, local_addr=('127.0.0.1', 0)) + s_transport, server = self.loop.run_until_complete(coro) + host, port = s_transport.get_extra_info('addr') + + coro = self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(loop=self.loop), + remote_addr=(host, port)) + transport, client = self.loop.run_until_complete(coro) + + self.assertEqual('INITIALIZED', client.state) + transport.sendto(b'xxx') + test_utils.run_briefly(self.loop) + self.assertEqual(3, server.nbytes) + test_utils.run_briefly(self.loop) + + # received + self.assertEqual(8, client.nbytes) + + # extra info is available + self.assertIsNotNone(transport.get_extra_info('socket')) + conn = transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + + # close connection + transport.close() + self.loop.run_until_complete(client.done) + self.assertEqual('CLOSED', client.state) + server.transport.close() + + def test_internal_fds(self): + loop = self.create_event_loop() + if not isinstance(loop, selector_events.BaseSelectorEventLoop): + return + + self.assertEqual(1, loop._internal_fds) + loop.close() + self.assertEqual(0, loop._internal_fds) + self.assertIsNone(loop._csock) + self.assertIsNone(loop._ssock) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_read_pipe(self): + proto = None + + def factory(): + nonlocal proto + proto = MyReadPipeProto(loop=self.loop) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(rpipe, 'rb', 1024) + + @tasks.coroutine + def connect(): + t, p = yield from self.loop.connect_read_pipe(factory, pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(0, proto.nbytes) + + self.loop.run_until_complete(connect()) + + os.write(wpipe, b'1') + test_utils.run_briefly(self.loop) + self.assertEqual(1, proto.nbytes) + + os.write(wpipe, b'2345') + test_utils.run_briefly(self.loop) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(5, proto.nbytes) + + os.close(wpipe) + self.loop.run_until_complete(proto.done) + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto(loop=self.loop) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.coroutine + def connect(): + nonlocal transport + t, p = yield from self.loop.connect_write_pipe(factory, pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.loop.run_until_complete(connect()) + + transport.write(b'1') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + transport.write(b'2345') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'2345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(rpipe) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe_disconnect_on_close(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto(loop=self.loop) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.coroutine + def connect(): + nonlocal transport + t, p = yield from self.loop.connect_write_pipe(factory, + pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.loop.run_until_complete(connect()) + self.assertEqual('CONNECTED', proto.state) + + transport.write(b'1') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + os.close(rpipe) + + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + def test_prompt_cancellation(self): + r, w = test_utils.socketpair() + r.setblocking(False) + f = self.loop.sock_recv(r, 1) + ov = getattr(f, 'ov', None) + self.assertTrue(ov is None or ov.pending) + + def main(): + try: + self.loop.call_soon(f.cancel) + yield from f + except futures.CancelledError: + res = 'cancelled' + else: + res = None + finally: + self.loop.stop() + return res + + start = time.monotonic() + t = tasks.Task(main(), timeout=1, loop=self.loop) + self.loop.run_forever() + elapsed = time.monotonic() - start + + self.assertLess(elapsed, 0.1) + self.assertEqual(t.result(), 'cancelled') + self.assertRaises(futures.CancelledError, f.result) + self.assertTrue(ov is None or not ov.pending) + self.loop.stop_serving(r) + + r.close() + w.close() + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_exec(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python The Winner') + self.loop.run_until_complete(proto.got_data[1].wait(1)) + transp.close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + self.assertEqual(b'Python The Winner', proto.data[1]) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_interactive(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + try: + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python ') + self.loop.run_until_complete(proto.got_data[1].wait(1)) + proto.got_data[1].clear() + self.assertEqual(b'Python ', proto.data[1]) + + stdin.write(b'The Winner') + self.loop.run_until_complete(proto.got_data[1].wait(1)) + self.assertEqual(b'Python The Winner', proto.data[1]) + finally: + transp.close() + + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_shell(self): + proto = None + transp = None + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'echo "Python"') + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.get_pipe_transport(0).close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(0, proto.returncode) + self.assertTrue(all(f.done() for f in proto.disconnects.values())) + self.assertEqual({1: b'Python\n', 2: b''}, proto.data) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_exitcode(self): + proto = None + + @tasks.coroutine + def connect(): + nonlocal proto + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_close_after_finish(self): + proto = None + transp = None + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.assertIsNone(transp.get_pipe_transport(0)) + self.assertIsNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + self.assertIsNone(transp.close()) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_kill(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.kill() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGKILL, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_send_signal(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.send_signal(signal.SIGHUP) + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGHUP, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_stderr(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'test') + + self.loop.run_until_complete(proto.completed) + + transp.close() + self.assertEqual(b'OUT:test', proto.data[1]) + self.assertTrue(proto.data[2].startswith(b'ERR:test'), proto.data[2]) + self.assertEqual(0, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_stderr_redirect_to_stdout(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog, stderr=subprocess.STDOUT) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + self.assertIsNotNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + + stdin.write(b'test') + self.loop.run_until_complete(proto.completed) + self.assertTrue(proto.data[1].startswith(b'OUT:testERR:test'), + proto.data[1]) + self.assertEqual(b'', proto.data[2]) + + transp.close() + self.assertEqual(0, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_close_client_stream(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo3.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdout = transp.get_pipe_transport(1) + stdin.write(b'test') + self.loop.run_until_complete(proto.got_data[1].wait(1)) + self.assertEqual(b'OUT:test', proto.data[1]) + + stdout.close() + self.loop.run_until_complete(proto.disconnects[1]) + stdin.write(b'xxx') + self.loop.run_until_complete(proto.got_data[2].wait(1)) + self.assertEqual(b'ERR:BrokenPipeError', proto.data[2]) + + transp.close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + + +if sys.platform == 'win32': + from tulip import windows_events + + class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return windows_events.SelectorEventLoop() + + class ProactorEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return windows_events.ProactorEventLoop() + + def test_create_ssl_connection(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + + def test_start_serving_ssl(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + + def test_reader_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_reader_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_writer_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + + def test_writer_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + + def test_create_datagram_endpoint(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") +else: + from tulip import selectors + from tulip import unix_events + + if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop( + selectors.KqueueSelector()) + + if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.EpollSelector()) + + if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.PollSelector()) + + # Should always exist. + class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.SelectSelector()) + + +class HandleTests(unittest.TestCase): + + def test_handle(self): + def callback(*args): + return args + + args = () + h = events.Handle(callback, args) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h._cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '.callback')) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h._cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '.callback')) + self.assertTrue(r.endswith('())'), r) + + def test_make_handle(self): + def callback(*args): + return args + h1 = events.Handle(callback, ()) + self.assertRaises( + AssertionError, events.make_handle, h1, ()) + + @unittest.mock.patch('tulip.events.tulip_log') + def test_callback_with_exception(self, log): + def callback(): + raise ValueError() + + h = events.Handle(callback, ()) + h._run() + self.assertTrue(log.exception.called) + + +class TimerTests(unittest.TestCase): + + def test_hash(self): + when = time.monotonic() + h = events.TimerHandle(when, lambda: False, ()) + self.assertEqual(hash(h), hash(when)) + + def test_timer(self): + def callback(*args): + return args + + args = () + when = time.monotonic() + h = events.TimerHandle(when, callback, args) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h._cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h._cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())'), r) + + self.assertRaises(AssertionError, + events.TimerHandle, None, callback, args) + + def test_timer_comparison(self): + def callback(*args): + return args + + when = time.monotonic() + + h1 = events.TimerHandle(when, callback, ()) + h2 = events.TimerHandle(when, callback, ()) + # TODO: Use assertLess etc. + self.assertFalse(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertTrue(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertFalse(h2 > h1) + self.assertTrue(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertTrue(h1 == h2) + self.assertFalse(h1 != h2) + + h2.cancel() + self.assertFalse(h1 == h2) + + h1 = events.TimerHandle(when, callback, ()) + h2 = events.TimerHandle(when + 10.0, callback, ()) + self.assertTrue(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertFalse(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertTrue(h2 > h1) + self.assertFalse(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h3 = events.Handle(callback, ()) + self.assertIs(NotImplemented, h1.__eq__(h3)) + self.assertIs(NotImplemented, h1.__ne__(h3)) + + +class AbstractEventLoopTests(unittest.TestCase): + + def test_not_implemented(self): + f = unittest.mock.Mock() + loop = events.AbstractEventLoop() + self.assertRaises( + NotImplementedError, loop.run_forever) + self.assertRaises( + NotImplementedError, loop.run_until_complete, None) + self.assertRaises( + NotImplementedError, loop.stop) + self.assertRaises( + NotImplementedError, loop.is_running) + self.assertRaises( + NotImplementedError, loop.call_later, None, None) + self.assertRaises( + NotImplementedError, loop.call_at, f, f) + self.assertRaises( + NotImplementedError, loop.call_soon, None) + self.assertRaises( + NotImplementedError, loop.time) + self.assertRaises( + NotImplementedError, loop.call_soon_threadsafe, None) + self.assertRaises( + NotImplementedError, loop.run_in_executor, f, f) + self.assertRaises( + NotImplementedError, loop.set_default_executor, f) + self.assertRaises( + NotImplementedError, loop.getaddrinfo, 'localhost', 8080) + self.assertRaises( + NotImplementedError, loop.getnameinfo, ('localhost', 8080)) + self.assertRaises( + NotImplementedError, loop.create_connection, f) + self.assertRaises( + NotImplementedError, loop.start_serving, f) + self.assertRaises( + NotImplementedError, loop.stop_serving, f) + self.assertRaises( + NotImplementedError, loop.create_datagram_endpoint, f) + self.assertRaises( + NotImplementedError, loop.add_reader, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_reader, 1) + self.assertRaises( + NotImplementedError, loop.add_writer, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_writer, 1) + self.assertRaises( + NotImplementedError, loop.sock_recv, f, 10) + self.assertRaises( + NotImplementedError, loop.sock_sendall, f, 10) + self.assertRaises( + NotImplementedError, loop.sock_connect, f, f) + self.assertRaises( + NotImplementedError, loop.sock_accept, f) + self.assertRaises( + NotImplementedError, loop.add_signal_handler, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, loop.connect_read_pipe, f, + unittest.mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, loop.connect_write_pipe, f, + unittest.mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, loop.subprocess_shell, f, + unittest.mock.sentinel) + self.assertRaises( + NotImplementedError, loop.subprocess_exec, f) + + +class ProtocolsAbsTests(unittest.TestCase): + + def test_empty(self): + f = unittest.mock.Mock() + p = protocols.Protocol() + self.assertIsNone(p.connection_made(f)) + self.assertIsNone(p.connection_lost(f)) + self.assertIsNone(p.data_received(f)) + self.assertIsNone(p.eof_received()) + + dp = protocols.DatagramProtocol() + self.assertIsNone(dp.connection_made(f)) + self.assertIsNone(dp.connection_lost(f)) + self.assertIsNone(dp.connection_refused(f)) + self.assertIsNone(dp.datagram_received(f, f)) + + sp = protocols.SubprocessProtocol() + self.assertIsNone(sp.connection_made(f)) + self.assertIsNone(sp.connection_lost(f)) + self.assertIsNone(sp.pipe_data_received(1, f)) + self.assertIsNone(sp.pipe_connection_lost(1, f)) + self.assertIsNone(sp.process_exited()) + + +class PolicyTests(unittest.TestCase): + + def test_event_loop_policy(self): + policy = events.AbstractEventLoopPolicy() + self.assertRaises(NotImplementedError, policy.get_event_loop) + self.assertRaises(NotImplementedError, policy.set_event_loop, object()) + self.assertRaises(NotImplementedError, policy.new_event_loop) + + def test_get_event_loop(self): + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy._loop) + + loop = policy.get_event_loop() + self.assertIsInstance(loop, events.AbstractEventLoop) + + self.assertIs(policy._loop, loop) + self.assertIs(loop, policy.get_event_loop()) + loop.close() + + def test_get_event_loop_after_set_none(self): + policy = events.DefaultEventLoopPolicy() + policy.set_event_loop(None) + self.assertRaises(AssertionError, policy.get_event_loop) + + @unittest.mock.patch('tulip.events.threading') + def test_get_event_loop_thread(self, m_threading): + m_t = m_threading.current_thread.return_value = unittest.mock.Mock() + m_t.name = 'Thread 1' + + policy = events.DefaultEventLoopPolicy() + self.assertRaises(AssertionError, policy.get_event_loop) + + def test_new_event_loop(self): + policy = events.DefaultEventLoopPolicy() + + loop = policy.new_event_loop() + self.assertIsInstance(loop, events.AbstractEventLoop) + loop.close() + + def test_set_event_loop(self): + policy = events.DefaultEventLoopPolicy() + old_loop = policy.get_event_loop() + + self.assertRaises(AssertionError, policy.set_event_loop, object()) + + loop = policy.new_event_loop() + policy.set_event_loop(loop) + self.assertIs(loop, policy.get_event_loop()) + self.assertIsNot(old_loop, policy.get_event_loop()) + loop.close() + old_loop.close() + + def test_get_event_loop_policy(self): + policy = events.get_event_loop_policy() + self.assertIsInstance(policy, events.AbstractEventLoopPolicy) + self.assertIs(policy, events.get_event_loop_policy()) + + def test_set_event_loop_policy(self): + self.assertRaises( + AssertionError, events.set_event_loop_policy, object()) + + old_policy = events.get_event_loop_policy() + + policy = events.DefaultEventLoopPolicy() + events.set_event_loop_policy(policy) + self.assertIs(policy, events.get_event_loop_policy()) + self.assertIsNot(policy, old_policy) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/futures_test.py b/tests/futures_test.py new file mode 100644 index 00000000..c7228c00 --- /dev/null +++ b/tests/futures_test.py @@ -0,0 +1,326 @@ +"""Tests for futures.py.""" + +import concurrent.futures +import threading +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import test_utils + + +def _fakefunc(f): + return f + + +class FutureTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_initial_state(self): + f = futures.Future(loop=self.loop) + self.assertFalse(f.cancelled()) + self.assertFalse(f.done()) + f.cancel() + self.assertTrue(f.cancelled()) + + def test_init_constructor_default_loop(self): + try: + events.set_event_loop(self.loop) + f = futures.Future() + self.assertIs(f._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_constructor_positional(self): + # Make sure Future does't accept a positional argument + self.assertRaises(TypeError, futures.Future, 42) + + def test_cancel(self): + f = futures.Future(loop=self.loop) + self.assertTrue(f.cancel()) + self.assertTrue(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(futures.CancelledError, f.result) + self.assertRaises(futures.CancelledError, f.exception) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_result(self): + f = futures.Future(loop=self.loop) + self.assertRaises(futures.InvalidStateError, f.result) + + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_exception(self): + exc = RuntimeError() + f = futures.Future(loop=self.loop) + self.assertRaises(futures.InvalidStateError, f.exception) + + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_yield_from_twice(self): + f = futures.Future(loop=self.loop) + + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + + def test_repr(self): + f_pending = futures.Future(loop=self.loop) + self.assertEqual(repr(f_pending), 'Future') + f_pending.cancel() + + f_cancelled = futures.Future(loop=self.loop) + f_cancelled.cancel() + self.assertEqual(repr(f_cancelled), 'Future') + + f_result = futures.Future(loop=self.loop) + f_result.set_result(4) + self.assertEqual(repr(f_result), 'Future') + self.assertEqual(f_result.result(), 4) + + exc = RuntimeError() + f_exception = futures.Future(loop=self.loop) + f_exception.set_exception(exc) + self.assertEqual(repr(f_exception), 'Future') + self.assertIs(f_exception.exception(), exc) + + f_few_callbacks = futures.Future(loop=self.loop) + f_few_callbacks.add_done_callback(_fakefunc) + self.assertIn('Future', r) + f_many_callbacks.cancel() + + f_pending = futures.Future(loop=self.loop, timeout=10) + self.assertEqual('Future{timeout=10, when=10}', + repr(f_pending)) + f_pending.cancel() + + f_pending = futures.Future(loop=self.loop, timeout=10) + f_pending.cancel() + self.assertEqual('Future{timeout=10}', repr(f_pending)) + + def test_copy_state(self): + # Test the internal _copy_state method since it's being directly + # invoked in other modules. + f = futures.Future(loop=self.loop) + f.set_result(10) + + newf = futures.Future(loop=self.loop) + newf._copy_state(f) + self.assertTrue(newf.done()) + self.assertEqual(newf.result(), 10) + + f_exception = futures.Future(loop=self.loop) + f_exception.set_exception(RuntimeError()) + + newf_exception = futures.Future(loop=self.loop) + newf_exception._copy_state(f_exception) + self.assertTrue(newf_exception.done()) + self.assertRaises(RuntimeError, newf_exception.result) + + f_cancelled = futures.Future(loop=self.loop) + f_cancelled.cancel() + + newf_cancelled = futures.Future(loop=self.loop) + newf_cancelled._copy_state(f_cancelled) + self.assertTrue(newf_cancelled.cancelled()) + + def test_iter(self): + fut = futures.Future(loop=self.loop) + + def coro(): + yield from fut + + def test(): + arg1, arg2 = coro() + + self.assertRaises(AssertionError, test) + fut.cancel() + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_tb_logger_abandoned(self, m_log): + fut = futures.Future(loop=self.loop) + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_tb_logger_result_unretrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_result(42) + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_tb_logger_result_retrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_result(42) + fut.result() + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_tb_logger_exception_unretrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + del fut + test_utils.run_briefly(self.loop) + self.assertTrue(m_log.error.called) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_tb_logger_exception_retrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + fut.exception() + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_tb_logger_exception_result_retrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + self.assertRaises(RuntimeError, fut.result) + del fut + self.assertFalse(m_log.error.called) + + def test_wrap_future(self): + + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = futures.wrap_future(f1, loop=self.loop) + res, ident = self.loop.run_until_complete(f2) + self.assertIsInstance(f2, futures.Future) + self.assertEqual(res, 'oi') + self.assertNotEqual(ident, threading.get_ident()) + + def test_wrap_future_future(self): + f1 = futures.Future(loop=self.loop) + f2 = futures.wrap_future(f1) + self.assertIs(f1, f2) + + @unittest.mock.patch('tulip.futures.events') + def test_wrap_future_use_global_loop(self, m_events): + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = futures.wrap_future(f1) + self.assertIs(m_events.get_event_loop.return_value, f2._loop) + + +# A fake event loop for tests. All it does is implement a call_soon method +# that immediately invokes the given function. +class _FakeEventLoop: + def call_soon(self, fn, *args): + fn(*args) + + +class FutureDoneCallbackTests(unittest.TestCase): + + def _make_callback(self, bag, thing): + # Create a callback function that appends thing to bag. + def bag_appender(future): + bag.append(thing) + return bag_appender + + def _new_future(self): + return futures.Future(loop=_FakeEventLoop()) + + def test_callbacks_invoked_on_set_result(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 42)) + f.add_done_callback(self._make_callback(bag, 17)) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def test_callbacks_invoked_on_set_exception(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 100)) + + self.assertEqual(bag, []) + exc = RuntimeError() + f.set_exception(exc) + self.assertEqual(bag, [100]) + self.assertEqual(f.exception(), exc) + + def test_remove_done_callback(self): + bag = [] + f = self._new_future() + cb1 = self._make_callback(bag, 1) + cb2 = self._make_callback(bag, 2) + cb3 = self._make_callback(bag, 3) + + # Add one cb1 and one cb2. + f.add_done_callback(cb1) + f.add_done_callback(cb2) + + # One instance of cb2 removed. Now there's only one cb1. + self.assertEqual(f.remove_done_callback(cb2), 1) + + # Never had any cb3 in there. + self.assertEqual(f.remove_done_callback(cb3), 0) + + # After this there will be 6 instances of cb1 and one of cb2. + f.add_done_callback(cb2) + for i in range(5): + f.add_done_callback(cb1) + + # Remove all instances of cb1. One cb2 remains. + self.assertEqual(f.remove_done_callback(cb1), 6) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [2]) + self.assertEqual(f.result(), 'foo') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/http_client_functional_test.py b/tests/http_client_functional_test.py new file mode 100644 index 00000000..91badfc4 --- /dev/null +++ b/tests/http_client_functional_test.py @@ -0,0 +1,552 @@ +"""Http client functional tests.""" + +import gc +import io +import os.path +import http.cookies +import unittest + +import tulip +import tulip.http +from tulip import test_utils +from tulip.http import client + + +class HttpClientFunctionalTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(None) + + def tearDown(self): + # just in case if we have transport close callbacks + test_utils.run_briefly(self.loop) + + self.loop.close() + gc.collect() + + def test_HTTP_200_OK_METHOD(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + for meth in ('get', 'post', 'put', 'delete', 'head'): + r = self.loop.run_until_complete( + client.request(meth, httpd.url('method', meth), + loop=self.loop)) + content1 = self.loop.run_until_complete(r.read()) + content2 = self.loop.run_until_complete(r.read()) + content = content1.decode() + + self.assertEqual(r.status, 200) + self.assertIn('"method": "%s"' % meth.upper(), content) + self.assertEqual(content1, content2) + r.close() + + def test_use_global_loop(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + try: + tulip.set_event_loop(self.loop) + r = self.loop.run_until_complete( + client.request('get', httpd.url('method', 'get'))) + finally: + tulip.set_event_loop(None) + content1 = self.loop.run_until_complete(r.read()) + content2 = self.loop.run_until_complete(r.read()) + content = content1.decode() + + self.assertEqual(r.status, 200) + self.assertIn('"method": "GET"', content) + self.assertEqual(content1, content2) + r.close() + + def test_HTTP_302_REDIRECT_GET(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('redirect', 2), + loop=self.loop)) + + self.assertEqual(r.status, 200) + self.assertEqual(2, httpd['redirects']) + r.close() + + def test_HTTP_302_REDIRECT_NON_HTTP(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + self.assertRaises( + ValueError, + self.loop.run_until_complete, + client.request('get', httpd.url('redirect_err'), + loop=self.loop)) + + def test_HTTP_302_REDIRECT_POST(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('post', httpd.url('redirect', 2), + data={'some': 'data'}, loop=self.loop)) + content = self.loop.run_until_complete(r.content.read()) + content = content.decode() + + self.assertEqual(r.status, 200) + self.assertIn('"method": "POST"', content) + self.assertEqual(2, httpd['redirects']) + r.close() + + def test_HTTP_302_max_redirects(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('redirect', 5), + max_redirects=2, loop=self.loop)) + + self.assertEqual(r.status, 302) + self.assertEqual(2, httpd['redirects']) + r.close() + + def test_HTTP_200_GET_WITH_PARAMS(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('method', 'get'), + params={'q': 'test'}, loop=self.loop)) + content = self.loop.run_until_complete(r.content.read()) + content = content.decode() + + self.assertIn('"query": "q=test"', content) + self.assertEqual(r.status, 200) + r.close() + + def test_HTTP_200_GET_WITH_MIXED_PARAMS(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request( + 'get', httpd.url('method', 'get') + '?test=true', + params={'q': 'test'}, loop=self.loop)) + content = self.loop.run_until_complete(r.content.read()) + content = content.decode() + + self.assertIn('"query": "test=true&q=test"', content) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_DATA(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + r = self.loop.run_until_complete( + client.request('post', url, data={'some': 'data'}, + loop=self.loop)) + self.assertEqual(r.status, 200) + + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual({'some': ['data']}, content['form']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_DATA_DEFLATE(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + r = self.loop.run_until_complete( + client.request('post', url, + data={'some': 'data'}, compress=True, + loop=self.loop)) + self.assertEqual(r.status, 200) + + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual('deflate', content['compression']) + self.assertEqual({'some': ['data']}, content['form']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request( + 'post', url, files={'some': f}, chunked=1024, + headers={'Transfer-Encoding': 'chunked'}, + loop=self.loop)) + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_DEFLATE(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files={'some': f}, + chunked=1024, compress='deflate', + loop=self.loop)) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual('deflate', content['compression']) + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_STR(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files=[('some', f.read())], + loop=self.loop)) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + 'some', content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_LIST(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files=[('some', f)], + loop=self.loop)) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_LIST_CT(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, loop=self.loop, + files=[('some', f, 'text/plain')])) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual( + 'text/plain', content['multipart-data'][0]['content-type']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_SINGLE(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files=[f], loop=self.loop)) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + filename, content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_IO(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + data = io.BytesIO(b'data') + + r = self.loop.run_until_complete( + client.request('post', url, files=[data], loop=self.loop)) + + content = self.loop.run_until_complete(r.read(True)) + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + {'content-type': 'application/octet-stream', + 'data': 'data', + 'filename': 'unknown', + 'name': 'unknown'}, content['multipart-data'][0]) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_WITH_DATA(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, loop=self.loop, + data={'test': 'true'}, files={'some': f})) + + content = self.loop.run_until_complete(r.read(True)) + + self.assertEqual(2, len(content['multipart-data'])) + self.assertEqual( + 'test', content['multipart-data'][0]['name']) + self.assertEqual( + 'true', content['multipart-data'][0]['data']) + + f.seek(0) + filename = os.path.split(f.name)[-1] + self.assertEqual( + 'some', content['multipart-data'][1]['name']) + self.assertEqual( + filename, content['multipart-data'][1]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][1]['data']) + self.assertEqual(r.status, 200) + r.close() + + def test_encoding(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('encoding', 'deflate'), + loop=self.loop)) + self.assertEqual(r.status, 200) + + r = self.loop.run_until_complete( + client.request('get', httpd.url('encoding', 'gzip'), + loop=self.loop)) + self.assertEqual(r.status, 200) + r.close() + + def test_cookies(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + c = http.cookies.Morsel() + c.set('test3', '456', '456') + + r = self.loop.run_until_complete( + client.request( + 'get', httpd.url('method', 'get'), loop=self.loop, + cookies={'test1': '123', 'test2': c})) + self.assertEqual(r.status, 200) + + content = self.loop.run_until_complete(r.content.read()) + self.assertIn(b'"Cookie": "test1=123; test3=456"', bytes(content)) + r.close() + + def test_set_cookies(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + resp = self.loop.run_until_complete( + client.request('get', httpd.url('cookies'), loop=self.loop)) + self.assertEqual(resp.status, 200) + + self.assertEqual(resp.cookies['c1'].value, 'cookie1') + self.assertEqual(resp.cookies['c2'].value, 'cookie2') + resp.close() + + def test_chunked(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('chunked'), loop=self.loop)) + self.assertEqual(r.status, 200) + self.assertEqual(r['Transfer-Encoding'], 'chunked') + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['path'], '/chunked') + r.close() + + def test_timeout(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + httpd['noresponse'] = True + self.assertRaises( + tulip.TimeoutError, + self.loop.run_until_complete, + client.request('get', httpd.url('method', 'get'), + timeout=0.1, loop=self.loop)) + + def test_request_conn_error(self): + self.assertRaises( + OSError, + self.loop.run_until_complete, + client.request('get', 'http://0.0.0.0:1', + timeout=0.1, loop=self.loop)) + + def test_request_conn_closed(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + httpd['close'] = True + self.assertRaises( + tulip.http.HttpException, + self.loop.run_until_complete, + client.request('get', httpd.url('method', 'get'), + loop=self.loop)) + + def test_keepalive(self): + from tulip.http import session + s = session.Session() + + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('keepalive',), + session=s, loop=self.loop)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['content'], 'requests=1') + r.close() + + r = self.loop.run_until_complete( + client.request('get', httpd.url('keepalive'), + session=s, loop=self.loop)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['content'], 'requests=2') + r.close() + + def test_session_close(self): + from tulip.http import session + s = session.Session() + + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request( + 'get', httpd.url('keepalive') + '?close=1', + session=s, loop=self.loop)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['content'], 'requests=1') + r.close() + + r = self.loop.run_until_complete( + client.request('get', httpd.url('keepalive'), + session=s, loop=self.loop)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['content'], 'requests=1') + r.close() + + def test_session_cookies(self): + from tulip.http import session + s = session.Session() + + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + s.update_cookies({'test': '1'}) + r = self.loop.run_until_complete( + client.request('get', httpd.url('cookies'), + session=s, loop=self.loop)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + + self.assertEqual(content['headers']['Cookie'], 'test=1') + r.close() + + cookies = sorted([(k, v.value) for k, v in s.cookies.items()]) + self.assertEqual( + cookies, [('c1', 'cookie1'), ('c2', 'cookie2'), ('test', '1')]) + + +class Functional(test_utils.Router): + + @test_utils.Router.define('/method/([A-Za-z]+)$') + def method(self, match): + meth = match.group(1).upper() + if meth == self._method: + self._response(self._start_response(200)) + else: + self._response(self._start_response(400)) + + @test_utils.Router.define('/redirect_err$') + def redirect_err(self, match): + self._response( + self._start_response(302), + headers={'Location': 'ftp://127.0.0.1/test/'}) + + @test_utils.Router.define('/redirect/([0-9]+)$') + def redirect(self, match): + no = int(match.group(1).upper()) + rno = self._props['redirects'] = self._props.get('redirects', 0) + 1 + + if rno >= no: + self._response( + self._start_response(302), + headers={'Location': '/method/%s' % self._method.lower()}) + else: + self._response( + self._start_response(302), + headers={'Location': self._path}) + + @test_utils.Router.define('/encoding/(gzip|deflate)$') + def encoding(self, match): + mode = match.group(1) + + resp = self._start_response(200) + resp.add_compression_filter(mode) + resp.add_chunking_filter(100) + self._response(resp, headers={'Content-encoding': mode}, chunked=True) + + @test_utils.Router.define('/chunked$') + def chunked(self, match): + resp = self._start_response(200) + resp.add_chunking_filter(100) + self._response(resp, chunked=True) + + @test_utils.Router.define('/keepalive$') + def keepalive(self, match): + self._transport._requests = getattr( + self._transport, '_requests', 0) + 1 + resp = self._start_response(200) + if 'close=' in self._query: + self._response( + resp, 'requests={}'.format(self._transport._requests)) + else: + self._response( + resp, 'requests={}'.format(self._transport._requests), + headers={'CONNECTION': 'keep-alive'}) + + @test_utils.Router.define('/cookies$') + def cookies(self, match): + cookies = http.cookies.SimpleCookie() + cookies['c1'] = 'cookie1' + cookies['c2'] = 'cookie2' + + resp = self._start_response(200) + for cookie in cookies.output(header='').split('\n'): + resp.add_header('Set-Cookie', cookie.strip()) + + self._response(resp) diff --git a/tests/http_client_test.py b/tests/http_client_test.py new file mode 100644 index 00000000..1aa27244 --- /dev/null +++ b/tests/http_client_test.py @@ -0,0 +1,289 @@ +# -*- coding: utf-8 -*- +"""Tests for tulip/http/client.py""" + +import unittest +import unittest.mock +import urllib.parse + +import tulip +import tulip.http + +from tulip.http.client import HttpRequest, HttpResponse + + +class HttpResponseTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(None) + + self.transport = unittest.mock.Mock() + self.stream = tulip.StreamBuffer(loop=self.loop) + self.response = HttpResponse('get', 'http://python.org') + + def tearDown(self): + self.loop.close() + + def test_close(self): + self.response.transport = self.transport + self.response.close() + self.assertIsNone(self.response.transport) + self.assertTrue(self.transport.close.called) + self.response.close() + self.response.close() + + def test_repr(self): + self.response.status = 200 + self.response.reason = 'Ok' + self.assertIn( + '', repr(self.response)) + + +class HttpRequestTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(None) + + self.transport = unittest.mock.Mock() + self.stream = tulip.StreamBuffer(loop=self.loop) + + def tearDown(self): + self.loop.close() + + def test_method(self): + req = HttpRequest('get', 'http://python.org/') + self.assertEqual(req.method, 'GET') + + req = HttpRequest('head', 'http://python.org/') + self.assertEqual(req.method, 'HEAD') + + req = HttpRequest('HEAD', 'http://python.org/') + self.assertEqual(req.method, 'HEAD') + + def test_version(self): + req = HttpRequest('get', 'http://python.org/', version='1.0') + self.assertEqual(req.version, (1, 0)) + + def test_version_err(self): + self.assertRaises( + ValueError, + HttpRequest, 'get', 'http://python.org/', version='1.c') + + def test_host_port(self): + req = HttpRequest('get', 'http://python.org/') + self.assertEqual(req.host, 'python.org') + self.assertEqual(req.port, 80) + self.assertFalse(req.ssl) + + req = HttpRequest('get', 'https://python.org/') + self.assertEqual(req.host, 'python.org') + self.assertEqual(req.port, 443) + self.assertTrue(req.ssl) + + req = HttpRequest('get', 'https://python.org:960/') + self.assertEqual(req.host, 'python.org') + self.assertEqual(req.port, 960) + self.assertTrue(req.ssl) + + def test_host_port_err(self): + self.assertRaises( + ValueError, HttpRequest, 'get', 'http://python.org:123e/') + + def test_host_header(self): + req = HttpRequest('get', 'http://python.org/') + self.assertEqual(req.headers['host'], 'python.org') + + req = HttpRequest('get', 'http://python.org/', + headers={'host': 'example.com'}) + self.assertEqual(req.headers['host'], 'example.com') + + def test_headers(self): + req = HttpRequest('get', 'http://python.org/', + headers={'Content-Type': 'text/plain'}) + self.assertIn('Content-Type', req.headers) + self.assertEqual(req.headers['Content-Type'], 'text/plain') + self.assertEqual(req.headers['Accept-Encoding'], 'gzip, deflate') + + def test_headers_list(self): + req = HttpRequest('get', 'http://python.org/', + headers=[('Content-Type', 'text/plain')]) + self.assertIn('Content-Type', req.headers) + self.assertEqual(req.headers['Content-Type'], 'text/plain') + + def test_headers_default(self): + req = HttpRequest('get', 'http://python.org/', + headers={'Accept-Encoding': 'deflate'}) + self.assertEqual(req.headers['Accept-Encoding'], 'deflate') + + def test_invalid_url(self): + self.assertRaises(ValueError, HttpRequest, 'get', 'hiwpefhipowhefopw') + + def test_invalid_idna(self): + self.assertRaises( + ValueError, HttpRequest, 'get', 'http://\u2061owhefopw.com') + + def test_no_path(self): + req = HttpRequest('get', 'http://python.org') + self.assertEqual('/', req.path) + + def test_basic_auth(self): + req = HttpRequest('get', 'http://python.org', auth=('nkim', '1234')) + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbToxMjM0', req.headers['Authorization']) + + def test_basic_auth_from_url(self): + req = HttpRequest('get', 'http://nkim:1234@python.org') + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbToxMjM0', req.headers['Authorization']) + + req = HttpRequest('get', 'http://nkim@python.org') + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbTo=', req.headers['Authorization']) + + req = HttpRequest( + 'get', 'http://nkim@python.org', auth=('nkim', '1234')) + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbToxMjM0', req.headers['Authorization']) + + def test_basic_auth_err(self): + self.assertRaises( + ValueError, HttpRequest, + 'get', 'http://python.org', auth=(1, 2, 3)) + + def test_no_content_length(self): + req = HttpRequest('get', 'http://python.org') + req.send(self.transport) + self.assertEqual('0', req.headers.get('Content-Length')) + + req = HttpRequest('head', 'http://python.org') + req.send(self.transport) + self.assertEqual('0', req.headers.get('Content-Length')) + + def test_path_is_not_double_encoded(self): + req = HttpRequest('get', "http://0.0.0.0/get/test case") + self.assertEqual(req.path, "/get/test%20case") + + req = HttpRequest('get', "http://0.0.0.0/get/test%20case") + self.assertEqual(req.path, "/get/test%20case") + + def test_params_are_added_before_fragment(self): + req = HttpRequest( + 'GET', "http://example.com/path#fragment", params={"a": "b"}) + self.assertEqual( + req.path, "/path?a=b#fragment") + + req = HttpRequest( + 'GET', + "http://example.com/path?key=value#fragment", params={"a": "b"}) + self.assertEqual( + req.path, "/path?key=value&a=b#fragment") + + def test_cookies(self): + req = HttpRequest( + 'get', 'http://test.com/path', cookies={'cookie1': 'val1'}) + self.assertIn('Cookie', req.headers) + self.assertEqual('cookie1=val1', req.headers['cookie']) + + req = HttpRequest( + 'get', 'http://test.com/path', + headers={'cookie': 'cookie1=val1'}, + cookies={'cookie2': 'val2'}) + self.assertEqual('cookie1=val1; cookie2=val2', req.headers['cookie']) + + def test_unicode_get(self): + def join(*suffix): + return urllib.parse.urljoin('http://python.org/', '/'.join(suffix)) + + url = 'http://python.org' + req = HttpRequest('get', url, params={'foo': 'f\xf8\xf8'}) + self.assertEqual('/?foo=f%C3%B8%C3%B8', req.path) + req = HttpRequest('', url, params={'f\xf8\xf8': 'f\xf8\xf8'}) + self.assertEqual('/?f%C3%B8%C3%B8=f%C3%B8%C3%B8', req.path) + req = HttpRequest('', url, params={'foo': 'foo'}) + self.assertEqual('/?foo=foo', req.path) + req = HttpRequest('', join('\xf8'), params={'foo': 'foo'}) + self.assertEqual('/%C3%B8?foo=foo', req.path) + + def test_query_multivalued_param(self): + for meth in HttpRequest.ALL_METHODS: + req = HttpRequest( + meth, 'http://python.org', + params=(('test', 'foo'), ('test', 'baz'))) + self.assertEqual(req.path, '/?test=foo&test=baz') + + def test_post_data(self): + for meth in HttpRequest.POST_METHODS: + req = HttpRequest(meth, 'http://python.org/', data={'life': '42'}) + req.send(self.transport) + self.assertEqual('/', req.path) + self.assertEqual(b'life=42', req.body[0]) + self.assertEqual('application/x-www-form-urlencoded', + req.headers['content-type']) + + def test_get_with_data(self): + for meth in HttpRequest.GET_METHODS: + req = HttpRequest(meth, 'http://python.org/', data={'life': '42'}) + self.assertEqual('/?life=42', req.path) + + @unittest.mock.patch('tulip.http.client.tulip') + def test_content_encoding(self, m_tulip): + req = HttpRequest('get', 'http://python.org/', compress='deflate') + req.send(self.transport) + self.assertEqual(req.headers['Transfer-encoding'], 'chunked') + self.assertEqual(req.headers['Content-encoding'], 'deflate') + m_tulip.http.Request.return_value\ + .add_compression_filter.assert_called_with('deflate') + + @unittest.mock.patch('tulip.http.client.tulip') + def test_content_encoding_header(self, m_tulip): + req = HttpRequest('get', 'http://python.org/', + headers={'Content-Encoding': 'deflate'}) + req.send(self.transport) + self.assertEqual(req.headers['Transfer-encoding'], 'chunked') + self.assertEqual(req.headers['Content-encoding'], 'deflate') + + m_tulip.http.Request.return_value\ + .add_compression_filter.assert_called_with('deflate') + m_tulip.http.Request.return_value\ + .add_chunking_filter.assert_called_with(8196) + + def test_chunked(self): + req = HttpRequest( + 'get', 'http://python.org/', + headers={'Transfer-encoding': 'gzip'}) + req.send(self.transport) + self.assertEqual('gzip', req.headers['Transfer-encoding']) + + req = HttpRequest( + 'get', 'http://python.org/', + headers={'Transfer-encoding': 'chunked'}) + req.send(self.transport) + self.assertEqual('chunked', req.headers['Transfer-encoding']) + + @unittest.mock.patch('tulip.http.client.tulip') + def test_chunked_explicit(self, m_tulip): + req = HttpRequest( + 'get', 'http://python.org/', chunked=True) + req.send(self.transport) + + self.assertEqual('chunked', req.headers['Transfer-encoding']) + m_tulip.http.Request.return_value\ + .add_chunking_filter.assert_called_with(8196) + + @unittest.mock.patch('tulip.http.client.tulip') + def test_chunked_explicit_size(self, m_tulip): + req = HttpRequest( + 'get', 'http://python.org/', chunked=1024) + req.send(self.transport) + self.assertEqual('chunked', req.headers['Transfer-encoding']) + m_tulip.http.Request.return_value\ + .add_chunking_filter.assert_called_with(1024) + + def test_chunked_length(self): + req = HttpRequest( + 'get', 'http://python.org/', + headers={'Content-Length': '1000'}, chunked=1024) + req.send(self.transport) + self.assertEqual(req.headers['Transfer-Encoding'], 'chunked') + self.assertNotIn('Content-Length', req.headers) diff --git a/tests/http_parser_test.py b/tests/http_parser_test.py new file mode 100644 index 00000000..6240ad49 --- /dev/null +++ b/tests/http_parser_test.py @@ -0,0 +1,539 @@ +"""Tests for http/parser.py""" + +from collections import deque +import zlib +import unittest +import unittest.mock + +import tulip +from tulip.http import errors +from tulip.http import protocol + + +class ParseHeadersTests(unittest.TestCase): + + def setUp(self): + tulip.set_event_loop(None) + + def test_parse_headers(self): + hdrs = ('', 'test: line\r\n', ' continue\r\n', + 'test2: data\r\n', '\r\n') + + headers, close, compression = protocol.parse_headers( + hdrs, 8190, 32768, 8190) + + self.assertEqual(list(headers), + [('TEST', 'line\r\n continue'), ('TEST2', 'data')]) + self.assertIsNone(close) + self.assertIsNone(compression) + + def test_parse_headers_multi(self): + hdrs = ('', + 'Set-Cookie: c1=cookie1\r\n', + 'Set-Cookie: c2=cookie2\r\n', '\r\n') + + headers, close, compression = protocol.parse_headers( + hdrs, 8190, 32768, 8190) + + self.assertEqual(list(headers), + [('SET-COOKIE', 'c1=cookie1'), + ('SET-COOKIE', 'c2=cookie2')]) + self.assertIsNone(close) + self.assertIsNone(compression) + + def test_conn_close(self): + headers, close, compression = protocol.parse_headers( + ['', 'connection: close\r\n', '\r\n'], 8190, 32768, 8190) + self.assertTrue(close) + + def test_conn_keep_alive(self): + headers, close, compression = protocol.parse_headers( + ['', 'connection: keep-alive\r\n', '\r\n'], 8190, 32768, 8190) + self.assertFalse(close) + + def test_conn_other(self): + headers, close, compression = protocol.parse_headers( + ['', 'connection: test\r\n', '\r\n'], 8190, 32768, 8190) + self.assertIsNone(close) + + def test_compression_gzip(self): + headers, close, compression = protocol.parse_headers( + ['', 'content-encoding: gzip\r\n', '\r\n'], 8190, 32768, 8190) + self.assertEqual('gzip', compression) + + def test_compression_deflate(self): + headers, close, compression = protocol.parse_headers( + ['', 'content-encoding: deflate\r\n', '\r\n'], 8190, 32768, 8190) + self.assertEqual('deflate', compression) + + def test_compression_unknown(self): + headers, close, compression = protocol.parse_headers( + ['', 'content-encoding: compress\r\n', '\r\n'], 8190, 32768, 8190) + self.assertIsNone(compression) + + def test_max_field_size(self): + with self.assertRaises(errors.LineTooLong) as cm: + protocol.parse_headers( + ['', 'test: line data data\r\n', 'data\r\n', '\r\n'], + 8190, 32768, 5) + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_max_continuation_headers_size(self): + with self.assertRaises(errors.LineTooLong) as cm: + protocol.parse_headers( + ['', 'test: line\r\n', ' test\r\n', '\r\n'], 8190, 32768, 5) + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_invalid_header(self): + with self.assertRaises(ValueError) as cm: + protocol.parse_headers( + ['', 'test line\r\n', '\r\n'], 8190, 32768, 8190) + self.assertIn("Invalid header: test line", str(cm.exception)) + + def test_invalid_name(self): + with self.assertRaises(ValueError) as cm: + protocol.parse_headers( + ['', 'test[]: line\r\n', '\r\n'], 8190, 32768, 8190) + self.assertIn("Invalid header name: TEST[]", str(cm.exception)) + + +class DeflateBufferTests(unittest.TestCase): + + def setUp(self): + tulip.set_event_loop(None) + + def test_feed_data(self): + buf = tulip.DataBuffer() + dbuf = protocol.DeflateBuffer(buf, 'deflate') + + dbuf.zlib = unittest.mock.Mock() + dbuf.zlib.decompress.return_value = b'line' + + dbuf.feed_data(b'data') + self.assertEqual([b'line'], list(buf._buffer)) + + def test_feed_data_err(self): + buf = tulip.DataBuffer() + dbuf = protocol.DeflateBuffer(buf, 'deflate') + + exc = ValueError() + dbuf.zlib = unittest.mock.Mock() + dbuf.zlib.decompress.side_effect = exc + + self.assertRaises(errors.IncompleteRead, dbuf.feed_data, b'data') + + def test_feed_eof(self): + buf = tulip.DataBuffer() + dbuf = protocol.DeflateBuffer(buf, 'deflate') + + dbuf.zlib = unittest.mock.Mock() + dbuf.zlib.flush.return_value = b'line' + + dbuf.feed_eof() + self.assertEqual([b'line'], list(buf._buffer)) + self.assertTrue(buf._eof) + + def test_feed_eof_err(self): + buf = tulip.DataBuffer() + dbuf = protocol.DeflateBuffer(buf, 'deflate') + + dbuf.zlib = unittest.mock.Mock() + dbuf.zlib.flush.return_value = b'line' + dbuf.zlib.eof = False + + self.assertRaises(errors.IncompleteRead, dbuf.feed_eof) + + +class ParsePayloadTests(unittest.TestCase): + + def setUp(self): + tulip.set_event_loop(None) + + def test_parse_eof_payload(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_eof_payload(out, buf) + next(p) + p.send(b'data') + try: + p.throw(tulip.EofStream()) + except tulip.EofStream: + pass + + self.assertEqual([b'data'], list(out._buffer)) + + def test_parse_length_payload(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_length_payload(out, buf, 4) + next(p) + p.send(b'da') + p.send(b't') + try: + p.send(b'aline') + except StopIteration: + pass + + self.assertEqual(3, len(out._buffer)) + self.assertEqual(b'data', b''.join(out._buffer)) + self.assertEqual(b'line', bytes(buf)) + + def test_parse_length_payload_eof(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_length_payload(out, buf, 4) + next(p) + p.send(b'da') + self.assertRaises( + errors.IncompleteRead, p.throw, tulip.EofStream) + + def test_parse_chunked_payload(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + try: + p.send(b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') + except StopIteration: + pass + self.assertEqual(b'dataline', b''.join(out._buffer)) + self.assertEqual(b'', bytes(buf)) + + def test_parse_chunked_payload_chunks(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + p.send(b'4\r\ndata\r') + p.send(b'\n4') + p.send(b'\r') + p.send(b'\n') + p.send(b'line\r\n0\r\n') + self.assertRaises(StopIteration, p.send, b'test\r\n') + self.assertEqual(b'dataline', b''.join(out._buffer)) + + def test_parse_chunked_payload_incomplete(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + p.send(b'4\r\ndata\r\n') + self.assertRaises(errors.IncompleteRead, p.throw, tulip.EofStream) + + def test_parse_chunked_payload_extension(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + try: + p.send(b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') + except StopIteration: + pass + self.assertEqual(b'dataline', b''.join(out._buffer)) + + def test_parse_chunked_payload_size_error(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + self.assertRaises(errors.IncompleteRead, p.send, b'blah\r\n') + + def test_http_payload_parser_length_broken(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', 'qwe')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + self.assertRaises(errors.InvalidHeader, p.send, (out, buf)) + + def test_http_payload_parser_length_wrong(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', '-1')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + self.assertRaises(errors.InvalidHeader, p.send, (out, buf)) + + def test_http_payload_parser_length(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', '2')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + try: + p.send(b'1245') + except StopIteration: + pass + + self.assertEqual(b'12', b''.join(out._buffer)) + self.assertEqual(b'45', bytes(buf)) + + def test_http_payload_parser_no_length(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [], None, None) + p = protocol.http_payload_parser(msg, readall=False) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + self.assertRaises(StopIteration, p.send, (out, buf)) + self.assertEqual(b'', b''.join(out._buffer)) + self.assertTrue(out._eof) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_http_payload_parser_deflate(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', len(self._COMPRESSED))], + None, 'deflate') + p = protocol.http_payload_parser(msg) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(StopIteration, p.send, self._COMPRESSED) + self.assertEqual(b'data', b''.join(out._buffer)) + + def test_http_payload_parser_deflate_disabled(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', len(self._COMPRESSED))], + None, 'deflate') + p = protocol.http_payload_parser(msg, compression=False) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(StopIteration, p.send, self._COMPRESSED) + self.assertEqual(self._COMPRESSED, b''.join(out._buffer)) + + def test_http_payload_parser_websocket(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('SEC-WEBSOCKET-KEY1', '13')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(StopIteration, p.send, b'1234567890') + self.assertEqual(b'12345678', b''.join(out._buffer)) + + def test_http_payload_parser_chunked(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('TRANSFER-ENCODING', 'chunked')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(StopIteration, p.send, + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') + self.assertEqual(b'dataline', b''.join(out._buffer)) + + def test_http_payload_parser_eof(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [], None, None) + p = protocol.http_payload_parser(msg, readall=True) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + p.send(b'data') + p.send(b'line') + self.assertRaises(tulip.EofStream, p.throw, tulip.EofStream()) + self.assertEqual(b'dataline', b''.join(out._buffer)) + + def test_http_payload_parser_length_zero(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', '0')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + self.assertRaises(StopIteration, p.send, (out, buf)) + self.assertEqual(b'', b''.join(out._buffer)) + + +class ParseRequestTests(unittest.TestCase): + + def setUp(self): + tulip.set_event_loop(None) + + def test_http_request_parser_max_headers(self): + p = protocol.http_request_parser(8190, 20, 8190) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + + self.assertRaises( + errors.LineTooLong, + p.send, + b'get /path HTTP/1.1\r\ntest: line\r\ntest2: data\r\n\r\n') + + def test_http_request_parser(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + try: + p.send(b'get /path HTTP/1.1\r\n\r\n') + except StopIteration: + pass + result = out._buffer[0] + self.assertEqual( + ('GET', '/path', (1, 1), deque(), False, None), result) + + def test_http_request_parser_eof(self): + # http_request_parser does not fail on EofStream() + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + p.send(b'get /path HTTP/1.1\r\n') + try: + p.throw(tulip.EofStream()) + except StopIteration: + pass + self.assertFalse(out._buffer) + + def test_http_request_parser_two_slashes(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + try: + p.send(b'get //path HTTP/1.1\r\n\r\n') + except StopIteration: + pass + self.assertEqual( + ('GET', '//path', (1, 1), deque(), False, None), out._buffer[0]) + + def test_http_request_parser_bad_status_line(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises( + errors.BadStatusLine, p.send, b'\r\n\r\n') + + def test_http_request_parser_bad_method(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises( + errors.BadStatusLine, + p.send, b'!12%()+=~$ /get HTTP/1.1\r\n\r\n') + + def test_http_request_parser_bad_version(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises( + errors.BadStatusLine, + p.send, b'GET //get HT/11\r\n\r\n') + + +class ParseResponseTests(unittest.TestCase): + + def setUp(self): + tulip.set_event_loop(None) + + def test_http_response_parser_bad_status_line(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(errors.BadStatusLine, p.send, b'\r\n\r\n') + + def test_http_response_parser_bad_status_line_eof(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises( + errors.BadStatusLine, p.throw, tulip.EofStream()) + + def test_http_response_parser_bad_version(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HT/11 200 Ok\r\n\r\n') + self.assertEqual('HT/11 200 Ok\r\n', cm.exception.args[0]) + + def test_http_response_parser_no_reason(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + try: + p.send(b'HTTP/1.1 200\r\n\r\n') + except StopIteration: + pass + v, s, r = out._buffer[0][:3] + self.assertEqual(v, (1, 1)) + self.assertEqual(s, 200) + self.assertEqual(r, '') + + def test_http_response_parser_bad(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HTT/1\r\n\r\n') + self.assertIn('HTT/1', str(cm.exception)) + + def test_http_response_parser_code_under_100(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HTTP/1.1 99 test\r\n\r\n') + self.assertIn('HTTP/1.1 99 test', str(cm.exception)) + + def test_http_response_parser_code_above_999(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HTTP/1.1 9999 test\r\n\r\n') + self.assertIn('HTTP/1.1 9999 test', str(cm.exception)) + + def test_http_response_parser_code_not_int(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HTTP/1.1 ttt test\r\n\r\n') + self.assertIn('HTTP/1.1 ttt test', str(cm.exception)) diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py new file mode 100644 index 00000000..e74b8f27 --- /dev/null +++ b/tests/http_protocol_test.py @@ -0,0 +1,394 @@ +"""Tests for http/protocol.py""" + +import unittest +import unittest.mock +import zlib + +import tulip +from tulip.http import protocol + + +class HttpMessageTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + tulip.set_event_loop(None) + + def test_start_request(self): + msg = protocol.Request( + self.transport, 'GET', '/index.html', close=True) + + self.assertIs(msg.transport, self.transport) + self.assertIsNone(msg.status) + self.assertTrue(msg.closing) + self.assertEqual(msg.status_line, 'GET /index.html HTTP/1.1\r\n') + + def test_start_response(self): + msg = protocol.Response(self.transport, 200, close=True) + + self.assertIs(msg.transport, self.transport) + self.assertEqual(msg.status, 200) + self.assertTrue(msg.closing) + self.assertEqual(msg.status_line, 'HTTP/1.1 200 OK\r\n') + + def test_force_close(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.closing) + msg.force_close() + self.assertTrue(msg.closing) + + def test_force_chunked(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.chunked) + msg.force_chunked() + self.assertTrue(msg.chunked) + + def test_keep_alive(self): + msg = protocol.Response(self.transport, 200, close=True) + self.assertFalse(msg.keep_alive()) + msg.keepalive = True + self.assertTrue(msg.keep_alive()) + + msg.force_close() + self.assertFalse(msg.keep_alive()) + + def test_keep_alive_http10(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + self.assertFalse(msg.keepalive) + self.assertFalse(msg.keep_alive()) + + msg = protocol.Response(self.transport, 200, http_version=(1, 1)) + self.assertIsNone(msg.keepalive) + + def test_add_header(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], list(msg.headers)) + + msg.add_header('content-type', 'plain/html') + self.assertEqual([('CONTENT-TYPE', 'plain/html')], list(msg.headers)) + + def test_add_headers(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], list(msg.headers)) + + msg.add_headers(('content-type', 'plain/html')) + self.assertEqual([('CONTENT-TYPE', 'plain/html')], list(msg.headers)) + + def test_add_headers_length(self): + msg = protocol.Response(self.transport, 200) + self.assertIsNone(msg.length) + + msg.add_headers(('content-length', '200')) + self.assertEqual(200, msg.length) + + def test_add_headers_upgrade(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.upgrade) + + msg.add_headers(('connection', 'upgrade')) + self.assertTrue(msg.upgrade) + + def test_add_headers_upgrade_websocket(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('upgrade', 'test')) + self.assertEqual([], list(msg.headers)) + + msg.add_headers(('upgrade', 'websocket')) + self.assertEqual([('UPGRADE', 'websocket')], list(msg.headers)) + + def test_add_headers_connection_keepalive(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('connection', 'keep-alive')) + self.assertEqual([], list(msg.headers)) + self.assertTrue(msg.keepalive) + + msg.add_headers(('connection', 'close')) + self.assertFalse(msg.keepalive) + + def test_add_headers_hop_headers(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('connection', 'test'), ('transfer-encoding', 't')) + self.assertEqual([], list(msg.headers)) + + def test_default_headers(self): + msg = protocol.Response(self.transport, 200) + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertIn('DATE', headers) + self.assertIn('CONNECTION', headers) + + def test_default_headers_server(self): + msg = protocol.Response(self.transport, 200) + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertIn('SERVER', headers) + + def test_default_headers_useragent(self): + msg = protocol.Request(self.transport, 'GET', '/') + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertNotIn('SERVER', headers) + self.assertIn('USER-AGENT', headers) + + def test_default_headers_chunked(self): + msg = protocol.Response(self.transport, 200) + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertNotIn('TRANSFER-ENCODING', headers) + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertIn('TRANSFER-ENCODING', headers) + + def test_default_headers_connection_upgrade(self): + msg = protocol.Response(self.transport, 200) + msg.upgrade = True + msg._add_default_headers() + + headers = [r for r in msg.headers if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'upgrade')], headers) + + def test_default_headers_connection_close(self): + msg = protocol.Response(self.transport, 200) + msg.force_close() + msg._add_default_headers() + + headers = [r for r in msg.headers if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'close')], headers) + + def test_default_headers_connection_keep_alive(self): + msg = protocol.Response(self.transport, 200) + msg.keepalive = True + msg._add_default_headers() + + headers = [r for r in msg.headers if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'keep-alive')], headers) + + def test_send_headers(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-type', 'plain/html')) + self.assertFalse(msg.is_headers_sent()) + + msg.send_headers() + + content = b''.join([arg[1][0] for arg in list(write.mock_calls)]) + + self.assertTrue(content.startswith(b'HTTP/1.1 200 OK\r\n')) + self.assertIn(b'CONTENT-TYPE: plain/html', content) + self.assertTrue(msg.headers_sent) + self.assertTrue(msg.is_headers_sent()) + + def test_send_headers_nomore_add(self): + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-type', 'plain/html')) + msg.send_headers() + + self.assertRaises(AssertionError, + msg.add_header, 'content-type', 'plain/html') + + def test_prepare_length(self): + msg = protocol.Response(self.transport, 200) + length = msg._write_length_payload = unittest.mock.Mock() + length.return_value = iter([1, 2, 3]) + + msg.add_headers(('content-length', '200')) + msg.send_headers() + + self.assertTrue(length.called) + self.assertTrue((200,), length.call_args[0]) + + def test_prepare_chunked_force(self): + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + + chunked = msg._write_chunked_payload = unittest.mock.Mock() + chunked.return_value = iter([1, 2, 3]) + + msg.add_headers(('content-length', '200')) + msg.send_headers() + self.assertTrue(chunked.called) + + def test_prepare_chunked_no_length(self): + msg = protocol.Response(self.transport, 200) + + chunked = msg._write_chunked_payload = unittest.mock.Mock() + chunked.return_value = iter([1, 2, 3]) + + msg.send_headers() + self.assertTrue(chunked.called) + + def test_prepare_eof(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + + eof = msg._write_eof_payload = unittest.mock.Mock() + eof.return_value = iter([1, 2, 3]) + + msg.send_headers() + self.assertTrue(eof.called) + + def test_write_auto_send_headers(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg._send_headers = True + + msg.write(b'data1') + self.assertTrue(msg.headers_sent) + + def test_write_payload_eof(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg.send_headers() + + msg.write(b'data1') + self.assertTrue(msg.headers_sent) + + msg.write(b'data2') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'data1data2', content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg.send_headers() + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'4\r\ndata\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_multiple(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg.send_headers() + + msg.write(b'data1') + msg.write(b'data2') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_length(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '2')) + msg.send_headers() + + msg.write(b'd') + msg.write(b'ata') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'da', content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_filter(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(2) + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith(b'2\r\nda\r\n2\r\nta\r\n0\r\n\r\n')) + + def test_write_payload_chunked_filter_mutiple_chunks(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(2) + msg.write(b'data1') + msg.write(b'data2') + msg.write_eof() + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith( + b'2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n' + b'2\r\na2\r\n0\r\n\r\n')) + + def test_write_payload_chunked_large_chunk(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(1024) + msg.write(b'data') + msg.write_eof() + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith(b'4\r\ndata\r\n0\r\n\r\n')) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_write_payload_deflate_filter(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '{}'.format(len(self._COMPRESSED)))) + msg.send_headers() + + msg.add_compression_filter('deflate') + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_deflate_and_chunked(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_compression_filter('deflate') + msg.add_chunking_filter(2) + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'2\r\nKI\r\n2\r\n,I\r\n2\r\n\x04\x00\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_and_deflate(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '{}'.format(len(self._COMPRESSED)))) + + msg.add_chunking_filter(2) + msg.add_compression_filter('deflate') + msg.send_headers() + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) diff --git a/tests/http_server_test.py b/tests/http_server_test.py new file mode 100644 index 00000000..862779b9 --- /dev/null +++ b/tests/http_server_test.py @@ -0,0 +1,300 @@ +"""Tests for http/server.py""" + +import unittest +import unittest.mock + +import tulip +from tulip.http import server +from tulip.http import errors +from tulip import test_utils + + +class HttpServerProtocolTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + tulip.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_http_error_exception(self): + exc = errors.HttpErrorException(500, message='Internal error') + self.assertEqual(exc.code, 500) + self.assertEqual(exc.message, 'Internal error') + + def test_handle_request(self): + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(transport) + + rline = unittest.mock.Mock() + rline.version = (1, 1) + message = unittest.mock.Mock() + srv.handle_request(rline, message) + + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertTrue(content.startswith(b'HTTP/1.1 404 Not Found\r\n')) + + def test_connection_made(self): + srv = server.ServerHttpProtocol(loop=self.loop) + self.assertIsNone(srv._request_handler) + + srv.connection_made(unittest.mock.Mock()) + self.assertIsNotNone(srv._request_handler) + + def test_data_received(self): + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(unittest.mock.Mock()) + + srv.data_received(b'123') + self.assertEqual(b'123', bytes(srv.stream._buffer)) + + srv.data_received(b'456') + self.assertEqual(b'123456', bytes(srv.stream._buffer)) + + def test_eof_received(self): + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(unittest.mock.Mock()) + srv.eof_received() + self.assertTrue(srv.stream._eof) + + def test_connection_lost(self): + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(unittest.mock.Mock()) + srv.data_received(b'123') + + keep_alive_handle = srv._keep_alive_handle = unittest.mock.Mock() + + handle = srv._request_handler + srv.connection_lost(None) + + self.assertIsNone(srv._request_handler) + self.assertTrue(handle.cancelled()) + + self.assertIsNone(srv._keep_alive_handle) + self.assertTrue(keep_alive_handle.cancel.called) + + srv.connection_lost(None) + self.assertIsNone(srv._request_handler) + self.assertIsNone(srv._keep_alive_handle) + + def test_srv_keep_alive(self): + srv = server.ServerHttpProtocol(loop=self.loop) + self.assertFalse(srv._keep_alive) + + srv.keep_alive(True) + self.assertTrue(srv._keep_alive) + + srv.keep_alive(False) + self.assertFalse(srv._keep_alive) + + def test_handle_error(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(transport) + srv.keep_alive(True) + + srv.handle_error(404, headers=(('X-Server', 'Tulip'),)) + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertIn(b'HTTP/1.1 404 Not Found', content) + self.assertIn(b'X-SERVER: Tulip', content) + self.assertFalse(srv._keep_alive) + + @unittest.mock.patch('tulip.http.server.traceback') + def test_handle_error_traceback_exc(self, m_trace): + transport = unittest.mock.Mock() + log = unittest.mock.Mock() + srv = server.ServerHttpProtocol(debug=True, log=log, loop=self.loop) + srv.connection_made(transport) + + m_trace.format_exc.side_effect = ValueError + + srv.handle_error(500, exc=object()) + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertTrue( + content.startswith(b'HTTP/1.1 500 Internal Server Error')) + self.assertTrue(log.exception.called) + + def test_handle_error_debug(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + srv.debug = True + srv.connection_made(transport) + + try: + raise ValueError() + except Exception as exc: + srv.handle_error(999, exc=exc) + + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + + self.assertIn(b'HTTP/1.1 500 Internal', content) + self.assertIn(b'Traceback (most recent call last):', content) + + def test_handle_error_500(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log, loop=self.loop) + srv.connection_made(transport) + + srv.handle_error(500) + self.assertTrue(log.exception.called) + + def test_handle(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + + self.loop.run_until_complete(srv._request_handler) + self.assertTrue(handle.called) + self.assertTrue(transport.close.called) + + def test_handle_coro(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + + called = False + + @tulip.coroutine + def coro(message, payload): + nonlocal called + called = True + srv.eof_received() + + srv.handle_request = coro + srv.connection_made(transport) + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + self.loop.run_until_complete(srv._request_handler) + self.assertTrue(called) + + def test_handle_cancel(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log, debug=True, loop=self.loop) + srv.connection_made(transport) + + srv.handle_request = unittest.mock.Mock() + + @tulip.coroutine + def cancel(): + srv._request_handler.cancel() + + self.loop.run_until_complete( + tulip.wait([srv._request_handler, cancel()], loop=self.loop)) + self.assertTrue(log.debug.called) + + def test_handle_cancelled(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log, debug=True, loop=self.loop) + srv.connection_made(transport) + + srv.handle_request = unittest.mock.Mock() + test_utils.run_briefly(self.loop) # start request_handler task + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + + r_handler = srv._request_handler + srv._request_handler = None # emulate srv.connection_lost() + + self.assertIsNone(self.loop.run_until_complete(r_handler)) + + def test_handle_400(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(transport) + srv.handle_error = unittest.mock.Mock() + srv.keep_alive(True) + srv.stream.feed_data(b'GET / HT/asd\r\n\r\n') + + self.loop.run_until_complete(srv._request_handler) + self.assertTrue(srv.handle_error.called) + self.assertTrue(400, srv.handle_error.call_args[0][0]) + self.assertTrue(transport.close.called) + + def test_handle_500(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + handle.side_effect = ValueError + srv.handle_error = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + self.loop.run_until_complete(srv._request_handler) + + self.assertTrue(srv.handle_error.called) + self.assertTrue(500, srv.handle_error.call_args[0][0]) + + def test_handle_error_no_handle_task(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + srv.keep_alive(True) + srv.connection_made(transport) + srv.connection_lost(None) + + srv.handle_error(300) + self.assertFalse(srv._keep_alive) + + def test_keep_alive(self): + srv = server.ServerHttpProtocol(keep_alive=0.1, loop=self.loop) + transport = unittest.mock.Mock() + closed = False + + def close(): + nonlocal closed + closed = True + srv.connection_lost(None) + self.loop.stop() + + transport.close = close + + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.1\r\n' + b'CONNECTION: keep-alive\r\n' + b'HOST: example.com\r\n\r\n') + + self.loop.run_forever() + self.assertTrue(handle.called) + self.assertTrue(closed) + + def test_keep_alive_close_existing(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(keep_alive=15, loop=self.loop) + srv.connection_made(transport) + + self.assertIsNone(srv._keep_alive_handle) + keep_alive_handle = srv._keep_alive_handle = unittest.mock.Mock() + srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'HOST: example.com\r\n\r\n') + + self.loop.run_until_complete(srv._request_handler) + self.assertTrue(keep_alive_handle.cancel.called) + self.assertIsNone(srv._keep_alive_handle) + self.assertTrue(transport.close.called) diff --git a/tests/http_session_test.py b/tests/http_session_test.py new file mode 100644 index 00000000..39a80091 --- /dev/null +++ b/tests/http_session_test.py @@ -0,0 +1,139 @@ +"""Tests for tulip/http/session.py""" + +import http.cookies +import unittest +import unittest.mock + +import tulip +import tulip.http + +from tulip.http.client import HttpResponse +from tulip.http.session import Session + +from tulip import test_utils + + +class HttpSessionTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + tulip.set_event_loop(self.loop) + + self.transport = unittest.mock.Mock() + self.stream = tulip.StreamBuffer() + self.response = HttpResponse('get', 'http://python.org') + + def tearDown(self): + tulip.set_event_loop(None) + self.loop.close() + + def test_del(self): + session = Session() + close = session.close = unittest.mock.Mock() + + del session + self.assertTrue(close.called) + + def test_close(self): + tr = unittest.mock.Mock() + + session = Session() + session._conns[1] = [(tr, object())] + session.close() + + self.assertFalse(session._conns) + self.assertTrue(tr.close.called) + + def test_get(self): + session = Session() + self.assertEqual(session._get(1), (None, None)) + + tr, proto = unittest.mock.Mock(), object() + session._conns[1] = [(tr, proto)] + self.assertEqual(session._get(1), (tr, proto)) + + def test_release(self): + session = Session() + resp = unittest.mock.Mock() + resp.message.should_close = False + + cookies = resp.cookies = http.cookies.SimpleCookie() + cookies['c1'] = 'cookie1' + cookies['c2'] = 'cookie2' + + tr, proto = unittest.mock.Mock(), unittest.mock.Mock() + session._release(resp, 1, (tr, proto)) + self.assertEqual(session._conns[1][0], (tr, proto)) + self.assertEqual(session.cookies, dict(cookies.items())) + + def test_release_close(self): + session = Session() + resp = unittest.mock.Mock() + resp.message.should_close = True + + cookies = resp.cookies = http.cookies.SimpleCookie() + cookies['c1'] = 'cookie1' + cookies['c2'] = 'cookie2' + + tr, proto = unittest.mock.Mock(), unittest.mock.Mock() + session._release(resp, 1, (tr, proto)) + self.assertFalse(session._conns) + self.assertTrue(tr.close.called) + + def test_call_new_conn_exc(self): + tr, proto = unittest.mock.Mock(), unittest.mock.Mock() + + class Req: + host = 'host' + port = 80 + ssl = False + + def send(self, *args): + raise ValueError() + + class Loop: + @tulip.coroutine + def create_connection(self, *args, **kw): + return tr, proto + + session = Session() + self.assertRaises( + ValueError, + self.loop.run_until_complete, session.start(Req(), Loop(), True)) + + self.assertTrue(tr.close.called) + + def test_call_existing_conn_exc(self): + existing = unittest.mock.Mock() + tr, proto = unittest.mock.Mock(), unittest.mock.Mock() + + class Req: + host = 'host' + port = 80 + ssl = False + + def send(self, transport): + if transport is existing: + transport.close() + raise ValueError() + else: + return Resp() + + class Resp: + @tulip.coroutine + def start(self, *args, **kw): + pass + + class Loop: + @tulip.coroutine + def create_connection(self, *args, **kw): + return tr, proto + + session = Session() + key = ('host', 80, False) + session._conns[key] = [(existing, object())] + + resp = self.loop.run_until_complete(session.start(Req(), Loop())) + self.assertIsInstance(resp, Resp) + self.assertTrue(existing.close.called) + self.assertFalse(session._conns[key]) diff --git a/tests/http_websocket_test.py b/tests/http_websocket_test.py new file mode 100644 index 00000000..319538ae --- /dev/null +++ b/tests/http_websocket_test.py @@ -0,0 +1,439 @@ +"""Tests for http/websocket.py""" + +import base64 +import hashlib +import os +import struct +import unittest +import unittest.mock + +import tulip +from tulip.http import websocket, protocol, errors + + +class WebsocketParserTests(unittest.TestCase): + + def test_parse_frame(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + p.send(struct.pack('!BB', 0b00000001, 0b00000001)) + try: + p.send(b'1') + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b'1'), (fin, opcode, payload)) + + def test_parse_frame_length0(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + try: + p.send(struct.pack('!BB', 0b00000001, 0b00000000)) + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b''), (fin, opcode, payload)) + + def test_parse_frame_length2(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + p.send(struct.pack('!BB', 0b00000001, 126)) + p.send(struct.pack('!H', 4)) + try: + p.send(b'1234') + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b'1234'), (fin, opcode, payload)) + + def test_parse_frame_length4(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + p.send(struct.pack('!BB', 0b00000001, 127)) + p.send(struct.pack('!Q', 4)) + try: + p.send(b'1234') + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b'1234'), (fin, opcode, payload)) + + def test_parse_frame_mask(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + p.send(struct.pack('!BB', 0b00000001, 0b10000001)) + p.send(b'0001') + try: + p.send(b'1') + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b'\x01'), (fin, opcode, payload)) + + def test_parse_frame_header_reversed_bits(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b01100000, 0b00000000)) + + def test_parse_frame_header_control_frame(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b00001000, 0b00000000)) + + def test_parse_frame_header_continuation(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b00000000, 0b00000000)) + + def test_parse_frame_header_new_data_err(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b000000000, 0b00000000)) + + def test_parse_frame_header_payload_size(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b10001000, 0b01111110)) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_ping_frame(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_PING, b'') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_PING, '', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_pong_frame(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_PONG, b'') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_PONG, '', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_close_frame(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_CLOSE, b'') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_CLOSE, '', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_close_frame_info(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_CLOSE, b'0112345') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_CLOSE, 12337, b'12345')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_close_frame_invalid(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_CLOSE, b'1') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + self.assertRaises(websocket.WebSocketError, p.send, b'') + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_unknown_frame(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_CONTINUATION, b'') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + self.assertRaises(websocket.WebSocketError, p.send, b'') + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_simple_text(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_TEXT, b'text') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_TEXT, 'text', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_simple_binary(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_BINARY, b'binary') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_BINARY, b'binary', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_continuation(self, m_parse_frame): + cur = 0 + + def parse_frame(buf): + nonlocal cur + yield + if cur == 0: + cur = 1 + return (0, websocket.OPCODE_TEXT, b'line1') + else: + return (1, websocket.OPCODE_CONTINUATION, b'line2') + + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + p.send(b'') + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_TEXT, 'line1line2', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_continuation_err(self, m_parse_frame): + cur = 0 + + def parse_frame(buf): + nonlocal cur + yield + if cur == 0: + cur = 1 + return (0, websocket.OPCODE_TEXT, b'line1') + else: + return (1, websocket.OPCODE_TEXT, b'line2') + + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + p.send(b'') + self.assertRaises(websocket.WebSocketError, p.send, b'') + + @unittest.mock.patch('tulip.http.websocket.parse_message') + def test_parser(self, m_parse_message): + cur = 0 + + def parse_message(buf): + nonlocal cur + yield + if cur == 0: + cur = 1 + return websocket.Message(websocket.OPCODE_TEXT, b'line1', b'') + else: + return websocket.Message(websocket.OPCODE_CLOSE, b'', b'') + + m_parse_message.side_effect = parse_message + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = websocket.WebSocketParser() + next(p) + p.send((out, buf)) + p.send(b'') + self.assertRaises(StopIteration, p.send, b'') + + self.assertEqual( + (websocket.OPCODE_TEXT, b'line1', b''), out._buffer[0]) + self.assertEqual( + (websocket.OPCODE_CLOSE, b'', b''), out._buffer[1]) + self.assertTrue(out._eof) + + def test_parser_eof(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = websocket.WebSocketParser() + next(p) + p.send((out, buf)) + self.assertRaises(tulip.EofStream, p.throw, tulip.EofStream) + self.assertEqual([], list(out._buffer)) + + +class WebsocketWriterTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + self.writer = websocket.WebSocketWriter(self.transport) + + def test_pong(self): + self.writer.pong() + self.transport.write.assert_called_with(b'\x8a\x00') + + def test_ping(self): + self.writer.ping() + self.transport.write.assert_called_with(b'\x89\x00') + + def test_send_text(self): + self.writer.send(b'text') + self.transport.write.assert_called_with(b'\x81\x04text') + + def test_send_binary(self): + self.writer.send('binary', True) + self.transport.write.assert_called_with(b'\x82\x06binary') + + def test_send_binary_long(self): + self.writer.send(b'b'*127, True) + self.assertTrue( + self.transport.write.call_args[0][0].startswith(b'\x82~\x00\x7fb')) + + def test_send_binary_very_long(self): + self.writer.send(b'b'*65537, True) + self.assertTrue( + self.transport.write.call_args[0][0].startswith( + b'\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x01b')) + + def test_close(self): + self.writer.close(1001, 'msg') + self.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') + + self.writer.close(1001, b'msg') + self.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') + + +class WebSocketHandshakeTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + self.headers = [] + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 0), self.headers, True, None) + + def test_not_get(self): + self.assertRaises( + errors.HttpErrorException, + websocket.do_handshake, + 'POST', self.message.headers, self.transport) + + def test_no_upgrade(self): + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + def test_no_connection(self): + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'keep-alive')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + def test_protocol_version(self): + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '1')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + def test_protocol_key(self): + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13'), + ('SEC-WEBSOCKET-KEY', '123')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + sec_key = base64.b64encode(os.urandom(2)) + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13'), + ('SEC-WEBSOCKET-KEY', sec_key.decode())]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + def test_handshake(self): + sec_key = base64.b64encode(os.urandom(16)).decode() + + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13'), + ('SEC-WEBSOCKET-KEY', sec_key)]) + status, headers, parser, writer = websocket.do_handshake( + self.message.method, self.message.headers, self.transport) + self.assertEqual(status, 101) + + key = base64.b64encode( + hashlib.sha1(sec_key.encode() + websocket.WS_KEY).digest()) + headers = dict(headers) + self.assertEqual(headers['SEC-WEBSOCKET-ACCEPT'], key.decode()) diff --git a/tests/http_wsgi_test.py b/tests/http_wsgi_test.py new file mode 100644 index 00000000..053f5a69 --- /dev/null +++ b/tests/http_wsgi_test.py @@ -0,0 +1,301 @@ +"""Tests for http/wsgi.py""" + +import io +import unittest +import unittest.mock + +import tulip +from tulip.http import wsgi +from tulip.http import protocol + + +class HttpWsgiServerProtocolTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(None) + + self.wsgi = unittest.mock.Mock() + self.stream = unittest.mock.Mock() + self.transport = unittest.mock.Mock() + self.transport.get_extra_info.return_value = '127.0.0.1' + + self.headers = [] + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 0), self.headers, True, 'deflate') + self.payload = tulip.DataBuffer() + self.payload.feed_data(b'data') + self.payload.feed_data(b'data') + self.payload.feed_eof() + + def tearDown(self): + self.loop.close() + + def test_ctor(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) + self.assertIs(srv.wsgi, self.wsgi) + self.assertFalse(srv.readpayload) + + def _make_one(self, **kw): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop, **kw) + srv.stream = self.stream + srv.transport = self.transport + return srv.create_wsgi_environ(self.message, self.payload) + + def test_environ(self): + environ = self._make_one() + self.assertEqual(environ['RAW_URI'], '/path') + self.assertEqual(environ['wsgi.async'], True) + + def test_environ_except_header(self): + self.headers.append(('EXPECT', '101-continue')) + self._make_one() + self.assertFalse(self.transport.write.called) + + self.headers[0] = ('EXPECT', '100-continue') + self._make_one() + self.transport.write.assert_called_with( + b'HTTP/1.1 100 Continue\r\n\r\n') + + def test_environ_headers(self): + self.headers.extend( + (('HOST', 'python.org'), + ('SCRIPT_NAME', 'script'), + ('CONTENT-TYPE', 'text/plain'), + ('CONTENT-LENGTH', '209'), + ('X_TEST', '123'), + ('X_TEST', '456'))) + environ = self._make_one(is_ssl=True) + self.assertEqual(environ['CONTENT_TYPE'], 'text/plain') + self.assertEqual(environ['CONTENT_LENGTH'], '209') + self.assertEqual(environ['HTTP_X_TEST'], '123,456') + self.assertEqual(environ['SCRIPT_NAME'], 'script') + self.assertEqual(environ['SERVER_NAME'], 'python.org') + self.assertEqual(environ['SERVER_PORT'], '443') + + def test_environ_host_header(self): + self.headers.append(('HOST', 'python.org')) + environ = self._make_one() + + self.assertEqual(environ['HTTP_HOST'], 'python.org') + self.assertEqual(environ['SERVER_NAME'], 'python.org') + self.assertEqual(environ['SERVER_PORT'], '80') + self.assertEqual(environ['SERVER_PROTOCOL'], 'HTTP/1.0') + + def test_environ_host_port_header(self): + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 1), self.headers, True, 'deflate') + self.headers.append(('HOST', 'python.org:443')) + environ = self._make_one() + + self.assertEqual(environ['HTTP_HOST'], 'python.org:443') + self.assertEqual(environ['SERVER_NAME'], 'python.org') + self.assertEqual(environ['SERVER_PORT'], '443') + self.assertEqual(environ['SERVER_PROTOCOL'], 'HTTP/1.1') + + def test_environ_forward(self): + self.transport.get_extra_info.return_value = 'localhost,127.0.0.1' + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') + self.assertEqual(environ['REMOTE_PORT'], '80') + + self.transport.get_extra_info.return_value = 'localhost,127.0.0.1:443' + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') + self.assertEqual(environ['REMOTE_PORT'], '443') + + self.transport.get_extra_info.return_value = ('127.0.0.1', 443) + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') + self.assertEqual(environ['REMOTE_PORT'], '443') + + self.transport.get_extra_info.return_value = '[::1]' + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '::1') + self.assertEqual(environ['REMOTE_PORT'], '80') + + def test_wsgi_response(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + self.assertIsInstance(resp, wsgi.WsgiResponse) + + def test_wsgi_response_start_response(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + resp.start_response( + '200 OK', [('CONTENT-TYPE', 'text/plain')]) + self.assertEqual(resp.status, '200 OK') + self.assertIsInstance(resp.response, protocol.Response) + + def test_wsgi_response_start_response_exc(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + resp.start_response( + '200 OK', [('CONTENT-TYPE', 'text/plain')], ['', ValueError()]) + self.assertEqual(resp.status, '200 OK') + self.assertIsInstance(resp.response, protocol.Response) + + def test_wsgi_response_start_response_exc_status(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + resp.start_response('200 OK', [('CONTENT-TYPE', 'text/plain')]) + + self.assertRaises( + ValueError, + resp.start_response, + '500 Err', [('CONTENT-TYPE', 'text/plain')], ['', ValueError()]) + + @unittest.mock.patch('tulip.http.wsgi.tulip') + def test_wsgi_response_101_upgrade_to_websocket(self, m_tulip): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + resp.start_response( + '101 Switching Protocols', (('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'))) + self.assertEqual(resp.status, '101 Switching Protocols') + self.assertTrue(m_tulip.http.Response.return_value.send_headers.called) + + def test_file_wrapper(self): + fobj = io.BytesIO(b'data') + wrapper = wsgi.FileWrapper(fobj, 2) + self.assertIs(wrapper, iter(wrapper)) + self.assertTrue(hasattr(wrapper, 'close')) + + self.assertEqual(next(wrapper), b'da') + self.assertEqual(next(wrapper), b'ta') + self.assertRaises(StopIteration, next, wrapper) + + wrapper = wsgi.FileWrapper(b'data', 2) + self.assertFalse(hasattr(wrapper, 'close')) + + def test_handle_request_futures(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + f1 = tulip.Future(loop=self.loop) + f1.set_result(b'data') + fut = tulip.Future(loop=self.loop) + fut.set_result([f1]) + return fut + + srv = wsgi.WSGIServerHttpProtocol(wsgi_app, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) + self.assertTrue(content.endswith(b'data')) + + def test_handle_request_simple(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + return [b'data'] + + stream = tulip.StreamReader(loop=self.loop) + stream.feed_data(b'data') + stream.feed_eof() + + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 1), self.headers, True, 'deflate') + + srv = wsgi.WSGIServerHttpProtocol( + wsgi_app, readpayload=True, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.1 200 OK')) + self.assertTrue(content.endswith(b'data\r\n0\r\n\r\n')) + self.assertFalse(srv._keep_alive) + + def test_handle_request_io(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + return io.BytesIO(b'data') + + srv = wsgi.WSGIServerHttpProtocol(wsgi_app, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) + self.assertTrue(content.endswith(b'data')) + + def test_handle_request_keep_alive(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + return [b'data'] + + stream = tulip.StreamReader(loop=self.loop) + stream.feed_data(b'data') + stream.feed_eof() + + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 1), self.headers, False, 'deflate') + + srv = wsgi.WSGIServerHttpProtocol( + wsgi_app, readpayload=True, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.1 200 OK')) + self.assertTrue(content.endswith(b'data\r\n0\r\n\r\n')) + self.assertTrue(srv._keep_alive) + + def test_handle_request_readpayload(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + return [env['wsgi.input'].read()] + + srv = wsgi.WSGIServerHttpProtocol( + wsgi_app, readpayload=True, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) + self.assertTrue(content.endswith(b'data')) diff --git a/tests/locks_test.py b/tests/locks_test.py new file mode 100644 index 00000000..83663ec0 --- /dev/null +++ b/tests/locks_test.py @@ -0,0 +1,918 @@ +"""Tests for lock.py""" + +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import locks +from tulip import tasks +from tulip import test_utils + + +class LockTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + lock = locks.Lock(loop=loop) + self.assertIs(lock._loop, loop) + + lock = locks.Lock(loop=self.loop) + self.assertIs(lock._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + lock = locks.Lock() + self.assertIs(lock._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + lock = locks.Lock(loop=self.loop) + self.assertTrue(repr(lock).endswith('[unlocked]>')) + + @tasks.coroutine + def acquire_lock(): + yield from lock + + self.loop.run_until_complete(acquire_lock()) + self.assertTrue(repr(lock).endswith('[locked]>')) + + def test_lock(self): + lock = locks.Lock(loop=self.loop) + + @tasks.coroutine + def acquire_lock(): + return (yield from lock) + + res = self.loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_acquire(self): + lock = locks.Lock(loop=self.loop) + result = [] + + self.assertTrue(self.loop.run_until_complete(lock.acquire())) + + @tasks.coroutine + def c1(result): + if (yield from lock.acquire()): + result.append(1) + return True + + @tasks.coroutine + def c2(result): + if (yield from lock.acquire()): + result.append(2) + return True + + @tasks.coroutine + def c3(result): + if (yield from lock.acquire()): + result.append(3) + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + t3 = tasks.Task(c3(result), loop=self.loop) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_acquire_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + lock = locks.Lock(loop=loop) + + self.assertTrue(loop.run_until_complete(lock.acquire())) + + acquired = loop.run_until_complete(lock.acquire(timeout=0.1)) + self.assertFalse(acquired) + self.assertAlmostEqual(0.1, loop.time()) + + lock = locks.Lock(loop=loop) + self.loop.run_until_complete(lock.acquire()) + + loop.call_soon(lock.release) + acquired = loop.run_until_complete(lock.acquire(10.1)) + self.assertTrue(acquired) + self.assertAlmostEqual(0.1, loop.time()) + + def test_acquire_timeout_mixed(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + lock = locks.Lock(loop=loop) + loop.run_until_complete(lock.acquire()) + tasks.Task(lock.acquire(), loop=loop) + tasks.Task(lock.acquire(), loop=loop) + acquire_task = tasks.Task(lock.acquire(0.01), loop=loop) + tasks.Task(lock.acquire(), loop=loop) + + acquired = loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + self.assertAlmostEqual(0.1, loop.time()) + + self.assertEqual(3, len(lock._waiters)) + + def test_acquire_cancel(self): + lock = locks.Lock(loop=self.loop) + self.assertTrue(self.loop.run_until_complete(lock.acquire())) + + task = tasks.Task(lock.acquire(), loop=self.loop) + self.loop.call_soon(task.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, task) + self.assertFalse(lock._waiters) + + def test_release_not_acquired(self): + lock = locks.Lock(loop=self.loop) + + self.assertRaises(RuntimeError, lock.release) + + def test_release_no_waiters(self): + lock = locks.Lock(loop=self.loop) + self.loop.run_until_complete(lock.acquire()) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_context_manager(self): + lock = locks.Lock(loop=self.loop) + + @tasks.coroutine + def acquire_lock(): + return (yield from lock) + + with self.loop.run_until_complete(acquire_lock()): + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + def test_context_manager_no_yield(self): + lock = locks.Lock(loop=self.loop) + + try: + with lock: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + +class EventWaiterTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + ev = locks.EventWaiter(loop=loop) + self.assertIs(ev._loop, loop) + + ev = locks.EventWaiter(loop=self.loop) + self.assertIs(ev._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + ev = locks.EventWaiter() + self.assertIs(ev._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + ev = locks.EventWaiter(loop=self.loop) + self.assertTrue(repr(ev).endswith('[unset]>')) + + ev.set() + self.assertTrue(repr(ev).endswith('[set]>')) + + def test_wait(self): + ev = locks.EventWaiter(loop=self.loop) + self.assertFalse(ev.is_set()) + + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from ev.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from ev.wait()): + result.append(3) + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + t3 = tasks.Task(c3(result), loop=self.loop) + + ev.set() + test_utils.run_briefly(self.loop) + self.assertEqual([3, 1, 2], result) + + self.assertTrue(t1.done()) + self.assertIsNone(t1.result()) + self.assertTrue(t2.done()) + self.assertIsNone(t2.result()) + self.assertTrue(t3.done()) + self.assertIsNone(t3.result()) + + def test_wait_on_set(self): + ev = locks.EventWaiter(loop=self.loop) + ev.set() + + res = self.loop.run_until_complete(ev.wait()) + self.assertTrue(res) + + def test_wait_timeout(self): + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.11, when) + when = yield 0 + self.assertAlmostEqual(10.2, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + ev = locks.EventWaiter(loop=loop) + + res = loop.run_until_complete(ev.wait(0.1)) + self.assertFalse(res) + self.assertAlmostEqual(0.1, loop.time()) + + ev = locks.EventWaiter(loop=loop) + loop.call_later(0.01, ev.set) + acquired = loop.run_until_complete(ev.wait(10.1)) + self.assertTrue(acquired) + self.assertAlmostEqual(0.11, loop.time()) + + def test_wait_timeout_mixed(self): + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + ev = locks.EventWaiter(loop=loop) + tasks.Task(ev.wait(), loop=loop) + tasks.Task(ev.wait(), loop=loop) + acquire_task = tasks.Task(ev.wait(0.1), loop=loop) + tasks.Task(ev.wait(), loop=loop) + + acquired = loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + self.assertAlmostEqual(0.1, loop.time()) + self.assertEqual(3, len(ev._waiters)) + + def test_wait_cancel(self): + ev = locks.EventWaiter(loop=self.loop) + + wait = tasks.Task(ev.wait(), loop=self.loop) + self.loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, wait) + self.assertFalse(ev._waiters) + + def test_clear(self): + ev = locks.EventWaiter(loop=self.loop) + self.assertFalse(ev.is_set()) + + ev.set() + self.assertTrue(ev.is_set()) + + ev.clear() + self.assertFalse(ev.is_set()) + + def test_clear_with_waiters(self): + ev = locks.EventWaiter(loop=self.loop) + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + return True + + t = tasks.Task(c1(result), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + ev.set() + ev.clear() + self.assertFalse(ev.is_set()) + + ev.set() + ev.set() + self.assertEqual(1, len(ev._waiters)) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertEqual(0, len(ev._waiters)) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + +class ConditionTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + cond = locks.Condition(loop=loop) + self.assertIs(cond._loop, loop) + + cond = locks.Condition(loop=self.loop) + self.assertIs(cond._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + cond = locks.Condition() + self.assertIs(cond._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_wait(self): + cond = locks.Condition(loop=self.loop) + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + return True + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue(self.loop.run_until_complete(cond.acquire())) + cond.notify() + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_wait_timeout(self): + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + cond = locks.Condition(loop=loop) + loop.run_until_complete(cond.acquire()) + + wait = loop.run_until_complete(cond.wait(0.1)) + self.assertFalse(wait) + self.assertTrue(cond.locked()) + self.assertAlmostEqual(0.1, loop.time()) + + def test_wait_cancel(self): + cond = locks.Condition(loop=self.loop) + self.loop.run_until_complete(cond.acquire()) + + wait = tasks.Task(cond.wait(), loop=self.loop) + self.loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, wait) + self.assertFalse(cond._condition_waiters) + self.assertTrue(cond.locked()) + + def test_wait_unacquired(self): + cond = locks.Condition(loop=self.loop) + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, cond.wait()) + + def test_wait_for(self): + cond = locks.Condition(loop=self.loop) + presult = False + + def predicate(): + return presult + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate)): + result.append(1) + cond.release() + return True + + t = tasks.Task(c1(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + presult = True + self.loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + def test_wait_for_timeout(self): + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + cond = locks.Condition(loop=loop) + + result = [] + + predicate = unittest.mock.Mock(return_value=False) + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate, 0.1)): + result.append(1) + else: + result.append(2) + cond.release() + + wait_for = tasks.Task(c1(result), loop=loop) + + test_utils.run_briefly(loop) + self.assertEqual([], result) + + loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + test_utils.run_briefly(loop) + self.assertEqual([], result) + + loop.run_until_complete(wait_for) + self.assertEqual([2], result) + self.assertEqual(3, predicate.call_count) + + self.assertAlmostEqual(0.1, loop.time()) + + def test_wait_for_unacquired(self): + cond = locks.Condition(loop=self.loop) + + # predicate can return true immediately + res = self.loop.run_until_complete(cond.wait_for(lambda: [1, 2, 3])) + self.assertEqual([1, 2, 3], res) + + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, + cond.wait_for(lambda: False)) + + def test_notify(self): + cond = locks.Condition(loop=self.loop) + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + return True + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + cond.release() + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.notify(2048) + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_notify_all(self): + cond = locks.Condition(loop=self.loop) + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify_all() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + + def test_notify_unacquired(self): + cond = locks.Condition(loop=self.loop) + self.assertRaises(RuntimeError, cond.notify) + + def test_notify_all_unacquired(self): + cond = locks.Condition(loop=self.loop) + self.assertRaises(RuntimeError, cond.notify_all) + + +class SemaphoreTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + sem = locks.Semaphore(loop=loop) + self.assertIs(sem._loop, loop) + + sem = locks.Semaphore(loop=self.loop) + self.assertIs(sem._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + sem = locks.Semaphore() + self.assertIs(sem._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + sem = locks.Semaphore(loop=self.loop) + self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) + + self.loop.run_until_complete(sem.acquire()) + self.assertTrue(repr(sem).endswith('[locked]>')) + + def test_semaphore(self): + sem = locks.Semaphore(loop=self.loop) + self.assertEqual(1, sem._value) + + @tasks.coroutine + def acquire_lock(): + return (yield from sem) + + res = self.loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(sem.locked()) + self.assertEqual(0, sem._value) + + sem.release() + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + def test_semaphore_value(self): + self.assertRaises(ValueError, locks.Semaphore, -1) + + def test_acquire(self): + sem = locks.Semaphore(3, loop=self.loop) + result = [] + + self.assertTrue(self.loop.run_until_complete(sem.acquire())) + self.assertTrue(self.loop.run_until_complete(sem.acquire())) + self.assertFalse(sem.locked()) + + @tasks.coroutine + def c1(result): + yield from sem.acquire() + result.append(1) + return True + + @tasks.coroutine + def c2(result): + yield from sem.acquire() + result.append(2) + return True + + @tasks.coroutine + def c3(result): + yield from sem.acquire() + result.append(3) + return True + + @tasks.coroutine + def c4(result): + yield from sem.acquire() + result.append(4) + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(sem.locked()) + self.assertEqual(2, len(sem._waiters)) + self.assertEqual(0, sem._value) + + t4 = tasks.Task(c4(result), loop=self.loop) + + sem.release() + sem.release() + self.assertEqual(2, sem._value) + + test_utils.run_briefly(self.loop) + self.assertEqual(0, sem._value) + self.assertEqual([1, 2, 3], result) + self.assertTrue(sem.locked()) + self.assertEqual(1, len(sem._waiters)) + self.assertEqual(0, sem._value) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + self.assertFalse(t4.done()) + + def test_acquire_timeout(self): + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.11, when) + when = yield 0 + self.assertAlmostEqual(10.2, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + sem = locks.Semaphore(loop=loop) + loop.run_until_complete(sem.acquire()) + + acquired = loop.run_until_complete(sem.acquire(0.1)) + self.assertFalse(acquired) + self.assertAlmostEqual(0.1, loop.time()) + + sem = locks.Semaphore(loop=loop) + loop.run_until_complete(sem.acquire()) + + loop.call_later(0.01, sem.release) + acquired = loop.run_until_complete(sem.acquire(10.1)) + self.assertTrue(acquired) + self.assertAlmostEqual(0.11, loop.time()) + + def test_acquire_timeout_mixed(self): + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + sem = locks.Semaphore(loop=loop) + loop.run_until_complete(sem.acquire()) + tasks.Task(sem.acquire(), loop=loop) + tasks.Task(sem.acquire(), loop=loop) + acquire_task = tasks.Task(sem.acquire(0.1), loop=loop) + tasks.Task(sem.acquire(), loop=loop) + + acquired = loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + self.assertAlmostEqual(0.1, loop.time()) + + self.assertEqual(3, len(sem._waiters)) + + def test_acquire_cancel(self): + sem = locks.Semaphore(loop=self.loop) + self.loop.run_until_complete(sem.acquire()) + + acquire = tasks.Task(sem.acquire(), loop=self.loop) + self.loop.call_soon(acquire.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, acquire) + self.assertFalse(sem._waiters) + + def test_release_not_acquired(self): + sem = locks.Semaphore(bound=True, loop=self.loop) + + self.assertRaises(ValueError, sem.release) + + def test_release_no_waiters(self): + sem = locks.Semaphore(loop=self.loop) + self.loop.run_until_complete(sem.acquire()) + self.assertTrue(sem.locked()) + + sem.release() + self.assertFalse(sem.locked()) + + def test_context_manager(self): + sem = locks.Semaphore(2, loop=self.loop) + + @tasks.coroutine + def acquire_lock(): + return (yield from sem) + + with self.loop.run_until_complete(acquire_lock()): + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + with self.loop.run_until_complete(acquire_lock()): + self.assertTrue(sem.locked()) + + self.assertEqual(2, sem._value) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/parsers_test.py b/tests/parsers_test.py new file mode 100644 index 00000000..debc532c --- /dev/null +++ b/tests/parsers_test.py @@ -0,0 +1,598 @@ +"""Tests for parser.py""" + +import unittest +import unittest.mock + +from tulip import events +from tulip import parsers +from tulip import tasks + + +class StreamBufferTests(unittest.TestCase): + + DATA = b'line1\nline2\nline3\n' + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_exception(self): + stream = parsers.StreamBuffer() + self.assertIsNone(stream.exception()) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(stream.exception(), exc) + + def test_exception_waiter(self): + stream = parsers.StreamBuffer() + + stream._parser = parsers.lines_parser() + buf = stream._parser_buffer = parsers.DataBuffer(loop=self.loop) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(buf.exception(), exc) + + def test_feed_data(self): + stream = parsers.StreamBuffer() + + stream.feed_data(self.DATA) + self.assertEqual(self.DATA, bytes(stream._buffer)) + + def test_feed_empty_data(self): + stream = parsers.StreamBuffer() + + stream.feed_data(b'') + self.assertEqual(b'', bytes(stream._buffer)) + + def test_set_parser_unset_prev(self): + stream = parsers.StreamBuffer() + stream.set_parser(parsers.lines_parser()) + + unset = stream.unset_parser = unittest.mock.Mock() + stream.set_parser(parsers.lines_parser()) + + self.assertTrue(unset.called) + + def test_set_parser_exception(self): + stream = parsers.StreamBuffer() + + exc = ValueError() + stream.set_exception(exc) + s = stream.set_parser(parsers.lines_parser()) + self.assertIs(s.exception(), exc) + + def test_set_parser_feed_existing(self): + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_data(b'\r\nline2\r\ndata') + s = stream.set_parser(parsers.lines_parser()) + + self.assertEqual([bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertIsNotNone(stream._parser) + + stream.unset_parser() + self.assertIsNone(stream._parser) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertTrue(s._eof) + + def test_set_parser_feed_existing_exc(self): + + def p(): + yield # stream + raise ValueError() + + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + s = stream.set_parser(p()) + self.assertIsInstance(s.exception(), ValueError) + + def test_set_parser_feed_existing_eof(self): + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_data(b'\r\nline2\r\ndata') + stream.feed_eof() + s = stream.set_parser(parsers.lines_parser()) + + self.assertEqual([bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertIsNone(stream._parser) + + def test_set_parser_feed_existing_eof_exc(self): + + def p(): + yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + raise ValueError() + + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_eof() + s = stream.set_parser(p()) + self.assertIsInstance(s.exception(), ValueError) + + def test_set_parser_feed_existing_eof_unhandled_eof(self): + + def p(): + yield # stream + while True: + yield # read chunk + + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_eof() + s = stream.set_parser(p()) + self.assertIsNone(s.exception()) + self.assertTrue(s._eof) + + def test_set_parser_unset(self): + stream = parsers.StreamBuffer() + s = stream.set_parser(parsers.lines_parser()) + + stream.feed_data(b'line1\r\nline2\r\n') + self.assertEqual( + [bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'', bytes(stream._buffer)) + stream.unset_parser() + self.assertTrue(s._eof) + self.assertEqual(b'', bytes(stream._buffer)) + + def test_set_parser_feed_existing_stop(self): + def lines_parser(): + out, buf = yield + try: + out.feed_data((yield from buf.readuntil(b'\n'))) + out.feed_data((yield from buf.readuntil(b'\n'))) + finally: + out.feed_eof() + + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_data(b'\r\nline2\r\ndata') + s = stream.set_parser(lines_parser()) + + self.assertEqual(b'line1\r\nline2\r\n', b''.join(s._buffer)) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertIsNone(stream._parser) + self.assertTrue(s._eof) + + def test_feed_parser(self): + stream = parsers.StreamBuffer() + s = stream.set_parser(parsers.lines_parser()) + + stream.feed_data(b'line1') + stream.feed_data(b'\r\nline2\r\ndata') + self.assertEqual(b'data', bytes(stream._buffer)) + + stream.feed_eof() + self.assertEqual([bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertTrue(s._eof) + + def test_feed_parser_exc(self): + def p(): + yield # stream + yield # read chunk + raise ValueError() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + self.assertIsInstance(s.exception(), ValueError) + self.assertEqual(b'', bytes(stream._buffer)) + + def test_feed_parser_stop(self): + def p(): + yield # stream + yield # chunk + + stream = parsers.StreamBuffer() + stream.set_parser(p()) + + stream.feed_data(b'line1') + self.assertIsNone(stream._parser) + self.assertEqual(b'', bytes(stream._buffer)) + + def test_feed_eof_exc(self): + def p(): + yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + raise ValueError() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + self.assertIsNone(s.exception()) + + stream.feed_eof() + self.assertIsInstance(s.exception(), ValueError) + + def test_feed_eof_stop(self): + def p(): + out, buf = yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + out.feed_eof() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.feed_eof() + self.assertTrue(s._eof) + + def test_feed_eof_unhandled_eof(self): + def p(): + yield # stream + while True: + yield # read chunk + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.feed_eof() + self.assertIsNone(s.exception()) + self.assertTrue(s._eof) + + def test_feed_parser2(self): + stream = parsers.StreamBuffer() + s = stream.set_parser(parsers.lines_parser()) + + stream.feed_data(b'line1\r\nline2\r\n') + stream.feed_eof() + self.assertEqual( + [bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'', bytes(stream._buffer)) + self.assertTrue(s._eof) + + def test_unset_parser_eof_exc(self): + def p(): + yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + raise ValueError() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.unset_parser() + self.assertIsInstance(s.exception(), ValueError) + self.assertIsNone(stream._parser) + + def test_unset_parser_eof_unhandled_eof(self): + def p(): + yield # stream + while True: + yield # read chunk + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.unset_parser() + self.assertIsNone(s.exception(), ValueError) + self.assertTrue(s._eof) + + def test_unset_parser_stop(self): + def p(): + out, buf = yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + out.feed_eof() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.unset_parser() + self.assertTrue(s._eof) + + +class DataBufferTests(unittest.TestCase): + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_feed_data(self): + buffer = parsers.DataBuffer(loop=self.loop) + + item = object() + buffer.feed_data(item) + self.assertEqual([item], list(buffer._buffer)) + + def test_feed_eof(self): + buffer = parsers.DataBuffer(loop=self.loop) + buffer.feed_eof() + self.assertTrue(buffer._eof) + + def test_read(self): + item = object() + buffer = parsers.DataBuffer(loop=self.loop) + read_task = tasks.Task(buffer.read(), loop=self.loop) + + def cb(): + buffer.feed_data(item) + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertIs(item, data) + + def test_read_eof(self): + buffer = parsers.DataBuffer(loop=self.loop) + read_task = tasks.Task(buffer.read(), loop=self.loop) + + def cb(): + buffer.feed_eof() + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertIsNone(data) + + def test_read_until_eof(self): + item = object() + buffer = parsers.DataBuffer(loop=self.loop) + buffer.feed_data(item) + buffer.feed_eof() + + data = self.loop.run_until_complete(buffer.read()) + self.assertIs(data, item) + + data = self.loop.run_until_complete(buffer.read()) + self.assertIsNone(data) + + def test_read_exception(self): + buffer = parsers.DataBuffer(loop=self.loop) + buffer.feed_data(object()) + buffer.set_exception(ValueError()) + + self.assertRaises( + ValueError, self.loop.run_until_complete, buffer.read()) + + def test_exception(self): + buffer = parsers.DataBuffer(loop=self.loop) + self.assertIsNone(buffer.exception()) + + exc = ValueError() + buffer.set_exception(exc) + self.assertIs(buffer.exception(), exc) + + def test_exception_waiter(self): + buffer = parsers.DataBuffer(loop=self.loop) + + @tasks.coroutine + def set_err(): + buffer.set_exception(ValueError()) + + t1 = tasks.Task(buffer.read(), loop=self.loop) + t2 = tasks.Task(set_err(), loop=self.loop) + + self.loop.run_until_complete(tasks.wait([t1, t2], loop=self.loop)) + + self.assertRaises(ValueError, t1.result) + + +class StreamProtocolTests(unittest.TestCase): + + def test_connection_made(self): + tr = unittest.mock.Mock() + + proto = parsers.StreamProtocol() + self.assertIsNone(proto.transport) + + proto.connection_made(tr) + self.assertIs(proto.transport, tr) + + def test_connection_lost(self): + proto = parsers.StreamProtocol() + proto.connection_made(unittest.mock.Mock()) + proto.connection_lost(None) + self.assertIsNone(proto.transport) + self.assertTrue(proto._eof) + + def test_connection_lost_exc(self): + proto = parsers.StreamProtocol() + proto.connection_made(unittest.mock.Mock()) + + exc = ValueError() + proto.connection_lost(exc) + self.assertIs(proto.exception(), exc) + + +class ParserBuffer(unittest.TestCase): + + def _make_one(self): + return parsers.ParserBuffer() + + def test_shrink(self): + buf = parsers.ParserBuffer() + buf.feed_data(b'data') + + buf._shrink() + self.assertEqual(bytes(buf), b'data') + + buf.offset = 2 + buf._shrink() + self.assertEqual(bytes(buf), b'ta') + self.assertEqual(2, len(buf)) + self.assertEqual(2, buf.size) + self.assertEqual(0, buf.offset) + + def test_feed_data(self): + buf = self._make_one() + buf.feed_data(b'') + self.assertEqual(len(buf), 0) + + buf.feed_data(b'data') + self.assertEqual(len(buf), 4) + self.assertEqual(bytes(buf), b'data') + + def test_read(self): + buf = self._make_one() + p = buf.read(3) + next(p) + p.send(b'1') + try: + p.send(b'234') + except StopIteration as exc: + res = exc.value + + self.assertEqual(res, b'123') + self.assertEqual(b'4', bytes(buf)) + + def test_readsome(self): + buf = self._make_one() + p = buf.readsome(3) + next(p) + try: + p.send(b'1') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, b'1') + + p = buf.readsome(2) + next(p) + try: + p.send(b'234') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, b'23') + self.assertEqual(b'4', bytes(buf)) + + def test_skip(self): + buf = self._make_one() + p = buf.skip(3) + next(p) + p.send(b'1') + try: + p.send(b'234') + except StopIteration as exc: + res = exc.value + + self.assertIsNone(res) + self.assertEqual(b'4', bytes(buf)) + + def test_readuntil_limit(self): + buf = self._make_one() + p = buf.readuntil(b'\n', 4) + next(p) + p.send(b'1') + p.send(b'234') + self.assertRaises(ValueError, p.send, b'5') + + buf = parsers.ParserBuffer() + p = buf.readuntil(b'\n', 4) + next(p) + self.assertRaises(ValueError, p.send, b'12345\n6') + + buf = parsers.ParserBuffer() + p = buf.readuntil(b'\n', 4) + next(p) + self.assertRaises(ValueError, p.send, b'12345\n6') + + class CustomExc(Exception): + pass + + buf = parsers.ParserBuffer() + p = buf.readuntil(b'\n', 4, CustomExc) + next(p) + self.assertRaises(CustomExc, p.send, b'12345\n6') + + def test_readuntil(self): + buf = self._make_one() + p = buf.readuntil(b'\n', 4) + next(p) + p.send(b'123') + try: + p.send(b'\n456') + except StopIteration as exc: + res = exc.value + + self.assertEqual(res, b'123\n') + self.assertEqual(b'456', bytes(buf)) + + def test_skipuntil(self): + buf = self._make_one() + p = buf.skipuntil(b'\n') + next(p) + p.send(b'123') + try: + p.send(b'\n456\n') + except StopIteration: + pass + self.assertEqual(b'456\n', bytes(buf)) + + p = buf.skipuntil(b'\n') + try: + next(p) + except StopIteration: + pass + self.assertEqual(b'', bytes(buf)) + + def test_lines_parser(self): + out = parsers.DataBuffer(loop=self.loop) + buf = self._make_one() + p = parsers.lines_parser() + next(p) + p.send((out, buf)) + + for d in (b'line1', b'\r\n', b'lin', b'e2\r', b'\ndata'): + p.send(d) + + self.assertEqual( + [bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(out._buffer)) + try: + p.throw(parsers.EofStream()) + except parsers.EofStream: + pass + + self.assertEqual(bytes(buf), b'data') + + def test_chunks_parser(self): + out = parsers.DataBuffer(loop=self.loop) + buf = self._make_one() + p = parsers.chunks_parser(5) + next(p) + p.send((out, buf)) + + for d in (b'line1', b'lin', b'e2d', b'ata'): + p.send(d) + + self.assertEqual( + [bytearray(b'line1'), bytearray(b'line2')], list(out._buffer)) + try: + p.throw(parsers.EofStream()) + except parsers.EofStream: + pass + + self.assertEqual(bytes(buf), b'data') diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py new file mode 100644 index 00000000..da4dea35 --- /dev/null +++ b/tests/proactor_events_test.py @@ -0,0 +1,393 @@ +"""Tests for proactor_events.py""" + +import socket +import unittest +import unittest.mock + +import tulip +from tulip.proactor_events import BaseProactorEventLoop +from tulip.proactor_events import _ProactorSocketTransport +from tulip import test_utils + + +class ProactorSocketTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.proactor = unittest.mock.Mock() + self.loop._proactor = self.proactor + self.protocol = test_utils.make_test_protocol(tulip.Protocol) + self.sock = unittest.mock.Mock(socket.socket) + + def test_ctor(self): + fut = tulip.Future(loop=self.loop) + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol, fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + self.protocol.connection_made(tr) + self.proactor.recv.assert_called_with(self.sock, 4096) + + def test_loop_reading(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._loop_reading() + self.loop._proactor.recv.assert_called_with(self.sock, 4096) + self.assertFalse(self.protocol.data_received.called) + self.assertFalse(self.protocol.eof_received.called) + + def test_loop_reading_data(self): + res = tulip.Future(loop=self.loop) + res.set_result(b'data') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + + tr._read_fut = res + tr._loop_reading(res) + self.loop._proactor.recv.assert_called_with(self.sock, 4096) + self.protocol.data_received.assert_called_with(b'data') + + def test_loop_reading_no_data(self): + res = tulip.Future(loop=self.loop) + res.set_result(b'') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + + self.assertRaises(AssertionError, tr._loop_reading, res) + + tr.close = unittest.mock.Mock() + tr._read_fut = res + tr._loop_reading(res) + self.assertFalse(self.loop._proactor.recv.called) + self.assertTrue(self.protocol.eof_received.called) + self.assertTrue(tr.close.called) + + def test_loop_reading_aborted(self): + err = self.loop._proactor.recv.side_effect = ConnectionAbortedError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with(err) + + def test_loop_reading_aborted_closing(self): + self.loop._proactor.recv.side_effect = ConnectionAbortedError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = True + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + + def test_loop_reading_aborted_is_fatal(self): + self.loop._proactor.recv.side_effect = ConnectionAbortedError() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = False + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + self.assertTrue(tr._fatal_error.called) + + def test_loop_reading_conn_reset_lost(self): + err = self.loop._proactor.recv.side_effect = ConnectionResetError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = False + tr._fatal_error = unittest.mock.Mock() + tr._force_close = unittest.mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + tr._force_close.assert_called_with(err) + + def test_loop_reading_exception(self): + err = self.loop._proactor.recv.side_effect = (OSError()) + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with(err) + + def test_write(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._loop_writing = unittest.mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, [b'data']) + self.assertTrue(tr._loop_writing.called) + + def test_write_no_data(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr.write(b'') + self.assertFalse(tr._buffer) + + def test_write_more(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = unittest.mock.Mock() + tr._loop_writing = unittest.mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, [b'data']) + self.assertFalse(tr._loop_writing.called) + + def test_loop_writing(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + self.loop._proactor.send.assert_called_with(self.sock, b'data') + self.loop._proactor.send.return_value.add_done_callback.\ + assert_called_with(tr._loop_writing) + + @unittest.mock.patch('tulip.proactor_events.tulip_log') + def test_loop_writing_err(self, m_log): + err = self.loop._proactor.send.side_effect = OSError() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + tr._fatal_error.assert_called_with(err) + tr._conn_lost = 1 + + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + self.assertEqual(tr._buffer, []) + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_loop_writing_stop(self): + fut = tulip.Future(loop=self.loop) + fut.set_result(b'data') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = fut + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + + def test_loop_writing_closing(self): + fut = tulip.Future(loop=self.loop) + fut.set_result(1) + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = fut + tr.close() + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test_abort(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._force_close = unittest.mock.Mock() + tr.abort() + tr._force_close.assert_called_with(None) + + def test_close(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr.close() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertTrue(tr._closing) + self.assertEqual(tr._conn_lost, 1) + + self.protocol.connection_lost.reset_mock() + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_write_fut(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = unittest.mock.Mock() + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_buffer(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + @unittest.mock.patch('tulip.proactor_events.tulip_log') + def test_fatal_error(self, m_logging): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._force_close = unittest.mock.Mock() + tr._fatal_error(None) + self.assertTrue(tr._force_close.called) + self.assertTrue(m_logging.exception.called) + + def test_force_close(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + read_fut = tr._read_fut = unittest.mock.Mock() + write_fut = tr._write_fut = unittest.mock.Mock() + tr._force_close(None) + + read_fut.cancel.assert_called_with() + write_fut.cancel.assert_called_with() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertEqual([], tr._buffer) + self.assertEqual(tr._conn_lost, 1) + + def test_force_close_idempotent(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = True + tr._force_close(None) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_fatal_error_2(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + tr._force_close(None) + + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertEqual([], tr._buffer) + + def test_call_connection_lost(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._call_connection_lost(None) + self.assertTrue(self.protocol.connection_lost.called) + self.assertTrue(self.sock.close.called) + + +class BaseProactorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.sock = unittest.mock.Mock(socket.socket) + self.proactor = unittest.mock.Mock() + + self.ssock, self.csock = unittest.mock.Mock(), unittest.mock.Mock() + + class EventLoop(BaseProactorEventLoop): + def _socketpair(s): + return (self.ssock, self.csock) + + self.loop = EventLoop(self.proactor) + + @unittest.mock.patch.object(BaseProactorEventLoop, 'call_soon') + @unittest.mock.patch.object(BaseProactorEventLoop, '_socketpair') + def test_ctor(self, socketpair, call_soon): + ssock, csock = socketpair.return_value = ( + unittest.mock.Mock(), unittest.mock.Mock()) + loop = BaseProactorEventLoop(self.proactor) + self.assertIs(loop._ssock, ssock) + self.assertIs(loop._csock, csock) + self.assertEqual(loop._internal_fds, 1) + call_soon.assert_called_with(loop._loop_self_reading) + + def test_close_self_pipe(self): + self.loop._close_self_pipe() + self.assertEqual(self.loop._internal_fds, 0) + self.assertTrue(self.ssock.close.called) + self.assertTrue(self.csock.close.called) + self.assertIsNone(self.loop._ssock) + self.assertIsNone(self.loop._csock) + + def test_close(self): + self.loop._close_self_pipe = unittest.mock.Mock() + self.loop.close() + self.assertTrue(self.loop._close_self_pipe.called) + self.assertTrue(self.proactor.close.called) + self.assertIsNone(self.loop._proactor) + + self.loop._close_self_pipe.reset_mock() + self.loop.close() + self.assertFalse(self.loop._close_self_pipe.called) + + def test_sock_recv(self): + self.loop.sock_recv(self.sock, 1024) + self.proactor.recv.assert_called_with(self.sock, 1024) + + def test_sock_sendall(self): + self.loop.sock_sendall(self.sock, b'data') + self.proactor.send.assert_called_with(self.sock, b'data') + + def test_sock_connect(self): + self.loop.sock_connect(self.sock, 123) + self.proactor.connect.assert_called_with(self.sock, 123) + + def test_sock_accept(self): + self.loop.sock_accept(self.sock) + self.proactor.accept.assert_called_with(self.sock) + + def test_socketpair(self): + self.assertRaises( + NotImplementedError, BaseProactorEventLoop, self.proactor) + + def test_make_socket_transport(self): + tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock()) + self.assertIsInstance(tr, _ProactorSocketTransport) + + def test_loop_self_reading(self): + self.loop._loop_self_reading() + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_fut(self): + fut = unittest.mock.Mock() + self.loop._loop_self_reading(fut) + self.assertTrue(fut.result.called) + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_exception(self): + self.loop.close = unittest.mock.Mock() + self.proactor.recv.side_effect = OSError() + self.assertRaises(OSError, self.loop._loop_self_reading) + self.assertTrue(self.loop.close.called) + + def test_write_to_self(self): + self.loop._write_to_self() + self.csock.send.assert_called_with(b'x') + + def test_process_events(self): + self.loop._process_events([]) + + @unittest.mock.patch('tulip.proactor_events.tulip_log') + def test_start_serving(self, m_log): + pf = unittest.mock.Mock() + call_soon = self.loop.call_soon = unittest.mock.Mock() + + self.loop._start_serving(pf, self.sock) + self.assertTrue(call_soon.called) + + # callback + loop = call_soon.call_args[0][0] + loop() + self.proactor.accept.assert_called_with(self.sock) + + # conn + fut = unittest.mock.Mock() + fut.result.return_value = (unittest.mock.Mock(), unittest.mock.Mock()) + + make_tr = self.loop._make_socket_transport = unittest.mock.Mock() + loop(fut) + self.assertTrue(fut.result.called) + self.assertTrue(make_tr.called) + + # exception + fut.result.side_effect = OSError() + loop(fut) + self.assertTrue(self.sock.close.called) + self.assertTrue(m_log.exception.called) + + def test_start_serving_cancel(self): + pf = unittest.mock.Mock() + call_soon = self.loop.call_soon = unittest.mock.Mock() + + self.loop._start_serving(pf, self.sock) + loop = call_soon.call_args[0][0] + + # cancelled + fut = tulip.Future(loop=self.loop) + fut.cancel() + loop(fut) + self.assertTrue(self.sock.close.called) + + def test_stop_serving(self): + sock = unittest.mock.Mock() + self.loop.stop_serving(sock) + self.assertTrue(sock.close.called) + self.proactor.stop_serving.assert_called_with(sock) diff --git a/tests/queues_test.py b/tests/queues_test.py new file mode 100644 index 00000000..4d4876b9 --- /dev/null +++ b/tests/queues_test.py @@ -0,0 +1,502 @@ +"""Tests for queues.py""" + +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import locks +from tulip import queues +from tulip import tasks +from tulip import test_utils + + +class _QueueTestBase(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + +class QueueBasicTests(_QueueTestBase): + + def _test_repr_or_str(self, fn, expect_id): + """Test Queue's repr or str. + + fn is repr or str. expect_id is True if we expect the Queue's id to + appear in fn(Queue()). + """ + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.2, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = queues.Queue(loop=loop) + self.assertTrue(fn(q).startswith('", repr(key)) + + def test_register(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ) + self.assertIsInstance(key, selectors.SelectorKey) + self.assertEqual(key.fd, 10) + self.assertIs(key, s._fd_to_key[10]) + + def test_register_unknown_event(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.register, unittest.mock.Mock(), 999999) + + def test_register_already_registered(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + s.register(fobj, selectors.EVENT_READ) + self.assertRaises(ValueError, s.register, fobj, selectors.EVENT_READ) + + def test_unregister(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + s.register(fobj, selectors.EVENT_READ) + s.unregister(fobj) + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_unregister_unknown(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.unregister, unittest.mock.Mock()) + + def test_modify_unknown(self): + s = selectors._BaseSelector() + self.assertRaises(ValueError, s.modify, unittest.mock.Mock(), 1) + + def test_modify(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ) + key2 = s.modify(fobj, selectors.EVENT_WRITE) + self.assertNotEqual(key.events, key2.events) + self.assertEqual((selectors.EVENT_WRITE, None), s.get_info(fobj)) + + def test_modify_data(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + d1 = object() + d2 = object() + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ, d1) + key2 = s.modify(fobj, selectors.EVENT_READ, d2) + self.assertEqual(key.events, key2.events) + self.assertNotEqual(key.data, key2.data) + self.assertEqual((selectors.EVENT_READ, d2), s.get_info(fobj)) + + def test_modify_same(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + data = object() + + s = selectors._BaseSelector() + key = s.register(fobj, selectors.EVENT_READ, data) + key2 = s.modify(fobj, selectors.EVENT_READ, data) + self.assertIs(key, key2) + + def test_select(self): + s = selectors._BaseSelector() + self.assertRaises(NotImplementedError, s.select) + + def test_close(self): + s = selectors._BaseSelector() + s.register(1, selectors.EVENT_READ) + + s.close() + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + def test_registered_count(self): + s = selectors._BaseSelector() + self.assertEqual(0, s.registered_count()) + + s.register(1, selectors.EVENT_READ) + self.assertEqual(1, s.registered_count()) + + s.unregister(1) + self.assertEqual(0, s.registered_count()) + + def test_context_manager(self): + s = selectors._BaseSelector() + + with s as sel: + sel.register(1, selectors.EVENT_READ) + + self.assertFalse(s._fd_to_key) + self.assertFalse(s._fileobj_to_key) + + @unittest.mock.patch('tulip.selectors.tulip_log') + def test_key_from_fd(self, m_log): + s = selectors._BaseSelector() + key = s.register(1, selectors.EVENT_READ) + + self.assertIs(key, s._key_from_fd(1)) + self.assertIsNone(s._key_from_fd(10)) + m_log.warning.assert_called_with('No key found for fd %r', 10) + + if hasattr(selectors.DefaultSelector, 'fileno'): + def test_fileno(self): + self.assertIsInstance(selectors.DefaultSelector().fileno(), int) diff --git a/tests/streams_test.py b/tests/streams_test.py new file mode 100644 index 00000000..2267a0f5 --- /dev/null +++ b/tests/streams_test.py @@ -0,0 +1,343 @@ +"""Tests for streams.py.""" + +import gc +import ssl +import unittest +import unittest.mock + +from tulip import events +from tulip import streams +from tulip import tasks +from tulip import test_utils + + +class StreamReaderTests(unittest.TestCase): + + DATA = b'line1\nline2\nline3\n' + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(None) + + def tearDown(self): + # just in case if we have transport close callbacks + test_utils.run_briefly(self.loop) + + self.loop.close() + gc.collect() + + @unittest.mock.patch('tulip.streams.events') + def test_ctor_global_loop(self, m_events): + stream = streams.StreamReader() + self.assertIs(stream.loop, m_events.get_event_loop.return_value) + + def test_open_connection(self): + with test_utils.run_test_server(self.loop) as httpd: + f = streams.open_connection(*httpd.address, loop=self.loop) + reader, writer = self.loop.run_until_complete(f) + writer.write(b'GET / HTTP/1.0\r\n\r\n') + f = reader.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + f = reader.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + + writer.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_open_connection_no_loop_ssl(self): + with test_utils.run_test_server(self.loop, use_ssl=True) as httpd: + try: + events.set_event_loop(self.loop) + f = streams.open_connection(*httpd.address, ssl=True) + reader, writer = self.loop.run_until_complete(f) + finally: + events.set_event_loop(None) + writer.write(b'GET / HTTP/1.0\r\n\r\n') + f = reader.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + + writer.close() + + def test_open_connection_error(self): + with test_utils.run_test_server(self.loop) as httpd: + f = streams.open_connection(*httpd.address, loop=self.loop) + reader, writer = self.loop.run_until_complete(f) + writer._protocol.connection_lost(ZeroDivisionError()) + f = reader.read() + with self.assertRaises(ZeroDivisionError): + self.loop.run_until_complete(f) + + writer.close() + + def test_feed_empty_data(self): + stream = streams.StreamReader(loop=self.loop) + + stream.feed_data(b'') + self.assertEqual(0, stream.byte_count) + + def test_feed_data_byte_count(self): + stream = streams.StreamReader(loop=self.loop) + + stream.feed_data(self.DATA) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read_zero(self): + # Read zero bytes. + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(self.DATA) + + data = self.loop.run_until_complete(stream.read(0)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_read(self): + # Read bytes. + stream = streams.StreamReader(loop=self.loop) + read_task = tasks.Task(stream.read(30), loop=self.loop) + + def cb(): + stream.feed_data(self.DATA) + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + + def test_read_line_breaks(self): + # Read bytes without line breaks. + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(b'line1') + stream.feed_data(b'line2') + + data = self.loop.run_until_complete(stream.read(5)) + + self.assertEqual(b'line1', data) + self.assertEqual(5, stream.byte_count) + + def test_read_eof(self): + # Read bytes, stop at eof. + stream = streams.StreamReader(loop=self.loop) + read_task = tasks.Task(stream.read(1024), loop=self.loop) + + def cb(): + stream.feed_eof() + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertEqual(b'', data) + self.assertFalse(stream.byte_count) + + def test_read_until_eof(self): + # Read all bytes until eof. + stream = streams.StreamReader(loop=self.loop) + read_task = tasks.Task(stream.read(-1), loop=self.loop) + + def cb(): + stream.feed_data(b'chunk1\n') + stream.feed_data(b'chunk2') + stream.feed_eof() + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + + self.assertEqual(b'chunk1\nchunk2', data) + self.assertFalse(stream.byte_count) + + def test_read_exception(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(b'line\n') + + data = self.loop.run_until_complete(stream.read(2)) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.read(2)) + + def test_readline(self): + # Read one line. + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(b'chunk1 ') + read_task = tasks.Task(stream.readline(), loop=self.loop) + + def cb(): + stream.feed_data(b'chunk2 ') + stream.feed_data(b'chunk3 ') + stream.feed_data(b'\n chunk4') + self.loop.call_soon(cb) + + line = self.loop.run_until_complete(read_task) + self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) + self.assertEqual(len(b'\n chunk4')-1, stream.byte_count) + + def test_readline_limit_with_existing_data(self): + stream = streams.StreamReader(3, loop=self.loop) + stream.feed_data(b'li') + stream.feed_data(b'ne1\nline2\n') + + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + self.assertEqual([b'line2\n'], list(stream.buffer)) + + stream = streams.StreamReader(3, loop=self.loop) + stream.feed_data(b'li') + stream.feed_data(b'ne1') + stream.feed_data(b'li') + + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + self.assertEqual([b'li'], list(stream.buffer)) + self.assertEqual(2, stream.byte_count) + + def test_readline_limit(self): + stream = streams.StreamReader(7, loop=self.loop) + + def cb(): + stream.feed_data(b'chunk1') + stream.feed_data(b'chunk2') + stream.feed_data(b'chunk3\n') + stream.feed_eof() + self.loop.call_soon(cb) + + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + self.assertEqual([b'chunk3\n'], list(stream.buffer)) + self.assertEqual(7, stream.byte_count) + + def test_readline_line_byte_count(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(self.DATA[:6]) + stream.feed_data(self.DATA[6:]) + + line = self.loop.run_until_complete(stream.readline()) + + self.assertEqual(b'line1\n', line) + self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) + + def test_readline_eof(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(b'some data') + stream.feed_eof() + + line = self.loop.run_until_complete(stream.readline()) + self.assertEqual(b'some data', line) + + def test_readline_empty_eof(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_eof() + + line = self.loop.run_until_complete(stream.readline()) + self.assertEqual(b'', line) + + def test_readline_read_byte_count(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(self.DATA) + + self.loop.run_until_complete(stream.readline()) + + data = self.loop.run_until_complete(stream.read(7)) + + self.assertEqual(b'line2\nl', data) + self.assertEqual( + len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), + stream.byte_count) + + def test_readline_exception(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(b'line\n') + + data = self.loop.run_until_complete(stream.readline()) + self.assertEqual(b'line\n', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + + def test_readexactly_zero_or_less(self): + # Read exact number of bytes (zero or less). + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(self.DATA) + + data = self.loop.run_until_complete(stream.readexactly(0)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + data = self.loop.run_until_complete(stream.readexactly(-1)) + self.assertEqual(b'', data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readexactly(self): + # Read exact number of bytes. + stream = streams.StreamReader(loop=self.loop) + + n = 2 * len(self.DATA) + read_task = tasks.Task(stream.readexactly(n), loop=self.loop) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + stream.feed_data(self.DATA) + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertEqual(self.DATA + self.DATA, data) + self.assertEqual(len(self.DATA), stream.byte_count) + + def test_readexactly_eof(self): + # Read exact number of bytes (eof). + stream = streams.StreamReader(loop=self.loop) + n = 2 * len(self.DATA) + read_task = tasks.Task(stream.readexactly(n), loop=self.loop) + + def cb(): + stream.feed_data(self.DATA) + stream.feed_eof() + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertEqual(self.DATA, data) + self.assertFalse(stream.byte_count) + + def test_readexactly_exception(self): + stream = streams.StreamReader(loop=self.loop) + stream.feed_data(b'line\n') + + data = self.loop.run_until_complete(stream.readexactly(2)) + self.assertEqual(b'li', data) + + stream.set_exception(ValueError()) + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readexactly(2)) + + def test_exception(self): + stream = streams.StreamReader(loop=self.loop) + self.assertIsNone(stream.exception()) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(stream.exception(), exc) + + def test_exception_waiter(self): + stream = streams.StreamReader(loop=self.loop) + + @tasks.coroutine + def set_err(): + stream.set_exception(ValueError()) + + @tasks.coroutine + def readline(): + yield from stream.readline() + + t1 = tasks.Task(stream.readline(), loop=self.loop) + t2 = tasks.Task(set_err(), loop=self.loop) + + self.loop.run_until_complete(tasks.wait([t1, t2], loop=self.loop)) + + self.assertRaises(ValueError, t1.result) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/tasks_test.py b/tests/tasks_test.py new file mode 100644 index 00000000..56a5e128 --- /dev/null +++ b/tests/tasks_test.py @@ -0,0 +1,1161 @@ +"""Tests for tasks.py.""" + +import gc +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import tasks +from tulip import test_utils + + +class Dummy: + + def __repr__(self): + return 'Dummy()' + + def __call__(self, *args): + pass + + +class TaskTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + gc.collect() + + def test_task_class(self): + @tasks.coroutine + def notmuch(): + return 'ok' + t = tasks.Task(notmuch(), loop=self.loop) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t._loop, self.loop) + + loop = events.new_event_loop() + t = tasks.Task(notmuch(), loop=loop) + self.assertIs(t._loop, loop) + loop.close() + + def test_task_decorator(self): + @tasks.task + def notmuch(): + yield from [] + return 'ko' + + try: + events.set_event_loop(self.loop) + t = notmuch() + finally: + events.set_event_loop(None) + + self.assertIsInstance(t, tasks.Task) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + + def test_task_decorator_func(self): + @tasks.task + def notmuch(): + return 'ko' + + try: + events.set_event_loop(self.loop) + t = notmuch() + finally: + events.set_event_loop(None) + + self.assertIsInstance(t, tasks.Task) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + + def test_task_decorator_fut(self): + @tasks.task + def notmuch(): + fut = futures.Future(loop=self.loop) + fut.set_result('ko') + return fut + + try: + events.set_event_loop(self.loop) + t = notmuch() + finally: + events.set_event_loop(None) + + self.assertIsInstance(t, tasks.Task) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + + def test_async_coroutine(self): + @tasks.coroutine + def notmuch(): + return 'ok' + t = tasks.async(notmuch(), loop=self.loop) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t._loop, self.loop) + + loop = events.new_event_loop() + t = tasks.async(notmuch(), loop=loop) + self.assertIs(t._loop, loop) + loop.close() + + def test_async_future(self): + f_orig = futures.Future(loop=self.loop) + f_orig.set_result('ko') + + f = tasks.async(f_orig) + self.loop.run_until_complete(f) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 'ko') + self.assertIs(f, f_orig) + + loop = events.new_event_loop() + + with self.assertRaises(ValueError): + f = tasks.async(f_orig, loop=loop) + + loop.close() + + f = tasks.async(f_orig, loop=self.loop) + self.assertIs(f, f_orig) + + def test_async_task(self): + @tasks.coroutine + def notmuch(): + return 'ok' + t_orig = tasks.Task(notmuch(), loop=self.loop) + t = tasks.async(t_orig) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t, t_orig) + + loop = events.new_event_loop() + + with self.assertRaises(ValueError): + t = tasks.async(t_orig, loop=loop) + + loop.close() + + t = tasks.async(t_orig, loop=self.loop) + self.assertIs(t, t_orig) + + def test_async_neither(self): + with self.assertRaises(TypeError): + tasks.async('ok') + + def test_task_repr(self): + @tasks.coroutine + def notmuch(): + yield from [] + return 'abc' + + t = tasks.Task(notmuch(), loop=self.loop) + t.add_done_callback(Dummy()) + self.assertEqual(repr(t), 'Task()') + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), 'Task()') + self.assertRaises(futures.CancelledError, + self.loop.run_until_complete, t) + self.assertEqual(repr(t), 'Task()') + t = tasks.Task(notmuch(), loop=self.loop) + self.loop.run_until_complete(t) + self.assertEqual(repr(t), "Task()") + + def test_task_repr_custom(self): + @tasks.coroutine + def coro(): + pass + + class T(futures.Future): + def __repr__(self): + return 'T[]' + + class MyTask(tasks.Task, T): + def __repr__(self): + return super().__repr__() + + t = MyTask(coro(), loop=self.loop) + self.assertEqual(repr(t), 'T[]()') + + def test_task_basics(self): + @tasks.coroutine + def outer(): + a = yield from inner1() + b = yield from inner2() + return a+b + + @tasks.coroutine + def inner1(): + return 42 + + @tasks.coroutine + def inner2(): + return 1000 + + t = outer() + self.assertEqual(self.loop.run_until_complete(t), 1042) + + def test_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def task(): + yield from tasks.sleep(10.0, loop=loop) + return 12 + + t = tasks.Task(task(), loop=loop) + loop.call_soon(t.cancel) + self.assertRaises( + futures.CancelledError, loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_cancel_yield(self): + @tasks.coroutine + def task(): + yield + yield + return 12 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) # start coro + t.cancel() + self.assertRaises( + futures.CancelledError, self.loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_cancel_done_future(self): + fut1 = futures.Future(loop=self.loop) + fut2 = futures.Future(loop=self.loop) + fut3 = futures.Future(loop=self.loop) + + @tasks.coroutine + def task(): + yield from fut1 + try: + yield from fut2 + except futures.CancelledError: + pass + yield from fut3 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + fut1.set_result(None) + t.cancel() + test_utils.run_once(self.loop) # process fut1 result, delay cancel + self.assertFalse(t.done()) + test_utils.run_once(self.loop) # cancel fut2, but coro still alive + self.assertFalse(t.done()) + test_utils.run_briefly(self.loop) # cancel fut3 + self.assertTrue(t.done()) + + self.assertEqual(fut1.result(), None) + self.assertTrue(fut2.cancelled()) + self.assertTrue(fut3.cancelled()) + self.assertTrue(t.cancelled()) + + def test_future_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(10.0, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def coro(): + yield from tasks.sleep(10.0, loop=loop) + return 12 + + t = tasks.Task(coro(), timeout=0.1, loop=loop) + + self.assertRaises( + futures.CancelledError, + loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + self.assertAlmostEqual(0.1, loop.time()) + + def test_future_timeout_catch(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(10.0, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def coro(): + yield from tasks.sleep(10.0, loop=loop) + return 12 + + class Cancelled(Exception): + pass + + @tasks.coroutine + def coro2(): + try: + yield from tasks.Task(coro(), timeout=0.1, loop=loop) + except futures.CancelledError: + raise Cancelled() + + self.assertRaises( + Cancelled, loop.run_until_complete, coro2()) + self.assertAlmostEqual(0.1, loop.time()) + + def test_cancel_in_coro(self): + @tasks.coroutine + def task(): + t.cancel() + return 12 + + t = tasks.Task(task(), loop=self.loop) + self.assertRaises( + futures.CancelledError, self.loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_stop_while_run_in_complete(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.2, when) + when = yield 0.1 + self.assertAlmostEqual(0.3, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + x = 0 + + @tasks.coroutine + def task(): + nonlocal x + while x < 10: + yield from tasks.sleep(0.1, loop=loop) + x += 1 + if x == 2: + loop.stop() + + t = tasks.Task(task(), loop=loop) + self.assertRaises( + RuntimeError, loop.run_until_complete, t) + self.assertFalse(t.done()) + self.assertEqual(x, 2) + self.assertAlmostEqual(0.3, loop.time()) + + def test_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(10.0, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def task(): + yield from tasks.sleep(10.0, loop=loop) + return 42 + + t = tasks.Task(task(), loop=loop) + self.assertRaises( + futures.TimeoutError, loop.run_until_complete, t, 0.1) + self.assertAlmostEqual(0.1, loop.time()) + self.assertFalse(t.done()) + + def test_timeout_not(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def task(): + yield from tasks.sleep(0.1, loop=loop) + return 42 + + t = tasks.Task(task(), loop=loop) + r = loop.run_until_complete(t, 10.0) + self.assertTrue(t.done()) + self.assertEqual(r, 42) + self.assertAlmostEqual(0.1, loop.time()) + + def test_wait_for(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.2, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.4, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def foo(): + yield from tasks.sleep(0.2, loop=loop) + return 'done' + + fut = tasks.Task(foo(), loop=loop) + + with self.assertRaises(futures.TimeoutError): + loop.run_until_complete(tasks.wait_for(fut, 0.1, loop=loop)) + + self.assertFalse(fut.done()) + self.assertAlmostEqual(0.1, loop.time()) + + loop + # wait for result + res = loop.run_until_complete( + tasks.wait_for(fut, 0.3, loop=loop)) + self.assertEqual(res, 'done') + self.assertAlmostEqual(0.2, loop.time()) + + def test_wait_for_with_global_loop(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.2, when) + when = yield 0 + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def foo(): + yield from tasks.sleep(0.2, loop=loop) + return 'done' + + events.set_event_loop(loop) + try: + fut = tasks.Task(foo(), loop=loop) + with self.assertRaises(futures.TimeoutError): + loop.run_until_complete(tasks.wait_for(fut, 0.01)) + finally: + events.set_event_loop(None) + + self.assertAlmostEqual(0.01, loop.time()) + self.assertFalse(fut.done()) + + def test_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], loop=loop) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertEqual(res, 42) + self.assertAlmostEqual(0.15, loop.time()) + + # Doing it again should take no time and exercise a different path. + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + self.assertEqual(res, 42) + + def test_wait_with_global_loop(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + when = yield 0 + self.assertAlmostEqual(0.015, when) + yield 0.015 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.01, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.015, loop=loop), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + events.set_event_loop(loop) + try: + res = loop.run_until_complete( + tasks.Task(foo(), loop=loop)) + finally: + events.set_event_loop(None) + + self.assertEqual(res, 42) + + def test_wait_errors(self): + self.assertRaises( + ValueError, self.loop.run_until_complete, + tasks.wait(set(), loop=self.loop)) + + self.assertRaises( + ValueError, self.loop.run_until_complete, + tasks.wait([tasks.sleep(10.0, loop=self.loop)], + return_when=-1, loop=self.loop)) + + def test_wait_first_completed(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED, + loop=loop), + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertFalse(a.done()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + self.assertAlmostEqual(0.1, loop.time()) + + def test_wait_really_done(self): + # there is possibility that some tasks in the pending list + # became done but their callbacks haven't all been called yet + + @tasks.coroutine + def coro1(): + yield + + @tasks.coroutine + def coro2(): + yield + yield + + a = tasks.Task(coro1(), loop=self.loop) + b = tasks.Task(coro2(), loop=self.loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED, + loop=self.loop), + loop=self.loop) + + done, pending = self.loop.run_until_complete(task) + self.assertEqual({a, b}, done) + self.assertTrue(a.done()) + self.assertIsNone(a.result()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + + def test_wait_first_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + # first_exception, task already has exception + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + + @tasks.coroutine + def exc(): + raise ZeroDivisionError('err') + + b = tasks.Task(exc(), loop=loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION, + loop=loop), + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertAlmostEqual(0, loop.time()) + + def test_wait_first_exception_in_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + # first_exception, exception during waiting + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + + @tasks.coroutine + def exc(): + yield from tasks.sleep(0.01, loop=loop) + raise ZeroDivisionError('err') + + b = tasks.Task(exc(), loop=loop) + task = tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION, + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertAlmostEqual(0.01, loop.time()) + + def test_wait_with_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(0.15, loop=loop) + raise ZeroDivisionError('really') + + b = tasks.Task(sleeper(), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], loop=loop) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + def test_wait_with_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.11, when) + yield 0.11 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], timeout=0.11, + loop=loop) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.11, loop.time()) + + def test_wait_concurrent_complete(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + done, pending = loop.run_until_complete( + tasks.wait([b, a], timeout=0.1, loop=loop)) + + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + self.assertAlmostEqual(0.1, loop.time()) + + def test_as_completed(self): + + def gen(): + yield 0 + yield 0 + yield 0.01 + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + completed = set() + time_shifted = False + + @tasks.coroutine + def sleeper(dt, x): + nonlocal time_shifted + yield from tasks.sleep(dt, loop=loop) + completed.add(x) + if not time_shifted and 'a' in completed and 'b' in completed: + time_shifted = True + loop.advance_time(0.14) + return x + + a = sleeper(0.01, 'a') + b = sleeper(0.01, 'b') + c = sleeper(0.15, 'c') + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([b, c, a], loop=loop): + values.append((yield from f)) + return values + + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + + # Doing it again should take no time and exercise a different path. + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + def test_as_completed_with_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.12, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0.1 + self.assertAlmostEqual(0.12, when) + yield 0.02 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.sleep(0.1, 'a', loop=loop) + b = tasks.sleep(0.15, 'b', loop=loop) + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([a, b], timeout=0.12, loop=loop): + try: + v = yield from f + values.append((1, v)) + except futures.TimeoutError as exc: + values.append((2, exc)) + return values + + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) + self.assertAlmostEqual(0.12, loop.time()) + + def test_as_completed_reverse_wait(self): + + def gen(): + yield 0 + yield 0.05 + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + completed = set() + time_shifted = False + + a = tasks.sleep(0.05, 'a', loop=loop) + b = tasks.sleep(0.10, 'b', loop=loop) + fs = {a, b} + futs = list(tasks.as_completed(fs, loop=loop)) + self.assertEqual(len(futs), 2) + + x = loop.run_until_complete(futs[1]) + self.assertEqual(x, 'a') + self.assertAlmostEqual(0.05, loop.time()) + loop.advance_time(0.05) + y = loop.run_until_complete(futs[0]) + self.assertEqual(y, 'b') + self.assertAlmostEqual(0.10, loop.time()) + + def test_as_completed_concurrent(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0 + self.assertAlmostEqual(0.05, when) + yield 0.05 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.sleep(0.05, 'a', loop=loop) + b = tasks.sleep(0.05, 'b', loop=loop) + fs = {a, b} + futs = list(tasks.as_completed(fs, loop=loop)) + self.assertEqual(len(futs), 2) + waiter = tasks.wait(futs, loop=loop) + done, pending = loop.run_until_complete(waiter) + self.assertEqual(set(f.result() for f in done), {'a', 'b'}) + + def test_sleep(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0.05 + self.assertAlmostEqual(0.1, when) + yield 0.05 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def sleeper(dt, arg): + yield from tasks.sleep(dt/2, loop=loop) + res = yield from tasks.sleep(dt/2, arg, loop=loop) + return res + + t = tasks.Task(sleeper(0.1, 'yeah'), loop=loop) + loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + self.assertAlmostEqual(0.1, loop.time()) + + def test_sleep_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + t = tasks.Task(tasks.sleep(10.0, 'yeah', loop=loop), + loop=loop) + + handle = None + orig_call_later = loop.call_later + + def call_later(self, delay, callback, *args): + nonlocal handle + handle = orig_call_later(self, delay, callback, *args) + return handle + + loop.call_later = call_later + test_utils.run_briefly(loop) + + self.assertFalse(handle._cancelled) + + t.cancel() + test_utils.run_briefly(loop) + self.assertTrue(handle._cancelled) + + def test_task_cancel_sleeping_task(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(5000, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + sleepfut = None + + @tasks.coroutine + def sleep(dt): + nonlocal sleepfut + sleepfut = tasks.sleep(dt, loop=loop) + yield from sleepfut + + @tasks.coroutine + def doit(): + sleeper = tasks.Task(sleep(5000), loop=loop) + loop.call_later(0.1, sleeper.cancel) + try: + yield from sleeper + except futures.CancelledError: + return 'cancelled' + else: + return 'slept in' + + doer = doit() + self.assertEqual(loop.run_until_complete(doer), 'cancelled') + self.assertAlmostEqual(0.1, loop.time()) + + def test_task_cancel_waiter_future(self): + fut = futures.Future(loop=self.loop) + + @tasks.coroutine + def coro(): + try: + yield from fut + except futures.CancelledError: + pass + + task = tasks.Task(coro(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(task._fut_waiter, fut) + + task.cancel() + self.assertRaises( + futures.CancelledError, self.loop.run_until_complete, task) + self.assertIsNone(task._fut_waiter) + self.assertTrue(fut.cancelled()) + + def test_step_in_completed_task(self): + @tasks.coroutine + def notmuch(): + return 'ko' + + task = tasks.Task(notmuch(), loop=self.loop) + task.set_result('ok') + + self.assertRaises(AssertionError, task._step) + + def test_step_result(self): + @tasks.coroutine + def notmuch(): + yield None + yield 1 + return 'ko' + + self.assertRaises( + RuntimeError, self.loop.run_until_complete, notmuch()) + + def test_step_result_future(self): + # If coroutine returns future, task waits on this future. + + class Fut(futures.Future): + def __init__(self, *args, **kwds): + self.cb_added = False + super().__init__(*args, **kwds) + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + fut = Fut(loop=self.loop) + result = None + + @tasks.coroutine + def wait_for_future(): + nonlocal result + result = yield from fut + + t = tasks.Task(wait_for_future(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertTrue(fut.cb_added) + + res = object() + fut.set_result(res) + test_utils.run_briefly(self.loop) + self.assertIs(res, result) + self.assertTrue(t.done()) + self.assertIsNone(t.result()) + + def test_step_with_baseexception(self): + @tasks.coroutine + def notmutch(): + raise BaseException() + + task = tasks.Task(notmutch(), loop=self.loop) + self.assertRaises(BaseException, task._step) + + self.assertTrue(task.done()) + self.assertIsInstance(task.exception(), BaseException) + + def test_baseexception_during_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(10, loop=loop) + + @tasks.coroutine + def notmutch(): + try: + yield from sleeper() + except futures.CancelledError: + raise BaseException() + + task = tasks.Task(notmutch(), loop=loop) + test_utils.run_briefly(loop) + + task.cancel() + self.assertFalse(task.done()) + + self.assertRaises(BaseException, test_utils.run_briefly, loop) + + self.assertTrue(task.done()) + self.assertTrue(task.cancelled()) + + def test_iscoroutinefunction(self): + def fn(): + pass + + self.assertFalse(tasks.iscoroutinefunction(fn)) + + def fn1(): + yield + self.assertFalse(tasks.iscoroutinefunction(fn1)) + + @tasks.coroutine + def fn2(): + yield + self.assertTrue(tasks.iscoroutinefunction(fn2)) + + def test_yield_vs_yield_from(self): + fut = futures.Future(loop=self.loop) + + @tasks.coroutine + def wait_for_future(): + yield fut + + task = wait_for_future() + with self.assertRaises(RuntimeError) as cm: + self.loop.run_until_complete(task) + + self.assertTrue(fut.done()) + self.assertIs(fut.exception(), cm.exception) + + def test_yield_vs_yield_from_generator(self): + @tasks.coroutine + def coro(): + yield + + @tasks.coroutine + def wait_for_future(): + yield coro() + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, task) + + def test_coroutine_non_gen_function(self): + @tasks.coroutine + def func(): + return 'test' + + self.assertTrue(tasks.iscoroutinefunction(func)) + + coro = func() + self.assertTrue(tasks.iscoroutine(coro)) + + res = self.loop.run_until_complete(coro) + self.assertEqual(res, 'test') + + def test_coroutine_non_gen_function_return_future(self): + fut = futures.Future(loop=self.loop) + + @tasks.coroutine + def func(): + return fut + + @tasks.coroutine + def coro(): + fut.set_result('test') + + t1 = tasks.Task(func(), loop=self.loop) + t2 = tasks.Task(coro(), loop=self.loop) + res = self.loop.run_until_complete(t1) + self.assertEqual(res, 'test') + self.assertIsNone(t2.result()) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/transports_test.py b/tests/transports_test.py new file mode 100644 index 00000000..5920cda6 --- /dev/null +++ b/tests/transports_test.py @@ -0,0 +1,59 @@ +"""Tests for transports.py.""" + +import unittest +import unittest.mock + +from tulip import futures +from tulip import transports + + +class TransportTests(unittest.TestCase): + + def test_ctor_extra_is_none(self): + transport = transports.Transport() + self.assertEqual(transport._extra, {}) + + def test_get_extra_info(self): + transport = transports.Transport({'extra': 'info'}) + self.assertEqual('info', transport.get_extra_info('extra')) + self.assertIsNone(transport.get_extra_info('unknown')) + + default = object() + self.assertIs(default, transport.get_extra_info('unknown', default)) + + def test_writelines(self): + transport = transports.Transport() + transport.write = unittest.mock.Mock() + + transport.writelines(['line1', 'line2', 'line3']) + self.assertEqual(3, transport.write.call_count) + + def test_not_implemented(self): + transport = transports.Transport() + + self.assertRaises(NotImplementedError, transport.write, 'data') + self.assertRaises(NotImplementedError, transport.write_eof) + self.assertRaises(NotImplementedError, transport.can_write_eof) + self.assertRaises(NotImplementedError, transport.pause) + self.assertRaises(NotImplementedError, transport.resume) + self.assertRaises(NotImplementedError, transport.close) + self.assertRaises(NotImplementedError, transport.abort) + self.assertRaises(NotImplementedError, transport.pause_writing) + self.assertRaises(NotImplementedError, transport.resume_writing) + self.assertRaises(NotImplementedError, transport.discard_output) + + def test_dgram_not_implemented(self): + transport = transports.DatagramTransport() + + self.assertRaises(NotImplementedError, transport.sendto, 'data') + self.assertRaises(NotImplementedError, transport.abort) + + def test_subprocess_transport_not_implemented(self): + transport = transports.SubprocessTransport() + + self.assertRaises(NotImplementedError, transport.get_pid) + self.assertRaises(NotImplementedError, transport.get_returncode) + self.assertRaises(NotImplementedError, transport.get_pipe_transport, 1) + self.assertRaises(NotImplementedError, transport.send_signal, 1) + self.assertRaises(NotImplementedError, transport.terminate) + self.assertRaises(NotImplementedError, transport.kill) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py new file mode 100644 index 00000000..f0b42a39 --- /dev/null +++ b/tests/unix_events_test.py @@ -0,0 +1,818 @@ +"""Tests for unix_events.py.""" + +import gc +import errno +import io +import pprint +import signal +import stat +import sys +import tempfile +import unittest +import unittest.mock + + +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import test_utils +from tulip import unix_events + + +@unittest.skipUnless(signal, 'Signals are not supported') +class SelectorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.loop = unix_events.SelectorEventLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_check_signal(self): + self.assertRaises( + TypeError, self.loop._check_signal, '1') + self.assertRaises( + ValueError, self.loop._check_signal, signal.NSIG + 1) + + def test_handle_signal_no_handler(self): + self.loop._handle_signal(signal.NSIG + 1, ()) + + def test_handle_signal_cancelled_handler(self): + h = events.Handle(unittest.mock.Mock(), ()) + h.cancel() + self.loop._signal_handlers[signal.NSIG + 1] = h + self.loop.remove_signal_handler = unittest.mock.Mock() + self.loop._handle_signal(signal.NSIG + 1, ()) + self.loop.remove_signal_handler.assert_called_with(signal.NSIG + 1) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_setup_error(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.set_wakeup_fd.side_effect = ValueError + + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + cb = lambda: True + self.loop.add_signal_handler(signal.SIGHUP, cb) + h = self.loop._signal_handlers.get(signal.SIGHUP) + self.assertTrue(isinstance(h, events.Handle)) + self.assertEqual(h._callback, cb) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_install_error(self, m_signal): + m_signal.NSIG = signal.NSIG + + def set_wakeup_fd(fd): + if fd == -1: + raise ValueError() + m_signal.set_wakeup_fd = set_wakeup_fd + + class Err(OSError): + errno = errno.EFAULT + m_signal.signal.side_effect = Err + + self.assertRaises( + Err, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_add_signal_handler_install_error2(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.loop._signal_handlers[signal.SIGHUP] = lambda: True + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(1, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_add_signal_handler_install_error3(self, m_logging, m_signal): + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + m_signal.NSIG = signal.NSIG + + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(2, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + self.assertTrue( + self.loop.remove_signal_handler(signal.SIGHUP)) + self.assertTrue(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_2(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.SIGINT = signal.SIGINT + + self.loop.add_signal_handler(signal.SIGINT, lambda: True) + self.loop._signal_handlers[signal.SIGHUP] = object() + m_signal.set_wakeup_fd.reset_mock() + + self.assertTrue( + self.loop.remove_signal_handler(signal.SIGINT)) + self.assertFalse(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGINT, m_signal.default_int_handler), + m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.set_wakeup_fd.side_effect = ValueError + + self.loop.remove_signal_handler(signal.SIGHUP) + self.assertTrue(m_logging.info) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error(self, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.signal.side_effect = OSError + + self.assertRaises( + OSError, self.loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error2(self, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.assertRaises( + RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = True + m_WIFSIGNALED.return_value = False + m_WEXITSTATUS.return_value = 3 + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + transp._process_exited.assert_called_with(3) + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_signal(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = False + m_WIFSIGNALED.return_value = True + m_WTERMSIG.return_value = 1 + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + transp._process_exited.assert_called_with(-1) + self.assertFalse(m_WEXITSTATUS.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_zero_pid(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(0, object()), ChildProcessError] + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m_WEXITSTATUS.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_not_registered_subprocess(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = True + m_WIFSIGNALED.return_value = False + m_WEXITSTATUS.return_value = 3 + + self.loop._sig_chld() + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_unknown_status(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = False + m_WIFSIGNALED.return_value = False + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('tulip.unix_events.tulip_log') + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_unknown_status_in_handler(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG, + m_log): + m_waitpid.side_effect = Exception + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m_WEXITSTATUS.called) + m_log.exception.assert_called_with( + 'Unknown exception in SIGCHLD handler') + + +class UnixReadPipeTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(protocols.Protocol) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + + fcntl_patcher = unittest.mock.patch('fcntl.fcntl') + fcntl_patcher.start() + self.addCleanup(fcntl_patcher.stop) + + def test_ctor(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = futures.Future(loop=self.loop) + unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol, fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + + @unittest.mock.patch('os.read') + def test__read_ready(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.return_value = b'data' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.data_received.assert_called_with(b'data') + + @unittest.mock.patch('os.read') + def test__read_ready_eof(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.return_value = b'' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.eof_received.assert_called_with() + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('os.read') + def test__read_ready_blocked(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.side_effect = BlockingIOError + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.data_received.called) + + @unittest.mock.patch('tulip.log.tulip_log.exception') + @unittest.mock.patch('os.read') + def test__read_ready_error(self, m_read, m_logexc): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + err = OSError() + m_read.side_effect = err + tr._close = unittest.mock.Mock() + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + tr._close.assert_called_with(err) + m_logexc.assert_called_with('Fatal error for %s', tr) + + @unittest.mock.patch('os.read') + def test_pause(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + m = unittest.mock.Mock() + self.loop.add_reader(5, m) + tr.pause() + self.assertFalse(self.loop.readers) + + @unittest.mock.patch('os.read') + def test_resume(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr.resume() + self.loop.assert_reader(5, tr._read_ready) + + @unittest.mock.patch('os.read') + def test_close(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr._close = unittest.mock.Mock() + tr.close() + tr._close.assert_called_with(None) + + @unittest.mock.patch('os.read') + def test_close_already_closing(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr._closing = True + tr._close = unittest.mock.Mock() + tr.close() + self.assertFalse(tr._close.called) + + @unittest.mock.patch('os.read') + def test__close(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = object() + tr._close(err) + self.assertTrue(tr._closing) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + + def test__call_connection_lost(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test__call_connection_lost_with_err(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + +class UnixWritePipeTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(protocols.BaseProtocol) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + + fcntl_patcher = unittest.mock.patch('fcntl.fcntl') + fcntl_patcher.start() + self.addCleanup(fcntl_patcher.stop) + + fstat_patcher = unittest.mock.patch('os.fstat') + m_fstat = fstat_patcher.start() + st = unittest.mock.Mock() + st.st_mode = stat.S_IFIFO + m_fstat.return_value = st + self.addCleanup(fstat_patcher.stop) + + def test_ctor(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = futures.Future(loop=self.loop) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol, fut) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.assertEqual(None, fut.result()) + + def test_can_write_eof(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.assertTrue(tr.can_write_eof()) + + @unittest.mock.patch('os.write') + def test_write(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.return_value = 4 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_no_data(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write(b'') + self.assertFalse(m_write.called) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_partial(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.return_value = 2 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'ta'], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_buffer(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'previous'] + tr.write(b'data') + self.assertFalse(m_write.called) + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'previous', b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_again(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.side_effect = BlockingIOError() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('tulip.unix_events.tulip_log') + @unittest.mock.patch('os.write') + def test_write_err(self, m_write, m_log): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + m_write.side_effect = err + tr._fatal_error = unittest.mock.Mock() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + tr._fatal_error.assert_called_with(err) + self.assertEqual(1, tr._conn_lost) + + tr.write(b'data') + self.assertEqual(2, tr._conn_lost) + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + m_log.warning.assert_called_with( + 'os.write(pipe, data) raised exception.') + + def test__read_ready(self): + tr = unix_events._UnixWritePipeTransport(self.loop, self.pipe, + self.protocol) + tr._read_ready() + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + self.assertTrue(tr._closing) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('os.write') + def test__write_ready(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_partial(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 3 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'a'], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_again(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.side_effect = BlockingIOError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_empty(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 0 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('tulip.log.tulip_log.exception') + @unittest.mock.patch('os.write') + def test__write_ready_err(self, m_write, m_logexc): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.side_effect = err = OSError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + m_logexc.assert_called_with('Fatal error for %s', tr) + self.assertEqual(1, tr._conn_lost) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + + + @unittest.mock.patch('os.write') + def test__write_ready_closing(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._closing = True + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) + self.assertEqual([], tr._buffer) + self.protocol.connection_lost.assert_called_with(None) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('os.write') + def test_abort(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + self.loop.add_reader(5, tr._read_ready) + tr._buffer = [b'da', b'ta'] + tr.abort() + self.assertFalse(m_write.called) + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test__call_connection_lost(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test__call_connection_lost_with_err(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test_close(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr.close() + tr.write_eof.assert_called_with() + + def test_close_closing(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr._closing = True + tr.close() + self.assertFalse(tr.write_eof.called) + + def test_write_eof(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test_write_eof_pending(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr._buffer = [b'data'] + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.protocol.connection_lost.called) + + def test_pause_resume_writing(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr.pause_writing() + self.assertFalse(tr._writing) + tr.resume_writing() + self.assertTrue(tr._writing) + + def test_double_pause_resume_writing(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr.pause_writing() + self.assertFalse(tr._writing) + tr.pause_writing() + self.assertFalse(tr._writing) + tr.resume_writing() + self.assertTrue(tr._writing) + tr.resume_writing() + self.assertTrue(tr._writing) + + def test_pause_resume_writing_with_nonempty_buffer(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + tr.pause_writing() + self.assertFalse(tr._writing) + self.assertFalse(self.loop.writers) + self.assertEqual([b'da', b'ta'], tr._buffer) + + tr.resume_writing() + self.assertTrue(tr._writing) + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'da', b'ta'], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_on_pause(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + tr.pause_writing() + + tr._write_ready() + self.assertFalse(m_write.called) + self.assertFalse(self.loop.writers) + self.assertEqual([b'da', b'ta'], tr._buffer) + self.assertFalse(tr._writing) + + def test_discard_output(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr._buffer = [b'da', b'ta'] + self.loop.add_writer(5, tr._write_ready) + tr.discard_output() + self.assertTrue(tr._writing) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + def test_discard_output_without_pending_writes(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr.discard_output() + self.assertTrue(tr._writing) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) diff --git a/tests/windows_events_test.py b/tests/windows_events_test.py new file mode 100644 index 00000000..ce9b74da --- /dev/null +++ b/tests/windows_events_test.py @@ -0,0 +1,81 @@ +import unittest + +import tulip + +from tulip import windows_events +from tulip import protocols +from tulip import streams + + +def connect_read_pipe(loop, file): + stream_reader = streams.StreamReader(loop=loop) + protocol = _StreamReaderProtocol(stream_reader) + loop._make_read_pipe_transport(file, protocol) + return stream_reader + + +class _StreamReaderProtocol(protocols.Protocol): + def __init__(self, stream_reader): + self.stream_reader = stream_reader + + def connection_lost(self, exc): + self.stream_reader.set_exception(exc) + + def data_received(self, data): + self.stream_reader.feed_data(data) + + def eof_received(self): + self.stream_reader.feed_eof() + + +class ProactorTests(unittest.TestCase): + + def setUp(self): + self.loop = windows_events.ProactorEventLoop() + tulip.set_event_loop(None) + + def tearDown(self): + self.loop.close() + self.loop = None + + def test_pause_resume_discard(self): + a, b = self.loop._socketpair() + trans = self.loop._make_write_pipe_transport(a, protocols.Protocol()) + reader = connect_read_pipe(self.loop, b) + f = tulip.async(reader.readline(), loop=self.loop) + + trans.write(b'msg1\n') + self.loop.run_until_complete(f, timeout=0.01) + self.assertEqual(f.result(), b'msg1\n') + f = tulip.async(reader.readline(), loop=self.loop) + + trans.pause_writing() + trans.write(b'msg2\n') + with self.assertRaises(tulip.TimeoutError): + self.loop.run_until_complete(f, timeout=0.01) + self.assertEqual(trans._buffer, [b'msg2\n']) + + trans.resume_writing() + self.loop.run_until_complete(f, timeout=0.1) + self.assertEqual(f.result(), b'msg2\n') + f = tulip.async(reader.readline(), loop=self.loop) + + trans.pause_writing() + trans.write(b'msg3\n') + self.assertEqual(trans._buffer, [b'msg3\n']) + trans.discard_output() + self.assertEqual(trans._buffer, []) + + trans.write(b'msg4\n') + self.assertEqual(trans._buffer, [b'msg4\n']) + trans.resume_writing() + self.loop.run_until_complete(f, timeout=0.01) + self.assertEqual(f.result(), b'msg4\n') + + def test_close(self): + a, b = self.loop._socketpair() + trans = self.loop._make_socket_transport(a, protocols.Protocol()) + f = tulip.async(self.loop.sock_recv(b, 100)) + trans.close() + self.loop.run_until_complete(f, timeout=1) + self.assertEqual(f.result(), b'') diff --git a/tests/windows_utils_test.py b/tests/windows_utils_test.py new file mode 100644 index 00000000..b23896d3 --- /dev/null +++ b/tests/windows_utils_test.py @@ -0,0 +1,132 @@ +"""Tests for window_utils""" + +import sys +import test.support +import unittest +import unittest.mock +import _winapi + +from tulip import windows_utils +from tulip import _overlapped + + +class WinsocketpairTests(unittest.TestCase): + + def test_winsocketpair(self): + ssock, csock = windows_utils.socketpair() + + csock.send(b'xxx') + self.assertEqual(b'xxx', ssock.recv(1024)) + + csock.close() + ssock.close() + + @unittest.mock.patch('tulip.windows_utils.socket') + def test_winsocketpair_exc(self, m_socket): + m_socket.socket.return_value.getsockname.return_value = ('', 12345) + m_socket.socket.return_value.accept.return_value = object(), object() + m_socket.socket.return_value.connect.side_effect = OSError() + + self.assertRaises(OSError, windows_utils.socketpair) + + +class PipeTests(unittest.TestCase): + + def test_pipe_overlapped(self): + h1, h2 = windows_utils.pipe(overlapped=(True, True)) + try: + ov1 = _overlapped.Overlapped() + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, 0) + + ov1.ReadFile(h1, 100) + self.assertTrue(ov1.pending) + self.assertEqual(ov1.error, _winapi.ERROR_IO_PENDING) + ERROR_IO_INCOMPLETE = 996 + try: + ov1.getresult() + except OSError as e: + self.assertEqual(e.winerror, ERROR_IO_INCOMPLETE) + else: + raise RuntimeError('expected ERROR_IO_INCOMPLETE') + + ov2 = _overlapped.Overlapped() + self.assertFalse(ov2.pending) + self.assertEqual(ov2.error, 0) + + ov2.WriteFile(h2, b"hello") + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + + res = _winapi.WaitForMultipleObjects([ov2.event], False, 100) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, ERROR_IO_INCOMPLETE) + self.assertFalse(ov2.pending) + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + self.assertEqual(ov1.getresult(), b"hello") + finally: + _winapi.CloseHandle(h1) + _winapi.CloseHandle(h2) + + def test_pipe_handle(self): + h, _ = windows_utils.pipe(overlapped=(True, True)) + _winapi.CloseHandle(_) + p = windows_utils.PipeHandle(h) + self.assertEqual(p.fileno(), h) + self.assertEqual(p.handle, h) + + # check garbage collection of p closes handle + del p + test.support.gc_collect() + try: + _winapi.CloseHandle(h) + except OSError as e: + self.assertEqual(e.winerror, 6) # ERROR_INVALID_HANDLE + else: + raise RuntimeError('expected ERROR_INVALID_HANDLE') + + +class PopenTests(unittest.TestCase): + + def test_popen(self): + command = r"""if 1: + import sys + s = sys.stdin.readline() + sys.stdout.write(s.upper()) + sys.stderr.write('stderr') + """ + msg = b"blah\n" + + p = windows_utils.Popen([sys.executable, '-c', command], + stdin=windows_utils.PIPE, + stdout=windows_utils.PIPE, + stderr=windows_utils.PIPE) + + for f in [p.stdin, p.stdout, p.stderr]: + self.assertIsInstance(f, windows_utils.PipeHandle) + + ovin = _overlapped.Overlapped() + ovout = _overlapped.Overlapped() + overr = _overlapped.Overlapped() + + ovin.WriteFile(p.stdin.handle, msg) + ovout.ReadFile(p.stdout.handle, 100) + overr.ReadFile(p.stderr.handle, 100) + + events = [ovin.event, ovout.event, overr.event] + res = _winapi.WaitForMultipleObjects(events, True, 2000) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + self.assertFalse(ovout.pending) + self.assertFalse(overr.pending) + self.assertFalse(ovin.pending) + + self.assertEqual(ovin.getresult(), len(msg)) + out = ovout.getresult().rstrip() + err = overr.getresult().rstrip() + + self.assertGreater(len(out), 0) + self.assertGreater(len(err), 0) + # allow for partial reads... + self.assertTrue(msg.upper().rstrip().startswith(out)) + self.assertTrue(b"stderr".startswith(err)) diff --git a/tulip/TODO b/tulip/TODO new file mode 100644 index 00000000..b3a9302e --- /dev/null +++ b/tulip/TODO @@ -0,0 +1,26 @@ +TODO in tulip v2 (tulip/ package directory) +------------------------------------------- + +- See also TBD and Open Issues in PEP 3156 + +- Refactor unix_events.py (it's getting too long) + +- Docstrings + +- Unittests + +- better run_once() behavior? (Run ready list last.) + +- start_serving() + +- Make Handler() callable? Log the exception in there? + +- Add the repeat interval to the Handler class? + +- Recognize Handler passed to add_reader(), call_soon(), etc.? + +- SSL support + +- buffered stream implementation + +- Primitives like par() and wait_one() diff --git a/tulip/__init__.py b/tulip/__init__.py new file mode 100644 index 00000000..9de84cb0 --- /dev/null +++ b/tulip/__init__.py @@ -0,0 +1,28 @@ +"""Tulip 2.0, tracking PEP 3156.""" + +import sys + +# This relies on each of the submodules having an __all__ variable. +from .futures import * +from .events import * +from .locks import * +from .transports import * +from .parsers import * +from .protocols import * +from .streams import * +from .tasks import * + +if sys.platform == 'win32': # pragma: no cover + from .windows_events import * +else: + from .unix_events import * # pragma: no cover + + +__all__ = (futures.__all__ + + events.__all__ + + locks.__all__ + + transports.__all__ + + parsers.__all__ + + protocols.__all__ + + streams.__all__ + + tasks.__all__) diff --git a/tulip/base_events.py b/tulip/base_events.py new file mode 100644 index 00000000..5ff2d3c9 --- /dev/null +++ b/tulip/base_events.py @@ -0,0 +1,611 @@ +"""Base implementation of event loop. + +The event loop can be broken up into a multiplexer (the part +responsible for notifying us of IO events) and the event loop proper, +which wraps a multiplexer with functionality for scheduling callbacks, +immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. +""" + + +import collections +import concurrent.futures +import heapq +import logging +import socket +import subprocess +import time +import os +import sys + +from . import events +from . import futures +from . import tasks +from .log import tulip_log + + +__all__ = ['BaseEventLoop'] + + +# Argument for default thread pool executor creation. +_MAX_WORKERS = 5 + + +class _StopError(BaseException): + """Raised to stop the event loop.""" + + +def _raise_stop_error(*args): + raise _StopError + + +class BaseEventLoop(events.AbstractEventLoop): + + def __init__(self): + self._ready = collections.deque() + self._scheduled = [] + self._default_executor = None + self._internal_fds = 0 + self._running = False + + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None): + """Create socket transport.""" + raise NotImplementedError + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + server_side=False, extra=None): + """Create SSL transport.""" + raise NotImplementedError + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + """Create datagram transport.""" + raise NotImplementedError + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create read pipe transport.""" + raise NotImplementedError + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create write pipe transport.""" + raise NotImplementedError + + @tasks.coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + """Create subprocess transport.""" + raise NotImplementedError + + def _read_from_self(self): + """XXX""" + raise NotImplementedError + + def _write_to_self(self): + """XXX""" + raise NotImplementedError + + def _process_events(self, event_list): + """Process selector events.""" + raise NotImplementedError + + def run_forever(self): + """Run until stop() is called.""" + if self._running: + raise RuntimeError('Event loop is running.') + self._running = True + try: + while True: + try: + self._run_once() + except _StopError: + break + finally: + self._running = False + + def run_until_complete(self, future, timeout=None): + """Run until the Future is done, or until a timeout. + + If the argument is a coroutine, it is wrapped in a Task. + + XXX TBD: It would be disastrous to call run_until_complete() + with the same coroutine twice -- it would wrap it in two + different Tasks and that can't be good. + + Return the Future's result, or raise its exception. If the + timeout is reached or stop() is called, raise TimeoutError. + """ + future = tasks.async(future, loop=self) + future.add_done_callback(_raise_stop_error) + handle_called = False + + if timeout is None: + self.run_forever() + else: + + def stop_loop(): + nonlocal handle_called + handle_called = True + raise _StopError + + handle = self.call_later(timeout, stop_loop) + self.run_forever() + handle.cancel() + + future.remove_done_callback(_raise_stop_error) + + if handle_called: + raise futures.TimeoutError + + if not future.done(): + raise RuntimeError('Event loop stopped before Future completed.') + + return future.result() + + def stop(self): + """Stop running the event loop. + + Every callback scheduled before stop() is called will run. + Callback scheduled after stop() is called won't. However, + those callbacks will run if run() is called again later. + """ + self.call_soon(_raise_stop_error) + + def is_running(self): + """Returns running status of event loop.""" + return self._running + + def time(self): + """Return the time according to the event loop's clock.""" + return time.monotonic() + + def call_later(self, delay, callback, *args): + """Arrange for a callback to be called at a given time. + + Return a Handle: an opaque object with a cancel() method that + can be used to cancel the call. + + The delay can be an int or float, expressed in seconds. It is + always a relative time. + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + return self.call_at(self.time() + delay, callback, *args) + + def call_at(self, when, callback, *args): + """Like call_later(), but uses an absolute time.""" + timer = events.TimerHandle(when, callback, args) + heapq.heappush(self._scheduled, timer) + return timer + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + handle = events.make_handle(callback, args) + self._ready.append(handle) + return handle + + def call_soon_threadsafe(self, callback, *args): + """XXX""" + handle = self.call_soon(callback, *args) + self._write_to_self() + return handle + + def run_in_executor(self, executor, callback, *args): + if isinstance(callback, events.Handle): + assert not args + assert not isinstance(callback, events.TimerHandle) + if callback._cancelled: + f = futures.Future(loop=self) + f.set_result(None) + return f + callback, args = callback._callback, callback._args + if executor is None: + executor = self._default_executor + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + self._default_executor = executor + return futures.wrap_future(executor.submit(callback, *args), loop=self) + + def set_default_executor(self, executor): + self._default_executor = executor + + def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) + + def getnameinfo(self, sockaddr, flags=0): + return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + + @tasks.coroutine + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=None, family=0, proto=0, flags=0, sock=None, + local_addr=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + 'host/port and sock can not be specified at the same time') + + f1 = self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + fs = [f1] + if local_addr is not None: + f2 = self.getaddrinfo( + *local_addr, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + fs.append(f2) + else: + f2 = None + + yield from tasks.wait(fs, loop=self) + + infos = f1.result() + if not infos: + raise OSError('getaddrinfo() returned empty list') + if f2 is not None: + laddr_infos = f2.result() + if not laddr_infos: + raise OSError('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + if f2 is not None: + for _, _, _, _, laddr in laddr_infos: + try: + sock.bind(laddr) + break + except OSError as exc: + exc = OSError( + exc.errno, 'error while ' + 'attempting to bind on address ' + '{!r}: {}'.format( + laddr, exc.strerror.lower())) + exceptions.append(exc) + else: + sock.close() + sock = None + continue + yield from self.sock_connect(sock, address) + except OSError as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + if len(exceptions) == 1: + raise exceptions[0] + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise OSError('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + + elif sock is None: + raise ValueError( + 'host and port was not specified and no sock specified') + + sock.setblocking(False) + + protocol = protocol_factory() + waiter = futures.Future(loop=self) + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + transport = self._make_ssl_transport( + sock, protocol, sslcontext, waiter, server_side=False) + else: + transport = self._make_socket_transport(sock, protocol, waiter) + + yield from waiter + return transport, protocol + + @tasks.coroutine + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + """Create datagram connection.""" + if not (local_addr or remote_addr): + if family == 0: + raise ValueError('unexpected address family') + addr_pairs_info = (((family, proto), (None, None)),) + else: + # join addresss by (family, protocol) + addr_infos = collections.OrderedDict() + for idx, addr in ((0, local_addr), (1, remote_addr)): + if addr is not None: + assert isinstance(addr, tuple) and len(addr) == 2, ( + '2-tuple is expected') + + infos = yield from self.getaddrinfo( + *addr, family=family, type=socket.SOCK_DGRAM, + proto=proto, flags=flags) + if not infos: + raise OSError('getaddrinfo() returned empty list') + + for fam, _, pro, _, address in infos: + key = (fam, pro) + if key not in addr_infos: + addr_infos[key] = [None, None] + addr_infos[key][idx] = address + + # each addr has to have info for each (family, proto) pair + addr_pairs_info = [ + (key, addr_pair) for key, addr_pair in addr_infos.items() + if not ((local_addr and addr_pair[0] is None) or + (remote_addr and addr_pair[1] is None))] + + if not addr_pairs_info: + raise ValueError('can not get address information') + + exceptions = [] + + for ((family, proto), + (local_address, remote_address)) in addr_pairs_info: + sock = None + l_addr = None + r_addr = None + try: + sock = socket.socket( + family=family, type=socket.SOCK_DGRAM, proto=proto) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(False) + + if local_addr: + sock.bind(local_address) + l_addr = sock.getsockname() + if remote_addr: + yield from self.sock_connect(sock, remote_address) + r_addr = remote_address + except OSError as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + + protocol = protocol_factory() + transport = self._make_datagram_transport( + sock, protocol, r_addr, extra={'addr': l_addr}) + return transport, protocol + + # This returns a Task made from self._start_serving_internal(). + # We want start_serving() to return a Task so that it will start + # running right away (when the event loop runs) even if the caller + # doesn't wait for it. Note that this is different from + # e.g. create_connection(), or create_datagram_endpoint(), which + # are a "mere" coroutines and require their caller to wait for + # them. The reason for the difference is that only + # start_serving() creates multiple transports and protocols. + def start_serving(self, protocol_factory, host=None, port=None, + *, + family=socket.AF_UNSPEC, + flags=socket.AI_PASSIVE, + sock=None, + backlog=100, + ssl=None, + reuse_address=None): + coro = self._start_serving_internal(protocol_factory, host, port, + family=family, + flags=flags, + sock=sock, + backlog=backlog, + ssl=ssl, + reuse_address=reuse_address) + return tasks.Task(coro, loop=self) + + @tasks.coroutine + def _start_serving_internal(self, protocol_factory, host=None, port=None, + *, + family=socket.AF_UNSPEC, + flags=socket.AI_PASSIVE, + sock=None, + backlog=100, + ssl=None, + reuse_address=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + 'host/port and sock can not be specified at the same time') + + AF_INET6 = getattr(socket, 'AF_INET6', 0) + if reuse_address is None: + reuse_address = os.name == 'posix' and sys.platform != 'cygwin' + sockets = [] + if host == '': + host = None + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=0, flags=flags) + if not infos: + raise OSError('getaddrinfo() returned empty list') + + completed = False + try: + for res in infos: + af, socktype, proto, canonname, sa = res + sock = socket.socket(af, socktype, proto) + sockets.append(sock) + if reuse_address: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, + True) + # Disable IPv4/IPv6 dual stack support (enabled by + # default on Linux) which makes a single socket + # listen on both address families. + if af == AF_INET6 and hasattr(socket, 'IPPROTO_IPV6'): + sock.setsockopt(socket.IPPROTO_IPV6, + socket.IPV6_V6ONLY, + True) + try: + sock.bind(sa) + except OSError as err: + raise OSError(err.errno, 'error while attempting ' + 'to bind on address %r: %s' + % (sa, err.strerror.lower())) + completed = True + finally: + if not completed: + for sock in sockets: + sock.close() + else: + if sock is None: + raise ValueError( + 'host and port was not specified and no sock specified') + sockets = [sock] + + for sock in sockets: + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock, ssl) + return sockets + + @tasks.coroutine + def connect_read_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future(loop=self) + transport = self._make_read_pipe_transport(pipe, protocol, waiter, + extra={}) + yield from waiter + return transport, protocol + + @tasks.coroutine + def connect_write_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future(loop=self) + transport = self._make_write_pipe_transport(pipe, protocol, waiter, + extra={}) + yield from waiter + return transport, protocol + + @tasks.coroutine + def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + universal_newlines=False, shell=True, bufsize=0, + **kwargs): + assert not universal_newlines, "universal_newlines must be False" + assert shell, "shell must be True" + assert isinstance(cmd, str), cmd + protocol = protocol_factory() + transport = yield from self._make_subprocess_transport( + protocol, cmd, True, stdin, stdout, stderr, bufsize, + extra={}, **kwargs) + return transport, protocol + + @tasks.coroutine + def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + universal_newlines=False, shell=False, bufsize=0, + **kwargs): + assert not universal_newlines, "universal_newlines must be False" + assert not shell, "shell must be False" + protocol = protocol_factory() + transport = yield from self._make_subprocess_transport( + protocol, args, False, stdin, stdout, stderr, bufsize, + extra={}, **kwargs) + return transport, protocol + + def _add_callback(self, handle): + """Add a Handle to ready or scheduled.""" + assert isinstance(handle, events.Handle), 'A Handle is required here' + if handle._cancelled: + return + if isinstance(handle, events.TimerHandle): + heapq.heappush(self._scheduled, handle) + else: + self._ready.append(handle) + + def _add_callback_signalsafe(self, handle): + """Like _add_callback() but called from a signal handler.""" + self._add_callback(handle) + self._write_to_self() + + def _run_once(self): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0]._cancelled: + heapq.heappop(self._scheduled) + + timeout = None + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0]._when + deadline = max(0, when - self.time()) + if timeout is None: + timeout = deadline + else: + timeout = min(timeout, deadline) + + # TODO: Instrumentation only in debug mode? + t0 = self.time() + event_list = self._selector.select(timeout) + t1 = self.time() + argstr = '' if timeout is None else '{:.3f}'.format(timeout) + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + tulip_log.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + self._process_events(event_list) + + # Handle 'later' callbacks that are ready. + now = self.time() + while self._scheduled: + handle = self._scheduled[0] + if handle._when > now: + break + handle = heapq.heappop(self._scheduled) + self._ready.append(handle) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is threadsafe without using locks. + ntodo = len(self._ready) + for i in range(ntodo): + handle = self._ready.popleft() + if not handle._cancelled: + handle._run() + handle = None # Needed to break cycles when an exception occurs. diff --git a/tulip/constants.py b/tulip/constants.py new file mode 100644 index 00000000..79c3b931 --- /dev/null +++ b/tulip/constants.py @@ -0,0 +1,4 @@ +"""Constants.""" + + +LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5 diff --git a/tulip/events.py b/tulip/events.py new file mode 100644 index 00000000..37b95594 --- /dev/null +++ b/tulip/events.py @@ -0,0 +1,393 @@ +"""Event loop and event loop policy. + +Beyond the PEP: +- Only the main thread has a default event loop. +""" + +__all__ = ['AbstractEventLoopPolicy', 'DefaultEventLoopPolicy', + 'AbstractEventLoop', 'TimerHandle', 'Handle', 'make_handle', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', 'new_event_loop', + ] + +import subprocess +import sys +import threading +import socket + +from .log import tulip_log + + +class Handle: + """Object returned by callback registration methods.""" + + def __init__(self, callback, args): + self._callback = callback + self._args = args + self._cancelled = False + + def __repr__(self): + res = 'Handle({}, {})'.format(self._callback, self._args) + if self._cancelled: + res += '' + return res + + def cancel(self): + self._cancelled = True + + def _run(self): + try: + self._callback(*self._args) + except Exception: + tulip_log.exception('Exception in callback %s %r', + self._callback, self._args) + self = None # Needed to break cycles when an exception occurs. + + +def make_handle(callback, args): + # TODO: Inline this? + assert not isinstance(callback, Handle), 'A Handle is not a callback' + return Handle(callback, args) + + +class TimerHandle(Handle): + """Object returned by timed callback registration methods.""" + + def __init__(self, when, callback, args): + assert when is not None + super().__init__(callback, args) + + self._when = when + + def __repr__(self): + res = 'TimerHandle({}, {}, {})'.format(self._when, + self._callback, + self._args) + if self._cancelled: + res += '' + + return res + + def __hash__(self): + return hash(self._when) + + def __lt__(self, other): + return self._when < other._when + + def __le__(self, other): + if self._when < other._when: + return True + return self.__eq__(other) + + def __gt__(self, other): + return self._when > other._when + + def __ge__(self, other): + if self._when > other._when: + return True + return self.__eq__(other) + + def __eq__(self, other): + if isinstance(other, TimerHandle): + return (self._when == other._when and + self._callback == other._callback and + self._args == other._args and + self._cancelled == other._cancelled) + return NotImplemented + + def __ne__(self, other): + equal = self.__eq__(other) + return NotImplemented if equal is NotImplemented else not equal + + +class AbstractEventLoop: + """Abstract event loop.""" + + # Running and stopping the event loop. + + def run_forever(self): + """Run the event loop until stop() is called.""" + raise NotImplementedError + + def run_until_complete(self, future, timeout=None): + """Run the event loop until a Future is done. + + Return the Future's result, or raise its exception. + + If timeout is not None, run it for at most that long; + if the Future is still not done, raise TimeoutError + (but don't cancel the Future). + """ + raise NotImplementedError + + def stop(self): + """Stop the event loop as soon as reasonable. + + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + raise NotImplementedError + + def is_running(self): + """Return whether the event loop is currently running.""" + raise NotImplementedError + + # Methods scheduling callbacks. All these return Handles. + + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) + + def call_later(self, delay, callback, *args): + raise NotImplementedError + + def call_at(self, when, callback, *args): + raise NotImplementedError + + def time(self): + raise NotImplementedError + + # Methods for interacting with threads. + + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError + + def run_in_executor(self, executor, callback, *args): + raise NotImplementedError + + def set_default_executor(self, executor): + raise NotImplementedError + + # Network I/O methods returning Futures. + + def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError + + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=None, family=0, proto=0, flags=0, sock=None, + local_addr=None): + raise NotImplementedError + + def start_serving(self, protocol_factory, host=None, port=None, *, + family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, + sock=None, backlog=100, ssl=None, reuse_address=None): + """Creates a TCP server bound to host and port and return a + Task whose result will be a list of socket objects which will + later be handled by protocol_factory. + + If host is an empty string or None all interfaces are assumed + and a list of multiple sockets will be returned (most likely + one for IPv4 and another one for IPv6). + + family can be set to either AF_INET or AF_INET6 to force the + socket to use IPv4 or IPv6. If not set it will be determined + from host (defaults to AF_UNSPEC). + + flags is a bitmask for getaddrinfo(). + + sock can optionally be specified in order to use a preexisting + socket object. + + backlog is the maximum number of queued connections passed to + listen() (defaults to 100). + + ssl can be set to an SSLContext to enable SSL over the + accepted connections. + + reuse_address tells the kernel to reuse a local socket in + TIME_WAIT state, without waiting for its natural timeout to + expire. If not specified will automatically be set to True on + UNIX. + """ + raise NotImplementedError + + def stop_serving(self, sock): + """Stop listening for incoming connections. Close socket.""" + raise NotImplementedError + + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + def connect_read_pipe(self, protocol_factory, pipe): + """Register read pipe in eventloop. + + protocol_factory should instantiate object with Protocol interface. + pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + ReadTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def connect_write_pipe(self, protocol_factory, pipe): + """Register write pipe in eventloop. + + protocol_factory should instantiate object with BaseProtocol interface. + Pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + WriteTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + **kwargs): + raise NotImplementedError + + def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + **kwargs): + raise NotImplementedError + + # Ready-based callback registration methods. + # The add_*() methods return None. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. + + def add_reader(self, fd, callback, *args): + raise NotImplementedError + + def remove_reader(self, fd): + raise NotImplementedError + + def add_writer(self, fd, callback, *args): + raise NotImplementedError + + def remove_writer(self, fd): + raise NotImplementedError + + # Completion based I/O methods returning Futures. + + def sock_recv(self, sock, nbytes): + raise NotImplementedError + + def sock_sendall(self, sock, data): + raise NotImplementedError + + def sock_connect(self, sock, address): + raise NotImplementedError + + def sock_accept(self, sock): + raise NotImplementedError + + # Signal handling. + + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError + + def remove_signal_handler(self, sig): + raise NotImplementedError + + +class AbstractEventLoopPolicy: + """Abstract policy for accessing the event loop.""" + + def get_event_loop(self): + """XXX""" + raise NotImplementedError + + def set_event_loop(self, loop): + """XXX""" + raise NotImplementedError + + def new_event_loop(self): + """XXX""" + raise NotImplementedError + + +class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy): + """Default policy implementation for accessing the event loop. + + In this policy, each thread has its own event loop. However, we + only automatically create an event loop by default for the main + thread; other threads by default have no event loop. + + Other policies may have different rules (e.g. a single global + event loop, or automatically creating an event loop per thread, or + using some other notion of context to which an event loop is + associated). + """ + + _loop = None + _set_called = False + + def get_event_loop(self): + """Get the event loop. + + This may be None or an instance of EventLoop. + """ + if (self._loop is None and + not self._set_called and + threading.current_thread().name == 'MainThread'): + self._loop = self.new_event_loop() + assert self._loop is not None, \ + ('There is no current event loop in thread %r.' % + threading.current_thread().name) + return self._loop + + def set_event_loop(self, loop): + """Set the event loop.""" + # TODO: The isinstance() test violates the PEP. + self._set_called = True + assert loop is None or isinstance(loop, AbstractEventLoop) + self._loop = loop + + def new_event_loop(self): + """Create a new event loop. + + You must call set_event_loop() to make this the current event + loop. + """ + if sys.platform == 'win32': # pragma: no cover + from . import windows_events + return windows_events.SelectorEventLoop() + else: # pragma: no cover + from . import unix_events + return unix_events.SelectorEventLoop() + + +# Event loop policy. The policy itself is always global, even if the +# policy's rules say that there is an event loop per thread (or other +# notion of context). The default policy is installed by the first +# call to get_event_loop_policy(). +_event_loop_policy = None + + +def get_event_loop_policy(): + """XXX""" + global _event_loop_policy + if _event_loop_policy is None: + _event_loop_policy = DefaultEventLoopPolicy() + return _event_loop_policy + + +def set_event_loop_policy(policy): + """XXX""" + global _event_loop_policy + # TODO: The isinstance() test violates the PEP. + assert policy is None or isinstance(policy, AbstractEventLoopPolicy) + _event_loop_policy = policy + + +def get_event_loop(): + """XXX""" + return get_event_loop_policy().get_event_loop() + + +def set_event_loop(loop): + """XXX""" + get_event_loop_policy().set_event_loop(loop) + + +def new_event_loop(): + """XXX""" + return get_event_loop_policy().new_event_loop() diff --git a/tulip/futures.py b/tulip/futures.py new file mode 100644 index 00000000..8593e9ae --- /dev/null +++ b/tulip/futures.py @@ -0,0 +1,362 @@ +"""A Future class similar to the one in PEP 3148.""" + +__all__ = ['CancelledError', 'TimeoutError', + 'InvalidStateError', 'InvalidTimeoutError', + 'Future', 'wrap_future', + ] + +import concurrent.futures._base +import logging +import traceback + +from . import events +from .log import tulip_log + +# States for Future. +_PENDING = 'PENDING' +_CANCELLED = 'CANCELLED' +_FINISHED = 'FINISHED' + +# TODO: Do we really want to depend on concurrent.futures internals? +Error = concurrent.futures._base.Error +CancelledError = concurrent.futures.CancelledError +TimeoutError = concurrent.futures.TimeoutError + +STACK_DEBUG = logging.DEBUG - 1 # heavy-duty debugging + + +class InvalidStateError(Error): + """The operation is not allowed in this state.""" + # TODO: Show the future, its state, the method, and the required state. + + +class InvalidTimeoutError(Error): + """Called result() or exception() with timeout != 0.""" + # TODO: Print a nice error message. + + +class _TracebackLogger: + """Helper to log a traceback upon destruction if not cleared. + + This solves a nasty problem with Futures and Tasks that have an + exception set: if nobody asks for the exception, the exception is + never logged. This violates the Zen of Python: 'Errors should + never pass silently. Unless explicitly silenced.' + + However, we don't want to log the exception as soon as + set_exception() is called: if the calling code is written + properly, it will get the exception and handle it properly. But + we *do* want to log it if result() or exception() was never called + -- otherwise developers waste a lot of time wondering why their + buggy code fails silently. + + An earlier attempt added a __del__() method to the Future class + itself, but this backfired because the presence of __del__() + prevents garbage collection from breaking cycles. A way out of + this catch-22 is to avoid having a __del__() method on the Future + class itself, but instead to have a reference to a helper object + with a __del__() method that logs the traceback, where we ensure + that the helper object doesn't participate in cycles, and only the + Future has a reference to it. + + The helper object is added when set_exception() is called. When + the Future is collected, and the helper is present, the helper + object is also collected, and its __del__() method will log the + traceback. When the Future's result() or exception() method is + called (and a helper object is present), it removes the the helper + object, after calling its clear() method to prevent it from + logging. + + One downside is that we do a fair amount of work to extract the + traceback from the exception, even when it is never logged. It + would seem cheaper to just store the exception object, but that + references the traceback, which references stack frames, which may + reference the Future, which references the _TracebackLogger, and + then the _TracebackLogger would be included in a cycle, which is + what we're trying to avoid! As an optimization, we don't + immediately format the exception; we only do the work when + activate() is called, which call is delayed until after all the + Future's callbacks have run. Since usually a Future has at least + one callback (typically set by 'yield from') and usually that + callback extracts the callback, thereby removing the need to + format the exception. + + PS. I don't claim credit for this solution. I first heard of it + in a discussion about closing files when they are collected. + """ + + __slots__ = ['exc', 'tb'] + + def __init__(self, exc): + self.exc = exc + self.tb = None + + def activate(self): + exc = self.exc + if exc is not None: + self.exc = None + self.tb = traceback.format_exception(exc.__class__, exc, + exc.__traceback__) + + def clear(self): + self.exc = None + self.tb = None + + def __del__(self): + if self.tb: + tulip_log.error('Future/Task exception was never retrieved:\n%s', + ''.join(self.tb)) + + +class Future: + """This class is *almost* compatible with concurrent.futures.Future. + + Differences: + + - result() and exception() do not take a timeout argument and + raise an exception when the future isn't done yet. + + - Callbacks registered with add_done_callback() are always called + via the event loop's call_soon_threadsafe(). + + - This class is not compatible with the wait() and as_completed() + methods in the concurrent.futures package. + + (In Python 3.4 or later we may be able to unify the implementations.) + """ + + # Class variables serving as defaults for instance variables. + _state = _PENDING + _result = None + _exception = None + _timeout = None + _timeout_handle = None + _loop = None + + _blocking = False # proper use of future (yield vs yield from) + + _tb_logger = None + + def __init__(self, *, loop=None, timeout=None): + """Initialize the future. + + The optional event_loop argument allows to explicitly set the event + loop object used by the future. If it's not provided, the future uses + the default event loop. + """ + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._callbacks = [] + + if timeout is not None: + self._timeout = timeout + self._timeout_handle = self._loop.call_later(timeout, self.cancel) + + def __repr__(self): + res = self.__class__.__name__ + if self._state == _FINISHED: + if self._exception is not None: + res += ''.format(self._exception) + else: + res += ''.format(self._result) + elif self._callbacks: + size = len(self._callbacks) + if size > 2: + res += '<{}, [{}, <{} more>, {}]>'.format( + self._state, self._callbacks[0], + size-2, self._callbacks[-1]) + else: + res += '<{}, {}>'.format(self._state, self._callbacks) + else: + res += '<{}>'.format(self._state) + dct = {} + if self._timeout is not None: + dct['timeout'] = self._timeout + if self._timeout_handle is not None: + dct['when'] = self._timeout_handle._when + if dct: + res += '{' + ', '.join('{}={}'.format(k, dct[k]) + for k in sorted(dct)) + '}' + return res + + def cancel(self): + """Cancel the future and schedule callbacks. + + If the future is already done or cancelled, return False. Otherwise, + change the future's state to cancelled, schedule the callbacks and + return True. + """ + if self._state != _PENDING: + return False + self._state = _CANCELLED + self._schedule_callbacks() + return True + + def _schedule_callbacks(self): + """Internal: Ask the event loop to call all callbacks. + + The callbacks are scheduled to be called as soon as possible. Also + clears the callback list. + """ + # Cancel timeout handle + if self._timeout_handle is not None: + self._timeout_handle.cancel() + self._timeout_handle = None + + callbacks = self._callbacks[:] + if not callbacks: + return + + self._callbacks[:] = [] + for callback in callbacks: + self._loop.call_soon(callback, self) + + def cancelled(self): + """Return True if the future was cancelled.""" + return self._state == _CANCELLED + + # Don't implement running(); see http://bugs.python.org/issue18699 + + def done(self): + """Return True if the future is done. + + Done means either that a result / exception are available, or that the + future was cancelled. + """ + return self._state != _PENDING + + def result(self): + """Return the result this future represents. + + If the future has been cancelled, raises CancelledError. If the + future's result isn't yet available, raises InvalidStateError. If + the future is done and has an exception set, this exception is raised. + """ + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError('Result is not ready.') + if self._tb_logger is not None: + self._tb_logger.clear() + self._tb_logger = None + if self._exception is not None: + raise self._exception + return self._result + + def exception(self): + """Return the exception that was set on this future. + + The exception (or None if no exception was set) is returned only if + the future is done. If the future has been cancelled, raises + CancelledError. If the future isn't done yet, raises + InvalidStateError. + """ + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError('Exception is not set.') + if self._tb_logger is not None: + self._tb_logger.clear() + self._tb_logger = None + return self._exception + + def add_done_callback(self, fn): + """Add a callback to be run when the future becomes done. + + The callback is called with a single argument - the future object. If + the future is already done when this is called, the callback is + scheduled with call_soon. + """ + if self._state != _PENDING: + self._loop.call_soon(fn, self) + else: + self._callbacks.append(fn) + + # New method not in PEP 3148. + + def remove_done_callback(self, fn): + """Remove all instances of a callback from the "call when done" list. + + Returns the number of callbacks removed. + """ + filtered_callbacks = [f for f in self._callbacks if f != fn] + removed_count = len(self._callbacks) - len(filtered_callbacks) + if removed_count: + self._callbacks[:] = filtered_callbacks + return removed_count + + # So-called internal methods (note: no set_running_or_notify_cancel()). + + def set_result(self, result): + """Mark the future done and set its result. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError('{}: {!r}'.format(self._state, self)) + self._result = result + self._state = _FINISHED + self._schedule_callbacks() + + def set_exception(self, exception): + """Mark the future done and set an exception. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError('{}: {!r}'.format(self._state, self)) + self._exception = exception + self._tb_logger = _TracebackLogger(exception) + self._state = _FINISHED + self._schedule_callbacks() + # Arrange for the logger to be activated after all callbacks + # have had a chance to call result() or exception(). + self._loop.call_soon(self._tb_logger.activate) + + # Truly internal methods. + + def _copy_state(self, other): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert other.done() + assert not self.done() + if other.cancelled(): + self.cancel() + else: + exception = other.exception() + if exception is not None: + self.set_exception(exception) + else: + result = other.result() + self.set_result(result) + + def __iter__(self): + if not self.done(): + self._blocking = True + yield self # This tells Task to wait for completion. + assert self.done(), "yield from wasn't used with future" + return self.result() # May raise too. + + +def wrap_future(fut, *, loop=None): + """Wrap concurrent.futures.Future object.""" + if isinstance(fut, Future): + return fut + + assert isinstance(fut, concurrent.futures.Future), \ + 'concurrent.futures.Future is expected, got {!r}'.format(fut) + + if loop is None: + loop = events.get_event_loop() + + new_future = Future(loop=loop) + fut.add_done_callback( + lambda future: loop.call_soon_threadsafe( + new_future._copy_state, fut)) + return new_future diff --git a/tulip/http/__init__.py b/tulip/http/__init__.py new file mode 100644 index 00000000..a1432dee --- /dev/null +++ b/tulip/http/__init__.py @@ -0,0 +1,16 @@ +# This relies on each of the submodules having an __all__ variable. + +from .client import * +from .errors import * +from .protocol import * +from .server import * +from .session import * +from .wsgi import * + + +__all__ = (client.__all__ + + errors.__all__ + + protocol.__all__ + + server.__all__ + + session.__all__ + + wsgi.__all__) diff --git a/tulip/http/client.py b/tulip/http/client.py new file mode 100644 index 00000000..2aedfdd1 --- /dev/null +++ b/tulip/http/client.py @@ -0,0 +1,565 @@ +"""HTTP Client for Tulip. + +Most basic usage: + + response = yield from tulip.http.request('GET', url) + response['Content-Type'] == 'application/json' + response.status == 200 + + content = yield from response.content.read() +""" + +__all__ = ['request'] + +import base64 +import email.message +import functools +import http.client +import http.cookies +import json +import io +import itertools +import mimetypes +import os +import uuid +import urllib.parse + +import tulip +import tulip.http + + +@tulip.coroutine +def request(method, url, *, + params=None, + data=None, + headers=None, + cookies=None, + files=None, + auth=None, + allow_redirects=True, + max_redirects=10, + encoding='utf-8', + version=(1, 1), + timeout=None, + compress=None, + chunked=None, + session=None, + loop=None): + """Constructs and sends a request. Returns response object. + + method: http method + url: request url + params: (optional) Dictionary or bytes to be sent in the query string + of the new request + data: (optional) Dictionary, bytes, or file-like object to + send in the body of the request + headers: (optional) Dictionary of HTTP Headers to send with the request + cookies: (optional) Dict object to send with the request + files: (optional) Dictionary of 'name': file-like-objects + for multipart encoding upload + auth: (optional) Auth tuple to enable Basic HTTP Auth + timeout: (optional) Float describing the timeout of the request + allow_redirects: (optional) Boolean. Set to True if POST/PUT/DELETE + redirect following is allowed. + compress: Boolean. Set to True if request has to be compressed + with deflate encoding. + chunked: Boolean or Integer. Set to chunk size for chunked + transfer encoding. + session: tulip.http.Session instance to support connection pooling and + session cookies. + loop: Optional event loop. + + Usage: + + import tulip.http + >> resp = yield from tulip.http.request('GET', 'http://python.org/') + >> resp + + + >> data = yield from resp.content.read() + + """ + redirects = 0 + if loop is None: + loop = tulip.get_event_loop() + + while True: + req = HttpRequest( + method, url, params=params, headers=headers, data=data, + cookies=cookies, files=files, auth=auth, encoding=encoding, + version=version, compress=compress, chunked=chunked) + + if session is None: + conn = start(req, loop) + else: + conn = session.start(req, loop) + + # connection timeout + try: + resp = yield from tulip.Task(conn, timeout=timeout, loop=loop) + except tulip.CancelledError: + raise tulip.TimeoutError from None + + # redirects + if resp.status in (301, 302) and allow_redirects: + redirects += 1 + if max_redirects and redirects >= max_redirects: + resp.close() + break + + r_url = resp.get('location') or resp.get('uri') + + scheme = urllib.parse.urlsplit(r_url)[0] + if scheme not in ('http', 'https', ''): + raise ValueError('Can redirect only to http or https') + elif not scheme: + r_url = urllib.parse.urljoin(url, r_url) + + url = urllib.parse.urldefrag(r_url)[0] + if url: + resp.close() + continue + + break + + return resp + + +@tulip.coroutine +def start(req, loop): + transport, p = yield from loop.create_connection( + functools.partial(tulip.StreamProtocol, loop=loop), + req.host, req.port, ssl=req.ssl) + + try: + resp = req.send(transport) + yield from resp.start(p, transport) + except: + transport.close() + raise + + return resp + + +class HttpRequest: + + GET_METHODS = {'DELETE', 'GET', 'HEAD', 'OPTIONS'} + POST_METHODS = {'PATCH', 'POST', 'PUT', 'TRACE'} + ALL_METHODS = GET_METHODS.union(POST_METHODS) + + DEFAULT_HEADERS = { + 'Accept': '*/*', + 'Accept-Encoding': 'gzip, deflate', + } + + body = b'' + + def __init__(self, method, url, *, + params=None, + headers=None, + data=None, + cookies=None, + files=None, + auth=None, + encoding='utf-8', + version=(1, 1), + compress=None, + chunked=None): + self.method = method.upper() + self.encoding = encoding + + # parser http version '1.1' => (1, 1) + if isinstance(version, str): + v = [l.strip() for l in version.split('.', 1)] + try: + version = int(v[0]), int(v[1]) + except ValueError: + raise ValueError( + 'Can not parse http version number: {}' + .format(version)) from None + self.version = version + + # path + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not netloc: + raise ValueError('Host could not be detected.') + + if not path: + path = '/' + else: + path = urllib.parse.unquote(path) + + # check domain idna encoding + try: + netloc = netloc.encode('idna').decode('utf-8') + except UnicodeError: + raise ValueError('URL has an invalid label.') + + # basic auth info + if '@' in netloc: + authinfo, netloc = netloc.split('@', 1) + if not auth: + auth = authinfo.split(':', 1) + if len(auth) == 1: + auth.append('') + + # extract host and port + ssl = scheme == 'https' + + if ':' in netloc: + netloc, port_s = netloc.split(':', 1) + try: + port = int(port_s) + except ValueError: + raise ValueError( + 'Port number could not be converted.') from None + else: + if ssl: + port = http.client.HTTPS_PORT + else: + port = http.client.HTTP_PORT + + self.host = netloc + self.port = port + self.ssl = ssl + + # build url query + if isinstance(params, dict): + params = list(params.items()) + + if data and self.method in self.GET_METHODS: + # include data to query + if isinstance(data, dict): + data = data.items() + params = list(itertools.chain(params or (), data)) + data = None + + if params: + params = urllib.parse.urlencode(params) + if query: + query = '%s&%s' % (query, params) + else: + query = params + + # build path + path = urllib.parse.quote(path) + self.path = urllib.parse.urlunsplit(('', '', path, query, fragment)) + + # headers + self.headers = email.message.Message() + if headers: + if isinstance(headers, dict): + headers = list(headers.items()) + + for key, value in headers: + self.headers.add_header(key, value) + + for hdr, val in self.DEFAULT_HEADERS.items(): + if hdr not in self.headers: + self.headers[hdr] = val + + # host + if 'host' not in self.headers: + self.headers['Host'] = self.host + + # cookies + if cookies: + self.update_cookies(cookies) + + # auth + if auth: + if isinstance(auth, (tuple, list)) and len(auth) == 2: + # basic auth + self.headers['Authorization'] = 'Basic %s' % ( + base64.b64encode( + ('%s:%s' % (auth[0], auth[1])).encode('latin1')) + .strip().decode('latin1')) + else: + raise ValueError("Only basic auth is supported") + + # Content-encoding + enc = self.headers.get('Content-Encoding', '').lower() + if enc: + chunked = True # enable chunked, no need to deal with length + compress = enc + elif compress: + chunked = True # enable chunked, no need to deal with length + compress = compress if isinstance(compress, str) else 'deflate' + self.headers['Content-Encoding'] = compress + + # form data (x-www-form-urlencoded) + if isinstance(data, dict): + data = list(data.items()) + + if data and not files: + if not isinstance(data, str): + data = urllib.parse.urlencode(data, doseq=True) + + self.body = data.encode(encoding) + if 'content-type' not in self.headers: + self.headers['content-type'] = ( + 'application/x-www-form-urlencoded') + if 'content-length' not in self.headers and not chunked: + self.headers['content-length'] = str(len(self.body)) + + # files (multipart/form-data) + elif files: + fields = [] + + if data: + for field, val in data: + fields.append((field, str_to_bytes(val))) + + if isinstance(files, dict): + files = list(files.items()) + + for rec in files: + if not isinstance(rec, (tuple, list)): + rec = (rec,) + + ft = None + if len(rec) == 1: + k = guess_filename(rec[0], 'unknown') + fields.append((k, k, rec[0])) + + elif len(rec) == 2: + k, fp = rec + fn = guess_filename(fp, k) + fields.append((k, fn, fp)) + + else: + k, fp, ft = rec + fn = guess_filename(fp, k) + fields.append((k, fn, fp, ft)) + + chunked = chunked or 8192 + boundary = uuid.uuid4().hex + + self.body = encode_multipart_data( + fields, bytes(boundary, 'latin1')) + + self.headers['content-type'] = ( + 'multipart/form-data; boundary=%s' % boundary) + + # chunked + te = self.headers.get('transfer-encoding', '').lower() + + if chunked: + if 'content-length' in self.headers: + del self.headers['content-length'] + if 'chunked' not in te: + self.headers['transfer-encoding'] = 'chunked' + + chunked = chunked if type(chunked) is int else 8196 + else: + if 'chunked' in te: + chunked = 8196 + else: + chunked = None + self.headers['content-length'] = str(len(self.body)) + + self._chunked = chunked + self._compress = compress + + def update_cookies(self, cookies): + """Update request cookies header.""" + c = http.cookies.SimpleCookie() + if 'cookie' in self.headers: + c.load(self.headers.get('cookie', '')) + del self.headers['cookie'] + + if isinstance(cookies, dict): + cookies = cookies.items() + + for name, value in cookies: + if isinstance(value, http.cookies.Morsel): + # use dict method because SimpleCookie class modifies value + dict.__setitem__(c, name, value) + else: + c[name] = value + + self.headers['cookie'] = c.output(header='', sep=';').strip() + + def send(self, transport): + request = tulip.http.Request( + transport, self.method, self.path, self.version) + + if self._compress: + request.add_compression_filter(self._compress) + + if self._chunked is not None: + request.add_chunking_filter(self._chunked) + + request.add_headers(*self.headers.items()) + request.send_headers() + + if isinstance(self.body, bytes): + self.body = (self.body,) + + for chunk in self.body: + request.write(chunk) + + request.write_eof() + + return HttpResponse(self.method, self.path, self.host) + + +class HttpResponse(http.client.HTTPMessage): + + message = None # RawResponseMessage object + + # from the Status-Line of the response + version = None # HTTP-Version + status = None # Status-Code + reason = None # Reason-Phrase + + cookies = None # Response cookies (Set-Cookie) + + content = None # payload stream + stream = None # input stream + transport = None # current transport + + def __init__(self, method, url, host=''): + super().__init__() + + self.method = method + self.url = url + self.host = host + self._content = None + + def __del__(self): + self.close() + + def __repr__(self): + out = io.StringIO() + print(''.format( + self.host, self.url, self.status, self.reason), file=out) + print(super().__str__(), file=out) + return out.getvalue() + + def start(self, stream, transport): + """Start response processing.""" + self.stream = stream + self.transport = transport + + httpstream = stream.set_parser(tulip.http.http_response_parser()) + + # read response + self.message = yield from httpstream.read() + + # response status + self.version = self.message.version + self.status = self.message.code + self.reason = self.message.reason + + # headers + for hdr, val in self.message.headers: + self.add_header(hdr, val) + + # payload + self.content = stream.set_parser( + tulip.http.http_payload_parser(self.message)) + + # cookies + self.cookies = http.cookies.SimpleCookie() + if 'Set-Cookie' in self: + for hdr in self.get_all('Set-Cookie'): + self.cookies.load(hdr) + + return self + + def close(self): + if self.transport is not None: + self.transport.close() + self.transport = None + + @tulip.coroutine + def read(self, decode=False): + """Read response payload. Decode known types of content.""" + if self._content is None: + buf = [] + total = 0 + chunk = yield from self.content.read() + while chunk: + size = len(chunk) + buf.append((chunk, size)) + total += size + chunk = yield from self.content.read() + + self._content = bytearray(total) + + idx = 0 + content = memoryview(self._content) + for chunk, size in buf: + content[idx:idx+size] = chunk + idx += size + + data = self._content + + if decode: + ct = self.get('content-type', '').lower() + if ct == 'application/json': + data = json.loads(data.decode('utf-8')) + + return data + + +def str_to_bytes(s, encoding='utf-8'): + if isinstance(s, str): + return s.encode(encoding) + return s + + +def guess_filename(obj, default=None): + name = getattr(obj, 'name', None) + if name and name[0] != '<' and name[-1] != '>': + return os.path.split(name)[-1] + return default + + +def encode_multipart_data(fields, boundary, encoding='utf-8', chunk_size=8196): + """ + Encode a list of fields using the multipart/form-data MIME format. + + fields: + List of (name, value) or (name, filename, io) or + (name, filename, io, MIME type) field tuples. + """ + for rec in fields: + yield b'--' + boundary + b'\r\n' + + field, *rec = rec + + if len(rec) == 1: + data = rec[0] + yield (('Content-Disposition: form-data; name="%s"\r\n\r\n' % + (field,)).encode(encoding)) + yield data + b'\r\n' + + else: + if len(rec) == 3: + fn, fp, ct = rec + else: + fn, fp = rec + ct = (mimetypes.guess_type(fn)[0] or + 'application/octet-stream') + + yield ('Content-Disposition: form-data; name="%s"; ' + 'filename="%s"\r\n' % (field, fn)).encode(encoding) + yield ('Content-Type: %s\r\n\r\n' % (ct,)).encode(encoding) + + if isinstance(fp, str): + fp = fp.encode(encoding) + + if isinstance(fp, bytes): + fp = io.BytesIO(fp) + + while True: + chunk = fp.read(chunk_size) + if not chunk: + break + yield str_to_bytes(chunk) + + yield b'\r\n' + + yield b'--' + boundary + b'--\r\n' diff --git a/tulip/http/errors.py b/tulip/http/errors.py new file mode 100644 index 00000000..f8b77e9b --- /dev/null +++ b/tulip/http/errors.py @@ -0,0 +1,46 @@ +"""http related errors.""" + +__all__ = ['HttpException', 'HttpErrorException', 'BadRequestException', + 'IncompleteRead', 'BadStatusLine', 'LineTooLong', 'InvalidHeader'] + +import http.client + + +class HttpException(http.client.HTTPException): + + code = None + headers = () + message = '' + + +class HttpErrorException(HttpException): + + def __init__(self, code, message='', headers=None): + self.code = code + self.headers = headers + self.message = message + + +class BadRequestException(HttpException): + + code = 400 + message = 'Bad Request' + + +class IncompleteRead(BadRequestException, http.client.IncompleteRead): + pass + + +class BadStatusLine(BadRequestException, http.client.BadStatusLine): + pass + + +class LineTooLong(BadRequestException, http.client.LineTooLong): + pass + + +class InvalidHeader(BadRequestException): + + def __init__(self, hdr): + super().__init__('Invalid HTTP Header: {}'.format(hdr)) + self.hdr = hdr diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py new file mode 100644 index 00000000..7081fd59 --- /dev/null +++ b/tulip/http/protocol.py @@ -0,0 +1,756 @@ +"""Http related helper utils.""" + +__all__ = ['HttpMessage', 'Request', 'Response', + 'RawRequestMessage', 'RawResponseMessage', + 'http_request_parser', 'http_response_parser', + 'http_payload_parser'] + +import collections +import functools +import http.server +import itertools +import re +import sys +import zlib +from wsgiref.handlers import format_date_time + +import tulip +from tulip.http import errors + +METHRE = re.compile('[A-Z0-9$-_.]+') +VERSRE = re.compile('HTTP/(\d+).(\d+)') +HDRRE = re.compile('[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]') +CONTINUATION = (' ', '\t') +EOF_MARKER = object() +EOL_MARKER = object() + +RESPONSES = http.server.BaseHTTPRequestHandler.responses + + +RawRequestMessage = collections.namedtuple( + 'RawRequestMessage', + ['method', 'path', 'version', 'headers', 'should_close', 'compression']) + + +RawResponseMessage = collections.namedtuple( + 'RawResponseMessage', + ['version', 'code', 'reason', 'headers', 'should_close', 'compression']) + + +def http_request_parser(max_line_size=8190, + max_headers=32768, max_field_size=8190): + """Read request status line. Exception errors.BadStatusLine + could be raised in case of any errors in status line. + Returns RawRequestMessage. + """ + out, buf = yield + + try: + # read http message (request line + headers) + raw_data = yield from buf.readuntil( + b'\r\n\r\n', max_headers, errors.LineTooLong) + lines = raw_data.decode('ascii', 'surrogateescape').splitlines(True) + + # request line + line = lines[0] + try: + method, path, version = line.split(None, 2) + except ValueError: + raise errors.BadStatusLine(line) from None + + # method + method = method.upper() + if not METHRE.match(method): + raise errors.BadStatusLine(method) + + # version + match = VERSRE.match(version) + if match is None: + raise errors.BadStatusLine(version) + version = (int(match.group(1)), int(match.group(2))) + + # read headers + headers, close, compression = parse_headers( + lines, max_line_size, max_headers, max_field_size) + if version <= (1, 0): + close = True + elif close is None: + close = False + + out.feed_data( + RawRequestMessage( + method, path, version, headers, close, compression)) + out.feed_eof() + except tulip.EofStream: + # Presumably, the server closed the connection before + # sending a valid response. + pass + + +def http_response_parser(max_line_size=8190, + max_headers=32768, max_field_size=8190): + """Read response status line and headers. + + BadStatusLine could be raised in case of any errors in status line. + Returns RawResponseMessage""" + out, buf = yield + + try: + # read http message (response line + headers) + raw_data = yield from buf.readuntil( + b'\r\n\r\n', max_line_size+max_headers, errors.LineTooLong) + lines = raw_data.decode('ascii', 'surrogateescape').splitlines(True) + + line = lines[0] + try: + version, status = line.split(None, 1) + except ValueError: + raise errors.BadStatusLine(line) from None + else: + try: + status, reason = status.split(None, 1) + except ValueError: + reason = '' + + # version + match = VERSRE.match(version) + if match is None: + raise errors.BadStatusLine(line) + version = (int(match.group(1)), int(match.group(2))) + + # The status code is a three-digit number + try: + status = int(status) + except ValueError: + raise errors.BadStatusLine(line) from None + + if status < 100 or status > 999: + raise errors.BadStatusLine(line) + + # read headers + headers, close, compression = parse_headers( + lines, max_line_size, max_headers, max_field_size) + if close is None: + close = version <= (1, 0) + + out.feed_data( + RawResponseMessage( + version, status, reason.strip(), headers, close, compression)) + out.feed_eof() + except tulip.EofStream: + # Presumably, the server closed the connection before + # sending a valid response. + raise errors.BadStatusLine(b'') from None + + +def parse_headers(lines, max_line_size, max_headers, max_field_size): + """Parses RFC2822 headers from a stream. + + Line continuations are supported. Returns list of header name + and value pairs. Header name is in upper case. + """ + close_conn = None + encoding = None + headers = collections.deque() + + lines_idx = 1 + line = lines[1] + + while line not in ('\r\n', '\n'): + header_length = len(line) + + # Parse initial header name : value pair. + try: + name, value = line.split(':', 1) + except ValueError: + raise ValueError('Invalid header: {}'.format(line)) from None + + name = name.strip(' \t').upper() + if HDRRE.search(name): + raise ValueError('Invalid header name: {}'.format(name)) + + # next line + lines_idx += 1 + line = lines[lines_idx] + + # consume continuation lines + continuation = line[0] in CONTINUATION + + if continuation: + value = [value] + while continuation: + header_length += len(line) + if header_length > max_field_size: + raise errors.LineTooLong( + 'limit request headers fields size') + value.append(line) + + # next line + lines_idx += 1 + line = lines[lines_idx] + continuation = line[0] in CONTINUATION + value = ''.join(value) + else: + if header_length > max_field_size: + raise errors.LineTooLong('limit request headers fields size') + + value = value.strip() + + # keep-alive and encoding + if name == 'CONNECTION': + v = value.lower() + if v == 'close': + close_conn = True + elif v == 'keep-alive': + close_conn = False + elif name == 'CONTENT-ENCODING': + enc = value.lower() + if enc in ('gzip', 'deflate'): + encoding = enc + + headers.append((name, value)) + + return headers, close_conn, encoding + + +def http_payload_parser(message, length=None, compression=True, readall=False): + out, buf = yield + + # payload params + chunked = False + for name, value in message.headers: + if name == 'CONTENT-LENGTH': + length = value + elif name == 'TRANSFER-ENCODING': + chunked = value.lower() == 'chunked' + elif name == 'SEC-WEBSOCKET-KEY1': + length = 8 + + # payload decompression wrapper + if compression and message.compression: + out = DeflateBuffer(out, message.compression) + + # payload parser + if chunked: + yield from parse_chunked_payload(out, buf) + + elif length is not None: + try: + length = int(length) + except ValueError: + raise errors.InvalidHeader('CONTENT-LENGTH') from None + + if length < 0: + raise errors.InvalidHeader('CONTENT-LENGTH') + elif length > 0: + yield from parse_length_payload(out, buf, length) + else: + if readall: + yield from parse_eof_payload(out, buf) + + out.feed_eof() + + +def parse_chunked_payload(out, buf): + """Chunked transfer encoding parser.""" + try: + while True: + # read next chunk size + #line = yield from buf.readline(8196) + line = yield from buf.readuntil(b'\r\n', 8196) + + i = line.find(b';') + if i >= 0: + line = line[:i] # strip chunk-extensions + else: + line = line.strip() + try: + size = int(line, 16) + except ValueError: + raise errors.IncompleteRead(b'') from None + + if size == 0: # eof marker + break + + # read chunk and feed buffer + while size: + chunk = yield from buf.readsome(size) + out.feed_data(chunk) + size = size - len(chunk) + + # toss the CRLF at the end of the chunk + yield from buf.skip(2) + + # read and discard trailer up to the CRLF terminator + yield from buf.skipuntil(b'\r\n') + + except tulip.EofStream: + raise errors.IncompleteRead(b'') from None + + +def parse_length_payload(out, buf, length): + """Read specified amount of bytes.""" + try: + while length: + chunk = yield from buf.readsome(length) + out.feed_data(chunk) + length -= len(chunk) + + except tulip.EofStream: + raise errors.IncompleteRead(b'') from None + + +def parse_eof_payload(out, buf): + """Read all bytes untile eof.""" + while True: + out.feed_data((yield from buf.readsome())) + + +class DeflateBuffer: + """DeflateStream decomress stream and feed data into specified stream.""" + + def __init__(self, out, encoding): + self.out = out + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + + self.zlib = zlib.decompressobj(wbits=zlib_mode) + + def feed_data(self, chunk): + try: + chunk = self.zlib.decompress(chunk) + except Exception: + raise errors.IncompleteRead(b'') from None + + if chunk: + self.out.feed_data(chunk) + + def feed_eof(self): + self.out.feed_data(self.zlib.flush()) + if not self.zlib.eof: + raise errors.IncompleteRead(b'') + + self.out.feed_eof() + + +def wrap_payload_filter(func): + """Wraps payload filter and piped filters. + + Filter is a generatator that accepts arbitrary chunks of data, + modify data and emit new stream of data. + + For example we have stream of chunks: ['1', '2', '3', '4', '5'], + we can apply chunking filter to this stream: + + ['1', '2', '3', '4', '5'] + | + response.add_chunking_filter(2) + | + ['12', '34', '5'] + + It is possible to use different filters at the same time. + + For a example to compress incoming stream with 'deflate' encoding + and then split data and emit chunks of 8196 bytes size chunks: + + >> response.add_compression_filter('deflate') + >> response.add_chunking_filter(8196) + + Filters do not alter transfer encoding. + + Filter can receive types types of data, bytes object or EOF_MARKER. + + 1. If filter receives bytes object, it should process data + and yield processed data then yield EOL_MARKER object. + 2. If Filter recevied EOF_MARKER, it should yield remaining + data (buffered) and then yield EOF_MARKER. + """ + @functools.wraps(func) + def wrapper(self, *args, **kw): + new_filter = func(self, *args, **kw) + + filter = self.filter + if filter is not None: + next(new_filter) + self.filter = filter_pipe(filter, new_filter) + else: + self.filter = new_filter + + next(self.filter) + + return wrapper + + +def filter_pipe(filter, filter2): + """Creates pipe between two filters. + + filter_pipe() feeds first filter with incoming data and then + send yielded from first filter data into filter2, results of + filter2 are being emitted. + + 1. If filter_pipe receives bytes object, it sends it to the first filter. + 2. Reads yielded values from the first filter until it receives + EOF_MARKER or EOL_MARKER. + 3. Each of this values is being send to second filter. + 4. Reads yielded values from second filter until it recives EOF_MARKER or + EOL_MARKER. Each of this values yields to writer. + """ + chunk = yield + + while True: + eof = chunk is EOF_MARKER + chunk = filter.send(chunk) + + while chunk is not EOL_MARKER: + chunk = filter2.send(chunk) + + while chunk not in (EOF_MARKER, EOL_MARKER): + yield chunk + chunk = next(filter2) + + if chunk is not EOF_MARKER: + if eof: + chunk = EOF_MARKER + else: + chunk = next(filter) + else: + break + + chunk = yield EOL_MARKER + + +class HttpMessage: + """HttpMessage allows to write headers and payload to a stream. + + For example, lets say we want to read file then compress it with deflate + compression and then send it with chunked transfer encoding, code may look + like this: + + >> response = tulip.http.Response(transport, 200) + + We have to use deflate compression first: + + >> response.add_compression_filter('deflate') + + Then we want to split output stream into chunks of 1024 bytes size: + + >> response.add_chunking_filter(1024) + + We can add headers to response with add_headers() method. add_headers() + does not send data to transport, send_headers() sends request/response + line and then sends headers: + + >> response.add_headers( + .. ('Content-Disposition', 'attachment; filename="..."')) + >> response.send_headers() + + Now we can use chunked writer to write stream to a network stream. + First call to write() method sends response status line and headers, + add_header() and add_headers() method unavailble at this stage: + + >> with open('...', 'rb') as f: + .. chunk = fp.read(8196) + .. while chunk: + .. response.write(chunk) + .. chunk = fp.read(8196) + + >> response.write_eof() + """ + + writer = None + + # 'filter' is being used for altering write() bahaviour, + # add_chunking_filter adds deflate/gzip compression and + # add_compression_filter splits incoming data into a chunks. + filter = None + + HOP_HEADERS = None # Must be set by subclass. + + SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} tulip/0.0'.format(sys.version_info) + + status = None + status_line = b'' + upgrade = False # Connection: UPGRADE + websocket = False # Upgrade: WEBSOCKET + + # subclass can enable auto sending headers with write() call, + # this is useful for wsgi's start_response implementation. + _send_headers = False + + def __init__(self, transport, version, close): + self.transport = transport + self.version = version + self.closing = close + + # disable keep-alive for http/1.0 + if version <= (1, 0): + self.keepalive = False + else: + self.keepalive = None + + self.chunked = False + self.length = None + self.headers = collections.deque() + self.headers_sent = False + + def force_close(self): + self.closing = True + self.keepalive = False + + def force_chunked(self): + self.chunked = True + + def keep_alive(self): + if self.keepalive is None: + return not self.closing + else: + return self.keepalive + + def is_headers_sent(self): + return self.headers_sent + + def add_header(self, name, value): + """Analyze headers. Calculate content length, + removes hop headers, etc.""" + assert not self.headers_sent, 'headers have been sent already' + assert isinstance(name, str), '{!r} is not a string'.format(name) + + name = name.strip().upper() + + if name == 'CONTENT-LENGTH': + self.length = int(value) + + if name == 'CONNECTION': + val = value.lower() + # handle websocket + if 'upgrade' in val: + self.upgrade = True + # connection keep-alive + elif 'close' in val: + self.keepalive = False + elif 'keep-alive' in val and self.version >= (1, 1): + self.keepalive = True + + elif name == 'UPGRADE': + if 'websocket' in value.lower(): + self.websocket = True + self.headers.append((name, value)) + + elif name == 'TRANSFER-ENCODING' and not self.chunked: + self.chunked = value.lower().strip() == 'chunked' + + elif name not in self.HOP_HEADERS: + # ignore hopbyhop headers + self.headers.append((name, value)) + + def add_headers(self, *headers): + """Adds headers to a http message.""" + for name, value in headers: + self.add_header(name, value) + + def send_headers(self): + """Writes headers to a stream. Constructs payload writer.""" + # Chunked response is only for HTTP/1.1 clients or newer + # and there is no Content-Length header is set. + # Do not use chunked responses when the response is guaranteed to + # not have a response body (304, 204). + assert not self.headers_sent, 'headers have been sent already' + self.headers_sent = True + + if (self.chunked is True) or ( + self.length is None and + self.version >= (1, 1) and + self.status not in (304, 204)): + self.chunked = True + self.writer = self._write_chunked_payload() + + elif self.length is not None: + self.writer = self._write_length_payload(self.length) + + else: + self.writer = self._write_eof_payload() + + next(self.writer) + + self._add_default_headers() + + # status + headers + hdrs = ''.join(itertools.chain( + (self.status_line,), + *((k, ': ', v, '\r\n') for k, v in self.headers))) + + self.transport.write(hdrs.encode('ascii') + b'\r\n') + + def _add_default_headers(self): + # set the connection header + if self.upgrade: + connection = 'upgrade' + elif not self.closing if self.keepalive is None else self.keepalive: + connection = 'keep-alive' + else: + connection = 'close' + + if self.chunked: + self.headers.appendleft(('TRANSFER-ENCODING', 'chunked')) + + self.headers.appendleft(('CONNECTION', connection)) + + def write(self, chunk): + """write() writes chunk of data to a steram by using different writers. + writer uses filter to modify chunk of data. write_eof() indicates + end of stream. writer can't be used after write_eof() method + being called.""" + assert (isinstance(chunk, (bytes, bytearray)) or + chunk is EOF_MARKER), chunk + + if self._send_headers and not self.headers_sent: + self.send_headers() + + assert self.writer is not None, 'send_headers() is not called.' + + if self.filter: + chunk = self.filter.send(chunk) + while chunk not in (EOF_MARKER, EOL_MARKER): + self.writer.send(chunk) + chunk = next(self.filter) + else: + if chunk is not EOF_MARKER: + self.writer.send(chunk) + + def write_eof(self): + self.write(EOF_MARKER) + try: + self.writer.throw(tulip.EofStream()) + except StopIteration: + pass + + def _write_chunked_payload(self): + """Write data in chunked transfer encoding.""" + while True: + try: + chunk = yield + except tulip.EofStream: + self.transport.write(b'0\r\n\r\n') + break + + self.transport.write('{:x}\r\n'.format(len(chunk)).encode('ascii')) + self.transport.write(bytes(chunk)) + self.transport.write(b'\r\n') + + def _write_length_payload(self, length): + """Write specified number of bytes to a stream.""" + while True: + try: + chunk = yield + except tulip.EofStream: + break + + if length: + l = len(chunk) + if length >= l: + self.transport.write(chunk) + else: + self.transport.write(chunk[:length]) + + length = max(0, length-l) + + def _write_eof_payload(self): + while True: + try: + chunk = yield + except tulip.EofStream: + break + + self.transport.write(chunk) + + @wrap_payload_filter + def add_chunking_filter(self, chunk_size=16*1024): + """Split incoming stream into chunks.""" + buf = bytearray() + chunk = yield + + while True: + if chunk is EOF_MARKER: + if buf: + yield buf + + yield EOF_MARKER + + else: + buf.extend(chunk) + + while len(buf) >= chunk_size: + chunk = bytes(buf[:chunk_size]) + del buf[:chunk_size] + yield chunk + + chunk = yield EOL_MARKER + + @wrap_payload_filter + def add_compression_filter(self, encoding='deflate'): + """Compress incoming stream with deflate or gzip encoding.""" + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + zcomp = zlib.compressobj(wbits=zlib_mode) + + chunk = yield + while True: + if chunk is EOF_MARKER: + yield zcomp.flush() + chunk = yield EOF_MARKER + + else: + yield zcomp.compress(chunk) + chunk = yield EOL_MARKER + + +class Response(HttpMessage): + """Create http response message. + + Transport is a socket stream transport. status is a response status code, + status has to be integer value. http_version is a tuple that represents + http version, (1, 0) stands for HTTP/1.0 and (1, 1) is for HTTP/1.1 + """ + + HOP_HEADERS = { + 'CONNECTION', + 'KEEP-ALIVE', + 'PROXY-AUTHENTICATE', + 'PROXY-AUTHORIZATION', + 'TE', + 'TRAILERS', + 'TRANSFER-ENCODING', + 'UPGRADE', + 'SERVER', + 'DATE', + } + + def __init__(self, transport, status, http_version=(1, 1), close=False): + super().__init__(transport, http_version, close) + + self.status = status + self.status_line = 'HTTP/{}.{} {} {}\r\n'.format( + http_version[0], http_version[1], status, RESPONSES[status][0]) + + def _add_default_headers(self): + super()._add_default_headers() + self.headers.extend((('DATE', format_date_time(None)), + ('SERVER', self.SERVER_SOFTWARE),)) + + +class Request(HttpMessage): + + HOP_HEADERS = () + + def __init__(self, transport, method, path, + http_version=(1, 1), close=False): + super().__init__(transport, http_version, close) + + self.method = method + self.path = path + self.status_line = '{0} {1} HTTP/{2[0]}.{2[1]}\r\n'.format( + method, path, http_version) + + def _add_default_headers(self): + super()._add_default_headers() + self.headers.append(('USER-AGENT', self.SERVER_SOFTWARE)) diff --git a/tulip/http/server.py b/tulip/http/server.py new file mode 100644 index 00000000..fc5621c5 --- /dev/null +++ b/tulip/http/server.py @@ -0,0 +1,215 @@ +"""simple http server.""" + +__all__ = ['ServerHttpProtocol'] + +import http.server +import inspect +import logging +import traceback + +import tulip +from tulip.http import errors + + +RESPONSES = http.server.BaseHTTPRequestHandler.responses +DEFAULT_ERROR_MESSAGE = """ + + + {status} {reason} + + +

{status} {reason}

+ {message} + +""" + + +class ServerHttpProtocol(tulip.Protocol): + """Simple http protocol implementation. + + ServerHttpProtocol handles incoming http request. It reads request line, + request headers and request payload and calls handler_request() method. + By default it always returns with 404 respose. + + ServerHttpProtocol handles errors in incoming request, like bad + status line, bad headers or incomplete payload. If any error occurs, + connection gets closed. + + log: custom logging object + debug: enable debug mode + keep_alive: number of seconds before closing keep alive connection + loop: event loop object + """ + _request_count = 0 + _request_handler = None + _keep_alive = False # keep transport open + _keep_alive_handle = None # keep alive timer handle + + def __init__(self, *, log=logging, debug=False, + keep_alive=None, loop=None, **kwargs): + self.__dict__.update(kwargs) + self.log = log + self.debug = debug + + self._keep_alive_period = keep_alive # number of seconds to keep alive + + if keep_alive and loop is None: + loop = tulip.get_event_loop() + self._loop = loop + + def connection_made(self, transport): + self.transport = transport + self.stream = tulip.StreamBuffer(loop=self._loop) + self._request_handler = tulip.Task(self.start(), loop=self._loop) + + def data_received(self, data): + self.stream.feed_data(data) + + def eof_received(self): + self.stream.feed_eof() + + def connection_lost(self, exc): + self.stream.feed_eof() + + if self._request_handler is not None: + self._request_handler.cancel() + self._request_handler = None + if self._keep_alive_handle is not None: + self._keep_alive_handle.cancel() + self._keep_alive_handle = None + + def keep_alive(self, val): + self._keep_alive = val + + def log_access(self, status, message, *args, **kw): + pass + + def log_debug(self, *args, **kw): + if self.debug: + self.log.debug(*args, **kw) + + def log_exception(self, *args, **kw): + self.log.exception(*args, **kw) + + @tulip.coroutine + def start(self): + """Start processing of incoming requests. + It reads request line, request headers and request payload, then + calls handle_request() method. Subclass has to override + handle_request(). start() handles various excetions in request + or response handling. Connection is being closed always unless + keep_alive(True) specified. + """ + + while True: + info = None + message = None + self._request_count += 1 + self._keep_alive = False + + try: + httpstream = self.stream.set_parser( + tulip.http.http_request_parser()) + + message = yield from httpstream.read() + + # cancel keep-alive timer + if self._keep_alive_handle is not None: + self._keep_alive_handle.cancel() + self._keep_alive_handle = None + + payload = self.stream.set_parser( + tulip.http.http_payload_parser(message)) + + handler = self.handle_request(message, payload) + if (inspect.isgenerator(handler) or + isinstance(handler, tulip.Future)): + yield from handler + + except tulip.CancelledError: + self.log_debug('Ignored premature client disconnection.') + break + except errors.HttpException as exc: + self.handle_error(exc.code, info, message, exc, exc.headers) + except Exception as exc: + self.handle_error(500, info, message, exc) + finally: + if self._request_handler: + if self._keep_alive and self._keep_alive_period: + self._keep_alive_handle = self._loop.call_later( + self._keep_alive_period, self.transport.close) + else: + self.transport.close() + self._request_handler = None + break + else: + break + + def handle_error(self, status=500, + message=None, payload=None, exc=None, headers=None): + """Handle errors. + + Returns http response with specific status code. Logs additional + information. It always closes current connection.""" + try: + if self._request_handler is None: + # client has been disconnected during writing. + return + + if status == 500: + self.log_exception("Error handling request") + + try: + reason, msg = RESPONSES[status] + except KeyError: + status = 500 + reason, msg = '???', '' + + if self.debug and exc is not None: + try: + tb = traceback.format_exc() + msg += '

Traceback:

\n
{}
'.format(tb) + except: + pass + + self.log_access(status, message) + + html = DEFAULT_ERROR_MESSAGE.format( + status=status, reason=reason, message=msg) + + response = tulip.http.Response(self.transport, status, close=True) + response.add_headers( + ('Content-Type', 'text/html'), + ('Content-Length', str(len(html)))) + if headers is not None: + response.add_headers(*headers) + response.send_headers() + + response.write(html.encode('ascii')) + response.write_eof() + finally: + self.keep_alive(False) + + def handle_request(self, message, payload): + """Handle a single http request. + + Subclass should override this method. By default it always + returns 404 response. + + info: tulip.http.RequestLine instance + message: tulip.http.RawHttpMessage instance + """ + response = tulip.http.Response( + self.transport, 404, http_version=message.version, close=True) + + body = b'Page Not Found!' + + response.add_headers( + ('Content-Type', 'text/plain'), + ('Content-Length', str(len(body)))) + response.send_headers() + response.write(body) + response.write_eof() + + self.keep_alive(False) + self.log_access(404, message) diff --git a/tulip/http/session.py b/tulip/http/session.py new file mode 100644 index 00000000..9cdd9cea --- /dev/null +++ b/tulip/http/session.py @@ -0,0 +1,103 @@ +"""client session support.""" + +__all__ = ['Session'] + +import functools +import tulip +import http.cookies + + +class Session: + + def __init__(self): + self._conns = {} + self.cookies = http.cookies.SimpleCookie() + + def __del__(self): + self.close() + + def close(self): + """Close all opened transports.""" + for key, data in self._conns.items(): + for transport, proto in data: + transport.close() + + self._conns.clear() + + def update_cookies(self, cookies): + if isinstance(cookies, dict): + cookies = cookies.items() + + for name, value in cookies: + if isinstance(value, http.cookies.Morsel): + # use dict method because SimpleCookie class modifies value + dict.__setitem__(self.cookies, name, value) + else: + self.cookies[name] = value + + @tulip.coroutine + def start(self, req, loop, new_conn=False, set_cookies=True): + key = (req.host, req.port, req.ssl) + + if set_cookies and self.cookies: + req.update_cookies(self.cookies.items()) + + if not new_conn: + transport, proto = self._get(key) + + if new_conn or transport is None: + new = True + transport, proto = yield from loop.create_connection( + functools.partial(tulip.StreamProtocol, loop=loop), + req.host, req.port, ssl=req.ssl) + else: + new = False + + try: + resp = req.send(transport) + yield from resp.start( + proto, TransportWrapper( + self._release, key, transport, proto, resp)) + except: + if new: + transport.close() + raise + + return (yield from self.start(req, loop, set_cookies=False)) + + return resp + + def _get(self, key): + conns = self._conns.get(key) + if conns: + return conns.pop() + + return None, None + + def _release(self, resp, key, conn): + msg = resp.message + if msg.should_close: + conn[0].close() + else: + conns = self._conns.get(key) + if conns is None: + conns = self._conns[key] = [] + conns.append(conn) + conn[1].unset_parser() + + if resp.cookies: + self.update_cookies(resp.cookies.items()) + + +class TransportWrapper: + + def __init__(self, release, key, transport, protocol, response): + self.release = release + self.key = key + self.transport = transport + self.protocol = protocol + self.response = response + + def close(self): + self.release(self.response, self.key, + (self.transport, self.protocol)) diff --git a/tulip/http/websocket.py b/tulip/http/websocket.py new file mode 100644 index 00000000..c3dd5872 --- /dev/null +++ b/tulip/http/websocket.py @@ -0,0 +1,233 @@ +"""WebSocket protocol versions 13 and 8.""" + +__all__ = ['WebSocketParser', 'WebSocketWriter', 'do_handshake', + 'Message', 'WebSocketError', + 'MSG_TEXT', 'MSG_BINARY', 'MSG_CLOSE', 'MSG_PING', 'MSG_PONG'] + +import base64 +import binascii +import collections +import hashlib +import struct +from tulip.http import errors + +# Frame opcodes defined in the spec. +OPCODE_CONTINUATION = 0x0 +MSG_TEXT = OPCODE_TEXT = 0x1 +MSG_BINARY = OPCODE_BINARY = 0x2 +MSG_CLOSE = OPCODE_CLOSE = 0x8 +MSG_PING = OPCODE_PING = 0x9 +MSG_PONG = OPCODE_PONG = 0xa + +WS_KEY = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +WS_HDRS = ('UPGRADE', 'CONNECTION', + 'SEC-WEBSOCKET-VERSION', 'SEC-WEBSOCKET-KEY') + +Message = collections.namedtuple('Message', ['tp', 'data', 'extra']) + + +class WebSocketError(Exception): + """WebSocket protocol parser error.""" + + +def WebSocketParser(): + out, buf = yield + + while True: + message = yield from parse_message(buf) + out.feed_data(message) + + if message.tp == MSG_CLOSE: + out.feed_eof() + break + + +def parse_frame(buf): + """Return the next frame from the socket.""" + # read header + data = yield from buf.read(2) + first_byte, second_byte = struct.unpack('!BB', data) + + fin = (first_byte >> 7) & 1 + rsv1 = (first_byte >> 6) & 1 + rsv2 = (first_byte >> 5) & 1 + rsv3 = (first_byte >> 4) & 1 + opcode = first_byte & 0xf + + # frame-fin = %x0 ; more frames of this message follow + # / %x1 ; final frame of this message + # frame-rsv1 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise + # frame-rsv2 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise + # frame-rsv3 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise + if rsv1 or rsv2 or rsv3: + raise WebSocketError('Received frame with non-zero reserved bits') + + if opcode > 0x7 and fin == 0: + raise WebSocketError('Received fragmented control frame') + + if fin == 0 and opcode == OPCODE_CONTINUATION: + raise WebSocketError( + 'Received new fragment frame with non-zero opcode') + + has_mask = (second_byte >> 7) & 1 + length = (second_byte) & 0x7f + + # Control frames MUST have a payload length of 125 bytes or less + if opcode > 0x7 and length > 125: + raise WebSocketError( + "Control frame payload cannot be larger than 125 bytes") + + # read payload + if length == 126: + data = yield from buf.read(2) + length = struct.unpack_from('!H', data)[0] + elif length > 126: + data = yield from buf.read(8) + length = struct.unpack_from('!Q', data)[0] + + if has_mask: + mask = yield from buf.read(4) + + if length: + payload = yield from buf.read(length) + else: + payload = b'' + + if has_mask: + payload = bytes(b ^ mask[i % 4] for i, b in enumerate(payload)) + + return fin, opcode, payload + + +def parse_message(buf): + fin, opcode, payload = yield from parse_frame(buf) + + if opcode == OPCODE_CLOSE: + if len(payload) >= 2: + close_code = struct.unpack('!H', payload[:2])[0] + close_message = payload[2:] + return Message(OPCODE_CLOSE, close_code, close_message) + elif payload: + raise WebSocketError( + 'Invalid close frame: {} {} {!r}'.format(fin, opcode, payload)) + return Message(OPCODE_CLOSE, '', '') + + elif opcode == OPCODE_PING: + return Message(OPCODE_PING, '', '') + + elif opcode == OPCODE_PONG: + return Message(OPCODE_PONG, '', '') + + elif opcode not in (OPCODE_TEXT, OPCODE_BINARY): + raise WebSocketError("Unexpected opcode={!r}".format(opcode)) + + # load text/binary + data = [payload] + + while not fin: + fin, _opcode, payload = yield from parse_frame(buf) + if _opcode != OPCODE_CONTINUATION: + raise WebSocketError( + 'The opcode in non-fin frame is expected ' + 'to be zero, got {!r}'.format(opcode)) + else: + data.append(payload) + + if opcode == OPCODE_TEXT: + return Message(OPCODE_TEXT, b''.join(data).decode('utf-8'), '') + else: + return Message(OPCODE_BINARY, b''.join(data), '') + + +class WebSocketWriter: + + def __init__(self, transport): + self.transport = transport + + def _send_frame(self, message, opcode): + """Send a frame over the websocket with message as its payload.""" + header = bytes([0x80 | opcode]) + msg_length = len(message) + + if msg_length < 126: + header += bytes([msg_length]) + elif msg_length < (1 << 16): + header += bytes([126]) + struct.pack('!H', msg_length) + else: + header += bytes([127]) + struct.pack('!Q', msg_length) + + self.transport.write(header + message) + + def pong(self): + """Send pong message.""" + self._send_frame(b'', OPCODE_PONG) + + def ping(self): + """Send pong message.""" + self._send_frame(b'', OPCODE_PING) + + def send(self, message, binary=False): + """Send a frame over the websocket with message as its payload.""" + if isinstance(message, str): + message = message.encode('utf-8') + if binary: + self._send_frame(message, OPCODE_BINARY) + else: + self._send_frame(message, OPCODE_TEXT) + + def close(self, code=1000, message=b''): + """Close the websocket, sending the specified code and message.""" + if isinstance(message, str): + message = message.encode('utf-8') + self._send_frame( + struct.pack('!H%ds' % len(message), code, message), + opcode=OPCODE_CLOSE) + + +def do_handshake(method, headers, transport): + """Prepare WebSocket handshake. It return http response code, + response headers, websocket parser, websocket writer. It does not + perform any IO.""" + + # WebSocket accepts only GET + if method.upper() != 'GET': + raise errors.HttpErrorException(405, headers=(('Allow', 'GET'),)) + + headers = dict(((hdr, val) for hdr, val in headers if hdr in WS_HDRS)) + + if 'websocket' != headers.get('UPGRADE', '').lower().strip(): + raise errors.BadRequestException( + 'No WebSocket UPGRADE hdr: {}\n' + 'Can "Upgrade" only to "WebSocket".'.format( + headers.get('UPGRADE'))) + + if 'upgrade' not in headers.get('CONNECTION', '').lower(): + raise errors.BadRequestException( + 'No CONNECTION upgrade hdr: {}'.format( + headers.get('CONNECTION'))) + + # check supported version + version = headers.get('SEC-WEBSOCKET-VERSION') + if version not in ('13', '8', '7'): + raise errors.BadRequestException( + 'Unsupported version: {}'.format(version)) + + # check client handshake for validity + key = headers.get('SEC-WEBSOCKET-KEY') + try: + if not key or len(base64.b64decode(key)) != 16: + raise errors.BadRequestException( + 'Handshake error: {!r}'.format(key)) + except binascii.Error: + raise errors.BadRequestException( + 'Handshake error: {!r}'.format(key)) from None + + # response code, headers, parser, writer + return (101, + (('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('TRANSFER-ENCODING', 'chunked'), + ('SEC-WEBSOCKET-ACCEPT', base64.b64encode( + hashlib.sha1(key.encode() + WS_KEY).digest()).decode())), + WebSocketParser(), + WebSocketWriter(transport)) diff --git a/tulip/http/wsgi.py b/tulip/http/wsgi.py new file mode 100644 index 00000000..738e100f --- /dev/null +++ b/tulip/http/wsgi.py @@ -0,0 +1,227 @@ +"""wsgi server. + +TODO: + * proxy protocol + * x-forward security + * wsgi file support (os.sendfile) +""" + +__all__ = ['WSGIServerHttpProtocol'] + +import inspect +import io +import os +import sys +from urllib.parse import unquote, urlsplit + +import tulip +import tulip.http +from tulip.http import server + + +class WSGIServerHttpProtocol(server.ServerHttpProtocol): + """HTTP Server that implements the Python WSGI protocol. + + It uses 'wsgi.async' of 'True'. 'wsgi.input' can behave differently + depends on 'readpayload' constructor parameter. If readpayload is set to + True, wsgi server reads all incoming data into BytesIO object and + sends it as 'wsgi.input' environ var. If readpayload is set to false + 'wsgi.input' is a StreamReader and application should read incoming + data with "yield from environ['wsgi.input'].read()". It defaults to False. + """ + + SCRIPT_NAME = os.environ.get('SCRIPT_NAME', '') + + def __init__(self, app, readpayload=False, is_ssl=False, *args, **kw): + super().__init__(*args, **kw) + + self.wsgi = app + self.is_ssl = is_ssl + self.readpayload = readpayload + + def create_wsgi_response(self, message): + return WsgiResponse(self.transport, message) + + def create_wsgi_environ(self, message, payload): + uri_parts = urlsplit(message.path) + url_scheme = 'https' if self.is_ssl else 'http' + + environ = { + 'wsgi.input': payload, + 'wsgi.errors': sys.stderr, + 'wsgi.version': (1, 0), + 'wsgi.async': True, + 'wsgi.multithread': False, + 'wsgi.multiprocess': False, + 'wsgi.run_once': False, + 'wsgi.file_wrapper': FileWrapper, + 'wsgi.url_scheme': url_scheme, + 'SERVER_SOFTWARE': tulip.http.HttpMessage.SERVER_SOFTWARE, + 'REQUEST_METHOD': message.method, + 'QUERY_STRING': uri_parts.query or '', + 'RAW_URI': message.path, + 'SERVER_PROTOCOL': 'HTTP/%s.%s' % message.version + } + + # authors should be aware that REMOTE_HOST and REMOTE_ADDR + # may not qualify the remote addr: + # http://www.ietf.org/rfc/rfc3875 + forward = self.transport.get_extra_info('addr', '127.0.0.1') + script_name = self.SCRIPT_NAME + server = forward + + for hdr_name, hdr_value in message.headers: + if hdr_name == 'EXPECT': + # handle expect + if hdr_value.lower() == '100-continue': + self.transport.write(b'HTTP/1.1 100 Continue\r\n\r\n') + elif hdr_name == 'HOST': + server = hdr_value + elif hdr_name == 'SCRIPT_NAME': + script_name = hdr_value + elif hdr_name == 'CONTENT-TYPE': + environ['CONTENT_TYPE'] = hdr_value + continue + elif hdr_name == 'CONTENT-LENGTH': + environ['CONTENT_LENGTH'] = hdr_value + continue + + key = 'HTTP_%s' % hdr_name.replace('-', '_') + if key in environ: + hdr_value = '%s,%s' % (environ[key], hdr_value) + + environ[key] = hdr_value + + if isinstance(forward, str): + # we only took the last one + # http://en.wikipedia.org/wiki/X-Forwarded-For + if ',' in forward: + forward = forward.rsplit(',', 1)[-1].strip() + + # find host and port on ipv6 address + if '[' in forward and ']' in forward: + host = forward.split(']')[0][1:].lower() + elif ':' in forward and forward.count(':') == 1: + host = forward.split(':')[0].lower() + else: + host = forward + + forward = forward.split(']')[-1] + if ':' in forward and forward.count(':') == 1: + port = forward.split(':', 1)[1] + else: + port = 80 + + remote = (host, port) + else: + remote = forward + + environ['REMOTE_ADDR'] = remote[0] + environ['REMOTE_PORT'] = str(remote[1]) + + if isinstance(server, str): + server = server.split(':') + if len(server) == 1: + server.append('80' if url_scheme == 'http' else '443') + + environ['SERVER_NAME'] = server[0] + environ['SERVER_PORT'] = str(server[1]) + + path_info = uri_parts.path + if script_name: + path_info = path_info.split(script_name, 1)[-1] + + environ['PATH_INFO'] = unquote(path_info) + environ['SCRIPT_NAME'] = script_name + + environ['tulip.reader'] = self.stream + environ['tulip.writer'] = self.transport + + return environ + + @tulip.coroutine + def handle_request(self, message, payload): + """Handle a single HTTP request""" + + if self.readpayload: + wsgiinput = io.BytesIO() + chunk = yield from payload.read() + while chunk: + wsgiinput.write(chunk) + chunk = yield from payload.read() + wsgiinput.seek(0) + payload = wsgiinput + + environ = self.create_wsgi_environ(message, payload) + response = self.create_wsgi_response(message) + + riter = self.wsgi(environ, response.start_response) + if isinstance(riter, tulip.Future) or inspect.isgenerator(riter): + riter = yield from riter + + resp = response.response + try: + for item in riter: + if isinstance(item, tulip.Future): + item = yield from item + resp.write(item) + + resp.write_eof() + finally: + if hasattr(riter, 'close'): + riter.close() + + if resp.keep_alive(): + self.keep_alive(True) + + +class FileWrapper: + """Custom file wrapper.""" + + def __init__(self, fobj, chunk_size=8192): + self.fobj = fobj + self.chunk_size = chunk_size + if hasattr(fobj, 'close'): + self.close = fobj.close + + def __iter__(self): + return self + + def __next__(self): + data = self.fobj.read(self.chunk_size) + if data: + return data + raise StopIteration + + +class WsgiResponse: + """Implementation of start_response() callable as specified by PEP 3333""" + + status = None + + def __init__(self, transport, message): + self.transport = transport + self.message = message + + def start_response(self, status, headers, exc_info=None): + if exc_info: + try: + if self.status: + raise exc_info[1] + finally: + exc_info = None + + status_code = int(status.split(' ', 1)[0]) + + self.status = status + resp = self.response = tulip.http.Response( + self.transport, status_code, + self.message.version, self.message.should_close) + resp.add_headers(*headers) + + # send headers immediately for websocket connection + if status_code == 101 and resp.upgrade and resp.websocket: + resp.send_headers() + else: + resp._send_headers = True + return self.response.write diff --git a/tulip/locks.py b/tulip/locks.py new file mode 100644 index 00000000..622a499b --- /dev/null +++ b/tulip/locks.py @@ -0,0 +1,442 @@ +"""Synchronization primitives""" + +__all__ = ['Lock', 'EventWaiter', 'Condition', 'Semaphore'] + +import collections + +from . import events +from . import futures +from . import tasks + + +class Lock: + """The class implementing primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned by + a particular coroutine when locked. A primitive lock is in one of two + states, "locked" or "unlocked". + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() changes + the state to locked and returns immediately. When the state is locked, + acquire() blocks until a call to release() in another coroutine changes + it to unlocked, then the acquire() call resets it to locked and returns. + The release() method should only be called in the locked state; it changes + the state to unlocked and returns immediately. If an attempt is made + to release an unlocked lock, a RuntimeError will be raised. + + When more than one coroutine is blocked in acquire() waiting for the state + to turn to unlocked, only one coroutine proceeds when a release() call + resets the state to unlocked; first coroutine which is blocked in acquire() + is being processed. + + acquire() method is a coroutine and should be called with "yield from" + + Locks also support the context manager protocol. (yield from lock) should + be used as context manager expression. + + Usage: + + lock = Lock() + ... + yield from lock + try: + ... + finally: + lock.release() + + Context manager usage: + + lock = Lock() + ... + with (yield from lock): + ... + + Lock object could be tested for locking state: + + if not lock.locked(): + yield from lock + else: + # lock is acquired + ... + + """ + + def __init__(self, *, loop=None): + self._waiters = collections.deque() + self._locked = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<{} [{}]>'.format( + res[1:-1], 'locked' if self._locked else 'unlocked') + + def locked(self): + """Return true if lock is acquired.""" + return self._locked + + @tasks.coroutine + def acquire(self, timeout=None): + """Acquire a lock. + + Acquire method blocks until the lock is unlocked, then set it to + locked and return True. + + When invoked with the floating-point timeout argument set, blocks for + at most the number of seconds specified by timeout and as long as + the lock cannot be acquired. + + The return value is True if the lock is acquired successfully, + False if not (for example if the timeout expired). + """ + if not self._waiters and not self._locked: + self._locked = True + return True + + fut = futures.Future(loop=self._loop, timeout=timeout) + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + self._locked = True + return True + + def release(self): + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + If any other coroutines are blocked waiting for the lock to become + unlocked, allow exactly one of them to proceed. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + if self._waiters: + self._waiters[0].set_result(True) + else: + raise RuntimeError('Lock is not acquired.') + + def __enter__(self): + if not self._locked: + raise RuntimeError( + '"yield from" should be used as context manager expression') + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self + + +class EventWaiter: + """A EventWaiter implementation, our equivalent to threading.Event + + Class implementing event objects. An event manages a flag that can be set + to true with the set() method and reset to false with the clear() method. + The wait() method blocks until the flag is true. The flag is initially + false. + """ + + def __init__(self, *, loop=None): + self._waiters = collections.deque() + self._value = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<{} [{}]>'.format(res[1:-1], 'set' if self._value else 'unset') + + def is_set(self): + """Return true if and only if the internal flag is true.""" + return self._value + + def set(self): + """Set the internal flag to true. All coroutines waiting for it to + become true are awakened. Coroutine that call wait() once the flag is + true will not block at all. + """ + if not self._value: + self._value = True + + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + + def clear(self): + """Reset the internal flag to false. Subsequently, coroutines calling + wait() will block until set() is called to set the internal flag + to true again.""" + self._value = False + + @tasks.coroutine + def wait(self, timeout=None): + """Block until the internal flag is true. If the internal flag + is true on entry, return immediately. Otherwise, block until another + coroutine calls set() to set the flag to true, or until the optional + timeout occurs. + + When the timeout argument is present and not None, it should be + a floating point number specifying a timeout for the operation in + seconds (or fractions thereof). + + This method returns true if and only if the internal flag has been + set to true, either before the wait call or after the wait starts, + so it will always return True except if a timeout is given and + the operation times out. + + wait() method is a coroutine. + """ + if self._value: + return True + + fut = futures.Future(loop=self._loop, timeout=timeout) + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + return True + + +class Condition(Lock): + """A Condition implementation. + + This class implements condition variable objects. A condition variable + allows one or more coroutines to wait until they are notified by another + coroutine. + """ + + def __init__(self, *, loop=None): + super().__init__(loop=loop) + + self._condition_waiters = collections.deque() + + @tasks.coroutine + def wait(self, timeout=None): + """Wait until notified or until a timeout occurs. If the calling + coroutine has not acquired the lock when this method is called, + a RuntimeError is raised. + + This method releases the underlying lock, and then blocks until it is + awakened by a notify() or notify_all() call for the same condition + variable in another coroutine, or until the optional timeout occurs. + Once awakened or timed out, it re-acquires the lock and returns. + + When the timeout argument is present and not None, it should be + a floating point number specifying a timeout for the operation + in seconds (or fractions thereof). + + The return value is True unless a given timeout expired, in which + case it is False. + """ + if not self._locked: + raise RuntimeError('cannot wait on un-acquired lock') + + self.release() + + fut = futures.Future(loop=self._loop, timeout=timeout) + + self._condition_waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._condition_waiters.remove(fut) + return False + else: + f = self._condition_waiters.popleft() + assert fut is f + finally: + yield from self.acquire() + + return True + + @tasks.coroutine + def wait_for(self, predicate, timeout=None): + """Wait until a condition evaluates to True. predicate should be a + callable which result will be interpreted as a boolean value. A timeout + may be provided giving the maximum time to wait. + """ + endtime = None + waittime = timeout + result = predicate() + + while not result: + if waittime is not None: + if endtime is None: + endtime = self._loop.time() + waittime + else: + waittime = endtime - self._loop.time() + if waittime <= 0: + break + + yield from self.wait(waittime) + result = predicate() + + return result + + def notify(self, n=1): + """By default, wake up one coroutine waiting on this condition, if any. + If the calling coroutine has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up at most n of the coroutines waiting for the + condition variable; it is a no-op if no coroutines are waiting. + + Note: an awakened coroutine does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self._locked: + raise RuntimeError('cannot notify on un-acquired lock') + + idx = 0 + for fut in self._condition_waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self): + """Wake up all threads waiting on this condition. This method acts + like notify(), but wakes up all waiting threads instead of one. If the + calling thread has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._condition_waiters)) + + +class Semaphore: + """A Semaphore implementation. + + A semaphore manages an internal counter which is decremented by each + acquire() call and incremented by each release() call. The counter + can never go below zero; when acquire() finds that it is zero, it blocks, + waiting until some other thread calls release(). + + Semaphores also support the context manager protocol. + + The first optional argument gives the initial value for the internal + counter; it defaults to 1. If the value given is less than 0, + ValueError is raised. + + The second optional argument determins can semophore be released more than + initial internal counter value; it defaults to False. If the value given + is True and number of release() is more than number of successfull + acquire() calls ValueError is raised. + """ + + def __init__(self, value=1, bound=False, *, loop=None): + if value < 0: + raise ValueError("Semaphore initial value must be > 0") + self._value = value + self._bound = bound + self._bound_value = value + self._waiters = collections.deque() + self._locked = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + return '<{} [{}]>'.format( + res[1:-1], + 'locked' if self._locked else 'unlocked,value:{}'.format( + self._value)) + + def locked(self): + """Returns True if semaphore can not be acquired immediately.""" + return self._locked + + @tasks.coroutine + def acquire(self, timeout=None): + """Acquire a semaphore. acquire() method is a coroutine. + + When invoked without arguments: if the internal counter is larger + than zero on entry, decrement it by one and return immediately. + If it is zero on entry, block, waiting until some other coroutine has + called release() to make it larger than zero. + + When invoked with a timeout other than None, it will block for at + most timeout seconds. If acquire does not complete successfully in + that interval, return false. Return true otherwise. + """ + if not self._waiters and self._value > 0: + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + fut = futures.Future(loop=self._loop, timeout=timeout) + + self._waiters.append(fut) + try: + yield from fut + except futures.CancelledError: + self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + When it was zero on entry and another coroutine is waiting for it to + become larger than zero again, wake up that coroutine. + + If Semaphore is create with "bound" paramter equals true, then + release() method checks to make sure its current value doesn't exceed + its initial value. If it does, ValueError is raised. + """ + if self._bound and self._value >= self._bound_value: + raise ValueError('Semaphore released too many times') + + self._value += 1 + self._locked = False + + for waiter in self._waiters: + if not waiter.done(): + waiter.set_result(True) + break + + def __enter__(self): + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self diff --git a/tulip/log.py b/tulip/log.py new file mode 100644 index 00000000..b918fe54 --- /dev/null +++ b/tulip/log.py @@ -0,0 +1,6 @@ +"""Tulip logging configuration""" + +import logging + + +tulip_log = logging.getLogger("tulip") diff --git a/tulip/parsers.py b/tulip/parsers.py new file mode 100644 index 00000000..43ddc2e9 --- /dev/null +++ b/tulip/parsers.py @@ -0,0 +1,399 @@ +"""Parser is a generator function. + +Parser receives data with generator's send() method and sends data to +destination DataBuffer. Parser receives ParserBuffer and DataBuffer objects +as a parameters of the first send() call, all subsequent send() calls should +send bytes objects. Parser sends parsed 'term' to desitnation buffer with +DataBuffer.feed_data() method. DataBuffer object should implement two methods. +feed_data() - parser uses this method to send parsed protocol data. +feed_eof() - parser uses this method for indication of end of parsing stream. +To indicate end of incoming data stream EofStream exception should be sent +into parser. Parser could throw exceptions. + +There are three stages: + + * Data flow chain: + + 1. Application creates StreamBuffer object for storing incoming data. + 2. StreamBuffer creates ParserBuffer as internal data buffer. + 3. Application create parser and set it into stream buffer: + + parser = http_request_parser() + data_buffer = stream.set_parser(parser) + + 3. At this stage StreamBuffer creates DataBuffer object and passes it + and internal buffer into parser with first send() call. + + def set_parser(self, parser): + next(parser) + data_buffer = DataBuffer() + parser.send((data_buffer, self._buffer)) + return data_buffer + + 4. Application waits data on data_buffer.read() + + while True: + msg = yield form data_buffer.read() + ... + + * Data flow: + + 1. Tulip's transport reads data from socket and sends data to protocol + with data_received() call. + 2. Protocol sends data to StreamBuffer with feed_data() call. + 3. StreamBuffer sends data into parser with generator's send() method. + 4. Parser processes incoming data and sends parsed data + to DataBuffer with feed_data() + 4. Application received parsed data from DataBuffer.read() + + * Eof: + + 1. StreamBuffer recevies eof with feed_eof() call. + 2. StreamBuffer throws EofStream exception into parser. + 3. Then it unsets parser. + +_SocketSocketTransport -> + -> "protocol" -> StreamBuffer -> "parser" -> DataBuffer <- "application" + +""" +__all__ = ['EofStream', 'StreamBuffer', 'StreamProtocol', + 'ParserBuffer', 'DataBuffer', 'lines_parser', 'chunks_parser'] + +import collections + +from . import tasks +from . import futures +from . import protocols + + +class EofStream(Exception): + """eof stream indication.""" + + +class StreamBuffer: + """StreamBuffer manages incoming bytes stream and protocol parsers. + + StreamBuffer uses ParserBuffer as internal buffer. + + set_parser() sets current parser, it creates DataBuffer object + and sends ParserBuffer and DataBuffer into parser generator. + + unset_parser() sends EofStream into parser and then removes it. + """ + + def __init__(self, *, loop=None): + self._loop = loop + self._buffer = ParserBuffer() + self._eof = False + self._parser = None + self._parser_buffer = None + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + if self._parser_buffer is not None: + self._parser_buffer.set_exception(exc) + self._parser = None + self._parser_buffer = None + + def feed_data(self, data): + """send data to current parser or store in buffer.""" + if not data: + return + + if self._parser: + try: + self._parser.send(data) + except StopIteration: + self._parser = None + self._parser_buffer = None + except Exception as exc: + self._parser_buffer.set_exception(exc) + self._parser = None + self._parser_buffer = None + else: + self._buffer.feed_data(data) + + def feed_eof(self): + """send eof to all parsers, recursively.""" + if self._parser: + try: + self._parser.throw(EofStream()) + except StopIteration: + pass + except EofStream: + self._parser_buffer.feed_eof() + except Exception as exc: + self._parser_buffer.set_exception(exc) + + self._parser = None + self._parser_buffer = None + + self._eof = True + + def set_parser(self, p): + """set parser to stream. return parser's DataStream.""" + if self._parser: + self.unset_parser() + + out = DataBuffer(loop=self._loop) + if self._exception: + out.set_exception(self._exception) + return out + + # init generator + next(p) + try: + # initialize parser with data and parser buffers + p.send((out, self._buffer)) + except StopIteration: + pass + except Exception as exc: + out.set_exception(exc) + else: + # parser still require more data + self._parser = p + self._parser_buffer = out + + if self._eof: + self.unset_parser() + + return out + + def unset_parser(self): + """unset parser, send eof to the parser and then remove it.""" + if self._parser is None: + return + + try: + self._parser.throw(EofStream()) + except StopIteration: + pass + except EofStream: + self._parser_buffer.feed_eof() + except Exception as exc: + self._parser_buffer.set_exception(exc) + finally: + self._parser = None + self._parser_buffer = None + + +class StreamProtocol(StreamBuffer, protocols.Protocol): + """Tulip's stream protocol based on StreamBuffer""" + + transport = None + + data_received = StreamBuffer.feed_data + + eof_received = StreamBuffer.feed_eof + + def connection_made(self, transport): + self.transport = transport + + def connection_lost(self, exc): + self.transport = None + + if exc is not None: + self.set_exception(exc) + else: + self.feed_eof() + + +class DataBuffer: + """DataBuffer is a destination for parsed data.""" + + def __init__(self, *, loop=None): + self._loop = loop + self._buffer = collections.deque() + self._eof = False + self._waiter = None + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + waiter = self._waiter + if waiter is not None: + self._waiter = None + if not waiter.done(): + waiter.set_exception(exc) + + def feed_data(self, data): + self._buffer.append(data) + + waiter = self._waiter + if waiter is not None: + self._waiter = None + waiter.set_result(True) + + def feed_eof(self): + self._eof = True + + waiter = self._waiter + if waiter is not None: + self._waiter = None + waiter.set_result(False) + + @tasks.coroutine + def read(self): + if self._exception is not None: + raise self._exception + + if not self._buffer and not self._eof: + assert not self._waiter + self._waiter = futures.Future(loop=self._loop) + yield from self._waiter + + if self._buffer: + return self._buffer.popleft() + else: + return None + + +class ParserBuffer(bytearray): + """ParserBuffer is a bytearray extension. + + ParserBuffer provides helper methods for parsers. + """ + + def __init__(self, *args): + super().__init__(*args) + + self.offset = 0 + self.size = 0 + self._writer = self._feed_data() + next(self._writer) + + def _shrink(self): + if self.offset: + del self[:self.offset] + self.offset = 0 + self.size = len(self) + + def _feed_data(self): + while True: + chunk = yield + if chunk: + chunk_len = len(chunk) + self.size += chunk_len + self.extend(chunk) + + # shrink buffer + if (self.offset and len(self) > 5120): + self._shrink() + + def feed_data(self, data): + self._writer.send(data) + + def read(self, size): + """read() reads specified amount of bytes.""" + + while True: + if self.size >= size: + start, end = self.offset, self.offset + size + self.offset = end + self.size = self.size - size + return self[start:end] + + self._writer.send((yield)) + + def readsome(self, size=None): + """reads size of less amount of bytes.""" + + while True: + if self.size > 0: + if size is None or self.size < size: + size = self.size + + start, end = self.offset, self.offset + size + self.offset = end + self.size = self.size - size + + return self[start:end] + + self._writer.send((yield)) + + def readuntil(self, stop, limit=None, exc=ValueError): + assert isinstance(stop, bytes) and stop, \ + 'bytes is required: {!r}'.format(stop) + + stop_len = len(stop) + + while True: + pos = self.find(stop, self.offset) + if pos >= 0: + end = pos + stop_len + size = end - self.offset + if limit is not None and size > limit: + raise exc('Line is too long.') + + start, self.offset = self.offset, end + self.size = self.size - size + + return self[start:end] + else: + if limit is not None and self.size > limit: + raise exc('Line is too long.') + + self._writer.send((yield)) + + def skip(self, size): + """skip() skips specified amount of bytes.""" + + while self.size < size: + self._writer.send((yield)) + + self.size -= size + self.offset += size + + def skipuntil(self, stop): + """skipuntil() reads until `stop` bytes sequence.""" + assert isinstance(stop, bytes) and stop, \ + 'bytes is required: {!r}'.format(stop) + + stop_len = len(stop) + + while True: + stop_line = self.find(stop, self.offset) + if stop_line >= 0: + end = stop_line + stop_len + self.size = self.size - (end - self.offset) + self.offset = end + return + else: + self.size = 0 + self.offset = len(self) - 1 + + self._writer.send((yield)) + + def __bytes__(self): + return bytes(self[self.offset:]) + + +def lines_parser(limit=2**16, exc=ValueError): + """Lines parser. + + lines parser splits a bytes stream into a chunks of data, each chunk ends + with \n symbol.""" + out, buf = yield + + while True: + out.feed_data((yield from buf.readuntil(b'\n', limit, exc))) + + +def chunks_parser(size=8196): + """Chunks parser. + + chunks parser splits a bytes stream into a specified + size chunks of data.""" + out, buf = yield + + while True: + out.feed_data((yield from buf.read(size))) diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py new file mode 100644 index 00000000..cda87918 --- /dev/null +++ b/tulip/proactor_events.py @@ -0,0 +1,288 @@ +"""Event loop using a proactor and related classes. + +A proactor is a "notify-on-completion" multiplexer. Currently a +proactor is only implemented on Windows with IOCP. +""" + +import socket + +from . import base_events +from . import constants +from . import futures +from . import transports +from .log import tulip_log + + +class _ProactorBasePipeTransport(transports.BaseTransport): + """Base class for pipe and socket transports.""" + + def __init__(self, loop, sock, protocol, waiter=None, extra=None): + super().__init__(extra) + self._set_extra(sock) + self._loop = loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._read_fut = None + self._write_fut = None + self._writing_disabled = False + self._conn_lost = 0 + self._closing = False # Set when close() called. + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _set_extra(self, sock): + self._extra['pipe'] = sock + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + if not self._buffer and self._write_fut is None: + self._loop.call_soon(self._call_connection_lost, None) + if self._read_fut is not None: + self._read_fut.cancel() + + def _fatal_error(self, exc): + tulip_log.exception('Fatal error for %s', self) + self._force_close(exc) + + def _force_close(self, exc): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + if self._write_fut: + self._write_fut.cancel() + if self._read_fut: + self._read_fut.cancel() + self._write_fut = self._read_fut = None + self._buffer = [] + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + # XXX If there is a pending overlapped read on the other + # end then it may fail with ERROR_NETNAME_DELETED if we + # just close our end. First calling shutdown() seems to + # cure it, but maybe using DisconnectEx() would be better. + if hasattr(self._sock, 'shutdown'): + self._sock.shutdown(socket.SHUT_RDWR) + self._sock.close() + + +class _ProactorReadPipeTransport(_ProactorBasePipeTransport, + transports.ReadTransport): + """Transport for read pipes.""" + + def __init__(self, loop, sock, protocol, waiter=None, extra=None): + super().__init__(loop, sock, protocol, waiter, extra) + self._loop.call_soon(self._loop_reading) + + def _loop_reading(self, fut=None): + data = None + + try: + if fut is not None: + assert fut is self._read_fut + self._read_fut = None + data = fut.result() # deliver data later in "finally" clause + + if self._closing: + # since close() has been called we ignore any read data + data = None + return + + if data == b'': + # we got end-of-file so no need to reschedule a new read + return + + # reschedule a new read + self._read_fut = self._loop._proactor.recv(self._sock, 4096) + except ConnectionAbortedError as exc: + if not self._closing: + self._fatal_error(exc) + except ConnectionResetError as exc: + self._force_close(exc) + except OSError as exc: + self._fatal_error(exc) + except futures.CancelledError: + if not self._closing: + raise + else: + self._read_fut.add_done_callback(self._loop_reading) + finally: + if data: + self._protocol.data_received(data) + elif data is not None: + try: + self._protocol.eof_received() + finally: + self.close() + + +class _ProactorWritePipeTransport(_ProactorBasePipeTransport, + transports.WriteTransport): + """Transport for write pipes.""" + + def write(self, data): + assert isinstance(data, bytes), repr(data) + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + self._buffer.append(data) + if self._write_fut is None and not self._writing_disabled: + self._loop_writing() + + def _loop_writing(self, f=None): + try: + assert f is self._write_fut + self._write_fut = None + if f: + f.result() + data = b''.join(self._buffer) + self._buffer = [] + if not data: + if self._closing: + self._loop.call_soon(self._call_connection_lost, None) + return + if not self._writing_disabled: + self._write_fut = self._loop._proactor.send(self._sock, data) + self._write_fut.add_done_callback(self._loop_writing) + except OSError as exc: + self._fatal_error(exc) + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._force_close(None) + + def pause_writing(self): + self._writing_disabled = True + + def resume_writing(self): + self._writing_disabled = False + if self._buffer and self._write_fut is None: + self._loop_writing() + + def discard_output(self): + if self._buffer: + self._buffer = [] + + +class _ProactorSocketTransport(_ProactorReadPipeTransport, + _ProactorWritePipeTransport, + transports.Transport): + """Transport for connected sockets.""" + + def _set_extra(self, sock): + self._extra['socket'] = sock + + +class BaseProactorEventLoop(base_events.BaseEventLoop): + + def __init__(self, proactor): + super().__init__() + tulip_log.debug('Using proactor: %s', proactor.__class__.__name__) + self._proactor = proactor + self._selector = proactor # convenient alias + proactor.set_loop(self) + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + return _ProactorSocketTransport(self, sock, protocol, waiter, extra) + + def _make_read_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorReadPipeTransport(self, sock, protocol, waiter, extra) + + def _make_write_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorWritePipeTransport(self, sock, protocol, waiter, extra) + + def close(self): + if self._proactor is not None: + self._close_self_pipe() + self._proactor.close() + self._proactor = None + self._selector = None + + def sock_recv(self, sock, n): + return self._proactor.recv(sock, n) + + def sock_sendall(self, sock, data): + return self._proactor.send(sock, data) + + def sock_connect(self, sock, address): + return self._proactor.connect(sock, address) + + def sock_accept(self, sock): + return self._proactor.accept(sock) + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.call_soon(self._loop_self_reading) + + def _loop_self_reading(self, f=None): + try: + if f is not None: + f.result() # may raise + f = self._proactor.recv(self._ssock, 4096) + except: + self.close() + raise + else: + f.add_done_callback(self._loop_self_reading) + + def _write_to_self(self): + self._csock.send(b'x') + + def _start_serving(self, protocol_factory, sock, ssl=None): + assert not ssl, 'IocpEventLoop imcompatible with SSL.' + + def loop(f=None): + try: + if f: + conn, addr = f.result() + protocol = protocol_factory() + self._make_socket_transport( + conn, protocol, extra={'addr': addr}) + f = self._proactor.accept(sock) + except OSError: + sock.close() + tulip_log.exception('Accept failed') + except futures.CancelledError: + sock.close() + else: + f.add_done_callback(loop) + self.call_soon(loop) + + def _process_events(self, event_list): + pass # XXX hard work currently done in poll + + def stop_serving(self, sock): + self._proactor.stop_serving(sock) + sock.close() diff --git a/tulip/protocols.py b/tulip/protocols.py new file mode 100644 index 00000000..d76f25a2 --- /dev/null +++ b/tulip/protocols.py @@ -0,0 +1,100 @@ +"""Abstract Protocol class.""" + +__all__ = ['Protocol', 'DatagramProtocol'] + + +class BaseProtocol: + """ABC for base protocol class. + + Usually user implements protocols that derived from BaseProtocol + like Protocol or ProcessProtocol. + + The only case when BaseProtocol should be implemented directly is + write-only transport like write pipe + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the pipe connection. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +class Protocol(BaseProtocol): + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing (they don't raise exceptions). + + When the user wants to requests a transport, they pass a protocol + factory to a utility function (e.g., EventLoop.create_connection()). + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_lost() will be called exactly once + with either an exception object or None as an argument. + + State machine of calls: + + start -> CM [-> DR*] [-> ER?] -> CL -> end + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + The default implementation does nothing. + + TODO: By default close the transport. But we don't have the + transport as an instance variable (connection_made() may not + set it). + """ + + +class DatagramProtocol(BaseProtocol): + """ABC representing a datagram protocol.""" + + def datagram_received(self, data, addr): + """Called when some datagram is received.""" + + def connection_refused(self, exc): + """Connection is refused.""" + + +class SubprocessProtocol(BaseProtocol): + """ABC representing a protocol for subprocess calls.""" + + def pipe_data_received(self, fd, data): + """Called when subprocess write a data into stdout/stderr pipes. + + fd is int file dascriptor. + data is bytes object. + """ + + def pipe_connection_lost(self, fd, exc): + """Called when a file descriptor associated with the child process is + closed. + + fd is the int file descriptor that was closed. + """ + + def process_exited(self): + """Called when subprocess has exited. + """ diff --git a/tulip/queues.py b/tulip/queues.py new file mode 100644 index 00000000..8214d0ec --- /dev/null +++ b/tulip/queues.py @@ -0,0 +1,298 @@ +"""Queues""" + +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue', + 'Full', 'Empty'] + +import collections +import concurrent.futures +import heapq +import queue + +from . import events +from . import futures +from . import locks +from .tasks import coroutine + + +# Re-export queue.Full and .Empty exceptions. +Full = queue.Full +Empty = queue.Empty + + +class Queue: + """A queue, useful for coordinating producer and consumer coroutines. + + If maxsize is less than or equal to zero, the queue size is infinite. If it + is an integer greater than 0, then "yield from put()" will block when the + queue reaches maxsize, until an item is removed by get(). + + Unlike the standard library Queue, you can reliably know this Queue's size + with qsize(), since your single-threaded Tulip application won't be + interrupted between calling qsize() and doing an operation on the Queue. + """ + + def __init__(self, maxsize=0, *, loop=None): + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._maxsize = maxsize + + # Futures. + self._getters = collections.deque() + # Pairs of (item, Future). + self._putters = collections.deque() + self._init(maxsize) + + def _init(self, maxsize): + self._queue = collections.deque() + + def _get(self): + return self._queue.popleft() + + def _put(self, item): + self._queue.append(item) + + def __repr__(self): + return '<{} at {:#x} {}>'.format( + type(self).__name__, id(self), self._format()) + + def __str__(self): + return '<{} {}>'.format(type(self).__name__, self._format()) + + def _format(self): + result = 'maxsize={!r}'.format(self._maxsize) + if getattr(self, '_queue', None): + result += ' _queue={!r}'.format(list(self._queue)) + if self._getters: + result += ' _getters[{}]'.format(len(self._getters)) + if self._putters: + result += ' _putters[{}]'.format(len(self._putters)) + return result + + def _consume_done_getters(self, waiters): + # Delete waiters at the head of the get() queue who've timed out. + while waiters and waiters[0].done(): + waiters.popleft() + + def _consume_done_putters(self): + # Delete waiters at the head of the put() queue who've timed out. + while self._putters and self._putters[0][1].done(): + self._putters.popleft() + + def qsize(self): + """Number of items in the queue.""" + return len(self._queue) + + @property + def maxsize(self): + """Number of items allowed in the queue.""" + return self._maxsize + + def empty(self): + """Return True if the queue is empty, False otherwise.""" + return not self._queue + + def full(self): + """Return True if there are maxsize items in the queue. + + Note: if the Queue was initialized with maxsize=0 (the default), + then full() is never True. + """ + if self._maxsize <= 0: + return False + else: + return self.qsize() == self._maxsize + + @coroutine + def put(self, item, timeout=None): + """Put an item into the queue. + + If you yield from put() and timeout is None (the default), wait until a + free slot is available before adding item. + + If a timeout is provided, raise Full if no free slot becomes + available before the timeout. + """ + self._consume_done_getters(self._getters) + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + waiter = futures.Future(loop=self._loop, timeout=timeout) + + self._putters.append((item, waiter)) + try: + yield from waiter + except concurrent.futures.CancelledError: + raise Full + + else: + self._put(item) + + def put_nowait(self, item): + """Put an item into the queue without blocking. + + If no free slot is immediately available, raise Full. + """ + self._consume_done_getters(self._getters) + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + raise Full + else: + self._put(item) + + @coroutine + def get(self, timeout=None): + """Remove and return an item from the queue. + + If you yield from get() and timeout is None (the default), wait until a + item is available. + + If a timeout is provided, raise Empty if no item is available + before the timeout. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + + # When a getter runs and frees up a slot so this putter can + # run, we need to defer the put for a tick to ensure that + # getters and putters alternate perfectly. See + # ChannelTest.test_wait. + self._loop.call_soon(putter.set_result, None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + waiter = futures.Future(loop=self._loop, timeout=timeout) + + self._getters.append(waiter) + try: + return (yield from waiter) + except concurrent.futures.CancelledError: + raise Empty + + def get_nowait(self): + """Remove and return an item from the queue. + + Return an item if one is immediately available, else raise Full. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + # Wake putter on next tick. + putter.set_result(None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + raise Empty + + +class PriorityQueue(Queue): + """A subclass of Queue; retrieves entries in priority order (lowest first). + + Entries are typically tuples of the form: (priority number, data). + """ + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item, heappush=heapq.heappush): + heappush(self._queue, item) + + def _get(self, heappop=heapq.heappop): + return heappop(self._queue) + + +class LifoQueue(Queue): + """A subclass of Queue that retrieves most recently added entries first.""" + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item): + self._queue.append(item) + + def _get(self): + return self._queue.pop() + + +class JoinableQueue(Queue): + """A subclass of Queue with task_done() and join() methods.""" + + def __init__(self, maxsize=0, *, loop=None): + super().__init__(maxsize=maxsize, loop=loop) + self._unfinished_tasks = 0 + self._finished = locks.EventWaiter(loop=self._loop) + self._finished.set() + + def _format(self): + result = Queue._format(self) + if self._unfinished_tasks: + result += ' tasks={}'.format(self._unfinished_tasks) + return result + + def _put(self, item): + super()._put(item) + self._unfinished_tasks += 1 + self._finished.clear() + + def task_done(self): + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each get() used to fetch a task, + a subsequent call to task_done() tells the queue that the processing + on the task is complete. + + If a join() is currently blocking, it will resume when all items have + been processed (meaning that a task_done() call was received for every + item that had been put() into the queue). + + Raises ValueError if called more times than there were items placed in + the queue. + """ + if self._unfinished_tasks <= 0: + raise ValueError('task_done() called too many times') + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + @coroutine + def join(self, timeout=None): + """Block until all items in the queue have been gotten and processed. + + The count of unfinished tasks goes up whenever an item is added to the + queue. The count goes down whenever a consumer thread calls task_done() + to indicate that the item was retrieved and all work on it is complete. + When the count of unfinished tasks drops to zero, join() unblocks. + """ + if self._unfinished_tasks > 0: + yield from self._finished.wait(timeout=timeout) diff --git a/tulip/selector_events.py b/tulip/selector_events.py new file mode 100644 index 00000000..99e73ece --- /dev/null +++ b/tulip/selector_events.py @@ -0,0 +1,671 @@ +"""Event loop using a selector and related classes. + +A selector is a "notify-when-ready" multiplexer. For a subclass which +also includes support for signal handling, see the unix_events sub-module. +""" + +import collections +import socket +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import base_events +from . import constants +from . import events +from . import futures +from . import selectors +from . import transports +from .log import tulip_log + + +class BaseSelectorEventLoop(base_events.BaseEventLoop): + """Selector event loop. + + See events.EventLoop for API specification. + """ + + def __init__(self, selector=None): + super().__init__() + + if selector is None: + selector = selectors.DefaultSelector() + tulip_log.debug('Using selector: %s', selector.__class__.__name__) + self._selector = selector + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None): + return _SelectorSocketTransport(self, sock, protocol, waiter, extra) + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + server_side=False, extra=None): + return _SelectorSslTransport( + self, rawsock, protocol, sslcontext, waiter, server_side, extra) + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + return _SelectorDatagramTransport(self, sock, protocol, address, extra) + + def close(self): + if self._selector is not None: + self._close_self_pipe() + self._selector.close() + self._selector = None + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self.remove_reader(self._ssock.fileno()) + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.add_reader(self._ssock.fileno(), self._read_from_self) + + def _read_from_self(self): + try: + self._ssock.recv(1) + except (BlockingIOError, InterruptedError): + pass + + def _write_to_self(self): + try: + self._csock.send(b'x') + except (BlockingIOError, InterruptedError): + pass + + def _start_serving(self, protocol_factory, sock, ssl=None): + self.add_reader(sock.fileno(), self._accept_connection, + protocol_factory, sock, ssl) + + def _accept_connection(self, protocol_factory, sock, ssl=None): + try: + conn, addr = sock.accept() + conn.setblocking(False) + except (BlockingIOError, InterruptedError): + pass # False alarm. + except Exception: + # Bad error. Stop serving. + self.remove_reader(sock.fileno()) + sock.close() + # There's nowhere to send the error, so just log it. + # TODO: Someone will want an error handler for this. + tulip_log.exception('Accept failed') + else: + if ssl: + self._make_ssl_transport( + conn, protocol_factory(), ssl, None, + server_side=True, extra={'addr': addr}) + else: + self._make_socket_transport( + conn, protocol_factory(), extra={'addr': addr}) + # It's now up to the protocol to handle the connection. + + def add_reader(self, fd, callback, *args): + """Add a reader callback.""" + handle = events.make_handle(callback, args) + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_READ, + (handle, None)) + else: + self._selector.modify(fd, mask | selectors.EVENT_READ, + (handle, writer)) + if reader is not None: + reader.cancel() + + def remove_reader(self, fd): + """Remove a reader callback.""" + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + return False + else: + mask &= ~selectors.EVENT_READ + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (None, writer)) + + if reader is not None: + reader.cancel() + return True + else: + return False + + def add_writer(self, fd, callback, *args): + """Add a writer callback..""" + handle = events.make_handle(callback, args) + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_WRITE, + (None, handle)) + else: + self._selector.modify(fd, mask | selectors.EVENT_WRITE, + (reader, handle)) + if writer is not None: + writer.cancel() + + def remove_writer(self, fd): + """Remove a writer callback.""" + try: + mask, (reader, writer) = self._selector.get_info(fd) + except KeyError: + return False + else: + # Remove both writer and connector. + mask &= ~selectors.EVENT_WRITE + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None)) + + if writer is not None: + writer.cancel() + return True + else: + return False + + def sock_recv(self, sock, n): + """XXX""" + fut = futures.Future(loop=self) + self._sock_recv(fut, False, sock, n) + return fut + + def _sock_recv(self, fut, registered, sock, n): + fd = sock.fileno() + if registered: + # Remove the callback early. It should be rare that the + # selector says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_reader(fd) + if fut.cancelled(): + return + try: + data = sock.recv(n) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_recv, fut, True, sock, n) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(data) + + def sock_sendall(self, sock, data): + """XXX""" + fut = futures.Future(loop=self) + if data: + self._sock_sendall(fut, False, sock, data) + else: + fut.set_result(None) + return fut + + def _sock_sendall(self, fut, registered, sock, data): + fd = sock.fileno() + + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + + try: + n = sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + fut.set_exception(exc) + return + + if n == len(data): + fut.set_result(None) + else: + if n: + data = data[n:] + self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + + def sock_connect(self, sock, address): + """XXX""" + # That address better not require a lookup! We're not calling + # self.getaddrinfo() for you here. But verifying this is + # complicated; the socket module doesn't have a pattern for + # IPv6 addresses (there are too many forms, apparently). + fut = futures.Future(loop=self) + self._sock_connect(fut, False, sock, address) + return fut + + def _sock_connect(self, fut, registered, sock, address): + # TODO: Use getaddrinfo() to look up the address, to avoid the + # trap of hanging the entire event loop when the address + # requires doing a DNS lookup. (OTOH, the caller should + # already have done this, so it would be nice if we could + # easily tell whether the address needs looking up or not. I + # know how to do this for IPv4, but IPv6 addresses have many + # syntaxes.) + fd = sock.fileno() + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + try: + if not registered: + # First time around. + sock.connect(address) + else: + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to the except clause below. + raise OSError(err, 'Connect call failed') + except (BlockingIOError, InterruptedError): + self.add_writer(fd, self._sock_connect, fut, True, sock, address) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(None) + + def sock_accept(self, sock): + """XXX""" + fut = futures.Future(loop=self) + self._sock_accept(fut, False, sock) + return fut + + def _sock_accept(self, fut, registered, sock): + fd = sock.fileno() + if registered: + self.remove_reader(fd) + if fut.cancelled(): + return + try: + conn, address = sock.accept() + conn.setblocking(False) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_accept, fut, True, sock) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result((conn, address)) + + def _process_events(self, event_list): + for fileobj, mask, (reader, writer) in event_list: + if mask & selectors.EVENT_READ and reader is not None: + if reader._cancelled: + self.remove_reader(fileobj) + else: + self._add_callback(reader) + if mask & selectors.EVENT_WRITE and writer is not None: + if writer._cancelled: + self.remove_writer(fileobj) + else: + self._add_callback(writer) + + def stop_serving(self, sock): + self.remove_reader(sock.fileno()) + sock.close() + + +class _SelectorTransport(transports.Transport): + + def __init__(self, loop, sock, protocol, extra): + super().__init__(extra) + self._extra['socket'] = sock + self._loop = loop + self._sock = sock + self._sock_fd = sock.fileno() + self._protocol = protocol + self._buffer = [] + self._conn_lost = 0 + self._writing = True + self._closing = False # Set when close() called. + + def abort(self): + self._force_close(None) + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + self._loop.remove_reader(self._sock_fd) + if not self._buffer: + self._loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + # should be called from exception handler only + tulip_log.exception('Fatal error for %s', self) + self._force_close(exc) + + def _force_close(self, exc): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + self._loop.remove_writer(self._sock_fd) + self._loop.remove_reader(self._sock_fd) + self._buffer.clear() + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + self._sock = None + self._protocol = None + self._loop = None + + +class _SelectorSocketTransport(_SelectorTransport): + + def __init__(self, loop, sock, protocol, waiter=None, extra=None): + super().__init__(loop, sock, protocol, extra) + + self._loop.add_reader(self._sock_fd, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = self._sock.recv(16*1024) + except (BlockingIOError, InterruptedError): + pass + except ConnectionResetError as exc: + self._force_close(exc) + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + try: + self._protocol.eof_received() + finally: + self.close() + + def write(self, data): + assert isinstance(data, bytes), repr(data) + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer and self._writing: + # Attempt to send it right away first. + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except OSError as exc: + self._fatal_error(exc) + return + + if n == len(data): + return + elif n: + data = data[n:] + self._loop.add_writer(self._sock_fd, self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + if not self._writing: + return # transmission off + + data = b''.join(self._buffer) + assert data, 'Data should not be empty' + + self._buffer.clear() + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except Exception as exc: + self._fatal_error(exc) + else: + if n == len(data): + self._loop.remove_writer(self._sock_fd) + if self._closing: + self._call_connection_lost(None) + return + elif n: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def pause_writing(self): + if self._writing: + if self._buffer: + self._loop.remove_writer(self._sock_fd) + self._writing = False + + def resume_writing(self): + if not self._writing: + if self._buffer: + self._loop.add_writer(self._sock_fd, self._write_ready) + self._writing = True + + def discard_output(self): + if self._buffer: + self._loop.remove_writer(self._sock_fd) + self._buffer.clear() + + +class _SelectorSslTransport(_SelectorTransport): + + def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, + server_side=False, extra=None): + if server_side: + assert isinstance( + sslcontext, ssl.SSLContext), 'Must pass an SSLContext' + else: + # Client-side may pass ssl=True to use a default context. + sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslsock = sslcontext.wrap_socket(rawsock, server_side=server_side, + do_handshake_on_connect=False) + + super().__init__(loop, sslsock, protocol, extra) + + self._waiter = waiter + self._rawsock = rawsock + self._sslcontext = sslcontext + + self._on_handshake() + + def _on_handshake(self): + try: + self._sock.do_handshake() + except ssl.SSLWantReadError: + self._loop.add_reader(self._sock_fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._loop.add_writer(self._sock_fd, self._on_handshake) + return + except Exception as exc: + self._sock.close() + if self._waiter is not None: + self._waiter.set_exception(exc) + return + except BaseException as exc: + self._sock.close() + if self._waiter is not None: + self._waiter.set_exception(exc) + raise + self._loop.remove_reader(self._sock_fd) + self._loop.remove_writer(self._sock_fd) + self._loop.add_reader(self._sock_fd, self._on_ready) + self._loop.add_writer(self._sock_fd, self._on_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if self._waiter is not None: + self._loop.call_soon(self._waiter.set_result, None) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # First try reading. + if not self._closing: + try: + data = self._sock.recv(8192) + except (BlockingIOError, InterruptedError, + ssl.SSLWantReadError, ssl.SSLWantWriteError): + pass + except ConnectionResetError as exc: + self._force_close(exc) + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + try: + self._protocol.eof_received() + finally: + self.close() + + # Now try writing, if there's anything to write. + if self._buffer: + data = b''.join(self._buffer) + self._buffer = [] + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError, + ssl.SSLWantReadError, ssl.SSLWantWriteError): + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + + if n < len(data): + self._buffer.append(data[n:]) + + if self._closing and not self._buffer: + self._loop.remove_writer(self._sock_fd) + self._call_connection_lost(None) + + def write(self, data): + assert isinstance(data, bytes), repr(data) + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + self._loop.remove_reader(self._sock_fd) + + # TODO: write_eof(), can_write_eof(). + + +class _SelectorDatagramTransport(_SelectorTransport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, loop, sock, protocol, address=None, extra=None): + super().__init__(loop, sock, protocol, extra) + + self._address = address + self._buffer = collections.deque() + self._loop.add_reader(self._sock_fd, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + + def _read_ready(self): + try: + data, addr = self._sock.recvfrom(self.max_size) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + self._protocol.datagram_received(data, addr) + + def sendto(self, data, addr=None): + assert isinstance(data, bytes), repr(data) + if not data: + return + + if self._address: + assert addr in (None, self._address) + + if self._conn_lost and self._address: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Attempt to send it right away first. + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + return + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except (BlockingIOError, InterruptedError): + self._loop.add_writer(self._sock_fd, self._sendto_ready) + except Exception as exc: + self._fatal_error(exc) + return + + self._buffer.append((data, addr)) + + def _sendto_ready(self): + while self._buffer: + data, addr = self._buffer.popleft() + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except (BlockingIOError, InterruptedError): + self._buffer.appendleft((data, addr)) # Try again later. + break + except Exception as exc: + self._fatal_error(exc) + return + + if not self._buffer: + self._loop.remove_writer(self._sock_fd) + if self._closing: + self._call_connection_lost(None) + + def _force_close(self, exc): + if self._address and isinstance(exc, ConnectionRefusedError): + self._protocol.connection_refused(exc) + + super()._force_close(exc) diff --git a/tulip/selectors.py b/tulip/selectors.py new file mode 100644 index 00000000..388df25f --- /dev/null +++ b/tulip/selectors.py @@ -0,0 +1,426 @@ +"""Select module. + +This module supports asynchronous I/O on multiple file descriptors. +""" + +import sys +from select import * + +from .log import tulip_log + + +# generic events, that must be mapped to implementation-specific ones +# read event +EVENT_READ = (1 << 0) +# write event +EVENT_WRITE = (1 << 1) + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file descriptor, or any object with a `fileno()` method + + Returns: + corresponding file descriptor + """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (ValueError, TypeError): + raise ValueError("Invalid file object: {!r}".format(fileobj)) + return fd + + +class SelectorKey: + """Object used internally to associate a file object to its backing file + descriptor, selected event mask and attached data.""" + + def __init__(self, fileobj, events, data=None): + self.fileobj = fileobj + self.fd = _fileobj_to_fd(fileobj) + self.events = events + self.data = data + + def __repr__(self): + return '{}'.format( + self.__class__.__name__, + self.fileobj, self.fd, self.events, self.data) + + +class _BaseSelector: + """Base selector class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + performant implementation on the current platform. + """ + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # this maps file objects to keys - for fast (un)registering + self._fileobj_to_key = {} + + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + """ + if (not events) or (events & ~(EVENT_READ|EVENT_WRITE)): + raise ValueError("Invalid events: {}".format(events)) + + if fileobj in self._fileobj_to_key: + raise ValueError("{!r} is already registered".format(fileobj)) + + key = SelectorKey(fileobj, events, data) + self._fd_to_key[key.fd] = key + self._fileobj_to_key[fileobj] = key + return key + + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object + + Returns: + SelectorKey instance + """ + try: + key = self._fileobj_to_key[fileobj] + del self._fd_to_key[key.fd] + del self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + return key + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + """ + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise ValueError("{!r} is not registered".format(fileobj)) + if events != key.events or data != key.data: + # TODO: If only the data changed, use a shortcut that only + # updates the data. + self.unregister(fileobj) + return self.register(fileobj, events, data) + else: + return key + + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout == 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (fileobj, events, attached data) for ready file objects + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE + """ + raise NotImplementedError() + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + self._fd_to_key.clear() + self._fileobj_to_key.clear() + + def get_info(self, fileobj): + """Return information about a registered file object. + + Returns: + (events, data) associated to this file object + + Raises KeyError if the file object is not registered. + """ + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise KeyError("{} is not registered".format(fileobj)) + return key.events, key.data + + def registered_count(self): + """Return the number of registered file objects. + + Returns: + number of currently registered file objects + """ + return len(self._fd_to_key) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key + """ + try: + return self._fd_to_key[fd] + except KeyError: + tulip_log.warning('No key found for fd %r', fd) + return None + + +class SelectSelector(_BaseSelector): + """Select-based selector.""" + + def __init__(self): + super().__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & EVENT_WRITE: + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + def select(self, timeout=None): + try: + r, w, _ = self._select(self._readers, self._writers, [], timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + r = set(r) + w = set(w) + ready = [] + for fd in r | w: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + if sys.platform == 'win32': + def _select(self, r, w, _, timeout=None): + r, w, x = select(r, w, w, timeout) + return r, w + x, [] + else: + from select import select as _select + + +if 'poll' in globals(): + + # TODO: Implement poll() for Windows with workaround for + # brokenness in WSAPoll() (Richard Oudkerk, see + # http://bugs.python.org/issue16507). + + class PollSelector(_BaseSelector): + """Poll-based selector.""" + + def __init__(self): + super().__init__() + self._poll = poll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= POLLIN + if events & EVENT_WRITE: + poll_events |= POLLOUT + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = None if timeout is None else int(1000 * timeout) + ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~POLLIN: + events |= EVENT_WRITE + if event & ~POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + +if 'epoll' in globals(): + + class EpollSelector(_BaseSelector): + """Epoll-based selector.""" + + def __init__(self): + super().__init__() + self._epoll = epoll() + + def fileno(self): + return self._epoll.fileno() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + epoll_events = 0 + if events & EVENT_READ: + epoll_events |= EPOLLIN + if events & EVENT_WRITE: + epoll_events |= EPOLLOUT + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._epoll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = -1 if timeout is None else timeout + max_ev = self.registered_count() + ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for fd, event in fd_event_list: + events = 0 + if event & ~EPOLLIN: + events |= EVENT_WRITE + if event & ~EPOLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + def close(self): + super().close() + self._epoll.close() + + +if 'kqueue' in globals(): + + class KqueueSelector(_BaseSelector): + """Kqueue-based selector.""" + + def __init__(self): + super().__init__() + self._kqueue = kqueue() + + def fileno(self): + return self._kqueue.fileno() + + def unregister(self, fileobj): + key = super().unregister(fileobj) + if key.events & EVENT_READ: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + if key.events & EVENT_WRITE: + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + return key + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & EVENT_WRITE: + kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def select(self, timeout=None): + max_ev = self.registered_count() + ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except InterruptedError: + # A signal arrived. Don't die, just return no events. + return [] + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == KQ_FILTER_READ: + events |= EVENT_READ + if flag == KQ_FILTER_WRITE: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key.fileobj, events & key.events, key.data)) + return ready + + def close(self): + super().close() + self._kqueue.close() + + +# Choose the best implementation: roughly, epoll|kqueue > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + DefaultSelector = KqueueSelector +elif 'EpollSelector' in globals(): + DefaultSelector = EpollSelector +elif 'PollSelector' in globals(): + DefaultSelector = PollSelector +else: + DefaultSelector = SelectSelector diff --git a/tulip/streams.py b/tulip/streams.py new file mode 100644 index 00000000..3203b7d6 --- /dev/null +++ b/tulip/streams.py @@ -0,0 +1,211 @@ +"""Stream-related things.""" + +__all__ = ['StreamReader', 'StreamReaderProtocol', 'open_connection'] + +import collections + +from . import events +from . import futures +from . import protocols +from . import tasks + + +_DEFAULT_LIMIT = 2**16 + + +@tasks.coroutine +def open_connection(host=None, port=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """A wrapper for create_connection() returning a (reader, writer) pair. + + The reader returned is a StreamReader instance; the writer is a + Transport. + + The arguments are all the usual arguments to create_connection() + except protocol_factory; most common are positional host and port, + with various optional keyword arguments following. + + Additional optional keyword arguments are loop (to set the event loop + instance to use) and limit (to set the buffer limit passed to the + StreamReader). + + (If you want to customize the StreamReader and/or + StreamReaderProtocol classes, just copy the code -- there's + really nothing special here except some convenience.) + """ + if loop is None: + loop = events.get_event_loop() + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader) + transport, _ = yield from loop.create_connection( + lambda: protocol, host, port, **kwds) + return reader, transport # (reader, writer) + + +class StreamReaderProtocol(protocols.Protocol): + """Trivial helper class to adapt between Protocol and StreamReader. + + (This is a helper class instead of making StreamReader itself a + Protocol subclass, because the StreamReader has other potential + uses, and to prevent the user of the StreamReader to accidentally + call inappropriate methods of the protocol.) + """ + + def __init__(self, stream_reader): + self.stream_reader = stream_reader + + def connection_lost(self, exc): + if exc is None: + self.stream_reader.feed_eof() + else: + self.stream_reader.set_exception(exc) + + def data_received(self, data): + self.stream_reader.feed_data(data) + + def eof_received(self): + self.stream_reader.feed_eof() + + +class StreamReader: + + def __init__(self, limit=_DEFAULT_LIMIT, loop=None): + self.limit = limit # Max line length. (Security feature.) + if loop is None: + loop = events.get_event_loop() + self.loop = loop + self.buffer = collections.deque() # Deque of bytes objects. + self.byte_count = 0 # Bytes in buffer. + self.eof = False # Whether we're done. + self.waiter = None # A future. + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_exception(exc) + + def feed_eof(self): + self.eof = True + waiter = self.waiter + if waiter is not None: + self.waiter = None + if not waiter.done(): + waiter.set_result(True) + + def feed_data(self, data): + if not data: + return + + self.buffer.append(data) + self.byte_count += len(data) + + waiter = self.waiter + if waiter is not None: + self.waiter = None + if not waiter.done(): + waiter.set_result(False) + + @tasks.coroutine + def readline(self): + if self._exception is not None: + raise self._exception + + parts = [] + parts_size = 0 + not_enough = True + + while not_enough: + while self.buffer and not_enough: + data = self.buffer.popleft() + ichar = data.find(b'\n') + if ichar < 0: + parts.append(data) + parts_size += len(data) + else: + ichar += 1 + head, tail = data[:ichar], data[ichar:] + if tail: + self.buffer.appendleft(tail) + not_enough = False + parts.append(head) + parts_size += len(head) + + if parts_size > self.limit: + self.byte_count -= parts_size + raise ValueError('Line is too long') + + if self.eof: + break + + if not_enough: + assert self.waiter is None + self.waiter = futures.Future(loop=self.loop) + yield from self.waiter + + line = b''.join(parts) + self.byte_count -= parts_size + + return line + + @tasks.coroutine + def read(self, n=-1): + if self._exception is not None: + raise self._exception + + if not n: + return b'' + + if n < 0: + while not self.eof: + assert not self.waiter + self.waiter = futures.Future(loop=self.loop) + yield from self.waiter + else: + if not self.byte_count and not self.eof: + assert not self.waiter + self.waiter = futures.Future(loop=self.loop) + yield from self.waiter + + if n < 0 or self.byte_count <= n: + data = b''.join(self.buffer) + self.buffer.clear() + self.byte_count = 0 + return data + + parts = [] + parts_bytes = 0 + while self.buffer and parts_bytes < n: + data = self.buffer.popleft() + data_bytes = len(data) + if n < parts_bytes + data_bytes: + data_bytes = n - parts_bytes + data, rest = data[:data_bytes], data[data_bytes:] + self.buffer.appendleft(rest) + + parts.append(data) + parts_bytes += data_bytes + self.byte_count -= data_bytes + + return b''.join(parts) + + @tasks.coroutine + def readexactly(self, n): + if self._exception is not None: + raise self._exception + + if n <= 0: + return b'' + + while self.byte_count < n and not self.eof: + assert not self.waiter + self.waiter = futures.Future(loop=self.loop) + yield from self.waiter + + return (yield from self.read(n)) diff --git a/tulip/tasks.py b/tulip/tasks.py new file mode 100644 index 00000000..ca513a10 --- /dev/null +++ b/tulip/tasks.py @@ -0,0 +1,359 @@ +"""Support for tasks, coroutines and the scheduler.""" + +__all__ = ['coroutine', 'task', 'Task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'wait_for', 'as_completed', 'sleep', 'async', + ] + +import collections +import concurrent.futures +import functools +import inspect + +from . import events +from . import futures + + +def coroutine(func): + """Decorator to mark coroutines. + + Decorator wraps non generator functions and returns generator wrapper. + If non generator function returns generator of Future it yield-from it. + + TODO: This is a feel-good API only. It is not enforced. + """ + if inspect.isgeneratorfunction(func): + coro = func + else: + @functools.wraps(func) + def coro(*args, **kw): + res = func(*args, **kw) + if isinstance(res, futures.Future) or inspect.isgenerator(res): + res = yield from res + return res + + coro._is_coroutine = True # Not sure who can use this. + return coro + + +# TODO: Do we need this? +def iscoroutinefunction(func): + """Return True if func is a decorated coroutine function.""" + return (inspect.isgeneratorfunction(func) and + getattr(func, '_is_coroutine', False)) + + +# TODO: Do we need this? +def iscoroutine(obj): + """Return True if obj is a coroutine object.""" + return inspect.isgenerator(obj) # TODO: And what? + + +def task(func): + """Decorator for a coroutine to be wrapped in a Task.""" + if inspect.isgeneratorfunction(func): + coro = func + else: + def coro(*args, **kw): + res = func(*args, **kw) + if isinstance(res, futures.Future) or inspect.isgenerator(res): + res = yield from res + return res + + def task_wrapper(*args, **kwds): + return Task(coro(*args, **kwds)) + + return task_wrapper + + +_marker = object() + + +class Task(futures.Future): + """A coroutine wrapped in a Future.""" + + def __init__(self, coro, *, loop=None, timeout=None): + assert inspect.isgenerator(coro) # Must be a coroutine *object*. + super().__init__(loop=loop, timeout=timeout) + self._coro = coro + self._fut_waiter = None + self._must_cancel = False + self._loop.call_soon(self._step) + + def __repr__(self): + res = super().__repr__() + if (self._must_cancel and + self._state == futures._PENDING and + ')'.format(self._coro.__name__) + res[i:] + return res + + def cancel(self): + if self.done() or self._must_cancel: + return False + self._must_cancel = True + # _step() will call super().cancel() to call the callbacks. + if self._fut_waiter is not None: + return self._fut_waiter.cancel() + else: + self._loop.call_soon(self._step_maybe) + return True + + def cancelled(self): + return self._must_cancel or super().cancelled() + + def _step_maybe(self): + # Helper for cancel(). + if not self.done(): + return self._step() + + def _step(self, value=_marker, exc=None): + assert not self.done(), \ + '_step(): already done: {!r}, {!r}, {!r}'.format(self, value, exc) + + # We'll call either coro.throw(exc) or coro.send(value). + # Task cancel has to be delayed if current waiter future is done. + if self._must_cancel and exc is None and value is _marker: + exc = futures.CancelledError + + coro = self._coro + value = None if value is _marker else value + self._fut_waiter = None + try: + if exc is not None: + result = coro.throw(exc) + elif value is not None: + result = coro.send(value) + else: + result = next(coro) + except StopIteration as exc: + if self._must_cancel: + super().cancel() + else: + self.set_result(exc.value) + except Exception as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + except BaseException as exc: + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) + raise + else: + if isinstance(result, futures.Future): + if not result._blocking: + result.set_exception( + RuntimeError( + 'yield was used instead of yield from ' + 'in task {!r} with {!r}'.format(self, result))) + + result._blocking = False + result.add_done_callback(self._wakeup) + self._fut_waiter = result + + # task cancellation has been delayed. + if self._must_cancel: + self._fut_waiter.cancel() + + else: + if inspect.isgenerator(result): + self._loop.call_soon( + self._step, None, + RuntimeError( + 'yield was used instead of yield from for ' + 'generator in task {!r} with {}'.format( + self, result))) + else: + if result is not None: + self._loop.call_soon( + self._step, None, + RuntimeError( + 'Task got bad yield: {!r}'.format(result))) + else: + self._loop.call_soon(self._step_maybe) + self = None + + def _wakeup(self, future): + try: + value = future.result() + except Exception as exc: + self._step(None, exc) + else: + self._step(value, None) + self = None # Needed to break cycles when an exception occurs. + + +# wait() and as_completed() similar to those in PEP 3148. + +FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED +FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION +ALL_COMPLETED = concurrent.futures.ALL_COMPLETED + + +@coroutine +def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): + """Wait for the Futures and coroutines given by fs to complete. + + Coroutines will be wrapped in Tasks. + + Returns two sets of Future: (done, pending). + + Usage: + + done, pending = yield from tulip.wait(fs) + + Note: This does not raise TimeoutError! Futures that aren't done + when the timeout occurs are returned in the second set. + """ + if not fs: + raise ValueError('Set of coroutines/Futures is empty.') + + if loop is None: + loop = events.get_event_loop() + + fs = set(async(f, loop=loop) for f in fs) + + if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED): + raise ValueError('Invalid return_when value: {}'.format(return_when)) + return (yield from _wait(fs, timeout, return_when, loop)) + + +@coroutine +def wait_for(fut, timeout, *, loop=None): + """Wait for the single Future or coroutine to complete. + + Coroutine will be wrapped in Task. + + Returns result of the Future or coroutine. Raises TimeoutError when + timeout occurs. + + Usage: + + result = yield from tulip.wait_for(fut, 10.0) + + """ + if loop is None: + loop = events.get_event_loop() + + fut = async(fut, loop=loop) + + done, pending = yield from _wait([fut], timeout, FIRST_COMPLETED, loop) + if done: + return done.pop().result() + + raise futures.TimeoutError() + + +@coroutine +def _wait(fs, timeout, return_when, loop): + """Internal helper for wait(return_when=FIRST_COMPLETED). + + The fs argument must be a set of Futures. + The timeout argument is like for wait(). + """ + assert fs, 'Set of Futures is empty.' + waiter = futures.Future(loop=loop, timeout=timeout) + counter = len(fs) + + def _on_completion(f): + nonlocal counter + counter -= 1 + if (counter <= 0 or + return_when == FIRST_COMPLETED or + return_when == FIRST_EXCEPTION and (not f.cancelled() and + f.exception() is not None)): + waiter.cancel() + + for f in fs: + f.add_done_callback(_on_completion) + try: + yield from waiter + except futures.CancelledError: + pass + done, pending = set(), set() + for f in fs: + f.remove_done_callback(_on_completion) + if f.done(): + done.add(f) + else: + pending.add(f) + return done, pending + + +# This is *not* a @coroutine! It is just an iterator (yielding Futures). +def as_completed(fs, *, loop=None, timeout=None): + """Return an iterator whose values, when waited for, are Futures. + + This differs from PEP 3148; the proper way to use this is: + + for f in as_completed(fs): + result = yield from f # The 'yield from' may raise. + # Use result. + + Raises TimeoutError if the timeout occurs before all Futures are + done. + + Note: The futures 'f' are not necessarily members of fs. + """ + loop = loop if loop is not None else events.get_event_loop() + deadline = None if timeout is None else loop.time() + timeout + todo = set(async(f, loop=loop) for f in fs) + completed = collections.deque() + + @coroutine + def _wait_for_one(): + while not completed: + timeout = None + if deadline is not None: + timeout = deadline - loop.time() + if timeout < 0: + raise futures.TimeoutError() + done, pending = yield from _wait( + todo, timeout, FIRST_COMPLETED, loop) + # Multiple callers might be waiting for the same events + # and getting the same outcome. Dedupe by updating todo. + for f in done: + if f in todo: + todo.remove(f) + completed.append(f) + f = completed.popleft() + return f.result() # May raise. + + for _ in range(len(todo)): + yield _wait_for_one() + + +@coroutine +def sleep(delay, result=None, *, loop=None): + """Coroutine that completes after a given time (in seconds).""" + future = futures.Future(loop=loop) + h = future._loop.call_later(delay, future.set_result, result) + try: + return (yield from future) + finally: + h.cancel() + + +def async(coro_or_future, *, loop=None, timeout=None): + """Wrap a coroutine in a future. + + If the argument is a Future, it is returned directly. + """ + if isinstance(coro_or_future, futures.Future): + if ((loop is not None and loop is not coro_or_future._loop) or + (timeout is not None and timeout != coro_or_future._timeout)): + raise ValueError( + 'loop and timeout arguments must agree with Future') + + return coro_or_future + elif iscoroutine(coro_or_future): + return Task(coro_or_future, loop=loop, timeout=timeout) + else: + raise TypeError('A Future or coroutine is required') diff --git a/tulip/test_utils.py b/tulip/test_utils.py new file mode 100644 index 00000000..05d5e6ab --- /dev/null +++ b/tulip/test_utils.py @@ -0,0 +1,443 @@ +"""Utilities shared by tests.""" + +import cgi +import collections +import contextlib +import gc +import email.parser +import http.server +import json +import logging +import io +import unittest.mock +import os +import re +import socket +import sys +import threading +import traceback +import unittest +import unittest.mock +import urllib.parse +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +import tulip +import tulip.http +from tulip.http import client +from tulip import base_events +from tulip import events + +from tulip import base_events +from tulip import selectors + + +if sys.platform == 'win32': # pragma: no cover + from .windows_utils import socketpair +else: + from socket import socketpair # pragma: no cover + + +def run_briefly(loop): + @tulip.coroutine + def once(): + pass + t = tulip.Task(once(), loop=loop) + loop.run_until_complete(t) + + +def run_once(loop): + """loop.stop() schedules _raise_stop_error() + and run_forever() runs until _raise_stop_error() callback. + this wont work if test waits for some IO events, because + _raise_stop_error() runs before any of io events callbacks. + """ + loop.stop() + loop.run_forever() + + +@contextlib.contextmanager +def run_test_server(loop, *, host='127.0.0.1', port=0, + use_ssl=False, router=None): + properties = {} + transports = [] + + class HttpServer: + + def __init__(self, host, port): + self.host = host + self.port = port + self.address = (host, port) + self._url = '{}://{}:{}'.format( + 'https' if use_ssl else 'http', host, port) + + def __getitem__(self, key): + return properties[key] + + def __setitem__(self, key, value): + properties[key] = value + + def url(self, *suffix): + return urllib.parse.urljoin( + self._url, '/'.join(str(s) for s in suffix)) + + class TestHttpServer(tulip.http.ServerHttpProtocol): + + def connection_made(self, transport): + transports.append(transport) + super().connection_made(transport) + + def handle_request(self, message, payload): + if properties.get('close', False): + return + + if properties.get('noresponse', False): + yield from tulip.sleep(99999) + + if router is not None: + body = bytearray() + chunk = yield from payload.read() + while chunk: + body.extend(chunk) + chunk = yield from payload.read() + + rob = router( + self, properties, + self.transport, message, bytes(body)) + rob.dispatch() + + else: + response = tulip.http.Response( + self.transport, 200, message.version) + + text = b'Test message' + response.add_header('Content-type', 'text/plain') + response.add_header('Content-length', str(len(text))) + response.send_headers() + response.write(text) + response.write_eof() + + if use_ssl: + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + keyfile = os.path.join(here, 'sample.key') + certfile = os.path.join(here, 'sample.crt') + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain(certfile, keyfile) + else: + sslcontext = None + + def run(loop, fut): + thread_loop = tulip.new_event_loop() + tulip.set_event_loop(thread_loop) + + socks = thread_loop.run_until_complete( + thread_loop.start_serving( + lambda: TestHttpServer(keep_alive=0.5), + host, port, ssl=sslcontext)) + + waiter = tulip.Future(loop=thread_loop) + loop.call_soon_threadsafe( + fut.set_result, (thread_loop, waiter, socks[0].getsockname())) + + try: + thread_loop.run_until_complete(waiter) + finally: + # call pending connection_made if present + run_briefly(thread_loop) + + # close opened trnsports + for tr in transports: + tr.close() + + run_briefly(thread_loop) # call close callbacks + + for s in socks: + thread_loop.stop_serving(s) + + thread_loop.stop() + thread_loop.close() + gc.collect() + + fut = tulip.Future(loop=loop) + server_thread = threading.Thread(target=run, args=(loop, fut)) + server_thread.start() + + thread_loop, waiter, addr = loop.run_until_complete(fut) + try: + yield HttpServer(*addr) + finally: + thread_loop.call_soon_threadsafe(waiter.set_result, None) + server_thread.join() + + +class Router: + + _response_version = "1.1" + _responses = http.server.BaseHTTPRequestHandler.responses + + def __init__(self, srv, props, transport, message, payload): + # headers + self._headers = http.client.HTTPMessage() + for hdr, val in message.headers: + self._headers.add_header(hdr, val) + + self._srv = srv + self._props = props + self._transport = transport + self._method = message.method + self._uri = message.path + self._version = message.version + self._compression = message.compression + self._body = payload + + url = urllib.parse.urlsplit(self._uri) + self._path = url.path + self._query = url.query + + @staticmethod + def define(rmatch): + def wrapper(fn): + f_locals = sys._getframe(1).f_locals + mapping = f_locals.setdefault('_mapping', []) + mapping.append((re.compile(rmatch), fn.__name__)) + return fn + + return wrapper + + def dispatch(self): # pragma: no cover + for route, fn in self._mapping: + match = route.match(self._path) + if match is not None: + try: + return getattr(self, fn)(match) + except Exception: + out = io.StringIO() + traceback.print_exc(file=out) + self._response(500, out.getvalue()) + + return + + return self._response(self._start_response(404)) + + def _start_response(self, code): + return tulip.http.Response(self._transport, code) + + def _response(self, response, body=None, headers=None, chunked=False): + r_headers = {} + for key, val in self._headers.items(): + key = '-'.join(p.capitalize() for p in key.split('-')) + r_headers[key] = val + + encoding = self._headers.get('content-encoding', '').lower() + if 'gzip' in encoding: # pragma: no cover + cmod = 'gzip' + elif 'deflate' in encoding: + cmod = 'deflate' + else: + cmod = '' + + resp = { + 'method': self._method, + 'version': '%s.%s' % self._version, + 'path': self._uri, + 'headers': r_headers, + 'origin': self._transport.get_extra_info('addr', ' ')[0], + 'query': self._query, + 'form': {}, + 'compression': cmod, + 'multipart-data': [] + } + if body: # pragma: no cover + resp['content'] = body + + ct = self._headers.get('content-type', '').lower() + + # application/x-www-form-urlencoded + if ct == 'application/x-www-form-urlencoded': + resp['form'] = urllib.parse.parse_qs(self._body.decode('latin1')) + + # multipart/form-data + elif ct.startswith('multipart/form-data'): # pragma: no cover + out = io.BytesIO() + for key, val in self._headers.items(): + out.write(bytes('{}: {}\r\n'.format(key, val), 'latin1')) + + out.write(b'\r\n') + out.write(self._body) + out.write(b'\r\n') + out.seek(0) + + message = email.parser.BytesParser().parse(out) + if message.is_multipart(): + for msg in message.get_payload(): + if msg.is_multipart(): + logging.warn('multipart msg is not expected') + else: + key, params = cgi.parse_header( + msg.get('content-disposition', '')) + params['data'] = msg.get_payload() + params['content-type'] = msg.get_content_type() + resp['multipart-data'].append(params) + + body = json.dumps(resp, indent=4, sort_keys=True) + + # default headers + hdrs = [('Connection', 'close'), + ('Content-Type', 'application/json')] + if chunked: + hdrs.append(('Transfer-Encoding', 'chunked')) + else: + hdrs.append(('Content-Length', str(len(body)))) + + # extra headers + if headers: + hdrs.extend(headers.items()) + + if chunked: + response.force_chunked() + + # headers + response.add_headers(*hdrs) + response.send_headers() + + # write payload + response.write(client.str_to_bytes(body)) + response.write_eof() + + # keep-alive + if response.keep_alive(): + self._srv.keep_alive(True) + + +def make_test_protocol(base): + dct = {} + for name in dir(base): + if name.startswith('__') and name.endswith('__'): + # skip magic names + continue + dct[name] = unittest.mock.Mock(return_value=None) + return type('TestProtocol', (base,) + base.__bases__, dct)() + + +class TestSelector(selectors._BaseSelector): + + def select(self, timeout): + return [] + + +class TestLoop(base_events.BaseEventLoop): + """Loop for unittests. + + It manages self time directly. + If something scheduled to be executed later then + on next loop iteration after all ready handlers done + generator passed to __init__ is calling. + + Generator should be like this: + + def gen(): + ... + when = yield ... + ... = yield time_advance + + Value retuned by yield is absolute time of next scheduled handler. + Value passed to yield is time advance to move loop's time forward. + """ + + def __init__(self, gen=None): + super().__init__() + + if gen is None: + self._check_on_close = False + def gen(): + yield + else: + self._check_on_close = True + + self._gen = gen() + next(self._gen) + self._time = 0 + self._timers = [] + self._selector = TestSelector() + + self.readers = {} + self.writers = {} + self.reset_counters() + + def time(self): + return self._time + + def advance_time(self, advance): + """Move test time forward.""" + if advance: + self._time += advance + + def close(self): + if self._check_on_close: + try: + self._gen.send(0) + except StopIteration: + pass + else: + raise AssertionError("Time generator is not finished") + + def add_reader(self, fd, callback, *args): + self.readers[fd] = events.make_handle(callback, args) + + def remove_reader(self, fd): + self.remove_reader_count[fd] += 1 + if fd in self.readers: + del self.readers[fd] + return True + else: + return False + + def assert_reader(self, fd, callback, *args): + assert fd in self.readers, 'fd {} is not registered'.format(fd) + handle = self.readers[fd] + assert handle._callback == callback, '{!r} != {!r}'.format( + handle._callback, callback) + assert handle._args == args, '{!r} != {!r}'.format( + handle._args, args) + + def add_writer(self, fd, callback, *args): + self.writers[fd] = events.make_handle(callback, args) + + def remove_writer(self, fd): + self.remove_writer_count[fd] += 1 + if fd in self.writers: + del self.writers[fd] + return True + else: + return False + + def assert_writer(self, fd, callback, *args): + assert fd in self.writers, 'fd {} is not registered'.format(fd) + handle = self.writers[fd] + assert handle._callback == callback, '{!r} != {!r}'.format( + handle._callback, callback) + assert handle._args == args, '{!r} != {!r}'.format( + handle._args, args) + + def reset_counters(self): + self.remove_reader_count = collections.defaultdict(int) + self.remove_writer_count = collections.defaultdict(int) + + def _run_once(self): + super()._run_once() + for when in self._timers: + advance = self._gen.send(when) + self.advance_time(advance) + self._timers = [] + + def call_at(self, when, callback, *args): + self._timers.append(when) + return super().call_at(when, callback, *args) + + def _process_events(self, event_list): + return + + def _write_to_self(self): + pass diff --git a/tulip/transports.py b/tulip/transports.py new file mode 100644 index 00000000..56425aa9 --- /dev/null +++ b/tulip/transports.py @@ -0,0 +1,201 @@ +"""Abstract Transport class.""" + +__all__ = ['ReadTransport', 'WriteTransport', 'Transport'] + + +class BaseTransport: + """Base ABC for transports.""" + + def __init__(self, extra=None): + if extra is None: + extra = {} + self._extra = extra + + def get_extra_info(self, name, default=None): + """Get optional transport information.""" + return self._extra.get(name, default) + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + raise NotImplementedError + + +class ReadTransport(BaseTransport): + """ABC for read-only transports.""" + + def pause(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + raise NotImplementedError + + +class WriteTransport(BaseTransport): + """ABC for write-only transports.""" + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def write_eof(self): + """Closes the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + raise NotImplementedError + + def can_write_eof(self): + """Return True if this protocol supports write_eof(), False if not.""" + raise NotImplementedError + + def pause_writing(self): + """Pause transmission on the transport. + + Subsequent writes are deferred until resume_writing() is called. + """ + raise NotImplementedError + + def resume_writing(self): + """Resume transmission on the transport. """ + raise NotImplementedError + + def discard_output(self): + """Discard any buffered data awaiting transmission on the transport.""" + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class Transport(ReadTransport, WriteTransport): + """ABC representing a bidirectional transport. + + There may be several implementations, but typically, the user does + not implement new transports; rather, the platform provides some + useful transports that are implemented using the platform's best + practices. + + The user never instantiates a transport directly; they call a + utility function, passing it a protocol factory and other + information necessary to create the transport and protocol. (E.g. + EventLoop.create_connection() or EventLoop.start_serving().) + + The utility function will asynchronously create a transport and a + protocol and hook them up by calling the protocol's + connection_made() method, passing it the transport. + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + +class DatagramTransport(BaseTransport): + """ABC for datagram (UDP) transports.""" + + def sendto(self, data, addr=None): + """Send data to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + addr is target socket address. + If addr is None use target address pointed on transport creation. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class SubprocessTransport(BaseTransport): + + def get_pid(self): + """Get subprocess id.""" + raise NotImplementedError + + def get_returncode(self): + """Get subprocess returncode. + + See also + http://docs.python.org/3/library/subprocess#subprocess.Popen.returncode + """ + raise NotImplementedError + + def get_pipe_transport(self, fd): + """Get transport for pipe with number fd.""" + raise NotImplementedError + + def send_signal(self, signal): + """Send signal to subprocess. + + See also: + docs.python.org/3/library/subprocess#subprocess.Popen.send_signal + """ + raise NotImplementedError + + def terminate(self): + """Stop the subprocess. + + Alias for close() method. + + On Posix OSs the method sends SIGTERM to the subprocess. + On Windows the Win32 API function TerminateProcess() + is called to stop the subprocess. + + See also: + http://docs.python.org/3/library/subprocess#subprocess.Popen.terminate + """ + raise NotImplementedError + + def kill(self): + """Kill the subprocess. + + On Posix OSs the function sends SIGKILL to the subprocess. + On Windows kill() is an alias for terminate(). + + See also: + http://docs.python.org/3/library/subprocess#subprocess.Popen.kill + """ + raise NotImplementedError diff --git a/tulip/unix_events.py b/tulip/unix_events.py new file mode 100644 index 00000000..75131851 --- /dev/null +++ b/tulip/unix_events.py @@ -0,0 +1,555 @@ +"""Selector eventloop for Unix with signal handling.""" + +import collections +import errno +import fcntl +import functools +import os +import signal +import socket +import stat +import subprocess +import sys + + +from . import constants +from . import events +from . import protocols +from . import selector_events +from . import tasks +from . import transports +from .log import tulip_log + + +__all__ = ['SelectorEventLoop'] + + +if sys.platform == 'win32': # pragma: no cover + raise ImportError('Signals are not really supported on Windows') + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Unix event loop + + Adds signal handling to SelectorEventLoop + """ + + def __init__(self, selector=None): + super().__init__(selector) + self._signal_handlers = {} + self._subprocesses = {} + + def _socketpair(self): + return socket.socketpair() + + def close(self): + handler = self._signal_handlers.get(signal.SIGCHLD) + if handler is not None: + self.remove_signal_handler(signal.SIGCHLD) + super().close() + + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + self._check_signal(sig) + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except ValueError as exc: + raise RuntimeError(str(exc)) + + handle = events.make_handle(callback, args) + self._signal_handlers[sig] = handle + + try: + signal.signal(sig, self._handle_signal) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as nexc: + tulip_log.info('set_wakeup_fd(-1) failed: %s', nexc) + + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + def _handle_signal(self, sig, arg): + """Internal helper that is the actual signal handler.""" + handle = self._signal_handlers.get(sig) + if handle is None: + return # Assume it's some race condition. + if handle._cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self._add_callback_signalsafe(handle) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not. + """ + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as exc: + tulip_log.info('set_wakeup_fd(-1) failed: %s', exc) + + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + + if not (1 <= sig < signal.NSIG): + raise ValueError( + 'sig {} out of range(1, {})'.format(sig, signal.NSIG)) + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixReadPipeTransport(self, pipe, protocol, waiter, extra) + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixWritePipeTransport(self, pipe, protocol, waiter, extra) + + @tasks.coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + self._reg_sigchld() + transp = _UnixSubprocessTransport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs) + self._subprocesses[transp.get_pid()] = transp + yield from transp._post_init() + return transp + + def _reg_sigchld(self): + if signal.SIGCHLD not in self._signal_handlers: + self.add_signal_handler(signal.SIGCHLD, self._sig_chld) + + def _sig_chld(self): + try: + while True: + try: + pid, status = os.waitpid(0, os.WNOHANG) + except ChildProcessError: + break + if pid == 0: + continue + elif os.WIFSIGNALED(status): + returncode = -os.WTERMSIG(status) + elif os.WIFEXITED(status): + returncode = os.WEXITSTATUS(status) + else: + # covered by + # SelectorEventLoopTests.test__sig_chld_unknown_status + # from tests/unix_events_test.py + # bug in coverage.py version 3.6 ??? + continue # pragma: no cover + transp = self._subprocesses.get(pid) + if transp is not None: + transp._process_exited(returncode) + except Exception: + tulip_log.exception('Unknown exception in SIGCHLD handler') + + def _subprocess_closed(self, transport): + pid = transport.get_pid() + self._subprocesses.pop(pid, None) + + +def _set_nonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + + +class _UnixReadPipeTransport(transports.ReadTransport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._loop = loop + self._pipe = pipe + self._fileno = pipe.fileno() + _set_nonblocking(self._fileno) + self._protocol = protocol + self._closing = False + self._loop.add_reader(self._fileno, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = os.read(self._fileno, self.max_size) + except (BlockingIOError, InterruptedError): + pass + except OSError as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + self._closing = True + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._protocol.eof_received) + self._loop.call_soon(self._call_connection_lost, None) + + def pause(self): + self._loop.remove_reader(self._fileno) + + def resume(self): + self._loop.add_reader(self._fileno, self._read_ready) + + def close(self): + if not self._closing: + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._closing = True + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + self._pipe = None + self._protocol = None + self._loop = None + + +class _UnixWritePipeTransport(transports.WriteTransport): + + def __init__(self, loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._loop = loop + self._pipe = pipe + self._fileno = pipe.fileno() + if not stat.S_ISFIFO(os.fstat(self._fileno).st_mode): + raise ValueError("Pipe transport is for pipes only.") + _set_nonblocking(self._fileno) + self._protocol = protocol + self._buffer = [] + self._conn_lost = 0 + self._closing = False # Set when close() or write_eof() called. + self._writing = True + self._loop.add_reader(self._fileno, self._read_ready) + + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + # pipe was closed by peer + + self._close() + + def write(self, data): + assert isinstance(data, bytes), repr(data) + assert not self._closing + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('os.write(pipe, data) raised exception.') + self._conn_lost += 1 + return + + if not self._buffer and self._writing: + # Attempt to send it right away first. + try: + n = os.write(self._fileno, data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + self._conn_lost += 1 + self._fatal_error(exc) + return + if n == len(data): + return + elif n > 0: + data = data[n:] + self._loop.add_writer(self._fileno, self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + if not self._writing: + return + + data = b''.join(self._buffer) + assert data, 'Data should not be empty' + + self._buffer.clear() + try: + n = os.write(self._fileno, data) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except Exception as exc: + self._conn_lost += 1 + # Remove writer here, _fatal_error() doesn't it + # because _buffer is empty. + self._loop.remove_writer(self._fileno) + self._fatal_error(exc) + else: + if n == len(data): + self._loop.remove_writer(self._fileno) + if self._closing: + self._loop.remove_reader(self._fileno) + self._call_connection_lost(None) + return + elif n > 0: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def can_write_eof(self): + return True + + def write_eof(self): + assert not self._closing + assert self._pipe + self._closing = True + if not self._buffer: + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, None) + + def close(self): + if not self._closing: + # write_eof is all what we needed to close the write pipe + self.write_eof() + + def abort(self): + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc=None): + self._closing = True + if self._buffer: + self._loop.remove_writer(self._fileno) + self._buffer.clear() + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + self._pipe = None + self._protocol = None + self._loop = None + + def pause_writing(self): + if self._writing: + if self._buffer: + self._loop.remove_writer(self._fileno) + self._writing = False + + def resume_writing(self): + if not self._writing: + if self._buffer: + self._loop.add_writer(self._fileno, self._write_ready) + self._writing = True + + def discard_output(self): + if self._buffer: + self._loop.remove_writer(self._fileno) + self._buffer.clear() + + +class _UnixWriteSubprocessPipeProto(protocols.BaseProtocol): + pipe = None + + def __init__(self, proc, fd): + self.proc = proc + self.fd = fd + self.connected = False + self.disconnected = False + proc._pipes[fd] = self + + def connection_made(self, transport): + self.connected = True + self.pipe = transport + self.proc._try_connected() + + def connection_lost(self, exc): + self.disconnected = True + self.proc._pipe_connection_lost(self.fd, exc) + + +class _UnixReadSubprocessPipeProto(_UnixWriteSubprocessPipeProto, + protocols.Protocol): + + def data_received(self, data): + self.proc._pipe_data_received(self.fd, data) + + def eof_received(self): + pass + + +class _UnixSubprocessTransport(transports.SubprocessTransport): + + def __init__(self, loop, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + super().__init__(extra) + self._protocol = protocol + self._loop = loop + + self._pipes = {} + if stdin == subprocess.PIPE: + self._pipes[0] = None + if stdout == subprocess.PIPE: + self._pipes[1] = None + if stderr == subprocess.PIPE: + self._pipes[2] = None + self._pending_calls = collections.deque() + self._finished = False + self._returncode = None + + self._proc = subprocess.Popen( + args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, + universal_newlines=False, bufsize=bufsize, **kwargs) + self._extra['subprocess'] = self._proc + + def close(self): + for proto in self._pipes.values(): + proto.pipe.close() + if self._returncode is None: + self.terminate() + + def get_pid(self): + return self._proc.pid + + def get_returncode(self): + return self._returncode + + def get_pipe_transport(self, fd): + if fd in self._pipes: + return self._pipes[fd].pipe + else: + return None + + def send_signal(self, signal): + self._proc.send_signal(signal) + + def terminate(self): + self._proc.terminate() + + def kill(self): + self._proc.kill() + + @tasks.coroutine + def _post_init(self): + proc = self._proc + loop = self._loop + if proc.stdin is not None: + transp, proto = yield from loop.connect_write_pipe( + functools.partial(_UnixWriteSubprocessPipeProto, self, 0), + proc.stdin) + if proc.stdout is not None: + transp, proto = yield from loop.connect_read_pipe( + functools.partial(_UnixReadSubprocessPipeProto, self, 1), + proc.stdout) + if proc.stderr is not None: + transp, proto = yield from loop.connect_read_pipe( + functools.partial(_UnixReadSubprocessPipeProto, self, 2), + proc.stderr) + if not self._pipes: + self._try_connected() + + def _call(self, cb, *data): + if self._pending_calls is not None: + self._pending_calls.append((cb, data)) + else: + self._loop.call_soon(cb, *data) + + def _try_connected(self): + assert self._pending_calls is not None + if all(p is not None and p.connected for p in self._pipes.values()): + self._loop.call_soon(self._protocol.connection_made, self) + for callback, data in self._pending_calls: + self._loop.call_soon(callback, *data) + self._pending_calls = None + + def _pipe_connection_lost(self, fd, exc): + self._call(self._protocol.pipe_connection_lost, fd, exc) + self._try_finish() + + def _pipe_data_received(self, fd, data): + self._call(self._protocol.pipe_data_received, fd, data) + + def _process_exited(self, returncode): + assert returncode is not None, returncode + assert self._returncode is None, self._returncode + self._returncode = returncode + self._loop._subprocess_closed(self) + self._call(self._protocol.process_exited) + self._try_finish() + + def _try_finish(self): + assert not self._finished + if self._returncode is None: + return + if all(p is not None and p.disconnected + for p in self._pipes.values()): + self._finished = True + self._loop.call_soon(self._call_connection_lost, None) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._proc = None + self._protocol = None + self._loop = None diff --git a/tulip/windows_events.py b/tulip/windows_events.py new file mode 100644 index 00000000..527922e1 --- /dev/null +++ b/tulip/windows_events.py @@ -0,0 +1,206 @@ +"""Selector and proactor eventloops for Windows.""" + +import errno +import socket +import weakref +import struct +import _winapi + +from . import futures +from . import proactor_events +from . import selector_events +from . import windows_utils +from . import _overlapped +from .log import tulip_log + + +__all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor'] + + +NULL = 0 +INFINITE = 0xffffffff +ERROR_CONNECTION_REFUSED = 1225 +ERROR_CONNECTION_ABORTED = 1236 + + +class _OverlappedFuture(futures.Future): + """Subclass of Future which represents an overlapped operation. + + Cancelling it will immediately cancel the overlapped operation. + """ + + def __init__(self, ov, *, loop=None): + super().__init__(loop=loop) + self.ov = ov + + def cancel(self): + try: + self.ov.cancel() + except OSError: + pass + return super().cancel() + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + def _socketpair(self): + return windows_utils.socketpair() + + +class ProactorEventLoop(proactor_events.BaseProactorEventLoop): + def __init__(self, proactor=None): + if proactor is None: + proactor = IocpProactor() + super().__init__(proactor) + + def _socketpair(self): + return windows_utils.socketpair() + + +class IocpProactor: + + def __init__(self, concurrency=0xffffffff): + self._loop = None + self._results = [] + self._iocp = _overlapped.CreateIoCompletionPort( + _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency) + self._cache = {} + self._registered = weakref.WeakSet() + self._stopped_serving = weakref.WeakSet() + + def set_loop(self, loop): + self._loop = loop + + def registered_count(self): + return len(self._cache) + + def select(self, timeout=None): + if not self._results: + self._poll(timeout) + tmp = self._results + self._results = [] + return tmp + + def recv(self, conn, nbytes, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + handle = getattr(conn, 'handle', None) + if handle is None: + ov.WSARecv(conn.fileno(), nbytes, flags) + else: + ov.ReadFile(handle, nbytes) + return self._register(ov, conn, ov.getresult) + + def send(self, conn, buf, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + handle = getattr(conn, 'handle', None) + if handle is None: + ov.WSASend(conn.fileno(), buf, flags) + else: + ov.WriteFile(handle, buf) + return self._register(ov, conn, ov.getresult) + + def accept(self, listener): + self._register_with_iocp(listener) + conn = self._get_accept_socket() + ov = _overlapped.Overlapped(NULL) + ov.AcceptEx(listener.fileno(), conn.fileno()) + + def finish_accept(): + ov.getresult() + buf = struct.pack('@P', listener.fileno()) + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_ACCEPT_CONTEXT, + buf) + conn.settimeout(listener.gettimeout()) + return conn, conn.getpeername() + + return self._register(ov, listener, finish_accept) + + def connect(self, conn, address): + self._register_with_iocp(conn) + # the socket needs to be locally bound before we call ConnectEx() + try: + _overlapped.BindLocal(conn.fileno(), len(address)) + except OSError as e: + if e.winerror != errno.WSAEINVAL: + raise + # probably already locally bound; check using getsockname() + if conn.getsockname()[1] == 0: + raise + ov = _overlapped.Overlapped(NULL) + ov.ConnectEx(conn.fileno(), address) + + def finish_connect(): + ov.getresult() + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_CONNECT_CONTEXT, + 0) + return conn + + return self._register(ov, conn, finish_connect) + + def _register_with_iocp(self, obj): + if obj not in self._registered: + self._registered.add(obj) + _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) + + def _register(self, ov, obj, callback): + f = _OverlappedFuture(ov, loop=self._loop) + self._cache[ov.address] = (f, ov, obj, callback) + return f + + def _get_accept_socket(self): + s = socket.socket() + s.settimeout(0) + return s + + def _poll(self, timeout=None): + if timeout is None: + ms = INFINITE + elif timeout < 0: + raise ValueError("negative timeout") + else: + ms = int(timeout * 1000 + 0.5) + if ms >= INFINITE: + raise ValueError("timeout too big") + while True: + status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms) + if status is None: + return + address = status[3] + f, ov, obj, callback = self._cache.pop(address) + if obj in self._stopped_serving: + f.cancel() + elif not f.cancelled(): + try: + value = callback() + except OSError as e: + f.set_exception(e) + self._results.append(f) + else: + f.set_result(value) + self._results.append(f) + ms = 0 + + def stop_serving(self, obj): + # obj is a socket or pipe handle. It will be closed in + # BaseProactorEventLoop.stop_serving() which will make any + # pending operations fail quickly. + self._stopped_serving.add(obj) + + def close(self): + for (f, ov, obj, callback) in self._cache.values(): + try: + ov.cancel() + except OSError: + pass + + while self._cache: + if not self._poll(1): + tulip_log.debug('taking long time to close proactor') + + self._results = [] + if self._iocp is not None: + _winapi.CloseHandle(self._iocp) + self._iocp = None diff --git a/tulip/windows_utils.py b/tulip/windows_utils.py new file mode 100644 index 00000000..bf85f31e --- /dev/null +++ b/tulip/windows_utils.py @@ -0,0 +1,181 @@ +""" +Various Windows specific bits and pieces +""" + +import sys + +if sys.platform != 'win32': # pragma: no cover + raise ImportError('win32 only') + +import socket +import itertools +import msvcrt +import os +import subprocess +import tempfile +import _winapi + + +__all__ = ['socketpair', 'pipe', 'Popen', 'PIPE', 'PipeHandle'] + +# +# Constants/globals +# + +BUFSIZE = 8192 +PIPE = subprocess.PIPE +_mmap_counter=itertools.count() + +# +# Replacement for socket.socketpair() +# + +def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """A socket pair usable as a self-pipe, for Windows. + + Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. + """ + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + lsock.bind(('localhost', 0)) + lsock.listen(1) + addr, port = lsock.getsockname() + csock = socket.socket(family, type, proto) + csock.setblocking(False) + try: + csock.connect((addr, port)) + except (BlockingIOError, InterruptedError): + pass + except Exception: + lsock.close() + csock.close() + raise + ssock, _ = lsock.accept() + csock.setblocking(True) + lsock.close() + return (ssock, csock) + +# +# Replacement for os.pipe() using handles instead of fds +# + +def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): + """Like os.pipe() but with overlapped support and using handles not fds.""" + address = tempfile.mktemp(prefix=r'\\.\pipe\python-pipe-%d-%d-' % + (os.getpid(), next(_mmap_counter))) + + if duplex: + openmode = _winapi.PIPE_ACCESS_DUPLEX + access = _winapi.GENERIC_READ | _winapi.GENERIC_WRITE + obsize, ibsize = bufsize, bufsize + else: + openmode = _winapi.PIPE_ACCESS_INBOUND + access = _winapi.GENERIC_WRITE + obsize, ibsize = 0, bufsize + + openmode |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE + + if overlapped[0]: + openmode |= _winapi.FILE_FLAG_OVERLAPPED + + if overlapped[1]: + flags_and_attribs = _winapi.FILE_FLAG_OVERLAPPED + else: + flags_and_attribs = 0 + + h1 = h2 = None + try: + h1 = _winapi.CreateNamedPipe( + address, openmode, _winapi.PIPE_WAIT, + 1, obsize, ibsize, _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL) + + h2 = _winapi.CreateFile( + address, access, 0, _winapi.NULL, _winapi.OPEN_EXISTING, + flags_and_attribs, _winapi.NULL) + + ov = _winapi.ConnectNamedPipe(h1, overlapped=True) + ov.GetOverlappedResult(True) + return h1, h2 + except: + if h1 is not None: + _winapi.CloseHandle(h1) + if h2 is not None: + _winapi.CloseHandle(h2) + raise + +# +# Wrapper for a pipe handle +# + +class PipeHandle: + """Wrapper for an overlapped pipe handle which is vaguely file-object like. + + The IOCP event loop can use these instead of socket objects. + """ + def __init__(self, handle): + self._handle = handle + + @property + def handle(self): + return self._handle + + def fileno(self): + return self._handle + + def close(self, *, CloseHandle=_winapi.CloseHandle): + if self._handle is not None: + CloseHandle(self._handle) + self._handle = None + + __del__ = close + + def __enter__(self): + return self + + def __exit__(self, t, v, tb): + self.close() + +# +# Replacement for subprocess.Popen using overlapped pipe handles +# + +class Popen(subprocess.Popen): + """Replacement for subprocess.Popen using overlapped pipe handles. + + The stdin, stdout, stderr are None or instances of PipeHandle. + """ + def __init__(self, args, stdin=None, stdout=None, stderr=None, **kwds): + stdin_rfd = stdout_wfd = stderr_wfd = None + stdin_wh = stdout_rh = stderr_rh = None + if stdin == PIPE: + stdin_rh, stdin_wh = pipe(overlapped=(False, True)) + stdin_rfd = msvcrt.open_osfhandle(stdin_rh, os.O_RDONLY) + if stdout == PIPE: + stdout_rh, stdout_wh = pipe(overlapped=(True, False)) + stdout_wfd = msvcrt.open_osfhandle(stdout_wh, 0) + if stderr == PIPE: + stderr_rh, stderr_wh = pipe(overlapped=(True, False)) + stderr_wfd = msvcrt.open_osfhandle(stderr_wh, 0) + try: + super().__init__(args, bufsize=0, universal_newlines=False, + stdin=stdin_rfd, stdout=stdout_wfd, + stderr=stderr_wfd, **kwds) + except: + for h in (stdin_wh, stdout_rh, stderr_rh): + _winapi.CloseHandle(h) + raise + else: + if stdin_wh is not None: + self.stdin = PipeHandle(stdin_wh) + if stdout_rh is not None: + self.stdout = PipeHandle(stdout_rh) + if stderr_rh is not None: + self.stderr = PipeHandle(stderr_rh) + finally: + if stdin == PIPE: + os.close(stdin_rfd) + if stdout == PIPE: + os.close(stdout_wfd) + if stderr == PIPE: + os.close(stderr_wfd) From 0740e1ffc7874efd2ae15683738146cb25ce9406 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 27 Aug 2013 09:00:16 -0700 Subject: [PATCH 0572/1502] Avoid yield in finally clause of Condition.wait(). --- tulip/locks.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tulip/locks.py b/tulip/locks.py index 622a499b..7c0a8f2a 100644 --- a/tulip/locks.py +++ b/tulip/locks.py @@ -261,16 +261,21 @@ def wait(self, timeout=None): fut = futures.Future(loop=self._loop, timeout=timeout) self._condition_waiters.append(fut) + keep_lock = True try: yield from fut except futures.CancelledError: self._condition_waiters.remove(fut) return False + except GeneratorExit: + keep_lock = False # Prevent yield in finally clause. + raise else: f = self._condition_waiters.popleft() assert fut is f finally: - yield from self.acquire() + if keep_lock: + yield from self.acquire() return True From e4e06680f80b63b7e600456ce93690cc09ba0ad3 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Thu, 29 Aug 2013 11:07:14 +0300 Subject: [PATCH 0573/1502] Add --findleaks option to runtests.py --- runtests.py | 53 ++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 5 deletions(-) diff --git a/runtests.py b/runtests.py index 7c62bf45..1ec0b298 100644 --- a/runtests.py +++ b/runtests.py @@ -23,6 +23,7 @@ # Originally written by Beech Horn (for NDB). import argparse +import gc import logging import os import re @@ -49,7 +50,10 @@ dest='catchbreak', help='Catch control-C and display results') ARGS.add_argument( '--forever', action="store_true", dest='forever', default=False, - help='Run tests forever to catch sporadic errors') + help='run tests forever to catch sporadic errors') +ARGS.add_argument( + '--findleaks', action='store_true', dest='findleaks', + help='detect tests that leak memory') ARGS.add_argument( '-q', action="store_true", dest='quiet', help='quiet') ARGS.add_argument( @@ -133,6 +137,43 @@ def load_tests(testsdir, includes=(), excludes=()): return suite +class TestResult(unittest.TextTestResult): + + def __init__(self, stream, descriptions, verbosity): + super().__init__(stream, descriptions, verbosity) + self.leaks = [] + + def startTest(self, test): + super().startTest(test) + gc.collect() + + def addSuccess(self, test): + super().addSuccess(test) + gc.collect() + if gc.garbage: + if self.showAll: + self.stream.writeln(" Warning: test created {} uncollectable " + "object(s).".format(len(gc.garbage))) + # move the uncollectable objects somewhere so we don't see + # them again + self.leaks.append((self.getDescription(test), gc.garbage[:])) + del gc.garbage[:] + + +class TestRunner(unittest.TextTestRunner): + resultclass = TestResult + + def run(self, test): + result = super().run(test) + if result.leaks: + self.stream.writeln("{} tests leaks:".format(len(result.leaks))) + for name, leaks in result.leaks: + self.stream.writeln(' '*4 + name + ':') + for leak in leaks: + self.stream.writeln(' '*8 + repr(leak)) + return result + + def runtests(): args = ARGS.parse_args() @@ -151,6 +192,8 @@ def runtests(): v = 0 if args.quiet else args.verbose + 1 failfast = args.failfast catchbreak = args.catchbreak + findleaks = args.findleaks + runner_factory = TestRunner if findleaks else unittest.TextTestRunner tests = load_tests(args.testsdir, includes, excludes) logger = logging.getLogger() @@ -168,13 +211,13 @@ def runtests(): installHandler() if args.forever: while True: - result = unittest.TextTestRunner(verbosity=v, - failfast=failfast).run(tests) + result = runner_factory(verbosity=v, + failfast=failfast).run(tests) if not result.wasSuccessful(): sys.exit(1) else: - result = unittest.TextTestRunner(verbosity=v, - failfast=failfast).run(tests) + result = runner_factory(verbosity=v, + failfast=failfast).run(tests) sys.exit(not result.wasSuccessful()) From 5f39b10be5bb03a731405b480a7ef6bb2df9be4d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 29 Aug 2013 08:10:18 -0700 Subject: [PATCH 0574/1502] Make script docstrings and doc pointers for coverage consistent. --- Makefile | 2 +- README | 4 ++-- runtests.py | 20 +++++++++----------- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/Makefile b/Makefile index 11fe52ca..6064fc63 100644 --- a/Makefile +++ b/Makefile @@ -14,7 +14,7 @@ vtest: testloop: while sleep 1; do $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS); done -# See README for coverage installation instructions. +# See runtests.py for coverage installation instructions. cov coverage: $(PYTHON) runtests.py --coverage tulip -v $(VERBOSE) $(FLAGS) echo "open file://`pwd`/htmlcov/index.html" diff --git a/README b/README index 85bfe5a7..8f2b6373 100644 --- a/README +++ b/README @@ -8,8 +8,8 @@ Copyright/license: Open source, Apache 2.0. Enjoy. Master Mercurial repo: http://code.google.com/p/tulip/ -The old code lives in the subdirectory 'old'; the new code (conforming -to PEP 3156, under construction) lives in the 'tulip' subdirectory. +The actual code lives in the 'tulip' subdirectory. +Tests are in the 'tests' subdirectory. To run tests: - make test diff --git a/runtests.py b/runtests.py index 1ec0b298..484bff09 100644 --- a/runtests.py +++ b/runtests.py @@ -1,18 +1,15 @@ -"""Run all unittests. +"""Run Tulip unittests. Usage: - python3 runtests.py [-v] [-q] [pattern] ... + python3 runtests.py [flags] [pattern] ... -Where: - -v: verbose - -q: quiet - pattern: optional regex patterns to match test ids (default all tests) - -Note that the test id is the fully qualified name of the test, +Patterns are matched against the fully qualified name of the test, including package, module, class and method, e.g. 'tests.events_test.PolicyTests.testPolicy'. -runtests.py with --coverage argument is equivalent of: +For full help, try --help. + +runtests.py --coverage is equivalent of: $(COVERAGE) run --branch runtests.py -v $(COVERAGE) html $(list of files) @@ -152,8 +149,9 @@ def addSuccess(self, test): gc.collect() if gc.garbage: if self.showAll: - self.stream.writeln(" Warning: test created {} uncollectable " - "object(s).".format(len(gc.garbage))) + self.stream.writeln( + " Warning: test created {} uncollectable " + "object(s).".format(len(gc.garbage))) # move the uncollectable objects somewhere so we don't see # them again self.leaks.append((self.getDescription(test), gc.garbage[:])) From 6b5141dc1ac204dd75d1ced3b0f7feddb1a7f63c Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Thu, 29 Aug 2013 18:43:11 +0300 Subject: [PATCH 0575/1502] Update tests to get rid of some resource leaks. --- tests/http_protocol_test.py | 6 ++++++ tests/locks_test.py | 17 +++++++++++++++ tests/queues_test.py | 14 +++++++++++-- tests/tasks_test.py | 41 ++++++++++++++++++++++++++++++++++++- 4 files changed, 75 insertions(+), 3 deletions(-) diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py index e74b8f27..ec3aaf58 100644 --- a/tests/http_protocol_test.py +++ b/tests/http_protocol_test.py @@ -189,6 +189,8 @@ def test_send_headers(self): self.assertIn(b'CONTENT-TYPE: plain/html', content) self.assertTrue(msg.headers_sent) self.assertTrue(msg.is_headers_sent()) + # cleanup + msg.writer.close() def test_send_headers_nomore_add(self): msg = protocol.Response(self.transport, 200) @@ -197,6 +199,8 @@ def test_send_headers_nomore_add(self): self.assertRaises(AssertionError, msg.add_header, 'content-type', 'plain/html') + # cleanup + msg.writer.close() def test_prepare_length(self): msg = protocol.Response(self.transport, 200) @@ -244,6 +248,8 @@ def test_write_auto_send_headers(self): msg.write(b'data1') self.assertTrue(msg.headers_sent) + # cleanup + msg.writer.close() def test_write_payload_eof(self): write = self.transport.write = unittest.mock.Mock() diff --git a/tests/locks_test.py b/tests/locks_test.py index 83663ec0..529c7268 100644 --- a/tests/locks_test.py +++ b/tests/locks_test.py @@ -163,6 +163,11 @@ def gen(): self.assertEqual(3, len(lock._waiters)) + # wakeup to close waiting coroutines + for i in range(3): + lock.release() + test_utils.run_briefly(loop) + def test_acquire_cancel(self): lock = locks.Lock(loop=self.loop) self.assertTrue(self.loop.run_until_complete(lock.acquire())) @@ -335,6 +340,10 @@ def gen(): self.assertAlmostEqual(0.1, loop.time()) self.assertEqual(3, len(ev._waiters)) + # wakeup to close waiting coroutines + ev.set() + test_utils.run_briefly(loop) + def test_wait_cancel(self): ev = locks.EventWaiter(loop=self.loop) @@ -822,6 +831,9 @@ def c4(result): self.assertTrue(t3.result()) self.assertFalse(t4.done()) + # cleanup locked semaphore + sem.release() + def test_acquire_timeout(self): def gen(): when = yield @@ -873,6 +885,11 @@ def gen(): self.assertEqual(3, len(sem._waiters)) + # wakeup to close waiting coroutines + for i in range(3): + sem.release() + test_utils.run_briefly(loop) + def test_acquire_cancel(self): sem = locks.Semaphore(loop=self.loop) self.loop.run_until_complete(sem.acquire()) diff --git a/tests/queues_test.py b/tests/queues_test.py index 4d4876b9..0dce6653 100644 --- a/tests/queues_test.py +++ b/tests/queues_test.py @@ -52,6 +52,8 @@ def add_getter(): # Let it start waiting. yield from tasks.sleep(0.1, loop=loop) self.assertTrue('_getters[1]' in fn(q)) + # resume q.get coroutine to finish generator + q.put_nowait(0) loop.run_until_complete(add_getter()) @@ -60,10 +62,12 @@ def add_putter(): q = queues.Queue(maxsize=1, loop=loop) q.put_nowait(1) # Start a task that waits to put. - tasks.Task(q.put(2), loop=loop) + t = tasks.Task(q.put(2), loop=loop) # Let it start waiting. yield from tasks.sleep(0.1, loop=loop) self.assertTrue('_putters[1]' in fn(q)) + # resume q.put coroutine to finish generator + q.get_nowait() loop.run_until_complete(add_putter()) @@ -437,12 +441,13 @@ def test_task_done(self): # Two workers get items from the queue and call task_done after each. # Join the queue and assert all items have been processed. + running = True @tasks.coroutine def worker(): nonlocal accumulator - while True: + while running: item = yield from q.get() accumulator += item q.task_done() @@ -457,6 +462,11 @@ def test(): self.loop.run_until_complete(test()) self.assertEqual(sum(range(100)), accumulator) + # close running generators + running = False + for i in range(2): + q.put_nowait(0) + def test_join_empty_queue(self): q = queues.JoinableQueue(loop=self.loop) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 56a5e128..3e1220dc 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -358,12 +358,14 @@ def gen(): self.addCleanup(loop.close) x = 0 + waiters = [] @tasks.coroutine def task(): nonlocal x while x < 10: - yield from tasks.sleep(0.1, loop=loop) + waiters.append(tasks.sleep(0.1, loop=loop)) + yield from waiters[-1] x += 1 if x == 2: loop.stop() @@ -375,6 +377,10 @@ def task(): self.assertEqual(x, 2) self.assertAlmostEqual(0.3, loop.time()) + # close generators + for w in waiters: + w.close() + def test_timeout(self): def gen(): @@ -398,6 +404,11 @@ def task(): self.assertAlmostEqual(0.1, loop.time()) self.assertFalse(t.done()) + # move forward to close generator + loop.advance_time(10) + self.assertEqual(42, loop.run_until_complete(t)) + self.assertTrue(t.done()) + def test_timeout_not(self): def gen(): @@ -483,6 +494,10 @@ def foo(): self.assertAlmostEqual(0.01, loop.time()) self.assertFalse(fut.done()) + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(fut) + def test_wait(self): def gen(): @@ -582,6 +597,10 @@ def gen(): self.assertIsNone(b.result()) self.assertAlmostEqual(0.1, loop.time()) + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + def test_wait_really_done(self): # there is possibility that some tasks in the pending list # became done but their callbacks haven't all been called yet @@ -637,6 +656,10 @@ def exc(): self.assertEqual({a}, pending) self.assertAlmostEqual(0, loop.time()) + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + def test_wait_first_exception_in_wait(self): def gen(): @@ -666,6 +689,10 @@ def exc(): self.assertEqual({a}, pending) self.assertAlmostEqual(0.01, loop.time()) + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + def test_wait_with_exception(self): def gen(): @@ -728,6 +755,10 @@ def foo(): loop.run_until_complete(tasks.Task(foo(), loop=loop)) self.assertAlmostEqual(0.11, loop.time()) + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + def test_wait_concurrent_complete(self): def gen(): @@ -752,6 +783,10 @@ def gen(): self.assertEqual(pending, set([b])) self.assertAlmostEqual(0.1, loop.time()) + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + def test_as_completed(self): def gen(): @@ -833,6 +868,10 @@ def foo(): self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) self.assertAlmostEqual(0.12, loop.time()) + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + def test_as_completed_reverse_wait(self): def gen(): From 2b0706bb0061c0542a5d8ed691d28b13d999c195 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 30 Aug 2013 06:48:41 +0300 Subject: [PATCH 0576/1502] Mark debug assertion in test_utils as not be analyzed by coverage. --- tulip/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tulip/test_utils.py b/tulip/test_utils.py index 05d5e6ab..2534a9a1 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -380,7 +380,7 @@ def close(self): self._gen.send(0) except StopIteration: pass - else: + else: # pragma: no cover raise AssertionError("Time generator is not finished") def add_reader(self, fd, callback, *args): From 4537576408e4170d91a1577307ccb2889820f7e7 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 30 Aug 2013 13:14:01 +0300 Subject: [PATCH 0577/1502] Get rid of using writable attr for main thread detection. --- tests/events_test.py | 15 +++++++++------ tulip/events.py | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 16058f19..7c342bad 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -1534,13 +1534,16 @@ def test_get_event_loop_after_set_none(self): policy.set_event_loop(None) self.assertRaises(AssertionError, policy.get_event_loop) - @unittest.mock.patch('tulip.events.threading') - def test_get_event_loop_thread(self, m_threading): - m_t = m_threading.current_thread.return_value = unittest.mock.Mock() - m_t.name = 'Thread 1' + @unittest.mock.patch('tulip.events.threading.current_thread') + def test_get_event_loop_thread(self, m_current_thread): - policy = events.DefaultEventLoopPolicy() - self.assertRaises(AssertionError, policy.get_event_loop) + def f(): + policy = events.DefaultEventLoopPolicy() + self.assertRaises(AssertionError, policy.get_event_loop) + + th = threading.Thread(target=f) + th.start() + th.join() def test_new_event_loop(self): policy = events.DefaultEventLoopPolicy() diff --git a/tulip/events.py b/tulip/events.py index 37b95594..e292eea2 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -327,7 +327,7 @@ def get_event_loop(self): """ if (self._loop is None and not self._set_called and - threading.current_thread().name == 'MainThread'): + isinstance(threading.current_thread(), threading._MainThread)): self._loop = self.new_event_loop() assert self._loop is not None, \ ('There is no current event loop in thread %r.' % From d7a48848afca157221130c9db379b720715a2cd9 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 30 Aug 2013 11:58:38 -0700 Subject: [PATCH 0578/1502] Add C-F Natali's selectors.py from http://bugs.python.org/issue16853. --- tests/base_events_test.py | 1 - tests/selector_events_test.py | 64 +++++---- tests/selectors_test.py | 67 +++++----- tulip/selector_events.py | 15 ++- tulip/selectors.py | 240 +++++++++++++++++----------------- tulip/test_utils.py | 2 +- tulip/windows_events.py | 3 - 7 files changed, 200 insertions(+), 192 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index 104dc763..e27b3ab9 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -19,7 +19,6 @@ class BaseEventLoopTests(unittest.TestCase): def setUp(self): self.loop = base_events.BaseEventLoop() self.loop._selector = unittest.mock.Mock() - self.loop._selector.registered_count.return_value = 1 events.set_event_loop(None) def test_not_implemented(self): diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index 1395bf9b..06318622 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -396,7 +396,7 @@ def test__sock_accept_exception(self): self.assertIs(err, f.exception()) def test_add_reader(self): - self.loop._selector.get_info.side_effect = KeyError + self.loop._selector.get_key.side_effect = KeyError cb = lambda: True self.loop.add_reader(1, cb) @@ -410,8 +410,8 @@ def test_add_reader(self): def test_add_reader_existing(self): reader = unittest.mock.Mock() writer = unittest.mock.Mock() - self.loop._selector.get_info.return_value = ( - selectors.EVENT_WRITE, (reader, writer)) + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_WRITE, (reader, writer)) cb = lambda: True self.loop.add_reader(1, cb) @@ -426,8 +426,8 @@ def test_add_reader_existing(self): def test_add_reader_existing_writer(self): writer = unittest.mock.Mock() - self.loop._selector.get_info.return_value = ( - selectors.EVENT_WRITE, (None, writer)) + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_WRITE, (None, writer)) cb = lambda: True self.loop.add_reader(1, cb) @@ -440,8 +440,8 @@ def test_add_reader_existing_writer(self): self.assertEqual(writer, w) def test_remove_reader(self): - self.loop._selector.get_info.return_value = ( - selectors.EVENT_READ, (None, None)) + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_READ, (None, None)) self.assertFalse(self.loop.remove_reader(1)) self.assertTrue(self.loop._selector.unregister.called) @@ -449,8 +449,9 @@ def test_remove_reader(self): def test_remove_reader_read_write(self): reader = unittest.mock.Mock() writer = unittest.mock.Mock() - self.loop._selector.get_info.return_value = ( - selectors.EVENT_READ | selectors.EVENT_WRITE, (reader, writer)) + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_READ | selectors.EVENT_WRITE, + (reader, writer)) self.assertTrue( self.loop.remove_reader(1)) @@ -460,12 +461,12 @@ def test_remove_reader_read_write(self): self.loop._selector.modify.call_args[0]) def test_remove_reader_unknown(self): - self.loop._selector.get_info.side_effect = KeyError + self.loop._selector.get_key.side_effect = KeyError self.assertFalse( self.loop.remove_reader(1)) def test_add_writer(self): - self.loop._selector.get_info.side_effect = KeyError + self.loop._selector.get_key.side_effect = KeyError cb = lambda: True self.loop.add_writer(1, cb) @@ -479,8 +480,8 @@ def test_add_writer(self): def test_add_writer_existing(self): reader = unittest.mock.Mock() writer = unittest.mock.Mock() - self.loop._selector.get_info.return_value = ( - selectors.EVENT_READ, (reader, writer)) + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_READ, (reader, writer)) cb = lambda: True self.loop.add_writer(1, cb) @@ -494,8 +495,8 @@ def test_add_writer_existing(self): self.assertEqual(cb, w._callback) def test_remove_writer(self): - self.loop._selector.get_info.return_value = ( - selectors.EVENT_WRITE, (None, None)) + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_WRITE, (None, None)) self.assertFalse(self.loop.remove_writer(1)) self.assertTrue(self.loop._selector.unregister.called) @@ -503,8 +504,9 @@ def test_remove_writer(self): def test_remove_writer_read_write(self): reader = unittest.mock.Mock() writer = unittest.mock.Mock() - self.loop._selector.get_info.return_value = ( - selectors.EVENT_READ | selectors.EVENT_WRITE, (reader, writer)) + self.loop._selector.get_key.return_value = selectors.SelectorKey( + 1, 1, selectors.EVENT_READ | selectors.EVENT_WRITE, + (reader, writer)) self.assertTrue( self.loop.remove_writer(1)) @@ -514,7 +516,7 @@ def test_remove_writer_read_write(self): self.loop._selector.modify.call_args[0]) def test_remove_writer_unknown(self): - self.loop._selector.get_info.side_effect = KeyError + self.loop._selector.get_key.side_effect = KeyError self.assertFalse( self.loop.remove_writer(1)) @@ -523,8 +525,10 @@ def test_process_events_read(self): reader._cancelled = False self.loop._add_callback = unittest.mock.Mock() - self.loop._process_events( - ((1, selectors.EVENT_READ, (reader, None)),)) + self.loop._process_events([ + (selectors.SelectorKey(1, 1, selectors.EVENT_READ, (reader, None)), + selectors.EVENT_READ), + ]) self.assertTrue(self.loop._add_callback.called) self.loop._add_callback.assert_called_with(reader) @@ -533,8 +537,10 @@ def test_process_events_read_cancelled(self): reader.cancelled = True self.loop.remove_reader = unittest.mock.Mock() - self.loop._process_events( - ((1, selectors.EVENT_READ, (reader, None)),)) + self.loop._process_events([ + (selectors.SelectorKey(1, 1, selectors.EVENT_READ, (reader, None)), + selectors.EVENT_READ), + ]) self.loop.remove_reader.assert_called_with(1) def test_process_events_write(self): @@ -542,8 +548,11 @@ def test_process_events_write(self): writer._cancelled = False self.loop._add_callback = unittest.mock.Mock() - self.loop._process_events( - ((1, selectors.EVENT_WRITE, (None, writer)),)) + self.loop._process_events([ + (selectors.SelectorKey(1, 1, selectors.EVENT_WRITE, + (None, writer)), + selectors.EVENT_WRITE), + ]) self.loop._add_callback.assert_called_with(writer) def test_process_events_write_cancelled(self): @@ -551,8 +560,11 @@ def test_process_events_write_cancelled(self): writer.cancelled = True self.loop.remove_writer = unittest.mock.Mock() - self.loop._process_events( - ((1, selectors.EVENT_WRITE, (None, writer)),)) + self.loop._process_events([ + (selectors.SelectorKey(1, 1, selectors.EVENT_WRITE, + (None, writer)), + selectors.EVENT_WRITE), + ]) self.loop.remove_writer.assert_called_with(1) diff --git a/tests/selectors_test.py b/tests/selectors_test.py index f933f35e..68c1c06b 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -6,6 +6,13 @@ from tulip import selectors +class FakeSelector(selectors.BaseSelector): + """Trivial non-abstract subclass of BaseSelector.""" + + def select(self, timeout=None): + raise NotImplementedError + + class BaseSelectorTests(unittest.TestCase): def test_fileobj_to_fd(self): @@ -15,63 +22,65 @@ def test_fileobj_to_fd(self): f.fileno.return_value = 10 self.assertEqual(10, selectors._fileobj_to_fd(f)) - f.fileno.side_effect = TypeError + f.fileno.side_effect = AttributeError self.assertRaises(ValueError, selectors._fileobj_to_fd, f) def test_selector_key_repr(self): - key = selectors.SelectorKey(10, selectors.EVENT_READ) + key = selectors.SelectorKey(10, 10, selectors.EVENT_READ, None) self.assertEqual( - "SelectorKey", repr(key)) + "SelectorKey(fileobj=10, fd=10, events=1, data=None)", repr(key)) def test_register(self): fobj = unittest.mock.Mock() fobj.fileno.return_value = 10 - s = selectors._BaseSelector() + s = FakeSelector() key = s.register(fobj, selectors.EVENT_READ) self.assertIsInstance(key, selectors.SelectorKey) self.assertEqual(key.fd, 10) self.assertIs(key, s._fd_to_key[10]) def test_register_unknown_event(self): - s = selectors._BaseSelector() + s = FakeSelector() self.assertRaises(ValueError, s.register, unittest.mock.Mock(), 999999) def test_register_already_registered(self): fobj = unittest.mock.Mock() fobj.fileno.return_value = 10 - s = selectors._BaseSelector() + s = FakeSelector() s.register(fobj, selectors.EVENT_READ) - self.assertRaises(ValueError, s.register, fobj, selectors.EVENT_READ) + self.assertRaises(KeyError, s.register, fobj, selectors.EVENT_READ) def test_unregister(self): fobj = unittest.mock.Mock() fobj.fileno.return_value = 10 - s = selectors._BaseSelector() + s = FakeSelector() s.register(fobj, selectors.EVENT_READ) s.unregister(fobj) self.assertFalse(s._fd_to_key) self.assertFalse(s._fileobj_to_key) def test_unregister_unknown(self): - s = selectors._BaseSelector() - self.assertRaises(ValueError, s.unregister, unittest.mock.Mock()) + s = FakeSelector() + self.assertRaises(KeyError, s.unregister, unittest.mock.Mock()) def test_modify_unknown(self): - s = selectors._BaseSelector() - self.assertRaises(ValueError, s.modify, unittest.mock.Mock(), 1) + s = FakeSelector() + self.assertRaises(KeyError, s.modify, unittest.mock.Mock(), 1) def test_modify(self): fobj = unittest.mock.Mock() fobj.fileno.return_value = 10 - s = selectors._BaseSelector() + s = FakeSelector() key = s.register(fobj, selectors.EVENT_READ) key2 = s.modify(fobj, selectors.EVENT_WRITE) self.assertNotEqual(key.events, key2.events) - self.assertEqual((selectors.EVENT_WRITE, None), s.get_info(fobj)) + self.assertEqual( + selectors.SelectorKey(fobj, 10, selectors.EVENT_WRITE, None), + s.get_key(fobj)) def test_modify_data(self): fobj = unittest.mock.Mock() @@ -80,12 +89,14 @@ def test_modify_data(self): d1 = object() d2 = object() - s = selectors._BaseSelector() + s = FakeSelector() key = s.register(fobj, selectors.EVENT_READ, d1) key2 = s.modify(fobj, selectors.EVENT_READ, d2) self.assertEqual(key.events, key2.events) self.assertNotEqual(key.data, key2.data) - self.assertEqual((selectors.EVENT_READ, d2), s.get_info(fobj)) + self.assertEqual( + selectors.SelectorKey(fobj, 10, selectors.EVENT_READ, d2), + s.get_key(fobj)) def test_modify_same(self): fobj = unittest.mock.Mock() @@ -93,35 +104,25 @@ def test_modify_same(self): data = object() - s = selectors._BaseSelector() + s = FakeSelector() key = s.register(fobj, selectors.EVENT_READ, data) key2 = s.modify(fobj, selectors.EVENT_READ, data) self.assertIs(key, key2) def test_select(self): - s = selectors._BaseSelector() + s = FakeSelector() self.assertRaises(NotImplementedError, s.select) def test_close(self): - s = selectors._BaseSelector() + s = FakeSelector() s.register(1, selectors.EVENT_READ) s.close() self.assertFalse(s._fd_to_key) self.assertFalse(s._fileobj_to_key) - def test_registered_count(self): - s = selectors._BaseSelector() - self.assertEqual(0, s.registered_count()) - - s.register(1, selectors.EVENT_READ) - self.assertEqual(1, s.registered_count()) - - s.unregister(1) - self.assertEqual(0, s.registered_count()) - def test_context_manager(self): - s = selectors._BaseSelector() + s = FakeSelector() with s as sel: sel.register(1, selectors.EVENT_READ) @@ -129,14 +130,12 @@ def test_context_manager(self): self.assertFalse(s._fd_to_key) self.assertFalse(s._fileobj_to_key) - @unittest.mock.patch('tulip.selectors.tulip_log') - def test_key_from_fd(self, m_log): - s = selectors._BaseSelector() + def test_key_from_fd(self): + s = FakeSelector() key = s.register(1, selectors.EVENT_READ) self.assertIs(key, s._key_from_fd(1)) self.assertIsNone(s._key_from_fd(10)) - m_log.warning.assert_called_with('No key found for fd %r', 10) if hasattr(selectors.DefaultSelector, 'fileno'): def test_fileno(self): diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 99e73ece..82d22bb6 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -116,11 +116,12 @@ def add_reader(self, fd, callback, *args): """Add a reader callback.""" handle = events.make_handle(callback, args) try: - mask, (reader, writer) = self._selector.get_info(fd) + key = self._selector.get_key(fd) except KeyError: self._selector.register(fd, selectors.EVENT_READ, (handle, None)) else: + mask, (reader, writer) = key.events, key.data self._selector.modify(fd, mask | selectors.EVENT_READ, (handle, writer)) if reader is not None: @@ -129,10 +130,11 @@ def add_reader(self, fd, callback, *args): def remove_reader(self, fd): """Remove a reader callback.""" try: - mask, (reader, writer) = self._selector.get_info(fd) + key = self._selector.get_key(fd) except KeyError: return False else: + mask, (reader, writer) = key.events, key.data mask &= ~selectors.EVENT_READ if not mask: self._selector.unregister(fd) @@ -149,11 +151,12 @@ def add_writer(self, fd, callback, *args): """Add a writer callback..""" handle = events.make_handle(callback, args) try: - mask, (reader, writer) = self._selector.get_info(fd) + key = self._selector.get_key(fd) except KeyError: self._selector.register(fd, selectors.EVENT_WRITE, (None, handle)) else: + mask, (reader, writer) = key.events, key.data self._selector.modify(fd, mask | selectors.EVENT_WRITE, (reader, handle)) if writer is not None: @@ -162,10 +165,11 @@ def add_writer(self, fd, callback, *args): def remove_writer(self, fd): """Remove a writer callback.""" try: - mask, (reader, writer) = self._selector.get_info(fd) + key = self._selector.get_key(fd) except KeyError: return False else: + mask, (reader, writer) = key.events, key.data # Remove both writer and connector. mask &= ~selectors.EVENT_WRITE if not mask: @@ -298,7 +302,8 @@ def _sock_accept(self, fut, registered, sock): fut.set_result((conn, address)) def _process_events(self, event_list): - for fileobj, mask, (reader, writer) in event_list: + for key, mask in event_list: + fileobj, (reader, writer) = key.fileobj, key.data if mask & selectors.EVENT_READ and reader is not None: if reader._cancelled: self.remove_reader(fileobj) diff --git a/tulip/selectors.py b/tulip/selectors.py index 388df25f..fd1abf50 100644 --- a/tulip/selectors.py +++ b/tulip/selectors.py @@ -1,18 +1,23 @@ -"""Select module. +"""Selectors module. -This module supports asynchronous I/O on multiple file descriptors. +This module allows high-level and efficient I/O multiplexing, built upon the +`select` module primitives. """ -import sys -from select import * -from .log import tulip_log +from abc import ABCMeta, abstractmethod +from collections import namedtuple +import functools +import select +import sys +try: + from time import monotonic as time +except ImportError: + from time import time as time # generic events, that must be mapped to implementation-specific ones -# read event EVENT_READ = (1 << 0) -# write event EVENT_WRITE = (1 << 1) @@ -20,7 +25,7 @@ def _fileobj_to_fd(fileobj): """Return a file descriptor from a file object. Parameters: - fileobj -- file descriptor, or any object with a `fileno()` method + fileobj -- file object Returns: corresponding file descriptor @@ -30,28 +35,40 @@ def _fileobj_to_fd(fileobj): else: try: fd = int(fileobj.fileno()) - except (ValueError, TypeError): - raise ValueError("Invalid file object: {!r}".format(fileobj)) + except (AttributeError, ValueError): + raise ValueError("Invalid file object: {!r}".format(fileobj)) from None + if fd < 0: + raise ValueError("Invalid file descriptor: {}".format(fd)) return fd -class SelectorKey: - """Object used internally to associate a file object to its backing file - descriptor, selected event mask and attached data.""" +def _select_interrupt_wrapper(func): + """InterruptedError-safe wrapper for select(), taking the (optional) + timeout into account.""" + @functools.wraps(func) + def wrapper(self, timeout=None): + if timeout is not None and timeout > 0: + deadline = time() + timeout + while True: + try: + return func(self, timeout) + except InterruptedError: + if timeout is not None: + if timeout > 0: + timeout = deadline - time() + if timeout <= 0: + # timeout expired + return [] - def __init__(self, fileobj, events, data=None): - self.fileobj = fileobj - self.fd = _fileobj_to_fd(fileobj) - self.events = events - self.data = data + return wrapper - def __repr__(self): - return '{}'.format( - self.__class__.__name__, - self.fileobj, self.fd, self.events, self.data) +SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data']) +"""Object used to associate a file object to its backing file descriptor, +selected event mask and attached data.""" -class _BaseSelector: + +class BaseSelector(metaclass=ABCMeta): """Base selector class. A selector supports registering file objects to be monitored for specific @@ -83,13 +100,15 @@ def register(self, fileobj, events, data=None): Returns: SelectorKey instance """ - if (not events) or (events & ~(EVENT_READ|EVENT_WRITE)): + if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)): raise ValueError("Invalid events: {}".format(events)) - if fileobj in self._fileobj_to_key: - raise ValueError("{!r} is already registered".format(fileobj)) + key = SelectorKey(fileobj, _fileobj_to_fd(fileobj), events, data) + + if key.fd in self._fd_to_key: + raise KeyError("{!r} (FD {}) is already " + "registered".format(fileobj, key.fd)) - key = SelectorKey(fileobj, events, data) self._fd_to_key[key.fd] = key self._fileobj_to_key[fileobj] = key return key @@ -104,11 +123,10 @@ def unregister(self, fileobj): SelectorKey instance """ try: - key = self._fileobj_to_key[fileobj] + key = self._fileobj_to_key.pop(fileobj) del self._fd_to_key[key.fd] - del self._fileobj_to_key[fileobj] except KeyError: - raise ValueError("{!r} is not registered".format(fileobj)) + raise KeyError("{!r} is not registered".format(fileobj)) from None return key def modify(self, fileobj, events, data=None): @@ -118,12 +136,15 @@ def modify(self, fileobj, events, data=None): fileobj -- file object events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) data -- attached data + + Returns: + SelectorKey instance """ # TODO: Subclasses can probably optimize this even further. try: key = self._fileobj_to_key[fileobj] except KeyError: - raise ValueError("{!r} is not registered".format(fileobj)) + raise KeyError("{!r} is not registered".format(fileobj)) from None if events != key.events or data != key.data: # TODO: If only the data changed, use a shortcut that only # updates the data. @@ -132,6 +153,7 @@ def modify(self, fileobj, events, data=None): else: return key + @abstractmethod def select(self, timeout=None): """Perform the actual selection, until some monitored file objects are ready or a timeout expires. @@ -139,13 +161,13 @@ def select(self, timeout=None): Parameters: timeout -- if timeout > 0, this specifies the maximum wait time, in seconds - if timeout == 0, the select() call won't block, and will + if timeout <= 0, the select() call won't block, and will report the currently ready file objects if timeout is None, select() will block until a monitored file object becomes ready Returns: - list of (fileobj, events, attached data) for ready file objects + list of (key, events) for ready file objects `events` is a bitwise mask of EVENT_READ|EVENT_WRITE """ raise NotImplementedError() @@ -158,27 +180,16 @@ def close(self): self._fd_to_key.clear() self._fileobj_to_key.clear() - def get_info(self, fileobj): - """Return information about a registered file object. + def get_key(self, fileobj): + """Return the key associated to a registered file object. Returns: - (events, data) associated to this file object - - Raises KeyError if the file object is not registered. + SelectorKey for this file object """ try: - key = self._fileobj_to_key[fileobj] + return self._fileobj_to_key[fileobj] except KeyError: - raise KeyError("{} is not registered".format(fileobj)) - return key.events, key.data - - def registered_count(self): - """Return the number of registered file objects. - - Returns: - number of currently registered file objects - """ - return len(self._fd_to_key) + raise KeyError("{} is not registered".format(fileobj)) from None def __enter__(self): return self @@ -193,16 +204,15 @@ def _key_from_fd(self, fd): fd -- file descriptor Returns: - corresponding key + corresponding key, or None if not found """ try: return self._fd_to_key[fd] except KeyError: - tulip_log.warning('No key found for fd %r', fd) return None -class SelectSelector(_BaseSelector): +class SelectSelector(BaseSelector): """Select-based selector.""" def __init__(self): @@ -224,12 +234,17 @@ def unregister(self, fileobj): self._writers.discard(key.fd) return key + if sys.platform == 'win32': + def _select(self, r, w, _, timeout=None): + r, w, x = select.select(r, w, w, timeout) + return r, w + x, [] + else: + _select = select.select + + @_select_interrupt_wrapper def select(self, timeout=None): - try: - r, w, _ = self._select(self._readers, self._writers, [], timeout) - except InterruptedError: - # A signal arrived. Don't die, just return no events. - return [] + timeout = None if timeout is None else max(timeout, 0) + r, w, _ = self._select(self._readers, self._writers, [], timeout) r = set(r) w = set(w) ready = [] @@ -242,37 +257,26 @@ def select(self, timeout=None): key = self._key_from_fd(fd) if key: - ready.append((key.fileobj, events & key.events, key.data)) + ready.append((key, events & key.events)) return ready - if sys.platform == 'win32': - def _select(self, r, w, _, timeout=None): - r, w, x = select(r, w, w, timeout) - return r, w + x, [] - else: - from select import select as _select - -if 'poll' in globals(): +if hasattr(select, 'poll'): - # TODO: Implement poll() for Windows with workaround for - # brokenness in WSAPoll() (Richard Oudkerk, see - # http://bugs.python.org/issue16507). - - class PollSelector(_BaseSelector): + class PollSelector(BaseSelector): """Poll-based selector.""" def __init__(self): super().__init__() - self._poll = poll() + self._poll = select.poll() def register(self, fileobj, events, data=None): key = super().register(fileobj, events, data) poll_events = 0 if events & EVENT_READ: - poll_events |= POLLIN + poll_events |= select.POLLIN if events & EVENT_WRITE: - poll_events |= POLLOUT + poll_events |= select.POLLOUT self._poll.register(key.fd, poll_events) return key @@ -281,35 +285,32 @@ def unregister(self, fileobj): self._poll.unregister(key.fd) return key + @_select_interrupt_wrapper def select(self, timeout=None): - timeout = None if timeout is None else int(1000 * timeout) + timeout = None if timeout is None else max(int(1000 * timeout), 0) ready = [] - try: - fd_event_list = self._poll.poll(timeout) - except InterruptedError: - # A signal arrived. Don't die, just return no events. - return [] + fd_event_list = self._poll.poll(timeout) for fd, event in fd_event_list: events = 0 - if event & ~POLLIN: + if event & ~select.POLLIN: events |= EVENT_WRITE - if event & ~POLLOUT: + if event & ~select.POLLOUT: events |= EVENT_READ key = self._key_from_fd(fd) if key: - ready.append((key.fileobj, events & key.events, key.data)) + ready.append((key, events & key.events)) return ready -if 'epoll' in globals(): +if hasattr(select, 'epoll'): - class EpollSelector(_BaseSelector): + class EpollSelector(BaseSelector): """Epoll-based selector.""" def __init__(self): super().__init__() - self._epoll = epoll() + self._epoll = select.epoll() def fileno(self): return self._epoll.fileno() @@ -318,9 +319,9 @@ def register(self, fileobj, events, data=None): key = super().register(fileobj, events, data) epoll_events = 0 if events & EVENT_READ: - epoll_events |= EPOLLIN + epoll_events |= select.EPOLLIN if events & EVENT_WRITE: - epoll_events |= EPOLLOUT + epoll_events |= select.EPOLLOUT self._epoll.register(key.fd, epoll_events) return key @@ -329,25 +330,22 @@ def unregister(self, fileobj): self._epoll.unregister(key.fd) return key + @_select_interrupt_wrapper def select(self, timeout=None): - timeout = -1 if timeout is None else timeout - max_ev = self.registered_count() + timeout = -1 if timeout is None else max(timeout, 0) + max_ev = len(self._fd_to_key) ready = [] - try: - fd_event_list = self._epoll.poll(timeout, max_ev) - except InterruptedError: - # A signal arrived. Don't die, just return no events. - return [] + fd_event_list = self._epoll.poll(timeout, max_ev) for fd, event in fd_event_list: events = 0 - if event & ~EPOLLIN: + if event & ~select.EPOLLIN: events |= EVENT_WRITE - if event & ~EPOLLOUT: + if event & ~select.EPOLLOUT: events |= EVENT_READ key = self._key_from_fd(fd) if key: - ready.append((key.fileobj, events & key.events, key.data)) + ready.append((key, events & key.events)) return ready def close(self): @@ -355,58 +353,56 @@ def close(self): self._epoll.close() -if 'kqueue' in globals(): +if hasattr(select, 'kqueue'): - class KqueueSelector(_BaseSelector): + class KqueueSelector(BaseSelector): """Kqueue-based selector.""" def __init__(self): super().__init__() - self._kqueue = kqueue() + self._kqueue = select.kqueue() def fileno(self): return self._kqueue.fileno() - def unregister(self, fileobj): - key = super().unregister(fileobj) - if key.events & EVENT_READ: - kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_DELETE) - self._kqueue.control([kev], 0, 0) - if key.events & EVENT_WRITE: - kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_DELETE) - self._kqueue.control([kev], 0, 0) - return key - def register(self, fileobj, events, data=None): key = super().register(fileobj, events, data) if events & EVENT_READ: - kev = kevent(key.fd, KQ_FILTER_READ, KQ_EV_ADD) + kev = select.kevent(key.fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) self._kqueue.control([kev], 0, 0) if events & EVENT_WRITE: - kev = kevent(key.fd, KQ_FILTER_WRITE, KQ_EV_ADD) + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + if key.events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + if key.events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE) self._kqueue.control([kev], 0, 0) return key + @_select_interrupt_wrapper def select(self, timeout=None): - max_ev = self.registered_count() + timeout = None if timeout is None else max(timeout, 0) + max_ev = len(self._fd_to_key) ready = [] - try: - kev_list = self._kqueue.control(None, max_ev, timeout) - except InterruptedError: - # A signal arrived. Don't die, just return no events. - return [] + kev_list = self._kqueue.control(None, max_ev, timeout) for kev in kev_list: fd = kev.ident flag = kev.filter events = 0 - if flag == KQ_FILTER_READ: + if flag == select.KQ_FILTER_READ: events |= EVENT_READ - if flag == KQ_FILTER_WRITE: + if flag == select.KQ_FILTER_WRITE: events |= EVENT_WRITE key = self._key_from_fd(fd) if key: - ready.append((key.fileobj, events & key.events, key.data)) + ready.append((key, events & key.events)) return ready def close(self): diff --git a/tulip/test_utils.py b/tulip/test_utils.py index 2534a9a1..b4af0c89 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -321,7 +321,7 @@ def make_test_protocol(base): return type('TestProtocol', (base,) + base.__bases__, dct)() -class TestSelector(selectors._BaseSelector): +class TestSelector(selectors.BaseSelector): def select(self, timeout): return [] diff --git a/tulip/windows_events.py b/tulip/windows_events.py index 527922e1..629b3475 100644 --- a/tulip/windows_events.py +++ b/tulip/windows_events.py @@ -70,9 +70,6 @@ def __init__(self, concurrency=0xffffffff): def set_loop(self, loop): self._loop = loop - def registered_count(self): - return len(self._cache) - def select(self, timeout=None): if not self._results: self._poll(timeout) From 105c7b92301e684ad2ff2ddaf5af655a581da7ed Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 30 Aug 2013 16:55:26 -0700 Subject: [PATCH 0579/1502] Remove ancient TODO list. Use the tracker instead. --- tulip/TODO | 26 -------------------------- 1 file changed, 26 deletions(-) delete mode 100644 tulip/TODO diff --git a/tulip/TODO b/tulip/TODO deleted file mode 100644 index b3a9302e..00000000 --- a/tulip/TODO +++ /dev/null @@ -1,26 +0,0 @@ -TODO in tulip v2 (tulip/ package directory) -------------------------------------------- - -- See also TBD and Open Issues in PEP 3156 - -- Refactor unix_events.py (it's getting too long) - -- Docstrings - -- Unittests - -- better run_once() behavior? (Run ready list last.) - -- start_serving() - -- Make Handler() callable? Log the exception in there? - -- Add the repeat interval to the Handler class? - -- Recognize Handler passed to add_reader(), call_soon(), etc.? - -- SSL support - -- buffered stream implementation - -- Primitives like par() and wait_one() From d646c87ca3d64ec1354f5783c2716d4d729f8c6d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 31 Aug 2013 10:29:25 -0700 Subject: [PATCH 0580/1502] New version of selectors.py, dropping @_select_interrupt_wrapper. --- tulip/selectors.py | 74 +++++++++++++++++++--------------------------- 1 file changed, 31 insertions(+), 43 deletions(-) diff --git a/tulip/selectors.py b/tulip/selectors.py index fd1abf50..b81b1dbe 100644 --- a/tulip/selectors.py +++ b/tulip/selectors.py @@ -10,10 +10,6 @@ import functools import select import sys -try: - from time import monotonic as time -except ImportError: - from time import time as time # generic events, that must be mapped to implementation-specific ones @@ -25,7 +21,7 @@ def _fileobj_to_fd(fileobj): """Return a file descriptor from a file object. Parameters: - fileobj -- file object + fileobj -- file object or file descriptor Returns: corresponding file descriptor @@ -36,33 +32,13 @@ def _fileobj_to_fd(fileobj): try: fd = int(fileobj.fileno()) except (AttributeError, ValueError): - raise ValueError("Invalid file object: {!r}".format(fileobj)) from None + raise ValueError("Invalid file object: " + "{!r}".format(fileobj)) from None if fd < 0: raise ValueError("Invalid file descriptor: {}".format(fd)) return fd -def _select_interrupt_wrapper(func): - """InterruptedError-safe wrapper for select(), taking the (optional) - timeout into account.""" - @functools.wraps(func) - def wrapper(self, timeout=None): - if timeout is not None and timeout > 0: - deadline = time() + timeout - while True: - try: - return func(self, timeout) - except InterruptedError: - if timeout is not None: - if timeout > 0: - timeout = deadline - time() - if timeout <= 0: - # timeout expired - return [] - - return wrapper - - SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data']) """Object used to associate a file object to its backing file descriptor, selected event mask and attached data.""" @@ -93,7 +69,7 @@ def register(self, fileobj, events, data=None): """Register a file object. Parameters: - fileobj -- file object + fileobj -- file object or file descriptor events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) data -- attached data @@ -117,7 +93,7 @@ def unregister(self, fileobj): """Unregister a file object. Parameters: - fileobj -- file object + fileobj -- file object or file descriptor Returns: SelectorKey instance @@ -133,7 +109,7 @@ def modify(self, fileobj, events, data=None): """Change a registered file object monitored events or attached data. Parameters: - fileobj -- file object + fileobj -- file object or file descriptor events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) data -- attached data @@ -241,13 +217,15 @@ def _select(self, r, w, _, timeout=None): else: _select = select.select - @_select_interrupt_wrapper def select(self, timeout=None): timeout = None if timeout is None else max(timeout, 0) - r, w, _ = self._select(self._readers, self._writers, [], timeout) + ready = [] + try: + r, w, _ = self._select(self._readers, self._writers, [], timeout) + except InterruptedError: + return ready r = set(r) w = set(w) - ready = [] for fd in r | w: events = 0 if fd in r: @@ -285,11 +263,13 @@ def unregister(self, fileobj): self._poll.unregister(key.fd) return key - @_select_interrupt_wrapper def select(self, timeout=None): timeout = None if timeout is None else max(int(1000 * timeout), 0) ready = [] - fd_event_list = self._poll.poll(timeout) + try: + fd_event_list = self._poll.poll(timeout) + except InterruptedError: + return ready for fd, event in fd_event_list: events = 0 if event & ~select.POLLIN: @@ -330,12 +310,14 @@ def unregister(self, fileobj): self._epoll.unregister(key.fd) return key - @_select_interrupt_wrapper def select(self, timeout=None): timeout = -1 if timeout is None else max(timeout, 0) max_ev = len(self._fd_to_key) ready = [] - fd_event_list = self._epoll.poll(timeout, max_ev) + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except InterruptedError: + return ready for fd, event in fd_event_list: events = 0 if event & ~select.EPOLLIN: @@ -368,29 +350,35 @@ def fileno(self): def register(self, fileobj, events, data=None): key = super().register(fileobj, events, data) if events & EVENT_READ: - kev = select.kevent(key.fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_ADD) self._kqueue.control([kev], 0, 0) if events & EVENT_WRITE: - kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, select.KQ_EV_ADD) + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_ADD) self._kqueue.control([kev], 0, 0) return key def unregister(self, fileobj): key = super().unregister(fileobj) if key.events & EVENT_READ: - kev = select.kevent(key.fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_DELETE) self._kqueue.control([kev], 0, 0) if key.events & EVENT_WRITE: - kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE) + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_DELETE) self._kqueue.control([kev], 0, 0) return key - @_select_interrupt_wrapper def select(self, timeout=None): timeout = None if timeout is None else max(timeout, 0) max_ev = len(self._fd_to_key) ready = [] - kev_list = self._kqueue.control(None, max_ev, timeout) + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except InterruptedError: + return ready for kev in kev_list: fd = kev.ident flag = kev.filter From 0149df44ba7bdf8d57110afaebc30aac32646890 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 29 Aug 2013 17:42:57 -0700 Subject: [PATCH 0581/1502] Remove timeout parameter from run_until_complete(). --- tests/events_test.py | 12 ------------ tests/tasks_test.py | 15 +++++---------- tulip/base_events.py | 27 ++++----------------------- tulip/events.py | 6 +----- 4 files changed, 10 insertions(+), 50 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 7c342bad..f505e2d7 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -228,18 +228,6 @@ def cb(): self.assertRaises(RuntimeError, self.loop.run_until_complete, task) - def test_run_until_complete_timeout(self): - t0 = self.loop.time() - task = tasks.async(tasks.sleep(0.2, loop=self.loop), loop=self.loop) - self.assertRaises(futures.TimeoutError, - self.loop.run_until_complete, - task, timeout=0.1) - t1 = self.loop.time() - self.assertTrue(0.08 <= t1-t0 <= 0.12, t1-t0) - self.loop.run_until_complete(task) - t2 = self.loop.time() - self.assertTrue(0.18 <= t2-t0 <= 0.22, t2-t0) - def test_call_later(self): results = [] diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 3e1220dc..15056d96 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -398,16 +398,11 @@ def task(): yield from tasks.sleep(10.0, loop=loop) return 42 - t = tasks.Task(task(), loop=loop) + t = tasks.Task(task(), loop=loop, timeout=0.1) self.assertRaises( - futures.TimeoutError, loop.run_until_complete, t, 0.1) + futures.CancelledError, loop.run_until_complete, t) self.assertAlmostEqual(0.1, loop.time()) - self.assertFalse(t.done()) - - # move forward to close generator - loop.advance_time(10) - self.assertEqual(42, loop.run_until_complete(t)) - self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) def test_timeout_not(self): @@ -426,8 +421,8 @@ def task(): yield from tasks.sleep(0.1, loop=loop) return 42 - t = tasks.Task(task(), loop=loop) - r = loop.run_until_complete(t, 10.0) + t = tasks.Task(task(), loop=loop, timeout=10.0) + r = loop.run_until_complete(t) self.assertTrue(t.done()) self.assertEqual(r, 42) self.assertAlmostEqual(0.1, loop.time()) diff --git a/tulip/base_events.py b/tulip/base_events.py index 5ff2d3c9..3bccfc83 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -112,8 +112,8 @@ def run_forever(self): finally: self._running = False - def run_until_complete(self, future, timeout=None): - """Run until the Future is done, or until a timeout. + def run_until_complete(self, future): + """Run until the Future is done. If the argument is a coroutine, it is wrapped in a Task. @@ -121,31 +121,12 @@ def run_until_complete(self, future, timeout=None): with the same coroutine twice -- it would wrap it in two different Tasks and that can't be good. - Return the Future's result, or raise its exception. If the - timeout is reached or stop() is called, raise TimeoutError. + Return the Future's result, or raise its exception. """ future = tasks.async(future, loop=self) future.add_done_callback(_raise_stop_error) - handle_called = False - - if timeout is None: - self.run_forever() - else: - - def stop_loop(): - nonlocal handle_called - handle_called = True - raise _StopError - - handle = self.call_later(timeout, stop_loop) - self.run_forever() - handle.cancel() - + self.run_forever() future.remove_done_callback(_raise_stop_error) - - if handle_called: - raise futures.TimeoutError - if not future.done(): raise RuntimeError('Event loop stopped before Future completed.') diff --git a/tulip/events.py b/tulip/events.py index e292eea2..7db2514d 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -109,14 +109,10 @@ def run_forever(self): """Run the event loop until stop() is called.""" raise NotImplementedError - def run_until_complete(self, future, timeout=None): + def run_until_complete(self, future): """Run the event loop until a Future is done. Return the Future's result, or raise its exception. - - If timeout is not None, run it for at most that long; - if the Future is still not done, raise TimeoutError - (but don't cancel the Future). """ raise NotImplementedError From f5c363341851674415e21768933ee7288aa9159d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 29 Aug 2013 20:45:02 -0700 Subject: [PATCH 0582/1502] Eradicate timeouts from locks and queues. --- tests/events_test.py | 10 +- tests/locks_test.py | 218 ------------------------------------------- tests/queues_test.py | 93 +----------------- tulip/locks.py | 217 ++++++++++++++++-------------------------- tulip/queues.py | 25 ++--- 5 files changed, 98 insertions(+), 465 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index f505e2d7..15e82fb8 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -974,7 +974,7 @@ def connect(): stdin = transp.get_pipe_transport(0) stdin.write(b'Python The Winner') - self.loop.run_until_complete(proto.got_data[1].wait(1)) + self.loop.run_until_complete(proto.got_data[1].wait()) transp.close() self.loop.run_until_complete(proto.completed) self.assertEqual(-signal.SIGTERM, proto.returncode) @@ -1003,12 +1003,12 @@ def connect(): try: stdin = transp.get_pipe_transport(0) stdin.write(b'Python ') - self.loop.run_until_complete(proto.got_data[1].wait(1)) + self.loop.run_until_complete(proto.got_data[1].wait()) proto.got_data[1].clear() self.assertEqual(b'Python ', proto.data[1]) stdin.write(b'The Winner') - self.loop.run_until_complete(proto.got_data[1].wait(1)) + self.loop.run_until_complete(proto.got_data[1].wait()) self.assertEqual(b'Python The Winner', proto.data[1]) finally: transp.close() @@ -1207,13 +1207,13 @@ def connect(): stdin = transp.get_pipe_transport(0) stdout = transp.get_pipe_transport(1) stdin.write(b'test') - self.loop.run_until_complete(proto.got_data[1].wait(1)) + self.loop.run_until_complete(proto.got_data[1].wait()) self.assertEqual(b'OUT:test', proto.data[1]) stdout.close() self.loop.run_until_complete(proto.disconnects[1]) stdin.write(b'xxx') - self.loop.run_until_complete(proto.got_data[2].wait(1)) + self.loop.run_until_complete(proto.got_data[2].wait()) self.assertEqual(b'ERR:BrokenPipeError', proto.data[2]) transp.close() diff --git a/tests/locks_test.py b/tests/locks_test.py index 529c7268..65ea0487 100644 --- a/tests/locks_test.py +++ b/tests/locks_test.py @@ -115,59 +115,6 @@ def c3(result): self.assertTrue(t3.done()) self.assertTrue(t3.result()) - def test_acquire_timeout(self): - - def gen(): - when = yield - self.assertAlmostEqual(0.1, when) - yield 0.1 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - lock = locks.Lock(loop=loop) - - self.assertTrue(loop.run_until_complete(lock.acquire())) - - acquired = loop.run_until_complete(lock.acquire(timeout=0.1)) - self.assertFalse(acquired) - self.assertAlmostEqual(0.1, loop.time()) - - lock = locks.Lock(loop=loop) - self.loop.run_until_complete(lock.acquire()) - - loop.call_soon(lock.release) - acquired = loop.run_until_complete(lock.acquire(10.1)) - self.assertTrue(acquired) - self.assertAlmostEqual(0.1, loop.time()) - - def test_acquire_timeout_mixed(self): - - def gen(): - when = yield - self.assertAlmostEqual(0.01, when) - yield 0.1 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - lock = locks.Lock(loop=loop) - loop.run_until_complete(lock.acquire()) - tasks.Task(lock.acquire(), loop=loop) - tasks.Task(lock.acquire(), loop=loop) - acquire_task = tasks.Task(lock.acquire(0.01), loop=loop) - tasks.Task(lock.acquire(), loop=loop) - - acquired = loop.run_until_complete(acquire_task) - self.assertFalse(acquired) - self.assertAlmostEqual(0.1, loop.time()) - - self.assertEqual(3, len(lock._waiters)) - - # wakeup to close waiting coroutines - for i in range(3): - lock.release() - test_utils.run_briefly(loop) - def test_acquire_cancel(self): lock = locks.Lock(loop=self.loop) self.assertTrue(self.loop.run_until_complete(lock.acquire())) @@ -295,55 +242,6 @@ def test_wait_on_set(self): res = self.loop.run_until_complete(ev.wait()) self.assertTrue(res) - def test_wait_timeout(self): - def gen(): - when = yield - self.assertAlmostEqual(0.1, when) - when = yield 0.1 - self.assertAlmostEqual(0.11, when) - when = yield 0 - self.assertAlmostEqual(10.2, when) - yield 0.01 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - ev = locks.EventWaiter(loop=loop) - - res = loop.run_until_complete(ev.wait(0.1)) - self.assertFalse(res) - self.assertAlmostEqual(0.1, loop.time()) - - ev = locks.EventWaiter(loop=loop) - loop.call_later(0.01, ev.set) - acquired = loop.run_until_complete(ev.wait(10.1)) - self.assertTrue(acquired) - self.assertAlmostEqual(0.11, loop.time()) - - def test_wait_timeout_mixed(self): - def gen(): - when = yield - self.assertAlmostEqual(0.1, when) - yield 0.1 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - ev = locks.EventWaiter(loop=loop) - tasks.Task(ev.wait(), loop=loop) - tasks.Task(ev.wait(), loop=loop) - acquire_task = tasks.Task(ev.wait(0.1), loop=loop) - tasks.Task(ev.wait(), loop=loop) - - acquired = loop.run_until_complete(acquire_task) - self.assertFalse(acquired) - self.assertAlmostEqual(0.1, loop.time()) - self.assertEqual(3, len(ev._waiters)) - - # wakeup to close waiting coroutines - ev.set() - test_utils.run_briefly(loop) - def test_wait_cancel(self): ev = locks.EventWaiter(loop=self.loop) @@ -485,23 +383,6 @@ def c3(result): self.assertTrue(t3.done()) self.assertTrue(t3.result()) - def test_wait_timeout(self): - def gen(): - when = yield - self.assertAlmostEqual(0.1, when) - yield 0.1 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - cond = locks.Condition(loop=loop) - loop.run_until_complete(cond.acquire()) - - wait = loop.run_until_complete(cond.wait(0.1)) - self.assertFalse(wait) - self.assertTrue(cond.locked()) - self.assertAlmostEqual(0.1, loop.time()) - def test_wait_cancel(self): cond = locks.Condition(loop=self.loop) self.loop.run_until_complete(cond.acquire()) @@ -558,49 +439,6 @@ def c1(result): self.assertTrue(t.done()) self.assertTrue(t.result()) - def test_wait_for_timeout(self): - def gen(): - when = yield - self.assertAlmostEqual(0.1, when) - when = yield 0 - self.assertAlmostEqual(0.1, when) - yield 0.1 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - cond = locks.Condition(loop=loop) - - result = [] - - predicate = unittest.mock.Mock(return_value=False) - - @tasks.coroutine - def c1(result): - yield from cond.acquire() - if (yield from cond.wait_for(predicate, 0.1)): - result.append(1) - else: - result.append(2) - cond.release() - - wait_for = tasks.Task(c1(result), loop=loop) - - test_utils.run_briefly(loop) - self.assertEqual([], result) - - loop.run_until_complete(cond.acquire()) - cond.notify() - cond.release() - test_utils.run_briefly(loop) - self.assertEqual([], result) - - loop.run_until_complete(wait_for) - self.assertEqual([2], result) - self.assertEqual(3, predicate.call_count) - - self.assertAlmostEqual(0.1, loop.time()) - def test_wait_for_unacquired(self): cond = locks.Condition(loop=self.loop) @@ -834,62 +672,6 @@ def c4(result): # cleanup locked semaphore sem.release() - def test_acquire_timeout(self): - def gen(): - when = yield - self.assertAlmostEqual(0.1, when) - when = yield 0.1 - self.assertAlmostEqual(0.11, when) - when = yield 0 - self.assertAlmostEqual(10.2, when) - yield 0.01 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - sem = locks.Semaphore(loop=loop) - loop.run_until_complete(sem.acquire()) - - acquired = loop.run_until_complete(sem.acquire(0.1)) - self.assertFalse(acquired) - self.assertAlmostEqual(0.1, loop.time()) - - sem = locks.Semaphore(loop=loop) - loop.run_until_complete(sem.acquire()) - - loop.call_later(0.01, sem.release) - acquired = loop.run_until_complete(sem.acquire(10.1)) - self.assertTrue(acquired) - self.assertAlmostEqual(0.11, loop.time()) - - def test_acquire_timeout_mixed(self): - def gen(): - when = yield - self.assertAlmostEqual(0.1, when) - yield 0.1 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - sem = locks.Semaphore(loop=loop) - loop.run_until_complete(sem.acquire()) - tasks.Task(sem.acquire(), loop=loop) - tasks.Task(sem.acquire(), loop=loop) - acquire_task = tasks.Task(sem.acquire(0.1), loop=loop) - tasks.Task(sem.acquire(), loop=loop) - - acquired = loop.run_until_complete(acquire_task) - self.assertFalse(acquired) - - self.assertAlmostEqual(0.1, loop.time()) - - self.assertEqual(3, len(sem._waiters)) - - # wakeup to close waiting coroutines - for i in range(3): - sem.release() - test_utils.run_briefly(loop) - def test_acquire_cancel(self): sem = locks.Semaphore(loop=self.loop) self.loop.run_until_complete(sem.acquire()) diff --git a/tests/queues_test.py b/tests/queues_test.py index 0dce6653..131812a4 100644 --- a/tests/queues_test.py +++ b/tests/queues_test.py @@ -236,37 +236,7 @@ def test_nonblocking_get_exception(self): q = queues.Queue(loop=self.loop) self.assertRaises(queues.Empty, q.get_nowait) - def test_get_timeout(self): - - def gen(): - when = yield - self.assertAlmostEqual(0.01, when) - yield 0.01 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - q = queues.Queue(loop=loop) - - @tasks.coroutine - def queue_get(): - with self.assertRaises(queues.Empty): - return (yield from q.get(timeout=0.01)) - - # Get works after timeout, with blocking and non-blocking put. - q.put_nowait(1) - self.assertEqual(1, (yield from q.get())) - - t = tasks.Task(q.put(2), loop=loop) - self.assertEqual(2, (yield from q.get())) - - self.assertTrue(t.done()) - self.assertIsNone(t.result()) - - loop.run_until_complete(queue_get()) - self.assertAlmostEqual(0.01, loop.time()) - - def test_get_timeout_cancelled(self): + def test_get_cancelled(self): def gen(): when = yield @@ -282,7 +252,7 @@ def gen(): @tasks.coroutine def queue_get(): - return (yield from q.get(timeout=0.05)) + return (yield from tasks.wait_for(q.get(), 0.05, loop=loop)) @tasks.coroutine def test(): @@ -351,47 +321,12 @@ def test_nonblocking_put_exception(self): q.put_nowait(1) self.assertRaises(queues.Full, q.put_nowait, 2) - def test_put_timeout(self): - - def gen(): - when = yield - self.assertAlmostEqual(0.01, when) - when = yield 0.01 - self.assertAlmostEqual(0.02, when) - yield 0.01 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - q = queues.Queue(1, loop=loop) - q.put_nowait(0) - - @tasks.coroutine - def queue_put(): - with self.assertRaises(queues.Full): - return (yield from q.put(1, timeout=0.01)) - - self.assertEqual(0, q.get_nowait()) - - # Put works after timeout, with blocking and non-blocking get. - get_task = tasks.Task(q.get(), loop=loop) - # Let the get start waiting. - yield from tasks.sleep(0.01, loop=loop) - q.put_nowait(2) - self.assertEqual(2, (yield from get_task)) - - q.put_nowait(3) - self.assertEqual(3, q.get_nowait()) - - loop.run_until_complete(queue_put()) - self.assertAlmostEqual(0.02, loop.time()) - - def test_put_timeout_cancelled(self): + def test_put_cancelled(self): q = queues.Queue(loop=self.loop) @tasks.coroutine def queue_put(): - yield from q.put(1, timeout=0.01) + yield from q.put(1) return True @tasks.coroutine @@ -480,26 +415,6 @@ def join(): self.loop.run_until_complete(join()) - def test_join_timeout(self): - def gen(): - when = yield - self.assertAlmostEqual(0.1, when) - yield 0.1 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - q = queues.JoinableQueue(loop=loop) - q.put_nowait(1) - - @tasks.coroutine - def join(): - yield from q.join(0.1) - - # Join completes in ~ 0.1 seconds, although no one calls task_done(). - loop.run_until_complete(join()) - self.assertAlmostEqual(0.1, loop.time()) - def test_format(self): q = queues.JoinableQueue(loop=self.loop) self.assertEqual(q._format(), 'maxsize=0') diff --git a/tulip/locks.py b/tulip/locks.py index 7c0a8f2a..de7f9156 100644 --- a/tulip/locks.py +++ b/tulip/locks.py @@ -1,4 +1,4 @@ -"""Synchronization primitives""" +"""Synchronization primitives.""" __all__ = ['Lock', 'EventWaiter', 'Condition', 'Semaphore'] @@ -10,29 +10,31 @@ class Lock: - """The class implementing primitive lock objects. - - A primitive lock is a synchronization primitive that is not owned by - a particular coroutine when locked. A primitive lock is in one of two - states, "locked" or "unlocked". - It is created in the unlocked state. It has two basic methods, - acquire() and release(). When the state is unlocked, acquire() changes - the state to locked and returns immediately. When the state is locked, - acquire() blocks until a call to release() in another coroutine changes - it to unlocked, then the acquire() call resets it to locked and returns. - The release() method should only be called in the locked state; it changes - the state to unlocked and returns immediately. If an attempt is made - to release an unlocked lock, a RuntimeError will be raised. - - When more than one coroutine is blocked in acquire() waiting for the state - to turn to unlocked, only one coroutine proceeds when a release() call - resets the state to unlocked; first coroutine which is blocked in acquire() - is being processed. - - acquire() method is a coroutine and should be called with "yield from" - - Locks also support the context manager protocol. (yield from lock) should - be used as context manager expression. + """Primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned + by a particular coroutine when locked. A primitive lock is in one + of two states, 'locked' or 'unlocked'. + + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() + changes the state to locked and returns immediately. When the + state is locked, acquire() blocks until a call to release() in + another coroutine changes it to unlocked, then the acquire() call + resets it to locked and returns. The release() method should only + be called in the locked state; it changes the state to unlocked + and returns immediately. If an attempt is made to release an + unlocked lock, a RuntimeError will be raised. + + When more than one coroutine is blocked in acquire() waiting for + the state to turn to unlocked, only one coroutine proceeds when a + release() call resets the state to unlocked; first coroutine which + is blocked in acquire() is being processed. + + acquire() is a coroutine and should be called with 'yield from'. + + Locks also support the context manager protocol. '(yield from lock)' + should be used as context manager expression. Usage: @@ -51,7 +53,7 @@ class Lock: with (yield from lock): ... - Lock object could be tested for locking state: + Lock objects can be tested for locking state: if not lock.locked(): yield from lock @@ -79,37 +81,24 @@ def locked(self): return self._locked @tasks.coroutine - def acquire(self, timeout=None): + def acquire(self): """Acquire a lock. - Acquire method blocks until the lock is unlocked, then set it to - locked and return True. - - When invoked with the floating-point timeout argument set, blocks for - at most the number of seconds specified by timeout and as long as - the lock cannot be acquired. - - The return value is True if the lock is acquired successfully, - False if not (for example if the timeout expired). + This method blocks until the lock is unlocked, then sets it to + locked and returns True. """ if not self._waiters and not self._locked: self._locked = True return True - fut = futures.Future(loop=self._loop, timeout=timeout) - + fut = futures.Future(loop=self._loop) self._waiters.append(fut) try: yield from fut - except futures.CancelledError: + self._locked = True + return True + finally: self._waiters.remove(fut) - return False - else: - f = self._waiters.popleft() - assert f is fut - - self._locked = True - return True def release(self): """Release a lock. @@ -187,39 +176,23 @@ def clear(self): self._value = False @tasks.coroutine - def wait(self, timeout=None): - """Block until the internal flag is true. If the internal flag - is true on entry, return immediately. Otherwise, block until another - coroutine calls set() to set the flag to true, or until the optional - timeout occurs. - - When the timeout argument is present and not None, it should be - a floating point number specifying a timeout for the operation in - seconds (or fractions thereof). - - This method returns true if and only if the internal flag has been - set to true, either before the wait call or after the wait starts, - so it will always return True except if a timeout is given and - the operation times out. - - wait() method is a coroutine. + def wait(self): + """Block until the internal flag is true. + + If the internal flag is true on entry, return True + immediately. Otherwise, block until another coroutine calls + set() to set the flag to true, then return True. """ if self._value: return True - fut = futures.Future(loop=self._loop, timeout=timeout) - + fut = futures.Future(loop=self._loop) self._waiters.append(fut) try: yield from fut - except futures.CancelledError: + return True + finally: self._waiters.remove(fut) - return False - else: - f = self._waiters.popleft() - assert f is fut - - return True class Condition(Lock): @@ -236,71 +209,49 @@ def __init__(self, *, loop=None): self._condition_waiters = collections.deque() @tasks.coroutine - def wait(self, timeout=None): - """Wait until notified or until a timeout occurs. If the calling - coroutine has not acquired the lock when this method is called, - a RuntimeError is raised. - - This method releases the underlying lock, and then blocks until it is - awakened by a notify() or notify_all() call for the same condition - variable in another coroutine, or until the optional timeout occurs. - Once awakened or timed out, it re-acquires the lock and returns. + def wait(self): + """Wait until notified. - When the timeout argument is present and not None, it should be - a floating point number specifying a timeout for the operation - in seconds (or fractions thereof). + If the calling coroutine has not acquired the lock when this + method is called, a RuntimeError is raised. - The return value is True unless a given timeout expired, in which - case it is False. + This method releases the underlying lock, and then blocks + until it is awakened by a notify() or notify_all() call for + the same condition variable in another coroutine. Once + awakened, it re-acquires the lock and returns True. """ if not self._locked: raise RuntimeError('cannot wait on un-acquired lock') - self.release() - - fut = futures.Future(loop=self._loop, timeout=timeout) - - self._condition_waiters.append(fut) keep_lock = True + self.release() try: - yield from fut - except futures.CancelledError: - self._condition_waiters.remove(fut) - return False + fut = futures.Future(loop=self._loop) + self._condition_waiters.append(fut) + try: + yield from fut + return True + finally: + self._condition_waiters.remove(fut) except GeneratorExit: keep_lock = False # Prevent yield in finally clause. raise - else: - f = self._condition_waiters.popleft() - assert fut is f finally: if keep_lock: yield from self.acquire() - return True - @tasks.coroutine - def wait_for(self, predicate, timeout=None): - """Wait until a condition evaluates to True. predicate should be a - callable which result will be interpreted as a boolean value. A timeout - may be provided giving the maximum time to wait. + def wait_for(self, predicate): + """Wait until a predicate becomes true. + + The predicate should be a callable which result will be + interpreted as a boolean value. The final predicate value is + the return value. """ - endtime = None - waittime = timeout result = predicate() - while not result: - if waittime is not None: - if endtime is None: - endtime = self._loop.time() + waittime - else: - waittime = endtime - self._loop.time() - if waittime <= 0: - break - - yield from self.wait(waittime) + yield from self.wait() result = predicate() - return result def notify(self, n=1): @@ -381,17 +332,14 @@ def locked(self): return self._locked @tasks.coroutine - def acquire(self, timeout=None): - """Acquire a semaphore. acquire() method is a coroutine. - - When invoked without arguments: if the internal counter is larger - than zero on entry, decrement it by one and return immediately. - If it is zero on entry, block, waiting until some other coroutine has - called release() to make it larger than zero. - - When invoked with a timeout other than None, it will block for at - most timeout seconds. If acquire does not complete successfully in - that interval, return false. Return true otherwise. + def acquire(self): + """Acquire a semaphore. + + If the internal counter is larger than zero on entry, + decrement it by one and return True immediately. If it is + zero on entry, block, waiting until some other coroutine has + called release() to make it larger than 0, and then return + True. """ if not self._waiters and self._value > 0: self._value -= 1 @@ -399,22 +347,17 @@ def acquire(self, timeout=None): self._locked = True return True - fut = futures.Future(loop=self._loop, timeout=timeout) - + fut = futures.Future(loop=self._loop) self._waiters.append(fut) try: yield from fut - except futures.CancelledError: + self._value -= 1 + if self._value == 0: + self._locked = True + return True + finally: self._waiters.remove(fut) - return False - else: - f = self._waiters.popleft() - assert f is fut - self._value -= 1 - if self._value == 0: - self._locked = True - return True def release(self): """Release a semaphore, incrementing the internal counter by one. diff --git a/tulip/queues.py b/tulip/queues.py index 8214d0ec..244d856d 100644 --- a/tulip/queues.py +++ b/tulip/queues.py @@ -105,14 +105,11 @@ def full(self): return self.qsize() == self._maxsize @coroutine - def put(self, item, timeout=None): + def put(self, item): """Put an item into the queue. - If you yield from put() and timeout is None (the default), wait until a - free slot is available before adding item. - - If a timeout is provided, raise Full if no free slot becomes - available before the timeout. + If you yield from put(), wait until a free slot is available + before adding item. """ self._consume_done_getters(self._getters) if self._getters: @@ -127,7 +124,7 @@ def put(self, item, timeout=None): getter.set_result(self._get()) elif self._maxsize > 0 and self._maxsize == self.qsize(): - waiter = futures.Future(loop=self._loop, timeout=timeout) + waiter = futures.Future(loop=self._loop) self._putters.append((item, waiter)) try: @@ -161,14 +158,10 @@ def put_nowait(self, item): self._put(item) @coroutine - def get(self, timeout=None): + def get(self): """Remove and return an item from the queue. - If you yield from get() and timeout is None (the default), wait until a - item is available. - - If a timeout is provided, raise Empty if no item is available - before the timeout. + If you yield from get(), wait until a item is available. """ self._consume_done_putters() if self._putters: @@ -187,7 +180,7 @@ def get(self, timeout=None): elif self.qsize(): return self._get() else: - waiter = futures.Future(loop=self._loop, timeout=timeout) + waiter = futures.Future(loop=self._loop) self._getters.append(waiter) try: @@ -286,7 +279,7 @@ def task_done(self): self._finished.set() @coroutine - def join(self, timeout=None): + def join(self): """Block until all items in the queue have been gotten and processed. The count of unfinished tasks goes up whenever an item is added to the @@ -295,4 +288,4 @@ def join(self, timeout=None): When the count of unfinished tasks drops to zero, join() unblocks. """ if self._unfinished_tasks > 0: - yield from self._finished.wait(timeout=timeout) + yield from self._finished.wait() From 556e288d2ca770b05eb476c10b8690d2ac078f89 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 29 Aug 2013 21:08:16 -0700 Subject: [PATCH 0583/1502] Eradicate timeout args from Future(), Task(), async(). --- tests/events_test.py | 2 +- tests/futures_test.py | 9 ---- tests/tasks_test.py | 103 ------------------------------------------ tulip/futures.py | 28 +----------- tulip/http/client.py | 9 +++- tulip/tasks.py | 26 ++++++----- 6 files changed, 25 insertions(+), 152 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 15e82fb8..240518c0 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -939,7 +939,7 @@ def main(): return res start = time.monotonic() - t = tasks.Task(main(), timeout=1, loop=self.loop) + t = tasks.Task(main(), loop=self.loop) self.loop.run_forever() elapsed = time.monotonic() - start diff --git a/tests/futures_test.py b/tests/futures_test.py index c7228c00..7c2abd18 100644 --- a/tests/futures_test.py +++ b/tests/futures_test.py @@ -132,15 +132,6 @@ def test_repr(self): self.assertIn('<18 more>', r) f_many_callbacks.cancel() - f_pending = futures.Future(loop=self.loop, timeout=10) - self.assertEqual('Future{timeout=10, when=10}', - repr(f_pending)) - f_pending.cancel() - - f_pending = futures.Future(loop=self.loop, timeout=10) - f_pending.cancel() - self.assertEqual('Future{timeout=10}', repr(f_pending)) - def test_copy_state(self): # Test the internal _copy_state method since it's being directly # invoked in other modules. diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 15056d96..d8457f4a 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -274,63 +274,6 @@ def task(): self.assertTrue(fut3.cancelled()) self.assertTrue(t.cancelled()) - def test_future_timeout(self): - - def gen(): - when = yield - self.assertAlmostEqual(0.1, when) - when = yield 0 - self.assertAlmostEqual(10.0, when) - yield 0.1 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - @tasks.coroutine - def coro(): - yield from tasks.sleep(10.0, loop=loop) - return 12 - - t = tasks.Task(coro(), timeout=0.1, loop=loop) - - self.assertRaises( - futures.CancelledError, - loop.run_until_complete, t) - self.assertTrue(t.done()) - self.assertFalse(t.cancel()) - self.assertAlmostEqual(0.1, loop.time()) - - def test_future_timeout_catch(self): - - def gen(): - when = yield - self.assertAlmostEqual(0.1, when) - when = yield 0 - self.assertAlmostEqual(10.0, when) - yield 0.1 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - @tasks.coroutine - def coro(): - yield from tasks.sleep(10.0, loop=loop) - return 12 - - class Cancelled(Exception): - pass - - @tasks.coroutine - def coro2(): - try: - yield from tasks.Task(coro(), timeout=0.1, loop=loop) - except futures.CancelledError: - raise Cancelled() - - self.assertRaises( - Cancelled, loop.run_until_complete, coro2()) - self.assertAlmostEqual(0.1, loop.time()) - def test_cancel_in_coro(self): @tasks.coroutine def task(): @@ -381,52 +324,6 @@ def task(): for w in waiters: w.close() - def test_timeout(self): - - def gen(): - when = yield - self.assertAlmostEqual(0.1, when) - when = yield 0 - self.assertAlmostEqual(10.0, when) - yield 0.1 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - @tasks.coroutine - def task(): - yield from tasks.sleep(10.0, loop=loop) - return 42 - - t = tasks.Task(task(), loop=loop, timeout=0.1) - self.assertRaises( - futures.CancelledError, loop.run_until_complete, t) - self.assertAlmostEqual(0.1, loop.time()) - self.assertTrue(t.cancelled()) - - def test_timeout_not(self): - - def gen(): - when = yield - self.assertAlmostEqual(10.0, when) - when = yield 0 - self.assertAlmostEqual(0.1, when) - yield 0.1 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - @tasks.coroutine - def task(): - yield from tasks.sleep(0.1, loop=loop) - return 42 - - t = tasks.Task(task(), loop=loop, timeout=10.0) - r = loop.run_until_complete(t) - self.assertTrue(t.done()) - self.assertEqual(r, 42) - self.assertAlmostEqual(0.1, loop.time()) - def test_wait_for(self): def gen(): diff --git a/tulip/futures.py b/tulip/futures.py index 8593e9ae..706e8c8a 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -1,7 +1,7 @@ """A Future class similar to the one in PEP 3148.""" __all__ = ['CancelledError', 'TimeoutError', - 'InvalidStateError', 'InvalidTimeoutError', + 'InvalidStateError', 'Future', 'wrap_future', ] @@ -30,11 +30,6 @@ class InvalidStateError(Error): # TODO: Show the future, its state, the method, and the required state. -class InvalidTimeoutError(Error): - """Called result() or exception() with timeout != 0.""" - # TODO: Print a nice error message. - - class _TracebackLogger: """Helper to log a traceback upon destruction if not cleared. @@ -129,15 +124,13 @@ class Future: _state = _PENDING _result = None _exception = None - _timeout = None - _timeout_handle = None _loop = None _blocking = False # proper use of future (yield vs yield from) _tb_logger = None - def __init__(self, *, loop=None, timeout=None): + def __init__(self, *, loop=None): """Initialize the future. The optional event_loop argument allows to explicitly set the event @@ -150,10 +143,6 @@ def __init__(self, *, loop=None, timeout=None): self._loop = loop self._callbacks = [] - if timeout is not None: - self._timeout = timeout - self._timeout_handle = self._loop.call_later(timeout, self.cancel) - def __repr__(self): res = self.__class__.__name__ if self._state == _FINISHED: @@ -171,14 +160,6 @@ def __repr__(self): res += '<{}, {}>'.format(self._state, self._callbacks) else: res += '<{}>'.format(self._state) - dct = {} - if self._timeout is not None: - dct['timeout'] = self._timeout - if self._timeout_handle is not None: - dct['when'] = self._timeout_handle._when - if dct: - res += '{' + ', '.join('{}={}'.format(k, dct[k]) - for k in sorted(dct)) + '}' return res def cancel(self): @@ -200,11 +181,6 @@ def _schedule_callbacks(self): The callbacks are scheduled to be called as soon as possible. Also clears the callback list. """ - # Cancel timeout handle - if self._timeout_handle is not None: - self._timeout_handle.cancel() - self._timeout_handle = None - callbacks = self._callbacks[:] if not callbacks: return diff --git a/tulip/http/client.py b/tulip/http/client.py index 2aedfdd1..ec7cd034 100644 --- a/tulip/http/client.py +++ b/tulip/http/client.py @@ -95,10 +95,17 @@ def request(method, url, *, conn = session.start(req, loop) # connection timeout + t = tulip.Task(conn, loop=loop) + th = None + if timeout is not None: + th = loop.call_later(timeout, t.cancel) try: - resp = yield from tulip.Task(conn, timeout=timeout, loop=loop) + resp = yield from t except tulip.CancelledError: raise tulip.TimeoutError from None + finally: + if th is not None: + th.cancel() # redirects if resp.status in (301, 302) and allow_redirects: diff --git a/tulip/tasks.py b/tulip/tasks.py index ca513a10..3fa9d25c 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -72,9 +72,9 @@ def task_wrapper(*args, **kwds): class Task(futures.Future): """A coroutine wrapped in a Future.""" - def __init__(self, coro, *, loop=None, timeout=None): + def __init__(self, coro, *, loop=None): assert inspect.isgenerator(coro) # Must be a coroutine *object*. - super().__init__(loop=loop, timeout=timeout) + super().__init__(loop=loop) self._coro = coro self._fut_waiter = None self._must_cancel = False @@ -227,11 +227,11 @@ def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): @coroutine def wait_for(fut, timeout, *, loop=None): - """Wait for the single Future or coroutine to complete. + """Wait for the single Future or coroutine to complete, with timeout. Coroutine will be wrapped in Task. - Returns result of the Future or coroutine. Raises TimeoutError when + Returns result of the Future or coroutine. Raises TimeoutError when timeout occurs. Usage: @@ -259,7 +259,10 @@ def _wait(fs, timeout, return_when, loop): The timeout argument is like for wait(). """ assert fs, 'Set of Futures is empty.' - waiter = futures.Future(loop=loop, timeout=timeout) + waiter = futures.Future(loop=loop) + timeout_handle = None + if timeout is not None: + timeout_handle = loop.call_later(timeout, waiter.cancel) counter = len(fs) def _on_completion(f): @@ -269,6 +272,8 @@ def _on_completion(f): return_when == FIRST_COMPLETED or return_when == FIRST_EXCEPTION and (not f.cancelled() and f.exception() is not None)): + if timeout_handle is not None: + timeout_handle.cancel() waiter.cancel() for f in fs: @@ -341,19 +346,16 @@ def sleep(delay, result=None, *, loop=None): h.cancel() -def async(coro_or_future, *, loop=None, timeout=None): +def async(coro_or_future, *, loop=None): """Wrap a coroutine in a future. If the argument is a Future, it is returned directly. """ if isinstance(coro_or_future, futures.Future): - if ((loop is not None and loop is not coro_or_future._loop) or - (timeout is not None and timeout != coro_or_future._timeout)): - raise ValueError( - 'loop and timeout arguments must agree with Future') - + if loop is not None and loop is not coro_or_future._loop: + raise ValueError('loop argument must agree with Future') return coro_or_future elif iscoroutine(coro_or_future): - return Task(coro_or_future, loop=loop, timeout=timeout) + return Task(coro_or_future, loop=loop) else: raise TypeError('A Future or coroutine is required') From 964de6f4fa50489267174a749f0e3db679ff6bc3 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 30 Aug 2013 20:12:14 -0700 Subject: [PATCH 0584/1502] Change Task.cancel() behavior according to new spec. Had to disable a few tests. --- tests/http_server_test.py | 1 + tests/tasks_test.py | 98 +++++++++++++++++------------------ tulip/tasks.py | 105 +++++++++++++++----------------------- 3 files changed, 91 insertions(+), 113 deletions(-) diff --git a/tests/http_server_test.py b/tests/http_server_test.py index 862779b9..a9d4d5ed 100644 --- a/tests/http_server_test.py +++ b/tests/http_server_test.py @@ -69,6 +69,7 @@ def test_connection_lost(self): handle = srv._request_handler srv.connection_lost(None) + test_utils.run_briefly(self.loop) self.assertIsNone(srv._request_handler) self.assertTrue(handle.cancelled()) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index d8457f4a..022f8142 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -244,47 +244,47 @@ def task(): self.assertTrue(t.done()) self.assertFalse(t.cancel()) - def test_cancel_done_future(self): - fut1 = futures.Future(loop=self.loop) - fut2 = futures.Future(loop=self.loop) - fut3 = futures.Future(loop=self.loop) - - @tasks.coroutine - def task(): - yield from fut1 - try: - yield from fut2 - except futures.CancelledError: - pass - yield from fut3 - - t = tasks.Task(task(), loop=self.loop) - test_utils.run_briefly(self.loop) - fut1.set_result(None) - t.cancel() - test_utils.run_once(self.loop) # process fut1 result, delay cancel - self.assertFalse(t.done()) - test_utils.run_once(self.loop) # cancel fut2, but coro still alive - self.assertFalse(t.done()) - test_utils.run_briefly(self.loop) # cancel fut3 - self.assertTrue(t.done()) - - self.assertEqual(fut1.result(), None) - self.assertTrue(fut2.cancelled()) - self.assertTrue(fut3.cancelled()) - self.assertTrue(t.cancelled()) - - def test_cancel_in_coro(self): - @tasks.coroutine - def task(): - t.cancel() - return 12 - - t = tasks.Task(task(), loop=self.loop) - self.assertRaises( - futures.CancelledError, self.loop.run_until_complete, t) - self.assertTrue(t.done()) - self.assertFalse(t.cancel()) +## def test_cancel_done_future(self): +## fut1 = futures.Future(loop=self.loop) +## fut2 = futures.Future(loop=self.loop) +## fut3 = futures.Future(loop=self.loop) + +## @tasks.coroutine +## def task(): +## yield from fut1 +## try: +## yield from fut2 +## except futures.CancelledError: +## pass +## yield from fut3 + +## t = tasks.Task(task(), loop=self.loop) +## test_utils.run_briefly(self.loop) +## fut1.set_result(None) +## t.cancel() +## test_utils.run_once(self.loop) # process fut1 result, delay cancel +## self.assertFalse(t.done()) +## test_utils.run_once(self.loop) # cancel fut2, but coro still alive +## self.assertFalse(t.done()) +## test_utils.run_briefly(self.loop) # cancel fut3 +## self.assertTrue(t.done()) + +## self.assertEqual(fut1.result(), None) +## self.assertTrue(fut2.cancelled()) +## self.assertTrue(fut3.cancelled()) +## self.assertTrue(t.cancelled()) + +## def test_cancel_in_coro(self): +## @tasks.coroutine +## def task(): +## t.cancel() +## return 12 + +## t = tasks.Task(task(), loop=self.loop) +## self.assertRaises( +## futures.CancelledError, self.loop.run_until_complete, t) +## self.assertTrue(t.done()) +## self.assertFalse(t.cancel()) def test_stop_while_run_in_complete(self): @@ -905,16 +905,14 @@ def test_task_cancel_waiter_future(self): @tasks.coroutine def coro(): - try: - yield from fut - except futures.CancelledError: - pass + yield from fut task = tasks.Task(coro(), loop=self.loop) test_utils.run_briefly(self.loop) self.assertIs(task._fut_waiter, fut) task.cancel() + test_utils.run_briefly(self.loop) self.assertRaises( futures.CancelledError, self.loop.run_until_complete, task) self.assertIsNone(task._fut_waiter) @@ -996,12 +994,14 @@ def gen(): def sleeper(): yield from tasks.sleep(10, loop=loop) + base_exc = BaseException() + @tasks.coroutine def notmutch(): try: yield from sleeper() except futures.CancelledError: - raise BaseException() + raise base_exc task = tasks.Task(notmutch(), loop=loop) test_utils.run_briefly(loop) @@ -1012,7 +1012,8 @@ def notmutch(): self.assertRaises(BaseException, test_utils.run_briefly, loop) self.assertTrue(task.done()) - self.assertTrue(task.cancelled()) + self.assertFalse(task.cancelled()) + self.assertIs(task.exception(), base_exc) def test_iscoroutinefunction(self): def fn(): @@ -1040,8 +1041,7 @@ def wait_for_future(): with self.assertRaises(RuntimeError) as cm: self.loop.run_until_complete(task) - self.assertTrue(fut.done()) - self.assertIs(fut.exception(), cm.exception) + self.assertFalse(fut.done()) def test_yield_vs_yield_from_generator(self): @tasks.coroutine diff --git a/tulip/tasks.py b/tulip/tasks.py index 3fa9d25c..4e3ea551 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -66,9 +66,6 @@ def task_wrapper(*args, **kwds): return task_wrapper -_marker = object() - - class Task(futures.Future): """A coroutine wrapped in a Future.""" @@ -93,36 +90,25 @@ def __repr__(self): return res def cancel(self): - if self.done() or self._must_cancel: + if self.done(): return False - self._must_cancel = True - # _step() will call super().cancel() to call the callbacks. if self._fut_waiter is not None: - return self._fut_waiter.cancel() - else: - self._loop.call_soon(self._step_maybe) - return True - - def cancelled(self): - return self._must_cancel or super().cancelled() - - def _step_maybe(self): - # Helper for cancel(). - if not self.done(): - return self._step() + if self._fut_waiter.cancel(): + return True + # It must be the case that self._step is already scheduled. + self._must_cancel = True + return True - def _step(self, value=_marker, exc=None): + def _step(self, value=None, exc=None): assert not self.done(), \ '_step(): already done: {!r}, {!r}, {!r}'.format(self, value, exc) - - # We'll call either coro.throw(exc) or coro.send(value). - # Task cancel has to be delayed if current waiter future is done. - if self._must_cancel and exc is None and value is _marker: - exc = futures.CancelledError - + if self._must_cancel: + assert self._fut_waiter is None + exc = futures.CancelledError() + value = None coro = self._coro - value = None if value is _marker else value self._fut_waiter = None + # Call either coro.throw(exc) or coro.send(value). try: if exc is not None: result = coro.throw(exc) @@ -131,53 +117,44 @@ def _step(self, value=_marker, exc=None): else: result = next(coro) except StopIteration as exc: - if self._must_cancel: - super().cancel() - else: - self.set_result(exc.value) + self.set_result(exc.value) + except futures.CancelledError as exc: + super().cancel() # I.e., Future.cancel(self). except Exception as exc: - if self._must_cancel: - super().cancel() - else: - self.set_exception(exc) + self.set_exception(exc) except BaseException as exc: - if self._must_cancel: - super().cancel() - else: - self.set_exception(exc) + self.set_exception(exc) raise else: if isinstance(result, futures.Future): - if not result._blocking: - result.set_exception( + # Yielded Future must come from Future.__iter__(). + if result._blocking: + result._blocking = False + result.add_done_callback(self._wakeup) + self._fut_waiter = result + else: + self._loop.call_soon( + self._step, None, RuntimeError( 'yield was used instead of yield from ' 'in task {!r} with {!r}'.format(self, result))) - - result._blocking = False - result.add_done_callback(self._wakeup) - self._fut_waiter = result - - # task cancellation has been delayed. - if self._must_cancel: - self._fut_waiter.cancel() - + elif result is None: + # Bare yield relinquishes control for one event loop iteration. + self._loop.call_soon(self._step) + elif inspect.isgenerator(result): + # Yielding a generator is just wrong. + self._loop.call_soon( + self._step, None, + RuntimeError( + 'yield was used instead of yield from for ' + 'generator in task {!r} with {}'.format( + self, result))) else: - if inspect.isgenerator(result): - self._loop.call_soon( - self._step, None, - RuntimeError( - 'yield was used instead of yield from for ' - 'generator in task {!r} with {}'.format( - self, result))) - else: - if result is not None: - self._loop.call_soon( - self._step, None, - RuntimeError( - 'Task got bad yield: {!r}'.format(result))) - else: - self._loop.call_soon(self._step_maybe) + # Yielding something else is an error. + self._loop.call_soon( + self._step, None, + RuntimeError( + 'Task got bad yield: {!r}'.format(result))) self = None def _wakeup(self, future): From 01e947bc98ffd734a368c12266921e8e27ecb36c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 31 Aug 2013 11:17:45 -0700 Subject: [PATCH 0585/1502] Fix race in Lock.release(). Still need to do other lock classes. --- tests/locks_test.py | 61 +++++++++++++++++++++++++++++++++++++++++++++ tests/tasks_test.py | 51 ------------------------------------- tulip/locks.py | 23 +++++++++++++---- tulip/queues.py | 5 ++-- tulip/tasks.py | 19 +------------- 5 files changed, 82 insertions(+), 77 deletions(-) diff --git a/tests/locks_test.py b/tests/locks_test.py index 65ea0487..af65c97c 100644 --- a/tests/locks_test.py +++ b/tests/locks_test.py @@ -126,6 +126,67 @@ def test_acquire_cancel(self): self.loop.run_until_complete, task) self.assertFalse(lock._waiters) + def test_cancel_race(self): + # XXX replace assert with self.assertXXX; remove dprint(). + # Several tasks: + # - A acquires the lock + # - B is blocked in aqcuire() + # - C is blocked in aqcuire() + # + # Now, concurrently: + # - B is cancelled + # - A releases the lock + # + # If B's waiter is marked cancelled but not yet removed from + # _waiters, A's release() call will crash when trying to set + # B's waiter; instead, it should move on to C's waiter. + + # Setup: A has the lock, b and c are waiting. + def dprint(*args): pass # or 'dprint = print' + lock = locks.Lock(loop=self.loop) + @tasks.coroutine + def lockit(name, blocker): + dprint(name, 'acquiring...') + yield from lock.acquire() + dprint(name, 'acquired') + try: + if blocker is not None: + dprint(name, 'blocking...') + yield from blocker + dprint(name, 'unlocked') + finally: + dprint(name, 'releasing...') + lock.release() + dprint(name, 'released') + fa = futures.Future(loop=self.loop) + ta = tasks.Task(lockit('A', fa), loop=self.loop) + test_utils.run_briefly(self.loop) + assert lock.locked() + fb = futures.Future(loop=self.loop) + tb = tasks.Task(lockit('B', None), loop=self.loop) + test_utils.run_briefly(self.loop) + assert len(lock._waiters) == 1 + fc = futures.Future(loop=self.loop) + tc = tasks.Task(lockit('C', None), loop=self.loop) + test_utils.run_briefly(self.loop) + assert len(lock._waiters) == 2 + + # Create the race and check. + # Without the fix this failed at the last assert. + dprint('---create the race---') + fa.set_result(None) + tb.cancel() + assert lock._waiters[0].cancelled() + dprint(tb, lock) + dprint(fa, lock) + assert lock._waiters[0].cancelled() + test_utils.run_briefly(self.loop) + dprint(lock) + assert not lock.locked() + assert ta.done() + assert tb.cancelled() + assert tc.done() + def test_release_not_acquired(self): lock = locks.Lock(loop=self.loop) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 022f8142..8c26e3f9 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -44,57 +44,6 @@ def notmuch(): self.assertIs(t._loop, loop) loop.close() - def test_task_decorator(self): - @tasks.task - def notmuch(): - yield from [] - return 'ko' - - try: - events.set_event_loop(self.loop) - t = notmuch() - finally: - events.set_event_loop(None) - - self.assertIsInstance(t, tasks.Task) - self.loop.run_until_complete(t) - self.assertTrue(t.done()) - self.assertEqual(t.result(), 'ko') - - def test_task_decorator_func(self): - @tasks.task - def notmuch(): - return 'ko' - - try: - events.set_event_loop(self.loop) - t = notmuch() - finally: - events.set_event_loop(None) - - self.assertIsInstance(t, tasks.Task) - self.loop.run_until_complete(t) - self.assertTrue(t.done()) - self.assertEqual(t.result(), 'ko') - - def test_task_decorator_fut(self): - @tasks.task - def notmuch(): - fut = futures.Future(loop=self.loop) - fut.set_result('ko') - return fut - - try: - events.set_event_loop(self.loop) - t = notmuch() - finally: - events.set_event_loop(None) - - self.assertIsInstance(t, tasks.Task) - self.loop.run_until_complete(t) - self.assertTrue(t.done()) - self.assertEqual(t.result(), 'ko') - def test_async_coroutine(self): @tasks.coroutine def notmuch(): diff --git a/tulip/locks.py b/tulip/locks.py index de7f9156..87937ec0 100644 --- a/tulip/locks.py +++ b/tulip/locks.py @@ -73,8 +73,10 @@ def __init__(self, *, loop=None): def __repr__(self): res = super().__repr__() - return '<{} [{}]>'.format( - res[1:-1], 'locked' if self._locked else 'unlocked') + extra = 'locked' if self._locked else 'unlocked' + if self._waiters: + extra = '{},waiters:{}'.format(extra, len(self._waiters)) + return '<{} [{}]>'.format(res[1:-1], extra) def locked(self): """Return true if lock is acquired.""" @@ -113,8 +115,11 @@ def release(self): """ if self._locked: self._locked = False - if self._waiters: - self._waiters[0].set_result(True) + # Wake up the first waiter who isn't cancelled. + for fut in self._waiters: + if not fut.cancelled(): + fut.set_result(True) + break else: raise RuntimeError('Lock is not acquired.') @@ -132,6 +137,7 @@ def __iter__(self): return self +# TODO: Why not call this Event? class EventWaiter: """A EventWaiter implementation, our equivalent to threading.Event @@ -150,6 +156,7 @@ def __init__(self, *, loop=None): self._loop = events.get_event_loop() def __repr__(self): + # TODO: add waiters:N if > 0. res = super().__repr__() return '<{} [{}]>'.format(res[1:-1], 'set' if self._value else 'unset') @@ -195,6 +202,7 @@ def wait(self): self._waiters.remove(fut) +# TODO: Why is this a Lock subclass? threading.Condition *has* a lock. class Condition(Lock): """A Condition implementation. @@ -205,9 +213,10 @@ class Condition(Lock): def __init__(self, *, loop=None): super().__init__(loop=loop) - self._condition_waiters = collections.deque() + # TODO: Add __repr__() with len(_condition_waiters). + @tasks.coroutine def wait(self): """Wait until notified. @@ -233,6 +242,7 @@ def wait(self): return True finally: self._condition_waiters.remove(fut) + except GeneratorExit: keep_lock = False # Prevent yield in finally clause. raise @@ -321,6 +331,7 @@ def __init__(self, value=1, bound=False, *, loop=None): self._loop = events.get_event_loop() def __repr__(self): + # TODO: add waiters:N if > 0. res = super().__repr__() return '<{} [{}]>'.format( res[1:-1], @@ -380,6 +391,8 @@ def release(self): break def __enter__(self): + # TODO: This is questionable. How do we know the user actually + # wrote "with (yield from sema)" instead of "with sema"? return True def __exit__(self, *args): diff --git a/tulip/queues.py b/tulip/queues.py index 244d856d..4a46f1a2 100644 --- a/tulip/queues.py +++ b/tulip/queues.py @@ -4,7 +4,6 @@ 'Full', 'Empty'] import collections -import concurrent.futures import heapq import queue @@ -129,7 +128,7 @@ def put(self, item): self._putters.append((item, waiter)) try: yield from waiter - except concurrent.futures.CancelledError: + except futures.CancelledError: raise Full else: @@ -185,7 +184,7 @@ def get(self): self._getters.append(waiter) try: return (yield from waiter) - except concurrent.futures.CancelledError: + except futures.CancelledError: raise Empty def get_nowait(self): diff --git a/tulip/tasks.py b/tulip/tasks.py index 4e3ea551..a51ee29a 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -1,6 +1,6 @@ """Support for tasks, coroutines and the scheduler.""" -__all__ = ['coroutine', 'task', 'Task', +__all__ = ['coroutine', 'Task', 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', 'wait', 'wait_for', 'as_completed', 'sleep', 'async', ] @@ -49,23 +49,6 @@ def iscoroutine(obj): return inspect.isgenerator(obj) # TODO: And what? -def task(func): - """Decorator for a coroutine to be wrapped in a Task.""" - if inspect.isgeneratorfunction(func): - coro = func - else: - def coro(*args, **kw): - res = func(*args, **kw) - if isinstance(res, futures.Future) or inspect.isgenerator(res): - res = yield from res - return res - - def task_wrapper(*args, **kwds): - return Task(coro(*args, **kwds)) - - return task_wrapper - - class Task(futures.Future): """A coroutine wrapped in a Future.""" From fd8633f3a8fba7a81e7b855760be4db8cbf8320b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 31 Aug 2013 14:31:45 -0700 Subject: [PATCH 0586/1502] Clean up test_cancel_race(). --- tests/locks_test.py | 33 ++++++++++----------------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/tests/locks_test.py b/tests/locks_test.py index af65c97c..9399d759 100644 --- a/tests/locks_test.py +++ b/tests/locks_test.py @@ -127,7 +127,6 @@ def test_acquire_cancel(self): self.assertFalse(lock._waiters) def test_cancel_race(self): - # XXX replace assert with self.assertXXX; remove dprint(). # Several tasks: # - A acquires the lock # - B is blocked in aqcuire() @@ -142,50 +141,38 @@ def test_cancel_race(self): # B's waiter; instead, it should move on to C's waiter. # Setup: A has the lock, b and c are waiting. - def dprint(*args): pass # or 'dprint = print' lock = locks.Lock(loop=self.loop) + @tasks.coroutine def lockit(name, blocker): - dprint(name, 'acquiring...') yield from lock.acquire() - dprint(name, 'acquired') try: if blocker is not None: - dprint(name, 'blocking...') yield from blocker - dprint(name, 'unlocked') finally: - dprint(name, 'releasing...') lock.release() - dprint(name, 'released') + fa = futures.Future(loop=self.loop) ta = tasks.Task(lockit('A', fa), loop=self.loop) test_utils.run_briefly(self.loop) - assert lock.locked() - fb = futures.Future(loop=self.loop) + self.assertTrue(lock.locked()) tb = tasks.Task(lockit('B', None), loop=self.loop) test_utils.run_briefly(self.loop) - assert len(lock._waiters) == 1 - fc = futures.Future(loop=self.loop) + self.assertEqual(len(lock._waiters), 1) tc = tasks.Task(lockit('C', None), loop=self.loop) test_utils.run_briefly(self.loop) - assert len(lock._waiters) == 2 + self.assertEqual(len(lock._waiters), 2) # Create the race and check. # Without the fix this failed at the last assert. - dprint('---create the race---') fa.set_result(None) tb.cancel() - assert lock._waiters[0].cancelled() - dprint(tb, lock) - dprint(fa, lock) - assert lock._waiters[0].cancelled() + self.assertTrue(lock._waiters[0].cancelled()) test_utils.run_briefly(self.loop) - dprint(lock) - assert not lock.locked() - assert ta.done() - assert tb.cancelled() - assert tc.done() + self.assertFalse(lock.locked()) + self.assertTrue(ta.done()) + self.assertTrue(tb.cancelled()) + self.assertTrue(tc.done()) def test_release_not_acquired(self): lock = locks.Lock(loop=self.loop) From 2c8a720be0677ec6f20690c836f8bf65625f89c7 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 3 Sep 2013 09:31:09 -0700 Subject: [PATCH 0587/1502] test --- .hgeol | 4 + .hgignore | 12 + Makefile | 35 + NOTES | 176 +++ README | 21 + TODO | 163 +++ check.py | 41 + examples/child_process.py | 127 +++ examples/crawl.py | 104 ++ examples/curl.py | 24 + examples/mpsrv.py | 289 +++++ examples/srv.py | 163 +++ examples/tcp_echo.py | 113 ++ examples/tcp_protocol_parser.py | 170 +++ examples/udp_echo.py | 98 ++ examples/websocket.html | 90 ++ examples/wsclient.py | 97 ++ examples/wssrv.py | 309 +++++ overlapped.c | 1009 +++++++++++++++++ runtests.py | 265 +++++ setup.cfg | 2 + setup.py | 14 + tests/base_events_test.py | 591 ++++++++++ tests/echo.py | 6 + tests/echo2.py | 6 + tests/echo3.py | 9 + tests/events_test.py | 1574 ++++++++++++++++++++++++++ tests/futures_test.py | 317 ++++++ tests/http_client_functional_test.py | 552 +++++++++ tests/http_client_test.py | 289 +++++ tests/http_parser_test.py | 539 +++++++++ tests/http_protocol_test.py | 400 +++++++ tests/http_server_test.py | 301 +++++ tests/http_session_test.py | 139 +++ tests/http_websocket_test.py | 439 +++++++ tests/http_wsgi_test.py | 301 +++++ tests/locks_test.py | 765 +++++++++++++ tests/parsers_test.py | 598 ++++++++++ tests/proactor_events_test.py | 393 +++++++ tests/queues_test.py | 427 +++++++ tests/sample.crt | 14 + tests/sample.key | 15 + tests/selector_events_test.py | 1471 ++++++++++++++++++++++++ tests/selectors_test.py | 142 +++ tests/streams_test.py | 343 ++++++ tests/tasks_test.py | 1041 +++++++++++++++++ tests/transports_test.py | 59 + tests/unix_events_test.py | 818 +++++++++++++ tests/windows_events_test.py | 81 ++ tests/windows_utils_test.py | 132 +++ tulip/__init__.py | 28 + tulip/base_events.py | 592 ++++++++++ tulip/constants.py | 4 + tulip/events.py | 389 +++++++ tulip/futures.py | 338 ++++++ tulip/http/__init__.py | 16 + tulip/http/client.py | 572 ++++++++++ tulip/http/errors.py | 46 + tulip/http/protocol.py | 756 +++++++++++++ tulip/http/server.py | 215 ++++ tulip/http/session.py | 103 ++ tulip/http/websocket.py | 233 ++++ tulip/http/wsgi.py | 227 ++++ tulip/locks.py | 403 +++++++ tulip/log.py | 6 + tulip/parsers.py | 399 +++++++ tulip/proactor_events.py | 288 +++++ tulip/protocols.py | 100 ++ tulip/queues.py | 290 +++++ tulip/selector_events.py | 676 +++++++++++ tulip/selectors.py | 410 +++++++ tulip/streams.py | 211 ++++ tulip/tasks.py | 321 ++++++ tulip/test_utils.py | 443 ++++++++ tulip/transports.py | 201 ++++ tulip/unix_events.py | 555 +++++++++ tulip/windows_events.py | 203 ++++ tulip/windows_utils.py | 181 +++ 78 files changed, 23294 insertions(+) create mode 100644 .hgeol create mode 100644 .hgignore create mode 100644 Makefile create mode 100644 NOTES create mode 100644 README create mode 100644 TODO create mode 100644 check.py create mode 100644 examples/child_process.py create mode 100755 examples/crawl.py create mode 100755 examples/curl.py create mode 100755 examples/mpsrv.py create mode 100755 examples/srv.py create mode 100755 examples/tcp_echo.py create mode 100755 examples/tcp_protocol_parser.py create mode 100755 examples/udp_echo.py create mode 100644 examples/websocket.html create mode 100755 examples/wsclient.py create mode 100755 examples/wssrv.py create mode 100644 overlapped.c create mode 100644 runtests.py create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 tests/base_events_test.py create mode 100644 tests/echo.py create mode 100644 tests/echo2.py create mode 100644 tests/echo3.py create mode 100644 tests/events_test.py create mode 100644 tests/futures_test.py create mode 100644 tests/http_client_functional_test.py create mode 100644 tests/http_client_test.py create mode 100644 tests/http_parser_test.py create mode 100644 tests/http_protocol_test.py create mode 100644 tests/http_server_test.py create mode 100644 tests/http_session_test.py create mode 100644 tests/http_websocket_test.py create mode 100644 tests/http_wsgi_test.py create mode 100644 tests/locks_test.py create mode 100644 tests/parsers_test.py create mode 100644 tests/proactor_events_test.py create mode 100644 tests/queues_test.py create mode 100644 tests/sample.crt create mode 100644 tests/sample.key create mode 100644 tests/selector_events_test.py create mode 100644 tests/selectors_test.py create mode 100644 tests/streams_test.py create mode 100644 tests/tasks_test.py create mode 100644 tests/transports_test.py create mode 100644 tests/unix_events_test.py create mode 100644 tests/windows_events_test.py create mode 100644 tests/windows_utils_test.py create mode 100644 tulip/__init__.py create mode 100644 tulip/base_events.py create mode 100644 tulip/constants.py create mode 100644 tulip/events.py create mode 100644 tulip/futures.py create mode 100644 tulip/http/__init__.py create mode 100644 tulip/http/client.py create mode 100644 tulip/http/errors.py create mode 100644 tulip/http/protocol.py create mode 100644 tulip/http/server.py create mode 100644 tulip/http/session.py create mode 100644 tulip/http/websocket.py create mode 100644 tulip/http/wsgi.py create mode 100644 tulip/locks.py create mode 100644 tulip/log.py create mode 100644 tulip/parsers.py create mode 100644 tulip/proactor_events.py create mode 100644 tulip/protocols.py create mode 100644 tulip/queues.py create mode 100644 tulip/selector_events.py create mode 100644 tulip/selectors.py create mode 100644 tulip/streams.py create mode 100644 tulip/tasks.py create mode 100644 tulip/test_utils.py create mode 100644 tulip/transports.py create mode 100644 tulip/unix_events.py create mode 100644 tulip/windows_events.py create mode 100644 tulip/windows_utils.py diff --git a/.hgeol b/.hgeol new file mode 100644 index 00000000..8233b6dd --- /dev/null +++ b/.hgeol @@ -0,0 +1,4 @@ +[patterns] +** = native +.hgignore = native +.hgeol = native diff --git a/.hgignore b/.hgignore new file mode 100644 index 00000000..99870025 --- /dev/null +++ b/.hgignore @@ -0,0 +1,12 @@ +.*\.py[co]$ +.*~$ +.*\.orig$ +.*\#.*$ +.*@.*$ +\.coverage$ +htmlcov$ +\.DS_Store$ +venv$ +distribute_setup.py$ +distribute-\d+.\d+.\d+.tar.gz$ +build$ diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..6064fc63 --- /dev/null +++ b/Makefile @@ -0,0 +1,35 @@ +# Some simple testing tasks (sorry, UNIX only). + +PYTHON=python3 +VERBOSE=$(V) +V= 0 +FLAGS= + +test: + $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS) + +vtest: + $(PYTHON) runtests.py -v 1 $(FLAGS) + +testloop: + while sleep 1; do $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS); done + +# See runtests.py for coverage installation instructions. +cov coverage: + $(PYTHON) runtests.py --coverage tulip -v $(VERBOSE) $(FLAGS) + echo "open file://`pwd`/htmlcov/index.html" + +check: + $(PYTHON) check.py + +clean: + rm -rf `find . -name __pycache__` + rm -f `find . -type f -name '*.py[co]' ` + rm -f `find . -type f -name '*~' ` + rm -f `find . -type f -name '.*~' ` + rm -f `find . -type f -name '@*' ` + rm -f `find . -type f -name '#*#' ` + rm -f `find . -type f -name '*.orig' ` + rm -f `find . -type f -name '*.rej' ` + rm -f .coverage + rm -rf htmlcov diff --git a/NOTES b/NOTES new file mode 100644 index 00000000..3b94ba96 --- /dev/null +++ b/NOTES @@ -0,0 +1,176 @@ +Notes from PyCon 2013 sprints +============================= + +- Cancellation. If a task creates several subtasks, and then the + parent task fails, should the subtasks be cancelled? (How do we + even establish the parent/subtask relationship?) + +- Adam Sah suggests that there might be a need for scheduling + (especially when multiple frameworks share an event loop). He + points to lottery scheduling but also mentions that's just one of + the options. However, after posting on python-tulip, it appears + none of the other frameworks have scheduling, and nobody seems to + miss it. + +- Feedback from Bram Cohen (Bittorrent creator) about UDP. He doesn't + think connected UDP is worth supporting, it doesn't do anything + except tell the kernel about the default target address for + sendto(). Basically he says all UDP end points are servers. He + sent me his own UDP event loop so I might glean some tricks from it. + He says we should treat EINTR the same as EAGAIN and friends. (We + should use the exceptions dedicated to errno checking, BTW.) HE + said to make sure we use SO_REUSEADDR (I think we already do). He + said to set the max datagram sizes pretty large (anything larger + than the declared limit is dropped on the floor). He reminds us of + the importance of being able to pick a valid, unused port by binding + to port 0 and then using getsockname(). He has an idea where he's + like to be able to kill all registered callbacks (i.e. Handles) + belonging to a certain "context". I think this can be done at the + application level (you'd have to wrap everything that returns a + Handle and collect these handles in some set or other datastructure) + but if someone thinks it's interesting we could imagine having some + kind of notion of context part of the event loop state, + e.g. associated with a Task (see Cancellation point above). He + brought up uTP (Micro Transport Protocol), a reimplementation of TCP + over UDP with more refined congestion control. + +- Mumblings about UNIX domain sockets and IPv6 addresses being + 4-tuples. The former can be handled by passing in a socket. There + seem to be no real use cases for the latter that can't be dealt with + by passing in suitably esoteric strings for the hostname. + getaddrinfo() will produce the appropriate 4-tuple and connect() + will accept it. + +- Mumblings on the list about add vs. set. + + +Notes from the second Tulip/Twisted meet-up +=========================================== + +Rackspace, 12/11/2012 +Glyph, Brian Warner, David Reid, Duncan McGreggor, others + +Flow control +------------ + +- Pause/resume on transport manages data_received. + +- There's also an API to tell the transport whom to pause when the + write calls are overwhelming it: IConsumer.registerProducer(). + +- There's also something called pipes but it's built on top of the + old interface. + +- Twisted has variations on the basic flow control that I should + ignore. + +Half_close +---------- + +- This sends an EOF after writing some stuff. + +- Can't write any more. + +- Problem with TLS is known (the RFC sadly specifies this behavior). + +- It must be dynamimcally discoverable whether the transport supports + half_close, since the protocol may have to do something different to + make up for its missing (e.g. use chunked encoding). Twisted uses + an interface check for this and also hasattr(trans, 'halfClose') + but a flag (or flag method) is fine too. + +Constructing transport and protocol +----------------------------------- + +- There are good reasons for passing a function to the transport + construction helper that creates the protocol. (You need these + anyway for server-side protocols.) The sequence of events is + something like + + . open socket + . create transport (pass it a socket?) + . create protocol (pass it nothing) + . proto.make_connection(transport); this does: + . self.transport = transport + . self.connection_made(transport) + + But it seems okay to skip make_connection and setting .transport. + Note that make_connection() is a concrete method on the Protocol + implementation base class, while connection_made() is an abstract + method on IProtocol. + +Event Loop +---------- + +- We discussed the sequence of actions in the event loop. I think in the + end we're fine with what Tulip currently does. There are two choices: + + Tulip: + . run ready callbacks until there aren't any left + . poll, adding more callbacks to the ready list + . add now-ready delayed callbacks to the ready list + . go to top + + Tornado: + . run all currently ready callbacks (but not new ones added during this) + . (the rest is the same) + + The difference is that in the Tulip version, CPU bound callbacks + that keep adding more to the queue will starve I/O (and yielding to + other tasks won't actually cause I/O to happen unless you do + e.g. sleep(0.001)). OTOH this may be good because it means there's + less overhead if you frequently split operations in two. + +- I think Twisted does it Tornado style (in a convoluted way :-), but + it may not matter, and it's important to leave this vague so + implementations can do what's best for their platform. (E.g. if the + event loop is built into the OS there are different trade-offs.) + +System call cost +---------------- + +- System calls on MacOS are expensive, on Linux they are cheap. + +- Optimal buffer size ~16K. + +- Try joining small buffer pieces together, but expect to be tuning + this later. + +Futures +------- + +- Futures are the most robust API for async stuff, you can check + errors etc. So let's do this. + +- Just don't implement wait(). + +- For the basics, however, (recv/send, mostly), don't use Futures but use + basic callbacks, transport/protocol style. + +- make_connection() (by any name) can return a Future, it makes it + easier to check for errors. + +- This means revisiting the Tulip proactor branch (IOCP). + +- The semantics of add_done_callback() are fuzzy about in which thread + the callback will be called. (It may be the current thread or + another one.) We don't like that. But always inserting a + call_soon() indirection may be expensive? Glyph suggested changing + the add_done_callback() method name to something else to indicate + the changed promise. + +- Separately, I've been thinking about having two versions of + call_soon() -- a more heavy-weight one to be called from other + threads that also writes a byte to the self-pipe. + +Signals +------- + +- There was a side conversation about signals. A signal handler is + similar to another thread, so probably should use (the heavy-weight + version of) call_soon() to schedule the real callback and not do + anything else. + +- Glyph vaguely recalled some trickiness with the self-pipe. We + should be able to fix this afterwards if necessary, it shouldn't + affect the API design. diff --git a/README b/README new file mode 100644 index 00000000..8f2b6373 --- /dev/null +++ b/README @@ -0,0 +1,21 @@ +Tulip is the codename for my reference implementation of PEP 3156. + +PEP 3156: http://www.python.org/dev/peps/pep-3156/ + +*** This requires Python 3.3 or later! *** + +Copyright/license: Open source, Apache 2.0. Enjoy. + +Master Mercurial repo: http://code.google.com/p/tulip/ + +The actual code lives in the 'tulip' subdirectory. +Tests are in the 'tests' subdirectory. + +To run tests: + - make test + +To run coverage (coverage package is required): + - make coverage + + +--Guido van Rossum diff --git a/TODO b/TODO new file mode 100644 index 00000000..c6d4eead --- /dev/null +++ b/TODO @@ -0,0 +1,163 @@ +# -*- Mode: text -*- + +TO DO LARGER TASKS + +- Need more examples. + +- Benchmarkable but more realistic HTTP server? + +- Example of using UDP. + +- Write up a tutorial for the scheduling API. + +- More systematic approach to logging. Logger objects? What about + heavy-duty logging, tracking essentially all task state changes? + +- Restructure directory, move demos and benchmarks to subdirectories. + + +TO DO LATER + +- When multiple tasks are accessing the same socket, they should + either get interleaved I/O or an immediate exception; it should not + compromise the integrity of the scheduler or the app or leave a task + hanging. + +- For epoll you probably want to check/(log?) EPOLLHUP and EPOLLERR errors. + +- Add the simplest API possible to run a generator with a timeout. + +- Ensure multiple tasks can do atomic writes to the same pipe (since + UNIX guarantees that short writes to pipes are atomic). + +- Ensure some easy way of distributing accepted connections across tasks. + +- Be wary of thread-local storage. There should be a standard API to + get the current Context (which holds current task, event loop, and + maybe more) and a standard meta-API to change how that standard API + works (i.e. without monkey-patching). + +- See how much of asyncore I've already replaced. + +- Could BufferedReader reuse the standard io module's readers??? + +- Support ZeroMQ "sockets" which are user objects. Though possibly + this can be supported by getting the underlying fd? See + http://mail.python.org/pipermail/python-ideas/2012-October/017532.html + OTOH see + https://github.com/zeromq/pyzmq/blob/master/zmq/eventloop/ioloop.py + +- Study goroutines (again). + +- Benchmarks: http://nichol.as/benchmark-of-python-web-servers + + +FROM OLDER LIST + +- Multiple readers/writers per socket? (At which level? pollster, + eventloop, or scheduler?) + +- Could poll() usefully be an iterator? + +- Do we need to support more epoll and/or kqueue modes/flags/options/etc.? + +- Optimize register/unregister calls away if they cancel each other out? + +- Add explicit wait queue to wait for Task's completion, instead of + callbacks? + +- Look at pyfdpdlib's ioloop.py: + http://code.google.com/p/pyftpdlib/source/browse/trunk/pyftpdlib/lib/ioloop.py + + +MISTAKES I MADE + +- Forgetting yield from. (E.g.: scheduler.sleep(1); listener.accept().) + +- Forgot to add bare yield at end of internal function, after block(). + +- Forgot to call add_done_callback(). + +- Forgot to pass an undoer to block(), bug only found when cancelled. + +- Subtle accounting mistake in a callback. + +- Used context.eventloop from a different thread, forgetting about TLS. + +- Nasty race: eventloop.ready may contain both an I/O callback and a + cancel callback. How to avoid? Keep the DelayedCall in ready. Is + that enough? + +- If a toplevel task raises an error it just stops and nothing is logged + unless you have debug logging on. This confused me. (Then again, + previously I logged whenever a task raised an error, and that was too + chatty...) + +- Forgot to set the connection socket returned by accept() in + nonblocking mode. + +- Nastiest so far (cost me about a day): A race condition in + call_in_thread() where the Future's done_callback (which was + task.unblock()) would run immediately at the time when + add_done_callback() was called, and this screwed over the task + state. Solution: wrap the callback in eventloop.call_later(). + Ironically, I had a comment stating there might be a race condition. + +- Another bug where I was calling unblock() for the current thread + immediately after calling block(), before yielding. + +- readexactly() wasn't checking for EOF, so could be looping. + (Worse, the first fix I attempted was wrong.) + +- Spent a day trying to understand why a tentative patch trying to + move the recv() implementation into the eventloop (or the pollster) + resulted in problems cancelling a recv() call. Ultimately the + problem is that the cancellation mechanism is part of the coroutine + scheduler, which simply throws an exception into a task when it next + runs, and there isn't anything to be interrupted in the eventloop; + but the eventloop still has a reader registered (which will never + fire because I suspended the server -- that's my test case :-). + Then, the eventloop keeps running until the last file descriptor is + unregistered. What contributed to this disaster? + * I didn't build the whole infrastructure, just played with recv() + * I don't have unittests + * I don't have good logging to see what is going + +- In sockets.py, in some SSL error handling code, used the wrong + variable (sock instead of sslsock). A linter would have found this. + +- In polling.py, in KqueuePollster.register_writer(), a copy/paste + error where I was testing for "if fd not in self.readers" instead of + writers. This only came out when I had both a reader and a writer + for the same fd. + +- Submitted some changes prematurely (forgot to pass the filename on + hg ci). + +- Forgot again that shutdown(SHUT_WR) on an ssl socket does not work + as I expected. I ran into this with the origininal sockets.py and + again in transport.py. + +- Having the same callback for both reading and writing has a problem: + it may be scheduled twice, and if the first call closes the socket, + the second runs into trouble. + + +MISTAKES I MADE IN TULIP V2 + +- Nice one: Future._schedule_callbacks() wasn't scheduling any callbacks. + Spot the bug in these four lines: + + def _schedule_callbacks(self): + callbacks = self._callbacks[:] + self._callbacks[:] = [] + for callback in self._callbacks: + self._event_loop.call_soon(callback, self) + + The good news is that I found it with a unittest (albeit not the + unittest intended to exercise this particular method :-( ). + +- In _make_self_pipe_or_sock(), called _pollster.register_reader() + instead of add_reader(), trying to optimize something but breaking + things instead (since the -- internal -- API of register_reader() + had changed). diff --git a/check.py b/check.py new file mode 100644 index 00000000..9ab6bcc0 --- /dev/null +++ b/check.py @@ -0,0 +1,41 @@ +"""Search for lines >= 80 chars or with trailing whitespace.""" + +import sys, os + +def main(): + args = sys.argv[1:] or os.curdir + for arg in args: + if os.path.isdir(arg): + for dn, dirs, files in os.walk(arg): + for fn in sorted(files): + if fn.endswith('.py'): + process(os.path.join(dn, fn)) + dirs[:] = [d for d in dirs if d[0] != '.'] + dirs.sort() + else: + process(arg) + +def isascii(x): + try: + x.encode('ascii') + return True + except UnicodeError: + return False + +def process(fn): + try: + f = open(fn) + except IOError as err: + print(err) + return + try: + for i, line in enumerate(f): + line = line.rstrip('\n') + sline = line.rstrip() + if len(line) >= 80 or line != sline or not isascii(line): + print('{}:{:d}:{}{}'.format( + fn, i+1, sline, '_' * (len(line) - len(sline)))) + finally: + f.close() + +main() diff --git a/examples/child_process.py b/examples/child_process.py new file mode 100644 index 00000000..d4a035bd --- /dev/null +++ b/examples/child_process.py @@ -0,0 +1,127 @@ +""" +Example of asynchronous interaction with a child python process. + +Note that on Windows we must use the IOCP event loop. +""" + +import os +import sys + +try: + import tulip +except ImportError: + # tulip is not installed + sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + import tulip + +from tulip import streams +from tulip import protocols + +if sys.platform == 'win32': + from tulip.windows_utils import Popen, PIPE + from tulip.windows_events import ProactorEventLoop +else: + from subprocess import Popen, PIPE + +# +# Return a write-only transport wrapping a writable pipe +# + +@tulip.coroutine +def connect_write_pipe(file): + loop = tulip.get_event_loop() + protocol = protocols.Protocol() + transport, _ = yield from loop.connect_write_pipe(tulip.Protocol, file) + return transport + +# +# Wrap a readable pipe in a stream +# + +@tulip.coroutine +def connect_read_pipe(file): + loop = tulip.get_event_loop() + stream_reader = streams.StreamReader(loop=loop) + def factory(): + return streams.StreamReaderProtocol(stream_reader) + transport, _ = yield from loop.connect_read_pipe(factory, file) + return stream_reader, transport + + +# +# Example +# + +@tulip.task +def main(loop): + # program which prints evaluation of each expression from stdin + code = r'''if 1: + import os + def writeall(fd, buf): + while buf: + n = os.write(fd, buf) + buf = buf[n:] + while True: + s = os.read(0, 1024) + if not s: + break + s = s.decode('ascii') + s = repr(eval(s)) + '\n' + s = s.encode('ascii') + writeall(1, s) + ''' + + # commands to send to input + commands = iter([b"1+1\n", + b"2**16\n", + b"1/3\n", + b"'x'*50", + b"1/0\n"]) + + # start subprocess and wrap stdin, stdout, stderr + p = Popen([sys.executable, '-c', code], + stdin=PIPE, stdout=PIPE, stderr=PIPE) + + stdin = yield from connect_write_pipe(p.stdin) + stdout, stdout_transport = yield from connect_read_pipe(p.stdout) + stderr, stderr_transport = yield from connect_read_pipe(p.stderr) + + # interact with subprocess + name = {stdout:'OUT', stderr:'ERR'} + registered = {tulip.Task(stderr.readline()): stderr, + tulip.Task(stdout.readline()): stdout} + while registered: + # write command + cmd = next(commands, None) + if cmd is None: + stdin.close() + else: + print('>>>', cmd.decode('ascii').rstrip()) + stdin.write(cmd) + + # get and print lines from stdout, stderr + timeout = None + while registered: + done, pending = yield from tulip.wait( + registered, timeout=timeout, return_when=tulip.FIRST_COMPLETED) + if not done: + break + for f in done: + stream = registered.pop(f) + res = f.result() + print(name[stream], res.decode('ascii').rstrip()) + if res != b'': + registered[tulip.Task(stream.readline())] = stream + timeout = 0.0 + + stdout_transport.close() + stderr_transport.close() + +if __name__ == '__main__': + if sys.platform == 'win32': + loop = ProactorEventLoop() + tulip.set_event_loop(loop) + else: + loop = tulip.get_event_loop() + loop.run_until_complete(main(loop)) + loop.close() diff --git a/examples/crawl.py b/examples/crawl.py new file mode 100755 index 00000000..ac9c25e9 --- /dev/null +++ b/examples/crawl.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 + +import logging +import re +import signal +import sys +import urllib.parse + +import tulip +import tulip.http + + +class Crawler: + + def __init__(self, rooturl, loop, maxtasks=100): + self.rooturl = rooturl + self.loop = loop + self.todo = set() + self.busy = set() + self.done = {} + self.tasks = set() + self.sem = tulip.Semaphore(maxtasks) + + # session stores cookies between requests and uses connection pool + self.session = tulip.http.Session() + + @tulip.task + def run(self): + self.addurls([(self.rooturl, '')]) # Set initial work. + yield from tulip.sleep(1) + while self.busy: + yield from tulip.sleep(1) + + self.session.close() + self.loop.stop() + + @tulip.task + def addurls(self, urls): + for url, parenturl in urls: + url = urllib.parse.urljoin(parenturl, url) + url, frag = urllib.parse.urldefrag(url) + if (url.startswith(self.rooturl) and + url not in self.busy and + url not in self.done and + url not in self.todo): + self.todo.add(url) + yield from self.sem.acquire() + task = self.process(url) + task.add_done_callback(lambda t: self.sem.release()) + task.add_done_callback(self.tasks.remove) + self.tasks.add(task) + + @tulip.task + def process(self, url): + print('processing:', url) + + self.todo.remove(url) + self.busy.add(url) + try: + resp = yield from tulip.http.request( + 'get', url, session=self.session) + except Exception as exc: + print('...', url, 'has error', repr(str(exc))) + self.done[url] = False + else: + if resp.status == 200 and resp.get_content_type() == 'text/html': + data = (yield from resp.read()).decode('utf-8', 'replace') + urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', data) + self.addurls([(u, url) for u in urls]) + + resp.close() + self.done[url] = True + + self.busy.remove(url) + print(len(self.done), 'completed tasks,', len(self.tasks), + 'still pending, todo', len(self.todo)) + + +def main(): + loop = tulip.get_event_loop() + + c = Crawler(sys.argv[1], loop) + c.run() + + try: + loop.add_signal_handler(signal.SIGINT, loop.stop) + except RuntimeError: + pass + loop.run_forever() + print('todo:', len(c.todo)) + print('busy:', len(c.busy)) + print('done:', len(c.done), '; ok:', sum(c.done.values())) + print('tasks:', len(c.tasks)) + + +if __name__ == '__main__': + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + logging.info('using iocp') + el = windows_events.ProactorEventLoop() + events.set_event_loop(el) + + main() diff --git a/examples/curl.py b/examples/curl.py new file mode 100755 index 00000000..7063adcd --- /dev/null +++ b/examples/curl.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 + +import sys +import tulip +import tulip.http + + +def curl(url): + response = yield from tulip.http.request('GET', url) + print(repr(response)) + + data = yield from response.read() + print(data.decode('utf-8', 'replace')) + + +if __name__ == '__main__': + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + el = windows_events.ProactorEventLoop() + events.set_event_loop(el) + + loop = tulip.get_event_loop() + loop.run_until_complete(curl(sys.argv[1])) diff --git a/examples/mpsrv.py b/examples/mpsrv.py new file mode 100755 index 00000000..6b1ebb8f --- /dev/null +++ b/examples/mpsrv.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python3 +"""Simple multiprocess http server written using an event loop.""" + +import argparse +import email.message +import os +import socket +import signal +import time +import tulip +import tulip.http +from tulip.http import websocket + +ARGS = argparse.ArgumentParser(description="Run simple http server.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') +ARGS.add_argument( + '--workers', action="store", dest='workers', + default=2, type=int, help='Number of workers.') + + +class HttpServer(tulip.http.ServerHttpProtocol): + + @tulip.coroutine + def handle_request(self, message, payload): + print('{}: method = {!r}; path = {!r}; version = {!r}'.format( + os.getpid(), message.method, message.path, message.version)) + + path = message.path + + if (not (path.isprintable() and path.startswith('/')) or '/.' in path): + path = None + else: + path = '.' + path + if not os.path.exists(path): + path = None + else: + isdir = os.path.isdir(path) + + if not path: + raise tulip.http.HttpStatusException(404) + + headers = email.message.Message() + for hdr, val in message.headers: + headers.add_header(hdr, val) + + if isdir and not path.endswith('/'): + path = path + '/' + raise tulip.http.HttpStatusException( + 302, headers=(('URI', path), ('Location', path))) + + response = tulip.http.Response(self.transport, 200) + response.add_header('Transfer-Encoding', 'chunked') + + # content encoding + accept_encoding = headers.get('accept-encoding', '').lower() + if 'deflate' in accept_encoding: + response.add_header('Content-Encoding', 'deflate') + response.add_compression_filter('deflate') + elif 'gzip' in accept_encoding: + response.add_header('Content-Encoding', 'gzip') + response.add_compression_filter('gzip') + + response.add_chunking_filter(1025) + + if isdir: + response.add_header('Content-type', 'text/html') + response.send_headers() + + response.write(b'
    \r\n') + for name in sorted(os.listdir(path)): + if name.isprintable() and not name.startswith('.'): + try: + bname = name.encode('ascii') + except UnicodeError: + pass + else: + if os.path.isdir(os.path.join(path, name)): + response.write(b'
  • ' + bname + b'/
  • \r\n') + else: + response.write(b'
  • ' + bname + b'
  • \r\n') + response.write(b'
') + else: + response.add_header('Content-type', 'text/plain') + response.send_headers() + + try: + with open(path, 'rb') as fp: + chunk = fp.read(8196) + while chunk: + response.write(chunk) + chunk = fp.read(8196) + except OSError: + response.write(b'Cannot open') + + response.write_eof() + if response.keep_alive(): + self.keep_alive(True) + + +class ChildProcess: + + def __init__(self, up_read, down_write, args, sock): + self.up_read = up_read + self.down_write = down_write + self.args = args + self.sock = sock + + def start(self): + # start server + self.loop = loop = tulip.new_event_loop() + tulip.set_event_loop(loop) + + def stop(): + self.loop.stop() + os._exit(0) + loop.add_signal_handler(signal.SIGINT, stop) + + f = loop.start_serving( + lambda: HttpServer(debug=True, keep_alive=75), sock=self.sock) + x = loop.run_until_complete(f)[0] + print('Starting srv worker process {} on {}'.format( + os.getpid(), x.getsockname())) + + # heartbeat + self.heartbeat() + + tulip.get_event_loop().run_forever() + os._exit(0) + + @tulip.task + def heartbeat(self): + # setup pipes + read_transport, read_proto = yield from self.loop.connect_read_pipe( + tulip.StreamProtocol, os.fdopen(self.up_read, 'rb')) + write_transport, _ = yield from self.loop.connect_write_pipe( + tulip.StreamProtocol, os.fdopen(self.down_write, 'wb')) + + reader = read_proto.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(write_transport) + + while True: + msg = yield from reader.read() + if msg is None: + print('Superviser is dead, {} stopping...'.format(os.getpid())) + self.loop.stop() + break + elif msg.tp == websocket.MSG_PING: + writer.pong() + elif msg.tp == websocket.MSG_CLOSE: + break + + read_transport.close() + write_transport.close() + + +class Worker: + + _started = False + + def __init__(self, loop, args, sock): + self.loop = loop + self.args = args + self.sock = sock + self.start() + + def start(self): + assert not self._started + self._started = True + + up_read, up_write = os.pipe() + down_read, down_write = os.pipe() + args, sock = self.args, self.sock + + pid = os.fork() + if pid: + # parent + os.close(up_read) + os.close(down_write) + self.connect(pid, up_write, down_read) + else: + # child + os.close(up_write) + os.close(down_read) + + # cleanup after fork + tulip.set_event_loop(None) + + # setup process + process = ChildProcess(up_read, down_write, args, sock) + process.start() + + @tulip.task + def heartbeat(self, writer): + while True: + yield from tulip.sleep(15) + + if (time.monotonic() - self.ping) < 30: + writer.ping() + else: + print('Restart unresponsive worker process: {}'.format( + self.pid)) + self.kill() + self.start() + return + + @tulip.task + def chat(self, reader): + while True: + msg = yield from reader.read() + if msg is None: + print('Restart unresponsive worker process: {}'.format( + self.pid)) + self.kill() + self.start() + return + elif msg.tp == websocket.MSG_PONG: + self.ping = time.monotonic() + + @tulip.task + def connect(self, pid, up_write, down_read): + # setup pipes + read_transport, proto = yield from self.loop.connect_read_pipe( + tulip.StreamProtocol, os.fdopen(down_read, 'rb')) + write_transport, _ = yield from self.loop.connect_write_pipe( + tulip.StreamProtocol, os.fdopen(up_write, 'wb')) + + # websocket protocol + reader = proto.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(write_transport) + + # store info + self.pid = pid + self.ping = time.monotonic() + self.rtransport = read_transport + self.wtransport = write_transport + self.chat_task = self.chat(reader) + self.heartbeat_task = self.heartbeat(writer) + + def kill(self): + self._started = False + self.chat_task.cancel() + self.heartbeat_task.cancel() + self.rtransport.close() + self.wtransport.close() + os.kill(self.pid, signal.SIGTERM) + + +class Superviser: + + def __init__(self, args): + self.loop = tulip.get_event_loop() + self.args = args + self.workers = [] + + def start(self): + # bind socket + sock = self.sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((self.args.host, self.args.port)) + sock.listen(1024) + sock.setblocking(False) + + # start processes + for idx in range(self.args.workers): + self.workers.append(Worker(self.loop, self.args, sock)) + + self.loop.add_signal_handler(signal.SIGINT, lambda: self.loop.stop()) + self.loop.run_forever() + + +def main(): + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + superviser = Superviser(args) + superviser.start() + + +if __name__ == '__main__': + main() diff --git a/examples/srv.py b/examples/srv.py new file mode 100755 index 00000000..e01e407c --- /dev/null +++ b/examples/srv.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +"""Simple server written using an event loop.""" + +import argparse +import email.message +import logging +import os +import sys +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +import tulip +import tulip.http + + +class HttpServer(tulip.http.ServerHttpProtocol): + + @tulip.coroutine + def handle_request(self, message, payload): + print('method = {!r}; path = {!r}; version = {!r}'.format( + message.method, message.path, message.version)) + + path = message.path + + if (not (path.isprintable() and path.startswith('/')) or '/.' in path): + print('bad path', repr(path)) + path = None + else: + path = '.' + path + if not os.path.exists(path): + print('no file', repr(path)) + path = None + else: + isdir = os.path.isdir(path) + + if not path: + raise tulip.http.HttpStatusException(404) + + headers = email.message.Message() + for hdr, val in message.headers: + print(hdr, val) + headers.add_header(hdr, val) + + if isdir and not path.endswith('/'): + path = path + '/' + raise tulip.http.HttpStatusException( + 302, headers=(('URI', path), ('Location', path))) + + response = tulip.http.Response(self.transport, 200) + response.add_header('Transfer-Encoding', 'chunked') + + # content encoding + accept_encoding = headers.get('accept-encoding', '').lower() + if 'deflate' in accept_encoding: + response.add_header('Content-Encoding', 'deflate') + response.add_compression_filter('deflate') + elif 'gzip' in accept_encoding: + response.add_header('Content-Encoding', 'gzip') + response.add_compression_filter('gzip') + + response.add_chunking_filter(1025) + + if isdir: + response.add_header('Content-type', 'text/html') + response.send_headers() + + response.write(b'
    \r\n') + for name in sorted(os.listdir(path)): + if name.isprintable() and not name.startswith('.'): + try: + bname = name.encode('ascii') + except UnicodeError: + pass + else: + if os.path.isdir(os.path.join(path, name)): + response.write(b'
  • ' + bname + b'/
  • \r\n') + else: + response.write(b'
  • ' + bname + b'
  • \r\n') + response.write(b'
') + else: + response.add_header('Content-type', 'text/plain') + response.send_headers() + + try: + with open(path, 'rb') as fp: + chunk = fp.read(8196) + while chunk: + response.write(chunk) + chunk = fp.read(8196) + except OSError: + response.write(b'Cannot open') + + response.write_eof() + if response.keep_alive(): + self.keep_alive(True) + + +ARGS = argparse.ArgumentParser(description="Run simple http server.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') +ARGS.add_argument( + '--iocp', action="store_true", dest='iocp', help='Windows IOCP event loop') +ARGS.add_argument( + '--ssl', action="store_true", dest='ssl', help='Run ssl mode.') +ARGS.add_argument( + '--sslcert', action="store", dest='certfile', help='SSL cert file.') +ARGS.add_argument( + '--sslkey', action="store", dest='keyfile', help='SSL key file.') + + +def main(): + args = ARGS.parse_args() + + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if args.iocp: + from tulip import windows_events + sys.argv.remove('--iocp') + logging.info('using iocp') + el = windows_events.ProactorEventLoop() + tulip.set_event_loop(el) + + if args.ssl: + here = os.path.join(os.path.dirname(__file__), 'tests') + + if args.certfile: + certfile = args.certfile or os.path.join(here, 'sample.crt') + keyfile = args.keyfile or os.path.join(here, 'sample.key') + else: + certfile = os.path.join(here, 'sample.crt') + keyfile = os.path.join(here, 'sample.key') + + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain(certfile, keyfile) + else: + sslcontext = None + + loop = tulip.get_event_loop() + f = loop.start_serving( + lambda: HttpServer(debug=True, keep_alive=75), args.host, args.port, + ssl=sslcontext) + socks = loop.run_until_complete(f) + print('serving on', socks[0].getsockname()) + try: + loop.run_forever() + except KeyboardInterrupt: + pass + + +if __name__ == '__main__': + main() diff --git a/examples/tcp_echo.py b/examples/tcp_echo.py new file mode 100755 index 00000000..39db5cca --- /dev/null +++ b/examples/tcp_echo.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +"""TCP echo server example.""" +import argparse +import tulip +try: + import signal +except ImportError: + signal = None + + +class EchoServer(tulip.Protocol): + + TIMEOUT = 5.0 + + def timeout(self): + print('connection timeout, closing.') + self.transport.close() + + def connection_made(self, transport): + print('connection made') + self.transport = transport + + # start 5 seconds timeout timer + self.h_timeout = tulip.get_event_loop().call_later( + self.TIMEOUT, self.timeout) + + def data_received(self, data): + print('data received: ', data.decode()) + self.transport.write(b'Re: ' + data) + + # restart timeout timer + self.h_timeout.cancel() + self.h_timeout = tulip.get_event_loop().call_later( + self.TIMEOUT, self.timeout) + + def eof_received(self): + pass + + def connection_lost(self, exc): + print('connection lost:', exc) + self.h_timeout.cancel() + + +class EchoClient(tulip.Protocol): + + message = 'This is the message. It will be echoed.' + + def connection_made(self, transport): + self.transport = transport + self.transport.write(self.message.encode()) + print('data sent:', self.message) + + def data_received(self, data): + print('data received:', data) + + # disconnect after 10 seconds + tulip.get_event_loop().call_later(10.0, self.transport.close) + + def eof_received(self): + pass + + def connection_lost(self, exc): + print('connection lost:', exc) + tulip.get_event_loop().stop() + + +def start_client(loop, host, port): + t = tulip.Task(loop.create_connection(EchoClient, host, port)) + loop.run_until_complete(t) + + +def start_server(loop, host, port): + f = loop.start_serving(EchoServer, host, port) + x = loop.run_until_complete(f)[0] + print('serving on', x.getsockname()) + + +ARGS = argparse.ArgumentParser(description="TCP Echo example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run tcp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run tcp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') + + +if __name__ == '__main__': + args = ARGS.parse_args() + + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + loop = tulip.get_event_loop() + if signal is not None: + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if args.server: + start_server(loop, args.host, args.port) + else: + start_client(loop, args.host, args.port) + + loop.run_forever() diff --git a/examples/tcp_protocol_parser.py b/examples/tcp_protocol_parser.py new file mode 100755 index 00000000..a0258613 --- /dev/null +++ b/examples/tcp_protocol_parser.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +"""Protocol parser example.""" +import argparse +import collections +import tulip +try: + import signal +except ImportError: + signal = None + + +MSG_TEXT = b'text:' +MSG_PING = b'ping:' +MSG_PONG = b'pong:' +MSG_STOP = b'stop:' + +Message = collections.namedtuple('Message', ('tp', 'data')) + + +def my_protocol_parser(): + """Parser is used with StreamBuffer for incremental protocol parsing. + Parser is a generator function, but it is not a coroutine. Usually + parsers are implemented as a state machine. + + more details in tulip/parsers.py + existing parsers: + * http protocol parsers tulip/http/protocol.py + * websocket parser tulip/http/websocket.py + """ + out, buf = yield + + while True: + tp = yield from buf.read(5) + if tp in (MSG_PING, MSG_PONG): + # skip line + yield from buf.skipuntil(b'\r\n') + out.feed_data(Message(tp, None)) + elif tp == MSG_STOP: + out.feed_data(Message(tp, None)) + elif tp == MSG_TEXT: + # read text + text = yield from buf.readuntil(b'\r\n') + out.feed_data(Message(tp, text.strip().decode('utf-8'))) + else: + raise ValueError('Unknown protocol prefix.') + + +class MyProtocolWriter: + + def __init__(self, transport): + self.transport = transport + + def ping(self): + self.transport.write(b'ping:\r\n') + + def pong(self): + self.transport.write(b'pong:\r\n') + + def stop(self): + self.transport.write(b'stop:\r\n') + + def send_text(self, text): + self.transport.write( + 'text:{}\r\n'.format(text.strip()).encode('utf-8')) + + +class EchoServer(tulip.Protocol): + + def connection_made(self, transport): + print('Connection made') + self.transport = transport + self.stream = tulip.StreamBuffer() + self.dispatch() + + def data_received(self, data): + self.stream.feed_data(data) + + def eof_received(self): + self.stream.feed_eof() + + def connection_lost(self, exc): + print('Connection lost') + + @tulip.task + def dispatch(self): + reader = self.stream.set_parser(my_protocol_parser()) + writer = MyProtocolWriter(self.transport) + + while True: + msg = yield from reader.read() + if msg is None: + break # client has been disconnected + + print('Message received: {}'.format(msg)) + + if msg.tp == MSG_PING: + writer.pong() + elif msg.tp == MSG_TEXT: + writer.send_text('Re: ' + msg.data) + elif msg.tp == MSG_STOP: + self.transport.close() + break + + +@tulip.task +def start_client(loop, host, port): + transport, stream = yield from loop.create_connection( + tulip.StreamProtocol, host, port) + reader = stream.set_parser(my_protocol_parser()) + writer = MyProtocolWriter(transport) + writer.ping() + + message = 'This is the message. It will be echoed.' + + while True: + msg = yield from reader.read() + + print('Message received: {}'.format(msg)) + if msg.tp == MSG_PONG: + writer.send_text(message) + print('data sent:', message) + elif msg.tp == MSG_TEXT: + writer.stop() + print('stop sent') + break + + transport.close() + + +def start_server(loop, host, port): + f = loop.start_serving(EchoServer, host, port) + x = loop.run_until_complete(f)[0] + print('serving on', x.getsockname()) + loop.run_forever() + + +ARGS = argparse.ArgumentParser(description="Protocol parser example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run tcp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run tcp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') + + +if __name__ == '__main__': + args = ARGS.parse_args() + + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + loop = tulip.get_event_loop() + if signal is not None: + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if args.server: + start_server(loop, args.host, args.port) + else: + loop.run_until_complete(start_client(loop, args.host, args.port)) diff --git a/examples/udp_echo.py b/examples/udp_echo.py new file mode 100755 index 00000000..0347bfbd --- /dev/null +++ b/examples/udp_echo.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +"""UDP echo example.""" +import argparse +import sys +import tulip +try: + import signal +except ImportError: + signal = None + + +class MyServerUdpEchoProtocol: + + def connection_made(self, transport): + print('start', transport) + self.transport = transport + + def datagram_received(self, data, addr): + print('Data received:', data, addr) + self.transport.sendto(data, addr) + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('stop', exc) + + +class MyClientUdpEchoProtocol: + + message = 'This is the message. It will be echoed.' + + def connection_made(self, transport): + self.transport = transport + print('sending "{}"'.format(self.message)) + self.transport.sendto(self.message.encode()) + print('waiting to receive') + + def datagram_received(self, data, addr): + print('received "{}"'.format(data.decode())) + self.transport.close() + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('closing transport', exc) + loop = tulip.get_event_loop() + loop.stop() + + +def start_server(loop, addr): + t = tulip.Task(loop.create_datagram_endpoint( + MyServerUdpEchoProtocol, local_addr=addr)) + loop.run_until_complete(t) + + +def start_client(loop, addr): + t = tulip.Task(loop.create_datagram_endpoint( + MyClientUdpEchoProtocol, remote_addr=addr)) + loop.run_until_complete(t) + + +ARGS = argparse.ArgumentParser(description="UDP Echo example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run udp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run udp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') + + +if __name__ == '__main__': + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + loop = tulip.get_event_loop() + if signal is not None: + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if '--server' in sys.argv: + start_server(loop, (args.host, args.port)) + else: + start_client(loop, (args.host, args.port)) + + loop.run_forever() diff --git a/examples/websocket.html b/examples/websocket.html new file mode 100644 index 00000000..6bad7f74 --- /dev/null +++ b/examples/websocket.html @@ -0,0 +1,90 @@ + + + + + + + + +

Chat!

+
+  | Status: + disconnected +
+
+
+
+ + +
+ + diff --git a/examples/wsclient.py b/examples/wsclient.py new file mode 100755 index 00000000..f5b2ef58 --- /dev/null +++ b/examples/wsclient.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +"""websocket cmd client for wssrv.py example.""" +import argparse +import base64 +import hashlib +import os +import signal +import sys + +import tulip +import tulip.http +from tulip.http import websocket +import tulip.selectors + +WS_KEY = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + +def start_client(loop, url): + name = input('Please enter your name: ').encode() + + sec_key = base64.b64encode(os.urandom(16)) + + # send request + response = yield from tulip.http.request( + 'get', url, + headers={ + 'UPGRADE': 'WebSocket', + 'CONNECTION': 'Upgrade', + 'SEC-WEBSOCKET-VERSION': '13', + 'SEC-WEBSOCKET-KEY': sec_key.decode(), + }, timeout=1.0) + + # websocket handshake + if response.status != 101: + raise ValueError("Handshake error: Invalid response status") + if response.get('upgrade', '').lower() != 'websocket': + raise ValueError("Handshake error - Invalid upgrade header") + if response.get('connection', '').lower() != 'upgrade': + raise ValueError("Handshake error - Invalid connection header") + + key = response.get('sec-websocket-accept', '').encode() + match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()) + if key != match: + raise ValueError("Handshake error - Invalid challenge response") + + # switch to websocket protocol + stream = response.stream.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(response.transport) + + # input reader + def stdin_callback(): + line = sys.stdin.buffer.readline() + if not line: + loop.stop() + else: + writer.send(name + b': ' + line) + loop.add_reader(sys.stdin.fileno(), stdin_callback) + + @tulip.coroutine + def dispatch(): + while True: + msg = yield from stream.read() + if msg is None: + break + elif msg.tp == websocket.MSG_PING: + writer.pong() + elif msg.tp == websocket.MSG_TEXT: + print(msg.data.strip()) + elif msg.tp == websocket.MSG_CLOSE: + break + + yield from dispatch() + + +ARGS = argparse.ArgumentParser( + description="websocket console client for wssrv.py example.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') + +if __name__ == '__main__': + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + url = 'http://{}:{}'.format(args.host, args.port) + + loop = tulip.SelectorEventLoop(tulip.selectors.SelectSelector()) + tulip.set_event_loop(loop) + + loop.add_signal_handler(signal.SIGINT, loop.stop) + tulip.Task(start_client(loop, url)) + loop.run_forever() diff --git a/examples/wssrv.py b/examples/wssrv.py new file mode 100755 index 00000000..f96e0855 --- /dev/null +++ b/examples/wssrv.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python3 +"""Multiprocess WebSocket http chat example.""" + +import argparse +import os +import socket +import signal +import time +import tulip +import tulip.http +from tulip.http import websocket + +ARGS = argparse.ArgumentParser(description="Run simple http server.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') +ARGS.add_argument( + '--workers', action="store", dest='workers', + default=2, type=int, help='Number of workers.') + +WS_FILE = os.path.join(os.path.dirname(__file__), 'websocket.html') + + +class HttpServer(tulip.http.ServerHttpProtocol): + + clients = None # list of all active connections + parent = None # process supervisor + # we use it as broadcaster to all workers + + @tulip.coroutine + def handle_request(self, message, payload): + upgrade = False + for hdr, val in message.headers: + if hdr == 'UPGRADE': + upgrade = 'websocket' in val.lower() + break + + if upgrade: + # websocket handshake + status, headers, parser, writer = websocket.do_handshake( + message.method, message.headers, self.transport) + + resp = tulip.http.Response(self.transport, status) + resp.add_headers(*headers) + resp.send_headers() + + # install websocket parser + databuffer = self.stream.set_parser(parser) + + # notify everybody + print('{}: Someone joined.'.format(os.getpid())) + for wsc in self.clients: + wsc.send(b'Someone joined.') + self.clients.append(writer) + self.parent.send(b'Someone joined.') + + # chat dispatcher + while True: + msg = yield from databuffer.read() + if msg is None: # client droped connection + break + + if msg.tp == websocket.MSG_PING: + writer.pong() + + elif msg.tp == websocket.MSG_TEXT: + data = msg.data.strip() + print('{}: {}'.format(os.getpid(), data)) + for wsc in self.clients: + if wsc is not writer: + wsc.send(data.encode()) + self.parent.send(data) + + elif msg.tp == websocket.MSG_CLOSE: + break + + # notify everybody + print('{}: Someone disconnected.'.format(os.getpid())) + self.parent.send(b'Someone disconnected.') + self.clients.remove(writer) + for wsc in self.clients: + wsc.send(b'Someone disconnected.') + + else: + # send html page with js chat + response = tulip.http.Response(self.transport, 200) + response.add_header('Transfer-Encoding', 'chunked') + response.add_header('Content-type', 'text/html') + response.send_headers() + + try: + with open(WS_FILE, 'rb') as fp: + chunk = fp.read(8196) + while chunk: + if not response.write(chunk): + break + chunk = fp.read(8196) + except OSError: + response.write(b'Cannot open') + + response.write_eof() + if response.keep_alive(): + self.keep_alive(True) + + +class ChildProcess: + + def __init__(self, up_read, down_write, args, sock): + self.up_read = up_read + self.down_write = down_write + self.args = args + self.sock = sock + self.clients = [] + + def start(self): + # start server + self.loop = loop = tulip.new_event_loop() + tulip.set_event_loop(loop) + + def stop(): + self.loop.stop() + os._exit(0) + loop.add_signal_handler(signal.SIGINT, stop) + + # heartbeat + self.heartbeat() + + tulip.get_event_loop().run_forever() + os._exit(0) + + @tulip.task + def start_server(self, writer): + socks = yield from self.loop.start_serving( + lambda: HttpServer( + debug=True, keep_alive=75, + parent=writer, clients=self.clients), + sock=self.sock) + print('Starting srv worker process {} on {}'.format( + os.getpid(), socks[0].getsockname())) + + @tulip.task + def heartbeat(self): + # setup pipes + read_transport, read_proto = yield from self.loop.connect_read_pipe( + tulip.StreamProtocol, os.fdopen(self.up_read, 'rb')) + write_transport, _ = yield from self.loop.connect_write_pipe( + tulip.StreamProtocol, os.fdopen(self.down_write, 'wb')) + + reader = read_proto.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(write_transport) + + self.start_server(writer) + + while True: + msg = yield from reader.read() + if msg is None: + print('Superviser is dead, {} stopping...'.format(os.getpid())) + self.loop.stop() + break + elif msg.tp == websocket.MSG_PING: + writer.pong() + elif msg.tp == websocket.MSG_CLOSE: + break + elif msg.tp == websocket.MSG_TEXT: # broadcast message + for wsc in self.clients: + wsc.send(msg.data.strip().encode()) + + read_transport.close() + write_transport.close() + + +class Worker: + + _started = False + + def __init__(self, sv, loop, args, sock): + self.sv = sv + self.loop = loop + self.args = args + self.sock = sock + self.start() + + def start(self): + assert not self._started + self._started = True + + up_read, up_write = os.pipe() + down_read, down_write = os.pipe() + args, sock = self.args, self.sock + + pid = os.fork() + if pid: + # parent + os.close(up_read) + os.close(down_write) + self.connect(pid, up_write, down_read) + else: + # child + os.close(up_write) + os.close(down_read) + + # cleanup after fork + tulip.set_event_loop(None) + + # setup process + process = ChildProcess(up_read, down_write, args, sock) + process.start() + + @tulip.task + def heartbeat(self, writer): + while True: + yield from tulip.sleep(15) + + if (time.monotonic() - self.ping) < 30: + writer.ping() + else: + print('Restart unresponsive worker process: {}'.format( + self.pid)) + self.kill() + self.start() + return + + @tulip.task + def chat(self, reader): + while True: + msg = yield from reader.read() + if msg is None: + print('Restart unresponsive worker process: {}'.format( + self.pid)) + self.kill() + self.start() + return + + elif msg.tp == websocket.MSG_PONG: + self.ping = time.monotonic() + + elif msg.tp == websocket.MSG_TEXT: # broadcast to all workers + for worker in self.sv.workers: + if self.pid != worker.pid: + worker.writer.send(msg.data) + + @tulip.task + def connect(self, pid, up_write, down_read): + # setup pipes + read_transport, proto = yield from self.loop.connect_read_pipe( + tulip.StreamProtocol, os.fdopen(down_read, 'rb')) + write_transport, _ = yield from self.loop.connect_write_pipe( + tulip.StreamProtocol, os.fdopen(up_write, 'wb')) + + # websocket protocol + reader = proto.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(write_transport) + + # store info + self.pid = pid + self.ping = time.monotonic() + self.writer = writer + self.rtransport = read_transport + self.wtransport = write_transport + self.chat_task = self.chat(reader) + self.heartbeat_task = self.heartbeat(writer) + + def kill(self): + self._started = False + self.chat_task.cancel() + self.heartbeat_task.cancel() + self.rtransport.close() + self.wtransport.close() + os.kill(self.pid, signal.SIGTERM) + + +class Superviser: + + def __init__(self, args): + self.loop = tulip.get_event_loop() + self.args = args + self.workers = [] + + def start(self): + # bind socket + sock = self.sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((self.args.host, self.args.port)) + sock.listen(1024) + sock.setblocking(False) + + # start processes + for idx in range(self.args.workers): + self.workers.append(Worker(self, self.loop, self.args, sock)) + + self.loop.add_signal_handler(signal.SIGINT, lambda: self.loop.stop()) + self.loop.run_forever() + + +def main(): + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + superviser = Superviser(args) + superviser.start() + + +if __name__ == '__main__': + main() diff --git a/overlapped.c b/overlapped.c new file mode 100644 index 00000000..3a2c1208 --- /dev/null +++ b/overlapped.c @@ -0,0 +1,1009 @@ +/* + * Support for overlapped IO + * + * Some code borrowed from Modules/_winapi.c of CPython + */ + +/* XXX check overflow and DWORD <-> Py_ssize_t conversions + Check itemsize */ + +#include "Python.h" +#include "structmember.h" + +#define WINDOWS_LEAN_AND_MEAN +#include +#include +#include + +#if defined(MS_WIN32) && !defined(MS_WIN64) +# define F_POINTER "k" +# define T_POINTER T_ULONG +#else +# define F_POINTER "K" +# define T_POINTER T_ULONGLONG +#endif + +#define F_HANDLE F_POINTER +#define F_ULONG_PTR F_POINTER +#define F_DWORD "k" +#define F_BOOL "i" +#define F_UINT "I" + +#define T_HANDLE T_POINTER + +enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_WRITE, TYPE_ACCEPT, + TYPE_CONNECT, TYPE_DISCONNECT}; + +/* + * Map Windows error codes to subclasses of OSError + */ + +static PyObject * +SetFromWindowsErr(DWORD err) +{ + PyObject *exception_type; + + if (err == 0) + err = GetLastError(); + switch (err) { + case ERROR_CONNECTION_REFUSED: + exception_type = PyExc_ConnectionRefusedError; + break; + case ERROR_CONNECTION_ABORTED: + exception_type = PyExc_ConnectionAbortedError; + break; + default: + exception_type = PyExc_OSError; + } + return PyErr_SetExcFromWindowsErr(exception_type, err); +} + +/* + * Some functions should be loaded at runtime + */ + +static LPFN_ACCEPTEX Py_AcceptEx = NULL; +static LPFN_CONNECTEX Py_ConnectEx = NULL; +static LPFN_DISCONNECTEX Py_DisconnectEx = NULL; +static BOOL (CALLBACK *Py_CancelIoEx)(HANDLE, LPOVERLAPPED) = NULL; + +#define GET_WSA_POINTER(s, x) \ + (SOCKET_ERROR != WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, \ + &Guid##x, sizeof(Guid##x), &Py_##x, \ + sizeof(Py_##x), &dwBytes, NULL, NULL)) + +static int +initialize_function_pointers(void) +{ + GUID GuidAcceptEx = WSAID_ACCEPTEX; + GUID GuidConnectEx = WSAID_CONNECTEX; + GUID GuidDisconnectEx = WSAID_DISCONNECTEX; + HINSTANCE hKernel32; + SOCKET s; + DWORD dwBytes; + + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (s == INVALID_SOCKET) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + if (!GET_WSA_POINTER(s, AcceptEx) || + !GET_WSA_POINTER(s, ConnectEx) || + !GET_WSA_POINTER(s, DisconnectEx)) + { + closesocket(s); + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + closesocket(s); + + /* On WinXP we will have Py_CancelIoEx == NULL */ + hKernel32 = GetModuleHandle("KERNEL32"); + *(FARPROC *)&Py_CancelIoEx = GetProcAddress(hKernel32, "CancelIoEx"); + return 0; +} + +/* + * Completion port stuff + */ + +PyDoc_STRVAR( + CreateIoCompletionPort_doc, + "CreateIoCompletionPort(handle, port, key, concurrency) -> port\n\n" + "Create a completion port or register a handle with a port."); + +static PyObject * +overlapped_CreateIoCompletionPort(PyObject *self, PyObject *args) +{ + HANDLE FileHandle; + HANDLE ExistingCompletionPort; + ULONG_PTR CompletionKey; + DWORD NumberOfConcurrentThreads; + HANDLE ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE F_ULONG_PTR F_DWORD, + &FileHandle, &ExistingCompletionPort, &CompletionKey, + &NumberOfConcurrentThreads)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = CreateIoCompletionPort(FileHandle, ExistingCompletionPort, + CompletionKey, NumberOfConcurrentThreads); + Py_END_ALLOW_THREADS + + if (ret == NULL) + return SetFromWindowsErr(0); + return Py_BuildValue(F_HANDLE, ret); +} + +PyDoc_STRVAR( + GetQueuedCompletionStatus_doc, + "GetQueuedCompletionStatus(port, msecs) -> (err, bytes, key, address)\n\n" + "Get a message from completion port. Wait for up to msecs milliseconds."); + +static PyObject * +overlapped_GetQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort = NULL; + DWORD NumberOfBytes = 0; + ULONG_PTR CompletionKey = 0; + OVERLAPPED *Overlapped = NULL; + DWORD Milliseconds; + DWORD err; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, + &CompletionPort, &Milliseconds)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = GetQueuedCompletionStatus(CompletionPort, &NumberOfBytes, + &CompletionKey, &Overlapped, Milliseconds); + Py_END_ALLOW_THREADS + + err = ret ? ERROR_SUCCESS : GetLastError(); + if (Overlapped == NULL) { + if (err == WAIT_TIMEOUT) + Py_RETURN_NONE; + else + return SetFromWindowsErr(err); + } + return Py_BuildValue(F_DWORD F_DWORD F_ULONG_PTR F_POINTER, + err, NumberOfBytes, CompletionKey, Overlapped); +} + +PyDoc_STRVAR( + PostQueuedCompletionStatus_doc, + "PostQueuedCompletionStatus(port, bytes, key, address) -> None\n\n" + "Post a message to completion port."); + +static PyObject * +overlapped_PostQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort; + DWORD NumberOfBytes; + ULONG_PTR CompletionKey; + OVERLAPPED *Overlapped; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD F_ULONG_PTR F_POINTER, + &CompletionPort, &NumberOfBytes, &CompletionKey, + &Overlapped)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = PostQueuedCompletionStatus(CompletionPort, NumberOfBytes, + CompletionKey, Overlapped); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +/* + * Bind socket handle to local port without doing slow getaddrinfo() + */ + +PyDoc_STRVAR( + BindLocal_doc, + "BindLocal(handle, length_of_address_tuple) -> None\n\n" + "Bind a socket handle to an arbitrary local port.\n" + "If length_of_address_tuple is 2 then an AF_INET address is used.\n" + "If length_of_address_tuple is 4 then an AF_INET6 address is used."); + +static PyObject * +overlapped_BindLocal(PyObject *self, PyObject *args) +{ + SOCKET Socket; + int TupleLength; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE "i", &Socket, &TupleLength)) + return NULL; + + if (TupleLength == 2) { + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = 0; + addr.sin_addr.S_un.S_addr = INADDR_ANY; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else if (TupleLength == 4) { + struct sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_port = 0; + addr.sin6_addr = in6addr_any; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else { + PyErr_SetString(PyExc_ValueError, "expected tuple of length 2 or 4"); + return NULL; + } + + if (!ret) + return SetFromWindowsErr(WSAGetLastError()); + Py_RETURN_NONE; +} + +/* + * Mark operation as completed - used when reading produces ERROR_BROKEN_PIPE + */ + +static void +mark_as_completed(OVERLAPPED *ov) +{ + ov->Internal = 0; + if (ov->hEvent != NULL) + SetEvent(ov->hEvent); +} + +/* + * A Python object wrapping an OVERLAPPED structure and other useful data + * for overlapped I/O + */ + +PyDoc_STRVAR( + Overlapped_doc, + "Overlapped object"); + +typedef struct { + PyObject_HEAD + OVERLAPPED overlapped; + /* For convenience, we store the file handle too */ + HANDLE handle; + /* Error returned by last method call */ + DWORD error; + /* Type of operation */ + DWORD type; + /* Buffer used for reading (optional) */ + PyObject *read_buffer; + /* Buffer used for writing (optional) */ + Py_buffer write_buffer; +} OverlappedObject; + + +static PyObject * +Overlapped_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + OverlappedObject *self; + HANDLE event = INVALID_HANDLE_VALUE; + static char *kwlist[] = {"event", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|" F_HANDLE, kwlist, &event)) + return NULL; + + if (event == INVALID_HANDLE_VALUE) { + event = CreateEvent(NULL, TRUE, FALSE, NULL); + if (event == NULL) + return SetFromWindowsErr(0); + } + + self = PyObject_New(OverlappedObject, type); + if (self == NULL) { + if (event != NULL) + CloseHandle(event); + return NULL; + } + + self->handle = NULL; + self->error = 0; + self->type = TYPE_NONE; + self->read_buffer = NULL; + memset(&self->overlapped, 0, sizeof(OVERLAPPED)); + memset(&self->write_buffer, 0, sizeof(Py_buffer)); + if (event) + self->overlapped.hEvent = event; + return (PyObject *)self; +} + +static void +Overlapped_dealloc(OverlappedObject *self) +{ + DWORD bytes; + DWORD olderr = GetLastError(); + BOOL wait = FALSE; + BOOL ret; + + if (!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED) + { + if (Py_CancelIoEx && Py_CancelIoEx(self->handle, &self->overlapped)) + wait = TRUE; + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, + &bytes, wait); + Py_END_ALLOW_THREADS + + switch (ret ? ERROR_SUCCESS : GetLastError()) { + case ERROR_SUCCESS: + case ERROR_NOT_FOUND: + case ERROR_OPERATION_ABORTED: + break; + default: + PyErr_Format( + PyExc_RuntimeError, + "%R still has pending operation at " + "deallocation, the process may crash", self); + PyErr_WriteUnraisable(NULL); + } + } + + if (self->overlapped.hEvent != NULL) + CloseHandle(self->overlapped.hEvent); + + if (self->write_buffer.obj) + PyBuffer_Release(&self->write_buffer); + + Py_CLEAR(self->read_buffer); + PyObject_Del(self); + SetLastError(olderr); +} + +PyDoc_STRVAR( + Overlapped_cancel_doc, + "cancel() -> None\n\n" + "Cancel overlapped operation"); + +static PyObject * +Overlapped_cancel(OverlappedObject *self) +{ + BOOL ret = TRUE; + + if (self->type == TYPE_NOT_STARTED) + Py_RETURN_NONE; + + if (!HasOverlappedIoCompleted(&self->overlapped)) { + Py_BEGIN_ALLOW_THREADS + if (Py_CancelIoEx) + ret = Py_CancelIoEx(self->handle, &self->overlapped); + else + ret = CancelIo(self->handle); + Py_END_ALLOW_THREADS + } + + /* CancelIoEx returns ERROR_NOT_FOUND if the I/O completed in-between */ + if (!ret && GetLastError() != ERROR_NOT_FOUND) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +PyDoc_STRVAR( + Overlapped_getresult_doc, + "getresult(wait=False) -> result\n\n" + "Retrieve result of operation. If wait is true then it blocks\n" + "until the operation is finished. If wait is false and the\n" + "operation is still pending then an error is raised."); + +static PyObject * +Overlapped_getresult(OverlappedObject *self, PyObject *args) +{ + BOOL wait = FALSE; + DWORD transferred = 0; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, "|" F_BOOL, &wait)) + return NULL; + + if (self->type == TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation not yet attempted"); + return NULL; + } + + if (self->type == TYPE_NOT_STARTED) { + PyErr_SetString(PyExc_ValueError, "operation failed to start"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, &transferred, + wait); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + break; + case ERROR_BROKEN_PIPE: + if (self->read_buffer != NULL) + break; + /* fall through */ + default: + return SetFromWindowsErr(err); + } + + switch (self->type) { + case TYPE_READ: + assert(PyBytes_CheckExact(self->read_buffer)); + if (transferred != PyBytes_GET_SIZE(self->read_buffer) && + _PyBytes_Resize(&self->read_buffer, transferred)) + return NULL; + Py_INCREF(self->read_buffer); + return self->read_buffer; + case TYPE_ACCEPT: + case TYPE_CONNECT: + case TYPE_DISCONNECT: + Py_RETURN_NONE; + default: + return PyLong_FromUnsignedLong((unsigned long) transferred); + } +} + +PyDoc_STRVAR( + Overlapped_ReadFile_doc, + "ReadFile(handle, size) -> Overlapped[message]\n\n" + "Start overlapped read"); + +static PyObject * +Overlapped_ReadFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD nread; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &handle, &size)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = ReadFile(handle, PyBytes_AS_STRING(buf), size, &nread, + &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_BROKEN_PIPE: + mark_as_completed(&self->overlapped); + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSARecv_doc, + "RecvFile(handle, size, flags) -> Overlapped[message]\n\n" + "Start overlapped receive"); + +static PyObject * +Overlapped_WSARecv(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD flags = 0; + DWORD nread; + PyObject *buf; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD "|" F_DWORD, + &handle, &size, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + wsabuf.len = size; + wsabuf.buf = PyBytes_AS_STRING(buf); + + Py_BEGIN_ALLOW_THREADS + ret = WSARecv((SOCKET)handle, &wsabuf, 1, &nread, &flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_BROKEN_PIPE: + mark_as_completed(&self->overlapped); + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WriteFile_doc, + "WriteFile(handle, buf) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped write"); + +static PyObject * +Overlapped_WriteFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD written; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &handle, &bufobj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + + Py_BEGIN_ALLOW_THREADS + ret = WriteFile(handle, self->write_buffer.buf, + (DWORD)self->write_buffer.len, + &written, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSASend_doc, + "WSASend(handle, buf, flags) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped send"); + +static PyObject * +Overlapped_WSASend(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD flags; + DWORD written; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O" F_DWORD, + &handle, &bufobj, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + wsabuf.len = (DWORD)self->write_buffer.len; + wsabuf.buf = self->write_buffer.buf; + + Py_BEGIN_ALLOW_THREADS + ret = WSASend((SOCKET)handle, &wsabuf, 1, &written, flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_AcceptEx_doc, + "AcceptEx(listen_handle, accept_handle) -> Overlapped[address_as_bytes]\n\n" + "Start overlapped wait for client to connect"); + +static PyObject * +Overlapped_AcceptEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ListenSocket; + SOCKET AcceptSocket; + DWORD BytesReceived; + DWORD size; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE, + &ListenSocket, &AcceptSocket)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + size = sizeof(struct sockaddr_in6) + 16; + buf = PyBytes_FromStringAndSize(NULL, size*2); + if (!buf) + return NULL; + + self->type = TYPE_ACCEPT; + self->handle = (HANDLE)ListenSocket; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = Py_AcceptEx(ListenSocket, AcceptSocket, PyBytes_AS_STRING(buf), + 0, size, size, &BytesReceived, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + + +static int +parse_address(PyObject *obj, SOCKADDR *Address, int Length) +{ + char *Host; + unsigned short Port; + unsigned long FlowInfo; + unsigned long ScopeId; + + memset(Address, 0, Length); + + if (PyArg_ParseTuple(obj, "sH", &Host, &Port)) + { + Address->sa_family = AF_INET; + if (WSAStringToAddressA(Host, AF_INET, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN*)Address)->sin_port = htons(Port); + return Length; + } + else if (PyArg_ParseTuple(obj, "sHkk", &Host, &Port, &FlowInfo, &ScopeId)) + { + PyErr_Clear(); + Address->sa_family = AF_INET6; + if (WSAStringToAddressA(Host, AF_INET6, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN6*)Address)->sin6_port = htons(Port); + ((SOCKADDR_IN6*)Address)->sin6_flowinfo = FlowInfo; + ((SOCKADDR_IN6*)Address)->sin6_scope_id = ScopeId; + return Length; + } + + return -1; +} + + +PyDoc_STRVAR( + Overlapped_ConnectEx_doc, + "ConnectEx(client_handle, address_as_bytes) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_ConnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ConnectSocket; + PyObject *AddressObj; + char AddressBuf[sizeof(struct sockaddr_in6)]; + SOCKADDR *Address = (SOCKADDR*)AddressBuf; + int Length; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &ConnectSocket, &AddressObj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + Length = sizeof(AddressBuf); + Length = parse_address(AddressObj, Address, Length); + if (Length < 0) + return NULL; + + self->type = TYPE_CONNECT; + self->handle = (HANDLE)ConnectSocket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_ConnectEx(ConnectSocket, Address, Length, + NULL, 0, NULL, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_DisconnectEx_doc, + "DisconnectEx(handle, flags) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_DisconnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET Socket; + DWORD flags; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &Socket, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + self->type = TYPE_DISCONNECT; + self->handle = (HANDLE)Socket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_DisconnectEx(Socket, &self->overlapped, flags, 0); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +static PyObject* +Overlapped_getaddress(OverlappedObject *self) +{ + return PyLong_FromVoidPtr(&self->overlapped); +} + +static PyObject* +Overlapped_getpending(OverlappedObject *self) +{ + return PyBool_FromLong(!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED); +} + +static PyMethodDef Overlapped_methods[] = { + {"getresult", (PyCFunction) Overlapped_getresult, + METH_VARARGS, Overlapped_getresult_doc}, + {"cancel", (PyCFunction) Overlapped_cancel, + METH_NOARGS, Overlapped_cancel_doc}, + {"ReadFile", (PyCFunction) Overlapped_ReadFile, + METH_VARARGS, Overlapped_ReadFile_doc}, + {"WSARecv", (PyCFunction) Overlapped_WSARecv, + METH_VARARGS, Overlapped_WSARecv_doc}, + {"WriteFile", (PyCFunction) Overlapped_WriteFile, + METH_VARARGS, Overlapped_WriteFile_doc}, + {"WSASend", (PyCFunction) Overlapped_WSASend, + METH_VARARGS, Overlapped_WSASend_doc}, + {"AcceptEx", (PyCFunction) Overlapped_AcceptEx, + METH_VARARGS, Overlapped_AcceptEx_doc}, + {"ConnectEx", (PyCFunction) Overlapped_ConnectEx, + METH_VARARGS, Overlapped_ConnectEx_doc}, + {"DisconnectEx", (PyCFunction) Overlapped_DisconnectEx, + METH_VARARGS, Overlapped_DisconnectEx_doc}, + {NULL} +}; + +static PyMemberDef Overlapped_members[] = { + {"error", T_ULONG, + offsetof(OverlappedObject, error), + READONLY, "Error from last operation"}, + {"event", T_HANDLE, + offsetof(OverlappedObject, overlapped) + offsetof(OVERLAPPED, hEvent), + READONLY, "Overlapped event handle"}, + {NULL} +}; + +static PyGetSetDef Overlapped_getsets[] = { + {"address", (getter)Overlapped_getaddress, NULL, + "Address of overlapped structure"}, + {"pending", (getter)Overlapped_getpending, NULL, + "Whether the operation is pending"}, + {NULL}, +}; + +PyTypeObject OverlappedType = { + PyVarObject_HEAD_INIT(NULL, 0) + /* tp_name */ "_overlapped.Overlapped", + /* tp_basicsize */ sizeof(OverlappedObject), + /* tp_itemsize */ 0, + /* tp_dealloc */ (destructor) Overlapped_dealloc, + /* tp_print */ 0, + /* tp_getattr */ 0, + /* tp_setattr */ 0, + /* tp_reserved */ 0, + /* tp_repr */ 0, + /* tp_as_number */ 0, + /* tp_as_sequence */ 0, + /* tp_as_mapping */ 0, + /* tp_hash */ 0, + /* tp_call */ 0, + /* tp_str */ 0, + /* tp_getattro */ 0, + /* tp_setattro */ 0, + /* tp_as_buffer */ 0, + /* tp_flags */ Py_TPFLAGS_DEFAULT, + /* tp_doc */ "OVERLAPPED structure wrapper", + /* tp_traverse */ 0, + /* tp_clear */ 0, + /* tp_richcompare */ 0, + /* tp_weaklistoffset */ 0, + /* tp_iter */ 0, + /* tp_iternext */ 0, + /* tp_methods */ Overlapped_methods, + /* tp_members */ Overlapped_members, + /* tp_getset */ Overlapped_getsets, + /* tp_base */ 0, + /* tp_dict */ 0, + /* tp_descr_get */ 0, + /* tp_descr_set */ 0, + /* tp_dictoffset */ 0, + /* tp_init */ 0, + /* tp_alloc */ 0, + /* tp_new */ Overlapped_new, +}; + +static PyMethodDef overlapped_functions[] = { + {"CreateIoCompletionPort", overlapped_CreateIoCompletionPort, + METH_VARARGS, CreateIoCompletionPort_doc}, + {"GetQueuedCompletionStatus", overlapped_GetQueuedCompletionStatus, + METH_VARARGS, GetQueuedCompletionStatus_doc}, + {"PostQueuedCompletionStatus", overlapped_PostQueuedCompletionStatus, + METH_VARARGS, PostQueuedCompletionStatus_doc}, + {"BindLocal", overlapped_BindLocal, + METH_VARARGS, BindLocal_doc}, + {NULL} +}; + +static struct PyModuleDef overlapped_module = { + PyModuleDef_HEAD_INIT, + "_overlapped", + NULL, + -1, + overlapped_functions, + NULL, + NULL, + NULL, + NULL +}; + +#define WINAPI_CONSTANT(fmt, con) \ + PyDict_SetItemString(d, #con, Py_BuildValue(fmt, con)) + +PyMODINIT_FUNC +PyInit__overlapped(void) +{ + PyObject *m, *d; + + /* Ensure WSAStartup() called before initializing function pointers */ + m = PyImport_ImportModule("_socket"); + if (!m) + return NULL; + Py_DECREF(m); + + if (initialize_function_pointers() < 0) + return NULL; + + if (PyType_Ready(&OverlappedType) < 0) + return NULL; + + m = PyModule_Create(&overlapped_module); + if (PyModule_AddObject(m, "Overlapped", (PyObject *)&OverlappedType) < 0) + return NULL; + + d = PyModule_GetDict(m); + + WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING); + WINAPI_CONSTANT(F_DWORD, INFINITE); + WINAPI_CONSTANT(F_HANDLE, INVALID_HANDLE_VALUE); + WINAPI_CONSTANT(F_HANDLE, NULL); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_ACCEPT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_CONNECT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, TF_REUSE_SOCKET); + + return m; +} diff --git a/runtests.py b/runtests.py new file mode 100644 index 00000000..484bff09 --- /dev/null +++ b/runtests.py @@ -0,0 +1,265 @@ +"""Run Tulip unittests. + +Usage: + python3 runtests.py [flags] [pattern] ... + +Patterns are matched against the fully qualified name of the test, +including package, module, class and method, +e.g. 'tests.events_test.PolicyTests.testPolicy'. + +For full help, try --help. + +runtests.py --coverage is equivalent of: + + $(COVERAGE) run --branch runtests.py -v + $(COVERAGE) html $(list of files) + $(COVERAGE) report -m $(list of files) + +""" + +# Originally written by Beech Horn (for NDB). + +import argparse +import gc +import logging +import os +import re +import sys +import subprocess +import unittest +import importlib.machinery + +from unittest.signals import installHandler + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +ARGS = argparse.ArgumentParser(description="Run all unittests.") +ARGS.add_argument( + '-v', action="store", dest='verbose', + nargs='?', const=1, type=int, default=0, help='verbose') +ARGS.add_argument( + '-x', action="store_true", dest='exclude', help='exclude tests') +ARGS.add_argument( + '-f', '--failfast', action="store_true", default=False, + dest='failfast', help='Stop on first fail or error') +ARGS.add_argument( + '-c', '--catch', action="store_true", default=False, + dest='catchbreak', help='Catch control-C and display results') +ARGS.add_argument( + '--forever', action="store_true", dest='forever', default=False, + help='run tests forever to catch sporadic errors') +ARGS.add_argument( + '--findleaks', action='store_true', dest='findleaks', + help='detect tests that leak memory') +ARGS.add_argument( + '-q', action="store_true", dest='quiet', help='quiet') +ARGS.add_argument( + '--tests', action="store", dest='testsdir', default='tests', + help='tests directory') +ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') +ARGS.add_argument( + 'pattern', action="store", nargs="*", + help='optional regex patterns to match test ids (default all tests)') + +COV_ARGS = argparse.ArgumentParser(description="Run all unittests.") +COV_ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') + + +def load_modules(basedir, suffix='.py'): + def list_dir(prefix, dir): + files = [] + + modpath = os.path.join(dir, '__init__.py') + if os.path.isfile(modpath): + mod = os.path.split(dir)[-1] + files.append(('{}{}'.format(prefix, mod), modpath)) + + prefix = '{}{}.'.format(prefix, mod) + + for name in os.listdir(dir): + path = os.path.join(dir, name) + + if os.path.isdir(path): + files.extend(list_dir('{}{}.'.format(prefix, name), path)) + else: + if (name != '__init__.py' and + name.endswith(suffix) and + not name.startswith(('.', '_'))): + files.append(('{}{}'.format(prefix, name[:-3]), path)) + + return files + + mods = [] + for modname, sourcefile in list_dir('', basedir): + if modname == 'runtests': + continue + try: + loader = importlib.machinery.SourceFileLoader(modname, sourcefile) + mods.append((loader.load_module(), sourcefile)) + except SyntaxError: + raise + except Exception as err: + print("Skipping '{}': {}".format(modname, err), file=sys.stderr) + + return mods + + +def load_tests(testsdir, includes=(), excludes=()): + mods = [mod for mod, _ in load_modules(testsdir)] + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + for mod in mods: + for name in set(dir(mod)): + if name.endswith('Tests'): + test_module = getattr(mod, name) + tests = loader.loadTestsFromTestCase(test_module) + if includes: + tests = [test + for test in tests + if any(re.search(pat, test.id()) + for pat in includes)] + if excludes: + tests = [test + for test in tests + if not any(re.search(pat, test.id()) + for pat in excludes)] + suite.addTests(tests) + + return suite + + +class TestResult(unittest.TextTestResult): + + def __init__(self, stream, descriptions, verbosity): + super().__init__(stream, descriptions, verbosity) + self.leaks = [] + + def startTest(self, test): + super().startTest(test) + gc.collect() + + def addSuccess(self, test): + super().addSuccess(test) + gc.collect() + if gc.garbage: + if self.showAll: + self.stream.writeln( + " Warning: test created {} uncollectable " + "object(s).".format(len(gc.garbage))) + # move the uncollectable objects somewhere so we don't see + # them again + self.leaks.append((self.getDescription(test), gc.garbage[:])) + del gc.garbage[:] + + +class TestRunner(unittest.TextTestRunner): + resultclass = TestResult + + def run(self, test): + result = super().run(test) + if result.leaks: + self.stream.writeln("{} tests leaks:".format(len(result.leaks))) + for name, leaks in result.leaks: + self.stream.writeln(' '*4 + name + ':') + for leak in leaks: + self.stream.writeln(' '*8 + repr(leak)) + return result + + +def runtests(): + args = ARGS.parse_args() + + testsdir = os.path.abspath(args.testsdir) + if not os.path.isdir(testsdir): + print("Tests directory is not found: {}\n".format(testsdir)) + ARGS.print_help() + return + + excludes = includes = [] + if args.exclude: + excludes = args.pattern + else: + includes = args.pattern + + v = 0 if args.quiet else args.verbose + 1 + failfast = args.failfast + catchbreak = args.catchbreak + findleaks = args.findleaks + runner_factory = TestRunner if findleaks else unittest.TextTestRunner + + tests = load_tests(args.testsdir, includes, excludes) + logger = logging.getLogger() + if v == 0: + logger.setLevel(logging.CRITICAL) + elif v == 1: + logger.setLevel(logging.ERROR) + elif v == 2: + logger.setLevel(logging.WARNING) + elif v == 3: + logger.setLevel(logging.INFO) + elif v >= 4: + logger.setLevel(logging.DEBUG) + if catchbreak: + installHandler() + if args.forever: + while True: + result = runner_factory(verbosity=v, + failfast=failfast).run(tests) + if not result.wasSuccessful(): + sys.exit(1) + else: + result = runner_factory(verbosity=v, + failfast=failfast).run(tests) + sys.exit(not result.wasSuccessful()) + + +def runcoverage(sdir, args): + """ + To install coverage3 for Python 3, you need: + - Setuptools (https://pypi.python.org/pypi/setuptools) + + What worked for me: + - download bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py + * curl -O \ + https://bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py + - python3 ez_setup.py + - python3 -m easy_install coverage + """ + try: + import coverage + except ImportError: + print("Coverage package is not found.") + print(runcoverage.__doc__) + return + + sdir = os.path.abspath(sdir) + if not os.path.isdir(sdir): + print("Python files directory is not found: {}\n".format(sdir)) + ARGS.print_help() + return + + mods = [source for _, source in load_modules(sdir)] + coverage = [sys.executable, '-m', 'coverage'] + + try: + subprocess.check_call( + coverage + ['run', '--branch', 'runtests.py'] + args) + except: + pass + else: + subprocess.check_call(coverage + ['html'] + mods) + subprocess.check_call(coverage + ['report'] + mods) + + +if __name__ == '__main__': + if '--coverage' in sys.argv: + cov_args, args = COV_ARGS.parse_known_args() + runcoverage(cov_args.coverage, args) + else: + runtests() diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..0260f9d5 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[build_ext] +build_lib=tulip diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..dcaee96f --- /dev/null +++ b/setup.py @@ -0,0 +1,14 @@ +import os +from distutils.core import setup, Extension + +extensions = [] +if os.name == 'nt': + ext = Extension('_overlapped', ['overlapped.c'], libraries=['ws2_32']) + extensions.append(ext) + +setup(name='tulip', + description="reference implementation of PEP 3156", + url='http://www.python.org/dev/peps/pep-3156/', + packages=['tulip'], + ext_modules=extensions + ) diff --git a/tests/base_events_test.py b/tests/base_events_test.py new file mode 100644 index 00000000..e27b3ab9 --- /dev/null +++ b/tests/base_events_test.py @@ -0,0 +1,591 @@ +"""Tests for base_events.py""" + +import logging +import socket +import time +import unittest +import unittest.mock + +from tulip import base_events +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import tasks +from tulip import test_utils + + +class BaseEventLoopTests(unittest.TestCase): + + def setUp(self): + self.loop = base_events.BaseEventLoop() + self.loop._selector = unittest.mock.Mock() + events.set_event_loop(None) + + def test_not_implemented(self): + m = unittest.mock.Mock() + self.assertRaises( + NotImplementedError, + self.loop._make_socket_transport, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_ssl_transport, m, m, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_datagram_transport, m, m) + self.assertRaises( + NotImplementedError, self.loop._process_events, []) + self.assertRaises( + NotImplementedError, self.loop._write_to_self) + self.assertRaises( + NotImplementedError, self.loop._read_from_self) + self.assertRaises( + NotImplementedError, + self.loop._make_read_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_write_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + next, self.loop._make_subprocess_transport(m, m, m, m, m, m, m)) + + def test__add_callback_handle(self): + h = events.Handle(lambda: False, ()) + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertIn(h, self.loop._ready) + + def test__add_callback_timer(self): + h = events.TimerHandle(time.monotonic()+10, lambda: False, ()) + + self.loop._add_callback(h) + self.assertIn(h, self.loop._scheduled) + + def test__add_callback_cancelled_handle(self): + h = events.Handle(lambda: False, ()) + h.cancel() + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertFalse(self.loop._ready) + + def test_set_default_executor(self): + executor = unittest.mock.Mock() + self.loop.set_default_executor(executor) + self.assertIs(executor, self.loop._default_executor) + + def test_getnameinfo(self): + sockaddr = unittest.mock.Mock() + self.loop.run_in_executor = unittest.mock.Mock() + self.loop.getnameinfo(sockaddr) + self.assertEqual( + (None, socket.getnameinfo, sockaddr, 0), + self.loop.run_in_executor.call_args[0]) + + def test_call_soon(self): + def cb(): + pass + + h = self.loop.call_soon(cb) + self.assertEqual(h._callback, cb) + self.assertIsInstance(h, events.Handle) + self.assertIn(h, self.loop._ready) + + def test_call_later(self): + def cb(): + pass + + h = self.loop.call_later(10.0, cb) + self.assertIsInstance(h, events.TimerHandle) + self.assertIn(h, self.loop._scheduled) + self.assertNotIn(h, self.loop._ready) + + def test_call_later_negative_delays(self): + calls = [] + + def cb(arg): + calls.append(arg) + + self.loop._process_events = unittest.mock.Mock() + self.loop.call_later(-1, cb, 'a') + self.loop.call_later(-2, cb, 'b') + test_utils.run_briefly(self.loop) + self.assertEqual(calls, ['b', 'a']) + + def test_time_and_call_at(self): + def cb(): + self.loop.stop() + + self.loop._process_events = unittest.mock.Mock() + when = self.loop.time() + 0.1 + self.loop.call_at(when, cb) + t0 = self.loop.time() + self.loop.run_forever() + t1 = self.loop.time() + self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0) + + def test_run_once_in_executor_handle(self): + def cb(): + pass + + self.assertRaises( + AssertionError, self.loop.run_in_executor, + None, events.Handle(cb, ()), ('',)) + self.assertRaises( + AssertionError, self.loop.run_in_executor, + None, events.TimerHandle(10, cb, ())) + + def test_run_once_in_executor_cancelled(self): + def cb(): + pass + h = events.Handle(cb, ()) + h.cancel() + + f = self.loop.run_in_executor(None, h) + self.assertIsInstance(f, futures.Future) + self.assertTrue(f.done()) + self.assertIsNone(f.result()) + + def test_run_once_in_executor_plain(self): + def cb(): + pass + h = events.Handle(cb, ()) + f = futures.Future(loop=self.loop) + executor = unittest.mock.Mock() + executor.submit.return_value = f + + self.loop.set_default_executor(executor) + + res = self.loop.run_in_executor(None, h) + self.assertIs(f, res) + + executor = unittest.mock.Mock() + executor.submit.return_value = f + res = self.loop.run_in_executor(executor, h) + self.assertIs(f, res) + self.assertTrue(executor.submit.called) + + f.cancel() # Don't complain about abandoned Future. + + def test__run_once(self): + h1 = events.TimerHandle(time.monotonic() + 0.1, lambda: True, ()) + h2 = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ()) + + h1.cancel() + + self.loop._process_events = unittest.mock.Mock() + self.loop._scheduled.append(h1) + self.loop._scheduled.append(h2) + self.loop._run_once() + + t = self.loop._selector.select.call_args[0][0] + self.assertTrue(9.99 < t < 10.1, t) + self.assertEqual([h2], self.loop._scheduled) + self.assertTrue(self.loop._process_events.called) + + @unittest.mock.patch('tulip.base_events.time') + @unittest.mock.patch('tulip.base_events.tulip_log') + def test__run_once_logging(self, m_logging, m_time): + # Log to INFO level if timeout > 1.0 sec. + idx = -1 + data = [10.0, 10.0, 12.0, 13.0] + + def monotonic(): + nonlocal data, idx + idx += 1 + return data[idx] + + m_time.monotonic = monotonic + m_logging.INFO = logging.INFO + m_logging.DEBUG = logging.DEBUG + + self.loop._scheduled.append( + events.TimerHandle(11.0, lambda: True, ())) + self.loop._process_events = unittest.mock.Mock() + self.loop._run_once() + self.assertEqual(logging.INFO, m_logging.log.call_args[0][0]) + + idx = -1 + data = [10.0, 10.0, 10.3, 13.0] + self.loop._scheduled = [events.TimerHandle(11.0, lambda:True, ())] + self.loop._run_once() + self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0]) + + def test__run_once_schedule_handle(self): + handle = None + processed = False + + def cb(loop): + nonlocal processed, handle + processed = True + handle = loop.call_soon(lambda: True) + + h = events.TimerHandle(time.monotonic() - 1, cb, (self.loop,)) + + self.loop._process_events = unittest.mock.Mock() + self.loop._scheduled.append(h) + self.loop._run_once() + + self.assertTrue(processed) + self.assertEqual([handle], list(self.loop._ready)) + + def test_run_until_complete_type_error(self): + self.assertRaises( + TypeError, self.loop.run_until_complete, 'blah') + + +class MyProto(protocols.Protocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = futures.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyDatagramProto(protocols.DatagramProtocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = futures.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class BaseEventLoopWithSelectorTests(unittest.TestCase): + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_connection_mutiple_errors(self, m_socket): + + class MyProto(protocols.Protocol): + pass + + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + idx = -1 + errors = ['err1', 'err2'] + + def _socket(*args, **kw): + nonlocal idx, errors + idx += 1 + raise OSError(errors[idx]) + + m_socket.socket = _socket + + self.loop.getaddrinfo = getaddrinfo_task + + task = tasks.Task( + self.loop.create_connection(MyProto, 'example.com', 80)) + yield from tasks.wait(task) + exc = task.exception() + self.assertEqual("Multiple exceptions: err1, err2", str(exc)) + + def test_create_connection_host_port_sock(self): + coro = self.loop.create_connection( + MyProto, 'example.com', 80, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_host_port_sock(self): + coro = self.loop.create_connection(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_getaddrinfo(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_connect_err(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_multiple(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET) + with self.assertRaises(OSError): + self.loop.run_until_complete(coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_connection_mutiple_errors_local_addr(self, m_socket): + + def bind(addr): + if addr[0] == '0.0.0.1': + err = OSError('Err') + err.strerror = 'Err' + raise err + + m_socket.socket.return_value.bind = bind + + @tasks.coroutine + def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError('Err2') + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(coro) + + self.assertTrue(str(cm.exception), 'Multiple exceptions: ') + self.assertTrue(m_socket.socket.return_value.close.called) + + def test_create_connection_no_local_addr(self): + @tasks.coroutine + def getaddrinfo(host, *args, **kw): + if host == 'example.com': + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + else: + return [] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + self.loop.getaddrinfo = getaddrinfo_task + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_start_serving_empty_host(self): + # if host is empty string use None instead + host = object() + + @tasks.coroutine + def getaddrinfo(*args, **kw): + nonlocal host + host = args[0] + yield from [] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + fut = self.loop.start_serving(MyProto, '', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertIsNone(host) + + def test_start_serving_host_port_sock(self): + fut = self.loop.start_serving( + MyProto, '0.0.0.0', 0, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_start_serving_no_host_port_sock(self): + fut = self.loop.start_serving(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_start_serving_no_getaddrinfo(self): + getaddrinfo = self.loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + f = self.loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, f) + + @unittest.mock.patch('tulip.base_events.socket') + def test_start_serving_cant_bind(self, m_socket): + + class Err(OSError): + strerror = 'error' + + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_no_addrinfo(self, m_socket): + m_socket.getaddrinfo.return_value = [] + + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_addr_error(self): + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr='localhost') + self.assertRaises( + AssertionError, self.loop.run_until_complete, coro) + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 1, 2, 3)) + self.assertRaises( + AssertionError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_connect_err(self): + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_socket_err(self, m_socket): + m_socket.getaddrinfo = socket.getaddrinfo + m_socket.socket.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, local_addr=('127.0.0.1', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_no_matching_family(self): + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, + remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) + self.assertRaises( + ValueError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_setblk_err(self, m_socket): + m_socket.socket.return_value.setblocking.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + self.assertTrue( + m_socket.socket.return_value.close.called) + + def test_create_datagram_endpoint_noaddr_nofamily(self): + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_cant_bind(self, m_socket): + class Err(OSError): + pass + + m_socket.AF_INET6 = socket.AF_INET6 + m_socket.getaddrinfo = socket.getaddrinfo + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, + local_addr=('127.0.0.1', 0), family=socket.AF_INET) + self.assertRaises(Err, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_accept_connection_retry(self): + sock = unittest.mock.Mock() + sock.accept.side_effect = BlockingIOError() + + self.loop._accept_connection(MyProto, sock) + self.assertFalse(sock.close.called) + + @unittest.mock.patch('tulip.selector_events.tulip_log') + def test_accept_connection_exception(self, m_log): + sock = unittest.mock.Mock() + sock.accept.side_effect = OSError() + + self.loop._accept_connection(MyProto, sock) + self.assertTrue(sock.close.called) + self.assertTrue(m_log.exception.called) diff --git a/tests/echo.py b/tests/echo.py new file mode 100644 index 00000000..f6ac0a30 --- /dev/null +++ b/tests/echo.py @@ -0,0 +1,6 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + os.write(1, buf) diff --git a/tests/echo2.py b/tests/echo2.py new file mode 100644 index 00000000..e83ca09f --- /dev/null +++ b/tests/echo2.py @@ -0,0 +1,6 @@ +import os + +if __name__ == '__main__': + buf = os.read(0, 1024) + os.write(1, b'OUT:'+buf) + os.write(2, b'ERR:'+buf) diff --git a/tests/echo3.py b/tests/echo3.py new file mode 100644 index 00000000..f1f7ea7c --- /dev/null +++ b/tests/echo3.py @@ -0,0 +1,9 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + try: + os.write(1, b'OUT:'+buf) + except OSError as ex: + os.write(2, b'ERR:' + ex.__class__.__name__.encode('ascii')) diff --git a/tests/events_test.py b/tests/events_test.py new file mode 100644 index 00000000..240518c0 --- /dev/null +++ b/tests/events_test.py @@ -0,0 +1,1574 @@ +"""Tests for events.py.""" + +import functools +import gc +import io +import os +import re +import signal +import socket +try: + import ssl +except ImportError: + ssl = None +import subprocess +import sys +import threading +import time +import errno +import unittest +import unittest.mock +from test.support import find_unused_port + + +from tulip import futures +from tulip import events +from tulip import transports +from tulip import protocols +from tulip import selector_events +from tulip import tasks +from tulip import test_utils +from tulip import locks + + +class MyProto(protocols.Protocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyDatagramProto(protocols.DatagramProtocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyReadPipeProto(protocols.Protocol): + done = None + + def __init__(self, loop=None): + self.state = ['INITIAL'] + self.nbytes = 0 + self.transport = None + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == ['INITIAL'], self.state + self.state.append('CONNECTED') + + def data_received(self, data): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.state.append('EOF') + self.transport.close() + + def connection_lost(self, exc): + assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state + self.state.append('CLOSED') + if self.done: + self.done.set_result(None) + + +class MyWritePipeProto(protocols.BaseProtocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.transport = None + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MySubprocessProtocol(protocols.SubprocessProtocol): + + def __init__(self, loop): + self.state = 'INITIAL' + self.transport = None + self.connected = futures.Future(loop=loop) + self.completed = futures.Future(loop=loop) + self.disconnects = {fd: futures.Future(loop=loop) for fd in range(3)} + self.data = {1: b'', 2: b''} + self.returncode = None + self.got_data = {1: locks.EventWaiter(loop=loop), + 2: locks.EventWaiter(loop=loop)} + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + self.connected.set_result(None) + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + self.completed.set_result(None) + + def pipe_data_received(self, fd, data): + assert self.state == 'CONNECTED', self.state + self.data[fd] += data + self.got_data[fd].set() + + def pipe_connection_lost(self, fd, exc): + assert self.state == 'CONNECTED', self.state + if exc: + self.disconnects[fd].set_exception(exc) + else: + self.disconnects[fd].set_result(exc) + + def process_exited(self): + assert self.state == 'CONNECTED', self.state + self.returncode = self.transport.get_returncode() + + +class EventLoopTestsMixin: + + def setUp(self): + super().setUp() + self.loop = self.create_event_loop() + events.set_event_loop(None) + + def tearDown(self): + # just in case if we have transport close callbacks + test_utils.run_briefly(self.loop) + + self.loop.close() + gc.collect() + super().tearDown() + + def test_run_until_complete_nesting(self): + @tasks.coroutine + def coro1(): + yield + + @tasks.coroutine + def coro2(): + self.assertTrue(self.loop.is_running()) + self.loop.run_until_complete(coro1()) + + self.assertRaises( + RuntimeError, self.loop.run_until_complete, coro2()) + + # Note: because of the default Windows timing granularity of + # 15.6 msec, we use fairly long sleep times here (~100 msec). + + def test_run_until_complete(self): + t0 = self.loop.time() + self.loop.run_until_complete(tasks.sleep(0.1, loop=self.loop)) + t1 = self.loop.time() + self.assertTrue(0.08 <= t1-t0 <= 0.12, t1-t0) + + def test_run_until_complete_stopped(self): + @tasks.coroutine + def cb(): + self.loop.stop() + yield from tasks.sleep(0.1, loop=self.loop) + task = cb() + self.assertRaises(RuntimeError, + self.loop.run_until_complete, task) + + def test_call_later(self): + results = [] + + def callback(arg): + results.append(arg) + self.loop.stop() + + self.loop.call_later(0.1, callback, 'hello world') + t0 = time.monotonic() + self.loop.run_forever() + t1 = time.monotonic() + self.assertEqual(results, ['hello world']) + self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0) + + def test_call_soon(self): + results = [] + + def callback(arg1, arg2): + results.append((arg1, arg2)) + self.loop.stop() + + self.loop.call_soon(callback, 'hello', 'world') + self.loop.run_forever() + self.assertEqual(results, [('hello', 'world')]) + + def test_call_soon_threadsafe(self): + results = [] + lock = threading.Lock() + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.loop.stop() + + def run_in_thread(): + self.loop.call_soon_threadsafe(callback, 'hello') + lock.release() + + lock.acquire() + t = threading.Thread(target=run_in_thread) + t.start() + + with lock: + self.loop.call_soon(callback, 'world') + self.loop.run_forever() + t.join() + self.assertEqual(results, ['hello', 'world']) + + def test_call_soon_threadsafe_same_thread(self): + results = [] + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.loop.stop() + + self.loop.call_soon_threadsafe(callback, 'hello') + self.loop.call_soon(callback, 'world') + self.loop.run_forever() + self.assertEqual(results, ['hello', 'world']) + + def test_run_in_executor(self): + def run(arg): + return (arg, threading.get_ident()) + f2 = self.loop.run_in_executor(None, run, 'yo') + res, thread_id = self.loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + self.assertNotEqual(thread_id, threading.get_ident()) + + def test_reader_callback(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.loop.remove_reader(r.fileno())) + r.close() + + self.loop.add_reader(r.fileno(), reader) + self.loop.call_soon(w.send, b'abc') + test_utils.run_briefly(self.loop) + self.loop.call_soon(w.send, b'def') + self.loop.call_soon(w.close) + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_writer_callback(self): + r, w = test_utils.socketpair() + w.setblocking(False) + self.loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + test_utils.run_briefly(self.loop) + + def remove_writer(): + self.assertTrue(self.loop.remove_writer(w.fileno())) + + self.loop.call_soon(remove_writer) + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + w.close() + data = r.recv(256*1024) + r.close() + self.assertGreaterEqual(len(data), 200) + + def test_sock_client_ops(self): + with test_utils.run_test_server(self.loop) as httpd: + sock = socket.socket() + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, httpd.address)) + self.loop.run_until_complete( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + # consume data + self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + sock.close() + + self.assertTrue(re.match(rb'HTTP/1.0 200 OK', data), data) + + def test_sock_client_fail(self): + # Make sure that we will get an unused port + address = None + try: + s = socket.socket() + s.bind(('127.0.0.1', 0)) + address = s.getsockname() + finally: + s.close() + + sock = socket.socket() + sock.setblocking(False) + with self.assertRaises(ConnectionRefusedError): + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) + sock.close() + + def test_sock_accept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + + f = self.loop.sock_accept(listener) + conn, addr = self.loop.run_until_complete(f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + def test_add_signal_handler(self): + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + # Check error behavior first. + self.assertRaises( + TypeError, self.loop.add_signal_handler, 'boom', my_handler) + self.assertRaises( + TypeError, self.loop.remove_signal_handler, 'boom') + self.assertRaises( + ValueError, self.loop.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, signal.NSIG+1) + self.assertRaises( + ValueError, self.loop.add_signal_handler, 0, my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, 0) + self.assertRaises( + ValueError, self.loop.add_signal_handler, -1, my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, -1) + self.assertRaises( + RuntimeError, self.loop.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(self.loop.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + self.loop.add_signal_handler(signal.SIGINT, my_handler) + test_utils.run_briefly(self.loop) + os.kill(os.getpid(), signal.SIGINT) + test_utils.run_briefly(self.loop) + self.assertEqual(caught, 1) + # Removing it should restore the default handler. + self.assertTrue(self.loop.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(self.loop.remove_signal_handler(signal.SIGINT)) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_while_selecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + self.loop.stop() + + self.loop.add_signal_handler(signal.SIGALRM, my_handler) + + signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once. + self.loop.run_forever() + self.assertEqual(caught, 1) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_args(self): + some_args = (42,) + caught = 0 + + def my_handler(*args): + nonlocal caught + caught += 1 + self.assertEqual(args, some_args) + + self.loop.add_signal_handler(signal.SIGALRM, my_handler, *some_args) + + signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once. + self.loop.call_later(0.015, self.loop.stop) + self.loop.run_forever() + self.assertEqual(caught, 1) + + def test_create_connection(self): + with test_utils.run_test_server(self.loop) as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), *httpd.address) + tr, pr = self.loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def test_create_connection_sock(self): + with test_utils.run_test_server(self.loop) as httpd: + sock = None + infos = self.loop.run_until_complete( + self.loop.getaddrinfo( + *httpd.address, type=socket.SOCK_STREAM)) + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) + except: + pass + else: + break + else: + assert False, 'Can not create socket.' + + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), sock=sock) + tr, pr = self.loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_ssl_connection(self): + with test_utils.run_test_server( + self.loop, use_ssl=True) as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), *httpd.address, ssl=True) + tr, pr = self.loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + self.assertTrue( + hasattr(tr.get_extra_info('socket'), 'getsockname')) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def test_create_connection_local_addr(self): + with test_utils.run_test_server(self.loop) as httpd: + port = find_unused_port() + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, local_addr=(httpd.address[0], port)) + tr, pr = self.loop.run_until_complete(f) + expected = pr.transport.get_extra_info('socket').getsockname()[1] + self.assertEqual(port, expected) + tr.close() + + def test_create_connection_local_addr_in_use(self): + with test_utils.run_test_server(self.loop) as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, local_addr=httpd.address) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + self.assertIn(str(httpd.address), cm.exception.strerror) + + def test_start_serving(self): + proto = None + + def factory(): + nonlocal proto + proto = MyProto() + return proto + + f = self.loop.start_serving(factory, '0.0.0.0', 0) + socks = self.loop.run_until_complete(f) + self.assertEqual(len(socks), 1) + sock = socks[0] + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + test_utils.run_briefly(self.loop) + self.assertIsInstance(proto, MyProto) + self.assertEqual('INITIAL', proto.state) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + test_utils.run_briefly(self.loop) # windows iocp + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('socket')) + conn = proto.transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + self.assertEqual( + '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + + # close connection + proto.transport.close() + test_utils.run_briefly(self.loop) # windows iocp + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # close start_serving socks + self.loop.stop_serving(sock) + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_start_serving_ssl(self): + proto = None + + class ClientMyProto(MyProto): + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def factory(): + nonlocal proto + proto = MyProto(loop=self.loop) + return proto + + here = os.path.dirname(__file__) + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain( + certfile=os.path.join(here, 'sample.crt'), + keyfile=os.path.join(here, 'sample.key')) + + f = self.loop.start_serving( + factory, '127.0.0.1', 0, ssl=sslcontext) + + sock = self.loop.run_until_complete(f)[0] + host, port = sock.getsockname() + self.assertEqual(host, '127.0.0.1') + + f_c = self.loop.create_connection(ClientMyProto, host, port, ssl=True) + client, pr = self.loop.run_until_complete(f_c) + + client.write(b'xxx') + test_utils.run_briefly(self.loop) + self.assertIsInstance(proto, MyProto) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('socket')) + conn = proto.transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + self.assertEqual( + '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # stop serving + self.loop.stop_serving(sock) + + def test_start_serving_sock(self): + proto = futures.Future(loop=self.loop) + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + proto.set_result(self) + + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.loop.start_serving(TestMyProto, sock=sock_ob) + sock = self.loop.run_until_complete(f)[0] + self.assertIs(sock, sock_ob) + + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + + self.loop.stop_serving(sock) + + def test_start_serving_addr_in_use(self): + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.loop.start_serving(MyProto, sock=sock_ob) + sock = self.loop.run_until_complete(f)[0] + host, port = sock.getsockname() + + f = self.loop.start_serving(MyProto, host=host, port=port) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + + self.loop.stop_serving(sock) + + @unittest.skipUnless(socket.has_ipv6, 'IPv6 not supported') + def test_start_serving_dual_stack(self): + f_proto = futures.Future(loop=self.loop) + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + f_proto.set_result(self) + + try_count = 0 + while True: + try: + port = find_unused_port() + f = self.loop.start_serving(TestMyProto, host=None, port=port) + socks = self.loop.run_until_complete(f) + except OSError as ex: + if ex.errno == errno.EADDRINUSE: + try_count += 1 + self.assertGreaterEqual(5, try_count) + continue + else: + raise + else: + break + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + proto = self.loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + f_proto = futures.Future(loop=self.loop) + client = socket.socket(socket.AF_INET6) + client.connect(('::1', port)) + client.send(b'xxx') + proto = self.loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + for s in socks: + self.loop.stop_serving(s) + + def test_stop_serving(self): + f = self.loop.start_serving(MyProto, '0.0.0.0', 0) + socks = self.loop.run_until_complete(f) + sock = socks[0] + host, port = sock.getsockname() + + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + + self.loop.stop_serving(sock) + + client = socket.socket() + self.assertRaises( + ConnectionRefusedError, client.connect, ('127.0.0.1', port)) + client.close() + + def test_create_datagram_endpoint(self): + class TestMyDatagramProto(MyDatagramProto): + def __init__(inner_self): + super().__init__(loop=self.loop) + + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + coro = self.loop.create_datagram_endpoint( + TestMyDatagramProto, local_addr=('127.0.0.1', 0)) + s_transport, server = self.loop.run_until_complete(coro) + host, port = s_transport.get_extra_info('addr') + + coro = self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(loop=self.loop), + remote_addr=(host, port)) + transport, client = self.loop.run_until_complete(coro) + + self.assertEqual('INITIALIZED', client.state) + transport.sendto(b'xxx') + test_utils.run_briefly(self.loop) + self.assertEqual(3, server.nbytes) + test_utils.run_briefly(self.loop) + + # received + self.assertEqual(8, client.nbytes) + + # extra info is available + self.assertIsNotNone(transport.get_extra_info('socket')) + conn = transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + + # close connection + transport.close() + self.loop.run_until_complete(client.done) + self.assertEqual('CLOSED', client.state) + server.transport.close() + + def test_internal_fds(self): + loop = self.create_event_loop() + if not isinstance(loop, selector_events.BaseSelectorEventLoop): + return + + self.assertEqual(1, loop._internal_fds) + loop.close() + self.assertEqual(0, loop._internal_fds) + self.assertIsNone(loop._csock) + self.assertIsNone(loop._ssock) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_read_pipe(self): + proto = None + + def factory(): + nonlocal proto + proto = MyReadPipeProto(loop=self.loop) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(rpipe, 'rb', 1024) + + @tasks.coroutine + def connect(): + t, p = yield from self.loop.connect_read_pipe(factory, pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(0, proto.nbytes) + + self.loop.run_until_complete(connect()) + + os.write(wpipe, b'1') + test_utils.run_briefly(self.loop) + self.assertEqual(1, proto.nbytes) + + os.write(wpipe, b'2345') + test_utils.run_briefly(self.loop) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(5, proto.nbytes) + + os.close(wpipe) + self.loop.run_until_complete(proto.done) + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto(loop=self.loop) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.coroutine + def connect(): + nonlocal transport + t, p = yield from self.loop.connect_write_pipe(factory, pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.loop.run_until_complete(connect()) + + transport.write(b'1') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + transport.write(b'2345') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'2345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(rpipe) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe_disconnect_on_close(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto(loop=self.loop) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.coroutine + def connect(): + nonlocal transport + t, p = yield from self.loop.connect_write_pipe(factory, + pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.loop.run_until_complete(connect()) + self.assertEqual('CONNECTED', proto.state) + + transport.write(b'1') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + os.close(rpipe) + + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + def test_prompt_cancellation(self): + r, w = test_utils.socketpair() + r.setblocking(False) + f = self.loop.sock_recv(r, 1) + ov = getattr(f, 'ov', None) + self.assertTrue(ov is None or ov.pending) + + def main(): + try: + self.loop.call_soon(f.cancel) + yield from f + except futures.CancelledError: + res = 'cancelled' + else: + res = None + finally: + self.loop.stop() + return res + + start = time.monotonic() + t = tasks.Task(main(), loop=self.loop) + self.loop.run_forever() + elapsed = time.monotonic() - start + + self.assertLess(elapsed, 0.1) + self.assertEqual(t.result(), 'cancelled') + self.assertRaises(futures.CancelledError, f.result) + self.assertTrue(ov is None or not ov.pending) + self.loop.stop_serving(r) + + r.close() + w.close() + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_exec(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python The Winner') + self.loop.run_until_complete(proto.got_data[1].wait()) + transp.close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + self.assertEqual(b'Python The Winner', proto.data[1]) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_interactive(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + try: + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python ') + self.loop.run_until_complete(proto.got_data[1].wait()) + proto.got_data[1].clear() + self.assertEqual(b'Python ', proto.data[1]) + + stdin.write(b'The Winner') + self.loop.run_until_complete(proto.got_data[1].wait()) + self.assertEqual(b'Python The Winner', proto.data[1]) + finally: + transp.close() + + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_shell(self): + proto = None + transp = None + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'echo "Python"') + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.get_pipe_transport(0).close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(0, proto.returncode) + self.assertTrue(all(f.done() for f in proto.disconnects.values())) + self.assertEqual({1: b'Python\n', 2: b''}, proto.data) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_exitcode(self): + proto = None + + @tasks.coroutine + def connect(): + nonlocal proto + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_close_after_finish(self): + proto = None + transp = None + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.assertIsNone(transp.get_pipe_transport(0)) + self.assertIsNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + self.assertIsNone(transp.close()) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_kill(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.kill() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGKILL, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_send_signal(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.send_signal(signal.SIGHUP) + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGHUP, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_stderr(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'test') + + self.loop.run_until_complete(proto.completed) + + transp.close() + self.assertEqual(b'OUT:test', proto.data[1]) + self.assertTrue(proto.data[2].startswith(b'ERR:test'), proto.data[2]) + self.assertEqual(0, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_stderr_redirect_to_stdout(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog, stderr=subprocess.STDOUT) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + self.assertIsNotNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + + stdin.write(b'test') + self.loop.run_until_complete(proto.completed) + self.assertTrue(proto.data[1].startswith(b'OUT:testERR:test'), + proto.data[1]) + self.assertEqual(b'', proto.data[2]) + + transp.close() + self.assertEqual(0, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_close_client_stream(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo3.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdout = transp.get_pipe_transport(1) + stdin.write(b'test') + self.loop.run_until_complete(proto.got_data[1].wait()) + self.assertEqual(b'OUT:test', proto.data[1]) + + stdout.close() + self.loop.run_until_complete(proto.disconnects[1]) + stdin.write(b'xxx') + self.loop.run_until_complete(proto.got_data[2].wait()) + self.assertEqual(b'ERR:BrokenPipeError', proto.data[2]) + + transp.close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + + +if sys.platform == 'win32': + from tulip import windows_events + + class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return windows_events.SelectorEventLoop() + + class ProactorEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return windows_events.ProactorEventLoop() + + def test_create_ssl_connection(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + + def test_start_serving_ssl(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + + def test_reader_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_reader_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_writer_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + + def test_writer_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + + def test_create_datagram_endpoint(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") +else: + from tulip import selectors + from tulip import unix_events + + if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop( + selectors.KqueueSelector()) + + if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.EpollSelector()) + + if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.PollSelector()) + + # Should always exist. + class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.SelectSelector()) + + +class HandleTests(unittest.TestCase): + + def test_handle(self): + def callback(*args): + return args + + args = () + h = events.Handle(callback, args) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h._cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '.callback')) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h._cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '.callback')) + self.assertTrue(r.endswith('())'), r) + + def test_make_handle(self): + def callback(*args): + return args + h1 = events.Handle(callback, ()) + self.assertRaises( + AssertionError, events.make_handle, h1, ()) + + @unittest.mock.patch('tulip.events.tulip_log') + def test_callback_with_exception(self, log): + def callback(): + raise ValueError() + + h = events.Handle(callback, ()) + h._run() + self.assertTrue(log.exception.called) + + +class TimerTests(unittest.TestCase): + + def test_hash(self): + when = time.monotonic() + h = events.TimerHandle(when, lambda: False, ()) + self.assertEqual(hash(h), hash(when)) + + def test_timer(self): + def callback(*args): + return args + + args = () + when = time.monotonic() + h = events.TimerHandle(when, callback, args) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h._cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h._cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())'), r) + + self.assertRaises(AssertionError, + events.TimerHandle, None, callback, args) + + def test_timer_comparison(self): + def callback(*args): + return args + + when = time.monotonic() + + h1 = events.TimerHandle(when, callback, ()) + h2 = events.TimerHandle(when, callback, ()) + # TODO: Use assertLess etc. + self.assertFalse(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertTrue(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertFalse(h2 > h1) + self.assertTrue(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertTrue(h1 == h2) + self.assertFalse(h1 != h2) + + h2.cancel() + self.assertFalse(h1 == h2) + + h1 = events.TimerHandle(when, callback, ()) + h2 = events.TimerHandle(when + 10.0, callback, ()) + self.assertTrue(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertFalse(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertTrue(h2 > h1) + self.assertFalse(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h3 = events.Handle(callback, ()) + self.assertIs(NotImplemented, h1.__eq__(h3)) + self.assertIs(NotImplemented, h1.__ne__(h3)) + + +class AbstractEventLoopTests(unittest.TestCase): + + def test_not_implemented(self): + f = unittest.mock.Mock() + loop = events.AbstractEventLoop() + self.assertRaises( + NotImplementedError, loop.run_forever) + self.assertRaises( + NotImplementedError, loop.run_until_complete, None) + self.assertRaises( + NotImplementedError, loop.stop) + self.assertRaises( + NotImplementedError, loop.is_running) + self.assertRaises( + NotImplementedError, loop.call_later, None, None) + self.assertRaises( + NotImplementedError, loop.call_at, f, f) + self.assertRaises( + NotImplementedError, loop.call_soon, None) + self.assertRaises( + NotImplementedError, loop.time) + self.assertRaises( + NotImplementedError, loop.call_soon_threadsafe, None) + self.assertRaises( + NotImplementedError, loop.run_in_executor, f, f) + self.assertRaises( + NotImplementedError, loop.set_default_executor, f) + self.assertRaises( + NotImplementedError, loop.getaddrinfo, 'localhost', 8080) + self.assertRaises( + NotImplementedError, loop.getnameinfo, ('localhost', 8080)) + self.assertRaises( + NotImplementedError, loop.create_connection, f) + self.assertRaises( + NotImplementedError, loop.start_serving, f) + self.assertRaises( + NotImplementedError, loop.stop_serving, f) + self.assertRaises( + NotImplementedError, loop.create_datagram_endpoint, f) + self.assertRaises( + NotImplementedError, loop.add_reader, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_reader, 1) + self.assertRaises( + NotImplementedError, loop.add_writer, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_writer, 1) + self.assertRaises( + NotImplementedError, loop.sock_recv, f, 10) + self.assertRaises( + NotImplementedError, loop.sock_sendall, f, 10) + self.assertRaises( + NotImplementedError, loop.sock_connect, f, f) + self.assertRaises( + NotImplementedError, loop.sock_accept, f) + self.assertRaises( + NotImplementedError, loop.add_signal_handler, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, loop.connect_read_pipe, f, + unittest.mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, loop.connect_write_pipe, f, + unittest.mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, loop.subprocess_shell, f, + unittest.mock.sentinel) + self.assertRaises( + NotImplementedError, loop.subprocess_exec, f) + + +class ProtocolsAbsTests(unittest.TestCase): + + def test_empty(self): + f = unittest.mock.Mock() + p = protocols.Protocol() + self.assertIsNone(p.connection_made(f)) + self.assertIsNone(p.connection_lost(f)) + self.assertIsNone(p.data_received(f)) + self.assertIsNone(p.eof_received()) + + dp = protocols.DatagramProtocol() + self.assertIsNone(dp.connection_made(f)) + self.assertIsNone(dp.connection_lost(f)) + self.assertIsNone(dp.connection_refused(f)) + self.assertIsNone(dp.datagram_received(f, f)) + + sp = protocols.SubprocessProtocol() + self.assertIsNone(sp.connection_made(f)) + self.assertIsNone(sp.connection_lost(f)) + self.assertIsNone(sp.pipe_data_received(1, f)) + self.assertIsNone(sp.pipe_connection_lost(1, f)) + self.assertIsNone(sp.process_exited()) + + +class PolicyTests(unittest.TestCase): + + def test_event_loop_policy(self): + policy = events.AbstractEventLoopPolicy() + self.assertRaises(NotImplementedError, policy.get_event_loop) + self.assertRaises(NotImplementedError, policy.set_event_loop, object()) + self.assertRaises(NotImplementedError, policy.new_event_loop) + + def test_get_event_loop(self): + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy._loop) + + loop = policy.get_event_loop() + self.assertIsInstance(loop, events.AbstractEventLoop) + + self.assertIs(policy._loop, loop) + self.assertIs(loop, policy.get_event_loop()) + loop.close() + + def test_get_event_loop_after_set_none(self): + policy = events.DefaultEventLoopPolicy() + policy.set_event_loop(None) + self.assertRaises(AssertionError, policy.get_event_loop) + + @unittest.mock.patch('tulip.events.threading.current_thread') + def test_get_event_loop_thread(self, m_current_thread): + + def f(): + policy = events.DefaultEventLoopPolicy() + self.assertRaises(AssertionError, policy.get_event_loop) + + th = threading.Thread(target=f) + th.start() + th.join() + + def test_new_event_loop(self): + policy = events.DefaultEventLoopPolicy() + + loop = policy.new_event_loop() + self.assertIsInstance(loop, events.AbstractEventLoop) + loop.close() + + def test_set_event_loop(self): + policy = events.DefaultEventLoopPolicy() + old_loop = policy.get_event_loop() + + self.assertRaises(AssertionError, policy.set_event_loop, object()) + + loop = policy.new_event_loop() + policy.set_event_loop(loop) + self.assertIs(loop, policy.get_event_loop()) + self.assertIsNot(old_loop, policy.get_event_loop()) + loop.close() + old_loop.close() + + def test_get_event_loop_policy(self): + policy = events.get_event_loop_policy() + self.assertIsInstance(policy, events.AbstractEventLoopPolicy) + self.assertIs(policy, events.get_event_loop_policy()) + + def test_set_event_loop_policy(self): + self.assertRaises( + AssertionError, events.set_event_loop_policy, object()) + + old_policy = events.get_event_loop_policy() + + policy = events.DefaultEventLoopPolicy() + events.set_event_loop_policy(policy) + self.assertIs(policy, events.get_event_loop_policy()) + self.assertIsNot(policy, old_policy) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/futures_test.py b/tests/futures_test.py new file mode 100644 index 00000000..7c2abd18 --- /dev/null +++ b/tests/futures_test.py @@ -0,0 +1,317 @@ +"""Tests for futures.py.""" + +import concurrent.futures +import threading +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import test_utils + + +def _fakefunc(f): + return f + + +class FutureTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_initial_state(self): + f = futures.Future(loop=self.loop) + self.assertFalse(f.cancelled()) + self.assertFalse(f.done()) + f.cancel() + self.assertTrue(f.cancelled()) + + def test_init_constructor_default_loop(self): + try: + events.set_event_loop(self.loop) + f = futures.Future() + self.assertIs(f._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_constructor_positional(self): + # Make sure Future does't accept a positional argument + self.assertRaises(TypeError, futures.Future, 42) + + def test_cancel(self): + f = futures.Future(loop=self.loop) + self.assertTrue(f.cancel()) + self.assertTrue(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(futures.CancelledError, f.result) + self.assertRaises(futures.CancelledError, f.exception) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_result(self): + f = futures.Future(loop=self.loop) + self.assertRaises(futures.InvalidStateError, f.result) + + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_exception(self): + exc = RuntimeError() + f = futures.Future(loop=self.loop) + self.assertRaises(futures.InvalidStateError, f.exception) + + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_yield_from_twice(self): + f = futures.Future(loop=self.loop) + + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + + def test_repr(self): + f_pending = futures.Future(loop=self.loop) + self.assertEqual(repr(f_pending), 'Future') + f_pending.cancel() + + f_cancelled = futures.Future(loop=self.loop) + f_cancelled.cancel() + self.assertEqual(repr(f_cancelled), 'Future') + + f_result = futures.Future(loop=self.loop) + f_result.set_result(4) + self.assertEqual(repr(f_result), 'Future') + self.assertEqual(f_result.result(), 4) + + exc = RuntimeError() + f_exception = futures.Future(loop=self.loop) + f_exception.set_exception(exc) + self.assertEqual(repr(f_exception), 'Future') + self.assertIs(f_exception.exception(), exc) + + f_few_callbacks = futures.Future(loop=self.loop) + f_few_callbacks.add_done_callback(_fakefunc) + self.assertIn('Future', r) + f_many_callbacks.cancel() + + def test_copy_state(self): + # Test the internal _copy_state method since it's being directly + # invoked in other modules. + f = futures.Future(loop=self.loop) + f.set_result(10) + + newf = futures.Future(loop=self.loop) + newf._copy_state(f) + self.assertTrue(newf.done()) + self.assertEqual(newf.result(), 10) + + f_exception = futures.Future(loop=self.loop) + f_exception.set_exception(RuntimeError()) + + newf_exception = futures.Future(loop=self.loop) + newf_exception._copy_state(f_exception) + self.assertTrue(newf_exception.done()) + self.assertRaises(RuntimeError, newf_exception.result) + + f_cancelled = futures.Future(loop=self.loop) + f_cancelled.cancel() + + newf_cancelled = futures.Future(loop=self.loop) + newf_cancelled._copy_state(f_cancelled) + self.assertTrue(newf_cancelled.cancelled()) + + def test_iter(self): + fut = futures.Future(loop=self.loop) + + def coro(): + yield from fut + + def test(): + arg1, arg2 = coro() + + self.assertRaises(AssertionError, test) + fut.cancel() + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_tb_logger_abandoned(self, m_log): + fut = futures.Future(loop=self.loop) + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_tb_logger_result_unretrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_result(42) + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_tb_logger_result_retrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_result(42) + fut.result() + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_tb_logger_exception_unretrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + del fut + test_utils.run_briefly(self.loop) + self.assertTrue(m_log.error.called) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_tb_logger_exception_retrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + fut.exception() + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_tb_logger_exception_result_retrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + self.assertRaises(RuntimeError, fut.result) + del fut + self.assertFalse(m_log.error.called) + + def test_wrap_future(self): + + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = futures.wrap_future(f1, loop=self.loop) + res, ident = self.loop.run_until_complete(f2) + self.assertIsInstance(f2, futures.Future) + self.assertEqual(res, 'oi') + self.assertNotEqual(ident, threading.get_ident()) + + def test_wrap_future_future(self): + f1 = futures.Future(loop=self.loop) + f2 = futures.wrap_future(f1) + self.assertIs(f1, f2) + + @unittest.mock.patch('tulip.futures.events') + def test_wrap_future_use_global_loop(self, m_events): + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = futures.wrap_future(f1) + self.assertIs(m_events.get_event_loop.return_value, f2._loop) + + +# A fake event loop for tests. All it does is implement a call_soon method +# that immediately invokes the given function. +class _FakeEventLoop: + def call_soon(self, fn, *args): + fn(*args) + + +class FutureDoneCallbackTests(unittest.TestCase): + + def _make_callback(self, bag, thing): + # Create a callback function that appends thing to bag. + def bag_appender(future): + bag.append(thing) + return bag_appender + + def _new_future(self): + return futures.Future(loop=_FakeEventLoop()) + + def test_callbacks_invoked_on_set_result(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 42)) + f.add_done_callback(self._make_callback(bag, 17)) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def test_callbacks_invoked_on_set_exception(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 100)) + + self.assertEqual(bag, []) + exc = RuntimeError() + f.set_exception(exc) + self.assertEqual(bag, [100]) + self.assertEqual(f.exception(), exc) + + def test_remove_done_callback(self): + bag = [] + f = self._new_future() + cb1 = self._make_callback(bag, 1) + cb2 = self._make_callback(bag, 2) + cb3 = self._make_callback(bag, 3) + + # Add one cb1 and one cb2. + f.add_done_callback(cb1) + f.add_done_callback(cb2) + + # One instance of cb2 removed. Now there's only one cb1. + self.assertEqual(f.remove_done_callback(cb2), 1) + + # Never had any cb3 in there. + self.assertEqual(f.remove_done_callback(cb3), 0) + + # After this there will be 6 instances of cb1 and one of cb2. + f.add_done_callback(cb2) + for i in range(5): + f.add_done_callback(cb1) + + # Remove all instances of cb1. One cb2 remains. + self.assertEqual(f.remove_done_callback(cb1), 6) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [2]) + self.assertEqual(f.result(), 'foo') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/http_client_functional_test.py b/tests/http_client_functional_test.py new file mode 100644 index 00000000..91badfc4 --- /dev/null +++ b/tests/http_client_functional_test.py @@ -0,0 +1,552 @@ +"""Http client functional tests.""" + +import gc +import io +import os.path +import http.cookies +import unittest + +import tulip +import tulip.http +from tulip import test_utils +from tulip.http import client + + +class HttpClientFunctionalTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(None) + + def tearDown(self): + # just in case if we have transport close callbacks + test_utils.run_briefly(self.loop) + + self.loop.close() + gc.collect() + + def test_HTTP_200_OK_METHOD(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + for meth in ('get', 'post', 'put', 'delete', 'head'): + r = self.loop.run_until_complete( + client.request(meth, httpd.url('method', meth), + loop=self.loop)) + content1 = self.loop.run_until_complete(r.read()) + content2 = self.loop.run_until_complete(r.read()) + content = content1.decode() + + self.assertEqual(r.status, 200) + self.assertIn('"method": "%s"' % meth.upper(), content) + self.assertEqual(content1, content2) + r.close() + + def test_use_global_loop(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + try: + tulip.set_event_loop(self.loop) + r = self.loop.run_until_complete( + client.request('get', httpd.url('method', 'get'))) + finally: + tulip.set_event_loop(None) + content1 = self.loop.run_until_complete(r.read()) + content2 = self.loop.run_until_complete(r.read()) + content = content1.decode() + + self.assertEqual(r.status, 200) + self.assertIn('"method": "GET"', content) + self.assertEqual(content1, content2) + r.close() + + def test_HTTP_302_REDIRECT_GET(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('redirect', 2), + loop=self.loop)) + + self.assertEqual(r.status, 200) + self.assertEqual(2, httpd['redirects']) + r.close() + + def test_HTTP_302_REDIRECT_NON_HTTP(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + self.assertRaises( + ValueError, + self.loop.run_until_complete, + client.request('get', httpd.url('redirect_err'), + loop=self.loop)) + + def test_HTTP_302_REDIRECT_POST(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('post', httpd.url('redirect', 2), + data={'some': 'data'}, loop=self.loop)) + content = self.loop.run_until_complete(r.content.read()) + content = content.decode() + + self.assertEqual(r.status, 200) + self.assertIn('"method": "POST"', content) + self.assertEqual(2, httpd['redirects']) + r.close() + + def test_HTTP_302_max_redirects(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('redirect', 5), + max_redirects=2, loop=self.loop)) + + self.assertEqual(r.status, 302) + self.assertEqual(2, httpd['redirects']) + r.close() + + def test_HTTP_200_GET_WITH_PARAMS(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('method', 'get'), + params={'q': 'test'}, loop=self.loop)) + content = self.loop.run_until_complete(r.content.read()) + content = content.decode() + + self.assertIn('"query": "q=test"', content) + self.assertEqual(r.status, 200) + r.close() + + def test_HTTP_200_GET_WITH_MIXED_PARAMS(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request( + 'get', httpd.url('method', 'get') + '?test=true', + params={'q': 'test'}, loop=self.loop)) + content = self.loop.run_until_complete(r.content.read()) + content = content.decode() + + self.assertIn('"query": "test=true&q=test"', content) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_DATA(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + r = self.loop.run_until_complete( + client.request('post', url, data={'some': 'data'}, + loop=self.loop)) + self.assertEqual(r.status, 200) + + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual({'some': ['data']}, content['form']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_DATA_DEFLATE(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + r = self.loop.run_until_complete( + client.request('post', url, + data={'some': 'data'}, compress=True, + loop=self.loop)) + self.assertEqual(r.status, 200) + + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual('deflate', content['compression']) + self.assertEqual({'some': ['data']}, content['form']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request( + 'post', url, files={'some': f}, chunked=1024, + headers={'Transfer-Encoding': 'chunked'}, + loop=self.loop)) + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_DEFLATE(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files={'some': f}, + chunked=1024, compress='deflate', + loop=self.loop)) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual('deflate', content['compression']) + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_STR(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files=[('some', f.read())], + loop=self.loop)) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + 'some', content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_LIST(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files=[('some', f)], + loop=self.loop)) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_LIST_CT(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, loop=self.loop, + files=[('some', f, 'text/plain')])) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual( + 'text/plain', content['multipart-data'][0]['content-type']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_SINGLE(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files=[f], loop=self.loop)) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + filename, content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_IO(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + data = io.BytesIO(b'data') + + r = self.loop.run_until_complete( + client.request('post', url, files=[data], loop=self.loop)) + + content = self.loop.run_until_complete(r.read(True)) + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + {'content-type': 'application/octet-stream', + 'data': 'data', + 'filename': 'unknown', + 'name': 'unknown'}, content['multipart-data'][0]) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_WITH_DATA(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, loop=self.loop, + data={'test': 'true'}, files={'some': f})) + + content = self.loop.run_until_complete(r.read(True)) + + self.assertEqual(2, len(content['multipart-data'])) + self.assertEqual( + 'test', content['multipart-data'][0]['name']) + self.assertEqual( + 'true', content['multipart-data'][0]['data']) + + f.seek(0) + filename = os.path.split(f.name)[-1] + self.assertEqual( + 'some', content['multipart-data'][1]['name']) + self.assertEqual( + filename, content['multipart-data'][1]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][1]['data']) + self.assertEqual(r.status, 200) + r.close() + + def test_encoding(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('encoding', 'deflate'), + loop=self.loop)) + self.assertEqual(r.status, 200) + + r = self.loop.run_until_complete( + client.request('get', httpd.url('encoding', 'gzip'), + loop=self.loop)) + self.assertEqual(r.status, 200) + r.close() + + def test_cookies(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + c = http.cookies.Morsel() + c.set('test3', '456', '456') + + r = self.loop.run_until_complete( + client.request( + 'get', httpd.url('method', 'get'), loop=self.loop, + cookies={'test1': '123', 'test2': c})) + self.assertEqual(r.status, 200) + + content = self.loop.run_until_complete(r.content.read()) + self.assertIn(b'"Cookie": "test1=123; test3=456"', bytes(content)) + r.close() + + def test_set_cookies(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + resp = self.loop.run_until_complete( + client.request('get', httpd.url('cookies'), loop=self.loop)) + self.assertEqual(resp.status, 200) + + self.assertEqual(resp.cookies['c1'].value, 'cookie1') + self.assertEqual(resp.cookies['c2'].value, 'cookie2') + resp.close() + + def test_chunked(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('chunked'), loop=self.loop)) + self.assertEqual(r.status, 200) + self.assertEqual(r['Transfer-Encoding'], 'chunked') + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['path'], '/chunked') + r.close() + + def test_timeout(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + httpd['noresponse'] = True + self.assertRaises( + tulip.TimeoutError, + self.loop.run_until_complete, + client.request('get', httpd.url('method', 'get'), + timeout=0.1, loop=self.loop)) + + def test_request_conn_error(self): + self.assertRaises( + OSError, + self.loop.run_until_complete, + client.request('get', 'http://0.0.0.0:1', + timeout=0.1, loop=self.loop)) + + def test_request_conn_closed(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + httpd['close'] = True + self.assertRaises( + tulip.http.HttpException, + self.loop.run_until_complete, + client.request('get', httpd.url('method', 'get'), + loop=self.loop)) + + def test_keepalive(self): + from tulip.http import session + s = session.Session() + + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('keepalive',), + session=s, loop=self.loop)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['content'], 'requests=1') + r.close() + + r = self.loop.run_until_complete( + client.request('get', httpd.url('keepalive'), + session=s, loop=self.loop)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['content'], 'requests=2') + r.close() + + def test_session_close(self): + from tulip.http import session + s = session.Session() + + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request( + 'get', httpd.url('keepalive') + '?close=1', + session=s, loop=self.loop)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['content'], 'requests=1') + r.close() + + r = self.loop.run_until_complete( + client.request('get', httpd.url('keepalive'), + session=s, loop=self.loop)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['content'], 'requests=1') + r.close() + + def test_session_cookies(self): + from tulip.http import session + s = session.Session() + + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + s.update_cookies({'test': '1'}) + r = self.loop.run_until_complete( + client.request('get', httpd.url('cookies'), + session=s, loop=self.loop)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + + self.assertEqual(content['headers']['Cookie'], 'test=1') + r.close() + + cookies = sorted([(k, v.value) for k, v in s.cookies.items()]) + self.assertEqual( + cookies, [('c1', 'cookie1'), ('c2', 'cookie2'), ('test', '1')]) + + +class Functional(test_utils.Router): + + @test_utils.Router.define('/method/([A-Za-z]+)$') + def method(self, match): + meth = match.group(1).upper() + if meth == self._method: + self._response(self._start_response(200)) + else: + self._response(self._start_response(400)) + + @test_utils.Router.define('/redirect_err$') + def redirect_err(self, match): + self._response( + self._start_response(302), + headers={'Location': 'ftp://127.0.0.1/test/'}) + + @test_utils.Router.define('/redirect/([0-9]+)$') + def redirect(self, match): + no = int(match.group(1).upper()) + rno = self._props['redirects'] = self._props.get('redirects', 0) + 1 + + if rno >= no: + self._response( + self._start_response(302), + headers={'Location': '/method/%s' % self._method.lower()}) + else: + self._response( + self._start_response(302), + headers={'Location': self._path}) + + @test_utils.Router.define('/encoding/(gzip|deflate)$') + def encoding(self, match): + mode = match.group(1) + + resp = self._start_response(200) + resp.add_compression_filter(mode) + resp.add_chunking_filter(100) + self._response(resp, headers={'Content-encoding': mode}, chunked=True) + + @test_utils.Router.define('/chunked$') + def chunked(self, match): + resp = self._start_response(200) + resp.add_chunking_filter(100) + self._response(resp, chunked=True) + + @test_utils.Router.define('/keepalive$') + def keepalive(self, match): + self._transport._requests = getattr( + self._transport, '_requests', 0) + 1 + resp = self._start_response(200) + if 'close=' in self._query: + self._response( + resp, 'requests={}'.format(self._transport._requests)) + else: + self._response( + resp, 'requests={}'.format(self._transport._requests), + headers={'CONNECTION': 'keep-alive'}) + + @test_utils.Router.define('/cookies$') + def cookies(self, match): + cookies = http.cookies.SimpleCookie() + cookies['c1'] = 'cookie1' + cookies['c2'] = 'cookie2' + + resp = self._start_response(200) + for cookie in cookies.output(header='').split('\n'): + resp.add_header('Set-Cookie', cookie.strip()) + + self._response(resp) diff --git a/tests/http_client_test.py b/tests/http_client_test.py new file mode 100644 index 00000000..1aa27244 --- /dev/null +++ b/tests/http_client_test.py @@ -0,0 +1,289 @@ +# -*- coding: utf-8 -*- +"""Tests for tulip/http/client.py""" + +import unittest +import unittest.mock +import urllib.parse + +import tulip +import tulip.http + +from tulip.http.client import HttpRequest, HttpResponse + + +class HttpResponseTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(None) + + self.transport = unittest.mock.Mock() + self.stream = tulip.StreamBuffer(loop=self.loop) + self.response = HttpResponse('get', 'http://python.org') + + def tearDown(self): + self.loop.close() + + def test_close(self): + self.response.transport = self.transport + self.response.close() + self.assertIsNone(self.response.transport) + self.assertTrue(self.transport.close.called) + self.response.close() + self.response.close() + + def test_repr(self): + self.response.status = 200 + self.response.reason = 'Ok' + self.assertIn( + '', repr(self.response)) + + +class HttpRequestTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(None) + + self.transport = unittest.mock.Mock() + self.stream = tulip.StreamBuffer(loop=self.loop) + + def tearDown(self): + self.loop.close() + + def test_method(self): + req = HttpRequest('get', 'http://python.org/') + self.assertEqual(req.method, 'GET') + + req = HttpRequest('head', 'http://python.org/') + self.assertEqual(req.method, 'HEAD') + + req = HttpRequest('HEAD', 'http://python.org/') + self.assertEqual(req.method, 'HEAD') + + def test_version(self): + req = HttpRequest('get', 'http://python.org/', version='1.0') + self.assertEqual(req.version, (1, 0)) + + def test_version_err(self): + self.assertRaises( + ValueError, + HttpRequest, 'get', 'http://python.org/', version='1.c') + + def test_host_port(self): + req = HttpRequest('get', 'http://python.org/') + self.assertEqual(req.host, 'python.org') + self.assertEqual(req.port, 80) + self.assertFalse(req.ssl) + + req = HttpRequest('get', 'https://python.org/') + self.assertEqual(req.host, 'python.org') + self.assertEqual(req.port, 443) + self.assertTrue(req.ssl) + + req = HttpRequest('get', 'https://python.org:960/') + self.assertEqual(req.host, 'python.org') + self.assertEqual(req.port, 960) + self.assertTrue(req.ssl) + + def test_host_port_err(self): + self.assertRaises( + ValueError, HttpRequest, 'get', 'http://python.org:123e/') + + def test_host_header(self): + req = HttpRequest('get', 'http://python.org/') + self.assertEqual(req.headers['host'], 'python.org') + + req = HttpRequest('get', 'http://python.org/', + headers={'host': 'example.com'}) + self.assertEqual(req.headers['host'], 'example.com') + + def test_headers(self): + req = HttpRequest('get', 'http://python.org/', + headers={'Content-Type': 'text/plain'}) + self.assertIn('Content-Type', req.headers) + self.assertEqual(req.headers['Content-Type'], 'text/plain') + self.assertEqual(req.headers['Accept-Encoding'], 'gzip, deflate') + + def test_headers_list(self): + req = HttpRequest('get', 'http://python.org/', + headers=[('Content-Type', 'text/plain')]) + self.assertIn('Content-Type', req.headers) + self.assertEqual(req.headers['Content-Type'], 'text/plain') + + def test_headers_default(self): + req = HttpRequest('get', 'http://python.org/', + headers={'Accept-Encoding': 'deflate'}) + self.assertEqual(req.headers['Accept-Encoding'], 'deflate') + + def test_invalid_url(self): + self.assertRaises(ValueError, HttpRequest, 'get', 'hiwpefhipowhefopw') + + def test_invalid_idna(self): + self.assertRaises( + ValueError, HttpRequest, 'get', 'http://\u2061owhefopw.com') + + def test_no_path(self): + req = HttpRequest('get', 'http://python.org') + self.assertEqual('/', req.path) + + def test_basic_auth(self): + req = HttpRequest('get', 'http://python.org', auth=('nkim', '1234')) + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbToxMjM0', req.headers['Authorization']) + + def test_basic_auth_from_url(self): + req = HttpRequest('get', 'http://nkim:1234@python.org') + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbToxMjM0', req.headers['Authorization']) + + req = HttpRequest('get', 'http://nkim@python.org') + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbTo=', req.headers['Authorization']) + + req = HttpRequest( + 'get', 'http://nkim@python.org', auth=('nkim', '1234')) + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbToxMjM0', req.headers['Authorization']) + + def test_basic_auth_err(self): + self.assertRaises( + ValueError, HttpRequest, + 'get', 'http://python.org', auth=(1, 2, 3)) + + def test_no_content_length(self): + req = HttpRequest('get', 'http://python.org') + req.send(self.transport) + self.assertEqual('0', req.headers.get('Content-Length')) + + req = HttpRequest('head', 'http://python.org') + req.send(self.transport) + self.assertEqual('0', req.headers.get('Content-Length')) + + def test_path_is_not_double_encoded(self): + req = HttpRequest('get', "http://0.0.0.0/get/test case") + self.assertEqual(req.path, "/get/test%20case") + + req = HttpRequest('get', "http://0.0.0.0/get/test%20case") + self.assertEqual(req.path, "/get/test%20case") + + def test_params_are_added_before_fragment(self): + req = HttpRequest( + 'GET', "http://example.com/path#fragment", params={"a": "b"}) + self.assertEqual( + req.path, "/path?a=b#fragment") + + req = HttpRequest( + 'GET', + "http://example.com/path?key=value#fragment", params={"a": "b"}) + self.assertEqual( + req.path, "/path?key=value&a=b#fragment") + + def test_cookies(self): + req = HttpRequest( + 'get', 'http://test.com/path', cookies={'cookie1': 'val1'}) + self.assertIn('Cookie', req.headers) + self.assertEqual('cookie1=val1', req.headers['cookie']) + + req = HttpRequest( + 'get', 'http://test.com/path', + headers={'cookie': 'cookie1=val1'}, + cookies={'cookie2': 'val2'}) + self.assertEqual('cookie1=val1; cookie2=val2', req.headers['cookie']) + + def test_unicode_get(self): + def join(*suffix): + return urllib.parse.urljoin('http://python.org/', '/'.join(suffix)) + + url = 'http://python.org' + req = HttpRequest('get', url, params={'foo': 'f\xf8\xf8'}) + self.assertEqual('/?foo=f%C3%B8%C3%B8', req.path) + req = HttpRequest('', url, params={'f\xf8\xf8': 'f\xf8\xf8'}) + self.assertEqual('/?f%C3%B8%C3%B8=f%C3%B8%C3%B8', req.path) + req = HttpRequest('', url, params={'foo': 'foo'}) + self.assertEqual('/?foo=foo', req.path) + req = HttpRequest('', join('\xf8'), params={'foo': 'foo'}) + self.assertEqual('/%C3%B8?foo=foo', req.path) + + def test_query_multivalued_param(self): + for meth in HttpRequest.ALL_METHODS: + req = HttpRequest( + meth, 'http://python.org', + params=(('test', 'foo'), ('test', 'baz'))) + self.assertEqual(req.path, '/?test=foo&test=baz') + + def test_post_data(self): + for meth in HttpRequest.POST_METHODS: + req = HttpRequest(meth, 'http://python.org/', data={'life': '42'}) + req.send(self.transport) + self.assertEqual('/', req.path) + self.assertEqual(b'life=42', req.body[0]) + self.assertEqual('application/x-www-form-urlencoded', + req.headers['content-type']) + + def test_get_with_data(self): + for meth in HttpRequest.GET_METHODS: + req = HttpRequest(meth, 'http://python.org/', data={'life': '42'}) + self.assertEqual('/?life=42', req.path) + + @unittest.mock.patch('tulip.http.client.tulip') + def test_content_encoding(self, m_tulip): + req = HttpRequest('get', 'http://python.org/', compress='deflate') + req.send(self.transport) + self.assertEqual(req.headers['Transfer-encoding'], 'chunked') + self.assertEqual(req.headers['Content-encoding'], 'deflate') + m_tulip.http.Request.return_value\ + .add_compression_filter.assert_called_with('deflate') + + @unittest.mock.patch('tulip.http.client.tulip') + def test_content_encoding_header(self, m_tulip): + req = HttpRequest('get', 'http://python.org/', + headers={'Content-Encoding': 'deflate'}) + req.send(self.transport) + self.assertEqual(req.headers['Transfer-encoding'], 'chunked') + self.assertEqual(req.headers['Content-encoding'], 'deflate') + + m_tulip.http.Request.return_value\ + .add_compression_filter.assert_called_with('deflate') + m_tulip.http.Request.return_value\ + .add_chunking_filter.assert_called_with(8196) + + def test_chunked(self): + req = HttpRequest( + 'get', 'http://python.org/', + headers={'Transfer-encoding': 'gzip'}) + req.send(self.transport) + self.assertEqual('gzip', req.headers['Transfer-encoding']) + + req = HttpRequest( + 'get', 'http://python.org/', + headers={'Transfer-encoding': 'chunked'}) + req.send(self.transport) + self.assertEqual('chunked', req.headers['Transfer-encoding']) + + @unittest.mock.patch('tulip.http.client.tulip') + def test_chunked_explicit(self, m_tulip): + req = HttpRequest( + 'get', 'http://python.org/', chunked=True) + req.send(self.transport) + + self.assertEqual('chunked', req.headers['Transfer-encoding']) + m_tulip.http.Request.return_value\ + .add_chunking_filter.assert_called_with(8196) + + @unittest.mock.patch('tulip.http.client.tulip') + def test_chunked_explicit_size(self, m_tulip): + req = HttpRequest( + 'get', 'http://python.org/', chunked=1024) + req.send(self.transport) + self.assertEqual('chunked', req.headers['Transfer-encoding']) + m_tulip.http.Request.return_value\ + .add_chunking_filter.assert_called_with(1024) + + def test_chunked_length(self): + req = HttpRequest( + 'get', 'http://python.org/', + headers={'Content-Length': '1000'}, chunked=1024) + req.send(self.transport) + self.assertEqual(req.headers['Transfer-Encoding'], 'chunked') + self.assertNotIn('Content-Length', req.headers) diff --git a/tests/http_parser_test.py b/tests/http_parser_test.py new file mode 100644 index 00000000..6240ad49 --- /dev/null +++ b/tests/http_parser_test.py @@ -0,0 +1,539 @@ +"""Tests for http/parser.py""" + +from collections import deque +import zlib +import unittest +import unittest.mock + +import tulip +from tulip.http import errors +from tulip.http import protocol + + +class ParseHeadersTests(unittest.TestCase): + + def setUp(self): + tulip.set_event_loop(None) + + def test_parse_headers(self): + hdrs = ('', 'test: line\r\n', ' continue\r\n', + 'test2: data\r\n', '\r\n') + + headers, close, compression = protocol.parse_headers( + hdrs, 8190, 32768, 8190) + + self.assertEqual(list(headers), + [('TEST', 'line\r\n continue'), ('TEST2', 'data')]) + self.assertIsNone(close) + self.assertIsNone(compression) + + def test_parse_headers_multi(self): + hdrs = ('', + 'Set-Cookie: c1=cookie1\r\n', + 'Set-Cookie: c2=cookie2\r\n', '\r\n') + + headers, close, compression = protocol.parse_headers( + hdrs, 8190, 32768, 8190) + + self.assertEqual(list(headers), + [('SET-COOKIE', 'c1=cookie1'), + ('SET-COOKIE', 'c2=cookie2')]) + self.assertIsNone(close) + self.assertIsNone(compression) + + def test_conn_close(self): + headers, close, compression = protocol.parse_headers( + ['', 'connection: close\r\n', '\r\n'], 8190, 32768, 8190) + self.assertTrue(close) + + def test_conn_keep_alive(self): + headers, close, compression = protocol.parse_headers( + ['', 'connection: keep-alive\r\n', '\r\n'], 8190, 32768, 8190) + self.assertFalse(close) + + def test_conn_other(self): + headers, close, compression = protocol.parse_headers( + ['', 'connection: test\r\n', '\r\n'], 8190, 32768, 8190) + self.assertIsNone(close) + + def test_compression_gzip(self): + headers, close, compression = protocol.parse_headers( + ['', 'content-encoding: gzip\r\n', '\r\n'], 8190, 32768, 8190) + self.assertEqual('gzip', compression) + + def test_compression_deflate(self): + headers, close, compression = protocol.parse_headers( + ['', 'content-encoding: deflate\r\n', '\r\n'], 8190, 32768, 8190) + self.assertEqual('deflate', compression) + + def test_compression_unknown(self): + headers, close, compression = protocol.parse_headers( + ['', 'content-encoding: compress\r\n', '\r\n'], 8190, 32768, 8190) + self.assertIsNone(compression) + + def test_max_field_size(self): + with self.assertRaises(errors.LineTooLong) as cm: + protocol.parse_headers( + ['', 'test: line data data\r\n', 'data\r\n', '\r\n'], + 8190, 32768, 5) + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_max_continuation_headers_size(self): + with self.assertRaises(errors.LineTooLong) as cm: + protocol.parse_headers( + ['', 'test: line\r\n', ' test\r\n', '\r\n'], 8190, 32768, 5) + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_invalid_header(self): + with self.assertRaises(ValueError) as cm: + protocol.parse_headers( + ['', 'test line\r\n', '\r\n'], 8190, 32768, 8190) + self.assertIn("Invalid header: test line", str(cm.exception)) + + def test_invalid_name(self): + with self.assertRaises(ValueError) as cm: + protocol.parse_headers( + ['', 'test[]: line\r\n', '\r\n'], 8190, 32768, 8190) + self.assertIn("Invalid header name: TEST[]", str(cm.exception)) + + +class DeflateBufferTests(unittest.TestCase): + + def setUp(self): + tulip.set_event_loop(None) + + def test_feed_data(self): + buf = tulip.DataBuffer() + dbuf = protocol.DeflateBuffer(buf, 'deflate') + + dbuf.zlib = unittest.mock.Mock() + dbuf.zlib.decompress.return_value = b'line' + + dbuf.feed_data(b'data') + self.assertEqual([b'line'], list(buf._buffer)) + + def test_feed_data_err(self): + buf = tulip.DataBuffer() + dbuf = protocol.DeflateBuffer(buf, 'deflate') + + exc = ValueError() + dbuf.zlib = unittest.mock.Mock() + dbuf.zlib.decompress.side_effect = exc + + self.assertRaises(errors.IncompleteRead, dbuf.feed_data, b'data') + + def test_feed_eof(self): + buf = tulip.DataBuffer() + dbuf = protocol.DeflateBuffer(buf, 'deflate') + + dbuf.zlib = unittest.mock.Mock() + dbuf.zlib.flush.return_value = b'line' + + dbuf.feed_eof() + self.assertEqual([b'line'], list(buf._buffer)) + self.assertTrue(buf._eof) + + def test_feed_eof_err(self): + buf = tulip.DataBuffer() + dbuf = protocol.DeflateBuffer(buf, 'deflate') + + dbuf.zlib = unittest.mock.Mock() + dbuf.zlib.flush.return_value = b'line' + dbuf.zlib.eof = False + + self.assertRaises(errors.IncompleteRead, dbuf.feed_eof) + + +class ParsePayloadTests(unittest.TestCase): + + def setUp(self): + tulip.set_event_loop(None) + + def test_parse_eof_payload(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_eof_payload(out, buf) + next(p) + p.send(b'data') + try: + p.throw(tulip.EofStream()) + except tulip.EofStream: + pass + + self.assertEqual([b'data'], list(out._buffer)) + + def test_parse_length_payload(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_length_payload(out, buf, 4) + next(p) + p.send(b'da') + p.send(b't') + try: + p.send(b'aline') + except StopIteration: + pass + + self.assertEqual(3, len(out._buffer)) + self.assertEqual(b'data', b''.join(out._buffer)) + self.assertEqual(b'line', bytes(buf)) + + def test_parse_length_payload_eof(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_length_payload(out, buf, 4) + next(p) + p.send(b'da') + self.assertRaises( + errors.IncompleteRead, p.throw, tulip.EofStream) + + def test_parse_chunked_payload(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + try: + p.send(b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') + except StopIteration: + pass + self.assertEqual(b'dataline', b''.join(out._buffer)) + self.assertEqual(b'', bytes(buf)) + + def test_parse_chunked_payload_chunks(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + p.send(b'4\r\ndata\r') + p.send(b'\n4') + p.send(b'\r') + p.send(b'\n') + p.send(b'line\r\n0\r\n') + self.assertRaises(StopIteration, p.send, b'test\r\n') + self.assertEqual(b'dataline', b''.join(out._buffer)) + + def test_parse_chunked_payload_incomplete(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + p.send(b'4\r\ndata\r\n') + self.assertRaises(errors.IncompleteRead, p.throw, tulip.EofStream) + + def test_parse_chunked_payload_extension(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + try: + p.send(b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') + except StopIteration: + pass + self.assertEqual(b'dataline', b''.join(out._buffer)) + + def test_parse_chunked_payload_size_error(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + self.assertRaises(errors.IncompleteRead, p.send, b'blah\r\n') + + def test_http_payload_parser_length_broken(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', 'qwe')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + self.assertRaises(errors.InvalidHeader, p.send, (out, buf)) + + def test_http_payload_parser_length_wrong(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', '-1')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + self.assertRaises(errors.InvalidHeader, p.send, (out, buf)) + + def test_http_payload_parser_length(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', '2')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + try: + p.send(b'1245') + except StopIteration: + pass + + self.assertEqual(b'12', b''.join(out._buffer)) + self.assertEqual(b'45', bytes(buf)) + + def test_http_payload_parser_no_length(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [], None, None) + p = protocol.http_payload_parser(msg, readall=False) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + self.assertRaises(StopIteration, p.send, (out, buf)) + self.assertEqual(b'', b''.join(out._buffer)) + self.assertTrue(out._eof) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_http_payload_parser_deflate(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', len(self._COMPRESSED))], + None, 'deflate') + p = protocol.http_payload_parser(msg) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(StopIteration, p.send, self._COMPRESSED) + self.assertEqual(b'data', b''.join(out._buffer)) + + def test_http_payload_parser_deflate_disabled(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', len(self._COMPRESSED))], + None, 'deflate') + p = protocol.http_payload_parser(msg, compression=False) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(StopIteration, p.send, self._COMPRESSED) + self.assertEqual(self._COMPRESSED, b''.join(out._buffer)) + + def test_http_payload_parser_websocket(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('SEC-WEBSOCKET-KEY1', '13')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(StopIteration, p.send, b'1234567890') + self.assertEqual(b'12345678', b''.join(out._buffer)) + + def test_http_payload_parser_chunked(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('TRANSFER-ENCODING', 'chunked')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(StopIteration, p.send, + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') + self.assertEqual(b'dataline', b''.join(out._buffer)) + + def test_http_payload_parser_eof(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [], None, None) + p = protocol.http_payload_parser(msg, readall=True) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + p.send(b'data') + p.send(b'line') + self.assertRaises(tulip.EofStream, p.throw, tulip.EofStream()) + self.assertEqual(b'dataline', b''.join(out._buffer)) + + def test_http_payload_parser_length_zero(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', '0')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + self.assertRaises(StopIteration, p.send, (out, buf)) + self.assertEqual(b'', b''.join(out._buffer)) + + +class ParseRequestTests(unittest.TestCase): + + def setUp(self): + tulip.set_event_loop(None) + + def test_http_request_parser_max_headers(self): + p = protocol.http_request_parser(8190, 20, 8190) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + + self.assertRaises( + errors.LineTooLong, + p.send, + b'get /path HTTP/1.1\r\ntest: line\r\ntest2: data\r\n\r\n') + + def test_http_request_parser(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + try: + p.send(b'get /path HTTP/1.1\r\n\r\n') + except StopIteration: + pass + result = out._buffer[0] + self.assertEqual( + ('GET', '/path', (1, 1), deque(), False, None), result) + + def test_http_request_parser_eof(self): + # http_request_parser does not fail on EofStream() + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + p.send(b'get /path HTTP/1.1\r\n') + try: + p.throw(tulip.EofStream()) + except StopIteration: + pass + self.assertFalse(out._buffer) + + def test_http_request_parser_two_slashes(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + try: + p.send(b'get //path HTTP/1.1\r\n\r\n') + except StopIteration: + pass + self.assertEqual( + ('GET', '//path', (1, 1), deque(), False, None), out._buffer[0]) + + def test_http_request_parser_bad_status_line(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises( + errors.BadStatusLine, p.send, b'\r\n\r\n') + + def test_http_request_parser_bad_method(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises( + errors.BadStatusLine, + p.send, b'!12%()+=~$ /get HTTP/1.1\r\n\r\n') + + def test_http_request_parser_bad_version(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises( + errors.BadStatusLine, + p.send, b'GET //get HT/11\r\n\r\n') + + +class ParseResponseTests(unittest.TestCase): + + def setUp(self): + tulip.set_event_loop(None) + + def test_http_response_parser_bad_status_line(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(errors.BadStatusLine, p.send, b'\r\n\r\n') + + def test_http_response_parser_bad_status_line_eof(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises( + errors.BadStatusLine, p.throw, tulip.EofStream()) + + def test_http_response_parser_bad_version(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HT/11 200 Ok\r\n\r\n') + self.assertEqual('HT/11 200 Ok\r\n', cm.exception.args[0]) + + def test_http_response_parser_no_reason(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + try: + p.send(b'HTTP/1.1 200\r\n\r\n') + except StopIteration: + pass + v, s, r = out._buffer[0][:3] + self.assertEqual(v, (1, 1)) + self.assertEqual(s, 200) + self.assertEqual(r, '') + + def test_http_response_parser_bad(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HTT/1\r\n\r\n') + self.assertIn('HTT/1', str(cm.exception)) + + def test_http_response_parser_code_under_100(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HTTP/1.1 99 test\r\n\r\n') + self.assertIn('HTTP/1.1 99 test', str(cm.exception)) + + def test_http_response_parser_code_above_999(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HTTP/1.1 9999 test\r\n\r\n') + self.assertIn('HTTP/1.1 9999 test', str(cm.exception)) + + def test_http_response_parser_code_not_int(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HTTP/1.1 ttt test\r\n\r\n') + self.assertIn('HTTP/1.1 ttt test', str(cm.exception)) diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py new file mode 100644 index 00000000..ec3aaf58 --- /dev/null +++ b/tests/http_protocol_test.py @@ -0,0 +1,400 @@ +"""Tests for http/protocol.py""" + +import unittest +import unittest.mock +import zlib + +import tulip +from tulip.http import protocol + + +class HttpMessageTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + tulip.set_event_loop(None) + + def test_start_request(self): + msg = protocol.Request( + self.transport, 'GET', '/index.html', close=True) + + self.assertIs(msg.transport, self.transport) + self.assertIsNone(msg.status) + self.assertTrue(msg.closing) + self.assertEqual(msg.status_line, 'GET /index.html HTTP/1.1\r\n') + + def test_start_response(self): + msg = protocol.Response(self.transport, 200, close=True) + + self.assertIs(msg.transport, self.transport) + self.assertEqual(msg.status, 200) + self.assertTrue(msg.closing) + self.assertEqual(msg.status_line, 'HTTP/1.1 200 OK\r\n') + + def test_force_close(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.closing) + msg.force_close() + self.assertTrue(msg.closing) + + def test_force_chunked(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.chunked) + msg.force_chunked() + self.assertTrue(msg.chunked) + + def test_keep_alive(self): + msg = protocol.Response(self.transport, 200, close=True) + self.assertFalse(msg.keep_alive()) + msg.keepalive = True + self.assertTrue(msg.keep_alive()) + + msg.force_close() + self.assertFalse(msg.keep_alive()) + + def test_keep_alive_http10(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + self.assertFalse(msg.keepalive) + self.assertFalse(msg.keep_alive()) + + msg = protocol.Response(self.transport, 200, http_version=(1, 1)) + self.assertIsNone(msg.keepalive) + + def test_add_header(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], list(msg.headers)) + + msg.add_header('content-type', 'plain/html') + self.assertEqual([('CONTENT-TYPE', 'plain/html')], list(msg.headers)) + + def test_add_headers(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], list(msg.headers)) + + msg.add_headers(('content-type', 'plain/html')) + self.assertEqual([('CONTENT-TYPE', 'plain/html')], list(msg.headers)) + + def test_add_headers_length(self): + msg = protocol.Response(self.transport, 200) + self.assertIsNone(msg.length) + + msg.add_headers(('content-length', '200')) + self.assertEqual(200, msg.length) + + def test_add_headers_upgrade(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.upgrade) + + msg.add_headers(('connection', 'upgrade')) + self.assertTrue(msg.upgrade) + + def test_add_headers_upgrade_websocket(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('upgrade', 'test')) + self.assertEqual([], list(msg.headers)) + + msg.add_headers(('upgrade', 'websocket')) + self.assertEqual([('UPGRADE', 'websocket')], list(msg.headers)) + + def test_add_headers_connection_keepalive(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('connection', 'keep-alive')) + self.assertEqual([], list(msg.headers)) + self.assertTrue(msg.keepalive) + + msg.add_headers(('connection', 'close')) + self.assertFalse(msg.keepalive) + + def test_add_headers_hop_headers(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('connection', 'test'), ('transfer-encoding', 't')) + self.assertEqual([], list(msg.headers)) + + def test_default_headers(self): + msg = protocol.Response(self.transport, 200) + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertIn('DATE', headers) + self.assertIn('CONNECTION', headers) + + def test_default_headers_server(self): + msg = protocol.Response(self.transport, 200) + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertIn('SERVER', headers) + + def test_default_headers_useragent(self): + msg = protocol.Request(self.transport, 'GET', '/') + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertNotIn('SERVER', headers) + self.assertIn('USER-AGENT', headers) + + def test_default_headers_chunked(self): + msg = protocol.Response(self.transport, 200) + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertNotIn('TRANSFER-ENCODING', headers) + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertIn('TRANSFER-ENCODING', headers) + + def test_default_headers_connection_upgrade(self): + msg = protocol.Response(self.transport, 200) + msg.upgrade = True + msg._add_default_headers() + + headers = [r for r in msg.headers if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'upgrade')], headers) + + def test_default_headers_connection_close(self): + msg = protocol.Response(self.transport, 200) + msg.force_close() + msg._add_default_headers() + + headers = [r for r in msg.headers if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'close')], headers) + + def test_default_headers_connection_keep_alive(self): + msg = protocol.Response(self.transport, 200) + msg.keepalive = True + msg._add_default_headers() + + headers = [r for r in msg.headers if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'keep-alive')], headers) + + def test_send_headers(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-type', 'plain/html')) + self.assertFalse(msg.is_headers_sent()) + + msg.send_headers() + + content = b''.join([arg[1][0] for arg in list(write.mock_calls)]) + + self.assertTrue(content.startswith(b'HTTP/1.1 200 OK\r\n')) + self.assertIn(b'CONTENT-TYPE: plain/html', content) + self.assertTrue(msg.headers_sent) + self.assertTrue(msg.is_headers_sent()) + # cleanup + msg.writer.close() + + def test_send_headers_nomore_add(self): + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-type', 'plain/html')) + msg.send_headers() + + self.assertRaises(AssertionError, + msg.add_header, 'content-type', 'plain/html') + # cleanup + msg.writer.close() + + def test_prepare_length(self): + msg = protocol.Response(self.transport, 200) + length = msg._write_length_payload = unittest.mock.Mock() + length.return_value = iter([1, 2, 3]) + + msg.add_headers(('content-length', '200')) + msg.send_headers() + + self.assertTrue(length.called) + self.assertTrue((200,), length.call_args[0]) + + def test_prepare_chunked_force(self): + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + + chunked = msg._write_chunked_payload = unittest.mock.Mock() + chunked.return_value = iter([1, 2, 3]) + + msg.add_headers(('content-length', '200')) + msg.send_headers() + self.assertTrue(chunked.called) + + def test_prepare_chunked_no_length(self): + msg = protocol.Response(self.transport, 200) + + chunked = msg._write_chunked_payload = unittest.mock.Mock() + chunked.return_value = iter([1, 2, 3]) + + msg.send_headers() + self.assertTrue(chunked.called) + + def test_prepare_eof(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + + eof = msg._write_eof_payload = unittest.mock.Mock() + eof.return_value = iter([1, 2, 3]) + + msg.send_headers() + self.assertTrue(eof.called) + + def test_write_auto_send_headers(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg._send_headers = True + + msg.write(b'data1') + self.assertTrue(msg.headers_sent) + # cleanup + msg.writer.close() + + def test_write_payload_eof(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg.send_headers() + + msg.write(b'data1') + self.assertTrue(msg.headers_sent) + + msg.write(b'data2') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'data1data2', content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg.send_headers() + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'4\r\ndata\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_multiple(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg.send_headers() + + msg.write(b'data1') + msg.write(b'data2') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_length(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '2')) + msg.send_headers() + + msg.write(b'd') + msg.write(b'ata') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'da', content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_filter(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(2) + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith(b'2\r\nda\r\n2\r\nta\r\n0\r\n\r\n')) + + def test_write_payload_chunked_filter_mutiple_chunks(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(2) + msg.write(b'data1') + msg.write(b'data2') + msg.write_eof() + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith( + b'2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n' + b'2\r\na2\r\n0\r\n\r\n')) + + def test_write_payload_chunked_large_chunk(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(1024) + msg.write(b'data') + msg.write_eof() + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith(b'4\r\ndata\r\n0\r\n\r\n')) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_write_payload_deflate_filter(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '{}'.format(len(self._COMPRESSED)))) + msg.send_headers() + + msg.add_compression_filter('deflate') + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_deflate_and_chunked(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_compression_filter('deflate') + msg.add_chunking_filter(2) + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'2\r\nKI\r\n2\r\n,I\r\n2\r\n\x04\x00\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_and_deflate(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '{}'.format(len(self._COMPRESSED)))) + + msg.add_chunking_filter(2) + msg.add_compression_filter('deflate') + msg.send_headers() + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) diff --git a/tests/http_server_test.py b/tests/http_server_test.py new file mode 100644 index 00000000..a9d4d5ed --- /dev/null +++ b/tests/http_server_test.py @@ -0,0 +1,301 @@ +"""Tests for http/server.py""" + +import unittest +import unittest.mock + +import tulip +from tulip.http import server +from tulip.http import errors +from tulip import test_utils + + +class HttpServerProtocolTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + tulip.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_http_error_exception(self): + exc = errors.HttpErrorException(500, message='Internal error') + self.assertEqual(exc.code, 500) + self.assertEqual(exc.message, 'Internal error') + + def test_handle_request(self): + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(transport) + + rline = unittest.mock.Mock() + rline.version = (1, 1) + message = unittest.mock.Mock() + srv.handle_request(rline, message) + + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertTrue(content.startswith(b'HTTP/1.1 404 Not Found\r\n')) + + def test_connection_made(self): + srv = server.ServerHttpProtocol(loop=self.loop) + self.assertIsNone(srv._request_handler) + + srv.connection_made(unittest.mock.Mock()) + self.assertIsNotNone(srv._request_handler) + + def test_data_received(self): + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(unittest.mock.Mock()) + + srv.data_received(b'123') + self.assertEqual(b'123', bytes(srv.stream._buffer)) + + srv.data_received(b'456') + self.assertEqual(b'123456', bytes(srv.stream._buffer)) + + def test_eof_received(self): + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(unittest.mock.Mock()) + srv.eof_received() + self.assertTrue(srv.stream._eof) + + def test_connection_lost(self): + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(unittest.mock.Mock()) + srv.data_received(b'123') + + keep_alive_handle = srv._keep_alive_handle = unittest.mock.Mock() + + handle = srv._request_handler + srv.connection_lost(None) + test_utils.run_briefly(self.loop) + + self.assertIsNone(srv._request_handler) + self.assertTrue(handle.cancelled()) + + self.assertIsNone(srv._keep_alive_handle) + self.assertTrue(keep_alive_handle.cancel.called) + + srv.connection_lost(None) + self.assertIsNone(srv._request_handler) + self.assertIsNone(srv._keep_alive_handle) + + def test_srv_keep_alive(self): + srv = server.ServerHttpProtocol(loop=self.loop) + self.assertFalse(srv._keep_alive) + + srv.keep_alive(True) + self.assertTrue(srv._keep_alive) + + srv.keep_alive(False) + self.assertFalse(srv._keep_alive) + + def test_handle_error(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(transport) + srv.keep_alive(True) + + srv.handle_error(404, headers=(('X-Server', 'Tulip'),)) + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertIn(b'HTTP/1.1 404 Not Found', content) + self.assertIn(b'X-SERVER: Tulip', content) + self.assertFalse(srv._keep_alive) + + @unittest.mock.patch('tulip.http.server.traceback') + def test_handle_error_traceback_exc(self, m_trace): + transport = unittest.mock.Mock() + log = unittest.mock.Mock() + srv = server.ServerHttpProtocol(debug=True, log=log, loop=self.loop) + srv.connection_made(transport) + + m_trace.format_exc.side_effect = ValueError + + srv.handle_error(500, exc=object()) + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertTrue( + content.startswith(b'HTTP/1.1 500 Internal Server Error')) + self.assertTrue(log.exception.called) + + def test_handle_error_debug(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + srv.debug = True + srv.connection_made(transport) + + try: + raise ValueError() + except Exception as exc: + srv.handle_error(999, exc=exc) + + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + + self.assertIn(b'HTTP/1.1 500 Internal', content) + self.assertIn(b'Traceback (most recent call last):', content) + + def test_handle_error_500(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log, loop=self.loop) + srv.connection_made(transport) + + srv.handle_error(500) + self.assertTrue(log.exception.called) + + def test_handle(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + + self.loop.run_until_complete(srv._request_handler) + self.assertTrue(handle.called) + self.assertTrue(transport.close.called) + + def test_handle_coro(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + + called = False + + @tulip.coroutine + def coro(message, payload): + nonlocal called + called = True + srv.eof_received() + + srv.handle_request = coro + srv.connection_made(transport) + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + self.loop.run_until_complete(srv._request_handler) + self.assertTrue(called) + + def test_handle_cancel(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log, debug=True, loop=self.loop) + srv.connection_made(transport) + + srv.handle_request = unittest.mock.Mock() + + @tulip.coroutine + def cancel(): + srv._request_handler.cancel() + + self.loop.run_until_complete( + tulip.wait([srv._request_handler, cancel()], loop=self.loop)) + self.assertTrue(log.debug.called) + + def test_handle_cancelled(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log, debug=True, loop=self.loop) + srv.connection_made(transport) + + srv.handle_request = unittest.mock.Mock() + test_utils.run_briefly(self.loop) # start request_handler task + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + + r_handler = srv._request_handler + srv._request_handler = None # emulate srv.connection_lost() + + self.assertIsNone(self.loop.run_until_complete(r_handler)) + + def test_handle_400(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(transport) + srv.handle_error = unittest.mock.Mock() + srv.keep_alive(True) + srv.stream.feed_data(b'GET / HT/asd\r\n\r\n') + + self.loop.run_until_complete(srv._request_handler) + self.assertTrue(srv.handle_error.called) + self.assertTrue(400, srv.handle_error.call_args[0][0]) + self.assertTrue(transport.close.called) + + def test_handle_500(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + handle.side_effect = ValueError + srv.handle_error = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + self.loop.run_until_complete(srv._request_handler) + + self.assertTrue(srv.handle_error.called) + self.assertTrue(500, srv.handle_error.call_args[0][0]) + + def test_handle_error_no_handle_task(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + srv.keep_alive(True) + srv.connection_made(transport) + srv.connection_lost(None) + + srv.handle_error(300) + self.assertFalse(srv._keep_alive) + + def test_keep_alive(self): + srv = server.ServerHttpProtocol(keep_alive=0.1, loop=self.loop) + transport = unittest.mock.Mock() + closed = False + + def close(): + nonlocal closed + closed = True + srv.connection_lost(None) + self.loop.stop() + + transport.close = close + + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.1\r\n' + b'CONNECTION: keep-alive\r\n' + b'HOST: example.com\r\n\r\n') + + self.loop.run_forever() + self.assertTrue(handle.called) + self.assertTrue(closed) + + def test_keep_alive_close_existing(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(keep_alive=15, loop=self.loop) + srv.connection_made(transport) + + self.assertIsNone(srv._keep_alive_handle) + keep_alive_handle = srv._keep_alive_handle = unittest.mock.Mock() + srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'HOST: example.com\r\n\r\n') + + self.loop.run_until_complete(srv._request_handler) + self.assertTrue(keep_alive_handle.cancel.called) + self.assertIsNone(srv._keep_alive_handle) + self.assertTrue(transport.close.called) diff --git a/tests/http_session_test.py b/tests/http_session_test.py new file mode 100644 index 00000000..39a80091 --- /dev/null +++ b/tests/http_session_test.py @@ -0,0 +1,139 @@ +"""Tests for tulip/http/session.py""" + +import http.cookies +import unittest +import unittest.mock + +import tulip +import tulip.http + +from tulip.http.client import HttpResponse +from tulip.http.session import Session + +from tulip import test_utils + + +class HttpSessionTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + tulip.set_event_loop(self.loop) + + self.transport = unittest.mock.Mock() + self.stream = tulip.StreamBuffer() + self.response = HttpResponse('get', 'http://python.org') + + def tearDown(self): + tulip.set_event_loop(None) + self.loop.close() + + def test_del(self): + session = Session() + close = session.close = unittest.mock.Mock() + + del session + self.assertTrue(close.called) + + def test_close(self): + tr = unittest.mock.Mock() + + session = Session() + session._conns[1] = [(tr, object())] + session.close() + + self.assertFalse(session._conns) + self.assertTrue(tr.close.called) + + def test_get(self): + session = Session() + self.assertEqual(session._get(1), (None, None)) + + tr, proto = unittest.mock.Mock(), object() + session._conns[1] = [(tr, proto)] + self.assertEqual(session._get(1), (tr, proto)) + + def test_release(self): + session = Session() + resp = unittest.mock.Mock() + resp.message.should_close = False + + cookies = resp.cookies = http.cookies.SimpleCookie() + cookies['c1'] = 'cookie1' + cookies['c2'] = 'cookie2' + + tr, proto = unittest.mock.Mock(), unittest.mock.Mock() + session._release(resp, 1, (tr, proto)) + self.assertEqual(session._conns[1][0], (tr, proto)) + self.assertEqual(session.cookies, dict(cookies.items())) + + def test_release_close(self): + session = Session() + resp = unittest.mock.Mock() + resp.message.should_close = True + + cookies = resp.cookies = http.cookies.SimpleCookie() + cookies['c1'] = 'cookie1' + cookies['c2'] = 'cookie2' + + tr, proto = unittest.mock.Mock(), unittest.mock.Mock() + session._release(resp, 1, (tr, proto)) + self.assertFalse(session._conns) + self.assertTrue(tr.close.called) + + def test_call_new_conn_exc(self): + tr, proto = unittest.mock.Mock(), unittest.mock.Mock() + + class Req: + host = 'host' + port = 80 + ssl = False + + def send(self, *args): + raise ValueError() + + class Loop: + @tulip.coroutine + def create_connection(self, *args, **kw): + return tr, proto + + session = Session() + self.assertRaises( + ValueError, + self.loop.run_until_complete, session.start(Req(), Loop(), True)) + + self.assertTrue(tr.close.called) + + def test_call_existing_conn_exc(self): + existing = unittest.mock.Mock() + tr, proto = unittest.mock.Mock(), unittest.mock.Mock() + + class Req: + host = 'host' + port = 80 + ssl = False + + def send(self, transport): + if transport is existing: + transport.close() + raise ValueError() + else: + return Resp() + + class Resp: + @tulip.coroutine + def start(self, *args, **kw): + pass + + class Loop: + @tulip.coroutine + def create_connection(self, *args, **kw): + return tr, proto + + session = Session() + key = ('host', 80, False) + session._conns[key] = [(existing, object())] + + resp = self.loop.run_until_complete(session.start(Req(), Loop())) + self.assertIsInstance(resp, Resp) + self.assertTrue(existing.close.called) + self.assertFalse(session._conns[key]) diff --git a/tests/http_websocket_test.py b/tests/http_websocket_test.py new file mode 100644 index 00000000..319538ae --- /dev/null +++ b/tests/http_websocket_test.py @@ -0,0 +1,439 @@ +"""Tests for http/websocket.py""" + +import base64 +import hashlib +import os +import struct +import unittest +import unittest.mock + +import tulip +from tulip.http import websocket, protocol, errors + + +class WebsocketParserTests(unittest.TestCase): + + def test_parse_frame(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + p.send(struct.pack('!BB', 0b00000001, 0b00000001)) + try: + p.send(b'1') + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b'1'), (fin, opcode, payload)) + + def test_parse_frame_length0(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + try: + p.send(struct.pack('!BB', 0b00000001, 0b00000000)) + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b''), (fin, opcode, payload)) + + def test_parse_frame_length2(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + p.send(struct.pack('!BB', 0b00000001, 126)) + p.send(struct.pack('!H', 4)) + try: + p.send(b'1234') + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b'1234'), (fin, opcode, payload)) + + def test_parse_frame_length4(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + p.send(struct.pack('!BB', 0b00000001, 127)) + p.send(struct.pack('!Q', 4)) + try: + p.send(b'1234') + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b'1234'), (fin, opcode, payload)) + + def test_parse_frame_mask(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + p.send(struct.pack('!BB', 0b00000001, 0b10000001)) + p.send(b'0001') + try: + p.send(b'1') + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b'\x01'), (fin, opcode, payload)) + + def test_parse_frame_header_reversed_bits(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b01100000, 0b00000000)) + + def test_parse_frame_header_control_frame(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b00001000, 0b00000000)) + + def test_parse_frame_header_continuation(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b00000000, 0b00000000)) + + def test_parse_frame_header_new_data_err(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b000000000, 0b00000000)) + + def test_parse_frame_header_payload_size(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b10001000, 0b01111110)) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_ping_frame(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_PING, b'') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_PING, '', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_pong_frame(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_PONG, b'') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_PONG, '', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_close_frame(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_CLOSE, b'') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_CLOSE, '', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_close_frame_info(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_CLOSE, b'0112345') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_CLOSE, 12337, b'12345')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_close_frame_invalid(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_CLOSE, b'1') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + self.assertRaises(websocket.WebSocketError, p.send, b'') + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_unknown_frame(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_CONTINUATION, b'') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + self.assertRaises(websocket.WebSocketError, p.send, b'') + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_simple_text(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_TEXT, b'text') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_TEXT, 'text', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_simple_binary(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_BINARY, b'binary') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_BINARY, b'binary', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_continuation(self, m_parse_frame): + cur = 0 + + def parse_frame(buf): + nonlocal cur + yield + if cur == 0: + cur = 1 + return (0, websocket.OPCODE_TEXT, b'line1') + else: + return (1, websocket.OPCODE_CONTINUATION, b'line2') + + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + p.send(b'') + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_TEXT, 'line1line2', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_continuation_err(self, m_parse_frame): + cur = 0 + + def parse_frame(buf): + nonlocal cur + yield + if cur == 0: + cur = 1 + return (0, websocket.OPCODE_TEXT, b'line1') + else: + return (1, websocket.OPCODE_TEXT, b'line2') + + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + p.send(b'') + self.assertRaises(websocket.WebSocketError, p.send, b'') + + @unittest.mock.patch('tulip.http.websocket.parse_message') + def test_parser(self, m_parse_message): + cur = 0 + + def parse_message(buf): + nonlocal cur + yield + if cur == 0: + cur = 1 + return websocket.Message(websocket.OPCODE_TEXT, b'line1', b'') + else: + return websocket.Message(websocket.OPCODE_CLOSE, b'', b'') + + m_parse_message.side_effect = parse_message + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = websocket.WebSocketParser() + next(p) + p.send((out, buf)) + p.send(b'') + self.assertRaises(StopIteration, p.send, b'') + + self.assertEqual( + (websocket.OPCODE_TEXT, b'line1', b''), out._buffer[0]) + self.assertEqual( + (websocket.OPCODE_CLOSE, b'', b''), out._buffer[1]) + self.assertTrue(out._eof) + + def test_parser_eof(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = websocket.WebSocketParser() + next(p) + p.send((out, buf)) + self.assertRaises(tulip.EofStream, p.throw, tulip.EofStream) + self.assertEqual([], list(out._buffer)) + + +class WebsocketWriterTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + self.writer = websocket.WebSocketWriter(self.transport) + + def test_pong(self): + self.writer.pong() + self.transport.write.assert_called_with(b'\x8a\x00') + + def test_ping(self): + self.writer.ping() + self.transport.write.assert_called_with(b'\x89\x00') + + def test_send_text(self): + self.writer.send(b'text') + self.transport.write.assert_called_with(b'\x81\x04text') + + def test_send_binary(self): + self.writer.send('binary', True) + self.transport.write.assert_called_with(b'\x82\x06binary') + + def test_send_binary_long(self): + self.writer.send(b'b'*127, True) + self.assertTrue( + self.transport.write.call_args[0][0].startswith(b'\x82~\x00\x7fb')) + + def test_send_binary_very_long(self): + self.writer.send(b'b'*65537, True) + self.assertTrue( + self.transport.write.call_args[0][0].startswith( + b'\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x01b')) + + def test_close(self): + self.writer.close(1001, 'msg') + self.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') + + self.writer.close(1001, b'msg') + self.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') + + +class WebSocketHandshakeTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + self.headers = [] + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 0), self.headers, True, None) + + def test_not_get(self): + self.assertRaises( + errors.HttpErrorException, + websocket.do_handshake, + 'POST', self.message.headers, self.transport) + + def test_no_upgrade(self): + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + def test_no_connection(self): + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'keep-alive')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + def test_protocol_version(self): + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '1')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + def test_protocol_key(self): + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13'), + ('SEC-WEBSOCKET-KEY', '123')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + sec_key = base64.b64encode(os.urandom(2)) + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13'), + ('SEC-WEBSOCKET-KEY', sec_key.decode())]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + def test_handshake(self): + sec_key = base64.b64encode(os.urandom(16)).decode() + + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13'), + ('SEC-WEBSOCKET-KEY', sec_key)]) + status, headers, parser, writer = websocket.do_handshake( + self.message.method, self.message.headers, self.transport) + self.assertEqual(status, 101) + + key = base64.b64encode( + hashlib.sha1(sec_key.encode() + websocket.WS_KEY).digest()) + headers = dict(headers) + self.assertEqual(headers['SEC-WEBSOCKET-ACCEPT'], key.decode()) diff --git a/tests/http_wsgi_test.py b/tests/http_wsgi_test.py new file mode 100644 index 00000000..053f5a69 --- /dev/null +++ b/tests/http_wsgi_test.py @@ -0,0 +1,301 @@ +"""Tests for http/wsgi.py""" + +import io +import unittest +import unittest.mock + +import tulip +from tulip.http import wsgi +from tulip.http import protocol + + +class HttpWsgiServerProtocolTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(None) + + self.wsgi = unittest.mock.Mock() + self.stream = unittest.mock.Mock() + self.transport = unittest.mock.Mock() + self.transport.get_extra_info.return_value = '127.0.0.1' + + self.headers = [] + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 0), self.headers, True, 'deflate') + self.payload = tulip.DataBuffer() + self.payload.feed_data(b'data') + self.payload.feed_data(b'data') + self.payload.feed_eof() + + def tearDown(self): + self.loop.close() + + def test_ctor(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) + self.assertIs(srv.wsgi, self.wsgi) + self.assertFalse(srv.readpayload) + + def _make_one(self, **kw): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop, **kw) + srv.stream = self.stream + srv.transport = self.transport + return srv.create_wsgi_environ(self.message, self.payload) + + def test_environ(self): + environ = self._make_one() + self.assertEqual(environ['RAW_URI'], '/path') + self.assertEqual(environ['wsgi.async'], True) + + def test_environ_except_header(self): + self.headers.append(('EXPECT', '101-continue')) + self._make_one() + self.assertFalse(self.transport.write.called) + + self.headers[0] = ('EXPECT', '100-continue') + self._make_one() + self.transport.write.assert_called_with( + b'HTTP/1.1 100 Continue\r\n\r\n') + + def test_environ_headers(self): + self.headers.extend( + (('HOST', 'python.org'), + ('SCRIPT_NAME', 'script'), + ('CONTENT-TYPE', 'text/plain'), + ('CONTENT-LENGTH', '209'), + ('X_TEST', '123'), + ('X_TEST', '456'))) + environ = self._make_one(is_ssl=True) + self.assertEqual(environ['CONTENT_TYPE'], 'text/plain') + self.assertEqual(environ['CONTENT_LENGTH'], '209') + self.assertEqual(environ['HTTP_X_TEST'], '123,456') + self.assertEqual(environ['SCRIPT_NAME'], 'script') + self.assertEqual(environ['SERVER_NAME'], 'python.org') + self.assertEqual(environ['SERVER_PORT'], '443') + + def test_environ_host_header(self): + self.headers.append(('HOST', 'python.org')) + environ = self._make_one() + + self.assertEqual(environ['HTTP_HOST'], 'python.org') + self.assertEqual(environ['SERVER_NAME'], 'python.org') + self.assertEqual(environ['SERVER_PORT'], '80') + self.assertEqual(environ['SERVER_PROTOCOL'], 'HTTP/1.0') + + def test_environ_host_port_header(self): + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 1), self.headers, True, 'deflate') + self.headers.append(('HOST', 'python.org:443')) + environ = self._make_one() + + self.assertEqual(environ['HTTP_HOST'], 'python.org:443') + self.assertEqual(environ['SERVER_NAME'], 'python.org') + self.assertEqual(environ['SERVER_PORT'], '443') + self.assertEqual(environ['SERVER_PROTOCOL'], 'HTTP/1.1') + + def test_environ_forward(self): + self.transport.get_extra_info.return_value = 'localhost,127.0.0.1' + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') + self.assertEqual(environ['REMOTE_PORT'], '80') + + self.transport.get_extra_info.return_value = 'localhost,127.0.0.1:443' + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') + self.assertEqual(environ['REMOTE_PORT'], '443') + + self.transport.get_extra_info.return_value = ('127.0.0.1', 443) + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') + self.assertEqual(environ['REMOTE_PORT'], '443') + + self.transport.get_extra_info.return_value = '[::1]' + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '::1') + self.assertEqual(environ['REMOTE_PORT'], '80') + + def test_wsgi_response(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + self.assertIsInstance(resp, wsgi.WsgiResponse) + + def test_wsgi_response_start_response(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + resp.start_response( + '200 OK', [('CONTENT-TYPE', 'text/plain')]) + self.assertEqual(resp.status, '200 OK') + self.assertIsInstance(resp.response, protocol.Response) + + def test_wsgi_response_start_response_exc(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + resp.start_response( + '200 OK', [('CONTENT-TYPE', 'text/plain')], ['', ValueError()]) + self.assertEqual(resp.status, '200 OK') + self.assertIsInstance(resp.response, protocol.Response) + + def test_wsgi_response_start_response_exc_status(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + resp.start_response('200 OK', [('CONTENT-TYPE', 'text/plain')]) + + self.assertRaises( + ValueError, + resp.start_response, + '500 Err', [('CONTENT-TYPE', 'text/plain')], ['', ValueError()]) + + @unittest.mock.patch('tulip.http.wsgi.tulip') + def test_wsgi_response_101_upgrade_to_websocket(self, m_tulip): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + resp.start_response( + '101 Switching Protocols', (('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'))) + self.assertEqual(resp.status, '101 Switching Protocols') + self.assertTrue(m_tulip.http.Response.return_value.send_headers.called) + + def test_file_wrapper(self): + fobj = io.BytesIO(b'data') + wrapper = wsgi.FileWrapper(fobj, 2) + self.assertIs(wrapper, iter(wrapper)) + self.assertTrue(hasattr(wrapper, 'close')) + + self.assertEqual(next(wrapper), b'da') + self.assertEqual(next(wrapper), b'ta') + self.assertRaises(StopIteration, next, wrapper) + + wrapper = wsgi.FileWrapper(b'data', 2) + self.assertFalse(hasattr(wrapper, 'close')) + + def test_handle_request_futures(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + f1 = tulip.Future(loop=self.loop) + f1.set_result(b'data') + fut = tulip.Future(loop=self.loop) + fut.set_result([f1]) + return fut + + srv = wsgi.WSGIServerHttpProtocol(wsgi_app, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) + self.assertTrue(content.endswith(b'data')) + + def test_handle_request_simple(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + return [b'data'] + + stream = tulip.StreamReader(loop=self.loop) + stream.feed_data(b'data') + stream.feed_eof() + + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 1), self.headers, True, 'deflate') + + srv = wsgi.WSGIServerHttpProtocol( + wsgi_app, readpayload=True, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.1 200 OK')) + self.assertTrue(content.endswith(b'data\r\n0\r\n\r\n')) + self.assertFalse(srv._keep_alive) + + def test_handle_request_io(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + return io.BytesIO(b'data') + + srv = wsgi.WSGIServerHttpProtocol(wsgi_app, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) + self.assertTrue(content.endswith(b'data')) + + def test_handle_request_keep_alive(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + return [b'data'] + + stream = tulip.StreamReader(loop=self.loop) + stream.feed_data(b'data') + stream.feed_eof() + + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 1), self.headers, False, 'deflate') + + srv = wsgi.WSGIServerHttpProtocol( + wsgi_app, readpayload=True, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.1 200 OK')) + self.assertTrue(content.endswith(b'data\r\n0\r\n\r\n')) + self.assertTrue(srv._keep_alive) + + def test_handle_request_readpayload(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + return [env['wsgi.input'].read()] + + srv = wsgi.WSGIServerHttpProtocol( + wsgi_app, readpayload=True, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) + self.assertTrue(content.endswith(b'data')) diff --git a/tests/locks_test.py b/tests/locks_test.py new file mode 100644 index 00000000..9399d759 --- /dev/null +++ b/tests/locks_test.py @@ -0,0 +1,765 @@ +"""Tests for lock.py""" + +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import locks +from tulip import tasks +from tulip import test_utils + + +class LockTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + lock = locks.Lock(loop=loop) + self.assertIs(lock._loop, loop) + + lock = locks.Lock(loop=self.loop) + self.assertIs(lock._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + lock = locks.Lock() + self.assertIs(lock._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + lock = locks.Lock(loop=self.loop) + self.assertTrue(repr(lock).endswith('[unlocked]>')) + + @tasks.coroutine + def acquire_lock(): + yield from lock + + self.loop.run_until_complete(acquire_lock()) + self.assertTrue(repr(lock).endswith('[locked]>')) + + def test_lock(self): + lock = locks.Lock(loop=self.loop) + + @tasks.coroutine + def acquire_lock(): + return (yield from lock) + + res = self.loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_acquire(self): + lock = locks.Lock(loop=self.loop) + result = [] + + self.assertTrue(self.loop.run_until_complete(lock.acquire())) + + @tasks.coroutine + def c1(result): + if (yield from lock.acquire()): + result.append(1) + return True + + @tasks.coroutine + def c2(result): + if (yield from lock.acquire()): + result.append(2) + return True + + @tasks.coroutine + def c3(result): + if (yield from lock.acquire()): + result.append(3) + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + t3 = tasks.Task(c3(result), loop=self.loop) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_acquire_cancel(self): + lock = locks.Lock(loop=self.loop) + self.assertTrue(self.loop.run_until_complete(lock.acquire())) + + task = tasks.Task(lock.acquire(), loop=self.loop) + self.loop.call_soon(task.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, task) + self.assertFalse(lock._waiters) + + def test_cancel_race(self): + # Several tasks: + # - A acquires the lock + # - B is blocked in aqcuire() + # - C is blocked in aqcuire() + # + # Now, concurrently: + # - B is cancelled + # - A releases the lock + # + # If B's waiter is marked cancelled but not yet removed from + # _waiters, A's release() call will crash when trying to set + # B's waiter; instead, it should move on to C's waiter. + + # Setup: A has the lock, b and c are waiting. + lock = locks.Lock(loop=self.loop) + + @tasks.coroutine + def lockit(name, blocker): + yield from lock.acquire() + try: + if blocker is not None: + yield from blocker + finally: + lock.release() + + fa = futures.Future(loop=self.loop) + ta = tasks.Task(lockit('A', fa), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertTrue(lock.locked()) + tb = tasks.Task(lockit('B', None), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(len(lock._waiters), 1) + tc = tasks.Task(lockit('C', None), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(len(lock._waiters), 2) + + # Create the race and check. + # Without the fix this failed at the last assert. + fa.set_result(None) + tb.cancel() + self.assertTrue(lock._waiters[0].cancelled()) + test_utils.run_briefly(self.loop) + self.assertFalse(lock.locked()) + self.assertTrue(ta.done()) + self.assertTrue(tb.cancelled()) + self.assertTrue(tc.done()) + + def test_release_not_acquired(self): + lock = locks.Lock(loop=self.loop) + + self.assertRaises(RuntimeError, lock.release) + + def test_release_no_waiters(self): + lock = locks.Lock(loop=self.loop) + self.loop.run_until_complete(lock.acquire()) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_context_manager(self): + lock = locks.Lock(loop=self.loop) + + @tasks.coroutine + def acquire_lock(): + return (yield from lock) + + with self.loop.run_until_complete(acquire_lock()): + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + def test_context_manager_no_yield(self): + lock = locks.Lock(loop=self.loop) + + try: + with lock: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + +class EventWaiterTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + ev = locks.EventWaiter(loop=loop) + self.assertIs(ev._loop, loop) + + ev = locks.EventWaiter(loop=self.loop) + self.assertIs(ev._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + ev = locks.EventWaiter() + self.assertIs(ev._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + ev = locks.EventWaiter(loop=self.loop) + self.assertTrue(repr(ev).endswith('[unset]>')) + + ev.set() + self.assertTrue(repr(ev).endswith('[set]>')) + + def test_wait(self): + ev = locks.EventWaiter(loop=self.loop) + self.assertFalse(ev.is_set()) + + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from ev.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from ev.wait()): + result.append(3) + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + t3 = tasks.Task(c3(result), loop=self.loop) + + ev.set() + test_utils.run_briefly(self.loop) + self.assertEqual([3, 1, 2], result) + + self.assertTrue(t1.done()) + self.assertIsNone(t1.result()) + self.assertTrue(t2.done()) + self.assertIsNone(t2.result()) + self.assertTrue(t3.done()) + self.assertIsNone(t3.result()) + + def test_wait_on_set(self): + ev = locks.EventWaiter(loop=self.loop) + ev.set() + + res = self.loop.run_until_complete(ev.wait()) + self.assertTrue(res) + + def test_wait_cancel(self): + ev = locks.EventWaiter(loop=self.loop) + + wait = tasks.Task(ev.wait(), loop=self.loop) + self.loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, wait) + self.assertFalse(ev._waiters) + + def test_clear(self): + ev = locks.EventWaiter(loop=self.loop) + self.assertFalse(ev.is_set()) + + ev.set() + self.assertTrue(ev.is_set()) + + ev.clear() + self.assertFalse(ev.is_set()) + + def test_clear_with_waiters(self): + ev = locks.EventWaiter(loop=self.loop) + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + return True + + t = tasks.Task(c1(result), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + ev.set() + ev.clear() + self.assertFalse(ev.is_set()) + + ev.set() + ev.set() + self.assertEqual(1, len(ev._waiters)) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertEqual(0, len(ev._waiters)) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + +class ConditionTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + cond = locks.Condition(loop=loop) + self.assertIs(cond._loop, loop) + + cond = locks.Condition(loop=self.loop) + self.assertIs(cond._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + cond = locks.Condition() + self.assertIs(cond._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_wait(self): + cond = locks.Condition(loop=self.loop) + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + return True + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue(self.loop.run_until_complete(cond.acquire())) + cond.notify() + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_wait_cancel(self): + cond = locks.Condition(loop=self.loop) + self.loop.run_until_complete(cond.acquire()) + + wait = tasks.Task(cond.wait(), loop=self.loop) + self.loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, wait) + self.assertFalse(cond._condition_waiters) + self.assertTrue(cond.locked()) + + def test_wait_unacquired(self): + cond = locks.Condition(loop=self.loop) + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, cond.wait()) + + def test_wait_for(self): + cond = locks.Condition(loop=self.loop) + presult = False + + def predicate(): + return presult + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate)): + result.append(1) + cond.release() + return True + + t = tasks.Task(c1(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + presult = True + self.loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + def test_wait_for_unacquired(self): + cond = locks.Condition(loop=self.loop) + + # predicate can return true immediately + res = self.loop.run_until_complete(cond.wait_for(lambda: [1, 2, 3])) + self.assertEqual([1, 2, 3], res) + + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, + cond.wait_for(lambda: False)) + + def test_notify(self): + cond = locks.Condition(loop=self.loop) + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + return True + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + cond.release() + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.notify(2048) + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_notify_all(self): + cond = locks.Condition(loop=self.loop) + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify_all() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + + def test_notify_unacquired(self): + cond = locks.Condition(loop=self.loop) + self.assertRaises(RuntimeError, cond.notify) + + def test_notify_all_unacquired(self): + cond = locks.Condition(loop=self.loop) + self.assertRaises(RuntimeError, cond.notify_all) + + +class SemaphoreTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + sem = locks.Semaphore(loop=loop) + self.assertIs(sem._loop, loop) + + sem = locks.Semaphore(loop=self.loop) + self.assertIs(sem._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + sem = locks.Semaphore() + self.assertIs(sem._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + sem = locks.Semaphore(loop=self.loop) + self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) + + self.loop.run_until_complete(sem.acquire()) + self.assertTrue(repr(sem).endswith('[locked]>')) + + def test_semaphore(self): + sem = locks.Semaphore(loop=self.loop) + self.assertEqual(1, sem._value) + + @tasks.coroutine + def acquire_lock(): + return (yield from sem) + + res = self.loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(sem.locked()) + self.assertEqual(0, sem._value) + + sem.release() + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + def test_semaphore_value(self): + self.assertRaises(ValueError, locks.Semaphore, -1) + + def test_acquire(self): + sem = locks.Semaphore(3, loop=self.loop) + result = [] + + self.assertTrue(self.loop.run_until_complete(sem.acquire())) + self.assertTrue(self.loop.run_until_complete(sem.acquire())) + self.assertFalse(sem.locked()) + + @tasks.coroutine + def c1(result): + yield from sem.acquire() + result.append(1) + return True + + @tasks.coroutine + def c2(result): + yield from sem.acquire() + result.append(2) + return True + + @tasks.coroutine + def c3(result): + yield from sem.acquire() + result.append(3) + return True + + @tasks.coroutine + def c4(result): + yield from sem.acquire() + result.append(4) + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(sem.locked()) + self.assertEqual(2, len(sem._waiters)) + self.assertEqual(0, sem._value) + + t4 = tasks.Task(c4(result), loop=self.loop) + + sem.release() + sem.release() + self.assertEqual(2, sem._value) + + test_utils.run_briefly(self.loop) + self.assertEqual(0, sem._value) + self.assertEqual([1, 2, 3], result) + self.assertTrue(sem.locked()) + self.assertEqual(1, len(sem._waiters)) + self.assertEqual(0, sem._value) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + self.assertFalse(t4.done()) + + # cleanup locked semaphore + sem.release() + + def test_acquire_cancel(self): + sem = locks.Semaphore(loop=self.loop) + self.loop.run_until_complete(sem.acquire()) + + acquire = tasks.Task(sem.acquire(), loop=self.loop) + self.loop.call_soon(acquire.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, acquire) + self.assertFalse(sem._waiters) + + def test_release_not_acquired(self): + sem = locks.Semaphore(bound=True, loop=self.loop) + + self.assertRaises(ValueError, sem.release) + + def test_release_no_waiters(self): + sem = locks.Semaphore(loop=self.loop) + self.loop.run_until_complete(sem.acquire()) + self.assertTrue(sem.locked()) + + sem.release() + self.assertFalse(sem.locked()) + + def test_context_manager(self): + sem = locks.Semaphore(2, loop=self.loop) + + @tasks.coroutine + def acquire_lock(): + return (yield from sem) + + with self.loop.run_until_complete(acquire_lock()): + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + with self.loop.run_until_complete(acquire_lock()): + self.assertTrue(sem.locked()) + + self.assertEqual(2, sem._value) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/parsers_test.py b/tests/parsers_test.py new file mode 100644 index 00000000..debc532c --- /dev/null +++ b/tests/parsers_test.py @@ -0,0 +1,598 @@ +"""Tests for parser.py""" + +import unittest +import unittest.mock + +from tulip import events +from tulip import parsers +from tulip import tasks + + +class StreamBufferTests(unittest.TestCase): + + DATA = b'line1\nline2\nline3\n' + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_exception(self): + stream = parsers.StreamBuffer() + self.assertIsNone(stream.exception()) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(stream.exception(), exc) + + def test_exception_waiter(self): + stream = parsers.StreamBuffer() + + stream._parser = parsers.lines_parser() + buf = stream._parser_buffer = parsers.DataBuffer(loop=self.loop) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(buf.exception(), exc) + + def test_feed_data(self): + stream = parsers.StreamBuffer() + + stream.feed_data(self.DATA) + self.assertEqual(self.DATA, bytes(stream._buffer)) + + def test_feed_empty_data(self): + stream = parsers.StreamBuffer() + + stream.feed_data(b'') + self.assertEqual(b'', bytes(stream._buffer)) + + def test_set_parser_unset_prev(self): + stream = parsers.StreamBuffer() + stream.set_parser(parsers.lines_parser()) + + unset = stream.unset_parser = unittest.mock.Mock() + stream.set_parser(parsers.lines_parser()) + + self.assertTrue(unset.called) + + def test_set_parser_exception(self): + stream = parsers.StreamBuffer() + + exc = ValueError() + stream.set_exception(exc) + s = stream.set_parser(parsers.lines_parser()) + self.assertIs(s.exception(), exc) + + def test_set_parser_feed_existing(self): + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_data(b'\r\nline2\r\ndata') + s = stream.set_parser(parsers.lines_parser()) + + self.assertEqual([bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertIsNotNone(stream._parser) + + stream.unset_parser() + self.assertIsNone(stream._parser) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertTrue(s._eof) + + def test_set_parser_feed_existing_exc(self): + + def p(): + yield # stream + raise ValueError() + + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + s = stream.set_parser(p()) + self.assertIsInstance(s.exception(), ValueError) + + def test_set_parser_feed_existing_eof(self): + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_data(b'\r\nline2\r\ndata') + stream.feed_eof() + s = stream.set_parser(parsers.lines_parser()) + + self.assertEqual([bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertIsNone(stream._parser) + + def test_set_parser_feed_existing_eof_exc(self): + + def p(): + yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + raise ValueError() + + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_eof() + s = stream.set_parser(p()) + self.assertIsInstance(s.exception(), ValueError) + + def test_set_parser_feed_existing_eof_unhandled_eof(self): + + def p(): + yield # stream + while True: + yield # read chunk + + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_eof() + s = stream.set_parser(p()) + self.assertIsNone(s.exception()) + self.assertTrue(s._eof) + + def test_set_parser_unset(self): + stream = parsers.StreamBuffer() + s = stream.set_parser(parsers.lines_parser()) + + stream.feed_data(b'line1\r\nline2\r\n') + self.assertEqual( + [bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'', bytes(stream._buffer)) + stream.unset_parser() + self.assertTrue(s._eof) + self.assertEqual(b'', bytes(stream._buffer)) + + def test_set_parser_feed_existing_stop(self): + def lines_parser(): + out, buf = yield + try: + out.feed_data((yield from buf.readuntil(b'\n'))) + out.feed_data((yield from buf.readuntil(b'\n'))) + finally: + out.feed_eof() + + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_data(b'\r\nline2\r\ndata') + s = stream.set_parser(lines_parser()) + + self.assertEqual(b'line1\r\nline2\r\n', b''.join(s._buffer)) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertIsNone(stream._parser) + self.assertTrue(s._eof) + + def test_feed_parser(self): + stream = parsers.StreamBuffer() + s = stream.set_parser(parsers.lines_parser()) + + stream.feed_data(b'line1') + stream.feed_data(b'\r\nline2\r\ndata') + self.assertEqual(b'data', bytes(stream._buffer)) + + stream.feed_eof() + self.assertEqual([bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertTrue(s._eof) + + def test_feed_parser_exc(self): + def p(): + yield # stream + yield # read chunk + raise ValueError() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + self.assertIsInstance(s.exception(), ValueError) + self.assertEqual(b'', bytes(stream._buffer)) + + def test_feed_parser_stop(self): + def p(): + yield # stream + yield # chunk + + stream = parsers.StreamBuffer() + stream.set_parser(p()) + + stream.feed_data(b'line1') + self.assertIsNone(stream._parser) + self.assertEqual(b'', bytes(stream._buffer)) + + def test_feed_eof_exc(self): + def p(): + yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + raise ValueError() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + self.assertIsNone(s.exception()) + + stream.feed_eof() + self.assertIsInstance(s.exception(), ValueError) + + def test_feed_eof_stop(self): + def p(): + out, buf = yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + out.feed_eof() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.feed_eof() + self.assertTrue(s._eof) + + def test_feed_eof_unhandled_eof(self): + def p(): + yield # stream + while True: + yield # read chunk + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.feed_eof() + self.assertIsNone(s.exception()) + self.assertTrue(s._eof) + + def test_feed_parser2(self): + stream = parsers.StreamBuffer() + s = stream.set_parser(parsers.lines_parser()) + + stream.feed_data(b'line1\r\nline2\r\n') + stream.feed_eof() + self.assertEqual( + [bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'', bytes(stream._buffer)) + self.assertTrue(s._eof) + + def test_unset_parser_eof_exc(self): + def p(): + yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + raise ValueError() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.unset_parser() + self.assertIsInstance(s.exception(), ValueError) + self.assertIsNone(stream._parser) + + def test_unset_parser_eof_unhandled_eof(self): + def p(): + yield # stream + while True: + yield # read chunk + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.unset_parser() + self.assertIsNone(s.exception(), ValueError) + self.assertTrue(s._eof) + + def test_unset_parser_stop(self): + def p(): + out, buf = yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + out.feed_eof() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.unset_parser() + self.assertTrue(s._eof) + + +class DataBufferTests(unittest.TestCase): + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_feed_data(self): + buffer = parsers.DataBuffer(loop=self.loop) + + item = object() + buffer.feed_data(item) + self.assertEqual([item], list(buffer._buffer)) + + def test_feed_eof(self): + buffer = parsers.DataBuffer(loop=self.loop) + buffer.feed_eof() + self.assertTrue(buffer._eof) + + def test_read(self): + item = object() + buffer = parsers.DataBuffer(loop=self.loop) + read_task = tasks.Task(buffer.read(), loop=self.loop) + + def cb(): + buffer.feed_data(item) + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertIs(item, data) + + def test_read_eof(self): + buffer = parsers.DataBuffer(loop=self.loop) + read_task = tasks.Task(buffer.read(), loop=self.loop) + + def cb(): + buffer.feed_eof() + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertIsNone(data) + + def test_read_until_eof(self): + item = object() + buffer = parsers.DataBuffer(loop=self.loop) + buffer.feed_data(item) + buffer.feed_eof() + + data = self.loop.run_until_complete(buffer.read()) + self.assertIs(data, item) + + data = self.loop.run_until_complete(buffer.read()) + self.assertIsNone(data) + + def test_read_exception(self): + buffer = parsers.DataBuffer(loop=self.loop) + buffer.feed_data(object()) + buffer.set_exception(ValueError()) + + self.assertRaises( + ValueError, self.loop.run_until_complete, buffer.read()) + + def test_exception(self): + buffer = parsers.DataBuffer(loop=self.loop) + self.assertIsNone(buffer.exception()) + + exc = ValueError() + buffer.set_exception(exc) + self.assertIs(buffer.exception(), exc) + + def test_exception_waiter(self): + buffer = parsers.DataBuffer(loop=self.loop) + + @tasks.coroutine + def set_err(): + buffer.set_exception(ValueError()) + + t1 = tasks.Task(buffer.read(), loop=self.loop) + t2 = tasks.Task(set_err(), loop=self.loop) + + self.loop.run_until_complete(tasks.wait([t1, t2], loop=self.loop)) + + self.assertRaises(ValueError, t1.result) + + +class StreamProtocolTests(unittest.TestCase): + + def test_connection_made(self): + tr = unittest.mock.Mock() + + proto = parsers.StreamProtocol() + self.assertIsNone(proto.transport) + + proto.connection_made(tr) + self.assertIs(proto.transport, tr) + + def test_connection_lost(self): + proto = parsers.StreamProtocol() + proto.connection_made(unittest.mock.Mock()) + proto.connection_lost(None) + self.assertIsNone(proto.transport) + self.assertTrue(proto._eof) + + def test_connection_lost_exc(self): + proto = parsers.StreamProtocol() + proto.connection_made(unittest.mock.Mock()) + + exc = ValueError() + proto.connection_lost(exc) + self.assertIs(proto.exception(), exc) + + +class ParserBuffer(unittest.TestCase): + + def _make_one(self): + return parsers.ParserBuffer() + + def test_shrink(self): + buf = parsers.ParserBuffer() + buf.feed_data(b'data') + + buf._shrink() + self.assertEqual(bytes(buf), b'data') + + buf.offset = 2 + buf._shrink() + self.assertEqual(bytes(buf), b'ta') + self.assertEqual(2, len(buf)) + self.assertEqual(2, buf.size) + self.assertEqual(0, buf.offset) + + def test_feed_data(self): + buf = self._make_one() + buf.feed_data(b'') + self.assertEqual(len(buf), 0) + + buf.feed_data(b'data') + self.assertEqual(len(buf), 4) + self.assertEqual(bytes(buf), b'data') + + def test_read(self): + buf = self._make_one() + p = buf.read(3) + next(p) + p.send(b'1') + try: + p.send(b'234') + except StopIteration as exc: + res = exc.value + + self.assertEqual(res, b'123') + self.assertEqual(b'4', bytes(buf)) + + def test_readsome(self): + buf = self._make_one() + p = buf.readsome(3) + next(p) + try: + p.send(b'1') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, b'1') + + p = buf.readsome(2) + next(p) + try: + p.send(b'234') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, b'23') + self.assertEqual(b'4', bytes(buf)) + + def test_skip(self): + buf = self._make_one() + p = buf.skip(3) + next(p) + p.send(b'1') + try: + p.send(b'234') + except StopIteration as exc: + res = exc.value + + self.assertIsNone(res) + self.assertEqual(b'4', bytes(buf)) + + def test_readuntil_limit(self): + buf = self._make_one() + p = buf.readuntil(b'\n', 4) + next(p) + p.send(b'1') + p.send(b'234') + self.assertRaises(ValueError, p.send, b'5') + + buf = parsers.ParserBuffer() + p = buf.readuntil(b'\n', 4) + next(p) + self.assertRaises(ValueError, p.send, b'12345\n6') + + buf = parsers.ParserBuffer() + p = buf.readuntil(b'\n', 4) + next(p) + self.assertRaises(ValueError, p.send, b'12345\n6') + + class CustomExc(Exception): + pass + + buf = parsers.ParserBuffer() + p = buf.readuntil(b'\n', 4, CustomExc) + next(p) + self.assertRaises(CustomExc, p.send, b'12345\n6') + + def test_readuntil(self): + buf = self._make_one() + p = buf.readuntil(b'\n', 4) + next(p) + p.send(b'123') + try: + p.send(b'\n456') + except StopIteration as exc: + res = exc.value + + self.assertEqual(res, b'123\n') + self.assertEqual(b'456', bytes(buf)) + + def test_skipuntil(self): + buf = self._make_one() + p = buf.skipuntil(b'\n') + next(p) + p.send(b'123') + try: + p.send(b'\n456\n') + except StopIteration: + pass + self.assertEqual(b'456\n', bytes(buf)) + + p = buf.skipuntil(b'\n') + try: + next(p) + except StopIteration: + pass + self.assertEqual(b'', bytes(buf)) + + def test_lines_parser(self): + out = parsers.DataBuffer(loop=self.loop) + buf = self._make_one() + p = parsers.lines_parser() + next(p) + p.send((out, buf)) + + for d in (b'line1', b'\r\n', b'lin', b'e2\r', b'\ndata'): + p.send(d) + + self.assertEqual( + [bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(out._buffer)) + try: + p.throw(parsers.EofStream()) + except parsers.EofStream: + pass + + self.assertEqual(bytes(buf), b'data') + + def test_chunks_parser(self): + out = parsers.DataBuffer(loop=self.loop) + buf = self._make_one() + p = parsers.chunks_parser(5) + next(p) + p.send((out, buf)) + + for d in (b'line1', b'lin', b'e2d', b'ata'): + p.send(d) + + self.assertEqual( + [bytearray(b'line1'), bytearray(b'line2')], list(out._buffer)) + try: + p.throw(parsers.EofStream()) + except parsers.EofStream: + pass + + self.assertEqual(bytes(buf), b'data') diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py new file mode 100644 index 00000000..da4dea35 --- /dev/null +++ b/tests/proactor_events_test.py @@ -0,0 +1,393 @@ +"""Tests for proactor_events.py""" + +import socket +import unittest +import unittest.mock + +import tulip +from tulip.proactor_events import BaseProactorEventLoop +from tulip.proactor_events import _ProactorSocketTransport +from tulip import test_utils + + +class ProactorSocketTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.proactor = unittest.mock.Mock() + self.loop._proactor = self.proactor + self.protocol = test_utils.make_test_protocol(tulip.Protocol) + self.sock = unittest.mock.Mock(socket.socket) + + def test_ctor(self): + fut = tulip.Future(loop=self.loop) + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol, fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + self.protocol.connection_made(tr) + self.proactor.recv.assert_called_with(self.sock, 4096) + + def test_loop_reading(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._loop_reading() + self.loop._proactor.recv.assert_called_with(self.sock, 4096) + self.assertFalse(self.protocol.data_received.called) + self.assertFalse(self.protocol.eof_received.called) + + def test_loop_reading_data(self): + res = tulip.Future(loop=self.loop) + res.set_result(b'data') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + + tr._read_fut = res + tr._loop_reading(res) + self.loop._proactor.recv.assert_called_with(self.sock, 4096) + self.protocol.data_received.assert_called_with(b'data') + + def test_loop_reading_no_data(self): + res = tulip.Future(loop=self.loop) + res.set_result(b'') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + + self.assertRaises(AssertionError, tr._loop_reading, res) + + tr.close = unittest.mock.Mock() + tr._read_fut = res + tr._loop_reading(res) + self.assertFalse(self.loop._proactor.recv.called) + self.assertTrue(self.protocol.eof_received.called) + self.assertTrue(tr.close.called) + + def test_loop_reading_aborted(self): + err = self.loop._proactor.recv.side_effect = ConnectionAbortedError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with(err) + + def test_loop_reading_aborted_closing(self): + self.loop._proactor.recv.side_effect = ConnectionAbortedError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = True + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + + def test_loop_reading_aborted_is_fatal(self): + self.loop._proactor.recv.side_effect = ConnectionAbortedError() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = False + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + self.assertTrue(tr._fatal_error.called) + + def test_loop_reading_conn_reset_lost(self): + err = self.loop._proactor.recv.side_effect = ConnectionResetError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = False + tr._fatal_error = unittest.mock.Mock() + tr._force_close = unittest.mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + tr._force_close.assert_called_with(err) + + def test_loop_reading_exception(self): + err = self.loop._proactor.recv.side_effect = (OSError()) + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with(err) + + def test_write(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._loop_writing = unittest.mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, [b'data']) + self.assertTrue(tr._loop_writing.called) + + def test_write_no_data(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr.write(b'') + self.assertFalse(tr._buffer) + + def test_write_more(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = unittest.mock.Mock() + tr._loop_writing = unittest.mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, [b'data']) + self.assertFalse(tr._loop_writing.called) + + def test_loop_writing(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + self.loop._proactor.send.assert_called_with(self.sock, b'data') + self.loop._proactor.send.return_value.add_done_callback.\ + assert_called_with(tr._loop_writing) + + @unittest.mock.patch('tulip.proactor_events.tulip_log') + def test_loop_writing_err(self, m_log): + err = self.loop._proactor.send.side_effect = OSError() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + tr._fatal_error.assert_called_with(err) + tr._conn_lost = 1 + + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + self.assertEqual(tr._buffer, []) + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_loop_writing_stop(self): + fut = tulip.Future(loop=self.loop) + fut.set_result(b'data') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = fut + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + + def test_loop_writing_closing(self): + fut = tulip.Future(loop=self.loop) + fut.set_result(1) + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = fut + tr.close() + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test_abort(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._force_close = unittest.mock.Mock() + tr.abort() + tr._force_close.assert_called_with(None) + + def test_close(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr.close() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertTrue(tr._closing) + self.assertEqual(tr._conn_lost, 1) + + self.protocol.connection_lost.reset_mock() + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_write_fut(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = unittest.mock.Mock() + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_buffer(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + @unittest.mock.patch('tulip.proactor_events.tulip_log') + def test_fatal_error(self, m_logging): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._force_close = unittest.mock.Mock() + tr._fatal_error(None) + self.assertTrue(tr._force_close.called) + self.assertTrue(m_logging.exception.called) + + def test_force_close(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + read_fut = tr._read_fut = unittest.mock.Mock() + write_fut = tr._write_fut = unittest.mock.Mock() + tr._force_close(None) + + read_fut.cancel.assert_called_with() + write_fut.cancel.assert_called_with() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertEqual([], tr._buffer) + self.assertEqual(tr._conn_lost, 1) + + def test_force_close_idempotent(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = True + tr._force_close(None) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_fatal_error_2(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + tr._force_close(None) + + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertEqual([], tr._buffer) + + def test_call_connection_lost(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._call_connection_lost(None) + self.assertTrue(self.protocol.connection_lost.called) + self.assertTrue(self.sock.close.called) + + +class BaseProactorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.sock = unittest.mock.Mock(socket.socket) + self.proactor = unittest.mock.Mock() + + self.ssock, self.csock = unittest.mock.Mock(), unittest.mock.Mock() + + class EventLoop(BaseProactorEventLoop): + def _socketpair(s): + return (self.ssock, self.csock) + + self.loop = EventLoop(self.proactor) + + @unittest.mock.patch.object(BaseProactorEventLoop, 'call_soon') + @unittest.mock.patch.object(BaseProactorEventLoop, '_socketpair') + def test_ctor(self, socketpair, call_soon): + ssock, csock = socketpair.return_value = ( + unittest.mock.Mock(), unittest.mock.Mock()) + loop = BaseProactorEventLoop(self.proactor) + self.assertIs(loop._ssock, ssock) + self.assertIs(loop._csock, csock) + self.assertEqual(loop._internal_fds, 1) + call_soon.assert_called_with(loop._loop_self_reading) + + def test_close_self_pipe(self): + self.loop._close_self_pipe() + self.assertEqual(self.loop._internal_fds, 0) + self.assertTrue(self.ssock.close.called) + self.assertTrue(self.csock.close.called) + self.assertIsNone(self.loop._ssock) + self.assertIsNone(self.loop._csock) + + def test_close(self): + self.loop._close_self_pipe = unittest.mock.Mock() + self.loop.close() + self.assertTrue(self.loop._close_self_pipe.called) + self.assertTrue(self.proactor.close.called) + self.assertIsNone(self.loop._proactor) + + self.loop._close_self_pipe.reset_mock() + self.loop.close() + self.assertFalse(self.loop._close_self_pipe.called) + + def test_sock_recv(self): + self.loop.sock_recv(self.sock, 1024) + self.proactor.recv.assert_called_with(self.sock, 1024) + + def test_sock_sendall(self): + self.loop.sock_sendall(self.sock, b'data') + self.proactor.send.assert_called_with(self.sock, b'data') + + def test_sock_connect(self): + self.loop.sock_connect(self.sock, 123) + self.proactor.connect.assert_called_with(self.sock, 123) + + def test_sock_accept(self): + self.loop.sock_accept(self.sock) + self.proactor.accept.assert_called_with(self.sock) + + def test_socketpair(self): + self.assertRaises( + NotImplementedError, BaseProactorEventLoop, self.proactor) + + def test_make_socket_transport(self): + tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock()) + self.assertIsInstance(tr, _ProactorSocketTransport) + + def test_loop_self_reading(self): + self.loop._loop_self_reading() + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_fut(self): + fut = unittest.mock.Mock() + self.loop._loop_self_reading(fut) + self.assertTrue(fut.result.called) + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_exception(self): + self.loop.close = unittest.mock.Mock() + self.proactor.recv.side_effect = OSError() + self.assertRaises(OSError, self.loop._loop_self_reading) + self.assertTrue(self.loop.close.called) + + def test_write_to_self(self): + self.loop._write_to_self() + self.csock.send.assert_called_with(b'x') + + def test_process_events(self): + self.loop._process_events([]) + + @unittest.mock.patch('tulip.proactor_events.tulip_log') + def test_start_serving(self, m_log): + pf = unittest.mock.Mock() + call_soon = self.loop.call_soon = unittest.mock.Mock() + + self.loop._start_serving(pf, self.sock) + self.assertTrue(call_soon.called) + + # callback + loop = call_soon.call_args[0][0] + loop() + self.proactor.accept.assert_called_with(self.sock) + + # conn + fut = unittest.mock.Mock() + fut.result.return_value = (unittest.mock.Mock(), unittest.mock.Mock()) + + make_tr = self.loop._make_socket_transport = unittest.mock.Mock() + loop(fut) + self.assertTrue(fut.result.called) + self.assertTrue(make_tr.called) + + # exception + fut.result.side_effect = OSError() + loop(fut) + self.assertTrue(self.sock.close.called) + self.assertTrue(m_log.exception.called) + + def test_start_serving_cancel(self): + pf = unittest.mock.Mock() + call_soon = self.loop.call_soon = unittest.mock.Mock() + + self.loop._start_serving(pf, self.sock) + loop = call_soon.call_args[0][0] + + # cancelled + fut = tulip.Future(loop=self.loop) + fut.cancel() + loop(fut) + self.assertTrue(self.sock.close.called) + + def test_stop_serving(self): + sock = unittest.mock.Mock() + self.loop.stop_serving(sock) + self.assertTrue(sock.close.called) + self.proactor.stop_serving.assert_called_with(sock) diff --git a/tests/queues_test.py b/tests/queues_test.py new file mode 100644 index 00000000..131812a4 --- /dev/null +++ b/tests/queues_test.py @@ -0,0 +1,427 @@ +"""Tests for queues.py""" + +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import locks +from tulip import queues +from tulip import tasks +from tulip import test_utils + + +class _QueueTestBase(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + +class QueueBasicTests(_QueueTestBase): + + def _test_repr_or_str(self, fn, expect_id): + """Test Queue's repr or str. + + fn is repr or str. expect_id is True if we expect the Queue's id to + appear in fn(Queue()). + """ + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.2, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = queues.Queue(loop=loop) + self.assertTrue(fn(q).startswith(')') + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), 'Task()') + self.assertRaises(futures.CancelledError, + self.loop.run_until_complete, t) + self.assertEqual(repr(t), 'Task()') + t = tasks.Task(notmuch(), loop=self.loop) + self.loop.run_until_complete(t) + self.assertEqual(repr(t), "Task()") + + def test_task_repr_custom(self): + @tasks.coroutine + def coro(): + pass + + class T(futures.Future): + def __repr__(self): + return 'T[]' + + class MyTask(tasks.Task, T): + def __repr__(self): + return super().__repr__() + + t = MyTask(coro(), loop=self.loop) + self.assertEqual(repr(t), 'T[]()') + + def test_task_basics(self): + @tasks.coroutine + def outer(): + a = yield from inner1() + b = yield from inner2() + return a+b + + @tasks.coroutine + def inner1(): + return 42 + + @tasks.coroutine + def inner2(): + return 1000 + + t = outer() + self.assertEqual(self.loop.run_until_complete(t), 1042) + + def test_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def task(): + yield from tasks.sleep(10.0, loop=loop) + return 12 + + t = tasks.Task(task(), loop=loop) + loop.call_soon(t.cancel) + self.assertRaises( + futures.CancelledError, loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_cancel_yield(self): + @tasks.coroutine + def task(): + yield + yield + return 12 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) # start coro + t.cancel() + self.assertRaises( + futures.CancelledError, self.loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + +## def test_cancel_done_future(self): +## fut1 = futures.Future(loop=self.loop) +## fut2 = futures.Future(loop=self.loop) +## fut3 = futures.Future(loop=self.loop) + +## @tasks.coroutine +## def task(): +## yield from fut1 +## try: +## yield from fut2 +## except futures.CancelledError: +## pass +## yield from fut3 + +## t = tasks.Task(task(), loop=self.loop) +## test_utils.run_briefly(self.loop) +## fut1.set_result(None) +## t.cancel() +## test_utils.run_once(self.loop) # process fut1 result, delay cancel +## self.assertFalse(t.done()) +## test_utils.run_once(self.loop) # cancel fut2, but coro still alive +## self.assertFalse(t.done()) +## test_utils.run_briefly(self.loop) # cancel fut3 +## self.assertTrue(t.done()) + +## self.assertEqual(fut1.result(), None) +## self.assertTrue(fut2.cancelled()) +## self.assertTrue(fut3.cancelled()) +## self.assertTrue(t.cancelled()) + +## def test_cancel_in_coro(self): +## @tasks.coroutine +## def task(): +## t.cancel() +## return 12 + +## t = tasks.Task(task(), loop=self.loop) +## self.assertRaises( +## futures.CancelledError, self.loop.run_until_complete, t) +## self.assertTrue(t.done()) +## self.assertFalse(t.cancel()) + + def test_stop_while_run_in_complete(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.2, when) + when = yield 0.1 + self.assertAlmostEqual(0.3, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + x = 0 + waiters = [] + + @tasks.coroutine + def task(): + nonlocal x + while x < 10: + waiters.append(tasks.sleep(0.1, loop=loop)) + yield from waiters[-1] + x += 1 + if x == 2: + loop.stop() + + t = tasks.Task(task(), loop=loop) + self.assertRaises( + RuntimeError, loop.run_until_complete, t) + self.assertFalse(t.done()) + self.assertEqual(x, 2) + self.assertAlmostEqual(0.3, loop.time()) + + # close generators + for w in waiters: + w.close() + + def test_wait_for(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.2, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.4, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def foo(): + yield from tasks.sleep(0.2, loop=loop) + return 'done' + + fut = tasks.Task(foo(), loop=loop) + + with self.assertRaises(futures.TimeoutError): + loop.run_until_complete(tasks.wait_for(fut, 0.1, loop=loop)) + + self.assertFalse(fut.done()) + self.assertAlmostEqual(0.1, loop.time()) + + loop + # wait for result + res = loop.run_until_complete( + tasks.wait_for(fut, 0.3, loop=loop)) + self.assertEqual(res, 'done') + self.assertAlmostEqual(0.2, loop.time()) + + def test_wait_for_with_global_loop(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.2, when) + when = yield 0 + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def foo(): + yield from tasks.sleep(0.2, loop=loop) + return 'done' + + events.set_event_loop(loop) + try: + fut = tasks.Task(foo(), loop=loop) + with self.assertRaises(futures.TimeoutError): + loop.run_until_complete(tasks.wait_for(fut, 0.01)) + finally: + events.set_event_loop(None) + + self.assertAlmostEqual(0.01, loop.time()) + self.assertFalse(fut.done()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(fut) + + def test_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], loop=loop) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertEqual(res, 42) + self.assertAlmostEqual(0.15, loop.time()) + + # Doing it again should take no time and exercise a different path. + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + self.assertEqual(res, 42) + + def test_wait_with_global_loop(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + when = yield 0 + self.assertAlmostEqual(0.015, when) + yield 0.015 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.01, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.015, loop=loop), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + events.set_event_loop(loop) + try: + res = loop.run_until_complete( + tasks.Task(foo(), loop=loop)) + finally: + events.set_event_loop(None) + + self.assertEqual(res, 42) + + def test_wait_errors(self): + self.assertRaises( + ValueError, self.loop.run_until_complete, + tasks.wait(set(), loop=self.loop)) + + self.assertRaises( + ValueError, self.loop.run_until_complete, + tasks.wait([tasks.sleep(10.0, loop=self.loop)], + return_when=-1, loop=self.loop)) + + def test_wait_first_completed(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED, + loop=loop), + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertFalse(a.done()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + self.assertAlmostEqual(0.1, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_really_done(self): + # there is possibility that some tasks in the pending list + # became done but their callbacks haven't all been called yet + + @tasks.coroutine + def coro1(): + yield + + @tasks.coroutine + def coro2(): + yield + yield + + a = tasks.Task(coro1(), loop=self.loop) + b = tasks.Task(coro2(), loop=self.loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED, + loop=self.loop), + loop=self.loop) + + done, pending = self.loop.run_until_complete(task) + self.assertEqual({a, b}, done) + self.assertTrue(a.done()) + self.assertIsNone(a.result()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + + def test_wait_first_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + # first_exception, task already has exception + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + + @tasks.coroutine + def exc(): + raise ZeroDivisionError('err') + + b = tasks.Task(exc(), loop=loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION, + loop=loop), + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertAlmostEqual(0, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_first_exception_in_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + # first_exception, exception during waiting + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + + @tasks.coroutine + def exc(): + yield from tasks.sleep(0.01, loop=loop) + raise ZeroDivisionError('err') + + b = tasks.Task(exc(), loop=loop) + task = tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION, + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertAlmostEqual(0.01, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_with_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(0.15, loop=loop) + raise ZeroDivisionError('really') + + b = tasks.Task(sleeper(), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], loop=loop) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + def test_wait_with_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.11, when) + yield 0.11 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], timeout=0.11, + loop=loop) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.11, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_concurrent_complete(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + done, pending = loop.run_until_complete( + tasks.wait([b, a], timeout=0.1, loop=loop)) + + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + self.assertAlmostEqual(0.1, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_as_completed(self): + + def gen(): + yield 0 + yield 0 + yield 0.01 + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + completed = set() + time_shifted = False + + @tasks.coroutine + def sleeper(dt, x): + nonlocal time_shifted + yield from tasks.sleep(dt, loop=loop) + completed.add(x) + if not time_shifted and 'a' in completed and 'b' in completed: + time_shifted = True + loop.advance_time(0.14) + return x + + a = sleeper(0.01, 'a') + b = sleeper(0.01, 'b') + c = sleeper(0.15, 'c') + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([b, c, a], loop=loop): + values.append((yield from f)) + return values + + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + + # Doing it again should take no time and exercise a different path. + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + def test_as_completed_with_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.12, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0.1 + self.assertAlmostEqual(0.12, when) + yield 0.02 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.sleep(0.1, 'a', loop=loop) + b = tasks.sleep(0.15, 'b', loop=loop) + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([a, b], timeout=0.12, loop=loop): + try: + v = yield from f + values.append((1, v)) + except futures.TimeoutError as exc: + values.append((2, exc)) + return values + + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) + self.assertAlmostEqual(0.12, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_as_completed_reverse_wait(self): + + def gen(): + yield 0 + yield 0.05 + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + completed = set() + time_shifted = False + + a = tasks.sleep(0.05, 'a', loop=loop) + b = tasks.sleep(0.10, 'b', loop=loop) + fs = {a, b} + futs = list(tasks.as_completed(fs, loop=loop)) + self.assertEqual(len(futs), 2) + + x = loop.run_until_complete(futs[1]) + self.assertEqual(x, 'a') + self.assertAlmostEqual(0.05, loop.time()) + loop.advance_time(0.05) + y = loop.run_until_complete(futs[0]) + self.assertEqual(y, 'b') + self.assertAlmostEqual(0.10, loop.time()) + + def test_as_completed_concurrent(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0 + self.assertAlmostEqual(0.05, when) + yield 0.05 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.sleep(0.05, 'a', loop=loop) + b = tasks.sleep(0.05, 'b', loop=loop) + fs = {a, b} + futs = list(tasks.as_completed(fs, loop=loop)) + self.assertEqual(len(futs), 2) + waiter = tasks.wait(futs, loop=loop) + done, pending = loop.run_until_complete(waiter) + self.assertEqual(set(f.result() for f in done), {'a', 'b'}) + + def test_sleep(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0.05 + self.assertAlmostEqual(0.1, when) + yield 0.05 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def sleeper(dt, arg): + yield from tasks.sleep(dt/2, loop=loop) + res = yield from tasks.sleep(dt/2, arg, loop=loop) + return res + + t = tasks.Task(sleeper(0.1, 'yeah'), loop=loop) + loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + self.assertAlmostEqual(0.1, loop.time()) + + def test_sleep_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + t = tasks.Task(tasks.sleep(10.0, 'yeah', loop=loop), + loop=loop) + + handle = None + orig_call_later = loop.call_later + + def call_later(self, delay, callback, *args): + nonlocal handle + handle = orig_call_later(self, delay, callback, *args) + return handle + + loop.call_later = call_later + test_utils.run_briefly(loop) + + self.assertFalse(handle._cancelled) + + t.cancel() + test_utils.run_briefly(loop) + self.assertTrue(handle._cancelled) + + def test_task_cancel_sleeping_task(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(5000, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + sleepfut = None + + @tasks.coroutine + def sleep(dt): + nonlocal sleepfut + sleepfut = tasks.sleep(dt, loop=loop) + yield from sleepfut + + @tasks.coroutine + def doit(): + sleeper = tasks.Task(sleep(5000), loop=loop) + loop.call_later(0.1, sleeper.cancel) + try: + yield from sleeper + except futures.CancelledError: + return 'cancelled' + else: + return 'slept in' + + doer = doit() + self.assertEqual(loop.run_until_complete(doer), 'cancelled') + self.assertAlmostEqual(0.1, loop.time()) + + def test_task_cancel_waiter_future(self): + fut = futures.Future(loop=self.loop) + + @tasks.coroutine + def coro(): + yield from fut + + task = tasks.Task(coro(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(task._fut_waiter, fut) + + task.cancel() + test_utils.run_briefly(self.loop) + self.assertRaises( + futures.CancelledError, self.loop.run_until_complete, task) + self.assertIsNone(task._fut_waiter) + self.assertTrue(fut.cancelled()) + + def test_step_in_completed_task(self): + @tasks.coroutine + def notmuch(): + return 'ko' + + task = tasks.Task(notmuch(), loop=self.loop) + task.set_result('ok') + + self.assertRaises(AssertionError, task._step) + + def test_step_result(self): + @tasks.coroutine + def notmuch(): + yield None + yield 1 + return 'ko' + + self.assertRaises( + RuntimeError, self.loop.run_until_complete, notmuch()) + + def test_step_result_future(self): + # If coroutine returns future, task waits on this future. + + class Fut(futures.Future): + def __init__(self, *args, **kwds): + self.cb_added = False + super().__init__(*args, **kwds) + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + fut = Fut(loop=self.loop) + result = None + + @tasks.coroutine + def wait_for_future(): + nonlocal result + result = yield from fut + + t = tasks.Task(wait_for_future(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertTrue(fut.cb_added) + + res = object() + fut.set_result(res) + test_utils.run_briefly(self.loop) + self.assertIs(res, result) + self.assertTrue(t.done()) + self.assertIsNone(t.result()) + + def test_step_with_baseexception(self): + @tasks.coroutine + def notmutch(): + raise BaseException() + + task = tasks.Task(notmutch(), loop=self.loop) + self.assertRaises(BaseException, task._step) + + self.assertTrue(task.done()) + self.assertIsInstance(task.exception(), BaseException) + + def test_baseexception_during_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(10, loop=loop) + + base_exc = BaseException() + + @tasks.coroutine + def notmutch(): + try: + yield from sleeper() + except futures.CancelledError: + raise base_exc + + task = tasks.Task(notmutch(), loop=loop) + test_utils.run_briefly(loop) + + task.cancel() + self.assertFalse(task.done()) + + self.assertRaises(BaseException, test_utils.run_briefly, loop) + + self.assertTrue(task.done()) + self.assertFalse(task.cancelled()) + self.assertIs(task.exception(), base_exc) + + def test_iscoroutinefunction(self): + def fn(): + pass + + self.assertFalse(tasks.iscoroutinefunction(fn)) + + def fn1(): + yield + self.assertFalse(tasks.iscoroutinefunction(fn1)) + + @tasks.coroutine + def fn2(): + yield + self.assertTrue(tasks.iscoroutinefunction(fn2)) + + def test_yield_vs_yield_from(self): + fut = futures.Future(loop=self.loop) + + @tasks.coroutine + def wait_for_future(): + yield fut + + task = wait_for_future() + with self.assertRaises(RuntimeError) as cm: + self.loop.run_until_complete(task) + + self.assertFalse(fut.done()) + + def test_yield_vs_yield_from_generator(self): + @tasks.coroutine + def coro(): + yield + + @tasks.coroutine + def wait_for_future(): + yield coro() + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, task) + + def test_coroutine_non_gen_function(self): + @tasks.coroutine + def func(): + return 'test' + + self.assertTrue(tasks.iscoroutinefunction(func)) + + coro = func() + self.assertTrue(tasks.iscoroutine(coro)) + + res = self.loop.run_until_complete(coro) + self.assertEqual(res, 'test') + + def test_coroutine_non_gen_function_return_future(self): + fut = futures.Future(loop=self.loop) + + @tasks.coroutine + def func(): + return fut + + @tasks.coroutine + def coro(): + fut.set_result('test') + + t1 = tasks.Task(func(), loop=self.loop) + t2 = tasks.Task(coro(), loop=self.loop) + res = self.loop.run_until_complete(t1) + self.assertEqual(res, 'test') + self.assertIsNone(t2.result()) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/transports_test.py b/tests/transports_test.py new file mode 100644 index 00000000..5920cda6 --- /dev/null +++ b/tests/transports_test.py @@ -0,0 +1,59 @@ +"""Tests for transports.py.""" + +import unittest +import unittest.mock + +from tulip import futures +from tulip import transports + + +class TransportTests(unittest.TestCase): + + def test_ctor_extra_is_none(self): + transport = transports.Transport() + self.assertEqual(transport._extra, {}) + + def test_get_extra_info(self): + transport = transports.Transport({'extra': 'info'}) + self.assertEqual('info', transport.get_extra_info('extra')) + self.assertIsNone(transport.get_extra_info('unknown')) + + default = object() + self.assertIs(default, transport.get_extra_info('unknown', default)) + + def test_writelines(self): + transport = transports.Transport() + transport.write = unittest.mock.Mock() + + transport.writelines(['line1', 'line2', 'line3']) + self.assertEqual(3, transport.write.call_count) + + def test_not_implemented(self): + transport = transports.Transport() + + self.assertRaises(NotImplementedError, transport.write, 'data') + self.assertRaises(NotImplementedError, transport.write_eof) + self.assertRaises(NotImplementedError, transport.can_write_eof) + self.assertRaises(NotImplementedError, transport.pause) + self.assertRaises(NotImplementedError, transport.resume) + self.assertRaises(NotImplementedError, transport.close) + self.assertRaises(NotImplementedError, transport.abort) + self.assertRaises(NotImplementedError, transport.pause_writing) + self.assertRaises(NotImplementedError, transport.resume_writing) + self.assertRaises(NotImplementedError, transport.discard_output) + + def test_dgram_not_implemented(self): + transport = transports.DatagramTransport() + + self.assertRaises(NotImplementedError, transport.sendto, 'data') + self.assertRaises(NotImplementedError, transport.abort) + + def test_subprocess_transport_not_implemented(self): + transport = transports.SubprocessTransport() + + self.assertRaises(NotImplementedError, transport.get_pid) + self.assertRaises(NotImplementedError, transport.get_returncode) + self.assertRaises(NotImplementedError, transport.get_pipe_transport, 1) + self.assertRaises(NotImplementedError, transport.send_signal, 1) + self.assertRaises(NotImplementedError, transport.terminate) + self.assertRaises(NotImplementedError, transport.kill) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py new file mode 100644 index 00000000..f0b42a39 --- /dev/null +++ b/tests/unix_events_test.py @@ -0,0 +1,818 @@ +"""Tests for unix_events.py.""" + +import gc +import errno +import io +import pprint +import signal +import stat +import sys +import tempfile +import unittest +import unittest.mock + + +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import test_utils +from tulip import unix_events + + +@unittest.skipUnless(signal, 'Signals are not supported') +class SelectorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.loop = unix_events.SelectorEventLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_check_signal(self): + self.assertRaises( + TypeError, self.loop._check_signal, '1') + self.assertRaises( + ValueError, self.loop._check_signal, signal.NSIG + 1) + + def test_handle_signal_no_handler(self): + self.loop._handle_signal(signal.NSIG + 1, ()) + + def test_handle_signal_cancelled_handler(self): + h = events.Handle(unittest.mock.Mock(), ()) + h.cancel() + self.loop._signal_handlers[signal.NSIG + 1] = h + self.loop.remove_signal_handler = unittest.mock.Mock() + self.loop._handle_signal(signal.NSIG + 1, ()) + self.loop.remove_signal_handler.assert_called_with(signal.NSIG + 1) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_setup_error(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.set_wakeup_fd.side_effect = ValueError + + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + cb = lambda: True + self.loop.add_signal_handler(signal.SIGHUP, cb) + h = self.loop._signal_handlers.get(signal.SIGHUP) + self.assertTrue(isinstance(h, events.Handle)) + self.assertEqual(h._callback, cb) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_install_error(self, m_signal): + m_signal.NSIG = signal.NSIG + + def set_wakeup_fd(fd): + if fd == -1: + raise ValueError() + m_signal.set_wakeup_fd = set_wakeup_fd + + class Err(OSError): + errno = errno.EFAULT + m_signal.signal.side_effect = Err + + self.assertRaises( + Err, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_add_signal_handler_install_error2(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.loop._signal_handlers[signal.SIGHUP] = lambda: True + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(1, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_add_signal_handler_install_error3(self, m_logging, m_signal): + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + m_signal.NSIG = signal.NSIG + + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(2, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + self.assertTrue( + self.loop.remove_signal_handler(signal.SIGHUP)) + self.assertTrue(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_2(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.SIGINT = signal.SIGINT + + self.loop.add_signal_handler(signal.SIGINT, lambda: True) + self.loop._signal_handlers[signal.SIGHUP] = object() + m_signal.set_wakeup_fd.reset_mock() + + self.assertTrue( + self.loop.remove_signal_handler(signal.SIGINT)) + self.assertFalse(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGINT, m_signal.default_int_handler), + m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.set_wakeup_fd.side_effect = ValueError + + self.loop.remove_signal_handler(signal.SIGHUP) + self.assertTrue(m_logging.info) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error(self, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.signal.side_effect = OSError + + self.assertRaises( + OSError, self.loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error2(self, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.assertRaises( + RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = True + m_WIFSIGNALED.return_value = False + m_WEXITSTATUS.return_value = 3 + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + transp._process_exited.assert_called_with(3) + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_signal(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = False + m_WIFSIGNALED.return_value = True + m_WTERMSIG.return_value = 1 + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + transp._process_exited.assert_called_with(-1) + self.assertFalse(m_WEXITSTATUS.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_zero_pid(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(0, object()), ChildProcessError] + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m_WEXITSTATUS.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_not_registered_subprocess(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = True + m_WIFSIGNALED.return_value = False + m_WEXITSTATUS.return_value = 3 + + self.loop._sig_chld() + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_unknown_status(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = False + m_WIFSIGNALED.return_value = False + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('tulip.unix_events.tulip_log') + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_unknown_status_in_handler(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG, + m_log): + m_waitpid.side_effect = Exception + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m_WEXITSTATUS.called) + m_log.exception.assert_called_with( + 'Unknown exception in SIGCHLD handler') + + +class UnixReadPipeTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(protocols.Protocol) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + + fcntl_patcher = unittest.mock.patch('fcntl.fcntl') + fcntl_patcher.start() + self.addCleanup(fcntl_patcher.stop) + + def test_ctor(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = futures.Future(loop=self.loop) + unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol, fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + + @unittest.mock.patch('os.read') + def test__read_ready(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.return_value = b'data' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.data_received.assert_called_with(b'data') + + @unittest.mock.patch('os.read') + def test__read_ready_eof(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.return_value = b'' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.eof_received.assert_called_with() + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('os.read') + def test__read_ready_blocked(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.side_effect = BlockingIOError + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.data_received.called) + + @unittest.mock.patch('tulip.log.tulip_log.exception') + @unittest.mock.patch('os.read') + def test__read_ready_error(self, m_read, m_logexc): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + err = OSError() + m_read.side_effect = err + tr._close = unittest.mock.Mock() + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + tr._close.assert_called_with(err) + m_logexc.assert_called_with('Fatal error for %s', tr) + + @unittest.mock.patch('os.read') + def test_pause(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + m = unittest.mock.Mock() + self.loop.add_reader(5, m) + tr.pause() + self.assertFalse(self.loop.readers) + + @unittest.mock.patch('os.read') + def test_resume(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr.resume() + self.loop.assert_reader(5, tr._read_ready) + + @unittest.mock.patch('os.read') + def test_close(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr._close = unittest.mock.Mock() + tr.close() + tr._close.assert_called_with(None) + + @unittest.mock.patch('os.read') + def test_close_already_closing(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr._closing = True + tr._close = unittest.mock.Mock() + tr.close() + self.assertFalse(tr._close.called) + + @unittest.mock.patch('os.read') + def test__close(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = object() + tr._close(err) + self.assertTrue(tr._closing) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + + def test__call_connection_lost(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test__call_connection_lost_with_err(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + +class UnixWritePipeTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(protocols.BaseProtocol) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + + fcntl_patcher = unittest.mock.patch('fcntl.fcntl') + fcntl_patcher.start() + self.addCleanup(fcntl_patcher.stop) + + fstat_patcher = unittest.mock.patch('os.fstat') + m_fstat = fstat_patcher.start() + st = unittest.mock.Mock() + st.st_mode = stat.S_IFIFO + m_fstat.return_value = st + self.addCleanup(fstat_patcher.stop) + + def test_ctor(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = futures.Future(loop=self.loop) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol, fut) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.assertEqual(None, fut.result()) + + def test_can_write_eof(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.assertTrue(tr.can_write_eof()) + + @unittest.mock.patch('os.write') + def test_write(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.return_value = 4 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_no_data(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write(b'') + self.assertFalse(m_write.called) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_partial(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.return_value = 2 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'ta'], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_buffer(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'previous'] + tr.write(b'data') + self.assertFalse(m_write.called) + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'previous', b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_again(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.side_effect = BlockingIOError() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('tulip.unix_events.tulip_log') + @unittest.mock.patch('os.write') + def test_write_err(self, m_write, m_log): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + m_write.side_effect = err + tr._fatal_error = unittest.mock.Mock() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + tr._fatal_error.assert_called_with(err) + self.assertEqual(1, tr._conn_lost) + + tr.write(b'data') + self.assertEqual(2, tr._conn_lost) + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + m_log.warning.assert_called_with( + 'os.write(pipe, data) raised exception.') + + def test__read_ready(self): + tr = unix_events._UnixWritePipeTransport(self.loop, self.pipe, + self.protocol) + tr._read_ready() + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + self.assertTrue(tr._closing) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('os.write') + def test__write_ready(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_partial(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 3 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'a'], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_again(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.side_effect = BlockingIOError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_empty(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 0 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('tulip.log.tulip_log.exception') + @unittest.mock.patch('os.write') + def test__write_ready_err(self, m_write, m_logexc): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.side_effect = err = OSError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + m_logexc.assert_called_with('Fatal error for %s', tr) + self.assertEqual(1, tr._conn_lost) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + + + @unittest.mock.patch('os.write') + def test__write_ready_closing(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._closing = True + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) + self.assertEqual([], tr._buffer) + self.protocol.connection_lost.assert_called_with(None) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('os.write') + def test_abort(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + self.loop.add_reader(5, tr._read_ready) + tr._buffer = [b'da', b'ta'] + tr.abort() + self.assertFalse(m_write.called) + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test__call_connection_lost(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test__call_connection_lost_with_err(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test_close(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr.close() + tr.write_eof.assert_called_with() + + def test_close_closing(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr._closing = True + tr.close() + self.assertFalse(tr.write_eof.called) + + def test_write_eof(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test_write_eof_pending(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr._buffer = [b'data'] + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.protocol.connection_lost.called) + + def test_pause_resume_writing(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr.pause_writing() + self.assertFalse(tr._writing) + tr.resume_writing() + self.assertTrue(tr._writing) + + def test_double_pause_resume_writing(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr.pause_writing() + self.assertFalse(tr._writing) + tr.pause_writing() + self.assertFalse(tr._writing) + tr.resume_writing() + self.assertTrue(tr._writing) + tr.resume_writing() + self.assertTrue(tr._writing) + + def test_pause_resume_writing_with_nonempty_buffer(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + tr.pause_writing() + self.assertFalse(tr._writing) + self.assertFalse(self.loop.writers) + self.assertEqual([b'da', b'ta'], tr._buffer) + + tr.resume_writing() + self.assertTrue(tr._writing) + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'da', b'ta'], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_on_pause(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + tr.pause_writing() + + tr._write_ready() + self.assertFalse(m_write.called) + self.assertFalse(self.loop.writers) + self.assertEqual([b'da', b'ta'], tr._buffer) + self.assertFalse(tr._writing) + + def test_discard_output(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr._buffer = [b'da', b'ta'] + self.loop.add_writer(5, tr._write_ready) + tr.discard_output() + self.assertTrue(tr._writing) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + def test_discard_output_without_pending_writes(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr.discard_output() + self.assertTrue(tr._writing) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) diff --git a/tests/windows_events_test.py b/tests/windows_events_test.py new file mode 100644 index 00000000..ce9b74da --- /dev/null +++ b/tests/windows_events_test.py @@ -0,0 +1,81 @@ +import unittest + +import tulip + +from tulip import windows_events +from tulip import protocols +from tulip import streams + + +def connect_read_pipe(loop, file): + stream_reader = streams.StreamReader(loop=loop) + protocol = _StreamReaderProtocol(stream_reader) + loop._make_read_pipe_transport(file, protocol) + return stream_reader + + +class _StreamReaderProtocol(protocols.Protocol): + def __init__(self, stream_reader): + self.stream_reader = stream_reader + + def connection_lost(self, exc): + self.stream_reader.set_exception(exc) + + def data_received(self, data): + self.stream_reader.feed_data(data) + + def eof_received(self): + self.stream_reader.feed_eof() + + +class ProactorTests(unittest.TestCase): + + def setUp(self): + self.loop = windows_events.ProactorEventLoop() + tulip.set_event_loop(None) + + def tearDown(self): + self.loop.close() + self.loop = None + + def test_pause_resume_discard(self): + a, b = self.loop._socketpair() + trans = self.loop._make_write_pipe_transport(a, protocols.Protocol()) + reader = connect_read_pipe(self.loop, b) + f = tulip.async(reader.readline(), loop=self.loop) + + trans.write(b'msg1\n') + self.loop.run_until_complete(f, timeout=0.01) + self.assertEqual(f.result(), b'msg1\n') + f = tulip.async(reader.readline(), loop=self.loop) + + trans.pause_writing() + trans.write(b'msg2\n') + with self.assertRaises(tulip.TimeoutError): + self.loop.run_until_complete(f, timeout=0.01) + self.assertEqual(trans._buffer, [b'msg2\n']) + + trans.resume_writing() + self.loop.run_until_complete(f, timeout=0.1) + self.assertEqual(f.result(), b'msg2\n') + f = tulip.async(reader.readline(), loop=self.loop) + + trans.pause_writing() + trans.write(b'msg3\n') + self.assertEqual(trans._buffer, [b'msg3\n']) + trans.discard_output() + self.assertEqual(trans._buffer, []) + + trans.write(b'msg4\n') + self.assertEqual(trans._buffer, [b'msg4\n']) + trans.resume_writing() + self.loop.run_until_complete(f, timeout=0.01) + self.assertEqual(f.result(), b'msg4\n') + + def test_close(self): + a, b = self.loop._socketpair() + trans = self.loop._make_socket_transport(a, protocols.Protocol()) + f = tulip.async(self.loop.sock_recv(b, 100)) + trans.close() + self.loop.run_until_complete(f, timeout=1) + self.assertEqual(f.result(), b'') diff --git a/tests/windows_utils_test.py b/tests/windows_utils_test.py new file mode 100644 index 00000000..b23896d3 --- /dev/null +++ b/tests/windows_utils_test.py @@ -0,0 +1,132 @@ +"""Tests for window_utils""" + +import sys +import test.support +import unittest +import unittest.mock +import _winapi + +from tulip import windows_utils +from tulip import _overlapped + + +class WinsocketpairTests(unittest.TestCase): + + def test_winsocketpair(self): + ssock, csock = windows_utils.socketpair() + + csock.send(b'xxx') + self.assertEqual(b'xxx', ssock.recv(1024)) + + csock.close() + ssock.close() + + @unittest.mock.patch('tulip.windows_utils.socket') + def test_winsocketpair_exc(self, m_socket): + m_socket.socket.return_value.getsockname.return_value = ('', 12345) + m_socket.socket.return_value.accept.return_value = object(), object() + m_socket.socket.return_value.connect.side_effect = OSError() + + self.assertRaises(OSError, windows_utils.socketpair) + + +class PipeTests(unittest.TestCase): + + def test_pipe_overlapped(self): + h1, h2 = windows_utils.pipe(overlapped=(True, True)) + try: + ov1 = _overlapped.Overlapped() + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, 0) + + ov1.ReadFile(h1, 100) + self.assertTrue(ov1.pending) + self.assertEqual(ov1.error, _winapi.ERROR_IO_PENDING) + ERROR_IO_INCOMPLETE = 996 + try: + ov1.getresult() + except OSError as e: + self.assertEqual(e.winerror, ERROR_IO_INCOMPLETE) + else: + raise RuntimeError('expected ERROR_IO_INCOMPLETE') + + ov2 = _overlapped.Overlapped() + self.assertFalse(ov2.pending) + self.assertEqual(ov2.error, 0) + + ov2.WriteFile(h2, b"hello") + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + + res = _winapi.WaitForMultipleObjects([ov2.event], False, 100) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, ERROR_IO_INCOMPLETE) + self.assertFalse(ov2.pending) + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + self.assertEqual(ov1.getresult(), b"hello") + finally: + _winapi.CloseHandle(h1) + _winapi.CloseHandle(h2) + + def test_pipe_handle(self): + h, _ = windows_utils.pipe(overlapped=(True, True)) + _winapi.CloseHandle(_) + p = windows_utils.PipeHandle(h) + self.assertEqual(p.fileno(), h) + self.assertEqual(p.handle, h) + + # check garbage collection of p closes handle + del p + test.support.gc_collect() + try: + _winapi.CloseHandle(h) + except OSError as e: + self.assertEqual(e.winerror, 6) # ERROR_INVALID_HANDLE + else: + raise RuntimeError('expected ERROR_INVALID_HANDLE') + + +class PopenTests(unittest.TestCase): + + def test_popen(self): + command = r"""if 1: + import sys + s = sys.stdin.readline() + sys.stdout.write(s.upper()) + sys.stderr.write('stderr') + """ + msg = b"blah\n" + + p = windows_utils.Popen([sys.executable, '-c', command], + stdin=windows_utils.PIPE, + stdout=windows_utils.PIPE, + stderr=windows_utils.PIPE) + + for f in [p.stdin, p.stdout, p.stderr]: + self.assertIsInstance(f, windows_utils.PipeHandle) + + ovin = _overlapped.Overlapped() + ovout = _overlapped.Overlapped() + overr = _overlapped.Overlapped() + + ovin.WriteFile(p.stdin.handle, msg) + ovout.ReadFile(p.stdout.handle, 100) + overr.ReadFile(p.stderr.handle, 100) + + events = [ovin.event, ovout.event, overr.event] + res = _winapi.WaitForMultipleObjects(events, True, 2000) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + self.assertFalse(ovout.pending) + self.assertFalse(overr.pending) + self.assertFalse(ovin.pending) + + self.assertEqual(ovin.getresult(), len(msg)) + out = ovout.getresult().rstrip() + err = overr.getresult().rstrip() + + self.assertGreater(len(out), 0) + self.assertGreater(len(err), 0) + # allow for partial reads... + self.assertTrue(msg.upper().rstrip().startswith(out)) + self.assertTrue(b"stderr".startswith(err)) diff --git a/tulip/__init__.py b/tulip/__init__.py new file mode 100644 index 00000000..9de84cb0 --- /dev/null +++ b/tulip/__init__.py @@ -0,0 +1,28 @@ +"""Tulip 2.0, tracking PEP 3156.""" + +import sys + +# This relies on each of the submodules having an __all__ variable. +from .futures import * +from .events import * +from .locks import * +from .transports import * +from .parsers import * +from .protocols import * +from .streams import * +from .tasks import * + +if sys.platform == 'win32': # pragma: no cover + from .windows_events import * +else: + from .unix_events import * # pragma: no cover + + +__all__ = (futures.__all__ + + events.__all__ + + locks.__all__ + + transports.__all__ + + parsers.__all__ + + protocols.__all__ + + streams.__all__ + + tasks.__all__) diff --git a/tulip/base_events.py b/tulip/base_events.py new file mode 100644 index 00000000..3bccfc83 --- /dev/null +++ b/tulip/base_events.py @@ -0,0 +1,592 @@ +"""Base implementation of event loop. + +The event loop can be broken up into a multiplexer (the part +responsible for notifying us of IO events) and the event loop proper, +which wraps a multiplexer with functionality for scheduling callbacks, +immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. +""" + + +import collections +import concurrent.futures +import heapq +import logging +import socket +import subprocess +import time +import os +import sys + +from . import events +from . import futures +from . import tasks +from .log import tulip_log + + +__all__ = ['BaseEventLoop'] + + +# Argument for default thread pool executor creation. +_MAX_WORKERS = 5 + + +class _StopError(BaseException): + """Raised to stop the event loop.""" + + +def _raise_stop_error(*args): + raise _StopError + + +class BaseEventLoop(events.AbstractEventLoop): + + def __init__(self): + self._ready = collections.deque() + self._scheduled = [] + self._default_executor = None + self._internal_fds = 0 + self._running = False + + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None): + """Create socket transport.""" + raise NotImplementedError + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + server_side=False, extra=None): + """Create SSL transport.""" + raise NotImplementedError + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + """Create datagram transport.""" + raise NotImplementedError + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create read pipe transport.""" + raise NotImplementedError + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create write pipe transport.""" + raise NotImplementedError + + @tasks.coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + """Create subprocess transport.""" + raise NotImplementedError + + def _read_from_self(self): + """XXX""" + raise NotImplementedError + + def _write_to_self(self): + """XXX""" + raise NotImplementedError + + def _process_events(self, event_list): + """Process selector events.""" + raise NotImplementedError + + def run_forever(self): + """Run until stop() is called.""" + if self._running: + raise RuntimeError('Event loop is running.') + self._running = True + try: + while True: + try: + self._run_once() + except _StopError: + break + finally: + self._running = False + + def run_until_complete(self, future): + """Run until the Future is done. + + If the argument is a coroutine, it is wrapped in a Task. + + XXX TBD: It would be disastrous to call run_until_complete() + with the same coroutine twice -- it would wrap it in two + different Tasks and that can't be good. + + Return the Future's result, or raise its exception. + """ + future = tasks.async(future, loop=self) + future.add_done_callback(_raise_stop_error) + self.run_forever() + future.remove_done_callback(_raise_stop_error) + if not future.done(): + raise RuntimeError('Event loop stopped before Future completed.') + + return future.result() + + def stop(self): + """Stop running the event loop. + + Every callback scheduled before stop() is called will run. + Callback scheduled after stop() is called won't. However, + those callbacks will run if run() is called again later. + """ + self.call_soon(_raise_stop_error) + + def is_running(self): + """Returns running status of event loop.""" + return self._running + + def time(self): + """Return the time according to the event loop's clock.""" + return time.monotonic() + + def call_later(self, delay, callback, *args): + """Arrange for a callback to be called at a given time. + + Return a Handle: an opaque object with a cancel() method that + can be used to cancel the call. + + The delay can be an int or float, expressed in seconds. It is + always a relative time. + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + return self.call_at(self.time() + delay, callback, *args) + + def call_at(self, when, callback, *args): + """Like call_later(), but uses an absolute time.""" + timer = events.TimerHandle(when, callback, args) + heapq.heappush(self._scheduled, timer) + return timer + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + handle = events.make_handle(callback, args) + self._ready.append(handle) + return handle + + def call_soon_threadsafe(self, callback, *args): + """XXX""" + handle = self.call_soon(callback, *args) + self._write_to_self() + return handle + + def run_in_executor(self, executor, callback, *args): + if isinstance(callback, events.Handle): + assert not args + assert not isinstance(callback, events.TimerHandle) + if callback._cancelled: + f = futures.Future(loop=self) + f.set_result(None) + return f + callback, args = callback._callback, callback._args + if executor is None: + executor = self._default_executor + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + self._default_executor = executor + return futures.wrap_future(executor.submit(callback, *args), loop=self) + + def set_default_executor(self, executor): + self._default_executor = executor + + def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) + + def getnameinfo(self, sockaddr, flags=0): + return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + + @tasks.coroutine + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=None, family=0, proto=0, flags=0, sock=None, + local_addr=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + 'host/port and sock can not be specified at the same time') + + f1 = self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + fs = [f1] + if local_addr is not None: + f2 = self.getaddrinfo( + *local_addr, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + fs.append(f2) + else: + f2 = None + + yield from tasks.wait(fs, loop=self) + + infos = f1.result() + if not infos: + raise OSError('getaddrinfo() returned empty list') + if f2 is not None: + laddr_infos = f2.result() + if not laddr_infos: + raise OSError('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + if f2 is not None: + for _, _, _, _, laddr in laddr_infos: + try: + sock.bind(laddr) + break + except OSError as exc: + exc = OSError( + exc.errno, 'error while ' + 'attempting to bind on address ' + '{!r}: {}'.format( + laddr, exc.strerror.lower())) + exceptions.append(exc) + else: + sock.close() + sock = None + continue + yield from self.sock_connect(sock, address) + except OSError as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + if len(exceptions) == 1: + raise exceptions[0] + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise OSError('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + + elif sock is None: + raise ValueError( + 'host and port was not specified and no sock specified') + + sock.setblocking(False) + + protocol = protocol_factory() + waiter = futures.Future(loop=self) + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + transport = self._make_ssl_transport( + sock, protocol, sslcontext, waiter, server_side=False) + else: + transport = self._make_socket_transport(sock, protocol, waiter) + + yield from waiter + return transport, protocol + + @tasks.coroutine + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + """Create datagram connection.""" + if not (local_addr or remote_addr): + if family == 0: + raise ValueError('unexpected address family') + addr_pairs_info = (((family, proto), (None, None)),) + else: + # join addresss by (family, protocol) + addr_infos = collections.OrderedDict() + for idx, addr in ((0, local_addr), (1, remote_addr)): + if addr is not None: + assert isinstance(addr, tuple) and len(addr) == 2, ( + '2-tuple is expected') + + infos = yield from self.getaddrinfo( + *addr, family=family, type=socket.SOCK_DGRAM, + proto=proto, flags=flags) + if not infos: + raise OSError('getaddrinfo() returned empty list') + + for fam, _, pro, _, address in infos: + key = (fam, pro) + if key not in addr_infos: + addr_infos[key] = [None, None] + addr_infos[key][idx] = address + + # each addr has to have info for each (family, proto) pair + addr_pairs_info = [ + (key, addr_pair) for key, addr_pair in addr_infos.items() + if not ((local_addr and addr_pair[0] is None) or + (remote_addr and addr_pair[1] is None))] + + if not addr_pairs_info: + raise ValueError('can not get address information') + + exceptions = [] + + for ((family, proto), + (local_address, remote_address)) in addr_pairs_info: + sock = None + l_addr = None + r_addr = None + try: + sock = socket.socket( + family=family, type=socket.SOCK_DGRAM, proto=proto) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(False) + + if local_addr: + sock.bind(local_address) + l_addr = sock.getsockname() + if remote_addr: + yield from self.sock_connect(sock, remote_address) + r_addr = remote_address + except OSError as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + + protocol = protocol_factory() + transport = self._make_datagram_transport( + sock, protocol, r_addr, extra={'addr': l_addr}) + return transport, protocol + + # This returns a Task made from self._start_serving_internal(). + # We want start_serving() to return a Task so that it will start + # running right away (when the event loop runs) even if the caller + # doesn't wait for it. Note that this is different from + # e.g. create_connection(), or create_datagram_endpoint(), which + # are a "mere" coroutines and require their caller to wait for + # them. The reason for the difference is that only + # start_serving() creates multiple transports and protocols. + def start_serving(self, protocol_factory, host=None, port=None, + *, + family=socket.AF_UNSPEC, + flags=socket.AI_PASSIVE, + sock=None, + backlog=100, + ssl=None, + reuse_address=None): + coro = self._start_serving_internal(protocol_factory, host, port, + family=family, + flags=flags, + sock=sock, + backlog=backlog, + ssl=ssl, + reuse_address=reuse_address) + return tasks.Task(coro, loop=self) + + @tasks.coroutine + def _start_serving_internal(self, protocol_factory, host=None, port=None, + *, + family=socket.AF_UNSPEC, + flags=socket.AI_PASSIVE, + sock=None, + backlog=100, + ssl=None, + reuse_address=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + 'host/port and sock can not be specified at the same time') + + AF_INET6 = getattr(socket, 'AF_INET6', 0) + if reuse_address is None: + reuse_address = os.name == 'posix' and sys.platform != 'cygwin' + sockets = [] + if host == '': + host = None + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=0, flags=flags) + if not infos: + raise OSError('getaddrinfo() returned empty list') + + completed = False + try: + for res in infos: + af, socktype, proto, canonname, sa = res + sock = socket.socket(af, socktype, proto) + sockets.append(sock) + if reuse_address: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, + True) + # Disable IPv4/IPv6 dual stack support (enabled by + # default on Linux) which makes a single socket + # listen on both address families. + if af == AF_INET6 and hasattr(socket, 'IPPROTO_IPV6'): + sock.setsockopt(socket.IPPROTO_IPV6, + socket.IPV6_V6ONLY, + True) + try: + sock.bind(sa) + except OSError as err: + raise OSError(err.errno, 'error while attempting ' + 'to bind on address %r: %s' + % (sa, err.strerror.lower())) + completed = True + finally: + if not completed: + for sock in sockets: + sock.close() + else: + if sock is None: + raise ValueError( + 'host and port was not specified and no sock specified') + sockets = [sock] + + for sock in sockets: + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock, ssl) + return sockets + + @tasks.coroutine + def connect_read_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future(loop=self) + transport = self._make_read_pipe_transport(pipe, protocol, waiter, + extra={}) + yield from waiter + return transport, protocol + + @tasks.coroutine + def connect_write_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future(loop=self) + transport = self._make_write_pipe_transport(pipe, protocol, waiter, + extra={}) + yield from waiter + return transport, protocol + + @tasks.coroutine + def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + universal_newlines=False, shell=True, bufsize=0, + **kwargs): + assert not universal_newlines, "universal_newlines must be False" + assert shell, "shell must be True" + assert isinstance(cmd, str), cmd + protocol = protocol_factory() + transport = yield from self._make_subprocess_transport( + protocol, cmd, True, stdin, stdout, stderr, bufsize, + extra={}, **kwargs) + return transport, protocol + + @tasks.coroutine + def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + universal_newlines=False, shell=False, bufsize=0, + **kwargs): + assert not universal_newlines, "universal_newlines must be False" + assert not shell, "shell must be False" + protocol = protocol_factory() + transport = yield from self._make_subprocess_transport( + protocol, args, False, stdin, stdout, stderr, bufsize, + extra={}, **kwargs) + return transport, protocol + + def _add_callback(self, handle): + """Add a Handle to ready or scheduled.""" + assert isinstance(handle, events.Handle), 'A Handle is required here' + if handle._cancelled: + return + if isinstance(handle, events.TimerHandle): + heapq.heappush(self._scheduled, handle) + else: + self._ready.append(handle) + + def _add_callback_signalsafe(self, handle): + """Like _add_callback() but called from a signal handler.""" + self._add_callback(handle) + self._write_to_self() + + def _run_once(self): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0]._cancelled: + heapq.heappop(self._scheduled) + + timeout = None + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0]._when + deadline = max(0, when - self.time()) + if timeout is None: + timeout = deadline + else: + timeout = min(timeout, deadline) + + # TODO: Instrumentation only in debug mode? + t0 = self.time() + event_list = self._selector.select(timeout) + t1 = self.time() + argstr = '' if timeout is None else '{:.3f}'.format(timeout) + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + tulip_log.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + self._process_events(event_list) + + # Handle 'later' callbacks that are ready. + now = self.time() + while self._scheduled: + handle = self._scheduled[0] + if handle._when > now: + break + handle = heapq.heappop(self._scheduled) + self._ready.append(handle) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is threadsafe without using locks. + ntodo = len(self._ready) + for i in range(ntodo): + handle = self._ready.popleft() + if not handle._cancelled: + handle._run() + handle = None # Needed to break cycles when an exception occurs. diff --git a/tulip/constants.py b/tulip/constants.py new file mode 100644 index 00000000..79c3b931 --- /dev/null +++ b/tulip/constants.py @@ -0,0 +1,4 @@ +"""Constants.""" + + +LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5 diff --git a/tulip/events.py b/tulip/events.py new file mode 100644 index 00000000..7db2514d --- /dev/null +++ b/tulip/events.py @@ -0,0 +1,389 @@ +"""Event loop and event loop policy. + +Beyond the PEP: +- Only the main thread has a default event loop. +""" + +__all__ = ['AbstractEventLoopPolicy', 'DefaultEventLoopPolicy', + 'AbstractEventLoop', 'TimerHandle', 'Handle', 'make_handle', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', 'new_event_loop', + ] + +import subprocess +import sys +import threading +import socket + +from .log import tulip_log + + +class Handle: + """Object returned by callback registration methods.""" + + def __init__(self, callback, args): + self._callback = callback + self._args = args + self._cancelled = False + + def __repr__(self): + res = 'Handle({}, {})'.format(self._callback, self._args) + if self._cancelled: + res += '' + return res + + def cancel(self): + self._cancelled = True + + def _run(self): + try: + self._callback(*self._args) + except Exception: + tulip_log.exception('Exception in callback %s %r', + self._callback, self._args) + self = None # Needed to break cycles when an exception occurs. + + +def make_handle(callback, args): + # TODO: Inline this? + assert not isinstance(callback, Handle), 'A Handle is not a callback' + return Handle(callback, args) + + +class TimerHandle(Handle): + """Object returned by timed callback registration methods.""" + + def __init__(self, when, callback, args): + assert when is not None + super().__init__(callback, args) + + self._when = when + + def __repr__(self): + res = 'TimerHandle({}, {}, {})'.format(self._when, + self._callback, + self._args) + if self._cancelled: + res += '' + + return res + + def __hash__(self): + return hash(self._when) + + def __lt__(self, other): + return self._when < other._when + + def __le__(self, other): + if self._when < other._when: + return True + return self.__eq__(other) + + def __gt__(self, other): + return self._when > other._when + + def __ge__(self, other): + if self._when > other._when: + return True + return self.__eq__(other) + + def __eq__(self, other): + if isinstance(other, TimerHandle): + return (self._when == other._when and + self._callback == other._callback and + self._args == other._args and + self._cancelled == other._cancelled) + return NotImplemented + + def __ne__(self, other): + equal = self.__eq__(other) + return NotImplemented if equal is NotImplemented else not equal + + +class AbstractEventLoop: + """Abstract event loop.""" + + # Running and stopping the event loop. + + def run_forever(self): + """Run the event loop until stop() is called.""" + raise NotImplementedError + + def run_until_complete(self, future): + """Run the event loop until a Future is done. + + Return the Future's result, or raise its exception. + """ + raise NotImplementedError + + def stop(self): + """Stop the event loop as soon as reasonable. + + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + raise NotImplementedError + + def is_running(self): + """Return whether the event loop is currently running.""" + raise NotImplementedError + + # Methods scheduling callbacks. All these return Handles. + + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) + + def call_later(self, delay, callback, *args): + raise NotImplementedError + + def call_at(self, when, callback, *args): + raise NotImplementedError + + def time(self): + raise NotImplementedError + + # Methods for interacting with threads. + + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError + + def run_in_executor(self, executor, callback, *args): + raise NotImplementedError + + def set_default_executor(self, executor): + raise NotImplementedError + + # Network I/O methods returning Futures. + + def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError + + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=None, family=0, proto=0, flags=0, sock=None, + local_addr=None): + raise NotImplementedError + + def start_serving(self, protocol_factory, host=None, port=None, *, + family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, + sock=None, backlog=100, ssl=None, reuse_address=None): + """Creates a TCP server bound to host and port and return a + Task whose result will be a list of socket objects which will + later be handled by protocol_factory. + + If host is an empty string or None all interfaces are assumed + and a list of multiple sockets will be returned (most likely + one for IPv4 and another one for IPv6). + + family can be set to either AF_INET or AF_INET6 to force the + socket to use IPv4 or IPv6. If not set it will be determined + from host (defaults to AF_UNSPEC). + + flags is a bitmask for getaddrinfo(). + + sock can optionally be specified in order to use a preexisting + socket object. + + backlog is the maximum number of queued connections passed to + listen() (defaults to 100). + + ssl can be set to an SSLContext to enable SSL over the + accepted connections. + + reuse_address tells the kernel to reuse a local socket in + TIME_WAIT state, without waiting for its natural timeout to + expire. If not specified will automatically be set to True on + UNIX. + """ + raise NotImplementedError + + def stop_serving(self, sock): + """Stop listening for incoming connections. Close socket.""" + raise NotImplementedError + + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + def connect_read_pipe(self, protocol_factory, pipe): + """Register read pipe in eventloop. + + protocol_factory should instantiate object with Protocol interface. + pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + ReadTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def connect_write_pipe(self, protocol_factory, pipe): + """Register write pipe in eventloop. + + protocol_factory should instantiate object with BaseProtocol interface. + Pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + WriteTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + **kwargs): + raise NotImplementedError + + def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + **kwargs): + raise NotImplementedError + + # Ready-based callback registration methods. + # The add_*() methods return None. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. + + def add_reader(self, fd, callback, *args): + raise NotImplementedError + + def remove_reader(self, fd): + raise NotImplementedError + + def add_writer(self, fd, callback, *args): + raise NotImplementedError + + def remove_writer(self, fd): + raise NotImplementedError + + # Completion based I/O methods returning Futures. + + def sock_recv(self, sock, nbytes): + raise NotImplementedError + + def sock_sendall(self, sock, data): + raise NotImplementedError + + def sock_connect(self, sock, address): + raise NotImplementedError + + def sock_accept(self, sock): + raise NotImplementedError + + # Signal handling. + + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError + + def remove_signal_handler(self, sig): + raise NotImplementedError + + +class AbstractEventLoopPolicy: + """Abstract policy for accessing the event loop.""" + + def get_event_loop(self): + """XXX""" + raise NotImplementedError + + def set_event_loop(self, loop): + """XXX""" + raise NotImplementedError + + def new_event_loop(self): + """XXX""" + raise NotImplementedError + + +class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy): + """Default policy implementation for accessing the event loop. + + In this policy, each thread has its own event loop. However, we + only automatically create an event loop by default for the main + thread; other threads by default have no event loop. + + Other policies may have different rules (e.g. a single global + event loop, or automatically creating an event loop per thread, or + using some other notion of context to which an event loop is + associated). + """ + + _loop = None + _set_called = False + + def get_event_loop(self): + """Get the event loop. + + This may be None or an instance of EventLoop. + """ + if (self._loop is None and + not self._set_called and + isinstance(threading.current_thread(), threading._MainThread)): + self._loop = self.new_event_loop() + assert self._loop is not None, \ + ('There is no current event loop in thread %r.' % + threading.current_thread().name) + return self._loop + + def set_event_loop(self, loop): + """Set the event loop.""" + # TODO: The isinstance() test violates the PEP. + self._set_called = True + assert loop is None or isinstance(loop, AbstractEventLoop) + self._loop = loop + + def new_event_loop(self): + """Create a new event loop. + + You must call set_event_loop() to make this the current event + loop. + """ + if sys.platform == 'win32': # pragma: no cover + from . import windows_events + return windows_events.SelectorEventLoop() + else: # pragma: no cover + from . import unix_events + return unix_events.SelectorEventLoop() + + +# Event loop policy. The policy itself is always global, even if the +# policy's rules say that there is an event loop per thread (or other +# notion of context). The default policy is installed by the first +# call to get_event_loop_policy(). +_event_loop_policy = None + + +def get_event_loop_policy(): + """XXX""" + global _event_loop_policy + if _event_loop_policy is None: + _event_loop_policy = DefaultEventLoopPolicy() + return _event_loop_policy + + +def set_event_loop_policy(policy): + """XXX""" + global _event_loop_policy + # TODO: The isinstance() test violates the PEP. + assert policy is None or isinstance(policy, AbstractEventLoopPolicy) + _event_loop_policy = policy + + +def get_event_loop(): + """XXX""" + return get_event_loop_policy().get_event_loop() + + +def set_event_loop(loop): + """XXX""" + get_event_loop_policy().set_event_loop(loop) + + +def new_event_loop(): + """XXX""" + return get_event_loop_policy().new_event_loop() diff --git a/tulip/futures.py b/tulip/futures.py new file mode 100644 index 00000000..706e8c8a --- /dev/null +++ b/tulip/futures.py @@ -0,0 +1,338 @@ +"""A Future class similar to the one in PEP 3148.""" + +__all__ = ['CancelledError', 'TimeoutError', + 'InvalidStateError', + 'Future', 'wrap_future', + ] + +import concurrent.futures._base +import logging +import traceback + +from . import events +from .log import tulip_log + +# States for Future. +_PENDING = 'PENDING' +_CANCELLED = 'CANCELLED' +_FINISHED = 'FINISHED' + +# TODO: Do we really want to depend on concurrent.futures internals? +Error = concurrent.futures._base.Error +CancelledError = concurrent.futures.CancelledError +TimeoutError = concurrent.futures.TimeoutError + +STACK_DEBUG = logging.DEBUG - 1 # heavy-duty debugging + + +class InvalidStateError(Error): + """The operation is not allowed in this state.""" + # TODO: Show the future, its state, the method, and the required state. + + +class _TracebackLogger: + """Helper to log a traceback upon destruction if not cleared. + + This solves a nasty problem with Futures and Tasks that have an + exception set: if nobody asks for the exception, the exception is + never logged. This violates the Zen of Python: 'Errors should + never pass silently. Unless explicitly silenced.' + + However, we don't want to log the exception as soon as + set_exception() is called: if the calling code is written + properly, it will get the exception and handle it properly. But + we *do* want to log it if result() or exception() was never called + -- otherwise developers waste a lot of time wondering why their + buggy code fails silently. + + An earlier attempt added a __del__() method to the Future class + itself, but this backfired because the presence of __del__() + prevents garbage collection from breaking cycles. A way out of + this catch-22 is to avoid having a __del__() method on the Future + class itself, but instead to have a reference to a helper object + with a __del__() method that logs the traceback, where we ensure + that the helper object doesn't participate in cycles, and only the + Future has a reference to it. + + The helper object is added when set_exception() is called. When + the Future is collected, and the helper is present, the helper + object is also collected, and its __del__() method will log the + traceback. When the Future's result() or exception() method is + called (and a helper object is present), it removes the the helper + object, after calling its clear() method to prevent it from + logging. + + One downside is that we do a fair amount of work to extract the + traceback from the exception, even when it is never logged. It + would seem cheaper to just store the exception object, but that + references the traceback, which references stack frames, which may + reference the Future, which references the _TracebackLogger, and + then the _TracebackLogger would be included in a cycle, which is + what we're trying to avoid! As an optimization, we don't + immediately format the exception; we only do the work when + activate() is called, which call is delayed until after all the + Future's callbacks have run. Since usually a Future has at least + one callback (typically set by 'yield from') and usually that + callback extracts the callback, thereby removing the need to + format the exception. + + PS. I don't claim credit for this solution. I first heard of it + in a discussion about closing files when they are collected. + """ + + __slots__ = ['exc', 'tb'] + + def __init__(self, exc): + self.exc = exc + self.tb = None + + def activate(self): + exc = self.exc + if exc is not None: + self.exc = None + self.tb = traceback.format_exception(exc.__class__, exc, + exc.__traceback__) + + def clear(self): + self.exc = None + self.tb = None + + def __del__(self): + if self.tb: + tulip_log.error('Future/Task exception was never retrieved:\n%s', + ''.join(self.tb)) + + +class Future: + """This class is *almost* compatible with concurrent.futures.Future. + + Differences: + + - result() and exception() do not take a timeout argument and + raise an exception when the future isn't done yet. + + - Callbacks registered with add_done_callback() are always called + via the event loop's call_soon_threadsafe(). + + - This class is not compatible with the wait() and as_completed() + methods in the concurrent.futures package. + + (In Python 3.4 or later we may be able to unify the implementations.) + """ + + # Class variables serving as defaults for instance variables. + _state = _PENDING + _result = None + _exception = None + _loop = None + + _blocking = False # proper use of future (yield vs yield from) + + _tb_logger = None + + def __init__(self, *, loop=None): + """Initialize the future. + + The optional event_loop argument allows to explicitly set the event + loop object used by the future. If it's not provided, the future uses + the default event loop. + """ + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._callbacks = [] + + def __repr__(self): + res = self.__class__.__name__ + if self._state == _FINISHED: + if self._exception is not None: + res += ''.format(self._exception) + else: + res += ''.format(self._result) + elif self._callbacks: + size = len(self._callbacks) + if size > 2: + res += '<{}, [{}, <{} more>, {}]>'.format( + self._state, self._callbacks[0], + size-2, self._callbacks[-1]) + else: + res += '<{}, {}>'.format(self._state, self._callbacks) + else: + res += '<{}>'.format(self._state) + return res + + def cancel(self): + """Cancel the future and schedule callbacks. + + If the future is already done or cancelled, return False. Otherwise, + change the future's state to cancelled, schedule the callbacks and + return True. + """ + if self._state != _PENDING: + return False + self._state = _CANCELLED + self._schedule_callbacks() + return True + + def _schedule_callbacks(self): + """Internal: Ask the event loop to call all callbacks. + + The callbacks are scheduled to be called as soon as possible. Also + clears the callback list. + """ + callbacks = self._callbacks[:] + if not callbacks: + return + + self._callbacks[:] = [] + for callback in callbacks: + self._loop.call_soon(callback, self) + + def cancelled(self): + """Return True if the future was cancelled.""" + return self._state == _CANCELLED + + # Don't implement running(); see http://bugs.python.org/issue18699 + + def done(self): + """Return True if the future is done. + + Done means either that a result / exception are available, or that the + future was cancelled. + """ + return self._state != _PENDING + + def result(self): + """Return the result this future represents. + + If the future has been cancelled, raises CancelledError. If the + future's result isn't yet available, raises InvalidStateError. If + the future is done and has an exception set, this exception is raised. + """ + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError('Result is not ready.') + if self._tb_logger is not None: + self._tb_logger.clear() + self._tb_logger = None + if self._exception is not None: + raise self._exception + return self._result + + def exception(self): + """Return the exception that was set on this future. + + The exception (or None if no exception was set) is returned only if + the future is done. If the future has been cancelled, raises + CancelledError. If the future isn't done yet, raises + InvalidStateError. + """ + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError('Exception is not set.') + if self._tb_logger is not None: + self._tb_logger.clear() + self._tb_logger = None + return self._exception + + def add_done_callback(self, fn): + """Add a callback to be run when the future becomes done. + + The callback is called with a single argument - the future object. If + the future is already done when this is called, the callback is + scheduled with call_soon. + """ + if self._state != _PENDING: + self._loop.call_soon(fn, self) + else: + self._callbacks.append(fn) + + # New method not in PEP 3148. + + def remove_done_callback(self, fn): + """Remove all instances of a callback from the "call when done" list. + + Returns the number of callbacks removed. + """ + filtered_callbacks = [f for f in self._callbacks if f != fn] + removed_count = len(self._callbacks) - len(filtered_callbacks) + if removed_count: + self._callbacks[:] = filtered_callbacks + return removed_count + + # So-called internal methods (note: no set_running_or_notify_cancel()). + + def set_result(self, result): + """Mark the future done and set its result. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError('{}: {!r}'.format(self._state, self)) + self._result = result + self._state = _FINISHED + self._schedule_callbacks() + + def set_exception(self, exception): + """Mark the future done and set an exception. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError('{}: {!r}'.format(self._state, self)) + self._exception = exception + self._tb_logger = _TracebackLogger(exception) + self._state = _FINISHED + self._schedule_callbacks() + # Arrange for the logger to be activated after all callbacks + # have had a chance to call result() or exception(). + self._loop.call_soon(self._tb_logger.activate) + + # Truly internal methods. + + def _copy_state(self, other): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert other.done() + assert not self.done() + if other.cancelled(): + self.cancel() + else: + exception = other.exception() + if exception is not None: + self.set_exception(exception) + else: + result = other.result() + self.set_result(result) + + def __iter__(self): + if not self.done(): + self._blocking = True + yield self # This tells Task to wait for completion. + assert self.done(), "yield from wasn't used with future" + return self.result() # May raise too. + + +def wrap_future(fut, *, loop=None): + """Wrap concurrent.futures.Future object.""" + if isinstance(fut, Future): + return fut + + assert isinstance(fut, concurrent.futures.Future), \ + 'concurrent.futures.Future is expected, got {!r}'.format(fut) + + if loop is None: + loop = events.get_event_loop() + + new_future = Future(loop=loop) + fut.add_done_callback( + lambda future: loop.call_soon_threadsafe( + new_future._copy_state, fut)) + return new_future diff --git a/tulip/http/__init__.py b/tulip/http/__init__.py new file mode 100644 index 00000000..a1432dee --- /dev/null +++ b/tulip/http/__init__.py @@ -0,0 +1,16 @@ +# This relies on each of the submodules having an __all__ variable. + +from .client import * +from .errors import * +from .protocol import * +from .server import * +from .session import * +from .wsgi import * + + +__all__ = (client.__all__ + + errors.__all__ + + protocol.__all__ + + server.__all__ + + session.__all__ + + wsgi.__all__) diff --git a/tulip/http/client.py b/tulip/http/client.py new file mode 100644 index 00000000..ec7cd034 --- /dev/null +++ b/tulip/http/client.py @@ -0,0 +1,572 @@ +"""HTTP Client for Tulip. + +Most basic usage: + + response = yield from tulip.http.request('GET', url) + response['Content-Type'] == 'application/json' + response.status == 200 + + content = yield from response.content.read() +""" + +__all__ = ['request'] + +import base64 +import email.message +import functools +import http.client +import http.cookies +import json +import io +import itertools +import mimetypes +import os +import uuid +import urllib.parse + +import tulip +import tulip.http + + +@tulip.coroutine +def request(method, url, *, + params=None, + data=None, + headers=None, + cookies=None, + files=None, + auth=None, + allow_redirects=True, + max_redirects=10, + encoding='utf-8', + version=(1, 1), + timeout=None, + compress=None, + chunked=None, + session=None, + loop=None): + """Constructs and sends a request. Returns response object. + + method: http method + url: request url + params: (optional) Dictionary or bytes to be sent in the query string + of the new request + data: (optional) Dictionary, bytes, or file-like object to + send in the body of the request + headers: (optional) Dictionary of HTTP Headers to send with the request + cookies: (optional) Dict object to send with the request + files: (optional) Dictionary of 'name': file-like-objects + for multipart encoding upload + auth: (optional) Auth tuple to enable Basic HTTP Auth + timeout: (optional) Float describing the timeout of the request + allow_redirects: (optional) Boolean. Set to True if POST/PUT/DELETE + redirect following is allowed. + compress: Boolean. Set to True if request has to be compressed + with deflate encoding. + chunked: Boolean or Integer. Set to chunk size for chunked + transfer encoding. + session: tulip.http.Session instance to support connection pooling and + session cookies. + loop: Optional event loop. + + Usage: + + import tulip.http + >> resp = yield from tulip.http.request('GET', 'http://python.org/') + >> resp + + + >> data = yield from resp.content.read() + + """ + redirects = 0 + if loop is None: + loop = tulip.get_event_loop() + + while True: + req = HttpRequest( + method, url, params=params, headers=headers, data=data, + cookies=cookies, files=files, auth=auth, encoding=encoding, + version=version, compress=compress, chunked=chunked) + + if session is None: + conn = start(req, loop) + else: + conn = session.start(req, loop) + + # connection timeout + t = tulip.Task(conn, loop=loop) + th = None + if timeout is not None: + th = loop.call_later(timeout, t.cancel) + try: + resp = yield from t + except tulip.CancelledError: + raise tulip.TimeoutError from None + finally: + if th is not None: + th.cancel() + + # redirects + if resp.status in (301, 302) and allow_redirects: + redirects += 1 + if max_redirects and redirects >= max_redirects: + resp.close() + break + + r_url = resp.get('location') or resp.get('uri') + + scheme = urllib.parse.urlsplit(r_url)[0] + if scheme not in ('http', 'https', ''): + raise ValueError('Can redirect only to http or https') + elif not scheme: + r_url = urllib.parse.urljoin(url, r_url) + + url = urllib.parse.urldefrag(r_url)[0] + if url: + resp.close() + continue + + break + + return resp + + +@tulip.coroutine +def start(req, loop): + transport, p = yield from loop.create_connection( + functools.partial(tulip.StreamProtocol, loop=loop), + req.host, req.port, ssl=req.ssl) + + try: + resp = req.send(transport) + yield from resp.start(p, transport) + except: + transport.close() + raise + + return resp + + +class HttpRequest: + + GET_METHODS = {'DELETE', 'GET', 'HEAD', 'OPTIONS'} + POST_METHODS = {'PATCH', 'POST', 'PUT', 'TRACE'} + ALL_METHODS = GET_METHODS.union(POST_METHODS) + + DEFAULT_HEADERS = { + 'Accept': '*/*', + 'Accept-Encoding': 'gzip, deflate', + } + + body = b'' + + def __init__(self, method, url, *, + params=None, + headers=None, + data=None, + cookies=None, + files=None, + auth=None, + encoding='utf-8', + version=(1, 1), + compress=None, + chunked=None): + self.method = method.upper() + self.encoding = encoding + + # parser http version '1.1' => (1, 1) + if isinstance(version, str): + v = [l.strip() for l in version.split('.', 1)] + try: + version = int(v[0]), int(v[1]) + except ValueError: + raise ValueError( + 'Can not parse http version number: {}' + .format(version)) from None + self.version = version + + # path + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not netloc: + raise ValueError('Host could not be detected.') + + if not path: + path = '/' + else: + path = urllib.parse.unquote(path) + + # check domain idna encoding + try: + netloc = netloc.encode('idna').decode('utf-8') + except UnicodeError: + raise ValueError('URL has an invalid label.') + + # basic auth info + if '@' in netloc: + authinfo, netloc = netloc.split('@', 1) + if not auth: + auth = authinfo.split(':', 1) + if len(auth) == 1: + auth.append('') + + # extract host and port + ssl = scheme == 'https' + + if ':' in netloc: + netloc, port_s = netloc.split(':', 1) + try: + port = int(port_s) + except ValueError: + raise ValueError( + 'Port number could not be converted.') from None + else: + if ssl: + port = http.client.HTTPS_PORT + else: + port = http.client.HTTP_PORT + + self.host = netloc + self.port = port + self.ssl = ssl + + # build url query + if isinstance(params, dict): + params = list(params.items()) + + if data and self.method in self.GET_METHODS: + # include data to query + if isinstance(data, dict): + data = data.items() + params = list(itertools.chain(params or (), data)) + data = None + + if params: + params = urllib.parse.urlencode(params) + if query: + query = '%s&%s' % (query, params) + else: + query = params + + # build path + path = urllib.parse.quote(path) + self.path = urllib.parse.urlunsplit(('', '', path, query, fragment)) + + # headers + self.headers = email.message.Message() + if headers: + if isinstance(headers, dict): + headers = list(headers.items()) + + for key, value in headers: + self.headers.add_header(key, value) + + for hdr, val in self.DEFAULT_HEADERS.items(): + if hdr not in self.headers: + self.headers[hdr] = val + + # host + if 'host' not in self.headers: + self.headers['Host'] = self.host + + # cookies + if cookies: + self.update_cookies(cookies) + + # auth + if auth: + if isinstance(auth, (tuple, list)) and len(auth) == 2: + # basic auth + self.headers['Authorization'] = 'Basic %s' % ( + base64.b64encode( + ('%s:%s' % (auth[0], auth[1])).encode('latin1')) + .strip().decode('latin1')) + else: + raise ValueError("Only basic auth is supported") + + # Content-encoding + enc = self.headers.get('Content-Encoding', '').lower() + if enc: + chunked = True # enable chunked, no need to deal with length + compress = enc + elif compress: + chunked = True # enable chunked, no need to deal with length + compress = compress if isinstance(compress, str) else 'deflate' + self.headers['Content-Encoding'] = compress + + # form data (x-www-form-urlencoded) + if isinstance(data, dict): + data = list(data.items()) + + if data and not files: + if not isinstance(data, str): + data = urllib.parse.urlencode(data, doseq=True) + + self.body = data.encode(encoding) + if 'content-type' not in self.headers: + self.headers['content-type'] = ( + 'application/x-www-form-urlencoded') + if 'content-length' not in self.headers and not chunked: + self.headers['content-length'] = str(len(self.body)) + + # files (multipart/form-data) + elif files: + fields = [] + + if data: + for field, val in data: + fields.append((field, str_to_bytes(val))) + + if isinstance(files, dict): + files = list(files.items()) + + for rec in files: + if not isinstance(rec, (tuple, list)): + rec = (rec,) + + ft = None + if len(rec) == 1: + k = guess_filename(rec[0], 'unknown') + fields.append((k, k, rec[0])) + + elif len(rec) == 2: + k, fp = rec + fn = guess_filename(fp, k) + fields.append((k, fn, fp)) + + else: + k, fp, ft = rec + fn = guess_filename(fp, k) + fields.append((k, fn, fp, ft)) + + chunked = chunked or 8192 + boundary = uuid.uuid4().hex + + self.body = encode_multipart_data( + fields, bytes(boundary, 'latin1')) + + self.headers['content-type'] = ( + 'multipart/form-data; boundary=%s' % boundary) + + # chunked + te = self.headers.get('transfer-encoding', '').lower() + + if chunked: + if 'content-length' in self.headers: + del self.headers['content-length'] + if 'chunked' not in te: + self.headers['transfer-encoding'] = 'chunked' + + chunked = chunked if type(chunked) is int else 8196 + else: + if 'chunked' in te: + chunked = 8196 + else: + chunked = None + self.headers['content-length'] = str(len(self.body)) + + self._chunked = chunked + self._compress = compress + + def update_cookies(self, cookies): + """Update request cookies header.""" + c = http.cookies.SimpleCookie() + if 'cookie' in self.headers: + c.load(self.headers.get('cookie', '')) + del self.headers['cookie'] + + if isinstance(cookies, dict): + cookies = cookies.items() + + for name, value in cookies: + if isinstance(value, http.cookies.Morsel): + # use dict method because SimpleCookie class modifies value + dict.__setitem__(c, name, value) + else: + c[name] = value + + self.headers['cookie'] = c.output(header='', sep=';').strip() + + def send(self, transport): + request = tulip.http.Request( + transport, self.method, self.path, self.version) + + if self._compress: + request.add_compression_filter(self._compress) + + if self._chunked is not None: + request.add_chunking_filter(self._chunked) + + request.add_headers(*self.headers.items()) + request.send_headers() + + if isinstance(self.body, bytes): + self.body = (self.body,) + + for chunk in self.body: + request.write(chunk) + + request.write_eof() + + return HttpResponse(self.method, self.path, self.host) + + +class HttpResponse(http.client.HTTPMessage): + + message = None # RawResponseMessage object + + # from the Status-Line of the response + version = None # HTTP-Version + status = None # Status-Code + reason = None # Reason-Phrase + + cookies = None # Response cookies (Set-Cookie) + + content = None # payload stream + stream = None # input stream + transport = None # current transport + + def __init__(self, method, url, host=''): + super().__init__() + + self.method = method + self.url = url + self.host = host + self._content = None + + def __del__(self): + self.close() + + def __repr__(self): + out = io.StringIO() + print(''.format( + self.host, self.url, self.status, self.reason), file=out) + print(super().__str__(), file=out) + return out.getvalue() + + def start(self, stream, transport): + """Start response processing.""" + self.stream = stream + self.transport = transport + + httpstream = stream.set_parser(tulip.http.http_response_parser()) + + # read response + self.message = yield from httpstream.read() + + # response status + self.version = self.message.version + self.status = self.message.code + self.reason = self.message.reason + + # headers + for hdr, val in self.message.headers: + self.add_header(hdr, val) + + # payload + self.content = stream.set_parser( + tulip.http.http_payload_parser(self.message)) + + # cookies + self.cookies = http.cookies.SimpleCookie() + if 'Set-Cookie' in self: + for hdr in self.get_all('Set-Cookie'): + self.cookies.load(hdr) + + return self + + def close(self): + if self.transport is not None: + self.transport.close() + self.transport = None + + @tulip.coroutine + def read(self, decode=False): + """Read response payload. Decode known types of content.""" + if self._content is None: + buf = [] + total = 0 + chunk = yield from self.content.read() + while chunk: + size = len(chunk) + buf.append((chunk, size)) + total += size + chunk = yield from self.content.read() + + self._content = bytearray(total) + + idx = 0 + content = memoryview(self._content) + for chunk, size in buf: + content[idx:idx+size] = chunk + idx += size + + data = self._content + + if decode: + ct = self.get('content-type', '').lower() + if ct == 'application/json': + data = json.loads(data.decode('utf-8')) + + return data + + +def str_to_bytes(s, encoding='utf-8'): + if isinstance(s, str): + return s.encode(encoding) + return s + + +def guess_filename(obj, default=None): + name = getattr(obj, 'name', None) + if name and name[0] != '<' and name[-1] != '>': + return os.path.split(name)[-1] + return default + + +def encode_multipart_data(fields, boundary, encoding='utf-8', chunk_size=8196): + """ + Encode a list of fields using the multipart/form-data MIME format. + + fields: + List of (name, value) or (name, filename, io) or + (name, filename, io, MIME type) field tuples. + """ + for rec in fields: + yield b'--' + boundary + b'\r\n' + + field, *rec = rec + + if len(rec) == 1: + data = rec[0] + yield (('Content-Disposition: form-data; name="%s"\r\n\r\n' % + (field,)).encode(encoding)) + yield data + b'\r\n' + + else: + if len(rec) == 3: + fn, fp, ct = rec + else: + fn, fp = rec + ct = (mimetypes.guess_type(fn)[0] or + 'application/octet-stream') + + yield ('Content-Disposition: form-data; name="%s"; ' + 'filename="%s"\r\n' % (field, fn)).encode(encoding) + yield ('Content-Type: %s\r\n\r\n' % (ct,)).encode(encoding) + + if isinstance(fp, str): + fp = fp.encode(encoding) + + if isinstance(fp, bytes): + fp = io.BytesIO(fp) + + while True: + chunk = fp.read(chunk_size) + if not chunk: + break + yield str_to_bytes(chunk) + + yield b'\r\n' + + yield b'--' + boundary + b'--\r\n' diff --git a/tulip/http/errors.py b/tulip/http/errors.py new file mode 100644 index 00000000..f8b77e9b --- /dev/null +++ b/tulip/http/errors.py @@ -0,0 +1,46 @@ +"""http related errors.""" + +__all__ = ['HttpException', 'HttpErrorException', 'BadRequestException', + 'IncompleteRead', 'BadStatusLine', 'LineTooLong', 'InvalidHeader'] + +import http.client + + +class HttpException(http.client.HTTPException): + + code = None + headers = () + message = '' + + +class HttpErrorException(HttpException): + + def __init__(self, code, message='', headers=None): + self.code = code + self.headers = headers + self.message = message + + +class BadRequestException(HttpException): + + code = 400 + message = 'Bad Request' + + +class IncompleteRead(BadRequestException, http.client.IncompleteRead): + pass + + +class BadStatusLine(BadRequestException, http.client.BadStatusLine): + pass + + +class LineTooLong(BadRequestException, http.client.LineTooLong): + pass + + +class InvalidHeader(BadRequestException): + + def __init__(self, hdr): + super().__init__('Invalid HTTP Header: {}'.format(hdr)) + self.hdr = hdr diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py new file mode 100644 index 00000000..7081fd59 --- /dev/null +++ b/tulip/http/protocol.py @@ -0,0 +1,756 @@ +"""Http related helper utils.""" + +__all__ = ['HttpMessage', 'Request', 'Response', + 'RawRequestMessage', 'RawResponseMessage', + 'http_request_parser', 'http_response_parser', + 'http_payload_parser'] + +import collections +import functools +import http.server +import itertools +import re +import sys +import zlib +from wsgiref.handlers import format_date_time + +import tulip +from tulip.http import errors + +METHRE = re.compile('[A-Z0-9$-_.]+') +VERSRE = re.compile('HTTP/(\d+).(\d+)') +HDRRE = re.compile('[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]') +CONTINUATION = (' ', '\t') +EOF_MARKER = object() +EOL_MARKER = object() + +RESPONSES = http.server.BaseHTTPRequestHandler.responses + + +RawRequestMessage = collections.namedtuple( + 'RawRequestMessage', + ['method', 'path', 'version', 'headers', 'should_close', 'compression']) + + +RawResponseMessage = collections.namedtuple( + 'RawResponseMessage', + ['version', 'code', 'reason', 'headers', 'should_close', 'compression']) + + +def http_request_parser(max_line_size=8190, + max_headers=32768, max_field_size=8190): + """Read request status line. Exception errors.BadStatusLine + could be raised in case of any errors in status line. + Returns RawRequestMessage. + """ + out, buf = yield + + try: + # read http message (request line + headers) + raw_data = yield from buf.readuntil( + b'\r\n\r\n', max_headers, errors.LineTooLong) + lines = raw_data.decode('ascii', 'surrogateescape').splitlines(True) + + # request line + line = lines[0] + try: + method, path, version = line.split(None, 2) + except ValueError: + raise errors.BadStatusLine(line) from None + + # method + method = method.upper() + if not METHRE.match(method): + raise errors.BadStatusLine(method) + + # version + match = VERSRE.match(version) + if match is None: + raise errors.BadStatusLine(version) + version = (int(match.group(1)), int(match.group(2))) + + # read headers + headers, close, compression = parse_headers( + lines, max_line_size, max_headers, max_field_size) + if version <= (1, 0): + close = True + elif close is None: + close = False + + out.feed_data( + RawRequestMessage( + method, path, version, headers, close, compression)) + out.feed_eof() + except tulip.EofStream: + # Presumably, the server closed the connection before + # sending a valid response. + pass + + +def http_response_parser(max_line_size=8190, + max_headers=32768, max_field_size=8190): + """Read response status line and headers. + + BadStatusLine could be raised in case of any errors in status line. + Returns RawResponseMessage""" + out, buf = yield + + try: + # read http message (response line + headers) + raw_data = yield from buf.readuntil( + b'\r\n\r\n', max_line_size+max_headers, errors.LineTooLong) + lines = raw_data.decode('ascii', 'surrogateescape').splitlines(True) + + line = lines[0] + try: + version, status = line.split(None, 1) + except ValueError: + raise errors.BadStatusLine(line) from None + else: + try: + status, reason = status.split(None, 1) + except ValueError: + reason = '' + + # version + match = VERSRE.match(version) + if match is None: + raise errors.BadStatusLine(line) + version = (int(match.group(1)), int(match.group(2))) + + # The status code is a three-digit number + try: + status = int(status) + except ValueError: + raise errors.BadStatusLine(line) from None + + if status < 100 or status > 999: + raise errors.BadStatusLine(line) + + # read headers + headers, close, compression = parse_headers( + lines, max_line_size, max_headers, max_field_size) + if close is None: + close = version <= (1, 0) + + out.feed_data( + RawResponseMessage( + version, status, reason.strip(), headers, close, compression)) + out.feed_eof() + except tulip.EofStream: + # Presumably, the server closed the connection before + # sending a valid response. + raise errors.BadStatusLine(b'') from None + + +def parse_headers(lines, max_line_size, max_headers, max_field_size): + """Parses RFC2822 headers from a stream. + + Line continuations are supported. Returns list of header name + and value pairs. Header name is in upper case. + """ + close_conn = None + encoding = None + headers = collections.deque() + + lines_idx = 1 + line = lines[1] + + while line not in ('\r\n', '\n'): + header_length = len(line) + + # Parse initial header name : value pair. + try: + name, value = line.split(':', 1) + except ValueError: + raise ValueError('Invalid header: {}'.format(line)) from None + + name = name.strip(' \t').upper() + if HDRRE.search(name): + raise ValueError('Invalid header name: {}'.format(name)) + + # next line + lines_idx += 1 + line = lines[lines_idx] + + # consume continuation lines + continuation = line[0] in CONTINUATION + + if continuation: + value = [value] + while continuation: + header_length += len(line) + if header_length > max_field_size: + raise errors.LineTooLong( + 'limit request headers fields size') + value.append(line) + + # next line + lines_idx += 1 + line = lines[lines_idx] + continuation = line[0] in CONTINUATION + value = ''.join(value) + else: + if header_length > max_field_size: + raise errors.LineTooLong('limit request headers fields size') + + value = value.strip() + + # keep-alive and encoding + if name == 'CONNECTION': + v = value.lower() + if v == 'close': + close_conn = True + elif v == 'keep-alive': + close_conn = False + elif name == 'CONTENT-ENCODING': + enc = value.lower() + if enc in ('gzip', 'deflate'): + encoding = enc + + headers.append((name, value)) + + return headers, close_conn, encoding + + +def http_payload_parser(message, length=None, compression=True, readall=False): + out, buf = yield + + # payload params + chunked = False + for name, value in message.headers: + if name == 'CONTENT-LENGTH': + length = value + elif name == 'TRANSFER-ENCODING': + chunked = value.lower() == 'chunked' + elif name == 'SEC-WEBSOCKET-KEY1': + length = 8 + + # payload decompression wrapper + if compression and message.compression: + out = DeflateBuffer(out, message.compression) + + # payload parser + if chunked: + yield from parse_chunked_payload(out, buf) + + elif length is not None: + try: + length = int(length) + except ValueError: + raise errors.InvalidHeader('CONTENT-LENGTH') from None + + if length < 0: + raise errors.InvalidHeader('CONTENT-LENGTH') + elif length > 0: + yield from parse_length_payload(out, buf, length) + else: + if readall: + yield from parse_eof_payload(out, buf) + + out.feed_eof() + + +def parse_chunked_payload(out, buf): + """Chunked transfer encoding parser.""" + try: + while True: + # read next chunk size + #line = yield from buf.readline(8196) + line = yield from buf.readuntil(b'\r\n', 8196) + + i = line.find(b';') + if i >= 0: + line = line[:i] # strip chunk-extensions + else: + line = line.strip() + try: + size = int(line, 16) + except ValueError: + raise errors.IncompleteRead(b'') from None + + if size == 0: # eof marker + break + + # read chunk and feed buffer + while size: + chunk = yield from buf.readsome(size) + out.feed_data(chunk) + size = size - len(chunk) + + # toss the CRLF at the end of the chunk + yield from buf.skip(2) + + # read and discard trailer up to the CRLF terminator + yield from buf.skipuntil(b'\r\n') + + except tulip.EofStream: + raise errors.IncompleteRead(b'') from None + + +def parse_length_payload(out, buf, length): + """Read specified amount of bytes.""" + try: + while length: + chunk = yield from buf.readsome(length) + out.feed_data(chunk) + length -= len(chunk) + + except tulip.EofStream: + raise errors.IncompleteRead(b'') from None + + +def parse_eof_payload(out, buf): + """Read all bytes untile eof.""" + while True: + out.feed_data((yield from buf.readsome())) + + +class DeflateBuffer: + """DeflateStream decomress stream and feed data into specified stream.""" + + def __init__(self, out, encoding): + self.out = out + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + + self.zlib = zlib.decompressobj(wbits=zlib_mode) + + def feed_data(self, chunk): + try: + chunk = self.zlib.decompress(chunk) + except Exception: + raise errors.IncompleteRead(b'') from None + + if chunk: + self.out.feed_data(chunk) + + def feed_eof(self): + self.out.feed_data(self.zlib.flush()) + if not self.zlib.eof: + raise errors.IncompleteRead(b'') + + self.out.feed_eof() + + +def wrap_payload_filter(func): + """Wraps payload filter and piped filters. + + Filter is a generatator that accepts arbitrary chunks of data, + modify data and emit new stream of data. + + For example we have stream of chunks: ['1', '2', '3', '4', '5'], + we can apply chunking filter to this stream: + + ['1', '2', '3', '4', '5'] + | + response.add_chunking_filter(2) + | + ['12', '34', '5'] + + It is possible to use different filters at the same time. + + For a example to compress incoming stream with 'deflate' encoding + and then split data and emit chunks of 8196 bytes size chunks: + + >> response.add_compression_filter('deflate') + >> response.add_chunking_filter(8196) + + Filters do not alter transfer encoding. + + Filter can receive types types of data, bytes object or EOF_MARKER. + + 1. If filter receives bytes object, it should process data + and yield processed data then yield EOL_MARKER object. + 2. If Filter recevied EOF_MARKER, it should yield remaining + data (buffered) and then yield EOF_MARKER. + """ + @functools.wraps(func) + def wrapper(self, *args, **kw): + new_filter = func(self, *args, **kw) + + filter = self.filter + if filter is not None: + next(new_filter) + self.filter = filter_pipe(filter, new_filter) + else: + self.filter = new_filter + + next(self.filter) + + return wrapper + + +def filter_pipe(filter, filter2): + """Creates pipe between two filters. + + filter_pipe() feeds first filter with incoming data and then + send yielded from first filter data into filter2, results of + filter2 are being emitted. + + 1. If filter_pipe receives bytes object, it sends it to the first filter. + 2. Reads yielded values from the first filter until it receives + EOF_MARKER or EOL_MARKER. + 3. Each of this values is being send to second filter. + 4. Reads yielded values from second filter until it recives EOF_MARKER or + EOL_MARKER. Each of this values yields to writer. + """ + chunk = yield + + while True: + eof = chunk is EOF_MARKER + chunk = filter.send(chunk) + + while chunk is not EOL_MARKER: + chunk = filter2.send(chunk) + + while chunk not in (EOF_MARKER, EOL_MARKER): + yield chunk + chunk = next(filter2) + + if chunk is not EOF_MARKER: + if eof: + chunk = EOF_MARKER + else: + chunk = next(filter) + else: + break + + chunk = yield EOL_MARKER + + +class HttpMessage: + """HttpMessage allows to write headers and payload to a stream. + + For example, lets say we want to read file then compress it with deflate + compression and then send it with chunked transfer encoding, code may look + like this: + + >> response = tulip.http.Response(transport, 200) + + We have to use deflate compression first: + + >> response.add_compression_filter('deflate') + + Then we want to split output stream into chunks of 1024 bytes size: + + >> response.add_chunking_filter(1024) + + We can add headers to response with add_headers() method. add_headers() + does not send data to transport, send_headers() sends request/response + line and then sends headers: + + >> response.add_headers( + .. ('Content-Disposition', 'attachment; filename="..."')) + >> response.send_headers() + + Now we can use chunked writer to write stream to a network stream. + First call to write() method sends response status line and headers, + add_header() and add_headers() method unavailble at this stage: + + >> with open('...', 'rb') as f: + .. chunk = fp.read(8196) + .. while chunk: + .. response.write(chunk) + .. chunk = fp.read(8196) + + >> response.write_eof() + """ + + writer = None + + # 'filter' is being used for altering write() bahaviour, + # add_chunking_filter adds deflate/gzip compression and + # add_compression_filter splits incoming data into a chunks. + filter = None + + HOP_HEADERS = None # Must be set by subclass. + + SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} tulip/0.0'.format(sys.version_info) + + status = None + status_line = b'' + upgrade = False # Connection: UPGRADE + websocket = False # Upgrade: WEBSOCKET + + # subclass can enable auto sending headers with write() call, + # this is useful for wsgi's start_response implementation. + _send_headers = False + + def __init__(self, transport, version, close): + self.transport = transport + self.version = version + self.closing = close + + # disable keep-alive for http/1.0 + if version <= (1, 0): + self.keepalive = False + else: + self.keepalive = None + + self.chunked = False + self.length = None + self.headers = collections.deque() + self.headers_sent = False + + def force_close(self): + self.closing = True + self.keepalive = False + + def force_chunked(self): + self.chunked = True + + def keep_alive(self): + if self.keepalive is None: + return not self.closing + else: + return self.keepalive + + def is_headers_sent(self): + return self.headers_sent + + def add_header(self, name, value): + """Analyze headers. Calculate content length, + removes hop headers, etc.""" + assert not self.headers_sent, 'headers have been sent already' + assert isinstance(name, str), '{!r} is not a string'.format(name) + + name = name.strip().upper() + + if name == 'CONTENT-LENGTH': + self.length = int(value) + + if name == 'CONNECTION': + val = value.lower() + # handle websocket + if 'upgrade' in val: + self.upgrade = True + # connection keep-alive + elif 'close' in val: + self.keepalive = False + elif 'keep-alive' in val and self.version >= (1, 1): + self.keepalive = True + + elif name == 'UPGRADE': + if 'websocket' in value.lower(): + self.websocket = True + self.headers.append((name, value)) + + elif name == 'TRANSFER-ENCODING' and not self.chunked: + self.chunked = value.lower().strip() == 'chunked' + + elif name not in self.HOP_HEADERS: + # ignore hopbyhop headers + self.headers.append((name, value)) + + def add_headers(self, *headers): + """Adds headers to a http message.""" + for name, value in headers: + self.add_header(name, value) + + def send_headers(self): + """Writes headers to a stream. Constructs payload writer.""" + # Chunked response is only for HTTP/1.1 clients or newer + # and there is no Content-Length header is set. + # Do not use chunked responses when the response is guaranteed to + # not have a response body (304, 204). + assert not self.headers_sent, 'headers have been sent already' + self.headers_sent = True + + if (self.chunked is True) or ( + self.length is None and + self.version >= (1, 1) and + self.status not in (304, 204)): + self.chunked = True + self.writer = self._write_chunked_payload() + + elif self.length is not None: + self.writer = self._write_length_payload(self.length) + + else: + self.writer = self._write_eof_payload() + + next(self.writer) + + self._add_default_headers() + + # status + headers + hdrs = ''.join(itertools.chain( + (self.status_line,), + *((k, ': ', v, '\r\n') for k, v in self.headers))) + + self.transport.write(hdrs.encode('ascii') + b'\r\n') + + def _add_default_headers(self): + # set the connection header + if self.upgrade: + connection = 'upgrade' + elif not self.closing if self.keepalive is None else self.keepalive: + connection = 'keep-alive' + else: + connection = 'close' + + if self.chunked: + self.headers.appendleft(('TRANSFER-ENCODING', 'chunked')) + + self.headers.appendleft(('CONNECTION', connection)) + + def write(self, chunk): + """write() writes chunk of data to a steram by using different writers. + writer uses filter to modify chunk of data. write_eof() indicates + end of stream. writer can't be used after write_eof() method + being called.""" + assert (isinstance(chunk, (bytes, bytearray)) or + chunk is EOF_MARKER), chunk + + if self._send_headers and not self.headers_sent: + self.send_headers() + + assert self.writer is not None, 'send_headers() is not called.' + + if self.filter: + chunk = self.filter.send(chunk) + while chunk not in (EOF_MARKER, EOL_MARKER): + self.writer.send(chunk) + chunk = next(self.filter) + else: + if chunk is not EOF_MARKER: + self.writer.send(chunk) + + def write_eof(self): + self.write(EOF_MARKER) + try: + self.writer.throw(tulip.EofStream()) + except StopIteration: + pass + + def _write_chunked_payload(self): + """Write data in chunked transfer encoding.""" + while True: + try: + chunk = yield + except tulip.EofStream: + self.transport.write(b'0\r\n\r\n') + break + + self.transport.write('{:x}\r\n'.format(len(chunk)).encode('ascii')) + self.transport.write(bytes(chunk)) + self.transport.write(b'\r\n') + + def _write_length_payload(self, length): + """Write specified number of bytes to a stream.""" + while True: + try: + chunk = yield + except tulip.EofStream: + break + + if length: + l = len(chunk) + if length >= l: + self.transport.write(chunk) + else: + self.transport.write(chunk[:length]) + + length = max(0, length-l) + + def _write_eof_payload(self): + while True: + try: + chunk = yield + except tulip.EofStream: + break + + self.transport.write(chunk) + + @wrap_payload_filter + def add_chunking_filter(self, chunk_size=16*1024): + """Split incoming stream into chunks.""" + buf = bytearray() + chunk = yield + + while True: + if chunk is EOF_MARKER: + if buf: + yield buf + + yield EOF_MARKER + + else: + buf.extend(chunk) + + while len(buf) >= chunk_size: + chunk = bytes(buf[:chunk_size]) + del buf[:chunk_size] + yield chunk + + chunk = yield EOL_MARKER + + @wrap_payload_filter + def add_compression_filter(self, encoding='deflate'): + """Compress incoming stream with deflate or gzip encoding.""" + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + zcomp = zlib.compressobj(wbits=zlib_mode) + + chunk = yield + while True: + if chunk is EOF_MARKER: + yield zcomp.flush() + chunk = yield EOF_MARKER + + else: + yield zcomp.compress(chunk) + chunk = yield EOL_MARKER + + +class Response(HttpMessage): + """Create http response message. + + Transport is a socket stream transport. status is a response status code, + status has to be integer value. http_version is a tuple that represents + http version, (1, 0) stands for HTTP/1.0 and (1, 1) is for HTTP/1.1 + """ + + HOP_HEADERS = { + 'CONNECTION', + 'KEEP-ALIVE', + 'PROXY-AUTHENTICATE', + 'PROXY-AUTHORIZATION', + 'TE', + 'TRAILERS', + 'TRANSFER-ENCODING', + 'UPGRADE', + 'SERVER', + 'DATE', + } + + def __init__(self, transport, status, http_version=(1, 1), close=False): + super().__init__(transport, http_version, close) + + self.status = status + self.status_line = 'HTTP/{}.{} {} {}\r\n'.format( + http_version[0], http_version[1], status, RESPONSES[status][0]) + + def _add_default_headers(self): + super()._add_default_headers() + self.headers.extend((('DATE', format_date_time(None)), + ('SERVER', self.SERVER_SOFTWARE),)) + + +class Request(HttpMessage): + + HOP_HEADERS = () + + def __init__(self, transport, method, path, + http_version=(1, 1), close=False): + super().__init__(transport, http_version, close) + + self.method = method + self.path = path + self.status_line = '{0} {1} HTTP/{2[0]}.{2[1]}\r\n'.format( + method, path, http_version) + + def _add_default_headers(self): + super()._add_default_headers() + self.headers.append(('USER-AGENT', self.SERVER_SOFTWARE)) diff --git a/tulip/http/server.py b/tulip/http/server.py new file mode 100644 index 00000000..fc5621c5 --- /dev/null +++ b/tulip/http/server.py @@ -0,0 +1,215 @@ +"""simple http server.""" + +__all__ = ['ServerHttpProtocol'] + +import http.server +import inspect +import logging +import traceback + +import tulip +from tulip.http import errors + + +RESPONSES = http.server.BaseHTTPRequestHandler.responses +DEFAULT_ERROR_MESSAGE = """ + + + {status} {reason} + + +

{status} {reason}

+ {message} + +""" + + +class ServerHttpProtocol(tulip.Protocol): + """Simple http protocol implementation. + + ServerHttpProtocol handles incoming http request. It reads request line, + request headers and request payload and calls handler_request() method. + By default it always returns with 404 respose. + + ServerHttpProtocol handles errors in incoming request, like bad + status line, bad headers or incomplete payload. If any error occurs, + connection gets closed. + + log: custom logging object + debug: enable debug mode + keep_alive: number of seconds before closing keep alive connection + loop: event loop object + """ + _request_count = 0 + _request_handler = None + _keep_alive = False # keep transport open + _keep_alive_handle = None # keep alive timer handle + + def __init__(self, *, log=logging, debug=False, + keep_alive=None, loop=None, **kwargs): + self.__dict__.update(kwargs) + self.log = log + self.debug = debug + + self._keep_alive_period = keep_alive # number of seconds to keep alive + + if keep_alive and loop is None: + loop = tulip.get_event_loop() + self._loop = loop + + def connection_made(self, transport): + self.transport = transport + self.stream = tulip.StreamBuffer(loop=self._loop) + self._request_handler = tulip.Task(self.start(), loop=self._loop) + + def data_received(self, data): + self.stream.feed_data(data) + + def eof_received(self): + self.stream.feed_eof() + + def connection_lost(self, exc): + self.stream.feed_eof() + + if self._request_handler is not None: + self._request_handler.cancel() + self._request_handler = None + if self._keep_alive_handle is not None: + self._keep_alive_handle.cancel() + self._keep_alive_handle = None + + def keep_alive(self, val): + self._keep_alive = val + + def log_access(self, status, message, *args, **kw): + pass + + def log_debug(self, *args, **kw): + if self.debug: + self.log.debug(*args, **kw) + + def log_exception(self, *args, **kw): + self.log.exception(*args, **kw) + + @tulip.coroutine + def start(self): + """Start processing of incoming requests. + It reads request line, request headers and request payload, then + calls handle_request() method. Subclass has to override + handle_request(). start() handles various excetions in request + or response handling. Connection is being closed always unless + keep_alive(True) specified. + """ + + while True: + info = None + message = None + self._request_count += 1 + self._keep_alive = False + + try: + httpstream = self.stream.set_parser( + tulip.http.http_request_parser()) + + message = yield from httpstream.read() + + # cancel keep-alive timer + if self._keep_alive_handle is not None: + self._keep_alive_handle.cancel() + self._keep_alive_handle = None + + payload = self.stream.set_parser( + tulip.http.http_payload_parser(message)) + + handler = self.handle_request(message, payload) + if (inspect.isgenerator(handler) or + isinstance(handler, tulip.Future)): + yield from handler + + except tulip.CancelledError: + self.log_debug('Ignored premature client disconnection.') + break + except errors.HttpException as exc: + self.handle_error(exc.code, info, message, exc, exc.headers) + except Exception as exc: + self.handle_error(500, info, message, exc) + finally: + if self._request_handler: + if self._keep_alive and self._keep_alive_period: + self._keep_alive_handle = self._loop.call_later( + self._keep_alive_period, self.transport.close) + else: + self.transport.close() + self._request_handler = None + break + else: + break + + def handle_error(self, status=500, + message=None, payload=None, exc=None, headers=None): + """Handle errors. + + Returns http response with specific status code. Logs additional + information. It always closes current connection.""" + try: + if self._request_handler is None: + # client has been disconnected during writing. + return + + if status == 500: + self.log_exception("Error handling request") + + try: + reason, msg = RESPONSES[status] + except KeyError: + status = 500 + reason, msg = '???', '' + + if self.debug and exc is not None: + try: + tb = traceback.format_exc() + msg += '

Traceback:

\n
{}
'.format(tb) + except: + pass + + self.log_access(status, message) + + html = DEFAULT_ERROR_MESSAGE.format( + status=status, reason=reason, message=msg) + + response = tulip.http.Response(self.transport, status, close=True) + response.add_headers( + ('Content-Type', 'text/html'), + ('Content-Length', str(len(html)))) + if headers is not None: + response.add_headers(*headers) + response.send_headers() + + response.write(html.encode('ascii')) + response.write_eof() + finally: + self.keep_alive(False) + + def handle_request(self, message, payload): + """Handle a single http request. + + Subclass should override this method. By default it always + returns 404 response. + + info: tulip.http.RequestLine instance + message: tulip.http.RawHttpMessage instance + """ + response = tulip.http.Response( + self.transport, 404, http_version=message.version, close=True) + + body = b'Page Not Found!' + + response.add_headers( + ('Content-Type', 'text/plain'), + ('Content-Length', str(len(body)))) + response.send_headers() + response.write(body) + response.write_eof() + + self.keep_alive(False) + self.log_access(404, message) diff --git a/tulip/http/session.py b/tulip/http/session.py new file mode 100644 index 00000000..9cdd9cea --- /dev/null +++ b/tulip/http/session.py @@ -0,0 +1,103 @@ +"""client session support.""" + +__all__ = ['Session'] + +import functools +import tulip +import http.cookies + + +class Session: + + def __init__(self): + self._conns = {} + self.cookies = http.cookies.SimpleCookie() + + def __del__(self): + self.close() + + def close(self): + """Close all opened transports.""" + for key, data in self._conns.items(): + for transport, proto in data: + transport.close() + + self._conns.clear() + + def update_cookies(self, cookies): + if isinstance(cookies, dict): + cookies = cookies.items() + + for name, value in cookies: + if isinstance(value, http.cookies.Morsel): + # use dict method because SimpleCookie class modifies value + dict.__setitem__(self.cookies, name, value) + else: + self.cookies[name] = value + + @tulip.coroutine + def start(self, req, loop, new_conn=False, set_cookies=True): + key = (req.host, req.port, req.ssl) + + if set_cookies and self.cookies: + req.update_cookies(self.cookies.items()) + + if not new_conn: + transport, proto = self._get(key) + + if new_conn or transport is None: + new = True + transport, proto = yield from loop.create_connection( + functools.partial(tulip.StreamProtocol, loop=loop), + req.host, req.port, ssl=req.ssl) + else: + new = False + + try: + resp = req.send(transport) + yield from resp.start( + proto, TransportWrapper( + self._release, key, transport, proto, resp)) + except: + if new: + transport.close() + raise + + return (yield from self.start(req, loop, set_cookies=False)) + + return resp + + def _get(self, key): + conns = self._conns.get(key) + if conns: + return conns.pop() + + return None, None + + def _release(self, resp, key, conn): + msg = resp.message + if msg.should_close: + conn[0].close() + else: + conns = self._conns.get(key) + if conns is None: + conns = self._conns[key] = [] + conns.append(conn) + conn[1].unset_parser() + + if resp.cookies: + self.update_cookies(resp.cookies.items()) + + +class TransportWrapper: + + def __init__(self, release, key, transport, protocol, response): + self.release = release + self.key = key + self.transport = transport + self.protocol = protocol + self.response = response + + def close(self): + self.release(self.response, self.key, + (self.transport, self.protocol)) diff --git a/tulip/http/websocket.py b/tulip/http/websocket.py new file mode 100644 index 00000000..c3dd5872 --- /dev/null +++ b/tulip/http/websocket.py @@ -0,0 +1,233 @@ +"""WebSocket protocol versions 13 and 8.""" + +__all__ = ['WebSocketParser', 'WebSocketWriter', 'do_handshake', + 'Message', 'WebSocketError', + 'MSG_TEXT', 'MSG_BINARY', 'MSG_CLOSE', 'MSG_PING', 'MSG_PONG'] + +import base64 +import binascii +import collections +import hashlib +import struct +from tulip.http import errors + +# Frame opcodes defined in the spec. +OPCODE_CONTINUATION = 0x0 +MSG_TEXT = OPCODE_TEXT = 0x1 +MSG_BINARY = OPCODE_BINARY = 0x2 +MSG_CLOSE = OPCODE_CLOSE = 0x8 +MSG_PING = OPCODE_PING = 0x9 +MSG_PONG = OPCODE_PONG = 0xa + +WS_KEY = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +WS_HDRS = ('UPGRADE', 'CONNECTION', + 'SEC-WEBSOCKET-VERSION', 'SEC-WEBSOCKET-KEY') + +Message = collections.namedtuple('Message', ['tp', 'data', 'extra']) + + +class WebSocketError(Exception): + """WebSocket protocol parser error.""" + + +def WebSocketParser(): + out, buf = yield + + while True: + message = yield from parse_message(buf) + out.feed_data(message) + + if message.tp == MSG_CLOSE: + out.feed_eof() + break + + +def parse_frame(buf): + """Return the next frame from the socket.""" + # read header + data = yield from buf.read(2) + first_byte, second_byte = struct.unpack('!BB', data) + + fin = (first_byte >> 7) & 1 + rsv1 = (first_byte >> 6) & 1 + rsv2 = (first_byte >> 5) & 1 + rsv3 = (first_byte >> 4) & 1 + opcode = first_byte & 0xf + + # frame-fin = %x0 ; more frames of this message follow + # / %x1 ; final frame of this message + # frame-rsv1 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise + # frame-rsv2 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise + # frame-rsv3 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise + if rsv1 or rsv2 or rsv3: + raise WebSocketError('Received frame with non-zero reserved bits') + + if opcode > 0x7 and fin == 0: + raise WebSocketError('Received fragmented control frame') + + if fin == 0 and opcode == OPCODE_CONTINUATION: + raise WebSocketError( + 'Received new fragment frame with non-zero opcode') + + has_mask = (second_byte >> 7) & 1 + length = (second_byte) & 0x7f + + # Control frames MUST have a payload length of 125 bytes or less + if opcode > 0x7 and length > 125: + raise WebSocketError( + "Control frame payload cannot be larger than 125 bytes") + + # read payload + if length == 126: + data = yield from buf.read(2) + length = struct.unpack_from('!H', data)[0] + elif length > 126: + data = yield from buf.read(8) + length = struct.unpack_from('!Q', data)[0] + + if has_mask: + mask = yield from buf.read(4) + + if length: + payload = yield from buf.read(length) + else: + payload = b'' + + if has_mask: + payload = bytes(b ^ mask[i % 4] for i, b in enumerate(payload)) + + return fin, opcode, payload + + +def parse_message(buf): + fin, opcode, payload = yield from parse_frame(buf) + + if opcode == OPCODE_CLOSE: + if len(payload) >= 2: + close_code = struct.unpack('!H', payload[:2])[0] + close_message = payload[2:] + return Message(OPCODE_CLOSE, close_code, close_message) + elif payload: + raise WebSocketError( + 'Invalid close frame: {} {} {!r}'.format(fin, opcode, payload)) + return Message(OPCODE_CLOSE, '', '') + + elif opcode == OPCODE_PING: + return Message(OPCODE_PING, '', '') + + elif opcode == OPCODE_PONG: + return Message(OPCODE_PONG, '', '') + + elif opcode not in (OPCODE_TEXT, OPCODE_BINARY): + raise WebSocketError("Unexpected opcode={!r}".format(opcode)) + + # load text/binary + data = [payload] + + while not fin: + fin, _opcode, payload = yield from parse_frame(buf) + if _opcode != OPCODE_CONTINUATION: + raise WebSocketError( + 'The opcode in non-fin frame is expected ' + 'to be zero, got {!r}'.format(opcode)) + else: + data.append(payload) + + if opcode == OPCODE_TEXT: + return Message(OPCODE_TEXT, b''.join(data).decode('utf-8'), '') + else: + return Message(OPCODE_BINARY, b''.join(data), '') + + +class WebSocketWriter: + + def __init__(self, transport): + self.transport = transport + + def _send_frame(self, message, opcode): + """Send a frame over the websocket with message as its payload.""" + header = bytes([0x80 | opcode]) + msg_length = len(message) + + if msg_length < 126: + header += bytes([msg_length]) + elif msg_length < (1 << 16): + header += bytes([126]) + struct.pack('!H', msg_length) + else: + header += bytes([127]) + struct.pack('!Q', msg_length) + + self.transport.write(header + message) + + def pong(self): + """Send pong message.""" + self._send_frame(b'', OPCODE_PONG) + + def ping(self): + """Send pong message.""" + self._send_frame(b'', OPCODE_PING) + + def send(self, message, binary=False): + """Send a frame over the websocket with message as its payload.""" + if isinstance(message, str): + message = message.encode('utf-8') + if binary: + self._send_frame(message, OPCODE_BINARY) + else: + self._send_frame(message, OPCODE_TEXT) + + def close(self, code=1000, message=b''): + """Close the websocket, sending the specified code and message.""" + if isinstance(message, str): + message = message.encode('utf-8') + self._send_frame( + struct.pack('!H%ds' % len(message), code, message), + opcode=OPCODE_CLOSE) + + +def do_handshake(method, headers, transport): + """Prepare WebSocket handshake. It return http response code, + response headers, websocket parser, websocket writer. It does not + perform any IO.""" + + # WebSocket accepts only GET + if method.upper() != 'GET': + raise errors.HttpErrorException(405, headers=(('Allow', 'GET'),)) + + headers = dict(((hdr, val) for hdr, val in headers if hdr in WS_HDRS)) + + if 'websocket' != headers.get('UPGRADE', '').lower().strip(): + raise errors.BadRequestException( + 'No WebSocket UPGRADE hdr: {}\n' + 'Can "Upgrade" only to "WebSocket".'.format( + headers.get('UPGRADE'))) + + if 'upgrade' not in headers.get('CONNECTION', '').lower(): + raise errors.BadRequestException( + 'No CONNECTION upgrade hdr: {}'.format( + headers.get('CONNECTION'))) + + # check supported version + version = headers.get('SEC-WEBSOCKET-VERSION') + if version not in ('13', '8', '7'): + raise errors.BadRequestException( + 'Unsupported version: {}'.format(version)) + + # check client handshake for validity + key = headers.get('SEC-WEBSOCKET-KEY') + try: + if not key or len(base64.b64decode(key)) != 16: + raise errors.BadRequestException( + 'Handshake error: {!r}'.format(key)) + except binascii.Error: + raise errors.BadRequestException( + 'Handshake error: {!r}'.format(key)) from None + + # response code, headers, parser, writer + return (101, + (('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('TRANSFER-ENCODING', 'chunked'), + ('SEC-WEBSOCKET-ACCEPT', base64.b64encode( + hashlib.sha1(key.encode() + WS_KEY).digest()).decode())), + WebSocketParser(), + WebSocketWriter(transport)) diff --git a/tulip/http/wsgi.py b/tulip/http/wsgi.py new file mode 100644 index 00000000..738e100f --- /dev/null +++ b/tulip/http/wsgi.py @@ -0,0 +1,227 @@ +"""wsgi server. + +TODO: + * proxy protocol + * x-forward security + * wsgi file support (os.sendfile) +""" + +__all__ = ['WSGIServerHttpProtocol'] + +import inspect +import io +import os +import sys +from urllib.parse import unquote, urlsplit + +import tulip +import tulip.http +from tulip.http import server + + +class WSGIServerHttpProtocol(server.ServerHttpProtocol): + """HTTP Server that implements the Python WSGI protocol. + + It uses 'wsgi.async' of 'True'. 'wsgi.input' can behave differently + depends on 'readpayload' constructor parameter. If readpayload is set to + True, wsgi server reads all incoming data into BytesIO object and + sends it as 'wsgi.input' environ var. If readpayload is set to false + 'wsgi.input' is a StreamReader and application should read incoming + data with "yield from environ['wsgi.input'].read()". It defaults to False. + """ + + SCRIPT_NAME = os.environ.get('SCRIPT_NAME', '') + + def __init__(self, app, readpayload=False, is_ssl=False, *args, **kw): + super().__init__(*args, **kw) + + self.wsgi = app + self.is_ssl = is_ssl + self.readpayload = readpayload + + def create_wsgi_response(self, message): + return WsgiResponse(self.transport, message) + + def create_wsgi_environ(self, message, payload): + uri_parts = urlsplit(message.path) + url_scheme = 'https' if self.is_ssl else 'http' + + environ = { + 'wsgi.input': payload, + 'wsgi.errors': sys.stderr, + 'wsgi.version': (1, 0), + 'wsgi.async': True, + 'wsgi.multithread': False, + 'wsgi.multiprocess': False, + 'wsgi.run_once': False, + 'wsgi.file_wrapper': FileWrapper, + 'wsgi.url_scheme': url_scheme, + 'SERVER_SOFTWARE': tulip.http.HttpMessage.SERVER_SOFTWARE, + 'REQUEST_METHOD': message.method, + 'QUERY_STRING': uri_parts.query or '', + 'RAW_URI': message.path, + 'SERVER_PROTOCOL': 'HTTP/%s.%s' % message.version + } + + # authors should be aware that REMOTE_HOST and REMOTE_ADDR + # may not qualify the remote addr: + # http://www.ietf.org/rfc/rfc3875 + forward = self.transport.get_extra_info('addr', '127.0.0.1') + script_name = self.SCRIPT_NAME + server = forward + + for hdr_name, hdr_value in message.headers: + if hdr_name == 'EXPECT': + # handle expect + if hdr_value.lower() == '100-continue': + self.transport.write(b'HTTP/1.1 100 Continue\r\n\r\n') + elif hdr_name == 'HOST': + server = hdr_value + elif hdr_name == 'SCRIPT_NAME': + script_name = hdr_value + elif hdr_name == 'CONTENT-TYPE': + environ['CONTENT_TYPE'] = hdr_value + continue + elif hdr_name == 'CONTENT-LENGTH': + environ['CONTENT_LENGTH'] = hdr_value + continue + + key = 'HTTP_%s' % hdr_name.replace('-', '_') + if key in environ: + hdr_value = '%s,%s' % (environ[key], hdr_value) + + environ[key] = hdr_value + + if isinstance(forward, str): + # we only took the last one + # http://en.wikipedia.org/wiki/X-Forwarded-For + if ',' in forward: + forward = forward.rsplit(',', 1)[-1].strip() + + # find host and port on ipv6 address + if '[' in forward and ']' in forward: + host = forward.split(']')[0][1:].lower() + elif ':' in forward and forward.count(':') == 1: + host = forward.split(':')[0].lower() + else: + host = forward + + forward = forward.split(']')[-1] + if ':' in forward and forward.count(':') == 1: + port = forward.split(':', 1)[1] + else: + port = 80 + + remote = (host, port) + else: + remote = forward + + environ['REMOTE_ADDR'] = remote[0] + environ['REMOTE_PORT'] = str(remote[1]) + + if isinstance(server, str): + server = server.split(':') + if len(server) == 1: + server.append('80' if url_scheme == 'http' else '443') + + environ['SERVER_NAME'] = server[0] + environ['SERVER_PORT'] = str(server[1]) + + path_info = uri_parts.path + if script_name: + path_info = path_info.split(script_name, 1)[-1] + + environ['PATH_INFO'] = unquote(path_info) + environ['SCRIPT_NAME'] = script_name + + environ['tulip.reader'] = self.stream + environ['tulip.writer'] = self.transport + + return environ + + @tulip.coroutine + def handle_request(self, message, payload): + """Handle a single HTTP request""" + + if self.readpayload: + wsgiinput = io.BytesIO() + chunk = yield from payload.read() + while chunk: + wsgiinput.write(chunk) + chunk = yield from payload.read() + wsgiinput.seek(0) + payload = wsgiinput + + environ = self.create_wsgi_environ(message, payload) + response = self.create_wsgi_response(message) + + riter = self.wsgi(environ, response.start_response) + if isinstance(riter, tulip.Future) or inspect.isgenerator(riter): + riter = yield from riter + + resp = response.response + try: + for item in riter: + if isinstance(item, tulip.Future): + item = yield from item + resp.write(item) + + resp.write_eof() + finally: + if hasattr(riter, 'close'): + riter.close() + + if resp.keep_alive(): + self.keep_alive(True) + + +class FileWrapper: + """Custom file wrapper.""" + + def __init__(self, fobj, chunk_size=8192): + self.fobj = fobj + self.chunk_size = chunk_size + if hasattr(fobj, 'close'): + self.close = fobj.close + + def __iter__(self): + return self + + def __next__(self): + data = self.fobj.read(self.chunk_size) + if data: + return data + raise StopIteration + + +class WsgiResponse: + """Implementation of start_response() callable as specified by PEP 3333""" + + status = None + + def __init__(self, transport, message): + self.transport = transport + self.message = message + + def start_response(self, status, headers, exc_info=None): + if exc_info: + try: + if self.status: + raise exc_info[1] + finally: + exc_info = None + + status_code = int(status.split(' ', 1)[0]) + + self.status = status + resp = self.response = tulip.http.Response( + self.transport, status_code, + self.message.version, self.message.should_close) + resp.add_headers(*headers) + + # send headers immediately for websocket connection + if status_code == 101 and resp.upgrade and resp.websocket: + resp.send_headers() + else: + resp._send_headers = True + return self.response.write diff --git a/tulip/locks.py b/tulip/locks.py new file mode 100644 index 00000000..87937ec0 --- /dev/null +++ b/tulip/locks.py @@ -0,0 +1,403 @@ +"""Synchronization primitives.""" + +__all__ = ['Lock', 'EventWaiter', 'Condition', 'Semaphore'] + +import collections + +from . import events +from . import futures +from . import tasks + + +class Lock: + """Primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned + by a particular coroutine when locked. A primitive lock is in one + of two states, 'locked' or 'unlocked'. + + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() + changes the state to locked and returns immediately. When the + state is locked, acquire() blocks until a call to release() in + another coroutine changes it to unlocked, then the acquire() call + resets it to locked and returns. The release() method should only + be called in the locked state; it changes the state to unlocked + and returns immediately. If an attempt is made to release an + unlocked lock, a RuntimeError will be raised. + + When more than one coroutine is blocked in acquire() waiting for + the state to turn to unlocked, only one coroutine proceeds when a + release() call resets the state to unlocked; first coroutine which + is blocked in acquire() is being processed. + + acquire() is a coroutine and should be called with 'yield from'. + + Locks also support the context manager protocol. '(yield from lock)' + should be used as context manager expression. + + Usage: + + lock = Lock() + ... + yield from lock + try: + ... + finally: + lock.release() + + Context manager usage: + + lock = Lock() + ... + with (yield from lock): + ... + + Lock objects can be tested for locking state: + + if not lock.locked(): + yield from lock + else: + # lock is acquired + ... + + """ + + def __init__(self, *, loop=None): + self._waiters = collections.deque() + self._locked = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + extra = 'locked' if self._locked else 'unlocked' + if self._waiters: + extra = '{},waiters:{}'.format(extra, len(self._waiters)) + return '<{} [{}]>'.format(res[1:-1], extra) + + def locked(self): + """Return true if lock is acquired.""" + return self._locked + + @tasks.coroutine + def acquire(self): + """Acquire a lock. + + This method blocks until the lock is unlocked, then sets it to + locked and returns True. + """ + if not self._waiters and not self._locked: + self._locked = True + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + self._locked = True + return True + finally: + self._waiters.remove(fut) + + def release(self): + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + If any other coroutines are blocked waiting for the lock to become + unlocked, allow exactly one of them to proceed. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + # Wake up the first waiter who isn't cancelled. + for fut in self._waiters: + if not fut.cancelled(): + fut.set_result(True) + break + else: + raise RuntimeError('Lock is not acquired.') + + def __enter__(self): + if not self._locked: + raise RuntimeError( + '"yield from" should be used as context manager expression') + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self + + +# TODO: Why not call this Event? +class EventWaiter: + """A EventWaiter implementation, our equivalent to threading.Event + + Class implementing event objects. An event manages a flag that can be set + to true with the set() method and reset to false with the clear() method. + The wait() method blocks until the flag is true. The flag is initially + false. + """ + + def __init__(self, *, loop=None): + self._waiters = collections.deque() + self._value = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + # TODO: add waiters:N if > 0. + res = super().__repr__() + return '<{} [{}]>'.format(res[1:-1], 'set' if self._value else 'unset') + + def is_set(self): + """Return true if and only if the internal flag is true.""" + return self._value + + def set(self): + """Set the internal flag to true. All coroutines waiting for it to + become true are awakened. Coroutine that call wait() once the flag is + true will not block at all. + """ + if not self._value: + self._value = True + + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + + def clear(self): + """Reset the internal flag to false. Subsequently, coroutines calling + wait() will block until set() is called to set the internal flag + to true again.""" + self._value = False + + @tasks.coroutine + def wait(self): + """Block until the internal flag is true. + + If the internal flag is true on entry, return True + immediately. Otherwise, block until another coroutine calls + set() to set the flag to true, then return True. + """ + if self._value: + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + return True + finally: + self._waiters.remove(fut) + + +# TODO: Why is this a Lock subclass? threading.Condition *has* a lock. +class Condition(Lock): + """A Condition implementation. + + This class implements condition variable objects. A condition variable + allows one or more coroutines to wait until they are notified by another + coroutine. + """ + + def __init__(self, *, loop=None): + super().__init__(loop=loop) + self._condition_waiters = collections.deque() + + # TODO: Add __repr__() with len(_condition_waiters). + + @tasks.coroutine + def wait(self): + """Wait until notified. + + If the calling coroutine has not acquired the lock when this + method is called, a RuntimeError is raised. + + This method releases the underlying lock, and then blocks + until it is awakened by a notify() or notify_all() call for + the same condition variable in another coroutine. Once + awakened, it re-acquires the lock and returns True. + """ + if not self._locked: + raise RuntimeError('cannot wait on un-acquired lock') + + keep_lock = True + self.release() + try: + fut = futures.Future(loop=self._loop) + self._condition_waiters.append(fut) + try: + yield from fut + return True + finally: + self._condition_waiters.remove(fut) + + except GeneratorExit: + keep_lock = False # Prevent yield in finally clause. + raise + finally: + if keep_lock: + yield from self.acquire() + + @tasks.coroutine + def wait_for(self, predicate): + """Wait until a predicate becomes true. + + The predicate should be a callable which result will be + interpreted as a boolean value. The final predicate value is + the return value. + """ + result = predicate() + while not result: + yield from self.wait() + result = predicate() + return result + + def notify(self, n=1): + """By default, wake up one coroutine waiting on this condition, if any. + If the calling coroutine has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up at most n of the coroutines waiting for the + condition variable; it is a no-op if no coroutines are waiting. + + Note: an awakened coroutine does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self._locked: + raise RuntimeError('cannot notify on un-acquired lock') + + idx = 0 + for fut in self._condition_waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self): + """Wake up all threads waiting on this condition. This method acts + like notify(), but wakes up all waiting threads instead of one. If the + calling thread has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._condition_waiters)) + + +class Semaphore: + """A Semaphore implementation. + + A semaphore manages an internal counter which is decremented by each + acquire() call and incremented by each release() call. The counter + can never go below zero; when acquire() finds that it is zero, it blocks, + waiting until some other thread calls release(). + + Semaphores also support the context manager protocol. + + The first optional argument gives the initial value for the internal + counter; it defaults to 1. If the value given is less than 0, + ValueError is raised. + + The second optional argument determins can semophore be released more than + initial internal counter value; it defaults to False. If the value given + is True and number of release() is more than number of successfull + acquire() calls ValueError is raised. + """ + + def __init__(self, value=1, bound=False, *, loop=None): + if value < 0: + raise ValueError("Semaphore initial value must be > 0") + self._value = value + self._bound = bound + self._bound_value = value + self._waiters = collections.deque() + self._locked = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + # TODO: add waiters:N if > 0. + res = super().__repr__() + return '<{} [{}]>'.format( + res[1:-1], + 'locked' if self._locked else 'unlocked,value:{}'.format( + self._value)) + + def locked(self): + """Returns True if semaphore can not be acquired immediately.""" + return self._locked + + @tasks.coroutine + def acquire(self): + """Acquire a semaphore. + + If the internal counter is larger than zero on entry, + decrement it by one and return True immediately. If it is + zero on entry, block, waiting until some other coroutine has + called release() to make it larger than 0, and then return + True. + """ + if not self._waiters and self._value > 0: + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + self._value -= 1 + if self._value == 0: + self._locked = True + return True + finally: + self._waiters.remove(fut) + + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + When it was zero on entry and another coroutine is waiting for it to + become larger than zero again, wake up that coroutine. + + If Semaphore is create with "bound" paramter equals true, then + release() method checks to make sure its current value doesn't exceed + its initial value. If it does, ValueError is raised. + """ + if self._bound and self._value >= self._bound_value: + raise ValueError('Semaphore released too many times') + + self._value += 1 + self._locked = False + + for waiter in self._waiters: + if not waiter.done(): + waiter.set_result(True) + break + + def __enter__(self): + # TODO: This is questionable. How do we know the user actually + # wrote "with (yield from sema)" instead of "with sema"? + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self diff --git a/tulip/log.py b/tulip/log.py new file mode 100644 index 00000000..b918fe54 --- /dev/null +++ b/tulip/log.py @@ -0,0 +1,6 @@ +"""Tulip logging configuration""" + +import logging + + +tulip_log = logging.getLogger("tulip") diff --git a/tulip/parsers.py b/tulip/parsers.py new file mode 100644 index 00000000..43ddc2e9 --- /dev/null +++ b/tulip/parsers.py @@ -0,0 +1,399 @@ +"""Parser is a generator function. + +Parser receives data with generator's send() method and sends data to +destination DataBuffer. Parser receives ParserBuffer and DataBuffer objects +as a parameters of the first send() call, all subsequent send() calls should +send bytes objects. Parser sends parsed 'term' to desitnation buffer with +DataBuffer.feed_data() method. DataBuffer object should implement two methods. +feed_data() - parser uses this method to send parsed protocol data. +feed_eof() - parser uses this method for indication of end of parsing stream. +To indicate end of incoming data stream EofStream exception should be sent +into parser. Parser could throw exceptions. + +There are three stages: + + * Data flow chain: + + 1. Application creates StreamBuffer object for storing incoming data. + 2. StreamBuffer creates ParserBuffer as internal data buffer. + 3. Application create parser and set it into stream buffer: + + parser = http_request_parser() + data_buffer = stream.set_parser(parser) + + 3. At this stage StreamBuffer creates DataBuffer object and passes it + and internal buffer into parser with first send() call. + + def set_parser(self, parser): + next(parser) + data_buffer = DataBuffer() + parser.send((data_buffer, self._buffer)) + return data_buffer + + 4. Application waits data on data_buffer.read() + + while True: + msg = yield form data_buffer.read() + ... + + * Data flow: + + 1. Tulip's transport reads data from socket and sends data to protocol + with data_received() call. + 2. Protocol sends data to StreamBuffer with feed_data() call. + 3. StreamBuffer sends data into parser with generator's send() method. + 4. Parser processes incoming data and sends parsed data + to DataBuffer with feed_data() + 4. Application received parsed data from DataBuffer.read() + + * Eof: + + 1. StreamBuffer recevies eof with feed_eof() call. + 2. StreamBuffer throws EofStream exception into parser. + 3. Then it unsets parser. + +_SocketSocketTransport -> + -> "protocol" -> StreamBuffer -> "parser" -> DataBuffer <- "application" + +""" +__all__ = ['EofStream', 'StreamBuffer', 'StreamProtocol', + 'ParserBuffer', 'DataBuffer', 'lines_parser', 'chunks_parser'] + +import collections + +from . import tasks +from . import futures +from . import protocols + + +class EofStream(Exception): + """eof stream indication.""" + + +class StreamBuffer: + """StreamBuffer manages incoming bytes stream and protocol parsers. + + StreamBuffer uses ParserBuffer as internal buffer. + + set_parser() sets current parser, it creates DataBuffer object + and sends ParserBuffer and DataBuffer into parser generator. + + unset_parser() sends EofStream into parser and then removes it. + """ + + def __init__(self, *, loop=None): + self._loop = loop + self._buffer = ParserBuffer() + self._eof = False + self._parser = None + self._parser_buffer = None + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + if self._parser_buffer is not None: + self._parser_buffer.set_exception(exc) + self._parser = None + self._parser_buffer = None + + def feed_data(self, data): + """send data to current parser or store in buffer.""" + if not data: + return + + if self._parser: + try: + self._parser.send(data) + except StopIteration: + self._parser = None + self._parser_buffer = None + except Exception as exc: + self._parser_buffer.set_exception(exc) + self._parser = None + self._parser_buffer = None + else: + self._buffer.feed_data(data) + + def feed_eof(self): + """send eof to all parsers, recursively.""" + if self._parser: + try: + self._parser.throw(EofStream()) + except StopIteration: + pass + except EofStream: + self._parser_buffer.feed_eof() + except Exception as exc: + self._parser_buffer.set_exception(exc) + + self._parser = None + self._parser_buffer = None + + self._eof = True + + def set_parser(self, p): + """set parser to stream. return parser's DataStream.""" + if self._parser: + self.unset_parser() + + out = DataBuffer(loop=self._loop) + if self._exception: + out.set_exception(self._exception) + return out + + # init generator + next(p) + try: + # initialize parser with data and parser buffers + p.send((out, self._buffer)) + except StopIteration: + pass + except Exception as exc: + out.set_exception(exc) + else: + # parser still require more data + self._parser = p + self._parser_buffer = out + + if self._eof: + self.unset_parser() + + return out + + def unset_parser(self): + """unset parser, send eof to the parser and then remove it.""" + if self._parser is None: + return + + try: + self._parser.throw(EofStream()) + except StopIteration: + pass + except EofStream: + self._parser_buffer.feed_eof() + except Exception as exc: + self._parser_buffer.set_exception(exc) + finally: + self._parser = None + self._parser_buffer = None + + +class StreamProtocol(StreamBuffer, protocols.Protocol): + """Tulip's stream protocol based on StreamBuffer""" + + transport = None + + data_received = StreamBuffer.feed_data + + eof_received = StreamBuffer.feed_eof + + def connection_made(self, transport): + self.transport = transport + + def connection_lost(self, exc): + self.transport = None + + if exc is not None: + self.set_exception(exc) + else: + self.feed_eof() + + +class DataBuffer: + """DataBuffer is a destination for parsed data.""" + + def __init__(self, *, loop=None): + self._loop = loop + self._buffer = collections.deque() + self._eof = False + self._waiter = None + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + waiter = self._waiter + if waiter is not None: + self._waiter = None + if not waiter.done(): + waiter.set_exception(exc) + + def feed_data(self, data): + self._buffer.append(data) + + waiter = self._waiter + if waiter is not None: + self._waiter = None + waiter.set_result(True) + + def feed_eof(self): + self._eof = True + + waiter = self._waiter + if waiter is not None: + self._waiter = None + waiter.set_result(False) + + @tasks.coroutine + def read(self): + if self._exception is not None: + raise self._exception + + if not self._buffer and not self._eof: + assert not self._waiter + self._waiter = futures.Future(loop=self._loop) + yield from self._waiter + + if self._buffer: + return self._buffer.popleft() + else: + return None + + +class ParserBuffer(bytearray): + """ParserBuffer is a bytearray extension. + + ParserBuffer provides helper methods for parsers. + """ + + def __init__(self, *args): + super().__init__(*args) + + self.offset = 0 + self.size = 0 + self._writer = self._feed_data() + next(self._writer) + + def _shrink(self): + if self.offset: + del self[:self.offset] + self.offset = 0 + self.size = len(self) + + def _feed_data(self): + while True: + chunk = yield + if chunk: + chunk_len = len(chunk) + self.size += chunk_len + self.extend(chunk) + + # shrink buffer + if (self.offset and len(self) > 5120): + self._shrink() + + def feed_data(self, data): + self._writer.send(data) + + def read(self, size): + """read() reads specified amount of bytes.""" + + while True: + if self.size >= size: + start, end = self.offset, self.offset + size + self.offset = end + self.size = self.size - size + return self[start:end] + + self._writer.send((yield)) + + def readsome(self, size=None): + """reads size of less amount of bytes.""" + + while True: + if self.size > 0: + if size is None or self.size < size: + size = self.size + + start, end = self.offset, self.offset + size + self.offset = end + self.size = self.size - size + + return self[start:end] + + self._writer.send((yield)) + + def readuntil(self, stop, limit=None, exc=ValueError): + assert isinstance(stop, bytes) and stop, \ + 'bytes is required: {!r}'.format(stop) + + stop_len = len(stop) + + while True: + pos = self.find(stop, self.offset) + if pos >= 0: + end = pos + stop_len + size = end - self.offset + if limit is not None and size > limit: + raise exc('Line is too long.') + + start, self.offset = self.offset, end + self.size = self.size - size + + return self[start:end] + else: + if limit is not None and self.size > limit: + raise exc('Line is too long.') + + self._writer.send((yield)) + + def skip(self, size): + """skip() skips specified amount of bytes.""" + + while self.size < size: + self._writer.send((yield)) + + self.size -= size + self.offset += size + + def skipuntil(self, stop): + """skipuntil() reads until `stop` bytes sequence.""" + assert isinstance(stop, bytes) and stop, \ + 'bytes is required: {!r}'.format(stop) + + stop_len = len(stop) + + while True: + stop_line = self.find(stop, self.offset) + if stop_line >= 0: + end = stop_line + stop_len + self.size = self.size - (end - self.offset) + self.offset = end + return + else: + self.size = 0 + self.offset = len(self) - 1 + + self._writer.send((yield)) + + def __bytes__(self): + return bytes(self[self.offset:]) + + +def lines_parser(limit=2**16, exc=ValueError): + """Lines parser. + + lines parser splits a bytes stream into a chunks of data, each chunk ends + with \n symbol.""" + out, buf = yield + + while True: + out.feed_data((yield from buf.readuntil(b'\n', limit, exc))) + + +def chunks_parser(size=8196): + """Chunks parser. + + chunks parser splits a bytes stream into a specified + size chunks of data.""" + out, buf = yield + + while True: + out.feed_data((yield from buf.read(size))) diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py new file mode 100644 index 00000000..cda87918 --- /dev/null +++ b/tulip/proactor_events.py @@ -0,0 +1,288 @@ +"""Event loop using a proactor and related classes. + +A proactor is a "notify-on-completion" multiplexer. Currently a +proactor is only implemented on Windows with IOCP. +""" + +import socket + +from . import base_events +from . import constants +from . import futures +from . import transports +from .log import tulip_log + + +class _ProactorBasePipeTransport(transports.BaseTransport): + """Base class for pipe and socket transports.""" + + def __init__(self, loop, sock, protocol, waiter=None, extra=None): + super().__init__(extra) + self._set_extra(sock) + self._loop = loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._read_fut = None + self._write_fut = None + self._writing_disabled = False + self._conn_lost = 0 + self._closing = False # Set when close() called. + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _set_extra(self, sock): + self._extra['pipe'] = sock + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + if not self._buffer and self._write_fut is None: + self._loop.call_soon(self._call_connection_lost, None) + if self._read_fut is not None: + self._read_fut.cancel() + + def _fatal_error(self, exc): + tulip_log.exception('Fatal error for %s', self) + self._force_close(exc) + + def _force_close(self, exc): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + if self._write_fut: + self._write_fut.cancel() + if self._read_fut: + self._read_fut.cancel() + self._write_fut = self._read_fut = None + self._buffer = [] + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + # XXX If there is a pending overlapped read on the other + # end then it may fail with ERROR_NETNAME_DELETED if we + # just close our end. First calling shutdown() seems to + # cure it, but maybe using DisconnectEx() would be better. + if hasattr(self._sock, 'shutdown'): + self._sock.shutdown(socket.SHUT_RDWR) + self._sock.close() + + +class _ProactorReadPipeTransport(_ProactorBasePipeTransport, + transports.ReadTransport): + """Transport for read pipes.""" + + def __init__(self, loop, sock, protocol, waiter=None, extra=None): + super().__init__(loop, sock, protocol, waiter, extra) + self._loop.call_soon(self._loop_reading) + + def _loop_reading(self, fut=None): + data = None + + try: + if fut is not None: + assert fut is self._read_fut + self._read_fut = None + data = fut.result() # deliver data later in "finally" clause + + if self._closing: + # since close() has been called we ignore any read data + data = None + return + + if data == b'': + # we got end-of-file so no need to reschedule a new read + return + + # reschedule a new read + self._read_fut = self._loop._proactor.recv(self._sock, 4096) + except ConnectionAbortedError as exc: + if not self._closing: + self._fatal_error(exc) + except ConnectionResetError as exc: + self._force_close(exc) + except OSError as exc: + self._fatal_error(exc) + except futures.CancelledError: + if not self._closing: + raise + else: + self._read_fut.add_done_callback(self._loop_reading) + finally: + if data: + self._protocol.data_received(data) + elif data is not None: + try: + self._protocol.eof_received() + finally: + self.close() + + +class _ProactorWritePipeTransport(_ProactorBasePipeTransport, + transports.WriteTransport): + """Transport for write pipes.""" + + def write(self, data): + assert isinstance(data, bytes), repr(data) + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + self._buffer.append(data) + if self._write_fut is None and not self._writing_disabled: + self._loop_writing() + + def _loop_writing(self, f=None): + try: + assert f is self._write_fut + self._write_fut = None + if f: + f.result() + data = b''.join(self._buffer) + self._buffer = [] + if not data: + if self._closing: + self._loop.call_soon(self._call_connection_lost, None) + return + if not self._writing_disabled: + self._write_fut = self._loop._proactor.send(self._sock, data) + self._write_fut.add_done_callback(self._loop_writing) + except OSError as exc: + self._fatal_error(exc) + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._force_close(None) + + def pause_writing(self): + self._writing_disabled = True + + def resume_writing(self): + self._writing_disabled = False + if self._buffer and self._write_fut is None: + self._loop_writing() + + def discard_output(self): + if self._buffer: + self._buffer = [] + + +class _ProactorSocketTransport(_ProactorReadPipeTransport, + _ProactorWritePipeTransport, + transports.Transport): + """Transport for connected sockets.""" + + def _set_extra(self, sock): + self._extra['socket'] = sock + + +class BaseProactorEventLoop(base_events.BaseEventLoop): + + def __init__(self, proactor): + super().__init__() + tulip_log.debug('Using proactor: %s', proactor.__class__.__name__) + self._proactor = proactor + self._selector = proactor # convenient alias + proactor.set_loop(self) + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + return _ProactorSocketTransport(self, sock, protocol, waiter, extra) + + def _make_read_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorReadPipeTransport(self, sock, protocol, waiter, extra) + + def _make_write_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorWritePipeTransport(self, sock, protocol, waiter, extra) + + def close(self): + if self._proactor is not None: + self._close_self_pipe() + self._proactor.close() + self._proactor = None + self._selector = None + + def sock_recv(self, sock, n): + return self._proactor.recv(sock, n) + + def sock_sendall(self, sock, data): + return self._proactor.send(sock, data) + + def sock_connect(self, sock, address): + return self._proactor.connect(sock, address) + + def sock_accept(self, sock): + return self._proactor.accept(sock) + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.call_soon(self._loop_self_reading) + + def _loop_self_reading(self, f=None): + try: + if f is not None: + f.result() # may raise + f = self._proactor.recv(self._ssock, 4096) + except: + self.close() + raise + else: + f.add_done_callback(self._loop_self_reading) + + def _write_to_self(self): + self._csock.send(b'x') + + def _start_serving(self, protocol_factory, sock, ssl=None): + assert not ssl, 'IocpEventLoop imcompatible with SSL.' + + def loop(f=None): + try: + if f: + conn, addr = f.result() + protocol = protocol_factory() + self._make_socket_transport( + conn, protocol, extra={'addr': addr}) + f = self._proactor.accept(sock) + except OSError: + sock.close() + tulip_log.exception('Accept failed') + except futures.CancelledError: + sock.close() + else: + f.add_done_callback(loop) + self.call_soon(loop) + + def _process_events(self, event_list): + pass # XXX hard work currently done in poll + + def stop_serving(self, sock): + self._proactor.stop_serving(sock) + sock.close() diff --git a/tulip/protocols.py b/tulip/protocols.py new file mode 100644 index 00000000..d76f25a2 --- /dev/null +++ b/tulip/protocols.py @@ -0,0 +1,100 @@ +"""Abstract Protocol class.""" + +__all__ = ['Protocol', 'DatagramProtocol'] + + +class BaseProtocol: + """ABC for base protocol class. + + Usually user implements protocols that derived from BaseProtocol + like Protocol or ProcessProtocol. + + The only case when BaseProtocol should be implemented directly is + write-only transport like write pipe + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the pipe connection. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +class Protocol(BaseProtocol): + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing (they don't raise exceptions). + + When the user wants to requests a transport, they pass a protocol + factory to a utility function (e.g., EventLoop.create_connection()). + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_lost() will be called exactly once + with either an exception object or None as an argument. + + State machine of calls: + + start -> CM [-> DR*] [-> ER?] -> CL -> end + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + The default implementation does nothing. + + TODO: By default close the transport. But we don't have the + transport as an instance variable (connection_made() may not + set it). + """ + + +class DatagramProtocol(BaseProtocol): + """ABC representing a datagram protocol.""" + + def datagram_received(self, data, addr): + """Called when some datagram is received.""" + + def connection_refused(self, exc): + """Connection is refused.""" + + +class SubprocessProtocol(BaseProtocol): + """ABC representing a protocol for subprocess calls.""" + + def pipe_data_received(self, fd, data): + """Called when subprocess write a data into stdout/stderr pipes. + + fd is int file dascriptor. + data is bytes object. + """ + + def pipe_connection_lost(self, fd, exc): + """Called when a file descriptor associated with the child process is + closed. + + fd is the int file descriptor that was closed. + """ + + def process_exited(self): + """Called when subprocess has exited. + """ diff --git a/tulip/queues.py b/tulip/queues.py new file mode 100644 index 00000000..4a46f1a2 --- /dev/null +++ b/tulip/queues.py @@ -0,0 +1,290 @@ +"""Queues""" + +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue', + 'Full', 'Empty'] + +import collections +import heapq +import queue + +from . import events +from . import futures +from . import locks +from .tasks import coroutine + + +# Re-export queue.Full and .Empty exceptions. +Full = queue.Full +Empty = queue.Empty + + +class Queue: + """A queue, useful for coordinating producer and consumer coroutines. + + If maxsize is less than or equal to zero, the queue size is infinite. If it + is an integer greater than 0, then "yield from put()" will block when the + queue reaches maxsize, until an item is removed by get(). + + Unlike the standard library Queue, you can reliably know this Queue's size + with qsize(), since your single-threaded Tulip application won't be + interrupted between calling qsize() and doing an operation on the Queue. + """ + + def __init__(self, maxsize=0, *, loop=None): + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._maxsize = maxsize + + # Futures. + self._getters = collections.deque() + # Pairs of (item, Future). + self._putters = collections.deque() + self._init(maxsize) + + def _init(self, maxsize): + self._queue = collections.deque() + + def _get(self): + return self._queue.popleft() + + def _put(self, item): + self._queue.append(item) + + def __repr__(self): + return '<{} at {:#x} {}>'.format( + type(self).__name__, id(self), self._format()) + + def __str__(self): + return '<{} {}>'.format(type(self).__name__, self._format()) + + def _format(self): + result = 'maxsize={!r}'.format(self._maxsize) + if getattr(self, '_queue', None): + result += ' _queue={!r}'.format(list(self._queue)) + if self._getters: + result += ' _getters[{}]'.format(len(self._getters)) + if self._putters: + result += ' _putters[{}]'.format(len(self._putters)) + return result + + def _consume_done_getters(self, waiters): + # Delete waiters at the head of the get() queue who've timed out. + while waiters and waiters[0].done(): + waiters.popleft() + + def _consume_done_putters(self): + # Delete waiters at the head of the put() queue who've timed out. + while self._putters and self._putters[0][1].done(): + self._putters.popleft() + + def qsize(self): + """Number of items in the queue.""" + return len(self._queue) + + @property + def maxsize(self): + """Number of items allowed in the queue.""" + return self._maxsize + + def empty(self): + """Return True if the queue is empty, False otherwise.""" + return not self._queue + + def full(self): + """Return True if there are maxsize items in the queue. + + Note: if the Queue was initialized with maxsize=0 (the default), + then full() is never True. + """ + if self._maxsize <= 0: + return False + else: + return self.qsize() == self._maxsize + + @coroutine + def put(self, item): + """Put an item into the queue. + + If you yield from put(), wait until a free slot is available + before adding item. + """ + self._consume_done_getters(self._getters) + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + waiter = futures.Future(loop=self._loop) + + self._putters.append((item, waiter)) + try: + yield from waiter + except futures.CancelledError: + raise Full + + else: + self._put(item) + + def put_nowait(self, item): + """Put an item into the queue without blocking. + + If no free slot is immediately available, raise Full. + """ + self._consume_done_getters(self._getters) + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + raise Full + else: + self._put(item) + + @coroutine + def get(self): + """Remove and return an item from the queue. + + If you yield from get(), wait until a item is available. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + + # When a getter runs and frees up a slot so this putter can + # run, we need to defer the put for a tick to ensure that + # getters and putters alternate perfectly. See + # ChannelTest.test_wait. + self._loop.call_soon(putter.set_result, None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + waiter = futures.Future(loop=self._loop) + + self._getters.append(waiter) + try: + return (yield from waiter) + except futures.CancelledError: + raise Empty + + def get_nowait(self): + """Remove and return an item from the queue. + + Return an item if one is immediately available, else raise Full. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + # Wake putter on next tick. + putter.set_result(None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + raise Empty + + +class PriorityQueue(Queue): + """A subclass of Queue; retrieves entries in priority order (lowest first). + + Entries are typically tuples of the form: (priority number, data). + """ + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item, heappush=heapq.heappush): + heappush(self._queue, item) + + def _get(self, heappop=heapq.heappop): + return heappop(self._queue) + + +class LifoQueue(Queue): + """A subclass of Queue that retrieves most recently added entries first.""" + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item): + self._queue.append(item) + + def _get(self): + return self._queue.pop() + + +class JoinableQueue(Queue): + """A subclass of Queue with task_done() and join() methods.""" + + def __init__(self, maxsize=0, *, loop=None): + super().__init__(maxsize=maxsize, loop=loop) + self._unfinished_tasks = 0 + self._finished = locks.EventWaiter(loop=self._loop) + self._finished.set() + + def _format(self): + result = Queue._format(self) + if self._unfinished_tasks: + result += ' tasks={}'.format(self._unfinished_tasks) + return result + + def _put(self, item): + super()._put(item) + self._unfinished_tasks += 1 + self._finished.clear() + + def task_done(self): + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each get() used to fetch a task, + a subsequent call to task_done() tells the queue that the processing + on the task is complete. + + If a join() is currently blocking, it will resume when all items have + been processed (meaning that a task_done() call was received for every + item that had been put() into the queue). + + Raises ValueError if called more times than there were items placed in + the queue. + """ + if self._unfinished_tasks <= 0: + raise ValueError('task_done() called too many times') + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + @coroutine + def join(self): + """Block until all items in the queue have been gotten and processed. + + The count of unfinished tasks goes up whenever an item is added to the + queue. The count goes down whenever a consumer thread calls task_done() + to indicate that the item was retrieved and all work on it is complete. + When the count of unfinished tasks drops to zero, join() unblocks. + """ + if self._unfinished_tasks > 0: + yield from self._finished.wait() diff --git a/tulip/selector_events.py b/tulip/selector_events.py new file mode 100644 index 00000000..82d22bb6 --- /dev/null +++ b/tulip/selector_events.py @@ -0,0 +1,676 @@ +"""Event loop using a selector and related classes. + +A selector is a "notify-when-ready" multiplexer. For a subclass which +also includes support for signal handling, see the unix_events sub-module. +""" + +import collections +import socket +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import base_events +from . import constants +from . import events +from . import futures +from . import selectors +from . import transports +from .log import tulip_log + + +class BaseSelectorEventLoop(base_events.BaseEventLoop): + """Selector event loop. + + See events.EventLoop for API specification. + """ + + def __init__(self, selector=None): + super().__init__() + + if selector is None: + selector = selectors.DefaultSelector() + tulip_log.debug('Using selector: %s', selector.__class__.__name__) + self._selector = selector + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None): + return _SelectorSocketTransport(self, sock, protocol, waiter, extra) + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + server_side=False, extra=None): + return _SelectorSslTransport( + self, rawsock, protocol, sslcontext, waiter, server_side, extra) + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + return _SelectorDatagramTransport(self, sock, protocol, address, extra) + + def close(self): + if self._selector is not None: + self._close_self_pipe() + self._selector.close() + self._selector = None + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self.remove_reader(self._ssock.fileno()) + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.add_reader(self._ssock.fileno(), self._read_from_self) + + def _read_from_self(self): + try: + self._ssock.recv(1) + except (BlockingIOError, InterruptedError): + pass + + def _write_to_self(self): + try: + self._csock.send(b'x') + except (BlockingIOError, InterruptedError): + pass + + def _start_serving(self, protocol_factory, sock, ssl=None): + self.add_reader(sock.fileno(), self._accept_connection, + protocol_factory, sock, ssl) + + def _accept_connection(self, protocol_factory, sock, ssl=None): + try: + conn, addr = sock.accept() + conn.setblocking(False) + except (BlockingIOError, InterruptedError): + pass # False alarm. + except Exception: + # Bad error. Stop serving. + self.remove_reader(sock.fileno()) + sock.close() + # There's nowhere to send the error, so just log it. + # TODO: Someone will want an error handler for this. + tulip_log.exception('Accept failed') + else: + if ssl: + self._make_ssl_transport( + conn, protocol_factory(), ssl, None, + server_side=True, extra={'addr': addr}) + else: + self._make_socket_transport( + conn, protocol_factory(), extra={'addr': addr}) + # It's now up to the protocol to handle the connection. + + def add_reader(self, fd, callback, *args): + """Add a reader callback.""" + handle = events.make_handle(callback, args) + try: + key = self._selector.get_key(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_READ, + (handle, None)) + else: + mask, (reader, writer) = key.events, key.data + self._selector.modify(fd, mask | selectors.EVENT_READ, + (handle, writer)) + if reader is not None: + reader.cancel() + + def remove_reader(self, fd): + """Remove a reader callback.""" + try: + key = self._selector.get_key(fd) + except KeyError: + return False + else: + mask, (reader, writer) = key.events, key.data + mask &= ~selectors.EVENT_READ + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (None, writer)) + + if reader is not None: + reader.cancel() + return True + else: + return False + + def add_writer(self, fd, callback, *args): + """Add a writer callback..""" + handle = events.make_handle(callback, args) + try: + key = self._selector.get_key(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_WRITE, + (None, handle)) + else: + mask, (reader, writer) = key.events, key.data + self._selector.modify(fd, mask | selectors.EVENT_WRITE, + (reader, handle)) + if writer is not None: + writer.cancel() + + def remove_writer(self, fd): + """Remove a writer callback.""" + try: + key = self._selector.get_key(fd) + except KeyError: + return False + else: + mask, (reader, writer) = key.events, key.data + # Remove both writer and connector. + mask &= ~selectors.EVENT_WRITE + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None)) + + if writer is not None: + writer.cancel() + return True + else: + return False + + def sock_recv(self, sock, n): + """XXX""" + fut = futures.Future(loop=self) + self._sock_recv(fut, False, sock, n) + return fut + + def _sock_recv(self, fut, registered, sock, n): + fd = sock.fileno() + if registered: + # Remove the callback early. It should be rare that the + # selector says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_reader(fd) + if fut.cancelled(): + return + try: + data = sock.recv(n) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_recv, fut, True, sock, n) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(data) + + def sock_sendall(self, sock, data): + """XXX""" + fut = futures.Future(loop=self) + if data: + self._sock_sendall(fut, False, sock, data) + else: + fut.set_result(None) + return fut + + def _sock_sendall(self, fut, registered, sock, data): + fd = sock.fileno() + + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + + try: + n = sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + fut.set_exception(exc) + return + + if n == len(data): + fut.set_result(None) + else: + if n: + data = data[n:] + self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + + def sock_connect(self, sock, address): + """XXX""" + # That address better not require a lookup! We're not calling + # self.getaddrinfo() for you here. But verifying this is + # complicated; the socket module doesn't have a pattern for + # IPv6 addresses (there are too many forms, apparently). + fut = futures.Future(loop=self) + self._sock_connect(fut, False, sock, address) + return fut + + def _sock_connect(self, fut, registered, sock, address): + # TODO: Use getaddrinfo() to look up the address, to avoid the + # trap of hanging the entire event loop when the address + # requires doing a DNS lookup. (OTOH, the caller should + # already have done this, so it would be nice if we could + # easily tell whether the address needs looking up or not. I + # know how to do this for IPv4, but IPv6 addresses have many + # syntaxes.) + fd = sock.fileno() + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + try: + if not registered: + # First time around. + sock.connect(address) + else: + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to the except clause below. + raise OSError(err, 'Connect call failed') + except (BlockingIOError, InterruptedError): + self.add_writer(fd, self._sock_connect, fut, True, sock, address) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(None) + + def sock_accept(self, sock): + """XXX""" + fut = futures.Future(loop=self) + self._sock_accept(fut, False, sock) + return fut + + def _sock_accept(self, fut, registered, sock): + fd = sock.fileno() + if registered: + self.remove_reader(fd) + if fut.cancelled(): + return + try: + conn, address = sock.accept() + conn.setblocking(False) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_accept, fut, True, sock) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result((conn, address)) + + def _process_events(self, event_list): + for key, mask in event_list: + fileobj, (reader, writer) = key.fileobj, key.data + if mask & selectors.EVENT_READ and reader is not None: + if reader._cancelled: + self.remove_reader(fileobj) + else: + self._add_callback(reader) + if mask & selectors.EVENT_WRITE and writer is not None: + if writer._cancelled: + self.remove_writer(fileobj) + else: + self._add_callback(writer) + + def stop_serving(self, sock): + self.remove_reader(sock.fileno()) + sock.close() + + +class _SelectorTransport(transports.Transport): + + def __init__(self, loop, sock, protocol, extra): + super().__init__(extra) + self._extra['socket'] = sock + self._loop = loop + self._sock = sock + self._sock_fd = sock.fileno() + self._protocol = protocol + self._buffer = [] + self._conn_lost = 0 + self._writing = True + self._closing = False # Set when close() called. + + def abort(self): + self._force_close(None) + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + self._loop.remove_reader(self._sock_fd) + if not self._buffer: + self._loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + # should be called from exception handler only + tulip_log.exception('Fatal error for %s', self) + self._force_close(exc) + + def _force_close(self, exc): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + self._loop.remove_writer(self._sock_fd) + self._loop.remove_reader(self._sock_fd) + self._buffer.clear() + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + self._sock = None + self._protocol = None + self._loop = None + + +class _SelectorSocketTransport(_SelectorTransport): + + def __init__(self, loop, sock, protocol, waiter=None, extra=None): + super().__init__(loop, sock, protocol, extra) + + self._loop.add_reader(self._sock_fd, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = self._sock.recv(16*1024) + except (BlockingIOError, InterruptedError): + pass + except ConnectionResetError as exc: + self._force_close(exc) + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + try: + self._protocol.eof_received() + finally: + self.close() + + def write(self, data): + assert isinstance(data, bytes), repr(data) + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer and self._writing: + # Attempt to send it right away first. + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except OSError as exc: + self._fatal_error(exc) + return + + if n == len(data): + return + elif n: + data = data[n:] + self._loop.add_writer(self._sock_fd, self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + if not self._writing: + return # transmission off + + data = b''.join(self._buffer) + assert data, 'Data should not be empty' + + self._buffer.clear() + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except Exception as exc: + self._fatal_error(exc) + else: + if n == len(data): + self._loop.remove_writer(self._sock_fd) + if self._closing: + self._call_connection_lost(None) + return + elif n: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def pause_writing(self): + if self._writing: + if self._buffer: + self._loop.remove_writer(self._sock_fd) + self._writing = False + + def resume_writing(self): + if not self._writing: + if self._buffer: + self._loop.add_writer(self._sock_fd, self._write_ready) + self._writing = True + + def discard_output(self): + if self._buffer: + self._loop.remove_writer(self._sock_fd) + self._buffer.clear() + + +class _SelectorSslTransport(_SelectorTransport): + + def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, + server_side=False, extra=None): + if server_side: + assert isinstance( + sslcontext, ssl.SSLContext), 'Must pass an SSLContext' + else: + # Client-side may pass ssl=True to use a default context. + sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslsock = sslcontext.wrap_socket(rawsock, server_side=server_side, + do_handshake_on_connect=False) + + super().__init__(loop, sslsock, protocol, extra) + + self._waiter = waiter + self._rawsock = rawsock + self._sslcontext = sslcontext + + self._on_handshake() + + def _on_handshake(self): + try: + self._sock.do_handshake() + except ssl.SSLWantReadError: + self._loop.add_reader(self._sock_fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._loop.add_writer(self._sock_fd, self._on_handshake) + return + except Exception as exc: + self._sock.close() + if self._waiter is not None: + self._waiter.set_exception(exc) + return + except BaseException as exc: + self._sock.close() + if self._waiter is not None: + self._waiter.set_exception(exc) + raise + self._loop.remove_reader(self._sock_fd) + self._loop.remove_writer(self._sock_fd) + self._loop.add_reader(self._sock_fd, self._on_ready) + self._loop.add_writer(self._sock_fd, self._on_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if self._waiter is not None: + self._loop.call_soon(self._waiter.set_result, None) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # First try reading. + if not self._closing: + try: + data = self._sock.recv(8192) + except (BlockingIOError, InterruptedError, + ssl.SSLWantReadError, ssl.SSLWantWriteError): + pass + except ConnectionResetError as exc: + self._force_close(exc) + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + try: + self._protocol.eof_received() + finally: + self.close() + + # Now try writing, if there's anything to write. + if self._buffer: + data = b''.join(self._buffer) + self._buffer = [] + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError, + ssl.SSLWantReadError, ssl.SSLWantWriteError): + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + + if n < len(data): + self._buffer.append(data[n:]) + + if self._closing and not self._buffer: + self._loop.remove_writer(self._sock_fd) + self._call_connection_lost(None) + + def write(self, data): + assert isinstance(data, bytes), repr(data) + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + self._loop.remove_reader(self._sock_fd) + + # TODO: write_eof(), can_write_eof(). + + +class _SelectorDatagramTransport(_SelectorTransport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, loop, sock, protocol, address=None, extra=None): + super().__init__(loop, sock, protocol, extra) + + self._address = address + self._buffer = collections.deque() + self._loop.add_reader(self._sock_fd, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + + def _read_ready(self): + try: + data, addr = self._sock.recvfrom(self.max_size) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + self._protocol.datagram_received(data, addr) + + def sendto(self, data, addr=None): + assert isinstance(data, bytes), repr(data) + if not data: + return + + if self._address: + assert addr in (None, self._address) + + if self._conn_lost and self._address: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Attempt to send it right away first. + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + return + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except (BlockingIOError, InterruptedError): + self._loop.add_writer(self._sock_fd, self._sendto_ready) + except Exception as exc: + self._fatal_error(exc) + return + + self._buffer.append((data, addr)) + + def _sendto_ready(self): + while self._buffer: + data, addr = self._buffer.popleft() + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except (BlockingIOError, InterruptedError): + self._buffer.appendleft((data, addr)) # Try again later. + break + except Exception as exc: + self._fatal_error(exc) + return + + if not self._buffer: + self._loop.remove_writer(self._sock_fd) + if self._closing: + self._call_connection_lost(None) + + def _force_close(self, exc): + if self._address and isinstance(exc, ConnectionRefusedError): + self._protocol.connection_refused(exc) + + super()._force_close(exc) diff --git a/tulip/selectors.py b/tulip/selectors.py new file mode 100644 index 00000000..b81b1dbe --- /dev/null +++ b/tulip/selectors.py @@ -0,0 +1,410 @@ +"""Selectors module. + +This module allows high-level and efficient I/O multiplexing, built upon the +`select` module primitives. +""" + + +from abc import ABCMeta, abstractmethod +from collections import namedtuple +import functools +import select +import sys + + +# generic events, that must be mapped to implementation-specific ones +EVENT_READ = (1 << 0) +EVENT_WRITE = (1 << 1) + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file object or file descriptor + + Returns: + corresponding file descriptor + """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (AttributeError, ValueError): + raise ValueError("Invalid file object: " + "{!r}".format(fileobj)) from None + if fd < 0: + raise ValueError("Invalid file descriptor: {}".format(fd)) + return fd + + +SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data']) +"""Object used to associate a file object to its backing file descriptor, +selected event mask and attached data.""" + + +class BaseSelector(metaclass=ABCMeta): + """Base selector class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + performant implementation on the current platform. + """ + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # this maps file objects to keys - for fast (un)registering + self._fileobj_to_key = {} + + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object or file descriptor + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + """ + if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)): + raise ValueError("Invalid events: {}".format(events)) + + key = SelectorKey(fileobj, _fileobj_to_fd(fileobj), events, data) + + if key.fd in self._fd_to_key: + raise KeyError("{!r} (FD {}) is already " + "registered".format(fileobj, key.fd)) + + self._fd_to_key[key.fd] = key + self._fileobj_to_key[fileobj] = key + return key + + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object or file descriptor + + Returns: + SelectorKey instance + """ + try: + key = self._fileobj_to_key.pop(fileobj) + del self._fd_to_key[key.fd] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + return key + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object or file descriptor + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + """ + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + if events != key.events or data != key.data: + # TODO: If only the data changed, use a shortcut that only + # updates the data. + self.unregister(fileobj) + return self.register(fileobj, events, data) + else: + return key + + @abstractmethod + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout <= 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (key, events) for ready file objects + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE + """ + raise NotImplementedError() + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + self._fd_to_key.clear() + self._fileobj_to_key.clear() + + def get_key(self, fileobj): + """Return the key associated to a registered file object. + + Returns: + SelectorKey for this file object + """ + try: + return self._fileobj_to_key[fileobj] + except KeyError: + raise KeyError("{} is not registered".format(fileobj)) from None + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key, or None if not found + """ + try: + return self._fd_to_key[fd] + except KeyError: + return None + + +class SelectSelector(BaseSelector): + """Select-based selector.""" + + def __init__(self): + super().__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & EVENT_WRITE: + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + if sys.platform == 'win32': + def _select(self, r, w, _, timeout=None): + r, w, x = select.select(r, w, w, timeout) + return r, w + x, [] + else: + _select = select.select + + def select(self, timeout=None): + timeout = None if timeout is None else max(timeout, 0) + ready = [] + try: + r, w, _ = self._select(self._readers, self._writers, [], timeout) + except InterruptedError: + return ready + r = set(r) + w = set(w) + for fd in r | w: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + +if hasattr(select, 'poll'): + + class PollSelector(BaseSelector): + """Poll-based selector.""" + + def __init__(self): + super().__init__() + self._poll = select.poll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= select.POLLIN + if events & EVENT_WRITE: + poll_events |= select.POLLOUT + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = None if timeout is None else max(int(1000 * timeout), 0) + ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & ~select.POLLIN: + events |= EVENT_WRITE + if event & ~select.POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + +if hasattr(select, 'epoll'): + + class EpollSelector(BaseSelector): + """Epoll-based selector.""" + + def __init__(self): + super().__init__() + self._epoll = select.epoll() + + def fileno(self): + return self._epoll.fileno() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + epoll_events = 0 + if events & EVENT_READ: + epoll_events |= select.EPOLLIN + if events & EVENT_WRITE: + epoll_events |= select.EPOLLOUT + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._epoll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = -1 if timeout is None else max(timeout, 0) + max_ev = len(self._fd_to_key) + ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & ~select.EPOLLIN: + events |= EVENT_WRITE + if event & ~select.EPOLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + super().close() + self._epoll.close() + + +if hasattr(select, 'kqueue'): + + class KqueueSelector(BaseSelector): + """Kqueue-based selector.""" + + def __init__(self): + super().__init__() + self._kqueue = select.kqueue() + + def fileno(self): + return self._kqueue.fileno() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + if key.events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + if key.events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + return key + + def select(self, timeout=None): + timeout = None if timeout is None else max(timeout, 0) + max_ev = len(self._fd_to_key) + ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except InterruptedError: + return ready + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == select.KQ_FILTER_READ: + events |= EVENT_READ + if flag == select.KQ_FILTER_WRITE: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + super().close() + self._kqueue.close() + + +# Choose the best implementation: roughly, epoll|kqueue > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + DefaultSelector = KqueueSelector +elif 'EpollSelector' in globals(): + DefaultSelector = EpollSelector +elif 'PollSelector' in globals(): + DefaultSelector = PollSelector +else: + DefaultSelector = SelectSelector diff --git a/tulip/streams.py b/tulip/streams.py new file mode 100644 index 00000000..3203b7d6 --- /dev/null +++ b/tulip/streams.py @@ -0,0 +1,211 @@ +"""Stream-related things.""" + +__all__ = ['StreamReader', 'StreamReaderProtocol', 'open_connection'] + +import collections + +from . import events +from . import futures +from . import protocols +from . import tasks + + +_DEFAULT_LIMIT = 2**16 + + +@tasks.coroutine +def open_connection(host=None, port=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """A wrapper for create_connection() returning a (reader, writer) pair. + + The reader returned is a StreamReader instance; the writer is a + Transport. + + The arguments are all the usual arguments to create_connection() + except protocol_factory; most common are positional host and port, + with various optional keyword arguments following. + + Additional optional keyword arguments are loop (to set the event loop + instance to use) and limit (to set the buffer limit passed to the + StreamReader). + + (If you want to customize the StreamReader and/or + StreamReaderProtocol classes, just copy the code -- there's + really nothing special here except some convenience.) + """ + if loop is None: + loop = events.get_event_loop() + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader) + transport, _ = yield from loop.create_connection( + lambda: protocol, host, port, **kwds) + return reader, transport # (reader, writer) + + +class StreamReaderProtocol(protocols.Protocol): + """Trivial helper class to adapt between Protocol and StreamReader. + + (This is a helper class instead of making StreamReader itself a + Protocol subclass, because the StreamReader has other potential + uses, and to prevent the user of the StreamReader to accidentally + call inappropriate methods of the protocol.) + """ + + def __init__(self, stream_reader): + self.stream_reader = stream_reader + + def connection_lost(self, exc): + if exc is None: + self.stream_reader.feed_eof() + else: + self.stream_reader.set_exception(exc) + + def data_received(self, data): + self.stream_reader.feed_data(data) + + def eof_received(self): + self.stream_reader.feed_eof() + + +class StreamReader: + + def __init__(self, limit=_DEFAULT_LIMIT, loop=None): + self.limit = limit # Max line length. (Security feature.) + if loop is None: + loop = events.get_event_loop() + self.loop = loop + self.buffer = collections.deque() # Deque of bytes objects. + self.byte_count = 0 # Bytes in buffer. + self.eof = False # Whether we're done. + self.waiter = None # A future. + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_exception(exc) + + def feed_eof(self): + self.eof = True + waiter = self.waiter + if waiter is not None: + self.waiter = None + if not waiter.done(): + waiter.set_result(True) + + def feed_data(self, data): + if not data: + return + + self.buffer.append(data) + self.byte_count += len(data) + + waiter = self.waiter + if waiter is not None: + self.waiter = None + if not waiter.done(): + waiter.set_result(False) + + @tasks.coroutine + def readline(self): + if self._exception is not None: + raise self._exception + + parts = [] + parts_size = 0 + not_enough = True + + while not_enough: + while self.buffer and not_enough: + data = self.buffer.popleft() + ichar = data.find(b'\n') + if ichar < 0: + parts.append(data) + parts_size += len(data) + else: + ichar += 1 + head, tail = data[:ichar], data[ichar:] + if tail: + self.buffer.appendleft(tail) + not_enough = False + parts.append(head) + parts_size += len(head) + + if parts_size > self.limit: + self.byte_count -= parts_size + raise ValueError('Line is too long') + + if self.eof: + break + + if not_enough: + assert self.waiter is None + self.waiter = futures.Future(loop=self.loop) + yield from self.waiter + + line = b''.join(parts) + self.byte_count -= parts_size + + return line + + @tasks.coroutine + def read(self, n=-1): + if self._exception is not None: + raise self._exception + + if not n: + return b'' + + if n < 0: + while not self.eof: + assert not self.waiter + self.waiter = futures.Future(loop=self.loop) + yield from self.waiter + else: + if not self.byte_count and not self.eof: + assert not self.waiter + self.waiter = futures.Future(loop=self.loop) + yield from self.waiter + + if n < 0 or self.byte_count <= n: + data = b''.join(self.buffer) + self.buffer.clear() + self.byte_count = 0 + return data + + parts = [] + parts_bytes = 0 + while self.buffer and parts_bytes < n: + data = self.buffer.popleft() + data_bytes = len(data) + if n < parts_bytes + data_bytes: + data_bytes = n - parts_bytes + data, rest = data[:data_bytes], data[data_bytes:] + self.buffer.appendleft(rest) + + parts.append(data) + parts_bytes += data_bytes + self.byte_count -= data_bytes + + return b''.join(parts) + + @tasks.coroutine + def readexactly(self, n): + if self._exception is not None: + raise self._exception + + if n <= 0: + return b'' + + while self.byte_count < n and not self.eof: + assert not self.waiter + self.waiter = futures.Future(loop=self.loop) + yield from self.waiter + + return (yield from self.read(n)) diff --git a/tulip/tasks.py b/tulip/tasks.py new file mode 100644 index 00000000..a51ee29a --- /dev/null +++ b/tulip/tasks.py @@ -0,0 +1,321 @@ +"""Support for tasks, coroutines and the scheduler.""" + +__all__ = ['coroutine', 'Task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'wait_for', 'as_completed', 'sleep', 'async', + ] + +import collections +import concurrent.futures +import functools +import inspect + +from . import events +from . import futures + + +def coroutine(func): + """Decorator to mark coroutines. + + Decorator wraps non generator functions and returns generator wrapper. + If non generator function returns generator of Future it yield-from it. + + TODO: This is a feel-good API only. It is not enforced. + """ + if inspect.isgeneratorfunction(func): + coro = func + else: + @functools.wraps(func) + def coro(*args, **kw): + res = func(*args, **kw) + if isinstance(res, futures.Future) or inspect.isgenerator(res): + res = yield from res + return res + + coro._is_coroutine = True # Not sure who can use this. + return coro + + +# TODO: Do we need this? +def iscoroutinefunction(func): + """Return True if func is a decorated coroutine function.""" + return (inspect.isgeneratorfunction(func) and + getattr(func, '_is_coroutine', False)) + + +# TODO: Do we need this? +def iscoroutine(obj): + """Return True if obj is a coroutine object.""" + return inspect.isgenerator(obj) # TODO: And what? + + +class Task(futures.Future): + """A coroutine wrapped in a Future.""" + + def __init__(self, coro, *, loop=None): + assert inspect.isgenerator(coro) # Must be a coroutine *object*. + super().__init__(loop=loop) + self._coro = coro + self._fut_waiter = None + self._must_cancel = False + self._loop.call_soon(self._step) + + def __repr__(self): + res = super().__repr__() + if (self._must_cancel and + self._state == futures._PENDING and + ')'.format(self._coro.__name__) + res[i:] + return res + + def cancel(self): + if self.done(): + return False + if self._fut_waiter is not None: + if self._fut_waiter.cancel(): + return True + # It must be the case that self._step is already scheduled. + self._must_cancel = True + return True + + def _step(self, value=None, exc=None): + assert not self.done(), \ + '_step(): already done: {!r}, {!r}, {!r}'.format(self, value, exc) + if self._must_cancel: + assert self._fut_waiter is None + exc = futures.CancelledError() + value = None + coro = self._coro + self._fut_waiter = None + # Call either coro.throw(exc) or coro.send(value). + try: + if exc is not None: + result = coro.throw(exc) + elif value is not None: + result = coro.send(value) + else: + result = next(coro) + except StopIteration as exc: + self.set_result(exc.value) + except futures.CancelledError as exc: + super().cancel() # I.e., Future.cancel(self). + except Exception as exc: + self.set_exception(exc) + except BaseException as exc: + self.set_exception(exc) + raise + else: + if isinstance(result, futures.Future): + # Yielded Future must come from Future.__iter__(). + if result._blocking: + result._blocking = False + result.add_done_callback(self._wakeup) + self._fut_waiter = result + else: + self._loop.call_soon( + self._step, None, + RuntimeError( + 'yield was used instead of yield from ' + 'in task {!r} with {!r}'.format(self, result))) + elif result is None: + # Bare yield relinquishes control for one event loop iteration. + self._loop.call_soon(self._step) + elif inspect.isgenerator(result): + # Yielding a generator is just wrong. + self._loop.call_soon( + self._step, None, + RuntimeError( + 'yield was used instead of yield from for ' + 'generator in task {!r} with {}'.format( + self, result))) + else: + # Yielding something else is an error. + self._loop.call_soon( + self._step, None, + RuntimeError( + 'Task got bad yield: {!r}'.format(result))) + self = None + + def _wakeup(self, future): + try: + value = future.result() + except Exception as exc: + self._step(None, exc) + else: + self._step(value, None) + self = None # Needed to break cycles when an exception occurs. + + +# wait() and as_completed() similar to those in PEP 3148. + +FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED +FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION +ALL_COMPLETED = concurrent.futures.ALL_COMPLETED + + +@coroutine +def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): + """Wait for the Futures and coroutines given by fs to complete. + + Coroutines will be wrapped in Tasks. + + Returns two sets of Future: (done, pending). + + Usage: + + done, pending = yield from tulip.wait(fs) + + Note: This does not raise TimeoutError! Futures that aren't done + when the timeout occurs are returned in the second set. + """ + if not fs: + raise ValueError('Set of coroutines/Futures is empty.') + + if loop is None: + loop = events.get_event_loop() + + fs = set(async(f, loop=loop) for f in fs) + + if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED): + raise ValueError('Invalid return_when value: {}'.format(return_when)) + return (yield from _wait(fs, timeout, return_when, loop)) + + +@coroutine +def wait_for(fut, timeout, *, loop=None): + """Wait for the single Future or coroutine to complete, with timeout. + + Coroutine will be wrapped in Task. + + Returns result of the Future or coroutine. Raises TimeoutError when + timeout occurs. + + Usage: + + result = yield from tulip.wait_for(fut, 10.0) + + """ + if loop is None: + loop = events.get_event_loop() + + fut = async(fut, loop=loop) + + done, pending = yield from _wait([fut], timeout, FIRST_COMPLETED, loop) + if done: + return done.pop().result() + + raise futures.TimeoutError() + + +@coroutine +def _wait(fs, timeout, return_when, loop): + """Internal helper for wait(return_when=FIRST_COMPLETED). + + The fs argument must be a set of Futures. + The timeout argument is like for wait(). + """ + assert fs, 'Set of Futures is empty.' + waiter = futures.Future(loop=loop) + timeout_handle = None + if timeout is not None: + timeout_handle = loop.call_later(timeout, waiter.cancel) + counter = len(fs) + + def _on_completion(f): + nonlocal counter + counter -= 1 + if (counter <= 0 or + return_when == FIRST_COMPLETED or + return_when == FIRST_EXCEPTION and (not f.cancelled() and + f.exception() is not None)): + if timeout_handle is not None: + timeout_handle.cancel() + waiter.cancel() + + for f in fs: + f.add_done_callback(_on_completion) + try: + yield from waiter + except futures.CancelledError: + pass + done, pending = set(), set() + for f in fs: + f.remove_done_callback(_on_completion) + if f.done(): + done.add(f) + else: + pending.add(f) + return done, pending + + +# This is *not* a @coroutine! It is just an iterator (yielding Futures). +def as_completed(fs, *, loop=None, timeout=None): + """Return an iterator whose values, when waited for, are Futures. + + This differs from PEP 3148; the proper way to use this is: + + for f in as_completed(fs): + result = yield from f # The 'yield from' may raise. + # Use result. + + Raises TimeoutError if the timeout occurs before all Futures are + done. + + Note: The futures 'f' are not necessarily members of fs. + """ + loop = loop if loop is not None else events.get_event_loop() + deadline = None if timeout is None else loop.time() + timeout + todo = set(async(f, loop=loop) for f in fs) + completed = collections.deque() + + @coroutine + def _wait_for_one(): + while not completed: + timeout = None + if deadline is not None: + timeout = deadline - loop.time() + if timeout < 0: + raise futures.TimeoutError() + done, pending = yield from _wait( + todo, timeout, FIRST_COMPLETED, loop) + # Multiple callers might be waiting for the same events + # and getting the same outcome. Dedupe by updating todo. + for f in done: + if f in todo: + todo.remove(f) + completed.append(f) + f = completed.popleft() + return f.result() # May raise. + + for _ in range(len(todo)): + yield _wait_for_one() + + +@coroutine +def sleep(delay, result=None, *, loop=None): + """Coroutine that completes after a given time (in seconds).""" + future = futures.Future(loop=loop) + h = future._loop.call_later(delay, future.set_result, result) + try: + return (yield from future) + finally: + h.cancel() + + +def async(coro_or_future, *, loop=None): + """Wrap a coroutine in a future. + + If the argument is a Future, it is returned directly. + """ + if isinstance(coro_or_future, futures.Future): + if loop is not None and loop is not coro_or_future._loop: + raise ValueError('loop argument must agree with Future') + return coro_or_future + elif iscoroutine(coro_or_future): + return Task(coro_or_future, loop=loop) + else: + raise TypeError('A Future or coroutine is required') diff --git a/tulip/test_utils.py b/tulip/test_utils.py new file mode 100644 index 00000000..b4af0c89 --- /dev/null +++ b/tulip/test_utils.py @@ -0,0 +1,443 @@ +"""Utilities shared by tests.""" + +import cgi +import collections +import contextlib +import gc +import email.parser +import http.server +import json +import logging +import io +import unittest.mock +import os +import re +import socket +import sys +import threading +import traceback +import unittest +import unittest.mock +import urllib.parse +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +import tulip +import tulip.http +from tulip.http import client +from tulip import base_events +from tulip import events + +from tulip import base_events +from tulip import selectors + + +if sys.platform == 'win32': # pragma: no cover + from .windows_utils import socketpair +else: + from socket import socketpair # pragma: no cover + + +def run_briefly(loop): + @tulip.coroutine + def once(): + pass + t = tulip.Task(once(), loop=loop) + loop.run_until_complete(t) + + +def run_once(loop): + """loop.stop() schedules _raise_stop_error() + and run_forever() runs until _raise_stop_error() callback. + this wont work if test waits for some IO events, because + _raise_stop_error() runs before any of io events callbacks. + """ + loop.stop() + loop.run_forever() + + +@contextlib.contextmanager +def run_test_server(loop, *, host='127.0.0.1', port=0, + use_ssl=False, router=None): + properties = {} + transports = [] + + class HttpServer: + + def __init__(self, host, port): + self.host = host + self.port = port + self.address = (host, port) + self._url = '{}://{}:{}'.format( + 'https' if use_ssl else 'http', host, port) + + def __getitem__(self, key): + return properties[key] + + def __setitem__(self, key, value): + properties[key] = value + + def url(self, *suffix): + return urllib.parse.urljoin( + self._url, '/'.join(str(s) for s in suffix)) + + class TestHttpServer(tulip.http.ServerHttpProtocol): + + def connection_made(self, transport): + transports.append(transport) + super().connection_made(transport) + + def handle_request(self, message, payload): + if properties.get('close', False): + return + + if properties.get('noresponse', False): + yield from tulip.sleep(99999) + + if router is not None: + body = bytearray() + chunk = yield from payload.read() + while chunk: + body.extend(chunk) + chunk = yield from payload.read() + + rob = router( + self, properties, + self.transport, message, bytes(body)) + rob.dispatch() + + else: + response = tulip.http.Response( + self.transport, 200, message.version) + + text = b'Test message' + response.add_header('Content-type', 'text/plain') + response.add_header('Content-length', str(len(text))) + response.send_headers() + response.write(text) + response.write_eof() + + if use_ssl: + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + keyfile = os.path.join(here, 'sample.key') + certfile = os.path.join(here, 'sample.crt') + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain(certfile, keyfile) + else: + sslcontext = None + + def run(loop, fut): + thread_loop = tulip.new_event_loop() + tulip.set_event_loop(thread_loop) + + socks = thread_loop.run_until_complete( + thread_loop.start_serving( + lambda: TestHttpServer(keep_alive=0.5), + host, port, ssl=sslcontext)) + + waiter = tulip.Future(loop=thread_loop) + loop.call_soon_threadsafe( + fut.set_result, (thread_loop, waiter, socks[0].getsockname())) + + try: + thread_loop.run_until_complete(waiter) + finally: + # call pending connection_made if present + run_briefly(thread_loop) + + # close opened trnsports + for tr in transports: + tr.close() + + run_briefly(thread_loop) # call close callbacks + + for s in socks: + thread_loop.stop_serving(s) + + thread_loop.stop() + thread_loop.close() + gc.collect() + + fut = tulip.Future(loop=loop) + server_thread = threading.Thread(target=run, args=(loop, fut)) + server_thread.start() + + thread_loop, waiter, addr = loop.run_until_complete(fut) + try: + yield HttpServer(*addr) + finally: + thread_loop.call_soon_threadsafe(waiter.set_result, None) + server_thread.join() + + +class Router: + + _response_version = "1.1" + _responses = http.server.BaseHTTPRequestHandler.responses + + def __init__(self, srv, props, transport, message, payload): + # headers + self._headers = http.client.HTTPMessage() + for hdr, val in message.headers: + self._headers.add_header(hdr, val) + + self._srv = srv + self._props = props + self._transport = transport + self._method = message.method + self._uri = message.path + self._version = message.version + self._compression = message.compression + self._body = payload + + url = urllib.parse.urlsplit(self._uri) + self._path = url.path + self._query = url.query + + @staticmethod + def define(rmatch): + def wrapper(fn): + f_locals = sys._getframe(1).f_locals + mapping = f_locals.setdefault('_mapping', []) + mapping.append((re.compile(rmatch), fn.__name__)) + return fn + + return wrapper + + def dispatch(self): # pragma: no cover + for route, fn in self._mapping: + match = route.match(self._path) + if match is not None: + try: + return getattr(self, fn)(match) + except Exception: + out = io.StringIO() + traceback.print_exc(file=out) + self._response(500, out.getvalue()) + + return + + return self._response(self._start_response(404)) + + def _start_response(self, code): + return tulip.http.Response(self._transport, code) + + def _response(self, response, body=None, headers=None, chunked=False): + r_headers = {} + for key, val in self._headers.items(): + key = '-'.join(p.capitalize() for p in key.split('-')) + r_headers[key] = val + + encoding = self._headers.get('content-encoding', '').lower() + if 'gzip' in encoding: # pragma: no cover + cmod = 'gzip' + elif 'deflate' in encoding: + cmod = 'deflate' + else: + cmod = '' + + resp = { + 'method': self._method, + 'version': '%s.%s' % self._version, + 'path': self._uri, + 'headers': r_headers, + 'origin': self._transport.get_extra_info('addr', ' ')[0], + 'query': self._query, + 'form': {}, + 'compression': cmod, + 'multipart-data': [] + } + if body: # pragma: no cover + resp['content'] = body + + ct = self._headers.get('content-type', '').lower() + + # application/x-www-form-urlencoded + if ct == 'application/x-www-form-urlencoded': + resp['form'] = urllib.parse.parse_qs(self._body.decode('latin1')) + + # multipart/form-data + elif ct.startswith('multipart/form-data'): # pragma: no cover + out = io.BytesIO() + for key, val in self._headers.items(): + out.write(bytes('{}: {}\r\n'.format(key, val), 'latin1')) + + out.write(b'\r\n') + out.write(self._body) + out.write(b'\r\n') + out.seek(0) + + message = email.parser.BytesParser().parse(out) + if message.is_multipart(): + for msg in message.get_payload(): + if msg.is_multipart(): + logging.warn('multipart msg is not expected') + else: + key, params = cgi.parse_header( + msg.get('content-disposition', '')) + params['data'] = msg.get_payload() + params['content-type'] = msg.get_content_type() + resp['multipart-data'].append(params) + + body = json.dumps(resp, indent=4, sort_keys=True) + + # default headers + hdrs = [('Connection', 'close'), + ('Content-Type', 'application/json')] + if chunked: + hdrs.append(('Transfer-Encoding', 'chunked')) + else: + hdrs.append(('Content-Length', str(len(body)))) + + # extra headers + if headers: + hdrs.extend(headers.items()) + + if chunked: + response.force_chunked() + + # headers + response.add_headers(*hdrs) + response.send_headers() + + # write payload + response.write(client.str_to_bytes(body)) + response.write_eof() + + # keep-alive + if response.keep_alive(): + self._srv.keep_alive(True) + + +def make_test_protocol(base): + dct = {} + for name in dir(base): + if name.startswith('__') and name.endswith('__'): + # skip magic names + continue + dct[name] = unittest.mock.Mock(return_value=None) + return type('TestProtocol', (base,) + base.__bases__, dct)() + + +class TestSelector(selectors.BaseSelector): + + def select(self, timeout): + return [] + + +class TestLoop(base_events.BaseEventLoop): + """Loop for unittests. + + It manages self time directly. + If something scheduled to be executed later then + on next loop iteration after all ready handlers done + generator passed to __init__ is calling. + + Generator should be like this: + + def gen(): + ... + when = yield ... + ... = yield time_advance + + Value retuned by yield is absolute time of next scheduled handler. + Value passed to yield is time advance to move loop's time forward. + """ + + def __init__(self, gen=None): + super().__init__() + + if gen is None: + self._check_on_close = False + def gen(): + yield + else: + self._check_on_close = True + + self._gen = gen() + next(self._gen) + self._time = 0 + self._timers = [] + self._selector = TestSelector() + + self.readers = {} + self.writers = {} + self.reset_counters() + + def time(self): + return self._time + + def advance_time(self, advance): + """Move test time forward.""" + if advance: + self._time += advance + + def close(self): + if self._check_on_close: + try: + self._gen.send(0) + except StopIteration: + pass + else: # pragma: no cover + raise AssertionError("Time generator is not finished") + + def add_reader(self, fd, callback, *args): + self.readers[fd] = events.make_handle(callback, args) + + def remove_reader(self, fd): + self.remove_reader_count[fd] += 1 + if fd in self.readers: + del self.readers[fd] + return True + else: + return False + + def assert_reader(self, fd, callback, *args): + assert fd in self.readers, 'fd {} is not registered'.format(fd) + handle = self.readers[fd] + assert handle._callback == callback, '{!r} != {!r}'.format( + handle._callback, callback) + assert handle._args == args, '{!r} != {!r}'.format( + handle._args, args) + + def add_writer(self, fd, callback, *args): + self.writers[fd] = events.make_handle(callback, args) + + def remove_writer(self, fd): + self.remove_writer_count[fd] += 1 + if fd in self.writers: + del self.writers[fd] + return True + else: + return False + + def assert_writer(self, fd, callback, *args): + assert fd in self.writers, 'fd {} is not registered'.format(fd) + handle = self.writers[fd] + assert handle._callback == callback, '{!r} != {!r}'.format( + handle._callback, callback) + assert handle._args == args, '{!r} != {!r}'.format( + handle._args, args) + + def reset_counters(self): + self.remove_reader_count = collections.defaultdict(int) + self.remove_writer_count = collections.defaultdict(int) + + def _run_once(self): + super()._run_once() + for when in self._timers: + advance = self._gen.send(when) + self.advance_time(advance) + self._timers = [] + + def call_at(self, when, callback, *args): + self._timers.append(when) + return super().call_at(when, callback, *args) + + def _process_events(self, event_list): + return + + def _write_to_self(self): + pass diff --git a/tulip/transports.py b/tulip/transports.py new file mode 100644 index 00000000..56425aa9 --- /dev/null +++ b/tulip/transports.py @@ -0,0 +1,201 @@ +"""Abstract Transport class.""" + +__all__ = ['ReadTransport', 'WriteTransport', 'Transport'] + + +class BaseTransport: + """Base ABC for transports.""" + + def __init__(self, extra=None): + if extra is None: + extra = {} + self._extra = extra + + def get_extra_info(self, name, default=None): + """Get optional transport information.""" + return self._extra.get(name, default) + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + raise NotImplementedError + + +class ReadTransport(BaseTransport): + """ABC for read-only transports.""" + + def pause(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + raise NotImplementedError + + +class WriteTransport(BaseTransport): + """ABC for write-only transports.""" + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def write_eof(self): + """Closes the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + raise NotImplementedError + + def can_write_eof(self): + """Return True if this protocol supports write_eof(), False if not.""" + raise NotImplementedError + + def pause_writing(self): + """Pause transmission on the transport. + + Subsequent writes are deferred until resume_writing() is called. + """ + raise NotImplementedError + + def resume_writing(self): + """Resume transmission on the transport. """ + raise NotImplementedError + + def discard_output(self): + """Discard any buffered data awaiting transmission on the transport.""" + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class Transport(ReadTransport, WriteTransport): + """ABC representing a bidirectional transport. + + There may be several implementations, but typically, the user does + not implement new transports; rather, the platform provides some + useful transports that are implemented using the platform's best + practices. + + The user never instantiates a transport directly; they call a + utility function, passing it a protocol factory and other + information necessary to create the transport and protocol. (E.g. + EventLoop.create_connection() or EventLoop.start_serving().) + + The utility function will asynchronously create a transport and a + protocol and hook them up by calling the protocol's + connection_made() method, passing it the transport. + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + +class DatagramTransport(BaseTransport): + """ABC for datagram (UDP) transports.""" + + def sendto(self, data, addr=None): + """Send data to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + addr is target socket address. + If addr is None use target address pointed on transport creation. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class SubprocessTransport(BaseTransport): + + def get_pid(self): + """Get subprocess id.""" + raise NotImplementedError + + def get_returncode(self): + """Get subprocess returncode. + + See also + http://docs.python.org/3/library/subprocess#subprocess.Popen.returncode + """ + raise NotImplementedError + + def get_pipe_transport(self, fd): + """Get transport for pipe with number fd.""" + raise NotImplementedError + + def send_signal(self, signal): + """Send signal to subprocess. + + See also: + docs.python.org/3/library/subprocess#subprocess.Popen.send_signal + """ + raise NotImplementedError + + def terminate(self): + """Stop the subprocess. + + Alias for close() method. + + On Posix OSs the method sends SIGTERM to the subprocess. + On Windows the Win32 API function TerminateProcess() + is called to stop the subprocess. + + See also: + http://docs.python.org/3/library/subprocess#subprocess.Popen.terminate + """ + raise NotImplementedError + + def kill(self): + """Kill the subprocess. + + On Posix OSs the function sends SIGKILL to the subprocess. + On Windows kill() is an alias for terminate(). + + See also: + http://docs.python.org/3/library/subprocess#subprocess.Popen.kill + """ + raise NotImplementedError diff --git a/tulip/unix_events.py b/tulip/unix_events.py new file mode 100644 index 00000000..75131851 --- /dev/null +++ b/tulip/unix_events.py @@ -0,0 +1,555 @@ +"""Selector eventloop for Unix with signal handling.""" + +import collections +import errno +import fcntl +import functools +import os +import signal +import socket +import stat +import subprocess +import sys + + +from . import constants +from . import events +from . import protocols +from . import selector_events +from . import tasks +from . import transports +from .log import tulip_log + + +__all__ = ['SelectorEventLoop'] + + +if sys.platform == 'win32': # pragma: no cover + raise ImportError('Signals are not really supported on Windows') + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Unix event loop + + Adds signal handling to SelectorEventLoop + """ + + def __init__(self, selector=None): + super().__init__(selector) + self._signal_handlers = {} + self._subprocesses = {} + + def _socketpair(self): + return socket.socketpair() + + def close(self): + handler = self._signal_handlers.get(signal.SIGCHLD) + if handler is not None: + self.remove_signal_handler(signal.SIGCHLD) + super().close() + + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + self._check_signal(sig) + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except ValueError as exc: + raise RuntimeError(str(exc)) + + handle = events.make_handle(callback, args) + self._signal_handlers[sig] = handle + + try: + signal.signal(sig, self._handle_signal) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as nexc: + tulip_log.info('set_wakeup_fd(-1) failed: %s', nexc) + + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + def _handle_signal(self, sig, arg): + """Internal helper that is the actual signal handler.""" + handle = self._signal_handlers.get(sig) + if handle is None: + return # Assume it's some race condition. + if handle._cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self._add_callback_signalsafe(handle) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not. + """ + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as exc: + tulip_log.info('set_wakeup_fd(-1) failed: %s', exc) + + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + + if not (1 <= sig < signal.NSIG): + raise ValueError( + 'sig {} out of range(1, {})'.format(sig, signal.NSIG)) + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixReadPipeTransport(self, pipe, protocol, waiter, extra) + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixWritePipeTransport(self, pipe, protocol, waiter, extra) + + @tasks.coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + self._reg_sigchld() + transp = _UnixSubprocessTransport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs) + self._subprocesses[transp.get_pid()] = transp + yield from transp._post_init() + return transp + + def _reg_sigchld(self): + if signal.SIGCHLD not in self._signal_handlers: + self.add_signal_handler(signal.SIGCHLD, self._sig_chld) + + def _sig_chld(self): + try: + while True: + try: + pid, status = os.waitpid(0, os.WNOHANG) + except ChildProcessError: + break + if pid == 0: + continue + elif os.WIFSIGNALED(status): + returncode = -os.WTERMSIG(status) + elif os.WIFEXITED(status): + returncode = os.WEXITSTATUS(status) + else: + # covered by + # SelectorEventLoopTests.test__sig_chld_unknown_status + # from tests/unix_events_test.py + # bug in coverage.py version 3.6 ??? + continue # pragma: no cover + transp = self._subprocesses.get(pid) + if transp is not None: + transp._process_exited(returncode) + except Exception: + tulip_log.exception('Unknown exception in SIGCHLD handler') + + def _subprocess_closed(self, transport): + pid = transport.get_pid() + self._subprocesses.pop(pid, None) + + +def _set_nonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + + +class _UnixReadPipeTransport(transports.ReadTransport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._loop = loop + self._pipe = pipe + self._fileno = pipe.fileno() + _set_nonblocking(self._fileno) + self._protocol = protocol + self._closing = False + self._loop.add_reader(self._fileno, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = os.read(self._fileno, self.max_size) + except (BlockingIOError, InterruptedError): + pass + except OSError as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + self._closing = True + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._protocol.eof_received) + self._loop.call_soon(self._call_connection_lost, None) + + def pause(self): + self._loop.remove_reader(self._fileno) + + def resume(self): + self._loop.add_reader(self._fileno, self._read_ready) + + def close(self): + if not self._closing: + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._closing = True + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + self._pipe = None + self._protocol = None + self._loop = None + + +class _UnixWritePipeTransport(transports.WriteTransport): + + def __init__(self, loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._loop = loop + self._pipe = pipe + self._fileno = pipe.fileno() + if not stat.S_ISFIFO(os.fstat(self._fileno).st_mode): + raise ValueError("Pipe transport is for pipes only.") + _set_nonblocking(self._fileno) + self._protocol = protocol + self._buffer = [] + self._conn_lost = 0 + self._closing = False # Set when close() or write_eof() called. + self._writing = True + self._loop.add_reader(self._fileno, self._read_ready) + + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + # pipe was closed by peer + + self._close() + + def write(self, data): + assert isinstance(data, bytes), repr(data) + assert not self._closing + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('os.write(pipe, data) raised exception.') + self._conn_lost += 1 + return + + if not self._buffer and self._writing: + # Attempt to send it right away first. + try: + n = os.write(self._fileno, data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + self._conn_lost += 1 + self._fatal_error(exc) + return + if n == len(data): + return + elif n > 0: + data = data[n:] + self._loop.add_writer(self._fileno, self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + if not self._writing: + return + + data = b''.join(self._buffer) + assert data, 'Data should not be empty' + + self._buffer.clear() + try: + n = os.write(self._fileno, data) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except Exception as exc: + self._conn_lost += 1 + # Remove writer here, _fatal_error() doesn't it + # because _buffer is empty. + self._loop.remove_writer(self._fileno) + self._fatal_error(exc) + else: + if n == len(data): + self._loop.remove_writer(self._fileno) + if self._closing: + self._loop.remove_reader(self._fileno) + self._call_connection_lost(None) + return + elif n > 0: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def can_write_eof(self): + return True + + def write_eof(self): + assert not self._closing + assert self._pipe + self._closing = True + if not self._buffer: + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, None) + + def close(self): + if not self._closing: + # write_eof is all what we needed to close the write pipe + self.write_eof() + + def abort(self): + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc=None): + self._closing = True + if self._buffer: + self._loop.remove_writer(self._fileno) + self._buffer.clear() + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + self._pipe = None + self._protocol = None + self._loop = None + + def pause_writing(self): + if self._writing: + if self._buffer: + self._loop.remove_writer(self._fileno) + self._writing = False + + def resume_writing(self): + if not self._writing: + if self._buffer: + self._loop.add_writer(self._fileno, self._write_ready) + self._writing = True + + def discard_output(self): + if self._buffer: + self._loop.remove_writer(self._fileno) + self._buffer.clear() + + +class _UnixWriteSubprocessPipeProto(protocols.BaseProtocol): + pipe = None + + def __init__(self, proc, fd): + self.proc = proc + self.fd = fd + self.connected = False + self.disconnected = False + proc._pipes[fd] = self + + def connection_made(self, transport): + self.connected = True + self.pipe = transport + self.proc._try_connected() + + def connection_lost(self, exc): + self.disconnected = True + self.proc._pipe_connection_lost(self.fd, exc) + + +class _UnixReadSubprocessPipeProto(_UnixWriteSubprocessPipeProto, + protocols.Protocol): + + def data_received(self, data): + self.proc._pipe_data_received(self.fd, data) + + def eof_received(self): + pass + + +class _UnixSubprocessTransport(transports.SubprocessTransport): + + def __init__(self, loop, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + super().__init__(extra) + self._protocol = protocol + self._loop = loop + + self._pipes = {} + if stdin == subprocess.PIPE: + self._pipes[0] = None + if stdout == subprocess.PIPE: + self._pipes[1] = None + if stderr == subprocess.PIPE: + self._pipes[2] = None + self._pending_calls = collections.deque() + self._finished = False + self._returncode = None + + self._proc = subprocess.Popen( + args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, + universal_newlines=False, bufsize=bufsize, **kwargs) + self._extra['subprocess'] = self._proc + + def close(self): + for proto in self._pipes.values(): + proto.pipe.close() + if self._returncode is None: + self.terminate() + + def get_pid(self): + return self._proc.pid + + def get_returncode(self): + return self._returncode + + def get_pipe_transport(self, fd): + if fd in self._pipes: + return self._pipes[fd].pipe + else: + return None + + def send_signal(self, signal): + self._proc.send_signal(signal) + + def terminate(self): + self._proc.terminate() + + def kill(self): + self._proc.kill() + + @tasks.coroutine + def _post_init(self): + proc = self._proc + loop = self._loop + if proc.stdin is not None: + transp, proto = yield from loop.connect_write_pipe( + functools.partial(_UnixWriteSubprocessPipeProto, self, 0), + proc.stdin) + if proc.stdout is not None: + transp, proto = yield from loop.connect_read_pipe( + functools.partial(_UnixReadSubprocessPipeProto, self, 1), + proc.stdout) + if proc.stderr is not None: + transp, proto = yield from loop.connect_read_pipe( + functools.partial(_UnixReadSubprocessPipeProto, self, 2), + proc.stderr) + if not self._pipes: + self._try_connected() + + def _call(self, cb, *data): + if self._pending_calls is not None: + self._pending_calls.append((cb, data)) + else: + self._loop.call_soon(cb, *data) + + def _try_connected(self): + assert self._pending_calls is not None + if all(p is not None and p.connected for p in self._pipes.values()): + self._loop.call_soon(self._protocol.connection_made, self) + for callback, data in self._pending_calls: + self._loop.call_soon(callback, *data) + self._pending_calls = None + + def _pipe_connection_lost(self, fd, exc): + self._call(self._protocol.pipe_connection_lost, fd, exc) + self._try_finish() + + def _pipe_data_received(self, fd, data): + self._call(self._protocol.pipe_data_received, fd, data) + + def _process_exited(self, returncode): + assert returncode is not None, returncode + assert self._returncode is None, self._returncode + self._returncode = returncode + self._loop._subprocess_closed(self) + self._call(self._protocol.process_exited) + self._try_finish() + + def _try_finish(self): + assert not self._finished + if self._returncode is None: + return + if all(p is not None and p.disconnected + for p in self._pipes.values()): + self._finished = True + self._loop.call_soon(self._call_connection_lost, None) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._proc = None + self._protocol = None + self._loop = None diff --git a/tulip/windows_events.py b/tulip/windows_events.py new file mode 100644 index 00000000..629b3475 --- /dev/null +++ b/tulip/windows_events.py @@ -0,0 +1,203 @@ +"""Selector and proactor eventloops for Windows.""" + +import errno +import socket +import weakref +import struct +import _winapi + +from . import futures +from . import proactor_events +from . import selector_events +from . import windows_utils +from . import _overlapped +from .log import tulip_log + + +__all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor'] + + +NULL = 0 +INFINITE = 0xffffffff +ERROR_CONNECTION_REFUSED = 1225 +ERROR_CONNECTION_ABORTED = 1236 + + +class _OverlappedFuture(futures.Future): + """Subclass of Future which represents an overlapped operation. + + Cancelling it will immediately cancel the overlapped operation. + """ + + def __init__(self, ov, *, loop=None): + super().__init__(loop=loop) + self.ov = ov + + def cancel(self): + try: + self.ov.cancel() + except OSError: + pass + return super().cancel() + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + def _socketpair(self): + return windows_utils.socketpair() + + +class ProactorEventLoop(proactor_events.BaseProactorEventLoop): + def __init__(self, proactor=None): + if proactor is None: + proactor = IocpProactor() + super().__init__(proactor) + + def _socketpair(self): + return windows_utils.socketpair() + + +class IocpProactor: + + def __init__(self, concurrency=0xffffffff): + self._loop = None + self._results = [] + self._iocp = _overlapped.CreateIoCompletionPort( + _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency) + self._cache = {} + self._registered = weakref.WeakSet() + self._stopped_serving = weakref.WeakSet() + + def set_loop(self, loop): + self._loop = loop + + def select(self, timeout=None): + if not self._results: + self._poll(timeout) + tmp = self._results + self._results = [] + return tmp + + def recv(self, conn, nbytes, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + handle = getattr(conn, 'handle', None) + if handle is None: + ov.WSARecv(conn.fileno(), nbytes, flags) + else: + ov.ReadFile(handle, nbytes) + return self._register(ov, conn, ov.getresult) + + def send(self, conn, buf, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + handle = getattr(conn, 'handle', None) + if handle is None: + ov.WSASend(conn.fileno(), buf, flags) + else: + ov.WriteFile(handle, buf) + return self._register(ov, conn, ov.getresult) + + def accept(self, listener): + self._register_with_iocp(listener) + conn = self._get_accept_socket() + ov = _overlapped.Overlapped(NULL) + ov.AcceptEx(listener.fileno(), conn.fileno()) + + def finish_accept(): + ov.getresult() + buf = struct.pack('@P', listener.fileno()) + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_ACCEPT_CONTEXT, + buf) + conn.settimeout(listener.gettimeout()) + return conn, conn.getpeername() + + return self._register(ov, listener, finish_accept) + + def connect(self, conn, address): + self._register_with_iocp(conn) + # the socket needs to be locally bound before we call ConnectEx() + try: + _overlapped.BindLocal(conn.fileno(), len(address)) + except OSError as e: + if e.winerror != errno.WSAEINVAL: + raise + # probably already locally bound; check using getsockname() + if conn.getsockname()[1] == 0: + raise + ov = _overlapped.Overlapped(NULL) + ov.ConnectEx(conn.fileno(), address) + + def finish_connect(): + ov.getresult() + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_CONNECT_CONTEXT, + 0) + return conn + + return self._register(ov, conn, finish_connect) + + def _register_with_iocp(self, obj): + if obj not in self._registered: + self._registered.add(obj) + _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) + + def _register(self, ov, obj, callback): + f = _OverlappedFuture(ov, loop=self._loop) + self._cache[ov.address] = (f, ov, obj, callback) + return f + + def _get_accept_socket(self): + s = socket.socket() + s.settimeout(0) + return s + + def _poll(self, timeout=None): + if timeout is None: + ms = INFINITE + elif timeout < 0: + raise ValueError("negative timeout") + else: + ms = int(timeout * 1000 + 0.5) + if ms >= INFINITE: + raise ValueError("timeout too big") + while True: + status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms) + if status is None: + return + address = status[3] + f, ov, obj, callback = self._cache.pop(address) + if obj in self._stopped_serving: + f.cancel() + elif not f.cancelled(): + try: + value = callback() + except OSError as e: + f.set_exception(e) + self._results.append(f) + else: + f.set_result(value) + self._results.append(f) + ms = 0 + + def stop_serving(self, obj): + # obj is a socket or pipe handle. It will be closed in + # BaseProactorEventLoop.stop_serving() which will make any + # pending operations fail quickly. + self._stopped_serving.add(obj) + + def close(self): + for (f, ov, obj, callback) in self._cache.values(): + try: + ov.cancel() + except OSError: + pass + + while self._cache: + if not self._poll(1): + tulip_log.debug('taking long time to close proactor') + + self._results = [] + if self._iocp is not None: + _winapi.CloseHandle(self._iocp) + self._iocp = None diff --git a/tulip/windows_utils.py b/tulip/windows_utils.py new file mode 100644 index 00000000..bf85f31e --- /dev/null +++ b/tulip/windows_utils.py @@ -0,0 +1,181 @@ +""" +Various Windows specific bits and pieces +""" + +import sys + +if sys.platform != 'win32': # pragma: no cover + raise ImportError('win32 only') + +import socket +import itertools +import msvcrt +import os +import subprocess +import tempfile +import _winapi + + +__all__ = ['socketpair', 'pipe', 'Popen', 'PIPE', 'PipeHandle'] + +# +# Constants/globals +# + +BUFSIZE = 8192 +PIPE = subprocess.PIPE +_mmap_counter=itertools.count() + +# +# Replacement for socket.socketpair() +# + +def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """A socket pair usable as a self-pipe, for Windows. + + Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. + """ + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + lsock.bind(('localhost', 0)) + lsock.listen(1) + addr, port = lsock.getsockname() + csock = socket.socket(family, type, proto) + csock.setblocking(False) + try: + csock.connect((addr, port)) + except (BlockingIOError, InterruptedError): + pass + except Exception: + lsock.close() + csock.close() + raise + ssock, _ = lsock.accept() + csock.setblocking(True) + lsock.close() + return (ssock, csock) + +# +# Replacement for os.pipe() using handles instead of fds +# + +def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): + """Like os.pipe() but with overlapped support and using handles not fds.""" + address = tempfile.mktemp(prefix=r'\\.\pipe\python-pipe-%d-%d-' % + (os.getpid(), next(_mmap_counter))) + + if duplex: + openmode = _winapi.PIPE_ACCESS_DUPLEX + access = _winapi.GENERIC_READ | _winapi.GENERIC_WRITE + obsize, ibsize = bufsize, bufsize + else: + openmode = _winapi.PIPE_ACCESS_INBOUND + access = _winapi.GENERIC_WRITE + obsize, ibsize = 0, bufsize + + openmode |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE + + if overlapped[0]: + openmode |= _winapi.FILE_FLAG_OVERLAPPED + + if overlapped[1]: + flags_and_attribs = _winapi.FILE_FLAG_OVERLAPPED + else: + flags_and_attribs = 0 + + h1 = h2 = None + try: + h1 = _winapi.CreateNamedPipe( + address, openmode, _winapi.PIPE_WAIT, + 1, obsize, ibsize, _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL) + + h2 = _winapi.CreateFile( + address, access, 0, _winapi.NULL, _winapi.OPEN_EXISTING, + flags_and_attribs, _winapi.NULL) + + ov = _winapi.ConnectNamedPipe(h1, overlapped=True) + ov.GetOverlappedResult(True) + return h1, h2 + except: + if h1 is not None: + _winapi.CloseHandle(h1) + if h2 is not None: + _winapi.CloseHandle(h2) + raise + +# +# Wrapper for a pipe handle +# + +class PipeHandle: + """Wrapper for an overlapped pipe handle which is vaguely file-object like. + + The IOCP event loop can use these instead of socket objects. + """ + def __init__(self, handle): + self._handle = handle + + @property + def handle(self): + return self._handle + + def fileno(self): + return self._handle + + def close(self, *, CloseHandle=_winapi.CloseHandle): + if self._handle is not None: + CloseHandle(self._handle) + self._handle = None + + __del__ = close + + def __enter__(self): + return self + + def __exit__(self, t, v, tb): + self.close() + +# +# Replacement for subprocess.Popen using overlapped pipe handles +# + +class Popen(subprocess.Popen): + """Replacement for subprocess.Popen using overlapped pipe handles. + + The stdin, stdout, stderr are None or instances of PipeHandle. + """ + def __init__(self, args, stdin=None, stdout=None, stderr=None, **kwds): + stdin_rfd = stdout_wfd = stderr_wfd = None + stdin_wh = stdout_rh = stderr_rh = None + if stdin == PIPE: + stdin_rh, stdin_wh = pipe(overlapped=(False, True)) + stdin_rfd = msvcrt.open_osfhandle(stdin_rh, os.O_RDONLY) + if stdout == PIPE: + stdout_rh, stdout_wh = pipe(overlapped=(True, False)) + stdout_wfd = msvcrt.open_osfhandle(stdout_wh, 0) + if stderr == PIPE: + stderr_rh, stderr_wh = pipe(overlapped=(True, False)) + stderr_wfd = msvcrt.open_osfhandle(stderr_wh, 0) + try: + super().__init__(args, bufsize=0, universal_newlines=False, + stdin=stdin_rfd, stdout=stdout_wfd, + stderr=stderr_wfd, **kwds) + except: + for h in (stdin_wh, stdout_rh, stderr_rh): + _winapi.CloseHandle(h) + raise + else: + if stdin_wh is not None: + self.stdin = PipeHandle(stdin_wh) + if stdout_rh is not None: + self.stdout = PipeHandle(stdout_rh) + if stderr_rh is not None: + self.stderr = PipeHandle(stderr_rh) + finally: + if stdin == PIPE: + os.close(stdin_rfd) + if stdout == PIPE: + os.close(stdout_wfd) + if stderr == PIPE: + os.close(stderr_wfd) From ead782e867f8384d109c211ec6135fa62e7bf0cf Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 3 Sep 2013 16:32:42 -0700 Subject: [PATCH 0588/1502] add tulip.http to packages list --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index dcaee96f..a19e3224 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,6 @@ setup(name='tulip', description="reference implementation of PEP 3156", url='http://www.python.org/dev/peps/pep-3156/', - packages=['tulip'], + packages=['tulip', 'tulip.http'], ext_modules=extensions ) From b6347f879772d9c5e77dc97123455b3af2da3774 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 4 Sep 2013 06:36:51 +0300 Subject: [PATCH 0589/1502] Fix cancellation for queue. --- tests/queues_test.py | 28 ++++++++++++++++++++++++++++ tulip/queues.py | 21 +++++++-------------- 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/tests/queues_test.py b/tests/queues_test.py index 131812a4..ab4ee91d 100644 --- a/tests/queues_test.py +++ b/tests/queues_test.py @@ -264,6 +264,20 @@ def test(): self.assertEqual(1, loop.run_until_complete(test())) self.assertAlmostEqual(0.06, loop.time()) + def test_get_cancelled_race(self): + q = queues.Queue(loop=self.loop) + + t1 = tasks.Task(q.get(), loop=self.loop) + t2 = tasks.Task(q.get(), loop=self.loop) + + test_utils.run_briefly(self.loop) + t1.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(t1.done()) + q.put_nowait('a') + test_utils.run_briefly(self.loop) + self.assertEqual(t2.result(), 'a') + class QueuePutTests(_QueueTestBase): @@ -338,6 +352,20 @@ def test(): self.assertTrue(t.done()) self.assertTrue(t.result()) + def test_put_cancelled_race(self): + q = queues.Queue(loop=self.loop, maxsize=1) + + t1 = tasks.Task(q.put('a'), loop=self.loop) + t2 = tasks.Task(q.put('b'), loop=self.loop) + t3 = tasks.Task(q.put('c'), loop=self.loop) + + test_utils.run_briefly(self.loop) + t2.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(t2.done()) + self.assertEqual(q.get_nowait(), 'a') + self.assertEqual(q.get_nowait(), 'c') + class LifoQueueTests(_QueueTestBase): diff --git a/tulip/queues.py b/tulip/queues.py index 4a46f1a2..23051234 100644 --- a/tulip/queues.py +++ b/tulip/queues.py @@ -69,10 +69,10 @@ def _format(self): result += ' _putters[{}]'.format(len(self._putters)) return result - def _consume_done_getters(self, waiters): + def _consume_done_getters(self): # Delete waiters at the head of the get() queue who've timed out. - while waiters and waiters[0].done(): - waiters.popleft() + while self._getters and self._getters[0].done(): + self._getters.popleft() def _consume_done_putters(self): # Delete waiters at the head of the put() queue who've timed out. @@ -110,7 +110,7 @@ def put(self, item): If you yield from put(), wait until a free slot is available before adding item. """ - self._consume_done_getters(self._getters) + self._consume_done_getters() if self._getters: assert not self._queue, ( 'queue non-empty, why are getters waiting?') @@ -126,11 +126,7 @@ def put(self, item): waiter = futures.Future(loop=self._loop) self._putters.append((item, waiter)) - try: - yield from waiter - except futures.CancelledError: - raise Full - + yield from waiter else: self._put(item) @@ -139,7 +135,7 @@ def put_nowait(self, item): If no free slot is immediately available, raise Full. """ - self._consume_done_getters(self._getters) + self._consume_done_getters() if self._getters: assert not self._queue, ( 'queue non-empty, why are getters waiting?') @@ -182,10 +178,7 @@ def get(self): waiter = futures.Future(loop=self._loop) self._getters.append(waiter) - try: - return (yield from waiter) - except futures.CancelledError: - raise Empty + return (yield from waiter) def get_nowait(self): """Remove and return an item from the queue. From e7e35e0d3fdb54b91d6a968eab8987b144e00a34 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 4 Sep 2013 07:38:14 +0300 Subject: [PATCH 0590/1502] Backout 715959bb5312 --- tests/queues_test.py | 28 ---------------------------- tulip/queues.py | 21 ++++++++++++++------- 2 files changed, 14 insertions(+), 35 deletions(-) diff --git a/tests/queues_test.py b/tests/queues_test.py index ab4ee91d..131812a4 100644 --- a/tests/queues_test.py +++ b/tests/queues_test.py @@ -264,20 +264,6 @@ def test(): self.assertEqual(1, loop.run_until_complete(test())) self.assertAlmostEqual(0.06, loop.time()) - def test_get_cancelled_race(self): - q = queues.Queue(loop=self.loop) - - t1 = tasks.Task(q.get(), loop=self.loop) - t2 = tasks.Task(q.get(), loop=self.loop) - - test_utils.run_briefly(self.loop) - t1.cancel() - test_utils.run_briefly(self.loop) - self.assertTrue(t1.done()) - q.put_nowait('a') - test_utils.run_briefly(self.loop) - self.assertEqual(t2.result(), 'a') - class QueuePutTests(_QueueTestBase): @@ -352,20 +338,6 @@ def test(): self.assertTrue(t.done()) self.assertTrue(t.result()) - def test_put_cancelled_race(self): - q = queues.Queue(loop=self.loop, maxsize=1) - - t1 = tasks.Task(q.put('a'), loop=self.loop) - t2 = tasks.Task(q.put('b'), loop=self.loop) - t3 = tasks.Task(q.put('c'), loop=self.loop) - - test_utils.run_briefly(self.loop) - t2.cancel() - test_utils.run_briefly(self.loop) - self.assertTrue(t2.done()) - self.assertEqual(q.get_nowait(), 'a') - self.assertEqual(q.get_nowait(), 'c') - class LifoQueueTests(_QueueTestBase): diff --git a/tulip/queues.py b/tulip/queues.py index 23051234..4a46f1a2 100644 --- a/tulip/queues.py +++ b/tulip/queues.py @@ -69,10 +69,10 @@ def _format(self): result += ' _putters[{}]'.format(len(self._putters)) return result - def _consume_done_getters(self): + def _consume_done_getters(self, waiters): # Delete waiters at the head of the get() queue who've timed out. - while self._getters and self._getters[0].done(): - self._getters.popleft() + while waiters and waiters[0].done(): + waiters.popleft() def _consume_done_putters(self): # Delete waiters at the head of the put() queue who've timed out. @@ -110,7 +110,7 @@ def put(self, item): If you yield from put(), wait until a free slot is available before adding item. """ - self._consume_done_getters() + self._consume_done_getters(self._getters) if self._getters: assert not self._queue, ( 'queue non-empty, why are getters waiting?') @@ -126,7 +126,11 @@ def put(self, item): waiter = futures.Future(loop=self._loop) self._putters.append((item, waiter)) - yield from waiter + try: + yield from waiter + except futures.CancelledError: + raise Full + else: self._put(item) @@ -135,7 +139,7 @@ def put_nowait(self, item): If no free slot is immediately available, raise Full. """ - self._consume_done_getters() + self._consume_done_getters(self._getters) if self._getters: assert not self._queue, ( 'queue non-empty, why are getters waiting?') @@ -178,7 +182,10 @@ def get(self): waiter = futures.Future(loop=self._loop) self._getters.append(waiter) - return (yield from waiter) + try: + return (yield from waiter) + except futures.CancelledError: + raise Empty def get_nowait(self): """Remove and return an item from the queue. From 19c5a5d923ee9bcaae92dcb03757da3162116a18 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 4 Sep 2013 07:38:31 +0300 Subject: [PATCH 0591/1502] Backout 2c359b312008 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a19e3224..dcaee96f 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,6 @@ setup(name='tulip', description="reference implementation of PEP 3156", url='http://www.python.org/dev/peps/pep-3156/', - packages=['tulip', 'tulip.http'], + packages=['tulip'], ext_modules=extensions ) From 916d4398c6176a5090a9681b029bf7399e051fbd Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 4 Sep 2013 07:38:48 +0300 Subject: [PATCH 0592/1502] Backout 27985f461839 --- tests/events_test.py | 24 +++- tests/futures_test.py | 9 ++ tests/http_server_test.py | 1 - tests/locks_test.py | 266 +++++++++++++++++++++++++++++++------- tests/queues_test.py | 93 ++++++++++++- tests/tasks_test.py | 257 +++++++++++++++++++++++++++++------- tulip/base_events.py | 27 +++- tulip/events.py | 6 +- tulip/futures.py | 28 +++- tulip/http/client.py | 9 +- tulip/locks.py | 240 ++++++++++++++++++++-------------- tulip/queues.py | 30 +++-- tulip/tasks.py | 150 +++++++++++++-------- 13 files changed, 852 insertions(+), 288 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 240518c0..7c342bad 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -228,6 +228,18 @@ def cb(): self.assertRaises(RuntimeError, self.loop.run_until_complete, task) + def test_run_until_complete_timeout(self): + t0 = self.loop.time() + task = tasks.async(tasks.sleep(0.2, loop=self.loop), loop=self.loop) + self.assertRaises(futures.TimeoutError, + self.loop.run_until_complete, + task, timeout=0.1) + t1 = self.loop.time() + self.assertTrue(0.08 <= t1-t0 <= 0.12, t1-t0) + self.loop.run_until_complete(task) + t2 = self.loop.time() + self.assertTrue(0.18 <= t2-t0 <= 0.22, t2-t0) + def test_call_later(self): results = [] @@ -939,7 +951,7 @@ def main(): return res start = time.monotonic() - t = tasks.Task(main(), loop=self.loop) + t = tasks.Task(main(), timeout=1, loop=self.loop) self.loop.run_forever() elapsed = time.monotonic() - start @@ -974,7 +986,7 @@ def connect(): stdin = transp.get_pipe_transport(0) stdin.write(b'Python The Winner') - self.loop.run_until_complete(proto.got_data[1].wait()) + self.loop.run_until_complete(proto.got_data[1].wait(1)) transp.close() self.loop.run_until_complete(proto.completed) self.assertEqual(-signal.SIGTERM, proto.returncode) @@ -1003,12 +1015,12 @@ def connect(): try: stdin = transp.get_pipe_transport(0) stdin.write(b'Python ') - self.loop.run_until_complete(proto.got_data[1].wait()) + self.loop.run_until_complete(proto.got_data[1].wait(1)) proto.got_data[1].clear() self.assertEqual(b'Python ', proto.data[1]) stdin.write(b'The Winner') - self.loop.run_until_complete(proto.got_data[1].wait()) + self.loop.run_until_complete(proto.got_data[1].wait(1)) self.assertEqual(b'Python The Winner', proto.data[1]) finally: transp.close() @@ -1207,13 +1219,13 @@ def connect(): stdin = transp.get_pipe_transport(0) stdout = transp.get_pipe_transport(1) stdin.write(b'test') - self.loop.run_until_complete(proto.got_data[1].wait()) + self.loop.run_until_complete(proto.got_data[1].wait(1)) self.assertEqual(b'OUT:test', proto.data[1]) stdout.close() self.loop.run_until_complete(proto.disconnects[1]) stdin.write(b'xxx') - self.loop.run_until_complete(proto.got_data[2].wait()) + self.loop.run_until_complete(proto.got_data[2].wait(1)) self.assertEqual(b'ERR:BrokenPipeError', proto.data[2]) transp.close() diff --git a/tests/futures_test.py b/tests/futures_test.py index 7c2abd18..c7228c00 100644 --- a/tests/futures_test.py +++ b/tests/futures_test.py @@ -132,6 +132,15 @@ def test_repr(self): self.assertIn('<18 more>', r) f_many_callbacks.cancel() + f_pending = futures.Future(loop=self.loop, timeout=10) + self.assertEqual('Future{timeout=10, when=10}', + repr(f_pending)) + f_pending.cancel() + + f_pending = futures.Future(loop=self.loop, timeout=10) + f_pending.cancel() + self.assertEqual('Future{timeout=10}', repr(f_pending)) + def test_copy_state(self): # Test the internal _copy_state method since it's being directly # invoked in other modules. diff --git a/tests/http_server_test.py b/tests/http_server_test.py index a9d4d5ed..862779b9 100644 --- a/tests/http_server_test.py +++ b/tests/http_server_test.py @@ -69,7 +69,6 @@ def test_connection_lost(self): handle = srv._request_handler srv.connection_lost(None) - test_utils.run_briefly(self.loop) self.assertIsNone(srv._request_handler) self.assertTrue(handle.cancelled()) diff --git a/tests/locks_test.py b/tests/locks_test.py index 9399d759..529c7268 100644 --- a/tests/locks_test.py +++ b/tests/locks_test.py @@ -115,6 +115,59 @@ def c3(result): self.assertTrue(t3.done()) self.assertTrue(t3.result()) + def test_acquire_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + lock = locks.Lock(loop=loop) + + self.assertTrue(loop.run_until_complete(lock.acquire())) + + acquired = loop.run_until_complete(lock.acquire(timeout=0.1)) + self.assertFalse(acquired) + self.assertAlmostEqual(0.1, loop.time()) + + lock = locks.Lock(loop=loop) + self.loop.run_until_complete(lock.acquire()) + + loop.call_soon(lock.release) + acquired = loop.run_until_complete(lock.acquire(10.1)) + self.assertTrue(acquired) + self.assertAlmostEqual(0.1, loop.time()) + + def test_acquire_timeout_mixed(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + lock = locks.Lock(loop=loop) + loop.run_until_complete(lock.acquire()) + tasks.Task(lock.acquire(), loop=loop) + tasks.Task(lock.acquire(), loop=loop) + acquire_task = tasks.Task(lock.acquire(0.01), loop=loop) + tasks.Task(lock.acquire(), loop=loop) + + acquired = loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + self.assertAlmostEqual(0.1, loop.time()) + + self.assertEqual(3, len(lock._waiters)) + + # wakeup to close waiting coroutines + for i in range(3): + lock.release() + test_utils.run_briefly(loop) + def test_acquire_cancel(self): lock = locks.Lock(loop=self.loop) self.assertTrue(self.loop.run_until_complete(lock.acquire())) @@ -126,54 +179,6 @@ def test_acquire_cancel(self): self.loop.run_until_complete, task) self.assertFalse(lock._waiters) - def test_cancel_race(self): - # Several tasks: - # - A acquires the lock - # - B is blocked in aqcuire() - # - C is blocked in aqcuire() - # - # Now, concurrently: - # - B is cancelled - # - A releases the lock - # - # If B's waiter is marked cancelled but not yet removed from - # _waiters, A's release() call will crash when trying to set - # B's waiter; instead, it should move on to C's waiter. - - # Setup: A has the lock, b and c are waiting. - lock = locks.Lock(loop=self.loop) - - @tasks.coroutine - def lockit(name, blocker): - yield from lock.acquire() - try: - if blocker is not None: - yield from blocker - finally: - lock.release() - - fa = futures.Future(loop=self.loop) - ta = tasks.Task(lockit('A', fa), loop=self.loop) - test_utils.run_briefly(self.loop) - self.assertTrue(lock.locked()) - tb = tasks.Task(lockit('B', None), loop=self.loop) - test_utils.run_briefly(self.loop) - self.assertEqual(len(lock._waiters), 1) - tc = tasks.Task(lockit('C', None), loop=self.loop) - test_utils.run_briefly(self.loop) - self.assertEqual(len(lock._waiters), 2) - - # Create the race and check. - # Without the fix this failed at the last assert. - fa.set_result(None) - tb.cancel() - self.assertTrue(lock._waiters[0].cancelled()) - test_utils.run_briefly(self.loop) - self.assertFalse(lock.locked()) - self.assertTrue(ta.done()) - self.assertTrue(tb.cancelled()) - self.assertTrue(tc.done()) - def test_release_not_acquired(self): lock = locks.Lock(loop=self.loop) @@ -290,6 +295,55 @@ def test_wait_on_set(self): res = self.loop.run_until_complete(ev.wait()) self.assertTrue(res) + def test_wait_timeout(self): + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.11, when) + when = yield 0 + self.assertAlmostEqual(10.2, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + ev = locks.EventWaiter(loop=loop) + + res = loop.run_until_complete(ev.wait(0.1)) + self.assertFalse(res) + self.assertAlmostEqual(0.1, loop.time()) + + ev = locks.EventWaiter(loop=loop) + loop.call_later(0.01, ev.set) + acquired = loop.run_until_complete(ev.wait(10.1)) + self.assertTrue(acquired) + self.assertAlmostEqual(0.11, loop.time()) + + def test_wait_timeout_mixed(self): + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + ev = locks.EventWaiter(loop=loop) + tasks.Task(ev.wait(), loop=loop) + tasks.Task(ev.wait(), loop=loop) + acquire_task = tasks.Task(ev.wait(0.1), loop=loop) + tasks.Task(ev.wait(), loop=loop) + + acquired = loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + self.assertAlmostEqual(0.1, loop.time()) + self.assertEqual(3, len(ev._waiters)) + + # wakeup to close waiting coroutines + ev.set() + test_utils.run_briefly(loop) + def test_wait_cancel(self): ev = locks.EventWaiter(loop=self.loop) @@ -431,6 +485,23 @@ def c3(result): self.assertTrue(t3.done()) self.assertTrue(t3.result()) + def test_wait_timeout(self): + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + cond = locks.Condition(loop=loop) + loop.run_until_complete(cond.acquire()) + + wait = loop.run_until_complete(cond.wait(0.1)) + self.assertFalse(wait) + self.assertTrue(cond.locked()) + self.assertAlmostEqual(0.1, loop.time()) + def test_wait_cancel(self): cond = locks.Condition(loop=self.loop) self.loop.run_until_complete(cond.acquire()) @@ -487,6 +558,49 @@ def c1(result): self.assertTrue(t.done()) self.assertTrue(t.result()) + def test_wait_for_timeout(self): + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + cond = locks.Condition(loop=loop) + + result = [] + + predicate = unittest.mock.Mock(return_value=False) + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate, 0.1)): + result.append(1) + else: + result.append(2) + cond.release() + + wait_for = tasks.Task(c1(result), loop=loop) + + test_utils.run_briefly(loop) + self.assertEqual([], result) + + loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + test_utils.run_briefly(loop) + self.assertEqual([], result) + + loop.run_until_complete(wait_for) + self.assertEqual([2], result) + self.assertEqual(3, predicate.call_count) + + self.assertAlmostEqual(0.1, loop.time()) + def test_wait_for_unacquired(self): cond = locks.Condition(loop=self.loop) @@ -720,6 +834,62 @@ def c4(result): # cleanup locked semaphore sem.release() + def test_acquire_timeout(self): + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.11, when) + when = yield 0 + self.assertAlmostEqual(10.2, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + sem = locks.Semaphore(loop=loop) + loop.run_until_complete(sem.acquire()) + + acquired = loop.run_until_complete(sem.acquire(0.1)) + self.assertFalse(acquired) + self.assertAlmostEqual(0.1, loop.time()) + + sem = locks.Semaphore(loop=loop) + loop.run_until_complete(sem.acquire()) + + loop.call_later(0.01, sem.release) + acquired = loop.run_until_complete(sem.acquire(10.1)) + self.assertTrue(acquired) + self.assertAlmostEqual(0.11, loop.time()) + + def test_acquire_timeout_mixed(self): + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + sem = locks.Semaphore(loop=loop) + loop.run_until_complete(sem.acquire()) + tasks.Task(sem.acquire(), loop=loop) + tasks.Task(sem.acquire(), loop=loop) + acquire_task = tasks.Task(sem.acquire(0.1), loop=loop) + tasks.Task(sem.acquire(), loop=loop) + + acquired = loop.run_until_complete(acquire_task) + self.assertFalse(acquired) + + self.assertAlmostEqual(0.1, loop.time()) + + self.assertEqual(3, len(sem._waiters)) + + # wakeup to close waiting coroutines + for i in range(3): + sem.release() + test_utils.run_briefly(loop) + def test_acquire_cancel(self): sem = locks.Semaphore(loop=self.loop) self.loop.run_until_complete(sem.acquire()) diff --git a/tests/queues_test.py b/tests/queues_test.py index 131812a4..0dce6653 100644 --- a/tests/queues_test.py +++ b/tests/queues_test.py @@ -236,7 +236,37 @@ def test_nonblocking_get_exception(self): q = queues.Queue(loop=self.loop) self.assertRaises(queues.Empty, q.get_nowait) - def test_get_cancelled(self): + def test_get_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = queues.Queue(loop=loop) + + @tasks.coroutine + def queue_get(): + with self.assertRaises(queues.Empty): + return (yield from q.get(timeout=0.01)) + + # Get works after timeout, with blocking and non-blocking put. + q.put_nowait(1) + self.assertEqual(1, (yield from q.get())) + + t = tasks.Task(q.put(2), loop=loop) + self.assertEqual(2, (yield from q.get())) + + self.assertTrue(t.done()) + self.assertIsNone(t.result()) + + loop.run_until_complete(queue_get()) + self.assertAlmostEqual(0.01, loop.time()) + + def test_get_timeout_cancelled(self): def gen(): when = yield @@ -252,7 +282,7 @@ def gen(): @tasks.coroutine def queue_get(): - return (yield from tasks.wait_for(q.get(), 0.05, loop=loop)) + return (yield from q.get(timeout=0.05)) @tasks.coroutine def test(): @@ -321,12 +351,47 @@ def test_nonblocking_put_exception(self): q.put_nowait(1) self.assertRaises(queues.Full, q.put_nowait, 2) - def test_put_cancelled(self): + def test_put_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + when = yield 0.01 + self.assertAlmostEqual(0.02, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = queues.Queue(1, loop=loop) + q.put_nowait(0) + + @tasks.coroutine + def queue_put(): + with self.assertRaises(queues.Full): + return (yield from q.put(1, timeout=0.01)) + + self.assertEqual(0, q.get_nowait()) + + # Put works after timeout, with blocking and non-blocking get. + get_task = tasks.Task(q.get(), loop=loop) + # Let the get start waiting. + yield from tasks.sleep(0.01, loop=loop) + q.put_nowait(2) + self.assertEqual(2, (yield from get_task)) + + q.put_nowait(3) + self.assertEqual(3, q.get_nowait()) + + loop.run_until_complete(queue_put()) + self.assertAlmostEqual(0.02, loop.time()) + + def test_put_timeout_cancelled(self): q = queues.Queue(loop=self.loop) @tasks.coroutine def queue_put(): - yield from q.put(1) + yield from q.put(1, timeout=0.01) return True @tasks.coroutine @@ -415,6 +480,26 @@ def join(): self.loop.run_until_complete(join()) + def test_join_timeout(self): + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = queues.JoinableQueue(loop=loop) + q.put_nowait(1) + + @tasks.coroutine + def join(): + yield from q.join(0.1) + + # Join completes in ~ 0.1 seconds, although no one calls task_done(). + loop.run_until_complete(join()) + self.assertAlmostEqual(0.1, loop.time()) + def test_format(self): q = queues.JoinableQueue(loop=self.loop) self.assertEqual(q._format(), 'maxsize=0') diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 8c26e3f9..3e1220dc 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -44,6 +44,57 @@ def notmuch(): self.assertIs(t._loop, loop) loop.close() + def test_task_decorator(self): + @tasks.task + def notmuch(): + yield from [] + return 'ko' + + try: + events.set_event_loop(self.loop) + t = notmuch() + finally: + events.set_event_loop(None) + + self.assertIsInstance(t, tasks.Task) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + + def test_task_decorator_func(self): + @tasks.task + def notmuch(): + return 'ko' + + try: + events.set_event_loop(self.loop) + t = notmuch() + finally: + events.set_event_loop(None) + + self.assertIsInstance(t, tasks.Task) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + + def test_task_decorator_fut(self): + @tasks.task + def notmuch(): + fut = futures.Future(loop=self.loop) + fut.set_result('ko') + return fut + + try: + events.set_event_loop(self.loop) + t = notmuch() + finally: + events.set_event_loop(None) + + self.assertIsInstance(t, tasks.Task) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ko') + def test_async_coroutine(self): @tasks.coroutine def notmuch(): @@ -193,47 +244,104 @@ def task(): self.assertTrue(t.done()) self.assertFalse(t.cancel()) -## def test_cancel_done_future(self): -## fut1 = futures.Future(loop=self.loop) -## fut2 = futures.Future(loop=self.loop) -## fut3 = futures.Future(loop=self.loop) - -## @tasks.coroutine -## def task(): -## yield from fut1 -## try: -## yield from fut2 -## except futures.CancelledError: -## pass -## yield from fut3 - -## t = tasks.Task(task(), loop=self.loop) -## test_utils.run_briefly(self.loop) -## fut1.set_result(None) -## t.cancel() -## test_utils.run_once(self.loop) # process fut1 result, delay cancel -## self.assertFalse(t.done()) -## test_utils.run_once(self.loop) # cancel fut2, but coro still alive -## self.assertFalse(t.done()) -## test_utils.run_briefly(self.loop) # cancel fut3 -## self.assertTrue(t.done()) - -## self.assertEqual(fut1.result(), None) -## self.assertTrue(fut2.cancelled()) -## self.assertTrue(fut3.cancelled()) -## self.assertTrue(t.cancelled()) - -## def test_cancel_in_coro(self): -## @tasks.coroutine -## def task(): -## t.cancel() -## return 12 - -## t = tasks.Task(task(), loop=self.loop) -## self.assertRaises( -## futures.CancelledError, self.loop.run_until_complete, t) -## self.assertTrue(t.done()) -## self.assertFalse(t.cancel()) + def test_cancel_done_future(self): + fut1 = futures.Future(loop=self.loop) + fut2 = futures.Future(loop=self.loop) + fut3 = futures.Future(loop=self.loop) + + @tasks.coroutine + def task(): + yield from fut1 + try: + yield from fut2 + except futures.CancelledError: + pass + yield from fut3 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + fut1.set_result(None) + t.cancel() + test_utils.run_once(self.loop) # process fut1 result, delay cancel + self.assertFalse(t.done()) + test_utils.run_once(self.loop) # cancel fut2, but coro still alive + self.assertFalse(t.done()) + test_utils.run_briefly(self.loop) # cancel fut3 + self.assertTrue(t.done()) + + self.assertEqual(fut1.result(), None) + self.assertTrue(fut2.cancelled()) + self.assertTrue(fut3.cancelled()) + self.assertTrue(t.cancelled()) + + def test_future_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(10.0, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def coro(): + yield from tasks.sleep(10.0, loop=loop) + return 12 + + t = tasks.Task(coro(), timeout=0.1, loop=loop) + + self.assertRaises( + futures.CancelledError, + loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + self.assertAlmostEqual(0.1, loop.time()) + + def test_future_timeout_catch(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(10.0, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def coro(): + yield from tasks.sleep(10.0, loop=loop) + return 12 + + class Cancelled(Exception): + pass + + @tasks.coroutine + def coro2(): + try: + yield from tasks.Task(coro(), timeout=0.1, loop=loop) + except futures.CancelledError: + raise Cancelled() + + self.assertRaises( + Cancelled, loop.run_until_complete, coro2()) + self.assertAlmostEqual(0.1, loop.time()) + + def test_cancel_in_coro(self): + @tasks.coroutine + def task(): + t.cancel() + return 12 + + t = tasks.Task(task(), loop=self.loop) + self.assertRaises( + futures.CancelledError, self.loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) def test_stop_while_run_in_complete(self): @@ -273,6 +381,57 @@ def task(): for w in waiters: w.close() + def test_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(10.0, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def task(): + yield from tasks.sleep(10.0, loop=loop) + return 42 + + t = tasks.Task(task(), loop=loop) + self.assertRaises( + futures.TimeoutError, loop.run_until_complete, t, 0.1) + self.assertAlmostEqual(0.1, loop.time()) + self.assertFalse(t.done()) + + # move forward to close generator + loop.advance_time(10) + self.assertEqual(42, loop.run_until_complete(t)) + self.assertTrue(t.done()) + + def test_timeout_not(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def task(): + yield from tasks.sleep(0.1, loop=loop) + return 42 + + t = tasks.Task(task(), loop=loop) + r = loop.run_until_complete(t, 10.0) + self.assertTrue(t.done()) + self.assertEqual(r, 42) + self.assertAlmostEqual(0.1, loop.time()) + def test_wait_for(self): def gen(): @@ -854,14 +1013,16 @@ def test_task_cancel_waiter_future(self): @tasks.coroutine def coro(): - yield from fut + try: + yield from fut + except futures.CancelledError: + pass task = tasks.Task(coro(), loop=self.loop) test_utils.run_briefly(self.loop) self.assertIs(task._fut_waiter, fut) task.cancel() - test_utils.run_briefly(self.loop) self.assertRaises( futures.CancelledError, self.loop.run_until_complete, task) self.assertIsNone(task._fut_waiter) @@ -943,14 +1104,12 @@ def gen(): def sleeper(): yield from tasks.sleep(10, loop=loop) - base_exc = BaseException() - @tasks.coroutine def notmutch(): try: yield from sleeper() except futures.CancelledError: - raise base_exc + raise BaseException() task = tasks.Task(notmutch(), loop=loop) test_utils.run_briefly(loop) @@ -961,8 +1120,7 @@ def notmutch(): self.assertRaises(BaseException, test_utils.run_briefly, loop) self.assertTrue(task.done()) - self.assertFalse(task.cancelled()) - self.assertIs(task.exception(), base_exc) + self.assertTrue(task.cancelled()) def test_iscoroutinefunction(self): def fn(): @@ -990,7 +1148,8 @@ def wait_for_future(): with self.assertRaises(RuntimeError) as cm: self.loop.run_until_complete(task) - self.assertFalse(fut.done()) + self.assertTrue(fut.done()) + self.assertIs(fut.exception(), cm.exception) def test_yield_vs_yield_from_generator(self): @tasks.coroutine diff --git a/tulip/base_events.py b/tulip/base_events.py index 3bccfc83..5ff2d3c9 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -112,8 +112,8 @@ def run_forever(self): finally: self._running = False - def run_until_complete(self, future): - """Run until the Future is done. + def run_until_complete(self, future, timeout=None): + """Run until the Future is done, or until a timeout. If the argument is a coroutine, it is wrapped in a Task. @@ -121,12 +121,31 @@ def run_until_complete(self, future): with the same coroutine twice -- it would wrap it in two different Tasks and that can't be good. - Return the Future's result, or raise its exception. + Return the Future's result, or raise its exception. If the + timeout is reached or stop() is called, raise TimeoutError. """ future = tasks.async(future, loop=self) future.add_done_callback(_raise_stop_error) - self.run_forever() + handle_called = False + + if timeout is None: + self.run_forever() + else: + + def stop_loop(): + nonlocal handle_called + handle_called = True + raise _StopError + + handle = self.call_later(timeout, stop_loop) + self.run_forever() + handle.cancel() + future.remove_done_callback(_raise_stop_error) + + if handle_called: + raise futures.TimeoutError + if not future.done(): raise RuntimeError('Event loop stopped before Future completed.') diff --git a/tulip/events.py b/tulip/events.py index 7db2514d..e292eea2 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -109,10 +109,14 @@ def run_forever(self): """Run the event loop until stop() is called.""" raise NotImplementedError - def run_until_complete(self, future): + def run_until_complete(self, future, timeout=None): """Run the event loop until a Future is done. Return the Future's result, or raise its exception. + + If timeout is not None, run it for at most that long; + if the Future is still not done, raise TimeoutError + (but don't cancel the Future). """ raise NotImplementedError diff --git a/tulip/futures.py b/tulip/futures.py index 706e8c8a..8593e9ae 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -1,7 +1,7 @@ """A Future class similar to the one in PEP 3148.""" __all__ = ['CancelledError', 'TimeoutError', - 'InvalidStateError', + 'InvalidStateError', 'InvalidTimeoutError', 'Future', 'wrap_future', ] @@ -30,6 +30,11 @@ class InvalidStateError(Error): # TODO: Show the future, its state, the method, and the required state. +class InvalidTimeoutError(Error): + """Called result() or exception() with timeout != 0.""" + # TODO: Print a nice error message. + + class _TracebackLogger: """Helper to log a traceback upon destruction if not cleared. @@ -124,13 +129,15 @@ class Future: _state = _PENDING _result = None _exception = None + _timeout = None + _timeout_handle = None _loop = None _blocking = False # proper use of future (yield vs yield from) _tb_logger = None - def __init__(self, *, loop=None): + def __init__(self, *, loop=None, timeout=None): """Initialize the future. The optional event_loop argument allows to explicitly set the event @@ -143,6 +150,10 @@ def __init__(self, *, loop=None): self._loop = loop self._callbacks = [] + if timeout is not None: + self._timeout = timeout + self._timeout_handle = self._loop.call_later(timeout, self.cancel) + def __repr__(self): res = self.__class__.__name__ if self._state == _FINISHED: @@ -160,6 +171,14 @@ def __repr__(self): res += '<{}, {}>'.format(self._state, self._callbacks) else: res += '<{}>'.format(self._state) + dct = {} + if self._timeout is not None: + dct['timeout'] = self._timeout + if self._timeout_handle is not None: + dct['when'] = self._timeout_handle._when + if dct: + res += '{' + ', '.join('{}={}'.format(k, dct[k]) + for k in sorted(dct)) + '}' return res def cancel(self): @@ -181,6 +200,11 @@ def _schedule_callbacks(self): The callbacks are scheduled to be called as soon as possible. Also clears the callback list. """ + # Cancel timeout handle + if self._timeout_handle is not None: + self._timeout_handle.cancel() + self._timeout_handle = None + callbacks = self._callbacks[:] if not callbacks: return diff --git a/tulip/http/client.py b/tulip/http/client.py index ec7cd034..2aedfdd1 100644 --- a/tulip/http/client.py +++ b/tulip/http/client.py @@ -95,17 +95,10 @@ def request(method, url, *, conn = session.start(req, loop) # connection timeout - t = tulip.Task(conn, loop=loop) - th = None - if timeout is not None: - th = loop.call_later(timeout, t.cancel) try: - resp = yield from t + resp = yield from tulip.Task(conn, timeout=timeout, loop=loop) except tulip.CancelledError: raise tulip.TimeoutError from None - finally: - if th is not None: - th.cancel() # redirects if resp.status in (301, 302) and allow_redirects: diff --git a/tulip/locks.py b/tulip/locks.py index 87937ec0..7c0a8f2a 100644 --- a/tulip/locks.py +++ b/tulip/locks.py @@ -1,4 +1,4 @@ -"""Synchronization primitives.""" +"""Synchronization primitives""" __all__ = ['Lock', 'EventWaiter', 'Condition', 'Semaphore'] @@ -10,31 +10,29 @@ class Lock: - """Primitive lock objects. - - A primitive lock is a synchronization primitive that is not owned - by a particular coroutine when locked. A primitive lock is in one - of two states, 'locked' or 'unlocked'. - - It is created in the unlocked state. It has two basic methods, - acquire() and release(). When the state is unlocked, acquire() - changes the state to locked and returns immediately. When the - state is locked, acquire() blocks until a call to release() in - another coroutine changes it to unlocked, then the acquire() call - resets it to locked and returns. The release() method should only - be called in the locked state; it changes the state to unlocked - and returns immediately. If an attempt is made to release an - unlocked lock, a RuntimeError will be raised. - - When more than one coroutine is blocked in acquire() waiting for - the state to turn to unlocked, only one coroutine proceeds when a - release() call resets the state to unlocked; first coroutine which - is blocked in acquire() is being processed. - - acquire() is a coroutine and should be called with 'yield from'. - - Locks also support the context manager protocol. '(yield from lock)' - should be used as context manager expression. + """The class implementing primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned by + a particular coroutine when locked. A primitive lock is in one of two + states, "locked" or "unlocked". + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() changes + the state to locked and returns immediately. When the state is locked, + acquire() blocks until a call to release() in another coroutine changes + it to unlocked, then the acquire() call resets it to locked and returns. + The release() method should only be called in the locked state; it changes + the state to unlocked and returns immediately. If an attempt is made + to release an unlocked lock, a RuntimeError will be raised. + + When more than one coroutine is blocked in acquire() waiting for the state + to turn to unlocked, only one coroutine proceeds when a release() call + resets the state to unlocked; first coroutine which is blocked in acquire() + is being processed. + + acquire() method is a coroutine and should be called with "yield from" + + Locks also support the context manager protocol. (yield from lock) should + be used as context manager expression. Usage: @@ -53,7 +51,7 @@ class Lock: with (yield from lock): ... - Lock objects can be tested for locking state: + Lock object could be tested for locking state: if not lock.locked(): yield from lock @@ -73,34 +71,45 @@ def __init__(self, *, loop=None): def __repr__(self): res = super().__repr__() - extra = 'locked' if self._locked else 'unlocked' - if self._waiters: - extra = '{},waiters:{}'.format(extra, len(self._waiters)) - return '<{} [{}]>'.format(res[1:-1], extra) + return '<{} [{}]>'.format( + res[1:-1], 'locked' if self._locked else 'unlocked') def locked(self): """Return true if lock is acquired.""" return self._locked @tasks.coroutine - def acquire(self): + def acquire(self, timeout=None): """Acquire a lock. - This method blocks until the lock is unlocked, then sets it to - locked and returns True. + Acquire method blocks until the lock is unlocked, then set it to + locked and return True. + + When invoked with the floating-point timeout argument set, blocks for + at most the number of seconds specified by timeout and as long as + the lock cannot be acquired. + + The return value is True if the lock is acquired successfully, + False if not (for example if the timeout expired). """ if not self._waiters and not self._locked: self._locked = True return True - fut = futures.Future(loop=self._loop) + fut = futures.Future(loop=self._loop, timeout=timeout) + self._waiters.append(fut) try: yield from fut - self._locked = True - return True - finally: + except futures.CancelledError: self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + self._locked = True + return True def release(self): """Release a lock. @@ -115,11 +124,8 @@ def release(self): """ if self._locked: self._locked = False - # Wake up the first waiter who isn't cancelled. - for fut in self._waiters: - if not fut.cancelled(): - fut.set_result(True) - break + if self._waiters: + self._waiters[0].set_result(True) else: raise RuntimeError('Lock is not acquired.') @@ -137,7 +143,6 @@ def __iter__(self): return self -# TODO: Why not call this Event? class EventWaiter: """A EventWaiter implementation, our equivalent to threading.Event @@ -156,7 +161,6 @@ def __init__(self, *, loop=None): self._loop = events.get_event_loop() def __repr__(self): - # TODO: add waiters:N if > 0. res = super().__repr__() return '<{} [{}]>'.format(res[1:-1], 'set' if self._value else 'unset') @@ -183,26 +187,41 @@ def clear(self): self._value = False @tasks.coroutine - def wait(self): - """Block until the internal flag is true. - - If the internal flag is true on entry, return True - immediately. Otherwise, block until another coroutine calls - set() to set the flag to true, then return True. + def wait(self, timeout=None): + """Block until the internal flag is true. If the internal flag + is true on entry, return immediately. Otherwise, block until another + coroutine calls set() to set the flag to true, or until the optional + timeout occurs. + + When the timeout argument is present and not None, it should be + a floating point number specifying a timeout for the operation in + seconds (or fractions thereof). + + This method returns true if and only if the internal flag has been + set to true, either before the wait call or after the wait starts, + so it will always return True except if a timeout is given and + the operation times out. + + wait() method is a coroutine. """ if self._value: return True - fut = futures.Future(loop=self._loop) + fut = futures.Future(loop=self._loop, timeout=timeout) + self._waiters.append(fut) try: yield from fut - return True - finally: + except futures.CancelledError: self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + + return True -# TODO: Why is this a Lock subclass? threading.Condition *has* a lock. class Condition(Lock): """A Condition implementation. @@ -213,55 +232,75 @@ class Condition(Lock): def __init__(self, *, loop=None): super().__init__(loop=loop) - self._condition_waiters = collections.deque() - # TODO: Add __repr__() with len(_condition_waiters). + self._condition_waiters = collections.deque() @tasks.coroutine - def wait(self): - """Wait until notified. + def wait(self, timeout=None): + """Wait until notified or until a timeout occurs. If the calling + coroutine has not acquired the lock when this method is called, + a RuntimeError is raised. + + This method releases the underlying lock, and then blocks until it is + awakened by a notify() or notify_all() call for the same condition + variable in another coroutine, or until the optional timeout occurs. + Once awakened or timed out, it re-acquires the lock and returns. - If the calling coroutine has not acquired the lock when this - method is called, a RuntimeError is raised. + When the timeout argument is present and not None, it should be + a floating point number specifying a timeout for the operation + in seconds (or fractions thereof). - This method releases the underlying lock, and then blocks - until it is awakened by a notify() or notify_all() call for - the same condition variable in another coroutine. Once - awakened, it re-acquires the lock and returns True. + The return value is True unless a given timeout expired, in which + case it is False. """ if not self._locked: raise RuntimeError('cannot wait on un-acquired lock') - keep_lock = True self.release() - try: - fut = futures.Future(loop=self._loop) - self._condition_waiters.append(fut) - try: - yield from fut - return True - finally: - self._condition_waiters.remove(fut) + fut = futures.Future(loop=self._loop, timeout=timeout) + + self._condition_waiters.append(fut) + keep_lock = True + try: + yield from fut + except futures.CancelledError: + self._condition_waiters.remove(fut) + return False except GeneratorExit: keep_lock = False # Prevent yield in finally clause. raise + else: + f = self._condition_waiters.popleft() + assert fut is f finally: if keep_lock: yield from self.acquire() - @tasks.coroutine - def wait_for(self, predicate): - """Wait until a predicate becomes true. + return True - The predicate should be a callable which result will be - interpreted as a boolean value. The final predicate value is - the return value. + @tasks.coroutine + def wait_for(self, predicate, timeout=None): + """Wait until a condition evaluates to True. predicate should be a + callable which result will be interpreted as a boolean value. A timeout + may be provided giving the maximum time to wait. """ + endtime = None + waittime = timeout result = predicate() + while not result: - yield from self.wait() + if waittime is not None: + if endtime is None: + endtime = self._loop.time() + waittime + else: + waittime = endtime - self._loop.time() + if waittime <= 0: + break + + yield from self.wait(waittime) result = predicate() + return result def notify(self, n=1): @@ -331,7 +370,6 @@ def __init__(self, value=1, bound=False, *, loop=None): self._loop = events.get_event_loop() def __repr__(self): - # TODO: add waiters:N if > 0. res = super().__repr__() return '<{} [{}]>'.format( res[1:-1], @@ -343,14 +381,17 @@ def locked(self): return self._locked @tasks.coroutine - def acquire(self): - """Acquire a semaphore. - - If the internal counter is larger than zero on entry, - decrement it by one and return True immediately. If it is - zero on entry, block, waiting until some other coroutine has - called release() to make it larger than 0, and then return - True. + def acquire(self, timeout=None): + """Acquire a semaphore. acquire() method is a coroutine. + + When invoked without arguments: if the internal counter is larger + than zero on entry, decrement it by one and return immediately. + If it is zero on entry, block, waiting until some other coroutine has + called release() to make it larger than zero. + + When invoked with a timeout other than None, it will block for at + most timeout seconds. If acquire does not complete successfully in + that interval, return false. Return true otherwise. """ if not self._waiters and self._value > 0: self._value -= 1 @@ -358,17 +399,22 @@ def acquire(self): self._locked = True return True - fut = futures.Future(loop=self._loop) + fut = futures.Future(loop=self._loop, timeout=timeout) + self._waiters.append(fut) try: yield from fut - self._value -= 1 - if self._value == 0: - self._locked = True - return True - finally: + except futures.CancelledError: self._waiters.remove(fut) + return False + else: + f = self._waiters.popleft() + assert f is fut + self._value -= 1 + if self._value == 0: + self._locked = True + return True def release(self): """Release a semaphore, incrementing the internal counter by one. @@ -391,8 +437,6 @@ def release(self): break def __enter__(self): - # TODO: This is questionable. How do we know the user actually - # wrote "with (yield from sema)" instead of "with sema"? return True def __exit__(self, *args): diff --git a/tulip/queues.py b/tulip/queues.py index 4a46f1a2..8214d0ec 100644 --- a/tulip/queues.py +++ b/tulip/queues.py @@ -4,6 +4,7 @@ 'Full', 'Empty'] import collections +import concurrent.futures import heapq import queue @@ -104,11 +105,14 @@ def full(self): return self.qsize() == self._maxsize @coroutine - def put(self, item): + def put(self, item, timeout=None): """Put an item into the queue. - If you yield from put(), wait until a free slot is available - before adding item. + If you yield from put() and timeout is None (the default), wait until a + free slot is available before adding item. + + If a timeout is provided, raise Full if no free slot becomes + available before the timeout. """ self._consume_done_getters(self._getters) if self._getters: @@ -123,12 +127,12 @@ def put(self, item): getter.set_result(self._get()) elif self._maxsize > 0 and self._maxsize == self.qsize(): - waiter = futures.Future(loop=self._loop) + waiter = futures.Future(loop=self._loop, timeout=timeout) self._putters.append((item, waiter)) try: yield from waiter - except futures.CancelledError: + except concurrent.futures.CancelledError: raise Full else: @@ -157,10 +161,14 @@ def put_nowait(self, item): self._put(item) @coroutine - def get(self): + def get(self, timeout=None): """Remove and return an item from the queue. - If you yield from get(), wait until a item is available. + If you yield from get() and timeout is None (the default), wait until a + item is available. + + If a timeout is provided, raise Empty if no item is available + before the timeout. """ self._consume_done_putters() if self._putters: @@ -179,12 +187,12 @@ def get(self): elif self.qsize(): return self._get() else: - waiter = futures.Future(loop=self._loop) + waiter = futures.Future(loop=self._loop, timeout=timeout) self._getters.append(waiter) try: return (yield from waiter) - except futures.CancelledError: + except concurrent.futures.CancelledError: raise Empty def get_nowait(self): @@ -278,7 +286,7 @@ def task_done(self): self._finished.set() @coroutine - def join(self): + def join(self, timeout=None): """Block until all items in the queue have been gotten and processed. The count of unfinished tasks goes up whenever an item is added to the @@ -287,4 +295,4 @@ def join(self): When the count of unfinished tasks drops to zero, join() unblocks. """ if self._unfinished_tasks > 0: - yield from self._finished.wait() + yield from self._finished.wait(timeout=timeout) diff --git a/tulip/tasks.py b/tulip/tasks.py index a51ee29a..ca513a10 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -1,6 +1,6 @@ """Support for tasks, coroutines and the scheduler.""" -__all__ = ['coroutine', 'Task', +__all__ = ['coroutine', 'task', 'Task', 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', 'wait', 'wait_for', 'as_completed', 'sleep', 'async', ] @@ -49,12 +49,32 @@ def iscoroutine(obj): return inspect.isgenerator(obj) # TODO: And what? +def task(func): + """Decorator for a coroutine to be wrapped in a Task.""" + if inspect.isgeneratorfunction(func): + coro = func + else: + def coro(*args, **kw): + res = func(*args, **kw) + if isinstance(res, futures.Future) or inspect.isgenerator(res): + res = yield from res + return res + + def task_wrapper(*args, **kwds): + return Task(coro(*args, **kwds)) + + return task_wrapper + + +_marker = object() + + class Task(futures.Future): """A coroutine wrapped in a Future.""" - def __init__(self, coro, *, loop=None): + def __init__(self, coro, *, loop=None, timeout=None): assert inspect.isgenerator(coro) # Must be a coroutine *object*. - super().__init__(loop=loop) + super().__init__(loop=loop, timeout=timeout) self._coro = coro self._fut_waiter = None self._must_cancel = False @@ -73,25 +93,36 @@ def __repr__(self): return res def cancel(self): - if self.done(): + if self.done() or self._must_cancel: return False - if self._fut_waiter is not None: - if self._fut_waiter.cancel(): - return True - # It must be the case that self._step is already scheduled. self._must_cancel = True - return True + # _step() will call super().cancel() to call the callbacks. + if self._fut_waiter is not None: + return self._fut_waiter.cancel() + else: + self._loop.call_soon(self._step_maybe) + return True + + def cancelled(self): + return self._must_cancel or super().cancelled() - def _step(self, value=None, exc=None): + def _step_maybe(self): + # Helper for cancel(). + if not self.done(): + return self._step() + + def _step(self, value=_marker, exc=None): assert not self.done(), \ '_step(): already done: {!r}, {!r}, {!r}'.format(self, value, exc) - if self._must_cancel: - assert self._fut_waiter is None - exc = futures.CancelledError() - value = None + + # We'll call either coro.throw(exc) or coro.send(value). + # Task cancel has to be delayed if current waiter future is done. + if self._must_cancel and exc is None and value is _marker: + exc = futures.CancelledError + coro = self._coro + value = None if value is _marker else value self._fut_waiter = None - # Call either coro.throw(exc) or coro.send(value). try: if exc is not None: result = coro.throw(exc) @@ -100,44 +131,53 @@ def _step(self, value=None, exc=None): else: result = next(coro) except StopIteration as exc: - self.set_result(exc.value) - except futures.CancelledError as exc: - super().cancel() # I.e., Future.cancel(self). + if self._must_cancel: + super().cancel() + else: + self.set_result(exc.value) except Exception as exc: - self.set_exception(exc) + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) except BaseException as exc: - self.set_exception(exc) + if self._must_cancel: + super().cancel() + else: + self.set_exception(exc) raise else: if isinstance(result, futures.Future): - # Yielded Future must come from Future.__iter__(). - if result._blocking: - result._blocking = False - result.add_done_callback(self._wakeup) - self._fut_waiter = result - else: - self._loop.call_soon( - self._step, None, + if not result._blocking: + result.set_exception( RuntimeError( 'yield was used instead of yield from ' 'in task {!r} with {!r}'.format(self, result))) - elif result is None: - # Bare yield relinquishes control for one event loop iteration. - self._loop.call_soon(self._step) - elif inspect.isgenerator(result): - # Yielding a generator is just wrong. - self._loop.call_soon( - self._step, None, - RuntimeError( - 'yield was used instead of yield from for ' - 'generator in task {!r} with {}'.format( - self, result))) + + result._blocking = False + result.add_done_callback(self._wakeup) + self._fut_waiter = result + + # task cancellation has been delayed. + if self._must_cancel: + self._fut_waiter.cancel() + else: - # Yielding something else is an error. - self._loop.call_soon( - self._step, None, - RuntimeError( - 'Task got bad yield: {!r}'.format(result))) + if inspect.isgenerator(result): + self._loop.call_soon( + self._step, None, + RuntimeError( + 'yield was used instead of yield from for ' + 'generator in task {!r} with {}'.format( + self, result))) + else: + if result is not None: + self._loop.call_soon( + self._step, None, + RuntimeError( + 'Task got bad yield: {!r}'.format(result))) + else: + self._loop.call_soon(self._step_maybe) self = None def _wakeup(self, future): @@ -187,11 +227,11 @@ def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): @coroutine def wait_for(fut, timeout, *, loop=None): - """Wait for the single Future or coroutine to complete, with timeout. + """Wait for the single Future or coroutine to complete. Coroutine will be wrapped in Task. - Returns result of the Future or coroutine. Raises TimeoutError when + Returns result of the Future or coroutine. Raises TimeoutError when timeout occurs. Usage: @@ -219,10 +259,7 @@ def _wait(fs, timeout, return_when, loop): The timeout argument is like for wait(). """ assert fs, 'Set of Futures is empty.' - waiter = futures.Future(loop=loop) - timeout_handle = None - if timeout is not None: - timeout_handle = loop.call_later(timeout, waiter.cancel) + waiter = futures.Future(loop=loop, timeout=timeout) counter = len(fs) def _on_completion(f): @@ -232,8 +269,6 @@ def _on_completion(f): return_when == FIRST_COMPLETED or return_when == FIRST_EXCEPTION and (not f.cancelled() and f.exception() is not None)): - if timeout_handle is not None: - timeout_handle.cancel() waiter.cancel() for f in fs: @@ -306,16 +341,19 @@ def sleep(delay, result=None, *, loop=None): h.cancel() -def async(coro_or_future, *, loop=None): +def async(coro_or_future, *, loop=None, timeout=None): """Wrap a coroutine in a future. If the argument is a Future, it is returned directly. """ if isinstance(coro_or_future, futures.Future): - if loop is not None and loop is not coro_or_future._loop: - raise ValueError('loop argument must agree with Future') + if ((loop is not None and loop is not coro_or_future._loop) or + (timeout is not None and timeout != coro_or_future._timeout)): + raise ValueError( + 'loop and timeout arguments must agree with Future') + return coro_or_future elif iscoroutine(coro_or_future): - return Task(coro_or_future, loop=loop) + return Task(coro_or_future, loop=loop, timeout=timeout) else: raise TypeError('A Future or coroutine is required') From b9a148bb30dc39644b40db11400404ad70a60300 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 4 Sep 2013 07:51:09 +0300 Subject: [PATCH 0593/1502] Fix queues cancellation. --- .hgeol | 4 + .hgignore | 12 + Makefile | 35 + NOTES | 176 +++ README | 21 + TODO | 163 +++ check.py | 41 + examples/child_process.py | 127 +++ examples/crawl.py | 104 ++ examples/curl.py | 24 + examples/mpsrv.py | 289 +++++ examples/srv.py | 163 +++ examples/tcp_echo.py | 113 ++ examples/tcp_protocol_parser.py | 170 +++ examples/udp_echo.py | 98 ++ examples/websocket.html | 90 ++ examples/wsclient.py | 97 ++ examples/wssrv.py | 309 +++++ overlapped.c | 1009 +++++++++++++++++ runtests.py | 265 +++++ setup.cfg | 2 + setup.py | 14 + tests/base_events_test.py | 591 ++++++++++ tests/echo.py | 6 + tests/echo2.py | 6 + tests/echo3.py | 9 + tests/events_test.py | 1574 ++++++++++++++++++++++++++ tests/futures_test.py | 317 ++++++ tests/http_client_functional_test.py | 552 +++++++++ tests/http_client_test.py | 289 +++++ tests/http_parser_test.py | 539 +++++++++ tests/http_protocol_test.py | 400 +++++++ tests/http_server_test.py | 301 +++++ tests/http_session_test.py | 139 +++ tests/http_websocket_test.py | 439 +++++++ tests/http_wsgi_test.py | 301 +++++ tests/locks_test.py | 765 +++++++++++++ tests/parsers_test.py | 598 ++++++++++ tests/proactor_events_test.py | 393 +++++++ tests/queues_test.py | 455 ++++++++ tests/sample.crt | 14 + tests/sample.key | 15 + tests/selector_events_test.py | 1471 ++++++++++++++++++++++++ tests/selectors_test.py | 142 +++ tests/streams_test.py | 343 ++++++ tests/tasks_test.py | 1041 +++++++++++++++++ tests/transports_test.py | 59 + tests/unix_events_test.py | 818 +++++++++++++ tests/windows_events_test.py | 81 ++ tests/windows_utils_test.py | 132 +++ tulip/__init__.py | 28 + tulip/base_events.py | 592 ++++++++++ tulip/constants.py | 4 + tulip/events.py | 389 +++++++ tulip/futures.py | 338 ++++++ tulip/http/__init__.py | 16 + tulip/http/client.py | 572 ++++++++++ tulip/http/errors.py | 46 + tulip/http/protocol.py | 756 +++++++++++++ tulip/http/server.py | 215 ++++ tulip/http/session.py | 103 ++ tulip/http/websocket.py | 233 ++++ tulip/http/wsgi.py | 227 ++++ tulip/locks.py | 403 +++++++ tulip/log.py | 6 + tulip/parsers.py | 399 +++++++ tulip/proactor_events.py | 288 +++++ tulip/protocols.py | 100 ++ tulip/queues.py | 284 +++++ tulip/selector_events.py | 676 +++++++++++ tulip/selectors.py | 410 +++++++ tulip/streams.py | 211 ++++ tulip/tasks.py | 321 ++++++ tulip/test_utils.py | 443 ++++++++ tulip/transports.py | 201 ++++ tulip/unix_events.py | 555 +++++++++ tulip/windows_events.py | 203 ++++ tulip/windows_utils.py | 181 +++ 78 files changed, 23316 insertions(+) create mode 100644 .hgeol create mode 100644 .hgignore create mode 100644 Makefile create mode 100644 NOTES create mode 100644 README create mode 100644 TODO create mode 100644 check.py create mode 100644 examples/child_process.py create mode 100755 examples/crawl.py create mode 100755 examples/curl.py create mode 100755 examples/mpsrv.py create mode 100755 examples/srv.py create mode 100755 examples/tcp_echo.py create mode 100755 examples/tcp_protocol_parser.py create mode 100755 examples/udp_echo.py create mode 100644 examples/websocket.html create mode 100755 examples/wsclient.py create mode 100755 examples/wssrv.py create mode 100644 overlapped.c create mode 100644 runtests.py create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 tests/base_events_test.py create mode 100644 tests/echo.py create mode 100644 tests/echo2.py create mode 100644 tests/echo3.py create mode 100644 tests/events_test.py create mode 100644 tests/futures_test.py create mode 100644 tests/http_client_functional_test.py create mode 100644 tests/http_client_test.py create mode 100644 tests/http_parser_test.py create mode 100644 tests/http_protocol_test.py create mode 100644 tests/http_server_test.py create mode 100644 tests/http_session_test.py create mode 100644 tests/http_websocket_test.py create mode 100644 tests/http_wsgi_test.py create mode 100644 tests/locks_test.py create mode 100644 tests/parsers_test.py create mode 100644 tests/proactor_events_test.py create mode 100644 tests/queues_test.py create mode 100644 tests/sample.crt create mode 100644 tests/sample.key create mode 100644 tests/selector_events_test.py create mode 100644 tests/selectors_test.py create mode 100644 tests/streams_test.py create mode 100644 tests/tasks_test.py create mode 100644 tests/transports_test.py create mode 100644 tests/unix_events_test.py create mode 100644 tests/windows_events_test.py create mode 100644 tests/windows_utils_test.py create mode 100644 tulip/__init__.py create mode 100644 tulip/base_events.py create mode 100644 tulip/constants.py create mode 100644 tulip/events.py create mode 100644 tulip/futures.py create mode 100644 tulip/http/__init__.py create mode 100644 tulip/http/client.py create mode 100644 tulip/http/errors.py create mode 100644 tulip/http/protocol.py create mode 100644 tulip/http/server.py create mode 100644 tulip/http/session.py create mode 100644 tulip/http/websocket.py create mode 100644 tulip/http/wsgi.py create mode 100644 tulip/locks.py create mode 100644 tulip/log.py create mode 100644 tulip/parsers.py create mode 100644 tulip/proactor_events.py create mode 100644 tulip/protocols.py create mode 100644 tulip/queues.py create mode 100644 tulip/selector_events.py create mode 100644 tulip/selectors.py create mode 100644 tulip/streams.py create mode 100644 tulip/tasks.py create mode 100644 tulip/test_utils.py create mode 100644 tulip/transports.py create mode 100644 tulip/unix_events.py create mode 100644 tulip/windows_events.py create mode 100644 tulip/windows_utils.py diff --git a/.hgeol b/.hgeol new file mode 100644 index 00000000..8233b6dd --- /dev/null +++ b/.hgeol @@ -0,0 +1,4 @@ +[patterns] +** = native +.hgignore = native +.hgeol = native diff --git a/.hgignore b/.hgignore new file mode 100644 index 00000000..99870025 --- /dev/null +++ b/.hgignore @@ -0,0 +1,12 @@ +.*\.py[co]$ +.*~$ +.*\.orig$ +.*\#.*$ +.*@.*$ +\.coverage$ +htmlcov$ +\.DS_Store$ +venv$ +distribute_setup.py$ +distribute-\d+.\d+.\d+.tar.gz$ +build$ diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..6064fc63 --- /dev/null +++ b/Makefile @@ -0,0 +1,35 @@ +# Some simple testing tasks (sorry, UNIX only). + +PYTHON=python3 +VERBOSE=$(V) +V= 0 +FLAGS= + +test: + $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS) + +vtest: + $(PYTHON) runtests.py -v 1 $(FLAGS) + +testloop: + while sleep 1; do $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS); done + +# See runtests.py for coverage installation instructions. +cov coverage: + $(PYTHON) runtests.py --coverage tulip -v $(VERBOSE) $(FLAGS) + echo "open file://`pwd`/htmlcov/index.html" + +check: + $(PYTHON) check.py + +clean: + rm -rf `find . -name __pycache__` + rm -f `find . -type f -name '*.py[co]' ` + rm -f `find . -type f -name '*~' ` + rm -f `find . -type f -name '.*~' ` + rm -f `find . -type f -name '@*' ` + rm -f `find . -type f -name '#*#' ` + rm -f `find . -type f -name '*.orig' ` + rm -f `find . -type f -name '*.rej' ` + rm -f .coverage + rm -rf htmlcov diff --git a/NOTES b/NOTES new file mode 100644 index 00000000..3b94ba96 --- /dev/null +++ b/NOTES @@ -0,0 +1,176 @@ +Notes from PyCon 2013 sprints +============================= + +- Cancellation. If a task creates several subtasks, and then the + parent task fails, should the subtasks be cancelled? (How do we + even establish the parent/subtask relationship?) + +- Adam Sah suggests that there might be a need for scheduling + (especially when multiple frameworks share an event loop). He + points to lottery scheduling but also mentions that's just one of + the options. However, after posting on python-tulip, it appears + none of the other frameworks have scheduling, and nobody seems to + miss it. + +- Feedback from Bram Cohen (Bittorrent creator) about UDP. He doesn't + think connected UDP is worth supporting, it doesn't do anything + except tell the kernel about the default target address for + sendto(). Basically he says all UDP end points are servers. He + sent me his own UDP event loop so I might glean some tricks from it. + He says we should treat EINTR the same as EAGAIN and friends. (We + should use the exceptions dedicated to errno checking, BTW.) HE + said to make sure we use SO_REUSEADDR (I think we already do). He + said to set the max datagram sizes pretty large (anything larger + than the declared limit is dropped on the floor). He reminds us of + the importance of being able to pick a valid, unused port by binding + to port 0 and then using getsockname(). He has an idea where he's + like to be able to kill all registered callbacks (i.e. Handles) + belonging to a certain "context". I think this can be done at the + application level (you'd have to wrap everything that returns a + Handle and collect these handles in some set or other datastructure) + but if someone thinks it's interesting we could imagine having some + kind of notion of context part of the event loop state, + e.g. associated with a Task (see Cancellation point above). He + brought up uTP (Micro Transport Protocol), a reimplementation of TCP + over UDP with more refined congestion control. + +- Mumblings about UNIX domain sockets and IPv6 addresses being + 4-tuples. The former can be handled by passing in a socket. There + seem to be no real use cases for the latter that can't be dealt with + by passing in suitably esoteric strings for the hostname. + getaddrinfo() will produce the appropriate 4-tuple and connect() + will accept it. + +- Mumblings on the list about add vs. set. + + +Notes from the second Tulip/Twisted meet-up +=========================================== + +Rackspace, 12/11/2012 +Glyph, Brian Warner, David Reid, Duncan McGreggor, others + +Flow control +------------ + +- Pause/resume on transport manages data_received. + +- There's also an API to tell the transport whom to pause when the + write calls are overwhelming it: IConsumer.registerProducer(). + +- There's also something called pipes but it's built on top of the + old interface. + +- Twisted has variations on the basic flow control that I should + ignore. + +Half_close +---------- + +- This sends an EOF after writing some stuff. + +- Can't write any more. + +- Problem with TLS is known (the RFC sadly specifies this behavior). + +- It must be dynamimcally discoverable whether the transport supports + half_close, since the protocol may have to do something different to + make up for its missing (e.g. use chunked encoding). Twisted uses + an interface check for this and also hasattr(trans, 'halfClose') + but a flag (or flag method) is fine too. + +Constructing transport and protocol +----------------------------------- + +- There are good reasons for passing a function to the transport + construction helper that creates the protocol. (You need these + anyway for server-side protocols.) The sequence of events is + something like + + . open socket + . create transport (pass it a socket?) + . create protocol (pass it nothing) + . proto.make_connection(transport); this does: + . self.transport = transport + . self.connection_made(transport) + + But it seems okay to skip make_connection and setting .transport. + Note that make_connection() is a concrete method on the Protocol + implementation base class, while connection_made() is an abstract + method on IProtocol. + +Event Loop +---------- + +- We discussed the sequence of actions in the event loop. I think in the + end we're fine with what Tulip currently does. There are two choices: + + Tulip: + . run ready callbacks until there aren't any left + . poll, adding more callbacks to the ready list + . add now-ready delayed callbacks to the ready list + . go to top + + Tornado: + . run all currently ready callbacks (but not new ones added during this) + . (the rest is the same) + + The difference is that in the Tulip version, CPU bound callbacks + that keep adding more to the queue will starve I/O (and yielding to + other tasks won't actually cause I/O to happen unless you do + e.g. sleep(0.001)). OTOH this may be good because it means there's + less overhead if you frequently split operations in two. + +- I think Twisted does it Tornado style (in a convoluted way :-), but + it may not matter, and it's important to leave this vague so + implementations can do what's best for their platform. (E.g. if the + event loop is built into the OS there are different trade-offs.) + +System call cost +---------------- + +- System calls on MacOS are expensive, on Linux they are cheap. + +- Optimal buffer size ~16K. + +- Try joining small buffer pieces together, but expect to be tuning + this later. + +Futures +------- + +- Futures are the most robust API for async stuff, you can check + errors etc. So let's do this. + +- Just don't implement wait(). + +- For the basics, however, (recv/send, mostly), don't use Futures but use + basic callbacks, transport/protocol style. + +- make_connection() (by any name) can return a Future, it makes it + easier to check for errors. + +- This means revisiting the Tulip proactor branch (IOCP). + +- The semantics of add_done_callback() are fuzzy about in which thread + the callback will be called. (It may be the current thread or + another one.) We don't like that. But always inserting a + call_soon() indirection may be expensive? Glyph suggested changing + the add_done_callback() method name to something else to indicate + the changed promise. + +- Separately, I've been thinking about having two versions of + call_soon() -- a more heavy-weight one to be called from other + threads that also writes a byte to the self-pipe. + +Signals +------- + +- There was a side conversation about signals. A signal handler is + similar to another thread, so probably should use (the heavy-weight + version of) call_soon() to schedule the real callback and not do + anything else. + +- Glyph vaguely recalled some trickiness with the self-pipe. We + should be able to fix this afterwards if necessary, it shouldn't + affect the API design. diff --git a/README b/README new file mode 100644 index 00000000..8f2b6373 --- /dev/null +++ b/README @@ -0,0 +1,21 @@ +Tulip is the codename for my reference implementation of PEP 3156. + +PEP 3156: http://www.python.org/dev/peps/pep-3156/ + +*** This requires Python 3.3 or later! *** + +Copyright/license: Open source, Apache 2.0. Enjoy. + +Master Mercurial repo: http://code.google.com/p/tulip/ + +The actual code lives in the 'tulip' subdirectory. +Tests are in the 'tests' subdirectory. + +To run tests: + - make test + +To run coverage (coverage package is required): + - make coverage + + +--Guido van Rossum diff --git a/TODO b/TODO new file mode 100644 index 00000000..c6d4eead --- /dev/null +++ b/TODO @@ -0,0 +1,163 @@ +# -*- Mode: text -*- + +TO DO LARGER TASKS + +- Need more examples. + +- Benchmarkable but more realistic HTTP server? + +- Example of using UDP. + +- Write up a tutorial for the scheduling API. + +- More systematic approach to logging. Logger objects? What about + heavy-duty logging, tracking essentially all task state changes? + +- Restructure directory, move demos and benchmarks to subdirectories. + + +TO DO LATER + +- When multiple tasks are accessing the same socket, they should + either get interleaved I/O or an immediate exception; it should not + compromise the integrity of the scheduler or the app or leave a task + hanging. + +- For epoll you probably want to check/(log?) EPOLLHUP and EPOLLERR errors. + +- Add the simplest API possible to run a generator with a timeout. + +- Ensure multiple tasks can do atomic writes to the same pipe (since + UNIX guarantees that short writes to pipes are atomic). + +- Ensure some easy way of distributing accepted connections across tasks. + +- Be wary of thread-local storage. There should be a standard API to + get the current Context (which holds current task, event loop, and + maybe more) and a standard meta-API to change how that standard API + works (i.e. without monkey-patching). + +- See how much of asyncore I've already replaced. + +- Could BufferedReader reuse the standard io module's readers??? + +- Support ZeroMQ "sockets" which are user objects. Though possibly + this can be supported by getting the underlying fd? See + http://mail.python.org/pipermail/python-ideas/2012-October/017532.html + OTOH see + https://github.com/zeromq/pyzmq/blob/master/zmq/eventloop/ioloop.py + +- Study goroutines (again). + +- Benchmarks: http://nichol.as/benchmark-of-python-web-servers + + +FROM OLDER LIST + +- Multiple readers/writers per socket? (At which level? pollster, + eventloop, or scheduler?) + +- Could poll() usefully be an iterator? + +- Do we need to support more epoll and/or kqueue modes/flags/options/etc.? + +- Optimize register/unregister calls away if they cancel each other out? + +- Add explicit wait queue to wait for Task's completion, instead of + callbacks? + +- Look at pyfdpdlib's ioloop.py: + http://code.google.com/p/pyftpdlib/source/browse/trunk/pyftpdlib/lib/ioloop.py + + +MISTAKES I MADE + +- Forgetting yield from. (E.g.: scheduler.sleep(1); listener.accept().) + +- Forgot to add bare yield at end of internal function, after block(). + +- Forgot to call add_done_callback(). + +- Forgot to pass an undoer to block(), bug only found when cancelled. + +- Subtle accounting mistake in a callback. + +- Used context.eventloop from a different thread, forgetting about TLS. + +- Nasty race: eventloop.ready may contain both an I/O callback and a + cancel callback. How to avoid? Keep the DelayedCall in ready. Is + that enough? + +- If a toplevel task raises an error it just stops and nothing is logged + unless you have debug logging on. This confused me. (Then again, + previously I logged whenever a task raised an error, and that was too + chatty...) + +- Forgot to set the connection socket returned by accept() in + nonblocking mode. + +- Nastiest so far (cost me about a day): A race condition in + call_in_thread() where the Future's done_callback (which was + task.unblock()) would run immediately at the time when + add_done_callback() was called, and this screwed over the task + state. Solution: wrap the callback in eventloop.call_later(). + Ironically, I had a comment stating there might be a race condition. + +- Another bug where I was calling unblock() for the current thread + immediately after calling block(), before yielding. + +- readexactly() wasn't checking for EOF, so could be looping. + (Worse, the first fix I attempted was wrong.) + +- Spent a day trying to understand why a tentative patch trying to + move the recv() implementation into the eventloop (or the pollster) + resulted in problems cancelling a recv() call. Ultimately the + problem is that the cancellation mechanism is part of the coroutine + scheduler, which simply throws an exception into a task when it next + runs, and there isn't anything to be interrupted in the eventloop; + but the eventloop still has a reader registered (which will never + fire because I suspended the server -- that's my test case :-). + Then, the eventloop keeps running until the last file descriptor is + unregistered. What contributed to this disaster? + * I didn't build the whole infrastructure, just played with recv() + * I don't have unittests + * I don't have good logging to see what is going + +- In sockets.py, in some SSL error handling code, used the wrong + variable (sock instead of sslsock). A linter would have found this. + +- In polling.py, in KqueuePollster.register_writer(), a copy/paste + error where I was testing for "if fd not in self.readers" instead of + writers. This only came out when I had both a reader and a writer + for the same fd. + +- Submitted some changes prematurely (forgot to pass the filename on + hg ci). + +- Forgot again that shutdown(SHUT_WR) on an ssl socket does not work + as I expected. I ran into this with the origininal sockets.py and + again in transport.py. + +- Having the same callback for both reading and writing has a problem: + it may be scheduled twice, and if the first call closes the socket, + the second runs into trouble. + + +MISTAKES I MADE IN TULIP V2 + +- Nice one: Future._schedule_callbacks() wasn't scheduling any callbacks. + Spot the bug in these four lines: + + def _schedule_callbacks(self): + callbacks = self._callbacks[:] + self._callbacks[:] = [] + for callback in self._callbacks: + self._event_loop.call_soon(callback, self) + + The good news is that I found it with a unittest (albeit not the + unittest intended to exercise this particular method :-( ). + +- In _make_self_pipe_or_sock(), called _pollster.register_reader() + instead of add_reader(), trying to optimize something but breaking + things instead (since the -- internal -- API of register_reader() + had changed). diff --git a/check.py b/check.py new file mode 100644 index 00000000..9ab6bcc0 --- /dev/null +++ b/check.py @@ -0,0 +1,41 @@ +"""Search for lines >= 80 chars or with trailing whitespace.""" + +import sys, os + +def main(): + args = sys.argv[1:] or os.curdir + for arg in args: + if os.path.isdir(arg): + for dn, dirs, files in os.walk(arg): + for fn in sorted(files): + if fn.endswith('.py'): + process(os.path.join(dn, fn)) + dirs[:] = [d for d in dirs if d[0] != '.'] + dirs.sort() + else: + process(arg) + +def isascii(x): + try: + x.encode('ascii') + return True + except UnicodeError: + return False + +def process(fn): + try: + f = open(fn) + except IOError as err: + print(err) + return + try: + for i, line in enumerate(f): + line = line.rstrip('\n') + sline = line.rstrip() + if len(line) >= 80 or line != sline or not isascii(line): + print('{}:{:d}:{}{}'.format( + fn, i+1, sline, '_' * (len(line) - len(sline)))) + finally: + f.close() + +main() diff --git a/examples/child_process.py b/examples/child_process.py new file mode 100644 index 00000000..d4a035bd --- /dev/null +++ b/examples/child_process.py @@ -0,0 +1,127 @@ +""" +Example of asynchronous interaction with a child python process. + +Note that on Windows we must use the IOCP event loop. +""" + +import os +import sys + +try: + import tulip +except ImportError: + # tulip is not installed + sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + import tulip + +from tulip import streams +from tulip import protocols + +if sys.platform == 'win32': + from tulip.windows_utils import Popen, PIPE + from tulip.windows_events import ProactorEventLoop +else: + from subprocess import Popen, PIPE + +# +# Return a write-only transport wrapping a writable pipe +# + +@tulip.coroutine +def connect_write_pipe(file): + loop = tulip.get_event_loop() + protocol = protocols.Protocol() + transport, _ = yield from loop.connect_write_pipe(tulip.Protocol, file) + return transport + +# +# Wrap a readable pipe in a stream +# + +@tulip.coroutine +def connect_read_pipe(file): + loop = tulip.get_event_loop() + stream_reader = streams.StreamReader(loop=loop) + def factory(): + return streams.StreamReaderProtocol(stream_reader) + transport, _ = yield from loop.connect_read_pipe(factory, file) + return stream_reader, transport + + +# +# Example +# + +@tulip.task +def main(loop): + # program which prints evaluation of each expression from stdin + code = r'''if 1: + import os + def writeall(fd, buf): + while buf: + n = os.write(fd, buf) + buf = buf[n:] + while True: + s = os.read(0, 1024) + if not s: + break + s = s.decode('ascii') + s = repr(eval(s)) + '\n' + s = s.encode('ascii') + writeall(1, s) + ''' + + # commands to send to input + commands = iter([b"1+1\n", + b"2**16\n", + b"1/3\n", + b"'x'*50", + b"1/0\n"]) + + # start subprocess and wrap stdin, stdout, stderr + p = Popen([sys.executable, '-c', code], + stdin=PIPE, stdout=PIPE, stderr=PIPE) + + stdin = yield from connect_write_pipe(p.stdin) + stdout, stdout_transport = yield from connect_read_pipe(p.stdout) + stderr, stderr_transport = yield from connect_read_pipe(p.stderr) + + # interact with subprocess + name = {stdout:'OUT', stderr:'ERR'} + registered = {tulip.Task(stderr.readline()): stderr, + tulip.Task(stdout.readline()): stdout} + while registered: + # write command + cmd = next(commands, None) + if cmd is None: + stdin.close() + else: + print('>>>', cmd.decode('ascii').rstrip()) + stdin.write(cmd) + + # get and print lines from stdout, stderr + timeout = None + while registered: + done, pending = yield from tulip.wait( + registered, timeout=timeout, return_when=tulip.FIRST_COMPLETED) + if not done: + break + for f in done: + stream = registered.pop(f) + res = f.result() + print(name[stream], res.decode('ascii').rstrip()) + if res != b'': + registered[tulip.Task(stream.readline())] = stream + timeout = 0.0 + + stdout_transport.close() + stderr_transport.close() + +if __name__ == '__main__': + if sys.platform == 'win32': + loop = ProactorEventLoop() + tulip.set_event_loop(loop) + else: + loop = tulip.get_event_loop() + loop.run_until_complete(main(loop)) + loop.close() diff --git a/examples/crawl.py b/examples/crawl.py new file mode 100755 index 00000000..ac9c25e9 --- /dev/null +++ b/examples/crawl.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 + +import logging +import re +import signal +import sys +import urllib.parse + +import tulip +import tulip.http + + +class Crawler: + + def __init__(self, rooturl, loop, maxtasks=100): + self.rooturl = rooturl + self.loop = loop + self.todo = set() + self.busy = set() + self.done = {} + self.tasks = set() + self.sem = tulip.Semaphore(maxtasks) + + # session stores cookies between requests and uses connection pool + self.session = tulip.http.Session() + + @tulip.task + def run(self): + self.addurls([(self.rooturl, '')]) # Set initial work. + yield from tulip.sleep(1) + while self.busy: + yield from tulip.sleep(1) + + self.session.close() + self.loop.stop() + + @tulip.task + def addurls(self, urls): + for url, parenturl in urls: + url = urllib.parse.urljoin(parenturl, url) + url, frag = urllib.parse.urldefrag(url) + if (url.startswith(self.rooturl) and + url not in self.busy and + url not in self.done and + url not in self.todo): + self.todo.add(url) + yield from self.sem.acquire() + task = self.process(url) + task.add_done_callback(lambda t: self.sem.release()) + task.add_done_callback(self.tasks.remove) + self.tasks.add(task) + + @tulip.task + def process(self, url): + print('processing:', url) + + self.todo.remove(url) + self.busy.add(url) + try: + resp = yield from tulip.http.request( + 'get', url, session=self.session) + except Exception as exc: + print('...', url, 'has error', repr(str(exc))) + self.done[url] = False + else: + if resp.status == 200 and resp.get_content_type() == 'text/html': + data = (yield from resp.read()).decode('utf-8', 'replace') + urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', data) + self.addurls([(u, url) for u in urls]) + + resp.close() + self.done[url] = True + + self.busy.remove(url) + print(len(self.done), 'completed tasks,', len(self.tasks), + 'still pending, todo', len(self.todo)) + + +def main(): + loop = tulip.get_event_loop() + + c = Crawler(sys.argv[1], loop) + c.run() + + try: + loop.add_signal_handler(signal.SIGINT, loop.stop) + except RuntimeError: + pass + loop.run_forever() + print('todo:', len(c.todo)) + print('busy:', len(c.busy)) + print('done:', len(c.done), '; ok:', sum(c.done.values())) + print('tasks:', len(c.tasks)) + + +if __name__ == '__main__': + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + logging.info('using iocp') + el = windows_events.ProactorEventLoop() + events.set_event_loop(el) + + main() diff --git a/examples/curl.py b/examples/curl.py new file mode 100755 index 00000000..7063adcd --- /dev/null +++ b/examples/curl.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 + +import sys +import tulip +import tulip.http + + +def curl(url): + response = yield from tulip.http.request('GET', url) + print(repr(response)) + + data = yield from response.read() + print(data.decode('utf-8', 'replace')) + + +if __name__ == '__main__': + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + el = windows_events.ProactorEventLoop() + events.set_event_loop(el) + + loop = tulip.get_event_loop() + loop.run_until_complete(curl(sys.argv[1])) diff --git a/examples/mpsrv.py b/examples/mpsrv.py new file mode 100755 index 00000000..6b1ebb8f --- /dev/null +++ b/examples/mpsrv.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python3 +"""Simple multiprocess http server written using an event loop.""" + +import argparse +import email.message +import os +import socket +import signal +import time +import tulip +import tulip.http +from tulip.http import websocket + +ARGS = argparse.ArgumentParser(description="Run simple http server.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') +ARGS.add_argument( + '--workers', action="store", dest='workers', + default=2, type=int, help='Number of workers.') + + +class HttpServer(tulip.http.ServerHttpProtocol): + + @tulip.coroutine + def handle_request(self, message, payload): + print('{}: method = {!r}; path = {!r}; version = {!r}'.format( + os.getpid(), message.method, message.path, message.version)) + + path = message.path + + if (not (path.isprintable() and path.startswith('/')) or '/.' in path): + path = None + else: + path = '.' + path + if not os.path.exists(path): + path = None + else: + isdir = os.path.isdir(path) + + if not path: + raise tulip.http.HttpStatusException(404) + + headers = email.message.Message() + for hdr, val in message.headers: + headers.add_header(hdr, val) + + if isdir and not path.endswith('/'): + path = path + '/' + raise tulip.http.HttpStatusException( + 302, headers=(('URI', path), ('Location', path))) + + response = tulip.http.Response(self.transport, 200) + response.add_header('Transfer-Encoding', 'chunked') + + # content encoding + accept_encoding = headers.get('accept-encoding', '').lower() + if 'deflate' in accept_encoding: + response.add_header('Content-Encoding', 'deflate') + response.add_compression_filter('deflate') + elif 'gzip' in accept_encoding: + response.add_header('Content-Encoding', 'gzip') + response.add_compression_filter('gzip') + + response.add_chunking_filter(1025) + + if isdir: + response.add_header('Content-type', 'text/html') + response.send_headers() + + response.write(b'
    \r\n') + for name in sorted(os.listdir(path)): + if name.isprintable() and not name.startswith('.'): + try: + bname = name.encode('ascii') + except UnicodeError: + pass + else: + if os.path.isdir(os.path.join(path, name)): + response.write(b'
  • ' + bname + b'/
  • \r\n') + else: + response.write(b'
  • ' + bname + b'
  • \r\n') + response.write(b'
') + else: + response.add_header('Content-type', 'text/plain') + response.send_headers() + + try: + with open(path, 'rb') as fp: + chunk = fp.read(8196) + while chunk: + response.write(chunk) + chunk = fp.read(8196) + except OSError: + response.write(b'Cannot open') + + response.write_eof() + if response.keep_alive(): + self.keep_alive(True) + + +class ChildProcess: + + def __init__(self, up_read, down_write, args, sock): + self.up_read = up_read + self.down_write = down_write + self.args = args + self.sock = sock + + def start(self): + # start server + self.loop = loop = tulip.new_event_loop() + tulip.set_event_loop(loop) + + def stop(): + self.loop.stop() + os._exit(0) + loop.add_signal_handler(signal.SIGINT, stop) + + f = loop.start_serving( + lambda: HttpServer(debug=True, keep_alive=75), sock=self.sock) + x = loop.run_until_complete(f)[0] + print('Starting srv worker process {} on {}'.format( + os.getpid(), x.getsockname())) + + # heartbeat + self.heartbeat() + + tulip.get_event_loop().run_forever() + os._exit(0) + + @tulip.task + def heartbeat(self): + # setup pipes + read_transport, read_proto = yield from self.loop.connect_read_pipe( + tulip.StreamProtocol, os.fdopen(self.up_read, 'rb')) + write_transport, _ = yield from self.loop.connect_write_pipe( + tulip.StreamProtocol, os.fdopen(self.down_write, 'wb')) + + reader = read_proto.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(write_transport) + + while True: + msg = yield from reader.read() + if msg is None: + print('Superviser is dead, {} stopping...'.format(os.getpid())) + self.loop.stop() + break + elif msg.tp == websocket.MSG_PING: + writer.pong() + elif msg.tp == websocket.MSG_CLOSE: + break + + read_transport.close() + write_transport.close() + + +class Worker: + + _started = False + + def __init__(self, loop, args, sock): + self.loop = loop + self.args = args + self.sock = sock + self.start() + + def start(self): + assert not self._started + self._started = True + + up_read, up_write = os.pipe() + down_read, down_write = os.pipe() + args, sock = self.args, self.sock + + pid = os.fork() + if pid: + # parent + os.close(up_read) + os.close(down_write) + self.connect(pid, up_write, down_read) + else: + # child + os.close(up_write) + os.close(down_read) + + # cleanup after fork + tulip.set_event_loop(None) + + # setup process + process = ChildProcess(up_read, down_write, args, sock) + process.start() + + @tulip.task + def heartbeat(self, writer): + while True: + yield from tulip.sleep(15) + + if (time.monotonic() - self.ping) < 30: + writer.ping() + else: + print('Restart unresponsive worker process: {}'.format( + self.pid)) + self.kill() + self.start() + return + + @tulip.task + def chat(self, reader): + while True: + msg = yield from reader.read() + if msg is None: + print('Restart unresponsive worker process: {}'.format( + self.pid)) + self.kill() + self.start() + return + elif msg.tp == websocket.MSG_PONG: + self.ping = time.monotonic() + + @tulip.task + def connect(self, pid, up_write, down_read): + # setup pipes + read_transport, proto = yield from self.loop.connect_read_pipe( + tulip.StreamProtocol, os.fdopen(down_read, 'rb')) + write_transport, _ = yield from self.loop.connect_write_pipe( + tulip.StreamProtocol, os.fdopen(up_write, 'wb')) + + # websocket protocol + reader = proto.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(write_transport) + + # store info + self.pid = pid + self.ping = time.monotonic() + self.rtransport = read_transport + self.wtransport = write_transport + self.chat_task = self.chat(reader) + self.heartbeat_task = self.heartbeat(writer) + + def kill(self): + self._started = False + self.chat_task.cancel() + self.heartbeat_task.cancel() + self.rtransport.close() + self.wtransport.close() + os.kill(self.pid, signal.SIGTERM) + + +class Superviser: + + def __init__(self, args): + self.loop = tulip.get_event_loop() + self.args = args + self.workers = [] + + def start(self): + # bind socket + sock = self.sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((self.args.host, self.args.port)) + sock.listen(1024) + sock.setblocking(False) + + # start processes + for idx in range(self.args.workers): + self.workers.append(Worker(self.loop, self.args, sock)) + + self.loop.add_signal_handler(signal.SIGINT, lambda: self.loop.stop()) + self.loop.run_forever() + + +def main(): + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + superviser = Superviser(args) + superviser.start() + + +if __name__ == '__main__': + main() diff --git a/examples/srv.py b/examples/srv.py new file mode 100755 index 00000000..e01e407c --- /dev/null +++ b/examples/srv.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +"""Simple server written using an event loop.""" + +import argparse +import email.message +import logging +import os +import sys +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +import tulip +import tulip.http + + +class HttpServer(tulip.http.ServerHttpProtocol): + + @tulip.coroutine + def handle_request(self, message, payload): + print('method = {!r}; path = {!r}; version = {!r}'.format( + message.method, message.path, message.version)) + + path = message.path + + if (not (path.isprintable() and path.startswith('/')) or '/.' in path): + print('bad path', repr(path)) + path = None + else: + path = '.' + path + if not os.path.exists(path): + print('no file', repr(path)) + path = None + else: + isdir = os.path.isdir(path) + + if not path: + raise tulip.http.HttpStatusException(404) + + headers = email.message.Message() + for hdr, val in message.headers: + print(hdr, val) + headers.add_header(hdr, val) + + if isdir and not path.endswith('/'): + path = path + '/' + raise tulip.http.HttpStatusException( + 302, headers=(('URI', path), ('Location', path))) + + response = tulip.http.Response(self.transport, 200) + response.add_header('Transfer-Encoding', 'chunked') + + # content encoding + accept_encoding = headers.get('accept-encoding', '').lower() + if 'deflate' in accept_encoding: + response.add_header('Content-Encoding', 'deflate') + response.add_compression_filter('deflate') + elif 'gzip' in accept_encoding: + response.add_header('Content-Encoding', 'gzip') + response.add_compression_filter('gzip') + + response.add_chunking_filter(1025) + + if isdir: + response.add_header('Content-type', 'text/html') + response.send_headers() + + response.write(b'
    \r\n') + for name in sorted(os.listdir(path)): + if name.isprintable() and not name.startswith('.'): + try: + bname = name.encode('ascii') + except UnicodeError: + pass + else: + if os.path.isdir(os.path.join(path, name)): + response.write(b'
  • ' + bname + b'/
  • \r\n') + else: + response.write(b'
  • ' + bname + b'
  • \r\n') + response.write(b'
') + else: + response.add_header('Content-type', 'text/plain') + response.send_headers() + + try: + with open(path, 'rb') as fp: + chunk = fp.read(8196) + while chunk: + response.write(chunk) + chunk = fp.read(8196) + except OSError: + response.write(b'Cannot open') + + response.write_eof() + if response.keep_alive(): + self.keep_alive(True) + + +ARGS = argparse.ArgumentParser(description="Run simple http server.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') +ARGS.add_argument( + '--iocp', action="store_true", dest='iocp', help='Windows IOCP event loop') +ARGS.add_argument( + '--ssl', action="store_true", dest='ssl', help='Run ssl mode.') +ARGS.add_argument( + '--sslcert', action="store", dest='certfile', help='SSL cert file.') +ARGS.add_argument( + '--sslkey', action="store", dest='keyfile', help='SSL key file.') + + +def main(): + args = ARGS.parse_args() + + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if args.iocp: + from tulip import windows_events + sys.argv.remove('--iocp') + logging.info('using iocp') + el = windows_events.ProactorEventLoop() + tulip.set_event_loop(el) + + if args.ssl: + here = os.path.join(os.path.dirname(__file__), 'tests') + + if args.certfile: + certfile = args.certfile or os.path.join(here, 'sample.crt') + keyfile = args.keyfile or os.path.join(here, 'sample.key') + else: + certfile = os.path.join(here, 'sample.crt') + keyfile = os.path.join(here, 'sample.key') + + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain(certfile, keyfile) + else: + sslcontext = None + + loop = tulip.get_event_loop() + f = loop.start_serving( + lambda: HttpServer(debug=True, keep_alive=75), args.host, args.port, + ssl=sslcontext) + socks = loop.run_until_complete(f) + print('serving on', socks[0].getsockname()) + try: + loop.run_forever() + except KeyboardInterrupt: + pass + + +if __name__ == '__main__': + main() diff --git a/examples/tcp_echo.py b/examples/tcp_echo.py new file mode 100755 index 00000000..39db5cca --- /dev/null +++ b/examples/tcp_echo.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +"""TCP echo server example.""" +import argparse +import tulip +try: + import signal +except ImportError: + signal = None + + +class EchoServer(tulip.Protocol): + + TIMEOUT = 5.0 + + def timeout(self): + print('connection timeout, closing.') + self.transport.close() + + def connection_made(self, transport): + print('connection made') + self.transport = transport + + # start 5 seconds timeout timer + self.h_timeout = tulip.get_event_loop().call_later( + self.TIMEOUT, self.timeout) + + def data_received(self, data): + print('data received: ', data.decode()) + self.transport.write(b'Re: ' + data) + + # restart timeout timer + self.h_timeout.cancel() + self.h_timeout = tulip.get_event_loop().call_later( + self.TIMEOUT, self.timeout) + + def eof_received(self): + pass + + def connection_lost(self, exc): + print('connection lost:', exc) + self.h_timeout.cancel() + + +class EchoClient(tulip.Protocol): + + message = 'This is the message. It will be echoed.' + + def connection_made(self, transport): + self.transport = transport + self.transport.write(self.message.encode()) + print('data sent:', self.message) + + def data_received(self, data): + print('data received:', data) + + # disconnect after 10 seconds + tulip.get_event_loop().call_later(10.0, self.transport.close) + + def eof_received(self): + pass + + def connection_lost(self, exc): + print('connection lost:', exc) + tulip.get_event_loop().stop() + + +def start_client(loop, host, port): + t = tulip.Task(loop.create_connection(EchoClient, host, port)) + loop.run_until_complete(t) + + +def start_server(loop, host, port): + f = loop.start_serving(EchoServer, host, port) + x = loop.run_until_complete(f)[0] + print('serving on', x.getsockname()) + + +ARGS = argparse.ArgumentParser(description="TCP Echo example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run tcp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run tcp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') + + +if __name__ == '__main__': + args = ARGS.parse_args() + + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + loop = tulip.get_event_loop() + if signal is not None: + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if args.server: + start_server(loop, args.host, args.port) + else: + start_client(loop, args.host, args.port) + + loop.run_forever() diff --git a/examples/tcp_protocol_parser.py b/examples/tcp_protocol_parser.py new file mode 100755 index 00000000..a0258613 --- /dev/null +++ b/examples/tcp_protocol_parser.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +"""Protocol parser example.""" +import argparse +import collections +import tulip +try: + import signal +except ImportError: + signal = None + + +MSG_TEXT = b'text:' +MSG_PING = b'ping:' +MSG_PONG = b'pong:' +MSG_STOP = b'stop:' + +Message = collections.namedtuple('Message', ('tp', 'data')) + + +def my_protocol_parser(): + """Parser is used with StreamBuffer for incremental protocol parsing. + Parser is a generator function, but it is not a coroutine. Usually + parsers are implemented as a state machine. + + more details in tulip/parsers.py + existing parsers: + * http protocol parsers tulip/http/protocol.py + * websocket parser tulip/http/websocket.py + """ + out, buf = yield + + while True: + tp = yield from buf.read(5) + if tp in (MSG_PING, MSG_PONG): + # skip line + yield from buf.skipuntil(b'\r\n') + out.feed_data(Message(tp, None)) + elif tp == MSG_STOP: + out.feed_data(Message(tp, None)) + elif tp == MSG_TEXT: + # read text + text = yield from buf.readuntil(b'\r\n') + out.feed_data(Message(tp, text.strip().decode('utf-8'))) + else: + raise ValueError('Unknown protocol prefix.') + + +class MyProtocolWriter: + + def __init__(self, transport): + self.transport = transport + + def ping(self): + self.transport.write(b'ping:\r\n') + + def pong(self): + self.transport.write(b'pong:\r\n') + + def stop(self): + self.transport.write(b'stop:\r\n') + + def send_text(self, text): + self.transport.write( + 'text:{}\r\n'.format(text.strip()).encode('utf-8')) + + +class EchoServer(tulip.Protocol): + + def connection_made(self, transport): + print('Connection made') + self.transport = transport + self.stream = tulip.StreamBuffer() + self.dispatch() + + def data_received(self, data): + self.stream.feed_data(data) + + def eof_received(self): + self.stream.feed_eof() + + def connection_lost(self, exc): + print('Connection lost') + + @tulip.task + def dispatch(self): + reader = self.stream.set_parser(my_protocol_parser()) + writer = MyProtocolWriter(self.transport) + + while True: + msg = yield from reader.read() + if msg is None: + break # client has been disconnected + + print('Message received: {}'.format(msg)) + + if msg.tp == MSG_PING: + writer.pong() + elif msg.tp == MSG_TEXT: + writer.send_text('Re: ' + msg.data) + elif msg.tp == MSG_STOP: + self.transport.close() + break + + +@tulip.task +def start_client(loop, host, port): + transport, stream = yield from loop.create_connection( + tulip.StreamProtocol, host, port) + reader = stream.set_parser(my_protocol_parser()) + writer = MyProtocolWriter(transport) + writer.ping() + + message = 'This is the message. It will be echoed.' + + while True: + msg = yield from reader.read() + + print('Message received: {}'.format(msg)) + if msg.tp == MSG_PONG: + writer.send_text(message) + print('data sent:', message) + elif msg.tp == MSG_TEXT: + writer.stop() + print('stop sent') + break + + transport.close() + + +def start_server(loop, host, port): + f = loop.start_serving(EchoServer, host, port) + x = loop.run_until_complete(f)[0] + print('serving on', x.getsockname()) + loop.run_forever() + + +ARGS = argparse.ArgumentParser(description="Protocol parser example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run tcp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run tcp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') + + +if __name__ == '__main__': + args = ARGS.parse_args() + + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + loop = tulip.get_event_loop() + if signal is not None: + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if args.server: + start_server(loop, args.host, args.port) + else: + loop.run_until_complete(start_client(loop, args.host, args.port)) diff --git a/examples/udp_echo.py b/examples/udp_echo.py new file mode 100755 index 00000000..0347bfbd --- /dev/null +++ b/examples/udp_echo.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +"""UDP echo example.""" +import argparse +import sys +import tulip +try: + import signal +except ImportError: + signal = None + + +class MyServerUdpEchoProtocol: + + def connection_made(self, transport): + print('start', transport) + self.transport = transport + + def datagram_received(self, data, addr): + print('Data received:', data, addr) + self.transport.sendto(data, addr) + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('stop', exc) + + +class MyClientUdpEchoProtocol: + + message = 'This is the message. It will be echoed.' + + def connection_made(self, transport): + self.transport = transport + print('sending "{}"'.format(self.message)) + self.transport.sendto(self.message.encode()) + print('waiting to receive') + + def datagram_received(self, data, addr): + print('received "{}"'.format(data.decode())) + self.transport.close() + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('closing transport', exc) + loop = tulip.get_event_loop() + loop.stop() + + +def start_server(loop, addr): + t = tulip.Task(loop.create_datagram_endpoint( + MyServerUdpEchoProtocol, local_addr=addr)) + loop.run_until_complete(t) + + +def start_client(loop, addr): + t = tulip.Task(loop.create_datagram_endpoint( + MyClientUdpEchoProtocol, remote_addr=addr)) + loop.run_until_complete(t) + + +ARGS = argparse.ArgumentParser(description="UDP Echo example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run udp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run udp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') + + +if __name__ == '__main__': + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + loop = tulip.get_event_loop() + if signal is not None: + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if '--server' in sys.argv: + start_server(loop, (args.host, args.port)) + else: + start_client(loop, (args.host, args.port)) + + loop.run_forever() diff --git a/examples/websocket.html b/examples/websocket.html new file mode 100644 index 00000000..6bad7f74 --- /dev/null +++ b/examples/websocket.html @@ -0,0 +1,90 @@ + + + + + + + + +

Chat!

+
+  | Status: + disconnected +
+
+
+
+ + +
+ + diff --git a/examples/wsclient.py b/examples/wsclient.py new file mode 100755 index 00000000..f5b2ef58 --- /dev/null +++ b/examples/wsclient.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +"""websocket cmd client for wssrv.py example.""" +import argparse +import base64 +import hashlib +import os +import signal +import sys + +import tulip +import tulip.http +from tulip.http import websocket +import tulip.selectors + +WS_KEY = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + +def start_client(loop, url): + name = input('Please enter your name: ').encode() + + sec_key = base64.b64encode(os.urandom(16)) + + # send request + response = yield from tulip.http.request( + 'get', url, + headers={ + 'UPGRADE': 'WebSocket', + 'CONNECTION': 'Upgrade', + 'SEC-WEBSOCKET-VERSION': '13', + 'SEC-WEBSOCKET-KEY': sec_key.decode(), + }, timeout=1.0) + + # websocket handshake + if response.status != 101: + raise ValueError("Handshake error: Invalid response status") + if response.get('upgrade', '').lower() != 'websocket': + raise ValueError("Handshake error - Invalid upgrade header") + if response.get('connection', '').lower() != 'upgrade': + raise ValueError("Handshake error - Invalid connection header") + + key = response.get('sec-websocket-accept', '').encode() + match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()) + if key != match: + raise ValueError("Handshake error - Invalid challenge response") + + # switch to websocket protocol + stream = response.stream.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(response.transport) + + # input reader + def stdin_callback(): + line = sys.stdin.buffer.readline() + if not line: + loop.stop() + else: + writer.send(name + b': ' + line) + loop.add_reader(sys.stdin.fileno(), stdin_callback) + + @tulip.coroutine + def dispatch(): + while True: + msg = yield from stream.read() + if msg is None: + break + elif msg.tp == websocket.MSG_PING: + writer.pong() + elif msg.tp == websocket.MSG_TEXT: + print(msg.data.strip()) + elif msg.tp == websocket.MSG_CLOSE: + break + + yield from dispatch() + + +ARGS = argparse.ArgumentParser( + description="websocket console client for wssrv.py example.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') + +if __name__ == '__main__': + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + url = 'http://{}:{}'.format(args.host, args.port) + + loop = tulip.SelectorEventLoop(tulip.selectors.SelectSelector()) + tulip.set_event_loop(loop) + + loop.add_signal_handler(signal.SIGINT, loop.stop) + tulip.Task(start_client(loop, url)) + loop.run_forever() diff --git a/examples/wssrv.py b/examples/wssrv.py new file mode 100755 index 00000000..f96e0855 --- /dev/null +++ b/examples/wssrv.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python3 +"""Multiprocess WebSocket http chat example.""" + +import argparse +import os +import socket +import signal +import time +import tulip +import tulip.http +from tulip.http import websocket + +ARGS = argparse.ArgumentParser(description="Run simple http server.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') +ARGS.add_argument( + '--workers', action="store", dest='workers', + default=2, type=int, help='Number of workers.') + +WS_FILE = os.path.join(os.path.dirname(__file__), 'websocket.html') + + +class HttpServer(tulip.http.ServerHttpProtocol): + + clients = None # list of all active connections + parent = None # process supervisor + # we use it as broadcaster to all workers + + @tulip.coroutine + def handle_request(self, message, payload): + upgrade = False + for hdr, val in message.headers: + if hdr == 'UPGRADE': + upgrade = 'websocket' in val.lower() + break + + if upgrade: + # websocket handshake + status, headers, parser, writer = websocket.do_handshake( + message.method, message.headers, self.transport) + + resp = tulip.http.Response(self.transport, status) + resp.add_headers(*headers) + resp.send_headers() + + # install websocket parser + databuffer = self.stream.set_parser(parser) + + # notify everybody + print('{}: Someone joined.'.format(os.getpid())) + for wsc in self.clients: + wsc.send(b'Someone joined.') + self.clients.append(writer) + self.parent.send(b'Someone joined.') + + # chat dispatcher + while True: + msg = yield from databuffer.read() + if msg is None: # client droped connection + break + + if msg.tp == websocket.MSG_PING: + writer.pong() + + elif msg.tp == websocket.MSG_TEXT: + data = msg.data.strip() + print('{}: {}'.format(os.getpid(), data)) + for wsc in self.clients: + if wsc is not writer: + wsc.send(data.encode()) + self.parent.send(data) + + elif msg.tp == websocket.MSG_CLOSE: + break + + # notify everybody + print('{}: Someone disconnected.'.format(os.getpid())) + self.parent.send(b'Someone disconnected.') + self.clients.remove(writer) + for wsc in self.clients: + wsc.send(b'Someone disconnected.') + + else: + # send html page with js chat + response = tulip.http.Response(self.transport, 200) + response.add_header('Transfer-Encoding', 'chunked') + response.add_header('Content-type', 'text/html') + response.send_headers() + + try: + with open(WS_FILE, 'rb') as fp: + chunk = fp.read(8196) + while chunk: + if not response.write(chunk): + break + chunk = fp.read(8196) + except OSError: + response.write(b'Cannot open') + + response.write_eof() + if response.keep_alive(): + self.keep_alive(True) + + +class ChildProcess: + + def __init__(self, up_read, down_write, args, sock): + self.up_read = up_read + self.down_write = down_write + self.args = args + self.sock = sock + self.clients = [] + + def start(self): + # start server + self.loop = loop = tulip.new_event_loop() + tulip.set_event_loop(loop) + + def stop(): + self.loop.stop() + os._exit(0) + loop.add_signal_handler(signal.SIGINT, stop) + + # heartbeat + self.heartbeat() + + tulip.get_event_loop().run_forever() + os._exit(0) + + @tulip.task + def start_server(self, writer): + socks = yield from self.loop.start_serving( + lambda: HttpServer( + debug=True, keep_alive=75, + parent=writer, clients=self.clients), + sock=self.sock) + print('Starting srv worker process {} on {}'.format( + os.getpid(), socks[0].getsockname())) + + @tulip.task + def heartbeat(self): + # setup pipes + read_transport, read_proto = yield from self.loop.connect_read_pipe( + tulip.StreamProtocol, os.fdopen(self.up_read, 'rb')) + write_transport, _ = yield from self.loop.connect_write_pipe( + tulip.StreamProtocol, os.fdopen(self.down_write, 'wb')) + + reader = read_proto.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(write_transport) + + self.start_server(writer) + + while True: + msg = yield from reader.read() + if msg is None: + print('Superviser is dead, {} stopping...'.format(os.getpid())) + self.loop.stop() + break + elif msg.tp == websocket.MSG_PING: + writer.pong() + elif msg.tp == websocket.MSG_CLOSE: + break + elif msg.tp == websocket.MSG_TEXT: # broadcast message + for wsc in self.clients: + wsc.send(msg.data.strip().encode()) + + read_transport.close() + write_transport.close() + + +class Worker: + + _started = False + + def __init__(self, sv, loop, args, sock): + self.sv = sv + self.loop = loop + self.args = args + self.sock = sock + self.start() + + def start(self): + assert not self._started + self._started = True + + up_read, up_write = os.pipe() + down_read, down_write = os.pipe() + args, sock = self.args, self.sock + + pid = os.fork() + if pid: + # parent + os.close(up_read) + os.close(down_write) + self.connect(pid, up_write, down_read) + else: + # child + os.close(up_write) + os.close(down_read) + + # cleanup after fork + tulip.set_event_loop(None) + + # setup process + process = ChildProcess(up_read, down_write, args, sock) + process.start() + + @tulip.task + def heartbeat(self, writer): + while True: + yield from tulip.sleep(15) + + if (time.monotonic() - self.ping) < 30: + writer.ping() + else: + print('Restart unresponsive worker process: {}'.format( + self.pid)) + self.kill() + self.start() + return + + @tulip.task + def chat(self, reader): + while True: + msg = yield from reader.read() + if msg is None: + print('Restart unresponsive worker process: {}'.format( + self.pid)) + self.kill() + self.start() + return + + elif msg.tp == websocket.MSG_PONG: + self.ping = time.monotonic() + + elif msg.tp == websocket.MSG_TEXT: # broadcast to all workers + for worker in self.sv.workers: + if self.pid != worker.pid: + worker.writer.send(msg.data) + + @tulip.task + def connect(self, pid, up_write, down_read): + # setup pipes + read_transport, proto = yield from self.loop.connect_read_pipe( + tulip.StreamProtocol, os.fdopen(down_read, 'rb')) + write_transport, _ = yield from self.loop.connect_write_pipe( + tulip.StreamProtocol, os.fdopen(up_write, 'wb')) + + # websocket protocol + reader = proto.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(write_transport) + + # store info + self.pid = pid + self.ping = time.monotonic() + self.writer = writer + self.rtransport = read_transport + self.wtransport = write_transport + self.chat_task = self.chat(reader) + self.heartbeat_task = self.heartbeat(writer) + + def kill(self): + self._started = False + self.chat_task.cancel() + self.heartbeat_task.cancel() + self.rtransport.close() + self.wtransport.close() + os.kill(self.pid, signal.SIGTERM) + + +class Superviser: + + def __init__(self, args): + self.loop = tulip.get_event_loop() + self.args = args + self.workers = [] + + def start(self): + # bind socket + sock = self.sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((self.args.host, self.args.port)) + sock.listen(1024) + sock.setblocking(False) + + # start processes + for idx in range(self.args.workers): + self.workers.append(Worker(self, self.loop, self.args, sock)) + + self.loop.add_signal_handler(signal.SIGINT, lambda: self.loop.stop()) + self.loop.run_forever() + + +def main(): + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + superviser = Superviser(args) + superviser.start() + + +if __name__ == '__main__': + main() diff --git a/overlapped.c b/overlapped.c new file mode 100644 index 00000000..3a2c1208 --- /dev/null +++ b/overlapped.c @@ -0,0 +1,1009 @@ +/* + * Support for overlapped IO + * + * Some code borrowed from Modules/_winapi.c of CPython + */ + +/* XXX check overflow and DWORD <-> Py_ssize_t conversions + Check itemsize */ + +#include "Python.h" +#include "structmember.h" + +#define WINDOWS_LEAN_AND_MEAN +#include +#include +#include + +#if defined(MS_WIN32) && !defined(MS_WIN64) +# define F_POINTER "k" +# define T_POINTER T_ULONG +#else +# define F_POINTER "K" +# define T_POINTER T_ULONGLONG +#endif + +#define F_HANDLE F_POINTER +#define F_ULONG_PTR F_POINTER +#define F_DWORD "k" +#define F_BOOL "i" +#define F_UINT "I" + +#define T_HANDLE T_POINTER + +enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_WRITE, TYPE_ACCEPT, + TYPE_CONNECT, TYPE_DISCONNECT}; + +/* + * Map Windows error codes to subclasses of OSError + */ + +static PyObject * +SetFromWindowsErr(DWORD err) +{ + PyObject *exception_type; + + if (err == 0) + err = GetLastError(); + switch (err) { + case ERROR_CONNECTION_REFUSED: + exception_type = PyExc_ConnectionRefusedError; + break; + case ERROR_CONNECTION_ABORTED: + exception_type = PyExc_ConnectionAbortedError; + break; + default: + exception_type = PyExc_OSError; + } + return PyErr_SetExcFromWindowsErr(exception_type, err); +} + +/* + * Some functions should be loaded at runtime + */ + +static LPFN_ACCEPTEX Py_AcceptEx = NULL; +static LPFN_CONNECTEX Py_ConnectEx = NULL; +static LPFN_DISCONNECTEX Py_DisconnectEx = NULL; +static BOOL (CALLBACK *Py_CancelIoEx)(HANDLE, LPOVERLAPPED) = NULL; + +#define GET_WSA_POINTER(s, x) \ + (SOCKET_ERROR != WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, \ + &Guid##x, sizeof(Guid##x), &Py_##x, \ + sizeof(Py_##x), &dwBytes, NULL, NULL)) + +static int +initialize_function_pointers(void) +{ + GUID GuidAcceptEx = WSAID_ACCEPTEX; + GUID GuidConnectEx = WSAID_CONNECTEX; + GUID GuidDisconnectEx = WSAID_DISCONNECTEX; + HINSTANCE hKernel32; + SOCKET s; + DWORD dwBytes; + + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (s == INVALID_SOCKET) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + if (!GET_WSA_POINTER(s, AcceptEx) || + !GET_WSA_POINTER(s, ConnectEx) || + !GET_WSA_POINTER(s, DisconnectEx)) + { + closesocket(s); + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + closesocket(s); + + /* On WinXP we will have Py_CancelIoEx == NULL */ + hKernel32 = GetModuleHandle("KERNEL32"); + *(FARPROC *)&Py_CancelIoEx = GetProcAddress(hKernel32, "CancelIoEx"); + return 0; +} + +/* + * Completion port stuff + */ + +PyDoc_STRVAR( + CreateIoCompletionPort_doc, + "CreateIoCompletionPort(handle, port, key, concurrency) -> port\n\n" + "Create a completion port or register a handle with a port."); + +static PyObject * +overlapped_CreateIoCompletionPort(PyObject *self, PyObject *args) +{ + HANDLE FileHandle; + HANDLE ExistingCompletionPort; + ULONG_PTR CompletionKey; + DWORD NumberOfConcurrentThreads; + HANDLE ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE F_ULONG_PTR F_DWORD, + &FileHandle, &ExistingCompletionPort, &CompletionKey, + &NumberOfConcurrentThreads)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = CreateIoCompletionPort(FileHandle, ExistingCompletionPort, + CompletionKey, NumberOfConcurrentThreads); + Py_END_ALLOW_THREADS + + if (ret == NULL) + return SetFromWindowsErr(0); + return Py_BuildValue(F_HANDLE, ret); +} + +PyDoc_STRVAR( + GetQueuedCompletionStatus_doc, + "GetQueuedCompletionStatus(port, msecs) -> (err, bytes, key, address)\n\n" + "Get a message from completion port. Wait for up to msecs milliseconds."); + +static PyObject * +overlapped_GetQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort = NULL; + DWORD NumberOfBytes = 0; + ULONG_PTR CompletionKey = 0; + OVERLAPPED *Overlapped = NULL; + DWORD Milliseconds; + DWORD err; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, + &CompletionPort, &Milliseconds)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = GetQueuedCompletionStatus(CompletionPort, &NumberOfBytes, + &CompletionKey, &Overlapped, Milliseconds); + Py_END_ALLOW_THREADS + + err = ret ? ERROR_SUCCESS : GetLastError(); + if (Overlapped == NULL) { + if (err == WAIT_TIMEOUT) + Py_RETURN_NONE; + else + return SetFromWindowsErr(err); + } + return Py_BuildValue(F_DWORD F_DWORD F_ULONG_PTR F_POINTER, + err, NumberOfBytes, CompletionKey, Overlapped); +} + +PyDoc_STRVAR( + PostQueuedCompletionStatus_doc, + "PostQueuedCompletionStatus(port, bytes, key, address) -> None\n\n" + "Post a message to completion port."); + +static PyObject * +overlapped_PostQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort; + DWORD NumberOfBytes; + ULONG_PTR CompletionKey; + OVERLAPPED *Overlapped; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD F_ULONG_PTR F_POINTER, + &CompletionPort, &NumberOfBytes, &CompletionKey, + &Overlapped)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = PostQueuedCompletionStatus(CompletionPort, NumberOfBytes, + CompletionKey, Overlapped); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +/* + * Bind socket handle to local port without doing slow getaddrinfo() + */ + +PyDoc_STRVAR( + BindLocal_doc, + "BindLocal(handle, length_of_address_tuple) -> None\n\n" + "Bind a socket handle to an arbitrary local port.\n" + "If length_of_address_tuple is 2 then an AF_INET address is used.\n" + "If length_of_address_tuple is 4 then an AF_INET6 address is used."); + +static PyObject * +overlapped_BindLocal(PyObject *self, PyObject *args) +{ + SOCKET Socket; + int TupleLength; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE "i", &Socket, &TupleLength)) + return NULL; + + if (TupleLength == 2) { + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = 0; + addr.sin_addr.S_un.S_addr = INADDR_ANY; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else if (TupleLength == 4) { + struct sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_port = 0; + addr.sin6_addr = in6addr_any; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else { + PyErr_SetString(PyExc_ValueError, "expected tuple of length 2 or 4"); + return NULL; + } + + if (!ret) + return SetFromWindowsErr(WSAGetLastError()); + Py_RETURN_NONE; +} + +/* + * Mark operation as completed - used when reading produces ERROR_BROKEN_PIPE + */ + +static void +mark_as_completed(OVERLAPPED *ov) +{ + ov->Internal = 0; + if (ov->hEvent != NULL) + SetEvent(ov->hEvent); +} + +/* + * A Python object wrapping an OVERLAPPED structure and other useful data + * for overlapped I/O + */ + +PyDoc_STRVAR( + Overlapped_doc, + "Overlapped object"); + +typedef struct { + PyObject_HEAD + OVERLAPPED overlapped; + /* For convenience, we store the file handle too */ + HANDLE handle; + /* Error returned by last method call */ + DWORD error; + /* Type of operation */ + DWORD type; + /* Buffer used for reading (optional) */ + PyObject *read_buffer; + /* Buffer used for writing (optional) */ + Py_buffer write_buffer; +} OverlappedObject; + + +static PyObject * +Overlapped_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + OverlappedObject *self; + HANDLE event = INVALID_HANDLE_VALUE; + static char *kwlist[] = {"event", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|" F_HANDLE, kwlist, &event)) + return NULL; + + if (event == INVALID_HANDLE_VALUE) { + event = CreateEvent(NULL, TRUE, FALSE, NULL); + if (event == NULL) + return SetFromWindowsErr(0); + } + + self = PyObject_New(OverlappedObject, type); + if (self == NULL) { + if (event != NULL) + CloseHandle(event); + return NULL; + } + + self->handle = NULL; + self->error = 0; + self->type = TYPE_NONE; + self->read_buffer = NULL; + memset(&self->overlapped, 0, sizeof(OVERLAPPED)); + memset(&self->write_buffer, 0, sizeof(Py_buffer)); + if (event) + self->overlapped.hEvent = event; + return (PyObject *)self; +} + +static void +Overlapped_dealloc(OverlappedObject *self) +{ + DWORD bytes; + DWORD olderr = GetLastError(); + BOOL wait = FALSE; + BOOL ret; + + if (!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED) + { + if (Py_CancelIoEx && Py_CancelIoEx(self->handle, &self->overlapped)) + wait = TRUE; + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, + &bytes, wait); + Py_END_ALLOW_THREADS + + switch (ret ? ERROR_SUCCESS : GetLastError()) { + case ERROR_SUCCESS: + case ERROR_NOT_FOUND: + case ERROR_OPERATION_ABORTED: + break; + default: + PyErr_Format( + PyExc_RuntimeError, + "%R still has pending operation at " + "deallocation, the process may crash", self); + PyErr_WriteUnraisable(NULL); + } + } + + if (self->overlapped.hEvent != NULL) + CloseHandle(self->overlapped.hEvent); + + if (self->write_buffer.obj) + PyBuffer_Release(&self->write_buffer); + + Py_CLEAR(self->read_buffer); + PyObject_Del(self); + SetLastError(olderr); +} + +PyDoc_STRVAR( + Overlapped_cancel_doc, + "cancel() -> None\n\n" + "Cancel overlapped operation"); + +static PyObject * +Overlapped_cancel(OverlappedObject *self) +{ + BOOL ret = TRUE; + + if (self->type == TYPE_NOT_STARTED) + Py_RETURN_NONE; + + if (!HasOverlappedIoCompleted(&self->overlapped)) { + Py_BEGIN_ALLOW_THREADS + if (Py_CancelIoEx) + ret = Py_CancelIoEx(self->handle, &self->overlapped); + else + ret = CancelIo(self->handle); + Py_END_ALLOW_THREADS + } + + /* CancelIoEx returns ERROR_NOT_FOUND if the I/O completed in-between */ + if (!ret && GetLastError() != ERROR_NOT_FOUND) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +PyDoc_STRVAR( + Overlapped_getresult_doc, + "getresult(wait=False) -> result\n\n" + "Retrieve result of operation. If wait is true then it blocks\n" + "until the operation is finished. If wait is false and the\n" + "operation is still pending then an error is raised."); + +static PyObject * +Overlapped_getresult(OverlappedObject *self, PyObject *args) +{ + BOOL wait = FALSE; + DWORD transferred = 0; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, "|" F_BOOL, &wait)) + return NULL; + + if (self->type == TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation not yet attempted"); + return NULL; + } + + if (self->type == TYPE_NOT_STARTED) { + PyErr_SetString(PyExc_ValueError, "operation failed to start"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, &transferred, + wait); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + break; + case ERROR_BROKEN_PIPE: + if (self->read_buffer != NULL) + break; + /* fall through */ + default: + return SetFromWindowsErr(err); + } + + switch (self->type) { + case TYPE_READ: + assert(PyBytes_CheckExact(self->read_buffer)); + if (transferred != PyBytes_GET_SIZE(self->read_buffer) && + _PyBytes_Resize(&self->read_buffer, transferred)) + return NULL; + Py_INCREF(self->read_buffer); + return self->read_buffer; + case TYPE_ACCEPT: + case TYPE_CONNECT: + case TYPE_DISCONNECT: + Py_RETURN_NONE; + default: + return PyLong_FromUnsignedLong((unsigned long) transferred); + } +} + +PyDoc_STRVAR( + Overlapped_ReadFile_doc, + "ReadFile(handle, size) -> Overlapped[message]\n\n" + "Start overlapped read"); + +static PyObject * +Overlapped_ReadFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD nread; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &handle, &size)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = ReadFile(handle, PyBytes_AS_STRING(buf), size, &nread, + &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_BROKEN_PIPE: + mark_as_completed(&self->overlapped); + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSARecv_doc, + "RecvFile(handle, size, flags) -> Overlapped[message]\n\n" + "Start overlapped receive"); + +static PyObject * +Overlapped_WSARecv(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD flags = 0; + DWORD nread; + PyObject *buf; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD "|" F_DWORD, + &handle, &size, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + wsabuf.len = size; + wsabuf.buf = PyBytes_AS_STRING(buf); + + Py_BEGIN_ALLOW_THREADS + ret = WSARecv((SOCKET)handle, &wsabuf, 1, &nread, &flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_BROKEN_PIPE: + mark_as_completed(&self->overlapped); + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WriteFile_doc, + "WriteFile(handle, buf) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped write"); + +static PyObject * +Overlapped_WriteFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD written; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &handle, &bufobj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + + Py_BEGIN_ALLOW_THREADS + ret = WriteFile(handle, self->write_buffer.buf, + (DWORD)self->write_buffer.len, + &written, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSASend_doc, + "WSASend(handle, buf, flags) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped send"); + +static PyObject * +Overlapped_WSASend(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD flags; + DWORD written; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O" F_DWORD, + &handle, &bufobj, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + wsabuf.len = (DWORD)self->write_buffer.len; + wsabuf.buf = self->write_buffer.buf; + + Py_BEGIN_ALLOW_THREADS + ret = WSASend((SOCKET)handle, &wsabuf, 1, &written, flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_AcceptEx_doc, + "AcceptEx(listen_handle, accept_handle) -> Overlapped[address_as_bytes]\n\n" + "Start overlapped wait for client to connect"); + +static PyObject * +Overlapped_AcceptEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ListenSocket; + SOCKET AcceptSocket; + DWORD BytesReceived; + DWORD size; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE, + &ListenSocket, &AcceptSocket)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + size = sizeof(struct sockaddr_in6) + 16; + buf = PyBytes_FromStringAndSize(NULL, size*2); + if (!buf) + return NULL; + + self->type = TYPE_ACCEPT; + self->handle = (HANDLE)ListenSocket; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = Py_AcceptEx(ListenSocket, AcceptSocket, PyBytes_AS_STRING(buf), + 0, size, size, &BytesReceived, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + + +static int +parse_address(PyObject *obj, SOCKADDR *Address, int Length) +{ + char *Host; + unsigned short Port; + unsigned long FlowInfo; + unsigned long ScopeId; + + memset(Address, 0, Length); + + if (PyArg_ParseTuple(obj, "sH", &Host, &Port)) + { + Address->sa_family = AF_INET; + if (WSAStringToAddressA(Host, AF_INET, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN*)Address)->sin_port = htons(Port); + return Length; + } + else if (PyArg_ParseTuple(obj, "sHkk", &Host, &Port, &FlowInfo, &ScopeId)) + { + PyErr_Clear(); + Address->sa_family = AF_INET6; + if (WSAStringToAddressA(Host, AF_INET6, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN6*)Address)->sin6_port = htons(Port); + ((SOCKADDR_IN6*)Address)->sin6_flowinfo = FlowInfo; + ((SOCKADDR_IN6*)Address)->sin6_scope_id = ScopeId; + return Length; + } + + return -1; +} + + +PyDoc_STRVAR( + Overlapped_ConnectEx_doc, + "ConnectEx(client_handle, address_as_bytes) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_ConnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ConnectSocket; + PyObject *AddressObj; + char AddressBuf[sizeof(struct sockaddr_in6)]; + SOCKADDR *Address = (SOCKADDR*)AddressBuf; + int Length; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &ConnectSocket, &AddressObj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + Length = sizeof(AddressBuf); + Length = parse_address(AddressObj, Address, Length); + if (Length < 0) + return NULL; + + self->type = TYPE_CONNECT; + self->handle = (HANDLE)ConnectSocket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_ConnectEx(ConnectSocket, Address, Length, + NULL, 0, NULL, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_DisconnectEx_doc, + "DisconnectEx(handle, flags) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_DisconnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET Socket; + DWORD flags; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &Socket, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + self->type = TYPE_DISCONNECT; + self->handle = (HANDLE)Socket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_DisconnectEx(Socket, &self->overlapped, flags, 0); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +static PyObject* +Overlapped_getaddress(OverlappedObject *self) +{ + return PyLong_FromVoidPtr(&self->overlapped); +} + +static PyObject* +Overlapped_getpending(OverlappedObject *self) +{ + return PyBool_FromLong(!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED); +} + +static PyMethodDef Overlapped_methods[] = { + {"getresult", (PyCFunction) Overlapped_getresult, + METH_VARARGS, Overlapped_getresult_doc}, + {"cancel", (PyCFunction) Overlapped_cancel, + METH_NOARGS, Overlapped_cancel_doc}, + {"ReadFile", (PyCFunction) Overlapped_ReadFile, + METH_VARARGS, Overlapped_ReadFile_doc}, + {"WSARecv", (PyCFunction) Overlapped_WSARecv, + METH_VARARGS, Overlapped_WSARecv_doc}, + {"WriteFile", (PyCFunction) Overlapped_WriteFile, + METH_VARARGS, Overlapped_WriteFile_doc}, + {"WSASend", (PyCFunction) Overlapped_WSASend, + METH_VARARGS, Overlapped_WSASend_doc}, + {"AcceptEx", (PyCFunction) Overlapped_AcceptEx, + METH_VARARGS, Overlapped_AcceptEx_doc}, + {"ConnectEx", (PyCFunction) Overlapped_ConnectEx, + METH_VARARGS, Overlapped_ConnectEx_doc}, + {"DisconnectEx", (PyCFunction) Overlapped_DisconnectEx, + METH_VARARGS, Overlapped_DisconnectEx_doc}, + {NULL} +}; + +static PyMemberDef Overlapped_members[] = { + {"error", T_ULONG, + offsetof(OverlappedObject, error), + READONLY, "Error from last operation"}, + {"event", T_HANDLE, + offsetof(OverlappedObject, overlapped) + offsetof(OVERLAPPED, hEvent), + READONLY, "Overlapped event handle"}, + {NULL} +}; + +static PyGetSetDef Overlapped_getsets[] = { + {"address", (getter)Overlapped_getaddress, NULL, + "Address of overlapped structure"}, + {"pending", (getter)Overlapped_getpending, NULL, + "Whether the operation is pending"}, + {NULL}, +}; + +PyTypeObject OverlappedType = { + PyVarObject_HEAD_INIT(NULL, 0) + /* tp_name */ "_overlapped.Overlapped", + /* tp_basicsize */ sizeof(OverlappedObject), + /* tp_itemsize */ 0, + /* tp_dealloc */ (destructor) Overlapped_dealloc, + /* tp_print */ 0, + /* tp_getattr */ 0, + /* tp_setattr */ 0, + /* tp_reserved */ 0, + /* tp_repr */ 0, + /* tp_as_number */ 0, + /* tp_as_sequence */ 0, + /* tp_as_mapping */ 0, + /* tp_hash */ 0, + /* tp_call */ 0, + /* tp_str */ 0, + /* tp_getattro */ 0, + /* tp_setattro */ 0, + /* tp_as_buffer */ 0, + /* tp_flags */ Py_TPFLAGS_DEFAULT, + /* tp_doc */ "OVERLAPPED structure wrapper", + /* tp_traverse */ 0, + /* tp_clear */ 0, + /* tp_richcompare */ 0, + /* tp_weaklistoffset */ 0, + /* tp_iter */ 0, + /* tp_iternext */ 0, + /* tp_methods */ Overlapped_methods, + /* tp_members */ Overlapped_members, + /* tp_getset */ Overlapped_getsets, + /* tp_base */ 0, + /* tp_dict */ 0, + /* tp_descr_get */ 0, + /* tp_descr_set */ 0, + /* tp_dictoffset */ 0, + /* tp_init */ 0, + /* tp_alloc */ 0, + /* tp_new */ Overlapped_new, +}; + +static PyMethodDef overlapped_functions[] = { + {"CreateIoCompletionPort", overlapped_CreateIoCompletionPort, + METH_VARARGS, CreateIoCompletionPort_doc}, + {"GetQueuedCompletionStatus", overlapped_GetQueuedCompletionStatus, + METH_VARARGS, GetQueuedCompletionStatus_doc}, + {"PostQueuedCompletionStatus", overlapped_PostQueuedCompletionStatus, + METH_VARARGS, PostQueuedCompletionStatus_doc}, + {"BindLocal", overlapped_BindLocal, + METH_VARARGS, BindLocal_doc}, + {NULL} +}; + +static struct PyModuleDef overlapped_module = { + PyModuleDef_HEAD_INIT, + "_overlapped", + NULL, + -1, + overlapped_functions, + NULL, + NULL, + NULL, + NULL +}; + +#define WINAPI_CONSTANT(fmt, con) \ + PyDict_SetItemString(d, #con, Py_BuildValue(fmt, con)) + +PyMODINIT_FUNC +PyInit__overlapped(void) +{ + PyObject *m, *d; + + /* Ensure WSAStartup() called before initializing function pointers */ + m = PyImport_ImportModule("_socket"); + if (!m) + return NULL; + Py_DECREF(m); + + if (initialize_function_pointers() < 0) + return NULL; + + if (PyType_Ready(&OverlappedType) < 0) + return NULL; + + m = PyModule_Create(&overlapped_module); + if (PyModule_AddObject(m, "Overlapped", (PyObject *)&OverlappedType) < 0) + return NULL; + + d = PyModule_GetDict(m); + + WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING); + WINAPI_CONSTANT(F_DWORD, INFINITE); + WINAPI_CONSTANT(F_HANDLE, INVALID_HANDLE_VALUE); + WINAPI_CONSTANT(F_HANDLE, NULL); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_ACCEPT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_CONNECT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, TF_REUSE_SOCKET); + + return m; +} diff --git a/runtests.py b/runtests.py new file mode 100644 index 00000000..484bff09 --- /dev/null +++ b/runtests.py @@ -0,0 +1,265 @@ +"""Run Tulip unittests. + +Usage: + python3 runtests.py [flags] [pattern] ... + +Patterns are matched against the fully qualified name of the test, +including package, module, class and method, +e.g. 'tests.events_test.PolicyTests.testPolicy'. + +For full help, try --help. + +runtests.py --coverage is equivalent of: + + $(COVERAGE) run --branch runtests.py -v + $(COVERAGE) html $(list of files) + $(COVERAGE) report -m $(list of files) + +""" + +# Originally written by Beech Horn (for NDB). + +import argparse +import gc +import logging +import os +import re +import sys +import subprocess +import unittest +import importlib.machinery + +from unittest.signals import installHandler + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +ARGS = argparse.ArgumentParser(description="Run all unittests.") +ARGS.add_argument( + '-v', action="store", dest='verbose', + nargs='?', const=1, type=int, default=0, help='verbose') +ARGS.add_argument( + '-x', action="store_true", dest='exclude', help='exclude tests') +ARGS.add_argument( + '-f', '--failfast', action="store_true", default=False, + dest='failfast', help='Stop on first fail or error') +ARGS.add_argument( + '-c', '--catch', action="store_true", default=False, + dest='catchbreak', help='Catch control-C and display results') +ARGS.add_argument( + '--forever', action="store_true", dest='forever', default=False, + help='run tests forever to catch sporadic errors') +ARGS.add_argument( + '--findleaks', action='store_true', dest='findleaks', + help='detect tests that leak memory') +ARGS.add_argument( + '-q', action="store_true", dest='quiet', help='quiet') +ARGS.add_argument( + '--tests', action="store", dest='testsdir', default='tests', + help='tests directory') +ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') +ARGS.add_argument( + 'pattern', action="store", nargs="*", + help='optional regex patterns to match test ids (default all tests)') + +COV_ARGS = argparse.ArgumentParser(description="Run all unittests.") +COV_ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') + + +def load_modules(basedir, suffix='.py'): + def list_dir(prefix, dir): + files = [] + + modpath = os.path.join(dir, '__init__.py') + if os.path.isfile(modpath): + mod = os.path.split(dir)[-1] + files.append(('{}{}'.format(prefix, mod), modpath)) + + prefix = '{}{}.'.format(prefix, mod) + + for name in os.listdir(dir): + path = os.path.join(dir, name) + + if os.path.isdir(path): + files.extend(list_dir('{}{}.'.format(prefix, name), path)) + else: + if (name != '__init__.py' and + name.endswith(suffix) and + not name.startswith(('.', '_'))): + files.append(('{}{}'.format(prefix, name[:-3]), path)) + + return files + + mods = [] + for modname, sourcefile in list_dir('', basedir): + if modname == 'runtests': + continue + try: + loader = importlib.machinery.SourceFileLoader(modname, sourcefile) + mods.append((loader.load_module(), sourcefile)) + except SyntaxError: + raise + except Exception as err: + print("Skipping '{}': {}".format(modname, err), file=sys.stderr) + + return mods + + +def load_tests(testsdir, includes=(), excludes=()): + mods = [mod for mod, _ in load_modules(testsdir)] + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + for mod in mods: + for name in set(dir(mod)): + if name.endswith('Tests'): + test_module = getattr(mod, name) + tests = loader.loadTestsFromTestCase(test_module) + if includes: + tests = [test + for test in tests + if any(re.search(pat, test.id()) + for pat in includes)] + if excludes: + tests = [test + for test in tests + if not any(re.search(pat, test.id()) + for pat in excludes)] + suite.addTests(tests) + + return suite + + +class TestResult(unittest.TextTestResult): + + def __init__(self, stream, descriptions, verbosity): + super().__init__(stream, descriptions, verbosity) + self.leaks = [] + + def startTest(self, test): + super().startTest(test) + gc.collect() + + def addSuccess(self, test): + super().addSuccess(test) + gc.collect() + if gc.garbage: + if self.showAll: + self.stream.writeln( + " Warning: test created {} uncollectable " + "object(s).".format(len(gc.garbage))) + # move the uncollectable objects somewhere so we don't see + # them again + self.leaks.append((self.getDescription(test), gc.garbage[:])) + del gc.garbage[:] + + +class TestRunner(unittest.TextTestRunner): + resultclass = TestResult + + def run(self, test): + result = super().run(test) + if result.leaks: + self.stream.writeln("{} tests leaks:".format(len(result.leaks))) + for name, leaks in result.leaks: + self.stream.writeln(' '*4 + name + ':') + for leak in leaks: + self.stream.writeln(' '*8 + repr(leak)) + return result + + +def runtests(): + args = ARGS.parse_args() + + testsdir = os.path.abspath(args.testsdir) + if not os.path.isdir(testsdir): + print("Tests directory is not found: {}\n".format(testsdir)) + ARGS.print_help() + return + + excludes = includes = [] + if args.exclude: + excludes = args.pattern + else: + includes = args.pattern + + v = 0 if args.quiet else args.verbose + 1 + failfast = args.failfast + catchbreak = args.catchbreak + findleaks = args.findleaks + runner_factory = TestRunner if findleaks else unittest.TextTestRunner + + tests = load_tests(args.testsdir, includes, excludes) + logger = logging.getLogger() + if v == 0: + logger.setLevel(logging.CRITICAL) + elif v == 1: + logger.setLevel(logging.ERROR) + elif v == 2: + logger.setLevel(logging.WARNING) + elif v == 3: + logger.setLevel(logging.INFO) + elif v >= 4: + logger.setLevel(logging.DEBUG) + if catchbreak: + installHandler() + if args.forever: + while True: + result = runner_factory(verbosity=v, + failfast=failfast).run(tests) + if not result.wasSuccessful(): + sys.exit(1) + else: + result = runner_factory(verbosity=v, + failfast=failfast).run(tests) + sys.exit(not result.wasSuccessful()) + + +def runcoverage(sdir, args): + """ + To install coverage3 for Python 3, you need: + - Setuptools (https://pypi.python.org/pypi/setuptools) + + What worked for me: + - download bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py + * curl -O \ + https://bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py + - python3 ez_setup.py + - python3 -m easy_install coverage + """ + try: + import coverage + except ImportError: + print("Coverage package is not found.") + print(runcoverage.__doc__) + return + + sdir = os.path.abspath(sdir) + if not os.path.isdir(sdir): + print("Python files directory is not found: {}\n".format(sdir)) + ARGS.print_help() + return + + mods = [source for _, source in load_modules(sdir)] + coverage = [sys.executable, '-m', 'coverage'] + + try: + subprocess.check_call( + coverage + ['run', '--branch', 'runtests.py'] + args) + except: + pass + else: + subprocess.check_call(coverage + ['html'] + mods) + subprocess.check_call(coverage + ['report'] + mods) + + +if __name__ == '__main__': + if '--coverage' in sys.argv: + cov_args, args = COV_ARGS.parse_known_args() + runcoverage(cov_args.coverage, args) + else: + runtests() diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..0260f9d5 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[build_ext] +build_lib=tulip diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..dcaee96f --- /dev/null +++ b/setup.py @@ -0,0 +1,14 @@ +import os +from distutils.core import setup, Extension + +extensions = [] +if os.name == 'nt': + ext = Extension('_overlapped', ['overlapped.c'], libraries=['ws2_32']) + extensions.append(ext) + +setup(name='tulip', + description="reference implementation of PEP 3156", + url='http://www.python.org/dev/peps/pep-3156/', + packages=['tulip'], + ext_modules=extensions + ) diff --git a/tests/base_events_test.py b/tests/base_events_test.py new file mode 100644 index 00000000..e27b3ab9 --- /dev/null +++ b/tests/base_events_test.py @@ -0,0 +1,591 @@ +"""Tests for base_events.py""" + +import logging +import socket +import time +import unittest +import unittest.mock + +from tulip import base_events +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import tasks +from tulip import test_utils + + +class BaseEventLoopTests(unittest.TestCase): + + def setUp(self): + self.loop = base_events.BaseEventLoop() + self.loop._selector = unittest.mock.Mock() + events.set_event_loop(None) + + def test_not_implemented(self): + m = unittest.mock.Mock() + self.assertRaises( + NotImplementedError, + self.loop._make_socket_transport, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_ssl_transport, m, m, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_datagram_transport, m, m) + self.assertRaises( + NotImplementedError, self.loop._process_events, []) + self.assertRaises( + NotImplementedError, self.loop._write_to_self) + self.assertRaises( + NotImplementedError, self.loop._read_from_self) + self.assertRaises( + NotImplementedError, + self.loop._make_read_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_write_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + next, self.loop._make_subprocess_transport(m, m, m, m, m, m, m)) + + def test__add_callback_handle(self): + h = events.Handle(lambda: False, ()) + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertIn(h, self.loop._ready) + + def test__add_callback_timer(self): + h = events.TimerHandle(time.monotonic()+10, lambda: False, ()) + + self.loop._add_callback(h) + self.assertIn(h, self.loop._scheduled) + + def test__add_callback_cancelled_handle(self): + h = events.Handle(lambda: False, ()) + h.cancel() + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertFalse(self.loop._ready) + + def test_set_default_executor(self): + executor = unittest.mock.Mock() + self.loop.set_default_executor(executor) + self.assertIs(executor, self.loop._default_executor) + + def test_getnameinfo(self): + sockaddr = unittest.mock.Mock() + self.loop.run_in_executor = unittest.mock.Mock() + self.loop.getnameinfo(sockaddr) + self.assertEqual( + (None, socket.getnameinfo, sockaddr, 0), + self.loop.run_in_executor.call_args[0]) + + def test_call_soon(self): + def cb(): + pass + + h = self.loop.call_soon(cb) + self.assertEqual(h._callback, cb) + self.assertIsInstance(h, events.Handle) + self.assertIn(h, self.loop._ready) + + def test_call_later(self): + def cb(): + pass + + h = self.loop.call_later(10.0, cb) + self.assertIsInstance(h, events.TimerHandle) + self.assertIn(h, self.loop._scheduled) + self.assertNotIn(h, self.loop._ready) + + def test_call_later_negative_delays(self): + calls = [] + + def cb(arg): + calls.append(arg) + + self.loop._process_events = unittest.mock.Mock() + self.loop.call_later(-1, cb, 'a') + self.loop.call_later(-2, cb, 'b') + test_utils.run_briefly(self.loop) + self.assertEqual(calls, ['b', 'a']) + + def test_time_and_call_at(self): + def cb(): + self.loop.stop() + + self.loop._process_events = unittest.mock.Mock() + when = self.loop.time() + 0.1 + self.loop.call_at(when, cb) + t0 = self.loop.time() + self.loop.run_forever() + t1 = self.loop.time() + self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0) + + def test_run_once_in_executor_handle(self): + def cb(): + pass + + self.assertRaises( + AssertionError, self.loop.run_in_executor, + None, events.Handle(cb, ()), ('',)) + self.assertRaises( + AssertionError, self.loop.run_in_executor, + None, events.TimerHandle(10, cb, ())) + + def test_run_once_in_executor_cancelled(self): + def cb(): + pass + h = events.Handle(cb, ()) + h.cancel() + + f = self.loop.run_in_executor(None, h) + self.assertIsInstance(f, futures.Future) + self.assertTrue(f.done()) + self.assertIsNone(f.result()) + + def test_run_once_in_executor_plain(self): + def cb(): + pass + h = events.Handle(cb, ()) + f = futures.Future(loop=self.loop) + executor = unittest.mock.Mock() + executor.submit.return_value = f + + self.loop.set_default_executor(executor) + + res = self.loop.run_in_executor(None, h) + self.assertIs(f, res) + + executor = unittest.mock.Mock() + executor.submit.return_value = f + res = self.loop.run_in_executor(executor, h) + self.assertIs(f, res) + self.assertTrue(executor.submit.called) + + f.cancel() # Don't complain about abandoned Future. + + def test__run_once(self): + h1 = events.TimerHandle(time.monotonic() + 0.1, lambda: True, ()) + h2 = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ()) + + h1.cancel() + + self.loop._process_events = unittest.mock.Mock() + self.loop._scheduled.append(h1) + self.loop._scheduled.append(h2) + self.loop._run_once() + + t = self.loop._selector.select.call_args[0][0] + self.assertTrue(9.99 < t < 10.1, t) + self.assertEqual([h2], self.loop._scheduled) + self.assertTrue(self.loop._process_events.called) + + @unittest.mock.patch('tulip.base_events.time') + @unittest.mock.patch('tulip.base_events.tulip_log') + def test__run_once_logging(self, m_logging, m_time): + # Log to INFO level if timeout > 1.0 sec. + idx = -1 + data = [10.0, 10.0, 12.0, 13.0] + + def monotonic(): + nonlocal data, idx + idx += 1 + return data[idx] + + m_time.monotonic = monotonic + m_logging.INFO = logging.INFO + m_logging.DEBUG = logging.DEBUG + + self.loop._scheduled.append( + events.TimerHandle(11.0, lambda: True, ())) + self.loop._process_events = unittest.mock.Mock() + self.loop._run_once() + self.assertEqual(logging.INFO, m_logging.log.call_args[0][0]) + + idx = -1 + data = [10.0, 10.0, 10.3, 13.0] + self.loop._scheduled = [events.TimerHandle(11.0, lambda:True, ())] + self.loop._run_once() + self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0]) + + def test__run_once_schedule_handle(self): + handle = None + processed = False + + def cb(loop): + nonlocal processed, handle + processed = True + handle = loop.call_soon(lambda: True) + + h = events.TimerHandle(time.monotonic() - 1, cb, (self.loop,)) + + self.loop._process_events = unittest.mock.Mock() + self.loop._scheduled.append(h) + self.loop._run_once() + + self.assertTrue(processed) + self.assertEqual([handle], list(self.loop._ready)) + + def test_run_until_complete_type_error(self): + self.assertRaises( + TypeError, self.loop.run_until_complete, 'blah') + + +class MyProto(protocols.Protocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = futures.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyDatagramProto(protocols.DatagramProtocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = futures.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class BaseEventLoopWithSelectorTests(unittest.TestCase): + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_connection_mutiple_errors(self, m_socket): + + class MyProto(protocols.Protocol): + pass + + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + idx = -1 + errors = ['err1', 'err2'] + + def _socket(*args, **kw): + nonlocal idx, errors + idx += 1 + raise OSError(errors[idx]) + + m_socket.socket = _socket + + self.loop.getaddrinfo = getaddrinfo_task + + task = tasks.Task( + self.loop.create_connection(MyProto, 'example.com', 80)) + yield from tasks.wait(task) + exc = task.exception() + self.assertEqual("Multiple exceptions: err1, err2", str(exc)) + + def test_create_connection_host_port_sock(self): + coro = self.loop.create_connection( + MyProto, 'example.com', 80, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_host_port_sock(self): + coro = self.loop.create_connection(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_getaddrinfo(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_connect_err(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_multiple(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET) + with self.assertRaises(OSError): + self.loop.run_until_complete(coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_connection_mutiple_errors_local_addr(self, m_socket): + + def bind(addr): + if addr[0] == '0.0.0.1': + err = OSError('Err') + err.strerror = 'Err' + raise err + + m_socket.socket.return_value.bind = bind + + @tasks.coroutine + def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError('Err2') + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(coro) + + self.assertTrue(str(cm.exception), 'Multiple exceptions: ') + self.assertTrue(m_socket.socket.return_value.close.called) + + def test_create_connection_no_local_addr(self): + @tasks.coroutine + def getaddrinfo(host, *args, **kw): + if host == 'example.com': + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + else: + return [] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + self.loop.getaddrinfo = getaddrinfo_task + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_start_serving_empty_host(self): + # if host is empty string use None instead + host = object() + + @tasks.coroutine + def getaddrinfo(*args, **kw): + nonlocal host + host = args[0] + yield from [] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + fut = self.loop.start_serving(MyProto, '', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertIsNone(host) + + def test_start_serving_host_port_sock(self): + fut = self.loop.start_serving( + MyProto, '0.0.0.0', 0, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_start_serving_no_host_port_sock(self): + fut = self.loop.start_serving(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_start_serving_no_getaddrinfo(self): + getaddrinfo = self.loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + f = self.loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, f) + + @unittest.mock.patch('tulip.base_events.socket') + def test_start_serving_cant_bind(self, m_socket): + + class Err(OSError): + strerror = 'error' + + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_no_addrinfo(self, m_socket): + m_socket.getaddrinfo.return_value = [] + + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_addr_error(self): + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr='localhost') + self.assertRaises( + AssertionError, self.loop.run_until_complete, coro) + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 1, 2, 3)) + self.assertRaises( + AssertionError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_connect_err(self): + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_socket_err(self, m_socket): + m_socket.getaddrinfo = socket.getaddrinfo + m_socket.socket.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, local_addr=('127.0.0.1', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_no_matching_family(self): + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, + remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) + self.assertRaises( + ValueError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_setblk_err(self, m_socket): + m_socket.socket.return_value.setblocking.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + self.assertTrue( + m_socket.socket.return_value.close.called) + + def test_create_datagram_endpoint_noaddr_nofamily(self): + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_cant_bind(self, m_socket): + class Err(OSError): + pass + + m_socket.AF_INET6 = socket.AF_INET6 + m_socket.getaddrinfo = socket.getaddrinfo + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, + local_addr=('127.0.0.1', 0), family=socket.AF_INET) + self.assertRaises(Err, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_accept_connection_retry(self): + sock = unittest.mock.Mock() + sock.accept.side_effect = BlockingIOError() + + self.loop._accept_connection(MyProto, sock) + self.assertFalse(sock.close.called) + + @unittest.mock.patch('tulip.selector_events.tulip_log') + def test_accept_connection_exception(self, m_log): + sock = unittest.mock.Mock() + sock.accept.side_effect = OSError() + + self.loop._accept_connection(MyProto, sock) + self.assertTrue(sock.close.called) + self.assertTrue(m_log.exception.called) diff --git a/tests/echo.py b/tests/echo.py new file mode 100644 index 00000000..f6ac0a30 --- /dev/null +++ b/tests/echo.py @@ -0,0 +1,6 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + os.write(1, buf) diff --git a/tests/echo2.py b/tests/echo2.py new file mode 100644 index 00000000..e83ca09f --- /dev/null +++ b/tests/echo2.py @@ -0,0 +1,6 @@ +import os + +if __name__ == '__main__': + buf = os.read(0, 1024) + os.write(1, b'OUT:'+buf) + os.write(2, b'ERR:'+buf) diff --git a/tests/echo3.py b/tests/echo3.py new file mode 100644 index 00000000..f1f7ea7c --- /dev/null +++ b/tests/echo3.py @@ -0,0 +1,9 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + try: + os.write(1, b'OUT:'+buf) + except OSError as ex: + os.write(2, b'ERR:' + ex.__class__.__name__.encode('ascii')) diff --git a/tests/events_test.py b/tests/events_test.py new file mode 100644 index 00000000..240518c0 --- /dev/null +++ b/tests/events_test.py @@ -0,0 +1,1574 @@ +"""Tests for events.py.""" + +import functools +import gc +import io +import os +import re +import signal +import socket +try: + import ssl +except ImportError: + ssl = None +import subprocess +import sys +import threading +import time +import errno +import unittest +import unittest.mock +from test.support import find_unused_port + + +from tulip import futures +from tulip import events +from tulip import transports +from tulip import protocols +from tulip import selector_events +from tulip import tasks +from tulip import test_utils +from tulip import locks + + +class MyProto(protocols.Protocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyDatagramProto(protocols.DatagramProtocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyReadPipeProto(protocols.Protocol): + done = None + + def __init__(self, loop=None): + self.state = ['INITIAL'] + self.nbytes = 0 + self.transport = None + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == ['INITIAL'], self.state + self.state.append('CONNECTED') + + def data_received(self, data): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.state.append('EOF') + self.transport.close() + + def connection_lost(self, exc): + assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state + self.state.append('CLOSED') + if self.done: + self.done.set_result(None) + + +class MyWritePipeProto(protocols.BaseProtocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.transport = None + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MySubprocessProtocol(protocols.SubprocessProtocol): + + def __init__(self, loop): + self.state = 'INITIAL' + self.transport = None + self.connected = futures.Future(loop=loop) + self.completed = futures.Future(loop=loop) + self.disconnects = {fd: futures.Future(loop=loop) for fd in range(3)} + self.data = {1: b'', 2: b''} + self.returncode = None + self.got_data = {1: locks.EventWaiter(loop=loop), + 2: locks.EventWaiter(loop=loop)} + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + self.connected.set_result(None) + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + self.completed.set_result(None) + + def pipe_data_received(self, fd, data): + assert self.state == 'CONNECTED', self.state + self.data[fd] += data + self.got_data[fd].set() + + def pipe_connection_lost(self, fd, exc): + assert self.state == 'CONNECTED', self.state + if exc: + self.disconnects[fd].set_exception(exc) + else: + self.disconnects[fd].set_result(exc) + + def process_exited(self): + assert self.state == 'CONNECTED', self.state + self.returncode = self.transport.get_returncode() + + +class EventLoopTestsMixin: + + def setUp(self): + super().setUp() + self.loop = self.create_event_loop() + events.set_event_loop(None) + + def tearDown(self): + # just in case if we have transport close callbacks + test_utils.run_briefly(self.loop) + + self.loop.close() + gc.collect() + super().tearDown() + + def test_run_until_complete_nesting(self): + @tasks.coroutine + def coro1(): + yield + + @tasks.coroutine + def coro2(): + self.assertTrue(self.loop.is_running()) + self.loop.run_until_complete(coro1()) + + self.assertRaises( + RuntimeError, self.loop.run_until_complete, coro2()) + + # Note: because of the default Windows timing granularity of + # 15.6 msec, we use fairly long sleep times here (~100 msec). + + def test_run_until_complete(self): + t0 = self.loop.time() + self.loop.run_until_complete(tasks.sleep(0.1, loop=self.loop)) + t1 = self.loop.time() + self.assertTrue(0.08 <= t1-t0 <= 0.12, t1-t0) + + def test_run_until_complete_stopped(self): + @tasks.coroutine + def cb(): + self.loop.stop() + yield from tasks.sleep(0.1, loop=self.loop) + task = cb() + self.assertRaises(RuntimeError, + self.loop.run_until_complete, task) + + def test_call_later(self): + results = [] + + def callback(arg): + results.append(arg) + self.loop.stop() + + self.loop.call_later(0.1, callback, 'hello world') + t0 = time.monotonic() + self.loop.run_forever() + t1 = time.monotonic() + self.assertEqual(results, ['hello world']) + self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0) + + def test_call_soon(self): + results = [] + + def callback(arg1, arg2): + results.append((arg1, arg2)) + self.loop.stop() + + self.loop.call_soon(callback, 'hello', 'world') + self.loop.run_forever() + self.assertEqual(results, [('hello', 'world')]) + + def test_call_soon_threadsafe(self): + results = [] + lock = threading.Lock() + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.loop.stop() + + def run_in_thread(): + self.loop.call_soon_threadsafe(callback, 'hello') + lock.release() + + lock.acquire() + t = threading.Thread(target=run_in_thread) + t.start() + + with lock: + self.loop.call_soon(callback, 'world') + self.loop.run_forever() + t.join() + self.assertEqual(results, ['hello', 'world']) + + def test_call_soon_threadsafe_same_thread(self): + results = [] + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.loop.stop() + + self.loop.call_soon_threadsafe(callback, 'hello') + self.loop.call_soon(callback, 'world') + self.loop.run_forever() + self.assertEqual(results, ['hello', 'world']) + + def test_run_in_executor(self): + def run(arg): + return (arg, threading.get_ident()) + f2 = self.loop.run_in_executor(None, run, 'yo') + res, thread_id = self.loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + self.assertNotEqual(thread_id, threading.get_ident()) + + def test_reader_callback(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.loop.remove_reader(r.fileno())) + r.close() + + self.loop.add_reader(r.fileno(), reader) + self.loop.call_soon(w.send, b'abc') + test_utils.run_briefly(self.loop) + self.loop.call_soon(w.send, b'def') + self.loop.call_soon(w.close) + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_writer_callback(self): + r, w = test_utils.socketpair() + w.setblocking(False) + self.loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + test_utils.run_briefly(self.loop) + + def remove_writer(): + self.assertTrue(self.loop.remove_writer(w.fileno())) + + self.loop.call_soon(remove_writer) + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + w.close() + data = r.recv(256*1024) + r.close() + self.assertGreaterEqual(len(data), 200) + + def test_sock_client_ops(self): + with test_utils.run_test_server(self.loop) as httpd: + sock = socket.socket() + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, httpd.address)) + self.loop.run_until_complete( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + # consume data + self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + sock.close() + + self.assertTrue(re.match(rb'HTTP/1.0 200 OK', data), data) + + def test_sock_client_fail(self): + # Make sure that we will get an unused port + address = None + try: + s = socket.socket() + s.bind(('127.0.0.1', 0)) + address = s.getsockname() + finally: + s.close() + + sock = socket.socket() + sock.setblocking(False) + with self.assertRaises(ConnectionRefusedError): + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) + sock.close() + + def test_sock_accept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + + f = self.loop.sock_accept(listener) + conn, addr = self.loop.run_until_complete(f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + def test_add_signal_handler(self): + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + # Check error behavior first. + self.assertRaises( + TypeError, self.loop.add_signal_handler, 'boom', my_handler) + self.assertRaises( + TypeError, self.loop.remove_signal_handler, 'boom') + self.assertRaises( + ValueError, self.loop.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, signal.NSIG+1) + self.assertRaises( + ValueError, self.loop.add_signal_handler, 0, my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, 0) + self.assertRaises( + ValueError, self.loop.add_signal_handler, -1, my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, -1) + self.assertRaises( + RuntimeError, self.loop.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(self.loop.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + self.loop.add_signal_handler(signal.SIGINT, my_handler) + test_utils.run_briefly(self.loop) + os.kill(os.getpid(), signal.SIGINT) + test_utils.run_briefly(self.loop) + self.assertEqual(caught, 1) + # Removing it should restore the default handler. + self.assertTrue(self.loop.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(self.loop.remove_signal_handler(signal.SIGINT)) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_while_selecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + self.loop.stop() + + self.loop.add_signal_handler(signal.SIGALRM, my_handler) + + signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once. + self.loop.run_forever() + self.assertEqual(caught, 1) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_args(self): + some_args = (42,) + caught = 0 + + def my_handler(*args): + nonlocal caught + caught += 1 + self.assertEqual(args, some_args) + + self.loop.add_signal_handler(signal.SIGALRM, my_handler, *some_args) + + signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once. + self.loop.call_later(0.015, self.loop.stop) + self.loop.run_forever() + self.assertEqual(caught, 1) + + def test_create_connection(self): + with test_utils.run_test_server(self.loop) as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), *httpd.address) + tr, pr = self.loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def test_create_connection_sock(self): + with test_utils.run_test_server(self.loop) as httpd: + sock = None + infos = self.loop.run_until_complete( + self.loop.getaddrinfo( + *httpd.address, type=socket.SOCK_STREAM)) + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) + except: + pass + else: + break + else: + assert False, 'Can not create socket.' + + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), sock=sock) + tr, pr = self.loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_ssl_connection(self): + with test_utils.run_test_server( + self.loop, use_ssl=True) as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), *httpd.address, ssl=True) + tr, pr = self.loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + self.assertTrue( + hasattr(tr.get_extra_info('socket'), 'getsockname')) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def test_create_connection_local_addr(self): + with test_utils.run_test_server(self.loop) as httpd: + port = find_unused_port() + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, local_addr=(httpd.address[0], port)) + tr, pr = self.loop.run_until_complete(f) + expected = pr.transport.get_extra_info('socket').getsockname()[1] + self.assertEqual(port, expected) + tr.close() + + def test_create_connection_local_addr_in_use(self): + with test_utils.run_test_server(self.loop) as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, local_addr=httpd.address) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + self.assertIn(str(httpd.address), cm.exception.strerror) + + def test_start_serving(self): + proto = None + + def factory(): + nonlocal proto + proto = MyProto() + return proto + + f = self.loop.start_serving(factory, '0.0.0.0', 0) + socks = self.loop.run_until_complete(f) + self.assertEqual(len(socks), 1) + sock = socks[0] + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + test_utils.run_briefly(self.loop) + self.assertIsInstance(proto, MyProto) + self.assertEqual('INITIAL', proto.state) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + test_utils.run_briefly(self.loop) # windows iocp + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('socket')) + conn = proto.transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + self.assertEqual( + '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + + # close connection + proto.transport.close() + test_utils.run_briefly(self.loop) # windows iocp + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # close start_serving socks + self.loop.stop_serving(sock) + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_start_serving_ssl(self): + proto = None + + class ClientMyProto(MyProto): + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def factory(): + nonlocal proto + proto = MyProto(loop=self.loop) + return proto + + here = os.path.dirname(__file__) + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain( + certfile=os.path.join(here, 'sample.crt'), + keyfile=os.path.join(here, 'sample.key')) + + f = self.loop.start_serving( + factory, '127.0.0.1', 0, ssl=sslcontext) + + sock = self.loop.run_until_complete(f)[0] + host, port = sock.getsockname() + self.assertEqual(host, '127.0.0.1') + + f_c = self.loop.create_connection(ClientMyProto, host, port, ssl=True) + client, pr = self.loop.run_until_complete(f_c) + + client.write(b'xxx') + test_utils.run_briefly(self.loop) + self.assertIsInstance(proto, MyProto) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('socket')) + conn = proto.transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + self.assertEqual( + '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # stop serving + self.loop.stop_serving(sock) + + def test_start_serving_sock(self): + proto = futures.Future(loop=self.loop) + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + proto.set_result(self) + + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.loop.start_serving(TestMyProto, sock=sock_ob) + sock = self.loop.run_until_complete(f)[0] + self.assertIs(sock, sock_ob) + + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + + self.loop.stop_serving(sock) + + def test_start_serving_addr_in_use(self): + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.loop.start_serving(MyProto, sock=sock_ob) + sock = self.loop.run_until_complete(f)[0] + host, port = sock.getsockname() + + f = self.loop.start_serving(MyProto, host=host, port=port) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + + self.loop.stop_serving(sock) + + @unittest.skipUnless(socket.has_ipv6, 'IPv6 not supported') + def test_start_serving_dual_stack(self): + f_proto = futures.Future(loop=self.loop) + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + f_proto.set_result(self) + + try_count = 0 + while True: + try: + port = find_unused_port() + f = self.loop.start_serving(TestMyProto, host=None, port=port) + socks = self.loop.run_until_complete(f) + except OSError as ex: + if ex.errno == errno.EADDRINUSE: + try_count += 1 + self.assertGreaterEqual(5, try_count) + continue + else: + raise + else: + break + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + proto = self.loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + f_proto = futures.Future(loop=self.loop) + client = socket.socket(socket.AF_INET6) + client.connect(('::1', port)) + client.send(b'xxx') + proto = self.loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + for s in socks: + self.loop.stop_serving(s) + + def test_stop_serving(self): + f = self.loop.start_serving(MyProto, '0.0.0.0', 0) + socks = self.loop.run_until_complete(f) + sock = socks[0] + host, port = sock.getsockname() + + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + + self.loop.stop_serving(sock) + + client = socket.socket() + self.assertRaises( + ConnectionRefusedError, client.connect, ('127.0.0.1', port)) + client.close() + + def test_create_datagram_endpoint(self): + class TestMyDatagramProto(MyDatagramProto): + def __init__(inner_self): + super().__init__(loop=self.loop) + + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + coro = self.loop.create_datagram_endpoint( + TestMyDatagramProto, local_addr=('127.0.0.1', 0)) + s_transport, server = self.loop.run_until_complete(coro) + host, port = s_transport.get_extra_info('addr') + + coro = self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(loop=self.loop), + remote_addr=(host, port)) + transport, client = self.loop.run_until_complete(coro) + + self.assertEqual('INITIALIZED', client.state) + transport.sendto(b'xxx') + test_utils.run_briefly(self.loop) + self.assertEqual(3, server.nbytes) + test_utils.run_briefly(self.loop) + + # received + self.assertEqual(8, client.nbytes) + + # extra info is available + self.assertIsNotNone(transport.get_extra_info('socket')) + conn = transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + + # close connection + transport.close() + self.loop.run_until_complete(client.done) + self.assertEqual('CLOSED', client.state) + server.transport.close() + + def test_internal_fds(self): + loop = self.create_event_loop() + if not isinstance(loop, selector_events.BaseSelectorEventLoop): + return + + self.assertEqual(1, loop._internal_fds) + loop.close() + self.assertEqual(0, loop._internal_fds) + self.assertIsNone(loop._csock) + self.assertIsNone(loop._ssock) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_read_pipe(self): + proto = None + + def factory(): + nonlocal proto + proto = MyReadPipeProto(loop=self.loop) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(rpipe, 'rb', 1024) + + @tasks.coroutine + def connect(): + t, p = yield from self.loop.connect_read_pipe(factory, pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(0, proto.nbytes) + + self.loop.run_until_complete(connect()) + + os.write(wpipe, b'1') + test_utils.run_briefly(self.loop) + self.assertEqual(1, proto.nbytes) + + os.write(wpipe, b'2345') + test_utils.run_briefly(self.loop) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(5, proto.nbytes) + + os.close(wpipe) + self.loop.run_until_complete(proto.done) + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto(loop=self.loop) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.coroutine + def connect(): + nonlocal transport + t, p = yield from self.loop.connect_write_pipe(factory, pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.loop.run_until_complete(connect()) + + transport.write(b'1') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + transport.write(b'2345') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'2345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(rpipe) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe_disconnect_on_close(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto(loop=self.loop) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.coroutine + def connect(): + nonlocal transport + t, p = yield from self.loop.connect_write_pipe(factory, + pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.loop.run_until_complete(connect()) + self.assertEqual('CONNECTED', proto.state) + + transport.write(b'1') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + os.close(rpipe) + + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + def test_prompt_cancellation(self): + r, w = test_utils.socketpair() + r.setblocking(False) + f = self.loop.sock_recv(r, 1) + ov = getattr(f, 'ov', None) + self.assertTrue(ov is None or ov.pending) + + def main(): + try: + self.loop.call_soon(f.cancel) + yield from f + except futures.CancelledError: + res = 'cancelled' + else: + res = None + finally: + self.loop.stop() + return res + + start = time.monotonic() + t = tasks.Task(main(), loop=self.loop) + self.loop.run_forever() + elapsed = time.monotonic() - start + + self.assertLess(elapsed, 0.1) + self.assertEqual(t.result(), 'cancelled') + self.assertRaises(futures.CancelledError, f.result) + self.assertTrue(ov is None or not ov.pending) + self.loop.stop_serving(r) + + r.close() + w.close() + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_exec(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python The Winner') + self.loop.run_until_complete(proto.got_data[1].wait()) + transp.close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + self.assertEqual(b'Python The Winner', proto.data[1]) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_interactive(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + try: + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python ') + self.loop.run_until_complete(proto.got_data[1].wait()) + proto.got_data[1].clear() + self.assertEqual(b'Python ', proto.data[1]) + + stdin.write(b'The Winner') + self.loop.run_until_complete(proto.got_data[1].wait()) + self.assertEqual(b'Python The Winner', proto.data[1]) + finally: + transp.close() + + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_shell(self): + proto = None + transp = None + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'echo "Python"') + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.get_pipe_transport(0).close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(0, proto.returncode) + self.assertTrue(all(f.done() for f in proto.disconnects.values())) + self.assertEqual({1: b'Python\n', 2: b''}, proto.data) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_exitcode(self): + proto = None + + @tasks.coroutine + def connect(): + nonlocal proto + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_close_after_finish(self): + proto = None + transp = None + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.assertIsNone(transp.get_pipe_transport(0)) + self.assertIsNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + self.assertIsNone(transp.close()) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_kill(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.kill() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGKILL, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_send_signal(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.send_signal(signal.SIGHUP) + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGHUP, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_stderr(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'test') + + self.loop.run_until_complete(proto.completed) + + transp.close() + self.assertEqual(b'OUT:test', proto.data[1]) + self.assertTrue(proto.data[2].startswith(b'ERR:test'), proto.data[2]) + self.assertEqual(0, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_stderr_redirect_to_stdout(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog, stderr=subprocess.STDOUT) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + self.assertIsNotNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + + stdin.write(b'test') + self.loop.run_until_complete(proto.completed) + self.assertTrue(proto.data[1].startswith(b'OUT:testERR:test'), + proto.data[1]) + self.assertEqual(b'', proto.data[2]) + + transp.close() + self.assertEqual(0, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_close_client_stream(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo3.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdout = transp.get_pipe_transport(1) + stdin.write(b'test') + self.loop.run_until_complete(proto.got_data[1].wait()) + self.assertEqual(b'OUT:test', proto.data[1]) + + stdout.close() + self.loop.run_until_complete(proto.disconnects[1]) + stdin.write(b'xxx') + self.loop.run_until_complete(proto.got_data[2].wait()) + self.assertEqual(b'ERR:BrokenPipeError', proto.data[2]) + + transp.close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + + +if sys.platform == 'win32': + from tulip import windows_events + + class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return windows_events.SelectorEventLoop() + + class ProactorEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return windows_events.ProactorEventLoop() + + def test_create_ssl_connection(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + + def test_start_serving_ssl(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + + def test_reader_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_reader_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_writer_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + + def test_writer_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + + def test_create_datagram_endpoint(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") +else: + from tulip import selectors + from tulip import unix_events + + if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop( + selectors.KqueueSelector()) + + if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.EpollSelector()) + + if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.PollSelector()) + + # Should always exist. + class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.SelectSelector()) + + +class HandleTests(unittest.TestCase): + + def test_handle(self): + def callback(*args): + return args + + args = () + h = events.Handle(callback, args) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h._cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '.callback')) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h._cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '.callback')) + self.assertTrue(r.endswith('())'), r) + + def test_make_handle(self): + def callback(*args): + return args + h1 = events.Handle(callback, ()) + self.assertRaises( + AssertionError, events.make_handle, h1, ()) + + @unittest.mock.patch('tulip.events.tulip_log') + def test_callback_with_exception(self, log): + def callback(): + raise ValueError() + + h = events.Handle(callback, ()) + h._run() + self.assertTrue(log.exception.called) + + +class TimerTests(unittest.TestCase): + + def test_hash(self): + when = time.monotonic() + h = events.TimerHandle(when, lambda: False, ()) + self.assertEqual(hash(h), hash(when)) + + def test_timer(self): + def callback(*args): + return args + + args = () + when = time.monotonic() + h = events.TimerHandle(when, callback, args) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h._cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h._cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())'), r) + + self.assertRaises(AssertionError, + events.TimerHandle, None, callback, args) + + def test_timer_comparison(self): + def callback(*args): + return args + + when = time.monotonic() + + h1 = events.TimerHandle(when, callback, ()) + h2 = events.TimerHandle(when, callback, ()) + # TODO: Use assertLess etc. + self.assertFalse(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertTrue(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertFalse(h2 > h1) + self.assertTrue(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertTrue(h1 == h2) + self.assertFalse(h1 != h2) + + h2.cancel() + self.assertFalse(h1 == h2) + + h1 = events.TimerHandle(when, callback, ()) + h2 = events.TimerHandle(when + 10.0, callback, ()) + self.assertTrue(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertFalse(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertTrue(h2 > h1) + self.assertFalse(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h3 = events.Handle(callback, ()) + self.assertIs(NotImplemented, h1.__eq__(h3)) + self.assertIs(NotImplemented, h1.__ne__(h3)) + + +class AbstractEventLoopTests(unittest.TestCase): + + def test_not_implemented(self): + f = unittest.mock.Mock() + loop = events.AbstractEventLoop() + self.assertRaises( + NotImplementedError, loop.run_forever) + self.assertRaises( + NotImplementedError, loop.run_until_complete, None) + self.assertRaises( + NotImplementedError, loop.stop) + self.assertRaises( + NotImplementedError, loop.is_running) + self.assertRaises( + NotImplementedError, loop.call_later, None, None) + self.assertRaises( + NotImplementedError, loop.call_at, f, f) + self.assertRaises( + NotImplementedError, loop.call_soon, None) + self.assertRaises( + NotImplementedError, loop.time) + self.assertRaises( + NotImplementedError, loop.call_soon_threadsafe, None) + self.assertRaises( + NotImplementedError, loop.run_in_executor, f, f) + self.assertRaises( + NotImplementedError, loop.set_default_executor, f) + self.assertRaises( + NotImplementedError, loop.getaddrinfo, 'localhost', 8080) + self.assertRaises( + NotImplementedError, loop.getnameinfo, ('localhost', 8080)) + self.assertRaises( + NotImplementedError, loop.create_connection, f) + self.assertRaises( + NotImplementedError, loop.start_serving, f) + self.assertRaises( + NotImplementedError, loop.stop_serving, f) + self.assertRaises( + NotImplementedError, loop.create_datagram_endpoint, f) + self.assertRaises( + NotImplementedError, loop.add_reader, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_reader, 1) + self.assertRaises( + NotImplementedError, loop.add_writer, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_writer, 1) + self.assertRaises( + NotImplementedError, loop.sock_recv, f, 10) + self.assertRaises( + NotImplementedError, loop.sock_sendall, f, 10) + self.assertRaises( + NotImplementedError, loop.sock_connect, f, f) + self.assertRaises( + NotImplementedError, loop.sock_accept, f) + self.assertRaises( + NotImplementedError, loop.add_signal_handler, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, loop.connect_read_pipe, f, + unittest.mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, loop.connect_write_pipe, f, + unittest.mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, loop.subprocess_shell, f, + unittest.mock.sentinel) + self.assertRaises( + NotImplementedError, loop.subprocess_exec, f) + + +class ProtocolsAbsTests(unittest.TestCase): + + def test_empty(self): + f = unittest.mock.Mock() + p = protocols.Protocol() + self.assertIsNone(p.connection_made(f)) + self.assertIsNone(p.connection_lost(f)) + self.assertIsNone(p.data_received(f)) + self.assertIsNone(p.eof_received()) + + dp = protocols.DatagramProtocol() + self.assertIsNone(dp.connection_made(f)) + self.assertIsNone(dp.connection_lost(f)) + self.assertIsNone(dp.connection_refused(f)) + self.assertIsNone(dp.datagram_received(f, f)) + + sp = protocols.SubprocessProtocol() + self.assertIsNone(sp.connection_made(f)) + self.assertIsNone(sp.connection_lost(f)) + self.assertIsNone(sp.pipe_data_received(1, f)) + self.assertIsNone(sp.pipe_connection_lost(1, f)) + self.assertIsNone(sp.process_exited()) + + +class PolicyTests(unittest.TestCase): + + def test_event_loop_policy(self): + policy = events.AbstractEventLoopPolicy() + self.assertRaises(NotImplementedError, policy.get_event_loop) + self.assertRaises(NotImplementedError, policy.set_event_loop, object()) + self.assertRaises(NotImplementedError, policy.new_event_loop) + + def test_get_event_loop(self): + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy._loop) + + loop = policy.get_event_loop() + self.assertIsInstance(loop, events.AbstractEventLoop) + + self.assertIs(policy._loop, loop) + self.assertIs(loop, policy.get_event_loop()) + loop.close() + + def test_get_event_loop_after_set_none(self): + policy = events.DefaultEventLoopPolicy() + policy.set_event_loop(None) + self.assertRaises(AssertionError, policy.get_event_loop) + + @unittest.mock.patch('tulip.events.threading.current_thread') + def test_get_event_loop_thread(self, m_current_thread): + + def f(): + policy = events.DefaultEventLoopPolicy() + self.assertRaises(AssertionError, policy.get_event_loop) + + th = threading.Thread(target=f) + th.start() + th.join() + + def test_new_event_loop(self): + policy = events.DefaultEventLoopPolicy() + + loop = policy.new_event_loop() + self.assertIsInstance(loop, events.AbstractEventLoop) + loop.close() + + def test_set_event_loop(self): + policy = events.DefaultEventLoopPolicy() + old_loop = policy.get_event_loop() + + self.assertRaises(AssertionError, policy.set_event_loop, object()) + + loop = policy.new_event_loop() + policy.set_event_loop(loop) + self.assertIs(loop, policy.get_event_loop()) + self.assertIsNot(old_loop, policy.get_event_loop()) + loop.close() + old_loop.close() + + def test_get_event_loop_policy(self): + policy = events.get_event_loop_policy() + self.assertIsInstance(policy, events.AbstractEventLoopPolicy) + self.assertIs(policy, events.get_event_loop_policy()) + + def test_set_event_loop_policy(self): + self.assertRaises( + AssertionError, events.set_event_loop_policy, object()) + + old_policy = events.get_event_loop_policy() + + policy = events.DefaultEventLoopPolicy() + events.set_event_loop_policy(policy) + self.assertIs(policy, events.get_event_loop_policy()) + self.assertIsNot(policy, old_policy) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/futures_test.py b/tests/futures_test.py new file mode 100644 index 00000000..7c2abd18 --- /dev/null +++ b/tests/futures_test.py @@ -0,0 +1,317 @@ +"""Tests for futures.py.""" + +import concurrent.futures +import threading +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import test_utils + + +def _fakefunc(f): + return f + + +class FutureTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_initial_state(self): + f = futures.Future(loop=self.loop) + self.assertFalse(f.cancelled()) + self.assertFalse(f.done()) + f.cancel() + self.assertTrue(f.cancelled()) + + def test_init_constructor_default_loop(self): + try: + events.set_event_loop(self.loop) + f = futures.Future() + self.assertIs(f._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_constructor_positional(self): + # Make sure Future does't accept a positional argument + self.assertRaises(TypeError, futures.Future, 42) + + def test_cancel(self): + f = futures.Future(loop=self.loop) + self.assertTrue(f.cancel()) + self.assertTrue(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(futures.CancelledError, f.result) + self.assertRaises(futures.CancelledError, f.exception) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_result(self): + f = futures.Future(loop=self.loop) + self.assertRaises(futures.InvalidStateError, f.result) + + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_exception(self): + exc = RuntimeError() + f = futures.Future(loop=self.loop) + self.assertRaises(futures.InvalidStateError, f.exception) + + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_yield_from_twice(self): + f = futures.Future(loop=self.loop) + + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + + def test_repr(self): + f_pending = futures.Future(loop=self.loop) + self.assertEqual(repr(f_pending), 'Future') + f_pending.cancel() + + f_cancelled = futures.Future(loop=self.loop) + f_cancelled.cancel() + self.assertEqual(repr(f_cancelled), 'Future') + + f_result = futures.Future(loop=self.loop) + f_result.set_result(4) + self.assertEqual(repr(f_result), 'Future') + self.assertEqual(f_result.result(), 4) + + exc = RuntimeError() + f_exception = futures.Future(loop=self.loop) + f_exception.set_exception(exc) + self.assertEqual(repr(f_exception), 'Future') + self.assertIs(f_exception.exception(), exc) + + f_few_callbacks = futures.Future(loop=self.loop) + f_few_callbacks.add_done_callback(_fakefunc) + self.assertIn('Future', r) + f_many_callbacks.cancel() + + def test_copy_state(self): + # Test the internal _copy_state method since it's being directly + # invoked in other modules. + f = futures.Future(loop=self.loop) + f.set_result(10) + + newf = futures.Future(loop=self.loop) + newf._copy_state(f) + self.assertTrue(newf.done()) + self.assertEqual(newf.result(), 10) + + f_exception = futures.Future(loop=self.loop) + f_exception.set_exception(RuntimeError()) + + newf_exception = futures.Future(loop=self.loop) + newf_exception._copy_state(f_exception) + self.assertTrue(newf_exception.done()) + self.assertRaises(RuntimeError, newf_exception.result) + + f_cancelled = futures.Future(loop=self.loop) + f_cancelled.cancel() + + newf_cancelled = futures.Future(loop=self.loop) + newf_cancelled._copy_state(f_cancelled) + self.assertTrue(newf_cancelled.cancelled()) + + def test_iter(self): + fut = futures.Future(loop=self.loop) + + def coro(): + yield from fut + + def test(): + arg1, arg2 = coro() + + self.assertRaises(AssertionError, test) + fut.cancel() + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_tb_logger_abandoned(self, m_log): + fut = futures.Future(loop=self.loop) + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_tb_logger_result_unretrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_result(42) + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_tb_logger_result_retrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_result(42) + fut.result() + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_tb_logger_exception_unretrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + del fut + test_utils.run_briefly(self.loop) + self.assertTrue(m_log.error.called) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_tb_logger_exception_retrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + fut.exception() + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_tb_logger_exception_result_retrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + self.assertRaises(RuntimeError, fut.result) + del fut + self.assertFalse(m_log.error.called) + + def test_wrap_future(self): + + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = futures.wrap_future(f1, loop=self.loop) + res, ident = self.loop.run_until_complete(f2) + self.assertIsInstance(f2, futures.Future) + self.assertEqual(res, 'oi') + self.assertNotEqual(ident, threading.get_ident()) + + def test_wrap_future_future(self): + f1 = futures.Future(loop=self.loop) + f2 = futures.wrap_future(f1) + self.assertIs(f1, f2) + + @unittest.mock.patch('tulip.futures.events') + def test_wrap_future_use_global_loop(self, m_events): + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = futures.wrap_future(f1) + self.assertIs(m_events.get_event_loop.return_value, f2._loop) + + +# A fake event loop for tests. All it does is implement a call_soon method +# that immediately invokes the given function. +class _FakeEventLoop: + def call_soon(self, fn, *args): + fn(*args) + + +class FutureDoneCallbackTests(unittest.TestCase): + + def _make_callback(self, bag, thing): + # Create a callback function that appends thing to bag. + def bag_appender(future): + bag.append(thing) + return bag_appender + + def _new_future(self): + return futures.Future(loop=_FakeEventLoop()) + + def test_callbacks_invoked_on_set_result(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 42)) + f.add_done_callback(self._make_callback(bag, 17)) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def test_callbacks_invoked_on_set_exception(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 100)) + + self.assertEqual(bag, []) + exc = RuntimeError() + f.set_exception(exc) + self.assertEqual(bag, [100]) + self.assertEqual(f.exception(), exc) + + def test_remove_done_callback(self): + bag = [] + f = self._new_future() + cb1 = self._make_callback(bag, 1) + cb2 = self._make_callback(bag, 2) + cb3 = self._make_callback(bag, 3) + + # Add one cb1 and one cb2. + f.add_done_callback(cb1) + f.add_done_callback(cb2) + + # One instance of cb2 removed. Now there's only one cb1. + self.assertEqual(f.remove_done_callback(cb2), 1) + + # Never had any cb3 in there. + self.assertEqual(f.remove_done_callback(cb3), 0) + + # After this there will be 6 instances of cb1 and one of cb2. + f.add_done_callback(cb2) + for i in range(5): + f.add_done_callback(cb1) + + # Remove all instances of cb1. One cb2 remains. + self.assertEqual(f.remove_done_callback(cb1), 6) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [2]) + self.assertEqual(f.result(), 'foo') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/http_client_functional_test.py b/tests/http_client_functional_test.py new file mode 100644 index 00000000..91badfc4 --- /dev/null +++ b/tests/http_client_functional_test.py @@ -0,0 +1,552 @@ +"""Http client functional tests.""" + +import gc +import io +import os.path +import http.cookies +import unittest + +import tulip +import tulip.http +from tulip import test_utils +from tulip.http import client + + +class HttpClientFunctionalTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(None) + + def tearDown(self): + # just in case if we have transport close callbacks + test_utils.run_briefly(self.loop) + + self.loop.close() + gc.collect() + + def test_HTTP_200_OK_METHOD(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + for meth in ('get', 'post', 'put', 'delete', 'head'): + r = self.loop.run_until_complete( + client.request(meth, httpd.url('method', meth), + loop=self.loop)) + content1 = self.loop.run_until_complete(r.read()) + content2 = self.loop.run_until_complete(r.read()) + content = content1.decode() + + self.assertEqual(r.status, 200) + self.assertIn('"method": "%s"' % meth.upper(), content) + self.assertEqual(content1, content2) + r.close() + + def test_use_global_loop(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + try: + tulip.set_event_loop(self.loop) + r = self.loop.run_until_complete( + client.request('get', httpd.url('method', 'get'))) + finally: + tulip.set_event_loop(None) + content1 = self.loop.run_until_complete(r.read()) + content2 = self.loop.run_until_complete(r.read()) + content = content1.decode() + + self.assertEqual(r.status, 200) + self.assertIn('"method": "GET"', content) + self.assertEqual(content1, content2) + r.close() + + def test_HTTP_302_REDIRECT_GET(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('redirect', 2), + loop=self.loop)) + + self.assertEqual(r.status, 200) + self.assertEqual(2, httpd['redirects']) + r.close() + + def test_HTTP_302_REDIRECT_NON_HTTP(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + self.assertRaises( + ValueError, + self.loop.run_until_complete, + client.request('get', httpd.url('redirect_err'), + loop=self.loop)) + + def test_HTTP_302_REDIRECT_POST(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('post', httpd.url('redirect', 2), + data={'some': 'data'}, loop=self.loop)) + content = self.loop.run_until_complete(r.content.read()) + content = content.decode() + + self.assertEqual(r.status, 200) + self.assertIn('"method": "POST"', content) + self.assertEqual(2, httpd['redirects']) + r.close() + + def test_HTTP_302_max_redirects(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('redirect', 5), + max_redirects=2, loop=self.loop)) + + self.assertEqual(r.status, 302) + self.assertEqual(2, httpd['redirects']) + r.close() + + def test_HTTP_200_GET_WITH_PARAMS(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('method', 'get'), + params={'q': 'test'}, loop=self.loop)) + content = self.loop.run_until_complete(r.content.read()) + content = content.decode() + + self.assertIn('"query": "q=test"', content) + self.assertEqual(r.status, 200) + r.close() + + def test_HTTP_200_GET_WITH_MIXED_PARAMS(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request( + 'get', httpd.url('method', 'get') + '?test=true', + params={'q': 'test'}, loop=self.loop)) + content = self.loop.run_until_complete(r.content.read()) + content = content.decode() + + self.assertIn('"query": "test=true&q=test"', content) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_DATA(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + r = self.loop.run_until_complete( + client.request('post', url, data={'some': 'data'}, + loop=self.loop)) + self.assertEqual(r.status, 200) + + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual({'some': ['data']}, content['form']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_DATA_DEFLATE(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + r = self.loop.run_until_complete( + client.request('post', url, + data={'some': 'data'}, compress=True, + loop=self.loop)) + self.assertEqual(r.status, 200) + + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual('deflate', content['compression']) + self.assertEqual({'some': ['data']}, content['form']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request( + 'post', url, files={'some': f}, chunked=1024, + headers={'Transfer-Encoding': 'chunked'}, + loop=self.loop)) + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_DEFLATE(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files={'some': f}, + chunked=1024, compress='deflate', + loop=self.loop)) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual('deflate', content['compression']) + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_STR(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files=[('some', f.read())], + loop=self.loop)) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + 'some', content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_LIST(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files=[('some', f)], + loop=self.loop)) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_LIST_CT(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, loop=self.loop, + files=[('some', f, 'text/plain')])) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual( + 'text/plain', content['multipart-data'][0]['content-type']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_SINGLE(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files=[f], loop=self.loop)) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + filename, content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_IO(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + data = io.BytesIO(b'data') + + r = self.loop.run_until_complete( + client.request('post', url, files=[data], loop=self.loop)) + + content = self.loop.run_until_complete(r.read(True)) + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + {'content-type': 'application/octet-stream', + 'data': 'data', + 'filename': 'unknown', + 'name': 'unknown'}, content['multipart-data'][0]) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_WITH_DATA(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, loop=self.loop, + data={'test': 'true'}, files={'some': f})) + + content = self.loop.run_until_complete(r.read(True)) + + self.assertEqual(2, len(content['multipart-data'])) + self.assertEqual( + 'test', content['multipart-data'][0]['name']) + self.assertEqual( + 'true', content['multipart-data'][0]['data']) + + f.seek(0) + filename = os.path.split(f.name)[-1] + self.assertEqual( + 'some', content['multipart-data'][1]['name']) + self.assertEqual( + filename, content['multipart-data'][1]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][1]['data']) + self.assertEqual(r.status, 200) + r.close() + + def test_encoding(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('encoding', 'deflate'), + loop=self.loop)) + self.assertEqual(r.status, 200) + + r = self.loop.run_until_complete( + client.request('get', httpd.url('encoding', 'gzip'), + loop=self.loop)) + self.assertEqual(r.status, 200) + r.close() + + def test_cookies(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + c = http.cookies.Morsel() + c.set('test3', '456', '456') + + r = self.loop.run_until_complete( + client.request( + 'get', httpd.url('method', 'get'), loop=self.loop, + cookies={'test1': '123', 'test2': c})) + self.assertEqual(r.status, 200) + + content = self.loop.run_until_complete(r.content.read()) + self.assertIn(b'"Cookie": "test1=123; test3=456"', bytes(content)) + r.close() + + def test_set_cookies(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + resp = self.loop.run_until_complete( + client.request('get', httpd.url('cookies'), loop=self.loop)) + self.assertEqual(resp.status, 200) + + self.assertEqual(resp.cookies['c1'].value, 'cookie1') + self.assertEqual(resp.cookies['c2'].value, 'cookie2') + resp.close() + + def test_chunked(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('chunked'), loop=self.loop)) + self.assertEqual(r.status, 200) + self.assertEqual(r['Transfer-Encoding'], 'chunked') + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['path'], '/chunked') + r.close() + + def test_timeout(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + httpd['noresponse'] = True + self.assertRaises( + tulip.TimeoutError, + self.loop.run_until_complete, + client.request('get', httpd.url('method', 'get'), + timeout=0.1, loop=self.loop)) + + def test_request_conn_error(self): + self.assertRaises( + OSError, + self.loop.run_until_complete, + client.request('get', 'http://0.0.0.0:1', + timeout=0.1, loop=self.loop)) + + def test_request_conn_closed(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + httpd['close'] = True + self.assertRaises( + tulip.http.HttpException, + self.loop.run_until_complete, + client.request('get', httpd.url('method', 'get'), + loop=self.loop)) + + def test_keepalive(self): + from tulip.http import session + s = session.Session() + + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('keepalive',), + session=s, loop=self.loop)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['content'], 'requests=1') + r.close() + + r = self.loop.run_until_complete( + client.request('get', httpd.url('keepalive'), + session=s, loop=self.loop)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['content'], 'requests=2') + r.close() + + def test_session_close(self): + from tulip.http import session + s = session.Session() + + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request( + 'get', httpd.url('keepalive') + '?close=1', + session=s, loop=self.loop)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['content'], 'requests=1') + r.close() + + r = self.loop.run_until_complete( + client.request('get', httpd.url('keepalive'), + session=s, loop=self.loop)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['content'], 'requests=1') + r.close() + + def test_session_cookies(self): + from tulip.http import session + s = session.Session() + + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + s.update_cookies({'test': '1'}) + r = self.loop.run_until_complete( + client.request('get', httpd.url('cookies'), + session=s, loop=self.loop)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + + self.assertEqual(content['headers']['Cookie'], 'test=1') + r.close() + + cookies = sorted([(k, v.value) for k, v in s.cookies.items()]) + self.assertEqual( + cookies, [('c1', 'cookie1'), ('c2', 'cookie2'), ('test', '1')]) + + +class Functional(test_utils.Router): + + @test_utils.Router.define('/method/([A-Za-z]+)$') + def method(self, match): + meth = match.group(1).upper() + if meth == self._method: + self._response(self._start_response(200)) + else: + self._response(self._start_response(400)) + + @test_utils.Router.define('/redirect_err$') + def redirect_err(self, match): + self._response( + self._start_response(302), + headers={'Location': 'ftp://127.0.0.1/test/'}) + + @test_utils.Router.define('/redirect/([0-9]+)$') + def redirect(self, match): + no = int(match.group(1).upper()) + rno = self._props['redirects'] = self._props.get('redirects', 0) + 1 + + if rno >= no: + self._response( + self._start_response(302), + headers={'Location': '/method/%s' % self._method.lower()}) + else: + self._response( + self._start_response(302), + headers={'Location': self._path}) + + @test_utils.Router.define('/encoding/(gzip|deflate)$') + def encoding(self, match): + mode = match.group(1) + + resp = self._start_response(200) + resp.add_compression_filter(mode) + resp.add_chunking_filter(100) + self._response(resp, headers={'Content-encoding': mode}, chunked=True) + + @test_utils.Router.define('/chunked$') + def chunked(self, match): + resp = self._start_response(200) + resp.add_chunking_filter(100) + self._response(resp, chunked=True) + + @test_utils.Router.define('/keepalive$') + def keepalive(self, match): + self._transport._requests = getattr( + self._transport, '_requests', 0) + 1 + resp = self._start_response(200) + if 'close=' in self._query: + self._response( + resp, 'requests={}'.format(self._transport._requests)) + else: + self._response( + resp, 'requests={}'.format(self._transport._requests), + headers={'CONNECTION': 'keep-alive'}) + + @test_utils.Router.define('/cookies$') + def cookies(self, match): + cookies = http.cookies.SimpleCookie() + cookies['c1'] = 'cookie1' + cookies['c2'] = 'cookie2' + + resp = self._start_response(200) + for cookie in cookies.output(header='').split('\n'): + resp.add_header('Set-Cookie', cookie.strip()) + + self._response(resp) diff --git a/tests/http_client_test.py b/tests/http_client_test.py new file mode 100644 index 00000000..1aa27244 --- /dev/null +++ b/tests/http_client_test.py @@ -0,0 +1,289 @@ +# -*- coding: utf-8 -*- +"""Tests for tulip/http/client.py""" + +import unittest +import unittest.mock +import urllib.parse + +import tulip +import tulip.http + +from tulip.http.client import HttpRequest, HttpResponse + + +class HttpResponseTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(None) + + self.transport = unittest.mock.Mock() + self.stream = tulip.StreamBuffer(loop=self.loop) + self.response = HttpResponse('get', 'http://python.org') + + def tearDown(self): + self.loop.close() + + def test_close(self): + self.response.transport = self.transport + self.response.close() + self.assertIsNone(self.response.transport) + self.assertTrue(self.transport.close.called) + self.response.close() + self.response.close() + + def test_repr(self): + self.response.status = 200 + self.response.reason = 'Ok' + self.assertIn( + '', repr(self.response)) + + +class HttpRequestTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(None) + + self.transport = unittest.mock.Mock() + self.stream = tulip.StreamBuffer(loop=self.loop) + + def tearDown(self): + self.loop.close() + + def test_method(self): + req = HttpRequest('get', 'http://python.org/') + self.assertEqual(req.method, 'GET') + + req = HttpRequest('head', 'http://python.org/') + self.assertEqual(req.method, 'HEAD') + + req = HttpRequest('HEAD', 'http://python.org/') + self.assertEqual(req.method, 'HEAD') + + def test_version(self): + req = HttpRequest('get', 'http://python.org/', version='1.0') + self.assertEqual(req.version, (1, 0)) + + def test_version_err(self): + self.assertRaises( + ValueError, + HttpRequest, 'get', 'http://python.org/', version='1.c') + + def test_host_port(self): + req = HttpRequest('get', 'http://python.org/') + self.assertEqual(req.host, 'python.org') + self.assertEqual(req.port, 80) + self.assertFalse(req.ssl) + + req = HttpRequest('get', 'https://python.org/') + self.assertEqual(req.host, 'python.org') + self.assertEqual(req.port, 443) + self.assertTrue(req.ssl) + + req = HttpRequest('get', 'https://python.org:960/') + self.assertEqual(req.host, 'python.org') + self.assertEqual(req.port, 960) + self.assertTrue(req.ssl) + + def test_host_port_err(self): + self.assertRaises( + ValueError, HttpRequest, 'get', 'http://python.org:123e/') + + def test_host_header(self): + req = HttpRequest('get', 'http://python.org/') + self.assertEqual(req.headers['host'], 'python.org') + + req = HttpRequest('get', 'http://python.org/', + headers={'host': 'example.com'}) + self.assertEqual(req.headers['host'], 'example.com') + + def test_headers(self): + req = HttpRequest('get', 'http://python.org/', + headers={'Content-Type': 'text/plain'}) + self.assertIn('Content-Type', req.headers) + self.assertEqual(req.headers['Content-Type'], 'text/plain') + self.assertEqual(req.headers['Accept-Encoding'], 'gzip, deflate') + + def test_headers_list(self): + req = HttpRequest('get', 'http://python.org/', + headers=[('Content-Type', 'text/plain')]) + self.assertIn('Content-Type', req.headers) + self.assertEqual(req.headers['Content-Type'], 'text/plain') + + def test_headers_default(self): + req = HttpRequest('get', 'http://python.org/', + headers={'Accept-Encoding': 'deflate'}) + self.assertEqual(req.headers['Accept-Encoding'], 'deflate') + + def test_invalid_url(self): + self.assertRaises(ValueError, HttpRequest, 'get', 'hiwpefhipowhefopw') + + def test_invalid_idna(self): + self.assertRaises( + ValueError, HttpRequest, 'get', 'http://\u2061owhefopw.com') + + def test_no_path(self): + req = HttpRequest('get', 'http://python.org') + self.assertEqual('/', req.path) + + def test_basic_auth(self): + req = HttpRequest('get', 'http://python.org', auth=('nkim', '1234')) + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbToxMjM0', req.headers['Authorization']) + + def test_basic_auth_from_url(self): + req = HttpRequest('get', 'http://nkim:1234@python.org') + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbToxMjM0', req.headers['Authorization']) + + req = HttpRequest('get', 'http://nkim@python.org') + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbTo=', req.headers['Authorization']) + + req = HttpRequest( + 'get', 'http://nkim@python.org', auth=('nkim', '1234')) + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbToxMjM0', req.headers['Authorization']) + + def test_basic_auth_err(self): + self.assertRaises( + ValueError, HttpRequest, + 'get', 'http://python.org', auth=(1, 2, 3)) + + def test_no_content_length(self): + req = HttpRequest('get', 'http://python.org') + req.send(self.transport) + self.assertEqual('0', req.headers.get('Content-Length')) + + req = HttpRequest('head', 'http://python.org') + req.send(self.transport) + self.assertEqual('0', req.headers.get('Content-Length')) + + def test_path_is_not_double_encoded(self): + req = HttpRequest('get', "http://0.0.0.0/get/test case") + self.assertEqual(req.path, "/get/test%20case") + + req = HttpRequest('get', "http://0.0.0.0/get/test%20case") + self.assertEqual(req.path, "/get/test%20case") + + def test_params_are_added_before_fragment(self): + req = HttpRequest( + 'GET', "http://example.com/path#fragment", params={"a": "b"}) + self.assertEqual( + req.path, "/path?a=b#fragment") + + req = HttpRequest( + 'GET', + "http://example.com/path?key=value#fragment", params={"a": "b"}) + self.assertEqual( + req.path, "/path?key=value&a=b#fragment") + + def test_cookies(self): + req = HttpRequest( + 'get', 'http://test.com/path', cookies={'cookie1': 'val1'}) + self.assertIn('Cookie', req.headers) + self.assertEqual('cookie1=val1', req.headers['cookie']) + + req = HttpRequest( + 'get', 'http://test.com/path', + headers={'cookie': 'cookie1=val1'}, + cookies={'cookie2': 'val2'}) + self.assertEqual('cookie1=val1; cookie2=val2', req.headers['cookie']) + + def test_unicode_get(self): + def join(*suffix): + return urllib.parse.urljoin('http://python.org/', '/'.join(suffix)) + + url = 'http://python.org' + req = HttpRequest('get', url, params={'foo': 'f\xf8\xf8'}) + self.assertEqual('/?foo=f%C3%B8%C3%B8', req.path) + req = HttpRequest('', url, params={'f\xf8\xf8': 'f\xf8\xf8'}) + self.assertEqual('/?f%C3%B8%C3%B8=f%C3%B8%C3%B8', req.path) + req = HttpRequest('', url, params={'foo': 'foo'}) + self.assertEqual('/?foo=foo', req.path) + req = HttpRequest('', join('\xf8'), params={'foo': 'foo'}) + self.assertEqual('/%C3%B8?foo=foo', req.path) + + def test_query_multivalued_param(self): + for meth in HttpRequest.ALL_METHODS: + req = HttpRequest( + meth, 'http://python.org', + params=(('test', 'foo'), ('test', 'baz'))) + self.assertEqual(req.path, '/?test=foo&test=baz') + + def test_post_data(self): + for meth in HttpRequest.POST_METHODS: + req = HttpRequest(meth, 'http://python.org/', data={'life': '42'}) + req.send(self.transport) + self.assertEqual('/', req.path) + self.assertEqual(b'life=42', req.body[0]) + self.assertEqual('application/x-www-form-urlencoded', + req.headers['content-type']) + + def test_get_with_data(self): + for meth in HttpRequest.GET_METHODS: + req = HttpRequest(meth, 'http://python.org/', data={'life': '42'}) + self.assertEqual('/?life=42', req.path) + + @unittest.mock.patch('tulip.http.client.tulip') + def test_content_encoding(self, m_tulip): + req = HttpRequest('get', 'http://python.org/', compress='deflate') + req.send(self.transport) + self.assertEqual(req.headers['Transfer-encoding'], 'chunked') + self.assertEqual(req.headers['Content-encoding'], 'deflate') + m_tulip.http.Request.return_value\ + .add_compression_filter.assert_called_with('deflate') + + @unittest.mock.patch('tulip.http.client.tulip') + def test_content_encoding_header(self, m_tulip): + req = HttpRequest('get', 'http://python.org/', + headers={'Content-Encoding': 'deflate'}) + req.send(self.transport) + self.assertEqual(req.headers['Transfer-encoding'], 'chunked') + self.assertEqual(req.headers['Content-encoding'], 'deflate') + + m_tulip.http.Request.return_value\ + .add_compression_filter.assert_called_with('deflate') + m_tulip.http.Request.return_value\ + .add_chunking_filter.assert_called_with(8196) + + def test_chunked(self): + req = HttpRequest( + 'get', 'http://python.org/', + headers={'Transfer-encoding': 'gzip'}) + req.send(self.transport) + self.assertEqual('gzip', req.headers['Transfer-encoding']) + + req = HttpRequest( + 'get', 'http://python.org/', + headers={'Transfer-encoding': 'chunked'}) + req.send(self.transport) + self.assertEqual('chunked', req.headers['Transfer-encoding']) + + @unittest.mock.patch('tulip.http.client.tulip') + def test_chunked_explicit(self, m_tulip): + req = HttpRequest( + 'get', 'http://python.org/', chunked=True) + req.send(self.transport) + + self.assertEqual('chunked', req.headers['Transfer-encoding']) + m_tulip.http.Request.return_value\ + .add_chunking_filter.assert_called_with(8196) + + @unittest.mock.patch('tulip.http.client.tulip') + def test_chunked_explicit_size(self, m_tulip): + req = HttpRequest( + 'get', 'http://python.org/', chunked=1024) + req.send(self.transport) + self.assertEqual('chunked', req.headers['Transfer-encoding']) + m_tulip.http.Request.return_value\ + .add_chunking_filter.assert_called_with(1024) + + def test_chunked_length(self): + req = HttpRequest( + 'get', 'http://python.org/', + headers={'Content-Length': '1000'}, chunked=1024) + req.send(self.transport) + self.assertEqual(req.headers['Transfer-Encoding'], 'chunked') + self.assertNotIn('Content-Length', req.headers) diff --git a/tests/http_parser_test.py b/tests/http_parser_test.py new file mode 100644 index 00000000..6240ad49 --- /dev/null +++ b/tests/http_parser_test.py @@ -0,0 +1,539 @@ +"""Tests for http/parser.py""" + +from collections import deque +import zlib +import unittest +import unittest.mock + +import tulip +from tulip.http import errors +from tulip.http import protocol + + +class ParseHeadersTests(unittest.TestCase): + + def setUp(self): + tulip.set_event_loop(None) + + def test_parse_headers(self): + hdrs = ('', 'test: line\r\n', ' continue\r\n', + 'test2: data\r\n', '\r\n') + + headers, close, compression = protocol.parse_headers( + hdrs, 8190, 32768, 8190) + + self.assertEqual(list(headers), + [('TEST', 'line\r\n continue'), ('TEST2', 'data')]) + self.assertIsNone(close) + self.assertIsNone(compression) + + def test_parse_headers_multi(self): + hdrs = ('', + 'Set-Cookie: c1=cookie1\r\n', + 'Set-Cookie: c2=cookie2\r\n', '\r\n') + + headers, close, compression = protocol.parse_headers( + hdrs, 8190, 32768, 8190) + + self.assertEqual(list(headers), + [('SET-COOKIE', 'c1=cookie1'), + ('SET-COOKIE', 'c2=cookie2')]) + self.assertIsNone(close) + self.assertIsNone(compression) + + def test_conn_close(self): + headers, close, compression = protocol.parse_headers( + ['', 'connection: close\r\n', '\r\n'], 8190, 32768, 8190) + self.assertTrue(close) + + def test_conn_keep_alive(self): + headers, close, compression = protocol.parse_headers( + ['', 'connection: keep-alive\r\n', '\r\n'], 8190, 32768, 8190) + self.assertFalse(close) + + def test_conn_other(self): + headers, close, compression = protocol.parse_headers( + ['', 'connection: test\r\n', '\r\n'], 8190, 32768, 8190) + self.assertIsNone(close) + + def test_compression_gzip(self): + headers, close, compression = protocol.parse_headers( + ['', 'content-encoding: gzip\r\n', '\r\n'], 8190, 32768, 8190) + self.assertEqual('gzip', compression) + + def test_compression_deflate(self): + headers, close, compression = protocol.parse_headers( + ['', 'content-encoding: deflate\r\n', '\r\n'], 8190, 32768, 8190) + self.assertEqual('deflate', compression) + + def test_compression_unknown(self): + headers, close, compression = protocol.parse_headers( + ['', 'content-encoding: compress\r\n', '\r\n'], 8190, 32768, 8190) + self.assertIsNone(compression) + + def test_max_field_size(self): + with self.assertRaises(errors.LineTooLong) as cm: + protocol.parse_headers( + ['', 'test: line data data\r\n', 'data\r\n', '\r\n'], + 8190, 32768, 5) + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_max_continuation_headers_size(self): + with self.assertRaises(errors.LineTooLong) as cm: + protocol.parse_headers( + ['', 'test: line\r\n', ' test\r\n', '\r\n'], 8190, 32768, 5) + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_invalid_header(self): + with self.assertRaises(ValueError) as cm: + protocol.parse_headers( + ['', 'test line\r\n', '\r\n'], 8190, 32768, 8190) + self.assertIn("Invalid header: test line", str(cm.exception)) + + def test_invalid_name(self): + with self.assertRaises(ValueError) as cm: + protocol.parse_headers( + ['', 'test[]: line\r\n', '\r\n'], 8190, 32768, 8190) + self.assertIn("Invalid header name: TEST[]", str(cm.exception)) + + +class DeflateBufferTests(unittest.TestCase): + + def setUp(self): + tulip.set_event_loop(None) + + def test_feed_data(self): + buf = tulip.DataBuffer() + dbuf = protocol.DeflateBuffer(buf, 'deflate') + + dbuf.zlib = unittest.mock.Mock() + dbuf.zlib.decompress.return_value = b'line' + + dbuf.feed_data(b'data') + self.assertEqual([b'line'], list(buf._buffer)) + + def test_feed_data_err(self): + buf = tulip.DataBuffer() + dbuf = protocol.DeflateBuffer(buf, 'deflate') + + exc = ValueError() + dbuf.zlib = unittest.mock.Mock() + dbuf.zlib.decompress.side_effect = exc + + self.assertRaises(errors.IncompleteRead, dbuf.feed_data, b'data') + + def test_feed_eof(self): + buf = tulip.DataBuffer() + dbuf = protocol.DeflateBuffer(buf, 'deflate') + + dbuf.zlib = unittest.mock.Mock() + dbuf.zlib.flush.return_value = b'line' + + dbuf.feed_eof() + self.assertEqual([b'line'], list(buf._buffer)) + self.assertTrue(buf._eof) + + def test_feed_eof_err(self): + buf = tulip.DataBuffer() + dbuf = protocol.DeflateBuffer(buf, 'deflate') + + dbuf.zlib = unittest.mock.Mock() + dbuf.zlib.flush.return_value = b'line' + dbuf.zlib.eof = False + + self.assertRaises(errors.IncompleteRead, dbuf.feed_eof) + + +class ParsePayloadTests(unittest.TestCase): + + def setUp(self): + tulip.set_event_loop(None) + + def test_parse_eof_payload(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_eof_payload(out, buf) + next(p) + p.send(b'data') + try: + p.throw(tulip.EofStream()) + except tulip.EofStream: + pass + + self.assertEqual([b'data'], list(out._buffer)) + + def test_parse_length_payload(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_length_payload(out, buf, 4) + next(p) + p.send(b'da') + p.send(b't') + try: + p.send(b'aline') + except StopIteration: + pass + + self.assertEqual(3, len(out._buffer)) + self.assertEqual(b'data', b''.join(out._buffer)) + self.assertEqual(b'line', bytes(buf)) + + def test_parse_length_payload_eof(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_length_payload(out, buf, 4) + next(p) + p.send(b'da') + self.assertRaises( + errors.IncompleteRead, p.throw, tulip.EofStream) + + def test_parse_chunked_payload(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + try: + p.send(b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') + except StopIteration: + pass + self.assertEqual(b'dataline', b''.join(out._buffer)) + self.assertEqual(b'', bytes(buf)) + + def test_parse_chunked_payload_chunks(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + p.send(b'4\r\ndata\r') + p.send(b'\n4') + p.send(b'\r') + p.send(b'\n') + p.send(b'line\r\n0\r\n') + self.assertRaises(StopIteration, p.send, b'test\r\n') + self.assertEqual(b'dataline', b''.join(out._buffer)) + + def test_parse_chunked_payload_incomplete(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + p.send(b'4\r\ndata\r\n') + self.assertRaises(errors.IncompleteRead, p.throw, tulip.EofStream) + + def test_parse_chunked_payload_extension(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + try: + p.send(b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') + except StopIteration: + pass + self.assertEqual(b'dataline', b''.join(out._buffer)) + + def test_parse_chunked_payload_size_error(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + self.assertRaises(errors.IncompleteRead, p.send, b'blah\r\n') + + def test_http_payload_parser_length_broken(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', 'qwe')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + self.assertRaises(errors.InvalidHeader, p.send, (out, buf)) + + def test_http_payload_parser_length_wrong(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', '-1')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + self.assertRaises(errors.InvalidHeader, p.send, (out, buf)) + + def test_http_payload_parser_length(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', '2')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + try: + p.send(b'1245') + except StopIteration: + pass + + self.assertEqual(b'12', b''.join(out._buffer)) + self.assertEqual(b'45', bytes(buf)) + + def test_http_payload_parser_no_length(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [], None, None) + p = protocol.http_payload_parser(msg, readall=False) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + self.assertRaises(StopIteration, p.send, (out, buf)) + self.assertEqual(b'', b''.join(out._buffer)) + self.assertTrue(out._eof) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_http_payload_parser_deflate(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', len(self._COMPRESSED))], + None, 'deflate') + p = protocol.http_payload_parser(msg) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(StopIteration, p.send, self._COMPRESSED) + self.assertEqual(b'data', b''.join(out._buffer)) + + def test_http_payload_parser_deflate_disabled(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', len(self._COMPRESSED))], + None, 'deflate') + p = protocol.http_payload_parser(msg, compression=False) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(StopIteration, p.send, self._COMPRESSED) + self.assertEqual(self._COMPRESSED, b''.join(out._buffer)) + + def test_http_payload_parser_websocket(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('SEC-WEBSOCKET-KEY1', '13')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(StopIteration, p.send, b'1234567890') + self.assertEqual(b'12345678', b''.join(out._buffer)) + + def test_http_payload_parser_chunked(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('TRANSFER-ENCODING', 'chunked')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(StopIteration, p.send, + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') + self.assertEqual(b'dataline', b''.join(out._buffer)) + + def test_http_payload_parser_eof(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [], None, None) + p = protocol.http_payload_parser(msg, readall=True) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + p.send(b'data') + p.send(b'line') + self.assertRaises(tulip.EofStream, p.throw, tulip.EofStream()) + self.assertEqual(b'dataline', b''.join(out._buffer)) + + def test_http_payload_parser_length_zero(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', '0')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + self.assertRaises(StopIteration, p.send, (out, buf)) + self.assertEqual(b'', b''.join(out._buffer)) + + +class ParseRequestTests(unittest.TestCase): + + def setUp(self): + tulip.set_event_loop(None) + + def test_http_request_parser_max_headers(self): + p = protocol.http_request_parser(8190, 20, 8190) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + + self.assertRaises( + errors.LineTooLong, + p.send, + b'get /path HTTP/1.1\r\ntest: line\r\ntest2: data\r\n\r\n') + + def test_http_request_parser(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + try: + p.send(b'get /path HTTP/1.1\r\n\r\n') + except StopIteration: + pass + result = out._buffer[0] + self.assertEqual( + ('GET', '/path', (1, 1), deque(), False, None), result) + + def test_http_request_parser_eof(self): + # http_request_parser does not fail on EofStream() + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + p.send(b'get /path HTTP/1.1\r\n') + try: + p.throw(tulip.EofStream()) + except StopIteration: + pass + self.assertFalse(out._buffer) + + def test_http_request_parser_two_slashes(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + try: + p.send(b'get //path HTTP/1.1\r\n\r\n') + except StopIteration: + pass + self.assertEqual( + ('GET', '//path', (1, 1), deque(), False, None), out._buffer[0]) + + def test_http_request_parser_bad_status_line(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises( + errors.BadStatusLine, p.send, b'\r\n\r\n') + + def test_http_request_parser_bad_method(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises( + errors.BadStatusLine, + p.send, b'!12%()+=~$ /get HTTP/1.1\r\n\r\n') + + def test_http_request_parser_bad_version(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises( + errors.BadStatusLine, + p.send, b'GET //get HT/11\r\n\r\n') + + +class ParseResponseTests(unittest.TestCase): + + def setUp(self): + tulip.set_event_loop(None) + + def test_http_response_parser_bad_status_line(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(errors.BadStatusLine, p.send, b'\r\n\r\n') + + def test_http_response_parser_bad_status_line_eof(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises( + errors.BadStatusLine, p.throw, tulip.EofStream()) + + def test_http_response_parser_bad_version(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HT/11 200 Ok\r\n\r\n') + self.assertEqual('HT/11 200 Ok\r\n', cm.exception.args[0]) + + def test_http_response_parser_no_reason(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + try: + p.send(b'HTTP/1.1 200\r\n\r\n') + except StopIteration: + pass + v, s, r = out._buffer[0][:3] + self.assertEqual(v, (1, 1)) + self.assertEqual(s, 200) + self.assertEqual(r, '') + + def test_http_response_parser_bad(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HTT/1\r\n\r\n') + self.assertIn('HTT/1', str(cm.exception)) + + def test_http_response_parser_code_under_100(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HTTP/1.1 99 test\r\n\r\n') + self.assertIn('HTTP/1.1 99 test', str(cm.exception)) + + def test_http_response_parser_code_above_999(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HTTP/1.1 9999 test\r\n\r\n') + self.assertIn('HTTP/1.1 9999 test', str(cm.exception)) + + def test_http_response_parser_code_not_int(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HTTP/1.1 ttt test\r\n\r\n') + self.assertIn('HTTP/1.1 ttt test', str(cm.exception)) diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py new file mode 100644 index 00000000..ec3aaf58 --- /dev/null +++ b/tests/http_protocol_test.py @@ -0,0 +1,400 @@ +"""Tests for http/protocol.py""" + +import unittest +import unittest.mock +import zlib + +import tulip +from tulip.http import protocol + + +class HttpMessageTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + tulip.set_event_loop(None) + + def test_start_request(self): + msg = protocol.Request( + self.transport, 'GET', '/index.html', close=True) + + self.assertIs(msg.transport, self.transport) + self.assertIsNone(msg.status) + self.assertTrue(msg.closing) + self.assertEqual(msg.status_line, 'GET /index.html HTTP/1.1\r\n') + + def test_start_response(self): + msg = protocol.Response(self.transport, 200, close=True) + + self.assertIs(msg.transport, self.transport) + self.assertEqual(msg.status, 200) + self.assertTrue(msg.closing) + self.assertEqual(msg.status_line, 'HTTP/1.1 200 OK\r\n') + + def test_force_close(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.closing) + msg.force_close() + self.assertTrue(msg.closing) + + def test_force_chunked(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.chunked) + msg.force_chunked() + self.assertTrue(msg.chunked) + + def test_keep_alive(self): + msg = protocol.Response(self.transport, 200, close=True) + self.assertFalse(msg.keep_alive()) + msg.keepalive = True + self.assertTrue(msg.keep_alive()) + + msg.force_close() + self.assertFalse(msg.keep_alive()) + + def test_keep_alive_http10(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + self.assertFalse(msg.keepalive) + self.assertFalse(msg.keep_alive()) + + msg = protocol.Response(self.transport, 200, http_version=(1, 1)) + self.assertIsNone(msg.keepalive) + + def test_add_header(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], list(msg.headers)) + + msg.add_header('content-type', 'plain/html') + self.assertEqual([('CONTENT-TYPE', 'plain/html')], list(msg.headers)) + + def test_add_headers(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], list(msg.headers)) + + msg.add_headers(('content-type', 'plain/html')) + self.assertEqual([('CONTENT-TYPE', 'plain/html')], list(msg.headers)) + + def test_add_headers_length(self): + msg = protocol.Response(self.transport, 200) + self.assertIsNone(msg.length) + + msg.add_headers(('content-length', '200')) + self.assertEqual(200, msg.length) + + def test_add_headers_upgrade(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.upgrade) + + msg.add_headers(('connection', 'upgrade')) + self.assertTrue(msg.upgrade) + + def test_add_headers_upgrade_websocket(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('upgrade', 'test')) + self.assertEqual([], list(msg.headers)) + + msg.add_headers(('upgrade', 'websocket')) + self.assertEqual([('UPGRADE', 'websocket')], list(msg.headers)) + + def test_add_headers_connection_keepalive(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('connection', 'keep-alive')) + self.assertEqual([], list(msg.headers)) + self.assertTrue(msg.keepalive) + + msg.add_headers(('connection', 'close')) + self.assertFalse(msg.keepalive) + + def test_add_headers_hop_headers(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('connection', 'test'), ('transfer-encoding', 't')) + self.assertEqual([], list(msg.headers)) + + def test_default_headers(self): + msg = protocol.Response(self.transport, 200) + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertIn('DATE', headers) + self.assertIn('CONNECTION', headers) + + def test_default_headers_server(self): + msg = protocol.Response(self.transport, 200) + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertIn('SERVER', headers) + + def test_default_headers_useragent(self): + msg = protocol.Request(self.transport, 'GET', '/') + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertNotIn('SERVER', headers) + self.assertIn('USER-AGENT', headers) + + def test_default_headers_chunked(self): + msg = protocol.Response(self.transport, 200) + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertNotIn('TRANSFER-ENCODING', headers) + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertIn('TRANSFER-ENCODING', headers) + + def test_default_headers_connection_upgrade(self): + msg = protocol.Response(self.transport, 200) + msg.upgrade = True + msg._add_default_headers() + + headers = [r for r in msg.headers if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'upgrade')], headers) + + def test_default_headers_connection_close(self): + msg = protocol.Response(self.transport, 200) + msg.force_close() + msg._add_default_headers() + + headers = [r for r in msg.headers if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'close')], headers) + + def test_default_headers_connection_keep_alive(self): + msg = protocol.Response(self.transport, 200) + msg.keepalive = True + msg._add_default_headers() + + headers = [r for r in msg.headers if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'keep-alive')], headers) + + def test_send_headers(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-type', 'plain/html')) + self.assertFalse(msg.is_headers_sent()) + + msg.send_headers() + + content = b''.join([arg[1][0] for arg in list(write.mock_calls)]) + + self.assertTrue(content.startswith(b'HTTP/1.1 200 OK\r\n')) + self.assertIn(b'CONTENT-TYPE: plain/html', content) + self.assertTrue(msg.headers_sent) + self.assertTrue(msg.is_headers_sent()) + # cleanup + msg.writer.close() + + def test_send_headers_nomore_add(self): + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-type', 'plain/html')) + msg.send_headers() + + self.assertRaises(AssertionError, + msg.add_header, 'content-type', 'plain/html') + # cleanup + msg.writer.close() + + def test_prepare_length(self): + msg = protocol.Response(self.transport, 200) + length = msg._write_length_payload = unittest.mock.Mock() + length.return_value = iter([1, 2, 3]) + + msg.add_headers(('content-length', '200')) + msg.send_headers() + + self.assertTrue(length.called) + self.assertTrue((200,), length.call_args[0]) + + def test_prepare_chunked_force(self): + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + + chunked = msg._write_chunked_payload = unittest.mock.Mock() + chunked.return_value = iter([1, 2, 3]) + + msg.add_headers(('content-length', '200')) + msg.send_headers() + self.assertTrue(chunked.called) + + def test_prepare_chunked_no_length(self): + msg = protocol.Response(self.transport, 200) + + chunked = msg._write_chunked_payload = unittest.mock.Mock() + chunked.return_value = iter([1, 2, 3]) + + msg.send_headers() + self.assertTrue(chunked.called) + + def test_prepare_eof(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + + eof = msg._write_eof_payload = unittest.mock.Mock() + eof.return_value = iter([1, 2, 3]) + + msg.send_headers() + self.assertTrue(eof.called) + + def test_write_auto_send_headers(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg._send_headers = True + + msg.write(b'data1') + self.assertTrue(msg.headers_sent) + # cleanup + msg.writer.close() + + def test_write_payload_eof(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg.send_headers() + + msg.write(b'data1') + self.assertTrue(msg.headers_sent) + + msg.write(b'data2') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'data1data2', content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg.send_headers() + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'4\r\ndata\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_multiple(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg.send_headers() + + msg.write(b'data1') + msg.write(b'data2') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_length(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '2')) + msg.send_headers() + + msg.write(b'd') + msg.write(b'ata') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'da', content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_filter(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(2) + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith(b'2\r\nda\r\n2\r\nta\r\n0\r\n\r\n')) + + def test_write_payload_chunked_filter_mutiple_chunks(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(2) + msg.write(b'data1') + msg.write(b'data2') + msg.write_eof() + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith( + b'2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n' + b'2\r\na2\r\n0\r\n\r\n')) + + def test_write_payload_chunked_large_chunk(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(1024) + msg.write(b'data') + msg.write_eof() + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith(b'4\r\ndata\r\n0\r\n\r\n')) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_write_payload_deflate_filter(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '{}'.format(len(self._COMPRESSED)))) + msg.send_headers() + + msg.add_compression_filter('deflate') + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_deflate_and_chunked(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_compression_filter('deflate') + msg.add_chunking_filter(2) + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'2\r\nKI\r\n2\r\n,I\r\n2\r\n\x04\x00\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_and_deflate(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '{}'.format(len(self._COMPRESSED)))) + + msg.add_chunking_filter(2) + msg.add_compression_filter('deflate') + msg.send_headers() + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) diff --git a/tests/http_server_test.py b/tests/http_server_test.py new file mode 100644 index 00000000..a9d4d5ed --- /dev/null +++ b/tests/http_server_test.py @@ -0,0 +1,301 @@ +"""Tests for http/server.py""" + +import unittest +import unittest.mock + +import tulip +from tulip.http import server +from tulip.http import errors +from tulip import test_utils + + +class HttpServerProtocolTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + tulip.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_http_error_exception(self): + exc = errors.HttpErrorException(500, message='Internal error') + self.assertEqual(exc.code, 500) + self.assertEqual(exc.message, 'Internal error') + + def test_handle_request(self): + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(transport) + + rline = unittest.mock.Mock() + rline.version = (1, 1) + message = unittest.mock.Mock() + srv.handle_request(rline, message) + + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertTrue(content.startswith(b'HTTP/1.1 404 Not Found\r\n')) + + def test_connection_made(self): + srv = server.ServerHttpProtocol(loop=self.loop) + self.assertIsNone(srv._request_handler) + + srv.connection_made(unittest.mock.Mock()) + self.assertIsNotNone(srv._request_handler) + + def test_data_received(self): + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(unittest.mock.Mock()) + + srv.data_received(b'123') + self.assertEqual(b'123', bytes(srv.stream._buffer)) + + srv.data_received(b'456') + self.assertEqual(b'123456', bytes(srv.stream._buffer)) + + def test_eof_received(self): + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(unittest.mock.Mock()) + srv.eof_received() + self.assertTrue(srv.stream._eof) + + def test_connection_lost(self): + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(unittest.mock.Mock()) + srv.data_received(b'123') + + keep_alive_handle = srv._keep_alive_handle = unittest.mock.Mock() + + handle = srv._request_handler + srv.connection_lost(None) + test_utils.run_briefly(self.loop) + + self.assertIsNone(srv._request_handler) + self.assertTrue(handle.cancelled()) + + self.assertIsNone(srv._keep_alive_handle) + self.assertTrue(keep_alive_handle.cancel.called) + + srv.connection_lost(None) + self.assertIsNone(srv._request_handler) + self.assertIsNone(srv._keep_alive_handle) + + def test_srv_keep_alive(self): + srv = server.ServerHttpProtocol(loop=self.loop) + self.assertFalse(srv._keep_alive) + + srv.keep_alive(True) + self.assertTrue(srv._keep_alive) + + srv.keep_alive(False) + self.assertFalse(srv._keep_alive) + + def test_handle_error(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(transport) + srv.keep_alive(True) + + srv.handle_error(404, headers=(('X-Server', 'Tulip'),)) + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertIn(b'HTTP/1.1 404 Not Found', content) + self.assertIn(b'X-SERVER: Tulip', content) + self.assertFalse(srv._keep_alive) + + @unittest.mock.patch('tulip.http.server.traceback') + def test_handle_error_traceback_exc(self, m_trace): + transport = unittest.mock.Mock() + log = unittest.mock.Mock() + srv = server.ServerHttpProtocol(debug=True, log=log, loop=self.loop) + srv.connection_made(transport) + + m_trace.format_exc.side_effect = ValueError + + srv.handle_error(500, exc=object()) + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertTrue( + content.startswith(b'HTTP/1.1 500 Internal Server Error')) + self.assertTrue(log.exception.called) + + def test_handle_error_debug(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + srv.debug = True + srv.connection_made(transport) + + try: + raise ValueError() + except Exception as exc: + srv.handle_error(999, exc=exc) + + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + + self.assertIn(b'HTTP/1.1 500 Internal', content) + self.assertIn(b'Traceback (most recent call last):', content) + + def test_handle_error_500(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log, loop=self.loop) + srv.connection_made(transport) + + srv.handle_error(500) + self.assertTrue(log.exception.called) + + def test_handle(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + + self.loop.run_until_complete(srv._request_handler) + self.assertTrue(handle.called) + self.assertTrue(transport.close.called) + + def test_handle_coro(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + + called = False + + @tulip.coroutine + def coro(message, payload): + nonlocal called + called = True + srv.eof_received() + + srv.handle_request = coro + srv.connection_made(transport) + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + self.loop.run_until_complete(srv._request_handler) + self.assertTrue(called) + + def test_handle_cancel(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log, debug=True, loop=self.loop) + srv.connection_made(transport) + + srv.handle_request = unittest.mock.Mock() + + @tulip.coroutine + def cancel(): + srv._request_handler.cancel() + + self.loop.run_until_complete( + tulip.wait([srv._request_handler, cancel()], loop=self.loop)) + self.assertTrue(log.debug.called) + + def test_handle_cancelled(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log, debug=True, loop=self.loop) + srv.connection_made(transport) + + srv.handle_request = unittest.mock.Mock() + test_utils.run_briefly(self.loop) # start request_handler task + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + + r_handler = srv._request_handler + srv._request_handler = None # emulate srv.connection_lost() + + self.assertIsNone(self.loop.run_until_complete(r_handler)) + + def test_handle_400(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(transport) + srv.handle_error = unittest.mock.Mock() + srv.keep_alive(True) + srv.stream.feed_data(b'GET / HT/asd\r\n\r\n') + + self.loop.run_until_complete(srv._request_handler) + self.assertTrue(srv.handle_error.called) + self.assertTrue(400, srv.handle_error.call_args[0][0]) + self.assertTrue(transport.close.called) + + def test_handle_500(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + handle.side_effect = ValueError + srv.handle_error = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + self.loop.run_until_complete(srv._request_handler) + + self.assertTrue(srv.handle_error.called) + self.assertTrue(500, srv.handle_error.call_args[0][0]) + + def test_handle_error_no_handle_task(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + srv.keep_alive(True) + srv.connection_made(transport) + srv.connection_lost(None) + + srv.handle_error(300) + self.assertFalse(srv._keep_alive) + + def test_keep_alive(self): + srv = server.ServerHttpProtocol(keep_alive=0.1, loop=self.loop) + transport = unittest.mock.Mock() + closed = False + + def close(): + nonlocal closed + closed = True + srv.connection_lost(None) + self.loop.stop() + + transport.close = close + + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.1\r\n' + b'CONNECTION: keep-alive\r\n' + b'HOST: example.com\r\n\r\n') + + self.loop.run_forever() + self.assertTrue(handle.called) + self.assertTrue(closed) + + def test_keep_alive_close_existing(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(keep_alive=15, loop=self.loop) + srv.connection_made(transport) + + self.assertIsNone(srv._keep_alive_handle) + keep_alive_handle = srv._keep_alive_handle = unittest.mock.Mock() + srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'HOST: example.com\r\n\r\n') + + self.loop.run_until_complete(srv._request_handler) + self.assertTrue(keep_alive_handle.cancel.called) + self.assertIsNone(srv._keep_alive_handle) + self.assertTrue(transport.close.called) diff --git a/tests/http_session_test.py b/tests/http_session_test.py new file mode 100644 index 00000000..39a80091 --- /dev/null +++ b/tests/http_session_test.py @@ -0,0 +1,139 @@ +"""Tests for tulip/http/session.py""" + +import http.cookies +import unittest +import unittest.mock + +import tulip +import tulip.http + +from tulip.http.client import HttpResponse +from tulip.http.session import Session + +from tulip import test_utils + + +class HttpSessionTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + tulip.set_event_loop(self.loop) + + self.transport = unittest.mock.Mock() + self.stream = tulip.StreamBuffer() + self.response = HttpResponse('get', 'http://python.org') + + def tearDown(self): + tulip.set_event_loop(None) + self.loop.close() + + def test_del(self): + session = Session() + close = session.close = unittest.mock.Mock() + + del session + self.assertTrue(close.called) + + def test_close(self): + tr = unittest.mock.Mock() + + session = Session() + session._conns[1] = [(tr, object())] + session.close() + + self.assertFalse(session._conns) + self.assertTrue(tr.close.called) + + def test_get(self): + session = Session() + self.assertEqual(session._get(1), (None, None)) + + tr, proto = unittest.mock.Mock(), object() + session._conns[1] = [(tr, proto)] + self.assertEqual(session._get(1), (tr, proto)) + + def test_release(self): + session = Session() + resp = unittest.mock.Mock() + resp.message.should_close = False + + cookies = resp.cookies = http.cookies.SimpleCookie() + cookies['c1'] = 'cookie1' + cookies['c2'] = 'cookie2' + + tr, proto = unittest.mock.Mock(), unittest.mock.Mock() + session._release(resp, 1, (tr, proto)) + self.assertEqual(session._conns[1][0], (tr, proto)) + self.assertEqual(session.cookies, dict(cookies.items())) + + def test_release_close(self): + session = Session() + resp = unittest.mock.Mock() + resp.message.should_close = True + + cookies = resp.cookies = http.cookies.SimpleCookie() + cookies['c1'] = 'cookie1' + cookies['c2'] = 'cookie2' + + tr, proto = unittest.mock.Mock(), unittest.mock.Mock() + session._release(resp, 1, (tr, proto)) + self.assertFalse(session._conns) + self.assertTrue(tr.close.called) + + def test_call_new_conn_exc(self): + tr, proto = unittest.mock.Mock(), unittest.mock.Mock() + + class Req: + host = 'host' + port = 80 + ssl = False + + def send(self, *args): + raise ValueError() + + class Loop: + @tulip.coroutine + def create_connection(self, *args, **kw): + return tr, proto + + session = Session() + self.assertRaises( + ValueError, + self.loop.run_until_complete, session.start(Req(), Loop(), True)) + + self.assertTrue(tr.close.called) + + def test_call_existing_conn_exc(self): + existing = unittest.mock.Mock() + tr, proto = unittest.mock.Mock(), unittest.mock.Mock() + + class Req: + host = 'host' + port = 80 + ssl = False + + def send(self, transport): + if transport is existing: + transport.close() + raise ValueError() + else: + return Resp() + + class Resp: + @tulip.coroutine + def start(self, *args, **kw): + pass + + class Loop: + @tulip.coroutine + def create_connection(self, *args, **kw): + return tr, proto + + session = Session() + key = ('host', 80, False) + session._conns[key] = [(existing, object())] + + resp = self.loop.run_until_complete(session.start(Req(), Loop())) + self.assertIsInstance(resp, Resp) + self.assertTrue(existing.close.called) + self.assertFalse(session._conns[key]) diff --git a/tests/http_websocket_test.py b/tests/http_websocket_test.py new file mode 100644 index 00000000..319538ae --- /dev/null +++ b/tests/http_websocket_test.py @@ -0,0 +1,439 @@ +"""Tests for http/websocket.py""" + +import base64 +import hashlib +import os +import struct +import unittest +import unittest.mock + +import tulip +from tulip.http import websocket, protocol, errors + + +class WebsocketParserTests(unittest.TestCase): + + def test_parse_frame(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + p.send(struct.pack('!BB', 0b00000001, 0b00000001)) + try: + p.send(b'1') + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b'1'), (fin, opcode, payload)) + + def test_parse_frame_length0(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + try: + p.send(struct.pack('!BB', 0b00000001, 0b00000000)) + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b''), (fin, opcode, payload)) + + def test_parse_frame_length2(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + p.send(struct.pack('!BB', 0b00000001, 126)) + p.send(struct.pack('!H', 4)) + try: + p.send(b'1234') + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b'1234'), (fin, opcode, payload)) + + def test_parse_frame_length4(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + p.send(struct.pack('!BB', 0b00000001, 127)) + p.send(struct.pack('!Q', 4)) + try: + p.send(b'1234') + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b'1234'), (fin, opcode, payload)) + + def test_parse_frame_mask(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + p.send(struct.pack('!BB', 0b00000001, 0b10000001)) + p.send(b'0001') + try: + p.send(b'1') + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b'\x01'), (fin, opcode, payload)) + + def test_parse_frame_header_reversed_bits(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b01100000, 0b00000000)) + + def test_parse_frame_header_control_frame(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b00001000, 0b00000000)) + + def test_parse_frame_header_continuation(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b00000000, 0b00000000)) + + def test_parse_frame_header_new_data_err(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b000000000, 0b00000000)) + + def test_parse_frame_header_payload_size(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b10001000, 0b01111110)) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_ping_frame(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_PING, b'') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_PING, '', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_pong_frame(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_PONG, b'') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_PONG, '', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_close_frame(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_CLOSE, b'') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_CLOSE, '', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_close_frame_info(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_CLOSE, b'0112345') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_CLOSE, 12337, b'12345')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_close_frame_invalid(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_CLOSE, b'1') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + self.assertRaises(websocket.WebSocketError, p.send, b'') + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_unknown_frame(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_CONTINUATION, b'') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + self.assertRaises(websocket.WebSocketError, p.send, b'') + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_simple_text(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_TEXT, b'text') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_TEXT, 'text', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_simple_binary(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_BINARY, b'binary') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_BINARY, b'binary', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_continuation(self, m_parse_frame): + cur = 0 + + def parse_frame(buf): + nonlocal cur + yield + if cur == 0: + cur = 1 + return (0, websocket.OPCODE_TEXT, b'line1') + else: + return (1, websocket.OPCODE_CONTINUATION, b'line2') + + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + p.send(b'') + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_TEXT, 'line1line2', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_continuation_err(self, m_parse_frame): + cur = 0 + + def parse_frame(buf): + nonlocal cur + yield + if cur == 0: + cur = 1 + return (0, websocket.OPCODE_TEXT, b'line1') + else: + return (1, websocket.OPCODE_TEXT, b'line2') + + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + p.send(b'') + self.assertRaises(websocket.WebSocketError, p.send, b'') + + @unittest.mock.patch('tulip.http.websocket.parse_message') + def test_parser(self, m_parse_message): + cur = 0 + + def parse_message(buf): + nonlocal cur + yield + if cur == 0: + cur = 1 + return websocket.Message(websocket.OPCODE_TEXT, b'line1', b'') + else: + return websocket.Message(websocket.OPCODE_CLOSE, b'', b'') + + m_parse_message.side_effect = parse_message + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = websocket.WebSocketParser() + next(p) + p.send((out, buf)) + p.send(b'') + self.assertRaises(StopIteration, p.send, b'') + + self.assertEqual( + (websocket.OPCODE_TEXT, b'line1', b''), out._buffer[0]) + self.assertEqual( + (websocket.OPCODE_CLOSE, b'', b''), out._buffer[1]) + self.assertTrue(out._eof) + + def test_parser_eof(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = websocket.WebSocketParser() + next(p) + p.send((out, buf)) + self.assertRaises(tulip.EofStream, p.throw, tulip.EofStream) + self.assertEqual([], list(out._buffer)) + + +class WebsocketWriterTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + self.writer = websocket.WebSocketWriter(self.transport) + + def test_pong(self): + self.writer.pong() + self.transport.write.assert_called_with(b'\x8a\x00') + + def test_ping(self): + self.writer.ping() + self.transport.write.assert_called_with(b'\x89\x00') + + def test_send_text(self): + self.writer.send(b'text') + self.transport.write.assert_called_with(b'\x81\x04text') + + def test_send_binary(self): + self.writer.send('binary', True) + self.transport.write.assert_called_with(b'\x82\x06binary') + + def test_send_binary_long(self): + self.writer.send(b'b'*127, True) + self.assertTrue( + self.transport.write.call_args[0][0].startswith(b'\x82~\x00\x7fb')) + + def test_send_binary_very_long(self): + self.writer.send(b'b'*65537, True) + self.assertTrue( + self.transport.write.call_args[0][0].startswith( + b'\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x01b')) + + def test_close(self): + self.writer.close(1001, 'msg') + self.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') + + self.writer.close(1001, b'msg') + self.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') + + +class WebSocketHandshakeTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + self.headers = [] + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 0), self.headers, True, None) + + def test_not_get(self): + self.assertRaises( + errors.HttpErrorException, + websocket.do_handshake, + 'POST', self.message.headers, self.transport) + + def test_no_upgrade(self): + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + def test_no_connection(self): + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'keep-alive')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + def test_protocol_version(self): + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '1')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + def test_protocol_key(self): + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13'), + ('SEC-WEBSOCKET-KEY', '123')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + sec_key = base64.b64encode(os.urandom(2)) + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13'), + ('SEC-WEBSOCKET-KEY', sec_key.decode())]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + def test_handshake(self): + sec_key = base64.b64encode(os.urandom(16)).decode() + + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13'), + ('SEC-WEBSOCKET-KEY', sec_key)]) + status, headers, parser, writer = websocket.do_handshake( + self.message.method, self.message.headers, self.transport) + self.assertEqual(status, 101) + + key = base64.b64encode( + hashlib.sha1(sec_key.encode() + websocket.WS_KEY).digest()) + headers = dict(headers) + self.assertEqual(headers['SEC-WEBSOCKET-ACCEPT'], key.decode()) diff --git a/tests/http_wsgi_test.py b/tests/http_wsgi_test.py new file mode 100644 index 00000000..053f5a69 --- /dev/null +++ b/tests/http_wsgi_test.py @@ -0,0 +1,301 @@ +"""Tests for http/wsgi.py""" + +import io +import unittest +import unittest.mock + +import tulip +from tulip.http import wsgi +from tulip.http import protocol + + +class HttpWsgiServerProtocolTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(None) + + self.wsgi = unittest.mock.Mock() + self.stream = unittest.mock.Mock() + self.transport = unittest.mock.Mock() + self.transport.get_extra_info.return_value = '127.0.0.1' + + self.headers = [] + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 0), self.headers, True, 'deflate') + self.payload = tulip.DataBuffer() + self.payload.feed_data(b'data') + self.payload.feed_data(b'data') + self.payload.feed_eof() + + def tearDown(self): + self.loop.close() + + def test_ctor(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) + self.assertIs(srv.wsgi, self.wsgi) + self.assertFalse(srv.readpayload) + + def _make_one(self, **kw): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop, **kw) + srv.stream = self.stream + srv.transport = self.transport + return srv.create_wsgi_environ(self.message, self.payload) + + def test_environ(self): + environ = self._make_one() + self.assertEqual(environ['RAW_URI'], '/path') + self.assertEqual(environ['wsgi.async'], True) + + def test_environ_except_header(self): + self.headers.append(('EXPECT', '101-continue')) + self._make_one() + self.assertFalse(self.transport.write.called) + + self.headers[0] = ('EXPECT', '100-continue') + self._make_one() + self.transport.write.assert_called_with( + b'HTTP/1.1 100 Continue\r\n\r\n') + + def test_environ_headers(self): + self.headers.extend( + (('HOST', 'python.org'), + ('SCRIPT_NAME', 'script'), + ('CONTENT-TYPE', 'text/plain'), + ('CONTENT-LENGTH', '209'), + ('X_TEST', '123'), + ('X_TEST', '456'))) + environ = self._make_one(is_ssl=True) + self.assertEqual(environ['CONTENT_TYPE'], 'text/plain') + self.assertEqual(environ['CONTENT_LENGTH'], '209') + self.assertEqual(environ['HTTP_X_TEST'], '123,456') + self.assertEqual(environ['SCRIPT_NAME'], 'script') + self.assertEqual(environ['SERVER_NAME'], 'python.org') + self.assertEqual(environ['SERVER_PORT'], '443') + + def test_environ_host_header(self): + self.headers.append(('HOST', 'python.org')) + environ = self._make_one() + + self.assertEqual(environ['HTTP_HOST'], 'python.org') + self.assertEqual(environ['SERVER_NAME'], 'python.org') + self.assertEqual(environ['SERVER_PORT'], '80') + self.assertEqual(environ['SERVER_PROTOCOL'], 'HTTP/1.0') + + def test_environ_host_port_header(self): + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 1), self.headers, True, 'deflate') + self.headers.append(('HOST', 'python.org:443')) + environ = self._make_one() + + self.assertEqual(environ['HTTP_HOST'], 'python.org:443') + self.assertEqual(environ['SERVER_NAME'], 'python.org') + self.assertEqual(environ['SERVER_PORT'], '443') + self.assertEqual(environ['SERVER_PROTOCOL'], 'HTTP/1.1') + + def test_environ_forward(self): + self.transport.get_extra_info.return_value = 'localhost,127.0.0.1' + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') + self.assertEqual(environ['REMOTE_PORT'], '80') + + self.transport.get_extra_info.return_value = 'localhost,127.0.0.1:443' + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') + self.assertEqual(environ['REMOTE_PORT'], '443') + + self.transport.get_extra_info.return_value = ('127.0.0.1', 443) + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') + self.assertEqual(environ['REMOTE_PORT'], '443') + + self.transport.get_extra_info.return_value = '[::1]' + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '::1') + self.assertEqual(environ['REMOTE_PORT'], '80') + + def test_wsgi_response(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + self.assertIsInstance(resp, wsgi.WsgiResponse) + + def test_wsgi_response_start_response(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + resp.start_response( + '200 OK', [('CONTENT-TYPE', 'text/plain')]) + self.assertEqual(resp.status, '200 OK') + self.assertIsInstance(resp.response, protocol.Response) + + def test_wsgi_response_start_response_exc(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + resp.start_response( + '200 OK', [('CONTENT-TYPE', 'text/plain')], ['', ValueError()]) + self.assertEqual(resp.status, '200 OK') + self.assertIsInstance(resp.response, protocol.Response) + + def test_wsgi_response_start_response_exc_status(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + resp.start_response('200 OK', [('CONTENT-TYPE', 'text/plain')]) + + self.assertRaises( + ValueError, + resp.start_response, + '500 Err', [('CONTENT-TYPE', 'text/plain')], ['', ValueError()]) + + @unittest.mock.patch('tulip.http.wsgi.tulip') + def test_wsgi_response_101_upgrade_to_websocket(self, m_tulip): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + resp.start_response( + '101 Switching Protocols', (('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'))) + self.assertEqual(resp.status, '101 Switching Protocols') + self.assertTrue(m_tulip.http.Response.return_value.send_headers.called) + + def test_file_wrapper(self): + fobj = io.BytesIO(b'data') + wrapper = wsgi.FileWrapper(fobj, 2) + self.assertIs(wrapper, iter(wrapper)) + self.assertTrue(hasattr(wrapper, 'close')) + + self.assertEqual(next(wrapper), b'da') + self.assertEqual(next(wrapper), b'ta') + self.assertRaises(StopIteration, next, wrapper) + + wrapper = wsgi.FileWrapper(b'data', 2) + self.assertFalse(hasattr(wrapper, 'close')) + + def test_handle_request_futures(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + f1 = tulip.Future(loop=self.loop) + f1.set_result(b'data') + fut = tulip.Future(loop=self.loop) + fut.set_result([f1]) + return fut + + srv = wsgi.WSGIServerHttpProtocol(wsgi_app, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) + self.assertTrue(content.endswith(b'data')) + + def test_handle_request_simple(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + return [b'data'] + + stream = tulip.StreamReader(loop=self.loop) + stream.feed_data(b'data') + stream.feed_eof() + + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 1), self.headers, True, 'deflate') + + srv = wsgi.WSGIServerHttpProtocol( + wsgi_app, readpayload=True, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.1 200 OK')) + self.assertTrue(content.endswith(b'data\r\n0\r\n\r\n')) + self.assertFalse(srv._keep_alive) + + def test_handle_request_io(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + return io.BytesIO(b'data') + + srv = wsgi.WSGIServerHttpProtocol(wsgi_app, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) + self.assertTrue(content.endswith(b'data')) + + def test_handle_request_keep_alive(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + return [b'data'] + + stream = tulip.StreamReader(loop=self.loop) + stream.feed_data(b'data') + stream.feed_eof() + + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 1), self.headers, False, 'deflate') + + srv = wsgi.WSGIServerHttpProtocol( + wsgi_app, readpayload=True, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.1 200 OK')) + self.assertTrue(content.endswith(b'data\r\n0\r\n\r\n')) + self.assertTrue(srv._keep_alive) + + def test_handle_request_readpayload(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + return [env['wsgi.input'].read()] + + srv = wsgi.WSGIServerHttpProtocol( + wsgi_app, readpayload=True, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) + self.assertTrue(content.endswith(b'data')) diff --git a/tests/locks_test.py b/tests/locks_test.py new file mode 100644 index 00000000..9399d759 --- /dev/null +++ b/tests/locks_test.py @@ -0,0 +1,765 @@ +"""Tests for lock.py""" + +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import locks +from tulip import tasks +from tulip import test_utils + + +class LockTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + lock = locks.Lock(loop=loop) + self.assertIs(lock._loop, loop) + + lock = locks.Lock(loop=self.loop) + self.assertIs(lock._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + lock = locks.Lock() + self.assertIs(lock._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + lock = locks.Lock(loop=self.loop) + self.assertTrue(repr(lock).endswith('[unlocked]>')) + + @tasks.coroutine + def acquire_lock(): + yield from lock + + self.loop.run_until_complete(acquire_lock()) + self.assertTrue(repr(lock).endswith('[locked]>')) + + def test_lock(self): + lock = locks.Lock(loop=self.loop) + + @tasks.coroutine + def acquire_lock(): + return (yield from lock) + + res = self.loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_acquire(self): + lock = locks.Lock(loop=self.loop) + result = [] + + self.assertTrue(self.loop.run_until_complete(lock.acquire())) + + @tasks.coroutine + def c1(result): + if (yield from lock.acquire()): + result.append(1) + return True + + @tasks.coroutine + def c2(result): + if (yield from lock.acquire()): + result.append(2) + return True + + @tasks.coroutine + def c3(result): + if (yield from lock.acquire()): + result.append(3) + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + t3 = tasks.Task(c3(result), loop=self.loop) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_acquire_cancel(self): + lock = locks.Lock(loop=self.loop) + self.assertTrue(self.loop.run_until_complete(lock.acquire())) + + task = tasks.Task(lock.acquire(), loop=self.loop) + self.loop.call_soon(task.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, task) + self.assertFalse(lock._waiters) + + def test_cancel_race(self): + # Several tasks: + # - A acquires the lock + # - B is blocked in aqcuire() + # - C is blocked in aqcuire() + # + # Now, concurrently: + # - B is cancelled + # - A releases the lock + # + # If B's waiter is marked cancelled but not yet removed from + # _waiters, A's release() call will crash when trying to set + # B's waiter; instead, it should move on to C's waiter. + + # Setup: A has the lock, b and c are waiting. + lock = locks.Lock(loop=self.loop) + + @tasks.coroutine + def lockit(name, blocker): + yield from lock.acquire() + try: + if blocker is not None: + yield from blocker + finally: + lock.release() + + fa = futures.Future(loop=self.loop) + ta = tasks.Task(lockit('A', fa), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertTrue(lock.locked()) + tb = tasks.Task(lockit('B', None), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(len(lock._waiters), 1) + tc = tasks.Task(lockit('C', None), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(len(lock._waiters), 2) + + # Create the race and check. + # Without the fix this failed at the last assert. + fa.set_result(None) + tb.cancel() + self.assertTrue(lock._waiters[0].cancelled()) + test_utils.run_briefly(self.loop) + self.assertFalse(lock.locked()) + self.assertTrue(ta.done()) + self.assertTrue(tb.cancelled()) + self.assertTrue(tc.done()) + + def test_release_not_acquired(self): + lock = locks.Lock(loop=self.loop) + + self.assertRaises(RuntimeError, lock.release) + + def test_release_no_waiters(self): + lock = locks.Lock(loop=self.loop) + self.loop.run_until_complete(lock.acquire()) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_context_manager(self): + lock = locks.Lock(loop=self.loop) + + @tasks.coroutine + def acquire_lock(): + return (yield from lock) + + with self.loop.run_until_complete(acquire_lock()): + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + def test_context_manager_no_yield(self): + lock = locks.Lock(loop=self.loop) + + try: + with lock: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + +class EventWaiterTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + ev = locks.EventWaiter(loop=loop) + self.assertIs(ev._loop, loop) + + ev = locks.EventWaiter(loop=self.loop) + self.assertIs(ev._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + ev = locks.EventWaiter() + self.assertIs(ev._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + ev = locks.EventWaiter(loop=self.loop) + self.assertTrue(repr(ev).endswith('[unset]>')) + + ev.set() + self.assertTrue(repr(ev).endswith('[set]>')) + + def test_wait(self): + ev = locks.EventWaiter(loop=self.loop) + self.assertFalse(ev.is_set()) + + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from ev.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from ev.wait()): + result.append(3) + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + t3 = tasks.Task(c3(result), loop=self.loop) + + ev.set() + test_utils.run_briefly(self.loop) + self.assertEqual([3, 1, 2], result) + + self.assertTrue(t1.done()) + self.assertIsNone(t1.result()) + self.assertTrue(t2.done()) + self.assertIsNone(t2.result()) + self.assertTrue(t3.done()) + self.assertIsNone(t3.result()) + + def test_wait_on_set(self): + ev = locks.EventWaiter(loop=self.loop) + ev.set() + + res = self.loop.run_until_complete(ev.wait()) + self.assertTrue(res) + + def test_wait_cancel(self): + ev = locks.EventWaiter(loop=self.loop) + + wait = tasks.Task(ev.wait(), loop=self.loop) + self.loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, wait) + self.assertFalse(ev._waiters) + + def test_clear(self): + ev = locks.EventWaiter(loop=self.loop) + self.assertFalse(ev.is_set()) + + ev.set() + self.assertTrue(ev.is_set()) + + ev.clear() + self.assertFalse(ev.is_set()) + + def test_clear_with_waiters(self): + ev = locks.EventWaiter(loop=self.loop) + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + return True + + t = tasks.Task(c1(result), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + ev.set() + ev.clear() + self.assertFalse(ev.is_set()) + + ev.set() + ev.set() + self.assertEqual(1, len(ev._waiters)) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertEqual(0, len(ev._waiters)) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + +class ConditionTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + cond = locks.Condition(loop=loop) + self.assertIs(cond._loop, loop) + + cond = locks.Condition(loop=self.loop) + self.assertIs(cond._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + cond = locks.Condition() + self.assertIs(cond._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_wait(self): + cond = locks.Condition(loop=self.loop) + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + return True + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue(self.loop.run_until_complete(cond.acquire())) + cond.notify() + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_wait_cancel(self): + cond = locks.Condition(loop=self.loop) + self.loop.run_until_complete(cond.acquire()) + + wait = tasks.Task(cond.wait(), loop=self.loop) + self.loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, wait) + self.assertFalse(cond._condition_waiters) + self.assertTrue(cond.locked()) + + def test_wait_unacquired(self): + cond = locks.Condition(loop=self.loop) + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, cond.wait()) + + def test_wait_for(self): + cond = locks.Condition(loop=self.loop) + presult = False + + def predicate(): + return presult + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate)): + result.append(1) + cond.release() + return True + + t = tasks.Task(c1(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + presult = True + self.loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + def test_wait_for_unacquired(self): + cond = locks.Condition(loop=self.loop) + + # predicate can return true immediately + res = self.loop.run_until_complete(cond.wait_for(lambda: [1, 2, 3])) + self.assertEqual([1, 2, 3], res) + + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, + cond.wait_for(lambda: False)) + + def test_notify(self): + cond = locks.Condition(loop=self.loop) + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + return True + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + cond.release() + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.notify(2048) + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_notify_all(self): + cond = locks.Condition(loop=self.loop) + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify_all() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + + def test_notify_unacquired(self): + cond = locks.Condition(loop=self.loop) + self.assertRaises(RuntimeError, cond.notify) + + def test_notify_all_unacquired(self): + cond = locks.Condition(loop=self.loop) + self.assertRaises(RuntimeError, cond.notify_all) + + +class SemaphoreTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + sem = locks.Semaphore(loop=loop) + self.assertIs(sem._loop, loop) + + sem = locks.Semaphore(loop=self.loop) + self.assertIs(sem._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + sem = locks.Semaphore() + self.assertIs(sem._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + sem = locks.Semaphore(loop=self.loop) + self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) + + self.loop.run_until_complete(sem.acquire()) + self.assertTrue(repr(sem).endswith('[locked]>')) + + def test_semaphore(self): + sem = locks.Semaphore(loop=self.loop) + self.assertEqual(1, sem._value) + + @tasks.coroutine + def acquire_lock(): + return (yield from sem) + + res = self.loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(sem.locked()) + self.assertEqual(0, sem._value) + + sem.release() + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + def test_semaphore_value(self): + self.assertRaises(ValueError, locks.Semaphore, -1) + + def test_acquire(self): + sem = locks.Semaphore(3, loop=self.loop) + result = [] + + self.assertTrue(self.loop.run_until_complete(sem.acquire())) + self.assertTrue(self.loop.run_until_complete(sem.acquire())) + self.assertFalse(sem.locked()) + + @tasks.coroutine + def c1(result): + yield from sem.acquire() + result.append(1) + return True + + @tasks.coroutine + def c2(result): + yield from sem.acquire() + result.append(2) + return True + + @tasks.coroutine + def c3(result): + yield from sem.acquire() + result.append(3) + return True + + @tasks.coroutine + def c4(result): + yield from sem.acquire() + result.append(4) + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(sem.locked()) + self.assertEqual(2, len(sem._waiters)) + self.assertEqual(0, sem._value) + + t4 = tasks.Task(c4(result), loop=self.loop) + + sem.release() + sem.release() + self.assertEqual(2, sem._value) + + test_utils.run_briefly(self.loop) + self.assertEqual(0, sem._value) + self.assertEqual([1, 2, 3], result) + self.assertTrue(sem.locked()) + self.assertEqual(1, len(sem._waiters)) + self.assertEqual(0, sem._value) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + self.assertFalse(t4.done()) + + # cleanup locked semaphore + sem.release() + + def test_acquire_cancel(self): + sem = locks.Semaphore(loop=self.loop) + self.loop.run_until_complete(sem.acquire()) + + acquire = tasks.Task(sem.acquire(), loop=self.loop) + self.loop.call_soon(acquire.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, acquire) + self.assertFalse(sem._waiters) + + def test_release_not_acquired(self): + sem = locks.Semaphore(bound=True, loop=self.loop) + + self.assertRaises(ValueError, sem.release) + + def test_release_no_waiters(self): + sem = locks.Semaphore(loop=self.loop) + self.loop.run_until_complete(sem.acquire()) + self.assertTrue(sem.locked()) + + sem.release() + self.assertFalse(sem.locked()) + + def test_context_manager(self): + sem = locks.Semaphore(2, loop=self.loop) + + @tasks.coroutine + def acquire_lock(): + return (yield from sem) + + with self.loop.run_until_complete(acquire_lock()): + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + with self.loop.run_until_complete(acquire_lock()): + self.assertTrue(sem.locked()) + + self.assertEqual(2, sem._value) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/parsers_test.py b/tests/parsers_test.py new file mode 100644 index 00000000..debc532c --- /dev/null +++ b/tests/parsers_test.py @@ -0,0 +1,598 @@ +"""Tests for parser.py""" + +import unittest +import unittest.mock + +from tulip import events +from tulip import parsers +from tulip import tasks + + +class StreamBufferTests(unittest.TestCase): + + DATA = b'line1\nline2\nline3\n' + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_exception(self): + stream = parsers.StreamBuffer() + self.assertIsNone(stream.exception()) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(stream.exception(), exc) + + def test_exception_waiter(self): + stream = parsers.StreamBuffer() + + stream._parser = parsers.lines_parser() + buf = stream._parser_buffer = parsers.DataBuffer(loop=self.loop) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(buf.exception(), exc) + + def test_feed_data(self): + stream = parsers.StreamBuffer() + + stream.feed_data(self.DATA) + self.assertEqual(self.DATA, bytes(stream._buffer)) + + def test_feed_empty_data(self): + stream = parsers.StreamBuffer() + + stream.feed_data(b'') + self.assertEqual(b'', bytes(stream._buffer)) + + def test_set_parser_unset_prev(self): + stream = parsers.StreamBuffer() + stream.set_parser(parsers.lines_parser()) + + unset = stream.unset_parser = unittest.mock.Mock() + stream.set_parser(parsers.lines_parser()) + + self.assertTrue(unset.called) + + def test_set_parser_exception(self): + stream = parsers.StreamBuffer() + + exc = ValueError() + stream.set_exception(exc) + s = stream.set_parser(parsers.lines_parser()) + self.assertIs(s.exception(), exc) + + def test_set_parser_feed_existing(self): + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_data(b'\r\nline2\r\ndata') + s = stream.set_parser(parsers.lines_parser()) + + self.assertEqual([bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertIsNotNone(stream._parser) + + stream.unset_parser() + self.assertIsNone(stream._parser) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertTrue(s._eof) + + def test_set_parser_feed_existing_exc(self): + + def p(): + yield # stream + raise ValueError() + + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + s = stream.set_parser(p()) + self.assertIsInstance(s.exception(), ValueError) + + def test_set_parser_feed_existing_eof(self): + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_data(b'\r\nline2\r\ndata') + stream.feed_eof() + s = stream.set_parser(parsers.lines_parser()) + + self.assertEqual([bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertIsNone(stream._parser) + + def test_set_parser_feed_existing_eof_exc(self): + + def p(): + yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + raise ValueError() + + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_eof() + s = stream.set_parser(p()) + self.assertIsInstance(s.exception(), ValueError) + + def test_set_parser_feed_existing_eof_unhandled_eof(self): + + def p(): + yield # stream + while True: + yield # read chunk + + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_eof() + s = stream.set_parser(p()) + self.assertIsNone(s.exception()) + self.assertTrue(s._eof) + + def test_set_parser_unset(self): + stream = parsers.StreamBuffer() + s = stream.set_parser(parsers.lines_parser()) + + stream.feed_data(b'line1\r\nline2\r\n') + self.assertEqual( + [bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'', bytes(stream._buffer)) + stream.unset_parser() + self.assertTrue(s._eof) + self.assertEqual(b'', bytes(stream._buffer)) + + def test_set_parser_feed_existing_stop(self): + def lines_parser(): + out, buf = yield + try: + out.feed_data((yield from buf.readuntil(b'\n'))) + out.feed_data((yield from buf.readuntil(b'\n'))) + finally: + out.feed_eof() + + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_data(b'\r\nline2\r\ndata') + s = stream.set_parser(lines_parser()) + + self.assertEqual(b'line1\r\nline2\r\n', b''.join(s._buffer)) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertIsNone(stream._parser) + self.assertTrue(s._eof) + + def test_feed_parser(self): + stream = parsers.StreamBuffer() + s = stream.set_parser(parsers.lines_parser()) + + stream.feed_data(b'line1') + stream.feed_data(b'\r\nline2\r\ndata') + self.assertEqual(b'data', bytes(stream._buffer)) + + stream.feed_eof() + self.assertEqual([bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertTrue(s._eof) + + def test_feed_parser_exc(self): + def p(): + yield # stream + yield # read chunk + raise ValueError() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + self.assertIsInstance(s.exception(), ValueError) + self.assertEqual(b'', bytes(stream._buffer)) + + def test_feed_parser_stop(self): + def p(): + yield # stream + yield # chunk + + stream = parsers.StreamBuffer() + stream.set_parser(p()) + + stream.feed_data(b'line1') + self.assertIsNone(stream._parser) + self.assertEqual(b'', bytes(stream._buffer)) + + def test_feed_eof_exc(self): + def p(): + yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + raise ValueError() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + self.assertIsNone(s.exception()) + + stream.feed_eof() + self.assertIsInstance(s.exception(), ValueError) + + def test_feed_eof_stop(self): + def p(): + out, buf = yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + out.feed_eof() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.feed_eof() + self.assertTrue(s._eof) + + def test_feed_eof_unhandled_eof(self): + def p(): + yield # stream + while True: + yield # read chunk + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.feed_eof() + self.assertIsNone(s.exception()) + self.assertTrue(s._eof) + + def test_feed_parser2(self): + stream = parsers.StreamBuffer() + s = stream.set_parser(parsers.lines_parser()) + + stream.feed_data(b'line1\r\nline2\r\n') + stream.feed_eof() + self.assertEqual( + [bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'', bytes(stream._buffer)) + self.assertTrue(s._eof) + + def test_unset_parser_eof_exc(self): + def p(): + yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + raise ValueError() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.unset_parser() + self.assertIsInstance(s.exception(), ValueError) + self.assertIsNone(stream._parser) + + def test_unset_parser_eof_unhandled_eof(self): + def p(): + yield # stream + while True: + yield # read chunk + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.unset_parser() + self.assertIsNone(s.exception(), ValueError) + self.assertTrue(s._eof) + + def test_unset_parser_stop(self): + def p(): + out, buf = yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + out.feed_eof() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.unset_parser() + self.assertTrue(s._eof) + + +class DataBufferTests(unittest.TestCase): + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_feed_data(self): + buffer = parsers.DataBuffer(loop=self.loop) + + item = object() + buffer.feed_data(item) + self.assertEqual([item], list(buffer._buffer)) + + def test_feed_eof(self): + buffer = parsers.DataBuffer(loop=self.loop) + buffer.feed_eof() + self.assertTrue(buffer._eof) + + def test_read(self): + item = object() + buffer = parsers.DataBuffer(loop=self.loop) + read_task = tasks.Task(buffer.read(), loop=self.loop) + + def cb(): + buffer.feed_data(item) + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertIs(item, data) + + def test_read_eof(self): + buffer = parsers.DataBuffer(loop=self.loop) + read_task = tasks.Task(buffer.read(), loop=self.loop) + + def cb(): + buffer.feed_eof() + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertIsNone(data) + + def test_read_until_eof(self): + item = object() + buffer = parsers.DataBuffer(loop=self.loop) + buffer.feed_data(item) + buffer.feed_eof() + + data = self.loop.run_until_complete(buffer.read()) + self.assertIs(data, item) + + data = self.loop.run_until_complete(buffer.read()) + self.assertIsNone(data) + + def test_read_exception(self): + buffer = parsers.DataBuffer(loop=self.loop) + buffer.feed_data(object()) + buffer.set_exception(ValueError()) + + self.assertRaises( + ValueError, self.loop.run_until_complete, buffer.read()) + + def test_exception(self): + buffer = parsers.DataBuffer(loop=self.loop) + self.assertIsNone(buffer.exception()) + + exc = ValueError() + buffer.set_exception(exc) + self.assertIs(buffer.exception(), exc) + + def test_exception_waiter(self): + buffer = parsers.DataBuffer(loop=self.loop) + + @tasks.coroutine + def set_err(): + buffer.set_exception(ValueError()) + + t1 = tasks.Task(buffer.read(), loop=self.loop) + t2 = tasks.Task(set_err(), loop=self.loop) + + self.loop.run_until_complete(tasks.wait([t1, t2], loop=self.loop)) + + self.assertRaises(ValueError, t1.result) + + +class StreamProtocolTests(unittest.TestCase): + + def test_connection_made(self): + tr = unittest.mock.Mock() + + proto = parsers.StreamProtocol() + self.assertIsNone(proto.transport) + + proto.connection_made(tr) + self.assertIs(proto.transport, tr) + + def test_connection_lost(self): + proto = parsers.StreamProtocol() + proto.connection_made(unittest.mock.Mock()) + proto.connection_lost(None) + self.assertIsNone(proto.transport) + self.assertTrue(proto._eof) + + def test_connection_lost_exc(self): + proto = parsers.StreamProtocol() + proto.connection_made(unittest.mock.Mock()) + + exc = ValueError() + proto.connection_lost(exc) + self.assertIs(proto.exception(), exc) + + +class ParserBuffer(unittest.TestCase): + + def _make_one(self): + return parsers.ParserBuffer() + + def test_shrink(self): + buf = parsers.ParserBuffer() + buf.feed_data(b'data') + + buf._shrink() + self.assertEqual(bytes(buf), b'data') + + buf.offset = 2 + buf._shrink() + self.assertEqual(bytes(buf), b'ta') + self.assertEqual(2, len(buf)) + self.assertEqual(2, buf.size) + self.assertEqual(0, buf.offset) + + def test_feed_data(self): + buf = self._make_one() + buf.feed_data(b'') + self.assertEqual(len(buf), 0) + + buf.feed_data(b'data') + self.assertEqual(len(buf), 4) + self.assertEqual(bytes(buf), b'data') + + def test_read(self): + buf = self._make_one() + p = buf.read(3) + next(p) + p.send(b'1') + try: + p.send(b'234') + except StopIteration as exc: + res = exc.value + + self.assertEqual(res, b'123') + self.assertEqual(b'4', bytes(buf)) + + def test_readsome(self): + buf = self._make_one() + p = buf.readsome(3) + next(p) + try: + p.send(b'1') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, b'1') + + p = buf.readsome(2) + next(p) + try: + p.send(b'234') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, b'23') + self.assertEqual(b'4', bytes(buf)) + + def test_skip(self): + buf = self._make_one() + p = buf.skip(3) + next(p) + p.send(b'1') + try: + p.send(b'234') + except StopIteration as exc: + res = exc.value + + self.assertIsNone(res) + self.assertEqual(b'4', bytes(buf)) + + def test_readuntil_limit(self): + buf = self._make_one() + p = buf.readuntil(b'\n', 4) + next(p) + p.send(b'1') + p.send(b'234') + self.assertRaises(ValueError, p.send, b'5') + + buf = parsers.ParserBuffer() + p = buf.readuntil(b'\n', 4) + next(p) + self.assertRaises(ValueError, p.send, b'12345\n6') + + buf = parsers.ParserBuffer() + p = buf.readuntil(b'\n', 4) + next(p) + self.assertRaises(ValueError, p.send, b'12345\n6') + + class CustomExc(Exception): + pass + + buf = parsers.ParserBuffer() + p = buf.readuntil(b'\n', 4, CustomExc) + next(p) + self.assertRaises(CustomExc, p.send, b'12345\n6') + + def test_readuntil(self): + buf = self._make_one() + p = buf.readuntil(b'\n', 4) + next(p) + p.send(b'123') + try: + p.send(b'\n456') + except StopIteration as exc: + res = exc.value + + self.assertEqual(res, b'123\n') + self.assertEqual(b'456', bytes(buf)) + + def test_skipuntil(self): + buf = self._make_one() + p = buf.skipuntil(b'\n') + next(p) + p.send(b'123') + try: + p.send(b'\n456\n') + except StopIteration: + pass + self.assertEqual(b'456\n', bytes(buf)) + + p = buf.skipuntil(b'\n') + try: + next(p) + except StopIteration: + pass + self.assertEqual(b'', bytes(buf)) + + def test_lines_parser(self): + out = parsers.DataBuffer(loop=self.loop) + buf = self._make_one() + p = parsers.lines_parser() + next(p) + p.send((out, buf)) + + for d in (b'line1', b'\r\n', b'lin', b'e2\r', b'\ndata'): + p.send(d) + + self.assertEqual( + [bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(out._buffer)) + try: + p.throw(parsers.EofStream()) + except parsers.EofStream: + pass + + self.assertEqual(bytes(buf), b'data') + + def test_chunks_parser(self): + out = parsers.DataBuffer(loop=self.loop) + buf = self._make_one() + p = parsers.chunks_parser(5) + next(p) + p.send((out, buf)) + + for d in (b'line1', b'lin', b'e2d', b'ata'): + p.send(d) + + self.assertEqual( + [bytearray(b'line1'), bytearray(b'line2')], list(out._buffer)) + try: + p.throw(parsers.EofStream()) + except parsers.EofStream: + pass + + self.assertEqual(bytes(buf), b'data') diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py new file mode 100644 index 00000000..da4dea35 --- /dev/null +++ b/tests/proactor_events_test.py @@ -0,0 +1,393 @@ +"""Tests for proactor_events.py""" + +import socket +import unittest +import unittest.mock + +import tulip +from tulip.proactor_events import BaseProactorEventLoop +from tulip.proactor_events import _ProactorSocketTransport +from tulip import test_utils + + +class ProactorSocketTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.proactor = unittest.mock.Mock() + self.loop._proactor = self.proactor + self.protocol = test_utils.make_test_protocol(tulip.Protocol) + self.sock = unittest.mock.Mock(socket.socket) + + def test_ctor(self): + fut = tulip.Future(loop=self.loop) + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol, fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + self.protocol.connection_made(tr) + self.proactor.recv.assert_called_with(self.sock, 4096) + + def test_loop_reading(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._loop_reading() + self.loop._proactor.recv.assert_called_with(self.sock, 4096) + self.assertFalse(self.protocol.data_received.called) + self.assertFalse(self.protocol.eof_received.called) + + def test_loop_reading_data(self): + res = tulip.Future(loop=self.loop) + res.set_result(b'data') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + + tr._read_fut = res + tr._loop_reading(res) + self.loop._proactor.recv.assert_called_with(self.sock, 4096) + self.protocol.data_received.assert_called_with(b'data') + + def test_loop_reading_no_data(self): + res = tulip.Future(loop=self.loop) + res.set_result(b'') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + + self.assertRaises(AssertionError, tr._loop_reading, res) + + tr.close = unittest.mock.Mock() + tr._read_fut = res + tr._loop_reading(res) + self.assertFalse(self.loop._proactor.recv.called) + self.assertTrue(self.protocol.eof_received.called) + self.assertTrue(tr.close.called) + + def test_loop_reading_aborted(self): + err = self.loop._proactor.recv.side_effect = ConnectionAbortedError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with(err) + + def test_loop_reading_aborted_closing(self): + self.loop._proactor.recv.side_effect = ConnectionAbortedError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = True + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + + def test_loop_reading_aborted_is_fatal(self): + self.loop._proactor.recv.side_effect = ConnectionAbortedError() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = False + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + self.assertTrue(tr._fatal_error.called) + + def test_loop_reading_conn_reset_lost(self): + err = self.loop._proactor.recv.side_effect = ConnectionResetError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = False + tr._fatal_error = unittest.mock.Mock() + tr._force_close = unittest.mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + tr._force_close.assert_called_with(err) + + def test_loop_reading_exception(self): + err = self.loop._proactor.recv.side_effect = (OSError()) + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with(err) + + def test_write(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._loop_writing = unittest.mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, [b'data']) + self.assertTrue(tr._loop_writing.called) + + def test_write_no_data(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr.write(b'') + self.assertFalse(tr._buffer) + + def test_write_more(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = unittest.mock.Mock() + tr._loop_writing = unittest.mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, [b'data']) + self.assertFalse(tr._loop_writing.called) + + def test_loop_writing(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + self.loop._proactor.send.assert_called_with(self.sock, b'data') + self.loop._proactor.send.return_value.add_done_callback.\ + assert_called_with(tr._loop_writing) + + @unittest.mock.patch('tulip.proactor_events.tulip_log') + def test_loop_writing_err(self, m_log): + err = self.loop._proactor.send.side_effect = OSError() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + tr._fatal_error.assert_called_with(err) + tr._conn_lost = 1 + + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + self.assertEqual(tr._buffer, []) + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_loop_writing_stop(self): + fut = tulip.Future(loop=self.loop) + fut.set_result(b'data') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = fut + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + + def test_loop_writing_closing(self): + fut = tulip.Future(loop=self.loop) + fut.set_result(1) + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = fut + tr.close() + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test_abort(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._force_close = unittest.mock.Mock() + tr.abort() + tr._force_close.assert_called_with(None) + + def test_close(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr.close() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertTrue(tr._closing) + self.assertEqual(tr._conn_lost, 1) + + self.protocol.connection_lost.reset_mock() + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_write_fut(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = unittest.mock.Mock() + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_buffer(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + @unittest.mock.patch('tulip.proactor_events.tulip_log') + def test_fatal_error(self, m_logging): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._force_close = unittest.mock.Mock() + tr._fatal_error(None) + self.assertTrue(tr._force_close.called) + self.assertTrue(m_logging.exception.called) + + def test_force_close(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + read_fut = tr._read_fut = unittest.mock.Mock() + write_fut = tr._write_fut = unittest.mock.Mock() + tr._force_close(None) + + read_fut.cancel.assert_called_with() + write_fut.cancel.assert_called_with() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertEqual([], tr._buffer) + self.assertEqual(tr._conn_lost, 1) + + def test_force_close_idempotent(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = True + tr._force_close(None) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_fatal_error_2(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + tr._force_close(None) + + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertEqual([], tr._buffer) + + def test_call_connection_lost(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._call_connection_lost(None) + self.assertTrue(self.protocol.connection_lost.called) + self.assertTrue(self.sock.close.called) + + +class BaseProactorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.sock = unittest.mock.Mock(socket.socket) + self.proactor = unittest.mock.Mock() + + self.ssock, self.csock = unittest.mock.Mock(), unittest.mock.Mock() + + class EventLoop(BaseProactorEventLoop): + def _socketpair(s): + return (self.ssock, self.csock) + + self.loop = EventLoop(self.proactor) + + @unittest.mock.patch.object(BaseProactorEventLoop, 'call_soon') + @unittest.mock.patch.object(BaseProactorEventLoop, '_socketpair') + def test_ctor(self, socketpair, call_soon): + ssock, csock = socketpair.return_value = ( + unittest.mock.Mock(), unittest.mock.Mock()) + loop = BaseProactorEventLoop(self.proactor) + self.assertIs(loop._ssock, ssock) + self.assertIs(loop._csock, csock) + self.assertEqual(loop._internal_fds, 1) + call_soon.assert_called_with(loop._loop_self_reading) + + def test_close_self_pipe(self): + self.loop._close_self_pipe() + self.assertEqual(self.loop._internal_fds, 0) + self.assertTrue(self.ssock.close.called) + self.assertTrue(self.csock.close.called) + self.assertIsNone(self.loop._ssock) + self.assertIsNone(self.loop._csock) + + def test_close(self): + self.loop._close_self_pipe = unittest.mock.Mock() + self.loop.close() + self.assertTrue(self.loop._close_self_pipe.called) + self.assertTrue(self.proactor.close.called) + self.assertIsNone(self.loop._proactor) + + self.loop._close_self_pipe.reset_mock() + self.loop.close() + self.assertFalse(self.loop._close_self_pipe.called) + + def test_sock_recv(self): + self.loop.sock_recv(self.sock, 1024) + self.proactor.recv.assert_called_with(self.sock, 1024) + + def test_sock_sendall(self): + self.loop.sock_sendall(self.sock, b'data') + self.proactor.send.assert_called_with(self.sock, b'data') + + def test_sock_connect(self): + self.loop.sock_connect(self.sock, 123) + self.proactor.connect.assert_called_with(self.sock, 123) + + def test_sock_accept(self): + self.loop.sock_accept(self.sock) + self.proactor.accept.assert_called_with(self.sock) + + def test_socketpair(self): + self.assertRaises( + NotImplementedError, BaseProactorEventLoop, self.proactor) + + def test_make_socket_transport(self): + tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock()) + self.assertIsInstance(tr, _ProactorSocketTransport) + + def test_loop_self_reading(self): + self.loop._loop_self_reading() + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_fut(self): + fut = unittest.mock.Mock() + self.loop._loop_self_reading(fut) + self.assertTrue(fut.result.called) + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_exception(self): + self.loop.close = unittest.mock.Mock() + self.proactor.recv.side_effect = OSError() + self.assertRaises(OSError, self.loop._loop_self_reading) + self.assertTrue(self.loop.close.called) + + def test_write_to_self(self): + self.loop._write_to_self() + self.csock.send.assert_called_with(b'x') + + def test_process_events(self): + self.loop._process_events([]) + + @unittest.mock.patch('tulip.proactor_events.tulip_log') + def test_start_serving(self, m_log): + pf = unittest.mock.Mock() + call_soon = self.loop.call_soon = unittest.mock.Mock() + + self.loop._start_serving(pf, self.sock) + self.assertTrue(call_soon.called) + + # callback + loop = call_soon.call_args[0][0] + loop() + self.proactor.accept.assert_called_with(self.sock) + + # conn + fut = unittest.mock.Mock() + fut.result.return_value = (unittest.mock.Mock(), unittest.mock.Mock()) + + make_tr = self.loop._make_socket_transport = unittest.mock.Mock() + loop(fut) + self.assertTrue(fut.result.called) + self.assertTrue(make_tr.called) + + # exception + fut.result.side_effect = OSError() + loop(fut) + self.assertTrue(self.sock.close.called) + self.assertTrue(m_log.exception.called) + + def test_start_serving_cancel(self): + pf = unittest.mock.Mock() + call_soon = self.loop.call_soon = unittest.mock.Mock() + + self.loop._start_serving(pf, self.sock) + loop = call_soon.call_args[0][0] + + # cancelled + fut = tulip.Future(loop=self.loop) + fut.cancel() + loop(fut) + self.assertTrue(self.sock.close.called) + + def test_stop_serving(self): + sock = unittest.mock.Mock() + self.loop.stop_serving(sock) + self.assertTrue(sock.close.called) + self.proactor.stop_serving.assert_called_with(sock) diff --git a/tests/queues_test.py b/tests/queues_test.py new file mode 100644 index 00000000..ab4ee91d --- /dev/null +++ b/tests/queues_test.py @@ -0,0 +1,455 @@ +"""Tests for queues.py""" + +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import locks +from tulip import queues +from tulip import tasks +from tulip import test_utils + + +class _QueueTestBase(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + +class QueueBasicTests(_QueueTestBase): + + def _test_repr_or_str(self, fn, expect_id): + """Test Queue's repr or str. + + fn is repr or str. expect_id is True if we expect the Queue's id to + appear in fn(Queue()). + """ + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.2, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = queues.Queue(loop=loop) + self.assertTrue(fn(q).startswith(')') + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), 'Task()') + self.assertRaises(futures.CancelledError, + self.loop.run_until_complete, t) + self.assertEqual(repr(t), 'Task()') + t = tasks.Task(notmuch(), loop=self.loop) + self.loop.run_until_complete(t) + self.assertEqual(repr(t), "Task()") + + def test_task_repr_custom(self): + @tasks.coroutine + def coro(): + pass + + class T(futures.Future): + def __repr__(self): + return 'T[]' + + class MyTask(tasks.Task, T): + def __repr__(self): + return super().__repr__() + + t = MyTask(coro(), loop=self.loop) + self.assertEqual(repr(t), 'T[]()') + + def test_task_basics(self): + @tasks.coroutine + def outer(): + a = yield from inner1() + b = yield from inner2() + return a+b + + @tasks.coroutine + def inner1(): + return 42 + + @tasks.coroutine + def inner2(): + return 1000 + + t = outer() + self.assertEqual(self.loop.run_until_complete(t), 1042) + + def test_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def task(): + yield from tasks.sleep(10.0, loop=loop) + return 12 + + t = tasks.Task(task(), loop=loop) + loop.call_soon(t.cancel) + self.assertRaises( + futures.CancelledError, loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + + def test_cancel_yield(self): + @tasks.coroutine + def task(): + yield + yield + return 12 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) # start coro + t.cancel() + self.assertRaises( + futures.CancelledError, self.loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t.cancel()) + +## def test_cancel_done_future(self): +## fut1 = futures.Future(loop=self.loop) +## fut2 = futures.Future(loop=self.loop) +## fut3 = futures.Future(loop=self.loop) + +## @tasks.coroutine +## def task(): +## yield from fut1 +## try: +## yield from fut2 +## except futures.CancelledError: +## pass +## yield from fut3 + +## t = tasks.Task(task(), loop=self.loop) +## test_utils.run_briefly(self.loop) +## fut1.set_result(None) +## t.cancel() +## test_utils.run_once(self.loop) # process fut1 result, delay cancel +## self.assertFalse(t.done()) +## test_utils.run_once(self.loop) # cancel fut2, but coro still alive +## self.assertFalse(t.done()) +## test_utils.run_briefly(self.loop) # cancel fut3 +## self.assertTrue(t.done()) + +## self.assertEqual(fut1.result(), None) +## self.assertTrue(fut2.cancelled()) +## self.assertTrue(fut3.cancelled()) +## self.assertTrue(t.cancelled()) + +## def test_cancel_in_coro(self): +## @tasks.coroutine +## def task(): +## t.cancel() +## return 12 + +## t = tasks.Task(task(), loop=self.loop) +## self.assertRaises( +## futures.CancelledError, self.loop.run_until_complete, t) +## self.assertTrue(t.done()) +## self.assertFalse(t.cancel()) + + def test_stop_while_run_in_complete(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.2, when) + when = yield 0.1 + self.assertAlmostEqual(0.3, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + x = 0 + waiters = [] + + @tasks.coroutine + def task(): + nonlocal x + while x < 10: + waiters.append(tasks.sleep(0.1, loop=loop)) + yield from waiters[-1] + x += 1 + if x == 2: + loop.stop() + + t = tasks.Task(task(), loop=loop) + self.assertRaises( + RuntimeError, loop.run_until_complete, t) + self.assertFalse(t.done()) + self.assertEqual(x, 2) + self.assertAlmostEqual(0.3, loop.time()) + + # close generators + for w in waiters: + w.close() + + def test_wait_for(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.2, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.4, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def foo(): + yield from tasks.sleep(0.2, loop=loop) + return 'done' + + fut = tasks.Task(foo(), loop=loop) + + with self.assertRaises(futures.TimeoutError): + loop.run_until_complete(tasks.wait_for(fut, 0.1, loop=loop)) + + self.assertFalse(fut.done()) + self.assertAlmostEqual(0.1, loop.time()) + + loop + # wait for result + res = loop.run_until_complete( + tasks.wait_for(fut, 0.3, loop=loop)) + self.assertEqual(res, 'done') + self.assertAlmostEqual(0.2, loop.time()) + + def test_wait_for_with_global_loop(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.2, when) + when = yield 0 + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def foo(): + yield from tasks.sleep(0.2, loop=loop) + return 'done' + + events.set_event_loop(loop) + try: + fut = tasks.Task(foo(), loop=loop) + with self.assertRaises(futures.TimeoutError): + loop.run_until_complete(tasks.wait_for(fut, 0.01)) + finally: + events.set_event_loop(None) + + self.assertAlmostEqual(0.01, loop.time()) + self.assertFalse(fut.done()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(fut) + + def test_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], loop=loop) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertEqual(res, 42) + self.assertAlmostEqual(0.15, loop.time()) + + # Doing it again should take no time and exercise a different path. + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + self.assertEqual(res, 42) + + def test_wait_with_global_loop(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + when = yield 0 + self.assertAlmostEqual(0.015, when) + yield 0.015 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.01, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.015, loop=loop), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + events.set_event_loop(loop) + try: + res = loop.run_until_complete( + tasks.Task(foo(), loop=loop)) + finally: + events.set_event_loop(None) + + self.assertEqual(res, 42) + + def test_wait_errors(self): + self.assertRaises( + ValueError, self.loop.run_until_complete, + tasks.wait(set(), loop=self.loop)) + + self.assertRaises( + ValueError, self.loop.run_until_complete, + tasks.wait([tasks.sleep(10.0, loop=self.loop)], + return_when=-1, loop=self.loop)) + + def test_wait_first_completed(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED, + loop=loop), + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertFalse(a.done()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + self.assertAlmostEqual(0.1, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_really_done(self): + # there is possibility that some tasks in the pending list + # became done but their callbacks haven't all been called yet + + @tasks.coroutine + def coro1(): + yield + + @tasks.coroutine + def coro2(): + yield + yield + + a = tasks.Task(coro1(), loop=self.loop) + b = tasks.Task(coro2(), loop=self.loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED, + loop=self.loop), + loop=self.loop) + + done, pending = self.loop.run_until_complete(task) + self.assertEqual({a, b}, done) + self.assertTrue(a.done()) + self.assertIsNone(a.result()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + + def test_wait_first_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + # first_exception, task already has exception + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + + @tasks.coroutine + def exc(): + raise ZeroDivisionError('err') + + b = tasks.Task(exc(), loop=loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION, + loop=loop), + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertAlmostEqual(0, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_first_exception_in_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + # first_exception, exception during waiting + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + + @tasks.coroutine + def exc(): + yield from tasks.sleep(0.01, loop=loop) + raise ZeroDivisionError('err') + + b = tasks.Task(exc(), loop=loop) + task = tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION, + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertAlmostEqual(0.01, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_with_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(0.15, loop=loop) + raise ZeroDivisionError('really') + + b = tasks.Task(sleeper(), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], loop=loop) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + def test_wait_with_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.11, when) + yield 0.11 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], timeout=0.11, + loop=loop) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.11, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_concurrent_complete(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + done, pending = loop.run_until_complete( + tasks.wait([b, a], timeout=0.1, loop=loop)) + + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + self.assertAlmostEqual(0.1, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_as_completed(self): + + def gen(): + yield 0 + yield 0 + yield 0.01 + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + completed = set() + time_shifted = False + + @tasks.coroutine + def sleeper(dt, x): + nonlocal time_shifted + yield from tasks.sleep(dt, loop=loop) + completed.add(x) + if not time_shifted and 'a' in completed and 'b' in completed: + time_shifted = True + loop.advance_time(0.14) + return x + + a = sleeper(0.01, 'a') + b = sleeper(0.01, 'b') + c = sleeper(0.15, 'c') + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([b, c, a], loop=loop): + values.append((yield from f)) + return values + + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + + # Doing it again should take no time and exercise a different path. + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + def test_as_completed_with_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.12, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0.1 + self.assertAlmostEqual(0.12, when) + yield 0.02 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.sleep(0.1, 'a', loop=loop) + b = tasks.sleep(0.15, 'b', loop=loop) + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([a, b], timeout=0.12, loop=loop): + try: + v = yield from f + values.append((1, v)) + except futures.TimeoutError as exc: + values.append((2, exc)) + return values + + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) + self.assertAlmostEqual(0.12, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_as_completed_reverse_wait(self): + + def gen(): + yield 0 + yield 0.05 + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + completed = set() + time_shifted = False + + a = tasks.sleep(0.05, 'a', loop=loop) + b = tasks.sleep(0.10, 'b', loop=loop) + fs = {a, b} + futs = list(tasks.as_completed(fs, loop=loop)) + self.assertEqual(len(futs), 2) + + x = loop.run_until_complete(futs[1]) + self.assertEqual(x, 'a') + self.assertAlmostEqual(0.05, loop.time()) + loop.advance_time(0.05) + y = loop.run_until_complete(futs[0]) + self.assertEqual(y, 'b') + self.assertAlmostEqual(0.10, loop.time()) + + def test_as_completed_concurrent(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0 + self.assertAlmostEqual(0.05, when) + yield 0.05 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.sleep(0.05, 'a', loop=loop) + b = tasks.sleep(0.05, 'b', loop=loop) + fs = {a, b} + futs = list(tasks.as_completed(fs, loop=loop)) + self.assertEqual(len(futs), 2) + waiter = tasks.wait(futs, loop=loop) + done, pending = loop.run_until_complete(waiter) + self.assertEqual(set(f.result() for f in done), {'a', 'b'}) + + def test_sleep(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0.05 + self.assertAlmostEqual(0.1, when) + yield 0.05 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def sleeper(dt, arg): + yield from tasks.sleep(dt/2, loop=loop) + res = yield from tasks.sleep(dt/2, arg, loop=loop) + return res + + t = tasks.Task(sleeper(0.1, 'yeah'), loop=loop) + loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + self.assertAlmostEqual(0.1, loop.time()) + + def test_sleep_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + t = tasks.Task(tasks.sleep(10.0, 'yeah', loop=loop), + loop=loop) + + handle = None + orig_call_later = loop.call_later + + def call_later(self, delay, callback, *args): + nonlocal handle + handle = orig_call_later(self, delay, callback, *args) + return handle + + loop.call_later = call_later + test_utils.run_briefly(loop) + + self.assertFalse(handle._cancelled) + + t.cancel() + test_utils.run_briefly(loop) + self.assertTrue(handle._cancelled) + + def test_task_cancel_sleeping_task(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(5000, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + sleepfut = None + + @tasks.coroutine + def sleep(dt): + nonlocal sleepfut + sleepfut = tasks.sleep(dt, loop=loop) + yield from sleepfut + + @tasks.coroutine + def doit(): + sleeper = tasks.Task(sleep(5000), loop=loop) + loop.call_later(0.1, sleeper.cancel) + try: + yield from sleeper + except futures.CancelledError: + return 'cancelled' + else: + return 'slept in' + + doer = doit() + self.assertEqual(loop.run_until_complete(doer), 'cancelled') + self.assertAlmostEqual(0.1, loop.time()) + + def test_task_cancel_waiter_future(self): + fut = futures.Future(loop=self.loop) + + @tasks.coroutine + def coro(): + yield from fut + + task = tasks.Task(coro(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(task._fut_waiter, fut) + + task.cancel() + test_utils.run_briefly(self.loop) + self.assertRaises( + futures.CancelledError, self.loop.run_until_complete, task) + self.assertIsNone(task._fut_waiter) + self.assertTrue(fut.cancelled()) + + def test_step_in_completed_task(self): + @tasks.coroutine + def notmuch(): + return 'ko' + + task = tasks.Task(notmuch(), loop=self.loop) + task.set_result('ok') + + self.assertRaises(AssertionError, task._step) + + def test_step_result(self): + @tasks.coroutine + def notmuch(): + yield None + yield 1 + return 'ko' + + self.assertRaises( + RuntimeError, self.loop.run_until_complete, notmuch()) + + def test_step_result_future(self): + # If coroutine returns future, task waits on this future. + + class Fut(futures.Future): + def __init__(self, *args, **kwds): + self.cb_added = False + super().__init__(*args, **kwds) + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + fut = Fut(loop=self.loop) + result = None + + @tasks.coroutine + def wait_for_future(): + nonlocal result + result = yield from fut + + t = tasks.Task(wait_for_future(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertTrue(fut.cb_added) + + res = object() + fut.set_result(res) + test_utils.run_briefly(self.loop) + self.assertIs(res, result) + self.assertTrue(t.done()) + self.assertIsNone(t.result()) + + def test_step_with_baseexception(self): + @tasks.coroutine + def notmutch(): + raise BaseException() + + task = tasks.Task(notmutch(), loop=self.loop) + self.assertRaises(BaseException, task._step) + + self.assertTrue(task.done()) + self.assertIsInstance(task.exception(), BaseException) + + def test_baseexception_during_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(10, loop=loop) + + base_exc = BaseException() + + @tasks.coroutine + def notmutch(): + try: + yield from sleeper() + except futures.CancelledError: + raise base_exc + + task = tasks.Task(notmutch(), loop=loop) + test_utils.run_briefly(loop) + + task.cancel() + self.assertFalse(task.done()) + + self.assertRaises(BaseException, test_utils.run_briefly, loop) + + self.assertTrue(task.done()) + self.assertFalse(task.cancelled()) + self.assertIs(task.exception(), base_exc) + + def test_iscoroutinefunction(self): + def fn(): + pass + + self.assertFalse(tasks.iscoroutinefunction(fn)) + + def fn1(): + yield + self.assertFalse(tasks.iscoroutinefunction(fn1)) + + @tasks.coroutine + def fn2(): + yield + self.assertTrue(tasks.iscoroutinefunction(fn2)) + + def test_yield_vs_yield_from(self): + fut = futures.Future(loop=self.loop) + + @tasks.coroutine + def wait_for_future(): + yield fut + + task = wait_for_future() + with self.assertRaises(RuntimeError) as cm: + self.loop.run_until_complete(task) + + self.assertFalse(fut.done()) + + def test_yield_vs_yield_from_generator(self): + @tasks.coroutine + def coro(): + yield + + @tasks.coroutine + def wait_for_future(): + yield coro() + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, task) + + def test_coroutine_non_gen_function(self): + @tasks.coroutine + def func(): + return 'test' + + self.assertTrue(tasks.iscoroutinefunction(func)) + + coro = func() + self.assertTrue(tasks.iscoroutine(coro)) + + res = self.loop.run_until_complete(coro) + self.assertEqual(res, 'test') + + def test_coroutine_non_gen_function_return_future(self): + fut = futures.Future(loop=self.loop) + + @tasks.coroutine + def func(): + return fut + + @tasks.coroutine + def coro(): + fut.set_result('test') + + t1 = tasks.Task(func(), loop=self.loop) + t2 = tasks.Task(coro(), loop=self.loop) + res = self.loop.run_until_complete(t1) + self.assertEqual(res, 'test') + self.assertIsNone(t2.result()) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/transports_test.py b/tests/transports_test.py new file mode 100644 index 00000000..5920cda6 --- /dev/null +++ b/tests/transports_test.py @@ -0,0 +1,59 @@ +"""Tests for transports.py.""" + +import unittest +import unittest.mock + +from tulip import futures +from tulip import transports + + +class TransportTests(unittest.TestCase): + + def test_ctor_extra_is_none(self): + transport = transports.Transport() + self.assertEqual(transport._extra, {}) + + def test_get_extra_info(self): + transport = transports.Transport({'extra': 'info'}) + self.assertEqual('info', transport.get_extra_info('extra')) + self.assertIsNone(transport.get_extra_info('unknown')) + + default = object() + self.assertIs(default, transport.get_extra_info('unknown', default)) + + def test_writelines(self): + transport = transports.Transport() + transport.write = unittest.mock.Mock() + + transport.writelines(['line1', 'line2', 'line3']) + self.assertEqual(3, transport.write.call_count) + + def test_not_implemented(self): + transport = transports.Transport() + + self.assertRaises(NotImplementedError, transport.write, 'data') + self.assertRaises(NotImplementedError, transport.write_eof) + self.assertRaises(NotImplementedError, transport.can_write_eof) + self.assertRaises(NotImplementedError, transport.pause) + self.assertRaises(NotImplementedError, transport.resume) + self.assertRaises(NotImplementedError, transport.close) + self.assertRaises(NotImplementedError, transport.abort) + self.assertRaises(NotImplementedError, transport.pause_writing) + self.assertRaises(NotImplementedError, transport.resume_writing) + self.assertRaises(NotImplementedError, transport.discard_output) + + def test_dgram_not_implemented(self): + transport = transports.DatagramTransport() + + self.assertRaises(NotImplementedError, transport.sendto, 'data') + self.assertRaises(NotImplementedError, transport.abort) + + def test_subprocess_transport_not_implemented(self): + transport = transports.SubprocessTransport() + + self.assertRaises(NotImplementedError, transport.get_pid) + self.assertRaises(NotImplementedError, transport.get_returncode) + self.assertRaises(NotImplementedError, transport.get_pipe_transport, 1) + self.assertRaises(NotImplementedError, transport.send_signal, 1) + self.assertRaises(NotImplementedError, transport.terminate) + self.assertRaises(NotImplementedError, transport.kill) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py new file mode 100644 index 00000000..f0b42a39 --- /dev/null +++ b/tests/unix_events_test.py @@ -0,0 +1,818 @@ +"""Tests for unix_events.py.""" + +import gc +import errno +import io +import pprint +import signal +import stat +import sys +import tempfile +import unittest +import unittest.mock + + +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import test_utils +from tulip import unix_events + + +@unittest.skipUnless(signal, 'Signals are not supported') +class SelectorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.loop = unix_events.SelectorEventLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_check_signal(self): + self.assertRaises( + TypeError, self.loop._check_signal, '1') + self.assertRaises( + ValueError, self.loop._check_signal, signal.NSIG + 1) + + def test_handle_signal_no_handler(self): + self.loop._handle_signal(signal.NSIG + 1, ()) + + def test_handle_signal_cancelled_handler(self): + h = events.Handle(unittest.mock.Mock(), ()) + h.cancel() + self.loop._signal_handlers[signal.NSIG + 1] = h + self.loop.remove_signal_handler = unittest.mock.Mock() + self.loop._handle_signal(signal.NSIG + 1, ()) + self.loop.remove_signal_handler.assert_called_with(signal.NSIG + 1) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_setup_error(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.set_wakeup_fd.side_effect = ValueError + + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + cb = lambda: True + self.loop.add_signal_handler(signal.SIGHUP, cb) + h = self.loop._signal_handlers.get(signal.SIGHUP) + self.assertTrue(isinstance(h, events.Handle)) + self.assertEqual(h._callback, cb) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_install_error(self, m_signal): + m_signal.NSIG = signal.NSIG + + def set_wakeup_fd(fd): + if fd == -1: + raise ValueError() + m_signal.set_wakeup_fd = set_wakeup_fd + + class Err(OSError): + errno = errno.EFAULT + m_signal.signal.side_effect = Err + + self.assertRaises( + Err, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_add_signal_handler_install_error2(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.loop._signal_handlers[signal.SIGHUP] = lambda: True + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(1, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_add_signal_handler_install_error3(self, m_logging, m_signal): + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + m_signal.NSIG = signal.NSIG + + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(2, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + self.assertTrue( + self.loop.remove_signal_handler(signal.SIGHUP)) + self.assertTrue(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_2(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.SIGINT = signal.SIGINT + + self.loop.add_signal_handler(signal.SIGINT, lambda: True) + self.loop._signal_handlers[signal.SIGHUP] = object() + m_signal.set_wakeup_fd.reset_mock() + + self.assertTrue( + self.loop.remove_signal_handler(signal.SIGINT)) + self.assertFalse(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGINT, m_signal.default_int_handler), + m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.set_wakeup_fd.side_effect = ValueError + + self.loop.remove_signal_handler(signal.SIGHUP) + self.assertTrue(m_logging.info) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error(self, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.signal.side_effect = OSError + + self.assertRaises( + OSError, self.loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error2(self, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.assertRaises( + RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = True + m_WIFSIGNALED.return_value = False + m_WEXITSTATUS.return_value = 3 + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + transp._process_exited.assert_called_with(3) + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_signal(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = False + m_WIFSIGNALED.return_value = True + m_WTERMSIG.return_value = 1 + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + transp._process_exited.assert_called_with(-1) + self.assertFalse(m_WEXITSTATUS.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_zero_pid(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(0, object()), ChildProcessError] + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m_WEXITSTATUS.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_not_registered_subprocess(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = True + m_WIFSIGNALED.return_value = False + m_WEXITSTATUS.return_value = 3 + + self.loop._sig_chld() + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_unknown_status(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = False + m_WIFSIGNALED.return_value = False + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('tulip.unix_events.tulip_log') + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_unknown_status_in_handler(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG, + m_log): + m_waitpid.side_effect = Exception + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m_WEXITSTATUS.called) + m_log.exception.assert_called_with( + 'Unknown exception in SIGCHLD handler') + + +class UnixReadPipeTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(protocols.Protocol) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + + fcntl_patcher = unittest.mock.patch('fcntl.fcntl') + fcntl_patcher.start() + self.addCleanup(fcntl_patcher.stop) + + def test_ctor(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = futures.Future(loop=self.loop) + unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol, fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + + @unittest.mock.patch('os.read') + def test__read_ready(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.return_value = b'data' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.data_received.assert_called_with(b'data') + + @unittest.mock.patch('os.read') + def test__read_ready_eof(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.return_value = b'' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.eof_received.assert_called_with() + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('os.read') + def test__read_ready_blocked(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.side_effect = BlockingIOError + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.data_received.called) + + @unittest.mock.patch('tulip.log.tulip_log.exception') + @unittest.mock.patch('os.read') + def test__read_ready_error(self, m_read, m_logexc): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + err = OSError() + m_read.side_effect = err + tr._close = unittest.mock.Mock() + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + tr._close.assert_called_with(err) + m_logexc.assert_called_with('Fatal error for %s', tr) + + @unittest.mock.patch('os.read') + def test_pause(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + m = unittest.mock.Mock() + self.loop.add_reader(5, m) + tr.pause() + self.assertFalse(self.loop.readers) + + @unittest.mock.patch('os.read') + def test_resume(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr.resume() + self.loop.assert_reader(5, tr._read_ready) + + @unittest.mock.patch('os.read') + def test_close(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr._close = unittest.mock.Mock() + tr.close() + tr._close.assert_called_with(None) + + @unittest.mock.patch('os.read') + def test_close_already_closing(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr._closing = True + tr._close = unittest.mock.Mock() + tr.close() + self.assertFalse(tr._close.called) + + @unittest.mock.patch('os.read') + def test__close(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = object() + tr._close(err) + self.assertTrue(tr._closing) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + + def test__call_connection_lost(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test__call_connection_lost_with_err(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + +class UnixWritePipeTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(protocols.BaseProtocol) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + + fcntl_patcher = unittest.mock.patch('fcntl.fcntl') + fcntl_patcher.start() + self.addCleanup(fcntl_patcher.stop) + + fstat_patcher = unittest.mock.patch('os.fstat') + m_fstat = fstat_patcher.start() + st = unittest.mock.Mock() + st.st_mode = stat.S_IFIFO + m_fstat.return_value = st + self.addCleanup(fstat_patcher.stop) + + def test_ctor(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = futures.Future(loop=self.loop) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol, fut) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.assertEqual(None, fut.result()) + + def test_can_write_eof(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.assertTrue(tr.can_write_eof()) + + @unittest.mock.patch('os.write') + def test_write(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.return_value = 4 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_no_data(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write(b'') + self.assertFalse(m_write.called) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_partial(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.return_value = 2 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'ta'], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_buffer(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'previous'] + tr.write(b'data') + self.assertFalse(m_write.called) + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'previous', b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_again(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.side_effect = BlockingIOError() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('tulip.unix_events.tulip_log') + @unittest.mock.patch('os.write') + def test_write_err(self, m_write, m_log): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + m_write.side_effect = err + tr._fatal_error = unittest.mock.Mock() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + tr._fatal_error.assert_called_with(err) + self.assertEqual(1, tr._conn_lost) + + tr.write(b'data') + self.assertEqual(2, tr._conn_lost) + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + m_log.warning.assert_called_with( + 'os.write(pipe, data) raised exception.') + + def test__read_ready(self): + tr = unix_events._UnixWritePipeTransport(self.loop, self.pipe, + self.protocol) + tr._read_ready() + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + self.assertTrue(tr._closing) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('os.write') + def test__write_ready(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_partial(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 3 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'a'], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_again(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.side_effect = BlockingIOError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_empty(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 0 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('tulip.log.tulip_log.exception') + @unittest.mock.patch('os.write') + def test__write_ready_err(self, m_write, m_logexc): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.side_effect = err = OSError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + m_logexc.assert_called_with('Fatal error for %s', tr) + self.assertEqual(1, tr._conn_lost) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + + + @unittest.mock.patch('os.write') + def test__write_ready_closing(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._closing = True + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) + self.assertEqual([], tr._buffer) + self.protocol.connection_lost.assert_called_with(None) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('os.write') + def test_abort(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + self.loop.add_reader(5, tr._read_ready) + tr._buffer = [b'da', b'ta'] + tr.abort() + self.assertFalse(m_write.called) + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test__call_connection_lost(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test__call_connection_lost_with_err(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test_close(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr.close() + tr.write_eof.assert_called_with() + + def test_close_closing(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr._closing = True + tr.close() + self.assertFalse(tr.write_eof.called) + + def test_write_eof(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test_write_eof_pending(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr._buffer = [b'data'] + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.protocol.connection_lost.called) + + def test_pause_resume_writing(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr.pause_writing() + self.assertFalse(tr._writing) + tr.resume_writing() + self.assertTrue(tr._writing) + + def test_double_pause_resume_writing(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr.pause_writing() + self.assertFalse(tr._writing) + tr.pause_writing() + self.assertFalse(tr._writing) + tr.resume_writing() + self.assertTrue(tr._writing) + tr.resume_writing() + self.assertTrue(tr._writing) + + def test_pause_resume_writing_with_nonempty_buffer(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + tr.pause_writing() + self.assertFalse(tr._writing) + self.assertFalse(self.loop.writers) + self.assertEqual([b'da', b'ta'], tr._buffer) + + tr.resume_writing() + self.assertTrue(tr._writing) + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'da', b'ta'], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_on_pause(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + tr.pause_writing() + + tr._write_ready() + self.assertFalse(m_write.called) + self.assertFalse(self.loop.writers) + self.assertEqual([b'da', b'ta'], tr._buffer) + self.assertFalse(tr._writing) + + def test_discard_output(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr._buffer = [b'da', b'ta'] + self.loop.add_writer(5, tr._write_ready) + tr.discard_output() + self.assertTrue(tr._writing) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + def test_discard_output_without_pending_writes(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr.discard_output() + self.assertTrue(tr._writing) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) diff --git a/tests/windows_events_test.py b/tests/windows_events_test.py new file mode 100644 index 00000000..ce9b74da --- /dev/null +++ b/tests/windows_events_test.py @@ -0,0 +1,81 @@ +import unittest + +import tulip + +from tulip import windows_events +from tulip import protocols +from tulip import streams + + +def connect_read_pipe(loop, file): + stream_reader = streams.StreamReader(loop=loop) + protocol = _StreamReaderProtocol(stream_reader) + loop._make_read_pipe_transport(file, protocol) + return stream_reader + + +class _StreamReaderProtocol(protocols.Protocol): + def __init__(self, stream_reader): + self.stream_reader = stream_reader + + def connection_lost(self, exc): + self.stream_reader.set_exception(exc) + + def data_received(self, data): + self.stream_reader.feed_data(data) + + def eof_received(self): + self.stream_reader.feed_eof() + + +class ProactorTests(unittest.TestCase): + + def setUp(self): + self.loop = windows_events.ProactorEventLoop() + tulip.set_event_loop(None) + + def tearDown(self): + self.loop.close() + self.loop = None + + def test_pause_resume_discard(self): + a, b = self.loop._socketpair() + trans = self.loop._make_write_pipe_transport(a, protocols.Protocol()) + reader = connect_read_pipe(self.loop, b) + f = tulip.async(reader.readline(), loop=self.loop) + + trans.write(b'msg1\n') + self.loop.run_until_complete(f, timeout=0.01) + self.assertEqual(f.result(), b'msg1\n') + f = tulip.async(reader.readline(), loop=self.loop) + + trans.pause_writing() + trans.write(b'msg2\n') + with self.assertRaises(tulip.TimeoutError): + self.loop.run_until_complete(f, timeout=0.01) + self.assertEqual(trans._buffer, [b'msg2\n']) + + trans.resume_writing() + self.loop.run_until_complete(f, timeout=0.1) + self.assertEqual(f.result(), b'msg2\n') + f = tulip.async(reader.readline(), loop=self.loop) + + trans.pause_writing() + trans.write(b'msg3\n') + self.assertEqual(trans._buffer, [b'msg3\n']) + trans.discard_output() + self.assertEqual(trans._buffer, []) + + trans.write(b'msg4\n') + self.assertEqual(trans._buffer, [b'msg4\n']) + trans.resume_writing() + self.loop.run_until_complete(f, timeout=0.01) + self.assertEqual(f.result(), b'msg4\n') + + def test_close(self): + a, b = self.loop._socketpair() + trans = self.loop._make_socket_transport(a, protocols.Protocol()) + f = tulip.async(self.loop.sock_recv(b, 100)) + trans.close() + self.loop.run_until_complete(f, timeout=1) + self.assertEqual(f.result(), b'') diff --git a/tests/windows_utils_test.py b/tests/windows_utils_test.py new file mode 100644 index 00000000..b23896d3 --- /dev/null +++ b/tests/windows_utils_test.py @@ -0,0 +1,132 @@ +"""Tests for window_utils""" + +import sys +import test.support +import unittest +import unittest.mock +import _winapi + +from tulip import windows_utils +from tulip import _overlapped + + +class WinsocketpairTests(unittest.TestCase): + + def test_winsocketpair(self): + ssock, csock = windows_utils.socketpair() + + csock.send(b'xxx') + self.assertEqual(b'xxx', ssock.recv(1024)) + + csock.close() + ssock.close() + + @unittest.mock.patch('tulip.windows_utils.socket') + def test_winsocketpair_exc(self, m_socket): + m_socket.socket.return_value.getsockname.return_value = ('', 12345) + m_socket.socket.return_value.accept.return_value = object(), object() + m_socket.socket.return_value.connect.side_effect = OSError() + + self.assertRaises(OSError, windows_utils.socketpair) + + +class PipeTests(unittest.TestCase): + + def test_pipe_overlapped(self): + h1, h2 = windows_utils.pipe(overlapped=(True, True)) + try: + ov1 = _overlapped.Overlapped() + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, 0) + + ov1.ReadFile(h1, 100) + self.assertTrue(ov1.pending) + self.assertEqual(ov1.error, _winapi.ERROR_IO_PENDING) + ERROR_IO_INCOMPLETE = 996 + try: + ov1.getresult() + except OSError as e: + self.assertEqual(e.winerror, ERROR_IO_INCOMPLETE) + else: + raise RuntimeError('expected ERROR_IO_INCOMPLETE') + + ov2 = _overlapped.Overlapped() + self.assertFalse(ov2.pending) + self.assertEqual(ov2.error, 0) + + ov2.WriteFile(h2, b"hello") + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + + res = _winapi.WaitForMultipleObjects([ov2.event], False, 100) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, ERROR_IO_INCOMPLETE) + self.assertFalse(ov2.pending) + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + self.assertEqual(ov1.getresult(), b"hello") + finally: + _winapi.CloseHandle(h1) + _winapi.CloseHandle(h2) + + def test_pipe_handle(self): + h, _ = windows_utils.pipe(overlapped=(True, True)) + _winapi.CloseHandle(_) + p = windows_utils.PipeHandle(h) + self.assertEqual(p.fileno(), h) + self.assertEqual(p.handle, h) + + # check garbage collection of p closes handle + del p + test.support.gc_collect() + try: + _winapi.CloseHandle(h) + except OSError as e: + self.assertEqual(e.winerror, 6) # ERROR_INVALID_HANDLE + else: + raise RuntimeError('expected ERROR_INVALID_HANDLE') + + +class PopenTests(unittest.TestCase): + + def test_popen(self): + command = r"""if 1: + import sys + s = sys.stdin.readline() + sys.stdout.write(s.upper()) + sys.stderr.write('stderr') + """ + msg = b"blah\n" + + p = windows_utils.Popen([sys.executable, '-c', command], + stdin=windows_utils.PIPE, + stdout=windows_utils.PIPE, + stderr=windows_utils.PIPE) + + for f in [p.stdin, p.stdout, p.stderr]: + self.assertIsInstance(f, windows_utils.PipeHandle) + + ovin = _overlapped.Overlapped() + ovout = _overlapped.Overlapped() + overr = _overlapped.Overlapped() + + ovin.WriteFile(p.stdin.handle, msg) + ovout.ReadFile(p.stdout.handle, 100) + overr.ReadFile(p.stderr.handle, 100) + + events = [ovin.event, ovout.event, overr.event] + res = _winapi.WaitForMultipleObjects(events, True, 2000) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + self.assertFalse(ovout.pending) + self.assertFalse(overr.pending) + self.assertFalse(ovin.pending) + + self.assertEqual(ovin.getresult(), len(msg)) + out = ovout.getresult().rstrip() + err = overr.getresult().rstrip() + + self.assertGreater(len(out), 0) + self.assertGreater(len(err), 0) + # allow for partial reads... + self.assertTrue(msg.upper().rstrip().startswith(out)) + self.assertTrue(b"stderr".startswith(err)) diff --git a/tulip/__init__.py b/tulip/__init__.py new file mode 100644 index 00000000..9de84cb0 --- /dev/null +++ b/tulip/__init__.py @@ -0,0 +1,28 @@ +"""Tulip 2.0, tracking PEP 3156.""" + +import sys + +# This relies on each of the submodules having an __all__ variable. +from .futures import * +from .events import * +from .locks import * +from .transports import * +from .parsers import * +from .protocols import * +from .streams import * +from .tasks import * + +if sys.platform == 'win32': # pragma: no cover + from .windows_events import * +else: + from .unix_events import * # pragma: no cover + + +__all__ = (futures.__all__ + + events.__all__ + + locks.__all__ + + transports.__all__ + + parsers.__all__ + + protocols.__all__ + + streams.__all__ + + tasks.__all__) diff --git a/tulip/base_events.py b/tulip/base_events.py new file mode 100644 index 00000000..3bccfc83 --- /dev/null +++ b/tulip/base_events.py @@ -0,0 +1,592 @@ +"""Base implementation of event loop. + +The event loop can be broken up into a multiplexer (the part +responsible for notifying us of IO events) and the event loop proper, +which wraps a multiplexer with functionality for scheduling callbacks, +immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. +""" + + +import collections +import concurrent.futures +import heapq +import logging +import socket +import subprocess +import time +import os +import sys + +from . import events +from . import futures +from . import tasks +from .log import tulip_log + + +__all__ = ['BaseEventLoop'] + + +# Argument for default thread pool executor creation. +_MAX_WORKERS = 5 + + +class _StopError(BaseException): + """Raised to stop the event loop.""" + + +def _raise_stop_error(*args): + raise _StopError + + +class BaseEventLoop(events.AbstractEventLoop): + + def __init__(self): + self._ready = collections.deque() + self._scheduled = [] + self._default_executor = None + self._internal_fds = 0 + self._running = False + + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None): + """Create socket transport.""" + raise NotImplementedError + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + server_side=False, extra=None): + """Create SSL transport.""" + raise NotImplementedError + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + """Create datagram transport.""" + raise NotImplementedError + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create read pipe transport.""" + raise NotImplementedError + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create write pipe transport.""" + raise NotImplementedError + + @tasks.coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + """Create subprocess transport.""" + raise NotImplementedError + + def _read_from_self(self): + """XXX""" + raise NotImplementedError + + def _write_to_self(self): + """XXX""" + raise NotImplementedError + + def _process_events(self, event_list): + """Process selector events.""" + raise NotImplementedError + + def run_forever(self): + """Run until stop() is called.""" + if self._running: + raise RuntimeError('Event loop is running.') + self._running = True + try: + while True: + try: + self._run_once() + except _StopError: + break + finally: + self._running = False + + def run_until_complete(self, future): + """Run until the Future is done. + + If the argument is a coroutine, it is wrapped in a Task. + + XXX TBD: It would be disastrous to call run_until_complete() + with the same coroutine twice -- it would wrap it in two + different Tasks and that can't be good. + + Return the Future's result, or raise its exception. + """ + future = tasks.async(future, loop=self) + future.add_done_callback(_raise_stop_error) + self.run_forever() + future.remove_done_callback(_raise_stop_error) + if not future.done(): + raise RuntimeError('Event loop stopped before Future completed.') + + return future.result() + + def stop(self): + """Stop running the event loop. + + Every callback scheduled before stop() is called will run. + Callback scheduled after stop() is called won't. However, + those callbacks will run if run() is called again later. + """ + self.call_soon(_raise_stop_error) + + def is_running(self): + """Returns running status of event loop.""" + return self._running + + def time(self): + """Return the time according to the event loop's clock.""" + return time.monotonic() + + def call_later(self, delay, callback, *args): + """Arrange for a callback to be called at a given time. + + Return a Handle: an opaque object with a cancel() method that + can be used to cancel the call. + + The delay can be an int or float, expressed in seconds. It is + always a relative time. + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + return self.call_at(self.time() + delay, callback, *args) + + def call_at(self, when, callback, *args): + """Like call_later(), but uses an absolute time.""" + timer = events.TimerHandle(when, callback, args) + heapq.heappush(self._scheduled, timer) + return timer + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + handle = events.make_handle(callback, args) + self._ready.append(handle) + return handle + + def call_soon_threadsafe(self, callback, *args): + """XXX""" + handle = self.call_soon(callback, *args) + self._write_to_self() + return handle + + def run_in_executor(self, executor, callback, *args): + if isinstance(callback, events.Handle): + assert not args + assert not isinstance(callback, events.TimerHandle) + if callback._cancelled: + f = futures.Future(loop=self) + f.set_result(None) + return f + callback, args = callback._callback, callback._args + if executor is None: + executor = self._default_executor + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + self._default_executor = executor + return futures.wrap_future(executor.submit(callback, *args), loop=self) + + def set_default_executor(self, executor): + self._default_executor = executor + + def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) + + def getnameinfo(self, sockaddr, flags=0): + return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + + @tasks.coroutine + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=None, family=0, proto=0, flags=0, sock=None, + local_addr=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + 'host/port and sock can not be specified at the same time') + + f1 = self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + fs = [f1] + if local_addr is not None: + f2 = self.getaddrinfo( + *local_addr, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + fs.append(f2) + else: + f2 = None + + yield from tasks.wait(fs, loop=self) + + infos = f1.result() + if not infos: + raise OSError('getaddrinfo() returned empty list') + if f2 is not None: + laddr_infos = f2.result() + if not laddr_infos: + raise OSError('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + if f2 is not None: + for _, _, _, _, laddr in laddr_infos: + try: + sock.bind(laddr) + break + except OSError as exc: + exc = OSError( + exc.errno, 'error while ' + 'attempting to bind on address ' + '{!r}: {}'.format( + laddr, exc.strerror.lower())) + exceptions.append(exc) + else: + sock.close() + sock = None + continue + yield from self.sock_connect(sock, address) + except OSError as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + if len(exceptions) == 1: + raise exceptions[0] + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise OSError('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + + elif sock is None: + raise ValueError( + 'host and port was not specified and no sock specified') + + sock.setblocking(False) + + protocol = protocol_factory() + waiter = futures.Future(loop=self) + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + transport = self._make_ssl_transport( + sock, protocol, sslcontext, waiter, server_side=False) + else: + transport = self._make_socket_transport(sock, protocol, waiter) + + yield from waiter + return transport, protocol + + @tasks.coroutine + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + """Create datagram connection.""" + if not (local_addr or remote_addr): + if family == 0: + raise ValueError('unexpected address family') + addr_pairs_info = (((family, proto), (None, None)),) + else: + # join addresss by (family, protocol) + addr_infos = collections.OrderedDict() + for idx, addr in ((0, local_addr), (1, remote_addr)): + if addr is not None: + assert isinstance(addr, tuple) and len(addr) == 2, ( + '2-tuple is expected') + + infos = yield from self.getaddrinfo( + *addr, family=family, type=socket.SOCK_DGRAM, + proto=proto, flags=flags) + if not infos: + raise OSError('getaddrinfo() returned empty list') + + for fam, _, pro, _, address in infos: + key = (fam, pro) + if key not in addr_infos: + addr_infos[key] = [None, None] + addr_infos[key][idx] = address + + # each addr has to have info for each (family, proto) pair + addr_pairs_info = [ + (key, addr_pair) for key, addr_pair in addr_infos.items() + if not ((local_addr and addr_pair[0] is None) or + (remote_addr and addr_pair[1] is None))] + + if not addr_pairs_info: + raise ValueError('can not get address information') + + exceptions = [] + + for ((family, proto), + (local_address, remote_address)) in addr_pairs_info: + sock = None + l_addr = None + r_addr = None + try: + sock = socket.socket( + family=family, type=socket.SOCK_DGRAM, proto=proto) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(False) + + if local_addr: + sock.bind(local_address) + l_addr = sock.getsockname() + if remote_addr: + yield from self.sock_connect(sock, remote_address) + r_addr = remote_address + except OSError as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + + protocol = protocol_factory() + transport = self._make_datagram_transport( + sock, protocol, r_addr, extra={'addr': l_addr}) + return transport, protocol + + # This returns a Task made from self._start_serving_internal(). + # We want start_serving() to return a Task so that it will start + # running right away (when the event loop runs) even if the caller + # doesn't wait for it. Note that this is different from + # e.g. create_connection(), or create_datagram_endpoint(), which + # are a "mere" coroutines and require their caller to wait for + # them. The reason for the difference is that only + # start_serving() creates multiple transports and protocols. + def start_serving(self, protocol_factory, host=None, port=None, + *, + family=socket.AF_UNSPEC, + flags=socket.AI_PASSIVE, + sock=None, + backlog=100, + ssl=None, + reuse_address=None): + coro = self._start_serving_internal(protocol_factory, host, port, + family=family, + flags=flags, + sock=sock, + backlog=backlog, + ssl=ssl, + reuse_address=reuse_address) + return tasks.Task(coro, loop=self) + + @tasks.coroutine + def _start_serving_internal(self, protocol_factory, host=None, port=None, + *, + family=socket.AF_UNSPEC, + flags=socket.AI_PASSIVE, + sock=None, + backlog=100, + ssl=None, + reuse_address=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + 'host/port and sock can not be specified at the same time') + + AF_INET6 = getattr(socket, 'AF_INET6', 0) + if reuse_address is None: + reuse_address = os.name == 'posix' and sys.platform != 'cygwin' + sockets = [] + if host == '': + host = None + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=0, flags=flags) + if not infos: + raise OSError('getaddrinfo() returned empty list') + + completed = False + try: + for res in infos: + af, socktype, proto, canonname, sa = res + sock = socket.socket(af, socktype, proto) + sockets.append(sock) + if reuse_address: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, + True) + # Disable IPv4/IPv6 dual stack support (enabled by + # default on Linux) which makes a single socket + # listen on both address families. + if af == AF_INET6 and hasattr(socket, 'IPPROTO_IPV6'): + sock.setsockopt(socket.IPPROTO_IPV6, + socket.IPV6_V6ONLY, + True) + try: + sock.bind(sa) + except OSError as err: + raise OSError(err.errno, 'error while attempting ' + 'to bind on address %r: %s' + % (sa, err.strerror.lower())) + completed = True + finally: + if not completed: + for sock in sockets: + sock.close() + else: + if sock is None: + raise ValueError( + 'host and port was not specified and no sock specified') + sockets = [sock] + + for sock in sockets: + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock, ssl) + return sockets + + @tasks.coroutine + def connect_read_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future(loop=self) + transport = self._make_read_pipe_transport(pipe, protocol, waiter, + extra={}) + yield from waiter + return transport, protocol + + @tasks.coroutine + def connect_write_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future(loop=self) + transport = self._make_write_pipe_transport(pipe, protocol, waiter, + extra={}) + yield from waiter + return transport, protocol + + @tasks.coroutine + def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + universal_newlines=False, shell=True, bufsize=0, + **kwargs): + assert not universal_newlines, "universal_newlines must be False" + assert shell, "shell must be True" + assert isinstance(cmd, str), cmd + protocol = protocol_factory() + transport = yield from self._make_subprocess_transport( + protocol, cmd, True, stdin, stdout, stderr, bufsize, + extra={}, **kwargs) + return transport, protocol + + @tasks.coroutine + def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + universal_newlines=False, shell=False, bufsize=0, + **kwargs): + assert not universal_newlines, "universal_newlines must be False" + assert not shell, "shell must be False" + protocol = protocol_factory() + transport = yield from self._make_subprocess_transport( + protocol, args, False, stdin, stdout, stderr, bufsize, + extra={}, **kwargs) + return transport, protocol + + def _add_callback(self, handle): + """Add a Handle to ready or scheduled.""" + assert isinstance(handle, events.Handle), 'A Handle is required here' + if handle._cancelled: + return + if isinstance(handle, events.TimerHandle): + heapq.heappush(self._scheduled, handle) + else: + self._ready.append(handle) + + def _add_callback_signalsafe(self, handle): + """Like _add_callback() but called from a signal handler.""" + self._add_callback(handle) + self._write_to_self() + + def _run_once(self): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0]._cancelled: + heapq.heappop(self._scheduled) + + timeout = None + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0]._when + deadline = max(0, when - self.time()) + if timeout is None: + timeout = deadline + else: + timeout = min(timeout, deadline) + + # TODO: Instrumentation only in debug mode? + t0 = self.time() + event_list = self._selector.select(timeout) + t1 = self.time() + argstr = '' if timeout is None else '{:.3f}'.format(timeout) + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + tulip_log.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + self._process_events(event_list) + + # Handle 'later' callbacks that are ready. + now = self.time() + while self._scheduled: + handle = self._scheduled[0] + if handle._when > now: + break + handle = heapq.heappop(self._scheduled) + self._ready.append(handle) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is threadsafe without using locks. + ntodo = len(self._ready) + for i in range(ntodo): + handle = self._ready.popleft() + if not handle._cancelled: + handle._run() + handle = None # Needed to break cycles when an exception occurs. diff --git a/tulip/constants.py b/tulip/constants.py new file mode 100644 index 00000000..79c3b931 --- /dev/null +++ b/tulip/constants.py @@ -0,0 +1,4 @@ +"""Constants.""" + + +LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5 diff --git a/tulip/events.py b/tulip/events.py new file mode 100644 index 00000000..7db2514d --- /dev/null +++ b/tulip/events.py @@ -0,0 +1,389 @@ +"""Event loop and event loop policy. + +Beyond the PEP: +- Only the main thread has a default event loop. +""" + +__all__ = ['AbstractEventLoopPolicy', 'DefaultEventLoopPolicy', + 'AbstractEventLoop', 'TimerHandle', 'Handle', 'make_handle', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', 'new_event_loop', + ] + +import subprocess +import sys +import threading +import socket + +from .log import tulip_log + + +class Handle: + """Object returned by callback registration methods.""" + + def __init__(self, callback, args): + self._callback = callback + self._args = args + self._cancelled = False + + def __repr__(self): + res = 'Handle({}, {})'.format(self._callback, self._args) + if self._cancelled: + res += '' + return res + + def cancel(self): + self._cancelled = True + + def _run(self): + try: + self._callback(*self._args) + except Exception: + tulip_log.exception('Exception in callback %s %r', + self._callback, self._args) + self = None # Needed to break cycles when an exception occurs. + + +def make_handle(callback, args): + # TODO: Inline this? + assert not isinstance(callback, Handle), 'A Handle is not a callback' + return Handle(callback, args) + + +class TimerHandle(Handle): + """Object returned by timed callback registration methods.""" + + def __init__(self, when, callback, args): + assert when is not None + super().__init__(callback, args) + + self._when = when + + def __repr__(self): + res = 'TimerHandle({}, {}, {})'.format(self._when, + self._callback, + self._args) + if self._cancelled: + res += '' + + return res + + def __hash__(self): + return hash(self._when) + + def __lt__(self, other): + return self._when < other._when + + def __le__(self, other): + if self._when < other._when: + return True + return self.__eq__(other) + + def __gt__(self, other): + return self._when > other._when + + def __ge__(self, other): + if self._when > other._when: + return True + return self.__eq__(other) + + def __eq__(self, other): + if isinstance(other, TimerHandle): + return (self._when == other._when and + self._callback == other._callback and + self._args == other._args and + self._cancelled == other._cancelled) + return NotImplemented + + def __ne__(self, other): + equal = self.__eq__(other) + return NotImplemented if equal is NotImplemented else not equal + + +class AbstractEventLoop: + """Abstract event loop.""" + + # Running and stopping the event loop. + + def run_forever(self): + """Run the event loop until stop() is called.""" + raise NotImplementedError + + def run_until_complete(self, future): + """Run the event loop until a Future is done. + + Return the Future's result, or raise its exception. + """ + raise NotImplementedError + + def stop(self): + """Stop the event loop as soon as reasonable. + + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + raise NotImplementedError + + def is_running(self): + """Return whether the event loop is currently running.""" + raise NotImplementedError + + # Methods scheduling callbacks. All these return Handles. + + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) + + def call_later(self, delay, callback, *args): + raise NotImplementedError + + def call_at(self, when, callback, *args): + raise NotImplementedError + + def time(self): + raise NotImplementedError + + # Methods for interacting with threads. + + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError + + def run_in_executor(self, executor, callback, *args): + raise NotImplementedError + + def set_default_executor(self, executor): + raise NotImplementedError + + # Network I/O methods returning Futures. + + def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError + + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=None, family=0, proto=0, flags=0, sock=None, + local_addr=None): + raise NotImplementedError + + def start_serving(self, protocol_factory, host=None, port=None, *, + family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, + sock=None, backlog=100, ssl=None, reuse_address=None): + """Creates a TCP server bound to host and port and return a + Task whose result will be a list of socket objects which will + later be handled by protocol_factory. + + If host is an empty string or None all interfaces are assumed + and a list of multiple sockets will be returned (most likely + one for IPv4 and another one for IPv6). + + family can be set to either AF_INET or AF_INET6 to force the + socket to use IPv4 or IPv6. If not set it will be determined + from host (defaults to AF_UNSPEC). + + flags is a bitmask for getaddrinfo(). + + sock can optionally be specified in order to use a preexisting + socket object. + + backlog is the maximum number of queued connections passed to + listen() (defaults to 100). + + ssl can be set to an SSLContext to enable SSL over the + accepted connections. + + reuse_address tells the kernel to reuse a local socket in + TIME_WAIT state, without waiting for its natural timeout to + expire. If not specified will automatically be set to True on + UNIX. + """ + raise NotImplementedError + + def stop_serving(self, sock): + """Stop listening for incoming connections. Close socket.""" + raise NotImplementedError + + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + def connect_read_pipe(self, protocol_factory, pipe): + """Register read pipe in eventloop. + + protocol_factory should instantiate object with Protocol interface. + pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + ReadTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def connect_write_pipe(self, protocol_factory, pipe): + """Register write pipe in eventloop. + + protocol_factory should instantiate object with BaseProtocol interface. + Pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + WriteTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + **kwargs): + raise NotImplementedError + + def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + **kwargs): + raise NotImplementedError + + # Ready-based callback registration methods. + # The add_*() methods return None. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. + + def add_reader(self, fd, callback, *args): + raise NotImplementedError + + def remove_reader(self, fd): + raise NotImplementedError + + def add_writer(self, fd, callback, *args): + raise NotImplementedError + + def remove_writer(self, fd): + raise NotImplementedError + + # Completion based I/O methods returning Futures. + + def sock_recv(self, sock, nbytes): + raise NotImplementedError + + def sock_sendall(self, sock, data): + raise NotImplementedError + + def sock_connect(self, sock, address): + raise NotImplementedError + + def sock_accept(self, sock): + raise NotImplementedError + + # Signal handling. + + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError + + def remove_signal_handler(self, sig): + raise NotImplementedError + + +class AbstractEventLoopPolicy: + """Abstract policy for accessing the event loop.""" + + def get_event_loop(self): + """XXX""" + raise NotImplementedError + + def set_event_loop(self, loop): + """XXX""" + raise NotImplementedError + + def new_event_loop(self): + """XXX""" + raise NotImplementedError + + +class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy): + """Default policy implementation for accessing the event loop. + + In this policy, each thread has its own event loop. However, we + only automatically create an event loop by default for the main + thread; other threads by default have no event loop. + + Other policies may have different rules (e.g. a single global + event loop, or automatically creating an event loop per thread, or + using some other notion of context to which an event loop is + associated). + """ + + _loop = None + _set_called = False + + def get_event_loop(self): + """Get the event loop. + + This may be None or an instance of EventLoop. + """ + if (self._loop is None and + not self._set_called and + isinstance(threading.current_thread(), threading._MainThread)): + self._loop = self.new_event_loop() + assert self._loop is not None, \ + ('There is no current event loop in thread %r.' % + threading.current_thread().name) + return self._loop + + def set_event_loop(self, loop): + """Set the event loop.""" + # TODO: The isinstance() test violates the PEP. + self._set_called = True + assert loop is None or isinstance(loop, AbstractEventLoop) + self._loop = loop + + def new_event_loop(self): + """Create a new event loop. + + You must call set_event_loop() to make this the current event + loop. + """ + if sys.platform == 'win32': # pragma: no cover + from . import windows_events + return windows_events.SelectorEventLoop() + else: # pragma: no cover + from . import unix_events + return unix_events.SelectorEventLoop() + + +# Event loop policy. The policy itself is always global, even if the +# policy's rules say that there is an event loop per thread (or other +# notion of context). The default policy is installed by the first +# call to get_event_loop_policy(). +_event_loop_policy = None + + +def get_event_loop_policy(): + """XXX""" + global _event_loop_policy + if _event_loop_policy is None: + _event_loop_policy = DefaultEventLoopPolicy() + return _event_loop_policy + + +def set_event_loop_policy(policy): + """XXX""" + global _event_loop_policy + # TODO: The isinstance() test violates the PEP. + assert policy is None or isinstance(policy, AbstractEventLoopPolicy) + _event_loop_policy = policy + + +def get_event_loop(): + """XXX""" + return get_event_loop_policy().get_event_loop() + + +def set_event_loop(loop): + """XXX""" + get_event_loop_policy().set_event_loop(loop) + + +def new_event_loop(): + """XXX""" + return get_event_loop_policy().new_event_loop() diff --git a/tulip/futures.py b/tulip/futures.py new file mode 100644 index 00000000..706e8c8a --- /dev/null +++ b/tulip/futures.py @@ -0,0 +1,338 @@ +"""A Future class similar to the one in PEP 3148.""" + +__all__ = ['CancelledError', 'TimeoutError', + 'InvalidStateError', + 'Future', 'wrap_future', + ] + +import concurrent.futures._base +import logging +import traceback + +from . import events +from .log import tulip_log + +# States for Future. +_PENDING = 'PENDING' +_CANCELLED = 'CANCELLED' +_FINISHED = 'FINISHED' + +# TODO: Do we really want to depend on concurrent.futures internals? +Error = concurrent.futures._base.Error +CancelledError = concurrent.futures.CancelledError +TimeoutError = concurrent.futures.TimeoutError + +STACK_DEBUG = logging.DEBUG - 1 # heavy-duty debugging + + +class InvalidStateError(Error): + """The operation is not allowed in this state.""" + # TODO: Show the future, its state, the method, and the required state. + + +class _TracebackLogger: + """Helper to log a traceback upon destruction if not cleared. + + This solves a nasty problem with Futures and Tasks that have an + exception set: if nobody asks for the exception, the exception is + never logged. This violates the Zen of Python: 'Errors should + never pass silently. Unless explicitly silenced.' + + However, we don't want to log the exception as soon as + set_exception() is called: if the calling code is written + properly, it will get the exception and handle it properly. But + we *do* want to log it if result() or exception() was never called + -- otherwise developers waste a lot of time wondering why their + buggy code fails silently. + + An earlier attempt added a __del__() method to the Future class + itself, but this backfired because the presence of __del__() + prevents garbage collection from breaking cycles. A way out of + this catch-22 is to avoid having a __del__() method on the Future + class itself, but instead to have a reference to a helper object + with a __del__() method that logs the traceback, where we ensure + that the helper object doesn't participate in cycles, and only the + Future has a reference to it. + + The helper object is added when set_exception() is called. When + the Future is collected, and the helper is present, the helper + object is also collected, and its __del__() method will log the + traceback. When the Future's result() or exception() method is + called (and a helper object is present), it removes the the helper + object, after calling its clear() method to prevent it from + logging. + + One downside is that we do a fair amount of work to extract the + traceback from the exception, even when it is never logged. It + would seem cheaper to just store the exception object, but that + references the traceback, which references stack frames, which may + reference the Future, which references the _TracebackLogger, and + then the _TracebackLogger would be included in a cycle, which is + what we're trying to avoid! As an optimization, we don't + immediately format the exception; we only do the work when + activate() is called, which call is delayed until after all the + Future's callbacks have run. Since usually a Future has at least + one callback (typically set by 'yield from') and usually that + callback extracts the callback, thereby removing the need to + format the exception. + + PS. I don't claim credit for this solution. I first heard of it + in a discussion about closing files when they are collected. + """ + + __slots__ = ['exc', 'tb'] + + def __init__(self, exc): + self.exc = exc + self.tb = None + + def activate(self): + exc = self.exc + if exc is not None: + self.exc = None + self.tb = traceback.format_exception(exc.__class__, exc, + exc.__traceback__) + + def clear(self): + self.exc = None + self.tb = None + + def __del__(self): + if self.tb: + tulip_log.error('Future/Task exception was never retrieved:\n%s', + ''.join(self.tb)) + + +class Future: + """This class is *almost* compatible with concurrent.futures.Future. + + Differences: + + - result() and exception() do not take a timeout argument and + raise an exception when the future isn't done yet. + + - Callbacks registered with add_done_callback() are always called + via the event loop's call_soon_threadsafe(). + + - This class is not compatible with the wait() and as_completed() + methods in the concurrent.futures package. + + (In Python 3.4 or later we may be able to unify the implementations.) + """ + + # Class variables serving as defaults for instance variables. + _state = _PENDING + _result = None + _exception = None + _loop = None + + _blocking = False # proper use of future (yield vs yield from) + + _tb_logger = None + + def __init__(self, *, loop=None): + """Initialize the future. + + The optional event_loop argument allows to explicitly set the event + loop object used by the future. If it's not provided, the future uses + the default event loop. + """ + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._callbacks = [] + + def __repr__(self): + res = self.__class__.__name__ + if self._state == _FINISHED: + if self._exception is not None: + res += ''.format(self._exception) + else: + res += ''.format(self._result) + elif self._callbacks: + size = len(self._callbacks) + if size > 2: + res += '<{}, [{}, <{} more>, {}]>'.format( + self._state, self._callbacks[0], + size-2, self._callbacks[-1]) + else: + res += '<{}, {}>'.format(self._state, self._callbacks) + else: + res += '<{}>'.format(self._state) + return res + + def cancel(self): + """Cancel the future and schedule callbacks. + + If the future is already done or cancelled, return False. Otherwise, + change the future's state to cancelled, schedule the callbacks and + return True. + """ + if self._state != _PENDING: + return False + self._state = _CANCELLED + self._schedule_callbacks() + return True + + def _schedule_callbacks(self): + """Internal: Ask the event loop to call all callbacks. + + The callbacks are scheduled to be called as soon as possible. Also + clears the callback list. + """ + callbacks = self._callbacks[:] + if not callbacks: + return + + self._callbacks[:] = [] + for callback in callbacks: + self._loop.call_soon(callback, self) + + def cancelled(self): + """Return True if the future was cancelled.""" + return self._state == _CANCELLED + + # Don't implement running(); see http://bugs.python.org/issue18699 + + def done(self): + """Return True if the future is done. + + Done means either that a result / exception are available, or that the + future was cancelled. + """ + return self._state != _PENDING + + def result(self): + """Return the result this future represents. + + If the future has been cancelled, raises CancelledError. If the + future's result isn't yet available, raises InvalidStateError. If + the future is done and has an exception set, this exception is raised. + """ + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError('Result is not ready.') + if self._tb_logger is not None: + self._tb_logger.clear() + self._tb_logger = None + if self._exception is not None: + raise self._exception + return self._result + + def exception(self): + """Return the exception that was set on this future. + + The exception (or None if no exception was set) is returned only if + the future is done. If the future has been cancelled, raises + CancelledError. If the future isn't done yet, raises + InvalidStateError. + """ + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError('Exception is not set.') + if self._tb_logger is not None: + self._tb_logger.clear() + self._tb_logger = None + return self._exception + + def add_done_callback(self, fn): + """Add a callback to be run when the future becomes done. + + The callback is called with a single argument - the future object. If + the future is already done when this is called, the callback is + scheduled with call_soon. + """ + if self._state != _PENDING: + self._loop.call_soon(fn, self) + else: + self._callbacks.append(fn) + + # New method not in PEP 3148. + + def remove_done_callback(self, fn): + """Remove all instances of a callback from the "call when done" list. + + Returns the number of callbacks removed. + """ + filtered_callbacks = [f for f in self._callbacks if f != fn] + removed_count = len(self._callbacks) - len(filtered_callbacks) + if removed_count: + self._callbacks[:] = filtered_callbacks + return removed_count + + # So-called internal methods (note: no set_running_or_notify_cancel()). + + def set_result(self, result): + """Mark the future done and set its result. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError('{}: {!r}'.format(self._state, self)) + self._result = result + self._state = _FINISHED + self._schedule_callbacks() + + def set_exception(self, exception): + """Mark the future done and set an exception. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError('{}: {!r}'.format(self._state, self)) + self._exception = exception + self._tb_logger = _TracebackLogger(exception) + self._state = _FINISHED + self._schedule_callbacks() + # Arrange for the logger to be activated after all callbacks + # have had a chance to call result() or exception(). + self._loop.call_soon(self._tb_logger.activate) + + # Truly internal methods. + + def _copy_state(self, other): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert other.done() + assert not self.done() + if other.cancelled(): + self.cancel() + else: + exception = other.exception() + if exception is not None: + self.set_exception(exception) + else: + result = other.result() + self.set_result(result) + + def __iter__(self): + if not self.done(): + self._blocking = True + yield self # This tells Task to wait for completion. + assert self.done(), "yield from wasn't used with future" + return self.result() # May raise too. + + +def wrap_future(fut, *, loop=None): + """Wrap concurrent.futures.Future object.""" + if isinstance(fut, Future): + return fut + + assert isinstance(fut, concurrent.futures.Future), \ + 'concurrent.futures.Future is expected, got {!r}'.format(fut) + + if loop is None: + loop = events.get_event_loop() + + new_future = Future(loop=loop) + fut.add_done_callback( + lambda future: loop.call_soon_threadsafe( + new_future._copy_state, fut)) + return new_future diff --git a/tulip/http/__init__.py b/tulip/http/__init__.py new file mode 100644 index 00000000..a1432dee --- /dev/null +++ b/tulip/http/__init__.py @@ -0,0 +1,16 @@ +# This relies on each of the submodules having an __all__ variable. + +from .client import * +from .errors import * +from .protocol import * +from .server import * +from .session import * +from .wsgi import * + + +__all__ = (client.__all__ + + errors.__all__ + + protocol.__all__ + + server.__all__ + + session.__all__ + + wsgi.__all__) diff --git a/tulip/http/client.py b/tulip/http/client.py new file mode 100644 index 00000000..ec7cd034 --- /dev/null +++ b/tulip/http/client.py @@ -0,0 +1,572 @@ +"""HTTP Client for Tulip. + +Most basic usage: + + response = yield from tulip.http.request('GET', url) + response['Content-Type'] == 'application/json' + response.status == 200 + + content = yield from response.content.read() +""" + +__all__ = ['request'] + +import base64 +import email.message +import functools +import http.client +import http.cookies +import json +import io +import itertools +import mimetypes +import os +import uuid +import urllib.parse + +import tulip +import tulip.http + + +@tulip.coroutine +def request(method, url, *, + params=None, + data=None, + headers=None, + cookies=None, + files=None, + auth=None, + allow_redirects=True, + max_redirects=10, + encoding='utf-8', + version=(1, 1), + timeout=None, + compress=None, + chunked=None, + session=None, + loop=None): + """Constructs and sends a request. Returns response object. + + method: http method + url: request url + params: (optional) Dictionary or bytes to be sent in the query string + of the new request + data: (optional) Dictionary, bytes, or file-like object to + send in the body of the request + headers: (optional) Dictionary of HTTP Headers to send with the request + cookies: (optional) Dict object to send with the request + files: (optional) Dictionary of 'name': file-like-objects + for multipart encoding upload + auth: (optional) Auth tuple to enable Basic HTTP Auth + timeout: (optional) Float describing the timeout of the request + allow_redirects: (optional) Boolean. Set to True if POST/PUT/DELETE + redirect following is allowed. + compress: Boolean. Set to True if request has to be compressed + with deflate encoding. + chunked: Boolean or Integer. Set to chunk size for chunked + transfer encoding. + session: tulip.http.Session instance to support connection pooling and + session cookies. + loop: Optional event loop. + + Usage: + + import tulip.http + >> resp = yield from tulip.http.request('GET', 'http://python.org/') + >> resp + + + >> data = yield from resp.content.read() + + """ + redirects = 0 + if loop is None: + loop = tulip.get_event_loop() + + while True: + req = HttpRequest( + method, url, params=params, headers=headers, data=data, + cookies=cookies, files=files, auth=auth, encoding=encoding, + version=version, compress=compress, chunked=chunked) + + if session is None: + conn = start(req, loop) + else: + conn = session.start(req, loop) + + # connection timeout + t = tulip.Task(conn, loop=loop) + th = None + if timeout is not None: + th = loop.call_later(timeout, t.cancel) + try: + resp = yield from t + except tulip.CancelledError: + raise tulip.TimeoutError from None + finally: + if th is not None: + th.cancel() + + # redirects + if resp.status in (301, 302) and allow_redirects: + redirects += 1 + if max_redirects and redirects >= max_redirects: + resp.close() + break + + r_url = resp.get('location') or resp.get('uri') + + scheme = urllib.parse.urlsplit(r_url)[0] + if scheme not in ('http', 'https', ''): + raise ValueError('Can redirect only to http or https') + elif not scheme: + r_url = urllib.parse.urljoin(url, r_url) + + url = urllib.parse.urldefrag(r_url)[0] + if url: + resp.close() + continue + + break + + return resp + + +@tulip.coroutine +def start(req, loop): + transport, p = yield from loop.create_connection( + functools.partial(tulip.StreamProtocol, loop=loop), + req.host, req.port, ssl=req.ssl) + + try: + resp = req.send(transport) + yield from resp.start(p, transport) + except: + transport.close() + raise + + return resp + + +class HttpRequest: + + GET_METHODS = {'DELETE', 'GET', 'HEAD', 'OPTIONS'} + POST_METHODS = {'PATCH', 'POST', 'PUT', 'TRACE'} + ALL_METHODS = GET_METHODS.union(POST_METHODS) + + DEFAULT_HEADERS = { + 'Accept': '*/*', + 'Accept-Encoding': 'gzip, deflate', + } + + body = b'' + + def __init__(self, method, url, *, + params=None, + headers=None, + data=None, + cookies=None, + files=None, + auth=None, + encoding='utf-8', + version=(1, 1), + compress=None, + chunked=None): + self.method = method.upper() + self.encoding = encoding + + # parser http version '1.1' => (1, 1) + if isinstance(version, str): + v = [l.strip() for l in version.split('.', 1)] + try: + version = int(v[0]), int(v[1]) + except ValueError: + raise ValueError( + 'Can not parse http version number: {}' + .format(version)) from None + self.version = version + + # path + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not netloc: + raise ValueError('Host could not be detected.') + + if not path: + path = '/' + else: + path = urllib.parse.unquote(path) + + # check domain idna encoding + try: + netloc = netloc.encode('idna').decode('utf-8') + except UnicodeError: + raise ValueError('URL has an invalid label.') + + # basic auth info + if '@' in netloc: + authinfo, netloc = netloc.split('@', 1) + if not auth: + auth = authinfo.split(':', 1) + if len(auth) == 1: + auth.append('') + + # extract host and port + ssl = scheme == 'https' + + if ':' in netloc: + netloc, port_s = netloc.split(':', 1) + try: + port = int(port_s) + except ValueError: + raise ValueError( + 'Port number could not be converted.') from None + else: + if ssl: + port = http.client.HTTPS_PORT + else: + port = http.client.HTTP_PORT + + self.host = netloc + self.port = port + self.ssl = ssl + + # build url query + if isinstance(params, dict): + params = list(params.items()) + + if data and self.method in self.GET_METHODS: + # include data to query + if isinstance(data, dict): + data = data.items() + params = list(itertools.chain(params or (), data)) + data = None + + if params: + params = urllib.parse.urlencode(params) + if query: + query = '%s&%s' % (query, params) + else: + query = params + + # build path + path = urllib.parse.quote(path) + self.path = urllib.parse.urlunsplit(('', '', path, query, fragment)) + + # headers + self.headers = email.message.Message() + if headers: + if isinstance(headers, dict): + headers = list(headers.items()) + + for key, value in headers: + self.headers.add_header(key, value) + + for hdr, val in self.DEFAULT_HEADERS.items(): + if hdr not in self.headers: + self.headers[hdr] = val + + # host + if 'host' not in self.headers: + self.headers['Host'] = self.host + + # cookies + if cookies: + self.update_cookies(cookies) + + # auth + if auth: + if isinstance(auth, (tuple, list)) and len(auth) == 2: + # basic auth + self.headers['Authorization'] = 'Basic %s' % ( + base64.b64encode( + ('%s:%s' % (auth[0], auth[1])).encode('latin1')) + .strip().decode('latin1')) + else: + raise ValueError("Only basic auth is supported") + + # Content-encoding + enc = self.headers.get('Content-Encoding', '').lower() + if enc: + chunked = True # enable chunked, no need to deal with length + compress = enc + elif compress: + chunked = True # enable chunked, no need to deal with length + compress = compress if isinstance(compress, str) else 'deflate' + self.headers['Content-Encoding'] = compress + + # form data (x-www-form-urlencoded) + if isinstance(data, dict): + data = list(data.items()) + + if data and not files: + if not isinstance(data, str): + data = urllib.parse.urlencode(data, doseq=True) + + self.body = data.encode(encoding) + if 'content-type' not in self.headers: + self.headers['content-type'] = ( + 'application/x-www-form-urlencoded') + if 'content-length' not in self.headers and not chunked: + self.headers['content-length'] = str(len(self.body)) + + # files (multipart/form-data) + elif files: + fields = [] + + if data: + for field, val in data: + fields.append((field, str_to_bytes(val))) + + if isinstance(files, dict): + files = list(files.items()) + + for rec in files: + if not isinstance(rec, (tuple, list)): + rec = (rec,) + + ft = None + if len(rec) == 1: + k = guess_filename(rec[0], 'unknown') + fields.append((k, k, rec[0])) + + elif len(rec) == 2: + k, fp = rec + fn = guess_filename(fp, k) + fields.append((k, fn, fp)) + + else: + k, fp, ft = rec + fn = guess_filename(fp, k) + fields.append((k, fn, fp, ft)) + + chunked = chunked or 8192 + boundary = uuid.uuid4().hex + + self.body = encode_multipart_data( + fields, bytes(boundary, 'latin1')) + + self.headers['content-type'] = ( + 'multipart/form-data; boundary=%s' % boundary) + + # chunked + te = self.headers.get('transfer-encoding', '').lower() + + if chunked: + if 'content-length' in self.headers: + del self.headers['content-length'] + if 'chunked' not in te: + self.headers['transfer-encoding'] = 'chunked' + + chunked = chunked if type(chunked) is int else 8196 + else: + if 'chunked' in te: + chunked = 8196 + else: + chunked = None + self.headers['content-length'] = str(len(self.body)) + + self._chunked = chunked + self._compress = compress + + def update_cookies(self, cookies): + """Update request cookies header.""" + c = http.cookies.SimpleCookie() + if 'cookie' in self.headers: + c.load(self.headers.get('cookie', '')) + del self.headers['cookie'] + + if isinstance(cookies, dict): + cookies = cookies.items() + + for name, value in cookies: + if isinstance(value, http.cookies.Morsel): + # use dict method because SimpleCookie class modifies value + dict.__setitem__(c, name, value) + else: + c[name] = value + + self.headers['cookie'] = c.output(header='', sep=';').strip() + + def send(self, transport): + request = tulip.http.Request( + transport, self.method, self.path, self.version) + + if self._compress: + request.add_compression_filter(self._compress) + + if self._chunked is not None: + request.add_chunking_filter(self._chunked) + + request.add_headers(*self.headers.items()) + request.send_headers() + + if isinstance(self.body, bytes): + self.body = (self.body,) + + for chunk in self.body: + request.write(chunk) + + request.write_eof() + + return HttpResponse(self.method, self.path, self.host) + + +class HttpResponse(http.client.HTTPMessage): + + message = None # RawResponseMessage object + + # from the Status-Line of the response + version = None # HTTP-Version + status = None # Status-Code + reason = None # Reason-Phrase + + cookies = None # Response cookies (Set-Cookie) + + content = None # payload stream + stream = None # input stream + transport = None # current transport + + def __init__(self, method, url, host=''): + super().__init__() + + self.method = method + self.url = url + self.host = host + self._content = None + + def __del__(self): + self.close() + + def __repr__(self): + out = io.StringIO() + print(''.format( + self.host, self.url, self.status, self.reason), file=out) + print(super().__str__(), file=out) + return out.getvalue() + + def start(self, stream, transport): + """Start response processing.""" + self.stream = stream + self.transport = transport + + httpstream = stream.set_parser(tulip.http.http_response_parser()) + + # read response + self.message = yield from httpstream.read() + + # response status + self.version = self.message.version + self.status = self.message.code + self.reason = self.message.reason + + # headers + for hdr, val in self.message.headers: + self.add_header(hdr, val) + + # payload + self.content = stream.set_parser( + tulip.http.http_payload_parser(self.message)) + + # cookies + self.cookies = http.cookies.SimpleCookie() + if 'Set-Cookie' in self: + for hdr in self.get_all('Set-Cookie'): + self.cookies.load(hdr) + + return self + + def close(self): + if self.transport is not None: + self.transport.close() + self.transport = None + + @tulip.coroutine + def read(self, decode=False): + """Read response payload. Decode known types of content.""" + if self._content is None: + buf = [] + total = 0 + chunk = yield from self.content.read() + while chunk: + size = len(chunk) + buf.append((chunk, size)) + total += size + chunk = yield from self.content.read() + + self._content = bytearray(total) + + idx = 0 + content = memoryview(self._content) + for chunk, size in buf: + content[idx:idx+size] = chunk + idx += size + + data = self._content + + if decode: + ct = self.get('content-type', '').lower() + if ct == 'application/json': + data = json.loads(data.decode('utf-8')) + + return data + + +def str_to_bytes(s, encoding='utf-8'): + if isinstance(s, str): + return s.encode(encoding) + return s + + +def guess_filename(obj, default=None): + name = getattr(obj, 'name', None) + if name and name[0] != '<' and name[-1] != '>': + return os.path.split(name)[-1] + return default + + +def encode_multipart_data(fields, boundary, encoding='utf-8', chunk_size=8196): + """ + Encode a list of fields using the multipart/form-data MIME format. + + fields: + List of (name, value) or (name, filename, io) or + (name, filename, io, MIME type) field tuples. + """ + for rec in fields: + yield b'--' + boundary + b'\r\n' + + field, *rec = rec + + if len(rec) == 1: + data = rec[0] + yield (('Content-Disposition: form-data; name="%s"\r\n\r\n' % + (field,)).encode(encoding)) + yield data + b'\r\n' + + else: + if len(rec) == 3: + fn, fp, ct = rec + else: + fn, fp = rec + ct = (mimetypes.guess_type(fn)[0] or + 'application/octet-stream') + + yield ('Content-Disposition: form-data; name="%s"; ' + 'filename="%s"\r\n' % (field, fn)).encode(encoding) + yield ('Content-Type: %s\r\n\r\n' % (ct,)).encode(encoding) + + if isinstance(fp, str): + fp = fp.encode(encoding) + + if isinstance(fp, bytes): + fp = io.BytesIO(fp) + + while True: + chunk = fp.read(chunk_size) + if not chunk: + break + yield str_to_bytes(chunk) + + yield b'\r\n' + + yield b'--' + boundary + b'--\r\n' diff --git a/tulip/http/errors.py b/tulip/http/errors.py new file mode 100644 index 00000000..f8b77e9b --- /dev/null +++ b/tulip/http/errors.py @@ -0,0 +1,46 @@ +"""http related errors.""" + +__all__ = ['HttpException', 'HttpErrorException', 'BadRequestException', + 'IncompleteRead', 'BadStatusLine', 'LineTooLong', 'InvalidHeader'] + +import http.client + + +class HttpException(http.client.HTTPException): + + code = None + headers = () + message = '' + + +class HttpErrorException(HttpException): + + def __init__(self, code, message='', headers=None): + self.code = code + self.headers = headers + self.message = message + + +class BadRequestException(HttpException): + + code = 400 + message = 'Bad Request' + + +class IncompleteRead(BadRequestException, http.client.IncompleteRead): + pass + + +class BadStatusLine(BadRequestException, http.client.BadStatusLine): + pass + + +class LineTooLong(BadRequestException, http.client.LineTooLong): + pass + + +class InvalidHeader(BadRequestException): + + def __init__(self, hdr): + super().__init__('Invalid HTTP Header: {}'.format(hdr)) + self.hdr = hdr diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py new file mode 100644 index 00000000..7081fd59 --- /dev/null +++ b/tulip/http/protocol.py @@ -0,0 +1,756 @@ +"""Http related helper utils.""" + +__all__ = ['HttpMessage', 'Request', 'Response', + 'RawRequestMessage', 'RawResponseMessage', + 'http_request_parser', 'http_response_parser', + 'http_payload_parser'] + +import collections +import functools +import http.server +import itertools +import re +import sys +import zlib +from wsgiref.handlers import format_date_time + +import tulip +from tulip.http import errors + +METHRE = re.compile('[A-Z0-9$-_.]+') +VERSRE = re.compile('HTTP/(\d+).(\d+)') +HDRRE = re.compile('[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]') +CONTINUATION = (' ', '\t') +EOF_MARKER = object() +EOL_MARKER = object() + +RESPONSES = http.server.BaseHTTPRequestHandler.responses + + +RawRequestMessage = collections.namedtuple( + 'RawRequestMessage', + ['method', 'path', 'version', 'headers', 'should_close', 'compression']) + + +RawResponseMessage = collections.namedtuple( + 'RawResponseMessage', + ['version', 'code', 'reason', 'headers', 'should_close', 'compression']) + + +def http_request_parser(max_line_size=8190, + max_headers=32768, max_field_size=8190): + """Read request status line. Exception errors.BadStatusLine + could be raised in case of any errors in status line. + Returns RawRequestMessage. + """ + out, buf = yield + + try: + # read http message (request line + headers) + raw_data = yield from buf.readuntil( + b'\r\n\r\n', max_headers, errors.LineTooLong) + lines = raw_data.decode('ascii', 'surrogateescape').splitlines(True) + + # request line + line = lines[0] + try: + method, path, version = line.split(None, 2) + except ValueError: + raise errors.BadStatusLine(line) from None + + # method + method = method.upper() + if not METHRE.match(method): + raise errors.BadStatusLine(method) + + # version + match = VERSRE.match(version) + if match is None: + raise errors.BadStatusLine(version) + version = (int(match.group(1)), int(match.group(2))) + + # read headers + headers, close, compression = parse_headers( + lines, max_line_size, max_headers, max_field_size) + if version <= (1, 0): + close = True + elif close is None: + close = False + + out.feed_data( + RawRequestMessage( + method, path, version, headers, close, compression)) + out.feed_eof() + except tulip.EofStream: + # Presumably, the server closed the connection before + # sending a valid response. + pass + + +def http_response_parser(max_line_size=8190, + max_headers=32768, max_field_size=8190): + """Read response status line and headers. + + BadStatusLine could be raised in case of any errors in status line. + Returns RawResponseMessage""" + out, buf = yield + + try: + # read http message (response line + headers) + raw_data = yield from buf.readuntil( + b'\r\n\r\n', max_line_size+max_headers, errors.LineTooLong) + lines = raw_data.decode('ascii', 'surrogateescape').splitlines(True) + + line = lines[0] + try: + version, status = line.split(None, 1) + except ValueError: + raise errors.BadStatusLine(line) from None + else: + try: + status, reason = status.split(None, 1) + except ValueError: + reason = '' + + # version + match = VERSRE.match(version) + if match is None: + raise errors.BadStatusLine(line) + version = (int(match.group(1)), int(match.group(2))) + + # The status code is a three-digit number + try: + status = int(status) + except ValueError: + raise errors.BadStatusLine(line) from None + + if status < 100 or status > 999: + raise errors.BadStatusLine(line) + + # read headers + headers, close, compression = parse_headers( + lines, max_line_size, max_headers, max_field_size) + if close is None: + close = version <= (1, 0) + + out.feed_data( + RawResponseMessage( + version, status, reason.strip(), headers, close, compression)) + out.feed_eof() + except tulip.EofStream: + # Presumably, the server closed the connection before + # sending a valid response. + raise errors.BadStatusLine(b'') from None + + +def parse_headers(lines, max_line_size, max_headers, max_field_size): + """Parses RFC2822 headers from a stream. + + Line continuations are supported. Returns list of header name + and value pairs. Header name is in upper case. + """ + close_conn = None + encoding = None + headers = collections.deque() + + lines_idx = 1 + line = lines[1] + + while line not in ('\r\n', '\n'): + header_length = len(line) + + # Parse initial header name : value pair. + try: + name, value = line.split(':', 1) + except ValueError: + raise ValueError('Invalid header: {}'.format(line)) from None + + name = name.strip(' \t').upper() + if HDRRE.search(name): + raise ValueError('Invalid header name: {}'.format(name)) + + # next line + lines_idx += 1 + line = lines[lines_idx] + + # consume continuation lines + continuation = line[0] in CONTINUATION + + if continuation: + value = [value] + while continuation: + header_length += len(line) + if header_length > max_field_size: + raise errors.LineTooLong( + 'limit request headers fields size') + value.append(line) + + # next line + lines_idx += 1 + line = lines[lines_idx] + continuation = line[0] in CONTINUATION + value = ''.join(value) + else: + if header_length > max_field_size: + raise errors.LineTooLong('limit request headers fields size') + + value = value.strip() + + # keep-alive and encoding + if name == 'CONNECTION': + v = value.lower() + if v == 'close': + close_conn = True + elif v == 'keep-alive': + close_conn = False + elif name == 'CONTENT-ENCODING': + enc = value.lower() + if enc in ('gzip', 'deflate'): + encoding = enc + + headers.append((name, value)) + + return headers, close_conn, encoding + + +def http_payload_parser(message, length=None, compression=True, readall=False): + out, buf = yield + + # payload params + chunked = False + for name, value in message.headers: + if name == 'CONTENT-LENGTH': + length = value + elif name == 'TRANSFER-ENCODING': + chunked = value.lower() == 'chunked' + elif name == 'SEC-WEBSOCKET-KEY1': + length = 8 + + # payload decompression wrapper + if compression and message.compression: + out = DeflateBuffer(out, message.compression) + + # payload parser + if chunked: + yield from parse_chunked_payload(out, buf) + + elif length is not None: + try: + length = int(length) + except ValueError: + raise errors.InvalidHeader('CONTENT-LENGTH') from None + + if length < 0: + raise errors.InvalidHeader('CONTENT-LENGTH') + elif length > 0: + yield from parse_length_payload(out, buf, length) + else: + if readall: + yield from parse_eof_payload(out, buf) + + out.feed_eof() + + +def parse_chunked_payload(out, buf): + """Chunked transfer encoding parser.""" + try: + while True: + # read next chunk size + #line = yield from buf.readline(8196) + line = yield from buf.readuntil(b'\r\n', 8196) + + i = line.find(b';') + if i >= 0: + line = line[:i] # strip chunk-extensions + else: + line = line.strip() + try: + size = int(line, 16) + except ValueError: + raise errors.IncompleteRead(b'') from None + + if size == 0: # eof marker + break + + # read chunk and feed buffer + while size: + chunk = yield from buf.readsome(size) + out.feed_data(chunk) + size = size - len(chunk) + + # toss the CRLF at the end of the chunk + yield from buf.skip(2) + + # read and discard trailer up to the CRLF terminator + yield from buf.skipuntil(b'\r\n') + + except tulip.EofStream: + raise errors.IncompleteRead(b'') from None + + +def parse_length_payload(out, buf, length): + """Read specified amount of bytes.""" + try: + while length: + chunk = yield from buf.readsome(length) + out.feed_data(chunk) + length -= len(chunk) + + except tulip.EofStream: + raise errors.IncompleteRead(b'') from None + + +def parse_eof_payload(out, buf): + """Read all bytes untile eof.""" + while True: + out.feed_data((yield from buf.readsome())) + + +class DeflateBuffer: + """DeflateStream decomress stream and feed data into specified stream.""" + + def __init__(self, out, encoding): + self.out = out + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + + self.zlib = zlib.decompressobj(wbits=zlib_mode) + + def feed_data(self, chunk): + try: + chunk = self.zlib.decompress(chunk) + except Exception: + raise errors.IncompleteRead(b'') from None + + if chunk: + self.out.feed_data(chunk) + + def feed_eof(self): + self.out.feed_data(self.zlib.flush()) + if not self.zlib.eof: + raise errors.IncompleteRead(b'') + + self.out.feed_eof() + + +def wrap_payload_filter(func): + """Wraps payload filter and piped filters. + + Filter is a generatator that accepts arbitrary chunks of data, + modify data and emit new stream of data. + + For example we have stream of chunks: ['1', '2', '3', '4', '5'], + we can apply chunking filter to this stream: + + ['1', '2', '3', '4', '5'] + | + response.add_chunking_filter(2) + | + ['12', '34', '5'] + + It is possible to use different filters at the same time. + + For a example to compress incoming stream with 'deflate' encoding + and then split data and emit chunks of 8196 bytes size chunks: + + >> response.add_compression_filter('deflate') + >> response.add_chunking_filter(8196) + + Filters do not alter transfer encoding. + + Filter can receive types types of data, bytes object or EOF_MARKER. + + 1. If filter receives bytes object, it should process data + and yield processed data then yield EOL_MARKER object. + 2. If Filter recevied EOF_MARKER, it should yield remaining + data (buffered) and then yield EOF_MARKER. + """ + @functools.wraps(func) + def wrapper(self, *args, **kw): + new_filter = func(self, *args, **kw) + + filter = self.filter + if filter is not None: + next(new_filter) + self.filter = filter_pipe(filter, new_filter) + else: + self.filter = new_filter + + next(self.filter) + + return wrapper + + +def filter_pipe(filter, filter2): + """Creates pipe between two filters. + + filter_pipe() feeds first filter with incoming data and then + send yielded from first filter data into filter2, results of + filter2 are being emitted. + + 1. If filter_pipe receives bytes object, it sends it to the first filter. + 2. Reads yielded values from the first filter until it receives + EOF_MARKER or EOL_MARKER. + 3. Each of this values is being send to second filter. + 4. Reads yielded values from second filter until it recives EOF_MARKER or + EOL_MARKER. Each of this values yields to writer. + """ + chunk = yield + + while True: + eof = chunk is EOF_MARKER + chunk = filter.send(chunk) + + while chunk is not EOL_MARKER: + chunk = filter2.send(chunk) + + while chunk not in (EOF_MARKER, EOL_MARKER): + yield chunk + chunk = next(filter2) + + if chunk is not EOF_MARKER: + if eof: + chunk = EOF_MARKER + else: + chunk = next(filter) + else: + break + + chunk = yield EOL_MARKER + + +class HttpMessage: + """HttpMessage allows to write headers and payload to a stream. + + For example, lets say we want to read file then compress it with deflate + compression and then send it with chunked transfer encoding, code may look + like this: + + >> response = tulip.http.Response(transport, 200) + + We have to use deflate compression first: + + >> response.add_compression_filter('deflate') + + Then we want to split output stream into chunks of 1024 bytes size: + + >> response.add_chunking_filter(1024) + + We can add headers to response with add_headers() method. add_headers() + does not send data to transport, send_headers() sends request/response + line and then sends headers: + + >> response.add_headers( + .. ('Content-Disposition', 'attachment; filename="..."')) + >> response.send_headers() + + Now we can use chunked writer to write stream to a network stream. + First call to write() method sends response status line and headers, + add_header() and add_headers() method unavailble at this stage: + + >> with open('...', 'rb') as f: + .. chunk = fp.read(8196) + .. while chunk: + .. response.write(chunk) + .. chunk = fp.read(8196) + + >> response.write_eof() + """ + + writer = None + + # 'filter' is being used for altering write() bahaviour, + # add_chunking_filter adds deflate/gzip compression and + # add_compression_filter splits incoming data into a chunks. + filter = None + + HOP_HEADERS = None # Must be set by subclass. + + SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} tulip/0.0'.format(sys.version_info) + + status = None + status_line = b'' + upgrade = False # Connection: UPGRADE + websocket = False # Upgrade: WEBSOCKET + + # subclass can enable auto sending headers with write() call, + # this is useful for wsgi's start_response implementation. + _send_headers = False + + def __init__(self, transport, version, close): + self.transport = transport + self.version = version + self.closing = close + + # disable keep-alive for http/1.0 + if version <= (1, 0): + self.keepalive = False + else: + self.keepalive = None + + self.chunked = False + self.length = None + self.headers = collections.deque() + self.headers_sent = False + + def force_close(self): + self.closing = True + self.keepalive = False + + def force_chunked(self): + self.chunked = True + + def keep_alive(self): + if self.keepalive is None: + return not self.closing + else: + return self.keepalive + + def is_headers_sent(self): + return self.headers_sent + + def add_header(self, name, value): + """Analyze headers. Calculate content length, + removes hop headers, etc.""" + assert not self.headers_sent, 'headers have been sent already' + assert isinstance(name, str), '{!r} is not a string'.format(name) + + name = name.strip().upper() + + if name == 'CONTENT-LENGTH': + self.length = int(value) + + if name == 'CONNECTION': + val = value.lower() + # handle websocket + if 'upgrade' in val: + self.upgrade = True + # connection keep-alive + elif 'close' in val: + self.keepalive = False + elif 'keep-alive' in val and self.version >= (1, 1): + self.keepalive = True + + elif name == 'UPGRADE': + if 'websocket' in value.lower(): + self.websocket = True + self.headers.append((name, value)) + + elif name == 'TRANSFER-ENCODING' and not self.chunked: + self.chunked = value.lower().strip() == 'chunked' + + elif name not in self.HOP_HEADERS: + # ignore hopbyhop headers + self.headers.append((name, value)) + + def add_headers(self, *headers): + """Adds headers to a http message.""" + for name, value in headers: + self.add_header(name, value) + + def send_headers(self): + """Writes headers to a stream. Constructs payload writer.""" + # Chunked response is only for HTTP/1.1 clients or newer + # and there is no Content-Length header is set. + # Do not use chunked responses when the response is guaranteed to + # not have a response body (304, 204). + assert not self.headers_sent, 'headers have been sent already' + self.headers_sent = True + + if (self.chunked is True) or ( + self.length is None and + self.version >= (1, 1) and + self.status not in (304, 204)): + self.chunked = True + self.writer = self._write_chunked_payload() + + elif self.length is not None: + self.writer = self._write_length_payload(self.length) + + else: + self.writer = self._write_eof_payload() + + next(self.writer) + + self._add_default_headers() + + # status + headers + hdrs = ''.join(itertools.chain( + (self.status_line,), + *((k, ': ', v, '\r\n') for k, v in self.headers))) + + self.transport.write(hdrs.encode('ascii') + b'\r\n') + + def _add_default_headers(self): + # set the connection header + if self.upgrade: + connection = 'upgrade' + elif not self.closing if self.keepalive is None else self.keepalive: + connection = 'keep-alive' + else: + connection = 'close' + + if self.chunked: + self.headers.appendleft(('TRANSFER-ENCODING', 'chunked')) + + self.headers.appendleft(('CONNECTION', connection)) + + def write(self, chunk): + """write() writes chunk of data to a steram by using different writers. + writer uses filter to modify chunk of data. write_eof() indicates + end of stream. writer can't be used after write_eof() method + being called.""" + assert (isinstance(chunk, (bytes, bytearray)) or + chunk is EOF_MARKER), chunk + + if self._send_headers and not self.headers_sent: + self.send_headers() + + assert self.writer is not None, 'send_headers() is not called.' + + if self.filter: + chunk = self.filter.send(chunk) + while chunk not in (EOF_MARKER, EOL_MARKER): + self.writer.send(chunk) + chunk = next(self.filter) + else: + if chunk is not EOF_MARKER: + self.writer.send(chunk) + + def write_eof(self): + self.write(EOF_MARKER) + try: + self.writer.throw(tulip.EofStream()) + except StopIteration: + pass + + def _write_chunked_payload(self): + """Write data in chunked transfer encoding.""" + while True: + try: + chunk = yield + except tulip.EofStream: + self.transport.write(b'0\r\n\r\n') + break + + self.transport.write('{:x}\r\n'.format(len(chunk)).encode('ascii')) + self.transport.write(bytes(chunk)) + self.transport.write(b'\r\n') + + def _write_length_payload(self, length): + """Write specified number of bytes to a stream.""" + while True: + try: + chunk = yield + except tulip.EofStream: + break + + if length: + l = len(chunk) + if length >= l: + self.transport.write(chunk) + else: + self.transport.write(chunk[:length]) + + length = max(0, length-l) + + def _write_eof_payload(self): + while True: + try: + chunk = yield + except tulip.EofStream: + break + + self.transport.write(chunk) + + @wrap_payload_filter + def add_chunking_filter(self, chunk_size=16*1024): + """Split incoming stream into chunks.""" + buf = bytearray() + chunk = yield + + while True: + if chunk is EOF_MARKER: + if buf: + yield buf + + yield EOF_MARKER + + else: + buf.extend(chunk) + + while len(buf) >= chunk_size: + chunk = bytes(buf[:chunk_size]) + del buf[:chunk_size] + yield chunk + + chunk = yield EOL_MARKER + + @wrap_payload_filter + def add_compression_filter(self, encoding='deflate'): + """Compress incoming stream with deflate or gzip encoding.""" + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + zcomp = zlib.compressobj(wbits=zlib_mode) + + chunk = yield + while True: + if chunk is EOF_MARKER: + yield zcomp.flush() + chunk = yield EOF_MARKER + + else: + yield zcomp.compress(chunk) + chunk = yield EOL_MARKER + + +class Response(HttpMessage): + """Create http response message. + + Transport is a socket stream transport. status is a response status code, + status has to be integer value. http_version is a tuple that represents + http version, (1, 0) stands for HTTP/1.0 and (1, 1) is for HTTP/1.1 + """ + + HOP_HEADERS = { + 'CONNECTION', + 'KEEP-ALIVE', + 'PROXY-AUTHENTICATE', + 'PROXY-AUTHORIZATION', + 'TE', + 'TRAILERS', + 'TRANSFER-ENCODING', + 'UPGRADE', + 'SERVER', + 'DATE', + } + + def __init__(self, transport, status, http_version=(1, 1), close=False): + super().__init__(transport, http_version, close) + + self.status = status + self.status_line = 'HTTP/{}.{} {} {}\r\n'.format( + http_version[0], http_version[1], status, RESPONSES[status][0]) + + def _add_default_headers(self): + super()._add_default_headers() + self.headers.extend((('DATE', format_date_time(None)), + ('SERVER', self.SERVER_SOFTWARE),)) + + +class Request(HttpMessage): + + HOP_HEADERS = () + + def __init__(self, transport, method, path, + http_version=(1, 1), close=False): + super().__init__(transport, http_version, close) + + self.method = method + self.path = path + self.status_line = '{0} {1} HTTP/{2[0]}.{2[1]}\r\n'.format( + method, path, http_version) + + def _add_default_headers(self): + super()._add_default_headers() + self.headers.append(('USER-AGENT', self.SERVER_SOFTWARE)) diff --git a/tulip/http/server.py b/tulip/http/server.py new file mode 100644 index 00000000..fc5621c5 --- /dev/null +++ b/tulip/http/server.py @@ -0,0 +1,215 @@ +"""simple http server.""" + +__all__ = ['ServerHttpProtocol'] + +import http.server +import inspect +import logging +import traceback + +import tulip +from tulip.http import errors + + +RESPONSES = http.server.BaseHTTPRequestHandler.responses +DEFAULT_ERROR_MESSAGE = """ + + + {status} {reason} + + +

{status} {reason}

+ {message} + +""" + + +class ServerHttpProtocol(tulip.Protocol): + """Simple http protocol implementation. + + ServerHttpProtocol handles incoming http request. It reads request line, + request headers and request payload and calls handler_request() method. + By default it always returns with 404 respose. + + ServerHttpProtocol handles errors in incoming request, like bad + status line, bad headers or incomplete payload. If any error occurs, + connection gets closed. + + log: custom logging object + debug: enable debug mode + keep_alive: number of seconds before closing keep alive connection + loop: event loop object + """ + _request_count = 0 + _request_handler = None + _keep_alive = False # keep transport open + _keep_alive_handle = None # keep alive timer handle + + def __init__(self, *, log=logging, debug=False, + keep_alive=None, loop=None, **kwargs): + self.__dict__.update(kwargs) + self.log = log + self.debug = debug + + self._keep_alive_period = keep_alive # number of seconds to keep alive + + if keep_alive and loop is None: + loop = tulip.get_event_loop() + self._loop = loop + + def connection_made(self, transport): + self.transport = transport + self.stream = tulip.StreamBuffer(loop=self._loop) + self._request_handler = tulip.Task(self.start(), loop=self._loop) + + def data_received(self, data): + self.stream.feed_data(data) + + def eof_received(self): + self.stream.feed_eof() + + def connection_lost(self, exc): + self.stream.feed_eof() + + if self._request_handler is not None: + self._request_handler.cancel() + self._request_handler = None + if self._keep_alive_handle is not None: + self._keep_alive_handle.cancel() + self._keep_alive_handle = None + + def keep_alive(self, val): + self._keep_alive = val + + def log_access(self, status, message, *args, **kw): + pass + + def log_debug(self, *args, **kw): + if self.debug: + self.log.debug(*args, **kw) + + def log_exception(self, *args, **kw): + self.log.exception(*args, **kw) + + @tulip.coroutine + def start(self): + """Start processing of incoming requests. + It reads request line, request headers and request payload, then + calls handle_request() method. Subclass has to override + handle_request(). start() handles various excetions in request + or response handling. Connection is being closed always unless + keep_alive(True) specified. + """ + + while True: + info = None + message = None + self._request_count += 1 + self._keep_alive = False + + try: + httpstream = self.stream.set_parser( + tulip.http.http_request_parser()) + + message = yield from httpstream.read() + + # cancel keep-alive timer + if self._keep_alive_handle is not None: + self._keep_alive_handle.cancel() + self._keep_alive_handle = None + + payload = self.stream.set_parser( + tulip.http.http_payload_parser(message)) + + handler = self.handle_request(message, payload) + if (inspect.isgenerator(handler) or + isinstance(handler, tulip.Future)): + yield from handler + + except tulip.CancelledError: + self.log_debug('Ignored premature client disconnection.') + break + except errors.HttpException as exc: + self.handle_error(exc.code, info, message, exc, exc.headers) + except Exception as exc: + self.handle_error(500, info, message, exc) + finally: + if self._request_handler: + if self._keep_alive and self._keep_alive_period: + self._keep_alive_handle = self._loop.call_later( + self._keep_alive_period, self.transport.close) + else: + self.transport.close() + self._request_handler = None + break + else: + break + + def handle_error(self, status=500, + message=None, payload=None, exc=None, headers=None): + """Handle errors. + + Returns http response with specific status code. Logs additional + information. It always closes current connection.""" + try: + if self._request_handler is None: + # client has been disconnected during writing. + return + + if status == 500: + self.log_exception("Error handling request") + + try: + reason, msg = RESPONSES[status] + except KeyError: + status = 500 + reason, msg = '???', '' + + if self.debug and exc is not None: + try: + tb = traceback.format_exc() + msg += '

Traceback:

\n
{}
'.format(tb) + except: + pass + + self.log_access(status, message) + + html = DEFAULT_ERROR_MESSAGE.format( + status=status, reason=reason, message=msg) + + response = tulip.http.Response(self.transport, status, close=True) + response.add_headers( + ('Content-Type', 'text/html'), + ('Content-Length', str(len(html)))) + if headers is not None: + response.add_headers(*headers) + response.send_headers() + + response.write(html.encode('ascii')) + response.write_eof() + finally: + self.keep_alive(False) + + def handle_request(self, message, payload): + """Handle a single http request. + + Subclass should override this method. By default it always + returns 404 response. + + info: tulip.http.RequestLine instance + message: tulip.http.RawHttpMessage instance + """ + response = tulip.http.Response( + self.transport, 404, http_version=message.version, close=True) + + body = b'Page Not Found!' + + response.add_headers( + ('Content-Type', 'text/plain'), + ('Content-Length', str(len(body)))) + response.send_headers() + response.write(body) + response.write_eof() + + self.keep_alive(False) + self.log_access(404, message) diff --git a/tulip/http/session.py b/tulip/http/session.py new file mode 100644 index 00000000..9cdd9cea --- /dev/null +++ b/tulip/http/session.py @@ -0,0 +1,103 @@ +"""client session support.""" + +__all__ = ['Session'] + +import functools +import tulip +import http.cookies + + +class Session: + + def __init__(self): + self._conns = {} + self.cookies = http.cookies.SimpleCookie() + + def __del__(self): + self.close() + + def close(self): + """Close all opened transports.""" + for key, data in self._conns.items(): + for transport, proto in data: + transport.close() + + self._conns.clear() + + def update_cookies(self, cookies): + if isinstance(cookies, dict): + cookies = cookies.items() + + for name, value in cookies: + if isinstance(value, http.cookies.Morsel): + # use dict method because SimpleCookie class modifies value + dict.__setitem__(self.cookies, name, value) + else: + self.cookies[name] = value + + @tulip.coroutine + def start(self, req, loop, new_conn=False, set_cookies=True): + key = (req.host, req.port, req.ssl) + + if set_cookies and self.cookies: + req.update_cookies(self.cookies.items()) + + if not new_conn: + transport, proto = self._get(key) + + if new_conn or transport is None: + new = True + transport, proto = yield from loop.create_connection( + functools.partial(tulip.StreamProtocol, loop=loop), + req.host, req.port, ssl=req.ssl) + else: + new = False + + try: + resp = req.send(transport) + yield from resp.start( + proto, TransportWrapper( + self._release, key, transport, proto, resp)) + except: + if new: + transport.close() + raise + + return (yield from self.start(req, loop, set_cookies=False)) + + return resp + + def _get(self, key): + conns = self._conns.get(key) + if conns: + return conns.pop() + + return None, None + + def _release(self, resp, key, conn): + msg = resp.message + if msg.should_close: + conn[0].close() + else: + conns = self._conns.get(key) + if conns is None: + conns = self._conns[key] = [] + conns.append(conn) + conn[1].unset_parser() + + if resp.cookies: + self.update_cookies(resp.cookies.items()) + + +class TransportWrapper: + + def __init__(self, release, key, transport, protocol, response): + self.release = release + self.key = key + self.transport = transport + self.protocol = protocol + self.response = response + + def close(self): + self.release(self.response, self.key, + (self.transport, self.protocol)) diff --git a/tulip/http/websocket.py b/tulip/http/websocket.py new file mode 100644 index 00000000..c3dd5872 --- /dev/null +++ b/tulip/http/websocket.py @@ -0,0 +1,233 @@ +"""WebSocket protocol versions 13 and 8.""" + +__all__ = ['WebSocketParser', 'WebSocketWriter', 'do_handshake', + 'Message', 'WebSocketError', + 'MSG_TEXT', 'MSG_BINARY', 'MSG_CLOSE', 'MSG_PING', 'MSG_PONG'] + +import base64 +import binascii +import collections +import hashlib +import struct +from tulip.http import errors + +# Frame opcodes defined in the spec. +OPCODE_CONTINUATION = 0x0 +MSG_TEXT = OPCODE_TEXT = 0x1 +MSG_BINARY = OPCODE_BINARY = 0x2 +MSG_CLOSE = OPCODE_CLOSE = 0x8 +MSG_PING = OPCODE_PING = 0x9 +MSG_PONG = OPCODE_PONG = 0xa + +WS_KEY = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +WS_HDRS = ('UPGRADE', 'CONNECTION', + 'SEC-WEBSOCKET-VERSION', 'SEC-WEBSOCKET-KEY') + +Message = collections.namedtuple('Message', ['tp', 'data', 'extra']) + + +class WebSocketError(Exception): + """WebSocket protocol parser error.""" + + +def WebSocketParser(): + out, buf = yield + + while True: + message = yield from parse_message(buf) + out.feed_data(message) + + if message.tp == MSG_CLOSE: + out.feed_eof() + break + + +def parse_frame(buf): + """Return the next frame from the socket.""" + # read header + data = yield from buf.read(2) + first_byte, second_byte = struct.unpack('!BB', data) + + fin = (first_byte >> 7) & 1 + rsv1 = (first_byte >> 6) & 1 + rsv2 = (first_byte >> 5) & 1 + rsv3 = (first_byte >> 4) & 1 + opcode = first_byte & 0xf + + # frame-fin = %x0 ; more frames of this message follow + # / %x1 ; final frame of this message + # frame-rsv1 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise + # frame-rsv2 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise + # frame-rsv3 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise + if rsv1 or rsv2 or rsv3: + raise WebSocketError('Received frame with non-zero reserved bits') + + if opcode > 0x7 and fin == 0: + raise WebSocketError('Received fragmented control frame') + + if fin == 0 and opcode == OPCODE_CONTINUATION: + raise WebSocketError( + 'Received new fragment frame with non-zero opcode') + + has_mask = (second_byte >> 7) & 1 + length = (second_byte) & 0x7f + + # Control frames MUST have a payload length of 125 bytes or less + if opcode > 0x7 and length > 125: + raise WebSocketError( + "Control frame payload cannot be larger than 125 bytes") + + # read payload + if length == 126: + data = yield from buf.read(2) + length = struct.unpack_from('!H', data)[0] + elif length > 126: + data = yield from buf.read(8) + length = struct.unpack_from('!Q', data)[0] + + if has_mask: + mask = yield from buf.read(4) + + if length: + payload = yield from buf.read(length) + else: + payload = b'' + + if has_mask: + payload = bytes(b ^ mask[i % 4] for i, b in enumerate(payload)) + + return fin, opcode, payload + + +def parse_message(buf): + fin, opcode, payload = yield from parse_frame(buf) + + if opcode == OPCODE_CLOSE: + if len(payload) >= 2: + close_code = struct.unpack('!H', payload[:2])[0] + close_message = payload[2:] + return Message(OPCODE_CLOSE, close_code, close_message) + elif payload: + raise WebSocketError( + 'Invalid close frame: {} {} {!r}'.format(fin, opcode, payload)) + return Message(OPCODE_CLOSE, '', '') + + elif opcode == OPCODE_PING: + return Message(OPCODE_PING, '', '') + + elif opcode == OPCODE_PONG: + return Message(OPCODE_PONG, '', '') + + elif opcode not in (OPCODE_TEXT, OPCODE_BINARY): + raise WebSocketError("Unexpected opcode={!r}".format(opcode)) + + # load text/binary + data = [payload] + + while not fin: + fin, _opcode, payload = yield from parse_frame(buf) + if _opcode != OPCODE_CONTINUATION: + raise WebSocketError( + 'The opcode in non-fin frame is expected ' + 'to be zero, got {!r}'.format(opcode)) + else: + data.append(payload) + + if opcode == OPCODE_TEXT: + return Message(OPCODE_TEXT, b''.join(data).decode('utf-8'), '') + else: + return Message(OPCODE_BINARY, b''.join(data), '') + + +class WebSocketWriter: + + def __init__(self, transport): + self.transport = transport + + def _send_frame(self, message, opcode): + """Send a frame over the websocket with message as its payload.""" + header = bytes([0x80 | opcode]) + msg_length = len(message) + + if msg_length < 126: + header += bytes([msg_length]) + elif msg_length < (1 << 16): + header += bytes([126]) + struct.pack('!H', msg_length) + else: + header += bytes([127]) + struct.pack('!Q', msg_length) + + self.transport.write(header + message) + + def pong(self): + """Send pong message.""" + self._send_frame(b'', OPCODE_PONG) + + def ping(self): + """Send pong message.""" + self._send_frame(b'', OPCODE_PING) + + def send(self, message, binary=False): + """Send a frame over the websocket with message as its payload.""" + if isinstance(message, str): + message = message.encode('utf-8') + if binary: + self._send_frame(message, OPCODE_BINARY) + else: + self._send_frame(message, OPCODE_TEXT) + + def close(self, code=1000, message=b''): + """Close the websocket, sending the specified code and message.""" + if isinstance(message, str): + message = message.encode('utf-8') + self._send_frame( + struct.pack('!H%ds' % len(message), code, message), + opcode=OPCODE_CLOSE) + + +def do_handshake(method, headers, transport): + """Prepare WebSocket handshake. It return http response code, + response headers, websocket parser, websocket writer. It does not + perform any IO.""" + + # WebSocket accepts only GET + if method.upper() != 'GET': + raise errors.HttpErrorException(405, headers=(('Allow', 'GET'),)) + + headers = dict(((hdr, val) for hdr, val in headers if hdr in WS_HDRS)) + + if 'websocket' != headers.get('UPGRADE', '').lower().strip(): + raise errors.BadRequestException( + 'No WebSocket UPGRADE hdr: {}\n' + 'Can "Upgrade" only to "WebSocket".'.format( + headers.get('UPGRADE'))) + + if 'upgrade' not in headers.get('CONNECTION', '').lower(): + raise errors.BadRequestException( + 'No CONNECTION upgrade hdr: {}'.format( + headers.get('CONNECTION'))) + + # check supported version + version = headers.get('SEC-WEBSOCKET-VERSION') + if version not in ('13', '8', '7'): + raise errors.BadRequestException( + 'Unsupported version: {}'.format(version)) + + # check client handshake for validity + key = headers.get('SEC-WEBSOCKET-KEY') + try: + if not key or len(base64.b64decode(key)) != 16: + raise errors.BadRequestException( + 'Handshake error: {!r}'.format(key)) + except binascii.Error: + raise errors.BadRequestException( + 'Handshake error: {!r}'.format(key)) from None + + # response code, headers, parser, writer + return (101, + (('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('TRANSFER-ENCODING', 'chunked'), + ('SEC-WEBSOCKET-ACCEPT', base64.b64encode( + hashlib.sha1(key.encode() + WS_KEY).digest()).decode())), + WebSocketParser(), + WebSocketWriter(transport)) diff --git a/tulip/http/wsgi.py b/tulip/http/wsgi.py new file mode 100644 index 00000000..738e100f --- /dev/null +++ b/tulip/http/wsgi.py @@ -0,0 +1,227 @@ +"""wsgi server. + +TODO: + * proxy protocol + * x-forward security + * wsgi file support (os.sendfile) +""" + +__all__ = ['WSGIServerHttpProtocol'] + +import inspect +import io +import os +import sys +from urllib.parse import unquote, urlsplit + +import tulip +import tulip.http +from tulip.http import server + + +class WSGIServerHttpProtocol(server.ServerHttpProtocol): + """HTTP Server that implements the Python WSGI protocol. + + It uses 'wsgi.async' of 'True'. 'wsgi.input' can behave differently + depends on 'readpayload' constructor parameter. If readpayload is set to + True, wsgi server reads all incoming data into BytesIO object and + sends it as 'wsgi.input' environ var. If readpayload is set to false + 'wsgi.input' is a StreamReader and application should read incoming + data with "yield from environ['wsgi.input'].read()". It defaults to False. + """ + + SCRIPT_NAME = os.environ.get('SCRIPT_NAME', '') + + def __init__(self, app, readpayload=False, is_ssl=False, *args, **kw): + super().__init__(*args, **kw) + + self.wsgi = app + self.is_ssl = is_ssl + self.readpayload = readpayload + + def create_wsgi_response(self, message): + return WsgiResponse(self.transport, message) + + def create_wsgi_environ(self, message, payload): + uri_parts = urlsplit(message.path) + url_scheme = 'https' if self.is_ssl else 'http' + + environ = { + 'wsgi.input': payload, + 'wsgi.errors': sys.stderr, + 'wsgi.version': (1, 0), + 'wsgi.async': True, + 'wsgi.multithread': False, + 'wsgi.multiprocess': False, + 'wsgi.run_once': False, + 'wsgi.file_wrapper': FileWrapper, + 'wsgi.url_scheme': url_scheme, + 'SERVER_SOFTWARE': tulip.http.HttpMessage.SERVER_SOFTWARE, + 'REQUEST_METHOD': message.method, + 'QUERY_STRING': uri_parts.query or '', + 'RAW_URI': message.path, + 'SERVER_PROTOCOL': 'HTTP/%s.%s' % message.version + } + + # authors should be aware that REMOTE_HOST and REMOTE_ADDR + # may not qualify the remote addr: + # http://www.ietf.org/rfc/rfc3875 + forward = self.transport.get_extra_info('addr', '127.0.0.1') + script_name = self.SCRIPT_NAME + server = forward + + for hdr_name, hdr_value in message.headers: + if hdr_name == 'EXPECT': + # handle expect + if hdr_value.lower() == '100-continue': + self.transport.write(b'HTTP/1.1 100 Continue\r\n\r\n') + elif hdr_name == 'HOST': + server = hdr_value + elif hdr_name == 'SCRIPT_NAME': + script_name = hdr_value + elif hdr_name == 'CONTENT-TYPE': + environ['CONTENT_TYPE'] = hdr_value + continue + elif hdr_name == 'CONTENT-LENGTH': + environ['CONTENT_LENGTH'] = hdr_value + continue + + key = 'HTTP_%s' % hdr_name.replace('-', '_') + if key in environ: + hdr_value = '%s,%s' % (environ[key], hdr_value) + + environ[key] = hdr_value + + if isinstance(forward, str): + # we only took the last one + # http://en.wikipedia.org/wiki/X-Forwarded-For + if ',' in forward: + forward = forward.rsplit(',', 1)[-1].strip() + + # find host and port on ipv6 address + if '[' in forward and ']' in forward: + host = forward.split(']')[0][1:].lower() + elif ':' in forward and forward.count(':') == 1: + host = forward.split(':')[0].lower() + else: + host = forward + + forward = forward.split(']')[-1] + if ':' in forward and forward.count(':') == 1: + port = forward.split(':', 1)[1] + else: + port = 80 + + remote = (host, port) + else: + remote = forward + + environ['REMOTE_ADDR'] = remote[0] + environ['REMOTE_PORT'] = str(remote[1]) + + if isinstance(server, str): + server = server.split(':') + if len(server) == 1: + server.append('80' if url_scheme == 'http' else '443') + + environ['SERVER_NAME'] = server[0] + environ['SERVER_PORT'] = str(server[1]) + + path_info = uri_parts.path + if script_name: + path_info = path_info.split(script_name, 1)[-1] + + environ['PATH_INFO'] = unquote(path_info) + environ['SCRIPT_NAME'] = script_name + + environ['tulip.reader'] = self.stream + environ['tulip.writer'] = self.transport + + return environ + + @tulip.coroutine + def handle_request(self, message, payload): + """Handle a single HTTP request""" + + if self.readpayload: + wsgiinput = io.BytesIO() + chunk = yield from payload.read() + while chunk: + wsgiinput.write(chunk) + chunk = yield from payload.read() + wsgiinput.seek(0) + payload = wsgiinput + + environ = self.create_wsgi_environ(message, payload) + response = self.create_wsgi_response(message) + + riter = self.wsgi(environ, response.start_response) + if isinstance(riter, tulip.Future) or inspect.isgenerator(riter): + riter = yield from riter + + resp = response.response + try: + for item in riter: + if isinstance(item, tulip.Future): + item = yield from item + resp.write(item) + + resp.write_eof() + finally: + if hasattr(riter, 'close'): + riter.close() + + if resp.keep_alive(): + self.keep_alive(True) + + +class FileWrapper: + """Custom file wrapper.""" + + def __init__(self, fobj, chunk_size=8192): + self.fobj = fobj + self.chunk_size = chunk_size + if hasattr(fobj, 'close'): + self.close = fobj.close + + def __iter__(self): + return self + + def __next__(self): + data = self.fobj.read(self.chunk_size) + if data: + return data + raise StopIteration + + +class WsgiResponse: + """Implementation of start_response() callable as specified by PEP 3333""" + + status = None + + def __init__(self, transport, message): + self.transport = transport + self.message = message + + def start_response(self, status, headers, exc_info=None): + if exc_info: + try: + if self.status: + raise exc_info[1] + finally: + exc_info = None + + status_code = int(status.split(' ', 1)[0]) + + self.status = status + resp = self.response = tulip.http.Response( + self.transport, status_code, + self.message.version, self.message.should_close) + resp.add_headers(*headers) + + # send headers immediately for websocket connection + if status_code == 101 and resp.upgrade and resp.websocket: + resp.send_headers() + else: + resp._send_headers = True + return self.response.write diff --git a/tulip/locks.py b/tulip/locks.py new file mode 100644 index 00000000..87937ec0 --- /dev/null +++ b/tulip/locks.py @@ -0,0 +1,403 @@ +"""Synchronization primitives.""" + +__all__ = ['Lock', 'EventWaiter', 'Condition', 'Semaphore'] + +import collections + +from . import events +from . import futures +from . import tasks + + +class Lock: + """Primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned + by a particular coroutine when locked. A primitive lock is in one + of two states, 'locked' or 'unlocked'. + + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() + changes the state to locked and returns immediately. When the + state is locked, acquire() blocks until a call to release() in + another coroutine changes it to unlocked, then the acquire() call + resets it to locked and returns. The release() method should only + be called in the locked state; it changes the state to unlocked + and returns immediately. If an attempt is made to release an + unlocked lock, a RuntimeError will be raised. + + When more than one coroutine is blocked in acquire() waiting for + the state to turn to unlocked, only one coroutine proceeds when a + release() call resets the state to unlocked; first coroutine which + is blocked in acquire() is being processed. + + acquire() is a coroutine and should be called with 'yield from'. + + Locks also support the context manager protocol. '(yield from lock)' + should be used as context manager expression. + + Usage: + + lock = Lock() + ... + yield from lock + try: + ... + finally: + lock.release() + + Context manager usage: + + lock = Lock() + ... + with (yield from lock): + ... + + Lock objects can be tested for locking state: + + if not lock.locked(): + yield from lock + else: + # lock is acquired + ... + + """ + + def __init__(self, *, loop=None): + self._waiters = collections.deque() + self._locked = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + extra = 'locked' if self._locked else 'unlocked' + if self._waiters: + extra = '{},waiters:{}'.format(extra, len(self._waiters)) + return '<{} [{}]>'.format(res[1:-1], extra) + + def locked(self): + """Return true if lock is acquired.""" + return self._locked + + @tasks.coroutine + def acquire(self): + """Acquire a lock. + + This method blocks until the lock is unlocked, then sets it to + locked and returns True. + """ + if not self._waiters and not self._locked: + self._locked = True + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + self._locked = True + return True + finally: + self._waiters.remove(fut) + + def release(self): + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + If any other coroutines are blocked waiting for the lock to become + unlocked, allow exactly one of them to proceed. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + # Wake up the first waiter who isn't cancelled. + for fut in self._waiters: + if not fut.cancelled(): + fut.set_result(True) + break + else: + raise RuntimeError('Lock is not acquired.') + + def __enter__(self): + if not self._locked: + raise RuntimeError( + '"yield from" should be used as context manager expression') + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self + + +# TODO: Why not call this Event? +class EventWaiter: + """A EventWaiter implementation, our equivalent to threading.Event + + Class implementing event objects. An event manages a flag that can be set + to true with the set() method and reset to false with the clear() method. + The wait() method blocks until the flag is true. The flag is initially + false. + """ + + def __init__(self, *, loop=None): + self._waiters = collections.deque() + self._value = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + # TODO: add waiters:N if > 0. + res = super().__repr__() + return '<{} [{}]>'.format(res[1:-1], 'set' if self._value else 'unset') + + def is_set(self): + """Return true if and only if the internal flag is true.""" + return self._value + + def set(self): + """Set the internal flag to true. All coroutines waiting for it to + become true are awakened. Coroutine that call wait() once the flag is + true will not block at all. + """ + if not self._value: + self._value = True + + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + + def clear(self): + """Reset the internal flag to false. Subsequently, coroutines calling + wait() will block until set() is called to set the internal flag + to true again.""" + self._value = False + + @tasks.coroutine + def wait(self): + """Block until the internal flag is true. + + If the internal flag is true on entry, return True + immediately. Otherwise, block until another coroutine calls + set() to set the flag to true, then return True. + """ + if self._value: + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + return True + finally: + self._waiters.remove(fut) + + +# TODO: Why is this a Lock subclass? threading.Condition *has* a lock. +class Condition(Lock): + """A Condition implementation. + + This class implements condition variable objects. A condition variable + allows one or more coroutines to wait until they are notified by another + coroutine. + """ + + def __init__(self, *, loop=None): + super().__init__(loop=loop) + self._condition_waiters = collections.deque() + + # TODO: Add __repr__() with len(_condition_waiters). + + @tasks.coroutine + def wait(self): + """Wait until notified. + + If the calling coroutine has not acquired the lock when this + method is called, a RuntimeError is raised. + + This method releases the underlying lock, and then blocks + until it is awakened by a notify() or notify_all() call for + the same condition variable in another coroutine. Once + awakened, it re-acquires the lock and returns True. + """ + if not self._locked: + raise RuntimeError('cannot wait on un-acquired lock') + + keep_lock = True + self.release() + try: + fut = futures.Future(loop=self._loop) + self._condition_waiters.append(fut) + try: + yield from fut + return True + finally: + self._condition_waiters.remove(fut) + + except GeneratorExit: + keep_lock = False # Prevent yield in finally clause. + raise + finally: + if keep_lock: + yield from self.acquire() + + @tasks.coroutine + def wait_for(self, predicate): + """Wait until a predicate becomes true. + + The predicate should be a callable which result will be + interpreted as a boolean value. The final predicate value is + the return value. + """ + result = predicate() + while not result: + yield from self.wait() + result = predicate() + return result + + def notify(self, n=1): + """By default, wake up one coroutine waiting on this condition, if any. + If the calling coroutine has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up at most n of the coroutines waiting for the + condition variable; it is a no-op if no coroutines are waiting. + + Note: an awakened coroutine does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self._locked: + raise RuntimeError('cannot notify on un-acquired lock') + + idx = 0 + for fut in self._condition_waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self): + """Wake up all threads waiting on this condition. This method acts + like notify(), but wakes up all waiting threads instead of one. If the + calling thread has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._condition_waiters)) + + +class Semaphore: + """A Semaphore implementation. + + A semaphore manages an internal counter which is decremented by each + acquire() call and incremented by each release() call. The counter + can never go below zero; when acquire() finds that it is zero, it blocks, + waiting until some other thread calls release(). + + Semaphores also support the context manager protocol. + + The first optional argument gives the initial value for the internal + counter; it defaults to 1. If the value given is less than 0, + ValueError is raised. + + The second optional argument determins can semophore be released more than + initial internal counter value; it defaults to False. If the value given + is True and number of release() is more than number of successfull + acquire() calls ValueError is raised. + """ + + def __init__(self, value=1, bound=False, *, loop=None): + if value < 0: + raise ValueError("Semaphore initial value must be > 0") + self._value = value + self._bound = bound + self._bound_value = value + self._waiters = collections.deque() + self._locked = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + # TODO: add waiters:N if > 0. + res = super().__repr__() + return '<{} [{}]>'.format( + res[1:-1], + 'locked' if self._locked else 'unlocked,value:{}'.format( + self._value)) + + def locked(self): + """Returns True if semaphore can not be acquired immediately.""" + return self._locked + + @tasks.coroutine + def acquire(self): + """Acquire a semaphore. + + If the internal counter is larger than zero on entry, + decrement it by one and return True immediately. If it is + zero on entry, block, waiting until some other coroutine has + called release() to make it larger than 0, and then return + True. + """ + if not self._waiters and self._value > 0: + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + self._value -= 1 + if self._value == 0: + self._locked = True + return True + finally: + self._waiters.remove(fut) + + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + When it was zero on entry and another coroutine is waiting for it to + become larger than zero again, wake up that coroutine. + + If Semaphore is create with "bound" paramter equals true, then + release() method checks to make sure its current value doesn't exceed + its initial value. If it does, ValueError is raised. + """ + if self._bound and self._value >= self._bound_value: + raise ValueError('Semaphore released too many times') + + self._value += 1 + self._locked = False + + for waiter in self._waiters: + if not waiter.done(): + waiter.set_result(True) + break + + def __enter__(self): + # TODO: This is questionable. How do we know the user actually + # wrote "with (yield from sema)" instead of "with sema"? + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self diff --git a/tulip/log.py b/tulip/log.py new file mode 100644 index 00000000..b918fe54 --- /dev/null +++ b/tulip/log.py @@ -0,0 +1,6 @@ +"""Tulip logging configuration""" + +import logging + + +tulip_log = logging.getLogger("tulip") diff --git a/tulip/parsers.py b/tulip/parsers.py new file mode 100644 index 00000000..43ddc2e9 --- /dev/null +++ b/tulip/parsers.py @@ -0,0 +1,399 @@ +"""Parser is a generator function. + +Parser receives data with generator's send() method and sends data to +destination DataBuffer. Parser receives ParserBuffer and DataBuffer objects +as a parameters of the first send() call, all subsequent send() calls should +send bytes objects. Parser sends parsed 'term' to desitnation buffer with +DataBuffer.feed_data() method. DataBuffer object should implement two methods. +feed_data() - parser uses this method to send parsed protocol data. +feed_eof() - parser uses this method for indication of end of parsing stream. +To indicate end of incoming data stream EofStream exception should be sent +into parser. Parser could throw exceptions. + +There are three stages: + + * Data flow chain: + + 1. Application creates StreamBuffer object for storing incoming data. + 2. StreamBuffer creates ParserBuffer as internal data buffer. + 3. Application create parser and set it into stream buffer: + + parser = http_request_parser() + data_buffer = stream.set_parser(parser) + + 3. At this stage StreamBuffer creates DataBuffer object and passes it + and internal buffer into parser with first send() call. + + def set_parser(self, parser): + next(parser) + data_buffer = DataBuffer() + parser.send((data_buffer, self._buffer)) + return data_buffer + + 4. Application waits data on data_buffer.read() + + while True: + msg = yield form data_buffer.read() + ... + + * Data flow: + + 1. Tulip's transport reads data from socket and sends data to protocol + with data_received() call. + 2. Protocol sends data to StreamBuffer with feed_data() call. + 3. StreamBuffer sends data into parser with generator's send() method. + 4. Parser processes incoming data and sends parsed data + to DataBuffer with feed_data() + 4. Application received parsed data from DataBuffer.read() + + * Eof: + + 1. StreamBuffer recevies eof with feed_eof() call. + 2. StreamBuffer throws EofStream exception into parser. + 3. Then it unsets parser. + +_SocketSocketTransport -> + -> "protocol" -> StreamBuffer -> "parser" -> DataBuffer <- "application" + +""" +__all__ = ['EofStream', 'StreamBuffer', 'StreamProtocol', + 'ParserBuffer', 'DataBuffer', 'lines_parser', 'chunks_parser'] + +import collections + +from . import tasks +from . import futures +from . import protocols + + +class EofStream(Exception): + """eof stream indication.""" + + +class StreamBuffer: + """StreamBuffer manages incoming bytes stream and protocol parsers. + + StreamBuffer uses ParserBuffer as internal buffer. + + set_parser() sets current parser, it creates DataBuffer object + and sends ParserBuffer and DataBuffer into parser generator. + + unset_parser() sends EofStream into parser and then removes it. + """ + + def __init__(self, *, loop=None): + self._loop = loop + self._buffer = ParserBuffer() + self._eof = False + self._parser = None + self._parser_buffer = None + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + if self._parser_buffer is not None: + self._parser_buffer.set_exception(exc) + self._parser = None + self._parser_buffer = None + + def feed_data(self, data): + """send data to current parser or store in buffer.""" + if not data: + return + + if self._parser: + try: + self._parser.send(data) + except StopIteration: + self._parser = None + self._parser_buffer = None + except Exception as exc: + self._parser_buffer.set_exception(exc) + self._parser = None + self._parser_buffer = None + else: + self._buffer.feed_data(data) + + def feed_eof(self): + """send eof to all parsers, recursively.""" + if self._parser: + try: + self._parser.throw(EofStream()) + except StopIteration: + pass + except EofStream: + self._parser_buffer.feed_eof() + except Exception as exc: + self._parser_buffer.set_exception(exc) + + self._parser = None + self._parser_buffer = None + + self._eof = True + + def set_parser(self, p): + """set parser to stream. return parser's DataStream.""" + if self._parser: + self.unset_parser() + + out = DataBuffer(loop=self._loop) + if self._exception: + out.set_exception(self._exception) + return out + + # init generator + next(p) + try: + # initialize parser with data and parser buffers + p.send((out, self._buffer)) + except StopIteration: + pass + except Exception as exc: + out.set_exception(exc) + else: + # parser still require more data + self._parser = p + self._parser_buffer = out + + if self._eof: + self.unset_parser() + + return out + + def unset_parser(self): + """unset parser, send eof to the parser and then remove it.""" + if self._parser is None: + return + + try: + self._parser.throw(EofStream()) + except StopIteration: + pass + except EofStream: + self._parser_buffer.feed_eof() + except Exception as exc: + self._parser_buffer.set_exception(exc) + finally: + self._parser = None + self._parser_buffer = None + + +class StreamProtocol(StreamBuffer, protocols.Protocol): + """Tulip's stream protocol based on StreamBuffer""" + + transport = None + + data_received = StreamBuffer.feed_data + + eof_received = StreamBuffer.feed_eof + + def connection_made(self, transport): + self.transport = transport + + def connection_lost(self, exc): + self.transport = None + + if exc is not None: + self.set_exception(exc) + else: + self.feed_eof() + + +class DataBuffer: + """DataBuffer is a destination for parsed data.""" + + def __init__(self, *, loop=None): + self._loop = loop + self._buffer = collections.deque() + self._eof = False + self._waiter = None + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + waiter = self._waiter + if waiter is not None: + self._waiter = None + if not waiter.done(): + waiter.set_exception(exc) + + def feed_data(self, data): + self._buffer.append(data) + + waiter = self._waiter + if waiter is not None: + self._waiter = None + waiter.set_result(True) + + def feed_eof(self): + self._eof = True + + waiter = self._waiter + if waiter is not None: + self._waiter = None + waiter.set_result(False) + + @tasks.coroutine + def read(self): + if self._exception is not None: + raise self._exception + + if not self._buffer and not self._eof: + assert not self._waiter + self._waiter = futures.Future(loop=self._loop) + yield from self._waiter + + if self._buffer: + return self._buffer.popleft() + else: + return None + + +class ParserBuffer(bytearray): + """ParserBuffer is a bytearray extension. + + ParserBuffer provides helper methods for parsers. + """ + + def __init__(self, *args): + super().__init__(*args) + + self.offset = 0 + self.size = 0 + self._writer = self._feed_data() + next(self._writer) + + def _shrink(self): + if self.offset: + del self[:self.offset] + self.offset = 0 + self.size = len(self) + + def _feed_data(self): + while True: + chunk = yield + if chunk: + chunk_len = len(chunk) + self.size += chunk_len + self.extend(chunk) + + # shrink buffer + if (self.offset and len(self) > 5120): + self._shrink() + + def feed_data(self, data): + self._writer.send(data) + + def read(self, size): + """read() reads specified amount of bytes.""" + + while True: + if self.size >= size: + start, end = self.offset, self.offset + size + self.offset = end + self.size = self.size - size + return self[start:end] + + self._writer.send((yield)) + + def readsome(self, size=None): + """reads size of less amount of bytes.""" + + while True: + if self.size > 0: + if size is None or self.size < size: + size = self.size + + start, end = self.offset, self.offset + size + self.offset = end + self.size = self.size - size + + return self[start:end] + + self._writer.send((yield)) + + def readuntil(self, stop, limit=None, exc=ValueError): + assert isinstance(stop, bytes) and stop, \ + 'bytes is required: {!r}'.format(stop) + + stop_len = len(stop) + + while True: + pos = self.find(stop, self.offset) + if pos >= 0: + end = pos + stop_len + size = end - self.offset + if limit is not None and size > limit: + raise exc('Line is too long.') + + start, self.offset = self.offset, end + self.size = self.size - size + + return self[start:end] + else: + if limit is not None and self.size > limit: + raise exc('Line is too long.') + + self._writer.send((yield)) + + def skip(self, size): + """skip() skips specified amount of bytes.""" + + while self.size < size: + self._writer.send((yield)) + + self.size -= size + self.offset += size + + def skipuntil(self, stop): + """skipuntil() reads until `stop` bytes sequence.""" + assert isinstance(stop, bytes) and stop, \ + 'bytes is required: {!r}'.format(stop) + + stop_len = len(stop) + + while True: + stop_line = self.find(stop, self.offset) + if stop_line >= 0: + end = stop_line + stop_len + self.size = self.size - (end - self.offset) + self.offset = end + return + else: + self.size = 0 + self.offset = len(self) - 1 + + self._writer.send((yield)) + + def __bytes__(self): + return bytes(self[self.offset:]) + + +def lines_parser(limit=2**16, exc=ValueError): + """Lines parser. + + lines parser splits a bytes stream into a chunks of data, each chunk ends + with \n symbol.""" + out, buf = yield + + while True: + out.feed_data((yield from buf.readuntil(b'\n', limit, exc))) + + +def chunks_parser(size=8196): + """Chunks parser. + + chunks parser splits a bytes stream into a specified + size chunks of data.""" + out, buf = yield + + while True: + out.feed_data((yield from buf.read(size))) diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py new file mode 100644 index 00000000..cda87918 --- /dev/null +++ b/tulip/proactor_events.py @@ -0,0 +1,288 @@ +"""Event loop using a proactor and related classes. + +A proactor is a "notify-on-completion" multiplexer. Currently a +proactor is only implemented on Windows with IOCP. +""" + +import socket + +from . import base_events +from . import constants +from . import futures +from . import transports +from .log import tulip_log + + +class _ProactorBasePipeTransport(transports.BaseTransport): + """Base class for pipe and socket transports.""" + + def __init__(self, loop, sock, protocol, waiter=None, extra=None): + super().__init__(extra) + self._set_extra(sock) + self._loop = loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._read_fut = None + self._write_fut = None + self._writing_disabled = False + self._conn_lost = 0 + self._closing = False # Set when close() called. + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _set_extra(self, sock): + self._extra['pipe'] = sock + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + if not self._buffer and self._write_fut is None: + self._loop.call_soon(self._call_connection_lost, None) + if self._read_fut is not None: + self._read_fut.cancel() + + def _fatal_error(self, exc): + tulip_log.exception('Fatal error for %s', self) + self._force_close(exc) + + def _force_close(self, exc): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + if self._write_fut: + self._write_fut.cancel() + if self._read_fut: + self._read_fut.cancel() + self._write_fut = self._read_fut = None + self._buffer = [] + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + # XXX If there is a pending overlapped read on the other + # end then it may fail with ERROR_NETNAME_DELETED if we + # just close our end. First calling shutdown() seems to + # cure it, but maybe using DisconnectEx() would be better. + if hasattr(self._sock, 'shutdown'): + self._sock.shutdown(socket.SHUT_RDWR) + self._sock.close() + + +class _ProactorReadPipeTransport(_ProactorBasePipeTransport, + transports.ReadTransport): + """Transport for read pipes.""" + + def __init__(self, loop, sock, protocol, waiter=None, extra=None): + super().__init__(loop, sock, protocol, waiter, extra) + self._loop.call_soon(self._loop_reading) + + def _loop_reading(self, fut=None): + data = None + + try: + if fut is not None: + assert fut is self._read_fut + self._read_fut = None + data = fut.result() # deliver data later in "finally" clause + + if self._closing: + # since close() has been called we ignore any read data + data = None + return + + if data == b'': + # we got end-of-file so no need to reschedule a new read + return + + # reschedule a new read + self._read_fut = self._loop._proactor.recv(self._sock, 4096) + except ConnectionAbortedError as exc: + if not self._closing: + self._fatal_error(exc) + except ConnectionResetError as exc: + self._force_close(exc) + except OSError as exc: + self._fatal_error(exc) + except futures.CancelledError: + if not self._closing: + raise + else: + self._read_fut.add_done_callback(self._loop_reading) + finally: + if data: + self._protocol.data_received(data) + elif data is not None: + try: + self._protocol.eof_received() + finally: + self.close() + + +class _ProactorWritePipeTransport(_ProactorBasePipeTransport, + transports.WriteTransport): + """Transport for write pipes.""" + + def write(self, data): + assert isinstance(data, bytes), repr(data) + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + self._buffer.append(data) + if self._write_fut is None and not self._writing_disabled: + self._loop_writing() + + def _loop_writing(self, f=None): + try: + assert f is self._write_fut + self._write_fut = None + if f: + f.result() + data = b''.join(self._buffer) + self._buffer = [] + if not data: + if self._closing: + self._loop.call_soon(self._call_connection_lost, None) + return + if not self._writing_disabled: + self._write_fut = self._loop._proactor.send(self._sock, data) + self._write_fut.add_done_callback(self._loop_writing) + except OSError as exc: + self._fatal_error(exc) + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._force_close(None) + + def pause_writing(self): + self._writing_disabled = True + + def resume_writing(self): + self._writing_disabled = False + if self._buffer and self._write_fut is None: + self._loop_writing() + + def discard_output(self): + if self._buffer: + self._buffer = [] + + +class _ProactorSocketTransport(_ProactorReadPipeTransport, + _ProactorWritePipeTransport, + transports.Transport): + """Transport for connected sockets.""" + + def _set_extra(self, sock): + self._extra['socket'] = sock + + +class BaseProactorEventLoop(base_events.BaseEventLoop): + + def __init__(self, proactor): + super().__init__() + tulip_log.debug('Using proactor: %s', proactor.__class__.__name__) + self._proactor = proactor + self._selector = proactor # convenient alias + proactor.set_loop(self) + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + return _ProactorSocketTransport(self, sock, protocol, waiter, extra) + + def _make_read_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorReadPipeTransport(self, sock, protocol, waiter, extra) + + def _make_write_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorWritePipeTransport(self, sock, protocol, waiter, extra) + + def close(self): + if self._proactor is not None: + self._close_self_pipe() + self._proactor.close() + self._proactor = None + self._selector = None + + def sock_recv(self, sock, n): + return self._proactor.recv(sock, n) + + def sock_sendall(self, sock, data): + return self._proactor.send(sock, data) + + def sock_connect(self, sock, address): + return self._proactor.connect(sock, address) + + def sock_accept(self, sock): + return self._proactor.accept(sock) + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.call_soon(self._loop_self_reading) + + def _loop_self_reading(self, f=None): + try: + if f is not None: + f.result() # may raise + f = self._proactor.recv(self._ssock, 4096) + except: + self.close() + raise + else: + f.add_done_callback(self._loop_self_reading) + + def _write_to_self(self): + self._csock.send(b'x') + + def _start_serving(self, protocol_factory, sock, ssl=None): + assert not ssl, 'IocpEventLoop imcompatible with SSL.' + + def loop(f=None): + try: + if f: + conn, addr = f.result() + protocol = protocol_factory() + self._make_socket_transport( + conn, protocol, extra={'addr': addr}) + f = self._proactor.accept(sock) + except OSError: + sock.close() + tulip_log.exception('Accept failed') + except futures.CancelledError: + sock.close() + else: + f.add_done_callback(loop) + self.call_soon(loop) + + def _process_events(self, event_list): + pass # XXX hard work currently done in poll + + def stop_serving(self, sock): + self._proactor.stop_serving(sock) + sock.close() diff --git a/tulip/protocols.py b/tulip/protocols.py new file mode 100644 index 00000000..d76f25a2 --- /dev/null +++ b/tulip/protocols.py @@ -0,0 +1,100 @@ +"""Abstract Protocol class.""" + +__all__ = ['Protocol', 'DatagramProtocol'] + + +class BaseProtocol: + """ABC for base protocol class. + + Usually user implements protocols that derived from BaseProtocol + like Protocol or ProcessProtocol. + + The only case when BaseProtocol should be implemented directly is + write-only transport like write pipe + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the pipe connection. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +class Protocol(BaseProtocol): + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing (they don't raise exceptions). + + When the user wants to requests a transport, they pass a protocol + factory to a utility function (e.g., EventLoop.create_connection()). + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_lost() will be called exactly once + with either an exception object or None as an argument. + + State machine of calls: + + start -> CM [-> DR*] [-> ER?] -> CL -> end + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + The default implementation does nothing. + + TODO: By default close the transport. But we don't have the + transport as an instance variable (connection_made() may not + set it). + """ + + +class DatagramProtocol(BaseProtocol): + """ABC representing a datagram protocol.""" + + def datagram_received(self, data, addr): + """Called when some datagram is received.""" + + def connection_refused(self, exc): + """Connection is refused.""" + + +class SubprocessProtocol(BaseProtocol): + """ABC representing a protocol for subprocess calls.""" + + def pipe_data_received(self, fd, data): + """Called when subprocess write a data into stdout/stderr pipes. + + fd is int file dascriptor. + data is bytes object. + """ + + def pipe_connection_lost(self, fd, exc): + """Called when a file descriptor associated with the child process is + closed. + + fd is the int file descriptor that was closed. + """ + + def process_exited(self): + """Called when subprocess has exited. + """ diff --git a/tulip/queues.py b/tulip/queues.py new file mode 100644 index 00000000..b658e67e --- /dev/null +++ b/tulip/queues.py @@ -0,0 +1,284 @@ +"""Queues""" + +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue', + 'Full', 'Empty'] + +import collections +import heapq +import queue + +from . import events +from . import futures +from . import locks +from .tasks import coroutine + + +# Re-export queue.Full and .Empty exceptions. +Full = queue.Full +Empty = queue.Empty + + +class Queue: + """A queue, useful for coordinating producer and consumer coroutines. + + If maxsize is less than or equal to zero, the queue size is infinite. If it + is an integer greater than 0, then "yield from put()" will block when the + queue reaches maxsize, until an item is removed by get(). + + Unlike the standard library Queue, you can reliably know this Queue's size + with qsize(), since your single-threaded Tulip application won't be + interrupted between calling qsize() and doing an operation on the Queue. + """ + + def __init__(self, maxsize=0, *, loop=None): + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._maxsize = maxsize + + # Futures. + self._getters = collections.deque() + # Pairs of (item, Future). + self._putters = collections.deque() + self._init(maxsize) + + def _init(self, maxsize): + self._queue = collections.deque() + + def _get(self): + return self._queue.popleft() + + def _put(self, item): + self._queue.append(item) + + def __repr__(self): + return '<{} at {:#x} {}>'.format( + type(self).__name__, id(self), self._format()) + + def __str__(self): + return '<{} {}>'.format(type(self).__name__, self._format()) + + def _format(self): + result = 'maxsize={!r}'.format(self._maxsize) + if getattr(self, '_queue', None): + result += ' _queue={!r}'.format(list(self._queue)) + if self._getters: + result += ' _getters[{}]'.format(len(self._getters)) + if self._putters: + result += ' _putters[{}]'.format(len(self._putters)) + return result + + def _consume_done_getters(self): + # Delete waiters at the head of the get() queue who've timed out. + while self._getters and self._getters[0].done(): + self._getters.popleft() + + def _consume_done_putters(self): + # Delete waiters at the head of the put() queue who've timed out. + while self._putters and self._putters[0][1].done(): + self._putters.popleft() + + def qsize(self): + """Number of items in the queue.""" + return len(self._queue) + + @property + def maxsize(self): + """Number of items allowed in the queue.""" + return self._maxsize + + def empty(self): + """Return True if the queue is empty, False otherwise.""" + return not self._queue + + def full(self): + """Return True if there are maxsize items in the queue. + + Note: if the Queue was initialized with maxsize=0 (the default), + then full() is never True. + """ + if self._maxsize <= 0: + return False + else: + return self.qsize() == self._maxsize + + @coroutine + def put(self, item): + """Put an item into the queue. + + If you yield from put(), wait until a free slot is available + before adding item. + """ + self._consume_done_getters() + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + waiter = futures.Future(loop=self._loop) + + self._putters.append((item, waiter)) + yield from waiter + + else: + self._put(item) + + def put_nowait(self, item): + """Put an item into the queue without blocking. + + If no free slot is immediately available, raise Full. + """ + self._consume_done_getters() + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + raise Full + else: + self._put(item) + + @coroutine + def get(self): + """Remove and return an item from the queue. + + If you yield from get(), wait until a item is available. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + + # When a getter runs and frees up a slot so this putter can + # run, we need to defer the put for a tick to ensure that + # getters and putters alternate perfectly. See + # ChannelTest.test_wait. + self._loop.call_soon(putter.set_result, None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + waiter = futures.Future(loop=self._loop) + + self._getters.append(waiter) + return (yield from waiter) + + def get_nowait(self): + """Remove and return an item from the queue. + + Return an item if one is immediately available, else raise Full. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + # Wake putter on next tick. + putter.set_result(None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + raise Empty + + +class PriorityQueue(Queue): + """A subclass of Queue; retrieves entries in priority order (lowest first). + + Entries are typically tuples of the form: (priority number, data). + """ + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item, heappush=heapq.heappush): + heappush(self._queue, item) + + def _get(self, heappop=heapq.heappop): + return heappop(self._queue) + + +class LifoQueue(Queue): + """A subclass of Queue that retrieves most recently added entries first.""" + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item): + self._queue.append(item) + + def _get(self): + return self._queue.pop() + + +class JoinableQueue(Queue): + """A subclass of Queue with task_done() and join() methods.""" + + def __init__(self, maxsize=0, *, loop=None): + super().__init__(maxsize=maxsize, loop=loop) + self._unfinished_tasks = 0 + self._finished = locks.EventWaiter(loop=self._loop) + self._finished.set() + + def _format(self): + result = Queue._format(self) + if self._unfinished_tasks: + result += ' tasks={}'.format(self._unfinished_tasks) + return result + + def _put(self, item): + super()._put(item) + self._unfinished_tasks += 1 + self._finished.clear() + + def task_done(self): + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each get() used to fetch a task, + a subsequent call to task_done() tells the queue that the processing + on the task is complete. + + If a join() is currently blocking, it will resume when all items have + been processed (meaning that a task_done() call was received for every + item that had been put() into the queue). + + Raises ValueError if called more times than there were items placed in + the queue. + """ + if self._unfinished_tasks <= 0: + raise ValueError('task_done() called too many times') + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + @coroutine + def join(self): + """Block until all items in the queue have been gotten and processed. + + The count of unfinished tasks goes up whenever an item is added to the + queue. The count goes down whenever a consumer thread calls task_done() + to indicate that the item was retrieved and all work on it is complete. + When the count of unfinished tasks drops to zero, join() unblocks. + """ + if self._unfinished_tasks > 0: + yield from self._finished.wait() diff --git a/tulip/selector_events.py b/tulip/selector_events.py new file mode 100644 index 00000000..82d22bb6 --- /dev/null +++ b/tulip/selector_events.py @@ -0,0 +1,676 @@ +"""Event loop using a selector and related classes. + +A selector is a "notify-when-ready" multiplexer. For a subclass which +also includes support for signal handling, see the unix_events sub-module. +""" + +import collections +import socket +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import base_events +from . import constants +from . import events +from . import futures +from . import selectors +from . import transports +from .log import tulip_log + + +class BaseSelectorEventLoop(base_events.BaseEventLoop): + """Selector event loop. + + See events.EventLoop for API specification. + """ + + def __init__(self, selector=None): + super().__init__() + + if selector is None: + selector = selectors.DefaultSelector() + tulip_log.debug('Using selector: %s', selector.__class__.__name__) + self._selector = selector + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None): + return _SelectorSocketTransport(self, sock, protocol, waiter, extra) + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + server_side=False, extra=None): + return _SelectorSslTransport( + self, rawsock, protocol, sslcontext, waiter, server_side, extra) + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + return _SelectorDatagramTransport(self, sock, protocol, address, extra) + + def close(self): + if self._selector is not None: + self._close_self_pipe() + self._selector.close() + self._selector = None + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self.remove_reader(self._ssock.fileno()) + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.add_reader(self._ssock.fileno(), self._read_from_self) + + def _read_from_self(self): + try: + self._ssock.recv(1) + except (BlockingIOError, InterruptedError): + pass + + def _write_to_self(self): + try: + self._csock.send(b'x') + except (BlockingIOError, InterruptedError): + pass + + def _start_serving(self, protocol_factory, sock, ssl=None): + self.add_reader(sock.fileno(), self._accept_connection, + protocol_factory, sock, ssl) + + def _accept_connection(self, protocol_factory, sock, ssl=None): + try: + conn, addr = sock.accept() + conn.setblocking(False) + except (BlockingIOError, InterruptedError): + pass # False alarm. + except Exception: + # Bad error. Stop serving. + self.remove_reader(sock.fileno()) + sock.close() + # There's nowhere to send the error, so just log it. + # TODO: Someone will want an error handler for this. + tulip_log.exception('Accept failed') + else: + if ssl: + self._make_ssl_transport( + conn, protocol_factory(), ssl, None, + server_side=True, extra={'addr': addr}) + else: + self._make_socket_transport( + conn, protocol_factory(), extra={'addr': addr}) + # It's now up to the protocol to handle the connection. + + def add_reader(self, fd, callback, *args): + """Add a reader callback.""" + handle = events.make_handle(callback, args) + try: + key = self._selector.get_key(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_READ, + (handle, None)) + else: + mask, (reader, writer) = key.events, key.data + self._selector.modify(fd, mask | selectors.EVENT_READ, + (handle, writer)) + if reader is not None: + reader.cancel() + + def remove_reader(self, fd): + """Remove a reader callback.""" + try: + key = self._selector.get_key(fd) + except KeyError: + return False + else: + mask, (reader, writer) = key.events, key.data + mask &= ~selectors.EVENT_READ + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (None, writer)) + + if reader is not None: + reader.cancel() + return True + else: + return False + + def add_writer(self, fd, callback, *args): + """Add a writer callback..""" + handle = events.make_handle(callback, args) + try: + key = self._selector.get_key(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_WRITE, + (None, handle)) + else: + mask, (reader, writer) = key.events, key.data + self._selector.modify(fd, mask | selectors.EVENT_WRITE, + (reader, handle)) + if writer is not None: + writer.cancel() + + def remove_writer(self, fd): + """Remove a writer callback.""" + try: + key = self._selector.get_key(fd) + except KeyError: + return False + else: + mask, (reader, writer) = key.events, key.data + # Remove both writer and connector. + mask &= ~selectors.EVENT_WRITE + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None)) + + if writer is not None: + writer.cancel() + return True + else: + return False + + def sock_recv(self, sock, n): + """XXX""" + fut = futures.Future(loop=self) + self._sock_recv(fut, False, sock, n) + return fut + + def _sock_recv(self, fut, registered, sock, n): + fd = sock.fileno() + if registered: + # Remove the callback early. It should be rare that the + # selector says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_reader(fd) + if fut.cancelled(): + return + try: + data = sock.recv(n) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_recv, fut, True, sock, n) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(data) + + def sock_sendall(self, sock, data): + """XXX""" + fut = futures.Future(loop=self) + if data: + self._sock_sendall(fut, False, sock, data) + else: + fut.set_result(None) + return fut + + def _sock_sendall(self, fut, registered, sock, data): + fd = sock.fileno() + + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + + try: + n = sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + fut.set_exception(exc) + return + + if n == len(data): + fut.set_result(None) + else: + if n: + data = data[n:] + self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + + def sock_connect(self, sock, address): + """XXX""" + # That address better not require a lookup! We're not calling + # self.getaddrinfo() for you here. But verifying this is + # complicated; the socket module doesn't have a pattern for + # IPv6 addresses (there are too many forms, apparently). + fut = futures.Future(loop=self) + self._sock_connect(fut, False, sock, address) + return fut + + def _sock_connect(self, fut, registered, sock, address): + # TODO: Use getaddrinfo() to look up the address, to avoid the + # trap of hanging the entire event loop when the address + # requires doing a DNS lookup. (OTOH, the caller should + # already have done this, so it would be nice if we could + # easily tell whether the address needs looking up or not. I + # know how to do this for IPv4, but IPv6 addresses have many + # syntaxes.) + fd = sock.fileno() + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + try: + if not registered: + # First time around. + sock.connect(address) + else: + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to the except clause below. + raise OSError(err, 'Connect call failed') + except (BlockingIOError, InterruptedError): + self.add_writer(fd, self._sock_connect, fut, True, sock, address) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(None) + + def sock_accept(self, sock): + """XXX""" + fut = futures.Future(loop=self) + self._sock_accept(fut, False, sock) + return fut + + def _sock_accept(self, fut, registered, sock): + fd = sock.fileno() + if registered: + self.remove_reader(fd) + if fut.cancelled(): + return + try: + conn, address = sock.accept() + conn.setblocking(False) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_accept, fut, True, sock) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result((conn, address)) + + def _process_events(self, event_list): + for key, mask in event_list: + fileobj, (reader, writer) = key.fileobj, key.data + if mask & selectors.EVENT_READ and reader is not None: + if reader._cancelled: + self.remove_reader(fileobj) + else: + self._add_callback(reader) + if mask & selectors.EVENT_WRITE and writer is not None: + if writer._cancelled: + self.remove_writer(fileobj) + else: + self._add_callback(writer) + + def stop_serving(self, sock): + self.remove_reader(sock.fileno()) + sock.close() + + +class _SelectorTransport(transports.Transport): + + def __init__(self, loop, sock, protocol, extra): + super().__init__(extra) + self._extra['socket'] = sock + self._loop = loop + self._sock = sock + self._sock_fd = sock.fileno() + self._protocol = protocol + self._buffer = [] + self._conn_lost = 0 + self._writing = True + self._closing = False # Set when close() called. + + def abort(self): + self._force_close(None) + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + self._loop.remove_reader(self._sock_fd) + if not self._buffer: + self._loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + # should be called from exception handler only + tulip_log.exception('Fatal error for %s', self) + self._force_close(exc) + + def _force_close(self, exc): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + self._loop.remove_writer(self._sock_fd) + self._loop.remove_reader(self._sock_fd) + self._buffer.clear() + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + self._sock = None + self._protocol = None + self._loop = None + + +class _SelectorSocketTransport(_SelectorTransport): + + def __init__(self, loop, sock, protocol, waiter=None, extra=None): + super().__init__(loop, sock, protocol, extra) + + self._loop.add_reader(self._sock_fd, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = self._sock.recv(16*1024) + except (BlockingIOError, InterruptedError): + pass + except ConnectionResetError as exc: + self._force_close(exc) + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + try: + self._protocol.eof_received() + finally: + self.close() + + def write(self, data): + assert isinstance(data, bytes), repr(data) + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer and self._writing: + # Attempt to send it right away first. + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except OSError as exc: + self._fatal_error(exc) + return + + if n == len(data): + return + elif n: + data = data[n:] + self._loop.add_writer(self._sock_fd, self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + if not self._writing: + return # transmission off + + data = b''.join(self._buffer) + assert data, 'Data should not be empty' + + self._buffer.clear() + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except Exception as exc: + self._fatal_error(exc) + else: + if n == len(data): + self._loop.remove_writer(self._sock_fd) + if self._closing: + self._call_connection_lost(None) + return + elif n: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def pause_writing(self): + if self._writing: + if self._buffer: + self._loop.remove_writer(self._sock_fd) + self._writing = False + + def resume_writing(self): + if not self._writing: + if self._buffer: + self._loop.add_writer(self._sock_fd, self._write_ready) + self._writing = True + + def discard_output(self): + if self._buffer: + self._loop.remove_writer(self._sock_fd) + self._buffer.clear() + + +class _SelectorSslTransport(_SelectorTransport): + + def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, + server_side=False, extra=None): + if server_side: + assert isinstance( + sslcontext, ssl.SSLContext), 'Must pass an SSLContext' + else: + # Client-side may pass ssl=True to use a default context. + sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslsock = sslcontext.wrap_socket(rawsock, server_side=server_side, + do_handshake_on_connect=False) + + super().__init__(loop, sslsock, protocol, extra) + + self._waiter = waiter + self._rawsock = rawsock + self._sslcontext = sslcontext + + self._on_handshake() + + def _on_handshake(self): + try: + self._sock.do_handshake() + except ssl.SSLWantReadError: + self._loop.add_reader(self._sock_fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._loop.add_writer(self._sock_fd, self._on_handshake) + return + except Exception as exc: + self._sock.close() + if self._waiter is not None: + self._waiter.set_exception(exc) + return + except BaseException as exc: + self._sock.close() + if self._waiter is not None: + self._waiter.set_exception(exc) + raise + self._loop.remove_reader(self._sock_fd) + self._loop.remove_writer(self._sock_fd) + self._loop.add_reader(self._sock_fd, self._on_ready) + self._loop.add_writer(self._sock_fd, self._on_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if self._waiter is not None: + self._loop.call_soon(self._waiter.set_result, None) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # First try reading. + if not self._closing: + try: + data = self._sock.recv(8192) + except (BlockingIOError, InterruptedError, + ssl.SSLWantReadError, ssl.SSLWantWriteError): + pass + except ConnectionResetError as exc: + self._force_close(exc) + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + try: + self._protocol.eof_received() + finally: + self.close() + + # Now try writing, if there's anything to write. + if self._buffer: + data = b''.join(self._buffer) + self._buffer = [] + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError, + ssl.SSLWantReadError, ssl.SSLWantWriteError): + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + + if n < len(data): + self._buffer.append(data[n:]) + + if self._closing and not self._buffer: + self._loop.remove_writer(self._sock_fd) + self._call_connection_lost(None) + + def write(self, data): + assert isinstance(data, bytes), repr(data) + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + self._loop.remove_reader(self._sock_fd) + + # TODO: write_eof(), can_write_eof(). + + +class _SelectorDatagramTransport(_SelectorTransport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, loop, sock, protocol, address=None, extra=None): + super().__init__(loop, sock, protocol, extra) + + self._address = address + self._buffer = collections.deque() + self._loop.add_reader(self._sock_fd, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + + def _read_ready(self): + try: + data, addr = self._sock.recvfrom(self.max_size) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + self._protocol.datagram_received(data, addr) + + def sendto(self, data, addr=None): + assert isinstance(data, bytes), repr(data) + if not data: + return + + if self._address: + assert addr in (None, self._address) + + if self._conn_lost and self._address: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Attempt to send it right away first. + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + return + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except (BlockingIOError, InterruptedError): + self._loop.add_writer(self._sock_fd, self._sendto_ready) + except Exception as exc: + self._fatal_error(exc) + return + + self._buffer.append((data, addr)) + + def _sendto_ready(self): + while self._buffer: + data, addr = self._buffer.popleft() + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except (BlockingIOError, InterruptedError): + self._buffer.appendleft((data, addr)) # Try again later. + break + except Exception as exc: + self._fatal_error(exc) + return + + if not self._buffer: + self._loop.remove_writer(self._sock_fd) + if self._closing: + self._call_connection_lost(None) + + def _force_close(self, exc): + if self._address and isinstance(exc, ConnectionRefusedError): + self._protocol.connection_refused(exc) + + super()._force_close(exc) diff --git a/tulip/selectors.py b/tulip/selectors.py new file mode 100644 index 00000000..b81b1dbe --- /dev/null +++ b/tulip/selectors.py @@ -0,0 +1,410 @@ +"""Selectors module. + +This module allows high-level and efficient I/O multiplexing, built upon the +`select` module primitives. +""" + + +from abc import ABCMeta, abstractmethod +from collections import namedtuple +import functools +import select +import sys + + +# generic events, that must be mapped to implementation-specific ones +EVENT_READ = (1 << 0) +EVENT_WRITE = (1 << 1) + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file object or file descriptor + + Returns: + corresponding file descriptor + """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (AttributeError, ValueError): + raise ValueError("Invalid file object: " + "{!r}".format(fileobj)) from None + if fd < 0: + raise ValueError("Invalid file descriptor: {}".format(fd)) + return fd + + +SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data']) +"""Object used to associate a file object to its backing file descriptor, +selected event mask and attached data.""" + + +class BaseSelector(metaclass=ABCMeta): + """Base selector class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + performant implementation on the current platform. + """ + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # this maps file objects to keys - for fast (un)registering + self._fileobj_to_key = {} + + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object or file descriptor + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + """ + if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)): + raise ValueError("Invalid events: {}".format(events)) + + key = SelectorKey(fileobj, _fileobj_to_fd(fileobj), events, data) + + if key.fd in self._fd_to_key: + raise KeyError("{!r} (FD {}) is already " + "registered".format(fileobj, key.fd)) + + self._fd_to_key[key.fd] = key + self._fileobj_to_key[fileobj] = key + return key + + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object or file descriptor + + Returns: + SelectorKey instance + """ + try: + key = self._fileobj_to_key.pop(fileobj) + del self._fd_to_key[key.fd] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + return key + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object or file descriptor + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + """ + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + if events != key.events or data != key.data: + # TODO: If only the data changed, use a shortcut that only + # updates the data. + self.unregister(fileobj) + return self.register(fileobj, events, data) + else: + return key + + @abstractmethod + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout <= 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (key, events) for ready file objects + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE + """ + raise NotImplementedError() + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + self._fd_to_key.clear() + self._fileobj_to_key.clear() + + def get_key(self, fileobj): + """Return the key associated to a registered file object. + + Returns: + SelectorKey for this file object + """ + try: + return self._fileobj_to_key[fileobj] + except KeyError: + raise KeyError("{} is not registered".format(fileobj)) from None + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key, or None if not found + """ + try: + return self._fd_to_key[fd] + except KeyError: + return None + + +class SelectSelector(BaseSelector): + """Select-based selector.""" + + def __init__(self): + super().__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & EVENT_WRITE: + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + if sys.platform == 'win32': + def _select(self, r, w, _, timeout=None): + r, w, x = select.select(r, w, w, timeout) + return r, w + x, [] + else: + _select = select.select + + def select(self, timeout=None): + timeout = None if timeout is None else max(timeout, 0) + ready = [] + try: + r, w, _ = self._select(self._readers, self._writers, [], timeout) + except InterruptedError: + return ready + r = set(r) + w = set(w) + for fd in r | w: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + +if hasattr(select, 'poll'): + + class PollSelector(BaseSelector): + """Poll-based selector.""" + + def __init__(self): + super().__init__() + self._poll = select.poll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= select.POLLIN + if events & EVENT_WRITE: + poll_events |= select.POLLOUT + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = None if timeout is None else max(int(1000 * timeout), 0) + ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & ~select.POLLIN: + events |= EVENT_WRITE + if event & ~select.POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + +if hasattr(select, 'epoll'): + + class EpollSelector(BaseSelector): + """Epoll-based selector.""" + + def __init__(self): + super().__init__() + self._epoll = select.epoll() + + def fileno(self): + return self._epoll.fileno() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + epoll_events = 0 + if events & EVENT_READ: + epoll_events |= select.EPOLLIN + if events & EVENT_WRITE: + epoll_events |= select.EPOLLOUT + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._epoll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = -1 if timeout is None else max(timeout, 0) + max_ev = len(self._fd_to_key) + ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & ~select.EPOLLIN: + events |= EVENT_WRITE + if event & ~select.EPOLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + super().close() + self._epoll.close() + + +if hasattr(select, 'kqueue'): + + class KqueueSelector(BaseSelector): + """Kqueue-based selector.""" + + def __init__(self): + super().__init__() + self._kqueue = select.kqueue() + + def fileno(self): + return self._kqueue.fileno() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + if key.events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + if key.events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + return key + + def select(self, timeout=None): + timeout = None if timeout is None else max(timeout, 0) + max_ev = len(self._fd_to_key) + ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except InterruptedError: + return ready + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == select.KQ_FILTER_READ: + events |= EVENT_READ + if flag == select.KQ_FILTER_WRITE: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + super().close() + self._kqueue.close() + + +# Choose the best implementation: roughly, epoll|kqueue > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + DefaultSelector = KqueueSelector +elif 'EpollSelector' in globals(): + DefaultSelector = EpollSelector +elif 'PollSelector' in globals(): + DefaultSelector = PollSelector +else: + DefaultSelector = SelectSelector diff --git a/tulip/streams.py b/tulip/streams.py new file mode 100644 index 00000000..3203b7d6 --- /dev/null +++ b/tulip/streams.py @@ -0,0 +1,211 @@ +"""Stream-related things.""" + +__all__ = ['StreamReader', 'StreamReaderProtocol', 'open_connection'] + +import collections + +from . import events +from . import futures +from . import protocols +from . import tasks + + +_DEFAULT_LIMIT = 2**16 + + +@tasks.coroutine +def open_connection(host=None, port=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """A wrapper for create_connection() returning a (reader, writer) pair. + + The reader returned is a StreamReader instance; the writer is a + Transport. + + The arguments are all the usual arguments to create_connection() + except protocol_factory; most common are positional host and port, + with various optional keyword arguments following. + + Additional optional keyword arguments are loop (to set the event loop + instance to use) and limit (to set the buffer limit passed to the + StreamReader). + + (If you want to customize the StreamReader and/or + StreamReaderProtocol classes, just copy the code -- there's + really nothing special here except some convenience.) + """ + if loop is None: + loop = events.get_event_loop() + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader) + transport, _ = yield from loop.create_connection( + lambda: protocol, host, port, **kwds) + return reader, transport # (reader, writer) + + +class StreamReaderProtocol(protocols.Protocol): + """Trivial helper class to adapt between Protocol and StreamReader. + + (This is a helper class instead of making StreamReader itself a + Protocol subclass, because the StreamReader has other potential + uses, and to prevent the user of the StreamReader to accidentally + call inappropriate methods of the protocol.) + """ + + def __init__(self, stream_reader): + self.stream_reader = stream_reader + + def connection_lost(self, exc): + if exc is None: + self.stream_reader.feed_eof() + else: + self.stream_reader.set_exception(exc) + + def data_received(self, data): + self.stream_reader.feed_data(data) + + def eof_received(self): + self.stream_reader.feed_eof() + + +class StreamReader: + + def __init__(self, limit=_DEFAULT_LIMIT, loop=None): + self.limit = limit # Max line length. (Security feature.) + if loop is None: + loop = events.get_event_loop() + self.loop = loop + self.buffer = collections.deque() # Deque of bytes objects. + self.byte_count = 0 # Bytes in buffer. + self.eof = False # Whether we're done. + self.waiter = None # A future. + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_exception(exc) + + def feed_eof(self): + self.eof = True + waiter = self.waiter + if waiter is not None: + self.waiter = None + if not waiter.done(): + waiter.set_result(True) + + def feed_data(self, data): + if not data: + return + + self.buffer.append(data) + self.byte_count += len(data) + + waiter = self.waiter + if waiter is not None: + self.waiter = None + if not waiter.done(): + waiter.set_result(False) + + @tasks.coroutine + def readline(self): + if self._exception is not None: + raise self._exception + + parts = [] + parts_size = 0 + not_enough = True + + while not_enough: + while self.buffer and not_enough: + data = self.buffer.popleft() + ichar = data.find(b'\n') + if ichar < 0: + parts.append(data) + parts_size += len(data) + else: + ichar += 1 + head, tail = data[:ichar], data[ichar:] + if tail: + self.buffer.appendleft(tail) + not_enough = False + parts.append(head) + parts_size += len(head) + + if parts_size > self.limit: + self.byte_count -= parts_size + raise ValueError('Line is too long') + + if self.eof: + break + + if not_enough: + assert self.waiter is None + self.waiter = futures.Future(loop=self.loop) + yield from self.waiter + + line = b''.join(parts) + self.byte_count -= parts_size + + return line + + @tasks.coroutine + def read(self, n=-1): + if self._exception is not None: + raise self._exception + + if not n: + return b'' + + if n < 0: + while not self.eof: + assert not self.waiter + self.waiter = futures.Future(loop=self.loop) + yield from self.waiter + else: + if not self.byte_count and not self.eof: + assert not self.waiter + self.waiter = futures.Future(loop=self.loop) + yield from self.waiter + + if n < 0 or self.byte_count <= n: + data = b''.join(self.buffer) + self.buffer.clear() + self.byte_count = 0 + return data + + parts = [] + parts_bytes = 0 + while self.buffer and parts_bytes < n: + data = self.buffer.popleft() + data_bytes = len(data) + if n < parts_bytes + data_bytes: + data_bytes = n - parts_bytes + data, rest = data[:data_bytes], data[data_bytes:] + self.buffer.appendleft(rest) + + parts.append(data) + parts_bytes += data_bytes + self.byte_count -= data_bytes + + return b''.join(parts) + + @tasks.coroutine + def readexactly(self, n): + if self._exception is not None: + raise self._exception + + if n <= 0: + return b'' + + while self.byte_count < n and not self.eof: + assert not self.waiter + self.waiter = futures.Future(loop=self.loop) + yield from self.waiter + + return (yield from self.read(n)) diff --git a/tulip/tasks.py b/tulip/tasks.py new file mode 100644 index 00000000..a51ee29a --- /dev/null +++ b/tulip/tasks.py @@ -0,0 +1,321 @@ +"""Support for tasks, coroutines and the scheduler.""" + +__all__ = ['coroutine', 'Task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'wait_for', 'as_completed', 'sleep', 'async', + ] + +import collections +import concurrent.futures +import functools +import inspect + +from . import events +from . import futures + + +def coroutine(func): + """Decorator to mark coroutines. + + Decorator wraps non generator functions and returns generator wrapper. + If non generator function returns generator of Future it yield-from it. + + TODO: This is a feel-good API only. It is not enforced. + """ + if inspect.isgeneratorfunction(func): + coro = func + else: + @functools.wraps(func) + def coro(*args, **kw): + res = func(*args, **kw) + if isinstance(res, futures.Future) or inspect.isgenerator(res): + res = yield from res + return res + + coro._is_coroutine = True # Not sure who can use this. + return coro + + +# TODO: Do we need this? +def iscoroutinefunction(func): + """Return True if func is a decorated coroutine function.""" + return (inspect.isgeneratorfunction(func) and + getattr(func, '_is_coroutine', False)) + + +# TODO: Do we need this? +def iscoroutine(obj): + """Return True if obj is a coroutine object.""" + return inspect.isgenerator(obj) # TODO: And what? + + +class Task(futures.Future): + """A coroutine wrapped in a Future.""" + + def __init__(self, coro, *, loop=None): + assert inspect.isgenerator(coro) # Must be a coroutine *object*. + super().__init__(loop=loop) + self._coro = coro + self._fut_waiter = None + self._must_cancel = False + self._loop.call_soon(self._step) + + def __repr__(self): + res = super().__repr__() + if (self._must_cancel and + self._state == futures._PENDING and + ')'.format(self._coro.__name__) + res[i:] + return res + + def cancel(self): + if self.done(): + return False + if self._fut_waiter is not None: + if self._fut_waiter.cancel(): + return True + # It must be the case that self._step is already scheduled. + self._must_cancel = True + return True + + def _step(self, value=None, exc=None): + assert not self.done(), \ + '_step(): already done: {!r}, {!r}, {!r}'.format(self, value, exc) + if self._must_cancel: + assert self._fut_waiter is None + exc = futures.CancelledError() + value = None + coro = self._coro + self._fut_waiter = None + # Call either coro.throw(exc) or coro.send(value). + try: + if exc is not None: + result = coro.throw(exc) + elif value is not None: + result = coro.send(value) + else: + result = next(coro) + except StopIteration as exc: + self.set_result(exc.value) + except futures.CancelledError as exc: + super().cancel() # I.e., Future.cancel(self). + except Exception as exc: + self.set_exception(exc) + except BaseException as exc: + self.set_exception(exc) + raise + else: + if isinstance(result, futures.Future): + # Yielded Future must come from Future.__iter__(). + if result._blocking: + result._blocking = False + result.add_done_callback(self._wakeup) + self._fut_waiter = result + else: + self._loop.call_soon( + self._step, None, + RuntimeError( + 'yield was used instead of yield from ' + 'in task {!r} with {!r}'.format(self, result))) + elif result is None: + # Bare yield relinquishes control for one event loop iteration. + self._loop.call_soon(self._step) + elif inspect.isgenerator(result): + # Yielding a generator is just wrong. + self._loop.call_soon( + self._step, None, + RuntimeError( + 'yield was used instead of yield from for ' + 'generator in task {!r} with {}'.format( + self, result))) + else: + # Yielding something else is an error. + self._loop.call_soon( + self._step, None, + RuntimeError( + 'Task got bad yield: {!r}'.format(result))) + self = None + + def _wakeup(self, future): + try: + value = future.result() + except Exception as exc: + self._step(None, exc) + else: + self._step(value, None) + self = None # Needed to break cycles when an exception occurs. + + +# wait() and as_completed() similar to those in PEP 3148. + +FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED +FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION +ALL_COMPLETED = concurrent.futures.ALL_COMPLETED + + +@coroutine +def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): + """Wait for the Futures and coroutines given by fs to complete. + + Coroutines will be wrapped in Tasks. + + Returns two sets of Future: (done, pending). + + Usage: + + done, pending = yield from tulip.wait(fs) + + Note: This does not raise TimeoutError! Futures that aren't done + when the timeout occurs are returned in the second set. + """ + if not fs: + raise ValueError('Set of coroutines/Futures is empty.') + + if loop is None: + loop = events.get_event_loop() + + fs = set(async(f, loop=loop) for f in fs) + + if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED): + raise ValueError('Invalid return_when value: {}'.format(return_when)) + return (yield from _wait(fs, timeout, return_when, loop)) + + +@coroutine +def wait_for(fut, timeout, *, loop=None): + """Wait for the single Future or coroutine to complete, with timeout. + + Coroutine will be wrapped in Task. + + Returns result of the Future or coroutine. Raises TimeoutError when + timeout occurs. + + Usage: + + result = yield from tulip.wait_for(fut, 10.0) + + """ + if loop is None: + loop = events.get_event_loop() + + fut = async(fut, loop=loop) + + done, pending = yield from _wait([fut], timeout, FIRST_COMPLETED, loop) + if done: + return done.pop().result() + + raise futures.TimeoutError() + + +@coroutine +def _wait(fs, timeout, return_when, loop): + """Internal helper for wait(return_when=FIRST_COMPLETED). + + The fs argument must be a set of Futures. + The timeout argument is like for wait(). + """ + assert fs, 'Set of Futures is empty.' + waiter = futures.Future(loop=loop) + timeout_handle = None + if timeout is not None: + timeout_handle = loop.call_later(timeout, waiter.cancel) + counter = len(fs) + + def _on_completion(f): + nonlocal counter + counter -= 1 + if (counter <= 0 or + return_when == FIRST_COMPLETED or + return_when == FIRST_EXCEPTION and (not f.cancelled() and + f.exception() is not None)): + if timeout_handle is not None: + timeout_handle.cancel() + waiter.cancel() + + for f in fs: + f.add_done_callback(_on_completion) + try: + yield from waiter + except futures.CancelledError: + pass + done, pending = set(), set() + for f in fs: + f.remove_done_callback(_on_completion) + if f.done(): + done.add(f) + else: + pending.add(f) + return done, pending + + +# This is *not* a @coroutine! It is just an iterator (yielding Futures). +def as_completed(fs, *, loop=None, timeout=None): + """Return an iterator whose values, when waited for, are Futures. + + This differs from PEP 3148; the proper way to use this is: + + for f in as_completed(fs): + result = yield from f # The 'yield from' may raise. + # Use result. + + Raises TimeoutError if the timeout occurs before all Futures are + done. + + Note: The futures 'f' are not necessarily members of fs. + """ + loop = loop if loop is not None else events.get_event_loop() + deadline = None if timeout is None else loop.time() + timeout + todo = set(async(f, loop=loop) for f in fs) + completed = collections.deque() + + @coroutine + def _wait_for_one(): + while not completed: + timeout = None + if deadline is not None: + timeout = deadline - loop.time() + if timeout < 0: + raise futures.TimeoutError() + done, pending = yield from _wait( + todo, timeout, FIRST_COMPLETED, loop) + # Multiple callers might be waiting for the same events + # and getting the same outcome. Dedupe by updating todo. + for f in done: + if f in todo: + todo.remove(f) + completed.append(f) + f = completed.popleft() + return f.result() # May raise. + + for _ in range(len(todo)): + yield _wait_for_one() + + +@coroutine +def sleep(delay, result=None, *, loop=None): + """Coroutine that completes after a given time (in seconds).""" + future = futures.Future(loop=loop) + h = future._loop.call_later(delay, future.set_result, result) + try: + return (yield from future) + finally: + h.cancel() + + +def async(coro_or_future, *, loop=None): + """Wrap a coroutine in a future. + + If the argument is a Future, it is returned directly. + """ + if isinstance(coro_or_future, futures.Future): + if loop is not None and loop is not coro_or_future._loop: + raise ValueError('loop argument must agree with Future') + return coro_or_future + elif iscoroutine(coro_or_future): + return Task(coro_or_future, loop=loop) + else: + raise TypeError('A Future or coroutine is required') diff --git a/tulip/test_utils.py b/tulip/test_utils.py new file mode 100644 index 00000000..b4af0c89 --- /dev/null +++ b/tulip/test_utils.py @@ -0,0 +1,443 @@ +"""Utilities shared by tests.""" + +import cgi +import collections +import contextlib +import gc +import email.parser +import http.server +import json +import logging +import io +import unittest.mock +import os +import re +import socket +import sys +import threading +import traceback +import unittest +import unittest.mock +import urllib.parse +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +import tulip +import tulip.http +from tulip.http import client +from tulip import base_events +from tulip import events + +from tulip import base_events +from tulip import selectors + + +if sys.platform == 'win32': # pragma: no cover + from .windows_utils import socketpair +else: + from socket import socketpair # pragma: no cover + + +def run_briefly(loop): + @tulip.coroutine + def once(): + pass + t = tulip.Task(once(), loop=loop) + loop.run_until_complete(t) + + +def run_once(loop): + """loop.stop() schedules _raise_stop_error() + and run_forever() runs until _raise_stop_error() callback. + this wont work if test waits for some IO events, because + _raise_stop_error() runs before any of io events callbacks. + """ + loop.stop() + loop.run_forever() + + +@contextlib.contextmanager +def run_test_server(loop, *, host='127.0.0.1', port=0, + use_ssl=False, router=None): + properties = {} + transports = [] + + class HttpServer: + + def __init__(self, host, port): + self.host = host + self.port = port + self.address = (host, port) + self._url = '{}://{}:{}'.format( + 'https' if use_ssl else 'http', host, port) + + def __getitem__(self, key): + return properties[key] + + def __setitem__(self, key, value): + properties[key] = value + + def url(self, *suffix): + return urllib.parse.urljoin( + self._url, '/'.join(str(s) for s in suffix)) + + class TestHttpServer(tulip.http.ServerHttpProtocol): + + def connection_made(self, transport): + transports.append(transport) + super().connection_made(transport) + + def handle_request(self, message, payload): + if properties.get('close', False): + return + + if properties.get('noresponse', False): + yield from tulip.sleep(99999) + + if router is not None: + body = bytearray() + chunk = yield from payload.read() + while chunk: + body.extend(chunk) + chunk = yield from payload.read() + + rob = router( + self, properties, + self.transport, message, bytes(body)) + rob.dispatch() + + else: + response = tulip.http.Response( + self.transport, 200, message.version) + + text = b'Test message' + response.add_header('Content-type', 'text/plain') + response.add_header('Content-length', str(len(text))) + response.send_headers() + response.write(text) + response.write_eof() + + if use_ssl: + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + keyfile = os.path.join(here, 'sample.key') + certfile = os.path.join(here, 'sample.crt') + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain(certfile, keyfile) + else: + sslcontext = None + + def run(loop, fut): + thread_loop = tulip.new_event_loop() + tulip.set_event_loop(thread_loop) + + socks = thread_loop.run_until_complete( + thread_loop.start_serving( + lambda: TestHttpServer(keep_alive=0.5), + host, port, ssl=sslcontext)) + + waiter = tulip.Future(loop=thread_loop) + loop.call_soon_threadsafe( + fut.set_result, (thread_loop, waiter, socks[0].getsockname())) + + try: + thread_loop.run_until_complete(waiter) + finally: + # call pending connection_made if present + run_briefly(thread_loop) + + # close opened trnsports + for tr in transports: + tr.close() + + run_briefly(thread_loop) # call close callbacks + + for s in socks: + thread_loop.stop_serving(s) + + thread_loop.stop() + thread_loop.close() + gc.collect() + + fut = tulip.Future(loop=loop) + server_thread = threading.Thread(target=run, args=(loop, fut)) + server_thread.start() + + thread_loop, waiter, addr = loop.run_until_complete(fut) + try: + yield HttpServer(*addr) + finally: + thread_loop.call_soon_threadsafe(waiter.set_result, None) + server_thread.join() + + +class Router: + + _response_version = "1.1" + _responses = http.server.BaseHTTPRequestHandler.responses + + def __init__(self, srv, props, transport, message, payload): + # headers + self._headers = http.client.HTTPMessage() + for hdr, val in message.headers: + self._headers.add_header(hdr, val) + + self._srv = srv + self._props = props + self._transport = transport + self._method = message.method + self._uri = message.path + self._version = message.version + self._compression = message.compression + self._body = payload + + url = urllib.parse.urlsplit(self._uri) + self._path = url.path + self._query = url.query + + @staticmethod + def define(rmatch): + def wrapper(fn): + f_locals = sys._getframe(1).f_locals + mapping = f_locals.setdefault('_mapping', []) + mapping.append((re.compile(rmatch), fn.__name__)) + return fn + + return wrapper + + def dispatch(self): # pragma: no cover + for route, fn in self._mapping: + match = route.match(self._path) + if match is not None: + try: + return getattr(self, fn)(match) + except Exception: + out = io.StringIO() + traceback.print_exc(file=out) + self._response(500, out.getvalue()) + + return + + return self._response(self._start_response(404)) + + def _start_response(self, code): + return tulip.http.Response(self._transport, code) + + def _response(self, response, body=None, headers=None, chunked=False): + r_headers = {} + for key, val in self._headers.items(): + key = '-'.join(p.capitalize() for p in key.split('-')) + r_headers[key] = val + + encoding = self._headers.get('content-encoding', '').lower() + if 'gzip' in encoding: # pragma: no cover + cmod = 'gzip' + elif 'deflate' in encoding: + cmod = 'deflate' + else: + cmod = '' + + resp = { + 'method': self._method, + 'version': '%s.%s' % self._version, + 'path': self._uri, + 'headers': r_headers, + 'origin': self._transport.get_extra_info('addr', ' ')[0], + 'query': self._query, + 'form': {}, + 'compression': cmod, + 'multipart-data': [] + } + if body: # pragma: no cover + resp['content'] = body + + ct = self._headers.get('content-type', '').lower() + + # application/x-www-form-urlencoded + if ct == 'application/x-www-form-urlencoded': + resp['form'] = urllib.parse.parse_qs(self._body.decode('latin1')) + + # multipart/form-data + elif ct.startswith('multipart/form-data'): # pragma: no cover + out = io.BytesIO() + for key, val in self._headers.items(): + out.write(bytes('{}: {}\r\n'.format(key, val), 'latin1')) + + out.write(b'\r\n') + out.write(self._body) + out.write(b'\r\n') + out.seek(0) + + message = email.parser.BytesParser().parse(out) + if message.is_multipart(): + for msg in message.get_payload(): + if msg.is_multipart(): + logging.warn('multipart msg is not expected') + else: + key, params = cgi.parse_header( + msg.get('content-disposition', '')) + params['data'] = msg.get_payload() + params['content-type'] = msg.get_content_type() + resp['multipart-data'].append(params) + + body = json.dumps(resp, indent=4, sort_keys=True) + + # default headers + hdrs = [('Connection', 'close'), + ('Content-Type', 'application/json')] + if chunked: + hdrs.append(('Transfer-Encoding', 'chunked')) + else: + hdrs.append(('Content-Length', str(len(body)))) + + # extra headers + if headers: + hdrs.extend(headers.items()) + + if chunked: + response.force_chunked() + + # headers + response.add_headers(*hdrs) + response.send_headers() + + # write payload + response.write(client.str_to_bytes(body)) + response.write_eof() + + # keep-alive + if response.keep_alive(): + self._srv.keep_alive(True) + + +def make_test_protocol(base): + dct = {} + for name in dir(base): + if name.startswith('__') and name.endswith('__'): + # skip magic names + continue + dct[name] = unittest.mock.Mock(return_value=None) + return type('TestProtocol', (base,) + base.__bases__, dct)() + + +class TestSelector(selectors.BaseSelector): + + def select(self, timeout): + return [] + + +class TestLoop(base_events.BaseEventLoop): + """Loop for unittests. + + It manages self time directly. + If something scheduled to be executed later then + on next loop iteration after all ready handlers done + generator passed to __init__ is calling. + + Generator should be like this: + + def gen(): + ... + when = yield ... + ... = yield time_advance + + Value retuned by yield is absolute time of next scheduled handler. + Value passed to yield is time advance to move loop's time forward. + """ + + def __init__(self, gen=None): + super().__init__() + + if gen is None: + self._check_on_close = False + def gen(): + yield + else: + self._check_on_close = True + + self._gen = gen() + next(self._gen) + self._time = 0 + self._timers = [] + self._selector = TestSelector() + + self.readers = {} + self.writers = {} + self.reset_counters() + + def time(self): + return self._time + + def advance_time(self, advance): + """Move test time forward.""" + if advance: + self._time += advance + + def close(self): + if self._check_on_close: + try: + self._gen.send(0) + except StopIteration: + pass + else: # pragma: no cover + raise AssertionError("Time generator is not finished") + + def add_reader(self, fd, callback, *args): + self.readers[fd] = events.make_handle(callback, args) + + def remove_reader(self, fd): + self.remove_reader_count[fd] += 1 + if fd in self.readers: + del self.readers[fd] + return True + else: + return False + + def assert_reader(self, fd, callback, *args): + assert fd in self.readers, 'fd {} is not registered'.format(fd) + handle = self.readers[fd] + assert handle._callback == callback, '{!r} != {!r}'.format( + handle._callback, callback) + assert handle._args == args, '{!r} != {!r}'.format( + handle._args, args) + + def add_writer(self, fd, callback, *args): + self.writers[fd] = events.make_handle(callback, args) + + def remove_writer(self, fd): + self.remove_writer_count[fd] += 1 + if fd in self.writers: + del self.writers[fd] + return True + else: + return False + + def assert_writer(self, fd, callback, *args): + assert fd in self.writers, 'fd {} is not registered'.format(fd) + handle = self.writers[fd] + assert handle._callback == callback, '{!r} != {!r}'.format( + handle._callback, callback) + assert handle._args == args, '{!r} != {!r}'.format( + handle._args, args) + + def reset_counters(self): + self.remove_reader_count = collections.defaultdict(int) + self.remove_writer_count = collections.defaultdict(int) + + def _run_once(self): + super()._run_once() + for when in self._timers: + advance = self._gen.send(when) + self.advance_time(advance) + self._timers = [] + + def call_at(self, when, callback, *args): + self._timers.append(when) + return super().call_at(when, callback, *args) + + def _process_events(self, event_list): + return + + def _write_to_self(self): + pass diff --git a/tulip/transports.py b/tulip/transports.py new file mode 100644 index 00000000..56425aa9 --- /dev/null +++ b/tulip/transports.py @@ -0,0 +1,201 @@ +"""Abstract Transport class.""" + +__all__ = ['ReadTransport', 'WriteTransport', 'Transport'] + + +class BaseTransport: + """Base ABC for transports.""" + + def __init__(self, extra=None): + if extra is None: + extra = {} + self._extra = extra + + def get_extra_info(self, name, default=None): + """Get optional transport information.""" + return self._extra.get(name, default) + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + raise NotImplementedError + + +class ReadTransport(BaseTransport): + """ABC for read-only transports.""" + + def pause(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + raise NotImplementedError + + +class WriteTransport(BaseTransport): + """ABC for write-only transports.""" + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def write_eof(self): + """Closes the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + raise NotImplementedError + + def can_write_eof(self): + """Return True if this protocol supports write_eof(), False if not.""" + raise NotImplementedError + + def pause_writing(self): + """Pause transmission on the transport. + + Subsequent writes are deferred until resume_writing() is called. + """ + raise NotImplementedError + + def resume_writing(self): + """Resume transmission on the transport. """ + raise NotImplementedError + + def discard_output(self): + """Discard any buffered data awaiting transmission on the transport.""" + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class Transport(ReadTransport, WriteTransport): + """ABC representing a bidirectional transport. + + There may be several implementations, but typically, the user does + not implement new transports; rather, the platform provides some + useful transports that are implemented using the platform's best + practices. + + The user never instantiates a transport directly; they call a + utility function, passing it a protocol factory and other + information necessary to create the transport and protocol. (E.g. + EventLoop.create_connection() or EventLoop.start_serving().) + + The utility function will asynchronously create a transport and a + protocol and hook them up by calling the protocol's + connection_made() method, passing it the transport. + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + +class DatagramTransport(BaseTransport): + """ABC for datagram (UDP) transports.""" + + def sendto(self, data, addr=None): + """Send data to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + addr is target socket address. + If addr is None use target address pointed on transport creation. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class SubprocessTransport(BaseTransport): + + def get_pid(self): + """Get subprocess id.""" + raise NotImplementedError + + def get_returncode(self): + """Get subprocess returncode. + + See also + http://docs.python.org/3/library/subprocess#subprocess.Popen.returncode + """ + raise NotImplementedError + + def get_pipe_transport(self, fd): + """Get transport for pipe with number fd.""" + raise NotImplementedError + + def send_signal(self, signal): + """Send signal to subprocess. + + See also: + docs.python.org/3/library/subprocess#subprocess.Popen.send_signal + """ + raise NotImplementedError + + def terminate(self): + """Stop the subprocess. + + Alias for close() method. + + On Posix OSs the method sends SIGTERM to the subprocess. + On Windows the Win32 API function TerminateProcess() + is called to stop the subprocess. + + See also: + http://docs.python.org/3/library/subprocess#subprocess.Popen.terminate + """ + raise NotImplementedError + + def kill(self): + """Kill the subprocess. + + On Posix OSs the function sends SIGKILL to the subprocess. + On Windows kill() is an alias for terminate(). + + See also: + http://docs.python.org/3/library/subprocess#subprocess.Popen.kill + """ + raise NotImplementedError diff --git a/tulip/unix_events.py b/tulip/unix_events.py new file mode 100644 index 00000000..75131851 --- /dev/null +++ b/tulip/unix_events.py @@ -0,0 +1,555 @@ +"""Selector eventloop for Unix with signal handling.""" + +import collections +import errno +import fcntl +import functools +import os +import signal +import socket +import stat +import subprocess +import sys + + +from . import constants +from . import events +from . import protocols +from . import selector_events +from . import tasks +from . import transports +from .log import tulip_log + + +__all__ = ['SelectorEventLoop'] + + +if sys.platform == 'win32': # pragma: no cover + raise ImportError('Signals are not really supported on Windows') + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Unix event loop + + Adds signal handling to SelectorEventLoop + """ + + def __init__(self, selector=None): + super().__init__(selector) + self._signal_handlers = {} + self._subprocesses = {} + + def _socketpair(self): + return socket.socketpair() + + def close(self): + handler = self._signal_handlers.get(signal.SIGCHLD) + if handler is not None: + self.remove_signal_handler(signal.SIGCHLD) + super().close() + + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + self._check_signal(sig) + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except ValueError as exc: + raise RuntimeError(str(exc)) + + handle = events.make_handle(callback, args) + self._signal_handlers[sig] = handle + + try: + signal.signal(sig, self._handle_signal) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as nexc: + tulip_log.info('set_wakeup_fd(-1) failed: %s', nexc) + + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + def _handle_signal(self, sig, arg): + """Internal helper that is the actual signal handler.""" + handle = self._signal_handlers.get(sig) + if handle is None: + return # Assume it's some race condition. + if handle._cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self._add_callback_signalsafe(handle) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not. + """ + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as exc: + tulip_log.info('set_wakeup_fd(-1) failed: %s', exc) + + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + + if not (1 <= sig < signal.NSIG): + raise ValueError( + 'sig {} out of range(1, {})'.format(sig, signal.NSIG)) + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixReadPipeTransport(self, pipe, protocol, waiter, extra) + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixWritePipeTransport(self, pipe, protocol, waiter, extra) + + @tasks.coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + self._reg_sigchld() + transp = _UnixSubprocessTransport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs) + self._subprocesses[transp.get_pid()] = transp + yield from transp._post_init() + return transp + + def _reg_sigchld(self): + if signal.SIGCHLD not in self._signal_handlers: + self.add_signal_handler(signal.SIGCHLD, self._sig_chld) + + def _sig_chld(self): + try: + while True: + try: + pid, status = os.waitpid(0, os.WNOHANG) + except ChildProcessError: + break + if pid == 0: + continue + elif os.WIFSIGNALED(status): + returncode = -os.WTERMSIG(status) + elif os.WIFEXITED(status): + returncode = os.WEXITSTATUS(status) + else: + # covered by + # SelectorEventLoopTests.test__sig_chld_unknown_status + # from tests/unix_events_test.py + # bug in coverage.py version 3.6 ??? + continue # pragma: no cover + transp = self._subprocesses.get(pid) + if transp is not None: + transp._process_exited(returncode) + except Exception: + tulip_log.exception('Unknown exception in SIGCHLD handler') + + def _subprocess_closed(self, transport): + pid = transport.get_pid() + self._subprocesses.pop(pid, None) + + +def _set_nonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + + +class _UnixReadPipeTransport(transports.ReadTransport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._loop = loop + self._pipe = pipe + self._fileno = pipe.fileno() + _set_nonblocking(self._fileno) + self._protocol = protocol + self._closing = False + self._loop.add_reader(self._fileno, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = os.read(self._fileno, self.max_size) + except (BlockingIOError, InterruptedError): + pass + except OSError as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + self._closing = True + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._protocol.eof_received) + self._loop.call_soon(self._call_connection_lost, None) + + def pause(self): + self._loop.remove_reader(self._fileno) + + def resume(self): + self._loop.add_reader(self._fileno, self._read_ready) + + def close(self): + if not self._closing: + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._closing = True + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + self._pipe = None + self._protocol = None + self._loop = None + + +class _UnixWritePipeTransport(transports.WriteTransport): + + def __init__(self, loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._loop = loop + self._pipe = pipe + self._fileno = pipe.fileno() + if not stat.S_ISFIFO(os.fstat(self._fileno).st_mode): + raise ValueError("Pipe transport is for pipes only.") + _set_nonblocking(self._fileno) + self._protocol = protocol + self._buffer = [] + self._conn_lost = 0 + self._closing = False # Set when close() or write_eof() called. + self._writing = True + self._loop.add_reader(self._fileno, self._read_ready) + + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + # pipe was closed by peer + + self._close() + + def write(self, data): + assert isinstance(data, bytes), repr(data) + assert not self._closing + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('os.write(pipe, data) raised exception.') + self._conn_lost += 1 + return + + if not self._buffer and self._writing: + # Attempt to send it right away first. + try: + n = os.write(self._fileno, data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + self._conn_lost += 1 + self._fatal_error(exc) + return + if n == len(data): + return + elif n > 0: + data = data[n:] + self._loop.add_writer(self._fileno, self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + if not self._writing: + return + + data = b''.join(self._buffer) + assert data, 'Data should not be empty' + + self._buffer.clear() + try: + n = os.write(self._fileno, data) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except Exception as exc: + self._conn_lost += 1 + # Remove writer here, _fatal_error() doesn't it + # because _buffer is empty. + self._loop.remove_writer(self._fileno) + self._fatal_error(exc) + else: + if n == len(data): + self._loop.remove_writer(self._fileno) + if self._closing: + self._loop.remove_reader(self._fileno) + self._call_connection_lost(None) + return + elif n > 0: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def can_write_eof(self): + return True + + def write_eof(self): + assert not self._closing + assert self._pipe + self._closing = True + if not self._buffer: + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, None) + + def close(self): + if not self._closing: + # write_eof is all what we needed to close the write pipe + self.write_eof() + + def abort(self): + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc=None): + self._closing = True + if self._buffer: + self._loop.remove_writer(self._fileno) + self._buffer.clear() + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + self._pipe = None + self._protocol = None + self._loop = None + + def pause_writing(self): + if self._writing: + if self._buffer: + self._loop.remove_writer(self._fileno) + self._writing = False + + def resume_writing(self): + if not self._writing: + if self._buffer: + self._loop.add_writer(self._fileno, self._write_ready) + self._writing = True + + def discard_output(self): + if self._buffer: + self._loop.remove_writer(self._fileno) + self._buffer.clear() + + +class _UnixWriteSubprocessPipeProto(protocols.BaseProtocol): + pipe = None + + def __init__(self, proc, fd): + self.proc = proc + self.fd = fd + self.connected = False + self.disconnected = False + proc._pipes[fd] = self + + def connection_made(self, transport): + self.connected = True + self.pipe = transport + self.proc._try_connected() + + def connection_lost(self, exc): + self.disconnected = True + self.proc._pipe_connection_lost(self.fd, exc) + + +class _UnixReadSubprocessPipeProto(_UnixWriteSubprocessPipeProto, + protocols.Protocol): + + def data_received(self, data): + self.proc._pipe_data_received(self.fd, data) + + def eof_received(self): + pass + + +class _UnixSubprocessTransport(transports.SubprocessTransport): + + def __init__(self, loop, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + super().__init__(extra) + self._protocol = protocol + self._loop = loop + + self._pipes = {} + if stdin == subprocess.PIPE: + self._pipes[0] = None + if stdout == subprocess.PIPE: + self._pipes[1] = None + if stderr == subprocess.PIPE: + self._pipes[2] = None + self._pending_calls = collections.deque() + self._finished = False + self._returncode = None + + self._proc = subprocess.Popen( + args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, + universal_newlines=False, bufsize=bufsize, **kwargs) + self._extra['subprocess'] = self._proc + + def close(self): + for proto in self._pipes.values(): + proto.pipe.close() + if self._returncode is None: + self.terminate() + + def get_pid(self): + return self._proc.pid + + def get_returncode(self): + return self._returncode + + def get_pipe_transport(self, fd): + if fd in self._pipes: + return self._pipes[fd].pipe + else: + return None + + def send_signal(self, signal): + self._proc.send_signal(signal) + + def terminate(self): + self._proc.terminate() + + def kill(self): + self._proc.kill() + + @tasks.coroutine + def _post_init(self): + proc = self._proc + loop = self._loop + if proc.stdin is not None: + transp, proto = yield from loop.connect_write_pipe( + functools.partial(_UnixWriteSubprocessPipeProto, self, 0), + proc.stdin) + if proc.stdout is not None: + transp, proto = yield from loop.connect_read_pipe( + functools.partial(_UnixReadSubprocessPipeProto, self, 1), + proc.stdout) + if proc.stderr is not None: + transp, proto = yield from loop.connect_read_pipe( + functools.partial(_UnixReadSubprocessPipeProto, self, 2), + proc.stderr) + if not self._pipes: + self._try_connected() + + def _call(self, cb, *data): + if self._pending_calls is not None: + self._pending_calls.append((cb, data)) + else: + self._loop.call_soon(cb, *data) + + def _try_connected(self): + assert self._pending_calls is not None + if all(p is not None and p.connected for p in self._pipes.values()): + self._loop.call_soon(self._protocol.connection_made, self) + for callback, data in self._pending_calls: + self._loop.call_soon(callback, *data) + self._pending_calls = None + + def _pipe_connection_lost(self, fd, exc): + self._call(self._protocol.pipe_connection_lost, fd, exc) + self._try_finish() + + def _pipe_data_received(self, fd, data): + self._call(self._protocol.pipe_data_received, fd, data) + + def _process_exited(self, returncode): + assert returncode is not None, returncode + assert self._returncode is None, self._returncode + self._returncode = returncode + self._loop._subprocess_closed(self) + self._call(self._protocol.process_exited) + self._try_finish() + + def _try_finish(self): + assert not self._finished + if self._returncode is None: + return + if all(p is not None and p.disconnected + for p in self._pipes.values()): + self._finished = True + self._loop.call_soon(self._call_connection_lost, None) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._proc = None + self._protocol = None + self._loop = None diff --git a/tulip/windows_events.py b/tulip/windows_events.py new file mode 100644 index 00000000..629b3475 --- /dev/null +++ b/tulip/windows_events.py @@ -0,0 +1,203 @@ +"""Selector and proactor eventloops for Windows.""" + +import errno +import socket +import weakref +import struct +import _winapi + +from . import futures +from . import proactor_events +from . import selector_events +from . import windows_utils +from . import _overlapped +from .log import tulip_log + + +__all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor'] + + +NULL = 0 +INFINITE = 0xffffffff +ERROR_CONNECTION_REFUSED = 1225 +ERROR_CONNECTION_ABORTED = 1236 + + +class _OverlappedFuture(futures.Future): + """Subclass of Future which represents an overlapped operation. + + Cancelling it will immediately cancel the overlapped operation. + """ + + def __init__(self, ov, *, loop=None): + super().__init__(loop=loop) + self.ov = ov + + def cancel(self): + try: + self.ov.cancel() + except OSError: + pass + return super().cancel() + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + def _socketpair(self): + return windows_utils.socketpair() + + +class ProactorEventLoop(proactor_events.BaseProactorEventLoop): + def __init__(self, proactor=None): + if proactor is None: + proactor = IocpProactor() + super().__init__(proactor) + + def _socketpair(self): + return windows_utils.socketpair() + + +class IocpProactor: + + def __init__(self, concurrency=0xffffffff): + self._loop = None + self._results = [] + self._iocp = _overlapped.CreateIoCompletionPort( + _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency) + self._cache = {} + self._registered = weakref.WeakSet() + self._stopped_serving = weakref.WeakSet() + + def set_loop(self, loop): + self._loop = loop + + def select(self, timeout=None): + if not self._results: + self._poll(timeout) + tmp = self._results + self._results = [] + return tmp + + def recv(self, conn, nbytes, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + handle = getattr(conn, 'handle', None) + if handle is None: + ov.WSARecv(conn.fileno(), nbytes, flags) + else: + ov.ReadFile(handle, nbytes) + return self._register(ov, conn, ov.getresult) + + def send(self, conn, buf, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + handle = getattr(conn, 'handle', None) + if handle is None: + ov.WSASend(conn.fileno(), buf, flags) + else: + ov.WriteFile(handle, buf) + return self._register(ov, conn, ov.getresult) + + def accept(self, listener): + self._register_with_iocp(listener) + conn = self._get_accept_socket() + ov = _overlapped.Overlapped(NULL) + ov.AcceptEx(listener.fileno(), conn.fileno()) + + def finish_accept(): + ov.getresult() + buf = struct.pack('@P', listener.fileno()) + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_ACCEPT_CONTEXT, + buf) + conn.settimeout(listener.gettimeout()) + return conn, conn.getpeername() + + return self._register(ov, listener, finish_accept) + + def connect(self, conn, address): + self._register_with_iocp(conn) + # the socket needs to be locally bound before we call ConnectEx() + try: + _overlapped.BindLocal(conn.fileno(), len(address)) + except OSError as e: + if e.winerror != errno.WSAEINVAL: + raise + # probably already locally bound; check using getsockname() + if conn.getsockname()[1] == 0: + raise + ov = _overlapped.Overlapped(NULL) + ov.ConnectEx(conn.fileno(), address) + + def finish_connect(): + ov.getresult() + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_CONNECT_CONTEXT, + 0) + return conn + + return self._register(ov, conn, finish_connect) + + def _register_with_iocp(self, obj): + if obj not in self._registered: + self._registered.add(obj) + _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) + + def _register(self, ov, obj, callback): + f = _OverlappedFuture(ov, loop=self._loop) + self._cache[ov.address] = (f, ov, obj, callback) + return f + + def _get_accept_socket(self): + s = socket.socket() + s.settimeout(0) + return s + + def _poll(self, timeout=None): + if timeout is None: + ms = INFINITE + elif timeout < 0: + raise ValueError("negative timeout") + else: + ms = int(timeout * 1000 + 0.5) + if ms >= INFINITE: + raise ValueError("timeout too big") + while True: + status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms) + if status is None: + return + address = status[3] + f, ov, obj, callback = self._cache.pop(address) + if obj in self._stopped_serving: + f.cancel() + elif not f.cancelled(): + try: + value = callback() + except OSError as e: + f.set_exception(e) + self._results.append(f) + else: + f.set_result(value) + self._results.append(f) + ms = 0 + + def stop_serving(self, obj): + # obj is a socket or pipe handle. It will be closed in + # BaseProactorEventLoop.stop_serving() which will make any + # pending operations fail quickly. + self._stopped_serving.add(obj) + + def close(self): + for (f, ov, obj, callback) in self._cache.values(): + try: + ov.cancel() + except OSError: + pass + + while self._cache: + if not self._poll(1): + tulip_log.debug('taking long time to close proactor') + + self._results = [] + if self._iocp is not None: + _winapi.CloseHandle(self._iocp) + self._iocp = None diff --git a/tulip/windows_utils.py b/tulip/windows_utils.py new file mode 100644 index 00000000..bf85f31e --- /dev/null +++ b/tulip/windows_utils.py @@ -0,0 +1,181 @@ +""" +Various Windows specific bits and pieces +""" + +import sys + +if sys.platform != 'win32': # pragma: no cover + raise ImportError('win32 only') + +import socket +import itertools +import msvcrt +import os +import subprocess +import tempfile +import _winapi + + +__all__ = ['socketpair', 'pipe', 'Popen', 'PIPE', 'PipeHandle'] + +# +# Constants/globals +# + +BUFSIZE = 8192 +PIPE = subprocess.PIPE +_mmap_counter=itertools.count() + +# +# Replacement for socket.socketpair() +# + +def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """A socket pair usable as a self-pipe, for Windows. + + Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. + """ + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + lsock.bind(('localhost', 0)) + lsock.listen(1) + addr, port = lsock.getsockname() + csock = socket.socket(family, type, proto) + csock.setblocking(False) + try: + csock.connect((addr, port)) + except (BlockingIOError, InterruptedError): + pass + except Exception: + lsock.close() + csock.close() + raise + ssock, _ = lsock.accept() + csock.setblocking(True) + lsock.close() + return (ssock, csock) + +# +# Replacement for os.pipe() using handles instead of fds +# + +def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): + """Like os.pipe() but with overlapped support and using handles not fds.""" + address = tempfile.mktemp(prefix=r'\\.\pipe\python-pipe-%d-%d-' % + (os.getpid(), next(_mmap_counter))) + + if duplex: + openmode = _winapi.PIPE_ACCESS_DUPLEX + access = _winapi.GENERIC_READ | _winapi.GENERIC_WRITE + obsize, ibsize = bufsize, bufsize + else: + openmode = _winapi.PIPE_ACCESS_INBOUND + access = _winapi.GENERIC_WRITE + obsize, ibsize = 0, bufsize + + openmode |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE + + if overlapped[0]: + openmode |= _winapi.FILE_FLAG_OVERLAPPED + + if overlapped[1]: + flags_and_attribs = _winapi.FILE_FLAG_OVERLAPPED + else: + flags_and_attribs = 0 + + h1 = h2 = None + try: + h1 = _winapi.CreateNamedPipe( + address, openmode, _winapi.PIPE_WAIT, + 1, obsize, ibsize, _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL) + + h2 = _winapi.CreateFile( + address, access, 0, _winapi.NULL, _winapi.OPEN_EXISTING, + flags_and_attribs, _winapi.NULL) + + ov = _winapi.ConnectNamedPipe(h1, overlapped=True) + ov.GetOverlappedResult(True) + return h1, h2 + except: + if h1 is not None: + _winapi.CloseHandle(h1) + if h2 is not None: + _winapi.CloseHandle(h2) + raise + +# +# Wrapper for a pipe handle +# + +class PipeHandle: + """Wrapper for an overlapped pipe handle which is vaguely file-object like. + + The IOCP event loop can use these instead of socket objects. + """ + def __init__(self, handle): + self._handle = handle + + @property + def handle(self): + return self._handle + + def fileno(self): + return self._handle + + def close(self, *, CloseHandle=_winapi.CloseHandle): + if self._handle is not None: + CloseHandle(self._handle) + self._handle = None + + __del__ = close + + def __enter__(self): + return self + + def __exit__(self, t, v, tb): + self.close() + +# +# Replacement for subprocess.Popen using overlapped pipe handles +# + +class Popen(subprocess.Popen): + """Replacement for subprocess.Popen using overlapped pipe handles. + + The stdin, stdout, stderr are None or instances of PipeHandle. + """ + def __init__(self, args, stdin=None, stdout=None, stderr=None, **kwds): + stdin_rfd = stdout_wfd = stderr_wfd = None + stdin_wh = stdout_rh = stderr_rh = None + if stdin == PIPE: + stdin_rh, stdin_wh = pipe(overlapped=(False, True)) + stdin_rfd = msvcrt.open_osfhandle(stdin_rh, os.O_RDONLY) + if stdout == PIPE: + stdout_rh, stdout_wh = pipe(overlapped=(True, False)) + stdout_wfd = msvcrt.open_osfhandle(stdout_wh, 0) + if stderr == PIPE: + stderr_rh, stderr_wh = pipe(overlapped=(True, False)) + stderr_wfd = msvcrt.open_osfhandle(stderr_wh, 0) + try: + super().__init__(args, bufsize=0, universal_newlines=False, + stdin=stdin_rfd, stdout=stdout_wfd, + stderr=stderr_wfd, **kwds) + except: + for h in (stdin_wh, stdout_rh, stderr_rh): + _winapi.CloseHandle(h) + raise + else: + if stdin_wh is not None: + self.stdin = PipeHandle(stdin_wh) + if stdout_rh is not None: + self.stdout = PipeHandle(stdout_rh) + if stderr_rh is not None: + self.stderr = PipeHandle(stderr_rh) + finally: + if stdin == PIPE: + os.close(stdin_rfd) + if stdout == PIPE: + os.close(stdout_wfd) + if stderr == PIPE: + os.close(stderr_wfd) From 0bc1d149d6707dddd363f4601102d36348cb573c Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 4 Sep 2013 08:50:04 +0300 Subject: [PATCH 0594/1502] Refactor coverage support to allow collecting coverage for filtered tests only. For example now you can use 'python3 runtests.py --coverage queues_test' --- Makefile | 1 - runtests.py | 102 +++++++++++++++++++++++++--------------------------- 2 files changed, 49 insertions(+), 54 deletions(-) diff --git a/Makefile b/Makefile index 6064fc63..edf088e2 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,6 @@ testloop: # See runtests.py for coverage installation instructions. cov coverage: $(PYTHON) runtests.py --coverage tulip -v $(VERBOSE) $(FLAGS) - echo "open file://`pwd`/htmlcov/index.html" check: $(PYTHON) check.py diff --git a/runtests.py b/runtests.py index 484bff09..50068b31 100644 --- a/runtests.py +++ b/runtests.py @@ -27,7 +27,12 @@ import sys import subprocess import unittest +import textwrap import importlib.machinery +try: + import coverage +except ImportError: + coverage = None from unittest.signals import installHandler @@ -57,8 +62,8 @@ '--tests', action="store", dest='testsdir', default='tests', help='tests directory') ARGS.add_argument( - '--coverage', action="store", dest='coverage', nargs='?', const='', - help='enable coverage report and provide python files directory') + '--coverage', action="store_true", dest='coverage', + help='enable html coverage report') ARGS.add_argument( 'pattern', action="store", nargs="*", help='optional regex patterns to match test ids (default all tests)') @@ -175,6 +180,22 @@ def run(self, test): def runtests(): args = ARGS.parse_args() + if args.coverage and coverage is None: + URL = "bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py" + print(textwrap.dedent(""" + coverage package is not installed. + + To install coverage3 for Python 3, you need: + - Setuptools (https://pypi.python.org/pypi/setuptools) + + What worked for me: + - download {0} + * curl -O https://{0} + - python3 ez_setup.py + - python3 -m easy_install coverage + """.format(URL)).strip()) + sys.exit(1) + testsdir = os.path.abspath(args.testsdir) if not os.path.isdir(testsdir): print("Tests directory is not found: {}\n".format(testsdir)) @@ -193,6 +214,12 @@ def runtests(): findleaks = args.findleaks runner_factory = TestRunner if findleaks else unittest.TextTestRunner + if args.coverage: + cov = coverage.coverage(branch=True, + source=['tulip'], + ) + cov.start() + tests = load_tests(args.testsdir, includes, excludes) logger = logging.getLogger() if v == 0: @@ -207,59 +234,28 @@ def runtests(): logger.setLevel(logging.DEBUG) if catchbreak: installHandler() - if args.forever: - while True: + try: + if args.forever: + while True: + result = runner_factory(verbosity=v, + failfast=failfast).run(tests) + if not result.wasSuccessful(): + sys.exit(1) + else: result = runner_factory(verbosity=v, failfast=failfast).run(tests) - if not result.wasSuccessful(): - sys.exit(1) - else: - result = runner_factory(verbosity=v, - failfast=failfast).run(tests) - sys.exit(not result.wasSuccessful()) - - -def runcoverage(sdir, args): - """ - To install coverage3 for Python 3, you need: - - Setuptools (https://pypi.python.org/pypi/setuptools) - - What worked for me: - - download bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py - * curl -O \ - https://bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py - - python3 ez_setup.py - - python3 -m easy_install coverage - """ - try: - import coverage - except ImportError: - print("Coverage package is not found.") - print(runcoverage.__doc__) - return - - sdir = os.path.abspath(sdir) - if not os.path.isdir(sdir): - print("Python files directory is not found: {}\n".format(sdir)) - ARGS.print_help() - return - - mods = [source for _, source in load_modules(sdir)] - coverage = [sys.executable, '-m', 'coverage'] - - try: - subprocess.check_call( - coverage + ['run', '--branch', 'runtests.py'] + args) - except: - pass - else: - subprocess.check_call(coverage + ['html'] + mods) - subprocess.check_call(coverage + ['report'] + mods) + sys.exit(not result.wasSuccessful()) + finally: + if args.coverage: + cov.stop() + cov.save() + cov.html_report(directory='htmlcov') + print("\nCoverage report:") + cov.report(show_missing=False) + here = os.path.dirname(os.path.abspath(__file__)) + print("open file://{}/htmlcov/index.html for html report".format( + here)) if __name__ == '__main__': - if '--coverage' in sys.argv: - cov_args, args = COV_ARGS.parse_known_args() - runcoverage(cov_args.coverage, args) - else: - runtests() + runtests() From 76c9d92f7381d09d27441c358d78894d1cf364b1 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 4 Sep 2013 12:13:55 +0300 Subject: [PATCH 0595/1502] Improve test coverage for queues unit tests. --- tests/queues_test.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/queues_test.py b/tests/queues_test.py index ab4ee91d..98ca3199 100644 --- a/tests/queues_test.py +++ b/tests/queues_test.py @@ -278,6 +278,14 @@ def test_get_cancelled_race(self): test_utils.run_briefly(self.loop) self.assertEqual(t2.result(), 'a') + def test_get_with_waiting_putters(self): + q = queues.Queue(loop=self.loop, maxsize=1) + t1 = tasks.Task(q.put('a'), loop=self.loop) + t2 = tasks.Task(q.put('b'), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(self.loop.run_until_complete(q.get()), 'a') + self.assertEqual(self.loop.run_until_complete(q.get()), 'b') + class QueuePutTests(_QueueTestBase): @@ -366,6 +374,13 @@ def test_put_cancelled_race(self): self.assertEqual(q.get_nowait(), 'a') self.assertEqual(q.get_nowait(), 'c') + def test_put_with_waiting_getters(self): + q = queues.Queue(loop=self.loop) + t = tasks.Task(q.get(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.loop.run_until_complete(q.put('a')) + self.assertEqual(self.loop.run_until_complete(t), 'a') + class LifoQueueTests(_QueueTestBase): From 4d90c102f16b16885425e69d41e3990f2ddeede3 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Wed, 4 Sep 2013 12:30:15 +0300 Subject: [PATCH 0596/1502] Fix 'make coverage' command --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index edf088e2..ed3caf21 100644 --- a/Makefile +++ b/Makefile @@ -16,7 +16,7 @@ testloop: # See runtests.py for coverage installation instructions. cov coverage: - $(PYTHON) runtests.py --coverage tulip -v $(VERBOSE) $(FLAGS) + $(PYTHON) runtests.py --coverage -v $(VERBOSE) $(FLAGS) check: $(PYTHON) check.py From 375476d961cd5d8e9acf888588c7d9c683f38624 Mon Sep 17 00:00:00 2001 From: Charles-Fran?ois Natali Date: Wed, 4 Sep 2013 18:34:24 +0200 Subject: [PATCH 0597/1502] Update test to be less dependant on selectors iternals now that it's part of the stdlib: - pass proper file-like mock objects (i.e. with a fileno() method) - remove explicit checks for _fileobj_to_key dict, which is useless and removed in latest selectors version --- tests/base_events_test.py | 1 + tests/selectors_test.py | 13 ++++++++----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index e27b3ab9..deb82af7 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -584,6 +584,7 @@ def test_accept_connection_retry(self): @unittest.mock.patch('tulip.selector_events.tulip_log') def test_accept_connection_exception(self, m_log): sock = unittest.mock.Mock() + sock.fileno.return_value = 10 sock.accept.side_effect = OSError() self.loop._accept_connection(MyProto, sock) diff --git a/tests/selectors_test.py b/tests/selectors_test.py index 68c1c06b..0f74db0f 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -60,15 +60,20 @@ def test_unregister(self): s.register(fobj, selectors.EVENT_READ) s.unregister(fobj) self.assertFalse(s._fd_to_key) - self.assertFalse(s._fileobj_to_key) def test_unregister_unknown(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + s = FakeSelector() - self.assertRaises(KeyError, s.unregister, unittest.mock.Mock()) + self.assertRaises(KeyError, s.unregister, fobj) def test_modify_unknown(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + s = FakeSelector() - self.assertRaises(KeyError, s.modify, unittest.mock.Mock(), 1) + self.assertRaises(KeyError, s.modify, fobj, 1) def test_modify(self): fobj = unittest.mock.Mock() @@ -119,7 +124,6 @@ def test_close(self): s.close() self.assertFalse(s._fd_to_key) - self.assertFalse(s._fileobj_to_key) def test_context_manager(self): s = FakeSelector() @@ -128,7 +132,6 @@ def test_context_manager(self): sel.register(1, selectors.EVENT_READ) self.assertFalse(s._fd_to_key) - self.assertFalse(s._fileobj_to_key) def test_key_from_fd(self): s = FakeSelector() From e24703faa53bcfc229b0abf41ceba5fbdc9593d2 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 4 Sep 2013 10:10:18 -0700 Subject: [PATCH 0598/1502] Latest selectors.py straight from CPython tree. --- tulip/selectors.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/tulip/selectors.py b/tulip/selectors.py index b81b1dbe..fe027f09 100644 --- a/tulip/selectors.py +++ b/tulip/selectors.py @@ -31,7 +31,7 @@ def _fileobj_to_fd(fileobj): else: try: fd = int(fileobj.fileno()) - except (AttributeError, ValueError): + except (AttributeError, TypeError, ValueError): raise ValueError("Invalid file object: " "{!r}".format(fileobj)) from None if fd < 0: @@ -62,8 +62,6 @@ class BaseSelector(metaclass=ABCMeta): def __init__(self): # this maps file descriptors to keys self._fd_to_key = {} - # this maps file objects to keys - for fast (un)registering - self._fileobj_to_key = {} def register(self, fileobj, events, data=None): """Register a file object. @@ -77,7 +75,7 @@ def register(self, fileobj, events, data=None): SelectorKey instance """ if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)): - raise ValueError("Invalid events: {}".format(events)) + raise ValueError("Invalid events: {!r}".format(events)) key = SelectorKey(fileobj, _fileobj_to_fd(fileobj), events, data) @@ -86,7 +84,6 @@ def register(self, fileobj, events, data=None): "registered".format(fileobj, key.fd)) self._fd_to_key[key.fd] = key - self._fileobj_to_key[fileobj] = key return key def unregister(self, fileobj): @@ -99,8 +96,7 @@ def unregister(self, fileobj): SelectorKey instance """ try: - key = self._fileobj_to_key.pop(fileobj) - del self._fd_to_key[key.fd] + key = self._fd_to_key.pop(_fileobj_to_fd(fileobj)) except KeyError: raise KeyError("{!r} is not registered".format(fileobj)) from None return key @@ -118,7 +114,7 @@ def modify(self, fileobj, events, data=None): """ # TODO: Subclasses can probably optimize this even further. try: - key = self._fileobj_to_key[fileobj] + key = self._fd_to_key[_fileobj_to_fd(fileobj)] except KeyError: raise KeyError("{!r} is not registered".format(fileobj)) from None if events != key.events or data != key.data: @@ -154,7 +150,6 @@ def close(self): This must be called to make sure that any underlying resource is freed. """ self._fd_to_key.clear() - self._fileobj_to_key.clear() def get_key(self, fileobj): """Return the key associated to a registered file object. @@ -163,9 +158,9 @@ def get_key(self, fileobj): SelectorKey for this file object """ try: - return self._fileobj_to_key[fileobj] + return self._fd_to_key[_fileobj_to_fd(fileobj)] except KeyError: - raise KeyError("{} is not registered".format(fileobj)) from None + raise KeyError("{!r} is not registered".format(fileobj)) from None def __enter__(self): return self From 5fe2fdb9353f5cb96979eec603d81c6ee4869752 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Thu, 5 Sep 2013 11:29:24 +0300 Subject: [PATCH 0599/1502] Add checks to tasks_test --- tests/tasks_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 8c26e3f9..a2b2d408 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -173,9 +173,10 @@ def task(): t = tasks.Task(task(), loop=loop) loop.call_soon(t.cancel) - self.assertRaises( - futures.CancelledError, loop.run_until_complete, t) + with self.assertRaises(futures.CancelledError): + loop.run_until_complete(t) self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) self.assertFalse(t.cancel()) def test_cancel_yield(self): @@ -191,6 +192,7 @@ def task(): self.assertRaises( futures.CancelledError, self.loop.run_until_complete, t) self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) self.assertFalse(t.cancel()) ## def test_cancel_done_future(self): From bf38e64c40dfba6ad718a5a8c557ade087672710 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Thu, 5 Sep 2013 11:29:44 +0300 Subject: [PATCH 0600/1502] Add checks to tasks_test --- tests/tasks_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index a2b2d408..a8c88340 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -945,7 +945,7 @@ def gen(): def sleeper(): yield from tasks.sleep(10, loop=loop) - base_exc = BaseException() + base_exc = BaseException() @tasks.coroutine def notmutch(): From 8bfe8c2e5dede03230dfd2d316d7503c2525c2b9 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Thu, 5 Sep 2013 11:35:41 +0300 Subject: [PATCH 0601/1502] Add test for cancelling future which is waited by task. --- tests/tasks_test.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index a8c88340..3c0fa96d 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -195,6 +195,22 @@ def task(): self.assertTrue(t.cancelled()) self.assertFalse(t.cancel()) + def test_cancel_inner_future(self): + f = futures.Future(loop=self.loop) + + @tasks.coroutine + def task(): + yield from f + return 12 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) # start task + f.cancel() + with self.assertRaises(futures.CancelledError): + self.loop.run_until_complete(t) + self.assertTrue(f.cancelled()) + self.assertTrue(t.cancelled()) + ## def test_cancel_done_future(self): ## fut1 = futures.Future(loop=self.loop) ## fut2 = futures.Future(loop=self.loop) From c2ce8fdb75fcdf40965261ed1004485286192740 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 6 Sep 2013 06:17:22 +0300 Subject: [PATCH 0602/1502] Fix task when both task and inner waiter are cancelled. --- tests/tasks_test.py | 21 +++++++++++++++++++++ tulip/tasks.py | 6 +++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 3c0fa96d..d8d3870c 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -211,6 +211,27 @@ def task(): self.assertTrue(f.cancelled()) self.assertTrue(t.cancelled()) + def test_cancel_both_task_and_inner_future(self): + f = futures.Future(loop=self.loop) + + @tasks.coroutine + def task(): + yield from f + return 12 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + + f.cancel() + t.cancel() + + with self.assertRaises(futures.CancelledError): + self.loop.run_until_complete(t) + + self.assertTrue(t.done()) + self.assertTrue(f.cancelled()) + self.assertTrue(t.cancelled()) + ## def test_cancel_done_future(self): ## fut1 = futures.Future(loop=self.loop) ## fut2 = futures.Future(loop=self.loop) diff --git a/tulip/tasks.py b/tulip/tasks.py index a51ee29a..4a69315f 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -76,7 +76,11 @@ def cancel(self): if self.done(): return False if self._fut_waiter is not None: - if self._fut_waiter.cancel(): + # XXX: What to do if self._fut_waiter.cancel() returns False? + # If that's anready cancelled future everything is ok. + # What are other possible scenarios? + waiter, self._fut_waiter = self._fut_waiter, None + if waiter.cancel(): return True # It must be the case that self._step is already scheduled. self._must_cancel = True From de945ca4964388342efb7436612567737995352c Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 5 Sep 2013 23:08:24 -0700 Subject: [PATCH 0603/1502] add tulip.http to packages list --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index dcaee96f..a19e3224 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,6 @@ setup(name='tulip', description="reference implementation of PEP 3156", url='http://www.python.org/dev/peps/pep-3156/', - packages=['tulip'], + packages=['tulip', 'tulip.http'], ext_modules=extensions ) From cbda2628bba892e7259548da30cef48e0b492462 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 6 Sep 2013 10:08:06 -0700 Subject: [PATCH 0604/1502] Fold long line. --- runtests.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/runtests.py b/runtests.py index 50068b31..725bfa2e 100644 --- a/runtests.py +++ b/runtests.py @@ -253,8 +253,8 @@ def runtests(): print("\nCoverage report:") cov.report(show_missing=False) here = os.path.dirname(os.path.abspath(__file__)) - print("open file://{}/htmlcov/index.html for html report".format( - here)) + print("\nFor html report:") + print("open file://{}/htmlcov/index.html".format(here)) if __name__ == '__main__': From 18869014273804e1e3ac7b44b090568f7ba8c5a8 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 6 Sep 2013 11:07:00 -0700 Subject: [PATCH 0605/1502] Get rid of _FakeEventLoop in favor of run_briefly(). --- tests/futures_test.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/tests/futures_test.py b/tests/futures_test.py index 7c2abd18..18cec8b0 100644 --- a/tests/futures_test.py +++ b/tests/futures_test.py @@ -242,14 +242,17 @@ def run(arg): self.assertIs(m_events.get_event_loop.return_value, f2._loop) -# A fake event loop for tests. All it does is implement a call_soon method -# that immediately invokes the given function. -class _FakeEventLoop: - def call_soon(self, fn, *args): - fn(*args) +class FutureDoneCallbackTests(unittest.TestCase): + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) -class FutureDoneCallbackTests(unittest.TestCase): + def tearDown(self): + self.loop.close() + + def run_briefly(self): + test_utils.run_briefly(self.loop) def _make_callback(self, bag, thing): # Create a callback function that appends thing to bag. @@ -258,7 +261,7 @@ def bag_appender(future): return bag_appender def _new_future(self): - return futures.Future(loop=_FakeEventLoop()) + return futures.Future(loop=self.loop) def test_callbacks_invoked_on_set_result(self): bag = [] @@ -268,6 +271,9 @@ def test_callbacks_invoked_on_set_result(self): self.assertEqual(bag, []) f.set_result('foo') + + self.run_briefly() + self.assertEqual(bag, [42, 17]) self.assertEqual(f.result(), 'foo') @@ -279,6 +285,9 @@ def test_callbacks_invoked_on_set_exception(self): self.assertEqual(bag, []) exc = RuntimeError() f.set_exception(exc) + + self.run_briefly() + self.assertEqual(bag, [100]) self.assertEqual(f.exception(), exc) @@ -309,6 +318,9 @@ def test_remove_done_callback(self): self.assertEqual(bag, []) f.set_result('foo') + + self.run_briefly() + self.assertEqual(bag, [2]) self.assertEqual(f.result(), 'foo') From 81465772d2e2930493a69b87b7a390647165fa2a Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Fri, 6 Sep 2013 21:00:49 +0200 Subject: [PATCH 0606/1502] Add tasks.gather(), similar to Twisted's gatherResults(). --- .hgeol | 4 + .hgignore | 12 + Makefile | 34 + NOTES | 176 +++ README | 21 + TODO | 163 +++ check.py | 41 + examples/child_process.py | 127 +++ examples/crawl.py | 104 ++ examples/curl.py | 24 + examples/mpsrv.py | 289 +++++ examples/srv.py | 163 +++ examples/tcp_echo.py | 113 ++ examples/tcp_protocol_parser.py | 170 +++ examples/udp_echo.py | 98 ++ examples/websocket.html | 90 ++ examples/wsclient.py | 97 ++ examples/wssrv.py | 309 +++++ overlapped.c | 1009 +++++++++++++++++ runtests.py | 261 +++++ setup.cfg | 2 + setup.py | 14 + tests/base_events_test.py | 591 ++++++++++ tests/echo.py | 6 + tests/echo2.py | 6 + tests/echo3.py | 9 + tests/events_test.py | 1574 ++++++++++++++++++++++++++ tests/futures_test.py | 317 ++++++ tests/http_client_functional_test.py | 552 +++++++++ tests/http_client_test.py | 289 +++++ tests/http_parser_test.py | 539 +++++++++ tests/http_protocol_test.py | 400 +++++++ tests/http_server_test.py | 301 +++++ tests/http_session_test.py | 139 +++ tests/http_websocket_test.py | 439 +++++++ tests/http_wsgi_test.py | 301 +++++ tests/locks_test.py | 765 +++++++++++++ tests/parsers_test.py | 598 ++++++++++ tests/proactor_events_test.py | 393 +++++++ tests/queues_test.py | 470 ++++++++ tests/sample.crt | 14 + tests/sample.key | 15 + tests/selector_events_test.py | 1471 ++++++++++++++++++++++++ tests/selectors_test.py | 142 +++ tests/streams_test.py | 343 ++++++ tests/tasks_test.py | 1239 ++++++++++++++++++++ tests/transports_test.py | 59 + tests/unix_events_test.py | 818 +++++++++++++ tests/windows_events_test.py | 81 ++ tests/windows_utils_test.py | 132 +++ tulip/__init__.py | 28 + tulip/base_events.py | 592 ++++++++++ tulip/constants.py | 4 + tulip/events.py | 389 +++++++ tulip/futures.py | 338 ++++++ tulip/http/__init__.py | 16 + tulip/http/client.py | 572 ++++++++++ tulip/http/errors.py | 46 + tulip/http/protocol.py | 756 +++++++++++++ tulip/http/server.py | 215 ++++ tulip/http/session.py | 103 ++ tulip/http/websocket.py | 233 ++++ tulip/http/wsgi.py | 227 ++++ tulip/locks.py | 403 +++++++ tulip/log.py | 6 + tulip/parsers.py | 399 +++++++ tulip/proactor_events.py | 288 +++++ tulip/protocols.py | 100 ++ tulip/queues.py | 284 +++++ tulip/selector_events.py | 676 +++++++++++ tulip/selectors.py | 410 +++++++ tulip/streams.py | 211 ++++ tulip/tasks.py | 377 ++++++ tulip/test_utils.py | 443 ++++++++ tulip/transports.py | 201 ++++ tulip/unix_events.py | 555 +++++++++ tulip/windows_events.py | 203 ++++ tulip/windows_utils.py | 181 +++ 78 files changed, 23580 insertions(+) create mode 100644 .hgeol create mode 100644 .hgignore create mode 100644 Makefile create mode 100644 NOTES create mode 100644 README create mode 100644 TODO create mode 100644 check.py create mode 100644 examples/child_process.py create mode 100755 examples/crawl.py create mode 100755 examples/curl.py create mode 100755 examples/mpsrv.py create mode 100755 examples/srv.py create mode 100755 examples/tcp_echo.py create mode 100755 examples/tcp_protocol_parser.py create mode 100755 examples/udp_echo.py create mode 100644 examples/websocket.html create mode 100755 examples/wsclient.py create mode 100755 examples/wssrv.py create mode 100644 overlapped.c create mode 100644 runtests.py create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 tests/base_events_test.py create mode 100644 tests/echo.py create mode 100644 tests/echo2.py create mode 100644 tests/echo3.py create mode 100644 tests/events_test.py create mode 100644 tests/futures_test.py create mode 100644 tests/http_client_functional_test.py create mode 100644 tests/http_client_test.py create mode 100644 tests/http_parser_test.py create mode 100644 tests/http_protocol_test.py create mode 100644 tests/http_server_test.py create mode 100644 tests/http_session_test.py create mode 100644 tests/http_websocket_test.py create mode 100644 tests/http_wsgi_test.py create mode 100644 tests/locks_test.py create mode 100644 tests/parsers_test.py create mode 100644 tests/proactor_events_test.py create mode 100644 tests/queues_test.py create mode 100644 tests/sample.crt create mode 100644 tests/sample.key create mode 100644 tests/selector_events_test.py create mode 100644 tests/selectors_test.py create mode 100644 tests/streams_test.py create mode 100644 tests/tasks_test.py create mode 100644 tests/transports_test.py create mode 100644 tests/unix_events_test.py create mode 100644 tests/windows_events_test.py create mode 100644 tests/windows_utils_test.py create mode 100644 tulip/__init__.py create mode 100644 tulip/base_events.py create mode 100644 tulip/constants.py create mode 100644 tulip/events.py create mode 100644 tulip/futures.py create mode 100644 tulip/http/__init__.py create mode 100644 tulip/http/client.py create mode 100644 tulip/http/errors.py create mode 100644 tulip/http/protocol.py create mode 100644 tulip/http/server.py create mode 100644 tulip/http/session.py create mode 100644 tulip/http/websocket.py create mode 100644 tulip/http/wsgi.py create mode 100644 tulip/locks.py create mode 100644 tulip/log.py create mode 100644 tulip/parsers.py create mode 100644 tulip/proactor_events.py create mode 100644 tulip/protocols.py create mode 100644 tulip/queues.py create mode 100644 tulip/selector_events.py create mode 100644 tulip/selectors.py create mode 100644 tulip/streams.py create mode 100644 tulip/tasks.py create mode 100644 tulip/test_utils.py create mode 100644 tulip/transports.py create mode 100644 tulip/unix_events.py create mode 100644 tulip/windows_events.py create mode 100644 tulip/windows_utils.py diff --git a/.hgeol b/.hgeol new file mode 100644 index 00000000..8233b6dd --- /dev/null +++ b/.hgeol @@ -0,0 +1,4 @@ +[patterns] +** = native +.hgignore = native +.hgeol = native diff --git a/.hgignore b/.hgignore new file mode 100644 index 00000000..99870025 --- /dev/null +++ b/.hgignore @@ -0,0 +1,12 @@ +.*\.py[co]$ +.*~$ +.*\.orig$ +.*\#.*$ +.*@.*$ +\.coverage$ +htmlcov$ +\.DS_Store$ +venv$ +distribute_setup.py$ +distribute-\d+.\d+.\d+.tar.gz$ +build$ diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..ed3caf21 --- /dev/null +++ b/Makefile @@ -0,0 +1,34 @@ +# Some simple testing tasks (sorry, UNIX only). + +PYTHON=python3 +VERBOSE=$(V) +V= 0 +FLAGS= + +test: + $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS) + +vtest: + $(PYTHON) runtests.py -v 1 $(FLAGS) + +testloop: + while sleep 1; do $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS); done + +# See runtests.py for coverage installation instructions. +cov coverage: + $(PYTHON) runtests.py --coverage -v $(VERBOSE) $(FLAGS) + +check: + $(PYTHON) check.py + +clean: + rm -rf `find . -name __pycache__` + rm -f `find . -type f -name '*.py[co]' ` + rm -f `find . -type f -name '*~' ` + rm -f `find . -type f -name '.*~' ` + rm -f `find . -type f -name '@*' ` + rm -f `find . -type f -name '#*#' ` + rm -f `find . -type f -name '*.orig' ` + rm -f `find . -type f -name '*.rej' ` + rm -f .coverage + rm -rf htmlcov diff --git a/NOTES b/NOTES new file mode 100644 index 00000000..3b94ba96 --- /dev/null +++ b/NOTES @@ -0,0 +1,176 @@ +Notes from PyCon 2013 sprints +============================= + +- Cancellation. If a task creates several subtasks, and then the + parent task fails, should the subtasks be cancelled? (How do we + even establish the parent/subtask relationship?) + +- Adam Sah suggests that there might be a need for scheduling + (especially when multiple frameworks share an event loop). He + points to lottery scheduling but also mentions that's just one of + the options. However, after posting on python-tulip, it appears + none of the other frameworks have scheduling, and nobody seems to + miss it. + +- Feedback from Bram Cohen (Bittorrent creator) about UDP. He doesn't + think connected UDP is worth supporting, it doesn't do anything + except tell the kernel about the default target address for + sendto(). Basically he says all UDP end points are servers. He + sent me his own UDP event loop so I might glean some tricks from it. + He says we should treat EINTR the same as EAGAIN and friends. (We + should use the exceptions dedicated to errno checking, BTW.) HE + said to make sure we use SO_REUSEADDR (I think we already do). He + said to set the max datagram sizes pretty large (anything larger + than the declared limit is dropped on the floor). He reminds us of + the importance of being able to pick a valid, unused port by binding + to port 0 and then using getsockname(). He has an idea where he's + like to be able to kill all registered callbacks (i.e. Handles) + belonging to a certain "context". I think this can be done at the + application level (you'd have to wrap everything that returns a + Handle and collect these handles in some set or other datastructure) + but if someone thinks it's interesting we could imagine having some + kind of notion of context part of the event loop state, + e.g. associated with a Task (see Cancellation point above). He + brought up uTP (Micro Transport Protocol), a reimplementation of TCP + over UDP with more refined congestion control. + +- Mumblings about UNIX domain sockets and IPv6 addresses being + 4-tuples. The former can be handled by passing in a socket. There + seem to be no real use cases for the latter that can't be dealt with + by passing in suitably esoteric strings for the hostname. + getaddrinfo() will produce the appropriate 4-tuple and connect() + will accept it. + +- Mumblings on the list about add vs. set. + + +Notes from the second Tulip/Twisted meet-up +=========================================== + +Rackspace, 12/11/2012 +Glyph, Brian Warner, David Reid, Duncan McGreggor, others + +Flow control +------------ + +- Pause/resume on transport manages data_received. + +- There's also an API to tell the transport whom to pause when the + write calls are overwhelming it: IConsumer.registerProducer(). + +- There's also something called pipes but it's built on top of the + old interface. + +- Twisted has variations on the basic flow control that I should + ignore. + +Half_close +---------- + +- This sends an EOF after writing some stuff. + +- Can't write any more. + +- Problem with TLS is known (the RFC sadly specifies this behavior). + +- It must be dynamimcally discoverable whether the transport supports + half_close, since the protocol may have to do something different to + make up for its missing (e.g. use chunked encoding). Twisted uses + an interface check for this and also hasattr(trans, 'halfClose') + but a flag (or flag method) is fine too. + +Constructing transport and protocol +----------------------------------- + +- There are good reasons for passing a function to the transport + construction helper that creates the protocol. (You need these + anyway for server-side protocols.) The sequence of events is + something like + + . open socket + . create transport (pass it a socket?) + . create protocol (pass it nothing) + . proto.make_connection(transport); this does: + . self.transport = transport + . self.connection_made(transport) + + But it seems okay to skip make_connection and setting .transport. + Note that make_connection() is a concrete method on the Protocol + implementation base class, while connection_made() is an abstract + method on IProtocol. + +Event Loop +---------- + +- We discussed the sequence of actions in the event loop. I think in the + end we're fine with what Tulip currently does. There are two choices: + + Tulip: + . run ready callbacks until there aren't any left + . poll, adding more callbacks to the ready list + . add now-ready delayed callbacks to the ready list + . go to top + + Tornado: + . run all currently ready callbacks (but not new ones added during this) + . (the rest is the same) + + The difference is that in the Tulip version, CPU bound callbacks + that keep adding more to the queue will starve I/O (and yielding to + other tasks won't actually cause I/O to happen unless you do + e.g. sleep(0.001)). OTOH this may be good because it means there's + less overhead if you frequently split operations in two. + +- I think Twisted does it Tornado style (in a convoluted way :-), but + it may not matter, and it's important to leave this vague so + implementations can do what's best for their platform. (E.g. if the + event loop is built into the OS there are different trade-offs.) + +System call cost +---------------- + +- System calls on MacOS are expensive, on Linux they are cheap. + +- Optimal buffer size ~16K. + +- Try joining small buffer pieces together, but expect to be tuning + this later. + +Futures +------- + +- Futures are the most robust API for async stuff, you can check + errors etc. So let's do this. + +- Just don't implement wait(). + +- For the basics, however, (recv/send, mostly), don't use Futures but use + basic callbacks, transport/protocol style. + +- make_connection() (by any name) can return a Future, it makes it + easier to check for errors. + +- This means revisiting the Tulip proactor branch (IOCP). + +- The semantics of add_done_callback() are fuzzy about in which thread + the callback will be called. (It may be the current thread or + another one.) We don't like that. But always inserting a + call_soon() indirection may be expensive? Glyph suggested changing + the add_done_callback() method name to something else to indicate + the changed promise. + +- Separately, I've been thinking about having two versions of + call_soon() -- a more heavy-weight one to be called from other + threads that also writes a byte to the self-pipe. + +Signals +------- + +- There was a side conversation about signals. A signal handler is + similar to another thread, so probably should use (the heavy-weight + version of) call_soon() to schedule the real callback and not do + anything else. + +- Glyph vaguely recalled some trickiness with the self-pipe. We + should be able to fix this afterwards if necessary, it shouldn't + affect the API design. diff --git a/README b/README new file mode 100644 index 00000000..8f2b6373 --- /dev/null +++ b/README @@ -0,0 +1,21 @@ +Tulip is the codename for my reference implementation of PEP 3156. + +PEP 3156: http://www.python.org/dev/peps/pep-3156/ + +*** This requires Python 3.3 or later! *** + +Copyright/license: Open source, Apache 2.0. Enjoy. + +Master Mercurial repo: http://code.google.com/p/tulip/ + +The actual code lives in the 'tulip' subdirectory. +Tests are in the 'tests' subdirectory. + +To run tests: + - make test + +To run coverage (coverage package is required): + - make coverage + + +--Guido van Rossum diff --git a/TODO b/TODO new file mode 100644 index 00000000..c6d4eead --- /dev/null +++ b/TODO @@ -0,0 +1,163 @@ +# -*- Mode: text -*- + +TO DO LARGER TASKS + +- Need more examples. + +- Benchmarkable but more realistic HTTP server? + +- Example of using UDP. + +- Write up a tutorial for the scheduling API. + +- More systematic approach to logging. Logger objects? What about + heavy-duty logging, tracking essentially all task state changes? + +- Restructure directory, move demos and benchmarks to subdirectories. + + +TO DO LATER + +- When multiple tasks are accessing the same socket, they should + either get interleaved I/O or an immediate exception; it should not + compromise the integrity of the scheduler or the app or leave a task + hanging. + +- For epoll you probably want to check/(log?) EPOLLHUP and EPOLLERR errors. + +- Add the simplest API possible to run a generator with a timeout. + +- Ensure multiple tasks can do atomic writes to the same pipe (since + UNIX guarantees that short writes to pipes are atomic). + +- Ensure some easy way of distributing accepted connections across tasks. + +- Be wary of thread-local storage. There should be a standard API to + get the current Context (which holds current task, event loop, and + maybe more) and a standard meta-API to change how that standard API + works (i.e. without monkey-patching). + +- See how much of asyncore I've already replaced. + +- Could BufferedReader reuse the standard io module's readers??? + +- Support ZeroMQ "sockets" which are user objects. Though possibly + this can be supported by getting the underlying fd? See + http://mail.python.org/pipermail/python-ideas/2012-October/017532.html + OTOH see + https://github.com/zeromq/pyzmq/blob/master/zmq/eventloop/ioloop.py + +- Study goroutines (again). + +- Benchmarks: http://nichol.as/benchmark-of-python-web-servers + + +FROM OLDER LIST + +- Multiple readers/writers per socket? (At which level? pollster, + eventloop, or scheduler?) + +- Could poll() usefully be an iterator? + +- Do we need to support more epoll and/or kqueue modes/flags/options/etc.? + +- Optimize register/unregister calls away if they cancel each other out? + +- Add explicit wait queue to wait for Task's completion, instead of + callbacks? + +- Look at pyfdpdlib's ioloop.py: + http://code.google.com/p/pyftpdlib/source/browse/trunk/pyftpdlib/lib/ioloop.py + + +MISTAKES I MADE + +- Forgetting yield from. (E.g.: scheduler.sleep(1); listener.accept().) + +- Forgot to add bare yield at end of internal function, after block(). + +- Forgot to call add_done_callback(). + +- Forgot to pass an undoer to block(), bug only found when cancelled. + +- Subtle accounting mistake in a callback. + +- Used context.eventloop from a different thread, forgetting about TLS. + +- Nasty race: eventloop.ready may contain both an I/O callback and a + cancel callback. How to avoid? Keep the DelayedCall in ready. Is + that enough? + +- If a toplevel task raises an error it just stops and nothing is logged + unless you have debug logging on. This confused me. (Then again, + previously I logged whenever a task raised an error, and that was too + chatty...) + +- Forgot to set the connection socket returned by accept() in + nonblocking mode. + +- Nastiest so far (cost me about a day): A race condition in + call_in_thread() where the Future's done_callback (which was + task.unblock()) would run immediately at the time when + add_done_callback() was called, and this screwed over the task + state. Solution: wrap the callback in eventloop.call_later(). + Ironically, I had a comment stating there might be a race condition. + +- Another bug where I was calling unblock() for the current thread + immediately after calling block(), before yielding. + +- readexactly() wasn't checking for EOF, so could be looping. + (Worse, the first fix I attempted was wrong.) + +- Spent a day trying to understand why a tentative patch trying to + move the recv() implementation into the eventloop (or the pollster) + resulted in problems cancelling a recv() call. Ultimately the + problem is that the cancellation mechanism is part of the coroutine + scheduler, which simply throws an exception into a task when it next + runs, and there isn't anything to be interrupted in the eventloop; + but the eventloop still has a reader registered (which will never + fire because I suspended the server -- that's my test case :-). + Then, the eventloop keeps running until the last file descriptor is + unregistered. What contributed to this disaster? + * I didn't build the whole infrastructure, just played with recv() + * I don't have unittests + * I don't have good logging to see what is going + +- In sockets.py, in some SSL error handling code, used the wrong + variable (sock instead of sslsock). A linter would have found this. + +- In polling.py, in KqueuePollster.register_writer(), a copy/paste + error where I was testing for "if fd not in self.readers" instead of + writers. This only came out when I had both a reader and a writer + for the same fd. + +- Submitted some changes prematurely (forgot to pass the filename on + hg ci). + +- Forgot again that shutdown(SHUT_WR) on an ssl socket does not work + as I expected. I ran into this with the origininal sockets.py and + again in transport.py. + +- Having the same callback for both reading and writing has a problem: + it may be scheduled twice, and if the first call closes the socket, + the second runs into trouble. + + +MISTAKES I MADE IN TULIP V2 + +- Nice one: Future._schedule_callbacks() wasn't scheduling any callbacks. + Spot the bug in these four lines: + + def _schedule_callbacks(self): + callbacks = self._callbacks[:] + self._callbacks[:] = [] + for callback in self._callbacks: + self._event_loop.call_soon(callback, self) + + The good news is that I found it with a unittest (albeit not the + unittest intended to exercise this particular method :-( ). + +- In _make_self_pipe_or_sock(), called _pollster.register_reader() + instead of add_reader(), trying to optimize something but breaking + things instead (since the -- internal -- API of register_reader() + had changed). diff --git a/check.py b/check.py new file mode 100644 index 00000000..9ab6bcc0 --- /dev/null +++ b/check.py @@ -0,0 +1,41 @@ +"""Search for lines >= 80 chars or with trailing whitespace.""" + +import sys, os + +def main(): + args = sys.argv[1:] or os.curdir + for arg in args: + if os.path.isdir(arg): + for dn, dirs, files in os.walk(arg): + for fn in sorted(files): + if fn.endswith('.py'): + process(os.path.join(dn, fn)) + dirs[:] = [d for d in dirs if d[0] != '.'] + dirs.sort() + else: + process(arg) + +def isascii(x): + try: + x.encode('ascii') + return True + except UnicodeError: + return False + +def process(fn): + try: + f = open(fn) + except IOError as err: + print(err) + return + try: + for i, line in enumerate(f): + line = line.rstrip('\n') + sline = line.rstrip() + if len(line) >= 80 or line != sline or not isascii(line): + print('{}:{:d}:{}{}'.format( + fn, i+1, sline, '_' * (len(line) - len(sline)))) + finally: + f.close() + +main() diff --git a/examples/child_process.py b/examples/child_process.py new file mode 100644 index 00000000..d4a035bd --- /dev/null +++ b/examples/child_process.py @@ -0,0 +1,127 @@ +""" +Example of asynchronous interaction with a child python process. + +Note that on Windows we must use the IOCP event loop. +""" + +import os +import sys + +try: + import tulip +except ImportError: + # tulip is not installed + sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + import tulip + +from tulip import streams +from tulip import protocols + +if sys.platform == 'win32': + from tulip.windows_utils import Popen, PIPE + from tulip.windows_events import ProactorEventLoop +else: + from subprocess import Popen, PIPE + +# +# Return a write-only transport wrapping a writable pipe +# + +@tulip.coroutine +def connect_write_pipe(file): + loop = tulip.get_event_loop() + protocol = protocols.Protocol() + transport, _ = yield from loop.connect_write_pipe(tulip.Protocol, file) + return transport + +# +# Wrap a readable pipe in a stream +# + +@tulip.coroutine +def connect_read_pipe(file): + loop = tulip.get_event_loop() + stream_reader = streams.StreamReader(loop=loop) + def factory(): + return streams.StreamReaderProtocol(stream_reader) + transport, _ = yield from loop.connect_read_pipe(factory, file) + return stream_reader, transport + + +# +# Example +# + +@tulip.task +def main(loop): + # program which prints evaluation of each expression from stdin + code = r'''if 1: + import os + def writeall(fd, buf): + while buf: + n = os.write(fd, buf) + buf = buf[n:] + while True: + s = os.read(0, 1024) + if not s: + break + s = s.decode('ascii') + s = repr(eval(s)) + '\n' + s = s.encode('ascii') + writeall(1, s) + ''' + + # commands to send to input + commands = iter([b"1+1\n", + b"2**16\n", + b"1/3\n", + b"'x'*50", + b"1/0\n"]) + + # start subprocess and wrap stdin, stdout, stderr + p = Popen([sys.executable, '-c', code], + stdin=PIPE, stdout=PIPE, stderr=PIPE) + + stdin = yield from connect_write_pipe(p.stdin) + stdout, stdout_transport = yield from connect_read_pipe(p.stdout) + stderr, stderr_transport = yield from connect_read_pipe(p.stderr) + + # interact with subprocess + name = {stdout:'OUT', stderr:'ERR'} + registered = {tulip.Task(stderr.readline()): stderr, + tulip.Task(stdout.readline()): stdout} + while registered: + # write command + cmd = next(commands, None) + if cmd is None: + stdin.close() + else: + print('>>>', cmd.decode('ascii').rstrip()) + stdin.write(cmd) + + # get and print lines from stdout, stderr + timeout = None + while registered: + done, pending = yield from tulip.wait( + registered, timeout=timeout, return_when=tulip.FIRST_COMPLETED) + if not done: + break + for f in done: + stream = registered.pop(f) + res = f.result() + print(name[stream], res.decode('ascii').rstrip()) + if res != b'': + registered[tulip.Task(stream.readline())] = stream + timeout = 0.0 + + stdout_transport.close() + stderr_transport.close() + +if __name__ == '__main__': + if sys.platform == 'win32': + loop = ProactorEventLoop() + tulip.set_event_loop(loop) + else: + loop = tulip.get_event_loop() + loop.run_until_complete(main(loop)) + loop.close() diff --git a/examples/crawl.py b/examples/crawl.py new file mode 100755 index 00000000..ac9c25e9 --- /dev/null +++ b/examples/crawl.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 + +import logging +import re +import signal +import sys +import urllib.parse + +import tulip +import tulip.http + + +class Crawler: + + def __init__(self, rooturl, loop, maxtasks=100): + self.rooturl = rooturl + self.loop = loop + self.todo = set() + self.busy = set() + self.done = {} + self.tasks = set() + self.sem = tulip.Semaphore(maxtasks) + + # session stores cookies between requests and uses connection pool + self.session = tulip.http.Session() + + @tulip.task + def run(self): + self.addurls([(self.rooturl, '')]) # Set initial work. + yield from tulip.sleep(1) + while self.busy: + yield from tulip.sleep(1) + + self.session.close() + self.loop.stop() + + @tulip.task + def addurls(self, urls): + for url, parenturl in urls: + url = urllib.parse.urljoin(parenturl, url) + url, frag = urllib.parse.urldefrag(url) + if (url.startswith(self.rooturl) and + url not in self.busy and + url not in self.done and + url not in self.todo): + self.todo.add(url) + yield from self.sem.acquire() + task = self.process(url) + task.add_done_callback(lambda t: self.sem.release()) + task.add_done_callback(self.tasks.remove) + self.tasks.add(task) + + @tulip.task + def process(self, url): + print('processing:', url) + + self.todo.remove(url) + self.busy.add(url) + try: + resp = yield from tulip.http.request( + 'get', url, session=self.session) + except Exception as exc: + print('...', url, 'has error', repr(str(exc))) + self.done[url] = False + else: + if resp.status == 200 and resp.get_content_type() == 'text/html': + data = (yield from resp.read()).decode('utf-8', 'replace') + urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', data) + self.addurls([(u, url) for u in urls]) + + resp.close() + self.done[url] = True + + self.busy.remove(url) + print(len(self.done), 'completed tasks,', len(self.tasks), + 'still pending, todo', len(self.todo)) + + +def main(): + loop = tulip.get_event_loop() + + c = Crawler(sys.argv[1], loop) + c.run() + + try: + loop.add_signal_handler(signal.SIGINT, loop.stop) + except RuntimeError: + pass + loop.run_forever() + print('todo:', len(c.todo)) + print('busy:', len(c.busy)) + print('done:', len(c.done), '; ok:', sum(c.done.values())) + print('tasks:', len(c.tasks)) + + +if __name__ == '__main__': + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + logging.info('using iocp') + el = windows_events.ProactorEventLoop() + events.set_event_loop(el) + + main() diff --git a/examples/curl.py b/examples/curl.py new file mode 100755 index 00000000..7063adcd --- /dev/null +++ b/examples/curl.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 + +import sys +import tulip +import tulip.http + + +def curl(url): + response = yield from tulip.http.request('GET', url) + print(repr(response)) + + data = yield from response.read() + print(data.decode('utf-8', 'replace')) + + +if __name__ == '__main__': + if '--iocp' in sys.argv: + from tulip import events, windows_events + sys.argv.remove('--iocp') + el = windows_events.ProactorEventLoop() + events.set_event_loop(el) + + loop = tulip.get_event_loop() + loop.run_until_complete(curl(sys.argv[1])) diff --git a/examples/mpsrv.py b/examples/mpsrv.py new file mode 100755 index 00000000..6b1ebb8f --- /dev/null +++ b/examples/mpsrv.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python3 +"""Simple multiprocess http server written using an event loop.""" + +import argparse +import email.message +import os +import socket +import signal +import time +import tulip +import tulip.http +from tulip.http import websocket + +ARGS = argparse.ArgumentParser(description="Run simple http server.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') +ARGS.add_argument( + '--workers', action="store", dest='workers', + default=2, type=int, help='Number of workers.') + + +class HttpServer(tulip.http.ServerHttpProtocol): + + @tulip.coroutine + def handle_request(self, message, payload): + print('{}: method = {!r}; path = {!r}; version = {!r}'.format( + os.getpid(), message.method, message.path, message.version)) + + path = message.path + + if (not (path.isprintable() and path.startswith('/')) or '/.' in path): + path = None + else: + path = '.' + path + if not os.path.exists(path): + path = None + else: + isdir = os.path.isdir(path) + + if not path: + raise tulip.http.HttpStatusException(404) + + headers = email.message.Message() + for hdr, val in message.headers: + headers.add_header(hdr, val) + + if isdir and not path.endswith('/'): + path = path + '/' + raise tulip.http.HttpStatusException( + 302, headers=(('URI', path), ('Location', path))) + + response = tulip.http.Response(self.transport, 200) + response.add_header('Transfer-Encoding', 'chunked') + + # content encoding + accept_encoding = headers.get('accept-encoding', '').lower() + if 'deflate' in accept_encoding: + response.add_header('Content-Encoding', 'deflate') + response.add_compression_filter('deflate') + elif 'gzip' in accept_encoding: + response.add_header('Content-Encoding', 'gzip') + response.add_compression_filter('gzip') + + response.add_chunking_filter(1025) + + if isdir: + response.add_header('Content-type', 'text/html') + response.send_headers() + + response.write(b'
    \r\n') + for name in sorted(os.listdir(path)): + if name.isprintable() and not name.startswith('.'): + try: + bname = name.encode('ascii') + except UnicodeError: + pass + else: + if os.path.isdir(os.path.join(path, name)): + response.write(b'
  • ' + bname + b'/
  • \r\n') + else: + response.write(b'
  • ' + bname + b'
  • \r\n') + response.write(b'
') + else: + response.add_header('Content-type', 'text/plain') + response.send_headers() + + try: + with open(path, 'rb') as fp: + chunk = fp.read(8196) + while chunk: + response.write(chunk) + chunk = fp.read(8196) + except OSError: + response.write(b'Cannot open') + + response.write_eof() + if response.keep_alive(): + self.keep_alive(True) + + +class ChildProcess: + + def __init__(self, up_read, down_write, args, sock): + self.up_read = up_read + self.down_write = down_write + self.args = args + self.sock = sock + + def start(self): + # start server + self.loop = loop = tulip.new_event_loop() + tulip.set_event_loop(loop) + + def stop(): + self.loop.stop() + os._exit(0) + loop.add_signal_handler(signal.SIGINT, stop) + + f = loop.start_serving( + lambda: HttpServer(debug=True, keep_alive=75), sock=self.sock) + x = loop.run_until_complete(f)[0] + print('Starting srv worker process {} on {}'.format( + os.getpid(), x.getsockname())) + + # heartbeat + self.heartbeat() + + tulip.get_event_loop().run_forever() + os._exit(0) + + @tulip.task + def heartbeat(self): + # setup pipes + read_transport, read_proto = yield from self.loop.connect_read_pipe( + tulip.StreamProtocol, os.fdopen(self.up_read, 'rb')) + write_transport, _ = yield from self.loop.connect_write_pipe( + tulip.StreamProtocol, os.fdopen(self.down_write, 'wb')) + + reader = read_proto.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(write_transport) + + while True: + msg = yield from reader.read() + if msg is None: + print('Superviser is dead, {} stopping...'.format(os.getpid())) + self.loop.stop() + break + elif msg.tp == websocket.MSG_PING: + writer.pong() + elif msg.tp == websocket.MSG_CLOSE: + break + + read_transport.close() + write_transport.close() + + +class Worker: + + _started = False + + def __init__(self, loop, args, sock): + self.loop = loop + self.args = args + self.sock = sock + self.start() + + def start(self): + assert not self._started + self._started = True + + up_read, up_write = os.pipe() + down_read, down_write = os.pipe() + args, sock = self.args, self.sock + + pid = os.fork() + if pid: + # parent + os.close(up_read) + os.close(down_write) + self.connect(pid, up_write, down_read) + else: + # child + os.close(up_write) + os.close(down_read) + + # cleanup after fork + tulip.set_event_loop(None) + + # setup process + process = ChildProcess(up_read, down_write, args, sock) + process.start() + + @tulip.task + def heartbeat(self, writer): + while True: + yield from tulip.sleep(15) + + if (time.monotonic() - self.ping) < 30: + writer.ping() + else: + print('Restart unresponsive worker process: {}'.format( + self.pid)) + self.kill() + self.start() + return + + @tulip.task + def chat(self, reader): + while True: + msg = yield from reader.read() + if msg is None: + print('Restart unresponsive worker process: {}'.format( + self.pid)) + self.kill() + self.start() + return + elif msg.tp == websocket.MSG_PONG: + self.ping = time.monotonic() + + @tulip.task + def connect(self, pid, up_write, down_read): + # setup pipes + read_transport, proto = yield from self.loop.connect_read_pipe( + tulip.StreamProtocol, os.fdopen(down_read, 'rb')) + write_transport, _ = yield from self.loop.connect_write_pipe( + tulip.StreamProtocol, os.fdopen(up_write, 'wb')) + + # websocket protocol + reader = proto.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(write_transport) + + # store info + self.pid = pid + self.ping = time.monotonic() + self.rtransport = read_transport + self.wtransport = write_transport + self.chat_task = self.chat(reader) + self.heartbeat_task = self.heartbeat(writer) + + def kill(self): + self._started = False + self.chat_task.cancel() + self.heartbeat_task.cancel() + self.rtransport.close() + self.wtransport.close() + os.kill(self.pid, signal.SIGTERM) + + +class Superviser: + + def __init__(self, args): + self.loop = tulip.get_event_loop() + self.args = args + self.workers = [] + + def start(self): + # bind socket + sock = self.sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((self.args.host, self.args.port)) + sock.listen(1024) + sock.setblocking(False) + + # start processes + for idx in range(self.args.workers): + self.workers.append(Worker(self.loop, self.args, sock)) + + self.loop.add_signal_handler(signal.SIGINT, lambda: self.loop.stop()) + self.loop.run_forever() + + +def main(): + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + superviser = Superviser(args) + superviser.start() + + +if __name__ == '__main__': + main() diff --git a/examples/srv.py b/examples/srv.py new file mode 100755 index 00000000..e01e407c --- /dev/null +++ b/examples/srv.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +"""Simple server written using an event loop.""" + +import argparse +import email.message +import logging +import os +import sys +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +import tulip +import tulip.http + + +class HttpServer(tulip.http.ServerHttpProtocol): + + @tulip.coroutine + def handle_request(self, message, payload): + print('method = {!r}; path = {!r}; version = {!r}'.format( + message.method, message.path, message.version)) + + path = message.path + + if (not (path.isprintable() and path.startswith('/')) or '/.' in path): + print('bad path', repr(path)) + path = None + else: + path = '.' + path + if not os.path.exists(path): + print('no file', repr(path)) + path = None + else: + isdir = os.path.isdir(path) + + if not path: + raise tulip.http.HttpStatusException(404) + + headers = email.message.Message() + for hdr, val in message.headers: + print(hdr, val) + headers.add_header(hdr, val) + + if isdir and not path.endswith('/'): + path = path + '/' + raise tulip.http.HttpStatusException( + 302, headers=(('URI', path), ('Location', path))) + + response = tulip.http.Response(self.transport, 200) + response.add_header('Transfer-Encoding', 'chunked') + + # content encoding + accept_encoding = headers.get('accept-encoding', '').lower() + if 'deflate' in accept_encoding: + response.add_header('Content-Encoding', 'deflate') + response.add_compression_filter('deflate') + elif 'gzip' in accept_encoding: + response.add_header('Content-Encoding', 'gzip') + response.add_compression_filter('gzip') + + response.add_chunking_filter(1025) + + if isdir: + response.add_header('Content-type', 'text/html') + response.send_headers() + + response.write(b'
    \r\n') + for name in sorted(os.listdir(path)): + if name.isprintable() and not name.startswith('.'): + try: + bname = name.encode('ascii') + except UnicodeError: + pass + else: + if os.path.isdir(os.path.join(path, name)): + response.write(b'
  • ' + bname + b'/
  • \r\n') + else: + response.write(b'
  • ' + bname + b'
  • \r\n') + response.write(b'
') + else: + response.add_header('Content-type', 'text/plain') + response.send_headers() + + try: + with open(path, 'rb') as fp: + chunk = fp.read(8196) + while chunk: + response.write(chunk) + chunk = fp.read(8196) + except OSError: + response.write(b'Cannot open') + + response.write_eof() + if response.keep_alive(): + self.keep_alive(True) + + +ARGS = argparse.ArgumentParser(description="Run simple http server.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') +ARGS.add_argument( + '--iocp', action="store_true", dest='iocp', help='Windows IOCP event loop') +ARGS.add_argument( + '--ssl', action="store_true", dest='ssl', help='Run ssl mode.') +ARGS.add_argument( + '--sslcert', action="store", dest='certfile', help='SSL cert file.') +ARGS.add_argument( + '--sslkey', action="store", dest='keyfile', help='SSL key file.') + + +def main(): + args = ARGS.parse_args() + + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if args.iocp: + from tulip import windows_events + sys.argv.remove('--iocp') + logging.info('using iocp') + el = windows_events.ProactorEventLoop() + tulip.set_event_loop(el) + + if args.ssl: + here = os.path.join(os.path.dirname(__file__), 'tests') + + if args.certfile: + certfile = args.certfile or os.path.join(here, 'sample.crt') + keyfile = args.keyfile or os.path.join(here, 'sample.key') + else: + certfile = os.path.join(here, 'sample.crt') + keyfile = os.path.join(here, 'sample.key') + + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain(certfile, keyfile) + else: + sslcontext = None + + loop = tulip.get_event_loop() + f = loop.start_serving( + lambda: HttpServer(debug=True, keep_alive=75), args.host, args.port, + ssl=sslcontext) + socks = loop.run_until_complete(f) + print('serving on', socks[0].getsockname()) + try: + loop.run_forever() + except KeyboardInterrupt: + pass + + +if __name__ == '__main__': + main() diff --git a/examples/tcp_echo.py b/examples/tcp_echo.py new file mode 100755 index 00000000..39db5cca --- /dev/null +++ b/examples/tcp_echo.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +"""TCP echo server example.""" +import argparse +import tulip +try: + import signal +except ImportError: + signal = None + + +class EchoServer(tulip.Protocol): + + TIMEOUT = 5.0 + + def timeout(self): + print('connection timeout, closing.') + self.transport.close() + + def connection_made(self, transport): + print('connection made') + self.transport = transport + + # start 5 seconds timeout timer + self.h_timeout = tulip.get_event_loop().call_later( + self.TIMEOUT, self.timeout) + + def data_received(self, data): + print('data received: ', data.decode()) + self.transport.write(b'Re: ' + data) + + # restart timeout timer + self.h_timeout.cancel() + self.h_timeout = tulip.get_event_loop().call_later( + self.TIMEOUT, self.timeout) + + def eof_received(self): + pass + + def connection_lost(self, exc): + print('connection lost:', exc) + self.h_timeout.cancel() + + +class EchoClient(tulip.Protocol): + + message = 'This is the message. It will be echoed.' + + def connection_made(self, transport): + self.transport = transport + self.transport.write(self.message.encode()) + print('data sent:', self.message) + + def data_received(self, data): + print('data received:', data) + + # disconnect after 10 seconds + tulip.get_event_loop().call_later(10.0, self.transport.close) + + def eof_received(self): + pass + + def connection_lost(self, exc): + print('connection lost:', exc) + tulip.get_event_loop().stop() + + +def start_client(loop, host, port): + t = tulip.Task(loop.create_connection(EchoClient, host, port)) + loop.run_until_complete(t) + + +def start_server(loop, host, port): + f = loop.start_serving(EchoServer, host, port) + x = loop.run_until_complete(f)[0] + print('serving on', x.getsockname()) + + +ARGS = argparse.ArgumentParser(description="TCP Echo example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run tcp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run tcp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') + + +if __name__ == '__main__': + args = ARGS.parse_args() + + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + loop = tulip.get_event_loop() + if signal is not None: + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if args.server: + start_server(loop, args.host, args.port) + else: + start_client(loop, args.host, args.port) + + loop.run_forever() diff --git a/examples/tcp_protocol_parser.py b/examples/tcp_protocol_parser.py new file mode 100755 index 00000000..a0258613 --- /dev/null +++ b/examples/tcp_protocol_parser.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +"""Protocol parser example.""" +import argparse +import collections +import tulip +try: + import signal +except ImportError: + signal = None + + +MSG_TEXT = b'text:' +MSG_PING = b'ping:' +MSG_PONG = b'pong:' +MSG_STOP = b'stop:' + +Message = collections.namedtuple('Message', ('tp', 'data')) + + +def my_protocol_parser(): + """Parser is used with StreamBuffer for incremental protocol parsing. + Parser is a generator function, but it is not a coroutine. Usually + parsers are implemented as a state machine. + + more details in tulip/parsers.py + existing parsers: + * http protocol parsers tulip/http/protocol.py + * websocket parser tulip/http/websocket.py + """ + out, buf = yield + + while True: + tp = yield from buf.read(5) + if tp in (MSG_PING, MSG_PONG): + # skip line + yield from buf.skipuntil(b'\r\n') + out.feed_data(Message(tp, None)) + elif tp == MSG_STOP: + out.feed_data(Message(tp, None)) + elif tp == MSG_TEXT: + # read text + text = yield from buf.readuntil(b'\r\n') + out.feed_data(Message(tp, text.strip().decode('utf-8'))) + else: + raise ValueError('Unknown protocol prefix.') + + +class MyProtocolWriter: + + def __init__(self, transport): + self.transport = transport + + def ping(self): + self.transport.write(b'ping:\r\n') + + def pong(self): + self.transport.write(b'pong:\r\n') + + def stop(self): + self.transport.write(b'stop:\r\n') + + def send_text(self, text): + self.transport.write( + 'text:{}\r\n'.format(text.strip()).encode('utf-8')) + + +class EchoServer(tulip.Protocol): + + def connection_made(self, transport): + print('Connection made') + self.transport = transport + self.stream = tulip.StreamBuffer() + self.dispatch() + + def data_received(self, data): + self.stream.feed_data(data) + + def eof_received(self): + self.stream.feed_eof() + + def connection_lost(self, exc): + print('Connection lost') + + @tulip.task + def dispatch(self): + reader = self.stream.set_parser(my_protocol_parser()) + writer = MyProtocolWriter(self.transport) + + while True: + msg = yield from reader.read() + if msg is None: + break # client has been disconnected + + print('Message received: {}'.format(msg)) + + if msg.tp == MSG_PING: + writer.pong() + elif msg.tp == MSG_TEXT: + writer.send_text('Re: ' + msg.data) + elif msg.tp == MSG_STOP: + self.transport.close() + break + + +@tulip.task +def start_client(loop, host, port): + transport, stream = yield from loop.create_connection( + tulip.StreamProtocol, host, port) + reader = stream.set_parser(my_protocol_parser()) + writer = MyProtocolWriter(transport) + writer.ping() + + message = 'This is the message. It will be echoed.' + + while True: + msg = yield from reader.read() + + print('Message received: {}'.format(msg)) + if msg.tp == MSG_PONG: + writer.send_text(message) + print('data sent:', message) + elif msg.tp == MSG_TEXT: + writer.stop() + print('stop sent') + break + + transport.close() + + +def start_server(loop, host, port): + f = loop.start_serving(EchoServer, host, port) + x = loop.run_until_complete(f)[0] + print('serving on', x.getsockname()) + loop.run_forever() + + +ARGS = argparse.ArgumentParser(description="Protocol parser example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run tcp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run tcp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') + + +if __name__ == '__main__': + args = ARGS.parse_args() + + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + loop = tulip.get_event_loop() + if signal is not None: + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if args.server: + start_server(loop, args.host, args.port) + else: + loop.run_until_complete(start_client(loop, args.host, args.port)) diff --git a/examples/udp_echo.py b/examples/udp_echo.py new file mode 100755 index 00000000..0347bfbd --- /dev/null +++ b/examples/udp_echo.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +"""UDP echo example.""" +import argparse +import sys +import tulip +try: + import signal +except ImportError: + signal = None + + +class MyServerUdpEchoProtocol: + + def connection_made(self, transport): + print('start', transport) + self.transport = transport + + def datagram_received(self, data, addr): + print('Data received:', data, addr) + self.transport.sendto(data, addr) + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('stop', exc) + + +class MyClientUdpEchoProtocol: + + message = 'This is the message. It will be echoed.' + + def connection_made(self, transport): + self.transport = transport + print('sending "{}"'.format(self.message)) + self.transport.sendto(self.message.encode()) + print('waiting to receive') + + def datagram_received(self, data, addr): + print('received "{}"'.format(data.decode())) + self.transport.close() + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('closing transport', exc) + loop = tulip.get_event_loop() + loop.stop() + + +def start_server(loop, addr): + t = tulip.Task(loop.create_datagram_endpoint( + MyServerUdpEchoProtocol, local_addr=addr)) + loop.run_until_complete(t) + + +def start_client(loop, addr): + t = tulip.Task(loop.create_datagram_endpoint( + MyClientUdpEchoProtocol, remote_addr=addr)) + loop.run_until_complete(t) + + +ARGS = argparse.ArgumentParser(description="UDP Echo example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run udp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run udp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') + + +if __name__ == '__main__': + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + loop = tulip.get_event_loop() + if signal is not None: + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if '--server' in sys.argv: + start_server(loop, (args.host, args.port)) + else: + start_client(loop, (args.host, args.port)) + + loop.run_forever() diff --git a/examples/websocket.html b/examples/websocket.html new file mode 100644 index 00000000..6bad7f74 --- /dev/null +++ b/examples/websocket.html @@ -0,0 +1,90 @@ + + + + + + + + +

Chat!

+
+  | Status: + disconnected +
+
+
+
+ + +
+ + diff --git a/examples/wsclient.py b/examples/wsclient.py new file mode 100755 index 00000000..f5b2ef58 --- /dev/null +++ b/examples/wsclient.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +"""websocket cmd client for wssrv.py example.""" +import argparse +import base64 +import hashlib +import os +import signal +import sys + +import tulip +import tulip.http +from tulip.http import websocket +import tulip.selectors + +WS_KEY = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + +def start_client(loop, url): + name = input('Please enter your name: ').encode() + + sec_key = base64.b64encode(os.urandom(16)) + + # send request + response = yield from tulip.http.request( + 'get', url, + headers={ + 'UPGRADE': 'WebSocket', + 'CONNECTION': 'Upgrade', + 'SEC-WEBSOCKET-VERSION': '13', + 'SEC-WEBSOCKET-KEY': sec_key.decode(), + }, timeout=1.0) + + # websocket handshake + if response.status != 101: + raise ValueError("Handshake error: Invalid response status") + if response.get('upgrade', '').lower() != 'websocket': + raise ValueError("Handshake error - Invalid upgrade header") + if response.get('connection', '').lower() != 'upgrade': + raise ValueError("Handshake error - Invalid connection header") + + key = response.get('sec-websocket-accept', '').encode() + match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()) + if key != match: + raise ValueError("Handshake error - Invalid challenge response") + + # switch to websocket protocol + stream = response.stream.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(response.transport) + + # input reader + def stdin_callback(): + line = sys.stdin.buffer.readline() + if not line: + loop.stop() + else: + writer.send(name + b': ' + line) + loop.add_reader(sys.stdin.fileno(), stdin_callback) + + @tulip.coroutine + def dispatch(): + while True: + msg = yield from stream.read() + if msg is None: + break + elif msg.tp == websocket.MSG_PING: + writer.pong() + elif msg.tp == websocket.MSG_TEXT: + print(msg.data.strip()) + elif msg.tp == websocket.MSG_CLOSE: + break + + yield from dispatch() + + +ARGS = argparse.ArgumentParser( + description="websocket console client for wssrv.py example.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') + +if __name__ == '__main__': + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + url = 'http://{}:{}'.format(args.host, args.port) + + loop = tulip.SelectorEventLoop(tulip.selectors.SelectSelector()) + tulip.set_event_loop(loop) + + loop.add_signal_handler(signal.SIGINT, loop.stop) + tulip.Task(start_client(loop, url)) + loop.run_forever() diff --git a/examples/wssrv.py b/examples/wssrv.py new file mode 100755 index 00000000..f96e0855 --- /dev/null +++ b/examples/wssrv.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python3 +"""Multiprocess WebSocket http chat example.""" + +import argparse +import os +import socket +import signal +import time +import tulip +import tulip.http +from tulip.http import websocket + +ARGS = argparse.ArgumentParser(description="Run simple http server.") +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=8080, type=int, help='Port number') +ARGS.add_argument( + '--workers', action="store", dest='workers', + default=2, type=int, help='Number of workers.') + +WS_FILE = os.path.join(os.path.dirname(__file__), 'websocket.html') + + +class HttpServer(tulip.http.ServerHttpProtocol): + + clients = None # list of all active connections + parent = None # process supervisor + # we use it as broadcaster to all workers + + @tulip.coroutine + def handle_request(self, message, payload): + upgrade = False + for hdr, val in message.headers: + if hdr == 'UPGRADE': + upgrade = 'websocket' in val.lower() + break + + if upgrade: + # websocket handshake + status, headers, parser, writer = websocket.do_handshake( + message.method, message.headers, self.transport) + + resp = tulip.http.Response(self.transport, status) + resp.add_headers(*headers) + resp.send_headers() + + # install websocket parser + databuffer = self.stream.set_parser(parser) + + # notify everybody + print('{}: Someone joined.'.format(os.getpid())) + for wsc in self.clients: + wsc.send(b'Someone joined.') + self.clients.append(writer) + self.parent.send(b'Someone joined.') + + # chat dispatcher + while True: + msg = yield from databuffer.read() + if msg is None: # client droped connection + break + + if msg.tp == websocket.MSG_PING: + writer.pong() + + elif msg.tp == websocket.MSG_TEXT: + data = msg.data.strip() + print('{}: {}'.format(os.getpid(), data)) + for wsc in self.clients: + if wsc is not writer: + wsc.send(data.encode()) + self.parent.send(data) + + elif msg.tp == websocket.MSG_CLOSE: + break + + # notify everybody + print('{}: Someone disconnected.'.format(os.getpid())) + self.parent.send(b'Someone disconnected.') + self.clients.remove(writer) + for wsc in self.clients: + wsc.send(b'Someone disconnected.') + + else: + # send html page with js chat + response = tulip.http.Response(self.transport, 200) + response.add_header('Transfer-Encoding', 'chunked') + response.add_header('Content-type', 'text/html') + response.send_headers() + + try: + with open(WS_FILE, 'rb') as fp: + chunk = fp.read(8196) + while chunk: + if not response.write(chunk): + break + chunk = fp.read(8196) + except OSError: + response.write(b'Cannot open') + + response.write_eof() + if response.keep_alive(): + self.keep_alive(True) + + +class ChildProcess: + + def __init__(self, up_read, down_write, args, sock): + self.up_read = up_read + self.down_write = down_write + self.args = args + self.sock = sock + self.clients = [] + + def start(self): + # start server + self.loop = loop = tulip.new_event_loop() + tulip.set_event_loop(loop) + + def stop(): + self.loop.stop() + os._exit(0) + loop.add_signal_handler(signal.SIGINT, stop) + + # heartbeat + self.heartbeat() + + tulip.get_event_loop().run_forever() + os._exit(0) + + @tulip.task + def start_server(self, writer): + socks = yield from self.loop.start_serving( + lambda: HttpServer( + debug=True, keep_alive=75, + parent=writer, clients=self.clients), + sock=self.sock) + print('Starting srv worker process {} on {}'.format( + os.getpid(), socks[0].getsockname())) + + @tulip.task + def heartbeat(self): + # setup pipes + read_transport, read_proto = yield from self.loop.connect_read_pipe( + tulip.StreamProtocol, os.fdopen(self.up_read, 'rb')) + write_transport, _ = yield from self.loop.connect_write_pipe( + tulip.StreamProtocol, os.fdopen(self.down_write, 'wb')) + + reader = read_proto.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(write_transport) + + self.start_server(writer) + + while True: + msg = yield from reader.read() + if msg is None: + print('Superviser is dead, {} stopping...'.format(os.getpid())) + self.loop.stop() + break + elif msg.tp == websocket.MSG_PING: + writer.pong() + elif msg.tp == websocket.MSG_CLOSE: + break + elif msg.tp == websocket.MSG_TEXT: # broadcast message + for wsc in self.clients: + wsc.send(msg.data.strip().encode()) + + read_transport.close() + write_transport.close() + + +class Worker: + + _started = False + + def __init__(self, sv, loop, args, sock): + self.sv = sv + self.loop = loop + self.args = args + self.sock = sock + self.start() + + def start(self): + assert not self._started + self._started = True + + up_read, up_write = os.pipe() + down_read, down_write = os.pipe() + args, sock = self.args, self.sock + + pid = os.fork() + if pid: + # parent + os.close(up_read) + os.close(down_write) + self.connect(pid, up_write, down_read) + else: + # child + os.close(up_write) + os.close(down_read) + + # cleanup after fork + tulip.set_event_loop(None) + + # setup process + process = ChildProcess(up_read, down_write, args, sock) + process.start() + + @tulip.task + def heartbeat(self, writer): + while True: + yield from tulip.sleep(15) + + if (time.monotonic() - self.ping) < 30: + writer.ping() + else: + print('Restart unresponsive worker process: {}'.format( + self.pid)) + self.kill() + self.start() + return + + @tulip.task + def chat(self, reader): + while True: + msg = yield from reader.read() + if msg is None: + print('Restart unresponsive worker process: {}'.format( + self.pid)) + self.kill() + self.start() + return + + elif msg.tp == websocket.MSG_PONG: + self.ping = time.monotonic() + + elif msg.tp == websocket.MSG_TEXT: # broadcast to all workers + for worker in self.sv.workers: + if self.pid != worker.pid: + worker.writer.send(msg.data) + + @tulip.task + def connect(self, pid, up_write, down_read): + # setup pipes + read_transport, proto = yield from self.loop.connect_read_pipe( + tulip.StreamProtocol, os.fdopen(down_read, 'rb')) + write_transport, _ = yield from self.loop.connect_write_pipe( + tulip.StreamProtocol, os.fdopen(up_write, 'wb')) + + # websocket protocol + reader = proto.set_parser(websocket.WebSocketParser()) + writer = websocket.WebSocketWriter(write_transport) + + # store info + self.pid = pid + self.ping = time.monotonic() + self.writer = writer + self.rtransport = read_transport + self.wtransport = write_transport + self.chat_task = self.chat(reader) + self.heartbeat_task = self.heartbeat(writer) + + def kill(self): + self._started = False + self.chat_task.cancel() + self.heartbeat_task.cancel() + self.rtransport.close() + self.wtransport.close() + os.kill(self.pid, signal.SIGTERM) + + +class Superviser: + + def __init__(self, args): + self.loop = tulip.get_event_loop() + self.args = args + self.workers = [] + + def start(self): + # bind socket + sock = self.sock = socket.socket() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((self.args.host, self.args.port)) + sock.listen(1024) + sock.setblocking(False) + + # start processes + for idx in range(self.args.workers): + self.workers.append(Worker(self, self.loop, self.args, sock)) + + self.loop.add_signal_handler(signal.SIGINT, lambda: self.loop.stop()) + self.loop.run_forever() + + +def main(): + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + superviser = Superviser(args) + superviser.start() + + +if __name__ == '__main__': + main() diff --git a/overlapped.c b/overlapped.c new file mode 100644 index 00000000..3a2c1208 --- /dev/null +++ b/overlapped.c @@ -0,0 +1,1009 @@ +/* + * Support for overlapped IO + * + * Some code borrowed from Modules/_winapi.c of CPython + */ + +/* XXX check overflow and DWORD <-> Py_ssize_t conversions + Check itemsize */ + +#include "Python.h" +#include "structmember.h" + +#define WINDOWS_LEAN_AND_MEAN +#include +#include +#include + +#if defined(MS_WIN32) && !defined(MS_WIN64) +# define F_POINTER "k" +# define T_POINTER T_ULONG +#else +# define F_POINTER "K" +# define T_POINTER T_ULONGLONG +#endif + +#define F_HANDLE F_POINTER +#define F_ULONG_PTR F_POINTER +#define F_DWORD "k" +#define F_BOOL "i" +#define F_UINT "I" + +#define T_HANDLE T_POINTER + +enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_WRITE, TYPE_ACCEPT, + TYPE_CONNECT, TYPE_DISCONNECT}; + +/* + * Map Windows error codes to subclasses of OSError + */ + +static PyObject * +SetFromWindowsErr(DWORD err) +{ + PyObject *exception_type; + + if (err == 0) + err = GetLastError(); + switch (err) { + case ERROR_CONNECTION_REFUSED: + exception_type = PyExc_ConnectionRefusedError; + break; + case ERROR_CONNECTION_ABORTED: + exception_type = PyExc_ConnectionAbortedError; + break; + default: + exception_type = PyExc_OSError; + } + return PyErr_SetExcFromWindowsErr(exception_type, err); +} + +/* + * Some functions should be loaded at runtime + */ + +static LPFN_ACCEPTEX Py_AcceptEx = NULL; +static LPFN_CONNECTEX Py_ConnectEx = NULL; +static LPFN_DISCONNECTEX Py_DisconnectEx = NULL; +static BOOL (CALLBACK *Py_CancelIoEx)(HANDLE, LPOVERLAPPED) = NULL; + +#define GET_WSA_POINTER(s, x) \ + (SOCKET_ERROR != WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, \ + &Guid##x, sizeof(Guid##x), &Py_##x, \ + sizeof(Py_##x), &dwBytes, NULL, NULL)) + +static int +initialize_function_pointers(void) +{ + GUID GuidAcceptEx = WSAID_ACCEPTEX; + GUID GuidConnectEx = WSAID_CONNECTEX; + GUID GuidDisconnectEx = WSAID_DISCONNECTEX; + HINSTANCE hKernel32; + SOCKET s; + DWORD dwBytes; + + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (s == INVALID_SOCKET) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + if (!GET_WSA_POINTER(s, AcceptEx) || + !GET_WSA_POINTER(s, ConnectEx) || + !GET_WSA_POINTER(s, DisconnectEx)) + { + closesocket(s); + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + closesocket(s); + + /* On WinXP we will have Py_CancelIoEx == NULL */ + hKernel32 = GetModuleHandle("KERNEL32"); + *(FARPROC *)&Py_CancelIoEx = GetProcAddress(hKernel32, "CancelIoEx"); + return 0; +} + +/* + * Completion port stuff + */ + +PyDoc_STRVAR( + CreateIoCompletionPort_doc, + "CreateIoCompletionPort(handle, port, key, concurrency) -> port\n\n" + "Create a completion port or register a handle with a port."); + +static PyObject * +overlapped_CreateIoCompletionPort(PyObject *self, PyObject *args) +{ + HANDLE FileHandle; + HANDLE ExistingCompletionPort; + ULONG_PTR CompletionKey; + DWORD NumberOfConcurrentThreads; + HANDLE ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE F_ULONG_PTR F_DWORD, + &FileHandle, &ExistingCompletionPort, &CompletionKey, + &NumberOfConcurrentThreads)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = CreateIoCompletionPort(FileHandle, ExistingCompletionPort, + CompletionKey, NumberOfConcurrentThreads); + Py_END_ALLOW_THREADS + + if (ret == NULL) + return SetFromWindowsErr(0); + return Py_BuildValue(F_HANDLE, ret); +} + +PyDoc_STRVAR( + GetQueuedCompletionStatus_doc, + "GetQueuedCompletionStatus(port, msecs) -> (err, bytes, key, address)\n\n" + "Get a message from completion port. Wait for up to msecs milliseconds."); + +static PyObject * +overlapped_GetQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort = NULL; + DWORD NumberOfBytes = 0; + ULONG_PTR CompletionKey = 0; + OVERLAPPED *Overlapped = NULL; + DWORD Milliseconds; + DWORD err; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, + &CompletionPort, &Milliseconds)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = GetQueuedCompletionStatus(CompletionPort, &NumberOfBytes, + &CompletionKey, &Overlapped, Milliseconds); + Py_END_ALLOW_THREADS + + err = ret ? ERROR_SUCCESS : GetLastError(); + if (Overlapped == NULL) { + if (err == WAIT_TIMEOUT) + Py_RETURN_NONE; + else + return SetFromWindowsErr(err); + } + return Py_BuildValue(F_DWORD F_DWORD F_ULONG_PTR F_POINTER, + err, NumberOfBytes, CompletionKey, Overlapped); +} + +PyDoc_STRVAR( + PostQueuedCompletionStatus_doc, + "PostQueuedCompletionStatus(port, bytes, key, address) -> None\n\n" + "Post a message to completion port."); + +static PyObject * +overlapped_PostQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort; + DWORD NumberOfBytes; + ULONG_PTR CompletionKey; + OVERLAPPED *Overlapped; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD F_ULONG_PTR F_POINTER, + &CompletionPort, &NumberOfBytes, &CompletionKey, + &Overlapped)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = PostQueuedCompletionStatus(CompletionPort, NumberOfBytes, + CompletionKey, Overlapped); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +/* + * Bind socket handle to local port without doing slow getaddrinfo() + */ + +PyDoc_STRVAR( + BindLocal_doc, + "BindLocal(handle, length_of_address_tuple) -> None\n\n" + "Bind a socket handle to an arbitrary local port.\n" + "If length_of_address_tuple is 2 then an AF_INET address is used.\n" + "If length_of_address_tuple is 4 then an AF_INET6 address is used."); + +static PyObject * +overlapped_BindLocal(PyObject *self, PyObject *args) +{ + SOCKET Socket; + int TupleLength; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE "i", &Socket, &TupleLength)) + return NULL; + + if (TupleLength == 2) { + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = 0; + addr.sin_addr.S_un.S_addr = INADDR_ANY; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else if (TupleLength == 4) { + struct sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_port = 0; + addr.sin6_addr = in6addr_any; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else { + PyErr_SetString(PyExc_ValueError, "expected tuple of length 2 or 4"); + return NULL; + } + + if (!ret) + return SetFromWindowsErr(WSAGetLastError()); + Py_RETURN_NONE; +} + +/* + * Mark operation as completed - used when reading produces ERROR_BROKEN_PIPE + */ + +static void +mark_as_completed(OVERLAPPED *ov) +{ + ov->Internal = 0; + if (ov->hEvent != NULL) + SetEvent(ov->hEvent); +} + +/* + * A Python object wrapping an OVERLAPPED structure and other useful data + * for overlapped I/O + */ + +PyDoc_STRVAR( + Overlapped_doc, + "Overlapped object"); + +typedef struct { + PyObject_HEAD + OVERLAPPED overlapped; + /* For convenience, we store the file handle too */ + HANDLE handle; + /* Error returned by last method call */ + DWORD error; + /* Type of operation */ + DWORD type; + /* Buffer used for reading (optional) */ + PyObject *read_buffer; + /* Buffer used for writing (optional) */ + Py_buffer write_buffer; +} OverlappedObject; + + +static PyObject * +Overlapped_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + OverlappedObject *self; + HANDLE event = INVALID_HANDLE_VALUE; + static char *kwlist[] = {"event", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|" F_HANDLE, kwlist, &event)) + return NULL; + + if (event == INVALID_HANDLE_VALUE) { + event = CreateEvent(NULL, TRUE, FALSE, NULL); + if (event == NULL) + return SetFromWindowsErr(0); + } + + self = PyObject_New(OverlappedObject, type); + if (self == NULL) { + if (event != NULL) + CloseHandle(event); + return NULL; + } + + self->handle = NULL; + self->error = 0; + self->type = TYPE_NONE; + self->read_buffer = NULL; + memset(&self->overlapped, 0, sizeof(OVERLAPPED)); + memset(&self->write_buffer, 0, sizeof(Py_buffer)); + if (event) + self->overlapped.hEvent = event; + return (PyObject *)self; +} + +static void +Overlapped_dealloc(OverlappedObject *self) +{ + DWORD bytes; + DWORD olderr = GetLastError(); + BOOL wait = FALSE; + BOOL ret; + + if (!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED) + { + if (Py_CancelIoEx && Py_CancelIoEx(self->handle, &self->overlapped)) + wait = TRUE; + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, + &bytes, wait); + Py_END_ALLOW_THREADS + + switch (ret ? ERROR_SUCCESS : GetLastError()) { + case ERROR_SUCCESS: + case ERROR_NOT_FOUND: + case ERROR_OPERATION_ABORTED: + break; + default: + PyErr_Format( + PyExc_RuntimeError, + "%R still has pending operation at " + "deallocation, the process may crash", self); + PyErr_WriteUnraisable(NULL); + } + } + + if (self->overlapped.hEvent != NULL) + CloseHandle(self->overlapped.hEvent); + + if (self->write_buffer.obj) + PyBuffer_Release(&self->write_buffer); + + Py_CLEAR(self->read_buffer); + PyObject_Del(self); + SetLastError(olderr); +} + +PyDoc_STRVAR( + Overlapped_cancel_doc, + "cancel() -> None\n\n" + "Cancel overlapped operation"); + +static PyObject * +Overlapped_cancel(OverlappedObject *self) +{ + BOOL ret = TRUE; + + if (self->type == TYPE_NOT_STARTED) + Py_RETURN_NONE; + + if (!HasOverlappedIoCompleted(&self->overlapped)) { + Py_BEGIN_ALLOW_THREADS + if (Py_CancelIoEx) + ret = Py_CancelIoEx(self->handle, &self->overlapped); + else + ret = CancelIo(self->handle); + Py_END_ALLOW_THREADS + } + + /* CancelIoEx returns ERROR_NOT_FOUND if the I/O completed in-between */ + if (!ret && GetLastError() != ERROR_NOT_FOUND) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +PyDoc_STRVAR( + Overlapped_getresult_doc, + "getresult(wait=False) -> result\n\n" + "Retrieve result of operation. If wait is true then it blocks\n" + "until the operation is finished. If wait is false and the\n" + "operation is still pending then an error is raised."); + +static PyObject * +Overlapped_getresult(OverlappedObject *self, PyObject *args) +{ + BOOL wait = FALSE; + DWORD transferred = 0; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, "|" F_BOOL, &wait)) + return NULL; + + if (self->type == TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation not yet attempted"); + return NULL; + } + + if (self->type == TYPE_NOT_STARTED) { + PyErr_SetString(PyExc_ValueError, "operation failed to start"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, &transferred, + wait); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + break; + case ERROR_BROKEN_PIPE: + if (self->read_buffer != NULL) + break; + /* fall through */ + default: + return SetFromWindowsErr(err); + } + + switch (self->type) { + case TYPE_READ: + assert(PyBytes_CheckExact(self->read_buffer)); + if (transferred != PyBytes_GET_SIZE(self->read_buffer) && + _PyBytes_Resize(&self->read_buffer, transferred)) + return NULL; + Py_INCREF(self->read_buffer); + return self->read_buffer; + case TYPE_ACCEPT: + case TYPE_CONNECT: + case TYPE_DISCONNECT: + Py_RETURN_NONE; + default: + return PyLong_FromUnsignedLong((unsigned long) transferred); + } +} + +PyDoc_STRVAR( + Overlapped_ReadFile_doc, + "ReadFile(handle, size) -> Overlapped[message]\n\n" + "Start overlapped read"); + +static PyObject * +Overlapped_ReadFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD nread; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &handle, &size)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = ReadFile(handle, PyBytes_AS_STRING(buf), size, &nread, + &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_BROKEN_PIPE: + mark_as_completed(&self->overlapped); + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSARecv_doc, + "RecvFile(handle, size, flags) -> Overlapped[message]\n\n" + "Start overlapped receive"); + +static PyObject * +Overlapped_WSARecv(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD flags = 0; + DWORD nread; + PyObject *buf; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD "|" F_DWORD, + &handle, &size, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + wsabuf.len = size; + wsabuf.buf = PyBytes_AS_STRING(buf); + + Py_BEGIN_ALLOW_THREADS + ret = WSARecv((SOCKET)handle, &wsabuf, 1, &nread, &flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_BROKEN_PIPE: + mark_as_completed(&self->overlapped); + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WriteFile_doc, + "WriteFile(handle, buf) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped write"); + +static PyObject * +Overlapped_WriteFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD written; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &handle, &bufobj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + + Py_BEGIN_ALLOW_THREADS + ret = WriteFile(handle, self->write_buffer.buf, + (DWORD)self->write_buffer.len, + &written, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSASend_doc, + "WSASend(handle, buf, flags) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped send"); + +static PyObject * +Overlapped_WSASend(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD flags; + DWORD written; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O" F_DWORD, + &handle, &bufobj, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + wsabuf.len = (DWORD)self->write_buffer.len; + wsabuf.buf = self->write_buffer.buf; + + Py_BEGIN_ALLOW_THREADS + ret = WSASend((SOCKET)handle, &wsabuf, 1, &written, flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_AcceptEx_doc, + "AcceptEx(listen_handle, accept_handle) -> Overlapped[address_as_bytes]\n\n" + "Start overlapped wait for client to connect"); + +static PyObject * +Overlapped_AcceptEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ListenSocket; + SOCKET AcceptSocket; + DWORD BytesReceived; + DWORD size; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE, + &ListenSocket, &AcceptSocket)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + size = sizeof(struct sockaddr_in6) + 16; + buf = PyBytes_FromStringAndSize(NULL, size*2); + if (!buf) + return NULL; + + self->type = TYPE_ACCEPT; + self->handle = (HANDLE)ListenSocket; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = Py_AcceptEx(ListenSocket, AcceptSocket, PyBytes_AS_STRING(buf), + 0, size, size, &BytesReceived, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + + +static int +parse_address(PyObject *obj, SOCKADDR *Address, int Length) +{ + char *Host; + unsigned short Port; + unsigned long FlowInfo; + unsigned long ScopeId; + + memset(Address, 0, Length); + + if (PyArg_ParseTuple(obj, "sH", &Host, &Port)) + { + Address->sa_family = AF_INET; + if (WSAStringToAddressA(Host, AF_INET, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN*)Address)->sin_port = htons(Port); + return Length; + } + else if (PyArg_ParseTuple(obj, "sHkk", &Host, &Port, &FlowInfo, &ScopeId)) + { + PyErr_Clear(); + Address->sa_family = AF_INET6; + if (WSAStringToAddressA(Host, AF_INET6, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN6*)Address)->sin6_port = htons(Port); + ((SOCKADDR_IN6*)Address)->sin6_flowinfo = FlowInfo; + ((SOCKADDR_IN6*)Address)->sin6_scope_id = ScopeId; + return Length; + } + + return -1; +} + + +PyDoc_STRVAR( + Overlapped_ConnectEx_doc, + "ConnectEx(client_handle, address_as_bytes) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_ConnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ConnectSocket; + PyObject *AddressObj; + char AddressBuf[sizeof(struct sockaddr_in6)]; + SOCKADDR *Address = (SOCKADDR*)AddressBuf; + int Length; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &ConnectSocket, &AddressObj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + Length = sizeof(AddressBuf); + Length = parse_address(AddressObj, Address, Length); + if (Length < 0) + return NULL; + + self->type = TYPE_CONNECT; + self->handle = (HANDLE)ConnectSocket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_ConnectEx(ConnectSocket, Address, Length, + NULL, 0, NULL, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_DisconnectEx_doc, + "DisconnectEx(handle, flags) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_DisconnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET Socket; + DWORD flags; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &Socket, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + self->type = TYPE_DISCONNECT; + self->handle = (HANDLE)Socket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_DisconnectEx(Socket, &self->overlapped, flags, 0); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +static PyObject* +Overlapped_getaddress(OverlappedObject *self) +{ + return PyLong_FromVoidPtr(&self->overlapped); +} + +static PyObject* +Overlapped_getpending(OverlappedObject *self) +{ + return PyBool_FromLong(!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED); +} + +static PyMethodDef Overlapped_methods[] = { + {"getresult", (PyCFunction) Overlapped_getresult, + METH_VARARGS, Overlapped_getresult_doc}, + {"cancel", (PyCFunction) Overlapped_cancel, + METH_NOARGS, Overlapped_cancel_doc}, + {"ReadFile", (PyCFunction) Overlapped_ReadFile, + METH_VARARGS, Overlapped_ReadFile_doc}, + {"WSARecv", (PyCFunction) Overlapped_WSARecv, + METH_VARARGS, Overlapped_WSARecv_doc}, + {"WriteFile", (PyCFunction) Overlapped_WriteFile, + METH_VARARGS, Overlapped_WriteFile_doc}, + {"WSASend", (PyCFunction) Overlapped_WSASend, + METH_VARARGS, Overlapped_WSASend_doc}, + {"AcceptEx", (PyCFunction) Overlapped_AcceptEx, + METH_VARARGS, Overlapped_AcceptEx_doc}, + {"ConnectEx", (PyCFunction) Overlapped_ConnectEx, + METH_VARARGS, Overlapped_ConnectEx_doc}, + {"DisconnectEx", (PyCFunction) Overlapped_DisconnectEx, + METH_VARARGS, Overlapped_DisconnectEx_doc}, + {NULL} +}; + +static PyMemberDef Overlapped_members[] = { + {"error", T_ULONG, + offsetof(OverlappedObject, error), + READONLY, "Error from last operation"}, + {"event", T_HANDLE, + offsetof(OverlappedObject, overlapped) + offsetof(OVERLAPPED, hEvent), + READONLY, "Overlapped event handle"}, + {NULL} +}; + +static PyGetSetDef Overlapped_getsets[] = { + {"address", (getter)Overlapped_getaddress, NULL, + "Address of overlapped structure"}, + {"pending", (getter)Overlapped_getpending, NULL, + "Whether the operation is pending"}, + {NULL}, +}; + +PyTypeObject OverlappedType = { + PyVarObject_HEAD_INIT(NULL, 0) + /* tp_name */ "_overlapped.Overlapped", + /* tp_basicsize */ sizeof(OverlappedObject), + /* tp_itemsize */ 0, + /* tp_dealloc */ (destructor) Overlapped_dealloc, + /* tp_print */ 0, + /* tp_getattr */ 0, + /* tp_setattr */ 0, + /* tp_reserved */ 0, + /* tp_repr */ 0, + /* tp_as_number */ 0, + /* tp_as_sequence */ 0, + /* tp_as_mapping */ 0, + /* tp_hash */ 0, + /* tp_call */ 0, + /* tp_str */ 0, + /* tp_getattro */ 0, + /* tp_setattro */ 0, + /* tp_as_buffer */ 0, + /* tp_flags */ Py_TPFLAGS_DEFAULT, + /* tp_doc */ "OVERLAPPED structure wrapper", + /* tp_traverse */ 0, + /* tp_clear */ 0, + /* tp_richcompare */ 0, + /* tp_weaklistoffset */ 0, + /* tp_iter */ 0, + /* tp_iternext */ 0, + /* tp_methods */ Overlapped_methods, + /* tp_members */ Overlapped_members, + /* tp_getset */ Overlapped_getsets, + /* tp_base */ 0, + /* tp_dict */ 0, + /* tp_descr_get */ 0, + /* tp_descr_set */ 0, + /* tp_dictoffset */ 0, + /* tp_init */ 0, + /* tp_alloc */ 0, + /* tp_new */ Overlapped_new, +}; + +static PyMethodDef overlapped_functions[] = { + {"CreateIoCompletionPort", overlapped_CreateIoCompletionPort, + METH_VARARGS, CreateIoCompletionPort_doc}, + {"GetQueuedCompletionStatus", overlapped_GetQueuedCompletionStatus, + METH_VARARGS, GetQueuedCompletionStatus_doc}, + {"PostQueuedCompletionStatus", overlapped_PostQueuedCompletionStatus, + METH_VARARGS, PostQueuedCompletionStatus_doc}, + {"BindLocal", overlapped_BindLocal, + METH_VARARGS, BindLocal_doc}, + {NULL} +}; + +static struct PyModuleDef overlapped_module = { + PyModuleDef_HEAD_INIT, + "_overlapped", + NULL, + -1, + overlapped_functions, + NULL, + NULL, + NULL, + NULL +}; + +#define WINAPI_CONSTANT(fmt, con) \ + PyDict_SetItemString(d, #con, Py_BuildValue(fmt, con)) + +PyMODINIT_FUNC +PyInit__overlapped(void) +{ + PyObject *m, *d; + + /* Ensure WSAStartup() called before initializing function pointers */ + m = PyImport_ImportModule("_socket"); + if (!m) + return NULL; + Py_DECREF(m); + + if (initialize_function_pointers() < 0) + return NULL; + + if (PyType_Ready(&OverlappedType) < 0) + return NULL; + + m = PyModule_Create(&overlapped_module); + if (PyModule_AddObject(m, "Overlapped", (PyObject *)&OverlappedType) < 0) + return NULL; + + d = PyModule_GetDict(m); + + WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING); + WINAPI_CONSTANT(F_DWORD, INFINITE); + WINAPI_CONSTANT(F_HANDLE, INVALID_HANDLE_VALUE); + WINAPI_CONSTANT(F_HANDLE, NULL); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_ACCEPT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_CONNECT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, TF_REUSE_SOCKET); + + return m; +} diff --git a/runtests.py b/runtests.py new file mode 100644 index 00000000..50068b31 --- /dev/null +++ b/runtests.py @@ -0,0 +1,261 @@ +"""Run Tulip unittests. + +Usage: + python3 runtests.py [flags] [pattern] ... + +Patterns are matched against the fully qualified name of the test, +including package, module, class and method, +e.g. 'tests.events_test.PolicyTests.testPolicy'. + +For full help, try --help. + +runtests.py --coverage is equivalent of: + + $(COVERAGE) run --branch runtests.py -v + $(COVERAGE) html $(list of files) + $(COVERAGE) report -m $(list of files) + +""" + +# Originally written by Beech Horn (for NDB). + +import argparse +import gc +import logging +import os +import re +import sys +import subprocess +import unittest +import textwrap +import importlib.machinery +try: + import coverage +except ImportError: + coverage = None + +from unittest.signals import installHandler + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +ARGS = argparse.ArgumentParser(description="Run all unittests.") +ARGS.add_argument( + '-v', action="store", dest='verbose', + nargs='?', const=1, type=int, default=0, help='verbose') +ARGS.add_argument( + '-x', action="store_true", dest='exclude', help='exclude tests') +ARGS.add_argument( + '-f', '--failfast', action="store_true", default=False, + dest='failfast', help='Stop on first fail or error') +ARGS.add_argument( + '-c', '--catch', action="store_true", default=False, + dest='catchbreak', help='Catch control-C and display results') +ARGS.add_argument( + '--forever', action="store_true", dest='forever', default=False, + help='run tests forever to catch sporadic errors') +ARGS.add_argument( + '--findleaks', action='store_true', dest='findleaks', + help='detect tests that leak memory') +ARGS.add_argument( + '-q', action="store_true", dest='quiet', help='quiet') +ARGS.add_argument( + '--tests', action="store", dest='testsdir', default='tests', + help='tests directory') +ARGS.add_argument( + '--coverage', action="store_true", dest='coverage', + help='enable html coverage report') +ARGS.add_argument( + 'pattern', action="store", nargs="*", + help='optional regex patterns to match test ids (default all tests)') + +COV_ARGS = argparse.ArgumentParser(description="Run all unittests.") +COV_ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') + + +def load_modules(basedir, suffix='.py'): + def list_dir(prefix, dir): + files = [] + + modpath = os.path.join(dir, '__init__.py') + if os.path.isfile(modpath): + mod = os.path.split(dir)[-1] + files.append(('{}{}'.format(prefix, mod), modpath)) + + prefix = '{}{}.'.format(prefix, mod) + + for name in os.listdir(dir): + path = os.path.join(dir, name) + + if os.path.isdir(path): + files.extend(list_dir('{}{}.'.format(prefix, name), path)) + else: + if (name != '__init__.py' and + name.endswith(suffix) and + not name.startswith(('.', '_'))): + files.append(('{}{}'.format(prefix, name[:-3]), path)) + + return files + + mods = [] + for modname, sourcefile in list_dir('', basedir): + if modname == 'runtests': + continue + try: + loader = importlib.machinery.SourceFileLoader(modname, sourcefile) + mods.append((loader.load_module(), sourcefile)) + except SyntaxError: + raise + except Exception as err: + print("Skipping '{}': {}".format(modname, err), file=sys.stderr) + + return mods + + +def load_tests(testsdir, includes=(), excludes=()): + mods = [mod for mod, _ in load_modules(testsdir)] + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + for mod in mods: + for name in set(dir(mod)): + if name.endswith('Tests'): + test_module = getattr(mod, name) + tests = loader.loadTestsFromTestCase(test_module) + if includes: + tests = [test + for test in tests + if any(re.search(pat, test.id()) + for pat in includes)] + if excludes: + tests = [test + for test in tests + if not any(re.search(pat, test.id()) + for pat in excludes)] + suite.addTests(tests) + + return suite + + +class TestResult(unittest.TextTestResult): + + def __init__(self, stream, descriptions, verbosity): + super().__init__(stream, descriptions, verbosity) + self.leaks = [] + + def startTest(self, test): + super().startTest(test) + gc.collect() + + def addSuccess(self, test): + super().addSuccess(test) + gc.collect() + if gc.garbage: + if self.showAll: + self.stream.writeln( + " Warning: test created {} uncollectable " + "object(s).".format(len(gc.garbage))) + # move the uncollectable objects somewhere so we don't see + # them again + self.leaks.append((self.getDescription(test), gc.garbage[:])) + del gc.garbage[:] + + +class TestRunner(unittest.TextTestRunner): + resultclass = TestResult + + def run(self, test): + result = super().run(test) + if result.leaks: + self.stream.writeln("{} tests leaks:".format(len(result.leaks))) + for name, leaks in result.leaks: + self.stream.writeln(' '*4 + name + ':') + for leak in leaks: + self.stream.writeln(' '*8 + repr(leak)) + return result + + +def runtests(): + args = ARGS.parse_args() + + if args.coverage and coverage is None: + URL = "bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py" + print(textwrap.dedent(""" + coverage package is not installed. + + To install coverage3 for Python 3, you need: + - Setuptools (https://pypi.python.org/pypi/setuptools) + + What worked for me: + - download {0} + * curl -O https://{0} + - python3 ez_setup.py + - python3 -m easy_install coverage + """.format(URL)).strip()) + sys.exit(1) + + testsdir = os.path.abspath(args.testsdir) + if not os.path.isdir(testsdir): + print("Tests directory is not found: {}\n".format(testsdir)) + ARGS.print_help() + return + + excludes = includes = [] + if args.exclude: + excludes = args.pattern + else: + includes = args.pattern + + v = 0 if args.quiet else args.verbose + 1 + failfast = args.failfast + catchbreak = args.catchbreak + findleaks = args.findleaks + runner_factory = TestRunner if findleaks else unittest.TextTestRunner + + if args.coverage: + cov = coverage.coverage(branch=True, + source=['tulip'], + ) + cov.start() + + tests = load_tests(args.testsdir, includes, excludes) + logger = logging.getLogger() + if v == 0: + logger.setLevel(logging.CRITICAL) + elif v == 1: + logger.setLevel(logging.ERROR) + elif v == 2: + logger.setLevel(logging.WARNING) + elif v == 3: + logger.setLevel(logging.INFO) + elif v >= 4: + logger.setLevel(logging.DEBUG) + if catchbreak: + installHandler() + try: + if args.forever: + while True: + result = runner_factory(verbosity=v, + failfast=failfast).run(tests) + if not result.wasSuccessful(): + sys.exit(1) + else: + result = runner_factory(verbosity=v, + failfast=failfast).run(tests) + sys.exit(not result.wasSuccessful()) + finally: + if args.coverage: + cov.stop() + cov.save() + cov.html_report(directory='htmlcov') + print("\nCoverage report:") + cov.report(show_missing=False) + here = os.path.dirname(os.path.abspath(__file__)) + print("open file://{}/htmlcov/index.html for html report".format( + here)) + + +if __name__ == '__main__': + runtests() diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..0260f9d5 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[build_ext] +build_lib=tulip diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..dcaee96f --- /dev/null +++ b/setup.py @@ -0,0 +1,14 @@ +import os +from distutils.core import setup, Extension + +extensions = [] +if os.name == 'nt': + ext = Extension('_overlapped', ['overlapped.c'], libraries=['ws2_32']) + extensions.append(ext) + +setup(name='tulip', + description="reference implementation of PEP 3156", + url='http://www.python.org/dev/peps/pep-3156/', + packages=['tulip'], + ext_modules=extensions + ) diff --git a/tests/base_events_test.py b/tests/base_events_test.py new file mode 100644 index 00000000..e27b3ab9 --- /dev/null +++ b/tests/base_events_test.py @@ -0,0 +1,591 @@ +"""Tests for base_events.py""" + +import logging +import socket +import time +import unittest +import unittest.mock + +from tulip import base_events +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import tasks +from tulip import test_utils + + +class BaseEventLoopTests(unittest.TestCase): + + def setUp(self): + self.loop = base_events.BaseEventLoop() + self.loop._selector = unittest.mock.Mock() + events.set_event_loop(None) + + def test_not_implemented(self): + m = unittest.mock.Mock() + self.assertRaises( + NotImplementedError, + self.loop._make_socket_transport, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_ssl_transport, m, m, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_datagram_transport, m, m) + self.assertRaises( + NotImplementedError, self.loop._process_events, []) + self.assertRaises( + NotImplementedError, self.loop._write_to_self) + self.assertRaises( + NotImplementedError, self.loop._read_from_self) + self.assertRaises( + NotImplementedError, + self.loop._make_read_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_write_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + next, self.loop._make_subprocess_transport(m, m, m, m, m, m, m)) + + def test__add_callback_handle(self): + h = events.Handle(lambda: False, ()) + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertIn(h, self.loop._ready) + + def test__add_callback_timer(self): + h = events.TimerHandle(time.monotonic()+10, lambda: False, ()) + + self.loop._add_callback(h) + self.assertIn(h, self.loop._scheduled) + + def test__add_callback_cancelled_handle(self): + h = events.Handle(lambda: False, ()) + h.cancel() + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertFalse(self.loop._ready) + + def test_set_default_executor(self): + executor = unittest.mock.Mock() + self.loop.set_default_executor(executor) + self.assertIs(executor, self.loop._default_executor) + + def test_getnameinfo(self): + sockaddr = unittest.mock.Mock() + self.loop.run_in_executor = unittest.mock.Mock() + self.loop.getnameinfo(sockaddr) + self.assertEqual( + (None, socket.getnameinfo, sockaddr, 0), + self.loop.run_in_executor.call_args[0]) + + def test_call_soon(self): + def cb(): + pass + + h = self.loop.call_soon(cb) + self.assertEqual(h._callback, cb) + self.assertIsInstance(h, events.Handle) + self.assertIn(h, self.loop._ready) + + def test_call_later(self): + def cb(): + pass + + h = self.loop.call_later(10.0, cb) + self.assertIsInstance(h, events.TimerHandle) + self.assertIn(h, self.loop._scheduled) + self.assertNotIn(h, self.loop._ready) + + def test_call_later_negative_delays(self): + calls = [] + + def cb(arg): + calls.append(arg) + + self.loop._process_events = unittest.mock.Mock() + self.loop.call_later(-1, cb, 'a') + self.loop.call_later(-2, cb, 'b') + test_utils.run_briefly(self.loop) + self.assertEqual(calls, ['b', 'a']) + + def test_time_and_call_at(self): + def cb(): + self.loop.stop() + + self.loop._process_events = unittest.mock.Mock() + when = self.loop.time() + 0.1 + self.loop.call_at(when, cb) + t0 = self.loop.time() + self.loop.run_forever() + t1 = self.loop.time() + self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0) + + def test_run_once_in_executor_handle(self): + def cb(): + pass + + self.assertRaises( + AssertionError, self.loop.run_in_executor, + None, events.Handle(cb, ()), ('',)) + self.assertRaises( + AssertionError, self.loop.run_in_executor, + None, events.TimerHandle(10, cb, ())) + + def test_run_once_in_executor_cancelled(self): + def cb(): + pass + h = events.Handle(cb, ()) + h.cancel() + + f = self.loop.run_in_executor(None, h) + self.assertIsInstance(f, futures.Future) + self.assertTrue(f.done()) + self.assertIsNone(f.result()) + + def test_run_once_in_executor_plain(self): + def cb(): + pass + h = events.Handle(cb, ()) + f = futures.Future(loop=self.loop) + executor = unittest.mock.Mock() + executor.submit.return_value = f + + self.loop.set_default_executor(executor) + + res = self.loop.run_in_executor(None, h) + self.assertIs(f, res) + + executor = unittest.mock.Mock() + executor.submit.return_value = f + res = self.loop.run_in_executor(executor, h) + self.assertIs(f, res) + self.assertTrue(executor.submit.called) + + f.cancel() # Don't complain about abandoned Future. + + def test__run_once(self): + h1 = events.TimerHandle(time.monotonic() + 0.1, lambda: True, ()) + h2 = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ()) + + h1.cancel() + + self.loop._process_events = unittest.mock.Mock() + self.loop._scheduled.append(h1) + self.loop._scheduled.append(h2) + self.loop._run_once() + + t = self.loop._selector.select.call_args[0][0] + self.assertTrue(9.99 < t < 10.1, t) + self.assertEqual([h2], self.loop._scheduled) + self.assertTrue(self.loop._process_events.called) + + @unittest.mock.patch('tulip.base_events.time') + @unittest.mock.patch('tulip.base_events.tulip_log') + def test__run_once_logging(self, m_logging, m_time): + # Log to INFO level if timeout > 1.0 sec. + idx = -1 + data = [10.0, 10.0, 12.0, 13.0] + + def monotonic(): + nonlocal data, idx + idx += 1 + return data[idx] + + m_time.monotonic = monotonic + m_logging.INFO = logging.INFO + m_logging.DEBUG = logging.DEBUG + + self.loop._scheduled.append( + events.TimerHandle(11.0, lambda: True, ())) + self.loop._process_events = unittest.mock.Mock() + self.loop._run_once() + self.assertEqual(logging.INFO, m_logging.log.call_args[0][0]) + + idx = -1 + data = [10.0, 10.0, 10.3, 13.0] + self.loop._scheduled = [events.TimerHandle(11.0, lambda:True, ())] + self.loop._run_once() + self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0]) + + def test__run_once_schedule_handle(self): + handle = None + processed = False + + def cb(loop): + nonlocal processed, handle + processed = True + handle = loop.call_soon(lambda: True) + + h = events.TimerHandle(time.monotonic() - 1, cb, (self.loop,)) + + self.loop._process_events = unittest.mock.Mock() + self.loop._scheduled.append(h) + self.loop._run_once() + + self.assertTrue(processed) + self.assertEqual([handle], list(self.loop._ready)) + + def test_run_until_complete_type_error(self): + self.assertRaises( + TypeError, self.loop.run_until_complete, 'blah') + + +class MyProto(protocols.Protocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = futures.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyDatagramProto(protocols.DatagramProtocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = futures.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class BaseEventLoopWithSelectorTests(unittest.TestCase): + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_connection_mutiple_errors(self, m_socket): + + class MyProto(protocols.Protocol): + pass + + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + idx = -1 + errors = ['err1', 'err2'] + + def _socket(*args, **kw): + nonlocal idx, errors + idx += 1 + raise OSError(errors[idx]) + + m_socket.socket = _socket + + self.loop.getaddrinfo = getaddrinfo_task + + task = tasks.Task( + self.loop.create_connection(MyProto, 'example.com', 80)) + yield from tasks.wait(task) + exc = task.exception() + self.assertEqual("Multiple exceptions: err1, err2", str(exc)) + + def test_create_connection_host_port_sock(self): + coro = self.loop.create_connection( + MyProto, 'example.com', 80, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_host_port_sock(self): + coro = self.loop.create_connection(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_getaddrinfo(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_connect_err(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_multiple(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET) + with self.assertRaises(OSError): + self.loop.run_until_complete(coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_connection_mutiple_errors_local_addr(self, m_socket): + + def bind(addr): + if addr[0] == '0.0.0.1': + err = OSError('Err') + err.strerror = 'Err' + raise err + + m_socket.socket.return_value.bind = bind + + @tasks.coroutine + def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError('Err2') + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(coro) + + self.assertTrue(str(cm.exception), 'Multiple exceptions: ') + self.assertTrue(m_socket.socket.return_value.close.called) + + def test_create_connection_no_local_addr(self): + @tasks.coroutine + def getaddrinfo(host, *args, **kw): + if host == 'example.com': + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + else: + return [] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + self.loop.getaddrinfo = getaddrinfo_task + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_start_serving_empty_host(self): + # if host is empty string use None instead + host = object() + + @tasks.coroutine + def getaddrinfo(*args, **kw): + nonlocal host + host = args[0] + yield from [] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + fut = self.loop.start_serving(MyProto, '', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertIsNone(host) + + def test_start_serving_host_port_sock(self): + fut = self.loop.start_serving( + MyProto, '0.0.0.0', 0, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_start_serving_no_host_port_sock(self): + fut = self.loop.start_serving(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_start_serving_no_getaddrinfo(self): + getaddrinfo = self.loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + f = self.loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, f) + + @unittest.mock.patch('tulip.base_events.socket') + def test_start_serving_cant_bind(self, m_socket): + + class Err(OSError): + strerror = 'error' + + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.start_serving(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_no_addrinfo(self, m_socket): + m_socket.getaddrinfo.return_value = [] + + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_addr_error(self): + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr='localhost') + self.assertRaises( + AssertionError, self.loop.run_until_complete, coro) + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 1, 2, 3)) + self.assertRaises( + AssertionError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_connect_err(self): + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_socket_err(self, m_socket): + m_socket.getaddrinfo = socket.getaddrinfo + m_socket.socket.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, local_addr=('127.0.0.1', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_no_matching_family(self): + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, + remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) + self.assertRaises( + ValueError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_setblk_err(self, m_socket): + m_socket.socket.return_value.setblocking.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + self.assertTrue( + m_socket.socket.return_value.close.called) + + def test_create_datagram_endpoint_noaddr_nofamily(self): + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_cant_bind(self, m_socket): + class Err(OSError): + pass + + m_socket.AF_INET6 = socket.AF_INET6 + m_socket.getaddrinfo = socket.getaddrinfo + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, + local_addr=('127.0.0.1', 0), family=socket.AF_INET) + self.assertRaises(Err, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_accept_connection_retry(self): + sock = unittest.mock.Mock() + sock.accept.side_effect = BlockingIOError() + + self.loop._accept_connection(MyProto, sock) + self.assertFalse(sock.close.called) + + @unittest.mock.patch('tulip.selector_events.tulip_log') + def test_accept_connection_exception(self, m_log): + sock = unittest.mock.Mock() + sock.accept.side_effect = OSError() + + self.loop._accept_connection(MyProto, sock) + self.assertTrue(sock.close.called) + self.assertTrue(m_log.exception.called) diff --git a/tests/echo.py b/tests/echo.py new file mode 100644 index 00000000..f6ac0a30 --- /dev/null +++ b/tests/echo.py @@ -0,0 +1,6 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + os.write(1, buf) diff --git a/tests/echo2.py b/tests/echo2.py new file mode 100644 index 00000000..e83ca09f --- /dev/null +++ b/tests/echo2.py @@ -0,0 +1,6 @@ +import os + +if __name__ == '__main__': + buf = os.read(0, 1024) + os.write(1, b'OUT:'+buf) + os.write(2, b'ERR:'+buf) diff --git a/tests/echo3.py b/tests/echo3.py new file mode 100644 index 00000000..f1f7ea7c --- /dev/null +++ b/tests/echo3.py @@ -0,0 +1,9 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + try: + os.write(1, b'OUT:'+buf) + except OSError as ex: + os.write(2, b'ERR:' + ex.__class__.__name__.encode('ascii')) diff --git a/tests/events_test.py b/tests/events_test.py new file mode 100644 index 00000000..240518c0 --- /dev/null +++ b/tests/events_test.py @@ -0,0 +1,1574 @@ +"""Tests for events.py.""" + +import functools +import gc +import io +import os +import re +import signal +import socket +try: + import ssl +except ImportError: + ssl = None +import subprocess +import sys +import threading +import time +import errno +import unittest +import unittest.mock +from test.support import find_unused_port + + +from tulip import futures +from tulip import events +from tulip import transports +from tulip import protocols +from tulip import selector_events +from tulip import tasks +from tulip import test_utils +from tulip import locks + + +class MyProto(protocols.Protocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyDatagramProto(protocols.DatagramProtocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyReadPipeProto(protocols.Protocol): + done = None + + def __init__(self, loop=None): + self.state = ['INITIAL'] + self.nbytes = 0 + self.transport = None + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == ['INITIAL'], self.state + self.state.append('CONNECTED') + + def data_received(self, data): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.state.append('EOF') + self.transport.close() + + def connection_lost(self, exc): + assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state + self.state.append('CLOSED') + if self.done: + self.done.set_result(None) + + +class MyWritePipeProto(protocols.BaseProtocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.transport = None + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MySubprocessProtocol(protocols.SubprocessProtocol): + + def __init__(self, loop): + self.state = 'INITIAL' + self.transport = None + self.connected = futures.Future(loop=loop) + self.completed = futures.Future(loop=loop) + self.disconnects = {fd: futures.Future(loop=loop) for fd in range(3)} + self.data = {1: b'', 2: b''} + self.returncode = None + self.got_data = {1: locks.EventWaiter(loop=loop), + 2: locks.EventWaiter(loop=loop)} + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + self.connected.set_result(None) + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + self.completed.set_result(None) + + def pipe_data_received(self, fd, data): + assert self.state == 'CONNECTED', self.state + self.data[fd] += data + self.got_data[fd].set() + + def pipe_connection_lost(self, fd, exc): + assert self.state == 'CONNECTED', self.state + if exc: + self.disconnects[fd].set_exception(exc) + else: + self.disconnects[fd].set_result(exc) + + def process_exited(self): + assert self.state == 'CONNECTED', self.state + self.returncode = self.transport.get_returncode() + + +class EventLoopTestsMixin: + + def setUp(self): + super().setUp() + self.loop = self.create_event_loop() + events.set_event_loop(None) + + def tearDown(self): + # just in case if we have transport close callbacks + test_utils.run_briefly(self.loop) + + self.loop.close() + gc.collect() + super().tearDown() + + def test_run_until_complete_nesting(self): + @tasks.coroutine + def coro1(): + yield + + @tasks.coroutine + def coro2(): + self.assertTrue(self.loop.is_running()) + self.loop.run_until_complete(coro1()) + + self.assertRaises( + RuntimeError, self.loop.run_until_complete, coro2()) + + # Note: because of the default Windows timing granularity of + # 15.6 msec, we use fairly long sleep times here (~100 msec). + + def test_run_until_complete(self): + t0 = self.loop.time() + self.loop.run_until_complete(tasks.sleep(0.1, loop=self.loop)) + t1 = self.loop.time() + self.assertTrue(0.08 <= t1-t0 <= 0.12, t1-t0) + + def test_run_until_complete_stopped(self): + @tasks.coroutine + def cb(): + self.loop.stop() + yield from tasks.sleep(0.1, loop=self.loop) + task = cb() + self.assertRaises(RuntimeError, + self.loop.run_until_complete, task) + + def test_call_later(self): + results = [] + + def callback(arg): + results.append(arg) + self.loop.stop() + + self.loop.call_later(0.1, callback, 'hello world') + t0 = time.monotonic() + self.loop.run_forever() + t1 = time.monotonic() + self.assertEqual(results, ['hello world']) + self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0) + + def test_call_soon(self): + results = [] + + def callback(arg1, arg2): + results.append((arg1, arg2)) + self.loop.stop() + + self.loop.call_soon(callback, 'hello', 'world') + self.loop.run_forever() + self.assertEqual(results, [('hello', 'world')]) + + def test_call_soon_threadsafe(self): + results = [] + lock = threading.Lock() + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.loop.stop() + + def run_in_thread(): + self.loop.call_soon_threadsafe(callback, 'hello') + lock.release() + + lock.acquire() + t = threading.Thread(target=run_in_thread) + t.start() + + with lock: + self.loop.call_soon(callback, 'world') + self.loop.run_forever() + t.join() + self.assertEqual(results, ['hello', 'world']) + + def test_call_soon_threadsafe_same_thread(self): + results = [] + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.loop.stop() + + self.loop.call_soon_threadsafe(callback, 'hello') + self.loop.call_soon(callback, 'world') + self.loop.run_forever() + self.assertEqual(results, ['hello', 'world']) + + def test_run_in_executor(self): + def run(arg): + return (arg, threading.get_ident()) + f2 = self.loop.run_in_executor(None, run, 'yo') + res, thread_id = self.loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + self.assertNotEqual(thread_id, threading.get_ident()) + + def test_reader_callback(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.loop.remove_reader(r.fileno())) + r.close() + + self.loop.add_reader(r.fileno(), reader) + self.loop.call_soon(w.send, b'abc') + test_utils.run_briefly(self.loop) + self.loop.call_soon(w.send, b'def') + self.loop.call_soon(w.close) + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_writer_callback(self): + r, w = test_utils.socketpair() + w.setblocking(False) + self.loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + test_utils.run_briefly(self.loop) + + def remove_writer(): + self.assertTrue(self.loop.remove_writer(w.fileno())) + + self.loop.call_soon(remove_writer) + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + w.close() + data = r.recv(256*1024) + r.close() + self.assertGreaterEqual(len(data), 200) + + def test_sock_client_ops(self): + with test_utils.run_test_server(self.loop) as httpd: + sock = socket.socket() + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, httpd.address)) + self.loop.run_until_complete( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + # consume data + self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + sock.close() + + self.assertTrue(re.match(rb'HTTP/1.0 200 OK', data), data) + + def test_sock_client_fail(self): + # Make sure that we will get an unused port + address = None + try: + s = socket.socket() + s.bind(('127.0.0.1', 0)) + address = s.getsockname() + finally: + s.close() + + sock = socket.socket() + sock.setblocking(False) + with self.assertRaises(ConnectionRefusedError): + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) + sock.close() + + def test_sock_accept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + + f = self.loop.sock_accept(listener) + conn, addr = self.loop.run_until_complete(f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + def test_add_signal_handler(self): + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + # Check error behavior first. + self.assertRaises( + TypeError, self.loop.add_signal_handler, 'boom', my_handler) + self.assertRaises( + TypeError, self.loop.remove_signal_handler, 'boom') + self.assertRaises( + ValueError, self.loop.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, signal.NSIG+1) + self.assertRaises( + ValueError, self.loop.add_signal_handler, 0, my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, 0) + self.assertRaises( + ValueError, self.loop.add_signal_handler, -1, my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, -1) + self.assertRaises( + RuntimeError, self.loop.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(self.loop.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + self.loop.add_signal_handler(signal.SIGINT, my_handler) + test_utils.run_briefly(self.loop) + os.kill(os.getpid(), signal.SIGINT) + test_utils.run_briefly(self.loop) + self.assertEqual(caught, 1) + # Removing it should restore the default handler. + self.assertTrue(self.loop.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(self.loop.remove_signal_handler(signal.SIGINT)) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_while_selecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + self.loop.stop() + + self.loop.add_signal_handler(signal.SIGALRM, my_handler) + + signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once. + self.loop.run_forever() + self.assertEqual(caught, 1) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_args(self): + some_args = (42,) + caught = 0 + + def my_handler(*args): + nonlocal caught + caught += 1 + self.assertEqual(args, some_args) + + self.loop.add_signal_handler(signal.SIGALRM, my_handler, *some_args) + + signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once. + self.loop.call_later(0.015, self.loop.stop) + self.loop.run_forever() + self.assertEqual(caught, 1) + + def test_create_connection(self): + with test_utils.run_test_server(self.loop) as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), *httpd.address) + tr, pr = self.loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def test_create_connection_sock(self): + with test_utils.run_test_server(self.loop) as httpd: + sock = None + infos = self.loop.run_until_complete( + self.loop.getaddrinfo( + *httpd.address, type=socket.SOCK_STREAM)) + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) + except: + pass + else: + break + else: + assert False, 'Can not create socket.' + + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), sock=sock) + tr, pr = self.loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_ssl_connection(self): + with test_utils.run_test_server( + self.loop, use_ssl=True) as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), *httpd.address, ssl=True) + tr, pr = self.loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + self.assertTrue( + hasattr(tr.get_extra_info('socket'), 'getsockname')) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def test_create_connection_local_addr(self): + with test_utils.run_test_server(self.loop) as httpd: + port = find_unused_port() + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, local_addr=(httpd.address[0], port)) + tr, pr = self.loop.run_until_complete(f) + expected = pr.transport.get_extra_info('socket').getsockname()[1] + self.assertEqual(port, expected) + tr.close() + + def test_create_connection_local_addr_in_use(self): + with test_utils.run_test_server(self.loop) as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, local_addr=httpd.address) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + self.assertIn(str(httpd.address), cm.exception.strerror) + + def test_start_serving(self): + proto = None + + def factory(): + nonlocal proto + proto = MyProto() + return proto + + f = self.loop.start_serving(factory, '0.0.0.0', 0) + socks = self.loop.run_until_complete(f) + self.assertEqual(len(socks), 1) + sock = socks[0] + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + test_utils.run_briefly(self.loop) + self.assertIsInstance(proto, MyProto) + self.assertEqual('INITIAL', proto.state) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + test_utils.run_briefly(self.loop) # windows iocp + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('socket')) + conn = proto.transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + self.assertEqual( + '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + + # close connection + proto.transport.close() + test_utils.run_briefly(self.loop) # windows iocp + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # close start_serving socks + self.loop.stop_serving(sock) + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_start_serving_ssl(self): + proto = None + + class ClientMyProto(MyProto): + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def factory(): + nonlocal proto + proto = MyProto(loop=self.loop) + return proto + + here = os.path.dirname(__file__) + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain( + certfile=os.path.join(here, 'sample.crt'), + keyfile=os.path.join(here, 'sample.key')) + + f = self.loop.start_serving( + factory, '127.0.0.1', 0, ssl=sslcontext) + + sock = self.loop.run_until_complete(f)[0] + host, port = sock.getsockname() + self.assertEqual(host, '127.0.0.1') + + f_c = self.loop.create_connection(ClientMyProto, host, port, ssl=True) + client, pr = self.loop.run_until_complete(f_c) + + client.write(b'xxx') + test_utils.run_briefly(self.loop) + self.assertIsInstance(proto, MyProto) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('socket')) + conn = proto.transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + self.assertEqual( + '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # stop serving + self.loop.stop_serving(sock) + + def test_start_serving_sock(self): + proto = futures.Future(loop=self.loop) + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + proto.set_result(self) + + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.loop.start_serving(TestMyProto, sock=sock_ob) + sock = self.loop.run_until_complete(f)[0] + self.assertIs(sock, sock_ob) + + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + + self.loop.stop_serving(sock) + + def test_start_serving_addr_in_use(self): + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.loop.start_serving(MyProto, sock=sock_ob) + sock = self.loop.run_until_complete(f)[0] + host, port = sock.getsockname() + + f = self.loop.start_serving(MyProto, host=host, port=port) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + + self.loop.stop_serving(sock) + + @unittest.skipUnless(socket.has_ipv6, 'IPv6 not supported') + def test_start_serving_dual_stack(self): + f_proto = futures.Future(loop=self.loop) + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + f_proto.set_result(self) + + try_count = 0 + while True: + try: + port = find_unused_port() + f = self.loop.start_serving(TestMyProto, host=None, port=port) + socks = self.loop.run_until_complete(f) + except OSError as ex: + if ex.errno == errno.EADDRINUSE: + try_count += 1 + self.assertGreaterEqual(5, try_count) + continue + else: + raise + else: + break + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + proto = self.loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + f_proto = futures.Future(loop=self.loop) + client = socket.socket(socket.AF_INET6) + client.connect(('::1', port)) + client.send(b'xxx') + proto = self.loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + for s in socks: + self.loop.stop_serving(s) + + def test_stop_serving(self): + f = self.loop.start_serving(MyProto, '0.0.0.0', 0) + socks = self.loop.run_until_complete(f) + sock = socks[0] + host, port = sock.getsockname() + + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + + self.loop.stop_serving(sock) + + client = socket.socket() + self.assertRaises( + ConnectionRefusedError, client.connect, ('127.0.0.1', port)) + client.close() + + def test_create_datagram_endpoint(self): + class TestMyDatagramProto(MyDatagramProto): + def __init__(inner_self): + super().__init__(loop=self.loop) + + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + coro = self.loop.create_datagram_endpoint( + TestMyDatagramProto, local_addr=('127.0.0.1', 0)) + s_transport, server = self.loop.run_until_complete(coro) + host, port = s_transport.get_extra_info('addr') + + coro = self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(loop=self.loop), + remote_addr=(host, port)) + transport, client = self.loop.run_until_complete(coro) + + self.assertEqual('INITIALIZED', client.state) + transport.sendto(b'xxx') + test_utils.run_briefly(self.loop) + self.assertEqual(3, server.nbytes) + test_utils.run_briefly(self.loop) + + # received + self.assertEqual(8, client.nbytes) + + # extra info is available + self.assertIsNotNone(transport.get_extra_info('socket')) + conn = transport.get_extra_info('socket') + self.assertTrue(hasattr(conn, 'getsockname')) + + # close connection + transport.close() + self.loop.run_until_complete(client.done) + self.assertEqual('CLOSED', client.state) + server.transport.close() + + def test_internal_fds(self): + loop = self.create_event_loop() + if not isinstance(loop, selector_events.BaseSelectorEventLoop): + return + + self.assertEqual(1, loop._internal_fds) + loop.close() + self.assertEqual(0, loop._internal_fds) + self.assertIsNone(loop._csock) + self.assertIsNone(loop._ssock) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_read_pipe(self): + proto = None + + def factory(): + nonlocal proto + proto = MyReadPipeProto(loop=self.loop) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(rpipe, 'rb', 1024) + + @tasks.coroutine + def connect(): + t, p = yield from self.loop.connect_read_pipe(factory, pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(0, proto.nbytes) + + self.loop.run_until_complete(connect()) + + os.write(wpipe, b'1') + test_utils.run_briefly(self.loop) + self.assertEqual(1, proto.nbytes) + + os.write(wpipe, b'2345') + test_utils.run_briefly(self.loop) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(5, proto.nbytes) + + os.close(wpipe) + self.loop.run_until_complete(proto.done) + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto(loop=self.loop) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.coroutine + def connect(): + nonlocal transport + t, p = yield from self.loop.connect_write_pipe(factory, pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.loop.run_until_complete(connect()) + + transport.write(b'1') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + transport.write(b'2345') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'2345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(rpipe) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe_disconnect_on_close(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto(loop=self.loop) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.coroutine + def connect(): + nonlocal transport + t, p = yield from self.loop.connect_write_pipe(factory, + pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.loop.run_until_complete(connect()) + self.assertEqual('CONNECTED', proto.state) + + transport.write(b'1') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + os.close(rpipe) + + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + def test_prompt_cancellation(self): + r, w = test_utils.socketpair() + r.setblocking(False) + f = self.loop.sock_recv(r, 1) + ov = getattr(f, 'ov', None) + self.assertTrue(ov is None or ov.pending) + + def main(): + try: + self.loop.call_soon(f.cancel) + yield from f + except futures.CancelledError: + res = 'cancelled' + else: + res = None + finally: + self.loop.stop() + return res + + start = time.monotonic() + t = tasks.Task(main(), loop=self.loop) + self.loop.run_forever() + elapsed = time.monotonic() - start + + self.assertLess(elapsed, 0.1) + self.assertEqual(t.result(), 'cancelled') + self.assertRaises(futures.CancelledError, f.result) + self.assertTrue(ov is None or not ov.pending) + self.loop.stop_serving(r) + + r.close() + w.close() + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_exec(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python The Winner') + self.loop.run_until_complete(proto.got_data[1].wait()) + transp.close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + self.assertEqual(b'Python The Winner', proto.data[1]) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_interactive(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + try: + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python ') + self.loop.run_until_complete(proto.got_data[1].wait()) + proto.got_data[1].clear() + self.assertEqual(b'Python ', proto.data[1]) + + stdin.write(b'The Winner') + self.loop.run_until_complete(proto.got_data[1].wait()) + self.assertEqual(b'Python The Winner', proto.data[1]) + finally: + transp.close() + + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_shell(self): + proto = None + transp = None + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'echo "Python"') + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.get_pipe_transport(0).close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(0, proto.returncode) + self.assertTrue(all(f.done() for f in proto.disconnects.values())) + self.assertEqual({1: b'Python\n', 2: b''}, proto.data) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_exitcode(self): + proto = None + + @tasks.coroutine + def connect(): + nonlocal proto + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_close_after_finish(self): + proto = None + transp = None + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.assertIsNone(transp.get_pipe_transport(0)) + self.assertIsNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + self.assertIsNone(transp.close()) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_kill(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.kill() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGKILL, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_send_signal(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.send_signal(signal.SIGHUP) + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGHUP, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_stderr(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'test') + + self.loop.run_until_complete(proto.completed) + + transp.close() + self.assertEqual(b'OUT:test', proto.data[1]) + self.assertTrue(proto.data[2].startswith(b'ERR:test'), proto.data[2]) + self.assertEqual(0, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_stderr_redirect_to_stdout(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog, stderr=subprocess.STDOUT) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + self.assertIsNotNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + + stdin.write(b'test') + self.loop.run_until_complete(proto.completed) + self.assertTrue(proto.data[1].startswith(b'OUT:testERR:test'), + proto.data[1]) + self.assertEqual(b'', proto.data[2]) + + transp.close() + self.assertEqual(0, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_close_client_stream(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo3.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdout = transp.get_pipe_transport(1) + stdin.write(b'test') + self.loop.run_until_complete(proto.got_data[1].wait()) + self.assertEqual(b'OUT:test', proto.data[1]) + + stdout.close() + self.loop.run_until_complete(proto.disconnects[1]) + stdin.write(b'xxx') + self.loop.run_until_complete(proto.got_data[2].wait()) + self.assertEqual(b'ERR:BrokenPipeError', proto.data[2]) + + transp.close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + + +if sys.platform == 'win32': + from tulip import windows_events + + class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return windows_events.SelectorEventLoop() + + class ProactorEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return windows_events.ProactorEventLoop() + + def test_create_ssl_connection(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + + def test_start_serving_ssl(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + + def test_reader_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_reader_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_writer_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + + def test_writer_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + + def test_create_datagram_endpoint(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") +else: + from tulip import selectors + from tulip import unix_events + + if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop( + selectors.KqueueSelector()) + + if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.EpollSelector()) + + if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.PollSelector()) + + # Should always exist. + class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.SelectSelector()) + + +class HandleTests(unittest.TestCase): + + def test_handle(self): + def callback(*args): + return args + + args = () + h = events.Handle(callback, args) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h._cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '.callback')) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h._cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '.callback')) + self.assertTrue(r.endswith('())'), r) + + def test_make_handle(self): + def callback(*args): + return args + h1 = events.Handle(callback, ()) + self.assertRaises( + AssertionError, events.make_handle, h1, ()) + + @unittest.mock.patch('tulip.events.tulip_log') + def test_callback_with_exception(self, log): + def callback(): + raise ValueError() + + h = events.Handle(callback, ()) + h._run() + self.assertTrue(log.exception.called) + + +class TimerTests(unittest.TestCase): + + def test_hash(self): + when = time.monotonic() + h = events.TimerHandle(when, lambda: False, ()) + self.assertEqual(hash(h), hash(when)) + + def test_timer(self): + def callback(*args): + return args + + args = () + when = time.monotonic() + h = events.TimerHandle(when, callback, args) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h._cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h._cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())'), r) + + self.assertRaises(AssertionError, + events.TimerHandle, None, callback, args) + + def test_timer_comparison(self): + def callback(*args): + return args + + when = time.monotonic() + + h1 = events.TimerHandle(when, callback, ()) + h2 = events.TimerHandle(when, callback, ()) + # TODO: Use assertLess etc. + self.assertFalse(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertTrue(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertFalse(h2 > h1) + self.assertTrue(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertTrue(h1 == h2) + self.assertFalse(h1 != h2) + + h2.cancel() + self.assertFalse(h1 == h2) + + h1 = events.TimerHandle(when, callback, ()) + h2 = events.TimerHandle(when + 10.0, callback, ()) + self.assertTrue(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertFalse(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertTrue(h2 > h1) + self.assertFalse(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h3 = events.Handle(callback, ()) + self.assertIs(NotImplemented, h1.__eq__(h3)) + self.assertIs(NotImplemented, h1.__ne__(h3)) + + +class AbstractEventLoopTests(unittest.TestCase): + + def test_not_implemented(self): + f = unittest.mock.Mock() + loop = events.AbstractEventLoop() + self.assertRaises( + NotImplementedError, loop.run_forever) + self.assertRaises( + NotImplementedError, loop.run_until_complete, None) + self.assertRaises( + NotImplementedError, loop.stop) + self.assertRaises( + NotImplementedError, loop.is_running) + self.assertRaises( + NotImplementedError, loop.call_later, None, None) + self.assertRaises( + NotImplementedError, loop.call_at, f, f) + self.assertRaises( + NotImplementedError, loop.call_soon, None) + self.assertRaises( + NotImplementedError, loop.time) + self.assertRaises( + NotImplementedError, loop.call_soon_threadsafe, None) + self.assertRaises( + NotImplementedError, loop.run_in_executor, f, f) + self.assertRaises( + NotImplementedError, loop.set_default_executor, f) + self.assertRaises( + NotImplementedError, loop.getaddrinfo, 'localhost', 8080) + self.assertRaises( + NotImplementedError, loop.getnameinfo, ('localhost', 8080)) + self.assertRaises( + NotImplementedError, loop.create_connection, f) + self.assertRaises( + NotImplementedError, loop.start_serving, f) + self.assertRaises( + NotImplementedError, loop.stop_serving, f) + self.assertRaises( + NotImplementedError, loop.create_datagram_endpoint, f) + self.assertRaises( + NotImplementedError, loop.add_reader, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_reader, 1) + self.assertRaises( + NotImplementedError, loop.add_writer, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_writer, 1) + self.assertRaises( + NotImplementedError, loop.sock_recv, f, 10) + self.assertRaises( + NotImplementedError, loop.sock_sendall, f, 10) + self.assertRaises( + NotImplementedError, loop.sock_connect, f, f) + self.assertRaises( + NotImplementedError, loop.sock_accept, f) + self.assertRaises( + NotImplementedError, loop.add_signal_handler, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, loop.connect_read_pipe, f, + unittest.mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, loop.connect_write_pipe, f, + unittest.mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, loop.subprocess_shell, f, + unittest.mock.sentinel) + self.assertRaises( + NotImplementedError, loop.subprocess_exec, f) + + +class ProtocolsAbsTests(unittest.TestCase): + + def test_empty(self): + f = unittest.mock.Mock() + p = protocols.Protocol() + self.assertIsNone(p.connection_made(f)) + self.assertIsNone(p.connection_lost(f)) + self.assertIsNone(p.data_received(f)) + self.assertIsNone(p.eof_received()) + + dp = protocols.DatagramProtocol() + self.assertIsNone(dp.connection_made(f)) + self.assertIsNone(dp.connection_lost(f)) + self.assertIsNone(dp.connection_refused(f)) + self.assertIsNone(dp.datagram_received(f, f)) + + sp = protocols.SubprocessProtocol() + self.assertIsNone(sp.connection_made(f)) + self.assertIsNone(sp.connection_lost(f)) + self.assertIsNone(sp.pipe_data_received(1, f)) + self.assertIsNone(sp.pipe_connection_lost(1, f)) + self.assertIsNone(sp.process_exited()) + + +class PolicyTests(unittest.TestCase): + + def test_event_loop_policy(self): + policy = events.AbstractEventLoopPolicy() + self.assertRaises(NotImplementedError, policy.get_event_loop) + self.assertRaises(NotImplementedError, policy.set_event_loop, object()) + self.assertRaises(NotImplementedError, policy.new_event_loop) + + def test_get_event_loop(self): + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy._loop) + + loop = policy.get_event_loop() + self.assertIsInstance(loop, events.AbstractEventLoop) + + self.assertIs(policy._loop, loop) + self.assertIs(loop, policy.get_event_loop()) + loop.close() + + def test_get_event_loop_after_set_none(self): + policy = events.DefaultEventLoopPolicy() + policy.set_event_loop(None) + self.assertRaises(AssertionError, policy.get_event_loop) + + @unittest.mock.patch('tulip.events.threading.current_thread') + def test_get_event_loop_thread(self, m_current_thread): + + def f(): + policy = events.DefaultEventLoopPolicy() + self.assertRaises(AssertionError, policy.get_event_loop) + + th = threading.Thread(target=f) + th.start() + th.join() + + def test_new_event_loop(self): + policy = events.DefaultEventLoopPolicy() + + loop = policy.new_event_loop() + self.assertIsInstance(loop, events.AbstractEventLoop) + loop.close() + + def test_set_event_loop(self): + policy = events.DefaultEventLoopPolicy() + old_loop = policy.get_event_loop() + + self.assertRaises(AssertionError, policy.set_event_loop, object()) + + loop = policy.new_event_loop() + policy.set_event_loop(loop) + self.assertIs(loop, policy.get_event_loop()) + self.assertIsNot(old_loop, policy.get_event_loop()) + loop.close() + old_loop.close() + + def test_get_event_loop_policy(self): + policy = events.get_event_loop_policy() + self.assertIsInstance(policy, events.AbstractEventLoopPolicy) + self.assertIs(policy, events.get_event_loop_policy()) + + def test_set_event_loop_policy(self): + self.assertRaises( + AssertionError, events.set_event_loop_policy, object()) + + old_policy = events.get_event_loop_policy() + + policy = events.DefaultEventLoopPolicy() + events.set_event_loop_policy(policy) + self.assertIs(policy, events.get_event_loop_policy()) + self.assertIsNot(policy, old_policy) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/futures_test.py b/tests/futures_test.py new file mode 100644 index 00000000..7c2abd18 --- /dev/null +++ b/tests/futures_test.py @@ -0,0 +1,317 @@ +"""Tests for futures.py.""" + +import concurrent.futures +import threading +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import test_utils + + +def _fakefunc(f): + return f + + +class FutureTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_initial_state(self): + f = futures.Future(loop=self.loop) + self.assertFalse(f.cancelled()) + self.assertFalse(f.done()) + f.cancel() + self.assertTrue(f.cancelled()) + + def test_init_constructor_default_loop(self): + try: + events.set_event_loop(self.loop) + f = futures.Future() + self.assertIs(f._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_constructor_positional(self): + # Make sure Future does't accept a positional argument + self.assertRaises(TypeError, futures.Future, 42) + + def test_cancel(self): + f = futures.Future(loop=self.loop) + self.assertTrue(f.cancel()) + self.assertTrue(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(futures.CancelledError, f.result) + self.assertRaises(futures.CancelledError, f.exception) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_result(self): + f = futures.Future(loop=self.loop) + self.assertRaises(futures.InvalidStateError, f.result) + + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_exception(self): + exc = RuntimeError() + f = futures.Future(loop=self.loop) + self.assertRaises(futures.InvalidStateError, f.exception) + + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_yield_from_twice(self): + f = futures.Future(loop=self.loop) + + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + + def test_repr(self): + f_pending = futures.Future(loop=self.loop) + self.assertEqual(repr(f_pending), 'Future') + f_pending.cancel() + + f_cancelled = futures.Future(loop=self.loop) + f_cancelled.cancel() + self.assertEqual(repr(f_cancelled), 'Future') + + f_result = futures.Future(loop=self.loop) + f_result.set_result(4) + self.assertEqual(repr(f_result), 'Future') + self.assertEqual(f_result.result(), 4) + + exc = RuntimeError() + f_exception = futures.Future(loop=self.loop) + f_exception.set_exception(exc) + self.assertEqual(repr(f_exception), 'Future') + self.assertIs(f_exception.exception(), exc) + + f_few_callbacks = futures.Future(loop=self.loop) + f_few_callbacks.add_done_callback(_fakefunc) + self.assertIn('Future', r) + f_many_callbacks.cancel() + + def test_copy_state(self): + # Test the internal _copy_state method since it's being directly + # invoked in other modules. + f = futures.Future(loop=self.loop) + f.set_result(10) + + newf = futures.Future(loop=self.loop) + newf._copy_state(f) + self.assertTrue(newf.done()) + self.assertEqual(newf.result(), 10) + + f_exception = futures.Future(loop=self.loop) + f_exception.set_exception(RuntimeError()) + + newf_exception = futures.Future(loop=self.loop) + newf_exception._copy_state(f_exception) + self.assertTrue(newf_exception.done()) + self.assertRaises(RuntimeError, newf_exception.result) + + f_cancelled = futures.Future(loop=self.loop) + f_cancelled.cancel() + + newf_cancelled = futures.Future(loop=self.loop) + newf_cancelled._copy_state(f_cancelled) + self.assertTrue(newf_cancelled.cancelled()) + + def test_iter(self): + fut = futures.Future(loop=self.loop) + + def coro(): + yield from fut + + def test(): + arg1, arg2 = coro() + + self.assertRaises(AssertionError, test) + fut.cancel() + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_tb_logger_abandoned(self, m_log): + fut = futures.Future(loop=self.loop) + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_tb_logger_result_unretrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_result(42) + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_tb_logger_result_retrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_result(42) + fut.result() + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_tb_logger_exception_unretrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + del fut + test_utils.run_briefly(self.loop) + self.assertTrue(m_log.error.called) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_tb_logger_exception_retrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + fut.exception() + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('tulip.futures.tulip_log') + def test_tb_logger_exception_result_retrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + self.assertRaises(RuntimeError, fut.result) + del fut + self.assertFalse(m_log.error.called) + + def test_wrap_future(self): + + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = futures.wrap_future(f1, loop=self.loop) + res, ident = self.loop.run_until_complete(f2) + self.assertIsInstance(f2, futures.Future) + self.assertEqual(res, 'oi') + self.assertNotEqual(ident, threading.get_ident()) + + def test_wrap_future_future(self): + f1 = futures.Future(loop=self.loop) + f2 = futures.wrap_future(f1) + self.assertIs(f1, f2) + + @unittest.mock.patch('tulip.futures.events') + def test_wrap_future_use_global_loop(self, m_events): + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = futures.wrap_future(f1) + self.assertIs(m_events.get_event_loop.return_value, f2._loop) + + +# A fake event loop for tests. All it does is implement a call_soon method +# that immediately invokes the given function. +class _FakeEventLoop: + def call_soon(self, fn, *args): + fn(*args) + + +class FutureDoneCallbackTests(unittest.TestCase): + + def _make_callback(self, bag, thing): + # Create a callback function that appends thing to bag. + def bag_appender(future): + bag.append(thing) + return bag_appender + + def _new_future(self): + return futures.Future(loop=_FakeEventLoop()) + + def test_callbacks_invoked_on_set_result(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 42)) + f.add_done_callback(self._make_callback(bag, 17)) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def test_callbacks_invoked_on_set_exception(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 100)) + + self.assertEqual(bag, []) + exc = RuntimeError() + f.set_exception(exc) + self.assertEqual(bag, [100]) + self.assertEqual(f.exception(), exc) + + def test_remove_done_callback(self): + bag = [] + f = self._new_future() + cb1 = self._make_callback(bag, 1) + cb2 = self._make_callback(bag, 2) + cb3 = self._make_callback(bag, 3) + + # Add one cb1 and one cb2. + f.add_done_callback(cb1) + f.add_done_callback(cb2) + + # One instance of cb2 removed. Now there's only one cb1. + self.assertEqual(f.remove_done_callback(cb2), 1) + + # Never had any cb3 in there. + self.assertEqual(f.remove_done_callback(cb3), 0) + + # After this there will be 6 instances of cb1 and one of cb2. + f.add_done_callback(cb2) + for i in range(5): + f.add_done_callback(cb1) + + # Remove all instances of cb1. One cb2 remains. + self.assertEqual(f.remove_done_callback(cb1), 6) + + self.assertEqual(bag, []) + f.set_result('foo') + self.assertEqual(bag, [2]) + self.assertEqual(f.result(), 'foo') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/http_client_functional_test.py b/tests/http_client_functional_test.py new file mode 100644 index 00000000..91badfc4 --- /dev/null +++ b/tests/http_client_functional_test.py @@ -0,0 +1,552 @@ +"""Http client functional tests.""" + +import gc +import io +import os.path +import http.cookies +import unittest + +import tulip +import tulip.http +from tulip import test_utils +from tulip.http import client + + +class HttpClientFunctionalTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(None) + + def tearDown(self): + # just in case if we have transport close callbacks + test_utils.run_briefly(self.loop) + + self.loop.close() + gc.collect() + + def test_HTTP_200_OK_METHOD(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + for meth in ('get', 'post', 'put', 'delete', 'head'): + r = self.loop.run_until_complete( + client.request(meth, httpd.url('method', meth), + loop=self.loop)) + content1 = self.loop.run_until_complete(r.read()) + content2 = self.loop.run_until_complete(r.read()) + content = content1.decode() + + self.assertEqual(r.status, 200) + self.assertIn('"method": "%s"' % meth.upper(), content) + self.assertEqual(content1, content2) + r.close() + + def test_use_global_loop(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + try: + tulip.set_event_loop(self.loop) + r = self.loop.run_until_complete( + client.request('get', httpd.url('method', 'get'))) + finally: + tulip.set_event_loop(None) + content1 = self.loop.run_until_complete(r.read()) + content2 = self.loop.run_until_complete(r.read()) + content = content1.decode() + + self.assertEqual(r.status, 200) + self.assertIn('"method": "GET"', content) + self.assertEqual(content1, content2) + r.close() + + def test_HTTP_302_REDIRECT_GET(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('redirect', 2), + loop=self.loop)) + + self.assertEqual(r.status, 200) + self.assertEqual(2, httpd['redirects']) + r.close() + + def test_HTTP_302_REDIRECT_NON_HTTP(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + self.assertRaises( + ValueError, + self.loop.run_until_complete, + client.request('get', httpd.url('redirect_err'), + loop=self.loop)) + + def test_HTTP_302_REDIRECT_POST(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('post', httpd.url('redirect', 2), + data={'some': 'data'}, loop=self.loop)) + content = self.loop.run_until_complete(r.content.read()) + content = content.decode() + + self.assertEqual(r.status, 200) + self.assertIn('"method": "POST"', content) + self.assertEqual(2, httpd['redirects']) + r.close() + + def test_HTTP_302_max_redirects(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('redirect', 5), + max_redirects=2, loop=self.loop)) + + self.assertEqual(r.status, 302) + self.assertEqual(2, httpd['redirects']) + r.close() + + def test_HTTP_200_GET_WITH_PARAMS(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('method', 'get'), + params={'q': 'test'}, loop=self.loop)) + content = self.loop.run_until_complete(r.content.read()) + content = content.decode() + + self.assertIn('"query": "q=test"', content) + self.assertEqual(r.status, 200) + r.close() + + def test_HTTP_200_GET_WITH_MIXED_PARAMS(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request( + 'get', httpd.url('method', 'get') + '?test=true', + params={'q': 'test'}, loop=self.loop)) + content = self.loop.run_until_complete(r.content.read()) + content = content.decode() + + self.assertIn('"query": "test=true&q=test"', content) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_DATA(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + r = self.loop.run_until_complete( + client.request('post', url, data={'some': 'data'}, + loop=self.loop)) + self.assertEqual(r.status, 200) + + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual({'some': ['data']}, content['form']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_DATA_DEFLATE(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + r = self.loop.run_until_complete( + client.request('post', url, + data={'some': 'data'}, compress=True, + loop=self.loop)) + self.assertEqual(r.status, 200) + + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual('deflate', content['compression']) + self.assertEqual({'some': ['data']}, content['form']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request( + 'post', url, files={'some': f}, chunked=1024, + headers={'Transfer-Encoding': 'chunked'}, + loop=self.loop)) + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_DEFLATE(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files={'some': f}, + chunked=1024, compress='deflate', + loop=self.loop)) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual('deflate', content['compression']) + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_STR(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files=[('some', f.read())], + loop=self.loop)) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + 'some', content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_LIST(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files=[('some', f)], + loop=self.loop)) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_LIST_CT(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, loop=self.loop, + files=[('some', f, 'text/plain')])) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + 'some', content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual( + 'text/plain', content['multipart-data'][0]['content-type']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_SINGLE(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, files=[f], loop=self.loop)) + + content = self.loop.run_until_complete(r.read(True)) + + f.seek(0) + filename = os.path.split(f.name)[-1] + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + filename, content['multipart-data'][0]['name']) + self.assertEqual( + filename, content['multipart-data'][0]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][0]['data']) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_IO(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + data = io.BytesIO(b'data') + + r = self.loop.run_until_complete( + client.request('post', url, files=[data], loop=self.loop)) + + content = self.loop.run_until_complete(r.read(True)) + + self.assertEqual(1, len(content['multipart-data'])) + self.assertEqual( + {'content-type': 'application/octet-stream', + 'data': 'data', + 'filename': 'unknown', + 'name': 'unknown'}, content['multipart-data'][0]) + self.assertEqual(r.status, 200) + r.close() + + def test_POST_FILES_WITH_DATA(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + url = httpd.url('method', 'post') + + with open(__file__) as f: + r = self.loop.run_until_complete( + client.request('post', url, loop=self.loop, + data={'test': 'true'}, files={'some': f})) + + content = self.loop.run_until_complete(r.read(True)) + + self.assertEqual(2, len(content['multipart-data'])) + self.assertEqual( + 'test', content['multipart-data'][0]['name']) + self.assertEqual( + 'true', content['multipart-data'][0]['data']) + + f.seek(0) + filename = os.path.split(f.name)[-1] + self.assertEqual( + 'some', content['multipart-data'][1]['name']) + self.assertEqual( + filename, content['multipart-data'][1]['filename']) + self.assertEqual( + f.read(), content['multipart-data'][1]['data']) + self.assertEqual(r.status, 200) + r.close() + + def test_encoding(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('encoding', 'deflate'), + loop=self.loop)) + self.assertEqual(r.status, 200) + + r = self.loop.run_until_complete( + client.request('get', httpd.url('encoding', 'gzip'), + loop=self.loop)) + self.assertEqual(r.status, 200) + r.close() + + def test_cookies(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + c = http.cookies.Morsel() + c.set('test3', '456', '456') + + r = self.loop.run_until_complete( + client.request( + 'get', httpd.url('method', 'get'), loop=self.loop, + cookies={'test1': '123', 'test2': c})) + self.assertEqual(r.status, 200) + + content = self.loop.run_until_complete(r.content.read()) + self.assertIn(b'"Cookie": "test1=123; test3=456"', bytes(content)) + r.close() + + def test_set_cookies(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + resp = self.loop.run_until_complete( + client.request('get', httpd.url('cookies'), loop=self.loop)) + self.assertEqual(resp.status, 200) + + self.assertEqual(resp.cookies['c1'].value, 'cookie1') + self.assertEqual(resp.cookies['c2'].value, 'cookie2') + resp.close() + + def test_chunked(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('chunked'), loop=self.loop)) + self.assertEqual(r.status, 200) + self.assertEqual(r['Transfer-Encoding'], 'chunked') + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['path'], '/chunked') + r.close() + + def test_timeout(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + httpd['noresponse'] = True + self.assertRaises( + tulip.TimeoutError, + self.loop.run_until_complete, + client.request('get', httpd.url('method', 'get'), + timeout=0.1, loop=self.loop)) + + def test_request_conn_error(self): + self.assertRaises( + OSError, + self.loop.run_until_complete, + client.request('get', 'http://0.0.0.0:1', + timeout=0.1, loop=self.loop)) + + def test_request_conn_closed(self): + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + httpd['close'] = True + self.assertRaises( + tulip.http.HttpException, + self.loop.run_until_complete, + client.request('get', httpd.url('method', 'get'), + loop=self.loop)) + + def test_keepalive(self): + from tulip.http import session + s = session.Session() + + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request('get', httpd.url('keepalive',), + session=s, loop=self.loop)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['content'], 'requests=1') + r.close() + + r = self.loop.run_until_complete( + client.request('get', httpd.url('keepalive'), + session=s, loop=self.loop)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['content'], 'requests=2') + r.close() + + def test_session_close(self): + from tulip.http import session + s = session.Session() + + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + r = self.loop.run_until_complete( + client.request( + 'get', httpd.url('keepalive') + '?close=1', + session=s, loop=self.loop)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['content'], 'requests=1') + r.close() + + r = self.loop.run_until_complete( + client.request('get', httpd.url('keepalive'), + session=s, loop=self.loop)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + self.assertEqual(content['content'], 'requests=1') + r.close() + + def test_session_cookies(self): + from tulip.http import session + s = session.Session() + + with test_utils.run_test_server(self.loop, router=Functional) as httpd: + s.update_cookies({'test': '1'}) + r = self.loop.run_until_complete( + client.request('get', httpd.url('cookies'), + session=s, loop=self.loop)) + self.assertEqual(r.status, 200) + content = self.loop.run_until_complete(r.read(True)) + + self.assertEqual(content['headers']['Cookie'], 'test=1') + r.close() + + cookies = sorted([(k, v.value) for k, v in s.cookies.items()]) + self.assertEqual( + cookies, [('c1', 'cookie1'), ('c2', 'cookie2'), ('test', '1')]) + + +class Functional(test_utils.Router): + + @test_utils.Router.define('/method/([A-Za-z]+)$') + def method(self, match): + meth = match.group(1).upper() + if meth == self._method: + self._response(self._start_response(200)) + else: + self._response(self._start_response(400)) + + @test_utils.Router.define('/redirect_err$') + def redirect_err(self, match): + self._response( + self._start_response(302), + headers={'Location': 'ftp://127.0.0.1/test/'}) + + @test_utils.Router.define('/redirect/([0-9]+)$') + def redirect(self, match): + no = int(match.group(1).upper()) + rno = self._props['redirects'] = self._props.get('redirects', 0) + 1 + + if rno >= no: + self._response( + self._start_response(302), + headers={'Location': '/method/%s' % self._method.lower()}) + else: + self._response( + self._start_response(302), + headers={'Location': self._path}) + + @test_utils.Router.define('/encoding/(gzip|deflate)$') + def encoding(self, match): + mode = match.group(1) + + resp = self._start_response(200) + resp.add_compression_filter(mode) + resp.add_chunking_filter(100) + self._response(resp, headers={'Content-encoding': mode}, chunked=True) + + @test_utils.Router.define('/chunked$') + def chunked(self, match): + resp = self._start_response(200) + resp.add_chunking_filter(100) + self._response(resp, chunked=True) + + @test_utils.Router.define('/keepalive$') + def keepalive(self, match): + self._transport._requests = getattr( + self._transport, '_requests', 0) + 1 + resp = self._start_response(200) + if 'close=' in self._query: + self._response( + resp, 'requests={}'.format(self._transport._requests)) + else: + self._response( + resp, 'requests={}'.format(self._transport._requests), + headers={'CONNECTION': 'keep-alive'}) + + @test_utils.Router.define('/cookies$') + def cookies(self, match): + cookies = http.cookies.SimpleCookie() + cookies['c1'] = 'cookie1' + cookies['c2'] = 'cookie2' + + resp = self._start_response(200) + for cookie in cookies.output(header='').split('\n'): + resp.add_header('Set-Cookie', cookie.strip()) + + self._response(resp) diff --git a/tests/http_client_test.py b/tests/http_client_test.py new file mode 100644 index 00000000..1aa27244 --- /dev/null +++ b/tests/http_client_test.py @@ -0,0 +1,289 @@ +# -*- coding: utf-8 -*- +"""Tests for tulip/http/client.py""" + +import unittest +import unittest.mock +import urllib.parse + +import tulip +import tulip.http + +from tulip.http.client import HttpRequest, HttpResponse + + +class HttpResponseTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(None) + + self.transport = unittest.mock.Mock() + self.stream = tulip.StreamBuffer(loop=self.loop) + self.response = HttpResponse('get', 'http://python.org') + + def tearDown(self): + self.loop.close() + + def test_close(self): + self.response.transport = self.transport + self.response.close() + self.assertIsNone(self.response.transport) + self.assertTrue(self.transport.close.called) + self.response.close() + self.response.close() + + def test_repr(self): + self.response.status = 200 + self.response.reason = 'Ok' + self.assertIn( + '', repr(self.response)) + + +class HttpRequestTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(None) + + self.transport = unittest.mock.Mock() + self.stream = tulip.StreamBuffer(loop=self.loop) + + def tearDown(self): + self.loop.close() + + def test_method(self): + req = HttpRequest('get', 'http://python.org/') + self.assertEqual(req.method, 'GET') + + req = HttpRequest('head', 'http://python.org/') + self.assertEqual(req.method, 'HEAD') + + req = HttpRequest('HEAD', 'http://python.org/') + self.assertEqual(req.method, 'HEAD') + + def test_version(self): + req = HttpRequest('get', 'http://python.org/', version='1.0') + self.assertEqual(req.version, (1, 0)) + + def test_version_err(self): + self.assertRaises( + ValueError, + HttpRequest, 'get', 'http://python.org/', version='1.c') + + def test_host_port(self): + req = HttpRequest('get', 'http://python.org/') + self.assertEqual(req.host, 'python.org') + self.assertEqual(req.port, 80) + self.assertFalse(req.ssl) + + req = HttpRequest('get', 'https://python.org/') + self.assertEqual(req.host, 'python.org') + self.assertEqual(req.port, 443) + self.assertTrue(req.ssl) + + req = HttpRequest('get', 'https://python.org:960/') + self.assertEqual(req.host, 'python.org') + self.assertEqual(req.port, 960) + self.assertTrue(req.ssl) + + def test_host_port_err(self): + self.assertRaises( + ValueError, HttpRequest, 'get', 'http://python.org:123e/') + + def test_host_header(self): + req = HttpRequest('get', 'http://python.org/') + self.assertEqual(req.headers['host'], 'python.org') + + req = HttpRequest('get', 'http://python.org/', + headers={'host': 'example.com'}) + self.assertEqual(req.headers['host'], 'example.com') + + def test_headers(self): + req = HttpRequest('get', 'http://python.org/', + headers={'Content-Type': 'text/plain'}) + self.assertIn('Content-Type', req.headers) + self.assertEqual(req.headers['Content-Type'], 'text/plain') + self.assertEqual(req.headers['Accept-Encoding'], 'gzip, deflate') + + def test_headers_list(self): + req = HttpRequest('get', 'http://python.org/', + headers=[('Content-Type', 'text/plain')]) + self.assertIn('Content-Type', req.headers) + self.assertEqual(req.headers['Content-Type'], 'text/plain') + + def test_headers_default(self): + req = HttpRequest('get', 'http://python.org/', + headers={'Accept-Encoding': 'deflate'}) + self.assertEqual(req.headers['Accept-Encoding'], 'deflate') + + def test_invalid_url(self): + self.assertRaises(ValueError, HttpRequest, 'get', 'hiwpefhipowhefopw') + + def test_invalid_idna(self): + self.assertRaises( + ValueError, HttpRequest, 'get', 'http://\u2061owhefopw.com') + + def test_no_path(self): + req = HttpRequest('get', 'http://python.org') + self.assertEqual('/', req.path) + + def test_basic_auth(self): + req = HttpRequest('get', 'http://python.org', auth=('nkim', '1234')) + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbToxMjM0', req.headers['Authorization']) + + def test_basic_auth_from_url(self): + req = HttpRequest('get', 'http://nkim:1234@python.org') + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbToxMjM0', req.headers['Authorization']) + + req = HttpRequest('get', 'http://nkim@python.org') + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbTo=', req.headers['Authorization']) + + req = HttpRequest( + 'get', 'http://nkim@python.org', auth=('nkim', '1234')) + self.assertIn('Authorization', req.headers) + self.assertEqual('Basic bmtpbToxMjM0', req.headers['Authorization']) + + def test_basic_auth_err(self): + self.assertRaises( + ValueError, HttpRequest, + 'get', 'http://python.org', auth=(1, 2, 3)) + + def test_no_content_length(self): + req = HttpRequest('get', 'http://python.org') + req.send(self.transport) + self.assertEqual('0', req.headers.get('Content-Length')) + + req = HttpRequest('head', 'http://python.org') + req.send(self.transport) + self.assertEqual('0', req.headers.get('Content-Length')) + + def test_path_is_not_double_encoded(self): + req = HttpRequest('get', "http://0.0.0.0/get/test case") + self.assertEqual(req.path, "/get/test%20case") + + req = HttpRequest('get', "http://0.0.0.0/get/test%20case") + self.assertEqual(req.path, "/get/test%20case") + + def test_params_are_added_before_fragment(self): + req = HttpRequest( + 'GET', "http://example.com/path#fragment", params={"a": "b"}) + self.assertEqual( + req.path, "/path?a=b#fragment") + + req = HttpRequest( + 'GET', + "http://example.com/path?key=value#fragment", params={"a": "b"}) + self.assertEqual( + req.path, "/path?key=value&a=b#fragment") + + def test_cookies(self): + req = HttpRequest( + 'get', 'http://test.com/path', cookies={'cookie1': 'val1'}) + self.assertIn('Cookie', req.headers) + self.assertEqual('cookie1=val1', req.headers['cookie']) + + req = HttpRequest( + 'get', 'http://test.com/path', + headers={'cookie': 'cookie1=val1'}, + cookies={'cookie2': 'val2'}) + self.assertEqual('cookie1=val1; cookie2=val2', req.headers['cookie']) + + def test_unicode_get(self): + def join(*suffix): + return urllib.parse.urljoin('http://python.org/', '/'.join(suffix)) + + url = 'http://python.org' + req = HttpRequest('get', url, params={'foo': 'f\xf8\xf8'}) + self.assertEqual('/?foo=f%C3%B8%C3%B8', req.path) + req = HttpRequest('', url, params={'f\xf8\xf8': 'f\xf8\xf8'}) + self.assertEqual('/?f%C3%B8%C3%B8=f%C3%B8%C3%B8', req.path) + req = HttpRequest('', url, params={'foo': 'foo'}) + self.assertEqual('/?foo=foo', req.path) + req = HttpRequest('', join('\xf8'), params={'foo': 'foo'}) + self.assertEqual('/%C3%B8?foo=foo', req.path) + + def test_query_multivalued_param(self): + for meth in HttpRequest.ALL_METHODS: + req = HttpRequest( + meth, 'http://python.org', + params=(('test', 'foo'), ('test', 'baz'))) + self.assertEqual(req.path, '/?test=foo&test=baz') + + def test_post_data(self): + for meth in HttpRequest.POST_METHODS: + req = HttpRequest(meth, 'http://python.org/', data={'life': '42'}) + req.send(self.transport) + self.assertEqual('/', req.path) + self.assertEqual(b'life=42', req.body[0]) + self.assertEqual('application/x-www-form-urlencoded', + req.headers['content-type']) + + def test_get_with_data(self): + for meth in HttpRequest.GET_METHODS: + req = HttpRequest(meth, 'http://python.org/', data={'life': '42'}) + self.assertEqual('/?life=42', req.path) + + @unittest.mock.patch('tulip.http.client.tulip') + def test_content_encoding(self, m_tulip): + req = HttpRequest('get', 'http://python.org/', compress='deflate') + req.send(self.transport) + self.assertEqual(req.headers['Transfer-encoding'], 'chunked') + self.assertEqual(req.headers['Content-encoding'], 'deflate') + m_tulip.http.Request.return_value\ + .add_compression_filter.assert_called_with('deflate') + + @unittest.mock.patch('tulip.http.client.tulip') + def test_content_encoding_header(self, m_tulip): + req = HttpRequest('get', 'http://python.org/', + headers={'Content-Encoding': 'deflate'}) + req.send(self.transport) + self.assertEqual(req.headers['Transfer-encoding'], 'chunked') + self.assertEqual(req.headers['Content-encoding'], 'deflate') + + m_tulip.http.Request.return_value\ + .add_compression_filter.assert_called_with('deflate') + m_tulip.http.Request.return_value\ + .add_chunking_filter.assert_called_with(8196) + + def test_chunked(self): + req = HttpRequest( + 'get', 'http://python.org/', + headers={'Transfer-encoding': 'gzip'}) + req.send(self.transport) + self.assertEqual('gzip', req.headers['Transfer-encoding']) + + req = HttpRequest( + 'get', 'http://python.org/', + headers={'Transfer-encoding': 'chunked'}) + req.send(self.transport) + self.assertEqual('chunked', req.headers['Transfer-encoding']) + + @unittest.mock.patch('tulip.http.client.tulip') + def test_chunked_explicit(self, m_tulip): + req = HttpRequest( + 'get', 'http://python.org/', chunked=True) + req.send(self.transport) + + self.assertEqual('chunked', req.headers['Transfer-encoding']) + m_tulip.http.Request.return_value\ + .add_chunking_filter.assert_called_with(8196) + + @unittest.mock.patch('tulip.http.client.tulip') + def test_chunked_explicit_size(self, m_tulip): + req = HttpRequest( + 'get', 'http://python.org/', chunked=1024) + req.send(self.transport) + self.assertEqual('chunked', req.headers['Transfer-encoding']) + m_tulip.http.Request.return_value\ + .add_chunking_filter.assert_called_with(1024) + + def test_chunked_length(self): + req = HttpRequest( + 'get', 'http://python.org/', + headers={'Content-Length': '1000'}, chunked=1024) + req.send(self.transport) + self.assertEqual(req.headers['Transfer-Encoding'], 'chunked') + self.assertNotIn('Content-Length', req.headers) diff --git a/tests/http_parser_test.py b/tests/http_parser_test.py new file mode 100644 index 00000000..6240ad49 --- /dev/null +++ b/tests/http_parser_test.py @@ -0,0 +1,539 @@ +"""Tests for http/parser.py""" + +from collections import deque +import zlib +import unittest +import unittest.mock + +import tulip +from tulip.http import errors +from tulip.http import protocol + + +class ParseHeadersTests(unittest.TestCase): + + def setUp(self): + tulip.set_event_loop(None) + + def test_parse_headers(self): + hdrs = ('', 'test: line\r\n', ' continue\r\n', + 'test2: data\r\n', '\r\n') + + headers, close, compression = protocol.parse_headers( + hdrs, 8190, 32768, 8190) + + self.assertEqual(list(headers), + [('TEST', 'line\r\n continue'), ('TEST2', 'data')]) + self.assertIsNone(close) + self.assertIsNone(compression) + + def test_parse_headers_multi(self): + hdrs = ('', + 'Set-Cookie: c1=cookie1\r\n', + 'Set-Cookie: c2=cookie2\r\n', '\r\n') + + headers, close, compression = protocol.parse_headers( + hdrs, 8190, 32768, 8190) + + self.assertEqual(list(headers), + [('SET-COOKIE', 'c1=cookie1'), + ('SET-COOKIE', 'c2=cookie2')]) + self.assertIsNone(close) + self.assertIsNone(compression) + + def test_conn_close(self): + headers, close, compression = protocol.parse_headers( + ['', 'connection: close\r\n', '\r\n'], 8190, 32768, 8190) + self.assertTrue(close) + + def test_conn_keep_alive(self): + headers, close, compression = protocol.parse_headers( + ['', 'connection: keep-alive\r\n', '\r\n'], 8190, 32768, 8190) + self.assertFalse(close) + + def test_conn_other(self): + headers, close, compression = protocol.parse_headers( + ['', 'connection: test\r\n', '\r\n'], 8190, 32768, 8190) + self.assertIsNone(close) + + def test_compression_gzip(self): + headers, close, compression = protocol.parse_headers( + ['', 'content-encoding: gzip\r\n', '\r\n'], 8190, 32768, 8190) + self.assertEqual('gzip', compression) + + def test_compression_deflate(self): + headers, close, compression = protocol.parse_headers( + ['', 'content-encoding: deflate\r\n', '\r\n'], 8190, 32768, 8190) + self.assertEqual('deflate', compression) + + def test_compression_unknown(self): + headers, close, compression = protocol.parse_headers( + ['', 'content-encoding: compress\r\n', '\r\n'], 8190, 32768, 8190) + self.assertIsNone(compression) + + def test_max_field_size(self): + with self.assertRaises(errors.LineTooLong) as cm: + protocol.parse_headers( + ['', 'test: line data data\r\n', 'data\r\n', '\r\n'], + 8190, 32768, 5) + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_max_continuation_headers_size(self): + with self.assertRaises(errors.LineTooLong) as cm: + protocol.parse_headers( + ['', 'test: line\r\n', ' test\r\n', '\r\n'], 8190, 32768, 5) + self.assertIn("limit request headers fields size", str(cm.exception)) + + def test_invalid_header(self): + with self.assertRaises(ValueError) as cm: + protocol.parse_headers( + ['', 'test line\r\n', '\r\n'], 8190, 32768, 8190) + self.assertIn("Invalid header: test line", str(cm.exception)) + + def test_invalid_name(self): + with self.assertRaises(ValueError) as cm: + protocol.parse_headers( + ['', 'test[]: line\r\n', '\r\n'], 8190, 32768, 8190) + self.assertIn("Invalid header name: TEST[]", str(cm.exception)) + + +class DeflateBufferTests(unittest.TestCase): + + def setUp(self): + tulip.set_event_loop(None) + + def test_feed_data(self): + buf = tulip.DataBuffer() + dbuf = protocol.DeflateBuffer(buf, 'deflate') + + dbuf.zlib = unittest.mock.Mock() + dbuf.zlib.decompress.return_value = b'line' + + dbuf.feed_data(b'data') + self.assertEqual([b'line'], list(buf._buffer)) + + def test_feed_data_err(self): + buf = tulip.DataBuffer() + dbuf = protocol.DeflateBuffer(buf, 'deflate') + + exc = ValueError() + dbuf.zlib = unittest.mock.Mock() + dbuf.zlib.decompress.side_effect = exc + + self.assertRaises(errors.IncompleteRead, dbuf.feed_data, b'data') + + def test_feed_eof(self): + buf = tulip.DataBuffer() + dbuf = protocol.DeflateBuffer(buf, 'deflate') + + dbuf.zlib = unittest.mock.Mock() + dbuf.zlib.flush.return_value = b'line' + + dbuf.feed_eof() + self.assertEqual([b'line'], list(buf._buffer)) + self.assertTrue(buf._eof) + + def test_feed_eof_err(self): + buf = tulip.DataBuffer() + dbuf = protocol.DeflateBuffer(buf, 'deflate') + + dbuf.zlib = unittest.mock.Mock() + dbuf.zlib.flush.return_value = b'line' + dbuf.zlib.eof = False + + self.assertRaises(errors.IncompleteRead, dbuf.feed_eof) + + +class ParsePayloadTests(unittest.TestCase): + + def setUp(self): + tulip.set_event_loop(None) + + def test_parse_eof_payload(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_eof_payload(out, buf) + next(p) + p.send(b'data') + try: + p.throw(tulip.EofStream()) + except tulip.EofStream: + pass + + self.assertEqual([b'data'], list(out._buffer)) + + def test_parse_length_payload(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_length_payload(out, buf, 4) + next(p) + p.send(b'da') + p.send(b't') + try: + p.send(b'aline') + except StopIteration: + pass + + self.assertEqual(3, len(out._buffer)) + self.assertEqual(b'data', b''.join(out._buffer)) + self.assertEqual(b'line', bytes(buf)) + + def test_parse_length_payload_eof(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_length_payload(out, buf, 4) + next(p) + p.send(b'da') + self.assertRaises( + errors.IncompleteRead, p.throw, tulip.EofStream) + + def test_parse_chunked_payload(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + try: + p.send(b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') + except StopIteration: + pass + self.assertEqual(b'dataline', b''.join(out._buffer)) + self.assertEqual(b'', bytes(buf)) + + def test_parse_chunked_payload_chunks(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + p.send(b'4\r\ndata\r') + p.send(b'\n4') + p.send(b'\r') + p.send(b'\n') + p.send(b'line\r\n0\r\n') + self.assertRaises(StopIteration, p.send, b'test\r\n') + self.assertEqual(b'dataline', b''.join(out._buffer)) + + def test_parse_chunked_payload_incomplete(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + p.send(b'4\r\ndata\r\n') + self.assertRaises(errors.IncompleteRead, p.throw, tulip.EofStream) + + def test_parse_chunked_payload_extension(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + try: + p.send(b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') + except StopIteration: + pass + self.assertEqual(b'dataline', b''.join(out._buffer)) + + def test_parse_chunked_payload_size_error(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = protocol.parse_chunked_payload(out, buf) + next(p) + self.assertRaises(errors.IncompleteRead, p.send, b'blah\r\n') + + def test_http_payload_parser_length_broken(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', 'qwe')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + self.assertRaises(errors.InvalidHeader, p.send, (out, buf)) + + def test_http_payload_parser_length_wrong(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', '-1')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + self.assertRaises(errors.InvalidHeader, p.send, (out, buf)) + + def test_http_payload_parser_length(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', '2')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + try: + p.send(b'1245') + except StopIteration: + pass + + self.assertEqual(b'12', b''.join(out._buffer)) + self.assertEqual(b'45', bytes(buf)) + + def test_http_payload_parser_no_length(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [], None, None) + p = protocol.http_payload_parser(msg, readall=False) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + self.assertRaises(StopIteration, p.send, (out, buf)) + self.assertEqual(b'', b''.join(out._buffer)) + self.assertTrue(out._eof) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_http_payload_parser_deflate(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', len(self._COMPRESSED))], + None, 'deflate') + p = protocol.http_payload_parser(msg) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(StopIteration, p.send, self._COMPRESSED) + self.assertEqual(b'data', b''.join(out._buffer)) + + def test_http_payload_parser_deflate_disabled(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', len(self._COMPRESSED))], + None, 'deflate') + p = protocol.http_payload_parser(msg, compression=False) + next(p) + + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(StopIteration, p.send, self._COMPRESSED) + self.assertEqual(self._COMPRESSED, b''.join(out._buffer)) + + def test_http_payload_parser_websocket(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('SEC-WEBSOCKET-KEY1', '13')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(StopIteration, p.send, b'1234567890') + self.assertEqual(b'12345678', b''.join(out._buffer)) + + def test_http_payload_parser_chunked(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('TRANSFER-ENCODING', 'chunked')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(StopIteration, p.send, + b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') + self.assertEqual(b'dataline', b''.join(out._buffer)) + + def test_http_payload_parser_eof(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [], None, None) + p = protocol.http_payload_parser(msg, readall=True) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + p.send(b'data') + p.send(b'line') + self.assertRaises(tulip.EofStream, p.throw, tulip.EofStream()) + self.assertEqual(b'dataline', b''.join(out._buffer)) + + def test_http_payload_parser_length_zero(self): + msg = protocol.RawRequestMessage( + 'GET', '/', (1, 1), [('CONTENT-LENGTH', '0')], None, None) + p = protocol.http_payload_parser(msg) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + self.assertRaises(StopIteration, p.send, (out, buf)) + self.assertEqual(b'', b''.join(out._buffer)) + + +class ParseRequestTests(unittest.TestCase): + + def setUp(self): + tulip.set_event_loop(None) + + def test_http_request_parser_max_headers(self): + p = protocol.http_request_parser(8190, 20, 8190) + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + + self.assertRaises( + errors.LineTooLong, + p.send, + b'get /path HTTP/1.1\r\ntest: line\r\ntest2: data\r\n\r\n') + + def test_http_request_parser(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + try: + p.send(b'get /path HTTP/1.1\r\n\r\n') + except StopIteration: + pass + result = out._buffer[0] + self.assertEqual( + ('GET', '/path', (1, 1), deque(), False, None), result) + + def test_http_request_parser_eof(self): + # http_request_parser does not fail on EofStream() + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + p.send(b'get /path HTTP/1.1\r\n') + try: + p.throw(tulip.EofStream()) + except StopIteration: + pass + self.assertFalse(out._buffer) + + def test_http_request_parser_two_slashes(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + try: + p.send(b'get //path HTTP/1.1\r\n\r\n') + except StopIteration: + pass + self.assertEqual( + ('GET', '//path', (1, 1), deque(), False, None), out._buffer[0]) + + def test_http_request_parser_bad_status_line(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises( + errors.BadStatusLine, p.send, b'\r\n\r\n') + + def test_http_request_parser_bad_method(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises( + errors.BadStatusLine, + p.send, b'!12%()+=~$ /get HTTP/1.1\r\n\r\n') + + def test_http_request_parser_bad_version(self): + p = protocol.http_request_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises( + errors.BadStatusLine, + p.send, b'GET //get HT/11\r\n\r\n') + + +class ParseResponseTests(unittest.TestCase): + + def setUp(self): + tulip.set_event_loop(None) + + def test_http_response_parser_bad_status_line(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises(errors.BadStatusLine, p.send, b'\r\n\r\n') + + def test_http_response_parser_bad_status_line_eof(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + self.assertRaises( + errors.BadStatusLine, p.throw, tulip.EofStream()) + + def test_http_response_parser_bad_version(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HT/11 200 Ok\r\n\r\n') + self.assertEqual('HT/11 200 Ok\r\n', cm.exception.args[0]) + + def test_http_response_parser_no_reason(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + try: + p.send(b'HTTP/1.1 200\r\n\r\n') + except StopIteration: + pass + v, s, r = out._buffer[0][:3] + self.assertEqual(v, (1, 1)) + self.assertEqual(s, 200) + self.assertEqual(r, '') + + def test_http_response_parser_bad(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HTT/1\r\n\r\n') + self.assertIn('HTT/1', str(cm.exception)) + + def test_http_response_parser_code_under_100(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HTTP/1.1 99 test\r\n\r\n') + self.assertIn('HTTP/1.1 99 test', str(cm.exception)) + + def test_http_response_parser_code_above_999(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HTTP/1.1 9999 test\r\n\r\n') + self.assertIn('HTTP/1.1 9999 test', str(cm.exception)) + + def test_http_response_parser_code_not_int(self): + p = protocol.http_response_parser() + next(p) + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p.send((out, buf)) + with self.assertRaises(errors.BadStatusLine) as cm: + p.send(b'HTTP/1.1 ttt test\r\n\r\n') + self.assertIn('HTTP/1.1 ttt test', str(cm.exception)) diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py new file mode 100644 index 00000000..ec3aaf58 --- /dev/null +++ b/tests/http_protocol_test.py @@ -0,0 +1,400 @@ +"""Tests for http/protocol.py""" + +import unittest +import unittest.mock +import zlib + +import tulip +from tulip.http import protocol + + +class HttpMessageTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + tulip.set_event_loop(None) + + def test_start_request(self): + msg = protocol.Request( + self.transport, 'GET', '/index.html', close=True) + + self.assertIs(msg.transport, self.transport) + self.assertIsNone(msg.status) + self.assertTrue(msg.closing) + self.assertEqual(msg.status_line, 'GET /index.html HTTP/1.1\r\n') + + def test_start_response(self): + msg = protocol.Response(self.transport, 200, close=True) + + self.assertIs(msg.transport, self.transport) + self.assertEqual(msg.status, 200) + self.assertTrue(msg.closing) + self.assertEqual(msg.status_line, 'HTTP/1.1 200 OK\r\n') + + def test_force_close(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.closing) + msg.force_close() + self.assertTrue(msg.closing) + + def test_force_chunked(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.chunked) + msg.force_chunked() + self.assertTrue(msg.chunked) + + def test_keep_alive(self): + msg = protocol.Response(self.transport, 200, close=True) + self.assertFalse(msg.keep_alive()) + msg.keepalive = True + self.assertTrue(msg.keep_alive()) + + msg.force_close() + self.assertFalse(msg.keep_alive()) + + def test_keep_alive_http10(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + self.assertFalse(msg.keepalive) + self.assertFalse(msg.keep_alive()) + + msg = protocol.Response(self.transport, 200, http_version=(1, 1)) + self.assertIsNone(msg.keepalive) + + def test_add_header(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], list(msg.headers)) + + msg.add_header('content-type', 'plain/html') + self.assertEqual([('CONTENT-TYPE', 'plain/html')], list(msg.headers)) + + def test_add_headers(self): + msg = protocol.Response(self.transport, 200) + self.assertEqual([], list(msg.headers)) + + msg.add_headers(('content-type', 'plain/html')) + self.assertEqual([('CONTENT-TYPE', 'plain/html')], list(msg.headers)) + + def test_add_headers_length(self): + msg = protocol.Response(self.transport, 200) + self.assertIsNone(msg.length) + + msg.add_headers(('content-length', '200')) + self.assertEqual(200, msg.length) + + def test_add_headers_upgrade(self): + msg = protocol.Response(self.transport, 200) + self.assertFalse(msg.upgrade) + + msg.add_headers(('connection', 'upgrade')) + self.assertTrue(msg.upgrade) + + def test_add_headers_upgrade_websocket(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('upgrade', 'test')) + self.assertEqual([], list(msg.headers)) + + msg.add_headers(('upgrade', 'websocket')) + self.assertEqual([('UPGRADE', 'websocket')], list(msg.headers)) + + def test_add_headers_connection_keepalive(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('connection', 'keep-alive')) + self.assertEqual([], list(msg.headers)) + self.assertTrue(msg.keepalive) + + msg.add_headers(('connection', 'close')) + self.assertFalse(msg.keepalive) + + def test_add_headers_hop_headers(self): + msg = protocol.Response(self.transport, 200) + + msg.add_headers(('connection', 'test'), ('transfer-encoding', 't')) + self.assertEqual([], list(msg.headers)) + + def test_default_headers(self): + msg = protocol.Response(self.transport, 200) + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertIn('DATE', headers) + self.assertIn('CONNECTION', headers) + + def test_default_headers_server(self): + msg = protocol.Response(self.transport, 200) + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertIn('SERVER', headers) + + def test_default_headers_useragent(self): + msg = protocol.Request(self.transport, 'GET', '/') + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertNotIn('SERVER', headers) + self.assertIn('USER-AGENT', headers) + + def test_default_headers_chunked(self): + msg = protocol.Response(self.transport, 200) + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertNotIn('TRANSFER-ENCODING', headers) + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg._add_default_headers() + + headers = [r for r, _ in msg.headers] + self.assertIn('TRANSFER-ENCODING', headers) + + def test_default_headers_connection_upgrade(self): + msg = protocol.Response(self.transport, 200) + msg.upgrade = True + msg._add_default_headers() + + headers = [r for r in msg.headers if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'upgrade')], headers) + + def test_default_headers_connection_close(self): + msg = protocol.Response(self.transport, 200) + msg.force_close() + msg._add_default_headers() + + headers = [r for r in msg.headers if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'close')], headers) + + def test_default_headers_connection_keep_alive(self): + msg = protocol.Response(self.transport, 200) + msg.keepalive = True + msg._add_default_headers() + + headers = [r for r in msg.headers if r[0] == 'CONNECTION'] + self.assertEqual([('CONNECTION', 'keep-alive')], headers) + + def test_send_headers(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-type', 'plain/html')) + self.assertFalse(msg.is_headers_sent()) + + msg.send_headers() + + content = b''.join([arg[1][0] for arg in list(write.mock_calls)]) + + self.assertTrue(content.startswith(b'HTTP/1.1 200 OK\r\n')) + self.assertIn(b'CONTENT-TYPE: plain/html', content) + self.assertTrue(msg.headers_sent) + self.assertTrue(msg.is_headers_sent()) + # cleanup + msg.writer.close() + + def test_send_headers_nomore_add(self): + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-type', 'plain/html')) + msg.send_headers() + + self.assertRaises(AssertionError, + msg.add_header, 'content-type', 'plain/html') + # cleanup + msg.writer.close() + + def test_prepare_length(self): + msg = protocol.Response(self.transport, 200) + length = msg._write_length_payload = unittest.mock.Mock() + length.return_value = iter([1, 2, 3]) + + msg.add_headers(('content-length', '200')) + msg.send_headers() + + self.assertTrue(length.called) + self.assertTrue((200,), length.call_args[0]) + + def test_prepare_chunked_force(self): + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + + chunked = msg._write_chunked_payload = unittest.mock.Mock() + chunked.return_value = iter([1, 2, 3]) + + msg.add_headers(('content-length', '200')) + msg.send_headers() + self.assertTrue(chunked.called) + + def test_prepare_chunked_no_length(self): + msg = protocol.Response(self.transport, 200) + + chunked = msg._write_chunked_payload = unittest.mock.Mock() + chunked.return_value = iter([1, 2, 3]) + + msg.send_headers() + self.assertTrue(chunked.called) + + def test_prepare_eof(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + + eof = msg._write_eof_payload = unittest.mock.Mock() + eof.return_value = iter([1, 2, 3]) + + msg.send_headers() + self.assertTrue(eof.called) + + def test_write_auto_send_headers(self): + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg._send_headers = True + + msg.write(b'data1') + self.assertTrue(msg.headers_sent) + # cleanup + msg.writer.close() + + def test_write_payload_eof(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200, http_version=(1, 0)) + msg.send_headers() + + msg.write(b'data1') + self.assertTrue(msg.headers_sent) + + msg.write(b'data2') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'data1data2', content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg.send_headers() + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'4\r\ndata\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_multiple(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.force_chunked() + msg.send_headers() + + msg.write(b'data1') + msg.write(b'data2') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_length(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '2')) + msg.send_headers() + + msg.write(b'd') + msg.write(b'ata') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'da', content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_filter(self): + write = self.transport.write = unittest.mock.Mock() + + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(2) + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith(b'2\r\nda\r\n2\r\nta\r\n0\r\n\r\n')) + + def test_write_payload_chunked_filter_mutiple_chunks(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(2) + msg.write(b'data1') + msg.write(b'data2') + msg.write_eof() + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith( + b'2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n' + b'2\r\na2\r\n0\r\n\r\n')) + + def test_write_payload_chunked_large_chunk(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_chunking_filter(1024) + msg.write(b'data') + msg.write_eof() + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertTrue(content.endswith(b'4\r\ndata\r\n0\r\n\r\n')) + + _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) + _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) + + def test_write_payload_deflate_filter(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '{}'.format(len(self._COMPRESSED)))) + msg.send_headers() + + msg.add_compression_filter('deflate') + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_deflate_and_chunked(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.send_headers() + + msg.add_compression_filter('deflate') + msg.add_chunking_filter(2) + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + b'2\r\nKI\r\n2\r\n,I\r\n2\r\n\x04\x00\r\n0\r\n\r\n', + content.split(b'\r\n\r\n', 1)[-1]) + + def test_write_payload_chunked_and_deflate(self): + write = self.transport.write = unittest.mock.Mock() + msg = protocol.Response(self.transport, 200) + msg.add_headers(('content-length', '{}'.format(len(self._COMPRESSED)))) + + msg.add_chunking_filter(2) + msg.add_compression_filter('deflate') + msg.send_headers() + + msg.write(b'data') + msg.write_eof() + + content = b''.join([c[1][0] for c in list(write.mock_calls)]) + self.assertEqual( + self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) diff --git a/tests/http_server_test.py b/tests/http_server_test.py new file mode 100644 index 00000000..a9d4d5ed --- /dev/null +++ b/tests/http_server_test.py @@ -0,0 +1,301 @@ +"""Tests for http/server.py""" + +import unittest +import unittest.mock + +import tulip +from tulip.http import server +from tulip.http import errors +from tulip import test_utils + + +class HttpServerProtocolTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + tulip.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_http_error_exception(self): + exc = errors.HttpErrorException(500, message='Internal error') + self.assertEqual(exc.code, 500) + self.assertEqual(exc.message, 'Internal error') + + def test_handle_request(self): + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(transport) + + rline = unittest.mock.Mock() + rline.version = (1, 1) + message = unittest.mock.Mock() + srv.handle_request(rline, message) + + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertTrue(content.startswith(b'HTTP/1.1 404 Not Found\r\n')) + + def test_connection_made(self): + srv = server.ServerHttpProtocol(loop=self.loop) + self.assertIsNone(srv._request_handler) + + srv.connection_made(unittest.mock.Mock()) + self.assertIsNotNone(srv._request_handler) + + def test_data_received(self): + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(unittest.mock.Mock()) + + srv.data_received(b'123') + self.assertEqual(b'123', bytes(srv.stream._buffer)) + + srv.data_received(b'456') + self.assertEqual(b'123456', bytes(srv.stream._buffer)) + + def test_eof_received(self): + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(unittest.mock.Mock()) + srv.eof_received() + self.assertTrue(srv.stream._eof) + + def test_connection_lost(self): + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(unittest.mock.Mock()) + srv.data_received(b'123') + + keep_alive_handle = srv._keep_alive_handle = unittest.mock.Mock() + + handle = srv._request_handler + srv.connection_lost(None) + test_utils.run_briefly(self.loop) + + self.assertIsNone(srv._request_handler) + self.assertTrue(handle.cancelled()) + + self.assertIsNone(srv._keep_alive_handle) + self.assertTrue(keep_alive_handle.cancel.called) + + srv.connection_lost(None) + self.assertIsNone(srv._request_handler) + self.assertIsNone(srv._keep_alive_handle) + + def test_srv_keep_alive(self): + srv = server.ServerHttpProtocol(loop=self.loop) + self.assertFalse(srv._keep_alive) + + srv.keep_alive(True) + self.assertTrue(srv._keep_alive) + + srv.keep_alive(False) + self.assertFalse(srv._keep_alive) + + def test_handle_error(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(transport) + srv.keep_alive(True) + + srv.handle_error(404, headers=(('X-Server', 'Tulip'),)) + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertIn(b'HTTP/1.1 404 Not Found', content) + self.assertIn(b'X-SERVER: Tulip', content) + self.assertFalse(srv._keep_alive) + + @unittest.mock.patch('tulip.http.server.traceback') + def test_handle_error_traceback_exc(self, m_trace): + transport = unittest.mock.Mock() + log = unittest.mock.Mock() + srv = server.ServerHttpProtocol(debug=True, log=log, loop=self.loop) + srv.connection_made(transport) + + m_trace.format_exc.side_effect = ValueError + + srv.handle_error(500, exc=object()) + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + self.assertTrue( + content.startswith(b'HTTP/1.1 500 Internal Server Error')) + self.assertTrue(log.exception.called) + + def test_handle_error_debug(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + srv.debug = True + srv.connection_made(transport) + + try: + raise ValueError() + except Exception as exc: + srv.handle_error(999, exc=exc) + + content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) + + self.assertIn(b'HTTP/1.1 500 Internal', content) + self.assertIn(b'Traceback (most recent call last):', content) + + def test_handle_error_500(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log, loop=self.loop) + srv.connection_made(transport) + + srv.handle_error(500) + self.assertTrue(log.exception.called) + + def test_handle(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + + self.loop.run_until_complete(srv._request_handler) + self.assertTrue(handle.called) + self.assertTrue(transport.close.called) + + def test_handle_coro(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + + called = False + + @tulip.coroutine + def coro(message, payload): + nonlocal called + called = True + srv.eof_received() + + srv.handle_request = coro + srv.connection_made(transport) + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + self.loop.run_until_complete(srv._request_handler) + self.assertTrue(called) + + def test_handle_cancel(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log, debug=True, loop=self.loop) + srv.connection_made(transport) + + srv.handle_request = unittest.mock.Mock() + + @tulip.coroutine + def cancel(): + srv._request_handler.cancel() + + self.loop.run_until_complete( + tulip.wait([srv._request_handler, cancel()], loop=self.loop)) + self.assertTrue(log.debug.called) + + def test_handle_cancelled(self): + log = unittest.mock.Mock() + transport = unittest.mock.Mock() + + srv = server.ServerHttpProtocol(log=log, debug=True, loop=self.loop) + srv.connection_made(transport) + + srv.handle_request = unittest.mock.Mock() + test_utils.run_briefly(self.loop) # start request_handler task + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + + r_handler = srv._request_handler + srv._request_handler = None # emulate srv.connection_lost() + + self.assertIsNone(self.loop.run_until_complete(r_handler)) + + def test_handle_400(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(transport) + srv.handle_error = unittest.mock.Mock() + srv.keep_alive(True) + srv.stream.feed_data(b'GET / HT/asd\r\n\r\n') + + self.loop.run_until_complete(srv._request_handler) + self.assertTrue(srv.handle_error.called) + self.assertTrue(400, srv.handle_error.call_args[0][0]) + self.assertTrue(transport.close.called) + + def test_handle_500(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + handle.side_effect = ValueError + srv.handle_error = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'Host: example.com\r\n\r\n') + self.loop.run_until_complete(srv._request_handler) + + self.assertTrue(srv.handle_error.called) + self.assertTrue(500, srv.handle_error.call_args[0][0]) + + def test_handle_error_no_handle_task(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(loop=self.loop) + srv.keep_alive(True) + srv.connection_made(transport) + srv.connection_lost(None) + + srv.handle_error(300) + self.assertFalse(srv._keep_alive) + + def test_keep_alive(self): + srv = server.ServerHttpProtocol(keep_alive=0.1, loop=self.loop) + transport = unittest.mock.Mock() + closed = False + + def close(): + nonlocal closed + closed = True + srv.connection_lost(None) + self.loop.stop() + + transport.close = close + + srv.connection_made(transport) + + handle = srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.1\r\n' + b'CONNECTION: keep-alive\r\n' + b'HOST: example.com\r\n\r\n') + + self.loop.run_forever() + self.assertTrue(handle.called) + self.assertTrue(closed) + + def test_keep_alive_close_existing(self): + transport = unittest.mock.Mock() + srv = server.ServerHttpProtocol(keep_alive=15, loop=self.loop) + srv.connection_made(transport) + + self.assertIsNone(srv._keep_alive_handle) + keep_alive_handle = srv._keep_alive_handle = unittest.mock.Mock() + srv.handle_request = unittest.mock.Mock() + + srv.stream.feed_data( + b'GET / HTTP/1.0\r\n' + b'HOST: example.com\r\n\r\n') + + self.loop.run_until_complete(srv._request_handler) + self.assertTrue(keep_alive_handle.cancel.called) + self.assertIsNone(srv._keep_alive_handle) + self.assertTrue(transport.close.called) diff --git a/tests/http_session_test.py b/tests/http_session_test.py new file mode 100644 index 00000000..39a80091 --- /dev/null +++ b/tests/http_session_test.py @@ -0,0 +1,139 @@ +"""Tests for tulip/http/session.py""" + +import http.cookies +import unittest +import unittest.mock + +import tulip +import tulip.http + +from tulip.http.client import HttpResponse +from tulip.http.session import Session + +from tulip import test_utils + + +class HttpSessionTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + tulip.set_event_loop(self.loop) + + self.transport = unittest.mock.Mock() + self.stream = tulip.StreamBuffer() + self.response = HttpResponse('get', 'http://python.org') + + def tearDown(self): + tulip.set_event_loop(None) + self.loop.close() + + def test_del(self): + session = Session() + close = session.close = unittest.mock.Mock() + + del session + self.assertTrue(close.called) + + def test_close(self): + tr = unittest.mock.Mock() + + session = Session() + session._conns[1] = [(tr, object())] + session.close() + + self.assertFalse(session._conns) + self.assertTrue(tr.close.called) + + def test_get(self): + session = Session() + self.assertEqual(session._get(1), (None, None)) + + tr, proto = unittest.mock.Mock(), object() + session._conns[1] = [(tr, proto)] + self.assertEqual(session._get(1), (tr, proto)) + + def test_release(self): + session = Session() + resp = unittest.mock.Mock() + resp.message.should_close = False + + cookies = resp.cookies = http.cookies.SimpleCookie() + cookies['c1'] = 'cookie1' + cookies['c2'] = 'cookie2' + + tr, proto = unittest.mock.Mock(), unittest.mock.Mock() + session._release(resp, 1, (tr, proto)) + self.assertEqual(session._conns[1][0], (tr, proto)) + self.assertEqual(session.cookies, dict(cookies.items())) + + def test_release_close(self): + session = Session() + resp = unittest.mock.Mock() + resp.message.should_close = True + + cookies = resp.cookies = http.cookies.SimpleCookie() + cookies['c1'] = 'cookie1' + cookies['c2'] = 'cookie2' + + tr, proto = unittest.mock.Mock(), unittest.mock.Mock() + session._release(resp, 1, (tr, proto)) + self.assertFalse(session._conns) + self.assertTrue(tr.close.called) + + def test_call_new_conn_exc(self): + tr, proto = unittest.mock.Mock(), unittest.mock.Mock() + + class Req: + host = 'host' + port = 80 + ssl = False + + def send(self, *args): + raise ValueError() + + class Loop: + @tulip.coroutine + def create_connection(self, *args, **kw): + return tr, proto + + session = Session() + self.assertRaises( + ValueError, + self.loop.run_until_complete, session.start(Req(), Loop(), True)) + + self.assertTrue(tr.close.called) + + def test_call_existing_conn_exc(self): + existing = unittest.mock.Mock() + tr, proto = unittest.mock.Mock(), unittest.mock.Mock() + + class Req: + host = 'host' + port = 80 + ssl = False + + def send(self, transport): + if transport is existing: + transport.close() + raise ValueError() + else: + return Resp() + + class Resp: + @tulip.coroutine + def start(self, *args, **kw): + pass + + class Loop: + @tulip.coroutine + def create_connection(self, *args, **kw): + return tr, proto + + session = Session() + key = ('host', 80, False) + session._conns[key] = [(existing, object())] + + resp = self.loop.run_until_complete(session.start(Req(), Loop())) + self.assertIsInstance(resp, Resp) + self.assertTrue(existing.close.called) + self.assertFalse(session._conns[key]) diff --git a/tests/http_websocket_test.py b/tests/http_websocket_test.py new file mode 100644 index 00000000..319538ae --- /dev/null +++ b/tests/http_websocket_test.py @@ -0,0 +1,439 @@ +"""Tests for http/websocket.py""" + +import base64 +import hashlib +import os +import struct +import unittest +import unittest.mock + +import tulip +from tulip.http import websocket, protocol, errors + + +class WebsocketParserTests(unittest.TestCase): + + def test_parse_frame(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + p.send(struct.pack('!BB', 0b00000001, 0b00000001)) + try: + p.send(b'1') + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b'1'), (fin, opcode, payload)) + + def test_parse_frame_length0(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + try: + p.send(struct.pack('!BB', 0b00000001, 0b00000000)) + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b''), (fin, opcode, payload)) + + def test_parse_frame_length2(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + p.send(struct.pack('!BB', 0b00000001, 126)) + p.send(struct.pack('!H', 4)) + try: + p.send(b'1234') + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b'1234'), (fin, opcode, payload)) + + def test_parse_frame_length4(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + p.send(struct.pack('!BB', 0b00000001, 127)) + p.send(struct.pack('!Q', 4)) + try: + p.send(b'1234') + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b'1234'), (fin, opcode, payload)) + + def test_parse_frame_mask(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + p.send(struct.pack('!BB', 0b00000001, 0b10000001)) + p.send(b'0001') + try: + p.send(b'1') + except StopIteration as exc: + fin, opcode, payload = exc.value + + self.assertEqual((0, 1, b'\x01'), (fin, opcode, payload)) + + def test_parse_frame_header_reversed_bits(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b01100000, 0b00000000)) + + def test_parse_frame_header_control_frame(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b00001000, 0b00000000)) + + def test_parse_frame_header_continuation(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b00000000, 0b00000000)) + + def test_parse_frame_header_new_data_err(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b000000000, 0b00000000)) + + def test_parse_frame_header_payload_size(self): + buf = tulip.ParserBuffer() + p = websocket.parse_frame(buf) + next(p) + self.assertRaises( + websocket.WebSocketError, + p.send, struct.pack('!BB', 0b10001000, 0b01111110)) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_ping_frame(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_PING, b'') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_PING, '', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_pong_frame(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_PONG, b'') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_PONG, '', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_close_frame(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_CLOSE, b'') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_CLOSE, '', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_close_frame_info(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_CLOSE, b'0112345') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_CLOSE, 12337, b'12345')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_close_frame_invalid(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_CLOSE, b'1') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + self.assertRaises(websocket.WebSocketError, p.send, b'') + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_unknown_frame(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_CONTINUATION, b'') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + self.assertRaises(websocket.WebSocketError, p.send, b'') + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_simple_text(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_TEXT, b'text') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_TEXT, 'text', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_simple_binary(self, m_parse_frame): + def parse_frame(buf): + yield + return (1, websocket.OPCODE_BINARY, b'binary') + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_BINARY, b'binary', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_continuation(self, m_parse_frame): + cur = 0 + + def parse_frame(buf): + nonlocal cur + yield + if cur == 0: + cur = 1 + return (0, websocket.OPCODE_TEXT, b'line1') + else: + return (1, websocket.OPCODE_CONTINUATION, b'line2') + + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + p.send(b'') + try: + p.send(b'') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, (websocket.OPCODE_TEXT, 'line1line2', '')) + + @unittest.mock.patch('tulip.http.websocket.parse_frame') + def test_continuation_err(self, m_parse_frame): + cur = 0 + + def parse_frame(buf): + nonlocal cur + yield + if cur == 0: + cur = 1 + return (0, websocket.OPCODE_TEXT, b'line1') + else: + return (1, websocket.OPCODE_TEXT, b'line2') + + m_parse_frame.side_effect = parse_frame + buf = tulip.ParserBuffer() + p = websocket.parse_message(buf) + next(p) + p.send(b'') + self.assertRaises(websocket.WebSocketError, p.send, b'') + + @unittest.mock.patch('tulip.http.websocket.parse_message') + def test_parser(self, m_parse_message): + cur = 0 + + def parse_message(buf): + nonlocal cur + yield + if cur == 0: + cur = 1 + return websocket.Message(websocket.OPCODE_TEXT, b'line1', b'') + else: + return websocket.Message(websocket.OPCODE_CLOSE, b'', b'') + + m_parse_message.side_effect = parse_message + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = websocket.WebSocketParser() + next(p) + p.send((out, buf)) + p.send(b'') + self.assertRaises(StopIteration, p.send, b'') + + self.assertEqual( + (websocket.OPCODE_TEXT, b'line1', b''), out._buffer[0]) + self.assertEqual( + (websocket.OPCODE_CLOSE, b'', b''), out._buffer[1]) + self.assertTrue(out._eof) + + def test_parser_eof(self): + out = tulip.DataBuffer() + buf = tulip.ParserBuffer() + p = websocket.WebSocketParser() + next(p) + p.send((out, buf)) + self.assertRaises(tulip.EofStream, p.throw, tulip.EofStream) + self.assertEqual([], list(out._buffer)) + + +class WebsocketWriterTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + self.writer = websocket.WebSocketWriter(self.transport) + + def test_pong(self): + self.writer.pong() + self.transport.write.assert_called_with(b'\x8a\x00') + + def test_ping(self): + self.writer.ping() + self.transport.write.assert_called_with(b'\x89\x00') + + def test_send_text(self): + self.writer.send(b'text') + self.transport.write.assert_called_with(b'\x81\x04text') + + def test_send_binary(self): + self.writer.send('binary', True) + self.transport.write.assert_called_with(b'\x82\x06binary') + + def test_send_binary_long(self): + self.writer.send(b'b'*127, True) + self.assertTrue( + self.transport.write.call_args[0][0].startswith(b'\x82~\x00\x7fb')) + + def test_send_binary_very_long(self): + self.writer.send(b'b'*65537, True) + self.assertTrue( + self.transport.write.call_args[0][0].startswith( + b'\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x01b')) + + def test_close(self): + self.writer.close(1001, 'msg') + self.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') + + self.writer.close(1001, b'msg') + self.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') + + +class WebSocketHandshakeTests(unittest.TestCase): + + def setUp(self): + self.transport = unittest.mock.Mock() + self.headers = [] + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 0), self.headers, True, None) + + def test_not_get(self): + self.assertRaises( + errors.HttpErrorException, + websocket.do_handshake, + 'POST', self.message.headers, self.transport) + + def test_no_upgrade(self): + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + def test_no_connection(self): + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'keep-alive')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + def test_protocol_version(self): + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '1')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + def test_protocol_key(self): + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13'), + ('SEC-WEBSOCKET-KEY', '123')]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + sec_key = base64.b64encode(os.urandom(2)) + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13'), + ('SEC-WEBSOCKET-KEY', sec_key.decode())]) + self.assertRaises( + errors.BadRequestException, + websocket.do_handshake, + self.message.method, self.message.headers, self.transport) + + def test_handshake(self): + sec_key = base64.b64encode(os.urandom(16)).decode() + + self.headers.extend([('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('SEC-WEBSOCKET-VERSION', '13'), + ('SEC-WEBSOCKET-KEY', sec_key)]) + status, headers, parser, writer = websocket.do_handshake( + self.message.method, self.message.headers, self.transport) + self.assertEqual(status, 101) + + key = base64.b64encode( + hashlib.sha1(sec_key.encode() + websocket.WS_KEY).digest()) + headers = dict(headers) + self.assertEqual(headers['SEC-WEBSOCKET-ACCEPT'], key.decode()) diff --git a/tests/http_wsgi_test.py b/tests/http_wsgi_test.py new file mode 100644 index 00000000..053f5a69 --- /dev/null +++ b/tests/http_wsgi_test.py @@ -0,0 +1,301 @@ +"""Tests for http/wsgi.py""" + +import io +import unittest +import unittest.mock + +import tulip +from tulip.http import wsgi +from tulip.http import protocol + + +class HttpWsgiServerProtocolTests(unittest.TestCase): + + def setUp(self): + self.loop = tulip.new_event_loop() + tulip.set_event_loop(None) + + self.wsgi = unittest.mock.Mock() + self.stream = unittest.mock.Mock() + self.transport = unittest.mock.Mock() + self.transport.get_extra_info.return_value = '127.0.0.1' + + self.headers = [] + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 0), self.headers, True, 'deflate') + self.payload = tulip.DataBuffer() + self.payload.feed_data(b'data') + self.payload.feed_data(b'data') + self.payload.feed_eof() + + def tearDown(self): + self.loop.close() + + def test_ctor(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) + self.assertIs(srv.wsgi, self.wsgi) + self.assertFalse(srv.readpayload) + + def _make_one(self, **kw): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop, **kw) + srv.stream = self.stream + srv.transport = self.transport + return srv.create_wsgi_environ(self.message, self.payload) + + def test_environ(self): + environ = self._make_one() + self.assertEqual(environ['RAW_URI'], '/path') + self.assertEqual(environ['wsgi.async'], True) + + def test_environ_except_header(self): + self.headers.append(('EXPECT', '101-continue')) + self._make_one() + self.assertFalse(self.transport.write.called) + + self.headers[0] = ('EXPECT', '100-continue') + self._make_one() + self.transport.write.assert_called_with( + b'HTTP/1.1 100 Continue\r\n\r\n') + + def test_environ_headers(self): + self.headers.extend( + (('HOST', 'python.org'), + ('SCRIPT_NAME', 'script'), + ('CONTENT-TYPE', 'text/plain'), + ('CONTENT-LENGTH', '209'), + ('X_TEST', '123'), + ('X_TEST', '456'))) + environ = self._make_one(is_ssl=True) + self.assertEqual(environ['CONTENT_TYPE'], 'text/plain') + self.assertEqual(environ['CONTENT_LENGTH'], '209') + self.assertEqual(environ['HTTP_X_TEST'], '123,456') + self.assertEqual(environ['SCRIPT_NAME'], 'script') + self.assertEqual(environ['SERVER_NAME'], 'python.org') + self.assertEqual(environ['SERVER_PORT'], '443') + + def test_environ_host_header(self): + self.headers.append(('HOST', 'python.org')) + environ = self._make_one() + + self.assertEqual(environ['HTTP_HOST'], 'python.org') + self.assertEqual(environ['SERVER_NAME'], 'python.org') + self.assertEqual(environ['SERVER_PORT'], '80') + self.assertEqual(environ['SERVER_PROTOCOL'], 'HTTP/1.0') + + def test_environ_host_port_header(self): + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 1), self.headers, True, 'deflate') + self.headers.append(('HOST', 'python.org:443')) + environ = self._make_one() + + self.assertEqual(environ['HTTP_HOST'], 'python.org:443') + self.assertEqual(environ['SERVER_NAME'], 'python.org') + self.assertEqual(environ['SERVER_PORT'], '443') + self.assertEqual(environ['SERVER_PROTOCOL'], 'HTTP/1.1') + + def test_environ_forward(self): + self.transport.get_extra_info.return_value = 'localhost,127.0.0.1' + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') + self.assertEqual(environ['REMOTE_PORT'], '80') + + self.transport.get_extra_info.return_value = 'localhost,127.0.0.1:443' + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') + self.assertEqual(environ['REMOTE_PORT'], '443') + + self.transport.get_extra_info.return_value = ('127.0.0.1', 443) + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') + self.assertEqual(environ['REMOTE_PORT'], '443') + + self.transport.get_extra_info.return_value = '[::1]' + environ = self._make_one() + + self.assertEqual(environ['REMOTE_ADDR'], '::1') + self.assertEqual(environ['REMOTE_PORT'], '80') + + def test_wsgi_response(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + self.assertIsInstance(resp, wsgi.WsgiResponse) + + def test_wsgi_response_start_response(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + resp.start_response( + '200 OK', [('CONTENT-TYPE', 'text/plain')]) + self.assertEqual(resp.status, '200 OK') + self.assertIsInstance(resp.response, protocol.Response) + + def test_wsgi_response_start_response_exc(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + resp.start_response( + '200 OK', [('CONTENT-TYPE', 'text/plain')], ['', ValueError()]) + self.assertEqual(resp.status, '200 OK') + self.assertIsInstance(resp.response, protocol.Response) + + def test_wsgi_response_start_response_exc_status(self): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + resp.start_response('200 OK', [('CONTENT-TYPE', 'text/plain')]) + + self.assertRaises( + ValueError, + resp.start_response, + '500 Err', [('CONTENT-TYPE', 'text/plain')], ['', ValueError()]) + + @unittest.mock.patch('tulip.http.wsgi.tulip') + def test_wsgi_response_101_upgrade_to_websocket(self, m_tulip): + srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + resp = srv.create_wsgi_response(self.message) + resp.start_response( + '101 Switching Protocols', (('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'))) + self.assertEqual(resp.status, '101 Switching Protocols') + self.assertTrue(m_tulip.http.Response.return_value.send_headers.called) + + def test_file_wrapper(self): + fobj = io.BytesIO(b'data') + wrapper = wsgi.FileWrapper(fobj, 2) + self.assertIs(wrapper, iter(wrapper)) + self.assertTrue(hasattr(wrapper, 'close')) + + self.assertEqual(next(wrapper), b'da') + self.assertEqual(next(wrapper), b'ta') + self.assertRaises(StopIteration, next, wrapper) + + wrapper = wsgi.FileWrapper(b'data', 2) + self.assertFalse(hasattr(wrapper, 'close')) + + def test_handle_request_futures(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + f1 = tulip.Future(loop=self.loop) + f1.set_result(b'data') + fut = tulip.Future(loop=self.loop) + fut.set_result([f1]) + return fut + + srv = wsgi.WSGIServerHttpProtocol(wsgi_app, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) + self.assertTrue(content.endswith(b'data')) + + def test_handle_request_simple(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + return [b'data'] + + stream = tulip.StreamReader(loop=self.loop) + stream.feed_data(b'data') + stream.feed_eof() + + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 1), self.headers, True, 'deflate') + + srv = wsgi.WSGIServerHttpProtocol( + wsgi_app, readpayload=True, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.1 200 OK')) + self.assertTrue(content.endswith(b'data\r\n0\r\n\r\n')) + self.assertFalse(srv._keep_alive) + + def test_handle_request_io(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + return io.BytesIO(b'data') + + srv = wsgi.WSGIServerHttpProtocol(wsgi_app, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) + self.assertTrue(content.endswith(b'data')) + + def test_handle_request_keep_alive(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + return [b'data'] + + stream = tulip.StreamReader(loop=self.loop) + stream.feed_data(b'data') + stream.feed_eof() + + self.message = protocol.RawRequestMessage( + 'GET', '/path', (1, 1), self.headers, False, 'deflate') + + srv = wsgi.WSGIServerHttpProtocol( + wsgi_app, readpayload=True, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.1 200 OK')) + self.assertTrue(content.endswith(b'data\r\n0\r\n\r\n')) + self.assertTrue(srv._keep_alive) + + def test_handle_request_readpayload(self): + + def wsgi_app(env, start): + start('200 OK', [('Content-Type', 'text/plain')]) + return [env['wsgi.input'].read()] + + srv = wsgi.WSGIServerHttpProtocol( + wsgi_app, readpayload=True, loop=self.loop) + srv.stream = self.stream + srv.transport = self.transport + + self.loop.run_until_complete( + srv.handle_request(self.message, self.payload)) + + content = b''.join( + [c[1][0] for c in self.transport.write.mock_calls]) + self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) + self.assertTrue(content.endswith(b'data')) diff --git a/tests/locks_test.py b/tests/locks_test.py new file mode 100644 index 00000000..9399d759 --- /dev/null +++ b/tests/locks_test.py @@ -0,0 +1,765 @@ +"""Tests for lock.py""" + +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import locks +from tulip import tasks +from tulip import test_utils + + +class LockTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + lock = locks.Lock(loop=loop) + self.assertIs(lock._loop, loop) + + lock = locks.Lock(loop=self.loop) + self.assertIs(lock._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + lock = locks.Lock() + self.assertIs(lock._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + lock = locks.Lock(loop=self.loop) + self.assertTrue(repr(lock).endswith('[unlocked]>')) + + @tasks.coroutine + def acquire_lock(): + yield from lock + + self.loop.run_until_complete(acquire_lock()) + self.assertTrue(repr(lock).endswith('[locked]>')) + + def test_lock(self): + lock = locks.Lock(loop=self.loop) + + @tasks.coroutine + def acquire_lock(): + return (yield from lock) + + res = self.loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_acquire(self): + lock = locks.Lock(loop=self.loop) + result = [] + + self.assertTrue(self.loop.run_until_complete(lock.acquire())) + + @tasks.coroutine + def c1(result): + if (yield from lock.acquire()): + result.append(1) + return True + + @tasks.coroutine + def c2(result): + if (yield from lock.acquire()): + result.append(2) + return True + + @tasks.coroutine + def c3(result): + if (yield from lock.acquire()): + result.append(3) + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + t3 = tasks.Task(c3(result), loop=self.loop) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_acquire_cancel(self): + lock = locks.Lock(loop=self.loop) + self.assertTrue(self.loop.run_until_complete(lock.acquire())) + + task = tasks.Task(lock.acquire(), loop=self.loop) + self.loop.call_soon(task.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, task) + self.assertFalse(lock._waiters) + + def test_cancel_race(self): + # Several tasks: + # - A acquires the lock + # - B is blocked in aqcuire() + # - C is blocked in aqcuire() + # + # Now, concurrently: + # - B is cancelled + # - A releases the lock + # + # If B's waiter is marked cancelled but not yet removed from + # _waiters, A's release() call will crash when trying to set + # B's waiter; instead, it should move on to C's waiter. + + # Setup: A has the lock, b and c are waiting. + lock = locks.Lock(loop=self.loop) + + @tasks.coroutine + def lockit(name, blocker): + yield from lock.acquire() + try: + if blocker is not None: + yield from blocker + finally: + lock.release() + + fa = futures.Future(loop=self.loop) + ta = tasks.Task(lockit('A', fa), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertTrue(lock.locked()) + tb = tasks.Task(lockit('B', None), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(len(lock._waiters), 1) + tc = tasks.Task(lockit('C', None), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(len(lock._waiters), 2) + + # Create the race and check. + # Without the fix this failed at the last assert. + fa.set_result(None) + tb.cancel() + self.assertTrue(lock._waiters[0].cancelled()) + test_utils.run_briefly(self.loop) + self.assertFalse(lock.locked()) + self.assertTrue(ta.done()) + self.assertTrue(tb.cancelled()) + self.assertTrue(tc.done()) + + def test_release_not_acquired(self): + lock = locks.Lock(loop=self.loop) + + self.assertRaises(RuntimeError, lock.release) + + def test_release_no_waiters(self): + lock = locks.Lock(loop=self.loop) + self.loop.run_until_complete(lock.acquire()) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_context_manager(self): + lock = locks.Lock(loop=self.loop) + + @tasks.coroutine + def acquire_lock(): + return (yield from lock) + + with self.loop.run_until_complete(acquire_lock()): + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + def test_context_manager_no_yield(self): + lock = locks.Lock(loop=self.loop) + + try: + with lock: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + +class EventWaiterTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + ev = locks.EventWaiter(loop=loop) + self.assertIs(ev._loop, loop) + + ev = locks.EventWaiter(loop=self.loop) + self.assertIs(ev._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + ev = locks.EventWaiter() + self.assertIs(ev._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + ev = locks.EventWaiter(loop=self.loop) + self.assertTrue(repr(ev).endswith('[unset]>')) + + ev.set() + self.assertTrue(repr(ev).endswith('[set]>')) + + def test_wait(self): + ev = locks.EventWaiter(loop=self.loop) + self.assertFalse(ev.is_set()) + + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from ev.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from ev.wait()): + result.append(3) + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + t3 = tasks.Task(c3(result), loop=self.loop) + + ev.set() + test_utils.run_briefly(self.loop) + self.assertEqual([3, 1, 2], result) + + self.assertTrue(t1.done()) + self.assertIsNone(t1.result()) + self.assertTrue(t2.done()) + self.assertIsNone(t2.result()) + self.assertTrue(t3.done()) + self.assertIsNone(t3.result()) + + def test_wait_on_set(self): + ev = locks.EventWaiter(loop=self.loop) + ev.set() + + res = self.loop.run_until_complete(ev.wait()) + self.assertTrue(res) + + def test_wait_cancel(self): + ev = locks.EventWaiter(loop=self.loop) + + wait = tasks.Task(ev.wait(), loop=self.loop) + self.loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, wait) + self.assertFalse(ev._waiters) + + def test_clear(self): + ev = locks.EventWaiter(loop=self.loop) + self.assertFalse(ev.is_set()) + + ev.set() + self.assertTrue(ev.is_set()) + + ev.clear() + self.assertFalse(ev.is_set()) + + def test_clear_with_waiters(self): + ev = locks.EventWaiter(loop=self.loop) + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + return True + + t = tasks.Task(c1(result), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + ev.set() + ev.clear() + self.assertFalse(ev.is_set()) + + ev.set() + ev.set() + self.assertEqual(1, len(ev._waiters)) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertEqual(0, len(ev._waiters)) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + +class ConditionTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + cond = locks.Condition(loop=loop) + self.assertIs(cond._loop, loop) + + cond = locks.Condition(loop=self.loop) + self.assertIs(cond._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + cond = locks.Condition() + self.assertIs(cond._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_wait(self): + cond = locks.Condition(loop=self.loop) + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + return True + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue(self.loop.run_until_complete(cond.acquire())) + cond.notify() + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_wait_cancel(self): + cond = locks.Condition(loop=self.loop) + self.loop.run_until_complete(cond.acquire()) + + wait = tasks.Task(cond.wait(), loop=self.loop) + self.loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, wait) + self.assertFalse(cond._condition_waiters) + self.assertTrue(cond.locked()) + + def test_wait_unacquired(self): + cond = locks.Condition(loop=self.loop) + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, cond.wait()) + + def test_wait_for(self): + cond = locks.Condition(loop=self.loop) + presult = False + + def predicate(): + return presult + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate)): + result.append(1) + cond.release() + return True + + t = tasks.Task(c1(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + presult = True + self.loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + def test_wait_for_unacquired(self): + cond = locks.Condition(loop=self.loop) + + # predicate can return true immediately + res = self.loop.run_until_complete(cond.wait_for(lambda: [1, 2, 3])) + self.assertEqual([1, 2, 3], res) + + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, + cond.wait_for(lambda: False)) + + def test_notify(self): + cond = locks.Condition(loop=self.loop) + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + return True + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + cond.release() + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.notify(2048) + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_notify_all(self): + cond = locks.Condition(loop=self.loop) + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify_all() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + + def test_notify_unacquired(self): + cond = locks.Condition(loop=self.loop) + self.assertRaises(RuntimeError, cond.notify) + + def test_notify_all_unacquired(self): + cond = locks.Condition(loop=self.loop) + self.assertRaises(RuntimeError, cond.notify_all) + + +class SemaphoreTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + sem = locks.Semaphore(loop=loop) + self.assertIs(sem._loop, loop) + + sem = locks.Semaphore(loop=self.loop) + self.assertIs(sem._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + sem = locks.Semaphore() + self.assertIs(sem._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + sem = locks.Semaphore(loop=self.loop) + self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) + + self.loop.run_until_complete(sem.acquire()) + self.assertTrue(repr(sem).endswith('[locked]>')) + + def test_semaphore(self): + sem = locks.Semaphore(loop=self.loop) + self.assertEqual(1, sem._value) + + @tasks.coroutine + def acquire_lock(): + return (yield from sem) + + res = self.loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(sem.locked()) + self.assertEqual(0, sem._value) + + sem.release() + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + def test_semaphore_value(self): + self.assertRaises(ValueError, locks.Semaphore, -1) + + def test_acquire(self): + sem = locks.Semaphore(3, loop=self.loop) + result = [] + + self.assertTrue(self.loop.run_until_complete(sem.acquire())) + self.assertTrue(self.loop.run_until_complete(sem.acquire())) + self.assertFalse(sem.locked()) + + @tasks.coroutine + def c1(result): + yield from sem.acquire() + result.append(1) + return True + + @tasks.coroutine + def c2(result): + yield from sem.acquire() + result.append(2) + return True + + @tasks.coroutine + def c3(result): + yield from sem.acquire() + result.append(3) + return True + + @tasks.coroutine + def c4(result): + yield from sem.acquire() + result.append(4) + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(sem.locked()) + self.assertEqual(2, len(sem._waiters)) + self.assertEqual(0, sem._value) + + t4 = tasks.Task(c4(result), loop=self.loop) + + sem.release() + sem.release() + self.assertEqual(2, sem._value) + + test_utils.run_briefly(self.loop) + self.assertEqual(0, sem._value) + self.assertEqual([1, 2, 3], result) + self.assertTrue(sem.locked()) + self.assertEqual(1, len(sem._waiters)) + self.assertEqual(0, sem._value) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + self.assertFalse(t4.done()) + + # cleanup locked semaphore + sem.release() + + def test_acquire_cancel(self): + sem = locks.Semaphore(loop=self.loop) + self.loop.run_until_complete(sem.acquire()) + + acquire = tasks.Task(sem.acquire(), loop=self.loop) + self.loop.call_soon(acquire.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, acquire) + self.assertFalse(sem._waiters) + + def test_release_not_acquired(self): + sem = locks.Semaphore(bound=True, loop=self.loop) + + self.assertRaises(ValueError, sem.release) + + def test_release_no_waiters(self): + sem = locks.Semaphore(loop=self.loop) + self.loop.run_until_complete(sem.acquire()) + self.assertTrue(sem.locked()) + + sem.release() + self.assertFalse(sem.locked()) + + def test_context_manager(self): + sem = locks.Semaphore(2, loop=self.loop) + + @tasks.coroutine + def acquire_lock(): + return (yield from sem) + + with self.loop.run_until_complete(acquire_lock()): + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + with self.loop.run_until_complete(acquire_lock()): + self.assertTrue(sem.locked()) + + self.assertEqual(2, sem._value) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/parsers_test.py b/tests/parsers_test.py new file mode 100644 index 00000000..debc532c --- /dev/null +++ b/tests/parsers_test.py @@ -0,0 +1,598 @@ +"""Tests for parser.py""" + +import unittest +import unittest.mock + +from tulip import events +from tulip import parsers +from tulip import tasks + + +class StreamBufferTests(unittest.TestCase): + + DATA = b'line1\nline2\nline3\n' + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_exception(self): + stream = parsers.StreamBuffer() + self.assertIsNone(stream.exception()) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(stream.exception(), exc) + + def test_exception_waiter(self): + stream = parsers.StreamBuffer() + + stream._parser = parsers.lines_parser() + buf = stream._parser_buffer = parsers.DataBuffer(loop=self.loop) + + exc = ValueError() + stream.set_exception(exc) + self.assertIs(buf.exception(), exc) + + def test_feed_data(self): + stream = parsers.StreamBuffer() + + stream.feed_data(self.DATA) + self.assertEqual(self.DATA, bytes(stream._buffer)) + + def test_feed_empty_data(self): + stream = parsers.StreamBuffer() + + stream.feed_data(b'') + self.assertEqual(b'', bytes(stream._buffer)) + + def test_set_parser_unset_prev(self): + stream = parsers.StreamBuffer() + stream.set_parser(parsers.lines_parser()) + + unset = stream.unset_parser = unittest.mock.Mock() + stream.set_parser(parsers.lines_parser()) + + self.assertTrue(unset.called) + + def test_set_parser_exception(self): + stream = parsers.StreamBuffer() + + exc = ValueError() + stream.set_exception(exc) + s = stream.set_parser(parsers.lines_parser()) + self.assertIs(s.exception(), exc) + + def test_set_parser_feed_existing(self): + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_data(b'\r\nline2\r\ndata') + s = stream.set_parser(parsers.lines_parser()) + + self.assertEqual([bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertIsNotNone(stream._parser) + + stream.unset_parser() + self.assertIsNone(stream._parser) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertTrue(s._eof) + + def test_set_parser_feed_existing_exc(self): + + def p(): + yield # stream + raise ValueError() + + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + s = stream.set_parser(p()) + self.assertIsInstance(s.exception(), ValueError) + + def test_set_parser_feed_existing_eof(self): + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_data(b'\r\nline2\r\ndata') + stream.feed_eof() + s = stream.set_parser(parsers.lines_parser()) + + self.assertEqual([bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertIsNone(stream._parser) + + def test_set_parser_feed_existing_eof_exc(self): + + def p(): + yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + raise ValueError() + + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_eof() + s = stream.set_parser(p()) + self.assertIsInstance(s.exception(), ValueError) + + def test_set_parser_feed_existing_eof_unhandled_eof(self): + + def p(): + yield # stream + while True: + yield # read chunk + + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_eof() + s = stream.set_parser(p()) + self.assertIsNone(s.exception()) + self.assertTrue(s._eof) + + def test_set_parser_unset(self): + stream = parsers.StreamBuffer() + s = stream.set_parser(parsers.lines_parser()) + + stream.feed_data(b'line1\r\nline2\r\n') + self.assertEqual( + [bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'', bytes(stream._buffer)) + stream.unset_parser() + self.assertTrue(s._eof) + self.assertEqual(b'', bytes(stream._buffer)) + + def test_set_parser_feed_existing_stop(self): + def lines_parser(): + out, buf = yield + try: + out.feed_data((yield from buf.readuntil(b'\n'))) + out.feed_data((yield from buf.readuntil(b'\n'))) + finally: + out.feed_eof() + + stream = parsers.StreamBuffer() + stream.feed_data(b'line1') + stream.feed_data(b'\r\nline2\r\ndata') + s = stream.set_parser(lines_parser()) + + self.assertEqual(b'line1\r\nline2\r\n', b''.join(s._buffer)) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertIsNone(stream._parser) + self.assertTrue(s._eof) + + def test_feed_parser(self): + stream = parsers.StreamBuffer() + s = stream.set_parser(parsers.lines_parser()) + + stream.feed_data(b'line1') + stream.feed_data(b'\r\nline2\r\ndata') + self.assertEqual(b'data', bytes(stream._buffer)) + + stream.feed_eof() + self.assertEqual([bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'data', bytes(stream._buffer)) + self.assertTrue(s._eof) + + def test_feed_parser_exc(self): + def p(): + yield # stream + yield # read chunk + raise ValueError() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + self.assertIsInstance(s.exception(), ValueError) + self.assertEqual(b'', bytes(stream._buffer)) + + def test_feed_parser_stop(self): + def p(): + yield # stream + yield # chunk + + stream = parsers.StreamBuffer() + stream.set_parser(p()) + + stream.feed_data(b'line1') + self.assertIsNone(stream._parser) + self.assertEqual(b'', bytes(stream._buffer)) + + def test_feed_eof_exc(self): + def p(): + yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + raise ValueError() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + self.assertIsNone(s.exception()) + + stream.feed_eof() + self.assertIsInstance(s.exception(), ValueError) + + def test_feed_eof_stop(self): + def p(): + out, buf = yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + out.feed_eof() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.feed_eof() + self.assertTrue(s._eof) + + def test_feed_eof_unhandled_eof(self): + def p(): + yield # stream + while True: + yield # read chunk + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.feed_eof() + self.assertIsNone(s.exception()) + self.assertTrue(s._eof) + + def test_feed_parser2(self): + stream = parsers.StreamBuffer() + s = stream.set_parser(parsers.lines_parser()) + + stream.feed_data(b'line1\r\nline2\r\n') + stream.feed_eof() + self.assertEqual( + [bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(s._buffer)) + self.assertEqual(b'', bytes(stream._buffer)) + self.assertTrue(s._eof) + + def test_unset_parser_eof_exc(self): + def p(): + yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + raise ValueError() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.unset_parser() + self.assertIsInstance(s.exception(), ValueError) + self.assertIsNone(stream._parser) + + def test_unset_parser_eof_unhandled_eof(self): + def p(): + yield # stream + while True: + yield # read chunk + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.unset_parser() + self.assertIsNone(s.exception(), ValueError) + self.assertTrue(s._eof) + + def test_unset_parser_stop(self): + def p(): + out, buf = yield # stream + try: + while True: + yield # read chunk + except parsers.EofStream: + out.feed_eof() + + stream = parsers.StreamBuffer() + s = stream.set_parser(p()) + + stream.feed_data(b'line1') + stream.unset_parser() + self.assertTrue(s._eof) + + +class DataBufferTests(unittest.TestCase): + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_feed_data(self): + buffer = parsers.DataBuffer(loop=self.loop) + + item = object() + buffer.feed_data(item) + self.assertEqual([item], list(buffer._buffer)) + + def test_feed_eof(self): + buffer = parsers.DataBuffer(loop=self.loop) + buffer.feed_eof() + self.assertTrue(buffer._eof) + + def test_read(self): + item = object() + buffer = parsers.DataBuffer(loop=self.loop) + read_task = tasks.Task(buffer.read(), loop=self.loop) + + def cb(): + buffer.feed_data(item) + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertIs(item, data) + + def test_read_eof(self): + buffer = parsers.DataBuffer(loop=self.loop) + read_task = tasks.Task(buffer.read(), loop=self.loop) + + def cb(): + buffer.feed_eof() + self.loop.call_soon(cb) + + data = self.loop.run_until_complete(read_task) + self.assertIsNone(data) + + def test_read_until_eof(self): + item = object() + buffer = parsers.DataBuffer(loop=self.loop) + buffer.feed_data(item) + buffer.feed_eof() + + data = self.loop.run_until_complete(buffer.read()) + self.assertIs(data, item) + + data = self.loop.run_until_complete(buffer.read()) + self.assertIsNone(data) + + def test_read_exception(self): + buffer = parsers.DataBuffer(loop=self.loop) + buffer.feed_data(object()) + buffer.set_exception(ValueError()) + + self.assertRaises( + ValueError, self.loop.run_until_complete, buffer.read()) + + def test_exception(self): + buffer = parsers.DataBuffer(loop=self.loop) + self.assertIsNone(buffer.exception()) + + exc = ValueError() + buffer.set_exception(exc) + self.assertIs(buffer.exception(), exc) + + def test_exception_waiter(self): + buffer = parsers.DataBuffer(loop=self.loop) + + @tasks.coroutine + def set_err(): + buffer.set_exception(ValueError()) + + t1 = tasks.Task(buffer.read(), loop=self.loop) + t2 = tasks.Task(set_err(), loop=self.loop) + + self.loop.run_until_complete(tasks.wait([t1, t2], loop=self.loop)) + + self.assertRaises(ValueError, t1.result) + + +class StreamProtocolTests(unittest.TestCase): + + def test_connection_made(self): + tr = unittest.mock.Mock() + + proto = parsers.StreamProtocol() + self.assertIsNone(proto.transport) + + proto.connection_made(tr) + self.assertIs(proto.transport, tr) + + def test_connection_lost(self): + proto = parsers.StreamProtocol() + proto.connection_made(unittest.mock.Mock()) + proto.connection_lost(None) + self.assertIsNone(proto.transport) + self.assertTrue(proto._eof) + + def test_connection_lost_exc(self): + proto = parsers.StreamProtocol() + proto.connection_made(unittest.mock.Mock()) + + exc = ValueError() + proto.connection_lost(exc) + self.assertIs(proto.exception(), exc) + + +class ParserBuffer(unittest.TestCase): + + def _make_one(self): + return parsers.ParserBuffer() + + def test_shrink(self): + buf = parsers.ParserBuffer() + buf.feed_data(b'data') + + buf._shrink() + self.assertEqual(bytes(buf), b'data') + + buf.offset = 2 + buf._shrink() + self.assertEqual(bytes(buf), b'ta') + self.assertEqual(2, len(buf)) + self.assertEqual(2, buf.size) + self.assertEqual(0, buf.offset) + + def test_feed_data(self): + buf = self._make_one() + buf.feed_data(b'') + self.assertEqual(len(buf), 0) + + buf.feed_data(b'data') + self.assertEqual(len(buf), 4) + self.assertEqual(bytes(buf), b'data') + + def test_read(self): + buf = self._make_one() + p = buf.read(3) + next(p) + p.send(b'1') + try: + p.send(b'234') + except StopIteration as exc: + res = exc.value + + self.assertEqual(res, b'123') + self.assertEqual(b'4', bytes(buf)) + + def test_readsome(self): + buf = self._make_one() + p = buf.readsome(3) + next(p) + try: + p.send(b'1') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, b'1') + + p = buf.readsome(2) + next(p) + try: + p.send(b'234') + except StopIteration as exc: + res = exc.value + self.assertEqual(res, b'23') + self.assertEqual(b'4', bytes(buf)) + + def test_skip(self): + buf = self._make_one() + p = buf.skip(3) + next(p) + p.send(b'1') + try: + p.send(b'234') + except StopIteration as exc: + res = exc.value + + self.assertIsNone(res) + self.assertEqual(b'4', bytes(buf)) + + def test_readuntil_limit(self): + buf = self._make_one() + p = buf.readuntil(b'\n', 4) + next(p) + p.send(b'1') + p.send(b'234') + self.assertRaises(ValueError, p.send, b'5') + + buf = parsers.ParserBuffer() + p = buf.readuntil(b'\n', 4) + next(p) + self.assertRaises(ValueError, p.send, b'12345\n6') + + buf = parsers.ParserBuffer() + p = buf.readuntil(b'\n', 4) + next(p) + self.assertRaises(ValueError, p.send, b'12345\n6') + + class CustomExc(Exception): + pass + + buf = parsers.ParserBuffer() + p = buf.readuntil(b'\n', 4, CustomExc) + next(p) + self.assertRaises(CustomExc, p.send, b'12345\n6') + + def test_readuntil(self): + buf = self._make_one() + p = buf.readuntil(b'\n', 4) + next(p) + p.send(b'123') + try: + p.send(b'\n456') + except StopIteration as exc: + res = exc.value + + self.assertEqual(res, b'123\n') + self.assertEqual(b'456', bytes(buf)) + + def test_skipuntil(self): + buf = self._make_one() + p = buf.skipuntil(b'\n') + next(p) + p.send(b'123') + try: + p.send(b'\n456\n') + except StopIteration: + pass + self.assertEqual(b'456\n', bytes(buf)) + + p = buf.skipuntil(b'\n') + try: + next(p) + except StopIteration: + pass + self.assertEqual(b'', bytes(buf)) + + def test_lines_parser(self): + out = parsers.DataBuffer(loop=self.loop) + buf = self._make_one() + p = parsers.lines_parser() + next(p) + p.send((out, buf)) + + for d in (b'line1', b'\r\n', b'lin', b'e2\r', b'\ndata'): + p.send(d) + + self.assertEqual( + [bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], + list(out._buffer)) + try: + p.throw(parsers.EofStream()) + except parsers.EofStream: + pass + + self.assertEqual(bytes(buf), b'data') + + def test_chunks_parser(self): + out = parsers.DataBuffer(loop=self.loop) + buf = self._make_one() + p = parsers.chunks_parser(5) + next(p) + p.send((out, buf)) + + for d in (b'line1', b'lin', b'e2d', b'ata'): + p.send(d) + + self.assertEqual( + [bytearray(b'line1'), bytearray(b'line2')], list(out._buffer)) + try: + p.throw(parsers.EofStream()) + except parsers.EofStream: + pass + + self.assertEqual(bytes(buf), b'data') diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py new file mode 100644 index 00000000..da4dea35 --- /dev/null +++ b/tests/proactor_events_test.py @@ -0,0 +1,393 @@ +"""Tests for proactor_events.py""" + +import socket +import unittest +import unittest.mock + +import tulip +from tulip.proactor_events import BaseProactorEventLoop +from tulip.proactor_events import _ProactorSocketTransport +from tulip import test_utils + + +class ProactorSocketTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.proactor = unittest.mock.Mock() + self.loop._proactor = self.proactor + self.protocol = test_utils.make_test_protocol(tulip.Protocol) + self.sock = unittest.mock.Mock(socket.socket) + + def test_ctor(self): + fut = tulip.Future(loop=self.loop) + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol, fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + self.protocol.connection_made(tr) + self.proactor.recv.assert_called_with(self.sock, 4096) + + def test_loop_reading(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._loop_reading() + self.loop._proactor.recv.assert_called_with(self.sock, 4096) + self.assertFalse(self.protocol.data_received.called) + self.assertFalse(self.protocol.eof_received.called) + + def test_loop_reading_data(self): + res = tulip.Future(loop=self.loop) + res.set_result(b'data') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + + tr._read_fut = res + tr._loop_reading(res) + self.loop._proactor.recv.assert_called_with(self.sock, 4096) + self.protocol.data_received.assert_called_with(b'data') + + def test_loop_reading_no_data(self): + res = tulip.Future(loop=self.loop) + res.set_result(b'') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + + self.assertRaises(AssertionError, tr._loop_reading, res) + + tr.close = unittest.mock.Mock() + tr._read_fut = res + tr._loop_reading(res) + self.assertFalse(self.loop._proactor.recv.called) + self.assertTrue(self.protocol.eof_received.called) + self.assertTrue(tr.close.called) + + def test_loop_reading_aborted(self): + err = self.loop._proactor.recv.side_effect = ConnectionAbortedError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with(err) + + def test_loop_reading_aborted_closing(self): + self.loop._proactor.recv.side_effect = ConnectionAbortedError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = True + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + + def test_loop_reading_aborted_is_fatal(self): + self.loop._proactor.recv.side_effect = ConnectionAbortedError() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = False + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + self.assertTrue(tr._fatal_error.called) + + def test_loop_reading_conn_reset_lost(self): + err = self.loop._proactor.recv.side_effect = ConnectionResetError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = False + tr._fatal_error = unittest.mock.Mock() + tr._force_close = unittest.mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + tr._force_close.assert_called_with(err) + + def test_loop_reading_exception(self): + err = self.loop._proactor.recv.side_effect = (OSError()) + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with(err) + + def test_write(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._loop_writing = unittest.mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, [b'data']) + self.assertTrue(tr._loop_writing.called) + + def test_write_no_data(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr.write(b'') + self.assertFalse(tr._buffer) + + def test_write_more(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = unittest.mock.Mock() + tr._loop_writing = unittest.mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, [b'data']) + self.assertFalse(tr._loop_writing.called) + + def test_loop_writing(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + self.loop._proactor.send.assert_called_with(self.sock, b'data') + self.loop._proactor.send.return_value.add_done_callback.\ + assert_called_with(tr._loop_writing) + + @unittest.mock.patch('tulip.proactor_events.tulip_log') + def test_loop_writing_err(self, m_log): + err = self.loop._proactor.send.side_effect = OSError() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + tr._fatal_error.assert_called_with(err) + tr._conn_lost = 1 + + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + self.assertEqual(tr._buffer, []) + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_loop_writing_stop(self): + fut = tulip.Future(loop=self.loop) + fut.set_result(b'data') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = fut + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + + def test_loop_writing_closing(self): + fut = tulip.Future(loop=self.loop) + fut.set_result(1) + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = fut + tr.close() + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test_abort(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._force_close = unittest.mock.Mock() + tr.abort() + tr._force_close.assert_called_with(None) + + def test_close(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr.close() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertTrue(tr._closing) + self.assertEqual(tr._conn_lost, 1) + + self.protocol.connection_lost.reset_mock() + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_write_fut(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = unittest.mock.Mock() + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_buffer(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + @unittest.mock.patch('tulip.proactor_events.tulip_log') + def test_fatal_error(self, m_logging): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._force_close = unittest.mock.Mock() + tr._fatal_error(None) + self.assertTrue(tr._force_close.called) + self.assertTrue(m_logging.exception.called) + + def test_force_close(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + read_fut = tr._read_fut = unittest.mock.Mock() + write_fut = tr._write_fut = unittest.mock.Mock() + tr._force_close(None) + + read_fut.cancel.assert_called_with() + write_fut.cancel.assert_called_with() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertEqual([], tr._buffer) + self.assertEqual(tr._conn_lost, 1) + + def test_force_close_idempotent(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = True + tr._force_close(None) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_fatal_error_2(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + tr._force_close(None) + + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertEqual([], tr._buffer) + + def test_call_connection_lost(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._call_connection_lost(None) + self.assertTrue(self.protocol.connection_lost.called) + self.assertTrue(self.sock.close.called) + + +class BaseProactorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.sock = unittest.mock.Mock(socket.socket) + self.proactor = unittest.mock.Mock() + + self.ssock, self.csock = unittest.mock.Mock(), unittest.mock.Mock() + + class EventLoop(BaseProactorEventLoop): + def _socketpair(s): + return (self.ssock, self.csock) + + self.loop = EventLoop(self.proactor) + + @unittest.mock.patch.object(BaseProactorEventLoop, 'call_soon') + @unittest.mock.patch.object(BaseProactorEventLoop, '_socketpair') + def test_ctor(self, socketpair, call_soon): + ssock, csock = socketpair.return_value = ( + unittest.mock.Mock(), unittest.mock.Mock()) + loop = BaseProactorEventLoop(self.proactor) + self.assertIs(loop._ssock, ssock) + self.assertIs(loop._csock, csock) + self.assertEqual(loop._internal_fds, 1) + call_soon.assert_called_with(loop._loop_self_reading) + + def test_close_self_pipe(self): + self.loop._close_self_pipe() + self.assertEqual(self.loop._internal_fds, 0) + self.assertTrue(self.ssock.close.called) + self.assertTrue(self.csock.close.called) + self.assertIsNone(self.loop._ssock) + self.assertIsNone(self.loop._csock) + + def test_close(self): + self.loop._close_self_pipe = unittest.mock.Mock() + self.loop.close() + self.assertTrue(self.loop._close_self_pipe.called) + self.assertTrue(self.proactor.close.called) + self.assertIsNone(self.loop._proactor) + + self.loop._close_self_pipe.reset_mock() + self.loop.close() + self.assertFalse(self.loop._close_self_pipe.called) + + def test_sock_recv(self): + self.loop.sock_recv(self.sock, 1024) + self.proactor.recv.assert_called_with(self.sock, 1024) + + def test_sock_sendall(self): + self.loop.sock_sendall(self.sock, b'data') + self.proactor.send.assert_called_with(self.sock, b'data') + + def test_sock_connect(self): + self.loop.sock_connect(self.sock, 123) + self.proactor.connect.assert_called_with(self.sock, 123) + + def test_sock_accept(self): + self.loop.sock_accept(self.sock) + self.proactor.accept.assert_called_with(self.sock) + + def test_socketpair(self): + self.assertRaises( + NotImplementedError, BaseProactorEventLoop, self.proactor) + + def test_make_socket_transport(self): + tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock()) + self.assertIsInstance(tr, _ProactorSocketTransport) + + def test_loop_self_reading(self): + self.loop._loop_self_reading() + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_fut(self): + fut = unittest.mock.Mock() + self.loop._loop_self_reading(fut) + self.assertTrue(fut.result.called) + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_exception(self): + self.loop.close = unittest.mock.Mock() + self.proactor.recv.side_effect = OSError() + self.assertRaises(OSError, self.loop._loop_self_reading) + self.assertTrue(self.loop.close.called) + + def test_write_to_self(self): + self.loop._write_to_self() + self.csock.send.assert_called_with(b'x') + + def test_process_events(self): + self.loop._process_events([]) + + @unittest.mock.patch('tulip.proactor_events.tulip_log') + def test_start_serving(self, m_log): + pf = unittest.mock.Mock() + call_soon = self.loop.call_soon = unittest.mock.Mock() + + self.loop._start_serving(pf, self.sock) + self.assertTrue(call_soon.called) + + # callback + loop = call_soon.call_args[0][0] + loop() + self.proactor.accept.assert_called_with(self.sock) + + # conn + fut = unittest.mock.Mock() + fut.result.return_value = (unittest.mock.Mock(), unittest.mock.Mock()) + + make_tr = self.loop._make_socket_transport = unittest.mock.Mock() + loop(fut) + self.assertTrue(fut.result.called) + self.assertTrue(make_tr.called) + + # exception + fut.result.side_effect = OSError() + loop(fut) + self.assertTrue(self.sock.close.called) + self.assertTrue(m_log.exception.called) + + def test_start_serving_cancel(self): + pf = unittest.mock.Mock() + call_soon = self.loop.call_soon = unittest.mock.Mock() + + self.loop._start_serving(pf, self.sock) + loop = call_soon.call_args[0][0] + + # cancelled + fut = tulip.Future(loop=self.loop) + fut.cancel() + loop(fut) + self.assertTrue(self.sock.close.called) + + def test_stop_serving(self): + sock = unittest.mock.Mock() + self.loop.stop_serving(sock) + self.assertTrue(sock.close.called) + self.proactor.stop_serving.assert_called_with(sock) diff --git a/tests/queues_test.py b/tests/queues_test.py new file mode 100644 index 00000000..98ca3199 --- /dev/null +++ b/tests/queues_test.py @@ -0,0 +1,470 @@ +"""Tests for queues.py""" + +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import locks +from tulip import queues +from tulip import tasks +from tulip import test_utils + + +class _QueueTestBase(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + +class QueueBasicTests(_QueueTestBase): + + def _test_repr_or_str(self, fn, expect_id): + """Test Queue's repr or str. + + fn is repr or str. expect_id is True if we expect the Queue's id to + appear in fn(Queue()). + """ + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.2, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = queues.Queue(loop=loop) + self.assertTrue(fn(q).startswith(')') + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), 'Task()') + self.assertRaises(futures.CancelledError, + self.loop.run_until_complete, t) + self.assertEqual(repr(t), 'Task()') + t = tasks.Task(notmuch(), loop=self.loop) + self.loop.run_until_complete(t) + self.assertEqual(repr(t), "Task()") + + def test_task_repr_custom(self): + @tasks.coroutine + def coro(): + pass + + class T(futures.Future): + def __repr__(self): + return 'T[]' + + class MyTask(tasks.Task, T): + def __repr__(self): + return super().__repr__() + + t = MyTask(coro(), loop=self.loop) + self.assertEqual(repr(t), 'T[]()') + + def test_task_basics(self): + @tasks.coroutine + def outer(): + a = yield from inner1() + b = yield from inner2() + return a+b + + @tasks.coroutine + def inner1(): + return 42 + + @tasks.coroutine + def inner2(): + return 1000 + + t = outer() + self.assertEqual(self.loop.run_until_complete(t), 1042) + + def test_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def task(): + yield from tasks.sleep(10.0, loop=loop) + return 12 + + t = tasks.Task(task(), loop=loop) + loop.call_soon(t.cancel) + with self.assertRaises(futures.CancelledError): + loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) + self.assertFalse(t.cancel()) + + def test_cancel_yield(self): + @tasks.coroutine + def task(): + yield + yield + return 12 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) # start coro + t.cancel() + self.assertRaises( + futures.CancelledError, self.loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) + self.assertFalse(t.cancel()) + + def test_cancel_inner_future(self): + f = futures.Future(loop=self.loop) + + @tasks.coroutine + def task(): + yield from f + return 12 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) # start task + f.cancel() + with self.assertRaises(futures.CancelledError): + self.loop.run_until_complete(t) + self.assertTrue(f.cancelled()) + self.assertTrue(t.cancelled()) + +## def test_cancel_done_future(self): +## fut1 = futures.Future(loop=self.loop) +## fut2 = futures.Future(loop=self.loop) +## fut3 = futures.Future(loop=self.loop) + +## @tasks.coroutine +## def task(): +## yield from fut1 +## try: +## yield from fut2 +## except futures.CancelledError: +## pass +## yield from fut3 + +## t = tasks.Task(task(), loop=self.loop) +## test_utils.run_briefly(self.loop) +## fut1.set_result(None) +## t.cancel() +## test_utils.run_once(self.loop) # process fut1 result, delay cancel +## self.assertFalse(t.done()) +## test_utils.run_once(self.loop) # cancel fut2, but coro still alive +## self.assertFalse(t.done()) +## test_utils.run_briefly(self.loop) # cancel fut3 +## self.assertTrue(t.done()) + +## self.assertEqual(fut1.result(), None) +## self.assertTrue(fut2.cancelled()) +## self.assertTrue(fut3.cancelled()) +## self.assertTrue(t.cancelled()) + +## def test_cancel_in_coro(self): +## @tasks.coroutine +## def task(): +## t.cancel() +## return 12 + +## t = tasks.Task(task(), loop=self.loop) +## self.assertRaises( +## futures.CancelledError, self.loop.run_until_complete, t) +## self.assertTrue(t.done()) +## self.assertFalse(t.cancel()) + + def test_stop_while_run_in_complete(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.2, when) + when = yield 0.1 + self.assertAlmostEqual(0.3, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + x = 0 + waiters = [] + + @tasks.coroutine + def task(): + nonlocal x + while x < 10: + waiters.append(tasks.sleep(0.1, loop=loop)) + yield from waiters[-1] + x += 1 + if x == 2: + loop.stop() + + t = tasks.Task(task(), loop=loop) + self.assertRaises( + RuntimeError, loop.run_until_complete, t) + self.assertFalse(t.done()) + self.assertEqual(x, 2) + self.assertAlmostEqual(0.3, loop.time()) + + # close generators + for w in waiters: + w.close() + + def test_wait_for(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.2, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.4, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def foo(): + yield from tasks.sleep(0.2, loop=loop) + return 'done' + + fut = tasks.Task(foo(), loop=loop) + + with self.assertRaises(futures.TimeoutError): + loop.run_until_complete(tasks.wait_for(fut, 0.1, loop=loop)) + + self.assertFalse(fut.done()) + self.assertAlmostEqual(0.1, loop.time()) + + loop + # wait for result + res = loop.run_until_complete( + tasks.wait_for(fut, 0.3, loop=loop)) + self.assertEqual(res, 'done') + self.assertAlmostEqual(0.2, loop.time()) + + def test_wait_for_with_global_loop(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.2, when) + when = yield 0 + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def foo(): + yield from tasks.sleep(0.2, loop=loop) + return 'done' + + events.set_event_loop(loop) + try: + fut = tasks.Task(foo(), loop=loop) + with self.assertRaises(futures.TimeoutError): + loop.run_until_complete(tasks.wait_for(fut, 0.01)) + finally: + events.set_event_loop(None) + + self.assertAlmostEqual(0.01, loop.time()) + self.assertFalse(fut.done()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(fut) + + def test_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], loop=loop) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertEqual(res, 42) + self.assertAlmostEqual(0.15, loop.time()) + + # Doing it again should take no time and exercise a different path. + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + self.assertEqual(res, 42) + + def test_wait_with_global_loop(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + when = yield 0 + self.assertAlmostEqual(0.015, when) + yield 0.015 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.01, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.015, loop=loop), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + events.set_event_loop(loop) + try: + res = loop.run_until_complete( + tasks.Task(foo(), loop=loop)) + finally: + events.set_event_loop(None) + + self.assertEqual(res, 42) + + def test_wait_errors(self): + self.assertRaises( + ValueError, self.loop.run_until_complete, + tasks.wait(set(), loop=self.loop)) + + self.assertRaises( + ValueError, self.loop.run_until_complete, + tasks.wait([tasks.sleep(10.0, loop=self.loop)], + return_when=-1, loop=self.loop)) + + def test_wait_first_completed(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED, + loop=loop), + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertFalse(a.done()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + self.assertAlmostEqual(0.1, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_really_done(self): + # there is possibility that some tasks in the pending list + # became done but their callbacks haven't all been called yet + + @tasks.coroutine + def coro1(): + yield + + @tasks.coroutine + def coro2(): + yield + yield + + a = tasks.Task(coro1(), loop=self.loop) + b = tasks.Task(coro2(), loop=self.loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED, + loop=self.loop), + loop=self.loop) + + done, pending = self.loop.run_until_complete(task) + self.assertEqual({a, b}, done) + self.assertTrue(a.done()) + self.assertIsNone(a.result()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + + def test_wait_first_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + # first_exception, task already has exception + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + + @tasks.coroutine + def exc(): + raise ZeroDivisionError('err') + + b = tasks.Task(exc(), loop=loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION, + loop=loop), + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertAlmostEqual(0, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_first_exception_in_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + # first_exception, exception during waiting + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + + @tasks.coroutine + def exc(): + yield from tasks.sleep(0.01, loop=loop) + raise ZeroDivisionError('err') + + b = tasks.Task(exc(), loop=loop) + task = tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION, + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertAlmostEqual(0.01, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_with_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(0.15, loop=loop) + raise ZeroDivisionError('really') + + b = tasks.Task(sleeper(), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], loop=loop) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + def test_wait_with_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.11, when) + yield 0.11 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], timeout=0.11, + loop=loop) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.11, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_concurrent_complete(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + done, pending = loop.run_until_complete( + tasks.wait([b, a], timeout=0.1, loop=loop)) + + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + self.assertAlmostEqual(0.1, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_as_completed(self): + + def gen(): + yield 0 + yield 0 + yield 0.01 + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + completed = set() + time_shifted = False + + @tasks.coroutine + def sleeper(dt, x): + nonlocal time_shifted + yield from tasks.sleep(dt, loop=loop) + completed.add(x) + if not time_shifted and 'a' in completed and 'b' in completed: + time_shifted = True + loop.advance_time(0.14) + return x + + a = sleeper(0.01, 'a') + b = sleeper(0.01, 'b') + c = sleeper(0.15, 'c') + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([b, c, a], loop=loop): + values.append((yield from f)) + return values + + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + + # Doing it again should take no time and exercise a different path. + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + def test_as_completed_with_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.12, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0.1 + self.assertAlmostEqual(0.12, when) + yield 0.02 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.sleep(0.1, 'a', loop=loop) + b = tasks.sleep(0.15, 'b', loop=loop) + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([a, b], timeout=0.12, loop=loop): + try: + v = yield from f + values.append((1, v)) + except futures.TimeoutError as exc: + values.append((2, exc)) + return values + + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) + self.assertAlmostEqual(0.12, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_as_completed_reverse_wait(self): + + def gen(): + yield 0 + yield 0.05 + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + completed = set() + time_shifted = False + + a = tasks.sleep(0.05, 'a', loop=loop) + b = tasks.sleep(0.10, 'b', loop=loop) + fs = {a, b} + futs = list(tasks.as_completed(fs, loop=loop)) + self.assertEqual(len(futs), 2) + + x = loop.run_until_complete(futs[1]) + self.assertEqual(x, 'a') + self.assertAlmostEqual(0.05, loop.time()) + loop.advance_time(0.05) + y = loop.run_until_complete(futs[0]) + self.assertEqual(y, 'b') + self.assertAlmostEqual(0.10, loop.time()) + + def test_as_completed_concurrent(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0 + self.assertAlmostEqual(0.05, when) + yield 0.05 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.sleep(0.05, 'a', loop=loop) + b = tasks.sleep(0.05, 'b', loop=loop) + fs = {a, b} + futs = list(tasks.as_completed(fs, loop=loop)) + self.assertEqual(len(futs), 2) + waiter = tasks.wait(futs, loop=loop) + done, pending = loop.run_until_complete(waiter) + self.assertEqual(set(f.result() for f in done), {'a', 'b'}) + + def test_sleep(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0.05 + self.assertAlmostEqual(0.1, when) + yield 0.05 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def sleeper(dt, arg): + yield from tasks.sleep(dt/2, loop=loop) + res = yield from tasks.sleep(dt/2, arg, loop=loop) + return res + + t = tasks.Task(sleeper(0.1, 'yeah'), loop=loop) + loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + self.assertAlmostEqual(0.1, loop.time()) + + def test_sleep_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + t = tasks.Task(tasks.sleep(10.0, 'yeah', loop=loop), + loop=loop) + + handle = None + orig_call_later = loop.call_later + + def call_later(self, delay, callback, *args): + nonlocal handle + handle = orig_call_later(self, delay, callback, *args) + return handle + + loop.call_later = call_later + test_utils.run_briefly(loop) + + self.assertFalse(handle._cancelled) + + t.cancel() + test_utils.run_briefly(loop) + self.assertTrue(handle._cancelled) + + def test_task_cancel_sleeping_task(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(5000, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + sleepfut = None + + @tasks.coroutine + def sleep(dt): + nonlocal sleepfut + sleepfut = tasks.sleep(dt, loop=loop) + yield from sleepfut + + @tasks.coroutine + def doit(): + sleeper = tasks.Task(sleep(5000), loop=loop) + loop.call_later(0.1, sleeper.cancel) + try: + yield from sleeper + except futures.CancelledError: + return 'cancelled' + else: + return 'slept in' + + doer = doit() + self.assertEqual(loop.run_until_complete(doer), 'cancelled') + self.assertAlmostEqual(0.1, loop.time()) + + def test_task_cancel_waiter_future(self): + fut = futures.Future(loop=self.loop) + + @tasks.coroutine + def coro(): + yield from fut + + task = tasks.Task(coro(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(task._fut_waiter, fut) + + task.cancel() + test_utils.run_briefly(self.loop) + self.assertRaises( + futures.CancelledError, self.loop.run_until_complete, task) + self.assertIsNone(task._fut_waiter) + self.assertTrue(fut.cancelled()) + + def test_step_in_completed_task(self): + @tasks.coroutine + def notmuch(): + return 'ko' + + task = tasks.Task(notmuch(), loop=self.loop) + task.set_result('ok') + + self.assertRaises(AssertionError, task._step) + + def test_step_result(self): + @tasks.coroutine + def notmuch(): + yield None + yield 1 + return 'ko' + + self.assertRaises( + RuntimeError, self.loop.run_until_complete, notmuch()) + + def test_step_result_future(self): + # If coroutine returns future, task waits on this future. + + class Fut(futures.Future): + def __init__(self, *args, **kwds): + self.cb_added = False + super().__init__(*args, **kwds) + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + fut = Fut(loop=self.loop) + result = None + + @tasks.coroutine + def wait_for_future(): + nonlocal result + result = yield from fut + + t = tasks.Task(wait_for_future(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertTrue(fut.cb_added) + + res = object() + fut.set_result(res) + test_utils.run_briefly(self.loop) + self.assertIs(res, result) + self.assertTrue(t.done()) + self.assertIsNone(t.result()) + + def test_step_with_baseexception(self): + @tasks.coroutine + def notmutch(): + raise BaseException() + + task = tasks.Task(notmutch(), loop=self.loop) + self.assertRaises(BaseException, task._step) + + self.assertTrue(task.done()) + self.assertIsInstance(task.exception(), BaseException) + + def test_baseexception_during_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(10, loop=loop) + + base_exc = BaseException() + + @tasks.coroutine + def notmutch(): + try: + yield from sleeper() + except futures.CancelledError: + raise base_exc + + task = tasks.Task(notmutch(), loop=loop) + test_utils.run_briefly(loop) + + task.cancel() + self.assertFalse(task.done()) + + self.assertRaises(BaseException, test_utils.run_briefly, loop) + + self.assertTrue(task.done()) + self.assertFalse(task.cancelled()) + self.assertIs(task.exception(), base_exc) + + def test_iscoroutinefunction(self): + def fn(): + pass + + self.assertFalse(tasks.iscoroutinefunction(fn)) + + def fn1(): + yield + self.assertFalse(tasks.iscoroutinefunction(fn1)) + + @tasks.coroutine + def fn2(): + yield + self.assertTrue(tasks.iscoroutinefunction(fn2)) + + def test_yield_vs_yield_from(self): + fut = futures.Future(loop=self.loop) + + @tasks.coroutine + def wait_for_future(): + yield fut + + task = wait_for_future() + with self.assertRaises(RuntimeError) as cm: + self.loop.run_until_complete(task) + + self.assertFalse(fut.done()) + + def test_yield_vs_yield_from_generator(self): + @tasks.coroutine + def coro(): + yield + + @tasks.coroutine + def wait_for_future(): + yield coro() + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, task) + + def test_coroutine_non_gen_function(self): + @tasks.coroutine + def func(): + return 'test' + + self.assertTrue(tasks.iscoroutinefunction(func)) + + coro = func() + self.assertTrue(tasks.iscoroutine(coro)) + + res = self.loop.run_until_complete(coro) + self.assertEqual(res, 'test') + + def test_coroutine_non_gen_function_return_future(self): + fut = futures.Future(loop=self.loop) + + @tasks.coroutine + def func(): + return fut + + @tasks.coroutine + def coro(): + fut.set_result('test') + + t1 = tasks.Task(func(), loop=self.loop) + t2 = tasks.Task(coro(), loop=self.loop) + res = self.loop.run_until_complete(t1) + self.assertEqual(res, 'test') + self.assertIsNone(t2.result()) + + +class GatherTestsBase: + + def setUp(self): + self.one_loop = test_utils.TestLoop() + self.other_loop = test_utils.TestLoop() + + def tearDown(self): + self.one_loop.close() + self.other_loop.close() + + def _run_loop(self, loop): + while loop._ready: + test_utils.run_briefly(loop) + + def _check_success(self, **kwargs): + a, b, c = [futures.Future(loop=self.one_loop) for i in range(3)] + fut = tasks.gather(*self.wrap_futures(a, b, c), **kwargs) + cb = Mock() + fut.add_done_callback(cb) + b.set_result(1) + a.set_result(2) + self._run_loop(self.one_loop) + self.assertEqual(cb.called, False) + self.assertFalse(fut.done()) + c.set_result(3) + self._run_loop(self.one_loop) + cb.assert_called_once_with(fut) + self.assertEqual(fut.result(), [2, 1, 3]) + + def test_success(self): + self._check_success() + self._check_success(return_exceptions=False) + + def test_result_exception_success(self): + self._check_success(return_exceptions=True) + + def test_one_exception(self): + a, b, c, d, e = [futures.Future(loop=self.one_loop) for i in range(5)] + fut = tasks.gather(*self.wrap_futures(a, b, c, d, e)) + cb = Mock() + fut.add_done_callback(cb) + exc = ZeroDivisionError() + a.set_result(1) + b.set_exception(exc) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertIs(fut.exception(), exc) + # Does nothing + c.set_result(3) + d.cancel() + e.set_exception(RuntimeError()) + + def test_return_exceptions(self): + a, b, c, d = [futures.Future(loop=self.one_loop) for i in range(4)] + fut = tasks.gather(*self.wrap_futures(a, b, c, d), + return_exceptions=True) + cb = Mock() + fut.add_done_callback(cb) + exc = ZeroDivisionError() + exc2 = RuntimeError() + b.set_result(1) + c.set_exception(exc) + a.set_result(3) + self._run_loop(self.one_loop) + self.assertFalse(fut.done()) + d.set_exception(exc2) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertEqual(fut.result(), [3, 1, exc, exc2]) + + +class FutureGatherTests(GatherTestsBase, unittest.TestCase): + + def wrap_futures(self, *futures): + return futures + + def _check_empty_sequence(self, seq_or_iter): + events.set_event_loop(self.one_loop) + self.addCleanup(events.set_event_loop, None) + fut = tasks.gather(*seq_or_iter) + self.assertIsInstance(fut, futures.Future) + self.assertIs(fut._loop, self.one_loop) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + self.assertEqual(fut.result(), []) + fut = tasks.gather(*seq_or_iter, loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + + def test_constructor_empty_sequence(self): + self._check_empty_sequence([]) + self._check_empty_sequence(()) + self._check_empty_sequence(set()) + self._check_empty_sequence(iter("")) + + def test_constructor_heterogenous_futures(self): + fut1 = futures.Future(loop=self.one_loop) + fut2 = futures.Future(loop=self.other_loop) + with self.assertRaises(ValueError): + tasks.gather(fut1, fut2) + with self.assertRaises(ValueError): + tasks.gather(fut1, loop=self.other_loop) + + def test_constructor_homogenous_futures(self): + children = [futures.Future(loop=self.other_loop) for i in range(3)] + fut = tasks.gather(*children) + self.assertIs(fut._loop, self.other_loop) + self._run_loop(self.other_loop) + self.assertFalse(fut.done()) + fut = tasks.gather(*children, loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + self._run_loop(self.other_loop) + self.assertFalse(fut.done()) + + def test_one_cancellation(self): + a, b, c, d, e = [futures.Future(loop=self.one_loop) for i in range(5)] + fut = tasks.gather(a, b, c, d, e) + cb = Mock() + fut.add_done_callback(cb) + a.set_result(1) + b.cancel() + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertTrue(fut.cancelled()) + # Does nothing + c.set_result(3) + d.cancel() + e.set_exception(RuntimeError()) + + def test_result_exception_one_cancellation(self): + a, b, c, d, e, f = [futures.Future(loop=self.one_loop) + for i in range(6)] + fut = tasks.gather(a, b, c, d, e, f, return_exceptions=True) + cb = Mock() + fut.add_done_callback(cb) + a.set_result(1) + b.set_exception(ZeroDivisionError()) + c.cancel() + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertTrue(fut.cancelled()) + # Does nothing + d.set_result(3) + e.cancel() + f.set_exception(RuntimeError()) + + +class CoroutineGatherTests(GatherTestsBase, unittest.TestCase): + + def setUp(self): + super().setUp() + events.set_event_loop(self.one_loop) + + def tearDown(self): + events.set_event_loop(None) + super().tearDown() + + def wrap_futures(self, *futures): + coros = [] + for fut in futures: + def coro(fut=fut): + return (yield from fut) + coros.append(coro()) + return coros + + def test_constructor_loop_selection(self): + @tasks.coroutine + def coro(): + yield from [] + return 'abc' + fut = tasks.gather(coro(), coro()) + self.assertIs(fut._loop, self.one_loop) + fut = tasks.gather(coro(), coro(), loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/transports_test.py b/tests/transports_test.py new file mode 100644 index 00000000..5920cda6 --- /dev/null +++ b/tests/transports_test.py @@ -0,0 +1,59 @@ +"""Tests for transports.py.""" + +import unittest +import unittest.mock + +from tulip import futures +from tulip import transports + + +class TransportTests(unittest.TestCase): + + def test_ctor_extra_is_none(self): + transport = transports.Transport() + self.assertEqual(transport._extra, {}) + + def test_get_extra_info(self): + transport = transports.Transport({'extra': 'info'}) + self.assertEqual('info', transport.get_extra_info('extra')) + self.assertIsNone(transport.get_extra_info('unknown')) + + default = object() + self.assertIs(default, transport.get_extra_info('unknown', default)) + + def test_writelines(self): + transport = transports.Transport() + transport.write = unittest.mock.Mock() + + transport.writelines(['line1', 'line2', 'line3']) + self.assertEqual(3, transport.write.call_count) + + def test_not_implemented(self): + transport = transports.Transport() + + self.assertRaises(NotImplementedError, transport.write, 'data') + self.assertRaises(NotImplementedError, transport.write_eof) + self.assertRaises(NotImplementedError, transport.can_write_eof) + self.assertRaises(NotImplementedError, transport.pause) + self.assertRaises(NotImplementedError, transport.resume) + self.assertRaises(NotImplementedError, transport.close) + self.assertRaises(NotImplementedError, transport.abort) + self.assertRaises(NotImplementedError, transport.pause_writing) + self.assertRaises(NotImplementedError, transport.resume_writing) + self.assertRaises(NotImplementedError, transport.discard_output) + + def test_dgram_not_implemented(self): + transport = transports.DatagramTransport() + + self.assertRaises(NotImplementedError, transport.sendto, 'data') + self.assertRaises(NotImplementedError, transport.abort) + + def test_subprocess_transport_not_implemented(self): + transport = transports.SubprocessTransport() + + self.assertRaises(NotImplementedError, transport.get_pid) + self.assertRaises(NotImplementedError, transport.get_returncode) + self.assertRaises(NotImplementedError, transport.get_pipe_transport, 1) + self.assertRaises(NotImplementedError, transport.send_signal, 1) + self.assertRaises(NotImplementedError, transport.terminate) + self.assertRaises(NotImplementedError, transport.kill) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py new file mode 100644 index 00000000..f0b42a39 --- /dev/null +++ b/tests/unix_events_test.py @@ -0,0 +1,818 @@ +"""Tests for unix_events.py.""" + +import gc +import errno +import io +import pprint +import signal +import stat +import sys +import tempfile +import unittest +import unittest.mock + + +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import test_utils +from tulip import unix_events + + +@unittest.skipUnless(signal, 'Signals are not supported') +class SelectorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.loop = unix_events.SelectorEventLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_check_signal(self): + self.assertRaises( + TypeError, self.loop._check_signal, '1') + self.assertRaises( + ValueError, self.loop._check_signal, signal.NSIG + 1) + + def test_handle_signal_no_handler(self): + self.loop._handle_signal(signal.NSIG + 1, ()) + + def test_handle_signal_cancelled_handler(self): + h = events.Handle(unittest.mock.Mock(), ()) + h.cancel() + self.loop._signal_handlers[signal.NSIG + 1] = h + self.loop.remove_signal_handler = unittest.mock.Mock() + self.loop._handle_signal(signal.NSIG + 1, ()) + self.loop.remove_signal_handler.assert_called_with(signal.NSIG + 1) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_setup_error(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.set_wakeup_fd.side_effect = ValueError + + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + cb = lambda: True + self.loop.add_signal_handler(signal.SIGHUP, cb) + h = self.loop._signal_handlers.get(signal.SIGHUP) + self.assertTrue(isinstance(h, events.Handle)) + self.assertEqual(h._callback, cb) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_install_error(self, m_signal): + m_signal.NSIG = signal.NSIG + + def set_wakeup_fd(fd): + if fd == -1: + raise ValueError() + m_signal.set_wakeup_fd = set_wakeup_fd + + class Err(OSError): + errno = errno.EFAULT + m_signal.signal.side_effect = Err + + self.assertRaises( + Err, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_add_signal_handler_install_error2(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.loop._signal_handlers[signal.SIGHUP] = lambda: True + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(1, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_add_signal_handler_install_error3(self, m_logging, m_signal): + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + m_signal.NSIG = signal.NSIG + + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(2, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + self.assertTrue( + self.loop.remove_signal_handler(signal.SIGHUP)) + self.assertTrue(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_2(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.SIGINT = signal.SIGINT + + self.loop.add_signal_handler(signal.SIGINT, lambda: True) + self.loop._signal_handlers[signal.SIGHUP] = object() + m_signal.set_wakeup_fd.reset_mock() + + self.assertTrue( + self.loop.remove_signal_handler(signal.SIGINT)) + self.assertFalse(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGINT, m_signal.default_int_handler), + m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.tulip_log') + def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.set_wakeup_fd.side_effect = ValueError + + self.loop.remove_signal_handler(signal.SIGHUP) + self.assertTrue(m_logging.info) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error(self, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.signal.side_effect = OSError + + self.assertRaises( + OSError, self.loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error2(self, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.assertRaises( + RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = True + m_WIFSIGNALED.return_value = False + m_WEXITSTATUS.return_value = 3 + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + transp._process_exited.assert_called_with(3) + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_signal(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = False + m_WIFSIGNALED.return_value = True + m_WTERMSIG.return_value = 1 + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + transp._process_exited.assert_called_with(-1) + self.assertFalse(m_WEXITSTATUS.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_zero_pid(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(0, object()), ChildProcessError] + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m_WEXITSTATUS.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_not_registered_subprocess(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = True + m_WIFSIGNALED.return_value = False + m_WEXITSTATUS.return_value = 3 + + self.loop._sig_chld() + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_unknown_status(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = False + m_WIFSIGNALED.return_value = False + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('tulip.unix_events.tulip_log') + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_unknown_status_in_handler(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG, + m_log): + m_waitpid.side_effect = Exception + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m_WEXITSTATUS.called) + m_log.exception.assert_called_with( + 'Unknown exception in SIGCHLD handler') + + +class UnixReadPipeTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(protocols.Protocol) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + + fcntl_patcher = unittest.mock.patch('fcntl.fcntl') + fcntl_patcher.start() + self.addCleanup(fcntl_patcher.stop) + + def test_ctor(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = futures.Future(loop=self.loop) + unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol, fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + + @unittest.mock.patch('os.read') + def test__read_ready(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.return_value = b'data' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.data_received.assert_called_with(b'data') + + @unittest.mock.patch('os.read') + def test__read_ready_eof(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.return_value = b'' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.eof_received.assert_called_with() + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('os.read') + def test__read_ready_blocked(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.side_effect = BlockingIOError + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.data_received.called) + + @unittest.mock.patch('tulip.log.tulip_log.exception') + @unittest.mock.patch('os.read') + def test__read_ready_error(self, m_read, m_logexc): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + err = OSError() + m_read.side_effect = err + tr._close = unittest.mock.Mock() + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + tr._close.assert_called_with(err) + m_logexc.assert_called_with('Fatal error for %s', tr) + + @unittest.mock.patch('os.read') + def test_pause(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + m = unittest.mock.Mock() + self.loop.add_reader(5, m) + tr.pause() + self.assertFalse(self.loop.readers) + + @unittest.mock.patch('os.read') + def test_resume(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr.resume() + self.loop.assert_reader(5, tr._read_ready) + + @unittest.mock.patch('os.read') + def test_close(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr._close = unittest.mock.Mock() + tr.close() + tr._close.assert_called_with(None) + + @unittest.mock.patch('os.read') + def test_close_already_closing(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr._closing = True + tr._close = unittest.mock.Mock() + tr.close() + self.assertFalse(tr._close.called) + + @unittest.mock.patch('os.read') + def test__close(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = object() + tr._close(err) + self.assertTrue(tr._closing) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + + def test__call_connection_lost(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test__call_connection_lost_with_err(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + +class UnixWritePipeTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(protocols.BaseProtocol) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + + fcntl_patcher = unittest.mock.patch('fcntl.fcntl') + fcntl_patcher.start() + self.addCleanup(fcntl_patcher.stop) + + fstat_patcher = unittest.mock.patch('os.fstat') + m_fstat = fstat_patcher.start() + st = unittest.mock.Mock() + st.st_mode = stat.S_IFIFO + m_fstat.return_value = st + self.addCleanup(fstat_patcher.stop) + + def test_ctor(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = futures.Future(loop=self.loop) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol, fut) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.assertEqual(None, fut.result()) + + def test_can_write_eof(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.assertTrue(tr.can_write_eof()) + + @unittest.mock.patch('os.write') + def test_write(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.return_value = 4 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_no_data(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write(b'') + self.assertFalse(m_write.called) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_partial(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.return_value = 2 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'ta'], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_buffer(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'previous'] + tr.write(b'data') + self.assertFalse(m_write.called) + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'previous', b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_again(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.side_effect = BlockingIOError() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('tulip.unix_events.tulip_log') + @unittest.mock.patch('os.write') + def test_write_err(self, m_write, m_log): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + m_write.side_effect = err + tr._fatal_error = unittest.mock.Mock() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + tr._fatal_error.assert_called_with(err) + self.assertEqual(1, tr._conn_lost) + + tr.write(b'data') + self.assertEqual(2, tr._conn_lost) + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + m_log.warning.assert_called_with( + 'os.write(pipe, data) raised exception.') + + def test__read_ready(self): + tr = unix_events._UnixWritePipeTransport(self.loop, self.pipe, + self.protocol) + tr._read_ready() + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + self.assertTrue(tr._closing) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('os.write') + def test__write_ready(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_partial(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 3 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'a'], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_again(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.side_effect = BlockingIOError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_empty(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 0 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('tulip.log.tulip_log.exception') + @unittest.mock.patch('os.write') + def test__write_ready_err(self, m_write, m_logexc): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.side_effect = err = OSError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + m_logexc.assert_called_with('Fatal error for %s', tr) + self.assertEqual(1, tr._conn_lost) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + + + @unittest.mock.patch('os.write') + def test__write_ready_closing(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._closing = True + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) + self.assertEqual([], tr._buffer) + self.protocol.connection_lost.assert_called_with(None) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('os.write') + def test_abort(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + self.loop.add_reader(5, tr._read_ready) + tr._buffer = [b'da', b'ta'] + tr.abort() + self.assertFalse(m_write.called) + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test__call_connection_lost(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test__call_connection_lost_with_err(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test_close(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr.close() + tr.write_eof.assert_called_with() + + def test_close_closing(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr._closing = True + tr.close() + self.assertFalse(tr.write_eof.called) + + def test_write_eof(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test_write_eof_pending(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr._buffer = [b'data'] + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.protocol.connection_lost.called) + + def test_pause_resume_writing(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr.pause_writing() + self.assertFalse(tr._writing) + tr.resume_writing() + self.assertTrue(tr._writing) + + def test_double_pause_resume_writing(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr.pause_writing() + self.assertFalse(tr._writing) + tr.pause_writing() + self.assertFalse(tr._writing) + tr.resume_writing() + self.assertTrue(tr._writing) + tr.resume_writing() + self.assertTrue(tr._writing) + + def test_pause_resume_writing_with_nonempty_buffer(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + tr.pause_writing() + self.assertFalse(tr._writing) + self.assertFalse(self.loop.writers) + self.assertEqual([b'da', b'ta'], tr._buffer) + + tr.resume_writing() + self.assertTrue(tr._writing) + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'da', b'ta'], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_on_pause(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + tr.pause_writing() + + tr._write_ready() + self.assertFalse(m_write.called) + self.assertFalse(self.loop.writers) + self.assertEqual([b'da', b'ta'], tr._buffer) + self.assertFalse(tr._writing) + + def test_discard_output(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr._buffer = [b'da', b'ta'] + self.loop.add_writer(5, tr._write_ready) + tr.discard_output() + self.assertTrue(tr._writing) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + def test_discard_output_without_pending_writes(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr.discard_output() + self.assertTrue(tr._writing) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) diff --git a/tests/windows_events_test.py b/tests/windows_events_test.py new file mode 100644 index 00000000..ce9b74da --- /dev/null +++ b/tests/windows_events_test.py @@ -0,0 +1,81 @@ +import unittest + +import tulip + +from tulip import windows_events +from tulip import protocols +from tulip import streams + + +def connect_read_pipe(loop, file): + stream_reader = streams.StreamReader(loop=loop) + protocol = _StreamReaderProtocol(stream_reader) + loop._make_read_pipe_transport(file, protocol) + return stream_reader + + +class _StreamReaderProtocol(protocols.Protocol): + def __init__(self, stream_reader): + self.stream_reader = stream_reader + + def connection_lost(self, exc): + self.stream_reader.set_exception(exc) + + def data_received(self, data): + self.stream_reader.feed_data(data) + + def eof_received(self): + self.stream_reader.feed_eof() + + +class ProactorTests(unittest.TestCase): + + def setUp(self): + self.loop = windows_events.ProactorEventLoop() + tulip.set_event_loop(None) + + def tearDown(self): + self.loop.close() + self.loop = None + + def test_pause_resume_discard(self): + a, b = self.loop._socketpair() + trans = self.loop._make_write_pipe_transport(a, protocols.Protocol()) + reader = connect_read_pipe(self.loop, b) + f = tulip.async(reader.readline(), loop=self.loop) + + trans.write(b'msg1\n') + self.loop.run_until_complete(f, timeout=0.01) + self.assertEqual(f.result(), b'msg1\n') + f = tulip.async(reader.readline(), loop=self.loop) + + trans.pause_writing() + trans.write(b'msg2\n') + with self.assertRaises(tulip.TimeoutError): + self.loop.run_until_complete(f, timeout=0.01) + self.assertEqual(trans._buffer, [b'msg2\n']) + + trans.resume_writing() + self.loop.run_until_complete(f, timeout=0.1) + self.assertEqual(f.result(), b'msg2\n') + f = tulip.async(reader.readline(), loop=self.loop) + + trans.pause_writing() + trans.write(b'msg3\n') + self.assertEqual(trans._buffer, [b'msg3\n']) + trans.discard_output() + self.assertEqual(trans._buffer, []) + + trans.write(b'msg4\n') + self.assertEqual(trans._buffer, [b'msg4\n']) + trans.resume_writing() + self.loop.run_until_complete(f, timeout=0.01) + self.assertEqual(f.result(), b'msg4\n') + + def test_close(self): + a, b = self.loop._socketpair() + trans = self.loop._make_socket_transport(a, protocols.Protocol()) + f = tulip.async(self.loop.sock_recv(b, 100)) + trans.close() + self.loop.run_until_complete(f, timeout=1) + self.assertEqual(f.result(), b'') diff --git a/tests/windows_utils_test.py b/tests/windows_utils_test.py new file mode 100644 index 00000000..b23896d3 --- /dev/null +++ b/tests/windows_utils_test.py @@ -0,0 +1,132 @@ +"""Tests for window_utils""" + +import sys +import test.support +import unittest +import unittest.mock +import _winapi + +from tulip import windows_utils +from tulip import _overlapped + + +class WinsocketpairTests(unittest.TestCase): + + def test_winsocketpair(self): + ssock, csock = windows_utils.socketpair() + + csock.send(b'xxx') + self.assertEqual(b'xxx', ssock.recv(1024)) + + csock.close() + ssock.close() + + @unittest.mock.patch('tulip.windows_utils.socket') + def test_winsocketpair_exc(self, m_socket): + m_socket.socket.return_value.getsockname.return_value = ('', 12345) + m_socket.socket.return_value.accept.return_value = object(), object() + m_socket.socket.return_value.connect.side_effect = OSError() + + self.assertRaises(OSError, windows_utils.socketpair) + + +class PipeTests(unittest.TestCase): + + def test_pipe_overlapped(self): + h1, h2 = windows_utils.pipe(overlapped=(True, True)) + try: + ov1 = _overlapped.Overlapped() + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, 0) + + ov1.ReadFile(h1, 100) + self.assertTrue(ov1.pending) + self.assertEqual(ov1.error, _winapi.ERROR_IO_PENDING) + ERROR_IO_INCOMPLETE = 996 + try: + ov1.getresult() + except OSError as e: + self.assertEqual(e.winerror, ERROR_IO_INCOMPLETE) + else: + raise RuntimeError('expected ERROR_IO_INCOMPLETE') + + ov2 = _overlapped.Overlapped() + self.assertFalse(ov2.pending) + self.assertEqual(ov2.error, 0) + + ov2.WriteFile(h2, b"hello") + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + + res = _winapi.WaitForMultipleObjects([ov2.event], False, 100) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, ERROR_IO_INCOMPLETE) + self.assertFalse(ov2.pending) + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + self.assertEqual(ov1.getresult(), b"hello") + finally: + _winapi.CloseHandle(h1) + _winapi.CloseHandle(h2) + + def test_pipe_handle(self): + h, _ = windows_utils.pipe(overlapped=(True, True)) + _winapi.CloseHandle(_) + p = windows_utils.PipeHandle(h) + self.assertEqual(p.fileno(), h) + self.assertEqual(p.handle, h) + + # check garbage collection of p closes handle + del p + test.support.gc_collect() + try: + _winapi.CloseHandle(h) + except OSError as e: + self.assertEqual(e.winerror, 6) # ERROR_INVALID_HANDLE + else: + raise RuntimeError('expected ERROR_INVALID_HANDLE') + + +class PopenTests(unittest.TestCase): + + def test_popen(self): + command = r"""if 1: + import sys + s = sys.stdin.readline() + sys.stdout.write(s.upper()) + sys.stderr.write('stderr') + """ + msg = b"blah\n" + + p = windows_utils.Popen([sys.executable, '-c', command], + stdin=windows_utils.PIPE, + stdout=windows_utils.PIPE, + stderr=windows_utils.PIPE) + + for f in [p.stdin, p.stdout, p.stderr]: + self.assertIsInstance(f, windows_utils.PipeHandle) + + ovin = _overlapped.Overlapped() + ovout = _overlapped.Overlapped() + overr = _overlapped.Overlapped() + + ovin.WriteFile(p.stdin.handle, msg) + ovout.ReadFile(p.stdout.handle, 100) + overr.ReadFile(p.stderr.handle, 100) + + events = [ovin.event, ovout.event, overr.event] + res = _winapi.WaitForMultipleObjects(events, True, 2000) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + self.assertFalse(ovout.pending) + self.assertFalse(overr.pending) + self.assertFalse(ovin.pending) + + self.assertEqual(ovin.getresult(), len(msg)) + out = ovout.getresult().rstrip() + err = overr.getresult().rstrip() + + self.assertGreater(len(out), 0) + self.assertGreater(len(err), 0) + # allow for partial reads... + self.assertTrue(msg.upper().rstrip().startswith(out)) + self.assertTrue(b"stderr".startswith(err)) diff --git a/tulip/__init__.py b/tulip/__init__.py new file mode 100644 index 00000000..9de84cb0 --- /dev/null +++ b/tulip/__init__.py @@ -0,0 +1,28 @@ +"""Tulip 2.0, tracking PEP 3156.""" + +import sys + +# This relies on each of the submodules having an __all__ variable. +from .futures import * +from .events import * +from .locks import * +from .transports import * +from .parsers import * +from .protocols import * +from .streams import * +from .tasks import * + +if sys.platform == 'win32': # pragma: no cover + from .windows_events import * +else: + from .unix_events import * # pragma: no cover + + +__all__ = (futures.__all__ + + events.__all__ + + locks.__all__ + + transports.__all__ + + parsers.__all__ + + protocols.__all__ + + streams.__all__ + + tasks.__all__) diff --git a/tulip/base_events.py b/tulip/base_events.py new file mode 100644 index 00000000..3bccfc83 --- /dev/null +++ b/tulip/base_events.py @@ -0,0 +1,592 @@ +"""Base implementation of event loop. + +The event loop can be broken up into a multiplexer (the part +responsible for notifying us of IO events) and the event loop proper, +which wraps a multiplexer with functionality for scheduling callbacks, +immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. +""" + + +import collections +import concurrent.futures +import heapq +import logging +import socket +import subprocess +import time +import os +import sys + +from . import events +from . import futures +from . import tasks +from .log import tulip_log + + +__all__ = ['BaseEventLoop'] + + +# Argument for default thread pool executor creation. +_MAX_WORKERS = 5 + + +class _StopError(BaseException): + """Raised to stop the event loop.""" + + +def _raise_stop_error(*args): + raise _StopError + + +class BaseEventLoop(events.AbstractEventLoop): + + def __init__(self): + self._ready = collections.deque() + self._scheduled = [] + self._default_executor = None + self._internal_fds = 0 + self._running = False + + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None): + """Create socket transport.""" + raise NotImplementedError + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + server_side=False, extra=None): + """Create SSL transport.""" + raise NotImplementedError + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + """Create datagram transport.""" + raise NotImplementedError + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create read pipe transport.""" + raise NotImplementedError + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create write pipe transport.""" + raise NotImplementedError + + @tasks.coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + """Create subprocess transport.""" + raise NotImplementedError + + def _read_from_self(self): + """XXX""" + raise NotImplementedError + + def _write_to_self(self): + """XXX""" + raise NotImplementedError + + def _process_events(self, event_list): + """Process selector events.""" + raise NotImplementedError + + def run_forever(self): + """Run until stop() is called.""" + if self._running: + raise RuntimeError('Event loop is running.') + self._running = True + try: + while True: + try: + self._run_once() + except _StopError: + break + finally: + self._running = False + + def run_until_complete(self, future): + """Run until the Future is done. + + If the argument is a coroutine, it is wrapped in a Task. + + XXX TBD: It would be disastrous to call run_until_complete() + with the same coroutine twice -- it would wrap it in two + different Tasks and that can't be good. + + Return the Future's result, or raise its exception. + """ + future = tasks.async(future, loop=self) + future.add_done_callback(_raise_stop_error) + self.run_forever() + future.remove_done_callback(_raise_stop_error) + if not future.done(): + raise RuntimeError('Event loop stopped before Future completed.') + + return future.result() + + def stop(self): + """Stop running the event loop. + + Every callback scheduled before stop() is called will run. + Callback scheduled after stop() is called won't. However, + those callbacks will run if run() is called again later. + """ + self.call_soon(_raise_stop_error) + + def is_running(self): + """Returns running status of event loop.""" + return self._running + + def time(self): + """Return the time according to the event loop's clock.""" + return time.monotonic() + + def call_later(self, delay, callback, *args): + """Arrange for a callback to be called at a given time. + + Return a Handle: an opaque object with a cancel() method that + can be used to cancel the call. + + The delay can be an int or float, expressed in seconds. It is + always a relative time. + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + return self.call_at(self.time() + delay, callback, *args) + + def call_at(self, when, callback, *args): + """Like call_later(), but uses an absolute time.""" + timer = events.TimerHandle(when, callback, args) + heapq.heappush(self._scheduled, timer) + return timer + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + handle = events.make_handle(callback, args) + self._ready.append(handle) + return handle + + def call_soon_threadsafe(self, callback, *args): + """XXX""" + handle = self.call_soon(callback, *args) + self._write_to_self() + return handle + + def run_in_executor(self, executor, callback, *args): + if isinstance(callback, events.Handle): + assert not args + assert not isinstance(callback, events.TimerHandle) + if callback._cancelled: + f = futures.Future(loop=self) + f.set_result(None) + return f + callback, args = callback._callback, callback._args + if executor is None: + executor = self._default_executor + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + self._default_executor = executor + return futures.wrap_future(executor.submit(callback, *args), loop=self) + + def set_default_executor(self, executor): + self._default_executor = executor + + def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) + + def getnameinfo(self, sockaddr, flags=0): + return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + + @tasks.coroutine + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=None, family=0, proto=0, flags=0, sock=None, + local_addr=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + 'host/port and sock can not be specified at the same time') + + f1 = self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + fs = [f1] + if local_addr is not None: + f2 = self.getaddrinfo( + *local_addr, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + fs.append(f2) + else: + f2 = None + + yield from tasks.wait(fs, loop=self) + + infos = f1.result() + if not infos: + raise OSError('getaddrinfo() returned empty list') + if f2 is not None: + laddr_infos = f2.result() + if not laddr_infos: + raise OSError('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + if f2 is not None: + for _, _, _, _, laddr in laddr_infos: + try: + sock.bind(laddr) + break + except OSError as exc: + exc = OSError( + exc.errno, 'error while ' + 'attempting to bind on address ' + '{!r}: {}'.format( + laddr, exc.strerror.lower())) + exceptions.append(exc) + else: + sock.close() + sock = None + continue + yield from self.sock_connect(sock, address) + except OSError as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + if len(exceptions) == 1: + raise exceptions[0] + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise OSError('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + + elif sock is None: + raise ValueError( + 'host and port was not specified and no sock specified') + + sock.setblocking(False) + + protocol = protocol_factory() + waiter = futures.Future(loop=self) + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + transport = self._make_ssl_transport( + sock, protocol, sslcontext, waiter, server_side=False) + else: + transport = self._make_socket_transport(sock, protocol, waiter) + + yield from waiter + return transport, protocol + + @tasks.coroutine + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + """Create datagram connection.""" + if not (local_addr or remote_addr): + if family == 0: + raise ValueError('unexpected address family') + addr_pairs_info = (((family, proto), (None, None)),) + else: + # join addresss by (family, protocol) + addr_infos = collections.OrderedDict() + for idx, addr in ((0, local_addr), (1, remote_addr)): + if addr is not None: + assert isinstance(addr, tuple) and len(addr) == 2, ( + '2-tuple is expected') + + infos = yield from self.getaddrinfo( + *addr, family=family, type=socket.SOCK_DGRAM, + proto=proto, flags=flags) + if not infos: + raise OSError('getaddrinfo() returned empty list') + + for fam, _, pro, _, address in infos: + key = (fam, pro) + if key not in addr_infos: + addr_infos[key] = [None, None] + addr_infos[key][idx] = address + + # each addr has to have info for each (family, proto) pair + addr_pairs_info = [ + (key, addr_pair) for key, addr_pair in addr_infos.items() + if not ((local_addr and addr_pair[0] is None) or + (remote_addr and addr_pair[1] is None))] + + if not addr_pairs_info: + raise ValueError('can not get address information') + + exceptions = [] + + for ((family, proto), + (local_address, remote_address)) in addr_pairs_info: + sock = None + l_addr = None + r_addr = None + try: + sock = socket.socket( + family=family, type=socket.SOCK_DGRAM, proto=proto) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(False) + + if local_addr: + sock.bind(local_address) + l_addr = sock.getsockname() + if remote_addr: + yield from self.sock_connect(sock, remote_address) + r_addr = remote_address + except OSError as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + + protocol = protocol_factory() + transport = self._make_datagram_transport( + sock, protocol, r_addr, extra={'addr': l_addr}) + return transport, protocol + + # This returns a Task made from self._start_serving_internal(). + # We want start_serving() to return a Task so that it will start + # running right away (when the event loop runs) even if the caller + # doesn't wait for it. Note that this is different from + # e.g. create_connection(), or create_datagram_endpoint(), which + # are a "mere" coroutines and require their caller to wait for + # them. The reason for the difference is that only + # start_serving() creates multiple transports and protocols. + def start_serving(self, protocol_factory, host=None, port=None, + *, + family=socket.AF_UNSPEC, + flags=socket.AI_PASSIVE, + sock=None, + backlog=100, + ssl=None, + reuse_address=None): + coro = self._start_serving_internal(protocol_factory, host, port, + family=family, + flags=flags, + sock=sock, + backlog=backlog, + ssl=ssl, + reuse_address=reuse_address) + return tasks.Task(coro, loop=self) + + @tasks.coroutine + def _start_serving_internal(self, protocol_factory, host=None, port=None, + *, + family=socket.AF_UNSPEC, + flags=socket.AI_PASSIVE, + sock=None, + backlog=100, + ssl=None, + reuse_address=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + 'host/port and sock can not be specified at the same time') + + AF_INET6 = getattr(socket, 'AF_INET6', 0) + if reuse_address is None: + reuse_address = os.name == 'posix' and sys.platform != 'cygwin' + sockets = [] + if host == '': + host = None + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=0, flags=flags) + if not infos: + raise OSError('getaddrinfo() returned empty list') + + completed = False + try: + for res in infos: + af, socktype, proto, canonname, sa = res + sock = socket.socket(af, socktype, proto) + sockets.append(sock) + if reuse_address: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, + True) + # Disable IPv4/IPv6 dual stack support (enabled by + # default on Linux) which makes a single socket + # listen on both address families. + if af == AF_INET6 and hasattr(socket, 'IPPROTO_IPV6'): + sock.setsockopt(socket.IPPROTO_IPV6, + socket.IPV6_V6ONLY, + True) + try: + sock.bind(sa) + except OSError as err: + raise OSError(err.errno, 'error while attempting ' + 'to bind on address %r: %s' + % (sa, err.strerror.lower())) + completed = True + finally: + if not completed: + for sock in sockets: + sock.close() + else: + if sock is None: + raise ValueError( + 'host and port was not specified and no sock specified') + sockets = [sock] + + for sock in sockets: + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock, ssl) + return sockets + + @tasks.coroutine + def connect_read_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future(loop=self) + transport = self._make_read_pipe_transport(pipe, protocol, waiter, + extra={}) + yield from waiter + return transport, protocol + + @tasks.coroutine + def connect_write_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future(loop=self) + transport = self._make_write_pipe_transport(pipe, protocol, waiter, + extra={}) + yield from waiter + return transport, protocol + + @tasks.coroutine + def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + universal_newlines=False, shell=True, bufsize=0, + **kwargs): + assert not universal_newlines, "universal_newlines must be False" + assert shell, "shell must be True" + assert isinstance(cmd, str), cmd + protocol = protocol_factory() + transport = yield from self._make_subprocess_transport( + protocol, cmd, True, stdin, stdout, stderr, bufsize, + extra={}, **kwargs) + return transport, protocol + + @tasks.coroutine + def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + universal_newlines=False, shell=False, bufsize=0, + **kwargs): + assert not universal_newlines, "universal_newlines must be False" + assert not shell, "shell must be False" + protocol = protocol_factory() + transport = yield from self._make_subprocess_transport( + protocol, args, False, stdin, stdout, stderr, bufsize, + extra={}, **kwargs) + return transport, protocol + + def _add_callback(self, handle): + """Add a Handle to ready or scheduled.""" + assert isinstance(handle, events.Handle), 'A Handle is required here' + if handle._cancelled: + return + if isinstance(handle, events.TimerHandle): + heapq.heappush(self._scheduled, handle) + else: + self._ready.append(handle) + + def _add_callback_signalsafe(self, handle): + """Like _add_callback() but called from a signal handler.""" + self._add_callback(handle) + self._write_to_self() + + def _run_once(self): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0]._cancelled: + heapq.heappop(self._scheduled) + + timeout = None + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0]._when + deadline = max(0, when - self.time()) + if timeout is None: + timeout = deadline + else: + timeout = min(timeout, deadline) + + # TODO: Instrumentation only in debug mode? + t0 = self.time() + event_list = self._selector.select(timeout) + t1 = self.time() + argstr = '' if timeout is None else '{:.3f}'.format(timeout) + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + tulip_log.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + self._process_events(event_list) + + # Handle 'later' callbacks that are ready. + now = self.time() + while self._scheduled: + handle = self._scheduled[0] + if handle._when > now: + break + handle = heapq.heappop(self._scheduled) + self._ready.append(handle) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is threadsafe without using locks. + ntodo = len(self._ready) + for i in range(ntodo): + handle = self._ready.popleft() + if not handle._cancelled: + handle._run() + handle = None # Needed to break cycles when an exception occurs. diff --git a/tulip/constants.py b/tulip/constants.py new file mode 100644 index 00000000..79c3b931 --- /dev/null +++ b/tulip/constants.py @@ -0,0 +1,4 @@ +"""Constants.""" + + +LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5 diff --git a/tulip/events.py b/tulip/events.py new file mode 100644 index 00000000..7db2514d --- /dev/null +++ b/tulip/events.py @@ -0,0 +1,389 @@ +"""Event loop and event loop policy. + +Beyond the PEP: +- Only the main thread has a default event loop. +""" + +__all__ = ['AbstractEventLoopPolicy', 'DefaultEventLoopPolicy', + 'AbstractEventLoop', 'TimerHandle', 'Handle', 'make_handle', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', 'new_event_loop', + ] + +import subprocess +import sys +import threading +import socket + +from .log import tulip_log + + +class Handle: + """Object returned by callback registration methods.""" + + def __init__(self, callback, args): + self._callback = callback + self._args = args + self._cancelled = False + + def __repr__(self): + res = 'Handle({}, {})'.format(self._callback, self._args) + if self._cancelled: + res += '' + return res + + def cancel(self): + self._cancelled = True + + def _run(self): + try: + self._callback(*self._args) + except Exception: + tulip_log.exception('Exception in callback %s %r', + self._callback, self._args) + self = None # Needed to break cycles when an exception occurs. + + +def make_handle(callback, args): + # TODO: Inline this? + assert not isinstance(callback, Handle), 'A Handle is not a callback' + return Handle(callback, args) + + +class TimerHandle(Handle): + """Object returned by timed callback registration methods.""" + + def __init__(self, when, callback, args): + assert when is not None + super().__init__(callback, args) + + self._when = when + + def __repr__(self): + res = 'TimerHandle({}, {}, {})'.format(self._when, + self._callback, + self._args) + if self._cancelled: + res += '' + + return res + + def __hash__(self): + return hash(self._when) + + def __lt__(self, other): + return self._when < other._when + + def __le__(self, other): + if self._when < other._when: + return True + return self.__eq__(other) + + def __gt__(self, other): + return self._when > other._when + + def __ge__(self, other): + if self._when > other._when: + return True + return self.__eq__(other) + + def __eq__(self, other): + if isinstance(other, TimerHandle): + return (self._when == other._when and + self._callback == other._callback and + self._args == other._args and + self._cancelled == other._cancelled) + return NotImplemented + + def __ne__(self, other): + equal = self.__eq__(other) + return NotImplemented if equal is NotImplemented else not equal + + +class AbstractEventLoop: + """Abstract event loop.""" + + # Running and stopping the event loop. + + def run_forever(self): + """Run the event loop until stop() is called.""" + raise NotImplementedError + + def run_until_complete(self, future): + """Run the event loop until a Future is done. + + Return the Future's result, or raise its exception. + """ + raise NotImplementedError + + def stop(self): + """Stop the event loop as soon as reasonable. + + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + raise NotImplementedError + + def is_running(self): + """Return whether the event loop is currently running.""" + raise NotImplementedError + + # Methods scheduling callbacks. All these return Handles. + + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) + + def call_later(self, delay, callback, *args): + raise NotImplementedError + + def call_at(self, when, callback, *args): + raise NotImplementedError + + def time(self): + raise NotImplementedError + + # Methods for interacting with threads. + + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError + + def run_in_executor(self, executor, callback, *args): + raise NotImplementedError + + def set_default_executor(self, executor): + raise NotImplementedError + + # Network I/O methods returning Futures. + + def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError + + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=None, family=0, proto=0, flags=0, sock=None, + local_addr=None): + raise NotImplementedError + + def start_serving(self, protocol_factory, host=None, port=None, *, + family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, + sock=None, backlog=100, ssl=None, reuse_address=None): + """Creates a TCP server bound to host and port and return a + Task whose result will be a list of socket objects which will + later be handled by protocol_factory. + + If host is an empty string or None all interfaces are assumed + and a list of multiple sockets will be returned (most likely + one for IPv4 and another one for IPv6). + + family can be set to either AF_INET or AF_INET6 to force the + socket to use IPv4 or IPv6. If not set it will be determined + from host (defaults to AF_UNSPEC). + + flags is a bitmask for getaddrinfo(). + + sock can optionally be specified in order to use a preexisting + socket object. + + backlog is the maximum number of queued connections passed to + listen() (defaults to 100). + + ssl can be set to an SSLContext to enable SSL over the + accepted connections. + + reuse_address tells the kernel to reuse a local socket in + TIME_WAIT state, without waiting for its natural timeout to + expire. If not specified will automatically be set to True on + UNIX. + """ + raise NotImplementedError + + def stop_serving(self, sock): + """Stop listening for incoming connections. Close socket.""" + raise NotImplementedError + + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + def connect_read_pipe(self, protocol_factory, pipe): + """Register read pipe in eventloop. + + protocol_factory should instantiate object with Protocol interface. + pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + ReadTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def connect_write_pipe(self, protocol_factory, pipe): + """Register write pipe in eventloop. + + protocol_factory should instantiate object with BaseProtocol interface. + Pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + WriteTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + **kwargs): + raise NotImplementedError + + def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + **kwargs): + raise NotImplementedError + + # Ready-based callback registration methods. + # The add_*() methods return None. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. + + def add_reader(self, fd, callback, *args): + raise NotImplementedError + + def remove_reader(self, fd): + raise NotImplementedError + + def add_writer(self, fd, callback, *args): + raise NotImplementedError + + def remove_writer(self, fd): + raise NotImplementedError + + # Completion based I/O methods returning Futures. + + def sock_recv(self, sock, nbytes): + raise NotImplementedError + + def sock_sendall(self, sock, data): + raise NotImplementedError + + def sock_connect(self, sock, address): + raise NotImplementedError + + def sock_accept(self, sock): + raise NotImplementedError + + # Signal handling. + + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError + + def remove_signal_handler(self, sig): + raise NotImplementedError + + +class AbstractEventLoopPolicy: + """Abstract policy for accessing the event loop.""" + + def get_event_loop(self): + """XXX""" + raise NotImplementedError + + def set_event_loop(self, loop): + """XXX""" + raise NotImplementedError + + def new_event_loop(self): + """XXX""" + raise NotImplementedError + + +class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy): + """Default policy implementation for accessing the event loop. + + In this policy, each thread has its own event loop. However, we + only automatically create an event loop by default for the main + thread; other threads by default have no event loop. + + Other policies may have different rules (e.g. a single global + event loop, or automatically creating an event loop per thread, or + using some other notion of context to which an event loop is + associated). + """ + + _loop = None + _set_called = False + + def get_event_loop(self): + """Get the event loop. + + This may be None or an instance of EventLoop. + """ + if (self._loop is None and + not self._set_called and + isinstance(threading.current_thread(), threading._MainThread)): + self._loop = self.new_event_loop() + assert self._loop is not None, \ + ('There is no current event loop in thread %r.' % + threading.current_thread().name) + return self._loop + + def set_event_loop(self, loop): + """Set the event loop.""" + # TODO: The isinstance() test violates the PEP. + self._set_called = True + assert loop is None or isinstance(loop, AbstractEventLoop) + self._loop = loop + + def new_event_loop(self): + """Create a new event loop. + + You must call set_event_loop() to make this the current event + loop. + """ + if sys.platform == 'win32': # pragma: no cover + from . import windows_events + return windows_events.SelectorEventLoop() + else: # pragma: no cover + from . import unix_events + return unix_events.SelectorEventLoop() + + +# Event loop policy. The policy itself is always global, even if the +# policy's rules say that there is an event loop per thread (or other +# notion of context). The default policy is installed by the first +# call to get_event_loop_policy(). +_event_loop_policy = None + + +def get_event_loop_policy(): + """XXX""" + global _event_loop_policy + if _event_loop_policy is None: + _event_loop_policy = DefaultEventLoopPolicy() + return _event_loop_policy + + +def set_event_loop_policy(policy): + """XXX""" + global _event_loop_policy + # TODO: The isinstance() test violates the PEP. + assert policy is None or isinstance(policy, AbstractEventLoopPolicy) + _event_loop_policy = policy + + +def get_event_loop(): + """XXX""" + return get_event_loop_policy().get_event_loop() + + +def set_event_loop(loop): + """XXX""" + get_event_loop_policy().set_event_loop(loop) + + +def new_event_loop(): + """XXX""" + return get_event_loop_policy().new_event_loop() diff --git a/tulip/futures.py b/tulip/futures.py new file mode 100644 index 00000000..706e8c8a --- /dev/null +++ b/tulip/futures.py @@ -0,0 +1,338 @@ +"""A Future class similar to the one in PEP 3148.""" + +__all__ = ['CancelledError', 'TimeoutError', + 'InvalidStateError', + 'Future', 'wrap_future', + ] + +import concurrent.futures._base +import logging +import traceback + +from . import events +from .log import tulip_log + +# States for Future. +_PENDING = 'PENDING' +_CANCELLED = 'CANCELLED' +_FINISHED = 'FINISHED' + +# TODO: Do we really want to depend on concurrent.futures internals? +Error = concurrent.futures._base.Error +CancelledError = concurrent.futures.CancelledError +TimeoutError = concurrent.futures.TimeoutError + +STACK_DEBUG = logging.DEBUG - 1 # heavy-duty debugging + + +class InvalidStateError(Error): + """The operation is not allowed in this state.""" + # TODO: Show the future, its state, the method, and the required state. + + +class _TracebackLogger: + """Helper to log a traceback upon destruction if not cleared. + + This solves a nasty problem with Futures and Tasks that have an + exception set: if nobody asks for the exception, the exception is + never logged. This violates the Zen of Python: 'Errors should + never pass silently. Unless explicitly silenced.' + + However, we don't want to log the exception as soon as + set_exception() is called: if the calling code is written + properly, it will get the exception and handle it properly. But + we *do* want to log it if result() or exception() was never called + -- otherwise developers waste a lot of time wondering why their + buggy code fails silently. + + An earlier attempt added a __del__() method to the Future class + itself, but this backfired because the presence of __del__() + prevents garbage collection from breaking cycles. A way out of + this catch-22 is to avoid having a __del__() method on the Future + class itself, but instead to have a reference to a helper object + with a __del__() method that logs the traceback, where we ensure + that the helper object doesn't participate in cycles, and only the + Future has a reference to it. + + The helper object is added when set_exception() is called. When + the Future is collected, and the helper is present, the helper + object is also collected, and its __del__() method will log the + traceback. When the Future's result() or exception() method is + called (and a helper object is present), it removes the the helper + object, after calling its clear() method to prevent it from + logging. + + One downside is that we do a fair amount of work to extract the + traceback from the exception, even when it is never logged. It + would seem cheaper to just store the exception object, but that + references the traceback, which references stack frames, which may + reference the Future, which references the _TracebackLogger, and + then the _TracebackLogger would be included in a cycle, which is + what we're trying to avoid! As an optimization, we don't + immediately format the exception; we only do the work when + activate() is called, which call is delayed until after all the + Future's callbacks have run. Since usually a Future has at least + one callback (typically set by 'yield from') and usually that + callback extracts the callback, thereby removing the need to + format the exception. + + PS. I don't claim credit for this solution. I first heard of it + in a discussion about closing files when they are collected. + """ + + __slots__ = ['exc', 'tb'] + + def __init__(self, exc): + self.exc = exc + self.tb = None + + def activate(self): + exc = self.exc + if exc is not None: + self.exc = None + self.tb = traceback.format_exception(exc.__class__, exc, + exc.__traceback__) + + def clear(self): + self.exc = None + self.tb = None + + def __del__(self): + if self.tb: + tulip_log.error('Future/Task exception was never retrieved:\n%s', + ''.join(self.tb)) + + +class Future: + """This class is *almost* compatible with concurrent.futures.Future. + + Differences: + + - result() and exception() do not take a timeout argument and + raise an exception when the future isn't done yet. + + - Callbacks registered with add_done_callback() are always called + via the event loop's call_soon_threadsafe(). + + - This class is not compatible with the wait() and as_completed() + methods in the concurrent.futures package. + + (In Python 3.4 or later we may be able to unify the implementations.) + """ + + # Class variables serving as defaults for instance variables. + _state = _PENDING + _result = None + _exception = None + _loop = None + + _blocking = False # proper use of future (yield vs yield from) + + _tb_logger = None + + def __init__(self, *, loop=None): + """Initialize the future. + + The optional event_loop argument allows to explicitly set the event + loop object used by the future. If it's not provided, the future uses + the default event loop. + """ + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._callbacks = [] + + def __repr__(self): + res = self.__class__.__name__ + if self._state == _FINISHED: + if self._exception is not None: + res += ''.format(self._exception) + else: + res += ''.format(self._result) + elif self._callbacks: + size = len(self._callbacks) + if size > 2: + res += '<{}, [{}, <{} more>, {}]>'.format( + self._state, self._callbacks[0], + size-2, self._callbacks[-1]) + else: + res += '<{}, {}>'.format(self._state, self._callbacks) + else: + res += '<{}>'.format(self._state) + return res + + def cancel(self): + """Cancel the future and schedule callbacks. + + If the future is already done or cancelled, return False. Otherwise, + change the future's state to cancelled, schedule the callbacks and + return True. + """ + if self._state != _PENDING: + return False + self._state = _CANCELLED + self._schedule_callbacks() + return True + + def _schedule_callbacks(self): + """Internal: Ask the event loop to call all callbacks. + + The callbacks are scheduled to be called as soon as possible. Also + clears the callback list. + """ + callbacks = self._callbacks[:] + if not callbacks: + return + + self._callbacks[:] = [] + for callback in callbacks: + self._loop.call_soon(callback, self) + + def cancelled(self): + """Return True if the future was cancelled.""" + return self._state == _CANCELLED + + # Don't implement running(); see http://bugs.python.org/issue18699 + + def done(self): + """Return True if the future is done. + + Done means either that a result / exception are available, or that the + future was cancelled. + """ + return self._state != _PENDING + + def result(self): + """Return the result this future represents. + + If the future has been cancelled, raises CancelledError. If the + future's result isn't yet available, raises InvalidStateError. If + the future is done and has an exception set, this exception is raised. + """ + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError('Result is not ready.') + if self._tb_logger is not None: + self._tb_logger.clear() + self._tb_logger = None + if self._exception is not None: + raise self._exception + return self._result + + def exception(self): + """Return the exception that was set on this future. + + The exception (or None if no exception was set) is returned only if + the future is done. If the future has been cancelled, raises + CancelledError. If the future isn't done yet, raises + InvalidStateError. + """ + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError('Exception is not set.') + if self._tb_logger is not None: + self._tb_logger.clear() + self._tb_logger = None + return self._exception + + def add_done_callback(self, fn): + """Add a callback to be run when the future becomes done. + + The callback is called with a single argument - the future object. If + the future is already done when this is called, the callback is + scheduled with call_soon. + """ + if self._state != _PENDING: + self._loop.call_soon(fn, self) + else: + self._callbacks.append(fn) + + # New method not in PEP 3148. + + def remove_done_callback(self, fn): + """Remove all instances of a callback from the "call when done" list. + + Returns the number of callbacks removed. + """ + filtered_callbacks = [f for f in self._callbacks if f != fn] + removed_count = len(self._callbacks) - len(filtered_callbacks) + if removed_count: + self._callbacks[:] = filtered_callbacks + return removed_count + + # So-called internal methods (note: no set_running_or_notify_cancel()). + + def set_result(self, result): + """Mark the future done and set its result. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError('{}: {!r}'.format(self._state, self)) + self._result = result + self._state = _FINISHED + self._schedule_callbacks() + + def set_exception(self, exception): + """Mark the future done and set an exception. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError('{}: {!r}'.format(self._state, self)) + self._exception = exception + self._tb_logger = _TracebackLogger(exception) + self._state = _FINISHED + self._schedule_callbacks() + # Arrange for the logger to be activated after all callbacks + # have had a chance to call result() or exception(). + self._loop.call_soon(self._tb_logger.activate) + + # Truly internal methods. + + def _copy_state(self, other): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert other.done() + assert not self.done() + if other.cancelled(): + self.cancel() + else: + exception = other.exception() + if exception is not None: + self.set_exception(exception) + else: + result = other.result() + self.set_result(result) + + def __iter__(self): + if not self.done(): + self._blocking = True + yield self # This tells Task to wait for completion. + assert self.done(), "yield from wasn't used with future" + return self.result() # May raise too. + + +def wrap_future(fut, *, loop=None): + """Wrap concurrent.futures.Future object.""" + if isinstance(fut, Future): + return fut + + assert isinstance(fut, concurrent.futures.Future), \ + 'concurrent.futures.Future is expected, got {!r}'.format(fut) + + if loop is None: + loop = events.get_event_loop() + + new_future = Future(loop=loop) + fut.add_done_callback( + lambda future: loop.call_soon_threadsafe( + new_future._copy_state, fut)) + return new_future diff --git a/tulip/http/__init__.py b/tulip/http/__init__.py new file mode 100644 index 00000000..a1432dee --- /dev/null +++ b/tulip/http/__init__.py @@ -0,0 +1,16 @@ +# This relies on each of the submodules having an __all__ variable. + +from .client import * +from .errors import * +from .protocol import * +from .server import * +from .session import * +from .wsgi import * + + +__all__ = (client.__all__ + + errors.__all__ + + protocol.__all__ + + server.__all__ + + session.__all__ + + wsgi.__all__) diff --git a/tulip/http/client.py b/tulip/http/client.py new file mode 100644 index 00000000..ec7cd034 --- /dev/null +++ b/tulip/http/client.py @@ -0,0 +1,572 @@ +"""HTTP Client for Tulip. + +Most basic usage: + + response = yield from tulip.http.request('GET', url) + response['Content-Type'] == 'application/json' + response.status == 200 + + content = yield from response.content.read() +""" + +__all__ = ['request'] + +import base64 +import email.message +import functools +import http.client +import http.cookies +import json +import io +import itertools +import mimetypes +import os +import uuid +import urllib.parse + +import tulip +import tulip.http + + +@tulip.coroutine +def request(method, url, *, + params=None, + data=None, + headers=None, + cookies=None, + files=None, + auth=None, + allow_redirects=True, + max_redirects=10, + encoding='utf-8', + version=(1, 1), + timeout=None, + compress=None, + chunked=None, + session=None, + loop=None): + """Constructs and sends a request. Returns response object. + + method: http method + url: request url + params: (optional) Dictionary or bytes to be sent in the query string + of the new request + data: (optional) Dictionary, bytes, or file-like object to + send in the body of the request + headers: (optional) Dictionary of HTTP Headers to send with the request + cookies: (optional) Dict object to send with the request + files: (optional) Dictionary of 'name': file-like-objects + for multipart encoding upload + auth: (optional) Auth tuple to enable Basic HTTP Auth + timeout: (optional) Float describing the timeout of the request + allow_redirects: (optional) Boolean. Set to True if POST/PUT/DELETE + redirect following is allowed. + compress: Boolean. Set to True if request has to be compressed + with deflate encoding. + chunked: Boolean or Integer. Set to chunk size for chunked + transfer encoding. + session: tulip.http.Session instance to support connection pooling and + session cookies. + loop: Optional event loop. + + Usage: + + import tulip.http + >> resp = yield from tulip.http.request('GET', 'http://python.org/') + >> resp + + + >> data = yield from resp.content.read() + + """ + redirects = 0 + if loop is None: + loop = tulip.get_event_loop() + + while True: + req = HttpRequest( + method, url, params=params, headers=headers, data=data, + cookies=cookies, files=files, auth=auth, encoding=encoding, + version=version, compress=compress, chunked=chunked) + + if session is None: + conn = start(req, loop) + else: + conn = session.start(req, loop) + + # connection timeout + t = tulip.Task(conn, loop=loop) + th = None + if timeout is not None: + th = loop.call_later(timeout, t.cancel) + try: + resp = yield from t + except tulip.CancelledError: + raise tulip.TimeoutError from None + finally: + if th is not None: + th.cancel() + + # redirects + if resp.status in (301, 302) and allow_redirects: + redirects += 1 + if max_redirects and redirects >= max_redirects: + resp.close() + break + + r_url = resp.get('location') or resp.get('uri') + + scheme = urllib.parse.urlsplit(r_url)[0] + if scheme not in ('http', 'https', ''): + raise ValueError('Can redirect only to http or https') + elif not scheme: + r_url = urllib.parse.urljoin(url, r_url) + + url = urllib.parse.urldefrag(r_url)[0] + if url: + resp.close() + continue + + break + + return resp + + +@tulip.coroutine +def start(req, loop): + transport, p = yield from loop.create_connection( + functools.partial(tulip.StreamProtocol, loop=loop), + req.host, req.port, ssl=req.ssl) + + try: + resp = req.send(transport) + yield from resp.start(p, transport) + except: + transport.close() + raise + + return resp + + +class HttpRequest: + + GET_METHODS = {'DELETE', 'GET', 'HEAD', 'OPTIONS'} + POST_METHODS = {'PATCH', 'POST', 'PUT', 'TRACE'} + ALL_METHODS = GET_METHODS.union(POST_METHODS) + + DEFAULT_HEADERS = { + 'Accept': '*/*', + 'Accept-Encoding': 'gzip, deflate', + } + + body = b'' + + def __init__(self, method, url, *, + params=None, + headers=None, + data=None, + cookies=None, + files=None, + auth=None, + encoding='utf-8', + version=(1, 1), + compress=None, + chunked=None): + self.method = method.upper() + self.encoding = encoding + + # parser http version '1.1' => (1, 1) + if isinstance(version, str): + v = [l.strip() for l in version.split('.', 1)] + try: + version = int(v[0]), int(v[1]) + except ValueError: + raise ValueError( + 'Can not parse http version number: {}' + .format(version)) from None + self.version = version + + # path + scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) + if not netloc: + raise ValueError('Host could not be detected.') + + if not path: + path = '/' + else: + path = urllib.parse.unquote(path) + + # check domain idna encoding + try: + netloc = netloc.encode('idna').decode('utf-8') + except UnicodeError: + raise ValueError('URL has an invalid label.') + + # basic auth info + if '@' in netloc: + authinfo, netloc = netloc.split('@', 1) + if not auth: + auth = authinfo.split(':', 1) + if len(auth) == 1: + auth.append('') + + # extract host and port + ssl = scheme == 'https' + + if ':' in netloc: + netloc, port_s = netloc.split(':', 1) + try: + port = int(port_s) + except ValueError: + raise ValueError( + 'Port number could not be converted.') from None + else: + if ssl: + port = http.client.HTTPS_PORT + else: + port = http.client.HTTP_PORT + + self.host = netloc + self.port = port + self.ssl = ssl + + # build url query + if isinstance(params, dict): + params = list(params.items()) + + if data and self.method in self.GET_METHODS: + # include data to query + if isinstance(data, dict): + data = data.items() + params = list(itertools.chain(params or (), data)) + data = None + + if params: + params = urllib.parse.urlencode(params) + if query: + query = '%s&%s' % (query, params) + else: + query = params + + # build path + path = urllib.parse.quote(path) + self.path = urllib.parse.urlunsplit(('', '', path, query, fragment)) + + # headers + self.headers = email.message.Message() + if headers: + if isinstance(headers, dict): + headers = list(headers.items()) + + for key, value in headers: + self.headers.add_header(key, value) + + for hdr, val in self.DEFAULT_HEADERS.items(): + if hdr not in self.headers: + self.headers[hdr] = val + + # host + if 'host' not in self.headers: + self.headers['Host'] = self.host + + # cookies + if cookies: + self.update_cookies(cookies) + + # auth + if auth: + if isinstance(auth, (tuple, list)) and len(auth) == 2: + # basic auth + self.headers['Authorization'] = 'Basic %s' % ( + base64.b64encode( + ('%s:%s' % (auth[0], auth[1])).encode('latin1')) + .strip().decode('latin1')) + else: + raise ValueError("Only basic auth is supported") + + # Content-encoding + enc = self.headers.get('Content-Encoding', '').lower() + if enc: + chunked = True # enable chunked, no need to deal with length + compress = enc + elif compress: + chunked = True # enable chunked, no need to deal with length + compress = compress if isinstance(compress, str) else 'deflate' + self.headers['Content-Encoding'] = compress + + # form data (x-www-form-urlencoded) + if isinstance(data, dict): + data = list(data.items()) + + if data and not files: + if not isinstance(data, str): + data = urllib.parse.urlencode(data, doseq=True) + + self.body = data.encode(encoding) + if 'content-type' not in self.headers: + self.headers['content-type'] = ( + 'application/x-www-form-urlencoded') + if 'content-length' not in self.headers and not chunked: + self.headers['content-length'] = str(len(self.body)) + + # files (multipart/form-data) + elif files: + fields = [] + + if data: + for field, val in data: + fields.append((field, str_to_bytes(val))) + + if isinstance(files, dict): + files = list(files.items()) + + for rec in files: + if not isinstance(rec, (tuple, list)): + rec = (rec,) + + ft = None + if len(rec) == 1: + k = guess_filename(rec[0], 'unknown') + fields.append((k, k, rec[0])) + + elif len(rec) == 2: + k, fp = rec + fn = guess_filename(fp, k) + fields.append((k, fn, fp)) + + else: + k, fp, ft = rec + fn = guess_filename(fp, k) + fields.append((k, fn, fp, ft)) + + chunked = chunked or 8192 + boundary = uuid.uuid4().hex + + self.body = encode_multipart_data( + fields, bytes(boundary, 'latin1')) + + self.headers['content-type'] = ( + 'multipart/form-data; boundary=%s' % boundary) + + # chunked + te = self.headers.get('transfer-encoding', '').lower() + + if chunked: + if 'content-length' in self.headers: + del self.headers['content-length'] + if 'chunked' not in te: + self.headers['transfer-encoding'] = 'chunked' + + chunked = chunked if type(chunked) is int else 8196 + else: + if 'chunked' in te: + chunked = 8196 + else: + chunked = None + self.headers['content-length'] = str(len(self.body)) + + self._chunked = chunked + self._compress = compress + + def update_cookies(self, cookies): + """Update request cookies header.""" + c = http.cookies.SimpleCookie() + if 'cookie' in self.headers: + c.load(self.headers.get('cookie', '')) + del self.headers['cookie'] + + if isinstance(cookies, dict): + cookies = cookies.items() + + for name, value in cookies: + if isinstance(value, http.cookies.Morsel): + # use dict method because SimpleCookie class modifies value + dict.__setitem__(c, name, value) + else: + c[name] = value + + self.headers['cookie'] = c.output(header='', sep=';').strip() + + def send(self, transport): + request = tulip.http.Request( + transport, self.method, self.path, self.version) + + if self._compress: + request.add_compression_filter(self._compress) + + if self._chunked is not None: + request.add_chunking_filter(self._chunked) + + request.add_headers(*self.headers.items()) + request.send_headers() + + if isinstance(self.body, bytes): + self.body = (self.body,) + + for chunk in self.body: + request.write(chunk) + + request.write_eof() + + return HttpResponse(self.method, self.path, self.host) + + +class HttpResponse(http.client.HTTPMessage): + + message = None # RawResponseMessage object + + # from the Status-Line of the response + version = None # HTTP-Version + status = None # Status-Code + reason = None # Reason-Phrase + + cookies = None # Response cookies (Set-Cookie) + + content = None # payload stream + stream = None # input stream + transport = None # current transport + + def __init__(self, method, url, host=''): + super().__init__() + + self.method = method + self.url = url + self.host = host + self._content = None + + def __del__(self): + self.close() + + def __repr__(self): + out = io.StringIO() + print(''.format( + self.host, self.url, self.status, self.reason), file=out) + print(super().__str__(), file=out) + return out.getvalue() + + def start(self, stream, transport): + """Start response processing.""" + self.stream = stream + self.transport = transport + + httpstream = stream.set_parser(tulip.http.http_response_parser()) + + # read response + self.message = yield from httpstream.read() + + # response status + self.version = self.message.version + self.status = self.message.code + self.reason = self.message.reason + + # headers + for hdr, val in self.message.headers: + self.add_header(hdr, val) + + # payload + self.content = stream.set_parser( + tulip.http.http_payload_parser(self.message)) + + # cookies + self.cookies = http.cookies.SimpleCookie() + if 'Set-Cookie' in self: + for hdr in self.get_all('Set-Cookie'): + self.cookies.load(hdr) + + return self + + def close(self): + if self.transport is not None: + self.transport.close() + self.transport = None + + @tulip.coroutine + def read(self, decode=False): + """Read response payload. Decode known types of content.""" + if self._content is None: + buf = [] + total = 0 + chunk = yield from self.content.read() + while chunk: + size = len(chunk) + buf.append((chunk, size)) + total += size + chunk = yield from self.content.read() + + self._content = bytearray(total) + + idx = 0 + content = memoryview(self._content) + for chunk, size in buf: + content[idx:idx+size] = chunk + idx += size + + data = self._content + + if decode: + ct = self.get('content-type', '').lower() + if ct == 'application/json': + data = json.loads(data.decode('utf-8')) + + return data + + +def str_to_bytes(s, encoding='utf-8'): + if isinstance(s, str): + return s.encode(encoding) + return s + + +def guess_filename(obj, default=None): + name = getattr(obj, 'name', None) + if name and name[0] != '<' and name[-1] != '>': + return os.path.split(name)[-1] + return default + + +def encode_multipart_data(fields, boundary, encoding='utf-8', chunk_size=8196): + """ + Encode a list of fields using the multipart/form-data MIME format. + + fields: + List of (name, value) or (name, filename, io) or + (name, filename, io, MIME type) field tuples. + """ + for rec in fields: + yield b'--' + boundary + b'\r\n' + + field, *rec = rec + + if len(rec) == 1: + data = rec[0] + yield (('Content-Disposition: form-data; name="%s"\r\n\r\n' % + (field,)).encode(encoding)) + yield data + b'\r\n' + + else: + if len(rec) == 3: + fn, fp, ct = rec + else: + fn, fp = rec + ct = (mimetypes.guess_type(fn)[0] or + 'application/octet-stream') + + yield ('Content-Disposition: form-data; name="%s"; ' + 'filename="%s"\r\n' % (field, fn)).encode(encoding) + yield ('Content-Type: %s\r\n\r\n' % (ct,)).encode(encoding) + + if isinstance(fp, str): + fp = fp.encode(encoding) + + if isinstance(fp, bytes): + fp = io.BytesIO(fp) + + while True: + chunk = fp.read(chunk_size) + if not chunk: + break + yield str_to_bytes(chunk) + + yield b'\r\n' + + yield b'--' + boundary + b'--\r\n' diff --git a/tulip/http/errors.py b/tulip/http/errors.py new file mode 100644 index 00000000..f8b77e9b --- /dev/null +++ b/tulip/http/errors.py @@ -0,0 +1,46 @@ +"""http related errors.""" + +__all__ = ['HttpException', 'HttpErrorException', 'BadRequestException', + 'IncompleteRead', 'BadStatusLine', 'LineTooLong', 'InvalidHeader'] + +import http.client + + +class HttpException(http.client.HTTPException): + + code = None + headers = () + message = '' + + +class HttpErrorException(HttpException): + + def __init__(self, code, message='', headers=None): + self.code = code + self.headers = headers + self.message = message + + +class BadRequestException(HttpException): + + code = 400 + message = 'Bad Request' + + +class IncompleteRead(BadRequestException, http.client.IncompleteRead): + pass + + +class BadStatusLine(BadRequestException, http.client.BadStatusLine): + pass + + +class LineTooLong(BadRequestException, http.client.LineTooLong): + pass + + +class InvalidHeader(BadRequestException): + + def __init__(self, hdr): + super().__init__('Invalid HTTP Header: {}'.format(hdr)) + self.hdr = hdr diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py new file mode 100644 index 00000000..7081fd59 --- /dev/null +++ b/tulip/http/protocol.py @@ -0,0 +1,756 @@ +"""Http related helper utils.""" + +__all__ = ['HttpMessage', 'Request', 'Response', + 'RawRequestMessage', 'RawResponseMessage', + 'http_request_parser', 'http_response_parser', + 'http_payload_parser'] + +import collections +import functools +import http.server +import itertools +import re +import sys +import zlib +from wsgiref.handlers import format_date_time + +import tulip +from tulip.http import errors + +METHRE = re.compile('[A-Z0-9$-_.]+') +VERSRE = re.compile('HTTP/(\d+).(\d+)') +HDRRE = re.compile('[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]') +CONTINUATION = (' ', '\t') +EOF_MARKER = object() +EOL_MARKER = object() + +RESPONSES = http.server.BaseHTTPRequestHandler.responses + + +RawRequestMessage = collections.namedtuple( + 'RawRequestMessage', + ['method', 'path', 'version', 'headers', 'should_close', 'compression']) + + +RawResponseMessage = collections.namedtuple( + 'RawResponseMessage', + ['version', 'code', 'reason', 'headers', 'should_close', 'compression']) + + +def http_request_parser(max_line_size=8190, + max_headers=32768, max_field_size=8190): + """Read request status line. Exception errors.BadStatusLine + could be raised in case of any errors in status line. + Returns RawRequestMessage. + """ + out, buf = yield + + try: + # read http message (request line + headers) + raw_data = yield from buf.readuntil( + b'\r\n\r\n', max_headers, errors.LineTooLong) + lines = raw_data.decode('ascii', 'surrogateescape').splitlines(True) + + # request line + line = lines[0] + try: + method, path, version = line.split(None, 2) + except ValueError: + raise errors.BadStatusLine(line) from None + + # method + method = method.upper() + if not METHRE.match(method): + raise errors.BadStatusLine(method) + + # version + match = VERSRE.match(version) + if match is None: + raise errors.BadStatusLine(version) + version = (int(match.group(1)), int(match.group(2))) + + # read headers + headers, close, compression = parse_headers( + lines, max_line_size, max_headers, max_field_size) + if version <= (1, 0): + close = True + elif close is None: + close = False + + out.feed_data( + RawRequestMessage( + method, path, version, headers, close, compression)) + out.feed_eof() + except tulip.EofStream: + # Presumably, the server closed the connection before + # sending a valid response. + pass + + +def http_response_parser(max_line_size=8190, + max_headers=32768, max_field_size=8190): + """Read response status line and headers. + + BadStatusLine could be raised in case of any errors in status line. + Returns RawResponseMessage""" + out, buf = yield + + try: + # read http message (response line + headers) + raw_data = yield from buf.readuntil( + b'\r\n\r\n', max_line_size+max_headers, errors.LineTooLong) + lines = raw_data.decode('ascii', 'surrogateescape').splitlines(True) + + line = lines[0] + try: + version, status = line.split(None, 1) + except ValueError: + raise errors.BadStatusLine(line) from None + else: + try: + status, reason = status.split(None, 1) + except ValueError: + reason = '' + + # version + match = VERSRE.match(version) + if match is None: + raise errors.BadStatusLine(line) + version = (int(match.group(1)), int(match.group(2))) + + # The status code is a three-digit number + try: + status = int(status) + except ValueError: + raise errors.BadStatusLine(line) from None + + if status < 100 or status > 999: + raise errors.BadStatusLine(line) + + # read headers + headers, close, compression = parse_headers( + lines, max_line_size, max_headers, max_field_size) + if close is None: + close = version <= (1, 0) + + out.feed_data( + RawResponseMessage( + version, status, reason.strip(), headers, close, compression)) + out.feed_eof() + except tulip.EofStream: + # Presumably, the server closed the connection before + # sending a valid response. + raise errors.BadStatusLine(b'') from None + + +def parse_headers(lines, max_line_size, max_headers, max_field_size): + """Parses RFC2822 headers from a stream. + + Line continuations are supported. Returns list of header name + and value pairs. Header name is in upper case. + """ + close_conn = None + encoding = None + headers = collections.deque() + + lines_idx = 1 + line = lines[1] + + while line not in ('\r\n', '\n'): + header_length = len(line) + + # Parse initial header name : value pair. + try: + name, value = line.split(':', 1) + except ValueError: + raise ValueError('Invalid header: {}'.format(line)) from None + + name = name.strip(' \t').upper() + if HDRRE.search(name): + raise ValueError('Invalid header name: {}'.format(name)) + + # next line + lines_idx += 1 + line = lines[lines_idx] + + # consume continuation lines + continuation = line[0] in CONTINUATION + + if continuation: + value = [value] + while continuation: + header_length += len(line) + if header_length > max_field_size: + raise errors.LineTooLong( + 'limit request headers fields size') + value.append(line) + + # next line + lines_idx += 1 + line = lines[lines_idx] + continuation = line[0] in CONTINUATION + value = ''.join(value) + else: + if header_length > max_field_size: + raise errors.LineTooLong('limit request headers fields size') + + value = value.strip() + + # keep-alive and encoding + if name == 'CONNECTION': + v = value.lower() + if v == 'close': + close_conn = True + elif v == 'keep-alive': + close_conn = False + elif name == 'CONTENT-ENCODING': + enc = value.lower() + if enc in ('gzip', 'deflate'): + encoding = enc + + headers.append((name, value)) + + return headers, close_conn, encoding + + +def http_payload_parser(message, length=None, compression=True, readall=False): + out, buf = yield + + # payload params + chunked = False + for name, value in message.headers: + if name == 'CONTENT-LENGTH': + length = value + elif name == 'TRANSFER-ENCODING': + chunked = value.lower() == 'chunked' + elif name == 'SEC-WEBSOCKET-KEY1': + length = 8 + + # payload decompression wrapper + if compression and message.compression: + out = DeflateBuffer(out, message.compression) + + # payload parser + if chunked: + yield from parse_chunked_payload(out, buf) + + elif length is not None: + try: + length = int(length) + except ValueError: + raise errors.InvalidHeader('CONTENT-LENGTH') from None + + if length < 0: + raise errors.InvalidHeader('CONTENT-LENGTH') + elif length > 0: + yield from parse_length_payload(out, buf, length) + else: + if readall: + yield from parse_eof_payload(out, buf) + + out.feed_eof() + + +def parse_chunked_payload(out, buf): + """Chunked transfer encoding parser.""" + try: + while True: + # read next chunk size + #line = yield from buf.readline(8196) + line = yield from buf.readuntil(b'\r\n', 8196) + + i = line.find(b';') + if i >= 0: + line = line[:i] # strip chunk-extensions + else: + line = line.strip() + try: + size = int(line, 16) + except ValueError: + raise errors.IncompleteRead(b'') from None + + if size == 0: # eof marker + break + + # read chunk and feed buffer + while size: + chunk = yield from buf.readsome(size) + out.feed_data(chunk) + size = size - len(chunk) + + # toss the CRLF at the end of the chunk + yield from buf.skip(2) + + # read and discard trailer up to the CRLF terminator + yield from buf.skipuntil(b'\r\n') + + except tulip.EofStream: + raise errors.IncompleteRead(b'') from None + + +def parse_length_payload(out, buf, length): + """Read specified amount of bytes.""" + try: + while length: + chunk = yield from buf.readsome(length) + out.feed_data(chunk) + length -= len(chunk) + + except tulip.EofStream: + raise errors.IncompleteRead(b'') from None + + +def parse_eof_payload(out, buf): + """Read all bytes untile eof.""" + while True: + out.feed_data((yield from buf.readsome())) + + +class DeflateBuffer: + """DeflateStream decomress stream and feed data into specified stream.""" + + def __init__(self, out, encoding): + self.out = out + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + + self.zlib = zlib.decompressobj(wbits=zlib_mode) + + def feed_data(self, chunk): + try: + chunk = self.zlib.decompress(chunk) + except Exception: + raise errors.IncompleteRead(b'') from None + + if chunk: + self.out.feed_data(chunk) + + def feed_eof(self): + self.out.feed_data(self.zlib.flush()) + if not self.zlib.eof: + raise errors.IncompleteRead(b'') + + self.out.feed_eof() + + +def wrap_payload_filter(func): + """Wraps payload filter and piped filters. + + Filter is a generatator that accepts arbitrary chunks of data, + modify data and emit new stream of data. + + For example we have stream of chunks: ['1', '2', '3', '4', '5'], + we can apply chunking filter to this stream: + + ['1', '2', '3', '4', '5'] + | + response.add_chunking_filter(2) + | + ['12', '34', '5'] + + It is possible to use different filters at the same time. + + For a example to compress incoming stream with 'deflate' encoding + and then split data and emit chunks of 8196 bytes size chunks: + + >> response.add_compression_filter('deflate') + >> response.add_chunking_filter(8196) + + Filters do not alter transfer encoding. + + Filter can receive types types of data, bytes object or EOF_MARKER. + + 1. If filter receives bytes object, it should process data + and yield processed data then yield EOL_MARKER object. + 2. If Filter recevied EOF_MARKER, it should yield remaining + data (buffered) and then yield EOF_MARKER. + """ + @functools.wraps(func) + def wrapper(self, *args, **kw): + new_filter = func(self, *args, **kw) + + filter = self.filter + if filter is not None: + next(new_filter) + self.filter = filter_pipe(filter, new_filter) + else: + self.filter = new_filter + + next(self.filter) + + return wrapper + + +def filter_pipe(filter, filter2): + """Creates pipe between two filters. + + filter_pipe() feeds first filter with incoming data and then + send yielded from first filter data into filter2, results of + filter2 are being emitted. + + 1. If filter_pipe receives bytes object, it sends it to the first filter. + 2. Reads yielded values from the first filter until it receives + EOF_MARKER or EOL_MARKER. + 3. Each of this values is being send to second filter. + 4. Reads yielded values from second filter until it recives EOF_MARKER or + EOL_MARKER. Each of this values yields to writer. + """ + chunk = yield + + while True: + eof = chunk is EOF_MARKER + chunk = filter.send(chunk) + + while chunk is not EOL_MARKER: + chunk = filter2.send(chunk) + + while chunk not in (EOF_MARKER, EOL_MARKER): + yield chunk + chunk = next(filter2) + + if chunk is not EOF_MARKER: + if eof: + chunk = EOF_MARKER + else: + chunk = next(filter) + else: + break + + chunk = yield EOL_MARKER + + +class HttpMessage: + """HttpMessage allows to write headers and payload to a stream. + + For example, lets say we want to read file then compress it with deflate + compression and then send it with chunked transfer encoding, code may look + like this: + + >> response = tulip.http.Response(transport, 200) + + We have to use deflate compression first: + + >> response.add_compression_filter('deflate') + + Then we want to split output stream into chunks of 1024 bytes size: + + >> response.add_chunking_filter(1024) + + We can add headers to response with add_headers() method. add_headers() + does not send data to transport, send_headers() sends request/response + line and then sends headers: + + >> response.add_headers( + .. ('Content-Disposition', 'attachment; filename="..."')) + >> response.send_headers() + + Now we can use chunked writer to write stream to a network stream. + First call to write() method sends response status line and headers, + add_header() and add_headers() method unavailble at this stage: + + >> with open('...', 'rb') as f: + .. chunk = fp.read(8196) + .. while chunk: + .. response.write(chunk) + .. chunk = fp.read(8196) + + >> response.write_eof() + """ + + writer = None + + # 'filter' is being used for altering write() bahaviour, + # add_chunking_filter adds deflate/gzip compression and + # add_compression_filter splits incoming data into a chunks. + filter = None + + HOP_HEADERS = None # Must be set by subclass. + + SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} tulip/0.0'.format(sys.version_info) + + status = None + status_line = b'' + upgrade = False # Connection: UPGRADE + websocket = False # Upgrade: WEBSOCKET + + # subclass can enable auto sending headers with write() call, + # this is useful for wsgi's start_response implementation. + _send_headers = False + + def __init__(self, transport, version, close): + self.transport = transport + self.version = version + self.closing = close + + # disable keep-alive for http/1.0 + if version <= (1, 0): + self.keepalive = False + else: + self.keepalive = None + + self.chunked = False + self.length = None + self.headers = collections.deque() + self.headers_sent = False + + def force_close(self): + self.closing = True + self.keepalive = False + + def force_chunked(self): + self.chunked = True + + def keep_alive(self): + if self.keepalive is None: + return not self.closing + else: + return self.keepalive + + def is_headers_sent(self): + return self.headers_sent + + def add_header(self, name, value): + """Analyze headers. Calculate content length, + removes hop headers, etc.""" + assert not self.headers_sent, 'headers have been sent already' + assert isinstance(name, str), '{!r} is not a string'.format(name) + + name = name.strip().upper() + + if name == 'CONTENT-LENGTH': + self.length = int(value) + + if name == 'CONNECTION': + val = value.lower() + # handle websocket + if 'upgrade' in val: + self.upgrade = True + # connection keep-alive + elif 'close' in val: + self.keepalive = False + elif 'keep-alive' in val and self.version >= (1, 1): + self.keepalive = True + + elif name == 'UPGRADE': + if 'websocket' in value.lower(): + self.websocket = True + self.headers.append((name, value)) + + elif name == 'TRANSFER-ENCODING' and not self.chunked: + self.chunked = value.lower().strip() == 'chunked' + + elif name not in self.HOP_HEADERS: + # ignore hopbyhop headers + self.headers.append((name, value)) + + def add_headers(self, *headers): + """Adds headers to a http message.""" + for name, value in headers: + self.add_header(name, value) + + def send_headers(self): + """Writes headers to a stream. Constructs payload writer.""" + # Chunked response is only for HTTP/1.1 clients or newer + # and there is no Content-Length header is set. + # Do not use chunked responses when the response is guaranteed to + # not have a response body (304, 204). + assert not self.headers_sent, 'headers have been sent already' + self.headers_sent = True + + if (self.chunked is True) or ( + self.length is None and + self.version >= (1, 1) and + self.status not in (304, 204)): + self.chunked = True + self.writer = self._write_chunked_payload() + + elif self.length is not None: + self.writer = self._write_length_payload(self.length) + + else: + self.writer = self._write_eof_payload() + + next(self.writer) + + self._add_default_headers() + + # status + headers + hdrs = ''.join(itertools.chain( + (self.status_line,), + *((k, ': ', v, '\r\n') for k, v in self.headers))) + + self.transport.write(hdrs.encode('ascii') + b'\r\n') + + def _add_default_headers(self): + # set the connection header + if self.upgrade: + connection = 'upgrade' + elif not self.closing if self.keepalive is None else self.keepalive: + connection = 'keep-alive' + else: + connection = 'close' + + if self.chunked: + self.headers.appendleft(('TRANSFER-ENCODING', 'chunked')) + + self.headers.appendleft(('CONNECTION', connection)) + + def write(self, chunk): + """write() writes chunk of data to a steram by using different writers. + writer uses filter to modify chunk of data. write_eof() indicates + end of stream. writer can't be used after write_eof() method + being called.""" + assert (isinstance(chunk, (bytes, bytearray)) or + chunk is EOF_MARKER), chunk + + if self._send_headers and not self.headers_sent: + self.send_headers() + + assert self.writer is not None, 'send_headers() is not called.' + + if self.filter: + chunk = self.filter.send(chunk) + while chunk not in (EOF_MARKER, EOL_MARKER): + self.writer.send(chunk) + chunk = next(self.filter) + else: + if chunk is not EOF_MARKER: + self.writer.send(chunk) + + def write_eof(self): + self.write(EOF_MARKER) + try: + self.writer.throw(tulip.EofStream()) + except StopIteration: + pass + + def _write_chunked_payload(self): + """Write data in chunked transfer encoding.""" + while True: + try: + chunk = yield + except tulip.EofStream: + self.transport.write(b'0\r\n\r\n') + break + + self.transport.write('{:x}\r\n'.format(len(chunk)).encode('ascii')) + self.transport.write(bytes(chunk)) + self.transport.write(b'\r\n') + + def _write_length_payload(self, length): + """Write specified number of bytes to a stream.""" + while True: + try: + chunk = yield + except tulip.EofStream: + break + + if length: + l = len(chunk) + if length >= l: + self.transport.write(chunk) + else: + self.transport.write(chunk[:length]) + + length = max(0, length-l) + + def _write_eof_payload(self): + while True: + try: + chunk = yield + except tulip.EofStream: + break + + self.transport.write(chunk) + + @wrap_payload_filter + def add_chunking_filter(self, chunk_size=16*1024): + """Split incoming stream into chunks.""" + buf = bytearray() + chunk = yield + + while True: + if chunk is EOF_MARKER: + if buf: + yield buf + + yield EOF_MARKER + + else: + buf.extend(chunk) + + while len(buf) >= chunk_size: + chunk = bytes(buf[:chunk_size]) + del buf[:chunk_size] + yield chunk + + chunk = yield EOL_MARKER + + @wrap_payload_filter + def add_compression_filter(self, encoding='deflate'): + """Compress incoming stream with deflate or gzip encoding.""" + zlib_mode = (16 + zlib.MAX_WBITS + if encoding == 'gzip' else -zlib.MAX_WBITS) + zcomp = zlib.compressobj(wbits=zlib_mode) + + chunk = yield + while True: + if chunk is EOF_MARKER: + yield zcomp.flush() + chunk = yield EOF_MARKER + + else: + yield zcomp.compress(chunk) + chunk = yield EOL_MARKER + + +class Response(HttpMessage): + """Create http response message. + + Transport is a socket stream transport. status is a response status code, + status has to be integer value. http_version is a tuple that represents + http version, (1, 0) stands for HTTP/1.0 and (1, 1) is for HTTP/1.1 + """ + + HOP_HEADERS = { + 'CONNECTION', + 'KEEP-ALIVE', + 'PROXY-AUTHENTICATE', + 'PROXY-AUTHORIZATION', + 'TE', + 'TRAILERS', + 'TRANSFER-ENCODING', + 'UPGRADE', + 'SERVER', + 'DATE', + } + + def __init__(self, transport, status, http_version=(1, 1), close=False): + super().__init__(transport, http_version, close) + + self.status = status + self.status_line = 'HTTP/{}.{} {} {}\r\n'.format( + http_version[0], http_version[1], status, RESPONSES[status][0]) + + def _add_default_headers(self): + super()._add_default_headers() + self.headers.extend((('DATE', format_date_time(None)), + ('SERVER', self.SERVER_SOFTWARE),)) + + +class Request(HttpMessage): + + HOP_HEADERS = () + + def __init__(self, transport, method, path, + http_version=(1, 1), close=False): + super().__init__(transport, http_version, close) + + self.method = method + self.path = path + self.status_line = '{0} {1} HTTP/{2[0]}.{2[1]}\r\n'.format( + method, path, http_version) + + def _add_default_headers(self): + super()._add_default_headers() + self.headers.append(('USER-AGENT', self.SERVER_SOFTWARE)) diff --git a/tulip/http/server.py b/tulip/http/server.py new file mode 100644 index 00000000..fc5621c5 --- /dev/null +++ b/tulip/http/server.py @@ -0,0 +1,215 @@ +"""simple http server.""" + +__all__ = ['ServerHttpProtocol'] + +import http.server +import inspect +import logging +import traceback + +import tulip +from tulip.http import errors + + +RESPONSES = http.server.BaseHTTPRequestHandler.responses +DEFAULT_ERROR_MESSAGE = """ + + + {status} {reason} + + +

{status} {reason}

+ {message} + +""" + + +class ServerHttpProtocol(tulip.Protocol): + """Simple http protocol implementation. + + ServerHttpProtocol handles incoming http request. It reads request line, + request headers and request payload and calls handler_request() method. + By default it always returns with 404 respose. + + ServerHttpProtocol handles errors in incoming request, like bad + status line, bad headers or incomplete payload. If any error occurs, + connection gets closed. + + log: custom logging object + debug: enable debug mode + keep_alive: number of seconds before closing keep alive connection + loop: event loop object + """ + _request_count = 0 + _request_handler = None + _keep_alive = False # keep transport open + _keep_alive_handle = None # keep alive timer handle + + def __init__(self, *, log=logging, debug=False, + keep_alive=None, loop=None, **kwargs): + self.__dict__.update(kwargs) + self.log = log + self.debug = debug + + self._keep_alive_period = keep_alive # number of seconds to keep alive + + if keep_alive and loop is None: + loop = tulip.get_event_loop() + self._loop = loop + + def connection_made(self, transport): + self.transport = transport + self.stream = tulip.StreamBuffer(loop=self._loop) + self._request_handler = tulip.Task(self.start(), loop=self._loop) + + def data_received(self, data): + self.stream.feed_data(data) + + def eof_received(self): + self.stream.feed_eof() + + def connection_lost(self, exc): + self.stream.feed_eof() + + if self._request_handler is not None: + self._request_handler.cancel() + self._request_handler = None + if self._keep_alive_handle is not None: + self._keep_alive_handle.cancel() + self._keep_alive_handle = None + + def keep_alive(self, val): + self._keep_alive = val + + def log_access(self, status, message, *args, **kw): + pass + + def log_debug(self, *args, **kw): + if self.debug: + self.log.debug(*args, **kw) + + def log_exception(self, *args, **kw): + self.log.exception(*args, **kw) + + @tulip.coroutine + def start(self): + """Start processing of incoming requests. + It reads request line, request headers and request payload, then + calls handle_request() method. Subclass has to override + handle_request(). start() handles various excetions in request + or response handling. Connection is being closed always unless + keep_alive(True) specified. + """ + + while True: + info = None + message = None + self._request_count += 1 + self._keep_alive = False + + try: + httpstream = self.stream.set_parser( + tulip.http.http_request_parser()) + + message = yield from httpstream.read() + + # cancel keep-alive timer + if self._keep_alive_handle is not None: + self._keep_alive_handle.cancel() + self._keep_alive_handle = None + + payload = self.stream.set_parser( + tulip.http.http_payload_parser(message)) + + handler = self.handle_request(message, payload) + if (inspect.isgenerator(handler) or + isinstance(handler, tulip.Future)): + yield from handler + + except tulip.CancelledError: + self.log_debug('Ignored premature client disconnection.') + break + except errors.HttpException as exc: + self.handle_error(exc.code, info, message, exc, exc.headers) + except Exception as exc: + self.handle_error(500, info, message, exc) + finally: + if self._request_handler: + if self._keep_alive and self._keep_alive_period: + self._keep_alive_handle = self._loop.call_later( + self._keep_alive_period, self.transport.close) + else: + self.transport.close() + self._request_handler = None + break + else: + break + + def handle_error(self, status=500, + message=None, payload=None, exc=None, headers=None): + """Handle errors. + + Returns http response with specific status code. Logs additional + information. It always closes current connection.""" + try: + if self._request_handler is None: + # client has been disconnected during writing. + return + + if status == 500: + self.log_exception("Error handling request") + + try: + reason, msg = RESPONSES[status] + except KeyError: + status = 500 + reason, msg = '???', '' + + if self.debug and exc is not None: + try: + tb = traceback.format_exc() + msg += '

Traceback:

\n
{}
'.format(tb) + except: + pass + + self.log_access(status, message) + + html = DEFAULT_ERROR_MESSAGE.format( + status=status, reason=reason, message=msg) + + response = tulip.http.Response(self.transport, status, close=True) + response.add_headers( + ('Content-Type', 'text/html'), + ('Content-Length', str(len(html)))) + if headers is not None: + response.add_headers(*headers) + response.send_headers() + + response.write(html.encode('ascii')) + response.write_eof() + finally: + self.keep_alive(False) + + def handle_request(self, message, payload): + """Handle a single http request. + + Subclass should override this method. By default it always + returns 404 response. + + info: tulip.http.RequestLine instance + message: tulip.http.RawHttpMessage instance + """ + response = tulip.http.Response( + self.transport, 404, http_version=message.version, close=True) + + body = b'Page Not Found!' + + response.add_headers( + ('Content-Type', 'text/plain'), + ('Content-Length', str(len(body)))) + response.send_headers() + response.write(body) + response.write_eof() + + self.keep_alive(False) + self.log_access(404, message) diff --git a/tulip/http/session.py b/tulip/http/session.py new file mode 100644 index 00000000..9cdd9cea --- /dev/null +++ b/tulip/http/session.py @@ -0,0 +1,103 @@ +"""client session support.""" + +__all__ = ['Session'] + +import functools +import tulip +import http.cookies + + +class Session: + + def __init__(self): + self._conns = {} + self.cookies = http.cookies.SimpleCookie() + + def __del__(self): + self.close() + + def close(self): + """Close all opened transports.""" + for key, data in self._conns.items(): + for transport, proto in data: + transport.close() + + self._conns.clear() + + def update_cookies(self, cookies): + if isinstance(cookies, dict): + cookies = cookies.items() + + for name, value in cookies: + if isinstance(value, http.cookies.Morsel): + # use dict method because SimpleCookie class modifies value + dict.__setitem__(self.cookies, name, value) + else: + self.cookies[name] = value + + @tulip.coroutine + def start(self, req, loop, new_conn=False, set_cookies=True): + key = (req.host, req.port, req.ssl) + + if set_cookies and self.cookies: + req.update_cookies(self.cookies.items()) + + if not new_conn: + transport, proto = self._get(key) + + if new_conn or transport is None: + new = True + transport, proto = yield from loop.create_connection( + functools.partial(tulip.StreamProtocol, loop=loop), + req.host, req.port, ssl=req.ssl) + else: + new = False + + try: + resp = req.send(transport) + yield from resp.start( + proto, TransportWrapper( + self._release, key, transport, proto, resp)) + except: + if new: + transport.close() + raise + + return (yield from self.start(req, loop, set_cookies=False)) + + return resp + + def _get(self, key): + conns = self._conns.get(key) + if conns: + return conns.pop() + + return None, None + + def _release(self, resp, key, conn): + msg = resp.message + if msg.should_close: + conn[0].close() + else: + conns = self._conns.get(key) + if conns is None: + conns = self._conns[key] = [] + conns.append(conn) + conn[1].unset_parser() + + if resp.cookies: + self.update_cookies(resp.cookies.items()) + + +class TransportWrapper: + + def __init__(self, release, key, transport, protocol, response): + self.release = release + self.key = key + self.transport = transport + self.protocol = protocol + self.response = response + + def close(self): + self.release(self.response, self.key, + (self.transport, self.protocol)) diff --git a/tulip/http/websocket.py b/tulip/http/websocket.py new file mode 100644 index 00000000..c3dd5872 --- /dev/null +++ b/tulip/http/websocket.py @@ -0,0 +1,233 @@ +"""WebSocket protocol versions 13 and 8.""" + +__all__ = ['WebSocketParser', 'WebSocketWriter', 'do_handshake', + 'Message', 'WebSocketError', + 'MSG_TEXT', 'MSG_BINARY', 'MSG_CLOSE', 'MSG_PING', 'MSG_PONG'] + +import base64 +import binascii +import collections +import hashlib +import struct +from tulip.http import errors + +# Frame opcodes defined in the spec. +OPCODE_CONTINUATION = 0x0 +MSG_TEXT = OPCODE_TEXT = 0x1 +MSG_BINARY = OPCODE_BINARY = 0x2 +MSG_CLOSE = OPCODE_CLOSE = 0x8 +MSG_PING = OPCODE_PING = 0x9 +MSG_PONG = OPCODE_PONG = 0xa + +WS_KEY = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +WS_HDRS = ('UPGRADE', 'CONNECTION', + 'SEC-WEBSOCKET-VERSION', 'SEC-WEBSOCKET-KEY') + +Message = collections.namedtuple('Message', ['tp', 'data', 'extra']) + + +class WebSocketError(Exception): + """WebSocket protocol parser error.""" + + +def WebSocketParser(): + out, buf = yield + + while True: + message = yield from parse_message(buf) + out.feed_data(message) + + if message.tp == MSG_CLOSE: + out.feed_eof() + break + + +def parse_frame(buf): + """Return the next frame from the socket.""" + # read header + data = yield from buf.read(2) + first_byte, second_byte = struct.unpack('!BB', data) + + fin = (first_byte >> 7) & 1 + rsv1 = (first_byte >> 6) & 1 + rsv2 = (first_byte >> 5) & 1 + rsv3 = (first_byte >> 4) & 1 + opcode = first_byte & 0xf + + # frame-fin = %x0 ; more frames of this message follow + # / %x1 ; final frame of this message + # frame-rsv1 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise + # frame-rsv2 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise + # frame-rsv3 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise + if rsv1 or rsv2 or rsv3: + raise WebSocketError('Received frame with non-zero reserved bits') + + if opcode > 0x7 and fin == 0: + raise WebSocketError('Received fragmented control frame') + + if fin == 0 and opcode == OPCODE_CONTINUATION: + raise WebSocketError( + 'Received new fragment frame with non-zero opcode') + + has_mask = (second_byte >> 7) & 1 + length = (second_byte) & 0x7f + + # Control frames MUST have a payload length of 125 bytes or less + if opcode > 0x7 and length > 125: + raise WebSocketError( + "Control frame payload cannot be larger than 125 bytes") + + # read payload + if length == 126: + data = yield from buf.read(2) + length = struct.unpack_from('!H', data)[0] + elif length > 126: + data = yield from buf.read(8) + length = struct.unpack_from('!Q', data)[0] + + if has_mask: + mask = yield from buf.read(4) + + if length: + payload = yield from buf.read(length) + else: + payload = b'' + + if has_mask: + payload = bytes(b ^ mask[i % 4] for i, b in enumerate(payload)) + + return fin, opcode, payload + + +def parse_message(buf): + fin, opcode, payload = yield from parse_frame(buf) + + if opcode == OPCODE_CLOSE: + if len(payload) >= 2: + close_code = struct.unpack('!H', payload[:2])[0] + close_message = payload[2:] + return Message(OPCODE_CLOSE, close_code, close_message) + elif payload: + raise WebSocketError( + 'Invalid close frame: {} {} {!r}'.format(fin, opcode, payload)) + return Message(OPCODE_CLOSE, '', '') + + elif opcode == OPCODE_PING: + return Message(OPCODE_PING, '', '') + + elif opcode == OPCODE_PONG: + return Message(OPCODE_PONG, '', '') + + elif opcode not in (OPCODE_TEXT, OPCODE_BINARY): + raise WebSocketError("Unexpected opcode={!r}".format(opcode)) + + # load text/binary + data = [payload] + + while not fin: + fin, _opcode, payload = yield from parse_frame(buf) + if _opcode != OPCODE_CONTINUATION: + raise WebSocketError( + 'The opcode in non-fin frame is expected ' + 'to be zero, got {!r}'.format(opcode)) + else: + data.append(payload) + + if opcode == OPCODE_TEXT: + return Message(OPCODE_TEXT, b''.join(data).decode('utf-8'), '') + else: + return Message(OPCODE_BINARY, b''.join(data), '') + + +class WebSocketWriter: + + def __init__(self, transport): + self.transport = transport + + def _send_frame(self, message, opcode): + """Send a frame over the websocket with message as its payload.""" + header = bytes([0x80 | opcode]) + msg_length = len(message) + + if msg_length < 126: + header += bytes([msg_length]) + elif msg_length < (1 << 16): + header += bytes([126]) + struct.pack('!H', msg_length) + else: + header += bytes([127]) + struct.pack('!Q', msg_length) + + self.transport.write(header + message) + + def pong(self): + """Send pong message.""" + self._send_frame(b'', OPCODE_PONG) + + def ping(self): + """Send pong message.""" + self._send_frame(b'', OPCODE_PING) + + def send(self, message, binary=False): + """Send a frame over the websocket with message as its payload.""" + if isinstance(message, str): + message = message.encode('utf-8') + if binary: + self._send_frame(message, OPCODE_BINARY) + else: + self._send_frame(message, OPCODE_TEXT) + + def close(self, code=1000, message=b''): + """Close the websocket, sending the specified code and message.""" + if isinstance(message, str): + message = message.encode('utf-8') + self._send_frame( + struct.pack('!H%ds' % len(message), code, message), + opcode=OPCODE_CLOSE) + + +def do_handshake(method, headers, transport): + """Prepare WebSocket handshake. It return http response code, + response headers, websocket parser, websocket writer. It does not + perform any IO.""" + + # WebSocket accepts only GET + if method.upper() != 'GET': + raise errors.HttpErrorException(405, headers=(('Allow', 'GET'),)) + + headers = dict(((hdr, val) for hdr, val in headers if hdr in WS_HDRS)) + + if 'websocket' != headers.get('UPGRADE', '').lower().strip(): + raise errors.BadRequestException( + 'No WebSocket UPGRADE hdr: {}\n' + 'Can "Upgrade" only to "WebSocket".'.format( + headers.get('UPGRADE'))) + + if 'upgrade' not in headers.get('CONNECTION', '').lower(): + raise errors.BadRequestException( + 'No CONNECTION upgrade hdr: {}'.format( + headers.get('CONNECTION'))) + + # check supported version + version = headers.get('SEC-WEBSOCKET-VERSION') + if version not in ('13', '8', '7'): + raise errors.BadRequestException( + 'Unsupported version: {}'.format(version)) + + # check client handshake for validity + key = headers.get('SEC-WEBSOCKET-KEY') + try: + if not key or len(base64.b64decode(key)) != 16: + raise errors.BadRequestException( + 'Handshake error: {!r}'.format(key)) + except binascii.Error: + raise errors.BadRequestException( + 'Handshake error: {!r}'.format(key)) from None + + # response code, headers, parser, writer + return (101, + (('UPGRADE', 'websocket'), + ('CONNECTION', 'upgrade'), + ('TRANSFER-ENCODING', 'chunked'), + ('SEC-WEBSOCKET-ACCEPT', base64.b64encode( + hashlib.sha1(key.encode() + WS_KEY).digest()).decode())), + WebSocketParser(), + WebSocketWriter(transport)) diff --git a/tulip/http/wsgi.py b/tulip/http/wsgi.py new file mode 100644 index 00000000..738e100f --- /dev/null +++ b/tulip/http/wsgi.py @@ -0,0 +1,227 @@ +"""wsgi server. + +TODO: + * proxy protocol + * x-forward security + * wsgi file support (os.sendfile) +""" + +__all__ = ['WSGIServerHttpProtocol'] + +import inspect +import io +import os +import sys +from urllib.parse import unquote, urlsplit + +import tulip +import tulip.http +from tulip.http import server + + +class WSGIServerHttpProtocol(server.ServerHttpProtocol): + """HTTP Server that implements the Python WSGI protocol. + + It uses 'wsgi.async' of 'True'. 'wsgi.input' can behave differently + depends on 'readpayload' constructor parameter. If readpayload is set to + True, wsgi server reads all incoming data into BytesIO object and + sends it as 'wsgi.input' environ var. If readpayload is set to false + 'wsgi.input' is a StreamReader and application should read incoming + data with "yield from environ['wsgi.input'].read()". It defaults to False. + """ + + SCRIPT_NAME = os.environ.get('SCRIPT_NAME', '') + + def __init__(self, app, readpayload=False, is_ssl=False, *args, **kw): + super().__init__(*args, **kw) + + self.wsgi = app + self.is_ssl = is_ssl + self.readpayload = readpayload + + def create_wsgi_response(self, message): + return WsgiResponse(self.transport, message) + + def create_wsgi_environ(self, message, payload): + uri_parts = urlsplit(message.path) + url_scheme = 'https' if self.is_ssl else 'http' + + environ = { + 'wsgi.input': payload, + 'wsgi.errors': sys.stderr, + 'wsgi.version': (1, 0), + 'wsgi.async': True, + 'wsgi.multithread': False, + 'wsgi.multiprocess': False, + 'wsgi.run_once': False, + 'wsgi.file_wrapper': FileWrapper, + 'wsgi.url_scheme': url_scheme, + 'SERVER_SOFTWARE': tulip.http.HttpMessage.SERVER_SOFTWARE, + 'REQUEST_METHOD': message.method, + 'QUERY_STRING': uri_parts.query or '', + 'RAW_URI': message.path, + 'SERVER_PROTOCOL': 'HTTP/%s.%s' % message.version + } + + # authors should be aware that REMOTE_HOST and REMOTE_ADDR + # may not qualify the remote addr: + # http://www.ietf.org/rfc/rfc3875 + forward = self.transport.get_extra_info('addr', '127.0.0.1') + script_name = self.SCRIPT_NAME + server = forward + + for hdr_name, hdr_value in message.headers: + if hdr_name == 'EXPECT': + # handle expect + if hdr_value.lower() == '100-continue': + self.transport.write(b'HTTP/1.1 100 Continue\r\n\r\n') + elif hdr_name == 'HOST': + server = hdr_value + elif hdr_name == 'SCRIPT_NAME': + script_name = hdr_value + elif hdr_name == 'CONTENT-TYPE': + environ['CONTENT_TYPE'] = hdr_value + continue + elif hdr_name == 'CONTENT-LENGTH': + environ['CONTENT_LENGTH'] = hdr_value + continue + + key = 'HTTP_%s' % hdr_name.replace('-', '_') + if key in environ: + hdr_value = '%s,%s' % (environ[key], hdr_value) + + environ[key] = hdr_value + + if isinstance(forward, str): + # we only took the last one + # http://en.wikipedia.org/wiki/X-Forwarded-For + if ',' in forward: + forward = forward.rsplit(',', 1)[-1].strip() + + # find host and port on ipv6 address + if '[' in forward and ']' in forward: + host = forward.split(']')[0][1:].lower() + elif ':' in forward and forward.count(':') == 1: + host = forward.split(':')[0].lower() + else: + host = forward + + forward = forward.split(']')[-1] + if ':' in forward and forward.count(':') == 1: + port = forward.split(':', 1)[1] + else: + port = 80 + + remote = (host, port) + else: + remote = forward + + environ['REMOTE_ADDR'] = remote[0] + environ['REMOTE_PORT'] = str(remote[1]) + + if isinstance(server, str): + server = server.split(':') + if len(server) == 1: + server.append('80' if url_scheme == 'http' else '443') + + environ['SERVER_NAME'] = server[0] + environ['SERVER_PORT'] = str(server[1]) + + path_info = uri_parts.path + if script_name: + path_info = path_info.split(script_name, 1)[-1] + + environ['PATH_INFO'] = unquote(path_info) + environ['SCRIPT_NAME'] = script_name + + environ['tulip.reader'] = self.stream + environ['tulip.writer'] = self.transport + + return environ + + @tulip.coroutine + def handle_request(self, message, payload): + """Handle a single HTTP request""" + + if self.readpayload: + wsgiinput = io.BytesIO() + chunk = yield from payload.read() + while chunk: + wsgiinput.write(chunk) + chunk = yield from payload.read() + wsgiinput.seek(0) + payload = wsgiinput + + environ = self.create_wsgi_environ(message, payload) + response = self.create_wsgi_response(message) + + riter = self.wsgi(environ, response.start_response) + if isinstance(riter, tulip.Future) or inspect.isgenerator(riter): + riter = yield from riter + + resp = response.response + try: + for item in riter: + if isinstance(item, tulip.Future): + item = yield from item + resp.write(item) + + resp.write_eof() + finally: + if hasattr(riter, 'close'): + riter.close() + + if resp.keep_alive(): + self.keep_alive(True) + + +class FileWrapper: + """Custom file wrapper.""" + + def __init__(self, fobj, chunk_size=8192): + self.fobj = fobj + self.chunk_size = chunk_size + if hasattr(fobj, 'close'): + self.close = fobj.close + + def __iter__(self): + return self + + def __next__(self): + data = self.fobj.read(self.chunk_size) + if data: + return data + raise StopIteration + + +class WsgiResponse: + """Implementation of start_response() callable as specified by PEP 3333""" + + status = None + + def __init__(self, transport, message): + self.transport = transport + self.message = message + + def start_response(self, status, headers, exc_info=None): + if exc_info: + try: + if self.status: + raise exc_info[1] + finally: + exc_info = None + + status_code = int(status.split(' ', 1)[0]) + + self.status = status + resp = self.response = tulip.http.Response( + self.transport, status_code, + self.message.version, self.message.should_close) + resp.add_headers(*headers) + + # send headers immediately for websocket connection + if status_code == 101 and resp.upgrade and resp.websocket: + resp.send_headers() + else: + resp._send_headers = True + return self.response.write diff --git a/tulip/locks.py b/tulip/locks.py new file mode 100644 index 00000000..87937ec0 --- /dev/null +++ b/tulip/locks.py @@ -0,0 +1,403 @@ +"""Synchronization primitives.""" + +__all__ = ['Lock', 'EventWaiter', 'Condition', 'Semaphore'] + +import collections + +from . import events +from . import futures +from . import tasks + + +class Lock: + """Primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned + by a particular coroutine when locked. A primitive lock is in one + of two states, 'locked' or 'unlocked'. + + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() + changes the state to locked and returns immediately. When the + state is locked, acquire() blocks until a call to release() in + another coroutine changes it to unlocked, then the acquire() call + resets it to locked and returns. The release() method should only + be called in the locked state; it changes the state to unlocked + and returns immediately. If an attempt is made to release an + unlocked lock, a RuntimeError will be raised. + + When more than one coroutine is blocked in acquire() waiting for + the state to turn to unlocked, only one coroutine proceeds when a + release() call resets the state to unlocked; first coroutine which + is blocked in acquire() is being processed. + + acquire() is a coroutine and should be called with 'yield from'. + + Locks also support the context manager protocol. '(yield from lock)' + should be used as context manager expression. + + Usage: + + lock = Lock() + ... + yield from lock + try: + ... + finally: + lock.release() + + Context manager usage: + + lock = Lock() + ... + with (yield from lock): + ... + + Lock objects can be tested for locking state: + + if not lock.locked(): + yield from lock + else: + # lock is acquired + ... + + """ + + def __init__(self, *, loop=None): + self._waiters = collections.deque() + self._locked = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + extra = 'locked' if self._locked else 'unlocked' + if self._waiters: + extra = '{},waiters:{}'.format(extra, len(self._waiters)) + return '<{} [{}]>'.format(res[1:-1], extra) + + def locked(self): + """Return true if lock is acquired.""" + return self._locked + + @tasks.coroutine + def acquire(self): + """Acquire a lock. + + This method blocks until the lock is unlocked, then sets it to + locked and returns True. + """ + if not self._waiters and not self._locked: + self._locked = True + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + self._locked = True + return True + finally: + self._waiters.remove(fut) + + def release(self): + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + If any other coroutines are blocked waiting for the lock to become + unlocked, allow exactly one of them to proceed. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + # Wake up the first waiter who isn't cancelled. + for fut in self._waiters: + if not fut.cancelled(): + fut.set_result(True) + break + else: + raise RuntimeError('Lock is not acquired.') + + def __enter__(self): + if not self._locked: + raise RuntimeError( + '"yield from" should be used as context manager expression') + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self + + +# TODO: Why not call this Event? +class EventWaiter: + """A EventWaiter implementation, our equivalent to threading.Event + + Class implementing event objects. An event manages a flag that can be set + to true with the set() method and reset to false with the clear() method. + The wait() method blocks until the flag is true. The flag is initially + false. + """ + + def __init__(self, *, loop=None): + self._waiters = collections.deque() + self._value = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + # TODO: add waiters:N if > 0. + res = super().__repr__() + return '<{} [{}]>'.format(res[1:-1], 'set' if self._value else 'unset') + + def is_set(self): + """Return true if and only if the internal flag is true.""" + return self._value + + def set(self): + """Set the internal flag to true. All coroutines waiting for it to + become true are awakened. Coroutine that call wait() once the flag is + true will not block at all. + """ + if not self._value: + self._value = True + + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + + def clear(self): + """Reset the internal flag to false. Subsequently, coroutines calling + wait() will block until set() is called to set the internal flag + to true again.""" + self._value = False + + @tasks.coroutine + def wait(self): + """Block until the internal flag is true. + + If the internal flag is true on entry, return True + immediately. Otherwise, block until another coroutine calls + set() to set the flag to true, then return True. + """ + if self._value: + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + return True + finally: + self._waiters.remove(fut) + + +# TODO: Why is this a Lock subclass? threading.Condition *has* a lock. +class Condition(Lock): + """A Condition implementation. + + This class implements condition variable objects. A condition variable + allows one or more coroutines to wait until they are notified by another + coroutine. + """ + + def __init__(self, *, loop=None): + super().__init__(loop=loop) + self._condition_waiters = collections.deque() + + # TODO: Add __repr__() with len(_condition_waiters). + + @tasks.coroutine + def wait(self): + """Wait until notified. + + If the calling coroutine has not acquired the lock when this + method is called, a RuntimeError is raised. + + This method releases the underlying lock, and then blocks + until it is awakened by a notify() or notify_all() call for + the same condition variable in another coroutine. Once + awakened, it re-acquires the lock and returns True. + """ + if not self._locked: + raise RuntimeError('cannot wait on un-acquired lock') + + keep_lock = True + self.release() + try: + fut = futures.Future(loop=self._loop) + self._condition_waiters.append(fut) + try: + yield from fut + return True + finally: + self._condition_waiters.remove(fut) + + except GeneratorExit: + keep_lock = False # Prevent yield in finally clause. + raise + finally: + if keep_lock: + yield from self.acquire() + + @tasks.coroutine + def wait_for(self, predicate): + """Wait until a predicate becomes true. + + The predicate should be a callable which result will be + interpreted as a boolean value. The final predicate value is + the return value. + """ + result = predicate() + while not result: + yield from self.wait() + result = predicate() + return result + + def notify(self, n=1): + """By default, wake up one coroutine waiting on this condition, if any. + If the calling coroutine has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up at most n of the coroutines waiting for the + condition variable; it is a no-op if no coroutines are waiting. + + Note: an awakened coroutine does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self._locked: + raise RuntimeError('cannot notify on un-acquired lock') + + idx = 0 + for fut in self._condition_waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self): + """Wake up all threads waiting on this condition. This method acts + like notify(), but wakes up all waiting threads instead of one. If the + calling thread has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._condition_waiters)) + + +class Semaphore: + """A Semaphore implementation. + + A semaphore manages an internal counter which is decremented by each + acquire() call and incremented by each release() call. The counter + can never go below zero; when acquire() finds that it is zero, it blocks, + waiting until some other thread calls release(). + + Semaphores also support the context manager protocol. + + The first optional argument gives the initial value for the internal + counter; it defaults to 1. If the value given is less than 0, + ValueError is raised. + + The second optional argument determins can semophore be released more than + initial internal counter value; it defaults to False. If the value given + is True and number of release() is more than number of successfull + acquire() calls ValueError is raised. + """ + + def __init__(self, value=1, bound=False, *, loop=None): + if value < 0: + raise ValueError("Semaphore initial value must be > 0") + self._value = value + self._bound = bound + self._bound_value = value + self._waiters = collections.deque() + self._locked = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + # TODO: add waiters:N if > 0. + res = super().__repr__() + return '<{} [{}]>'.format( + res[1:-1], + 'locked' if self._locked else 'unlocked,value:{}'.format( + self._value)) + + def locked(self): + """Returns True if semaphore can not be acquired immediately.""" + return self._locked + + @tasks.coroutine + def acquire(self): + """Acquire a semaphore. + + If the internal counter is larger than zero on entry, + decrement it by one and return True immediately. If it is + zero on entry, block, waiting until some other coroutine has + called release() to make it larger than 0, and then return + True. + """ + if not self._waiters and self._value > 0: + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + self._value -= 1 + if self._value == 0: + self._locked = True + return True + finally: + self._waiters.remove(fut) + + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + When it was zero on entry and another coroutine is waiting for it to + become larger than zero again, wake up that coroutine. + + If Semaphore is create with "bound" paramter equals true, then + release() method checks to make sure its current value doesn't exceed + its initial value. If it does, ValueError is raised. + """ + if self._bound and self._value >= self._bound_value: + raise ValueError('Semaphore released too many times') + + self._value += 1 + self._locked = False + + for waiter in self._waiters: + if not waiter.done(): + waiter.set_result(True) + break + + def __enter__(self): + # TODO: This is questionable. How do we know the user actually + # wrote "with (yield from sema)" instead of "with sema"? + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self diff --git a/tulip/log.py b/tulip/log.py new file mode 100644 index 00000000..b918fe54 --- /dev/null +++ b/tulip/log.py @@ -0,0 +1,6 @@ +"""Tulip logging configuration""" + +import logging + + +tulip_log = logging.getLogger("tulip") diff --git a/tulip/parsers.py b/tulip/parsers.py new file mode 100644 index 00000000..43ddc2e9 --- /dev/null +++ b/tulip/parsers.py @@ -0,0 +1,399 @@ +"""Parser is a generator function. + +Parser receives data with generator's send() method and sends data to +destination DataBuffer. Parser receives ParserBuffer and DataBuffer objects +as a parameters of the first send() call, all subsequent send() calls should +send bytes objects. Parser sends parsed 'term' to desitnation buffer with +DataBuffer.feed_data() method. DataBuffer object should implement two methods. +feed_data() - parser uses this method to send parsed protocol data. +feed_eof() - parser uses this method for indication of end of parsing stream. +To indicate end of incoming data stream EofStream exception should be sent +into parser. Parser could throw exceptions. + +There are three stages: + + * Data flow chain: + + 1. Application creates StreamBuffer object for storing incoming data. + 2. StreamBuffer creates ParserBuffer as internal data buffer. + 3. Application create parser and set it into stream buffer: + + parser = http_request_parser() + data_buffer = stream.set_parser(parser) + + 3. At this stage StreamBuffer creates DataBuffer object and passes it + and internal buffer into parser with first send() call. + + def set_parser(self, parser): + next(parser) + data_buffer = DataBuffer() + parser.send((data_buffer, self._buffer)) + return data_buffer + + 4. Application waits data on data_buffer.read() + + while True: + msg = yield form data_buffer.read() + ... + + * Data flow: + + 1. Tulip's transport reads data from socket and sends data to protocol + with data_received() call. + 2. Protocol sends data to StreamBuffer with feed_data() call. + 3. StreamBuffer sends data into parser with generator's send() method. + 4. Parser processes incoming data and sends parsed data + to DataBuffer with feed_data() + 4. Application received parsed data from DataBuffer.read() + + * Eof: + + 1. StreamBuffer recevies eof with feed_eof() call. + 2. StreamBuffer throws EofStream exception into parser. + 3. Then it unsets parser. + +_SocketSocketTransport -> + -> "protocol" -> StreamBuffer -> "parser" -> DataBuffer <- "application" + +""" +__all__ = ['EofStream', 'StreamBuffer', 'StreamProtocol', + 'ParserBuffer', 'DataBuffer', 'lines_parser', 'chunks_parser'] + +import collections + +from . import tasks +from . import futures +from . import protocols + + +class EofStream(Exception): + """eof stream indication.""" + + +class StreamBuffer: + """StreamBuffer manages incoming bytes stream and protocol parsers. + + StreamBuffer uses ParserBuffer as internal buffer. + + set_parser() sets current parser, it creates DataBuffer object + and sends ParserBuffer and DataBuffer into parser generator. + + unset_parser() sends EofStream into parser and then removes it. + """ + + def __init__(self, *, loop=None): + self._loop = loop + self._buffer = ParserBuffer() + self._eof = False + self._parser = None + self._parser_buffer = None + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + if self._parser_buffer is not None: + self._parser_buffer.set_exception(exc) + self._parser = None + self._parser_buffer = None + + def feed_data(self, data): + """send data to current parser or store in buffer.""" + if not data: + return + + if self._parser: + try: + self._parser.send(data) + except StopIteration: + self._parser = None + self._parser_buffer = None + except Exception as exc: + self._parser_buffer.set_exception(exc) + self._parser = None + self._parser_buffer = None + else: + self._buffer.feed_data(data) + + def feed_eof(self): + """send eof to all parsers, recursively.""" + if self._parser: + try: + self._parser.throw(EofStream()) + except StopIteration: + pass + except EofStream: + self._parser_buffer.feed_eof() + except Exception as exc: + self._parser_buffer.set_exception(exc) + + self._parser = None + self._parser_buffer = None + + self._eof = True + + def set_parser(self, p): + """set parser to stream. return parser's DataStream.""" + if self._parser: + self.unset_parser() + + out = DataBuffer(loop=self._loop) + if self._exception: + out.set_exception(self._exception) + return out + + # init generator + next(p) + try: + # initialize parser with data and parser buffers + p.send((out, self._buffer)) + except StopIteration: + pass + except Exception as exc: + out.set_exception(exc) + else: + # parser still require more data + self._parser = p + self._parser_buffer = out + + if self._eof: + self.unset_parser() + + return out + + def unset_parser(self): + """unset parser, send eof to the parser and then remove it.""" + if self._parser is None: + return + + try: + self._parser.throw(EofStream()) + except StopIteration: + pass + except EofStream: + self._parser_buffer.feed_eof() + except Exception as exc: + self._parser_buffer.set_exception(exc) + finally: + self._parser = None + self._parser_buffer = None + + +class StreamProtocol(StreamBuffer, protocols.Protocol): + """Tulip's stream protocol based on StreamBuffer""" + + transport = None + + data_received = StreamBuffer.feed_data + + eof_received = StreamBuffer.feed_eof + + def connection_made(self, transport): + self.transport = transport + + def connection_lost(self, exc): + self.transport = None + + if exc is not None: + self.set_exception(exc) + else: + self.feed_eof() + + +class DataBuffer: + """DataBuffer is a destination for parsed data.""" + + def __init__(self, *, loop=None): + self._loop = loop + self._buffer = collections.deque() + self._eof = False + self._waiter = None + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + waiter = self._waiter + if waiter is not None: + self._waiter = None + if not waiter.done(): + waiter.set_exception(exc) + + def feed_data(self, data): + self._buffer.append(data) + + waiter = self._waiter + if waiter is not None: + self._waiter = None + waiter.set_result(True) + + def feed_eof(self): + self._eof = True + + waiter = self._waiter + if waiter is not None: + self._waiter = None + waiter.set_result(False) + + @tasks.coroutine + def read(self): + if self._exception is not None: + raise self._exception + + if not self._buffer and not self._eof: + assert not self._waiter + self._waiter = futures.Future(loop=self._loop) + yield from self._waiter + + if self._buffer: + return self._buffer.popleft() + else: + return None + + +class ParserBuffer(bytearray): + """ParserBuffer is a bytearray extension. + + ParserBuffer provides helper methods for parsers. + """ + + def __init__(self, *args): + super().__init__(*args) + + self.offset = 0 + self.size = 0 + self._writer = self._feed_data() + next(self._writer) + + def _shrink(self): + if self.offset: + del self[:self.offset] + self.offset = 0 + self.size = len(self) + + def _feed_data(self): + while True: + chunk = yield + if chunk: + chunk_len = len(chunk) + self.size += chunk_len + self.extend(chunk) + + # shrink buffer + if (self.offset and len(self) > 5120): + self._shrink() + + def feed_data(self, data): + self._writer.send(data) + + def read(self, size): + """read() reads specified amount of bytes.""" + + while True: + if self.size >= size: + start, end = self.offset, self.offset + size + self.offset = end + self.size = self.size - size + return self[start:end] + + self._writer.send((yield)) + + def readsome(self, size=None): + """reads size of less amount of bytes.""" + + while True: + if self.size > 0: + if size is None or self.size < size: + size = self.size + + start, end = self.offset, self.offset + size + self.offset = end + self.size = self.size - size + + return self[start:end] + + self._writer.send((yield)) + + def readuntil(self, stop, limit=None, exc=ValueError): + assert isinstance(stop, bytes) and stop, \ + 'bytes is required: {!r}'.format(stop) + + stop_len = len(stop) + + while True: + pos = self.find(stop, self.offset) + if pos >= 0: + end = pos + stop_len + size = end - self.offset + if limit is not None and size > limit: + raise exc('Line is too long.') + + start, self.offset = self.offset, end + self.size = self.size - size + + return self[start:end] + else: + if limit is not None and self.size > limit: + raise exc('Line is too long.') + + self._writer.send((yield)) + + def skip(self, size): + """skip() skips specified amount of bytes.""" + + while self.size < size: + self._writer.send((yield)) + + self.size -= size + self.offset += size + + def skipuntil(self, stop): + """skipuntil() reads until `stop` bytes sequence.""" + assert isinstance(stop, bytes) and stop, \ + 'bytes is required: {!r}'.format(stop) + + stop_len = len(stop) + + while True: + stop_line = self.find(stop, self.offset) + if stop_line >= 0: + end = stop_line + stop_len + self.size = self.size - (end - self.offset) + self.offset = end + return + else: + self.size = 0 + self.offset = len(self) - 1 + + self._writer.send((yield)) + + def __bytes__(self): + return bytes(self[self.offset:]) + + +def lines_parser(limit=2**16, exc=ValueError): + """Lines parser. + + lines parser splits a bytes stream into a chunks of data, each chunk ends + with \n symbol.""" + out, buf = yield + + while True: + out.feed_data((yield from buf.readuntil(b'\n', limit, exc))) + + +def chunks_parser(size=8196): + """Chunks parser. + + chunks parser splits a bytes stream into a specified + size chunks of data.""" + out, buf = yield + + while True: + out.feed_data((yield from buf.read(size))) diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py new file mode 100644 index 00000000..cda87918 --- /dev/null +++ b/tulip/proactor_events.py @@ -0,0 +1,288 @@ +"""Event loop using a proactor and related classes. + +A proactor is a "notify-on-completion" multiplexer. Currently a +proactor is only implemented on Windows with IOCP. +""" + +import socket + +from . import base_events +from . import constants +from . import futures +from . import transports +from .log import tulip_log + + +class _ProactorBasePipeTransport(transports.BaseTransport): + """Base class for pipe and socket transports.""" + + def __init__(self, loop, sock, protocol, waiter=None, extra=None): + super().__init__(extra) + self._set_extra(sock) + self._loop = loop + self._sock = sock + self._protocol = protocol + self._buffer = [] + self._read_fut = None + self._write_fut = None + self._writing_disabled = False + self._conn_lost = 0 + self._closing = False # Set when close() called. + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _set_extra(self, sock): + self._extra['pipe'] = sock + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + if not self._buffer and self._write_fut is None: + self._loop.call_soon(self._call_connection_lost, None) + if self._read_fut is not None: + self._read_fut.cancel() + + def _fatal_error(self, exc): + tulip_log.exception('Fatal error for %s', self) + self._force_close(exc) + + def _force_close(self, exc): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + if self._write_fut: + self._write_fut.cancel() + if self._read_fut: + self._read_fut.cancel() + self._write_fut = self._read_fut = None + self._buffer = [] + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + # XXX If there is a pending overlapped read on the other + # end then it may fail with ERROR_NETNAME_DELETED if we + # just close our end. First calling shutdown() seems to + # cure it, but maybe using DisconnectEx() would be better. + if hasattr(self._sock, 'shutdown'): + self._sock.shutdown(socket.SHUT_RDWR) + self._sock.close() + + +class _ProactorReadPipeTransport(_ProactorBasePipeTransport, + transports.ReadTransport): + """Transport for read pipes.""" + + def __init__(self, loop, sock, protocol, waiter=None, extra=None): + super().__init__(loop, sock, protocol, waiter, extra) + self._loop.call_soon(self._loop_reading) + + def _loop_reading(self, fut=None): + data = None + + try: + if fut is not None: + assert fut is self._read_fut + self._read_fut = None + data = fut.result() # deliver data later in "finally" clause + + if self._closing: + # since close() has been called we ignore any read data + data = None + return + + if data == b'': + # we got end-of-file so no need to reschedule a new read + return + + # reschedule a new read + self._read_fut = self._loop._proactor.recv(self._sock, 4096) + except ConnectionAbortedError as exc: + if not self._closing: + self._fatal_error(exc) + except ConnectionResetError as exc: + self._force_close(exc) + except OSError as exc: + self._fatal_error(exc) + except futures.CancelledError: + if not self._closing: + raise + else: + self._read_fut.add_done_callback(self._loop_reading) + finally: + if data: + self._protocol.data_received(data) + elif data is not None: + try: + self._protocol.eof_received() + finally: + self.close() + + +class _ProactorWritePipeTransport(_ProactorBasePipeTransport, + transports.WriteTransport): + """Transport for write pipes.""" + + def write(self, data): + assert isinstance(data, bytes), repr(data) + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + self._buffer.append(data) + if self._write_fut is None and not self._writing_disabled: + self._loop_writing() + + def _loop_writing(self, f=None): + try: + assert f is self._write_fut + self._write_fut = None + if f: + f.result() + data = b''.join(self._buffer) + self._buffer = [] + if not data: + if self._closing: + self._loop.call_soon(self._call_connection_lost, None) + return + if not self._writing_disabled: + self._write_fut = self._loop._proactor.send(self._sock, data) + self._write_fut.add_done_callback(self._loop_writing) + except OSError as exc: + self._fatal_error(exc) + + # TODO: write_eof(), can_write_eof(). + + def abort(self): + self._force_close(None) + + def pause_writing(self): + self._writing_disabled = True + + def resume_writing(self): + self._writing_disabled = False + if self._buffer and self._write_fut is None: + self._loop_writing() + + def discard_output(self): + if self._buffer: + self._buffer = [] + + +class _ProactorSocketTransport(_ProactorReadPipeTransport, + _ProactorWritePipeTransport, + transports.Transport): + """Transport for connected sockets.""" + + def _set_extra(self, sock): + self._extra['socket'] = sock + + +class BaseProactorEventLoop(base_events.BaseEventLoop): + + def __init__(self, proactor): + super().__init__() + tulip_log.debug('Using proactor: %s', proactor.__class__.__name__) + self._proactor = proactor + self._selector = proactor # convenient alias + proactor.set_loop(self) + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): + return _ProactorSocketTransport(self, sock, protocol, waiter, extra) + + def _make_read_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorReadPipeTransport(self, sock, protocol, waiter, extra) + + def _make_write_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorWritePipeTransport(self, sock, protocol, waiter, extra) + + def close(self): + if self._proactor is not None: + self._close_self_pipe() + self._proactor.close() + self._proactor = None + self._selector = None + + def sock_recv(self, sock, n): + return self._proactor.recv(sock, n) + + def sock_sendall(self, sock, data): + return self._proactor.send(sock, data) + + def sock_connect(self, sock, address): + return self._proactor.connect(sock, address) + + def sock_accept(self, sock): + return self._proactor.accept(sock) + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.call_soon(self._loop_self_reading) + + def _loop_self_reading(self, f=None): + try: + if f is not None: + f.result() # may raise + f = self._proactor.recv(self._ssock, 4096) + except: + self.close() + raise + else: + f.add_done_callback(self._loop_self_reading) + + def _write_to_self(self): + self._csock.send(b'x') + + def _start_serving(self, protocol_factory, sock, ssl=None): + assert not ssl, 'IocpEventLoop imcompatible with SSL.' + + def loop(f=None): + try: + if f: + conn, addr = f.result() + protocol = protocol_factory() + self._make_socket_transport( + conn, protocol, extra={'addr': addr}) + f = self._proactor.accept(sock) + except OSError: + sock.close() + tulip_log.exception('Accept failed') + except futures.CancelledError: + sock.close() + else: + f.add_done_callback(loop) + self.call_soon(loop) + + def _process_events(self, event_list): + pass # XXX hard work currently done in poll + + def stop_serving(self, sock): + self._proactor.stop_serving(sock) + sock.close() diff --git a/tulip/protocols.py b/tulip/protocols.py new file mode 100644 index 00000000..d76f25a2 --- /dev/null +++ b/tulip/protocols.py @@ -0,0 +1,100 @@ +"""Abstract Protocol class.""" + +__all__ = ['Protocol', 'DatagramProtocol'] + + +class BaseProtocol: + """ABC for base protocol class. + + Usually user implements protocols that derived from BaseProtocol + like Protocol or ProcessProtocol. + + The only case when BaseProtocol should be implemented directly is + write-only transport like write pipe + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the pipe connection. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +class Protocol(BaseProtocol): + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing (they don't raise exceptions). + + When the user wants to requests a transport, they pass a protocol + factory to a utility function (e.g., EventLoop.create_connection()). + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_lost() will be called exactly once + with either an exception object or None as an argument. + + State machine of calls: + + start -> CM [-> DR*] [-> ER?] -> CL -> end + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + The default implementation does nothing. + + TODO: By default close the transport. But we don't have the + transport as an instance variable (connection_made() may not + set it). + """ + + +class DatagramProtocol(BaseProtocol): + """ABC representing a datagram protocol.""" + + def datagram_received(self, data, addr): + """Called when some datagram is received.""" + + def connection_refused(self, exc): + """Connection is refused.""" + + +class SubprocessProtocol(BaseProtocol): + """ABC representing a protocol for subprocess calls.""" + + def pipe_data_received(self, fd, data): + """Called when subprocess write a data into stdout/stderr pipes. + + fd is int file dascriptor. + data is bytes object. + """ + + def pipe_connection_lost(self, fd, exc): + """Called when a file descriptor associated with the child process is + closed. + + fd is the int file descriptor that was closed. + """ + + def process_exited(self): + """Called when subprocess has exited. + """ diff --git a/tulip/queues.py b/tulip/queues.py new file mode 100644 index 00000000..b658e67e --- /dev/null +++ b/tulip/queues.py @@ -0,0 +1,284 @@ +"""Queues""" + +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue', + 'Full', 'Empty'] + +import collections +import heapq +import queue + +from . import events +from . import futures +from . import locks +from .tasks import coroutine + + +# Re-export queue.Full and .Empty exceptions. +Full = queue.Full +Empty = queue.Empty + + +class Queue: + """A queue, useful for coordinating producer and consumer coroutines. + + If maxsize is less than or equal to zero, the queue size is infinite. If it + is an integer greater than 0, then "yield from put()" will block when the + queue reaches maxsize, until an item is removed by get(). + + Unlike the standard library Queue, you can reliably know this Queue's size + with qsize(), since your single-threaded Tulip application won't be + interrupted between calling qsize() and doing an operation on the Queue. + """ + + def __init__(self, maxsize=0, *, loop=None): + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._maxsize = maxsize + + # Futures. + self._getters = collections.deque() + # Pairs of (item, Future). + self._putters = collections.deque() + self._init(maxsize) + + def _init(self, maxsize): + self._queue = collections.deque() + + def _get(self): + return self._queue.popleft() + + def _put(self, item): + self._queue.append(item) + + def __repr__(self): + return '<{} at {:#x} {}>'.format( + type(self).__name__, id(self), self._format()) + + def __str__(self): + return '<{} {}>'.format(type(self).__name__, self._format()) + + def _format(self): + result = 'maxsize={!r}'.format(self._maxsize) + if getattr(self, '_queue', None): + result += ' _queue={!r}'.format(list(self._queue)) + if self._getters: + result += ' _getters[{}]'.format(len(self._getters)) + if self._putters: + result += ' _putters[{}]'.format(len(self._putters)) + return result + + def _consume_done_getters(self): + # Delete waiters at the head of the get() queue who've timed out. + while self._getters and self._getters[0].done(): + self._getters.popleft() + + def _consume_done_putters(self): + # Delete waiters at the head of the put() queue who've timed out. + while self._putters and self._putters[0][1].done(): + self._putters.popleft() + + def qsize(self): + """Number of items in the queue.""" + return len(self._queue) + + @property + def maxsize(self): + """Number of items allowed in the queue.""" + return self._maxsize + + def empty(self): + """Return True if the queue is empty, False otherwise.""" + return not self._queue + + def full(self): + """Return True if there are maxsize items in the queue. + + Note: if the Queue was initialized with maxsize=0 (the default), + then full() is never True. + """ + if self._maxsize <= 0: + return False + else: + return self.qsize() == self._maxsize + + @coroutine + def put(self, item): + """Put an item into the queue. + + If you yield from put(), wait until a free slot is available + before adding item. + """ + self._consume_done_getters() + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + waiter = futures.Future(loop=self._loop) + + self._putters.append((item, waiter)) + yield from waiter + + else: + self._put(item) + + def put_nowait(self, item): + """Put an item into the queue without blocking. + + If no free slot is immediately available, raise Full. + """ + self._consume_done_getters() + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + raise Full + else: + self._put(item) + + @coroutine + def get(self): + """Remove and return an item from the queue. + + If you yield from get(), wait until a item is available. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + + # When a getter runs and frees up a slot so this putter can + # run, we need to defer the put for a tick to ensure that + # getters and putters alternate perfectly. See + # ChannelTest.test_wait. + self._loop.call_soon(putter.set_result, None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + waiter = futures.Future(loop=self._loop) + + self._getters.append(waiter) + return (yield from waiter) + + def get_nowait(self): + """Remove and return an item from the queue. + + Return an item if one is immediately available, else raise Full. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + # Wake putter on next tick. + putter.set_result(None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + raise Empty + + +class PriorityQueue(Queue): + """A subclass of Queue; retrieves entries in priority order (lowest first). + + Entries are typically tuples of the form: (priority number, data). + """ + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item, heappush=heapq.heappush): + heappush(self._queue, item) + + def _get(self, heappop=heapq.heappop): + return heappop(self._queue) + + +class LifoQueue(Queue): + """A subclass of Queue that retrieves most recently added entries first.""" + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item): + self._queue.append(item) + + def _get(self): + return self._queue.pop() + + +class JoinableQueue(Queue): + """A subclass of Queue with task_done() and join() methods.""" + + def __init__(self, maxsize=0, *, loop=None): + super().__init__(maxsize=maxsize, loop=loop) + self._unfinished_tasks = 0 + self._finished = locks.EventWaiter(loop=self._loop) + self._finished.set() + + def _format(self): + result = Queue._format(self) + if self._unfinished_tasks: + result += ' tasks={}'.format(self._unfinished_tasks) + return result + + def _put(self, item): + super()._put(item) + self._unfinished_tasks += 1 + self._finished.clear() + + def task_done(self): + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each get() used to fetch a task, + a subsequent call to task_done() tells the queue that the processing + on the task is complete. + + If a join() is currently blocking, it will resume when all items have + been processed (meaning that a task_done() call was received for every + item that had been put() into the queue). + + Raises ValueError if called more times than there were items placed in + the queue. + """ + if self._unfinished_tasks <= 0: + raise ValueError('task_done() called too many times') + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + @coroutine + def join(self): + """Block until all items in the queue have been gotten and processed. + + The count of unfinished tasks goes up whenever an item is added to the + queue. The count goes down whenever a consumer thread calls task_done() + to indicate that the item was retrieved and all work on it is complete. + When the count of unfinished tasks drops to zero, join() unblocks. + """ + if self._unfinished_tasks > 0: + yield from self._finished.wait() diff --git a/tulip/selector_events.py b/tulip/selector_events.py new file mode 100644 index 00000000..82d22bb6 --- /dev/null +++ b/tulip/selector_events.py @@ -0,0 +1,676 @@ +"""Event loop using a selector and related classes. + +A selector is a "notify-when-ready" multiplexer. For a subclass which +also includes support for signal handling, see the unix_events sub-module. +""" + +import collections +import socket +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import base_events +from . import constants +from . import events +from . import futures +from . import selectors +from . import transports +from .log import tulip_log + + +class BaseSelectorEventLoop(base_events.BaseEventLoop): + """Selector event loop. + + See events.EventLoop for API specification. + """ + + def __init__(self, selector=None): + super().__init__() + + if selector is None: + selector = selectors.DefaultSelector() + tulip_log.debug('Using selector: %s', selector.__class__.__name__) + self._selector = selector + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None): + return _SelectorSocketTransport(self, sock, protocol, waiter, extra) + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + server_side=False, extra=None): + return _SelectorSslTransport( + self, rawsock, protocol, sslcontext, waiter, server_side, extra) + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + return _SelectorDatagramTransport(self, sock, protocol, address, extra) + + def close(self): + if self._selector is not None: + self._close_self_pipe() + self._selector.close() + self._selector = None + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self.remove_reader(self._ssock.fileno()) + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.add_reader(self._ssock.fileno(), self._read_from_self) + + def _read_from_self(self): + try: + self._ssock.recv(1) + except (BlockingIOError, InterruptedError): + pass + + def _write_to_self(self): + try: + self._csock.send(b'x') + except (BlockingIOError, InterruptedError): + pass + + def _start_serving(self, protocol_factory, sock, ssl=None): + self.add_reader(sock.fileno(), self._accept_connection, + protocol_factory, sock, ssl) + + def _accept_connection(self, protocol_factory, sock, ssl=None): + try: + conn, addr = sock.accept() + conn.setblocking(False) + except (BlockingIOError, InterruptedError): + pass # False alarm. + except Exception: + # Bad error. Stop serving. + self.remove_reader(sock.fileno()) + sock.close() + # There's nowhere to send the error, so just log it. + # TODO: Someone will want an error handler for this. + tulip_log.exception('Accept failed') + else: + if ssl: + self._make_ssl_transport( + conn, protocol_factory(), ssl, None, + server_side=True, extra={'addr': addr}) + else: + self._make_socket_transport( + conn, protocol_factory(), extra={'addr': addr}) + # It's now up to the protocol to handle the connection. + + def add_reader(self, fd, callback, *args): + """Add a reader callback.""" + handle = events.make_handle(callback, args) + try: + key = self._selector.get_key(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_READ, + (handle, None)) + else: + mask, (reader, writer) = key.events, key.data + self._selector.modify(fd, mask | selectors.EVENT_READ, + (handle, writer)) + if reader is not None: + reader.cancel() + + def remove_reader(self, fd): + """Remove a reader callback.""" + try: + key = self._selector.get_key(fd) + except KeyError: + return False + else: + mask, (reader, writer) = key.events, key.data + mask &= ~selectors.EVENT_READ + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (None, writer)) + + if reader is not None: + reader.cancel() + return True + else: + return False + + def add_writer(self, fd, callback, *args): + """Add a writer callback..""" + handle = events.make_handle(callback, args) + try: + key = self._selector.get_key(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_WRITE, + (None, handle)) + else: + mask, (reader, writer) = key.events, key.data + self._selector.modify(fd, mask | selectors.EVENT_WRITE, + (reader, handle)) + if writer is not None: + writer.cancel() + + def remove_writer(self, fd): + """Remove a writer callback.""" + try: + key = self._selector.get_key(fd) + except KeyError: + return False + else: + mask, (reader, writer) = key.events, key.data + # Remove both writer and connector. + mask &= ~selectors.EVENT_WRITE + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None)) + + if writer is not None: + writer.cancel() + return True + else: + return False + + def sock_recv(self, sock, n): + """XXX""" + fut = futures.Future(loop=self) + self._sock_recv(fut, False, sock, n) + return fut + + def _sock_recv(self, fut, registered, sock, n): + fd = sock.fileno() + if registered: + # Remove the callback early. It should be rare that the + # selector says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_reader(fd) + if fut.cancelled(): + return + try: + data = sock.recv(n) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_recv, fut, True, sock, n) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(data) + + def sock_sendall(self, sock, data): + """XXX""" + fut = futures.Future(loop=self) + if data: + self._sock_sendall(fut, False, sock, data) + else: + fut.set_result(None) + return fut + + def _sock_sendall(self, fut, registered, sock, data): + fd = sock.fileno() + + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + + try: + n = sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + fut.set_exception(exc) + return + + if n == len(data): + fut.set_result(None) + else: + if n: + data = data[n:] + self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + + def sock_connect(self, sock, address): + """XXX""" + # That address better not require a lookup! We're not calling + # self.getaddrinfo() for you here. But verifying this is + # complicated; the socket module doesn't have a pattern for + # IPv6 addresses (there are too many forms, apparently). + fut = futures.Future(loop=self) + self._sock_connect(fut, False, sock, address) + return fut + + def _sock_connect(self, fut, registered, sock, address): + # TODO: Use getaddrinfo() to look up the address, to avoid the + # trap of hanging the entire event loop when the address + # requires doing a DNS lookup. (OTOH, the caller should + # already have done this, so it would be nice if we could + # easily tell whether the address needs looking up or not. I + # know how to do this for IPv4, but IPv6 addresses have many + # syntaxes.) + fd = sock.fileno() + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + try: + if not registered: + # First time around. + sock.connect(address) + else: + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to the except clause below. + raise OSError(err, 'Connect call failed') + except (BlockingIOError, InterruptedError): + self.add_writer(fd, self._sock_connect, fut, True, sock, address) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(None) + + def sock_accept(self, sock): + """XXX""" + fut = futures.Future(loop=self) + self._sock_accept(fut, False, sock) + return fut + + def _sock_accept(self, fut, registered, sock): + fd = sock.fileno() + if registered: + self.remove_reader(fd) + if fut.cancelled(): + return + try: + conn, address = sock.accept() + conn.setblocking(False) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_accept, fut, True, sock) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result((conn, address)) + + def _process_events(self, event_list): + for key, mask in event_list: + fileobj, (reader, writer) = key.fileobj, key.data + if mask & selectors.EVENT_READ and reader is not None: + if reader._cancelled: + self.remove_reader(fileobj) + else: + self._add_callback(reader) + if mask & selectors.EVENT_WRITE and writer is not None: + if writer._cancelled: + self.remove_writer(fileobj) + else: + self._add_callback(writer) + + def stop_serving(self, sock): + self.remove_reader(sock.fileno()) + sock.close() + + +class _SelectorTransport(transports.Transport): + + def __init__(self, loop, sock, protocol, extra): + super().__init__(extra) + self._extra['socket'] = sock + self._loop = loop + self._sock = sock + self._sock_fd = sock.fileno() + self._protocol = protocol + self._buffer = [] + self._conn_lost = 0 + self._writing = True + self._closing = False # Set when close() called. + + def abort(self): + self._force_close(None) + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + self._loop.remove_reader(self._sock_fd) + if not self._buffer: + self._loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + # should be called from exception handler only + tulip_log.exception('Fatal error for %s', self) + self._force_close(exc) + + def _force_close(self, exc): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + self._loop.remove_writer(self._sock_fd) + self._loop.remove_reader(self._sock_fd) + self._buffer.clear() + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + self._sock = None + self._protocol = None + self._loop = None + + +class _SelectorSocketTransport(_SelectorTransport): + + def __init__(self, loop, sock, protocol, waiter=None, extra=None): + super().__init__(loop, sock, protocol, extra) + + self._loop.add_reader(self._sock_fd, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = self._sock.recv(16*1024) + except (BlockingIOError, InterruptedError): + pass + except ConnectionResetError as exc: + self._force_close(exc) + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + try: + self._protocol.eof_received() + finally: + self.close() + + def write(self, data): + assert isinstance(data, bytes), repr(data) + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer and self._writing: + # Attempt to send it right away first. + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except OSError as exc: + self._fatal_error(exc) + return + + if n == len(data): + return + elif n: + data = data[n:] + self._loop.add_writer(self._sock_fd, self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + if not self._writing: + return # transmission off + + data = b''.join(self._buffer) + assert data, 'Data should not be empty' + + self._buffer.clear() + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except Exception as exc: + self._fatal_error(exc) + else: + if n == len(data): + self._loop.remove_writer(self._sock_fd) + if self._closing: + self._call_connection_lost(None) + return + elif n: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def pause_writing(self): + if self._writing: + if self._buffer: + self._loop.remove_writer(self._sock_fd) + self._writing = False + + def resume_writing(self): + if not self._writing: + if self._buffer: + self._loop.add_writer(self._sock_fd, self._write_ready) + self._writing = True + + def discard_output(self): + if self._buffer: + self._loop.remove_writer(self._sock_fd) + self._buffer.clear() + + +class _SelectorSslTransport(_SelectorTransport): + + def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, + server_side=False, extra=None): + if server_side: + assert isinstance( + sslcontext, ssl.SSLContext), 'Must pass an SSLContext' + else: + # Client-side may pass ssl=True to use a default context. + sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslsock = sslcontext.wrap_socket(rawsock, server_side=server_side, + do_handshake_on_connect=False) + + super().__init__(loop, sslsock, protocol, extra) + + self._waiter = waiter + self._rawsock = rawsock + self._sslcontext = sslcontext + + self._on_handshake() + + def _on_handshake(self): + try: + self._sock.do_handshake() + except ssl.SSLWantReadError: + self._loop.add_reader(self._sock_fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._loop.add_writer(self._sock_fd, self._on_handshake) + return + except Exception as exc: + self._sock.close() + if self._waiter is not None: + self._waiter.set_exception(exc) + return + except BaseException as exc: + self._sock.close() + if self._waiter is not None: + self._waiter.set_exception(exc) + raise + self._loop.remove_reader(self._sock_fd) + self._loop.remove_writer(self._sock_fd) + self._loop.add_reader(self._sock_fd, self._on_ready) + self._loop.add_writer(self._sock_fd, self._on_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if self._waiter is not None: + self._loop.call_soon(self._waiter.set_result, None) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # First try reading. + if not self._closing: + try: + data = self._sock.recv(8192) + except (BlockingIOError, InterruptedError, + ssl.SSLWantReadError, ssl.SSLWantWriteError): + pass + except ConnectionResetError as exc: + self._force_close(exc) + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + try: + self._protocol.eof_received() + finally: + self.close() + + # Now try writing, if there's anything to write. + if self._buffer: + data = b''.join(self._buffer) + self._buffer = [] + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError, + ssl.SSLWantReadError, ssl.SSLWantWriteError): + n = 0 + except Exception as exc: + self._fatal_error(exc) + return + + if n < len(data): + self._buffer.append(data[n:]) + + if self._closing and not self._buffer: + self._loop.remove_writer(self._sock_fd) + self._call_connection_lost(None) + + def write(self, data): + assert isinstance(data, bytes), repr(data) + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + self._loop.remove_reader(self._sock_fd) + + # TODO: write_eof(), can_write_eof(). + + +class _SelectorDatagramTransport(_SelectorTransport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, loop, sock, protocol, address=None, extra=None): + super().__init__(loop, sock, protocol, extra) + + self._address = address + self._buffer = collections.deque() + self._loop.add_reader(self._sock_fd, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + + def _read_ready(self): + try: + data, addr = self._sock.recvfrom(self.max_size) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + self._protocol.datagram_received(data, addr) + + def sendto(self, data, addr=None): + assert isinstance(data, bytes), repr(data) + if not data: + return + + if self._address: + assert addr in (None, self._address) + + if self._conn_lost and self._address: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Attempt to send it right away first. + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + return + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except (BlockingIOError, InterruptedError): + self._loop.add_writer(self._sock_fd, self._sendto_ready) + except Exception as exc: + self._fatal_error(exc) + return + + self._buffer.append((data, addr)) + + def _sendto_ready(self): + while self._buffer: + data, addr = self._buffer.popleft() + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except (BlockingIOError, InterruptedError): + self._buffer.appendleft((data, addr)) # Try again later. + break + except Exception as exc: + self._fatal_error(exc) + return + + if not self._buffer: + self._loop.remove_writer(self._sock_fd) + if self._closing: + self._call_connection_lost(None) + + def _force_close(self, exc): + if self._address and isinstance(exc, ConnectionRefusedError): + self._protocol.connection_refused(exc) + + super()._force_close(exc) diff --git a/tulip/selectors.py b/tulip/selectors.py new file mode 100644 index 00000000..b81b1dbe --- /dev/null +++ b/tulip/selectors.py @@ -0,0 +1,410 @@ +"""Selectors module. + +This module allows high-level and efficient I/O multiplexing, built upon the +`select` module primitives. +""" + + +from abc import ABCMeta, abstractmethod +from collections import namedtuple +import functools +import select +import sys + + +# generic events, that must be mapped to implementation-specific ones +EVENT_READ = (1 << 0) +EVENT_WRITE = (1 << 1) + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file object or file descriptor + + Returns: + corresponding file descriptor + """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (AttributeError, ValueError): + raise ValueError("Invalid file object: " + "{!r}".format(fileobj)) from None + if fd < 0: + raise ValueError("Invalid file descriptor: {}".format(fd)) + return fd + + +SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data']) +"""Object used to associate a file object to its backing file descriptor, +selected event mask and attached data.""" + + +class BaseSelector(metaclass=ABCMeta): + """Base selector class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + performant implementation on the current platform. + """ + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # this maps file objects to keys - for fast (un)registering + self._fileobj_to_key = {} + + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object or file descriptor + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + """ + if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)): + raise ValueError("Invalid events: {}".format(events)) + + key = SelectorKey(fileobj, _fileobj_to_fd(fileobj), events, data) + + if key.fd in self._fd_to_key: + raise KeyError("{!r} (FD {}) is already " + "registered".format(fileobj, key.fd)) + + self._fd_to_key[key.fd] = key + self._fileobj_to_key[fileobj] = key + return key + + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object or file descriptor + + Returns: + SelectorKey instance + """ + try: + key = self._fileobj_to_key.pop(fileobj) + del self._fd_to_key[key.fd] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + return key + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object or file descriptor + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + """ + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fileobj_to_key[fileobj] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + if events != key.events or data != key.data: + # TODO: If only the data changed, use a shortcut that only + # updates the data. + self.unregister(fileobj) + return self.register(fileobj, events, data) + else: + return key + + @abstractmethod + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout <= 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (key, events) for ready file objects + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE + """ + raise NotImplementedError() + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + self._fd_to_key.clear() + self._fileobj_to_key.clear() + + def get_key(self, fileobj): + """Return the key associated to a registered file object. + + Returns: + SelectorKey for this file object + """ + try: + return self._fileobj_to_key[fileobj] + except KeyError: + raise KeyError("{} is not registered".format(fileobj)) from None + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key, or None if not found + """ + try: + return self._fd_to_key[fd] + except KeyError: + return None + + +class SelectSelector(BaseSelector): + """Select-based selector.""" + + def __init__(self): + super().__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & EVENT_WRITE: + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + if sys.platform == 'win32': + def _select(self, r, w, _, timeout=None): + r, w, x = select.select(r, w, w, timeout) + return r, w + x, [] + else: + _select = select.select + + def select(self, timeout=None): + timeout = None if timeout is None else max(timeout, 0) + ready = [] + try: + r, w, _ = self._select(self._readers, self._writers, [], timeout) + except InterruptedError: + return ready + r = set(r) + w = set(w) + for fd in r | w: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + +if hasattr(select, 'poll'): + + class PollSelector(BaseSelector): + """Poll-based selector.""" + + def __init__(self): + super().__init__() + self._poll = select.poll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= select.POLLIN + if events & EVENT_WRITE: + poll_events |= select.POLLOUT + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = None if timeout is None else max(int(1000 * timeout), 0) + ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & ~select.POLLIN: + events |= EVENT_WRITE + if event & ~select.POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + +if hasattr(select, 'epoll'): + + class EpollSelector(BaseSelector): + """Epoll-based selector.""" + + def __init__(self): + super().__init__() + self._epoll = select.epoll() + + def fileno(self): + return self._epoll.fileno() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + epoll_events = 0 + if events & EVENT_READ: + epoll_events |= select.EPOLLIN + if events & EVENT_WRITE: + epoll_events |= select.EPOLLOUT + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._epoll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = -1 if timeout is None else max(timeout, 0) + max_ev = len(self._fd_to_key) + ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & ~select.EPOLLIN: + events |= EVENT_WRITE + if event & ~select.EPOLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + super().close() + self._epoll.close() + + +if hasattr(select, 'kqueue'): + + class KqueueSelector(BaseSelector): + """Kqueue-based selector.""" + + def __init__(self): + super().__init__() + self._kqueue = select.kqueue() + + def fileno(self): + return self._kqueue.fileno() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + if key.events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + if key.events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + return key + + def select(self, timeout=None): + timeout = None if timeout is None else max(timeout, 0) + max_ev = len(self._fd_to_key) + ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except InterruptedError: + return ready + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == select.KQ_FILTER_READ: + events |= EVENT_READ + if flag == select.KQ_FILTER_WRITE: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + super().close() + self._kqueue.close() + + +# Choose the best implementation: roughly, epoll|kqueue > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + DefaultSelector = KqueueSelector +elif 'EpollSelector' in globals(): + DefaultSelector = EpollSelector +elif 'PollSelector' in globals(): + DefaultSelector = PollSelector +else: + DefaultSelector = SelectSelector diff --git a/tulip/streams.py b/tulip/streams.py new file mode 100644 index 00000000..3203b7d6 --- /dev/null +++ b/tulip/streams.py @@ -0,0 +1,211 @@ +"""Stream-related things.""" + +__all__ = ['StreamReader', 'StreamReaderProtocol', 'open_connection'] + +import collections + +from . import events +from . import futures +from . import protocols +from . import tasks + + +_DEFAULT_LIMIT = 2**16 + + +@tasks.coroutine +def open_connection(host=None, port=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """A wrapper for create_connection() returning a (reader, writer) pair. + + The reader returned is a StreamReader instance; the writer is a + Transport. + + The arguments are all the usual arguments to create_connection() + except protocol_factory; most common are positional host and port, + with various optional keyword arguments following. + + Additional optional keyword arguments are loop (to set the event loop + instance to use) and limit (to set the buffer limit passed to the + StreamReader). + + (If you want to customize the StreamReader and/or + StreamReaderProtocol classes, just copy the code -- there's + really nothing special here except some convenience.) + """ + if loop is None: + loop = events.get_event_loop() + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader) + transport, _ = yield from loop.create_connection( + lambda: protocol, host, port, **kwds) + return reader, transport # (reader, writer) + + +class StreamReaderProtocol(protocols.Protocol): + """Trivial helper class to adapt between Protocol and StreamReader. + + (This is a helper class instead of making StreamReader itself a + Protocol subclass, because the StreamReader has other potential + uses, and to prevent the user of the StreamReader to accidentally + call inappropriate methods of the protocol.) + """ + + def __init__(self, stream_reader): + self.stream_reader = stream_reader + + def connection_lost(self, exc): + if exc is None: + self.stream_reader.feed_eof() + else: + self.stream_reader.set_exception(exc) + + def data_received(self, data): + self.stream_reader.feed_data(data) + + def eof_received(self): + self.stream_reader.feed_eof() + + +class StreamReader: + + def __init__(self, limit=_DEFAULT_LIMIT, loop=None): + self.limit = limit # Max line length. (Security feature.) + if loop is None: + loop = events.get_event_loop() + self.loop = loop + self.buffer = collections.deque() # Deque of bytes objects. + self.byte_count = 0 # Bytes in buffer. + self.eof = False # Whether we're done. + self.waiter = None # A future. + self._exception = None + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + waiter = self.waiter + if waiter is not None: + self.waiter = None + waiter.set_exception(exc) + + def feed_eof(self): + self.eof = True + waiter = self.waiter + if waiter is not None: + self.waiter = None + if not waiter.done(): + waiter.set_result(True) + + def feed_data(self, data): + if not data: + return + + self.buffer.append(data) + self.byte_count += len(data) + + waiter = self.waiter + if waiter is not None: + self.waiter = None + if not waiter.done(): + waiter.set_result(False) + + @tasks.coroutine + def readline(self): + if self._exception is not None: + raise self._exception + + parts = [] + parts_size = 0 + not_enough = True + + while not_enough: + while self.buffer and not_enough: + data = self.buffer.popleft() + ichar = data.find(b'\n') + if ichar < 0: + parts.append(data) + parts_size += len(data) + else: + ichar += 1 + head, tail = data[:ichar], data[ichar:] + if tail: + self.buffer.appendleft(tail) + not_enough = False + parts.append(head) + parts_size += len(head) + + if parts_size > self.limit: + self.byte_count -= parts_size + raise ValueError('Line is too long') + + if self.eof: + break + + if not_enough: + assert self.waiter is None + self.waiter = futures.Future(loop=self.loop) + yield from self.waiter + + line = b''.join(parts) + self.byte_count -= parts_size + + return line + + @tasks.coroutine + def read(self, n=-1): + if self._exception is not None: + raise self._exception + + if not n: + return b'' + + if n < 0: + while not self.eof: + assert not self.waiter + self.waiter = futures.Future(loop=self.loop) + yield from self.waiter + else: + if not self.byte_count and not self.eof: + assert not self.waiter + self.waiter = futures.Future(loop=self.loop) + yield from self.waiter + + if n < 0 or self.byte_count <= n: + data = b''.join(self.buffer) + self.buffer.clear() + self.byte_count = 0 + return data + + parts = [] + parts_bytes = 0 + while self.buffer and parts_bytes < n: + data = self.buffer.popleft() + data_bytes = len(data) + if n < parts_bytes + data_bytes: + data_bytes = n - parts_bytes + data, rest = data[:data_bytes], data[data_bytes:] + self.buffer.appendleft(rest) + + parts.append(data) + parts_bytes += data_bytes + self.byte_count -= data_bytes + + return b''.join(parts) + + @tasks.coroutine + def readexactly(self, n): + if self._exception is not None: + raise self._exception + + if n <= 0: + return b'' + + while self.byte_count < n and not self.eof: + assert not self.waiter + self.waiter = futures.Future(loop=self.loop) + yield from self.waiter + + return (yield from self.read(n)) diff --git a/tulip/tasks.py b/tulip/tasks.py new file mode 100644 index 00000000..94441bea --- /dev/null +++ b/tulip/tasks.py @@ -0,0 +1,377 @@ +"""Support for tasks, coroutines and the scheduler.""" + +__all__ = ['coroutine', 'Task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'wait_for', 'as_completed', 'sleep', 'async', + 'gather', + ] + +import collections +import concurrent.futures +import functools +import inspect + +from . import events +from . import futures + + +def coroutine(func): + """Decorator to mark coroutines. + + Decorator wraps non generator functions and returns generator wrapper. + If non generator function returns generator of Future it yield-from it. + + TODO: This is a feel-good API only. It is not enforced. + """ + if inspect.isgeneratorfunction(func): + coro = func + else: + @functools.wraps(func) + def coro(*args, **kw): + res = func(*args, **kw) + if isinstance(res, futures.Future) or inspect.isgenerator(res): + res = yield from res + return res + + coro._is_coroutine = True # Not sure who can use this. + return coro + + +# TODO: Do we need this? +def iscoroutinefunction(func): + """Return True if func is a decorated coroutine function.""" + return (inspect.isgeneratorfunction(func) and + getattr(func, '_is_coroutine', False)) + + +# TODO: Do we need this? +def iscoroutine(obj): + """Return True if obj is a coroutine object.""" + return inspect.isgenerator(obj) # TODO: And what? + + +class Task(futures.Future): + """A coroutine wrapped in a Future.""" + + def __init__(self, coro, *, loop=None): + assert inspect.isgenerator(coro) # Must be a coroutine *object*. + super().__init__(loop=loop) + self._coro = coro + self._fut_waiter = None + self._must_cancel = False + self._loop.call_soon(self._step) + + def __repr__(self): + res = super().__repr__() + if (self._must_cancel and + self._state == futures._PENDING and + ')'.format(self._coro.__name__) + res[i:] + return res + + def cancel(self): + if self.done(): + return False + if self._fut_waiter is not None: + if self._fut_waiter.cancel(): + return True + # It must be the case that self._step is already scheduled. + self._must_cancel = True + return True + + def _step(self, value=None, exc=None): + assert not self.done(), \ + '_step(): already done: {!r}, {!r}, {!r}'.format(self, value, exc) + if self._must_cancel: + assert self._fut_waiter is None + exc = futures.CancelledError() + value = None + coro = self._coro + self._fut_waiter = None + # Call either coro.throw(exc) or coro.send(value). + try: + if exc is not None: + result = coro.throw(exc) + elif value is not None: + result = coro.send(value) + else: + result = next(coro) + except StopIteration as exc: + self.set_result(exc.value) + except futures.CancelledError as exc: + super().cancel() # I.e., Future.cancel(self). + except Exception as exc: + self.set_exception(exc) + except BaseException as exc: + self.set_exception(exc) + raise + else: + if isinstance(result, futures.Future): + # Yielded Future must come from Future.__iter__(). + if result._blocking: + result._blocking = False + result.add_done_callback(self._wakeup) + self._fut_waiter = result + else: + self._loop.call_soon( + self._step, None, + RuntimeError( + 'yield was used instead of yield from ' + 'in task {!r} with {!r}'.format(self, result))) + elif result is None: + # Bare yield relinquishes control for one event loop iteration. + self._loop.call_soon(self._step) + elif inspect.isgenerator(result): + # Yielding a generator is just wrong. + self._loop.call_soon( + self._step, None, + RuntimeError( + 'yield was used instead of yield from for ' + 'generator in task {!r} with {}'.format( + self, result))) + else: + # Yielding something else is an error. + self._loop.call_soon( + self._step, None, + RuntimeError( + 'Task got bad yield: {!r}'.format(result))) + self = None + + def _wakeup(self, future): + try: + value = future.result() + except Exception as exc: + self._step(None, exc) + else: + self._step(value, None) + self = None # Needed to break cycles when an exception occurs. + + +# wait() and as_completed() similar to those in PEP 3148. + +FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED +FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION +ALL_COMPLETED = concurrent.futures.ALL_COMPLETED + + +@coroutine +def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): + """Wait for the Futures and coroutines given by fs to complete. + + Coroutines will be wrapped in Tasks. + + Returns two sets of Future: (done, pending). + + Usage: + + done, pending = yield from tulip.wait(fs) + + Note: This does not raise TimeoutError! Futures that aren't done + when the timeout occurs are returned in the second set. + """ + if not fs: + raise ValueError('Set of coroutines/Futures is empty.') + + if loop is None: + loop = events.get_event_loop() + + fs = set(async(f, loop=loop) for f in fs) + + if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED): + raise ValueError('Invalid return_when value: {}'.format(return_when)) + return (yield from _wait(fs, timeout, return_when, loop)) + + +@coroutine +def wait_for(fut, timeout, *, loop=None): + """Wait for the single Future or coroutine to complete, with timeout. + + Coroutine will be wrapped in Task. + + Returns result of the Future or coroutine. Raises TimeoutError when + timeout occurs. + + Usage: + + result = yield from tulip.wait_for(fut, 10.0) + + """ + if loop is None: + loop = events.get_event_loop() + + fut = async(fut, loop=loop) + + done, pending = yield from _wait([fut], timeout, FIRST_COMPLETED, loop) + if done: + return done.pop().result() + + raise futures.TimeoutError() + + +@coroutine +def _wait(fs, timeout, return_when, loop): + """Internal helper for wait(return_when=FIRST_COMPLETED). + + The fs argument must be a set of Futures. + The timeout argument is like for wait(). + """ + assert fs, 'Set of Futures is empty.' + waiter = futures.Future(loop=loop) + timeout_handle = None + if timeout is not None: + timeout_handle = loop.call_later(timeout, waiter.cancel) + counter = len(fs) + + def _on_completion(f): + nonlocal counter + counter -= 1 + if (counter <= 0 or + return_when == FIRST_COMPLETED or + return_when == FIRST_EXCEPTION and (not f.cancelled() and + f.exception() is not None)): + if timeout_handle is not None: + timeout_handle.cancel() + waiter.cancel() + + for f in fs: + f.add_done_callback(_on_completion) + try: + yield from waiter + except futures.CancelledError: + pass + done, pending = set(), set() + for f in fs: + f.remove_done_callback(_on_completion) + if f.done(): + done.add(f) + else: + pending.add(f) + return done, pending + + +# This is *not* a @coroutine! It is just an iterator (yielding Futures). +def as_completed(fs, *, loop=None, timeout=None): + """Return an iterator whose values, when waited for, are Futures. + + This differs from PEP 3148; the proper way to use this is: + + for f in as_completed(fs): + result = yield from f # The 'yield from' may raise. + # Use result. + + Raises TimeoutError if the timeout occurs before all Futures are + done. + + Note: The futures 'f' are not necessarily members of fs. + """ + loop = loop if loop is not None else events.get_event_loop() + deadline = None if timeout is None else loop.time() + timeout + todo = set(async(f, loop=loop) for f in fs) + completed = collections.deque() + + @coroutine + def _wait_for_one(): + while not completed: + timeout = None + if deadline is not None: + timeout = deadline - loop.time() + if timeout < 0: + raise futures.TimeoutError() + done, pending = yield from _wait( + todo, timeout, FIRST_COMPLETED, loop) + # Multiple callers might be waiting for the same events + # and getting the same outcome. Dedupe by updating todo. + for f in done: + if f in todo: + todo.remove(f) + completed.append(f) + f = completed.popleft() + return f.result() # May raise. + + for _ in range(len(todo)): + yield _wait_for_one() + + +@coroutine +def sleep(delay, result=None, *, loop=None): + """Coroutine that completes after a given time (in seconds).""" + future = futures.Future(loop=loop) + h = future._loop.call_later(delay, future.set_result, result) + try: + return (yield from future) + finally: + h.cancel() + + +def async(coro_or_future, *, loop=None): + """Wrap a coroutine in a future. + + If the argument is a Future, it is returned directly. + """ + if isinstance(coro_or_future, futures.Future): + if loop is not None and loop is not coro_or_future._loop: + raise ValueError('loop argument must agree with Future') + return coro_or_future + elif iscoroutine(coro_or_future): + return Task(coro_or_future, loop=loop) + else: + raise TypeError('A Future or coroutine is required') + + +def gather(*coros_or_futures, loop=None, return_exceptions=False): + """Return a future aggregating results from the given coroutines + or futures. + + All futures must share the same event loop. If all the tasks + are done successfully, the returned future's result is the list of + results (in the order of the original sequence, not necessarily the + order of results arrival). If one of the tasks is cancelled, the + returned future is immediately cancelled too. If *result_exception* + is True, exceptions in the tasks are treated the same as successful + results, and gathered in the result list; otherwise, the first raised + exception will be immediately propagated to the returned future. + """ + children = [async(fut, loop=loop) for fut in coros_or_futures] + n = len(children) + if n == 0: + outer = futures.Future(loop=loop) + outer.set_result([]) + return outer + if loop is None: + loop = children[0]._loop + for fut in children: + if fut._loop is not loop: + raise ValueError("futures are tied to different event loops") + outer = futures.Future(loop=loop) + nfinished = 0 + results = [None] * n + + def _done_callback(i, fut): + nonlocal nfinished + if outer._state != futures._PENDING: + if fut._exception is not None: + # Be sure to mark the result retrieved + fut.exception() + return + if fut._state == futures._CANCELLED: + outer.cancel() + return + elif fut._exception is not None: + if not return_exceptions: + outer.set_exception(fut.exception()) + return + res = fut.exception() + else: + res = fut._result + results[i] = res + nfinished += 1 + if nfinished == n: + outer.set_result(results) + + for i, fut in enumerate(children): + fut.add_done_callback(functools.partial(_done_callback, i)) + return outer diff --git a/tulip/test_utils.py b/tulip/test_utils.py new file mode 100644 index 00000000..b4af0c89 --- /dev/null +++ b/tulip/test_utils.py @@ -0,0 +1,443 @@ +"""Utilities shared by tests.""" + +import cgi +import collections +import contextlib +import gc +import email.parser +import http.server +import json +import logging +import io +import unittest.mock +import os +import re +import socket +import sys +import threading +import traceback +import unittest +import unittest.mock +import urllib.parse +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +import tulip +import tulip.http +from tulip.http import client +from tulip import base_events +from tulip import events + +from tulip import base_events +from tulip import selectors + + +if sys.platform == 'win32': # pragma: no cover + from .windows_utils import socketpair +else: + from socket import socketpair # pragma: no cover + + +def run_briefly(loop): + @tulip.coroutine + def once(): + pass + t = tulip.Task(once(), loop=loop) + loop.run_until_complete(t) + + +def run_once(loop): + """loop.stop() schedules _raise_stop_error() + and run_forever() runs until _raise_stop_error() callback. + this wont work if test waits for some IO events, because + _raise_stop_error() runs before any of io events callbacks. + """ + loop.stop() + loop.run_forever() + + +@contextlib.contextmanager +def run_test_server(loop, *, host='127.0.0.1', port=0, + use_ssl=False, router=None): + properties = {} + transports = [] + + class HttpServer: + + def __init__(self, host, port): + self.host = host + self.port = port + self.address = (host, port) + self._url = '{}://{}:{}'.format( + 'https' if use_ssl else 'http', host, port) + + def __getitem__(self, key): + return properties[key] + + def __setitem__(self, key, value): + properties[key] = value + + def url(self, *suffix): + return urllib.parse.urljoin( + self._url, '/'.join(str(s) for s in suffix)) + + class TestHttpServer(tulip.http.ServerHttpProtocol): + + def connection_made(self, transport): + transports.append(transport) + super().connection_made(transport) + + def handle_request(self, message, payload): + if properties.get('close', False): + return + + if properties.get('noresponse', False): + yield from tulip.sleep(99999) + + if router is not None: + body = bytearray() + chunk = yield from payload.read() + while chunk: + body.extend(chunk) + chunk = yield from payload.read() + + rob = router( + self, properties, + self.transport, message, bytes(body)) + rob.dispatch() + + else: + response = tulip.http.Response( + self.transport, 200, message.version) + + text = b'Test message' + response.add_header('Content-type', 'text/plain') + response.add_header('Content-length', str(len(text))) + response.send_headers() + response.write(text) + response.write_eof() + + if use_ssl: + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + keyfile = os.path.join(here, 'sample.key') + certfile = os.path.join(here, 'sample.crt') + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain(certfile, keyfile) + else: + sslcontext = None + + def run(loop, fut): + thread_loop = tulip.new_event_loop() + tulip.set_event_loop(thread_loop) + + socks = thread_loop.run_until_complete( + thread_loop.start_serving( + lambda: TestHttpServer(keep_alive=0.5), + host, port, ssl=sslcontext)) + + waiter = tulip.Future(loop=thread_loop) + loop.call_soon_threadsafe( + fut.set_result, (thread_loop, waiter, socks[0].getsockname())) + + try: + thread_loop.run_until_complete(waiter) + finally: + # call pending connection_made if present + run_briefly(thread_loop) + + # close opened trnsports + for tr in transports: + tr.close() + + run_briefly(thread_loop) # call close callbacks + + for s in socks: + thread_loop.stop_serving(s) + + thread_loop.stop() + thread_loop.close() + gc.collect() + + fut = tulip.Future(loop=loop) + server_thread = threading.Thread(target=run, args=(loop, fut)) + server_thread.start() + + thread_loop, waiter, addr = loop.run_until_complete(fut) + try: + yield HttpServer(*addr) + finally: + thread_loop.call_soon_threadsafe(waiter.set_result, None) + server_thread.join() + + +class Router: + + _response_version = "1.1" + _responses = http.server.BaseHTTPRequestHandler.responses + + def __init__(self, srv, props, transport, message, payload): + # headers + self._headers = http.client.HTTPMessage() + for hdr, val in message.headers: + self._headers.add_header(hdr, val) + + self._srv = srv + self._props = props + self._transport = transport + self._method = message.method + self._uri = message.path + self._version = message.version + self._compression = message.compression + self._body = payload + + url = urllib.parse.urlsplit(self._uri) + self._path = url.path + self._query = url.query + + @staticmethod + def define(rmatch): + def wrapper(fn): + f_locals = sys._getframe(1).f_locals + mapping = f_locals.setdefault('_mapping', []) + mapping.append((re.compile(rmatch), fn.__name__)) + return fn + + return wrapper + + def dispatch(self): # pragma: no cover + for route, fn in self._mapping: + match = route.match(self._path) + if match is not None: + try: + return getattr(self, fn)(match) + except Exception: + out = io.StringIO() + traceback.print_exc(file=out) + self._response(500, out.getvalue()) + + return + + return self._response(self._start_response(404)) + + def _start_response(self, code): + return tulip.http.Response(self._transport, code) + + def _response(self, response, body=None, headers=None, chunked=False): + r_headers = {} + for key, val in self._headers.items(): + key = '-'.join(p.capitalize() for p in key.split('-')) + r_headers[key] = val + + encoding = self._headers.get('content-encoding', '').lower() + if 'gzip' in encoding: # pragma: no cover + cmod = 'gzip' + elif 'deflate' in encoding: + cmod = 'deflate' + else: + cmod = '' + + resp = { + 'method': self._method, + 'version': '%s.%s' % self._version, + 'path': self._uri, + 'headers': r_headers, + 'origin': self._transport.get_extra_info('addr', ' ')[0], + 'query': self._query, + 'form': {}, + 'compression': cmod, + 'multipart-data': [] + } + if body: # pragma: no cover + resp['content'] = body + + ct = self._headers.get('content-type', '').lower() + + # application/x-www-form-urlencoded + if ct == 'application/x-www-form-urlencoded': + resp['form'] = urllib.parse.parse_qs(self._body.decode('latin1')) + + # multipart/form-data + elif ct.startswith('multipart/form-data'): # pragma: no cover + out = io.BytesIO() + for key, val in self._headers.items(): + out.write(bytes('{}: {}\r\n'.format(key, val), 'latin1')) + + out.write(b'\r\n') + out.write(self._body) + out.write(b'\r\n') + out.seek(0) + + message = email.parser.BytesParser().parse(out) + if message.is_multipart(): + for msg in message.get_payload(): + if msg.is_multipart(): + logging.warn('multipart msg is not expected') + else: + key, params = cgi.parse_header( + msg.get('content-disposition', '')) + params['data'] = msg.get_payload() + params['content-type'] = msg.get_content_type() + resp['multipart-data'].append(params) + + body = json.dumps(resp, indent=4, sort_keys=True) + + # default headers + hdrs = [('Connection', 'close'), + ('Content-Type', 'application/json')] + if chunked: + hdrs.append(('Transfer-Encoding', 'chunked')) + else: + hdrs.append(('Content-Length', str(len(body)))) + + # extra headers + if headers: + hdrs.extend(headers.items()) + + if chunked: + response.force_chunked() + + # headers + response.add_headers(*hdrs) + response.send_headers() + + # write payload + response.write(client.str_to_bytes(body)) + response.write_eof() + + # keep-alive + if response.keep_alive(): + self._srv.keep_alive(True) + + +def make_test_protocol(base): + dct = {} + for name in dir(base): + if name.startswith('__') and name.endswith('__'): + # skip magic names + continue + dct[name] = unittest.mock.Mock(return_value=None) + return type('TestProtocol', (base,) + base.__bases__, dct)() + + +class TestSelector(selectors.BaseSelector): + + def select(self, timeout): + return [] + + +class TestLoop(base_events.BaseEventLoop): + """Loop for unittests. + + It manages self time directly. + If something scheduled to be executed later then + on next loop iteration after all ready handlers done + generator passed to __init__ is calling. + + Generator should be like this: + + def gen(): + ... + when = yield ... + ... = yield time_advance + + Value retuned by yield is absolute time of next scheduled handler. + Value passed to yield is time advance to move loop's time forward. + """ + + def __init__(self, gen=None): + super().__init__() + + if gen is None: + self._check_on_close = False + def gen(): + yield + else: + self._check_on_close = True + + self._gen = gen() + next(self._gen) + self._time = 0 + self._timers = [] + self._selector = TestSelector() + + self.readers = {} + self.writers = {} + self.reset_counters() + + def time(self): + return self._time + + def advance_time(self, advance): + """Move test time forward.""" + if advance: + self._time += advance + + def close(self): + if self._check_on_close: + try: + self._gen.send(0) + except StopIteration: + pass + else: # pragma: no cover + raise AssertionError("Time generator is not finished") + + def add_reader(self, fd, callback, *args): + self.readers[fd] = events.make_handle(callback, args) + + def remove_reader(self, fd): + self.remove_reader_count[fd] += 1 + if fd in self.readers: + del self.readers[fd] + return True + else: + return False + + def assert_reader(self, fd, callback, *args): + assert fd in self.readers, 'fd {} is not registered'.format(fd) + handle = self.readers[fd] + assert handle._callback == callback, '{!r} != {!r}'.format( + handle._callback, callback) + assert handle._args == args, '{!r} != {!r}'.format( + handle._args, args) + + def add_writer(self, fd, callback, *args): + self.writers[fd] = events.make_handle(callback, args) + + def remove_writer(self, fd): + self.remove_writer_count[fd] += 1 + if fd in self.writers: + del self.writers[fd] + return True + else: + return False + + def assert_writer(self, fd, callback, *args): + assert fd in self.writers, 'fd {} is not registered'.format(fd) + handle = self.writers[fd] + assert handle._callback == callback, '{!r} != {!r}'.format( + handle._callback, callback) + assert handle._args == args, '{!r} != {!r}'.format( + handle._args, args) + + def reset_counters(self): + self.remove_reader_count = collections.defaultdict(int) + self.remove_writer_count = collections.defaultdict(int) + + def _run_once(self): + super()._run_once() + for when in self._timers: + advance = self._gen.send(when) + self.advance_time(advance) + self._timers = [] + + def call_at(self, when, callback, *args): + self._timers.append(when) + return super().call_at(when, callback, *args) + + def _process_events(self, event_list): + return + + def _write_to_self(self): + pass diff --git a/tulip/transports.py b/tulip/transports.py new file mode 100644 index 00000000..56425aa9 --- /dev/null +++ b/tulip/transports.py @@ -0,0 +1,201 @@ +"""Abstract Transport class.""" + +__all__ = ['ReadTransport', 'WriteTransport', 'Transport'] + + +class BaseTransport: + """Base ABC for transports.""" + + def __init__(self, extra=None): + if extra is None: + extra = {} + self._extra = extra + + def get_extra_info(self, name, default=None): + """Get optional transport information.""" + return self._extra.get(name, default) + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + raise NotImplementedError + + +class ReadTransport(BaseTransport): + """ABC for read-only transports.""" + + def pause(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + raise NotImplementedError + + +class WriteTransport(BaseTransport): + """ABC for write-only transports.""" + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def write_eof(self): + """Closes the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + raise NotImplementedError + + def can_write_eof(self): + """Return True if this protocol supports write_eof(), False if not.""" + raise NotImplementedError + + def pause_writing(self): + """Pause transmission on the transport. + + Subsequent writes are deferred until resume_writing() is called. + """ + raise NotImplementedError + + def resume_writing(self): + """Resume transmission on the transport. """ + raise NotImplementedError + + def discard_output(self): + """Discard any buffered data awaiting transmission on the transport.""" + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class Transport(ReadTransport, WriteTransport): + """ABC representing a bidirectional transport. + + There may be several implementations, but typically, the user does + not implement new transports; rather, the platform provides some + useful transports that are implemented using the platform's best + practices. + + The user never instantiates a transport directly; they call a + utility function, passing it a protocol factory and other + information necessary to create the transport and protocol. (E.g. + EventLoop.create_connection() or EventLoop.start_serving().) + + The utility function will asynchronously create a transport and a + protocol and hook them up by calling the protocol's + connection_made() method, passing it the transport. + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + +class DatagramTransport(BaseTransport): + """ABC for datagram (UDP) transports.""" + + def sendto(self, data, addr=None): + """Send data to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + addr is target socket address. + If addr is None use target address pointed on transport creation. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class SubprocessTransport(BaseTransport): + + def get_pid(self): + """Get subprocess id.""" + raise NotImplementedError + + def get_returncode(self): + """Get subprocess returncode. + + See also + http://docs.python.org/3/library/subprocess#subprocess.Popen.returncode + """ + raise NotImplementedError + + def get_pipe_transport(self, fd): + """Get transport for pipe with number fd.""" + raise NotImplementedError + + def send_signal(self, signal): + """Send signal to subprocess. + + See also: + docs.python.org/3/library/subprocess#subprocess.Popen.send_signal + """ + raise NotImplementedError + + def terminate(self): + """Stop the subprocess. + + Alias for close() method. + + On Posix OSs the method sends SIGTERM to the subprocess. + On Windows the Win32 API function TerminateProcess() + is called to stop the subprocess. + + See also: + http://docs.python.org/3/library/subprocess#subprocess.Popen.terminate + """ + raise NotImplementedError + + def kill(self): + """Kill the subprocess. + + On Posix OSs the function sends SIGKILL to the subprocess. + On Windows kill() is an alias for terminate(). + + See also: + http://docs.python.org/3/library/subprocess#subprocess.Popen.kill + """ + raise NotImplementedError diff --git a/tulip/unix_events.py b/tulip/unix_events.py new file mode 100644 index 00000000..75131851 --- /dev/null +++ b/tulip/unix_events.py @@ -0,0 +1,555 @@ +"""Selector eventloop for Unix with signal handling.""" + +import collections +import errno +import fcntl +import functools +import os +import signal +import socket +import stat +import subprocess +import sys + + +from . import constants +from . import events +from . import protocols +from . import selector_events +from . import tasks +from . import transports +from .log import tulip_log + + +__all__ = ['SelectorEventLoop'] + + +if sys.platform == 'win32': # pragma: no cover + raise ImportError('Signals are not really supported on Windows') + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Unix event loop + + Adds signal handling to SelectorEventLoop + """ + + def __init__(self, selector=None): + super().__init__(selector) + self._signal_handlers = {} + self._subprocesses = {} + + def _socketpair(self): + return socket.socketpair() + + def close(self): + handler = self._signal_handlers.get(signal.SIGCHLD) + if handler is not None: + self.remove_signal_handler(signal.SIGCHLD) + super().close() + + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + self._check_signal(sig) + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except ValueError as exc: + raise RuntimeError(str(exc)) + + handle = events.make_handle(callback, args) + self._signal_handlers[sig] = handle + + try: + signal.signal(sig, self._handle_signal) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as nexc: + tulip_log.info('set_wakeup_fd(-1) failed: %s', nexc) + + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + def _handle_signal(self, sig, arg): + """Internal helper that is the actual signal handler.""" + handle = self._signal_handlers.get(sig) + if handle is None: + return # Assume it's some race condition. + if handle._cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self._add_callback_signalsafe(handle) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not. + """ + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as exc: + tulip_log.info('set_wakeup_fd(-1) failed: %s', exc) + + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + + if not (1 <= sig < signal.NSIG): + raise ValueError( + 'sig {} out of range(1, {})'.format(sig, signal.NSIG)) + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixReadPipeTransport(self, pipe, protocol, waiter, extra) + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixWritePipeTransport(self, pipe, protocol, waiter, extra) + + @tasks.coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + self._reg_sigchld() + transp = _UnixSubprocessTransport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs) + self._subprocesses[transp.get_pid()] = transp + yield from transp._post_init() + return transp + + def _reg_sigchld(self): + if signal.SIGCHLD not in self._signal_handlers: + self.add_signal_handler(signal.SIGCHLD, self._sig_chld) + + def _sig_chld(self): + try: + while True: + try: + pid, status = os.waitpid(0, os.WNOHANG) + except ChildProcessError: + break + if pid == 0: + continue + elif os.WIFSIGNALED(status): + returncode = -os.WTERMSIG(status) + elif os.WIFEXITED(status): + returncode = os.WEXITSTATUS(status) + else: + # covered by + # SelectorEventLoopTests.test__sig_chld_unknown_status + # from tests/unix_events_test.py + # bug in coverage.py version 3.6 ??? + continue # pragma: no cover + transp = self._subprocesses.get(pid) + if transp is not None: + transp._process_exited(returncode) + except Exception: + tulip_log.exception('Unknown exception in SIGCHLD handler') + + def _subprocess_closed(self, transport): + pid = transport.get_pid() + self._subprocesses.pop(pid, None) + + +def _set_nonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + + +class _UnixReadPipeTransport(transports.ReadTransport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._loop = loop + self._pipe = pipe + self._fileno = pipe.fileno() + _set_nonblocking(self._fileno) + self._protocol = protocol + self._closing = False + self._loop.add_reader(self._fileno, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = os.read(self._fileno, self.max_size) + except (BlockingIOError, InterruptedError): + pass + except OSError as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + self._closing = True + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._protocol.eof_received) + self._loop.call_soon(self._call_connection_lost, None) + + def pause(self): + self._loop.remove_reader(self._fileno) + + def resume(self): + self._loop.add_reader(self._fileno, self._read_ready) + + def close(self): + if not self._closing: + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._closing = True + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + self._pipe = None + self._protocol = None + self._loop = None + + +class _UnixWritePipeTransport(transports.WriteTransport): + + def __init__(self, loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._loop = loop + self._pipe = pipe + self._fileno = pipe.fileno() + if not stat.S_ISFIFO(os.fstat(self._fileno).st_mode): + raise ValueError("Pipe transport is for pipes only.") + _set_nonblocking(self._fileno) + self._protocol = protocol + self._buffer = [] + self._conn_lost = 0 + self._closing = False # Set when close() or write_eof() called. + self._writing = True + self._loop.add_reader(self._fileno, self._read_ready) + + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + # pipe was closed by peer + + self._close() + + def write(self, data): + assert isinstance(data, bytes), repr(data) + assert not self._closing + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + tulip_log.warning('os.write(pipe, data) raised exception.') + self._conn_lost += 1 + return + + if not self._buffer and self._writing: + # Attempt to send it right away first. + try: + n = os.write(self._fileno, data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + self._conn_lost += 1 + self._fatal_error(exc) + return + if n == len(data): + return + elif n > 0: + data = data[n:] + self._loop.add_writer(self._fileno, self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + if not self._writing: + return + + data = b''.join(self._buffer) + assert data, 'Data should not be empty' + + self._buffer.clear() + try: + n = os.write(self._fileno, data) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except Exception as exc: + self._conn_lost += 1 + # Remove writer here, _fatal_error() doesn't it + # because _buffer is empty. + self._loop.remove_writer(self._fileno) + self._fatal_error(exc) + else: + if n == len(data): + self._loop.remove_writer(self._fileno) + if self._closing: + self._loop.remove_reader(self._fileno) + self._call_connection_lost(None) + return + elif n > 0: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def can_write_eof(self): + return True + + def write_eof(self): + assert not self._closing + assert self._pipe + self._closing = True + if not self._buffer: + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, None) + + def close(self): + if not self._closing: + # write_eof is all what we needed to close the write pipe + self.write_eof() + + def abort(self): + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + tulip_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc=None): + self._closing = True + if self._buffer: + self._loop.remove_writer(self._fileno) + self._buffer.clear() + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + self._pipe = None + self._protocol = None + self._loop = None + + def pause_writing(self): + if self._writing: + if self._buffer: + self._loop.remove_writer(self._fileno) + self._writing = False + + def resume_writing(self): + if not self._writing: + if self._buffer: + self._loop.add_writer(self._fileno, self._write_ready) + self._writing = True + + def discard_output(self): + if self._buffer: + self._loop.remove_writer(self._fileno) + self._buffer.clear() + + +class _UnixWriteSubprocessPipeProto(protocols.BaseProtocol): + pipe = None + + def __init__(self, proc, fd): + self.proc = proc + self.fd = fd + self.connected = False + self.disconnected = False + proc._pipes[fd] = self + + def connection_made(self, transport): + self.connected = True + self.pipe = transport + self.proc._try_connected() + + def connection_lost(self, exc): + self.disconnected = True + self.proc._pipe_connection_lost(self.fd, exc) + + +class _UnixReadSubprocessPipeProto(_UnixWriteSubprocessPipeProto, + protocols.Protocol): + + def data_received(self, data): + self.proc._pipe_data_received(self.fd, data) + + def eof_received(self): + pass + + +class _UnixSubprocessTransport(transports.SubprocessTransport): + + def __init__(self, loop, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + super().__init__(extra) + self._protocol = protocol + self._loop = loop + + self._pipes = {} + if stdin == subprocess.PIPE: + self._pipes[0] = None + if stdout == subprocess.PIPE: + self._pipes[1] = None + if stderr == subprocess.PIPE: + self._pipes[2] = None + self._pending_calls = collections.deque() + self._finished = False + self._returncode = None + + self._proc = subprocess.Popen( + args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, + universal_newlines=False, bufsize=bufsize, **kwargs) + self._extra['subprocess'] = self._proc + + def close(self): + for proto in self._pipes.values(): + proto.pipe.close() + if self._returncode is None: + self.terminate() + + def get_pid(self): + return self._proc.pid + + def get_returncode(self): + return self._returncode + + def get_pipe_transport(self, fd): + if fd in self._pipes: + return self._pipes[fd].pipe + else: + return None + + def send_signal(self, signal): + self._proc.send_signal(signal) + + def terminate(self): + self._proc.terminate() + + def kill(self): + self._proc.kill() + + @tasks.coroutine + def _post_init(self): + proc = self._proc + loop = self._loop + if proc.stdin is not None: + transp, proto = yield from loop.connect_write_pipe( + functools.partial(_UnixWriteSubprocessPipeProto, self, 0), + proc.stdin) + if proc.stdout is not None: + transp, proto = yield from loop.connect_read_pipe( + functools.partial(_UnixReadSubprocessPipeProto, self, 1), + proc.stdout) + if proc.stderr is not None: + transp, proto = yield from loop.connect_read_pipe( + functools.partial(_UnixReadSubprocessPipeProto, self, 2), + proc.stderr) + if not self._pipes: + self._try_connected() + + def _call(self, cb, *data): + if self._pending_calls is not None: + self._pending_calls.append((cb, data)) + else: + self._loop.call_soon(cb, *data) + + def _try_connected(self): + assert self._pending_calls is not None + if all(p is not None and p.connected for p in self._pipes.values()): + self._loop.call_soon(self._protocol.connection_made, self) + for callback, data in self._pending_calls: + self._loop.call_soon(callback, *data) + self._pending_calls = None + + def _pipe_connection_lost(self, fd, exc): + self._call(self._protocol.pipe_connection_lost, fd, exc) + self._try_finish() + + def _pipe_data_received(self, fd, data): + self._call(self._protocol.pipe_data_received, fd, data) + + def _process_exited(self, returncode): + assert returncode is not None, returncode + assert self._returncode is None, self._returncode + self._returncode = returncode + self._loop._subprocess_closed(self) + self._call(self._protocol.process_exited) + self._try_finish() + + def _try_finish(self): + assert not self._finished + if self._returncode is None: + return + if all(p is not None and p.disconnected + for p in self._pipes.values()): + self._finished = True + self._loop.call_soon(self._call_connection_lost, None) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._proc = None + self._protocol = None + self._loop = None diff --git a/tulip/windows_events.py b/tulip/windows_events.py new file mode 100644 index 00000000..629b3475 --- /dev/null +++ b/tulip/windows_events.py @@ -0,0 +1,203 @@ +"""Selector and proactor eventloops for Windows.""" + +import errno +import socket +import weakref +import struct +import _winapi + +from . import futures +from . import proactor_events +from . import selector_events +from . import windows_utils +from . import _overlapped +from .log import tulip_log + + +__all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor'] + + +NULL = 0 +INFINITE = 0xffffffff +ERROR_CONNECTION_REFUSED = 1225 +ERROR_CONNECTION_ABORTED = 1236 + + +class _OverlappedFuture(futures.Future): + """Subclass of Future which represents an overlapped operation. + + Cancelling it will immediately cancel the overlapped operation. + """ + + def __init__(self, ov, *, loop=None): + super().__init__(loop=loop) + self.ov = ov + + def cancel(self): + try: + self.ov.cancel() + except OSError: + pass + return super().cancel() + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + def _socketpair(self): + return windows_utils.socketpair() + + +class ProactorEventLoop(proactor_events.BaseProactorEventLoop): + def __init__(self, proactor=None): + if proactor is None: + proactor = IocpProactor() + super().__init__(proactor) + + def _socketpair(self): + return windows_utils.socketpair() + + +class IocpProactor: + + def __init__(self, concurrency=0xffffffff): + self._loop = None + self._results = [] + self._iocp = _overlapped.CreateIoCompletionPort( + _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency) + self._cache = {} + self._registered = weakref.WeakSet() + self._stopped_serving = weakref.WeakSet() + + def set_loop(self, loop): + self._loop = loop + + def select(self, timeout=None): + if not self._results: + self._poll(timeout) + tmp = self._results + self._results = [] + return tmp + + def recv(self, conn, nbytes, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + handle = getattr(conn, 'handle', None) + if handle is None: + ov.WSARecv(conn.fileno(), nbytes, flags) + else: + ov.ReadFile(handle, nbytes) + return self._register(ov, conn, ov.getresult) + + def send(self, conn, buf, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + handle = getattr(conn, 'handle', None) + if handle is None: + ov.WSASend(conn.fileno(), buf, flags) + else: + ov.WriteFile(handle, buf) + return self._register(ov, conn, ov.getresult) + + def accept(self, listener): + self._register_with_iocp(listener) + conn = self._get_accept_socket() + ov = _overlapped.Overlapped(NULL) + ov.AcceptEx(listener.fileno(), conn.fileno()) + + def finish_accept(): + ov.getresult() + buf = struct.pack('@P', listener.fileno()) + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_ACCEPT_CONTEXT, + buf) + conn.settimeout(listener.gettimeout()) + return conn, conn.getpeername() + + return self._register(ov, listener, finish_accept) + + def connect(self, conn, address): + self._register_with_iocp(conn) + # the socket needs to be locally bound before we call ConnectEx() + try: + _overlapped.BindLocal(conn.fileno(), len(address)) + except OSError as e: + if e.winerror != errno.WSAEINVAL: + raise + # probably already locally bound; check using getsockname() + if conn.getsockname()[1] == 0: + raise + ov = _overlapped.Overlapped(NULL) + ov.ConnectEx(conn.fileno(), address) + + def finish_connect(): + ov.getresult() + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_CONNECT_CONTEXT, + 0) + return conn + + return self._register(ov, conn, finish_connect) + + def _register_with_iocp(self, obj): + if obj not in self._registered: + self._registered.add(obj) + _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) + + def _register(self, ov, obj, callback): + f = _OverlappedFuture(ov, loop=self._loop) + self._cache[ov.address] = (f, ov, obj, callback) + return f + + def _get_accept_socket(self): + s = socket.socket() + s.settimeout(0) + return s + + def _poll(self, timeout=None): + if timeout is None: + ms = INFINITE + elif timeout < 0: + raise ValueError("negative timeout") + else: + ms = int(timeout * 1000 + 0.5) + if ms >= INFINITE: + raise ValueError("timeout too big") + while True: + status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms) + if status is None: + return + address = status[3] + f, ov, obj, callback = self._cache.pop(address) + if obj in self._stopped_serving: + f.cancel() + elif not f.cancelled(): + try: + value = callback() + except OSError as e: + f.set_exception(e) + self._results.append(f) + else: + f.set_result(value) + self._results.append(f) + ms = 0 + + def stop_serving(self, obj): + # obj is a socket or pipe handle. It will be closed in + # BaseProactorEventLoop.stop_serving() which will make any + # pending operations fail quickly. + self._stopped_serving.add(obj) + + def close(self): + for (f, ov, obj, callback) in self._cache.values(): + try: + ov.cancel() + except OSError: + pass + + while self._cache: + if not self._poll(1): + tulip_log.debug('taking long time to close proactor') + + self._results = [] + if self._iocp is not None: + _winapi.CloseHandle(self._iocp) + self._iocp = None diff --git a/tulip/windows_utils.py b/tulip/windows_utils.py new file mode 100644 index 00000000..bf85f31e --- /dev/null +++ b/tulip/windows_utils.py @@ -0,0 +1,181 @@ +""" +Various Windows specific bits and pieces +""" + +import sys + +if sys.platform != 'win32': # pragma: no cover + raise ImportError('win32 only') + +import socket +import itertools +import msvcrt +import os +import subprocess +import tempfile +import _winapi + + +__all__ = ['socketpair', 'pipe', 'Popen', 'PIPE', 'PipeHandle'] + +# +# Constants/globals +# + +BUFSIZE = 8192 +PIPE = subprocess.PIPE +_mmap_counter=itertools.count() + +# +# Replacement for socket.socketpair() +# + +def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """A socket pair usable as a self-pipe, for Windows. + + Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. + """ + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + lsock.bind(('localhost', 0)) + lsock.listen(1) + addr, port = lsock.getsockname() + csock = socket.socket(family, type, proto) + csock.setblocking(False) + try: + csock.connect((addr, port)) + except (BlockingIOError, InterruptedError): + pass + except Exception: + lsock.close() + csock.close() + raise + ssock, _ = lsock.accept() + csock.setblocking(True) + lsock.close() + return (ssock, csock) + +# +# Replacement for os.pipe() using handles instead of fds +# + +def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): + """Like os.pipe() but with overlapped support and using handles not fds.""" + address = tempfile.mktemp(prefix=r'\\.\pipe\python-pipe-%d-%d-' % + (os.getpid(), next(_mmap_counter))) + + if duplex: + openmode = _winapi.PIPE_ACCESS_DUPLEX + access = _winapi.GENERIC_READ | _winapi.GENERIC_WRITE + obsize, ibsize = bufsize, bufsize + else: + openmode = _winapi.PIPE_ACCESS_INBOUND + access = _winapi.GENERIC_WRITE + obsize, ibsize = 0, bufsize + + openmode |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE + + if overlapped[0]: + openmode |= _winapi.FILE_FLAG_OVERLAPPED + + if overlapped[1]: + flags_and_attribs = _winapi.FILE_FLAG_OVERLAPPED + else: + flags_and_attribs = 0 + + h1 = h2 = None + try: + h1 = _winapi.CreateNamedPipe( + address, openmode, _winapi.PIPE_WAIT, + 1, obsize, ibsize, _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL) + + h2 = _winapi.CreateFile( + address, access, 0, _winapi.NULL, _winapi.OPEN_EXISTING, + flags_and_attribs, _winapi.NULL) + + ov = _winapi.ConnectNamedPipe(h1, overlapped=True) + ov.GetOverlappedResult(True) + return h1, h2 + except: + if h1 is not None: + _winapi.CloseHandle(h1) + if h2 is not None: + _winapi.CloseHandle(h2) + raise + +# +# Wrapper for a pipe handle +# + +class PipeHandle: + """Wrapper for an overlapped pipe handle which is vaguely file-object like. + + The IOCP event loop can use these instead of socket objects. + """ + def __init__(self, handle): + self._handle = handle + + @property + def handle(self): + return self._handle + + def fileno(self): + return self._handle + + def close(self, *, CloseHandle=_winapi.CloseHandle): + if self._handle is not None: + CloseHandle(self._handle) + self._handle = None + + __del__ = close + + def __enter__(self): + return self + + def __exit__(self, t, v, tb): + self.close() + +# +# Replacement for subprocess.Popen using overlapped pipe handles +# + +class Popen(subprocess.Popen): + """Replacement for subprocess.Popen using overlapped pipe handles. + + The stdin, stdout, stderr are None or instances of PipeHandle. + """ + def __init__(self, args, stdin=None, stdout=None, stderr=None, **kwds): + stdin_rfd = stdout_wfd = stderr_wfd = None + stdin_wh = stdout_rh = stderr_rh = None + if stdin == PIPE: + stdin_rh, stdin_wh = pipe(overlapped=(False, True)) + stdin_rfd = msvcrt.open_osfhandle(stdin_rh, os.O_RDONLY) + if stdout == PIPE: + stdout_rh, stdout_wh = pipe(overlapped=(True, False)) + stdout_wfd = msvcrt.open_osfhandle(stdout_wh, 0) + if stderr == PIPE: + stderr_rh, stderr_wh = pipe(overlapped=(True, False)) + stderr_wfd = msvcrt.open_osfhandle(stderr_wh, 0) + try: + super().__init__(args, bufsize=0, universal_newlines=False, + stdin=stdin_rfd, stdout=stdout_wfd, + stderr=stderr_wfd, **kwds) + except: + for h in (stdin_wh, stdout_rh, stderr_rh): + _winapi.CloseHandle(h) + raise + else: + if stdin_wh is not None: + self.stdin = PipeHandle(stdin_wh) + if stdout_rh is not None: + self.stdout = PipeHandle(stdout_rh) + if stderr_rh is not None: + self.stderr = PipeHandle(stderr_rh) + finally: + if stdin == PIPE: + os.close(stdin_rfd) + if stdout == PIPE: + os.close(stdout_wfd) + if stderr == PIPE: + os.close(stderr_wfd) From 8cf125d01b56102839005da95f4a00b862a66922 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 8 Sep 2013 11:36:08 -0700 Subject: [PATCH 0607/1502] Improved handling of various cases in Task.cancel(). --- tulip/tasks.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tulip/tasks.py b/tulip/tasks.py index 5b3311b6..d2eb9cb2 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -77,11 +77,10 @@ def cancel(self): if self.done(): return False if self._fut_waiter is not None: - # XXX: What to do if self._fut_waiter.cancel() returns False? - # If that's anready cancelled future everything is ok. - # What are other possible scenarios? - waiter, self._fut_waiter = self._fut_waiter, None - if waiter.cancel(): + if self._fut_waiter.cancel(): + # Leave self._fut_waiter; it may be a Task that + # catches and ignores the cancellation so we may have + # to cancel it again later. return True # It must be the case that self._step is already scheduled. self._must_cancel = True @@ -90,10 +89,8 @@ def cancel(self): def _step(self, value=None, exc=None): assert not self.done(), \ '_step(): already done: {!r}, {!r}, {!r}'.format(self, value, exc) - if self._must_cancel: - assert self._fut_waiter is None + if self._must_cancel and not isinstance(exc, futures.CancelledError): exc = futures.CancelledError() - value = None coro = self._coro self._fut_waiter = None # Call either coro.throw(exc) or coro.send(value). @@ -149,6 +146,7 @@ def _wakeup(self, future): try: value = future.result() except Exception as exc: + # This may also be a cancellation. self._step(None, exc) else: self._step(value, None) From e1e7ad9b841f3dc01d014e94c74d73fd0221d80b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 8 Sep 2013 15:19:50 -0700 Subject: [PATCH 0608/1502] Replace two incorrect assertTrue(x, y) calls with assertEqual(x, y). --- tests/http_server_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/http_server_test.py b/tests/http_server_test.py index a9d4d5ed..5c7a97a0 100644 --- a/tests/http_server_test.py +++ b/tests/http_server_test.py @@ -226,7 +226,7 @@ def test_handle_400(self): self.loop.run_until_complete(srv._request_handler) self.assertTrue(srv.handle_error.called) - self.assertTrue(400, srv.handle_error.call_args[0][0]) + self.assertEqual(400, srv.handle_error.call_args[0][0]) self.assertTrue(transport.close.called) def test_handle_500(self): @@ -244,7 +244,7 @@ def test_handle_500(self): self.loop.run_until_complete(srv._request_handler) self.assertTrue(srv.handle_error.called) - self.assertTrue(500, srv.handle_error.call_args[0][0]) + self.assertEqual(500, srv.handle_error.call_args[0][0]) def test_handle_error_no_handle_task(self): transport = unittest.mock.Mock() From 9fb7c0467a067cc643810f357ac6d6197331a29b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 8 Sep 2013 15:21:12 -0700 Subject: [PATCH 0609/1502] Fix one incorrect assertTrue(), and clarify a few things. --- tests/http_protocol_test.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py index ec3aaf58..43201a8f 100644 --- a/tests/http_protocol_test.py +++ b/tests/http_protocol_test.py @@ -78,8 +78,8 @@ def test_add_headers_length(self): msg = protocol.Response(self.transport, 200) self.assertIsNone(msg.length) - msg.add_headers(('content-length', '200')) - self.assertEqual(200, msg.length) + msg.add_headers(('content-length', '42')) + self.assertEqual(42, msg.length) def test_add_headers_upgrade(self): msg = protocol.Response(self.transport, 200) @@ -204,14 +204,14 @@ def test_send_headers_nomore_add(self): def test_prepare_length(self): msg = protocol.Response(self.transport, 200) - length = msg._write_length_payload = unittest.mock.Mock() - length.return_value = iter([1, 2, 3]) + w_l_p = msg._write_length_payload = unittest.mock.Mock() + w_l_p.return_value = iter([1, 2, 3]) - msg.add_headers(('content-length', '200')) + msg.add_headers(('content-length', '42')) msg.send_headers() - self.assertTrue(length.called) - self.assertTrue((200,), length.call_args[0]) + self.assertTrue(w_l_p.called) + self.assertEqual((42,), w_l_p.call_args[0]) def test_prepare_chunked_force(self): msg = protocol.Response(self.transport, 200) @@ -220,7 +220,7 @@ def test_prepare_chunked_force(self): chunked = msg._write_chunked_payload = unittest.mock.Mock() chunked.return_value = iter([1, 2, 3]) - msg.add_headers(('content-length', '200')) + msg.add_headers(('content-length', '42')) msg.send_headers() self.assertTrue(chunked.called) From 58cedd2a9b74f8fbcd292193e8ce5f22f33204f6 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 8 Sep 2013 15:22:07 -0700 Subject: [PATCH 0610/1502] Test fixes: wrong assertTrue(), top-level "yield from", test name typos. --- tests/base_events_test.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index e27b3ab9..f137830a 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -303,7 +303,7 @@ def tearDown(self): self.loop.close() @unittest.mock.patch('tulip.base_events.socket') - def test_create_connection_mutiple_errors(self, m_socket): + def test_create_connection_multiple_errors(self, m_socket): class MyProto(protocols.Protocol): pass @@ -329,11 +329,11 @@ def _socket(*args, **kw): self.loop.getaddrinfo = getaddrinfo_task - task = tasks.Task( - self.loop.create_connection(MyProto, 'example.com', 80)) - yield from tasks.wait(task) - exc = task.exception() - self.assertEqual("Multiple exceptions: err1, err2", str(exc)) + coro = self.loop.create_connection(MyProto, 'example.com', 80) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(coro) + + self.assertEqual(str(cm.exception), 'Multiple exceptions: err1, err2') def test_create_connection_host_port_sock(self): coro = self.loop.create_connection( @@ -393,7 +393,7 @@ def getaddrinfo_task(*args, **kwds): self.loop.run_until_complete(coro) @unittest.mock.patch('tulip.base_events.socket') - def test_create_connection_mutiple_errors_local_addr(self, m_socket): + def test_create_connection_multiple_errors_local_addr(self, m_socket): def bind(addr): if addr[0] == '0.0.0.1': @@ -421,7 +421,7 @@ def getaddrinfo_task(*args, **kwds): with self.assertRaises(OSError) as cm: self.loop.run_until_complete(coro) - self.assertTrue(str(cm.exception), 'Multiple exceptions: ') + self.assertTrue(str(cm.exception).startswith('Multiple exceptions: ')) self.assertTrue(m_socket.socket.return_value.close.called) def test_create_connection_no_local_addr(self): From 5c3d353457eda61f5515fab96738b6a93fb29218 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 8 Sep 2013 15:23:07 -0700 Subject: [PATCH 0611/1502] Various test fixes I committed first on the newcancel branch. --- tests/base_events_test.py | 16 ++++++++-------- tests/http_protocol_test.py | 16 ++++++++-------- tests/http_server_test.py | 4 ++-- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index deb82af7..b423f329 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -303,7 +303,7 @@ def tearDown(self): self.loop.close() @unittest.mock.patch('tulip.base_events.socket') - def test_create_connection_mutiple_errors(self, m_socket): + def test_create_connection_multiple_errors(self, m_socket): class MyProto(protocols.Protocol): pass @@ -329,11 +329,11 @@ def _socket(*args, **kw): self.loop.getaddrinfo = getaddrinfo_task - task = tasks.Task( - self.loop.create_connection(MyProto, 'example.com', 80)) - yield from tasks.wait(task) - exc = task.exception() - self.assertEqual("Multiple exceptions: err1, err2", str(exc)) + coro = self.loop.create_connection(MyProto, 'example.com', 80) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(coro) + + self.assertEqual(str(cm.exception), 'Multiple exceptions: err1, err2') def test_create_connection_host_port_sock(self): coro = self.loop.create_connection( @@ -393,7 +393,7 @@ def getaddrinfo_task(*args, **kwds): self.loop.run_until_complete(coro) @unittest.mock.patch('tulip.base_events.socket') - def test_create_connection_mutiple_errors_local_addr(self, m_socket): + def test_create_connection_multiple_errors_local_addr(self, m_socket): def bind(addr): if addr[0] == '0.0.0.1': @@ -421,7 +421,7 @@ def getaddrinfo_task(*args, **kwds): with self.assertRaises(OSError) as cm: self.loop.run_until_complete(coro) - self.assertTrue(str(cm.exception), 'Multiple exceptions: ') + self.assertTrue(str(cm.exception).startswith('Multiple exceptions: ')) self.assertTrue(m_socket.socket.return_value.close.called) def test_create_connection_no_local_addr(self): diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py index ec3aaf58..43201a8f 100644 --- a/tests/http_protocol_test.py +++ b/tests/http_protocol_test.py @@ -78,8 +78,8 @@ def test_add_headers_length(self): msg = protocol.Response(self.transport, 200) self.assertIsNone(msg.length) - msg.add_headers(('content-length', '200')) - self.assertEqual(200, msg.length) + msg.add_headers(('content-length', '42')) + self.assertEqual(42, msg.length) def test_add_headers_upgrade(self): msg = protocol.Response(self.transport, 200) @@ -204,14 +204,14 @@ def test_send_headers_nomore_add(self): def test_prepare_length(self): msg = protocol.Response(self.transport, 200) - length = msg._write_length_payload = unittest.mock.Mock() - length.return_value = iter([1, 2, 3]) + w_l_p = msg._write_length_payload = unittest.mock.Mock() + w_l_p.return_value = iter([1, 2, 3]) - msg.add_headers(('content-length', '200')) + msg.add_headers(('content-length', '42')) msg.send_headers() - self.assertTrue(length.called) - self.assertTrue((200,), length.call_args[0]) + self.assertTrue(w_l_p.called) + self.assertEqual((42,), w_l_p.call_args[0]) def test_prepare_chunked_force(self): msg = protocol.Response(self.transport, 200) @@ -220,7 +220,7 @@ def test_prepare_chunked_force(self): chunked = msg._write_chunked_payload = unittest.mock.Mock() chunked.return_value = iter([1, 2, 3]) - msg.add_headers(('content-length', '200')) + msg.add_headers(('content-length', '42')) msg.send_headers() self.assertTrue(chunked.called) diff --git a/tests/http_server_test.py b/tests/http_server_test.py index 862779b9..a2c8542a 100644 --- a/tests/http_server_test.py +++ b/tests/http_server_test.py @@ -225,7 +225,7 @@ def test_handle_400(self): self.loop.run_until_complete(srv._request_handler) self.assertTrue(srv.handle_error.called) - self.assertTrue(400, srv.handle_error.call_args[0][0]) + self.assertEqual(400, srv.handle_error.call_args[0][0]) self.assertTrue(transport.close.called) def test_handle_500(self): @@ -243,7 +243,7 @@ def test_handle_500(self): self.loop.run_until_complete(srv._request_handler) self.assertTrue(srv.handle_error.called) - self.assertTrue(500, srv.handle_error.call_args[0][0]) + self.assertEqual(500, srv.handle_error.call_args[0][0]) def test_handle_error_no_handle_task(self): transport = unittest.mock.Mock() From 8f7d69ce7477ab987b0868a33f8d9f8d9d3782ef Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 8 Sep 2013 17:11:46 -0700 Subject: [PATCH 0612/1502] More cancellation tests. And catch more cases. --- tests/tasks_test.py | 114 ++++++++++++++++++++++++++++---------------- tulip/tasks.py | 9 +++- 2 files changed, 80 insertions(+), 43 deletions(-) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 4758024c..e0611d23 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -233,47 +233,79 @@ def task(): self.assertTrue(f.cancelled()) self.assertTrue(t.cancelled()) -## def test_cancel_done_future(self): -## fut1 = futures.Future(loop=self.loop) -## fut2 = futures.Future(loop=self.loop) -## fut3 = futures.Future(loop=self.loop) - -## @tasks.coroutine -## def task(): -## yield from fut1 -## try: -## yield from fut2 -## except futures.CancelledError: -## pass -## yield from fut3 - -## t = tasks.Task(task(), loop=self.loop) -## test_utils.run_briefly(self.loop) -## fut1.set_result(None) -## t.cancel() -## test_utils.run_once(self.loop) # process fut1 result, delay cancel -## self.assertFalse(t.done()) -## test_utils.run_once(self.loop) # cancel fut2, but coro still alive -## self.assertFalse(t.done()) -## test_utils.run_briefly(self.loop) # cancel fut3 -## self.assertTrue(t.done()) - -## self.assertEqual(fut1.result(), None) -## self.assertTrue(fut2.cancelled()) -## self.assertTrue(fut3.cancelled()) -## self.assertTrue(t.cancelled()) - -## def test_cancel_in_coro(self): -## @tasks.coroutine -## def task(): -## t.cancel() -## return 12 - -## t = tasks.Task(task(), loop=self.loop) -## self.assertRaises( -## futures.CancelledError, self.loop.run_until_complete, t) -## self.assertTrue(t.done()) -## self.assertFalse(t.cancel()) + def test_cancel_task_catching(self): + fut1 = futures.Future(loop=self.loop) + fut2 = futures.Future(loop=self.loop) + + @tasks.coroutine + def task(): + yield from fut1 + try: + yield from fut2 + except futures.CancelledError: + return 42 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut1) # White-box test. + fut1.set_result(None) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut2) # White-box test. + t.cancel() + self.assertTrue(fut2.cancelled()) + res = self.loop.run_until_complete(t) + self.assertEqual(res, 42) + self.assertFalse(t.cancelled()) + + def test_cancel_task_ignoring(self): + fut1 = futures.Future(loop=self.loop) + fut2 = futures.Future(loop=self.loop) + fut3 = futures.Future(loop=self.loop) + + @tasks.coroutine + def task(): + yield from fut1 + try: + yield from fut2 + except futures.CancelledError: + pass + res = yield from fut3 + return res + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut1) # White-box test. + fut1.set_result(None) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut2) # White-box test. + t.cancel() + self.assertTrue(fut2.cancelled()) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut3) # White-box test. + fut3.set_result(42) + res = self.loop.run_until_complete(t) + self.assertEqual(res, 42) + self.assertFalse(fut3.cancelled()) + self.assertFalse(t.cancelled()) + + def test_cancel_current_task(self): + loop = events.new_event_loop() + self.addCleanup(loop.close) + + @tasks.coroutine + def task(): + t.cancel() + self.assertTrue(t._must_cancel) # White-box test. + # The sleep should be cancelled immediately. + yield from tasks.sleep(100, loop=loop) + return 12 + + t = tasks.Task(task(), loop=loop) + self.assertRaises( + futures.CancelledError, loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t._must_cancel) # White-box test. + self.assertFalse(t.cancel()) def test_stop_while_run_in_complete(self): diff --git a/tulip/tasks.py b/tulip/tasks.py index d2eb9cb2..ca2ef0c0 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -89,8 +89,10 @@ def cancel(self): def _step(self, value=None, exc=None): assert not self.done(), \ '_step(): already done: {!r}, {!r}, {!r}'.format(self, value, exc) - if self._must_cancel and not isinstance(exc, futures.CancelledError): - exc = futures.CancelledError() + if self._must_cancel: + if not isinstance(exc, futures.CancelledError): + exc = futures.CancelledError() + self._must_cancel = False coro = self._coro self._fut_waiter = None # Call either coro.throw(exc) or coro.send(value). @@ -117,6 +119,9 @@ def _step(self, value=None, exc=None): result._blocking = False result.add_done_callback(self._wakeup) self._fut_waiter = result + if self._must_cancel: + if self._fut_waiter.cancel(): + self._must_cancel = False else: self._loop.call_soon( self._step, None, From 1cd0ff1f7e89bf432487722998c40cca43ba6602 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 8 Sep 2013 17:29:29 -0700 Subject: [PATCH 0613/1502] Import coverage fixes by Andrew Svetlov to default branch. --- Makefile | 3 +- runtests.py | 102 +++++++++++++++++++++++++--------------------------- 2 files changed, 50 insertions(+), 55 deletions(-) diff --git a/Makefile b/Makefile index 6064fc63..ed3caf21 100644 --- a/Makefile +++ b/Makefile @@ -16,8 +16,7 @@ testloop: # See runtests.py for coverage installation instructions. cov coverage: - $(PYTHON) runtests.py --coverage tulip -v $(VERBOSE) $(FLAGS) - echo "open file://`pwd`/htmlcov/index.html" + $(PYTHON) runtests.py --coverage -v $(VERBOSE) $(FLAGS) check: $(PYTHON) check.py diff --git a/runtests.py b/runtests.py index 484bff09..725bfa2e 100644 --- a/runtests.py +++ b/runtests.py @@ -27,7 +27,12 @@ import sys import subprocess import unittest +import textwrap import importlib.machinery +try: + import coverage +except ImportError: + coverage = None from unittest.signals import installHandler @@ -57,8 +62,8 @@ '--tests', action="store", dest='testsdir', default='tests', help='tests directory') ARGS.add_argument( - '--coverage', action="store", dest='coverage', nargs='?', const='', - help='enable coverage report and provide python files directory') + '--coverage', action="store_true", dest='coverage', + help='enable html coverage report') ARGS.add_argument( 'pattern', action="store", nargs="*", help='optional regex patterns to match test ids (default all tests)') @@ -175,6 +180,22 @@ def run(self, test): def runtests(): args = ARGS.parse_args() + if args.coverage and coverage is None: + URL = "bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py" + print(textwrap.dedent(""" + coverage package is not installed. + + To install coverage3 for Python 3, you need: + - Setuptools (https://pypi.python.org/pypi/setuptools) + + What worked for me: + - download {0} + * curl -O https://{0} + - python3 ez_setup.py + - python3 -m easy_install coverage + """.format(URL)).strip()) + sys.exit(1) + testsdir = os.path.abspath(args.testsdir) if not os.path.isdir(testsdir): print("Tests directory is not found: {}\n".format(testsdir)) @@ -193,6 +214,12 @@ def runtests(): findleaks = args.findleaks runner_factory = TestRunner if findleaks else unittest.TextTestRunner + if args.coverage: + cov = coverage.coverage(branch=True, + source=['tulip'], + ) + cov.start() + tests = load_tests(args.testsdir, includes, excludes) logger = logging.getLogger() if v == 0: @@ -207,59 +234,28 @@ def runtests(): logger.setLevel(logging.DEBUG) if catchbreak: installHandler() - if args.forever: - while True: + try: + if args.forever: + while True: + result = runner_factory(verbosity=v, + failfast=failfast).run(tests) + if not result.wasSuccessful(): + sys.exit(1) + else: result = runner_factory(verbosity=v, failfast=failfast).run(tests) - if not result.wasSuccessful(): - sys.exit(1) - else: - result = runner_factory(verbosity=v, - failfast=failfast).run(tests) - sys.exit(not result.wasSuccessful()) - - -def runcoverage(sdir, args): - """ - To install coverage3 for Python 3, you need: - - Setuptools (https://pypi.python.org/pypi/setuptools) - - What worked for me: - - download bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py - * curl -O \ - https://bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py - - python3 ez_setup.py - - python3 -m easy_install coverage - """ - try: - import coverage - except ImportError: - print("Coverage package is not found.") - print(runcoverage.__doc__) - return - - sdir = os.path.abspath(sdir) - if not os.path.isdir(sdir): - print("Python files directory is not found: {}\n".format(sdir)) - ARGS.print_help() - return - - mods = [source for _, source in load_modules(sdir)] - coverage = [sys.executable, '-m', 'coverage'] - - try: - subprocess.check_call( - coverage + ['run', '--branch', 'runtests.py'] + args) - except: - pass - else: - subprocess.check_call(coverage + ['html'] + mods) - subprocess.check_call(coverage + ['report'] + mods) + sys.exit(not result.wasSuccessful()) + finally: + if args.coverage: + cov.stop() + cov.save() + cov.html_report(directory='htmlcov') + print("\nCoverage report:") + cov.report(show_missing=False) + here = os.path.dirname(os.path.abspath(__file__)) + print("\nFor html report:") + print("open file://{}/htmlcov/index.html".format(here)) if __name__ == '__main__': - if '--coverage' in sys.argv: - cov_args, args = COV_ARGS.parse_known_args() - runcoverage(cov_args.coverage, args) - else: - runtests() + runtests() From f8241acd569c89f299affedb9b53ff8c409e4eb5 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 8 Sep 2013 17:55:53 -0700 Subject: [PATCH 0614/1502] Kill fake event loop (copy from newcancel branch). --- tests/futures_test.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/tests/futures_test.py b/tests/futures_test.py index c7228c00..786de31c 100644 --- a/tests/futures_test.py +++ b/tests/futures_test.py @@ -251,14 +251,17 @@ def run(arg): self.assertIs(m_events.get_event_loop.return_value, f2._loop) -# A fake event loop for tests. All it does is implement a call_soon method -# that immediately invokes the given function. -class _FakeEventLoop: - def call_soon(self, fn, *args): - fn(*args) +class FutureDoneCallbackTests(unittest.TestCase): + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) -class FutureDoneCallbackTests(unittest.TestCase): + def tearDown(self): + self.loop.close() + + def run_briefly(self): + test_utils.run_briefly(self.loop) def _make_callback(self, bag, thing): # Create a callback function that appends thing to bag. @@ -267,7 +270,7 @@ def bag_appender(future): return bag_appender def _new_future(self): - return futures.Future(loop=_FakeEventLoop()) + return futures.Future(loop=self.loop) def test_callbacks_invoked_on_set_result(self): bag = [] @@ -277,6 +280,9 @@ def test_callbacks_invoked_on_set_result(self): self.assertEqual(bag, []) f.set_result('foo') + + self.run_briefly() + self.assertEqual(bag, [42, 17]) self.assertEqual(f.result(), 'foo') @@ -288,6 +294,9 @@ def test_callbacks_invoked_on_set_exception(self): self.assertEqual(bag, []) exc = RuntimeError() f.set_exception(exc) + + self.run_briefly() + self.assertEqual(bag, [100]) self.assertEqual(f.exception(), exc) @@ -318,6 +327,9 @@ def test_remove_done_callback(self): self.assertEqual(bag, []) f.set_result('foo') + + self.run_briefly() + self.assertEqual(bag, [2]) self.assertEqual(f.result(), 'foo') From 787fc23afb93259bb2bc8eeaf41c9516f9ae487e Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 8 Sep 2013 18:03:40 -0700 Subject: [PATCH 0615/1502] Import Antoine's gather() into the default branch. --- tests/tasks_test.py | 180 ++++++++++++++++++++++++++++++++++++++++++++ tulip/tasks.py | 56 ++++++++++++++ 2 files changed, 236 insertions(+) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 3e1220dc..a83c63b2 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -3,6 +3,7 @@ import gc import unittest import unittest.mock +from unittest.mock import Mock from tulip import events from tulip import futures @@ -1196,5 +1197,184 @@ def coro(): self.assertIsNone(t2.result()) +class GatherTestsBase: + + def setUp(self): + self.one_loop = test_utils.TestLoop() + self.other_loop = test_utils.TestLoop() + + def tearDown(self): + self.one_loop.close() + self.other_loop.close() + + def _run_loop(self, loop): + while loop._ready: + test_utils.run_briefly(loop) + + def _check_success(self, **kwargs): + a, b, c = [futures.Future(loop=self.one_loop) for i in range(3)] + fut = tasks.gather(*self.wrap_futures(a, b, c), **kwargs) + cb = Mock() + fut.add_done_callback(cb) + b.set_result(1) + a.set_result(2) + self._run_loop(self.one_loop) + self.assertEqual(cb.called, False) + self.assertFalse(fut.done()) + c.set_result(3) + self._run_loop(self.one_loop) + cb.assert_called_once_with(fut) + self.assertEqual(fut.result(), [2, 1, 3]) + + def test_success(self): + self._check_success() + self._check_success(return_exceptions=False) + + def test_result_exception_success(self): + self._check_success(return_exceptions=True) + + def test_one_exception(self): + a, b, c, d, e = [futures.Future(loop=self.one_loop) for i in range(5)] + fut = tasks.gather(*self.wrap_futures(a, b, c, d, e)) + cb = Mock() + fut.add_done_callback(cb) + exc = ZeroDivisionError() + a.set_result(1) + b.set_exception(exc) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertIs(fut.exception(), exc) + # Does nothing + c.set_result(3) + d.cancel() + e.set_exception(RuntimeError()) + + def test_return_exceptions(self): + a, b, c, d = [futures.Future(loop=self.one_loop) for i in range(4)] + fut = tasks.gather(*self.wrap_futures(a, b, c, d), + return_exceptions=True) + cb = Mock() + fut.add_done_callback(cb) + exc = ZeroDivisionError() + exc2 = RuntimeError() + b.set_result(1) + c.set_exception(exc) + a.set_result(3) + self._run_loop(self.one_loop) + self.assertFalse(fut.done()) + d.set_exception(exc2) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertEqual(fut.result(), [3, 1, exc, exc2]) + + +class FutureGatherTests(GatherTestsBase, unittest.TestCase): + + def wrap_futures(self, *futures): + return futures + + def _check_empty_sequence(self, seq_or_iter): + events.set_event_loop(self.one_loop) + self.addCleanup(events.set_event_loop, None) + fut = tasks.gather(*seq_or_iter) + self.assertIsInstance(fut, futures.Future) + self.assertIs(fut._loop, self.one_loop) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + self.assertEqual(fut.result(), []) + fut = tasks.gather(*seq_or_iter, loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + + def test_constructor_empty_sequence(self): + self._check_empty_sequence([]) + self._check_empty_sequence(()) + self._check_empty_sequence(set()) + self._check_empty_sequence(iter("")) + + def test_constructor_heterogenous_futures(self): + fut1 = futures.Future(loop=self.one_loop) + fut2 = futures.Future(loop=self.other_loop) + with self.assertRaises(ValueError): + tasks.gather(fut1, fut2) + with self.assertRaises(ValueError): + tasks.gather(fut1, loop=self.other_loop) + + def test_constructor_homogenous_futures(self): + children = [futures.Future(loop=self.other_loop) for i in range(3)] + fut = tasks.gather(*children) + self.assertIs(fut._loop, self.other_loop) + self._run_loop(self.other_loop) + self.assertFalse(fut.done()) + fut = tasks.gather(*children, loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + self._run_loop(self.other_loop) + self.assertFalse(fut.done()) + + def test_one_cancellation(self): + a, b, c, d, e = [futures.Future(loop=self.one_loop) for i in range(5)] + fut = tasks.gather(a, b, c, d, e) + cb = Mock() + fut.add_done_callback(cb) + a.set_result(1) + b.cancel() + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertTrue(fut.cancelled()) + # Does nothing + c.set_result(3) + d.cancel() + e.set_exception(RuntimeError()) + + def test_result_exception_one_cancellation(self): + a, b, c, d, e, f = [futures.Future(loop=self.one_loop) + for i in range(6)] + fut = tasks.gather(a, b, c, d, e, f, return_exceptions=True) + cb = Mock() + fut.add_done_callback(cb) + a.set_result(1) + b.set_exception(ZeroDivisionError()) + c.cancel() + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertTrue(fut.cancelled()) + # Does nothing + d.set_result(3) + e.cancel() + f.set_exception(RuntimeError()) + + +class CoroutineGatherTests(GatherTestsBase, unittest.TestCase): + + def setUp(self): + super().setUp() + events.set_event_loop(self.one_loop) + + def tearDown(self): + events.set_event_loop(None) + super().tearDown() + + def wrap_futures(self, *futures): + coros = [] + for fut in futures: + def coro(fut=fut): + return (yield from fut) + coros.append(coro()) + return coros + + def test_constructor_loop_selection(self): + @tasks.coroutine + def coro(): + yield from [] + return 'abc' + fut = tasks.gather(coro(), coro()) + self.assertIs(fut._loop, self.one_loop) + fut = tasks.gather(coro(), coro(), loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + + if __name__ == '__main__': unittest.main() diff --git a/tulip/tasks.py b/tulip/tasks.py index ca513a10..f488e5da 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -3,6 +3,7 @@ __all__ = ['coroutine', 'task', 'Task', 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', 'wait', 'wait_for', 'as_completed', 'sleep', 'async', + 'gather', ] import collections @@ -357,3 +358,58 @@ def async(coro_or_future, *, loop=None, timeout=None): return Task(coro_or_future, loop=loop, timeout=timeout) else: raise TypeError('A Future or coroutine is required') + + +def gather(*coros_or_futures, loop=None, return_exceptions=False): + """Return a future aggregating results from the given coroutines + or futures. + + All futures must share the same event loop. If all the tasks + are done successfully, the returned future's result is the list of + results (in the order of the original sequence, not necessarily the + order of results arrival). If one of the tasks is cancelled, the + returned future is immediately cancelled too. If *result_exception* + is True, exceptions in the tasks are treated the same as successful + results, and gathered in the result list; otherwise, the first raised + exception will be immediately propagated to the returned future. + """ + children = [async(fut, loop=loop) for fut in coros_or_futures] + n = len(children) + if n == 0: + outer = futures.Future(loop=loop) + outer.set_result([]) + return outer + if loop is None: + loop = children[0]._loop + for fut in children: + if fut._loop is not loop: + raise ValueError("futures are tied to different event loops") + outer = futures.Future(loop=loop) + nfinished = 0 + results = [None] * n + + def _done_callback(i, fut): + nonlocal nfinished + if outer._state != futures._PENDING: + if fut._exception is not None: + # Be sure to mark the result retrieved + fut.exception() + return + if fut._state == futures._CANCELLED: + outer.cancel() + return + elif fut._exception is not None: + if not return_exceptions: + outer.set_exception(fut.exception()) + return + res = fut.exception() + else: + res = fut._result + results[i] = res + nfinished += 1 + if nfinished == n: + outer.set_result(results) + + for i, fut in enumerate(children): + fut.add_done_callback(functools.partial(_done_callback, i)) + return outer From 764055227c4ae156b62271c636f89eca9bdf98cd Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 8 Sep 2013 19:19:34 -0700 Subject: [PATCH 0616/1502] Merge default into newcancel. At least this is uneventful; only 8e0b46263b35. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index dcaee96f..a19e3224 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,6 @@ setup(name='tulip', description="reference implementation of PEP 3156", url='http://www.python.org/dev/peps/pep-3156/', - packages=['tulip'], + packages=['tulip', 'tulip.http'], ext_modules=extensions ) From 6e44162f3e04ad22fe24cdaa4d85930eb0d1646d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 8 Sep 2013 19:48:39 -0700 Subject: [PATCH 0617/1502] Rename EventWaiter -> Event. --- tests/events_test.py | 4 ++-- tests/locks_test.py | 20 ++++++++++---------- tests/queues_test.py | 4 ++-- tulip/locks.py | 9 ++++----- tulip/queues.py | 2 +- 5 files changed, 19 insertions(+), 20 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 240518c0..91cd5db7 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -151,8 +151,8 @@ def __init__(self, loop): self.disconnects = {fd: futures.Future(loop=loop) for fd in range(3)} self.data = {1: b'', 2: b''} self.returncode = None - self.got_data = {1: locks.EventWaiter(loop=loop), - 2: locks.EventWaiter(loop=loop)} + self.got_data = {1: locks.Event(loop=loop), + 2: locks.Event(loop=loop)} def connection_made(self, transport): self.transport = transport diff --git a/tests/locks_test.py b/tests/locks_test.py index 9399d759..7c138eef 100644 --- a/tests/locks_test.py +++ b/tests/locks_test.py @@ -211,7 +211,7 @@ def test_context_manager_no_yield(self): '"yield from" should be used as context manager expression') -class EventWaiterTests(unittest.TestCase): +class EventTests(unittest.TestCase): def setUp(self): self.loop = test_utils.TestLoop() @@ -222,29 +222,29 @@ def tearDown(self): def test_ctor_loop(self): loop = unittest.mock.Mock() - ev = locks.EventWaiter(loop=loop) + ev = locks.Event(loop=loop) self.assertIs(ev._loop, loop) - ev = locks.EventWaiter(loop=self.loop) + ev = locks.Event(loop=self.loop) self.assertIs(ev._loop, self.loop) def test_ctor_noloop(self): try: events.set_event_loop(self.loop) - ev = locks.EventWaiter() + ev = locks.Event() self.assertIs(ev._loop, self.loop) finally: events.set_event_loop(None) def test_repr(self): - ev = locks.EventWaiter(loop=self.loop) + ev = locks.Event(loop=self.loop) self.assertTrue(repr(ev).endswith('[unset]>')) ev.set() self.assertTrue(repr(ev).endswith('[set]>')) def test_wait(self): - ev = locks.EventWaiter(loop=self.loop) + ev = locks.Event(loop=self.loop) self.assertFalse(ev.is_set()) result = [] @@ -284,14 +284,14 @@ def c3(result): self.assertIsNone(t3.result()) def test_wait_on_set(self): - ev = locks.EventWaiter(loop=self.loop) + ev = locks.Event(loop=self.loop) ev.set() res = self.loop.run_until_complete(ev.wait()) self.assertTrue(res) def test_wait_cancel(self): - ev = locks.EventWaiter(loop=self.loop) + ev = locks.Event(loop=self.loop) wait = tasks.Task(ev.wait(), loop=self.loop) self.loop.call_soon(wait.cancel) @@ -301,7 +301,7 @@ def test_wait_cancel(self): self.assertFalse(ev._waiters) def test_clear(self): - ev = locks.EventWaiter(loop=self.loop) + ev = locks.Event(loop=self.loop) self.assertFalse(ev.is_set()) ev.set() @@ -311,7 +311,7 @@ def test_clear(self): self.assertFalse(ev.is_set()) def test_clear_with_waiters(self): - ev = locks.EventWaiter(loop=self.loop) + ev = locks.Event(loop=self.loop) result = [] @tasks.coroutine diff --git a/tests/queues_test.py b/tests/queues_test.py index 98ca3199..7241ffdc 100644 --- a/tests/queues_test.py +++ b/tests/queues_test.py @@ -202,7 +202,7 @@ def gen(): self.addCleanup(loop.close) q = queues.Queue(loop=loop) - started = locks.EventWaiter(loop=loop) + started = locks.Event(loop=loop) finished = False @tasks.coroutine @@ -310,7 +310,7 @@ def gen(): self.addCleanup(loop.close) q = queues.Queue(maxsize=1, loop=loop) - started = locks.EventWaiter(loop=loop) + started = locks.Event(loop=loop) finished = False @tasks.coroutine diff --git a/tulip/locks.py b/tulip/locks.py index 87937ec0..8f1a1e9a 100644 --- a/tulip/locks.py +++ b/tulip/locks.py @@ -1,6 +1,6 @@ """Synchronization primitives.""" -__all__ = ['Lock', 'EventWaiter', 'Condition', 'Semaphore'] +__all__ = ['Lock', 'Event', 'Condition', 'Semaphore'] import collections @@ -117,7 +117,7 @@ def release(self): self._locked = False # Wake up the first waiter who isn't cancelled. for fut in self._waiters: - if not fut.cancelled(): + if not fut.done(): fut.set_result(True) break else: @@ -137,9 +137,8 @@ def __iter__(self): return self -# TODO: Why not call this Event? -class EventWaiter: - """A EventWaiter implementation, our equivalent to threading.Event +class Event: + """An Event implementation, our equivalent to threading.Event. Class implementing event objects. An event manages a flag that can be set to true with the set() method and reset to false with the clear() method. diff --git a/tulip/queues.py b/tulip/queues.py index b658e67e..536de1cb 100644 --- a/tulip/queues.py +++ b/tulip/queues.py @@ -237,7 +237,7 @@ class JoinableQueue(Queue): def __init__(self, maxsize=0, *, loop=None): super().__init__(maxsize=maxsize, loop=loop) self._unfinished_tasks = 0 - self._finished = locks.EventWaiter(loop=self._loop) + self._finished = locks.Event(loop=self._loop) self._finished.set() def _format(self): From 93bb8e7fec764fd490ce37859d286aaaab4edd7c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 8 Sep 2013 19:53:01 -0700 Subject: [PATCH 0618/1502] Copy selectors changes from default to newcancel branch. --- tests/base_events_test.py | 1 + tests/selectors_test.py | 13 ++++++++----- tulip/selectors.py | 17 ++++++----------- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index f137830a..b423f329 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -584,6 +584,7 @@ def test_accept_connection_retry(self): @unittest.mock.patch('tulip.selector_events.tulip_log') def test_accept_connection_exception(self, m_log): sock = unittest.mock.Mock() + sock.fileno.return_value = 10 sock.accept.side_effect = OSError() self.loop._accept_connection(MyProto, sock) diff --git a/tests/selectors_test.py b/tests/selectors_test.py index 68c1c06b..0f74db0f 100644 --- a/tests/selectors_test.py +++ b/tests/selectors_test.py @@ -60,15 +60,20 @@ def test_unregister(self): s.register(fobj, selectors.EVENT_READ) s.unregister(fobj) self.assertFalse(s._fd_to_key) - self.assertFalse(s._fileobj_to_key) def test_unregister_unknown(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + s = FakeSelector() - self.assertRaises(KeyError, s.unregister, unittest.mock.Mock()) + self.assertRaises(KeyError, s.unregister, fobj) def test_modify_unknown(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + s = FakeSelector() - self.assertRaises(KeyError, s.modify, unittest.mock.Mock(), 1) + self.assertRaises(KeyError, s.modify, fobj, 1) def test_modify(self): fobj = unittest.mock.Mock() @@ -119,7 +124,6 @@ def test_close(self): s.close() self.assertFalse(s._fd_to_key) - self.assertFalse(s._fileobj_to_key) def test_context_manager(self): s = FakeSelector() @@ -128,7 +132,6 @@ def test_context_manager(self): sel.register(1, selectors.EVENT_READ) self.assertFalse(s._fd_to_key) - self.assertFalse(s._fileobj_to_key) def test_key_from_fd(self): s = FakeSelector() diff --git a/tulip/selectors.py b/tulip/selectors.py index b81b1dbe..fe027f09 100644 --- a/tulip/selectors.py +++ b/tulip/selectors.py @@ -31,7 +31,7 @@ def _fileobj_to_fd(fileobj): else: try: fd = int(fileobj.fileno()) - except (AttributeError, ValueError): + except (AttributeError, TypeError, ValueError): raise ValueError("Invalid file object: " "{!r}".format(fileobj)) from None if fd < 0: @@ -62,8 +62,6 @@ class BaseSelector(metaclass=ABCMeta): def __init__(self): # this maps file descriptors to keys self._fd_to_key = {} - # this maps file objects to keys - for fast (un)registering - self._fileobj_to_key = {} def register(self, fileobj, events, data=None): """Register a file object. @@ -77,7 +75,7 @@ def register(self, fileobj, events, data=None): SelectorKey instance """ if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)): - raise ValueError("Invalid events: {}".format(events)) + raise ValueError("Invalid events: {!r}".format(events)) key = SelectorKey(fileobj, _fileobj_to_fd(fileobj), events, data) @@ -86,7 +84,6 @@ def register(self, fileobj, events, data=None): "registered".format(fileobj, key.fd)) self._fd_to_key[key.fd] = key - self._fileobj_to_key[fileobj] = key return key def unregister(self, fileobj): @@ -99,8 +96,7 @@ def unregister(self, fileobj): SelectorKey instance """ try: - key = self._fileobj_to_key.pop(fileobj) - del self._fd_to_key[key.fd] + key = self._fd_to_key.pop(_fileobj_to_fd(fileobj)) except KeyError: raise KeyError("{!r} is not registered".format(fileobj)) from None return key @@ -118,7 +114,7 @@ def modify(self, fileobj, events, data=None): """ # TODO: Subclasses can probably optimize this even further. try: - key = self._fileobj_to_key[fileobj] + key = self._fd_to_key[_fileobj_to_fd(fileobj)] except KeyError: raise KeyError("{!r} is not registered".format(fileobj)) from None if events != key.events or data != key.data: @@ -154,7 +150,6 @@ def close(self): This must be called to make sure that any underlying resource is freed. """ self._fd_to_key.clear() - self._fileobj_to_key.clear() def get_key(self, fileobj): """Return the key associated to a registered file object. @@ -163,9 +158,9 @@ def get_key(self, fileobj): SelectorKey for this file object """ try: - return self._fileobj_to_key[fileobj] + return self._fd_to_key[_fileobj_to_fd(fileobj)] except KeyError: - raise KeyError("{} is not registered".format(fileobj)) from None + raise KeyError("{!r} is not registered".format(fileobj)) from None def __enter__(self): return self From 7b93f9d0c51e3e1ddb2db40c31ef7fb996b34a3f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 9 Sep 2013 07:46:37 -0700 Subject: [PATCH 0619/1502] Land new cancel and timeout code in default. @task is gone. --- tests/events_test.py | 28 ++-- tests/futures_test.py | 9 -- tests/http_server_test.py | 1 + tests/locks_test.py | 286 ++++++++------------------------------ tests/queues_test.py | 140 +++++++------------ tests/tasks_test.py | 268 ++++++++++++----------------------- tulip/base_events.py | 27 +--- tulip/events.py | 6 +- tulip/futures.py | 28 +--- tulip/http/client.py | 9 +- tulip/locks.py | 245 +++++++++++++------------------- tulip/queues.py | 48 +++---- tulip/tasks.py | 157 +++++++++------------ 13 files changed, 401 insertions(+), 851 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 7c342bad..91cd5db7 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -151,8 +151,8 @@ def __init__(self, loop): self.disconnects = {fd: futures.Future(loop=loop) for fd in range(3)} self.data = {1: b'', 2: b''} self.returncode = None - self.got_data = {1: locks.EventWaiter(loop=loop), - 2: locks.EventWaiter(loop=loop)} + self.got_data = {1: locks.Event(loop=loop), + 2: locks.Event(loop=loop)} def connection_made(self, transport): self.transport = transport @@ -228,18 +228,6 @@ def cb(): self.assertRaises(RuntimeError, self.loop.run_until_complete, task) - def test_run_until_complete_timeout(self): - t0 = self.loop.time() - task = tasks.async(tasks.sleep(0.2, loop=self.loop), loop=self.loop) - self.assertRaises(futures.TimeoutError, - self.loop.run_until_complete, - task, timeout=0.1) - t1 = self.loop.time() - self.assertTrue(0.08 <= t1-t0 <= 0.12, t1-t0) - self.loop.run_until_complete(task) - t2 = self.loop.time() - self.assertTrue(0.18 <= t2-t0 <= 0.22, t2-t0) - def test_call_later(self): results = [] @@ -951,7 +939,7 @@ def main(): return res start = time.monotonic() - t = tasks.Task(main(), timeout=1, loop=self.loop) + t = tasks.Task(main(), loop=self.loop) self.loop.run_forever() elapsed = time.monotonic() - start @@ -986,7 +974,7 @@ def connect(): stdin = transp.get_pipe_transport(0) stdin.write(b'Python The Winner') - self.loop.run_until_complete(proto.got_data[1].wait(1)) + self.loop.run_until_complete(proto.got_data[1].wait()) transp.close() self.loop.run_until_complete(proto.completed) self.assertEqual(-signal.SIGTERM, proto.returncode) @@ -1015,12 +1003,12 @@ def connect(): try: stdin = transp.get_pipe_transport(0) stdin.write(b'Python ') - self.loop.run_until_complete(proto.got_data[1].wait(1)) + self.loop.run_until_complete(proto.got_data[1].wait()) proto.got_data[1].clear() self.assertEqual(b'Python ', proto.data[1]) stdin.write(b'The Winner') - self.loop.run_until_complete(proto.got_data[1].wait(1)) + self.loop.run_until_complete(proto.got_data[1].wait()) self.assertEqual(b'Python The Winner', proto.data[1]) finally: transp.close() @@ -1219,13 +1207,13 @@ def connect(): stdin = transp.get_pipe_transport(0) stdout = transp.get_pipe_transport(1) stdin.write(b'test') - self.loop.run_until_complete(proto.got_data[1].wait(1)) + self.loop.run_until_complete(proto.got_data[1].wait()) self.assertEqual(b'OUT:test', proto.data[1]) stdout.close() self.loop.run_until_complete(proto.disconnects[1]) stdin.write(b'xxx') - self.loop.run_until_complete(proto.got_data[2].wait(1)) + self.loop.run_until_complete(proto.got_data[2].wait()) self.assertEqual(b'ERR:BrokenPipeError', proto.data[2]) transp.close() diff --git a/tests/futures_test.py b/tests/futures_test.py index 786de31c..18cec8b0 100644 --- a/tests/futures_test.py +++ b/tests/futures_test.py @@ -132,15 +132,6 @@ def test_repr(self): self.assertIn('<18 more>', r) f_many_callbacks.cancel() - f_pending = futures.Future(loop=self.loop, timeout=10) - self.assertEqual('Future{timeout=10, when=10}', - repr(f_pending)) - f_pending.cancel() - - f_pending = futures.Future(loop=self.loop, timeout=10) - f_pending.cancel() - self.assertEqual('Future{timeout=10}', repr(f_pending)) - def test_copy_state(self): # Test the internal _copy_state method since it's being directly # invoked in other modules. diff --git a/tests/http_server_test.py b/tests/http_server_test.py index a2c8542a..5c7a97a0 100644 --- a/tests/http_server_test.py +++ b/tests/http_server_test.py @@ -69,6 +69,7 @@ def test_connection_lost(self): handle = srv._request_handler srv.connection_lost(None) + test_utils.run_briefly(self.loop) self.assertIsNone(srv._request_handler) self.assertTrue(handle.cancelled()) diff --git a/tests/locks_test.py b/tests/locks_test.py index 529c7268..7c138eef 100644 --- a/tests/locks_test.py +++ b/tests/locks_test.py @@ -115,59 +115,6 @@ def c3(result): self.assertTrue(t3.done()) self.assertTrue(t3.result()) - def test_acquire_timeout(self): - - def gen(): - when = yield - self.assertAlmostEqual(0.1, when) - yield 0.1 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - lock = locks.Lock(loop=loop) - - self.assertTrue(loop.run_until_complete(lock.acquire())) - - acquired = loop.run_until_complete(lock.acquire(timeout=0.1)) - self.assertFalse(acquired) - self.assertAlmostEqual(0.1, loop.time()) - - lock = locks.Lock(loop=loop) - self.loop.run_until_complete(lock.acquire()) - - loop.call_soon(lock.release) - acquired = loop.run_until_complete(lock.acquire(10.1)) - self.assertTrue(acquired) - self.assertAlmostEqual(0.1, loop.time()) - - def test_acquire_timeout_mixed(self): - - def gen(): - when = yield - self.assertAlmostEqual(0.01, when) - yield 0.1 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - lock = locks.Lock(loop=loop) - loop.run_until_complete(lock.acquire()) - tasks.Task(lock.acquire(), loop=loop) - tasks.Task(lock.acquire(), loop=loop) - acquire_task = tasks.Task(lock.acquire(0.01), loop=loop) - tasks.Task(lock.acquire(), loop=loop) - - acquired = loop.run_until_complete(acquire_task) - self.assertFalse(acquired) - self.assertAlmostEqual(0.1, loop.time()) - - self.assertEqual(3, len(lock._waiters)) - - # wakeup to close waiting coroutines - for i in range(3): - lock.release() - test_utils.run_briefly(loop) - def test_acquire_cancel(self): lock = locks.Lock(loop=self.loop) self.assertTrue(self.loop.run_until_complete(lock.acquire())) @@ -179,6 +126,54 @@ def test_acquire_cancel(self): self.loop.run_until_complete, task) self.assertFalse(lock._waiters) + def test_cancel_race(self): + # Several tasks: + # - A acquires the lock + # - B is blocked in aqcuire() + # - C is blocked in aqcuire() + # + # Now, concurrently: + # - B is cancelled + # - A releases the lock + # + # If B's waiter is marked cancelled but not yet removed from + # _waiters, A's release() call will crash when trying to set + # B's waiter; instead, it should move on to C's waiter. + + # Setup: A has the lock, b and c are waiting. + lock = locks.Lock(loop=self.loop) + + @tasks.coroutine + def lockit(name, blocker): + yield from lock.acquire() + try: + if blocker is not None: + yield from blocker + finally: + lock.release() + + fa = futures.Future(loop=self.loop) + ta = tasks.Task(lockit('A', fa), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertTrue(lock.locked()) + tb = tasks.Task(lockit('B', None), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(len(lock._waiters), 1) + tc = tasks.Task(lockit('C', None), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(len(lock._waiters), 2) + + # Create the race and check. + # Without the fix this failed at the last assert. + fa.set_result(None) + tb.cancel() + self.assertTrue(lock._waiters[0].cancelled()) + test_utils.run_briefly(self.loop) + self.assertFalse(lock.locked()) + self.assertTrue(ta.done()) + self.assertTrue(tb.cancelled()) + self.assertTrue(tc.done()) + def test_release_not_acquired(self): lock = locks.Lock(loop=self.loop) @@ -216,7 +211,7 @@ def test_context_manager_no_yield(self): '"yield from" should be used as context manager expression') -class EventWaiterTests(unittest.TestCase): +class EventTests(unittest.TestCase): def setUp(self): self.loop = test_utils.TestLoop() @@ -227,29 +222,29 @@ def tearDown(self): def test_ctor_loop(self): loop = unittest.mock.Mock() - ev = locks.EventWaiter(loop=loop) + ev = locks.Event(loop=loop) self.assertIs(ev._loop, loop) - ev = locks.EventWaiter(loop=self.loop) + ev = locks.Event(loop=self.loop) self.assertIs(ev._loop, self.loop) def test_ctor_noloop(self): try: events.set_event_loop(self.loop) - ev = locks.EventWaiter() + ev = locks.Event() self.assertIs(ev._loop, self.loop) finally: events.set_event_loop(None) def test_repr(self): - ev = locks.EventWaiter(loop=self.loop) + ev = locks.Event(loop=self.loop) self.assertTrue(repr(ev).endswith('[unset]>')) ev.set() self.assertTrue(repr(ev).endswith('[set]>')) def test_wait(self): - ev = locks.EventWaiter(loop=self.loop) + ev = locks.Event(loop=self.loop) self.assertFalse(ev.is_set()) result = [] @@ -289,63 +284,14 @@ def c3(result): self.assertIsNone(t3.result()) def test_wait_on_set(self): - ev = locks.EventWaiter(loop=self.loop) + ev = locks.Event(loop=self.loop) ev.set() res = self.loop.run_until_complete(ev.wait()) self.assertTrue(res) - def test_wait_timeout(self): - def gen(): - when = yield - self.assertAlmostEqual(0.1, when) - when = yield 0.1 - self.assertAlmostEqual(0.11, when) - when = yield 0 - self.assertAlmostEqual(10.2, when) - yield 0.01 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - ev = locks.EventWaiter(loop=loop) - - res = loop.run_until_complete(ev.wait(0.1)) - self.assertFalse(res) - self.assertAlmostEqual(0.1, loop.time()) - - ev = locks.EventWaiter(loop=loop) - loop.call_later(0.01, ev.set) - acquired = loop.run_until_complete(ev.wait(10.1)) - self.assertTrue(acquired) - self.assertAlmostEqual(0.11, loop.time()) - - def test_wait_timeout_mixed(self): - def gen(): - when = yield - self.assertAlmostEqual(0.1, when) - yield 0.1 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - ev = locks.EventWaiter(loop=loop) - tasks.Task(ev.wait(), loop=loop) - tasks.Task(ev.wait(), loop=loop) - acquire_task = tasks.Task(ev.wait(0.1), loop=loop) - tasks.Task(ev.wait(), loop=loop) - - acquired = loop.run_until_complete(acquire_task) - self.assertFalse(acquired) - self.assertAlmostEqual(0.1, loop.time()) - self.assertEqual(3, len(ev._waiters)) - - # wakeup to close waiting coroutines - ev.set() - test_utils.run_briefly(loop) - def test_wait_cancel(self): - ev = locks.EventWaiter(loop=self.loop) + ev = locks.Event(loop=self.loop) wait = tasks.Task(ev.wait(), loop=self.loop) self.loop.call_soon(wait.cancel) @@ -355,7 +301,7 @@ def test_wait_cancel(self): self.assertFalse(ev._waiters) def test_clear(self): - ev = locks.EventWaiter(loop=self.loop) + ev = locks.Event(loop=self.loop) self.assertFalse(ev.is_set()) ev.set() @@ -365,7 +311,7 @@ def test_clear(self): self.assertFalse(ev.is_set()) def test_clear_with_waiters(self): - ev = locks.EventWaiter(loop=self.loop) + ev = locks.Event(loop=self.loop) result = [] @tasks.coroutine @@ -485,23 +431,6 @@ def c3(result): self.assertTrue(t3.done()) self.assertTrue(t3.result()) - def test_wait_timeout(self): - def gen(): - when = yield - self.assertAlmostEqual(0.1, when) - yield 0.1 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - cond = locks.Condition(loop=loop) - loop.run_until_complete(cond.acquire()) - - wait = loop.run_until_complete(cond.wait(0.1)) - self.assertFalse(wait) - self.assertTrue(cond.locked()) - self.assertAlmostEqual(0.1, loop.time()) - def test_wait_cancel(self): cond = locks.Condition(loop=self.loop) self.loop.run_until_complete(cond.acquire()) @@ -558,49 +487,6 @@ def c1(result): self.assertTrue(t.done()) self.assertTrue(t.result()) - def test_wait_for_timeout(self): - def gen(): - when = yield - self.assertAlmostEqual(0.1, when) - when = yield 0 - self.assertAlmostEqual(0.1, when) - yield 0.1 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - cond = locks.Condition(loop=loop) - - result = [] - - predicate = unittest.mock.Mock(return_value=False) - - @tasks.coroutine - def c1(result): - yield from cond.acquire() - if (yield from cond.wait_for(predicate, 0.1)): - result.append(1) - else: - result.append(2) - cond.release() - - wait_for = tasks.Task(c1(result), loop=loop) - - test_utils.run_briefly(loop) - self.assertEqual([], result) - - loop.run_until_complete(cond.acquire()) - cond.notify() - cond.release() - test_utils.run_briefly(loop) - self.assertEqual([], result) - - loop.run_until_complete(wait_for) - self.assertEqual([2], result) - self.assertEqual(3, predicate.call_count) - - self.assertAlmostEqual(0.1, loop.time()) - def test_wait_for_unacquired(self): cond = locks.Condition(loop=self.loop) @@ -834,62 +720,6 @@ def c4(result): # cleanup locked semaphore sem.release() - def test_acquire_timeout(self): - def gen(): - when = yield - self.assertAlmostEqual(0.1, when) - when = yield 0.1 - self.assertAlmostEqual(0.11, when) - when = yield 0 - self.assertAlmostEqual(10.2, when) - yield 0.01 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - sem = locks.Semaphore(loop=loop) - loop.run_until_complete(sem.acquire()) - - acquired = loop.run_until_complete(sem.acquire(0.1)) - self.assertFalse(acquired) - self.assertAlmostEqual(0.1, loop.time()) - - sem = locks.Semaphore(loop=loop) - loop.run_until_complete(sem.acquire()) - - loop.call_later(0.01, sem.release) - acquired = loop.run_until_complete(sem.acquire(10.1)) - self.assertTrue(acquired) - self.assertAlmostEqual(0.11, loop.time()) - - def test_acquire_timeout_mixed(self): - def gen(): - when = yield - self.assertAlmostEqual(0.1, when) - yield 0.1 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - sem = locks.Semaphore(loop=loop) - loop.run_until_complete(sem.acquire()) - tasks.Task(sem.acquire(), loop=loop) - tasks.Task(sem.acquire(), loop=loop) - acquire_task = tasks.Task(sem.acquire(0.1), loop=loop) - tasks.Task(sem.acquire(), loop=loop) - - acquired = loop.run_until_complete(acquire_task) - self.assertFalse(acquired) - - self.assertAlmostEqual(0.1, loop.time()) - - self.assertEqual(3, len(sem._waiters)) - - # wakeup to close waiting coroutines - for i in range(3): - sem.release() - test_utils.run_briefly(loop) - def test_acquire_cancel(self): sem = locks.Semaphore(loop=self.loop) self.loop.run_until_complete(sem.acquire()) diff --git a/tests/queues_test.py b/tests/queues_test.py index 0dce6653..7241ffdc 100644 --- a/tests/queues_test.py +++ b/tests/queues_test.py @@ -202,7 +202,7 @@ def gen(): self.addCleanup(loop.close) q = queues.Queue(loop=loop) - started = locks.EventWaiter(loop=loop) + started = locks.Event(loop=loop) finished = False @tasks.coroutine @@ -236,37 +236,7 @@ def test_nonblocking_get_exception(self): q = queues.Queue(loop=self.loop) self.assertRaises(queues.Empty, q.get_nowait) - def test_get_timeout(self): - - def gen(): - when = yield - self.assertAlmostEqual(0.01, when) - yield 0.01 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - q = queues.Queue(loop=loop) - - @tasks.coroutine - def queue_get(): - with self.assertRaises(queues.Empty): - return (yield from q.get(timeout=0.01)) - - # Get works after timeout, with blocking and non-blocking put. - q.put_nowait(1) - self.assertEqual(1, (yield from q.get())) - - t = tasks.Task(q.put(2), loop=loop) - self.assertEqual(2, (yield from q.get())) - - self.assertTrue(t.done()) - self.assertIsNone(t.result()) - - loop.run_until_complete(queue_get()) - self.assertAlmostEqual(0.01, loop.time()) - - def test_get_timeout_cancelled(self): + def test_get_cancelled(self): def gen(): when = yield @@ -282,7 +252,7 @@ def gen(): @tasks.coroutine def queue_get(): - return (yield from q.get(timeout=0.05)) + return (yield from tasks.wait_for(q.get(), 0.05, loop=loop)) @tasks.coroutine def test(): @@ -294,6 +264,28 @@ def test(): self.assertEqual(1, loop.run_until_complete(test())) self.assertAlmostEqual(0.06, loop.time()) + def test_get_cancelled_race(self): + q = queues.Queue(loop=self.loop) + + t1 = tasks.Task(q.get(), loop=self.loop) + t2 = tasks.Task(q.get(), loop=self.loop) + + test_utils.run_briefly(self.loop) + t1.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(t1.done()) + q.put_nowait('a') + test_utils.run_briefly(self.loop) + self.assertEqual(t2.result(), 'a') + + def test_get_with_waiting_putters(self): + q = queues.Queue(loop=self.loop, maxsize=1) + t1 = tasks.Task(q.put('a'), loop=self.loop) + t2 = tasks.Task(q.put('b'), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(self.loop.run_until_complete(q.get()), 'a') + self.assertEqual(self.loop.run_until_complete(q.get()), 'b') + class QueuePutTests(_QueueTestBase): @@ -318,7 +310,7 @@ def gen(): self.addCleanup(loop.close) q = queues.Queue(maxsize=1, loop=loop) - started = locks.EventWaiter(loop=loop) + started = locks.Event(loop=loop) finished = False @tasks.coroutine @@ -351,47 +343,12 @@ def test_nonblocking_put_exception(self): q.put_nowait(1) self.assertRaises(queues.Full, q.put_nowait, 2) - def test_put_timeout(self): - - def gen(): - when = yield - self.assertAlmostEqual(0.01, when) - when = yield 0.01 - self.assertAlmostEqual(0.02, when) - yield 0.01 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - q = queues.Queue(1, loop=loop) - q.put_nowait(0) - - @tasks.coroutine - def queue_put(): - with self.assertRaises(queues.Full): - return (yield from q.put(1, timeout=0.01)) - - self.assertEqual(0, q.get_nowait()) - - # Put works after timeout, with blocking and non-blocking get. - get_task = tasks.Task(q.get(), loop=loop) - # Let the get start waiting. - yield from tasks.sleep(0.01, loop=loop) - q.put_nowait(2) - self.assertEqual(2, (yield from get_task)) - - q.put_nowait(3) - self.assertEqual(3, q.get_nowait()) - - loop.run_until_complete(queue_put()) - self.assertAlmostEqual(0.02, loop.time()) - - def test_put_timeout_cancelled(self): + def test_put_cancelled(self): q = queues.Queue(loop=self.loop) @tasks.coroutine def queue_put(): - yield from q.put(1, timeout=0.01) + yield from q.put(1) return True @tasks.coroutine @@ -403,6 +360,27 @@ def test(): self.assertTrue(t.done()) self.assertTrue(t.result()) + def test_put_cancelled_race(self): + q = queues.Queue(loop=self.loop, maxsize=1) + + t1 = tasks.Task(q.put('a'), loop=self.loop) + t2 = tasks.Task(q.put('b'), loop=self.loop) + t3 = tasks.Task(q.put('c'), loop=self.loop) + + test_utils.run_briefly(self.loop) + t2.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(t2.done()) + self.assertEqual(q.get_nowait(), 'a') + self.assertEqual(q.get_nowait(), 'c') + + def test_put_with_waiting_getters(self): + q = queues.Queue(loop=self.loop) + t = tasks.Task(q.get(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.loop.run_until_complete(q.put('a')) + self.assertEqual(self.loop.run_until_complete(t), 'a') + class LifoQueueTests(_QueueTestBase): @@ -480,26 +458,6 @@ def join(): self.loop.run_until_complete(join()) - def test_join_timeout(self): - def gen(): - when = yield - self.assertAlmostEqual(0.1, when) - yield 0.1 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - q = queues.JoinableQueue(loop=loop) - q.put_nowait(1) - - @tasks.coroutine - def join(): - yield from q.join(0.1) - - # Join completes in ~ 0.1 seconds, although no one calls task_done(). - loop.run_until_complete(join()) - self.assertAlmostEqual(0.1, loop.time()) - def test_format(self): q = queues.JoinableQueue(loop=self.loop) self.assertEqual(q._format(), 'maxsize=0') diff --git a/tests/tasks_test.py b/tests/tasks_test.py index a83c63b2..e0611d23 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -45,57 +45,6 @@ def notmuch(): self.assertIs(t._loop, loop) loop.close() - def test_task_decorator(self): - @tasks.task - def notmuch(): - yield from [] - return 'ko' - - try: - events.set_event_loop(self.loop) - t = notmuch() - finally: - events.set_event_loop(None) - - self.assertIsInstance(t, tasks.Task) - self.loop.run_until_complete(t) - self.assertTrue(t.done()) - self.assertEqual(t.result(), 'ko') - - def test_task_decorator_func(self): - @tasks.task - def notmuch(): - return 'ko' - - try: - events.set_event_loop(self.loop) - t = notmuch() - finally: - events.set_event_loop(None) - - self.assertIsInstance(t, tasks.Task) - self.loop.run_until_complete(t) - self.assertTrue(t.done()) - self.assertEqual(t.result(), 'ko') - - def test_task_decorator_fut(self): - @tasks.task - def notmuch(): - fut = futures.Future(loop=self.loop) - fut.set_result('ko') - return fut - - try: - events.set_event_loop(self.loop) - t = notmuch() - finally: - events.set_event_loop(None) - - self.assertIsInstance(t, tasks.Task) - self.loop.run_until_complete(t) - self.assertTrue(t.done()) - self.assertEqual(t.result(), 'ko') - def test_async_coroutine(self): @tasks.coroutine def notmuch(): @@ -225,9 +174,10 @@ def task(): t = tasks.Task(task(), loop=loop) loop.call_soon(t.cancel) - self.assertRaises( - futures.CancelledError, loop.run_until_complete, t) + with self.assertRaises(futures.CancelledError): + loop.run_until_complete(t) self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) self.assertFalse(t.cancel()) def test_cancel_yield(self): @@ -243,105 +193,118 @@ def task(): self.assertRaises( futures.CancelledError, self.loop.run_until_complete, t) self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) self.assertFalse(t.cancel()) - def test_cancel_done_future(self): - fut1 = futures.Future(loop=self.loop) - fut2 = futures.Future(loop=self.loop) - fut3 = futures.Future(loop=self.loop) + def test_cancel_inner_future(self): + f = futures.Future(loop=self.loop) @tasks.coroutine def task(): - yield from fut1 - try: - yield from fut2 - except futures.CancelledError: - pass - yield from fut3 + yield from f + return 12 t = tasks.Task(task(), loop=self.loop) - test_utils.run_briefly(self.loop) - fut1.set_result(None) - t.cancel() - test_utils.run_once(self.loop) # process fut1 result, delay cancel - self.assertFalse(t.done()) - test_utils.run_once(self.loop) # cancel fut2, but coro still alive - self.assertFalse(t.done()) - test_utils.run_briefly(self.loop) # cancel fut3 - self.assertTrue(t.done()) - - self.assertEqual(fut1.result(), None) - self.assertTrue(fut2.cancelled()) - self.assertTrue(fut3.cancelled()) + test_utils.run_briefly(self.loop) # start task + f.cancel() + with self.assertRaises(futures.CancelledError): + self.loop.run_until_complete(t) + self.assertTrue(f.cancelled()) self.assertTrue(t.cancelled()) - def test_future_timeout(self): - - def gen(): - when = yield - self.assertAlmostEqual(0.1, when) - when = yield 0 - self.assertAlmostEqual(10.0, when) - yield 0.1 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) + def test_cancel_both_task_and_inner_future(self): + f = futures.Future(loop=self.loop) @tasks.coroutine - def coro(): - yield from tasks.sleep(10.0, loop=loop) + def task(): + yield from f return 12 - t = tasks.Task(coro(), timeout=0.1, loop=loop) + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) - self.assertRaises( - futures.CancelledError, - loop.run_until_complete, t) - self.assertTrue(t.done()) - self.assertFalse(t.cancel()) - self.assertAlmostEqual(0.1, loop.time()) + f.cancel() + t.cancel() - def test_future_timeout_catch(self): + with self.assertRaises(futures.CancelledError): + self.loop.run_until_complete(t) - def gen(): - when = yield - self.assertAlmostEqual(0.1, when) - when = yield 0 - self.assertAlmostEqual(10.0, when) - yield 0.1 + self.assertTrue(t.done()) + self.assertTrue(f.cancelled()) + self.assertTrue(t.cancelled()) - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) + def test_cancel_task_catching(self): + fut1 = futures.Future(loop=self.loop) + fut2 = futures.Future(loop=self.loop) @tasks.coroutine - def coro(): - yield from tasks.sleep(10.0, loop=loop) - return 12 + def task(): + yield from fut1 + try: + yield from fut2 + except futures.CancelledError: + return 42 - class Cancelled(Exception): - pass + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut1) # White-box test. + fut1.set_result(None) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut2) # White-box test. + t.cancel() + self.assertTrue(fut2.cancelled()) + res = self.loop.run_until_complete(t) + self.assertEqual(res, 42) + self.assertFalse(t.cancelled()) + + def test_cancel_task_ignoring(self): + fut1 = futures.Future(loop=self.loop) + fut2 = futures.Future(loop=self.loop) + fut3 = futures.Future(loop=self.loop) @tasks.coroutine - def coro2(): + def task(): + yield from fut1 try: - yield from tasks.Task(coro(), timeout=0.1, loop=loop) + yield from fut2 except futures.CancelledError: - raise Cancelled() + pass + res = yield from fut3 + return res - self.assertRaises( - Cancelled, loop.run_until_complete, coro2()) - self.assertAlmostEqual(0.1, loop.time()) + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut1) # White-box test. + fut1.set_result(None) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut2) # White-box test. + t.cancel() + self.assertTrue(fut2.cancelled()) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut3) # White-box test. + fut3.set_result(42) + res = self.loop.run_until_complete(t) + self.assertEqual(res, 42) + self.assertFalse(fut3.cancelled()) + self.assertFalse(t.cancelled()) + + def test_cancel_current_task(self): + loop = events.new_event_loop() + self.addCleanup(loop.close) - def test_cancel_in_coro(self): @tasks.coroutine def task(): t.cancel() + self.assertTrue(t._must_cancel) # White-box test. + # The sleep should be cancelled immediately. + yield from tasks.sleep(100, loop=loop) return 12 - t = tasks.Task(task(), loop=self.loop) + t = tasks.Task(task(), loop=loop) self.assertRaises( - futures.CancelledError, self.loop.run_until_complete, t) + futures.CancelledError, loop.run_until_complete, t) self.assertTrue(t.done()) + self.assertFalse(t._must_cancel) # White-box test. self.assertFalse(t.cancel()) def test_stop_while_run_in_complete(self): @@ -382,57 +345,6 @@ def task(): for w in waiters: w.close() - def test_timeout(self): - - def gen(): - when = yield - self.assertAlmostEqual(0.1, when) - when = yield 0 - self.assertAlmostEqual(10.0, when) - yield 0.1 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - @tasks.coroutine - def task(): - yield from tasks.sleep(10.0, loop=loop) - return 42 - - t = tasks.Task(task(), loop=loop) - self.assertRaises( - futures.TimeoutError, loop.run_until_complete, t, 0.1) - self.assertAlmostEqual(0.1, loop.time()) - self.assertFalse(t.done()) - - # move forward to close generator - loop.advance_time(10) - self.assertEqual(42, loop.run_until_complete(t)) - self.assertTrue(t.done()) - - def test_timeout_not(self): - - def gen(): - when = yield - self.assertAlmostEqual(10.0, when) - when = yield 0 - self.assertAlmostEqual(0.1, when) - yield 0.1 - - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) - - @tasks.coroutine - def task(): - yield from tasks.sleep(0.1, loop=loop) - return 42 - - t = tasks.Task(task(), loop=loop) - r = loop.run_until_complete(t, 10.0) - self.assertTrue(t.done()) - self.assertEqual(r, 42) - self.assertAlmostEqual(0.1, loop.time()) - def test_wait_for(self): def gen(): @@ -1014,16 +926,14 @@ def test_task_cancel_waiter_future(self): @tasks.coroutine def coro(): - try: - yield from fut - except futures.CancelledError: - pass + yield from fut task = tasks.Task(coro(), loop=self.loop) test_utils.run_briefly(self.loop) self.assertIs(task._fut_waiter, fut) task.cancel() + test_utils.run_briefly(self.loop) self.assertRaises( futures.CancelledError, self.loop.run_until_complete, task) self.assertIsNone(task._fut_waiter) @@ -1105,12 +1015,14 @@ def gen(): def sleeper(): yield from tasks.sleep(10, loop=loop) + base_exc = BaseException() + @tasks.coroutine def notmutch(): try: yield from sleeper() except futures.CancelledError: - raise BaseException() + raise base_exc task = tasks.Task(notmutch(), loop=loop) test_utils.run_briefly(loop) @@ -1121,7 +1033,8 @@ def notmutch(): self.assertRaises(BaseException, test_utils.run_briefly, loop) self.assertTrue(task.done()) - self.assertTrue(task.cancelled()) + self.assertFalse(task.cancelled()) + self.assertIs(task.exception(), base_exc) def test_iscoroutinefunction(self): def fn(): @@ -1149,8 +1062,7 @@ def wait_for_future(): with self.assertRaises(RuntimeError) as cm: self.loop.run_until_complete(task) - self.assertTrue(fut.done()) - self.assertIs(fut.exception(), cm.exception) + self.assertFalse(fut.done()) def test_yield_vs_yield_from_generator(self): @tasks.coroutine diff --git a/tulip/base_events.py b/tulip/base_events.py index 5ff2d3c9..3bccfc83 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -112,8 +112,8 @@ def run_forever(self): finally: self._running = False - def run_until_complete(self, future, timeout=None): - """Run until the Future is done, or until a timeout. + def run_until_complete(self, future): + """Run until the Future is done. If the argument is a coroutine, it is wrapped in a Task. @@ -121,31 +121,12 @@ def run_until_complete(self, future, timeout=None): with the same coroutine twice -- it would wrap it in two different Tasks and that can't be good. - Return the Future's result, or raise its exception. If the - timeout is reached or stop() is called, raise TimeoutError. + Return the Future's result, or raise its exception. """ future = tasks.async(future, loop=self) future.add_done_callback(_raise_stop_error) - handle_called = False - - if timeout is None: - self.run_forever() - else: - - def stop_loop(): - nonlocal handle_called - handle_called = True - raise _StopError - - handle = self.call_later(timeout, stop_loop) - self.run_forever() - handle.cancel() - + self.run_forever() future.remove_done_callback(_raise_stop_error) - - if handle_called: - raise futures.TimeoutError - if not future.done(): raise RuntimeError('Event loop stopped before Future completed.') diff --git a/tulip/events.py b/tulip/events.py index e292eea2..7db2514d 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -109,14 +109,10 @@ def run_forever(self): """Run the event loop until stop() is called.""" raise NotImplementedError - def run_until_complete(self, future, timeout=None): + def run_until_complete(self, future): """Run the event loop until a Future is done. Return the Future's result, or raise its exception. - - If timeout is not None, run it for at most that long; - if the Future is still not done, raise TimeoutError - (but don't cancel the Future). """ raise NotImplementedError diff --git a/tulip/futures.py b/tulip/futures.py index 8593e9ae..706e8c8a 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -1,7 +1,7 @@ """A Future class similar to the one in PEP 3148.""" __all__ = ['CancelledError', 'TimeoutError', - 'InvalidStateError', 'InvalidTimeoutError', + 'InvalidStateError', 'Future', 'wrap_future', ] @@ -30,11 +30,6 @@ class InvalidStateError(Error): # TODO: Show the future, its state, the method, and the required state. -class InvalidTimeoutError(Error): - """Called result() or exception() with timeout != 0.""" - # TODO: Print a nice error message. - - class _TracebackLogger: """Helper to log a traceback upon destruction if not cleared. @@ -129,15 +124,13 @@ class Future: _state = _PENDING _result = None _exception = None - _timeout = None - _timeout_handle = None _loop = None _blocking = False # proper use of future (yield vs yield from) _tb_logger = None - def __init__(self, *, loop=None, timeout=None): + def __init__(self, *, loop=None): """Initialize the future. The optional event_loop argument allows to explicitly set the event @@ -150,10 +143,6 @@ def __init__(self, *, loop=None, timeout=None): self._loop = loop self._callbacks = [] - if timeout is not None: - self._timeout = timeout - self._timeout_handle = self._loop.call_later(timeout, self.cancel) - def __repr__(self): res = self.__class__.__name__ if self._state == _FINISHED: @@ -171,14 +160,6 @@ def __repr__(self): res += '<{}, {}>'.format(self._state, self._callbacks) else: res += '<{}>'.format(self._state) - dct = {} - if self._timeout is not None: - dct['timeout'] = self._timeout - if self._timeout_handle is not None: - dct['when'] = self._timeout_handle._when - if dct: - res += '{' + ', '.join('{}={}'.format(k, dct[k]) - for k in sorted(dct)) + '}' return res def cancel(self): @@ -200,11 +181,6 @@ def _schedule_callbacks(self): The callbacks are scheduled to be called as soon as possible. Also clears the callback list. """ - # Cancel timeout handle - if self._timeout_handle is not None: - self._timeout_handle.cancel() - self._timeout_handle = None - callbacks = self._callbacks[:] if not callbacks: return diff --git a/tulip/http/client.py b/tulip/http/client.py index 2aedfdd1..ec7cd034 100644 --- a/tulip/http/client.py +++ b/tulip/http/client.py @@ -95,10 +95,17 @@ def request(method, url, *, conn = session.start(req, loop) # connection timeout + t = tulip.Task(conn, loop=loop) + th = None + if timeout is not None: + th = loop.call_later(timeout, t.cancel) try: - resp = yield from tulip.Task(conn, timeout=timeout, loop=loop) + resp = yield from t except tulip.CancelledError: raise tulip.TimeoutError from None + finally: + if th is not None: + th.cancel() # redirects if resp.status in (301, 302) and allow_redirects: diff --git a/tulip/locks.py b/tulip/locks.py index 7c0a8f2a..8f1a1e9a 100644 --- a/tulip/locks.py +++ b/tulip/locks.py @@ -1,6 +1,6 @@ -"""Synchronization primitives""" +"""Synchronization primitives.""" -__all__ = ['Lock', 'EventWaiter', 'Condition', 'Semaphore'] +__all__ = ['Lock', 'Event', 'Condition', 'Semaphore'] import collections @@ -10,29 +10,31 @@ class Lock: - """The class implementing primitive lock objects. - - A primitive lock is a synchronization primitive that is not owned by - a particular coroutine when locked. A primitive lock is in one of two - states, "locked" or "unlocked". - It is created in the unlocked state. It has two basic methods, - acquire() and release(). When the state is unlocked, acquire() changes - the state to locked and returns immediately. When the state is locked, - acquire() blocks until a call to release() in another coroutine changes - it to unlocked, then the acquire() call resets it to locked and returns. - The release() method should only be called in the locked state; it changes - the state to unlocked and returns immediately. If an attempt is made - to release an unlocked lock, a RuntimeError will be raised. - - When more than one coroutine is blocked in acquire() waiting for the state - to turn to unlocked, only one coroutine proceeds when a release() call - resets the state to unlocked; first coroutine which is blocked in acquire() - is being processed. - - acquire() method is a coroutine and should be called with "yield from" - - Locks also support the context manager protocol. (yield from lock) should - be used as context manager expression. + """Primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned + by a particular coroutine when locked. A primitive lock is in one + of two states, 'locked' or 'unlocked'. + + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() + changes the state to locked and returns immediately. When the + state is locked, acquire() blocks until a call to release() in + another coroutine changes it to unlocked, then the acquire() call + resets it to locked and returns. The release() method should only + be called in the locked state; it changes the state to unlocked + and returns immediately. If an attempt is made to release an + unlocked lock, a RuntimeError will be raised. + + When more than one coroutine is blocked in acquire() waiting for + the state to turn to unlocked, only one coroutine proceeds when a + release() call resets the state to unlocked; first coroutine which + is blocked in acquire() is being processed. + + acquire() is a coroutine and should be called with 'yield from'. + + Locks also support the context manager protocol. '(yield from lock)' + should be used as context manager expression. Usage: @@ -51,7 +53,7 @@ class Lock: with (yield from lock): ... - Lock object could be tested for locking state: + Lock objects can be tested for locking state: if not lock.locked(): yield from lock @@ -71,45 +73,34 @@ def __init__(self, *, loop=None): def __repr__(self): res = super().__repr__() - return '<{} [{}]>'.format( - res[1:-1], 'locked' if self._locked else 'unlocked') + extra = 'locked' if self._locked else 'unlocked' + if self._waiters: + extra = '{},waiters:{}'.format(extra, len(self._waiters)) + return '<{} [{}]>'.format(res[1:-1], extra) def locked(self): """Return true if lock is acquired.""" return self._locked @tasks.coroutine - def acquire(self, timeout=None): + def acquire(self): """Acquire a lock. - Acquire method blocks until the lock is unlocked, then set it to - locked and return True. - - When invoked with the floating-point timeout argument set, blocks for - at most the number of seconds specified by timeout and as long as - the lock cannot be acquired. - - The return value is True if the lock is acquired successfully, - False if not (for example if the timeout expired). + This method blocks until the lock is unlocked, then sets it to + locked and returns True. """ if not self._waiters and not self._locked: self._locked = True return True - fut = futures.Future(loop=self._loop, timeout=timeout) - + fut = futures.Future(loop=self._loop) self._waiters.append(fut) try: yield from fut - except futures.CancelledError: + self._locked = True + return True + finally: self._waiters.remove(fut) - return False - else: - f = self._waiters.popleft() - assert f is fut - - self._locked = True - return True def release(self): """Release a lock. @@ -124,8 +115,11 @@ def release(self): """ if self._locked: self._locked = False - if self._waiters: - self._waiters[0].set_result(True) + # Wake up the first waiter who isn't cancelled. + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + break else: raise RuntimeError('Lock is not acquired.') @@ -143,8 +137,8 @@ def __iter__(self): return self -class EventWaiter: - """A EventWaiter implementation, our equivalent to threading.Event +class Event: + """An Event implementation, our equivalent to threading.Event. Class implementing event objects. An event manages a flag that can be set to true with the set() method and reset to false with the clear() method. @@ -161,6 +155,7 @@ def __init__(self, *, loop=None): self._loop = events.get_event_loop() def __repr__(self): + # TODO: add waiters:N if > 0. res = super().__repr__() return '<{} [{}]>'.format(res[1:-1], 'set' if self._value else 'unset') @@ -187,41 +182,26 @@ def clear(self): self._value = False @tasks.coroutine - def wait(self, timeout=None): - """Block until the internal flag is true. If the internal flag - is true on entry, return immediately. Otherwise, block until another - coroutine calls set() to set the flag to true, or until the optional - timeout occurs. - - When the timeout argument is present and not None, it should be - a floating point number specifying a timeout for the operation in - seconds (or fractions thereof). - - This method returns true if and only if the internal flag has been - set to true, either before the wait call or after the wait starts, - so it will always return True except if a timeout is given and - the operation times out. - - wait() method is a coroutine. + def wait(self): + """Block until the internal flag is true. + + If the internal flag is true on entry, return True + immediately. Otherwise, block until another coroutine calls + set() to set the flag to true, then return True. """ if self._value: return True - fut = futures.Future(loop=self._loop, timeout=timeout) - + fut = futures.Future(loop=self._loop) self._waiters.append(fut) try: yield from fut - except futures.CancelledError: + return True + finally: self._waiters.remove(fut) - return False - else: - f = self._waiters.popleft() - assert f is fut - - return True +# TODO: Why is this a Lock subclass? threading.Condition *has* a lock. class Condition(Lock): """A Condition implementation. @@ -232,75 +212,55 @@ class Condition(Lock): def __init__(self, *, loop=None): super().__init__(loop=loop) - self._condition_waiters = collections.deque() - @tasks.coroutine - def wait(self, timeout=None): - """Wait until notified or until a timeout occurs. If the calling - coroutine has not acquired the lock when this method is called, - a RuntimeError is raised. + # TODO: Add __repr__() with len(_condition_waiters). - This method releases the underlying lock, and then blocks until it is - awakened by a notify() or notify_all() call for the same condition - variable in another coroutine, or until the optional timeout occurs. - Once awakened or timed out, it re-acquires the lock and returns. + @tasks.coroutine + def wait(self): + """Wait until notified. - When the timeout argument is present and not None, it should be - a floating point number specifying a timeout for the operation - in seconds (or fractions thereof). + If the calling coroutine has not acquired the lock when this + method is called, a RuntimeError is raised. - The return value is True unless a given timeout expired, in which - case it is False. + This method releases the underlying lock, and then blocks + until it is awakened by a notify() or notify_all() call for + the same condition variable in another coroutine. Once + awakened, it re-acquires the lock and returns True. """ if not self._locked: raise RuntimeError('cannot wait on un-acquired lock') - self.release() - - fut = futures.Future(loop=self._loop, timeout=timeout) - - self._condition_waiters.append(fut) keep_lock = True + self.release() try: - yield from fut - except futures.CancelledError: - self._condition_waiters.remove(fut) - return False + fut = futures.Future(loop=self._loop) + self._condition_waiters.append(fut) + try: + yield from fut + return True + finally: + self._condition_waiters.remove(fut) + except GeneratorExit: keep_lock = False # Prevent yield in finally clause. raise - else: - f = self._condition_waiters.popleft() - assert fut is f finally: if keep_lock: yield from self.acquire() - return True - @tasks.coroutine - def wait_for(self, predicate, timeout=None): - """Wait until a condition evaluates to True. predicate should be a - callable which result will be interpreted as a boolean value. A timeout - may be provided giving the maximum time to wait. + def wait_for(self, predicate): + """Wait until a predicate becomes true. + + The predicate should be a callable which result will be + interpreted as a boolean value. The final predicate value is + the return value. """ - endtime = None - waittime = timeout result = predicate() - while not result: - if waittime is not None: - if endtime is None: - endtime = self._loop.time() + waittime - else: - waittime = endtime - self._loop.time() - if waittime <= 0: - break - - yield from self.wait(waittime) + yield from self.wait() result = predicate() - return result def notify(self, n=1): @@ -370,6 +330,7 @@ def __init__(self, value=1, bound=False, *, loop=None): self._loop = events.get_event_loop() def __repr__(self): + # TODO: add waiters:N if > 0. res = super().__repr__() return '<{} [{}]>'.format( res[1:-1], @@ -381,17 +342,14 @@ def locked(self): return self._locked @tasks.coroutine - def acquire(self, timeout=None): - """Acquire a semaphore. acquire() method is a coroutine. - - When invoked without arguments: if the internal counter is larger - than zero on entry, decrement it by one and return immediately. - If it is zero on entry, block, waiting until some other coroutine has - called release() to make it larger than zero. - - When invoked with a timeout other than None, it will block for at - most timeout seconds. If acquire does not complete successfully in - that interval, return false. Return true otherwise. + def acquire(self): + """Acquire a semaphore. + + If the internal counter is larger than zero on entry, + decrement it by one and return True immediately. If it is + zero on entry, block, waiting until some other coroutine has + called release() to make it larger than 0, and then return + True. """ if not self._waiters and self._value > 0: self._value -= 1 @@ -399,22 +357,17 @@ def acquire(self, timeout=None): self._locked = True return True - fut = futures.Future(loop=self._loop, timeout=timeout) - + fut = futures.Future(loop=self._loop) self._waiters.append(fut) try: yield from fut - except futures.CancelledError: + self._value -= 1 + if self._value == 0: + self._locked = True + return True + finally: self._waiters.remove(fut) - return False - else: - f = self._waiters.popleft() - assert f is fut - self._value -= 1 - if self._value == 0: - self._locked = True - return True def release(self): """Release a semaphore, incrementing the internal counter by one. @@ -437,6 +390,8 @@ def release(self): break def __enter__(self): + # TODO: This is questionable. How do we know the user actually + # wrote "with (yield from sema)" instead of "with sema"? return True def __exit__(self, *args): diff --git a/tulip/queues.py b/tulip/queues.py index 8214d0ec..536de1cb 100644 --- a/tulip/queues.py +++ b/tulip/queues.py @@ -4,7 +4,6 @@ 'Full', 'Empty'] import collections -import concurrent.futures import heapq import queue @@ -70,10 +69,10 @@ def _format(self): result += ' _putters[{}]'.format(len(self._putters)) return result - def _consume_done_getters(self, waiters): + def _consume_done_getters(self): # Delete waiters at the head of the get() queue who've timed out. - while waiters and waiters[0].done(): - waiters.popleft() + while self._getters and self._getters[0].done(): + self._getters.popleft() def _consume_done_putters(self): # Delete waiters at the head of the put() queue who've timed out. @@ -105,16 +104,13 @@ def full(self): return self.qsize() == self._maxsize @coroutine - def put(self, item, timeout=None): + def put(self, item): """Put an item into the queue. - If you yield from put() and timeout is None (the default), wait until a - free slot is available before adding item. - - If a timeout is provided, raise Full if no free slot becomes - available before the timeout. + If you yield from put(), wait until a free slot is available + before adding item. """ - self._consume_done_getters(self._getters) + self._consume_done_getters() if self._getters: assert not self._queue, ( 'queue non-empty, why are getters waiting?') @@ -127,13 +123,10 @@ def put(self, item, timeout=None): getter.set_result(self._get()) elif self._maxsize > 0 and self._maxsize == self.qsize(): - waiter = futures.Future(loop=self._loop, timeout=timeout) + waiter = futures.Future(loop=self._loop) self._putters.append((item, waiter)) - try: - yield from waiter - except concurrent.futures.CancelledError: - raise Full + yield from waiter else: self._put(item) @@ -143,7 +136,7 @@ def put_nowait(self, item): If no free slot is immediately available, raise Full. """ - self._consume_done_getters(self._getters) + self._consume_done_getters() if self._getters: assert not self._queue, ( 'queue non-empty, why are getters waiting?') @@ -161,14 +154,10 @@ def put_nowait(self, item): self._put(item) @coroutine - def get(self, timeout=None): + def get(self): """Remove and return an item from the queue. - If you yield from get() and timeout is None (the default), wait until a - item is available. - - If a timeout is provided, raise Empty if no item is available - before the timeout. + If you yield from get(), wait until a item is available. """ self._consume_done_putters() if self._putters: @@ -187,13 +176,10 @@ def get(self, timeout=None): elif self.qsize(): return self._get() else: - waiter = futures.Future(loop=self._loop, timeout=timeout) + waiter = futures.Future(loop=self._loop) self._getters.append(waiter) - try: - return (yield from waiter) - except concurrent.futures.CancelledError: - raise Empty + return (yield from waiter) def get_nowait(self): """Remove and return an item from the queue. @@ -251,7 +237,7 @@ class JoinableQueue(Queue): def __init__(self, maxsize=0, *, loop=None): super().__init__(maxsize=maxsize, loop=loop) self._unfinished_tasks = 0 - self._finished = locks.EventWaiter(loop=self._loop) + self._finished = locks.Event(loop=self._loop) self._finished.set() def _format(self): @@ -286,7 +272,7 @@ def task_done(self): self._finished.set() @coroutine - def join(self, timeout=None): + def join(self): """Block until all items in the queue have been gotten and processed. The count of unfinished tasks goes up whenever an item is added to the @@ -295,4 +281,4 @@ def join(self, timeout=None): When the count of unfinished tasks drops to zero, join() unblocks. """ if self._unfinished_tasks > 0: - yield from self._finished.wait(timeout=timeout) + yield from self._finished.wait() diff --git a/tulip/tasks.py b/tulip/tasks.py index f488e5da..ca2ef0c0 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -1,6 +1,6 @@ """Support for tasks, coroutines and the scheduler.""" -__all__ = ['coroutine', 'task', 'Task', +__all__ = ['coroutine', 'Task', 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', 'wait', 'wait_for', 'as_completed', 'sleep', 'async', 'gather', @@ -50,32 +50,12 @@ def iscoroutine(obj): return inspect.isgenerator(obj) # TODO: And what? -def task(func): - """Decorator for a coroutine to be wrapped in a Task.""" - if inspect.isgeneratorfunction(func): - coro = func - else: - def coro(*args, **kw): - res = func(*args, **kw) - if isinstance(res, futures.Future) or inspect.isgenerator(res): - res = yield from res - return res - - def task_wrapper(*args, **kwds): - return Task(coro(*args, **kwds)) - - return task_wrapper - - -_marker = object() - - class Task(futures.Future): """A coroutine wrapped in a Future.""" - def __init__(self, coro, *, loop=None, timeout=None): + def __init__(self, coro, *, loop=None): assert inspect.isgenerator(coro) # Must be a coroutine *object*. - super().__init__(loop=loop, timeout=timeout) + super().__init__(loop=loop) self._coro = coro self._fut_waiter = None self._must_cancel = False @@ -94,36 +74,28 @@ def __repr__(self): return res def cancel(self): - if self.done() or self._must_cancel: + if self.done(): return False - self._must_cancel = True - # _step() will call super().cancel() to call the callbacks. if self._fut_waiter is not None: - return self._fut_waiter.cancel() - else: - self._loop.call_soon(self._step_maybe) - return True - - def cancelled(self): - return self._must_cancel or super().cancelled() - - def _step_maybe(self): - # Helper for cancel(). - if not self.done(): - return self._step() + if self._fut_waiter.cancel(): + # Leave self._fut_waiter; it may be a Task that + # catches and ignores the cancellation so we may have + # to cancel it again later. + return True + # It must be the case that self._step is already scheduled. + self._must_cancel = True + return True - def _step(self, value=_marker, exc=None): + def _step(self, value=None, exc=None): assert not self.done(), \ '_step(): already done: {!r}, {!r}, {!r}'.format(self, value, exc) - - # We'll call either coro.throw(exc) or coro.send(value). - # Task cancel has to be delayed if current waiter future is done. - if self._must_cancel and exc is None and value is _marker: - exc = futures.CancelledError - + if self._must_cancel: + if not isinstance(exc, futures.CancelledError): + exc = futures.CancelledError() + self._must_cancel = False coro = self._coro - value = None if value is _marker else value self._fut_waiter = None + # Call either coro.throw(exc) or coro.send(value). try: if exc is not None: result = coro.throw(exc) @@ -132,59 +104,54 @@ def _step(self, value=_marker, exc=None): else: result = next(coro) except StopIteration as exc: - if self._must_cancel: - super().cancel() - else: - self.set_result(exc.value) + self.set_result(exc.value) + except futures.CancelledError as exc: + super().cancel() # I.e., Future.cancel(self). except Exception as exc: - if self._must_cancel: - super().cancel() - else: - self.set_exception(exc) + self.set_exception(exc) except BaseException as exc: - if self._must_cancel: - super().cancel() - else: - self.set_exception(exc) + self.set_exception(exc) raise else: if isinstance(result, futures.Future): - if not result._blocking: - result.set_exception( + # Yielded Future must come from Future.__iter__(). + if result._blocking: + result._blocking = False + result.add_done_callback(self._wakeup) + self._fut_waiter = result + if self._must_cancel: + if self._fut_waiter.cancel(): + self._must_cancel = False + else: + self._loop.call_soon( + self._step, None, RuntimeError( 'yield was used instead of yield from ' 'in task {!r} with {!r}'.format(self, result))) - - result._blocking = False - result.add_done_callback(self._wakeup) - self._fut_waiter = result - - # task cancellation has been delayed. - if self._must_cancel: - self._fut_waiter.cancel() - + elif result is None: + # Bare yield relinquishes control for one event loop iteration. + self._loop.call_soon(self._step) + elif inspect.isgenerator(result): + # Yielding a generator is just wrong. + self._loop.call_soon( + self._step, None, + RuntimeError( + 'yield was used instead of yield from for ' + 'generator in task {!r} with {}'.format( + self, result))) else: - if inspect.isgenerator(result): - self._loop.call_soon( - self._step, None, - RuntimeError( - 'yield was used instead of yield from for ' - 'generator in task {!r} with {}'.format( - self, result))) - else: - if result is not None: - self._loop.call_soon( - self._step, None, - RuntimeError( - 'Task got bad yield: {!r}'.format(result))) - else: - self._loop.call_soon(self._step_maybe) + # Yielding something else is an error. + self._loop.call_soon( + self._step, None, + RuntimeError( + 'Task got bad yield: {!r}'.format(result))) self = None def _wakeup(self, future): try: value = future.result() except Exception as exc: + # This may also be a cancellation. self._step(None, exc) else: self._step(value, None) @@ -228,11 +195,11 @@ def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): @coroutine def wait_for(fut, timeout, *, loop=None): - """Wait for the single Future or coroutine to complete. + """Wait for the single Future or coroutine to complete, with timeout. Coroutine will be wrapped in Task. - Returns result of the Future or coroutine. Raises TimeoutError when + Returns result of the Future or coroutine. Raises TimeoutError when timeout occurs. Usage: @@ -260,7 +227,10 @@ def _wait(fs, timeout, return_when, loop): The timeout argument is like for wait(). """ assert fs, 'Set of Futures is empty.' - waiter = futures.Future(loop=loop, timeout=timeout) + waiter = futures.Future(loop=loop) + timeout_handle = None + if timeout is not None: + timeout_handle = loop.call_later(timeout, waiter.cancel) counter = len(fs) def _on_completion(f): @@ -270,6 +240,8 @@ def _on_completion(f): return_when == FIRST_COMPLETED or return_when == FIRST_EXCEPTION and (not f.cancelled() and f.exception() is not None)): + if timeout_handle is not None: + timeout_handle.cancel() waiter.cancel() for f in fs: @@ -342,20 +314,17 @@ def sleep(delay, result=None, *, loop=None): h.cancel() -def async(coro_or_future, *, loop=None, timeout=None): +def async(coro_or_future, *, loop=None): """Wrap a coroutine in a future. If the argument is a Future, it is returned directly. """ if isinstance(coro_or_future, futures.Future): - if ((loop is not None and loop is not coro_or_future._loop) or - (timeout is not None and timeout != coro_or_future._timeout)): - raise ValueError( - 'loop and timeout arguments must agree with Future') - + if loop is not None and loop is not coro_or_future._loop: + raise ValueError('loop argument must agree with Future') return coro_or_future elif iscoroutine(coro_or_future): - return Task(coro_or_future, loop=loop, timeout=timeout) + return Task(coro_or_future, loop=loop) else: raise TypeError('A Future or coroutine is required') From 1c4a2fa5240879b29516d88b53cbb8fb8360035c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 9 Sep 2013 07:51:10 -0700 Subject: [PATCH 0620/1502] Add comment about Task invariants. --- tulip/tasks.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tulip/tasks.py b/tulip/tasks.py index ca2ef0c0..b3fc8c69 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -53,6 +53,16 @@ def iscoroutine(obj): class Task(futures.Future): """A coroutine wrapped in a Future.""" + # An important invariant maintained while a Task not done: + # + # - Either _fut_waiter is None, and _step() is scheduled; + # - or _fut_waiter is some Future, and _step() is *not* scheduled. + # + # The only transition from the latter to the former is through + # _wakeup(). When _fut_waiter is not None, one of its callbacks + # must be _wakeup(). + + def __init__(self, coro, *, loop=None): assert inspect.isgenerator(coro) # Must be a coroutine *object*. super().__init__(loop=loop) From 46b723669cabea8d425295da27d526fd998d76b0 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 9 Sep 2013 09:25:43 -0700 Subject: [PATCH 0621/1502] Three more thorough tests for cancellation (possibly redundant). --- tests/tasks_test.py | 88 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index e0611d23..808f368a 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -1108,6 +1108,94 @@ def coro(): self.assertEqual(res, 'test') self.assertIsNone(t2.result()) + # Some thorough tests for cancellation propagation through + # coroutines, tasks and wait(). + + def test_yield_future_passes_cancel(self): + # Cancelling outer() cancels inner() cancels waiter. + proof = 0 + waiter = futures.Future(loop=self.loop) + + @tasks.coroutine + def inner(): + nonlocal proof + try: + yield from waiter + except futures.CancelledError: + proof += 1 + raise + else: + self.fail('got past sleep() in inner()') + + @tasks.coroutine + def outer(): + nonlocal proof + try: + yield from inner() + except futures.CancelledError: + proof += 100 # Expect this path. + else: + proof += 10 + + f = tasks.async(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + self.loop.run_until_complete(f) + self.assertEqual(proof, 101) + self.assertTrue(waiter.cancelled()) + + def test_yield_wait_shields_cancel(self): + # Cancelling outer() makes wait() return early, leaves inner() + # running. + proof = 0 + waiter = futures.Future(loop=self.loop) + + @tasks.coroutine + def inner(): + nonlocal proof + yield from waiter + proof += 1 + + @tasks.coroutine + def outer(): + nonlocal proof + d, p = yield from tasks.wait([inner()], loop=self.loop) + proof += 100 + + f = tasks.async(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + self.loop.run_until_complete(f) + waiter.set_result(None) + test_utils.run_briefly(self.loop) + self.assertEqual(proof, 101) + + def test_yield_gather_blocks_cancel(self): + # Cancelling outer() cancels gather() but not inner(). + proof = 0 + waiter = futures.Future(loop=self.loop) + + @tasks.coroutine + def inner(): + nonlocal proof + yield from waiter + proof += 1 + + @tasks.coroutine + def outer(): + nonlocal proof + yield from tasks.gather(inner(), loop=self.loop) + proof += 100 + + f = tasks.async(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + with self.assertRaises(futures.CancelledError): + self.loop.run_until_complete(f) + waiter.set_result(None) + test_utils.run_briefly(self.loop) + self.assertEqual(proof, 1) + class GatherTestsBase: From 16fed06fd0ab0360d2d6d7612fd1318f9dfc9e6c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 9 Sep 2013 09:49:37 -0700 Subject: [PATCH 0622/1502] Remove spurious blank line. --- tulip/tasks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tulip/tasks.py b/tulip/tasks.py index b3fc8c69..ee6d8087 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -61,7 +61,6 @@ class Task(futures.Future): # The only transition from the latter to the former is through # _wakeup(). When _fut_waiter is not None, one of its callbacks # must be _wakeup(). - def __init__(self, coro, *, loop=None): assert inspect.isgenerator(coro) # Must be a coroutine *object*. From 85596cdf4ad6973291b538a4b09c5b4511d2f0b2 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 9 Sep 2013 11:11:22 -0700 Subject: [PATCH 0623/1502] Change start_serving() into a coroutine. No tests break. --- tulip/base_events.py | 27 +-------------------------- tulip/events.py | 6 +++--- 2 files changed, 4 insertions(+), 29 deletions(-) diff --git a/tulip/base_events.py b/tulip/base_events.py index 3bccfc83..6f77d93d 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -382,14 +382,7 @@ def create_datagram_endpoint(self, protocol_factory, sock, protocol, r_addr, extra={'addr': l_addr}) return transport, protocol - # This returns a Task made from self._start_serving_internal(). - # We want start_serving() to return a Task so that it will start - # running right away (when the event loop runs) even if the caller - # doesn't wait for it. Note that this is different from - # e.g. create_connection(), or create_datagram_endpoint(), which - # are a "mere" coroutines and require their caller to wait for - # them. The reason for the difference is that only - # start_serving() creates multiple transports and protocols. + @tasks.coroutine def start_serving(self, protocol_factory, host=None, port=None, *, family=socket.AF_UNSPEC, @@ -398,24 +391,6 @@ def start_serving(self, protocol_factory, host=None, port=None, backlog=100, ssl=None, reuse_address=None): - coro = self._start_serving_internal(protocol_factory, host, port, - family=family, - flags=flags, - sock=sock, - backlog=backlog, - ssl=ssl, - reuse_address=reuse_address) - return tasks.Task(coro, loop=self) - - @tasks.coroutine - def _start_serving_internal(self, protocol_factory, host=None, port=None, - *, - family=socket.AF_UNSPEC, - flags=socket.AI_PASSIVE, - sock=None, - backlog=100, - ssl=None, - reuse_address=None): """XXX""" if host is not None or port is not None: if sock is not None: diff --git a/tulip/events.py b/tulip/events.py index 7db2514d..9e715a17 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -169,9 +169,9 @@ def create_connection(self, protocol_factory, host=None, port=None, *, def start_serving(self, protocol_factory, host=None, port=None, *, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, sock=None, backlog=100, ssl=None, reuse_address=None): - """Creates a TCP server bound to host and port and return a - Task whose result will be a list of socket objects which will - later be handled by protocol_factory. + """A coroutine which creates a TCP server bound to host and + port and whose result will be a list of socket objects which + will later be handled by protocol_factory. If host is an empty string or None all interfaces are assumed and a list of multiple sockets will be returned (most likely From 41d24c92b3f85cbf1831d6498cbaaa687c941e4a Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 9 Sep 2013 16:35:38 -0700 Subject: [PATCH 0624/1502] Fix test failures on Windows. --- tests/windows_events_test.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/windows_events_test.py b/tests/windows_events_test.py index ce9b74da..b75eebbf 100644 --- a/tests/windows_events_test.py +++ b/tests/windows_events_test.py @@ -5,6 +5,7 @@ from tulip import windows_events from tulip import protocols from tulip import streams +from tulip import test_utils def connect_read_pipe(loop, file): @@ -45,18 +46,17 @@ def test_pause_resume_discard(self): f = tulip.async(reader.readline(), loop=self.loop) trans.write(b'msg1\n') - self.loop.run_until_complete(f, timeout=0.01) + self.loop.run_until_complete(f) self.assertEqual(f.result(), b'msg1\n') f = tulip.async(reader.readline(), loop=self.loop) trans.pause_writing() trans.write(b'msg2\n') - with self.assertRaises(tulip.TimeoutError): - self.loop.run_until_complete(f, timeout=0.01) + test_utils.run_briefly(self.loop) self.assertEqual(trans._buffer, [b'msg2\n']) trans.resume_writing() - self.loop.run_until_complete(f, timeout=0.1) + self.loop.run_until_complete(f) self.assertEqual(f.result(), b'msg2\n') f = tulip.async(reader.readline(), loop=self.loop) @@ -69,7 +69,7 @@ def test_pause_resume_discard(self): trans.write(b'msg4\n') self.assertEqual(trans._buffer, [b'msg4\n']) trans.resume_writing() - self.loop.run_until_complete(f, timeout=0.01) + self.loop.run_until_complete(f) self.assertEqual(f.result(), b'msg4\n') def test_close(self): @@ -77,5 +77,5 @@ def test_close(self): trans = self.loop._make_socket_transport(a, protocols.Protocol()) f = tulip.async(self.loop.sock_recv(b, 100)) trans.close() - self.loop.run_until_complete(f, timeout=1) + self.loop.run_until_complete(f) self.assertEqual(f.result(), b'') From 1c584f1dbb5540b82c1fcc67643413145e928fd7 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 10 Sep 2013 14:55:47 -0700 Subject: [PATCH 0625/1502] Fix test reported by Sa?l Ibarra Corretg?. Tweak _wait() docsstring. --- tests/events_test.py | 1 + tulip/tasks.py | 5 ++--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 91cd5db7..1db35d14 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -318,6 +318,7 @@ def reader(): self.loop.call_soon(w.send, b'abc') test_utils.run_briefly(self.loop) self.loop.call_soon(w.send, b'def') + test_utils.run_briefly(self.loop) self.loop.call_soon(w.close) self.loop.call_soon(self.loop.stop) self.loop.run_forever() diff --git a/tulip/tasks.py b/tulip/tasks.py index ee6d8087..4d00c044 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -230,10 +230,9 @@ def wait_for(fut, timeout, *, loop=None): @coroutine def _wait(fs, timeout, return_when, loop): - """Internal helper for wait(return_when=FIRST_COMPLETED). + """Internal helper for wait() and _wait_for(). - The fs argument must be a set of Futures. - The timeout argument is like for wait(). + The fs argument must be a collection of Futures. """ assert fs, 'Set of Futures is empty.' waiter = futures.Future(loop=loop) From 673a5de13a18098c630bc98e5c59ea12b190b683 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Fri, 13 Sep 2013 15:40:45 -0700 Subject: [PATCH 0626/1502] raise EofStream in DataBuffer.read() instead of returning None; examples cleanup --- examples/child_process.py | 2 +- examples/crawl.py | 14 +++++------ examples/mpsrv.py | 34 +++++++++++++++------------ examples/srv.py | 4 ++-- examples/tcp_protocol_parser.py | 20 ++++++++++------ examples/wsclient.py | 9 +++++--- examples/wssrv.py | 41 ++++++++++++++++++--------------- tests/http_client_test.py | 9 ++++++++ tests/parsers_test.py | 17 ++++++++++---- tulip/http/client.py | 35 +++++++++++++++++----------- tulip/http/wsgi.py | 9 ++++---- tulip/locks.py | 1 - tulip/parsers.py | 2 +- tulip/test_utils.py | 9 ++++---- 14 files changed, 125 insertions(+), 81 deletions(-) diff --git a/examples/child_process.py b/examples/child_process.py index d4a035bd..5a88faa6 100644 --- a/examples/child_process.py +++ b/examples/child_process.py @@ -52,7 +52,7 @@ def factory(): # Example # -@tulip.task +@tulip.coroutine def main(loop): # program which prints evaluation of each expression from stdin code = r'''if 1: diff --git a/examples/crawl.py b/examples/crawl.py index ac9c25e9..f7d53feb 100755 --- a/examples/crawl.py +++ b/examples/crawl.py @@ -24,9 +24,9 @@ def __init__(self, rooturl, loop, maxtasks=100): # session stores cookies between requests and uses connection pool self.session = tulip.http.Session() - @tulip.task + @tulip.coroutine def run(self): - self.addurls([(self.rooturl, '')]) # Set initial work. + tulip.Task(self.addurls([(self.rooturl, '')])) # Set initial work. yield from tulip.sleep(1) while self.busy: yield from tulip.sleep(1) @@ -34,7 +34,7 @@ def run(self): self.session.close() self.loop.stop() - @tulip.task + @tulip.coroutine def addurls(self, urls): for url, parenturl in urls: url = urllib.parse.urljoin(parenturl, url) @@ -45,12 +45,12 @@ def addurls(self, urls): url not in self.todo): self.todo.add(url) yield from self.sem.acquire() - task = self.process(url) + task = tulip.Task(self.process(url)) task.add_done_callback(lambda t: self.sem.release()) task.add_done_callback(self.tasks.remove) self.tasks.add(task) - @tulip.task + @tulip.coroutine def process(self, url): print('processing:', url) @@ -66,7 +66,7 @@ def process(self, url): if resp.status == 200 and resp.get_content_type() == 'text/html': data = (yield from resp.read()).decode('utf-8', 'replace') urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', data) - self.addurls([(u, url) for u in urls]) + tulip.Task(self.addurls([(u, url) for u in urls])) resp.close() self.done[url] = True @@ -80,7 +80,7 @@ def main(): loop = tulip.get_event_loop() c = Crawler(sys.argv[1], loop) - c.run() + tulip.Task(c.run()) try: loop.add_signal_handler(signal.SIGINT, loop.stop) diff --git a/examples/mpsrv.py b/examples/mpsrv.py index 6b1ebb8f..c594f5bc 100755 --- a/examples/mpsrv.py +++ b/examples/mpsrv.py @@ -42,7 +42,7 @@ def handle_request(self, message, payload): isdir = os.path.isdir(path) if not path: - raise tulip.http.HttpStatusException(404) + raise tulip.http.HttpErrorException(404) headers = email.message.Message() for hdr, val in message.headers: @@ -50,7 +50,7 @@ def handle_request(self, message, payload): if isdir and not path.endswith('/'): path = path + '/' - raise tulip.http.HttpStatusException( + raise tulip.http.HttpErrorException( 302, headers=(('URI', path), ('Location', path))) response = tulip.http.Response(self.transport, 200) @@ -129,12 +129,12 @@ def stop(): os.getpid(), x.getsockname())) # heartbeat - self.heartbeat() + tulip.Task(self.heartbeat()) tulip.get_event_loop().run_forever() os._exit(0) - @tulip.task + @tulip.coroutine def heartbeat(self): # setup pipes read_transport, read_proto = yield from self.loop.connect_read_pipe( @@ -146,12 +146,14 @@ def heartbeat(self): writer = websocket.WebSocketWriter(write_transport) while True: - msg = yield from reader.read() - if msg is None: + try: + msg = yield from reader.read() + except tulip.EofStream: print('Superviser is dead, {} stopping...'.format(os.getpid())) self.loop.stop() break - elif msg.tp == websocket.MSG_PING: + + if msg.tp == websocket.MSG_PING: writer.pong() elif msg.tp == websocket.MSG_CLOSE: break @@ -196,7 +198,7 @@ def start(self): process = ChildProcess(up_read, down_write, args, sock) process.start() - @tulip.task + @tulip.coroutine def heartbeat(self, writer): while True: yield from tulip.sleep(15) @@ -210,20 +212,22 @@ def heartbeat(self, writer): self.start() return - @tulip.task + @tulip.coroutine def chat(self, reader): while True: - msg = yield from reader.read() - if msg is None: + try: + msg = yield from reader.read() + except tulip.EofStream: print('Restart unresponsive worker process: {}'.format( self.pid)) self.kill() self.start() return - elif msg.tp == websocket.MSG_PONG: + + if msg.tp == websocket.MSG_PONG: self.ping = time.monotonic() - @tulip.task + @tulip.coroutine def connect(self, pid, up_write, down_read): # setup pipes read_transport, proto = yield from self.loop.connect_read_pipe( @@ -240,8 +244,8 @@ def connect(self, pid, up_write, down_read): self.ping = time.monotonic() self.rtransport = read_transport self.wtransport = write_transport - self.chat_task = self.chat(reader) - self.heartbeat_task = self.heartbeat(writer) + self.chat_task = tulip.Task(self.chat(reader)) + self.heartbeat_task = tulip.Task(self.heartbeat(writer)) def kill(self): self._started = False diff --git a/examples/srv.py b/examples/srv.py index e01e407c..e4bf16c1 100755 --- a/examples/srv.py +++ b/examples/srv.py @@ -38,7 +38,7 @@ def handle_request(self, message, payload): isdir = os.path.isdir(path) if not path: - raise tulip.http.HttpStatusException(404) + raise tulip.http.HttpErrorException(404) headers = email.message.Message() for hdr, val in message.headers: @@ -47,7 +47,7 @@ def handle_request(self, message, payload): if isdir and not path.endswith('/'): path = path + '/' - raise tulip.http.HttpStatusException( + raise tulip.http.HttpErrorException( 302, headers=(('URI', path), ('Location', path))) response = tulip.http.Response(self.transport, 200) diff --git a/examples/tcp_protocol_parser.py b/examples/tcp_protocol_parser.py index a0258613..e4fc59ad 100755 --- a/examples/tcp_protocol_parser.py +++ b/examples/tcp_protocol_parser.py @@ -70,7 +70,7 @@ def connection_made(self, transport): print('Connection made') self.transport = transport self.stream = tulip.StreamBuffer() - self.dispatch() + tulip.Task(self.dispatch()) def data_received(self, data): self.stream.feed_data(data) @@ -81,15 +81,17 @@ def eof_received(self): def connection_lost(self, exc): print('Connection lost') - @tulip.task + @tulip.coroutine def dispatch(self): reader = self.stream.set_parser(my_protocol_parser()) writer = MyProtocolWriter(self.transport) while True: - msg = yield from reader.read() - if msg is None: - break # client has been disconnected + try: + msg = yield from reader.read() + except tulip.EofStream: + # client has been disconnected + break print('Message received: {}'.format(msg)) @@ -102,7 +104,7 @@ def dispatch(self): break -@tulip.task +@tulip.coroutine def start_client(loop, host, port): transport, stream = yield from loop.create_connection( tulip.StreamProtocol, host, port) @@ -113,7 +115,11 @@ def start_client(loop, host, port): message = 'This is the message. It will be echoed.' while True: - msg = yield from reader.read() + try: + msg = yield from reader.read() + except tulip.EofStream: + print('Server has been disconnected.') + break print('Message received: {}'.format(msg)) if msg.tp == MSG_PONG: diff --git a/examples/wsclient.py b/examples/wsclient.py index f5b2ef58..ed7beda5 100755 --- a/examples/wsclient.py +++ b/examples/wsclient.py @@ -59,10 +59,13 @@ def stdin_callback(): @tulip.coroutine def dispatch(): while True: - msg = yield from stream.read() - if msg is None: + try: + msg = yield from stream.read() + except tulip.EofStream: + # server disconnected break - elif msg.tp == websocket.MSG_PING: + + if msg.tp == websocket.MSG_PING: writer.pong() elif msg.tp == websocket.MSG_TEXT: print(msg.data.strip()) diff --git a/examples/wssrv.py b/examples/wssrv.py index f96e0855..8a02a2dd 100755 --- a/examples/wssrv.py +++ b/examples/wssrv.py @@ -59,8 +59,10 @@ def handle_request(self, message, payload): # chat dispatcher while True: - msg = yield from databuffer.read() - if msg is None: # client droped connection + try: + msg = yield from databuffer.read() + except tulip.EofStream: + # client droped connection break if msg.tp == websocket.MSG_PING: @@ -126,12 +128,12 @@ def stop(): loop.add_signal_handler(signal.SIGINT, stop) # heartbeat - self.heartbeat() + tulip.Task(self.heartbeat()) tulip.get_event_loop().run_forever() os._exit(0) - @tulip.task + @tulip.coroutine def start_server(self, writer): socks = yield from self.loop.start_serving( lambda: HttpServer( @@ -141,7 +143,7 @@ def start_server(self, writer): print('Starting srv worker process {} on {}'.format( os.getpid(), socks[0].getsockname())) - @tulip.task + @tulip.coroutine def heartbeat(self): # setup pipes read_transport, read_proto = yield from self.loop.connect_read_pipe( @@ -152,15 +154,17 @@ def heartbeat(self): reader = read_proto.set_parser(websocket.WebSocketParser()) writer = websocket.WebSocketWriter(write_transport) - self.start_server(writer) + tulip.Task(self.start_server(writer)) while True: - msg = yield from reader.read() - if msg is None: + try: + msg = yield from reader.read() + except tulip.EofStream: print('Superviser is dead, {} stopping...'.format(os.getpid())) self.loop.stop() break - elif msg.tp == websocket.MSG_PING: + + if msg.tp == websocket.MSG_PING: writer.pong() elif msg.tp == websocket.MSG_CLOSE: break @@ -196,7 +200,7 @@ def start(self): # parent os.close(up_read) os.close(down_write) - self.connect(pid, up_write, down_read) + tulip.async(self.connect(pid, up_write, down_read)) else: # child os.close(up_write) @@ -209,7 +213,7 @@ def start(self): process = ChildProcess(up_read, down_write, args, sock) process.start() - @tulip.task + @tulip.coroutine def heartbeat(self, writer): while True: yield from tulip.sleep(15) @@ -223,18 +227,19 @@ def heartbeat(self, writer): self.start() return - @tulip.task + @tulip.coroutine def chat(self, reader): while True: - msg = yield from reader.read() - if msg is None: + try: + msg = yield from reader.read() + except tulip.EofStream: print('Restart unresponsive worker process: {}'.format( self.pid)) self.kill() self.start() return - elif msg.tp == websocket.MSG_PONG: + if msg.tp == websocket.MSG_PONG: self.ping = time.monotonic() elif msg.tp == websocket.MSG_TEXT: # broadcast to all workers @@ -242,7 +247,7 @@ def chat(self, reader): if self.pid != worker.pid: worker.writer.send(msg.data) - @tulip.task + @tulip.coroutine def connect(self, pid, up_write, down_read): # setup pipes read_transport, proto = yield from self.loop.connect_read_pipe( @@ -260,8 +265,8 @@ def connect(self, pid, up_write, down_read): self.writer = writer self.rtransport = read_transport self.wtransport = write_transport - self.chat_task = self.chat(reader) - self.heartbeat_task = self.heartbeat(writer) + self.chat_task = tulip.async(self.chat(reader)) + self.heartbeat_task = tulip.async(self.heartbeat(writer)) def kill(self): self._started = False diff --git a/tests/http_client_test.py b/tests/http_client_test.py index 1aa27244..a911c975 100644 --- a/tests/http_client_test.py +++ b/tests/http_client_test.py @@ -226,6 +226,15 @@ def test_get_with_data(self): req = HttpRequest(meth, 'http://python.org/', data={'life': '42'}) self.assertEqual('/?life=42', req.path) + def test_bytes_data(self): + for meth in HttpRequest.POST_METHODS: + req = HttpRequest(meth, 'http://python.org/', data=b'binary data') + req.send(self.transport) + self.assertEqual('/', req.path) + self.assertEqual((b'binary data',), req.body) + self.assertEqual('application/octet-stream', + req.headers['content-type']) + @unittest.mock.patch('tulip.http.client.tulip') def test_content_encoding(self, m_tulip): req = HttpRequest('get', 'http://python.org/', compress='deflate') diff --git a/tests/parsers_test.py b/tests/parsers_test.py index debc532c..c6b7cec2 100644 --- a/tests/parsers_test.py +++ b/tests/parsers_test.py @@ -355,8 +355,8 @@ def cb(): buffer.feed_eof() self.loop.call_soon(cb) - data = self.loop.run_until_complete(read_task) - self.assertIsNone(data) + self.assertRaises( + parsers.EofStream, self.loop.run_until_complete, read_task) def test_read_until_eof(self): item = object() @@ -367,8 +367,8 @@ def test_read_until_eof(self): data = self.loop.run_until_complete(buffer.read()) self.assertIs(data, item) - data = self.loop.run_until_complete(buffer.read()) - self.assertIsNone(data) + self.assertRaises( + parsers.EofStream, self.loop.run_until_complete, buffer.read()) def test_read_exception(self): buffer = parsers.DataBuffer(loop=self.loop) @@ -428,7 +428,14 @@ def test_connection_lost_exc(self): self.assertIs(proto.exception(), exc) -class ParserBuffer(unittest.TestCase): +class ParserBufferTests(unittest.TestCase): + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() def _make_one(self): return parsers.ParserBuffer() diff --git a/tulip/http/client.py b/tulip/http/client.py index ec7cd034..a28fdc21 100644 --- a/tulip/http/client.py +++ b/tulip/http/client.py @@ -76,7 +76,7 @@ def request(method, url, *, >> resp - >> data = yield from resp.content.read() + >> data = yield from resp.read() """ redirects = 0 @@ -299,13 +299,20 @@ def __init__(self, method, url, *, data = list(data.items()) if data and not files: - if not isinstance(data, str): - data = urllib.parse.urlencode(data, doseq=True) + if isinstance(data, (bytes, bytearray)): + self.body = data + if 'content-type' not in self.headers: + self.headers['content-type'] = ( + 'application/octet-stream') + else: + if not isinstance(data, str): + data = urllib.parse.urlencode(data, doseq=True) + + self.body = data.encode(encoding) + if 'content-type' not in self.headers: + self.headers['content-type'] = ( + 'application/x-www-form-urlencoded') - self.body = data.encode(encoding) - if 'content-type' not in self.headers: - self.headers['content-type'] = ( - 'application/x-www-form-urlencoded') if 'content-length' not in self.headers and not chunked: self.headers['content-length'] = str(len(self.body)) @@ -486,12 +493,14 @@ def read(self, decode=False): if self._content is None: buf = [] total = 0 - chunk = yield from self.content.read() - while chunk: - size = len(chunk) - buf.append((chunk, size)) - total += size - chunk = yield from self.content.read() + try: + while True: + chunk = yield from self.content.read() + size = len(chunk) + buf.append((chunk, size)) + total += size + except tulip.EofStream: + pass self._content = bytearray(total) diff --git a/tulip/http/wsgi.py b/tulip/http/wsgi.py index 738e100f..02611f78 100644 --- a/tulip/http/wsgi.py +++ b/tulip/http/wsgi.py @@ -145,10 +145,11 @@ def handle_request(self, message, payload): if self.readpayload: wsgiinput = io.BytesIO() - chunk = yield from payload.read() - while chunk: - wsgiinput.write(chunk) - chunk = yield from payload.read() + try: + while True: + wsgiinput.write((yield from payload.read())) + except tulip.EofStream: + pass wsgiinput.seek(0) payload = wsgiinput diff --git a/tulip/locks.py b/tulip/locks.py index 8f1a1e9a..06edbbc1 100644 --- a/tulip/locks.py +++ b/tulip/locks.py @@ -368,7 +368,6 @@ def acquire(self): finally: self._waiters.remove(fut) - def release(self): """Release a semaphore, incrementing the internal counter by one. When it was zero on entry and another coroutine is waiting for it to diff --git a/tulip/parsers.py b/tulip/parsers.py index 43ddc2e9..8ac05e18 100644 --- a/tulip/parsers.py +++ b/tulip/parsers.py @@ -254,7 +254,7 @@ def read(self): if self._buffer: return self._buffer.popleft() else: - return None + raise EofStream class ParserBuffer(bytearray): diff --git a/tulip/test_utils.py b/tulip/test_utils.py index b4af0c89..cf04f21a 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -98,10 +98,11 @@ def handle_request(self, message, payload): if router is not None: body = bytearray() - chunk = yield from payload.read() - while chunk: - body.extend(chunk) - chunk = yield from payload.read() + try: + while True: + body.extend((yield from payload.read())) + except tulip.EofStream: + pass rob = router( self, properties, From de870f9e2b3bef2aa129e6c4afe28b63a7ba572a Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 19 Sep 2013 21:14:54 -0700 Subject: [PATCH 0627/1502] fix deadlock in _SelectorSocketTransport._write_ready --- tests/selector_events_test.py | 12 ++++++++++++ tulip/selector_events.py | 9 +++++++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index 06318622..352407c4 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -932,6 +932,18 @@ def test_write_ready_exception(self): transport._write_ready() transport._fatal_error.assert_called_with(err) + @unittest.mock.patch('tulip.selector_events.tulip_log') + def test_write_ready_exception_and_close(self, m_log): + self.sock.send.side_effect = OSError() + remove_writer = self.loop.remove_writer = unittest.mock.Mock() + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.close() + transport._buffer.append(b'data') + transport._write_ready() + remove_writer.assert_called_with(self.sock_fd) + def test_pause_writing(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 82d22bb6..5503f8da 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -352,13 +352,16 @@ def _fatal_error(self, exc): self._force_close(exc) def _force_close(self, exc): + if self._buffer: + self._buffer.clear() + self._loop.remove_writer(self._sock_fd) + if self._closing: return + self._closing = True self._conn_lost += 1 - self._loop.remove_writer(self._sock_fd) self._loop.remove_reader(self._sock_fd) - self._buffer.clear() self._loop.call_soon(self._call_connection_lost, exc) def _call_connection_lost(self, exc): @@ -441,6 +444,7 @@ def _write_ready(self): except (BlockingIOError, InterruptedError): self._buffer.append(data) except Exception as exc: + self._loop.remove_writer(self._sock_fd) self._fatal_error(exc) else: if n == len(data): @@ -555,6 +559,7 @@ def _on_ready(self): ssl.SSLWantReadError, ssl.SSLWantWriteError): n = 0 except Exception as exc: + self._loop.remove_writer(self._sock_fd) self._fatal_error(exc) return From c13410bf53b95b6d024c8106f269c425b19193b8 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 19 Sep 2013 21:20:42 -0700 Subject: [PATCH 0628/1502] pep8 related fixes --- check.py | 6 +++++- runtests.py | 1 - tests/queues_test.py | 16 ++++++++-------- tests/selector_events_test.py | 35 ++++++++++++++++------------------- tests/tasks_test.py | 4 +--- tests/transports_test.py | 1 - tests/unix_events_test.py | 2 -- tulip/base_events.py | 4 ++-- tulip/test_utils.py | 2 +- tulip/windows_utils.py | 2 +- 10 files changed, 34 insertions(+), 39 deletions(-) diff --git a/check.py b/check.py index 9ab6bcc0..6db82d64 100644 --- a/check.py +++ b/check.py @@ -1,6 +1,8 @@ """Search for lines >= 80 chars or with trailing whitespace.""" -import sys, os +import os +import sys + def main(): args = sys.argv[1:] or os.curdir @@ -15,6 +17,7 @@ def main(): else: process(arg) + def isascii(x): try: x.encode('ascii') @@ -22,6 +25,7 @@ def isascii(x): except UnicodeError: return False + def process(fn): try: f = open(fn) diff --git a/runtests.py b/runtests.py index 725bfa2e..62f55a4f 100644 --- a/runtests.py +++ b/runtests.py @@ -25,7 +25,6 @@ import os import re import sys -import subprocess import unittest import textwrap import importlib.machinery diff --git a/tests/queues_test.py b/tests/queues_test.py index 7241ffdc..437a1c30 100644 --- a/tests/queues_test.py +++ b/tests/queues_test.py @@ -62,7 +62,7 @@ def add_putter(): q = queues.Queue(maxsize=1, loop=loop) q.put_nowait(1) # Start a task that waits to put. - t = tasks.Task(q.put(2), loop=loop) + tasks.Task(q.put(2), loop=loop) # Let it start waiting. yield from tasks.sleep(0.1, loop=loop) self.assertTrue('_putters[1]' in fn(q)) @@ -280,8 +280,8 @@ def test_get_cancelled_race(self): def test_get_with_waiting_putters(self): q = queues.Queue(loop=self.loop, maxsize=1) - t1 = tasks.Task(q.put('a'), loop=self.loop) - t2 = tasks.Task(q.put('b'), loop=self.loop) + tasks.Task(q.put('a'), loop=self.loop) + tasks.Task(q.put('b'), loop=self.loop) test_utils.run_briefly(self.loop) self.assertEqual(self.loop.run_until_complete(q.get()), 'a') self.assertEqual(self.loop.run_until_complete(q.get()), 'b') @@ -363,14 +363,14 @@ def test(): def test_put_cancelled_race(self): q = queues.Queue(loop=self.loop, maxsize=1) - t1 = tasks.Task(q.put('a'), loop=self.loop) - t2 = tasks.Task(q.put('b'), loop=self.loop) - t3 = tasks.Task(q.put('c'), loop=self.loop) + tasks.Task(q.put('a'), loop=self.loop) + tasks.Task(q.put('c'), loop=self.loop) + t = tasks.Task(q.put('b'), loop=self.loop) test_utils.run_briefly(self.loop) - t2.cancel() + t.cancel() test_utils.run_briefly(self.loop) - self.assertTrue(t2.done()) + self.assertTrue(t.done()) self.assertEqual(q.get_nowait(), 'a') self.assertEqual(q.get_nowait(), 'c') diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index 352407c4..9596e928 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -15,7 +15,6 @@ from tulip import futures from tulip import selectors from tulip import test_utils -from tulip.events import AbstractEventLoop from tulip.protocols import DatagramProtocol, Protocol from tulip.selector_events import BaseSelectorEventLoop from tulip.selector_events import _SelectorTransport @@ -525,10 +524,10 @@ def test_process_events_read(self): reader._cancelled = False self.loop._add_callback = unittest.mock.Mock() - self.loop._process_events([ - (selectors.SelectorKey(1, 1, selectors.EVENT_READ, (reader, None)), - selectors.EVENT_READ), - ]) + self.loop._process_events( + [(selectors.SelectorKey( + 1, 1, selectors.EVENT_READ, (reader, None)), + selectors.EVENT_READ)]) self.assertTrue(self.loop._add_callback.called) self.loop._add_callback.assert_called_with(reader) @@ -537,10 +536,10 @@ def test_process_events_read_cancelled(self): reader.cancelled = True self.loop.remove_reader = unittest.mock.Mock() - self.loop._process_events([ - (selectors.SelectorKey(1, 1, selectors.EVENT_READ, (reader, None)), - selectors.EVENT_READ), - ]) + self.loop._process_events( + [(selectors.SelectorKey( + 1, 1, selectors.EVENT_READ, (reader, None)), + selectors.EVENT_READ)]) self.loop.remove_reader.assert_called_with(1) def test_process_events_write(self): @@ -548,11 +547,10 @@ def test_process_events_write(self): writer._cancelled = False self.loop._add_callback = unittest.mock.Mock() - self.loop._process_events([ - (selectors.SelectorKey(1, 1, selectors.EVENT_WRITE, - (None, writer)), - selectors.EVENT_WRITE), - ]) + self.loop._process_events( + [(selectors.SelectorKey(1, 1, selectors.EVENT_WRITE, + (None, writer)), + selectors.EVENT_WRITE)]) self.loop._add_callback.assert_called_with(writer) def test_process_events_write_cancelled(self): @@ -560,11 +558,10 @@ def test_process_events_write_cancelled(self): writer.cancelled = True self.loop.remove_writer = unittest.mock.Mock() - self.loop._process_events([ - (selectors.SelectorKey(1, 1, selectors.EVENT_WRITE, - (None, writer)), - selectors.EVENT_WRITE), - ]) + self.loop._process_events( + [(selectors.SelectorKey(1, 1, selectors.EVENT_WRITE, + (None, writer)), + selectors.EVENT_WRITE)]) self.loop.remove_writer.assert_called_with(1) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 808f368a..6cb7b291 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -794,8 +794,6 @@ def gen(): loop = test_utils.TestLoop(gen) self.addCleanup(loop.close) - completed = set() - time_shifted = False a = tasks.sleep(0.05, 'a', loop=loop) b = tasks.sleep(0.10, 'b', loop=loop) @@ -1059,7 +1057,7 @@ def wait_for_future(): yield fut task = wait_for_future() - with self.assertRaises(RuntimeError) as cm: + with self.assertRaises(RuntimeError): self.loop.run_until_complete(task) self.assertFalse(fut.done()) diff --git a/tests/transports_test.py b/tests/transports_test.py index 5920cda6..d2688c3a 100644 --- a/tests/transports_test.py +++ b/tests/transports_test.py @@ -3,7 +3,6 @@ import unittest import unittest.mock -from tulip import futures from tulip import transports diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index f0b42a39..b78a879e 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -7,7 +7,6 @@ import signal import stat import sys -import tempfile import unittest import unittest.mock @@ -648,7 +647,6 @@ def test__write_ready_err(self, m_write, m_logexc): test_utils.run_briefly(self.loop) self.protocol.connection_lost.assert_called_with(err) - @unittest.mock.patch('os.write') def test__write_ready_closing(self, m_write): tr = unix_events._UnixWritePipeTransport( diff --git a/tulip/base_events.py b/tulip/base_events.py index 6f77d93d..5157b5b0 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -430,8 +430,8 @@ def start_serving(self, protocol_factory, host=None, port=None, sock.bind(sa) except OSError as err: raise OSError(err.errno, 'error while attempting ' - 'to bind on address %r: %s' - % (sa, err.strerror.lower())) + 'to bind on address %r: %s' + % (sa, err.strerror.lower())) completed = True finally: if not completed: diff --git a/tulip/test_utils.py b/tulip/test_utils.py index cf04f21a..e73a1d7b 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -351,9 +351,9 @@ def __init__(self, gen=None): super().__init__() if gen is None: - self._check_on_close = False def gen(): yield + self._check_on_close = False else: self._check_on_close = True diff --git a/tulip/windows_utils.py b/tulip/windows_utils.py index bf85f31e..af9d1418 100644 --- a/tulip/windows_utils.py +++ b/tulip/windows_utils.py @@ -24,7 +24,7 @@ BUFSIZE = 8192 PIPE = subprocess.PIPE -_mmap_counter=itertools.count() +_mmap_counter = itertools.count() # # Replacement for socket.socketpair() From 2fa075e44131d085508bf60a107fa76e0148d571 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Fri, 20 Sep 2013 21:29:20 +0300 Subject: [PATCH 0629/1502] Fix runtests.py --forever for python 3.4. Thanks to Antoine Pitrou. --- runtests.py | 68 +++++++++++++++++++++++++++++++++-------------------- 1 file changed, 43 insertions(+), 25 deletions(-) diff --git a/runtests.py b/runtests.py index 62f55a4f..287d0367 100644 --- a/runtests.py +++ b/runtests.py @@ -112,30 +112,46 @@ def list_dir(prefix, dir): return mods -def load_tests(testsdir, includes=(), excludes=()): - mods = [mod for mod, _ in load_modules(testsdir)] - - loader = unittest.TestLoader() - suite = unittest.TestSuite() - - for mod in mods: - for name in set(dir(mod)): - if name.endswith('Tests'): - test_module = getattr(mod, name) - tests = loader.loadTestsFromTestCase(test_module) - if includes: - tests = [test - for test in tests - if any(re.search(pat, test.id()) - for pat in includes)] - if excludes: - tests = [test - for test in tests - if not any(re.search(pat, test.id()) - for pat in excludes)] - suite.addTests(tests) - - return suite +class TestsFinder: + + def __init__(self, testsdir, includes=(), excludes=()): + self._testsdir = testsdir + self._includes = includes + self._excludes = excludes + self.find_available_tests() + + def find_available_tests(self): + """ + Find available test classes without instantiating them. + """ + self._test_factories = [] + mods = [mod for mod, _ in load_modules(self._testsdir)] + for mod in mods: + for name in set(dir(mod)): + if name.endswith('Tests'): + self._test_factories.append(getattr(mod, name)) + + def load_tests(self): + """ + Load test cases from the available test classes and apply + optional include / exclude filters. + """ + loader = unittest.TestLoader() + suite = unittest.TestSuite() + for test_factory in self._test_factories: + tests = loader.loadTestsFromTestCase(test_factory) + if self._includes: + tests = [test + for test in tests + if any(re.search(pat, test.id()) + for pat in self._includes)] + if self._excludes: + tests = [test + for test in tests + if not any(re.search(pat, test.id()) + for pat in self._excludes)] + suite.addTests(tests) + return suite class TestResult(unittest.TextTestResult): @@ -219,7 +235,7 @@ def runtests(): ) cov.start() - tests = load_tests(args.testsdir, includes, excludes) + finder = TestsFinder(args.testsdir, includes, excludes) logger = logging.getLogger() if v == 0: logger.setLevel(logging.CRITICAL) @@ -236,11 +252,13 @@ def runtests(): try: if args.forever: while True: + tests = finder.load_tests() result = runner_factory(verbosity=v, failfast=failfast).run(tests) if not result.wasSuccessful(): sys.exit(1) else: + tests = finder.load_tests() result = runner_factory(verbosity=v, failfast=failfast).run(tests) sys.exit(not result.wasSuccessful()) From 17b35d08fe70ca8fa505a822d21697c92b15c848 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 23 Sep 2013 10:15:50 -0700 Subject: [PATCH 0630/1502] Add shield(). --- tests/tasks_test.py | 146 ++++++++++++++++++++++++++++++++++++++++---- tulip/tasks.py | 108 +++++++++++++++++++++++++++----- 2 files changed, 229 insertions(+), 25 deletions(-) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 6cb7b291..7aa32a09 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -1168,8 +1168,38 @@ def outer(): test_utils.run_briefly(self.loop) self.assertEqual(proof, 101) - def test_yield_gather_blocks_cancel(self): - # Cancelling outer() cancels gather() but not inner(). + def test_shield_result(self): + inner = futures.Future(loop=self.loop) + outer = tasks.shield(inner) + inner.set_result(42) + res = self.loop.run_until_complete(outer) + self.assertEqual(res, 42) + + def test_shield_exception(self): + inner = futures.Future(loop=self.loop) + outer = tasks.shield(inner) + test_utils.run_briefly(self.loop) + exc = RuntimeError('expected') + inner.set_exception(exc) + test_utils.run_briefly(self.loop) + self.assertIs(outer.exception(), exc) + + def test_shield_cancel(self): + inner = futures.Future(loop=self.loop) + outer = tasks.shield(inner) + test_utils.run_briefly(self.loop) + inner.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(outer.cancelled()) + + def test_shield_shortcut(self): + fut = futures.Future(loop=self.loop) + fut.set_result(42) + res = self.loop.run_until_complete(tasks.shield(fut)) + self.assertEqual(res, 42) + + def test_shield_effect(self): + # Cancelling outer() does not affect inner(). proof = 0 waiter = futures.Future(loop=self.loop) @@ -1182,7 +1212,7 @@ def inner(): @tasks.coroutine def outer(): nonlocal proof - yield from tasks.gather(inner(), loop=self.loop) + yield from tasks.shield(inner(), loop=self.loop) proof += 100 f = tasks.async(outer(), loop=self.loop) @@ -1194,6 +1224,38 @@ def outer(): test_utils.run_briefly(self.loop) self.assertEqual(proof, 1) + def test_shield_gather(self): + child1 = futures.Future(loop=self.loop) + child2 = futures.Future(loop=self.loop) + parent = tasks.gather(child1, child2, loop=self.loop) + outer = tasks.shield(parent, loop=self.loop) + test_utils.run_briefly(self.loop) + outer.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(outer.cancelled()) + child1.set_result(1) + child2.set_result(2) + test_utils.run_briefly(self.loop) + self.assertEqual(parent.result(), [1, 2]) + + def test_gather_shield(self): + child1 = futures.Future(loop=self.loop) + child2 = futures.Future(loop=self.loop) + inner1 = tasks.shield(child1, loop=self.loop) + inner2 = tasks.shield(child2, loop=self.loop) + parent = tasks.gather(inner1, inner2, loop=self.loop) + test_utils.run_briefly(self.loop) + parent.cancel() + # This should cancel inner1 and inner2 but bot child1 and child2. + test_utils.run_briefly(self.loop) + self.assertIsInstance(parent.exception(), futures.CancelledError) + self.assertTrue(inner1.cancelled()) + self.assertTrue(inner2.cancelled()) + child1.set_result(1) + child2.set_result(2) + test_utils.run_briefly(self.loop) + + class GatherTestsBase: @@ -1320,7 +1382,8 @@ def test_one_cancellation(self): self._run_loop(self.one_loop) self.assertTrue(fut.done()) cb.assert_called_once_with(fut) - self.assertTrue(fut.cancelled()) + self.assertFalse(fut.cancelled()) + self.assertIsInstance(fut.exception(), futures.CancelledError) # Does nothing c.set_result(3) d.cancel() @@ -1333,16 +1396,21 @@ def test_result_exception_one_cancellation(self): cb = Mock() fut.add_done_callback(cb) a.set_result(1) - b.set_exception(ZeroDivisionError()) + zde = ZeroDivisionError() + b.set_exception(zde) c.cancel() self._run_loop(self.one_loop) - self.assertTrue(fut.done()) - cb.assert_called_once_with(fut) - self.assertTrue(fut.cancelled()) - # Does nothing + self.assertFalse(fut.done()) d.set_result(3) e.cancel() - f.set_exception(RuntimeError()) + rte = RuntimeError() + f.set_exception(rte) + res = self.one_loop.run_until_complete(fut) + self.assertIsInstance(res[2], futures.CancelledError) + self.assertIsInstance(res[4], futures.CancelledError) + res[2] = res[4] = None + self.assertEqual(res, [1, zde, None, 3, None, rte]) + cb.assert_called_once_with(fut) class CoroutineGatherTests(GatherTestsBase, unittest.TestCase): @@ -1366,13 +1434,69 @@ def coro(fut=fut): def test_constructor_loop_selection(self): @tasks.coroutine def coro(): - yield from [] return 'abc' fut = tasks.gather(coro(), coro()) self.assertIs(fut._loop, self.one_loop) fut = tasks.gather(coro(), coro(), loop=self.other_loop) self.assertIs(fut._loop, self.other_loop) + def test_cancellation_broadcast(self): + # Cancelling outer() cancels all children. + proof = 0 + waiter = futures.Future(loop=self.one_loop) + + @tasks.coroutine + def inner(): + nonlocal proof + yield from waiter + proof += 1 + + child1 = tasks.async(inner(), loop=self.one_loop) + child2 = tasks.async(inner(), loop=self.one_loop) + gatherer = None + + @tasks.coroutine + def outer(): + nonlocal proof, gatherer + gatherer = tasks.gather(child1, child2, loop=self.one_loop) + yield from gatherer + proof += 100 + + f = tasks.async(outer(), loop=self.one_loop) + test_utils.run_briefly(self.one_loop) + self.assertTrue(f.cancel()) + with self.assertRaises(futures.CancelledError): + self.one_loop.run_until_complete(f) + self.assertFalse(gatherer.cancel()) + self.assertTrue(waiter.cancelled()) + self.assertTrue(child1.cancelled()) + self.assertTrue(child2.cancelled()) + test_utils.run_briefly(self.one_loop) + self.assertEqual(proof, 0) + + def test_exception_marking(self): + # Test for the first line marked "Mark exception retrieved." + + @tasks.coroutine + def inner(f): + yield from f + raise RuntimeError('should not be ignored') + + a = futures.Future(loop=self.one_loop) + b = futures.Future(loop=self.one_loop) + + @tasks.coroutine + def outer(): + yield from tasks.gather(inner(a), inner(b), loop=self.one_loop) + + f = tasks.async(outer(), loop=self.one_loop) + test_utils.run_briefly(self.one_loop) + a.set_result(None) + test_utils.run_briefly(self.one_loop) + b.set_result(None) + test_utils.run_briefly(self.one_loop) + self.assertIsInstance(f.exception(), RuntimeError) + if __name__ == '__main__': unittest.main() diff --git a/tulip/tasks.py b/tulip/tasks.py index 4d00c044..e3814a64 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -337,18 +337,45 @@ def async(coro_or_future, *, loop=None): raise TypeError('A Future or coroutine is required') +class _GatheringFuture(futures.Future): + """Helper for gather(). + + This overrides cancel() to cancel all the children and act more + like Task.cancel(), which doesn't immediately mark itself as + cancelled. + """ + + def __init__(self, children, *, loop=None): + super().__init__(loop=loop) + self._children = children + + def cancel(self): + if self.done(): + return False + for child in self._children: + child.cancel() + return True + + def gather(*coros_or_futures, loop=None, return_exceptions=False): """Return a future aggregating results from the given coroutines or futures. - All futures must share the same event loop. If all the tasks - are done successfully, the returned future's result is the list of - results (in the order of the original sequence, not necessarily the - order of results arrival). If one of the tasks is cancelled, the - returned future is immediately cancelled too. If *result_exception* - is True, exceptions in the tasks are treated the same as successful - results, and gathered in the result list; otherwise, the first raised - exception will be immediately propagated to the returned future. + All futures must share the same event loop. If all the tasks are + done successfully, the returned future's result is the list of + results (in the order of the original sequence, not necessarily + the order of results arrival). If *result_exception* is True, + exceptions in the tasks are treated the same as successful + results, and gathered in the result list; otherwise, the first + raised exception will be immediately propagated to the returned + future. + + Cancellation: if the outer Future is cancelled, all children (that + have not completed yet) are also cancelled. If any child is + cancelled, this is treated as if it raised CancelledError -- + the outer Future is *not* cancelled in this case. (This is to + prevent the cancellation of one child to cause other children to + be cancelled.) """ children = [async(fut, loop=loop) for fut in coros_or_futures] n = len(children) @@ -361,7 +388,7 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False): for fut in children: if fut._loop is not loop: raise ValueError("futures are tied to different event loops") - outer = futures.Future(loop=loop) + outer = _GatheringFuture(children, loop=loop) nfinished = 0 results = [None] * n @@ -369,17 +396,19 @@ def _done_callback(i, fut): nonlocal nfinished if outer._state != futures._PENDING: if fut._exception is not None: - # Be sure to mark the result retrieved + # Mark exception retrieved. fut.exception() return if fut._state == futures._CANCELLED: - outer.cancel() - return + res = futures.CancelledError() + if not return_exceptions: + outer.set_exception(res) + return elif fut._exception is not None: + res = fut.exception() # Mark exception retrieved. if not return_exceptions: - outer.set_exception(fut.exception()) + outer.set_exception(res) return - res = fut.exception() else: res = fut._result results[i] = res @@ -390,3 +419,54 @@ def _done_callback(i, fut): for i, fut in enumerate(children): fut.add_done_callback(functools.partial(_done_callback, i)) return outer + + +def shield(arg, *, loop=None): + """Wait for a future, shielding it from cancellation. + + The statement + + res = yield from shield(something()) + + is exactly equivalent to the statement + + res = yield from something() + + *except* that if the coroutine containing it is cancelled, the + task running in something() is not cancelled. From the POV of + something(), the cancellation did not happen. But its caller is + still cancelled, so the yield-from expression still raises + CancelledError. Note: If something() is cancelled by other means + this will still cancel shield(). + + If you want to completely ignore cancellation (not recommended) + you can combine shield() with a try/except clause, as follows: + + try: + res = yield from shield(something()) + except CancelledError: + res = None + """ + inner = async(arg, loop=loop) + if inner.done(): + # Shortcut. + return inner + loop = inner._loop + outer = futures.Future(loop=loop) + + def _done_callback(inner): + if outer.cancelled(): + # Mark inner's result as retrieved. + inner.cancelled() or inner.exception() + return + if inner.cancelled(): + outer.cancel() + else: + exc = inner.exception() + if exc is not None: + outer.set_exception(exc) + else: + outer.set_result(inner.result()) + + inner.add_done_callback(_done_callback) + return outer From e44e7575b3aaf1dc6c4e20d8f388da76dba975a3 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 27 Sep 2013 14:58:18 -0700 Subject: [PATCH 0631/1502] StreamReader feed* and set_exc* should check for waiter.cancelled(), not done. --- tests/streams_test.py | 16 ++++++++++++++++ tulip/streams.py | 7 ++++--- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/tests/streams_test.py b/tests/streams_test.py index 2267a0f5..49a43cab 100644 --- a/tests/streams_test.py +++ b/tests/streams_test.py @@ -338,6 +338,22 @@ def readline(): self.assertRaises(ValueError, t1.result) + def test_exception_cancel(self): + stream = streams.StreamReader(loop=self.loop) + + @tasks.coroutine + def read_a_line(): + yield from stream.readline() + + t = tasks.Task(read_a_line(), loop=self.loop) + test_utils.run_briefly(self.loop) + t.cancel() + test_utils.run_briefly(self.loop) + # The following line fails if set_exception() isn't careful. + stream.set_exception(RuntimeError('message')) + test_utils.run_briefly(self.loop) + self.assertIs(stream.waiter, None) + if __name__ == '__main__': unittest.main() diff --git a/tulip/streams.py b/tulip/streams.py index 3203b7d6..5950da33 100644 --- a/tulip/streams.py +++ b/tulip/streams.py @@ -89,14 +89,15 @@ def set_exception(self, exc): waiter = self.waiter if waiter is not None: self.waiter = None - waiter.set_exception(exc) + if not waiter.cancelled(): + waiter.set_exception(exc) def feed_eof(self): self.eof = True waiter = self.waiter if waiter is not None: self.waiter = None - if not waiter.done(): + if not waiter.cancelled(): waiter.set_result(True) def feed_data(self, data): @@ -109,7 +110,7 @@ def feed_data(self, data): waiter = self.waiter if waiter is not None: self.waiter = None - if not waiter.done(): + if not waiter.cancelled(): waiter.set_result(False) @tasks.coroutine From c5a300691d679b161e6a3b17aca8601741c69c70 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 30 Sep 2013 11:40:01 -0700 Subject: [PATCH 0632/1502] Get rid of pause/resume_writing() and discard_output(). --- tests/selector_events_test.py | 71 ----------------------------------- tests/transports_test.py | 3 -- tests/unix_events_test.py | 67 --------------------------------- tests/windows_events_test.py | 33 ---------------- tulip/proactor_events.py | 20 ++-------- tulip/selector_events.py | 23 +----------- tulip/transports.py | 15 -------- tulip/unix_events.py | 23 +----------- 8 files changed, 5 insertions(+), 250 deletions(-) diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index 9596e928..1f462c39 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -666,7 +666,6 @@ def test_ctor(self): self.loop.assert_reader(7, tr._read_ready) test_utils.run_briefly(self.loop) self.protocol.connection_made.assert_called_with(tr) - self.assertTrue(tr._writing) def test_ctor_with_waiter(self): fut = futures.Future(loop=self.loop) @@ -764,14 +763,6 @@ def test_write_buffer(self): self.assertFalse(self.sock.send.called) self.assertEqual([b'data1', b'data2'], transport._buffer) - def test_write_paused(self): - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) - transport._writing = False - transport.write(b'data') - self.assertFalse(self.sock.send.called) - self.assertEqual(transport._buffer, [b'data']) - def test_write_partial(self): data = b'data' self.sock.send.return_value = 2 @@ -854,15 +845,6 @@ def test_write_ready(self): self.assertEqual(self.sock.send.call_args[0], (data,)) self.assertFalse(self.loop.writers) - def test_write_ready_paused(self): - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) - transport._writing = False - transport._buffer.append(b'data') - transport._write_ready() - self.assertFalse(self.sock.send.called) - self.assertEqual(transport._buffer, [b'data']) - def test_write_ready_closing(self): data = b'data' self.sock.send.return_value = len(data) @@ -941,59 +923,6 @@ def test_write_ready_exception_and_close(self, m_log): transport._write_ready() remove_writer.assert_called_with(self.sock_fd) - def test_pause_writing(self): - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) - transport._buffer.append(b'data') - self.loop.add_writer(self.sock_fd, transport._write_ready) - transport.pause_writing() - self.assertFalse(transport._writing) - self.assertFalse(self.loop.writers) - self.assertEqual(1, self.loop.remove_writer_count[self.sock_fd]) - - transport.pause_writing() - self.assertEqual(1, self.loop.remove_writer_count[self.sock_fd]) - - def test_pause_writing_no_buffer(self): - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) - transport.pause_writing() - self.assertFalse(transport._writing) - self.assertEqual(0, self.loop.remove_writer_count[7]) - - def test_resume_writing(self): - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) - transport._buffer.append(b'data') - transport.resume_writing() - self.assertFalse(self.loop.writers) - - transport._writing = False - transport.resume_writing() - self.assertTrue(transport._writing) - self.loop.assert_writer(self.sock_fd, transport._write_ready) - - def test_resume_writing_no_buffer(self): - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) - transport._writing = False - transport.resume_writing() - self.assertTrue(transport._writing) - self.assertFalse(self.loop.writers) - - def test_discard_output(self): - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) - transport.discard_output() - self.assertEqual(0, self.loop.remove_writer_count[self.sock_fd]) - - transport._buffer.append(b'data') - self.loop.add_writer(self.sock_fd, transport._write_ready) - transport.discard_output() - self.assertEqual(transport._buffer, []) - self.assertEqual(1, self.loop.remove_writer_count[self.sock_fd]) - self.assertFalse(self.loop.writers) - @unittest.skipIf(ssl is None, 'No ssl module') class SelectorSslTransportTests(unittest.TestCase): diff --git a/tests/transports_test.py b/tests/transports_test.py index d2688c3a..304ec206 100644 --- a/tests/transports_test.py +++ b/tests/transports_test.py @@ -37,9 +37,6 @@ def test_not_implemented(self): self.assertRaises(NotImplementedError, transport.resume) self.assertRaises(NotImplementedError, transport.close) self.assertRaises(NotImplementedError, transport.abort) - self.assertRaises(NotImplementedError, transport.pause_writing) - self.assertRaises(NotImplementedError, transport.resume_writing) - self.assertRaises(NotImplementedError, transport.discard_output) def test_dgram_not_implemented(self): transport = transports.DatagramTransport() diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index b78a879e..9d8d7e54 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -747,70 +747,3 @@ def test_write_eof_pending(self): tr.write_eof() self.assertTrue(tr._closing) self.assertFalse(self.protocol.connection_lost.called) - - def test_pause_resume_writing(self): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) - tr.pause_writing() - self.assertFalse(tr._writing) - tr.resume_writing() - self.assertTrue(tr._writing) - - def test_double_pause_resume_writing(self): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) - tr.pause_writing() - self.assertFalse(tr._writing) - tr.pause_writing() - self.assertFalse(tr._writing) - tr.resume_writing() - self.assertTrue(tr._writing) - tr.resume_writing() - self.assertTrue(tr._writing) - - def test_pause_resume_writing_with_nonempty_buffer(self): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) - self.loop.add_writer(5, tr._write_ready) - tr._buffer = [b'da', b'ta'] - tr.pause_writing() - self.assertFalse(tr._writing) - self.assertFalse(self.loop.writers) - self.assertEqual([b'da', b'ta'], tr._buffer) - - tr.resume_writing() - self.assertTrue(tr._writing) - self.loop.assert_writer(5, tr._write_ready) - self.assertEqual([b'da', b'ta'], tr._buffer) - - @unittest.mock.patch('os.write') - def test__write_ready_on_pause(self, m_write): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) - self.loop.add_writer(5, tr._write_ready) - tr._buffer = [b'da', b'ta'] - tr.pause_writing() - - tr._write_ready() - self.assertFalse(m_write.called) - self.assertFalse(self.loop.writers) - self.assertEqual([b'da', b'ta'], tr._buffer) - self.assertFalse(tr._writing) - - def test_discard_output(self): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) - tr._buffer = [b'da', b'ta'] - self.loop.add_writer(5, tr._write_ready) - tr.discard_output() - self.assertTrue(tr._writing) - self.assertFalse(self.loop.writers) - self.assertEqual([], tr._buffer) - - def test_discard_output_without_pending_writes(self): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) - tr.discard_output() - self.assertTrue(tr._writing) - self.assertFalse(self.loop.writers) - self.assertEqual([], tr._buffer) diff --git a/tests/windows_events_test.py b/tests/windows_events_test.py index b75eebbf..c365158f 100644 --- a/tests/windows_events_test.py +++ b/tests/windows_events_test.py @@ -39,39 +39,6 @@ def tearDown(self): self.loop.close() self.loop = None - def test_pause_resume_discard(self): - a, b = self.loop._socketpair() - trans = self.loop._make_write_pipe_transport(a, protocols.Protocol()) - reader = connect_read_pipe(self.loop, b) - f = tulip.async(reader.readline(), loop=self.loop) - - trans.write(b'msg1\n') - self.loop.run_until_complete(f) - self.assertEqual(f.result(), b'msg1\n') - f = tulip.async(reader.readline(), loop=self.loop) - - trans.pause_writing() - trans.write(b'msg2\n') - test_utils.run_briefly(self.loop) - self.assertEqual(trans._buffer, [b'msg2\n']) - - trans.resume_writing() - self.loop.run_until_complete(f) - self.assertEqual(f.result(), b'msg2\n') - f = tulip.async(reader.readline(), loop=self.loop) - - trans.pause_writing() - trans.write(b'msg3\n') - self.assertEqual(trans._buffer, [b'msg3\n']) - trans.discard_output() - self.assertEqual(trans._buffer, []) - - trans.write(b'msg4\n') - self.assertEqual(trans._buffer, [b'msg4\n']) - trans.resume_writing() - self.loop.run_until_complete(f) - self.assertEqual(f.result(), b'msg4\n') - def test_close(self): a, b = self.loop._socketpair() trans = self.loop._make_socket_transport(a, protocols.Protocol()) diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index cda87918..a83c780a 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -25,7 +25,6 @@ def __init__(self, loop, sock, protocol, waiter=None, extra=None): self._buffer = [] self._read_fut = None self._write_fut = None - self._writing_disabled = False self._conn_lost = 0 self._closing = False # Set when close() called. self._loop.call_soon(self._protocol.connection_made, self) @@ -140,7 +139,7 @@ def write(self, data): self._conn_lost += 1 return self._buffer.append(data) - if self._write_fut is None and not self._writing_disabled: + if self._write_fut is None: self._loop_writing() def _loop_writing(self, f=None): @@ -155,9 +154,8 @@ def _loop_writing(self, f=None): if self._closing: self._loop.call_soon(self._call_connection_lost, None) return - if not self._writing_disabled: - self._write_fut = self._loop._proactor.send(self._sock, data) - self._write_fut.add_done_callback(self._loop_writing) + self._write_fut = self._loop._proactor.send(self._sock, data) + self._write_fut.add_done_callback(self._loop_writing) except OSError as exc: self._fatal_error(exc) @@ -166,18 +164,6 @@ def _loop_writing(self, f=None): def abort(self): self._force_close(None) - def pause_writing(self): - self._writing_disabled = True - - def resume_writing(self): - self._writing_disabled = False - if self._buffer and self._write_fut is None: - self._loop_writing() - - def discard_output(self): - if self._buffer: - self._buffer = [] - class _ProactorSocketTransport(_ProactorReadPipeTransport, _ProactorWritePipeTransport, diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 5503f8da..3174b279 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -331,7 +331,6 @@ def __init__(self, loop, sock, protocol, extra): self._protocol = protocol self._buffer = [] self._conn_lost = 0 - self._writing = True self._closing = False # Set when close() called. def abort(self): @@ -413,7 +412,7 @@ def write(self, data): self._conn_lost += 1 return - if not self._buffer and self._writing: + if not self._buffer: # Attempt to send it right away first. try: n = self._sock.send(data) @@ -432,9 +431,6 @@ def write(self, data): self._buffer.append(data) def _write_ready(self): - if not self._writing: - return # transmission off - data = b''.join(self._buffer) assert data, 'Data should not be empty' @@ -457,23 +453,6 @@ def _write_ready(self): self._buffer.append(data) # Try again later. - def pause_writing(self): - if self._writing: - if self._buffer: - self._loop.remove_writer(self._sock_fd) - self._writing = False - - def resume_writing(self): - if not self._writing: - if self._buffer: - self._loop.add_writer(self._sock_fd, self._write_ready) - self._writing = True - - def discard_output(self): - if self._buffer: - self._loop.remove_writer(self._sock_fd) - self._buffer.clear() - class _SelectorSslTransport(_SelectorTransport): diff --git a/tulip/transports.py b/tulip/transports.py index 56425aa9..f6eb2820 100644 --- a/tulip/transports.py +++ b/tulip/transports.py @@ -79,21 +79,6 @@ def can_write_eof(self): """Return True if this protocol supports write_eof(), False if not.""" raise NotImplementedError - def pause_writing(self): - """Pause transmission on the transport. - - Subsequent writes are deferred until resume_writing() is called. - """ - raise NotImplementedError - - def resume_writing(self): - """Resume transmission on the transport. """ - raise NotImplementedError - - def discard_output(self): - """Discard any buffered data awaiting transmission on the transport.""" - raise NotImplementedError - def abort(self): """Closes the transport immediately. diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 75131851..250c8fb9 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -276,7 +276,6 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): self._buffer = [] self._conn_lost = 0 self._closing = False # Set when close() or write_eof() called. - self._writing = True self._loop.add_reader(self._fileno, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) @@ -300,7 +299,7 @@ def write(self, data): self._conn_lost += 1 return - if not self._buffer and self._writing: + if not self._buffer: # Attempt to send it right away first. try: n = os.write(self._fileno, data) @@ -319,9 +318,6 @@ def write(self, data): self._buffer.append(data) def _write_ready(self): - if not self._writing: - return - data = b''.join(self._buffer) assert data, 'Data should not be empty' @@ -389,23 +385,6 @@ def _call_connection_lost(self, exc): self._protocol = None self._loop = None - def pause_writing(self): - if self._writing: - if self._buffer: - self._loop.remove_writer(self._fileno) - self._writing = False - - def resume_writing(self): - if not self._writing: - if self._buffer: - self._loop.add_writer(self._fileno, self._write_ready) - self._writing = True - - def discard_output(self): - if self._buffer: - self._loop.remove_writer(self._fileno) - self._buffer.clear() - class _UnixWriteSubprocessPipeProto(protocols.BaseProtocol): pipe = None From c580d8f2e4b579369b43ab5069a8919abe418b79 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 30 Sep 2013 12:34:02 -0700 Subject: [PATCH 0633/1502] Add [can_]write_eof() for socket and ssl transports. --- tests/selector_events_test.py | 30 ++++++++++++++++++++++++++++++ tulip/selector_events.py | 24 +++++++++++++++++++++--- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index 1f462c39..031a7a28 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -923,6 +923,31 @@ def test_write_ready_exception_and_close(self, m_log): transport._write_ready() remove_writer.assert_called_with(self.sock_fd) + def test_write_eof(self): + tr = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + self.assertTrue(tr.can_write_eof()) + tr.write_eof() + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.write_eof() + self.assertEqual(self.sock.shutdown.call_count, 1) + tr.close() + + def test_write_eof_buffer(self): + tr = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + self.sock.send.side_effect = BlockingIOError + tr.write(b'data') + tr.write_eof() + self.assertEqual(tr._buffer, [b'data']) + self.assertTrue(tr._eof) + self.assertFalse(self.sock.shutdown.called) + self.sock.send.side_effect = lambda _: 4 + tr._write_ready() + self.sock.send.assert_called_with(b'data') + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.close() + @unittest.skipIf(ssl is None, 'No ssl module') class SelectorSslTransportTests(unittest.TestCase): @@ -1156,6 +1181,11 @@ def test_on_ready_send_exc(self): transport._fatal_error.assert_called_with(err) self.assertEqual([], transport._buffer) + def test_write_eof(self): + tr = self._make_one() + self.assertFalse(tr.can_write_eof()) + self.assertRaises(RuntimeError, tr.write_eof) + def test_close(self): tr = self._make_one() tr.close() diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 3174b279..cd3b7f94 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -377,6 +377,7 @@ class _SelectorSocketTransport(_SelectorTransport): def __init__(self, loop, sock, protocol, waiter=None, extra=None): super().__init__(loop, sock, protocol, extra) + self._eof = False self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) @@ -402,7 +403,8 @@ def _read_ready(self): self.close() def write(self, data): - assert isinstance(data, bytes), repr(data) + assert isinstance(data, bytes), repr(data)[:100] + assert not self._eof, 'Cannot call write() after write_eof()' if not data: return @@ -447,12 +449,24 @@ def _write_ready(self): self._loop.remove_writer(self._sock_fd) if self._closing: self._call_connection_lost(None) + elif self._eof: + self._sock.shutdown(socket.SHUT_WR) return elif n: data = data[n:] self._buffer.append(data) # Try again later. + def write_eof(self): + if self._eof: + return + self._eof = True + if not self._buffer: + self._sock.shutdown(socket.SHUT_WR) + + def can_write_eof(self): + return True + class _SelectorSslTransport(_SelectorTransport): @@ -563,6 +577,12 @@ def write(self, data): self._buffer.append(data) # We could optimize, but the callback can do this for now. + def write_eof(self): + raise RuntimeError('SSL transport does not support write_eof().') + + def can_write_eof(self): + return False + def close(self): if self._closing: return @@ -570,8 +590,6 @@ def close(self): self._conn_lost += 1 self._loop.remove_reader(self._sock_fd) - # TODO: write_eof(), can_write_eof(). - class _SelectorDatagramTransport(_SelectorTransport): From 82cb926f95ac42887feb80fb883d079dbc02918d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 30 Sep 2013 14:34:28 -0700 Subject: [PATCH 0634/1502] Implement plain socket pause()/resume(). --- tests/selector_events_test.py | 12 ++++++++++++ tulip/selector_events.py | 14 ++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index 031a7a28..a0e1fa5d 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -675,6 +675,18 @@ def test_ctor_with_waiter(self): test_utils.run_briefly(self.loop) self.assertIsNone(fut.result()) + def test_pause_resume(self): + tr = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + self.assertFalse(tr._paused) + self.loop.assert_reader(7, tr._read_ready) + tr.pause() + self.assertTrue(tr._paused) + self.assertFalse(7 in self.loop.readers) + tr.resume() + self.assertFalse(tr._paused) + self.loop.assert_reader(7, tr._read_ready) + def test_read_ready(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index cd3b7f94..4db64d7f 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -378,12 +378,26 @@ class _SelectorSocketTransport(_SelectorTransport): def __init__(self, loop, sock, protocol, waiter=None, extra=None): super().__init__(loop, sock, protocol, extra) self._eof = False + self._paused = False self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: self._loop.call_soon(waiter.set_result, None) + def pause(self): + assert not self._closing, 'Cannot pause() when closing' + assert not self._paused, 'Already paused' + self._paused = True + self._loop.remove_reader(self._sock_fd) + + def resume(self): + assert self._paused, 'Not paused' + self._paused = False + if self._closing: + return + self._loop.add_reader(self._sock_fd, self._read_ready) + def _read_ready(self): try: data = self._sock.recv(16*1024) From 4d9d51fb05283e6ebe167b00674cc6312e7f7c69 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 30 Sep 2013 14:44:56 -0700 Subject: [PATCH 0635/1502] Implement ssl socket pause()/resume(). --- tests/selector_events_test.py | 11 +++++++++++ tulip/selector_events.py | 22 +++++++++++++++++++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index a0e1fa5d..9811b914 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -1030,6 +1030,17 @@ def test_on_handshake_base_exc(self): self.assertTrue(transport._waiter.done()) self.assertIs(exc, transport._waiter.exception()) + def test_pause_resume(self): + tr = self._make_one() + self.assertFalse(tr._paused) + self.loop.assert_reader(1, tr._on_ready) + tr.pause() + self.assertTrue(tr._paused) + self.assertFalse(1 in self.loop.readers) + tr.resume() + self.assertFalse(tr._paused) + self.loop.assert_reader(1, tr._on_ready) + def test_write_no_data(self): transport = self._make_one() transport._buffer.append(b'data') diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 4db64d7f..032c41e3 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -500,6 +500,7 @@ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, self._waiter = waiter self._rawsock = rawsock self._sslcontext = sslcontext + self._paused = False self._on_handshake() @@ -530,6 +531,25 @@ def _on_handshake(self): if self._waiter is not None: self._loop.call_soon(self._waiter.set_result, None) + def pause(self): + # XXX This is a bit icky, given the comment at the top of + # _on_ready(). Is it possible to evoke a deadlock? I don't + # know, although it doesn't look like it; write() will still + # accept more data for the buffer and eventually the app will + # call resume() again, and things will flow again. + + assert not self._closing, 'Cannot pause() when closing' + assert not self._paused, 'Already paused' + self._paused = True + self._loop.remove_reader(self._sock_fd) + + def resume(self): + assert self._paused, 'Not paused' + self._paused = False + if self._closing: + return + self._loop.add_reader(self._sock_fd, self._on_ready) + def _on_ready(self): # Because of renegotiations (?), there's no difference between # readable and writable. We just try both. XXX This may be @@ -537,7 +557,7 @@ def _on_ready(self): # should do next. # First try reading. - if not self._closing: + if not self._closing and not self._paused: try: data = self._sock.recv(8192) except (BlockingIOError, InterruptedError, From fe5458a5ebac3fb02da4457ad783ef22f3fa8a8d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 30 Sep 2013 15:47:56 -0700 Subject: [PATCH 0636/1502] Add flow control pushback to StreamReader.feed_data(). --- tulip/streams.py | 55 +++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 50 insertions(+), 5 deletions(-) diff --git a/tulip/streams.py b/tulip/streams.py index 5950da33..d0f12e81 100644 --- a/tulip/streams.py +++ b/tulip/streams.py @@ -54,6 +54,9 @@ class StreamReaderProtocol(protocols.Protocol): def __init__(self, stream_reader): self.stream_reader = stream_reader + def connection_made(self, transport): + self.stream_reader.set_transport(transport) + def connection_lost(self, exc): if exc is None: self.stream_reader.feed_eof() @@ -70,7 +73,9 @@ def eof_received(self): class StreamReader: def __init__(self, limit=_DEFAULT_LIMIT, loop=None): - self.limit = limit # Max line length. (Security feature.) + # The line length limit is a security feature; + # it also doubles as half the buffer limit. + self.limit = limit if loop is None: loop = events.get_event_loop() self.loop = loop @@ -79,6 +84,8 @@ def __init__(self, limit=_DEFAULT_LIMIT, loop=None): self.eof = False # Whether we're done. self.waiter = None # A future. self._exception = None + self._transport = None + self._paused = False def exception(self): return self._exception @@ -92,6 +99,15 @@ def set_exception(self, exc): if not waiter.cancelled(): waiter.set_exception(exc) + def set_transport(self, transport): + assert self._transport is None, 'Transport already set' + self._transport = transport + + def _maybe_resume_transport(self): + if self._paused and self.byte_count <= self.limit: + self._paused = False + self._transport.resume() + def feed_eof(self): self.eof = True waiter = self.waiter @@ -113,6 +129,19 @@ def feed_data(self, data): if not waiter.cancelled(): waiter.set_result(False) + if (self._transport is not None and + not self._paused and + self.byte_count > 2*self.limit): + try: + self._transport.pause() + except NotImplementedError: + # The transport can't be paused. + # We'll just have to buffer all data. + # Forget the transport so we don't keep trying. + self._transport = None + else: + self._paused = True + @tasks.coroutine def readline(self): if self._exception is not None: @@ -140,6 +169,7 @@ def readline(self): if parts_size > self.limit: self.byte_count -= parts_size + self._maybe_resume_transport() raise ValueError('Line is too long') if self.eof: @@ -148,10 +178,14 @@ def readline(self): if not_enough: assert self.waiter is None self.waiter = futures.Future(loop=self.loop) - yield from self.waiter + try: + yield from self.waiter + finally: + self.waiter = None line = b''.join(parts) self.byte_count -= parts_size + self._maybe_resume_transport() return line @@ -167,17 +201,24 @@ def read(self, n=-1): while not self.eof: assert not self.waiter self.waiter = futures.Future(loop=self.loop) - yield from self.waiter + try: + yield from self.waiter + finally: + self.waiter = None else: if not self.byte_count and not self.eof: assert not self.waiter self.waiter = futures.Future(loop=self.loop) - yield from self.waiter + try: + yield from self.waiter + finally: + self.waiter = None if n < 0 or self.byte_count <= n: data = b''.join(self.buffer) self.buffer.clear() self.byte_count = 0 + self._maybe_resume_transport() return data parts = [] @@ -193,6 +234,7 @@ def read(self, n=-1): parts.append(data) parts_bytes += data_bytes self.byte_count -= data_bytes + self._maybe_resume_transport() return b''.join(parts) @@ -207,6 +249,9 @@ def readexactly(self, n): while self.byte_count < n and not self.eof: assert not self.waiter self.waiter = futures.Future(loop=self.loop) - yield from self.waiter + try: + yield from self.waiter + finally: + self.waiter = None return (yield from self.read(n)) From b7f09691adf22efddfe2543ddf20a83ca3cdacea Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Tue, 1 Oct 2013 12:37:05 +0100 Subject: [PATCH 0637/1502] Add write_eof() to proactor write transports. --- tests/proactor_events_test.py | 56 +++++++++++++++++++++++++++++++++++ tulip/proactor_events.py | 22 +++++++++++++- 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py index da4dea35..1cc2cfc3 100644 --- a/tests/proactor_events_test.py +++ b/tests/proactor_events_test.py @@ -7,6 +7,7 @@ import tulip from tulip.proactor_events import BaseProactorEventLoop from tulip.proactor_events import _ProactorSocketTransport +from tulip.proactor_events import _ProactorWritePipeTransport from tulip import test_utils @@ -249,6 +250,61 @@ def test_call_connection_lost(self): self.assertTrue(self.protocol.connection_lost.called) self.assertTrue(self.sock.close.called) + def test_write_eof(self): + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol) + self.assertTrue(tr.can_write_eof()) + tr.write_eof() + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.write_eof() + self.assertEqual(self.sock.shutdown.call_count, 1) + tr.close() + + def test_write_eof_buffer(self): + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol) + self.sock.send.side_effect = BlockingIOError + tr.write(b'data') + tr.write_eof() + self.assertTrue(tr._eof_written) + self.assertFalse(self.sock.shutdown.called) + tr._loop._proactor.send.assert_called_with(self.sock, b'data') + tr._write_fut.add_done_callback.assert_called_with( + tr._loop_writing) + tr._write_fut = f = tulip.Future(loop=self.loop) + f.set_result(4) + tr._loop_writing(f) + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.close() + + def test_write_eof_write_pipe(self): + tr = _ProactorWritePipeTransport( + self.loop, self.sock, self.protocol) + self.assertTrue(tr.can_write_eof()) + tr.write_eof() + self.assertTrue(tr._closing) + self.loop._run_once() + self.assertTrue(self.sock.close.called) + tr.close() + + def test_write_eof_buffer_write_pipe(self): + tr = _ProactorWritePipeTransport( + self.loop, self.sock, self.protocol) + self.sock.send.side_effect = BlockingIOError + tr.write(b'data') + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.sock.shutdown.called) + tr._loop._proactor.send.assert_called_with(self.sock, b'data') + tr._write_fut.add_done_callback.assert_called_with( + tr._loop_writing) + tr._write_fut = f = tulip.Future(loop=self.loop) + f.set_result(4) + tr._loop_writing(f) + self.loop._run_once() + self.assertTrue(self.sock.close.called) + tr.close() + class BaseProactorEventLoopTests(unittest.TestCase): diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index a83c780a..bab169dd 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -27,6 +27,7 @@ def __init__(self, loop, sock, protocol, waiter=None, extra=None): self._write_fut = None self._conn_lost = 0 self._closing = False # Set when close() called. + self._eof_written = False self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: self._loop.call_soon(waiter.set_result, None) @@ -130,6 +131,9 @@ class _ProactorWritePipeTransport(_ProactorBasePipeTransport, def write(self, data): assert isinstance(data, bytes), repr(data) + if self._closing or self._eof_written: + raise IOError('close() or write_eof() already called') + if not data: return @@ -153,13 +157,19 @@ def _loop_writing(self, f=None): if not data: if self._closing: self._loop.call_soon(self._call_connection_lost, None) + if self._eof_written: + self._sock.shutdown(socket.SHUT_WR) return self._write_fut = self._loop._proactor.send(self._sock, data) self._write_fut.add_done_callback(self._loop_writing) except OSError as exc: self._fatal_error(exc) - # TODO: write_eof(), can_write_eof(). + def can_write_eof(self): + return True + + def write_eof(self): + self.close() def abort(self): self._force_close(None) @@ -173,6 +183,16 @@ class _ProactorSocketTransport(_ProactorReadPipeTransport, def _set_extra(self, sock): self._extra['socket'] = sock + def can_write_eof(self): + return True + + def write_eof(self): + if self._closing or self._eof_written: + return + self._eof_written = True + if self._write_fut is None: + self._sock.shutdown(socket.SHUT_WR) + class BaseProactorEventLoop(base_events.BaseEventLoop): From f1b5262949f0b8a09d289b5b1b78ed6c5f623a7c Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Tue, 1 Oct 2013 14:50:32 +0100 Subject: [PATCH 0638/1502] Add pause() and resume() to proactor write transports. --- tests/proactor_events_test.py | 50 +++++++++++++++++++++++++---------- tulip/proactor_events.py | 25 +++++++++++++++--- 2 files changed, 58 insertions(+), 17 deletions(-) diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py index 1cc2cfc3..6a7391de 100644 --- a/tests/proactor_events_test.py +++ b/tests/proactor_events_test.py @@ -261,19 +261,16 @@ def test_write_eof(self): tr.close() def test_write_eof_buffer(self): - tr = _ProactorSocketTransport( - self.loop, self.sock, self.protocol) - self.sock.send.side_effect = BlockingIOError + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + f = tulip.Future(loop=self.loop) + tr._loop._proactor.send.side_effect = f tr.write(b'data') tr.write_eof() self.assertTrue(tr._eof_written) self.assertFalse(self.sock.shutdown.called) tr._loop._proactor.send.assert_called_with(self.sock, b'data') - tr._write_fut.add_done_callback.assert_called_with( - tr._loop_writing) - tr._write_fut = f = tulip.Future(loop=self.loop) f.set_result(4) - tr._loop_writing(f) + self.loop._run_once() self.sock.shutdown.assert_called_with(socket.SHUT_WR) tr.close() @@ -288,23 +285,48 @@ def test_write_eof_write_pipe(self): tr.close() def test_write_eof_buffer_write_pipe(self): - tr = _ProactorWritePipeTransport( - self.loop, self.sock, self.protocol) - self.sock.send.side_effect = BlockingIOError + tr = _ProactorWritePipeTransport(self.loop, self.sock, self.protocol) + f = tulip.Future(loop=self.loop) + tr._loop._proactor.send.side_effect = f tr.write(b'data') tr.write_eof() self.assertTrue(tr._closing) self.assertFalse(self.sock.shutdown.called) tr._loop._proactor.send.assert_called_with(self.sock, b'data') - tr._write_fut.add_done_callback.assert_called_with( - tr._loop_writing) - tr._write_fut = f = tulip.Future(loop=self.loop) f.set_result(4) - tr._loop_writing(f) + self.loop._run_once() self.loop._run_once() self.assertTrue(self.sock.close.called) tr.close() + def test_pause_resume(self): + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol) + f = tulip.Future(loop=self.loop) + tr._loop._proactor.send.side_effect = f + self.assertFalse(tr._paused) + tr.write(b'data1') + tr._loop._proactor.send.assert_called_with(self.sock, b'data1') + self.assertEqual(tr._buffer, []) + tr.write(b'data2') + self.assertEqual(tr._buffer, [b'data2']) + tr.pause() + tr.write(b'data3') + self.assertEqual(tr._buffer, [b'data2', b'data3']) + f.set_result(5) + self.loop._run_once() + self.assertEqual(tr._buffer, [b'data2data3']) + self.loop._run_once() + self.assertEqual(tr._buffer, [b'data2data3']) + f = tulip.Future(loop=self.loop) + tr._loop._proactor.send.side_effect = f + tr.resume() + tr._loop._proactor.send.assert_called_with(self.sock, b'data2data3') + self.assertEqual(tr._buffer, []) + tr.write(b'data4') + self.assertEqual(tr._buffer, [b'data4']) + tr.close() + class BaseProactorEventLoopTests(unittest.TestCase): diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index bab169dd..f897f8cd 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -28,6 +28,7 @@ def __init__(self, loop, sock, protocol, waiter=None, extra=None): self._conn_lost = 0 self._closing = False # Set when close() called. self._eof_written = False + self._paused = False self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: self._loop.call_soon(waiter.set_result, None) @@ -129,6 +130,21 @@ class _ProactorWritePipeTransport(_ProactorBasePipeTransport, transports.WriteTransport): """Transport for write pipes.""" + def pause(self): + assert not self._closing, 'Cannot pause() when closing' + assert not self._paused, 'Already paused' + # We don't try to cancel an existing overlapped write. Instead + # we prevent new overlapped writes until resume() is called. + self._paused = True + + def resume(self): + assert self._paused, 'Not paused' + self._paused = False + if self._closing: + return + if self._buffer and self._write_fut is None: + self._loop_writing() + def write(self, data): assert isinstance(data, bytes), repr(data) if self._closing or self._eof_written: @@ -143,7 +159,7 @@ def write(self, data): self._conn_lost += 1 return self._buffer.append(data) - if self._write_fut is None: + if self._write_fut is None and not self._paused: self._loop_writing() def _loop_writing(self, f=None): @@ -160,8 +176,11 @@ def _loop_writing(self, f=None): if self._eof_written: self._sock.shutdown(socket.SHUT_WR) return - self._write_fut = self._loop._proactor.send(self._sock, data) - self._write_fut.add_done_callback(self._loop_writing) + if not self._paused: + self._write_fut = self._loop._proactor.send(self._sock, data) + self._write_fut.add_done_callback(self._loop_writing) + else: + self._buffer.append(data) except OSError as exc: self._fatal_error(exc) From e24683348114aad4ba9be61a819cf9e79a5442bf Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 1 Oct 2013 08:07:37 -0700 Subject: [PATCH 0639/1502] Unimplemented write_eof() should raise NotImplementedError. --- tests/selector_events_test.py | 2 +- tulip/selector_events.py | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index 9811b914..f810f319 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -1207,7 +1207,7 @@ def test_on_ready_send_exc(self): def test_write_eof(self): tr = self._make_one() self.assertFalse(tr.can_write_eof()) - self.assertRaises(RuntimeError, tr.write_eof) + self.assertRaises(NotImplementedError, tr.write_eof) def test_close(self): tr = self._make_one() diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 032c41e3..92330e87 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -611,9 +611,6 @@ def write(self, data): self._buffer.append(data) # We could optimize, but the callback can do this for now. - def write_eof(self): - raise RuntimeError('SSL transport does not support write_eof().') - def can_write_eof(self): return False From 421672a2a52c68c206f5e328b77ba5138f14a080 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Thu, 3 Oct 2013 14:52:24 +0100 Subject: [PATCH 0640/1502] Fix pause() and resume() for proactor transports: they should control reading not writing. --- tests/proactor_events_test.py | 42 +++++++++++++++++------------------ tulip/proactor_events.py | 41 ++++++++++++++++------------------ 2 files changed, 40 insertions(+), 43 deletions(-) diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py index 6a7391de..ae38c031 100644 --- a/tests/proactor_events_test.py +++ b/tests/proactor_events_test.py @@ -263,7 +263,7 @@ def test_write_eof(self): def test_write_eof_buffer(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) f = tulip.Future(loop=self.loop) - tr._loop._proactor.send.side_effect = f + tr._loop._proactor.send.return_value = f tr.write(b'data') tr.write_eof() self.assertTrue(tr._eof_written) @@ -287,7 +287,7 @@ def test_write_eof_write_pipe(self): def test_write_eof_buffer_write_pipe(self): tr = _ProactorWritePipeTransport(self.loop, self.sock, self.protocol) f = tulip.Future(loop=self.loop) - tr._loop._proactor.send.side_effect = f + tr._loop._proactor.send.return_value = f tr.write(b'data') tr.write_eof() self.assertTrue(tr._closing) @@ -302,29 +302,29 @@ def test_write_eof_buffer_write_pipe(self): def test_pause_resume(self): tr = _ProactorSocketTransport( self.loop, self.sock, self.protocol) - f = tulip.Future(loop=self.loop) - tr._loop._proactor.send.side_effect = f + futures = [] + for msg in [b'data1', b'data2', b'data3', b'data4', b'']: + f = tulip.Future(loop=self.loop) + f.set_result(msg) + futures.append(f) + self.loop._proactor.recv.side_effect = futures + self.loop._run_once() self.assertFalse(tr._paused) - tr.write(b'data1') - tr._loop._proactor.send.assert_called_with(self.sock, b'data1') - self.assertEqual(tr._buffer, []) - tr.write(b'data2') - self.assertEqual(tr._buffer, [b'data2']) - tr.pause() - tr.write(b'data3') - self.assertEqual(tr._buffer, [b'data2', b'data3']) - f.set_result(5) self.loop._run_once() - self.assertEqual(tr._buffer, [b'data2data3']) + self.protocol.data_received.assert_called_with(b'data1') self.loop._run_once() - self.assertEqual(tr._buffer, [b'data2data3']) - f = tulip.Future(loop=self.loop) - tr._loop._proactor.send.side_effect = f + self.protocol.data_received.assert_called_with(b'data2') + tr.pause() + self.assertTrue(tr._paused) + for i in range(10): + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data2') tr.resume() - tr._loop._proactor.send.assert_called_with(self.sock, b'data2data3') - self.assertEqual(tr._buffer, []) - tr.write(b'data4') - self.assertEqual(tr._buffer, [b'data4']) + self.assertFalse(tr._paused) + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data3') + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data4') tr.close() diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index f897f8cd..5b631f6f 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -28,7 +28,6 @@ def __init__(self, loop, sock, protocol, waiter=None, extra=None): self._conn_lost = 0 self._closing = False # Set when close() called. self._eof_written = False - self._paused = False self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: self._loop.call_soon(waiter.set_result, None) @@ -82,9 +81,25 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport, def __init__(self, loop, sock, protocol, waiter=None, extra=None): super().__init__(loop, sock, protocol, waiter, extra) + self._read_fut = None + self._paused = False self._loop.call_soon(self._loop_reading) + def pause(self): + assert not self._closing, 'Cannot pause() when closing' + assert not self._paused, 'Already paused' + self._paused = True + + def resume(self): + assert self._paused, 'Not paused' + self._paused = False + if self._closing: + return + self._loop.call_soon(self._loop_reading, self._read_fut) + def _loop_reading(self, fut=None): + if self._paused: + return data = None try: @@ -130,21 +145,6 @@ class _ProactorWritePipeTransport(_ProactorBasePipeTransport, transports.WriteTransport): """Transport for write pipes.""" - def pause(self): - assert not self._closing, 'Cannot pause() when closing' - assert not self._paused, 'Already paused' - # We don't try to cancel an existing overlapped write. Instead - # we prevent new overlapped writes until resume() is called. - self._paused = True - - def resume(self): - assert self._paused, 'Not paused' - self._paused = False - if self._closing: - return - if self._buffer and self._write_fut is None: - self._loop_writing() - def write(self, data): assert isinstance(data, bytes), repr(data) if self._closing or self._eof_written: @@ -159,7 +159,7 @@ def write(self, data): self._conn_lost += 1 return self._buffer.append(data) - if self._write_fut is None and not self._paused: + if self._write_fut is None: self._loop_writing() def _loop_writing(self, f=None): @@ -176,11 +176,8 @@ def _loop_writing(self, f=None): if self._eof_written: self._sock.shutdown(socket.SHUT_WR) return - if not self._paused: - self._write_fut = self._loop._proactor.send(self._sock, data) - self._write_fut.add_done_callback(self._loop_writing) - else: - self._buffer.append(data) + self._write_fut = self._loop._proactor.send(self._sock, data) + self._write_fut.add_done_callback(self._loop_writing) except OSError as exc: self._fatal_error(exc) From 243ca9d0f4f5767ad8ae871c9d0f1a29b6ae4019 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Sat, 5 Oct 2013 14:55:49 +0100 Subject: [PATCH 0641/1502] Add support for Windows named pipes. --- overlapped.c | 239 ++++++++++++++++++++++++++++++---- tests/proactor_events_test.py | 9 ++ tests/windows_events_test.py | 75 ++++++++--- tulip/proactor_events.py | 21 ++- tulip/windows_events.py | 218 ++++++++++++++++++++++++++----- tulip/windows_utils.py | 4 +- 6 files changed, 492 insertions(+), 74 deletions(-) diff --git a/overlapped.c b/overlapped.c index 3a2c1208..ae1e77ca 100644 --- a/overlapped.c +++ b/overlapped.c @@ -32,7 +32,31 @@ #define T_HANDLE T_POINTER enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_WRITE, TYPE_ACCEPT, - TYPE_CONNECT, TYPE_DISCONNECT}; + TYPE_CONNECT, TYPE_DISCONNECT, TYPE_CONNECT_NAMED_PIPE, + TYPE_WAIT_NAMED_PIPE_AND_CONNECT}; + +typedef struct { + PyObject_HEAD + OVERLAPPED overlapped; + /* For convenience, we store the file handle too */ + HANDLE handle; + /* Error returned by last method call */ + DWORD error; + /* Type of operation */ + DWORD type; + union { + /* Buffer used for reading (optional) */ + PyObject *read_buffer; + /* Buffer used for writing (optional) */ + Py_buffer write_buffer; + }; +} OverlappedObject; + +typedef struct { + OVERLAPPED *Overlapped; + HANDLE IocpHandle; + char Address[1]; +} WaitNamedPipeAndConnectContext; /* * Map Windows error codes to subclasses of OSError @@ -248,6 +272,46 @@ overlapped_BindLocal(PyObject *self, PyObject *args) Py_RETURN_NONE; } +/* + * Windows equivalent of os.strerror() -- compare _ctypes/callproc.c + */ + +PyDoc_STRVAR( + FormatMessage_doc, + "FormatMessage(error_code) -> error_message\n\n" + "Return error message for an error code."); + +static PyObject * +overlapped_FormatMessage(PyObject *ignore, PyObject *args) +{ + DWORD code, n; + WCHAR *lpMsgBuf; + PyObject *res; + + if (!PyArg_ParseTuple(args, F_DWORD, &code)) + return NULL; + + n = FormatMessageW(FORMAT_MESSAGE_ALLOCATE_BUFFER | + FORMAT_MESSAGE_FROM_SYSTEM, + NULL, + code, + MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + (LPWSTR) &lpMsgBuf, + 0, + NULL); + if (n) { + while (iswspace(lpMsgBuf[n-1])) + --n; + lpMsgBuf[n] = L'\0'; + res = Py_BuildValue("u", lpMsgBuf); + } else { + res = PyUnicode_FromFormat("unknown error code %u", code); + } + LocalFree(lpMsgBuf); + return res; +} + + /* * Mark operation as completed - used when reading produces ERROR_BROKEN_PIPE */ @@ -269,22 +333,6 @@ PyDoc_STRVAR( Overlapped_doc, "Overlapped object"); -typedef struct { - PyObject_HEAD - OVERLAPPED overlapped; - /* For convenience, we store the file handle too */ - HANDLE handle; - /* Error returned by last method call */ - DWORD error; - /* Type of operation */ - DWORD type; - /* Buffer used for reading (optional) */ - PyObject *read_buffer; - /* Buffer used for writing (optional) */ - Py_buffer write_buffer; -} OverlappedObject; - - static PyObject * Overlapped_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { @@ -358,7 +406,11 @@ Overlapped_dealloc(OverlappedObject *self) if (self->write_buffer.obj) PyBuffer_Release(&self->write_buffer); - Py_CLEAR(self->read_buffer); + switch (self->type) { + case TYPE_READ: + case TYPE_ACCEPT: + Py_CLEAR(self->read_buffer); + } PyObject_Del(self); SetLastError(olderr); } @@ -373,7 +425,8 @@ Overlapped_cancel(OverlappedObject *self) { BOOL ret = TRUE; - if (self->type == TYPE_NOT_STARTED) + if (self->type == TYPE_NOT_STARTED + || self->type == TYPE_WAIT_NAMED_PIPE_AND_CONNECT) Py_RETURN_NONE; if (!HasOverlappedIoCompleted(&self->overlapped)) { @@ -445,10 +498,6 @@ Overlapped_getresult(OverlappedObject *self, PyObject *args) return NULL; Py_INCREF(self->read_buffer); return self->read_buffer; - case TYPE_ACCEPT: - case TYPE_CONNECT: - case TYPE_DISCONNECT: - Py_RETURN_NONE; default: return PyLong_FromUnsignedLong((unsigned long) transferred); } @@ -853,6 +902,142 @@ Overlapped_DisconnectEx(OverlappedObject *self, PyObject *args) } } +PyDoc_STRVAR( + Overlapped_ConnectNamedPipe_doc, + "ConnectNamedPipe(handle) -> Overlapped[None]\n\n" + "Start overlapped wait for a client to connect."); + +static PyObject * +Overlapped_ConnectNamedPipe(OverlappedObject *self, PyObject *args) +{ + HANDLE Pipe; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE, &Pipe)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + self->type = TYPE_CONNECT_NAMED_PIPE; + self->handle = Pipe; + + Py_BEGIN_ALLOW_THREADS + ret = ConnectNamedPipe(Pipe, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_PIPE_CONNECTED: + mark_as_completed(&self->overlapped); + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +/* Unfortunately there is no way to do an overlapped connect to a + pipe. We instead use WaitNamedPipe() and CreateFile() in a thread + pool thread. If a connection succeeds within a time limit (10 + seconds) then PostQueuedCompletionStatus() is used to return the + pipe handle to the completion port. */ + +static DWORD WINAPI +WaitNamedPipeAndConnectInThread(WaitNamedPipeAndConnectContext *ctx) +{ + HANDLE PipeHandle = INVALID_HANDLE_VALUE; + DWORD Start = GetTickCount(); + DWORD Deadline = Start + 10*1000; + DWORD Error = 0; + DWORD Timeout; + BOOL Success; + + for ( ; ; ) { + Timeout = Deadline - GetTickCount(); + if ((int)Timeout < 0) + break; + Success = WaitNamedPipe(ctx->Address, Timeout); + Error = Success ? ERROR_SUCCESS : GetLastError(); + switch (Error) { + case ERROR_SUCCESS: + PipeHandle = CreateFile(ctx->Address, + GENERIC_READ | GENERIC_WRITE, + 0, NULL, OPEN_EXISTING, + FILE_FLAG_OVERLAPPED, NULL); + if (PipeHandle == INVALID_HANDLE_VALUE) + continue; + break; + case ERROR_SEM_TIMEOUT: + continue; + } + break; + } + if (!PostQueuedCompletionStatus(ctx->IocpHandle, Error, + (ULONG_PTR)PipeHandle, ctx->Overlapped)) + CloseHandle(PipeHandle); + free(ctx); + return 0; +} + +PyDoc_STRVAR( + Overlapped_WaitNamedPipeAndConnect_doc, + "WaitNamedPipeAndConnect(addr, iocp_handle) -> Overlapped[pipe_handle]\n\n" + "Start overlapped connection to address, notifying iocp_handle when\n" + "finished"); + +static PyObject * +Overlapped_WaitNamedPipeAndConnect(OverlappedObject *self, PyObject *args) +{ + char *Address; + Py_ssize_t AddressLength; + HANDLE IocpHandle; + OVERLAPPED Overlapped; + BOOL ret; + DWORD err; + WaitNamedPipeAndConnectContext *ctx; + Py_ssize_t ContextLength; + + if (!PyArg_ParseTuple(args, "s#" F_HANDLE F_POINTER, + &Address, &AddressLength, &IocpHandle, &Overlapped)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + ContextLength = (AddressLength + + offsetof(WaitNamedPipeAndConnectContext, Address)); + ctx = calloc(1, ContextLength + 1); + if (ctx == NULL) + return PyErr_NoMemory(); + memcpy(ctx->Address, Address, AddressLength + 1); + ctx->Overlapped = &self->overlapped; + ctx->IocpHandle = IocpHandle; + + self->type = TYPE_WAIT_NAMED_PIPE_AND_CONNECT; + self->handle = NULL; + + Py_BEGIN_ALLOW_THREADS + ret = QueueUserWorkItem(WaitNamedPipeAndConnectInThread, ctx, + WT_EXECUTELONGFUNCTION); + Py_END_ALLOW_THREADS + + mark_as_completed(&self->overlapped); + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + if (!ret) + return SetFromWindowsErr(err); + Py_RETURN_NONE; +} + static PyObject* Overlapped_getaddress(OverlappedObject *self) { @@ -885,6 +1070,11 @@ static PyMethodDef Overlapped_methods[] = { METH_VARARGS, Overlapped_ConnectEx_doc}, {"DisconnectEx", (PyCFunction) Overlapped_DisconnectEx, METH_VARARGS, Overlapped_DisconnectEx_doc}, + {"ConnectNamedPipe", (PyCFunction) Overlapped_ConnectNamedPipe, + METH_VARARGS, Overlapped_ConnectNamedPipe_doc}, + {"WaitNamedPipeAndConnect", + (PyCFunction) Overlapped_WaitNamedPipeAndConnect, + METH_VARARGS, Overlapped_WaitNamedPipeAndConnect_doc}, {NULL} }; @@ -954,6 +1144,8 @@ static PyMethodDef overlapped_functions[] = { METH_VARARGS, GetQueuedCompletionStatus_doc}, {"PostQueuedCompletionStatus", overlapped_PostQueuedCompletionStatus, METH_VARARGS, PostQueuedCompletionStatus_doc}, + {"FormatMessage", overlapped_FormatMessage, + METH_VARARGS, FormatMessage_doc}, {"BindLocal", overlapped_BindLocal, METH_VARARGS, BindLocal_doc}, {NULL} @@ -998,6 +1190,7 @@ PyInit__overlapped(void) d = PyModule_GetDict(m); WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING); + WINAPI_CONSTANT(F_DWORD, ERROR_SEM_TIMEOUT); WINAPI_CONSTANT(F_DWORD, INFINITE); WINAPI_CONSTANT(F_HANDLE, INVALID_HANDLE_VALUE); WINAPI_CONSTANT(F_HANDLE, NULL); diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py index ae38c031..d9eae50e 100644 --- a/tests/proactor_events_test.py +++ b/tests/proactor_events_test.py @@ -8,6 +8,7 @@ from tulip.proactor_events import BaseProactorEventLoop from tulip.proactor_events import _ProactorSocketTransport from tulip.proactor_events import _ProactorWritePipeTransport +from tulip.proactor_events import _ProactorDuplexPipeTransport from tulip import test_utils @@ -299,6 +300,14 @@ def test_write_eof_buffer_write_pipe(self): self.assertTrue(self.sock.close.called) tr.close() + def test_write_eof_duplex_pipe(self): + tr = _ProactorDuplexPipeTransport( + self.loop, self.sock, self.protocol) + self.assertFalse(tr.can_write_eof()) + with self.assertRaises(NotImplementedError): + tr.write_eof() + tr.close() + def test_pause_resume(self): tr = _ProactorSocketTransport( self.loop, self.sock, self.protocol) diff --git a/tests/windows_events_test.py b/tests/windows_events_test.py index c365158f..675728e0 100644 --- a/tests/windows_events_test.py +++ b/tests/windows_events_test.py @@ -1,3 +1,4 @@ +import os import unittest import tulip @@ -5,28 +6,22 @@ from tulip import windows_events from tulip import protocols from tulip import streams +from tulip import transports from tulip import test_utils -def connect_read_pipe(loop, file): - stream_reader = streams.StreamReader(loop=loop) - protocol = _StreamReaderProtocol(stream_reader) - loop._make_read_pipe_transport(file, protocol) - return stream_reader +class UpperProto(protocols.Protocol): + def __init__(self): + self.buf = [] - -class _StreamReaderProtocol(protocols.Protocol): - def __init__(self, stream_reader): - self.stream_reader = stream_reader - - def connection_lost(self, exc): - self.stream_reader.set_exception(exc) + def connection_made(self, trans): + self.trans = trans def data_received(self, data): - self.stream_reader.feed_data(data) - - def eof_received(self): - self.stream_reader.feed_eof() + self.buf.append(data) + if b'\n' in data: + self.trans.write(b''.join(self.buf).upper()) + self.trans.close() class ProactorTests(unittest.TestCase): @@ -46,3 +41,51 @@ def test_close(self): trans.close() self.loop.run_until_complete(f) self.assertEqual(f.result(), b'') + + def test_double_bind(self): + ADDRESS = r'\\.\pipe\test_double_bind-%s' % os.getpid() + server1 = windows_events.PipeServer(ADDRESS) + with self.assertRaises(PermissionError): + server2 = windows_events.PipeServer(ADDRESS) + server1.close() + + def test_pipe(self): + res = self.loop.run_until_complete(self._test_pipe()) + self.assertEqual(res, 'done') + + def _test_pipe(self): + ADDRESS = r'\\.\pipe\_test_pipe-%s' % os.getpid() + + with self.assertRaises(FileNotFoundError): + yield from self.loop.create_pipe_connection( + protocols.Protocol, ADDRESS) + + [server] = yield from self.loop.start_serving_pipe( + UpperProto, ADDRESS) + self.assertIsInstance(server, windows_events.PipeServer) + + clients = [] + for i in range(5): + stream_reader = streams.StreamReader(loop=self.loop) + protocol = streams.StreamReaderProtocol(stream_reader) + trans, proto = yield from self.loop.create_pipe_connection( + lambda:protocol, ADDRESS) + self.assertIsInstance(trans, transports.Transport) + self.assertEqual(protocol, proto) + clients.append((stream_reader, trans)) + + for i, (r, w) in enumerate(clients): + w.write('lower-{}\n'.format(i).encode()) + + for i, (r, w) in enumerate(clients): + response = yield from r.readline() + self.assertEqual(response, 'LOWER-{}\n'.format(i).encode()) + w.close() + + server.close() + + with self.assertRaises(FileNotFoundError): + yield from self.loop.create_pipe_connection( + protocols.Protocol, ADDRESS) + + return 'done' diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index 5b631f6f..c7b524ea 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -191,6 +191,18 @@ def abort(self): self._force_close(None) +class _ProactorDuplexPipeTransport(_ProactorReadPipeTransport, + _ProactorWritePipeTransport, + transports.Transport): + """Transport for duplex pipes.""" + + def can_write_eof(self): + return False + + def write_eof(self): + raise NotImplementedError + + class _ProactorSocketTransport(_ProactorReadPipeTransport, _ProactorWritePipeTransport, transports.Transport): @@ -223,6 +235,10 @@ def __init__(self, proactor): def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): return _ProactorSocketTransport(self, sock, protocol, waiter, extra) + def _make_duplex_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorDuplexPipeTransport(self, sock, protocol, waiter, extra) + def _make_read_pipe_transport(self, sock, protocol, waiter=None, extra=None): return _ProactorReadPipeTransport(self, sock, protocol, waiter, extra) @@ -294,8 +310,9 @@ def loop(f=None): conn, protocol, extra={'addr': addr}) f = self._proactor.accept(sock) except OSError: - sock.close() - tulip_log.exception('Accept failed') + if sock.fileno() != -1: + tulip_log.exception('Accept failed') + sock.close() except futures.CancelledError: sock.close() else: diff --git a/tulip/windows_events.py b/tulip/windows_events.py index 629b3475..8fbbe103 100644 --- a/tulip/windows_events.py +++ b/tulip/windows_events.py @@ -9,6 +9,7 @@ from . import futures from . import proactor_events from . import selector_events +from . import tasks from . import windows_utils from . import _overlapped from .log import tulip_log @@ -41,12 +42,64 @@ def cancel(self): return super().cancel() +class PipeServer(object): + """Class representing a pipe server. + + This is much like a bound, listening socket. + """ + def __init__(self, address): + self._address = address + self._free_instances = weakref.WeakSet() + self._pipe = self._server_pipe_handle(True) + + def _get_unconnected_pipe(self): + # Create new instance and return previous one. This ensures + # that (until the server is closed) there is always at least + # one pipe handle for address. Therefore if a client attempt + # to connect it will not fail with FileNotFoundError. + tmp, self._pipe = self._pipe, self._server_pipe_handle(False) + return tmp + + def _server_pipe_handle(self, first): + # Return a wrapper for a new pipe handle. + if self._address is None: + return None + flags = _winapi.PIPE_ACCESS_DUPLEX | _winapi.FILE_FLAG_OVERLAPPED + if first: + flags |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE + h = _winapi.CreateNamedPipe( + self._address, flags, + _winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_READMODE_MESSAGE | + _winapi.PIPE_WAIT, + _winapi.PIPE_UNLIMITED_INSTANCES, + windows_utils.BUFSIZE, windows_utils.BUFSIZE, + _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL) + pipe = windows_utils.PipeHandle(h) + self._free_instances.add(pipe) + return pipe + + def close(self): + # Close all instances which have not been connected to by a client. + if self._address is not None: + for pipe in self._free_instances: + pipe.close() + self._pipe = None + self._address = None + self._free_instances.clear() + + __del__ = close + + class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Windows version of selector event loop.""" + def _socketpair(self): return windows_utils.socketpair() class ProactorEventLoop(proactor_events.BaseProactorEventLoop): + """Windows version of proactor event loop using IOCP.""" + def __init__(self, proactor=None): if proactor is None: proactor = IocpProactor() @@ -55,8 +108,49 @@ def __init__(self, proactor=None): def _socketpair(self): return windows_utils.socketpair() + @tasks.coroutine + def create_pipe_connection(self, protocol_factory, address): + f = self._proactor.connect_pipe(address) + pipe = yield from f + protocol = protocol_factory() + trans = self._make_socket_transport(pipe, protocol, + extra={'addr': address}) + return trans, protocol + + @tasks.coroutine + def start_serving_pipe(self, protocol_factory, address): + server = PipeServer(address) + def loop(f=None): + pipe = None + try: + if f: + pipe = f.result() + server._free_instances.discard(pipe) + protocol = protocol_factory() + self._make_duplex_pipe_transport( + pipe, protocol, extra={'addr': address}) + pipe = server._get_unconnected_pipe() + if pipe is None: + return + f = self._proactor.accept_pipe(pipe) + except OSError: + if pipe and pipe.fileno() != -1: + tulip_log.exception('Pipe accept failed') + pipe.close() + except futures.CancelledError: + if pipe: + pipe.close() + else: + f.add_done_callback(loop) + self.call_soon(loop) + return [server] + + def stop_serving(self, server): + server.close() + class IocpProactor: + """Proactor implementation using IOCP.""" def __init__(self, concurrency=0xffffffff): self._loop = None @@ -80,71 +174,118 @@ def select(self, timeout=None): def recv(self, conn, nbytes, flags=0): self._register_with_iocp(conn) ov = _overlapped.Overlapped(NULL) - handle = getattr(conn, 'handle', None) - if handle is None: + if isinstance(conn, socket.socket): ov.WSARecv(conn.fileno(), nbytes, flags) else: - ov.ReadFile(handle, nbytes) - return self._register(ov, conn, ov.getresult) + ov.ReadFile(conn.fileno(), nbytes) + def finish(trans, key, ov): + return ov.getresult() + return self._register(ov, conn, finish) def send(self, conn, buf, flags=0): self._register_with_iocp(conn) ov = _overlapped.Overlapped(NULL) - handle = getattr(conn, 'handle', None) - if handle is None: + if isinstance(conn, socket.socket): ov.WSASend(conn.fileno(), buf, flags) else: - ov.WriteFile(handle, buf) - return self._register(ov, conn, ov.getresult) + ov.WriteFile(conn.fileno(), buf) + def finish(trans, key, ov): + return ov.getresult() + return self._register(ov, conn, finish) def accept(self, listener): self._register_with_iocp(listener) conn = self._get_accept_socket() ov = _overlapped.Overlapped(NULL) ov.AcceptEx(listener.fileno(), conn.fileno()) - - def finish_accept(): + def finish_accept(trans, key, ov): ov.getresult() + # Use SO_UPDATE_ACCEPT_CONTEXT so getsockname() etc work. buf = struct.pack('@P', listener.fileno()) conn.setsockopt(socket.SOL_SOCKET, - _overlapped.SO_UPDATE_ACCEPT_CONTEXT, - buf) + _overlapped.SO_UPDATE_ACCEPT_CONTEXT, buf) conn.settimeout(listener.gettimeout()) return conn, conn.getpeername() - return self._register(ov, listener, finish_accept) def connect(self, conn, address): self._register_with_iocp(conn) - # the socket needs to be locally bound before we call ConnectEx() + # The socket needs to be locally bound before we call ConnectEx(). try: _overlapped.BindLocal(conn.fileno(), len(address)) except OSError as e: if e.winerror != errno.WSAEINVAL: raise - # probably already locally bound; check using getsockname() + # Probably already locally bound; check using getsockname(). if conn.getsockname()[1] == 0: raise ov = _overlapped.Overlapped(NULL) ov.ConnectEx(conn.fileno(), address) - - def finish_connect(): + def finish_connect(trans, key, ov): ov.getresult() + # Use SO_UPDATE_CONNECT_CONTEXT so getsockname() etc work. conn.setsockopt(socket.SOL_SOCKET, - _overlapped.SO_UPDATE_CONNECT_CONTEXT, - 0) + _overlapped.SO_UPDATE_CONNECT_CONTEXT, 0) return conn - return self._register(ov, conn, finish_connect) + def accept_pipe(self, pipe): + self._register_with_iocp(pipe) + ov = _overlapped.Overlapped(NULL) + ov.ConnectNamedPipe(pipe.fileno()) + def finish(trans, key, ov): + ov.getresult() + return pipe + return self._register(ov, pipe, finish) + + def connect_pipe(self, address): + ov = _overlapped.Overlapped(NULL) + ov.WaitNamedPipeAndConnect(address, self._iocp, ov.address) + def finish(err, handle, ov): + # err, handle were arguments passed to PostQueuedCompletionStatus() + # in a function run in a thread pool. + if err == _overlapped.ERROR_SEM_TIMEOUT: + # Connection did not succeed within time limit. + msg = _overlapped.FormatMessage(err) + raise ConnectionRefusedError(0, msg, None, err) + elif err != 0: + msg = _overlapped.FormatMessage(err) + raise OSError(0, msg, None, err) + else: + return windows_utils.PipeHandle(handle) + return self._register(ov, None, finish, wait_for_post=True) + def _register_with_iocp(self, obj): + # To get notifications of finished ops on this objects sent to the + # completion port, were must register the handle. if obj not in self._registered: self._registered.add(obj) _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) - - def _register(self, ov, obj, callback): + # XXX We could also use SetFileCompletionNotificationModes() + # to avoid sending notifications to completion port of ops + # that succeed immediately. + + def _register(self, ov, obj, callback, wait_for_post=False): + # Return a future which will be set with the result of the + # operation when it completes. The future's value is actually + # the value returned by callback(). f = _OverlappedFuture(ov, loop=self._loop) - self._cache[ov.address] = (f, ov, obj, callback) + if ov.pending or wait_for_post: + # Register the overlapped operation for later. Note that + # we only store obj to prevent it from being garbage + # collected too early. + self._cache[ov.address] = (f, ov, obj, callback) + else: + # The operation has completed, so no need to postpone the + # work. We cannot take this short cut if we need the + # NumberOfBytes, CompletionKey values returned by + # PostQueuedCompletionStatus(). + try: + value = callback(None, None, ov) + except OSError as e: + f.set_exception(e) + else: + f.set_result(value) return f def _get_accept_socket(self): @@ -165,13 +306,21 @@ def _poll(self, timeout=None): status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms) if status is None: return - address = status[3] - f, ov, obj, callback = self._cache.pop(address) + err, transferred, key, address = status + try: + f, ov, obj, callback = self._cache.pop(address) + except KeyError: + # key is either zero, or it is used to return a pipe + # handle which should be closed to avoid a leak. + if key not in (0, _overlapped.INVALID_HANDLE_VALUE): + _winapi.CloseHandle(key) + ms = 0 + continue if obj in self._stopped_serving: f.cancel() elif not f.cancelled(): try: - value = callback() + value = callback(transferred, key, ov) except OSError as e: f.set_exception(e) self._results.append(f) @@ -187,11 +336,18 @@ def stop_serving(self, obj): self._stopped_serving.add(obj) def close(self): - for (f, ov, obj, callback) in self._cache.values(): - try: - ov.cancel() - except OSError: - pass + # Cancel remaining registered operations. + for address, (f, ov, obj, callback) in list(self._cache.items()): + if obj is None: + # The operation was started with connect_pipe() which + # queues a task to Windows' thread pool. This cannot + # be cancelled, so just forget it. + del self._cache[address] + else: + try: + ov.cancel() + except OSError: + pass while self._cache: if not self._poll(1): diff --git a/tulip/windows_utils.py b/tulip/windows_utils.py index af9d1418..04b43e9a 100644 --- a/tulip/windows_utils.py +++ b/tulip/windows_utils.py @@ -124,9 +124,9 @@ def fileno(self): return self._handle def close(self, *, CloseHandle=_winapi.CloseHandle): - if self._handle is not None: + if self._handle != -1: CloseHandle(self._handle) - self._handle = None + self._handle = -1 __del__ = close From b4fe468f33a9c94e822e7b94b4a5188ad24c5539 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 7 Oct 2013 10:52:08 -0700 Subject: [PATCH 0642/1502] remove tulip.http --- examples/crawl.py | 104 ---- examples/curl.py | 24 - examples/mpsrv.py | 293 ----------- examples/srv.py | 163 ------ examples/tcp_protocol_parser.py | 176 ------- examples/websocket.html | 90 ---- examples/wsclient.py | 100 ---- examples/wssrv.py | 314 ----------- tests/events_test.py | 16 +- tests/http_client_functional_test.py | 552 ------------------- tests/http_client_test.py | 298 ----------- tests/http_parser_test.py | 539 ------------------- tests/http_protocol_test.py | 400 -------------- tests/http_server_test.py | 301 ----------- tests/http_session_test.py | 139 ----- tests/http_websocket_test.py | 439 ---------------- tests/http_wsgi_test.py | 301 ----------- tests/parsers_test.py | 605 --------------------- tests/streams_test.py | 7 +- tulip/__init__.py | 2 - tulip/http/__init__.py | 16 - tulip/http/client.py | 581 -------------------- tulip/http/errors.py | 46 -- tulip/http/protocol.py | 756 --------------------------- tulip/http/server.py | 215 -------- tulip/http/session.py | 103 ---- tulip/http/websocket.py | 233 --------- tulip/http/wsgi.py | 228 -------- tulip/parsers.py | 399 -------------- tulip/test_utils.py | 306 ++--------- 30 files changed, 56 insertions(+), 7690 deletions(-) delete mode 100755 examples/crawl.py delete mode 100755 examples/curl.py delete mode 100755 examples/mpsrv.py delete mode 100755 examples/srv.py delete mode 100755 examples/tcp_protocol_parser.py delete mode 100644 examples/websocket.html delete mode 100755 examples/wsclient.py delete mode 100755 examples/wssrv.py delete mode 100644 tests/http_client_functional_test.py delete mode 100644 tests/http_client_test.py delete mode 100644 tests/http_parser_test.py delete mode 100644 tests/http_protocol_test.py delete mode 100644 tests/http_server_test.py delete mode 100644 tests/http_session_test.py delete mode 100644 tests/http_websocket_test.py delete mode 100644 tests/http_wsgi_test.py delete mode 100644 tests/parsers_test.py delete mode 100644 tulip/http/__init__.py delete mode 100644 tulip/http/client.py delete mode 100644 tulip/http/errors.py delete mode 100644 tulip/http/protocol.py delete mode 100644 tulip/http/server.py delete mode 100644 tulip/http/session.py delete mode 100644 tulip/http/websocket.py delete mode 100644 tulip/http/wsgi.py delete mode 100644 tulip/parsers.py diff --git a/examples/crawl.py b/examples/crawl.py deleted file mode 100755 index f7d53feb..00000000 --- a/examples/crawl.py +++ /dev/null @@ -1,104 +0,0 @@ -#!/usr/bin/env python3 - -import logging -import re -import signal -import sys -import urllib.parse - -import tulip -import tulip.http - - -class Crawler: - - def __init__(self, rooturl, loop, maxtasks=100): - self.rooturl = rooturl - self.loop = loop - self.todo = set() - self.busy = set() - self.done = {} - self.tasks = set() - self.sem = tulip.Semaphore(maxtasks) - - # session stores cookies between requests and uses connection pool - self.session = tulip.http.Session() - - @tulip.coroutine - def run(self): - tulip.Task(self.addurls([(self.rooturl, '')])) # Set initial work. - yield from tulip.sleep(1) - while self.busy: - yield from tulip.sleep(1) - - self.session.close() - self.loop.stop() - - @tulip.coroutine - def addurls(self, urls): - for url, parenturl in urls: - url = urllib.parse.urljoin(parenturl, url) - url, frag = urllib.parse.urldefrag(url) - if (url.startswith(self.rooturl) and - url not in self.busy and - url not in self.done and - url not in self.todo): - self.todo.add(url) - yield from self.sem.acquire() - task = tulip.Task(self.process(url)) - task.add_done_callback(lambda t: self.sem.release()) - task.add_done_callback(self.tasks.remove) - self.tasks.add(task) - - @tulip.coroutine - def process(self, url): - print('processing:', url) - - self.todo.remove(url) - self.busy.add(url) - try: - resp = yield from tulip.http.request( - 'get', url, session=self.session) - except Exception as exc: - print('...', url, 'has error', repr(str(exc))) - self.done[url] = False - else: - if resp.status == 200 and resp.get_content_type() == 'text/html': - data = (yield from resp.read()).decode('utf-8', 'replace') - urls = re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', data) - tulip.Task(self.addurls([(u, url) for u in urls])) - - resp.close() - self.done[url] = True - - self.busy.remove(url) - print(len(self.done), 'completed tasks,', len(self.tasks), - 'still pending, todo', len(self.todo)) - - -def main(): - loop = tulip.get_event_loop() - - c = Crawler(sys.argv[1], loop) - tulip.Task(c.run()) - - try: - loop.add_signal_handler(signal.SIGINT, loop.stop) - except RuntimeError: - pass - loop.run_forever() - print('todo:', len(c.todo)) - print('busy:', len(c.busy)) - print('done:', len(c.done), '; ok:', sum(c.done.values())) - print('tasks:', len(c.tasks)) - - -if __name__ == '__main__': - if '--iocp' in sys.argv: - from tulip import events, windows_events - sys.argv.remove('--iocp') - logging.info('using iocp') - el = windows_events.ProactorEventLoop() - events.set_event_loop(el) - - main() diff --git a/examples/curl.py b/examples/curl.py deleted file mode 100755 index 7063adcd..00000000 --- a/examples/curl.py +++ /dev/null @@ -1,24 +0,0 @@ -#!/usr/bin/env python3 - -import sys -import tulip -import tulip.http - - -def curl(url): - response = yield from tulip.http.request('GET', url) - print(repr(response)) - - data = yield from response.read() - print(data.decode('utf-8', 'replace')) - - -if __name__ == '__main__': - if '--iocp' in sys.argv: - from tulip import events, windows_events - sys.argv.remove('--iocp') - el = windows_events.ProactorEventLoop() - events.set_event_loop(el) - - loop = tulip.get_event_loop() - loop.run_until_complete(curl(sys.argv[1])) diff --git a/examples/mpsrv.py b/examples/mpsrv.py deleted file mode 100755 index c594f5bc..00000000 --- a/examples/mpsrv.py +++ /dev/null @@ -1,293 +0,0 @@ -#!/usr/bin/env python3 -"""Simple multiprocess http server written using an event loop.""" - -import argparse -import email.message -import os -import socket -import signal -import time -import tulip -import tulip.http -from tulip.http import websocket - -ARGS = argparse.ArgumentParser(description="Run simple http server.") -ARGS.add_argument( - '--host', action="store", dest='host', - default='127.0.0.1', help='Host name') -ARGS.add_argument( - '--port', action="store", dest='port', - default=8080, type=int, help='Port number') -ARGS.add_argument( - '--workers', action="store", dest='workers', - default=2, type=int, help='Number of workers.') - - -class HttpServer(tulip.http.ServerHttpProtocol): - - @tulip.coroutine - def handle_request(self, message, payload): - print('{}: method = {!r}; path = {!r}; version = {!r}'.format( - os.getpid(), message.method, message.path, message.version)) - - path = message.path - - if (not (path.isprintable() and path.startswith('/')) or '/.' in path): - path = None - else: - path = '.' + path - if not os.path.exists(path): - path = None - else: - isdir = os.path.isdir(path) - - if not path: - raise tulip.http.HttpErrorException(404) - - headers = email.message.Message() - for hdr, val in message.headers: - headers.add_header(hdr, val) - - if isdir and not path.endswith('/'): - path = path + '/' - raise tulip.http.HttpErrorException( - 302, headers=(('URI', path), ('Location', path))) - - response = tulip.http.Response(self.transport, 200) - response.add_header('Transfer-Encoding', 'chunked') - - # content encoding - accept_encoding = headers.get('accept-encoding', '').lower() - if 'deflate' in accept_encoding: - response.add_header('Content-Encoding', 'deflate') - response.add_compression_filter('deflate') - elif 'gzip' in accept_encoding: - response.add_header('Content-Encoding', 'gzip') - response.add_compression_filter('gzip') - - response.add_chunking_filter(1025) - - if isdir: - response.add_header('Content-type', 'text/html') - response.send_headers() - - response.write(b'
    \r\n') - for name in sorted(os.listdir(path)): - if name.isprintable() and not name.startswith('.'): - try: - bname = name.encode('ascii') - except UnicodeError: - pass - else: - if os.path.isdir(os.path.join(path, name)): - response.write(b'
  • ' + bname + b'/
  • \r\n') - else: - response.write(b'
  • ' + bname + b'
  • \r\n') - response.write(b'
') - else: - response.add_header('Content-type', 'text/plain') - response.send_headers() - - try: - with open(path, 'rb') as fp: - chunk = fp.read(8196) - while chunk: - response.write(chunk) - chunk = fp.read(8196) - except OSError: - response.write(b'Cannot open') - - response.write_eof() - if response.keep_alive(): - self.keep_alive(True) - - -class ChildProcess: - - def __init__(self, up_read, down_write, args, sock): - self.up_read = up_read - self.down_write = down_write - self.args = args - self.sock = sock - - def start(self): - # start server - self.loop = loop = tulip.new_event_loop() - tulip.set_event_loop(loop) - - def stop(): - self.loop.stop() - os._exit(0) - loop.add_signal_handler(signal.SIGINT, stop) - - f = loop.start_serving( - lambda: HttpServer(debug=True, keep_alive=75), sock=self.sock) - x = loop.run_until_complete(f)[0] - print('Starting srv worker process {} on {}'.format( - os.getpid(), x.getsockname())) - - # heartbeat - tulip.Task(self.heartbeat()) - - tulip.get_event_loop().run_forever() - os._exit(0) - - @tulip.coroutine - def heartbeat(self): - # setup pipes - read_transport, read_proto = yield from self.loop.connect_read_pipe( - tulip.StreamProtocol, os.fdopen(self.up_read, 'rb')) - write_transport, _ = yield from self.loop.connect_write_pipe( - tulip.StreamProtocol, os.fdopen(self.down_write, 'wb')) - - reader = read_proto.set_parser(websocket.WebSocketParser()) - writer = websocket.WebSocketWriter(write_transport) - - while True: - try: - msg = yield from reader.read() - except tulip.EofStream: - print('Superviser is dead, {} stopping...'.format(os.getpid())) - self.loop.stop() - break - - if msg.tp == websocket.MSG_PING: - writer.pong() - elif msg.tp == websocket.MSG_CLOSE: - break - - read_transport.close() - write_transport.close() - - -class Worker: - - _started = False - - def __init__(self, loop, args, sock): - self.loop = loop - self.args = args - self.sock = sock - self.start() - - def start(self): - assert not self._started - self._started = True - - up_read, up_write = os.pipe() - down_read, down_write = os.pipe() - args, sock = self.args, self.sock - - pid = os.fork() - if pid: - # parent - os.close(up_read) - os.close(down_write) - self.connect(pid, up_write, down_read) - else: - # child - os.close(up_write) - os.close(down_read) - - # cleanup after fork - tulip.set_event_loop(None) - - # setup process - process = ChildProcess(up_read, down_write, args, sock) - process.start() - - @tulip.coroutine - def heartbeat(self, writer): - while True: - yield from tulip.sleep(15) - - if (time.monotonic() - self.ping) < 30: - writer.ping() - else: - print('Restart unresponsive worker process: {}'.format( - self.pid)) - self.kill() - self.start() - return - - @tulip.coroutine - def chat(self, reader): - while True: - try: - msg = yield from reader.read() - except tulip.EofStream: - print('Restart unresponsive worker process: {}'.format( - self.pid)) - self.kill() - self.start() - return - - if msg.tp == websocket.MSG_PONG: - self.ping = time.monotonic() - - @tulip.coroutine - def connect(self, pid, up_write, down_read): - # setup pipes - read_transport, proto = yield from self.loop.connect_read_pipe( - tulip.StreamProtocol, os.fdopen(down_read, 'rb')) - write_transport, _ = yield from self.loop.connect_write_pipe( - tulip.StreamProtocol, os.fdopen(up_write, 'wb')) - - # websocket protocol - reader = proto.set_parser(websocket.WebSocketParser()) - writer = websocket.WebSocketWriter(write_transport) - - # store info - self.pid = pid - self.ping = time.monotonic() - self.rtransport = read_transport - self.wtransport = write_transport - self.chat_task = tulip.Task(self.chat(reader)) - self.heartbeat_task = tulip.Task(self.heartbeat(writer)) - - def kill(self): - self._started = False - self.chat_task.cancel() - self.heartbeat_task.cancel() - self.rtransport.close() - self.wtransport.close() - os.kill(self.pid, signal.SIGTERM) - - -class Superviser: - - def __init__(self, args): - self.loop = tulip.get_event_loop() - self.args = args - self.workers = [] - - def start(self): - # bind socket - sock = self.sock = socket.socket() - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind((self.args.host, self.args.port)) - sock.listen(1024) - sock.setblocking(False) - - # start processes - for idx in range(self.args.workers): - self.workers.append(Worker(self.loop, self.args, sock)) - - self.loop.add_signal_handler(signal.SIGINT, lambda: self.loop.stop()) - self.loop.run_forever() - - -def main(): - args = ARGS.parse_args() - if ':' in args.host: - args.host, port = args.host.split(':', 1) - args.port = int(port) - - superviser = Superviser(args) - superviser.start() - - -if __name__ == '__main__': - main() diff --git a/examples/srv.py b/examples/srv.py deleted file mode 100755 index e4bf16c1..00000000 --- a/examples/srv.py +++ /dev/null @@ -1,163 +0,0 @@ -#!/usr/bin/env python3 -"""Simple server written using an event loop.""" - -import argparse -import email.message -import logging -import os -import sys -try: - import ssl -except ImportError: # pragma: no cover - ssl = None - -assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' - -import tulip -import tulip.http - - -class HttpServer(tulip.http.ServerHttpProtocol): - - @tulip.coroutine - def handle_request(self, message, payload): - print('method = {!r}; path = {!r}; version = {!r}'.format( - message.method, message.path, message.version)) - - path = message.path - - if (not (path.isprintable() and path.startswith('/')) or '/.' in path): - print('bad path', repr(path)) - path = None - else: - path = '.' + path - if not os.path.exists(path): - print('no file', repr(path)) - path = None - else: - isdir = os.path.isdir(path) - - if not path: - raise tulip.http.HttpErrorException(404) - - headers = email.message.Message() - for hdr, val in message.headers: - print(hdr, val) - headers.add_header(hdr, val) - - if isdir and not path.endswith('/'): - path = path + '/' - raise tulip.http.HttpErrorException( - 302, headers=(('URI', path), ('Location', path))) - - response = tulip.http.Response(self.transport, 200) - response.add_header('Transfer-Encoding', 'chunked') - - # content encoding - accept_encoding = headers.get('accept-encoding', '').lower() - if 'deflate' in accept_encoding: - response.add_header('Content-Encoding', 'deflate') - response.add_compression_filter('deflate') - elif 'gzip' in accept_encoding: - response.add_header('Content-Encoding', 'gzip') - response.add_compression_filter('gzip') - - response.add_chunking_filter(1025) - - if isdir: - response.add_header('Content-type', 'text/html') - response.send_headers() - - response.write(b'
    \r\n') - for name in sorted(os.listdir(path)): - if name.isprintable() and not name.startswith('.'): - try: - bname = name.encode('ascii') - except UnicodeError: - pass - else: - if os.path.isdir(os.path.join(path, name)): - response.write(b'
  • ' + bname + b'/
  • \r\n') - else: - response.write(b'
  • ' + bname + b'
  • \r\n') - response.write(b'
') - else: - response.add_header('Content-type', 'text/plain') - response.send_headers() - - try: - with open(path, 'rb') as fp: - chunk = fp.read(8196) - while chunk: - response.write(chunk) - chunk = fp.read(8196) - except OSError: - response.write(b'Cannot open') - - response.write_eof() - if response.keep_alive(): - self.keep_alive(True) - - -ARGS = argparse.ArgumentParser(description="Run simple http server.") -ARGS.add_argument( - '--host', action="store", dest='host', - default='127.0.0.1', help='Host name') -ARGS.add_argument( - '--port', action="store", dest='port', - default=8080, type=int, help='Port number') -ARGS.add_argument( - '--iocp', action="store_true", dest='iocp', help='Windows IOCP event loop') -ARGS.add_argument( - '--ssl', action="store_true", dest='ssl', help='Run ssl mode.') -ARGS.add_argument( - '--sslcert', action="store", dest='certfile', help='SSL cert file.') -ARGS.add_argument( - '--sslkey', action="store", dest='keyfile', help='SSL key file.') - - -def main(): - args = ARGS.parse_args() - - if ':' in args.host: - args.host, port = args.host.split(':', 1) - args.port = int(port) - - if args.iocp: - from tulip import windows_events - sys.argv.remove('--iocp') - logging.info('using iocp') - el = windows_events.ProactorEventLoop() - tulip.set_event_loop(el) - - if args.ssl: - here = os.path.join(os.path.dirname(__file__), 'tests') - - if args.certfile: - certfile = args.certfile or os.path.join(here, 'sample.crt') - keyfile = args.keyfile or os.path.join(here, 'sample.key') - else: - certfile = os.path.join(here, 'sample.crt') - keyfile = os.path.join(here, 'sample.key') - - sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - sslcontext.load_cert_chain(certfile, keyfile) - else: - sslcontext = None - - loop = tulip.get_event_loop() - f = loop.start_serving( - lambda: HttpServer(debug=True, keep_alive=75), args.host, args.port, - ssl=sslcontext) - socks = loop.run_until_complete(f) - print('serving on', socks[0].getsockname()) - try: - loop.run_forever() - except KeyboardInterrupt: - pass - - -if __name__ == '__main__': - main() diff --git a/examples/tcp_protocol_parser.py b/examples/tcp_protocol_parser.py deleted file mode 100755 index e4fc59ad..00000000 --- a/examples/tcp_protocol_parser.py +++ /dev/null @@ -1,176 +0,0 @@ -#!/usr/bin/env python3 -"""Protocol parser example.""" -import argparse -import collections -import tulip -try: - import signal -except ImportError: - signal = None - - -MSG_TEXT = b'text:' -MSG_PING = b'ping:' -MSG_PONG = b'pong:' -MSG_STOP = b'stop:' - -Message = collections.namedtuple('Message', ('tp', 'data')) - - -def my_protocol_parser(): - """Parser is used with StreamBuffer for incremental protocol parsing. - Parser is a generator function, but it is not a coroutine. Usually - parsers are implemented as a state machine. - - more details in tulip/parsers.py - existing parsers: - * http protocol parsers tulip/http/protocol.py - * websocket parser tulip/http/websocket.py - """ - out, buf = yield - - while True: - tp = yield from buf.read(5) - if tp in (MSG_PING, MSG_PONG): - # skip line - yield from buf.skipuntil(b'\r\n') - out.feed_data(Message(tp, None)) - elif tp == MSG_STOP: - out.feed_data(Message(tp, None)) - elif tp == MSG_TEXT: - # read text - text = yield from buf.readuntil(b'\r\n') - out.feed_data(Message(tp, text.strip().decode('utf-8'))) - else: - raise ValueError('Unknown protocol prefix.') - - -class MyProtocolWriter: - - def __init__(self, transport): - self.transport = transport - - def ping(self): - self.transport.write(b'ping:\r\n') - - def pong(self): - self.transport.write(b'pong:\r\n') - - def stop(self): - self.transport.write(b'stop:\r\n') - - def send_text(self, text): - self.transport.write( - 'text:{}\r\n'.format(text.strip()).encode('utf-8')) - - -class EchoServer(tulip.Protocol): - - def connection_made(self, transport): - print('Connection made') - self.transport = transport - self.stream = tulip.StreamBuffer() - tulip.Task(self.dispatch()) - - def data_received(self, data): - self.stream.feed_data(data) - - def eof_received(self): - self.stream.feed_eof() - - def connection_lost(self, exc): - print('Connection lost') - - @tulip.coroutine - def dispatch(self): - reader = self.stream.set_parser(my_protocol_parser()) - writer = MyProtocolWriter(self.transport) - - while True: - try: - msg = yield from reader.read() - except tulip.EofStream: - # client has been disconnected - break - - print('Message received: {}'.format(msg)) - - if msg.tp == MSG_PING: - writer.pong() - elif msg.tp == MSG_TEXT: - writer.send_text('Re: ' + msg.data) - elif msg.tp == MSG_STOP: - self.transport.close() - break - - -@tulip.coroutine -def start_client(loop, host, port): - transport, stream = yield from loop.create_connection( - tulip.StreamProtocol, host, port) - reader = stream.set_parser(my_protocol_parser()) - writer = MyProtocolWriter(transport) - writer.ping() - - message = 'This is the message. It will be echoed.' - - while True: - try: - msg = yield from reader.read() - except tulip.EofStream: - print('Server has been disconnected.') - break - - print('Message received: {}'.format(msg)) - if msg.tp == MSG_PONG: - writer.send_text(message) - print('data sent:', message) - elif msg.tp == MSG_TEXT: - writer.stop() - print('stop sent') - break - - transport.close() - - -def start_server(loop, host, port): - f = loop.start_serving(EchoServer, host, port) - x = loop.run_until_complete(f)[0] - print('serving on', x.getsockname()) - loop.run_forever() - - -ARGS = argparse.ArgumentParser(description="Protocol parser example.") -ARGS.add_argument( - '--server', action="store_true", dest='server', - default=False, help='Run tcp server') -ARGS.add_argument( - '--client', action="store_true", dest='client', - default=False, help='Run tcp client') -ARGS.add_argument( - '--host', action="store", dest='host', - default='127.0.0.1', help='Host name') -ARGS.add_argument( - '--port', action="store", dest='port', - default=9999, type=int, help='Port number') - - -if __name__ == '__main__': - args = ARGS.parse_args() - - if ':' in args.host: - args.host, port = args.host.split(':', 1) - args.port = int(port) - - if (not (args.server or args.client)) or (args.server and args.client): - print('Please specify --server or --client\n') - ARGS.print_help() - else: - loop = tulip.get_event_loop() - if signal is not None: - loop.add_signal_handler(signal.SIGINT, loop.stop) - - if args.server: - start_server(loop, args.host, args.port) - else: - loop.run_until_complete(start_client(loop, args.host, args.port)) diff --git a/examples/websocket.html b/examples/websocket.html deleted file mode 100644 index 6bad7f74..00000000 --- a/examples/websocket.html +++ /dev/null @@ -1,90 +0,0 @@ - - - - - - - - -

Chat!

-
-  | Status: - disconnected -
-
-
-
- - -
- - diff --git a/examples/wsclient.py b/examples/wsclient.py deleted file mode 100755 index ed7beda5..00000000 --- a/examples/wsclient.py +++ /dev/null @@ -1,100 +0,0 @@ -#!/usr/bin/env python3 -"""websocket cmd client for wssrv.py example.""" -import argparse -import base64 -import hashlib -import os -import signal -import sys - -import tulip -import tulip.http -from tulip.http import websocket -import tulip.selectors - -WS_KEY = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - - -def start_client(loop, url): - name = input('Please enter your name: ').encode() - - sec_key = base64.b64encode(os.urandom(16)) - - # send request - response = yield from tulip.http.request( - 'get', url, - headers={ - 'UPGRADE': 'WebSocket', - 'CONNECTION': 'Upgrade', - 'SEC-WEBSOCKET-VERSION': '13', - 'SEC-WEBSOCKET-KEY': sec_key.decode(), - }, timeout=1.0) - - # websocket handshake - if response.status != 101: - raise ValueError("Handshake error: Invalid response status") - if response.get('upgrade', '').lower() != 'websocket': - raise ValueError("Handshake error - Invalid upgrade header") - if response.get('connection', '').lower() != 'upgrade': - raise ValueError("Handshake error - Invalid connection header") - - key = response.get('sec-websocket-accept', '').encode() - match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()) - if key != match: - raise ValueError("Handshake error - Invalid challenge response") - - # switch to websocket protocol - stream = response.stream.set_parser(websocket.WebSocketParser()) - writer = websocket.WebSocketWriter(response.transport) - - # input reader - def stdin_callback(): - line = sys.stdin.buffer.readline() - if not line: - loop.stop() - else: - writer.send(name + b': ' + line) - loop.add_reader(sys.stdin.fileno(), stdin_callback) - - @tulip.coroutine - def dispatch(): - while True: - try: - msg = yield from stream.read() - except tulip.EofStream: - # server disconnected - break - - if msg.tp == websocket.MSG_PING: - writer.pong() - elif msg.tp == websocket.MSG_TEXT: - print(msg.data.strip()) - elif msg.tp == websocket.MSG_CLOSE: - break - - yield from dispatch() - - -ARGS = argparse.ArgumentParser( - description="websocket console client for wssrv.py example.") -ARGS.add_argument( - '--host', action="store", dest='host', - default='127.0.0.1', help='Host name') -ARGS.add_argument( - '--port', action="store", dest='port', - default=8080, type=int, help='Port number') - -if __name__ == '__main__': - args = ARGS.parse_args() - if ':' in args.host: - args.host, port = args.host.split(':', 1) - args.port = int(port) - - url = 'http://{}:{}'.format(args.host, args.port) - - loop = tulip.SelectorEventLoop(tulip.selectors.SelectSelector()) - tulip.set_event_loop(loop) - - loop.add_signal_handler(signal.SIGINT, loop.stop) - tulip.Task(start_client(loop, url)) - loop.run_forever() diff --git a/examples/wssrv.py b/examples/wssrv.py deleted file mode 100755 index 8a02a2dd..00000000 --- a/examples/wssrv.py +++ /dev/null @@ -1,314 +0,0 @@ -#!/usr/bin/env python3 -"""Multiprocess WebSocket http chat example.""" - -import argparse -import os -import socket -import signal -import time -import tulip -import tulip.http -from tulip.http import websocket - -ARGS = argparse.ArgumentParser(description="Run simple http server.") -ARGS.add_argument( - '--host', action="store", dest='host', - default='127.0.0.1', help='Host name') -ARGS.add_argument( - '--port', action="store", dest='port', - default=8080, type=int, help='Port number') -ARGS.add_argument( - '--workers', action="store", dest='workers', - default=2, type=int, help='Number of workers.') - -WS_FILE = os.path.join(os.path.dirname(__file__), 'websocket.html') - - -class HttpServer(tulip.http.ServerHttpProtocol): - - clients = None # list of all active connections - parent = None # process supervisor - # we use it as broadcaster to all workers - - @tulip.coroutine - def handle_request(self, message, payload): - upgrade = False - for hdr, val in message.headers: - if hdr == 'UPGRADE': - upgrade = 'websocket' in val.lower() - break - - if upgrade: - # websocket handshake - status, headers, parser, writer = websocket.do_handshake( - message.method, message.headers, self.transport) - - resp = tulip.http.Response(self.transport, status) - resp.add_headers(*headers) - resp.send_headers() - - # install websocket parser - databuffer = self.stream.set_parser(parser) - - # notify everybody - print('{}: Someone joined.'.format(os.getpid())) - for wsc in self.clients: - wsc.send(b'Someone joined.') - self.clients.append(writer) - self.parent.send(b'Someone joined.') - - # chat dispatcher - while True: - try: - msg = yield from databuffer.read() - except tulip.EofStream: - # client droped connection - break - - if msg.tp == websocket.MSG_PING: - writer.pong() - - elif msg.tp == websocket.MSG_TEXT: - data = msg.data.strip() - print('{}: {}'.format(os.getpid(), data)) - for wsc in self.clients: - if wsc is not writer: - wsc.send(data.encode()) - self.parent.send(data) - - elif msg.tp == websocket.MSG_CLOSE: - break - - # notify everybody - print('{}: Someone disconnected.'.format(os.getpid())) - self.parent.send(b'Someone disconnected.') - self.clients.remove(writer) - for wsc in self.clients: - wsc.send(b'Someone disconnected.') - - else: - # send html page with js chat - response = tulip.http.Response(self.transport, 200) - response.add_header('Transfer-Encoding', 'chunked') - response.add_header('Content-type', 'text/html') - response.send_headers() - - try: - with open(WS_FILE, 'rb') as fp: - chunk = fp.read(8196) - while chunk: - if not response.write(chunk): - break - chunk = fp.read(8196) - except OSError: - response.write(b'Cannot open') - - response.write_eof() - if response.keep_alive(): - self.keep_alive(True) - - -class ChildProcess: - - def __init__(self, up_read, down_write, args, sock): - self.up_read = up_read - self.down_write = down_write - self.args = args - self.sock = sock - self.clients = [] - - def start(self): - # start server - self.loop = loop = tulip.new_event_loop() - tulip.set_event_loop(loop) - - def stop(): - self.loop.stop() - os._exit(0) - loop.add_signal_handler(signal.SIGINT, stop) - - # heartbeat - tulip.Task(self.heartbeat()) - - tulip.get_event_loop().run_forever() - os._exit(0) - - @tulip.coroutine - def start_server(self, writer): - socks = yield from self.loop.start_serving( - lambda: HttpServer( - debug=True, keep_alive=75, - parent=writer, clients=self.clients), - sock=self.sock) - print('Starting srv worker process {} on {}'.format( - os.getpid(), socks[0].getsockname())) - - @tulip.coroutine - def heartbeat(self): - # setup pipes - read_transport, read_proto = yield from self.loop.connect_read_pipe( - tulip.StreamProtocol, os.fdopen(self.up_read, 'rb')) - write_transport, _ = yield from self.loop.connect_write_pipe( - tulip.StreamProtocol, os.fdopen(self.down_write, 'wb')) - - reader = read_proto.set_parser(websocket.WebSocketParser()) - writer = websocket.WebSocketWriter(write_transport) - - tulip.Task(self.start_server(writer)) - - while True: - try: - msg = yield from reader.read() - except tulip.EofStream: - print('Superviser is dead, {} stopping...'.format(os.getpid())) - self.loop.stop() - break - - if msg.tp == websocket.MSG_PING: - writer.pong() - elif msg.tp == websocket.MSG_CLOSE: - break - elif msg.tp == websocket.MSG_TEXT: # broadcast message - for wsc in self.clients: - wsc.send(msg.data.strip().encode()) - - read_transport.close() - write_transport.close() - - -class Worker: - - _started = False - - def __init__(self, sv, loop, args, sock): - self.sv = sv - self.loop = loop - self.args = args - self.sock = sock - self.start() - - def start(self): - assert not self._started - self._started = True - - up_read, up_write = os.pipe() - down_read, down_write = os.pipe() - args, sock = self.args, self.sock - - pid = os.fork() - if pid: - # parent - os.close(up_read) - os.close(down_write) - tulip.async(self.connect(pid, up_write, down_read)) - else: - # child - os.close(up_write) - os.close(down_read) - - # cleanup after fork - tulip.set_event_loop(None) - - # setup process - process = ChildProcess(up_read, down_write, args, sock) - process.start() - - @tulip.coroutine - def heartbeat(self, writer): - while True: - yield from tulip.sleep(15) - - if (time.monotonic() - self.ping) < 30: - writer.ping() - else: - print('Restart unresponsive worker process: {}'.format( - self.pid)) - self.kill() - self.start() - return - - @tulip.coroutine - def chat(self, reader): - while True: - try: - msg = yield from reader.read() - except tulip.EofStream: - print('Restart unresponsive worker process: {}'.format( - self.pid)) - self.kill() - self.start() - return - - if msg.tp == websocket.MSG_PONG: - self.ping = time.monotonic() - - elif msg.tp == websocket.MSG_TEXT: # broadcast to all workers - for worker in self.sv.workers: - if self.pid != worker.pid: - worker.writer.send(msg.data) - - @tulip.coroutine - def connect(self, pid, up_write, down_read): - # setup pipes - read_transport, proto = yield from self.loop.connect_read_pipe( - tulip.StreamProtocol, os.fdopen(down_read, 'rb')) - write_transport, _ = yield from self.loop.connect_write_pipe( - tulip.StreamProtocol, os.fdopen(up_write, 'wb')) - - # websocket protocol - reader = proto.set_parser(websocket.WebSocketParser()) - writer = websocket.WebSocketWriter(write_transport) - - # store info - self.pid = pid - self.ping = time.monotonic() - self.writer = writer - self.rtransport = read_transport - self.wtransport = write_transport - self.chat_task = tulip.async(self.chat(reader)) - self.heartbeat_task = tulip.async(self.heartbeat(writer)) - - def kill(self): - self._started = False - self.chat_task.cancel() - self.heartbeat_task.cancel() - self.rtransport.close() - self.wtransport.close() - os.kill(self.pid, signal.SIGTERM) - - -class Superviser: - - def __init__(self, args): - self.loop = tulip.get_event_loop() - self.args = args - self.workers = [] - - def start(self): - # bind socket - sock = self.sock = socket.socket() - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.bind((self.args.host, self.args.port)) - sock.listen(1024) - sock.setblocking(False) - - # start processes - for idx in range(self.args.workers): - self.workers.append(Worker(self, self.loop, self.args, sock)) - - self.loop.add_signal_handler(signal.SIGINT, lambda: self.loop.stop()) - self.loop.run_forever() - - -def main(): - args = ARGS.parse_args() - if ':' in args.host: - args.host, port = args.host.split(':', 1) - args.port = int(port) - - superviser = Superviser(args) - superviser.start() - - -if __name__ == '__main__': - main() diff --git a/tests/events_test.py b/tests/events_test.py index 1db35d14..460d2fe9 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -4,7 +4,6 @@ import gc import io import os -import re import signal import socket try: @@ -342,7 +341,7 @@ def remove_writer(): self.assertGreaterEqual(len(data), 200) def test_sock_client_ops(self): - with test_utils.run_test_server(self.loop) as httpd: + with test_utils.run_test_server() as httpd: sock = socket.socket() sock.setblocking(False) self.loop.run_until_complete( @@ -356,7 +355,7 @@ def test_sock_client_ops(self): self.loop.sock_recv(sock, 1024)) sock.close() - self.assertTrue(re.match(rb'HTTP/1.0 200 OK', data), data) + self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) def test_sock_client_fail(self): # Make sure that we will get an unused port @@ -470,7 +469,7 @@ def my_handler(*args): self.assertEqual(caught, 1) def test_create_connection(self): - with test_utils.run_test_server(self.loop) as httpd: + with test_utils.run_test_server() as httpd: f = self.loop.create_connection( lambda: MyProto(loop=self.loop), *httpd.address) tr, pr = self.loop.run_until_complete(f) @@ -481,7 +480,7 @@ def test_create_connection(self): tr.close() def test_create_connection_sock(self): - with test_utils.run_test_server(self.loop) as httpd: + with test_utils.run_test_server() as httpd: sock = None infos = self.loop.run_until_complete( self.loop.getaddrinfo( @@ -510,8 +509,7 @@ def test_create_connection_sock(self): @unittest.skipIf(ssl is None, 'No ssl module') def test_create_ssl_connection(self): - with test_utils.run_test_server( - self.loop, use_ssl=True) as httpd: + with test_utils.run_test_server(use_ssl=True) as httpd: f = self.loop.create_connection( lambda: MyProto(loop=self.loop), *httpd.address, ssl=True) tr, pr = self.loop.run_until_complete(f) @@ -525,7 +523,7 @@ def test_create_ssl_connection(self): tr.close() def test_create_connection_local_addr(self): - with test_utils.run_test_server(self.loop) as httpd: + with test_utils.run_test_server() as httpd: port = find_unused_port() f = self.loop.create_connection( lambda: MyProto(loop=self.loop), @@ -536,7 +534,7 @@ def test_create_connection_local_addr(self): tr.close() def test_create_connection_local_addr_in_use(self): - with test_utils.run_test_server(self.loop) as httpd: + with test_utils.run_test_server() as httpd: f = self.loop.create_connection( lambda: MyProto(loop=self.loop), *httpd.address, local_addr=httpd.address) diff --git a/tests/http_client_functional_test.py b/tests/http_client_functional_test.py deleted file mode 100644 index 91badfc4..00000000 --- a/tests/http_client_functional_test.py +++ /dev/null @@ -1,552 +0,0 @@ -"""Http client functional tests.""" - -import gc -import io -import os.path -import http.cookies -import unittest - -import tulip -import tulip.http -from tulip import test_utils -from tulip.http import client - - -class HttpClientFunctionalTests(unittest.TestCase): - - def setUp(self): - self.loop = tulip.new_event_loop() - tulip.set_event_loop(None) - - def tearDown(self): - # just in case if we have transport close callbacks - test_utils.run_briefly(self.loop) - - self.loop.close() - gc.collect() - - def test_HTTP_200_OK_METHOD(self): - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - for meth in ('get', 'post', 'put', 'delete', 'head'): - r = self.loop.run_until_complete( - client.request(meth, httpd.url('method', meth), - loop=self.loop)) - content1 = self.loop.run_until_complete(r.read()) - content2 = self.loop.run_until_complete(r.read()) - content = content1.decode() - - self.assertEqual(r.status, 200) - self.assertIn('"method": "%s"' % meth.upper(), content) - self.assertEqual(content1, content2) - r.close() - - def test_use_global_loop(self): - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - try: - tulip.set_event_loop(self.loop) - r = self.loop.run_until_complete( - client.request('get', httpd.url('method', 'get'))) - finally: - tulip.set_event_loop(None) - content1 = self.loop.run_until_complete(r.read()) - content2 = self.loop.run_until_complete(r.read()) - content = content1.decode() - - self.assertEqual(r.status, 200) - self.assertIn('"method": "GET"', content) - self.assertEqual(content1, content2) - r.close() - - def test_HTTP_302_REDIRECT_GET(self): - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - r = self.loop.run_until_complete( - client.request('get', httpd.url('redirect', 2), - loop=self.loop)) - - self.assertEqual(r.status, 200) - self.assertEqual(2, httpd['redirects']) - r.close() - - def test_HTTP_302_REDIRECT_NON_HTTP(self): - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - self.assertRaises( - ValueError, - self.loop.run_until_complete, - client.request('get', httpd.url('redirect_err'), - loop=self.loop)) - - def test_HTTP_302_REDIRECT_POST(self): - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - r = self.loop.run_until_complete( - client.request('post', httpd.url('redirect', 2), - data={'some': 'data'}, loop=self.loop)) - content = self.loop.run_until_complete(r.content.read()) - content = content.decode() - - self.assertEqual(r.status, 200) - self.assertIn('"method": "POST"', content) - self.assertEqual(2, httpd['redirects']) - r.close() - - def test_HTTP_302_max_redirects(self): - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - r = self.loop.run_until_complete( - client.request('get', httpd.url('redirect', 5), - max_redirects=2, loop=self.loop)) - - self.assertEqual(r.status, 302) - self.assertEqual(2, httpd['redirects']) - r.close() - - def test_HTTP_200_GET_WITH_PARAMS(self): - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - r = self.loop.run_until_complete( - client.request('get', httpd.url('method', 'get'), - params={'q': 'test'}, loop=self.loop)) - content = self.loop.run_until_complete(r.content.read()) - content = content.decode() - - self.assertIn('"query": "q=test"', content) - self.assertEqual(r.status, 200) - r.close() - - def test_HTTP_200_GET_WITH_MIXED_PARAMS(self): - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - r = self.loop.run_until_complete( - client.request( - 'get', httpd.url('method', 'get') + '?test=true', - params={'q': 'test'}, loop=self.loop)) - content = self.loop.run_until_complete(r.content.read()) - content = content.decode() - - self.assertIn('"query": "test=true&q=test"', content) - self.assertEqual(r.status, 200) - r.close() - - def test_POST_DATA(self): - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - url = httpd.url('method', 'post') - r = self.loop.run_until_complete( - client.request('post', url, data={'some': 'data'}, - loop=self.loop)) - self.assertEqual(r.status, 200) - - content = self.loop.run_until_complete(r.read(True)) - self.assertEqual({'some': ['data']}, content['form']) - self.assertEqual(r.status, 200) - r.close() - - def test_POST_DATA_DEFLATE(self): - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - url = httpd.url('method', 'post') - r = self.loop.run_until_complete( - client.request('post', url, - data={'some': 'data'}, compress=True, - loop=self.loop)) - self.assertEqual(r.status, 200) - - content = self.loop.run_until_complete(r.read(True)) - self.assertEqual('deflate', content['compression']) - self.assertEqual({'some': ['data']}, content['form']) - self.assertEqual(r.status, 200) - r.close() - - def test_POST_FILES(self): - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - url = httpd.url('method', 'post') - - with open(__file__) as f: - r = self.loop.run_until_complete( - client.request( - 'post', url, files={'some': f}, chunked=1024, - headers={'Transfer-Encoding': 'chunked'}, - loop=self.loop)) - content = self.loop.run_until_complete(r.read(True)) - - f.seek(0) - filename = os.path.split(f.name)[-1] - - self.assertEqual(1, len(content['multipart-data'])) - self.assertEqual( - 'some', content['multipart-data'][0]['name']) - self.assertEqual( - filename, content['multipart-data'][0]['filename']) - self.assertEqual( - f.read(), content['multipart-data'][0]['data']) - self.assertEqual(r.status, 200) - r.close() - - def test_POST_FILES_DEFLATE(self): - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - url = httpd.url('method', 'post') - - with open(__file__) as f: - r = self.loop.run_until_complete( - client.request('post', url, files={'some': f}, - chunked=1024, compress='deflate', - loop=self.loop)) - - content = self.loop.run_until_complete(r.read(True)) - - f.seek(0) - filename = os.path.split(f.name)[-1] - - self.assertEqual('deflate', content['compression']) - self.assertEqual(1, len(content['multipart-data'])) - self.assertEqual( - 'some', content['multipart-data'][0]['name']) - self.assertEqual( - filename, content['multipart-data'][0]['filename']) - self.assertEqual( - f.read(), content['multipart-data'][0]['data']) - self.assertEqual(r.status, 200) - r.close() - - def test_POST_FILES_STR(self): - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - url = httpd.url('method', 'post') - - with open(__file__) as f: - r = self.loop.run_until_complete( - client.request('post', url, files=[('some', f.read())], - loop=self.loop)) - - content = self.loop.run_until_complete(r.read(True)) - - f.seek(0) - self.assertEqual(1, len(content['multipart-data'])) - self.assertEqual( - 'some', content['multipart-data'][0]['name']) - self.assertEqual( - 'some', content['multipart-data'][0]['filename']) - self.assertEqual( - f.read(), content['multipart-data'][0]['data']) - self.assertEqual(r.status, 200) - r.close() - - def test_POST_FILES_LIST(self): - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - url = httpd.url('method', 'post') - - with open(__file__) as f: - r = self.loop.run_until_complete( - client.request('post', url, files=[('some', f)], - loop=self.loop)) - - content = self.loop.run_until_complete(r.read(True)) - - f.seek(0) - filename = os.path.split(f.name)[-1] - - self.assertEqual(1, len(content['multipart-data'])) - self.assertEqual( - 'some', content['multipart-data'][0]['name']) - self.assertEqual( - filename, content['multipart-data'][0]['filename']) - self.assertEqual( - f.read(), content['multipart-data'][0]['data']) - self.assertEqual(r.status, 200) - r.close() - - def test_POST_FILES_LIST_CT(self): - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - url = httpd.url('method', 'post') - - with open(__file__) as f: - r = self.loop.run_until_complete( - client.request('post', url, loop=self.loop, - files=[('some', f, 'text/plain')])) - - content = self.loop.run_until_complete(r.read(True)) - - f.seek(0) - filename = os.path.split(f.name)[-1] - - self.assertEqual(1, len(content['multipart-data'])) - self.assertEqual( - 'some', content['multipart-data'][0]['name']) - self.assertEqual( - filename, content['multipart-data'][0]['filename']) - self.assertEqual( - f.read(), content['multipart-data'][0]['data']) - self.assertEqual( - 'text/plain', content['multipart-data'][0]['content-type']) - self.assertEqual(r.status, 200) - r.close() - - def test_POST_FILES_SINGLE(self): - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - url = httpd.url('method', 'post') - - with open(__file__) as f: - r = self.loop.run_until_complete( - client.request('post', url, files=[f], loop=self.loop)) - - content = self.loop.run_until_complete(r.read(True)) - - f.seek(0) - filename = os.path.split(f.name)[-1] - - self.assertEqual(1, len(content['multipart-data'])) - self.assertEqual( - filename, content['multipart-data'][0]['name']) - self.assertEqual( - filename, content['multipart-data'][0]['filename']) - self.assertEqual( - f.read(), content['multipart-data'][0]['data']) - self.assertEqual(r.status, 200) - r.close() - - def test_POST_FILES_IO(self): - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - url = httpd.url('method', 'post') - - data = io.BytesIO(b'data') - - r = self.loop.run_until_complete( - client.request('post', url, files=[data], loop=self.loop)) - - content = self.loop.run_until_complete(r.read(True)) - - self.assertEqual(1, len(content['multipart-data'])) - self.assertEqual( - {'content-type': 'application/octet-stream', - 'data': 'data', - 'filename': 'unknown', - 'name': 'unknown'}, content['multipart-data'][0]) - self.assertEqual(r.status, 200) - r.close() - - def test_POST_FILES_WITH_DATA(self): - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - url = httpd.url('method', 'post') - - with open(__file__) as f: - r = self.loop.run_until_complete( - client.request('post', url, loop=self.loop, - data={'test': 'true'}, files={'some': f})) - - content = self.loop.run_until_complete(r.read(True)) - - self.assertEqual(2, len(content['multipart-data'])) - self.assertEqual( - 'test', content['multipart-data'][0]['name']) - self.assertEqual( - 'true', content['multipart-data'][0]['data']) - - f.seek(0) - filename = os.path.split(f.name)[-1] - self.assertEqual( - 'some', content['multipart-data'][1]['name']) - self.assertEqual( - filename, content['multipart-data'][1]['filename']) - self.assertEqual( - f.read(), content['multipart-data'][1]['data']) - self.assertEqual(r.status, 200) - r.close() - - def test_encoding(self): - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - r = self.loop.run_until_complete( - client.request('get', httpd.url('encoding', 'deflate'), - loop=self.loop)) - self.assertEqual(r.status, 200) - - r = self.loop.run_until_complete( - client.request('get', httpd.url('encoding', 'gzip'), - loop=self.loop)) - self.assertEqual(r.status, 200) - r.close() - - def test_cookies(self): - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - c = http.cookies.Morsel() - c.set('test3', '456', '456') - - r = self.loop.run_until_complete( - client.request( - 'get', httpd.url('method', 'get'), loop=self.loop, - cookies={'test1': '123', 'test2': c})) - self.assertEqual(r.status, 200) - - content = self.loop.run_until_complete(r.content.read()) - self.assertIn(b'"Cookie": "test1=123; test3=456"', bytes(content)) - r.close() - - def test_set_cookies(self): - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - resp = self.loop.run_until_complete( - client.request('get', httpd.url('cookies'), loop=self.loop)) - self.assertEqual(resp.status, 200) - - self.assertEqual(resp.cookies['c1'].value, 'cookie1') - self.assertEqual(resp.cookies['c2'].value, 'cookie2') - resp.close() - - def test_chunked(self): - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - r = self.loop.run_until_complete( - client.request('get', httpd.url('chunked'), loop=self.loop)) - self.assertEqual(r.status, 200) - self.assertEqual(r['Transfer-Encoding'], 'chunked') - content = self.loop.run_until_complete(r.read(True)) - self.assertEqual(content['path'], '/chunked') - r.close() - - def test_timeout(self): - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - httpd['noresponse'] = True - self.assertRaises( - tulip.TimeoutError, - self.loop.run_until_complete, - client.request('get', httpd.url('method', 'get'), - timeout=0.1, loop=self.loop)) - - def test_request_conn_error(self): - self.assertRaises( - OSError, - self.loop.run_until_complete, - client.request('get', 'http://0.0.0.0:1', - timeout=0.1, loop=self.loop)) - - def test_request_conn_closed(self): - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - httpd['close'] = True - self.assertRaises( - tulip.http.HttpException, - self.loop.run_until_complete, - client.request('get', httpd.url('method', 'get'), - loop=self.loop)) - - def test_keepalive(self): - from tulip.http import session - s = session.Session() - - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - r = self.loop.run_until_complete( - client.request('get', httpd.url('keepalive',), - session=s, loop=self.loop)) - self.assertEqual(r.status, 200) - content = self.loop.run_until_complete(r.read(True)) - self.assertEqual(content['content'], 'requests=1') - r.close() - - r = self.loop.run_until_complete( - client.request('get', httpd.url('keepalive'), - session=s, loop=self.loop)) - self.assertEqual(r.status, 200) - content = self.loop.run_until_complete(r.read(True)) - self.assertEqual(content['content'], 'requests=2') - r.close() - - def test_session_close(self): - from tulip.http import session - s = session.Session() - - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - r = self.loop.run_until_complete( - client.request( - 'get', httpd.url('keepalive') + '?close=1', - session=s, loop=self.loop)) - self.assertEqual(r.status, 200) - content = self.loop.run_until_complete(r.read(True)) - self.assertEqual(content['content'], 'requests=1') - r.close() - - r = self.loop.run_until_complete( - client.request('get', httpd.url('keepalive'), - session=s, loop=self.loop)) - self.assertEqual(r.status, 200) - content = self.loop.run_until_complete(r.read(True)) - self.assertEqual(content['content'], 'requests=1') - r.close() - - def test_session_cookies(self): - from tulip.http import session - s = session.Session() - - with test_utils.run_test_server(self.loop, router=Functional) as httpd: - s.update_cookies({'test': '1'}) - r = self.loop.run_until_complete( - client.request('get', httpd.url('cookies'), - session=s, loop=self.loop)) - self.assertEqual(r.status, 200) - content = self.loop.run_until_complete(r.read(True)) - - self.assertEqual(content['headers']['Cookie'], 'test=1') - r.close() - - cookies = sorted([(k, v.value) for k, v in s.cookies.items()]) - self.assertEqual( - cookies, [('c1', 'cookie1'), ('c2', 'cookie2'), ('test', '1')]) - - -class Functional(test_utils.Router): - - @test_utils.Router.define('/method/([A-Za-z]+)$') - def method(self, match): - meth = match.group(1).upper() - if meth == self._method: - self._response(self._start_response(200)) - else: - self._response(self._start_response(400)) - - @test_utils.Router.define('/redirect_err$') - def redirect_err(self, match): - self._response( - self._start_response(302), - headers={'Location': 'ftp://127.0.0.1/test/'}) - - @test_utils.Router.define('/redirect/([0-9]+)$') - def redirect(self, match): - no = int(match.group(1).upper()) - rno = self._props['redirects'] = self._props.get('redirects', 0) + 1 - - if rno >= no: - self._response( - self._start_response(302), - headers={'Location': '/method/%s' % self._method.lower()}) - else: - self._response( - self._start_response(302), - headers={'Location': self._path}) - - @test_utils.Router.define('/encoding/(gzip|deflate)$') - def encoding(self, match): - mode = match.group(1) - - resp = self._start_response(200) - resp.add_compression_filter(mode) - resp.add_chunking_filter(100) - self._response(resp, headers={'Content-encoding': mode}, chunked=True) - - @test_utils.Router.define('/chunked$') - def chunked(self, match): - resp = self._start_response(200) - resp.add_chunking_filter(100) - self._response(resp, chunked=True) - - @test_utils.Router.define('/keepalive$') - def keepalive(self, match): - self._transport._requests = getattr( - self._transport, '_requests', 0) + 1 - resp = self._start_response(200) - if 'close=' in self._query: - self._response( - resp, 'requests={}'.format(self._transport._requests)) - else: - self._response( - resp, 'requests={}'.format(self._transport._requests), - headers={'CONNECTION': 'keep-alive'}) - - @test_utils.Router.define('/cookies$') - def cookies(self, match): - cookies = http.cookies.SimpleCookie() - cookies['c1'] = 'cookie1' - cookies['c2'] = 'cookie2' - - resp = self._start_response(200) - for cookie in cookies.output(header='').split('\n'): - resp.add_header('Set-Cookie', cookie.strip()) - - self._response(resp) diff --git a/tests/http_client_test.py b/tests/http_client_test.py deleted file mode 100644 index a911c975..00000000 --- a/tests/http_client_test.py +++ /dev/null @@ -1,298 +0,0 @@ -# -*- coding: utf-8 -*- -"""Tests for tulip/http/client.py""" - -import unittest -import unittest.mock -import urllib.parse - -import tulip -import tulip.http - -from tulip.http.client import HttpRequest, HttpResponse - - -class HttpResponseTests(unittest.TestCase): - - def setUp(self): - self.loop = tulip.new_event_loop() - tulip.set_event_loop(None) - - self.transport = unittest.mock.Mock() - self.stream = tulip.StreamBuffer(loop=self.loop) - self.response = HttpResponse('get', 'http://python.org') - - def tearDown(self): - self.loop.close() - - def test_close(self): - self.response.transport = self.transport - self.response.close() - self.assertIsNone(self.response.transport) - self.assertTrue(self.transport.close.called) - self.response.close() - self.response.close() - - def test_repr(self): - self.response.status = 200 - self.response.reason = 'Ok' - self.assertIn( - '', repr(self.response)) - - -class HttpRequestTests(unittest.TestCase): - - def setUp(self): - self.loop = tulip.new_event_loop() - tulip.set_event_loop(None) - - self.transport = unittest.mock.Mock() - self.stream = tulip.StreamBuffer(loop=self.loop) - - def tearDown(self): - self.loop.close() - - def test_method(self): - req = HttpRequest('get', 'http://python.org/') - self.assertEqual(req.method, 'GET') - - req = HttpRequest('head', 'http://python.org/') - self.assertEqual(req.method, 'HEAD') - - req = HttpRequest('HEAD', 'http://python.org/') - self.assertEqual(req.method, 'HEAD') - - def test_version(self): - req = HttpRequest('get', 'http://python.org/', version='1.0') - self.assertEqual(req.version, (1, 0)) - - def test_version_err(self): - self.assertRaises( - ValueError, - HttpRequest, 'get', 'http://python.org/', version='1.c') - - def test_host_port(self): - req = HttpRequest('get', 'http://python.org/') - self.assertEqual(req.host, 'python.org') - self.assertEqual(req.port, 80) - self.assertFalse(req.ssl) - - req = HttpRequest('get', 'https://python.org/') - self.assertEqual(req.host, 'python.org') - self.assertEqual(req.port, 443) - self.assertTrue(req.ssl) - - req = HttpRequest('get', 'https://python.org:960/') - self.assertEqual(req.host, 'python.org') - self.assertEqual(req.port, 960) - self.assertTrue(req.ssl) - - def test_host_port_err(self): - self.assertRaises( - ValueError, HttpRequest, 'get', 'http://python.org:123e/') - - def test_host_header(self): - req = HttpRequest('get', 'http://python.org/') - self.assertEqual(req.headers['host'], 'python.org') - - req = HttpRequest('get', 'http://python.org/', - headers={'host': 'example.com'}) - self.assertEqual(req.headers['host'], 'example.com') - - def test_headers(self): - req = HttpRequest('get', 'http://python.org/', - headers={'Content-Type': 'text/plain'}) - self.assertIn('Content-Type', req.headers) - self.assertEqual(req.headers['Content-Type'], 'text/plain') - self.assertEqual(req.headers['Accept-Encoding'], 'gzip, deflate') - - def test_headers_list(self): - req = HttpRequest('get', 'http://python.org/', - headers=[('Content-Type', 'text/plain')]) - self.assertIn('Content-Type', req.headers) - self.assertEqual(req.headers['Content-Type'], 'text/plain') - - def test_headers_default(self): - req = HttpRequest('get', 'http://python.org/', - headers={'Accept-Encoding': 'deflate'}) - self.assertEqual(req.headers['Accept-Encoding'], 'deflate') - - def test_invalid_url(self): - self.assertRaises(ValueError, HttpRequest, 'get', 'hiwpefhipowhefopw') - - def test_invalid_idna(self): - self.assertRaises( - ValueError, HttpRequest, 'get', 'http://\u2061owhefopw.com') - - def test_no_path(self): - req = HttpRequest('get', 'http://python.org') - self.assertEqual('/', req.path) - - def test_basic_auth(self): - req = HttpRequest('get', 'http://python.org', auth=('nkim', '1234')) - self.assertIn('Authorization', req.headers) - self.assertEqual('Basic bmtpbToxMjM0', req.headers['Authorization']) - - def test_basic_auth_from_url(self): - req = HttpRequest('get', 'http://nkim:1234@python.org') - self.assertIn('Authorization', req.headers) - self.assertEqual('Basic bmtpbToxMjM0', req.headers['Authorization']) - - req = HttpRequest('get', 'http://nkim@python.org') - self.assertIn('Authorization', req.headers) - self.assertEqual('Basic bmtpbTo=', req.headers['Authorization']) - - req = HttpRequest( - 'get', 'http://nkim@python.org', auth=('nkim', '1234')) - self.assertIn('Authorization', req.headers) - self.assertEqual('Basic bmtpbToxMjM0', req.headers['Authorization']) - - def test_basic_auth_err(self): - self.assertRaises( - ValueError, HttpRequest, - 'get', 'http://python.org', auth=(1, 2, 3)) - - def test_no_content_length(self): - req = HttpRequest('get', 'http://python.org') - req.send(self.transport) - self.assertEqual('0', req.headers.get('Content-Length')) - - req = HttpRequest('head', 'http://python.org') - req.send(self.transport) - self.assertEqual('0', req.headers.get('Content-Length')) - - def test_path_is_not_double_encoded(self): - req = HttpRequest('get', "http://0.0.0.0/get/test case") - self.assertEqual(req.path, "/get/test%20case") - - req = HttpRequest('get', "http://0.0.0.0/get/test%20case") - self.assertEqual(req.path, "/get/test%20case") - - def test_params_are_added_before_fragment(self): - req = HttpRequest( - 'GET', "http://example.com/path#fragment", params={"a": "b"}) - self.assertEqual( - req.path, "/path?a=b#fragment") - - req = HttpRequest( - 'GET', - "http://example.com/path?key=value#fragment", params={"a": "b"}) - self.assertEqual( - req.path, "/path?key=value&a=b#fragment") - - def test_cookies(self): - req = HttpRequest( - 'get', 'http://test.com/path', cookies={'cookie1': 'val1'}) - self.assertIn('Cookie', req.headers) - self.assertEqual('cookie1=val1', req.headers['cookie']) - - req = HttpRequest( - 'get', 'http://test.com/path', - headers={'cookie': 'cookie1=val1'}, - cookies={'cookie2': 'val2'}) - self.assertEqual('cookie1=val1; cookie2=val2', req.headers['cookie']) - - def test_unicode_get(self): - def join(*suffix): - return urllib.parse.urljoin('http://python.org/', '/'.join(suffix)) - - url = 'http://python.org' - req = HttpRequest('get', url, params={'foo': 'f\xf8\xf8'}) - self.assertEqual('/?foo=f%C3%B8%C3%B8', req.path) - req = HttpRequest('', url, params={'f\xf8\xf8': 'f\xf8\xf8'}) - self.assertEqual('/?f%C3%B8%C3%B8=f%C3%B8%C3%B8', req.path) - req = HttpRequest('', url, params={'foo': 'foo'}) - self.assertEqual('/?foo=foo', req.path) - req = HttpRequest('', join('\xf8'), params={'foo': 'foo'}) - self.assertEqual('/%C3%B8?foo=foo', req.path) - - def test_query_multivalued_param(self): - for meth in HttpRequest.ALL_METHODS: - req = HttpRequest( - meth, 'http://python.org', - params=(('test', 'foo'), ('test', 'baz'))) - self.assertEqual(req.path, '/?test=foo&test=baz') - - def test_post_data(self): - for meth in HttpRequest.POST_METHODS: - req = HttpRequest(meth, 'http://python.org/', data={'life': '42'}) - req.send(self.transport) - self.assertEqual('/', req.path) - self.assertEqual(b'life=42', req.body[0]) - self.assertEqual('application/x-www-form-urlencoded', - req.headers['content-type']) - - def test_get_with_data(self): - for meth in HttpRequest.GET_METHODS: - req = HttpRequest(meth, 'http://python.org/', data={'life': '42'}) - self.assertEqual('/?life=42', req.path) - - def test_bytes_data(self): - for meth in HttpRequest.POST_METHODS: - req = HttpRequest(meth, 'http://python.org/', data=b'binary data') - req.send(self.transport) - self.assertEqual('/', req.path) - self.assertEqual((b'binary data',), req.body) - self.assertEqual('application/octet-stream', - req.headers['content-type']) - - @unittest.mock.patch('tulip.http.client.tulip') - def test_content_encoding(self, m_tulip): - req = HttpRequest('get', 'http://python.org/', compress='deflate') - req.send(self.transport) - self.assertEqual(req.headers['Transfer-encoding'], 'chunked') - self.assertEqual(req.headers['Content-encoding'], 'deflate') - m_tulip.http.Request.return_value\ - .add_compression_filter.assert_called_with('deflate') - - @unittest.mock.patch('tulip.http.client.tulip') - def test_content_encoding_header(self, m_tulip): - req = HttpRequest('get', 'http://python.org/', - headers={'Content-Encoding': 'deflate'}) - req.send(self.transport) - self.assertEqual(req.headers['Transfer-encoding'], 'chunked') - self.assertEqual(req.headers['Content-encoding'], 'deflate') - - m_tulip.http.Request.return_value\ - .add_compression_filter.assert_called_with('deflate') - m_tulip.http.Request.return_value\ - .add_chunking_filter.assert_called_with(8196) - - def test_chunked(self): - req = HttpRequest( - 'get', 'http://python.org/', - headers={'Transfer-encoding': 'gzip'}) - req.send(self.transport) - self.assertEqual('gzip', req.headers['Transfer-encoding']) - - req = HttpRequest( - 'get', 'http://python.org/', - headers={'Transfer-encoding': 'chunked'}) - req.send(self.transport) - self.assertEqual('chunked', req.headers['Transfer-encoding']) - - @unittest.mock.patch('tulip.http.client.tulip') - def test_chunked_explicit(self, m_tulip): - req = HttpRequest( - 'get', 'http://python.org/', chunked=True) - req.send(self.transport) - - self.assertEqual('chunked', req.headers['Transfer-encoding']) - m_tulip.http.Request.return_value\ - .add_chunking_filter.assert_called_with(8196) - - @unittest.mock.patch('tulip.http.client.tulip') - def test_chunked_explicit_size(self, m_tulip): - req = HttpRequest( - 'get', 'http://python.org/', chunked=1024) - req.send(self.transport) - self.assertEqual('chunked', req.headers['Transfer-encoding']) - m_tulip.http.Request.return_value\ - .add_chunking_filter.assert_called_with(1024) - - def test_chunked_length(self): - req = HttpRequest( - 'get', 'http://python.org/', - headers={'Content-Length': '1000'}, chunked=1024) - req.send(self.transport) - self.assertEqual(req.headers['Transfer-Encoding'], 'chunked') - self.assertNotIn('Content-Length', req.headers) diff --git a/tests/http_parser_test.py b/tests/http_parser_test.py deleted file mode 100644 index 6240ad49..00000000 --- a/tests/http_parser_test.py +++ /dev/null @@ -1,539 +0,0 @@ -"""Tests for http/parser.py""" - -from collections import deque -import zlib -import unittest -import unittest.mock - -import tulip -from tulip.http import errors -from tulip.http import protocol - - -class ParseHeadersTests(unittest.TestCase): - - def setUp(self): - tulip.set_event_loop(None) - - def test_parse_headers(self): - hdrs = ('', 'test: line\r\n', ' continue\r\n', - 'test2: data\r\n', '\r\n') - - headers, close, compression = protocol.parse_headers( - hdrs, 8190, 32768, 8190) - - self.assertEqual(list(headers), - [('TEST', 'line\r\n continue'), ('TEST2', 'data')]) - self.assertIsNone(close) - self.assertIsNone(compression) - - def test_parse_headers_multi(self): - hdrs = ('', - 'Set-Cookie: c1=cookie1\r\n', - 'Set-Cookie: c2=cookie2\r\n', '\r\n') - - headers, close, compression = protocol.parse_headers( - hdrs, 8190, 32768, 8190) - - self.assertEqual(list(headers), - [('SET-COOKIE', 'c1=cookie1'), - ('SET-COOKIE', 'c2=cookie2')]) - self.assertIsNone(close) - self.assertIsNone(compression) - - def test_conn_close(self): - headers, close, compression = protocol.parse_headers( - ['', 'connection: close\r\n', '\r\n'], 8190, 32768, 8190) - self.assertTrue(close) - - def test_conn_keep_alive(self): - headers, close, compression = protocol.parse_headers( - ['', 'connection: keep-alive\r\n', '\r\n'], 8190, 32768, 8190) - self.assertFalse(close) - - def test_conn_other(self): - headers, close, compression = protocol.parse_headers( - ['', 'connection: test\r\n', '\r\n'], 8190, 32768, 8190) - self.assertIsNone(close) - - def test_compression_gzip(self): - headers, close, compression = protocol.parse_headers( - ['', 'content-encoding: gzip\r\n', '\r\n'], 8190, 32768, 8190) - self.assertEqual('gzip', compression) - - def test_compression_deflate(self): - headers, close, compression = protocol.parse_headers( - ['', 'content-encoding: deflate\r\n', '\r\n'], 8190, 32768, 8190) - self.assertEqual('deflate', compression) - - def test_compression_unknown(self): - headers, close, compression = protocol.parse_headers( - ['', 'content-encoding: compress\r\n', '\r\n'], 8190, 32768, 8190) - self.assertIsNone(compression) - - def test_max_field_size(self): - with self.assertRaises(errors.LineTooLong) as cm: - protocol.parse_headers( - ['', 'test: line data data\r\n', 'data\r\n', '\r\n'], - 8190, 32768, 5) - self.assertIn("limit request headers fields size", str(cm.exception)) - - def test_max_continuation_headers_size(self): - with self.assertRaises(errors.LineTooLong) as cm: - protocol.parse_headers( - ['', 'test: line\r\n', ' test\r\n', '\r\n'], 8190, 32768, 5) - self.assertIn("limit request headers fields size", str(cm.exception)) - - def test_invalid_header(self): - with self.assertRaises(ValueError) as cm: - protocol.parse_headers( - ['', 'test line\r\n', '\r\n'], 8190, 32768, 8190) - self.assertIn("Invalid header: test line", str(cm.exception)) - - def test_invalid_name(self): - with self.assertRaises(ValueError) as cm: - protocol.parse_headers( - ['', 'test[]: line\r\n', '\r\n'], 8190, 32768, 8190) - self.assertIn("Invalid header name: TEST[]", str(cm.exception)) - - -class DeflateBufferTests(unittest.TestCase): - - def setUp(self): - tulip.set_event_loop(None) - - def test_feed_data(self): - buf = tulip.DataBuffer() - dbuf = protocol.DeflateBuffer(buf, 'deflate') - - dbuf.zlib = unittest.mock.Mock() - dbuf.zlib.decompress.return_value = b'line' - - dbuf.feed_data(b'data') - self.assertEqual([b'line'], list(buf._buffer)) - - def test_feed_data_err(self): - buf = tulip.DataBuffer() - dbuf = protocol.DeflateBuffer(buf, 'deflate') - - exc = ValueError() - dbuf.zlib = unittest.mock.Mock() - dbuf.zlib.decompress.side_effect = exc - - self.assertRaises(errors.IncompleteRead, dbuf.feed_data, b'data') - - def test_feed_eof(self): - buf = tulip.DataBuffer() - dbuf = protocol.DeflateBuffer(buf, 'deflate') - - dbuf.zlib = unittest.mock.Mock() - dbuf.zlib.flush.return_value = b'line' - - dbuf.feed_eof() - self.assertEqual([b'line'], list(buf._buffer)) - self.assertTrue(buf._eof) - - def test_feed_eof_err(self): - buf = tulip.DataBuffer() - dbuf = protocol.DeflateBuffer(buf, 'deflate') - - dbuf.zlib = unittest.mock.Mock() - dbuf.zlib.flush.return_value = b'line' - dbuf.zlib.eof = False - - self.assertRaises(errors.IncompleteRead, dbuf.feed_eof) - - -class ParsePayloadTests(unittest.TestCase): - - def setUp(self): - tulip.set_event_loop(None) - - def test_parse_eof_payload(self): - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p = protocol.parse_eof_payload(out, buf) - next(p) - p.send(b'data') - try: - p.throw(tulip.EofStream()) - except tulip.EofStream: - pass - - self.assertEqual([b'data'], list(out._buffer)) - - def test_parse_length_payload(self): - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p = protocol.parse_length_payload(out, buf, 4) - next(p) - p.send(b'da') - p.send(b't') - try: - p.send(b'aline') - except StopIteration: - pass - - self.assertEqual(3, len(out._buffer)) - self.assertEqual(b'data', b''.join(out._buffer)) - self.assertEqual(b'line', bytes(buf)) - - def test_parse_length_payload_eof(self): - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p = protocol.parse_length_payload(out, buf, 4) - next(p) - p.send(b'da') - self.assertRaises( - errors.IncompleteRead, p.throw, tulip.EofStream) - - def test_parse_chunked_payload(self): - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p = protocol.parse_chunked_payload(out, buf) - next(p) - try: - p.send(b'4\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') - except StopIteration: - pass - self.assertEqual(b'dataline', b''.join(out._buffer)) - self.assertEqual(b'', bytes(buf)) - - def test_parse_chunked_payload_chunks(self): - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p = protocol.parse_chunked_payload(out, buf) - next(p) - p.send(b'4\r\ndata\r') - p.send(b'\n4') - p.send(b'\r') - p.send(b'\n') - p.send(b'line\r\n0\r\n') - self.assertRaises(StopIteration, p.send, b'test\r\n') - self.assertEqual(b'dataline', b''.join(out._buffer)) - - def test_parse_chunked_payload_incomplete(self): - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p = protocol.parse_chunked_payload(out, buf) - next(p) - p.send(b'4\r\ndata\r\n') - self.assertRaises(errors.IncompleteRead, p.throw, tulip.EofStream) - - def test_parse_chunked_payload_extension(self): - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p = protocol.parse_chunked_payload(out, buf) - next(p) - try: - p.send(b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') - except StopIteration: - pass - self.assertEqual(b'dataline', b''.join(out._buffer)) - - def test_parse_chunked_payload_size_error(self): - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p = protocol.parse_chunked_payload(out, buf) - next(p) - self.assertRaises(errors.IncompleteRead, p.send, b'blah\r\n') - - def test_http_payload_parser_length_broken(self): - msg = protocol.RawRequestMessage( - 'GET', '/', (1, 1), [('CONTENT-LENGTH', 'qwe')], None, None) - p = protocol.http_payload_parser(msg) - next(p) - - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - self.assertRaises(errors.InvalidHeader, p.send, (out, buf)) - - def test_http_payload_parser_length_wrong(self): - msg = protocol.RawRequestMessage( - 'GET', '/', (1, 1), [('CONTENT-LENGTH', '-1')], None, None) - p = protocol.http_payload_parser(msg) - next(p) - - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - self.assertRaises(errors.InvalidHeader, p.send, (out, buf)) - - def test_http_payload_parser_length(self): - msg = protocol.RawRequestMessage( - 'GET', '/', (1, 1), [('CONTENT-LENGTH', '2')], None, None) - p = protocol.http_payload_parser(msg) - next(p) - - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p.send((out, buf)) - try: - p.send(b'1245') - except StopIteration: - pass - - self.assertEqual(b'12', b''.join(out._buffer)) - self.assertEqual(b'45', bytes(buf)) - - def test_http_payload_parser_no_length(self): - msg = protocol.RawRequestMessage( - 'GET', '/', (1, 1), [], None, None) - p = protocol.http_payload_parser(msg, readall=False) - next(p) - - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - self.assertRaises(StopIteration, p.send, (out, buf)) - self.assertEqual(b'', b''.join(out._buffer)) - self.assertTrue(out._eof) - - _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) - _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) - - def test_http_payload_parser_deflate(self): - msg = protocol.RawRequestMessage( - 'GET', '/', (1, 1), [('CONTENT-LENGTH', len(self._COMPRESSED))], - None, 'deflate') - p = protocol.http_payload_parser(msg) - next(p) - - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p.send((out, buf)) - self.assertRaises(StopIteration, p.send, self._COMPRESSED) - self.assertEqual(b'data', b''.join(out._buffer)) - - def test_http_payload_parser_deflate_disabled(self): - msg = protocol.RawRequestMessage( - 'GET', '/', (1, 1), [('CONTENT-LENGTH', len(self._COMPRESSED))], - None, 'deflate') - p = protocol.http_payload_parser(msg, compression=False) - next(p) - - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p.send((out, buf)) - self.assertRaises(StopIteration, p.send, self._COMPRESSED) - self.assertEqual(self._COMPRESSED, b''.join(out._buffer)) - - def test_http_payload_parser_websocket(self): - msg = protocol.RawRequestMessage( - 'GET', '/', (1, 1), [('SEC-WEBSOCKET-KEY1', '13')], None, None) - p = protocol.http_payload_parser(msg) - next(p) - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p.send((out, buf)) - self.assertRaises(StopIteration, p.send, b'1234567890') - self.assertEqual(b'12345678', b''.join(out._buffer)) - - def test_http_payload_parser_chunked(self): - msg = protocol.RawRequestMessage( - 'GET', '/', (1, 1), [('TRANSFER-ENCODING', 'chunked')], None, None) - p = protocol.http_payload_parser(msg) - next(p) - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p.send((out, buf)) - self.assertRaises(StopIteration, p.send, - b'4;test\r\ndata\r\n4\r\nline\r\n0\r\ntest\r\n') - self.assertEqual(b'dataline', b''.join(out._buffer)) - - def test_http_payload_parser_eof(self): - msg = protocol.RawRequestMessage( - 'GET', '/', (1, 1), [], None, None) - p = protocol.http_payload_parser(msg, readall=True) - next(p) - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p.send((out, buf)) - p.send(b'data') - p.send(b'line') - self.assertRaises(tulip.EofStream, p.throw, tulip.EofStream()) - self.assertEqual(b'dataline', b''.join(out._buffer)) - - def test_http_payload_parser_length_zero(self): - msg = protocol.RawRequestMessage( - 'GET', '/', (1, 1), [('CONTENT-LENGTH', '0')], None, None) - p = protocol.http_payload_parser(msg) - next(p) - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - self.assertRaises(StopIteration, p.send, (out, buf)) - self.assertEqual(b'', b''.join(out._buffer)) - - -class ParseRequestTests(unittest.TestCase): - - def setUp(self): - tulip.set_event_loop(None) - - def test_http_request_parser_max_headers(self): - p = protocol.http_request_parser(8190, 20, 8190) - next(p) - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p.send((out, buf)) - - self.assertRaises( - errors.LineTooLong, - p.send, - b'get /path HTTP/1.1\r\ntest: line\r\ntest2: data\r\n\r\n') - - def test_http_request_parser(self): - p = protocol.http_request_parser() - next(p) - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p.send((out, buf)) - try: - p.send(b'get /path HTTP/1.1\r\n\r\n') - except StopIteration: - pass - result = out._buffer[0] - self.assertEqual( - ('GET', '/path', (1, 1), deque(), False, None), result) - - def test_http_request_parser_eof(self): - # http_request_parser does not fail on EofStream() - p = protocol.http_request_parser() - next(p) - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p.send((out, buf)) - p.send(b'get /path HTTP/1.1\r\n') - try: - p.throw(tulip.EofStream()) - except StopIteration: - pass - self.assertFalse(out._buffer) - - def test_http_request_parser_two_slashes(self): - p = protocol.http_request_parser() - next(p) - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p.send((out, buf)) - try: - p.send(b'get //path HTTP/1.1\r\n\r\n') - except StopIteration: - pass - self.assertEqual( - ('GET', '//path', (1, 1), deque(), False, None), out._buffer[0]) - - def test_http_request_parser_bad_status_line(self): - p = protocol.http_request_parser() - next(p) - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p.send((out, buf)) - self.assertRaises( - errors.BadStatusLine, p.send, b'\r\n\r\n') - - def test_http_request_parser_bad_method(self): - p = protocol.http_request_parser() - next(p) - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p.send((out, buf)) - self.assertRaises( - errors.BadStatusLine, - p.send, b'!12%()+=~$ /get HTTP/1.1\r\n\r\n') - - def test_http_request_parser_bad_version(self): - p = protocol.http_request_parser() - next(p) - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p.send((out, buf)) - self.assertRaises( - errors.BadStatusLine, - p.send, b'GET //get HT/11\r\n\r\n') - - -class ParseResponseTests(unittest.TestCase): - - def setUp(self): - tulip.set_event_loop(None) - - def test_http_response_parser_bad_status_line(self): - p = protocol.http_response_parser() - next(p) - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p.send((out, buf)) - self.assertRaises(errors.BadStatusLine, p.send, b'\r\n\r\n') - - def test_http_response_parser_bad_status_line_eof(self): - p = protocol.http_response_parser() - next(p) - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p.send((out, buf)) - self.assertRaises( - errors.BadStatusLine, p.throw, tulip.EofStream()) - - def test_http_response_parser_bad_version(self): - p = protocol.http_response_parser() - next(p) - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p.send((out, buf)) - with self.assertRaises(errors.BadStatusLine) as cm: - p.send(b'HT/11 200 Ok\r\n\r\n') - self.assertEqual('HT/11 200 Ok\r\n', cm.exception.args[0]) - - def test_http_response_parser_no_reason(self): - p = protocol.http_response_parser() - next(p) - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p.send((out, buf)) - try: - p.send(b'HTTP/1.1 200\r\n\r\n') - except StopIteration: - pass - v, s, r = out._buffer[0][:3] - self.assertEqual(v, (1, 1)) - self.assertEqual(s, 200) - self.assertEqual(r, '') - - def test_http_response_parser_bad(self): - p = protocol.http_response_parser() - next(p) - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p.send((out, buf)) - with self.assertRaises(errors.BadStatusLine) as cm: - p.send(b'HTT/1\r\n\r\n') - self.assertIn('HTT/1', str(cm.exception)) - - def test_http_response_parser_code_under_100(self): - p = protocol.http_response_parser() - next(p) - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p.send((out, buf)) - with self.assertRaises(errors.BadStatusLine) as cm: - p.send(b'HTTP/1.1 99 test\r\n\r\n') - self.assertIn('HTTP/1.1 99 test', str(cm.exception)) - - def test_http_response_parser_code_above_999(self): - p = protocol.http_response_parser() - next(p) - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p.send((out, buf)) - with self.assertRaises(errors.BadStatusLine) as cm: - p.send(b'HTTP/1.1 9999 test\r\n\r\n') - self.assertIn('HTTP/1.1 9999 test', str(cm.exception)) - - def test_http_response_parser_code_not_int(self): - p = protocol.http_response_parser() - next(p) - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p.send((out, buf)) - with self.assertRaises(errors.BadStatusLine) as cm: - p.send(b'HTTP/1.1 ttt test\r\n\r\n') - self.assertIn('HTTP/1.1 ttt test', str(cm.exception)) diff --git a/tests/http_protocol_test.py b/tests/http_protocol_test.py deleted file mode 100644 index 43201a8f..00000000 --- a/tests/http_protocol_test.py +++ /dev/null @@ -1,400 +0,0 @@ -"""Tests for http/protocol.py""" - -import unittest -import unittest.mock -import zlib - -import tulip -from tulip.http import protocol - - -class HttpMessageTests(unittest.TestCase): - - def setUp(self): - self.transport = unittest.mock.Mock() - tulip.set_event_loop(None) - - def test_start_request(self): - msg = protocol.Request( - self.transport, 'GET', '/index.html', close=True) - - self.assertIs(msg.transport, self.transport) - self.assertIsNone(msg.status) - self.assertTrue(msg.closing) - self.assertEqual(msg.status_line, 'GET /index.html HTTP/1.1\r\n') - - def test_start_response(self): - msg = protocol.Response(self.transport, 200, close=True) - - self.assertIs(msg.transport, self.transport) - self.assertEqual(msg.status, 200) - self.assertTrue(msg.closing) - self.assertEqual(msg.status_line, 'HTTP/1.1 200 OK\r\n') - - def test_force_close(self): - msg = protocol.Response(self.transport, 200) - self.assertFalse(msg.closing) - msg.force_close() - self.assertTrue(msg.closing) - - def test_force_chunked(self): - msg = protocol.Response(self.transport, 200) - self.assertFalse(msg.chunked) - msg.force_chunked() - self.assertTrue(msg.chunked) - - def test_keep_alive(self): - msg = protocol.Response(self.transport, 200, close=True) - self.assertFalse(msg.keep_alive()) - msg.keepalive = True - self.assertTrue(msg.keep_alive()) - - msg.force_close() - self.assertFalse(msg.keep_alive()) - - def test_keep_alive_http10(self): - msg = protocol.Response(self.transport, 200, http_version=(1, 0)) - self.assertFalse(msg.keepalive) - self.assertFalse(msg.keep_alive()) - - msg = protocol.Response(self.transport, 200, http_version=(1, 1)) - self.assertIsNone(msg.keepalive) - - def test_add_header(self): - msg = protocol.Response(self.transport, 200) - self.assertEqual([], list(msg.headers)) - - msg.add_header('content-type', 'plain/html') - self.assertEqual([('CONTENT-TYPE', 'plain/html')], list(msg.headers)) - - def test_add_headers(self): - msg = protocol.Response(self.transport, 200) - self.assertEqual([], list(msg.headers)) - - msg.add_headers(('content-type', 'plain/html')) - self.assertEqual([('CONTENT-TYPE', 'plain/html')], list(msg.headers)) - - def test_add_headers_length(self): - msg = protocol.Response(self.transport, 200) - self.assertIsNone(msg.length) - - msg.add_headers(('content-length', '42')) - self.assertEqual(42, msg.length) - - def test_add_headers_upgrade(self): - msg = protocol.Response(self.transport, 200) - self.assertFalse(msg.upgrade) - - msg.add_headers(('connection', 'upgrade')) - self.assertTrue(msg.upgrade) - - def test_add_headers_upgrade_websocket(self): - msg = protocol.Response(self.transport, 200) - - msg.add_headers(('upgrade', 'test')) - self.assertEqual([], list(msg.headers)) - - msg.add_headers(('upgrade', 'websocket')) - self.assertEqual([('UPGRADE', 'websocket')], list(msg.headers)) - - def test_add_headers_connection_keepalive(self): - msg = protocol.Response(self.transport, 200) - - msg.add_headers(('connection', 'keep-alive')) - self.assertEqual([], list(msg.headers)) - self.assertTrue(msg.keepalive) - - msg.add_headers(('connection', 'close')) - self.assertFalse(msg.keepalive) - - def test_add_headers_hop_headers(self): - msg = protocol.Response(self.transport, 200) - - msg.add_headers(('connection', 'test'), ('transfer-encoding', 't')) - self.assertEqual([], list(msg.headers)) - - def test_default_headers(self): - msg = protocol.Response(self.transport, 200) - msg._add_default_headers() - - headers = [r for r, _ in msg.headers] - self.assertIn('DATE', headers) - self.assertIn('CONNECTION', headers) - - def test_default_headers_server(self): - msg = protocol.Response(self.transport, 200) - msg._add_default_headers() - - headers = [r for r, _ in msg.headers] - self.assertIn('SERVER', headers) - - def test_default_headers_useragent(self): - msg = protocol.Request(self.transport, 'GET', '/') - msg._add_default_headers() - - headers = [r for r, _ in msg.headers] - self.assertNotIn('SERVER', headers) - self.assertIn('USER-AGENT', headers) - - def test_default_headers_chunked(self): - msg = protocol.Response(self.transport, 200) - msg._add_default_headers() - - headers = [r for r, _ in msg.headers] - self.assertNotIn('TRANSFER-ENCODING', headers) - - msg = protocol.Response(self.transport, 200) - msg.force_chunked() - msg._add_default_headers() - - headers = [r for r, _ in msg.headers] - self.assertIn('TRANSFER-ENCODING', headers) - - def test_default_headers_connection_upgrade(self): - msg = protocol.Response(self.transport, 200) - msg.upgrade = True - msg._add_default_headers() - - headers = [r for r in msg.headers if r[0] == 'CONNECTION'] - self.assertEqual([('CONNECTION', 'upgrade')], headers) - - def test_default_headers_connection_close(self): - msg = protocol.Response(self.transport, 200) - msg.force_close() - msg._add_default_headers() - - headers = [r for r in msg.headers if r[0] == 'CONNECTION'] - self.assertEqual([('CONNECTION', 'close')], headers) - - def test_default_headers_connection_keep_alive(self): - msg = protocol.Response(self.transport, 200) - msg.keepalive = True - msg._add_default_headers() - - headers = [r for r in msg.headers if r[0] == 'CONNECTION'] - self.assertEqual([('CONNECTION', 'keep-alive')], headers) - - def test_send_headers(self): - write = self.transport.write = unittest.mock.Mock() - - msg = protocol.Response(self.transport, 200) - msg.add_headers(('content-type', 'plain/html')) - self.assertFalse(msg.is_headers_sent()) - - msg.send_headers() - - content = b''.join([arg[1][0] for arg in list(write.mock_calls)]) - - self.assertTrue(content.startswith(b'HTTP/1.1 200 OK\r\n')) - self.assertIn(b'CONTENT-TYPE: plain/html', content) - self.assertTrue(msg.headers_sent) - self.assertTrue(msg.is_headers_sent()) - # cleanup - msg.writer.close() - - def test_send_headers_nomore_add(self): - msg = protocol.Response(self.transport, 200) - msg.add_headers(('content-type', 'plain/html')) - msg.send_headers() - - self.assertRaises(AssertionError, - msg.add_header, 'content-type', 'plain/html') - # cleanup - msg.writer.close() - - def test_prepare_length(self): - msg = protocol.Response(self.transport, 200) - w_l_p = msg._write_length_payload = unittest.mock.Mock() - w_l_p.return_value = iter([1, 2, 3]) - - msg.add_headers(('content-length', '42')) - msg.send_headers() - - self.assertTrue(w_l_p.called) - self.assertEqual((42,), w_l_p.call_args[0]) - - def test_prepare_chunked_force(self): - msg = protocol.Response(self.transport, 200) - msg.force_chunked() - - chunked = msg._write_chunked_payload = unittest.mock.Mock() - chunked.return_value = iter([1, 2, 3]) - - msg.add_headers(('content-length', '42')) - msg.send_headers() - self.assertTrue(chunked.called) - - def test_prepare_chunked_no_length(self): - msg = protocol.Response(self.transport, 200) - - chunked = msg._write_chunked_payload = unittest.mock.Mock() - chunked.return_value = iter([1, 2, 3]) - - msg.send_headers() - self.assertTrue(chunked.called) - - def test_prepare_eof(self): - msg = protocol.Response(self.transport, 200, http_version=(1, 0)) - - eof = msg._write_eof_payload = unittest.mock.Mock() - eof.return_value = iter([1, 2, 3]) - - msg.send_headers() - self.assertTrue(eof.called) - - def test_write_auto_send_headers(self): - msg = protocol.Response(self.transport, 200, http_version=(1, 0)) - msg._send_headers = True - - msg.write(b'data1') - self.assertTrue(msg.headers_sent) - # cleanup - msg.writer.close() - - def test_write_payload_eof(self): - write = self.transport.write = unittest.mock.Mock() - msg = protocol.Response(self.transport, 200, http_version=(1, 0)) - msg.send_headers() - - msg.write(b'data1') - self.assertTrue(msg.headers_sent) - - msg.write(b'data2') - msg.write_eof() - - content = b''.join([c[1][0] for c in list(write.mock_calls)]) - self.assertEqual( - b'data1data2', content.split(b'\r\n\r\n', 1)[-1]) - - def test_write_payload_chunked(self): - write = self.transport.write = unittest.mock.Mock() - - msg = protocol.Response(self.transport, 200) - msg.force_chunked() - msg.send_headers() - - msg.write(b'data') - msg.write_eof() - - content = b''.join([c[1][0] for c in list(write.mock_calls)]) - self.assertEqual( - b'4\r\ndata\r\n0\r\n\r\n', - content.split(b'\r\n\r\n', 1)[-1]) - - def test_write_payload_chunked_multiple(self): - write = self.transport.write = unittest.mock.Mock() - - msg = protocol.Response(self.transport, 200) - msg.force_chunked() - msg.send_headers() - - msg.write(b'data1') - msg.write(b'data2') - msg.write_eof() - - content = b''.join([c[1][0] for c in list(write.mock_calls)]) - self.assertEqual( - b'5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n', - content.split(b'\r\n\r\n', 1)[-1]) - - def test_write_payload_length(self): - write = self.transport.write = unittest.mock.Mock() - - msg = protocol.Response(self.transport, 200) - msg.add_headers(('content-length', '2')) - msg.send_headers() - - msg.write(b'd') - msg.write(b'ata') - msg.write_eof() - - content = b''.join([c[1][0] for c in list(write.mock_calls)]) - self.assertEqual( - b'da', content.split(b'\r\n\r\n', 1)[-1]) - - def test_write_payload_chunked_filter(self): - write = self.transport.write = unittest.mock.Mock() - - msg = protocol.Response(self.transport, 200) - msg.send_headers() - - msg.add_chunking_filter(2) - msg.write(b'data') - msg.write_eof() - - content = b''.join([c[1][0] for c in list(write.mock_calls)]) - self.assertTrue(content.endswith(b'2\r\nda\r\n2\r\nta\r\n0\r\n\r\n')) - - def test_write_payload_chunked_filter_mutiple_chunks(self): - write = self.transport.write = unittest.mock.Mock() - msg = protocol.Response(self.transport, 200) - msg.send_headers() - - msg.add_chunking_filter(2) - msg.write(b'data1') - msg.write(b'data2') - msg.write_eof() - content = b''.join([c[1][0] for c in list(write.mock_calls)]) - self.assertTrue(content.endswith( - b'2\r\nda\r\n2\r\nta\r\n2\r\n1d\r\n2\r\nat\r\n' - b'2\r\na2\r\n0\r\n\r\n')) - - def test_write_payload_chunked_large_chunk(self): - write = self.transport.write = unittest.mock.Mock() - msg = protocol.Response(self.transport, 200) - msg.send_headers() - - msg.add_chunking_filter(1024) - msg.write(b'data') - msg.write_eof() - content = b''.join([c[1][0] for c in list(write.mock_calls)]) - self.assertTrue(content.endswith(b'4\r\ndata\r\n0\r\n\r\n')) - - _comp = zlib.compressobj(wbits=-zlib.MAX_WBITS) - _COMPRESSED = b''.join([_comp.compress(b'data'), _comp.flush()]) - - def test_write_payload_deflate_filter(self): - write = self.transport.write = unittest.mock.Mock() - msg = protocol.Response(self.transport, 200) - msg.add_headers(('content-length', '{}'.format(len(self._COMPRESSED)))) - msg.send_headers() - - msg.add_compression_filter('deflate') - msg.write(b'data') - msg.write_eof() - - content = b''.join([c[1][0] for c in list(write.mock_calls)]) - self.assertEqual( - self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) - - def test_write_payload_deflate_and_chunked(self): - write = self.transport.write = unittest.mock.Mock() - msg = protocol.Response(self.transport, 200) - msg.send_headers() - - msg.add_compression_filter('deflate') - msg.add_chunking_filter(2) - - msg.write(b'data') - msg.write_eof() - - content = b''.join([c[1][0] for c in list(write.mock_calls)]) - self.assertEqual( - b'2\r\nKI\r\n2\r\n,I\r\n2\r\n\x04\x00\r\n0\r\n\r\n', - content.split(b'\r\n\r\n', 1)[-1]) - - def test_write_payload_chunked_and_deflate(self): - write = self.transport.write = unittest.mock.Mock() - msg = protocol.Response(self.transport, 200) - msg.add_headers(('content-length', '{}'.format(len(self._COMPRESSED)))) - - msg.add_chunking_filter(2) - msg.add_compression_filter('deflate') - msg.send_headers() - - msg.write(b'data') - msg.write_eof() - - content = b''.join([c[1][0] for c in list(write.mock_calls)]) - self.assertEqual( - self._COMPRESSED, content.split(b'\r\n\r\n', 1)[-1]) diff --git a/tests/http_server_test.py b/tests/http_server_test.py deleted file mode 100644 index 5c7a97a0..00000000 --- a/tests/http_server_test.py +++ /dev/null @@ -1,301 +0,0 @@ -"""Tests for http/server.py""" - -import unittest -import unittest.mock - -import tulip -from tulip.http import server -from tulip.http import errors -from tulip import test_utils - - -class HttpServerProtocolTests(unittest.TestCase): - - def setUp(self): - self.loop = test_utils.TestLoop() - tulip.set_event_loop(None) - - def tearDown(self): - self.loop.close() - - def test_http_error_exception(self): - exc = errors.HttpErrorException(500, message='Internal error') - self.assertEqual(exc.code, 500) - self.assertEqual(exc.message, 'Internal error') - - def test_handle_request(self): - transport = unittest.mock.Mock() - - srv = server.ServerHttpProtocol(loop=self.loop) - srv.connection_made(transport) - - rline = unittest.mock.Mock() - rline.version = (1, 1) - message = unittest.mock.Mock() - srv.handle_request(rline, message) - - content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) - self.assertTrue(content.startswith(b'HTTP/1.1 404 Not Found\r\n')) - - def test_connection_made(self): - srv = server.ServerHttpProtocol(loop=self.loop) - self.assertIsNone(srv._request_handler) - - srv.connection_made(unittest.mock.Mock()) - self.assertIsNotNone(srv._request_handler) - - def test_data_received(self): - srv = server.ServerHttpProtocol(loop=self.loop) - srv.connection_made(unittest.mock.Mock()) - - srv.data_received(b'123') - self.assertEqual(b'123', bytes(srv.stream._buffer)) - - srv.data_received(b'456') - self.assertEqual(b'123456', bytes(srv.stream._buffer)) - - def test_eof_received(self): - srv = server.ServerHttpProtocol(loop=self.loop) - srv.connection_made(unittest.mock.Mock()) - srv.eof_received() - self.assertTrue(srv.stream._eof) - - def test_connection_lost(self): - srv = server.ServerHttpProtocol(loop=self.loop) - srv.connection_made(unittest.mock.Mock()) - srv.data_received(b'123') - - keep_alive_handle = srv._keep_alive_handle = unittest.mock.Mock() - - handle = srv._request_handler - srv.connection_lost(None) - test_utils.run_briefly(self.loop) - - self.assertIsNone(srv._request_handler) - self.assertTrue(handle.cancelled()) - - self.assertIsNone(srv._keep_alive_handle) - self.assertTrue(keep_alive_handle.cancel.called) - - srv.connection_lost(None) - self.assertIsNone(srv._request_handler) - self.assertIsNone(srv._keep_alive_handle) - - def test_srv_keep_alive(self): - srv = server.ServerHttpProtocol(loop=self.loop) - self.assertFalse(srv._keep_alive) - - srv.keep_alive(True) - self.assertTrue(srv._keep_alive) - - srv.keep_alive(False) - self.assertFalse(srv._keep_alive) - - def test_handle_error(self): - transport = unittest.mock.Mock() - srv = server.ServerHttpProtocol(loop=self.loop) - srv.connection_made(transport) - srv.keep_alive(True) - - srv.handle_error(404, headers=(('X-Server', 'Tulip'),)) - content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) - self.assertIn(b'HTTP/1.1 404 Not Found', content) - self.assertIn(b'X-SERVER: Tulip', content) - self.assertFalse(srv._keep_alive) - - @unittest.mock.patch('tulip.http.server.traceback') - def test_handle_error_traceback_exc(self, m_trace): - transport = unittest.mock.Mock() - log = unittest.mock.Mock() - srv = server.ServerHttpProtocol(debug=True, log=log, loop=self.loop) - srv.connection_made(transport) - - m_trace.format_exc.side_effect = ValueError - - srv.handle_error(500, exc=object()) - content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) - self.assertTrue( - content.startswith(b'HTTP/1.1 500 Internal Server Error')) - self.assertTrue(log.exception.called) - - def test_handle_error_debug(self): - transport = unittest.mock.Mock() - srv = server.ServerHttpProtocol(loop=self.loop) - srv.debug = True - srv.connection_made(transport) - - try: - raise ValueError() - except Exception as exc: - srv.handle_error(999, exc=exc) - - content = b''.join([c[1][0] for c in list(transport.write.mock_calls)]) - - self.assertIn(b'HTTP/1.1 500 Internal', content) - self.assertIn(b'Traceback (most recent call last):', content) - - def test_handle_error_500(self): - log = unittest.mock.Mock() - transport = unittest.mock.Mock() - - srv = server.ServerHttpProtocol(log=log, loop=self.loop) - srv.connection_made(transport) - - srv.handle_error(500) - self.assertTrue(log.exception.called) - - def test_handle(self): - transport = unittest.mock.Mock() - srv = server.ServerHttpProtocol(loop=self.loop) - srv.connection_made(transport) - - handle = srv.handle_request = unittest.mock.Mock() - - srv.stream.feed_data( - b'GET / HTTP/1.0\r\n' - b'Host: example.com\r\n\r\n') - - self.loop.run_until_complete(srv._request_handler) - self.assertTrue(handle.called) - self.assertTrue(transport.close.called) - - def test_handle_coro(self): - transport = unittest.mock.Mock() - srv = server.ServerHttpProtocol(loop=self.loop) - - called = False - - @tulip.coroutine - def coro(message, payload): - nonlocal called - called = True - srv.eof_received() - - srv.handle_request = coro - srv.connection_made(transport) - - srv.stream.feed_data( - b'GET / HTTP/1.0\r\n' - b'Host: example.com\r\n\r\n') - self.loop.run_until_complete(srv._request_handler) - self.assertTrue(called) - - def test_handle_cancel(self): - log = unittest.mock.Mock() - transport = unittest.mock.Mock() - - srv = server.ServerHttpProtocol(log=log, debug=True, loop=self.loop) - srv.connection_made(transport) - - srv.handle_request = unittest.mock.Mock() - - @tulip.coroutine - def cancel(): - srv._request_handler.cancel() - - self.loop.run_until_complete( - tulip.wait([srv._request_handler, cancel()], loop=self.loop)) - self.assertTrue(log.debug.called) - - def test_handle_cancelled(self): - log = unittest.mock.Mock() - transport = unittest.mock.Mock() - - srv = server.ServerHttpProtocol(log=log, debug=True, loop=self.loop) - srv.connection_made(transport) - - srv.handle_request = unittest.mock.Mock() - test_utils.run_briefly(self.loop) # start request_handler task - - srv.stream.feed_data( - b'GET / HTTP/1.0\r\n' - b'Host: example.com\r\n\r\n') - - r_handler = srv._request_handler - srv._request_handler = None # emulate srv.connection_lost() - - self.assertIsNone(self.loop.run_until_complete(r_handler)) - - def test_handle_400(self): - transport = unittest.mock.Mock() - srv = server.ServerHttpProtocol(loop=self.loop) - srv.connection_made(transport) - srv.handle_error = unittest.mock.Mock() - srv.keep_alive(True) - srv.stream.feed_data(b'GET / HT/asd\r\n\r\n') - - self.loop.run_until_complete(srv._request_handler) - self.assertTrue(srv.handle_error.called) - self.assertEqual(400, srv.handle_error.call_args[0][0]) - self.assertTrue(transport.close.called) - - def test_handle_500(self): - transport = unittest.mock.Mock() - srv = server.ServerHttpProtocol(loop=self.loop) - srv.connection_made(transport) - - handle = srv.handle_request = unittest.mock.Mock() - handle.side_effect = ValueError - srv.handle_error = unittest.mock.Mock() - - srv.stream.feed_data( - b'GET / HTTP/1.0\r\n' - b'Host: example.com\r\n\r\n') - self.loop.run_until_complete(srv._request_handler) - - self.assertTrue(srv.handle_error.called) - self.assertEqual(500, srv.handle_error.call_args[0][0]) - - def test_handle_error_no_handle_task(self): - transport = unittest.mock.Mock() - srv = server.ServerHttpProtocol(loop=self.loop) - srv.keep_alive(True) - srv.connection_made(transport) - srv.connection_lost(None) - - srv.handle_error(300) - self.assertFalse(srv._keep_alive) - - def test_keep_alive(self): - srv = server.ServerHttpProtocol(keep_alive=0.1, loop=self.loop) - transport = unittest.mock.Mock() - closed = False - - def close(): - nonlocal closed - closed = True - srv.connection_lost(None) - self.loop.stop() - - transport.close = close - - srv.connection_made(transport) - - handle = srv.handle_request = unittest.mock.Mock() - - srv.stream.feed_data( - b'GET / HTTP/1.1\r\n' - b'CONNECTION: keep-alive\r\n' - b'HOST: example.com\r\n\r\n') - - self.loop.run_forever() - self.assertTrue(handle.called) - self.assertTrue(closed) - - def test_keep_alive_close_existing(self): - transport = unittest.mock.Mock() - srv = server.ServerHttpProtocol(keep_alive=15, loop=self.loop) - srv.connection_made(transport) - - self.assertIsNone(srv._keep_alive_handle) - keep_alive_handle = srv._keep_alive_handle = unittest.mock.Mock() - srv.handle_request = unittest.mock.Mock() - - srv.stream.feed_data( - b'GET / HTTP/1.0\r\n' - b'HOST: example.com\r\n\r\n') - - self.loop.run_until_complete(srv._request_handler) - self.assertTrue(keep_alive_handle.cancel.called) - self.assertIsNone(srv._keep_alive_handle) - self.assertTrue(transport.close.called) diff --git a/tests/http_session_test.py b/tests/http_session_test.py deleted file mode 100644 index 39a80091..00000000 --- a/tests/http_session_test.py +++ /dev/null @@ -1,139 +0,0 @@ -"""Tests for tulip/http/session.py""" - -import http.cookies -import unittest -import unittest.mock - -import tulip -import tulip.http - -from tulip.http.client import HttpResponse -from tulip.http.session import Session - -from tulip import test_utils - - -class HttpSessionTests(unittest.TestCase): - - def setUp(self): - self.loop = test_utils.TestLoop() - tulip.set_event_loop(self.loop) - - self.transport = unittest.mock.Mock() - self.stream = tulip.StreamBuffer() - self.response = HttpResponse('get', 'http://python.org') - - def tearDown(self): - tulip.set_event_loop(None) - self.loop.close() - - def test_del(self): - session = Session() - close = session.close = unittest.mock.Mock() - - del session - self.assertTrue(close.called) - - def test_close(self): - tr = unittest.mock.Mock() - - session = Session() - session._conns[1] = [(tr, object())] - session.close() - - self.assertFalse(session._conns) - self.assertTrue(tr.close.called) - - def test_get(self): - session = Session() - self.assertEqual(session._get(1), (None, None)) - - tr, proto = unittest.mock.Mock(), object() - session._conns[1] = [(tr, proto)] - self.assertEqual(session._get(1), (tr, proto)) - - def test_release(self): - session = Session() - resp = unittest.mock.Mock() - resp.message.should_close = False - - cookies = resp.cookies = http.cookies.SimpleCookie() - cookies['c1'] = 'cookie1' - cookies['c2'] = 'cookie2' - - tr, proto = unittest.mock.Mock(), unittest.mock.Mock() - session._release(resp, 1, (tr, proto)) - self.assertEqual(session._conns[1][0], (tr, proto)) - self.assertEqual(session.cookies, dict(cookies.items())) - - def test_release_close(self): - session = Session() - resp = unittest.mock.Mock() - resp.message.should_close = True - - cookies = resp.cookies = http.cookies.SimpleCookie() - cookies['c1'] = 'cookie1' - cookies['c2'] = 'cookie2' - - tr, proto = unittest.mock.Mock(), unittest.mock.Mock() - session._release(resp, 1, (tr, proto)) - self.assertFalse(session._conns) - self.assertTrue(tr.close.called) - - def test_call_new_conn_exc(self): - tr, proto = unittest.mock.Mock(), unittest.mock.Mock() - - class Req: - host = 'host' - port = 80 - ssl = False - - def send(self, *args): - raise ValueError() - - class Loop: - @tulip.coroutine - def create_connection(self, *args, **kw): - return tr, proto - - session = Session() - self.assertRaises( - ValueError, - self.loop.run_until_complete, session.start(Req(), Loop(), True)) - - self.assertTrue(tr.close.called) - - def test_call_existing_conn_exc(self): - existing = unittest.mock.Mock() - tr, proto = unittest.mock.Mock(), unittest.mock.Mock() - - class Req: - host = 'host' - port = 80 - ssl = False - - def send(self, transport): - if transport is existing: - transport.close() - raise ValueError() - else: - return Resp() - - class Resp: - @tulip.coroutine - def start(self, *args, **kw): - pass - - class Loop: - @tulip.coroutine - def create_connection(self, *args, **kw): - return tr, proto - - session = Session() - key = ('host', 80, False) - session._conns[key] = [(existing, object())] - - resp = self.loop.run_until_complete(session.start(Req(), Loop())) - self.assertIsInstance(resp, Resp) - self.assertTrue(existing.close.called) - self.assertFalse(session._conns[key]) diff --git a/tests/http_websocket_test.py b/tests/http_websocket_test.py deleted file mode 100644 index 319538ae..00000000 --- a/tests/http_websocket_test.py +++ /dev/null @@ -1,439 +0,0 @@ -"""Tests for http/websocket.py""" - -import base64 -import hashlib -import os -import struct -import unittest -import unittest.mock - -import tulip -from tulip.http import websocket, protocol, errors - - -class WebsocketParserTests(unittest.TestCase): - - def test_parse_frame(self): - buf = tulip.ParserBuffer() - p = websocket.parse_frame(buf) - next(p) - p.send(struct.pack('!BB', 0b00000001, 0b00000001)) - try: - p.send(b'1') - except StopIteration as exc: - fin, opcode, payload = exc.value - - self.assertEqual((0, 1, b'1'), (fin, opcode, payload)) - - def test_parse_frame_length0(self): - buf = tulip.ParserBuffer() - p = websocket.parse_frame(buf) - next(p) - try: - p.send(struct.pack('!BB', 0b00000001, 0b00000000)) - except StopIteration as exc: - fin, opcode, payload = exc.value - - self.assertEqual((0, 1, b''), (fin, opcode, payload)) - - def test_parse_frame_length2(self): - buf = tulip.ParserBuffer() - p = websocket.parse_frame(buf) - next(p) - p.send(struct.pack('!BB', 0b00000001, 126)) - p.send(struct.pack('!H', 4)) - try: - p.send(b'1234') - except StopIteration as exc: - fin, opcode, payload = exc.value - - self.assertEqual((0, 1, b'1234'), (fin, opcode, payload)) - - def test_parse_frame_length4(self): - buf = tulip.ParserBuffer() - p = websocket.parse_frame(buf) - next(p) - p.send(struct.pack('!BB', 0b00000001, 127)) - p.send(struct.pack('!Q', 4)) - try: - p.send(b'1234') - except StopIteration as exc: - fin, opcode, payload = exc.value - - self.assertEqual((0, 1, b'1234'), (fin, opcode, payload)) - - def test_parse_frame_mask(self): - buf = tulip.ParserBuffer() - p = websocket.parse_frame(buf) - next(p) - p.send(struct.pack('!BB', 0b00000001, 0b10000001)) - p.send(b'0001') - try: - p.send(b'1') - except StopIteration as exc: - fin, opcode, payload = exc.value - - self.assertEqual((0, 1, b'\x01'), (fin, opcode, payload)) - - def test_parse_frame_header_reversed_bits(self): - buf = tulip.ParserBuffer() - p = websocket.parse_frame(buf) - next(p) - self.assertRaises( - websocket.WebSocketError, - p.send, struct.pack('!BB', 0b01100000, 0b00000000)) - - def test_parse_frame_header_control_frame(self): - buf = tulip.ParserBuffer() - p = websocket.parse_frame(buf) - next(p) - self.assertRaises( - websocket.WebSocketError, - p.send, struct.pack('!BB', 0b00001000, 0b00000000)) - - def test_parse_frame_header_continuation(self): - buf = tulip.ParserBuffer() - p = websocket.parse_frame(buf) - next(p) - self.assertRaises( - websocket.WebSocketError, - p.send, struct.pack('!BB', 0b00000000, 0b00000000)) - - def test_parse_frame_header_new_data_err(self): - buf = tulip.ParserBuffer() - p = websocket.parse_frame(buf) - next(p) - self.assertRaises( - websocket.WebSocketError, - p.send, struct.pack('!BB', 0b000000000, 0b00000000)) - - def test_parse_frame_header_payload_size(self): - buf = tulip.ParserBuffer() - p = websocket.parse_frame(buf) - next(p) - self.assertRaises( - websocket.WebSocketError, - p.send, struct.pack('!BB', 0b10001000, 0b01111110)) - - @unittest.mock.patch('tulip.http.websocket.parse_frame') - def test_ping_frame(self, m_parse_frame): - def parse_frame(buf): - yield - return (1, websocket.OPCODE_PING, b'') - m_parse_frame.side_effect = parse_frame - buf = tulip.ParserBuffer() - p = websocket.parse_message(buf) - next(p) - try: - p.send(b'') - except StopIteration as exc: - res = exc.value - self.assertEqual(res, (websocket.OPCODE_PING, '', '')) - - @unittest.mock.patch('tulip.http.websocket.parse_frame') - def test_pong_frame(self, m_parse_frame): - def parse_frame(buf): - yield - return (1, websocket.OPCODE_PONG, b'') - m_parse_frame.side_effect = parse_frame - buf = tulip.ParserBuffer() - p = websocket.parse_message(buf) - next(p) - try: - p.send(b'') - except StopIteration as exc: - res = exc.value - self.assertEqual(res, (websocket.OPCODE_PONG, '', '')) - - @unittest.mock.patch('tulip.http.websocket.parse_frame') - def test_close_frame(self, m_parse_frame): - def parse_frame(buf): - yield - return (1, websocket.OPCODE_CLOSE, b'') - m_parse_frame.side_effect = parse_frame - buf = tulip.ParserBuffer() - p = websocket.parse_message(buf) - next(p) - try: - p.send(b'') - except StopIteration as exc: - res = exc.value - self.assertEqual(res, (websocket.OPCODE_CLOSE, '', '')) - - @unittest.mock.patch('tulip.http.websocket.parse_frame') - def test_close_frame_info(self, m_parse_frame): - def parse_frame(buf): - yield - return (1, websocket.OPCODE_CLOSE, b'0112345') - m_parse_frame.side_effect = parse_frame - buf = tulip.ParserBuffer() - p = websocket.parse_message(buf) - next(p) - try: - p.send(b'') - except StopIteration as exc: - res = exc.value - self.assertEqual(res, (websocket.OPCODE_CLOSE, 12337, b'12345')) - - @unittest.mock.patch('tulip.http.websocket.parse_frame') - def test_close_frame_invalid(self, m_parse_frame): - def parse_frame(buf): - yield - return (1, websocket.OPCODE_CLOSE, b'1') - m_parse_frame.side_effect = parse_frame - buf = tulip.ParserBuffer() - p = websocket.parse_message(buf) - next(p) - self.assertRaises(websocket.WebSocketError, p.send, b'') - - @unittest.mock.patch('tulip.http.websocket.parse_frame') - def test_unknown_frame(self, m_parse_frame): - def parse_frame(buf): - yield - return (1, websocket.OPCODE_CONTINUATION, b'') - m_parse_frame.side_effect = parse_frame - buf = tulip.ParserBuffer() - p = websocket.parse_message(buf) - next(p) - self.assertRaises(websocket.WebSocketError, p.send, b'') - - @unittest.mock.patch('tulip.http.websocket.parse_frame') - def test_simple_text(self, m_parse_frame): - def parse_frame(buf): - yield - return (1, websocket.OPCODE_TEXT, b'text') - m_parse_frame.side_effect = parse_frame - buf = tulip.ParserBuffer() - p = websocket.parse_message(buf) - next(p) - try: - p.send(b'') - except StopIteration as exc: - res = exc.value - self.assertEqual(res, (websocket.OPCODE_TEXT, 'text', '')) - - @unittest.mock.patch('tulip.http.websocket.parse_frame') - def test_simple_binary(self, m_parse_frame): - def parse_frame(buf): - yield - return (1, websocket.OPCODE_BINARY, b'binary') - m_parse_frame.side_effect = parse_frame - buf = tulip.ParserBuffer() - p = websocket.parse_message(buf) - next(p) - try: - p.send(b'') - except StopIteration as exc: - res = exc.value - self.assertEqual(res, (websocket.OPCODE_BINARY, b'binary', '')) - - @unittest.mock.patch('tulip.http.websocket.parse_frame') - def test_continuation(self, m_parse_frame): - cur = 0 - - def parse_frame(buf): - nonlocal cur - yield - if cur == 0: - cur = 1 - return (0, websocket.OPCODE_TEXT, b'line1') - else: - return (1, websocket.OPCODE_CONTINUATION, b'line2') - - m_parse_frame.side_effect = parse_frame - buf = tulip.ParserBuffer() - p = websocket.parse_message(buf) - next(p) - p.send(b'') - try: - p.send(b'') - except StopIteration as exc: - res = exc.value - self.assertEqual(res, (websocket.OPCODE_TEXT, 'line1line2', '')) - - @unittest.mock.patch('tulip.http.websocket.parse_frame') - def test_continuation_err(self, m_parse_frame): - cur = 0 - - def parse_frame(buf): - nonlocal cur - yield - if cur == 0: - cur = 1 - return (0, websocket.OPCODE_TEXT, b'line1') - else: - return (1, websocket.OPCODE_TEXT, b'line2') - - m_parse_frame.side_effect = parse_frame - buf = tulip.ParserBuffer() - p = websocket.parse_message(buf) - next(p) - p.send(b'') - self.assertRaises(websocket.WebSocketError, p.send, b'') - - @unittest.mock.patch('tulip.http.websocket.parse_message') - def test_parser(self, m_parse_message): - cur = 0 - - def parse_message(buf): - nonlocal cur - yield - if cur == 0: - cur = 1 - return websocket.Message(websocket.OPCODE_TEXT, b'line1', b'') - else: - return websocket.Message(websocket.OPCODE_CLOSE, b'', b'') - - m_parse_message.side_effect = parse_message - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p = websocket.WebSocketParser() - next(p) - p.send((out, buf)) - p.send(b'') - self.assertRaises(StopIteration, p.send, b'') - - self.assertEqual( - (websocket.OPCODE_TEXT, b'line1', b''), out._buffer[0]) - self.assertEqual( - (websocket.OPCODE_CLOSE, b'', b''), out._buffer[1]) - self.assertTrue(out._eof) - - def test_parser_eof(self): - out = tulip.DataBuffer() - buf = tulip.ParserBuffer() - p = websocket.WebSocketParser() - next(p) - p.send((out, buf)) - self.assertRaises(tulip.EofStream, p.throw, tulip.EofStream) - self.assertEqual([], list(out._buffer)) - - -class WebsocketWriterTests(unittest.TestCase): - - def setUp(self): - self.transport = unittest.mock.Mock() - self.writer = websocket.WebSocketWriter(self.transport) - - def test_pong(self): - self.writer.pong() - self.transport.write.assert_called_with(b'\x8a\x00') - - def test_ping(self): - self.writer.ping() - self.transport.write.assert_called_with(b'\x89\x00') - - def test_send_text(self): - self.writer.send(b'text') - self.transport.write.assert_called_with(b'\x81\x04text') - - def test_send_binary(self): - self.writer.send('binary', True) - self.transport.write.assert_called_with(b'\x82\x06binary') - - def test_send_binary_long(self): - self.writer.send(b'b'*127, True) - self.assertTrue( - self.transport.write.call_args[0][0].startswith(b'\x82~\x00\x7fb')) - - def test_send_binary_very_long(self): - self.writer.send(b'b'*65537, True) - self.assertTrue( - self.transport.write.call_args[0][0].startswith( - b'\x82\x7f\x00\x00\x00\x00\x00\x01\x00\x01b')) - - def test_close(self): - self.writer.close(1001, 'msg') - self.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') - - self.writer.close(1001, b'msg') - self.transport.write.assert_called_with(b'\x88\x05\x03\xe9msg') - - -class WebSocketHandshakeTests(unittest.TestCase): - - def setUp(self): - self.transport = unittest.mock.Mock() - self.headers = [] - self.message = protocol.RawRequestMessage( - 'GET', '/path', (1, 0), self.headers, True, None) - - def test_not_get(self): - self.assertRaises( - errors.HttpErrorException, - websocket.do_handshake, - 'POST', self.message.headers, self.transport) - - def test_no_upgrade(self): - self.assertRaises( - errors.BadRequestException, - websocket.do_handshake, - self.message.method, self.message.headers, self.transport) - - def test_no_connection(self): - self.headers.extend([('UPGRADE', 'websocket'), - ('CONNECTION', 'keep-alive')]) - self.assertRaises( - errors.BadRequestException, - websocket.do_handshake, - self.message.method, self.message.headers, self.transport) - - def test_protocol_version(self): - self.headers.extend([('UPGRADE', 'websocket'), - ('CONNECTION', 'upgrade')]) - self.assertRaises( - errors.BadRequestException, - websocket.do_handshake, - self.message.method, self.message.headers, self.transport) - - self.headers.extend([('UPGRADE', 'websocket'), - ('CONNECTION', 'upgrade'), - ('SEC-WEBSOCKET-VERSION', '1')]) - self.assertRaises( - errors.BadRequestException, - websocket.do_handshake, - self.message.method, self.message.headers, self.transport) - - def test_protocol_key(self): - self.headers.extend([('UPGRADE', 'websocket'), - ('CONNECTION', 'upgrade'), - ('SEC-WEBSOCKET-VERSION', '13')]) - self.assertRaises( - errors.BadRequestException, - websocket.do_handshake, - self.message.method, self.message.headers, self.transport) - - self.headers.extend([('UPGRADE', 'websocket'), - ('CONNECTION', 'upgrade'), - ('SEC-WEBSOCKET-VERSION', '13'), - ('SEC-WEBSOCKET-KEY', '123')]) - self.assertRaises( - errors.BadRequestException, - websocket.do_handshake, - self.message.method, self.message.headers, self.transport) - - sec_key = base64.b64encode(os.urandom(2)) - self.headers.extend([('UPGRADE', 'websocket'), - ('CONNECTION', 'upgrade'), - ('SEC-WEBSOCKET-VERSION', '13'), - ('SEC-WEBSOCKET-KEY', sec_key.decode())]) - self.assertRaises( - errors.BadRequestException, - websocket.do_handshake, - self.message.method, self.message.headers, self.transport) - - def test_handshake(self): - sec_key = base64.b64encode(os.urandom(16)).decode() - - self.headers.extend([('UPGRADE', 'websocket'), - ('CONNECTION', 'upgrade'), - ('SEC-WEBSOCKET-VERSION', '13'), - ('SEC-WEBSOCKET-KEY', sec_key)]) - status, headers, parser, writer = websocket.do_handshake( - self.message.method, self.message.headers, self.transport) - self.assertEqual(status, 101) - - key = base64.b64encode( - hashlib.sha1(sec_key.encode() + websocket.WS_KEY).digest()) - headers = dict(headers) - self.assertEqual(headers['SEC-WEBSOCKET-ACCEPT'], key.decode()) diff --git a/tests/http_wsgi_test.py b/tests/http_wsgi_test.py deleted file mode 100644 index 053f5a69..00000000 --- a/tests/http_wsgi_test.py +++ /dev/null @@ -1,301 +0,0 @@ -"""Tests for http/wsgi.py""" - -import io -import unittest -import unittest.mock - -import tulip -from tulip.http import wsgi -from tulip.http import protocol - - -class HttpWsgiServerProtocolTests(unittest.TestCase): - - def setUp(self): - self.loop = tulip.new_event_loop() - tulip.set_event_loop(None) - - self.wsgi = unittest.mock.Mock() - self.stream = unittest.mock.Mock() - self.transport = unittest.mock.Mock() - self.transport.get_extra_info.return_value = '127.0.0.1' - - self.headers = [] - self.message = protocol.RawRequestMessage( - 'GET', '/path', (1, 0), self.headers, True, 'deflate') - self.payload = tulip.DataBuffer() - self.payload.feed_data(b'data') - self.payload.feed_data(b'data') - self.payload.feed_eof() - - def tearDown(self): - self.loop.close() - - def test_ctor(self): - srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) - self.assertIs(srv.wsgi, self.wsgi) - self.assertFalse(srv.readpayload) - - def _make_one(self, **kw): - srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop, **kw) - srv.stream = self.stream - srv.transport = self.transport - return srv.create_wsgi_environ(self.message, self.payload) - - def test_environ(self): - environ = self._make_one() - self.assertEqual(environ['RAW_URI'], '/path') - self.assertEqual(environ['wsgi.async'], True) - - def test_environ_except_header(self): - self.headers.append(('EXPECT', '101-continue')) - self._make_one() - self.assertFalse(self.transport.write.called) - - self.headers[0] = ('EXPECT', '100-continue') - self._make_one() - self.transport.write.assert_called_with( - b'HTTP/1.1 100 Continue\r\n\r\n') - - def test_environ_headers(self): - self.headers.extend( - (('HOST', 'python.org'), - ('SCRIPT_NAME', 'script'), - ('CONTENT-TYPE', 'text/plain'), - ('CONTENT-LENGTH', '209'), - ('X_TEST', '123'), - ('X_TEST', '456'))) - environ = self._make_one(is_ssl=True) - self.assertEqual(environ['CONTENT_TYPE'], 'text/plain') - self.assertEqual(environ['CONTENT_LENGTH'], '209') - self.assertEqual(environ['HTTP_X_TEST'], '123,456') - self.assertEqual(environ['SCRIPT_NAME'], 'script') - self.assertEqual(environ['SERVER_NAME'], 'python.org') - self.assertEqual(environ['SERVER_PORT'], '443') - - def test_environ_host_header(self): - self.headers.append(('HOST', 'python.org')) - environ = self._make_one() - - self.assertEqual(environ['HTTP_HOST'], 'python.org') - self.assertEqual(environ['SERVER_NAME'], 'python.org') - self.assertEqual(environ['SERVER_PORT'], '80') - self.assertEqual(environ['SERVER_PROTOCOL'], 'HTTP/1.0') - - def test_environ_host_port_header(self): - self.message = protocol.RawRequestMessage( - 'GET', '/path', (1, 1), self.headers, True, 'deflate') - self.headers.append(('HOST', 'python.org:443')) - environ = self._make_one() - - self.assertEqual(environ['HTTP_HOST'], 'python.org:443') - self.assertEqual(environ['SERVER_NAME'], 'python.org') - self.assertEqual(environ['SERVER_PORT'], '443') - self.assertEqual(environ['SERVER_PROTOCOL'], 'HTTP/1.1') - - def test_environ_forward(self): - self.transport.get_extra_info.return_value = 'localhost,127.0.0.1' - environ = self._make_one() - - self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') - self.assertEqual(environ['REMOTE_PORT'], '80') - - self.transport.get_extra_info.return_value = 'localhost,127.0.0.1:443' - environ = self._make_one() - - self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') - self.assertEqual(environ['REMOTE_PORT'], '443') - - self.transport.get_extra_info.return_value = ('127.0.0.1', 443) - environ = self._make_one() - - self.assertEqual(environ['REMOTE_ADDR'], '127.0.0.1') - self.assertEqual(environ['REMOTE_PORT'], '443') - - self.transport.get_extra_info.return_value = '[::1]' - environ = self._make_one() - - self.assertEqual(environ['REMOTE_ADDR'], '::1') - self.assertEqual(environ['REMOTE_PORT'], '80') - - def test_wsgi_response(self): - srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) - srv.stream = self.stream - srv.transport = self.transport - - resp = srv.create_wsgi_response(self.message) - self.assertIsInstance(resp, wsgi.WsgiResponse) - - def test_wsgi_response_start_response(self): - srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) - srv.stream = self.stream - srv.transport = self.transport - - resp = srv.create_wsgi_response(self.message) - resp.start_response( - '200 OK', [('CONTENT-TYPE', 'text/plain')]) - self.assertEqual(resp.status, '200 OK') - self.assertIsInstance(resp.response, protocol.Response) - - def test_wsgi_response_start_response_exc(self): - srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) - srv.stream = self.stream - srv.transport = self.transport - - resp = srv.create_wsgi_response(self.message) - resp.start_response( - '200 OK', [('CONTENT-TYPE', 'text/plain')], ['', ValueError()]) - self.assertEqual(resp.status, '200 OK') - self.assertIsInstance(resp.response, protocol.Response) - - def test_wsgi_response_start_response_exc_status(self): - srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) - srv.stream = self.stream - srv.transport = self.transport - - resp = srv.create_wsgi_response(self.message) - resp.start_response('200 OK', [('CONTENT-TYPE', 'text/plain')]) - - self.assertRaises( - ValueError, - resp.start_response, - '500 Err', [('CONTENT-TYPE', 'text/plain')], ['', ValueError()]) - - @unittest.mock.patch('tulip.http.wsgi.tulip') - def test_wsgi_response_101_upgrade_to_websocket(self, m_tulip): - srv = wsgi.WSGIServerHttpProtocol(self.wsgi, loop=self.loop) - srv.stream = self.stream - srv.transport = self.transport - - resp = srv.create_wsgi_response(self.message) - resp.start_response( - '101 Switching Protocols', (('UPGRADE', 'websocket'), - ('CONNECTION', 'upgrade'))) - self.assertEqual(resp.status, '101 Switching Protocols') - self.assertTrue(m_tulip.http.Response.return_value.send_headers.called) - - def test_file_wrapper(self): - fobj = io.BytesIO(b'data') - wrapper = wsgi.FileWrapper(fobj, 2) - self.assertIs(wrapper, iter(wrapper)) - self.assertTrue(hasattr(wrapper, 'close')) - - self.assertEqual(next(wrapper), b'da') - self.assertEqual(next(wrapper), b'ta') - self.assertRaises(StopIteration, next, wrapper) - - wrapper = wsgi.FileWrapper(b'data', 2) - self.assertFalse(hasattr(wrapper, 'close')) - - def test_handle_request_futures(self): - - def wsgi_app(env, start): - start('200 OK', [('Content-Type', 'text/plain')]) - f1 = tulip.Future(loop=self.loop) - f1.set_result(b'data') - fut = tulip.Future(loop=self.loop) - fut.set_result([f1]) - return fut - - srv = wsgi.WSGIServerHttpProtocol(wsgi_app, loop=self.loop) - srv.stream = self.stream - srv.transport = self.transport - - self.loop.run_until_complete( - srv.handle_request(self.message, self.payload)) - - content = b''.join( - [c[1][0] for c in self.transport.write.mock_calls]) - self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) - self.assertTrue(content.endswith(b'data')) - - def test_handle_request_simple(self): - - def wsgi_app(env, start): - start('200 OK', [('Content-Type', 'text/plain')]) - return [b'data'] - - stream = tulip.StreamReader(loop=self.loop) - stream.feed_data(b'data') - stream.feed_eof() - - self.message = protocol.RawRequestMessage( - 'GET', '/path', (1, 1), self.headers, True, 'deflate') - - srv = wsgi.WSGIServerHttpProtocol( - wsgi_app, readpayload=True, loop=self.loop) - srv.stream = self.stream - srv.transport = self.transport - - self.loop.run_until_complete( - srv.handle_request(self.message, self.payload)) - - content = b''.join( - [c[1][0] for c in self.transport.write.mock_calls]) - self.assertTrue(content.startswith(b'HTTP/1.1 200 OK')) - self.assertTrue(content.endswith(b'data\r\n0\r\n\r\n')) - self.assertFalse(srv._keep_alive) - - def test_handle_request_io(self): - - def wsgi_app(env, start): - start('200 OK', [('Content-Type', 'text/plain')]) - return io.BytesIO(b'data') - - srv = wsgi.WSGIServerHttpProtocol(wsgi_app, loop=self.loop) - srv.stream = self.stream - srv.transport = self.transport - - self.loop.run_until_complete( - srv.handle_request(self.message, self.payload)) - - content = b''.join( - [c[1][0] for c in self.transport.write.mock_calls]) - self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) - self.assertTrue(content.endswith(b'data')) - - def test_handle_request_keep_alive(self): - - def wsgi_app(env, start): - start('200 OK', [('Content-Type', 'text/plain')]) - return [b'data'] - - stream = tulip.StreamReader(loop=self.loop) - stream.feed_data(b'data') - stream.feed_eof() - - self.message = protocol.RawRequestMessage( - 'GET', '/path', (1, 1), self.headers, False, 'deflate') - - srv = wsgi.WSGIServerHttpProtocol( - wsgi_app, readpayload=True, loop=self.loop) - srv.stream = self.stream - srv.transport = self.transport - - self.loop.run_until_complete( - srv.handle_request(self.message, self.payload)) - - content = b''.join( - [c[1][0] for c in self.transport.write.mock_calls]) - self.assertTrue(content.startswith(b'HTTP/1.1 200 OK')) - self.assertTrue(content.endswith(b'data\r\n0\r\n\r\n')) - self.assertTrue(srv._keep_alive) - - def test_handle_request_readpayload(self): - - def wsgi_app(env, start): - start('200 OK', [('Content-Type', 'text/plain')]) - return [env['wsgi.input'].read()] - - srv = wsgi.WSGIServerHttpProtocol( - wsgi_app, readpayload=True, loop=self.loop) - srv.stream = self.stream - srv.transport = self.transport - - self.loop.run_until_complete( - srv.handle_request(self.message, self.payload)) - - content = b''.join( - [c[1][0] for c in self.transport.write.mock_calls]) - self.assertTrue(content.startswith(b'HTTP/1.0 200 OK')) - self.assertTrue(content.endswith(b'data')) diff --git a/tests/parsers_test.py b/tests/parsers_test.py deleted file mode 100644 index c6b7cec2..00000000 --- a/tests/parsers_test.py +++ /dev/null @@ -1,605 +0,0 @@ -"""Tests for parser.py""" - -import unittest -import unittest.mock - -from tulip import events -from tulip import parsers -from tulip import tasks - - -class StreamBufferTests(unittest.TestCase): - - DATA = b'line1\nline2\nline3\n' - - def setUp(self): - self.loop = events.new_event_loop() - events.set_event_loop(None) - - def tearDown(self): - self.loop.close() - - def test_exception(self): - stream = parsers.StreamBuffer() - self.assertIsNone(stream.exception()) - - exc = ValueError() - stream.set_exception(exc) - self.assertIs(stream.exception(), exc) - - def test_exception_waiter(self): - stream = parsers.StreamBuffer() - - stream._parser = parsers.lines_parser() - buf = stream._parser_buffer = parsers.DataBuffer(loop=self.loop) - - exc = ValueError() - stream.set_exception(exc) - self.assertIs(buf.exception(), exc) - - def test_feed_data(self): - stream = parsers.StreamBuffer() - - stream.feed_data(self.DATA) - self.assertEqual(self.DATA, bytes(stream._buffer)) - - def test_feed_empty_data(self): - stream = parsers.StreamBuffer() - - stream.feed_data(b'') - self.assertEqual(b'', bytes(stream._buffer)) - - def test_set_parser_unset_prev(self): - stream = parsers.StreamBuffer() - stream.set_parser(parsers.lines_parser()) - - unset = stream.unset_parser = unittest.mock.Mock() - stream.set_parser(parsers.lines_parser()) - - self.assertTrue(unset.called) - - def test_set_parser_exception(self): - stream = parsers.StreamBuffer() - - exc = ValueError() - stream.set_exception(exc) - s = stream.set_parser(parsers.lines_parser()) - self.assertIs(s.exception(), exc) - - def test_set_parser_feed_existing(self): - stream = parsers.StreamBuffer() - stream.feed_data(b'line1') - stream.feed_data(b'\r\nline2\r\ndata') - s = stream.set_parser(parsers.lines_parser()) - - self.assertEqual([bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], - list(s._buffer)) - self.assertEqual(b'data', bytes(stream._buffer)) - self.assertIsNotNone(stream._parser) - - stream.unset_parser() - self.assertIsNone(stream._parser) - self.assertEqual(b'data', bytes(stream._buffer)) - self.assertTrue(s._eof) - - def test_set_parser_feed_existing_exc(self): - - def p(): - yield # stream - raise ValueError() - - stream = parsers.StreamBuffer() - stream.feed_data(b'line1') - s = stream.set_parser(p()) - self.assertIsInstance(s.exception(), ValueError) - - def test_set_parser_feed_existing_eof(self): - stream = parsers.StreamBuffer() - stream.feed_data(b'line1') - stream.feed_data(b'\r\nline2\r\ndata') - stream.feed_eof() - s = stream.set_parser(parsers.lines_parser()) - - self.assertEqual([bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], - list(s._buffer)) - self.assertEqual(b'data', bytes(stream._buffer)) - self.assertIsNone(stream._parser) - - def test_set_parser_feed_existing_eof_exc(self): - - def p(): - yield # stream - try: - while True: - yield # read chunk - except parsers.EofStream: - raise ValueError() - - stream = parsers.StreamBuffer() - stream.feed_data(b'line1') - stream.feed_eof() - s = stream.set_parser(p()) - self.assertIsInstance(s.exception(), ValueError) - - def test_set_parser_feed_existing_eof_unhandled_eof(self): - - def p(): - yield # stream - while True: - yield # read chunk - - stream = parsers.StreamBuffer() - stream.feed_data(b'line1') - stream.feed_eof() - s = stream.set_parser(p()) - self.assertIsNone(s.exception()) - self.assertTrue(s._eof) - - def test_set_parser_unset(self): - stream = parsers.StreamBuffer() - s = stream.set_parser(parsers.lines_parser()) - - stream.feed_data(b'line1\r\nline2\r\n') - self.assertEqual( - [bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], - list(s._buffer)) - self.assertEqual(b'', bytes(stream._buffer)) - stream.unset_parser() - self.assertTrue(s._eof) - self.assertEqual(b'', bytes(stream._buffer)) - - def test_set_parser_feed_existing_stop(self): - def lines_parser(): - out, buf = yield - try: - out.feed_data((yield from buf.readuntil(b'\n'))) - out.feed_data((yield from buf.readuntil(b'\n'))) - finally: - out.feed_eof() - - stream = parsers.StreamBuffer() - stream.feed_data(b'line1') - stream.feed_data(b'\r\nline2\r\ndata') - s = stream.set_parser(lines_parser()) - - self.assertEqual(b'line1\r\nline2\r\n', b''.join(s._buffer)) - self.assertEqual(b'data', bytes(stream._buffer)) - self.assertIsNone(stream._parser) - self.assertTrue(s._eof) - - def test_feed_parser(self): - stream = parsers.StreamBuffer() - s = stream.set_parser(parsers.lines_parser()) - - stream.feed_data(b'line1') - stream.feed_data(b'\r\nline2\r\ndata') - self.assertEqual(b'data', bytes(stream._buffer)) - - stream.feed_eof() - self.assertEqual([bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], - list(s._buffer)) - self.assertEqual(b'data', bytes(stream._buffer)) - self.assertTrue(s._eof) - - def test_feed_parser_exc(self): - def p(): - yield # stream - yield # read chunk - raise ValueError() - - stream = parsers.StreamBuffer() - s = stream.set_parser(p()) - - stream.feed_data(b'line1') - self.assertIsInstance(s.exception(), ValueError) - self.assertEqual(b'', bytes(stream._buffer)) - - def test_feed_parser_stop(self): - def p(): - yield # stream - yield # chunk - - stream = parsers.StreamBuffer() - stream.set_parser(p()) - - stream.feed_data(b'line1') - self.assertIsNone(stream._parser) - self.assertEqual(b'', bytes(stream._buffer)) - - def test_feed_eof_exc(self): - def p(): - yield # stream - try: - while True: - yield # read chunk - except parsers.EofStream: - raise ValueError() - - stream = parsers.StreamBuffer() - s = stream.set_parser(p()) - - stream.feed_data(b'line1') - self.assertIsNone(s.exception()) - - stream.feed_eof() - self.assertIsInstance(s.exception(), ValueError) - - def test_feed_eof_stop(self): - def p(): - out, buf = yield # stream - try: - while True: - yield # read chunk - except parsers.EofStream: - out.feed_eof() - - stream = parsers.StreamBuffer() - s = stream.set_parser(p()) - - stream.feed_data(b'line1') - stream.feed_eof() - self.assertTrue(s._eof) - - def test_feed_eof_unhandled_eof(self): - def p(): - yield # stream - while True: - yield # read chunk - - stream = parsers.StreamBuffer() - s = stream.set_parser(p()) - - stream.feed_data(b'line1') - stream.feed_eof() - self.assertIsNone(s.exception()) - self.assertTrue(s._eof) - - def test_feed_parser2(self): - stream = parsers.StreamBuffer() - s = stream.set_parser(parsers.lines_parser()) - - stream.feed_data(b'line1\r\nline2\r\n') - stream.feed_eof() - self.assertEqual( - [bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], - list(s._buffer)) - self.assertEqual(b'', bytes(stream._buffer)) - self.assertTrue(s._eof) - - def test_unset_parser_eof_exc(self): - def p(): - yield # stream - try: - while True: - yield # read chunk - except parsers.EofStream: - raise ValueError() - - stream = parsers.StreamBuffer() - s = stream.set_parser(p()) - - stream.feed_data(b'line1') - stream.unset_parser() - self.assertIsInstance(s.exception(), ValueError) - self.assertIsNone(stream._parser) - - def test_unset_parser_eof_unhandled_eof(self): - def p(): - yield # stream - while True: - yield # read chunk - - stream = parsers.StreamBuffer() - s = stream.set_parser(p()) - - stream.feed_data(b'line1') - stream.unset_parser() - self.assertIsNone(s.exception(), ValueError) - self.assertTrue(s._eof) - - def test_unset_parser_stop(self): - def p(): - out, buf = yield # stream - try: - while True: - yield # read chunk - except parsers.EofStream: - out.feed_eof() - - stream = parsers.StreamBuffer() - s = stream.set_parser(p()) - - stream.feed_data(b'line1') - stream.unset_parser() - self.assertTrue(s._eof) - - -class DataBufferTests(unittest.TestCase): - - def setUp(self): - self.loop = events.new_event_loop() - events.set_event_loop(None) - - def tearDown(self): - self.loop.close() - - def test_feed_data(self): - buffer = parsers.DataBuffer(loop=self.loop) - - item = object() - buffer.feed_data(item) - self.assertEqual([item], list(buffer._buffer)) - - def test_feed_eof(self): - buffer = parsers.DataBuffer(loop=self.loop) - buffer.feed_eof() - self.assertTrue(buffer._eof) - - def test_read(self): - item = object() - buffer = parsers.DataBuffer(loop=self.loop) - read_task = tasks.Task(buffer.read(), loop=self.loop) - - def cb(): - buffer.feed_data(item) - self.loop.call_soon(cb) - - data = self.loop.run_until_complete(read_task) - self.assertIs(item, data) - - def test_read_eof(self): - buffer = parsers.DataBuffer(loop=self.loop) - read_task = tasks.Task(buffer.read(), loop=self.loop) - - def cb(): - buffer.feed_eof() - self.loop.call_soon(cb) - - self.assertRaises( - parsers.EofStream, self.loop.run_until_complete, read_task) - - def test_read_until_eof(self): - item = object() - buffer = parsers.DataBuffer(loop=self.loop) - buffer.feed_data(item) - buffer.feed_eof() - - data = self.loop.run_until_complete(buffer.read()) - self.assertIs(data, item) - - self.assertRaises( - parsers.EofStream, self.loop.run_until_complete, buffer.read()) - - def test_read_exception(self): - buffer = parsers.DataBuffer(loop=self.loop) - buffer.feed_data(object()) - buffer.set_exception(ValueError()) - - self.assertRaises( - ValueError, self.loop.run_until_complete, buffer.read()) - - def test_exception(self): - buffer = parsers.DataBuffer(loop=self.loop) - self.assertIsNone(buffer.exception()) - - exc = ValueError() - buffer.set_exception(exc) - self.assertIs(buffer.exception(), exc) - - def test_exception_waiter(self): - buffer = parsers.DataBuffer(loop=self.loop) - - @tasks.coroutine - def set_err(): - buffer.set_exception(ValueError()) - - t1 = tasks.Task(buffer.read(), loop=self.loop) - t2 = tasks.Task(set_err(), loop=self.loop) - - self.loop.run_until_complete(tasks.wait([t1, t2], loop=self.loop)) - - self.assertRaises(ValueError, t1.result) - - -class StreamProtocolTests(unittest.TestCase): - - def test_connection_made(self): - tr = unittest.mock.Mock() - - proto = parsers.StreamProtocol() - self.assertIsNone(proto.transport) - - proto.connection_made(tr) - self.assertIs(proto.transport, tr) - - def test_connection_lost(self): - proto = parsers.StreamProtocol() - proto.connection_made(unittest.mock.Mock()) - proto.connection_lost(None) - self.assertIsNone(proto.transport) - self.assertTrue(proto._eof) - - def test_connection_lost_exc(self): - proto = parsers.StreamProtocol() - proto.connection_made(unittest.mock.Mock()) - - exc = ValueError() - proto.connection_lost(exc) - self.assertIs(proto.exception(), exc) - - -class ParserBufferTests(unittest.TestCase): - - def setUp(self): - self.loop = events.new_event_loop() - events.set_event_loop(None) - - def tearDown(self): - self.loop.close() - - def _make_one(self): - return parsers.ParserBuffer() - - def test_shrink(self): - buf = parsers.ParserBuffer() - buf.feed_data(b'data') - - buf._shrink() - self.assertEqual(bytes(buf), b'data') - - buf.offset = 2 - buf._shrink() - self.assertEqual(bytes(buf), b'ta') - self.assertEqual(2, len(buf)) - self.assertEqual(2, buf.size) - self.assertEqual(0, buf.offset) - - def test_feed_data(self): - buf = self._make_one() - buf.feed_data(b'') - self.assertEqual(len(buf), 0) - - buf.feed_data(b'data') - self.assertEqual(len(buf), 4) - self.assertEqual(bytes(buf), b'data') - - def test_read(self): - buf = self._make_one() - p = buf.read(3) - next(p) - p.send(b'1') - try: - p.send(b'234') - except StopIteration as exc: - res = exc.value - - self.assertEqual(res, b'123') - self.assertEqual(b'4', bytes(buf)) - - def test_readsome(self): - buf = self._make_one() - p = buf.readsome(3) - next(p) - try: - p.send(b'1') - except StopIteration as exc: - res = exc.value - self.assertEqual(res, b'1') - - p = buf.readsome(2) - next(p) - try: - p.send(b'234') - except StopIteration as exc: - res = exc.value - self.assertEqual(res, b'23') - self.assertEqual(b'4', bytes(buf)) - - def test_skip(self): - buf = self._make_one() - p = buf.skip(3) - next(p) - p.send(b'1') - try: - p.send(b'234') - except StopIteration as exc: - res = exc.value - - self.assertIsNone(res) - self.assertEqual(b'4', bytes(buf)) - - def test_readuntil_limit(self): - buf = self._make_one() - p = buf.readuntil(b'\n', 4) - next(p) - p.send(b'1') - p.send(b'234') - self.assertRaises(ValueError, p.send, b'5') - - buf = parsers.ParserBuffer() - p = buf.readuntil(b'\n', 4) - next(p) - self.assertRaises(ValueError, p.send, b'12345\n6') - - buf = parsers.ParserBuffer() - p = buf.readuntil(b'\n', 4) - next(p) - self.assertRaises(ValueError, p.send, b'12345\n6') - - class CustomExc(Exception): - pass - - buf = parsers.ParserBuffer() - p = buf.readuntil(b'\n', 4, CustomExc) - next(p) - self.assertRaises(CustomExc, p.send, b'12345\n6') - - def test_readuntil(self): - buf = self._make_one() - p = buf.readuntil(b'\n', 4) - next(p) - p.send(b'123') - try: - p.send(b'\n456') - except StopIteration as exc: - res = exc.value - - self.assertEqual(res, b'123\n') - self.assertEqual(b'456', bytes(buf)) - - def test_skipuntil(self): - buf = self._make_one() - p = buf.skipuntil(b'\n') - next(p) - p.send(b'123') - try: - p.send(b'\n456\n') - except StopIteration: - pass - self.assertEqual(b'456\n', bytes(buf)) - - p = buf.skipuntil(b'\n') - try: - next(p) - except StopIteration: - pass - self.assertEqual(b'', bytes(buf)) - - def test_lines_parser(self): - out = parsers.DataBuffer(loop=self.loop) - buf = self._make_one() - p = parsers.lines_parser() - next(p) - p.send((out, buf)) - - for d in (b'line1', b'\r\n', b'lin', b'e2\r', b'\ndata'): - p.send(d) - - self.assertEqual( - [bytearray(b'line1\r\n'), bytearray(b'line2\r\n')], - list(out._buffer)) - try: - p.throw(parsers.EofStream()) - except parsers.EofStream: - pass - - self.assertEqual(bytes(buf), b'data') - - def test_chunks_parser(self): - out = parsers.DataBuffer(loop=self.loop) - buf = self._make_one() - p = parsers.chunks_parser(5) - next(p) - p.send((out, buf)) - - for d in (b'line1', b'lin', b'e2d', b'ata'): - p.send(d) - - self.assertEqual( - [bytearray(b'line1'), bytearray(b'line2')], list(out._buffer)) - try: - p.throw(parsers.EofStream()) - except parsers.EofStream: - pass - - self.assertEqual(bytes(buf), b'data') diff --git a/tests/streams_test.py b/tests/streams_test.py index 49a43cab..c8ad7801 100644 --- a/tests/streams_test.py +++ b/tests/streams_test.py @@ -32,7 +32,7 @@ def test_ctor_global_loop(self, m_events): self.assertIs(stream.loop, m_events.get_event_loop.return_value) def test_open_connection(self): - with test_utils.run_test_server(self.loop) as httpd: + with test_utils.run_test_server() as httpd: f = streams.open_connection(*httpd.address, loop=self.loop) reader, writer = self.loop.run_until_complete(f) writer.write(b'GET / HTTP/1.0\r\n\r\n') @@ -47,7 +47,7 @@ def test_open_connection(self): @unittest.skipIf(ssl is None, 'No ssl module') def test_open_connection_no_loop_ssl(self): - with test_utils.run_test_server(self.loop, use_ssl=True) as httpd: + with test_utils.run_test_server(use_ssl=True) as httpd: try: events.set_event_loop(self.loop) f = streams.open_connection(*httpd.address, ssl=True) @@ -62,7 +62,7 @@ def test_open_connection_no_loop_ssl(self): writer.close() def test_open_connection_error(self): - with test_utils.run_test_server(self.loop) as httpd: + with test_utils.run_test_server() as httpd: f = streams.open_connection(*httpd.address, loop=self.loop) reader, writer = self.loop.run_until_complete(f) writer._protocol.connection_lost(ZeroDivisionError()) @@ -71,6 +71,7 @@ def test_open_connection_error(self): self.loop.run_until_complete(f) writer.close() + test_utils.run_briefly(self.loop) def test_feed_empty_data(self): stream = streams.StreamReader(loop=self.loop) diff --git a/tulip/__init__.py b/tulip/__init__.py index 9de84cb0..faf307fb 100644 --- a/tulip/__init__.py +++ b/tulip/__init__.py @@ -7,7 +7,6 @@ from .events import * from .locks import * from .transports import * -from .parsers import * from .protocols import * from .streams import * from .tasks import * @@ -22,7 +21,6 @@ events.__all__ + locks.__all__ + transports.__all__ + - parsers.__all__ + protocols.__all__ + streams.__all__ + tasks.__all__) diff --git a/tulip/http/__init__.py b/tulip/http/__init__.py deleted file mode 100644 index a1432dee..00000000 --- a/tulip/http/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# This relies on each of the submodules having an __all__ variable. - -from .client import * -from .errors import * -from .protocol import * -from .server import * -from .session import * -from .wsgi import * - - -__all__ = (client.__all__ + - errors.__all__ + - protocol.__all__ + - server.__all__ + - session.__all__ + - wsgi.__all__) diff --git a/tulip/http/client.py b/tulip/http/client.py deleted file mode 100644 index a28fdc21..00000000 --- a/tulip/http/client.py +++ /dev/null @@ -1,581 +0,0 @@ -"""HTTP Client for Tulip. - -Most basic usage: - - response = yield from tulip.http.request('GET', url) - response['Content-Type'] == 'application/json' - response.status == 200 - - content = yield from response.content.read() -""" - -__all__ = ['request'] - -import base64 -import email.message -import functools -import http.client -import http.cookies -import json -import io -import itertools -import mimetypes -import os -import uuid -import urllib.parse - -import tulip -import tulip.http - - -@tulip.coroutine -def request(method, url, *, - params=None, - data=None, - headers=None, - cookies=None, - files=None, - auth=None, - allow_redirects=True, - max_redirects=10, - encoding='utf-8', - version=(1, 1), - timeout=None, - compress=None, - chunked=None, - session=None, - loop=None): - """Constructs and sends a request. Returns response object. - - method: http method - url: request url - params: (optional) Dictionary or bytes to be sent in the query string - of the new request - data: (optional) Dictionary, bytes, or file-like object to - send in the body of the request - headers: (optional) Dictionary of HTTP Headers to send with the request - cookies: (optional) Dict object to send with the request - files: (optional) Dictionary of 'name': file-like-objects - for multipart encoding upload - auth: (optional) Auth tuple to enable Basic HTTP Auth - timeout: (optional) Float describing the timeout of the request - allow_redirects: (optional) Boolean. Set to True if POST/PUT/DELETE - redirect following is allowed. - compress: Boolean. Set to True if request has to be compressed - with deflate encoding. - chunked: Boolean or Integer. Set to chunk size for chunked - transfer encoding. - session: tulip.http.Session instance to support connection pooling and - session cookies. - loop: Optional event loop. - - Usage: - - import tulip.http - >> resp = yield from tulip.http.request('GET', 'http://python.org/') - >> resp - - - >> data = yield from resp.read() - - """ - redirects = 0 - if loop is None: - loop = tulip.get_event_loop() - - while True: - req = HttpRequest( - method, url, params=params, headers=headers, data=data, - cookies=cookies, files=files, auth=auth, encoding=encoding, - version=version, compress=compress, chunked=chunked) - - if session is None: - conn = start(req, loop) - else: - conn = session.start(req, loop) - - # connection timeout - t = tulip.Task(conn, loop=loop) - th = None - if timeout is not None: - th = loop.call_later(timeout, t.cancel) - try: - resp = yield from t - except tulip.CancelledError: - raise tulip.TimeoutError from None - finally: - if th is not None: - th.cancel() - - # redirects - if resp.status in (301, 302) and allow_redirects: - redirects += 1 - if max_redirects and redirects >= max_redirects: - resp.close() - break - - r_url = resp.get('location') or resp.get('uri') - - scheme = urllib.parse.urlsplit(r_url)[0] - if scheme not in ('http', 'https', ''): - raise ValueError('Can redirect only to http or https') - elif not scheme: - r_url = urllib.parse.urljoin(url, r_url) - - url = urllib.parse.urldefrag(r_url)[0] - if url: - resp.close() - continue - - break - - return resp - - -@tulip.coroutine -def start(req, loop): - transport, p = yield from loop.create_connection( - functools.partial(tulip.StreamProtocol, loop=loop), - req.host, req.port, ssl=req.ssl) - - try: - resp = req.send(transport) - yield from resp.start(p, transport) - except: - transport.close() - raise - - return resp - - -class HttpRequest: - - GET_METHODS = {'DELETE', 'GET', 'HEAD', 'OPTIONS'} - POST_METHODS = {'PATCH', 'POST', 'PUT', 'TRACE'} - ALL_METHODS = GET_METHODS.union(POST_METHODS) - - DEFAULT_HEADERS = { - 'Accept': '*/*', - 'Accept-Encoding': 'gzip, deflate', - } - - body = b'' - - def __init__(self, method, url, *, - params=None, - headers=None, - data=None, - cookies=None, - files=None, - auth=None, - encoding='utf-8', - version=(1, 1), - compress=None, - chunked=None): - self.method = method.upper() - self.encoding = encoding - - # parser http version '1.1' => (1, 1) - if isinstance(version, str): - v = [l.strip() for l in version.split('.', 1)] - try: - version = int(v[0]), int(v[1]) - except ValueError: - raise ValueError( - 'Can not parse http version number: {}' - .format(version)) from None - self.version = version - - # path - scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url) - if not netloc: - raise ValueError('Host could not be detected.') - - if not path: - path = '/' - else: - path = urllib.parse.unquote(path) - - # check domain idna encoding - try: - netloc = netloc.encode('idna').decode('utf-8') - except UnicodeError: - raise ValueError('URL has an invalid label.') - - # basic auth info - if '@' in netloc: - authinfo, netloc = netloc.split('@', 1) - if not auth: - auth = authinfo.split(':', 1) - if len(auth) == 1: - auth.append('') - - # extract host and port - ssl = scheme == 'https' - - if ':' in netloc: - netloc, port_s = netloc.split(':', 1) - try: - port = int(port_s) - except ValueError: - raise ValueError( - 'Port number could not be converted.') from None - else: - if ssl: - port = http.client.HTTPS_PORT - else: - port = http.client.HTTP_PORT - - self.host = netloc - self.port = port - self.ssl = ssl - - # build url query - if isinstance(params, dict): - params = list(params.items()) - - if data and self.method in self.GET_METHODS: - # include data to query - if isinstance(data, dict): - data = data.items() - params = list(itertools.chain(params or (), data)) - data = None - - if params: - params = urllib.parse.urlencode(params) - if query: - query = '%s&%s' % (query, params) - else: - query = params - - # build path - path = urllib.parse.quote(path) - self.path = urllib.parse.urlunsplit(('', '', path, query, fragment)) - - # headers - self.headers = email.message.Message() - if headers: - if isinstance(headers, dict): - headers = list(headers.items()) - - for key, value in headers: - self.headers.add_header(key, value) - - for hdr, val in self.DEFAULT_HEADERS.items(): - if hdr not in self.headers: - self.headers[hdr] = val - - # host - if 'host' not in self.headers: - self.headers['Host'] = self.host - - # cookies - if cookies: - self.update_cookies(cookies) - - # auth - if auth: - if isinstance(auth, (tuple, list)) and len(auth) == 2: - # basic auth - self.headers['Authorization'] = 'Basic %s' % ( - base64.b64encode( - ('%s:%s' % (auth[0], auth[1])).encode('latin1')) - .strip().decode('latin1')) - else: - raise ValueError("Only basic auth is supported") - - # Content-encoding - enc = self.headers.get('Content-Encoding', '').lower() - if enc: - chunked = True # enable chunked, no need to deal with length - compress = enc - elif compress: - chunked = True # enable chunked, no need to deal with length - compress = compress if isinstance(compress, str) else 'deflate' - self.headers['Content-Encoding'] = compress - - # form data (x-www-form-urlencoded) - if isinstance(data, dict): - data = list(data.items()) - - if data and not files: - if isinstance(data, (bytes, bytearray)): - self.body = data - if 'content-type' not in self.headers: - self.headers['content-type'] = ( - 'application/octet-stream') - else: - if not isinstance(data, str): - data = urllib.parse.urlencode(data, doseq=True) - - self.body = data.encode(encoding) - if 'content-type' not in self.headers: - self.headers['content-type'] = ( - 'application/x-www-form-urlencoded') - - if 'content-length' not in self.headers and not chunked: - self.headers['content-length'] = str(len(self.body)) - - # files (multipart/form-data) - elif files: - fields = [] - - if data: - for field, val in data: - fields.append((field, str_to_bytes(val))) - - if isinstance(files, dict): - files = list(files.items()) - - for rec in files: - if not isinstance(rec, (tuple, list)): - rec = (rec,) - - ft = None - if len(rec) == 1: - k = guess_filename(rec[0], 'unknown') - fields.append((k, k, rec[0])) - - elif len(rec) == 2: - k, fp = rec - fn = guess_filename(fp, k) - fields.append((k, fn, fp)) - - else: - k, fp, ft = rec - fn = guess_filename(fp, k) - fields.append((k, fn, fp, ft)) - - chunked = chunked or 8192 - boundary = uuid.uuid4().hex - - self.body = encode_multipart_data( - fields, bytes(boundary, 'latin1')) - - self.headers['content-type'] = ( - 'multipart/form-data; boundary=%s' % boundary) - - # chunked - te = self.headers.get('transfer-encoding', '').lower() - - if chunked: - if 'content-length' in self.headers: - del self.headers['content-length'] - if 'chunked' not in te: - self.headers['transfer-encoding'] = 'chunked' - - chunked = chunked if type(chunked) is int else 8196 - else: - if 'chunked' in te: - chunked = 8196 - else: - chunked = None - self.headers['content-length'] = str(len(self.body)) - - self._chunked = chunked - self._compress = compress - - def update_cookies(self, cookies): - """Update request cookies header.""" - c = http.cookies.SimpleCookie() - if 'cookie' in self.headers: - c.load(self.headers.get('cookie', '')) - del self.headers['cookie'] - - if isinstance(cookies, dict): - cookies = cookies.items() - - for name, value in cookies: - if isinstance(value, http.cookies.Morsel): - # use dict method because SimpleCookie class modifies value - dict.__setitem__(c, name, value) - else: - c[name] = value - - self.headers['cookie'] = c.output(header='', sep=';').strip() - - def send(self, transport): - request = tulip.http.Request( - transport, self.method, self.path, self.version) - - if self._compress: - request.add_compression_filter(self._compress) - - if self._chunked is not None: - request.add_chunking_filter(self._chunked) - - request.add_headers(*self.headers.items()) - request.send_headers() - - if isinstance(self.body, bytes): - self.body = (self.body,) - - for chunk in self.body: - request.write(chunk) - - request.write_eof() - - return HttpResponse(self.method, self.path, self.host) - - -class HttpResponse(http.client.HTTPMessage): - - message = None # RawResponseMessage object - - # from the Status-Line of the response - version = None # HTTP-Version - status = None # Status-Code - reason = None # Reason-Phrase - - cookies = None # Response cookies (Set-Cookie) - - content = None # payload stream - stream = None # input stream - transport = None # current transport - - def __init__(self, method, url, host=''): - super().__init__() - - self.method = method - self.url = url - self.host = host - self._content = None - - def __del__(self): - self.close() - - def __repr__(self): - out = io.StringIO() - print(''.format( - self.host, self.url, self.status, self.reason), file=out) - print(super().__str__(), file=out) - return out.getvalue() - - def start(self, stream, transport): - """Start response processing.""" - self.stream = stream - self.transport = transport - - httpstream = stream.set_parser(tulip.http.http_response_parser()) - - # read response - self.message = yield from httpstream.read() - - # response status - self.version = self.message.version - self.status = self.message.code - self.reason = self.message.reason - - # headers - for hdr, val in self.message.headers: - self.add_header(hdr, val) - - # payload - self.content = stream.set_parser( - tulip.http.http_payload_parser(self.message)) - - # cookies - self.cookies = http.cookies.SimpleCookie() - if 'Set-Cookie' in self: - for hdr in self.get_all('Set-Cookie'): - self.cookies.load(hdr) - - return self - - def close(self): - if self.transport is not None: - self.transport.close() - self.transport = None - - @tulip.coroutine - def read(self, decode=False): - """Read response payload. Decode known types of content.""" - if self._content is None: - buf = [] - total = 0 - try: - while True: - chunk = yield from self.content.read() - size = len(chunk) - buf.append((chunk, size)) - total += size - except tulip.EofStream: - pass - - self._content = bytearray(total) - - idx = 0 - content = memoryview(self._content) - for chunk, size in buf: - content[idx:idx+size] = chunk - idx += size - - data = self._content - - if decode: - ct = self.get('content-type', '').lower() - if ct == 'application/json': - data = json.loads(data.decode('utf-8')) - - return data - - -def str_to_bytes(s, encoding='utf-8'): - if isinstance(s, str): - return s.encode(encoding) - return s - - -def guess_filename(obj, default=None): - name = getattr(obj, 'name', None) - if name and name[0] != '<' and name[-1] != '>': - return os.path.split(name)[-1] - return default - - -def encode_multipart_data(fields, boundary, encoding='utf-8', chunk_size=8196): - """ - Encode a list of fields using the multipart/form-data MIME format. - - fields: - List of (name, value) or (name, filename, io) or - (name, filename, io, MIME type) field tuples. - """ - for rec in fields: - yield b'--' + boundary + b'\r\n' - - field, *rec = rec - - if len(rec) == 1: - data = rec[0] - yield (('Content-Disposition: form-data; name="%s"\r\n\r\n' % - (field,)).encode(encoding)) - yield data + b'\r\n' - - else: - if len(rec) == 3: - fn, fp, ct = rec - else: - fn, fp = rec - ct = (mimetypes.guess_type(fn)[0] or - 'application/octet-stream') - - yield ('Content-Disposition: form-data; name="%s"; ' - 'filename="%s"\r\n' % (field, fn)).encode(encoding) - yield ('Content-Type: %s\r\n\r\n' % (ct,)).encode(encoding) - - if isinstance(fp, str): - fp = fp.encode(encoding) - - if isinstance(fp, bytes): - fp = io.BytesIO(fp) - - while True: - chunk = fp.read(chunk_size) - if not chunk: - break - yield str_to_bytes(chunk) - - yield b'\r\n' - - yield b'--' + boundary + b'--\r\n' diff --git a/tulip/http/errors.py b/tulip/http/errors.py deleted file mode 100644 index f8b77e9b..00000000 --- a/tulip/http/errors.py +++ /dev/null @@ -1,46 +0,0 @@ -"""http related errors.""" - -__all__ = ['HttpException', 'HttpErrorException', 'BadRequestException', - 'IncompleteRead', 'BadStatusLine', 'LineTooLong', 'InvalidHeader'] - -import http.client - - -class HttpException(http.client.HTTPException): - - code = None - headers = () - message = '' - - -class HttpErrorException(HttpException): - - def __init__(self, code, message='', headers=None): - self.code = code - self.headers = headers - self.message = message - - -class BadRequestException(HttpException): - - code = 400 - message = 'Bad Request' - - -class IncompleteRead(BadRequestException, http.client.IncompleteRead): - pass - - -class BadStatusLine(BadRequestException, http.client.BadStatusLine): - pass - - -class LineTooLong(BadRequestException, http.client.LineTooLong): - pass - - -class InvalidHeader(BadRequestException): - - def __init__(self, hdr): - super().__init__('Invalid HTTP Header: {}'.format(hdr)) - self.hdr = hdr diff --git a/tulip/http/protocol.py b/tulip/http/protocol.py deleted file mode 100644 index 7081fd59..00000000 --- a/tulip/http/protocol.py +++ /dev/null @@ -1,756 +0,0 @@ -"""Http related helper utils.""" - -__all__ = ['HttpMessage', 'Request', 'Response', - 'RawRequestMessage', 'RawResponseMessage', - 'http_request_parser', 'http_response_parser', - 'http_payload_parser'] - -import collections -import functools -import http.server -import itertools -import re -import sys -import zlib -from wsgiref.handlers import format_date_time - -import tulip -from tulip.http import errors - -METHRE = re.compile('[A-Z0-9$-_.]+') -VERSRE = re.compile('HTTP/(\d+).(\d+)') -HDRRE = re.compile('[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]') -CONTINUATION = (' ', '\t') -EOF_MARKER = object() -EOL_MARKER = object() - -RESPONSES = http.server.BaseHTTPRequestHandler.responses - - -RawRequestMessage = collections.namedtuple( - 'RawRequestMessage', - ['method', 'path', 'version', 'headers', 'should_close', 'compression']) - - -RawResponseMessage = collections.namedtuple( - 'RawResponseMessage', - ['version', 'code', 'reason', 'headers', 'should_close', 'compression']) - - -def http_request_parser(max_line_size=8190, - max_headers=32768, max_field_size=8190): - """Read request status line. Exception errors.BadStatusLine - could be raised in case of any errors in status line. - Returns RawRequestMessage. - """ - out, buf = yield - - try: - # read http message (request line + headers) - raw_data = yield from buf.readuntil( - b'\r\n\r\n', max_headers, errors.LineTooLong) - lines = raw_data.decode('ascii', 'surrogateescape').splitlines(True) - - # request line - line = lines[0] - try: - method, path, version = line.split(None, 2) - except ValueError: - raise errors.BadStatusLine(line) from None - - # method - method = method.upper() - if not METHRE.match(method): - raise errors.BadStatusLine(method) - - # version - match = VERSRE.match(version) - if match is None: - raise errors.BadStatusLine(version) - version = (int(match.group(1)), int(match.group(2))) - - # read headers - headers, close, compression = parse_headers( - lines, max_line_size, max_headers, max_field_size) - if version <= (1, 0): - close = True - elif close is None: - close = False - - out.feed_data( - RawRequestMessage( - method, path, version, headers, close, compression)) - out.feed_eof() - except tulip.EofStream: - # Presumably, the server closed the connection before - # sending a valid response. - pass - - -def http_response_parser(max_line_size=8190, - max_headers=32768, max_field_size=8190): - """Read response status line and headers. - - BadStatusLine could be raised in case of any errors in status line. - Returns RawResponseMessage""" - out, buf = yield - - try: - # read http message (response line + headers) - raw_data = yield from buf.readuntil( - b'\r\n\r\n', max_line_size+max_headers, errors.LineTooLong) - lines = raw_data.decode('ascii', 'surrogateescape').splitlines(True) - - line = lines[0] - try: - version, status = line.split(None, 1) - except ValueError: - raise errors.BadStatusLine(line) from None - else: - try: - status, reason = status.split(None, 1) - except ValueError: - reason = '' - - # version - match = VERSRE.match(version) - if match is None: - raise errors.BadStatusLine(line) - version = (int(match.group(1)), int(match.group(2))) - - # The status code is a three-digit number - try: - status = int(status) - except ValueError: - raise errors.BadStatusLine(line) from None - - if status < 100 or status > 999: - raise errors.BadStatusLine(line) - - # read headers - headers, close, compression = parse_headers( - lines, max_line_size, max_headers, max_field_size) - if close is None: - close = version <= (1, 0) - - out.feed_data( - RawResponseMessage( - version, status, reason.strip(), headers, close, compression)) - out.feed_eof() - except tulip.EofStream: - # Presumably, the server closed the connection before - # sending a valid response. - raise errors.BadStatusLine(b'') from None - - -def parse_headers(lines, max_line_size, max_headers, max_field_size): - """Parses RFC2822 headers from a stream. - - Line continuations are supported. Returns list of header name - and value pairs. Header name is in upper case. - """ - close_conn = None - encoding = None - headers = collections.deque() - - lines_idx = 1 - line = lines[1] - - while line not in ('\r\n', '\n'): - header_length = len(line) - - # Parse initial header name : value pair. - try: - name, value = line.split(':', 1) - except ValueError: - raise ValueError('Invalid header: {}'.format(line)) from None - - name = name.strip(' \t').upper() - if HDRRE.search(name): - raise ValueError('Invalid header name: {}'.format(name)) - - # next line - lines_idx += 1 - line = lines[lines_idx] - - # consume continuation lines - continuation = line[0] in CONTINUATION - - if continuation: - value = [value] - while continuation: - header_length += len(line) - if header_length > max_field_size: - raise errors.LineTooLong( - 'limit request headers fields size') - value.append(line) - - # next line - lines_idx += 1 - line = lines[lines_idx] - continuation = line[0] in CONTINUATION - value = ''.join(value) - else: - if header_length > max_field_size: - raise errors.LineTooLong('limit request headers fields size') - - value = value.strip() - - # keep-alive and encoding - if name == 'CONNECTION': - v = value.lower() - if v == 'close': - close_conn = True - elif v == 'keep-alive': - close_conn = False - elif name == 'CONTENT-ENCODING': - enc = value.lower() - if enc in ('gzip', 'deflate'): - encoding = enc - - headers.append((name, value)) - - return headers, close_conn, encoding - - -def http_payload_parser(message, length=None, compression=True, readall=False): - out, buf = yield - - # payload params - chunked = False - for name, value in message.headers: - if name == 'CONTENT-LENGTH': - length = value - elif name == 'TRANSFER-ENCODING': - chunked = value.lower() == 'chunked' - elif name == 'SEC-WEBSOCKET-KEY1': - length = 8 - - # payload decompression wrapper - if compression and message.compression: - out = DeflateBuffer(out, message.compression) - - # payload parser - if chunked: - yield from parse_chunked_payload(out, buf) - - elif length is not None: - try: - length = int(length) - except ValueError: - raise errors.InvalidHeader('CONTENT-LENGTH') from None - - if length < 0: - raise errors.InvalidHeader('CONTENT-LENGTH') - elif length > 0: - yield from parse_length_payload(out, buf, length) - else: - if readall: - yield from parse_eof_payload(out, buf) - - out.feed_eof() - - -def parse_chunked_payload(out, buf): - """Chunked transfer encoding parser.""" - try: - while True: - # read next chunk size - #line = yield from buf.readline(8196) - line = yield from buf.readuntil(b'\r\n', 8196) - - i = line.find(b';') - if i >= 0: - line = line[:i] # strip chunk-extensions - else: - line = line.strip() - try: - size = int(line, 16) - except ValueError: - raise errors.IncompleteRead(b'') from None - - if size == 0: # eof marker - break - - # read chunk and feed buffer - while size: - chunk = yield from buf.readsome(size) - out.feed_data(chunk) - size = size - len(chunk) - - # toss the CRLF at the end of the chunk - yield from buf.skip(2) - - # read and discard trailer up to the CRLF terminator - yield from buf.skipuntil(b'\r\n') - - except tulip.EofStream: - raise errors.IncompleteRead(b'') from None - - -def parse_length_payload(out, buf, length): - """Read specified amount of bytes.""" - try: - while length: - chunk = yield from buf.readsome(length) - out.feed_data(chunk) - length -= len(chunk) - - except tulip.EofStream: - raise errors.IncompleteRead(b'') from None - - -def parse_eof_payload(out, buf): - """Read all bytes untile eof.""" - while True: - out.feed_data((yield from buf.readsome())) - - -class DeflateBuffer: - """DeflateStream decomress stream and feed data into specified stream.""" - - def __init__(self, out, encoding): - self.out = out - zlib_mode = (16 + zlib.MAX_WBITS - if encoding == 'gzip' else -zlib.MAX_WBITS) - - self.zlib = zlib.decompressobj(wbits=zlib_mode) - - def feed_data(self, chunk): - try: - chunk = self.zlib.decompress(chunk) - except Exception: - raise errors.IncompleteRead(b'') from None - - if chunk: - self.out.feed_data(chunk) - - def feed_eof(self): - self.out.feed_data(self.zlib.flush()) - if not self.zlib.eof: - raise errors.IncompleteRead(b'') - - self.out.feed_eof() - - -def wrap_payload_filter(func): - """Wraps payload filter and piped filters. - - Filter is a generatator that accepts arbitrary chunks of data, - modify data and emit new stream of data. - - For example we have stream of chunks: ['1', '2', '3', '4', '5'], - we can apply chunking filter to this stream: - - ['1', '2', '3', '4', '5'] - | - response.add_chunking_filter(2) - | - ['12', '34', '5'] - - It is possible to use different filters at the same time. - - For a example to compress incoming stream with 'deflate' encoding - and then split data and emit chunks of 8196 bytes size chunks: - - >> response.add_compression_filter('deflate') - >> response.add_chunking_filter(8196) - - Filters do not alter transfer encoding. - - Filter can receive types types of data, bytes object or EOF_MARKER. - - 1. If filter receives bytes object, it should process data - and yield processed data then yield EOL_MARKER object. - 2. If Filter recevied EOF_MARKER, it should yield remaining - data (buffered) and then yield EOF_MARKER. - """ - @functools.wraps(func) - def wrapper(self, *args, **kw): - new_filter = func(self, *args, **kw) - - filter = self.filter - if filter is not None: - next(new_filter) - self.filter = filter_pipe(filter, new_filter) - else: - self.filter = new_filter - - next(self.filter) - - return wrapper - - -def filter_pipe(filter, filter2): - """Creates pipe between two filters. - - filter_pipe() feeds first filter with incoming data and then - send yielded from first filter data into filter2, results of - filter2 are being emitted. - - 1. If filter_pipe receives bytes object, it sends it to the first filter. - 2. Reads yielded values from the first filter until it receives - EOF_MARKER or EOL_MARKER. - 3. Each of this values is being send to second filter. - 4. Reads yielded values from second filter until it recives EOF_MARKER or - EOL_MARKER. Each of this values yields to writer. - """ - chunk = yield - - while True: - eof = chunk is EOF_MARKER - chunk = filter.send(chunk) - - while chunk is not EOL_MARKER: - chunk = filter2.send(chunk) - - while chunk not in (EOF_MARKER, EOL_MARKER): - yield chunk - chunk = next(filter2) - - if chunk is not EOF_MARKER: - if eof: - chunk = EOF_MARKER - else: - chunk = next(filter) - else: - break - - chunk = yield EOL_MARKER - - -class HttpMessage: - """HttpMessage allows to write headers and payload to a stream. - - For example, lets say we want to read file then compress it with deflate - compression and then send it with chunked transfer encoding, code may look - like this: - - >> response = tulip.http.Response(transport, 200) - - We have to use deflate compression first: - - >> response.add_compression_filter('deflate') - - Then we want to split output stream into chunks of 1024 bytes size: - - >> response.add_chunking_filter(1024) - - We can add headers to response with add_headers() method. add_headers() - does not send data to transport, send_headers() sends request/response - line and then sends headers: - - >> response.add_headers( - .. ('Content-Disposition', 'attachment; filename="..."')) - >> response.send_headers() - - Now we can use chunked writer to write stream to a network stream. - First call to write() method sends response status line and headers, - add_header() and add_headers() method unavailble at this stage: - - >> with open('...', 'rb') as f: - .. chunk = fp.read(8196) - .. while chunk: - .. response.write(chunk) - .. chunk = fp.read(8196) - - >> response.write_eof() - """ - - writer = None - - # 'filter' is being used for altering write() bahaviour, - # add_chunking_filter adds deflate/gzip compression and - # add_compression_filter splits incoming data into a chunks. - filter = None - - HOP_HEADERS = None # Must be set by subclass. - - SERVER_SOFTWARE = 'Python/{0[0]}.{0[1]} tulip/0.0'.format(sys.version_info) - - status = None - status_line = b'' - upgrade = False # Connection: UPGRADE - websocket = False # Upgrade: WEBSOCKET - - # subclass can enable auto sending headers with write() call, - # this is useful for wsgi's start_response implementation. - _send_headers = False - - def __init__(self, transport, version, close): - self.transport = transport - self.version = version - self.closing = close - - # disable keep-alive for http/1.0 - if version <= (1, 0): - self.keepalive = False - else: - self.keepalive = None - - self.chunked = False - self.length = None - self.headers = collections.deque() - self.headers_sent = False - - def force_close(self): - self.closing = True - self.keepalive = False - - def force_chunked(self): - self.chunked = True - - def keep_alive(self): - if self.keepalive is None: - return not self.closing - else: - return self.keepalive - - def is_headers_sent(self): - return self.headers_sent - - def add_header(self, name, value): - """Analyze headers. Calculate content length, - removes hop headers, etc.""" - assert not self.headers_sent, 'headers have been sent already' - assert isinstance(name, str), '{!r} is not a string'.format(name) - - name = name.strip().upper() - - if name == 'CONTENT-LENGTH': - self.length = int(value) - - if name == 'CONNECTION': - val = value.lower() - # handle websocket - if 'upgrade' in val: - self.upgrade = True - # connection keep-alive - elif 'close' in val: - self.keepalive = False - elif 'keep-alive' in val and self.version >= (1, 1): - self.keepalive = True - - elif name == 'UPGRADE': - if 'websocket' in value.lower(): - self.websocket = True - self.headers.append((name, value)) - - elif name == 'TRANSFER-ENCODING' and not self.chunked: - self.chunked = value.lower().strip() == 'chunked' - - elif name not in self.HOP_HEADERS: - # ignore hopbyhop headers - self.headers.append((name, value)) - - def add_headers(self, *headers): - """Adds headers to a http message.""" - for name, value in headers: - self.add_header(name, value) - - def send_headers(self): - """Writes headers to a stream. Constructs payload writer.""" - # Chunked response is only for HTTP/1.1 clients or newer - # and there is no Content-Length header is set. - # Do not use chunked responses when the response is guaranteed to - # not have a response body (304, 204). - assert not self.headers_sent, 'headers have been sent already' - self.headers_sent = True - - if (self.chunked is True) or ( - self.length is None and - self.version >= (1, 1) and - self.status not in (304, 204)): - self.chunked = True - self.writer = self._write_chunked_payload() - - elif self.length is not None: - self.writer = self._write_length_payload(self.length) - - else: - self.writer = self._write_eof_payload() - - next(self.writer) - - self._add_default_headers() - - # status + headers - hdrs = ''.join(itertools.chain( - (self.status_line,), - *((k, ': ', v, '\r\n') for k, v in self.headers))) - - self.transport.write(hdrs.encode('ascii') + b'\r\n') - - def _add_default_headers(self): - # set the connection header - if self.upgrade: - connection = 'upgrade' - elif not self.closing if self.keepalive is None else self.keepalive: - connection = 'keep-alive' - else: - connection = 'close' - - if self.chunked: - self.headers.appendleft(('TRANSFER-ENCODING', 'chunked')) - - self.headers.appendleft(('CONNECTION', connection)) - - def write(self, chunk): - """write() writes chunk of data to a steram by using different writers. - writer uses filter to modify chunk of data. write_eof() indicates - end of stream. writer can't be used after write_eof() method - being called.""" - assert (isinstance(chunk, (bytes, bytearray)) or - chunk is EOF_MARKER), chunk - - if self._send_headers and not self.headers_sent: - self.send_headers() - - assert self.writer is not None, 'send_headers() is not called.' - - if self.filter: - chunk = self.filter.send(chunk) - while chunk not in (EOF_MARKER, EOL_MARKER): - self.writer.send(chunk) - chunk = next(self.filter) - else: - if chunk is not EOF_MARKER: - self.writer.send(chunk) - - def write_eof(self): - self.write(EOF_MARKER) - try: - self.writer.throw(tulip.EofStream()) - except StopIteration: - pass - - def _write_chunked_payload(self): - """Write data in chunked transfer encoding.""" - while True: - try: - chunk = yield - except tulip.EofStream: - self.transport.write(b'0\r\n\r\n') - break - - self.transport.write('{:x}\r\n'.format(len(chunk)).encode('ascii')) - self.transport.write(bytes(chunk)) - self.transport.write(b'\r\n') - - def _write_length_payload(self, length): - """Write specified number of bytes to a stream.""" - while True: - try: - chunk = yield - except tulip.EofStream: - break - - if length: - l = len(chunk) - if length >= l: - self.transport.write(chunk) - else: - self.transport.write(chunk[:length]) - - length = max(0, length-l) - - def _write_eof_payload(self): - while True: - try: - chunk = yield - except tulip.EofStream: - break - - self.transport.write(chunk) - - @wrap_payload_filter - def add_chunking_filter(self, chunk_size=16*1024): - """Split incoming stream into chunks.""" - buf = bytearray() - chunk = yield - - while True: - if chunk is EOF_MARKER: - if buf: - yield buf - - yield EOF_MARKER - - else: - buf.extend(chunk) - - while len(buf) >= chunk_size: - chunk = bytes(buf[:chunk_size]) - del buf[:chunk_size] - yield chunk - - chunk = yield EOL_MARKER - - @wrap_payload_filter - def add_compression_filter(self, encoding='deflate'): - """Compress incoming stream with deflate or gzip encoding.""" - zlib_mode = (16 + zlib.MAX_WBITS - if encoding == 'gzip' else -zlib.MAX_WBITS) - zcomp = zlib.compressobj(wbits=zlib_mode) - - chunk = yield - while True: - if chunk is EOF_MARKER: - yield zcomp.flush() - chunk = yield EOF_MARKER - - else: - yield zcomp.compress(chunk) - chunk = yield EOL_MARKER - - -class Response(HttpMessage): - """Create http response message. - - Transport is a socket stream transport. status is a response status code, - status has to be integer value. http_version is a tuple that represents - http version, (1, 0) stands for HTTP/1.0 and (1, 1) is for HTTP/1.1 - """ - - HOP_HEADERS = { - 'CONNECTION', - 'KEEP-ALIVE', - 'PROXY-AUTHENTICATE', - 'PROXY-AUTHORIZATION', - 'TE', - 'TRAILERS', - 'TRANSFER-ENCODING', - 'UPGRADE', - 'SERVER', - 'DATE', - } - - def __init__(self, transport, status, http_version=(1, 1), close=False): - super().__init__(transport, http_version, close) - - self.status = status - self.status_line = 'HTTP/{}.{} {} {}\r\n'.format( - http_version[0], http_version[1], status, RESPONSES[status][0]) - - def _add_default_headers(self): - super()._add_default_headers() - self.headers.extend((('DATE', format_date_time(None)), - ('SERVER', self.SERVER_SOFTWARE),)) - - -class Request(HttpMessage): - - HOP_HEADERS = () - - def __init__(self, transport, method, path, - http_version=(1, 1), close=False): - super().__init__(transport, http_version, close) - - self.method = method - self.path = path - self.status_line = '{0} {1} HTTP/{2[0]}.{2[1]}\r\n'.format( - method, path, http_version) - - def _add_default_headers(self): - super()._add_default_headers() - self.headers.append(('USER-AGENT', self.SERVER_SOFTWARE)) diff --git a/tulip/http/server.py b/tulip/http/server.py deleted file mode 100644 index fc5621c5..00000000 --- a/tulip/http/server.py +++ /dev/null @@ -1,215 +0,0 @@ -"""simple http server.""" - -__all__ = ['ServerHttpProtocol'] - -import http.server -import inspect -import logging -import traceback - -import tulip -from tulip.http import errors - - -RESPONSES = http.server.BaseHTTPRequestHandler.responses -DEFAULT_ERROR_MESSAGE = """ - - - {status} {reason} - - -

{status} {reason}

- {message} - -""" - - -class ServerHttpProtocol(tulip.Protocol): - """Simple http protocol implementation. - - ServerHttpProtocol handles incoming http request. It reads request line, - request headers and request payload and calls handler_request() method. - By default it always returns with 404 respose. - - ServerHttpProtocol handles errors in incoming request, like bad - status line, bad headers or incomplete payload. If any error occurs, - connection gets closed. - - log: custom logging object - debug: enable debug mode - keep_alive: number of seconds before closing keep alive connection - loop: event loop object - """ - _request_count = 0 - _request_handler = None - _keep_alive = False # keep transport open - _keep_alive_handle = None # keep alive timer handle - - def __init__(self, *, log=logging, debug=False, - keep_alive=None, loop=None, **kwargs): - self.__dict__.update(kwargs) - self.log = log - self.debug = debug - - self._keep_alive_period = keep_alive # number of seconds to keep alive - - if keep_alive and loop is None: - loop = tulip.get_event_loop() - self._loop = loop - - def connection_made(self, transport): - self.transport = transport - self.stream = tulip.StreamBuffer(loop=self._loop) - self._request_handler = tulip.Task(self.start(), loop=self._loop) - - def data_received(self, data): - self.stream.feed_data(data) - - def eof_received(self): - self.stream.feed_eof() - - def connection_lost(self, exc): - self.stream.feed_eof() - - if self._request_handler is not None: - self._request_handler.cancel() - self._request_handler = None - if self._keep_alive_handle is not None: - self._keep_alive_handle.cancel() - self._keep_alive_handle = None - - def keep_alive(self, val): - self._keep_alive = val - - def log_access(self, status, message, *args, **kw): - pass - - def log_debug(self, *args, **kw): - if self.debug: - self.log.debug(*args, **kw) - - def log_exception(self, *args, **kw): - self.log.exception(*args, **kw) - - @tulip.coroutine - def start(self): - """Start processing of incoming requests. - It reads request line, request headers and request payload, then - calls handle_request() method. Subclass has to override - handle_request(). start() handles various excetions in request - or response handling. Connection is being closed always unless - keep_alive(True) specified. - """ - - while True: - info = None - message = None - self._request_count += 1 - self._keep_alive = False - - try: - httpstream = self.stream.set_parser( - tulip.http.http_request_parser()) - - message = yield from httpstream.read() - - # cancel keep-alive timer - if self._keep_alive_handle is not None: - self._keep_alive_handle.cancel() - self._keep_alive_handle = None - - payload = self.stream.set_parser( - tulip.http.http_payload_parser(message)) - - handler = self.handle_request(message, payload) - if (inspect.isgenerator(handler) or - isinstance(handler, tulip.Future)): - yield from handler - - except tulip.CancelledError: - self.log_debug('Ignored premature client disconnection.') - break - except errors.HttpException as exc: - self.handle_error(exc.code, info, message, exc, exc.headers) - except Exception as exc: - self.handle_error(500, info, message, exc) - finally: - if self._request_handler: - if self._keep_alive and self._keep_alive_period: - self._keep_alive_handle = self._loop.call_later( - self._keep_alive_period, self.transport.close) - else: - self.transport.close() - self._request_handler = None - break - else: - break - - def handle_error(self, status=500, - message=None, payload=None, exc=None, headers=None): - """Handle errors. - - Returns http response with specific status code. Logs additional - information. It always closes current connection.""" - try: - if self._request_handler is None: - # client has been disconnected during writing. - return - - if status == 500: - self.log_exception("Error handling request") - - try: - reason, msg = RESPONSES[status] - except KeyError: - status = 500 - reason, msg = '???', '' - - if self.debug and exc is not None: - try: - tb = traceback.format_exc() - msg += '

Traceback:

\n
{}
'.format(tb) - except: - pass - - self.log_access(status, message) - - html = DEFAULT_ERROR_MESSAGE.format( - status=status, reason=reason, message=msg) - - response = tulip.http.Response(self.transport, status, close=True) - response.add_headers( - ('Content-Type', 'text/html'), - ('Content-Length', str(len(html)))) - if headers is not None: - response.add_headers(*headers) - response.send_headers() - - response.write(html.encode('ascii')) - response.write_eof() - finally: - self.keep_alive(False) - - def handle_request(self, message, payload): - """Handle a single http request. - - Subclass should override this method. By default it always - returns 404 response. - - info: tulip.http.RequestLine instance - message: tulip.http.RawHttpMessage instance - """ - response = tulip.http.Response( - self.transport, 404, http_version=message.version, close=True) - - body = b'Page Not Found!' - - response.add_headers( - ('Content-Type', 'text/plain'), - ('Content-Length', str(len(body)))) - response.send_headers() - response.write(body) - response.write_eof() - - self.keep_alive(False) - self.log_access(404, message) diff --git a/tulip/http/session.py b/tulip/http/session.py deleted file mode 100644 index 9cdd9cea..00000000 --- a/tulip/http/session.py +++ /dev/null @@ -1,103 +0,0 @@ -"""client session support.""" - -__all__ = ['Session'] - -import functools -import tulip -import http.cookies - - -class Session: - - def __init__(self): - self._conns = {} - self.cookies = http.cookies.SimpleCookie() - - def __del__(self): - self.close() - - def close(self): - """Close all opened transports.""" - for key, data in self._conns.items(): - for transport, proto in data: - transport.close() - - self._conns.clear() - - def update_cookies(self, cookies): - if isinstance(cookies, dict): - cookies = cookies.items() - - for name, value in cookies: - if isinstance(value, http.cookies.Morsel): - # use dict method because SimpleCookie class modifies value - dict.__setitem__(self.cookies, name, value) - else: - self.cookies[name] = value - - @tulip.coroutine - def start(self, req, loop, new_conn=False, set_cookies=True): - key = (req.host, req.port, req.ssl) - - if set_cookies and self.cookies: - req.update_cookies(self.cookies.items()) - - if not new_conn: - transport, proto = self._get(key) - - if new_conn or transport is None: - new = True - transport, proto = yield from loop.create_connection( - functools.partial(tulip.StreamProtocol, loop=loop), - req.host, req.port, ssl=req.ssl) - else: - new = False - - try: - resp = req.send(transport) - yield from resp.start( - proto, TransportWrapper( - self._release, key, transport, proto, resp)) - except: - if new: - transport.close() - raise - - return (yield from self.start(req, loop, set_cookies=False)) - - return resp - - def _get(self, key): - conns = self._conns.get(key) - if conns: - return conns.pop() - - return None, None - - def _release(self, resp, key, conn): - msg = resp.message - if msg.should_close: - conn[0].close() - else: - conns = self._conns.get(key) - if conns is None: - conns = self._conns[key] = [] - conns.append(conn) - conn[1].unset_parser() - - if resp.cookies: - self.update_cookies(resp.cookies.items()) - - -class TransportWrapper: - - def __init__(self, release, key, transport, protocol, response): - self.release = release - self.key = key - self.transport = transport - self.protocol = protocol - self.response = response - - def close(self): - self.release(self.response, self.key, - (self.transport, self.protocol)) diff --git a/tulip/http/websocket.py b/tulip/http/websocket.py deleted file mode 100644 index c3dd5872..00000000 --- a/tulip/http/websocket.py +++ /dev/null @@ -1,233 +0,0 @@ -"""WebSocket protocol versions 13 and 8.""" - -__all__ = ['WebSocketParser', 'WebSocketWriter', 'do_handshake', - 'Message', 'WebSocketError', - 'MSG_TEXT', 'MSG_BINARY', 'MSG_CLOSE', 'MSG_PING', 'MSG_PONG'] - -import base64 -import binascii -import collections -import hashlib -import struct -from tulip.http import errors - -# Frame opcodes defined in the spec. -OPCODE_CONTINUATION = 0x0 -MSG_TEXT = OPCODE_TEXT = 0x1 -MSG_BINARY = OPCODE_BINARY = 0x2 -MSG_CLOSE = OPCODE_CLOSE = 0x8 -MSG_PING = OPCODE_PING = 0x9 -MSG_PONG = OPCODE_PONG = 0xa - -WS_KEY = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' -WS_HDRS = ('UPGRADE', 'CONNECTION', - 'SEC-WEBSOCKET-VERSION', 'SEC-WEBSOCKET-KEY') - -Message = collections.namedtuple('Message', ['tp', 'data', 'extra']) - - -class WebSocketError(Exception): - """WebSocket protocol parser error.""" - - -def WebSocketParser(): - out, buf = yield - - while True: - message = yield from parse_message(buf) - out.feed_data(message) - - if message.tp == MSG_CLOSE: - out.feed_eof() - break - - -def parse_frame(buf): - """Return the next frame from the socket.""" - # read header - data = yield from buf.read(2) - first_byte, second_byte = struct.unpack('!BB', data) - - fin = (first_byte >> 7) & 1 - rsv1 = (first_byte >> 6) & 1 - rsv2 = (first_byte >> 5) & 1 - rsv3 = (first_byte >> 4) & 1 - opcode = first_byte & 0xf - - # frame-fin = %x0 ; more frames of this message follow - # / %x1 ; final frame of this message - # frame-rsv1 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise - # frame-rsv2 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise - # frame-rsv3 = %x0 ; 1 bit, MUST be 0 unless negotiated otherwise - if rsv1 or rsv2 or rsv3: - raise WebSocketError('Received frame with non-zero reserved bits') - - if opcode > 0x7 and fin == 0: - raise WebSocketError('Received fragmented control frame') - - if fin == 0 and opcode == OPCODE_CONTINUATION: - raise WebSocketError( - 'Received new fragment frame with non-zero opcode') - - has_mask = (second_byte >> 7) & 1 - length = (second_byte) & 0x7f - - # Control frames MUST have a payload length of 125 bytes or less - if opcode > 0x7 and length > 125: - raise WebSocketError( - "Control frame payload cannot be larger than 125 bytes") - - # read payload - if length == 126: - data = yield from buf.read(2) - length = struct.unpack_from('!H', data)[0] - elif length > 126: - data = yield from buf.read(8) - length = struct.unpack_from('!Q', data)[0] - - if has_mask: - mask = yield from buf.read(4) - - if length: - payload = yield from buf.read(length) - else: - payload = b'' - - if has_mask: - payload = bytes(b ^ mask[i % 4] for i, b in enumerate(payload)) - - return fin, opcode, payload - - -def parse_message(buf): - fin, opcode, payload = yield from parse_frame(buf) - - if opcode == OPCODE_CLOSE: - if len(payload) >= 2: - close_code = struct.unpack('!H', payload[:2])[0] - close_message = payload[2:] - return Message(OPCODE_CLOSE, close_code, close_message) - elif payload: - raise WebSocketError( - 'Invalid close frame: {} {} {!r}'.format(fin, opcode, payload)) - return Message(OPCODE_CLOSE, '', '') - - elif opcode == OPCODE_PING: - return Message(OPCODE_PING, '', '') - - elif opcode == OPCODE_PONG: - return Message(OPCODE_PONG, '', '') - - elif opcode not in (OPCODE_TEXT, OPCODE_BINARY): - raise WebSocketError("Unexpected opcode={!r}".format(opcode)) - - # load text/binary - data = [payload] - - while not fin: - fin, _opcode, payload = yield from parse_frame(buf) - if _opcode != OPCODE_CONTINUATION: - raise WebSocketError( - 'The opcode in non-fin frame is expected ' - 'to be zero, got {!r}'.format(opcode)) - else: - data.append(payload) - - if opcode == OPCODE_TEXT: - return Message(OPCODE_TEXT, b''.join(data).decode('utf-8'), '') - else: - return Message(OPCODE_BINARY, b''.join(data), '') - - -class WebSocketWriter: - - def __init__(self, transport): - self.transport = transport - - def _send_frame(self, message, opcode): - """Send a frame over the websocket with message as its payload.""" - header = bytes([0x80 | opcode]) - msg_length = len(message) - - if msg_length < 126: - header += bytes([msg_length]) - elif msg_length < (1 << 16): - header += bytes([126]) + struct.pack('!H', msg_length) - else: - header += bytes([127]) + struct.pack('!Q', msg_length) - - self.transport.write(header + message) - - def pong(self): - """Send pong message.""" - self._send_frame(b'', OPCODE_PONG) - - def ping(self): - """Send pong message.""" - self._send_frame(b'', OPCODE_PING) - - def send(self, message, binary=False): - """Send a frame over the websocket with message as its payload.""" - if isinstance(message, str): - message = message.encode('utf-8') - if binary: - self._send_frame(message, OPCODE_BINARY) - else: - self._send_frame(message, OPCODE_TEXT) - - def close(self, code=1000, message=b''): - """Close the websocket, sending the specified code and message.""" - if isinstance(message, str): - message = message.encode('utf-8') - self._send_frame( - struct.pack('!H%ds' % len(message), code, message), - opcode=OPCODE_CLOSE) - - -def do_handshake(method, headers, transport): - """Prepare WebSocket handshake. It return http response code, - response headers, websocket parser, websocket writer. It does not - perform any IO.""" - - # WebSocket accepts only GET - if method.upper() != 'GET': - raise errors.HttpErrorException(405, headers=(('Allow', 'GET'),)) - - headers = dict(((hdr, val) for hdr, val in headers if hdr in WS_HDRS)) - - if 'websocket' != headers.get('UPGRADE', '').lower().strip(): - raise errors.BadRequestException( - 'No WebSocket UPGRADE hdr: {}\n' - 'Can "Upgrade" only to "WebSocket".'.format( - headers.get('UPGRADE'))) - - if 'upgrade' not in headers.get('CONNECTION', '').lower(): - raise errors.BadRequestException( - 'No CONNECTION upgrade hdr: {}'.format( - headers.get('CONNECTION'))) - - # check supported version - version = headers.get('SEC-WEBSOCKET-VERSION') - if version not in ('13', '8', '7'): - raise errors.BadRequestException( - 'Unsupported version: {}'.format(version)) - - # check client handshake for validity - key = headers.get('SEC-WEBSOCKET-KEY') - try: - if not key or len(base64.b64decode(key)) != 16: - raise errors.BadRequestException( - 'Handshake error: {!r}'.format(key)) - except binascii.Error: - raise errors.BadRequestException( - 'Handshake error: {!r}'.format(key)) from None - - # response code, headers, parser, writer - return (101, - (('UPGRADE', 'websocket'), - ('CONNECTION', 'upgrade'), - ('TRANSFER-ENCODING', 'chunked'), - ('SEC-WEBSOCKET-ACCEPT', base64.b64encode( - hashlib.sha1(key.encode() + WS_KEY).digest()).decode())), - WebSocketParser(), - WebSocketWriter(transport)) diff --git a/tulip/http/wsgi.py b/tulip/http/wsgi.py deleted file mode 100644 index 02611f78..00000000 --- a/tulip/http/wsgi.py +++ /dev/null @@ -1,228 +0,0 @@ -"""wsgi server. - -TODO: - * proxy protocol - * x-forward security - * wsgi file support (os.sendfile) -""" - -__all__ = ['WSGIServerHttpProtocol'] - -import inspect -import io -import os -import sys -from urllib.parse import unquote, urlsplit - -import tulip -import tulip.http -from tulip.http import server - - -class WSGIServerHttpProtocol(server.ServerHttpProtocol): - """HTTP Server that implements the Python WSGI protocol. - - It uses 'wsgi.async' of 'True'. 'wsgi.input' can behave differently - depends on 'readpayload' constructor parameter. If readpayload is set to - True, wsgi server reads all incoming data into BytesIO object and - sends it as 'wsgi.input' environ var. If readpayload is set to false - 'wsgi.input' is a StreamReader and application should read incoming - data with "yield from environ['wsgi.input'].read()". It defaults to False. - """ - - SCRIPT_NAME = os.environ.get('SCRIPT_NAME', '') - - def __init__(self, app, readpayload=False, is_ssl=False, *args, **kw): - super().__init__(*args, **kw) - - self.wsgi = app - self.is_ssl = is_ssl - self.readpayload = readpayload - - def create_wsgi_response(self, message): - return WsgiResponse(self.transport, message) - - def create_wsgi_environ(self, message, payload): - uri_parts = urlsplit(message.path) - url_scheme = 'https' if self.is_ssl else 'http' - - environ = { - 'wsgi.input': payload, - 'wsgi.errors': sys.stderr, - 'wsgi.version': (1, 0), - 'wsgi.async': True, - 'wsgi.multithread': False, - 'wsgi.multiprocess': False, - 'wsgi.run_once': False, - 'wsgi.file_wrapper': FileWrapper, - 'wsgi.url_scheme': url_scheme, - 'SERVER_SOFTWARE': tulip.http.HttpMessage.SERVER_SOFTWARE, - 'REQUEST_METHOD': message.method, - 'QUERY_STRING': uri_parts.query or '', - 'RAW_URI': message.path, - 'SERVER_PROTOCOL': 'HTTP/%s.%s' % message.version - } - - # authors should be aware that REMOTE_HOST and REMOTE_ADDR - # may not qualify the remote addr: - # http://www.ietf.org/rfc/rfc3875 - forward = self.transport.get_extra_info('addr', '127.0.0.1') - script_name = self.SCRIPT_NAME - server = forward - - for hdr_name, hdr_value in message.headers: - if hdr_name == 'EXPECT': - # handle expect - if hdr_value.lower() == '100-continue': - self.transport.write(b'HTTP/1.1 100 Continue\r\n\r\n') - elif hdr_name == 'HOST': - server = hdr_value - elif hdr_name == 'SCRIPT_NAME': - script_name = hdr_value - elif hdr_name == 'CONTENT-TYPE': - environ['CONTENT_TYPE'] = hdr_value - continue - elif hdr_name == 'CONTENT-LENGTH': - environ['CONTENT_LENGTH'] = hdr_value - continue - - key = 'HTTP_%s' % hdr_name.replace('-', '_') - if key in environ: - hdr_value = '%s,%s' % (environ[key], hdr_value) - - environ[key] = hdr_value - - if isinstance(forward, str): - # we only took the last one - # http://en.wikipedia.org/wiki/X-Forwarded-For - if ',' in forward: - forward = forward.rsplit(',', 1)[-1].strip() - - # find host and port on ipv6 address - if '[' in forward and ']' in forward: - host = forward.split(']')[0][1:].lower() - elif ':' in forward and forward.count(':') == 1: - host = forward.split(':')[0].lower() - else: - host = forward - - forward = forward.split(']')[-1] - if ':' in forward and forward.count(':') == 1: - port = forward.split(':', 1)[1] - else: - port = 80 - - remote = (host, port) - else: - remote = forward - - environ['REMOTE_ADDR'] = remote[0] - environ['REMOTE_PORT'] = str(remote[1]) - - if isinstance(server, str): - server = server.split(':') - if len(server) == 1: - server.append('80' if url_scheme == 'http' else '443') - - environ['SERVER_NAME'] = server[0] - environ['SERVER_PORT'] = str(server[1]) - - path_info = uri_parts.path - if script_name: - path_info = path_info.split(script_name, 1)[-1] - - environ['PATH_INFO'] = unquote(path_info) - environ['SCRIPT_NAME'] = script_name - - environ['tulip.reader'] = self.stream - environ['tulip.writer'] = self.transport - - return environ - - @tulip.coroutine - def handle_request(self, message, payload): - """Handle a single HTTP request""" - - if self.readpayload: - wsgiinput = io.BytesIO() - try: - while True: - wsgiinput.write((yield from payload.read())) - except tulip.EofStream: - pass - wsgiinput.seek(0) - payload = wsgiinput - - environ = self.create_wsgi_environ(message, payload) - response = self.create_wsgi_response(message) - - riter = self.wsgi(environ, response.start_response) - if isinstance(riter, tulip.Future) or inspect.isgenerator(riter): - riter = yield from riter - - resp = response.response - try: - for item in riter: - if isinstance(item, tulip.Future): - item = yield from item - resp.write(item) - - resp.write_eof() - finally: - if hasattr(riter, 'close'): - riter.close() - - if resp.keep_alive(): - self.keep_alive(True) - - -class FileWrapper: - """Custom file wrapper.""" - - def __init__(self, fobj, chunk_size=8192): - self.fobj = fobj - self.chunk_size = chunk_size - if hasattr(fobj, 'close'): - self.close = fobj.close - - def __iter__(self): - return self - - def __next__(self): - data = self.fobj.read(self.chunk_size) - if data: - return data - raise StopIteration - - -class WsgiResponse: - """Implementation of start_response() callable as specified by PEP 3333""" - - status = None - - def __init__(self, transport, message): - self.transport = transport - self.message = message - - def start_response(self, status, headers, exc_info=None): - if exc_info: - try: - if self.status: - raise exc_info[1] - finally: - exc_info = None - - status_code = int(status.split(' ', 1)[0]) - - self.status = status - resp = self.response = tulip.http.Response( - self.transport, status_code, - self.message.version, self.message.should_close) - resp.add_headers(*headers) - - # send headers immediately for websocket connection - if status_code == 101 and resp.upgrade and resp.websocket: - resp.send_headers() - else: - resp._send_headers = True - return self.response.write diff --git a/tulip/parsers.py b/tulip/parsers.py deleted file mode 100644 index 8ac05e18..00000000 --- a/tulip/parsers.py +++ /dev/null @@ -1,399 +0,0 @@ -"""Parser is a generator function. - -Parser receives data with generator's send() method and sends data to -destination DataBuffer. Parser receives ParserBuffer and DataBuffer objects -as a parameters of the first send() call, all subsequent send() calls should -send bytes objects. Parser sends parsed 'term' to desitnation buffer with -DataBuffer.feed_data() method. DataBuffer object should implement two methods. -feed_data() - parser uses this method to send parsed protocol data. -feed_eof() - parser uses this method for indication of end of parsing stream. -To indicate end of incoming data stream EofStream exception should be sent -into parser. Parser could throw exceptions. - -There are three stages: - - * Data flow chain: - - 1. Application creates StreamBuffer object for storing incoming data. - 2. StreamBuffer creates ParserBuffer as internal data buffer. - 3. Application create parser and set it into stream buffer: - - parser = http_request_parser() - data_buffer = stream.set_parser(parser) - - 3. At this stage StreamBuffer creates DataBuffer object and passes it - and internal buffer into parser with first send() call. - - def set_parser(self, parser): - next(parser) - data_buffer = DataBuffer() - parser.send((data_buffer, self._buffer)) - return data_buffer - - 4. Application waits data on data_buffer.read() - - while True: - msg = yield form data_buffer.read() - ... - - * Data flow: - - 1. Tulip's transport reads data from socket and sends data to protocol - with data_received() call. - 2. Protocol sends data to StreamBuffer with feed_data() call. - 3. StreamBuffer sends data into parser with generator's send() method. - 4. Parser processes incoming data and sends parsed data - to DataBuffer with feed_data() - 4. Application received parsed data from DataBuffer.read() - - * Eof: - - 1. StreamBuffer recevies eof with feed_eof() call. - 2. StreamBuffer throws EofStream exception into parser. - 3. Then it unsets parser. - -_SocketSocketTransport -> - -> "protocol" -> StreamBuffer -> "parser" -> DataBuffer <- "application" - -""" -__all__ = ['EofStream', 'StreamBuffer', 'StreamProtocol', - 'ParserBuffer', 'DataBuffer', 'lines_parser', 'chunks_parser'] - -import collections - -from . import tasks -from . import futures -from . import protocols - - -class EofStream(Exception): - """eof stream indication.""" - - -class StreamBuffer: - """StreamBuffer manages incoming bytes stream and protocol parsers. - - StreamBuffer uses ParserBuffer as internal buffer. - - set_parser() sets current parser, it creates DataBuffer object - and sends ParserBuffer and DataBuffer into parser generator. - - unset_parser() sends EofStream into parser and then removes it. - """ - - def __init__(self, *, loop=None): - self._loop = loop - self._buffer = ParserBuffer() - self._eof = False - self._parser = None - self._parser_buffer = None - self._exception = None - - def exception(self): - return self._exception - - def set_exception(self, exc): - self._exception = exc - - if self._parser_buffer is not None: - self._parser_buffer.set_exception(exc) - self._parser = None - self._parser_buffer = None - - def feed_data(self, data): - """send data to current parser or store in buffer.""" - if not data: - return - - if self._parser: - try: - self._parser.send(data) - except StopIteration: - self._parser = None - self._parser_buffer = None - except Exception as exc: - self._parser_buffer.set_exception(exc) - self._parser = None - self._parser_buffer = None - else: - self._buffer.feed_data(data) - - def feed_eof(self): - """send eof to all parsers, recursively.""" - if self._parser: - try: - self._parser.throw(EofStream()) - except StopIteration: - pass - except EofStream: - self._parser_buffer.feed_eof() - except Exception as exc: - self._parser_buffer.set_exception(exc) - - self._parser = None - self._parser_buffer = None - - self._eof = True - - def set_parser(self, p): - """set parser to stream. return parser's DataStream.""" - if self._parser: - self.unset_parser() - - out = DataBuffer(loop=self._loop) - if self._exception: - out.set_exception(self._exception) - return out - - # init generator - next(p) - try: - # initialize parser with data and parser buffers - p.send((out, self._buffer)) - except StopIteration: - pass - except Exception as exc: - out.set_exception(exc) - else: - # parser still require more data - self._parser = p - self._parser_buffer = out - - if self._eof: - self.unset_parser() - - return out - - def unset_parser(self): - """unset parser, send eof to the parser and then remove it.""" - if self._parser is None: - return - - try: - self._parser.throw(EofStream()) - except StopIteration: - pass - except EofStream: - self._parser_buffer.feed_eof() - except Exception as exc: - self._parser_buffer.set_exception(exc) - finally: - self._parser = None - self._parser_buffer = None - - -class StreamProtocol(StreamBuffer, protocols.Protocol): - """Tulip's stream protocol based on StreamBuffer""" - - transport = None - - data_received = StreamBuffer.feed_data - - eof_received = StreamBuffer.feed_eof - - def connection_made(self, transport): - self.transport = transport - - def connection_lost(self, exc): - self.transport = None - - if exc is not None: - self.set_exception(exc) - else: - self.feed_eof() - - -class DataBuffer: - """DataBuffer is a destination for parsed data.""" - - def __init__(self, *, loop=None): - self._loop = loop - self._buffer = collections.deque() - self._eof = False - self._waiter = None - self._exception = None - - def exception(self): - return self._exception - - def set_exception(self, exc): - self._exception = exc - - waiter = self._waiter - if waiter is not None: - self._waiter = None - if not waiter.done(): - waiter.set_exception(exc) - - def feed_data(self, data): - self._buffer.append(data) - - waiter = self._waiter - if waiter is not None: - self._waiter = None - waiter.set_result(True) - - def feed_eof(self): - self._eof = True - - waiter = self._waiter - if waiter is not None: - self._waiter = None - waiter.set_result(False) - - @tasks.coroutine - def read(self): - if self._exception is not None: - raise self._exception - - if not self._buffer and not self._eof: - assert not self._waiter - self._waiter = futures.Future(loop=self._loop) - yield from self._waiter - - if self._buffer: - return self._buffer.popleft() - else: - raise EofStream - - -class ParserBuffer(bytearray): - """ParserBuffer is a bytearray extension. - - ParserBuffer provides helper methods for parsers. - """ - - def __init__(self, *args): - super().__init__(*args) - - self.offset = 0 - self.size = 0 - self._writer = self._feed_data() - next(self._writer) - - def _shrink(self): - if self.offset: - del self[:self.offset] - self.offset = 0 - self.size = len(self) - - def _feed_data(self): - while True: - chunk = yield - if chunk: - chunk_len = len(chunk) - self.size += chunk_len - self.extend(chunk) - - # shrink buffer - if (self.offset and len(self) > 5120): - self._shrink() - - def feed_data(self, data): - self._writer.send(data) - - def read(self, size): - """read() reads specified amount of bytes.""" - - while True: - if self.size >= size: - start, end = self.offset, self.offset + size - self.offset = end - self.size = self.size - size - return self[start:end] - - self._writer.send((yield)) - - def readsome(self, size=None): - """reads size of less amount of bytes.""" - - while True: - if self.size > 0: - if size is None or self.size < size: - size = self.size - - start, end = self.offset, self.offset + size - self.offset = end - self.size = self.size - size - - return self[start:end] - - self._writer.send((yield)) - - def readuntil(self, stop, limit=None, exc=ValueError): - assert isinstance(stop, bytes) and stop, \ - 'bytes is required: {!r}'.format(stop) - - stop_len = len(stop) - - while True: - pos = self.find(stop, self.offset) - if pos >= 0: - end = pos + stop_len - size = end - self.offset - if limit is not None and size > limit: - raise exc('Line is too long.') - - start, self.offset = self.offset, end - self.size = self.size - size - - return self[start:end] - else: - if limit is not None and self.size > limit: - raise exc('Line is too long.') - - self._writer.send((yield)) - - def skip(self, size): - """skip() skips specified amount of bytes.""" - - while self.size < size: - self._writer.send((yield)) - - self.size -= size - self.offset += size - - def skipuntil(self, stop): - """skipuntil() reads until `stop` bytes sequence.""" - assert isinstance(stop, bytes) and stop, \ - 'bytes is required: {!r}'.format(stop) - - stop_len = len(stop) - - while True: - stop_line = self.find(stop, self.offset) - if stop_line >= 0: - end = stop_line + stop_len - self.size = self.size - (end - self.offset) - self.offset = end - return - else: - self.size = 0 - self.offset = len(self) - 1 - - self._writer.send((yield)) - - def __bytes__(self): - return bytes(self[self.offset:]) - - -def lines_parser(limit=2**16, exc=ValueError): - """Lines parser. - - lines parser splits a bytes stream into a chunks of data, each chunk ends - with \n symbol.""" - out, buf = yield - - while True: - out.feed_data((yield from buf.readuntil(b'\n', limit, exc))) - - -def chunks_parser(size=8196): - """Chunks parser. - - chunks parser splits a bytes stream into a specified - size chunks of data.""" - out, buf = yield - - while True: - out.feed_data((yield from buf.read(size))) diff --git a/tulip/test_utils.py b/tulip/test_utils.py index e73a1d7b..61001168 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -1,36 +1,23 @@ """Utilities shared by tests.""" -import cgi import collections import contextlib -import gc -import email.parser -import http.server -import json -import logging import io import unittest.mock import os -import re -import socket import sys import threading -import traceback import unittest import unittest.mock -import urllib.parse +from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer try: import ssl except ImportError: # pragma: no cover ssl = None import tulip -import tulip.http -from tulip.http import client from tulip import base_events from tulip import events - -from tulip import base_events from tulip import selectors @@ -59,259 +46,56 @@ def run_once(loop): @contextlib.contextmanager -def run_test_server(loop, *, host='127.0.0.1', port=0, - use_ssl=False, router=None): - properties = {} - transports = [] - - class HttpServer: - - def __init__(self, host, port): - self.host = host - self.port = port - self.address = (host, port) - self._url = '{}://{}:{}'.format( - 'https' if use_ssl else 'http', host, port) - - def __getitem__(self, key): - return properties[key] - - def __setitem__(self, key, value): - properties[key] = value - - def url(self, *suffix): - return urllib.parse.urljoin( - self._url, '/'.join(str(s) for s in suffix)) - - class TestHttpServer(tulip.http.ServerHttpProtocol): - - def connection_made(self, transport): - transports.append(transport) - super().connection_made(transport) - - def handle_request(self, message, payload): - if properties.get('close', False): - return - - if properties.get('noresponse', False): - yield from tulip.sleep(99999) - - if router is not None: - body = bytearray() - try: - while True: - body.extend((yield from payload.read())) - except tulip.EofStream: - pass - - rob = router( - self, properties, - self.transport, message, bytes(body)) - rob.dispatch() - - else: - response = tulip.http.Response( - self.transport, 200, message.version) - - text = b'Test message' - response.add_header('Content-type', 'text/plain') - response.add_header('Content-length', str(len(text))) - response.send_headers() - response.write(text) - response.write_eof() - - if use_ssl: - here = os.path.join(os.path.dirname(__file__), '..', 'tests') - keyfile = os.path.join(here, 'sample.key') - certfile = os.path.join(here, 'sample.crt') - sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - sslcontext.load_cert_chain(certfile, keyfile) - else: - sslcontext = None - - def run(loop, fut): - thread_loop = tulip.new_event_loop() - tulip.set_event_loop(thread_loop) - - socks = thread_loop.run_until_complete( - thread_loop.start_serving( - lambda: TestHttpServer(keep_alive=0.5), - host, port, ssl=sslcontext)) - - waiter = tulip.Future(loop=thread_loop) - loop.call_soon_threadsafe( - fut.set_result, (thread_loop, waiter, socks[0].getsockname())) - - try: - thread_loop.run_until_complete(waiter) - finally: - # call pending connection_made if present - run_briefly(thread_loop) - - # close opened trnsports - for tr in transports: - tr.close() - - run_briefly(thread_loop) # call close callbacks - - for s in socks: - thread_loop.stop_serving(s) - - thread_loop.stop() - thread_loop.close() - gc.collect() - - fut = tulip.Future(loop=loop) - server_thread = threading.Thread(target=run, args=(loop, fut)) - server_thread.start() +def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): + + class SilentWSGIRequestHandler(WSGIRequestHandler): + def get_stderr(self): + return io.StringIO() + + def log_message(self, format, *args): + pass + + class SilentWSGIServer(WSGIServer): + def handle_error(self, request, client_address): + pass + + class SSLWSGIServer(SilentWSGIServer): + def finish_request(self, request, client_address): + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + keyfile = os.path.join(here, 'sample.key') + certfile = os.path.join(here, 'sample.crt') + ssock = ssl.wrap_socket(request, + keyfile=keyfile, + certfile=certfile, + server_side=True) + try: + self.RequestHandlerClass(ssock, client_address, self) + ssock.close() + except OSError: + # maybe socket has been closed by peer + pass - thread_loop, waiter, addr = loop.run_until_complete(fut) + def app(environ, start_response): + status = '200 OK' + headers = [('Content-type', 'text/plain')] + start_response(status, headers) + return [b'Test message'] + + # Run the test WSGI server in a separate thread in order not to + # interfere with event handling in the main thread + server_class = SSLWSGIServer if use_ssl else SilentWSGIServer + httpd = make_server(host, port, app, + server_class, SilentWSGIRequestHandler) + httpd.address = httpd.server_address + server_thread = threading.Thread(target=httpd.serve_forever) + server_thread.start() try: - yield HttpServer(*addr) + yield httpd finally: - thread_loop.call_soon_threadsafe(waiter.set_result, None) + httpd.shutdown() server_thread.join() -class Router: - - _response_version = "1.1" - _responses = http.server.BaseHTTPRequestHandler.responses - - def __init__(self, srv, props, transport, message, payload): - # headers - self._headers = http.client.HTTPMessage() - for hdr, val in message.headers: - self._headers.add_header(hdr, val) - - self._srv = srv - self._props = props - self._transport = transport - self._method = message.method - self._uri = message.path - self._version = message.version - self._compression = message.compression - self._body = payload - - url = urllib.parse.urlsplit(self._uri) - self._path = url.path - self._query = url.query - - @staticmethod - def define(rmatch): - def wrapper(fn): - f_locals = sys._getframe(1).f_locals - mapping = f_locals.setdefault('_mapping', []) - mapping.append((re.compile(rmatch), fn.__name__)) - return fn - - return wrapper - - def dispatch(self): # pragma: no cover - for route, fn in self._mapping: - match = route.match(self._path) - if match is not None: - try: - return getattr(self, fn)(match) - except Exception: - out = io.StringIO() - traceback.print_exc(file=out) - self._response(500, out.getvalue()) - - return - - return self._response(self._start_response(404)) - - def _start_response(self, code): - return tulip.http.Response(self._transport, code) - - def _response(self, response, body=None, headers=None, chunked=False): - r_headers = {} - for key, val in self._headers.items(): - key = '-'.join(p.capitalize() for p in key.split('-')) - r_headers[key] = val - - encoding = self._headers.get('content-encoding', '').lower() - if 'gzip' in encoding: # pragma: no cover - cmod = 'gzip' - elif 'deflate' in encoding: - cmod = 'deflate' - else: - cmod = '' - - resp = { - 'method': self._method, - 'version': '%s.%s' % self._version, - 'path': self._uri, - 'headers': r_headers, - 'origin': self._transport.get_extra_info('addr', ' ')[0], - 'query': self._query, - 'form': {}, - 'compression': cmod, - 'multipart-data': [] - } - if body: # pragma: no cover - resp['content'] = body - - ct = self._headers.get('content-type', '').lower() - - # application/x-www-form-urlencoded - if ct == 'application/x-www-form-urlencoded': - resp['form'] = urllib.parse.parse_qs(self._body.decode('latin1')) - - # multipart/form-data - elif ct.startswith('multipart/form-data'): # pragma: no cover - out = io.BytesIO() - for key, val in self._headers.items(): - out.write(bytes('{}: {}\r\n'.format(key, val), 'latin1')) - - out.write(b'\r\n') - out.write(self._body) - out.write(b'\r\n') - out.seek(0) - - message = email.parser.BytesParser().parse(out) - if message.is_multipart(): - for msg in message.get_payload(): - if msg.is_multipart(): - logging.warn('multipart msg is not expected') - else: - key, params = cgi.parse_header( - msg.get('content-disposition', '')) - params['data'] = msg.get_payload() - params['content-type'] = msg.get_content_type() - resp['multipart-data'].append(params) - - body = json.dumps(resp, indent=4, sort_keys=True) - - # default headers - hdrs = [('Connection', 'close'), - ('Content-Type', 'application/json')] - if chunked: - hdrs.append(('Transfer-Encoding', 'chunked')) - else: - hdrs.append(('Content-Length', str(len(body)))) - - # extra headers - if headers: - hdrs.extend(headers.items()) - - if chunked: - response.force_chunked() - - # headers - response.add_headers(*hdrs) - response.send_headers() - - # write payload - response.write(client.str_to_bytes(body)) - response.write_eof() - - # keep-alive - if response.keep_alive(): - self._srv.keep_alive(True) - - def make_test_protocol(base): dct = {} for name in dir(base): From bbae286480ac9294b096b09bb40d268f4e822de3 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Mon, 7 Oct 2013 18:44:37 -0700 Subject: [PATCH 0643/1502] remove tulip.http from setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a19e3224..dcaee96f 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,6 @@ setup(name='tulip', description="reference implementation of PEP 3156", url='http://www.python.org/dev/peps/pep-3156/', - packages=['tulip', 'tulip.http'], + packages=['tulip'], ext_modules=extensions ) From 65a58c979821e6f3534dea9227a7c60892687a4e Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 8 Oct 2013 15:25:54 -0700 Subject: [PATCH 0644/1502] Change buffers to queues. --- tests/selector_events_test.py | 59 ++++++++++++++++++----------------- tulip/selector_events.py | 5 ++- 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index f810f319..be624366 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -1,5 +1,6 @@ """Tests for selector_events.py""" +import collections import errno import gc import pprint @@ -610,13 +611,13 @@ def test_close_write_buffer(self): def test_force_close(self): tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) - tr._buffer = [b'1'] + tr._buffer.append(b'1') self.loop.add_reader(7, unittest.mock.sentinel) self.loop.add_writer(7, unittest.mock.sentinel) tr._force_close(None) self.assertTrue(tr._closing) - self.assertEqual(tr._buffer, []) + self.assertEqual(tr._buffer, collections.deque()) self.assertFalse(self.loop.readers) self.assertFalse(self.loop.writers) @@ -765,7 +766,7 @@ def test_write_no_data(self): transport._buffer.append(b'data') transport.write(b'') self.assertFalse(self.sock.send.called) - self.assertEqual([b'data'], transport._buffer) + self.assertEqual(collections.deque([b'data']), transport._buffer) def test_write_buffer(self): transport = _SelectorSocketTransport( @@ -773,7 +774,7 @@ def test_write_buffer(self): transport._buffer.append(b'data1') transport.write(b'data2') self.assertFalse(self.sock.send.called) - self.assertEqual([b'data1', b'data2'], transport._buffer) + self.assertEqual(collections.deque([b'data1', b'data2']), transport._buffer) def test_write_partial(self): data = b'data' @@ -784,7 +785,7 @@ def test_write_partial(self): transport.write(data) self.loop.assert_writer(7, transport._write_ready) - self.assertEqual([b'ta'], transport._buffer) + self.assertEqual(collections.deque([b'ta']), transport._buffer) def test_write_partial_none(self): data = b'data' @@ -796,7 +797,7 @@ def test_write_partial_none(self): transport.write(data) self.loop.assert_writer(7, transport._write_ready) - self.assertEqual([b'data'], transport._buffer) + self.assertEqual(collections.deque([b'data']), transport._buffer) def test_write_tryagain(self): self.sock.send.side_effect = BlockingIOError @@ -807,7 +808,7 @@ def test_write_tryagain(self): transport.write(data) self.loop.assert_writer(7, transport._write_ready) - self.assertEqual([b'data'], transport._buffer) + self.assertEqual(collections.deque([b'data']), transport._buffer) @unittest.mock.patch('tulip.selector_events.tulip_log') def test_write_exception(self, m_log): @@ -887,7 +888,7 @@ def test_write_ready_partial(self): self.loop.add_writer(7, transport._write_ready) transport._write_ready() self.loop.assert_writer(7, transport._write_ready) - self.assertEqual([b'ta'], transport._buffer) + self.assertEqual(collections.deque([b'ta']), transport._buffer) def test_write_ready_partial_none(self): data = b'data' @@ -899,19 +900,19 @@ def test_write_ready_partial_none(self): self.loop.add_writer(7, transport._write_ready) transport._write_ready() self.loop.assert_writer(7, transport._write_ready) - self.assertEqual([b'data'], transport._buffer) + self.assertEqual(collections.deque([b'data']), transport._buffer) def test_write_ready_tryagain(self): self.sock.send.side_effect = BlockingIOError transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) - transport._buffer = [b'data1', b'data2'] + transport._buffer = collections.deque([b'data1', b'data2']) self.loop.add_writer(7, transport._write_ready) transport._write_ready() self.loop.assert_writer(7, transport._write_ready) - self.assertEqual([b'data1data2'], transport._buffer) + self.assertEqual(collections.deque([b'data1data2']), transport._buffer) def test_write_ready_exception(self): err = self.sock.send.side_effect = OSError() @@ -951,7 +952,7 @@ def test_write_eof_buffer(self): self.sock.send.side_effect = BlockingIOError tr.write(b'data') tr.write_eof() - self.assertEqual(tr._buffer, [b'data']) + self.assertEqual(tr._buffer, collections.deque([b'data'])) self.assertTrue(tr._eof) self.assertFalse(self.sock.shutdown.called) self.sock.send.side_effect = lambda _: 4 @@ -1045,7 +1046,7 @@ def test_write_no_data(self): transport = self._make_one() transport._buffer.append(b'data') transport.write(b'') - self.assertEqual([b'data'], transport._buffer) + self.assertEqual(collections.deque([b'data']), transport._buffer) def test_write_str(self): transport = self._make_one() @@ -1063,7 +1064,7 @@ def test_write_exception(self, m_log): transport = self._make_one() transport._conn_lost = 1 transport.write(b'data') - self.assertEqual(transport._buffer, []) + self.assertEqual(transport._buffer, collections.deque()) transport.write(b'data') transport.write(b'data') transport.write(b'data') @@ -1122,34 +1123,34 @@ def test_on_ready_send(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError self.sslsock.send.return_value = 4 transport = self._make_one() - transport._buffer = [b'data'] + transport._buffer = collections.deque([b'data']) transport._on_ready() - self.assertEqual([], transport._buffer) + self.assertEqual(collections.deque(), transport._buffer) self.assertTrue(self.sslsock.send.called) def test_on_ready_send_none(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError self.sslsock.send.return_value = 0 transport = self._make_one() - transport._buffer = [b'data1', b'data2'] + transport._buffer = collections.deque([b'data1', b'data2']) transport._on_ready() self.assertTrue(self.sslsock.send.called) - self.assertEqual([b'data1data2'], transport._buffer) + self.assertEqual(collections.deque([b'data1data2']), transport._buffer) def test_on_ready_send_partial(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError self.sslsock.send.return_value = 2 transport = self._make_one() - transport._buffer = [b'data1', b'data2'] + transport._buffer = collections.deque([b'data1', b'data2']) transport._on_ready() self.assertTrue(self.sslsock.send.called) - self.assertEqual([b'ta1data2'], transport._buffer) + self.assertEqual(collections.deque([b'ta1data2']), transport._buffer) def test_on_ready_send_closing_partial(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError self.sslsock.send.return_value = 2 transport = self._make_one() - transport._buffer = [b'data1', b'data2'] + transport._buffer = collections.deque([b'data1', b'data2']) transport._on_ready() self.assertTrue(self.sslsock.send.called) self.assertFalse(self.sslsock.close.called) @@ -1159,7 +1160,7 @@ def test_on_ready_send_closing(self): self.sslsock.send.return_value = 4 transport = self._make_one() transport.close() - transport._buffer = [b'data'] + transport._buffer = collections.deque([b'data']) transport._on_ready() self.assertFalse(self.loop.writers) self.protocol.connection_lost.assert_called_with(None) @@ -1169,7 +1170,7 @@ def test_on_ready_send_closing_empty_buffer(self): self.sslsock.send.return_value = 4 transport = self._make_one() transport.close() - transport._buffer = [] + transport._buffer = collections.deque() transport._on_ready() self.assertFalse(self.loop.writers) self.protocol.connection_lost.assert_called_with(None) @@ -1178,31 +1179,31 @@ def test_on_ready_send_retry(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError transport = self._make_one() - transport._buffer = [b'data'] + transport._buffer = collections.deque([b'data']) self.sslsock.send.side_effect = ssl.SSLWantReadError transport._on_ready() self.assertTrue(self.sslsock.send.called) - self.assertEqual([b'data'], transport._buffer) + self.assertEqual(collections.deque([b'data']), transport._buffer) self.sslsock.send.side_effect = ssl.SSLWantWriteError transport._on_ready() - self.assertEqual([b'data'], transport._buffer) + self.assertEqual(collections.deque([b'data']), transport._buffer) self.sslsock.send.side_effect = BlockingIOError() transport._on_ready() - self.assertEqual([b'data'], transport._buffer) + self.assertEqual(collections.deque([b'data']), transport._buffer) def test_on_ready_send_exc(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError err = self.sslsock.send.side_effect = OSError() transport = self._make_one() - transport._buffer = [b'data'] + transport._buffer = collections.deque([b'data']) transport._fatal_error = unittest.mock.Mock() transport._on_ready() transport._fatal_error.assert_called_with(err) - self.assertEqual([], transport._buffer) + self.assertEqual(collections.deque(), transport._buffer) def test_write_eof(self): tr = self._make_one() diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 92330e87..6f6a271c 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -329,7 +329,7 @@ def __init__(self, loop, sock, protocol, extra): self._sock = sock self._sock_fd = sock.fileno() self._protocol = protocol - self._buffer = [] + self._buffer = collections.deque() self._conn_lost = 0 self._closing = False # Set when close() called. @@ -579,7 +579,7 @@ def _on_ready(self): # Now try writing, if there's anything to write. if self._buffer: data = b''.join(self._buffer) - self._buffer = [] + self._buffer.clear() try: n = self._sock.send(data) except (BlockingIOError, InterruptedError, @@ -630,7 +630,6 @@ def __init__(self, loop, sock, protocol, address=None, extra=None): super().__init__(loop, sock, protocol, extra) self._address = address - self._buffer = collections.deque() self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) From 22218e99bda53fe718d09ce67c533cfe679d1181 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 8 Oct 2013 16:19:18 -0700 Subject: [PATCH 0645/1502] Add all_tasks() and variants, get_stack(), and print_stack(). --- tulip/tasks.py | 130 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) diff --git a/tulip/tasks.py b/tulip/tasks.py index e3814a64..22e506bc 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -10,6 +10,10 @@ import concurrent.futures import functools import inspect +import linecache +import sys +import traceback +import weakref from . import events from . import futures @@ -62,6 +66,54 @@ class Task(futures.Future): # _wakeup(). When _fut_waiter is not None, one of its callbacks # must be _wakeup(). + # Weak set containing all tasks alive. + _all_tasks = weakref.WeakSet() + + @classmethod + def all_tasks(cls): + """Return a set of all tasks in existence.""" + return set(cls._all_tasks) + + @classmethod + def all_pending_tasks(cls): + """Return a set of all tasks in existence that aren't done yet.""" + return {t for t in cls._all_tasks if not t.done()} + + @classmethod + def all_done_tasks(cls): + """Return a set of all tasks in existence that are done. + + This is the union of all_successful_tasks() and all_failed_tasks(). + """ + return {t for t in cls._all_tasks if t.done()} + + @classmethod + def all_successful_tasks(cls): + """Return a set of all tasks in existence that have a valid result.""" + return {t for t in cls._all_tasks + if t.done() and not t.cancelled() and t.exception() is None} + + @classmethod + def all_failed_tasks(cls): + """Return a set of all tasks in existence that have failed. + + This is the union of all_excepted_tasks() and all_cancelled_tasks(). + """ + return {t for t in cls._all_tasks + if t.done() and (t.cancelled() or t.exception())} + + @classmethod + def all_excepted_tasks(cls): + """Return a set of all tasks in existence that have an exception.""" + return {t for t in cls._all_tasks + if t.done() and not t.cancelled() and + t.exception() is not None} + + @classmethod + def all_cancelled_tasks(cls): + """Return a set of all tasks in existence that were cancelled.""" + return {t for t in cls._all_tasks if t.cancelled()} + def __init__(self, coro, *, loop=None): assert inspect.isgenerator(coro) # Must be a coroutine *object*. super().__init__(loop=loop) @@ -69,6 +121,7 @@ def __init__(self, coro, *, loop=None): self._fut_waiter = None self._must_cancel = False self._loop.call_soon(self._step) + self.__class__._all_tasks.add(self) def __repr__(self): res = super().__repr__() @@ -82,6 +135,83 @@ def __repr__(self): res = res[:i] + '(<{}>)'.format(self._coro.__name__) + res[i:] return res + def get_stack(self, *, limit=None): + """Return the list of stack frames for this task's coroutine. + + If the coroutine is active, this returns the stack where it is + suspended. If the coroutine has completed successfully or was + cancelled, this returns an empty list. If the coroutine was + terminated by an exception, this returns the list of traceback + frames. + + The frames are always ordered from oldest to newest. + + The optional limit gives the maximum nummber of frames to + return; by default all available frames are returned. Its + meaning differs depending on whether a stack or a traceback is + returned: the newest frames of a stack are returned, but the + oldest frames of a traceback are returned. (This matches the + behavior of the traceback module.) + + For reasons beyond our control, only one stack frame is + returned for a suspended coroutine. + """ + frames = [] + f = self._coro.gi_frame + if f is not None: + while f is not None: + if limit is not None: + if limit <= 0: + break + limit -= 1 + frames.append(f) + f = f.f_back + frames.reverse() + elif self._exception is not None: + tb = self._exception.__traceback__ + while tb is not None: + if limit is not None: + if limit <= 0: + break + limit -= 1 + frames.append(tb.tb_frame) + tb = tb.tb_next + return frames + + def print_stack(self, *, limit=None, file=None): + """Print the stack or traceback for this task's coroutine. + + This produces output similar to that of the traceback module, + for the frames retrieved by get_stack(). The limit argument + is passed to get_stack(). The file argument is an I/O stream + to which the output goes; by default it goes to sys.stderr. + """ + extracted_list = [] + checked = set() + for f in self.get_stack(limit=limit): + lineno = f.f_lineno + co = f.f_code + filename = co.co_filename + name = co.co_name + if filename not in checked: + checked.add(filename) + linecache.checkcache(filename) + line = linecache.getline(filename, lineno, f.f_globals) + extracted_list.append((filename, lineno, name, line)) + exc = self._exception + if not extracted_list: + print('No stack for %r' % self, file=file) + elif exc is not None: + print('Traceback for %r (most recent call last):' % self, + file=file) + else: + print('Stack for %r (most recent call last):' % self, + file=file) + traceback.print_list(extracted_list, file=file) + if exc is not None: + for line in traceback.format_exception_only(exc.__class__, exc): + print(line, file=file, end='') + def cancel(self): if self.done(): return False From 1e441e949f2ad99bc75f36e7bcd19950a4d80647 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 8 Oct 2013 16:23:20 -0700 Subject: [PATCH 0646/1502] Fold some long lines that crept in. --- t.py | 60 +++++++++++++++++++++++++++++++++++ tests/selector_events_test.py | 3 +- tulip/proactor_events.py | 3 +- 3 files changed, 64 insertions(+), 2 deletions(-) create mode 100644 t.py diff --git a/t.py b/t.py new file mode 100644 index 00000000..89518098 --- /dev/null +++ b/t.py @@ -0,0 +1,60 @@ +from tulip import * + + +@coroutine +def helper(r): + print('--- helper ---') + for t in Task.all_tasks(): + print('[[[') + t.print_stack() + print(']]]') + print('--- end helper ---') + print(Task.all_pending_tasks()) + print(Task.all_done_tasks()) + print(Task.all_successful_tasks()) + print(Task.all_failed_tasks()) + print(Task.all_excepted_tasks()) + line = yield from r.readline() + 1/0 + return line + +def doit(): + l = get_event_loop() + lr = l.run_until_complete + r, w = lr(open_connection('python.org', 80)) + t1 = async(helper(r)) + for t in Task.all_tasks(): t.print_stack() + print(Task.all_pending_tasks()) + print(Task.all_done_tasks()) + print(Task.all_successful_tasks()) + print(Task.all_failed_tasks()) + print(Task.all_excepted_tasks()) + print('---') + l._run_once() + for t in Task.all_tasks(): t.print_stack() + print('---') + w.write(b'GET /\r\n') + w.write_eof() + try: + lr(t1) + except Exception as e: + print('catching', e) + finally: + print(Task.all_pending_tasks()) + print(Task.all_done_tasks()) + print(Task.all_successful_tasks()) + print(Task.all_failed_tasks()) + print(Task.all_excepted_tasks()) + print(Task.all_cancelled_tasks()) + for t in Task.all_tasks(): + print('[[[') + t.print_stack() + print(']]]') + + +def main(): + doit() + + +if __name__ == '__main__': + main() diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index be624366..8926b2da 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -774,7 +774,8 @@ def test_write_buffer(self): transport._buffer.append(b'data1') transport.write(b'data2') self.assertFalse(self.sock.send.called) - self.assertEqual(collections.deque([b'data1', b'data2']), transport._buffer) + self.assertEqual(collections.deque([b'data1', b'data2']), + transport._buffer) def test_write_partial(self): data = b'data' diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index c7b524ea..79d2d094 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -237,7 +237,8 @@ def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): def _make_duplex_pipe_transport(self, sock, protocol, waiter=None, extra=None): - return _ProactorDuplexPipeTransport(self, sock, protocol, waiter, extra) + return _ProactorDuplexPipeTransport(self, + sock, protocol, waiter, extra) def _make_read_pipe_transport(self, sock, protocol, waiter=None, extra=None): From 4c5456f861db9ac964c0e1777acf4a4bdf5f263a Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 8 Oct 2013 16:26:29 -0700 Subject: [PATCH 0647/1502] Move stack example. --- t.py => examples/stacks.py | 23 +++-------------------- 1 file changed, 3 insertions(+), 20 deletions(-) rename t.py => examples/stacks.py (52%) diff --git a/t.py b/examples/stacks.py similarity index 52% rename from t.py rename to examples/stacks.py index 89518098..77a99cf5 100644 --- a/t.py +++ b/examples/stacks.py @@ -1,3 +1,6 @@ +"""Crude demo for print_stack().""" + + from tulip import * @@ -5,15 +8,8 @@ def helper(r): print('--- helper ---') for t in Task.all_tasks(): - print('[[[') t.print_stack() - print(']]]') print('--- end helper ---') - print(Task.all_pending_tasks()) - print(Task.all_done_tasks()) - print(Task.all_successful_tasks()) - print(Task.all_failed_tasks()) - print(Task.all_excepted_tasks()) line = yield from r.readline() 1/0 return line @@ -24,11 +20,6 @@ def doit(): r, w = lr(open_connection('python.org', 80)) t1 = async(helper(r)) for t in Task.all_tasks(): t.print_stack() - print(Task.all_pending_tasks()) - print(Task.all_done_tasks()) - print(Task.all_successful_tasks()) - print(Task.all_failed_tasks()) - print(Task.all_excepted_tasks()) print('---') l._run_once() for t in Task.all_tasks(): t.print_stack() @@ -40,16 +31,8 @@ def doit(): except Exception as e: print('catching', e) finally: - print(Task.all_pending_tasks()) - print(Task.all_done_tasks()) - print(Task.all_successful_tasks()) - print(Task.all_failed_tasks()) - print(Task.all_excepted_tasks()) - print(Task.all_cancelled_tasks()) for t in Task.all_tasks(): - print('[[[') t.print_stack() - print(']]]') def main(): From dfd45ef4ee5031ca1e9dd0b1a1c79953df8163b4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 9 Oct 2013 11:35:52 -0700 Subject: [PATCH 0648/1502] A succession of ever more sophisticated HTTP clients. (More to come.) --- examples/fetch0.py | 32 +++++++++++++ examples/fetch1.py | 74 ++++++++++++++++++++++++++++ examples/fetch2.py | 117 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 223 insertions(+) create mode 100644 examples/fetch0.py create mode 100644 examples/fetch1.py create mode 100644 examples/fetch2.py diff --git a/examples/fetch0.py b/examples/fetch0.py new file mode 100644 index 00000000..84edaa26 --- /dev/null +++ b/examples/fetch0.py @@ -0,0 +1,32 @@ +"""Simplest possible HTTP client.""" + +import sys + +from tulip import * + + +@coroutine +def fetch(): + r, w = yield from open_connection('python.org', 80) + request = 'GET / HTTP/1.0\r\n\r\n' + print('>', request, file=sys.stderr) + w.write(request.encode('latin-1')) + while True: + line = yield from r.readline() + line = line.decode('latin-1').rstrip() + if not line: + break + print('<', line, file=sys.stderr) + print(file=sys.stderr) + body = yield from r.read() + return body + + +def main(): + loop = get_event_loop() + body = loop.run_until_complete(fetch()) + print(body.decode('latin-1'), end='') + + +if __name__ == '__main__': + main() diff --git a/examples/fetch1.py b/examples/fetch1.py new file mode 100644 index 00000000..d7d23507 --- /dev/null +++ b/examples/fetch1.py @@ -0,0 +1,74 @@ +"""Fetch one URL and write its content to stdout. + +This version adds URL parsing (including SSL) and a Response object. +""" + +import sys +import urllib.parse + +from tulip import * + + +class Response: + + def __init__(self, verbose=True): + self.verbose = verbose + self.http_version = None # 'HTTP/1.1' + self.status = None # 200 + self.reason = None # 'Ok' + self.headers = [] # [('Content-Type', 'text/html')] + + @coroutine + def read(self, reader): + @coroutine + def getline(): + return (yield from reader.readline()).decode('latin-1').rstrip() + status_line = yield from getline() + if self.verbose: print('<', status_line, file=sys.stderr) + self.http_version, status, self.reason = status_line.split(None, 2) + self.status = int(status) + while True: + header_line = yield from getline() + if not header_line: + break + if self.verbose: print('<', header_line, file=sys.stderr) + key, value = header_line.split(':', 1) + self.headers.append((key, value.strip())) # TODO: Continuation lines. + if self.verbose: print(file=sys.stderr) + + +@coroutine +def fetch(url, verbose=True): + parts = urllib.parse.urlparse(url) + if parts.scheme == 'http': + ssl = False + elif parts.scheme == 'https': + ssl = True + else: + print('URL must use http or https.') + sys.exit(1) + port = parts.port + if port is None: + port = 443 if ssl else 80 + path = parts.path or '/' + if parts.query: + path += '?' + parts.query + request = 'GET %s HTTP/1.0\r\n\r\n' % path + if verbose: + print('>', request, file=sys.stderr, end='') + r, w = yield from open_connection(parts.hostname, port, ssl=ssl) + w.write(request.encode('latin-1')) + response = Response(verbose) + yield from response.read(r) + body = yield from r.read() + return body + + +def main(): + loop = get_event_loop() + body = loop.run_until_complete(fetch(sys.argv[1], '-v' in sys.argv)) + print(body.decode('latin-1'), end='') + + +if __name__ == '__main__': + main() diff --git a/examples/fetch2.py b/examples/fetch2.py new file mode 100644 index 00000000..cd9bf34c --- /dev/null +++ b/examples/fetch2.py @@ -0,0 +1,117 @@ +"""Fetch one URL and write its content to stdout. + +This version adds a Request object. +""" + +import sys +import urllib.parse +from http.client import BadStatusLine + +from tulip import * + + +class Request: + + def __init__(self, url, verbose=True): + self.url = url + self.verbose = verbose + self.parts = urllib.parse.urlparse(self.url) + self.scheme = self.parts.scheme + assert self.scheme in ('http', 'https'), repr(url) + self.ssl = self.parts.scheme == 'https' + self.hostname = self.parts.hostname + self.port = self.parts.port or (443 if self.ssl else 80) + self.path = (self.parts.path or '/') + self.query = self.parts.query + if self.query: + self.full_path = '%s?%s' % (self.path, self.query) + else: + self.full_path = self.path + self.http_version = 'HTTP/1.0' + self.method = 'GET' + self.headers = [] + + @coroutine + def connect(self): + if self.verbose: + print('* Connecting to %s:%s using %s' % + (self.hostname, self.port, 'ssl' if self.ssl else 'tcp')) + self.reader, self.writer = yield from open_connection(self.hostname, self.port, ssl=self.ssl) + if self.verbose: + print('* Connected to %s' % (self.writer.get_extra_info('socket').getpeername(),)) + + def putline(self, line): + self.writer.write(line.encode('latin-1') + b'\r\n') + + @coroutine + def send_request(self): + request = '%s %s %s' % (self.method, self.full_path, self.http_version) + if self.verbose: print('>', request, file=sys.stderr) + self.putline(request) + for key, value in self.headers: + self.putline('%s: %s' % (key, value)) + self.putline('') + + @coroutine + def get_response(self): + response = Response(self.reader, self.verbose) + yield from response.read_headers() + return response + + +class Response: + + def __init__(self, reader, verbose=True): + self.reader = reader + self.verbose = verbose + self.http_version = None # 'HTTP/1.1' + self.status = None # 200 + self.reason = None # 'Ok' + self.headers = [] # [('Content-Type', 'text/html')] + + @coroutine + def getline(self): + return (yield from self.reader.readline()).decode('latin-1').rstrip() + + @coroutine + def read_headers(self): + status_line = yield from self.getline() + if self.verbose: print('<', status_line, file=sys.stderr) + status_parts = status_line.split(None, 2) + if len(status_parts) != 3: + raise BadStatusLine(status_line) + self.http_version, status, self.reason = status_parts + self.status = int(status) + while True: + header_line = yield from self.getline() + if not header_line: + break + if self.verbose: print('<', header_line, file=sys.stderr) + key, value = header_line.split(':', 1) + self.headers.append((key, value.strip())) # TODO: Continuation lines. + if self.verbose: print(file=sys.stderr) + + @coroutine + def read(self): + body = yield from self.reader.read() + return body + + +@coroutine +def fetch(url, verbose=True): + request = Request(url, verbose) + yield from request.connect() + yield from request.send_request() + response = yield from request.get_response() + body = yield from response.read() + return body + + +def main(): + loop = get_event_loop() + body = loop.run_until_complete(fetch(sys.argv[1], '-v' in sys.argv)) + print(body.decode('latin-1'), end='') + + +if __name__ == '__main__': + main() From 2ce868c52f59cce62d15d06c024b12ed4e519430 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 9 Oct 2013 12:10:19 -0700 Subject: [PATCH 0649/1502] Use HTTP/1.1; couple of fixes/refinements. --- examples/fetch2.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/examples/fetch2.py b/examples/fetch2.py index cd9bf34c..e9a746b0 100644 --- a/examples/fetch2.py +++ b/examples/fetch2.py @@ -19,6 +19,7 @@ def __init__(self, url, verbose=True): self.scheme = self.parts.scheme assert self.scheme in ('http', 'https'), repr(url) self.ssl = self.parts.scheme == 'https' + self.netloc = self.parts.netloc self.hostname = self.parts.hostname self.port = self.parts.port or (443 if self.ssl else 80) self.path = (self.parts.path or '/') @@ -27,18 +28,23 @@ def __init__(self, url, verbose=True): self.full_path = '%s?%s' % (self.path, self.query) else: self.full_path = self.path - self.http_version = 'HTTP/1.0' + self.http_version = 'HTTP/1.1' self.method = 'GET' self.headers = [] + self.reader = None + self.writer = None @coroutine def connect(self): if self.verbose: print('* Connecting to %s:%s using %s' % - (self.hostname, self.port, 'ssl' if self.ssl else 'tcp')) + (self.hostname, self.port, 'ssl' if self.ssl else 'tcp'), + file=sys.stderr) self.reader, self.writer = yield from open_connection(self.hostname, self.port, ssl=self.ssl) if self.verbose: - print('* Connected to %s' % (self.writer.get_extra_info('socket').getpeername(),)) + print('* Connected to %s' % + (self.writer.get_extra_info('socket').getpeername(),), + file=sys.stderr) def putline(self, line): self.writer.write(line.encode('latin-1') + b'\r\n') @@ -48,8 +54,12 @@ def send_request(self): request = '%s %s %s' % (self.method, self.full_path, self.http_version) if self.verbose: print('>', request, file=sys.stderr) self.putline(request) + if 'host' not in {key.lower() for key, _ in self.headers}: + self.headers.insert(0, ('Host', self.netloc)) for key, value in self.headers: - self.putline('%s: %s' % (key, value)) + line = '%s: %s' % (key, value) + if self.verbose: print('>', line, file=sys.stderr) + self.putline(line) self.putline('') @coroutine @@ -93,7 +103,15 @@ def read_headers(self): @coroutine def read(self): - body = yield from self.reader.read() + nbytes = None + for key, value in self.headers: + if key.lower() == 'content-length': + nbytes = int(value) + break + if nbytes is None: + body = yield from self.reader.read() + else: + body = yield from self.reader.readexactly(nbytes) return body @@ -110,7 +128,7 @@ def fetch(url, verbose=True): def main(): loop = get_event_loop() body = loop.run_until_complete(fetch(sys.argv[1], '-v' in sys.argv)) - print(body.decode('latin-1'), end='') + sys.stdout.buffer.write(body) if __name__ == '__main__': From 7b2a2e71c2f31c6d0d5433fd69f70a8ab575f395 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 9 Oct 2013 12:55:03 -0700 Subject: [PATCH 0650/1502] Another fetch example, using a rudimentary connection pool. --- examples/fetch3.py | 169 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 169 insertions(+) create mode 100644 examples/fetch3.py diff --git a/examples/fetch3.py b/examples/fetch3.py new file mode 100644 index 00000000..2ae79d8a --- /dev/null +++ b/examples/fetch3.py @@ -0,0 +1,169 @@ +"""Fetch one URL and write its content to stdout. + +This version adds a primitive connection pool and redirect following. +""" + +import sys +import urllib.parse +from http.client import BadStatusLine + +from tulip import * + + +class ConnectionPool: + # TODO: Locking? Close idle connections? + + def __init__(self): + self.connections = {} # {(host, port, ssl): (reader, writer)} + + @coroutine + def open_connection(self, host, port, ssl): + port = port or (443 if ssl else 80) + key = (host, port, ssl) + if key in self.connections: + return self.connections[key] + reader, writer = yield from open_connection(host, port, ssl=ssl) + self.connections[key] = reader, writer + return reader, writer + + +class Request: + + def __init__(self, url, verbose=True): + self.url = url + self.verbose = verbose + self.parts = urllib.parse.urlparse(self.url) + self.scheme = self.parts.scheme + assert self.scheme in ('http', 'https'), repr(url) + self.ssl = self.parts.scheme == 'https' + self.netloc = self.parts.netloc + self.hostname = self.parts.hostname + self.port = self.parts.port or (443 if self.ssl else 80) + self.path = (self.parts.path or '/') + self.query = self.parts.query + if self.query: + self.full_path = '%s?%s' % (self.path, self.query) + else: + self.full_path = self.path + self.http_version = 'HTTP/1.1' + self.method = 'GET' + self.headers = [] + self.reader = None + self.writer = None + + @coroutine + def connect(self, pool): + if self.verbose: + print('* Connecting to %s:%s using %s' % + (self.hostname, self.port, 'ssl' if self.ssl else 'tcp'), + file=sys.stderr) + self.reader, self.writer = yield from pool.open_connection(self.hostname, self.port, ssl=self.ssl) + if self.verbose: + print('* Connected to %s' % + (self.writer.get_extra_info('socket').getpeername(),), + file=sys.stderr) + + def putline(self, line): + if self.verbose: + print('>', line, file=sys.stderr) + self.writer.write(line.encode('latin-1') + b'\r\n') + + @coroutine + def send_request(self): + request = '%s %s %s' % (self.method, self.full_path, self.http_version) + self.putline(request) + if 'host' not in {key.lower() for key, _ in self.headers}: + self.headers.insert(0, ('Host', self.netloc)) + for key, value in self.headers: + line = '%s: %s' % (key, value) + self.putline(line) + self.putline('') + + @coroutine + def get_response(self): + response = Response(self.reader, self.verbose) + yield from response.read_headers() + return response + + +class Response: + + def __init__(self, reader, verbose=True): + self.reader = reader + self.verbose = verbose + self.http_version = None # 'HTTP/1.1' + self.status = None # 200 + self.reason = None # 'Ok' + self.headers = [] # [('Content-Type', 'text/html')] + + @coroutine + def getline(self): + line = (yield from self.reader.readline()).decode('latin-1').rstrip() + if self.verbose: print('<', line, file=sys.stderr) + return line + + @coroutine + def read_headers(self): + status_line = yield from self.getline() + status_parts = status_line.split(None, 2) + if len(status_parts) != 3: + raise BadStatusLine(status_line) + self.http_version, status, self.reason = status_parts + self.status = int(status) + while True: + header_line = yield from self.getline() + if not header_line: + break + key, value = header_line.split(':', 1) + self.headers.append((key, value.strip())) # TODO: Continuation lines. + + def get_redirect_url(self, default=None): + if self.status not in (300, 301, 302, 303, 307): + return default + return self.get_header('Location') or self.get_header('URI') or default + + def get_header(self, key, default=None): + key = key.lower() + for k, v in self.headers: + if k.lower() == key: + return v + return default + + @coroutine + def read(self): + nbytes = None + for key, value in self.headers: + if key.lower() == 'content-length': + nbytes = int(value) + break + if nbytes is None: + body = yield from self.reader.read() + else: + body = yield from self.reader.readexactly(nbytes) + return body + + +@coroutine +def fetch(url, verbose=True, max_redirect=10): + pool = ConnectionPool() + for _ in range(max_redirect): + request = Request(url, verbose) + yield from request.connect(pool) + yield from request.send_request() + response = yield from request.get_response() + body = yield from response.read() + next_url = response.get_redirect_url() + if not next_url: + break + url = urllib.parse.urljoin(url, next_url) + return body + + +def main(): + loop = get_event_loop() + body = loop.run_until_complete(fetch(sys.argv[1], '-v' in sys.argv)) + sys.stdout.buffer.write(body) + + +if __name__ == '__main__': + main() From 73e1ef9f85ce0db47acb56f1178522f282056df6 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 9 Oct 2013 13:07:23 -0700 Subject: [PATCH 0651/1502] There is no URI header. --- examples/fetch3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fetch3.py b/examples/fetch3.py index 2ae79d8a..d9fb6b22 100644 --- a/examples/fetch3.py +++ b/examples/fetch3.py @@ -120,7 +120,7 @@ def read_headers(self): def get_redirect_url(self, default=None): if self.status not in (300, 301, 302, 303, 307): return default - return self.get_header('Location') or self.get_header('URI') or default + return self.get_header('Location', default) def get_header(self, key, default=None): key = key.lower() From 0f981e3f947696ed15045b1cfb5b92c64c16ea8e Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 9 Oct 2013 13:54:09 -0700 Subject: [PATCH 0652/1502] Add chunked support. Refactor verbose printing. --- examples/fetch3.py | 47 +++++++++++++++++++++++++++++++++------------- 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/examples/fetch3.py b/examples/fetch3.py index d9fb6b22..054141b4 100644 --- a/examples/fetch3.py +++ b/examples/fetch3.py @@ -1,6 +1,7 @@ """Fetch one URL and write its content to stdout. -This version adds a primitive connection pool and redirect following. +This version adds a primitive connection pool, redirect following and +chunked transfer-encoding. """ import sys @@ -51,21 +52,20 @@ def __init__(self, url, verbose=True): self.reader = None self.writer = None + def vprint(self, *args): + if self.verbose: + print(*args, file=sys.stderr) + @coroutine def connect(self, pool): - if self.verbose: - print('* Connecting to %s:%s using %s' % - (self.hostname, self.port, 'ssl' if self.ssl else 'tcp'), - file=sys.stderr) + self.vprint('* Connecting to %s:%s using %s' % + (self.hostname, self.port, 'ssl' if self.ssl else 'tcp')) self.reader, self.writer = yield from pool.open_connection(self.hostname, self.port, ssl=self.ssl) - if self.verbose: - print('* Connected to %s' % - (self.writer.get_extra_info('socket').getpeername(),), - file=sys.stderr) + self.vprint('* Connected to %s' % + (self.writer.get_extra_info('socket').getpeername(),)) def putline(self, line): - if self.verbose: - print('>', line, file=sys.stderr) + self.vprint('>', line) self.writer.write(line.encode('latin-1') + b'\r\n') @coroutine @@ -96,10 +96,14 @@ def __init__(self, reader, verbose=True): self.reason = None # 'Ok' self.headers = [] # [('Content-Type', 'text/html')] + def vprint(self, *args): + if self.verbose: + print(*args, file=sys.stderr) + @coroutine def getline(self): line = (yield from self.reader.readline()).decode('latin-1').rstrip() - if self.verbose: print('<', line, file=sys.stderr) + self.vprint('<', line) return line @coroutine @@ -137,7 +141,24 @@ def read(self): nbytes = int(value) break if nbytes is None: - body = yield from self.reader.read() + if self.get_header('transfer-encoding').lower() == 'chunked': + blocks = [] + while True: + size_header = yield from self.reader.readline() + if not size_header: + break + parts = size_header.split(b';') + size = int(parts[0], 16) + if not size: + break + block = yield from self.reader.readexactly(size) + assert len(block) == size, (len(block), size) + blocks.append(block) + crlf = yield from self.reader.readline() + assert crlf == b'\r\n' + body = b''.join(blocks) + else: + body = self.reader.read() else: body = yield from self.reader.readexactly(nbytes) return body From 2b8b14bab4689e865bdb6c77de16ae245d75721a Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 9 Oct 2013 13:56:56 -0700 Subject: [PATCH 0653/1502] Fix a few bugs. --- examples/fetch3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/fetch3.py b/examples/fetch3.py index 054141b4..91ac0e47 100644 --- a/examples/fetch3.py +++ b/examples/fetch3.py @@ -141,7 +141,7 @@ def read(self): nbytes = int(value) break if nbytes is None: - if self.get_header('transfer-encoding').lower() == 'chunked': + if self.get_header('transfer-encoding', '').lower() == 'chunked': blocks = [] while True: size_header = yield from self.reader.readline() @@ -158,7 +158,7 @@ def read(self): assert crlf == b'\r\n' body = b''.join(blocks) else: - body = self.reader.read() + body = yield from self.reader.read() else: body = yield from self.reader.readexactly(nbytes) return body From 1c296c300cebfb44ec7cc097574533a1610069ae Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 9 Oct 2013 14:25:14 -0700 Subject: [PATCH 0654/1502] Fix connection pool. --- examples/fetch3.py | 24 +++++++++++++++++++----- tulip/protocols.py | 23 +++++++++++++++++++++++ tulip/transports.py | 21 +++++++++++++++++++++ 3 files changed, 63 insertions(+), 5 deletions(-) diff --git a/examples/fetch3.py b/examples/fetch3.py index 91ac0e47..26d6102e 100644 --- a/examples/fetch3.py +++ b/examples/fetch3.py @@ -14,17 +14,31 @@ class ConnectionPool: # TODO: Locking? Close idle connections? - def __init__(self): + def __init__(self, verbose=False): + self.verbose = verbose self.connections = {} # {(host, port, ssl): (reader, writer)} @coroutine def open_connection(self, host, port, ssl): port = port or (443 if ssl else 80) - key = (host, port, ssl) - if key in self.connections: - return self.connections[key] + ipaddrs = yield from get_event_loop().getaddrinfo(host, port) + if self.verbose: + print('* %s resolves to %s' % + (host, ', '.join(ip[4][0] for ip in ipaddrs)), + file=sys.stderr) + for _, _, _, _, (h, p, *_) in ipaddrs: + key = h, p, ssl + conn = self.connections.get(key) + if conn: + if self.verbose: + print('* Reusing pooled connection', key, file=sys.stderr) + return conn reader, writer = yield from open_connection(host, port, ssl=ssl) + host, port, *_ = writer.get_extra_info('socket').getpeername() + key = host, port, ssl self.connections[key] = reader, writer + if self.verbose: + print('* New connection', key, file=sys.stderr) return reader, writer @@ -166,7 +180,7 @@ def read(self): @coroutine def fetch(url, verbose=True, max_redirect=10): - pool = ConnectionPool() + pool = ConnectionPool(verbose) for _ in range(max_redirect): request = Request(url, verbose) yield from request.connect(pool) diff --git a/tulip/protocols.py b/tulip/protocols.py index d76f25a2..7b741b76 100644 --- a/tulip/protocols.py +++ b/tulip/protocols.py @@ -29,6 +29,29 @@ def connection_lost(self, exc): aborted or closed). """ + def pause_writing(self): + """Called when the transport's buffer goes over the high-water mark. + + Pause and resume calls are paired -- pause_writing() is called + once when the buffer goes strictly over the high-water mark + (even if subsequent writes increases the buffer size even + more), and eventually resume_writing() is called once when the + buffer size reaches the low-water mark. + + Note that if the buffer size equals the high-water mark, + pause_writing() is not called -- it must go strictly over. + Conversely, resume_writing() is called when the buffer size is + equal or lower than the low-water mark. These end conditions + are important to ensure that things go as expected when either + mark is zero. + """ + + def resume_writing(self): + """Called when the transport's buffer drains below the low-water mark. + + See pause_writing() for details. + """ + class Protocol(BaseProtocol): """ABC representing a protocol. diff --git a/tulip/transports.py b/tulip/transports.py index f6eb2820..cb8bf787 100644 --- a/tulip/transports.py +++ b/tulip/transports.py @@ -49,6 +49,27 @@ def resume(self): class WriteTransport(BaseTransport): """ABC for write-only transports.""" + def set_buffer_limits(self, high, low=None): + """Set the high- and low-water limits for flow control. + + These two values control when to call the protocol's + pause_writing() and resume_writing() methods. If specified, + the low-water limit must be less than or equal to the + high-water limit. Neither value can be negative. + + The initial defaults are implementation-specific. If only the + high-water limit is given, the low-water limit defaults to a + implementation-specific value less than or equal to the + high-water limit. Setting high to zero forces low to zero as + well, and causes pause_writing() to be called whenever the + buffer becomes non-empty. Setting low to zero causes + resume_writing() to be called only once the buffer is empty. + Use of zero for either limit is generally sub-optimal as it + reduces opportunities for doing I/O and computation + concurrently. + """ + raise NotImplementedError + def write(self, data): """Write some data bytes to the transport. From 0cb9b1be0c87dffc900c3fd4060c131a7456480c Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Wed, 9 Oct 2013 14:54:46 -0700 Subject: [PATCH 0655/1502] do not shield cancel in wait() --- tests/tasks_test.py | 8 ++++---- tulip/tasks.py | 18 +++++++++++++----- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 7aa32a09..396e4f7f 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -1142,7 +1142,7 @@ def outer(): self.assertEqual(proof, 101) self.assertTrue(waiter.cancelled()) - def test_yield_wait_shields_cancel(self): + def test_yield_wait_does_not_shield_cancel(self): # Cancelling outer() makes wait() return early, leaves inner() # running. proof = 0 @@ -1163,10 +1163,11 @@ def outer(): f = tasks.async(outer(), loop=self.loop) test_utils.run_briefly(self.loop) f.cancel() - self.loop.run_until_complete(f) + self.assertRaises( + futures.CancelledError, self.loop.run_until_complete, f) waiter.set_result(None) test_utils.run_briefly(self.loop) - self.assertEqual(proof, 101) + self.assertEqual(proof, 1) def test_shield_result(self): inner = futures.Future(loop=self.loop) @@ -1256,7 +1257,6 @@ def test_gather_shield(self): test_utils.run_briefly(self.loop) - class GatherTestsBase: def setUp(self): diff --git a/tulip/tasks.py b/tulip/tasks.py index 22e506bc..dc9d2f3c 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -11,7 +11,6 @@ import functools import inspect import linecache -import sys import traceback import weakref @@ -358,6 +357,11 @@ def wait_for(fut, timeout, *, loop=None): raise futures.TimeoutError() +def _waiter_timeout(waiter): + if not waiter.done(): + waiter.set_result(False) + + @coroutine def _wait(fs, timeout, return_when, loop): """Internal helper for wait() and _wait_for(). @@ -368,7 +372,7 @@ def _wait(fs, timeout, return_when, loop): waiter = futures.Future(loop=loop) timeout_handle = None if timeout is not None: - timeout_handle = loop.call_later(timeout, waiter.cancel) + timeout_handle = loop.call_later(timeout, _waiter_timeout, waiter) counter = len(fs) def _on_completion(f): @@ -380,14 +384,18 @@ def _on_completion(f): f.exception() is not None)): if timeout_handle is not None: timeout_handle.cancel() - waiter.cancel() + if not waiter.done(): + waiter.set_result(False) for f in fs: f.add_done_callback(_on_completion) + try: yield from waiter - except futures.CancelledError: - pass + finally: + if timeout_handle is not None: + timeout_handle.cancel() + done, pending = set(), set() for f in fs: f.remove_done_callback(_on_completion) From 5d82c82e4c3ce5e684d8680a48a40ad14c8dd1fc Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 9 Oct 2013 16:46:14 -0700 Subject: [PATCH 0656/1502] Remove experimental code that was accidentally committed. --- tulip/protocols.py | 23 ----------------------- tulip/transports.py | 21 --------------------- 2 files changed, 44 deletions(-) diff --git a/tulip/protocols.py b/tulip/protocols.py index 7b741b76..d76f25a2 100644 --- a/tulip/protocols.py +++ b/tulip/protocols.py @@ -29,29 +29,6 @@ def connection_lost(self, exc): aborted or closed). """ - def pause_writing(self): - """Called when the transport's buffer goes over the high-water mark. - - Pause and resume calls are paired -- pause_writing() is called - once when the buffer goes strictly over the high-water mark - (even if subsequent writes increases the buffer size even - more), and eventually resume_writing() is called once when the - buffer size reaches the low-water mark. - - Note that if the buffer size equals the high-water mark, - pause_writing() is not called -- it must go strictly over. - Conversely, resume_writing() is called when the buffer size is - equal or lower than the low-water mark. These end conditions - are important to ensure that things go as expected when either - mark is zero. - """ - - def resume_writing(self): - """Called when the transport's buffer drains below the low-water mark. - - See pause_writing() for details. - """ - class Protocol(BaseProtocol): """ABC representing a protocol. diff --git a/tulip/transports.py b/tulip/transports.py index cb8bf787..f6eb2820 100644 --- a/tulip/transports.py +++ b/tulip/transports.py @@ -49,27 +49,6 @@ def resume(self): class WriteTransport(BaseTransport): """ABC for write-only transports.""" - def set_buffer_limits(self, high, low=None): - """Set the high- and low-water limits for flow control. - - These two values control when to call the protocol's - pause_writing() and resume_writing() methods. If specified, - the low-water limit must be less than or equal to the - high-water limit. Neither value can be negative. - - The initial defaults are implementation-specific. If only the - high-water limit is given, the low-water limit defaults to a - implementation-specific value less than or equal to the - high-water limit. Setting high to zero forces low to zero as - well, and causes pause_writing() to be called whenever the - buffer becomes non-empty. Setting low to zero causes - resume_writing() to be called only once the buffer is empty. - Use of zero for either limit is generally sub-optimal as it - reduces opportunities for doing I/O and computation - concurrently. - """ - raise NotImplementedError - def write(self, data): """Write some data bytes to the transport. From 9abf38db2ecbc856c13af911f9022e22fb5e8925 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 9 Oct 2013 23:35:09 -0700 Subject: [PATCH 0657/1502] Fold long lines. --- examples/fetch1.py | 3 ++- examples/fetch2.py | 7 +++++-- examples/fetch3.py | 16 +++++++++++----- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/examples/fetch1.py b/examples/fetch1.py index d7d23507..57e66e6a 100644 --- a/examples/fetch1.py +++ b/examples/fetch1.py @@ -32,8 +32,9 @@ def getline(): if not header_line: break if self.verbose: print('<', header_line, file=sys.stderr) + # TODO: Continuation lines. key, value = header_line.split(':', 1) - self.headers.append((key, value.strip())) # TODO: Continuation lines. + self.headers.append((key, value.strip())) if self.verbose: print(file=sys.stderr) diff --git a/examples/fetch2.py b/examples/fetch2.py index e9a746b0..2ea1c695 100644 --- a/examples/fetch2.py +++ b/examples/fetch2.py @@ -40,7 +40,9 @@ def connect(self): print('* Connecting to %s:%s using %s' % (self.hostname, self.port, 'ssl' if self.ssl else 'tcp'), file=sys.stderr) - self.reader, self.writer = yield from open_connection(self.hostname, self.port, ssl=self.ssl) + self.reader, self.writer = yield from open_connection(self.hostname, + self.port, + ssl=self.ssl) if self.verbose: print('* Connected to %s' % (self.writer.get_extra_info('socket').getpeername(),), @@ -97,8 +99,9 @@ def read_headers(self): if not header_line: break if self.verbose: print('<', header_line, file=sys.stderr) + # TODO: Continuation lines. key, value = header_line.split(':', 1) - self.headers.append((key, value.strip())) # TODO: Continuation lines. + self.headers.append((key, value.strip())) if self.verbose: print(file=sys.stderr) @coroutine diff --git a/examples/fetch3.py b/examples/fetch3.py index 26d6102e..ea5e298d 100644 --- a/examples/fetch3.py +++ b/examples/fetch3.py @@ -74,24 +74,29 @@ def vprint(self, *args): def connect(self, pool): self.vprint('* Connecting to %s:%s using %s' % (self.hostname, self.port, 'ssl' if self.ssl else 'tcp')) - self.reader, self.writer = yield from pool.open_connection(self.hostname, self.port, ssl=self.ssl) + self.reader, self.writer = \ + yield from pool.open_connection(self.hostname, + self.port, + ssl=self.ssl) self.vprint('* Connected to %s' % (self.writer.get_extra_info('socket').getpeername(),)) + @coroutine def putline(self, line): self.vprint('>', line) self.writer.write(line.encode('latin-1') + b'\r\n') + yield from self.writer.drain() @coroutine def send_request(self): request = '%s %s %s' % (self.method, self.full_path, self.http_version) - self.putline(request) + yield from self.putline(request) if 'host' not in {key.lower() for key, _ in self.headers}: self.headers.insert(0, ('Host', self.netloc)) for key, value in self.headers: line = '%s: %s' % (key, value) - self.putline(line) - self.putline('') + yield from self.putline(line) + yield from self.putline('') @coroutine def get_response(self): @@ -132,8 +137,9 @@ def read_headers(self): header_line = yield from self.getline() if not header_line: break + # TODO: Continuation lines. key, value = header_line.split(':', 1) - self.headers.append((key, value.strip())) # TODO: Continuation lines. + self.headers.append((key, value.strip())) def get_redirect_url(self, default=None): if self.status not in (300, 301, 302, 303, 307): From 50b67e03dc61b50401d0ab65d53d7d24ec1a0bdc Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 10 Oct 2013 09:25:58 -0700 Subject: [PATCH 0658/1502] Keep only all_tasks() and make it specific to an event loop. --- tulip/tasks.py | 47 ++++++----------------------------------------- 1 file changed, 6 insertions(+), 41 deletions(-) diff --git a/tulip/tasks.py b/tulip/tasks.py index dc9d2f3c..af9330d4 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -69,49 +69,14 @@ class Task(futures.Future): _all_tasks = weakref.WeakSet() @classmethod - def all_tasks(cls): - """Return a set of all tasks in existence.""" - return set(cls._all_tasks) + def all_tasks(cls, loop=None): + """Return a set of all tasks for an event loop. - @classmethod - def all_pending_tasks(cls): - """Return a set of all tasks in existence that aren't done yet.""" - return {t for t in cls._all_tasks if not t.done()} - - @classmethod - def all_done_tasks(cls): - """Return a set of all tasks in existence that are done. - - This is the union of all_successful_tasks() and all_failed_tasks(). - """ - return {t for t in cls._all_tasks if t.done()} - - @classmethod - def all_successful_tasks(cls): - """Return a set of all tasks in existence that have a valid result.""" - return {t for t in cls._all_tasks - if t.done() and not t.cancelled() and t.exception() is None} - - @classmethod - def all_failed_tasks(cls): - """Return a set of all tasks in existence that have failed. - - This is the union of all_excepted_tasks() and all_cancelled_tasks(). + By default all tasks for the current event loop are returned. """ - return {t for t in cls._all_tasks - if t.done() and (t.cancelled() or t.exception())} - - @classmethod - def all_excepted_tasks(cls): - """Return a set of all tasks in existence that have an exception.""" - return {t for t in cls._all_tasks - if t.done() and not t.cancelled() and - t.exception() is not None} - - @classmethod - def all_cancelled_tasks(cls): - """Return a set of all tasks in existence that were cancelled.""" - return {t for t in cls._all_tasks if t.cancelled()} + if loop is None: + loop = events.get_event_loop() + return {t for t in cls._all_tasks if t._loop is loop} def __init__(self, coro, *, loop=None): assert inspect.isgenerator(coro) # Must be a coroutine *object*. From 9e69d9a4742f7d779d38f83de4a5e9f5d0e449c9 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 10 Oct 2013 13:53:56 -0700 Subject: [PATCH 0659/1502] Handle BrokenPipeError. Share max_size. Add source/sink examples. --- examples/sink.py | 39 ++++++++++++++++++++++++++++++++++ examples/source.py | 46 ++++++++++++++++++++++++++++++++++++++++ tulip/selector_events.py | 17 +++++++++++---- 3 files changed, 98 insertions(+), 4 deletions(-) create mode 100644 examples/sink.py create mode 100644 examples/source.py diff --git a/examples/sink.py b/examples/sink.py new file mode 100644 index 00000000..ec78c3bb --- /dev/null +++ b/examples/sink.py @@ -0,0 +1,39 @@ +"""Test server that accepts connections and reads all data off them.""" + +import sys + +from tulip import * + +def dprint(*args): + print('sink:', *args, file=sys.stderr) + +class Server(Protocol): + + def connection_made(self, tr): + dprint('connection from', tr.get_extra_info('addr')) + self.tr = tr + self.total = 0 + + def data_received(self, data): + self.total += len(data) + dprint('received', len(data), 'bytes; total', self.total) + if self.total > 1e6: + dprint('closing due to too much data') + self.tr.close() + + def connection_lost(self, how): + dprint('closed', repr(how)) + +@coroutine +def start(loop): + ss = yield from loop.start_serving(Server, 'localhost', 1111) + return ss + +def main(): + loop = get_event_loop() + ss = loop.run_until_complete(start(loop)) + dprint('serving', [s.getsockname() for s in ss]) + loop.run_forever() + +if __name__ == '__main__': + main() diff --git a/examples/source.py b/examples/source.py new file mode 100644 index 00000000..98e63502 --- /dev/null +++ b/examples/source.py @@ -0,0 +1,46 @@ +"""Test client that connects and sends infinite data.""" + +import sys + +from tulip import * + +def dprint(*args): + print('source:', *args, file=sys.stderr) + +class Client(Protocol): + + data = b'x'*16*1024 + + def connection_made(self, tr): + dprint('connecting to', tr.get_extra_info('addr')) + self.tr = tr + self.lost = False + self.loop = get_event_loop() + self.waiter = Future() + self.write_some_data() + + def write_some_data(self): + dprint('writing', len(self.data), 'bytes') + self.tr.write(self.data) + if not self.lost: + self.loop.call_soon(self.write_some_data) + + def connection_lost(self, exc): + dprint('lost connection', repr(exc)) + self.lost = True + self.waiter.set_result(None) + +@coroutine +def start(loop): + tr, pr = yield from loop.create_connection(Client, 'localhost', 1111) + dprint('tr =', tr) + dprint('pr =', pr) + res = yield from pr.waiter + return res + +def main(): + loop = get_event_loop() + loop.run_until_complete(start(loop)) + +if __name__ == '__main__': + main() diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 6f6a271c..cb22a9ff 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -322,6 +322,8 @@ def stop_serving(self, sock): class _SelectorTransport(transports.Transport): + max_size = 256 * 1024 # Buffer size passed to recv(). + def __init__(self, loop, sock, protocol, extra): super().__init__(extra) self._extra['socket'] = sock @@ -400,7 +402,7 @@ def resume(self): def _read_ready(self): try: - data = self._sock.recv(16*1024) + data = self._sock.recv(self.max_size) except (BlockingIOError, InterruptedError): pass except ConnectionResetError as exc: @@ -434,6 +436,9 @@ def write(self, data): n = self._sock.send(data) except (BlockingIOError, InterruptedError): n = 0 + except BrokenPipeError as exc: + self._force_close(exc) + return except OSError as exc: self._fatal_error(exc) return @@ -455,6 +460,8 @@ def _write_ready(self): n = self._sock.send(data) except (BlockingIOError, InterruptedError): self._buffer.append(data) + except BrokenPipeError as exc: + self._force_close(exc) except Exception as exc: self._loop.remove_writer(self._sock_fd) self._fatal_error(exc) @@ -559,7 +566,7 @@ def _on_ready(self): # First try reading. if not self._closing and not self._paused: try: - data = self._sock.recv(8192) + data = self._sock.recv(self.max_size) except (BlockingIOError, InterruptedError, ssl.SSLWantReadError, ssl.SSLWantWriteError): pass @@ -585,6 +592,10 @@ def _on_ready(self): except (BlockingIOError, InterruptedError, ssl.SSLWantReadError, ssl.SSLWantWriteError): n = 0 + except BrokenPipeError as exc: + self._loop.remove_writer(self._sock_fd) + self._force_close(exc) + return except Exception as exc: self._loop.remove_writer(self._sock_fd) self._fatal_error(exc) @@ -624,8 +635,6 @@ def close(self): class _SelectorDatagramTransport(_SelectorTransport): - max_size = 256 * 1024 # max bytes we read in one eventloop iteration - def __init__(self, loop, sock, protocol, address=None, extra=None): super().__init__(loop, sock, protocol, extra) From 84acb6200e4fcdcf10e8c7662d47d8fb9bff02ab Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 10 Oct 2013 14:00:09 -0700 Subject: [PATCH 0660/1502] Treat ConnectionResetError the same as BrokenPipeError. --- tulip/selector_events.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index cb22a9ff..9d2c0c81 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -436,7 +436,7 @@ def write(self, data): n = self._sock.send(data) except (BlockingIOError, InterruptedError): n = 0 - except BrokenPipeError as exc: + except (BrokenPipeError, ConnectionResetError) as exc: self._force_close(exc) return except OSError as exc: @@ -460,7 +460,7 @@ def _write_ready(self): n = self._sock.send(data) except (BlockingIOError, InterruptedError): self._buffer.append(data) - except BrokenPipeError as exc: + except (BrokenPipeError, ConnectionResetError) as exc: self._force_close(exc) except Exception as exc: self._loop.remove_writer(self._sock_fd) @@ -592,7 +592,7 @@ def _on_ready(self): except (BlockingIOError, InterruptedError, ssl.SSLWantReadError, ssl.SSLWantWriteError): n = 0 - except BrokenPipeError as exc: + except (BrokenPipeError, ConnectionResetError) as exc: self._loop.remove_writer(self._sock_fd) self._force_close(exc) return From 00fd8b13e36c5e7020b7769f6dcd1afecfb79fb3 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 10 Oct 2013 14:21:16 -0700 Subject: [PATCH 0661/1502] Fix logic around lost connection. --- examples/source.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/source.py b/examples/source.py index 98e63502..3b9e04ab 100644 --- a/examples/source.py +++ b/examples/source.py @@ -20,10 +20,12 @@ def connection_made(self, tr): self.write_some_data() def write_some_data(self): + if self.lost: + dprint('lost already') + return dprint('writing', len(self.data), 'bytes') self.tr.write(self.data) - if not self.lost: - self.loop.call_soon(self.write_some_data) + self.loop.call_soon(self.write_some_data) def connection_lost(self, exc): dprint('lost connection', repr(exc)) From 920e482862b0d9f0cfc2d0f0d58bed2f277efd91 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 10 Oct 2013 15:14:31 -0700 Subject: [PATCH 0662/1502] First crack at create_server(). --- examples/sink.py | 8 ++-- tests/base_events_test.py | 20 +++++----- tests/events_test.py | 71 +++++++++++++++++------------------ tests/proactor_events_test.py | 8 ++-- tulip/base_events.py | 20 ++++++++-- tulip/events.py | 24 +++++++----- tulip/proactor_events.py | 4 +- tulip/selector_events.py | 2 +- tulip/transports.py | 2 +- tulip/windows_events.py | 6 +-- 10 files changed, 92 insertions(+), 73 deletions(-) diff --git a/examples/sink.py b/examples/sink.py index ec78c3bb..1d256e28 100644 --- a/examples/sink.py +++ b/examples/sink.py @@ -26,13 +26,13 @@ def connection_lost(self, how): @coroutine def start(loop): - ss = yield from loop.start_serving(Server, 'localhost', 1111) - return ss + svr = yield from loop.create_server(Server, 'localhost', 1111) + return svr def main(): loop = get_event_loop() - ss = loop.run_until_complete(start(loop)) - dprint('serving', [s.getsockname() for s in ss]) + svr = loop.run_until_complete(start(loop)) + dprint('serving', [s.getsockname() for s in svr.sockets]) loop.run_forever() if __name__ == '__main__': diff --git a/tests/base_events_test.py b/tests/base_events_test.py index b423f329..940dee5f 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -443,7 +443,7 @@ def getaddrinfo_task(*args, **kwds): self.assertRaises( OSError, self.loop.run_until_complete, coro) - def test_start_serving_empty_host(self): + def test_create_server_empty_host(self): # if host is empty string use None instead host = object() @@ -457,28 +457,28 @@ def getaddrinfo_task(*args, **kwds): return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) self.loop.getaddrinfo = getaddrinfo_task - fut = self.loop.start_serving(MyProto, '', 0) + fut = self.loop.create_server(MyProto, '', 0) self.assertRaises(OSError, self.loop.run_until_complete, fut) self.assertIsNone(host) - def test_start_serving_host_port_sock(self): - fut = self.loop.start_serving( + def test_create_server_host_port_sock(self): + fut = self.loop.create_server( MyProto, '0.0.0.0', 0, sock=object()) self.assertRaises(ValueError, self.loop.run_until_complete, fut) - def test_start_serving_no_host_port_sock(self): - fut = self.loop.start_serving(MyProto) + def test_create_server_no_host_port_sock(self): + fut = self.loop.create_server(MyProto) self.assertRaises(ValueError, self.loop.run_until_complete, fut) - def test_start_serving_no_getaddrinfo(self): + def test_create_server_no_getaddrinfo(self): getaddrinfo = self.loop.getaddrinfo = unittest.mock.Mock() getaddrinfo.return_value = [] - f = self.loop.start_serving(MyProto, '0.0.0.0', 0) + f = self.loop.create_server(MyProto, '0.0.0.0', 0) self.assertRaises(OSError, self.loop.run_until_complete, f) @unittest.mock.patch('tulip.base_events.socket') - def test_start_serving_cant_bind(self, m_socket): + def test_create_server_cant_bind(self, m_socket): class Err(OSError): strerror = 'error' @@ -488,7 +488,7 @@ class Err(OSError): m_sock = m_socket.socket.return_value = unittest.mock.Mock() m_sock.bind.side_effect = Err - fut = self.loop.start_serving(MyProto, '0.0.0.0', 0) + fut = self.loop.create_server(MyProto, '0.0.0.0', 0) self.assertRaises(OSError, self.loop.run_until_complete, fut) self.assertTrue(m_sock.close.called) diff --git a/tests/events_test.py b/tests/events_test.py index 460d2fe9..e1fb020b 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -543,7 +543,7 @@ def test_create_connection_local_addr_in_use(self): self.assertEqual(cm.exception.errno, errno.EADDRINUSE) self.assertIn(str(httpd.address), cm.exception.strerror) - def test_start_serving(self): + def test_create_server(self): proto = None def factory(): @@ -551,10 +551,10 @@ def factory(): proto = MyProto() return proto - f = self.loop.start_serving(factory, '0.0.0.0', 0) - socks = self.loop.run_until_complete(f) - self.assertEqual(len(socks), 1) - sock = socks[0] + f = self.loop.create_server(factory, '0.0.0.0', 0) + server = self.loop.run_until_complete(f) + self.assertEqual(len(server.sockets), 1) + sock = server.sockets[0] host, port = sock.getsockname() self.assertEqual(host, '0.0.0.0') client = socket.socket() @@ -585,11 +585,11 @@ def factory(): # recv()/send() on the serving socket client.close() - # close start_serving socks - self.loop.stop_serving(sock) + # close server + server.close() @unittest.skipIf(ssl is None, 'No ssl module') - def test_start_serving_ssl(self): + def test_create_server_ssl(self): proto = None class ClientMyProto(MyProto): @@ -609,10 +609,11 @@ def factory(): certfile=os.path.join(here, 'sample.crt'), keyfile=os.path.join(here, 'sample.key')) - f = self.loop.start_serving( + f = self.loop.create_server( factory, '127.0.0.1', 0, ssl=sslcontext) - sock = self.loop.run_until_complete(f)[0] + server = self.loop.run_until_complete(f) + sock = server.sockets[0] host, port = sock.getsockname() self.assertEqual(host, '127.0.0.1') @@ -643,9 +644,9 @@ def factory(): client.close() # stop serving - self.loop.stop_serving(sock) + server.close() - def test_start_serving_sock(self): + def test_create_server_sock(self): proto = futures.Future(loop=self.loop) class TestMyProto(MyProto): @@ -657,8 +658,9 @@ def connection_made(self, transport): sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock_ob.bind(('0.0.0.0', 0)) - f = self.loop.start_serving(TestMyProto, sock=sock_ob) - sock = self.loop.run_until_complete(f)[0] + f = self.loop.create_server(TestMyProto, sock=sock_ob) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] self.assertIs(sock, sock_ob) host, port = sock.getsockname() @@ -667,27 +669,27 @@ def connection_made(self, transport): client.connect(('127.0.0.1', port)) client.send(b'xxx') client.close() + server.close() - self.loop.stop_serving(sock) - - def test_start_serving_addr_in_use(self): + def test_create_server_addr_in_use(self): sock_ob = socket.socket(type=socket.SOCK_STREAM) sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock_ob.bind(('0.0.0.0', 0)) - f = self.loop.start_serving(MyProto, sock=sock_ob) - sock = self.loop.run_until_complete(f)[0] + f = self.loop.create_server(MyProto, sock=sock_ob) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] host, port = sock.getsockname() - f = self.loop.start_serving(MyProto, host=host, port=port) + f = self.loop.create_server(MyProto, host=host, port=port) with self.assertRaises(OSError) as cm: self.loop.run_until_complete(f) self.assertEqual(cm.exception.errno, errno.EADDRINUSE) - self.loop.stop_serving(sock) + server.close() @unittest.skipUnless(socket.has_ipv6, 'IPv6 not supported') - def test_start_serving_dual_stack(self): + def test_create_server_dual_stack(self): f_proto = futures.Future(loop=self.loop) class TestMyProto(MyProto): @@ -699,8 +701,8 @@ def connection_made(self, transport): while True: try: port = find_unused_port() - f = self.loop.start_serving(TestMyProto, host=None, port=port) - socks = self.loop.run_until_complete(f) + f = self.loop.create_server(TestMyProto, host=None, port=port) + server = self.loop.run_until_complete(f) except OSError as ex: if ex.errno == errno.EADDRINUSE: try_count += 1 @@ -725,13 +727,12 @@ def connection_made(self, transport): proto.transport.close() client.close() - for s in socks: - self.loop.stop_serving(s) + server.close() - def test_stop_serving(self): - f = self.loop.start_serving(MyProto, '0.0.0.0', 0) - socks = self.loop.run_until_complete(f) - sock = socks[0] + def test_server_close(self): + f = self.loop.create_server(MyProto, '0.0.0.0', 0) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] host, port = sock.getsockname() client = socket.socket() @@ -739,7 +740,7 @@ def test_stop_serving(self): client.send(b'xxx') client.close() - self.loop.stop_serving(sock) + server.close() client = socket.socket() self.assertRaises( @@ -946,7 +947,7 @@ def main(): self.assertEqual(t.result(), 'cancelled') self.assertRaises(futures.CancelledError, f.result) self.assertTrue(ov is None or not ov.pending) - self.loop.stop_serving(r) + self.loop._stop_serving(r) r.close() w.close() @@ -1236,7 +1237,7 @@ def create_event_loop(self): def test_create_ssl_connection(self): raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") - def test_start_serving_ssl(self): + def test_create_server_ssl(self): raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") def test_reader_callback(self): @@ -1433,9 +1434,7 @@ def test_not_implemented(self): self.assertRaises( NotImplementedError, loop.create_connection, f) self.assertRaises( - NotImplementedError, loop.start_serving, f) - self.assertRaises( - NotImplementedError, loop.stop_serving, f) + NotImplementedError, loop.create_server, f) self.assertRaises( NotImplementedError, loop.create_datagram_endpoint, f) self.assertRaises( diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py index d9eae50e..009b0697 100644 --- a/tests/proactor_events_test.py +++ b/tests/proactor_events_test.py @@ -433,7 +433,7 @@ def test_process_events(self): self.loop._process_events([]) @unittest.mock.patch('tulip.proactor_events.tulip_log') - def test_start_serving(self, m_log): + def test_create_server(self, m_log): pf = unittest.mock.Mock() call_soon = self.loop.call_soon = unittest.mock.Mock() @@ -460,7 +460,7 @@ def test_start_serving(self, m_log): self.assertTrue(self.sock.close.called) self.assertTrue(m_log.exception.called) - def test_start_serving_cancel(self): + def test_create_server_cancel(self): pf = unittest.mock.Mock() call_soon = self.loop.call_soon = unittest.mock.Mock() @@ -475,6 +475,6 @@ def test_start_serving_cancel(self): def test_stop_serving(self): sock = unittest.mock.Mock() - self.loop.stop_serving(sock) + self.loop._stop_serving(sock) self.assertTrue(sock.close.called) - self.proactor.stop_serving.assert_called_with(sock) + self.proactor._stop_serving.assert_called_with(sock) diff --git a/tulip/base_events.py b/tulip/base_events.py index 5157b5b0..5c685403 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -30,7 +30,7 @@ from .log import tulip_log -__all__ = ['BaseEventLoop'] +__all__ = ['BaseEventLoop', 'Server'] # Argument for default thread pool executor creation. @@ -45,6 +45,20 @@ def _raise_stop_error(*args): raise _StopError +class Server(events.AbstractServer): + + def __init__(self, loop, sockets): + self.loop = loop + self.sockets = sockets + + def close(self): + sockets = self.sockets + if sockets is not None: + self.sockets = None + for sock in sockets: + self.loop._stop_serving(sock) + + class BaseEventLoop(events.AbstractEventLoop): def __init__(self): @@ -383,7 +397,7 @@ def create_datagram_endpoint(self, protocol_factory, return transport, protocol @tasks.coroutine - def start_serving(self, protocol_factory, host=None, port=None, + def create_server(self, protocol_factory, host=None, port=None, *, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, @@ -447,7 +461,7 @@ def start_serving(self, protocol_factory, host=None, port=None, sock.listen(backlog) sock.setblocking(False) self._start_serving(protocol_factory, sock, ssl) - return sockets + return Server(self, sockets) @tasks.coroutine def connect_read_pipe(self, protocol_factory, pipe): diff --git a/tulip/events.py b/tulip/events.py index 9e715a17..3615dccc 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -5,7 +5,8 @@ """ __all__ = ['AbstractEventLoopPolicy', 'DefaultEventLoopPolicy', - 'AbstractEventLoop', 'TimerHandle', 'Handle', 'make_handle', + 'AbstractEventLoop', 'AbstractServer', + 'Handle', 'TimerHandle', 'get_event_loop_policy', 'set_event_loop_policy', 'get_event_loop', 'set_event_loop', 'new_event_loop', ] @@ -100,6 +101,14 @@ def __ne__(self, other): return NotImplemented if equal is NotImplemented else not equal +class AbstractServer: + """Abstract server returned by create_service().""" + + def close(self): + """Stop serving. This leaves existing connections open.""" + return NotImplemented + + class AbstractEventLoop: """Abstract event loop.""" @@ -166,12 +175,13 @@ def create_connection(self, protocol_factory, host=None, port=None, *, local_addr=None): raise NotImplementedError - def start_serving(self, protocol_factory, host=None, port=None, *, + def create_server(self, protocol_factory, host=None, port=None, *, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, sock=None, backlog=100, ssl=None, reuse_address=None): - """A coroutine which creates a TCP server bound to host and - port and whose result will be a list of socket objects which - will later be handled by protocol_factory. + """A coroutine which creates a TCP server bound to host and port. + + The return value is a Server object which can be used to stop + the service. If host is an empty string or None all interfaces are assumed and a list of multiple sockets will be returned (most likely @@ -199,10 +209,6 @@ def start_serving(self, protocol_factory, host=None, port=None, *, """ raise NotImplementedError - def stop_serving(self, sock): - """Stop listening for incoming connections. Close socket.""" - raise NotImplementedError - def create_datagram_endpoint(self, protocol_factory, local_addr=None, remote_addr=None, *, family=0, proto=0, flags=0): diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index 79d2d094..4665dfca 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -323,6 +323,6 @@ def loop(f=None): def _process_events(self, event_list): pass # XXX hard work currently done in poll - def stop_serving(self, sock): - self._proactor.stop_serving(sock) + def _stop_serving(self, sock): + self._proactor._stop_serving(sock) sock.close() diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 9d2c0c81..8431f485 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -315,7 +315,7 @@ def _process_events(self, event_list): else: self._add_callback(writer) - def stop_serving(self, sock): + def _stop_serving(self, sock): self.remove_reader(sock.fileno()) sock.close() diff --git a/tulip/transports.py b/tulip/transports.py index f6eb2820..bf3adee7 100644 --- a/tulip/transports.py +++ b/tulip/transports.py @@ -100,7 +100,7 @@ class Transport(ReadTransport, WriteTransport): The user never instantiates a transport directly; they call a utility function, passing it a protocol factory and other information necessary to create the transport and protocol. (E.g. - EventLoop.create_connection() or EventLoop.start_serving().) + EventLoop.create_connection() or EventLoop.create_server().) The utility function will asynchronously create a transport and a protocol and hook them up by calling the protocol's diff --git a/tulip/windows_events.py b/tulip/windows_events.py index 8fbbe103..190c796a 100644 --- a/tulip/windows_events.py +++ b/tulip/windows_events.py @@ -145,7 +145,7 @@ def loop(f=None): self.call_soon(loop) return [server] - def stop_serving(self, server): + def _stop_serving(self, server): server.close() @@ -329,9 +329,9 @@ def _poll(self, timeout=None): self._results.append(f) ms = 0 - def stop_serving(self, obj): + def _stop_serving(self, obj): # obj is a socket or pipe handle. It will be closed in - # BaseProactorEventLoop.stop_serving() which will make any + # BaseProactorEventLoop._stop_serving() which will make any # pending operations fail quickly. self._stopped_serving.add(obj) From 8af37f81a2d26cc9e5b40fe063b38e8256a41de9 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 10 Oct 2013 15:53:02 -0700 Subject: [PATCH 0663/1502] Add Server.wait_close() to wait until it is closed. --- examples/sink.py | 20 ++++++++++++++------ examples/source.py | 6 +++++- tulip/base_events.py | 14 ++++++++++++++ tulip/events.py | 4 ++++ 4 files changed, 37 insertions(+), 7 deletions(-) diff --git a/examples/sink.py b/examples/sink.py index 1d256e28..6835ee77 100644 --- a/examples/sink.py +++ b/examples/sink.py @@ -1,13 +1,15 @@ -"""Test server that accepts connections and reads all data off them.""" +"""Test service that accepts connections and reads all data off them.""" import sys from tulip import * +server = None + def dprint(*args): print('sink:', *args, file=sys.stderr) -class Server(Protocol): +class Service(Protocol): def connection_made(self, tr): dprint('connection from', tr.get_extra_info('addr')) @@ -15,6 +17,11 @@ def connection_made(self, tr): self.total = 0 def data_received(self, data): + if data == b'stop': + # Magic data that closes the service. + server.close() + self.tr.close() + return self.total += len(data) dprint('received', len(data), 'bytes; total', self.total) if self.total > 1e6: @@ -26,14 +33,15 @@ def connection_lost(self, how): @coroutine def start(loop): - svr = yield from loop.create_server(Server, 'localhost', 1111) + svr = yield from loop.create_server(Service, 'localhost', 1111) return svr def main(): loop = get_event_loop() - svr = loop.run_until_complete(start(loop)) - dprint('serving', [s.getsockname() for s in svr.sockets]) - loop.run_forever() + global server + server = loop.run_until_complete(start(loop)) + dprint('serving', [s.getsockname() for s in server.sockets]) + loop.run_until_complete(server.wait_closed()) if __name__ == '__main__': main() diff --git a/examples/source.py b/examples/source.py index 3b9e04ab..8041ae42 100644 --- a/examples/source.py +++ b/examples/source.py @@ -17,7 +17,11 @@ def connection_made(self, tr): self.lost = False self.loop = get_event_loop() self.waiter = Future() - self.write_some_data() + if '--stop' in sys.argv[1:]: + self.tr.write(b'stop') + self.tr.close() + else: + self.write_some_data() def write_some_data(self): if self.lost: diff --git a/tulip/base_events.py b/tulip/base_events.py index 5c685403..db52b757 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -50,6 +50,7 @@ class Server(events.AbstractServer): def __init__(self, loop, sockets): self.loop = loop self.sockets = sockets + self.waiters = [] def close(self): sockets = self.sockets @@ -57,6 +58,19 @@ def close(self): self.sockets = None for sock in sockets: self.loop._stop_serving(sock) + waiters = self.waiters + self.waiters = None + for waiter in waiters: + if not waiter.done(): + waiter.set_result(waiter) + + @tasks.coroutine + def wait_closed(self): + if self.sockets is None or self.waiters is None: + return + waiter = futures.Future(loop=self.loop) + self.waiters.append(waiter) + yield from waiter class BaseEventLoop(events.AbstractEventLoop): diff --git a/tulip/events.py b/tulip/events.py index 3615dccc..bded631b 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -108,6 +108,10 @@ def close(self): """Stop serving. This leaves existing connections open.""" return NotImplemented + def wait_closed(self): + """Coroutine to wait until service is closed.""" + return NotImplemented + class AbstractEventLoop: """Abstract event loop.""" From e729ce37c3dd4e80ea407a783188d0b8df4ae6d1 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 10 Oct 2013 16:04:41 -0700 Subject: [PATCH 0664/1502] Small example tweaks. --- examples/sink.py | 5 +++-- examples/source.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/sink.py b/examples/sink.py index 6835ee77..0b2fd243 100644 --- a/examples/sink.py +++ b/examples/sink.py @@ -12,13 +12,14 @@ def dprint(*args): class Service(Protocol): def connection_made(self, tr): - dprint('connection from', tr.get_extra_info('addr')) + dprint('connection from', tr.get_extra_info('socket').getpeername()) + dprint('my socket is', tr.get_extra_info('socket').getsockname()) self.tr = tr self.total = 0 def data_received(self, data): if data == b'stop': - # Magic data that closes the service. + dprint('stopping server') server.close() self.tr.close() return diff --git a/examples/source.py b/examples/source.py index 8041ae42..2f65358d 100644 --- a/examples/source.py +++ b/examples/source.py @@ -12,7 +12,8 @@ class Client(Protocol): data = b'x'*16*1024 def connection_made(self, tr): - dprint('connecting to', tr.get_extra_info('addr')) + dprint('connecting to', tr.get_extra_info('socket').getpeername()) + dprint('my socket is', tr.get_extra_info('socket').getsockname()) self.tr = tr self.lost = False self.loop = get_event_loop() @@ -38,7 +39,7 @@ def connection_lost(self, exc): @coroutine def start(loop): - tr, pr = yield from loop.create_connection(Client, 'localhost', 1111) + tr, pr = yield from loop.create_connection(Client, '127.0.0.1', 1111) dprint('tr =', tr) dprint('pr =', pr) res = yield from pr.waiter From c4107763e59bf556bab615187bfe349ef4f7ffd7 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Fri, 11 Oct 2013 14:35:02 +0100 Subject: [PATCH 0665/1502] Fixes to make iocp loop handle ERROR_NETNAME_DELETED and make sink.py/source.py run with iocp. --- examples/sink.py | 5 +++++ examples/source.py | 7 ++++++- overlapped.c | 1 + tulip/proactor_events.py | 9 ++++++--- tulip/windows_events.py | 16 ++++++++++++++-- 5 files changed, 32 insertions(+), 6 deletions(-) diff --git a/examples/sink.py b/examples/sink.py index 0b2fd243..108340a1 100644 --- a/examples/sink.py +++ b/examples/sink.py @@ -38,11 +38,16 @@ def start(loop): return svr def main(): + if '--iocp' in sys.argv: + from tulip.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + set_event_loop(loop) loop = get_event_loop() global server server = loop.run_until_complete(start(loop)) dprint('serving', [s.getsockname() for s in server.sockets]) loop.run_until_complete(server.wait_closed()) + loop.close() if __name__ == '__main__': main() diff --git a/examples/source.py b/examples/source.py index 2f65358d..ac146d40 100644 --- a/examples/source.py +++ b/examples/source.py @@ -5,7 +5,7 @@ from tulip import * def dprint(*args): - print('source:', *args, file=sys.stderr) + print('source:', *args, file=sys.stderr) class Client(Protocol): @@ -46,8 +46,13 @@ def start(loop): return res def main(): + if '--iocp' in sys.argv: + from tulip.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + set_event_loop(loop) loop = get_event_loop() loop.run_until_complete(start(loop)) + loop.close() if __name__ == '__main__': main() diff --git a/overlapped.c b/overlapped.c index ae1e77ca..b5be63a0 100644 --- a/overlapped.c +++ b/overlapped.c @@ -1190,6 +1190,7 @@ PyInit__overlapped(void) d = PyModule_GetDict(m); WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING); + WINAPI_CONSTANT(F_DWORD, ERROR_NETNAME_DELETED); WINAPI_CONSTANT(F_DWORD, ERROR_SEM_TIMEOUT); WINAPI_CONSTANT(F_DWORD, INFINITE); WINAPI_CONSTANT(F_HANDLE, INVALID_HANDLE_VALUE); diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index 4665dfca..f1fe6d2a 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -104,7 +104,8 @@ def _loop_reading(self, fut=None): try: if fut is not None: - assert fut is self._read_fut + assert self._read_fut is fut or (self._read_fut is None and + self._closing) self._read_fut = None data = fut.result() # deliver data later in "finally" clause @@ -147,8 +148,8 @@ class _ProactorWritePipeTransport(_ProactorBasePipeTransport, def write(self, data): assert isinstance(data, bytes), repr(data) - if self._closing or self._eof_written: - raise IOError('close() or write_eof() already called') + if self._eof_written: + raise IOError('write_eof() already called') if not data: return @@ -178,6 +179,8 @@ def _loop_writing(self, f=None): return self._write_fut = self._loop._proactor.send(self._sock, data) self._write_fut.add_done_callback(self._loop_writing) + except ConnectionResetError as exc: + self._force_close(exc) except OSError as exc: self._fatal_error(exc) diff --git a/tulip/windows_events.py b/tulip/windows_events.py index 190c796a..7253bb4b 100644 --- a/tulip/windows_events.py +++ b/tulip/windows_events.py @@ -179,7 +179,13 @@ def recv(self, conn, nbytes, flags=0): else: ov.ReadFile(conn.fileno(), nbytes) def finish(trans, key, ov): - return ov.getresult() + try: + return ov.getresult() + except OSError as exc: + if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: + raise ConnectionResetError(*exc.args) + else: + raise return self._register(ov, conn, finish) def send(self, conn, buf, flags=0): @@ -190,7 +196,13 @@ def send(self, conn, buf, flags=0): else: ov.WriteFile(conn.fileno(), buf) def finish(trans, key, ov): - return ov.getresult() + try: + return ov.getresult() + except OSError as exc: + if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: + raise ConnectionResetError(*exc.args) + else: + raise return self._register(ov, conn, finish) def accept(self, listener): From 0c32e035f6f142bcf03f6e6ce003545f53d00603 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 11 Oct 2013 08:03:46 -0700 Subject: [PATCH 0666/1502] Subtle fix to write error handling; clarify logic. --- tulip/selector_events.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 8431f485..032dd9ad 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -419,7 +419,7 @@ def _read_ready(self): self.close() def write(self, data): - assert isinstance(data, bytes), repr(data)[:100] + assert isinstance(data, bytes), repr(type(data)) assert not self._eof, 'Cannot call write() after write_eof()' if not data: return @@ -442,11 +442,11 @@ def write(self, data): except OSError as exc: self._fatal_error(exc) return - - if n == len(data): - return - elif n: + else: data = data[n:] + if not data: + return + # Start async I/O. self._loop.add_writer(self._sock_fd, self._write_ready) self._buffer.append(data) @@ -461,20 +461,20 @@ def _write_ready(self): except (BlockingIOError, InterruptedError): self._buffer.append(data) except (BrokenPipeError, ConnectionResetError) as exc: + self._loop.remove_writer(self._sock_fd) self._force_close(exc) except Exception as exc: self._loop.remove_writer(self._sock_fd) self._fatal_error(exc) else: - if n == len(data): + data = data[n:] + if not data: self._loop.remove_writer(self._sock_fd) if self._closing: self._call_connection_lost(None) elif self._eof: self._sock.shutdown(socket.SHUT_WR) return - elif n: - data = data[n:] self._buffer.append(data) # Try again later. @@ -609,7 +609,7 @@ def _on_ready(self): self._call_connection_lost(None) def write(self, data): - assert isinstance(data, bytes), repr(data) + assert isinstance(data, bytes), repr(type(data)) if not data: return @@ -653,7 +653,7 @@ def _read_ready(self): self._protocol.datagram_received(data, addr) def sendto(self, data, addr=None): - assert isinstance(data, bytes), repr(data) + assert isinstance(data, bytes), repr(type(data)) if not data: return From 26c33c1c548064aa9b0f6b6d4870707db7f58043 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 11 Oct 2013 08:13:20 -0700 Subject: [PATCH 0667/1502] Change Server.wait_closed() to wait for all connections to close. --- tulip/base_events.py | 34 +++++++++++++++++++++++++--------- tulip/proactor_events.py | 26 +++++++++++++++++--------- tulip/selector_events.py | 35 +++++++++++++++++++++-------------- 3 files changed, 63 insertions(+), 32 deletions(-) diff --git a/tulip/base_events.py b/tulip/base_events.py index db52b757..ccaecb62 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -50,19 +50,34 @@ class Server(events.AbstractServer): def __init__(self, loop, sockets): self.loop = loop self.sockets = sockets + self.active_count = 0 self.waiters = [] + def attach(self, transport): + assert self.sockets is not None + self.active_count += 1 + + def detach(self, transport): + assert self.active_count > 0 + self.active_count -= 1 + if self.active_count == 0 and self.sockets is None: + self._wakeup() + def close(self): sockets = self.sockets if sockets is not None: self.sockets = None for sock in sockets: self.loop._stop_serving(sock) - waiters = self.waiters - self.waiters = None - for waiter in waiters: - if not waiter.done(): - waiter.set_result(waiter) + if self.active_count == 0: + self._wakeup() + + def _wakeup(self): + waiters = self.waiters + self.waiters = None + for waiter in waiters: + if not waiter.done(): + waiter.set_result(waiter) @tasks.coroutine def wait_closed(self): @@ -83,12 +98,12 @@ def __init__(self): self._running = False def _make_socket_transport(self, sock, protocol, waiter=None, *, - extra=None): + extra=None, server=None): """Create socket transport.""" raise NotImplementedError def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, - server_side=False, extra=None): + server_side=False, extra=None, server=None): """Create SSL transport.""" raise NotImplementedError @@ -471,11 +486,12 @@ def create_server(self, protocol_factory, host=None, port=None, 'host and port was not specified and no sock specified') sockets = [sock] + server = Server(self, sockets) for sock in sockets: sock.listen(backlog) sock.setblocking(False) - self._start_serving(protocol_factory, sock, ssl) - return Server(self, sockets) + self._start_serving(protocol_factory, sock, ssl, server) + return server @tasks.coroutine def connect_read_pipe(self, protocol_factory, pipe): diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index f1fe6d2a..5c4bfc7c 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -16,18 +16,21 @@ class _ProactorBasePipeTransport(transports.BaseTransport): """Base class for pipe and socket transports.""" - def __init__(self, loop, sock, protocol, waiter=None, extra=None): + def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None): super().__init__(extra) self._set_extra(sock) self._loop = loop self._sock = sock self._protocol = protocol + self._server = server self._buffer = [] self._read_fut = None self._write_fut = None self._conn_lost = 0 self._closing = False # Set when close() called. self._eof_written = False + if self._server is not None: + self._server.attach(self) self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: self._loop.call_soon(waiter.set_result, None) @@ -73,14 +76,18 @@ def _call_connection_lost(self, exc): if hasattr(self._sock, 'shutdown'): self._sock.shutdown(socket.SHUT_RDWR) self._sock.close() + server = self._server + if server is not None: + server.detach(self) + self._server = None class _ProactorReadPipeTransport(_ProactorBasePipeTransport, transports.ReadTransport): """Transport for read pipes.""" - def __init__(self, loop, sock, protocol, waiter=None, extra=None): - super().__init__(loop, sock, protocol, waiter, extra) + def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None): + super().__init__(loop, sock, protocol, waiter, extra, server) self._read_fut = None self._paused = False self._loop.call_soon(self._loop_reading) @@ -235,8 +242,8 @@ def __init__(self, proactor): proactor.set_loop(self) self._make_self_pipe() - def _make_socket_transport(self, sock, protocol, waiter=None, extra=None): - return _ProactorSocketTransport(self, sock, protocol, waiter, extra) + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None, server=None): + return _ProactorSocketTransport(self, sock, protocol, waiter, extra, server) def _make_duplex_pipe_transport(self, sock, protocol, waiter=None, extra=None): @@ -302,16 +309,16 @@ def _loop_self_reading(self, f=None): def _write_to_self(self): self._csock.send(b'x') - def _start_serving(self, protocol_factory, sock, ssl=None): - assert not ssl, 'IocpEventLoop imcompatible with SSL.' + def _start_serving(self, protocol_factory, sock, ssl=None, server=None): + assert not ssl, 'IocpEventLoop is incompatible with SSL.' def loop(f=None): try: - if f: + if f is not None: conn, addr = f.result() protocol = protocol_factory() self._make_socket_transport( - conn, protocol, extra={'addr': addr}) + conn, protocol, extra={'addr': addr}, server=server) f = self._proactor.accept(sock) except OSError: if sock.fileno() != -1: @@ -321,6 +328,7 @@ def loop(f=None): sock.close() else: f.add_done_callback(loop) + self.call_soon(loop) def _process_events(self, event_list): diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 032dd9ad..7ffbc526 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -36,13 +36,13 @@ def __init__(self, selector=None): self._make_self_pipe() def _make_socket_transport(self, sock, protocol, waiter=None, *, - extra=None): - return _SelectorSocketTransport(self, sock, protocol, waiter, extra) + extra=None, server=None): + return _SelectorSocketTransport(self, sock, protocol, waiter, extra, server) def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, - server_side=False, extra=None): + server_side=False, extra=None, server=None): return _SelectorSslTransport( - self, rawsock, protocol, sslcontext, waiter, server_side, extra) + self, rawsock, protocol, sslcontext, waiter, server_side, extra, server) def _make_datagram_transport(self, sock, protocol, address=None, extra=None): @@ -85,11 +85,11 @@ def _write_to_self(self): except (BlockingIOError, InterruptedError): pass - def _start_serving(self, protocol_factory, sock, ssl=None): + def _start_serving(self, protocol_factory, sock, ssl=None, server=None): self.add_reader(sock.fileno(), self._accept_connection, - protocol_factory, sock, ssl) + protocol_factory, sock, ssl, server) - def _accept_connection(self, protocol_factory, sock, ssl=None): + def _accept_connection(self, protocol_factory, sock, ssl=None, server=None): try: conn, addr = sock.accept() conn.setblocking(False) @@ -106,10 +106,10 @@ def _accept_connection(self, protocol_factory, sock, ssl=None): if ssl: self._make_ssl_transport( conn, protocol_factory(), ssl, None, - server_side=True, extra={'addr': addr}) + server_side=True, extra={'addr': addr}, server=server) else: self._make_socket_transport( - conn, protocol_factory(), extra={'addr': addr}) + conn, protocol_factory(), extra={'addr': addr}, server=server) # It's now up to the protocol to handle the connection. def add_reader(self, fd, callback, *args): @@ -324,16 +324,19 @@ class _SelectorTransport(transports.Transport): max_size = 256 * 1024 # Buffer size passed to recv(). - def __init__(self, loop, sock, protocol, extra): + def __init__(self, loop, sock, protocol, extra, server=None): super().__init__(extra) self._extra['socket'] = sock self._loop = loop self._sock = sock self._sock_fd = sock.fileno() self._protocol = protocol + self._server = server self._buffer = collections.deque() self._conn_lost = 0 self._closing = False # Set when close() called. + if server is not None: + server.attach(self) def abort(self): self._force_close(None) @@ -373,12 +376,16 @@ def _call_connection_lost(self, exc): self._sock = None self._protocol = None self._loop = None + server = self._server + if server is not None: + server.detach(self) + self._server = None class _SelectorSocketTransport(_SelectorTransport): - def __init__(self, loop, sock, protocol, waiter=None, extra=None): - super().__init__(loop, sock, protocol, extra) + def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None): + super().__init__(loop, sock, protocol, extra, server) self._eof = False self._paused = False @@ -492,7 +499,7 @@ def can_write_eof(self): class _SelectorSslTransport(_SelectorTransport): def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, - server_side=False, extra=None): + server_side=False, extra=None, server=None): if server_side: assert isinstance( sslcontext, ssl.SSLContext), 'Must pass an SSLContext' @@ -502,7 +509,7 @@ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, sslsock = sslcontext.wrap_socket(rawsock, server_side=server_side, do_handshake_on_connect=False) - super().__init__(loop, sslsock, protocol, extra) + super().__init__(loop, sslsock, protocol, extra, server) self._waiter = waiter self._rawsock = rawsock From b3029282329c6e7f36fb65ffaef30b7558e0e06f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 11 Oct 2013 08:38:42 -0700 Subject: [PATCH 0668/1502] Make source and sink more alike. --- examples/sink.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/examples/sink.py b/examples/sink.py index 108340a1..40326d3c 100644 --- a/examples/sink.py +++ b/examples/sink.py @@ -34,8 +34,10 @@ def connection_lost(self, how): @coroutine def start(loop): - svr = yield from loop.create_server(Service, 'localhost', 1111) - return svr + global server + server = yield from loop.create_server(Service, 'localhost', 1111) + dprint('serving', [s.getsockname() for s in server.sockets]) + yield from server.wait_closed() def main(): if '--iocp' in sys.argv: @@ -43,10 +45,7 @@ def main(): loop = ProactorEventLoop() set_event_loop(loop) loop = get_event_loop() - global server - server = loop.run_until_complete(start(loop)) - dprint('serving', [s.getsockname() for s in server.sockets]) - loop.run_until_complete(server.wait_closed()) + loop.run_until_complete(start(loop)) loop.close() if __name__ == '__main__': From 27c9469ddce7b26142fed221e3dc5c8c48ba4a2e Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 11 Oct 2013 09:42:31 -0700 Subject: [PATCH 0669/1502] Make a start with standardized extra info. --- examples/fetch2.py | 2 +- examples/fetch3.py | 4 ++-- examples/sink.py | 4 ++-- examples/source.py | 4 ++-- tests/events_test.py | 27 ++++++++++----------------- tulip/base_events.py | 17 +++++------------ tulip/proactor_events.py | 18 +++++++++++++----- tulip/selector_events.py | 23 +++++++++++++++++------ 8 files changed, 52 insertions(+), 47 deletions(-) diff --git a/examples/fetch2.py b/examples/fetch2.py index 2ea1c695..8badf1db 100644 --- a/examples/fetch2.py +++ b/examples/fetch2.py @@ -45,7 +45,7 @@ def connect(self): ssl=self.ssl) if self.verbose: print('* Connected to %s' % - (self.writer.get_extra_info('socket').getpeername(),), + (self.writer.get_extra_info('getpeername'),), file=sys.stderr) def putline(self, line): diff --git a/examples/fetch3.py b/examples/fetch3.py index ea5e298d..8142f094 100644 --- a/examples/fetch3.py +++ b/examples/fetch3.py @@ -34,7 +34,7 @@ def open_connection(self, host, port, ssl): print('* Reusing pooled connection', key, file=sys.stderr) return conn reader, writer = yield from open_connection(host, port, ssl=ssl) - host, port, *_ = writer.get_extra_info('socket').getpeername() + host, port, *_ = writer.get_extra_info('getpeername') key = host, port, ssl self.connections[key] = reader, writer if self.verbose: @@ -79,7 +79,7 @@ def connect(self, pool): self.port, ssl=self.ssl) self.vprint('* Connected to %s' % - (self.writer.get_extra_info('socket').getpeername(),)) + (self.writer.get_extra_info('getpeername'),)) @coroutine def putline(self, line): diff --git a/examples/sink.py b/examples/sink.py index 40326d3c..855a4aa1 100644 --- a/examples/sink.py +++ b/examples/sink.py @@ -12,8 +12,8 @@ def dprint(*args): class Service(Protocol): def connection_made(self, tr): - dprint('connection from', tr.get_extra_info('socket').getpeername()) - dprint('my socket is', tr.get_extra_info('socket').getsockname()) + dprint('connection from', tr.get_extra_info('peername')) + dprint('my socket is', tr.get_extra_info('sockname')) self.tr = tr self.total = 0 diff --git a/examples/source.py b/examples/source.py index ac146d40..a0134139 100644 --- a/examples/source.py +++ b/examples/source.py @@ -12,8 +12,8 @@ class Client(Protocol): data = b'x'*16*1024 def connection_made(self, tr): - dprint('connecting to', tr.get_extra_info('socket').getpeername()) - dprint('my socket is', tr.get_extra_info('socket').getsockname()) + dprint('connecting to', tr.get_extra_info('getpeername')) + dprint('my socket is', tr.get_extra_info('getsockname')) self.tr = tr self.lost = False self.loop = get_event_loop() diff --git a/tests/events_test.py b/tests/events_test.py index e1fb020b..667348c0 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -516,8 +516,7 @@ def test_create_ssl_connection(self): self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) self.assertTrue('ssl' in tr.__class__.__name__.lower()) - self.assertTrue( - hasattr(tr.get_extra_info('socket'), 'getsockname')) + self.assertIsNotNone(tr.get_extra_info('sockname')) self.loop.run_until_complete(pr.done) self.assertGreater(pr.nbytes, 0) tr.close() @@ -529,7 +528,7 @@ def test_create_connection_local_addr(self): lambda: MyProto(loop=self.loop), *httpd.address, local_addr=(httpd.address[0], port)) tr, pr = self.loop.run_until_complete(f) - expected = pr.transport.get_extra_info('socket').getsockname()[1] + expected = pr.transport.get_extra_info('sockname')[1] self.assertEqual(port, expected) tr.close() @@ -569,11 +568,9 @@ def factory(): self.assertEqual(3, proto.nbytes) # extra info is available - self.assertIsNotNone(proto.transport.get_extra_info('socket')) - conn = proto.transport.get_extra_info('socket') - self.assertTrue(hasattr(conn, 'getsockname')) - self.assertEqual( - '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + self.assertIsNotNone(proto.transport.get_extra_info('sockname')) + self.assertEqual('127.0.0.1', + proto.transport.get_extra_info('peername')[0]) # close connection proto.transport.close() @@ -628,11 +625,9 @@ def factory(): self.assertEqual(3, proto.nbytes) # extra info is available - self.assertIsNotNone(proto.transport.get_extra_info('socket')) - conn = proto.transport.get_extra_info('socket') - self.assertTrue(hasattr(conn, 'getsockname')) - self.assertEqual( - '127.0.0.1', proto.transport.get_extra_info('addr')[0]) + self.assertIsNotNone(proto.transport.get_extra_info('sockname')) + self.assertEqual('127.0.0.1', + proto.transport.get_extra_info('peername')[0]) # close connection proto.transport.close() @@ -759,7 +754,7 @@ def datagram_received(self, data, addr): coro = self.loop.create_datagram_endpoint( TestMyDatagramProto, local_addr=('127.0.0.1', 0)) s_transport, server = self.loop.run_until_complete(coro) - host, port = s_transport.get_extra_info('addr') + host, port = s_transport.get_extra_info('sockname') coro = self.loop.create_datagram_endpoint( lambda: MyDatagramProto(loop=self.loop), @@ -776,9 +771,7 @@ def datagram_received(self, data, addr): self.assertEqual(8, client.nbytes) # extra info is available - self.assertIsNotNone(transport.get_extra_info('socket')) - conn = transport.get_extra_info('socket') - self.assertTrue(hasattr(conn, 'getsockname')) + self.assertIsNotNone(transport.get_extra_info('sockname')) # close connection transport.close() diff --git a/tulip/base_events.py b/tulip/base_events.py index ccaecb62..9e75439b 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -397,7 +397,6 @@ def create_datagram_endpoint(self, protocol_factory, for ((family, proto), (local_address, remote_address)) in addr_pairs_info: sock = None - l_addr = None r_addr = None try: sock = socket.socket( @@ -407,7 +406,6 @@ def create_datagram_endpoint(self, protocol_factory, if local_addr: sock.bind(local_address) - l_addr = sock.getsockname() if remote_addr: yield from self.sock_connect(sock, remote_address) r_addr = remote_address @@ -421,8 +419,7 @@ def create_datagram_endpoint(self, protocol_factory, raise exceptions[0] protocol = protocol_factory() - transport = self._make_datagram_transport( - sock, protocol, r_addr, extra={'addr': l_addr}) + transport = self._make_datagram_transport(sock, protocol, r_addr) return transport, protocol @tasks.coroutine @@ -497,8 +494,7 @@ def create_server(self, protocol_factory, host=None, port=None, def connect_read_pipe(self, protocol_factory, pipe): protocol = protocol_factory() waiter = futures.Future(loop=self) - transport = self._make_read_pipe_transport(pipe, protocol, waiter, - extra={}) + transport = self._make_read_pipe_transport(pipe, protocol, waiter) yield from waiter return transport, protocol @@ -506,8 +502,7 @@ def connect_read_pipe(self, protocol_factory, pipe): def connect_write_pipe(self, protocol_factory, pipe): protocol = protocol_factory() waiter = futures.Future(loop=self) - transport = self._make_write_pipe_transport(pipe, protocol, waiter, - extra={}) + transport = self._make_write_pipe_transport(pipe, protocol, waiter) yield from waiter return transport, protocol @@ -521,8 +516,7 @@ def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, assert isinstance(cmd, str), cmd protocol = protocol_factory() transport = yield from self._make_subprocess_transport( - protocol, cmd, True, stdin, stdout, stderr, bufsize, - extra={}, **kwargs) + protocol, cmd, True, stdin, stdout, stderr, bufsize, **kwargs) return transport, protocol @tasks.coroutine @@ -534,8 +528,7 @@ def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, assert not shell, "shell must be False" protocol = protocol_factory() transport = yield from self._make_subprocess_transport( - protocol, args, False, stdin, stdout, stderr, bufsize, - extra={}, **kwargs) + protocol, args, False, stdin, stdout, stderr, bufsize, **kwargs) return transport, protocol def _add_callback(self, handle): diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index 5c4bfc7c..7c49ae0d 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -16,7 +16,8 @@ class _ProactorBasePipeTransport(transports.BaseTransport): """Base class for pipe and socket transports.""" - def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None): + def __init__(self, loop, sock, protocol, waiter=None, + extra=None, server=None): super().__init__(extra) self._set_extra(sock) self._loop = loop @@ -86,7 +87,8 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport, transports.ReadTransport): """Transport for read pipes.""" - def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None): + def __init__(self, loop, sock, protocol, waiter=None, + extra=None, server=None): super().__init__(loop, sock, protocol, waiter, extra, server) self._read_fut = None self._paused = False @@ -220,6 +222,9 @@ class _ProactorSocketTransport(_ProactorReadPipeTransport, def _set_extra(self, sock): self._extra['socket'] = sock + self._extra['sockname'] = sock.getsockname() + if 'peername' not in self._extra: + self._extra['peername'] = sock.getpeername() def can_write_eof(self): return True @@ -242,8 +247,10 @@ def __init__(self, proactor): proactor.set_loop(self) self._make_self_pipe() - def _make_socket_transport(self, sock, protocol, waiter=None, extra=None, server=None): - return _ProactorSocketTransport(self, sock, protocol, waiter, extra, server) + def _make_socket_transport(self, sock, protocol, waiter=None, + extra=None, server=None): + return _ProactorSocketTransport(self, sock, protocol, waiter, + extra, server) def _make_duplex_pipe_transport(self, sock, protocol, waiter=None, extra=None): @@ -318,7 +325,8 @@ def loop(f=None): conn, addr = f.result() protocol = protocol_factory() self._make_socket_transport( - conn, protocol, extra={'addr': addr}, server=server) + conn, protocol, + extra={'peername': addr}, server=server) f = self._proactor.accept(sock) except OSError: if sock.fileno() != -1: diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 7ffbc526..00b02521 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -37,12 +37,14 @@ def __init__(self, selector=None): def _make_socket_transport(self, sock, protocol, waiter=None, *, extra=None, server=None): - return _SelectorSocketTransport(self, sock, protocol, waiter, extra, server) + return _SelectorSocketTransport(self, sock, protocol, waiter, + extra, server) def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, server_side=False, extra=None, server=None): return _SelectorSslTransport( - self, rawsock, protocol, sslcontext, waiter, server_side, extra, server) + self, rawsock, protocol, sslcontext, waiter, server_side, + extra, server) def _make_datagram_transport(self, sock, protocol, address=None, extra=None): @@ -89,7 +91,8 @@ def _start_serving(self, protocol_factory, sock, ssl=None, server=None): self.add_reader(sock.fileno(), self._accept_connection, protocol_factory, sock, ssl, server) - def _accept_connection(self, protocol_factory, sock, ssl=None, server=None): + def _accept_connection(self, protocol_factory, sock, ssl=None, + server=None): try: conn, addr = sock.accept() conn.setblocking(False) @@ -106,10 +109,11 @@ def _accept_connection(self, protocol_factory, sock, ssl=None, server=None): if ssl: self._make_ssl_transport( conn, protocol_factory(), ssl, None, - server_side=True, extra={'addr': addr}, server=server) + server_side=True, extra={'peername': addr}, server=server) else: self._make_socket_transport( - conn, protocol_factory(), extra={'addr': addr}, server=server) + conn, protocol_factory(), extra={'peername': addr}, + server=server) # It's now up to the protocol to handle the connection. def add_reader(self, fd, callback, *args): @@ -327,6 +331,12 @@ class _SelectorTransport(transports.Transport): def __init__(self, loop, sock, protocol, extra, server=None): super().__init__(extra) self._extra['socket'] = sock + self._extra['sockname'] = sock.getsockname() + if 'peername' not in self._extra: + try: + self._extra['peername'] = sock.getpeername() + except socket.error: + self._extra['peername'] = None self._loop = loop self._sock = sock self._sock_fd = sock.fileno() @@ -384,7 +394,8 @@ def _call_connection_lost(self, exc): class _SelectorSocketTransport(_SelectorTransport): - def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None): + def __init__(self, loop, sock, protocol, waiter=None, + extra=None, server=None): super().__init__(loop, sock, protocol, extra, server) self._eof = False self._paused = False From f103676dcd624b65cbcf04c7005f1fb79aef4537 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 11 Oct 2013 10:10:06 -0700 Subject: [PATCH 0670/1502] Fix sock/peer name request --- examples/source.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/source.py b/examples/source.py index a0134139..7a5011da 100644 --- a/examples/source.py +++ b/examples/source.py @@ -12,8 +12,8 @@ class Client(Protocol): data = b'x'*16*1024 def connection_made(self, tr): - dprint('connecting to', tr.get_extra_info('getpeername')) - dprint('my socket is', tr.get_extra_info('getsockname')) + dprint('connecting to', tr.get_extra_info('peername')) + dprint('my socket is', tr.get_extra_info('sockname')) self.tr = tr self.lost = False self.loop = get_event_loop() From 2e5933311b962be20d27f616b6b777f46fe69fe1 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 12 Oct 2013 11:08:41 -0700 Subject: [PATCH 0671/1502] Add SNI support to SSL. Patch from Issue 74 by Aymeric Augustin. --- tests/selector_events_test.py | 9 +++++++++ tulip/base_events.py | 6 ++++-- tulip/selector_events.py | 19 +++++++++++++------ 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index 8926b2da..d117f97b 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -1223,6 +1223,15 @@ def test_close(self): self.assertEqual(tr._conn_lost, 1) self.assertEqual(1, self.loop.remove_reader_count[1]) + @unittest.skipIf(ssl is None or not ssl.HAS_SNI, 'No SNI support') + def test_server_hostname(self): + _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext, + server_hostname='localhost') + self.sslcontext.wrap_socket.assert_called_with( + self.sock, do_handshake_on_connect=False, server_side=False, + server_hostname='localhost') + class SelectorDatagramTransportTests(unittest.TestCase): diff --git a/tulip/base_events.py b/tulip/base_events.py index 9e75439b..dcd97617 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -103,7 +103,8 @@ def _make_socket_transport(self, sock, protocol, waiter=None, *, raise NotImplementedError def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, - server_side=False, extra=None, server=None): + server_side=False, server_hostname=None, + extra=None, server=None): """Create SSL transport.""" raise NotImplementedError @@ -347,7 +348,8 @@ def create_connection(self, protocol_factory, host=None, port=None, *, if ssl: sslcontext = None if isinstance(ssl, bool) else ssl transport = self._make_ssl_transport( - sock, protocol, sslcontext, waiter, server_side=False) + sock, protocol, sslcontext, waiter, + server_side=False, server_hostname=host) else: transport = self._make_socket_transport(sock, protocol, waiter) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 00b02521..37af46db 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -41,10 +41,11 @@ def _make_socket_transport(self, sock, protocol, waiter=None, *, extra, server) def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, - server_side=False, extra=None, server=None): + server_side=False, server_hostname=None, + extra=None, server=None): return _SelectorSslTransport( - self, rawsock, protocol, sslcontext, waiter, server_side, - extra, server) + self, rawsock, protocol, sslcontext, waiter, + server_side, server_hostname, extra, server) def _make_datagram_transport(self, sock, protocol, address=None, extra=None): @@ -510,15 +511,21 @@ def can_write_eof(self): class _SelectorSslTransport(_SelectorTransport): def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, - server_side=False, extra=None, server=None): + server_side=False, server_hostname=None, + extra=None, server=None): if server_side: assert isinstance( sslcontext, ssl.SSLContext), 'Must pass an SSLContext' else: # Client-side may pass ssl=True to use a default context. sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) - sslsock = sslcontext.wrap_socket(rawsock, server_side=server_side, - do_handshake_on_connect=False) + wrap_kwargs = { + 'server_side': server_side, + 'do_handshake_on_connect': False, + } + if server_hostname is not None and not server_side and ssl.HAS_SNI: + wrap_kwargs['server_hostname'] = server_hostname + sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs) super().__init__(loop, sslsock, protocol, extra, server) From 89fe3076e21665fbaf992273061581161ac61ccb Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 14 Oct 2013 14:09:01 -0700 Subject: [PATCH 0672/1502] Make test_create_datagram_endpoint() less flaky. --- tests/events_test.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 667348c0..7063efb3 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -763,9 +763,15 @@ def datagram_received(self, data, addr): self.assertEqual('INITIALIZED', client.state) transport.sendto(b'xxx') - test_utils.run_briefly(self.loop) + for _ in range(1000): + if server.nbytes: + break + test_utils.run_briefly(self.loop) self.assertEqual(3, server.nbytes) - test_utils.run_briefly(self.loop) + for _ in range(1000): + if client.nbytes: + break + test_utils.run_briefly(self.loop) # received self.assertEqual(8, client.nbytes) From f4a119f160f1388372f5da3f6d9299b69b520359 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 14 Oct 2013 14:12:46 -0700 Subject: [PATCH 0673/1502] Rename tulip package to asyncio. --- .hgeol | 4 + .hgignore | 12 + Makefile | 34 + NOTES | 176 ++++ README | 21 + TODO | 163 ++++ asyncio/__init__.py | 26 + asyncio/base_events.py | 606 +++++++++++++ asyncio/constants.py | 4 + asyncio/events.py | 399 +++++++++ asyncio/futures.py | 338 +++++++ asyncio/locks.py | 401 +++++++++ asyncio/log.py | 6 + asyncio/proactor_events.py | 347 ++++++++ asyncio/protocols.py | 100 +++ asyncio/queues.py | 284 ++++++ asyncio/selector_events.py | 742 ++++++++++++++++ asyncio/selectors.py | 405 +++++++++ asyncio/streams.py | 257 ++++++ asyncio/tasks.py | 575 ++++++++++++ asyncio/test_utils.py | 228 +++++ asyncio/transports.py | 186 ++++ asyncio/unix_events.py | 534 +++++++++++ asyncio/windows_events.py | 371 ++++++++ asyncio/windows_utils.py | 181 ++++ check.py | 45 + examples/child_process.py | 127 +++ examples/fetch0.py | 32 + examples/fetch1.py | 75 ++ examples/fetch2.py | 138 +++ examples/fetch3.py | 210 +++++ examples/sink.py | 52 ++ examples/source.py | 58 ++ examples/stacks.py | 43 + examples/tcp_echo.py | 113 +++ examples/udp_echo.py | 98 ++ overlapped.c | 1203 +++++++++++++++++++++++++ runtests.py | 278 ++++++ setup.cfg | 2 + setup.py | 14 + tests/base_events_test.py | 592 +++++++++++++ tests/echo.py | 6 + tests/echo2.py | 6 + tests/echo3.py | 9 + tests/events_test.py | 1571 +++++++++++++++++++++++++++++++++ tests/futures_test.py | 329 +++++++ tests/locks_test.py | 765 ++++++++++++++++ tests/proactor_events_test.py | 480 ++++++++++ tests/queues_test.py | 470 ++++++++++ tests/sample.crt | 14 + tests/sample.key | 15 + tests/selector_events_test.py | 1473 +++++++++++++++++++++++++++++++ tests/selectors_test.py | 145 +++ tests/streams_test.py | 360 ++++++++ tests/tasks_test.py | 1502 +++++++++++++++++++++++++++++++ tests/transports_test.py | 55 ++ tests/unix_events_test.py | 749 ++++++++++++++++ tests/windows_events_test.py | 91 ++ tests/windows_utils_test.py | 132 +++ 59 files changed, 17652 insertions(+) create mode 100644 .hgeol create mode 100644 .hgignore create mode 100644 Makefile create mode 100644 NOTES create mode 100644 README create mode 100644 TODO create mode 100644 asyncio/__init__.py create mode 100644 asyncio/base_events.py create mode 100644 asyncio/constants.py create mode 100644 asyncio/events.py create mode 100644 asyncio/futures.py create mode 100644 asyncio/locks.py create mode 100644 asyncio/log.py create mode 100644 asyncio/proactor_events.py create mode 100644 asyncio/protocols.py create mode 100644 asyncio/queues.py create mode 100644 asyncio/selector_events.py create mode 100644 asyncio/selectors.py create mode 100644 asyncio/streams.py create mode 100644 asyncio/tasks.py create mode 100644 asyncio/test_utils.py create mode 100644 asyncio/transports.py create mode 100644 asyncio/unix_events.py create mode 100644 asyncio/windows_events.py create mode 100644 asyncio/windows_utils.py create mode 100644 check.py create mode 100644 examples/child_process.py create mode 100644 examples/fetch0.py create mode 100644 examples/fetch1.py create mode 100644 examples/fetch2.py create mode 100644 examples/fetch3.py create mode 100644 examples/sink.py create mode 100644 examples/source.py create mode 100644 examples/stacks.py create mode 100755 examples/tcp_echo.py create mode 100755 examples/udp_echo.py create mode 100644 overlapped.c create mode 100644 runtests.py create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 tests/base_events_test.py create mode 100644 tests/echo.py create mode 100644 tests/echo2.py create mode 100644 tests/echo3.py create mode 100644 tests/events_test.py create mode 100644 tests/futures_test.py create mode 100644 tests/locks_test.py create mode 100644 tests/proactor_events_test.py create mode 100644 tests/queues_test.py create mode 100644 tests/sample.crt create mode 100644 tests/sample.key create mode 100644 tests/selector_events_test.py create mode 100644 tests/selectors_test.py create mode 100644 tests/streams_test.py create mode 100644 tests/tasks_test.py create mode 100644 tests/transports_test.py create mode 100644 tests/unix_events_test.py create mode 100644 tests/windows_events_test.py create mode 100644 tests/windows_utils_test.py diff --git a/.hgeol b/.hgeol new file mode 100644 index 00000000..8233b6dd --- /dev/null +++ b/.hgeol @@ -0,0 +1,4 @@ +[patterns] +** = native +.hgignore = native +.hgeol = native diff --git a/.hgignore b/.hgignore new file mode 100644 index 00000000..99870025 --- /dev/null +++ b/.hgignore @@ -0,0 +1,12 @@ +.*\.py[co]$ +.*~$ +.*\.orig$ +.*\#.*$ +.*@.*$ +\.coverage$ +htmlcov$ +\.DS_Store$ +venv$ +distribute_setup.py$ +distribute-\d+.\d+.\d+.tar.gz$ +build$ diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..ed3caf21 --- /dev/null +++ b/Makefile @@ -0,0 +1,34 @@ +# Some simple testing tasks (sorry, UNIX only). + +PYTHON=python3 +VERBOSE=$(V) +V= 0 +FLAGS= + +test: + $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS) + +vtest: + $(PYTHON) runtests.py -v 1 $(FLAGS) + +testloop: + while sleep 1; do $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS); done + +# See runtests.py for coverage installation instructions. +cov coverage: + $(PYTHON) runtests.py --coverage -v $(VERBOSE) $(FLAGS) + +check: + $(PYTHON) check.py + +clean: + rm -rf `find . -name __pycache__` + rm -f `find . -type f -name '*.py[co]' ` + rm -f `find . -type f -name '*~' ` + rm -f `find . -type f -name '.*~' ` + rm -f `find . -type f -name '@*' ` + rm -f `find . -type f -name '#*#' ` + rm -f `find . -type f -name '*.orig' ` + rm -f `find . -type f -name '*.rej' ` + rm -f .coverage + rm -rf htmlcov diff --git a/NOTES b/NOTES new file mode 100644 index 00000000..3b94ba96 --- /dev/null +++ b/NOTES @@ -0,0 +1,176 @@ +Notes from PyCon 2013 sprints +============================= + +- Cancellation. If a task creates several subtasks, and then the + parent task fails, should the subtasks be cancelled? (How do we + even establish the parent/subtask relationship?) + +- Adam Sah suggests that there might be a need for scheduling + (especially when multiple frameworks share an event loop). He + points to lottery scheduling but also mentions that's just one of + the options. However, after posting on python-tulip, it appears + none of the other frameworks have scheduling, and nobody seems to + miss it. + +- Feedback from Bram Cohen (Bittorrent creator) about UDP. He doesn't + think connected UDP is worth supporting, it doesn't do anything + except tell the kernel about the default target address for + sendto(). Basically he says all UDP end points are servers. He + sent me his own UDP event loop so I might glean some tricks from it. + He says we should treat EINTR the same as EAGAIN and friends. (We + should use the exceptions dedicated to errno checking, BTW.) HE + said to make sure we use SO_REUSEADDR (I think we already do). He + said to set the max datagram sizes pretty large (anything larger + than the declared limit is dropped on the floor). He reminds us of + the importance of being able to pick a valid, unused port by binding + to port 0 and then using getsockname(). He has an idea where he's + like to be able to kill all registered callbacks (i.e. Handles) + belonging to a certain "context". I think this can be done at the + application level (you'd have to wrap everything that returns a + Handle and collect these handles in some set or other datastructure) + but if someone thinks it's interesting we could imagine having some + kind of notion of context part of the event loop state, + e.g. associated with a Task (see Cancellation point above). He + brought up uTP (Micro Transport Protocol), a reimplementation of TCP + over UDP with more refined congestion control. + +- Mumblings about UNIX domain sockets and IPv6 addresses being + 4-tuples. The former can be handled by passing in a socket. There + seem to be no real use cases for the latter that can't be dealt with + by passing in suitably esoteric strings for the hostname. + getaddrinfo() will produce the appropriate 4-tuple and connect() + will accept it. + +- Mumblings on the list about add vs. set. + + +Notes from the second Tulip/Twisted meet-up +=========================================== + +Rackspace, 12/11/2012 +Glyph, Brian Warner, David Reid, Duncan McGreggor, others + +Flow control +------------ + +- Pause/resume on transport manages data_received. + +- There's also an API to tell the transport whom to pause when the + write calls are overwhelming it: IConsumer.registerProducer(). + +- There's also something called pipes but it's built on top of the + old interface. + +- Twisted has variations on the basic flow control that I should + ignore. + +Half_close +---------- + +- This sends an EOF after writing some stuff. + +- Can't write any more. + +- Problem with TLS is known (the RFC sadly specifies this behavior). + +- It must be dynamimcally discoverable whether the transport supports + half_close, since the protocol may have to do something different to + make up for its missing (e.g. use chunked encoding). Twisted uses + an interface check for this and also hasattr(trans, 'halfClose') + but a flag (or flag method) is fine too. + +Constructing transport and protocol +----------------------------------- + +- There are good reasons for passing a function to the transport + construction helper that creates the protocol. (You need these + anyway for server-side protocols.) The sequence of events is + something like + + . open socket + . create transport (pass it a socket?) + . create protocol (pass it nothing) + . proto.make_connection(transport); this does: + . self.transport = transport + . self.connection_made(transport) + + But it seems okay to skip make_connection and setting .transport. + Note that make_connection() is a concrete method on the Protocol + implementation base class, while connection_made() is an abstract + method on IProtocol. + +Event Loop +---------- + +- We discussed the sequence of actions in the event loop. I think in the + end we're fine with what Tulip currently does. There are two choices: + + Tulip: + . run ready callbacks until there aren't any left + . poll, adding more callbacks to the ready list + . add now-ready delayed callbacks to the ready list + . go to top + + Tornado: + . run all currently ready callbacks (but not new ones added during this) + . (the rest is the same) + + The difference is that in the Tulip version, CPU bound callbacks + that keep adding more to the queue will starve I/O (and yielding to + other tasks won't actually cause I/O to happen unless you do + e.g. sleep(0.001)). OTOH this may be good because it means there's + less overhead if you frequently split operations in two. + +- I think Twisted does it Tornado style (in a convoluted way :-), but + it may not matter, and it's important to leave this vague so + implementations can do what's best for their platform. (E.g. if the + event loop is built into the OS there are different trade-offs.) + +System call cost +---------------- + +- System calls on MacOS are expensive, on Linux they are cheap. + +- Optimal buffer size ~16K. + +- Try joining small buffer pieces together, but expect to be tuning + this later. + +Futures +------- + +- Futures are the most robust API for async stuff, you can check + errors etc. So let's do this. + +- Just don't implement wait(). + +- For the basics, however, (recv/send, mostly), don't use Futures but use + basic callbacks, transport/protocol style. + +- make_connection() (by any name) can return a Future, it makes it + easier to check for errors. + +- This means revisiting the Tulip proactor branch (IOCP). + +- The semantics of add_done_callback() are fuzzy about in which thread + the callback will be called. (It may be the current thread or + another one.) We don't like that. But always inserting a + call_soon() indirection may be expensive? Glyph suggested changing + the add_done_callback() method name to something else to indicate + the changed promise. + +- Separately, I've been thinking about having two versions of + call_soon() -- a more heavy-weight one to be called from other + threads that also writes a byte to the self-pipe. + +Signals +------- + +- There was a side conversation about signals. A signal handler is + similar to another thread, so probably should use (the heavy-weight + version of) call_soon() to schedule the real callback and not do + anything else. + +- Glyph vaguely recalled some trickiness with the self-pipe. We + should be able to fix this afterwards if necessary, it shouldn't + affect the API design. diff --git a/README b/README new file mode 100644 index 00000000..8f2b6373 --- /dev/null +++ b/README @@ -0,0 +1,21 @@ +Tulip is the codename for my reference implementation of PEP 3156. + +PEP 3156: http://www.python.org/dev/peps/pep-3156/ + +*** This requires Python 3.3 or later! *** + +Copyright/license: Open source, Apache 2.0. Enjoy. + +Master Mercurial repo: http://code.google.com/p/tulip/ + +The actual code lives in the 'tulip' subdirectory. +Tests are in the 'tests' subdirectory. + +To run tests: + - make test + +To run coverage (coverage package is required): + - make coverage + + +--Guido van Rossum diff --git a/TODO b/TODO new file mode 100644 index 00000000..c6d4eead --- /dev/null +++ b/TODO @@ -0,0 +1,163 @@ +# -*- Mode: text -*- + +TO DO LARGER TASKS + +- Need more examples. + +- Benchmarkable but more realistic HTTP server? + +- Example of using UDP. + +- Write up a tutorial for the scheduling API. + +- More systematic approach to logging. Logger objects? What about + heavy-duty logging, tracking essentially all task state changes? + +- Restructure directory, move demos and benchmarks to subdirectories. + + +TO DO LATER + +- When multiple tasks are accessing the same socket, they should + either get interleaved I/O or an immediate exception; it should not + compromise the integrity of the scheduler or the app or leave a task + hanging. + +- For epoll you probably want to check/(log?) EPOLLHUP and EPOLLERR errors. + +- Add the simplest API possible to run a generator with a timeout. + +- Ensure multiple tasks can do atomic writes to the same pipe (since + UNIX guarantees that short writes to pipes are atomic). + +- Ensure some easy way of distributing accepted connections across tasks. + +- Be wary of thread-local storage. There should be a standard API to + get the current Context (which holds current task, event loop, and + maybe more) and a standard meta-API to change how that standard API + works (i.e. without monkey-patching). + +- See how much of asyncore I've already replaced. + +- Could BufferedReader reuse the standard io module's readers??? + +- Support ZeroMQ "sockets" which are user objects. Though possibly + this can be supported by getting the underlying fd? See + http://mail.python.org/pipermail/python-ideas/2012-October/017532.html + OTOH see + https://github.com/zeromq/pyzmq/blob/master/zmq/eventloop/ioloop.py + +- Study goroutines (again). + +- Benchmarks: http://nichol.as/benchmark-of-python-web-servers + + +FROM OLDER LIST + +- Multiple readers/writers per socket? (At which level? pollster, + eventloop, or scheduler?) + +- Could poll() usefully be an iterator? + +- Do we need to support more epoll and/or kqueue modes/flags/options/etc.? + +- Optimize register/unregister calls away if they cancel each other out? + +- Add explicit wait queue to wait for Task's completion, instead of + callbacks? + +- Look at pyfdpdlib's ioloop.py: + http://code.google.com/p/pyftpdlib/source/browse/trunk/pyftpdlib/lib/ioloop.py + + +MISTAKES I MADE + +- Forgetting yield from. (E.g.: scheduler.sleep(1); listener.accept().) + +- Forgot to add bare yield at end of internal function, after block(). + +- Forgot to call add_done_callback(). + +- Forgot to pass an undoer to block(), bug only found when cancelled. + +- Subtle accounting mistake in a callback. + +- Used context.eventloop from a different thread, forgetting about TLS. + +- Nasty race: eventloop.ready may contain both an I/O callback and a + cancel callback. How to avoid? Keep the DelayedCall in ready. Is + that enough? + +- If a toplevel task raises an error it just stops and nothing is logged + unless you have debug logging on. This confused me. (Then again, + previously I logged whenever a task raised an error, and that was too + chatty...) + +- Forgot to set the connection socket returned by accept() in + nonblocking mode. + +- Nastiest so far (cost me about a day): A race condition in + call_in_thread() where the Future's done_callback (which was + task.unblock()) would run immediately at the time when + add_done_callback() was called, and this screwed over the task + state. Solution: wrap the callback in eventloop.call_later(). + Ironically, I had a comment stating there might be a race condition. + +- Another bug where I was calling unblock() for the current thread + immediately after calling block(), before yielding. + +- readexactly() wasn't checking for EOF, so could be looping. + (Worse, the first fix I attempted was wrong.) + +- Spent a day trying to understand why a tentative patch trying to + move the recv() implementation into the eventloop (or the pollster) + resulted in problems cancelling a recv() call. Ultimately the + problem is that the cancellation mechanism is part of the coroutine + scheduler, which simply throws an exception into a task when it next + runs, and there isn't anything to be interrupted in the eventloop; + but the eventloop still has a reader registered (which will never + fire because I suspended the server -- that's my test case :-). + Then, the eventloop keeps running until the last file descriptor is + unregistered. What contributed to this disaster? + * I didn't build the whole infrastructure, just played with recv() + * I don't have unittests + * I don't have good logging to see what is going + +- In sockets.py, in some SSL error handling code, used the wrong + variable (sock instead of sslsock). A linter would have found this. + +- In polling.py, in KqueuePollster.register_writer(), a copy/paste + error where I was testing for "if fd not in self.readers" instead of + writers. This only came out when I had both a reader and a writer + for the same fd. + +- Submitted some changes prematurely (forgot to pass the filename on + hg ci). + +- Forgot again that shutdown(SHUT_WR) on an ssl socket does not work + as I expected. I ran into this with the origininal sockets.py and + again in transport.py. + +- Having the same callback for both reading and writing has a problem: + it may be scheduled twice, and if the first call closes the socket, + the second runs into trouble. + + +MISTAKES I MADE IN TULIP V2 + +- Nice one: Future._schedule_callbacks() wasn't scheduling any callbacks. + Spot the bug in these four lines: + + def _schedule_callbacks(self): + callbacks = self._callbacks[:] + self._callbacks[:] = [] + for callback in self._callbacks: + self._event_loop.call_soon(callback, self) + + The good news is that I found it with a unittest (albeit not the + unittest intended to exercise this particular method :-( ). + +- In _make_self_pipe_or_sock(), called _pollster.register_reader() + instead of add_reader(), trying to optimize something but breaking + things instead (since the -- internal -- API of register_reader() + had changed). diff --git a/asyncio/__init__.py b/asyncio/__init__.py new file mode 100644 index 00000000..513aa958 --- /dev/null +++ b/asyncio/__init__.py @@ -0,0 +1,26 @@ +"""The asyncio package, tracking PEP 3156.""" + +import sys + +# This relies on each of the submodules having an __all__ variable. +from .futures import * +from .events import * +from .locks import * +from .transports import * +from .protocols import * +from .streams import * +from .tasks import * + +if sys.platform == 'win32': # pragma: no cover + from .windows_events import * +else: + from .unix_events import * # pragma: no cover + + +__all__ = (futures.__all__ + + events.__all__ + + locks.__all__ + + transports.__all__ + + protocols.__all__ + + streams.__all__ + + tasks.__all__) diff --git a/asyncio/base_events.py b/asyncio/base_events.py new file mode 100644 index 00000000..32457ebe --- /dev/null +++ b/asyncio/base_events.py @@ -0,0 +1,606 @@ +"""Base implementation of event loop. + +The event loop can be broken up into a multiplexer (the part +responsible for notifying us of IO events) and the event loop proper, +which wraps a multiplexer with functionality for scheduling callbacks, +immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. +""" + + +import collections +import concurrent.futures +import heapq +import logging +import socket +import subprocess +import time +import os +import sys + +from . import events +from . import futures +from . import tasks +from .log import asyncio_log + + +__all__ = ['BaseEventLoop', 'Server'] + + +# Argument for default thread pool executor creation. +_MAX_WORKERS = 5 + + +class _StopError(BaseException): + """Raised to stop the event loop.""" + + +def _raise_stop_error(*args): + raise _StopError + + +class Server(events.AbstractServer): + + def __init__(self, loop, sockets): + self.loop = loop + self.sockets = sockets + self.active_count = 0 + self.waiters = [] + + def attach(self, transport): + assert self.sockets is not None + self.active_count += 1 + + def detach(self, transport): + assert self.active_count > 0 + self.active_count -= 1 + if self.active_count == 0 and self.sockets is None: + self._wakeup() + + def close(self): + sockets = self.sockets + if sockets is not None: + self.sockets = None + for sock in sockets: + self.loop._stop_serving(sock) + if self.active_count == 0: + self._wakeup() + + def _wakeup(self): + waiters = self.waiters + self.waiters = None + for waiter in waiters: + if not waiter.done(): + waiter.set_result(waiter) + + @tasks.coroutine + def wait_closed(self): + if self.sockets is None or self.waiters is None: + return + waiter = futures.Future(loop=self.loop) + self.waiters.append(waiter) + yield from waiter + + +class BaseEventLoop(events.AbstractEventLoop): + + def __init__(self): + self._ready = collections.deque() + self._scheduled = [] + self._default_executor = None + self._internal_fds = 0 + self._running = False + + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None, server=None): + """Create socket transport.""" + raise NotImplementedError + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + server_side=False, server_hostname=None, + extra=None, server=None): + """Create SSL transport.""" + raise NotImplementedError + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + """Create datagram transport.""" + raise NotImplementedError + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create read pipe transport.""" + raise NotImplementedError + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create write pipe transport.""" + raise NotImplementedError + + @tasks.coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + """Create subprocess transport.""" + raise NotImplementedError + + def _read_from_self(self): + """XXX""" + raise NotImplementedError + + def _write_to_self(self): + """XXX""" + raise NotImplementedError + + def _process_events(self, event_list): + """Process selector events.""" + raise NotImplementedError + + def run_forever(self): + """Run until stop() is called.""" + if self._running: + raise RuntimeError('Event loop is running.') + self._running = True + try: + while True: + try: + self._run_once() + except _StopError: + break + finally: + self._running = False + + def run_until_complete(self, future): + """Run until the Future is done. + + If the argument is a coroutine, it is wrapped in a Task. + + XXX TBD: It would be disastrous to call run_until_complete() + with the same coroutine twice -- it would wrap it in two + different Tasks and that can't be good. + + Return the Future's result, or raise its exception. + """ + future = tasks.async(future, loop=self) + future.add_done_callback(_raise_stop_error) + self.run_forever() + future.remove_done_callback(_raise_stop_error) + if not future.done(): + raise RuntimeError('Event loop stopped before Future completed.') + + return future.result() + + def stop(self): + """Stop running the event loop. + + Every callback scheduled before stop() is called will run. + Callback scheduled after stop() is called won't. However, + those callbacks will run if run() is called again later. + """ + self.call_soon(_raise_stop_error) + + def is_running(self): + """Returns running status of event loop.""" + return self._running + + def time(self): + """Return the time according to the event loop's clock.""" + return time.monotonic() + + def call_later(self, delay, callback, *args): + """Arrange for a callback to be called at a given time. + + Return a Handle: an opaque object with a cancel() method that + can be used to cancel the call. + + The delay can be an int or float, expressed in seconds. It is + always a relative time. + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + return self.call_at(self.time() + delay, callback, *args) + + def call_at(self, when, callback, *args): + """Like call_later(), but uses an absolute time.""" + timer = events.TimerHandle(when, callback, args) + heapq.heappush(self._scheduled, timer) + return timer + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + handle = events.make_handle(callback, args) + self._ready.append(handle) + return handle + + def call_soon_threadsafe(self, callback, *args): + """XXX""" + handle = self.call_soon(callback, *args) + self._write_to_self() + return handle + + def run_in_executor(self, executor, callback, *args): + if isinstance(callback, events.Handle): + assert not args + assert not isinstance(callback, events.TimerHandle) + if callback._cancelled: + f = futures.Future(loop=self) + f.set_result(None) + return f + callback, args = callback._callback, callback._args + if executor is None: + executor = self._default_executor + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + self._default_executor = executor + return futures.wrap_future(executor.submit(callback, *args), loop=self) + + def set_default_executor(self, executor): + self._default_executor = executor + + def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) + + def getnameinfo(self, sockaddr, flags=0): + return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + + @tasks.coroutine + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=None, family=0, proto=0, flags=0, sock=None, + local_addr=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + 'host/port and sock can not be specified at the same time') + + f1 = self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + fs = [f1] + if local_addr is not None: + f2 = self.getaddrinfo( + *local_addr, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + fs.append(f2) + else: + f2 = None + + yield from tasks.wait(fs, loop=self) + + infos = f1.result() + if not infos: + raise OSError('getaddrinfo() returned empty list') + if f2 is not None: + laddr_infos = f2.result() + if not laddr_infos: + raise OSError('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + if f2 is not None: + for _, _, _, _, laddr in laddr_infos: + try: + sock.bind(laddr) + break + except OSError as exc: + exc = OSError( + exc.errno, 'error while ' + 'attempting to bind on address ' + '{!r}: {}'.format( + laddr, exc.strerror.lower())) + exceptions.append(exc) + else: + sock.close() + sock = None + continue + yield from self.sock_connect(sock, address) + except OSError as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + if len(exceptions) == 1: + raise exceptions[0] + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise OSError('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + + elif sock is None: + raise ValueError( + 'host and port was not specified and no sock specified') + + sock.setblocking(False) + + protocol = protocol_factory() + waiter = futures.Future(loop=self) + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + transport = self._make_ssl_transport( + sock, protocol, sslcontext, waiter, + server_side=False, server_hostname=host) + else: + transport = self._make_socket_transport(sock, protocol, waiter) + + yield from waiter + return transport, protocol + + @tasks.coroutine + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + """Create datagram connection.""" + if not (local_addr or remote_addr): + if family == 0: + raise ValueError('unexpected address family') + addr_pairs_info = (((family, proto), (None, None)),) + else: + # join addresss by (family, protocol) + addr_infos = collections.OrderedDict() + for idx, addr in ((0, local_addr), (1, remote_addr)): + if addr is not None: + assert isinstance(addr, tuple) and len(addr) == 2, ( + '2-tuple is expected') + + infos = yield from self.getaddrinfo( + *addr, family=family, type=socket.SOCK_DGRAM, + proto=proto, flags=flags) + if not infos: + raise OSError('getaddrinfo() returned empty list') + + for fam, _, pro, _, address in infos: + key = (fam, pro) + if key not in addr_infos: + addr_infos[key] = [None, None] + addr_infos[key][idx] = address + + # each addr has to have info for each (family, proto) pair + addr_pairs_info = [ + (key, addr_pair) for key, addr_pair in addr_infos.items() + if not ((local_addr and addr_pair[0] is None) or + (remote_addr and addr_pair[1] is None))] + + if not addr_pairs_info: + raise ValueError('can not get address information') + + exceptions = [] + + for ((family, proto), + (local_address, remote_address)) in addr_pairs_info: + sock = None + r_addr = None + try: + sock = socket.socket( + family=family, type=socket.SOCK_DGRAM, proto=proto) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(False) + + if local_addr: + sock.bind(local_address) + if remote_addr: + yield from self.sock_connect(sock, remote_address) + r_addr = remote_address + except OSError as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + + protocol = protocol_factory() + transport = self._make_datagram_transport(sock, protocol, r_addr) + return transport, protocol + + @tasks.coroutine + def create_server(self, protocol_factory, host=None, port=None, + *, + family=socket.AF_UNSPEC, + flags=socket.AI_PASSIVE, + sock=None, + backlog=100, + ssl=None, + reuse_address=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + 'host/port and sock can not be specified at the same time') + + AF_INET6 = getattr(socket, 'AF_INET6', 0) + if reuse_address is None: + reuse_address = os.name == 'posix' and sys.platform != 'cygwin' + sockets = [] + if host == '': + host = None + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=0, flags=flags) + if not infos: + raise OSError('getaddrinfo() returned empty list') + + completed = False + try: + for res in infos: + af, socktype, proto, canonname, sa = res + sock = socket.socket(af, socktype, proto) + sockets.append(sock) + if reuse_address: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, + True) + # Disable IPv4/IPv6 dual stack support (enabled by + # default on Linux) which makes a single socket + # listen on both address families. + if af == AF_INET6 and hasattr(socket, 'IPPROTO_IPV6'): + sock.setsockopt(socket.IPPROTO_IPV6, + socket.IPV6_V6ONLY, + True) + try: + sock.bind(sa) + except OSError as err: + raise OSError(err.errno, 'error while attempting ' + 'to bind on address %r: %s' + % (sa, err.strerror.lower())) + completed = True + finally: + if not completed: + for sock in sockets: + sock.close() + else: + if sock is None: + raise ValueError( + 'host and port was not specified and no sock specified') + sockets = [sock] + + server = Server(self, sockets) + for sock in sockets: + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock, ssl, server) + return server + + @tasks.coroutine + def connect_read_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future(loop=self) + transport = self._make_read_pipe_transport(pipe, protocol, waiter) + yield from waiter + return transport, protocol + + @tasks.coroutine + def connect_write_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future(loop=self) + transport = self._make_write_pipe_transport(pipe, protocol, waiter) + yield from waiter + return transport, protocol + + @tasks.coroutine + def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + universal_newlines=False, shell=True, bufsize=0, + **kwargs): + assert not universal_newlines, "universal_newlines must be False" + assert shell, "shell must be True" + assert isinstance(cmd, str), cmd + protocol = protocol_factory() + transport = yield from self._make_subprocess_transport( + protocol, cmd, True, stdin, stdout, stderr, bufsize, **kwargs) + return transport, protocol + + @tasks.coroutine + def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + universal_newlines=False, shell=False, bufsize=0, + **kwargs): + assert not universal_newlines, "universal_newlines must be False" + assert not shell, "shell must be False" + protocol = protocol_factory() + transport = yield from self._make_subprocess_transport( + protocol, args, False, stdin, stdout, stderr, bufsize, **kwargs) + return transport, protocol + + def _add_callback(self, handle): + """Add a Handle to ready or scheduled.""" + assert isinstance(handle, events.Handle), 'A Handle is required here' + if handle._cancelled: + return + if isinstance(handle, events.TimerHandle): + heapq.heappush(self._scheduled, handle) + else: + self._ready.append(handle) + + def _add_callback_signalsafe(self, handle): + """Like _add_callback() but called from a signal handler.""" + self._add_callback(handle) + self._write_to_self() + + def _run_once(self): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0]._cancelled: + heapq.heappop(self._scheduled) + + timeout = None + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0]._when + deadline = max(0, when - self.time()) + if timeout is None: + timeout = deadline + else: + timeout = min(timeout, deadline) + + # TODO: Instrumentation only in debug mode? + t0 = self.time() + event_list = self._selector.select(timeout) + t1 = self.time() + argstr = '' if timeout is None else '{:.3f}'.format(timeout) + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + asyncio_log.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + self._process_events(event_list) + + # Handle 'later' callbacks that are ready. + now = self.time() + while self._scheduled: + handle = self._scheduled[0] + if handle._when > now: + break + handle = heapq.heappop(self._scheduled) + self._ready.append(handle) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is threadsafe without using locks. + ntodo = len(self._ready) + for i in range(ntodo): + handle = self._ready.popleft() + if not handle._cancelled: + handle._run() + handle = None # Needed to break cycles when an exception occurs. diff --git a/asyncio/constants.py b/asyncio/constants.py new file mode 100644 index 00000000..79c3b931 --- /dev/null +++ b/asyncio/constants.py @@ -0,0 +1,4 @@ +"""Constants.""" + + +LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5 diff --git a/asyncio/events.py b/asyncio/events.py new file mode 100644 index 00000000..d2ca80c4 --- /dev/null +++ b/asyncio/events.py @@ -0,0 +1,399 @@ +"""Event loop and event loop policy. + +Beyond the PEP: +- Only the main thread has a default event loop. +""" + +__all__ = ['AbstractEventLoopPolicy', 'DefaultEventLoopPolicy', + 'AbstractEventLoop', 'AbstractServer', + 'Handle', 'TimerHandle', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', 'new_event_loop', + ] + +import subprocess +import sys +import threading +import socket + +from .log import asyncio_log + + +class Handle: + """Object returned by callback registration methods.""" + + def __init__(self, callback, args): + self._callback = callback + self._args = args + self._cancelled = False + + def __repr__(self): + res = 'Handle({}, {})'.format(self._callback, self._args) + if self._cancelled: + res += '' + return res + + def cancel(self): + self._cancelled = True + + def _run(self): + try: + self._callback(*self._args) + except Exception: + asyncio_log.exception('Exception in callback %s %r', + self._callback, self._args) + self = None # Needed to break cycles when an exception occurs. + + +def make_handle(callback, args): + # TODO: Inline this? + assert not isinstance(callback, Handle), 'A Handle is not a callback' + return Handle(callback, args) + + +class TimerHandle(Handle): + """Object returned by timed callback registration methods.""" + + def __init__(self, when, callback, args): + assert when is not None + super().__init__(callback, args) + + self._when = when + + def __repr__(self): + res = 'TimerHandle({}, {}, {})'.format(self._when, + self._callback, + self._args) + if self._cancelled: + res += '' + + return res + + def __hash__(self): + return hash(self._when) + + def __lt__(self, other): + return self._when < other._when + + def __le__(self, other): + if self._when < other._when: + return True + return self.__eq__(other) + + def __gt__(self, other): + return self._when > other._when + + def __ge__(self, other): + if self._when > other._when: + return True + return self.__eq__(other) + + def __eq__(self, other): + if isinstance(other, TimerHandle): + return (self._when == other._when and + self._callback == other._callback and + self._args == other._args and + self._cancelled == other._cancelled) + return NotImplemented + + def __ne__(self, other): + equal = self.__eq__(other) + return NotImplemented if equal is NotImplemented else not equal + + +class AbstractServer: + """Abstract server returned by create_service().""" + + def close(self): + """Stop serving. This leaves existing connections open.""" + return NotImplemented + + def wait_closed(self): + """Coroutine to wait until service is closed.""" + return NotImplemented + + +class AbstractEventLoop: + """Abstract event loop.""" + + # Running and stopping the event loop. + + def run_forever(self): + """Run the event loop until stop() is called.""" + raise NotImplementedError + + def run_until_complete(self, future): + """Run the event loop until a Future is done. + + Return the Future's result, or raise its exception. + """ + raise NotImplementedError + + def stop(self): + """Stop the event loop as soon as reasonable. + + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + raise NotImplementedError + + def is_running(self): + """Return whether the event loop is currently running.""" + raise NotImplementedError + + # Methods scheduling callbacks. All these return Handles. + + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) + + def call_later(self, delay, callback, *args): + raise NotImplementedError + + def call_at(self, when, callback, *args): + raise NotImplementedError + + def time(self): + raise NotImplementedError + + # Methods for interacting with threads. + + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError + + def run_in_executor(self, executor, callback, *args): + raise NotImplementedError + + def set_default_executor(self, executor): + raise NotImplementedError + + # Network I/O methods returning Futures. + + def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError + + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=None, family=0, proto=0, flags=0, sock=None, + local_addr=None): + raise NotImplementedError + + def create_server(self, protocol_factory, host=None, port=None, *, + family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, + sock=None, backlog=100, ssl=None, reuse_address=None): + """A coroutine which creates a TCP server bound to host and port. + + The return value is a Server object which can be used to stop + the service. + + If host is an empty string or None all interfaces are assumed + and a list of multiple sockets will be returned (most likely + one for IPv4 and another one for IPv6). + + family can be set to either AF_INET or AF_INET6 to force the + socket to use IPv4 or IPv6. If not set it will be determined + from host (defaults to AF_UNSPEC). + + flags is a bitmask for getaddrinfo(). + + sock can optionally be specified in order to use a preexisting + socket object. + + backlog is the maximum number of queued connections passed to + listen() (defaults to 100). + + ssl can be set to an SSLContext to enable SSL over the + accepted connections. + + reuse_address tells the kernel to reuse a local socket in + TIME_WAIT state, without waiting for its natural timeout to + expire. If not specified will automatically be set to True on + UNIX. + """ + raise NotImplementedError + + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + def connect_read_pipe(self, protocol_factory, pipe): + """Register read pipe in eventloop. + + protocol_factory should instantiate object with Protocol interface. + pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + ReadTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def connect_write_pipe(self, protocol_factory, pipe): + """Register write pipe in eventloop. + + protocol_factory should instantiate object with BaseProtocol interface. + Pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + WriteTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + **kwargs): + raise NotImplementedError + + def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + **kwargs): + raise NotImplementedError + + # Ready-based callback registration methods. + # The add_*() methods return None. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. + + def add_reader(self, fd, callback, *args): + raise NotImplementedError + + def remove_reader(self, fd): + raise NotImplementedError + + def add_writer(self, fd, callback, *args): + raise NotImplementedError + + def remove_writer(self, fd): + raise NotImplementedError + + # Completion based I/O methods returning Futures. + + def sock_recv(self, sock, nbytes): + raise NotImplementedError + + def sock_sendall(self, sock, data): + raise NotImplementedError + + def sock_connect(self, sock, address): + raise NotImplementedError + + def sock_accept(self, sock): + raise NotImplementedError + + # Signal handling. + + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError + + def remove_signal_handler(self, sig): + raise NotImplementedError + + +class AbstractEventLoopPolicy: + """Abstract policy for accessing the event loop.""" + + def get_event_loop(self): + """XXX""" + raise NotImplementedError + + def set_event_loop(self, loop): + """XXX""" + raise NotImplementedError + + def new_event_loop(self): + """XXX""" + raise NotImplementedError + + +class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy): + """Default policy implementation for accessing the event loop. + + In this policy, each thread has its own event loop. However, we + only automatically create an event loop by default for the main + thread; other threads by default have no event loop. + + Other policies may have different rules (e.g. a single global + event loop, or automatically creating an event loop per thread, or + using some other notion of context to which an event loop is + associated). + """ + + _loop = None + _set_called = False + + def get_event_loop(self): + """Get the event loop. + + This may be None or an instance of EventLoop. + """ + if (self._loop is None and + not self._set_called and + isinstance(threading.current_thread(), threading._MainThread)): + self._loop = self.new_event_loop() + assert self._loop is not None, \ + ('There is no current event loop in thread %r.' % + threading.current_thread().name) + return self._loop + + def set_event_loop(self, loop): + """Set the event loop.""" + # TODO: The isinstance() test violates the PEP. + self._set_called = True + assert loop is None or isinstance(loop, AbstractEventLoop) + self._loop = loop + + def new_event_loop(self): + """Create a new event loop. + + You must call set_event_loop() to make this the current event + loop. + """ + if sys.platform == 'win32': # pragma: no cover + from . import windows_events + return windows_events.SelectorEventLoop() + else: # pragma: no cover + from . import unix_events + return unix_events.SelectorEventLoop() + + +# Event loop policy. The policy itself is always global, even if the +# policy's rules say that there is an event loop per thread (or other +# notion of context). The default policy is installed by the first +# call to get_event_loop_policy(). +_event_loop_policy = None + + +def get_event_loop_policy(): + """XXX""" + global _event_loop_policy + if _event_loop_policy is None: + _event_loop_policy = DefaultEventLoopPolicy() + return _event_loop_policy + + +def set_event_loop_policy(policy): + """XXX""" + global _event_loop_policy + # TODO: The isinstance() test violates the PEP. + assert policy is None or isinstance(policy, AbstractEventLoopPolicy) + _event_loop_policy = policy + + +def get_event_loop(): + """XXX""" + return get_event_loop_policy().get_event_loop() + + +def set_event_loop(loop): + """XXX""" + get_event_loop_policy().set_event_loop(loop) + + +def new_event_loop(): + """XXX""" + return get_event_loop_policy().new_event_loop() diff --git a/asyncio/futures.py b/asyncio/futures.py new file mode 100644 index 00000000..99a043b4 --- /dev/null +++ b/asyncio/futures.py @@ -0,0 +1,338 @@ +"""A Future class similar to the one in PEP 3148.""" + +__all__ = ['CancelledError', 'TimeoutError', + 'InvalidStateError', + 'Future', 'wrap_future', + ] + +import concurrent.futures._base +import logging +import traceback + +from . import events +from .log import asyncio_log + +# States for Future. +_PENDING = 'PENDING' +_CANCELLED = 'CANCELLED' +_FINISHED = 'FINISHED' + +# TODO: Do we really want to depend on concurrent.futures internals? +Error = concurrent.futures._base.Error +CancelledError = concurrent.futures.CancelledError +TimeoutError = concurrent.futures.TimeoutError + +STACK_DEBUG = logging.DEBUG - 1 # heavy-duty debugging + + +class InvalidStateError(Error): + """The operation is not allowed in this state.""" + # TODO: Show the future, its state, the method, and the required state. + + +class _TracebackLogger: + """Helper to log a traceback upon destruction if not cleared. + + This solves a nasty problem with Futures and Tasks that have an + exception set: if nobody asks for the exception, the exception is + never logged. This violates the Zen of Python: 'Errors should + never pass silently. Unless explicitly silenced.' + + However, we don't want to log the exception as soon as + set_exception() is called: if the calling code is written + properly, it will get the exception and handle it properly. But + we *do* want to log it if result() or exception() was never called + -- otherwise developers waste a lot of time wondering why their + buggy code fails silently. + + An earlier attempt added a __del__() method to the Future class + itself, but this backfired because the presence of __del__() + prevents garbage collection from breaking cycles. A way out of + this catch-22 is to avoid having a __del__() method on the Future + class itself, but instead to have a reference to a helper object + with a __del__() method that logs the traceback, where we ensure + that the helper object doesn't participate in cycles, and only the + Future has a reference to it. + + The helper object is added when set_exception() is called. When + the Future is collected, and the helper is present, the helper + object is also collected, and its __del__() method will log the + traceback. When the Future's result() or exception() method is + called (and a helper object is present), it removes the the helper + object, after calling its clear() method to prevent it from + logging. + + One downside is that we do a fair amount of work to extract the + traceback from the exception, even when it is never logged. It + would seem cheaper to just store the exception object, but that + references the traceback, which references stack frames, which may + reference the Future, which references the _TracebackLogger, and + then the _TracebackLogger would be included in a cycle, which is + what we're trying to avoid! As an optimization, we don't + immediately format the exception; we only do the work when + activate() is called, which call is delayed until after all the + Future's callbacks have run. Since usually a Future has at least + one callback (typically set by 'yield from') and usually that + callback extracts the callback, thereby removing the need to + format the exception. + + PS. I don't claim credit for this solution. I first heard of it + in a discussion about closing files when they are collected. + """ + + __slots__ = ['exc', 'tb'] + + def __init__(self, exc): + self.exc = exc + self.tb = None + + def activate(self): + exc = self.exc + if exc is not None: + self.exc = None + self.tb = traceback.format_exception(exc.__class__, exc, + exc.__traceback__) + + def clear(self): + self.exc = None + self.tb = None + + def __del__(self): + if self.tb: + asyncio_log.error('Future/Task exception was never retrieved:\n%s', + ''.join(self.tb)) + + +class Future: + """This class is *almost* compatible with concurrent.futures.Future. + + Differences: + + - result() and exception() do not take a timeout argument and + raise an exception when the future isn't done yet. + + - Callbacks registered with add_done_callback() are always called + via the event loop's call_soon_threadsafe(). + + - This class is not compatible with the wait() and as_completed() + methods in the concurrent.futures package. + + (In Python 3.4 or later we may be able to unify the implementations.) + """ + + # Class variables serving as defaults for instance variables. + _state = _PENDING + _result = None + _exception = None + _loop = None + + _blocking = False # proper use of future (yield vs yield from) + + _tb_logger = None + + def __init__(self, *, loop=None): + """Initialize the future. + + The optional event_loop argument allows to explicitly set the event + loop object used by the future. If it's not provided, the future uses + the default event loop. + """ + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._callbacks = [] + + def __repr__(self): + res = self.__class__.__name__ + if self._state == _FINISHED: + if self._exception is not None: + res += ''.format(self._exception) + else: + res += ''.format(self._result) + elif self._callbacks: + size = len(self._callbacks) + if size > 2: + res += '<{}, [{}, <{} more>, {}]>'.format( + self._state, self._callbacks[0], + size-2, self._callbacks[-1]) + else: + res += '<{}, {}>'.format(self._state, self._callbacks) + else: + res += '<{}>'.format(self._state) + return res + + def cancel(self): + """Cancel the future and schedule callbacks. + + If the future is already done or cancelled, return False. Otherwise, + change the future's state to cancelled, schedule the callbacks and + return True. + """ + if self._state != _PENDING: + return False + self._state = _CANCELLED + self._schedule_callbacks() + return True + + def _schedule_callbacks(self): + """Internal: Ask the event loop to call all callbacks. + + The callbacks are scheduled to be called as soon as possible. Also + clears the callback list. + """ + callbacks = self._callbacks[:] + if not callbacks: + return + + self._callbacks[:] = [] + for callback in callbacks: + self._loop.call_soon(callback, self) + + def cancelled(self): + """Return True if the future was cancelled.""" + return self._state == _CANCELLED + + # Don't implement running(); see http://bugs.python.org/issue18699 + + def done(self): + """Return True if the future is done. + + Done means either that a result / exception are available, or that the + future was cancelled. + """ + return self._state != _PENDING + + def result(self): + """Return the result this future represents. + + If the future has been cancelled, raises CancelledError. If the + future's result isn't yet available, raises InvalidStateError. If + the future is done and has an exception set, this exception is raised. + """ + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError('Result is not ready.') + if self._tb_logger is not None: + self._tb_logger.clear() + self._tb_logger = None + if self._exception is not None: + raise self._exception + return self._result + + def exception(self): + """Return the exception that was set on this future. + + The exception (or None if no exception was set) is returned only if + the future is done. If the future has been cancelled, raises + CancelledError. If the future isn't done yet, raises + InvalidStateError. + """ + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError('Exception is not set.') + if self._tb_logger is not None: + self._tb_logger.clear() + self._tb_logger = None + return self._exception + + def add_done_callback(self, fn): + """Add a callback to be run when the future becomes done. + + The callback is called with a single argument - the future object. If + the future is already done when this is called, the callback is + scheduled with call_soon. + """ + if self._state != _PENDING: + self._loop.call_soon(fn, self) + else: + self._callbacks.append(fn) + + # New method not in PEP 3148. + + def remove_done_callback(self, fn): + """Remove all instances of a callback from the "call when done" list. + + Returns the number of callbacks removed. + """ + filtered_callbacks = [f for f in self._callbacks if f != fn] + removed_count = len(self._callbacks) - len(filtered_callbacks) + if removed_count: + self._callbacks[:] = filtered_callbacks + return removed_count + + # So-called internal methods (note: no set_running_or_notify_cancel()). + + def set_result(self, result): + """Mark the future done and set its result. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError('{}: {!r}'.format(self._state, self)) + self._result = result + self._state = _FINISHED + self._schedule_callbacks() + + def set_exception(self, exception): + """Mark the future done and set an exception. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError('{}: {!r}'.format(self._state, self)) + self._exception = exception + self._tb_logger = _TracebackLogger(exception) + self._state = _FINISHED + self._schedule_callbacks() + # Arrange for the logger to be activated after all callbacks + # have had a chance to call result() or exception(). + self._loop.call_soon(self._tb_logger.activate) + + # Truly internal methods. + + def _copy_state(self, other): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert other.done() + assert not self.done() + if other.cancelled(): + self.cancel() + else: + exception = other.exception() + if exception is not None: + self.set_exception(exception) + else: + result = other.result() + self.set_result(result) + + def __iter__(self): + if not self.done(): + self._blocking = True + yield self # This tells Task to wait for completion. + assert self.done(), "yield from wasn't used with future" + return self.result() # May raise too. + + +def wrap_future(fut, *, loop=None): + """Wrap concurrent.futures.Future object.""" + if isinstance(fut, Future): + return fut + + assert isinstance(fut, concurrent.futures.Future), \ + 'concurrent.futures.Future is expected, got {!r}'.format(fut) + + if loop is None: + loop = events.get_event_loop() + + new_future = Future(loop=loop) + fut.add_done_callback( + lambda future: loop.call_soon_threadsafe( + new_future._copy_state, fut)) + return new_future diff --git a/asyncio/locks.py b/asyncio/locks.py new file mode 100644 index 00000000..06edbbc1 --- /dev/null +++ b/asyncio/locks.py @@ -0,0 +1,401 @@ +"""Synchronization primitives.""" + +__all__ = ['Lock', 'Event', 'Condition', 'Semaphore'] + +import collections + +from . import events +from . import futures +from . import tasks + + +class Lock: + """Primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned + by a particular coroutine when locked. A primitive lock is in one + of two states, 'locked' or 'unlocked'. + + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() + changes the state to locked and returns immediately. When the + state is locked, acquire() blocks until a call to release() in + another coroutine changes it to unlocked, then the acquire() call + resets it to locked and returns. The release() method should only + be called in the locked state; it changes the state to unlocked + and returns immediately. If an attempt is made to release an + unlocked lock, a RuntimeError will be raised. + + When more than one coroutine is blocked in acquire() waiting for + the state to turn to unlocked, only one coroutine proceeds when a + release() call resets the state to unlocked; first coroutine which + is blocked in acquire() is being processed. + + acquire() is a coroutine and should be called with 'yield from'. + + Locks also support the context manager protocol. '(yield from lock)' + should be used as context manager expression. + + Usage: + + lock = Lock() + ... + yield from lock + try: + ... + finally: + lock.release() + + Context manager usage: + + lock = Lock() + ... + with (yield from lock): + ... + + Lock objects can be tested for locking state: + + if not lock.locked(): + yield from lock + else: + # lock is acquired + ... + + """ + + def __init__(self, *, loop=None): + self._waiters = collections.deque() + self._locked = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + extra = 'locked' if self._locked else 'unlocked' + if self._waiters: + extra = '{},waiters:{}'.format(extra, len(self._waiters)) + return '<{} [{}]>'.format(res[1:-1], extra) + + def locked(self): + """Return true if lock is acquired.""" + return self._locked + + @tasks.coroutine + def acquire(self): + """Acquire a lock. + + This method blocks until the lock is unlocked, then sets it to + locked and returns True. + """ + if not self._waiters and not self._locked: + self._locked = True + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + self._locked = True + return True + finally: + self._waiters.remove(fut) + + def release(self): + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + If any other coroutines are blocked waiting for the lock to become + unlocked, allow exactly one of them to proceed. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + # Wake up the first waiter who isn't cancelled. + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + break + else: + raise RuntimeError('Lock is not acquired.') + + def __enter__(self): + if not self._locked: + raise RuntimeError( + '"yield from" should be used as context manager expression') + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self + + +class Event: + """An Event implementation, our equivalent to threading.Event. + + Class implementing event objects. An event manages a flag that can be set + to true with the set() method and reset to false with the clear() method. + The wait() method blocks until the flag is true. The flag is initially + false. + """ + + def __init__(self, *, loop=None): + self._waiters = collections.deque() + self._value = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + # TODO: add waiters:N if > 0. + res = super().__repr__() + return '<{} [{}]>'.format(res[1:-1], 'set' if self._value else 'unset') + + def is_set(self): + """Return true if and only if the internal flag is true.""" + return self._value + + def set(self): + """Set the internal flag to true. All coroutines waiting for it to + become true are awakened. Coroutine that call wait() once the flag is + true will not block at all. + """ + if not self._value: + self._value = True + + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + + def clear(self): + """Reset the internal flag to false. Subsequently, coroutines calling + wait() will block until set() is called to set the internal flag + to true again.""" + self._value = False + + @tasks.coroutine + def wait(self): + """Block until the internal flag is true. + + If the internal flag is true on entry, return True + immediately. Otherwise, block until another coroutine calls + set() to set the flag to true, then return True. + """ + if self._value: + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + return True + finally: + self._waiters.remove(fut) + + +# TODO: Why is this a Lock subclass? threading.Condition *has* a lock. +class Condition(Lock): + """A Condition implementation. + + This class implements condition variable objects. A condition variable + allows one or more coroutines to wait until they are notified by another + coroutine. + """ + + def __init__(self, *, loop=None): + super().__init__(loop=loop) + self._condition_waiters = collections.deque() + + # TODO: Add __repr__() with len(_condition_waiters). + + @tasks.coroutine + def wait(self): + """Wait until notified. + + If the calling coroutine has not acquired the lock when this + method is called, a RuntimeError is raised. + + This method releases the underlying lock, and then blocks + until it is awakened by a notify() or notify_all() call for + the same condition variable in another coroutine. Once + awakened, it re-acquires the lock and returns True. + """ + if not self._locked: + raise RuntimeError('cannot wait on un-acquired lock') + + keep_lock = True + self.release() + try: + fut = futures.Future(loop=self._loop) + self._condition_waiters.append(fut) + try: + yield from fut + return True + finally: + self._condition_waiters.remove(fut) + + except GeneratorExit: + keep_lock = False # Prevent yield in finally clause. + raise + finally: + if keep_lock: + yield from self.acquire() + + @tasks.coroutine + def wait_for(self, predicate): + """Wait until a predicate becomes true. + + The predicate should be a callable which result will be + interpreted as a boolean value. The final predicate value is + the return value. + """ + result = predicate() + while not result: + yield from self.wait() + result = predicate() + return result + + def notify(self, n=1): + """By default, wake up one coroutine waiting on this condition, if any. + If the calling coroutine has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up at most n of the coroutines waiting for the + condition variable; it is a no-op if no coroutines are waiting. + + Note: an awakened coroutine does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self._locked: + raise RuntimeError('cannot notify on un-acquired lock') + + idx = 0 + for fut in self._condition_waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self): + """Wake up all threads waiting on this condition. This method acts + like notify(), but wakes up all waiting threads instead of one. If the + calling thread has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._condition_waiters)) + + +class Semaphore: + """A Semaphore implementation. + + A semaphore manages an internal counter which is decremented by each + acquire() call and incremented by each release() call. The counter + can never go below zero; when acquire() finds that it is zero, it blocks, + waiting until some other thread calls release(). + + Semaphores also support the context manager protocol. + + The first optional argument gives the initial value for the internal + counter; it defaults to 1. If the value given is less than 0, + ValueError is raised. + + The second optional argument determins can semophore be released more than + initial internal counter value; it defaults to False. If the value given + is True and number of release() is more than number of successfull + acquire() calls ValueError is raised. + """ + + def __init__(self, value=1, bound=False, *, loop=None): + if value < 0: + raise ValueError("Semaphore initial value must be > 0") + self._value = value + self._bound = bound + self._bound_value = value + self._waiters = collections.deque() + self._locked = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + # TODO: add waiters:N if > 0. + res = super().__repr__() + return '<{} [{}]>'.format( + res[1:-1], + 'locked' if self._locked else 'unlocked,value:{}'.format( + self._value)) + + def locked(self): + """Returns True if semaphore can not be acquired immediately.""" + return self._locked + + @tasks.coroutine + def acquire(self): + """Acquire a semaphore. + + If the internal counter is larger than zero on entry, + decrement it by one and return True immediately. If it is + zero on entry, block, waiting until some other coroutine has + called release() to make it larger than 0, and then return + True. + """ + if not self._waiters and self._value > 0: + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + self._value -= 1 + if self._value == 0: + self._locked = True + return True + finally: + self._waiters.remove(fut) + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + When it was zero on entry and another coroutine is waiting for it to + become larger than zero again, wake up that coroutine. + + If Semaphore is create with "bound" paramter equals true, then + release() method checks to make sure its current value doesn't exceed + its initial value. If it does, ValueError is raised. + """ + if self._bound and self._value >= self._bound_value: + raise ValueError('Semaphore released too many times') + + self._value += 1 + self._locked = False + + for waiter in self._waiters: + if not waiter.done(): + waiter.set_result(True) + break + + def __enter__(self): + # TODO: This is questionable. How do we know the user actually + # wrote "with (yield from sema)" instead of "with sema"? + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self diff --git a/asyncio/log.py b/asyncio/log.py new file mode 100644 index 00000000..54dc784e --- /dev/null +++ b/asyncio/log.py @@ -0,0 +1,6 @@ +"""Logging configuration.""" + +import logging + + +asyncio_log = logging.getLogger("asyncio") diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py new file mode 100644 index 00000000..034f405b --- /dev/null +++ b/asyncio/proactor_events.py @@ -0,0 +1,347 @@ +"""Event loop using a proactor and related classes. + +A proactor is a "notify-on-completion" multiplexer. Currently a +proactor is only implemented on Windows with IOCP. +""" + +import socket + +from . import base_events +from . import constants +from . import futures +from . import transports +from .log import asyncio_log + + +class _ProactorBasePipeTransport(transports.BaseTransport): + """Base class for pipe and socket transports.""" + + def __init__(self, loop, sock, protocol, waiter=None, + extra=None, server=None): + super().__init__(extra) + self._set_extra(sock) + self._loop = loop + self._sock = sock + self._protocol = protocol + self._server = server + self._buffer = [] + self._read_fut = None + self._write_fut = None + self._conn_lost = 0 + self._closing = False # Set when close() called. + self._eof_written = False + if self._server is not None: + self._server.attach(self) + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _set_extra(self, sock): + self._extra['pipe'] = sock + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + if not self._buffer and self._write_fut is None: + self._loop.call_soon(self._call_connection_lost, None) + if self._read_fut is not None: + self._read_fut.cancel() + + def _fatal_error(self, exc): + asyncio_log.exception('Fatal error for %s', self) + self._force_close(exc) + + def _force_close(self, exc): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + if self._write_fut: + self._write_fut.cancel() + if self._read_fut: + self._read_fut.cancel() + self._write_fut = self._read_fut = None + self._buffer = [] + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + # XXX If there is a pending overlapped read on the other + # end then it may fail with ERROR_NETNAME_DELETED if we + # just close our end. First calling shutdown() seems to + # cure it, but maybe using DisconnectEx() would be better. + if hasattr(self._sock, 'shutdown'): + self._sock.shutdown(socket.SHUT_RDWR) + self._sock.close() + server = self._server + if server is not None: + server.detach(self) + self._server = None + + +class _ProactorReadPipeTransport(_ProactorBasePipeTransport, + transports.ReadTransport): + """Transport for read pipes.""" + + def __init__(self, loop, sock, protocol, waiter=None, + extra=None, server=None): + super().__init__(loop, sock, protocol, waiter, extra, server) + self._read_fut = None + self._paused = False + self._loop.call_soon(self._loop_reading) + + def pause(self): + assert not self._closing, 'Cannot pause() when closing' + assert not self._paused, 'Already paused' + self._paused = True + + def resume(self): + assert self._paused, 'Not paused' + self._paused = False + if self._closing: + return + self._loop.call_soon(self._loop_reading, self._read_fut) + + def _loop_reading(self, fut=None): + if self._paused: + return + data = None + + try: + if fut is not None: + assert self._read_fut is fut or (self._read_fut is None and + self._closing) + self._read_fut = None + data = fut.result() # deliver data later in "finally" clause + + if self._closing: + # since close() has been called we ignore any read data + data = None + return + + if data == b'': + # we got end-of-file so no need to reschedule a new read + return + + # reschedule a new read + self._read_fut = self._loop._proactor.recv(self._sock, 4096) + except ConnectionAbortedError as exc: + if not self._closing: + self._fatal_error(exc) + except ConnectionResetError as exc: + self._force_close(exc) + except OSError as exc: + self._fatal_error(exc) + except futures.CancelledError: + if not self._closing: + raise + else: + self._read_fut.add_done_callback(self._loop_reading) + finally: + if data: + self._protocol.data_received(data) + elif data is not None: + try: + self._protocol.eof_received() + finally: + self.close() + + +class _ProactorWritePipeTransport(_ProactorBasePipeTransport, + transports.WriteTransport): + """Transport for write pipes.""" + + def write(self, data): + assert isinstance(data, bytes), repr(data) + if self._eof_written: + raise IOError('write_eof() already called') + + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + asyncio_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + self._buffer.append(data) + if self._write_fut is None: + self._loop_writing() + + def _loop_writing(self, f=None): + try: + assert f is self._write_fut + self._write_fut = None + if f: + f.result() + data = b''.join(self._buffer) + self._buffer = [] + if not data: + if self._closing: + self._loop.call_soon(self._call_connection_lost, None) + if self._eof_written: + self._sock.shutdown(socket.SHUT_WR) + return + self._write_fut = self._loop._proactor.send(self._sock, data) + self._write_fut.add_done_callback(self._loop_writing) + except ConnectionResetError as exc: + self._force_close(exc) + except OSError as exc: + self._fatal_error(exc) + + def can_write_eof(self): + return True + + def write_eof(self): + self.close() + + def abort(self): + self._force_close(None) + + +class _ProactorDuplexPipeTransport(_ProactorReadPipeTransport, + _ProactorWritePipeTransport, + transports.Transport): + """Transport for duplex pipes.""" + + def can_write_eof(self): + return False + + def write_eof(self): + raise NotImplementedError + + +class _ProactorSocketTransport(_ProactorReadPipeTransport, + _ProactorWritePipeTransport, + transports.Transport): + """Transport for connected sockets.""" + + def _set_extra(self, sock): + self._extra['socket'] = sock + self._extra['sockname'] = sock.getsockname() + if 'peername' not in self._extra: + self._extra['peername'] = sock.getpeername() + + def can_write_eof(self): + return True + + def write_eof(self): + if self._closing or self._eof_written: + return + self._eof_written = True + if self._write_fut is None: + self._sock.shutdown(socket.SHUT_WR) + + +class BaseProactorEventLoop(base_events.BaseEventLoop): + + def __init__(self, proactor): + super().__init__() + asyncio_log.debug('Using proactor: %s', proactor.__class__.__name__) + self._proactor = proactor + self._selector = proactor # convenient alias + proactor.set_loop(self) + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, + extra=None, server=None): + return _ProactorSocketTransport(self, sock, protocol, waiter, + extra, server) + + def _make_duplex_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorDuplexPipeTransport(self, + sock, protocol, waiter, extra) + + def _make_read_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorReadPipeTransport(self, sock, protocol, waiter, extra) + + def _make_write_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorWritePipeTransport(self, sock, protocol, waiter, extra) + + def close(self): + if self._proactor is not None: + self._close_self_pipe() + self._proactor.close() + self._proactor = None + self._selector = None + + def sock_recv(self, sock, n): + return self._proactor.recv(sock, n) + + def sock_sendall(self, sock, data): + return self._proactor.send(sock, data) + + def sock_connect(self, sock, address): + return self._proactor.connect(sock, address) + + def sock_accept(self, sock): + return self._proactor.accept(sock) + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.call_soon(self._loop_self_reading) + + def _loop_self_reading(self, f=None): + try: + if f is not None: + f.result() # may raise + f = self._proactor.recv(self._ssock, 4096) + except: + self.close() + raise + else: + f.add_done_callback(self._loop_self_reading) + + def _write_to_self(self): + self._csock.send(b'x') + + def _start_serving(self, protocol_factory, sock, ssl=None, server=None): + assert not ssl, 'IocpEventLoop is incompatible with SSL.' + + def loop(f=None): + try: + if f is not None: + conn, addr = f.result() + protocol = protocol_factory() + self._make_socket_transport( + conn, protocol, + extra={'peername': addr}, server=server) + f = self._proactor.accept(sock) + except OSError: + if sock.fileno() != -1: + asyncio_log.exception('Accept failed') + sock.close() + except futures.CancelledError: + sock.close() + else: + f.add_done_callback(loop) + + self.call_soon(loop) + + def _process_events(self, event_list): + pass # XXX hard work currently done in poll + + def _stop_serving(self, sock): + self._proactor._stop_serving(sock) + sock.close() diff --git a/asyncio/protocols.py b/asyncio/protocols.py new file mode 100644 index 00000000..d76f25a2 --- /dev/null +++ b/asyncio/protocols.py @@ -0,0 +1,100 @@ +"""Abstract Protocol class.""" + +__all__ = ['Protocol', 'DatagramProtocol'] + + +class BaseProtocol: + """ABC for base protocol class. + + Usually user implements protocols that derived from BaseProtocol + like Protocol or ProcessProtocol. + + The only case when BaseProtocol should be implemented directly is + write-only transport like write pipe + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the pipe connection. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +class Protocol(BaseProtocol): + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing (they don't raise exceptions). + + When the user wants to requests a transport, they pass a protocol + factory to a utility function (e.g., EventLoop.create_connection()). + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_lost() will be called exactly once + with either an exception object or None as an argument. + + State machine of calls: + + start -> CM [-> DR*] [-> ER?] -> CL -> end + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + The default implementation does nothing. + + TODO: By default close the transport. But we don't have the + transport as an instance variable (connection_made() may not + set it). + """ + + +class DatagramProtocol(BaseProtocol): + """ABC representing a datagram protocol.""" + + def datagram_received(self, data, addr): + """Called when some datagram is received.""" + + def connection_refused(self, exc): + """Connection is refused.""" + + +class SubprocessProtocol(BaseProtocol): + """ABC representing a protocol for subprocess calls.""" + + def pipe_data_received(self, fd, data): + """Called when subprocess write a data into stdout/stderr pipes. + + fd is int file dascriptor. + data is bytes object. + """ + + def pipe_connection_lost(self, fd, exc): + """Called when a file descriptor associated with the child process is + closed. + + fd is the int file descriptor that was closed. + """ + + def process_exited(self): + """Called when subprocess has exited. + """ diff --git a/asyncio/queues.py b/asyncio/queues.py new file mode 100644 index 00000000..536de1cb --- /dev/null +++ b/asyncio/queues.py @@ -0,0 +1,284 @@ +"""Queues""" + +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue', + 'Full', 'Empty'] + +import collections +import heapq +import queue + +from . import events +from . import futures +from . import locks +from .tasks import coroutine + + +# Re-export queue.Full and .Empty exceptions. +Full = queue.Full +Empty = queue.Empty + + +class Queue: + """A queue, useful for coordinating producer and consumer coroutines. + + If maxsize is less than or equal to zero, the queue size is infinite. If it + is an integer greater than 0, then "yield from put()" will block when the + queue reaches maxsize, until an item is removed by get(). + + Unlike the standard library Queue, you can reliably know this Queue's size + with qsize(), since your single-threaded Tulip application won't be + interrupted between calling qsize() and doing an operation on the Queue. + """ + + def __init__(self, maxsize=0, *, loop=None): + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._maxsize = maxsize + + # Futures. + self._getters = collections.deque() + # Pairs of (item, Future). + self._putters = collections.deque() + self._init(maxsize) + + def _init(self, maxsize): + self._queue = collections.deque() + + def _get(self): + return self._queue.popleft() + + def _put(self, item): + self._queue.append(item) + + def __repr__(self): + return '<{} at {:#x} {}>'.format( + type(self).__name__, id(self), self._format()) + + def __str__(self): + return '<{} {}>'.format(type(self).__name__, self._format()) + + def _format(self): + result = 'maxsize={!r}'.format(self._maxsize) + if getattr(self, '_queue', None): + result += ' _queue={!r}'.format(list(self._queue)) + if self._getters: + result += ' _getters[{}]'.format(len(self._getters)) + if self._putters: + result += ' _putters[{}]'.format(len(self._putters)) + return result + + def _consume_done_getters(self): + # Delete waiters at the head of the get() queue who've timed out. + while self._getters and self._getters[0].done(): + self._getters.popleft() + + def _consume_done_putters(self): + # Delete waiters at the head of the put() queue who've timed out. + while self._putters and self._putters[0][1].done(): + self._putters.popleft() + + def qsize(self): + """Number of items in the queue.""" + return len(self._queue) + + @property + def maxsize(self): + """Number of items allowed in the queue.""" + return self._maxsize + + def empty(self): + """Return True if the queue is empty, False otherwise.""" + return not self._queue + + def full(self): + """Return True if there are maxsize items in the queue. + + Note: if the Queue was initialized with maxsize=0 (the default), + then full() is never True. + """ + if self._maxsize <= 0: + return False + else: + return self.qsize() == self._maxsize + + @coroutine + def put(self, item): + """Put an item into the queue. + + If you yield from put(), wait until a free slot is available + before adding item. + """ + self._consume_done_getters() + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + waiter = futures.Future(loop=self._loop) + + self._putters.append((item, waiter)) + yield from waiter + + else: + self._put(item) + + def put_nowait(self, item): + """Put an item into the queue without blocking. + + If no free slot is immediately available, raise Full. + """ + self._consume_done_getters() + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + raise Full + else: + self._put(item) + + @coroutine + def get(self): + """Remove and return an item from the queue. + + If you yield from get(), wait until a item is available. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + + # When a getter runs and frees up a slot so this putter can + # run, we need to defer the put for a tick to ensure that + # getters and putters alternate perfectly. See + # ChannelTest.test_wait. + self._loop.call_soon(putter.set_result, None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + waiter = futures.Future(loop=self._loop) + + self._getters.append(waiter) + return (yield from waiter) + + def get_nowait(self): + """Remove and return an item from the queue. + + Return an item if one is immediately available, else raise Full. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + # Wake putter on next tick. + putter.set_result(None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + raise Empty + + +class PriorityQueue(Queue): + """A subclass of Queue; retrieves entries in priority order (lowest first). + + Entries are typically tuples of the form: (priority number, data). + """ + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item, heappush=heapq.heappush): + heappush(self._queue, item) + + def _get(self, heappop=heapq.heappop): + return heappop(self._queue) + + +class LifoQueue(Queue): + """A subclass of Queue that retrieves most recently added entries first.""" + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item): + self._queue.append(item) + + def _get(self): + return self._queue.pop() + + +class JoinableQueue(Queue): + """A subclass of Queue with task_done() and join() methods.""" + + def __init__(self, maxsize=0, *, loop=None): + super().__init__(maxsize=maxsize, loop=loop) + self._unfinished_tasks = 0 + self._finished = locks.Event(loop=self._loop) + self._finished.set() + + def _format(self): + result = Queue._format(self) + if self._unfinished_tasks: + result += ' tasks={}'.format(self._unfinished_tasks) + return result + + def _put(self, item): + super()._put(item) + self._unfinished_tasks += 1 + self._finished.clear() + + def task_done(self): + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each get() used to fetch a task, + a subsequent call to task_done() tells the queue that the processing + on the task is complete. + + If a join() is currently blocking, it will resume when all items have + been processed (meaning that a task_done() call was received for every + item that had been put() into the queue). + + Raises ValueError if called more times than there were items placed in + the queue. + """ + if self._unfinished_tasks <= 0: + raise ValueError('task_done() called too many times') + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + @coroutine + def join(self): + """Block until all items in the queue have been gotten and processed. + + The count of unfinished tasks goes up whenever an item is added to the + queue. The count goes down whenever a consumer thread calls task_done() + to indicate that the item was retrieved and all work on it is complete. + When the count of unfinished tasks drops to zero, join() unblocks. + """ + if self._unfinished_tasks > 0: + yield from self._finished.wait() diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py new file mode 100644 index 00000000..d0677b9f --- /dev/null +++ b/asyncio/selector_events.py @@ -0,0 +1,742 @@ +"""Event loop using a selector and related classes. + +A selector is a "notify-when-ready" multiplexer. For a subclass which +also includes support for signal handling, see the unix_events sub-module. +""" + +import collections +import socket +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import base_events +from . import constants +from . import events +from . import futures +from . import selectors +from . import transports +from .log import asyncio_log + + +class BaseSelectorEventLoop(base_events.BaseEventLoop): + """Selector event loop. + + See events.EventLoop for API specification. + """ + + def __init__(self, selector=None): + super().__init__() + + if selector is None: + selector = selectors.DefaultSelector() + asyncio_log.debug('Using selector: %s', selector.__class__.__name__) + self._selector = selector + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None, server=None): + return _SelectorSocketTransport(self, sock, protocol, waiter, + extra, server) + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + server_side=False, server_hostname=None, + extra=None, server=None): + return _SelectorSslTransport( + self, rawsock, protocol, sslcontext, waiter, + server_side, server_hostname, extra, server) + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + return _SelectorDatagramTransport(self, sock, protocol, address, extra) + + def close(self): + if self._selector is not None: + self._close_self_pipe() + self._selector.close() + self._selector = None + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self.remove_reader(self._ssock.fileno()) + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.add_reader(self._ssock.fileno(), self._read_from_self) + + def _read_from_self(self): + try: + self._ssock.recv(1) + except (BlockingIOError, InterruptedError): + pass + + def _write_to_self(self): + try: + self._csock.send(b'x') + except (BlockingIOError, InterruptedError): + pass + + def _start_serving(self, protocol_factory, sock, ssl=None, server=None): + self.add_reader(sock.fileno(), self._accept_connection, + protocol_factory, sock, ssl, server) + + def _accept_connection(self, protocol_factory, sock, ssl=None, + server=None): + try: + conn, addr = sock.accept() + conn.setblocking(False) + except (BlockingIOError, InterruptedError): + pass # False alarm. + except Exception: + # Bad error. Stop serving. + self.remove_reader(sock.fileno()) + sock.close() + # There's nowhere to send the error, so just log it. + # TODO: Someone will want an error handler for this. + asyncio_log.exception('Accept failed') + else: + if ssl: + self._make_ssl_transport( + conn, protocol_factory(), ssl, None, + server_side=True, extra={'peername': addr}, server=server) + else: + self._make_socket_transport( + conn, protocol_factory(), extra={'peername': addr}, + server=server) + # It's now up to the protocol to handle the connection. + + def add_reader(self, fd, callback, *args): + """Add a reader callback.""" + handle = events.make_handle(callback, args) + try: + key = self._selector.get_key(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_READ, + (handle, None)) + else: + mask, (reader, writer) = key.events, key.data + self._selector.modify(fd, mask | selectors.EVENT_READ, + (handle, writer)) + if reader is not None: + reader.cancel() + + def remove_reader(self, fd): + """Remove a reader callback.""" + try: + key = self._selector.get_key(fd) + except KeyError: + return False + else: + mask, (reader, writer) = key.events, key.data + mask &= ~selectors.EVENT_READ + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (None, writer)) + + if reader is not None: + reader.cancel() + return True + else: + return False + + def add_writer(self, fd, callback, *args): + """Add a writer callback..""" + handle = events.make_handle(callback, args) + try: + key = self._selector.get_key(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_WRITE, + (None, handle)) + else: + mask, (reader, writer) = key.events, key.data + self._selector.modify(fd, mask | selectors.EVENT_WRITE, + (reader, handle)) + if writer is not None: + writer.cancel() + + def remove_writer(self, fd): + """Remove a writer callback.""" + try: + key = self._selector.get_key(fd) + except KeyError: + return False + else: + mask, (reader, writer) = key.events, key.data + # Remove both writer and connector. + mask &= ~selectors.EVENT_WRITE + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None)) + + if writer is not None: + writer.cancel() + return True + else: + return False + + def sock_recv(self, sock, n): + """XXX""" + fut = futures.Future(loop=self) + self._sock_recv(fut, False, sock, n) + return fut + + def _sock_recv(self, fut, registered, sock, n): + fd = sock.fileno() + if registered: + # Remove the callback early. It should be rare that the + # selector says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_reader(fd) + if fut.cancelled(): + return + try: + data = sock.recv(n) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_recv, fut, True, sock, n) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(data) + + def sock_sendall(self, sock, data): + """XXX""" + fut = futures.Future(loop=self) + if data: + self._sock_sendall(fut, False, sock, data) + else: + fut.set_result(None) + return fut + + def _sock_sendall(self, fut, registered, sock, data): + fd = sock.fileno() + + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + + try: + n = sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + fut.set_exception(exc) + return + + if n == len(data): + fut.set_result(None) + else: + if n: + data = data[n:] + self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + + def sock_connect(self, sock, address): + """XXX""" + # That address better not require a lookup! We're not calling + # self.getaddrinfo() for you here. But verifying this is + # complicated; the socket module doesn't have a pattern for + # IPv6 addresses (there are too many forms, apparently). + fut = futures.Future(loop=self) + self._sock_connect(fut, False, sock, address) + return fut + + def _sock_connect(self, fut, registered, sock, address): + # TODO: Use getaddrinfo() to look up the address, to avoid the + # trap of hanging the entire event loop when the address + # requires doing a DNS lookup. (OTOH, the caller should + # already have done this, so it would be nice if we could + # easily tell whether the address needs looking up or not. I + # know how to do this for IPv4, but IPv6 addresses have many + # syntaxes.) + fd = sock.fileno() + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + try: + if not registered: + # First time around. + sock.connect(address) + else: + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to the except clause below. + raise OSError(err, 'Connect call failed') + except (BlockingIOError, InterruptedError): + self.add_writer(fd, self._sock_connect, fut, True, sock, address) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(None) + + def sock_accept(self, sock): + """XXX""" + fut = futures.Future(loop=self) + self._sock_accept(fut, False, sock) + return fut + + def _sock_accept(self, fut, registered, sock): + fd = sock.fileno() + if registered: + self.remove_reader(fd) + if fut.cancelled(): + return + try: + conn, address = sock.accept() + conn.setblocking(False) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_accept, fut, True, sock) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result((conn, address)) + + def _process_events(self, event_list): + for key, mask in event_list: + fileobj, (reader, writer) = key.fileobj, key.data + if mask & selectors.EVENT_READ and reader is not None: + if reader._cancelled: + self.remove_reader(fileobj) + else: + self._add_callback(reader) + if mask & selectors.EVENT_WRITE and writer is not None: + if writer._cancelled: + self.remove_writer(fileobj) + else: + self._add_callback(writer) + + def _stop_serving(self, sock): + self.remove_reader(sock.fileno()) + sock.close() + + +class _SelectorTransport(transports.Transport): + + max_size = 256 * 1024 # Buffer size passed to recv(). + + def __init__(self, loop, sock, protocol, extra, server=None): + super().__init__(extra) + self._extra['socket'] = sock + self._extra['sockname'] = sock.getsockname() + if 'peername' not in self._extra: + try: + self._extra['peername'] = sock.getpeername() + except socket.error: + self._extra['peername'] = None + self._loop = loop + self._sock = sock + self._sock_fd = sock.fileno() + self._protocol = protocol + self._server = server + self._buffer = collections.deque() + self._conn_lost = 0 + self._closing = False # Set when close() called. + if server is not None: + server.attach(self) + + def abort(self): + self._force_close(None) + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + self._loop.remove_reader(self._sock_fd) + if not self._buffer: + self._loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + # should be called from exception handler only + asyncio_log.exception('Fatal error for %s', self) + self._force_close(exc) + + def _force_close(self, exc): + if self._buffer: + self._buffer.clear() + self._loop.remove_writer(self._sock_fd) + + if self._closing: + return + + self._closing = True + self._conn_lost += 1 + self._loop.remove_reader(self._sock_fd) + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + self._sock = None + self._protocol = None + self._loop = None + server = self._server + if server is not None: + server.detach(self) + self._server = None + + +class _SelectorSocketTransport(_SelectorTransport): + + def __init__(self, loop, sock, protocol, waiter=None, + extra=None, server=None): + super().__init__(loop, sock, protocol, extra, server) + self._eof = False + self._paused = False + + self._loop.add_reader(self._sock_fd, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def pause(self): + assert not self._closing, 'Cannot pause() when closing' + assert not self._paused, 'Already paused' + self._paused = True + self._loop.remove_reader(self._sock_fd) + + def resume(self): + assert self._paused, 'Not paused' + self._paused = False + if self._closing: + return + self._loop.add_reader(self._sock_fd, self._read_ready) + + def _read_ready(self): + try: + data = self._sock.recv(self.max_size) + except (BlockingIOError, InterruptedError): + pass + except ConnectionResetError as exc: + self._force_close(exc) + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + try: + self._protocol.eof_received() + finally: + self.close() + + def write(self, data): + assert isinstance(data, bytes), repr(type(data)) + assert not self._eof, 'Cannot call write() after write_eof()' + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + asyncio_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Attempt to send it right away first. + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except (BrokenPipeError, ConnectionResetError) as exc: + self._force_close(exc) + return + except OSError as exc: + self._fatal_error(exc) + return + else: + data = data[n:] + if not data: + return + # Start async I/O. + self._loop.add_writer(self._sock_fd, self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + assert data, 'Data should not be empty' + + self._buffer.clear() + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except (BrokenPipeError, ConnectionResetError) as exc: + self._loop.remove_writer(self._sock_fd) + self._force_close(exc) + except Exception as exc: + self._loop.remove_writer(self._sock_fd) + self._fatal_error(exc) + else: + data = data[n:] + if not data: + self._loop.remove_writer(self._sock_fd) + if self._closing: + self._call_connection_lost(None) + elif self._eof: + self._sock.shutdown(socket.SHUT_WR) + return + + self._buffer.append(data) # Try again later. + + def write_eof(self): + if self._eof: + return + self._eof = True + if not self._buffer: + self._sock.shutdown(socket.SHUT_WR) + + def can_write_eof(self): + return True + + +class _SelectorSslTransport(_SelectorTransport): + + def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, + server_side=False, server_hostname=None, + extra=None, server=None): + if server_side: + assert isinstance( + sslcontext, ssl.SSLContext), 'Must pass an SSLContext' + else: + # Client-side may pass ssl=True to use a default context. + sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + wrap_kwargs = { + 'server_side': server_side, + 'do_handshake_on_connect': False, + } + if server_hostname is not None and not server_side and ssl.HAS_SNI: + wrap_kwargs['server_hostname'] = server_hostname + sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs) + + super().__init__(loop, sslsock, protocol, extra, server) + + self._waiter = waiter + self._rawsock = rawsock + self._sslcontext = sslcontext + self._paused = False + + self._on_handshake() + + def _on_handshake(self): + try: + self._sock.do_handshake() + except ssl.SSLWantReadError: + self._loop.add_reader(self._sock_fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._loop.add_writer(self._sock_fd, self._on_handshake) + return + except Exception as exc: + self._sock.close() + if self._waiter is not None: + self._waiter.set_exception(exc) + return + except BaseException as exc: + self._sock.close() + if self._waiter is not None: + self._waiter.set_exception(exc) + raise + self._loop.remove_reader(self._sock_fd) + self._loop.remove_writer(self._sock_fd) + self._loop.add_reader(self._sock_fd, self._on_ready) + self._loop.add_writer(self._sock_fd, self._on_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if self._waiter is not None: + self._loop.call_soon(self._waiter.set_result, None) + + def pause(self): + # XXX This is a bit icky, given the comment at the top of + # _on_ready(). Is it possible to evoke a deadlock? I don't + # know, although it doesn't look like it; write() will still + # accept more data for the buffer and eventually the app will + # call resume() again, and things will flow again. + + assert not self._closing, 'Cannot pause() when closing' + assert not self._paused, 'Already paused' + self._paused = True + self._loop.remove_reader(self._sock_fd) + + def resume(self): + assert self._paused, 'Not paused' + self._paused = False + if self._closing: + return + self._loop.add_reader(self._sock_fd, self._on_ready) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # First try reading. + if not self._closing and not self._paused: + try: + data = self._sock.recv(self.max_size) + except (BlockingIOError, InterruptedError, + ssl.SSLWantReadError, ssl.SSLWantWriteError): + pass + except ConnectionResetError as exc: + self._force_close(exc) + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + try: + self._protocol.eof_received() + finally: + self.close() + + # Now try writing, if there's anything to write. + if self._buffer: + data = b''.join(self._buffer) + self._buffer.clear() + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError, + ssl.SSLWantReadError, ssl.SSLWantWriteError): + n = 0 + except (BrokenPipeError, ConnectionResetError) as exc: + self._loop.remove_writer(self._sock_fd) + self._force_close(exc) + return + except Exception as exc: + self._loop.remove_writer(self._sock_fd) + self._fatal_error(exc) + return + + if n < len(data): + self._buffer.append(data[n:]) + + if self._closing and not self._buffer: + self._loop.remove_writer(self._sock_fd) + self._call_connection_lost(None) + + def write(self, data): + assert isinstance(data, bytes), repr(type(data)) + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + asyncio_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + def can_write_eof(self): + return False + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + self._loop.remove_reader(self._sock_fd) + + +class _SelectorDatagramTransport(_SelectorTransport): + + def __init__(self, loop, sock, protocol, address=None, extra=None): + super().__init__(loop, sock, protocol, extra) + + self._address = address + self._loop.add_reader(self._sock_fd, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + + def _read_ready(self): + try: + data, addr = self._sock.recvfrom(self.max_size) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + self._protocol.datagram_received(data, addr) + + def sendto(self, data, addr=None): + assert isinstance(data, bytes), repr(type(data)) + if not data: + return + + if self._address: + assert addr in (None, self._address) + + if self._conn_lost and self._address: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + asyncio_log.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Attempt to send it right away first. + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + return + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except (BlockingIOError, InterruptedError): + self._loop.add_writer(self._sock_fd, self._sendto_ready) + except Exception as exc: + self._fatal_error(exc) + return + + self._buffer.append((data, addr)) + + def _sendto_ready(self): + while self._buffer: + data, addr = self._buffer.popleft() + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except (BlockingIOError, InterruptedError): + self._buffer.appendleft((data, addr)) # Try again later. + break + except Exception as exc: + self._fatal_error(exc) + return + + if not self._buffer: + self._loop.remove_writer(self._sock_fd) + if self._closing: + self._call_connection_lost(None) + + def _force_close(self, exc): + if self._address and isinstance(exc, ConnectionRefusedError): + self._protocol.connection_refused(exc) + + super()._force_close(exc) diff --git a/asyncio/selectors.py b/asyncio/selectors.py new file mode 100644 index 00000000..fe027f09 --- /dev/null +++ b/asyncio/selectors.py @@ -0,0 +1,405 @@ +"""Selectors module. + +This module allows high-level and efficient I/O multiplexing, built upon the +`select` module primitives. +""" + + +from abc import ABCMeta, abstractmethod +from collections import namedtuple +import functools +import select +import sys + + +# generic events, that must be mapped to implementation-specific ones +EVENT_READ = (1 << 0) +EVENT_WRITE = (1 << 1) + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file object or file descriptor + + Returns: + corresponding file descriptor + """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (AttributeError, TypeError, ValueError): + raise ValueError("Invalid file object: " + "{!r}".format(fileobj)) from None + if fd < 0: + raise ValueError("Invalid file descriptor: {}".format(fd)) + return fd + + +SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data']) +"""Object used to associate a file object to its backing file descriptor, +selected event mask and attached data.""" + + +class BaseSelector(metaclass=ABCMeta): + """Base selector class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + performant implementation on the current platform. + """ + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object or file descriptor + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + """ + if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)): + raise ValueError("Invalid events: {!r}".format(events)) + + key = SelectorKey(fileobj, _fileobj_to_fd(fileobj), events, data) + + if key.fd in self._fd_to_key: + raise KeyError("{!r} (FD {}) is already " + "registered".format(fileobj, key.fd)) + + self._fd_to_key[key.fd] = key + return key + + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object or file descriptor + + Returns: + SelectorKey instance + """ + try: + key = self._fd_to_key.pop(_fileobj_to_fd(fileobj)) + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + return key + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object or file descriptor + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + """ + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fd_to_key[_fileobj_to_fd(fileobj)] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + if events != key.events or data != key.data: + # TODO: If only the data changed, use a shortcut that only + # updates the data. + self.unregister(fileobj) + return self.register(fileobj, events, data) + else: + return key + + @abstractmethod + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout <= 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (key, events) for ready file objects + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE + """ + raise NotImplementedError() + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + self._fd_to_key.clear() + + def get_key(self, fileobj): + """Return the key associated to a registered file object. + + Returns: + SelectorKey for this file object + """ + try: + return self._fd_to_key[_fileobj_to_fd(fileobj)] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key, or None if not found + """ + try: + return self._fd_to_key[fd] + except KeyError: + return None + + +class SelectSelector(BaseSelector): + """Select-based selector.""" + + def __init__(self): + super().__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & EVENT_WRITE: + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + if sys.platform == 'win32': + def _select(self, r, w, _, timeout=None): + r, w, x = select.select(r, w, w, timeout) + return r, w + x, [] + else: + _select = select.select + + def select(self, timeout=None): + timeout = None if timeout is None else max(timeout, 0) + ready = [] + try: + r, w, _ = self._select(self._readers, self._writers, [], timeout) + except InterruptedError: + return ready + r = set(r) + w = set(w) + for fd in r | w: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + +if hasattr(select, 'poll'): + + class PollSelector(BaseSelector): + """Poll-based selector.""" + + def __init__(self): + super().__init__() + self._poll = select.poll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= select.POLLIN + if events & EVENT_WRITE: + poll_events |= select.POLLOUT + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = None if timeout is None else max(int(1000 * timeout), 0) + ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & ~select.POLLIN: + events |= EVENT_WRITE + if event & ~select.POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + +if hasattr(select, 'epoll'): + + class EpollSelector(BaseSelector): + """Epoll-based selector.""" + + def __init__(self): + super().__init__() + self._epoll = select.epoll() + + def fileno(self): + return self._epoll.fileno() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + epoll_events = 0 + if events & EVENT_READ: + epoll_events |= select.EPOLLIN + if events & EVENT_WRITE: + epoll_events |= select.EPOLLOUT + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._epoll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = -1 if timeout is None else max(timeout, 0) + max_ev = len(self._fd_to_key) + ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & ~select.EPOLLIN: + events |= EVENT_WRITE + if event & ~select.EPOLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + super().close() + self._epoll.close() + + +if hasattr(select, 'kqueue'): + + class KqueueSelector(BaseSelector): + """Kqueue-based selector.""" + + def __init__(self): + super().__init__() + self._kqueue = select.kqueue() + + def fileno(self): + return self._kqueue.fileno() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + if key.events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + if key.events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + return key + + def select(self, timeout=None): + timeout = None if timeout is None else max(timeout, 0) + max_ev = len(self._fd_to_key) + ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except InterruptedError: + return ready + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == select.KQ_FILTER_READ: + events |= EVENT_READ + if flag == select.KQ_FILTER_WRITE: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + super().close() + self._kqueue.close() + + +# Choose the best implementation: roughly, epoll|kqueue > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + DefaultSelector = KqueueSelector +elif 'EpollSelector' in globals(): + DefaultSelector = EpollSelector +elif 'PollSelector' in globals(): + DefaultSelector = PollSelector +else: + DefaultSelector = SelectSelector diff --git a/asyncio/streams.py b/asyncio/streams.py new file mode 100644 index 00000000..d0f12e81 --- /dev/null +++ b/asyncio/streams.py @@ -0,0 +1,257 @@ +"""Stream-related things.""" + +__all__ = ['StreamReader', 'StreamReaderProtocol', 'open_connection'] + +import collections + +from . import events +from . import futures +from . import protocols +from . import tasks + + +_DEFAULT_LIMIT = 2**16 + + +@tasks.coroutine +def open_connection(host=None, port=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """A wrapper for create_connection() returning a (reader, writer) pair. + + The reader returned is a StreamReader instance; the writer is a + Transport. + + The arguments are all the usual arguments to create_connection() + except protocol_factory; most common are positional host and port, + with various optional keyword arguments following. + + Additional optional keyword arguments are loop (to set the event loop + instance to use) and limit (to set the buffer limit passed to the + StreamReader). + + (If you want to customize the StreamReader and/or + StreamReaderProtocol classes, just copy the code -- there's + really nothing special here except some convenience.) + """ + if loop is None: + loop = events.get_event_loop() + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader) + transport, _ = yield from loop.create_connection( + lambda: protocol, host, port, **kwds) + return reader, transport # (reader, writer) + + +class StreamReaderProtocol(protocols.Protocol): + """Trivial helper class to adapt between Protocol and StreamReader. + + (This is a helper class instead of making StreamReader itself a + Protocol subclass, because the StreamReader has other potential + uses, and to prevent the user of the StreamReader to accidentally + call inappropriate methods of the protocol.) + """ + + def __init__(self, stream_reader): + self.stream_reader = stream_reader + + def connection_made(self, transport): + self.stream_reader.set_transport(transport) + + def connection_lost(self, exc): + if exc is None: + self.stream_reader.feed_eof() + else: + self.stream_reader.set_exception(exc) + + def data_received(self, data): + self.stream_reader.feed_data(data) + + def eof_received(self): + self.stream_reader.feed_eof() + + +class StreamReader: + + def __init__(self, limit=_DEFAULT_LIMIT, loop=None): + # The line length limit is a security feature; + # it also doubles as half the buffer limit. + self.limit = limit + if loop is None: + loop = events.get_event_loop() + self.loop = loop + self.buffer = collections.deque() # Deque of bytes objects. + self.byte_count = 0 # Bytes in buffer. + self.eof = False # Whether we're done. + self.waiter = None # A future. + self._exception = None + self._transport = None + self._paused = False + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + waiter = self.waiter + if waiter is not None: + self.waiter = None + if not waiter.cancelled(): + waiter.set_exception(exc) + + def set_transport(self, transport): + assert self._transport is None, 'Transport already set' + self._transport = transport + + def _maybe_resume_transport(self): + if self._paused and self.byte_count <= self.limit: + self._paused = False + self._transport.resume() + + def feed_eof(self): + self.eof = True + waiter = self.waiter + if waiter is not None: + self.waiter = None + if not waiter.cancelled(): + waiter.set_result(True) + + def feed_data(self, data): + if not data: + return + + self.buffer.append(data) + self.byte_count += len(data) + + waiter = self.waiter + if waiter is not None: + self.waiter = None + if not waiter.cancelled(): + waiter.set_result(False) + + if (self._transport is not None and + not self._paused and + self.byte_count > 2*self.limit): + try: + self._transport.pause() + except NotImplementedError: + # The transport can't be paused. + # We'll just have to buffer all data. + # Forget the transport so we don't keep trying. + self._transport = None + else: + self._paused = True + + @tasks.coroutine + def readline(self): + if self._exception is not None: + raise self._exception + + parts = [] + parts_size = 0 + not_enough = True + + while not_enough: + while self.buffer and not_enough: + data = self.buffer.popleft() + ichar = data.find(b'\n') + if ichar < 0: + parts.append(data) + parts_size += len(data) + else: + ichar += 1 + head, tail = data[:ichar], data[ichar:] + if tail: + self.buffer.appendleft(tail) + not_enough = False + parts.append(head) + parts_size += len(head) + + if parts_size > self.limit: + self.byte_count -= parts_size + self._maybe_resume_transport() + raise ValueError('Line is too long') + + if self.eof: + break + + if not_enough: + assert self.waiter is None + self.waiter = futures.Future(loop=self.loop) + try: + yield from self.waiter + finally: + self.waiter = None + + line = b''.join(parts) + self.byte_count -= parts_size + self._maybe_resume_transport() + + return line + + @tasks.coroutine + def read(self, n=-1): + if self._exception is not None: + raise self._exception + + if not n: + return b'' + + if n < 0: + while not self.eof: + assert not self.waiter + self.waiter = futures.Future(loop=self.loop) + try: + yield from self.waiter + finally: + self.waiter = None + else: + if not self.byte_count and not self.eof: + assert not self.waiter + self.waiter = futures.Future(loop=self.loop) + try: + yield from self.waiter + finally: + self.waiter = None + + if n < 0 or self.byte_count <= n: + data = b''.join(self.buffer) + self.buffer.clear() + self.byte_count = 0 + self._maybe_resume_transport() + return data + + parts = [] + parts_bytes = 0 + while self.buffer and parts_bytes < n: + data = self.buffer.popleft() + data_bytes = len(data) + if n < parts_bytes + data_bytes: + data_bytes = n - parts_bytes + data, rest = data[:data_bytes], data[data_bytes:] + self.buffer.appendleft(rest) + + parts.append(data) + parts_bytes += data_bytes + self.byte_count -= data_bytes + self._maybe_resume_transport() + + return b''.join(parts) + + @tasks.coroutine + def readexactly(self, n): + if self._exception is not None: + raise self._exception + + if n <= 0: + return b'' + + while self.byte_count < n and not self.eof: + assert not self.waiter + self.waiter = futures.Future(loop=self.loop) + try: + yield from self.waiter + finally: + self.waiter = None + + return (yield from self.read(n)) diff --git a/asyncio/tasks.py b/asyncio/tasks.py new file mode 100644 index 00000000..7ece2b9d --- /dev/null +++ b/asyncio/tasks.py @@ -0,0 +1,575 @@ +"""Support for tasks, coroutines and the scheduler.""" + +__all__ = ['coroutine', 'Task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'wait_for', 'as_completed', 'sleep', 'async', + 'gather', + ] + +import collections +import concurrent.futures +import functools +import inspect +import linecache +import traceback +import weakref + +from . import events +from . import futures + + +def coroutine(func): + """Decorator to mark coroutines. + + Decorator wraps non generator functions and returns generator wrapper. + If non generator function returns generator of Future it yield-from it. + + TODO: This is a feel-good API only. It is not enforced. + """ + if inspect.isgeneratorfunction(func): + coro = func + else: + @functools.wraps(func) + def coro(*args, **kw): + res = func(*args, **kw) + if isinstance(res, futures.Future) or inspect.isgenerator(res): + res = yield from res + return res + + coro._is_coroutine = True # Not sure who can use this. + return coro + + +# TODO: Do we need this? +def iscoroutinefunction(func): + """Return True if func is a decorated coroutine function.""" + return (inspect.isgeneratorfunction(func) and + getattr(func, '_is_coroutine', False)) + + +# TODO: Do we need this? +def iscoroutine(obj): + """Return True if obj is a coroutine object.""" + return inspect.isgenerator(obj) # TODO: And what? + + +class Task(futures.Future): + """A coroutine wrapped in a Future.""" + + # An important invariant maintained while a Task not done: + # + # - Either _fut_waiter is None, and _step() is scheduled; + # - or _fut_waiter is some Future, and _step() is *not* scheduled. + # + # The only transition from the latter to the former is through + # _wakeup(). When _fut_waiter is not None, one of its callbacks + # must be _wakeup(). + + # Weak set containing all tasks alive. + _all_tasks = weakref.WeakSet() + + @classmethod + def all_tasks(cls, loop=None): + """Return a set of all tasks for an event loop. + + By default all tasks for the current event loop are returned. + """ + if loop is None: + loop = events.get_event_loop() + return {t for t in cls._all_tasks if t._loop is loop} + + def __init__(self, coro, *, loop=None): + assert inspect.isgenerator(coro) # Must be a coroutine *object*. + super().__init__(loop=loop) + self._coro = coro + self._fut_waiter = None + self._must_cancel = False + self._loop.call_soon(self._step) + self.__class__._all_tasks.add(self) + + def __repr__(self): + res = super().__repr__() + if (self._must_cancel and + self._state == futures._PENDING and + ')'.format(self._coro.__name__) + res[i:] + return res + + def get_stack(self, *, limit=None): + """Return the list of stack frames for this task's coroutine. + + If the coroutine is active, this returns the stack where it is + suspended. If the coroutine has completed successfully or was + cancelled, this returns an empty list. If the coroutine was + terminated by an exception, this returns the list of traceback + frames. + + The frames are always ordered from oldest to newest. + + The optional limit gives the maximum nummber of frames to + return; by default all available frames are returned. Its + meaning differs depending on whether a stack or a traceback is + returned: the newest frames of a stack are returned, but the + oldest frames of a traceback are returned. (This matches the + behavior of the traceback module.) + + For reasons beyond our control, only one stack frame is + returned for a suspended coroutine. + """ + frames = [] + f = self._coro.gi_frame + if f is not None: + while f is not None: + if limit is not None: + if limit <= 0: + break + limit -= 1 + frames.append(f) + f = f.f_back + frames.reverse() + elif self._exception is not None: + tb = self._exception.__traceback__ + while tb is not None: + if limit is not None: + if limit <= 0: + break + limit -= 1 + frames.append(tb.tb_frame) + tb = tb.tb_next + return frames + + def print_stack(self, *, limit=None, file=None): + """Print the stack or traceback for this task's coroutine. + + This produces output similar to that of the traceback module, + for the frames retrieved by get_stack(). The limit argument + is passed to get_stack(). The file argument is an I/O stream + to which the output goes; by default it goes to sys.stderr. + """ + extracted_list = [] + checked = set() + for f in self.get_stack(limit=limit): + lineno = f.f_lineno + co = f.f_code + filename = co.co_filename + name = co.co_name + if filename not in checked: + checked.add(filename) + linecache.checkcache(filename) + line = linecache.getline(filename, lineno, f.f_globals) + extracted_list.append((filename, lineno, name, line)) + exc = self._exception + if not extracted_list: + print('No stack for %r' % self, file=file) + elif exc is not None: + print('Traceback for %r (most recent call last):' % self, + file=file) + else: + print('Stack for %r (most recent call last):' % self, + file=file) + traceback.print_list(extracted_list, file=file) + if exc is not None: + for line in traceback.format_exception_only(exc.__class__, exc): + print(line, file=file, end='') + + def cancel(self): + if self.done(): + return False + if self._fut_waiter is not None: + if self._fut_waiter.cancel(): + # Leave self._fut_waiter; it may be a Task that + # catches and ignores the cancellation so we may have + # to cancel it again later. + return True + # It must be the case that self._step is already scheduled. + self._must_cancel = True + return True + + def _step(self, value=None, exc=None): + assert not self.done(), \ + '_step(): already done: {!r}, {!r}, {!r}'.format(self, value, exc) + if self._must_cancel: + if not isinstance(exc, futures.CancelledError): + exc = futures.CancelledError() + self._must_cancel = False + coro = self._coro + self._fut_waiter = None + # Call either coro.throw(exc) or coro.send(value). + try: + if exc is not None: + result = coro.throw(exc) + elif value is not None: + result = coro.send(value) + else: + result = next(coro) + except StopIteration as exc: + self.set_result(exc.value) + except futures.CancelledError as exc: + super().cancel() # I.e., Future.cancel(self). + except Exception as exc: + self.set_exception(exc) + except BaseException as exc: + self.set_exception(exc) + raise + else: + if isinstance(result, futures.Future): + # Yielded Future must come from Future.__iter__(). + if result._blocking: + result._blocking = False + result.add_done_callback(self._wakeup) + self._fut_waiter = result + if self._must_cancel: + if self._fut_waiter.cancel(): + self._must_cancel = False + else: + self._loop.call_soon( + self._step, None, + RuntimeError( + 'yield was used instead of yield from ' + 'in task {!r} with {!r}'.format(self, result))) + elif result is None: + # Bare yield relinquishes control for one event loop iteration. + self._loop.call_soon(self._step) + elif inspect.isgenerator(result): + # Yielding a generator is just wrong. + self._loop.call_soon( + self._step, None, + RuntimeError( + 'yield was used instead of yield from for ' + 'generator in task {!r} with {}'.format( + self, result))) + else: + # Yielding something else is an error. + self._loop.call_soon( + self._step, None, + RuntimeError( + 'Task got bad yield: {!r}'.format(result))) + self = None + + def _wakeup(self, future): + try: + value = future.result() + except Exception as exc: + # This may also be a cancellation. + self._step(None, exc) + else: + self._step(value, None) + self = None # Needed to break cycles when an exception occurs. + + +# wait() and as_completed() similar to those in PEP 3148. + +FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED +FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION +ALL_COMPLETED = concurrent.futures.ALL_COMPLETED + + +@coroutine +def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): + """Wait for the Futures and coroutines given by fs to complete. + + Coroutines will be wrapped in Tasks. + + Returns two sets of Future: (done, pending). + + Usage: + + done, pending = yield from asyncio.wait(fs) + + Note: This does not raise TimeoutError! Futures that aren't done + when the timeout occurs are returned in the second set. + """ + if not fs: + raise ValueError('Set of coroutines/Futures is empty.') + + if loop is None: + loop = events.get_event_loop() + + fs = set(async(f, loop=loop) for f in fs) + + if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED): + raise ValueError('Invalid return_when value: {}'.format(return_when)) + return (yield from _wait(fs, timeout, return_when, loop)) + + +@coroutine +def wait_for(fut, timeout, *, loop=None): + """Wait for the single Future or coroutine to complete, with timeout. + + Coroutine will be wrapped in Task. + + Returns result of the Future or coroutine. Raises TimeoutError when + timeout occurs. + + Usage: + + result = yield from asyncio.wait_for(fut, 10.0) + + """ + if loop is None: + loop = events.get_event_loop() + + fut = async(fut, loop=loop) + + done, pending = yield from _wait([fut], timeout, FIRST_COMPLETED, loop) + if done: + return done.pop().result() + + raise futures.TimeoutError() + + +def _waiter_timeout(waiter): + if not waiter.done(): + waiter.set_result(False) + + +@coroutine +def _wait(fs, timeout, return_when, loop): + """Internal helper for wait() and _wait_for(). + + The fs argument must be a collection of Futures. + """ + assert fs, 'Set of Futures is empty.' + waiter = futures.Future(loop=loop) + timeout_handle = None + if timeout is not None: + timeout_handle = loop.call_later(timeout, _waiter_timeout, waiter) + counter = len(fs) + + def _on_completion(f): + nonlocal counter + counter -= 1 + if (counter <= 0 or + return_when == FIRST_COMPLETED or + return_when == FIRST_EXCEPTION and (not f.cancelled() and + f.exception() is not None)): + if timeout_handle is not None: + timeout_handle.cancel() + if not waiter.done(): + waiter.set_result(False) + + for f in fs: + f.add_done_callback(_on_completion) + + try: + yield from waiter + finally: + if timeout_handle is not None: + timeout_handle.cancel() + + done, pending = set(), set() + for f in fs: + f.remove_done_callback(_on_completion) + if f.done(): + done.add(f) + else: + pending.add(f) + return done, pending + + +# This is *not* a @coroutine! It is just an iterator (yielding Futures). +def as_completed(fs, *, loop=None, timeout=None): + """Return an iterator whose values, when waited for, are Futures. + + This differs from PEP 3148; the proper way to use this is: + + for f in as_completed(fs): + result = yield from f # The 'yield from' may raise. + # Use result. + + Raises TimeoutError if the timeout occurs before all Futures are + done. + + Note: The futures 'f' are not necessarily members of fs. + """ + loop = loop if loop is not None else events.get_event_loop() + deadline = None if timeout is None else loop.time() + timeout + todo = set(async(f, loop=loop) for f in fs) + completed = collections.deque() + + @coroutine + def _wait_for_one(): + while not completed: + timeout = None + if deadline is not None: + timeout = deadline - loop.time() + if timeout < 0: + raise futures.TimeoutError() + done, pending = yield from _wait( + todo, timeout, FIRST_COMPLETED, loop) + # Multiple callers might be waiting for the same events + # and getting the same outcome. Dedupe by updating todo. + for f in done: + if f in todo: + todo.remove(f) + completed.append(f) + f = completed.popleft() + return f.result() # May raise. + + for _ in range(len(todo)): + yield _wait_for_one() + + +@coroutine +def sleep(delay, result=None, *, loop=None): + """Coroutine that completes after a given time (in seconds).""" + future = futures.Future(loop=loop) + h = future._loop.call_later(delay, future.set_result, result) + try: + return (yield from future) + finally: + h.cancel() + + +def async(coro_or_future, *, loop=None): + """Wrap a coroutine in a future. + + If the argument is a Future, it is returned directly. + """ + if isinstance(coro_or_future, futures.Future): + if loop is not None and loop is not coro_or_future._loop: + raise ValueError('loop argument must agree with Future') + return coro_or_future + elif iscoroutine(coro_or_future): + return Task(coro_or_future, loop=loop) + else: + raise TypeError('A Future or coroutine is required') + + +class _GatheringFuture(futures.Future): + """Helper for gather(). + + This overrides cancel() to cancel all the children and act more + like Task.cancel(), which doesn't immediately mark itself as + cancelled. + """ + + def __init__(self, children, *, loop=None): + super().__init__(loop=loop) + self._children = children + + def cancel(self): + if self.done(): + return False + for child in self._children: + child.cancel() + return True + + +def gather(*coros_or_futures, loop=None, return_exceptions=False): + """Return a future aggregating results from the given coroutines + or futures. + + All futures must share the same event loop. If all the tasks are + done successfully, the returned future's result is the list of + results (in the order of the original sequence, not necessarily + the order of results arrival). If *result_exception* is True, + exceptions in the tasks are treated the same as successful + results, and gathered in the result list; otherwise, the first + raised exception will be immediately propagated to the returned + future. + + Cancellation: if the outer Future is cancelled, all children (that + have not completed yet) are also cancelled. If any child is + cancelled, this is treated as if it raised CancelledError -- + the outer Future is *not* cancelled in this case. (This is to + prevent the cancellation of one child to cause other children to + be cancelled.) + """ + children = [async(fut, loop=loop) for fut in coros_or_futures] + n = len(children) + if n == 0: + outer = futures.Future(loop=loop) + outer.set_result([]) + return outer + if loop is None: + loop = children[0]._loop + for fut in children: + if fut._loop is not loop: + raise ValueError("futures are tied to different event loops") + outer = _GatheringFuture(children, loop=loop) + nfinished = 0 + results = [None] * n + + def _done_callback(i, fut): + nonlocal nfinished + if outer._state != futures._PENDING: + if fut._exception is not None: + # Mark exception retrieved. + fut.exception() + return + if fut._state == futures._CANCELLED: + res = futures.CancelledError() + if not return_exceptions: + outer.set_exception(res) + return + elif fut._exception is not None: + res = fut.exception() # Mark exception retrieved. + if not return_exceptions: + outer.set_exception(res) + return + else: + res = fut._result + results[i] = res + nfinished += 1 + if nfinished == n: + outer.set_result(results) + + for i, fut in enumerate(children): + fut.add_done_callback(functools.partial(_done_callback, i)) + return outer + + +def shield(arg, *, loop=None): + """Wait for a future, shielding it from cancellation. + + The statement + + res = yield from shield(something()) + + is exactly equivalent to the statement + + res = yield from something() + + *except* that if the coroutine containing it is cancelled, the + task running in something() is not cancelled. From the POV of + something(), the cancellation did not happen. But its caller is + still cancelled, so the yield-from expression still raises + CancelledError. Note: If something() is cancelled by other means + this will still cancel shield(). + + If you want to completely ignore cancellation (not recommended) + you can combine shield() with a try/except clause, as follows: + + try: + res = yield from shield(something()) + except CancelledError: + res = None + """ + inner = async(arg, loop=loop) + if inner.done(): + # Shortcut. + return inner + loop = inner._loop + outer = futures.Future(loop=loop) + + def _done_callback(inner): + if outer.cancelled(): + # Mark inner's result as retrieved. + inner.cancelled() or inner.exception() + return + if inner.cancelled(): + outer.cancel() + else: + exc = inner.exception() + if exc is not None: + outer.set_exception(exc) + else: + outer.set_result(inner.result()) + + inner.add_done_callback(_done_callback) + return outer diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py new file mode 100644 index 00000000..1841f852 --- /dev/null +++ b/asyncio/test_utils.py @@ -0,0 +1,228 @@ +"""Utilities shared by tests.""" + +import collections +import contextlib +import io +import unittest.mock +import os +import sys +import threading +import unittest +import unittest.mock +from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import tasks +from . import base_events +from . import events +from . import selectors + + +if sys.platform == 'win32': # pragma: no cover + from .windows_utils import socketpair +else: + from socket import socketpair # pragma: no cover + + +def run_briefly(loop): + @tasks.coroutine + def once(): + pass + t = tasks.Task(once(), loop=loop) + loop.run_until_complete(t) + + +def run_once(loop): + """loop.stop() schedules _raise_stop_error() + and run_forever() runs until _raise_stop_error() callback. + this wont work if test waits for some IO events, because + _raise_stop_error() runs before any of io events callbacks. + """ + loop.stop() + loop.run_forever() + + +@contextlib.contextmanager +def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): + + class SilentWSGIRequestHandler(WSGIRequestHandler): + def get_stderr(self): + return io.StringIO() + + def log_message(self, format, *args): + pass + + class SilentWSGIServer(WSGIServer): + def handle_error(self, request, client_address): + pass + + class SSLWSGIServer(SilentWSGIServer): + def finish_request(self, request, client_address): + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + keyfile = os.path.join(here, 'sample.key') + certfile = os.path.join(here, 'sample.crt') + ssock = ssl.wrap_socket(request, + keyfile=keyfile, + certfile=certfile, + server_side=True) + try: + self.RequestHandlerClass(ssock, client_address, self) + ssock.close() + except OSError: + # maybe socket has been closed by peer + pass + + def app(environ, start_response): + status = '200 OK' + headers = [('Content-type', 'text/plain')] + start_response(status, headers) + return [b'Test message'] + + # Run the test WSGI server in a separate thread in order not to + # interfere with event handling in the main thread + server_class = SSLWSGIServer if use_ssl else SilentWSGIServer + httpd = make_server(host, port, app, + server_class, SilentWSGIRequestHandler) + httpd.address = httpd.server_address + server_thread = threading.Thread(target=httpd.serve_forever) + server_thread.start() + try: + yield httpd + finally: + httpd.shutdown() + server_thread.join() + + +def make_test_protocol(base): + dct = {} + for name in dir(base): + if name.startswith('__') and name.endswith('__'): + # skip magic names + continue + dct[name] = unittest.mock.Mock(return_value=None) + return type('TestProtocol', (base,) + base.__bases__, dct)() + + +class TestSelector(selectors.BaseSelector): + + def select(self, timeout): + return [] + + +class TestLoop(base_events.BaseEventLoop): + """Loop for unittests. + + It manages self time directly. + If something scheduled to be executed later then + on next loop iteration after all ready handlers done + generator passed to __init__ is calling. + + Generator should be like this: + + def gen(): + ... + when = yield ... + ... = yield time_advance + + Value retuned by yield is absolute time of next scheduled handler. + Value passed to yield is time advance to move loop's time forward. + """ + + def __init__(self, gen=None): + super().__init__() + + if gen is None: + def gen(): + yield + self._check_on_close = False + else: + self._check_on_close = True + + self._gen = gen() + next(self._gen) + self._time = 0 + self._timers = [] + self._selector = TestSelector() + + self.readers = {} + self.writers = {} + self.reset_counters() + + def time(self): + return self._time + + def advance_time(self, advance): + """Move test time forward.""" + if advance: + self._time += advance + + def close(self): + if self._check_on_close: + try: + self._gen.send(0) + except StopIteration: + pass + else: # pragma: no cover + raise AssertionError("Time generator is not finished") + + def add_reader(self, fd, callback, *args): + self.readers[fd] = events.make_handle(callback, args) + + def remove_reader(self, fd): + self.remove_reader_count[fd] += 1 + if fd in self.readers: + del self.readers[fd] + return True + else: + return False + + def assert_reader(self, fd, callback, *args): + assert fd in self.readers, 'fd {} is not registered'.format(fd) + handle = self.readers[fd] + assert handle._callback == callback, '{!r} != {!r}'.format( + handle._callback, callback) + assert handle._args == args, '{!r} != {!r}'.format( + handle._args, args) + + def add_writer(self, fd, callback, *args): + self.writers[fd] = events.make_handle(callback, args) + + def remove_writer(self, fd): + self.remove_writer_count[fd] += 1 + if fd in self.writers: + del self.writers[fd] + return True + else: + return False + + def assert_writer(self, fd, callback, *args): + assert fd in self.writers, 'fd {} is not registered'.format(fd) + handle = self.writers[fd] + assert handle._callback == callback, '{!r} != {!r}'.format( + handle._callback, callback) + assert handle._args == args, '{!r} != {!r}'.format( + handle._args, args) + + def reset_counters(self): + self.remove_reader_count = collections.defaultdict(int) + self.remove_writer_count = collections.defaultdict(int) + + def _run_once(self): + super()._run_once() + for when in self._timers: + advance = self._gen.send(when) + self.advance_time(advance) + self._timers = [] + + def call_at(self, when, callback, *args): + self._timers.append(when) + return super().call_at(when, callback, *args) + + def _process_events(self, event_list): + return + + def _write_to_self(self): + pass diff --git a/asyncio/transports.py b/asyncio/transports.py new file mode 100644 index 00000000..bf3adee7 --- /dev/null +++ b/asyncio/transports.py @@ -0,0 +1,186 @@ +"""Abstract Transport class.""" + +__all__ = ['ReadTransport', 'WriteTransport', 'Transport'] + + +class BaseTransport: + """Base ABC for transports.""" + + def __init__(self, extra=None): + if extra is None: + extra = {} + self._extra = extra + + def get_extra_info(self, name, default=None): + """Get optional transport information.""" + return self._extra.get(name, default) + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + raise NotImplementedError + + +class ReadTransport(BaseTransport): + """ABC for read-only transports.""" + + def pause(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume() is called. + """ + raise NotImplementedError + + def resume(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + raise NotImplementedError + + +class WriteTransport(BaseTransport): + """ABC for write-only transports.""" + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def write_eof(self): + """Closes the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + raise NotImplementedError + + def can_write_eof(self): + """Return True if this protocol supports write_eof(), False if not.""" + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class Transport(ReadTransport, WriteTransport): + """ABC representing a bidirectional transport. + + There may be several implementations, but typically, the user does + not implement new transports; rather, the platform provides some + useful transports that are implemented using the platform's best + practices. + + The user never instantiates a transport directly; they call a + utility function, passing it a protocol factory and other + information necessary to create the transport and protocol. (E.g. + EventLoop.create_connection() or EventLoop.create_server().) + + The utility function will asynchronously create a transport and a + protocol and hook them up by calling the protocol's + connection_made() method, passing it the transport. + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + +class DatagramTransport(BaseTransport): + """ABC for datagram (UDP) transports.""" + + def sendto(self, data, addr=None): + """Send data to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + addr is target socket address. + If addr is None use target address pointed on transport creation. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class SubprocessTransport(BaseTransport): + + def get_pid(self): + """Get subprocess id.""" + raise NotImplementedError + + def get_returncode(self): + """Get subprocess returncode. + + See also + http://docs.python.org/3/library/subprocess#subprocess.Popen.returncode + """ + raise NotImplementedError + + def get_pipe_transport(self, fd): + """Get transport for pipe with number fd.""" + raise NotImplementedError + + def send_signal(self, signal): + """Send signal to subprocess. + + See also: + docs.python.org/3/library/subprocess#subprocess.Popen.send_signal + """ + raise NotImplementedError + + def terminate(self): + """Stop the subprocess. + + Alias for close() method. + + On Posix OSs the method sends SIGTERM to the subprocess. + On Windows the Win32 API function TerminateProcess() + is called to stop the subprocess. + + See also: + http://docs.python.org/3/library/subprocess#subprocess.Popen.terminate + """ + raise NotImplementedError + + def kill(self): + """Kill the subprocess. + + On Posix OSs the function sends SIGKILL to the subprocess. + On Windows kill() is an alias for terminate(). + + See also: + http://docs.python.org/3/library/subprocess#subprocess.Popen.kill + """ + raise NotImplementedError diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py new file mode 100644 index 00000000..49761b99 --- /dev/null +++ b/asyncio/unix_events.py @@ -0,0 +1,534 @@ +"""Selector eventloop for Unix with signal handling.""" + +import collections +import errno +import fcntl +import functools +import os +import signal +import socket +import stat +import subprocess +import sys + + +from . import constants +from . import events +from . import protocols +from . import selector_events +from . import tasks +from . import transports +from .log import asyncio_log + + +__all__ = ['SelectorEventLoop'] + + +if sys.platform == 'win32': # pragma: no cover + raise ImportError('Signals are not really supported on Windows') + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Unix event loop + + Adds signal handling to SelectorEventLoop + """ + + def __init__(self, selector=None): + super().__init__(selector) + self._signal_handlers = {} + self._subprocesses = {} + + def _socketpair(self): + return socket.socketpair() + + def close(self): + handler = self._signal_handlers.get(signal.SIGCHLD) + if handler is not None: + self.remove_signal_handler(signal.SIGCHLD) + super().close() + + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + self._check_signal(sig) + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except ValueError as exc: + raise RuntimeError(str(exc)) + + handle = events.make_handle(callback, args) + self._signal_handlers[sig] = handle + + try: + signal.signal(sig, self._handle_signal) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as nexc: + asyncio_log.info('set_wakeup_fd(-1) failed: %s', nexc) + + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + def _handle_signal(self, sig, arg): + """Internal helper that is the actual signal handler.""" + handle = self._signal_handlers.get(sig) + if handle is None: + return # Assume it's some race condition. + if handle._cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self._add_callback_signalsafe(handle) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not. + """ + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as exc: + asyncio_log.info('set_wakeup_fd(-1) failed: %s', exc) + + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + + if not (1 <= sig < signal.NSIG): + raise ValueError( + 'sig {} out of range(1, {})'.format(sig, signal.NSIG)) + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixReadPipeTransport(self, pipe, protocol, waiter, extra) + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixWritePipeTransport(self, pipe, protocol, waiter, extra) + + @tasks.coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + self._reg_sigchld() + transp = _UnixSubprocessTransport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs) + self._subprocesses[transp.get_pid()] = transp + yield from transp._post_init() + return transp + + def _reg_sigchld(self): + if signal.SIGCHLD not in self._signal_handlers: + self.add_signal_handler(signal.SIGCHLD, self._sig_chld) + + def _sig_chld(self): + try: + while True: + try: + pid, status = os.waitpid(0, os.WNOHANG) + except ChildProcessError: + break + if pid == 0: + continue + elif os.WIFSIGNALED(status): + returncode = -os.WTERMSIG(status) + elif os.WIFEXITED(status): + returncode = os.WEXITSTATUS(status) + else: + # covered by + # SelectorEventLoopTests.test__sig_chld_unknown_status + # from tests/unix_events_test.py + # bug in coverage.py version 3.6 ??? + continue # pragma: no cover + transp = self._subprocesses.get(pid) + if transp is not None: + transp._process_exited(returncode) + except Exception: + asyncio_log.exception('Unknown exception in SIGCHLD handler') + + def _subprocess_closed(self, transport): + pid = transport.get_pid() + self._subprocesses.pop(pid, None) + + +def _set_nonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + + +class _UnixReadPipeTransport(transports.ReadTransport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._loop = loop + self._pipe = pipe + self._fileno = pipe.fileno() + _set_nonblocking(self._fileno) + self._protocol = protocol + self._closing = False + self._loop.add_reader(self._fileno, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = os.read(self._fileno, self.max_size) + except (BlockingIOError, InterruptedError): + pass + except OSError as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + self._closing = True + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._protocol.eof_received) + self._loop.call_soon(self._call_connection_lost, None) + + def pause(self): + self._loop.remove_reader(self._fileno) + + def resume(self): + self._loop.add_reader(self._fileno, self._read_ready) + + def close(self): + if not self._closing: + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + asyncio_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._closing = True + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + self._pipe = None + self._protocol = None + self._loop = None + + +class _UnixWritePipeTransport(transports.WriteTransport): + + def __init__(self, loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._loop = loop + self._pipe = pipe + self._fileno = pipe.fileno() + if not stat.S_ISFIFO(os.fstat(self._fileno).st_mode): + raise ValueError("Pipe transport is for pipes only.") + _set_nonblocking(self._fileno) + self._protocol = protocol + self._buffer = [] + self._conn_lost = 0 + self._closing = False # Set when close() or write_eof() called. + self._loop.add_reader(self._fileno, self._read_ready) + + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + # pipe was closed by peer + + self._close() + + def write(self, data): + assert isinstance(data, bytes), repr(data) + assert not self._closing + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + asyncio_log.warning('os.write(pipe, data) raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Attempt to send it right away first. + try: + n = os.write(self._fileno, data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + self._conn_lost += 1 + self._fatal_error(exc) + return + if n == len(data): + return + elif n > 0: + data = data[n:] + self._loop.add_writer(self._fileno, self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + assert data, 'Data should not be empty' + + self._buffer.clear() + try: + n = os.write(self._fileno, data) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except Exception as exc: + self._conn_lost += 1 + # Remove writer here, _fatal_error() doesn't it + # because _buffer is empty. + self._loop.remove_writer(self._fileno) + self._fatal_error(exc) + else: + if n == len(data): + self._loop.remove_writer(self._fileno) + if self._closing: + self._loop.remove_reader(self._fileno) + self._call_connection_lost(None) + return + elif n > 0: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def can_write_eof(self): + return True + + def write_eof(self): + assert not self._closing + assert self._pipe + self._closing = True + if not self._buffer: + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, None) + + def close(self): + if not self._closing: + # write_eof is all what we needed to close the write pipe + self.write_eof() + + def abort(self): + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + asyncio_log.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc=None): + self._closing = True + if self._buffer: + self._loop.remove_writer(self._fileno) + self._buffer.clear() + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + self._pipe = None + self._protocol = None + self._loop = None + + +class _UnixWriteSubprocessPipeProto(protocols.BaseProtocol): + pipe = None + + def __init__(self, proc, fd): + self.proc = proc + self.fd = fd + self.connected = False + self.disconnected = False + proc._pipes[fd] = self + + def connection_made(self, transport): + self.connected = True + self.pipe = transport + self.proc._try_connected() + + def connection_lost(self, exc): + self.disconnected = True + self.proc._pipe_connection_lost(self.fd, exc) + + +class _UnixReadSubprocessPipeProto(_UnixWriteSubprocessPipeProto, + protocols.Protocol): + + def data_received(self, data): + self.proc._pipe_data_received(self.fd, data) + + def eof_received(self): + pass + + +class _UnixSubprocessTransport(transports.SubprocessTransport): + + def __init__(self, loop, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + super().__init__(extra) + self._protocol = protocol + self._loop = loop + + self._pipes = {} + if stdin == subprocess.PIPE: + self._pipes[0] = None + if stdout == subprocess.PIPE: + self._pipes[1] = None + if stderr == subprocess.PIPE: + self._pipes[2] = None + self._pending_calls = collections.deque() + self._finished = False + self._returncode = None + + self._proc = subprocess.Popen( + args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, + universal_newlines=False, bufsize=bufsize, **kwargs) + self._extra['subprocess'] = self._proc + + def close(self): + for proto in self._pipes.values(): + proto.pipe.close() + if self._returncode is None: + self.terminate() + + def get_pid(self): + return self._proc.pid + + def get_returncode(self): + return self._returncode + + def get_pipe_transport(self, fd): + if fd in self._pipes: + return self._pipes[fd].pipe + else: + return None + + def send_signal(self, signal): + self._proc.send_signal(signal) + + def terminate(self): + self._proc.terminate() + + def kill(self): + self._proc.kill() + + @tasks.coroutine + def _post_init(self): + proc = self._proc + loop = self._loop + if proc.stdin is not None: + transp, proto = yield from loop.connect_write_pipe( + functools.partial(_UnixWriteSubprocessPipeProto, self, 0), + proc.stdin) + if proc.stdout is not None: + transp, proto = yield from loop.connect_read_pipe( + functools.partial(_UnixReadSubprocessPipeProto, self, 1), + proc.stdout) + if proc.stderr is not None: + transp, proto = yield from loop.connect_read_pipe( + functools.partial(_UnixReadSubprocessPipeProto, self, 2), + proc.stderr) + if not self._pipes: + self._try_connected() + + def _call(self, cb, *data): + if self._pending_calls is not None: + self._pending_calls.append((cb, data)) + else: + self._loop.call_soon(cb, *data) + + def _try_connected(self): + assert self._pending_calls is not None + if all(p is not None and p.connected for p in self._pipes.values()): + self._loop.call_soon(self._protocol.connection_made, self) + for callback, data in self._pending_calls: + self._loop.call_soon(callback, *data) + self._pending_calls = None + + def _pipe_connection_lost(self, fd, exc): + self._call(self._protocol.pipe_connection_lost, fd, exc) + self._try_finish() + + def _pipe_data_received(self, fd, data): + self._call(self._protocol.pipe_data_received, fd, data) + + def _process_exited(self, returncode): + assert returncode is not None, returncode + assert self._returncode is None, self._returncode + self._returncode = returncode + self._loop._subprocess_closed(self) + self._call(self._protocol.process_exited) + self._try_finish() + + def _try_finish(self): + assert not self._finished + if self._returncode is None: + return + if all(p is not None and p.disconnected + for p in self._pipes.values()): + self._finished = True + self._loop.call_soon(self._call_connection_lost, None) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._proc = None + self._protocol = None + self._loop = None diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py new file mode 100644 index 00000000..2897dd0e --- /dev/null +++ b/asyncio/windows_events.py @@ -0,0 +1,371 @@ +"""Selector and proactor eventloops for Windows.""" + +import errno +import socket +import weakref +import struct +import _winapi + +from . import futures +from . import proactor_events +from . import selector_events +from . import tasks +from . import windows_utils +from . import _overlapped +from .log import asyncio_log + + +__all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor'] + + +NULL = 0 +INFINITE = 0xffffffff +ERROR_CONNECTION_REFUSED = 1225 +ERROR_CONNECTION_ABORTED = 1236 + + +class _OverlappedFuture(futures.Future): + """Subclass of Future which represents an overlapped operation. + + Cancelling it will immediately cancel the overlapped operation. + """ + + def __init__(self, ov, *, loop=None): + super().__init__(loop=loop) + self.ov = ov + + def cancel(self): + try: + self.ov.cancel() + except OSError: + pass + return super().cancel() + + +class PipeServer(object): + """Class representing a pipe server. + + This is much like a bound, listening socket. + """ + def __init__(self, address): + self._address = address + self._free_instances = weakref.WeakSet() + self._pipe = self._server_pipe_handle(True) + + def _get_unconnected_pipe(self): + # Create new instance and return previous one. This ensures + # that (until the server is closed) there is always at least + # one pipe handle for address. Therefore if a client attempt + # to connect it will not fail with FileNotFoundError. + tmp, self._pipe = self._pipe, self._server_pipe_handle(False) + return tmp + + def _server_pipe_handle(self, first): + # Return a wrapper for a new pipe handle. + if self._address is None: + return None + flags = _winapi.PIPE_ACCESS_DUPLEX | _winapi.FILE_FLAG_OVERLAPPED + if first: + flags |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE + h = _winapi.CreateNamedPipe( + self._address, flags, + _winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_READMODE_MESSAGE | + _winapi.PIPE_WAIT, + _winapi.PIPE_UNLIMITED_INSTANCES, + windows_utils.BUFSIZE, windows_utils.BUFSIZE, + _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL) + pipe = windows_utils.PipeHandle(h) + self._free_instances.add(pipe) + return pipe + + def close(self): + # Close all instances which have not been connected to by a client. + if self._address is not None: + for pipe in self._free_instances: + pipe.close() + self._pipe = None + self._address = None + self._free_instances.clear() + + __del__ = close + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Windows version of selector event loop.""" + + def _socketpair(self): + return windows_utils.socketpair() + + +class ProactorEventLoop(proactor_events.BaseProactorEventLoop): + """Windows version of proactor event loop using IOCP.""" + + def __init__(self, proactor=None): + if proactor is None: + proactor = IocpProactor() + super().__init__(proactor) + + def _socketpair(self): + return windows_utils.socketpair() + + @tasks.coroutine + def create_pipe_connection(self, protocol_factory, address): + f = self._proactor.connect_pipe(address) + pipe = yield from f + protocol = protocol_factory() + trans = self._make_socket_transport(pipe, protocol, + extra={'addr': address}) + return trans, protocol + + @tasks.coroutine + def start_serving_pipe(self, protocol_factory, address): + server = PipeServer(address) + def loop(f=None): + pipe = None + try: + if f: + pipe = f.result() + server._free_instances.discard(pipe) + protocol = protocol_factory() + self._make_duplex_pipe_transport( + pipe, protocol, extra={'addr': address}) + pipe = server._get_unconnected_pipe() + if pipe is None: + return + f = self._proactor.accept_pipe(pipe) + except OSError: + if pipe and pipe.fileno() != -1: + asyncio_log.exception('Pipe accept failed') + pipe.close() + except futures.CancelledError: + if pipe: + pipe.close() + else: + f.add_done_callback(loop) + self.call_soon(loop) + return [server] + + def _stop_serving(self, server): + server.close() + + +class IocpProactor: + """Proactor implementation using IOCP.""" + + def __init__(self, concurrency=0xffffffff): + self._loop = None + self._results = [] + self._iocp = _overlapped.CreateIoCompletionPort( + _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency) + self._cache = {} + self._registered = weakref.WeakSet() + self._stopped_serving = weakref.WeakSet() + + def set_loop(self, loop): + self._loop = loop + + def select(self, timeout=None): + if not self._results: + self._poll(timeout) + tmp = self._results + self._results = [] + return tmp + + def recv(self, conn, nbytes, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + if isinstance(conn, socket.socket): + ov.WSARecv(conn.fileno(), nbytes, flags) + else: + ov.ReadFile(conn.fileno(), nbytes) + def finish(trans, key, ov): + try: + return ov.getresult() + except OSError as exc: + if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: + raise ConnectionResetError(*exc.args) + else: + raise + return self._register(ov, conn, finish) + + def send(self, conn, buf, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + if isinstance(conn, socket.socket): + ov.WSASend(conn.fileno(), buf, flags) + else: + ov.WriteFile(conn.fileno(), buf) + def finish(trans, key, ov): + try: + return ov.getresult() + except OSError as exc: + if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: + raise ConnectionResetError(*exc.args) + else: + raise + return self._register(ov, conn, finish) + + def accept(self, listener): + self._register_with_iocp(listener) + conn = self._get_accept_socket() + ov = _overlapped.Overlapped(NULL) + ov.AcceptEx(listener.fileno(), conn.fileno()) + def finish_accept(trans, key, ov): + ov.getresult() + # Use SO_UPDATE_ACCEPT_CONTEXT so getsockname() etc work. + buf = struct.pack('@P', listener.fileno()) + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_ACCEPT_CONTEXT, buf) + conn.settimeout(listener.gettimeout()) + return conn, conn.getpeername() + return self._register(ov, listener, finish_accept) + + def connect(self, conn, address): + self._register_with_iocp(conn) + # The socket needs to be locally bound before we call ConnectEx(). + try: + _overlapped.BindLocal(conn.fileno(), len(address)) + except OSError as e: + if e.winerror != errno.WSAEINVAL: + raise + # Probably already locally bound; check using getsockname(). + if conn.getsockname()[1] == 0: + raise + ov = _overlapped.Overlapped(NULL) + ov.ConnectEx(conn.fileno(), address) + def finish_connect(trans, key, ov): + ov.getresult() + # Use SO_UPDATE_CONNECT_CONTEXT so getsockname() etc work. + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_CONNECT_CONTEXT, 0) + return conn + return self._register(ov, conn, finish_connect) + + def accept_pipe(self, pipe): + self._register_with_iocp(pipe) + ov = _overlapped.Overlapped(NULL) + ov.ConnectNamedPipe(pipe.fileno()) + def finish(trans, key, ov): + ov.getresult() + return pipe + return self._register(ov, pipe, finish) + + def connect_pipe(self, address): + ov = _overlapped.Overlapped(NULL) + ov.WaitNamedPipeAndConnect(address, self._iocp, ov.address) + def finish(err, handle, ov): + # err, handle were arguments passed to PostQueuedCompletionStatus() + # in a function run in a thread pool. + if err == _overlapped.ERROR_SEM_TIMEOUT: + # Connection did not succeed within time limit. + msg = _overlapped.FormatMessage(err) + raise ConnectionRefusedError(0, msg, None, err) + elif err != 0: + msg = _overlapped.FormatMessage(err) + raise OSError(0, msg, None, err) + else: + return windows_utils.PipeHandle(handle) + return self._register(ov, None, finish, wait_for_post=True) + + def _register_with_iocp(self, obj): + # To get notifications of finished ops on this objects sent to the + # completion port, were must register the handle. + if obj not in self._registered: + self._registered.add(obj) + _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) + # XXX We could also use SetFileCompletionNotificationModes() + # to avoid sending notifications to completion port of ops + # that succeed immediately. + + def _register(self, ov, obj, callback, wait_for_post=False): + # Return a future which will be set with the result of the + # operation when it completes. The future's value is actually + # the value returned by callback(). + f = _OverlappedFuture(ov, loop=self._loop) + if ov.pending or wait_for_post: + # Register the overlapped operation for later. Note that + # we only store obj to prevent it from being garbage + # collected too early. + self._cache[ov.address] = (f, ov, obj, callback) + else: + # The operation has completed, so no need to postpone the + # work. We cannot take this short cut if we need the + # NumberOfBytes, CompletionKey values returned by + # PostQueuedCompletionStatus(). + try: + value = callback(None, None, ov) + except OSError as e: + f.set_exception(e) + else: + f.set_result(value) + return f + + def _get_accept_socket(self): + s = socket.socket() + s.settimeout(0) + return s + + def _poll(self, timeout=None): + if timeout is None: + ms = INFINITE + elif timeout < 0: + raise ValueError("negative timeout") + else: + ms = int(timeout * 1000 + 0.5) + if ms >= INFINITE: + raise ValueError("timeout too big") + while True: + status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms) + if status is None: + return + err, transferred, key, address = status + try: + f, ov, obj, callback = self._cache.pop(address) + except KeyError: + # key is either zero, or it is used to return a pipe + # handle which should be closed to avoid a leak. + if key not in (0, _overlapped.INVALID_HANDLE_VALUE): + _winapi.CloseHandle(key) + ms = 0 + continue + if obj in self._stopped_serving: + f.cancel() + elif not f.cancelled(): + try: + value = callback(transferred, key, ov) + except OSError as e: + f.set_exception(e) + self._results.append(f) + else: + f.set_result(value) + self._results.append(f) + ms = 0 + + def _stop_serving(self, obj): + # obj is a socket or pipe handle. It will be closed in + # BaseProactorEventLoop._stop_serving() which will make any + # pending operations fail quickly. + self._stopped_serving.add(obj) + + def close(self): + # Cancel remaining registered operations. + for address, (f, ov, obj, callback) in list(self._cache.items()): + if obj is None: + # The operation was started with connect_pipe() which + # queues a task to Windows' thread pool. This cannot + # be cancelled, so just forget it. + del self._cache[address] + else: + try: + ov.cancel() + except OSError: + pass + + while self._cache: + if not self._poll(1): + asyncio_log.debug('taking long time to close proactor') + + self._results = [] + if self._iocp is not None: + _winapi.CloseHandle(self._iocp) + self._iocp = None diff --git a/asyncio/windows_utils.py b/asyncio/windows_utils.py new file mode 100644 index 00000000..04b43e9a --- /dev/null +++ b/asyncio/windows_utils.py @@ -0,0 +1,181 @@ +""" +Various Windows specific bits and pieces +""" + +import sys + +if sys.platform != 'win32': # pragma: no cover + raise ImportError('win32 only') + +import socket +import itertools +import msvcrt +import os +import subprocess +import tempfile +import _winapi + + +__all__ = ['socketpair', 'pipe', 'Popen', 'PIPE', 'PipeHandle'] + +# +# Constants/globals +# + +BUFSIZE = 8192 +PIPE = subprocess.PIPE +_mmap_counter = itertools.count() + +# +# Replacement for socket.socketpair() +# + +def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """A socket pair usable as a self-pipe, for Windows. + + Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. + """ + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + lsock.bind(('localhost', 0)) + lsock.listen(1) + addr, port = lsock.getsockname() + csock = socket.socket(family, type, proto) + csock.setblocking(False) + try: + csock.connect((addr, port)) + except (BlockingIOError, InterruptedError): + pass + except Exception: + lsock.close() + csock.close() + raise + ssock, _ = lsock.accept() + csock.setblocking(True) + lsock.close() + return (ssock, csock) + +# +# Replacement for os.pipe() using handles instead of fds +# + +def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): + """Like os.pipe() but with overlapped support and using handles not fds.""" + address = tempfile.mktemp(prefix=r'\\.\pipe\python-pipe-%d-%d-' % + (os.getpid(), next(_mmap_counter))) + + if duplex: + openmode = _winapi.PIPE_ACCESS_DUPLEX + access = _winapi.GENERIC_READ | _winapi.GENERIC_WRITE + obsize, ibsize = bufsize, bufsize + else: + openmode = _winapi.PIPE_ACCESS_INBOUND + access = _winapi.GENERIC_WRITE + obsize, ibsize = 0, bufsize + + openmode |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE + + if overlapped[0]: + openmode |= _winapi.FILE_FLAG_OVERLAPPED + + if overlapped[1]: + flags_and_attribs = _winapi.FILE_FLAG_OVERLAPPED + else: + flags_and_attribs = 0 + + h1 = h2 = None + try: + h1 = _winapi.CreateNamedPipe( + address, openmode, _winapi.PIPE_WAIT, + 1, obsize, ibsize, _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL) + + h2 = _winapi.CreateFile( + address, access, 0, _winapi.NULL, _winapi.OPEN_EXISTING, + flags_and_attribs, _winapi.NULL) + + ov = _winapi.ConnectNamedPipe(h1, overlapped=True) + ov.GetOverlappedResult(True) + return h1, h2 + except: + if h1 is not None: + _winapi.CloseHandle(h1) + if h2 is not None: + _winapi.CloseHandle(h2) + raise + +# +# Wrapper for a pipe handle +# + +class PipeHandle: + """Wrapper for an overlapped pipe handle which is vaguely file-object like. + + The IOCP event loop can use these instead of socket objects. + """ + def __init__(self, handle): + self._handle = handle + + @property + def handle(self): + return self._handle + + def fileno(self): + return self._handle + + def close(self, *, CloseHandle=_winapi.CloseHandle): + if self._handle != -1: + CloseHandle(self._handle) + self._handle = -1 + + __del__ = close + + def __enter__(self): + return self + + def __exit__(self, t, v, tb): + self.close() + +# +# Replacement for subprocess.Popen using overlapped pipe handles +# + +class Popen(subprocess.Popen): + """Replacement for subprocess.Popen using overlapped pipe handles. + + The stdin, stdout, stderr are None or instances of PipeHandle. + """ + def __init__(self, args, stdin=None, stdout=None, stderr=None, **kwds): + stdin_rfd = stdout_wfd = stderr_wfd = None + stdin_wh = stdout_rh = stderr_rh = None + if stdin == PIPE: + stdin_rh, stdin_wh = pipe(overlapped=(False, True)) + stdin_rfd = msvcrt.open_osfhandle(stdin_rh, os.O_RDONLY) + if stdout == PIPE: + stdout_rh, stdout_wh = pipe(overlapped=(True, False)) + stdout_wfd = msvcrt.open_osfhandle(stdout_wh, 0) + if stderr == PIPE: + stderr_rh, stderr_wh = pipe(overlapped=(True, False)) + stderr_wfd = msvcrt.open_osfhandle(stderr_wh, 0) + try: + super().__init__(args, bufsize=0, universal_newlines=False, + stdin=stdin_rfd, stdout=stdout_wfd, + stderr=stderr_wfd, **kwds) + except: + for h in (stdin_wh, stdout_rh, stderr_rh): + _winapi.CloseHandle(h) + raise + else: + if stdin_wh is not None: + self.stdin = PipeHandle(stdin_wh) + if stdout_rh is not None: + self.stdout = PipeHandle(stdout_rh) + if stderr_rh is not None: + self.stderr = PipeHandle(stderr_rh) + finally: + if stdin == PIPE: + os.close(stdin_rfd) + if stdout == PIPE: + os.close(stdout_wfd) + if stderr == PIPE: + os.close(stderr_wfd) diff --git a/check.py b/check.py new file mode 100644 index 00000000..6db82d64 --- /dev/null +++ b/check.py @@ -0,0 +1,45 @@ +"""Search for lines >= 80 chars or with trailing whitespace.""" + +import os +import sys + + +def main(): + args = sys.argv[1:] or os.curdir + for arg in args: + if os.path.isdir(arg): + for dn, dirs, files in os.walk(arg): + for fn in sorted(files): + if fn.endswith('.py'): + process(os.path.join(dn, fn)) + dirs[:] = [d for d in dirs if d[0] != '.'] + dirs.sort() + else: + process(arg) + + +def isascii(x): + try: + x.encode('ascii') + return True + except UnicodeError: + return False + + +def process(fn): + try: + f = open(fn) + except IOError as err: + print(err) + return + try: + for i, line in enumerate(f): + line = line.rstrip('\n') + sline = line.rstrip() + if len(line) >= 80 or line != sline or not isascii(line): + print('{}:{:d}:{}{}'.format( + fn, i+1, sline, '_' * (len(line) - len(sline)))) + finally: + f.close() + +main() diff --git a/examples/child_process.py b/examples/child_process.py new file mode 100644 index 00000000..5a88faa6 --- /dev/null +++ b/examples/child_process.py @@ -0,0 +1,127 @@ +""" +Example of asynchronous interaction with a child python process. + +Note that on Windows we must use the IOCP event loop. +""" + +import os +import sys + +try: + import tulip +except ImportError: + # tulip is not installed + sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + import tulip + +from tulip import streams +from tulip import protocols + +if sys.platform == 'win32': + from tulip.windows_utils import Popen, PIPE + from tulip.windows_events import ProactorEventLoop +else: + from subprocess import Popen, PIPE + +# +# Return a write-only transport wrapping a writable pipe +# + +@tulip.coroutine +def connect_write_pipe(file): + loop = tulip.get_event_loop() + protocol = protocols.Protocol() + transport, _ = yield from loop.connect_write_pipe(tulip.Protocol, file) + return transport + +# +# Wrap a readable pipe in a stream +# + +@tulip.coroutine +def connect_read_pipe(file): + loop = tulip.get_event_loop() + stream_reader = streams.StreamReader(loop=loop) + def factory(): + return streams.StreamReaderProtocol(stream_reader) + transport, _ = yield from loop.connect_read_pipe(factory, file) + return stream_reader, transport + + +# +# Example +# + +@tulip.coroutine +def main(loop): + # program which prints evaluation of each expression from stdin + code = r'''if 1: + import os + def writeall(fd, buf): + while buf: + n = os.write(fd, buf) + buf = buf[n:] + while True: + s = os.read(0, 1024) + if not s: + break + s = s.decode('ascii') + s = repr(eval(s)) + '\n' + s = s.encode('ascii') + writeall(1, s) + ''' + + # commands to send to input + commands = iter([b"1+1\n", + b"2**16\n", + b"1/3\n", + b"'x'*50", + b"1/0\n"]) + + # start subprocess and wrap stdin, stdout, stderr + p = Popen([sys.executable, '-c', code], + stdin=PIPE, stdout=PIPE, stderr=PIPE) + + stdin = yield from connect_write_pipe(p.stdin) + stdout, stdout_transport = yield from connect_read_pipe(p.stdout) + stderr, stderr_transport = yield from connect_read_pipe(p.stderr) + + # interact with subprocess + name = {stdout:'OUT', stderr:'ERR'} + registered = {tulip.Task(stderr.readline()): stderr, + tulip.Task(stdout.readline()): stdout} + while registered: + # write command + cmd = next(commands, None) + if cmd is None: + stdin.close() + else: + print('>>>', cmd.decode('ascii').rstrip()) + stdin.write(cmd) + + # get and print lines from stdout, stderr + timeout = None + while registered: + done, pending = yield from tulip.wait( + registered, timeout=timeout, return_when=tulip.FIRST_COMPLETED) + if not done: + break + for f in done: + stream = registered.pop(f) + res = f.result() + print(name[stream], res.decode('ascii').rstrip()) + if res != b'': + registered[tulip.Task(stream.readline())] = stream + timeout = 0.0 + + stdout_transport.close() + stderr_transport.close() + +if __name__ == '__main__': + if sys.platform == 'win32': + loop = ProactorEventLoop() + tulip.set_event_loop(loop) + else: + loop = tulip.get_event_loop() + loop.run_until_complete(main(loop)) + loop.close() diff --git a/examples/fetch0.py b/examples/fetch0.py new file mode 100644 index 00000000..84edaa26 --- /dev/null +++ b/examples/fetch0.py @@ -0,0 +1,32 @@ +"""Simplest possible HTTP client.""" + +import sys + +from tulip import * + + +@coroutine +def fetch(): + r, w = yield from open_connection('python.org', 80) + request = 'GET / HTTP/1.0\r\n\r\n' + print('>', request, file=sys.stderr) + w.write(request.encode('latin-1')) + while True: + line = yield from r.readline() + line = line.decode('latin-1').rstrip() + if not line: + break + print('<', line, file=sys.stderr) + print(file=sys.stderr) + body = yield from r.read() + return body + + +def main(): + loop = get_event_loop() + body = loop.run_until_complete(fetch()) + print(body.decode('latin-1'), end='') + + +if __name__ == '__main__': + main() diff --git a/examples/fetch1.py b/examples/fetch1.py new file mode 100644 index 00000000..57e66e6a --- /dev/null +++ b/examples/fetch1.py @@ -0,0 +1,75 @@ +"""Fetch one URL and write its content to stdout. + +This version adds URL parsing (including SSL) and a Response object. +""" + +import sys +import urllib.parse + +from tulip import * + + +class Response: + + def __init__(self, verbose=True): + self.verbose = verbose + self.http_version = None # 'HTTP/1.1' + self.status = None # 200 + self.reason = None # 'Ok' + self.headers = [] # [('Content-Type', 'text/html')] + + @coroutine + def read(self, reader): + @coroutine + def getline(): + return (yield from reader.readline()).decode('latin-1').rstrip() + status_line = yield from getline() + if self.verbose: print('<', status_line, file=sys.stderr) + self.http_version, status, self.reason = status_line.split(None, 2) + self.status = int(status) + while True: + header_line = yield from getline() + if not header_line: + break + if self.verbose: print('<', header_line, file=sys.stderr) + # TODO: Continuation lines. + key, value = header_line.split(':', 1) + self.headers.append((key, value.strip())) + if self.verbose: print(file=sys.stderr) + + +@coroutine +def fetch(url, verbose=True): + parts = urllib.parse.urlparse(url) + if parts.scheme == 'http': + ssl = False + elif parts.scheme == 'https': + ssl = True + else: + print('URL must use http or https.') + sys.exit(1) + port = parts.port + if port is None: + port = 443 if ssl else 80 + path = parts.path or '/' + if parts.query: + path += '?' + parts.query + request = 'GET %s HTTP/1.0\r\n\r\n' % path + if verbose: + print('>', request, file=sys.stderr, end='') + r, w = yield from open_connection(parts.hostname, port, ssl=ssl) + w.write(request.encode('latin-1')) + response = Response(verbose) + yield from response.read(r) + body = yield from r.read() + return body + + +def main(): + loop = get_event_loop() + body = loop.run_until_complete(fetch(sys.argv[1], '-v' in sys.argv)) + print(body.decode('latin-1'), end='') + + +if __name__ == '__main__': + main() diff --git a/examples/fetch2.py b/examples/fetch2.py new file mode 100644 index 00000000..8badf1db --- /dev/null +++ b/examples/fetch2.py @@ -0,0 +1,138 @@ +"""Fetch one URL and write its content to stdout. + +This version adds a Request object. +""" + +import sys +import urllib.parse +from http.client import BadStatusLine + +from tulip import * + + +class Request: + + def __init__(self, url, verbose=True): + self.url = url + self.verbose = verbose + self.parts = urllib.parse.urlparse(self.url) + self.scheme = self.parts.scheme + assert self.scheme in ('http', 'https'), repr(url) + self.ssl = self.parts.scheme == 'https' + self.netloc = self.parts.netloc + self.hostname = self.parts.hostname + self.port = self.parts.port or (443 if self.ssl else 80) + self.path = (self.parts.path or '/') + self.query = self.parts.query + if self.query: + self.full_path = '%s?%s' % (self.path, self.query) + else: + self.full_path = self.path + self.http_version = 'HTTP/1.1' + self.method = 'GET' + self.headers = [] + self.reader = None + self.writer = None + + @coroutine + def connect(self): + if self.verbose: + print('* Connecting to %s:%s using %s' % + (self.hostname, self.port, 'ssl' if self.ssl else 'tcp'), + file=sys.stderr) + self.reader, self.writer = yield from open_connection(self.hostname, + self.port, + ssl=self.ssl) + if self.verbose: + print('* Connected to %s' % + (self.writer.get_extra_info('getpeername'),), + file=sys.stderr) + + def putline(self, line): + self.writer.write(line.encode('latin-1') + b'\r\n') + + @coroutine + def send_request(self): + request = '%s %s %s' % (self.method, self.full_path, self.http_version) + if self.verbose: print('>', request, file=sys.stderr) + self.putline(request) + if 'host' not in {key.lower() for key, _ in self.headers}: + self.headers.insert(0, ('Host', self.netloc)) + for key, value in self.headers: + line = '%s: %s' % (key, value) + if self.verbose: print('>', line, file=sys.stderr) + self.putline(line) + self.putline('') + + @coroutine + def get_response(self): + response = Response(self.reader, self.verbose) + yield from response.read_headers() + return response + + +class Response: + + def __init__(self, reader, verbose=True): + self.reader = reader + self.verbose = verbose + self.http_version = None # 'HTTP/1.1' + self.status = None # 200 + self.reason = None # 'Ok' + self.headers = [] # [('Content-Type', 'text/html')] + + @coroutine + def getline(self): + return (yield from self.reader.readline()).decode('latin-1').rstrip() + + @coroutine + def read_headers(self): + status_line = yield from self.getline() + if self.verbose: print('<', status_line, file=sys.stderr) + status_parts = status_line.split(None, 2) + if len(status_parts) != 3: + raise BadStatusLine(status_line) + self.http_version, status, self.reason = status_parts + self.status = int(status) + while True: + header_line = yield from self.getline() + if not header_line: + break + if self.verbose: print('<', header_line, file=sys.stderr) + # TODO: Continuation lines. + key, value = header_line.split(':', 1) + self.headers.append((key, value.strip())) + if self.verbose: print(file=sys.stderr) + + @coroutine + def read(self): + nbytes = None + for key, value in self.headers: + if key.lower() == 'content-length': + nbytes = int(value) + break + if nbytes is None: + body = yield from self.reader.read() + else: + body = yield from self.reader.readexactly(nbytes) + return body + + +@coroutine +def fetch(url, verbose=True): + request = Request(url, verbose) + yield from request.connect() + yield from request.send_request() + response = yield from request.get_response() + body = yield from response.read() + return body + + +def main(): + loop = get_event_loop() + body = loop.run_until_complete(fetch(sys.argv[1], '-v' in sys.argv)) + sys.stdout.buffer.write(body) + + +if __name__ == '__main__': + main() diff --git a/examples/fetch3.py b/examples/fetch3.py new file mode 100644 index 00000000..8142f094 --- /dev/null +++ b/examples/fetch3.py @@ -0,0 +1,210 @@ +"""Fetch one URL and write its content to stdout. + +This version adds a primitive connection pool, redirect following and +chunked transfer-encoding. +""" + +import sys +import urllib.parse +from http.client import BadStatusLine + +from tulip import * + + +class ConnectionPool: + # TODO: Locking? Close idle connections? + + def __init__(self, verbose=False): + self.verbose = verbose + self.connections = {} # {(host, port, ssl): (reader, writer)} + + @coroutine + def open_connection(self, host, port, ssl): + port = port or (443 if ssl else 80) + ipaddrs = yield from get_event_loop().getaddrinfo(host, port) + if self.verbose: + print('* %s resolves to %s' % + (host, ', '.join(ip[4][0] for ip in ipaddrs)), + file=sys.stderr) + for _, _, _, _, (h, p, *_) in ipaddrs: + key = h, p, ssl + conn = self.connections.get(key) + if conn: + if self.verbose: + print('* Reusing pooled connection', key, file=sys.stderr) + return conn + reader, writer = yield from open_connection(host, port, ssl=ssl) + host, port, *_ = writer.get_extra_info('getpeername') + key = host, port, ssl + self.connections[key] = reader, writer + if self.verbose: + print('* New connection', key, file=sys.stderr) + return reader, writer + + +class Request: + + def __init__(self, url, verbose=True): + self.url = url + self.verbose = verbose + self.parts = urllib.parse.urlparse(self.url) + self.scheme = self.parts.scheme + assert self.scheme in ('http', 'https'), repr(url) + self.ssl = self.parts.scheme == 'https' + self.netloc = self.parts.netloc + self.hostname = self.parts.hostname + self.port = self.parts.port or (443 if self.ssl else 80) + self.path = (self.parts.path or '/') + self.query = self.parts.query + if self.query: + self.full_path = '%s?%s' % (self.path, self.query) + else: + self.full_path = self.path + self.http_version = 'HTTP/1.1' + self.method = 'GET' + self.headers = [] + self.reader = None + self.writer = None + + def vprint(self, *args): + if self.verbose: + print(*args, file=sys.stderr) + + @coroutine + def connect(self, pool): + self.vprint('* Connecting to %s:%s using %s' % + (self.hostname, self.port, 'ssl' if self.ssl else 'tcp')) + self.reader, self.writer = \ + yield from pool.open_connection(self.hostname, + self.port, + ssl=self.ssl) + self.vprint('* Connected to %s' % + (self.writer.get_extra_info('getpeername'),)) + + @coroutine + def putline(self, line): + self.vprint('>', line) + self.writer.write(line.encode('latin-1') + b'\r\n') + yield from self.writer.drain() + + @coroutine + def send_request(self): + request = '%s %s %s' % (self.method, self.full_path, self.http_version) + yield from self.putline(request) + if 'host' not in {key.lower() for key, _ in self.headers}: + self.headers.insert(0, ('Host', self.netloc)) + for key, value in self.headers: + line = '%s: %s' % (key, value) + yield from self.putline(line) + yield from self.putline('') + + @coroutine + def get_response(self): + response = Response(self.reader, self.verbose) + yield from response.read_headers() + return response + + +class Response: + + def __init__(self, reader, verbose=True): + self.reader = reader + self.verbose = verbose + self.http_version = None # 'HTTP/1.1' + self.status = None # 200 + self.reason = None # 'Ok' + self.headers = [] # [('Content-Type', 'text/html')] + + def vprint(self, *args): + if self.verbose: + print(*args, file=sys.stderr) + + @coroutine + def getline(self): + line = (yield from self.reader.readline()).decode('latin-1').rstrip() + self.vprint('<', line) + return line + + @coroutine + def read_headers(self): + status_line = yield from self.getline() + status_parts = status_line.split(None, 2) + if len(status_parts) != 3: + raise BadStatusLine(status_line) + self.http_version, status, self.reason = status_parts + self.status = int(status) + while True: + header_line = yield from self.getline() + if not header_line: + break + # TODO: Continuation lines. + key, value = header_line.split(':', 1) + self.headers.append((key, value.strip())) + + def get_redirect_url(self, default=None): + if self.status not in (300, 301, 302, 303, 307): + return default + return self.get_header('Location', default) + + def get_header(self, key, default=None): + key = key.lower() + for k, v in self.headers: + if k.lower() == key: + return v + return default + + @coroutine + def read(self): + nbytes = None + for key, value in self.headers: + if key.lower() == 'content-length': + nbytes = int(value) + break + if nbytes is None: + if self.get_header('transfer-encoding', '').lower() == 'chunked': + blocks = [] + while True: + size_header = yield from self.reader.readline() + if not size_header: + break + parts = size_header.split(b';') + size = int(parts[0], 16) + if not size: + break + block = yield from self.reader.readexactly(size) + assert len(block) == size, (len(block), size) + blocks.append(block) + crlf = yield from self.reader.readline() + assert crlf == b'\r\n' + body = b''.join(blocks) + else: + body = yield from self.reader.read() + else: + body = yield from self.reader.readexactly(nbytes) + return body + + +@coroutine +def fetch(url, verbose=True, max_redirect=10): + pool = ConnectionPool(verbose) + for _ in range(max_redirect): + request = Request(url, verbose) + yield from request.connect(pool) + yield from request.send_request() + response = yield from request.get_response() + body = yield from response.read() + next_url = response.get_redirect_url() + if not next_url: + break + url = urllib.parse.urljoin(url, next_url) + return body + + +def main(): + loop = get_event_loop() + body = loop.run_until_complete(fetch(sys.argv[1], '-v' in sys.argv)) + sys.stdout.buffer.write(body) + + +if __name__ == '__main__': + main() diff --git a/examples/sink.py b/examples/sink.py new file mode 100644 index 00000000..855a4aa1 --- /dev/null +++ b/examples/sink.py @@ -0,0 +1,52 @@ +"""Test service that accepts connections and reads all data off them.""" + +import sys + +from tulip import * + +server = None + +def dprint(*args): + print('sink:', *args, file=sys.stderr) + +class Service(Protocol): + + def connection_made(self, tr): + dprint('connection from', tr.get_extra_info('peername')) + dprint('my socket is', tr.get_extra_info('sockname')) + self.tr = tr + self.total = 0 + + def data_received(self, data): + if data == b'stop': + dprint('stopping server') + server.close() + self.tr.close() + return + self.total += len(data) + dprint('received', len(data), 'bytes; total', self.total) + if self.total > 1e6: + dprint('closing due to too much data') + self.tr.close() + + def connection_lost(self, how): + dprint('closed', repr(how)) + +@coroutine +def start(loop): + global server + server = yield from loop.create_server(Service, 'localhost', 1111) + dprint('serving', [s.getsockname() for s in server.sockets]) + yield from server.wait_closed() + +def main(): + if '--iocp' in sys.argv: + from tulip.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + set_event_loop(loop) + loop = get_event_loop() + loop.run_until_complete(start(loop)) + loop.close() + +if __name__ == '__main__': + main() diff --git a/examples/source.py b/examples/source.py new file mode 100644 index 00000000..7a5011da --- /dev/null +++ b/examples/source.py @@ -0,0 +1,58 @@ +"""Test client that connects and sends infinite data.""" + +import sys + +from tulip import * + +def dprint(*args): + print('source:', *args, file=sys.stderr) + +class Client(Protocol): + + data = b'x'*16*1024 + + def connection_made(self, tr): + dprint('connecting to', tr.get_extra_info('peername')) + dprint('my socket is', tr.get_extra_info('sockname')) + self.tr = tr + self.lost = False + self.loop = get_event_loop() + self.waiter = Future() + if '--stop' in sys.argv[1:]: + self.tr.write(b'stop') + self.tr.close() + else: + self.write_some_data() + + def write_some_data(self): + if self.lost: + dprint('lost already') + return + dprint('writing', len(self.data), 'bytes') + self.tr.write(self.data) + self.loop.call_soon(self.write_some_data) + + def connection_lost(self, exc): + dprint('lost connection', repr(exc)) + self.lost = True + self.waiter.set_result(None) + +@coroutine +def start(loop): + tr, pr = yield from loop.create_connection(Client, '127.0.0.1', 1111) + dprint('tr =', tr) + dprint('pr =', pr) + res = yield from pr.waiter + return res + +def main(): + if '--iocp' in sys.argv: + from tulip.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + set_event_loop(loop) + loop = get_event_loop() + loop.run_until_complete(start(loop)) + loop.close() + +if __name__ == '__main__': + main() diff --git a/examples/stacks.py b/examples/stacks.py new file mode 100644 index 00000000..77a99cf5 --- /dev/null +++ b/examples/stacks.py @@ -0,0 +1,43 @@ +"""Crude demo for print_stack().""" + + +from tulip import * + + +@coroutine +def helper(r): + print('--- helper ---') + for t in Task.all_tasks(): + t.print_stack() + print('--- end helper ---') + line = yield from r.readline() + 1/0 + return line + +def doit(): + l = get_event_loop() + lr = l.run_until_complete + r, w = lr(open_connection('python.org', 80)) + t1 = async(helper(r)) + for t in Task.all_tasks(): t.print_stack() + print('---') + l._run_once() + for t in Task.all_tasks(): t.print_stack() + print('---') + w.write(b'GET /\r\n') + w.write_eof() + try: + lr(t1) + except Exception as e: + print('catching', e) + finally: + for t in Task.all_tasks(): + t.print_stack() + + +def main(): + doit() + + +if __name__ == '__main__': + main() diff --git a/examples/tcp_echo.py b/examples/tcp_echo.py new file mode 100755 index 00000000..39db5cca --- /dev/null +++ b/examples/tcp_echo.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +"""TCP echo server example.""" +import argparse +import tulip +try: + import signal +except ImportError: + signal = None + + +class EchoServer(tulip.Protocol): + + TIMEOUT = 5.0 + + def timeout(self): + print('connection timeout, closing.') + self.transport.close() + + def connection_made(self, transport): + print('connection made') + self.transport = transport + + # start 5 seconds timeout timer + self.h_timeout = tulip.get_event_loop().call_later( + self.TIMEOUT, self.timeout) + + def data_received(self, data): + print('data received: ', data.decode()) + self.transport.write(b'Re: ' + data) + + # restart timeout timer + self.h_timeout.cancel() + self.h_timeout = tulip.get_event_loop().call_later( + self.TIMEOUT, self.timeout) + + def eof_received(self): + pass + + def connection_lost(self, exc): + print('connection lost:', exc) + self.h_timeout.cancel() + + +class EchoClient(tulip.Protocol): + + message = 'This is the message. It will be echoed.' + + def connection_made(self, transport): + self.transport = transport + self.transport.write(self.message.encode()) + print('data sent:', self.message) + + def data_received(self, data): + print('data received:', data) + + # disconnect after 10 seconds + tulip.get_event_loop().call_later(10.0, self.transport.close) + + def eof_received(self): + pass + + def connection_lost(self, exc): + print('connection lost:', exc) + tulip.get_event_loop().stop() + + +def start_client(loop, host, port): + t = tulip.Task(loop.create_connection(EchoClient, host, port)) + loop.run_until_complete(t) + + +def start_server(loop, host, port): + f = loop.start_serving(EchoServer, host, port) + x = loop.run_until_complete(f)[0] + print('serving on', x.getsockname()) + + +ARGS = argparse.ArgumentParser(description="TCP Echo example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run tcp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run tcp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') + + +if __name__ == '__main__': + args = ARGS.parse_args() + + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + loop = tulip.get_event_loop() + if signal is not None: + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if args.server: + start_server(loop, args.host, args.port) + else: + start_client(loop, args.host, args.port) + + loop.run_forever() diff --git a/examples/udp_echo.py b/examples/udp_echo.py new file mode 100755 index 00000000..0347bfbd --- /dev/null +++ b/examples/udp_echo.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +"""UDP echo example.""" +import argparse +import sys +import tulip +try: + import signal +except ImportError: + signal = None + + +class MyServerUdpEchoProtocol: + + def connection_made(self, transport): + print('start', transport) + self.transport = transport + + def datagram_received(self, data, addr): + print('Data received:', data, addr) + self.transport.sendto(data, addr) + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('stop', exc) + + +class MyClientUdpEchoProtocol: + + message = 'This is the message. It will be echoed.' + + def connection_made(self, transport): + self.transport = transport + print('sending "{}"'.format(self.message)) + self.transport.sendto(self.message.encode()) + print('waiting to receive') + + def datagram_received(self, data, addr): + print('received "{}"'.format(data.decode())) + self.transport.close() + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('closing transport', exc) + loop = tulip.get_event_loop() + loop.stop() + + +def start_server(loop, addr): + t = tulip.Task(loop.create_datagram_endpoint( + MyServerUdpEchoProtocol, local_addr=addr)) + loop.run_until_complete(t) + + +def start_client(loop, addr): + t = tulip.Task(loop.create_datagram_endpoint( + MyClientUdpEchoProtocol, remote_addr=addr)) + loop.run_until_complete(t) + + +ARGS = argparse.ArgumentParser(description="UDP Echo example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run udp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run udp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') + + +if __name__ == '__main__': + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + loop = tulip.get_event_loop() + if signal is not None: + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if '--server' in sys.argv: + start_server(loop, (args.host, args.port)) + else: + start_client(loop, (args.host, args.port)) + + loop.run_forever() diff --git a/overlapped.c b/overlapped.c new file mode 100644 index 00000000..b5be63a0 --- /dev/null +++ b/overlapped.c @@ -0,0 +1,1203 @@ +/* + * Support for overlapped IO + * + * Some code borrowed from Modules/_winapi.c of CPython + */ + +/* XXX check overflow and DWORD <-> Py_ssize_t conversions + Check itemsize */ + +#include "Python.h" +#include "structmember.h" + +#define WINDOWS_LEAN_AND_MEAN +#include +#include +#include + +#if defined(MS_WIN32) && !defined(MS_WIN64) +# define F_POINTER "k" +# define T_POINTER T_ULONG +#else +# define F_POINTER "K" +# define T_POINTER T_ULONGLONG +#endif + +#define F_HANDLE F_POINTER +#define F_ULONG_PTR F_POINTER +#define F_DWORD "k" +#define F_BOOL "i" +#define F_UINT "I" + +#define T_HANDLE T_POINTER + +enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_WRITE, TYPE_ACCEPT, + TYPE_CONNECT, TYPE_DISCONNECT, TYPE_CONNECT_NAMED_PIPE, + TYPE_WAIT_NAMED_PIPE_AND_CONNECT}; + +typedef struct { + PyObject_HEAD + OVERLAPPED overlapped; + /* For convenience, we store the file handle too */ + HANDLE handle; + /* Error returned by last method call */ + DWORD error; + /* Type of operation */ + DWORD type; + union { + /* Buffer used for reading (optional) */ + PyObject *read_buffer; + /* Buffer used for writing (optional) */ + Py_buffer write_buffer; + }; +} OverlappedObject; + +typedef struct { + OVERLAPPED *Overlapped; + HANDLE IocpHandle; + char Address[1]; +} WaitNamedPipeAndConnectContext; + +/* + * Map Windows error codes to subclasses of OSError + */ + +static PyObject * +SetFromWindowsErr(DWORD err) +{ + PyObject *exception_type; + + if (err == 0) + err = GetLastError(); + switch (err) { + case ERROR_CONNECTION_REFUSED: + exception_type = PyExc_ConnectionRefusedError; + break; + case ERROR_CONNECTION_ABORTED: + exception_type = PyExc_ConnectionAbortedError; + break; + default: + exception_type = PyExc_OSError; + } + return PyErr_SetExcFromWindowsErr(exception_type, err); +} + +/* + * Some functions should be loaded at runtime + */ + +static LPFN_ACCEPTEX Py_AcceptEx = NULL; +static LPFN_CONNECTEX Py_ConnectEx = NULL; +static LPFN_DISCONNECTEX Py_DisconnectEx = NULL; +static BOOL (CALLBACK *Py_CancelIoEx)(HANDLE, LPOVERLAPPED) = NULL; + +#define GET_WSA_POINTER(s, x) \ + (SOCKET_ERROR != WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, \ + &Guid##x, sizeof(Guid##x), &Py_##x, \ + sizeof(Py_##x), &dwBytes, NULL, NULL)) + +static int +initialize_function_pointers(void) +{ + GUID GuidAcceptEx = WSAID_ACCEPTEX; + GUID GuidConnectEx = WSAID_CONNECTEX; + GUID GuidDisconnectEx = WSAID_DISCONNECTEX; + HINSTANCE hKernel32; + SOCKET s; + DWORD dwBytes; + + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (s == INVALID_SOCKET) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + if (!GET_WSA_POINTER(s, AcceptEx) || + !GET_WSA_POINTER(s, ConnectEx) || + !GET_WSA_POINTER(s, DisconnectEx)) + { + closesocket(s); + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + closesocket(s); + + /* On WinXP we will have Py_CancelIoEx == NULL */ + hKernel32 = GetModuleHandle("KERNEL32"); + *(FARPROC *)&Py_CancelIoEx = GetProcAddress(hKernel32, "CancelIoEx"); + return 0; +} + +/* + * Completion port stuff + */ + +PyDoc_STRVAR( + CreateIoCompletionPort_doc, + "CreateIoCompletionPort(handle, port, key, concurrency) -> port\n\n" + "Create a completion port or register a handle with a port."); + +static PyObject * +overlapped_CreateIoCompletionPort(PyObject *self, PyObject *args) +{ + HANDLE FileHandle; + HANDLE ExistingCompletionPort; + ULONG_PTR CompletionKey; + DWORD NumberOfConcurrentThreads; + HANDLE ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE F_ULONG_PTR F_DWORD, + &FileHandle, &ExistingCompletionPort, &CompletionKey, + &NumberOfConcurrentThreads)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = CreateIoCompletionPort(FileHandle, ExistingCompletionPort, + CompletionKey, NumberOfConcurrentThreads); + Py_END_ALLOW_THREADS + + if (ret == NULL) + return SetFromWindowsErr(0); + return Py_BuildValue(F_HANDLE, ret); +} + +PyDoc_STRVAR( + GetQueuedCompletionStatus_doc, + "GetQueuedCompletionStatus(port, msecs) -> (err, bytes, key, address)\n\n" + "Get a message from completion port. Wait for up to msecs milliseconds."); + +static PyObject * +overlapped_GetQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort = NULL; + DWORD NumberOfBytes = 0; + ULONG_PTR CompletionKey = 0; + OVERLAPPED *Overlapped = NULL; + DWORD Milliseconds; + DWORD err; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, + &CompletionPort, &Milliseconds)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = GetQueuedCompletionStatus(CompletionPort, &NumberOfBytes, + &CompletionKey, &Overlapped, Milliseconds); + Py_END_ALLOW_THREADS + + err = ret ? ERROR_SUCCESS : GetLastError(); + if (Overlapped == NULL) { + if (err == WAIT_TIMEOUT) + Py_RETURN_NONE; + else + return SetFromWindowsErr(err); + } + return Py_BuildValue(F_DWORD F_DWORD F_ULONG_PTR F_POINTER, + err, NumberOfBytes, CompletionKey, Overlapped); +} + +PyDoc_STRVAR( + PostQueuedCompletionStatus_doc, + "PostQueuedCompletionStatus(port, bytes, key, address) -> None\n\n" + "Post a message to completion port."); + +static PyObject * +overlapped_PostQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort; + DWORD NumberOfBytes; + ULONG_PTR CompletionKey; + OVERLAPPED *Overlapped; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD F_ULONG_PTR F_POINTER, + &CompletionPort, &NumberOfBytes, &CompletionKey, + &Overlapped)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = PostQueuedCompletionStatus(CompletionPort, NumberOfBytes, + CompletionKey, Overlapped); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +/* + * Bind socket handle to local port without doing slow getaddrinfo() + */ + +PyDoc_STRVAR( + BindLocal_doc, + "BindLocal(handle, length_of_address_tuple) -> None\n\n" + "Bind a socket handle to an arbitrary local port.\n" + "If length_of_address_tuple is 2 then an AF_INET address is used.\n" + "If length_of_address_tuple is 4 then an AF_INET6 address is used."); + +static PyObject * +overlapped_BindLocal(PyObject *self, PyObject *args) +{ + SOCKET Socket; + int TupleLength; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE "i", &Socket, &TupleLength)) + return NULL; + + if (TupleLength == 2) { + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = 0; + addr.sin_addr.S_un.S_addr = INADDR_ANY; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else if (TupleLength == 4) { + struct sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_port = 0; + addr.sin6_addr = in6addr_any; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else { + PyErr_SetString(PyExc_ValueError, "expected tuple of length 2 or 4"); + return NULL; + } + + if (!ret) + return SetFromWindowsErr(WSAGetLastError()); + Py_RETURN_NONE; +} + +/* + * Windows equivalent of os.strerror() -- compare _ctypes/callproc.c + */ + +PyDoc_STRVAR( + FormatMessage_doc, + "FormatMessage(error_code) -> error_message\n\n" + "Return error message for an error code."); + +static PyObject * +overlapped_FormatMessage(PyObject *ignore, PyObject *args) +{ + DWORD code, n; + WCHAR *lpMsgBuf; + PyObject *res; + + if (!PyArg_ParseTuple(args, F_DWORD, &code)) + return NULL; + + n = FormatMessageW(FORMAT_MESSAGE_ALLOCATE_BUFFER | + FORMAT_MESSAGE_FROM_SYSTEM, + NULL, + code, + MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + (LPWSTR) &lpMsgBuf, + 0, + NULL); + if (n) { + while (iswspace(lpMsgBuf[n-1])) + --n; + lpMsgBuf[n] = L'\0'; + res = Py_BuildValue("u", lpMsgBuf); + } else { + res = PyUnicode_FromFormat("unknown error code %u", code); + } + LocalFree(lpMsgBuf); + return res; +} + + +/* + * Mark operation as completed - used when reading produces ERROR_BROKEN_PIPE + */ + +static void +mark_as_completed(OVERLAPPED *ov) +{ + ov->Internal = 0; + if (ov->hEvent != NULL) + SetEvent(ov->hEvent); +} + +/* + * A Python object wrapping an OVERLAPPED structure and other useful data + * for overlapped I/O + */ + +PyDoc_STRVAR( + Overlapped_doc, + "Overlapped object"); + +static PyObject * +Overlapped_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + OverlappedObject *self; + HANDLE event = INVALID_HANDLE_VALUE; + static char *kwlist[] = {"event", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|" F_HANDLE, kwlist, &event)) + return NULL; + + if (event == INVALID_HANDLE_VALUE) { + event = CreateEvent(NULL, TRUE, FALSE, NULL); + if (event == NULL) + return SetFromWindowsErr(0); + } + + self = PyObject_New(OverlappedObject, type); + if (self == NULL) { + if (event != NULL) + CloseHandle(event); + return NULL; + } + + self->handle = NULL; + self->error = 0; + self->type = TYPE_NONE; + self->read_buffer = NULL; + memset(&self->overlapped, 0, sizeof(OVERLAPPED)); + memset(&self->write_buffer, 0, sizeof(Py_buffer)); + if (event) + self->overlapped.hEvent = event; + return (PyObject *)self; +} + +static void +Overlapped_dealloc(OverlappedObject *self) +{ + DWORD bytes; + DWORD olderr = GetLastError(); + BOOL wait = FALSE; + BOOL ret; + + if (!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED) + { + if (Py_CancelIoEx && Py_CancelIoEx(self->handle, &self->overlapped)) + wait = TRUE; + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, + &bytes, wait); + Py_END_ALLOW_THREADS + + switch (ret ? ERROR_SUCCESS : GetLastError()) { + case ERROR_SUCCESS: + case ERROR_NOT_FOUND: + case ERROR_OPERATION_ABORTED: + break; + default: + PyErr_Format( + PyExc_RuntimeError, + "%R still has pending operation at " + "deallocation, the process may crash", self); + PyErr_WriteUnraisable(NULL); + } + } + + if (self->overlapped.hEvent != NULL) + CloseHandle(self->overlapped.hEvent); + + if (self->write_buffer.obj) + PyBuffer_Release(&self->write_buffer); + + switch (self->type) { + case TYPE_READ: + case TYPE_ACCEPT: + Py_CLEAR(self->read_buffer); + } + PyObject_Del(self); + SetLastError(olderr); +} + +PyDoc_STRVAR( + Overlapped_cancel_doc, + "cancel() -> None\n\n" + "Cancel overlapped operation"); + +static PyObject * +Overlapped_cancel(OverlappedObject *self) +{ + BOOL ret = TRUE; + + if (self->type == TYPE_NOT_STARTED + || self->type == TYPE_WAIT_NAMED_PIPE_AND_CONNECT) + Py_RETURN_NONE; + + if (!HasOverlappedIoCompleted(&self->overlapped)) { + Py_BEGIN_ALLOW_THREADS + if (Py_CancelIoEx) + ret = Py_CancelIoEx(self->handle, &self->overlapped); + else + ret = CancelIo(self->handle); + Py_END_ALLOW_THREADS + } + + /* CancelIoEx returns ERROR_NOT_FOUND if the I/O completed in-between */ + if (!ret && GetLastError() != ERROR_NOT_FOUND) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +PyDoc_STRVAR( + Overlapped_getresult_doc, + "getresult(wait=False) -> result\n\n" + "Retrieve result of operation. If wait is true then it blocks\n" + "until the operation is finished. If wait is false and the\n" + "operation is still pending then an error is raised."); + +static PyObject * +Overlapped_getresult(OverlappedObject *self, PyObject *args) +{ + BOOL wait = FALSE; + DWORD transferred = 0; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, "|" F_BOOL, &wait)) + return NULL; + + if (self->type == TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation not yet attempted"); + return NULL; + } + + if (self->type == TYPE_NOT_STARTED) { + PyErr_SetString(PyExc_ValueError, "operation failed to start"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, &transferred, + wait); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + break; + case ERROR_BROKEN_PIPE: + if (self->read_buffer != NULL) + break; + /* fall through */ + default: + return SetFromWindowsErr(err); + } + + switch (self->type) { + case TYPE_READ: + assert(PyBytes_CheckExact(self->read_buffer)); + if (transferred != PyBytes_GET_SIZE(self->read_buffer) && + _PyBytes_Resize(&self->read_buffer, transferred)) + return NULL; + Py_INCREF(self->read_buffer); + return self->read_buffer; + default: + return PyLong_FromUnsignedLong((unsigned long) transferred); + } +} + +PyDoc_STRVAR( + Overlapped_ReadFile_doc, + "ReadFile(handle, size) -> Overlapped[message]\n\n" + "Start overlapped read"); + +static PyObject * +Overlapped_ReadFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD nread; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &handle, &size)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = ReadFile(handle, PyBytes_AS_STRING(buf), size, &nread, + &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_BROKEN_PIPE: + mark_as_completed(&self->overlapped); + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSARecv_doc, + "RecvFile(handle, size, flags) -> Overlapped[message]\n\n" + "Start overlapped receive"); + +static PyObject * +Overlapped_WSARecv(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD flags = 0; + DWORD nread; + PyObject *buf; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD "|" F_DWORD, + &handle, &size, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + wsabuf.len = size; + wsabuf.buf = PyBytes_AS_STRING(buf); + + Py_BEGIN_ALLOW_THREADS + ret = WSARecv((SOCKET)handle, &wsabuf, 1, &nread, &flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_BROKEN_PIPE: + mark_as_completed(&self->overlapped); + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WriteFile_doc, + "WriteFile(handle, buf) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped write"); + +static PyObject * +Overlapped_WriteFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD written; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &handle, &bufobj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + + Py_BEGIN_ALLOW_THREADS + ret = WriteFile(handle, self->write_buffer.buf, + (DWORD)self->write_buffer.len, + &written, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSASend_doc, + "WSASend(handle, buf, flags) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped send"); + +static PyObject * +Overlapped_WSASend(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD flags; + DWORD written; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O" F_DWORD, + &handle, &bufobj, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + wsabuf.len = (DWORD)self->write_buffer.len; + wsabuf.buf = self->write_buffer.buf; + + Py_BEGIN_ALLOW_THREADS + ret = WSASend((SOCKET)handle, &wsabuf, 1, &written, flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_AcceptEx_doc, + "AcceptEx(listen_handle, accept_handle) -> Overlapped[address_as_bytes]\n\n" + "Start overlapped wait for client to connect"); + +static PyObject * +Overlapped_AcceptEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ListenSocket; + SOCKET AcceptSocket; + DWORD BytesReceived; + DWORD size; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE, + &ListenSocket, &AcceptSocket)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + size = sizeof(struct sockaddr_in6) + 16; + buf = PyBytes_FromStringAndSize(NULL, size*2); + if (!buf) + return NULL; + + self->type = TYPE_ACCEPT; + self->handle = (HANDLE)ListenSocket; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = Py_AcceptEx(ListenSocket, AcceptSocket, PyBytes_AS_STRING(buf), + 0, size, size, &BytesReceived, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + + +static int +parse_address(PyObject *obj, SOCKADDR *Address, int Length) +{ + char *Host; + unsigned short Port; + unsigned long FlowInfo; + unsigned long ScopeId; + + memset(Address, 0, Length); + + if (PyArg_ParseTuple(obj, "sH", &Host, &Port)) + { + Address->sa_family = AF_INET; + if (WSAStringToAddressA(Host, AF_INET, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN*)Address)->sin_port = htons(Port); + return Length; + } + else if (PyArg_ParseTuple(obj, "sHkk", &Host, &Port, &FlowInfo, &ScopeId)) + { + PyErr_Clear(); + Address->sa_family = AF_INET6; + if (WSAStringToAddressA(Host, AF_INET6, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN6*)Address)->sin6_port = htons(Port); + ((SOCKADDR_IN6*)Address)->sin6_flowinfo = FlowInfo; + ((SOCKADDR_IN6*)Address)->sin6_scope_id = ScopeId; + return Length; + } + + return -1; +} + + +PyDoc_STRVAR( + Overlapped_ConnectEx_doc, + "ConnectEx(client_handle, address_as_bytes) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_ConnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ConnectSocket; + PyObject *AddressObj; + char AddressBuf[sizeof(struct sockaddr_in6)]; + SOCKADDR *Address = (SOCKADDR*)AddressBuf; + int Length; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &ConnectSocket, &AddressObj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + Length = sizeof(AddressBuf); + Length = parse_address(AddressObj, Address, Length); + if (Length < 0) + return NULL; + + self->type = TYPE_CONNECT; + self->handle = (HANDLE)ConnectSocket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_ConnectEx(ConnectSocket, Address, Length, + NULL, 0, NULL, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_DisconnectEx_doc, + "DisconnectEx(handle, flags) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_DisconnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET Socket; + DWORD flags; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &Socket, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + self->type = TYPE_DISCONNECT; + self->handle = (HANDLE)Socket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_DisconnectEx(Socket, &self->overlapped, flags, 0); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_ConnectNamedPipe_doc, + "ConnectNamedPipe(handle) -> Overlapped[None]\n\n" + "Start overlapped wait for a client to connect."); + +static PyObject * +Overlapped_ConnectNamedPipe(OverlappedObject *self, PyObject *args) +{ + HANDLE Pipe; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE, &Pipe)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + self->type = TYPE_CONNECT_NAMED_PIPE; + self->handle = Pipe; + + Py_BEGIN_ALLOW_THREADS + ret = ConnectNamedPipe(Pipe, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_PIPE_CONNECTED: + mark_as_completed(&self->overlapped); + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +/* Unfortunately there is no way to do an overlapped connect to a + pipe. We instead use WaitNamedPipe() and CreateFile() in a thread + pool thread. If a connection succeeds within a time limit (10 + seconds) then PostQueuedCompletionStatus() is used to return the + pipe handle to the completion port. */ + +static DWORD WINAPI +WaitNamedPipeAndConnectInThread(WaitNamedPipeAndConnectContext *ctx) +{ + HANDLE PipeHandle = INVALID_HANDLE_VALUE; + DWORD Start = GetTickCount(); + DWORD Deadline = Start + 10*1000; + DWORD Error = 0; + DWORD Timeout; + BOOL Success; + + for ( ; ; ) { + Timeout = Deadline - GetTickCount(); + if ((int)Timeout < 0) + break; + Success = WaitNamedPipe(ctx->Address, Timeout); + Error = Success ? ERROR_SUCCESS : GetLastError(); + switch (Error) { + case ERROR_SUCCESS: + PipeHandle = CreateFile(ctx->Address, + GENERIC_READ | GENERIC_WRITE, + 0, NULL, OPEN_EXISTING, + FILE_FLAG_OVERLAPPED, NULL); + if (PipeHandle == INVALID_HANDLE_VALUE) + continue; + break; + case ERROR_SEM_TIMEOUT: + continue; + } + break; + } + if (!PostQueuedCompletionStatus(ctx->IocpHandle, Error, + (ULONG_PTR)PipeHandle, ctx->Overlapped)) + CloseHandle(PipeHandle); + free(ctx); + return 0; +} + +PyDoc_STRVAR( + Overlapped_WaitNamedPipeAndConnect_doc, + "WaitNamedPipeAndConnect(addr, iocp_handle) -> Overlapped[pipe_handle]\n\n" + "Start overlapped connection to address, notifying iocp_handle when\n" + "finished"); + +static PyObject * +Overlapped_WaitNamedPipeAndConnect(OverlappedObject *self, PyObject *args) +{ + char *Address; + Py_ssize_t AddressLength; + HANDLE IocpHandle; + OVERLAPPED Overlapped; + BOOL ret; + DWORD err; + WaitNamedPipeAndConnectContext *ctx; + Py_ssize_t ContextLength; + + if (!PyArg_ParseTuple(args, "s#" F_HANDLE F_POINTER, + &Address, &AddressLength, &IocpHandle, &Overlapped)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + ContextLength = (AddressLength + + offsetof(WaitNamedPipeAndConnectContext, Address)); + ctx = calloc(1, ContextLength + 1); + if (ctx == NULL) + return PyErr_NoMemory(); + memcpy(ctx->Address, Address, AddressLength + 1); + ctx->Overlapped = &self->overlapped; + ctx->IocpHandle = IocpHandle; + + self->type = TYPE_WAIT_NAMED_PIPE_AND_CONNECT; + self->handle = NULL; + + Py_BEGIN_ALLOW_THREADS + ret = QueueUserWorkItem(WaitNamedPipeAndConnectInThread, ctx, + WT_EXECUTELONGFUNCTION); + Py_END_ALLOW_THREADS + + mark_as_completed(&self->overlapped); + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + if (!ret) + return SetFromWindowsErr(err); + Py_RETURN_NONE; +} + +static PyObject* +Overlapped_getaddress(OverlappedObject *self) +{ + return PyLong_FromVoidPtr(&self->overlapped); +} + +static PyObject* +Overlapped_getpending(OverlappedObject *self) +{ + return PyBool_FromLong(!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED); +} + +static PyMethodDef Overlapped_methods[] = { + {"getresult", (PyCFunction) Overlapped_getresult, + METH_VARARGS, Overlapped_getresult_doc}, + {"cancel", (PyCFunction) Overlapped_cancel, + METH_NOARGS, Overlapped_cancel_doc}, + {"ReadFile", (PyCFunction) Overlapped_ReadFile, + METH_VARARGS, Overlapped_ReadFile_doc}, + {"WSARecv", (PyCFunction) Overlapped_WSARecv, + METH_VARARGS, Overlapped_WSARecv_doc}, + {"WriteFile", (PyCFunction) Overlapped_WriteFile, + METH_VARARGS, Overlapped_WriteFile_doc}, + {"WSASend", (PyCFunction) Overlapped_WSASend, + METH_VARARGS, Overlapped_WSASend_doc}, + {"AcceptEx", (PyCFunction) Overlapped_AcceptEx, + METH_VARARGS, Overlapped_AcceptEx_doc}, + {"ConnectEx", (PyCFunction) Overlapped_ConnectEx, + METH_VARARGS, Overlapped_ConnectEx_doc}, + {"DisconnectEx", (PyCFunction) Overlapped_DisconnectEx, + METH_VARARGS, Overlapped_DisconnectEx_doc}, + {"ConnectNamedPipe", (PyCFunction) Overlapped_ConnectNamedPipe, + METH_VARARGS, Overlapped_ConnectNamedPipe_doc}, + {"WaitNamedPipeAndConnect", + (PyCFunction) Overlapped_WaitNamedPipeAndConnect, + METH_VARARGS, Overlapped_WaitNamedPipeAndConnect_doc}, + {NULL} +}; + +static PyMemberDef Overlapped_members[] = { + {"error", T_ULONG, + offsetof(OverlappedObject, error), + READONLY, "Error from last operation"}, + {"event", T_HANDLE, + offsetof(OverlappedObject, overlapped) + offsetof(OVERLAPPED, hEvent), + READONLY, "Overlapped event handle"}, + {NULL} +}; + +static PyGetSetDef Overlapped_getsets[] = { + {"address", (getter)Overlapped_getaddress, NULL, + "Address of overlapped structure"}, + {"pending", (getter)Overlapped_getpending, NULL, + "Whether the operation is pending"}, + {NULL}, +}; + +PyTypeObject OverlappedType = { + PyVarObject_HEAD_INIT(NULL, 0) + /* tp_name */ "_overlapped.Overlapped", + /* tp_basicsize */ sizeof(OverlappedObject), + /* tp_itemsize */ 0, + /* tp_dealloc */ (destructor) Overlapped_dealloc, + /* tp_print */ 0, + /* tp_getattr */ 0, + /* tp_setattr */ 0, + /* tp_reserved */ 0, + /* tp_repr */ 0, + /* tp_as_number */ 0, + /* tp_as_sequence */ 0, + /* tp_as_mapping */ 0, + /* tp_hash */ 0, + /* tp_call */ 0, + /* tp_str */ 0, + /* tp_getattro */ 0, + /* tp_setattro */ 0, + /* tp_as_buffer */ 0, + /* tp_flags */ Py_TPFLAGS_DEFAULT, + /* tp_doc */ "OVERLAPPED structure wrapper", + /* tp_traverse */ 0, + /* tp_clear */ 0, + /* tp_richcompare */ 0, + /* tp_weaklistoffset */ 0, + /* tp_iter */ 0, + /* tp_iternext */ 0, + /* tp_methods */ Overlapped_methods, + /* tp_members */ Overlapped_members, + /* tp_getset */ Overlapped_getsets, + /* tp_base */ 0, + /* tp_dict */ 0, + /* tp_descr_get */ 0, + /* tp_descr_set */ 0, + /* tp_dictoffset */ 0, + /* tp_init */ 0, + /* tp_alloc */ 0, + /* tp_new */ Overlapped_new, +}; + +static PyMethodDef overlapped_functions[] = { + {"CreateIoCompletionPort", overlapped_CreateIoCompletionPort, + METH_VARARGS, CreateIoCompletionPort_doc}, + {"GetQueuedCompletionStatus", overlapped_GetQueuedCompletionStatus, + METH_VARARGS, GetQueuedCompletionStatus_doc}, + {"PostQueuedCompletionStatus", overlapped_PostQueuedCompletionStatus, + METH_VARARGS, PostQueuedCompletionStatus_doc}, + {"FormatMessage", overlapped_FormatMessage, + METH_VARARGS, FormatMessage_doc}, + {"BindLocal", overlapped_BindLocal, + METH_VARARGS, BindLocal_doc}, + {NULL} +}; + +static struct PyModuleDef overlapped_module = { + PyModuleDef_HEAD_INIT, + "_overlapped", + NULL, + -1, + overlapped_functions, + NULL, + NULL, + NULL, + NULL +}; + +#define WINAPI_CONSTANT(fmt, con) \ + PyDict_SetItemString(d, #con, Py_BuildValue(fmt, con)) + +PyMODINIT_FUNC +PyInit__overlapped(void) +{ + PyObject *m, *d; + + /* Ensure WSAStartup() called before initializing function pointers */ + m = PyImport_ImportModule("_socket"); + if (!m) + return NULL; + Py_DECREF(m); + + if (initialize_function_pointers() < 0) + return NULL; + + if (PyType_Ready(&OverlappedType) < 0) + return NULL; + + m = PyModule_Create(&overlapped_module); + if (PyModule_AddObject(m, "Overlapped", (PyObject *)&OverlappedType) < 0) + return NULL; + + d = PyModule_GetDict(m); + + WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING); + WINAPI_CONSTANT(F_DWORD, ERROR_NETNAME_DELETED); + WINAPI_CONSTANT(F_DWORD, ERROR_SEM_TIMEOUT); + WINAPI_CONSTANT(F_DWORD, INFINITE); + WINAPI_CONSTANT(F_HANDLE, INVALID_HANDLE_VALUE); + WINAPI_CONSTANT(F_HANDLE, NULL); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_ACCEPT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_CONNECT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, TF_REUSE_SOCKET); + + return m; +} diff --git a/runtests.py b/runtests.py new file mode 100644 index 00000000..287d0367 --- /dev/null +++ b/runtests.py @@ -0,0 +1,278 @@ +"""Run Tulip unittests. + +Usage: + python3 runtests.py [flags] [pattern] ... + +Patterns are matched against the fully qualified name of the test, +including package, module, class and method, +e.g. 'tests.events_test.PolicyTests.testPolicy'. + +For full help, try --help. + +runtests.py --coverage is equivalent of: + + $(COVERAGE) run --branch runtests.py -v + $(COVERAGE) html $(list of files) + $(COVERAGE) report -m $(list of files) + +""" + +# Originally written by Beech Horn (for NDB). + +import argparse +import gc +import logging +import os +import re +import sys +import unittest +import textwrap +import importlib.machinery +try: + import coverage +except ImportError: + coverage = None + +from unittest.signals import installHandler + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +ARGS = argparse.ArgumentParser(description="Run all unittests.") +ARGS.add_argument( + '-v', action="store", dest='verbose', + nargs='?', const=1, type=int, default=0, help='verbose') +ARGS.add_argument( + '-x', action="store_true", dest='exclude', help='exclude tests') +ARGS.add_argument( + '-f', '--failfast', action="store_true", default=False, + dest='failfast', help='Stop on first fail or error') +ARGS.add_argument( + '-c', '--catch', action="store_true", default=False, + dest='catchbreak', help='Catch control-C and display results') +ARGS.add_argument( + '--forever', action="store_true", dest='forever', default=False, + help='run tests forever to catch sporadic errors') +ARGS.add_argument( + '--findleaks', action='store_true', dest='findleaks', + help='detect tests that leak memory') +ARGS.add_argument( + '-q', action="store_true", dest='quiet', help='quiet') +ARGS.add_argument( + '--tests', action="store", dest='testsdir', default='tests', + help='tests directory') +ARGS.add_argument( + '--coverage', action="store_true", dest='coverage', + help='enable html coverage report') +ARGS.add_argument( + 'pattern', action="store", nargs="*", + help='optional regex patterns to match test ids (default all tests)') + +COV_ARGS = argparse.ArgumentParser(description="Run all unittests.") +COV_ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') + + +def load_modules(basedir, suffix='.py'): + def list_dir(prefix, dir): + files = [] + + modpath = os.path.join(dir, '__init__.py') + if os.path.isfile(modpath): + mod = os.path.split(dir)[-1] + files.append(('{}{}'.format(prefix, mod), modpath)) + + prefix = '{}{}.'.format(prefix, mod) + + for name in os.listdir(dir): + path = os.path.join(dir, name) + + if os.path.isdir(path): + files.extend(list_dir('{}{}.'.format(prefix, name), path)) + else: + if (name != '__init__.py' and + name.endswith(suffix) and + not name.startswith(('.', '_'))): + files.append(('{}{}'.format(prefix, name[:-3]), path)) + + return files + + mods = [] + for modname, sourcefile in list_dir('', basedir): + if modname == 'runtests': + continue + try: + loader = importlib.machinery.SourceFileLoader(modname, sourcefile) + mods.append((loader.load_module(), sourcefile)) + except SyntaxError: + raise + except Exception as err: + print("Skipping '{}': {}".format(modname, err), file=sys.stderr) + + return mods + + +class TestsFinder: + + def __init__(self, testsdir, includes=(), excludes=()): + self._testsdir = testsdir + self._includes = includes + self._excludes = excludes + self.find_available_tests() + + def find_available_tests(self): + """ + Find available test classes without instantiating them. + """ + self._test_factories = [] + mods = [mod for mod, _ in load_modules(self._testsdir)] + for mod in mods: + for name in set(dir(mod)): + if name.endswith('Tests'): + self._test_factories.append(getattr(mod, name)) + + def load_tests(self): + """ + Load test cases from the available test classes and apply + optional include / exclude filters. + """ + loader = unittest.TestLoader() + suite = unittest.TestSuite() + for test_factory in self._test_factories: + tests = loader.loadTestsFromTestCase(test_factory) + if self._includes: + tests = [test + for test in tests + if any(re.search(pat, test.id()) + for pat in self._includes)] + if self._excludes: + tests = [test + for test in tests + if not any(re.search(pat, test.id()) + for pat in self._excludes)] + suite.addTests(tests) + return suite + + +class TestResult(unittest.TextTestResult): + + def __init__(self, stream, descriptions, verbosity): + super().__init__(stream, descriptions, verbosity) + self.leaks = [] + + def startTest(self, test): + super().startTest(test) + gc.collect() + + def addSuccess(self, test): + super().addSuccess(test) + gc.collect() + if gc.garbage: + if self.showAll: + self.stream.writeln( + " Warning: test created {} uncollectable " + "object(s).".format(len(gc.garbage))) + # move the uncollectable objects somewhere so we don't see + # them again + self.leaks.append((self.getDescription(test), gc.garbage[:])) + del gc.garbage[:] + + +class TestRunner(unittest.TextTestRunner): + resultclass = TestResult + + def run(self, test): + result = super().run(test) + if result.leaks: + self.stream.writeln("{} tests leaks:".format(len(result.leaks))) + for name, leaks in result.leaks: + self.stream.writeln(' '*4 + name + ':') + for leak in leaks: + self.stream.writeln(' '*8 + repr(leak)) + return result + + +def runtests(): + args = ARGS.parse_args() + + if args.coverage and coverage is None: + URL = "bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py" + print(textwrap.dedent(""" + coverage package is not installed. + + To install coverage3 for Python 3, you need: + - Setuptools (https://pypi.python.org/pypi/setuptools) + + What worked for me: + - download {0} + * curl -O https://{0} + - python3 ez_setup.py + - python3 -m easy_install coverage + """.format(URL)).strip()) + sys.exit(1) + + testsdir = os.path.abspath(args.testsdir) + if not os.path.isdir(testsdir): + print("Tests directory is not found: {}\n".format(testsdir)) + ARGS.print_help() + return + + excludes = includes = [] + if args.exclude: + excludes = args.pattern + else: + includes = args.pattern + + v = 0 if args.quiet else args.verbose + 1 + failfast = args.failfast + catchbreak = args.catchbreak + findleaks = args.findleaks + runner_factory = TestRunner if findleaks else unittest.TextTestRunner + + if args.coverage: + cov = coverage.coverage(branch=True, + source=['tulip'], + ) + cov.start() + + finder = TestsFinder(args.testsdir, includes, excludes) + logger = logging.getLogger() + if v == 0: + logger.setLevel(logging.CRITICAL) + elif v == 1: + logger.setLevel(logging.ERROR) + elif v == 2: + logger.setLevel(logging.WARNING) + elif v == 3: + logger.setLevel(logging.INFO) + elif v >= 4: + logger.setLevel(logging.DEBUG) + if catchbreak: + installHandler() + try: + if args.forever: + while True: + tests = finder.load_tests() + result = runner_factory(verbosity=v, + failfast=failfast).run(tests) + if not result.wasSuccessful(): + sys.exit(1) + else: + tests = finder.load_tests() + result = runner_factory(verbosity=v, + failfast=failfast).run(tests) + sys.exit(not result.wasSuccessful()) + finally: + if args.coverage: + cov.stop() + cov.save() + cov.html_report(directory='htmlcov') + print("\nCoverage report:") + cov.report(show_missing=False) + here = os.path.dirname(os.path.abspath(__file__)) + print("\nFor html report:") + print("open file://{}/htmlcov/index.html".format(here)) + + +if __name__ == '__main__': + runtests() diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..0260f9d5 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[build_ext] +build_lib=tulip diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..dcaee96f --- /dev/null +++ b/setup.py @@ -0,0 +1,14 @@ +import os +from distutils.core import setup, Extension + +extensions = [] +if os.name == 'nt': + ext = Extension('_overlapped', ['overlapped.c'], libraries=['ws2_32']) + extensions.append(ext) + +setup(name='tulip', + description="reference implementation of PEP 3156", + url='http://www.python.org/dev/peps/pep-3156/', + packages=['tulip'], + ext_modules=extensions + ) diff --git a/tests/base_events_test.py b/tests/base_events_test.py new file mode 100644 index 00000000..5393bd42 --- /dev/null +++ b/tests/base_events_test.py @@ -0,0 +1,592 @@ +"""Tests for base_events.py""" + +import logging +import socket +import time +import unittest +import unittest.mock + +from asyncio import base_events +from asyncio import events +from asyncio import futures +from asyncio import protocols +from asyncio import tasks +from asyncio import test_utils + + +class BaseEventLoopTests(unittest.TestCase): + + def setUp(self): + self.loop = base_events.BaseEventLoop() + self.loop._selector = unittest.mock.Mock() + events.set_event_loop(None) + + def test_not_implemented(self): + m = unittest.mock.Mock() + self.assertRaises( + NotImplementedError, + self.loop._make_socket_transport, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_ssl_transport, m, m, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_datagram_transport, m, m) + self.assertRaises( + NotImplementedError, self.loop._process_events, []) + self.assertRaises( + NotImplementedError, self.loop._write_to_self) + self.assertRaises( + NotImplementedError, self.loop._read_from_self) + self.assertRaises( + NotImplementedError, + self.loop._make_read_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_write_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + next, self.loop._make_subprocess_transport(m, m, m, m, m, m, m)) + + def test__add_callback_handle(self): + h = events.Handle(lambda: False, ()) + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertIn(h, self.loop._ready) + + def test__add_callback_timer(self): + h = events.TimerHandle(time.monotonic()+10, lambda: False, ()) + + self.loop._add_callback(h) + self.assertIn(h, self.loop._scheduled) + + def test__add_callback_cancelled_handle(self): + h = events.Handle(lambda: False, ()) + h.cancel() + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertFalse(self.loop._ready) + + def test_set_default_executor(self): + executor = unittest.mock.Mock() + self.loop.set_default_executor(executor) + self.assertIs(executor, self.loop._default_executor) + + def test_getnameinfo(self): + sockaddr = unittest.mock.Mock() + self.loop.run_in_executor = unittest.mock.Mock() + self.loop.getnameinfo(sockaddr) + self.assertEqual( + (None, socket.getnameinfo, sockaddr, 0), + self.loop.run_in_executor.call_args[0]) + + def test_call_soon(self): + def cb(): + pass + + h = self.loop.call_soon(cb) + self.assertEqual(h._callback, cb) + self.assertIsInstance(h, events.Handle) + self.assertIn(h, self.loop._ready) + + def test_call_later(self): + def cb(): + pass + + h = self.loop.call_later(10.0, cb) + self.assertIsInstance(h, events.TimerHandle) + self.assertIn(h, self.loop._scheduled) + self.assertNotIn(h, self.loop._ready) + + def test_call_later_negative_delays(self): + calls = [] + + def cb(arg): + calls.append(arg) + + self.loop._process_events = unittest.mock.Mock() + self.loop.call_later(-1, cb, 'a') + self.loop.call_later(-2, cb, 'b') + test_utils.run_briefly(self.loop) + self.assertEqual(calls, ['b', 'a']) + + def test_time_and_call_at(self): + def cb(): + self.loop.stop() + + self.loop._process_events = unittest.mock.Mock() + when = self.loop.time() + 0.1 + self.loop.call_at(when, cb) + t0 = self.loop.time() + self.loop.run_forever() + t1 = self.loop.time() + self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0) + + def test_run_once_in_executor_handle(self): + def cb(): + pass + + self.assertRaises( + AssertionError, self.loop.run_in_executor, + None, events.Handle(cb, ()), ('',)) + self.assertRaises( + AssertionError, self.loop.run_in_executor, + None, events.TimerHandle(10, cb, ())) + + def test_run_once_in_executor_cancelled(self): + def cb(): + pass + h = events.Handle(cb, ()) + h.cancel() + + f = self.loop.run_in_executor(None, h) + self.assertIsInstance(f, futures.Future) + self.assertTrue(f.done()) + self.assertIsNone(f.result()) + + def test_run_once_in_executor_plain(self): + def cb(): + pass + h = events.Handle(cb, ()) + f = futures.Future(loop=self.loop) + executor = unittest.mock.Mock() + executor.submit.return_value = f + + self.loop.set_default_executor(executor) + + res = self.loop.run_in_executor(None, h) + self.assertIs(f, res) + + executor = unittest.mock.Mock() + executor.submit.return_value = f + res = self.loop.run_in_executor(executor, h) + self.assertIs(f, res) + self.assertTrue(executor.submit.called) + + f.cancel() # Don't complain about abandoned Future. + + def test__run_once(self): + h1 = events.TimerHandle(time.monotonic() + 0.1, lambda: True, ()) + h2 = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ()) + + h1.cancel() + + self.loop._process_events = unittest.mock.Mock() + self.loop._scheduled.append(h1) + self.loop._scheduled.append(h2) + self.loop._run_once() + + t = self.loop._selector.select.call_args[0][0] + self.assertTrue(9.99 < t < 10.1, t) + self.assertEqual([h2], self.loop._scheduled) + self.assertTrue(self.loop._process_events.called) + + @unittest.mock.patch('asyncio.base_events.time') + @unittest.mock.patch('asyncio.base_events.asyncio_log') + def test__run_once_logging(self, m_logging, m_time): + # Log to INFO level if timeout > 1.0 sec. + idx = -1 + data = [10.0, 10.0, 12.0, 13.0] + + def monotonic(): + nonlocal data, idx + idx += 1 + return data[idx] + + m_time.monotonic = monotonic + m_logging.INFO = logging.INFO + m_logging.DEBUG = logging.DEBUG + + self.loop._scheduled.append( + events.TimerHandle(11.0, lambda: True, ())) + self.loop._process_events = unittest.mock.Mock() + self.loop._run_once() + self.assertEqual(logging.INFO, m_logging.log.call_args[0][0]) + + idx = -1 + data = [10.0, 10.0, 10.3, 13.0] + self.loop._scheduled = [events.TimerHandle(11.0, lambda:True, ())] + self.loop._run_once() + self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0]) + + def test__run_once_schedule_handle(self): + handle = None + processed = False + + def cb(loop): + nonlocal processed, handle + processed = True + handle = loop.call_soon(lambda: True) + + h = events.TimerHandle(time.monotonic() - 1, cb, (self.loop,)) + + self.loop._process_events = unittest.mock.Mock() + self.loop._scheduled.append(h) + self.loop._run_once() + + self.assertTrue(processed) + self.assertEqual([handle], list(self.loop._ready)) + + def test_run_until_complete_type_error(self): + self.assertRaises( + TypeError, self.loop.run_until_complete, 'blah') + + +class MyProto(protocols.Protocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = futures.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + self.transport.close() + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyDatagramProto(protocols.DatagramProtocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = futures.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class BaseEventLoopWithSelectorTests(unittest.TestCase): + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + @unittest.mock.patch('asyncio.base_events.socket') + def test_create_connection_multiple_errors(self, m_socket): + + class MyProto(protocols.Protocol): + pass + + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + idx = -1 + errors = ['err1', 'err2'] + + def _socket(*args, **kw): + nonlocal idx, errors + idx += 1 + raise OSError(errors[idx]) + + m_socket.socket = _socket + + self.loop.getaddrinfo = getaddrinfo_task + + coro = self.loop.create_connection(MyProto, 'example.com', 80) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(coro) + + self.assertEqual(str(cm.exception), 'Multiple exceptions: err1, err2') + + def test_create_connection_host_port_sock(self): + coro = self.loop.create_connection( + MyProto, 'example.com', 80, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_host_port_sock(self): + coro = self.loop.create_connection(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_getaddrinfo(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_connect_err(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_multiple(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET) + with self.assertRaises(OSError): + self.loop.run_until_complete(coro) + + @unittest.mock.patch('asyncio.base_events.socket') + def test_create_connection_multiple_errors_local_addr(self, m_socket): + + def bind(addr): + if addr[0] == '0.0.0.1': + err = OSError('Err') + err.strerror = 'Err' + raise err + + m_socket.socket.return_value.bind = bind + + @tasks.coroutine + def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError('Err2') + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(coro) + + self.assertTrue(str(cm.exception).startswith('Multiple exceptions: ')) + self.assertTrue(m_socket.socket.return_value.close.called) + + def test_create_connection_no_local_addr(self): + @tasks.coroutine + def getaddrinfo(host, *args, **kw): + if host == 'example.com': + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + else: + return [] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + self.loop.getaddrinfo = getaddrinfo_task + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_server_empty_host(self): + # if host is empty string use None instead + host = object() + + @tasks.coroutine + def getaddrinfo(*args, **kw): + nonlocal host + host = args[0] + yield from [] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + fut = self.loop.create_server(MyProto, '', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertIsNone(host) + + def test_create_server_host_port_sock(self): + fut = self.loop.create_server( + MyProto, '0.0.0.0', 0, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_create_server_no_host_port_sock(self): + fut = self.loop.create_server(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_create_server_no_getaddrinfo(self): + getaddrinfo = self.loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + f = self.loop.create_server(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, f) + + @unittest.mock.patch('asyncio.base_events.socket') + def test_create_server_cant_bind(self, m_socket): + + class Err(OSError): + strerror = 'error' + + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.create_server(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + @unittest.mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_no_addrinfo(self, m_socket): + m_socket.getaddrinfo.return_value = [] + + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_addr_error(self): + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr='localhost') + self.assertRaises( + AssertionError, self.loop.run_until_complete, coro) + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 1, 2, 3)) + self.assertRaises( + AssertionError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_connect_err(self): + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_socket_err(self, m_socket): + m_socket.getaddrinfo = socket.getaddrinfo + m_socket.socket.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, local_addr=('127.0.0.1', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_no_matching_family(self): + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, + remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) + self.assertRaises( + ValueError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_setblk_err(self, m_socket): + m_socket.socket.return_value.setblocking.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + self.assertTrue( + m_socket.socket.return_value.close.called) + + def test_create_datagram_endpoint_noaddr_nofamily(self): + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_cant_bind(self, m_socket): + class Err(OSError): + pass + + m_socket.AF_INET6 = socket.AF_INET6 + m_socket.getaddrinfo = socket.getaddrinfo + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, + local_addr=('127.0.0.1', 0), family=socket.AF_INET) + self.assertRaises(Err, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_accept_connection_retry(self): + sock = unittest.mock.Mock() + sock.accept.side_effect = BlockingIOError() + + self.loop._accept_connection(MyProto, sock) + self.assertFalse(sock.close.called) + + @unittest.mock.patch('asyncio.selector_events.asyncio_log') + def test_accept_connection_exception(self, m_log): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.accept.side_effect = OSError() + + self.loop._accept_connection(MyProto, sock) + self.assertTrue(sock.close.called) + self.assertTrue(m_log.exception.called) diff --git a/tests/echo.py b/tests/echo.py new file mode 100644 index 00000000..f6ac0a30 --- /dev/null +++ b/tests/echo.py @@ -0,0 +1,6 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + os.write(1, buf) diff --git a/tests/echo2.py b/tests/echo2.py new file mode 100644 index 00000000..e83ca09f --- /dev/null +++ b/tests/echo2.py @@ -0,0 +1,6 @@ +import os + +if __name__ == '__main__': + buf = os.read(0, 1024) + os.write(1, b'OUT:'+buf) + os.write(2, b'ERR:'+buf) diff --git a/tests/echo3.py b/tests/echo3.py new file mode 100644 index 00000000..f1f7ea7c --- /dev/null +++ b/tests/echo3.py @@ -0,0 +1,9 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + try: + os.write(1, b'OUT:'+buf) + except OSError as ex: + os.write(2, b'ERR:' + ex.__class__.__name__.encode('ascii')) diff --git a/tests/events_test.py b/tests/events_test.py new file mode 100644 index 00000000..7b12700d --- /dev/null +++ b/tests/events_test.py @@ -0,0 +1,1571 @@ +"""Tests for events.py.""" + +import functools +import gc +import io +import os +import signal +import socket +try: + import ssl +except ImportError: + ssl = None +import subprocess +import sys +import threading +import time +import errno +import unittest +import unittest.mock +from test.support import find_unused_port + + +from asyncio import futures +from asyncio import events +from asyncio import transports +from asyncio import protocols +from asyncio import selector_events +from asyncio import tasks +from asyncio import test_utils +from asyncio import locks + + +class MyProto(protocols.Protocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyDatagramProto(protocols.DatagramProtocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyReadPipeProto(protocols.Protocol): + done = None + + def __init__(self, loop=None): + self.state = ['INITIAL'] + self.nbytes = 0 + self.transport = None + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == ['INITIAL'], self.state + self.state.append('CONNECTED') + + def data_received(self, data): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.state.append('EOF') + self.transport.close() + + def connection_lost(self, exc): + assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state + self.state.append('CLOSED') + if self.done: + self.done.set_result(None) + + +class MyWritePipeProto(protocols.BaseProtocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.transport = None + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MySubprocessProtocol(protocols.SubprocessProtocol): + + def __init__(self, loop): + self.state = 'INITIAL' + self.transport = None + self.connected = futures.Future(loop=loop) + self.completed = futures.Future(loop=loop) + self.disconnects = {fd: futures.Future(loop=loop) for fd in range(3)} + self.data = {1: b'', 2: b''} + self.returncode = None + self.got_data = {1: locks.Event(loop=loop), + 2: locks.Event(loop=loop)} + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + self.connected.set_result(None) + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + self.completed.set_result(None) + + def pipe_data_received(self, fd, data): + assert self.state == 'CONNECTED', self.state + self.data[fd] += data + self.got_data[fd].set() + + def pipe_connection_lost(self, fd, exc): + assert self.state == 'CONNECTED', self.state + if exc: + self.disconnects[fd].set_exception(exc) + else: + self.disconnects[fd].set_result(exc) + + def process_exited(self): + assert self.state == 'CONNECTED', self.state + self.returncode = self.transport.get_returncode() + + +class EventLoopTestsMixin: + + def setUp(self): + super().setUp() + self.loop = self.create_event_loop() + events.set_event_loop(None) + + def tearDown(self): + # just in case if we have transport close callbacks + test_utils.run_briefly(self.loop) + + self.loop.close() + gc.collect() + super().tearDown() + + def test_run_until_complete_nesting(self): + @tasks.coroutine + def coro1(): + yield + + @tasks.coroutine + def coro2(): + self.assertTrue(self.loop.is_running()) + self.loop.run_until_complete(coro1()) + + self.assertRaises( + RuntimeError, self.loop.run_until_complete, coro2()) + + # Note: because of the default Windows timing granularity of + # 15.6 msec, we use fairly long sleep times here (~100 msec). + + def test_run_until_complete(self): + t0 = self.loop.time() + self.loop.run_until_complete(tasks.sleep(0.1, loop=self.loop)) + t1 = self.loop.time() + self.assertTrue(0.08 <= t1-t0 <= 0.12, t1-t0) + + def test_run_until_complete_stopped(self): + @tasks.coroutine + def cb(): + self.loop.stop() + yield from tasks.sleep(0.1, loop=self.loop) + task = cb() + self.assertRaises(RuntimeError, + self.loop.run_until_complete, task) + + def test_call_later(self): + results = [] + + def callback(arg): + results.append(arg) + self.loop.stop() + + self.loop.call_later(0.1, callback, 'hello world') + t0 = time.monotonic() + self.loop.run_forever() + t1 = time.monotonic() + self.assertEqual(results, ['hello world']) + self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0) + + def test_call_soon(self): + results = [] + + def callback(arg1, arg2): + results.append((arg1, arg2)) + self.loop.stop() + + self.loop.call_soon(callback, 'hello', 'world') + self.loop.run_forever() + self.assertEqual(results, [('hello', 'world')]) + + def test_call_soon_threadsafe(self): + results = [] + lock = threading.Lock() + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.loop.stop() + + def run_in_thread(): + self.loop.call_soon_threadsafe(callback, 'hello') + lock.release() + + lock.acquire() + t = threading.Thread(target=run_in_thread) + t.start() + + with lock: + self.loop.call_soon(callback, 'world') + self.loop.run_forever() + t.join() + self.assertEqual(results, ['hello', 'world']) + + def test_call_soon_threadsafe_same_thread(self): + results = [] + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.loop.stop() + + self.loop.call_soon_threadsafe(callback, 'hello') + self.loop.call_soon(callback, 'world') + self.loop.run_forever() + self.assertEqual(results, ['hello', 'world']) + + def test_run_in_executor(self): + def run(arg): + return (arg, threading.get_ident()) + f2 = self.loop.run_in_executor(None, run, 'yo') + res, thread_id = self.loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + self.assertNotEqual(thread_id, threading.get_ident()) + + def test_reader_callback(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.loop.remove_reader(r.fileno())) + r.close() + + self.loop.add_reader(r.fileno(), reader) + self.loop.call_soon(w.send, b'abc') + test_utils.run_briefly(self.loop) + self.loop.call_soon(w.send, b'def') + test_utils.run_briefly(self.loop) + self.loop.call_soon(w.close) + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_writer_callback(self): + r, w = test_utils.socketpair() + w.setblocking(False) + self.loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + test_utils.run_briefly(self.loop) + + def remove_writer(): + self.assertTrue(self.loop.remove_writer(w.fileno())) + + self.loop.call_soon(remove_writer) + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + w.close() + data = r.recv(256*1024) + r.close() + self.assertGreaterEqual(len(data), 200) + + def test_sock_client_ops(self): + with test_utils.run_test_server() as httpd: + sock = socket.socket() + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, httpd.address)) + self.loop.run_until_complete( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + # consume data + self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + sock.close() + + self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) + + def test_sock_client_fail(self): + # Make sure that we will get an unused port + address = None + try: + s = socket.socket() + s.bind(('127.0.0.1', 0)) + address = s.getsockname() + finally: + s.close() + + sock = socket.socket() + sock.setblocking(False) + with self.assertRaises(ConnectionRefusedError): + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) + sock.close() + + def test_sock_accept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + + f = self.loop.sock_accept(listener) + conn, addr = self.loop.run_until_complete(f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + def test_add_signal_handler(self): + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + # Check error behavior first. + self.assertRaises( + TypeError, self.loop.add_signal_handler, 'boom', my_handler) + self.assertRaises( + TypeError, self.loop.remove_signal_handler, 'boom') + self.assertRaises( + ValueError, self.loop.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, signal.NSIG+1) + self.assertRaises( + ValueError, self.loop.add_signal_handler, 0, my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, 0) + self.assertRaises( + ValueError, self.loop.add_signal_handler, -1, my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, -1) + self.assertRaises( + RuntimeError, self.loop.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(self.loop.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + self.loop.add_signal_handler(signal.SIGINT, my_handler) + test_utils.run_briefly(self.loop) + os.kill(os.getpid(), signal.SIGINT) + test_utils.run_briefly(self.loop) + self.assertEqual(caught, 1) + # Removing it should restore the default handler. + self.assertTrue(self.loop.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(self.loop.remove_signal_handler(signal.SIGINT)) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_while_selecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + self.loop.stop() + + self.loop.add_signal_handler(signal.SIGALRM, my_handler) + + signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once. + self.loop.run_forever() + self.assertEqual(caught, 1) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_args(self): + some_args = (42,) + caught = 0 + + def my_handler(*args): + nonlocal caught + caught += 1 + self.assertEqual(args, some_args) + + self.loop.add_signal_handler(signal.SIGALRM, my_handler, *some_args) + + signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once. + self.loop.call_later(0.015, self.loop.stop) + self.loop.run_forever() + self.assertEqual(caught, 1) + + def test_create_connection(self): + with test_utils.run_test_server() as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), *httpd.address) + tr, pr = self.loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def test_create_connection_sock(self): + with test_utils.run_test_server() as httpd: + sock = None + infos = self.loop.run_until_complete( + self.loop.getaddrinfo( + *httpd.address, type=socket.SOCK_STREAM)) + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) + except: + pass + else: + break + else: + assert False, 'Can not create socket.' + + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), sock=sock) + tr, pr = self.loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_ssl_connection(self): + with test_utils.run_test_server(use_ssl=True) as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), *httpd.address, ssl=True) + tr, pr = self.loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + self.assertIsNotNone(tr.get_extra_info('sockname')) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def test_create_connection_local_addr(self): + with test_utils.run_test_server() as httpd: + port = find_unused_port() + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, local_addr=(httpd.address[0], port)) + tr, pr = self.loop.run_until_complete(f) + expected = pr.transport.get_extra_info('sockname')[1] + self.assertEqual(port, expected) + tr.close() + + def test_create_connection_local_addr_in_use(self): + with test_utils.run_test_server() as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, local_addr=httpd.address) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + self.assertIn(str(httpd.address), cm.exception.strerror) + + def test_create_server(self): + proto = None + + def factory(): + nonlocal proto + proto = MyProto() + return proto + + f = self.loop.create_server(factory, '0.0.0.0', 0) + server = self.loop.run_until_complete(f) + self.assertEqual(len(server.sockets), 1) + sock = server.sockets[0] + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + test_utils.run_briefly(self.loop) + self.assertIsInstance(proto, MyProto) + self.assertEqual('INITIAL', proto.state) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + test_utils.run_briefly(self.loop) # windows iocp + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('sockname')) + self.assertEqual('127.0.0.1', + proto.transport.get_extra_info('peername')[0]) + + # close connection + proto.transport.close() + test_utils.run_briefly(self.loop) # windows iocp + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # close server + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_server_ssl(self): + proto = None + + class ClientMyProto(MyProto): + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def factory(): + nonlocal proto + proto = MyProto(loop=self.loop) + return proto + + here = os.path.dirname(__file__) + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain( + certfile=os.path.join(here, 'sample.crt'), + keyfile=os.path.join(here, 'sample.key')) + + f = self.loop.create_server( + factory, '127.0.0.1', 0, ssl=sslcontext) + + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + host, port = sock.getsockname() + self.assertEqual(host, '127.0.0.1') + + f_c = self.loop.create_connection(ClientMyProto, host, port, ssl=True) + client, pr = self.loop.run_until_complete(f_c) + + client.write(b'xxx') + test_utils.run_briefly(self.loop) + self.assertIsInstance(proto, MyProto) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('sockname')) + self.assertEqual('127.0.0.1', + proto.transport.get_extra_info('peername')[0]) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # stop serving + server.close() + + def test_create_server_sock(self): + proto = futures.Future(loop=self.loop) + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + proto.set_result(self) + + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.loop.create_server(TestMyProto, sock=sock_ob) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + self.assertIs(sock, sock_ob) + + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + server.close() + + def test_create_server_addr_in_use(self): + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.loop.create_server(MyProto, sock=sock_ob) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + host, port = sock.getsockname() + + f = self.loop.create_server(MyProto, host=host, port=port) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + + server.close() + + @unittest.skipUnless(socket.has_ipv6, 'IPv6 not supported') + def test_create_server_dual_stack(self): + f_proto = futures.Future(loop=self.loop) + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + f_proto.set_result(self) + + try_count = 0 + while True: + try: + port = find_unused_port() + f = self.loop.create_server(TestMyProto, host=None, port=port) + server = self.loop.run_until_complete(f) + except OSError as ex: + if ex.errno == errno.EADDRINUSE: + try_count += 1 + self.assertGreaterEqual(5, try_count) + continue + else: + raise + else: + break + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + proto = self.loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + f_proto = futures.Future(loop=self.loop) + client = socket.socket(socket.AF_INET6) + client.connect(('::1', port)) + client.send(b'xxx') + proto = self.loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + server.close() + + def test_server_close(self): + f = self.loop.create_server(MyProto, '0.0.0.0', 0) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + host, port = sock.getsockname() + + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + + server.close() + + client = socket.socket() + self.assertRaises( + ConnectionRefusedError, client.connect, ('127.0.0.1', port)) + client.close() + + def test_create_datagram_endpoint(self): + class TestMyDatagramProto(MyDatagramProto): + def __init__(inner_self): + super().__init__(loop=self.loop) + + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + coro = self.loop.create_datagram_endpoint( + TestMyDatagramProto, local_addr=('127.0.0.1', 0)) + s_transport, server = self.loop.run_until_complete(coro) + host, port = s_transport.get_extra_info('sockname') + + coro = self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(loop=self.loop), + remote_addr=(host, port)) + transport, client = self.loop.run_until_complete(coro) + + self.assertEqual('INITIALIZED', client.state) + transport.sendto(b'xxx') + for _ in range(1000): + if server.nbytes: + break + test_utils.run_briefly(self.loop) + self.assertEqual(3, server.nbytes) + for _ in range(1000): + if client.nbytes: + break + test_utils.run_briefly(self.loop) + + # received + self.assertEqual(8, client.nbytes) + + # extra info is available + self.assertIsNotNone(transport.get_extra_info('sockname')) + + # close connection + transport.close() + self.loop.run_until_complete(client.done) + self.assertEqual('CLOSED', client.state) + server.transport.close() + + def test_internal_fds(self): + loop = self.create_event_loop() + if not isinstance(loop, selector_events.BaseSelectorEventLoop): + return + + self.assertEqual(1, loop._internal_fds) + loop.close() + self.assertEqual(0, loop._internal_fds) + self.assertIsNone(loop._csock) + self.assertIsNone(loop._ssock) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_read_pipe(self): + proto = None + + def factory(): + nonlocal proto + proto = MyReadPipeProto(loop=self.loop) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(rpipe, 'rb', 1024) + + @tasks.coroutine + def connect(): + t, p = yield from self.loop.connect_read_pipe(factory, pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(0, proto.nbytes) + + self.loop.run_until_complete(connect()) + + os.write(wpipe, b'1') + test_utils.run_briefly(self.loop) + self.assertEqual(1, proto.nbytes) + + os.write(wpipe, b'2345') + test_utils.run_briefly(self.loop) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(5, proto.nbytes) + + os.close(wpipe) + self.loop.run_until_complete(proto.done) + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto(loop=self.loop) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.coroutine + def connect(): + nonlocal transport + t, p = yield from self.loop.connect_write_pipe(factory, pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.loop.run_until_complete(connect()) + + transport.write(b'1') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + transport.write(b'2345') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'2345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(rpipe) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe_disconnect_on_close(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto(loop=self.loop) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.coroutine + def connect(): + nonlocal transport + t, p = yield from self.loop.connect_write_pipe(factory, + pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.loop.run_until_complete(connect()) + self.assertEqual('CONNECTED', proto.state) + + transport.write(b'1') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + os.close(rpipe) + + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + def test_prompt_cancellation(self): + r, w = test_utils.socketpair() + r.setblocking(False) + f = self.loop.sock_recv(r, 1) + ov = getattr(f, 'ov', None) + self.assertTrue(ov is None or ov.pending) + + def main(): + try: + self.loop.call_soon(f.cancel) + yield from f + except futures.CancelledError: + res = 'cancelled' + else: + res = None + finally: + self.loop.stop() + return res + + start = time.monotonic() + t = tasks.Task(main(), loop=self.loop) + self.loop.run_forever() + elapsed = time.monotonic() - start + + self.assertLess(elapsed, 0.1) + self.assertEqual(t.result(), 'cancelled') + self.assertRaises(futures.CancelledError, f.result) + self.assertTrue(ov is None or not ov.pending) + self.loop._stop_serving(r) + + r.close() + w.close() + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_exec(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python The Winner') + self.loop.run_until_complete(proto.got_data[1].wait()) + transp.close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + self.assertEqual(b'Python The Winner', proto.data[1]) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_interactive(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + try: + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python ') + self.loop.run_until_complete(proto.got_data[1].wait()) + proto.got_data[1].clear() + self.assertEqual(b'Python ', proto.data[1]) + + stdin.write(b'The Winner') + self.loop.run_until_complete(proto.got_data[1].wait()) + self.assertEqual(b'Python The Winner', proto.data[1]) + finally: + transp.close() + + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_shell(self): + proto = None + transp = None + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'echo "Python"') + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.get_pipe_transport(0).close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(0, proto.returncode) + self.assertTrue(all(f.done() for f in proto.disconnects.values())) + self.assertEqual({1: b'Python\n', 2: b''}, proto.data) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_exitcode(self): + proto = None + + @tasks.coroutine + def connect(): + nonlocal proto + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_close_after_finish(self): + proto = None + transp = None + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.assertIsNone(transp.get_pipe_transport(0)) + self.assertIsNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + self.assertIsNone(transp.close()) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_kill(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.kill() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGKILL, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_send_signal(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.send_signal(signal.SIGHUP) + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGHUP, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_stderr(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'test') + + self.loop.run_until_complete(proto.completed) + + transp.close() + self.assertEqual(b'OUT:test', proto.data[1]) + self.assertTrue(proto.data[2].startswith(b'ERR:test'), proto.data[2]) + self.assertEqual(0, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_stderr_redirect_to_stdout(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog, stderr=subprocess.STDOUT) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + self.assertIsNotNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + + stdin.write(b'test') + self.loop.run_until_complete(proto.completed) + self.assertTrue(proto.data[1].startswith(b'OUT:testERR:test'), + proto.data[1]) + self.assertEqual(b'', proto.data[2]) + + transp.close() + self.assertEqual(0, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_close_client_stream(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo3.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdout = transp.get_pipe_transport(1) + stdin.write(b'test') + self.loop.run_until_complete(proto.got_data[1].wait()) + self.assertEqual(b'OUT:test', proto.data[1]) + + stdout.close() + self.loop.run_until_complete(proto.disconnects[1]) + stdin.write(b'xxx') + self.loop.run_until_complete(proto.got_data[2].wait()) + self.assertEqual(b'ERR:BrokenPipeError', proto.data[2]) + + transp.close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + + +if sys.platform == 'win32': + from asyncio import windows_events + + class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return windows_events.SelectorEventLoop() + + class ProactorEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return windows_events.ProactorEventLoop() + + def test_create_ssl_connection(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + + def test_create_server_ssl(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + + def test_reader_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_reader_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_writer_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + + def test_writer_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + + def test_create_datagram_endpoint(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") +else: + from asyncio import selectors + from asyncio import unix_events + + if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop( + selectors.KqueueSelector()) + + if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.EpollSelector()) + + if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.PollSelector()) + + # Should always exist. + class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.SelectSelector()) + + +class HandleTests(unittest.TestCase): + + def test_handle(self): + def callback(*args): + return args + + args = () + h = events.Handle(callback, args) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h._cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '.callback')) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h._cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '.callback')) + self.assertTrue(r.endswith('())'), r) + + def test_make_handle(self): + def callback(*args): + return args + h1 = events.Handle(callback, ()) + self.assertRaises( + AssertionError, events.make_handle, h1, ()) + + @unittest.mock.patch('asyncio.events.asyncio_log') + def test_callback_with_exception(self, log): + def callback(): + raise ValueError() + + h = events.Handle(callback, ()) + h._run() + self.assertTrue(log.exception.called) + + +class TimerTests(unittest.TestCase): + + def test_hash(self): + when = time.monotonic() + h = events.TimerHandle(when, lambda: False, ()) + self.assertEqual(hash(h), hash(when)) + + def test_timer(self): + def callback(*args): + return args + + args = () + when = time.monotonic() + h = events.TimerHandle(when, callback, args) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h._cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h._cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())'), r) + + self.assertRaises(AssertionError, + events.TimerHandle, None, callback, args) + + def test_timer_comparison(self): + def callback(*args): + return args + + when = time.monotonic() + + h1 = events.TimerHandle(when, callback, ()) + h2 = events.TimerHandle(when, callback, ()) + # TODO: Use assertLess etc. + self.assertFalse(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertTrue(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertFalse(h2 > h1) + self.assertTrue(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertTrue(h1 == h2) + self.assertFalse(h1 != h2) + + h2.cancel() + self.assertFalse(h1 == h2) + + h1 = events.TimerHandle(when, callback, ()) + h2 = events.TimerHandle(when + 10.0, callback, ()) + self.assertTrue(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertFalse(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertTrue(h2 > h1) + self.assertFalse(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h3 = events.Handle(callback, ()) + self.assertIs(NotImplemented, h1.__eq__(h3)) + self.assertIs(NotImplemented, h1.__ne__(h3)) + + +class AbstractEventLoopTests(unittest.TestCase): + + def test_not_implemented(self): + f = unittest.mock.Mock() + loop = events.AbstractEventLoop() + self.assertRaises( + NotImplementedError, loop.run_forever) + self.assertRaises( + NotImplementedError, loop.run_until_complete, None) + self.assertRaises( + NotImplementedError, loop.stop) + self.assertRaises( + NotImplementedError, loop.is_running) + self.assertRaises( + NotImplementedError, loop.call_later, None, None) + self.assertRaises( + NotImplementedError, loop.call_at, f, f) + self.assertRaises( + NotImplementedError, loop.call_soon, None) + self.assertRaises( + NotImplementedError, loop.time) + self.assertRaises( + NotImplementedError, loop.call_soon_threadsafe, None) + self.assertRaises( + NotImplementedError, loop.run_in_executor, f, f) + self.assertRaises( + NotImplementedError, loop.set_default_executor, f) + self.assertRaises( + NotImplementedError, loop.getaddrinfo, 'localhost', 8080) + self.assertRaises( + NotImplementedError, loop.getnameinfo, ('localhost', 8080)) + self.assertRaises( + NotImplementedError, loop.create_connection, f) + self.assertRaises( + NotImplementedError, loop.create_server, f) + self.assertRaises( + NotImplementedError, loop.create_datagram_endpoint, f) + self.assertRaises( + NotImplementedError, loop.add_reader, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_reader, 1) + self.assertRaises( + NotImplementedError, loop.add_writer, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_writer, 1) + self.assertRaises( + NotImplementedError, loop.sock_recv, f, 10) + self.assertRaises( + NotImplementedError, loop.sock_sendall, f, 10) + self.assertRaises( + NotImplementedError, loop.sock_connect, f, f) + self.assertRaises( + NotImplementedError, loop.sock_accept, f) + self.assertRaises( + NotImplementedError, loop.add_signal_handler, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, loop.connect_read_pipe, f, + unittest.mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, loop.connect_write_pipe, f, + unittest.mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, loop.subprocess_shell, f, + unittest.mock.sentinel) + self.assertRaises( + NotImplementedError, loop.subprocess_exec, f) + + +class ProtocolsAbsTests(unittest.TestCase): + + def test_empty(self): + f = unittest.mock.Mock() + p = protocols.Protocol() + self.assertIsNone(p.connection_made(f)) + self.assertIsNone(p.connection_lost(f)) + self.assertIsNone(p.data_received(f)) + self.assertIsNone(p.eof_received()) + + dp = protocols.DatagramProtocol() + self.assertIsNone(dp.connection_made(f)) + self.assertIsNone(dp.connection_lost(f)) + self.assertIsNone(dp.connection_refused(f)) + self.assertIsNone(dp.datagram_received(f, f)) + + sp = protocols.SubprocessProtocol() + self.assertIsNone(sp.connection_made(f)) + self.assertIsNone(sp.connection_lost(f)) + self.assertIsNone(sp.pipe_data_received(1, f)) + self.assertIsNone(sp.pipe_connection_lost(1, f)) + self.assertIsNone(sp.process_exited()) + + +class PolicyTests(unittest.TestCase): + + def test_event_loop_policy(self): + policy = events.AbstractEventLoopPolicy() + self.assertRaises(NotImplementedError, policy.get_event_loop) + self.assertRaises(NotImplementedError, policy.set_event_loop, object()) + self.assertRaises(NotImplementedError, policy.new_event_loop) + + def test_get_event_loop(self): + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy._loop) + + loop = policy.get_event_loop() + self.assertIsInstance(loop, events.AbstractEventLoop) + + self.assertIs(policy._loop, loop) + self.assertIs(loop, policy.get_event_loop()) + loop.close() + + def test_get_event_loop_after_set_none(self): + policy = events.DefaultEventLoopPolicy() + policy.set_event_loop(None) + self.assertRaises(AssertionError, policy.get_event_loop) + + @unittest.mock.patch('asyncio.events.threading.current_thread') + def test_get_event_loop_thread(self, m_current_thread): + + def f(): + policy = events.DefaultEventLoopPolicy() + self.assertRaises(AssertionError, policy.get_event_loop) + + th = threading.Thread(target=f) + th.start() + th.join() + + def test_new_event_loop(self): + policy = events.DefaultEventLoopPolicy() + + loop = policy.new_event_loop() + self.assertIsInstance(loop, events.AbstractEventLoop) + loop.close() + + def test_set_event_loop(self): + policy = events.DefaultEventLoopPolicy() + old_loop = policy.get_event_loop() + + self.assertRaises(AssertionError, policy.set_event_loop, object()) + + loop = policy.new_event_loop() + policy.set_event_loop(loop) + self.assertIs(loop, policy.get_event_loop()) + self.assertIsNot(old_loop, policy.get_event_loop()) + loop.close() + old_loop.close() + + def test_get_event_loop_policy(self): + policy = events.get_event_loop_policy() + self.assertIsInstance(policy, events.AbstractEventLoopPolicy) + self.assertIs(policy, events.get_event_loop_policy()) + + def test_set_event_loop_policy(self): + self.assertRaises( + AssertionError, events.set_event_loop_policy, object()) + + old_policy = events.get_event_loop_policy() + + policy = events.DefaultEventLoopPolicy() + events.set_event_loop_policy(policy) + self.assertIs(policy, events.get_event_loop_policy()) + self.assertIsNot(policy, old_policy) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/futures_test.py b/tests/futures_test.py new file mode 100644 index 00000000..9b5108c4 --- /dev/null +++ b/tests/futures_test.py @@ -0,0 +1,329 @@ +"""Tests for futures.py.""" + +import concurrent.futures +import threading +import unittest +import unittest.mock + +from asyncio import events +from asyncio import futures +from asyncio import test_utils + + +def _fakefunc(f): + return f + + +class FutureTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_initial_state(self): + f = futures.Future(loop=self.loop) + self.assertFalse(f.cancelled()) + self.assertFalse(f.done()) + f.cancel() + self.assertTrue(f.cancelled()) + + def test_init_constructor_default_loop(self): + try: + events.set_event_loop(self.loop) + f = futures.Future() + self.assertIs(f._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_constructor_positional(self): + # Make sure Future does't accept a positional argument + self.assertRaises(TypeError, futures.Future, 42) + + def test_cancel(self): + f = futures.Future(loop=self.loop) + self.assertTrue(f.cancel()) + self.assertTrue(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(futures.CancelledError, f.result) + self.assertRaises(futures.CancelledError, f.exception) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_result(self): + f = futures.Future(loop=self.loop) + self.assertRaises(futures.InvalidStateError, f.result) + + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_exception(self): + exc = RuntimeError() + f = futures.Future(loop=self.loop) + self.assertRaises(futures.InvalidStateError, f.exception) + + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_yield_from_twice(self): + f = futures.Future(loop=self.loop) + + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + + def test_repr(self): + f_pending = futures.Future(loop=self.loop) + self.assertEqual(repr(f_pending), 'Future') + f_pending.cancel() + + f_cancelled = futures.Future(loop=self.loop) + f_cancelled.cancel() + self.assertEqual(repr(f_cancelled), 'Future') + + f_result = futures.Future(loop=self.loop) + f_result.set_result(4) + self.assertEqual(repr(f_result), 'Future') + self.assertEqual(f_result.result(), 4) + + exc = RuntimeError() + f_exception = futures.Future(loop=self.loop) + f_exception.set_exception(exc) + self.assertEqual(repr(f_exception), 'Future') + self.assertIs(f_exception.exception(), exc) + + f_few_callbacks = futures.Future(loop=self.loop) + f_few_callbacks.add_done_callback(_fakefunc) + self.assertIn('Future', r) + f_many_callbacks.cancel() + + def test_copy_state(self): + # Test the internal _copy_state method since it's being directly + # invoked in other modules. + f = futures.Future(loop=self.loop) + f.set_result(10) + + newf = futures.Future(loop=self.loop) + newf._copy_state(f) + self.assertTrue(newf.done()) + self.assertEqual(newf.result(), 10) + + f_exception = futures.Future(loop=self.loop) + f_exception.set_exception(RuntimeError()) + + newf_exception = futures.Future(loop=self.loop) + newf_exception._copy_state(f_exception) + self.assertTrue(newf_exception.done()) + self.assertRaises(RuntimeError, newf_exception.result) + + f_cancelled = futures.Future(loop=self.loop) + f_cancelled.cancel() + + newf_cancelled = futures.Future(loop=self.loop) + newf_cancelled._copy_state(f_cancelled) + self.assertTrue(newf_cancelled.cancelled()) + + def test_iter(self): + fut = futures.Future(loop=self.loop) + + def coro(): + yield from fut + + def test(): + arg1, arg2 = coro() + + self.assertRaises(AssertionError, test) + fut.cancel() + + @unittest.mock.patch('asyncio.futures.asyncio_log') + def test_tb_logger_abandoned(self, m_log): + fut = futures.Future(loop=self.loop) + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('asyncio.futures.asyncio_log') + def test_tb_logger_result_unretrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_result(42) + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('asyncio.futures.asyncio_log') + def test_tb_logger_result_retrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_result(42) + fut.result() + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('asyncio.futures.asyncio_log') + def test_tb_logger_exception_unretrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + del fut + test_utils.run_briefly(self.loop) + self.assertTrue(m_log.error.called) + + @unittest.mock.patch('asyncio.futures.asyncio_log') + def test_tb_logger_exception_retrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + fut.exception() + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('asyncio.futures.asyncio_log') + def test_tb_logger_exception_result_retrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + self.assertRaises(RuntimeError, fut.result) + del fut + self.assertFalse(m_log.error.called) + + def test_wrap_future(self): + + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = futures.wrap_future(f1, loop=self.loop) + res, ident = self.loop.run_until_complete(f2) + self.assertIsInstance(f2, futures.Future) + self.assertEqual(res, 'oi') + self.assertNotEqual(ident, threading.get_ident()) + + def test_wrap_future_future(self): + f1 = futures.Future(loop=self.loop) + f2 = futures.wrap_future(f1) + self.assertIs(f1, f2) + + @unittest.mock.patch('asyncio.futures.events') + def test_wrap_future_use_global_loop(self, m_events): + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = futures.wrap_future(f1) + self.assertIs(m_events.get_event_loop.return_value, f2._loop) + + +class FutureDoneCallbackTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def run_briefly(self): + test_utils.run_briefly(self.loop) + + def _make_callback(self, bag, thing): + # Create a callback function that appends thing to bag. + def bag_appender(future): + bag.append(thing) + return bag_appender + + def _new_future(self): + return futures.Future(loop=self.loop) + + def test_callbacks_invoked_on_set_result(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 42)) + f.add_done_callback(self._make_callback(bag, 17)) + + self.assertEqual(bag, []) + f.set_result('foo') + + self.run_briefly() + + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def test_callbacks_invoked_on_set_exception(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 100)) + + self.assertEqual(bag, []) + exc = RuntimeError() + f.set_exception(exc) + + self.run_briefly() + + self.assertEqual(bag, [100]) + self.assertEqual(f.exception(), exc) + + def test_remove_done_callback(self): + bag = [] + f = self._new_future() + cb1 = self._make_callback(bag, 1) + cb2 = self._make_callback(bag, 2) + cb3 = self._make_callback(bag, 3) + + # Add one cb1 and one cb2. + f.add_done_callback(cb1) + f.add_done_callback(cb2) + + # One instance of cb2 removed. Now there's only one cb1. + self.assertEqual(f.remove_done_callback(cb2), 1) + + # Never had any cb3 in there. + self.assertEqual(f.remove_done_callback(cb3), 0) + + # After this there will be 6 instances of cb1 and one of cb2. + f.add_done_callback(cb2) + for i in range(5): + f.add_done_callback(cb1) + + # Remove all instances of cb1. One cb2 remains. + self.assertEqual(f.remove_done_callback(cb1), 6) + + self.assertEqual(bag, []) + f.set_result('foo') + + self.run_briefly() + + self.assertEqual(bag, [2]) + self.assertEqual(f.result(), 'foo') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/locks_test.py b/tests/locks_test.py new file mode 100644 index 00000000..31b4d64b --- /dev/null +++ b/tests/locks_test.py @@ -0,0 +1,765 @@ +"""Tests for lock.py""" + +import unittest +import unittest.mock + +from asyncio import events +from asyncio import futures +from asyncio import locks +from asyncio import tasks +from asyncio import test_utils + + +class LockTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + lock = locks.Lock(loop=loop) + self.assertIs(lock._loop, loop) + + lock = locks.Lock(loop=self.loop) + self.assertIs(lock._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + lock = locks.Lock() + self.assertIs(lock._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + lock = locks.Lock(loop=self.loop) + self.assertTrue(repr(lock).endswith('[unlocked]>')) + + @tasks.coroutine + def acquire_lock(): + yield from lock + + self.loop.run_until_complete(acquire_lock()) + self.assertTrue(repr(lock).endswith('[locked]>')) + + def test_lock(self): + lock = locks.Lock(loop=self.loop) + + @tasks.coroutine + def acquire_lock(): + return (yield from lock) + + res = self.loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_acquire(self): + lock = locks.Lock(loop=self.loop) + result = [] + + self.assertTrue(self.loop.run_until_complete(lock.acquire())) + + @tasks.coroutine + def c1(result): + if (yield from lock.acquire()): + result.append(1) + return True + + @tasks.coroutine + def c2(result): + if (yield from lock.acquire()): + result.append(2) + return True + + @tasks.coroutine + def c3(result): + if (yield from lock.acquire()): + result.append(3) + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + t3 = tasks.Task(c3(result), loop=self.loop) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_acquire_cancel(self): + lock = locks.Lock(loop=self.loop) + self.assertTrue(self.loop.run_until_complete(lock.acquire())) + + task = tasks.Task(lock.acquire(), loop=self.loop) + self.loop.call_soon(task.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, task) + self.assertFalse(lock._waiters) + + def test_cancel_race(self): + # Several tasks: + # - A acquires the lock + # - B is blocked in aqcuire() + # - C is blocked in aqcuire() + # + # Now, concurrently: + # - B is cancelled + # - A releases the lock + # + # If B's waiter is marked cancelled but not yet removed from + # _waiters, A's release() call will crash when trying to set + # B's waiter; instead, it should move on to C's waiter. + + # Setup: A has the lock, b and c are waiting. + lock = locks.Lock(loop=self.loop) + + @tasks.coroutine + def lockit(name, blocker): + yield from lock.acquire() + try: + if blocker is not None: + yield from blocker + finally: + lock.release() + + fa = futures.Future(loop=self.loop) + ta = tasks.Task(lockit('A', fa), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertTrue(lock.locked()) + tb = tasks.Task(lockit('B', None), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(len(lock._waiters), 1) + tc = tasks.Task(lockit('C', None), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(len(lock._waiters), 2) + + # Create the race and check. + # Without the fix this failed at the last assert. + fa.set_result(None) + tb.cancel() + self.assertTrue(lock._waiters[0].cancelled()) + test_utils.run_briefly(self.loop) + self.assertFalse(lock.locked()) + self.assertTrue(ta.done()) + self.assertTrue(tb.cancelled()) + self.assertTrue(tc.done()) + + def test_release_not_acquired(self): + lock = locks.Lock(loop=self.loop) + + self.assertRaises(RuntimeError, lock.release) + + def test_release_no_waiters(self): + lock = locks.Lock(loop=self.loop) + self.loop.run_until_complete(lock.acquire()) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_context_manager(self): + lock = locks.Lock(loop=self.loop) + + @tasks.coroutine + def acquire_lock(): + return (yield from lock) + + with self.loop.run_until_complete(acquire_lock()): + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + def test_context_manager_no_yield(self): + lock = locks.Lock(loop=self.loop) + + try: + with lock: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + +class EventTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + ev = locks.Event(loop=loop) + self.assertIs(ev._loop, loop) + + ev = locks.Event(loop=self.loop) + self.assertIs(ev._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + ev = locks.Event() + self.assertIs(ev._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + ev = locks.Event(loop=self.loop) + self.assertTrue(repr(ev).endswith('[unset]>')) + + ev.set() + self.assertTrue(repr(ev).endswith('[set]>')) + + def test_wait(self): + ev = locks.Event(loop=self.loop) + self.assertFalse(ev.is_set()) + + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from ev.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from ev.wait()): + result.append(3) + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + t3 = tasks.Task(c3(result), loop=self.loop) + + ev.set() + test_utils.run_briefly(self.loop) + self.assertEqual([3, 1, 2], result) + + self.assertTrue(t1.done()) + self.assertIsNone(t1.result()) + self.assertTrue(t2.done()) + self.assertIsNone(t2.result()) + self.assertTrue(t3.done()) + self.assertIsNone(t3.result()) + + def test_wait_on_set(self): + ev = locks.Event(loop=self.loop) + ev.set() + + res = self.loop.run_until_complete(ev.wait()) + self.assertTrue(res) + + def test_wait_cancel(self): + ev = locks.Event(loop=self.loop) + + wait = tasks.Task(ev.wait(), loop=self.loop) + self.loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, wait) + self.assertFalse(ev._waiters) + + def test_clear(self): + ev = locks.Event(loop=self.loop) + self.assertFalse(ev.is_set()) + + ev.set() + self.assertTrue(ev.is_set()) + + ev.clear() + self.assertFalse(ev.is_set()) + + def test_clear_with_waiters(self): + ev = locks.Event(loop=self.loop) + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + return True + + t = tasks.Task(c1(result), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + ev.set() + ev.clear() + self.assertFalse(ev.is_set()) + + ev.set() + ev.set() + self.assertEqual(1, len(ev._waiters)) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertEqual(0, len(ev._waiters)) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + +class ConditionTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + cond = locks.Condition(loop=loop) + self.assertIs(cond._loop, loop) + + cond = locks.Condition(loop=self.loop) + self.assertIs(cond._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + cond = locks.Condition() + self.assertIs(cond._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_wait(self): + cond = locks.Condition(loop=self.loop) + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + return True + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue(self.loop.run_until_complete(cond.acquire())) + cond.notify() + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_wait_cancel(self): + cond = locks.Condition(loop=self.loop) + self.loop.run_until_complete(cond.acquire()) + + wait = tasks.Task(cond.wait(), loop=self.loop) + self.loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, wait) + self.assertFalse(cond._condition_waiters) + self.assertTrue(cond.locked()) + + def test_wait_unacquired(self): + cond = locks.Condition(loop=self.loop) + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, cond.wait()) + + def test_wait_for(self): + cond = locks.Condition(loop=self.loop) + presult = False + + def predicate(): + return presult + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate)): + result.append(1) + cond.release() + return True + + t = tasks.Task(c1(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + presult = True + self.loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + def test_wait_for_unacquired(self): + cond = locks.Condition(loop=self.loop) + + # predicate can return true immediately + res = self.loop.run_until_complete(cond.wait_for(lambda: [1, 2, 3])) + self.assertEqual([1, 2, 3], res) + + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, + cond.wait_for(lambda: False)) + + def test_notify(self): + cond = locks.Condition(loop=self.loop) + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + return True + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + cond.release() + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.notify(2048) + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_notify_all(self): + cond = locks.Condition(loop=self.loop) + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify_all() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + + def test_notify_unacquired(self): + cond = locks.Condition(loop=self.loop) + self.assertRaises(RuntimeError, cond.notify) + + def test_notify_all_unacquired(self): + cond = locks.Condition(loop=self.loop) + self.assertRaises(RuntimeError, cond.notify_all) + + +class SemaphoreTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + sem = locks.Semaphore(loop=loop) + self.assertIs(sem._loop, loop) + + sem = locks.Semaphore(loop=self.loop) + self.assertIs(sem._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + sem = locks.Semaphore() + self.assertIs(sem._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + sem = locks.Semaphore(loop=self.loop) + self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) + + self.loop.run_until_complete(sem.acquire()) + self.assertTrue(repr(sem).endswith('[locked]>')) + + def test_semaphore(self): + sem = locks.Semaphore(loop=self.loop) + self.assertEqual(1, sem._value) + + @tasks.coroutine + def acquire_lock(): + return (yield from sem) + + res = self.loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(sem.locked()) + self.assertEqual(0, sem._value) + + sem.release() + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + def test_semaphore_value(self): + self.assertRaises(ValueError, locks.Semaphore, -1) + + def test_acquire(self): + sem = locks.Semaphore(3, loop=self.loop) + result = [] + + self.assertTrue(self.loop.run_until_complete(sem.acquire())) + self.assertTrue(self.loop.run_until_complete(sem.acquire())) + self.assertFalse(sem.locked()) + + @tasks.coroutine + def c1(result): + yield from sem.acquire() + result.append(1) + return True + + @tasks.coroutine + def c2(result): + yield from sem.acquire() + result.append(2) + return True + + @tasks.coroutine + def c3(result): + yield from sem.acquire() + result.append(3) + return True + + @tasks.coroutine + def c4(result): + yield from sem.acquire() + result.append(4) + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(sem.locked()) + self.assertEqual(2, len(sem._waiters)) + self.assertEqual(0, sem._value) + + t4 = tasks.Task(c4(result), loop=self.loop) + + sem.release() + sem.release() + self.assertEqual(2, sem._value) + + test_utils.run_briefly(self.loop) + self.assertEqual(0, sem._value) + self.assertEqual([1, 2, 3], result) + self.assertTrue(sem.locked()) + self.assertEqual(1, len(sem._waiters)) + self.assertEqual(0, sem._value) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + self.assertFalse(t4.done()) + + # cleanup locked semaphore + sem.release() + + def test_acquire_cancel(self): + sem = locks.Semaphore(loop=self.loop) + self.loop.run_until_complete(sem.acquire()) + + acquire = tasks.Task(sem.acquire(), loop=self.loop) + self.loop.call_soon(acquire.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, acquire) + self.assertFalse(sem._waiters) + + def test_release_not_acquired(self): + sem = locks.Semaphore(bound=True, loop=self.loop) + + self.assertRaises(ValueError, sem.release) + + def test_release_no_waiters(self): + sem = locks.Semaphore(loop=self.loop) + self.loop.run_until_complete(sem.acquire()) + self.assertTrue(sem.locked()) + + sem.release() + self.assertFalse(sem.locked()) + + def test_context_manager(self): + sem = locks.Semaphore(2, loop=self.loop) + + @tasks.coroutine + def acquire_lock(): + return (yield from sem) + + with self.loop.run_until_complete(acquire_lock()): + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + with self.loop.run_until_complete(acquire_lock()): + self.assertTrue(sem.locked()) + + self.assertEqual(2, sem._value) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py new file mode 100644 index 00000000..c52ade05 --- /dev/null +++ b/tests/proactor_events_test.py @@ -0,0 +1,480 @@ +"""Tests for proactor_events.py""" + +import socket +import unittest +import unittest.mock + +import asyncio +from asyncio.proactor_events import BaseProactorEventLoop +from asyncio.proactor_events import _ProactorSocketTransport +from asyncio.proactor_events import _ProactorWritePipeTransport +from asyncio.proactor_events import _ProactorDuplexPipeTransport +from asyncio import test_utils + + +class ProactorSocketTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.proactor = unittest.mock.Mock() + self.loop._proactor = self.proactor + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) + self.sock = unittest.mock.Mock(socket.socket) + + def test_ctor(self): + fut = asyncio.Future(loop=self.loop) + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol, fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + self.protocol.connection_made(tr) + self.proactor.recv.assert_called_with(self.sock, 4096) + + def test_loop_reading(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._loop_reading() + self.loop._proactor.recv.assert_called_with(self.sock, 4096) + self.assertFalse(self.protocol.data_received.called) + self.assertFalse(self.protocol.eof_received.called) + + def test_loop_reading_data(self): + res = asyncio.Future(loop=self.loop) + res.set_result(b'data') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + + tr._read_fut = res + tr._loop_reading(res) + self.loop._proactor.recv.assert_called_with(self.sock, 4096) + self.protocol.data_received.assert_called_with(b'data') + + def test_loop_reading_no_data(self): + res = asyncio.Future(loop=self.loop) + res.set_result(b'') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + + self.assertRaises(AssertionError, tr._loop_reading, res) + + tr.close = unittest.mock.Mock() + tr._read_fut = res + tr._loop_reading(res) + self.assertFalse(self.loop._proactor.recv.called) + self.assertTrue(self.protocol.eof_received.called) + self.assertTrue(tr.close.called) + + def test_loop_reading_aborted(self): + err = self.loop._proactor.recv.side_effect = ConnectionAbortedError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with(err) + + def test_loop_reading_aborted_closing(self): + self.loop._proactor.recv.side_effect = ConnectionAbortedError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = True + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + + def test_loop_reading_aborted_is_fatal(self): + self.loop._proactor.recv.side_effect = ConnectionAbortedError() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = False + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + self.assertTrue(tr._fatal_error.called) + + def test_loop_reading_conn_reset_lost(self): + err = self.loop._proactor.recv.side_effect = ConnectionResetError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = False + tr._fatal_error = unittest.mock.Mock() + tr._force_close = unittest.mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + tr._force_close.assert_called_with(err) + + def test_loop_reading_exception(self): + err = self.loop._proactor.recv.side_effect = (OSError()) + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with(err) + + def test_write(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._loop_writing = unittest.mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, [b'data']) + self.assertTrue(tr._loop_writing.called) + + def test_write_no_data(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr.write(b'') + self.assertFalse(tr._buffer) + + def test_write_more(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = unittest.mock.Mock() + tr._loop_writing = unittest.mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, [b'data']) + self.assertFalse(tr._loop_writing.called) + + def test_loop_writing(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + self.loop._proactor.send.assert_called_with(self.sock, b'data') + self.loop._proactor.send.return_value.add_done_callback.\ + assert_called_with(tr._loop_writing) + + @unittest.mock.patch('asyncio.proactor_events.asyncio_log') + def test_loop_writing_err(self, m_log): + err = self.loop._proactor.send.side_effect = OSError() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + tr._fatal_error.assert_called_with(err) + tr._conn_lost = 1 + + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + self.assertEqual(tr._buffer, []) + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_loop_writing_stop(self): + fut = asyncio.Future(loop=self.loop) + fut.set_result(b'data') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = fut + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + + def test_loop_writing_closing(self): + fut = asyncio.Future(loop=self.loop) + fut.set_result(1) + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = fut + tr.close() + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test_abort(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._force_close = unittest.mock.Mock() + tr.abort() + tr._force_close.assert_called_with(None) + + def test_close(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr.close() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertTrue(tr._closing) + self.assertEqual(tr._conn_lost, 1) + + self.protocol.connection_lost.reset_mock() + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_write_fut(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = unittest.mock.Mock() + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_buffer(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + @unittest.mock.patch('asyncio.proactor_events.asyncio_log') + def test_fatal_error(self, m_logging): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._force_close = unittest.mock.Mock() + tr._fatal_error(None) + self.assertTrue(tr._force_close.called) + self.assertTrue(m_logging.exception.called) + + def test_force_close(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + read_fut = tr._read_fut = unittest.mock.Mock() + write_fut = tr._write_fut = unittest.mock.Mock() + tr._force_close(None) + + read_fut.cancel.assert_called_with() + write_fut.cancel.assert_called_with() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertEqual([], tr._buffer) + self.assertEqual(tr._conn_lost, 1) + + def test_force_close_idempotent(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = True + tr._force_close(None) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_fatal_error_2(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + tr._force_close(None) + + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertEqual([], tr._buffer) + + def test_call_connection_lost(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._call_connection_lost(None) + self.assertTrue(self.protocol.connection_lost.called) + self.assertTrue(self.sock.close.called) + + def test_write_eof(self): + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol) + self.assertTrue(tr.can_write_eof()) + tr.write_eof() + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.write_eof() + self.assertEqual(self.sock.shutdown.call_count, 1) + tr.close() + + def test_write_eof_buffer(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + f = asyncio.Future(loop=self.loop) + tr._loop._proactor.send.return_value = f + tr.write(b'data') + tr.write_eof() + self.assertTrue(tr._eof_written) + self.assertFalse(self.sock.shutdown.called) + tr._loop._proactor.send.assert_called_with(self.sock, b'data') + f.set_result(4) + self.loop._run_once() + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.close() + + def test_write_eof_write_pipe(self): + tr = _ProactorWritePipeTransport( + self.loop, self.sock, self.protocol) + self.assertTrue(tr.can_write_eof()) + tr.write_eof() + self.assertTrue(tr._closing) + self.loop._run_once() + self.assertTrue(self.sock.close.called) + tr.close() + + def test_write_eof_buffer_write_pipe(self): + tr = _ProactorWritePipeTransport(self.loop, self.sock, self.protocol) + f = asyncio.Future(loop=self.loop) + tr._loop._proactor.send.return_value = f + tr.write(b'data') + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.sock.shutdown.called) + tr._loop._proactor.send.assert_called_with(self.sock, b'data') + f.set_result(4) + self.loop._run_once() + self.loop._run_once() + self.assertTrue(self.sock.close.called) + tr.close() + + def test_write_eof_duplex_pipe(self): + tr = _ProactorDuplexPipeTransport( + self.loop, self.sock, self.protocol) + self.assertFalse(tr.can_write_eof()) + with self.assertRaises(NotImplementedError): + tr.write_eof() + tr.close() + + def test_pause_resume(self): + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol) + futures = [] + for msg in [b'data1', b'data2', b'data3', b'data4', b'']: + f = asyncio.Future(loop=self.loop) + f.set_result(msg) + futures.append(f) + self.loop._proactor.recv.side_effect = futures + self.loop._run_once() + self.assertFalse(tr._paused) + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data1') + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data2') + tr.pause() + self.assertTrue(tr._paused) + for i in range(10): + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data2') + tr.resume() + self.assertFalse(tr._paused) + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data3') + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data4') + tr.close() + + +class BaseProactorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.sock = unittest.mock.Mock(socket.socket) + self.proactor = unittest.mock.Mock() + + self.ssock, self.csock = unittest.mock.Mock(), unittest.mock.Mock() + + class EventLoop(BaseProactorEventLoop): + def _socketpair(s): + return (self.ssock, self.csock) + + self.loop = EventLoop(self.proactor) + + @unittest.mock.patch.object(BaseProactorEventLoop, 'call_soon') + @unittest.mock.patch.object(BaseProactorEventLoop, '_socketpair') + def test_ctor(self, socketpair, call_soon): + ssock, csock = socketpair.return_value = ( + unittest.mock.Mock(), unittest.mock.Mock()) + loop = BaseProactorEventLoop(self.proactor) + self.assertIs(loop._ssock, ssock) + self.assertIs(loop._csock, csock) + self.assertEqual(loop._internal_fds, 1) + call_soon.assert_called_with(loop._loop_self_reading) + + def test_close_self_pipe(self): + self.loop._close_self_pipe() + self.assertEqual(self.loop._internal_fds, 0) + self.assertTrue(self.ssock.close.called) + self.assertTrue(self.csock.close.called) + self.assertIsNone(self.loop._ssock) + self.assertIsNone(self.loop._csock) + + def test_close(self): + self.loop._close_self_pipe = unittest.mock.Mock() + self.loop.close() + self.assertTrue(self.loop._close_self_pipe.called) + self.assertTrue(self.proactor.close.called) + self.assertIsNone(self.loop._proactor) + + self.loop._close_self_pipe.reset_mock() + self.loop.close() + self.assertFalse(self.loop._close_self_pipe.called) + + def test_sock_recv(self): + self.loop.sock_recv(self.sock, 1024) + self.proactor.recv.assert_called_with(self.sock, 1024) + + def test_sock_sendall(self): + self.loop.sock_sendall(self.sock, b'data') + self.proactor.send.assert_called_with(self.sock, b'data') + + def test_sock_connect(self): + self.loop.sock_connect(self.sock, 123) + self.proactor.connect.assert_called_with(self.sock, 123) + + def test_sock_accept(self): + self.loop.sock_accept(self.sock) + self.proactor.accept.assert_called_with(self.sock) + + def test_socketpair(self): + self.assertRaises( + NotImplementedError, BaseProactorEventLoop, self.proactor) + + def test_make_socket_transport(self): + tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock()) + self.assertIsInstance(tr, _ProactorSocketTransport) + + def test_loop_self_reading(self): + self.loop._loop_self_reading() + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_fut(self): + fut = unittest.mock.Mock() + self.loop._loop_self_reading(fut) + self.assertTrue(fut.result.called) + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_exception(self): + self.loop.close = unittest.mock.Mock() + self.proactor.recv.side_effect = OSError() + self.assertRaises(OSError, self.loop._loop_self_reading) + self.assertTrue(self.loop.close.called) + + def test_write_to_self(self): + self.loop._write_to_self() + self.csock.send.assert_called_with(b'x') + + def test_process_events(self): + self.loop._process_events([]) + + @unittest.mock.patch('asyncio.proactor_events.asyncio_log') + def test_create_server(self, m_log): + pf = unittest.mock.Mock() + call_soon = self.loop.call_soon = unittest.mock.Mock() + + self.loop._start_serving(pf, self.sock) + self.assertTrue(call_soon.called) + + # callback + loop = call_soon.call_args[0][0] + loop() + self.proactor.accept.assert_called_with(self.sock) + + # conn + fut = unittest.mock.Mock() + fut.result.return_value = (unittest.mock.Mock(), unittest.mock.Mock()) + + make_tr = self.loop._make_socket_transport = unittest.mock.Mock() + loop(fut) + self.assertTrue(fut.result.called) + self.assertTrue(make_tr.called) + + # exception + fut.result.side_effect = OSError() + loop(fut) + self.assertTrue(self.sock.close.called) + self.assertTrue(m_log.exception.called) + + def test_create_server_cancel(self): + pf = unittest.mock.Mock() + call_soon = self.loop.call_soon = unittest.mock.Mock() + + self.loop._start_serving(pf, self.sock) + loop = call_soon.call_args[0][0] + + # cancelled + fut = asyncio.Future(loop=self.loop) + fut.cancel() + loop(fut) + self.assertTrue(self.sock.close.called) + + def test_stop_serving(self): + sock = unittest.mock.Mock() + self.loop._stop_serving(sock) + self.assertTrue(sock.close.called) + self.proactor._stop_serving.assert_called_with(sock) diff --git a/tests/queues_test.py b/tests/queues_test.py new file mode 100644 index 00000000..24805570 --- /dev/null +++ b/tests/queues_test.py @@ -0,0 +1,470 @@ +"""Tests for queues.py""" + +import unittest +import unittest.mock + +from asyncio import events +from asyncio import futures +from asyncio import locks +from asyncio import queues +from asyncio import tasks +from asyncio import test_utils + + +class _QueueTestBase(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + +class QueueBasicTests(_QueueTestBase): + + def _test_repr_or_str(self, fn, expect_id): + """Test Queue's repr or str. + + fn is repr or str. expect_id is True if we expect the Queue's id to + appear in fn(Queue()). + """ + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.2, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = queues.Queue(loop=loop) + self.assertTrue(fn(q).startswith(')') + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), 'Task()') + self.assertRaises(futures.CancelledError, + self.loop.run_until_complete, t) + self.assertEqual(repr(t), 'Task()') + t = tasks.Task(notmuch(), loop=self.loop) + self.loop.run_until_complete(t) + self.assertEqual(repr(t), "Task()") + + def test_task_repr_custom(self): + @tasks.coroutine + def coro(): + pass + + class T(futures.Future): + def __repr__(self): + return 'T[]' + + class MyTask(tasks.Task, T): + def __repr__(self): + return super().__repr__() + + t = MyTask(coro(), loop=self.loop) + self.assertEqual(repr(t), 'T[]()') + + def test_task_basics(self): + @tasks.coroutine + def outer(): + a = yield from inner1() + b = yield from inner2() + return a+b + + @tasks.coroutine + def inner1(): + return 42 + + @tasks.coroutine + def inner2(): + return 1000 + + t = outer() + self.assertEqual(self.loop.run_until_complete(t), 1042) + + def test_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def task(): + yield from tasks.sleep(10.0, loop=loop) + return 12 + + t = tasks.Task(task(), loop=loop) + loop.call_soon(t.cancel) + with self.assertRaises(futures.CancelledError): + loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) + self.assertFalse(t.cancel()) + + def test_cancel_yield(self): + @tasks.coroutine + def task(): + yield + yield + return 12 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) # start coro + t.cancel() + self.assertRaises( + futures.CancelledError, self.loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) + self.assertFalse(t.cancel()) + + def test_cancel_inner_future(self): + f = futures.Future(loop=self.loop) + + @tasks.coroutine + def task(): + yield from f + return 12 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) # start task + f.cancel() + with self.assertRaises(futures.CancelledError): + self.loop.run_until_complete(t) + self.assertTrue(f.cancelled()) + self.assertTrue(t.cancelled()) + + def test_cancel_both_task_and_inner_future(self): + f = futures.Future(loop=self.loop) + + @tasks.coroutine + def task(): + yield from f + return 12 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + + f.cancel() + t.cancel() + + with self.assertRaises(futures.CancelledError): + self.loop.run_until_complete(t) + + self.assertTrue(t.done()) + self.assertTrue(f.cancelled()) + self.assertTrue(t.cancelled()) + + def test_cancel_task_catching(self): + fut1 = futures.Future(loop=self.loop) + fut2 = futures.Future(loop=self.loop) + + @tasks.coroutine + def task(): + yield from fut1 + try: + yield from fut2 + except futures.CancelledError: + return 42 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut1) # White-box test. + fut1.set_result(None) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut2) # White-box test. + t.cancel() + self.assertTrue(fut2.cancelled()) + res = self.loop.run_until_complete(t) + self.assertEqual(res, 42) + self.assertFalse(t.cancelled()) + + def test_cancel_task_ignoring(self): + fut1 = futures.Future(loop=self.loop) + fut2 = futures.Future(loop=self.loop) + fut3 = futures.Future(loop=self.loop) + + @tasks.coroutine + def task(): + yield from fut1 + try: + yield from fut2 + except futures.CancelledError: + pass + res = yield from fut3 + return res + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut1) # White-box test. + fut1.set_result(None) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut2) # White-box test. + t.cancel() + self.assertTrue(fut2.cancelled()) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut3) # White-box test. + fut3.set_result(42) + res = self.loop.run_until_complete(t) + self.assertEqual(res, 42) + self.assertFalse(fut3.cancelled()) + self.assertFalse(t.cancelled()) + + def test_cancel_current_task(self): + loop = events.new_event_loop() + self.addCleanup(loop.close) + + @tasks.coroutine + def task(): + t.cancel() + self.assertTrue(t._must_cancel) # White-box test. + # The sleep should be cancelled immediately. + yield from tasks.sleep(100, loop=loop) + return 12 + + t = tasks.Task(task(), loop=loop) + self.assertRaises( + futures.CancelledError, loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t._must_cancel) # White-box test. + self.assertFalse(t.cancel()) + + def test_stop_while_run_in_complete(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.2, when) + when = yield 0.1 + self.assertAlmostEqual(0.3, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + x = 0 + waiters = [] + + @tasks.coroutine + def task(): + nonlocal x + while x < 10: + waiters.append(tasks.sleep(0.1, loop=loop)) + yield from waiters[-1] + x += 1 + if x == 2: + loop.stop() + + t = tasks.Task(task(), loop=loop) + self.assertRaises( + RuntimeError, loop.run_until_complete, t) + self.assertFalse(t.done()) + self.assertEqual(x, 2) + self.assertAlmostEqual(0.3, loop.time()) + + # close generators + for w in waiters: + w.close() + + def test_wait_for(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.2, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.4, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def foo(): + yield from tasks.sleep(0.2, loop=loop) + return 'done' + + fut = tasks.Task(foo(), loop=loop) + + with self.assertRaises(futures.TimeoutError): + loop.run_until_complete(tasks.wait_for(fut, 0.1, loop=loop)) + + self.assertFalse(fut.done()) + self.assertAlmostEqual(0.1, loop.time()) + + loop + # wait for result + res = loop.run_until_complete( + tasks.wait_for(fut, 0.3, loop=loop)) + self.assertEqual(res, 'done') + self.assertAlmostEqual(0.2, loop.time()) + + def test_wait_for_with_global_loop(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.2, when) + when = yield 0 + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def foo(): + yield from tasks.sleep(0.2, loop=loop) + return 'done' + + events.set_event_loop(loop) + try: + fut = tasks.Task(foo(), loop=loop) + with self.assertRaises(futures.TimeoutError): + loop.run_until_complete(tasks.wait_for(fut, 0.01)) + finally: + events.set_event_loop(None) + + self.assertAlmostEqual(0.01, loop.time()) + self.assertFalse(fut.done()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(fut) + + def test_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], loop=loop) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertEqual(res, 42) + self.assertAlmostEqual(0.15, loop.time()) + + # Doing it again should take no time and exercise a different path. + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + self.assertEqual(res, 42) + + def test_wait_with_global_loop(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + when = yield 0 + self.assertAlmostEqual(0.015, when) + yield 0.015 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.01, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.015, loop=loop), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + events.set_event_loop(loop) + try: + res = loop.run_until_complete( + tasks.Task(foo(), loop=loop)) + finally: + events.set_event_loop(None) + + self.assertEqual(res, 42) + + def test_wait_errors(self): + self.assertRaises( + ValueError, self.loop.run_until_complete, + tasks.wait(set(), loop=self.loop)) + + self.assertRaises( + ValueError, self.loop.run_until_complete, + tasks.wait([tasks.sleep(10.0, loop=self.loop)], + return_when=-1, loop=self.loop)) + + def test_wait_first_completed(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED, + loop=loop), + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertFalse(a.done()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + self.assertAlmostEqual(0.1, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_really_done(self): + # there is possibility that some tasks in the pending list + # became done but their callbacks haven't all been called yet + + @tasks.coroutine + def coro1(): + yield + + @tasks.coroutine + def coro2(): + yield + yield + + a = tasks.Task(coro1(), loop=self.loop) + b = tasks.Task(coro2(), loop=self.loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED, + loop=self.loop), + loop=self.loop) + + done, pending = self.loop.run_until_complete(task) + self.assertEqual({a, b}, done) + self.assertTrue(a.done()) + self.assertIsNone(a.result()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + + def test_wait_first_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + # first_exception, task already has exception + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + + @tasks.coroutine + def exc(): + raise ZeroDivisionError('err') + + b = tasks.Task(exc(), loop=loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION, + loop=loop), + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertAlmostEqual(0, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_first_exception_in_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + # first_exception, exception during waiting + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + + @tasks.coroutine + def exc(): + yield from tasks.sleep(0.01, loop=loop) + raise ZeroDivisionError('err') + + b = tasks.Task(exc(), loop=loop) + task = tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION, + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertAlmostEqual(0.01, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_with_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(0.15, loop=loop) + raise ZeroDivisionError('really') + + b = tasks.Task(sleeper(), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], loop=loop) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + def test_wait_with_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.11, when) + yield 0.11 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], timeout=0.11, + loop=loop) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.11, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_concurrent_complete(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + done, pending = loop.run_until_complete( + tasks.wait([b, a], timeout=0.1, loop=loop)) + + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + self.assertAlmostEqual(0.1, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_as_completed(self): + + def gen(): + yield 0 + yield 0 + yield 0.01 + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + completed = set() + time_shifted = False + + @tasks.coroutine + def sleeper(dt, x): + nonlocal time_shifted + yield from tasks.sleep(dt, loop=loop) + completed.add(x) + if not time_shifted and 'a' in completed and 'b' in completed: + time_shifted = True + loop.advance_time(0.14) + return x + + a = sleeper(0.01, 'a') + b = sleeper(0.01, 'b') + c = sleeper(0.15, 'c') + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([b, c, a], loop=loop): + values.append((yield from f)) + return values + + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + + # Doing it again should take no time and exercise a different path. + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + def test_as_completed_with_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.12, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0.1 + self.assertAlmostEqual(0.12, when) + yield 0.02 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.sleep(0.1, 'a', loop=loop) + b = tasks.sleep(0.15, 'b', loop=loop) + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([a, b], timeout=0.12, loop=loop): + try: + v = yield from f + values.append((1, v)) + except futures.TimeoutError as exc: + values.append((2, exc)) + return values + + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) + self.assertAlmostEqual(0.12, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_as_completed_reverse_wait(self): + + def gen(): + yield 0 + yield 0.05 + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.sleep(0.05, 'a', loop=loop) + b = tasks.sleep(0.10, 'b', loop=loop) + fs = {a, b} + futs = list(tasks.as_completed(fs, loop=loop)) + self.assertEqual(len(futs), 2) + + x = loop.run_until_complete(futs[1]) + self.assertEqual(x, 'a') + self.assertAlmostEqual(0.05, loop.time()) + loop.advance_time(0.05) + y = loop.run_until_complete(futs[0]) + self.assertEqual(y, 'b') + self.assertAlmostEqual(0.10, loop.time()) + + def test_as_completed_concurrent(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0 + self.assertAlmostEqual(0.05, when) + yield 0.05 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.sleep(0.05, 'a', loop=loop) + b = tasks.sleep(0.05, 'b', loop=loop) + fs = {a, b} + futs = list(tasks.as_completed(fs, loop=loop)) + self.assertEqual(len(futs), 2) + waiter = tasks.wait(futs, loop=loop) + done, pending = loop.run_until_complete(waiter) + self.assertEqual(set(f.result() for f in done), {'a', 'b'}) + + def test_sleep(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0.05 + self.assertAlmostEqual(0.1, when) + yield 0.05 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def sleeper(dt, arg): + yield from tasks.sleep(dt/2, loop=loop) + res = yield from tasks.sleep(dt/2, arg, loop=loop) + return res + + t = tasks.Task(sleeper(0.1, 'yeah'), loop=loop) + loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + self.assertAlmostEqual(0.1, loop.time()) + + def test_sleep_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + t = tasks.Task(tasks.sleep(10.0, 'yeah', loop=loop), + loop=loop) + + handle = None + orig_call_later = loop.call_later + + def call_later(self, delay, callback, *args): + nonlocal handle + handle = orig_call_later(self, delay, callback, *args) + return handle + + loop.call_later = call_later + test_utils.run_briefly(loop) + + self.assertFalse(handle._cancelled) + + t.cancel() + test_utils.run_briefly(loop) + self.assertTrue(handle._cancelled) + + def test_task_cancel_sleeping_task(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(5000, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + sleepfut = None + + @tasks.coroutine + def sleep(dt): + nonlocal sleepfut + sleepfut = tasks.sleep(dt, loop=loop) + yield from sleepfut + + @tasks.coroutine + def doit(): + sleeper = tasks.Task(sleep(5000), loop=loop) + loop.call_later(0.1, sleeper.cancel) + try: + yield from sleeper + except futures.CancelledError: + return 'cancelled' + else: + return 'slept in' + + doer = doit() + self.assertEqual(loop.run_until_complete(doer), 'cancelled') + self.assertAlmostEqual(0.1, loop.time()) + + def test_task_cancel_waiter_future(self): + fut = futures.Future(loop=self.loop) + + @tasks.coroutine + def coro(): + yield from fut + + task = tasks.Task(coro(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(task._fut_waiter, fut) + + task.cancel() + test_utils.run_briefly(self.loop) + self.assertRaises( + futures.CancelledError, self.loop.run_until_complete, task) + self.assertIsNone(task._fut_waiter) + self.assertTrue(fut.cancelled()) + + def test_step_in_completed_task(self): + @tasks.coroutine + def notmuch(): + return 'ko' + + task = tasks.Task(notmuch(), loop=self.loop) + task.set_result('ok') + + self.assertRaises(AssertionError, task._step) + + def test_step_result(self): + @tasks.coroutine + def notmuch(): + yield None + yield 1 + return 'ko' + + self.assertRaises( + RuntimeError, self.loop.run_until_complete, notmuch()) + + def test_step_result_future(self): + # If coroutine returns future, task waits on this future. + + class Fut(futures.Future): + def __init__(self, *args, **kwds): + self.cb_added = False + super().__init__(*args, **kwds) + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + fut = Fut(loop=self.loop) + result = None + + @tasks.coroutine + def wait_for_future(): + nonlocal result + result = yield from fut + + t = tasks.Task(wait_for_future(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertTrue(fut.cb_added) + + res = object() + fut.set_result(res) + test_utils.run_briefly(self.loop) + self.assertIs(res, result) + self.assertTrue(t.done()) + self.assertIsNone(t.result()) + + def test_step_with_baseexception(self): + @tasks.coroutine + def notmutch(): + raise BaseException() + + task = tasks.Task(notmutch(), loop=self.loop) + self.assertRaises(BaseException, task._step) + + self.assertTrue(task.done()) + self.assertIsInstance(task.exception(), BaseException) + + def test_baseexception_during_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(10, loop=loop) + + base_exc = BaseException() + + @tasks.coroutine + def notmutch(): + try: + yield from sleeper() + except futures.CancelledError: + raise base_exc + + task = tasks.Task(notmutch(), loop=loop) + test_utils.run_briefly(loop) + + task.cancel() + self.assertFalse(task.done()) + + self.assertRaises(BaseException, test_utils.run_briefly, loop) + + self.assertTrue(task.done()) + self.assertFalse(task.cancelled()) + self.assertIs(task.exception(), base_exc) + + def test_iscoroutinefunction(self): + def fn(): + pass + + self.assertFalse(tasks.iscoroutinefunction(fn)) + + def fn1(): + yield + self.assertFalse(tasks.iscoroutinefunction(fn1)) + + @tasks.coroutine + def fn2(): + yield + self.assertTrue(tasks.iscoroutinefunction(fn2)) + + def test_yield_vs_yield_from(self): + fut = futures.Future(loop=self.loop) + + @tasks.coroutine + def wait_for_future(): + yield fut + + task = wait_for_future() + with self.assertRaises(RuntimeError): + self.loop.run_until_complete(task) + + self.assertFalse(fut.done()) + + def test_yield_vs_yield_from_generator(self): + @tasks.coroutine + def coro(): + yield + + @tasks.coroutine + def wait_for_future(): + yield coro() + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, task) + + def test_coroutine_non_gen_function(self): + @tasks.coroutine + def func(): + return 'test' + + self.assertTrue(tasks.iscoroutinefunction(func)) + + coro = func() + self.assertTrue(tasks.iscoroutine(coro)) + + res = self.loop.run_until_complete(coro) + self.assertEqual(res, 'test') + + def test_coroutine_non_gen_function_return_future(self): + fut = futures.Future(loop=self.loop) + + @tasks.coroutine + def func(): + return fut + + @tasks.coroutine + def coro(): + fut.set_result('test') + + t1 = tasks.Task(func(), loop=self.loop) + t2 = tasks.Task(coro(), loop=self.loop) + res = self.loop.run_until_complete(t1) + self.assertEqual(res, 'test') + self.assertIsNone(t2.result()) + + # Some thorough tests for cancellation propagation through + # coroutines, tasks and wait(). + + def test_yield_future_passes_cancel(self): + # Cancelling outer() cancels inner() cancels waiter. + proof = 0 + waiter = futures.Future(loop=self.loop) + + @tasks.coroutine + def inner(): + nonlocal proof + try: + yield from waiter + except futures.CancelledError: + proof += 1 + raise + else: + self.fail('got past sleep() in inner()') + + @tasks.coroutine + def outer(): + nonlocal proof + try: + yield from inner() + except futures.CancelledError: + proof += 100 # Expect this path. + else: + proof += 10 + + f = tasks.async(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + self.loop.run_until_complete(f) + self.assertEqual(proof, 101) + self.assertTrue(waiter.cancelled()) + + def test_yield_wait_does_not_shield_cancel(self): + # Cancelling outer() makes wait() return early, leaves inner() + # running. + proof = 0 + waiter = futures.Future(loop=self.loop) + + @tasks.coroutine + def inner(): + nonlocal proof + yield from waiter + proof += 1 + + @tasks.coroutine + def outer(): + nonlocal proof + d, p = yield from tasks.wait([inner()], loop=self.loop) + proof += 100 + + f = tasks.async(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + self.assertRaises( + futures.CancelledError, self.loop.run_until_complete, f) + waiter.set_result(None) + test_utils.run_briefly(self.loop) + self.assertEqual(proof, 1) + + def test_shield_result(self): + inner = futures.Future(loop=self.loop) + outer = tasks.shield(inner) + inner.set_result(42) + res = self.loop.run_until_complete(outer) + self.assertEqual(res, 42) + + def test_shield_exception(self): + inner = futures.Future(loop=self.loop) + outer = tasks.shield(inner) + test_utils.run_briefly(self.loop) + exc = RuntimeError('expected') + inner.set_exception(exc) + test_utils.run_briefly(self.loop) + self.assertIs(outer.exception(), exc) + + def test_shield_cancel(self): + inner = futures.Future(loop=self.loop) + outer = tasks.shield(inner) + test_utils.run_briefly(self.loop) + inner.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(outer.cancelled()) + + def test_shield_shortcut(self): + fut = futures.Future(loop=self.loop) + fut.set_result(42) + res = self.loop.run_until_complete(tasks.shield(fut)) + self.assertEqual(res, 42) + + def test_shield_effect(self): + # Cancelling outer() does not affect inner(). + proof = 0 + waiter = futures.Future(loop=self.loop) + + @tasks.coroutine + def inner(): + nonlocal proof + yield from waiter + proof += 1 + + @tasks.coroutine + def outer(): + nonlocal proof + yield from tasks.shield(inner(), loop=self.loop) + proof += 100 + + f = tasks.async(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + with self.assertRaises(futures.CancelledError): + self.loop.run_until_complete(f) + waiter.set_result(None) + test_utils.run_briefly(self.loop) + self.assertEqual(proof, 1) + + def test_shield_gather(self): + child1 = futures.Future(loop=self.loop) + child2 = futures.Future(loop=self.loop) + parent = tasks.gather(child1, child2, loop=self.loop) + outer = tasks.shield(parent, loop=self.loop) + test_utils.run_briefly(self.loop) + outer.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(outer.cancelled()) + child1.set_result(1) + child2.set_result(2) + test_utils.run_briefly(self.loop) + self.assertEqual(parent.result(), [1, 2]) + + def test_gather_shield(self): + child1 = futures.Future(loop=self.loop) + child2 = futures.Future(loop=self.loop) + inner1 = tasks.shield(child1, loop=self.loop) + inner2 = tasks.shield(child2, loop=self.loop) + parent = tasks.gather(inner1, inner2, loop=self.loop) + test_utils.run_briefly(self.loop) + parent.cancel() + # This should cancel inner1 and inner2 but bot child1 and child2. + test_utils.run_briefly(self.loop) + self.assertIsInstance(parent.exception(), futures.CancelledError) + self.assertTrue(inner1.cancelled()) + self.assertTrue(inner2.cancelled()) + child1.set_result(1) + child2.set_result(2) + test_utils.run_briefly(self.loop) + + +class GatherTestsBase: + + def setUp(self): + self.one_loop = test_utils.TestLoop() + self.other_loop = test_utils.TestLoop() + + def tearDown(self): + self.one_loop.close() + self.other_loop.close() + + def _run_loop(self, loop): + while loop._ready: + test_utils.run_briefly(loop) + + def _check_success(self, **kwargs): + a, b, c = [futures.Future(loop=self.one_loop) for i in range(3)] + fut = tasks.gather(*self.wrap_futures(a, b, c), **kwargs) + cb = Mock() + fut.add_done_callback(cb) + b.set_result(1) + a.set_result(2) + self._run_loop(self.one_loop) + self.assertEqual(cb.called, False) + self.assertFalse(fut.done()) + c.set_result(3) + self._run_loop(self.one_loop) + cb.assert_called_once_with(fut) + self.assertEqual(fut.result(), [2, 1, 3]) + + def test_success(self): + self._check_success() + self._check_success(return_exceptions=False) + + def test_result_exception_success(self): + self._check_success(return_exceptions=True) + + def test_one_exception(self): + a, b, c, d, e = [futures.Future(loop=self.one_loop) for i in range(5)] + fut = tasks.gather(*self.wrap_futures(a, b, c, d, e)) + cb = Mock() + fut.add_done_callback(cb) + exc = ZeroDivisionError() + a.set_result(1) + b.set_exception(exc) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertIs(fut.exception(), exc) + # Does nothing + c.set_result(3) + d.cancel() + e.set_exception(RuntimeError()) + + def test_return_exceptions(self): + a, b, c, d = [futures.Future(loop=self.one_loop) for i in range(4)] + fut = tasks.gather(*self.wrap_futures(a, b, c, d), + return_exceptions=True) + cb = Mock() + fut.add_done_callback(cb) + exc = ZeroDivisionError() + exc2 = RuntimeError() + b.set_result(1) + c.set_exception(exc) + a.set_result(3) + self._run_loop(self.one_loop) + self.assertFalse(fut.done()) + d.set_exception(exc2) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertEqual(fut.result(), [3, 1, exc, exc2]) + + +class FutureGatherTests(GatherTestsBase, unittest.TestCase): + + def wrap_futures(self, *futures): + return futures + + def _check_empty_sequence(self, seq_or_iter): + events.set_event_loop(self.one_loop) + self.addCleanup(events.set_event_loop, None) + fut = tasks.gather(*seq_or_iter) + self.assertIsInstance(fut, futures.Future) + self.assertIs(fut._loop, self.one_loop) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + self.assertEqual(fut.result(), []) + fut = tasks.gather(*seq_or_iter, loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + + def test_constructor_empty_sequence(self): + self._check_empty_sequence([]) + self._check_empty_sequence(()) + self._check_empty_sequence(set()) + self._check_empty_sequence(iter("")) + + def test_constructor_heterogenous_futures(self): + fut1 = futures.Future(loop=self.one_loop) + fut2 = futures.Future(loop=self.other_loop) + with self.assertRaises(ValueError): + tasks.gather(fut1, fut2) + with self.assertRaises(ValueError): + tasks.gather(fut1, loop=self.other_loop) + + def test_constructor_homogenous_futures(self): + children = [futures.Future(loop=self.other_loop) for i in range(3)] + fut = tasks.gather(*children) + self.assertIs(fut._loop, self.other_loop) + self._run_loop(self.other_loop) + self.assertFalse(fut.done()) + fut = tasks.gather(*children, loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + self._run_loop(self.other_loop) + self.assertFalse(fut.done()) + + def test_one_cancellation(self): + a, b, c, d, e = [futures.Future(loop=self.one_loop) for i in range(5)] + fut = tasks.gather(a, b, c, d, e) + cb = Mock() + fut.add_done_callback(cb) + a.set_result(1) + b.cancel() + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertFalse(fut.cancelled()) + self.assertIsInstance(fut.exception(), futures.CancelledError) + # Does nothing + c.set_result(3) + d.cancel() + e.set_exception(RuntimeError()) + + def test_result_exception_one_cancellation(self): + a, b, c, d, e, f = [futures.Future(loop=self.one_loop) + for i in range(6)] + fut = tasks.gather(a, b, c, d, e, f, return_exceptions=True) + cb = Mock() + fut.add_done_callback(cb) + a.set_result(1) + zde = ZeroDivisionError() + b.set_exception(zde) + c.cancel() + self._run_loop(self.one_loop) + self.assertFalse(fut.done()) + d.set_result(3) + e.cancel() + rte = RuntimeError() + f.set_exception(rte) + res = self.one_loop.run_until_complete(fut) + self.assertIsInstance(res[2], futures.CancelledError) + self.assertIsInstance(res[4], futures.CancelledError) + res[2] = res[4] = None + self.assertEqual(res, [1, zde, None, 3, None, rte]) + cb.assert_called_once_with(fut) + + +class CoroutineGatherTests(GatherTestsBase, unittest.TestCase): + + def setUp(self): + super().setUp() + events.set_event_loop(self.one_loop) + + def tearDown(self): + events.set_event_loop(None) + super().tearDown() + + def wrap_futures(self, *futures): + coros = [] + for fut in futures: + def coro(fut=fut): + return (yield from fut) + coros.append(coro()) + return coros + + def test_constructor_loop_selection(self): + @tasks.coroutine + def coro(): + return 'abc' + fut = tasks.gather(coro(), coro()) + self.assertIs(fut._loop, self.one_loop) + fut = tasks.gather(coro(), coro(), loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + + def test_cancellation_broadcast(self): + # Cancelling outer() cancels all children. + proof = 0 + waiter = futures.Future(loop=self.one_loop) + + @tasks.coroutine + def inner(): + nonlocal proof + yield from waiter + proof += 1 + + child1 = tasks.async(inner(), loop=self.one_loop) + child2 = tasks.async(inner(), loop=self.one_loop) + gatherer = None + + @tasks.coroutine + def outer(): + nonlocal proof, gatherer + gatherer = tasks.gather(child1, child2, loop=self.one_loop) + yield from gatherer + proof += 100 + + f = tasks.async(outer(), loop=self.one_loop) + test_utils.run_briefly(self.one_loop) + self.assertTrue(f.cancel()) + with self.assertRaises(futures.CancelledError): + self.one_loop.run_until_complete(f) + self.assertFalse(gatherer.cancel()) + self.assertTrue(waiter.cancelled()) + self.assertTrue(child1.cancelled()) + self.assertTrue(child2.cancelled()) + test_utils.run_briefly(self.one_loop) + self.assertEqual(proof, 0) + + def test_exception_marking(self): + # Test for the first line marked "Mark exception retrieved." + + @tasks.coroutine + def inner(f): + yield from f + raise RuntimeError('should not be ignored') + + a = futures.Future(loop=self.one_loop) + b = futures.Future(loop=self.one_loop) + + @tasks.coroutine + def outer(): + yield from tasks.gather(inner(a), inner(b), loop=self.one_loop) + + f = tasks.async(outer(), loop=self.one_loop) + test_utils.run_briefly(self.one_loop) + a.set_result(None) + test_utils.run_briefly(self.one_loop) + b.set_result(None) + test_utils.run_briefly(self.one_loop) + self.assertIsInstance(f.exception(), RuntimeError) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/transports_test.py b/tests/transports_test.py new file mode 100644 index 00000000..fce2e6f5 --- /dev/null +++ b/tests/transports_test.py @@ -0,0 +1,55 @@ +"""Tests for transports.py.""" + +import unittest +import unittest.mock + +from asyncio import transports + + +class TransportTests(unittest.TestCase): + + def test_ctor_extra_is_none(self): + transport = transports.Transport() + self.assertEqual(transport._extra, {}) + + def test_get_extra_info(self): + transport = transports.Transport({'extra': 'info'}) + self.assertEqual('info', transport.get_extra_info('extra')) + self.assertIsNone(transport.get_extra_info('unknown')) + + default = object() + self.assertIs(default, transport.get_extra_info('unknown', default)) + + def test_writelines(self): + transport = transports.Transport() + transport.write = unittest.mock.Mock() + + transport.writelines(['line1', 'line2', 'line3']) + self.assertEqual(3, transport.write.call_count) + + def test_not_implemented(self): + transport = transports.Transport() + + self.assertRaises(NotImplementedError, transport.write, 'data') + self.assertRaises(NotImplementedError, transport.write_eof) + self.assertRaises(NotImplementedError, transport.can_write_eof) + self.assertRaises(NotImplementedError, transport.pause) + self.assertRaises(NotImplementedError, transport.resume) + self.assertRaises(NotImplementedError, transport.close) + self.assertRaises(NotImplementedError, transport.abort) + + def test_dgram_not_implemented(self): + transport = transports.DatagramTransport() + + self.assertRaises(NotImplementedError, transport.sendto, 'data') + self.assertRaises(NotImplementedError, transport.abort) + + def test_subprocess_transport_not_implemented(self): + transport = transports.SubprocessTransport() + + self.assertRaises(NotImplementedError, transport.get_pid) + self.assertRaises(NotImplementedError, transport.get_returncode) + self.assertRaises(NotImplementedError, transport.get_pipe_transport, 1) + self.assertRaises(NotImplementedError, transport.send_signal, 1) + self.assertRaises(NotImplementedError, transport.terminate) + self.assertRaises(NotImplementedError, transport.kill) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py new file mode 100644 index 00000000..42dd919b --- /dev/null +++ b/tests/unix_events_test.py @@ -0,0 +1,749 @@ +"""Tests for unix_events.py.""" + +import gc +import errno +import io +import pprint +import signal +import stat +import sys +import unittest +import unittest.mock + + +from asyncio import events +from asyncio import futures +from asyncio import protocols +from asyncio import test_utils +from asyncio import unix_events + + +@unittest.skipUnless(signal, 'Signals are not supported') +class SelectorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.loop = unix_events.SelectorEventLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_check_signal(self): + self.assertRaises( + TypeError, self.loop._check_signal, '1') + self.assertRaises( + ValueError, self.loop._check_signal, signal.NSIG + 1) + + def test_handle_signal_no_handler(self): + self.loop._handle_signal(signal.NSIG + 1, ()) + + def test_handle_signal_cancelled_handler(self): + h = events.Handle(unittest.mock.Mock(), ()) + h.cancel() + self.loop._signal_handlers[signal.NSIG + 1] = h + self.loop.remove_signal_handler = unittest.mock.Mock() + self.loop._handle_signal(signal.NSIG + 1, ()) + self.loop.remove_signal_handler.assert_called_with(signal.NSIG + 1) + + @unittest.mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler_setup_error(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.set_wakeup_fd.side_effect = ValueError + + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + cb = lambda: True + self.loop.add_signal_handler(signal.SIGHUP, cb) + h = self.loop._signal_handlers.get(signal.SIGHUP) + self.assertTrue(isinstance(h, events.Handle)) + self.assertEqual(h._callback, cb) + + @unittest.mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler_install_error(self, m_signal): + m_signal.NSIG = signal.NSIG + + def set_wakeup_fd(fd): + if fd == -1: + raise ValueError() + m_signal.set_wakeup_fd = set_wakeup_fd + + class Err(OSError): + errno = errno.EFAULT + m_signal.signal.side_effect = Err + + self.assertRaises( + Err, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('asyncio.unix_events.signal') + @unittest.mock.patch('asyncio.unix_events.asyncio_log') + def test_add_signal_handler_install_error2(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.loop._signal_handlers[signal.SIGHUP] = lambda: True + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(1, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('asyncio.unix_events.signal') + @unittest.mock.patch('asyncio.unix_events.asyncio_log') + def test_add_signal_handler_install_error3(self, m_logging, m_signal): + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + m_signal.NSIG = signal.NSIG + + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(2, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + self.assertTrue( + self.loop.remove_signal_handler(signal.SIGHUP)) + self.assertTrue(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0]) + + @unittest.mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler_2(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.SIGINT = signal.SIGINT + + self.loop.add_signal_handler(signal.SIGINT, lambda: True) + self.loop._signal_handlers[signal.SIGHUP] = object() + m_signal.set_wakeup_fd.reset_mock() + + self.assertTrue( + self.loop.remove_signal_handler(signal.SIGINT)) + self.assertFalse(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGINT, m_signal.default_int_handler), + m_signal.signal.call_args[0]) + + @unittest.mock.patch('asyncio.unix_events.signal') + @unittest.mock.patch('asyncio.unix_events.asyncio_log') + def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.set_wakeup_fd.side_effect = ValueError + + self.loop.remove_signal_handler(signal.SIGHUP) + self.assertTrue(m_logging.info) + + @unittest.mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler_error(self, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.signal.side_effect = OSError + + self.assertRaises( + OSError, self.loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler_error2(self, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.assertRaises( + RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = True + m_WIFSIGNALED.return_value = False + m_WEXITSTATUS.return_value = 3 + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + transp._process_exited.assert_called_with(3) + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_signal(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = False + m_WIFSIGNALED.return_value = True + m_WTERMSIG.return_value = 1 + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + transp._process_exited.assert_called_with(-1) + self.assertFalse(m_WEXITSTATUS.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_zero_pid(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(0, object()), ChildProcessError] + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m_WEXITSTATUS.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_not_registered_subprocess(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = True + m_WIFSIGNALED.return_value = False + m_WEXITSTATUS.return_value = 3 + + self.loop._sig_chld() + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_unknown_status(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = False + m_WIFSIGNALED.return_value = False + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('asyncio.unix_events.asyncio_log') + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_unknown_status_in_handler(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG, + m_log): + m_waitpid.side_effect = Exception + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m_WEXITSTATUS.called) + m_log.exception.assert_called_with( + 'Unknown exception in SIGCHLD handler') + + +class UnixReadPipeTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(protocols.Protocol) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + + fcntl_patcher = unittest.mock.patch('fcntl.fcntl') + fcntl_patcher.start() + self.addCleanup(fcntl_patcher.stop) + + def test_ctor(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = futures.Future(loop=self.loop) + unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol, fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + + @unittest.mock.patch('os.read') + def test__read_ready(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.return_value = b'data' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.data_received.assert_called_with(b'data') + + @unittest.mock.patch('os.read') + def test__read_ready_eof(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.return_value = b'' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.eof_received.assert_called_with() + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('os.read') + def test__read_ready_blocked(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.side_effect = BlockingIOError + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.data_received.called) + + @unittest.mock.patch('asyncio.log.asyncio_log.exception') + @unittest.mock.patch('os.read') + def test__read_ready_error(self, m_read, m_logexc): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + err = OSError() + m_read.side_effect = err + tr._close = unittest.mock.Mock() + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + tr._close.assert_called_with(err) + m_logexc.assert_called_with('Fatal error for %s', tr) + + @unittest.mock.patch('os.read') + def test_pause(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + m = unittest.mock.Mock() + self.loop.add_reader(5, m) + tr.pause() + self.assertFalse(self.loop.readers) + + @unittest.mock.patch('os.read') + def test_resume(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr.resume() + self.loop.assert_reader(5, tr._read_ready) + + @unittest.mock.patch('os.read') + def test_close(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr._close = unittest.mock.Mock() + tr.close() + tr._close.assert_called_with(None) + + @unittest.mock.patch('os.read') + def test_close_already_closing(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr._closing = True + tr._close = unittest.mock.Mock() + tr.close() + self.assertFalse(tr._close.called) + + @unittest.mock.patch('os.read') + def test__close(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = object() + tr._close(err) + self.assertTrue(tr._closing) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + + def test__call_connection_lost(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test__call_connection_lost_with_err(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + +class UnixWritePipeTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(protocols.BaseProtocol) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + + fcntl_patcher = unittest.mock.patch('fcntl.fcntl') + fcntl_patcher.start() + self.addCleanup(fcntl_patcher.stop) + + fstat_patcher = unittest.mock.patch('os.fstat') + m_fstat = fstat_patcher.start() + st = unittest.mock.Mock() + st.st_mode = stat.S_IFIFO + m_fstat.return_value = st + self.addCleanup(fstat_patcher.stop) + + def test_ctor(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = futures.Future(loop=self.loop) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol, fut) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.assertEqual(None, fut.result()) + + def test_can_write_eof(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.assertTrue(tr.can_write_eof()) + + @unittest.mock.patch('os.write') + def test_write(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.return_value = 4 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_no_data(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write(b'') + self.assertFalse(m_write.called) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_partial(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.return_value = 2 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'ta'], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_buffer(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'previous'] + tr.write(b'data') + self.assertFalse(m_write.called) + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'previous', b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_again(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.side_effect = BlockingIOError() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('asyncio.unix_events.asyncio_log') + @unittest.mock.patch('os.write') + def test_write_err(self, m_write, m_log): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + m_write.side_effect = err + tr._fatal_error = unittest.mock.Mock() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + tr._fatal_error.assert_called_with(err) + self.assertEqual(1, tr._conn_lost) + + tr.write(b'data') + self.assertEqual(2, tr._conn_lost) + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + m_log.warning.assert_called_with( + 'os.write(pipe, data) raised exception.') + + def test__read_ready(self): + tr = unix_events._UnixWritePipeTransport(self.loop, self.pipe, + self.protocol) + tr._read_ready() + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + self.assertTrue(tr._closing) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('os.write') + def test__write_ready(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_partial(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 3 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'a'], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_again(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.side_effect = BlockingIOError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_empty(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 0 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('asyncio.log.asyncio_log.exception') + @unittest.mock.patch('os.write') + def test__write_ready_err(self, m_write, m_logexc): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.side_effect = err = OSError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + m_logexc.assert_called_with('Fatal error for %s', tr) + self.assertEqual(1, tr._conn_lost) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + + @unittest.mock.patch('os.write') + def test__write_ready_closing(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._closing = True + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) + self.assertEqual([], tr._buffer) + self.protocol.connection_lost.assert_called_with(None) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('os.write') + def test_abort(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + self.loop.add_reader(5, tr._read_ready) + tr._buffer = [b'da', b'ta'] + tr.abort() + self.assertFalse(m_write.called) + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test__call_connection_lost(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test__call_connection_lost_with_err(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test_close(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr.close() + tr.write_eof.assert_called_with() + + def test_close_closing(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr._closing = True + tr.close() + self.assertFalse(tr.write_eof.called) + + def test_write_eof(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test_write_eof_pending(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr._buffer = [b'data'] + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.protocol.connection_lost.called) diff --git a/tests/windows_events_test.py b/tests/windows_events_test.py new file mode 100644 index 00000000..589df8a4 --- /dev/null +++ b/tests/windows_events_test.py @@ -0,0 +1,91 @@ +import os +import unittest + +import asyncio + +from asyncio import windows_events +from asyncio import protocols +from asyncio import streams +from asyncio import transports +from asyncio import test_utils + + +class UpperProto(protocols.Protocol): + def __init__(self): + self.buf = [] + + def connection_made(self, trans): + self.trans = trans + + def data_received(self, data): + self.buf.append(data) + if b'\n' in data: + self.trans.write(b''.join(self.buf).upper()) + self.trans.close() + + +class ProactorTests(unittest.TestCase): + + def setUp(self): + self.loop = windows_events.ProactorEventLoop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + self.loop = None + + def test_close(self): + a, b = self.loop._socketpair() + trans = self.loop._make_socket_transport(a, protocols.Protocol()) + f = asyncio.async(self.loop.sock_recv(b, 100)) + trans.close() + self.loop.run_until_complete(f) + self.assertEqual(f.result(), b'') + + def test_double_bind(self): + ADDRESS = r'\\.\pipe\test_double_bind-%s' % os.getpid() + server1 = windows_events.PipeServer(ADDRESS) + with self.assertRaises(PermissionError): + server2 = windows_events.PipeServer(ADDRESS) + server1.close() + + def test_pipe(self): + res = self.loop.run_until_complete(self._test_pipe()) + self.assertEqual(res, 'done') + + def _test_pipe(self): + ADDRESS = r'\\.\pipe\_test_pipe-%s' % os.getpid() + + with self.assertRaises(FileNotFoundError): + yield from self.loop.create_pipe_connection( + protocols.Protocol, ADDRESS) + + [server] = yield from self.loop.start_serving_pipe( + UpperProto, ADDRESS) + self.assertIsInstance(server, windows_events.PipeServer) + + clients = [] + for i in range(5): + stream_reader = streams.StreamReader(loop=self.loop) + protocol = streams.StreamReaderProtocol(stream_reader) + trans, proto = yield from self.loop.create_pipe_connection( + lambda:protocol, ADDRESS) + self.assertIsInstance(trans, transports.Transport) + self.assertEqual(protocol, proto) + clients.append((stream_reader, trans)) + + for i, (r, w) in enumerate(clients): + w.write('lower-{}\n'.format(i).encode()) + + for i, (r, w) in enumerate(clients): + response = yield from r.readline() + self.assertEqual(response, 'LOWER-{}\n'.format(i).encode()) + w.close() + + server.close() + + with self.assertRaises(FileNotFoundError): + yield from self.loop.create_pipe_connection( + protocols.Protocol, ADDRESS) + + return 'done' diff --git a/tests/windows_utils_test.py b/tests/windows_utils_test.py new file mode 100644 index 00000000..24c407f4 --- /dev/null +++ b/tests/windows_utils_test.py @@ -0,0 +1,132 @@ +"""Tests for window_utils""" + +import sys +import test.support +import unittest +import unittest.mock +import _winapi + +from asyncio import windows_utils +from asyncio import _overlapped + + +class WinsocketpairTests(unittest.TestCase): + + def test_winsocketpair(self): + ssock, csock = windows_utils.socketpair() + + csock.send(b'xxx') + self.assertEqual(b'xxx', ssock.recv(1024)) + + csock.close() + ssock.close() + + @unittest.mock.patch('asyncio.windows_utils.socket') + def test_winsocketpair_exc(self, m_socket): + m_socket.socket.return_value.getsockname.return_value = ('', 12345) + m_socket.socket.return_value.accept.return_value = object(), object() + m_socket.socket.return_value.connect.side_effect = OSError() + + self.assertRaises(OSError, windows_utils.socketpair) + + +class PipeTests(unittest.TestCase): + + def test_pipe_overlapped(self): + h1, h2 = windows_utils.pipe(overlapped=(True, True)) + try: + ov1 = _overlapped.Overlapped() + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, 0) + + ov1.ReadFile(h1, 100) + self.assertTrue(ov1.pending) + self.assertEqual(ov1.error, _winapi.ERROR_IO_PENDING) + ERROR_IO_INCOMPLETE = 996 + try: + ov1.getresult() + except OSError as e: + self.assertEqual(e.winerror, ERROR_IO_INCOMPLETE) + else: + raise RuntimeError('expected ERROR_IO_INCOMPLETE') + + ov2 = _overlapped.Overlapped() + self.assertFalse(ov2.pending) + self.assertEqual(ov2.error, 0) + + ov2.WriteFile(h2, b"hello") + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + + res = _winapi.WaitForMultipleObjects([ov2.event], False, 100) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, ERROR_IO_INCOMPLETE) + self.assertFalse(ov2.pending) + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + self.assertEqual(ov1.getresult(), b"hello") + finally: + _winapi.CloseHandle(h1) + _winapi.CloseHandle(h2) + + def test_pipe_handle(self): + h, _ = windows_utils.pipe(overlapped=(True, True)) + _winapi.CloseHandle(_) + p = windows_utils.PipeHandle(h) + self.assertEqual(p.fileno(), h) + self.assertEqual(p.handle, h) + + # check garbage collection of p closes handle + del p + test.support.gc_collect() + try: + _winapi.CloseHandle(h) + except OSError as e: + self.assertEqual(e.winerror, 6) # ERROR_INVALID_HANDLE + else: + raise RuntimeError('expected ERROR_INVALID_HANDLE') + + +class PopenTests(unittest.TestCase): + + def test_popen(self): + command = r"""if 1: + import sys + s = sys.stdin.readline() + sys.stdout.write(s.upper()) + sys.stderr.write('stderr') + """ + msg = b"blah\n" + + p = windows_utils.Popen([sys.executable, '-c', command], + stdin=windows_utils.PIPE, + stdout=windows_utils.PIPE, + stderr=windows_utils.PIPE) + + for f in [p.stdin, p.stdout, p.stderr]: + self.assertIsInstance(f, windows_utils.PipeHandle) + + ovin = _overlapped.Overlapped() + ovout = _overlapped.Overlapped() + overr = _overlapped.Overlapped() + + ovin.WriteFile(p.stdin.handle, msg) + ovout.ReadFile(p.stdout.handle, 100) + overr.ReadFile(p.stderr.handle, 100) + + events = [ovin.event, ovout.event, overr.event] + res = _winapi.WaitForMultipleObjects(events, True, 2000) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + self.assertFalse(ovout.pending) + self.assertFalse(overr.pending) + self.assertFalse(ovin.pending) + + self.assertEqual(ovin.getresult(), len(msg)) + out = ovout.getresult().rstrip() + err = overr.getresult().rstrip() + + self.assertGreater(len(out), 0) + self.assertGreater(len(err), 0) + # allow for partial reads... + self.assertTrue(msg.upper().rstrip().startswith(out)) + self.assertTrue(b"stderr".startswith(err)) From 3328a58e7c81fb86af2424b851de05e7b409ec00 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 14 Oct 2013 15:52:27 -0700 Subject: [PATCH 0674/1502] Add fakery so "from asyncio import selectors" always works. --- asyncio/__init__.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/asyncio/__init__.py b/asyncio/__init__.py index 513aa958..afc444d9 100644 --- a/asyncio/__init__.py +++ b/asyncio/__init__.py @@ -2,6 +2,13 @@ import sys +# The selectors module is in the stdlib in Python 3.4 but not in 3.3. +# Do this first, so the other submodules can use "from . import selectors". +try: + import selectors # Will also be exported. +except ImportError: + from . import selectors + # This relies on each of the submodules having an __all__ variable. from .futures import * from .events import * From 27dbf40b89ac6311d5f1fe6967cc3c82c8c246e7 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 14 Oct 2013 15:54:49 -0700 Subject: [PATCH 0675/1502] Rename tests from foo_test.py to test_foo.py, to match stdlib preference. --- runtests.py | 2 +- tests/{base_events_test.py => test_base_events.py} | 0 tests/{events_test.py => test_events.py} | 0 tests/{futures_test.py => test_futures.py} | 0 tests/{locks_test.py => test_locks.py} | 0 tests/{proactor_events_test.py => test_proactor_events.py} | 0 tests/{queues_test.py => test_queues.py} | 0 tests/{selector_events_test.py => test_selector_events.py} | 0 tests/{selectors_test.py => test_selectors.py} | 0 tests/{streams_test.py => test_streams.py} | 0 tests/{tasks_test.py => test_tasks.py} | 0 tests/{transports_test.py => test_transports.py} | 0 tests/{unix_events_test.py => test_unix_events.py} | 0 tests/{windows_events_test.py => test_windows_events.py} | 0 tests/{windows_utils_test.py => test_windows_utils.py} | 0 15 files changed, 1 insertion(+), 1 deletion(-) rename tests/{base_events_test.py => test_base_events.py} (100%) rename tests/{events_test.py => test_events.py} (100%) rename tests/{futures_test.py => test_futures.py} (100%) rename tests/{locks_test.py => test_locks.py} (100%) rename tests/{proactor_events_test.py => test_proactor_events.py} (100%) rename tests/{queues_test.py => test_queues.py} (100%) rename tests/{selector_events_test.py => test_selector_events.py} (100%) rename tests/{selectors_test.py => test_selectors.py} (100%) rename tests/{streams_test.py => test_streams.py} (100%) rename tests/{tasks_test.py => test_tasks.py} (100%) rename tests/{transports_test.py => test_transports.py} (100%) rename tests/{unix_events_test.py => test_unix_events.py} (100%) rename tests/{windows_events_test.py => test_windows_events.py} (100%) rename tests/{windows_utils_test.py => test_windows_utils.py} (100%) diff --git a/runtests.py b/runtests.py index 287d0367..b85cce83 100644 --- a/runtests.py +++ b/runtests.py @@ -5,7 +5,7 @@ Patterns are matched against the fully qualified name of the test, including package, module, class and method, -e.g. 'tests.events_test.PolicyTests.testPolicy'. +e.g. 'tests.test_events.PolicyTests.testPolicy'. For full help, try --help. diff --git a/tests/base_events_test.py b/tests/test_base_events.py similarity index 100% rename from tests/base_events_test.py rename to tests/test_base_events.py diff --git a/tests/events_test.py b/tests/test_events.py similarity index 100% rename from tests/events_test.py rename to tests/test_events.py diff --git a/tests/futures_test.py b/tests/test_futures.py similarity index 100% rename from tests/futures_test.py rename to tests/test_futures.py diff --git a/tests/locks_test.py b/tests/test_locks.py similarity index 100% rename from tests/locks_test.py rename to tests/test_locks.py diff --git a/tests/proactor_events_test.py b/tests/test_proactor_events.py similarity index 100% rename from tests/proactor_events_test.py rename to tests/test_proactor_events.py diff --git a/tests/queues_test.py b/tests/test_queues.py similarity index 100% rename from tests/queues_test.py rename to tests/test_queues.py diff --git a/tests/selector_events_test.py b/tests/test_selector_events.py similarity index 100% rename from tests/selector_events_test.py rename to tests/test_selector_events.py diff --git a/tests/selectors_test.py b/tests/test_selectors.py similarity index 100% rename from tests/selectors_test.py rename to tests/test_selectors.py diff --git a/tests/streams_test.py b/tests/test_streams.py similarity index 100% rename from tests/streams_test.py rename to tests/test_streams.py diff --git a/tests/tasks_test.py b/tests/test_tasks.py similarity index 100% rename from tests/tasks_test.py rename to tests/test_tasks.py diff --git a/tests/transports_test.py b/tests/test_transports.py similarity index 100% rename from tests/transports_test.py rename to tests/test_transports.py diff --git a/tests/unix_events_test.py b/tests/test_unix_events.py similarity index 100% rename from tests/unix_events_test.py rename to tests/test_unix_events.py diff --git a/tests/windows_events_test.py b/tests/test_windows_events.py similarity index 100% rename from tests/windows_events_test.py rename to tests/test_windows_events.py diff --git a/tests/windows_utils_test.py b/tests/test_windows_utils.py similarity index 100% rename from tests/windows_utils_test.py rename to tests/test_windows_utils.py From 114e38f4324c143fdfd3794e31b842404c3b14ea Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 15 Oct 2013 09:17:54 -0700 Subject: [PATCH 0676/1502] In test_utils.py, try two locations for test key/cert files. --- asyncio/test_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index 1841f852..d0ba95d6 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -61,7 +61,14 @@ def handle_error(self, request, client_address): class SSLWSGIServer(SilentWSGIServer): def finish_request(self, request, client_address): + # The relative location of our test directory (which + # contains the sample key and certificate files) differs + # between the stdlib and stand-alone Tulip/asyncio. + # Prefer our own if we can find it. here = os.path.join(os.path.dirname(__file__), '..', 'tests') + if not os.path.isdir(here): + here = os.path.join(os.path.dirname(os.__file__), + 'test', 'test_asyncio') keyfile = os.path.join(here, 'sample.key') certfile = os.path.join(here, 'sample.crt') ssock = ssl.wrap_socket(request, From 7a06ad43baeb4c8922105c49eeedc58ab4e8869a Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 15 Oct 2013 09:30:45 -0700 Subject: [PATCH 0677/1502] fix infinite loop in SelectorEventLoop._sig_chld() --- tests/unix_events_test.py | 6 ++++++ tulip/unix_events.py | 37 +++++++++++++++++-------------------- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index 9d8d7e54..bc893230 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -290,6 +290,12 @@ def test__sig_chld_unknown_status_in_handler(self, m_waitpid, m_log.exception.assert_called_with( 'Unknown exception in SIGCHLD handler') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_process_error(self, m_waitpid): + m_waitpid.side_effect = ChildProcessError + self.loop._sig_chld() + self.assertTrue(m_waitpid.called) + class UnixReadPipeTransportTests(unittest.TestCase): diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 250c8fb9..0d8068a5 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -163,26 +163,23 @@ def _reg_sigchld(self): def _sig_chld(self): try: - while True: - try: - pid, status = os.waitpid(0, os.WNOHANG) - except ChildProcessError: - break - if pid == 0: - continue - elif os.WIFSIGNALED(status): - returncode = -os.WTERMSIG(status) - elif os.WIFEXITED(status): - returncode = os.WEXITSTATUS(status) - else: - # covered by - # SelectorEventLoopTests.test__sig_chld_unknown_status - # from tests/unix_events_test.py - # bug in coverage.py version 3.6 ??? - continue # pragma: no cover - transp = self._subprocesses.get(pid) - if transp is not None: - transp._process_exited(returncode) + try: + pid, status = os.waitpid(0, os.WNOHANG) + except ChildProcessError: + return + if pid == 0: + self.call_soon(self._sig_chld) + return + elif os.WIFSIGNALED(status): + returncode = -os.WTERMSIG(status) + elif os.WIFEXITED(status): + returncode = os.WEXITSTATUS(status) + else: + self.call_soon(self._sig_chld) + return + transp = self._subprocesses.get(pid) + if transp is not None: + transp._process_exited(returncode) except Exception: tulip_log.exception('Unknown exception in SIGCHLD handler') From 9b92051c9826400c29ec8c92e9f84d9a955d1295 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 15 Oct 2013 09:52:04 -0700 Subject: [PATCH 0678/1502] Merge _sig_chld() fix into asyncio branch. --- asyncio/unix_events.py | 37 +++++++++++++++++-------------------- tests/test_unix_events.py | 6 ++++++ 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 49761b99..b4bd15d7 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -163,26 +163,23 @@ def _reg_sigchld(self): def _sig_chld(self): try: - while True: - try: - pid, status = os.waitpid(0, os.WNOHANG) - except ChildProcessError: - break - if pid == 0: - continue - elif os.WIFSIGNALED(status): - returncode = -os.WTERMSIG(status) - elif os.WIFEXITED(status): - returncode = os.WEXITSTATUS(status) - else: - # covered by - # SelectorEventLoopTests.test__sig_chld_unknown_status - # from tests/unix_events_test.py - # bug in coverage.py version 3.6 ??? - continue # pragma: no cover - transp = self._subprocesses.get(pid) - if transp is not None: - transp._process_exited(returncode) + try: + pid, status = os.waitpid(0, os.WNOHANG) + except ChildProcessError: + return + if pid == 0: + self.call_soon(self._sig_chld) + return + elif os.WIFSIGNALED(status): + returncode = -os.WTERMSIG(status) + elif os.WIFEXITED(status): + returncode = os.WEXITSTATUS(status) + else: + self.call_soon(self._sig_chld) + return + transp = self._subprocesses.get(pid) + if transp is not None: + transp._process_exited(returncode) except Exception: asyncio_log.exception('Unknown exception in SIGCHLD handler') diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 42dd919b..e8a9ec39 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -290,6 +290,12 @@ def test__sig_chld_unknown_status_in_handler(self, m_waitpid, m_log.exception.assert_called_with( 'Unknown exception in SIGCHLD handler') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_process_error(self, m_waitpid): + m_waitpid.side_effect = ChildProcessError + self.loop._sig_chld() + self.assertTrue(m_waitpid.called) + class UnixReadPipeTransportTests(unittest.TestCase): From 9a77db16c56e0cf473f5218db3d1cb8169fb9971 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 15 Oct 2013 10:11:55 -0700 Subject: [PATCH 0679/1502] Extra info for SSL connections: context, peercert, cipher, compression. --- tulip/selector_events.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 37af46db..1d5cb293 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -534,6 +534,9 @@ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, self._sslcontext = sslcontext self._paused = False + # SSL-specific extra info. (peercert is set later) + self._extra.update(sslcontext=sslcontext) + self._on_handshake() def _on_handshake(self): @@ -555,6 +558,13 @@ def _on_handshake(self): if self._waiter is not None: self._waiter.set_exception(exc) raise + + # Add extra info that becomes available after handshake. + self._extra.update(peercert=self._sock.getpeercert(), + cipher=self._sock.cipher(), + compression=self._sock.compression(), + ) + self._loop.remove_reader(self._sock_fd) self._loop.remove_writer(self._sock_fd) self._loop.add_reader(self._sock_fd, self._on_ready) From c8109d381ada31c41df5d7b8cc008225b7e5d3f9 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 15 Oct 2013 10:16:27 -0700 Subject: [PATCH 0680/1502] Merge SSL extra info. --- asyncio/selector_events.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index d0677b9f..84963eb2 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -534,6 +534,9 @@ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, self._sslcontext = sslcontext self._paused = False + # SSL-specific extra info. (peercert is set later) + self._extra.update(sslcontext=sslcontext) + self._on_handshake() def _on_handshake(self): @@ -555,6 +558,13 @@ def _on_handshake(self): if self._waiter is not None: self._waiter.set_exception(exc) raise + + # Add extra info that becomes available after handshake. + self._extra.update(peercert=self._sock.getpeercert(), + cipher=self._sock.cipher(), + compression=self._sock.compression(), + ) + self._loop.remove_reader(self._sock_fd) self._loop.remove_writer(self._sock_fd) self._loop.add_reader(self._sock_fd, self._on_ready) From dbf146e6b5ef3688c54620bcf7556b1cfebb6e14 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 15 Oct 2013 11:44:42 -0700 Subject: [PATCH 0681/1502] Port asyncio rename and new extra info to Windows. --- asyncio/proactor_events.py | 10 ++++++++-- setup.cfg | 2 +- setup.py | 4 ++-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 034f405b..e27882ea 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -222,9 +222,15 @@ class _ProactorSocketTransport(_ProactorReadPipeTransport, def _set_extra(self, sock): self._extra['socket'] = sock - self._extra['sockname'] = sock.getsockname() + try: + self._extra['sockname'] = sock.getsockname() + except (socket.error, AttributeError): + pass if 'peername' not in self._extra: - self._extra['peername'] = sock.getpeername() + try: + self._extra['peername'] = sock.getpeername() + except (socket.error, AttributeError): + pass def can_write_eof(self): return True diff --git a/setup.cfg b/setup.cfg index 0260f9d5..172844ce 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,2 @@ [build_ext] -build_lib=tulip +build_lib=asyncio diff --git a/setup.py b/setup.py index dcaee96f..fad16e7a 100644 --- a/setup.py +++ b/setup.py @@ -6,9 +6,9 @@ ext = Extension('_overlapped', ['overlapped.c'], libraries=['ws2_32']) extensions.append(ext) -setup(name='tulip', +setup(name='asyncio', description="reference implementation of PEP 3156", url='http://www.python.org/dev/peps/pep-3156/', - packages=['tulip'], + packages=['asyncio'], ext_modules=extensions ) From 63bb8605f0b45669bf9baade3b695197f1f5704d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 15 Oct 2013 13:07:07 -0700 Subject: [PATCH 0682/1502] Change eof_ready() to return True to keep transport open. --- tests/base_events_test.py | 1 - tests/events_test.py | 1 - tests/selector_events_test.py | 12 ++++++++++++ tulip/proactor_events.py | 5 ++--- tulip/protocols.py | 8 +++----- tulip/selector_events.py | 5 ++--- 6 files changed, 19 insertions(+), 13 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index 940dee5f..609edb0b 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -256,7 +256,6 @@ def data_received(self, data): def eof_received(self): assert self.state == 'CONNECTED', self.state self.state = 'EOF' - self.transport.close() def connection_lost(self, exc): assert self.state in ('CONNECTED', 'EOF'), self.state diff --git a/tests/events_test.py b/tests/events_test.py index 7063efb3..4969cea0 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -110,7 +110,6 @@ def data_received(self, data): def eof_received(self): assert self.state == ['INITIAL', 'CONNECTED'], self.state self.state.append('EOF') - self.transport.close() def connection_lost(self, exc): assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index d117f97b..894ddbe6 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -708,6 +708,18 @@ def test_read_ready_eof(self): self.protocol.eof_received.assert_called_with() transport.close.assert_called_with() + def test_read_ready_eof_keep_open(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.close = unittest.mock.Mock() + + self.sock.recv.return_value = b'' + self.protocol.eof_received.return_value = True + transport._read_ready() + + self.protocol.eof_received.assert_called_with() + self.assertFalse(transport.close.called) + @unittest.mock.patch('logging.exception') def test_read_ready_tryagain(self, m_exc): self.sock.recv.side_effect = BlockingIOError diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index 7c49ae0d..24080cea 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -145,9 +145,8 @@ def _loop_reading(self, fut=None): if data: self._protocol.data_received(data) elif data is not None: - try: - self._protocol.eof_received() - finally: + keep_open = self._protocol.eof_received() + if not keep_open: self.close() diff --git a/tulip/protocols.py b/tulip/protocols.py index d76f25a2..a94abbe5 100644 --- a/tulip/protocols.py +++ b/tulip/protocols.py @@ -60,11 +60,9 @@ def data_received(self, data): def eof_received(self): """Called when the other end calls write_eof() or equivalent. - The default implementation does nothing. - - TODO: By default close the transport. But we don't have the - transport as an instance variable (connection_made() may not - set it). + If this returns a false value (including None), the transport + will close itself. If it returns a true value, closing the + transport is up to the protocol. """ diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 1d5cb293..40d70689 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -432,9 +432,8 @@ def _read_ready(self): if data: self._protocol.data_received(data) else: - try: - self._protocol.eof_received() - finally: + keep_open = self._protocol.eof_received() + if not keep_open: self.close() def write(self, data): From 7b5d41c1ef90e1633a74bd9d86ec76039dc6dc98 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 15 Oct 2013 13:08:03 -0700 Subject: [PATCH 0683/1502] Merge eof_received() change. --- asyncio/proactor_events.py | 5 ++--- asyncio/protocols.py | 8 +++----- asyncio/selector_events.py | 5 ++--- tests/test_base_events.py | 1 - tests/test_events.py | 1 - tests/test_selector_events.py | 12 ++++++++++++ 6 files changed, 19 insertions(+), 13 deletions(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index e27882ea..348de033 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -145,9 +145,8 @@ def _loop_reading(self, fut=None): if data: self._protocol.data_received(data) elif data is not None: - try: - self._protocol.eof_received() - finally: + keep_open = self._protocol.eof_received() + if not keep_open: self.close() diff --git a/asyncio/protocols.py b/asyncio/protocols.py index d76f25a2..a94abbe5 100644 --- a/asyncio/protocols.py +++ b/asyncio/protocols.py @@ -60,11 +60,9 @@ def data_received(self, data): def eof_received(self): """Called when the other end calls write_eof() or equivalent. - The default implementation does nothing. - - TODO: By default close the transport. But we don't have the - transport as an instance variable (connection_made() may not - set it). + If this returns a false value (including None), the transport + will close itself. If it returns a true value, closing the + transport is up to the protocol. """ diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 84963eb2..98a8f20f 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -432,9 +432,8 @@ def _read_ready(self): if data: self._protocol.data_received(data) else: - try: - self._protocol.eof_received() - finally: + keep_open = self._protocol.eof_received() + if not keep_open: self.close() def write(self, data): diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 5393bd42..e2f12724 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -256,7 +256,6 @@ def data_received(self, data): def eof_received(self): assert self.state == 'CONNECTED', self.state self.state = 'EOF' - self.transport.close() def connection_lost(self, exc): assert self.state in ('CONNECTED', 'EOF'), self.state diff --git a/tests/test_events.py b/tests/test_events.py index 7b12700d..3ba984f3 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -110,7 +110,6 @@ def data_received(self, data): def eof_received(self): assert self.state == ['INITIAL', 'CONNECTED'], self.state self.state.append('EOF') - self.transport.close() def connection_lost(self, exc): assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 1cd34dd2..0225e132 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -708,6 +708,18 @@ def test_read_ready_eof(self): self.protocol.eof_received.assert_called_with() transport.close.assert_called_with() + def test_read_ready_eof_keep_open(self): + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.close = unittest.mock.Mock() + + self.sock.recv.return_value = b'' + self.protocol.eof_received.return_value = True + transport._read_ready() + + self.protocol.eof_received.assert_called_with() + self.assertFalse(transport.close.called) + @unittest.mock.patch('logging.exception') def test_read_ready_tryagain(self, m_exc): self.sock.recv.side_effect = BlockingIOError From 6f1b16cdd13996201797e638b4e04ad7c3198b91 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Tue, 15 Oct 2013 21:48:27 +0100 Subject: [PATCH 0684/1502] Fix how the iocp eventloop accepts AF_INET6 connections. --- tulip/windows_events.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tulip/windows_events.py b/tulip/windows_events.py index 7253bb4b..853fec70 100644 --- a/tulip/windows_events.py +++ b/tulip/windows_events.py @@ -113,8 +113,8 @@ def create_pipe_connection(self, protocol_factory, address): f = self._proactor.connect_pipe(address) pipe = yield from f protocol = protocol_factory() - trans = self._make_socket_transport(pipe, protocol, - extra={'addr': address}) + trans = self._make_duplex_pipe_transport(pipe, protocol, + extra={'addr': address}) return trans, protocol @tasks.coroutine @@ -207,7 +207,7 @@ def finish(trans, key, ov): def accept(self, listener): self._register_with_iocp(listener) - conn = self._get_accept_socket() + conn = self._get_accept_socket(listener.family) ov = _overlapped.Overlapped(NULL) ov.AcceptEx(listener.fileno(), conn.fileno()) def finish_accept(trans, key, ov): @@ -300,8 +300,8 @@ def _register(self, ov, obj, callback, wait_for_post=False): f.set_result(value) return f - def _get_accept_socket(self): - s = socket.socket() + def _get_accept_socket(self, family): + s = socket.socket(family) s.settimeout(0) return s From 027bd19fc37f322b583cacd527eb648a3809dc8c Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Tue, 15 Oct 2013 22:22:00 +0100 Subject: [PATCH 0685/1502] Make BindLocal() take an argument for the family of the socket. --- overlapped.c | 13 ++++++------- tulip/windows_events.py | 2 +- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/overlapped.c b/overlapped.c index b5be63a0..6a1d9e4a 100644 --- a/overlapped.c +++ b/overlapped.c @@ -233,29 +233,28 @@ overlapped_PostQueuedCompletionStatus(PyObject *self, PyObject *args) PyDoc_STRVAR( BindLocal_doc, - "BindLocal(handle, length_of_address_tuple) -> None\n\n" + "BindLocal(handle, family) -> None\n\n" "Bind a socket handle to an arbitrary local port.\n" - "If length_of_address_tuple is 2 then an AF_INET address is used.\n" - "If length_of_address_tuple is 4 then an AF_INET6 address is used."); + "family should AF_INET or AF_INET6.\n"); static PyObject * overlapped_BindLocal(PyObject *self, PyObject *args) { SOCKET Socket; - int TupleLength; + int Family; BOOL ret; - if (!PyArg_ParseTuple(args, F_HANDLE "i", &Socket, &TupleLength)) + if (!PyArg_ParseTuple(args, F_HANDLE "i", &Socket, &Family)) return NULL; - if (TupleLength == 2) { + if (Family == AF_INET) { struct sockaddr_in addr; memset(&addr, 0, sizeof(addr)); addr.sin_family = AF_INET; addr.sin_port = 0; addr.sin_addr.S_un.S_addr = INADDR_ANY; ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; - } else if (TupleLength == 4) { + } else if (Family == AF_INET6) { struct sockaddr_in6 addr; memset(&addr, 0, sizeof(addr)); addr.sin6_family = AF_INET6; diff --git a/tulip/windows_events.py b/tulip/windows_events.py index 853fec70..f776b16c 100644 --- a/tulip/windows_events.py +++ b/tulip/windows_events.py @@ -224,7 +224,7 @@ def connect(self, conn, address): self._register_with_iocp(conn) # The socket needs to be locally bound before we call ConnectEx(). try: - _overlapped.BindLocal(conn.fileno(), len(address)) + _overlapped.BindLocal(conn.fileno(), conn.family) except OSError as e: if e.winerror != errno.WSAEINVAL: raise From 1bd0f74ad549be5a10d998a102548ded6bdeef1b Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Tue, 15 Oct 2013 22:23:48 +0100 Subject: [PATCH 0686/1502] Merge. --- asyncio/windows_events.py | 12 ++++++------ overlapped.c | 13 ++++++------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 2897dd0e..4cd7b060 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -113,8 +113,8 @@ def create_pipe_connection(self, protocol_factory, address): f = self._proactor.connect_pipe(address) pipe = yield from f protocol = protocol_factory() - trans = self._make_socket_transport(pipe, protocol, - extra={'addr': address}) + trans = self._make_duplex_pipe_transport(pipe, protocol, + extra={'addr': address}) return trans, protocol @tasks.coroutine @@ -207,7 +207,7 @@ def finish(trans, key, ov): def accept(self, listener): self._register_with_iocp(listener) - conn = self._get_accept_socket() + conn = self._get_accept_socket(listener.family) ov = _overlapped.Overlapped(NULL) ov.AcceptEx(listener.fileno(), conn.fileno()) def finish_accept(trans, key, ov): @@ -224,7 +224,7 @@ def connect(self, conn, address): self._register_with_iocp(conn) # The socket needs to be locally bound before we call ConnectEx(). try: - _overlapped.BindLocal(conn.fileno(), len(address)) + _overlapped.BindLocal(conn.fileno(), conn.family) except OSError as e: if e.winerror != errno.WSAEINVAL: raise @@ -300,8 +300,8 @@ def _register(self, ov, obj, callback, wait_for_post=False): f.set_result(value) return f - def _get_accept_socket(self): - s = socket.socket() + def _get_accept_socket(self, family): + s = socket.socket(family) s.settimeout(0) return s diff --git a/overlapped.c b/overlapped.c index b5be63a0..6a1d9e4a 100644 --- a/overlapped.c +++ b/overlapped.c @@ -233,29 +233,28 @@ overlapped_PostQueuedCompletionStatus(PyObject *self, PyObject *args) PyDoc_STRVAR( BindLocal_doc, - "BindLocal(handle, length_of_address_tuple) -> None\n\n" + "BindLocal(handle, family) -> None\n\n" "Bind a socket handle to an arbitrary local port.\n" - "If length_of_address_tuple is 2 then an AF_INET address is used.\n" - "If length_of_address_tuple is 4 then an AF_INET6 address is used."); + "family should AF_INET or AF_INET6.\n"); static PyObject * overlapped_BindLocal(PyObject *self, PyObject *args) { SOCKET Socket; - int TupleLength; + int Family; BOOL ret; - if (!PyArg_ParseTuple(args, F_HANDLE "i", &Socket, &TupleLength)) + if (!PyArg_ParseTuple(args, F_HANDLE "i", &Socket, &Family)) return NULL; - if (TupleLength == 2) { + if (Family == AF_INET) { struct sockaddr_in addr; memset(&addr, 0, sizeof(addr)); addr.sin_family = AF_INET; addr.sin_port = 0; addr.sin_addr.S_un.S_addr = INADDR_ANY; ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; - } else if (TupleLength == 4) { + } else if (Family == AF_INET6) { struct sockaddr_in6 addr; memset(&addr, 0, sizeof(addr)); addr.sin6_family = AF_INET6; From 0f12665e39879985b9e98d33a32a82b3e40a0414 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 15 Oct 2013 16:52:49 -0700 Subject: [PATCH 0687/1502] Look for _overlapped module in two places. --- asyncio/windows_events.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 4cd7b060..1d0ad26b 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -11,9 +11,13 @@ from . import selector_events from . import tasks from . import windows_utils -from . import _overlapped from .log import asyncio_log +try: + import _overlapped +except ImportError: + from . import _overlapped + __all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor'] From 4e9b8d0c801ff9ebaf29bf70d447ee8398ad0f0f Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Tue, 15 Oct 2013 22:20:05 -0700 Subject: [PATCH 0688/1502] allow to write to closed _UnixWritePipeTransport --- tests/unix_events_test.py | 11 +++++++++++ tulip/unix_events.py | 28 +++++++++++++++++----------- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index bc893230..3aa9db11 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -572,6 +572,17 @@ def test_write_err(self, m_write, m_log): m_log.warning.assert_called_with( 'os.write(pipe, data) raised exception.') + @unittest.mock.patch('os.write') + def test_write_close(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr._read_ready() # pipe was closed by peer + + tr.write(b'data') + self.assertEqual(tr._conn_lost, 1) + tr.write(b'data') + self.assertEqual(tr._conn_lost, 2) + def test__read_ready(self): tr = unix_events._UnixWritePipeTransport(self.loop, self.pipe, self.protocol) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 0d8068a5..fb6bf237 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -21,7 +21,11 @@ from .log import tulip_log -__all__ = ['SelectorEventLoop'] +__all__ = ['SelectorEventLoop', 'STDIN', 'STDOUT', 'STDERR'] + +STDIN = 0 +STDOUT = 1 +STDERR = 2 if sys.platform == 'win32': # pragma: no cover @@ -281,18 +285,17 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): def _read_ready(self): # pipe was closed by peer - self._close() def write(self, data): assert isinstance(data, bytes), repr(data) - assert not self._closing if not data: return - if self._conn_lost: + if self._conn_lost or self._closing: if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: - tulip_log.warning('os.write(pipe, data) raised exception.') + tulip_log.warning('pipe closed by peer or ' + 'os.write(pipe, data) raised exception.') self._conn_lost += 1 return @@ -424,11 +427,11 @@ def __init__(self, loop, protocol, args, shell, self._pipes = {} if stdin == subprocess.PIPE: - self._pipes[0] = None + self._pipes[STDIN] = None if stdout == subprocess.PIPE: - self._pipes[1] = None + self._pipes[STDOUT] = None if stderr == subprocess.PIPE: - self._pipes[2] = None + self._pipes[STDERR] = None self._pending_calls = collections.deque() self._finished = False self._returncode = None @@ -471,15 +474,18 @@ def _post_init(self): loop = self._loop if proc.stdin is not None: transp, proto = yield from loop.connect_write_pipe( - functools.partial(_UnixWriteSubprocessPipeProto, self, 0), + functools.partial( + _UnixWriteSubprocessPipeProto, self, STDIN), proc.stdin) if proc.stdout is not None: transp, proto = yield from loop.connect_read_pipe( - functools.partial(_UnixReadSubprocessPipeProto, self, 1), + functools.partial( + _UnixReadSubprocessPipeProto, self, STDOUT), proc.stdout) if proc.stderr is not None: transp, proto = yield from loop.connect_read_pipe( - functools.partial(_UnixReadSubprocessPipeProto, self, 2), + functools.partial( + _UnixReadSubprocessPipeProto, self, STDERR), proc.stderr) if not self._pipes: self._try_connected() From e38367b1a0e36a83bc0f0ce61a7e4641d8b62a10 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 16 Oct 2013 08:35:00 -0700 Subject: [PATCH 0689/1502] Remove comment about "Beyond the PEP". --- tulip/events.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tulip/events.py b/tulip/events.py index bded631b..a3fdd380 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -1,8 +1,4 @@ -"""Event loop and event loop policy. - -Beyond the PEP: -- Only the main thread has a default event loop. -""" +"""Event loop and event loop policy.""" __all__ = ['AbstractEventLoopPolicy', 'DefaultEventLoopPolicy', 'AbstractEventLoop', 'AbstractServer', From bd3bc95e034e54bd74ff7d28ade1830b1f03d6d5 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 16 Oct 2013 08:35:41 -0700 Subject: [PATCH 0690/1502] Add OP_NO_SSLv2 to default SSLContext options. --- tulip/selector_events.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 40d70689..548c5049 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -517,7 +517,10 @@ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, sslcontext, ssl.SSLContext), 'Must pass an SSLContext' else: # Client-side may pass ssl=True to use a default context. - sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + # The default is the same as used by urllib. + if sslcontext is None: + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.options |= ssl.OP_NO_SSLv2 wrap_kwargs = { 'server_side': server_side, 'do_handshake_on_connect': False, From a59bf0cbdb21c11b732cba36bd4b46f3e2d7ab9a Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 16 Oct 2013 08:36:12 -0700 Subject: [PATCH 0691/1502] Make write_eof() idempotent. --- tulip/unix_events.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tulip/unix_events.py b/tulip/unix_events.py index fb6bf237..1808256e 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -347,8 +347,12 @@ def _write_ready(self): def can_write_eof(self): return True + # TODO: Make the relationships between write_eof(), close(), + # abort(), _fatal_error() and _close() more straightforward. + def write_eof(self): - assert not self._closing + if self._closing: + return assert self._pipe self._closing = True if not self._buffer: From 4d6d94803a7bd0951326518c83cb657d749b6aa6 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 16 Oct 2013 08:41:25 -0700 Subject: [PATCH 0692/1502] Fix expected error message in test. --- tests/unix_events_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index 3aa9db11..49c5ded7 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -569,8 +569,9 @@ def test_write_err(self, m_write, m_log): tr.write(b'data') tr.write(b'data') tr.write(b'data') + # This is a bit overspecified. :-( m_log.warning.assert_called_with( - 'os.write(pipe, data) raised exception.') + 'pipe closed by peer or os.write(pipe, data) raised exception.') @unittest.mock.patch('os.write') def test_write_close(self, m_write): From dd32d59065de3553849a372c90590ae41cdeb1a5 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 16 Oct 2013 08:43:25 -0700 Subject: [PATCH 0693/1502] Merge --- asyncio/events.py | 6 +----- asyncio/selector_events.py | 5 ++++- asyncio/unix_events.py | 34 ++++++++++++++++++++++------------ tests/test_unix_events.py | 14 +++++++++++++- 4 files changed, 40 insertions(+), 19 deletions(-) diff --git a/asyncio/events.py b/asyncio/events.py index d2ca80c4..9724615b 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -1,8 +1,4 @@ -"""Event loop and event loop policy. - -Beyond the PEP: -- Only the main thread has a default event loop. -""" +"""Event loop and event loop policy.""" __all__ = ['AbstractEventLoopPolicy', 'DefaultEventLoopPolicy', 'AbstractEventLoop', 'AbstractServer', diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 98a8f20f..3b7831ac 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -517,7 +517,10 @@ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, sslcontext, ssl.SSLContext), 'Must pass an SSLContext' else: # Client-side may pass ssl=True to use a default context. - sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23) + # The default is the same as used by urllib. + if sslcontext is None: + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.options |= ssl.OP_NO_SSLv2 wrap_kwargs = { 'server_side': server_side, 'do_handshake_on_connect': False, diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index b4bd15d7..a3a8e112 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -21,7 +21,11 @@ from .log import asyncio_log -__all__ = ['SelectorEventLoop'] +__all__ = ['SelectorEventLoop', 'STDIN', 'STDOUT', 'STDERR'] + +STDIN = 0 +STDOUT = 1 +STDERR = 2 if sys.platform == 'win32': # pragma: no cover @@ -281,18 +285,17 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): def _read_ready(self): # pipe was closed by peer - self._close() def write(self, data): assert isinstance(data, bytes), repr(data) - assert not self._closing if not data: return - if self._conn_lost: + if self._conn_lost or self._closing: if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: - asyncio_log.warning('os.write(pipe, data) raised exception.') + asyncio_log.warning('pipe closed by peer or ' + 'os.write(pipe, data) raised exception.') self._conn_lost += 1 return @@ -344,8 +347,12 @@ def _write_ready(self): def can_write_eof(self): return True + # TODO: Make the relationships between write_eof(), close(), + # abort(), _fatal_error() and _close() more straightforward. + def write_eof(self): - assert not self._closing + if self._closing: + return assert self._pipe self._closing = True if not self._buffer: @@ -424,11 +431,11 @@ def __init__(self, loop, protocol, args, shell, self._pipes = {} if stdin == subprocess.PIPE: - self._pipes[0] = None + self._pipes[STDIN] = None if stdout == subprocess.PIPE: - self._pipes[1] = None + self._pipes[STDOUT] = None if stderr == subprocess.PIPE: - self._pipes[2] = None + self._pipes[STDERR] = None self._pending_calls = collections.deque() self._finished = False self._returncode = None @@ -471,15 +478,18 @@ def _post_init(self): loop = self._loop if proc.stdin is not None: transp, proto = yield from loop.connect_write_pipe( - functools.partial(_UnixWriteSubprocessPipeProto, self, 0), + functools.partial( + _UnixWriteSubprocessPipeProto, self, STDIN), proc.stdin) if proc.stdout is not None: transp, proto = yield from loop.connect_read_pipe( - functools.partial(_UnixReadSubprocessPipeProto, self, 1), + functools.partial( + _UnixReadSubprocessPipeProto, self, STDOUT), proc.stdout) if proc.stderr is not None: transp, proto = yield from loop.connect_read_pipe( - functools.partial(_UnixReadSubprocessPipeProto, self, 2), + functools.partial( + _UnixReadSubprocessPipeProto, self, STDERR), proc.stderr) if not self._pipes: self._try_connected() diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index e8a9ec39..ea678624 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -569,8 +569,20 @@ def test_write_err(self, m_write, m_log): tr.write(b'data') tr.write(b'data') tr.write(b'data') + # This is a bit overspecified. :-( m_log.warning.assert_called_with( - 'os.write(pipe, data) raised exception.') + 'pipe closed by peer or os.write(pipe, data) raised exception.') + + @unittest.mock.patch('os.write') + def test_write_close(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr._read_ready() # pipe was closed by peer + + tr.write(b'data') + self.assertEqual(tr._conn_lost, 1) + tr.write(b'data') + self.assertEqual(tr._conn_lost, 2) def test__read_ready(self): tr = unix_events._UnixWritePipeTransport(self.loop, self.pipe, From c399e1e783c632b1040fb81f9154eeb7219374ab Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 16 Oct 2013 11:17:08 -0700 Subject: [PATCH 0694/1502] Fix misspelled peername in examples. Also comment out drain() call. --- examples/fetch2.py | 2 +- examples/fetch3.py | 7 ++++--- tulip/selector_events.py | 4 ++++ 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/fetch2.py b/examples/fetch2.py index 8badf1db..ca250d61 100644 --- a/examples/fetch2.py +++ b/examples/fetch2.py @@ -45,7 +45,7 @@ def connect(self): ssl=self.ssl) if self.verbose: print('* Connected to %s' % - (self.writer.get_extra_info('getpeername'),), + (self.writer.get_extra_info('peername'),), file=sys.stderr) def putline(self, line): diff --git a/examples/fetch3.py b/examples/fetch3.py index 8142f094..3b2c8ae0 100644 --- a/examples/fetch3.py +++ b/examples/fetch3.py @@ -34,7 +34,7 @@ def open_connection(self, host, port, ssl): print('* Reusing pooled connection', key, file=sys.stderr) return conn reader, writer = yield from open_connection(host, port, ssl=ssl) - host, port, *_ = writer.get_extra_info('getpeername') + host, port, *_ = writer.get_extra_info('peername') key = host, port, ssl self.connections[key] = reader, writer if self.verbose: @@ -79,13 +79,13 @@ def connect(self, pool): self.port, ssl=self.ssl) self.vprint('* Connected to %s' % - (self.writer.get_extra_info('getpeername'),)) + (self.writer.get_extra_info('peername'),)) @coroutine def putline(self, line): self.vprint('>', line) self.writer.write(line.encode('latin-1') + b'\r\n') - yield from self.writer.drain() + ##yield from self.writer.drain() @coroutine def send_request(self): @@ -197,6 +197,7 @@ def fetch(url, verbose=True, max_redirect=10): if not next_url: break url = urllib.parse.urljoin(url, next_url) + print('redirect to', url, file=sys.stderr) return body diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 548c5049..447fcc5b 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -519,8 +519,12 @@ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, # Client-side may pass ssl=True to use a default context. # The default is the same as used by urllib. if sslcontext is None: + import sys; print('default ssl context', file=sys.stderr) + ##import pdb; pdb.set_trace() sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.set_default_verify_paths() + sslcontext.verify_mode = ssl.CERT_REQUIRED wrap_kwargs = { 'server_side': server_side, 'do_handshake_on_connect': False, From c86a83a5d22222b55af9e662edafef66de9180b3 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 16 Oct 2013 11:18:30 -0700 Subject: [PATCH 0695/1502] Kill accidental debug code. --- tulip/selector_events.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 447fcc5b..053afda5 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -519,8 +519,6 @@ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, # Client-side may pass ssl=True to use a default context. # The default is the same as used by urllib. if sslcontext is None: - import sys; print('default ssl context', file=sys.stderr) - ##import pdb; pdb.set_trace() sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext.options |= ssl.OP_NO_SSLv2 sslcontext.set_default_verify_paths() From f39da151649ae191e19752272a5f0137ecb275fb Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 16 Oct 2013 11:42:59 -0700 Subject: [PATCH 0696/1502] Make tests that connect to a dummy SSL server use a dumy certificate. --- tests/events_test.py | 6 ++++-- tests/streams_test.py | 3 ++- tulip/test_utils.py | 7 +++++++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/events_test.py b/tests/events_test.py index 4969cea0..3b667aa1 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -510,7 +510,8 @@ def test_create_connection_sock(self): def test_create_ssl_connection(self): with test_utils.run_test_server(use_ssl=True) as httpd: f = self.loop.create_connection( - lambda: MyProto(loop=self.loop), *httpd.address, ssl=True) + lambda: MyProto(loop=self.loop), *httpd.address, + ssl=test_utils.dummy_ssl_context()) tr, pr = self.loop.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) @@ -613,7 +614,8 @@ def factory(): host, port = sock.getsockname() self.assertEqual(host, '127.0.0.1') - f_c = self.loop.create_connection(ClientMyProto, host, port, ssl=True) + f_c = self.loop.create_connection(ClientMyProto, host, port, + ssl=test_utils.dummy_ssl_context()) client, pr = self.loop.run_until_complete(f_c) client.write(b'xxx') diff --git a/tests/streams_test.py b/tests/streams_test.py index c8ad7801..c5c2ff5d 100644 --- a/tests/streams_test.py +++ b/tests/streams_test.py @@ -50,7 +50,8 @@ def test_open_connection_no_loop_ssl(self): with test_utils.run_test_server(use_ssl=True) as httpd: try: events.set_event_loop(self.loop) - f = streams.open_connection(*httpd.address, ssl=True) + f = streams.open_connection(*httpd.address, + ssl=test_utils.dummy_ssl_context()) reader, writer = self.loop.run_until_complete(f) finally: events.set_event_loop(None) diff --git a/tulip/test_utils.py b/tulip/test_utils.py index 61001168..6555142c 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -27,6 +27,13 @@ from socket import socketpair # pragma: no cover +def dummy_ssl_context(): + if ssl is None: + return None + else: + return ssl.SSLContext(ssl.PROTOCOL_SSLv23) + + def run_briefly(loop): @tulip.coroutine def once(): From 13b91c475b4c2562061d56422db1ec4298f6772d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 16 Oct 2013 11:48:18 -0700 Subject: [PATCH 0697/1502] Verify hostname if requested. --- tulip/selector_events.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 053afda5..98e0a948 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -533,6 +533,7 @@ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, super().__init__(loop, sslsock, protocol, extra, server) + self._server_hostname = server_hostname self._waiter = waiter self._rawsock = rawsock self._sslcontext = sslcontext @@ -563,8 +564,20 @@ def _on_handshake(self): self._waiter.set_exception(exc) raise + # Verify hostname if requested. + peercert = self._sock.getpeercert() + if (self._server_hostname is not None and + self._sslcontext.verify_mode == ssl.CERT_REQUIRED): + try: + ssl.match_hostname(peercert, self._server_hostname) + except Exception as exc: + self._sock.close() + if self._waiter is not None: + self._waiter.set_exception(exc) + return + # Add extra info that becomes available after handshake. - self._extra.update(peercert=self._sock.getpeercert(), + self._extra.update(peercert=peercert, cipher=self._sock.cipher(), compression=self._sock.compression(), ) From 8f9dcea2172f76b562b9411d1be19a9d4e76c2ae Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 16 Oct 2013 11:52:47 -0700 Subject: [PATCH 0698/1502] Merge --- asyncio/selector_events.py | 17 ++++++++++++++++- asyncio/test_utils.py | 7 +++++++ examples/fetch2.py | 2 +- examples/fetch3.py | 7 ++++--- tests/test_events.py | 6 ++++-- tests/test_streams.py | 3 ++- 6 files changed, 34 insertions(+), 8 deletions(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 3b7831ac..bae9a493 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -521,6 +521,8 @@ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, if sslcontext is None: sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.set_default_verify_paths() + sslcontext.verify_mode = ssl.CERT_REQUIRED wrap_kwargs = { 'server_side': server_side, 'do_handshake_on_connect': False, @@ -531,6 +533,7 @@ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, super().__init__(loop, sslsock, protocol, extra, server) + self._server_hostname = server_hostname self._waiter = waiter self._rawsock = rawsock self._sslcontext = sslcontext @@ -561,8 +564,20 @@ def _on_handshake(self): self._waiter.set_exception(exc) raise + # Verify hostname if requested. + peercert = self._sock.getpeercert() + if (self._server_hostname is not None and + self._sslcontext.verify_mode == ssl.CERT_REQUIRED): + try: + ssl.match_hostname(peercert, self._server_hostname) + except Exception as exc: + self._sock.close() + if self._waiter is not None: + self._waiter.set_exception(exc) + return + # Add extra info that becomes available after handshake. - self._extra.update(peercert=self._sock.getpeercert(), + self._extra.update(peercert=peercert, cipher=self._sock.cipher(), compression=self._sock.compression(), ) diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index d0ba95d6..7c361978 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -27,6 +27,13 @@ from socket import socketpair # pragma: no cover +def dummy_ssl_context(): + if ssl is None: + return None + else: + return ssl.SSLContext(ssl.PROTOCOL_SSLv23) + + def run_briefly(loop): @tasks.coroutine def once(): diff --git a/examples/fetch2.py b/examples/fetch2.py index 8badf1db..ca250d61 100644 --- a/examples/fetch2.py +++ b/examples/fetch2.py @@ -45,7 +45,7 @@ def connect(self): ssl=self.ssl) if self.verbose: print('* Connected to %s' % - (self.writer.get_extra_info('getpeername'),), + (self.writer.get_extra_info('peername'),), file=sys.stderr) def putline(self, line): diff --git a/examples/fetch3.py b/examples/fetch3.py index 8142f094..3b2c8ae0 100644 --- a/examples/fetch3.py +++ b/examples/fetch3.py @@ -34,7 +34,7 @@ def open_connection(self, host, port, ssl): print('* Reusing pooled connection', key, file=sys.stderr) return conn reader, writer = yield from open_connection(host, port, ssl=ssl) - host, port, *_ = writer.get_extra_info('getpeername') + host, port, *_ = writer.get_extra_info('peername') key = host, port, ssl self.connections[key] = reader, writer if self.verbose: @@ -79,13 +79,13 @@ def connect(self, pool): self.port, ssl=self.ssl) self.vprint('* Connected to %s' % - (self.writer.get_extra_info('getpeername'),)) + (self.writer.get_extra_info('peername'),)) @coroutine def putline(self, line): self.vprint('>', line) self.writer.write(line.encode('latin-1') + b'\r\n') - yield from self.writer.drain() + ##yield from self.writer.drain() @coroutine def send_request(self): @@ -197,6 +197,7 @@ def fetch(url, verbose=True, max_redirect=10): if not next_url: break url = urllib.parse.urljoin(url, next_url) + print('redirect to', url, file=sys.stderr) return body diff --git a/tests/test_events.py b/tests/test_events.py index 3ba984f3..053bc35b 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -510,7 +510,8 @@ def test_create_connection_sock(self): def test_create_ssl_connection(self): with test_utils.run_test_server(use_ssl=True) as httpd: f = self.loop.create_connection( - lambda: MyProto(loop=self.loop), *httpd.address, ssl=True) + lambda: MyProto(loop=self.loop), *httpd.address, + ssl=test_utils.dummy_ssl_context()) tr, pr = self.loop.run_until_complete(f) self.assertTrue(isinstance(tr, transports.Transport)) self.assertTrue(isinstance(pr, protocols.Protocol)) @@ -613,7 +614,8 @@ def factory(): host, port = sock.getsockname() self.assertEqual(host, '127.0.0.1') - f_c = self.loop.create_connection(ClientMyProto, host, port, ssl=True) + f_c = self.loop.create_connection(ClientMyProto, host, port, + ssl=test_utils.dummy_ssl_context()) client, pr = self.loop.run_until_complete(f_c) client.write(b'xxx') diff --git a/tests/test_streams.py b/tests/test_streams.py index 9cd7eade..011a09da 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -50,7 +50,8 @@ def test_open_connection_no_loop_ssl(self): with test_utils.run_test_server(use_ssl=True) as httpd: try: events.set_event_loop(self.loop) - f = streams.open_connection(*httpd.address, ssl=True) + f = streams.open_connection(*httpd.address, + ssl=test_utils.dummy_ssl_context()) reader, writer = self.loop.run_until_complete(f) finally: events.set_event_loop(None) From 11017be9cac0311f4ee9e70461c611107efcc6a5 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 16 Oct 2013 12:03:07 -0700 Subject: [PATCH 0699/1502] Update examples to use asyncio, not tulip. --- examples/child_process.py | 40 +++++++++++++++++++-------------------- examples/fetch0.py | 2 +- examples/fetch1.py | 2 +- examples/fetch2.py | 2 +- examples/fetch3.py | 2 +- examples/sink.py | 4 ++-- examples/source.py | 4 ++-- examples/stacks.py | 2 +- examples/tcp_echo.py | 18 +++++++++--------- examples/udp_echo.py | 10 +++++----- 10 files changed, 43 insertions(+), 43 deletions(-) diff --git a/examples/child_process.py b/examples/child_process.py index 5a88faa6..ef31e68b 100644 --- a/examples/child_process.py +++ b/examples/child_process.py @@ -8,18 +8,18 @@ import sys try: - import tulip + import asyncio except ImportError: - # tulip is not installed + # asyncio is not installed sys.path.append(os.path.join(os.path.dirname(__file__), '..')) - import tulip + import asyncio -from tulip import streams -from tulip import protocols +from asyncio import streams +from asyncio import protocols if sys.platform == 'win32': - from tulip.windows_utils import Popen, PIPE - from tulip.windows_events import ProactorEventLoop + from asyncio.windows_utils import Popen, PIPE + from asyncio.windows_events import ProactorEventLoop else: from subprocess import Popen, PIPE @@ -27,20 +27,20 @@ # Return a write-only transport wrapping a writable pipe # -@tulip.coroutine +@asyncio.coroutine def connect_write_pipe(file): - loop = tulip.get_event_loop() + loop = asyncio.get_event_loop() protocol = protocols.Protocol() - transport, _ = yield from loop.connect_write_pipe(tulip.Protocol, file) + transport, _ = yield from loop.connect_write_pipe(asyncio.Protocol, file) return transport # # Wrap a readable pipe in a stream # -@tulip.coroutine +@asyncio.coroutine def connect_read_pipe(file): - loop = tulip.get_event_loop() + loop = asyncio.get_event_loop() stream_reader = streams.StreamReader(loop=loop) def factory(): return streams.StreamReaderProtocol(stream_reader) @@ -52,7 +52,7 @@ def factory(): # Example # -@tulip.coroutine +@asyncio.coroutine def main(loop): # program which prints evaluation of each expression from stdin code = r'''if 1: @@ -88,8 +88,8 @@ def writeall(fd, buf): # interact with subprocess name = {stdout:'OUT', stderr:'ERR'} - registered = {tulip.Task(stderr.readline()): stderr, - tulip.Task(stdout.readline()): stdout} + registered = {asyncio.Task(stderr.readline()): stderr, + asyncio.Task(stdout.readline()): stdout} while registered: # write command cmd = next(commands, None) @@ -102,8 +102,8 @@ def writeall(fd, buf): # get and print lines from stdout, stderr timeout = None while registered: - done, pending = yield from tulip.wait( - registered, timeout=timeout, return_when=tulip.FIRST_COMPLETED) + done, pending = yield from asyncio.wait( + registered, timeout=timeout, return_when=asyncio.FIRST_COMPLETED) if not done: break for f in done: @@ -111,7 +111,7 @@ def writeall(fd, buf): res = f.result() print(name[stream], res.decode('ascii').rstrip()) if res != b'': - registered[tulip.Task(stream.readline())] = stream + registered[asyncio.Task(stream.readline())] = stream timeout = 0.0 stdout_transport.close() @@ -120,8 +120,8 @@ def writeall(fd, buf): if __name__ == '__main__': if sys.platform == 'win32': loop = ProactorEventLoop() - tulip.set_event_loop(loop) + asyncio.set_event_loop(loop) else: - loop = tulip.get_event_loop() + loop = asyncio.get_event_loop() loop.run_until_complete(main(loop)) loop.close() diff --git a/examples/fetch0.py b/examples/fetch0.py index 84edaa26..ac4d5d95 100644 --- a/examples/fetch0.py +++ b/examples/fetch0.py @@ -2,7 +2,7 @@ import sys -from tulip import * +from asyncio import * @coroutine diff --git a/examples/fetch1.py b/examples/fetch1.py index 57e66e6a..6d99262b 100644 --- a/examples/fetch1.py +++ b/examples/fetch1.py @@ -6,7 +6,7 @@ import sys import urllib.parse -from tulip import * +from asyncio import * class Response: diff --git a/examples/fetch2.py b/examples/fetch2.py index ca250d61..0899123f 100644 --- a/examples/fetch2.py +++ b/examples/fetch2.py @@ -7,7 +7,7 @@ import urllib.parse from http.client import BadStatusLine -from tulip import * +from asyncio import * class Request: diff --git a/examples/fetch3.py b/examples/fetch3.py index 3b2c8ae0..fac880fe 100644 --- a/examples/fetch3.py +++ b/examples/fetch3.py @@ -8,7 +8,7 @@ import urllib.parse from http.client import BadStatusLine -from tulip import * +from asyncio import * class ConnectionPool: diff --git a/examples/sink.py b/examples/sink.py index 855a4aa1..bb29be24 100644 --- a/examples/sink.py +++ b/examples/sink.py @@ -2,7 +2,7 @@ import sys -from tulip import * +from asyncio import * server = None @@ -41,7 +41,7 @@ def start(loop): def main(): if '--iocp' in sys.argv: - from tulip.windows_events import ProactorEventLoop + from asyncio.windows_events import ProactorEventLoop loop = ProactorEventLoop() set_event_loop(loop) loop = get_event_loop() diff --git a/examples/source.py b/examples/source.py index 7a5011da..6bfdcb0f 100644 --- a/examples/source.py +++ b/examples/source.py @@ -2,7 +2,7 @@ import sys -from tulip import * +from asyncio import * def dprint(*args): print('source:', *args, file=sys.stderr) @@ -47,7 +47,7 @@ def start(loop): def main(): if '--iocp' in sys.argv: - from tulip.windows_events import ProactorEventLoop + from asyncio.windows_events import ProactorEventLoop loop = ProactorEventLoop() set_event_loop(loop) loop = get_event_loop() diff --git a/examples/stacks.py b/examples/stacks.py index 77a99cf5..371d31f2 100644 --- a/examples/stacks.py +++ b/examples/stacks.py @@ -1,7 +1,7 @@ """Crude demo for print_stack().""" -from tulip import * +from asyncio import * @coroutine diff --git a/examples/tcp_echo.py b/examples/tcp_echo.py index 39db5cca..6082ef73 100755 --- a/examples/tcp_echo.py +++ b/examples/tcp_echo.py @@ -1,14 +1,14 @@ #!/usr/bin/env python3 """TCP echo server example.""" import argparse -import tulip +import asyncio try: import signal except ImportError: signal = None -class EchoServer(tulip.Protocol): +class EchoServer(asyncio.Protocol): TIMEOUT = 5.0 @@ -21,7 +21,7 @@ def connection_made(self, transport): self.transport = transport # start 5 seconds timeout timer - self.h_timeout = tulip.get_event_loop().call_later( + self.h_timeout = asyncio.get_event_loop().call_later( self.TIMEOUT, self.timeout) def data_received(self, data): @@ -30,7 +30,7 @@ def data_received(self, data): # restart timeout timer self.h_timeout.cancel() - self.h_timeout = tulip.get_event_loop().call_later( + self.h_timeout = asyncio.get_event_loop().call_later( self.TIMEOUT, self.timeout) def eof_received(self): @@ -41,7 +41,7 @@ def connection_lost(self, exc): self.h_timeout.cancel() -class EchoClient(tulip.Protocol): +class EchoClient(asyncio.Protocol): message = 'This is the message. It will be echoed.' @@ -54,18 +54,18 @@ def data_received(self, data): print('data received:', data) # disconnect after 10 seconds - tulip.get_event_loop().call_later(10.0, self.transport.close) + asyncio.get_event_loop().call_later(10.0, self.transport.close) def eof_received(self): pass def connection_lost(self, exc): print('connection lost:', exc) - tulip.get_event_loop().stop() + asyncio.get_event_loop().stop() def start_client(loop, host, port): - t = tulip.Task(loop.create_connection(EchoClient, host, port)) + t = asyncio.Task(loop.create_connection(EchoClient, host, port)) loop.run_until_complete(t) @@ -101,7 +101,7 @@ def start_server(loop, host, port): print('Please specify --server or --client\n') ARGS.print_help() else: - loop = tulip.get_event_loop() + loop = asyncio.get_event_loop() if signal is not None: loop.add_signal_handler(signal.SIGINT, loop.stop) diff --git a/examples/udp_echo.py b/examples/udp_echo.py index 0347bfbd..8e95d292 100755 --- a/examples/udp_echo.py +++ b/examples/udp_echo.py @@ -2,7 +2,7 @@ """UDP echo example.""" import argparse import sys -import tulip +import asyncio try: import signal except ImportError: @@ -45,18 +45,18 @@ def connection_refused(self, exc): def connection_lost(self, exc): print('closing transport', exc) - loop = tulip.get_event_loop() + loop = asyncio.get_event_loop() loop.stop() def start_server(loop, addr): - t = tulip.Task(loop.create_datagram_endpoint( + t = asyncio.Task(loop.create_datagram_endpoint( MyServerUdpEchoProtocol, local_addr=addr)) loop.run_until_complete(t) def start_client(loop, addr): - t = tulip.Task(loop.create_datagram_endpoint( + t = asyncio.Task(loop.create_datagram_endpoint( MyClientUdpEchoProtocol, remote_addr=addr)) loop.run_until_complete(t) @@ -86,7 +86,7 @@ def start_client(loop, addr): print('Please specify --server or --client\n') ARGS.print_help() else: - loop = tulip.get_event_loop() + loop = asyncio.get_event_loop() if signal is not None: loop.add_signal_handler(signal.SIGINT, loop.stop) From 3973ae5723443aca8f27718fc0514050f924fe0b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 16 Oct 2013 13:22:39 -0700 Subject: [PATCH 0700/1502] Skip windows tests using unittest.SkipTest(). --- tests/test_windows_events.py | 4 ++++ tests/test_windows_utils.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index 589df8a4..4b04073e 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -1,6 +1,10 @@ import os +import sys import unittest +if sys.platform != 'win32': + raise unittest.SkipTest('Windows only') + import asyncio from asyncio import windows_events diff --git a/tests/test_windows_utils.py b/tests/test_windows_utils.py index 24c407f4..4b960861 100644 --- a/tests/test_windows_utils.py +++ b/tests/test_windows_utils.py @@ -4,6 +4,10 @@ import test.support import unittest import unittest.mock + +if sys.platform != 'win32': + raise unittest.SkipTest('Windows only') + import _winapi from asyncio import windows_utils From 3ccae93bc815eefd516b47cc485f1486c38822fb Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 16 Oct 2013 23:59:00 -0700 Subject: [PATCH 0701/1502] Add script to quickly sync files to cpython tree. --- update_stdlib.sh | 54 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100755 update_stdlib.sh diff --git a/update_stdlib.sh b/update_stdlib.sh new file mode 100755 index 00000000..70e28f35 --- /dev/null +++ b/update_stdlib.sh @@ -0,0 +1,54 @@ +#!/bin/bash + +# Script to copy asyncio files to the standard library tree. +# Optional argument is the root of the Python 3.4 tree. +# Assumes you have already created Lib/asyncio and +# Lib/test/test_asyncio in the destination tree. + +CPYTHON=${1-$HOME/cpython} + +if [ ! -d $CPYTHON ] +then + echo Bad destination $CPYTHON + exit 1 +fi + +if [ ! -f asyncio/__init__.py ] +then + echo Bad current directory + exit 1 +fi + +maybe_copy() +{ + SRC=$1 + DST=$CPYTHON/$2 + if cmp $DST $SRC + then + return + fi + echo ======== $SRC === $DST ======== + diff -u $DST $SRC + echo -n "Copy $SRC? [y/N] " + read X + case $X in + [yY]*) echo Copying $SRC; cp $SRC $DST;; + *) echo Not copying $SRC;; + esac +} + +for i in `(cd asyncio && ls *.py)` +do + if [ $i == selectors.py ] + then + continue + fi + maybe_copy asyncio/$i Lib/asyncio/$i +done + +for i in `(cd tests && ls *.py sample.???)` +do + maybe_copy tests/$i Lib/test/test_asyncio/$i +done + +maybe_copy overlapped.c Modules/overlapped.c From e354f0e738ff6f4a5f71a3b37dc12685c5d5901a Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 17 Oct 2013 10:56:15 -0700 Subject: [PATCH 0702/1502] Various test tweaks to be more precise about coroutines. --- tests/base_events_test.py | 5 ++--- tests/events_test.py | 1 + tests/tasks_test.py | 27 ++++++++++++++++++++++----- tulip/test_utils.py | 8 ++++++-- 4 files changed, 31 insertions(+), 10 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index 609edb0b..478d821e 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -44,9 +44,8 @@ def test_not_implemented(self): self.assertRaises( NotImplementedError, self.loop._make_write_pipe_transport, m, m) - self.assertRaises( - NotImplementedError, - next, self.loop._make_subprocess_transport(m, m, m, m, m, m, m)) + gen = self.loop._make_subprocess_transport(m, m, m, m, m, m, m) + self.assertRaises(NotImplementedError, next, iter(gen)) def test__add_callback_handle(self): h = events.Handle(lambda: False, ()) diff --git a/tests/events_test.py b/tests/events_test.py index 3b667aa1..1d2a60dc 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -926,6 +926,7 @@ def test_prompt_cancellation(self): ov = getattr(f, 'ov', None) self.assertTrue(ov is None or ov.pending) + @tasks.coroutine def main(): try: self.loop.call_soon(f.cancel) diff --git a/tests/tasks_test.py b/tests/tasks_test.py index 396e4f7f..b8dcd7ec 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -136,8 +136,10 @@ class MyTask(tasks.Task, T): def __repr__(self): return super().__repr__() - t = MyTask(coro(), loop=self.loop) + gen = coro() + t = MyTask(gen, loop=self.loop) self.assertEqual(repr(t), 'T[]()') + gen.close() def test_task_basics(self): @tasks.coroutine @@ -942,10 +944,12 @@ def test_step_in_completed_task(self): def notmuch(): return 'ko' - task = tasks.Task(notmuch(), loop=self.loop) + gen = notmuch() + task = tasks.Task(gen, loop=self.loop) task.set_result('ok') self.assertRaises(AssertionError, task._step) + gen.close() def test_step_result(self): @tasks.coroutine @@ -1069,7 +1073,11 @@ def coro(): @tasks.coroutine def wait_for_future(): - yield coro() + gen = coro() + try: + yield gen + finally: + gen.close() task = wait_for_future() self.assertRaises( @@ -1426,6 +1434,7 @@ def tearDown(self): def wrap_futures(self, *futures): coros = [] for fut in futures: + @tasks.coroutine def coro(fut=fut): return (yield from fut) coros.append(coro()) @@ -1435,10 +1444,18 @@ def test_constructor_loop_selection(self): @tasks.coroutine def coro(): return 'abc' - fut = tasks.gather(coro(), coro()) + gen1 = coro() + gen2 = coro() + fut = tasks.gather(gen1, gen2) self.assertIs(fut._loop, self.one_loop) - fut = tasks.gather(coro(), coro(), loop=self.other_loop) + gen1.close() + gen2.close() + gen3 = coro() + gen4 = coro() + fut = tasks.gather(gen3, gen4, loop=self.other_loop) self.assertIs(fut._loop, self.other_loop) + gen3.close() + gen4.close() def test_cancellation_broadcast(self): # Cancelling outer() cancels all children. diff --git a/tulip/test_utils.py b/tulip/test_utils.py index 6555142c..f3cd8fbb 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -38,8 +38,12 @@ def run_briefly(loop): @tulip.coroutine def once(): pass - t = tulip.Task(once(), loop=loop) - loop.run_until_complete(t) + gen = once() + t = tulip.Task(gen, loop=loop) + try: + loop.run_until_complete(t) + finally: + gen.close() def run_once(loop): From 6a02a3f601ad616f6744be9498264dd4593e057b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 17 Oct 2013 10:56:46 -0700 Subject: [PATCH 0703/1502] Add a _DEBUG feature to @coroutine to catch un-waited-for coroutine calls. --- tulip/tasks.py | 79 +++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 66 insertions(+), 13 deletions(-) diff --git a/tulip/tasks.py b/tulip/tasks.py index af9330d4..b020d9cb 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -16,15 +16,61 @@ from . import events from . import futures +from .log import tulip_log + +# If you set _DEBUG to true, @coroutine will wrap the resulting +# generator objects in a CoroWrapper instance (defined below). That +# instance will log a message when the generator is never iterated +# over, which may happen when you forget to use "yield from" with a +# coroutine call. Note that the value of the _DEBUG flag is taken +# when the decorator is used, so to be of any use it must be set +# before you define your coroutines. A downside of using this feature +# is that tracebacks show entries for the CoroWrapper.__next__ method +# when _DEBUG is true. +_DEBUG = False + + +class CoroWrapper: + """Wrapper for coroutine in _DEBUG mode.""" + + __slot__ = ['gen', 'func'] + + def __init__(self, gen, func): + assert inspect.isgenerator(gen), gen + self.gen = gen + self.func = func + + def __iter__(self): + return self + + def __next__(self): + return next(self.gen) + + def send(self, value): + return self.gen.send(value) + + def throw(self, exc): + return self.gen.throw(exc) + + def close(self): + return self.gen.close() + + def __del__(self): + frame = self.gen.gi_frame + if frame is not None and frame.f_lasti == -1: + func = self.func + code = func.__code__ + filename = code.co_filename + lineno = code.co_firstlineno + tulip_log.error('Coroutine %r defined at %s:%s was never yielded from', + func.__name__, filename, lineno) def coroutine(func): """Decorator to mark coroutines. - Decorator wraps non generator functions and returns generator wrapper. - If non generator function returns generator of Future it yield-from it. - - TODO: This is a feel-good API only. It is not enforced. + If the coroutine is not yielded from before it is destroyed, + an error message is logged. """ if inspect.isgeneratorfunction(func): coro = func @@ -36,21 +82,28 @@ def coro(*args, **kw): res = yield from res return res - coro._is_coroutine = True # Not sure who can use this. - return coro + if not _DEBUG: + wrapper = coro + else: + @functools.wraps(func) + def wrapper(*args, **kwds): + w = CoroWrapper(coro(*args, **kwds), func) + w.__name__ = coro.__name__ + w.__doc__ = coro.__doc__ + return w + + wrapper._is_coroutine = True # For iscoroutinefunction(). + return wrapper -# TODO: Do we need this? def iscoroutinefunction(func): """Return True if func is a decorated coroutine function.""" - return (inspect.isgeneratorfunction(func) and - getattr(func, '_is_coroutine', False)) + return getattr(func, '_is_coroutine', False) -# TODO: Do we need this? def iscoroutine(obj): """Return True if obj is a coroutine object.""" - return inspect.isgenerator(obj) # TODO: And what? + return isinstance(obj, CoroWrapper) or inspect.isgenerator(obj) class Task(futures.Future): @@ -79,9 +132,9 @@ def all_tasks(cls, loop=None): return {t for t in cls._all_tasks if t._loop is loop} def __init__(self, coro, *, loop=None): - assert inspect.isgenerator(coro) # Must be a coroutine *object*. + assert iscoroutine(coro), repr(coro) # Not a coroutine function! super().__init__(loop=loop) - self._coro = coro + self._coro = iter(coro) # Use the iterator just in case. self._fut_waiter = None self._must_cancel = False self._loop.call_soon(self._step) From 9a628953142b09defc163e5ad3bd8c4b1845531a Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 17 Oct 2013 11:03:03 -0700 Subject: [PATCH 0704/1502] Merge --- asyncio/tasks.py | 79 ++++++++++++++++++++++++++++++++------- asyncio/test_utils.py | 8 +++- tests/test_base_events.py | 5 +-- tests/test_events.py | 1 + tests/test_tasks.py | 27 ++++++++++--- 5 files changed, 97 insertions(+), 23 deletions(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 7ece2b9d..cfe409a3 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -16,15 +16,61 @@ from . import events from . import futures +from .log import asyncio_log + +# If you set _DEBUG to true, @coroutine will wrap the resulting +# generator objects in a CoroWrapper instance (defined below). That +# instance will log a message when the generator is never iterated +# over, which may happen when you forget to use "yield from" with a +# coroutine call. Note that the value of the _DEBUG flag is taken +# when the decorator is used, so to be of any use it must be set +# before you define your coroutines. A downside of using this feature +# is that tracebacks show entries for the CoroWrapper.__next__ method +# when _DEBUG is true. +_DEBUG = False + + +class CoroWrapper: + """Wrapper for coroutine in _DEBUG mode.""" + + __slot__ = ['gen', 'func'] + + def __init__(self, gen, func): + assert inspect.isgenerator(gen), gen + self.gen = gen + self.func = func + + def __iter__(self): + return self + + def __next__(self): + return next(self.gen) + + def send(self, value): + return self.gen.send(value) + + def throw(self, exc): + return self.gen.throw(exc) + + def close(self): + return self.gen.close() + + def __del__(self): + frame = self.gen.gi_frame + if frame is not None and frame.f_lasti == -1: + func = self.func + code = func.__code__ + filename = code.co_filename + lineno = code.co_firstlineno + asyncio_log.error('Coroutine %r defined at %s:%s was never yielded from', + func.__name__, filename, lineno) def coroutine(func): """Decorator to mark coroutines. - Decorator wraps non generator functions and returns generator wrapper. - If non generator function returns generator of Future it yield-from it. - - TODO: This is a feel-good API only. It is not enforced. + If the coroutine is not yielded from before it is destroyed, + an error message is logged. """ if inspect.isgeneratorfunction(func): coro = func @@ -36,21 +82,28 @@ def coro(*args, **kw): res = yield from res return res - coro._is_coroutine = True # Not sure who can use this. - return coro + if not _DEBUG: + wrapper = coro + else: + @functools.wraps(func) + def wrapper(*args, **kwds): + w = CoroWrapper(coro(*args, **kwds), func) + w.__name__ = coro.__name__ + w.__doc__ = coro.__doc__ + return w + + wrapper._is_coroutine = True # For iscoroutinefunction(). + return wrapper -# TODO: Do we need this? def iscoroutinefunction(func): """Return True if func is a decorated coroutine function.""" - return (inspect.isgeneratorfunction(func) and - getattr(func, '_is_coroutine', False)) + return getattr(func, '_is_coroutine', False) -# TODO: Do we need this? def iscoroutine(obj): """Return True if obj is a coroutine object.""" - return inspect.isgenerator(obj) # TODO: And what? + return isinstance(obj, CoroWrapper) or inspect.isgenerator(obj) class Task(futures.Future): @@ -79,9 +132,9 @@ def all_tasks(cls, loop=None): return {t for t in cls._all_tasks if t._loop is loop} def __init__(self, coro, *, loop=None): - assert inspect.isgenerator(coro) # Must be a coroutine *object*. + assert iscoroutine(coro), repr(coro) # Not a coroutine function! super().__init__(loop=loop) - self._coro = coro + self._coro = iter(coro) # Use the iterator just in case. self._fut_waiter = None self._must_cancel = False self._loop.call_soon(self._step) diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index 7c361978..f4fb802b 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -38,8 +38,12 @@ def run_briefly(loop): @tasks.coroutine def once(): pass - t = tasks.Task(once(), loop=loop) - loop.run_until_complete(t) + gen = once() + t = tasks.Task(gen, loop=loop) + try: + loop.run_until_complete(t) + finally: + gen.close() def run_once(loop): diff --git a/tests/test_base_events.py b/tests/test_base_events.py index e2f12724..d48d12cd 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -44,9 +44,8 @@ def test_not_implemented(self): self.assertRaises( NotImplementedError, self.loop._make_write_pipe_transport, m, m) - self.assertRaises( - NotImplementedError, - next, self.loop._make_subprocess_transport(m, m, m, m, m, m, m)) + gen = self.loop._make_subprocess_transport(m, m, m, m, m, m, m) + self.assertRaises(NotImplementedError, next, iter(gen)) def test__add_callback_handle(self): h = events.Handle(lambda: False, ()) diff --git a/tests/test_events.py b/tests/test_events.py index 053bc35b..243f4001 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -926,6 +926,7 @@ def test_prompt_cancellation(self): ov = getattr(f, 'ov', None) self.assertTrue(ov is None or ov.pending) + @tasks.coroutine def main(): try: self.loop.call_soon(f.cancel) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 4b8fa903..ab455960 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -136,8 +136,10 @@ class MyTask(tasks.Task, T): def __repr__(self): return super().__repr__() - t = MyTask(coro(), loop=self.loop) + gen = coro() + t = MyTask(gen, loop=self.loop) self.assertEqual(repr(t), 'T[]()') + gen.close() def test_task_basics(self): @tasks.coroutine @@ -942,10 +944,12 @@ def test_step_in_completed_task(self): def notmuch(): return 'ko' - task = tasks.Task(notmuch(), loop=self.loop) + gen = notmuch() + task = tasks.Task(gen, loop=self.loop) task.set_result('ok') self.assertRaises(AssertionError, task._step) + gen.close() def test_step_result(self): @tasks.coroutine @@ -1069,7 +1073,11 @@ def coro(): @tasks.coroutine def wait_for_future(): - yield coro() + gen = coro() + try: + yield gen + finally: + gen.close() task = wait_for_future() self.assertRaises( @@ -1426,6 +1434,7 @@ def tearDown(self): def wrap_futures(self, *futures): coros = [] for fut in futures: + @tasks.coroutine def coro(fut=fut): return (yield from fut) coros.append(coro()) @@ -1435,10 +1444,18 @@ def test_constructor_loop_selection(self): @tasks.coroutine def coro(): return 'abc' - fut = tasks.gather(coro(), coro()) + gen1 = coro() + gen2 = coro() + fut = tasks.gather(gen1, gen2) self.assertIs(fut._loop, self.one_loop) - fut = tasks.gather(coro(), coro(), loop=self.other_loop) + gen1.close() + gen2.close() + gen3 = coro() + gen4 = coro() + fut = tasks.gather(gen3, gen4, loop=self.other_loop) self.assertIs(fut._loop, self.other_loop) + gen3.close() + gen4.close() def test_cancellation_broadcast(self): # Cancelling outer() cancels all children. From 1f90bbdc4f3b21e59226671120141959cc1f6da9 Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 17 Oct 2013 12:19:02 -0700 Subject: [PATCH 0705/1502] custom implementation for wait_for --- tests/queues_test.py | 4 ++-- tests/tasks_test.py | 1 - tulip/tasks.py | 30 +++++++++++++++++++----------- 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/tests/queues_test.py b/tests/queues_test.py index 437a1c30..ae1e3a14 100644 --- a/tests/queues_test.py +++ b/tests/queues_test.py @@ -242,7 +242,7 @@ def gen(): when = yield self.assertAlmostEqual(0.01, when) when = yield 0.01 - self.assertAlmostEqual(0.06, when) + self.assertAlmostEqual(0.061, when) yield 0.05 loop = test_utils.TestLoop(gen) @@ -252,7 +252,7 @@ def gen(): @tasks.coroutine def queue_get(): - return (yield from tasks.wait_for(q.get(), 0.05, loop=loop)) + return (yield from tasks.wait_for(q.get(), 0.051, loop=loop)) @tasks.coroutine def test(): diff --git a/tests/tasks_test.py b/tests/tasks_test.py index b8dcd7ec..161bff81 100644 --- a/tests/tasks_test.py +++ b/tests/tasks_test.py @@ -374,7 +374,6 @@ def foo(): self.assertFalse(fut.done()) self.assertAlmostEqual(0.1, loop.time()) - loop # wait for result res = loop.run_until_complete( tasks.wait_for(fut, 0.3, loop=loop)) diff --git a/tulip/tasks.py b/tulip/tasks.py index b020d9cb..c998a7ed 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -349,6 +349,11 @@ def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): return (yield from _wait(fs, timeout, return_when, loop)) +def _release_waiter(waiter, value=True, *args): + if not waiter.done(): + waiter.set_result(value) + + @coroutine def wait_for(fut, timeout, *, loop=None): """Wait for the single Future or coroutine to complete, with timeout. @@ -366,18 +371,21 @@ def wait_for(fut, timeout, *, loop=None): if loop is None: loop = events.get_event_loop() - fut = async(fut, loop=loop) - - done, pending = yield from _wait([fut], timeout, FIRST_COMPLETED, loop) - if done: - return done.pop().result() - - raise futures.TimeoutError() + waiter = futures.Future(loop=loop) + timeout_handle = loop.call_later(timeout, _release_waiter, waiter, False) + cb = functools.partial(_release_waiter, waiter, True) + fut = async(fut, loop=loop) + fut.add_done_callback(cb) -def _waiter_timeout(waiter): - if not waiter.done(): - waiter.set_result(False) + try: + if (yield from waiter): + return fut.result() + else: + fut.remove_done_callback(cb) + raise futures.TimeoutError() + finally: + timeout_handle.cancel() @coroutine @@ -390,7 +398,7 @@ def _wait(fs, timeout, return_when, loop): waiter = futures.Future(loop=loop) timeout_handle = None if timeout is not None: - timeout_handle = loop.call_later(timeout, _waiter_timeout, waiter) + timeout_handle = loop.call_later(timeout, _release_waiter, waiter) counter = len(fs) def _on_completion(f): From b1aba984570173390ed35e106b94f12f8778fcb7 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 17 Oct 2013 12:37:17 -0700 Subject: [PATCH 0706/1502] Merge --- asyncio/tasks.py | 30 +++++++++++++++++++----------- tests/test_queues.py | 4 ++-- tests/test_tasks.py | 1 - 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index cfe409a3..2c8579fa 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -349,6 +349,11 @@ def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): return (yield from _wait(fs, timeout, return_when, loop)) +def _release_waiter(waiter, value=True, *args): + if not waiter.done(): + waiter.set_result(value) + + @coroutine def wait_for(fut, timeout, *, loop=None): """Wait for the single Future or coroutine to complete, with timeout. @@ -366,18 +371,21 @@ def wait_for(fut, timeout, *, loop=None): if loop is None: loop = events.get_event_loop() - fut = async(fut, loop=loop) - - done, pending = yield from _wait([fut], timeout, FIRST_COMPLETED, loop) - if done: - return done.pop().result() - - raise futures.TimeoutError() + waiter = futures.Future(loop=loop) + timeout_handle = loop.call_later(timeout, _release_waiter, waiter, False) + cb = functools.partial(_release_waiter, waiter, True) + fut = async(fut, loop=loop) + fut.add_done_callback(cb) -def _waiter_timeout(waiter): - if not waiter.done(): - waiter.set_result(False) + try: + if (yield from waiter): + return fut.result() + else: + fut.remove_done_callback(cb) + raise futures.TimeoutError() + finally: + timeout_handle.cancel() @coroutine @@ -390,7 +398,7 @@ def _wait(fs, timeout, return_when, loop): waiter = futures.Future(loop=loop) timeout_handle = None if timeout is not None: - timeout_handle = loop.call_later(timeout, _waiter_timeout, waiter) + timeout_handle = loop.call_later(timeout, _release_waiter, waiter) counter = len(fs) def _on_completion(f): diff --git a/tests/test_queues.py b/tests/test_queues.py index 24805570..8af4ee7f 100644 --- a/tests/test_queues.py +++ b/tests/test_queues.py @@ -242,7 +242,7 @@ def gen(): when = yield self.assertAlmostEqual(0.01, when) when = yield 0.01 - self.assertAlmostEqual(0.06, when) + self.assertAlmostEqual(0.061, when) yield 0.05 loop = test_utils.TestLoop(gen) @@ -252,7 +252,7 @@ def gen(): @tasks.coroutine def queue_get(): - return (yield from tasks.wait_for(q.get(), 0.05, loop=loop)) + return (yield from tasks.wait_for(q.get(), 0.051, loop=loop)) @tasks.coroutine def test(): diff --git a/tests/test_tasks.py b/tests/test_tasks.py index ab455960..57fb0537 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -374,7 +374,6 @@ def foo(): self.assertFalse(fut.done()) self.assertAlmostEqual(0.1, loop.time()) - loop # wait for result res = loop.run_until_complete( tasks.wait_for(fut, 0.3, loop=loop)) From a65e46389d1b2de7dbd225e2228ff7514947c44e Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 17 Oct 2013 13:42:02 -0700 Subject: [PATCH 0707/1502] Fix indent. --- tulip/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tulip/test_utils.py b/tulip/test_utils.py index f3cd8fbb..fdf629d6 100644 --- a/tulip/test_utils.py +++ b/tulip/test_utils.py @@ -43,7 +43,7 @@ def once(): try: loop.run_until_complete(t) finally: - gen.close() + gen.close() def run_once(loop): From 40fea19b145bfe229a41afb5c73fd5ee0ecb795c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 17 Oct 2013 13:42:24 -0700 Subject: [PATCH 0708/1502] Merge --- asyncio/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index f4fb802b..91bbedba 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -43,7 +43,7 @@ def once(): try: loop.run_until_complete(t) finally: - gen.close() + gen.close() def run_once(loop): From 1b8b62de71770d229f29edb04ea9d651e0fcb2af Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 17 Oct 2013 14:49:52 -0700 Subject: [PATCH 0709/1502] Make CPython tests run on Windows. --- tests/streams_test.py | 5 ++++- tests/unix_events_test.py | 3 +++ tests/windows_utils_test.py | 6 +++++- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/streams_test.py b/tests/streams_test.py index c5c2ff5d..e562bf92 100644 --- a/tests/streams_test.py +++ b/tests/streams_test.py @@ -1,9 +1,12 @@ """Tests for streams.py.""" import gc -import ssl import unittest import unittest.mock +try: + import ssl +except ImportError: + ssl = None from tulip import events from tulip import streams diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index 49c5ded7..c40455aa 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -10,6 +10,9 @@ import unittest import unittest.mock +if sys.platform == 'win32': + raise unittest.SkipTest('UNIX only') + from tulip import events from tulip import futures diff --git a/tests/windows_utils_test.py b/tests/windows_utils_test.py index b23896d3..bd61463f 100644 --- a/tests/windows_utils_test.py +++ b/tests/windows_utils_test.py @@ -7,7 +7,11 @@ import _winapi from tulip import windows_utils -from tulip import _overlapped + +try: + import _overlapped +except ImportError: + from tulip import _overlapped class WinsocketpairTests(unittest.TestCase): From 78c3a28b3673b8dfdfbb7442a95b662537a34e44 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 17 Oct 2013 14:51:22 -0700 Subject: [PATCH 0710/1502] Merge --- tests/test_streams.py | 5 ++++- tests/test_unix_events.py | 3 +++ tests/test_windows_utils.py | 6 +++++- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/test_streams.py b/tests/test_streams.py index 011a09da..31d81514 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -1,9 +1,12 @@ """Tests for streams.py.""" import gc -import ssl import unittest import unittest.mock +try: + import ssl +except ImportError: + ssl = None from asyncio import events from asyncio import streams diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index ea678624..6dbd47f6 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -10,6 +10,9 @@ import unittest import unittest.mock +if sys.platform == 'win32': + raise unittest.SkipTest('UNIX only') + from asyncio import events from asyncio import futures diff --git a/tests/test_windows_utils.py b/tests/test_windows_utils.py index 4b960861..3b6b0368 100644 --- a/tests/test_windows_utils.py +++ b/tests/test_windows_utils.py @@ -11,7 +11,11 @@ import _winapi from asyncio import windows_utils -from asyncio import _overlapped + +try: + import _overlapped +except ImportError: + from asyncio import _overlapped class WinsocketpairTests(unittest.TestCase): From f8da66f9192a55ce77b000331b239ab2ded392a4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 17 Oct 2013 15:10:58 -0700 Subject: [PATCH 0711/1502] Add proper argument parsing to source/sink examples. --- examples/sink.py | 29 ++++++++++++++++++++++++----- examples/source.py | 38 ++++++++++++++++++++++++++++++++------ 2 files changed, 56 insertions(+), 11 deletions(-) diff --git a/examples/sink.py b/examples/sink.py index 855a4aa1..119af6eb 100644 --- a/examples/sink.py +++ b/examples/sink.py @@ -1,14 +1,28 @@ """Test service that accepts connections and reads all data off them.""" +import argparse import sys from tulip import * +ARGS = argparse.ArgumentParser(description="TCP data sink example.") +ARGS.add_argument( + '--iocp', action='store_true', dest='iocp', + default=False, help='Use IOCP event loop (Windows only)') +ARGS.add_argument( + '--host', action='store', dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action='store', dest='port', + default=1111, type=int, help='Port number') + server = None + def dprint(*args): print('sink:', *args, file=sys.stderr) + class Service(Protocol): def connection_made(self, tr): @@ -32,21 +46,26 @@ def data_received(self, data): def connection_lost(self, how): dprint('closed', repr(how)) + @coroutine -def start(loop): +def start(loop, host, port): global server - server = yield from loop.create_server(Service, 'localhost', 1111) + server = yield from loop.create_server(Service, host, port) dprint('serving', [s.getsockname() for s in server.sockets]) yield from server.wait_closed() + def main(): - if '--iocp' in sys.argv: + args = ARGS.parse_args() + if args.iocp: from tulip.windows_events import ProactorEventLoop loop = ProactorEventLoop() set_event_loop(loop) - loop = get_event_loop() - loop.run_until_complete(start(loop)) + else: + loop = get_event_loop() + loop.run_until_complete(start(loop, args.host, args.port)) loop.close() + if __name__ == '__main__': main() diff --git a/examples/source.py b/examples/source.py index 7a5011da..f5d0291d 100644 --- a/examples/source.py +++ b/examples/source.py @@ -1,12 +1,32 @@ """Test client that connects and sends infinite data.""" +import argparse import sys from tulip import * + +ARGS = argparse.ArgumentParser(description="TCP data sink example.") +ARGS.add_argument( + '--iocp', action='store_true', dest='iocp', + default=False, help='Use IOCP event loop (Windows only)') +ARGS.add_argument( + '--stop', action='store_true', dest='stop', + default=False, help='Stop the server by sending it b"stop" as data') +ARGS.add_argument( + '--host', action='store', dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action='store', dest='port', + default=1111, type=int, help='Port number') + +args = None + + def dprint(*args): print('source:', *args, file=sys.stderr) + class Client(Protocol): data = b'x'*16*1024 @@ -18,7 +38,7 @@ def connection_made(self, tr): self.lost = False self.loop = get_event_loop() self.waiter = Future() - if '--stop' in sys.argv[1:]: + if args.stop: self.tr.write(b'stop') self.tr.close() else: @@ -37,22 +57,28 @@ def connection_lost(self, exc): self.lost = True self.waiter.set_result(None) + @coroutine -def start(loop): - tr, pr = yield from loop.create_connection(Client, '127.0.0.1', 1111) +def start(loop, host, port): + tr, pr = yield from loop.create_connection(Client, host, port) dprint('tr =', tr) dprint('pr =', pr) res = yield from pr.waiter return res + def main(): - if '--iocp' in sys.argv: + global args + args = ARGS.parse_args() + if args.iocp: from tulip.windows_events import ProactorEventLoop loop = ProactorEventLoop() set_event_loop(loop) - loop = get_event_loop() - loop.run_until_complete(start(loop)) + else: + loop = get_event_loop() + loop.run_until_complete(start(loop, args.host, args.port)) loop.close() + if __name__ == '__main__': main() From 9faf524e074266a4894b286ea9bafc8a6e2635f7 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 17 Oct 2013 15:17:40 -0700 Subject: [PATCH 0712/1502] Add size argument to source example. --- examples/source.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/source.py b/examples/source.py index f5d0291d..c4842240 100644 --- a/examples/source.py +++ b/examples/source.py @@ -19,6 +19,9 @@ ARGS.add_argument( '--port', action='store', dest='port', default=1111, type=int, help='Port number') +ARGS.add_argument( + '--size', action='store', dest='size', + default=16*1024, type=int, help='Data size') args = None @@ -29,8 +32,6 @@ def dprint(*args): class Client(Protocol): - data = b'x'*16*1024 - def connection_made(self, tr): dprint('connecting to', tr.get_extra_info('peername')) dprint('my socket is', tr.get_extra_info('sockname')) @@ -42,15 +43,16 @@ def connection_made(self, tr): self.tr.write(b'stop') self.tr.close() else: - self.write_some_data() + data = b'x' * args.size + self.write_some_data(data) - def write_some_data(self): + def write_some_data(self, data): if self.lost: dprint('lost already') return - dprint('writing', len(self.data), 'bytes') - self.tr.write(self.data) - self.loop.call_soon(self.write_some_data) + dprint('writing', len(data), 'bytes') + self.tr.write(data) + self.loop.call_soon(self.write_some_data, data) def connection_lost(self, exc): dprint('lost connection', repr(exc)) From 1019b4f0fc8315804f2b38d3945ef0b73527b80c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 17 Oct 2013 15:28:33 -0700 Subject: [PATCH 0713/1502] Rename tulip_log to logger. --- tests/base_events_test.py | 4 ++-- tests/events_test.py | 2 +- tests/futures_test.py | 12 ++++++------ tests/proactor_events_test.py | 6 +++--- tests/selector_events_test.py | 12 ++++++------ tests/unix_events_test.py | 14 +++++++------- tulip/base_events.py | 4 ++-- tulip/events.py | 6 +++--- tulip/futures.py | 6 +++--- tulip/log.py | 3 ++- tulip/proactor_events.py | 10 +++++----- tulip/selector_events.py | 14 +++++++------- tulip/tasks.py | 6 +++--- tulip/unix_events.py | 16 ++++++++-------- tulip/windows_events.py | 6 +++--- 15 files changed, 61 insertions(+), 60 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index 478d821e..38fe07aa 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -183,7 +183,7 @@ def test__run_once(self): self.assertTrue(self.loop._process_events.called) @unittest.mock.patch('tulip.base_events.time') - @unittest.mock.patch('tulip.base_events.tulip_log') + @unittest.mock.patch('tulip.base_events.logger') def test__run_once_logging(self, m_logging, m_time): # Log to INFO level if timeout > 1.0 sec. idx = -1 @@ -579,7 +579,7 @@ def test_accept_connection_retry(self): self.loop._accept_connection(MyProto, sock) self.assertFalse(sock.close.called) - @unittest.mock.patch('tulip.selector_events.tulip_log') + @unittest.mock.patch('tulip.selector_events.logger') def test_accept_connection_exception(self, m_log): sock = unittest.mock.Mock() sock.fileno.return_value = 10 diff --git a/tests/events_test.py b/tests/events_test.py index 1d2a60dc..67b2cfc6 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -1320,7 +1320,7 @@ def callback(*args): self.assertRaises( AssertionError, events.make_handle, h1, ()) - @unittest.mock.patch('tulip.events.tulip_log') + @unittest.mock.patch('tulip.events.logger') def test_callback_with_exception(self, log): def callback(): raise ValueError() diff --git a/tests/futures_test.py b/tests/futures_test.py index 18cec8b0..13a1dd93 100644 --- a/tests/futures_test.py +++ b/tests/futures_test.py @@ -170,20 +170,20 @@ def test(): self.assertRaises(AssertionError, test) fut.cancel() - @unittest.mock.patch('tulip.futures.tulip_log') + @unittest.mock.patch('tulip.futures.logger') def test_tb_logger_abandoned(self, m_log): fut = futures.Future(loop=self.loop) del fut self.assertFalse(m_log.error.called) - @unittest.mock.patch('tulip.futures.tulip_log') + @unittest.mock.patch('tulip.futures.logger') def test_tb_logger_result_unretrieved(self, m_log): fut = futures.Future(loop=self.loop) fut.set_result(42) del fut self.assertFalse(m_log.error.called) - @unittest.mock.patch('tulip.futures.tulip_log') + @unittest.mock.patch('tulip.futures.logger') def test_tb_logger_result_retrieved(self, m_log): fut = futures.Future(loop=self.loop) fut.set_result(42) @@ -191,7 +191,7 @@ def test_tb_logger_result_retrieved(self, m_log): del fut self.assertFalse(m_log.error.called) - @unittest.mock.patch('tulip.futures.tulip_log') + @unittest.mock.patch('tulip.futures.logger') def test_tb_logger_exception_unretrieved(self, m_log): fut = futures.Future(loop=self.loop) fut.set_exception(RuntimeError('boom')) @@ -199,7 +199,7 @@ def test_tb_logger_exception_unretrieved(self, m_log): test_utils.run_briefly(self.loop) self.assertTrue(m_log.error.called) - @unittest.mock.patch('tulip.futures.tulip_log') + @unittest.mock.patch('tulip.futures.logger') def test_tb_logger_exception_retrieved(self, m_log): fut = futures.Future(loop=self.loop) fut.set_exception(RuntimeError('boom')) @@ -207,7 +207,7 @@ def test_tb_logger_exception_retrieved(self, m_log): del fut self.assertFalse(m_log.error.called) - @unittest.mock.patch('tulip.futures.tulip_log') + @unittest.mock.patch('tulip.futures.logger') def test_tb_logger_exception_result_retrieved(self, m_log): fut = futures.Future(loop=self.loop) fut.set_exception(RuntimeError('boom')) diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py index 009b0697..27b83a3f 100644 --- a/tests/proactor_events_test.py +++ b/tests/proactor_events_test.py @@ -135,7 +135,7 @@ def test_loop_writing(self): self.loop._proactor.send.return_value.add_done_callback.\ assert_called_with(tr._loop_writing) - @unittest.mock.patch('tulip.proactor_events.tulip_log') + @unittest.mock.patch('tulip.proactor_events.logger') def test_loop_writing_err(self, m_log): err = self.loop._proactor.send.side_effect = OSError() tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) @@ -207,7 +207,7 @@ def test_close_buffer(self): test_utils.run_briefly(self.loop) self.assertFalse(self.protocol.connection_lost.called) - @unittest.mock.patch('tulip.proactor_events.tulip_log') + @unittest.mock.patch('tulip.proactor_events.logger') def test_fatal_error(self, m_logging): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._force_close = unittest.mock.Mock() @@ -432,7 +432,7 @@ def test_write_to_self(self): def test_process_events(self): self.loop._process_events([]) - @unittest.mock.patch('tulip.proactor_events.tulip_log') + @unittest.mock.patch('tulip.proactor_events.logger') def test_create_server(self, m_log): pf = unittest.mock.Mock() call_soon = self.loop.call_soon = unittest.mock.Mock() diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index 894ddbe6..9f0b117e 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -626,7 +626,7 @@ def test_force_close(self): self.assertFalse(self.loop.readers) self.assertEqual(1, self.loop.remove_reader_count[7]) - @unittest.mock.patch('tulip.log.tulip_log.exception') + @unittest.mock.patch('tulip.log.logger.exception') def test_fatal_error(self, m_exc): exc = OSError() tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) @@ -823,7 +823,7 @@ def test_write_tryagain(self): self.loop.assert_writer(7, transport._write_ready) self.assertEqual(collections.deque([b'data']), transport._buffer) - @unittest.mock.patch('tulip.selector_events.tulip_log') + @unittest.mock.patch('tulip.selector_events.logger') def test_write_exception(self, m_log): err = self.sock.send.side_effect = OSError() @@ -937,7 +937,7 @@ def test_write_ready_exception(self): transport._write_ready() transport._fatal_error.assert_called_with(err) - @unittest.mock.patch('tulip.selector_events.tulip_log') + @unittest.mock.patch('tulip.selector_events.logger') def test_write_ready_exception_and_close(self, m_log): self.sock.send.side_effect = OSError() remove_writer = self.loop.remove_writer = unittest.mock.Mock() @@ -1072,7 +1072,7 @@ def test_write_closing(self): transport.write(b'data') self.assertEqual(transport._conn_lost, 2) - @unittest.mock.patch('tulip.selector_events.tulip_log') + @unittest.mock.patch('tulip.selector_events.logger') def test_write_exception(self, m_log): transport = self._make_one() transport._conn_lost = 1 @@ -1325,7 +1325,7 @@ def test_sendto_tryagain(self): self.assertEqual( [(b'data', ('0.0.0.0', 12345))], list(transport._buffer)) - @unittest.mock.patch('tulip.selector_events.tulip_log') + @unittest.mock.patch('tulip.selector_events.logger') def test_sendto_exception(self, m_log): data = b'data' err = self.sock.sendto.side_effect = OSError() @@ -1475,7 +1475,7 @@ def test_sendto_ready_connection_refused_connection(self): self.assertTrue(transport._fatal_error.called) - @unittest.mock.patch('tulip.log.tulip_log.exception') + @unittest.mock.patch('tulip.log.logger.exception') def test_fatal_error_connected(self, m_exc): transport = _SelectorDatagramTransport( self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index c40455aa..a5f7de45 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -87,7 +87,7 @@ class Err(OSError): signal.SIGINT, lambda: True) @unittest.mock.patch('tulip.unix_events.signal') - @unittest.mock.patch('tulip.unix_events.tulip_log') + @unittest.mock.patch('tulip.unix_events.logger') def test_add_signal_handler_install_error2(self, m_logging, m_signal): m_signal.NSIG = signal.NSIG @@ -104,7 +104,7 @@ class Err(OSError): self.assertEqual(1, m_signal.set_wakeup_fd.call_count) @unittest.mock.patch('tulip.unix_events.signal') - @unittest.mock.patch('tulip.unix_events.tulip_log') + @unittest.mock.patch('tulip.unix_events.logger') def test_add_signal_handler_install_error3(self, m_logging, m_signal): class Err(OSError): errno = errno.EINVAL @@ -149,7 +149,7 @@ def test_remove_signal_handler_2(self, m_signal): m_signal.signal.call_args[0]) @unittest.mock.patch('tulip.unix_events.signal') - @unittest.mock.patch('tulip.unix_events.tulip_log') + @unittest.mock.patch('tulip.unix_events.logger') def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): m_signal.NSIG = signal.NSIG self.loop.add_signal_handler(signal.SIGHUP, lambda: True) @@ -270,7 +270,7 @@ def test__sig_chld_unknown_status(self, m_waitpid, self.assertFalse(m_WEXITSTATUS.called) self.assertFalse(m_WTERMSIG.called) - @unittest.mock.patch('tulip.unix_events.tulip_log') + @unittest.mock.patch('tulip.unix_events.logger') @unittest.mock.patch('os.WTERMSIG') @unittest.mock.patch('os.WEXITSTATUS') @unittest.mock.patch('os.WIFSIGNALED') @@ -360,7 +360,7 @@ def test__read_ready_blocked(self, m_read): test_utils.run_briefly(self.loop) self.assertFalse(self.protocol.data_received.called) - @unittest.mock.patch('tulip.log.tulip_log.exception') + @unittest.mock.patch('tulip.log.logger.exception') @unittest.mock.patch('os.read') def test__read_ready_error(self, m_read, m_logexc): tr = unix_events._UnixReadPipeTransport( @@ -550,7 +550,7 @@ def test_write_again(self, m_write): self.loop.assert_writer(5, tr._write_ready) self.assertEqual([b'data'], tr._buffer) - @unittest.mock.patch('tulip.unix_events.tulip_log') + @unittest.mock.patch('tulip.unix_events.logger') @unittest.mock.patch('os.write') def test_write_err(self, m_write, m_log): tr = unix_events._UnixWritePipeTransport( @@ -648,7 +648,7 @@ def test__write_ready_empty(self, m_write): self.loop.assert_writer(5, tr._write_ready) self.assertEqual([b'data'], tr._buffer) - @unittest.mock.patch('tulip.log.tulip_log.exception') + @unittest.mock.patch('tulip.log.logger.exception') @unittest.mock.patch('os.write') def test__write_ready_err(self, m_write, m_logexc): tr = unix_events._UnixWritePipeTransport( diff --git a/tulip/base_events.py b/tulip/base_events.py index dcd97617..5f1bff71 100644 --- a/tulip/base_events.py +++ b/tulip/base_events.py @@ -27,7 +27,7 @@ from . import events from . import futures from . import tasks -from .log import tulip_log +from .log import logger __all__ = ['BaseEventLoop', 'Server'] @@ -580,7 +580,7 @@ def _run_once(self): level = logging.INFO else: level = logging.DEBUG - tulip_log.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + logger.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) self._process_events(event_list) # Handle 'later' callbacks that are ready. diff --git a/tulip/events.py b/tulip/events.py index a3fdd380..6ca5668c 100644 --- a/tulip/events.py +++ b/tulip/events.py @@ -12,7 +12,7 @@ import threading import socket -from .log import tulip_log +from .log import logger class Handle: @@ -36,8 +36,8 @@ def _run(self): try: self._callback(*self._args) except Exception: - tulip_log.exception('Exception in callback %s %r', - self._callback, self._args) + logger.exception('Exception in callback %s %r', + self._callback, self._args) self = None # Needed to break cycles when an exception occurs. diff --git a/tulip/futures.py b/tulip/futures.py index 706e8c8a..db278386 100644 --- a/tulip/futures.py +++ b/tulip/futures.py @@ -10,7 +10,7 @@ import traceback from . import events -from .log import tulip_log +from .log import logger # States for Future. _PENDING = 'PENDING' @@ -99,8 +99,8 @@ def clear(self): def __del__(self): if self.tb: - tulip_log.error('Future/Task exception was never retrieved:\n%s', - ''.join(self.tb)) + logger.error('Future/Task exception was never retrieved:\n%s', + ''.join(self.tb)) class Future: diff --git a/tulip/log.py b/tulip/log.py index b918fe54..5f3534c7 100644 --- a/tulip/log.py +++ b/tulip/log.py @@ -3,4 +3,5 @@ import logging -tulip_log = logging.getLogger("tulip") +# Name the logger after the package. +logger = logging.getLogger(__package__) diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index 24080cea..7ce9bbb0 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -10,7 +10,7 @@ from . import constants from . import futures from . import transports -from .log import tulip_log +from .log import logger class _ProactorBasePipeTransport(transports.BaseTransport): @@ -50,7 +50,7 @@ def close(self): self._read_fut.cancel() def _fatal_error(self, exc): - tulip_log.exception('Fatal error for %s', self) + logger.exception('Fatal error for %s', self) self._force_close(exc) def _force_close(self, exc): @@ -164,7 +164,7 @@ def write(self, data): if self._conn_lost: if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: - tulip_log.warning('socket.send() raised exception.') + logger.warning('socket.send() raised exception.') self._conn_lost += 1 return self._buffer.append(data) @@ -240,7 +240,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop): def __init__(self, proactor): super().__init__() - tulip_log.debug('Using proactor: %s', proactor.__class__.__name__) + logger.debug('Using proactor: %s', proactor.__class__.__name__) self._proactor = proactor self._selector = proactor # convenient alias proactor.set_loop(self) @@ -329,7 +329,7 @@ def loop(f=None): f = self._proactor.accept(sock) except OSError: if sock.fileno() != -1: - tulip_log.exception('Accept failed') + logger.exception('Accept failed') sock.close() except futures.CancelledError: sock.close() diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 98e0a948..e8ae8854 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -17,7 +17,7 @@ from . import futures from . import selectors from . import transports -from .log import tulip_log +from .log import logger class BaseSelectorEventLoop(base_events.BaseEventLoop): @@ -31,7 +31,7 @@ def __init__(self, selector=None): if selector is None: selector = selectors.DefaultSelector() - tulip_log.debug('Using selector: %s', selector.__class__.__name__) + logger.debug('Using selector: %s', selector.__class__.__name__) self._selector = selector self._make_self_pipe() @@ -105,7 +105,7 @@ def _accept_connection(self, protocol_factory, sock, ssl=None, sock.close() # There's nowhere to send the error, so just log it. # TODO: Someone will want an error handler for this. - tulip_log.exception('Accept failed') + logger.exception('Accept failed') else: if ssl: self._make_ssl_transport( @@ -363,7 +363,7 @@ def close(self): def _fatal_error(self, exc): # should be called from exception handler only - tulip_log.exception('Fatal error for %s', self) + logger.exception('Fatal error for %s', self) self._force_close(exc) def _force_close(self, exc): @@ -444,7 +444,7 @@ def write(self, data): if self._conn_lost: if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: - tulip_log.warning('socket.send() raised exception.') + logger.warning('socket.send() raised exception.') self._conn_lost += 1 return @@ -667,7 +667,7 @@ def write(self, data): if self._conn_lost: if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: - tulip_log.warning('socket.send() raised exception.') + logger.warning('socket.send() raised exception.') self._conn_lost += 1 return @@ -714,7 +714,7 @@ def sendto(self, data, addr=None): if self._conn_lost and self._address: if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: - tulip_log.warning('socket.send() raised exception.') + logger.warning('socket.send() raised exception.') self._conn_lost += 1 return diff --git a/tulip/tasks.py b/tulip/tasks.py index c998a7ed..7aba698c 100644 --- a/tulip/tasks.py +++ b/tulip/tasks.py @@ -16,7 +16,7 @@ from . import events from . import futures -from .log import tulip_log +from .log import logger # If you set _DEBUG to true, @coroutine will wrap the resulting # generator objects in a CoroWrapper instance (defined below). That @@ -62,8 +62,8 @@ def __del__(self): code = func.__code__ filename = code.co_filename lineno = code.co_firstlineno - tulip_log.error('Coroutine %r defined at %s:%s was never yielded from', - func.__name__, filename, lineno) + logger.error('Coroutine %r defined at %s:%s was never yielded from', + func.__name__, filename, lineno) def coroutine(func): diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 1808256e..34b2aea0 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -18,7 +18,7 @@ from . import selector_events from . import tasks from . import transports -from .log import tulip_log +from .log import logger __all__ = ['SelectorEventLoop', 'STDIN', 'STDOUT', 'STDERR'] @@ -79,7 +79,7 @@ def add_signal_handler(self, sig, callback, *args): try: signal.set_wakeup_fd(-1) except ValueError as nexc: - tulip_log.info('set_wakeup_fd(-1) failed: %s', nexc) + logger.info('set_wakeup_fd(-1) failed: %s', nexc) if exc.errno == errno.EINVAL: raise RuntimeError('sig {} cannot be caught'.format(sig)) @@ -124,7 +124,7 @@ def remove_signal_handler(self, sig): try: signal.set_wakeup_fd(-1) except ValueError as exc: - tulip_log.info('set_wakeup_fd(-1) failed: %s', exc) + logger.info('set_wakeup_fd(-1) failed: %s', exc) return True @@ -185,7 +185,7 @@ def _sig_chld(self): if transp is not None: transp._process_exited(returncode) except Exception: - tulip_log.exception('Unknown exception in SIGCHLD handler') + logger.exception('Unknown exception in SIGCHLD handler') def _subprocess_closed(self, transport): pid = transport.get_pid() @@ -244,7 +244,7 @@ def close(self): def _fatal_error(self, exc): # should be called by exception handler only - tulip_log.exception('Fatal error for %s', self) + logger.exception('Fatal error for %s', self) self._close(exc) def _close(self, exc): @@ -294,8 +294,8 @@ def write(self, data): if self._conn_lost or self._closing: if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: - tulip_log.warning('pipe closed by peer or ' - 'os.write(pipe, data) raised exception.') + logger.warning('pipe closed by peer or ' + 'os.write(pipe, data) raised exception.') self._conn_lost += 1 return @@ -369,7 +369,7 @@ def abort(self): def _fatal_error(self, exc): # should be called by exception handler only - tulip_log.exception('Fatal error for %s', self) + logger.exception('Fatal error for %s', self) self._close(exc) def _close(self, exc=None): diff --git a/tulip/windows_events.py b/tulip/windows_events.py index f776b16c..ad53a8e7 100644 --- a/tulip/windows_events.py +++ b/tulip/windows_events.py @@ -12,7 +12,7 @@ from . import tasks from . import windows_utils from . import _overlapped -from .log import tulip_log +from .log import logger __all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor'] @@ -135,7 +135,7 @@ def loop(f=None): f = self._proactor.accept_pipe(pipe) except OSError: if pipe and pipe.fileno() != -1: - tulip_log.exception('Pipe accept failed') + logger.exception('Pipe accept failed') pipe.close() except futures.CancelledError: if pipe: @@ -363,7 +363,7 @@ def close(self): while self._cache: if not self._poll(1): - tulip_log.debug('taking long time to close proactor') + logger.debug('taking long time to close proactor') self._results = [] if self._iocp is not None: From 4408d676ec85408e4a0b10e199eb03c32d4bcd0b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 17 Oct 2013 15:37:47 -0700 Subject: [PATCH 0714/1502] Merge --- asyncio/base_events.py | 4 +-- asyncio/events.py | 6 ++-- asyncio/futures.py | 6 ++-- asyncio/log.py | 3 +- asyncio/proactor_events.py | 10 +++---- asyncio/selector_events.py | 14 ++++----- asyncio/tasks.py | 6 ++-- asyncio/unix_events.py | 16 +++++------ asyncio/windows_events.py | 6 ++-- examples/sink.py | 29 +++++++++++++++---- examples/source.py | 54 ++++++++++++++++++++++++++--------- tests/test_base_events.py | 4 +-- tests/test_events.py | 2 +- tests/test_futures.py | 12 ++++---- tests/test_proactor_events.py | 6 ++-- tests/test_selector_events.py | 12 ++++---- tests/test_unix_events.py | 14 ++++----- 17 files changed, 126 insertions(+), 78 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 32457ebe..5f1bff71 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -27,7 +27,7 @@ from . import events from . import futures from . import tasks -from .log import asyncio_log +from .log import logger __all__ = ['BaseEventLoop', 'Server'] @@ -580,7 +580,7 @@ def _run_once(self): level = logging.INFO else: level = logging.DEBUG - asyncio_log.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + logger.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) self._process_events(event_list) # Handle 'later' callbacks that are ready. diff --git a/asyncio/events.py b/asyncio/events.py index 9724615b..6ca5668c 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -12,7 +12,7 @@ import threading import socket -from .log import asyncio_log +from .log import logger class Handle: @@ -36,8 +36,8 @@ def _run(self): try: self._callback(*self._args) except Exception: - asyncio_log.exception('Exception in callback %s %r', - self._callback, self._args) + logger.exception('Exception in callback %s %r', + self._callback, self._args) self = None # Needed to break cycles when an exception occurs. diff --git a/asyncio/futures.py b/asyncio/futures.py index 99a043b4..db278386 100644 --- a/asyncio/futures.py +++ b/asyncio/futures.py @@ -10,7 +10,7 @@ import traceback from . import events -from .log import asyncio_log +from .log import logger # States for Future. _PENDING = 'PENDING' @@ -99,8 +99,8 @@ def clear(self): def __del__(self): if self.tb: - asyncio_log.error('Future/Task exception was never retrieved:\n%s', - ''.join(self.tb)) + logger.error('Future/Task exception was never retrieved:\n%s', + ''.join(self.tb)) class Future: diff --git a/asyncio/log.py b/asyncio/log.py index 54dc784e..23a7074a 100644 --- a/asyncio/log.py +++ b/asyncio/log.py @@ -3,4 +3,5 @@ import logging -asyncio_log = logging.getLogger("asyncio") +# Name the logger after the package. +logger = logging.getLogger(__package__) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 348de033..c1347b7d 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -10,7 +10,7 @@ from . import constants from . import futures from . import transports -from .log import asyncio_log +from .log import logger class _ProactorBasePipeTransport(transports.BaseTransport): @@ -50,7 +50,7 @@ def close(self): self._read_fut.cancel() def _fatal_error(self, exc): - asyncio_log.exception('Fatal error for %s', self) + logger.exception('Fatal error for %s', self) self._force_close(exc) def _force_close(self, exc): @@ -164,7 +164,7 @@ def write(self, data): if self._conn_lost: if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: - asyncio_log.warning('socket.send() raised exception.') + logger.warning('socket.send() raised exception.') self._conn_lost += 1 return self._buffer.append(data) @@ -246,7 +246,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop): def __init__(self, proactor): super().__init__() - asyncio_log.debug('Using proactor: %s', proactor.__class__.__name__) + logger.debug('Using proactor: %s', proactor.__class__.__name__) self._proactor = proactor self._selector = proactor # convenient alias proactor.set_loop(self) @@ -335,7 +335,7 @@ def loop(f=None): f = self._proactor.accept(sock) except OSError: if sock.fileno() != -1: - asyncio_log.exception('Accept failed') + logger.exception('Accept failed') sock.close() except futures.CancelledError: sock.close() diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index bae9a493..e8ae8854 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -17,7 +17,7 @@ from . import futures from . import selectors from . import transports -from .log import asyncio_log +from .log import logger class BaseSelectorEventLoop(base_events.BaseEventLoop): @@ -31,7 +31,7 @@ def __init__(self, selector=None): if selector is None: selector = selectors.DefaultSelector() - asyncio_log.debug('Using selector: %s', selector.__class__.__name__) + logger.debug('Using selector: %s', selector.__class__.__name__) self._selector = selector self._make_self_pipe() @@ -105,7 +105,7 @@ def _accept_connection(self, protocol_factory, sock, ssl=None, sock.close() # There's nowhere to send the error, so just log it. # TODO: Someone will want an error handler for this. - asyncio_log.exception('Accept failed') + logger.exception('Accept failed') else: if ssl: self._make_ssl_transport( @@ -363,7 +363,7 @@ def close(self): def _fatal_error(self, exc): # should be called from exception handler only - asyncio_log.exception('Fatal error for %s', self) + logger.exception('Fatal error for %s', self) self._force_close(exc) def _force_close(self, exc): @@ -444,7 +444,7 @@ def write(self, data): if self._conn_lost: if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: - asyncio_log.warning('socket.send() raised exception.') + logger.warning('socket.send() raised exception.') self._conn_lost += 1 return @@ -667,7 +667,7 @@ def write(self, data): if self._conn_lost: if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: - asyncio_log.warning('socket.send() raised exception.') + logger.warning('socket.send() raised exception.') self._conn_lost += 1 return @@ -714,7 +714,7 @@ def sendto(self, data, addr=None): if self._conn_lost and self._address: if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: - asyncio_log.warning('socket.send() raised exception.') + logger.warning('socket.send() raised exception.') self._conn_lost += 1 return diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 2c8579fa..63850178 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -16,7 +16,7 @@ from . import events from . import futures -from .log import asyncio_log +from .log import logger # If you set _DEBUG to true, @coroutine will wrap the resulting # generator objects in a CoroWrapper instance (defined below). That @@ -62,8 +62,8 @@ def __del__(self): code = func.__code__ filename = code.co_filename lineno = code.co_firstlineno - asyncio_log.error('Coroutine %r defined at %s:%s was never yielded from', - func.__name__, filename, lineno) + logger.error('Coroutine %r defined at %s:%s was never yielded from', + func.__name__, filename, lineno) def coroutine(func): diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index a3a8e112..34b2aea0 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -18,7 +18,7 @@ from . import selector_events from . import tasks from . import transports -from .log import asyncio_log +from .log import logger __all__ = ['SelectorEventLoop', 'STDIN', 'STDOUT', 'STDERR'] @@ -79,7 +79,7 @@ def add_signal_handler(self, sig, callback, *args): try: signal.set_wakeup_fd(-1) except ValueError as nexc: - asyncio_log.info('set_wakeup_fd(-1) failed: %s', nexc) + logger.info('set_wakeup_fd(-1) failed: %s', nexc) if exc.errno == errno.EINVAL: raise RuntimeError('sig {} cannot be caught'.format(sig)) @@ -124,7 +124,7 @@ def remove_signal_handler(self, sig): try: signal.set_wakeup_fd(-1) except ValueError as exc: - asyncio_log.info('set_wakeup_fd(-1) failed: %s', exc) + logger.info('set_wakeup_fd(-1) failed: %s', exc) return True @@ -185,7 +185,7 @@ def _sig_chld(self): if transp is not None: transp._process_exited(returncode) except Exception: - asyncio_log.exception('Unknown exception in SIGCHLD handler') + logger.exception('Unknown exception in SIGCHLD handler') def _subprocess_closed(self, transport): pid = transport.get_pid() @@ -244,7 +244,7 @@ def close(self): def _fatal_error(self, exc): # should be called by exception handler only - asyncio_log.exception('Fatal error for %s', self) + logger.exception('Fatal error for %s', self) self._close(exc) def _close(self, exc): @@ -294,8 +294,8 @@ def write(self, data): if self._conn_lost or self._closing: if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: - asyncio_log.warning('pipe closed by peer or ' - 'os.write(pipe, data) raised exception.') + logger.warning('pipe closed by peer or ' + 'os.write(pipe, data) raised exception.') self._conn_lost += 1 return @@ -369,7 +369,7 @@ def abort(self): def _fatal_error(self, exc): # should be called by exception handler only - asyncio_log.exception('Fatal error for %s', self) + logger.exception('Fatal error for %s', self) self._close(exc) def _close(self, exc=None): diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 1d0ad26b..bbeada87 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -11,7 +11,7 @@ from . import selector_events from . import tasks from . import windows_utils -from .log import asyncio_log +from .log import logger try: import _overlapped @@ -139,7 +139,7 @@ def loop(f=None): f = self._proactor.accept_pipe(pipe) except OSError: if pipe and pipe.fileno() != -1: - asyncio_log.exception('Pipe accept failed') + logger.exception('Pipe accept failed') pipe.close() except futures.CancelledError: if pipe: @@ -367,7 +367,7 @@ def close(self): while self._cache: if not self._poll(1): - asyncio_log.debug('taking long time to close proactor') + logger.debug('taking long time to close proactor') self._results = [] if self._iocp is not None: diff --git a/examples/sink.py b/examples/sink.py index bb29be24..eab6ffa4 100644 --- a/examples/sink.py +++ b/examples/sink.py @@ -1,14 +1,28 @@ """Test service that accepts connections and reads all data off them.""" +import argparse import sys from asyncio import * +ARGS = argparse.ArgumentParser(description="TCP data sink example.") +ARGS.add_argument( + '--iocp', action='store_true', dest='iocp', + default=False, help='Use IOCP event loop (Windows only)') +ARGS.add_argument( + '--host', action='store', dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action='store', dest='port', + default=1111, type=int, help='Port number') + server = None + def dprint(*args): print('sink:', *args, file=sys.stderr) + class Service(Protocol): def connection_made(self, tr): @@ -32,21 +46,26 @@ def data_received(self, data): def connection_lost(self, how): dprint('closed', repr(how)) + @coroutine -def start(loop): +def start(loop, host, port): global server - server = yield from loop.create_server(Service, 'localhost', 1111) + server = yield from loop.create_server(Service, host, port) dprint('serving', [s.getsockname() for s in server.sockets]) yield from server.wait_closed() + def main(): - if '--iocp' in sys.argv: + args = ARGS.parse_args() + if args.iocp: from asyncio.windows_events import ProactorEventLoop loop = ProactorEventLoop() set_event_loop(loop) - loop = get_event_loop() - loop.run_until_complete(start(loop)) + else: + loop = get_event_loop() + loop.run_until_complete(start(loop, args.host, args.port)) loop.close() + if __name__ == '__main__': main() diff --git a/examples/source.py b/examples/source.py index 6bfdcb0f..adaeeb35 100644 --- a/examples/source.py +++ b/examples/source.py @@ -1,15 +1,36 @@ """Test client that connects and sends infinite data.""" +import argparse import sys from asyncio import * + +ARGS = argparse.ArgumentParser(description="TCP data sink example.") +ARGS.add_argument( + '--iocp', action='store_true', dest='iocp', + default=False, help='Use IOCP event loop (Windows only)') +ARGS.add_argument( + '--stop', action='store_true', dest='stop', + default=False, help='Stop the server by sending it b"stop" as data') +ARGS.add_argument( + '--host', action='store', dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action='store', dest='port', + default=1111, type=int, help='Port number') +ARGS.add_argument( + '--size', action='store', dest='size', + default=16*1024, type=int, help='Data size') + +args = None + + def dprint(*args): print('source:', *args, file=sys.stderr) -class Client(Protocol): - data = b'x'*16*1024 +class Client(Protocol): def connection_made(self, tr): dprint('connecting to', tr.get_extra_info('peername')) @@ -18,41 +39,48 @@ def connection_made(self, tr): self.lost = False self.loop = get_event_loop() self.waiter = Future() - if '--stop' in sys.argv[1:]: + if args.stop: self.tr.write(b'stop') self.tr.close() else: - self.write_some_data() + data = b'x' * args.size + self.write_some_data(data) - def write_some_data(self): + def write_some_data(self, data): if self.lost: dprint('lost already') return - dprint('writing', len(self.data), 'bytes') - self.tr.write(self.data) - self.loop.call_soon(self.write_some_data) + dprint('writing', len(data), 'bytes') + self.tr.write(data) + self.loop.call_soon(self.write_some_data, data) def connection_lost(self, exc): dprint('lost connection', repr(exc)) self.lost = True self.waiter.set_result(None) + @coroutine -def start(loop): - tr, pr = yield from loop.create_connection(Client, '127.0.0.1', 1111) +def start(loop, host, port): + tr, pr = yield from loop.create_connection(Client, host, port) dprint('tr =', tr) dprint('pr =', pr) res = yield from pr.waiter return res + def main(): - if '--iocp' in sys.argv: + global args + args = ARGS.parse_args() + if args.iocp: from asyncio.windows_events import ProactorEventLoop loop = ProactorEventLoop() set_event_loop(loop) - loop = get_event_loop() - loop.run_until_complete(start(loop)) + else: + loop = get_event_loop() + loop.run_until_complete(start(loop, args.host, args.port)) loop.close() + if __name__ == '__main__': main() diff --git a/tests/test_base_events.py b/tests/test_base_events.py index d48d12cd..e62f4764 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -183,7 +183,7 @@ def test__run_once(self): self.assertTrue(self.loop._process_events.called) @unittest.mock.patch('asyncio.base_events.time') - @unittest.mock.patch('asyncio.base_events.asyncio_log') + @unittest.mock.patch('asyncio.base_events.logger') def test__run_once_logging(self, m_logging, m_time): # Log to INFO level if timeout > 1.0 sec. idx = -1 @@ -579,7 +579,7 @@ def test_accept_connection_retry(self): self.loop._accept_connection(MyProto, sock) self.assertFalse(sock.close.called) - @unittest.mock.patch('asyncio.selector_events.asyncio_log') + @unittest.mock.patch('asyncio.selector_events.logger') def test_accept_connection_exception(self, m_log): sock = unittest.mock.Mock() sock.fileno.return_value = 10 diff --git a/tests/test_events.py b/tests/test_events.py index 243f4001..a9a92712 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1320,7 +1320,7 @@ def callback(*args): self.assertRaises( AssertionError, events.make_handle, h1, ()) - @unittest.mock.patch('asyncio.events.asyncio_log') + @unittest.mock.patch('asyncio.events.logger') def test_callback_with_exception(self, log): def callback(): raise ValueError() diff --git a/tests/test_futures.py b/tests/test_futures.py index 9b5108c4..ccea2ffd 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -170,20 +170,20 @@ def test(): self.assertRaises(AssertionError, test) fut.cancel() - @unittest.mock.patch('asyncio.futures.asyncio_log') + @unittest.mock.patch('asyncio.futures.logger') def test_tb_logger_abandoned(self, m_log): fut = futures.Future(loop=self.loop) del fut self.assertFalse(m_log.error.called) - @unittest.mock.patch('asyncio.futures.asyncio_log') + @unittest.mock.patch('asyncio.futures.logger') def test_tb_logger_result_unretrieved(self, m_log): fut = futures.Future(loop=self.loop) fut.set_result(42) del fut self.assertFalse(m_log.error.called) - @unittest.mock.patch('asyncio.futures.asyncio_log') + @unittest.mock.patch('asyncio.futures.logger') def test_tb_logger_result_retrieved(self, m_log): fut = futures.Future(loop=self.loop) fut.set_result(42) @@ -191,7 +191,7 @@ def test_tb_logger_result_retrieved(self, m_log): del fut self.assertFalse(m_log.error.called) - @unittest.mock.patch('asyncio.futures.asyncio_log') + @unittest.mock.patch('asyncio.futures.logger') def test_tb_logger_exception_unretrieved(self, m_log): fut = futures.Future(loop=self.loop) fut.set_exception(RuntimeError('boom')) @@ -199,7 +199,7 @@ def test_tb_logger_exception_unretrieved(self, m_log): test_utils.run_briefly(self.loop) self.assertTrue(m_log.error.called) - @unittest.mock.patch('asyncio.futures.asyncio_log') + @unittest.mock.patch('asyncio.futures.logger') def test_tb_logger_exception_retrieved(self, m_log): fut = futures.Future(loop=self.loop) fut.set_exception(RuntimeError('boom')) @@ -207,7 +207,7 @@ def test_tb_logger_exception_retrieved(self, m_log): del fut self.assertFalse(m_log.error.called) - @unittest.mock.patch('asyncio.futures.asyncio_log') + @unittest.mock.patch('asyncio.futures.logger') def test_tb_logger_exception_result_retrieved(self, m_log): fut = futures.Future(loop=self.loop) fut.set_exception(RuntimeError('boom')) diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index c52ade05..e4dd609c 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -135,7 +135,7 @@ def test_loop_writing(self): self.loop._proactor.send.return_value.add_done_callback.\ assert_called_with(tr._loop_writing) - @unittest.mock.patch('asyncio.proactor_events.asyncio_log') + @unittest.mock.patch('asyncio.proactor_events.logger') def test_loop_writing_err(self, m_log): err = self.loop._proactor.send.side_effect = OSError() tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) @@ -207,7 +207,7 @@ def test_close_buffer(self): test_utils.run_briefly(self.loop) self.assertFalse(self.protocol.connection_lost.called) - @unittest.mock.patch('asyncio.proactor_events.asyncio_log') + @unittest.mock.patch('asyncio.proactor_events.logger') def test_fatal_error(self, m_logging): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._force_close = unittest.mock.Mock() @@ -432,7 +432,7 @@ def test_write_to_self(self): def test_process_events(self): self.loop._process_events([]) - @unittest.mock.patch('asyncio.proactor_events.asyncio_log') + @unittest.mock.patch('asyncio.proactor_events.logger') def test_create_server(self, m_log): pf = unittest.mock.Mock() call_soon = self.loop.call_soon = unittest.mock.Mock() diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 0225e132..1465cd24 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -626,7 +626,7 @@ def test_force_close(self): self.assertFalse(self.loop.readers) self.assertEqual(1, self.loop.remove_reader_count[7]) - @unittest.mock.patch('asyncio.log.asyncio_log.exception') + @unittest.mock.patch('asyncio.log.logger.exception') def test_fatal_error(self, m_exc): exc = OSError() tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) @@ -823,7 +823,7 @@ def test_write_tryagain(self): self.loop.assert_writer(7, transport._write_ready) self.assertEqual(collections.deque([b'data']), transport._buffer) - @unittest.mock.patch('asyncio.selector_events.asyncio_log') + @unittest.mock.patch('asyncio.selector_events.logger') def test_write_exception(self, m_log): err = self.sock.send.side_effect = OSError() @@ -937,7 +937,7 @@ def test_write_ready_exception(self): transport._write_ready() transport._fatal_error.assert_called_with(err) - @unittest.mock.patch('asyncio.selector_events.asyncio_log') + @unittest.mock.patch('asyncio.selector_events.logger') def test_write_ready_exception_and_close(self, m_log): self.sock.send.side_effect = OSError() remove_writer = self.loop.remove_writer = unittest.mock.Mock() @@ -1072,7 +1072,7 @@ def test_write_closing(self): transport.write(b'data') self.assertEqual(transport._conn_lost, 2) - @unittest.mock.patch('asyncio.selector_events.asyncio_log') + @unittest.mock.patch('asyncio.selector_events.logger') def test_write_exception(self, m_log): transport = self._make_one() transport._conn_lost = 1 @@ -1325,7 +1325,7 @@ def test_sendto_tryagain(self): self.assertEqual( [(b'data', ('0.0.0.0', 12345))], list(transport._buffer)) - @unittest.mock.patch('asyncio.selector_events.asyncio_log') + @unittest.mock.patch('asyncio.selector_events.logger') def test_sendto_exception(self, m_log): data = b'data' err = self.sock.sendto.side_effect = OSError() @@ -1475,7 +1475,7 @@ def test_sendto_ready_connection_refused_connection(self): self.assertTrue(transport._fatal_error.called) - @unittest.mock.patch('asyncio.log.asyncio_log.exception') + @unittest.mock.patch('asyncio.log.logger.exception') def test_fatal_error_connected(self, m_exc): transport = _SelectorDatagramTransport( self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 6dbd47f6..227366d9 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -87,7 +87,7 @@ class Err(OSError): signal.SIGINT, lambda: True) @unittest.mock.patch('asyncio.unix_events.signal') - @unittest.mock.patch('asyncio.unix_events.asyncio_log') + @unittest.mock.patch('asyncio.unix_events.logger') def test_add_signal_handler_install_error2(self, m_logging, m_signal): m_signal.NSIG = signal.NSIG @@ -104,7 +104,7 @@ class Err(OSError): self.assertEqual(1, m_signal.set_wakeup_fd.call_count) @unittest.mock.patch('asyncio.unix_events.signal') - @unittest.mock.patch('asyncio.unix_events.asyncio_log') + @unittest.mock.patch('asyncio.unix_events.logger') def test_add_signal_handler_install_error3(self, m_logging, m_signal): class Err(OSError): errno = errno.EINVAL @@ -149,7 +149,7 @@ def test_remove_signal_handler_2(self, m_signal): m_signal.signal.call_args[0]) @unittest.mock.patch('asyncio.unix_events.signal') - @unittest.mock.patch('asyncio.unix_events.asyncio_log') + @unittest.mock.patch('asyncio.unix_events.logger') def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): m_signal.NSIG = signal.NSIG self.loop.add_signal_handler(signal.SIGHUP, lambda: True) @@ -270,7 +270,7 @@ def test__sig_chld_unknown_status(self, m_waitpid, self.assertFalse(m_WEXITSTATUS.called) self.assertFalse(m_WTERMSIG.called) - @unittest.mock.patch('asyncio.unix_events.asyncio_log') + @unittest.mock.patch('asyncio.unix_events.logger') @unittest.mock.patch('os.WTERMSIG') @unittest.mock.patch('os.WEXITSTATUS') @unittest.mock.patch('os.WIFSIGNALED') @@ -360,7 +360,7 @@ def test__read_ready_blocked(self, m_read): test_utils.run_briefly(self.loop) self.assertFalse(self.protocol.data_received.called) - @unittest.mock.patch('asyncio.log.asyncio_log.exception') + @unittest.mock.patch('asyncio.log.logger.exception') @unittest.mock.patch('os.read') def test__read_ready_error(self, m_read, m_logexc): tr = unix_events._UnixReadPipeTransport( @@ -550,7 +550,7 @@ def test_write_again(self, m_write): self.loop.assert_writer(5, tr._write_ready) self.assertEqual([b'data'], tr._buffer) - @unittest.mock.patch('asyncio.unix_events.asyncio_log') + @unittest.mock.patch('asyncio.unix_events.logger') @unittest.mock.patch('os.write') def test_write_err(self, m_write, m_log): tr = unix_events._UnixWritePipeTransport( @@ -648,7 +648,7 @@ def test__write_ready_empty(self, m_write): self.loop.assert_writer(5, tr._write_ready) self.assertEqual([b'data'], tr._buffer) - @unittest.mock.patch('asyncio.log.asyncio_log.exception') + @unittest.mock.patch('asyncio.log.logger.exception') @unittest.mock.patch('os.write') def test__write_ready_err(self, m_write, m_logexc): tr = unix_events._UnixWritePipeTransport( From 03e6f5f9033996460daa49be2f934e833c456e9b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 18 Oct 2013 07:42:58 -0700 Subject: [PATCH 0715/1502] Rename Transport.pause/resume to pause_reading/resume_reading. --- tests/proactor_events_test.py | 6 +++--- tests/selector_events_test.py | 12 ++++++------ tests/transports_test.py | 4 ++-- tests/unix_events_test.py | 8 ++++---- tulip/proactor_events.py | 6 +++--- tulip/selector_events.py | 14 +++++++------- tulip/streams.py | 4 ++-- tulip/transports.py | 6 +++--- tulip/unix_events.py | 4 ++-- 9 files changed, 32 insertions(+), 32 deletions(-) diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py index 27b83a3f..c1240838 100644 --- a/tests/proactor_events_test.py +++ b/tests/proactor_events_test.py @@ -308,7 +308,7 @@ def test_write_eof_duplex_pipe(self): tr.write_eof() tr.close() - def test_pause_resume(self): + def test_pause_resume_reading(self): tr = _ProactorSocketTransport( self.loop, self.sock, self.protocol) futures = [] @@ -323,12 +323,12 @@ def test_pause_resume(self): self.protocol.data_received.assert_called_with(b'data1') self.loop._run_once() self.protocol.data_received.assert_called_with(b'data2') - tr.pause() + tr.pause_reading() self.assertTrue(tr._paused) for i in range(10): self.loop._run_once() self.protocol.data_received.assert_called_with(b'data2') - tr.resume() + tr.resume_reading() self.assertFalse(tr._paused) self.loop._run_once() self.protocol.data_received.assert_called_with(b'data3') diff --git a/tests/selector_events_test.py b/tests/selector_events_test.py index 9f0b117e..25521c0e 100644 --- a/tests/selector_events_test.py +++ b/tests/selector_events_test.py @@ -676,15 +676,15 @@ def test_ctor_with_waiter(self): test_utils.run_briefly(self.loop) self.assertIsNone(fut.result()) - def test_pause_resume(self): + def test_pause_resume_reading(self): tr = _SelectorSocketTransport( self.loop, self.sock, self.protocol) self.assertFalse(tr._paused) self.loop.assert_reader(7, tr._read_ready) - tr.pause() + tr.pause_reading() self.assertTrue(tr._paused) self.assertFalse(7 in self.loop.readers) - tr.resume() + tr.resume_reading() self.assertFalse(tr._paused) self.loop.assert_reader(7, tr._read_ready) @@ -1044,14 +1044,14 @@ def test_on_handshake_base_exc(self): self.assertTrue(transport._waiter.done()) self.assertIs(exc, transport._waiter.exception()) - def test_pause_resume(self): + def test_pause_resume_reading(self): tr = self._make_one() self.assertFalse(tr._paused) self.loop.assert_reader(1, tr._on_ready) - tr.pause() + tr.pause_reading() self.assertTrue(tr._paused) self.assertFalse(1 in self.loop.readers) - tr.resume() + tr.resume_reading() self.assertFalse(tr._paused) self.loop.assert_reader(1, tr._on_ready) diff --git a/tests/transports_test.py b/tests/transports_test.py index 304ec206..38492931 100644 --- a/tests/transports_test.py +++ b/tests/transports_test.py @@ -33,8 +33,8 @@ def test_not_implemented(self): self.assertRaises(NotImplementedError, transport.write, 'data') self.assertRaises(NotImplementedError, transport.write_eof) self.assertRaises(NotImplementedError, transport.can_write_eof) - self.assertRaises(NotImplementedError, transport.pause) - self.assertRaises(NotImplementedError, transport.resume) + self.assertRaises(NotImplementedError, transport.pause_reading) + self.assertRaises(NotImplementedError, transport.resume_reading) self.assertRaises(NotImplementedError, transport.close) self.assertRaises(NotImplementedError, transport.abort) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py index a5f7de45..5659d31f 100644 --- a/tests/unix_events_test.py +++ b/tests/unix_events_test.py @@ -375,21 +375,21 @@ def test__read_ready_error(self, m_read, m_logexc): m_logexc.assert_called_with('Fatal error for %s', tr) @unittest.mock.patch('os.read') - def test_pause(self, m_read): + def test_pause_reading(self, m_read): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) m = unittest.mock.Mock() self.loop.add_reader(5, m) - tr.pause() + tr.pause_reading() self.assertFalse(self.loop.readers) @unittest.mock.patch('os.read') - def test_resume(self, m_read): + def test_resume_reading(self, m_read): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) - tr.resume() + tr.resume_reading() self.loop.assert_reader(5, tr._read_ready) @unittest.mock.patch('os.read') diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py index 7ce9bbb0..2a2701a8 100644 --- a/tulip/proactor_events.py +++ b/tulip/proactor_events.py @@ -94,12 +94,12 @@ def __init__(self, loop, sock, protocol, waiter=None, self._paused = False self._loop.call_soon(self._loop_reading) - def pause(self): - assert not self._closing, 'Cannot pause() when closing' + def pause_reading(self): + assert not self._closing, 'Cannot pause_reading() when closing' assert not self._paused, 'Already paused' self._paused = True - def resume(self): + def resume_reading(self): assert self._paused, 'Not paused' self._paused = False if self._closing: diff --git a/tulip/selector_events.py b/tulip/selector_events.py index e8ae8854..2edac65b 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -406,13 +406,13 @@ def __init__(self, loop, sock, protocol, waiter=None, if waiter is not None: self._loop.call_soon(waiter.set_result, None) - def pause(self): - assert not self._closing, 'Cannot pause() when closing' + def pause_reading(self): + assert not self._closing, 'Cannot pause_reading() when closing' assert not self._paused, 'Already paused' self._paused = True self._loop.remove_reader(self._sock_fd) - def resume(self): + def resume_reading(self): assert self._paused, 'Not paused' self._paused = False if self._closing: @@ -590,19 +590,19 @@ def _on_handshake(self): if self._waiter is not None: self._loop.call_soon(self._waiter.set_result, None) - def pause(self): + def pause_reading(self): # XXX This is a bit icky, given the comment at the top of # _on_ready(). Is it possible to evoke a deadlock? I don't # know, although it doesn't look like it; write() will still # accept more data for the buffer and eventually the app will - # call resume() again, and things will flow again. + # call resume_reading() again, and things will flow again. - assert not self._closing, 'Cannot pause() when closing' + assert not self._closing, 'Cannot pause_reading() when closing' assert not self._paused, 'Already paused' self._paused = True self._loop.remove_reader(self._sock_fd) - def resume(self): + def resume_reading(self): assert self._paused, 'Not paused' self._paused = False if self._closing: diff --git a/tulip/streams.py b/tulip/streams.py index d0f12e81..9915aa5c 100644 --- a/tulip/streams.py +++ b/tulip/streams.py @@ -106,7 +106,7 @@ def set_transport(self, transport): def _maybe_resume_transport(self): if self._paused and self.byte_count <= self.limit: self._paused = False - self._transport.resume() + self._transport.resume_reading() def feed_eof(self): self.eof = True @@ -133,7 +133,7 @@ def feed_data(self, data): not self._paused and self.byte_count > 2*self.limit): try: - self._transport.pause() + self._transport.pause_reading() except NotImplementedError: # The transport can't be paused. # We'll just have to buffer all data. diff --git a/tulip/transports.py b/tulip/transports.py index bf3adee7..f1a71800 100644 --- a/tulip/transports.py +++ b/tulip/transports.py @@ -29,15 +29,15 @@ def close(self): class ReadTransport(BaseTransport): """ABC for read-only transports.""" - def pause(self): + def pause_reading(self): """Pause the receiving end. No data will be passed to the protocol's data_received() - method until resume() is called. + method until resume_reading() is called. """ raise NotImplementedError - def resume(self): + def resume_reading(self): """Resume the receiving end. Data received will once again be passed to the protocol's diff --git a/tulip/unix_events.py b/tulip/unix_events.py index 34b2aea0..a234f4fa 100644 --- a/tulip/unix_events.py +++ b/tulip/unix_events.py @@ -232,10 +232,10 @@ def _read_ready(self): self._loop.call_soon(self._protocol.eof_received) self._loop.call_soon(self._call_connection_lost, None) - def pause(self): + def pause_reading(self): self._loop.remove_reader(self._fileno) - def resume(self): + def resume_reading(self): self._loop.add_reader(self._fileno, self._read_ready) def close(self): From aaa348fe2d81b6ec6d6b23fada1da4e057752281 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 18 Oct 2013 07:43:44 -0700 Subject: [PATCH 0716/1502] Merge --- asyncio/proactor_events.py | 6 +++--- asyncio/selector_events.py | 14 +++++++------- asyncio/streams.py | 4 ++-- asyncio/transports.py | 6 +++--- asyncio/unix_events.py | 4 ++-- tests/test_proactor_events.py | 6 +++--- tests/test_selector_events.py | 12 ++++++------ tests/test_transports.py | 4 ++-- tests/test_unix_events.py | 8 ++++---- 9 files changed, 32 insertions(+), 32 deletions(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index c1347b7d..665569f0 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -94,12 +94,12 @@ def __init__(self, loop, sock, protocol, waiter=None, self._paused = False self._loop.call_soon(self._loop_reading) - def pause(self): - assert not self._closing, 'Cannot pause() when closing' + def pause_reading(self): + assert not self._closing, 'Cannot pause_reading() when closing' assert not self._paused, 'Already paused' self._paused = True - def resume(self): + def resume_reading(self): assert self._paused, 'Not paused' self._paused = False if self._closing: diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index e8ae8854..2edac65b 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -406,13 +406,13 @@ def __init__(self, loop, sock, protocol, waiter=None, if waiter is not None: self._loop.call_soon(waiter.set_result, None) - def pause(self): - assert not self._closing, 'Cannot pause() when closing' + def pause_reading(self): + assert not self._closing, 'Cannot pause_reading() when closing' assert not self._paused, 'Already paused' self._paused = True self._loop.remove_reader(self._sock_fd) - def resume(self): + def resume_reading(self): assert self._paused, 'Not paused' self._paused = False if self._closing: @@ -590,19 +590,19 @@ def _on_handshake(self): if self._waiter is not None: self._loop.call_soon(self._waiter.set_result, None) - def pause(self): + def pause_reading(self): # XXX This is a bit icky, given the comment at the top of # _on_ready(). Is it possible to evoke a deadlock? I don't # know, although it doesn't look like it; write() will still # accept more data for the buffer and eventually the app will - # call resume() again, and things will flow again. + # call resume_reading() again, and things will flow again. - assert not self._closing, 'Cannot pause() when closing' + assert not self._closing, 'Cannot pause_reading() when closing' assert not self._paused, 'Already paused' self._paused = True self._loop.remove_reader(self._sock_fd) - def resume(self): + def resume_reading(self): assert self._paused, 'Not paused' self._paused = False if self._closing: diff --git a/asyncio/streams.py b/asyncio/streams.py index d0f12e81..9915aa5c 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -106,7 +106,7 @@ def set_transport(self, transport): def _maybe_resume_transport(self): if self._paused and self.byte_count <= self.limit: self._paused = False - self._transport.resume() + self._transport.resume_reading() def feed_eof(self): self.eof = True @@ -133,7 +133,7 @@ def feed_data(self, data): not self._paused and self.byte_count > 2*self.limit): try: - self._transport.pause() + self._transport.pause_reading() except NotImplementedError: # The transport can't be paused. # We'll just have to buffer all data. diff --git a/asyncio/transports.py b/asyncio/transports.py index bf3adee7..f1a71800 100644 --- a/asyncio/transports.py +++ b/asyncio/transports.py @@ -29,15 +29,15 @@ def close(self): class ReadTransport(BaseTransport): """ABC for read-only transports.""" - def pause(self): + def pause_reading(self): """Pause the receiving end. No data will be passed to the protocol's data_received() - method until resume() is called. + method until resume_reading() is called. """ raise NotImplementedError - def resume(self): + def resume_reading(self): """Resume the receiving end. Data received will once again be passed to the protocol's diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 34b2aea0..a234f4fa 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -232,10 +232,10 @@ def _read_ready(self): self._loop.call_soon(self._protocol.eof_received) self._loop.call_soon(self._call_connection_lost, None) - def pause(self): + def pause_reading(self): self._loop.remove_reader(self._fileno) - def resume(self): + def resume_reading(self): self._loop.add_reader(self._fileno, self._read_ready) def close(self): diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index e4dd609c..05d1606c 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -308,7 +308,7 @@ def test_write_eof_duplex_pipe(self): tr.write_eof() tr.close() - def test_pause_resume(self): + def test_pause_resume_reading(self): tr = _ProactorSocketTransport( self.loop, self.sock, self.protocol) futures = [] @@ -323,12 +323,12 @@ def test_pause_resume(self): self.protocol.data_received.assert_called_with(b'data1') self.loop._run_once() self.protocol.data_received.assert_called_with(b'data2') - tr.pause() + tr.pause_reading() self.assertTrue(tr._paused) for i in range(10): self.loop._run_once() self.protocol.data_received.assert_called_with(b'data2') - tr.resume() + tr.resume_reading() self.assertFalse(tr._paused) self.loop._run_once() self.protocol.data_received.assert_called_with(b'data3') diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 1465cd24..53728b8d 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -676,15 +676,15 @@ def test_ctor_with_waiter(self): test_utils.run_briefly(self.loop) self.assertIsNone(fut.result()) - def test_pause_resume(self): + def test_pause_resume_reading(self): tr = _SelectorSocketTransport( self.loop, self.sock, self.protocol) self.assertFalse(tr._paused) self.loop.assert_reader(7, tr._read_ready) - tr.pause() + tr.pause_reading() self.assertTrue(tr._paused) self.assertFalse(7 in self.loop.readers) - tr.resume() + tr.resume_reading() self.assertFalse(tr._paused) self.loop.assert_reader(7, tr._read_ready) @@ -1044,14 +1044,14 @@ def test_on_handshake_base_exc(self): self.assertTrue(transport._waiter.done()) self.assertIs(exc, transport._waiter.exception()) - def test_pause_resume(self): + def test_pause_resume_reading(self): tr = self._make_one() self.assertFalse(tr._paused) self.loop.assert_reader(1, tr._on_ready) - tr.pause() + tr.pause_reading() self.assertTrue(tr._paused) self.assertFalse(1 in self.loop.readers) - tr.resume() + tr.resume_reading() self.assertFalse(tr._paused) self.loop.assert_reader(1, tr._on_ready) diff --git a/tests/test_transports.py b/tests/test_transports.py index fce2e6f5..53071afd 100644 --- a/tests/test_transports.py +++ b/tests/test_transports.py @@ -33,8 +33,8 @@ def test_not_implemented(self): self.assertRaises(NotImplementedError, transport.write, 'data') self.assertRaises(NotImplementedError, transport.write_eof) self.assertRaises(NotImplementedError, transport.can_write_eof) - self.assertRaises(NotImplementedError, transport.pause) - self.assertRaises(NotImplementedError, transport.resume) + self.assertRaises(NotImplementedError, transport.pause_reading) + self.assertRaises(NotImplementedError, transport.resume_reading) self.assertRaises(NotImplementedError, transport.close) self.assertRaises(NotImplementedError, transport.abort) diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 227366d9..ccabeea1 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -375,21 +375,21 @@ def test__read_ready_error(self, m_read, m_logexc): m_logexc.assert_called_with('Fatal error for %s', tr) @unittest.mock.patch('os.read') - def test_pause(self, m_read): + def test_pause_reading(self, m_read): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) m = unittest.mock.Mock() self.loop.add_reader(5, m) - tr.pause() + tr.pause_reading() self.assertFalse(self.loop.readers) @unittest.mock.patch('os.read') - def test_resume(self, m_read): + def test_resume_reading(self, m_read): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) - tr.resume() + tr.resume_reading() self.loop.assert_reader(5, tr._read_ready) @unittest.mock.patch('os.read') From 5bf22a4086bc00a947ba8aa4fa2036ea18511dfe Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 18 Oct 2013 07:54:05 -0700 Subject: [PATCH 0717/1502] More lenient delay expectancy in test_call_later(). --- tests/events_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/events_test.py b/tests/events_test.py index 67b2cfc6..f5f51e5b 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -238,7 +238,7 @@ def callback(arg): self.loop.run_forever() t1 = time.monotonic() self.assertEqual(results, ['hello world']) - self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0) + self.assertTrue(0.08 <= t1-t0 <= 0.2, t1-t0) def test_call_soon(self): results = [] From e7998924caa1d390e3d742ae6843ab7eaab176c6 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 18 Oct 2013 07:54:37 -0700 Subject: [PATCH 0718/1502] Merge --- tests/test_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_events.py b/tests/test_events.py index a9a92712..2e89f72c 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -238,7 +238,7 @@ def callback(arg): self.loop.run_forever() t1 = time.monotonic() self.assertEqual(results, ['hello world']) - self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0) + self.assertTrue(0.08 <= t1-t0 <= 0.2, t1-t0) def test_call_soon(self): results = [] From 4c995bd91f8eec3ec5fe7d8bc2b72c4fd1d3b9e7 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 18 Oct 2013 08:33:33 -0700 Subject: [PATCH 0719/1502] Report total bytes written. --- examples/source.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/source.py b/examples/source.py index c4842240..1515dd1f 100644 --- a/examples/source.py +++ b/examples/source.py @@ -32,6 +32,8 @@ def dprint(*args): class Client(Protocol): + total = 0 + def connection_made(self, tr): dprint('connecting to', tr.get_extra_info('peername')) dprint('my socket is', tr.get_extra_info('sockname')) @@ -50,7 +52,8 @@ def write_some_data(self, data): if self.lost: dprint('lost already') return - dprint('writing', len(data), 'bytes') + self.total += len(data) + dprint('writing', len(data), 'bytes; total', self.total) self.tr.write(data) self.loop.call_soon(self.write_some_data, data) @@ -65,8 +68,7 @@ def start(loop, host, port): tr, pr = yield from loop.create_connection(Client, host, port) dprint('tr =', tr) dprint('pr =', pr) - res = yield from pr.waiter - return res + yield from pr.waiter def main(): From de2c702e579167ccda2c4fb1b40ef37b864235c8 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 18 Oct 2013 10:01:25 -0700 Subject: [PATCH 0720/1502] Rationalize error handling, fixing a race condition. --- tulip/selector_events.py | 51 ++++++++++++++-------------------------- 1 file changed, 18 insertions(+), 33 deletions(-) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 2edac65b..084d9be7 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -344,7 +344,7 @@ def __init__(self, loop, sock, protocol, extra, server=None): self._protocol = protocol self._server = server self._buffer = collections.deque() - self._conn_lost = 0 + self._conn_lost = 0 # Set when call to connection_lost scheduled. self._closing = False # Set when close() called. if server is not None: server.attach(self) @@ -356,27 +356,27 @@ def close(self): if self._closing: return self._closing = True - self._conn_lost += 1 self._loop.remove_reader(self._sock_fd) if not self._buffer: + self._conn_lost += 1 self._loop.call_soon(self._call_connection_lost, None) def _fatal_error(self, exc): - # should be called from exception handler only - logger.exception('Fatal error for %s', self) + # Should be called from exception handler only. + if not isinstance(exc, (BrokenPipeError, ConnectionResetError)): + logger.exception('Fatal error for %s', self) self._force_close(exc) def _force_close(self, exc): + if self._conn_lost: + return if self._buffer: self._buffer.clear() self._loop.remove_writer(self._sock_fd) - - if self._closing: - return - - self._closing = True + if not self._closing: + self._closing = True + self._loop.remove_reader(self._sock_fd) self._conn_lost += 1 - self._loop.remove_reader(self._sock_fd) self._loop.call_soon(self._call_connection_lost, exc) def _call_connection_lost(self, exc): @@ -424,8 +424,6 @@ def _read_ready(self): data = self._sock.recv(self.max_size) except (BlockingIOError, InterruptedError): pass - except ConnectionResetError as exc: - self._force_close(exc) except Exception as exc: self._fatal_error(exc) else: @@ -453,17 +451,15 @@ def write(self, data): try: n = self._sock.send(data) except (BlockingIOError, InterruptedError): - n = 0 - except (BrokenPipeError, ConnectionResetError) as exc: - self._force_close(exc) - return - except OSError as exc: + pass + except Exception as exc: self._fatal_error(exc) return else: data = data[n:] if not data: return + # Start async I/O. self._loop.add_writer(self._sock_fd, self._write_ready) @@ -478,9 +474,6 @@ def _write_ready(self): n = self._sock.send(data) except (BlockingIOError, InterruptedError): self._buffer.append(data) - except (BrokenPipeError, ConnectionResetError) as exc: - self._loop.remove_writer(self._sock_fd) - self._force_close(exc) except Exception as exc: self._loop.remove_writer(self._sock_fd) self._fatal_error(exc) @@ -493,7 +486,6 @@ def _write_ready(self): elif self._eof: self._sock.shutdown(socket.SHUT_WR) return - self._buffer.append(data) # Try again later. def write_eof(self): @@ -622,8 +614,6 @@ def _on_ready(self): except (BlockingIOError, InterruptedError, ssl.SSLWantReadError, ssl.SSLWantWriteError): pass - except ConnectionResetError as exc: - self._force_close(exc) except Exception as exc: self._fatal_error(exc) else: @@ -644,10 +634,6 @@ def _on_ready(self): except (BlockingIOError, InterruptedError, ssl.SSLWantReadError, ssl.SSLWantWriteError): n = 0 - except (BrokenPipeError, ConnectionResetError) as exc: - self._loop.remove_writer(self._sock_fd) - self._force_close(exc) - return except Exception as exc: self._loop.remove_writer(self._sock_fd) self._fatal_error(exc) @@ -726,12 +712,12 @@ def sendto(self, data, addr=None): else: self._sock.sendto(data, addr) return + except (BlockingIOError, InterruptedError): + self._loop.add_writer(self._sock_fd, self._sendto_ready) except ConnectionRefusedError as exc: if self._address: self._fatal_error(exc) return - except (BlockingIOError, InterruptedError): - self._loop.add_writer(self._sock_fd, self._sendto_ready) except Exception as exc: self._fatal_error(exc) return @@ -746,13 +732,13 @@ def _sendto_ready(self): self._sock.send(data) else: self._sock.sendto(data, addr) + except (BlockingIOError, InterruptedError): + self._buffer.appendleft((data, addr)) # Try again later. + break except ConnectionRefusedError as exc: if self._address: self._fatal_error(exc) return - except (BlockingIOError, InterruptedError): - self._buffer.appendleft((data, addr)) # Try again later. - break except Exception as exc: self._fatal_error(exc) return @@ -765,5 +751,4 @@ def _sendto_ready(self): def _force_close(self, exc): if self._address and isinstance(exc, ConnectionRefusedError): self._protocol.connection_refused(exc) - super()._force_close(exc) From 3b1a0581febaa1769bb9ef4203e7fd0022489b99 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 18 Oct 2013 11:23:07 -0700 Subject: [PATCH 0721/1502] Improved argument handling. --- .hgeol | 4 + .hgignore | 12 + Makefile | 34 + NOTES | 176 ++++ README | 21 + TODO | 163 ++++ check.py | 45 + examples/child_process.py | 127 +++ examples/fetch0.py | 32 + examples/fetch1.py | 75 ++ examples/fetch2.py | 138 +++ examples/fetch3.py | 211 +++++ examples/sink.py | 76 ++ examples/source.py | 90 ++ examples/stacks.py | 43 + examples/tcp_echo.py | 113 +++ examples/udp_echo.py | 98 ++ overlapped.c | 1202 +++++++++++++++++++++++++ runtests.py | 278 ++++++ setup.cfg | 2 + setup.py | 14 + tests/base_events_test.py | 590 +++++++++++++ tests/echo.py | 6 + tests/echo2.py | 6 + tests/echo3.py | 9 + tests/events_test.py | 1573 +++++++++++++++++++++++++++++++++ tests/futures_test.py | 329 +++++++ tests/locks_test.py | 765 ++++++++++++++++ tests/proactor_events_test.py | 480 ++++++++++ tests/queues_test.py | 470 ++++++++++ tests/sample.crt | 14 + tests/sample.key | 15 + tests/selector_events_test.py | 1485 +++++++++++++++++++++++++++++++ tests/selectors_test.py | 145 +++ tests/streams_test.py | 364 ++++++++ tests/tasks_test.py | 1518 +++++++++++++++++++++++++++++++ tests/transports_test.py | 55 ++ tests/unix_events_test.py | 770 ++++++++++++++++ tests/windows_events_test.py | 91 ++ tests/windows_utils_test.py | 136 +++ tulip/__init__.py | 26 + tulip/base_events.py | 606 +++++++++++++ tulip/constants.py | 4 + tulip/events.py | 395 +++++++++ tulip/futures.py | 338 +++++++ tulip/locks.py | 401 +++++++++ tulip/log.py | 7 + tulip/proactor_events.py | 346 ++++++++ tulip/protocols.py | 98 ++ tulip/queues.py | 284 ++++++ tulip/selector_events.py | 754 ++++++++++++++++ tulip/selectors.py | 405 +++++++++ tulip/streams.py | 257 ++++++ tulip/tasks.py | 636 +++++++++++++ tulip/test_utils.py | 239 +++++ tulip/transports.py | 186 ++++ tulip/unix_events.py | 541 ++++++++++++ tulip/windows_events.py | 371 ++++++++ tulip/windows_utils.py | 181 ++++ 59 files changed, 17850 insertions(+) create mode 100644 .hgeol create mode 100644 .hgignore create mode 100644 Makefile create mode 100644 NOTES create mode 100644 README create mode 100644 TODO create mode 100644 check.py create mode 100644 examples/child_process.py create mode 100644 examples/fetch0.py create mode 100644 examples/fetch1.py create mode 100644 examples/fetch2.py create mode 100644 examples/fetch3.py create mode 100644 examples/sink.py create mode 100644 examples/source.py create mode 100644 examples/stacks.py create mode 100755 examples/tcp_echo.py create mode 100755 examples/udp_echo.py create mode 100644 overlapped.c create mode 100644 runtests.py create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 tests/base_events_test.py create mode 100644 tests/echo.py create mode 100644 tests/echo2.py create mode 100644 tests/echo3.py create mode 100644 tests/events_test.py create mode 100644 tests/futures_test.py create mode 100644 tests/locks_test.py create mode 100644 tests/proactor_events_test.py create mode 100644 tests/queues_test.py create mode 100644 tests/sample.crt create mode 100644 tests/sample.key create mode 100644 tests/selector_events_test.py create mode 100644 tests/selectors_test.py create mode 100644 tests/streams_test.py create mode 100644 tests/tasks_test.py create mode 100644 tests/transports_test.py create mode 100644 tests/unix_events_test.py create mode 100644 tests/windows_events_test.py create mode 100644 tests/windows_utils_test.py create mode 100644 tulip/__init__.py create mode 100644 tulip/base_events.py create mode 100644 tulip/constants.py create mode 100644 tulip/events.py create mode 100644 tulip/futures.py create mode 100644 tulip/locks.py create mode 100644 tulip/log.py create mode 100644 tulip/proactor_events.py create mode 100644 tulip/protocols.py create mode 100644 tulip/queues.py create mode 100644 tulip/selector_events.py create mode 100644 tulip/selectors.py create mode 100644 tulip/streams.py create mode 100644 tulip/tasks.py create mode 100644 tulip/test_utils.py create mode 100644 tulip/transports.py create mode 100644 tulip/unix_events.py create mode 100644 tulip/windows_events.py create mode 100644 tulip/windows_utils.py diff --git a/.hgeol b/.hgeol new file mode 100644 index 00000000..8233b6dd --- /dev/null +++ b/.hgeol @@ -0,0 +1,4 @@ +[patterns] +** = native +.hgignore = native +.hgeol = native diff --git a/.hgignore b/.hgignore new file mode 100644 index 00000000..99870025 --- /dev/null +++ b/.hgignore @@ -0,0 +1,12 @@ +.*\.py[co]$ +.*~$ +.*\.orig$ +.*\#.*$ +.*@.*$ +\.coverage$ +htmlcov$ +\.DS_Store$ +venv$ +distribute_setup.py$ +distribute-\d+.\d+.\d+.tar.gz$ +build$ diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..ed3caf21 --- /dev/null +++ b/Makefile @@ -0,0 +1,34 @@ +# Some simple testing tasks (sorry, UNIX only). + +PYTHON=python3 +VERBOSE=$(V) +V= 0 +FLAGS= + +test: + $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS) + +vtest: + $(PYTHON) runtests.py -v 1 $(FLAGS) + +testloop: + while sleep 1; do $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS); done + +# See runtests.py for coverage installation instructions. +cov coverage: + $(PYTHON) runtests.py --coverage -v $(VERBOSE) $(FLAGS) + +check: + $(PYTHON) check.py + +clean: + rm -rf `find . -name __pycache__` + rm -f `find . -type f -name '*.py[co]' ` + rm -f `find . -type f -name '*~' ` + rm -f `find . -type f -name '.*~' ` + rm -f `find . -type f -name '@*' ` + rm -f `find . -type f -name '#*#' ` + rm -f `find . -type f -name '*.orig' ` + rm -f `find . -type f -name '*.rej' ` + rm -f .coverage + rm -rf htmlcov diff --git a/NOTES b/NOTES new file mode 100644 index 00000000..3b94ba96 --- /dev/null +++ b/NOTES @@ -0,0 +1,176 @@ +Notes from PyCon 2013 sprints +============================= + +- Cancellation. If a task creates several subtasks, and then the + parent task fails, should the subtasks be cancelled? (How do we + even establish the parent/subtask relationship?) + +- Adam Sah suggests that there might be a need for scheduling + (especially when multiple frameworks share an event loop). He + points to lottery scheduling but also mentions that's just one of + the options. However, after posting on python-tulip, it appears + none of the other frameworks have scheduling, and nobody seems to + miss it. + +- Feedback from Bram Cohen (Bittorrent creator) about UDP. He doesn't + think connected UDP is worth supporting, it doesn't do anything + except tell the kernel about the default target address for + sendto(). Basically he says all UDP end points are servers. He + sent me his own UDP event loop so I might glean some tricks from it. + He says we should treat EINTR the same as EAGAIN and friends. (We + should use the exceptions dedicated to errno checking, BTW.) HE + said to make sure we use SO_REUSEADDR (I think we already do). He + said to set the max datagram sizes pretty large (anything larger + than the declared limit is dropped on the floor). He reminds us of + the importance of being able to pick a valid, unused port by binding + to port 0 and then using getsockname(). He has an idea where he's + like to be able to kill all registered callbacks (i.e. Handles) + belonging to a certain "context". I think this can be done at the + application level (you'd have to wrap everything that returns a + Handle and collect these handles in some set or other datastructure) + but if someone thinks it's interesting we could imagine having some + kind of notion of context part of the event loop state, + e.g. associated with a Task (see Cancellation point above). He + brought up uTP (Micro Transport Protocol), a reimplementation of TCP + over UDP with more refined congestion control. + +- Mumblings about UNIX domain sockets and IPv6 addresses being + 4-tuples. The former can be handled by passing in a socket. There + seem to be no real use cases for the latter that can't be dealt with + by passing in suitably esoteric strings for the hostname. + getaddrinfo() will produce the appropriate 4-tuple and connect() + will accept it. + +- Mumblings on the list about add vs. set. + + +Notes from the second Tulip/Twisted meet-up +=========================================== + +Rackspace, 12/11/2012 +Glyph, Brian Warner, David Reid, Duncan McGreggor, others + +Flow control +------------ + +- Pause/resume on transport manages data_received. + +- There's also an API to tell the transport whom to pause when the + write calls are overwhelming it: IConsumer.registerProducer(). + +- There's also something called pipes but it's built on top of the + old interface. + +- Twisted has variations on the basic flow control that I should + ignore. + +Half_close +---------- + +- This sends an EOF after writing some stuff. + +- Can't write any more. + +- Problem with TLS is known (the RFC sadly specifies this behavior). + +- It must be dynamimcally discoverable whether the transport supports + half_close, since the protocol may have to do something different to + make up for its missing (e.g. use chunked encoding). Twisted uses + an interface check for this and also hasattr(trans, 'halfClose') + but a flag (or flag method) is fine too. + +Constructing transport and protocol +----------------------------------- + +- There are good reasons for passing a function to the transport + construction helper that creates the protocol. (You need these + anyway for server-side protocols.) The sequence of events is + something like + + . open socket + . create transport (pass it a socket?) + . create protocol (pass it nothing) + . proto.make_connection(transport); this does: + . self.transport = transport + . self.connection_made(transport) + + But it seems okay to skip make_connection and setting .transport. + Note that make_connection() is a concrete method on the Protocol + implementation base class, while connection_made() is an abstract + method on IProtocol. + +Event Loop +---------- + +- We discussed the sequence of actions in the event loop. I think in the + end we're fine with what Tulip currently does. There are two choices: + + Tulip: + . run ready callbacks until there aren't any left + . poll, adding more callbacks to the ready list + . add now-ready delayed callbacks to the ready list + . go to top + + Tornado: + . run all currently ready callbacks (but not new ones added during this) + . (the rest is the same) + + The difference is that in the Tulip version, CPU bound callbacks + that keep adding more to the queue will starve I/O (and yielding to + other tasks won't actually cause I/O to happen unless you do + e.g. sleep(0.001)). OTOH this may be good because it means there's + less overhead if you frequently split operations in two. + +- I think Twisted does it Tornado style (in a convoluted way :-), but + it may not matter, and it's important to leave this vague so + implementations can do what's best for their platform. (E.g. if the + event loop is built into the OS there are different trade-offs.) + +System call cost +---------------- + +- System calls on MacOS are expensive, on Linux they are cheap. + +- Optimal buffer size ~16K. + +- Try joining small buffer pieces together, but expect to be tuning + this later. + +Futures +------- + +- Futures are the most robust API for async stuff, you can check + errors etc. So let's do this. + +- Just don't implement wait(). + +- For the basics, however, (recv/send, mostly), don't use Futures but use + basic callbacks, transport/protocol style. + +- make_connection() (by any name) can return a Future, it makes it + easier to check for errors. + +- This means revisiting the Tulip proactor branch (IOCP). + +- The semantics of add_done_callback() are fuzzy about in which thread + the callback will be called. (It may be the current thread or + another one.) We don't like that. But always inserting a + call_soon() indirection may be expensive? Glyph suggested changing + the add_done_callback() method name to something else to indicate + the changed promise. + +- Separately, I've been thinking about having two versions of + call_soon() -- a more heavy-weight one to be called from other + threads that also writes a byte to the self-pipe. + +Signals +------- + +- There was a side conversation about signals. A signal handler is + similar to another thread, so probably should use (the heavy-weight + version of) call_soon() to schedule the real callback and not do + anything else. + +- Glyph vaguely recalled some trickiness with the self-pipe. We + should be able to fix this afterwards if necessary, it shouldn't + affect the API design. diff --git a/README b/README new file mode 100644 index 00000000..8f2b6373 --- /dev/null +++ b/README @@ -0,0 +1,21 @@ +Tulip is the codename for my reference implementation of PEP 3156. + +PEP 3156: http://www.python.org/dev/peps/pep-3156/ + +*** This requires Python 3.3 or later! *** + +Copyright/license: Open source, Apache 2.0. Enjoy. + +Master Mercurial repo: http://code.google.com/p/tulip/ + +The actual code lives in the 'tulip' subdirectory. +Tests are in the 'tests' subdirectory. + +To run tests: + - make test + +To run coverage (coverage package is required): + - make coverage + + +--Guido van Rossum diff --git a/TODO b/TODO new file mode 100644 index 00000000..c6d4eead --- /dev/null +++ b/TODO @@ -0,0 +1,163 @@ +# -*- Mode: text -*- + +TO DO LARGER TASKS + +- Need more examples. + +- Benchmarkable but more realistic HTTP server? + +- Example of using UDP. + +- Write up a tutorial for the scheduling API. + +- More systematic approach to logging. Logger objects? What about + heavy-duty logging, tracking essentially all task state changes? + +- Restructure directory, move demos and benchmarks to subdirectories. + + +TO DO LATER + +- When multiple tasks are accessing the same socket, they should + either get interleaved I/O or an immediate exception; it should not + compromise the integrity of the scheduler or the app or leave a task + hanging. + +- For epoll you probably want to check/(log?) EPOLLHUP and EPOLLERR errors. + +- Add the simplest API possible to run a generator with a timeout. + +- Ensure multiple tasks can do atomic writes to the same pipe (since + UNIX guarantees that short writes to pipes are atomic). + +- Ensure some easy way of distributing accepted connections across tasks. + +- Be wary of thread-local storage. There should be a standard API to + get the current Context (which holds current task, event loop, and + maybe more) and a standard meta-API to change how that standard API + works (i.e. without monkey-patching). + +- See how much of asyncore I've already replaced. + +- Could BufferedReader reuse the standard io module's readers??? + +- Support ZeroMQ "sockets" which are user objects. Though possibly + this can be supported by getting the underlying fd? See + http://mail.python.org/pipermail/python-ideas/2012-October/017532.html + OTOH see + https://github.com/zeromq/pyzmq/blob/master/zmq/eventloop/ioloop.py + +- Study goroutines (again). + +- Benchmarks: http://nichol.as/benchmark-of-python-web-servers + + +FROM OLDER LIST + +- Multiple readers/writers per socket? (At which level? pollster, + eventloop, or scheduler?) + +- Could poll() usefully be an iterator? + +- Do we need to support more epoll and/or kqueue modes/flags/options/etc.? + +- Optimize register/unregister calls away if they cancel each other out? + +- Add explicit wait queue to wait for Task's completion, instead of + callbacks? + +- Look at pyfdpdlib's ioloop.py: + http://code.google.com/p/pyftpdlib/source/browse/trunk/pyftpdlib/lib/ioloop.py + + +MISTAKES I MADE + +- Forgetting yield from. (E.g.: scheduler.sleep(1); listener.accept().) + +- Forgot to add bare yield at end of internal function, after block(). + +- Forgot to call add_done_callback(). + +- Forgot to pass an undoer to block(), bug only found when cancelled. + +- Subtle accounting mistake in a callback. + +- Used context.eventloop from a different thread, forgetting about TLS. + +- Nasty race: eventloop.ready may contain both an I/O callback and a + cancel callback. How to avoid? Keep the DelayedCall in ready. Is + that enough? + +- If a toplevel task raises an error it just stops and nothing is logged + unless you have debug logging on. This confused me. (Then again, + previously I logged whenever a task raised an error, and that was too + chatty...) + +- Forgot to set the connection socket returned by accept() in + nonblocking mode. + +- Nastiest so far (cost me about a day): A race condition in + call_in_thread() where the Future's done_callback (which was + task.unblock()) would run immediately at the time when + add_done_callback() was called, and this screwed over the task + state. Solution: wrap the callback in eventloop.call_later(). + Ironically, I had a comment stating there might be a race condition. + +- Another bug where I was calling unblock() for the current thread + immediately after calling block(), before yielding. + +- readexactly() wasn't checking for EOF, so could be looping. + (Worse, the first fix I attempted was wrong.) + +- Spent a day trying to understand why a tentative patch trying to + move the recv() implementation into the eventloop (or the pollster) + resulted in problems cancelling a recv() call. Ultimately the + problem is that the cancellation mechanism is part of the coroutine + scheduler, which simply throws an exception into a task when it next + runs, and there isn't anything to be interrupted in the eventloop; + but the eventloop still has a reader registered (which will never + fire because I suspended the server -- that's my test case :-). + Then, the eventloop keeps running until the last file descriptor is + unregistered. What contributed to this disaster? + * I didn't build the whole infrastructure, just played with recv() + * I don't have unittests + * I don't have good logging to see what is going + +- In sockets.py, in some SSL error handling code, used the wrong + variable (sock instead of sslsock). A linter would have found this. + +- In polling.py, in KqueuePollster.register_writer(), a copy/paste + error where I was testing for "if fd not in self.readers" instead of + writers. This only came out when I had both a reader and a writer + for the same fd. + +- Submitted some changes prematurely (forgot to pass the filename on + hg ci). + +- Forgot again that shutdown(SHUT_WR) on an ssl socket does not work + as I expected. I ran into this with the origininal sockets.py and + again in transport.py. + +- Having the same callback for both reading and writing has a problem: + it may be scheduled twice, and if the first call closes the socket, + the second runs into trouble. + + +MISTAKES I MADE IN TULIP V2 + +- Nice one: Future._schedule_callbacks() wasn't scheduling any callbacks. + Spot the bug in these four lines: + + def _schedule_callbacks(self): + callbacks = self._callbacks[:] + self._callbacks[:] = [] + for callback in self._callbacks: + self._event_loop.call_soon(callback, self) + + The good news is that I found it with a unittest (albeit not the + unittest intended to exercise this particular method :-( ). + +- In _make_self_pipe_or_sock(), called _pollster.register_reader() + instead of add_reader(), trying to optimize something but breaking + things instead (since the -- internal -- API of register_reader() + had changed). diff --git a/check.py b/check.py new file mode 100644 index 00000000..6db82d64 --- /dev/null +++ b/check.py @@ -0,0 +1,45 @@ +"""Search for lines >= 80 chars or with trailing whitespace.""" + +import os +import sys + + +def main(): + args = sys.argv[1:] or os.curdir + for arg in args: + if os.path.isdir(arg): + for dn, dirs, files in os.walk(arg): + for fn in sorted(files): + if fn.endswith('.py'): + process(os.path.join(dn, fn)) + dirs[:] = [d for d in dirs if d[0] != '.'] + dirs.sort() + else: + process(arg) + + +def isascii(x): + try: + x.encode('ascii') + return True + except UnicodeError: + return False + + +def process(fn): + try: + f = open(fn) + except IOError as err: + print(err) + return + try: + for i, line in enumerate(f): + line = line.rstrip('\n') + sline = line.rstrip() + if len(line) >= 80 or line != sline or not isascii(line): + print('{}:{:d}:{}{}'.format( + fn, i+1, sline, '_' * (len(line) - len(sline)))) + finally: + f.close() + +main() diff --git a/examples/child_process.py b/examples/child_process.py new file mode 100644 index 00000000..5a88faa6 --- /dev/null +++ b/examples/child_process.py @@ -0,0 +1,127 @@ +""" +Example of asynchronous interaction with a child python process. + +Note that on Windows we must use the IOCP event loop. +""" + +import os +import sys + +try: + import tulip +except ImportError: + # tulip is not installed + sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + import tulip + +from tulip import streams +from tulip import protocols + +if sys.platform == 'win32': + from tulip.windows_utils import Popen, PIPE + from tulip.windows_events import ProactorEventLoop +else: + from subprocess import Popen, PIPE + +# +# Return a write-only transport wrapping a writable pipe +# + +@tulip.coroutine +def connect_write_pipe(file): + loop = tulip.get_event_loop() + protocol = protocols.Protocol() + transport, _ = yield from loop.connect_write_pipe(tulip.Protocol, file) + return transport + +# +# Wrap a readable pipe in a stream +# + +@tulip.coroutine +def connect_read_pipe(file): + loop = tulip.get_event_loop() + stream_reader = streams.StreamReader(loop=loop) + def factory(): + return streams.StreamReaderProtocol(stream_reader) + transport, _ = yield from loop.connect_read_pipe(factory, file) + return stream_reader, transport + + +# +# Example +# + +@tulip.coroutine +def main(loop): + # program which prints evaluation of each expression from stdin + code = r'''if 1: + import os + def writeall(fd, buf): + while buf: + n = os.write(fd, buf) + buf = buf[n:] + while True: + s = os.read(0, 1024) + if not s: + break + s = s.decode('ascii') + s = repr(eval(s)) + '\n' + s = s.encode('ascii') + writeall(1, s) + ''' + + # commands to send to input + commands = iter([b"1+1\n", + b"2**16\n", + b"1/3\n", + b"'x'*50", + b"1/0\n"]) + + # start subprocess and wrap stdin, stdout, stderr + p = Popen([sys.executable, '-c', code], + stdin=PIPE, stdout=PIPE, stderr=PIPE) + + stdin = yield from connect_write_pipe(p.stdin) + stdout, stdout_transport = yield from connect_read_pipe(p.stdout) + stderr, stderr_transport = yield from connect_read_pipe(p.stderr) + + # interact with subprocess + name = {stdout:'OUT', stderr:'ERR'} + registered = {tulip.Task(stderr.readline()): stderr, + tulip.Task(stdout.readline()): stdout} + while registered: + # write command + cmd = next(commands, None) + if cmd is None: + stdin.close() + else: + print('>>>', cmd.decode('ascii').rstrip()) + stdin.write(cmd) + + # get and print lines from stdout, stderr + timeout = None + while registered: + done, pending = yield from tulip.wait( + registered, timeout=timeout, return_when=tulip.FIRST_COMPLETED) + if not done: + break + for f in done: + stream = registered.pop(f) + res = f.result() + print(name[stream], res.decode('ascii').rstrip()) + if res != b'': + registered[tulip.Task(stream.readline())] = stream + timeout = 0.0 + + stdout_transport.close() + stderr_transport.close() + +if __name__ == '__main__': + if sys.platform == 'win32': + loop = ProactorEventLoop() + tulip.set_event_loop(loop) + else: + loop = tulip.get_event_loop() + loop.run_until_complete(main(loop)) + loop.close() diff --git a/examples/fetch0.py b/examples/fetch0.py new file mode 100644 index 00000000..84edaa26 --- /dev/null +++ b/examples/fetch0.py @@ -0,0 +1,32 @@ +"""Simplest possible HTTP client.""" + +import sys + +from tulip import * + + +@coroutine +def fetch(): + r, w = yield from open_connection('python.org', 80) + request = 'GET / HTTP/1.0\r\n\r\n' + print('>', request, file=sys.stderr) + w.write(request.encode('latin-1')) + while True: + line = yield from r.readline() + line = line.decode('latin-1').rstrip() + if not line: + break + print('<', line, file=sys.stderr) + print(file=sys.stderr) + body = yield from r.read() + return body + + +def main(): + loop = get_event_loop() + body = loop.run_until_complete(fetch()) + print(body.decode('latin-1'), end='') + + +if __name__ == '__main__': + main() diff --git a/examples/fetch1.py b/examples/fetch1.py new file mode 100644 index 00000000..57e66e6a --- /dev/null +++ b/examples/fetch1.py @@ -0,0 +1,75 @@ +"""Fetch one URL and write its content to stdout. + +This version adds URL parsing (including SSL) and a Response object. +""" + +import sys +import urllib.parse + +from tulip import * + + +class Response: + + def __init__(self, verbose=True): + self.verbose = verbose + self.http_version = None # 'HTTP/1.1' + self.status = None # 200 + self.reason = None # 'Ok' + self.headers = [] # [('Content-Type', 'text/html')] + + @coroutine + def read(self, reader): + @coroutine + def getline(): + return (yield from reader.readline()).decode('latin-1').rstrip() + status_line = yield from getline() + if self.verbose: print('<', status_line, file=sys.stderr) + self.http_version, status, self.reason = status_line.split(None, 2) + self.status = int(status) + while True: + header_line = yield from getline() + if not header_line: + break + if self.verbose: print('<', header_line, file=sys.stderr) + # TODO: Continuation lines. + key, value = header_line.split(':', 1) + self.headers.append((key, value.strip())) + if self.verbose: print(file=sys.stderr) + + +@coroutine +def fetch(url, verbose=True): + parts = urllib.parse.urlparse(url) + if parts.scheme == 'http': + ssl = False + elif parts.scheme == 'https': + ssl = True + else: + print('URL must use http or https.') + sys.exit(1) + port = parts.port + if port is None: + port = 443 if ssl else 80 + path = parts.path or '/' + if parts.query: + path += '?' + parts.query + request = 'GET %s HTTP/1.0\r\n\r\n' % path + if verbose: + print('>', request, file=sys.stderr, end='') + r, w = yield from open_connection(parts.hostname, port, ssl=ssl) + w.write(request.encode('latin-1')) + response = Response(verbose) + yield from response.read(r) + body = yield from r.read() + return body + + +def main(): + loop = get_event_loop() + body = loop.run_until_complete(fetch(sys.argv[1], '-v' in sys.argv)) + print(body.decode('latin-1'), end='') + + +if __name__ == '__main__': + main() diff --git a/examples/fetch2.py b/examples/fetch2.py new file mode 100644 index 00000000..ca250d61 --- /dev/null +++ b/examples/fetch2.py @@ -0,0 +1,138 @@ +"""Fetch one URL and write its content to stdout. + +This version adds a Request object. +""" + +import sys +import urllib.parse +from http.client import BadStatusLine + +from tulip import * + + +class Request: + + def __init__(self, url, verbose=True): + self.url = url + self.verbose = verbose + self.parts = urllib.parse.urlparse(self.url) + self.scheme = self.parts.scheme + assert self.scheme in ('http', 'https'), repr(url) + self.ssl = self.parts.scheme == 'https' + self.netloc = self.parts.netloc + self.hostname = self.parts.hostname + self.port = self.parts.port or (443 if self.ssl else 80) + self.path = (self.parts.path or '/') + self.query = self.parts.query + if self.query: + self.full_path = '%s?%s' % (self.path, self.query) + else: + self.full_path = self.path + self.http_version = 'HTTP/1.1' + self.method = 'GET' + self.headers = [] + self.reader = None + self.writer = None + + @coroutine + def connect(self): + if self.verbose: + print('* Connecting to %s:%s using %s' % + (self.hostname, self.port, 'ssl' if self.ssl else 'tcp'), + file=sys.stderr) + self.reader, self.writer = yield from open_connection(self.hostname, + self.port, + ssl=self.ssl) + if self.verbose: + print('* Connected to %s' % + (self.writer.get_extra_info('peername'),), + file=sys.stderr) + + def putline(self, line): + self.writer.write(line.encode('latin-1') + b'\r\n') + + @coroutine + def send_request(self): + request = '%s %s %s' % (self.method, self.full_path, self.http_version) + if self.verbose: print('>', request, file=sys.stderr) + self.putline(request) + if 'host' not in {key.lower() for key, _ in self.headers}: + self.headers.insert(0, ('Host', self.netloc)) + for key, value in self.headers: + line = '%s: %s' % (key, value) + if self.verbose: print('>', line, file=sys.stderr) + self.putline(line) + self.putline('') + + @coroutine + def get_response(self): + response = Response(self.reader, self.verbose) + yield from response.read_headers() + return response + + +class Response: + + def __init__(self, reader, verbose=True): + self.reader = reader + self.verbose = verbose + self.http_version = None # 'HTTP/1.1' + self.status = None # 200 + self.reason = None # 'Ok' + self.headers = [] # [('Content-Type', 'text/html')] + + @coroutine + def getline(self): + return (yield from self.reader.readline()).decode('latin-1').rstrip() + + @coroutine + def read_headers(self): + status_line = yield from self.getline() + if self.verbose: print('<', status_line, file=sys.stderr) + status_parts = status_line.split(None, 2) + if len(status_parts) != 3: + raise BadStatusLine(status_line) + self.http_version, status, self.reason = status_parts + self.status = int(status) + while True: + header_line = yield from self.getline() + if not header_line: + break + if self.verbose: print('<', header_line, file=sys.stderr) + # TODO: Continuation lines. + key, value = header_line.split(':', 1) + self.headers.append((key, value.strip())) + if self.verbose: print(file=sys.stderr) + + @coroutine + def read(self): + nbytes = None + for key, value in self.headers: + if key.lower() == 'content-length': + nbytes = int(value) + break + if nbytes is None: + body = yield from self.reader.read() + else: + body = yield from self.reader.readexactly(nbytes) + return body + + +@coroutine +def fetch(url, verbose=True): + request = Request(url, verbose) + yield from request.connect() + yield from request.send_request() + response = yield from request.get_response() + body = yield from response.read() + return body + + +def main(): + loop = get_event_loop() + body = loop.run_until_complete(fetch(sys.argv[1], '-v' in sys.argv)) + sys.stdout.buffer.write(body) + + +if __name__ == '__main__': + main() diff --git a/examples/fetch3.py b/examples/fetch3.py new file mode 100644 index 00000000..3b2c8ae0 --- /dev/null +++ b/examples/fetch3.py @@ -0,0 +1,211 @@ +"""Fetch one URL and write its content to stdout. + +This version adds a primitive connection pool, redirect following and +chunked transfer-encoding. +""" + +import sys +import urllib.parse +from http.client import BadStatusLine + +from tulip import * + + +class ConnectionPool: + # TODO: Locking? Close idle connections? + + def __init__(self, verbose=False): + self.verbose = verbose + self.connections = {} # {(host, port, ssl): (reader, writer)} + + @coroutine + def open_connection(self, host, port, ssl): + port = port or (443 if ssl else 80) + ipaddrs = yield from get_event_loop().getaddrinfo(host, port) + if self.verbose: + print('* %s resolves to %s' % + (host, ', '.join(ip[4][0] for ip in ipaddrs)), + file=sys.stderr) + for _, _, _, _, (h, p, *_) in ipaddrs: + key = h, p, ssl + conn = self.connections.get(key) + if conn: + if self.verbose: + print('* Reusing pooled connection', key, file=sys.stderr) + return conn + reader, writer = yield from open_connection(host, port, ssl=ssl) + host, port, *_ = writer.get_extra_info('peername') + key = host, port, ssl + self.connections[key] = reader, writer + if self.verbose: + print('* New connection', key, file=sys.stderr) + return reader, writer + + +class Request: + + def __init__(self, url, verbose=True): + self.url = url + self.verbose = verbose + self.parts = urllib.parse.urlparse(self.url) + self.scheme = self.parts.scheme + assert self.scheme in ('http', 'https'), repr(url) + self.ssl = self.parts.scheme == 'https' + self.netloc = self.parts.netloc + self.hostname = self.parts.hostname + self.port = self.parts.port or (443 if self.ssl else 80) + self.path = (self.parts.path or '/') + self.query = self.parts.query + if self.query: + self.full_path = '%s?%s' % (self.path, self.query) + else: + self.full_path = self.path + self.http_version = 'HTTP/1.1' + self.method = 'GET' + self.headers = [] + self.reader = None + self.writer = None + + def vprint(self, *args): + if self.verbose: + print(*args, file=sys.stderr) + + @coroutine + def connect(self, pool): + self.vprint('* Connecting to %s:%s using %s' % + (self.hostname, self.port, 'ssl' if self.ssl else 'tcp')) + self.reader, self.writer = \ + yield from pool.open_connection(self.hostname, + self.port, + ssl=self.ssl) + self.vprint('* Connected to %s' % + (self.writer.get_extra_info('peername'),)) + + @coroutine + def putline(self, line): + self.vprint('>', line) + self.writer.write(line.encode('latin-1') + b'\r\n') + ##yield from self.writer.drain() + + @coroutine + def send_request(self): + request = '%s %s %s' % (self.method, self.full_path, self.http_version) + yield from self.putline(request) + if 'host' not in {key.lower() for key, _ in self.headers}: + self.headers.insert(0, ('Host', self.netloc)) + for key, value in self.headers: + line = '%s: %s' % (key, value) + yield from self.putline(line) + yield from self.putline('') + + @coroutine + def get_response(self): + response = Response(self.reader, self.verbose) + yield from response.read_headers() + return response + + +class Response: + + def __init__(self, reader, verbose=True): + self.reader = reader + self.verbose = verbose + self.http_version = None # 'HTTP/1.1' + self.status = None # 200 + self.reason = None # 'Ok' + self.headers = [] # [('Content-Type', 'text/html')] + + def vprint(self, *args): + if self.verbose: + print(*args, file=sys.stderr) + + @coroutine + def getline(self): + line = (yield from self.reader.readline()).decode('latin-1').rstrip() + self.vprint('<', line) + return line + + @coroutine + def read_headers(self): + status_line = yield from self.getline() + status_parts = status_line.split(None, 2) + if len(status_parts) != 3: + raise BadStatusLine(status_line) + self.http_version, status, self.reason = status_parts + self.status = int(status) + while True: + header_line = yield from self.getline() + if not header_line: + break + # TODO: Continuation lines. + key, value = header_line.split(':', 1) + self.headers.append((key, value.strip())) + + def get_redirect_url(self, default=None): + if self.status not in (300, 301, 302, 303, 307): + return default + return self.get_header('Location', default) + + def get_header(self, key, default=None): + key = key.lower() + for k, v in self.headers: + if k.lower() == key: + return v + return default + + @coroutine + def read(self): + nbytes = None + for key, value in self.headers: + if key.lower() == 'content-length': + nbytes = int(value) + break + if nbytes is None: + if self.get_header('transfer-encoding', '').lower() == 'chunked': + blocks = [] + while True: + size_header = yield from self.reader.readline() + if not size_header: + break + parts = size_header.split(b';') + size = int(parts[0], 16) + if not size: + break + block = yield from self.reader.readexactly(size) + assert len(block) == size, (len(block), size) + blocks.append(block) + crlf = yield from self.reader.readline() + assert crlf == b'\r\n' + body = b''.join(blocks) + else: + body = yield from self.reader.read() + else: + body = yield from self.reader.readexactly(nbytes) + return body + + +@coroutine +def fetch(url, verbose=True, max_redirect=10): + pool = ConnectionPool(verbose) + for _ in range(max_redirect): + request = Request(url, verbose) + yield from request.connect(pool) + yield from request.send_request() + response = yield from request.get_response() + body = yield from response.read() + next_url = response.get_redirect_url() + if not next_url: + break + url = urllib.parse.urljoin(url, next_url) + print('redirect to', url, file=sys.stderr) + return body + + +def main(): + loop = get_event_loop() + body = loop.run_until_complete(fetch(sys.argv[1], '-v' in sys.argv)) + sys.stdout.buffer.write(body) + + +if __name__ == '__main__': + main() diff --git a/examples/sink.py b/examples/sink.py new file mode 100644 index 00000000..c94574a6 --- /dev/null +++ b/examples/sink.py @@ -0,0 +1,76 @@ +"""Test service that accepts connections and reads all data off them.""" + +import argparse +import sys + +from tulip import * + +ARGS = argparse.ArgumentParser(description="TCP data sink example.") +ARGS.add_argument( + '--iocp', action='store_true', dest='iocp', + default=False, help='Use IOCP event loop (Windows only)') +ARGS.add_argument( + '--host', action='store', dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action='store', dest='port', + default=1111, type=int, help='Port number') +ARGS.add_argument( + '--maxsize', action='store', dest='maxsize', + default=16*1024*1024, type=int, help='Max total data size') + +server = None +args = None + + +def dprint(*args): + print('sink:', *args, file=sys.stderr) + + +class Service(Protocol): + + def connection_made(self, tr): + dprint('connection from', tr.get_extra_info('peername')) + dprint('my socket is', tr.get_extra_info('sockname')) + self.tr = tr + self.total = 0 + + def data_received(self, data): + if data == b'stop': + dprint('stopping server') + server.close() + self.tr.close() + return + self.total += len(data) + dprint('received', len(data), 'bytes; total', self.total) + if self.total > args.maxsize: + dprint('closing due to too much data') + self.tr.close() + + def connection_lost(self, how): + dprint('closed', repr(how)) + + +@coroutine +def start(loop, host, port): + global server + server = yield from loop.create_server(Service, host, port) + dprint('serving', [s.getsockname() for s in server.sockets]) + yield from server.wait_closed() + + +def main(): + global args + args = ARGS.parse_args() + if args.iocp: + from tulip.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + set_event_loop(loop) + else: + loop = get_event_loop() + loop.run_until_complete(start(loop, args.host, args.port)) + loop.close() + + +if __name__ == '__main__': + main() diff --git a/examples/source.py b/examples/source.py new file mode 100644 index 00000000..933739bf --- /dev/null +++ b/examples/source.py @@ -0,0 +1,90 @@ +"""Test client that connects and sends infinite data.""" + +import argparse +import sys + +from tulip import * + + +ARGS = argparse.ArgumentParser(description="TCP data sink example.") +ARGS.add_argument( + '--iocp', action='store_true', dest='iocp', + default=False, help='Use IOCP event loop (Windows only)') +ARGS.add_argument( + '--stop', action='store_true', dest='stop', + default=False, help='Stop the server by sending it b"stop" as data') +ARGS.add_argument( + '--host', action='store', dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action='store', dest='port', + default=1111, type=int, help='Port number') +ARGS.add_argument( + '--size', action='store', dest='size', + default=16*1024, type=int, help='Data size') + +args = None + + +def dprint(*args): + print('source:', *args, file=sys.stderr) + + +class Client(Protocol): + + total = 0 + + def connection_made(self, tr): + dprint('connecting to', tr.get_extra_info('peername')) + dprint('my socket is', tr.get_extra_info('sockname')) + self.tr = tr + self.lost = False + self.loop = get_event_loop() + self.waiter = Future() + if args.stop: + self.tr.write(b'stop') + self.tr.close() + else: + self.data = b'x'*args.size + self.write_some_data() + + def write_some_data(self): + if self.lost: + dprint('lost already') + return + data = self.data + size = len(data) + self.total += size + dprint('writing', size, 'bytes; total', self.total) + self.tr.write(data) + self.loop.call_soon(self.write_some_data) + + def connection_lost(self, exc): + dprint('lost connection', repr(exc)) + self.lost = True + self.waiter.set_result(None) + + +@coroutine +def start(loop, host, port): + tr, pr = yield from loop.create_connection(Client, host, port) + dprint('tr =', tr) + dprint('pr =', pr) + yield from pr.waiter + + +def main(): + global args + args = ARGS.parse_args() + if args.iocp: + from tulip.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + set_event_loop(loop) + else: + loop = get_event_loop() + loop.run_until_complete(start(loop, args.host, args.port)) + loop.close() + + +if __name__ == '__main__': + main() diff --git a/examples/stacks.py b/examples/stacks.py new file mode 100644 index 00000000..77a99cf5 --- /dev/null +++ b/examples/stacks.py @@ -0,0 +1,43 @@ +"""Crude demo for print_stack().""" + + +from tulip import * + + +@coroutine +def helper(r): + print('--- helper ---') + for t in Task.all_tasks(): + t.print_stack() + print('--- end helper ---') + line = yield from r.readline() + 1/0 + return line + +def doit(): + l = get_event_loop() + lr = l.run_until_complete + r, w = lr(open_connection('python.org', 80)) + t1 = async(helper(r)) + for t in Task.all_tasks(): t.print_stack() + print('---') + l._run_once() + for t in Task.all_tasks(): t.print_stack() + print('---') + w.write(b'GET /\r\n') + w.write_eof() + try: + lr(t1) + except Exception as e: + print('catching', e) + finally: + for t in Task.all_tasks(): + t.print_stack() + + +def main(): + doit() + + +if __name__ == '__main__': + main() diff --git a/examples/tcp_echo.py b/examples/tcp_echo.py new file mode 100755 index 00000000..39db5cca --- /dev/null +++ b/examples/tcp_echo.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +"""TCP echo server example.""" +import argparse +import tulip +try: + import signal +except ImportError: + signal = None + + +class EchoServer(tulip.Protocol): + + TIMEOUT = 5.0 + + def timeout(self): + print('connection timeout, closing.') + self.transport.close() + + def connection_made(self, transport): + print('connection made') + self.transport = transport + + # start 5 seconds timeout timer + self.h_timeout = tulip.get_event_loop().call_later( + self.TIMEOUT, self.timeout) + + def data_received(self, data): + print('data received: ', data.decode()) + self.transport.write(b'Re: ' + data) + + # restart timeout timer + self.h_timeout.cancel() + self.h_timeout = tulip.get_event_loop().call_later( + self.TIMEOUT, self.timeout) + + def eof_received(self): + pass + + def connection_lost(self, exc): + print('connection lost:', exc) + self.h_timeout.cancel() + + +class EchoClient(tulip.Protocol): + + message = 'This is the message. It will be echoed.' + + def connection_made(self, transport): + self.transport = transport + self.transport.write(self.message.encode()) + print('data sent:', self.message) + + def data_received(self, data): + print('data received:', data) + + # disconnect after 10 seconds + tulip.get_event_loop().call_later(10.0, self.transport.close) + + def eof_received(self): + pass + + def connection_lost(self, exc): + print('connection lost:', exc) + tulip.get_event_loop().stop() + + +def start_client(loop, host, port): + t = tulip.Task(loop.create_connection(EchoClient, host, port)) + loop.run_until_complete(t) + + +def start_server(loop, host, port): + f = loop.start_serving(EchoServer, host, port) + x = loop.run_until_complete(f)[0] + print('serving on', x.getsockname()) + + +ARGS = argparse.ArgumentParser(description="TCP Echo example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run tcp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run tcp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') + + +if __name__ == '__main__': + args = ARGS.parse_args() + + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + loop = tulip.get_event_loop() + if signal is not None: + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if args.server: + start_server(loop, args.host, args.port) + else: + start_client(loop, args.host, args.port) + + loop.run_forever() diff --git a/examples/udp_echo.py b/examples/udp_echo.py new file mode 100755 index 00000000..0347bfbd --- /dev/null +++ b/examples/udp_echo.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +"""UDP echo example.""" +import argparse +import sys +import tulip +try: + import signal +except ImportError: + signal = None + + +class MyServerUdpEchoProtocol: + + def connection_made(self, transport): + print('start', transport) + self.transport = transport + + def datagram_received(self, data, addr): + print('Data received:', data, addr) + self.transport.sendto(data, addr) + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('stop', exc) + + +class MyClientUdpEchoProtocol: + + message = 'This is the message. It will be echoed.' + + def connection_made(self, transport): + self.transport = transport + print('sending "{}"'.format(self.message)) + self.transport.sendto(self.message.encode()) + print('waiting to receive') + + def datagram_received(self, data, addr): + print('received "{}"'.format(data.decode())) + self.transport.close() + + def connection_refused(self, exc): + print('Connection refused:', exc) + + def connection_lost(self, exc): + print('closing transport', exc) + loop = tulip.get_event_loop() + loop.stop() + + +def start_server(loop, addr): + t = tulip.Task(loop.create_datagram_endpoint( + MyServerUdpEchoProtocol, local_addr=addr)) + loop.run_until_complete(t) + + +def start_client(loop, addr): + t = tulip.Task(loop.create_datagram_endpoint( + MyClientUdpEchoProtocol, remote_addr=addr)) + loop.run_until_complete(t) + + +ARGS = argparse.ArgumentParser(description="UDP Echo example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run udp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run udp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') + + +if __name__ == '__main__': + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + loop = tulip.get_event_loop() + if signal is not None: + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if '--server' in sys.argv: + start_server(loop, (args.host, args.port)) + else: + start_client(loop, (args.host, args.port)) + + loop.run_forever() diff --git a/overlapped.c b/overlapped.c new file mode 100644 index 00000000..6a1d9e4a --- /dev/null +++ b/overlapped.c @@ -0,0 +1,1202 @@ +/* + * Support for overlapped IO + * + * Some code borrowed from Modules/_winapi.c of CPython + */ + +/* XXX check overflow and DWORD <-> Py_ssize_t conversions + Check itemsize */ + +#include "Python.h" +#include "structmember.h" + +#define WINDOWS_LEAN_AND_MEAN +#include +#include +#include + +#if defined(MS_WIN32) && !defined(MS_WIN64) +# define F_POINTER "k" +# define T_POINTER T_ULONG +#else +# define F_POINTER "K" +# define T_POINTER T_ULONGLONG +#endif + +#define F_HANDLE F_POINTER +#define F_ULONG_PTR F_POINTER +#define F_DWORD "k" +#define F_BOOL "i" +#define F_UINT "I" + +#define T_HANDLE T_POINTER + +enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_WRITE, TYPE_ACCEPT, + TYPE_CONNECT, TYPE_DISCONNECT, TYPE_CONNECT_NAMED_PIPE, + TYPE_WAIT_NAMED_PIPE_AND_CONNECT}; + +typedef struct { + PyObject_HEAD + OVERLAPPED overlapped; + /* For convenience, we store the file handle too */ + HANDLE handle; + /* Error returned by last method call */ + DWORD error; + /* Type of operation */ + DWORD type; + union { + /* Buffer used for reading (optional) */ + PyObject *read_buffer; + /* Buffer used for writing (optional) */ + Py_buffer write_buffer; + }; +} OverlappedObject; + +typedef struct { + OVERLAPPED *Overlapped; + HANDLE IocpHandle; + char Address[1]; +} WaitNamedPipeAndConnectContext; + +/* + * Map Windows error codes to subclasses of OSError + */ + +static PyObject * +SetFromWindowsErr(DWORD err) +{ + PyObject *exception_type; + + if (err == 0) + err = GetLastError(); + switch (err) { + case ERROR_CONNECTION_REFUSED: + exception_type = PyExc_ConnectionRefusedError; + break; + case ERROR_CONNECTION_ABORTED: + exception_type = PyExc_ConnectionAbortedError; + break; + default: + exception_type = PyExc_OSError; + } + return PyErr_SetExcFromWindowsErr(exception_type, err); +} + +/* + * Some functions should be loaded at runtime + */ + +static LPFN_ACCEPTEX Py_AcceptEx = NULL; +static LPFN_CONNECTEX Py_ConnectEx = NULL; +static LPFN_DISCONNECTEX Py_DisconnectEx = NULL; +static BOOL (CALLBACK *Py_CancelIoEx)(HANDLE, LPOVERLAPPED) = NULL; + +#define GET_WSA_POINTER(s, x) \ + (SOCKET_ERROR != WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, \ + &Guid##x, sizeof(Guid##x), &Py_##x, \ + sizeof(Py_##x), &dwBytes, NULL, NULL)) + +static int +initialize_function_pointers(void) +{ + GUID GuidAcceptEx = WSAID_ACCEPTEX; + GUID GuidConnectEx = WSAID_CONNECTEX; + GUID GuidDisconnectEx = WSAID_DISCONNECTEX; + HINSTANCE hKernel32; + SOCKET s; + DWORD dwBytes; + + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (s == INVALID_SOCKET) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + if (!GET_WSA_POINTER(s, AcceptEx) || + !GET_WSA_POINTER(s, ConnectEx) || + !GET_WSA_POINTER(s, DisconnectEx)) + { + closesocket(s); + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + closesocket(s); + + /* On WinXP we will have Py_CancelIoEx == NULL */ + hKernel32 = GetModuleHandle("KERNEL32"); + *(FARPROC *)&Py_CancelIoEx = GetProcAddress(hKernel32, "CancelIoEx"); + return 0; +} + +/* + * Completion port stuff + */ + +PyDoc_STRVAR( + CreateIoCompletionPort_doc, + "CreateIoCompletionPort(handle, port, key, concurrency) -> port\n\n" + "Create a completion port or register a handle with a port."); + +static PyObject * +overlapped_CreateIoCompletionPort(PyObject *self, PyObject *args) +{ + HANDLE FileHandle; + HANDLE ExistingCompletionPort; + ULONG_PTR CompletionKey; + DWORD NumberOfConcurrentThreads; + HANDLE ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE F_ULONG_PTR F_DWORD, + &FileHandle, &ExistingCompletionPort, &CompletionKey, + &NumberOfConcurrentThreads)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = CreateIoCompletionPort(FileHandle, ExistingCompletionPort, + CompletionKey, NumberOfConcurrentThreads); + Py_END_ALLOW_THREADS + + if (ret == NULL) + return SetFromWindowsErr(0); + return Py_BuildValue(F_HANDLE, ret); +} + +PyDoc_STRVAR( + GetQueuedCompletionStatus_doc, + "GetQueuedCompletionStatus(port, msecs) -> (err, bytes, key, address)\n\n" + "Get a message from completion port. Wait for up to msecs milliseconds."); + +static PyObject * +overlapped_GetQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort = NULL; + DWORD NumberOfBytes = 0; + ULONG_PTR CompletionKey = 0; + OVERLAPPED *Overlapped = NULL; + DWORD Milliseconds; + DWORD err; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, + &CompletionPort, &Milliseconds)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = GetQueuedCompletionStatus(CompletionPort, &NumberOfBytes, + &CompletionKey, &Overlapped, Milliseconds); + Py_END_ALLOW_THREADS + + err = ret ? ERROR_SUCCESS : GetLastError(); + if (Overlapped == NULL) { + if (err == WAIT_TIMEOUT) + Py_RETURN_NONE; + else + return SetFromWindowsErr(err); + } + return Py_BuildValue(F_DWORD F_DWORD F_ULONG_PTR F_POINTER, + err, NumberOfBytes, CompletionKey, Overlapped); +} + +PyDoc_STRVAR( + PostQueuedCompletionStatus_doc, + "PostQueuedCompletionStatus(port, bytes, key, address) -> None\n\n" + "Post a message to completion port."); + +static PyObject * +overlapped_PostQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort; + DWORD NumberOfBytes; + ULONG_PTR CompletionKey; + OVERLAPPED *Overlapped; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD F_ULONG_PTR F_POINTER, + &CompletionPort, &NumberOfBytes, &CompletionKey, + &Overlapped)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = PostQueuedCompletionStatus(CompletionPort, NumberOfBytes, + CompletionKey, Overlapped); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +/* + * Bind socket handle to local port without doing slow getaddrinfo() + */ + +PyDoc_STRVAR( + BindLocal_doc, + "BindLocal(handle, family) -> None\n\n" + "Bind a socket handle to an arbitrary local port.\n" + "family should AF_INET or AF_INET6.\n"); + +static PyObject * +overlapped_BindLocal(PyObject *self, PyObject *args) +{ + SOCKET Socket; + int Family; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE "i", &Socket, &Family)) + return NULL; + + if (Family == AF_INET) { + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = 0; + addr.sin_addr.S_un.S_addr = INADDR_ANY; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else if (Family == AF_INET6) { + struct sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_port = 0; + addr.sin6_addr = in6addr_any; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else { + PyErr_SetString(PyExc_ValueError, "expected tuple of length 2 or 4"); + return NULL; + } + + if (!ret) + return SetFromWindowsErr(WSAGetLastError()); + Py_RETURN_NONE; +} + +/* + * Windows equivalent of os.strerror() -- compare _ctypes/callproc.c + */ + +PyDoc_STRVAR( + FormatMessage_doc, + "FormatMessage(error_code) -> error_message\n\n" + "Return error message for an error code."); + +static PyObject * +overlapped_FormatMessage(PyObject *ignore, PyObject *args) +{ + DWORD code, n; + WCHAR *lpMsgBuf; + PyObject *res; + + if (!PyArg_ParseTuple(args, F_DWORD, &code)) + return NULL; + + n = FormatMessageW(FORMAT_MESSAGE_ALLOCATE_BUFFER | + FORMAT_MESSAGE_FROM_SYSTEM, + NULL, + code, + MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + (LPWSTR) &lpMsgBuf, + 0, + NULL); + if (n) { + while (iswspace(lpMsgBuf[n-1])) + --n; + lpMsgBuf[n] = L'\0'; + res = Py_BuildValue("u", lpMsgBuf); + } else { + res = PyUnicode_FromFormat("unknown error code %u", code); + } + LocalFree(lpMsgBuf); + return res; +} + + +/* + * Mark operation as completed - used when reading produces ERROR_BROKEN_PIPE + */ + +static void +mark_as_completed(OVERLAPPED *ov) +{ + ov->Internal = 0; + if (ov->hEvent != NULL) + SetEvent(ov->hEvent); +} + +/* + * A Python object wrapping an OVERLAPPED structure and other useful data + * for overlapped I/O + */ + +PyDoc_STRVAR( + Overlapped_doc, + "Overlapped object"); + +static PyObject * +Overlapped_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + OverlappedObject *self; + HANDLE event = INVALID_HANDLE_VALUE; + static char *kwlist[] = {"event", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|" F_HANDLE, kwlist, &event)) + return NULL; + + if (event == INVALID_HANDLE_VALUE) { + event = CreateEvent(NULL, TRUE, FALSE, NULL); + if (event == NULL) + return SetFromWindowsErr(0); + } + + self = PyObject_New(OverlappedObject, type); + if (self == NULL) { + if (event != NULL) + CloseHandle(event); + return NULL; + } + + self->handle = NULL; + self->error = 0; + self->type = TYPE_NONE; + self->read_buffer = NULL; + memset(&self->overlapped, 0, sizeof(OVERLAPPED)); + memset(&self->write_buffer, 0, sizeof(Py_buffer)); + if (event) + self->overlapped.hEvent = event; + return (PyObject *)self; +} + +static void +Overlapped_dealloc(OverlappedObject *self) +{ + DWORD bytes; + DWORD olderr = GetLastError(); + BOOL wait = FALSE; + BOOL ret; + + if (!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED) + { + if (Py_CancelIoEx && Py_CancelIoEx(self->handle, &self->overlapped)) + wait = TRUE; + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, + &bytes, wait); + Py_END_ALLOW_THREADS + + switch (ret ? ERROR_SUCCESS : GetLastError()) { + case ERROR_SUCCESS: + case ERROR_NOT_FOUND: + case ERROR_OPERATION_ABORTED: + break; + default: + PyErr_Format( + PyExc_RuntimeError, + "%R still has pending operation at " + "deallocation, the process may crash", self); + PyErr_WriteUnraisable(NULL); + } + } + + if (self->overlapped.hEvent != NULL) + CloseHandle(self->overlapped.hEvent); + + if (self->write_buffer.obj) + PyBuffer_Release(&self->write_buffer); + + switch (self->type) { + case TYPE_READ: + case TYPE_ACCEPT: + Py_CLEAR(self->read_buffer); + } + PyObject_Del(self); + SetLastError(olderr); +} + +PyDoc_STRVAR( + Overlapped_cancel_doc, + "cancel() -> None\n\n" + "Cancel overlapped operation"); + +static PyObject * +Overlapped_cancel(OverlappedObject *self) +{ + BOOL ret = TRUE; + + if (self->type == TYPE_NOT_STARTED + || self->type == TYPE_WAIT_NAMED_PIPE_AND_CONNECT) + Py_RETURN_NONE; + + if (!HasOverlappedIoCompleted(&self->overlapped)) { + Py_BEGIN_ALLOW_THREADS + if (Py_CancelIoEx) + ret = Py_CancelIoEx(self->handle, &self->overlapped); + else + ret = CancelIo(self->handle); + Py_END_ALLOW_THREADS + } + + /* CancelIoEx returns ERROR_NOT_FOUND if the I/O completed in-between */ + if (!ret && GetLastError() != ERROR_NOT_FOUND) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +PyDoc_STRVAR( + Overlapped_getresult_doc, + "getresult(wait=False) -> result\n\n" + "Retrieve result of operation. If wait is true then it blocks\n" + "until the operation is finished. If wait is false and the\n" + "operation is still pending then an error is raised."); + +static PyObject * +Overlapped_getresult(OverlappedObject *self, PyObject *args) +{ + BOOL wait = FALSE; + DWORD transferred = 0; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, "|" F_BOOL, &wait)) + return NULL; + + if (self->type == TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation not yet attempted"); + return NULL; + } + + if (self->type == TYPE_NOT_STARTED) { + PyErr_SetString(PyExc_ValueError, "operation failed to start"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, &transferred, + wait); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + break; + case ERROR_BROKEN_PIPE: + if (self->read_buffer != NULL) + break; + /* fall through */ + default: + return SetFromWindowsErr(err); + } + + switch (self->type) { + case TYPE_READ: + assert(PyBytes_CheckExact(self->read_buffer)); + if (transferred != PyBytes_GET_SIZE(self->read_buffer) && + _PyBytes_Resize(&self->read_buffer, transferred)) + return NULL; + Py_INCREF(self->read_buffer); + return self->read_buffer; + default: + return PyLong_FromUnsignedLong((unsigned long) transferred); + } +} + +PyDoc_STRVAR( + Overlapped_ReadFile_doc, + "ReadFile(handle, size) -> Overlapped[message]\n\n" + "Start overlapped read"); + +static PyObject * +Overlapped_ReadFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD nread; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &handle, &size)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = ReadFile(handle, PyBytes_AS_STRING(buf), size, &nread, + &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_BROKEN_PIPE: + mark_as_completed(&self->overlapped); + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSARecv_doc, + "RecvFile(handle, size, flags) -> Overlapped[message]\n\n" + "Start overlapped receive"); + +static PyObject * +Overlapped_WSARecv(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD flags = 0; + DWORD nread; + PyObject *buf; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD "|" F_DWORD, + &handle, &size, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + wsabuf.len = size; + wsabuf.buf = PyBytes_AS_STRING(buf); + + Py_BEGIN_ALLOW_THREADS + ret = WSARecv((SOCKET)handle, &wsabuf, 1, &nread, &flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_BROKEN_PIPE: + mark_as_completed(&self->overlapped); + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WriteFile_doc, + "WriteFile(handle, buf) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped write"); + +static PyObject * +Overlapped_WriteFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD written; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &handle, &bufobj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + + Py_BEGIN_ALLOW_THREADS + ret = WriteFile(handle, self->write_buffer.buf, + (DWORD)self->write_buffer.len, + &written, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSASend_doc, + "WSASend(handle, buf, flags) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped send"); + +static PyObject * +Overlapped_WSASend(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD flags; + DWORD written; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O" F_DWORD, + &handle, &bufobj, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + wsabuf.len = (DWORD)self->write_buffer.len; + wsabuf.buf = self->write_buffer.buf; + + Py_BEGIN_ALLOW_THREADS + ret = WSASend((SOCKET)handle, &wsabuf, 1, &written, flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_AcceptEx_doc, + "AcceptEx(listen_handle, accept_handle) -> Overlapped[address_as_bytes]\n\n" + "Start overlapped wait for client to connect"); + +static PyObject * +Overlapped_AcceptEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ListenSocket; + SOCKET AcceptSocket; + DWORD BytesReceived; + DWORD size; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE, + &ListenSocket, &AcceptSocket)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + size = sizeof(struct sockaddr_in6) + 16; + buf = PyBytes_FromStringAndSize(NULL, size*2); + if (!buf) + return NULL; + + self->type = TYPE_ACCEPT; + self->handle = (HANDLE)ListenSocket; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = Py_AcceptEx(ListenSocket, AcceptSocket, PyBytes_AS_STRING(buf), + 0, size, size, &BytesReceived, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + + +static int +parse_address(PyObject *obj, SOCKADDR *Address, int Length) +{ + char *Host; + unsigned short Port; + unsigned long FlowInfo; + unsigned long ScopeId; + + memset(Address, 0, Length); + + if (PyArg_ParseTuple(obj, "sH", &Host, &Port)) + { + Address->sa_family = AF_INET; + if (WSAStringToAddressA(Host, AF_INET, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN*)Address)->sin_port = htons(Port); + return Length; + } + else if (PyArg_ParseTuple(obj, "sHkk", &Host, &Port, &FlowInfo, &ScopeId)) + { + PyErr_Clear(); + Address->sa_family = AF_INET6; + if (WSAStringToAddressA(Host, AF_INET6, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN6*)Address)->sin6_port = htons(Port); + ((SOCKADDR_IN6*)Address)->sin6_flowinfo = FlowInfo; + ((SOCKADDR_IN6*)Address)->sin6_scope_id = ScopeId; + return Length; + } + + return -1; +} + + +PyDoc_STRVAR( + Overlapped_ConnectEx_doc, + "ConnectEx(client_handle, address_as_bytes) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_ConnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ConnectSocket; + PyObject *AddressObj; + char AddressBuf[sizeof(struct sockaddr_in6)]; + SOCKADDR *Address = (SOCKADDR*)AddressBuf; + int Length; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &ConnectSocket, &AddressObj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + Length = sizeof(AddressBuf); + Length = parse_address(AddressObj, Address, Length); + if (Length < 0) + return NULL; + + self->type = TYPE_CONNECT; + self->handle = (HANDLE)ConnectSocket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_ConnectEx(ConnectSocket, Address, Length, + NULL, 0, NULL, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_DisconnectEx_doc, + "DisconnectEx(handle, flags) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_DisconnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET Socket; + DWORD flags; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &Socket, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + self->type = TYPE_DISCONNECT; + self->handle = (HANDLE)Socket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_DisconnectEx(Socket, &self->overlapped, flags, 0); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_ConnectNamedPipe_doc, + "ConnectNamedPipe(handle) -> Overlapped[None]\n\n" + "Start overlapped wait for a client to connect."); + +static PyObject * +Overlapped_ConnectNamedPipe(OverlappedObject *self, PyObject *args) +{ + HANDLE Pipe; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE, &Pipe)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + self->type = TYPE_CONNECT_NAMED_PIPE; + self->handle = Pipe; + + Py_BEGIN_ALLOW_THREADS + ret = ConnectNamedPipe(Pipe, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_PIPE_CONNECTED: + mark_as_completed(&self->overlapped); + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +/* Unfortunately there is no way to do an overlapped connect to a + pipe. We instead use WaitNamedPipe() and CreateFile() in a thread + pool thread. If a connection succeeds within a time limit (10 + seconds) then PostQueuedCompletionStatus() is used to return the + pipe handle to the completion port. */ + +static DWORD WINAPI +WaitNamedPipeAndConnectInThread(WaitNamedPipeAndConnectContext *ctx) +{ + HANDLE PipeHandle = INVALID_HANDLE_VALUE; + DWORD Start = GetTickCount(); + DWORD Deadline = Start + 10*1000; + DWORD Error = 0; + DWORD Timeout; + BOOL Success; + + for ( ; ; ) { + Timeout = Deadline - GetTickCount(); + if ((int)Timeout < 0) + break; + Success = WaitNamedPipe(ctx->Address, Timeout); + Error = Success ? ERROR_SUCCESS : GetLastError(); + switch (Error) { + case ERROR_SUCCESS: + PipeHandle = CreateFile(ctx->Address, + GENERIC_READ | GENERIC_WRITE, + 0, NULL, OPEN_EXISTING, + FILE_FLAG_OVERLAPPED, NULL); + if (PipeHandle == INVALID_HANDLE_VALUE) + continue; + break; + case ERROR_SEM_TIMEOUT: + continue; + } + break; + } + if (!PostQueuedCompletionStatus(ctx->IocpHandle, Error, + (ULONG_PTR)PipeHandle, ctx->Overlapped)) + CloseHandle(PipeHandle); + free(ctx); + return 0; +} + +PyDoc_STRVAR( + Overlapped_WaitNamedPipeAndConnect_doc, + "WaitNamedPipeAndConnect(addr, iocp_handle) -> Overlapped[pipe_handle]\n\n" + "Start overlapped connection to address, notifying iocp_handle when\n" + "finished"); + +static PyObject * +Overlapped_WaitNamedPipeAndConnect(OverlappedObject *self, PyObject *args) +{ + char *Address; + Py_ssize_t AddressLength; + HANDLE IocpHandle; + OVERLAPPED Overlapped; + BOOL ret; + DWORD err; + WaitNamedPipeAndConnectContext *ctx; + Py_ssize_t ContextLength; + + if (!PyArg_ParseTuple(args, "s#" F_HANDLE F_POINTER, + &Address, &AddressLength, &IocpHandle, &Overlapped)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + ContextLength = (AddressLength + + offsetof(WaitNamedPipeAndConnectContext, Address)); + ctx = calloc(1, ContextLength + 1); + if (ctx == NULL) + return PyErr_NoMemory(); + memcpy(ctx->Address, Address, AddressLength + 1); + ctx->Overlapped = &self->overlapped; + ctx->IocpHandle = IocpHandle; + + self->type = TYPE_WAIT_NAMED_PIPE_AND_CONNECT; + self->handle = NULL; + + Py_BEGIN_ALLOW_THREADS + ret = QueueUserWorkItem(WaitNamedPipeAndConnectInThread, ctx, + WT_EXECUTELONGFUNCTION); + Py_END_ALLOW_THREADS + + mark_as_completed(&self->overlapped); + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + if (!ret) + return SetFromWindowsErr(err); + Py_RETURN_NONE; +} + +static PyObject* +Overlapped_getaddress(OverlappedObject *self) +{ + return PyLong_FromVoidPtr(&self->overlapped); +} + +static PyObject* +Overlapped_getpending(OverlappedObject *self) +{ + return PyBool_FromLong(!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED); +} + +static PyMethodDef Overlapped_methods[] = { + {"getresult", (PyCFunction) Overlapped_getresult, + METH_VARARGS, Overlapped_getresult_doc}, + {"cancel", (PyCFunction) Overlapped_cancel, + METH_NOARGS, Overlapped_cancel_doc}, + {"ReadFile", (PyCFunction) Overlapped_ReadFile, + METH_VARARGS, Overlapped_ReadFile_doc}, + {"WSARecv", (PyCFunction) Overlapped_WSARecv, + METH_VARARGS, Overlapped_WSARecv_doc}, + {"WriteFile", (PyCFunction) Overlapped_WriteFile, + METH_VARARGS, Overlapped_WriteFile_doc}, + {"WSASend", (PyCFunction) Overlapped_WSASend, + METH_VARARGS, Overlapped_WSASend_doc}, + {"AcceptEx", (PyCFunction) Overlapped_AcceptEx, + METH_VARARGS, Overlapped_AcceptEx_doc}, + {"ConnectEx", (PyCFunction) Overlapped_ConnectEx, + METH_VARARGS, Overlapped_ConnectEx_doc}, + {"DisconnectEx", (PyCFunction) Overlapped_DisconnectEx, + METH_VARARGS, Overlapped_DisconnectEx_doc}, + {"ConnectNamedPipe", (PyCFunction) Overlapped_ConnectNamedPipe, + METH_VARARGS, Overlapped_ConnectNamedPipe_doc}, + {"WaitNamedPipeAndConnect", + (PyCFunction) Overlapped_WaitNamedPipeAndConnect, + METH_VARARGS, Overlapped_WaitNamedPipeAndConnect_doc}, + {NULL} +}; + +static PyMemberDef Overlapped_members[] = { + {"error", T_ULONG, + offsetof(OverlappedObject, error), + READONLY, "Error from last operation"}, + {"event", T_HANDLE, + offsetof(OverlappedObject, overlapped) + offsetof(OVERLAPPED, hEvent), + READONLY, "Overlapped event handle"}, + {NULL} +}; + +static PyGetSetDef Overlapped_getsets[] = { + {"address", (getter)Overlapped_getaddress, NULL, + "Address of overlapped structure"}, + {"pending", (getter)Overlapped_getpending, NULL, + "Whether the operation is pending"}, + {NULL}, +}; + +PyTypeObject OverlappedType = { + PyVarObject_HEAD_INIT(NULL, 0) + /* tp_name */ "_overlapped.Overlapped", + /* tp_basicsize */ sizeof(OverlappedObject), + /* tp_itemsize */ 0, + /* tp_dealloc */ (destructor) Overlapped_dealloc, + /* tp_print */ 0, + /* tp_getattr */ 0, + /* tp_setattr */ 0, + /* tp_reserved */ 0, + /* tp_repr */ 0, + /* tp_as_number */ 0, + /* tp_as_sequence */ 0, + /* tp_as_mapping */ 0, + /* tp_hash */ 0, + /* tp_call */ 0, + /* tp_str */ 0, + /* tp_getattro */ 0, + /* tp_setattro */ 0, + /* tp_as_buffer */ 0, + /* tp_flags */ Py_TPFLAGS_DEFAULT, + /* tp_doc */ "OVERLAPPED structure wrapper", + /* tp_traverse */ 0, + /* tp_clear */ 0, + /* tp_richcompare */ 0, + /* tp_weaklistoffset */ 0, + /* tp_iter */ 0, + /* tp_iternext */ 0, + /* tp_methods */ Overlapped_methods, + /* tp_members */ Overlapped_members, + /* tp_getset */ Overlapped_getsets, + /* tp_base */ 0, + /* tp_dict */ 0, + /* tp_descr_get */ 0, + /* tp_descr_set */ 0, + /* tp_dictoffset */ 0, + /* tp_init */ 0, + /* tp_alloc */ 0, + /* tp_new */ Overlapped_new, +}; + +static PyMethodDef overlapped_functions[] = { + {"CreateIoCompletionPort", overlapped_CreateIoCompletionPort, + METH_VARARGS, CreateIoCompletionPort_doc}, + {"GetQueuedCompletionStatus", overlapped_GetQueuedCompletionStatus, + METH_VARARGS, GetQueuedCompletionStatus_doc}, + {"PostQueuedCompletionStatus", overlapped_PostQueuedCompletionStatus, + METH_VARARGS, PostQueuedCompletionStatus_doc}, + {"FormatMessage", overlapped_FormatMessage, + METH_VARARGS, FormatMessage_doc}, + {"BindLocal", overlapped_BindLocal, + METH_VARARGS, BindLocal_doc}, + {NULL} +}; + +static struct PyModuleDef overlapped_module = { + PyModuleDef_HEAD_INIT, + "_overlapped", + NULL, + -1, + overlapped_functions, + NULL, + NULL, + NULL, + NULL +}; + +#define WINAPI_CONSTANT(fmt, con) \ + PyDict_SetItemString(d, #con, Py_BuildValue(fmt, con)) + +PyMODINIT_FUNC +PyInit__overlapped(void) +{ + PyObject *m, *d; + + /* Ensure WSAStartup() called before initializing function pointers */ + m = PyImport_ImportModule("_socket"); + if (!m) + return NULL; + Py_DECREF(m); + + if (initialize_function_pointers() < 0) + return NULL; + + if (PyType_Ready(&OverlappedType) < 0) + return NULL; + + m = PyModule_Create(&overlapped_module); + if (PyModule_AddObject(m, "Overlapped", (PyObject *)&OverlappedType) < 0) + return NULL; + + d = PyModule_GetDict(m); + + WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING); + WINAPI_CONSTANT(F_DWORD, ERROR_NETNAME_DELETED); + WINAPI_CONSTANT(F_DWORD, ERROR_SEM_TIMEOUT); + WINAPI_CONSTANT(F_DWORD, INFINITE); + WINAPI_CONSTANT(F_HANDLE, INVALID_HANDLE_VALUE); + WINAPI_CONSTANT(F_HANDLE, NULL); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_ACCEPT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_CONNECT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, TF_REUSE_SOCKET); + + return m; +} diff --git a/runtests.py b/runtests.py new file mode 100644 index 00000000..287d0367 --- /dev/null +++ b/runtests.py @@ -0,0 +1,278 @@ +"""Run Tulip unittests. + +Usage: + python3 runtests.py [flags] [pattern] ... + +Patterns are matched against the fully qualified name of the test, +including package, module, class and method, +e.g. 'tests.events_test.PolicyTests.testPolicy'. + +For full help, try --help. + +runtests.py --coverage is equivalent of: + + $(COVERAGE) run --branch runtests.py -v + $(COVERAGE) html $(list of files) + $(COVERAGE) report -m $(list of files) + +""" + +# Originally written by Beech Horn (for NDB). + +import argparse +import gc +import logging +import os +import re +import sys +import unittest +import textwrap +import importlib.machinery +try: + import coverage +except ImportError: + coverage = None + +from unittest.signals import installHandler + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +ARGS = argparse.ArgumentParser(description="Run all unittests.") +ARGS.add_argument( + '-v', action="store", dest='verbose', + nargs='?', const=1, type=int, default=0, help='verbose') +ARGS.add_argument( + '-x', action="store_true", dest='exclude', help='exclude tests') +ARGS.add_argument( + '-f', '--failfast', action="store_true", default=False, + dest='failfast', help='Stop on first fail or error') +ARGS.add_argument( + '-c', '--catch', action="store_true", default=False, + dest='catchbreak', help='Catch control-C and display results') +ARGS.add_argument( + '--forever', action="store_true", dest='forever', default=False, + help='run tests forever to catch sporadic errors') +ARGS.add_argument( + '--findleaks', action='store_true', dest='findleaks', + help='detect tests that leak memory') +ARGS.add_argument( + '-q', action="store_true", dest='quiet', help='quiet') +ARGS.add_argument( + '--tests', action="store", dest='testsdir', default='tests', + help='tests directory') +ARGS.add_argument( + '--coverage', action="store_true", dest='coverage', + help='enable html coverage report') +ARGS.add_argument( + 'pattern', action="store", nargs="*", + help='optional regex patterns to match test ids (default all tests)') + +COV_ARGS = argparse.ArgumentParser(description="Run all unittests.") +COV_ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') + + +def load_modules(basedir, suffix='.py'): + def list_dir(prefix, dir): + files = [] + + modpath = os.path.join(dir, '__init__.py') + if os.path.isfile(modpath): + mod = os.path.split(dir)[-1] + files.append(('{}{}'.format(prefix, mod), modpath)) + + prefix = '{}{}.'.format(prefix, mod) + + for name in os.listdir(dir): + path = os.path.join(dir, name) + + if os.path.isdir(path): + files.extend(list_dir('{}{}.'.format(prefix, name), path)) + else: + if (name != '__init__.py' and + name.endswith(suffix) and + not name.startswith(('.', '_'))): + files.append(('{}{}'.format(prefix, name[:-3]), path)) + + return files + + mods = [] + for modname, sourcefile in list_dir('', basedir): + if modname == 'runtests': + continue + try: + loader = importlib.machinery.SourceFileLoader(modname, sourcefile) + mods.append((loader.load_module(), sourcefile)) + except SyntaxError: + raise + except Exception as err: + print("Skipping '{}': {}".format(modname, err), file=sys.stderr) + + return mods + + +class TestsFinder: + + def __init__(self, testsdir, includes=(), excludes=()): + self._testsdir = testsdir + self._includes = includes + self._excludes = excludes + self.find_available_tests() + + def find_available_tests(self): + """ + Find available test classes without instantiating them. + """ + self._test_factories = [] + mods = [mod for mod, _ in load_modules(self._testsdir)] + for mod in mods: + for name in set(dir(mod)): + if name.endswith('Tests'): + self._test_factories.append(getattr(mod, name)) + + def load_tests(self): + """ + Load test cases from the available test classes and apply + optional include / exclude filters. + """ + loader = unittest.TestLoader() + suite = unittest.TestSuite() + for test_factory in self._test_factories: + tests = loader.loadTestsFromTestCase(test_factory) + if self._includes: + tests = [test + for test in tests + if any(re.search(pat, test.id()) + for pat in self._includes)] + if self._excludes: + tests = [test + for test in tests + if not any(re.search(pat, test.id()) + for pat in self._excludes)] + suite.addTests(tests) + return suite + + +class TestResult(unittest.TextTestResult): + + def __init__(self, stream, descriptions, verbosity): + super().__init__(stream, descriptions, verbosity) + self.leaks = [] + + def startTest(self, test): + super().startTest(test) + gc.collect() + + def addSuccess(self, test): + super().addSuccess(test) + gc.collect() + if gc.garbage: + if self.showAll: + self.stream.writeln( + " Warning: test created {} uncollectable " + "object(s).".format(len(gc.garbage))) + # move the uncollectable objects somewhere so we don't see + # them again + self.leaks.append((self.getDescription(test), gc.garbage[:])) + del gc.garbage[:] + + +class TestRunner(unittest.TextTestRunner): + resultclass = TestResult + + def run(self, test): + result = super().run(test) + if result.leaks: + self.stream.writeln("{} tests leaks:".format(len(result.leaks))) + for name, leaks in result.leaks: + self.stream.writeln(' '*4 + name + ':') + for leak in leaks: + self.stream.writeln(' '*8 + repr(leak)) + return result + + +def runtests(): + args = ARGS.parse_args() + + if args.coverage and coverage is None: + URL = "bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py" + print(textwrap.dedent(""" + coverage package is not installed. + + To install coverage3 for Python 3, you need: + - Setuptools (https://pypi.python.org/pypi/setuptools) + + What worked for me: + - download {0} + * curl -O https://{0} + - python3 ez_setup.py + - python3 -m easy_install coverage + """.format(URL)).strip()) + sys.exit(1) + + testsdir = os.path.abspath(args.testsdir) + if not os.path.isdir(testsdir): + print("Tests directory is not found: {}\n".format(testsdir)) + ARGS.print_help() + return + + excludes = includes = [] + if args.exclude: + excludes = args.pattern + else: + includes = args.pattern + + v = 0 if args.quiet else args.verbose + 1 + failfast = args.failfast + catchbreak = args.catchbreak + findleaks = args.findleaks + runner_factory = TestRunner if findleaks else unittest.TextTestRunner + + if args.coverage: + cov = coverage.coverage(branch=True, + source=['tulip'], + ) + cov.start() + + finder = TestsFinder(args.testsdir, includes, excludes) + logger = logging.getLogger() + if v == 0: + logger.setLevel(logging.CRITICAL) + elif v == 1: + logger.setLevel(logging.ERROR) + elif v == 2: + logger.setLevel(logging.WARNING) + elif v == 3: + logger.setLevel(logging.INFO) + elif v >= 4: + logger.setLevel(logging.DEBUG) + if catchbreak: + installHandler() + try: + if args.forever: + while True: + tests = finder.load_tests() + result = runner_factory(verbosity=v, + failfast=failfast).run(tests) + if not result.wasSuccessful(): + sys.exit(1) + else: + tests = finder.load_tests() + result = runner_factory(verbosity=v, + failfast=failfast).run(tests) + sys.exit(not result.wasSuccessful()) + finally: + if args.coverage: + cov.stop() + cov.save() + cov.html_report(directory='htmlcov') + print("\nCoverage report:") + cov.report(show_missing=False) + here = os.path.dirname(os.path.abspath(__file__)) + print("\nFor html report:") + print("open file://{}/htmlcov/index.html".format(here)) + + +if __name__ == '__main__': + runtests() diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..0260f9d5 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[build_ext] +build_lib=tulip diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..dcaee96f --- /dev/null +++ b/setup.py @@ -0,0 +1,14 @@ +import os +from distutils.core import setup, Extension + +extensions = [] +if os.name == 'nt': + ext = Extension('_overlapped', ['overlapped.c'], libraries=['ws2_32']) + extensions.append(ext) + +setup(name='tulip', + description="reference implementation of PEP 3156", + url='http://www.python.org/dev/peps/pep-3156/', + packages=['tulip'], + ext_modules=extensions + ) diff --git a/tests/base_events_test.py b/tests/base_events_test.py new file mode 100644 index 00000000..38fe07aa --- /dev/null +++ b/tests/base_events_test.py @@ -0,0 +1,590 @@ +"""Tests for base_events.py""" + +import logging +import socket +import time +import unittest +import unittest.mock + +from tulip import base_events +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import tasks +from tulip import test_utils + + +class BaseEventLoopTests(unittest.TestCase): + + def setUp(self): + self.loop = base_events.BaseEventLoop() + self.loop._selector = unittest.mock.Mock() + events.set_event_loop(None) + + def test_not_implemented(self): + m = unittest.mock.Mock() + self.assertRaises( + NotImplementedError, + self.loop._make_socket_transport, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_ssl_transport, m, m, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_datagram_transport, m, m) + self.assertRaises( + NotImplementedError, self.loop._process_events, []) + self.assertRaises( + NotImplementedError, self.loop._write_to_self) + self.assertRaises( + NotImplementedError, self.loop._read_from_self) + self.assertRaises( + NotImplementedError, + self.loop._make_read_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_write_pipe_transport, m, m) + gen = self.loop._make_subprocess_transport(m, m, m, m, m, m, m) + self.assertRaises(NotImplementedError, next, iter(gen)) + + def test__add_callback_handle(self): + h = events.Handle(lambda: False, ()) + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertIn(h, self.loop._ready) + + def test__add_callback_timer(self): + h = events.TimerHandle(time.monotonic()+10, lambda: False, ()) + + self.loop._add_callback(h) + self.assertIn(h, self.loop._scheduled) + + def test__add_callback_cancelled_handle(self): + h = events.Handle(lambda: False, ()) + h.cancel() + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertFalse(self.loop._ready) + + def test_set_default_executor(self): + executor = unittest.mock.Mock() + self.loop.set_default_executor(executor) + self.assertIs(executor, self.loop._default_executor) + + def test_getnameinfo(self): + sockaddr = unittest.mock.Mock() + self.loop.run_in_executor = unittest.mock.Mock() + self.loop.getnameinfo(sockaddr) + self.assertEqual( + (None, socket.getnameinfo, sockaddr, 0), + self.loop.run_in_executor.call_args[0]) + + def test_call_soon(self): + def cb(): + pass + + h = self.loop.call_soon(cb) + self.assertEqual(h._callback, cb) + self.assertIsInstance(h, events.Handle) + self.assertIn(h, self.loop._ready) + + def test_call_later(self): + def cb(): + pass + + h = self.loop.call_later(10.0, cb) + self.assertIsInstance(h, events.TimerHandle) + self.assertIn(h, self.loop._scheduled) + self.assertNotIn(h, self.loop._ready) + + def test_call_later_negative_delays(self): + calls = [] + + def cb(arg): + calls.append(arg) + + self.loop._process_events = unittest.mock.Mock() + self.loop.call_later(-1, cb, 'a') + self.loop.call_later(-2, cb, 'b') + test_utils.run_briefly(self.loop) + self.assertEqual(calls, ['b', 'a']) + + def test_time_and_call_at(self): + def cb(): + self.loop.stop() + + self.loop._process_events = unittest.mock.Mock() + when = self.loop.time() + 0.1 + self.loop.call_at(when, cb) + t0 = self.loop.time() + self.loop.run_forever() + t1 = self.loop.time() + self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0) + + def test_run_once_in_executor_handle(self): + def cb(): + pass + + self.assertRaises( + AssertionError, self.loop.run_in_executor, + None, events.Handle(cb, ()), ('',)) + self.assertRaises( + AssertionError, self.loop.run_in_executor, + None, events.TimerHandle(10, cb, ())) + + def test_run_once_in_executor_cancelled(self): + def cb(): + pass + h = events.Handle(cb, ()) + h.cancel() + + f = self.loop.run_in_executor(None, h) + self.assertIsInstance(f, futures.Future) + self.assertTrue(f.done()) + self.assertIsNone(f.result()) + + def test_run_once_in_executor_plain(self): + def cb(): + pass + h = events.Handle(cb, ()) + f = futures.Future(loop=self.loop) + executor = unittest.mock.Mock() + executor.submit.return_value = f + + self.loop.set_default_executor(executor) + + res = self.loop.run_in_executor(None, h) + self.assertIs(f, res) + + executor = unittest.mock.Mock() + executor.submit.return_value = f + res = self.loop.run_in_executor(executor, h) + self.assertIs(f, res) + self.assertTrue(executor.submit.called) + + f.cancel() # Don't complain about abandoned Future. + + def test__run_once(self): + h1 = events.TimerHandle(time.monotonic() + 0.1, lambda: True, ()) + h2 = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ()) + + h1.cancel() + + self.loop._process_events = unittest.mock.Mock() + self.loop._scheduled.append(h1) + self.loop._scheduled.append(h2) + self.loop._run_once() + + t = self.loop._selector.select.call_args[0][0] + self.assertTrue(9.99 < t < 10.1, t) + self.assertEqual([h2], self.loop._scheduled) + self.assertTrue(self.loop._process_events.called) + + @unittest.mock.patch('tulip.base_events.time') + @unittest.mock.patch('tulip.base_events.logger') + def test__run_once_logging(self, m_logging, m_time): + # Log to INFO level if timeout > 1.0 sec. + idx = -1 + data = [10.0, 10.0, 12.0, 13.0] + + def monotonic(): + nonlocal data, idx + idx += 1 + return data[idx] + + m_time.monotonic = monotonic + m_logging.INFO = logging.INFO + m_logging.DEBUG = logging.DEBUG + + self.loop._scheduled.append( + events.TimerHandle(11.0, lambda: True, ())) + self.loop._process_events = unittest.mock.Mock() + self.loop._run_once() + self.assertEqual(logging.INFO, m_logging.log.call_args[0][0]) + + idx = -1 + data = [10.0, 10.0, 10.3, 13.0] + self.loop._scheduled = [events.TimerHandle(11.0, lambda:True, ())] + self.loop._run_once() + self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0]) + + def test__run_once_schedule_handle(self): + handle = None + processed = False + + def cb(loop): + nonlocal processed, handle + processed = True + handle = loop.call_soon(lambda: True) + + h = events.TimerHandle(time.monotonic() - 1, cb, (self.loop,)) + + self.loop._process_events = unittest.mock.Mock() + self.loop._scheduled.append(h) + self.loop._run_once() + + self.assertTrue(processed) + self.assertEqual([handle], list(self.loop._ready)) + + def test_run_until_complete_type_error(self): + self.assertRaises( + TypeError, self.loop.run_until_complete, 'blah') + + +class MyProto(protocols.Protocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = futures.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyDatagramProto(protocols.DatagramProtocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = futures.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class BaseEventLoopWithSelectorTests(unittest.TestCase): + + def setUp(self): + self.loop = events.new_event_loop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_connection_multiple_errors(self, m_socket): + + class MyProto(protocols.Protocol): + pass + + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + idx = -1 + errors = ['err1', 'err2'] + + def _socket(*args, **kw): + nonlocal idx, errors + idx += 1 + raise OSError(errors[idx]) + + m_socket.socket = _socket + + self.loop.getaddrinfo = getaddrinfo_task + + coro = self.loop.create_connection(MyProto, 'example.com', 80) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(coro) + + self.assertEqual(str(cm.exception), 'Multiple exceptions: err1, err2') + + def test_create_connection_host_port_sock(self): + coro = self.loop.create_connection( + MyProto, 'example.com', 80, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_host_port_sock(self): + coro = self.loop.create_connection(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_getaddrinfo(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_connect_err(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_multiple(self): + @tasks.coroutine + def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET) + with self.assertRaises(OSError): + self.loop.run_until_complete(coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_connection_multiple_errors_local_addr(self, m_socket): + + def bind(addr): + if addr[0] == '0.0.0.1': + err = OSError('Err') + err.strerror = 'Err' + raise err + + m_socket.socket.return_value.bind = bind + + @tasks.coroutine + def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError('Err2') + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(coro) + + self.assertTrue(str(cm.exception).startswith('Multiple exceptions: ')) + self.assertTrue(m_socket.socket.return_value.close.called) + + def test_create_connection_no_local_addr(self): + @tasks.coroutine + def getaddrinfo(host, *args, **kw): + if host == 'example.com': + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + else: + return [] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + self.loop.getaddrinfo = getaddrinfo_task + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_server_empty_host(self): + # if host is empty string use None instead + host = object() + + @tasks.coroutine + def getaddrinfo(*args, **kw): + nonlocal host + host = args[0] + yield from [] + + def getaddrinfo_task(*args, **kwds): + return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + fut = self.loop.create_server(MyProto, '', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertIsNone(host) + + def test_create_server_host_port_sock(self): + fut = self.loop.create_server( + MyProto, '0.0.0.0', 0, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_create_server_no_host_port_sock(self): + fut = self.loop.create_server(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_create_server_no_getaddrinfo(self): + getaddrinfo = self.loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo.return_value = [] + + f = self.loop.create_server(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, f) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_server_cant_bind(self, m_socket): + + class Err(OSError): + strerror = 'error' + + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.create_server(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_no_addrinfo(self, m_socket): + m_socket.getaddrinfo.return_value = [] + + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_addr_error(self): + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr='localhost') + self.assertRaises( + AssertionError, self.loop.run_until_complete, coro) + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 1, 2, 3)) + self.assertRaises( + AssertionError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_connect_err(self): + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_socket_err(self, m_socket): + m_socket.getaddrinfo = socket.getaddrinfo + m_socket.socket.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, local_addr=('127.0.0.1', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_no_matching_family(self): + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, + remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) + self.assertRaises( + ValueError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_setblk_err(self, m_socket): + m_socket.socket.return_value.setblocking.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + self.assertTrue( + m_socket.socket.return_value.close.called) + + def test_create_datagram_endpoint_noaddr_nofamily(self): + coro = self.loop.create_datagram_endpoint( + protocols.DatagramProtocol) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + @unittest.mock.patch('tulip.base_events.socket') + def test_create_datagram_endpoint_cant_bind(self, m_socket): + class Err(OSError): + pass + + m_socket.AF_INET6 = socket.AF_INET6 + m_socket.getaddrinfo = socket.getaddrinfo + m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, + local_addr=('127.0.0.1', 0), family=socket.AF_INET) + self.assertRaises(Err, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_accept_connection_retry(self): + sock = unittest.mock.Mock() + sock.accept.side_effect = BlockingIOError() + + self.loop._accept_connection(MyProto, sock) + self.assertFalse(sock.close.called) + + @unittest.mock.patch('tulip.selector_events.logger') + def test_accept_connection_exception(self, m_log): + sock = unittest.mock.Mock() + sock.fileno.return_value = 10 + sock.accept.side_effect = OSError() + + self.loop._accept_connection(MyProto, sock) + self.assertTrue(sock.close.called) + self.assertTrue(m_log.exception.called) diff --git a/tests/echo.py b/tests/echo.py new file mode 100644 index 00000000..f6ac0a30 --- /dev/null +++ b/tests/echo.py @@ -0,0 +1,6 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + os.write(1, buf) diff --git a/tests/echo2.py b/tests/echo2.py new file mode 100644 index 00000000..e83ca09f --- /dev/null +++ b/tests/echo2.py @@ -0,0 +1,6 @@ +import os + +if __name__ == '__main__': + buf = os.read(0, 1024) + os.write(1, b'OUT:'+buf) + os.write(2, b'ERR:'+buf) diff --git a/tests/echo3.py b/tests/echo3.py new file mode 100644 index 00000000..f1f7ea7c --- /dev/null +++ b/tests/echo3.py @@ -0,0 +1,9 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + try: + os.write(1, b'OUT:'+buf) + except OSError as ex: + os.write(2, b'ERR:' + ex.__class__.__name__.encode('ascii')) diff --git a/tests/events_test.py b/tests/events_test.py new file mode 100644 index 00000000..f5f51e5b --- /dev/null +++ b/tests/events_test.py @@ -0,0 +1,1573 @@ +"""Tests for events.py.""" + +import functools +import gc +import io +import os +import signal +import socket +try: + import ssl +except ImportError: + ssl = None +import subprocess +import sys +import threading +import time +import errno +import unittest +import unittest.mock +from test.support import find_unused_port + + +from tulip import futures +from tulip import events +from tulip import transports +from tulip import protocols +from tulip import selector_events +from tulip import tasks +from tulip import test_utils +from tulip import locks + + +class MyProto(protocols.Protocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyDatagramProto(protocols.DatagramProtocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def connection_refused(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyReadPipeProto(protocols.Protocol): + done = None + + def __init__(self, loop=None): + self.state = ['INITIAL'] + self.nbytes = 0 + self.transport = None + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == ['INITIAL'], self.state + self.state.append('CONNECTED') + + def data_received(self, data): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.state.append('EOF') + + def connection_lost(self, exc): + assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state + self.state.append('CLOSED') + if self.done: + self.done.set_result(None) + + +class MyWritePipeProto(protocols.BaseProtocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.transport = None + if loop is not None: + self.done = futures.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MySubprocessProtocol(protocols.SubprocessProtocol): + + def __init__(self, loop): + self.state = 'INITIAL' + self.transport = None + self.connected = futures.Future(loop=loop) + self.completed = futures.Future(loop=loop) + self.disconnects = {fd: futures.Future(loop=loop) for fd in range(3)} + self.data = {1: b'', 2: b''} + self.returncode = None + self.got_data = {1: locks.Event(loop=loop), + 2: locks.Event(loop=loop)} + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + self.connected.set_result(None) + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + self.completed.set_result(None) + + def pipe_data_received(self, fd, data): + assert self.state == 'CONNECTED', self.state + self.data[fd] += data + self.got_data[fd].set() + + def pipe_connection_lost(self, fd, exc): + assert self.state == 'CONNECTED', self.state + if exc: + self.disconnects[fd].set_exception(exc) + else: + self.disconnects[fd].set_result(exc) + + def process_exited(self): + assert self.state == 'CONNECTED', self.state + self.returncode = self.transport.get_returncode() + + +class EventLoopTestsMixin: + + def setUp(self): + super().setUp() + self.loop = self.create_event_loop() + events.set_event_loop(None) + + def tearDown(self): + # just in case if we have transport close callbacks + test_utils.run_briefly(self.loop) + + self.loop.close() + gc.collect() + super().tearDown() + + def test_run_until_complete_nesting(self): + @tasks.coroutine + def coro1(): + yield + + @tasks.coroutine + def coro2(): + self.assertTrue(self.loop.is_running()) + self.loop.run_until_complete(coro1()) + + self.assertRaises( + RuntimeError, self.loop.run_until_complete, coro2()) + + # Note: because of the default Windows timing granularity of + # 15.6 msec, we use fairly long sleep times here (~100 msec). + + def test_run_until_complete(self): + t0 = self.loop.time() + self.loop.run_until_complete(tasks.sleep(0.1, loop=self.loop)) + t1 = self.loop.time() + self.assertTrue(0.08 <= t1-t0 <= 0.12, t1-t0) + + def test_run_until_complete_stopped(self): + @tasks.coroutine + def cb(): + self.loop.stop() + yield from tasks.sleep(0.1, loop=self.loop) + task = cb() + self.assertRaises(RuntimeError, + self.loop.run_until_complete, task) + + def test_call_later(self): + results = [] + + def callback(arg): + results.append(arg) + self.loop.stop() + + self.loop.call_later(0.1, callback, 'hello world') + t0 = time.monotonic() + self.loop.run_forever() + t1 = time.monotonic() + self.assertEqual(results, ['hello world']) + self.assertTrue(0.08 <= t1-t0 <= 0.2, t1-t0) + + def test_call_soon(self): + results = [] + + def callback(arg1, arg2): + results.append((arg1, arg2)) + self.loop.stop() + + self.loop.call_soon(callback, 'hello', 'world') + self.loop.run_forever() + self.assertEqual(results, [('hello', 'world')]) + + def test_call_soon_threadsafe(self): + results = [] + lock = threading.Lock() + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.loop.stop() + + def run_in_thread(): + self.loop.call_soon_threadsafe(callback, 'hello') + lock.release() + + lock.acquire() + t = threading.Thread(target=run_in_thread) + t.start() + + with lock: + self.loop.call_soon(callback, 'world') + self.loop.run_forever() + t.join() + self.assertEqual(results, ['hello', 'world']) + + def test_call_soon_threadsafe_same_thread(self): + results = [] + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.loop.stop() + + self.loop.call_soon_threadsafe(callback, 'hello') + self.loop.call_soon(callback, 'world') + self.loop.run_forever() + self.assertEqual(results, ['hello', 'world']) + + def test_run_in_executor(self): + def run(arg): + return (arg, threading.get_ident()) + f2 = self.loop.run_in_executor(None, run, 'yo') + res, thread_id = self.loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + self.assertNotEqual(thread_id, threading.get_ident()) + + def test_reader_callback(self): + r, w = test_utils.socketpair() + bytes_read = [] + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.append(data) + else: + self.assertTrue(self.loop.remove_reader(r.fileno())) + r.close() + + self.loop.add_reader(r.fileno(), reader) + self.loop.call_soon(w.send, b'abc') + test_utils.run_briefly(self.loop) + self.loop.call_soon(w.send, b'def') + test_utils.run_briefly(self.loop) + self.loop.call_soon(w.close) + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + self.assertEqual(b''.join(bytes_read), b'abcdef') + + def test_writer_callback(self): + r, w = test_utils.socketpair() + w.setblocking(False) + self.loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) + test_utils.run_briefly(self.loop) + + def remove_writer(): + self.assertTrue(self.loop.remove_writer(w.fileno())) + + self.loop.call_soon(remove_writer) + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + w.close() + data = r.recv(256*1024) + r.close() + self.assertGreaterEqual(len(data), 200) + + def test_sock_client_ops(self): + with test_utils.run_test_server() as httpd: + sock = socket.socket() + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, httpd.address)) + self.loop.run_until_complete( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + # consume data + self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + sock.close() + + self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) + + def test_sock_client_fail(self): + # Make sure that we will get an unused port + address = None + try: + s = socket.socket() + s.bind(('127.0.0.1', 0)) + address = s.getsockname() + finally: + s.close() + + sock = socket.socket() + sock.setblocking(False) + with self.assertRaises(ConnectionRefusedError): + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) + sock.close() + + def test_sock_accept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + + f = self.loop.sock_accept(listener) + conn, addr = self.loop.run_until_complete(f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + def test_add_signal_handler(self): + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + # Check error behavior first. + self.assertRaises( + TypeError, self.loop.add_signal_handler, 'boom', my_handler) + self.assertRaises( + TypeError, self.loop.remove_signal_handler, 'boom') + self.assertRaises( + ValueError, self.loop.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, signal.NSIG+1) + self.assertRaises( + ValueError, self.loop.add_signal_handler, 0, my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, 0) + self.assertRaises( + ValueError, self.loop.add_signal_handler, -1, my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, -1) + self.assertRaises( + RuntimeError, self.loop.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(self.loop.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + self.loop.add_signal_handler(signal.SIGINT, my_handler) + test_utils.run_briefly(self.loop) + os.kill(os.getpid(), signal.SIGINT) + test_utils.run_briefly(self.loop) + self.assertEqual(caught, 1) + # Removing it should restore the default handler. + self.assertTrue(self.loop.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(self.loop.remove_signal_handler(signal.SIGINT)) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_while_selecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + self.loop.stop() + + self.loop.add_signal_handler(signal.SIGALRM, my_handler) + + signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once. + self.loop.run_forever() + self.assertEqual(caught, 1) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_args(self): + some_args = (42,) + caught = 0 + + def my_handler(*args): + nonlocal caught + caught += 1 + self.assertEqual(args, some_args) + + self.loop.add_signal_handler(signal.SIGALRM, my_handler, *some_args) + + signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once. + self.loop.call_later(0.015, self.loop.stop) + self.loop.run_forever() + self.assertEqual(caught, 1) + + def test_create_connection(self): + with test_utils.run_test_server() as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), *httpd.address) + tr, pr = self.loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def test_create_connection_sock(self): + with test_utils.run_test_server() as httpd: + sock = None + infos = self.loop.run_until_complete( + self.loop.getaddrinfo( + *httpd.address, type=socket.SOCK_STREAM)) + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) + except: + pass + else: + break + else: + assert False, 'Can not create socket.' + + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), sock=sock) + tr, pr = self.loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_ssl_connection(self): + with test_utils.run_test_server(use_ssl=True) as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), *httpd.address, + ssl=test_utils.dummy_ssl_context()) + tr, pr = self.loop.run_until_complete(f) + self.assertTrue(isinstance(tr, transports.Transport)) + self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + self.assertIsNotNone(tr.get_extra_info('sockname')) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def test_create_connection_local_addr(self): + with test_utils.run_test_server() as httpd: + port = find_unused_port() + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, local_addr=(httpd.address[0], port)) + tr, pr = self.loop.run_until_complete(f) + expected = pr.transport.get_extra_info('sockname')[1] + self.assertEqual(port, expected) + tr.close() + + def test_create_connection_local_addr_in_use(self): + with test_utils.run_test_server() as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, local_addr=httpd.address) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + self.assertIn(str(httpd.address), cm.exception.strerror) + + def test_create_server(self): + proto = None + + def factory(): + nonlocal proto + proto = MyProto() + return proto + + f = self.loop.create_server(factory, '0.0.0.0', 0) + server = self.loop.run_until_complete(f) + self.assertEqual(len(server.sockets), 1) + sock = server.sockets[0] + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + test_utils.run_briefly(self.loop) + self.assertIsInstance(proto, MyProto) + self.assertEqual('INITIAL', proto.state) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + test_utils.run_briefly(self.loop) # windows iocp + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('sockname')) + self.assertEqual('127.0.0.1', + proto.transport.get_extra_info('peername')[0]) + + # close connection + proto.transport.close() + test_utils.run_briefly(self.loop) # windows iocp + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # close server + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_server_ssl(self): + proto = None + + class ClientMyProto(MyProto): + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def factory(): + nonlocal proto + proto = MyProto(loop=self.loop) + return proto + + here = os.path.dirname(__file__) + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.load_cert_chain( + certfile=os.path.join(here, 'sample.crt'), + keyfile=os.path.join(here, 'sample.key')) + + f = self.loop.create_server( + factory, '127.0.0.1', 0, ssl=sslcontext) + + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + host, port = sock.getsockname() + self.assertEqual(host, '127.0.0.1') + + f_c = self.loop.create_connection(ClientMyProto, host, port, + ssl=test_utils.dummy_ssl_context()) + client, pr = self.loop.run_until_complete(f_c) + + client.write(b'xxx') + test_utils.run_briefly(self.loop) + self.assertIsInstance(proto, MyProto) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('sockname')) + self.assertEqual('127.0.0.1', + proto.transport.get_extra_info('peername')[0]) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # stop serving + server.close() + + def test_create_server_sock(self): + proto = futures.Future(loop=self.loop) + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + proto.set_result(self) + + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.loop.create_server(TestMyProto, sock=sock_ob) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + self.assertIs(sock, sock_ob) + + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + server.close() + + def test_create_server_addr_in_use(self): + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.loop.create_server(MyProto, sock=sock_ob) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + host, port = sock.getsockname() + + f = self.loop.create_server(MyProto, host=host, port=port) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + + server.close() + + @unittest.skipUnless(socket.has_ipv6, 'IPv6 not supported') + def test_create_server_dual_stack(self): + f_proto = futures.Future(loop=self.loop) + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + f_proto.set_result(self) + + try_count = 0 + while True: + try: + port = find_unused_port() + f = self.loop.create_server(TestMyProto, host=None, port=port) + server = self.loop.run_until_complete(f) + except OSError as ex: + if ex.errno == errno.EADDRINUSE: + try_count += 1 + self.assertGreaterEqual(5, try_count) + continue + else: + raise + else: + break + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + proto = self.loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + f_proto = futures.Future(loop=self.loop) + client = socket.socket(socket.AF_INET6) + client.connect(('::1', port)) + client.send(b'xxx') + proto = self.loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + server.close() + + def test_server_close(self): + f = self.loop.create_server(MyProto, '0.0.0.0', 0) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + host, port = sock.getsockname() + + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + + server.close() + + client = socket.socket() + self.assertRaises( + ConnectionRefusedError, client.connect, ('127.0.0.1', port)) + client.close() + + def test_create_datagram_endpoint(self): + class TestMyDatagramProto(MyDatagramProto): + def __init__(inner_self): + super().__init__(loop=self.loop) + + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + coro = self.loop.create_datagram_endpoint( + TestMyDatagramProto, local_addr=('127.0.0.1', 0)) + s_transport, server = self.loop.run_until_complete(coro) + host, port = s_transport.get_extra_info('sockname') + + coro = self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(loop=self.loop), + remote_addr=(host, port)) + transport, client = self.loop.run_until_complete(coro) + + self.assertEqual('INITIALIZED', client.state) + transport.sendto(b'xxx') + for _ in range(1000): + if server.nbytes: + break + test_utils.run_briefly(self.loop) + self.assertEqual(3, server.nbytes) + for _ in range(1000): + if client.nbytes: + break + test_utils.run_briefly(self.loop) + + # received + self.assertEqual(8, client.nbytes) + + # extra info is available + self.assertIsNotNone(transport.get_extra_info('sockname')) + + # close connection + transport.close() + self.loop.run_until_complete(client.done) + self.assertEqual('CLOSED', client.state) + server.transport.close() + + def test_internal_fds(self): + loop = self.create_event_loop() + if not isinstance(loop, selector_events.BaseSelectorEventLoop): + return + + self.assertEqual(1, loop._internal_fds) + loop.close() + self.assertEqual(0, loop._internal_fds) + self.assertIsNone(loop._csock) + self.assertIsNone(loop._ssock) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_read_pipe(self): + proto = None + + def factory(): + nonlocal proto + proto = MyReadPipeProto(loop=self.loop) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(rpipe, 'rb', 1024) + + @tasks.coroutine + def connect(): + t, p = yield from self.loop.connect_read_pipe(factory, pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(0, proto.nbytes) + + self.loop.run_until_complete(connect()) + + os.write(wpipe, b'1') + test_utils.run_briefly(self.loop) + self.assertEqual(1, proto.nbytes) + + os.write(wpipe, b'2345') + test_utils.run_briefly(self.loop) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(5, proto.nbytes) + + os.close(wpipe) + self.loop.run_until_complete(proto.done) + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto(loop=self.loop) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.coroutine + def connect(): + nonlocal transport + t, p = yield from self.loop.connect_write_pipe(factory, pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.loop.run_until_complete(connect()) + + transport.write(b'1') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + transport.write(b'2345') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'2345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(rpipe) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe_disconnect_on_close(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto(loop=self.loop) + return proto + + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + @tasks.coroutine + def connect(): + nonlocal transport + t, p = yield from self.loop.connect_write_pipe(factory, + pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.loop.run_until_complete(connect()) + self.assertEqual('CONNECTED', proto.state) + + transport.write(b'1') + test_utils.run_briefly(self.loop) + data = os.read(rpipe, 1024) + self.assertEqual(b'1', data) + + os.close(rpipe) + + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + def test_prompt_cancellation(self): + r, w = test_utils.socketpair() + r.setblocking(False) + f = self.loop.sock_recv(r, 1) + ov = getattr(f, 'ov', None) + self.assertTrue(ov is None or ov.pending) + + @tasks.coroutine + def main(): + try: + self.loop.call_soon(f.cancel) + yield from f + except futures.CancelledError: + res = 'cancelled' + else: + res = None + finally: + self.loop.stop() + return res + + start = time.monotonic() + t = tasks.Task(main(), loop=self.loop) + self.loop.run_forever() + elapsed = time.monotonic() - start + + self.assertLess(elapsed, 0.1) + self.assertEqual(t.result(), 'cancelled') + self.assertRaises(futures.CancelledError, f.result) + self.assertTrue(ov is None or not ov.pending) + self.loop._stop_serving(r) + + r.close() + w.close() + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_exec(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python The Winner') + self.loop.run_until_complete(proto.got_data[1].wait()) + transp.close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + self.assertEqual(b'Python The Winner', proto.data[1]) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_interactive(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + try: + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python ') + self.loop.run_until_complete(proto.got_data[1].wait()) + proto.got_data[1].clear() + self.assertEqual(b'Python ', proto.data[1]) + + stdin.write(b'The Winner') + self.loop.run_until_complete(proto.got_data[1].wait()) + self.assertEqual(b'Python The Winner', proto.data[1]) + finally: + transp.close() + + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_shell(self): + proto = None + transp = None + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'echo "Python"') + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.get_pipe_transport(0).close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(0, proto.returncode) + self.assertTrue(all(f.done() for f in proto.disconnects.values())) + self.assertEqual({1: b'Python\n', 2: b''}, proto.data) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_exitcode(self): + proto = None + + @tasks.coroutine + def connect(): + nonlocal proto + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_close_after_finish(self): + proto = None + transp = None + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.assertIsNone(transp.get_pipe_transport(0)) + self.assertIsNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + self.assertIsNone(transp.close()) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_kill(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.kill() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGKILL, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_send_signal(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.send_signal(signal.SIGHUP) + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGHUP, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_stderr(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'test') + + self.loop.run_until_complete(proto.completed) + + transp.close() + self.assertEqual(b'OUT:test', proto.data[1]) + self.assertTrue(proto.data[2].startswith(b'ERR:test'), proto.data[2]) + self.assertEqual(0, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_stderr_redirect_to_stdout(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog, stderr=subprocess.STDOUT) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + self.assertIsNotNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + + stdin.write(b'test') + self.loop.run_until_complete(proto.completed) + self.assertTrue(proto.data[1].startswith(b'OUT:testERR:test'), + proto.data[1]) + self.assertEqual(b'', proto.data[2]) + + transp.close() + self.assertEqual(0, proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_close_client_stream(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo3.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdout = transp.get_pipe_transport(1) + stdin.write(b'test') + self.loop.run_until_complete(proto.got_data[1].wait()) + self.assertEqual(b'OUT:test', proto.data[1]) + + stdout.close() + self.loop.run_until_complete(proto.disconnects[1]) + stdin.write(b'xxx') + self.loop.run_until_complete(proto.got_data[2].wait()) + self.assertEqual(b'ERR:BrokenPipeError', proto.data[2]) + + transp.close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGTERM, proto.returncode) + + +if sys.platform == 'win32': + from tulip import windows_events + + class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return windows_events.SelectorEventLoop() + + class ProactorEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return windows_events.ProactorEventLoop() + + def test_create_ssl_connection(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + + def test_create_server_ssl(self): + raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + + def test_reader_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_reader_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_writer_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + + def test_writer_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + + def test_create_datagram_endpoint(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") +else: + from tulip import selectors + from tulip import unix_events + + if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop( + selectors.KqueueSelector()) + + if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.EpollSelector()) + + if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.PollSelector()) + + # Should always exist. + class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + def create_event_loop(self): + return unix_events.SelectorEventLoop(selectors.SelectSelector()) + + +class HandleTests(unittest.TestCase): + + def test_handle(self): + def callback(*args): + return args + + args = () + h = events.Handle(callback, args) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h._cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '.callback')) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h._cancelled) + + r = repr(h) + self.assertTrue(r.startswith( + 'Handle(' + '.callback')) + self.assertTrue(r.endswith('())'), r) + + def test_make_handle(self): + def callback(*args): + return args + h1 = events.Handle(callback, ()) + self.assertRaises( + AssertionError, events.make_handle, h1, ()) + + @unittest.mock.patch('tulip.events.logger') + def test_callback_with_exception(self, log): + def callback(): + raise ValueError() + + h = events.Handle(callback, ()) + h._run() + self.assertTrue(log.exception.called) + + +class TimerTests(unittest.TestCase): + + def test_hash(self): + when = time.monotonic() + h = events.TimerHandle(when, lambda: False, ()) + self.assertEqual(hash(h), hash(when)) + + def test_timer(self): + def callback(*args): + return args + + args = () + when = time.monotonic() + h = events.TimerHandle(when, callback, args) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h._cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())')) + + h.cancel() + self.assertTrue(h._cancelled) + + r = repr(h) + self.assertTrue(r.endswith('())'), r) + + self.assertRaises(AssertionError, + events.TimerHandle, None, callback, args) + + def test_timer_comparison(self): + def callback(*args): + return args + + when = time.monotonic() + + h1 = events.TimerHandle(when, callback, ()) + h2 = events.TimerHandle(when, callback, ()) + # TODO: Use assertLess etc. + self.assertFalse(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertTrue(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertFalse(h2 > h1) + self.assertTrue(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertTrue(h1 == h2) + self.assertFalse(h1 != h2) + + h2.cancel() + self.assertFalse(h1 == h2) + + h1 = events.TimerHandle(when, callback, ()) + h2 = events.TimerHandle(when + 10.0, callback, ()) + self.assertTrue(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertFalse(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertTrue(h2 > h1) + self.assertFalse(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h3 = events.Handle(callback, ()) + self.assertIs(NotImplemented, h1.__eq__(h3)) + self.assertIs(NotImplemented, h1.__ne__(h3)) + + +class AbstractEventLoopTests(unittest.TestCase): + + def test_not_implemented(self): + f = unittest.mock.Mock() + loop = events.AbstractEventLoop() + self.assertRaises( + NotImplementedError, loop.run_forever) + self.assertRaises( + NotImplementedError, loop.run_until_complete, None) + self.assertRaises( + NotImplementedError, loop.stop) + self.assertRaises( + NotImplementedError, loop.is_running) + self.assertRaises( + NotImplementedError, loop.call_later, None, None) + self.assertRaises( + NotImplementedError, loop.call_at, f, f) + self.assertRaises( + NotImplementedError, loop.call_soon, None) + self.assertRaises( + NotImplementedError, loop.time) + self.assertRaises( + NotImplementedError, loop.call_soon_threadsafe, None) + self.assertRaises( + NotImplementedError, loop.run_in_executor, f, f) + self.assertRaises( + NotImplementedError, loop.set_default_executor, f) + self.assertRaises( + NotImplementedError, loop.getaddrinfo, 'localhost', 8080) + self.assertRaises( + NotImplementedError, loop.getnameinfo, ('localhost', 8080)) + self.assertRaises( + NotImplementedError, loop.create_connection, f) + self.assertRaises( + NotImplementedError, loop.create_server, f) + self.assertRaises( + NotImplementedError, loop.create_datagram_endpoint, f) + self.assertRaises( + NotImplementedError, loop.add_reader, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_reader, 1) + self.assertRaises( + NotImplementedError, loop.add_writer, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_writer, 1) + self.assertRaises( + NotImplementedError, loop.sock_recv, f, 10) + self.assertRaises( + NotImplementedError, loop.sock_sendall, f, 10) + self.assertRaises( + NotImplementedError, loop.sock_connect, f, f) + self.assertRaises( + NotImplementedError, loop.sock_accept, f) + self.assertRaises( + NotImplementedError, loop.add_signal_handler, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, loop.connect_read_pipe, f, + unittest.mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, loop.connect_write_pipe, f, + unittest.mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, loop.subprocess_shell, f, + unittest.mock.sentinel) + self.assertRaises( + NotImplementedError, loop.subprocess_exec, f) + + +class ProtocolsAbsTests(unittest.TestCase): + + def test_empty(self): + f = unittest.mock.Mock() + p = protocols.Protocol() + self.assertIsNone(p.connection_made(f)) + self.assertIsNone(p.connection_lost(f)) + self.assertIsNone(p.data_received(f)) + self.assertIsNone(p.eof_received()) + + dp = protocols.DatagramProtocol() + self.assertIsNone(dp.connection_made(f)) + self.assertIsNone(dp.connection_lost(f)) + self.assertIsNone(dp.connection_refused(f)) + self.assertIsNone(dp.datagram_received(f, f)) + + sp = protocols.SubprocessProtocol() + self.assertIsNone(sp.connection_made(f)) + self.assertIsNone(sp.connection_lost(f)) + self.assertIsNone(sp.pipe_data_received(1, f)) + self.assertIsNone(sp.pipe_connection_lost(1, f)) + self.assertIsNone(sp.process_exited()) + + +class PolicyTests(unittest.TestCase): + + def test_event_loop_policy(self): + policy = events.AbstractEventLoopPolicy() + self.assertRaises(NotImplementedError, policy.get_event_loop) + self.assertRaises(NotImplementedError, policy.set_event_loop, object()) + self.assertRaises(NotImplementedError, policy.new_event_loop) + + def test_get_event_loop(self): + policy = events.DefaultEventLoopPolicy() + self.assertIsNone(policy._loop) + + loop = policy.get_event_loop() + self.assertIsInstance(loop, events.AbstractEventLoop) + + self.assertIs(policy._loop, loop) + self.assertIs(loop, policy.get_event_loop()) + loop.close() + + def test_get_event_loop_after_set_none(self): + policy = events.DefaultEventLoopPolicy() + policy.set_event_loop(None) + self.assertRaises(AssertionError, policy.get_event_loop) + + @unittest.mock.patch('tulip.events.threading.current_thread') + def test_get_event_loop_thread(self, m_current_thread): + + def f(): + policy = events.DefaultEventLoopPolicy() + self.assertRaises(AssertionError, policy.get_event_loop) + + th = threading.Thread(target=f) + th.start() + th.join() + + def test_new_event_loop(self): + policy = events.DefaultEventLoopPolicy() + + loop = policy.new_event_loop() + self.assertIsInstance(loop, events.AbstractEventLoop) + loop.close() + + def test_set_event_loop(self): + policy = events.DefaultEventLoopPolicy() + old_loop = policy.get_event_loop() + + self.assertRaises(AssertionError, policy.set_event_loop, object()) + + loop = policy.new_event_loop() + policy.set_event_loop(loop) + self.assertIs(loop, policy.get_event_loop()) + self.assertIsNot(old_loop, policy.get_event_loop()) + loop.close() + old_loop.close() + + def test_get_event_loop_policy(self): + policy = events.get_event_loop_policy() + self.assertIsInstance(policy, events.AbstractEventLoopPolicy) + self.assertIs(policy, events.get_event_loop_policy()) + + def test_set_event_loop_policy(self): + self.assertRaises( + AssertionError, events.set_event_loop_policy, object()) + + old_policy = events.get_event_loop_policy() + + policy = events.DefaultEventLoopPolicy() + events.set_event_loop_policy(policy) + self.assertIs(policy, events.get_event_loop_policy()) + self.assertIsNot(policy, old_policy) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/futures_test.py b/tests/futures_test.py new file mode 100644 index 00000000..13a1dd93 --- /dev/null +++ b/tests/futures_test.py @@ -0,0 +1,329 @@ +"""Tests for futures.py.""" + +import concurrent.futures +import threading +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import test_utils + + +def _fakefunc(f): + return f + + +class FutureTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_initial_state(self): + f = futures.Future(loop=self.loop) + self.assertFalse(f.cancelled()) + self.assertFalse(f.done()) + f.cancel() + self.assertTrue(f.cancelled()) + + def test_init_constructor_default_loop(self): + try: + events.set_event_loop(self.loop) + f = futures.Future() + self.assertIs(f._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_constructor_positional(self): + # Make sure Future does't accept a positional argument + self.assertRaises(TypeError, futures.Future, 42) + + def test_cancel(self): + f = futures.Future(loop=self.loop) + self.assertTrue(f.cancel()) + self.assertTrue(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(futures.CancelledError, f.result) + self.assertRaises(futures.CancelledError, f.exception) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_result(self): + f = futures.Future(loop=self.loop) + self.assertRaises(futures.InvalidStateError, f.result) + + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_exception(self): + exc = RuntimeError() + f = futures.Future(loop=self.loop) + self.assertRaises(futures.InvalidStateError, f.exception) + + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + self.assertRaises(futures.InvalidStateError, f.set_result, None) + self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_yield_from_twice(self): + f = futures.Future(loop=self.loop) + + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + + def test_repr(self): + f_pending = futures.Future(loop=self.loop) + self.assertEqual(repr(f_pending), 'Future') + f_pending.cancel() + + f_cancelled = futures.Future(loop=self.loop) + f_cancelled.cancel() + self.assertEqual(repr(f_cancelled), 'Future') + + f_result = futures.Future(loop=self.loop) + f_result.set_result(4) + self.assertEqual(repr(f_result), 'Future') + self.assertEqual(f_result.result(), 4) + + exc = RuntimeError() + f_exception = futures.Future(loop=self.loop) + f_exception.set_exception(exc) + self.assertEqual(repr(f_exception), 'Future') + self.assertIs(f_exception.exception(), exc) + + f_few_callbacks = futures.Future(loop=self.loop) + f_few_callbacks.add_done_callback(_fakefunc) + self.assertIn('Future', r) + f_many_callbacks.cancel() + + def test_copy_state(self): + # Test the internal _copy_state method since it's being directly + # invoked in other modules. + f = futures.Future(loop=self.loop) + f.set_result(10) + + newf = futures.Future(loop=self.loop) + newf._copy_state(f) + self.assertTrue(newf.done()) + self.assertEqual(newf.result(), 10) + + f_exception = futures.Future(loop=self.loop) + f_exception.set_exception(RuntimeError()) + + newf_exception = futures.Future(loop=self.loop) + newf_exception._copy_state(f_exception) + self.assertTrue(newf_exception.done()) + self.assertRaises(RuntimeError, newf_exception.result) + + f_cancelled = futures.Future(loop=self.loop) + f_cancelled.cancel() + + newf_cancelled = futures.Future(loop=self.loop) + newf_cancelled._copy_state(f_cancelled) + self.assertTrue(newf_cancelled.cancelled()) + + def test_iter(self): + fut = futures.Future(loop=self.loop) + + def coro(): + yield from fut + + def test(): + arg1, arg2 = coro() + + self.assertRaises(AssertionError, test) + fut.cancel() + + @unittest.mock.patch('tulip.futures.logger') + def test_tb_logger_abandoned(self, m_log): + fut = futures.Future(loop=self.loop) + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('tulip.futures.logger') + def test_tb_logger_result_unretrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_result(42) + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('tulip.futures.logger') + def test_tb_logger_result_retrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_result(42) + fut.result() + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('tulip.futures.logger') + def test_tb_logger_exception_unretrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + del fut + test_utils.run_briefly(self.loop) + self.assertTrue(m_log.error.called) + + @unittest.mock.patch('tulip.futures.logger') + def test_tb_logger_exception_retrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + fut.exception() + del fut + self.assertFalse(m_log.error.called) + + @unittest.mock.patch('tulip.futures.logger') + def test_tb_logger_exception_result_retrieved(self, m_log): + fut = futures.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + self.assertRaises(RuntimeError, fut.result) + del fut + self.assertFalse(m_log.error.called) + + def test_wrap_future(self): + + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = futures.wrap_future(f1, loop=self.loop) + res, ident = self.loop.run_until_complete(f2) + self.assertIsInstance(f2, futures.Future) + self.assertEqual(res, 'oi') + self.assertNotEqual(ident, threading.get_ident()) + + def test_wrap_future_future(self): + f1 = futures.Future(loop=self.loop) + f2 = futures.wrap_future(f1) + self.assertIs(f1, f2) + + @unittest.mock.patch('tulip.futures.events') + def test_wrap_future_use_global_loop(self, m_events): + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = futures.wrap_future(f1) + self.assertIs(m_events.get_event_loop.return_value, f2._loop) + + +class FutureDoneCallbackTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def run_briefly(self): + test_utils.run_briefly(self.loop) + + def _make_callback(self, bag, thing): + # Create a callback function that appends thing to bag. + def bag_appender(future): + bag.append(thing) + return bag_appender + + def _new_future(self): + return futures.Future(loop=self.loop) + + def test_callbacks_invoked_on_set_result(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 42)) + f.add_done_callback(self._make_callback(bag, 17)) + + self.assertEqual(bag, []) + f.set_result('foo') + + self.run_briefly() + + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def test_callbacks_invoked_on_set_exception(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 100)) + + self.assertEqual(bag, []) + exc = RuntimeError() + f.set_exception(exc) + + self.run_briefly() + + self.assertEqual(bag, [100]) + self.assertEqual(f.exception(), exc) + + def test_remove_done_callback(self): + bag = [] + f = self._new_future() + cb1 = self._make_callback(bag, 1) + cb2 = self._make_callback(bag, 2) + cb3 = self._make_callback(bag, 3) + + # Add one cb1 and one cb2. + f.add_done_callback(cb1) + f.add_done_callback(cb2) + + # One instance of cb2 removed. Now there's only one cb1. + self.assertEqual(f.remove_done_callback(cb2), 1) + + # Never had any cb3 in there. + self.assertEqual(f.remove_done_callback(cb3), 0) + + # After this there will be 6 instances of cb1 and one of cb2. + f.add_done_callback(cb2) + for i in range(5): + f.add_done_callback(cb1) + + # Remove all instances of cb1. One cb2 remains. + self.assertEqual(f.remove_done_callback(cb1), 6) + + self.assertEqual(bag, []) + f.set_result('foo') + + self.run_briefly() + + self.assertEqual(bag, [2]) + self.assertEqual(f.result(), 'foo') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/locks_test.py b/tests/locks_test.py new file mode 100644 index 00000000..7c138eef --- /dev/null +++ b/tests/locks_test.py @@ -0,0 +1,765 @@ +"""Tests for lock.py""" + +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import locks +from tulip import tasks +from tulip import test_utils + + +class LockTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + lock = locks.Lock(loop=loop) + self.assertIs(lock._loop, loop) + + lock = locks.Lock(loop=self.loop) + self.assertIs(lock._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + lock = locks.Lock() + self.assertIs(lock._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + lock = locks.Lock(loop=self.loop) + self.assertTrue(repr(lock).endswith('[unlocked]>')) + + @tasks.coroutine + def acquire_lock(): + yield from lock + + self.loop.run_until_complete(acquire_lock()) + self.assertTrue(repr(lock).endswith('[locked]>')) + + def test_lock(self): + lock = locks.Lock(loop=self.loop) + + @tasks.coroutine + def acquire_lock(): + return (yield from lock) + + res = self.loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_acquire(self): + lock = locks.Lock(loop=self.loop) + result = [] + + self.assertTrue(self.loop.run_until_complete(lock.acquire())) + + @tasks.coroutine + def c1(result): + if (yield from lock.acquire()): + result.append(1) + return True + + @tasks.coroutine + def c2(result): + if (yield from lock.acquire()): + result.append(2) + return True + + @tasks.coroutine + def c3(result): + if (yield from lock.acquire()): + result.append(3) + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + t3 = tasks.Task(c3(result), loop=self.loop) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_acquire_cancel(self): + lock = locks.Lock(loop=self.loop) + self.assertTrue(self.loop.run_until_complete(lock.acquire())) + + task = tasks.Task(lock.acquire(), loop=self.loop) + self.loop.call_soon(task.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, task) + self.assertFalse(lock._waiters) + + def test_cancel_race(self): + # Several tasks: + # - A acquires the lock + # - B is blocked in aqcuire() + # - C is blocked in aqcuire() + # + # Now, concurrently: + # - B is cancelled + # - A releases the lock + # + # If B's waiter is marked cancelled but not yet removed from + # _waiters, A's release() call will crash when trying to set + # B's waiter; instead, it should move on to C's waiter. + + # Setup: A has the lock, b and c are waiting. + lock = locks.Lock(loop=self.loop) + + @tasks.coroutine + def lockit(name, blocker): + yield from lock.acquire() + try: + if blocker is not None: + yield from blocker + finally: + lock.release() + + fa = futures.Future(loop=self.loop) + ta = tasks.Task(lockit('A', fa), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertTrue(lock.locked()) + tb = tasks.Task(lockit('B', None), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(len(lock._waiters), 1) + tc = tasks.Task(lockit('C', None), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(len(lock._waiters), 2) + + # Create the race and check. + # Without the fix this failed at the last assert. + fa.set_result(None) + tb.cancel() + self.assertTrue(lock._waiters[0].cancelled()) + test_utils.run_briefly(self.loop) + self.assertFalse(lock.locked()) + self.assertTrue(ta.done()) + self.assertTrue(tb.cancelled()) + self.assertTrue(tc.done()) + + def test_release_not_acquired(self): + lock = locks.Lock(loop=self.loop) + + self.assertRaises(RuntimeError, lock.release) + + def test_release_no_waiters(self): + lock = locks.Lock(loop=self.loop) + self.loop.run_until_complete(lock.acquire()) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_context_manager(self): + lock = locks.Lock(loop=self.loop) + + @tasks.coroutine + def acquire_lock(): + return (yield from lock) + + with self.loop.run_until_complete(acquire_lock()): + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + def test_context_manager_no_yield(self): + lock = locks.Lock(loop=self.loop) + + try: + with lock: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + +class EventTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + ev = locks.Event(loop=loop) + self.assertIs(ev._loop, loop) + + ev = locks.Event(loop=self.loop) + self.assertIs(ev._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + ev = locks.Event() + self.assertIs(ev._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + ev = locks.Event(loop=self.loop) + self.assertTrue(repr(ev).endswith('[unset]>')) + + ev.set() + self.assertTrue(repr(ev).endswith('[set]>')) + + def test_wait(self): + ev = locks.Event(loop=self.loop) + self.assertFalse(ev.is_set()) + + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + @tasks.coroutine + def c2(result): + if (yield from ev.wait()): + result.append(2) + + @tasks.coroutine + def c3(result): + if (yield from ev.wait()): + result.append(3) + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + t3 = tasks.Task(c3(result), loop=self.loop) + + ev.set() + test_utils.run_briefly(self.loop) + self.assertEqual([3, 1, 2], result) + + self.assertTrue(t1.done()) + self.assertIsNone(t1.result()) + self.assertTrue(t2.done()) + self.assertIsNone(t2.result()) + self.assertTrue(t3.done()) + self.assertIsNone(t3.result()) + + def test_wait_on_set(self): + ev = locks.Event(loop=self.loop) + ev.set() + + res = self.loop.run_until_complete(ev.wait()) + self.assertTrue(res) + + def test_wait_cancel(self): + ev = locks.Event(loop=self.loop) + + wait = tasks.Task(ev.wait(), loop=self.loop) + self.loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, wait) + self.assertFalse(ev._waiters) + + def test_clear(self): + ev = locks.Event(loop=self.loop) + self.assertFalse(ev.is_set()) + + ev.set() + self.assertTrue(ev.is_set()) + + ev.clear() + self.assertFalse(ev.is_set()) + + def test_clear_with_waiters(self): + ev = locks.Event(loop=self.loop) + result = [] + + @tasks.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + return True + + t = tasks.Task(c1(result), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + ev.set() + ev.clear() + self.assertFalse(ev.is_set()) + + ev.set() + ev.set() + self.assertEqual(1, len(ev._waiters)) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertEqual(0, len(ev._waiters)) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + +class ConditionTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + cond = locks.Condition(loop=loop) + self.assertIs(cond._loop, loop) + + cond = locks.Condition(loop=self.loop) + self.assertIs(cond._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + cond = locks.Condition() + self.assertIs(cond._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_wait(self): + cond = locks.Condition(loop=self.loop) + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + return True + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue(self.loop.run_until_complete(cond.acquire())) + cond.notify() + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_wait_cancel(self): + cond = locks.Condition(loop=self.loop) + self.loop.run_until_complete(cond.acquire()) + + wait = tasks.Task(cond.wait(), loop=self.loop) + self.loop.call_soon(wait.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, wait) + self.assertFalse(cond._condition_waiters) + self.assertTrue(cond.locked()) + + def test_wait_unacquired(self): + cond = locks.Condition(loop=self.loop) + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, cond.wait()) + + def test_wait_for(self): + cond = locks.Condition(loop=self.loop) + presult = False + + def predicate(): + return presult + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate)): + result.append(1) + cond.release() + return True + + t = tasks.Task(c1(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + presult = True + self.loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + def test_wait_for_unacquired(self): + cond = locks.Condition(loop=self.loop) + + # predicate can return true immediately + res = self.loop.run_until_complete(cond.wait_for(lambda: [1, 2, 3])) + self.assertEqual([1, 2, 3], res) + + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, + cond.wait_for(lambda: False)) + + def test_notify(self): + cond = locks.Condition(loop=self.loop) + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + return True + + @tasks.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + cond.release() + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.notify(2048) + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_notify_all(self): + cond = locks.Condition(loop=self.loop) + + result = [] + + @tasks.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + return True + + @tasks.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify_all() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + + def test_notify_unacquired(self): + cond = locks.Condition(loop=self.loop) + self.assertRaises(RuntimeError, cond.notify) + + def test_notify_all_unacquired(self): + cond = locks.Condition(loop=self.loop) + self.assertRaises(RuntimeError, cond.notify_all) + + +class SemaphoreTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_ctor_loop(self): + loop = unittest.mock.Mock() + sem = locks.Semaphore(loop=loop) + self.assertIs(sem._loop, loop) + + sem = locks.Semaphore(loop=self.loop) + self.assertIs(sem._loop, self.loop) + + def test_ctor_noloop(self): + try: + events.set_event_loop(self.loop) + sem = locks.Semaphore() + self.assertIs(sem._loop, self.loop) + finally: + events.set_event_loop(None) + + def test_repr(self): + sem = locks.Semaphore(loop=self.loop) + self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) + + self.loop.run_until_complete(sem.acquire()) + self.assertTrue(repr(sem).endswith('[locked]>')) + + def test_semaphore(self): + sem = locks.Semaphore(loop=self.loop) + self.assertEqual(1, sem._value) + + @tasks.coroutine + def acquire_lock(): + return (yield from sem) + + res = self.loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(sem.locked()) + self.assertEqual(0, sem._value) + + sem.release() + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + def test_semaphore_value(self): + self.assertRaises(ValueError, locks.Semaphore, -1) + + def test_acquire(self): + sem = locks.Semaphore(3, loop=self.loop) + result = [] + + self.assertTrue(self.loop.run_until_complete(sem.acquire())) + self.assertTrue(self.loop.run_until_complete(sem.acquire())) + self.assertFalse(sem.locked()) + + @tasks.coroutine + def c1(result): + yield from sem.acquire() + result.append(1) + return True + + @tasks.coroutine + def c2(result): + yield from sem.acquire() + result.append(2) + return True + + @tasks.coroutine + def c3(result): + yield from sem.acquire() + result.append(3) + return True + + @tasks.coroutine + def c4(result): + yield from sem.acquire() + result.append(4) + return True + + t1 = tasks.Task(c1(result), loop=self.loop) + t2 = tasks.Task(c2(result), loop=self.loop) + t3 = tasks.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(sem.locked()) + self.assertEqual(2, len(sem._waiters)) + self.assertEqual(0, sem._value) + + t4 = tasks.Task(c4(result), loop=self.loop) + + sem.release() + sem.release() + self.assertEqual(2, sem._value) + + test_utils.run_briefly(self.loop) + self.assertEqual(0, sem._value) + self.assertEqual([1, 2, 3], result) + self.assertTrue(sem.locked()) + self.assertEqual(1, len(sem._waiters)) + self.assertEqual(0, sem._value) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + self.assertFalse(t4.done()) + + # cleanup locked semaphore + sem.release() + + def test_acquire_cancel(self): + sem = locks.Semaphore(loop=self.loop) + self.loop.run_until_complete(sem.acquire()) + + acquire = tasks.Task(sem.acquire(), loop=self.loop) + self.loop.call_soon(acquire.cancel) + self.assertRaises( + futures.CancelledError, + self.loop.run_until_complete, acquire) + self.assertFalse(sem._waiters) + + def test_release_not_acquired(self): + sem = locks.Semaphore(bound=True, loop=self.loop) + + self.assertRaises(ValueError, sem.release) + + def test_release_no_waiters(self): + sem = locks.Semaphore(loop=self.loop) + self.loop.run_until_complete(sem.acquire()) + self.assertTrue(sem.locked()) + + sem.release() + self.assertFalse(sem.locked()) + + def test_context_manager(self): + sem = locks.Semaphore(2, loop=self.loop) + + @tasks.coroutine + def acquire_lock(): + return (yield from sem) + + with self.loop.run_until_complete(acquire_lock()): + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + with self.loop.run_until_complete(acquire_lock()): + self.assertTrue(sem.locked()) + + self.assertEqual(2, sem._value) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/proactor_events_test.py b/tests/proactor_events_test.py new file mode 100644 index 00000000..c1240838 --- /dev/null +++ b/tests/proactor_events_test.py @@ -0,0 +1,480 @@ +"""Tests for proactor_events.py""" + +import socket +import unittest +import unittest.mock + +import tulip +from tulip.proactor_events import BaseProactorEventLoop +from tulip.proactor_events import _ProactorSocketTransport +from tulip.proactor_events import _ProactorWritePipeTransport +from tulip.proactor_events import _ProactorDuplexPipeTransport +from tulip import test_utils + + +class ProactorSocketTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.proactor = unittest.mock.Mock() + self.loop._proactor = self.proactor + self.protocol = test_utils.make_test_protocol(tulip.Protocol) + self.sock = unittest.mock.Mock(socket.socket) + + def test_ctor(self): + fut = tulip.Future(loop=self.loop) + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol, fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + self.protocol.connection_made(tr) + self.proactor.recv.assert_called_with(self.sock, 4096) + + def test_loop_reading(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._loop_reading() + self.loop._proactor.recv.assert_called_with(self.sock, 4096) + self.assertFalse(self.protocol.data_received.called) + self.assertFalse(self.protocol.eof_received.called) + + def test_loop_reading_data(self): + res = tulip.Future(loop=self.loop) + res.set_result(b'data') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + + tr._read_fut = res + tr._loop_reading(res) + self.loop._proactor.recv.assert_called_with(self.sock, 4096) + self.protocol.data_received.assert_called_with(b'data') + + def test_loop_reading_no_data(self): + res = tulip.Future(loop=self.loop) + res.set_result(b'') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + + self.assertRaises(AssertionError, tr._loop_reading, res) + + tr.close = unittest.mock.Mock() + tr._read_fut = res + tr._loop_reading(res) + self.assertFalse(self.loop._proactor.recv.called) + self.assertTrue(self.protocol.eof_received.called) + self.assertTrue(tr.close.called) + + def test_loop_reading_aborted(self): + err = self.loop._proactor.recv.side_effect = ConnectionAbortedError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with(err) + + def test_loop_reading_aborted_closing(self): + self.loop._proactor.recv.side_effect = ConnectionAbortedError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = True + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + + def test_loop_reading_aborted_is_fatal(self): + self.loop._proactor.recv.side_effect = ConnectionAbortedError() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = False + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + self.assertTrue(tr._fatal_error.called) + + def test_loop_reading_conn_reset_lost(self): + err = self.loop._proactor.recv.side_effect = ConnectionResetError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = False + tr._fatal_error = unittest.mock.Mock() + tr._force_close = unittest.mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + tr._force_close.assert_called_with(err) + + def test_loop_reading_exception(self): + err = self.loop._proactor.recv.side_effect = (OSError()) + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with(err) + + def test_write(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._loop_writing = unittest.mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, [b'data']) + self.assertTrue(tr._loop_writing.called) + + def test_write_no_data(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr.write(b'') + self.assertFalse(tr._buffer) + + def test_write_more(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = unittest.mock.Mock() + tr._loop_writing = unittest.mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, [b'data']) + self.assertFalse(tr._loop_writing.called) + + def test_loop_writing(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + self.loop._proactor.send.assert_called_with(self.sock, b'data') + self.loop._proactor.send.return_value.add_done_callback.\ + assert_called_with(tr._loop_writing) + + @unittest.mock.patch('tulip.proactor_events.logger') + def test_loop_writing_err(self, m_log): + err = self.loop._proactor.send.side_effect = OSError() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = unittest.mock.Mock() + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + tr._fatal_error.assert_called_with(err) + tr._conn_lost = 1 + + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + self.assertEqual(tr._buffer, []) + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_loop_writing_stop(self): + fut = tulip.Future(loop=self.loop) + fut.set_result(b'data') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = fut + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + + def test_loop_writing_closing(self): + fut = tulip.Future(loop=self.loop) + fut.set_result(1) + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = fut + tr.close() + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test_abort(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._force_close = unittest.mock.Mock() + tr.abort() + tr._force_close.assert_called_with(None) + + def test_close(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr.close() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertTrue(tr._closing) + self.assertEqual(tr._conn_lost, 1) + + self.protocol.connection_lost.reset_mock() + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_write_fut(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = unittest.mock.Mock() + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_buffer(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + @unittest.mock.patch('tulip.proactor_events.logger') + def test_fatal_error(self, m_logging): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._force_close = unittest.mock.Mock() + tr._fatal_error(None) + self.assertTrue(tr._force_close.called) + self.assertTrue(m_logging.exception.called) + + def test_force_close(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + read_fut = tr._read_fut = unittest.mock.Mock() + write_fut = tr._write_fut = unittest.mock.Mock() + tr._force_close(None) + + read_fut.cancel.assert_called_with() + write_fut.cancel.assert_called_with() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertEqual([], tr._buffer) + self.assertEqual(tr._conn_lost, 1) + + def test_force_close_idempotent(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = True + tr._force_close(None) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_fatal_error_2(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + tr._force_close(None) + + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertEqual([], tr._buffer) + + def test_call_connection_lost(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._call_connection_lost(None) + self.assertTrue(self.protocol.connection_lost.called) + self.assertTrue(self.sock.close.called) + + def test_write_eof(self): + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol) + self.assertTrue(tr.can_write_eof()) + tr.write_eof() + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.write_eof() + self.assertEqual(self.sock.shutdown.call_count, 1) + tr.close() + + def test_write_eof_buffer(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + f = tulip.Future(loop=self.loop) + tr._loop._proactor.send.return_value = f + tr.write(b'data') + tr.write_eof() + self.assertTrue(tr._eof_written) + self.assertFalse(self.sock.shutdown.called) + tr._loop._proactor.send.assert_called_with(self.sock, b'data') + f.set_result(4) + self.loop._run_once() + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.close() + + def test_write_eof_write_pipe(self): + tr = _ProactorWritePipeTransport( + self.loop, self.sock, self.protocol) + self.assertTrue(tr.can_write_eof()) + tr.write_eof() + self.assertTrue(tr._closing) + self.loop._run_once() + self.assertTrue(self.sock.close.called) + tr.close() + + def test_write_eof_buffer_write_pipe(self): + tr = _ProactorWritePipeTransport(self.loop, self.sock, self.protocol) + f = tulip.Future(loop=self.loop) + tr._loop._proactor.send.return_value = f + tr.write(b'data') + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.sock.shutdown.called) + tr._loop._proactor.send.assert_called_with(self.sock, b'data') + f.set_result(4) + self.loop._run_once() + self.loop._run_once() + self.assertTrue(self.sock.close.called) + tr.close() + + def test_write_eof_duplex_pipe(self): + tr = _ProactorDuplexPipeTransport( + self.loop, self.sock, self.protocol) + self.assertFalse(tr.can_write_eof()) + with self.assertRaises(NotImplementedError): + tr.write_eof() + tr.close() + + def test_pause_resume_reading(self): + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol) + futures = [] + for msg in [b'data1', b'data2', b'data3', b'data4', b'']: + f = tulip.Future(loop=self.loop) + f.set_result(msg) + futures.append(f) + self.loop._proactor.recv.side_effect = futures + self.loop._run_once() + self.assertFalse(tr._paused) + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data1') + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data2') + tr.pause_reading() + self.assertTrue(tr._paused) + for i in range(10): + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data2') + tr.resume_reading() + self.assertFalse(tr._paused) + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data3') + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data4') + tr.close() + + +class BaseProactorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.sock = unittest.mock.Mock(socket.socket) + self.proactor = unittest.mock.Mock() + + self.ssock, self.csock = unittest.mock.Mock(), unittest.mock.Mock() + + class EventLoop(BaseProactorEventLoop): + def _socketpair(s): + return (self.ssock, self.csock) + + self.loop = EventLoop(self.proactor) + + @unittest.mock.patch.object(BaseProactorEventLoop, 'call_soon') + @unittest.mock.patch.object(BaseProactorEventLoop, '_socketpair') + def test_ctor(self, socketpair, call_soon): + ssock, csock = socketpair.return_value = ( + unittest.mock.Mock(), unittest.mock.Mock()) + loop = BaseProactorEventLoop(self.proactor) + self.assertIs(loop._ssock, ssock) + self.assertIs(loop._csock, csock) + self.assertEqual(loop._internal_fds, 1) + call_soon.assert_called_with(loop._loop_self_reading) + + def test_close_self_pipe(self): + self.loop._close_self_pipe() + self.assertEqual(self.loop._internal_fds, 0) + self.assertTrue(self.ssock.close.called) + self.assertTrue(self.csock.close.called) + self.assertIsNone(self.loop._ssock) + self.assertIsNone(self.loop._csock) + + def test_close(self): + self.loop._close_self_pipe = unittest.mock.Mock() + self.loop.close() + self.assertTrue(self.loop._close_self_pipe.called) + self.assertTrue(self.proactor.close.called) + self.assertIsNone(self.loop._proactor) + + self.loop._close_self_pipe.reset_mock() + self.loop.close() + self.assertFalse(self.loop._close_self_pipe.called) + + def test_sock_recv(self): + self.loop.sock_recv(self.sock, 1024) + self.proactor.recv.assert_called_with(self.sock, 1024) + + def test_sock_sendall(self): + self.loop.sock_sendall(self.sock, b'data') + self.proactor.send.assert_called_with(self.sock, b'data') + + def test_sock_connect(self): + self.loop.sock_connect(self.sock, 123) + self.proactor.connect.assert_called_with(self.sock, 123) + + def test_sock_accept(self): + self.loop.sock_accept(self.sock) + self.proactor.accept.assert_called_with(self.sock) + + def test_socketpair(self): + self.assertRaises( + NotImplementedError, BaseProactorEventLoop, self.proactor) + + def test_make_socket_transport(self): + tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock()) + self.assertIsInstance(tr, _ProactorSocketTransport) + + def test_loop_self_reading(self): + self.loop._loop_self_reading() + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_fut(self): + fut = unittest.mock.Mock() + self.loop._loop_self_reading(fut) + self.assertTrue(fut.result.called) + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_exception(self): + self.loop.close = unittest.mock.Mock() + self.proactor.recv.side_effect = OSError() + self.assertRaises(OSError, self.loop._loop_self_reading) + self.assertTrue(self.loop.close.called) + + def test_write_to_self(self): + self.loop._write_to_self() + self.csock.send.assert_called_with(b'x') + + def test_process_events(self): + self.loop._process_events([]) + + @unittest.mock.patch('tulip.proactor_events.logger') + def test_create_server(self, m_log): + pf = unittest.mock.Mock() + call_soon = self.loop.call_soon = unittest.mock.Mock() + + self.loop._start_serving(pf, self.sock) + self.assertTrue(call_soon.called) + + # callback + loop = call_soon.call_args[0][0] + loop() + self.proactor.accept.assert_called_with(self.sock) + + # conn + fut = unittest.mock.Mock() + fut.result.return_value = (unittest.mock.Mock(), unittest.mock.Mock()) + + make_tr = self.loop._make_socket_transport = unittest.mock.Mock() + loop(fut) + self.assertTrue(fut.result.called) + self.assertTrue(make_tr.called) + + # exception + fut.result.side_effect = OSError() + loop(fut) + self.assertTrue(self.sock.close.called) + self.assertTrue(m_log.exception.called) + + def test_create_server_cancel(self): + pf = unittest.mock.Mock() + call_soon = self.loop.call_soon = unittest.mock.Mock() + + self.loop._start_serving(pf, self.sock) + loop = call_soon.call_args[0][0] + + # cancelled + fut = tulip.Future(loop=self.loop) + fut.cancel() + loop(fut) + self.assertTrue(self.sock.close.called) + + def test_stop_serving(self): + sock = unittest.mock.Mock() + self.loop._stop_serving(sock) + self.assertTrue(sock.close.called) + self.proactor._stop_serving.assert_called_with(sock) diff --git a/tests/queues_test.py b/tests/queues_test.py new file mode 100644 index 00000000..ae1e3a14 --- /dev/null +++ b/tests/queues_test.py @@ -0,0 +1,470 @@ +"""Tests for queues.py""" + +import unittest +import unittest.mock + +from tulip import events +from tulip import futures +from tulip import locks +from tulip import queues +from tulip import tasks +from tulip import test_utils + + +class _QueueTestBase(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + +class QueueBasicTests(_QueueTestBase): + + def _test_repr_or_str(self, fn, expect_id): + """Test Queue's repr or str. + + fn is repr or str. expect_id is True if we expect the Queue's id to + appear in fn(Queue()). + """ + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.2, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + q = queues.Queue(loop=loop) + self.assertTrue(fn(q).startswith(')') + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), 'Task()') + self.assertRaises(futures.CancelledError, + self.loop.run_until_complete, t) + self.assertEqual(repr(t), 'Task()') + t = tasks.Task(notmuch(), loop=self.loop) + self.loop.run_until_complete(t) + self.assertEqual(repr(t), "Task()") + + def test_task_repr_custom(self): + @tasks.coroutine + def coro(): + pass + + class T(futures.Future): + def __repr__(self): + return 'T[]' + + class MyTask(tasks.Task, T): + def __repr__(self): + return super().__repr__() + + gen = coro() + t = MyTask(gen, loop=self.loop) + self.assertEqual(repr(t), 'T[]()') + gen.close() + + def test_task_basics(self): + @tasks.coroutine + def outer(): + a = yield from inner1() + b = yield from inner2() + return a+b + + @tasks.coroutine + def inner1(): + return 42 + + @tasks.coroutine + def inner2(): + return 1000 + + t = outer() + self.assertEqual(self.loop.run_until_complete(t), 1042) + + def test_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def task(): + yield from tasks.sleep(10.0, loop=loop) + return 12 + + t = tasks.Task(task(), loop=loop) + loop.call_soon(t.cancel) + with self.assertRaises(futures.CancelledError): + loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) + self.assertFalse(t.cancel()) + + def test_cancel_yield(self): + @tasks.coroutine + def task(): + yield + yield + return 12 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) # start coro + t.cancel() + self.assertRaises( + futures.CancelledError, self.loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) + self.assertFalse(t.cancel()) + + def test_cancel_inner_future(self): + f = futures.Future(loop=self.loop) + + @tasks.coroutine + def task(): + yield from f + return 12 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) # start task + f.cancel() + with self.assertRaises(futures.CancelledError): + self.loop.run_until_complete(t) + self.assertTrue(f.cancelled()) + self.assertTrue(t.cancelled()) + + def test_cancel_both_task_and_inner_future(self): + f = futures.Future(loop=self.loop) + + @tasks.coroutine + def task(): + yield from f + return 12 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + + f.cancel() + t.cancel() + + with self.assertRaises(futures.CancelledError): + self.loop.run_until_complete(t) + + self.assertTrue(t.done()) + self.assertTrue(f.cancelled()) + self.assertTrue(t.cancelled()) + + def test_cancel_task_catching(self): + fut1 = futures.Future(loop=self.loop) + fut2 = futures.Future(loop=self.loop) + + @tasks.coroutine + def task(): + yield from fut1 + try: + yield from fut2 + except futures.CancelledError: + return 42 + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut1) # White-box test. + fut1.set_result(None) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut2) # White-box test. + t.cancel() + self.assertTrue(fut2.cancelled()) + res = self.loop.run_until_complete(t) + self.assertEqual(res, 42) + self.assertFalse(t.cancelled()) + + def test_cancel_task_ignoring(self): + fut1 = futures.Future(loop=self.loop) + fut2 = futures.Future(loop=self.loop) + fut3 = futures.Future(loop=self.loop) + + @tasks.coroutine + def task(): + yield from fut1 + try: + yield from fut2 + except futures.CancelledError: + pass + res = yield from fut3 + return res + + t = tasks.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut1) # White-box test. + fut1.set_result(None) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut2) # White-box test. + t.cancel() + self.assertTrue(fut2.cancelled()) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut3) # White-box test. + fut3.set_result(42) + res = self.loop.run_until_complete(t) + self.assertEqual(res, 42) + self.assertFalse(fut3.cancelled()) + self.assertFalse(t.cancelled()) + + def test_cancel_current_task(self): + loop = events.new_event_loop() + self.addCleanup(loop.close) + + @tasks.coroutine + def task(): + t.cancel() + self.assertTrue(t._must_cancel) # White-box test. + # The sleep should be cancelled immediately. + yield from tasks.sleep(100, loop=loop) + return 12 + + t = tasks.Task(task(), loop=loop) + self.assertRaises( + futures.CancelledError, loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t._must_cancel) # White-box test. + self.assertFalse(t.cancel()) + + def test_stop_while_run_in_complete(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.2, when) + when = yield 0.1 + self.assertAlmostEqual(0.3, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + x = 0 + waiters = [] + + @tasks.coroutine + def task(): + nonlocal x + while x < 10: + waiters.append(tasks.sleep(0.1, loop=loop)) + yield from waiters[-1] + x += 1 + if x == 2: + loop.stop() + + t = tasks.Task(task(), loop=loop) + self.assertRaises( + RuntimeError, loop.run_until_complete, t) + self.assertFalse(t.done()) + self.assertEqual(x, 2) + self.assertAlmostEqual(0.3, loop.time()) + + # close generators + for w in waiters: + w.close() + + def test_wait_for(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.2, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.4, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def foo(): + yield from tasks.sleep(0.2, loop=loop) + return 'done' + + fut = tasks.Task(foo(), loop=loop) + + with self.assertRaises(futures.TimeoutError): + loop.run_until_complete(tasks.wait_for(fut, 0.1, loop=loop)) + + self.assertFalse(fut.done()) + self.assertAlmostEqual(0.1, loop.time()) + + # wait for result + res = loop.run_until_complete( + tasks.wait_for(fut, 0.3, loop=loop)) + self.assertEqual(res, 'done') + self.assertAlmostEqual(0.2, loop.time()) + + def test_wait_for_with_global_loop(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.2, when) + when = yield 0 + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def foo(): + yield from tasks.sleep(0.2, loop=loop) + return 'done' + + events.set_event_loop(loop) + try: + fut = tasks.Task(foo(), loop=loop) + with self.assertRaises(futures.TimeoutError): + loop.run_until_complete(tasks.wait_for(fut, 0.01)) + finally: + events.set_event_loop(None) + + self.assertAlmostEqual(0.01, loop.time()) + self.assertFalse(fut.done()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(fut) + + def test_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], loop=loop) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertEqual(res, 42) + self.assertAlmostEqual(0.15, loop.time()) + + # Doing it again should take no time and exercise a different path. + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + self.assertEqual(res, 42) + + def test_wait_with_global_loop(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + when = yield 0 + self.assertAlmostEqual(0.015, when) + yield 0.015 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.01, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.015, loop=loop), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + events.set_event_loop(loop) + try: + res = loop.run_until_complete( + tasks.Task(foo(), loop=loop)) + finally: + events.set_event_loop(None) + + self.assertEqual(res, 42) + + def test_wait_errors(self): + self.assertRaises( + ValueError, self.loop.run_until_complete, + tasks.wait(set(), loop=self.loop)) + + self.assertRaises( + ValueError, self.loop.run_until_complete, + tasks.wait([tasks.sleep(10.0, loop=self.loop)], + return_when=-1, loop=self.loop)) + + def test_wait_first_completed(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED, + loop=loop), + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertFalse(a.done()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + self.assertAlmostEqual(0.1, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_really_done(self): + # there is possibility that some tasks in the pending list + # became done but their callbacks haven't all been called yet + + @tasks.coroutine + def coro1(): + yield + + @tasks.coroutine + def coro2(): + yield + yield + + a = tasks.Task(coro1(), loop=self.loop) + b = tasks.Task(coro2(), loop=self.loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED, + loop=self.loop), + loop=self.loop) + + done, pending = self.loop.run_until_complete(task) + self.assertEqual({a, b}, done) + self.assertTrue(a.done()) + self.assertIsNone(a.result()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + + def test_wait_first_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + # first_exception, task already has exception + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + + @tasks.coroutine + def exc(): + raise ZeroDivisionError('err') + + b = tasks.Task(exc(), loop=loop) + task = tasks.Task( + tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION, + loop=loop), + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertAlmostEqual(0, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_first_exception_in_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + # first_exception, exception during waiting + a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + + @tasks.coroutine + def exc(): + yield from tasks.sleep(0.01, loop=loop) + raise ZeroDivisionError('err') + + b = tasks.Task(exc(), loop=loop) + task = tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION, + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertAlmostEqual(0.01, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_with_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(0.15, loop=loop) + raise ZeroDivisionError('really') + + b = tasks.Task(sleeper(), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], loop=loop) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + def test_wait_with_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.11, when) + yield 0.11 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + @tasks.coroutine + def foo(): + done, pending = yield from tasks.wait([b, a], timeout=0.11, + loop=loop) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + + loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.11, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_wait_concurrent_complete(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + + done, pending = loop.run_until_complete( + tasks.wait([b, a], timeout=0.1, loop=loop)) + + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + self.assertAlmostEqual(0.1, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_as_completed(self): + + def gen(): + yield 0 + yield 0 + yield 0.01 + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + completed = set() + time_shifted = False + + @tasks.coroutine + def sleeper(dt, x): + nonlocal time_shifted + yield from tasks.sleep(dt, loop=loop) + completed.add(x) + if not time_shifted and 'a' in completed and 'b' in completed: + time_shifted = True + loop.advance_time(0.14) + return x + + a = sleeper(0.01, 'a') + b = sleeper(0.01, 'b') + c = sleeper(0.15, 'c') + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([b, c, a], loop=loop): + values.append((yield from f)) + return values + + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + + # Doing it again should take no time and exercise a different path. + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + def test_as_completed_with_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.12, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0.1 + self.assertAlmostEqual(0.12, when) + yield 0.02 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.sleep(0.1, 'a', loop=loop) + b = tasks.sleep(0.15, 'b', loop=loop) + + @tasks.coroutine + def foo(): + values = [] + for f in tasks.as_completed([a, b], timeout=0.12, loop=loop): + try: + v = yield from f + values.append((1, v)) + except futures.TimeoutError as exc: + values.append((2, exc)) + return values + + res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) + self.assertAlmostEqual(0.12, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(tasks.wait([a, b], loop=loop)) + + def test_as_completed_reverse_wait(self): + + def gen(): + yield 0 + yield 0.05 + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.sleep(0.05, 'a', loop=loop) + b = tasks.sleep(0.10, 'b', loop=loop) + fs = {a, b} + futs = list(tasks.as_completed(fs, loop=loop)) + self.assertEqual(len(futs), 2) + + x = loop.run_until_complete(futs[1]) + self.assertEqual(x, 'a') + self.assertAlmostEqual(0.05, loop.time()) + loop.advance_time(0.05) + y = loop.run_until_complete(futs[0]) + self.assertEqual(y, 'b') + self.assertAlmostEqual(0.10, loop.time()) + + def test_as_completed_concurrent(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0 + self.assertAlmostEqual(0.05, when) + yield 0.05 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = tasks.sleep(0.05, 'a', loop=loop) + b = tasks.sleep(0.05, 'b', loop=loop) + fs = {a, b} + futs = list(tasks.as_completed(fs, loop=loop)) + self.assertEqual(len(futs), 2) + waiter = tasks.wait(futs, loop=loop) + done, pending = loop.run_until_complete(waiter) + self.assertEqual(set(f.result() for f in done), {'a', 'b'}) + + def test_sleep(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0.05 + self.assertAlmostEqual(0.1, when) + yield 0.05 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def sleeper(dt, arg): + yield from tasks.sleep(dt/2, loop=loop) + res = yield from tasks.sleep(dt/2, arg, loop=loop) + return res + + t = tasks.Task(sleeper(0.1, 'yeah'), loop=loop) + loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + self.assertAlmostEqual(0.1, loop.time()) + + def test_sleep_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + t = tasks.Task(tasks.sleep(10.0, 'yeah', loop=loop), + loop=loop) + + handle = None + orig_call_later = loop.call_later + + def call_later(self, delay, callback, *args): + nonlocal handle + handle = orig_call_later(self, delay, callback, *args) + return handle + + loop.call_later = call_later + test_utils.run_briefly(loop) + + self.assertFalse(handle._cancelled) + + t.cancel() + test_utils.run_briefly(loop) + self.assertTrue(handle._cancelled) + + def test_task_cancel_sleeping_task(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(5000, when) + yield 0.1 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + sleepfut = None + + @tasks.coroutine + def sleep(dt): + nonlocal sleepfut + sleepfut = tasks.sleep(dt, loop=loop) + yield from sleepfut + + @tasks.coroutine + def doit(): + sleeper = tasks.Task(sleep(5000), loop=loop) + loop.call_later(0.1, sleeper.cancel) + try: + yield from sleeper + except futures.CancelledError: + return 'cancelled' + else: + return 'slept in' + + doer = doit() + self.assertEqual(loop.run_until_complete(doer), 'cancelled') + self.assertAlmostEqual(0.1, loop.time()) + + def test_task_cancel_waiter_future(self): + fut = futures.Future(loop=self.loop) + + @tasks.coroutine + def coro(): + yield from fut + + task = tasks.Task(coro(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(task._fut_waiter, fut) + + task.cancel() + test_utils.run_briefly(self.loop) + self.assertRaises( + futures.CancelledError, self.loop.run_until_complete, task) + self.assertIsNone(task._fut_waiter) + self.assertTrue(fut.cancelled()) + + def test_step_in_completed_task(self): + @tasks.coroutine + def notmuch(): + return 'ko' + + gen = notmuch() + task = tasks.Task(gen, loop=self.loop) + task.set_result('ok') + + self.assertRaises(AssertionError, task._step) + gen.close() + + def test_step_result(self): + @tasks.coroutine + def notmuch(): + yield None + yield 1 + return 'ko' + + self.assertRaises( + RuntimeError, self.loop.run_until_complete, notmuch()) + + def test_step_result_future(self): + # If coroutine returns future, task waits on this future. + + class Fut(futures.Future): + def __init__(self, *args, **kwds): + self.cb_added = False + super().__init__(*args, **kwds) + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + fut = Fut(loop=self.loop) + result = None + + @tasks.coroutine + def wait_for_future(): + nonlocal result + result = yield from fut + + t = tasks.Task(wait_for_future(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertTrue(fut.cb_added) + + res = object() + fut.set_result(res) + test_utils.run_briefly(self.loop) + self.assertIs(res, result) + self.assertTrue(t.done()) + self.assertIsNone(t.result()) + + def test_step_with_baseexception(self): + @tasks.coroutine + def notmutch(): + raise BaseException() + + task = tasks.Task(notmutch(), loop=self.loop) + self.assertRaises(BaseException, task._step) + + self.assertTrue(task.done()) + self.assertIsInstance(task.exception(), BaseException) + + def test_baseexception_during_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + @tasks.coroutine + def sleeper(): + yield from tasks.sleep(10, loop=loop) + + base_exc = BaseException() + + @tasks.coroutine + def notmutch(): + try: + yield from sleeper() + except futures.CancelledError: + raise base_exc + + task = tasks.Task(notmutch(), loop=loop) + test_utils.run_briefly(loop) + + task.cancel() + self.assertFalse(task.done()) + + self.assertRaises(BaseException, test_utils.run_briefly, loop) + + self.assertTrue(task.done()) + self.assertFalse(task.cancelled()) + self.assertIs(task.exception(), base_exc) + + def test_iscoroutinefunction(self): + def fn(): + pass + + self.assertFalse(tasks.iscoroutinefunction(fn)) + + def fn1(): + yield + self.assertFalse(tasks.iscoroutinefunction(fn1)) + + @tasks.coroutine + def fn2(): + yield + self.assertTrue(tasks.iscoroutinefunction(fn2)) + + def test_yield_vs_yield_from(self): + fut = futures.Future(loop=self.loop) + + @tasks.coroutine + def wait_for_future(): + yield fut + + task = wait_for_future() + with self.assertRaises(RuntimeError): + self.loop.run_until_complete(task) + + self.assertFalse(fut.done()) + + def test_yield_vs_yield_from_generator(self): + @tasks.coroutine + def coro(): + yield + + @tasks.coroutine + def wait_for_future(): + gen = coro() + try: + yield gen + finally: + gen.close() + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, task) + + def test_coroutine_non_gen_function(self): + @tasks.coroutine + def func(): + return 'test' + + self.assertTrue(tasks.iscoroutinefunction(func)) + + coro = func() + self.assertTrue(tasks.iscoroutine(coro)) + + res = self.loop.run_until_complete(coro) + self.assertEqual(res, 'test') + + def test_coroutine_non_gen_function_return_future(self): + fut = futures.Future(loop=self.loop) + + @tasks.coroutine + def func(): + return fut + + @tasks.coroutine + def coro(): + fut.set_result('test') + + t1 = tasks.Task(func(), loop=self.loop) + t2 = tasks.Task(coro(), loop=self.loop) + res = self.loop.run_until_complete(t1) + self.assertEqual(res, 'test') + self.assertIsNone(t2.result()) + + # Some thorough tests for cancellation propagation through + # coroutines, tasks and wait(). + + def test_yield_future_passes_cancel(self): + # Cancelling outer() cancels inner() cancels waiter. + proof = 0 + waiter = futures.Future(loop=self.loop) + + @tasks.coroutine + def inner(): + nonlocal proof + try: + yield from waiter + except futures.CancelledError: + proof += 1 + raise + else: + self.fail('got past sleep() in inner()') + + @tasks.coroutine + def outer(): + nonlocal proof + try: + yield from inner() + except futures.CancelledError: + proof += 100 # Expect this path. + else: + proof += 10 + + f = tasks.async(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + self.loop.run_until_complete(f) + self.assertEqual(proof, 101) + self.assertTrue(waiter.cancelled()) + + def test_yield_wait_does_not_shield_cancel(self): + # Cancelling outer() makes wait() return early, leaves inner() + # running. + proof = 0 + waiter = futures.Future(loop=self.loop) + + @tasks.coroutine + def inner(): + nonlocal proof + yield from waiter + proof += 1 + + @tasks.coroutine + def outer(): + nonlocal proof + d, p = yield from tasks.wait([inner()], loop=self.loop) + proof += 100 + + f = tasks.async(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + self.assertRaises( + futures.CancelledError, self.loop.run_until_complete, f) + waiter.set_result(None) + test_utils.run_briefly(self.loop) + self.assertEqual(proof, 1) + + def test_shield_result(self): + inner = futures.Future(loop=self.loop) + outer = tasks.shield(inner) + inner.set_result(42) + res = self.loop.run_until_complete(outer) + self.assertEqual(res, 42) + + def test_shield_exception(self): + inner = futures.Future(loop=self.loop) + outer = tasks.shield(inner) + test_utils.run_briefly(self.loop) + exc = RuntimeError('expected') + inner.set_exception(exc) + test_utils.run_briefly(self.loop) + self.assertIs(outer.exception(), exc) + + def test_shield_cancel(self): + inner = futures.Future(loop=self.loop) + outer = tasks.shield(inner) + test_utils.run_briefly(self.loop) + inner.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(outer.cancelled()) + + def test_shield_shortcut(self): + fut = futures.Future(loop=self.loop) + fut.set_result(42) + res = self.loop.run_until_complete(tasks.shield(fut)) + self.assertEqual(res, 42) + + def test_shield_effect(self): + # Cancelling outer() does not affect inner(). + proof = 0 + waiter = futures.Future(loop=self.loop) + + @tasks.coroutine + def inner(): + nonlocal proof + yield from waiter + proof += 1 + + @tasks.coroutine + def outer(): + nonlocal proof + yield from tasks.shield(inner(), loop=self.loop) + proof += 100 + + f = tasks.async(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + with self.assertRaises(futures.CancelledError): + self.loop.run_until_complete(f) + waiter.set_result(None) + test_utils.run_briefly(self.loop) + self.assertEqual(proof, 1) + + def test_shield_gather(self): + child1 = futures.Future(loop=self.loop) + child2 = futures.Future(loop=self.loop) + parent = tasks.gather(child1, child2, loop=self.loop) + outer = tasks.shield(parent, loop=self.loop) + test_utils.run_briefly(self.loop) + outer.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(outer.cancelled()) + child1.set_result(1) + child2.set_result(2) + test_utils.run_briefly(self.loop) + self.assertEqual(parent.result(), [1, 2]) + + def test_gather_shield(self): + child1 = futures.Future(loop=self.loop) + child2 = futures.Future(loop=self.loop) + inner1 = tasks.shield(child1, loop=self.loop) + inner2 = tasks.shield(child2, loop=self.loop) + parent = tasks.gather(inner1, inner2, loop=self.loop) + test_utils.run_briefly(self.loop) + parent.cancel() + # This should cancel inner1 and inner2 but bot child1 and child2. + test_utils.run_briefly(self.loop) + self.assertIsInstance(parent.exception(), futures.CancelledError) + self.assertTrue(inner1.cancelled()) + self.assertTrue(inner2.cancelled()) + child1.set_result(1) + child2.set_result(2) + test_utils.run_briefly(self.loop) + + +class GatherTestsBase: + + def setUp(self): + self.one_loop = test_utils.TestLoop() + self.other_loop = test_utils.TestLoop() + + def tearDown(self): + self.one_loop.close() + self.other_loop.close() + + def _run_loop(self, loop): + while loop._ready: + test_utils.run_briefly(loop) + + def _check_success(self, **kwargs): + a, b, c = [futures.Future(loop=self.one_loop) for i in range(3)] + fut = tasks.gather(*self.wrap_futures(a, b, c), **kwargs) + cb = Mock() + fut.add_done_callback(cb) + b.set_result(1) + a.set_result(2) + self._run_loop(self.one_loop) + self.assertEqual(cb.called, False) + self.assertFalse(fut.done()) + c.set_result(3) + self._run_loop(self.one_loop) + cb.assert_called_once_with(fut) + self.assertEqual(fut.result(), [2, 1, 3]) + + def test_success(self): + self._check_success() + self._check_success(return_exceptions=False) + + def test_result_exception_success(self): + self._check_success(return_exceptions=True) + + def test_one_exception(self): + a, b, c, d, e = [futures.Future(loop=self.one_loop) for i in range(5)] + fut = tasks.gather(*self.wrap_futures(a, b, c, d, e)) + cb = Mock() + fut.add_done_callback(cb) + exc = ZeroDivisionError() + a.set_result(1) + b.set_exception(exc) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertIs(fut.exception(), exc) + # Does nothing + c.set_result(3) + d.cancel() + e.set_exception(RuntimeError()) + + def test_return_exceptions(self): + a, b, c, d = [futures.Future(loop=self.one_loop) for i in range(4)] + fut = tasks.gather(*self.wrap_futures(a, b, c, d), + return_exceptions=True) + cb = Mock() + fut.add_done_callback(cb) + exc = ZeroDivisionError() + exc2 = RuntimeError() + b.set_result(1) + c.set_exception(exc) + a.set_result(3) + self._run_loop(self.one_loop) + self.assertFalse(fut.done()) + d.set_exception(exc2) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertEqual(fut.result(), [3, 1, exc, exc2]) + + +class FutureGatherTests(GatherTestsBase, unittest.TestCase): + + def wrap_futures(self, *futures): + return futures + + def _check_empty_sequence(self, seq_or_iter): + events.set_event_loop(self.one_loop) + self.addCleanup(events.set_event_loop, None) + fut = tasks.gather(*seq_or_iter) + self.assertIsInstance(fut, futures.Future) + self.assertIs(fut._loop, self.one_loop) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + self.assertEqual(fut.result(), []) + fut = tasks.gather(*seq_or_iter, loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + + def test_constructor_empty_sequence(self): + self._check_empty_sequence([]) + self._check_empty_sequence(()) + self._check_empty_sequence(set()) + self._check_empty_sequence(iter("")) + + def test_constructor_heterogenous_futures(self): + fut1 = futures.Future(loop=self.one_loop) + fut2 = futures.Future(loop=self.other_loop) + with self.assertRaises(ValueError): + tasks.gather(fut1, fut2) + with self.assertRaises(ValueError): + tasks.gather(fut1, loop=self.other_loop) + + def test_constructor_homogenous_futures(self): + children = [futures.Future(loop=self.other_loop) for i in range(3)] + fut = tasks.gather(*children) + self.assertIs(fut._loop, self.other_loop) + self._run_loop(self.other_loop) + self.assertFalse(fut.done()) + fut = tasks.gather(*children, loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + self._run_loop(self.other_loop) + self.assertFalse(fut.done()) + + def test_one_cancellation(self): + a, b, c, d, e = [futures.Future(loop=self.one_loop) for i in range(5)] + fut = tasks.gather(a, b, c, d, e) + cb = Mock() + fut.add_done_callback(cb) + a.set_result(1) + b.cancel() + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertFalse(fut.cancelled()) + self.assertIsInstance(fut.exception(), futures.CancelledError) + # Does nothing + c.set_result(3) + d.cancel() + e.set_exception(RuntimeError()) + + def test_result_exception_one_cancellation(self): + a, b, c, d, e, f = [futures.Future(loop=self.one_loop) + for i in range(6)] + fut = tasks.gather(a, b, c, d, e, f, return_exceptions=True) + cb = Mock() + fut.add_done_callback(cb) + a.set_result(1) + zde = ZeroDivisionError() + b.set_exception(zde) + c.cancel() + self._run_loop(self.one_loop) + self.assertFalse(fut.done()) + d.set_result(3) + e.cancel() + rte = RuntimeError() + f.set_exception(rte) + res = self.one_loop.run_until_complete(fut) + self.assertIsInstance(res[2], futures.CancelledError) + self.assertIsInstance(res[4], futures.CancelledError) + res[2] = res[4] = None + self.assertEqual(res, [1, zde, None, 3, None, rte]) + cb.assert_called_once_with(fut) + + +class CoroutineGatherTests(GatherTestsBase, unittest.TestCase): + + def setUp(self): + super().setUp() + events.set_event_loop(self.one_loop) + + def tearDown(self): + events.set_event_loop(None) + super().tearDown() + + def wrap_futures(self, *futures): + coros = [] + for fut in futures: + @tasks.coroutine + def coro(fut=fut): + return (yield from fut) + coros.append(coro()) + return coros + + def test_constructor_loop_selection(self): + @tasks.coroutine + def coro(): + return 'abc' + gen1 = coro() + gen2 = coro() + fut = tasks.gather(gen1, gen2) + self.assertIs(fut._loop, self.one_loop) + gen1.close() + gen2.close() + gen3 = coro() + gen4 = coro() + fut = tasks.gather(gen3, gen4, loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + gen3.close() + gen4.close() + + def test_cancellation_broadcast(self): + # Cancelling outer() cancels all children. + proof = 0 + waiter = futures.Future(loop=self.one_loop) + + @tasks.coroutine + def inner(): + nonlocal proof + yield from waiter + proof += 1 + + child1 = tasks.async(inner(), loop=self.one_loop) + child2 = tasks.async(inner(), loop=self.one_loop) + gatherer = None + + @tasks.coroutine + def outer(): + nonlocal proof, gatherer + gatherer = tasks.gather(child1, child2, loop=self.one_loop) + yield from gatherer + proof += 100 + + f = tasks.async(outer(), loop=self.one_loop) + test_utils.run_briefly(self.one_loop) + self.assertTrue(f.cancel()) + with self.assertRaises(futures.CancelledError): + self.one_loop.run_until_complete(f) + self.assertFalse(gatherer.cancel()) + self.assertTrue(waiter.cancelled()) + self.assertTrue(child1.cancelled()) + self.assertTrue(child2.cancelled()) + test_utils.run_briefly(self.one_loop) + self.assertEqual(proof, 0) + + def test_exception_marking(self): + # Test for the first line marked "Mark exception retrieved." + + @tasks.coroutine + def inner(f): + yield from f + raise RuntimeError('should not be ignored') + + a = futures.Future(loop=self.one_loop) + b = futures.Future(loop=self.one_loop) + + @tasks.coroutine + def outer(): + yield from tasks.gather(inner(a), inner(b), loop=self.one_loop) + + f = tasks.async(outer(), loop=self.one_loop) + test_utils.run_briefly(self.one_loop) + a.set_result(None) + test_utils.run_briefly(self.one_loop) + b.set_result(None) + test_utils.run_briefly(self.one_loop) + self.assertIsInstance(f.exception(), RuntimeError) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/transports_test.py b/tests/transports_test.py new file mode 100644 index 00000000..38492931 --- /dev/null +++ b/tests/transports_test.py @@ -0,0 +1,55 @@ +"""Tests for transports.py.""" + +import unittest +import unittest.mock + +from tulip import transports + + +class TransportTests(unittest.TestCase): + + def test_ctor_extra_is_none(self): + transport = transports.Transport() + self.assertEqual(transport._extra, {}) + + def test_get_extra_info(self): + transport = transports.Transport({'extra': 'info'}) + self.assertEqual('info', transport.get_extra_info('extra')) + self.assertIsNone(transport.get_extra_info('unknown')) + + default = object() + self.assertIs(default, transport.get_extra_info('unknown', default)) + + def test_writelines(self): + transport = transports.Transport() + transport.write = unittest.mock.Mock() + + transport.writelines(['line1', 'line2', 'line3']) + self.assertEqual(3, transport.write.call_count) + + def test_not_implemented(self): + transport = transports.Transport() + + self.assertRaises(NotImplementedError, transport.write, 'data') + self.assertRaises(NotImplementedError, transport.write_eof) + self.assertRaises(NotImplementedError, transport.can_write_eof) + self.assertRaises(NotImplementedError, transport.pause_reading) + self.assertRaises(NotImplementedError, transport.resume_reading) + self.assertRaises(NotImplementedError, transport.close) + self.assertRaises(NotImplementedError, transport.abort) + + def test_dgram_not_implemented(self): + transport = transports.DatagramTransport() + + self.assertRaises(NotImplementedError, transport.sendto, 'data') + self.assertRaises(NotImplementedError, transport.abort) + + def test_subprocess_transport_not_implemented(self): + transport = transports.SubprocessTransport() + + self.assertRaises(NotImplementedError, transport.get_pid) + self.assertRaises(NotImplementedError, transport.get_returncode) + self.assertRaises(NotImplementedError, transport.get_pipe_transport, 1) + self.assertRaises(NotImplementedError, transport.send_signal, 1) + self.assertRaises(NotImplementedError, transport.terminate) + self.assertRaises(NotImplementedError, transport.kill) diff --git a/tests/unix_events_test.py b/tests/unix_events_test.py new file mode 100644 index 00000000..5659d31f --- /dev/null +++ b/tests/unix_events_test.py @@ -0,0 +1,770 @@ +"""Tests for unix_events.py.""" + +import gc +import errno +import io +import pprint +import signal +import stat +import sys +import unittest +import unittest.mock + +if sys.platform == 'win32': + raise unittest.SkipTest('UNIX only') + + +from tulip import events +from tulip import futures +from tulip import protocols +from tulip import test_utils +from tulip import unix_events + + +@unittest.skipUnless(signal, 'Signals are not supported') +class SelectorEventLoopTests(unittest.TestCase): + + def setUp(self): + self.loop = unix_events.SelectorEventLoop() + events.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_check_signal(self): + self.assertRaises( + TypeError, self.loop._check_signal, '1') + self.assertRaises( + ValueError, self.loop._check_signal, signal.NSIG + 1) + + def test_handle_signal_no_handler(self): + self.loop._handle_signal(signal.NSIG + 1, ()) + + def test_handle_signal_cancelled_handler(self): + h = events.Handle(unittest.mock.Mock(), ()) + h.cancel() + self.loop._signal_handlers[signal.NSIG + 1] = h + self.loop.remove_signal_handler = unittest.mock.Mock() + self.loop._handle_signal(signal.NSIG + 1, ()) + self.loop.remove_signal_handler.assert_called_with(signal.NSIG + 1) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_setup_error(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.set_wakeup_fd.side_effect = ValueError + + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + cb = lambda: True + self.loop.add_signal_handler(signal.SIGHUP, cb) + h = self.loop._signal_handlers.get(signal.SIGHUP) + self.assertTrue(isinstance(h, events.Handle)) + self.assertEqual(h._callback, cb) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_add_signal_handler_install_error(self, m_signal): + m_signal.NSIG = signal.NSIG + + def set_wakeup_fd(fd): + if fd == -1: + raise ValueError() + m_signal.set_wakeup_fd = set_wakeup_fd + + class Err(OSError): + errno = errno.EFAULT + m_signal.signal.side_effect = Err + + self.assertRaises( + Err, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.logger') + def test_add_signal_handler_install_error2(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.loop._signal_handlers[signal.SIGHUP] = lambda: True + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(1, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.logger') + def test_add_signal_handler_install_error3(self, m_logging, m_signal): + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + m_signal.NSIG = signal.NSIG + + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(2, m_signal.set_wakeup_fd.call_count) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + self.assertTrue( + self.loop.remove_signal_handler(signal.SIGHUP)) + self.assertTrue(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_2(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.SIGINT = signal.SIGINT + + self.loop.add_signal_handler(signal.SIGINT, lambda: True) + self.loop._signal_handlers[signal.SIGHUP] = object() + m_signal.set_wakeup_fd.reset_mock() + + self.assertTrue( + self.loop.remove_signal_handler(signal.SIGINT)) + self.assertFalse(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGINT, m_signal.default_int_handler), + m_signal.signal.call_args[0]) + + @unittest.mock.patch('tulip.unix_events.signal') + @unittest.mock.patch('tulip.unix_events.logger') + def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.set_wakeup_fd.side_effect = ValueError + + self.loop.remove_signal_handler(signal.SIGHUP) + self.assertTrue(m_logging.info) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error(self, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.signal.side_effect = OSError + + self.assertRaises( + OSError, self.loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('tulip.unix_events.signal') + def test_remove_signal_handler_error2(self, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.assertRaises( + RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = True + m_WIFSIGNALED.return_value = False + m_WEXITSTATUS.return_value = 3 + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + transp._process_exited.assert_called_with(3) + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_signal(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = False + m_WIFSIGNALED.return_value = True + m_WTERMSIG.return_value = 1 + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + transp._process_exited.assert_called_with(-1) + self.assertFalse(m_WEXITSTATUS.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_zero_pid(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(0, object()), ChildProcessError] + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m_WEXITSTATUS.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_not_registered_subprocess(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = True + m_WIFSIGNALED.return_value = False + m_WEXITSTATUS.return_value = 3 + + self.loop._sig_chld() + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_unknown_status(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + m_waitpid.side_effect = [(7, object()), ChildProcessError] + m_WIFEXITED.return_value = False + m_WIFSIGNALED.return_value = False + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('tulip.unix_events.logger') + @unittest.mock.patch('os.WTERMSIG') + @unittest.mock.patch('os.WEXITSTATUS') + @unittest.mock.patch('os.WIFSIGNALED') + @unittest.mock.patch('os.WIFEXITED') + @unittest.mock.patch('os.waitpid') + def test__sig_chld_unknown_status_in_handler(self, m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG, + m_log): + m_waitpid.side_effect = Exception + transp = unittest.mock.Mock() + self.loop._subprocesses[7] = transp + + self.loop._sig_chld() + self.assertFalse(transp._process_exited.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m_WEXITSTATUS.called) + m_log.exception.assert_called_with( + 'Unknown exception in SIGCHLD handler') + + @unittest.mock.patch('os.waitpid') + def test__sig_chld_process_error(self, m_waitpid): + m_waitpid.side_effect = ChildProcessError + self.loop._sig_chld() + self.assertTrue(m_waitpid.called) + + +class UnixReadPipeTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(protocols.Protocol) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + + fcntl_patcher = unittest.mock.patch('fcntl.fcntl') + fcntl_patcher.start() + self.addCleanup(fcntl_patcher.stop) + + def test_ctor(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = futures.Future(loop=self.loop) + unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol, fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + + @unittest.mock.patch('os.read') + def test__read_ready(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.return_value = b'data' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.data_received.assert_called_with(b'data') + + @unittest.mock.patch('os.read') + def test__read_ready_eof(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.return_value = b'' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.eof_received.assert_called_with() + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('os.read') + def test__read_ready_blocked(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.side_effect = BlockingIOError + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.data_received.called) + + @unittest.mock.patch('tulip.log.logger.exception') + @unittest.mock.patch('os.read') + def test__read_ready_error(self, m_read, m_logexc): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + err = OSError() + m_read.side_effect = err + tr._close = unittest.mock.Mock() + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + tr._close.assert_called_with(err) + m_logexc.assert_called_with('Fatal error for %s', tr) + + @unittest.mock.patch('os.read') + def test_pause_reading(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + m = unittest.mock.Mock() + self.loop.add_reader(5, m) + tr.pause_reading() + self.assertFalse(self.loop.readers) + + @unittest.mock.patch('os.read') + def test_resume_reading(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr.resume_reading() + self.loop.assert_reader(5, tr._read_ready) + + @unittest.mock.patch('os.read') + def test_close(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr._close = unittest.mock.Mock() + tr.close() + tr._close.assert_called_with(None) + + @unittest.mock.patch('os.read') + def test_close_already_closing(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr._closing = True + tr._close = unittest.mock.Mock() + tr.close() + self.assertFalse(tr._close.called) + + @unittest.mock.patch('os.read') + def test__close(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = object() + tr._close(err) + self.assertTrue(tr._closing) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + + def test__call_connection_lost(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test__call_connection_lost_with_err(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + +class UnixWritePipeTransportTests(unittest.TestCase): + + def setUp(self): + self.loop = test_utils.TestLoop() + self.protocol = test_utils.make_test_protocol(protocols.BaseProtocol) + self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + + fcntl_patcher = unittest.mock.patch('fcntl.fcntl') + fcntl_patcher.start() + self.addCleanup(fcntl_patcher.stop) + + fstat_patcher = unittest.mock.patch('os.fstat') + m_fstat = fstat_patcher.start() + st = unittest.mock.Mock() + st.st_mode = stat.S_IFIFO + m_fstat.return_value = st + self.addCleanup(fstat_patcher.stop) + + def test_ctor(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = futures.Future(loop=self.loop) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol, fut) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.assertEqual(None, fut.result()) + + def test_can_write_eof(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.assertTrue(tr.can_write_eof()) + + @unittest.mock.patch('os.write') + def test_write(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.return_value = 4 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_no_data(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write(b'') + self.assertFalse(m_write.called) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_partial(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.return_value = 2 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'ta'], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_buffer(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'previous'] + tr.write(b'data') + self.assertFalse(m_write.called) + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'previous', b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + def test_write_again(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.side_effect = BlockingIOError() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('tulip.unix_events.logger') + @unittest.mock.patch('os.write') + def test_write_err(self, m_write, m_log): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + m_write.side_effect = err + tr._fatal_error = unittest.mock.Mock() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + tr._fatal_error.assert_called_with(err) + self.assertEqual(1, tr._conn_lost) + + tr.write(b'data') + self.assertEqual(2, tr._conn_lost) + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + # This is a bit overspecified. :-( + m_log.warning.assert_called_with( + 'pipe closed by peer or os.write(pipe, data) raised exception.') + + @unittest.mock.patch('os.write') + def test_write_close(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr._read_ready() # pipe was closed by peer + + tr.write(b'data') + self.assertEqual(tr._conn_lost, 1) + tr.write(b'data') + self.assertEqual(tr._conn_lost, 2) + + def test__read_ready(self): + tr = unix_events._UnixWritePipeTransport(self.loop, self.pipe, + self.protocol) + tr._read_ready() + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + self.assertTrue(tr._closing) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + @unittest.mock.patch('os.write') + def test__write_ready(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_partial(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 3 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'a'], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_again(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.side_effect = BlockingIOError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('os.write') + def test__write_ready_empty(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 0 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @unittest.mock.patch('tulip.log.logger.exception') + @unittest.mock.patch('os.write') + def test__write_ready_err(self, m_write, m_logexc): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.side_effect = err = OSError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + m_logexc.assert_called_with('Fatal error for %s', tr) + self.assertEqual(1, tr._conn_lost) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + + @unittest.mock.patch('os.write') + def test__write_ready_closing(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._closing = True + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) + self.assertEqual([], tr._buffer) + self.protocol.connection_lost.assert_called_with(None) + self.pipe.close.assert_called_with() + + @unittest.mock.patch('os.write') + def test_abort(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + self.loop.add_reader(5, tr._read_ready) + tr._buffer = [b'da', b'ta'] + tr.abort() + self.assertFalse(m_write.called) + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test__call_connection_lost(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test__call_connection_lost_with_err(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), + pprint.pformat(gc.get_referrers(self.protocol))) + self.assertIsNone(tr._loop) + self.assertEqual(2, sys.getrefcount(self.loop), + pprint.pformat(gc.get_referrers(self.loop))) + + def test_close(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr.close() + tr.write_eof.assert_called_with() + + def test_close_closing(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof = unittest.mock.Mock() + tr._closing = True + tr.close() + self.assertFalse(tr.write_eof.called) + + def test_write_eof(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test_write_eof_pending(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr._buffer = [b'data'] + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.protocol.connection_lost.called) diff --git a/tests/windows_events_test.py b/tests/windows_events_test.py new file mode 100644 index 00000000..675728e0 --- /dev/null +++ b/tests/windows_events_test.py @@ -0,0 +1,91 @@ +import os +import unittest + +import tulip + +from tulip import windows_events +from tulip import protocols +from tulip import streams +from tulip import transports +from tulip import test_utils + + +class UpperProto(protocols.Protocol): + def __init__(self): + self.buf = [] + + def connection_made(self, trans): + self.trans = trans + + def data_received(self, data): + self.buf.append(data) + if b'\n' in data: + self.trans.write(b''.join(self.buf).upper()) + self.trans.close() + + +class ProactorTests(unittest.TestCase): + + def setUp(self): + self.loop = windows_events.ProactorEventLoop() + tulip.set_event_loop(None) + + def tearDown(self): + self.loop.close() + self.loop = None + + def test_close(self): + a, b = self.loop._socketpair() + trans = self.loop._make_socket_transport(a, protocols.Protocol()) + f = tulip.async(self.loop.sock_recv(b, 100)) + trans.close() + self.loop.run_until_complete(f) + self.assertEqual(f.result(), b'') + + def test_double_bind(self): + ADDRESS = r'\\.\pipe\test_double_bind-%s' % os.getpid() + server1 = windows_events.PipeServer(ADDRESS) + with self.assertRaises(PermissionError): + server2 = windows_events.PipeServer(ADDRESS) + server1.close() + + def test_pipe(self): + res = self.loop.run_until_complete(self._test_pipe()) + self.assertEqual(res, 'done') + + def _test_pipe(self): + ADDRESS = r'\\.\pipe\_test_pipe-%s' % os.getpid() + + with self.assertRaises(FileNotFoundError): + yield from self.loop.create_pipe_connection( + protocols.Protocol, ADDRESS) + + [server] = yield from self.loop.start_serving_pipe( + UpperProto, ADDRESS) + self.assertIsInstance(server, windows_events.PipeServer) + + clients = [] + for i in range(5): + stream_reader = streams.StreamReader(loop=self.loop) + protocol = streams.StreamReaderProtocol(stream_reader) + trans, proto = yield from self.loop.create_pipe_connection( + lambda:protocol, ADDRESS) + self.assertIsInstance(trans, transports.Transport) + self.assertEqual(protocol, proto) + clients.append((stream_reader, trans)) + + for i, (r, w) in enumerate(clients): + w.write('lower-{}\n'.format(i).encode()) + + for i, (r, w) in enumerate(clients): + response = yield from r.readline() + self.assertEqual(response, 'LOWER-{}\n'.format(i).encode()) + w.close() + + server.close() + + with self.assertRaises(FileNotFoundError): + yield from self.loop.create_pipe_connection( + protocols.Protocol, ADDRESS) + + return 'done' diff --git a/tests/windows_utils_test.py b/tests/windows_utils_test.py new file mode 100644 index 00000000..bd61463f --- /dev/null +++ b/tests/windows_utils_test.py @@ -0,0 +1,136 @@ +"""Tests for window_utils""" + +import sys +import test.support +import unittest +import unittest.mock +import _winapi + +from tulip import windows_utils + +try: + import _overlapped +except ImportError: + from tulip import _overlapped + + +class WinsocketpairTests(unittest.TestCase): + + def test_winsocketpair(self): + ssock, csock = windows_utils.socketpair() + + csock.send(b'xxx') + self.assertEqual(b'xxx', ssock.recv(1024)) + + csock.close() + ssock.close() + + @unittest.mock.patch('tulip.windows_utils.socket') + def test_winsocketpair_exc(self, m_socket): + m_socket.socket.return_value.getsockname.return_value = ('', 12345) + m_socket.socket.return_value.accept.return_value = object(), object() + m_socket.socket.return_value.connect.side_effect = OSError() + + self.assertRaises(OSError, windows_utils.socketpair) + + +class PipeTests(unittest.TestCase): + + def test_pipe_overlapped(self): + h1, h2 = windows_utils.pipe(overlapped=(True, True)) + try: + ov1 = _overlapped.Overlapped() + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, 0) + + ov1.ReadFile(h1, 100) + self.assertTrue(ov1.pending) + self.assertEqual(ov1.error, _winapi.ERROR_IO_PENDING) + ERROR_IO_INCOMPLETE = 996 + try: + ov1.getresult() + except OSError as e: + self.assertEqual(e.winerror, ERROR_IO_INCOMPLETE) + else: + raise RuntimeError('expected ERROR_IO_INCOMPLETE') + + ov2 = _overlapped.Overlapped() + self.assertFalse(ov2.pending) + self.assertEqual(ov2.error, 0) + + ov2.WriteFile(h2, b"hello") + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + + res = _winapi.WaitForMultipleObjects([ov2.event], False, 100) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, ERROR_IO_INCOMPLETE) + self.assertFalse(ov2.pending) + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + self.assertEqual(ov1.getresult(), b"hello") + finally: + _winapi.CloseHandle(h1) + _winapi.CloseHandle(h2) + + def test_pipe_handle(self): + h, _ = windows_utils.pipe(overlapped=(True, True)) + _winapi.CloseHandle(_) + p = windows_utils.PipeHandle(h) + self.assertEqual(p.fileno(), h) + self.assertEqual(p.handle, h) + + # check garbage collection of p closes handle + del p + test.support.gc_collect() + try: + _winapi.CloseHandle(h) + except OSError as e: + self.assertEqual(e.winerror, 6) # ERROR_INVALID_HANDLE + else: + raise RuntimeError('expected ERROR_INVALID_HANDLE') + + +class PopenTests(unittest.TestCase): + + def test_popen(self): + command = r"""if 1: + import sys + s = sys.stdin.readline() + sys.stdout.write(s.upper()) + sys.stderr.write('stderr') + """ + msg = b"blah\n" + + p = windows_utils.Popen([sys.executable, '-c', command], + stdin=windows_utils.PIPE, + stdout=windows_utils.PIPE, + stderr=windows_utils.PIPE) + + for f in [p.stdin, p.stdout, p.stderr]: + self.assertIsInstance(f, windows_utils.PipeHandle) + + ovin = _overlapped.Overlapped() + ovout = _overlapped.Overlapped() + overr = _overlapped.Overlapped() + + ovin.WriteFile(p.stdin.handle, msg) + ovout.ReadFile(p.stdout.handle, 100) + overr.ReadFile(p.stderr.handle, 100) + + events = [ovin.event, ovout.event, overr.event] + res = _winapi.WaitForMultipleObjects(events, True, 2000) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + self.assertFalse(ovout.pending) + self.assertFalse(overr.pending) + self.assertFalse(ovin.pending) + + self.assertEqual(ovin.getresult(), len(msg)) + out = ovout.getresult().rstrip() + err = overr.getresult().rstrip() + + self.assertGreater(len(out), 0) + self.assertGreater(len(err), 0) + # allow for partial reads... + self.assertTrue(msg.upper().rstrip().startswith(out)) + self.assertTrue(b"stderr".startswith(err)) diff --git a/tulip/__init__.py b/tulip/__init__.py new file mode 100644 index 00000000..faf307fb --- /dev/null +++ b/tulip/__init__.py @@ -0,0 +1,26 @@ +"""Tulip 2.0, tracking PEP 3156.""" + +import sys + +# This relies on each of the submodules having an __all__ variable. +from .futures import * +from .events import * +from .locks import * +from .transports import * +from .protocols import * +from .streams import * +from .tasks import * + +if sys.platform == 'win32': # pragma: no cover + from .windows_events import * +else: + from .unix_events import * # pragma: no cover + + +__all__ = (futures.__all__ + + events.__all__ + + locks.__all__ + + transports.__all__ + + protocols.__all__ + + streams.__all__ + + tasks.__all__) diff --git a/tulip/base_events.py b/tulip/base_events.py new file mode 100644 index 00000000..5f1bff71 --- /dev/null +++ b/tulip/base_events.py @@ -0,0 +1,606 @@ +"""Base implementation of event loop. + +The event loop can be broken up into a multiplexer (the part +responsible for notifying us of IO events) and the event loop proper, +which wraps a multiplexer with functionality for scheduling callbacks, +immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. +""" + + +import collections +import concurrent.futures +import heapq +import logging +import socket +import subprocess +import time +import os +import sys + +from . import events +from . import futures +from . import tasks +from .log import logger + + +__all__ = ['BaseEventLoop', 'Server'] + + +# Argument for default thread pool executor creation. +_MAX_WORKERS = 5 + + +class _StopError(BaseException): + """Raised to stop the event loop.""" + + +def _raise_stop_error(*args): + raise _StopError + + +class Server(events.AbstractServer): + + def __init__(self, loop, sockets): + self.loop = loop + self.sockets = sockets + self.active_count = 0 + self.waiters = [] + + def attach(self, transport): + assert self.sockets is not None + self.active_count += 1 + + def detach(self, transport): + assert self.active_count > 0 + self.active_count -= 1 + if self.active_count == 0 and self.sockets is None: + self._wakeup() + + def close(self): + sockets = self.sockets + if sockets is not None: + self.sockets = None + for sock in sockets: + self.loop._stop_serving(sock) + if self.active_count == 0: + self._wakeup() + + def _wakeup(self): + waiters = self.waiters + self.waiters = None + for waiter in waiters: + if not waiter.done(): + waiter.set_result(waiter) + + @tasks.coroutine + def wait_closed(self): + if self.sockets is None or self.waiters is None: + return + waiter = futures.Future(loop=self.loop) + self.waiters.append(waiter) + yield from waiter + + +class BaseEventLoop(events.AbstractEventLoop): + + def __init__(self): + self._ready = collections.deque() + self._scheduled = [] + self._default_executor = None + self._internal_fds = 0 + self._running = False + + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None, server=None): + """Create socket transport.""" + raise NotImplementedError + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + server_side=False, server_hostname=None, + extra=None, server=None): + """Create SSL transport.""" + raise NotImplementedError + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + """Create datagram transport.""" + raise NotImplementedError + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create read pipe transport.""" + raise NotImplementedError + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create write pipe transport.""" + raise NotImplementedError + + @tasks.coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + """Create subprocess transport.""" + raise NotImplementedError + + def _read_from_self(self): + """XXX""" + raise NotImplementedError + + def _write_to_self(self): + """XXX""" + raise NotImplementedError + + def _process_events(self, event_list): + """Process selector events.""" + raise NotImplementedError + + def run_forever(self): + """Run until stop() is called.""" + if self._running: + raise RuntimeError('Event loop is running.') + self._running = True + try: + while True: + try: + self._run_once() + except _StopError: + break + finally: + self._running = False + + def run_until_complete(self, future): + """Run until the Future is done. + + If the argument is a coroutine, it is wrapped in a Task. + + XXX TBD: It would be disastrous to call run_until_complete() + with the same coroutine twice -- it would wrap it in two + different Tasks and that can't be good. + + Return the Future's result, or raise its exception. + """ + future = tasks.async(future, loop=self) + future.add_done_callback(_raise_stop_error) + self.run_forever() + future.remove_done_callback(_raise_stop_error) + if not future.done(): + raise RuntimeError('Event loop stopped before Future completed.') + + return future.result() + + def stop(self): + """Stop running the event loop. + + Every callback scheduled before stop() is called will run. + Callback scheduled after stop() is called won't. However, + those callbacks will run if run() is called again later. + """ + self.call_soon(_raise_stop_error) + + def is_running(self): + """Returns running status of event loop.""" + return self._running + + def time(self): + """Return the time according to the event loop's clock.""" + return time.monotonic() + + def call_later(self, delay, callback, *args): + """Arrange for a callback to be called at a given time. + + Return a Handle: an opaque object with a cancel() method that + can be used to cancel the call. + + The delay can be an int or float, expressed in seconds. It is + always a relative time. + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + return self.call_at(self.time() + delay, callback, *args) + + def call_at(self, when, callback, *args): + """Like call_later(), but uses an absolute time.""" + timer = events.TimerHandle(when, callback, args) + heapq.heappush(self._scheduled, timer) + return timer + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue, callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + handle = events.make_handle(callback, args) + self._ready.append(handle) + return handle + + def call_soon_threadsafe(self, callback, *args): + """XXX""" + handle = self.call_soon(callback, *args) + self._write_to_self() + return handle + + def run_in_executor(self, executor, callback, *args): + if isinstance(callback, events.Handle): + assert not args + assert not isinstance(callback, events.TimerHandle) + if callback._cancelled: + f = futures.Future(loop=self) + f.set_result(None) + return f + callback, args = callback._callback, callback._args + if executor is None: + executor = self._default_executor + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + self._default_executor = executor + return futures.wrap_future(executor.submit(callback, *args), loop=self) + + def set_default_executor(self, executor): + self._default_executor = executor + + def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) + + def getnameinfo(self, sockaddr, flags=0): + return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + + @tasks.coroutine + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=None, family=0, proto=0, flags=0, sock=None, + local_addr=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + 'host/port and sock can not be specified at the same time') + + f1 = self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + fs = [f1] + if local_addr is not None: + f2 = self.getaddrinfo( + *local_addr, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + fs.append(f2) + else: + f2 = None + + yield from tasks.wait(fs, loop=self) + + infos = f1.result() + if not infos: + raise OSError('getaddrinfo() returned empty list') + if f2 is not None: + laddr_infos = f2.result() + if not laddr_infos: + raise OSError('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + if f2 is not None: + for _, _, _, _, laddr in laddr_infos: + try: + sock.bind(laddr) + break + except OSError as exc: + exc = OSError( + exc.errno, 'error while ' + 'attempting to bind on address ' + '{!r}: {}'.format( + laddr, exc.strerror.lower())) + exceptions.append(exc) + else: + sock.close() + sock = None + continue + yield from self.sock_connect(sock, address) + except OSError as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + if len(exceptions) == 1: + raise exceptions[0] + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise OSError('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + + elif sock is None: + raise ValueError( + 'host and port was not specified and no sock specified') + + sock.setblocking(False) + + protocol = protocol_factory() + waiter = futures.Future(loop=self) + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + transport = self._make_ssl_transport( + sock, protocol, sslcontext, waiter, + server_side=False, server_hostname=host) + else: + transport = self._make_socket_transport(sock, protocol, waiter) + + yield from waiter + return transport, protocol + + @tasks.coroutine + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + """Create datagram connection.""" + if not (local_addr or remote_addr): + if family == 0: + raise ValueError('unexpected address family') + addr_pairs_info = (((family, proto), (None, None)),) + else: + # join addresss by (family, protocol) + addr_infos = collections.OrderedDict() + for idx, addr in ((0, local_addr), (1, remote_addr)): + if addr is not None: + assert isinstance(addr, tuple) and len(addr) == 2, ( + '2-tuple is expected') + + infos = yield from self.getaddrinfo( + *addr, family=family, type=socket.SOCK_DGRAM, + proto=proto, flags=flags) + if not infos: + raise OSError('getaddrinfo() returned empty list') + + for fam, _, pro, _, address in infos: + key = (fam, pro) + if key not in addr_infos: + addr_infos[key] = [None, None] + addr_infos[key][idx] = address + + # each addr has to have info for each (family, proto) pair + addr_pairs_info = [ + (key, addr_pair) for key, addr_pair in addr_infos.items() + if not ((local_addr and addr_pair[0] is None) or + (remote_addr and addr_pair[1] is None))] + + if not addr_pairs_info: + raise ValueError('can not get address information') + + exceptions = [] + + for ((family, proto), + (local_address, remote_address)) in addr_pairs_info: + sock = None + r_addr = None + try: + sock = socket.socket( + family=family, type=socket.SOCK_DGRAM, proto=proto) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(False) + + if local_addr: + sock.bind(local_address) + if remote_addr: + yield from self.sock_connect(sock, remote_address) + r_addr = remote_address + except OSError as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + else: + break + else: + raise exceptions[0] + + protocol = protocol_factory() + transport = self._make_datagram_transport(sock, protocol, r_addr) + return transport, protocol + + @tasks.coroutine + def create_server(self, protocol_factory, host=None, port=None, + *, + family=socket.AF_UNSPEC, + flags=socket.AI_PASSIVE, + sock=None, + backlog=100, + ssl=None, + reuse_address=None): + """XXX""" + if host is not None or port is not None: + if sock is not None: + raise ValueError( + 'host/port and sock can not be specified at the same time') + + AF_INET6 = getattr(socket, 'AF_INET6', 0) + if reuse_address is None: + reuse_address = os.name == 'posix' and sys.platform != 'cygwin' + sockets = [] + if host == '': + host = None + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=0, flags=flags) + if not infos: + raise OSError('getaddrinfo() returned empty list') + + completed = False + try: + for res in infos: + af, socktype, proto, canonname, sa = res + sock = socket.socket(af, socktype, proto) + sockets.append(sock) + if reuse_address: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, + True) + # Disable IPv4/IPv6 dual stack support (enabled by + # default on Linux) which makes a single socket + # listen on both address families. + if af == AF_INET6 and hasattr(socket, 'IPPROTO_IPV6'): + sock.setsockopt(socket.IPPROTO_IPV6, + socket.IPV6_V6ONLY, + True) + try: + sock.bind(sa) + except OSError as err: + raise OSError(err.errno, 'error while attempting ' + 'to bind on address %r: %s' + % (sa, err.strerror.lower())) + completed = True + finally: + if not completed: + for sock in sockets: + sock.close() + else: + if sock is None: + raise ValueError( + 'host and port was not specified and no sock specified') + sockets = [sock] + + server = Server(self, sockets) + for sock in sockets: + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock, ssl, server) + return server + + @tasks.coroutine + def connect_read_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future(loop=self) + transport = self._make_read_pipe_transport(pipe, protocol, waiter) + yield from waiter + return transport, protocol + + @tasks.coroutine + def connect_write_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future(loop=self) + transport = self._make_write_pipe_transport(pipe, protocol, waiter) + yield from waiter + return transport, protocol + + @tasks.coroutine + def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + universal_newlines=False, shell=True, bufsize=0, + **kwargs): + assert not universal_newlines, "universal_newlines must be False" + assert shell, "shell must be True" + assert isinstance(cmd, str), cmd + protocol = protocol_factory() + transport = yield from self._make_subprocess_transport( + protocol, cmd, True, stdin, stdout, stderr, bufsize, **kwargs) + return transport, protocol + + @tasks.coroutine + def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + universal_newlines=False, shell=False, bufsize=0, + **kwargs): + assert not universal_newlines, "universal_newlines must be False" + assert not shell, "shell must be False" + protocol = protocol_factory() + transport = yield from self._make_subprocess_transport( + protocol, args, False, stdin, stdout, stderr, bufsize, **kwargs) + return transport, protocol + + def _add_callback(self, handle): + """Add a Handle to ready or scheduled.""" + assert isinstance(handle, events.Handle), 'A Handle is required here' + if handle._cancelled: + return + if isinstance(handle, events.TimerHandle): + heapq.heappush(self._scheduled, handle) + else: + self._ready.append(handle) + + def _add_callback_signalsafe(self, handle): + """Like _add_callback() but called from a signal handler.""" + self._add_callback(handle) + self._write_to_self() + + def _run_once(self): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0]._cancelled: + heapq.heappop(self._scheduled) + + timeout = None + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0]._when + deadline = max(0, when - self.time()) + if timeout is None: + timeout = deadline + else: + timeout = min(timeout, deadline) + + # TODO: Instrumentation only in debug mode? + t0 = self.time() + event_list = self._selector.select(timeout) + t1 = self.time() + argstr = '' if timeout is None else '{:.3f}'.format(timeout) + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logger.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + self._process_events(event_list) + + # Handle 'later' callbacks that are ready. + now = self.time() + while self._scheduled: + handle = self._scheduled[0] + if handle._when > now: + break + handle = heapq.heappop(self._scheduled) + self._ready.append(handle) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is threadsafe without using locks. + ntodo = len(self._ready) + for i in range(ntodo): + handle = self._ready.popleft() + if not handle._cancelled: + handle._run() + handle = None # Needed to break cycles when an exception occurs. diff --git a/tulip/constants.py b/tulip/constants.py new file mode 100644 index 00000000..79c3b931 --- /dev/null +++ b/tulip/constants.py @@ -0,0 +1,4 @@ +"""Constants.""" + + +LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5 diff --git a/tulip/events.py b/tulip/events.py new file mode 100644 index 00000000..6ca5668c --- /dev/null +++ b/tulip/events.py @@ -0,0 +1,395 @@ +"""Event loop and event loop policy.""" + +__all__ = ['AbstractEventLoopPolicy', 'DefaultEventLoopPolicy', + 'AbstractEventLoop', 'AbstractServer', + 'Handle', 'TimerHandle', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', 'new_event_loop', + ] + +import subprocess +import sys +import threading +import socket + +from .log import logger + + +class Handle: + """Object returned by callback registration methods.""" + + def __init__(self, callback, args): + self._callback = callback + self._args = args + self._cancelled = False + + def __repr__(self): + res = 'Handle({}, {})'.format(self._callback, self._args) + if self._cancelled: + res += '' + return res + + def cancel(self): + self._cancelled = True + + def _run(self): + try: + self._callback(*self._args) + except Exception: + logger.exception('Exception in callback %s %r', + self._callback, self._args) + self = None # Needed to break cycles when an exception occurs. + + +def make_handle(callback, args): + # TODO: Inline this? + assert not isinstance(callback, Handle), 'A Handle is not a callback' + return Handle(callback, args) + + +class TimerHandle(Handle): + """Object returned by timed callback registration methods.""" + + def __init__(self, when, callback, args): + assert when is not None + super().__init__(callback, args) + + self._when = when + + def __repr__(self): + res = 'TimerHandle({}, {}, {})'.format(self._when, + self._callback, + self._args) + if self._cancelled: + res += '' + + return res + + def __hash__(self): + return hash(self._when) + + def __lt__(self, other): + return self._when < other._when + + def __le__(self, other): + if self._when < other._when: + return True + return self.__eq__(other) + + def __gt__(self, other): + return self._when > other._when + + def __ge__(self, other): + if self._when > other._when: + return True + return self.__eq__(other) + + def __eq__(self, other): + if isinstance(other, TimerHandle): + return (self._when == other._when and + self._callback == other._callback and + self._args == other._args and + self._cancelled == other._cancelled) + return NotImplemented + + def __ne__(self, other): + equal = self.__eq__(other) + return NotImplemented if equal is NotImplemented else not equal + + +class AbstractServer: + """Abstract server returned by create_service().""" + + def close(self): + """Stop serving. This leaves existing connections open.""" + return NotImplemented + + def wait_closed(self): + """Coroutine to wait until service is closed.""" + return NotImplemented + + +class AbstractEventLoop: + """Abstract event loop.""" + + # Running and stopping the event loop. + + def run_forever(self): + """Run the event loop until stop() is called.""" + raise NotImplementedError + + def run_until_complete(self, future): + """Run the event loop until a Future is done. + + Return the Future's result, or raise its exception. + """ + raise NotImplementedError + + def stop(self): + """Stop the event loop as soon as reasonable. + + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + raise NotImplementedError + + def is_running(self): + """Return whether the event loop is currently running.""" + raise NotImplementedError + + # Methods scheduling callbacks. All these return Handles. + + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) + + def call_later(self, delay, callback, *args): + raise NotImplementedError + + def call_at(self, when, callback, *args): + raise NotImplementedError + + def time(self): + raise NotImplementedError + + # Methods for interacting with threads. + + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError + + def run_in_executor(self, executor, callback, *args): + raise NotImplementedError + + def set_default_executor(self, executor): + raise NotImplementedError + + # Network I/O methods returning Futures. + + def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError + + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=None, family=0, proto=0, flags=0, sock=None, + local_addr=None): + raise NotImplementedError + + def create_server(self, protocol_factory, host=None, port=None, *, + family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, + sock=None, backlog=100, ssl=None, reuse_address=None): + """A coroutine which creates a TCP server bound to host and port. + + The return value is a Server object which can be used to stop + the service. + + If host is an empty string or None all interfaces are assumed + and a list of multiple sockets will be returned (most likely + one for IPv4 and another one for IPv6). + + family can be set to either AF_INET or AF_INET6 to force the + socket to use IPv4 or IPv6. If not set it will be determined + from host (defaults to AF_UNSPEC). + + flags is a bitmask for getaddrinfo(). + + sock can optionally be specified in order to use a preexisting + socket object. + + backlog is the maximum number of queued connections passed to + listen() (defaults to 100). + + ssl can be set to an SSLContext to enable SSL over the + accepted connections. + + reuse_address tells the kernel to reuse a local socket in + TIME_WAIT state, without waiting for its natural timeout to + expire. If not specified will automatically be set to True on + UNIX. + """ + raise NotImplementedError + + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + def connect_read_pipe(self, protocol_factory, pipe): + """Register read pipe in eventloop. + + protocol_factory should instantiate object with Protocol interface. + pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + ReadTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def connect_write_pipe(self, protocol_factory, pipe): + """Register write pipe in eventloop. + + protocol_factory should instantiate object with BaseProtocol interface. + Pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + WriteTransport ABC""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + **kwargs): + raise NotImplementedError + + def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + **kwargs): + raise NotImplementedError + + # Ready-based callback registration methods. + # The add_*() methods return None. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. + + def add_reader(self, fd, callback, *args): + raise NotImplementedError + + def remove_reader(self, fd): + raise NotImplementedError + + def add_writer(self, fd, callback, *args): + raise NotImplementedError + + def remove_writer(self, fd): + raise NotImplementedError + + # Completion based I/O methods returning Futures. + + def sock_recv(self, sock, nbytes): + raise NotImplementedError + + def sock_sendall(self, sock, data): + raise NotImplementedError + + def sock_connect(self, sock, address): + raise NotImplementedError + + def sock_accept(self, sock): + raise NotImplementedError + + # Signal handling. + + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError + + def remove_signal_handler(self, sig): + raise NotImplementedError + + +class AbstractEventLoopPolicy: + """Abstract policy for accessing the event loop.""" + + def get_event_loop(self): + """XXX""" + raise NotImplementedError + + def set_event_loop(self, loop): + """XXX""" + raise NotImplementedError + + def new_event_loop(self): + """XXX""" + raise NotImplementedError + + +class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy): + """Default policy implementation for accessing the event loop. + + In this policy, each thread has its own event loop. However, we + only automatically create an event loop by default for the main + thread; other threads by default have no event loop. + + Other policies may have different rules (e.g. a single global + event loop, or automatically creating an event loop per thread, or + using some other notion of context to which an event loop is + associated). + """ + + _loop = None + _set_called = False + + def get_event_loop(self): + """Get the event loop. + + This may be None or an instance of EventLoop. + """ + if (self._loop is None and + not self._set_called and + isinstance(threading.current_thread(), threading._MainThread)): + self._loop = self.new_event_loop() + assert self._loop is not None, \ + ('There is no current event loop in thread %r.' % + threading.current_thread().name) + return self._loop + + def set_event_loop(self, loop): + """Set the event loop.""" + # TODO: The isinstance() test violates the PEP. + self._set_called = True + assert loop is None or isinstance(loop, AbstractEventLoop) + self._loop = loop + + def new_event_loop(self): + """Create a new event loop. + + You must call set_event_loop() to make this the current event + loop. + """ + if sys.platform == 'win32': # pragma: no cover + from . import windows_events + return windows_events.SelectorEventLoop() + else: # pragma: no cover + from . import unix_events + return unix_events.SelectorEventLoop() + + +# Event loop policy. The policy itself is always global, even if the +# policy's rules say that there is an event loop per thread (or other +# notion of context). The default policy is installed by the first +# call to get_event_loop_policy(). +_event_loop_policy = None + + +def get_event_loop_policy(): + """XXX""" + global _event_loop_policy + if _event_loop_policy is None: + _event_loop_policy = DefaultEventLoopPolicy() + return _event_loop_policy + + +def set_event_loop_policy(policy): + """XXX""" + global _event_loop_policy + # TODO: The isinstance() test violates the PEP. + assert policy is None or isinstance(policy, AbstractEventLoopPolicy) + _event_loop_policy = policy + + +def get_event_loop(): + """XXX""" + return get_event_loop_policy().get_event_loop() + + +def set_event_loop(loop): + """XXX""" + get_event_loop_policy().set_event_loop(loop) + + +def new_event_loop(): + """XXX""" + return get_event_loop_policy().new_event_loop() diff --git a/tulip/futures.py b/tulip/futures.py new file mode 100644 index 00000000..db278386 --- /dev/null +++ b/tulip/futures.py @@ -0,0 +1,338 @@ +"""A Future class similar to the one in PEP 3148.""" + +__all__ = ['CancelledError', 'TimeoutError', + 'InvalidStateError', + 'Future', 'wrap_future', + ] + +import concurrent.futures._base +import logging +import traceback + +from . import events +from .log import logger + +# States for Future. +_PENDING = 'PENDING' +_CANCELLED = 'CANCELLED' +_FINISHED = 'FINISHED' + +# TODO: Do we really want to depend on concurrent.futures internals? +Error = concurrent.futures._base.Error +CancelledError = concurrent.futures.CancelledError +TimeoutError = concurrent.futures.TimeoutError + +STACK_DEBUG = logging.DEBUG - 1 # heavy-duty debugging + + +class InvalidStateError(Error): + """The operation is not allowed in this state.""" + # TODO: Show the future, its state, the method, and the required state. + + +class _TracebackLogger: + """Helper to log a traceback upon destruction if not cleared. + + This solves a nasty problem with Futures and Tasks that have an + exception set: if nobody asks for the exception, the exception is + never logged. This violates the Zen of Python: 'Errors should + never pass silently. Unless explicitly silenced.' + + However, we don't want to log the exception as soon as + set_exception() is called: if the calling code is written + properly, it will get the exception and handle it properly. But + we *do* want to log it if result() or exception() was never called + -- otherwise developers waste a lot of time wondering why their + buggy code fails silently. + + An earlier attempt added a __del__() method to the Future class + itself, but this backfired because the presence of __del__() + prevents garbage collection from breaking cycles. A way out of + this catch-22 is to avoid having a __del__() method on the Future + class itself, but instead to have a reference to a helper object + with a __del__() method that logs the traceback, where we ensure + that the helper object doesn't participate in cycles, and only the + Future has a reference to it. + + The helper object is added when set_exception() is called. When + the Future is collected, and the helper is present, the helper + object is also collected, and its __del__() method will log the + traceback. When the Future's result() or exception() method is + called (and a helper object is present), it removes the the helper + object, after calling its clear() method to prevent it from + logging. + + One downside is that we do a fair amount of work to extract the + traceback from the exception, even when it is never logged. It + would seem cheaper to just store the exception object, but that + references the traceback, which references stack frames, which may + reference the Future, which references the _TracebackLogger, and + then the _TracebackLogger would be included in a cycle, which is + what we're trying to avoid! As an optimization, we don't + immediately format the exception; we only do the work when + activate() is called, which call is delayed until after all the + Future's callbacks have run. Since usually a Future has at least + one callback (typically set by 'yield from') and usually that + callback extracts the callback, thereby removing the need to + format the exception. + + PS. I don't claim credit for this solution. I first heard of it + in a discussion about closing files when they are collected. + """ + + __slots__ = ['exc', 'tb'] + + def __init__(self, exc): + self.exc = exc + self.tb = None + + def activate(self): + exc = self.exc + if exc is not None: + self.exc = None + self.tb = traceback.format_exception(exc.__class__, exc, + exc.__traceback__) + + def clear(self): + self.exc = None + self.tb = None + + def __del__(self): + if self.tb: + logger.error('Future/Task exception was never retrieved:\n%s', + ''.join(self.tb)) + + +class Future: + """This class is *almost* compatible with concurrent.futures.Future. + + Differences: + + - result() and exception() do not take a timeout argument and + raise an exception when the future isn't done yet. + + - Callbacks registered with add_done_callback() are always called + via the event loop's call_soon_threadsafe(). + + - This class is not compatible with the wait() and as_completed() + methods in the concurrent.futures package. + + (In Python 3.4 or later we may be able to unify the implementations.) + """ + + # Class variables serving as defaults for instance variables. + _state = _PENDING + _result = None + _exception = None + _loop = None + + _blocking = False # proper use of future (yield vs yield from) + + _tb_logger = None + + def __init__(self, *, loop=None): + """Initialize the future. + + The optional event_loop argument allows to explicitly set the event + loop object used by the future. If it's not provided, the future uses + the default event loop. + """ + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._callbacks = [] + + def __repr__(self): + res = self.__class__.__name__ + if self._state == _FINISHED: + if self._exception is not None: + res += ''.format(self._exception) + else: + res += ''.format(self._result) + elif self._callbacks: + size = len(self._callbacks) + if size > 2: + res += '<{}, [{}, <{} more>, {}]>'.format( + self._state, self._callbacks[0], + size-2, self._callbacks[-1]) + else: + res += '<{}, {}>'.format(self._state, self._callbacks) + else: + res += '<{}>'.format(self._state) + return res + + def cancel(self): + """Cancel the future and schedule callbacks. + + If the future is already done or cancelled, return False. Otherwise, + change the future's state to cancelled, schedule the callbacks and + return True. + """ + if self._state != _PENDING: + return False + self._state = _CANCELLED + self._schedule_callbacks() + return True + + def _schedule_callbacks(self): + """Internal: Ask the event loop to call all callbacks. + + The callbacks are scheduled to be called as soon as possible. Also + clears the callback list. + """ + callbacks = self._callbacks[:] + if not callbacks: + return + + self._callbacks[:] = [] + for callback in callbacks: + self._loop.call_soon(callback, self) + + def cancelled(self): + """Return True if the future was cancelled.""" + return self._state == _CANCELLED + + # Don't implement running(); see http://bugs.python.org/issue18699 + + def done(self): + """Return True if the future is done. + + Done means either that a result / exception are available, or that the + future was cancelled. + """ + return self._state != _PENDING + + def result(self): + """Return the result this future represents. + + If the future has been cancelled, raises CancelledError. If the + future's result isn't yet available, raises InvalidStateError. If + the future is done and has an exception set, this exception is raised. + """ + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError('Result is not ready.') + if self._tb_logger is not None: + self._tb_logger.clear() + self._tb_logger = None + if self._exception is not None: + raise self._exception + return self._result + + def exception(self): + """Return the exception that was set on this future. + + The exception (or None if no exception was set) is returned only if + the future is done. If the future has been cancelled, raises + CancelledError. If the future isn't done yet, raises + InvalidStateError. + """ + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError('Exception is not set.') + if self._tb_logger is not None: + self._tb_logger.clear() + self._tb_logger = None + return self._exception + + def add_done_callback(self, fn): + """Add a callback to be run when the future becomes done. + + The callback is called with a single argument - the future object. If + the future is already done when this is called, the callback is + scheduled with call_soon. + """ + if self._state != _PENDING: + self._loop.call_soon(fn, self) + else: + self._callbacks.append(fn) + + # New method not in PEP 3148. + + def remove_done_callback(self, fn): + """Remove all instances of a callback from the "call when done" list. + + Returns the number of callbacks removed. + """ + filtered_callbacks = [f for f in self._callbacks if f != fn] + removed_count = len(self._callbacks) - len(filtered_callbacks) + if removed_count: + self._callbacks[:] = filtered_callbacks + return removed_count + + # So-called internal methods (note: no set_running_or_notify_cancel()). + + def set_result(self, result): + """Mark the future done and set its result. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError('{}: {!r}'.format(self._state, self)) + self._result = result + self._state = _FINISHED + self._schedule_callbacks() + + def set_exception(self, exception): + """Mark the future done and set an exception. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError('{}: {!r}'.format(self._state, self)) + self._exception = exception + self._tb_logger = _TracebackLogger(exception) + self._state = _FINISHED + self._schedule_callbacks() + # Arrange for the logger to be activated after all callbacks + # have had a chance to call result() or exception(). + self._loop.call_soon(self._tb_logger.activate) + + # Truly internal methods. + + def _copy_state(self, other): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert other.done() + assert not self.done() + if other.cancelled(): + self.cancel() + else: + exception = other.exception() + if exception is not None: + self.set_exception(exception) + else: + result = other.result() + self.set_result(result) + + def __iter__(self): + if not self.done(): + self._blocking = True + yield self # This tells Task to wait for completion. + assert self.done(), "yield from wasn't used with future" + return self.result() # May raise too. + + +def wrap_future(fut, *, loop=None): + """Wrap concurrent.futures.Future object.""" + if isinstance(fut, Future): + return fut + + assert isinstance(fut, concurrent.futures.Future), \ + 'concurrent.futures.Future is expected, got {!r}'.format(fut) + + if loop is None: + loop = events.get_event_loop() + + new_future = Future(loop=loop) + fut.add_done_callback( + lambda future: loop.call_soon_threadsafe( + new_future._copy_state, fut)) + return new_future diff --git a/tulip/locks.py b/tulip/locks.py new file mode 100644 index 00000000..06edbbc1 --- /dev/null +++ b/tulip/locks.py @@ -0,0 +1,401 @@ +"""Synchronization primitives.""" + +__all__ = ['Lock', 'Event', 'Condition', 'Semaphore'] + +import collections + +from . import events +from . import futures +from . import tasks + + +class Lock: + """Primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned + by a particular coroutine when locked. A primitive lock is in one + of two states, 'locked' or 'unlocked'. + + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() + changes the state to locked and returns immediately. When the + state is locked, acquire() blocks until a call to release() in + another coroutine changes it to unlocked, then the acquire() call + resets it to locked and returns. The release() method should only + be called in the locked state; it changes the state to unlocked + and returns immediately. If an attempt is made to release an + unlocked lock, a RuntimeError will be raised. + + When more than one coroutine is blocked in acquire() waiting for + the state to turn to unlocked, only one coroutine proceeds when a + release() call resets the state to unlocked; first coroutine which + is blocked in acquire() is being processed. + + acquire() is a coroutine and should be called with 'yield from'. + + Locks also support the context manager protocol. '(yield from lock)' + should be used as context manager expression. + + Usage: + + lock = Lock() + ... + yield from lock + try: + ... + finally: + lock.release() + + Context manager usage: + + lock = Lock() + ... + with (yield from lock): + ... + + Lock objects can be tested for locking state: + + if not lock.locked(): + yield from lock + else: + # lock is acquired + ... + + """ + + def __init__(self, *, loop=None): + self._waiters = collections.deque() + self._locked = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + extra = 'locked' if self._locked else 'unlocked' + if self._waiters: + extra = '{},waiters:{}'.format(extra, len(self._waiters)) + return '<{} [{}]>'.format(res[1:-1], extra) + + def locked(self): + """Return true if lock is acquired.""" + return self._locked + + @tasks.coroutine + def acquire(self): + """Acquire a lock. + + This method blocks until the lock is unlocked, then sets it to + locked and returns True. + """ + if not self._waiters and not self._locked: + self._locked = True + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + self._locked = True + return True + finally: + self._waiters.remove(fut) + + def release(self): + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + If any other coroutines are blocked waiting for the lock to become + unlocked, allow exactly one of them to proceed. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + # Wake up the first waiter who isn't cancelled. + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + break + else: + raise RuntimeError('Lock is not acquired.') + + def __enter__(self): + if not self._locked: + raise RuntimeError( + '"yield from" should be used as context manager expression') + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self + + +class Event: + """An Event implementation, our equivalent to threading.Event. + + Class implementing event objects. An event manages a flag that can be set + to true with the set() method and reset to false with the clear() method. + The wait() method blocks until the flag is true. The flag is initially + false. + """ + + def __init__(self, *, loop=None): + self._waiters = collections.deque() + self._value = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + # TODO: add waiters:N if > 0. + res = super().__repr__() + return '<{} [{}]>'.format(res[1:-1], 'set' if self._value else 'unset') + + def is_set(self): + """Return true if and only if the internal flag is true.""" + return self._value + + def set(self): + """Set the internal flag to true. All coroutines waiting for it to + become true are awakened. Coroutine that call wait() once the flag is + true will not block at all. + """ + if not self._value: + self._value = True + + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + + def clear(self): + """Reset the internal flag to false. Subsequently, coroutines calling + wait() will block until set() is called to set the internal flag + to true again.""" + self._value = False + + @tasks.coroutine + def wait(self): + """Block until the internal flag is true. + + If the internal flag is true on entry, return True + immediately. Otherwise, block until another coroutine calls + set() to set the flag to true, then return True. + """ + if self._value: + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + return True + finally: + self._waiters.remove(fut) + + +# TODO: Why is this a Lock subclass? threading.Condition *has* a lock. +class Condition(Lock): + """A Condition implementation. + + This class implements condition variable objects. A condition variable + allows one or more coroutines to wait until they are notified by another + coroutine. + """ + + def __init__(self, *, loop=None): + super().__init__(loop=loop) + self._condition_waiters = collections.deque() + + # TODO: Add __repr__() with len(_condition_waiters). + + @tasks.coroutine + def wait(self): + """Wait until notified. + + If the calling coroutine has not acquired the lock when this + method is called, a RuntimeError is raised. + + This method releases the underlying lock, and then blocks + until it is awakened by a notify() or notify_all() call for + the same condition variable in another coroutine. Once + awakened, it re-acquires the lock and returns True. + """ + if not self._locked: + raise RuntimeError('cannot wait on un-acquired lock') + + keep_lock = True + self.release() + try: + fut = futures.Future(loop=self._loop) + self._condition_waiters.append(fut) + try: + yield from fut + return True + finally: + self._condition_waiters.remove(fut) + + except GeneratorExit: + keep_lock = False # Prevent yield in finally clause. + raise + finally: + if keep_lock: + yield from self.acquire() + + @tasks.coroutine + def wait_for(self, predicate): + """Wait until a predicate becomes true. + + The predicate should be a callable which result will be + interpreted as a boolean value. The final predicate value is + the return value. + """ + result = predicate() + while not result: + yield from self.wait() + result = predicate() + return result + + def notify(self, n=1): + """By default, wake up one coroutine waiting on this condition, if any. + If the calling coroutine has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up at most n of the coroutines waiting for the + condition variable; it is a no-op if no coroutines are waiting. + + Note: an awakened coroutine does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self._locked: + raise RuntimeError('cannot notify on un-acquired lock') + + idx = 0 + for fut in self._condition_waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self): + """Wake up all threads waiting on this condition. This method acts + like notify(), but wakes up all waiting threads instead of one. If the + calling thread has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._condition_waiters)) + + +class Semaphore: + """A Semaphore implementation. + + A semaphore manages an internal counter which is decremented by each + acquire() call and incremented by each release() call. The counter + can never go below zero; when acquire() finds that it is zero, it blocks, + waiting until some other thread calls release(). + + Semaphores also support the context manager protocol. + + The first optional argument gives the initial value for the internal + counter; it defaults to 1. If the value given is less than 0, + ValueError is raised. + + The second optional argument determins can semophore be released more than + initial internal counter value; it defaults to False. If the value given + is True and number of release() is more than number of successfull + acquire() calls ValueError is raised. + """ + + def __init__(self, value=1, bound=False, *, loop=None): + if value < 0: + raise ValueError("Semaphore initial value must be > 0") + self._value = value + self._bound = bound + self._bound_value = value + self._waiters = collections.deque() + self._locked = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + # TODO: add waiters:N if > 0. + res = super().__repr__() + return '<{} [{}]>'.format( + res[1:-1], + 'locked' if self._locked else 'unlocked,value:{}'.format( + self._value)) + + def locked(self): + """Returns True if semaphore can not be acquired immediately.""" + return self._locked + + @tasks.coroutine + def acquire(self): + """Acquire a semaphore. + + If the internal counter is larger than zero on entry, + decrement it by one and return True immediately. If it is + zero on entry, block, waiting until some other coroutine has + called release() to make it larger than 0, and then return + True. + """ + if not self._waiters and self._value > 0: + self._value -= 1 + if self._value == 0: + self._locked = True + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + self._value -= 1 + if self._value == 0: + self._locked = True + return True + finally: + self._waiters.remove(fut) + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + When it was zero on entry and another coroutine is waiting for it to + become larger than zero again, wake up that coroutine. + + If Semaphore is create with "bound" paramter equals true, then + release() method checks to make sure its current value doesn't exceed + its initial value. If it does, ValueError is raised. + """ + if self._bound and self._value >= self._bound_value: + raise ValueError('Semaphore released too many times') + + self._value += 1 + self._locked = False + + for waiter in self._waiters: + if not waiter.done(): + waiter.set_result(True) + break + + def __enter__(self): + # TODO: This is questionable. How do we know the user actually + # wrote "with (yield from sema)" instead of "with sema"? + return True + + def __exit__(self, *args): + self.release() + + def __iter__(self): + yield from self.acquire() + return self diff --git a/tulip/log.py b/tulip/log.py new file mode 100644 index 00000000..5f3534c7 --- /dev/null +++ b/tulip/log.py @@ -0,0 +1,7 @@ +"""Tulip logging configuration""" + +import logging + + +# Name the logger after the package. +logger = logging.getLogger(__package__) diff --git a/tulip/proactor_events.py b/tulip/proactor_events.py new file mode 100644 index 00000000..2a2701a8 --- /dev/null +++ b/tulip/proactor_events.py @@ -0,0 +1,346 @@ +"""Event loop using a proactor and related classes. + +A proactor is a "notify-on-completion" multiplexer. Currently a +proactor is only implemented on Windows with IOCP. +""" + +import socket + +from . import base_events +from . import constants +from . import futures +from . import transports +from .log import logger + + +class _ProactorBasePipeTransport(transports.BaseTransport): + """Base class for pipe and socket transports.""" + + def __init__(self, loop, sock, protocol, waiter=None, + extra=None, server=None): + super().__init__(extra) + self._set_extra(sock) + self._loop = loop + self._sock = sock + self._protocol = protocol + self._server = server + self._buffer = [] + self._read_fut = None + self._write_fut = None + self._conn_lost = 0 + self._closing = False # Set when close() called. + self._eof_written = False + if self._server is not None: + self._server.attach(self) + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _set_extra(self, sock): + self._extra['pipe'] = sock + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + if not self._buffer and self._write_fut is None: + self._loop.call_soon(self._call_connection_lost, None) + if self._read_fut is not None: + self._read_fut.cancel() + + def _fatal_error(self, exc): + logger.exception('Fatal error for %s', self) + self._force_close(exc) + + def _force_close(self, exc): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + if self._write_fut: + self._write_fut.cancel() + if self._read_fut: + self._read_fut.cancel() + self._write_fut = self._read_fut = None + self._buffer = [] + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + # XXX If there is a pending overlapped read on the other + # end then it may fail with ERROR_NETNAME_DELETED if we + # just close our end. First calling shutdown() seems to + # cure it, but maybe using DisconnectEx() would be better. + if hasattr(self._sock, 'shutdown'): + self._sock.shutdown(socket.SHUT_RDWR) + self._sock.close() + server = self._server + if server is not None: + server.detach(self) + self._server = None + + +class _ProactorReadPipeTransport(_ProactorBasePipeTransport, + transports.ReadTransport): + """Transport for read pipes.""" + + def __init__(self, loop, sock, protocol, waiter=None, + extra=None, server=None): + super().__init__(loop, sock, protocol, waiter, extra, server) + self._read_fut = None + self._paused = False + self._loop.call_soon(self._loop_reading) + + def pause_reading(self): + assert not self._closing, 'Cannot pause_reading() when closing' + assert not self._paused, 'Already paused' + self._paused = True + + def resume_reading(self): + assert self._paused, 'Not paused' + self._paused = False + if self._closing: + return + self._loop.call_soon(self._loop_reading, self._read_fut) + + def _loop_reading(self, fut=None): + if self._paused: + return + data = None + + try: + if fut is not None: + assert self._read_fut is fut or (self._read_fut is None and + self._closing) + self._read_fut = None + data = fut.result() # deliver data later in "finally" clause + + if self._closing: + # since close() has been called we ignore any read data + data = None + return + + if data == b'': + # we got end-of-file so no need to reschedule a new read + return + + # reschedule a new read + self._read_fut = self._loop._proactor.recv(self._sock, 4096) + except ConnectionAbortedError as exc: + if not self._closing: + self._fatal_error(exc) + except ConnectionResetError as exc: + self._force_close(exc) + except OSError as exc: + self._fatal_error(exc) + except futures.CancelledError: + if not self._closing: + raise + else: + self._read_fut.add_done_callback(self._loop_reading) + finally: + if data: + self._protocol.data_received(data) + elif data is not None: + keep_open = self._protocol.eof_received() + if not keep_open: + self.close() + + +class _ProactorWritePipeTransport(_ProactorBasePipeTransport, + transports.WriteTransport): + """Transport for write pipes.""" + + def write(self, data): + assert isinstance(data, bytes), repr(data) + if self._eof_written: + raise IOError('write_eof() already called') + + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + self._buffer.append(data) + if self._write_fut is None: + self._loop_writing() + + def _loop_writing(self, f=None): + try: + assert f is self._write_fut + self._write_fut = None + if f: + f.result() + data = b''.join(self._buffer) + self._buffer = [] + if not data: + if self._closing: + self._loop.call_soon(self._call_connection_lost, None) + if self._eof_written: + self._sock.shutdown(socket.SHUT_WR) + return + self._write_fut = self._loop._proactor.send(self._sock, data) + self._write_fut.add_done_callback(self._loop_writing) + except ConnectionResetError as exc: + self._force_close(exc) + except OSError as exc: + self._fatal_error(exc) + + def can_write_eof(self): + return True + + def write_eof(self): + self.close() + + def abort(self): + self._force_close(None) + + +class _ProactorDuplexPipeTransport(_ProactorReadPipeTransport, + _ProactorWritePipeTransport, + transports.Transport): + """Transport for duplex pipes.""" + + def can_write_eof(self): + return False + + def write_eof(self): + raise NotImplementedError + + +class _ProactorSocketTransport(_ProactorReadPipeTransport, + _ProactorWritePipeTransport, + transports.Transport): + """Transport for connected sockets.""" + + def _set_extra(self, sock): + self._extra['socket'] = sock + self._extra['sockname'] = sock.getsockname() + if 'peername' not in self._extra: + self._extra['peername'] = sock.getpeername() + + def can_write_eof(self): + return True + + def write_eof(self): + if self._closing or self._eof_written: + return + self._eof_written = True + if self._write_fut is None: + self._sock.shutdown(socket.SHUT_WR) + + +class BaseProactorEventLoop(base_events.BaseEventLoop): + + def __init__(self, proactor): + super().__init__() + logger.debug('Using proactor: %s', proactor.__class__.__name__) + self._proactor = proactor + self._selector = proactor # convenient alias + proactor.set_loop(self) + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, + extra=None, server=None): + return _ProactorSocketTransport(self, sock, protocol, waiter, + extra, server) + + def _make_duplex_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorDuplexPipeTransport(self, + sock, protocol, waiter, extra) + + def _make_read_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorReadPipeTransport(self, sock, protocol, waiter, extra) + + def _make_write_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorWritePipeTransport(self, sock, protocol, waiter, extra) + + def close(self): + if self._proactor is not None: + self._close_self_pipe() + self._proactor.close() + self._proactor = None + self._selector = None + + def sock_recv(self, sock, n): + return self._proactor.recv(sock, n) + + def sock_sendall(self, sock, data): + return self._proactor.send(sock, data) + + def sock_connect(self, sock, address): + return self._proactor.connect(sock, address) + + def sock_accept(self, sock): + return self._proactor.accept(sock) + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.call_soon(self._loop_self_reading) + + def _loop_self_reading(self, f=None): + try: + if f is not None: + f.result() # may raise + f = self._proactor.recv(self._ssock, 4096) + except: + self.close() + raise + else: + f.add_done_callback(self._loop_self_reading) + + def _write_to_self(self): + self._csock.send(b'x') + + def _start_serving(self, protocol_factory, sock, ssl=None, server=None): + assert not ssl, 'IocpEventLoop is incompatible with SSL.' + + def loop(f=None): + try: + if f is not None: + conn, addr = f.result() + protocol = protocol_factory() + self._make_socket_transport( + conn, protocol, + extra={'peername': addr}, server=server) + f = self._proactor.accept(sock) + except OSError: + if sock.fileno() != -1: + logger.exception('Accept failed') + sock.close() + except futures.CancelledError: + sock.close() + else: + f.add_done_callback(loop) + + self.call_soon(loop) + + def _process_events(self, event_list): + pass # XXX hard work currently done in poll + + def _stop_serving(self, sock): + self._proactor._stop_serving(sock) + sock.close() diff --git a/tulip/protocols.py b/tulip/protocols.py new file mode 100644 index 00000000..a94abbe5 --- /dev/null +++ b/tulip/protocols.py @@ -0,0 +1,98 @@ +"""Abstract Protocol class.""" + +__all__ = ['Protocol', 'DatagramProtocol'] + + +class BaseProtocol: + """ABC for base protocol class. + + Usually user implements protocols that derived from BaseProtocol + like Protocol or ProcessProtocol. + + The only case when BaseProtocol should be implemented directly is + write-only transport like write pipe + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the pipe connection. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + +class Protocol(BaseProtocol): + """ABC representing a protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing (they don't raise exceptions). + + When the user wants to requests a transport, they pass a protocol + factory to a utility function (e.g., EventLoop.create_connection()). + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_lost() will be called exactly once + with either an exception object or None as an argument. + + State machine of calls: + + start -> CM [-> DR*] [-> ER?] -> CL -> end + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + If this returns a false value (including None), the transport + will close itself. If it returns a true value, closing the + transport is up to the protocol. + """ + + +class DatagramProtocol(BaseProtocol): + """ABC representing a datagram protocol.""" + + def datagram_received(self, data, addr): + """Called when some datagram is received.""" + + def connection_refused(self, exc): + """Connection is refused.""" + + +class SubprocessProtocol(BaseProtocol): + """ABC representing a protocol for subprocess calls.""" + + def pipe_data_received(self, fd, data): + """Called when subprocess write a data into stdout/stderr pipes. + + fd is int file dascriptor. + data is bytes object. + """ + + def pipe_connection_lost(self, fd, exc): + """Called when a file descriptor associated with the child process is + closed. + + fd is the int file descriptor that was closed. + """ + + def process_exited(self): + """Called when subprocess has exited. + """ diff --git a/tulip/queues.py b/tulip/queues.py new file mode 100644 index 00000000..536de1cb --- /dev/null +++ b/tulip/queues.py @@ -0,0 +1,284 @@ +"""Queues""" + +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue', + 'Full', 'Empty'] + +import collections +import heapq +import queue + +from . import events +from . import futures +from . import locks +from .tasks import coroutine + + +# Re-export queue.Full and .Empty exceptions. +Full = queue.Full +Empty = queue.Empty + + +class Queue: + """A queue, useful for coordinating producer and consumer coroutines. + + If maxsize is less than or equal to zero, the queue size is infinite. If it + is an integer greater than 0, then "yield from put()" will block when the + queue reaches maxsize, until an item is removed by get(). + + Unlike the standard library Queue, you can reliably know this Queue's size + with qsize(), since your single-threaded Tulip application won't be + interrupted between calling qsize() and doing an operation on the Queue. + """ + + def __init__(self, maxsize=0, *, loop=None): + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._maxsize = maxsize + + # Futures. + self._getters = collections.deque() + # Pairs of (item, Future). + self._putters = collections.deque() + self._init(maxsize) + + def _init(self, maxsize): + self._queue = collections.deque() + + def _get(self): + return self._queue.popleft() + + def _put(self, item): + self._queue.append(item) + + def __repr__(self): + return '<{} at {:#x} {}>'.format( + type(self).__name__, id(self), self._format()) + + def __str__(self): + return '<{} {}>'.format(type(self).__name__, self._format()) + + def _format(self): + result = 'maxsize={!r}'.format(self._maxsize) + if getattr(self, '_queue', None): + result += ' _queue={!r}'.format(list(self._queue)) + if self._getters: + result += ' _getters[{}]'.format(len(self._getters)) + if self._putters: + result += ' _putters[{}]'.format(len(self._putters)) + return result + + def _consume_done_getters(self): + # Delete waiters at the head of the get() queue who've timed out. + while self._getters and self._getters[0].done(): + self._getters.popleft() + + def _consume_done_putters(self): + # Delete waiters at the head of the put() queue who've timed out. + while self._putters and self._putters[0][1].done(): + self._putters.popleft() + + def qsize(self): + """Number of items in the queue.""" + return len(self._queue) + + @property + def maxsize(self): + """Number of items allowed in the queue.""" + return self._maxsize + + def empty(self): + """Return True if the queue is empty, False otherwise.""" + return not self._queue + + def full(self): + """Return True if there are maxsize items in the queue. + + Note: if the Queue was initialized with maxsize=0 (the default), + then full() is never True. + """ + if self._maxsize <= 0: + return False + else: + return self.qsize() == self._maxsize + + @coroutine + def put(self, item): + """Put an item into the queue. + + If you yield from put(), wait until a free slot is available + before adding item. + """ + self._consume_done_getters() + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + waiter = futures.Future(loop=self._loop) + + self._putters.append((item, waiter)) + yield from waiter + + else: + self._put(item) + + def put_nowait(self, item): + """Put an item into the queue without blocking. + + If no free slot is immediately available, raise Full. + """ + self._consume_done_getters() + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize == self.qsize(): + raise Full + else: + self._put(item) + + @coroutine + def get(self): + """Remove and return an item from the queue. + + If you yield from get(), wait until a item is available. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + + # When a getter runs and frees up a slot so this putter can + # run, we need to defer the put for a tick to ensure that + # getters and putters alternate perfectly. See + # ChannelTest.test_wait. + self._loop.call_soon(putter.set_result, None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + waiter = futures.Future(loop=self._loop) + + self._getters.append(waiter) + return (yield from waiter) + + def get_nowait(self): + """Remove and return an item from the queue. + + Return an item if one is immediately available, else raise Full. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + # Wake putter on next tick. + putter.set_result(None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + raise Empty + + +class PriorityQueue(Queue): + """A subclass of Queue; retrieves entries in priority order (lowest first). + + Entries are typically tuples of the form: (priority number, data). + """ + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item, heappush=heapq.heappush): + heappush(self._queue, item) + + def _get(self, heappop=heapq.heappop): + return heappop(self._queue) + + +class LifoQueue(Queue): + """A subclass of Queue that retrieves most recently added entries first.""" + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item): + self._queue.append(item) + + def _get(self): + return self._queue.pop() + + +class JoinableQueue(Queue): + """A subclass of Queue with task_done() and join() methods.""" + + def __init__(self, maxsize=0, *, loop=None): + super().__init__(maxsize=maxsize, loop=loop) + self._unfinished_tasks = 0 + self._finished = locks.Event(loop=self._loop) + self._finished.set() + + def _format(self): + result = Queue._format(self) + if self._unfinished_tasks: + result += ' tasks={}'.format(self._unfinished_tasks) + return result + + def _put(self, item): + super()._put(item) + self._unfinished_tasks += 1 + self._finished.clear() + + def task_done(self): + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each get() used to fetch a task, + a subsequent call to task_done() tells the queue that the processing + on the task is complete. + + If a join() is currently blocking, it will resume when all items have + been processed (meaning that a task_done() call was received for every + item that had been put() into the queue). + + Raises ValueError if called more times than there were items placed in + the queue. + """ + if self._unfinished_tasks <= 0: + raise ValueError('task_done() called too many times') + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + @coroutine + def join(self): + """Block until all items in the queue have been gotten and processed. + + The count of unfinished tasks goes up whenever an item is added to the + queue. The count goes down whenever a consumer thread calls task_done() + to indicate that the item was retrieved and all work on it is complete. + When the count of unfinished tasks drops to zero, join() unblocks. + """ + if self._unfinished_tasks > 0: + yield from self._finished.wait() diff --git a/tulip/selector_events.py b/tulip/selector_events.py new file mode 100644 index 00000000..084d9be7 --- /dev/null +++ b/tulip/selector_events.py @@ -0,0 +1,754 @@ +"""Event loop using a selector and related classes. + +A selector is a "notify-when-ready" multiplexer. For a subclass which +also includes support for signal handling, see the unix_events sub-module. +""" + +import collections +import socket +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import base_events +from . import constants +from . import events +from . import futures +from . import selectors +from . import transports +from .log import logger + + +class BaseSelectorEventLoop(base_events.BaseEventLoop): + """Selector event loop. + + See events.EventLoop for API specification. + """ + + def __init__(self, selector=None): + super().__init__() + + if selector is None: + selector = selectors.DefaultSelector() + logger.debug('Using selector: %s', selector.__class__.__name__) + self._selector = selector + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None, server=None): + return _SelectorSocketTransport(self, sock, protocol, waiter, + extra, server) + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + server_side=False, server_hostname=None, + extra=None, server=None): + return _SelectorSslTransport( + self, rawsock, protocol, sslcontext, waiter, + server_side, server_hostname, extra, server) + + def _make_datagram_transport(self, sock, protocol, + address=None, extra=None): + return _SelectorDatagramTransport(self, sock, protocol, address, extra) + + def close(self): + if self._selector is not None: + self._close_self_pipe() + self._selector.close() + self._selector = None + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self.remove_reader(self._ssock.fileno()) + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.add_reader(self._ssock.fileno(), self._read_from_self) + + def _read_from_self(self): + try: + self._ssock.recv(1) + except (BlockingIOError, InterruptedError): + pass + + def _write_to_self(self): + try: + self._csock.send(b'x') + except (BlockingIOError, InterruptedError): + pass + + def _start_serving(self, protocol_factory, sock, ssl=None, server=None): + self.add_reader(sock.fileno(), self._accept_connection, + protocol_factory, sock, ssl, server) + + def _accept_connection(self, protocol_factory, sock, ssl=None, + server=None): + try: + conn, addr = sock.accept() + conn.setblocking(False) + except (BlockingIOError, InterruptedError): + pass # False alarm. + except Exception: + # Bad error. Stop serving. + self.remove_reader(sock.fileno()) + sock.close() + # There's nowhere to send the error, so just log it. + # TODO: Someone will want an error handler for this. + logger.exception('Accept failed') + else: + if ssl: + self._make_ssl_transport( + conn, protocol_factory(), ssl, None, + server_side=True, extra={'peername': addr}, server=server) + else: + self._make_socket_transport( + conn, protocol_factory(), extra={'peername': addr}, + server=server) + # It's now up to the protocol to handle the connection. + + def add_reader(self, fd, callback, *args): + """Add a reader callback.""" + handle = events.make_handle(callback, args) + try: + key = self._selector.get_key(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_READ, + (handle, None)) + else: + mask, (reader, writer) = key.events, key.data + self._selector.modify(fd, mask | selectors.EVENT_READ, + (handle, writer)) + if reader is not None: + reader.cancel() + + def remove_reader(self, fd): + """Remove a reader callback.""" + try: + key = self._selector.get_key(fd) + except KeyError: + return False + else: + mask, (reader, writer) = key.events, key.data + mask &= ~selectors.EVENT_READ + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (None, writer)) + + if reader is not None: + reader.cancel() + return True + else: + return False + + def add_writer(self, fd, callback, *args): + """Add a writer callback..""" + handle = events.make_handle(callback, args) + try: + key = self._selector.get_key(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_WRITE, + (None, handle)) + else: + mask, (reader, writer) = key.events, key.data + self._selector.modify(fd, mask | selectors.EVENT_WRITE, + (reader, handle)) + if writer is not None: + writer.cancel() + + def remove_writer(self, fd): + """Remove a writer callback.""" + try: + key = self._selector.get_key(fd) + except KeyError: + return False + else: + mask, (reader, writer) = key.events, key.data + # Remove both writer and connector. + mask &= ~selectors.EVENT_WRITE + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None)) + + if writer is not None: + writer.cancel() + return True + else: + return False + + def sock_recv(self, sock, n): + """XXX""" + fut = futures.Future(loop=self) + self._sock_recv(fut, False, sock, n) + return fut + + def _sock_recv(self, fut, registered, sock, n): + fd = sock.fileno() + if registered: + # Remove the callback early. It should be rare that the + # selector says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_reader(fd) + if fut.cancelled(): + return + try: + data = sock.recv(n) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_recv, fut, True, sock, n) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(data) + + def sock_sendall(self, sock, data): + """XXX""" + fut = futures.Future(loop=self) + if data: + self._sock_sendall(fut, False, sock, data) + else: + fut.set_result(None) + return fut + + def _sock_sendall(self, fut, registered, sock, data): + fd = sock.fileno() + + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + + try: + n = sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + fut.set_exception(exc) + return + + if n == len(data): + fut.set_result(None) + else: + if n: + data = data[n:] + self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + + def sock_connect(self, sock, address): + """XXX""" + # That address better not require a lookup! We're not calling + # self.getaddrinfo() for you here. But verifying this is + # complicated; the socket module doesn't have a pattern for + # IPv6 addresses (there are too many forms, apparently). + fut = futures.Future(loop=self) + self._sock_connect(fut, False, sock, address) + return fut + + def _sock_connect(self, fut, registered, sock, address): + # TODO: Use getaddrinfo() to look up the address, to avoid the + # trap of hanging the entire event loop when the address + # requires doing a DNS lookup. (OTOH, the caller should + # already have done this, so it would be nice if we could + # easily tell whether the address needs looking up or not. I + # know how to do this for IPv4, but IPv6 addresses have many + # syntaxes.) + fd = sock.fileno() + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + try: + if not registered: + # First time around. + sock.connect(address) + else: + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to the except clause below. + raise OSError(err, 'Connect call failed') + except (BlockingIOError, InterruptedError): + self.add_writer(fd, self._sock_connect, fut, True, sock, address) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(None) + + def sock_accept(self, sock): + """XXX""" + fut = futures.Future(loop=self) + self._sock_accept(fut, False, sock) + return fut + + def _sock_accept(self, fut, registered, sock): + fd = sock.fileno() + if registered: + self.remove_reader(fd) + if fut.cancelled(): + return + try: + conn, address = sock.accept() + conn.setblocking(False) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_accept, fut, True, sock) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result((conn, address)) + + def _process_events(self, event_list): + for key, mask in event_list: + fileobj, (reader, writer) = key.fileobj, key.data + if mask & selectors.EVENT_READ and reader is not None: + if reader._cancelled: + self.remove_reader(fileobj) + else: + self._add_callback(reader) + if mask & selectors.EVENT_WRITE and writer is not None: + if writer._cancelled: + self.remove_writer(fileobj) + else: + self._add_callback(writer) + + def _stop_serving(self, sock): + self.remove_reader(sock.fileno()) + sock.close() + + +class _SelectorTransport(transports.Transport): + + max_size = 256 * 1024 # Buffer size passed to recv(). + + def __init__(self, loop, sock, protocol, extra, server=None): + super().__init__(extra) + self._extra['socket'] = sock + self._extra['sockname'] = sock.getsockname() + if 'peername' not in self._extra: + try: + self._extra['peername'] = sock.getpeername() + except socket.error: + self._extra['peername'] = None + self._loop = loop + self._sock = sock + self._sock_fd = sock.fileno() + self._protocol = protocol + self._server = server + self._buffer = collections.deque() + self._conn_lost = 0 # Set when call to connection_lost scheduled. + self._closing = False # Set when close() called. + if server is not None: + server.attach(self) + + def abort(self): + self._force_close(None) + + def close(self): + if self._closing: + return + self._closing = True + self._loop.remove_reader(self._sock_fd) + if not self._buffer: + self._conn_lost += 1 + self._loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc): + # Should be called from exception handler only. + if not isinstance(exc, (BrokenPipeError, ConnectionResetError)): + logger.exception('Fatal error for %s', self) + self._force_close(exc) + + def _force_close(self, exc): + if self._conn_lost: + return + if self._buffer: + self._buffer.clear() + self._loop.remove_writer(self._sock_fd) + if not self._closing: + self._closing = True + self._loop.remove_reader(self._sock_fd) + self._conn_lost += 1 + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + self._sock = None + self._protocol = None + self._loop = None + server = self._server + if server is not None: + server.detach(self) + self._server = None + + +class _SelectorSocketTransport(_SelectorTransport): + + def __init__(self, loop, sock, protocol, waiter=None, + extra=None, server=None): + super().__init__(loop, sock, protocol, extra, server) + self._eof = False + self._paused = False + + self._loop.add_reader(self._sock_fd, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def pause_reading(self): + assert not self._closing, 'Cannot pause_reading() when closing' + assert not self._paused, 'Already paused' + self._paused = True + self._loop.remove_reader(self._sock_fd) + + def resume_reading(self): + assert self._paused, 'Not paused' + self._paused = False + if self._closing: + return + self._loop.add_reader(self._sock_fd, self._read_ready) + + def _read_ready(self): + try: + data = self._sock.recv(self.max_size) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + keep_open = self._protocol.eof_received() + if not keep_open: + self.close() + + def write(self, data): + assert isinstance(data, bytes), repr(type(data)) + assert not self._eof, 'Cannot call write() after write_eof()' + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Attempt to send it right away first. + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + return + else: + data = data[n:] + if not data: + return + + # Start async I/O. + self._loop.add_writer(self._sock_fd, self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + assert data, 'Data should not be empty' + + self._buffer.clear() + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except Exception as exc: + self._loop.remove_writer(self._sock_fd) + self._fatal_error(exc) + else: + data = data[n:] + if not data: + self._loop.remove_writer(self._sock_fd) + if self._closing: + self._call_connection_lost(None) + elif self._eof: + self._sock.shutdown(socket.SHUT_WR) + return + self._buffer.append(data) # Try again later. + + def write_eof(self): + if self._eof: + return + self._eof = True + if not self._buffer: + self._sock.shutdown(socket.SHUT_WR) + + def can_write_eof(self): + return True + + +class _SelectorSslTransport(_SelectorTransport): + + def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, + server_side=False, server_hostname=None, + extra=None, server=None): + if server_side: + assert isinstance( + sslcontext, ssl.SSLContext), 'Must pass an SSLContext' + else: + # Client-side may pass ssl=True to use a default context. + # The default is the same as used by urllib. + if sslcontext is None: + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.set_default_verify_paths() + sslcontext.verify_mode = ssl.CERT_REQUIRED + wrap_kwargs = { + 'server_side': server_side, + 'do_handshake_on_connect': False, + } + if server_hostname is not None and not server_side and ssl.HAS_SNI: + wrap_kwargs['server_hostname'] = server_hostname + sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs) + + super().__init__(loop, sslsock, protocol, extra, server) + + self._server_hostname = server_hostname + self._waiter = waiter + self._rawsock = rawsock + self._sslcontext = sslcontext + self._paused = False + + # SSL-specific extra info. (peercert is set later) + self._extra.update(sslcontext=sslcontext) + + self._on_handshake() + + def _on_handshake(self): + try: + self._sock.do_handshake() + except ssl.SSLWantReadError: + self._loop.add_reader(self._sock_fd, self._on_handshake) + return + except ssl.SSLWantWriteError: + self._loop.add_writer(self._sock_fd, self._on_handshake) + return + except Exception as exc: + self._sock.close() + if self._waiter is not None: + self._waiter.set_exception(exc) + return + except BaseException as exc: + self._sock.close() + if self._waiter is not None: + self._waiter.set_exception(exc) + raise + + # Verify hostname if requested. + peercert = self._sock.getpeercert() + if (self._server_hostname is not None and + self._sslcontext.verify_mode == ssl.CERT_REQUIRED): + try: + ssl.match_hostname(peercert, self._server_hostname) + except Exception as exc: + self._sock.close() + if self._waiter is not None: + self._waiter.set_exception(exc) + return + + # Add extra info that becomes available after handshake. + self._extra.update(peercert=peercert, + cipher=self._sock.cipher(), + compression=self._sock.compression(), + ) + + self._loop.remove_reader(self._sock_fd) + self._loop.remove_writer(self._sock_fd) + self._loop.add_reader(self._sock_fd, self._on_ready) + self._loop.add_writer(self._sock_fd, self._on_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if self._waiter is not None: + self._loop.call_soon(self._waiter.set_result, None) + + def pause_reading(self): + # XXX This is a bit icky, given the comment at the top of + # _on_ready(). Is it possible to evoke a deadlock? I don't + # know, although it doesn't look like it; write() will still + # accept more data for the buffer and eventually the app will + # call resume_reading() again, and things will flow again. + + assert not self._closing, 'Cannot pause_reading() when closing' + assert not self._paused, 'Already paused' + self._paused = True + self._loop.remove_reader(self._sock_fd) + + def resume_reading(self): + assert self._paused, 'Not paused' + self._paused = False + if self._closing: + return + self._loop.add_reader(self._sock_fd, self._on_ready) + + def _on_ready(self): + # Because of renegotiations (?), there's no difference between + # readable and writable. We just try both. XXX This may be + # incorrect; we probably need to keep state about what we + # should do next. + + # First try reading. + if not self._closing and not self._paused: + try: + data = self._sock.recv(self.max_size) + except (BlockingIOError, InterruptedError, + ssl.SSLWantReadError, ssl.SSLWantWriteError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + try: + self._protocol.eof_received() + finally: + self.close() + + # Now try writing, if there's anything to write. + if self._buffer: + data = b''.join(self._buffer) + self._buffer.clear() + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError, + ssl.SSLWantReadError, ssl.SSLWantWriteError): + n = 0 + except Exception as exc: + self._loop.remove_writer(self._sock_fd) + self._fatal_error(exc) + return + + if n < len(data): + self._buffer.append(data[n:]) + + if self._closing and not self._buffer: + self._loop.remove_writer(self._sock_fd) + self._call_connection_lost(None) + + def write(self, data): + assert isinstance(data, bytes), repr(type(data)) + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + self._buffer.append(data) + # We could optimize, but the callback can do this for now. + + def can_write_eof(self): + return False + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + self._loop.remove_reader(self._sock_fd) + + +class _SelectorDatagramTransport(_SelectorTransport): + + def __init__(self, loop, sock, protocol, address=None, extra=None): + super().__init__(loop, sock, protocol, extra) + + self._address = address + self._loop.add_reader(self._sock_fd, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + + def _read_ready(self): + try: + data, addr = self._sock.recvfrom(self.max_size) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc) + else: + self._protocol.datagram_received(data, addr) + + def sendto(self, data, addr=None): + assert isinstance(data, bytes), repr(type(data)) + if not data: + return + + if self._address: + assert addr in (None, self._address) + + if self._conn_lost and self._address: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Attempt to send it right away first. + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + return + except (BlockingIOError, InterruptedError): + self._loop.add_writer(self._sock_fd, self._sendto_ready) + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except Exception as exc: + self._fatal_error(exc) + return + + self._buffer.append((data, addr)) + + def _sendto_ready(self): + while self._buffer: + data, addr = self._buffer.popleft() + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + except (BlockingIOError, InterruptedError): + self._buffer.appendleft((data, addr)) # Try again later. + break + except ConnectionRefusedError as exc: + if self._address: + self._fatal_error(exc) + return + except Exception as exc: + self._fatal_error(exc) + return + + if not self._buffer: + self._loop.remove_writer(self._sock_fd) + if self._closing: + self._call_connection_lost(None) + + def _force_close(self, exc): + if self._address and isinstance(exc, ConnectionRefusedError): + self._protocol.connection_refused(exc) + super()._force_close(exc) diff --git a/tulip/selectors.py b/tulip/selectors.py new file mode 100644 index 00000000..fe027f09 --- /dev/null +++ b/tulip/selectors.py @@ -0,0 +1,405 @@ +"""Selectors module. + +This module allows high-level and efficient I/O multiplexing, built upon the +`select` module primitives. +""" + + +from abc import ABCMeta, abstractmethod +from collections import namedtuple +import functools +import select +import sys + + +# generic events, that must be mapped to implementation-specific ones +EVENT_READ = (1 << 0) +EVENT_WRITE = (1 << 1) + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file object or file descriptor + + Returns: + corresponding file descriptor + """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (AttributeError, TypeError, ValueError): + raise ValueError("Invalid file object: " + "{!r}".format(fileobj)) from None + if fd < 0: + raise ValueError("Invalid file descriptor: {}".format(fd)) + return fd + + +SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data']) +"""Object used to associate a file object to its backing file descriptor, +selected event mask and attached data.""" + + +class BaseSelector(metaclass=ABCMeta): + """Base selector class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + performant implementation on the current platform. + """ + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object or file descriptor + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + """ + if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)): + raise ValueError("Invalid events: {!r}".format(events)) + + key = SelectorKey(fileobj, _fileobj_to_fd(fileobj), events, data) + + if key.fd in self._fd_to_key: + raise KeyError("{!r} (FD {}) is already " + "registered".format(fileobj, key.fd)) + + self._fd_to_key[key.fd] = key + return key + + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object or file descriptor + + Returns: + SelectorKey instance + """ + try: + key = self._fd_to_key.pop(_fileobj_to_fd(fileobj)) + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + return key + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object or file descriptor + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + """ + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fd_to_key[_fileobj_to_fd(fileobj)] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + if events != key.events or data != key.data: + # TODO: If only the data changed, use a shortcut that only + # updates the data. + self.unregister(fileobj) + return self.register(fileobj, events, data) + else: + return key + + @abstractmethod + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout <= 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (key, events) for ready file objects + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE + """ + raise NotImplementedError() + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + self._fd_to_key.clear() + + def get_key(self, fileobj): + """Return the key associated to a registered file object. + + Returns: + SelectorKey for this file object + """ + try: + return self._fd_to_key[_fileobj_to_fd(fileobj)] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key, or None if not found + """ + try: + return self._fd_to_key[fd] + except KeyError: + return None + + +class SelectSelector(BaseSelector): + """Select-based selector.""" + + def __init__(self): + super().__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & EVENT_WRITE: + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + if sys.platform == 'win32': + def _select(self, r, w, _, timeout=None): + r, w, x = select.select(r, w, w, timeout) + return r, w + x, [] + else: + _select = select.select + + def select(self, timeout=None): + timeout = None if timeout is None else max(timeout, 0) + ready = [] + try: + r, w, _ = self._select(self._readers, self._writers, [], timeout) + except InterruptedError: + return ready + r = set(r) + w = set(w) + for fd in r | w: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + +if hasattr(select, 'poll'): + + class PollSelector(BaseSelector): + """Poll-based selector.""" + + def __init__(self): + super().__init__() + self._poll = select.poll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= select.POLLIN + if events & EVENT_WRITE: + poll_events |= select.POLLOUT + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = None if timeout is None else max(int(1000 * timeout), 0) + ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & ~select.POLLIN: + events |= EVENT_WRITE + if event & ~select.POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + +if hasattr(select, 'epoll'): + + class EpollSelector(BaseSelector): + """Epoll-based selector.""" + + def __init__(self): + super().__init__() + self._epoll = select.epoll() + + def fileno(self): + return self._epoll.fileno() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + epoll_events = 0 + if events & EVENT_READ: + epoll_events |= select.EPOLLIN + if events & EVENT_WRITE: + epoll_events |= select.EPOLLOUT + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._epoll.unregister(key.fd) + return key + + def select(self, timeout=None): + timeout = -1 if timeout is None else max(timeout, 0) + max_ev = len(self._fd_to_key) + ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & ~select.EPOLLIN: + events |= EVENT_WRITE + if event & ~select.EPOLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + super().close() + self._epoll.close() + + +if hasattr(select, 'kqueue'): + + class KqueueSelector(BaseSelector): + """Kqueue-based selector.""" + + def __init__(self): + super().__init__() + self._kqueue = select.kqueue() + + def fileno(self): + return self._kqueue.fileno() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + if key.events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + if key.events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_DELETE) + self._kqueue.control([kev], 0, 0) + return key + + def select(self, timeout=None): + timeout = None if timeout is None else max(timeout, 0) + max_ev = len(self._fd_to_key) + ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except InterruptedError: + return ready + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == select.KQ_FILTER_READ: + events |= EVENT_READ + if flag == select.KQ_FILTER_WRITE: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + super().close() + self._kqueue.close() + + +# Choose the best implementation: roughly, epoll|kqueue > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + DefaultSelector = KqueueSelector +elif 'EpollSelector' in globals(): + DefaultSelector = EpollSelector +elif 'PollSelector' in globals(): + DefaultSelector = PollSelector +else: + DefaultSelector = SelectSelector diff --git a/tulip/streams.py b/tulip/streams.py new file mode 100644 index 00000000..9915aa5c --- /dev/null +++ b/tulip/streams.py @@ -0,0 +1,257 @@ +"""Stream-related things.""" + +__all__ = ['StreamReader', 'StreamReaderProtocol', 'open_connection'] + +import collections + +from . import events +from . import futures +from . import protocols +from . import tasks + + +_DEFAULT_LIMIT = 2**16 + + +@tasks.coroutine +def open_connection(host=None, port=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """A wrapper for create_connection() returning a (reader, writer) pair. + + The reader returned is a StreamReader instance; the writer is a + Transport. + + The arguments are all the usual arguments to create_connection() + except protocol_factory; most common are positional host and port, + with various optional keyword arguments following. + + Additional optional keyword arguments are loop (to set the event loop + instance to use) and limit (to set the buffer limit passed to the + StreamReader). + + (If you want to customize the StreamReader and/or + StreamReaderProtocol classes, just copy the code -- there's + really nothing special here except some convenience.) + """ + if loop is None: + loop = events.get_event_loop() + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader) + transport, _ = yield from loop.create_connection( + lambda: protocol, host, port, **kwds) + return reader, transport # (reader, writer) + + +class StreamReaderProtocol(protocols.Protocol): + """Trivial helper class to adapt between Protocol and StreamReader. + + (This is a helper class instead of making StreamReader itself a + Protocol subclass, because the StreamReader has other potential + uses, and to prevent the user of the StreamReader to accidentally + call inappropriate methods of the protocol.) + """ + + def __init__(self, stream_reader): + self.stream_reader = stream_reader + + def connection_made(self, transport): + self.stream_reader.set_transport(transport) + + def connection_lost(self, exc): + if exc is None: + self.stream_reader.feed_eof() + else: + self.stream_reader.set_exception(exc) + + def data_received(self, data): + self.stream_reader.feed_data(data) + + def eof_received(self): + self.stream_reader.feed_eof() + + +class StreamReader: + + def __init__(self, limit=_DEFAULT_LIMIT, loop=None): + # The line length limit is a security feature; + # it also doubles as half the buffer limit. + self.limit = limit + if loop is None: + loop = events.get_event_loop() + self.loop = loop + self.buffer = collections.deque() # Deque of bytes objects. + self.byte_count = 0 # Bytes in buffer. + self.eof = False # Whether we're done. + self.waiter = None # A future. + self._exception = None + self._transport = None + self._paused = False + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + waiter = self.waiter + if waiter is not None: + self.waiter = None + if not waiter.cancelled(): + waiter.set_exception(exc) + + def set_transport(self, transport): + assert self._transport is None, 'Transport already set' + self._transport = transport + + def _maybe_resume_transport(self): + if self._paused and self.byte_count <= self.limit: + self._paused = False + self._transport.resume_reading() + + def feed_eof(self): + self.eof = True + waiter = self.waiter + if waiter is not None: + self.waiter = None + if not waiter.cancelled(): + waiter.set_result(True) + + def feed_data(self, data): + if not data: + return + + self.buffer.append(data) + self.byte_count += len(data) + + waiter = self.waiter + if waiter is not None: + self.waiter = None + if not waiter.cancelled(): + waiter.set_result(False) + + if (self._transport is not None and + not self._paused and + self.byte_count > 2*self.limit): + try: + self._transport.pause_reading() + except NotImplementedError: + # The transport can't be paused. + # We'll just have to buffer all data. + # Forget the transport so we don't keep trying. + self._transport = None + else: + self._paused = True + + @tasks.coroutine + def readline(self): + if self._exception is not None: + raise self._exception + + parts = [] + parts_size = 0 + not_enough = True + + while not_enough: + while self.buffer and not_enough: + data = self.buffer.popleft() + ichar = data.find(b'\n') + if ichar < 0: + parts.append(data) + parts_size += len(data) + else: + ichar += 1 + head, tail = data[:ichar], data[ichar:] + if tail: + self.buffer.appendleft(tail) + not_enough = False + parts.append(head) + parts_size += len(head) + + if parts_size > self.limit: + self.byte_count -= parts_size + self._maybe_resume_transport() + raise ValueError('Line is too long') + + if self.eof: + break + + if not_enough: + assert self.waiter is None + self.waiter = futures.Future(loop=self.loop) + try: + yield from self.waiter + finally: + self.waiter = None + + line = b''.join(parts) + self.byte_count -= parts_size + self._maybe_resume_transport() + + return line + + @tasks.coroutine + def read(self, n=-1): + if self._exception is not None: + raise self._exception + + if not n: + return b'' + + if n < 0: + while not self.eof: + assert not self.waiter + self.waiter = futures.Future(loop=self.loop) + try: + yield from self.waiter + finally: + self.waiter = None + else: + if not self.byte_count and not self.eof: + assert not self.waiter + self.waiter = futures.Future(loop=self.loop) + try: + yield from self.waiter + finally: + self.waiter = None + + if n < 0 or self.byte_count <= n: + data = b''.join(self.buffer) + self.buffer.clear() + self.byte_count = 0 + self._maybe_resume_transport() + return data + + parts = [] + parts_bytes = 0 + while self.buffer and parts_bytes < n: + data = self.buffer.popleft() + data_bytes = len(data) + if n < parts_bytes + data_bytes: + data_bytes = n - parts_bytes + data, rest = data[:data_bytes], data[data_bytes:] + self.buffer.appendleft(rest) + + parts.append(data) + parts_bytes += data_bytes + self.byte_count -= data_bytes + self._maybe_resume_transport() + + return b''.join(parts) + + @tasks.coroutine + def readexactly(self, n): + if self._exception is not None: + raise self._exception + + if n <= 0: + return b'' + + while self.byte_count < n and not self.eof: + assert not self.waiter + self.waiter = futures.Future(loop=self.loop) + try: + yield from self.waiter + finally: + self.waiter = None + + return (yield from self.read(n)) diff --git a/tulip/tasks.py b/tulip/tasks.py new file mode 100644 index 00000000..7aba698c --- /dev/null +++ b/tulip/tasks.py @@ -0,0 +1,636 @@ +"""Support for tasks, coroutines and the scheduler.""" + +__all__ = ['coroutine', 'Task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'wait_for', 'as_completed', 'sleep', 'async', + 'gather', + ] + +import collections +import concurrent.futures +import functools +import inspect +import linecache +import traceback +import weakref + +from . import events +from . import futures +from .log import logger + +# If you set _DEBUG to true, @coroutine will wrap the resulting +# generator objects in a CoroWrapper instance (defined below). That +# instance will log a message when the generator is never iterated +# over, which may happen when you forget to use "yield from" with a +# coroutine call. Note that the value of the _DEBUG flag is taken +# when the decorator is used, so to be of any use it must be set +# before you define your coroutines. A downside of using this feature +# is that tracebacks show entries for the CoroWrapper.__next__ method +# when _DEBUG is true. +_DEBUG = False + + +class CoroWrapper: + """Wrapper for coroutine in _DEBUG mode.""" + + __slot__ = ['gen', 'func'] + + def __init__(self, gen, func): + assert inspect.isgenerator(gen), gen + self.gen = gen + self.func = func + + def __iter__(self): + return self + + def __next__(self): + return next(self.gen) + + def send(self, value): + return self.gen.send(value) + + def throw(self, exc): + return self.gen.throw(exc) + + def close(self): + return self.gen.close() + + def __del__(self): + frame = self.gen.gi_frame + if frame is not None and frame.f_lasti == -1: + func = self.func + code = func.__code__ + filename = code.co_filename + lineno = code.co_firstlineno + logger.error('Coroutine %r defined at %s:%s was never yielded from', + func.__name__, filename, lineno) + + +def coroutine(func): + """Decorator to mark coroutines. + + If the coroutine is not yielded from before it is destroyed, + an error message is logged. + """ + if inspect.isgeneratorfunction(func): + coro = func + else: + @functools.wraps(func) + def coro(*args, **kw): + res = func(*args, **kw) + if isinstance(res, futures.Future) or inspect.isgenerator(res): + res = yield from res + return res + + if not _DEBUG: + wrapper = coro + else: + @functools.wraps(func) + def wrapper(*args, **kwds): + w = CoroWrapper(coro(*args, **kwds), func) + w.__name__ = coro.__name__ + w.__doc__ = coro.__doc__ + return w + + wrapper._is_coroutine = True # For iscoroutinefunction(). + return wrapper + + +def iscoroutinefunction(func): + """Return True if func is a decorated coroutine function.""" + return getattr(func, '_is_coroutine', False) + + +def iscoroutine(obj): + """Return True if obj is a coroutine object.""" + return isinstance(obj, CoroWrapper) or inspect.isgenerator(obj) + + +class Task(futures.Future): + """A coroutine wrapped in a Future.""" + + # An important invariant maintained while a Task not done: + # + # - Either _fut_waiter is None, and _step() is scheduled; + # - or _fut_waiter is some Future, and _step() is *not* scheduled. + # + # The only transition from the latter to the former is through + # _wakeup(). When _fut_waiter is not None, one of its callbacks + # must be _wakeup(). + + # Weak set containing all tasks alive. + _all_tasks = weakref.WeakSet() + + @classmethod + def all_tasks(cls, loop=None): + """Return a set of all tasks for an event loop. + + By default all tasks for the current event loop are returned. + """ + if loop is None: + loop = events.get_event_loop() + return {t for t in cls._all_tasks if t._loop is loop} + + def __init__(self, coro, *, loop=None): + assert iscoroutine(coro), repr(coro) # Not a coroutine function! + super().__init__(loop=loop) + self._coro = iter(coro) # Use the iterator just in case. + self._fut_waiter = None + self._must_cancel = False + self._loop.call_soon(self._step) + self.__class__._all_tasks.add(self) + + def __repr__(self): + res = super().__repr__() + if (self._must_cancel and + self._state == futures._PENDING and + ')'.format(self._coro.__name__) + res[i:] + return res + + def get_stack(self, *, limit=None): + """Return the list of stack frames for this task's coroutine. + + If the coroutine is active, this returns the stack where it is + suspended. If the coroutine has completed successfully or was + cancelled, this returns an empty list. If the coroutine was + terminated by an exception, this returns the list of traceback + frames. + + The frames are always ordered from oldest to newest. + + The optional limit gives the maximum nummber of frames to + return; by default all available frames are returned. Its + meaning differs depending on whether a stack or a traceback is + returned: the newest frames of a stack are returned, but the + oldest frames of a traceback are returned. (This matches the + behavior of the traceback module.) + + For reasons beyond our control, only one stack frame is + returned for a suspended coroutine. + """ + frames = [] + f = self._coro.gi_frame + if f is not None: + while f is not None: + if limit is not None: + if limit <= 0: + break + limit -= 1 + frames.append(f) + f = f.f_back + frames.reverse() + elif self._exception is not None: + tb = self._exception.__traceback__ + while tb is not None: + if limit is not None: + if limit <= 0: + break + limit -= 1 + frames.append(tb.tb_frame) + tb = tb.tb_next + return frames + + def print_stack(self, *, limit=None, file=None): + """Print the stack or traceback for this task's coroutine. + + This produces output similar to that of the traceback module, + for the frames retrieved by get_stack(). The limit argument + is passed to get_stack(). The file argument is an I/O stream + to which the output goes; by default it goes to sys.stderr. + """ + extracted_list = [] + checked = set() + for f in self.get_stack(limit=limit): + lineno = f.f_lineno + co = f.f_code + filename = co.co_filename + name = co.co_name + if filename not in checked: + checked.add(filename) + linecache.checkcache(filename) + line = linecache.getline(filename, lineno, f.f_globals) + extracted_list.append((filename, lineno, name, line)) + exc = self._exception + if not extracted_list: + print('No stack for %r' % self, file=file) + elif exc is not None: + print('Traceback for %r (most recent call last):' % self, + file=file) + else: + print('Stack for %r (most recent call last):' % self, + file=file) + traceback.print_list(extracted_list, file=file) + if exc is not None: + for line in traceback.format_exception_only(exc.__class__, exc): + print(line, file=file, end='') + + def cancel(self): + if self.done(): + return False + if self._fut_waiter is not None: + if self._fut_waiter.cancel(): + # Leave self._fut_waiter; it may be a Task that + # catches and ignores the cancellation so we may have + # to cancel it again later. + return True + # It must be the case that self._step is already scheduled. + self._must_cancel = True + return True + + def _step(self, value=None, exc=None): + assert not self.done(), \ + '_step(): already done: {!r}, {!r}, {!r}'.format(self, value, exc) + if self._must_cancel: + if not isinstance(exc, futures.CancelledError): + exc = futures.CancelledError() + self._must_cancel = False + coro = self._coro + self._fut_waiter = None + # Call either coro.throw(exc) or coro.send(value). + try: + if exc is not None: + result = coro.throw(exc) + elif value is not None: + result = coro.send(value) + else: + result = next(coro) + except StopIteration as exc: + self.set_result(exc.value) + except futures.CancelledError as exc: + super().cancel() # I.e., Future.cancel(self). + except Exception as exc: + self.set_exception(exc) + except BaseException as exc: + self.set_exception(exc) + raise + else: + if isinstance(result, futures.Future): + # Yielded Future must come from Future.__iter__(). + if result._blocking: + result._blocking = False + result.add_done_callback(self._wakeup) + self._fut_waiter = result + if self._must_cancel: + if self._fut_waiter.cancel(): + self._must_cancel = False + else: + self._loop.call_soon( + self._step, None, + RuntimeError( + 'yield was used instead of yield from ' + 'in task {!r} with {!r}'.format(self, result))) + elif result is None: + # Bare yield relinquishes control for one event loop iteration. + self._loop.call_soon(self._step) + elif inspect.isgenerator(result): + # Yielding a generator is just wrong. + self._loop.call_soon( + self._step, None, + RuntimeError( + 'yield was used instead of yield from for ' + 'generator in task {!r} with {}'.format( + self, result))) + else: + # Yielding something else is an error. + self._loop.call_soon( + self._step, None, + RuntimeError( + 'Task got bad yield: {!r}'.format(result))) + self = None + + def _wakeup(self, future): + try: + value = future.result() + except Exception as exc: + # This may also be a cancellation. + self._step(None, exc) + else: + self._step(value, None) + self = None # Needed to break cycles when an exception occurs. + + +# wait() and as_completed() similar to those in PEP 3148. + +FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED +FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION +ALL_COMPLETED = concurrent.futures.ALL_COMPLETED + + +@coroutine +def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): + """Wait for the Futures and coroutines given by fs to complete. + + Coroutines will be wrapped in Tasks. + + Returns two sets of Future: (done, pending). + + Usage: + + done, pending = yield from tulip.wait(fs) + + Note: This does not raise TimeoutError! Futures that aren't done + when the timeout occurs are returned in the second set. + """ + if not fs: + raise ValueError('Set of coroutines/Futures is empty.') + + if loop is None: + loop = events.get_event_loop() + + fs = set(async(f, loop=loop) for f in fs) + + if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED): + raise ValueError('Invalid return_when value: {}'.format(return_when)) + return (yield from _wait(fs, timeout, return_when, loop)) + + +def _release_waiter(waiter, value=True, *args): + if not waiter.done(): + waiter.set_result(value) + + +@coroutine +def wait_for(fut, timeout, *, loop=None): + """Wait for the single Future or coroutine to complete, with timeout. + + Coroutine will be wrapped in Task. + + Returns result of the Future or coroutine. Raises TimeoutError when + timeout occurs. + + Usage: + + result = yield from tulip.wait_for(fut, 10.0) + + """ + if loop is None: + loop = events.get_event_loop() + + waiter = futures.Future(loop=loop) + timeout_handle = loop.call_later(timeout, _release_waiter, waiter, False) + cb = functools.partial(_release_waiter, waiter, True) + + fut = async(fut, loop=loop) + fut.add_done_callback(cb) + + try: + if (yield from waiter): + return fut.result() + else: + fut.remove_done_callback(cb) + raise futures.TimeoutError() + finally: + timeout_handle.cancel() + + +@coroutine +def _wait(fs, timeout, return_when, loop): + """Internal helper for wait() and _wait_for(). + + The fs argument must be a collection of Futures. + """ + assert fs, 'Set of Futures is empty.' + waiter = futures.Future(loop=loop) + timeout_handle = None + if timeout is not None: + timeout_handle = loop.call_later(timeout, _release_waiter, waiter) + counter = len(fs) + + def _on_completion(f): + nonlocal counter + counter -= 1 + if (counter <= 0 or + return_when == FIRST_COMPLETED or + return_when == FIRST_EXCEPTION and (not f.cancelled() and + f.exception() is not None)): + if timeout_handle is not None: + timeout_handle.cancel() + if not waiter.done(): + waiter.set_result(False) + + for f in fs: + f.add_done_callback(_on_completion) + + try: + yield from waiter + finally: + if timeout_handle is not None: + timeout_handle.cancel() + + done, pending = set(), set() + for f in fs: + f.remove_done_callback(_on_completion) + if f.done(): + done.add(f) + else: + pending.add(f) + return done, pending + + +# This is *not* a @coroutine! It is just an iterator (yielding Futures). +def as_completed(fs, *, loop=None, timeout=None): + """Return an iterator whose values, when waited for, are Futures. + + This differs from PEP 3148; the proper way to use this is: + + for f in as_completed(fs): + result = yield from f # The 'yield from' may raise. + # Use result. + + Raises TimeoutError if the timeout occurs before all Futures are + done. + + Note: The futures 'f' are not necessarily members of fs. + """ + loop = loop if loop is not None else events.get_event_loop() + deadline = None if timeout is None else loop.time() + timeout + todo = set(async(f, loop=loop) for f in fs) + completed = collections.deque() + + @coroutine + def _wait_for_one(): + while not completed: + timeout = None + if deadline is not None: + timeout = deadline - loop.time() + if timeout < 0: + raise futures.TimeoutError() + done, pending = yield from _wait( + todo, timeout, FIRST_COMPLETED, loop) + # Multiple callers might be waiting for the same events + # and getting the same outcome. Dedupe by updating todo. + for f in done: + if f in todo: + todo.remove(f) + completed.append(f) + f = completed.popleft() + return f.result() # May raise. + + for _ in range(len(todo)): + yield _wait_for_one() + + +@coroutine +def sleep(delay, result=None, *, loop=None): + """Coroutine that completes after a given time (in seconds).""" + future = futures.Future(loop=loop) + h = future._loop.call_later(delay, future.set_result, result) + try: + return (yield from future) + finally: + h.cancel() + + +def async(coro_or_future, *, loop=None): + """Wrap a coroutine in a future. + + If the argument is a Future, it is returned directly. + """ + if isinstance(coro_or_future, futures.Future): + if loop is not None and loop is not coro_or_future._loop: + raise ValueError('loop argument must agree with Future') + return coro_or_future + elif iscoroutine(coro_or_future): + return Task(coro_or_future, loop=loop) + else: + raise TypeError('A Future or coroutine is required') + + +class _GatheringFuture(futures.Future): + """Helper for gather(). + + This overrides cancel() to cancel all the children and act more + like Task.cancel(), which doesn't immediately mark itself as + cancelled. + """ + + def __init__(self, children, *, loop=None): + super().__init__(loop=loop) + self._children = children + + def cancel(self): + if self.done(): + return False + for child in self._children: + child.cancel() + return True + + +def gather(*coros_or_futures, loop=None, return_exceptions=False): + """Return a future aggregating results from the given coroutines + or futures. + + All futures must share the same event loop. If all the tasks are + done successfully, the returned future's result is the list of + results (in the order of the original sequence, not necessarily + the order of results arrival). If *result_exception* is True, + exceptions in the tasks are treated the same as successful + results, and gathered in the result list; otherwise, the first + raised exception will be immediately propagated to the returned + future. + + Cancellation: if the outer Future is cancelled, all children (that + have not completed yet) are also cancelled. If any child is + cancelled, this is treated as if it raised CancelledError -- + the outer Future is *not* cancelled in this case. (This is to + prevent the cancellation of one child to cause other children to + be cancelled.) + """ + children = [async(fut, loop=loop) for fut in coros_or_futures] + n = len(children) + if n == 0: + outer = futures.Future(loop=loop) + outer.set_result([]) + return outer + if loop is None: + loop = children[0]._loop + for fut in children: + if fut._loop is not loop: + raise ValueError("futures are tied to different event loops") + outer = _GatheringFuture(children, loop=loop) + nfinished = 0 + results = [None] * n + + def _done_callback(i, fut): + nonlocal nfinished + if outer._state != futures._PENDING: + if fut._exception is not None: + # Mark exception retrieved. + fut.exception() + return + if fut._state == futures._CANCELLED: + res = futures.CancelledError() + if not return_exceptions: + outer.set_exception(res) + return + elif fut._exception is not None: + res = fut.exception() # Mark exception retrieved. + if not return_exceptions: + outer.set_exception(res) + return + else: + res = fut._result + results[i] = res + nfinished += 1 + if nfinished == n: + outer.set_result(results) + + for i, fut in enumerate(children): + fut.add_done_callback(functools.partial(_done_callback, i)) + return outer + + +def shield(arg, *, loop=None): + """Wait for a future, shielding it from cancellation. + + The statement + + res = yield from shield(something()) + + is exactly equivalent to the statement + + res = yield from something() + + *except* that if the coroutine containing it is cancelled, the + task running in something() is not cancelled. From the POV of + something(), the cancellation did not happen. But its caller is + still cancelled, so the yield-from expression still raises + CancelledError. Note: If something() is cancelled by other means + this will still cancel shield(). + + If you want to completely ignore cancellation (not recommended) + you can combine shield() with a try/except clause, as follows: + + try: + res = yield from shield(something()) + except CancelledError: + res = None + """ + inner = async(arg, loop=loop) + if inner.done(): + # Shortcut. + return inner + loop = inner._loop + outer = futures.Future(loop=loop) + + def _done_callback(inner): + if outer.cancelled(): + # Mark inner's result as retrieved. + inner.cancelled() or inner.exception() + return + if inner.cancelled(): + outer.cancel() + else: + exc = inner.exception() + if exc is not None: + outer.set_exception(exc) + else: + outer.set_result(inner.result()) + + inner.add_done_callback(_done_callback) + return outer diff --git a/tulip/test_utils.py b/tulip/test_utils.py new file mode 100644 index 00000000..fdf629d6 --- /dev/null +++ b/tulip/test_utils.py @@ -0,0 +1,239 @@ +"""Utilities shared by tests.""" + +import collections +import contextlib +import io +import unittest.mock +import os +import sys +import threading +import unittest +import unittest.mock +from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +import tulip +from tulip import base_events +from tulip import events +from tulip import selectors + + +if sys.platform == 'win32': # pragma: no cover + from .windows_utils import socketpair +else: + from socket import socketpair # pragma: no cover + + +def dummy_ssl_context(): + if ssl is None: + return None + else: + return ssl.SSLContext(ssl.PROTOCOL_SSLv23) + + +def run_briefly(loop): + @tulip.coroutine + def once(): + pass + gen = once() + t = tulip.Task(gen, loop=loop) + try: + loop.run_until_complete(t) + finally: + gen.close() + + +def run_once(loop): + """loop.stop() schedules _raise_stop_error() + and run_forever() runs until _raise_stop_error() callback. + this wont work if test waits for some IO events, because + _raise_stop_error() runs before any of io events callbacks. + """ + loop.stop() + loop.run_forever() + + +@contextlib.contextmanager +def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): + + class SilentWSGIRequestHandler(WSGIRequestHandler): + def get_stderr(self): + return io.StringIO() + + def log_message(self, format, *args): + pass + + class SilentWSGIServer(WSGIServer): + def handle_error(self, request, client_address): + pass + + class SSLWSGIServer(SilentWSGIServer): + def finish_request(self, request, client_address): + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + keyfile = os.path.join(here, 'sample.key') + certfile = os.path.join(here, 'sample.crt') + ssock = ssl.wrap_socket(request, + keyfile=keyfile, + certfile=certfile, + server_side=True) + try: + self.RequestHandlerClass(ssock, client_address, self) + ssock.close() + except OSError: + # maybe socket has been closed by peer + pass + + def app(environ, start_response): + status = '200 OK' + headers = [('Content-type', 'text/plain')] + start_response(status, headers) + return [b'Test message'] + + # Run the test WSGI server in a separate thread in order not to + # interfere with event handling in the main thread + server_class = SSLWSGIServer if use_ssl else SilentWSGIServer + httpd = make_server(host, port, app, + server_class, SilentWSGIRequestHandler) + httpd.address = httpd.server_address + server_thread = threading.Thread(target=httpd.serve_forever) + server_thread.start() + try: + yield httpd + finally: + httpd.shutdown() + server_thread.join() + + +def make_test_protocol(base): + dct = {} + for name in dir(base): + if name.startswith('__') and name.endswith('__'): + # skip magic names + continue + dct[name] = unittest.mock.Mock(return_value=None) + return type('TestProtocol', (base,) + base.__bases__, dct)() + + +class TestSelector(selectors.BaseSelector): + + def select(self, timeout): + return [] + + +class TestLoop(base_events.BaseEventLoop): + """Loop for unittests. + + It manages self time directly. + If something scheduled to be executed later then + on next loop iteration after all ready handlers done + generator passed to __init__ is calling. + + Generator should be like this: + + def gen(): + ... + when = yield ... + ... = yield time_advance + + Value retuned by yield is absolute time of next scheduled handler. + Value passed to yield is time advance to move loop's time forward. + """ + + def __init__(self, gen=None): + super().__init__() + + if gen is None: + def gen(): + yield + self._check_on_close = False + else: + self._check_on_close = True + + self._gen = gen() + next(self._gen) + self._time = 0 + self._timers = [] + self._selector = TestSelector() + + self.readers = {} + self.writers = {} + self.reset_counters() + + def time(self): + return self._time + + def advance_time(self, advance): + """Move test time forward.""" + if advance: + self._time += advance + + def close(self): + if self._check_on_close: + try: + self._gen.send(0) + except StopIteration: + pass + else: # pragma: no cover + raise AssertionError("Time generator is not finished") + + def add_reader(self, fd, callback, *args): + self.readers[fd] = events.make_handle(callback, args) + + def remove_reader(self, fd): + self.remove_reader_count[fd] += 1 + if fd in self.readers: + del self.readers[fd] + return True + else: + return False + + def assert_reader(self, fd, callback, *args): + assert fd in self.readers, 'fd {} is not registered'.format(fd) + handle = self.readers[fd] + assert handle._callback == callback, '{!r} != {!r}'.format( + handle._callback, callback) + assert handle._args == args, '{!r} != {!r}'.format( + handle._args, args) + + def add_writer(self, fd, callback, *args): + self.writers[fd] = events.make_handle(callback, args) + + def remove_writer(self, fd): + self.remove_writer_count[fd] += 1 + if fd in self.writers: + del self.writers[fd] + return True + else: + return False + + def assert_writer(self, fd, callback, *args): + assert fd in self.writers, 'fd {} is not registered'.format(fd) + handle = self.writers[fd] + assert handle._callback == callback, '{!r} != {!r}'.format( + handle._callback, callback) + assert handle._args == args, '{!r} != {!r}'.format( + handle._args, args) + + def reset_counters(self): + self.remove_reader_count = collections.defaultdict(int) + self.remove_writer_count = collections.defaultdict(int) + + def _run_once(self): + super()._run_once() + for when in self._timers: + advance = self._gen.send(when) + self.advance_time(advance) + self._timers = [] + + def call_at(self, when, callback, *args): + self._timers.append(when) + return super().call_at(when, callback, *args) + + def _process_events(self, event_list): + return + + def _write_to_self(self): + pass diff --git a/tulip/transports.py b/tulip/transports.py new file mode 100644 index 00000000..f1a71800 --- /dev/null +++ b/tulip/transports.py @@ -0,0 +1,186 @@ +"""Abstract Transport class.""" + +__all__ = ['ReadTransport', 'WriteTransport', 'Transport'] + + +class BaseTransport: + """Base ABC for transports.""" + + def __init__(self, extra=None): + if extra is None: + extra = {} + self._extra = extra + + def get_extra_info(self, name, default=None): + """Get optional transport information.""" + return self._extra.get(name, default) + + def close(self): + """Closes the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + raise NotImplementedError + + +class ReadTransport(BaseTransport): + """ABC for read-only transports.""" + + def pause_reading(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume_reading() is called. + """ + raise NotImplementedError + + def resume_reading(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + raise NotImplementedError + + +class WriteTransport(BaseTransport): + """ABC for write-only transports.""" + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation just calls write() for each item in + the list/iterable. + """ + for data in list_of_data: + self.write(data) + + def write_eof(self): + """Closes the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + raise NotImplementedError + + def can_write_eof(self): + """Return True if this protocol supports write_eof(), False if not.""" + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class Transport(ReadTransport, WriteTransport): + """ABC representing a bidirectional transport. + + There may be several implementations, but typically, the user does + not implement new transports; rather, the platform provides some + useful transports that are implemented using the platform's best + practices. + + The user never instantiates a transport directly; they call a + utility function, passing it a protocol factory and other + information necessary to create the transport and protocol. (E.g. + EventLoop.create_connection() or EventLoop.create_server().) + + The utility function will asynchronously create a transport and a + protocol and hook them up by calling the protocol's + connection_made() method, passing it the transport. + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + +class DatagramTransport(BaseTransport): + """ABC for datagram (UDP) transports.""" + + def sendto(self, data, addr=None): + """Send data to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + addr is target socket address. + If addr is None use target address pointed on transport creation. + """ + raise NotImplementedError + + def abort(self): + """Closes the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class SubprocessTransport(BaseTransport): + + def get_pid(self): + """Get subprocess id.""" + raise NotImplementedError + + def get_returncode(self): + """Get subprocess returncode. + + See also + http://docs.python.org/3/library/subprocess#subprocess.Popen.returncode + """ + raise NotImplementedError + + def get_pipe_transport(self, fd): + """Get transport for pipe with number fd.""" + raise NotImplementedError + + def send_signal(self, signal): + """Send signal to subprocess. + + See also: + docs.python.org/3/library/subprocess#subprocess.Popen.send_signal + """ + raise NotImplementedError + + def terminate(self): + """Stop the subprocess. + + Alias for close() method. + + On Posix OSs the method sends SIGTERM to the subprocess. + On Windows the Win32 API function TerminateProcess() + is called to stop the subprocess. + + See also: + http://docs.python.org/3/library/subprocess#subprocess.Popen.terminate + """ + raise NotImplementedError + + def kill(self): + """Kill the subprocess. + + On Posix OSs the function sends SIGKILL to the subprocess. + On Windows kill() is an alias for terminate(). + + See also: + http://docs.python.org/3/library/subprocess#subprocess.Popen.kill + """ + raise NotImplementedError diff --git a/tulip/unix_events.py b/tulip/unix_events.py new file mode 100644 index 00000000..a234f4fa --- /dev/null +++ b/tulip/unix_events.py @@ -0,0 +1,541 @@ +"""Selector eventloop for Unix with signal handling.""" + +import collections +import errno +import fcntl +import functools +import os +import signal +import socket +import stat +import subprocess +import sys + + +from . import constants +from . import events +from . import protocols +from . import selector_events +from . import tasks +from . import transports +from .log import logger + + +__all__ = ['SelectorEventLoop', 'STDIN', 'STDOUT', 'STDERR'] + +STDIN = 0 +STDOUT = 1 +STDERR = 2 + + +if sys.platform == 'win32': # pragma: no cover + raise ImportError('Signals are not really supported on Windows') + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Unix event loop + + Adds signal handling to SelectorEventLoop + """ + + def __init__(self, selector=None): + super().__init__(selector) + self._signal_handlers = {} + self._subprocesses = {} + + def _socketpair(self): + return socket.socketpair() + + def close(self): + handler = self._signal_handlers.get(signal.SIGCHLD) + if handler is not None: + self.remove_signal_handler(signal.SIGCHLD) + super().close() + + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + self._check_signal(sig) + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except ValueError as exc: + raise RuntimeError(str(exc)) + + handle = events.make_handle(callback, args) + self._signal_handlers[sig] = handle + + try: + signal.signal(sig, self._handle_signal) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as nexc: + logger.info('set_wakeup_fd(-1) failed: %s', nexc) + + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + def _handle_signal(self, sig, arg): + """Internal helper that is the actual signal handler.""" + handle = self._signal_handlers.get(sig) + if handle is None: + return # Assume it's some race condition. + if handle._cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self._add_callback_signalsafe(handle) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not. + """ + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except ValueError as exc: + logger.info('set_wakeup_fd(-1) failed: %s', exc) + + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + + if not (1 <= sig < signal.NSIG): + raise ValueError( + 'sig {} out of range(1, {})'.format(sig, signal.NSIG)) + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixReadPipeTransport(self, pipe, protocol, waiter, extra) + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixWritePipeTransport(self, pipe, protocol, waiter, extra) + + @tasks.coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + self._reg_sigchld() + transp = _UnixSubprocessTransport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs) + self._subprocesses[transp.get_pid()] = transp + yield from transp._post_init() + return transp + + def _reg_sigchld(self): + if signal.SIGCHLD not in self._signal_handlers: + self.add_signal_handler(signal.SIGCHLD, self._sig_chld) + + def _sig_chld(self): + try: + try: + pid, status = os.waitpid(0, os.WNOHANG) + except ChildProcessError: + return + if pid == 0: + self.call_soon(self._sig_chld) + return + elif os.WIFSIGNALED(status): + returncode = -os.WTERMSIG(status) + elif os.WIFEXITED(status): + returncode = os.WEXITSTATUS(status) + else: + self.call_soon(self._sig_chld) + return + transp = self._subprocesses.get(pid) + if transp is not None: + transp._process_exited(returncode) + except Exception: + logger.exception('Unknown exception in SIGCHLD handler') + + def _subprocess_closed(self, transport): + pid = transport.get_pid() + self._subprocesses.pop(pid, None) + + +def _set_nonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + + +class _UnixReadPipeTransport(transports.ReadTransport): + + max_size = 256 * 1024 # max bytes we read in one eventloop iteration + + def __init__(self, loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._loop = loop + self._pipe = pipe + self._fileno = pipe.fileno() + _set_nonblocking(self._fileno) + self._protocol = protocol + self._closing = False + self._loop.add_reader(self._fileno, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + try: + data = os.read(self._fileno, self.max_size) + except (BlockingIOError, InterruptedError): + pass + except OSError as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) + else: + self._closing = True + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._protocol.eof_received) + self._loop.call_soon(self._call_connection_lost, None) + + def pause_reading(self): + self._loop.remove_reader(self._fileno) + + def resume_reading(self): + self._loop.add_reader(self._fileno, self._read_ready) + + def close(self): + if not self._closing: + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + logger.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc): + self._closing = True + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + self._pipe = None + self._protocol = None + self._loop = None + + +class _UnixWritePipeTransport(transports.WriteTransport): + + def __init__(self, loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._loop = loop + self._pipe = pipe + self._fileno = pipe.fileno() + if not stat.S_ISFIFO(os.fstat(self._fileno).st_mode): + raise ValueError("Pipe transport is for pipes only.") + _set_nonblocking(self._fileno) + self._protocol = protocol + self._buffer = [] + self._conn_lost = 0 + self._closing = False # Set when close() or write_eof() called. + self._loop.add_reader(self._fileno, self._read_ready) + + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + self._loop.call_soon(waiter.set_result, None) + + def _read_ready(self): + # pipe was closed by peer + self._close() + + def write(self, data): + assert isinstance(data, bytes), repr(data) + if not data: + return + + if self._conn_lost or self._closing: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('pipe closed by peer or ' + 'os.write(pipe, data) raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Attempt to send it right away first. + try: + n = os.write(self._fileno, data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + self._conn_lost += 1 + self._fatal_error(exc) + return + if n == len(data): + return + elif n > 0: + data = data[n:] + self._loop.add_writer(self._fileno, self._write_ready) + + self._buffer.append(data) + + def _write_ready(self): + data = b''.join(self._buffer) + assert data, 'Data should not be empty' + + self._buffer.clear() + try: + n = os.write(self._fileno, data) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except Exception as exc: + self._conn_lost += 1 + # Remove writer here, _fatal_error() doesn't it + # because _buffer is empty. + self._loop.remove_writer(self._fileno) + self._fatal_error(exc) + else: + if n == len(data): + self._loop.remove_writer(self._fileno) + if self._closing: + self._loop.remove_reader(self._fileno) + self._call_connection_lost(None) + return + elif n > 0: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def can_write_eof(self): + return True + + # TODO: Make the relationships between write_eof(), close(), + # abort(), _fatal_error() and _close() more straightforward. + + def write_eof(self): + if self._closing: + return + assert self._pipe + self._closing = True + if not self._buffer: + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, None) + + def close(self): + if not self._closing: + # write_eof is all what we needed to close the write pipe + self.write_eof() + + def abort(self): + self._close(None) + + def _fatal_error(self, exc): + # should be called by exception handler only + logger.exception('Fatal error for %s', self) + self._close(exc) + + def _close(self, exc=None): + self._closing = True + if self._buffer: + self._loop.remove_writer(self._fileno) + self._buffer.clear() + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + self._pipe = None + self._protocol = None + self._loop = None + + +class _UnixWriteSubprocessPipeProto(protocols.BaseProtocol): + pipe = None + + def __init__(self, proc, fd): + self.proc = proc + self.fd = fd + self.connected = False + self.disconnected = False + proc._pipes[fd] = self + + def connection_made(self, transport): + self.connected = True + self.pipe = transport + self.proc._try_connected() + + def connection_lost(self, exc): + self.disconnected = True + self.proc._pipe_connection_lost(self.fd, exc) + + +class _UnixReadSubprocessPipeProto(_UnixWriteSubprocessPipeProto, + protocols.Protocol): + + def data_received(self, data): + self.proc._pipe_data_received(self.fd, data) + + def eof_received(self): + pass + + +class _UnixSubprocessTransport(transports.SubprocessTransport): + + def __init__(self, loop, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + super().__init__(extra) + self._protocol = protocol + self._loop = loop + + self._pipes = {} + if stdin == subprocess.PIPE: + self._pipes[STDIN] = None + if stdout == subprocess.PIPE: + self._pipes[STDOUT] = None + if stderr == subprocess.PIPE: + self._pipes[STDERR] = None + self._pending_calls = collections.deque() + self._finished = False + self._returncode = None + + self._proc = subprocess.Popen( + args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, + universal_newlines=False, bufsize=bufsize, **kwargs) + self._extra['subprocess'] = self._proc + + def close(self): + for proto in self._pipes.values(): + proto.pipe.close() + if self._returncode is None: + self.terminate() + + def get_pid(self): + return self._proc.pid + + def get_returncode(self): + return self._returncode + + def get_pipe_transport(self, fd): + if fd in self._pipes: + return self._pipes[fd].pipe + else: + return None + + def send_signal(self, signal): + self._proc.send_signal(signal) + + def terminate(self): + self._proc.terminate() + + def kill(self): + self._proc.kill() + + @tasks.coroutine + def _post_init(self): + proc = self._proc + loop = self._loop + if proc.stdin is not None: + transp, proto = yield from loop.connect_write_pipe( + functools.partial( + _UnixWriteSubprocessPipeProto, self, STDIN), + proc.stdin) + if proc.stdout is not None: + transp, proto = yield from loop.connect_read_pipe( + functools.partial( + _UnixReadSubprocessPipeProto, self, STDOUT), + proc.stdout) + if proc.stderr is not None: + transp, proto = yield from loop.connect_read_pipe( + functools.partial( + _UnixReadSubprocessPipeProto, self, STDERR), + proc.stderr) + if not self._pipes: + self._try_connected() + + def _call(self, cb, *data): + if self._pending_calls is not None: + self._pending_calls.append((cb, data)) + else: + self._loop.call_soon(cb, *data) + + def _try_connected(self): + assert self._pending_calls is not None + if all(p is not None and p.connected for p in self._pipes.values()): + self._loop.call_soon(self._protocol.connection_made, self) + for callback, data in self._pending_calls: + self._loop.call_soon(callback, *data) + self._pending_calls = None + + def _pipe_connection_lost(self, fd, exc): + self._call(self._protocol.pipe_connection_lost, fd, exc) + self._try_finish() + + def _pipe_data_received(self, fd, data): + self._call(self._protocol.pipe_data_received, fd, data) + + def _process_exited(self, returncode): + assert returncode is not None, returncode + assert self._returncode is None, self._returncode + self._returncode = returncode + self._loop._subprocess_closed(self) + self._call(self._protocol.process_exited) + self._try_finish() + + def _try_finish(self): + assert not self._finished + if self._returncode is None: + return + if all(p is not None and p.disconnected + for p in self._pipes.values()): + self._finished = True + self._loop.call_soon(self._call_connection_lost, None) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._proc = None + self._protocol = None + self._loop = None diff --git a/tulip/windows_events.py b/tulip/windows_events.py new file mode 100644 index 00000000..ad53a8e7 --- /dev/null +++ b/tulip/windows_events.py @@ -0,0 +1,371 @@ +"""Selector and proactor eventloops for Windows.""" + +import errno +import socket +import weakref +import struct +import _winapi + +from . import futures +from . import proactor_events +from . import selector_events +from . import tasks +from . import windows_utils +from . import _overlapped +from .log import logger + + +__all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor'] + + +NULL = 0 +INFINITE = 0xffffffff +ERROR_CONNECTION_REFUSED = 1225 +ERROR_CONNECTION_ABORTED = 1236 + + +class _OverlappedFuture(futures.Future): + """Subclass of Future which represents an overlapped operation. + + Cancelling it will immediately cancel the overlapped operation. + """ + + def __init__(self, ov, *, loop=None): + super().__init__(loop=loop) + self.ov = ov + + def cancel(self): + try: + self.ov.cancel() + except OSError: + pass + return super().cancel() + + +class PipeServer(object): + """Class representing a pipe server. + + This is much like a bound, listening socket. + """ + def __init__(self, address): + self._address = address + self._free_instances = weakref.WeakSet() + self._pipe = self._server_pipe_handle(True) + + def _get_unconnected_pipe(self): + # Create new instance and return previous one. This ensures + # that (until the server is closed) there is always at least + # one pipe handle for address. Therefore if a client attempt + # to connect it will not fail with FileNotFoundError. + tmp, self._pipe = self._pipe, self._server_pipe_handle(False) + return tmp + + def _server_pipe_handle(self, first): + # Return a wrapper for a new pipe handle. + if self._address is None: + return None + flags = _winapi.PIPE_ACCESS_DUPLEX | _winapi.FILE_FLAG_OVERLAPPED + if first: + flags |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE + h = _winapi.CreateNamedPipe( + self._address, flags, + _winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_READMODE_MESSAGE | + _winapi.PIPE_WAIT, + _winapi.PIPE_UNLIMITED_INSTANCES, + windows_utils.BUFSIZE, windows_utils.BUFSIZE, + _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL) + pipe = windows_utils.PipeHandle(h) + self._free_instances.add(pipe) + return pipe + + def close(self): + # Close all instances which have not been connected to by a client. + if self._address is not None: + for pipe in self._free_instances: + pipe.close() + self._pipe = None + self._address = None + self._free_instances.clear() + + __del__ = close + + +class SelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Windows version of selector event loop.""" + + def _socketpair(self): + return windows_utils.socketpair() + + +class ProactorEventLoop(proactor_events.BaseProactorEventLoop): + """Windows version of proactor event loop using IOCP.""" + + def __init__(self, proactor=None): + if proactor is None: + proactor = IocpProactor() + super().__init__(proactor) + + def _socketpair(self): + return windows_utils.socketpair() + + @tasks.coroutine + def create_pipe_connection(self, protocol_factory, address): + f = self._proactor.connect_pipe(address) + pipe = yield from f + protocol = protocol_factory() + trans = self._make_duplex_pipe_transport(pipe, protocol, + extra={'addr': address}) + return trans, protocol + + @tasks.coroutine + def start_serving_pipe(self, protocol_factory, address): + server = PipeServer(address) + def loop(f=None): + pipe = None + try: + if f: + pipe = f.result() + server._free_instances.discard(pipe) + protocol = protocol_factory() + self._make_duplex_pipe_transport( + pipe, protocol, extra={'addr': address}) + pipe = server._get_unconnected_pipe() + if pipe is None: + return + f = self._proactor.accept_pipe(pipe) + except OSError: + if pipe and pipe.fileno() != -1: + logger.exception('Pipe accept failed') + pipe.close() + except futures.CancelledError: + if pipe: + pipe.close() + else: + f.add_done_callback(loop) + self.call_soon(loop) + return [server] + + def _stop_serving(self, server): + server.close() + + +class IocpProactor: + """Proactor implementation using IOCP.""" + + def __init__(self, concurrency=0xffffffff): + self._loop = None + self._results = [] + self._iocp = _overlapped.CreateIoCompletionPort( + _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency) + self._cache = {} + self._registered = weakref.WeakSet() + self._stopped_serving = weakref.WeakSet() + + def set_loop(self, loop): + self._loop = loop + + def select(self, timeout=None): + if not self._results: + self._poll(timeout) + tmp = self._results + self._results = [] + return tmp + + def recv(self, conn, nbytes, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + if isinstance(conn, socket.socket): + ov.WSARecv(conn.fileno(), nbytes, flags) + else: + ov.ReadFile(conn.fileno(), nbytes) + def finish(trans, key, ov): + try: + return ov.getresult() + except OSError as exc: + if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: + raise ConnectionResetError(*exc.args) + else: + raise + return self._register(ov, conn, finish) + + def send(self, conn, buf, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + if isinstance(conn, socket.socket): + ov.WSASend(conn.fileno(), buf, flags) + else: + ov.WriteFile(conn.fileno(), buf) + def finish(trans, key, ov): + try: + return ov.getresult() + except OSError as exc: + if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: + raise ConnectionResetError(*exc.args) + else: + raise + return self._register(ov, conn, finish) + + def accept(self, listener): + self._register_with_iocp(listener) + conn = self._get_accept_socket(listener.family) + ov = _overlapped.Overlapped(NULL) + ov.AcceptEx(listener.fileno(), conn.fileno()) + def finish_accept(trans, key, ov): + ov.getresult() + # Use SO_UPDATE_ACCEPT_CONTEXT so getsockname() etc work. + buf = struct.pack('@P', listener.fileno()) + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_ACCEPT_CONTEXT, buf) + conn.settimeout(listener.gettimeout()) + return conn, conn.getpeername() + return self._register(ov, listener, finish_accept) + + def connect(self, conn, address): + self._register_with_iocp(conn) + # The socket needs to be locally bound before we call ConnectEx(). + try: + _overlapped.BindLocal(conn.fileno(), conn.family) + except OSError as e: + if e.winerror != errno.WSAEINVAL: + raise + # Probably already locally bound; check using getsockname(). + if conn.getsockname()[1] == 0: + raise + ov = _overlapped.Overlapped(NULL) + ov.ConnectEx(conn.fileno(), address) + def finish_connect(trans, key, ov): + ov.getresult() + # Use SO_UPDATE_CONNECT_CONTEXT so getsockname() etc work. + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_CONNECT_CONTEXT, 0) + return conn + return self._register(ov, conn, finish_connect) + + def accept_pipe(self, pipe): + self._register_with_iocp(pipe) + ov = _overlapped.Overlapped(NULL) + ov.ConnectNamedPipe(pipe.fileno()) + def finish(trans, key, ov): + ov.getresult() + return pipe + return self._register(ov, pipe, finish) + + def connect_pipe(self, address): + ov = _overlapped.Overlapped(NULL) + ov.WaitNamedPipeAndConnect(address, self._iocp, ov.address) + def finish(err, handle, ov): + # err, handle were arguments passed to PostQueuedCompletionStatus() + # in a function run in a thread pool. + if err == _overlapped.ERROR_SEM_TIMEOUT: + # Connection did not succeed within time limit. + msg = _overlapped.FormatMessage(err) + raise ConnectionRefusedError(0, msg, None, err) + elif err != 0: + msg = _overlapped.FormatMessage(err) + raise OSError(0, msg, None, err) + else: + return windows_utils.PipeHandle(handle) + return self._register(ov, None, finish, wait_for_post=True) + + def _register_with_iocp(self, obj): + # To get notifications of finished ops on this objects sent to the + # completion port, were must register the handle. + if obj not in self._registered: + self._registered.add(obj) + _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) + # XXX We could also use SetFileCompletionNotificationModes() + # to avoid sending notifications to completion port of ops + # that succeed immediately. + + def _register(self, ov, obj, callback, wait_for_post=False): + # Return a future which will be set with the result of the + # operation when it completes. The future's value is actually + # the value returned by callback(). + f = _OverlappedFuture(ov, loop=self._loop) + if ov.pending or wait_for_post: + # Register the overlapped operation for later. Note that + # we only store obj to prevent it from being garbage + # collected too early. + self._cache[ov.address] = (f, ov, obj, callback) + else: + # The operation has completed, so no need to postpone the + # work. We cannot take this short cut if we need the + # NumberOfBytes, CompletionKey values returned by + # PostQueuedCompletionStatus(). + try: + value = callback(None, None, ov) + except OSError as e: + f.set_exception(e) + else: + f.set_result(value) + return f + + def _get_accept_socket(self, family): + s = socket.socket(family) + s.settimeout(0) + return s + + def _poll(self, timeout=None): + if timeout is None: + ms = INFINITE + elif timeout < 0: + raise ValueError("negative timeout") + else: + ms = int(timeout * 1000 + 0.5) + if ms >= INFINITE: + raise ValueError("timeout too big") + while True: + status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms) + if status is None: + return + err, transferred, key, address = status + try: + f, ov, obj, callback = self._cache.pop(address) + except KeyError: + # key is either zero, or it is used to return a pipe + # handle which should be closed to avoid a leak. + if key not in (0, _overlapped.INVALID_HANDLE_VALUE): + _winapi.CloseHandle(key) + ms = 0 + continue + if obj in self._stopped_serving: + f.cancel() + elif not f.cancelled(): + try: + value = callback(transferred, key, ov) + except OSError as e: + f.set_exception(e) + self._results.append(f) + else: + f.set_result(value) + self._results.append(f) + ms = 0 + + def _stop_serving(self, obj): + # obj is a socket or pipe handle. It will be closed in + # BaseProactorEventLoop._stop_serving() which will make any + # pending operations fail quickly. + self._stopped_serving.add(obj) + + def close(self): + # Cancel remaining registered operations. + for address, (f, ov, obj, callback) in list(self._cache.items()): + if obj is None: + # The operation was started with connect_pipe() which + # queues a task to Windows' thread pool. This cannot + # be cancelled, so just forget it. + del self._cache[address] + else: + try: + ov.cancel() + except OSError: + pass + + while self._cache: + if not self._poll(1): + logger.debug('taking long time to close proactor') + + self._results = [] + if self._iocp is not None: + _winapi.CloseHandle(self._iocp) + self._iocp = None diff --git a/tulip/windows_utils.py b/tulip/windows_utils.py new file mode 100644 index 00000000..04b43e9a --- /dev/null +++ b/tulip/windows_utils.py @@ -0,0 +1,181 @@ +""" +Various Windows specific bits and pieces +""" + +import sys + +if sys.platform != 'win32': # pragma: no cover + raise ImportError('win32 only') + +import socket +import itertools +import msvcrt +import os +import subprocess +import tempfile +import _winapi + + +__all__ = ['socketpair', 'pipe', 'Popen', 'PIPE', 'PipeHandle'] + +# +# Constants/globals +# + +BUFSIZE = 8192 +PIPE = subprocess.PIPE +_mmap_counter = itertools.count() + +# +# Replacement for socket.socketpair() +# + +def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """A socket pair usable as a self-pipe, for Windows. + + Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. + """ + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + lsock.bind(('localhost', 0)) + lsock.listen(1) + addr, port = lsock.getsockname() + csock = socket.socket(family, type, proto) + csock.setblocking(False) + try: + csock.connect((addr, port)) + except (BlockingIOError, InterruptedError): + pass + except Exception: + lsock.close() + csock.close() + raise + ssock, _ = lsock.accept() + csock.setblocking(True) + lsock.close() + return (ssock, csock) + +# +# Replacement for os.pipe() using handles instead of fds +# + +def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): + """Like os.pipe() but with overlapped support and using handles not fds.""" + address = tempfile.mktemp(prefix=r'\\.\pipe\python-pipe-%d-%d-' % + (os.getpid(), next(_mmap_counter))) + + if duplex: + openmode = _winapi.PIPE_ACCESS_DUPLEX + access = _winapi.GENERIC_READ | _winapi.GENERIC_WRITE + obsize, ibsize = bufsize, bufsize + else: + openmode = _winapi.PIPE_ACCESS_INBOUND + access = _winapi.GENERIC_WRITE + obsize, ibsize = 0, bufsize + + openmode |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE + + if overlapped[0]: + openmode |= _winapi.FILE_FLAG_OVERLAPPED + + if overlapped[1]: + flags_and_attribs = _winapi.FILE_FLAG_OVERLAPPED + else: + flags_and_attribs = 0 + + h1 = h2 = None + try: + h1 = _winapi.CreateNamedPipe( + address, openmode, _winapi.PIPE_WAIT, + 1, obsize, ibsize, _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL) + + h2 = _winapi.CreateFile( + address, access, 0, _winapi.NULL, _winapi.OPEN_EXISTING, + flags_and_attribs, _winapi.NULL) + + ov = _winapi.ConnectNamedPipe(h1, overlapped=True) + ov.GetOverlappedResult(True) + return h1, h2 + except: + if h1 is not None: + _winapi.CloseHandle(h1) + if h2 is not None: + _winapi.CloseHandle(h2) + raise + +# +# Wrapper for a pipe handle +# + +class PipeHandle: + """Wrapper for an overlapped pipe handle which is vaguely file-object like. + + The IOCP event loop can use these instead of socket objects. + """ + def __init__(self, handle): + self._handle = handle + + @property + def handle(self): + return self._handle + + def fileno(self): + return self._handle + + def close(self, *, CloseHandle=_winapi.CloseHandle): + if self._handle != -1: + CloseHandle(self._handle) + self._handle = -1 + + __del__ = close + + def __enter__(self): + return self + + def __exit__(self, t, v, tb): + self.close() + +# +# Replacement for subprocess.Popen using overlapped pipe handles +# + +class Popen(subprocess.Popen): + """Replacement for subprocess.Popen using overlapped pipe handles. + + The stdin, stdout, stderr are None or instances of PipeHandle. + """ + def __init__(self, args, stdin=None, stdout=None, stderr=None, **kwds): + stdin_rfd = stdout_wfd = stderr_wfd = None + stdin_wh = stdout_rh = stderr_rh = None + if stdin == PIPE: + stdin_rh, stdin_wh = pipe(overlapped=(False, True)) + stdin_rfd = msvcrt.open_osfhandle(stdin_rh, os.O_RDONLY) + if stdout == PIPE: + stdout_rh, stdout_wh = pipe(overlapped=(True, False)) + stdout_wfd = msvcrt.open_osfhandle(stdout_wh, 0) + if stderr == PIPE: + stderr_rh, stderr_wh = pipe(overlapped=(True, False)) + stderr_wfd = msvcrt.open_osfhandle(stderr_wh, 0) + try: + super().__init__(args, bufsize=0, universal_newlines=False, + stdin=stdin_rfd, stdout=stdout_wfd, + stderr=stderr_wfd, **kwds) + except: + for h in (stdin_wh, stdout_rh, stderr_rh): + _winapi.CloseHandle(h) + raise + else: + if stdin_wh is not None: + self.stdin = PipeHandle(stdin_wh) + if stdout_rh is not None: + self.stdout = PipeHandle(stdout_rh) + if stderr_rh is not None: + self.stderr = PipeHandle(stderr_rh) + finally: + if stdin == PIPE: + os.close(stdin_rfd) + if stdout == PIPE: + os.close(stdout_wfd) + if stderr == PIPE: + os.close(stderr_wfd) From 53d8ac9de6770431e1d19eaf4f2b0f0df92c54e4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 18 Oct 2013 14:50:45 -0700 Subject: [PATCH 0722/1502] Fix example to use create_server() instead of start_serving(). --- examples/tcp_echo.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/tcp_echo.py b/examples/tcp_echo.py index 39db5cca..9ecc4805 100755 --- a/examples/tcp_echo.py +++ b/examples/tcp_echo.py @@ -70,9 +70,9 @@ def start_client(loop, host, port): def start_server(loop, host, port): - f = loop.start_serving(EchoServer, host, port) - x = loop.run_until_complete(f)[0] - print('serving on', x.getsockname()) + f = loop.create_server(EchoServer, host, port) + s = loop.run_until_complete(f) + print('serving on', s.sockets[0].getsockname()) ARGS = argparse.ArgumentParser(description="TCP Echo example.") From ba7e4bb5b3bc7fcf710482e0b5408a26d57bc343 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 18 Oct 2013 15:04:45 -0700 Subject: [PATCH 0723/1502] Write flow control. Also a somewhat major streams overhaul. --- examples/source1.py | 88 +++++++++++++++++ tests/streams_test.py | 42 ++++---- tulip/protocols.py | 28 ++++++ tulip/selector_events.py | 69 ++++++++++--- tulip/streams.py | 208 ++++++++++++++++++++++++++++----------- tulip/transports.py | 25 +++++ 6 files changed, 369 insertions(+), 91 deletions(-) create mode 100644 examples/source1.py diff --git a/examples/source1.py b/examples/source1.py new file mode 100644 index 00000000..4e05964f --- /dev/null +++ b/examples/source1.py @@ -0,0 +1,88 @@ +"""Like source.py, but uses streams.""" + +import argparse +import sys + +from tulip import * + +ARGS = argparse.ArgumentParser(description="TCP data sink example.") +ARGS.add_argument( + '--iocp', action='store_true', dest='iocp', + default=False, help='Use IOCP event loop (Windows only)') +ARGS.add_argument( + '--stop', action='store_true', dest='stop', + default=False, help='Stop the server by sending it b"stop" as data') +ARGS.add_argument( + '--host', action='store', dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action='store', dest='port', + default=1111, type=int, help='Port number') +ARGS.add_argument( + '--size', action='store', dest='size', + default=16*1024, type=int, help='Data size') + + +class Debug: + """A clever little class that suppresses repetitive messages.""" + + overwriting = False + label = 'stream1:' + + def print(self, *args): + if self.overwriting: + print(file=sys.stderr) + self.overwriting = 0 + print(self.label, *args, file=sys.stderr) + + def oprint(self, *args): + self.overwriting += 1 + end = '\n' + if self.overwriting >= 3: + if self.overwriting == 3: + print(self.label, '[...]', file=sys.stderr) + end = '\r' + print(self.label, *args, file=sys.stderr, end=end, flush=True) + + +@coroutine +def start(loop, args): + d = Debug() + total = 0 + r, w = yield from open_connection(args.host, args.port) + d.print('r =', r) + d.print('w =', w) + if args.stop: + w.write(b'stop') + w.close() + else: + size = args.size + data = b'x'*size + try: + while True: + total += size + d.oprint('writing', size, 'bytes; total', total) + w.write(data) + f = w.drain() + if f: + d.print('pausing') + yield from f + except (ConnectionResetError, BrokenPipeError) as exc: + d.print('caught', repr(exc)) + + +def main(): + global args + args = ARGS.parse_args() + if args.iocp: + from tulip.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + set_event_loop(loop) + else: + loop = get_event_loop() + loop.run_until_complete(start(loop, args)) + loop.close() + + +if __name__ == '__main__': + main() diff --git a/tests/streams_test.py b/tests/streams_test.py index e562bf92..1d525575 100644 --- a/tests/streams_test.py +++ b/tests/streams_test.py @@ -32,7 +32,7 @@ def tearDown(self): @unittest.mock.patch('tulip.streams.events') def test_ctor_global_loop(self, m_events): stream = streams.StreamReader() - self.assertIs(stream.loop, m_events.get_event_loop.return_value) + self.assertIs(stream._loop, m_events.get_event_loop.return_value) def test_open_connection(self): with test_utils.run_test_server() as httpd: @@ -81,13 +81,13 @@ def test_feed_empty_data(self): stream = streams.StreamReader(loop=self.loop) stream.feed_data(b'') - self.assertEqual(0, stream.byte_count) + self.assertEqual(0, stream._byte_count) def test_feed_data_byte_count(self): stream = streams.StreamReader(loop=self.loop) stream.feed_data(self.DATA) - self.assertEqual(len(self.DATA), stream.byte_count) + self.assertEqual(len(self.DATA), stream._byte_count) def test_read_zero(self): # Read zero bytes. @@ -96,7 +96,7 @@ def test_read_zero(self): data = self.loop.run_until_complete(stream.read(0)) self.assertEqual(b'', data) - self.assertEqual(len(self.DATA), stream.byte_count) + self.assertEqual(len(self.DATA), stream._byte_count) def test_read(self): # Read bytes. @@ -109,7 +109,7 @@ def cb(): data = self.loop.run_until_complete(read_task) self.assertEqual(self.DATA, data) - self.assertFalse(stream.byte_count) + self.assertFalse(stream._byte_count) def test_read_line_breaks(self): # Read bytes without line breaks. @@ -120,7 +120,7 @@ def test_read_line_breaks(self): data = self.loop.run_until_complete(stream.read(5)) self.assertEqual(b'line1', data) - self.assertEqual(5, stream.byte_count) + self.assertEqual(5, stream._byte_count) def test_read_eof(self): # Read bytes, stop at eof. @@ -133,7 +133,7 @@ def cb(): data = self.loop.run_until_complete(read_task) self.assertEqual(b'', data) - self.assertFalse(stream.byte_count) + self.assertFalse(stream._byte_count) def test_read_until_eof(self): # Read all bytes until eof. @@ -149,7 +149,7 @@ def cb(): data = self.loop.run_until_complete(read_task) self.assertEqual(b'chunk1\nchunk2', data) - self.assertFalse(stream.byte_count) + self.assertFalse(stream._byte_count) def test_read_exception(self): stream = streams.StreamReader(loop=self.loop) @@ -176,7 +176,7 @@ def cb(): line = self.loop.run_until_complete(read_task) self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) - self.assertEqual(len(b'\n chunk4')-1, stream.byte_count) + self.assertEqual(len(b'\n chunk4')-1, stream._byte_count) def test_readline_limit_with_existing_data(self): stream = streams.StreamReader(3, loop=self.loop) @@ -185,7 +185,7 @@ def test_readline_limit_with_existing_data(self): self.assertRaises( ValueError, self.loop.run_until_complete, stream.readline()) - self.assertEqual([b'line2\n'], list(stream.buffer)) + self.assertEqual([b'line2\n'], list(stream._buffer)) stream = streams.StreamReader(3, loop=self.loop) stream.feed_data(b'li') @@ -194,8 +194,8 @@ def test_readline_limit_with_existing_data(self): self.assertRaises( ValueError, self.loop.run_until_complete, stream.readline()) - self.assertEqual([b'li'], list(stream.buffer)) - self.assertEqual(2, stream.byte_count) + self.assertEqual([b'li'], list(stream._buffer)) + self.assertEqual(2, stream._byte_count) def test_readline_limit(self): stream = streams.StreamReader(7, loop=self.loop) @@ -209,8 +209,8 @@ def cb(): self.assertRaises( ValueError, self.loop.run_until_complete, stream.readline()) - self.assertEqual([b'chunk3\n'], list(stream.buffer)) - self.assertEqual(7, stream.byte_count) + self.assertEqual([b'chunk3\n'], list(stream._buffer)) + self.assertEqual(7, stream._byte_count) def test_readline_line_byte_count(self): stream = streams.StreamReader(loop=self.loop) @@ -220,7 +220,7 @@ def test_readline_line_byte_count(self): line = self.loop.run_until_complete(stream.readline()) self.assertEqual(b'line1\n', line) - self.assertEqual(len(self.DATA) - len(b'line1\n'), stream.byte_count) + self.assertEqual(len(self.DATA) - len(b'line1\n'), stream._byte_count) def test_readline_eof(self): stream = streams.StreamReader(loop=self.loop) @@ -248,7 +248,7 @@ def test_readline_read_byte_count(self): self.assertEqual(b'line2\nl', data) self.assertEqual( len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), - stream.byte_count) + stream._byte_count) def test_readline_exception(self): stream = streams.StreamReader(loop=self.loop) @@ -268,11 +268,11 @@ def test_readexactly_zero_or_less(self): data = self.loop.run_until_complete(stream.readexactly(0)) self.assertEqual(b'', data) - self.assertEqual(len(self.DATA), stream.byte_count) + self.assertEqual(len(self.DATA), stream._byte_count) data = self.loop.run_until_complete(stream.readexactly(-1)) self.assertEqual(b'', data) - self.assertEqual(len(self.DATA), stream.byte_count) + self.assertEqual(len(self.DATA), stream._byte_count) def test_readexactly(self): # Read exact number of bytes. @@ -289,7 +289,7 @@ def cb(): data = self.loop.run_until_complete(read_task) self.assertEqual(self.DATA + self.DATA, data) - self.assertEqual(len(self.DATA), stream.byte_count) + self.assertEqual(len(self.DATA), stream._byte_count) def test_readexactly_eof(self): # Read exact number of bytes (eof). @@ -304,7 +304,7 @@ def cb(): data = self.loop.run_until_complete(read_task) self.assertEqual(self.DATA, data) - self.assertFalse(stream.byte_count) + self.assertFalse(stream._byte_count) def test_readexactly_exception(self): stream = streams.StreamReader(loop=self.loop) @@ -357,7 +357,7 @@ def read_a_line(): # The following line fails if set_exception() isn't careful. stream.set_exception(RuntimeError('message')) test_utils.run_briefly(self.loop) - self.assertIs(stream.waiter, None) + self.assertIs(stream._waiter, None) if __name__ == '__main__': diff --git a/tulip/protocols.py b/tulip/protocols.py index a94abbe5..d3a86859 100644 --- a/tulip/protocols.py +++ b/tulip/protocols.py @@ -29,6 +29,34 @@ def connection_lost(self, exc): aborted or closed). """ + def pause_writing(self): + """Called when the transport's buffer goes over the high-water mark. + + Pause and resume calls are paired -- pause_writing() is called + once when the buffer goes strictly over the high-water mark + (even if subsequent writes increases the buffer size even + more), and eventually resume_writing() is called once when the + buffer size reaches the low-water mark. + + Note that if the buffer size equals the high-water mark, + pause_writing() is not called -- it must go strictly over. + Conversely, resume_writing() is called when the buffer size is + equal or lower than the low-water mark. These end conditions + are important to ensure that things go as expected when either + mark is zero. + + NOTE: This is the only Protocol callback that is not called + through EventLoop.call_soon() -- if it were, it would have no + effect when it's most needed (when the app keeps writing + without yielding until pause_writing() is called). + """ + + def resume_writing(self): + """Called when the transport's buffer drains below the low-water mark. + + See pause_writing() for details. + """ + class Protocol(BaseProtocol): """ABC representing a protocol. diff --git a/tulip/selector_events.py b/tulip/selector_events.py index 084d9be7..adf8d382 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -346,8 +346,10 @@ def __init__(self, loop, sock, protocol, extra, server=None): self._buffer = collections.deque() self._conn_lost = 0 # Set when call to connection_lost scheduled. self._closing = False # Set when close() called. - if server is not None: - server.attach(self) + self._protocol_paused = False + self.set_write_buffer_limits() + if self._server is not None: + self._server.attach(self) def abort(self): self._force_close(None) @@ -392,6 +394,40 @@ def _call_connection_lost(self, exc): server.detach(self) self._server = None + def _maybe_pause_protocol(self): + size = self.get_write_buffer_size() + if size <= self._high_water: + return + if not self._protocol_paused: + self._protocol_paused = True + try: + self._protocol.pause_writing() + except Exception: + tulip_log.exception('pause_writing() failed') + + def _maybe_resume_protocol(self): + if self._protocol_paused and self.get_write_buffer_size() <= self._low_water: + self._protocol_paused = False + try: + self._protocol.resume_writing() + except Exception: + tulip_log.exception('resume_writing() failed') + + def set_write_buffer_limits(self, high=None, low=None): + if high is None: + if low is None: + high = 64*1024 + else: + high = 4*low + if low is None: + low = high // 4 + assert 0 <= low <= high, repr((low, high)) + self._high_water = high + self._low_water = low + + def get_write_buffer_size(self): + return sum(len(data) for data in self._buffer) + class _SelectorSocketTransport(_SelectorTransport): @@ -447,7 +483,7 @@ def write(self, data): return if not self._buffer: - # Attempt to send it right away first. + # Optimization: try to send now. try: n = self._sock.send(data) except (BlockingIOError, InterruptedError): @@ -459,34 +495,36 @@ def write(self, data): data = data[n:] if not data: return - - # Start async I/O. + # Not all was written; register write handler. self._loop.add_writer(self._sock_fd, self._write_ready) + # Add it to the buffer. self._buffer.append(data) + self._maybe_pause_protocol() def _write_ready(self): data = b''.join(self._buffer) assert data, 'Data should not be empty' - self._buffer.clear() + self._buffer.clear() # Optimistically; may have to put it back later. try: n = self._sock.send(data) except (BlockingIOError, InterruptedError): - self._buffer.append(data) + self._buffer.append(data) # Still need to write this. except Exception as exc: self._loop.remove_writer(self._sock_fd) self._fatal_error(exc) else: data = data[n:] - if not data: + if data: + self._buffer.append(data) # Still need to write this. + self._maybe_resume_protocol() # May append to buffer. + if not self._buffer: self._loop.remove_writer(self._sock_fd) if self._closing: self._call_connection_lost(None) elif self._eof: self._sock.shutdown(socket.SHUT_WR) - return - self._buffer.append(data) # Try again later. def write_eof(self): if self._eof: @@ -642,6 +680,8 @@ def _on_ready(self): if n < len(data): self._buffer.append(data[n:]) + self._maybe_resume_protocol() # May append to buffer. + if self._closing and not self._buffer: self._loop.remove_writer(self._sock_fd) self._call_connection_lost(None) @@ -657,8 +697,9 @@ def write(self, data): self._conn_lost += 1 return - self._buffer.append(data) # We could optimize, but the callback can do this for now. + self._buffer.append(data) + self._maybe_pause_protocol() def can_write_eof(self): return False @@ -675,11 +716,13 @@ class _SelectorDatagramTransport(_SelectorTransport): def __init__(self, loop, sock, protocol, address=None, extra=None): super().__init__(loop, sock, protocol, extra) - self._address = address self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) + def get_write_buffer_size(self): + return sum(len(data) for data, _ in self._buffer) + def _read_ready(self): try: data, addr = self._sock.recvfrom(self.max_size) @@ -723,6 +766,7 @@ def sendto(self, data, addr=None): return self._buffer.append((data, addr)) + self._maybe_pause_protocol() def _sendto_ready(self): while self._buffer: @@ -743,6 +787,7 @@ def _sendto_ready(self): self._fatal_error(exc) return + self._maybe_resume_protocol() # May append to buffer. if not self._buffer: self._loop.remove_writer(self._sock_fd) if self._closing: diff --git a/tulip/streams.py b/tulip/streams.py index 9915aa5c..e9953682 100644 --- a/tulip/streams.py +++ b/tulip/streams.py @@ -39,7 +39,8 @@ def open_connection(host=None, port=None, *, protocol = StreamReaderProtocol(reader) transport, _ = yield from loop.create_connection( lambda: protocol, host, port, **kwds) - return reader, transport # (reader, writer) + writer = StreamWriter(transport, protocol, reader, loop) + return reader, writer class StreamReaderProtocol(protocols.Protocol): @@ -52,22 +53,113 @@ class StreamReaderProtocol(protocols.Protocol): """ def __init__(self, stream_reader): - self.stream_reader = stream_reader + self._stream_reader = stream_reader + self._drain_waiter = None + self._paused = False def connection_made(self, transport): - self.stream_reader.set_transport(transport) + self._stream_reader.set_transport(transport) def connection_lost(self, exc): if exc is None: - self.stream_reader.feed_eof() + self._stream_reader.feed_eof() else: - self.stream_reader.set_exception(exc) + self._stream_reader.set_exception(exc) + # Also wake up the writing side. + if self._paused: + waiter = self._drain_waiter + if waiter is not None: + self._drain_waiter = None + if not waiter.done(): + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) def data_received(self, data): - self.stream_reader.feed_data(data) + self._stream_reader.feed_data(data) def eof_received(self): - self.stream_reader.feed_eof() + self._stream_reader.feed_eof() + + def pause_writing(self): + assert not self._paused + self._paused = True + + def resume_writing(self): + assert self._paused + self._paused = False + waiter = self._drain_waiter + if waiter is not None: + self._drain_waiter = None + if not waiter.done(): + waiter.set_result(None) + + +class StreamWriter: + """Wraps a Transport. + + This exposes write(), writelines(), [can_]write_eof(), + get_extra_info() and close(). It adds drain() which returns an + optional Future on which you can wait for flow control. It also + adds a transport attribute which references the Transport + directly. + """ + + def __init__(self, transport, protocol, reader, loop): + self._transport = transport + self._protocol = protocol + self._reader = reader + self._loop = loop + + @property + def transport(self): + return self._transport + + def write(self, data): + self._transport.write(data) + + def writelines(self, data): + self._transport.writelines(data) + + def write_eof(self): + return self._transport.write_eof() + + def can_write_eof(self): + return self._transport.can_write_eof() + + def close(self): + return self._transport.close() + + def get_extra_info(self, name, default=None): + return self._transport.get_extra_info(name, default) + + def drain(self): + """This method has an unusual return value. + + The intended use is to write + + w.write(data) + yield from w.drain() + + When there's nothing to wait for, drain() returns (), and the + yield-from continues immediately. When the transport buffer + is full (the protocol is paused), drain() creates and returns + a Future and the yield-from will block until that Future is + completed, which will happen when the buffer is (partially) + drained and the protocol is resumed. + """ + if self._reader._exception is not None: + raise self._writer._exception + if self._transport._conn_lost: # Uses private variable. + raise ConnectionResetError('Connection lost') + if not self._protocol._paused: + return () + waiter = self._protocol._drain_waiter + assert waiter is None or waiter.cancelled() + waiter = futures.Future(loop=self._loop) + self._protocol._drain_waiter = waiter + return waiter class StreamReader: @@ -75,14 +167,14 @@ class StreamReader: def __init__(self, limit=_DEFAULT_LIMIT, loop=None): # The line length limit is a security feature; # it also doubles as half the buffer limit. - self.limit = limit + self._limit = limit if loop is None: loop = events.get_event_loop() - self.loop = loop - self.buffer = collections.deque() # Deque of bytes objects. - self.byte_count = 0 # Bytes in buffer. - self.eof = False # Whether we're done. - self.waiter = None # A future. + self._loop = loop + self._buffer = collections.deque() # Deque of bytes objects. + self._byte_count = 0 # Bytes in buffer. + self._eof = False # Whether we're done. + self._waiter = None # A future. self._exception = None self._transport = None self._paused = False @@ -93,9 +185,9 @@ def exception(self): def set_exception(self, exc): self._exception = exc - waiter = self.waiter + waiter = self._waiter if waiter is not None: - self.waiter = None + self._waiter = None if not waiter.cancelled(): waiter.set_exception(exc) @@ -104,15 +196,15 @@ def set_transport(self, transport): self._transport = transport def _maybe_resume_transport(self): - if self._paused and self.byte_count <= self.limit: + if self._paused and self._byte_count <= self._limit: self._paused = False self._transport.resume_reading() def feed_eof(self): - self.eof = True - waiter = self.waiter + self._eof = True + waiter = self._waiter if waiter is not None: - self.waiter = None + self._waiter = None if not waiter.cancelled(): waiter.set_result(True) @@ -120,18 +212,18 @@ def feed_data(self, data): if not data: return - self.buffer.append(data) - self.byte_count += len(data) + self._buffer.append(data) + self._byte_count += len(data) - waiter = self.waiter + waiter = self._waiter if waiter is not None: - self.waiter = None + self._waiter = None if not waiter.cancelled(): waiter.set_result(False) if (self._transport is not None and not self._paused and - self.byte_count > 2*self.limit): + self._byte_count > 2*self._limit): try: self._transport.pause_reading() except NotImplementedError: @@ -152,8 +244,8 @@ def readline(self): not_enough = True while not_enough: - while self.buffer and not_enough: - data = self.buffer.popleft() + while self._buffer and not_enough: + data = self._buffer.popleft() ichar = data.find(b'\n') if ichar < 0: parts.append(data) @@ -162,29 +254,29 @@ def readline(self): ichar += 1 head, tail = data[:ichar], data[ichar:] if tail: - self.buffer.appendleft(tail) + self._buffer.appendleft(tail) not_enough = False parts.append(head) parts_size += len(head) - if parts_size > self.limit: - self.byte_count -= parts_size + if parts_size > self._limit: + self._byte_count -= parts_size self._maybe_resume_transport() raise ValueError('Line is too long') - if self.eof: + if self._eof: break if not_enough: - assert self.waiter is None - self.waiter = futures.Future(loop=self.loop) + assert self._waiter is None + self._waiter = futures.Future(loop=self._loop) try: - yield from self.waiter + yield from self._waiter finally: - self.waiter = None + self._waiter = None line = b''.join(parts) - self.byte_count -= parts_size + self._byte_count -= parts_size self._maybe_resume_transport() return line @@ -198,42 +290,42 @@ def read(self, n=-1): return b'' if n < 0: - while not self.eof: - assert not self.waiter - self.waiter = futures.Future(loop=self.loop) + while not self._eof: + assert not self._waiter + self._waiter = futures.Future(loop=self._loop) try: - yield from self.waiter + yield from self._waiter finally: - self.waiter = None + self._waiter = None else: - if not self.byte_count and not self.eof: - assert not self.waiter - self.waiter = futures.Future(loop=self.loop) + if not self._byte_count and not self._eof: + assert not self._waiter + self._waiter = futures.Future(loop=self._loop) try: - yield from self.waiter + yield from self._waiter finally: - self.waiter = None + self._waiter = None - if n < 0 or self.byte_count <= n: - data = b''.join(self.buffer) - self.buffer.clear() - self.byte_count = 0 + if n < 0 or self._byte_count <= n: + data = b''.join(self._buffer) + self._buffer.clear() + self._byte_count = 0 self._maybe_resume_transport() return data parts = [] parts_bytes = 0 - while self.buffer and parts_bytes < n: - data = self.buffer.popleft() + while self._buffer and parts_bytes < n: + data = self._buffer.popleft() data_bytes = len(data) if n < parts_bytes + data_bytes: data_bytes = n - parts_bytes data, rest = data[:data_bytes], data[data_bytes:] - self.buffer.appendleft(rest) + self._buffer.appendleft(rest) parts.append(data) parts_bytes += data_bytes - self.byte_count -= data_bytes + self._byte_count -= data_bytes self._maybe_resume_transport() return b''.join(parts) @@ -246,12 +338,12 @@ def readexactly(self, n): if n <= 0: return b'' - while self.byte_count < n and not self.eof: - assert not self.waiter - self.waiter = futures.Future(loop=self.loop) + while self._byte_count < n and not self._eof: + assert not self._waiter + self._waiter = futures.Future(loop=self._loop) try: - yield from self.waiter + yield from self._waiter finally: - self.waiter = None + self._waiter = None return (yield from self.read(n)) diff --git a/tulip/transports.py b/tulip/transports.py index f1a71800..8c6b1896 100644 --- a/tulip/transports.py +++ b/tulip/transports.py @@ -49,6 +49,31 @@ def resume_reading(self): class WriteTransport(BaseTransport): """ABC for write-only transports.""" + def set_write_buffer_limits(self, high=None, low=None): + """Set the high- and low-water limits for write flow control. + + These two values control when to call the protocol's + pause_writing() and resume_writing() methods. If specified, + the low-water limit must be less than or equal to the + high-water limit. Neither value can be negative. + + The defaults are implementation-specific. If only the + high-water limit is given, the low-water limit defaults to a + implementation-specific value less than or equal to the + high-water limit. Setting high to zero forces low to zero as + well, and causes pause_writing() to be called whenever the + buffer becomes non-empty. Setting low to zero causes + resume_writing() to be called only once the buffer is empty. + Use of zero for either limit is generally sub-optimal as it + reduces opportunities for doing I/O and computation + concurrently. + """ + raise NotImplementedError + + def get_write_buffer_size(self): + """Return the current size of the write buffer.""" + raise NotImplementedError + def write(self, data): """Write some data bytes to the transport. From 28d6a26d0972a05b7a7115244a2eeb0eb3990cec Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 18 Oct 2013 15:07:14 -0700 Subject: [PATCH 0724/1502] Tentative fix for Windows ssl breakage on hostname mismatch. --- tulip/selector_events.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tulip/selector_events.py b/tulip/selector_events.py index adf8d382..63164f05 100644 --- a/tulip/selector_events.py +++ b/tulip/selector_events.py @@ -584,16 +584,23 @@ def _on_handshake(self): self._loop.add_writer(self._sock_fd, self._on_handshake) return except Exception as exc: + self._loop.remove_reader(self._sock_fd) + self._loop.remove_writer(self._sock_fd) self._sock.close() if self._waiter is not None: self._waiter.set_exception(exc) return except BaseException as exc: + self._loop.remove_reader(self._sock_fd) + self._loop.remove_writer(self._sock_fd) self._sock.close() if self._waiter is not None: self._waiter.set_exception(exc) raise + self._loop.remove_reader(self._sock_fd) + self._loop.remove_writer(self._sock_fd) + # Verify hostname if requested. peercert = self._sock.getpeercert() if (self._server_hostname is not None and @@ -612,8 +619,6 @@ def _on_handshake(self): compression=self._sock.compression(), ) - self._loop.remove_reader(self._sock_fd) - self._loop.remove_writer(self._sock_fd) self._loop.add_reader(self._sock_fd, self._on_ready) self._loop.add_writer(self._sock_fd, self._on_ready) self._loop.call_soon(self._protocol.connection_made, self) From 021e2e0b10306a1162f29fd424ac4b8cc4f82835 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 18 Oct 2013 15:11:22 -0700 Subject: [PATCH 0725/1502] Relax some test timeouts (http://bugs.python.org/issue19285). --- tests/base_events_test.py | 2 +- tests/events_test.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/base_events_test.py b/tests/base_events_test.py index 38fe07aa..60fb007d 100644 --- a/tests/base_events_test.py +++ b/tests/base_events_test.py @@ -121,7 +121,7 @@ def cb(): t0 = self.loop.time() self.loop.run_forever() t1 = self.loop.time() - self.assertTrue(0.09 <= t1-t0 <= 0.12, t1-t0) + self.assertTrue(0.09 <= t1-t0 <= 0.9, t1-t0) def test_run_once_in_executor_handle(self): def cb(): diff --git a/tests/events_test.py b/tests/events_test.py index f5f51e5b..5abd0c4d 100644 --- a/tests/events_test.py +++ b/tests/events_test.py @@ -215,7 +215,7 @@ def test_run_until_complete(self): t0 = self.loop.time() self.loop.run_until_complete(tasks.sleep(0.1, loop=self.loop)) t1 = self.loop.time() - self.assertTrue(0.08 <= t1-t0 <= 0.12, t1-t0) + self.assertTrue(0.08 <= t1-t0 <= 0.8, t1-t0) def test_run_until_complete_stopped(self): @tasks.coroutine @@ -238,7 +238,7 @@ def callback(arg): self.loop.run_forever() t1 = time.monotonic() self.assertEqual(results, ['hello world']) - self.assertTrue(0.08 <= t1-t0 <= 0.2, t1-t0) + self.assertTrue(0.08 <= t1-t0 <= 0.8, t1-t0) def test_call_soon(self): results = [] @@ -462,8 +462,8 @@ def my_handler(*args): self.loop.add_signal_handler(signal.SIGALRM, my_handler, *some_args) - signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once. - self.loop.call_later(0.015, self.loop.stop) + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + self.loop.call_later(0.5, self.loop.stop) self.loop.run_forever() self.assertEqual(caught, 1) From 2ecf8a3ad4b0e28cc867938c92fc41184fea6df7 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 18 Oct 2013 17:08:15 -0700 Subject: [PATCH 0726/1502] Offer to copy selectors.py too. --- update_stdlib.sh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/update_stdlib.sh b/update_stdlib.sh index 70e28f35..ceaa8f54 100755 --- a/update_stdlib.sh +++ b/update_stdlib.sh @@ -41,9 +41,10 @@ for i in `(cd asyncio && ls *.py)` do if [ $i == selectors.py ] then - continue + maybe_copy asyncio/$i Lib/$i + else + maybe_copy asyncio/$i Lib/asyncio/$i fi - maybe_copy asyncio/$i Lib/asyncio/$i done for i in `(cd tests && ls *.py sample.???)` From 8fc5b6d9c8da2d21c133373183eef77b4a860f08 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 20 Oct 2013 20:08:37 -0700 Subject: [PATCH 0727/1502] Temporarily skip some subprocess tests that fail on AIX. --- tests/test_events.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test_events.py b/tests/test_events.py index 4921e7fa..7254e2dc 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -983,6 +983,9 @@ def connect(): @unittest.skipIf(sys.platform == 'win32', "Don't support subprocess for Windows yet") + # Issue #19293 + @unittest.skipIf(sys.platform.startswith("aix"), + 'cannot be interrupted with signal on AIX') def test_subprocess_interactive(self): proto = None transp = None @@ -1081,6 +1084,9 @@ def connect(): @unittest.skipIf(sys.platform == 'win32', "Don't support subprocess for Windows yet") + # Issue #19293 + @unittest.skipIf(sys.platform.startswith("aix"), + 'cannot be interrupted with signal on AIX') def test_subprocess_kill(self): proto = None transp = None @@ -1104,6 +1110,9 @@ def connect(): @unittest.skipIf(sys.platform == 'win32', "Don't support subprocess for Windows yet") + # Issue #19293 + @unittest.skipIf(sys.platform.startswith("aix"), + 'cannot be interrupted with signal on AIX') def test_subprocess_send_signal(self): proto = None transp = None From 114bcc32e5389cce1e3fbaf363ac7308e62f7dea Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 20 Oct 2013 20:10:17 -0700 Subject: [PATCH 0728/1502] Verify hostname if verify_mode is CERT_OPTIONAL too. --- asyncio/selector_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 63164f05..dee23064 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -604,7 +604,7 @@ def _on_handshake(self): # Verify hostname if requested. peercert = self._sock.getpeercert() if (self._server_hostname is not None and - self._sslcontext.verify_mode == ssl.CERT_REQUIRED): + self._sslcontext.verify_mode != ssl.CERT_NONE): try: ssl.match_hostname(peercert, self._server_hostname) except Exception as exc: From f72a4c6abeb092841cb59608619d8beacab29137 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 20 Oct 2013 20:12:26 -0700 Subject: [PATCH 0729/1502] CPython issue #19299: fix refleak test failures in test_asyncio. --- asyncio/base_events.py | 8 ++++++++ asyncio/proactor_events.py | 1 + asyncio/selector_events.py | 1 + 3 files changed, 10 insertions(+) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 5f1bff71..2e007137 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -185,6 +185,14 @@ def stop(self): """ self.call_soon(_raise_stop_error) + def close(self): + self._ready.clear() + self._scheduled.clear() + executor = self._default_executor + if executor is not None: + self._default_executor = None + executor.shutdown(wait=False) + def is_running(self): """Returns running status of event loop.""" return self._running diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 665569f0..cb8625d9 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -276,6 +276,7 @@ def close(self): self._proactor.close() self._proactor = None self._selector = None + super().close() def sock_recv(self, sock, n): return self._proactor.recv(sock, n) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index dee23064..6cffdd4e 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -56,6 +56,7 @@ def close(self): self._close_self_pipe() self._selector.close() self._selector = None + super().close() def _socketpair(self): raise NotImplementedError From 20aef323ce57ecb4cbb331498498f3c028c8c966 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 20 Oct 2013 20:14:00 -0700 Subject: [PATCH 0730/1502] Skip test_asyncio dual stack test when IPv6 not supported. --- tests/test_events.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_events.py b/tests/test_events.py index 7254e2dc..098cf71f 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -17,7 +17,7 @@ import errno import unittest import unittest.mock -from test.support import find_unused_port +from test.support import find_unused_port, IPV6_ENABLED from asyncio import futures @@ -684,7 +684,7 @@ def test_create_server_addr_in_use(self): server.close() - @unittest.skipUnless(socket.has_ipv6, 'IPv6 not supported') + @unittest.skipUnless(IPV6_ENABLED, 'IPv6 not supported or enabled') def test_create_server_dual_stack(self): f_proto = futures.Future(loop=self.loop) From 719de27920ee26fa18bf9773014a8d5806aac5d9 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 20 Oct 2013 20:16:12 -0700 Subject: [PATCH 0731/1502] CPython issue #19305: fix sporadic test_asyncio failure on FreeBSD 10.0. --- asyncio/test_utils.py | 15 +++++++++++++++ tests/test_events.py | 7 +++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index 91bbedba..d650c447 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -7,6 +7,7 @@ import os import sys import threading +import time import unittest import unittest.mock from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer @@ -46,6 +47,20 @@ def once(): gen.close() +def run_until(loop, pred, timeout=None): + if timeout is not None: + deadline = time.time() + timeout + while not pred(): + if timeout is not None: + timeout = deadline - time.time() + if timeout <= 0: + return False + loop.run_until_complete(tasks.sleep(timeout, loop=loop)) + else: + run_briefly(loop) + return True + + def run_once(loop): """loop.stop() schedules _raise_stop_error() and run_forever() runs until _raise_stop_error() callback. diff --git a/tests/test_events.py b/tests/test_events.py index 098cf71f..f0f4810f 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -558,13 +558,14 @@ def factory(): self.assertEqual(host, '0.0.0.0') client = socket.socket() client.connect(('127.0.0.1', port)) - client.send(b'xxx') + client.sendall(b'xxx') test_utils.run_briefly(self.loop) self.assertIsInstance(proto, MyProto) self.assertEqual('INITIAL', proto.state) test_utils.run_briefly(self.loop) self.assertEqual('CONNECTED', proto.state) - test_utils.run_briefly(self.loop) # windows iocp + test_utils.run_until(self.loop, lambda: proto.nbytes > 0, + timeout=10) self.assertEqual(3, proto.nbytes) # extra info is available @@ -623,6 +624,8 @@ def factory(): self.assertIsInstance(proto, MyProto) test_utils.run_briefly(self.loop) self.assertEqual('CONNECTED', proto.state) + test_utils.run_until(self.loop, lambda: proto.nbytes > 0, + timeout=10) self.assertEqual(3, proto.nbytes) # extra info is available From 051438233dd3782ba0cd090988e9ff8fa0048205 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 20 Oct 2013 20:18:21 -0700 Subject: [PATCH 0732/1502] Ignore error from socket() if getaddrinfo() returned an unusable protocol/family combo. --- asyncio/base_events.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 2e007137..37d50aa2 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -464,7 +464,11 @@ def create_server(self, protocol_factory, host=None, port=None, try: for res in infos: af, socktype, proto, canonname, sa = res - sock = socket.socket(af, socktype, proto) + try: + sock = socket.socket(af, socktype, proto) + except socket.error: + # Assume it's a bad family/type/protocol combination. + continue sockets.append(sock) if reuse_address: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, From e6946b932305e8150f2485609d48dfac0ca758bb Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 20 Oct 2013 20:19:22 -0700 Subject: [PATCH 0733/1502] Break out of loop on EOF in echo test programs. --- tests/echo.py | 2 ++ tests/echo3.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tests/echo.py b/tests/echo.py index f6ac0a30..006364bb 100644 --- a/tests/echo.py +++ b/tests/echo.py @@ -3,4 +3,6 @@ if __name__ == '__main__': while True: buf = os.read(0, 1024) + if not buf: + break os.write(1, buf) diff --git a/tests/echo3.py b/tests/echo3.py index f1f7ea7c..06449673 100644 --- a/tests/echo3.py +++ b/tests/echo3.py @@ -3,6 +3,8 @@ if __name__ == '__main__': while True: buf = os.read(0, 1024) + if not buf: + break try: os.write(1, b'OUT:'+buf) except OSError as ex: From c87eda0aa0478925542fa30ff573d2f1b46abbaa Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 20 Oct 2013 20:21:11 -0700 Subject: [PATCH 0734/1502] CPython issue #19309: make waitpid() wait for processes from all groups. --- asyncio/unix_events.py | 2 +- tests/test_events.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index a234f4fa..7623f789 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -168,7 +168,7 @@ def _reg_sigchld(self): def _sig_chld(self): try: try: - pid, status = os.waitpid(0, os.WNOHANG) + pid, status = os.waitpid(-1, os.WNOHANG) except ChildProcessError: return if pid == 0: diff --git a/tests/test_events.py b/tests/test_events.py index f0f4810f..10ddabb8 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1233,6 +1233,26 @@ def connect(): self.loop.run_until_complete(proto.completed) self.assertEqual(-signal.SIGTERM, proto.returncode) + @unittest.skipIf(sys.platform == 'win32', + "Don't support subprocess for Windows yet") + def test_subprocess_wait_no_same_group(self): + proto = None + transp = None + + @tasks.coroutine + def connect(): + nonlocal proto + # start the new process in a new session + transp, proto = yield from self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None, + start_new_session=True) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + if sys.platform == 'win32': from asyncio import windows_events From 8ac07748b34be2aa8a4a9806637933a1206751d3 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 20 Oct 2013 20:22:16 -0700 Subject: [PATCH 0735/1502] Make various asyncio test files individually runnable. --- tests/test_base_events.py | 4 ++++ tests/test_proactor_events.py | 4 ++++ tests/test_selector_events.py | 4 ++++ tests/test_selectors.py | 4 ++++ tests/test_transports.py | 4 ++++ tests/test_unix_events.py | 4 ++++ tests/test_windows_events.py | 4 ++++ tests/test_windows_utils.py | 4 ++++ 8 files changed, 32 insertions(+) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index db15244e..fd48fdd5 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -588,3 +588,7 @@ def test_accept_connection_exception(self, m_log): self.loop._accept_connection(MyProto, sock) self.assertTrue(sock.close.called) self.assertTrue(m_log.exception.called) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index 05d1606c..5a2a51c4 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -478,3 +478,7 @@ def test_stop_serving(self): self.loop._stop_serving(sock) self.assertTrue(sock.close.called) self.proactor._stop_serving.assert_called_with(sock) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 53728b8d..fbd5d723 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -1483,3 +1483,7 @@ def test_fatal_error_connected(self, m_exc): transport._fatal_error(err) self.protocol.connection_refused.assert_called_with(err) m_exc.assert_called_with('Fatal error for %s', transport) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_selectors.py b/tests/test_selectors.py index 2f7dc69d..db5b3ece 100644 --- a/tests/test_selectors.py +++ b/tests/test_selectors.py @@ -143,3 +143,7 @@ def test_key_from_fd(self): if hasattr(selectors.DefaultSelector, 'fileno'): def test_fileno(self): self.assertIsInstance(selectors.DefaultSelector().fileno(), int) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_transports.py b/tests/test_transports.py index 53071afd..f96445c1 100644 --- a/tests/test_transports.py +++ b/tests/test_transports.py @@ -53,3 +53,7 @@ def test_subprocess_transport_not_implemented(self): self.assertRaises(NotImplementedError, transport.send_signal, 1) self.assertRaises(NotImplementedError, transport.terminate) self.assertRaises(NotImplementedError, transport.kill) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index ccabeea1..834df811 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -768,3 +768,7 @@ def test_write_eof_pending(self): tr.write_eof() self.assertTrue(tr._closing) self.assertFalse(self.protocol.connection_lost.called) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index 4b04073e..969360c1 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -93,3 +93,7 @@ def _test_pipe(self): protocols.Protocol, ADDRESS) return 'done' + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_windows_utils.py b/tests/test_windows_utils.py index 3b6b0368..f721d318 100644 --- a/tests/test_windows_utils.py +++ b/tests/test_windows_utils.py @@ -138,3 +138,7 @@ def test_popen(self): # allow for partial reads... self.assertTrue(msg.upper().rstrip().startswith(out)) self.assertTrue(b"stderr".startswith(err)) + + +if __name__ == '__main__': + unittest.main() From dbc8bcab99cb28991ea0bc97a187f5c50bc8e747 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 20 Oct 2013 20:23:06 -0700 Subject: [PATCH 0736/1502] Skip test_create_datagram_endpoint_no_matching_family if IPv6 unsupported. --- tests/test_base_events.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index fd48fdd5..9f36896f 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -5,6 +5,7 @@ import time import unittest import unittest.mock +from test.support import find_unused_port, IPV6_ENABLED from asyncio import base_events from asyncio import events @@ -533,6 +534,7 @@ def test_create_datagram_endpoint_socket_err(self, m_socket): self.assertRaises( OSError, self.loop.run_until_complete, coro) + @unittest.skipUnless(IPV6_ENABLED, 'IPv6 not supported or enabled') def test_create_datagram_endpoint_no_matching_family(self): coro = self.loop.create_datagram_endpoint( protocols.DatagramProtocol, From f43bf3cd26800b290ec8c4df46722fa59c94487b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 20 Oct 2013 20:24:09 -0700 Subject: [PATCH 0737/1502] CPython issue #19310: fix child processes reaping logic (CF Natali). --- asyncio/unix_events.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 7623f789..b4e26992 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -167,23 +167,25 @@ def _reg_sigchld(self): def _sig_chld(self): try: - try: - pid, status = os.waitpid(-1, os.WNOHANG) - except ChildProcessError: - return - if pid == 0: - self.call_soon(self._sig_chld) - return - elif os.WIFSIGNALED(status): - returncode = -os.WTERMSIG(status) - elif os.WIFEXITED(status): - returncode = os.WEXITSTATUS(status) - else: - self.call_soon(self._sig_chld) - return - transp = self._subprocesses.get(pid) - if transp is not None: - transp._process_exited(returncode) + # because of signal coalescing, we must keep calling waitpid() as + # long as we're able to reap a child + while True: + try: + pid, status = os.waitpid(-1, os.WNOHANG) + except ChildProcessError: + break + if pid == 0: + break + elif os.WIFSIGNALED(status): + returncode = -os.WTERMSIG(status) + elif os.WIFEXITED(status): + returncode = os.WEXITSTATUS(status) + else: + # shouldn't happen + continue + transp = self._subprocesses.get(pid) + if transp is not None: + transp._process_exited(returncode) except Exception: logger.exception('Unknown exception in SIGCHLD handler') From b1416df10d461639ea02731b321b6aa41e448b99 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 20 Oct 2013 20:25:04 -0700 Subject: [PATCH 0738/1502] CPython issue #19297: fix resource warnings. Patch by Vajrasky Kok. --- asyncio/test_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index d650c447..c278dd17 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -126,6 +126,7 @@ def app(environ, start_response): yield httpd finally: httpd.shutdown() + httpd.server_close() server_thread.join() From fe0b17a16c33d7338f200f83f9d77a3b8d50f377 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 21 Oct 2013 14:49:02 -0700 Subject: [PATCH 0739/1502] If waitpid() returns a weird status, the process is still dead. Also tidy up a few comment and replace functools.partial with lambda. --- asyncio/unix_events.py | 26 +++++++++++++------------- tests/test_unix_events.py | 2 +- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index b4e26992..8c0e09ad 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -3,7 +3,6 @@ import collections import errno import fcntl -import functools import os import signal import socket @@ -167,22 +166,26 @@ def _reg_sigchld(self): def _sig_chld(self): try: - # because of signal coalescing, we must keep calling waitpid() as - # long as we're able to reap a child + # Because of signal coalescing, we must keep calling waitpid() as + # long as we're able to reap a child. while True: try: pid, status = os.waitpid(-1, os.WNOHANG) except ChildProcessError: - break + break # No more child processes exist. if pid == 0: - break + break # All remaining child processes are still alive. elif os.WIFSIGNALED(status): + # A child process died because of a signal. returncode = -os.WTERMSIG(status) elif os.WIFEXITED(status): + # A child process exited (e.g. sys.exit()). returncode = os.WEXITSTATUS(status) else: - # shouldn't happen - continue + # A child exited, but we don't understand its status. + # This shouldn't happen, but if it does, let's just + # return that status; perhaps that helps debug it. + returncode = status transp = self._subprocesses.get(pid) if transp is not None: transp._process_exited(returncode) @@ -480,18 +483,15 @@ def _post_init(self): loop = self._loop if proc.stdin is not None: transp, proto = yield from loop.connect_write_pipe( - functools.partial( - _UnixWriteSubprocessPipeProto, self, STDIN), + lambda: _UnixWriteSubprocessPipeProto(self, STDIN), proc.stdin) if proc.stdout is not None: transp, proto = yield from loop.connect_read_pipe( - functools.partial( - _UnixReadSubprocessPipeProto, self, STDOUT), + lambda: _UnixReadSubprocessPipeProto(self, STDOUT), proc.stdout) if proc.stderr is not None: transp, proto = yield from loop.connect_read_pipe( - functools.partial( - _UnixReadSubprocessPipeProto, self, STDERR), + lambda: _UnixReadSubprocessPipeProto(self, STDERR), proc.stderr) if not self._pipes: self._try_connected() diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 834df811..27e70c6d 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -266,7 +266,7 @@ def test__sig_chld_unknown_status(self, m_waitpid, self.loop._subprocesses[7] = transp self.loop._sig_chld() - self.assertFalse(transp._process_exited.called) + self.assertTrue(transp._process_exited.called) self.assertFalse(m_WEXITSTATUS.called) self.assertFalse(m_WTERMSIG.called) From f6f3d77399fdc0e9276d3e4cec7d636de3cc122c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 21 Oct 2013 20:33:31 -0700 Subject: [PATCH 0740/1502] Skip test_subprocess_close_client_stream on AIX (to avoid hang). --- tests/test_events.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_events.py b/tests/test_events.py index 10ddabb8..1a4fe9cf 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1200,6 +1200,9 @@ def connect(): @unittest.skipIf(sys.platform == 'win32', "Don't support subprocess for Windows yet") + # Issue #19293 + @unittest.skipIf(sys.platform.startswith("aix"), + 'cannot be interrupted with signal on AIX') def test_subprocess_close_client_stream(self): proto = None transp = None From eaf48c0ceae2aae5eda8c965c732d2234f9f8e69 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 21 Oct 2013 20:36:17 -0700 Subject: [PATCH 0741/1502] Switch subprocess stdin to a socketpair, attempting to fix CPython issue #19293 (AIX hang). --- asyncio/unix_events.py | 29 +++++++++++++++++++++++++---- tests/test_unix_events.py | 7 +++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 8c0e09ad..3807680f 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -213,6 +213,9 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): self._loop = loop self._pipe = pipe self._fileno = pipe.fileno() + mode = os.fstat(self._fileno).st_mode + if not (stat.S_ISFIFO(mode) or stat.S_ISSOCK(mode)): + raise ValueError("Pipe transport is for pipes/sockets only.") _set_nonblocking(self._fileno) self._protocol = protocol self._closing = False @@ -275,21 +278,29 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): self._loop = loop self._pipe = pipe self._fileno = pipe.fileno() - if not stat.S_ISFIFO(os.fstat(self._fileno).st_mode): - raise ValueError("Pipe transport is for pipes only.") + mode = os.fstat(self._fileno).st_mode + is_socket = stat.S_ISSOCK(mode) + is_pipe = stat.S_ISFIFO(mode) + if not (is_socket or is_pipe): + raise ValueError("Pipe transport is for pipes/sockets only.") _set_nonblocking(self._fileno) self._protocol = protocol self._buffer = [] self._conn_lost = 0 self._closing = False # Set when close() or write_eof() called. - self._loop.add_reader(self._fileno, self._read_ready) + + # On AIX, the reader trick only works for sockets. + # On other platforms it works for pipes and sockets. + # (Exception: OS X 10.4? Issue #19294.) + if is_socket or not sys.platform.startswith("aix"): + self._loop.add_reader(self._fileno, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: self._loop.call_soon(waiter.set_result, None) def _read_ready(self): - # pipe was closed by peer + # Pipe was closed by peer. self._close() def write(self, data): @@ -435,8 +446,15 @@ def __init__(self, loop, protocol, args, shell, self._loop = loop self._pipes = {} + stdin_w = None if stdin == subprocess.PIPE: self._pipes[STDIN] = None + # Use a socket pair for stdin, since not all platforms + # support selecting read events on the write end of a + # socket (which we use in order to detect closing of the + # other end). Notably this is needed on AIX, and works + # just fine on other platforms. + stdin, stdin_w = self._loop._socketpair() if stdout == subprocess.PIPE: self._pipes[STDOUT] = None if stderr == subprocess.PIPE: @@ -448,6 +466,9 @@ def __init__(self, loop, protocol, args, shell, self._proc = subprocess.Popen( args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, universal_newlines=False, bufsize=bufsize, **kwargs) + if stdin_w is not None: + stdin.close() + self._proc.stdin = open(stdin_w.detach(), 'rb', buffering=bufsize) self._extra['subprocess'] = self._proc def close(self): diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 27e70c6d..f29e7afe 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -312,6 +312,13 @@ def setUp(self): fcntl_patcher.start() self.addCleanup(fcntl_patcher.stop) + fstat_patcher = unittest.mock.patch('os.fstat') + m_fstat = fstat_patcher.start() + st = unittest.mock.Mock() + st.st_mode = stat.S_IFIFO + m_fstat.return_value = st + self.addCleanup(fstat_patcher.stop) + def test_ctor(self): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) From 74a367d5a4ae412d55ef261e95379ae52e1871d0 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 21 Oct 2013 21:01:59 -0700 Subject: [PATCH 0742/1502] Unsilence several tests that no longer hang on AIX, and silence a new AIX hang. --- tests/test_events.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/tests/test_events.py b/tests/test_events.py index 1a4fe9cf..3924a2f9 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -887,6 +887,9 @@ def connect(): @unittest.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") + # Issue #19293 + @unittest.skipIf(sys.platform.startswith("aix"), + 'cannot be interrupted with signal on AIX') def test_write_pipe_disconnect_on_close(self): proto = None transport = None @@ -986,9 +989,6 @@ def connect(): @unittest.skipIf(sys.platform == 'win32', "Don't support subprocess for Windows yet") - # Issue #19293 - @unittest.skipIf(sys.platform.startswith("aix"), - 'cannot be interrupted with signal on AIX') def test_subprocess_interactive(self): proto = None transp = None @@ -1087,9 +1087,6 @@ def connect(): @unittest.skipIf(sys.platform == 'win32', "Don't support subprocess for Windows yet") - # Issue #19293 - @unittest.skipIf(sys.platform.startswith("aix"), - 'cannot be interrupted with signal on AIX') def test_subprocess_kill(self): proto = None transp = None @@ -1113,9 +1110,6 @@ def connect(): @unittest.skipIf(sys.platform == 'win32', "Don't support subprocess for Windows yet") - # Issue #19293 - @unittest.skipIf(sys.platform.startswith("aix"), - 'cannot be interrupted with signal on AIX') def test_subprocess_send_signal(self): proto = None transp = None @@ -1200,9 +1194,6 @@ def connect(): @unittest.skipIf(sys.platform == 'win32', "Don't support subprocess for Windows yet") - # Issue #19293 - @unittest.skipIf(sys.platform.startswith("aix"), - 'cannot be interrupted with signal on AIX') def test_subprocess_close_client_stream(self): proto = None transp = None From 6ea686a2f529f67013cef08501258070d0e2fe82 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 21 Oct 2013 21:28:58 -0700 Subject: [PATCH 0743/1502] Fix CPython issue #19293 (hangs on AIX). --- tests/test_events.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/test_events.py b/tests/test_events.py index 3924a2f9..98896e81 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -887,9 +887,6 @@ def connect(): @unittest.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") - # Issue #19293 - @unittest.skipIf(sys.platform.startswith("aix"), - 'cannot be interrupted with signal on AIX') def test_write_pipe_disconnect_on_close(self): proto = None transport = None @@ -899,8 +896,8 @@ def factory(): proto = MyWritePipeProto(loop=self.loop) return proto - rpipe, wpipe = os.pipe() - pipeobj = io.open(wpipe, 'wb', 1024) + rsock, wsock = self.loop._socketpair() + pipeobj = io.open(wsock.detach(), 'wb', 1024) @tasks.coroutine def connect(): @@ -916,11 +913,10 @@ def connect(): self.assertEqual('CONNECTED', proto.state) transport.write(b'1') - test_utils.run_briefly(self.loop) - data = os.read(rpipe, 1024) + data = self.loop.run_until_complete(self.loop.sock_recv(rsock, 1024)) self.assertEqual(b'1', data) - os.close(rpipe) + rsock.close() self.loop.run_until_complete(proto.done) self.assertEqual('CLOSED', proto.state) From 13d7f672626cb13bf9ec2ca3a4fb63d60a3bfaf6 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 22 Oct 2013 10:33:07 -0700 Subject: [PATCH 0744/1502] Changes by Sonald Stufft to build pypi distros. Yay! --- .hgignore | 2 ++ MANIFEST.in | 7 +++++++ Makefile | 7 +++++++ pypi.bat | 1 + setup.cfg | 2 -- setup.py | 30 ++++++++++++++++++++++-------- 6 files changed, 39 insertions(+), 10 deletions(-) create mode 100644 MANIFEST.in create mode 100644 pypi.bat delete mode 100644 setup.cfg diff --git a/.hgignore b/.hgignore index 99870025..6d1136f2 100644 --- a/.hgignore +++ b/.hgignore @@ -10,3 +10,5 @@ venv$ distribute_setup.py$ distribute-\d+.\d+.\d+.tar.gz$ build$ +dist$ +.*\.egg-info$ diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..317dcc3d --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,7 @@ +include Makefile +include *.c *.py + +recursive-include examples *.py +recursive-include tests *.crt +recursive-include tests *.key +recursive-include tests *.py diff --git a/Makefile b/Makefile index ed3caf21..74c38cae 100644 --- a/Makefile +++ b/Makefile @@ -30,5 +30,12 @@ clean: rm -f `find . -type f -name '#*#' ` rm -f `find . -type f -name '*.orig' ` rm -f `find . -type f -name '*.rej' ` + rm -rf dist rm -f .coverage rm -rf htmlcov + rm -f MANIFEST + + +# Make distributions for Python 3.3 +pypi: clean + python3.3 setup.py sdist upload diff --git a/pypi.bat b/pypi.bat new file mode 100644 index 00000000..5218ace3 --- /dev/null +++ b/pypi.bat @@ -0,0 +1 @@ +c:\Python33\python.exe setup.py bdist_wheel upload diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 172844ce..00000000 --- a/setup.cfg +++ /dev/null @@ -1,2 +0,0 @@ -[build_ext] -build_lib=asyncio diff --git a/setup.py b/setup.py index fad16e7a..011db099 100644 --- a/setup.py +++ b/setup.py @@ -1,14 +1,28 @@ import os -from distutils.core import setup, Extension +from setuptools import setup, Extension extensions = [] if os.name == 'nt': - ext = Extension('_overlapped', ['overlapped.c'], libraries=['ws2_32']) + ext = Extension( + 'asyncio._overlapped', ['overlapped.c'], libraries=['ws2_32'], + ) extensions.append(ext) -setup(name='asyncio', - description="reference implementation of PEP 3156", - url='http://www.python.org/dev/peps/pep-3156/', - packages=['asyncio'], - ext_modules=extensions - ) +setup( + name="asyncio", + version="0.1.1", + + description="reference implementation of PEP 3156", + long_description=open("README").read(), + url="http://www.python.org/dev/peps/pep-3156/", + + classifiers=[ + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.3", + ], + + packages=["asyncio"], + + ext_modules=extensions, +) From fc1e2589604fcbbee297ba6b00ab21eb2dcda918 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 22 Oct 2013 10:38:45 -0700 Subject: [PATCH 0745/1502] Added tag 0.1.1 for changeset 27f5ccf1ba62 From ad2b700392f69cb661dc6573ae3110db95cee2cf Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 23 Oct 2013 08:51:37 -0700 Subject: [PATCH 0746/1502] Make it not a fatal error when accept() raises EMFILE. Fixes issue #78. --- asyncio/selector_events.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 6cffdd4e..60bd3c80 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -5,6 +5,7 @@ """ import collections +import errno import socket try: import ssl @@ -100,7 +101,11 @@ def _accept_connection(self, protocol_factory, sock, ssl=None, conn.setblocking(False) except (BlockingIOError, InterruptedError): pass # False alarm. - except Exception: + except Exception as exc: + if isinstance(exc, OSError) and exc.errno == errno.EMFILE: + # Too many filedescriptors. Don't die. + logger.error('Out of FDs accepting connections') + return # Bad error. Stop serving. self.remove_reader(sock.fileno()) sock.close() From e98ad367b770f86159b3f4483e58a3ca3791613c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 24 Oct 2013 16:32:09 -0700 Subject: [PATCH 0747/1502] When not closing the connection after receiving EOF, still remove the read handler. --- asyncio/selector_events.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 60bd3c80..0defe811 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -473,7 +473,12 @@ def _read_ready(self): self._protocol.data_received(data) else: keep_open = self._protocol.eof_received() - if not keep_open: + if keep_open: + # We're keeping the connection open so the + # protocol can write more, but we still can't + # receive more, so remove the reader callback. + self._loop.remove_reader(self._sock_fd) + else: self.close() def write(self, data): From 9fb9a1c2770cd4599159f6827a447f801d5ca7bb Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 24 Oct 2013 16:41:48 -0700 Subject: [PATCH 0748/1502] Update some comments. --- asyncio/events.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/asyncio/events.py b/asyncio/events.py index 6ca5668c..62f8e949 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -42,7 +42,7 @@ def _run(self): def make_handle(callback, args): - # TODO: Inline this? + # TODO: Inline this? Or make it a private EventLoop method? assert not isinstance(callback, Handle), 'A Handle is not a callback' return Handle(callback, args) @@ -338,7 +338,6 @@ def get_event_loop(self): def set_event_loop(self, loop): """Set the event loop.""" - # TODO: The isinstance() test violates the PEP. self._set_called = True assert loop is None or isinstance(loop, AbstractEventLoop) self._loop = loop @@ -375,7 +374,6 @@ def get_event_loop_policy(): def set_event_loop_policy(policy): """XXX""" global _event_loop_policy - # TODO: The isinstance() test violates the PEP. assert policy is None or isinstance(policy, AbstractEventLoopPolicy) _event_loop_policy = policy From 24151267b44818c8b0d57c8aec59e54d0b2ff6f0 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 24 Oct 2013 16:44:48 -0700 Subject: [PATCH 0749/1502] Just log any exception coming out of accept(). Fixes issue #78. --- asyncio/selector_events.py | 7 ------- tests/test_base_events.py | 2 +- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 0defe811..7ec07491 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -102,13 +102,6 @@ def _accept_connection(self, protocol_factory, sock, ssl=None, except (BlockingIOError, InterruptedError): pass # False alarm. except Exception as exc: - if isinstance(exc, OSError) and exc.errno == errno.EMFILE: - # Too many filedescriptors. Don't die. - logger.error('Out of FDs accepting connections') - return - # Bad error. Stop serving. - self.remove_reader(sock.fileno()) - sock.close() # There's nowhere to send the error, so just log it. # TODO: Someone will want an error handler for this. logger.exception('Accept failed') diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 9f36896f..09a5fcb2 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -588,8 +588,8 @@ def test_accept_connection_exception(self, m_log): sock.accept.side_effect = OSError() self.loop._accept_connection(MyProto, sock) - self.assertTrue(sock.close.called) self.assertTrue(m_log.exception.called) + self.assertFalse(sock.close.called) if __name__ == '__main__': From 53f842abc067bd3b241af6841b5ca1b568142df0 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Fri, 25 Oct 2013 19:06:41 +0100 Subject: [PATCH 0750/1502] Make the IOCP proactor support "waitable" handles. --- asyncio/windows_events.py | 40 ++++++++ overlapped.c | 176 +++++++++++++++++++++++++++++++++++ tests/test_windows_events.py | 39 ++++++++ 3 files changed, 255 insertions(+) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index bbeada87..1ffac999 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -46,6 +46,22 @@ def cancel(self): return super().cancel() +class _WaitHandleFuture(futures.Future): + """Subclass of Future which represents a wait handle.""" + + def __init__(self, wait_handle, *, loop=None): + super().__init__(loop=loop) + self._wait_handle = wait_handle + + def cancel(self): + super().cancel() + try: + _overlapped.UnregisterWait(self._wait_handle) + except OSError as e: + if e.winerror != _overlapped.ERROR_IO_PENDING: + raise + + class PipeServer(object): """Class representing a pipe server. @@ -271,6 +287,30 @@ def finish(err, handle, ov): return windows_utils.PipeHandle(handle) return self._register(ov, None, finish, wait_for_post=True) + def wait_for_handle(self, handle, timeout=None): + if timeout is None: + ms = _winapi.INFINITE + else: + ms = int(timeout * 1000 + 0.5) + + # We only create ov so we can use ov.address as a key for the cache. + ov = _overlapped.Overlapped(NULL) + wh = _overlapped.RegisterWaitWithQueue( + handle, self._iocp, ov.address, ms) + f = _WaitHandleFuture(wh, loop=self._loop) + + def finish(timed_out, _, ov): + if not f.cancelled(): + try: + _overlapped.UnregisterWait(wh) + except OSError as e: + if e.winerror != _overlapped.ERROR_IO_PENDING: + raise + return not timed_out + + self._cache[ov.address] = (f, ov, None, finish) + return f + def _register_with_iocp(self, obj): # To get notifications of finished ops on this objects sent to the # completion port, were must register the handle. diff --git a/overlapped.c b/overlapped.c index 6a1d9e4a..625c76ef 100644 --- a/overlapped.c +++ b/overlapped.c @@ -227,6 +227,172 @@ overlapped_PostQueuedCompletionStatus(PyObject *self, PyObject *args) Py_RETURN_NONE; } +/* + * Wait for a handle + */ + +struct PostCallbackData { + HANDLE CompletionPort; + LPOVERLAPPED Overlapped; +}; + +static VOID CALLBACK +PostToQueueCallback(PVOID lpParameter, BOOL TimerOrWaitFired) +{ + struct PostCallbackData *p = (struct PostCallbackData*) lpParameter; + + PostQueuedCompletionStatus(p->CompletionPort, TimerOrWaitFired, + 0, p->Overlapped); + /* ignore possible error! */ + PyMem_Free(p); +} + +PyDoc_STRVAR( + RegisterWaitWithQueue_doc, + "RegisterWaitWithQueue(Object, CompletionPort, Overlapped, Timeout)\n" + " -> WaitHandle\n\n" + "Register wait for Object; when complete CompletionPort is notified.\n"); + +static PyObject * +overlapped_RegisterWaitWithQueue(PyObject *self, PyObject *args) +{ + HANDLE NewWaitObject; + HANDLE Object; + ULONG Milliseconds; + struct PostCallbackData data, *pdata; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE F_POINTER F_DWORD, + &Object, + &data.CompletionPort, + &data.Overlapped, + &Milliseconds)) + return NULL; + + pdata = PyMem_Malloc(sizeof(struct PostCallbackData)); + if (pdata == NULL) + return SetFromWindowsErr(0); + + *pdata = data; + + if (!RegisterWaitForSingleObject( + &NewWaitObject, Object, (WAITORTIMERCALLBACK)PostToQueueCallback, + pdata, Milliseconds, + WT_EXECUTEINWAITTHREAD | WT_EXECUTEONLYONCE)) + { + PyMem_Free(pdata); + return SetFromWindowsErr(0); + } + + return Py_BuildValue(F_HANDLE, NewWaitObject); +} + +PyDoc_STRVAR( + UnregisterWait_doc, + "UnregisterWait(WaitHandle) -> None\n\n" + "Unregister wait handle.\n"); + +static PyObject * +overlapped_UnregisterWait(PyObject *self, PyObject *args) +{ + HANDLE WaitHandle; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE, &WaitHandle)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = UnregisterWait(WaitHandle); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +/* + * Event functions -- currently only used by tests + */ + +PyDoc_STRVAR( + CreateEvent_doc, + "CreateEvent(EventAttributes, ManualReset, InitialState, Name)" + " -> Handle\n\n" + "Create an event. EventAttributes must be None.\n"); + +static PyObject * +overlapped_CreateEvent(PyObject *self, PyObject *args) +{ + PyObject *EventAttributes; + BOOL ManualReset; + BOOL InitialState; + Py_UNICODE *Name; + HANDLE Event; + + if (!PyArg_ParseTuple(args, "O" F_BOOL F_BOOL "Z", + &EventAttributes, &ManualReset, + &InitialState, &Name)) + return NULL; + + if (EventAttributes != Py_None) { + PyErr_SetString(PyExc_ValueError, "EventAttributes must be None"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + Event = CreateEventW(NULL, ManualReset, InitialState, Name); + Py_END_ALLOW_THREADS + + if (Event == NULL) + return SetFromWindowsErr(0); + return Py_BuildValue(F_HANDLE, Event); +} + +PyDoc_STRVAR( + SetEvent_doc, + "SetEvent(Handle) -> None\n\n" + "Set event.\n"); + +static PyObject * +overlapped_SetEvent(PyObject *self, PyObject *args) +{ + HANDLE Handle; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE, &Handle)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = SetEvent(Handle); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +PyDoc_STRVAR( + ResetEvent_doc, + "ResetEvent(Handle) -> None\n\n" + "Reset event.\n"); + +static PyObject * +overlapped_ResetEvent(PyObject *self, PyObject *args) +{ + HANDLE Handle; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE, &Handle)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = ResetEvent(Handle); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + /* * Bind socket handle to local port without doing slow getaddrinfo() */ @@ -1147,6 +1313,16 @@ static PyMethodDef overlapped_functions[] = { METH_VARARGS, FormatMessage_doc}, {"BindLocal", overlapped_BindLocal, METH_VARARGS, BindLocal_doc}, + {"RegisterWaitWithQueue", overlapped_RegisterWaitWithQueue, + METH_VARARGS, RegisterWaitWithQueue_doc}, + {"UnregisterWait", overlapped_UnregisterWait, + METH_VARARGS, UnregisterWait_doc}, + {"CreateEvent", overlapped_CreateEvent, + METH_VARARGS, CreateEvent_doc}, + {"SetEvent", overlapped_SetEvent, + METH_VARARGS, SetEvent_doc}, + {"ResetEvent", overlapped_ResetEvent, + METH_VARARGS, ResetEvent_doc}, {NULL} }; diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index 969360c1..17146a36 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -1,6 +1,7 @@ import os import sys import unittest +import _winapi if sys.platform != 'win32': raise unittest.SkipTest('Windows only') @@ -8,10 +9,12 @@ import asyncio from asyncio import windows_events +from asyncio import futures from asyncio import protocols from asyncio import streams from asyncio import transports from asyncio import test_utils +from asyncio import _overlapped class UpperProto(protocols.Protocol): @@ -94,6 +97,42 @@ def _test_pipe(self): return 'done' + def test_wait_for_handle(self): + event = _overlapped.CreateEvent(None, True, False, None) + self.addCleanup(_winapi.CloseHandle, event) + + # Wait for unset event with 0.2s timeout; + # result should be False at timeout + f = self.loop._proactor.wait_for_handle(event, 0.2) + start = self.loop.time() + self.loop.run_until_complete(f) + elapsed = self.loop.time() - start + self.assertFalse(f.result()) + self.assertTrue(0.18 < elapsed < 0.22, elapsed) + + _overlapped.SetEvent(event) + + # Wait for for set event; + # result should be True immediately + f = self.loop._proactor.wait_for_handle(event, 10) + start = self.loop.time() + self.loop.run_until_complete(f) + elapsed = self.loop.time() - start + self.assertTrue(f.result()) + self.assertTrue(0 <= elapsed < 0.02, elapsed) + + _overlapped.ResetEvent(event) + + # Wait for unset event with a cancelled future; + # CancelledError should be raised immediately + f = self.loop._proactor.wait_for_handle(event, 10) + f.cancel() + start = self.loop.time() + with self.assertRaises(futures.CancelledError): + self.loop.run_until_complete(f) + elapsed = self.loop.time() - start + self.assertTrue(0 <= elapsed < 0.02, elapsed) + if __name__ == '__main__': unittest.main() From e21c56fa6466ae418c05230e43057940095a144f Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Sun, 27 Oct 2013 15:14:17 +0000 Subject: [PATCH 0751/1502] Add support for running subprocesses on Windows with the IOCP event loop. --- asyncio/base_subprocess.py | 166 +++++++++++++++++++++++++++++++++++++ asyncio/proactor_events.py | 11 ++- asyncio/unix_events.py | 145 +------------------------------- asyncio/windows_events.py | 28 +++++++ asyncio/windows_utils.py | 19 ++++- tests/test_events.py | 105 +++++++++++++++-------- 6 files changed, 291 insertions(+), 183 deletions(-) create mode 100644 asyncio/base_subprocess.py diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py new file mode 100644 index 00000000..d15fb159 --- /dev/null +++ b/asyncio/base_subprocess.py @@ -0,0 +1,166 @@ +import collections +import subprocess + +from . import protocols +from . import tasks +from . import transports + + +STDIN = 0 +STDOUT = 1 +STDERR = 2 + + +class BaseSubprocessTransport(transports.SubprocessTransport): + + def __init__(self, loop, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + super().__init__(extra) + self._protocol = protocol + self._loop = loop + + self._pipes = {} + if stdin == subprocess.PIPE: + self._pipes[STDIN] = None + if stdout == subprocess.PIPE: + self._pipes[STDOUT] = None + if stderr == subprocess.PIPE: + self._pipes[STDERR] = None + self._pending_calls = collections.deque() + self._finished = False + self._returncode = None + self._start(args=args, shell=shell, stdin=stdin, stdout=stdout, + stderr=stderr, bufsize=bufsize, **kwargs) + self._extra['subprocess'] = self._proc + + def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): + raise NotImplementedError + + def _make_write_subprocess_pipe_proto(self, fd): + raise NotImplementedError + + def _make_read_subprocess_pipe_proto(self, fd): + raise NotImplementedError + + def close(self): + for proto in self._pipes.values(): + proto.pipe.close() + if self._returncode is None: + self.terminate() + + def get_pid(self): + return self._proc.pid + + def get_returncode(self): + return self._returncode + + def get_pipe_transport(self, fd): + if fd in self._pipes: + return self._pipes[fd].pipe + else: + return None + + def send_signal(self, signal): + self._proc.send_signal(signal) + + def terminate(self): + self._proc.terminate() + + def kill(self): + self._proc.kill() + + @tasks.coroutine + def _post_init(self): + proc = self._proc + loop = self._loop + if proc.stdin is not None: + transp, proto = yield from loop.connect_write_pipe( + lambda: WriteSubprocessPipeProto(self, STDIN), + proc.stdin) + if proc.stdout is not None: + transp, proto = yield from loop.connect_read_pipe( + lambda: ReadSubprocessPipeProto(self, STDOUT), + proc.stdout) + if proc.stderr is not None: + transp, proto = yield from loop.connect_read_pipe( + lambda: ReadSubprocessPipeProto(self, STDERR), + proc.stderr) + if not self._pipes: + self._try_connected() + + def _call(self, cb, *data): + if self._pending_calls is not None: + self._pending_calls.append((cb, data)) + else: + self._loop.call_soon(cb, *data) + + def _try_connected(self): + assert self._pending_calls is not None + if all(p is not None and p.connected for p in self._pipes.values()): + self._loop.call_soon(self._protocol.connection_made, self) + for callback, data in self._pending_calls: + self._loop.call_soon(callback, *data) + self._pending_calls = None + + def _pipe_connection_lost(self, fd, exc): + self._call(self._protocol.pipe_connection_lost, fd, exc) + self._try_finish() + + def _pipe_data_received(self, fd, data): + self._call(self._protocol.pipe_data_received, fd, data) + + def _process_exited(self, returncode): + assert returncode is not None, returncode + assert self._returncode is None, self._returncode + self._returncode = returncode + self._loop._subprocess_closed(self) + self._call(self._protocol.process_exited) + self._try_finish() + + def _try_finish(self): + assert not self._finished + if self._returncode is None: + return + if all(p is not None and p.disconnected + for p in self._pipes.values()): + self._finished = True + self._loop.call_soon(self._call_connection_lost, None) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._proc = None + self._protocol = None + self._loop = None + + +class WriteSubprocessPipeProto(protocols.BaseProtocol): + pipe = None + + def __init__(self, proc, fd): + self.proc = proc + self.fd = fd + self.connected = False + self.disconnected = False + proc._pipes[fd] = self + + def connection_made(self, transport): + self.connected = True + self.pipe = transport + self.proc._try_connected() + + def connection_lost(self, exc): + self.disconnected = True + self.proc._pipe_connection_lost(self.fd, exc) + + def eof_received(self): + pass + + +class ReadSubprocessPipeProto(WriteSubprocessPipeProto, + protocols.Protocol): + + def data_received(self, data): + self.proc._pipe_data_received(self.fd, data) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index cb8625d9..ce226b9b 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -267,8 +267,15 @@ def _make_read_pipe_transport(self, sock, protocol, waiter=None, return _ProactorReadPipeTransport(self, sock, protocol, waiter, extra) def _make_write_pipe_transport(self, sock, protocol, waiter=None, - extra=None): - return _ProactorWritePipeTransport(self, sock, protocol, waiter, extra) + extra=None, check_for_hangup=True): + if check_for_hangup: + # We want connection_lost() to be called when other end closes + return _ProactorDuplexPipeTransport(self, + sock, protocol, waiter, extra) + else: + # If other end closes we may not notice for a long time + return _ProactorWritePipeTransport(self, sock, protocol, waiter, + extra) def close(self): if self._proactor is not None: diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 3807680f..c95ad488 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -1,6 +1,5 @@ """Selector eventloop for Unix with signal handling.""" -import collections import errno import fcntl import os @@ -11,6 +10,7 @@ import sys +from . import base_subprocess from . import constants from . import events from . import protocols @@ -406,159 +406,20 @@ def _call_connection_lost(self, exc): self._loop = None -class _UnixWriteSubprocessPipeProto(protocols.BaseProtocol): - pipe = None +class _UnixSubprocessTransport(base_subprocess.BaseSubprocessTransport): - def __init__(self, proc, fd): - self.proc = proc - self.fd = fd - self.connected = False - self.disconnected = False - proc._pipes[fd] = self - - def connection_made(self, transport): - self.connected = True - self.pipe = transport - self.proc._try_connected() - - def connection_lost(self, exc): - self.disconnected = True - self.proc._pipe_connection_lost(self.fd, exc) - - -class _UnixReadSubprocessPipeProto(_UnixWriteSubprocessPipeProto, - protocols.Protocol): - - def data_received(self, data): - self.proc._pipe_data_received(self.fd, data) - - def eof_received(self): - pass - - -class _UnixSubprocessTransport(transports.SubprocessTransport): - - def __init__(self, loop, protocol, args, shell, - stdin, stdout, stderr, bufsize, - extra=None, **kwargs): - super().__init__(extra) - self._protocol = protocol - self._loop = loop - - self._pipes = {} + def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): stdin_w = None if stdin == subprocess.PIPE: - self._pipes[STDIN] = None # Use a socket pair for stdin, since not all platforms # support selecting read events on the write end of a # socket (which we use in order to detect closing of the # other end). Notably this is needed on AIX, and works # just fine on other platforms. stdin, stdin_w = self._loop._socketpair() - if stdout == subprocess.PIPE: - self._pipes[STDOUT] = None - if stderr == subprocess.PIPE: - self._pipes[STDERR] = None - self._pending_calls = collections.deque() - self._finished = False - self._returncode = None - self._proc = subprocess.Popen( args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, universal_newlines=False, bufsize=bufsize, **kwargs) if stdin_w is not None: stdin.close() self._proc.stdin = open(stdin_w.detach(), 'rb', buffering=bufsize) - self._extra['subprocess'] = self._proc - - def close(self): - for proto in self._pipes.values(): - proto.pipe.close() - if self._returncode is None: - self.terminate() - - def get_pid(self): - return self._proc.pid - - def get_returncode(self): - return self._returncode - - def get_pipe_transport(self, fd): - if fd in self._pipes: - return self._pipes[fd].pipe - else: - return None - - def send_signal(self, signal): - self._proc.send_signal(signal) - - def terminate(self): - self._proc.terminate() - - def kill(self): - self._proc.kill() - - @tasks.coroutine - def _post_init(self): - proc = self._proc - loop = self._loop - if proc.stdin is not None: - transp, proto = yield from loop.connect_write_pipe( - lambda: _UnixWriteSubprocessPipeProto(self, STDIN), - proc.stdin) - if proc.stdout is not None: - transp, proto = yield from loop.connect_read_pipe( - lambda: _UnixReadSubprocessPipeProto(self, STDOUT), - proc.stdout) - if proc.stderr is not None: - transp, proto = yield from loop.connect_read_pipe( - lambda: _UnixReadSubprocessPipeProto(self, STDERR), - proc.stderr) - if not self._pipes: - self._try_connected() - - def _call(self, cb, *data): - if self._pending_calls is not None: - self._pending_calls.append((cb, data)) - else: - self._loop.call_soon(cb, *data) - - def _try_connected(self): - assert self._pending_calls is not None - if all(p is not None and p.connected for p in self._pipes.values()): - self._loop.call_soon(self._protocol.connection_made, self) - for callback, data in self._pending_calls: - self._loop.call_soon(callback, *data) - self._pending_calls = None - - def _pipe_connection_lost(self, fd, exc): - self._call(self._protocol.pipe_connection_lost, fd, exc) - self._try_finish() - - def _pipe_data_received(self, fd, data): - self._call(self._protocol.pipe_data_received, fd, data) - - def _process_exited(self, returncode): - assert returncode is not None, returncode - assert self._returncode is None, self._returncode - self._returncode = returncode - self._loop._subprocess_closed(self) - self._call(self._protocol.process_exited) - self._try_finish() - - def _try_finish(self): - assert not self._finished - if self._returncode is None: - return - if all(p is not None and p.disconnected - for p in self._pipes.values()): - self._finished = True - self._loop.call_soon(self._call_connection_lost, None) - - def _call_connection_lost(self, exc): - try: - self._protocol.connection_lost(exc) - finally: - self._proc = None - self._protocol = None - self._loop = None diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 1ffac999..fc4ae6b9 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -2,10 +2,12 @@ import errno import socket +import subprocess import weakref import struct import _winapi +from . import base_subprocess from . import futures from . import proactor_events from . import selector_events @@ -168,6 +170,19 @@ def loop(f=None): def _stop_serving(self, server): server.close() + @tasks.coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + transp = _WindowsSubprocessTransport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs) + yield from transp._post_init() + return transp + + def _subprocess_closed(self, transport): + pass + class IocpProactor: """Proactor implementation using IOCP.""" @@ -413,3 +428,16 @@ def close(self): if self._iocp is not None: _winapi.CloseHandle(self._iocp) self._iocp = None + + +class _WindowsSubprocessTransport(base_subprocess.BaseSubprocessTransport): + + def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): + self._proc = windows_utils.Popen( + args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, + bufsize=bufsize, **kwargs) + def callback(f): + returncode = self._proc.poll() + self._process_exited(returncode) + f = self._loop._proactor.wait_for_handle(int(self._proc._handle)) + f.add_done_callback(callback) diff --git a/asyncio/windows_utils.py b/asyncio/windows_utils.py index 04b43e9a..2fc3f7a9 100644 --- a/asyncio/windows_utils.py +++ b/asyncio/windows_utils.py @@ -24,6 +24,7 @@ BUFSIZE = 8192 PIPE = subprocess.PIPE +STDOUT = subprocess.STDOUT _mmap_counter = itertools.count() # @@ -146,24 +147,34 @@ class Popen(subprocess.Popen): The stdin, stdout, stderr are None or instances of PipeHandle. """ def __init__(self, args, stdin=None, stdout=None, stderr=None, **kwds): + assert not kwds.get('universal_newlines') + assert kwds.get('bufsize', 0) == 0 stdin_rfd = stdout_wfd = stderr_wfd = None stdin_wh = stdout_rh = stderr_rh = None if stdin == PIPE: - stdin_rh, stdin_wh = pipe(overlapped=(False, True)) + stdin_rh, stdin_wh = pipe(overlapped=(False, True), duplex=True) stdin_rfd = msvcrt.open_osfhandle(stdin_rh, os.O_RDONLY) + else: + stdin_rfd = stdin if stdout == PIPE: stdout_rh, stdout_wh = pipe(overlapped=(True, False)) stdout_wfd = msvcrt.open_osfhandle(stdout_wh, 0) + else: + stdout_wfd = stdout if stderr == PIPE: stderr_rh, stderr_wh = pipe(overlapped=(True, False)) stderr_wfd = msvcrt.open_osfhandle(stderr_wh, 0) + elif stderr == STDOUT: + stderr_wfd = stdout_wfd + else: + stderr_wfd = stderr try: - super().__init__(args, bufsize=0, universal_newlines=False, - stdin=stdin_rfd, stdout=stdout_wfd, + super().__init__(args, stdin=stdin_rfd, stdout=stdout_wfd, stderr=stderr_wfd, **kwds) except: for h in (stdin_wh, stdout_rh, stderr_rh): - _winapi.CloseHandle(h) + if h is not None: + _winapi.CloseHandle(h) raise else: if stdin_wh is not None: diff --git a/tests/test_events.py b/tests/test_events.py index 98896e81..fd2af2e1 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -955,8 +955,23 @@ def main(): r.close() w.close() - @unittest.skipIf(sys.platform == 'win32', - "Don't support subprocess for Windows yet") + +class SubprocessTestsMixin: + + def check_terminated(self, returncode): + if sys.platform == 'win32': + self.assertIsInstance(returncode, int) + self.assertNotEqual(0, returncode) + else: + self.assertEqual(-signal.SIGTERM, returncode) + + def check_killed(self, returncode): + if sys.platform == 'win32': + self.assertIsInstance(returncode, int) + self.assertNotEqual(0, returncode) + else: + self.assertEqual(-signal.SIGKILL, returncode) + def test_subprocess_exec(self): proto = None transp = None @@ -980,11 +995,9 @@ def connect(): self.loop.run_until_complete(proto.got_data[1].wait()) transp.close() self.loop.run_until_complete(proto.completed) - self.assertEqual(-signal.SIGTERM, proto.returncode) + self.check_terminated(proto.returncode) self.assertEqual(b'Python The Winner', proto.data[1]) - @unittest.skipIf(sys.platform == 'win32', - "Don't support subprocess for Windows yet") def test_subprocess_interactive(self): proto = None transp = None @@ -1017,10 +1030,8 @@ def connect(): transp.close() self.loop.run_until_complete(proto.completed) - self.assertEqual(-signal.SIGTERM, proto.returncode) + self.check_terminated(proto.returncode) - @unittest.skipIf(sys.platform == 'win32', - "Don't support subprocess for Windows yet") def test_subprocess_shell(self): proto = None transp = None @@ -1030,7 +1041,7 @@ def connect(): nonlocal proto, transp transp, proto = yield from self.loop.subprocess_shell( functools.partial(MySubprocessProtocol, self.loop), - 'echo "Python"') + 'echo Python') self.assertIsInstance(proto, MySubprocessProtocol) self.loop.run_until_complete(connect()) @@ -1040,10 +1051,9 @@ def connect(): self.loop.run_until_complete(proto.completed) self.assertEqual(0, proto.returncode) self.assertTrue(all(f.done() for f in proto.disconnects.values())) - self.assertEqual({1: b'Python\n', 2: b''}, proto.data) + self.assertEqual(proto.data[1].rstrip(b'\r\n'), b'Python') + self.assertEqual(proto.data[2], b'') - @unittest.skipIf(sys.platform == 'win32', - "Don't support subprocess for Windows yet") def test_subprocess_exitcode(self): proto = None @@ -1059,8 +1069,6 @@ def connect(): self.loop.run_until_complete(proto.completed) self.assertEqual(7, proto.returncode) - @unittest.skipIf(sys.platform == 'win32', - "Don't support subprocess for Windows yet") def test_subprocess_close_after_finish(self): proto = None transp = None @@ -1081,8 +1089,6 @@ def connect(): self.assertEqual(7, proto.returncode) self.assertIsNone(transp.close()) - @unittest.skipIf(sys.platform == 'win32', - "Don't support subprocess for Windows yet") def test_subprocess_kill(self): proto = None transp = None @@ -1102,10 +1108,30 @@ def connect(): transp.kill() self.loop.run_until_complete(proto.completed) - self.assertEqual(-signal.SIGKILL, proto.returncode) + self.check_killed(proto.returncode) + + def test_subprocess_terminate(self): + proto = None + transp = None + + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + @tasks.coroutine + def connect(): + nonlocal proto, transp + transp, proto = yield from self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + self.assertIsInstance(proto, MySubprocessProtocol) + + self.loop.run_until_complete(connect()) + self.loop.run_until_complete(proto.connected) + + transp.terminate() + self.loop.run_until_complete(proto.completed) + self.check_terminated(proto.returncode) - @unittest.skipIf(sys.platform == 'win32', - "Don't support subprocess for Windows yet") + @unittest.skipIf(sys.platform == 'win32', "Don't have SIGHUP") def test_subprocess_send_signal(self): proto = None transp = None @@ -1127,8 +1153,6 @@ def connect(): self.loop.run_until_complete(proto.completed) self.assertEqual(-signal.SIGHUP, proto.returncode) - @unittest.skipIf(sys.platform == 'win32', - "Don't support subprocess for Windows yet") def test_subprocess_stderr(self): proto = None transp = None @@ -1156,8 +1180,6 @@ def connect(): self.assertTrue(proto.data[2].startswith(b'ERR:test'), proto.data[2]) self.assertEqual(0, proto.returncode) - @unittest.skipIf(sys.platform == 'win32', - "Don't support subprocess for Windows yet") def test_subprocess_stderr_redirect_to_stdout(self): proto = None transp = None @@ -1188,8 +1210,6 @@ def connect(): transp.close() self.assertEqual(0, proto.returncode) - @unittest.skipIf(sys.platform == 'win32', - "Don't support subprocess for Windows yet") def test_subprocess_close_client_stream(self): proto = None transp = None @@ -1217,14 +1237,18 @@ def connect(): self.loop.run_until_complete(proto.disconnects[1]) stdin.write(b'xxx') self.loop.run_until_complete(proto.got_data[2].wait()) - self.assertEqual(b'ERR:BrokenPipeError', proto.data[2]) - + if sys.platform != 'win32': + self.assertEqual(b'ERR:BrokenPipeError', proto.data[2]) + else: + # After closing the read-end of a pipe, writing to the + # write-end using os.write() fails with errno==EINVAL and + # GetLastError()==ERROR_INVALID_NAME on Windows!?! (Using + # WriteFile() we get ERROR_BROKEN_PIPE as expected.) + self.assertEqual(b'ERR:OSError', proto.data[2]) transp.close() self.loop.run_until_complete(proto.completed) - self.assertEqual(-signal.SIGTERM, proto.returncode) + self.check_terminated(proto.returncode) - @unittest.skipIf(sys.platform == 'win32', - "Don't support subprocess for Windows yet") def test_subprocess_wait_no_same_group(self): proto = None transp = None @@ -1252,7 +1276,10 @@ class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): def create_event_loop(self): return windows_events.SelectorEventLoop() - class ProactorEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + + class ProactorEventLoopTests(EventLoopTestsMixin, + SubprocessTestsMixin, + unittest.TestCase): def create_event_loop(self): return windows_events.ProactorEventLoop() @@ -1283,26 +1310,34 @@ def test_create_datagram_endpoint(self): from asyncio import unix_events if hasattr(selectors, 'KqueueSelector'): - class KqueueEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + class KqueueEventLoopTests(EventLoopTestsMixin, + SubprocessTestsMixin, + unittest.TestCase): def create_event_loop(self): return unix_events.SelectorEventLoop( selectors.KqueueSelector()) if hasattr(selectors, 'EpollSelector'): - class EPollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + class EPollEventLoopTests(EventLoopTestsMixin, + SubprocessTestsMixin, + unittest.TestCase): def create_event_loop(self): return unix_events.SelectorEventLoop(selectors.EpollSelector()) if hasattr(selectors, 'PollSelector'): - class PollEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + class PollEventLoopTests(EventLoopTestsMixin, + SubprocessTestsMixin, + unittest.TestCase): def create_event_loop(self): return unix_events.SelectorEventLoop(selectors.PollSelector()) # Should always exist. - class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + class SelectEventLoopTests(EventLoopTestsMixin, + SubprocessTestsMixin, + unittest.TestCase): def create_event_loop(self): return unix_events.SelectorEventLoop(selectors.SelectSelector()) From 62df46eaba3d7e18e65e99b1539e9b7095054541 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 29 Oct 2013 10:38:59 -0700 Subject: [PATCH 0752/1502] Tweak import of _overlapped and add instructions README (mostly for myself :-). --- README | 15 +++++++++++++++ asyncio/__init__.py | 12 ++++++++++-- asyncio/windows_events.py | 6 +----- tests/test_windows_utils.py | 6 +----- 4 files changed, 27 insertions(+), 12 deletions(-) diff --git a/README b/README index a7f5e0ce..25ed832e 100644 --- a/README +++ b/README @@ -17,5 +17,20 @@ To run tests: To run coverage (coverage package is required): - make coverage +On Windows, things are a little more complicated. Assume 'P' is your +Python binary (for example C:\Python33\python.exe). + +You must first build the _overlapped.pyd extension and have it placed +in the asyncio directory, as follows: + + C> P setup.py build --build-lib . + +Then you can run the tests as follows: + + C> P runtests.py + +And coverage as follows: + + C> P runtests.py --coverage --Guido van Rossum diff --git a/asyncio/__init__.py b/asyncio/__init__.py index afc444d9..0d288d5a 100644 --- a/asyncio/__init__.py +++ b/asyncio/__init__.py @@ -4,10 +4,18 @@ # The selectors module is in the stdlib in Python 3.4 but not in 3.3. # Do this first, so the other submodules can use "from . import selectors". +# Prefer asyncio/selectors.py over the stdlib one, as ours may be newer. try: - import selectors # Will also be exported. -except ImportError: from . import selectors +except ImportError: + import selectors # Will also be exported. + +if sys.platform == 'win32': + # Similar thing for _overlapped. + try: + from . import _overlapped + except ImportError: + import _overlapped # Will also be exported. # This relies on each of the submodules having an __all__ variable. from .futures import * diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index fc4ae6b9..b70b353c 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -14,11 +14,7 @@ from . import tasks from . import windows_utils from .log import logger - -try: - import _overlapped -except ImportError: - from . import _overlapped +from . import _overlapped __all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor'] diff --git a/tests/test_windows_utils.py b/tests/test_windows_utils.py index f721d318..e013fbdd 100644 --- a/tests/test_windows_utils.py +++ b/tests/test_windows_utils.py @@ -11,11 +11,7 @@ import _winapi from asyncio import windows_utils - -try: - import _overlapped -except ImportError: - from asyncio import _overlapped +from asyncio import _overlapped class WinsocketpairTests(unittest.TestCase): From 6a9fc1b9bf707c6f377938c5416a99ba79702fc6 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 29 Oct 2013 10:46:14 -0700 Subject: [PATCH 0753/1502] Do not attempt to import _winapi on UNIX. --- tests/test_windows_events.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index 17146a36..553ea343 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -1,11 +1,12 @@ import os import sys import unittest -import _winapi if sys.platform != 'win32': raise unittest.SkipTest('Windows only') +import _winapi + import asyncio from asyncio import windows_events From e820833dd4ec116447aec2d0f5b093a3a8572685 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 29 Oct 2013 13:07:36 -0700 Subject: [PATCH 0754/1502] If setuptools cannot be imported, try distutils.core. --- setup.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 011db099..b27f7860 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,10 @@ import os -from setuptools import setup, Extension +try: + from setuptools import setup, Extension +except ImportError: + # Use distutils.core as a fallback. + # We won't be able to build the Wheel file on Windows. + from distutils.core import setup, Extension extensions = [] if os.name == 'nt': From 938e5e4d1d80c866aa45d6885aaa70acf18a570d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 29 Oct 2013 16:46:35 -0700 Subject: [PATCH 0755/1502] Better instructions for running setup.py on Windows. --- README | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README b/README index 25ed832e..34812e08 100644 --- a/README +++ b/README @@ -23,7 +23,7 @@ Python binary (for example C:\Python33\python.exe). You must first build the _overlapped.pyd extension and have it placed in the asyncio directory, as follows: - C> P setup.py build --build-lib . + C> P setup.py build_ext --inplace Then you can run the tests as follows: From 5178c36a7e4ff6b67e2b094620a76c35b14edc39 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 30 Oct 2013 11:08:39 -0700 Subject: [PATCH 0756/1502] Temporarily stop accepting whenever accept() returns certain errors. Fixes issue #78. --- asyncio/constants.py | 5 ++++- asyncio/selector_events.py | 17 ++++++++++++++--- tests/test_base_events.py | 11 ++++++++++- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/asyncio/constants.py b/asyncio/constants.py index 79c3b931..f9e12328 100644 --- a/asyncio/constants.py +++ b/asyncio/constants.py @@ -1,4 +1,7 @@ """Constants.""" - +# After the connection is lost, log warnings after this many write()s. LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5 + +# Seconds to wait before retrying accept(). +ACCEPT_RETRY_DELAY = 1 diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 7ec07491..f7bc61ac 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -99,12 +99,23 @@ def _accept_connection(self, protocol_factory, sock, ssl=None, try: conn, addr = sock.accept() conn.setblocking(False) - except (BlockingIOError, InterruptedError): + except (BlockingIOError, InterruptedError, ConnectionAbortedError): pass # False alarm. - except Exception as exc: + except OSError as exc: # There's nowhere to send the error, so just log it. # TODO: Someone will want an error handler for this. - logger.exception('Accept failed') + if exc.errno in (errno.EMFILE, errno.ENFILE, + errno.ENOBUFS, errno.ENOMEM): + # Some platforms (e.g. Linux keep reporting the FD as + # ready, so we remove the read handler temporarily. + # We'll try again in a while. + logger.exception('Accept out of system resource (%s)', exc) + self.remove_reader(sock.fileno()) + self.call_later(constants.ACCEPT_RETRY_DELAY, + self._start_serving, + protocol_factory, sock, ssl, server) + else: + raise # The event loop will catch, log and ignore it. else: if ssl: self._make_ssl_transport( diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 09a5fcb2..f4d16d9b 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -1,5 +1,6 @@ """Tests for base_events.py""" +import errno import logging import socket import time @@ -8,6 +9,7 @@ from test.support import find_unused_port, IPV6_ENABLED from asyncio import base_events +from asyncio import constants from asyncio import events from asyncio import futures from asyncio import protocols @@ -585,11 +587,18 @@ def test_accept_connection_retry(self): def test_accept_connection_exception(self, m_log): sock = unittest.mock.Mock() sock.fileno.return_value = 10 - sock.accept.side_effect = OSError() + sock.accept.side_effect = OSError(errno.EMFILE, 'Too many open files') + self.loop.remove_reader = unittest.mock.Mock() + self.loop.call_later = unittest.mock.Mock() self.loop._accept_connection(MyProto, sock) self.assertTrue(m_log.exception.called) self.assertFalse(sock.close.called) + self.loop.remove_reader.assert_called_with(10) + self.loop.call_later.assert_called_with(constants.ACCEPT_RETRY_DELAY, + # self.loop._start_serving + unittest.mock.ANY, + MyProto, sock, None, None) if __name__ == '__main__': From bad4f77801b98ee4bffd1c8310ff3ae8bc22f83d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 30 Oct 2013 11:10:02 -0700 Subject: [PATCH 0757/1502] Fold some long lines. --- asyncio/selector_events.py | 3 ++- asyncio/tasks.py | 5 +++-- examples/child_process.py | 3 ++- examples/source1.py | 2 +- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index f7bc61ac..e61a88d8 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -416,7 +416,8 @@ def _maybe_pause_protocol(self): tulip_log.exception('pause_writing() failed') def _maybe_resume_protocol(self): - if self._protocol_paused and self.get_write_buffer_size() <= self._low_water: + if (self._protocol_paused and + self.get_write_buffer_size() <= self._low_water): self._protocol_paused = False try: self._protocol.resume_writing() diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 63850178..2a21a4b9 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -62,8 +62,9 @@ def __del__(self): code = func.__code__ filename = code.co_filename lineno = code.co_firstlineno - logger.error('Coroutine %r defined at %s:%s was never yielded from', - func.__name__, filename, lineno) + logger.error( + 'Coroutine %r defined at %s:%s was never yielded from', + func.__name__, filename, lineno) def coroutine(func): diff --git a/examples/child_process.py b/examples/child_process.py index ef31e68b..8a7ed0e6 100644 --- a/examples/child_process.py +++ b/examples/child_process.py @@ -103,7 +103,8 @@ def writeall(fd, buf): timeout = None while registered: done, pending = yield from asyncio.wait( - registered, timeout=timeout, return_when=asyncio.FIRST_COMPLETED) + registered, timeout=timeout, + return_when=asyncio.FIRST_COMPLETED) if not done: break for f in done: diff --git a/examples/source1.py b/examples/source1.py index 4e05964f..6471d819 100644 --- a/examples/source1.py +++ b/examples/source1.py @@ -43,7 +43,7 @@ def oprint(self, *args): print(self.label, '[...]', file=sys.stderr) end = '\r' print(self.label, *args, file=sys.stderr, end=end, flush=True) - + @coroutine def start(loop, args): From 6bdcc00939c8cb07a8286d1be298b5b19ae12ea9 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 30 Oct 2013 12:49:42 -0700 Subject: [PATCH 0758/1502] Update selectors.py from cpython: add get_map() method (Natali + Pitrou). --- asyncio/selectors.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/asyncio/selectors.py b/asyncio/selectors.py index fe027f09..3e6c2adc 100644 --- a/asyncio/selectors.py +++ b/asyncio/selectors.py @@ -6,7 +6,7 @@ from abc import ABCMeta, abstractmethod -from collections import namedtuple +from collections import namedtuple, Mapping import functools import select import sys @@ -44,6 +44,25 @@ def _fileobj_to_fd(fileobj): selected event mask and attached data.""" +class _SelectorMapping(Mapping): + """Mapping of file objects to selector keys.""" + + def __init__(self, selector): + self._selector = selector + + def __len__(self): + return len(self._selector._fd_to_key) + + def __getitem__(self, fileobj): + try: + return self._selector._fd_to_key[_fileobj_to_fd(fileobj)] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + + def __iter__(self): + return iter(self._selector._fd_to_key) + + class BaseSelector(metaclass=ABCMeta): """Base selector class. @@ -62,6 +81,8 @@ class BaseSelector(metaclass=ABCMeta): def __init__(self): # this maps file descriptors to keys self._fd_to_key = {} + # read-only mapping returned by get_map() + self._map = _SelectorMapping(self) def register(self, fileobj, events, data=None): """Register a file object. @@ -162,6 +183,10 @@ def get_key(self, fileobj): except KeyError: raise KeyError("{!r} is not registered".format(fileobj)) from None + def get_map(self): + """Return a mapping of file objects to selector keys.""" + return self._map + def __enter__(self): return self From c57cbe222f5aa81054742f5f124f9e2d970f17a3 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 30 Oct 2013 13:59:43 -0700 Subject: [PATCH 0759/1502] Add server_hostname as create_connection() argument, with secure default. --- asyncio/base_events.py | 23 ++++++++++++++-- asyncio/events.py | 2 +- asyncio/selector_events.py | 4 +-- tests/test_base_events.py | 54 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 78 insertions(+), 5 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 37d50aa2..f18a5565 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -275,8 +275,27 @@ def getnameinfo(self, sockaddr, flags=0): @tasks.coroutine def create_connection(self, protocol_factory, host=None, port=None, *, ssl=None, family=0, proto=0, flags=0, sock=None, - local_addr=None): + local_addr=None, server_hostname=None): """XXX""" + if server_hostname is not None and not ssl: + raise ValueError('server_hostname is only meaningful with ssl') + + if server_hostname is None and ssl: + # Use host as default for server_hostname. It is an error + # if host is empty or not set, e.g. when an + # already-connected socket was passed or when only a port + # is given. To avoid this error, you can pass + # server_hostname='' -- this will bypass the hostname + # check. (This also means that if host is a numeric + # IP/IPv6 address, we will attempt to verify that exact + # address; this will probably fail, but it is possible to + # create a certificate for a specific IP address, so we + # don't judge it here.) + if not host: + raise ValueError('You must set server_hostname ' + 'when using ssl without a host') + server_hostname = host + if host is not None or port is not None: if sock is not None: raise ValueError( @@ -357,7 +376,7 @@ def create_connection(self, protocol_factory, host=None, port=None, *, sslcontext = None if isinstance(ssl, bool) else ssl transport = self._make_ssl_transport( sock, protocol, sslcontext, waiter, - server_side=False, server_hostname=host) + server_side=False, server_hostname=server_hostname) else: transport = self._make_socket_transport(sock, protocol, waiter) diff --git a/asyncio/events.py b/asyncio/events.py index 62f8e949..a47253a6 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -172,7 +172,7 @@ def getnameinfo(self, sockaddr, flags=0): def create_connection(self, protocol_factory, host=None, port=None, *, ssl=None, family=0, proto=0, flags=0, sock=None, - local_addr=None): + local_addr=None, server_hostname=None): raise NotImplementedError def create_server(self, protocol_factory, host=None, port=None, *, diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index e61a88d8..44430b21 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -573,7 +573,7 @@ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, 'server_side': server_side, 'do_handshake_on_connect': False, } - if server_hostname is not None and not server_side and ssl.HAS_SNI: + if server_hostname and not server_side and ssl.HAS_SNI: wrap_kwargs['server_hostname'] = server_hostname sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs) @@ -619,7 +619,7 @@ def _on_handshake(self): # Verify hostname if requested. peercert = self._sock.getpeercert() - if (self._server_hostname is not None and + if (self._server_hostname and self._sslcontext.verify_mode != ssl.CERT_NONE): try: ssl.match_hostname(peercert, self._server_hostname) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index f4d16d9b..5c120ff1 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -444,6 +444,60 @@ def getaddrinfo_task(*args, **kwds): self.assertRaises( OSError, self.loop.run_until_complete, coro) + def test_create_connection_server_hostname_default(self): + self.loop.getaddrinfo = unittest.mock.Mock() + def mock_getaddrinfo(*args, **kwds): + f = futures.Future(loop=self.loop) + f.set_result([(socket.AF_INET, socket.SOCK_STREAM, + socket.SOL_TCP, '', ('1.2.3.4', 80))]) + return f + self.loop.getaddrinfo.side_effect = mock_getaddrinfo + self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect.return_value = () + self.loop._make_ssl_transport = unittest.mock.Mock() + def mock_make_ssl_transport(sock, protocol, sslcontext, waiter, **kwds): + waiter.set_result(None) + self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport + ANY = unittest.mock.ANY + # First try the default server_hostname. + self.loop._make_ssl_transport.reset_mock() + coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True) + self.loop.run_until_complete(coro) + self.loop._make_ssl_transport.assert_called_with(ANY, ANY, ANY, ANY, + server_side=False, + server_hostname='python.org') + # Next try an explicit server_hostname. + self.loop._make_ssl_transport.reset_mock() + coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True, + server_hostname='perl.com') + self.loop.run_until_complete(coro) + self.loop._make_ssl_transport.assert_called_with(ANY, ANY, ANY, ANY, + server_side=False, + server_hostname='perl.com') + # Finally try an explicit empty server_hostname. + self.loop._make_ssl_transport.reset_mock() + coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True, + server_hostname='') + self.loop.run_until_complete(coro) + self.loop._make_ssl_transport.assert_called_with(ANY, ANY, ANY, ANY, + server_side=False, + server_hostname='') + + def test_create_connection_server_hostname_errors(self): + # When not using ssl, server_hostname must be None (but '' is OK). + coro = self.loop.create_connection(MyProto, 'python.org', 80, server_hostname='') + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + coro = self.loop.create_connection(MyProto, 'python.org', 80, server_hostname='python.org') + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + # When using ssl, server_hostname may be None if host is non-empty. + coro = self.loop.create_connection(MyProto, '', 80, ssl=True) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + coro = self.loop.create_connection(MyProto, None, 80, ssl=True) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + coro = self.loop.create_connection(MyProto, None, None, ssl=True, sock=socket.socket()) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + def test_create_server_empty_host(self): # if host is empty string use None instead host = object() From 67205ec51004676a2bf2bfac912140f393764af7 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 31 Oct 2013 11:02:47 -0700 Subject: [PATCH 0760/1502] Close resources owned by subclass before calling super().close(). --- asyncio/selectors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/asyncio/selectors.py b/asyncio/selectors.py index 3e6c2adc..3638e854 100644 --- a/asyncio/selectors.py +++ b/asyncio/selectors.py @@ -351,8 +351,8 @@ def select(self, timeout=None): return ready def close(self): - super().close() self._epoll.close() + super().close() if hasattr(select, 'kqueue'): @@ -414,8 +414,8 @@ def select(self, timeout=None): return ready def close(self): - super().close() self._kqueue.close() + super().close() # Choose the best implementation: roughly, epoll|kqueue > poll > select. From b7c411cfaa3741de5adbae30bb8a22a48e3e502d Mon Sep 17 00:00:00 2001 From: Nikolay Kim Date: Thu, 31 Oct 2013 11:38:11 -0700 Subject: [PATCH 0761/1502] refactor ssl transport ready loop --- asyncio/selector_events.py | 94 +++++++++++++----------- tests/test_selector_events.py | 134 +++++++++++++++++++++------------- 2 files changed, 136 insertions(+), 92 deletions(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 44430b21..a975dbb7 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -286,7 +286,7 @@ def _sock_connect(self, fut, registered, sock, address): err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) if err != 0: # Jump to the except clause below. - raise OSError(err, 'Connect call failed') + raise OSError(err, 'Connect call failed %s' % (address,)) except (BlockingIOError, InterruptedError): self.add_writer(fd, self._sock_connect, fut, True, sock, address) except Exception as exc: @@ -413,7 +413,7 @@ def _maybe_pause_protocol(self): try: self._protocol.pause_writing() except Exception: - tulip_log.exception('pause_writing() failed') + logger.exception('pause_writing() failed') def _maybe_resume_protocol(self): if (self._protocol_paused and @@ -422,7 +422,7 @@ def _maybe_resume_protocol(self): try: self._protocol.resume_writing() except Exception: - tulip_log.exception('resume_writing() failed') + logger.exception('resume_writing() failed') def set_write_buffer_limits(self, high=None, low=None): if high is None: @@ -635,15 +635,16 @@ def _on_handshake(self): compression=self._sock.compression(), ) - self._loop.add_reader(self._sock_fd, self._on_ready) - self._loop.add_writer(self._sock_fd, self._on_ready) + self._read_wants_write = False + self._write_wants_read = False + self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) if self._waiter is not None: self._loop.call_soon(self._waiter.set_result, None) def pause_reading(self): # XXX This is a bit icky, given the comment at the top of - # _on_ready(). Is it possible to evoke a deadlock? I don't + # _read_ready(). Is it possible to evoke a deadlock? I don't # know, although it doesn't look like it; write() will still # accept more data for the buffer and eventually the app will # call resume_reading() again, and things will flow again. @@ -658,41 +659,55 @@ def resume_reading(self): self._paused = False if self._closing: return - self._loop.add_reader(self._sock_fd, self._on_ready) + self._loop.add_reader(self._sock_fd, self._read_ready) - def _on_ready(self): - # Because of renegotiations (?), there's no difference between - # readable and writable. We just try both. XXX This may be - # incorrect; we probably need to keep state about what we - # should do next. + def _read_ready(self): + if self._write_wants_read: + self._write_wants_read = False + self._write_ready() - # First try reading. - if not self._closing and not self._paused: - try: - data = self._sock.recv(self.max_size) - except (BlockingIOError, InterruptedError, - ssl.SSLWantReadError, ssl.SSLWantWriteError): - pass - except Exception as exc: - self._fatal_error(exc) + if self._buffer: + self._loop.add_writer(self._sock_fd, self._write_ready) + + try: + data = self._sock.recv(self.max_size) + except (BlockingIOError, InterruptedError, ssl.SSLWantReadError): + pass + except ssl.SSLWantWriteError: + self._read_wants_write = True + self._loop.remove_reader(self._sock_fd) + self._loop.add_writer(self._sock_fd, self._write_ready) + except Exception as exc: + self._fatal_error(exc) + else: + if data: + self._protocol.data_received(data) else: - if data: - self._protocol.data_received(data) - else: - try: - self._protocol.eof_received() - finally: - self.close() + try: + self._protocol.eof_received() + finally: + self.close() + + def _write_ready(self): + if self._read_wants_write: + self._read_wants_write = False + self._read_ready() + + if not (self._paused or self._closing): + self._loop.add_reader(self._sock_fd, self._read_ready) - # Now try writing, if there's anything to write. if self._buffer: data = b''.join(self._buffer) self._buffer.clear() try: n = self._sock.send(data) except (BlockingIOError, InterruptedError, - ssl.SSLWantReadError, ssl.SSLWantWriteError): + ssl.SSLWantWriteError): n = 0 + except ssl.SSLWantReadError: + n = 0 + self._loop.remove_writer(self._sock_fd) + self._write_wants_read = True except Exception as exc: self._loop.remove_writer(self._sock_fd) self._fatal_error(exc) @@ -701,11 +716,12 @@ def _on_ready(self): if n < len(data): self._buffer.append(data[n:]) - self._maybe_resume_protocol() # May append to buffer. + self._maybe_resume_protocol() # May append to buffer. - if self._closing and not self._buffer: + if not self._buffer: self._loop.remove_writer(self._sock_fd) - self._call_connection_lost(None) + if self._closing: + self._call_connection_lost(None) def write(self, data): assert isinstance(data, bytes), repr(type(data)) @@ -718,20 +734,16 @@ def write(self, data): self._conn_lost += 1 return - # We could optimize, but the callback can do this for now. + if not self._buffer: + self._loop.add_writer(self._sock_fd, self._write_ready) + + # Add it to the buffer. self._buffer.append(data) self._maybe_pause_protocol() def can_write_eof(self): return False - def close(self): - if self._closing: - return - self._closing = True - self._conn_lost += 1 - self._loop.remove_reader(self._sock_fd) - class _SelectorDatagramTransport(_SelectorTransport): diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index fbd5d723..3b8238d5 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -1003,8 +1003,7 @@ def test_on_handshake(self): self.loop, self.sock, self.protocol, self.sslcontext, waiter=waiter) self.assertTrue(self.sslsock.do_handshake.called) - self.loop.assert_reader(1, tr._on_ready) - self.loop.assert_writer(1, tr._on_ready) + self.loop.assert_reader(1, tr._read_ready) test_utils.run_briefly(self.loop) self.assertIsNone(waiter.result()) @@ -1047,13 +1046,13 @@ def test_on_handshake_base_exc(self): def test_pause_resume_reading(self): tr = self._make_one() self.assertFalse(tr._paused) - self.loop.assert_reader(1, tr._on_ready) + self.loop.assert_reader(1, tr._read_ready) tr.pause_reading() self.assertTrue(tr._paused) self.assertFalse(1 in self.loop.readers) tr.resume_reading() self.assertFalse(tr._paused) - self.loop.assert_reader(1, tr._on_ready) + self.loop.assert_reader(1, tr._read_ready) def test_write_no_data(self): transport = self._make_one() @@ -1084,140 +1083,173 @@ def test_write_exception(self, m_log): transport.write(b'data') m_log.warning.assert_called_with('socket.send() raised exception.') - def test_on_ready_recv(self): + def test_read_ready_recv(self): self.sslsock.recv.return_value = b'data' transport = self._make_one() - transport._on_ready() + transport._read_ready() self.assertTrue(self.sslsock.recv.called) self.assertEqual((b'data',), self.protocol.data_received.call_args[0]) - def test_on_ready_recv_eof(self): + def test_read_ready_write_wants_read(self): + self.loop.add_writer = unittest.mock.Mock() + self.sslsock.recv.side_effect = BlockingIOError + transport = self._make_one() + transport._write_wants_read = True + transport._write_ready = unittest.mock.Mock() + transport._buffer.append(b'data') + transport._read_ready() + + self.assertFalse(transport._write_wants_read) + transport._write_ready.assert_called_with() + self.loop.add_writer.assert_called_with( + transport._sock_fd, transport._write_ready) + + def test_read_ready_recv_eof(self): self.sslsock.recv.return_value = b'' transport = self._make_one() transport.close = unittest.mock.Mock() - transport._on_ready() + transport._read_ready() transport.close.assert_called_with() self.protocol.eof_received.assert_called_with() - def test_on_ready_recv_conn_reset(self): + def test_read_ready_recv_conn_reset(self): err = self.sslsock.recv.side_effect = ConnectionResetError() transport = self._make_one() transport._force_close = unittest.mock.Mock() - transport._on_ready() + transport._read_ready() transport._force_close.assert_called_with(err) - def test_on_ready_recv_retry(self): + def test_read_ready_recv_retry(self): self.sslsock.recv.side_effect = ssl.SSLWantReadError transport = self._make_one() - transport._on_ready() + transport._read_ready() self.assertTrue(self.sslsock.recv.called) self.assertFalse(self.protocol.data_received.called) - self.sslsock.recv.side_effect = ssl.SSLWantWriteError - transport._on_ready() - self.assertFalse(self.protocol.data_received.called) - self.sslsock.recv.side_effect = BlockingIOError - transport._on_ready() + transport._read_ready() self.assertFalse(self.protocol.data_received.called) self.sslsock.recv.side_effect = InterruptedError - transport._on_ready() + transport._read_ready() + self.assertFalse(self.protocol.data_received.called) + + def test_read_ready_recv_write(self): + self.loop.remove_reader = unittest.mock.Mock() + self.loop.add_writer = unittest.mock.Mock() + self.sslsock.recv.side_effect = ssl.SSLWantWriteError + transport = self._make_one() + transport._read_ready() self.assertFalse(self.protocol.data_received.called) + self.assertTrue(transport._read_wants_write) - def test_on_ready_recv_exc(self): + self.loop.remove_reader.assert_called_with(transport._sock_fd) + self.loop.add_writer.assert_called_with( + transport._sock_fd, transport._write_ready) + + def test_read_ready_recv_exc(self): err = self.sslsock.recv.side_effect = OSError() transport = self._make_one() transport._fatal_error = unittest.mock.Mock() - transport._on_ready() + transport._read_ready() transport._fatal_error.assert_called_with(err) - def test_on_ready_send(self): - self.sslsock.recv.side_effect = ssl.SSLWantReadError + def test_write_ready_send(self): self.sslsock.send.return_value = 4 transport = self._make_one() transport._buffer = collections.deque([b'data']) - transport._on_ready() + transport._write_ready() self.assertEqual(collections.deque(), transport._buffer) self.assertTrue(self.sslsock.send.called) - def test_on_ready_send_none(self): - self.sslsock.recv.side_effect = ssl.SSLWantReadError + def test_write_ready_send_none(self): self.sslsock.send.return_value = 0 transport = self._make_one() transport._buffer = collections.deque([b'data1', b'data2']) - transport._on_ready() + transport._write_ready() self.assertTrue(self.sslsock.send.called) self.assertEqual(collections.deque([b'data1data2']), transport._buffer) - def test_on_ready_send_partial(self): - self.sslsock.recv.side_effect = ssl.SSLWantReadError + def test_write_ready_send_partial(self): self.sslsock.send.return_value = 2 transport = self._make_one() transport._buffer = collections.deque([b'data1', b'data2']) - transport._on_ready() + transport._write_ready() self.assertTrue(self.sslsock.send.called) self.assertEqual(collections.deque([b'ta1data2']), transport._buffer) - def test_on_ready_send_closing_partial(self): - self.sslsock.recv.side_effect = ssl.SSLWantReadError + def test_write_ready_send_closing_partial(self): self.sslsock.send.return_value = 2 transport = self._make_one() transport._buffer = collections.deque([b'data1', b'data2']) - transport._on_ready() + transport._write_ready() self.assertTrue(self.sslsock.send.called) self.assertFalse(self.sslsock.close.called) - def test_on_ready_send_closing(self): - self.sslsock.recv.side_effect = ssl.SSLWantReadError + def test_write_ready_send_closing(self): self.sslsock.send.return_value = 4 transport = self._make_one() transport.close() transport._buffer = collections.deque([b'data']) - transport._on_ready() + transport._write_ready() self.assertFalse(self.loop.writers) self.protocol.connection_lost.assert_called_with(None) - def test_on_ready_send_closing_empty_buffer(self): - self.sslsock.recv.side_effect = ssl.SSLWantReadError + def test_write_ready_send_closing_empty_buffer(self): self.sslsock.send.return_value = 4 transport = self._make_one() transport.close() transport._buffer = collections.deque() - transport._on_ready() + transport._write_ready() self.assertFalse(self.loop.writers) self.protocol.connection_lost.assert_called_with(None) - def test_on_ready_send_retry(self): - self.sslsock.recv.side_effect = ssl.SSLWantReadError - + def test_write_ready_send_retry(self): transport = self._make_one() transport._buffer = collections.deque([b'data']) - self.sslsock.send.side_effect = ssl.SSLWantReadError - transport._on_ready() - self.assertTrue(self.sslsock.send.called) - self.assertEqual(collections.deque([b'data']), transport._buffer) - self.sslsock.send.side_effect = ssl.SSLWantWriteError - transport._on_ready() + transport._write_ready() self.assertEqual(collections.deque([b'data']), transport._buffer) self.sslsock.send.side_effect = BlockingIOError() - transport._on_ready() + transport._write_ready() self.assertEqual(collections.deque([b'data']), transport._buffer) - def test_on_ready_send_exc(self): - self.sslsock.recv.side_effect = ssl.SSLWantReadError + def test_write_ready_send_read(self): + transport = self._make_one() + transport._buffer = collections.deque([b'data']) + + self.loop.remove_writer = unittest.mock.Mock() + self.sslsock.send.side_effect = ssl.SSLWantReadError + transport._write_ready() + self.assertFalse(self.protocol.data_received.called) + self.assertTrue(transport._write_wants_read) + self.loop.remove_writer.assert_called_with(transport._sock_fd) + + def test_write_ready_send_exc(self): err = self.sslsock.send.side_effect = OSError() transport = self._make_one() transport._buffer = collections.deque([b'data']) transport._fatal_error = unittest.mock.Mock() - transport._on_ready() + transport._write_ready() transport._fatal_error.assert_called_with(err) self.assertEqual(collections.deque(), transport._buffer) + def test_write_ready_read_wants_write(self): + self.loop.add_reader = unittest.mock.Mock() + self.sslsock.send.side_effect = BlockingIOError + transport = self._make_one() + transport._read_wants_write = True + transport._read_ready = unittest.mock.Mock() + transport._write_ready() + + self.assertFalse(transport._read_wants_write) + transport._read_ready.assert_called_with() + self.loop.add_reader.assert_called_with( + transport._sock_fd, transport._read_ready) + def test_write_eof(self): tr = self._make_one() self.assertFalse(tr.can_write_eof()) From a429474af62f5ed8db5f24311b499d7beebfff2b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 31 Oct 2013 11:45:03 -0700 Subject: [PATCH 0762/1502] Fix coverage -- it was still using tulip instead of asyncio. --- runtests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtests.py b/runtests.py index b85cce83..469ac86a 100644 --- a/runtests.py +++ b/runtests.py @@ -231,7 +231,7 @@ def runtests(): if args.coverage: cov = coverage.coverage(branch=True, - source=['tulip'], + source=['asyncio'], ) cov.start() From 69de2ea9042f8009f74d3eebc9baf81199036d14 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 31 Oct 2013 15:53:28 -0700 Subject: [PATCH 0763/1502] Document EventLoop.close(). --- asyncio/base_events.py | 5 +++++ asyncio/events.py | 13 +++++++++++++ tests/test_events.py | 2 ++ 3 files changed, 20 insertions(+) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index f18a5565..6e409ead 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -186,6 +186,11 @@ def stop(self): self.call_soon(_raise_stop_error) def close(self): + """Close the event loop. + + This clears the queues and shuts down the executor, + but does not wait for the executor to finish. + """ self._ready.clear() self._scheduled.clear() executor = self._default_executor diff --git a/asyncio/events.py b/asyncio/events.py index a47253a6..7ebc3cb4 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -137,6 +137,17 @@ def is_running(self): """Return whether the event loop is currently running.""" raise NotImplementedError + def close(self): + """Close the loop. + + The loop should not be running. + + This is idempotent and irreversible. + + No other methods should be called after this one. + """ + raise NotImplementedError + # Methods scheduling callbacks. All these return Handles. def call_soon(self, callback, *args): @@ -214,6 +225,8 @@ def create_datagram_endpoint(self, protocol_factory, family=0, proto=0, flags=0): raise NotImplementedError + # Pipes and subprocesses. + def connect_read_pipe(self, protocol_factory, pipe): """Register read pipe in eventloop. diff --git a/tests/test_events.py b/tests/test_events.py index fd2af2e1..83d73973 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1471,6 +1471,8 @@ def test_not_implemented(self): NotImplementedError, loop.stop) self.assertRaises( NotImplementedError, loop.is_running) + self.assertRaises( + NotImplementedError, loop.close) self.assertRaises( NotImplementedError, loop.call_later, None, None) self.assertRaises( From 28bc7aa702369bef1d7dc7365d12c21da9218c97 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 31 Oct 2013 16:44:19 -0700 Subject: [PATCH 0764/1502] Log a warning when eof_received() returns true and using ssl. --- asyncio/selector_events.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index a975dbb7..c5fc5eb7 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -684,7 +684,10 @@ def _read_ready(self): self._protocol.data_received(data) else: try: - self._protocol.eof_received() + keep_open = self._protocol.eof_received() + if keep_open: + logger.warning('returning true from eof_received() ' + 'has no effect when using ssl') finally: self.close() From d055dd6e9229c829e93a6169c14dfb3b10ee6eaf Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 31 Oct 2013 17:13:02 -0700 Subject: [PATCH 0765/1502] Fold some long lines. --- asyncio/base_events.py | 2 +- tests/test_base_events.py | 26 ++++++++++++++++---------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 6e409ead..a73b3d39 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -300,7 +300,7 @@ def create_connection(self, protocol_factory, host=None, port=None, *, raise ValueError('You must set server_hostname ' 'when using ssl without a host') server_hostname = host - + if host is not None or port is not None: if sock is not None: raise ValueError( diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 5c120ff1..8610c781 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -455,7 +455,8 @@ def mock_getaddrinfo(*args, **kwds): self.loop.sock_connect = unittest.mock.Mock() self.loop.sock_connect.return_value = () self.loop._make_ssl_transport = unittest.mock.Mock() - def mock_make_ssl_transport(sock, protocol, sslcontext, waiter, **kwds): + def mock_make_ssl_transport(sock, protocol, sslcontext, waiter, + **kwds): waiter.set_result(None) self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport ANY = unittest.mock.ANY @@ -463,17 +464,19 @@ def mock_make_ssl_transport(sock, protocol, sslcontext, waiter, **kwds): self.loop._make_ssl_transport.reset_mock() coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True) self.loop.run_until_complete(coro) - self.loop._make_ssl_transport.assert_called_with(ANY, ANY, ANY, ANY, - server_side=False, - server_hostname='python.org') + self.loop._make_ssl_transport.assert_called_with( + ANY, ANY, ANY, ANY, + server_side=False, + server_hostname='python.org') # Next try an explicit server_hostname. self.loop._make_ssl_transport.reset_mock() coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True, server_hostname='perl.com') self.loop.run_until_complete(coro) - self.loop._make_ssl_transport.assert_called_with(ANY, ANY, ANY, ANY, - server_side=False, - server_hostname='perl.com') + self.loop._make_ssl_transport.assert_called_with( + ANY, ANY, ANY, ANY, + server_side=False, + server_hostname='perl.com') # Finally try an explicit empty server_hostname. self.loop._make_ssl_transport.reset_mock() coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True, @@ -485,9 +488,11 @@ def mock_make_ssl_transport(sock, protocol, sslcontext, waiter, **kwds): def test_create_connection_server_hostname_errors(self): # When not using ssl, server_hostname must be None (but '' is OK). - coro = self.loop.create_connection(MyProto, 'python.org', 80, server_hostname='') + coro = self.loop.create_connection(MyProto, 'python.org', 80, + server_hostname='') self.assertRaises(ValueError, self.loop.run_until_complete, coro) - coro = self.loop.create_connection(MyProto, 'python.org', 80, server_hostname='python.org') + coro = self.loop.create_connection(MyProto, 'python.org', 80, + server_hostname='python.org') self.assertRaises(ValueError, self.loop.run_until_complete, coro) # When using ssl, server_hostname may be None if host is non-empty. @@ -495,7 +500,8 @@ def test_create_connection_server_hostname_errors(self): self.assertRaises(ValueError, self.loop.run_until_complete, coro) coro = self.loop.create_connection(MyProto, None, 80, ssl=True) self.assertRaises(ValueError, self.loop.run_until_complete, coro) - coro = self.loop.create_connection(MyProto, None, None, ssl=True, sock=socket.socket()) + coro = self.loop.create_connection(MyProto, None, None, + ssl=True, sock=socket.socket()) self.assertRaises(ValueError, self.loop.run_until_complete, coro) def test_create_server_empty_host(self): From 194f19d5b348fd09132c6747e1b7e9ecfd78afab Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 1 Nov 2013 08:14:03 -0700 Subject: [PATCH 0766/1502] Satisfy (most) pep8 whitespace requirements. --- Makefile | 5 +++++ asyncio/windows_events.py | 16 ++++++++++++++++ asyncio/windows_utils.py | 20 ++++++++++---------- tests/test_base_events.py | 4 ++++ tests/test_events.py | 1 - tests/test_windows_events.py | 2 +- 6 files changed, 36 insertions(+), 12 deletions(-) diff --git a/Makefile b/Makefile index 74c38cae..448796ef 100644 --- a/Makefile +++ b/Makefile @@ -21,6 +21,10 @@ cov coverage: check: $(PYTHON) check.py +# Requires "pip install pep8". +pep8: check + pep8 --ignore E125,E127,E226 tests asyncio + clean: rm -rf `find . -name __pycache__` rm -f `find . -type f -name '*.py[co]' ` @@ -33,6 +37,7 @@ clean: rm -rf dist rm -f .coverage rm -rf htmlcov + rm -rf build rm -f MANIFEST diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index b70b353c..d7444bdf 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -138,6 +138,7 @@ def create_pipe_connection(self, protocol_factory, address): @tasks.coroutine def start_serving_pipe(self, protocol_factory, address): server = PipeServer(address) + def loop(f=None): pipe = None try: @@ -160,6 +161,7 @@ def loop(f=None): pipe.close() else: f.add_done_callback(loop) + self.call_soon(loop) return [server] @@ -209,6 +211,7 @@ def recv(self, conn, nbytes, flags=0): ov.WSARecv(conn.fileno(), nbytes, flags) else: ov.ReadFile(conn.fileno(), nbytes) + def finish(trans, key, ov): try: return ov.getresult() @@ -217,6 +220,7 @@ def finish(trans, key, ov): raise ConnectionResetError(*exc.args) else: raise + return self._register(ov, conn, finish) def send(self, conn, buf, flags=0): @@ -226,6 +230,7 @@ def send(self, conn, buf, flags=0): ov.WSASend(conn.fileno(), buf, flags) else: ov.WriteFile(conn.fileno(), buf) + def finish(trans, key, ov): try: return ov.getresult() @@ -234,6 +239,7 @@ def finish(trans, key, ov): raise ConnectionResetError(*exc.args) else: raise + return self._register(ov, conn, finish) def accept(self, listener): @@ -241,6 +247,7 @@ def accept(self, listener): conn = self._get_accept_socket(listener.family) ov = _overlapped.Overlapped(NULL) ov.AcceptEx(listener.fileno(), conn.fileno()) + def finish_accept(trans, key, ov): ov.getresult() # Use SO_UPDATE_ACCEPT_CONTEXT so getsockname() etc work. @@ -249,6 +256,7 @@ def finish_accept(trans, key, ov): _overlapped.SO_UPDATE_ACCEPT_CONTEXT, buf) conn.settimeout(listener.gettimeout()) return conn, conn.getpeername() + return self._register(ov, listener, finish_accept) def connect(self, conn, address): @@ -264,26 +272,31 @@ def connect(self, conn, address): raise ov = _overlapped.Overlapped(NULL) ov.ConnectEx(conn.fileno(), address) + def finish_connect(trans, key, ov): ov.getresult() # Use SO_UPDATE_CONNECT_CONTEXT so getsockname() etc work. conn.setsockopt(socket.SOL_SOCKET, _overlapped.SO_UPDATE_CONNECT_CONTEXT, 0) return conn + return self._register(ov, conn, finish_connect) def accept_pipe(self, pipe): self._register_with_iocp(pipe) ov = _overlapped.Overlapped(NULL) ov.ConnectNamedPipe(pipe.fileno()) + def finish(trans, key, ov): ov.getresult() return pipe + return self._register(ov, pipe, finish) def connect_pipe(self, address): ov = _overlapped.Overlapped(NULL) ov.WaitNamedPipeAndConnect(address, self._iocp, ov.address) + def finish(err, handle, ov): # err, handle were arguments passed to PostQueuedCompletionStatus() # in a function run in a thread pool. @@ -296,6 +309,7 @@ def finish(err, handle, ov): raise OSError(0, msg, None, err) else: return windows_utils.PipeHandle(handle) + return self._register(ov, None, finish, wait_for_post=True) def wait_for_handle(self, handle, timeout=None): @@ -432,8 +446,10 @@ def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): self._proc = windows_utils.Popen( args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, bufsize=bufsize, **kwargs) + def callback(f): returncode = self._proc.poll() self._process_exited(returncode) + f = self._loop._proactor.wait_for_handle(int(self._proc._handle)) f.add_done_callback(callback) diff --git a/asyncio/windows_utils.py b/asyncio/windows_utils.py index 2fc3f7a9..aa1c0648 100644 --- a/asyncio/windows_utils.py +++ b/asyncio/windows_utils.py @@ -18,18 +18,18 @@ __all__ = ['socketpair', 'pipe', 'Popen', 'PIPE', 'PipeHandle'] -# + # Constants/globals -# + BUFSIZE = 8192 PIPE = subprocess.PIPE STDOUT = subprocess.STDOUT _mmap_counter = itertools.count() -# + # Replacement for socket.socketpair() -# + def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): """A socket pair usable as a self-pipe, for Windows. @@ -57,9 +57,9 @@ def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): lsock.close() return (ssock, csock) -# + # Replacement for os.pipe() using handles instead of fds -# + def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): """Like os.pipe() but with overlapped support and using handles not fds.""" @@ -105,9 +105,9 @@ def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): _winapi.CloseHandle(h2) raise -# + # Wrapper for a pipe handle -# + class PipeHandle: """Wrapper for an overlapped pipe handle which is vaguely file-object like. @@ -137,9 +137,9 @@ def __enter__(self): def __exit__(self, t, v, tb): self.close() -# + # Replacement for subprocess.Popen using overlapped pipe handles -# + class Popen(subprocess.Popen): """Replacement for subprocess.Popen using overlapped pipe handles. diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 8610c781..f093ed0c 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -446,18 +446,22 @@ def getaddrinfo_task(*args, **kwds): def test_create_connection_server_hostname_default(self): self.loop.getaddrinfo = unittest.mock.Mock() + def mock_getaddrinfo(*args, **kwds): f = futures.Future(loop=self.loop) f.set_result([(socket.AF_INET, socket.SOCK_STREAM, socket.SOL_TCP, '', ('1.2.3.4', 80))]) return f + self.loop.getaddrinfo.side_effect = mock_getaddrinfo self.loop.sock_connect = unittest.mock.Mock() self.loop.sock_connect.return_value = () self.loop._make_ssl_transport = unittest.mock.Mock() + def mock_make_ssl_transport(sock, protocol, sslcontext, waiter, **kwds): waiter.set_result(None) + self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport ANY = unittest.mock.ANY # First try the default server_hostname. diff --git a/tests/test_events.py b/tests/test_events.py index 83d73973..12889419 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1276,7 +1276,6 @@ class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): def create_event_loop(self): return windows_events.SelectorEventLoop() - class ProactorEventLoopTests(EventLoopTestsMixin, SubprocessTestsMixin, unittest.TestCase): diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index 553ea343..f5147de2 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -77,7 +77,7 @@ def _test_pipe(self): stream_reader = streams.StreamReader(loop=self.loop) protocol = streams.StreamReaderProtocol(stream_reader) trans, proto = yield from self.loop.create_pipe_connection( - lambda:protocol, ADDRESS) + lambda: protocol, ADDRESS) self.assertIsInstance(trans, transports.Transport) self.assertEqual(protocol, proto) clients.append((stream_reader, trans)) From 05fb98265aeaed8279f72470f48fd046e3815f43 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 1 Nov 2013 09:53:40 -0700 Subject: [PATCH 0767/1502] Better-looking errors when ssl module cannot be imported. Fixes issue #77. After a patch by Arno Faure. --- asyncio/base_events.py | 2 ++ asyncio/selector_events.py | 31 +++++++++++++++++++------------ tests/test_selector_events.py | 20 ++++++++++++++++++++ 3 files changed, 41 insertions(+), 12 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index a73b3d39..f2d117bd 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -466,6 +466,8 @@ def create_server(self, protocol_factory, host=None, port=None, ssl=None, reuse_address=None): """XXX""" + if isinstance(ssl, bool): + raise TypeError('ssl argument must be an SSLContext or None') if host is not None or port is not None: if sock is not None: raise ValueError( diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index c5fc5eb7..3bad1980 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -90,12 +90,13 @@ def _write_to_self(self): except (BlockingIOError, InterruptedError): pass - def _start_serving(self, protocol_factory, sock, ssl=None, server=None): + def _start_serving(self, protocol_factory, sock, + sslcontext=None, server=None): self.add_reader(sock.fileno(), self._accept_connection, - protocol_factory, sock, ssl, server) + protocol_factory, sock, sslcontext, server) - def _accept_connection(self, protocol_factory, sock, ssl=None, - server=None): + def _accept_connection(self, protocol_factory, sock, + sslcontext=None, server=None): try: conn, addr = sock.accept() conn.setblocking(False) @@ -113,13 +114,13 @@ def _accept_connection(self, protocol_factory, sock, ssl=None, self.remove_reader(sock.fileno()) self.call_later(constants.ACCEPT_RETRY_DELAY, self._start_serving, - protocol_factory, sock, ssl, server) + protocol_factory, sock, sslcontext, server) else: raise # The event loop will catch, log and ignore it. else: - if ssl: + if sslcontext: self._make_ssl_transport( - conn, protocol_factory(), ssl, None, + conn, protocol_factory(), sslcontext, None, server_side=True, extra={'peername': addr}, server=server) else: self._make_socket_transport( @@ -558,17 +559,23 @@ class _SelectorSslTransport(_SelectorTransport): def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, server_side=False, server_hostname=None, extra=None, server=None): + if ssl is None: + raise RuntimeError('stdlib ssl module not available') + if server_side: - assert isinstance( - sslcontext, ssl.SSLContext), 'Must pass an SSLContext' + if not sslcontext: + raise ValueError('Server side ssl needs a valid SSLContext') else: - # Client-side may pass ssl=True to use a default context. - # The default is the same as used by urllib. - if sslcontext is None: + if not sslcontext: + # Client side may pass ssl=True to use a default + # context; in that case the sslcontext passed is None. + # The default is the same as used by urllib with + # cadefault=True. sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext.options |= ssl.OP_NO_SSLv2 sslcontext.set_default_verify_paths() sslcontext.verify_mode = ssl.CERT_REQUIRED + wrap_kwargs = { 'server_side': server_side, 'do_handshake_on_connect': False, diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 3b8238d5..04a7d0c5 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -43,6 +43,7 @@ def test_make_socket_transport(self): self.assertIsInstance( self.loop._make_socket_transport(m, m), _SelectorSocketTransport) + @unittest.skipIf(ssl is None, 'No ssl module') def test_make_ssl_transport(self): m = unittest.mock.Mock() self.loop.add_reader = unittest.mock.Mock() @@ -52,6 +53,16 @@ def test_make_ssl_transport(self): self.assertIsInstance( self.loop._make_ssl_transport(m, m, m, m), _SelectorSslTransport) + @unittest.mock.patch('asyncio.selector_events.ssl', None) + def test_make_ssl_transport_without_ssl_error(self): + m = unittest.mock.Mock() + self.loop.add_reader = unittest.mock.Mock() + self.loop.add_writer = unittest.mock.Mock() + self.loop.remove_reader = unittest.mock.Mock() + self.loop.remove_writer = unittest.mock.Mock() + with self.assertRaises(RuntimeError): + self.loop._make_ssl_transport(m, m, m, m) + def test_close(self): ssock = self.loop._ssock ssock.fileno.return_value = 7 @@ -1277,6 +1288,15 @@ def test_server_hostname(self): server_hostname='localhost') +class SelectorSslWithoutSslTransportTests(unittest.TestCase): + + @unittest.mock.patch('asyncio.selector_events.ssl', None) + def test_ssl_transport_requires_ssl_module(self): + Mock = unittest.mock.Mock + with self.assertRaises(RuntimeError): + transport = _SelectorSslTransport(Mock(), Mock(), Mock(), Mock()) + + class SelectorDatagramTransportTests(unittest.TestCase): def setUp(self): From 6ddd94b25cef98401331cf1ee4c295e42ef0add3 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 1 Nov 2013 10:01:39 -0700 Subject: [PATCH 0768/1502] Slight rearrangement of tests for server_hostname=... --- tests/test_base_events.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index f093ed0c..9b883c52 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -444,7 +444,7 @@ def getaddrinfo_task(*args, **kwds): self.assertRaises( OSError, self.loop.run_until_complete, coro) - def test_create_connection_server_hostname_default(self): + def test_create_connection_ssl_server_hostname_default(self): self.loop.getaddrinfo = unittest.mock.Mock() def mock_getaddrinfo(*args, **kwds): @@ -490,8 +490,8 @@ def mock_make_ssl_transport(sock, protocol, sslcontext, waiter, server_side=False, server_hostname='') - def test_create_connection_server_hostname_errors(self): - # When not using ssl, server_hostname must be None (but '' is OK). + def test_create_connection_no_ssl_server_hostname_errors(self): + # When not using ssl, server_hostname must be None. coro = self.loop.create_connection(MyProto, 'python.org', 80, server_hostname='') self.assertRaises(ValueError, self.loop.run_until_complete, coro) @@ -499,6 +499,7 @@ def test_create_connection_server_hostname_errors(self): server_hostname='python.org') self.assertRaises(ValueError, self.loop.run_until_complete, coro) + def test_create_connection_ssl_server_hostname_errors(self): # When using ssl, server_hostname may be None if host is non-empty. coro = self.loop.create_connection(MyProto, '', 80, ssl=True) self.assertRaises(ValueError, self.loop.run_until_complete, coro) From 53e4794c6790d14a32969ad2f28426b5b24a0117 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 1 Nov 2013 10:23:23 -0700 Subject: [PATCH 0769/1502] Add pool-closing to fetch3 example. --- examples/fetch3.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/examples/fetch3.py b/examples/fetch3.py index fac880fe..fa9ebb01 100644 --- a/examples/fetch3.py +++ b/examples/fetch3.py @@ -18,6 +18,10 @@ def __init__(self, verbose=False): self.verbose = verbose self.connections = {} # {(host, port, ssl): (reader, writer)} + def close(self): + for _, writer in self.connections.values(): + writer.close() + @coroutine def open_connection(self, host, port, ssl): port = port or (443 if ssl else 80) @@ -187,18 +191,21 @@ def read(self): @coroutine def fetch(url, verbose=True, max_redirect=10): pool = ConnectionPool(verbose) - for _ in range(max_redirect): - request = Request(url, verbose) - yield from request.connect(pool) - yield from request.send_request() - response = yield from request.get_response() - body = yield from response.read() - next_url = response.get_redirect_url() - if not next_url: - break - url = urllib.parse.urljoin(url, next_url) - print('redirect to', url, file=sys.stderr) - return body + try: + for _ in range(max_redirect): + request = Request(url, verbose) + yield from request.connect(pool) + yield from request.send_request() + response = yield from request.get_response() + body = yield from response.read() + next_url = response.get_redirect_url() + if not next_url: + break + url = urllib.parse.urljoin(url, next_url) + print('redirect to', url, file=sys.stderr) + return body + finally: + pool.close() def main(): From a4f2204a57a26fa5d2b91bfb82e2f5d6f44cede7 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 1 Nov 2013 15:51:10 -0700 Subject: [PATCH 0770/1502] Add limited TLS capability to source/sink examples. --- examples/sink.py | 20 ++++++++++++++++++-- examples/source.py | 10 +++++++++- examples/source1.py | 14 +++++++++++--- 3 files changed, 38 insertions(+), 6 deletions(-) diff --git a/examples/sink.py b/examples/sink.py index b5edc3aa..4b223fdd 100644 --- a/examples/sink.py +++ b/examples/sink.py @@ -1,11 +1,15 @@ """Test service that accepts connections and reads all data off them.""" import argparse +import os import sys from asyncio import * ARGS = argparse.ArgumentParser(description="TCP data sink example.") +ARGS.add_argument( + '--tls', action='store_true', dest='tls', + default=False, help='Use TLS with a self-signed cert') ARGS.add_argument( '--iocp', action='store_true', dest='iocp', default=False, help='Use IOCP event loop (Windows only)') @@ -54,8 +58,20 @@ def connection_lost(self, how): @coroutine def start(loop, host, port): global server - server = yield from loop.create_server(Service, host, port) - dprint('serving', [s.getsockname() for s in server.sockets]) + sslctx = None + if args.tls: + import ssl + # TODO: take cert/key from args as well. + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + sslctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslctx.options |= ssl.OP_NO_SSLv2 + sslctx.load_cert_chain( + certfile=os.path.join(here, 'sample.crt'), + keyfile=os.path.join(here, 'sample.key')) + + server = yield from loop.create_server(Service, host, port, ssl=sslctx) + dprint('serving TLS' if sslctx else 'serving', + [s.getsockname() for s in server.sockets]) yield from server.wait_closed() diff --git a/examples/source.py b/examples/source.py index c36f1478..9aff8c1d 100644 --- a/examples/source.py +++ b/examples/source.py @@ -4,9 +4,13 @@ import sys from asyncio import * +from asyncio import test_utils ARGS = argparse.ArgumentParser(description="TCP data sink example.") +ARGS.add_argument( + '--tls', action='store_true', dest='tls', + default=False, help='Use TLS') ARGS.add_argument( '--iocp', action='store_true', dest='iocp', default=False, help='Use IOCP event loop (Windows only)') @@ -67,7 +71,11 @@ def connection_lost(self, exc): @coroutine def start(loop, host, port): - tr, pr = yield from loop.create_connection(Client, host, port) + sslctx = None + if args.tls: + sslctx = test_utils.dummy_ssl_context() + tr, pr = yield from loop.create_connection(Client, host, port, + ssl=sslctx) dprint('tr =', tr) dprint('pr =', pr) yield from pr.waiter diff --git a/examples/source1.py b/examples/source1.py index 6471d819..b8f89790 100644 --- a/examples/source1.py +++ b/examples/source1.py @@ -3,9 +3,13 @@ import argparse import sys -from tulip import * +from asyncio import * +from asyncio import test_utils ARGS = argparse.ArgumentParser(description="TCP data sink example.") +ARGS.add_argument( + '--tls', action='store_true', dest='tls', + default=False, help='Use TLS') ARGS.add_argument( '--iocp', action='store_true', dest='iocp', default=False, help='Use IOCP event loop (Windows only)') @@ -49,7 +53,11 @@ def oprint(self, *args): def start(loop, args): d = Debug() total = 0 - r, w = yield from open_connection(args.host, args.port) + sslctx = None + if args.tls: + d.print('using dummy SSLContext') + sslctx = test_utils.dummy_ssl_context() + r, w = yield from open_connection(args.host, args.port, ssl=sslctx) d.print('r =', r) d.print('w =', w) if args.stop: @@ -75,7 +83,7 @@ def main(): global args args = ARGS.parse_args() if args.iocp: - from tulip.windows_events import ProactorEventLoop + from asyncio.windows_events import ProactorEventLoop loop = ProactorEventLoop() set_event_loop(loop) else: From 65abd4a80699e001ab757a959f4253af74cd86b4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 3 Nov 2013 14:07:25 -0800 Subject: [PATCH 0771/1502] Locks improvements by Arnaud Faure: better repr(), change Condition structure. --- asyncio/locks.py | 78 +++++++++++++++++++++++++++++++-------------- tests/test_locks.py | 71 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 124 insertions(+), 25 deletions(-) diff --git a/asyncio/locks.py b/asyncio/locks.py index 06edbbc1..ac851e5c 100644 --- a/asyncio/locks.py +++ b/asyncio/locks.py @@ -155,9 +155,11 @@ def __init__(self, *, loop=None): self._loop = events.get_event_loop() def __repr__(self): - # TODO: add waiters:N if > 0. res = super().__repr__() - return '<{} [{}]>'.format(res[1:-1], 'set' if self._value else 'unset') + extra = 'set' if self._value else 'unset' + if self._waiters: + extra = '{},waiters:{}'.format(extra, len(self._waiters)) + return '<{} [{}]>'.format(res[1:-1], extra) def is_set(self): """Return true if and only if the internal flag is true.""" @@ -201,20 +203,38 @@ def wait(self): self._waiters.remove(fut) -# TODO: Why is this a Lock subclass? threading.Condition *has* a lock. -class Condition(Lock): - """A Condition implementation. +class Condition: + """A Condition implementation, our equivalent to threading.Condition. This class implements condition variable objects. A condition variable allows one or more coroutines to wait until they are notified by another coroutine. + + A new Lock object is created and used as the underlying lock. """ def __init__(self, *, loop=None): - super().__init__(loop=loop) - self._condition_waiters = collections.deque() + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() - # TODO: Add __repr__() with len(_condition_waiters). + # Lock as an attribute as in threading.Condition. + lock = Lock(loop=self._loop) + self._lock = lock + # Export the lock's locked(), acquire() and release() methods. + self.locked = lock.locked + self.acquire = lock.acquire + self.release = lock.release + + self._waiters = collections.deque() + + def __repr__(self): + res = super().__repr__() + extra = 'locked' if self.locked() else 'unlocked' + if self._waiters: + extra = '{},waiters:{}'.format(extra, len(self._waiters)) + return '<{} [{}]>'.format(res[1:-1], extra) @tasks.coroutine def wait(self): @@ -228,19 +248,19 @@ def wait(self): the same condition variable in another coroutine. Once awakened, it re-acquires the lock and returns True. """ - if not self._locked: + if not self.locked(): raise RuntimeError('cannot wait on un-acquired lock') keep_lock = True self.release() try: fut = futures.Future(loop=self._loop) - self._condition_waiters.append(fut) + self._waiters.append(fut) try: yield from fut return True finally: - self._condition_waiters.remove(fut) + self._waiters.remove(fut) except GeneratorExit: keep_lock = False # Prevent yield in finally clause. @@ -275,11 +295,11 @@ def notify(self, n=1): wait() call until it can reacquire the lock. Since notify() does not release the lock, its caller should. """ - if not self._locked: + if not self.locked(): raise RuntimeError('cannot notify on un-acquired lock') idx = 0 - for fut in self._condition_waiters: + for fut in self._waiters: if idx >= n: break @@ -293,7 +313,17 @@ def notify_all(self): calling thread has not acquired the lock when this method is called, a RuntimeError is raised. """ - self.notify(len(self._condition_waiters)) + self.notify(len(self._waiters)) + + def __enter__(self): + return self._lock.__enter__() + + def __exit__(self, *args): + return self._lock.__exit__(*args) + + def __iter__(self): + yield from self.acquire() + return self class Semaphore: @@ -310,10 +340,10 @@ class Semaphore: counter; it defaults to 1. If the value given is less than 0, ValueError is raised. - The second optional argument determins can semophore be released more than - initial internal counter value; it defaults to False. If the value given - is True and number of release() is more than number of successfull - acquire() calls ValueError is raised. + The second optional argument determines if the semaphore can be released + more than initial internal counter value; it defaults to False. If the + value given is True and number of release() is more than number of + successful acquire() calls ValueError is raised. """ def __init__(self, value=1, bound=False, *, loop=None): @@ -330,12 +360,12 @@ def __init__(self, value=1, bound=False, *, loop=None): self._loop = events.get_event_loop() def __repr__(self): - # TODO: add waiters:N if > 0. res = super().__repr__() - return '<{} [{}]>'.format( - res[1:-1], - 'locked' if self._locked else 'unlocked,value:{}'.format( - self._value)) + extra = 'locked' if self._locked else 'unlocked,value:{}'.format( + self._value) + if self._waiters: + extra = '{},waiters:{}'.format(extra, len(self._waiters)) + return '<{} [{}]>'.format(res[1:-1], extra) def locked(self): """Returns True if semaphore can not be acquired immediately.""" @@ -373,7 +403,7 @@ def release(self): When it was zero on entry and another coroutine is waiting for it to become larger than zero again, wake up that coroutine. - If Semaphore is create with "bound" paramter equals true, then + If Semaphore is created with "bound" parameter equals true, then release() method checks to make sure its current value doesn't exceed its initial value. If it does, ValueError is raised. """ diff --git a/tests/test_locks.py b/tests/test_locks.py index 31b4d64b..19ef877a 100644 --- a/tests/test_locks.py +++ b/tests/test_locks.py @@ -2,6 +2,7 @@ import unittest import unittest.mock +import re from asyncio import events from asyncio import futures @@ -10,6 +11,15 @@ from asyncio import test_utils +STR_RGX_REPR = ( + r'^<(?P.*?) object at (?P
.*?)' + r'\[(?P' + r'(set|unset|locked|unlocked)(,value:\d)?(,waiters:\d+)?' + r')\]>\Z' +) +RGX_REPR = re.compile(STR_RGX_REPR) + + class LockTests(unittest.TestCase): def setUp(self): @@ -38,6 +48,7 @@ def test_ctor_noloop(self): def test_repr(self): lock = locks.Lock(loop=self.loop) self.assertTrue(repr(lock).endswith('[unlocked]>')) + self.assertTrue(RGX_REPR.match(repr(lock))) @tasks.coroutine def acquire_lock(): @@ -45,6 +56,7 @@ def acquire_lock(): self.loop.run_until_complete(acquire_lock()) self.assertTrue(repr(lock).endswith('[locked]>')) + self.assertTrue(RGX_REPR.match(repr(lock))) def test_lock(self): lock = locks.Lock(loop=self.loop) @@ -239,9 +251,16 @@ def test_ctor_noloop(self): def test_repr(self): ev = locks.Event(loop=self.loop) self.assertTrue(repr(ev).endswith('[unset]>')) + match = RGX_REPR.match(repr(ev)) + self.assertEqual(match.group('extras'), 'unset') ev.set() self.assertTrue(repr(ev).endswith('[set]>')) + self.assertTrue(RGX_REPR.match(repr(ev))) + + ev._waiters.append(unittest.mock.Mock()) + self.assertTrue('waiters:1' in repr(ev)) + self.assertTrue(RGX_REPR.match(repr(ev))) def test_wait(self): ev = locks.Event(loop=self.loop) @@ -440,7 +459,7 @@ def test_wait_cancel(self): self.assertRaises( futures.CancelledError, self.loop.run_until_complete, wait) - self.assertFalse(cond._condition_waiters) + self.assertFalse(cond._waiters) self.assertTrue(cond.locked()) def test_wait_unacquired(self): @@ -600,6 +619,45 @@ def test_notify_all_unacquired(self): cond = locks.Condition(loop=self.loop) self.assertRaises(RuntimeError, cond.notify_all) + def test_repr(self): + cond = locks.Condition(loop=self.loop) + self.assertTrue('unlocked' in repr(cond)) + self.assertTrue(RGX_REPR.match(repr(cond))) + + self.loop.run_until_complete(cond.acquire()) + self.assertTrue('locked' in repr(cond)) + + cond._waiters.append(unittest.mock.Mock()) + self.assertTrue('waiters:1' in repr(cond)) + self.assertTrue(RGX_REPR.match(repr(cond))) + + cond._waiters.append(unittest.mock.Mock()) + self.assertTrue('waiters:2' in repr(cond)) + self.assertTrue(RGX_REPR.match(repr(cond))) + + def test_context_manager(self): + cond = locks.Condition(loop=self.loop) + + @tasks.coroutine + def acquire_cond(): + return (yield from cond) + + with self.loop.run_until_complete(acquire_cond()): + self.assertTrue(cond.locked()) + + self.assertFalse(cond.locked()) + + def test_context_manager_no_yield(self): + cond = locks.Condition(loop=self.loop) + + try: + with cond: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + class SemaphoreTests(unittest.TestCase): @@ -629,9 +687,20 @@ def test_ctor_noloop(self): def test_repr(self): sem = locks.Semaphore(loop=self.loop) self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) + self.assertTrue(RGX_REPR.match(repr(sem))) self.loop.run_until_complete(sem.acquire()) self.assertTrue(repr(sem).endswith('[locked]>')) + self.assertTrue('waiters' not in repr(sem)) + self.assertTrue(RGX_REPR.match(repr(sem))) + + sem._waiters.append(unittest.mock.Mock()) + self.assertTrue('waiters:1' in repr(sem)) + self.assertTrue(RGX_REPR.match(repr(sem))) + + sem._waiters.append(unittest.mock.Mock()) + self.assertTrue('waiters:2' in repr(sem)) + self.assertTrue(RGX_REPR.match(repr(sem))) def test_semaphore(self): sem = locks.Semaphore(loop=self.loop) From 16e33e1dc34e1e2e3d4622ef6a5df56284bc623a Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Mon, 4 Nov 2013 13:16:23 -0800 Subject: [PATCH 0772/1502] Relax test for process return code on Windows. --- tests/test_events.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_events.py b/tests/test_events.py index 12889419..4af9aa93 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -961,14 +961,14 @@ class SubprocessTestsMixin: def check_terminated(self, returncode): if sys.platform == 'win32': self.assertIsInstance(returncode, int) - self.assertNotEqual(0, returncode) + # expect 1 but sometimes get 0 else: self.assertEqual(-signal.SIGTERM, returncode) def check_killed(self, returncode): if sys.platform == 'win32': self.assertIsInstance(returncode, int) - self.assertNotEqual(0, returncode) + # expect 1 but sometimes get 0 else: self.assertEqual(-signal.SIGKILL, returncode) From 11ddf5f8681dd202d05e40d098e3ba8289ff046d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 4 Nov 2013 15:34:55 -0800 Subject: [PATCH 0773/1502] Refactor SIGCHLD handler, by Anthony Baire. Fixes issue 67. --- asyncio/events.py | 70 ++- asyncio/unix_events.py | 396 +++++++++++++-- asyncio/windows_events.py | 14 +- tests/test_events.py | 44 +- tests/test_unix_events.py | 987 +++++++++++++++++++++++++++++++++----- 5 files changed, 1312 insertions(+), 199 deletions(-) diff --git a/asyncio/events.py b/asyncio/events.py index 7ebc3cb4..36ae312b 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -1,10 +1,11 @@ """Event loop and event loop policy.""" -__all__ = ['AbstractEventLoopPolicy', 'DefaultEventLoopPolicy', +__all__ = ['AbstractEventLoopPolicy', 'AbstractEventLoop', 'AbstractServer', 'Handle', 'TimerHandle', 'get_event_loop_policy', 'set_event_loop_policy', 'get_event_loop', 'set_event_loop', 'new_event_loop', + 'get_child_watcher', 'set_child_watcher', ] import subprocess @@ -318,8 +319,18 @@ def new_event_loop(self): """XXX""" raise NotImplementedError + # Child processes handling (Unix only). -class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy): + def get_child_watcher(self): + """XXX""" + raise NotImplementedError + + def set_child_watcher(self, watcher): + """XXX""" + raise NotImplementedError + + +class BaseDefaultEventLoopPolicy(AbstractEventLoopPolicy): """Default policy implementation for accessing the event loop. In this policy, each thread has its own event loop. However, we @@ -332,28 +343,34 @@ class DefaultEventLoopPolicy(threading.local, AbstractEventLoopPolicy): associated). """ - _loop = None - _set_called = False + _loop_factory = None + + class _Local(threading.local): + _loop = None + _set_called = False + + def __init__(self): + self._local = self._Local() def get_event_loop(self): """Get the event loop. This may be None or an instance of EventLoop. """ - if (self._loop is None and - not self._set_called and + if (self._local._loop is None and + not self._local._set_called and isinstance(threading.current_thread(), threading._MainThread)): - self._loop = self.new_event_loop() - assert self._loop is not None, \ + self._local._loop = self.new_event_loop() + assert self._local._loop is not None, \ ('There is no current event loop in thread %r.' % threading.current_thread().name) - return self._loop + return self._local._loop def set_event_loop(self, loop): """Set the event loop.""" - self._set_called = True + self._local._set_called = True assert loop is None or isinstance(loop, AbstractEventLoop) - self._loop = loop + self._local._loop = loop def new_event_loop(self): """Create a new event loop. @@ -361,12 +378,7 @@ def new_event_loop(self): You must call set_event_loop() to make this the current event loop. """ - if sys.platform == 'win32': # pragma: no cover - from . import windows_events - return windows_events.SelectorEventLoop() - else: # pragma: no cover - from . import unix_events - return unix_events.SelectorEventLoop() + return self._loop_factory() # Event loop policy. The policy itself is always global, even if the @@ -375,12 +387,22 @@ def new_event_loop(self): # call to get_event_loop_policy(). _event_loop_policy = None +# Lock for protecting the on-the-fly creation of the event loop policy. +_lock = threading.Lock() + + +def _init_event_loop_policy(): + global _event_loop_policy + with _lock: + if _event_loop_policy is None: # pragma: no branch + from . import DefaultEventLoopPolicy + _event_loop_policy = DefaultEventLoopPolicy() + def get_event_loop_policy(): """XXX""" - global _event_loop_policy if _event_loop_policy is None: - _event_loop_policy = DefaultEventLoopPolicy() + _init_event_loop_policy() return _event_loop_policy @@ -404,3 +426,13 @@ def set_event_loop(loop): def new_event_loop(): """XXX""" return get_event_loop_policy().new_event_loop() + + +def get_child_watcher(): + """XXX""" + return get_event_loop_policy().get_child_watcher() + + +def set_child_watcher(watcher): + """XXX""" + return get_event_loop_policy().set_child_watcher(watcher) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index c95ad488..dd57fe8e 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -8,6 +8,7 @@ import stat import subprocess import sys +import threading from . import base_subprocess @@ -20,7 +21,10 @@ from .log import logger -__all__ = ['SelectorEventLoop', 'STDIN', 'STDOUT', 'STDERR'] +__all__ = ['SelectorEventLoop', 'STDIN', 'STDOUT', 'STDERR', + 'AbstractChildWatcher', 'SafeChildWatcher', + 'FastChildWatcher', 'DefaultEventLoopPolicy', + ] STDIN = 0 STDOUT = 1 @@ -31,7 +35,7 @@ raise ImportError('Signals are not really supported on Windows') -class SelectorEventLoop(selector_events.BaseSelectorEventLoop): +class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): """Unix event loop Adds signal handling to SelectorEventLoop @@ -40,17 +44,10 @@ class SelectorEventLoop(selector_events.BaseSelectorEventLoop): def __init__(self, selector=None): super().__init__(selector) self._signal_handlers = {} - self._subprocesses = {} def _socketpair(self): return socket.socketpair() - def close(self): - handler = self._signal_handlers.get(signal.SIGCHLD) - if handler is not None: - self.remove_signal_handler(signal.SIGCHLD) - super().close() - def add_signal_handler(self, sig, callback, *args): """Add a handler for a signal. UNIX only. @@ -152,49 +149,20 @@ def _make_write_pipe_transport(self, pipe, protocol, waiter=None, def _make_subprocess_transport(self, protocol, args, shell, stdin, stdout, stderr, bufsize, extra=None, **kwargs): - self._reg_sigchld() - transp = _UnixSubprocessTransport(self, protocol, args, shell, - stdin, stdout, stderr, bufsize, - extra=None, **kwargs) - self._subprocesses[transp.get_pid()] = transp + with events.get_child_watcher() as watcher: + transp = _UnixSubprocessTransport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs) + watcher.add_child_handler(transp.get_pid(), + self._child_watcher_callback, transp) yield from transp._post_init() return transp - def _reg_sigchld(self): - if signal.SIGCHLD not in self._signal_handlers: - self.add_signal_handler(signal.SIGCHLD, self._sig_chld) + def _child_watcher_callback(self, pid, returncode, transp): + self.call_soon_threadsafe(transp._process_exited, returncode) - def _sig_chld(self): - try: - # Because of signal coalescing, we must keep calling waitpid() as - # long as we're able to reap a child. - while True: - try: - pid, status = os.waitpid(-1, os.WNOHANG) - except ChildProcessError: - break # No more child processes exist. - if pid == 0: - break # All remaining child processes are still alive. - elif os.WIFSIGNALED(status): - # A child process died because of a signal. - returncode = -os.WTERMSIG(status) - elif os.WIFEXITED(status): - # A child process exited (e.g. sys.exit()). - returncode = os.WEXITSTATUS(status) - else: - # A child exited, but we don't understand its status. - # This shouldn't happen, but if it does, let's just - # return that status; perhaps that helps debug it. - returncode = status - transp = self._subprocesses.get(pid) - if transp is not None: - transp._process_exited(returncode) - except Exception: - logger.exception('Unknown exception in SIGCHLD handler') - - def _subprocess_closed(self, transport): - pid = transport.get_pid() - self._subprocesses.pop(pid, None) + def _subprocess_closed(self, transp): + pass def _set_nonblocking(fd): @@ -423,3 +391,335 @@ def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): if stdin_w is not None: stdin.close() self._proc.stdin = open(stdin_w.detach(), 'rb', buffering=bufsize) + + +class AbstractChildWatcher: + """Abstract base class for monitoring child processes. + + Objects derived from this class monitor a collection of subprocesses and + report their termination or interruption by a signal. + + New callbacks are registered with .add_child_handler(). Starting a new + process must be done within a 'with' block to allow the watcher to suspend + its activity until the new process if fully registered (this is needed to + prevent a race condition in some implementations). + + Example: + with watcher: + proc = subprocess.Popen("sleep 1") + watcher.add_child_handler(proc.pid, callback) + + Notes: + Implementations of this class must be thread-safe. + + Since child watcher objects may catch the SIGCHLD signal and call + waitpid(-1), there should be only one active object per process. + """ + + def add_child_handler(self, pid, callback, *args): + """Register a new child handler. + + Arrange for callback(pid, returncode, *args) to be called when + process 'pid' terminates. Specifying another callback for the same + process replaces the previous handler. + + Note: callback() must be thread-safe + """ + raise NotImplementedError() + + def remove_child_handler(self, pid): + """Removes the handler for process 'pid'. + + The function returns True if the handler was successfully removed, + False if there was nothing to remove.""" + + raise NotImplementedError() + + def set_loop(self, loop): + """Reattach the watcher to another event loop. + + Note: loop may be None + """ + raise NotImplementedError() + + def close(self): + """Close the watcher. + + This must be called to make sure that any underlying resource is freed. + """ + raise NotImplementedError() + + def __enter__(self): + """Enter the watcher's context and allow starting new processes + + This function must return self""" + raise NotImplementedError() + + def __exit__(self, a, b, c): + """Exit the watcher's context""" + raise NotImplementedError() + + +class BaseChildWatcher(AbstractChildWatcher): + + def __init__(self, loop): + self._loop = None + self._callbacks = {} + + self.set_loop(loop) + + def close(self): + self.set_loop(None) + self._callbacks.clear() + + def _do_waitpid(self, expected_pid): + raise NotImplementedError() + + def _do_waitpid_all(self): + raise NotImplementedError() + + def set_loop(self, loop): + assert loop is None or isinstance(loop, events.AbstractEventLoop) + + if self._loop is not None: + self._loop.remove_signal_handler(signal.SIGCHLD) + + self._loop = loop + if loop is not None: + loop.add_signal_handler(signal.SIGCHLD, self._sig_chld) + + # Prevent a race condition in case a child terminated + # during the switch. + self._do_waitpid_all() + + def remove_child_handler(self, pid): + try: + del self._callbacks[pid] + return True + except KeyError: + return False + + def _sig_chld(self): + try: + self._do_waitpid_all() + except Exception: + logger.exception('Unknown exception in SIGCHLD handler') + + def _compute_returncode(self, status): + if os.WIFSIGNALED(status): + # The child process died because of a signal. + return -os.WTERMSIG(status) + elif os.WIFEXITED(status): + # The child process exited (e.g sys.exit()). + return os.WEXITSTATUS(status) + else: + # The child exited, but we don't understand its status. + # This shouldn't happen, but if it does, let's just + # return that status; perhaps that helps debug it. + return status + + +class SafeChildWatcher(BaseChildWatcher): + """'Safe' child watcher implementation. + + This implementation avoids disrupting other code spawning processes by + polling explicitly each process in the SIGCHLD handler instead of calling + os.waitpid(-1). + + This is a safe solution but it has a significant overhead when handling a + big number of children (O(n) each time SIGCHLD is raised) + """ + + def __enter__(self): + return self + + def __exit__(self, a, b, c): + pass + + def add_child_handler(self, pid, callback, *args): + self._callbacks[pid] = callback, args + + # Prevent a race condition in case the child is already terminated. + self._do_waitpid(pid) + + def _do_waitpid_all(self): + + for pid in list(self._callbacks): + self._do_waitpid(pid) + + def _do_waitpid(self, expected_pid): + assert expected_pid > 0 + + try: + pid, status = os.waitpid(expected_pid, os.WNOHANG) + except ChildProcessError: + # The child process is already reaped + # (may happen if waitpid() is called elsewhere). + pid = expected_pid + returncode = 255 + logger.warning( + "Unknown child process pid %d, will report returncode 255", + pid) + else: + if pid == 0: + # The child process is still alive. + return + + returncode = self._compute_returncode(status) + + try: + callback, args = self._callbacks.pop(pid) + except KeyError: # pragma: no cover + # May happen if .remove_child_handler() is called + # after os.waitpid() returns. + pass + else: + callback(pid, returncode, *args) + + +class FastChildWatcher(BaseChildWatcher): + """'Fast' child watcher implementation. + + This implementation reaps every terminated processes by calling + os.waitpid(-1) directly, possibly breaking other code spawning processes + and waiting for their termination. + + There is no noticeable overhead when handling a big number of children + (O(1) each time a child terminates). + """ + def __init__(self, loop): + super().__init__(loop) + + self._lock = threading.Lock() + self._zombies = {} + self._forks = 0 + + def close(self): + super().close() + self._zombies.clear() + + def __enter__(self): + with self._lock: + self._forks += 1 + + return self + + def __exit__(self, a, b, c): + with self._lock: + self._forks -= 1 + + if self._forks or not self._zombies: + return + + collateral_victims = str(self._zombies) + self._zombies.clear() + + logger.warning( + "Caught subprocesses termination from unknown pids: %s", + collateral_victims) + + def add_child_handler(self, pid, callback, *args): + assert self._forks, "Must use the context manager" + + self._callbacks[pid] = callback, args + + try: + # Ensure that the child is not already terminated. + # (raise KeyError if still alive) + returncode = self._zombies.pop(pid) + + # Child is dead, therefore we can fire the callback immediately. + # First we remove it from the dict. + # (raise KeyError if .remove_child_handler() was called in-between) + del self._callbacks[pid] + except KeyError: + pass + else: + callback(pid, returncode, *args) + + def _do_waitpid_all(self): + # Because of signal coalescing, we must keep calling waitpid() as + # long as we're able to reap a child. + while True: + try: + pid, status = os.waitpid(-1, os.WNOHANG) + except ChildProcessError: + # No more child processes exist. + return + else: + if pid == 0: + # A child process is still alive. + return + + returncode = self._compute_returncode(status) + + try: + callback, args = self._callbacks.pop(pid) + except KeyError: + # unknown child + with self._lock: + if self._forks: + # It may not be registered yet. + self._zombies[pid] = returncode + continue + + logger.warning( + "Caught subprocess termination from unknown pid: " + "%d -> %d", pid, returncode) + else: + callback(pid, returncode, *args) + + +class _UnixDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy): + """XXX""" + _loop_factory = _UnixSelectorEventLoop + + def __init__(self): + super().__init__() + self._watcher = None + + def _init_watcher(self): + with events._lock: + if self._watcher is None: # pragma: no branch + if isinstance(threading.current_thread(), + threading._MainThread): + self._watcher = SafeChildWatcher(self._local._loop) + else: + self._watcher = SafeChildWatcher(None) + + def set_event_loop(self, loop): + """Set the event loop. + + As a side effect, if a child watcher was set before, then calling + .set_event_loop() from the main thread will call .set_loop(loop) on the + child watcher. + """ + + super().set_event_loop(loop) + + if self._watcher is not None and \ + isinstance(threading.current_thread(), threading._MainThread): + self._watcher.set_loop(loop) + + def get_child_watcher(self): + """Get the child watcher + + If not yet set, a SafeChildWatcher object is automatically created. + """ + if self._watcher is None: + self._init_watcher() + + return self._watcher + + def set_child_watcher(self, watcher): + """Set the child watcher""" + + assert watcher is None or isinstance(watcher, AbstractChildWatcher) + + if self._watcher is not None: + self._watcher.close() + + self._watcher = watcher + +SelectorEventLoop = _UnixSelectorEventLoop +DefaultEventLoopPolicy = _UnixDefaultEventLoopPolicy diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index d7444bdf..ae3e44f4 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -7,6 +7,7 @@ import struct import _winapi +from . import events from . import base_subprocess from . import futures from . import proactor_events @@ -17,7 +18,9 @@ from . import _overlapped -__all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor'] +__all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor', + 'DefaultEventLoopPolicy', + ] NULL = 0 @@ -108,7 +111,7 @@ def close(self): __del__ = close -class SelectorEventLoop(selector_events.BaseSelectorEventLoop): +class _WindowsSelectorEventLoop(selector_events.BaseSelectorEventLoop): """Windows version of selector event loop.""" def _socketpair(self): @@ -453,3 +456,10 @@ def callback(f): f = self._loop._proactor.wait_for_handle(int(self._proc._handle)) f.add_done_callback(callback) + + +class _WindowsDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy): + _loop_factory = SelectorEventLoop + +SelectorEventLoop = _WindowsSelectorEventLoop +DefaultEventLoopPolicy = _WindowsDefaultEventLoopPolicy diff --git a/tests/test_events.py b/tests/test_events.py index 4af9aa93..00bd4085 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1308,8 +1308,17 @@ def test_create_datagram_endpoint(self): from asyncio import selectors from asyncio import unix_events + class UnixEventLoopTestsMixin(EventLoopTestsMixin): + def setUp(self): + super().setUp() + events.set_child_watcher(unix_events.SafeChildWatcher(self.loop)) + + def tearDown(self): + events.set_child_watcher(None) + super().tearDown() + if hasattr(selectors, 'KqueueSelector'): - class KqueueEventLoopTests(EventLoopTestsMixin, + class KqueueEventLoopTests(UnixEventLoopTestsMixin, SubprocessTestsMixin, unittest.TestCase): @@ -1318,7 +1327,7 @@ def create_event_loop(self): selectors.KqueueSelector()) if hasattr(selectors, 'EpollSelector'): - class EPollEventLoopTests(EventLoopTestsMixin, + class EPollEventLoopTests(UnixEventLoopTestsMixin, SubprocessTestsMixin, unittest.TestCase): @@ -1326,7 +1335,7 @@ def create_event_loop(self): return unix_events.SelectorEventLoop(selectors.EpollSelector()) if hasattr(selectors, 'PollSelector'): - class PollEventLoopTests(EventLoopTestsMixin, + class PollEventLoopTests(UnixEventLoopTestsMixin, SubprocessTestsMixin, unittest.TestCase): @@ -1334,7 +1343,7 @@ def create_event_loop(self): return unix_events.SelectorEventLoop(selectors.PollSelector()) # Should always exist. - class SelectEventLoopTests(EventLoopTestsMixin, + class SelectEventLoopTests(UnixEventLoopTestsMixin, SubprocessTestsMixin, unittest.TestCase): @@ -1557,25 +1566,36 @@ def test_empty(self): class PolicyTests(unittest.TestCase): + def create_policy(self): + if sys.platform == "win32": + from asyncio import windows_events + return windows_events.DefaultEventLoopPolicy() + else: + from asyncio import unix_events + return unix_events.DefaultEventLoopPolicy() + def test_event_loop_policy(self): policy = events.AbstractEventLoopPolicy() self.assertRaises(NotImplementedError, policy.get_event_loop) self.assertRaises(NotImplementedError, policy.set_event_loop, object()) self.assertRaises(NotImplementedError, policy.new_event_loop) + self.assertRaises(NotImplementedError, policy.get_child_watcher) + self.assertRaises(NotImplementedError, policy.set_child_watcher, + object()) def test_get_event_loop(self): - policy = events.DefaultEventLoopPolicy() - self.assertIsNone(policy._loop) + policy = self.create_policy() + self.assertIsNone(policy._local._loop) loop = policy.get_event_loop() self.assertIsInstance(loop, events.AbstractEventLoop) - self.assertIs(policy._loop, loop) + self.assertIs(policy._local._loop, loop) self.assertIs(loop, policy.get_event_loop()) loop.close() def test_get_event_loop_after_set_none(self): - policy = events.DefaultEventLoopPolicy() + policy = self.create_policy() policy.set_event_loop(None) self.assertRaises(AssertionError, policy.get_event_loop) @@ -1583,7 +1603,7 @@ def test_get_event_loop_after_set_none(self): def test_get_event_loop_thread(self, m_current_thread): def f(): - policy = events.DefaultEventLoopPolicy() + policy = self.create_policy() self.assertRaises(AssertionError, policy.get_event_loop) th = threading.Thread(target=f) @@ -1591,14 +1611,14 @@ def f(): th.join() def test_new_event_loop(self): - policy = events.DefaultEventLoopPolicy() + policy = self.create_policy() loop = policy.new_event_loop() self.assertIsInstance(loop, events.AbstractEventLoop) loop.close() def test_set_event_loop(self): - policy = events.DefaultEventLoopPolicy() + policy = self.create_policy() old_loop = policy.get_event_loop() self.assertRaises(AssertionError, policy.set_event_loop, object()) @@ -1621,7 +1641,7 @@ def test_set_event_loop_policy(self): old_policy = events.get_event_loop_policy() - policy = events.DefaultEventLoopPolicy() + policy = self.create_policy() events.set_event_loop_policy(policy) self.assertIs(policy, events.get_event_loop_policy()) self.assertIsNot(policy, old_policy) diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index f29e7afe..a4d835e3 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -3,10 +3,12 @@ import gc import errno import io +import os import pprint import signal import stat import sys +import threading import unittest import unittest.mock @@ -181,124 +183,6 @@ class Err(OSError): self.assertRaises( RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP) - @unittest.mock.patch('os.WTERMSIG') - @unittest.mock.patch('os.WEXITSTATUS') - @unittest.mock.patch('os.WIFSIGNALED') - @unittest.mock.patch('os.WIFEXITED') - @unittest.mock.patch('os.waitpid') - def test__sig_chld(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, - m_WEXITSTATUS, m_WTERMSIG): - m_waitpid.side_effect = [(7, object()), ChildProcessError] - m_WIFEXITED.return_value = True - m_WIFSIGNALED.return_value = False - m_WEXITSTATUS.return_value = 3 - transp = unittest.mock.Mock() - self.loop._subprocesses[7] = transp - - self.loop._sig_chld() - transp._process_exited.assert_called_with(3) - self.assertFalse(m_WTERMSIG.called) - - @unittest.mock.patch('os.WTERMSIG') - @unittest.mock.patch('os.WEXITSTATUS') - @unittest.mock.patch('os.WIFSIGNALED') - @unittest.mock.patch('os.WIFEXITED') - @unittest.mock.patch('os.waitpid') - def test__sig_chld_signal(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, - m_WEXITSTATUS, m_WTERMSIG): - m_waitpid.side_effect = [(7, object()), ChildProcessError] - m_WIFEXITED.return_value = False - m_WIFSIGNALED.return_value = True - m_WTERMSIG.return_value = 1 - transp = unittest.mock.Mock() - self.loop._subprocesses[7] = transp - - self.loop._sig_chld() - transp._process_exited.assert_called_with(-1) - self.assertFalse(m_WEXITSTATUS.called) - - @unittest.mock.patch('os.WTERMSIG') - @unittest.mock.patch('os.WEXITSTATUS') - @unittest.mock.patch('os.WIFSIGNALED') - @unittest.mock.patch('os.WIFEXITED') - @unittest.mock.patch('os.waitpid') - def test__sig_chld_zero_pid(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, - m_WEXITSTATUS, m_WTERMSIG): - m_waitpid.side_effect = [(0, object()), ChildProcessError] - transp = unittest.mock.Mock() - self.loop._subprocesses[7] = transp - - self.loop._sig_chld() - self.assertFalse(transp._process_exited.called) - self.assertFalse(m_WIFSIGNALED.called) - self.assertFalse(m_WIFEXITED.called) - self.assertFalse(m_WTERMSIG.called) - self.assertFalse(m_WEXITSTATUS.called) - - @unittest.mock.patch('os.WTERMSIG') - @unittest.mock.patch('os.WEXITSTATUS') - @unittest.mock.patch('os.WIFSIGNALED') - @unittest.mock.patch('os.WIFEXITED') - @unittest.mock.patch('os.waitpid') - def test__sig_chld_not_registered_subprocess(self, m_waitpid, - m_WIFEXITED, m_WIFSIGNALED, - m_WEXITSTATUS, m_WTERMSIG): - m_waitpid.side_effect = [(7, object()), ChildProcessError] - m_WIFEXITED.return_value = True - m_WIFSIGNALED.return_value = False - m_WEXITSTATUS.return_value = 3 - - self.loop._sig_chld() - self.assertFalse(m_WTERMSIG.called) - - @unittest.mock.patch('os.WTERMSIG') - @unittest.mock.patch('os.WEXITSTATUS') - @unittest.mock.patch('os.WIFSIGNALED') - @unittest.mock.patch('os.WIFEXITED') - @unittest.mock.patch('os.waitpid') - def test__sig_chld_unknown_status(self, m_waitpid, - m_WIFEXITED, m_WIFSIGNALED, - m_WEXITSTATUS, m_WTERMSIG): - m_waitpid.side_effect = [(7, object()), ChildProcessError] - m_WIFEXITED.return_value = False - m_WIFSIGNALED.return_value = False - transp = unittest.mock.Mock() - self.loop._subprocesses[7] = transp - - self.loop._sig_chld() - self.assertTrue(transp._process_exited.called) - self.assertFalse(m_WEXITSTATUS.called) - self.assertFalse(m_WTERMSIG.called) - - @unittest.mock.patch('asyncio.unix_events.logger') - @unittest.mock.patch('os.WTERMSIG') - @unittest.mock.patch('os.WEXITSTATUS') - @unittest.mock.patch('os.WIFSIGNALED') - @unittest.mock.patch('os.WIFEXITED') - @unittest.mock.patch('os.waitpid') - def test__sig_chld_unknown_status_in_handler(self, m_waitpid, - m_WIFEXITED, m_WIFSIGNALED, - m_WEXITSTATUS, m_WTERMSIG, - m_log): - m_waitpid.side_effect = Exception - transp = unittest.mock.Mock() - self.loop._subprocesses[7] = transp - - self.loop._sig_chld() - self.assertFalse(transp._process_exited.called) - self.assertFalse(m_WIFSIGNALED.called) - self.assertFalse(m_WIFEXITED.called) - self.assertFalse(m_WTERMSIG.called) - self.assertFalse(m_WEXITSTATUS.called) - m_log.exception.assert_called_with( - 'Unknown exception in SIGCHLD handler') - - @unittest.mock.patch('os.waitpid') - def test__sig_chld_process_error(self, m_waitpid): - m_waitpid.side_effect = ChildProcessError - self.loop._sig_chld() - self.assertTrue(m_waitpid.called) - class UnixReadPipeTransportTests(unittest.TestCase): @@ -777,5 +661,872 @@ def test_write_eof_pending(self): self.assertFalse(self.protocol.connection_lost.called) +class AbstractChildWatcherTests(unittest.TestCase): + + def test_not_implemented(self): + f = unittest.mock.Mock() + watcher = unix_events.AbstractChildWatcher() + self.assertRaises( + NotImplementedError, watcher.add_child_handler, f, f) + self.assertRaises( + NotImplementedError, watcher.remove_child_handler, f) + self.assertRaises( + NotImplementedError, watcher.set_loop, f) + self.assertRaises( + NotImplementedError, watcher.close) + self.assertRaises( + NotImplementedError, watcher.__enter__) + self.assertRaises( + NotImplementedError, watcher.__exit__, f, f, f) + + +class BaseChildWatcherTests(unittest.TestCase): + + def test_not_implemented(self): + f = unittest.mock.Mock() + watcher = unix_events.BaseChildWatcher(None) + self.assertRaises( + NotImplementedError, watcher._do_waitpid, f) + + +class ChildWatcherTestsMixin: + instance = None + + ignore_warnings = unittest.mock.patch.object(unix_events.logger, "warning") + + def setUp(self): + self.loop = test_utils.TestLoop() + self.running = False + self.zombies = {} + + assert ChildWatcherTestsMixin.instance is None + ChildWatcherTestsMixin.instance = self + + with unittest.mock.patch.object( + self.loop, "add_signal_handler") as self.m_add_signal_handler: + self.watcher = self.create_watcher(self.loop) + + def tearDown(self): + ChildWatcherTestsMixin.instance = None + + def waitpid(pid, flags): + self = ChildWatcherTestsMixin.instance + if isinstance(self.watcher, unix_events.SafeChildWatcher) or pid != -1: + self.assertGreater(pid, 0) + try: + if pid < 0: + return self.zombies.popitem() + else: + return pid, self.zombies.pop(pid) + except KeyError: + pass + if self.running: + return 0, 0 + else: + raise ChildProcessError() + + def add_zombie(self, pid, returncode): + self.zombies[pid] = returncode + 32768 + + def WIFEXITED(status): + return status >= 32768 + + def WIFSIGNALED(status): + return 32700 < status < 32768 + + def WEXITSTATUS(status): + self = ChildWatcherTestsMixin.instance + self.assertTrue(type(self).WIFEXITED(status)) + return status - 32768 + + def WTERMSIG(status): + self = ChildWatcherTestsMixin.instance + self.assertTrue(type(self).WIFSIGNALED(status)) + return 32768 - status + + def test_create_watcher(self): + self.m_add_signal_handler.assert_called_once_with( + signal.SIGCHLD, self.watcher._sig_chld) + + @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) + @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) + @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) + @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) + @unittest.mock.patch('os.waitpid', wraps=waitpid) + def test_sigchld(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + # register a child + callback = unittest.mock.Mock() + + with self.watcher: + self.running = True + self.watcher.add_child_handler(42, callback, 9, 10, 14) + + self.assertFalse(callback.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + # child is running + self.watcher._sig_chld() + + self.assertFalse(callback.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + # child terminates (returncode 12) + self.running = False + self.add_zombie(42, 12) + self.watcher._sig_chld() + + self.assertTrue(m_WIFEXITED.called) + self.assertTrue(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + callback.assert_called_once_with(42, 12, 9, 10, 14) + + m_WIFSIGNALED.reset_mock() + m_WIFEXITED.reset_mock() + m_WEXITSTATUS.reset_mock() + callback.reset_mock() + + # ensure that the child is effectively reaped + self.add_zombie(42, 13) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback.called) + self.assertFalse(m_WTERMSIG.called) + + m_WIFSIGNALED.reset_mock() + m_WIFEXITED.reset_mock() + m_WEXITSTATUS.reset_mock() + + # sigchld called again + self.zombies.clear() + self.watcher._sig_chld() + + self.assertFalse(callback.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) + @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) + @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) + @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) + @unittest.mock.patch('os.waitpid', wraps=waitpid) + def test_sigchld_two_children(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG): + callback1 = unittest.mock.Mock() + callback2 = unittest.mock.Mock() + + # register child 1 + with self.watcher: + self.running = True + self.watcher.add_child_handler(43, callback1, 7, 8) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + # register child 2 + with self.watcher: + self.watcher.add_child_handler(44, callback2, 147, 18) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + # childen are running + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + # child 1 terminates (signal 3) + self.add_zombie(43, -3) + self.watcher._sig_chld() + + callback1.assert_called_once_with(43, -3, 7, 8) + self.assertFalse(callback2.called) + self.assertTrue(m_WIFSIGNALED.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertTrue(m_WTERMSIG.called) + + m_WIFSIGNALED.reset_mock() + m_WIFEXITED.reset_mock() + m_WTERMSIG.reset_mock() + callback1.reset_mock() + + # child 2 still running + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + # child 2 terminates (code 108) + self.add_zombie(44, 108) + self.running = False + self.watcher._sig_chld() + + callback2.assert_called_once_with(44, 108, 147, 18) + self.assertFalse(callback1.called) + self.assertTrue(m_WIFEXITED.called) + self.assertTrue(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + m_WIFSIGNALED.reset_mock() + m_WIFEXITED.reset_mock() + m_WEXITSTATUS.reset_mock() + callback2.reset_mock() + + # ensure that the children are effectively reaped + self.add_zombie(43, 14) + self.add_zombie(44, 15) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m_WTERMSIG.called) + + m_WIFSIGNALED.reset_mock() + m_WIFEXITED.reset_mock() + m_WEXITSTATUS.reset_mock() + + # sigchld called again + self.zombies.clear() + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) + @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) + @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) + @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) + @unittest.mock.patch('os.waitpid', wraps=waitpid) + def test_sigchld_two_children_terminating_together( + self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, m_WEXITSTATUS, + m_WTERMSIG): + callback1 = unittest.mock.Mock() + callback2 = unittest.mock.Mock() + + # register child 1 + with self.watcher: + self.running = True + self.watcher.add_child_handler(45, callback1, 17, 8) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + # register child 2 + with self.watcher: + self.watcher.add_child_handler(46, callback2, 1147, 18) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + # childen are running + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + # child 1 terminates (code 78) + # child 2 terminates (signal 5) + self.add_zombie(45, 78) + self.add_zombie(46, -5) + self.running = False + self.watcher._sig_chld() + + callback1.assert_called_once_with(45, 78, 17, 8) + callback2.assert_called_once_with(46, -5, 1147, 18) + self.assertTrue(m_WIFSIGNALED.called) + self.assertTrue(m_WIFEXITED.called) + self.assertTrue(m_WEXITSTATUS.called) + self.assertTrue(m_WTERMSIG.called) + + m_WIFSIGNALED.reset_mock() + m_WIFEXITED.reset_mock() + m_WTERMSIG.reset_mock() + m_WEXITSTATUS.reset_mock() + callback1.reset_mock() + callback2.reset_mock() + + # ensure that the children are effectively reaped + self.add_zombie(45, 14) + self.add_zombie(46, 15) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) + @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) + @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) + @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) + @unittest.mock.patch('os.waitpid', wraps=waitpid) + def test_sigchld_race_condition( + self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, m_WEXITSTATUS, + m_WTERMSIG): + # register a child + callback = unittest.mock.Mock() + + with self.watcher: + # child terminates before being registered + self.add_zombie(50, 4) + self.watcher._sig_chld() + + self.watcher.add_child_handler(50, callback, 1, 12) + + callback.assert_called_once_with(50, 4, 1, 12) + callback.reset_mock() + + # ensure that the child is effectively reaped + self.add_zombie(50, -1) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback.called) + + @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) + @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) + @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) + @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) + @unittest.mock.patch('os.waitpid', wraps=waitpid) + def test_sigchld_replace_handler( + self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, m_WEXITSTATUS, + m_WTERMSIG): + callback1 = unittest.mock.Mock() + callback2 = unittest.mock.Mock() + + # register a child + with self.watcher: + self.running = True + self.watcher.add_child_handler(51, callback1, 19) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + # register the same child again + with self.watcher: + self.watcher.add_child_handler(51, callback2, 21) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + # child terminates (signal 8) + self.running = False + self.add_zombie(51, -8) + self.watcher._sig_chld() + + callback2.assert_called_once_with(51, -8, 21) + self.assertFalse(callback1.called) + self.assertTrue(m_WIFSIGNALED.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertTrue(m_WTERMSIG.called) + + m_WIFSIGNALED.reset_mock() + m_WIFEXITED.reset_mock() + m_WTERMSIG.reset_mock() + callback2.reset_mock() + + # ensure that the child is effectively reaped + self.add_zombie(51, 13) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m_WTERMSIG.called) + + @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) + @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) + @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) + @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) + @unittest.mock.patch('os.waitpid', wraps=waitpid) + def test_sigchld_remove_handler(self, m_waitpid, m_WIFEXITED, + m_WIFSIGNALED, m_WEXITSTATUS, m_WTERMSIG): + callback = unittest.mock.Mock() + + # register a child + with self.watcher: + self.running = True + self.watcher.add_child_handler(52, callback, 1984) + + self.assertFalse(callback.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + # unregister the child + self.watcher.remove_child_handler(52) + + self.assertFalse(callback.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + # child terminates (code 99) + self.running = False + self.add_zombie(52, 99) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback.called) + + @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) + @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) + @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) + @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) + @unittest.mock.patch('os.waitpid', wraps=waitpid) + def test_sigchld_unknown_status(self, m_waitpid, m_WIFEXITED, + m_WIFSIGNALED, m_WEXITSTATUS, m_WTERMSIG): + callback = unittest.mock.Mock() + + # register a child + with self.watcher: + self.running = True + self.watcher.add_child_handler(53, callback, -19) + + self.assertFalse(callback.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + # terminate with unknown status + self.zombies[53] = 1178 + self.running = False + self.watcher._sig_chld() + + callback.assert_called_once_with(53, 1178, -19) + self.assertTrue(m_WIFEXITED.called) + self.assertTrue(m_WIFSIGNALED.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + callback.reset_mock() + m_WIFEXITED.reset_mock() + m_WIFSIGNALED.reset_mock() + + # ensure that the child is effectively reaped + self.add_zombie(53, 101) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback.called) + + @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) + @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) + @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) + @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) + @unittest.mock.patch('os.waitpid', wraps=waitpid) + def test_remove_child_handler(self, m_waitpid, m_WIFEXITED, + m_WIFSIGNALED, m_WEXITSTATUS, m_WTERMSIG): + callback1 = unittest.mock.Mock() + callback2 = unittest.mock.Mock() + callback3 = unittest.mock.Mock() + + # register children + with self.watcher: + self.running = True + self.watcher.add_child_handler(54, callback1, 1) + self.watcher.add_child_handler(55, callback2, 2) + self.watcher.add_child_handler(56, callback3, 3) + + # remove child handler 1 + self.assertTrue(self.watcher.remove_child_handler(54)) + + # remove child handler 2 multiple times + self.assertTrue(self.watcher.remove_child_handler(55)) + self.assertFalse(self.watcher.remove_child_handler(55)) + self.assertFalse(self.watcher.remove_child_handler(55)) + + # all children terminate + self.add_zombie(54, 0) + self.add_zombie(55, 1) + self.add_zombie(56, 2) + self.running = False + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + callback3.assert_called_once_with(56, 2, 3) + + @unittest.mock.patch('os.waitpid', wraps=waitpid) + def test_sigchld_unhandled_exception(self, m_waitpid): + callback = unittest.mock.Mock() + + # register a child + with self.watcher: + self.running = True + self.watcher.add_child_handler(57, callback) + + # raise an exception + m_waitpid.side_effect = ValueError + + with unittest.mock.patch.object(unix_events.logger, + "exception") as m_exception: + + self.assertEqual(self.watcher._sig_chld(), None) + self.assertTrue(m_exception.called) + + @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) + @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) + @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) + @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) + @unittest.mock.patch('os.waitpid', wraps=waitpid) + def test_sigchld_child_reaped_elsewhere( + self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, m_WEXITSTATUS, + m_WTERMSIG): + + # register a child + callback = unittest.mock.Mock() + + with self.watcher: + self.running = True + self.watcher.add_child_handler(58, callback) + + self.assertFalse(callback.called) + self.assertFalse(m_WIFEXITED.called) + self.assertFalse(m_WIFSIGNALED.called) + self.assertFalse(m_WEXITSTATUS.called) + self.assertFalse(m_WTERMSIG.called) + + # child terminates + self.running = False + self.add_zombie(58, 4) + + # waitpid is called elsewhere + os.waitpid(58, os.WNOHANG) + + m_waitpid.reset_mock() + + # sigchld + with self.ignore_warnings: + self.watcher._sig_chld() + + callback.assert_called(m_waitpid) + if isinstance(self.watcher, unix_events.FastChildWatcher): + # here the FastChildWatche enters a deadlock + # (there is no way to prevent it) + self.assertFalse(callback.called) + else: + callback.assert_called_once_with(58, 255) + + @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) + @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) + @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) + @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) + @unittest.mock.patch('os.waitpid', wraps=waitpid) + def test_sigchld_unknown_pid_during_registration( + self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, m_WEXITSTATUS, + m_WTERMSIG): + + # register two children + callback1 = unittest.mock.Mock() + callback2 = unittest.mock.Mock() + + with self.ignore_warnings, self.watcher: + self.running = True + # child 1 terminates + self.add_zombie(591, 7) + # an unknown child terminates + self.add_zombie(593, 17) + + self.watcher._sig_chld() + + self.watcher.add_child_handler(591, callback1) + self.watcher.add_child_handler(592, callback2) + + callback1.assert_called_once_with(591, 7) + self.assertFalse(callback2.called) + + @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) + @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) + @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) + @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) + @unittest.mock.patch('os.waitpid', wraps=waitpid) + def test_set_loop( + self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, m_WEXITSTATUS, + m_WTERMSIG): + + # register a child + callback = unittest.mock.Mock() + + with self.watcher: + self.running = True + self.watcher.add_child_handler(60, callback) + + # attach a new loop + old_loop = self.loop + self.loop = test_utils.TestLoop() + + with unittest.mock.patch.object( + old_loop, + "remove_signal_handler") as m_old_remove_signal_handler, \ + unittest.mock.patch.object( + self.loop, + "add_signal_handler") as m_new_add_signal_handler: + + self.watcher.set_loop(self.loop) + + m_old_remove_signal_handler.assert_called_once_with( + signal.SIGCHLD) + m_new_add_signal_handler.assert_called_once_with( + signal.SIGCHLD, self.watcher._sig_chld) + + # child terminates + self.running = False + self.add_zombie(60, 9) + self.watcher._sig_chld() + + callback.assert_called_once_with(60, 9) + + @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) + @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) + @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) + @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) + @unittest.mock.patch('os.waitpid', wraps=waitpid) + def test_set_loop_race_condition( + self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, m_WEXITSTATUS, + m_WTERMSIG): + + # register 3 children + callback1 = unittest.mock.Mock() + callback2 = unittest.mock.Mock() + callback3 = unittest.mock.Mock() + + with self.watcher: + self.running = True + self.watcher.add_child_handler(61, callback1) + self.watcher.add_child_handler(62, callback2) + self.watcher.add_child_handler(622, callback3) + + # detach the loop + old_loop = self.loop + self.loop = None + + with unittest.mock.patch.object( + old_loop, "remove_signal_handler") as m_remove_signal_handler: + + self.watcher.set_loop(None) + + m_remove_signal_handler.assert_called_once_with( + signal.SIGCHLD) + + # child 1 & 2 terminate + self.add_zombie(61, 11) + self.add_zombie(62, -5) + + # SIGCHLD was not catched + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(callback3.called) + + # attach a new loop + self.loop = test_utils.TestLoop() + + with unittest.mock.patch.object( + self.loop, "add_signal_handler") as m_add_signal_handler: + + self.watcher.set_loop(self.loop) + + m_add_signal_handler.assert_called_once_with( + signal.SIGCHLD, self.watcher._sig_chld) + callback1.assert_called_once_with(61, 11) # race condition! + callback2.assert_called_once_with(62, -5) # race condition! + self.assertFalse(callback3.called) + + callback1.reset_mock() + callback2.reset_mock() + + # child 3 terminates + self.running = False + self.add_zombie(622, 19) + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + callback3.assert_called_once_with(622, 19) + + @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) + @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) + @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) + @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) + @unittest.mock.patch('os.waitpid', wraps=waitpid) + def test_close( + self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, m_WEXITSTATUS, + m_WTERMSIG): + + # register two children + callback1 = unittest.mock.Mock() + callback2 = unittest.mock.Mock() + + with self.watcher: + self.running = True + # child 1 terminates + self.add_zombie(63, 9) + # other child terminates + self.add_zombie(65, 18) + self.watcher._sig_chld() + + self.watcher.add_child_handler(63, callback1) + self.watcher.add_child_handler(64, callback1) + + self.assertEqual(len(self.watcher._callbacks), 1) + if isinstance(self.watcher, unix_events.FastChildWatcher): + self.assertEqual(len(self.watcher._zombies), 1) + + with unittest.mock.patch.object( + self.loop, + "remove_signal_handler") as m_remove_signal_handler: + + self.watcher.close() + + m_remove_signal_handler.assert_called_once_with( + signal.SIGCHLD) + self.assertFalse(self.watcher._callbacks) + if isinstance(self.watcher, unix_events.FastChildWatcher): + self.assertFalse(self.watcher._zombies) + + +class SafeChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase): + def create_watcher(self, loop): + return unix_events.SafeChildWatcher(loop) + + +class FastChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase): + def create_watcher(self, loop): + return unix_events.FastChildWatcher(loop) + + +class PolicyTests(unittest.TestCase): + + def create_policy(self): + return unix_events.DefaultEventLoopPolicy() + + def test_get_child_watcher(self): + policy = self.create_policy() + self.assertIsNone(policy._watcher) + + watcher = policy.get_child_watcher() + self.assertIsInstance(watcher, unix_events.SafeChildWatcher) + + self.assertIs(policy._watcher, watcher) + + self.assertIs(watcher, policy.get_child_watcher()) + self.assertIsNone(watcher._loop) + + def test_get_child_watcher_after_set(self): + policy = self.create_policy() + watcher = unix_events.FastChildWatcher(None) + + policy.set_child_watcher(watcher) + self.assertIs(policy._watcher, watcher) + self.assertIs(watcher, policy.get_child_watcher()) + + def test_get_child_watcher_with_mainloop_existing(self): + policy = self.create_policy() + loop = policy.get_event_loop() + + self.assertIsNone(policy._watcher) + watcher = policy.get_child_watcher() + + self.assertIsInstance(watcher, unix_events.SafeChildWatcher) + self.assertIs(watcher._loop, loop) + + loop.close() + + def test_get_child_watcher_thread(self): + + def f(): + policy.set_event_loop(policy.new_event_loop()) + + self.assertIsInstance(policy.get_event_loop(), + events.AbstractEventLoop) + watcher = policy.get_child_watcher() + + self.assertIsInstance(watcher, unix_events.SafeChildWatcher) + self.assertIsNone(watcher._loop) + + policy.get_event_loop().close() + + policy = self.create_policy() + + th = threading.Thread(target=f) + th.start() + th.join() + + def test_child_watcher_replace_mainloop_existing(self): + policy = self.create_policy() + loop = policy.get_event_loop() + + watcher = policy.get_child_watcher() + + self.assertIs(watcher._loop, loop) + + new_loop = policy.new_event_loop() + policy.set_event_loop(new_loop) + + self.assertIs(watcher._loop, new_loop) + + policy.set_event_loop(None) + + self.assertIs(watcher._loop, None) + + loop.close() + new_loop.close() + + if __name__ == '__main__': unittest.main() From f3ffb8eee496de4aa4050105d3d08accbe4db9c1 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 4 Nov 2013 15:40:45 -0800 Subject: [PATCH 0774/1502] Fix policy refactoring last-minute change for Windows. --- asyncio/windows_events.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index ae3e44f4..64fe3861 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -458,8 +458,11 @@ def callback(f): f.add_done_callback(callback) +SelectorEventLoop = _WindowsSelectorEventLoop + + class _WindowsDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy): _loop_factory = SelectorEventLoop -SelectorEventLoop = _WindowsSelectorEventLoop + DefaultEventLoopPolicy = _WindowsDefaultEventLoopPolicy From dfca215c884d31319ddcedeff78d97a6e04a338c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 6 Nov 2013 20:24:10 -0800 Subject: [PATCH 0775/1502] Add close() back to Unix selector event loop, to remove all signal handlers. --- asyncio/unix_events.py | 5 +++++ tests/test_unix_events.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index dd57fe8e..f4cf6e7f 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -48,6 +48,11 @@ def __init__(self, selector=None): def _socketpair(self): return socket.socketpair() + def close(self): + for sig in list(self._signal_handlers): + self.remove_signal_handler(sig) + super().close() + def add_signal_handler(self, sig, callback, *args): """Add a handler for a signal. UNIX only. diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index a4d835e3..42eba8d6 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -183,6 +183,22 @@ class Err(OSError): self.assertRaises( RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP) + @unittest.mock.patch('asyncio.unix_events.signal') + def test_close(self, m_signal): + m_signal.NSIG = signal.NSIG + + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + self.loop.add_signal_handler(signal.SIGCHLD, lambda: True) + + self.assertEqual(len(self.loop._signal_handlers), 2) + + m_signal.set_wakeup_fd.reset_mock() + + self.loop.close() + + self.assertEqual(len(self.loop._signal_handlers), 0) + m_signal.set_wakeup_fd.assert_called_once_with(-1) + class UnixReadPipeTransportTests(unittest.TestCase): From f481dd128154d21ea716d66c9ca185441c231e25 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 7 Nov 2013 08:41:36 -0800 Subject: [PATCH 0776/1502] Optimize BaseSelector.modify(). Patch by Arnaud Faure (without tests). --- asyncio/selectors.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/asyncio/selectors.py b/asyncio/selectors.py index 3638e854..3971502e 100644 --- a/asyncio/selectors.py +++ b/asyncio/selectors.py @@ -138,11 +138,14 @@ def modify(self, fileobj, events, data=None): key = self._fd_to_key[_fileobj_to_fd(fileobj)] except KeyError: raise KeyError("{!r} is not registered".format(fileobj)) from None - if events != key.events or data != key.data: - # TODO: If only the data changed, use a shortcut that only - # updates the data. + if events != key.events: self.unregister(fileobj) return self.register(fileobj, events, data) + elif data != key.data: + # Use a shortcut to update the data. + key = key._replace(data=data) + self._fd_to_key[key.fd] = key + return key else: return key From 39fd5e17c28278aae732b3aeb010f18b67de1417 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 7 Nov 2013 09:20:46 -0800 Subject: [PATCH 0777/1502] Improved tests for selectors.py by Arnaud Faure. --- tests/test_selectors.py | 67 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 66 insertions(+), 1 deletion(-) diff --git a/tests/test_selectors.py b/tests/test_selectors.py index db5b3ece..78012897 100644 --- a/tests/test_selectors.py +++ b/tests/test_selectors.py @@ -13,8 +13,54 @@ def select(self, timeout=None): raise NotImplementedError -class BaseSelectorTests(unittest.TestCase): +class _SelectorMappingTests(unittest.TestCase): + + def test_len(self): + s = FakeSelector() + map = selectors._SelectorMapping(s) + self.assertTrue(map.__len__() == 0) + + f = unittest.mock.Mock() + f.fileno.return_value = 10 + s.register(f, selectors.EVENT_READ, None) + self.assertTrue(len(map) == 1) + + def test_getitem(self): + s = FakeSelector() + map = selectors._SelectorMapping(s) + f = unittest.mock.Mock() + f.fileno.return_value = 10 + s.register(f, selectors.EVENT_READ, None) + attended = selectors.SelectorKey(f, 10, selectors.EVENT_READ, None) + self.assertEqual(attended, map.__getitem__(f)) + + def test_getitem_key_error(self): + s = FakeSelector() + map = selectors._SelectorMapping(s) + self.assertTrue(len(map) == 0) + f = unittest.mock.Mock() + f.fileno.return_value = 10 + s.register(f, selectors.EVENT_READ, None) + self.assertRaises(KeyError, map.__getitem__, 5) + + def test_iter(self): + s = FakeSelector() + map = selectors._SelectorMapping(s) + self.assertTrue(len(map) == 0) + f = unittest.mock.Mock() + f.fileno.return_value = 5 + s.register(f, selectors.EVENT_READ, None) + counter = 0 + for fileno in map.__iter__(): + self.assertEqual(5, fileno) + counter += 1 + for idx in map: + self.assertEqual(f, map[idx].fileobj) + self.assertEqual(1, counter) + + +class BaseSelectorTests(unittest.TestCase): def test_fileobj_to_fd(self): self.assertEqual(10, selectors._fileobj_to_fd(10)) @@ -25,6 +71,9 @@ def test_fileobj_to_fd(self): f.fileno.side_effect = AttributeError self.assertRaises(ValueError, selectors._fileobj_to_fd, f) + f.fileno.return_value = -1 + self.assertRaises(ValueError, selectors._fileobj_to_fd, f) + def test_selector_key_repr(self): key = selectors.SelectorKey(10, 10, selectors.EVENT_READ, None) self.assertEqual( @@ -103,6 +152,22 @@ def test_modify_data(self): selectors.SelectorKey(fobj, 10, selectors.EVENT_READ, d2), s.get_key(fobj)) + def test_modify_data_use_a_shortcut(self): + fobj = unittest.mock.Mock() + fobj.fileno.return_value = 10 + + d1 = object() + d2 = object() + + s = FakeSelector() + key = s.register(fobj, selectors.EVENT_READ, d1) + + s.unregister = unittest.mock.Mock() + s.register = unittest.mock.Mock() + key2 = s.modify(fobj, selectors.EVENT_READ, d2) + self.assertFalse(s.unregister.called) + self.assertFalse(s.register.called) + def test_modify_same(self): fobj = unittest.mock.Mock() fobj.fileno.return_value = 10 From 0e86c1942368310b02aaff834025a96b84d45bc0 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 7 Nov 2013 09:26:32 -0800 Subject: [PATCH 0778/1502] Skip test_selectors.py when copying to stdlib. --- update_stdlib.sh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/update_stdlib.sh b/update_stdlib.sh index ceaa8f54..828de2c8 100755 --- a/update_stdlib.sh +++ b/update_stdlib.sh @@ -49,6 +49,10 @@ done for i in `(cd tests && ls *.py sample.???)` do + if [ $i == test_selectors.py ] + then + continue + fi maybe_copy tests/$i Lib/test/test_asyncio/$i done From 219dec2cc05074c240659d54060e690f6ebb5afc Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 13 Nov 2013 15:47:40 -0800 Subject: [PATCH 0779/1502] Fix from Anthony Baire for CPython issue 19566. --- asyncio/unix_events.py | 68 +++++++++++++++++++++++---------------- tests/test_events.py | 4 ++- tests/test_unix_events.py | 29 +++++++++-------- 3 files changed, 60 insertions(+), 41 deletions(-) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index f4cf6e7f..b611efd1 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -440,10 +440,13 @@ def remove_child_handler(self, pid): raise NotImplementedError() - def set_loop(self, loop): - """Reattach the watcher to another event loop. + def attach_loop(self, loop): + """Attach the watcher to an event loop. - Note: loop may be None + If the watcher was previously attached to an event loop, then it is + first detached before attaching to the new loop. + + Note: loop may be None. """ raise NotImplementedError() @@ -467,15 +470,11 @@ def __exit__(self, a, b, c): class BaseChildWatcher(AbstractChildWatcher): - def __init__(self, loop): + def __init__(self): self._loop = None - self._callbacks = {} - - self.set_loop(loop) def close(self): - self.set_loop(None) - self._callbacks.clear() + self.attach_loop(None) def _do_waitpid(self, expected_pid): raise NotImplementedError() @@ -483,7 +482,7 @@ def _do_waitpid(self, expected_pid): def _do_waitpid_all(self): raise NotImplementedError() - def set_loop(self, loop): + def attach_loop(self, loop): assert loop is None or isinstance(loop, events.AbstractEventLoop) if self._loop is not None: @@ -497,13 +496,6 @@ def set_loop(self, loop): # during the switch. self._do_waitpid_all() - def remove_child_handler(self, pid): - try: - del self._callbacks[pid] - return True - except KeyError: - return False - def _sig_chld(self): try: self._do_waitpid_all() @@ -535,6 +527,14 @@ class SafeChildWatcher(BaseChildWatcher): big number of children (O(n) each time SIGCHLD is raised) """ + def __init__(self): + super().__init__() + self._callbacks = {} + + def close(self): + self._callbacks.clear() + super().close() + def __enter__(self): return self @@ -547,6 +547,13 @@ def add_child_handler(self, pid, callback, *args): # Prevent a race condition in case the child is already terminated. self._do_waitpid(pid) + def remove_child_handler(self, pid): + try: + del self._callbacks[pid] + return True + except KeyError: + return False + def _do_waitpid_all(self): for pid in list(self._callbacks): @@ -592,16 +599,17 @@ class FastChildWatcher(BaseChildWatcher): There is no noticeable overhead when handling a big number of children (O(1) each time a child terminates). """ - def __init__(self, loop): - super().__init__(loop) - + def __init__(self): + super().__init__() + self._callbacks = {} self._lock = threading.Lock() self._zombies = {} self._forks = 0 def close(self): - super().close() + self._callbacks.clear() self._zombies.clear() + super().close() def __enter__(self): with self._lock: @@ -642,6 +650,13 @@ def add_child_handler(self, pid, callback, *args): else: callback(pid, returncode, *args) + def remove_child_handler(self, pid): + try: + del self._callbacks[pid] + return True + except KeyError: + return False + def _do_waitpid_all(self): # Because of signal coalescing, we must keep calling waitpid() as # long as we're able to reap a child. @@ -686,25 +701,24 @@ def __init__(self): def _init_watcher(self): with events._lock: if self._watcher is None: # pragma: no branch + self._watcher = SafeChildWatcher() if isinstance(threading.current_thread(), threading._MainThread): - self._watcher = SafeChildWatcher(self._local._loop) - else: - self._watcher = SafeChildWatcher(None) + self._watcher.attach_loop(self._local._loop) def set_event_loop(self, loop): """Set the event loop. As a side effect, if a child watcher was set before, then calling - .set_event_loop() from the main thread will call .set_loop(loop) on the - child watcher. + .set_event_loop() from the main thread will call .attach_loop(loop) on + the child watcher. """ super().set_event_loop(loop) if self._watcher is not None and \ isinstance(threading.current_thread(), threading._MainThread): - self._watcher.set_loop(loop) + self._watcher.attach_loop(loop) def get_child_watcher(self): """Get the child watcher diff --git a/tests/test_events.py b/tests/test_events.py index 00bd4085..7b9839ce 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1311,7 +1311,9 @@ def test_create_datagram_endpoint(self): class UnixEventLoopTestsMixin(EventLoopTestsMixin): def setUp(self): super().setUp() - events.set_child_watcher(unix_events.SafeChildWatcher(self.loop)) + watcher = unix_events.SafeChildWatcher() + watcher.attach_loop(self.loop) + events.set_child_watcher(watcher) def tearDown(self): events.set_child_watcher(None) diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 42eba8d6..ea1c08cf 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -687,7 +687,7 @@ def test_not_implemented(self): self.assertRaises( NotImplementedError, watcher.remove_child_handler, f) self.assertRaises( - NotImplementedError, watcher.set_loop, f) + NotImplementedError, watcher.attach_loop, f) self.assertRaises( NotImplementedError, watcher.close) self.assertRaises( @@ -700,7 +700,7 @@ class BaseChildWatcherTests(unittest.TestCase): def test_not_implemented(self): f = unittest.mock.Mock() - watcher = unix_events.BaseChildWatcher(None) + watcher = unix_events.BaseChildWatcher() self.assertRaises( NotImplementedError, watcher._do_waitpid, f) @@ -720,10 +720,13 @@ def setUp(self): with unittest.mock.patch.object( self.loop, "add_signal_handler") as self.m_add_signal_handler: - self.watcher = self.create_watcher(self.loop) + self.watcher = self.create_watcher() + self.watcher.attach_loop(self.loop) - def tearDown(self): - ChildWatcherTestsMixin.instance = None + def cleanup(): + ChildWatcherTestsMixin.instance = None + + self.addCleanup(cleanup) def waitpid(pid, flags): self = ChildWatcherTestsMixin.instance @@ -1334,7 +1337,7 @@ def test_set_loop( self.loop, "add_signal_handler") as m_new_add_signal_handler: - self.watcher.set_loop(self.loop) + self.watcher.attach_loop(self.loop) m_old_remove_signal_handler.assert_called_once_with( signal.SIGCHLD) @@ -1375,7 +1378,7 @@ def test_set_loop_race_condition( with unittest.mock.patch.object( old_loop, "remove_signal_handler") as m_remove_signal_handler: - self.watcher.set_loop(None) + self.watcher.attach_loop(None) m_remove_signal_handler.assert_called_once_with( signal.SIGCHLD) @@ -1395,7 +1398,7 @@ def test_set_loop_race_condition( with unittest.mock.patch.object( self.loop, "add_signal_handler") as m_add_signal_handler: - self.watcher.set_loop(self.loop) + self.watcher.attach_loop(self.loop) m_add_signal_handler.assert_called_once_with( signal.SIGCHLD, self.watcher._sig_chld) @@ -1457,13 +1460,13 @@ def test_close( class SafeChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase): - def create_watcher(self, loop): - return unix_events.SafeChildWatcher(loop) + def create_watcher(self): + return unix_events.SafeChildWatcher() class FastChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase): - def create_watcher(self, loop): - return unix_events.FastChildWatcher(loop) + def create_watcher(self): + return unix_events.FastChildWatcher() class PolicyTests(unittest.TestCase): @@ -1485,7 +1488,7 @@ def test_get_child_watcher(self): def test_get_child_watcher_after_set(self): policy = self.create_policy() - watcher = unix_events.FastChildWatcher(None) + watcher = unix_events.FastChildWatcher() policy.set_child_watcher(watcher) self.assertIs(policy._watcher, watcher) From 34b79e8fb1ebf8f48596d20c8df791acddae53d9 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 13 Nov 2013 20:16:49 -0800 Subject: [PATCH 0780/1502] Relax timing requirement. Fixes CPython issue 19579. --- tests/test_base_events.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 9b883c52..3c4d52e6 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -170,7 +170,7 @@ def cb(): f.cancel() # Don't complain about abandoned Future. def test__run_once(self): - h1 = events.TimerHandle(time.monotonic() + 0.1, lambda: True, ()) + h1 = events.TimerHandle(time.monotonic() + 5.0, lambda: True, ()) h2 = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ()) h1.cancel() @@ -181,7 +181,7 @@ def test__run_once(self): self.loop._run_once() t = self.loop._selector.select.call_args[0][0] - self.assertTrue(9.99 < t < 10.1, t) + self.assertTrue(9.9 < t < 10.1, t) self.assertEqual([h2], self.loop._scheduled) self.assertTrue(self.loop._process_events.called) From c3a40f341972fe6454258bc1de0a9f57fb36be88 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 14 Nov 2013 10:05:09 -0800 Subject: [PATCH 0781/1502] Avoid ResourceWarning. Fix for CPython issue 19580 by Vajrasky Kok. --- tests/test_base_events.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 3c4d52e6..5178d543 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -458,16 +458,26 @@ def mock_getaddrinfo(*args, **kwds): self.loop.sock_connect.return_value = () self.loop._make_ssl_transport = unittest.mock.Mock() + class _SelectorTransportMock: + _sock = None + + def close(self): + self._sock.close() + def mock_make_ssl_transport(sock, protocol, sslcontext, waiter, **kwds): waiter.set_result(None) + transport = _SelectorTransportMock() + transport._sock = sock + return transport self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport ANY = unittest.mock.ANY # First try the default server_hostname. self.loop._make_ssl_transport.reset_mock() coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True) - self.loop.run_until_complete(coro) + transport, _ = self.loop.run_until_complete(coro) + transport.close() self.loop._make_ssl_transport.assert_called_with( ANY, ANY, ANY, ANY, server_side=False, @@ -476,7 +486,8 @@ def mock_make_ssl_transport(sock, protocol, sslcontext, waiter, self.loop._make_ssl_transport.reset_mock() coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True, server_hostname='perl.com') - self.loop.run_until_complete(coro) + transport, _ = self.loop.run_until_complete(coro) + transport.close() self.loop._make_ssl_transport.assert_called_with( ANY, ANY, ANY, ANY, server_side=False, @@ -485,7 +496,8 @@ def mock_make_ssl_transport(sock, protocol, sslcontext, waiter, self.loop._make_ssl_transport.reset_mock() coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True, server_hostname='') - self.loop.run_until_complete(coro) + transport, _ = self.loop.run_until_complete(coro) + transport.close() self.loop._make_ssl_transport.assert_called_with(ANY, ANY, ANY, ANY, server_side=False, server_hostname='') @@ -505,8 +517,10 @@ def test_create_connection_ssl_server_hostname_errors(self): self.assertRaises(ValueError, self.loop.run_until_complete, coro) coro = self.loop.create_connection(MyProto, None, 80, ssl=True) self.assertRaises(ValueError, self.loop.run_until_complete, coro) + sock = socket.socket() coro = self.loop.create_connection(MyProto, None, None, - ssl=True, sock=socket.socket()) + ssl=True, sock=sock) + self.addCleanup(sock.close) self.assertRaises(ValueError, self.loop.run_until_complete, coro) def test_create_server_empty_host(self): From bb6916ad9456ffcc9dc831b62879e6fe4dab2488 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 14 Nov 2013 13:31:02 -0800 Subject: [PATCH 0782/1502] Use more specific asserts for some tests. From CPython issue 19589 by Serhiy Storchaka. --- tests/test_events.py | 18 ++++++++++-------- tests/test_tasks.py | 2 +- tests/test_unix_events.py | 2 +- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/test_events.py b/tests/test_events.py index 7b9839ce..2338546a 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -472,8 +472,8 @@ def test_create_connection(self): f = self.loop.create_connection( lambda: MyProto(loop=self.loop), *httpd.address) tr, pr = self.loop.run_until_complete(f) - self.assertTrue(isinstance(tr, transports.Transport)) - self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertIsInstance(tr, transports.Transport) + self.assertIsInstance(pr, protocols.Protocol) self.loop.run_until_complete(pr.done) self.assertGreater(pr.nbytes, 0) tr.close() @@ -500,8 +500,8 @@ def test_create_connection_sock(self): f = self.loop.create_connection( lambda: MyProto(loop=self.loop), sock=sock) tr, pr = self.loop.run_until_complete(f) - self.assertTrue(isinstance(tr, transports.Transport)) - self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertIsInstance(tr, transports.Transport) + self.assertIsInstance(pr, protocols.Protocol) self.loop.run_until_complete(pr.done) self.assertGreater(pr.nbytes, 0) tr.close() @@ -513,8 +513,8 @@ def test_create_ssl_connection(self): lambda: MyProto(loop=self.loop), *httpd.address, ssl=test_utils.dummy_ssl_context()) tr, pr = self.loop.run_until_complete(f) - self.assertTrue(isinstance(tr, transports.Transport)) - self.assertTrue(isinstance(pr, protocols.Protocol)) + self.assertIsInstance(tr, transports.Transport) + self.assertIsInstance(pr, protocols.Protocol) self.assertTrue('ssl' in tr.__class__.__name__.lower()) self.assertIsNotNone(tr.get_extra_info('sockname')) self.loop.run_until_complete(pr.done) @@ -926,7 +926,8 @@ def test_prompt_cancellation(self): r.setblocking(False) f = self.loop.sock_recv(r, 1) ov = getattr(f, 'ov', None) - self.assertTrue(ov is None or ov.pending) + if ov is not None: + self.assertTrue(ov.pending) @tasks.coroutine def main(): @@ -949,7 +950,8 @@ def main(): self.assertLess(elapsed, 0.1) self.assertEqual(t.result(), 'cancelled') self.assertRaises(futures.CancelledError, f.result) - self.assertTrue(ov is None or not ov.pending) + if ov is not None: + self.assertFalse(ov.pending) self.loop._stop_serving(r) r.close() diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 57fb0537..8f0d0815 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -779,7 +779,7 @@ def foo(): self.assertEqual(len(res), 2, res) self.assertEqual(res[0], (1, 'a')) self.assertEqual(res[1][0], 2) - self.assertTrue(isinstance(res[1][1], futures.TimeoutError)) + self.assertIsInstance(res[1][1], futures.TimeoutError) self.assertAlmostEqual(0.12, loop.time()) # move forward to close generator diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index ea1c08cf..af86be19 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -67,7 +67,7 @@ def test_add_signal_handler(self, m_signal): cb = lambda: True self.loop.add_signal_handler(signal.SIGHUP, cb) h = self.loop._signal_handlers.get(signal.SIGHUP) - self.assertTrue(isinstance(h, events.Handle)) + self.assertIsInstance(h, events.Handle) self.assertEqual(h._callback, cb) @unittest.mock.patch('asyncio.unix_events.signal') From a8ff6d044589d6be10246d078519516212aec2dc Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 14 Nov 2013 16:10:08 -0800 Subject: [PATCH 0783/1502] Refactor waitpid mocks. Patch by Anthony Baire. --- tests/test_unix_events.py | 435 ++++++++++++++++---------------------- 1 file changed, 186 insertions(+), 249 deletions(-) diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index af86be19..fdd90495 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -1,5 +1,6 @@ """Tests for unix_events.py.""" +import collections import gc import errno import io @@ -705,8 +706,16 @@ def test_not_implemented(self): NotImplementedError, watcher._do_waitpid, f) +WaitPidMocks = collections.namedtuple("WaitPidMocks", + ("waitpid", + "WIFEXITED", + "WIFSIGNALED", + "WEXITSTATUS", + "WTERMSIG", + )) + + class ChildWatcherTestsMixin: - instance = None ignore_warnings = unittest.mock.patch.object(unix_events.logger, "warning") @@ -715,21 +724,12 @@ def setUp(self): self.running = False self.zombies = {} - assert ChildWatcherTestsMixin.instance is None - ChildWatcherTestsMixin.instance = self - with unittest.mock.patch.object( self.loop, "add_signal_handler") as self.m_add_signal_handler: self.watcher = self.create_watcher() self.watcher.attach_loop(self.loop) - def cleanup(): - ChildWatcherTestsMixin.instance = None - - self.addCleanup(cleanup) - - def waitpid(pid, flags): - self = ChildWatcherTestsMixin.instance + def waitpid(self, pid, flags): if isinstance(self.watcher, unix_events.SafeChildWatcher) or pid != -1: self.assertGreater(pid, 0) try: @@ -747,33 +747,43 @@ def waitpid(pid, flags): def add_zombie(self, pid, returncode): self.zombies[pid] = returncode + 32768 - def WIFEXITED(status): + def WIFEXITED(self, status): return status >= 32768 - def WIFSIGNALED(status): + def WIFSIGNALED(self, status): return 32700 < status < 32768 - def WEXITSTATUS(status): - self = ChildWatcherTestsMixin.instance - self.assertTrue(type(self).WIFEXITED(status)) + def WEXITSTATUS(self, status): + self.assertTrue(self.WIFEXITED(status)) return status - 32768 - def WTERMSIG(status): - self = ChildWatcherTestsMixin.instance - self.assertTrue(type(self).WIFSIGNALED(status)) + def WTERMSIG(self, status): + self.assertTrue(self.WIFSIGNALED(status)) return 32768 - status def test_create_watcher(self): self.m_add_signal_handler.assert_called_once_with( signal.SIGCHLD, self.watcher._sig_chld) - @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) - @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) - @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) - @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) - @unittest.mock.patch('os.waitpid', wraps=waitpid) - def test_sigchld(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, - m_WEXITSTATUS, m_WTERMSIG): + def waitpid_mocks(func): + def wrapped_func(self): + def patch(target, wrapper): + return unittest.mock.patch(target, wraps=wrapper, + new_callable=unittest.mock.Mock) + + with patch('os.WTERMSIG', self.WTERMSIG) as m_WTERMSIG, \ + patch('os.WEXITSTATUS', self.WEXITSTATUS) as m_WEXITSTATUS, \ + patch('os.WIFSIGNALED', self.WIFSIGNALED) as m_WIFSIGNALED, \ + patch('os.WIFEXITED', self.WIFEXITED) as m_WIFEXITED, \ + patch('os.waitpid', self.waitpid) as m_waitpid: + func(self, WaitPidMocks(m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG, + )) + return wrapped_func + + @waitpid_mocks + def test_sigchld(self, m): # register a child callback = unittest.mock.Mock() @@ -782,33 +792,33 @@ def test_sigchld(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, self.watcher.add_child_handler(42, callback, 9, 10, 14) self.assertFalse(callback.called) - self.assertFalse(m_WIFEXITED.called) - self.assertFalse(m_WIFSIGNALED.called) - self.assertFalse(m_WEXITSTATUS.called) - self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) # child is running self.watcher._sig_chld() self.assertFalse(callback.called) - self.assertFalse(m_WIFEXITED.called) - self.assertFalse(m_WIFSIGNALED.called) - self.assertFalse(m_WEXITSTATUS.called) - self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) # child terminates (returncode 12) self.running = False self.add_zombie(42, 12) self.watcher._sig_chld() - self.assertTrue(m_WIFEXITED.called) - self.assertTrue(m_WEXITSTATUS.called) - self.assertFalse(m_WTERMSIG.called) + self.assertTrue(m.WIFEXITED.called) + self.assertTrue(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) callback.assert_called_once_with(42, 12, 9, 10, 14) - m_WIFSIGNALED.reset_mock() - m_WIFEXITED.reset_mock() - m_WEXITSTATUS.reset_mock() + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WEXITSTATUS.reset_mock() callback.reset_mock() # ensure that the child is effectively reaped @@ -817,29 +827,24 @@ def test_sigchld(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, self.watcher._sig_chld() self.assertFalse(callback.called) - self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m.WTERMSIG.called) - m_WIFSIGNALED.reset_mock() - m_WIFEXITED.reset_mock() - m_WEXITSTATUS.reset_mock() + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WEXITSTATUS.reset_mock() # sigchld called again self.zombies.clear() self.watcher._sig_chld() self.assertFalse(callback.called) - self.assertFalse(m_WIFEXITED.called) - self.assertFalse(m_WIFSIGNALED.called) - self.assertFalse(m_WEXITSTATUS.called) - self.assertFalse(m_WTERMSIG.called) - - @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) - @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) - @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) - @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) - @unittest.mock.patch('os.waitpid', wraps=waitpid) - def test_sigchld_two_children(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, - m_WEXITSTATUS, m_WTERMSIG): + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + @waitpid_mocks + def test_sigchld_two_children(self, m): callback1 = unittest.mock.Mock() callback2 = unittest.mock.Mock() @@ -850,10 +855,10 @@ def test_sigchld_two_children(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, self.assertFalse(callback1.called) self.assertFalse(callback2.called) - self.assertFalse(m_WIFEXITED.called) - self.assertFalse(m_WIFSIGNALED.called) - self.assertFalse(m_WEXITSTATUS.called) - self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) # register child 2 with self.watcher: @@ -861,20 +866,20 @@ def test_sigchld_two_children(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, self.assertFalse(callback1.called) self.assertFalse(callback2.called) - self.assertFalse(m_WIFEXITED.called) - self.assertFalse(m_WIFSIGNALED.called) - self.assertFalse(m_WEXITSTATUS.called) - self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) # childen are running self.watcher._sig_chld() self.assertFalse(callback1.called) self.assertFalse(callback2.called) - self.assertFalse(m_WIFEXITED.called) - self.assertFalse(m_WIFSIGNALED.called) - self.assertFalse(m_WEXITSTATUS.called) - self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) # child 1 terminates (signal 3) self.add_zombie(43, -3) @@ -882,13 +887,13 @@ def test_sigchld_two_children(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, callback1.assert_called_once_with(43, -3, 7, 8) self.assertFalse(callback2.called) - self.assertTrue(m_WIFSIGNALED.called) - self.assertFalse(m_WEXITSTATUS.called) - self.assertTrue(m_WTERMSIG.called) + self.assertTrue(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertTrue(m.WTERMSIG.called) - m_WIFSIGNALED.reset_mock() - m_WIFEXITED.reset_mock() - m_WTERMSIG.reset_mock() + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WTERMSIG.reset_mock() callback1.reset_mock() # child 2 still running @@ -896,10 +901,10 @@ def test_sigchld_two_children(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, self.assertFalse(callback1.called) self.assertFalse(callback2.called) - self.assertFalse(m_WIFEXITED.called) - self.assertFalse(m_WIFSIGNALED.called) - self.assertFalse(m_WEXITSTATUS.called) - self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) # child 2 terminates (code 108) self.add_zombie(44, 108) @@ -908,13 +913,13 @@ def test_sigchld_two_children(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, callback2.assert_called_once_with(44, 108, 147, 18) self.assertFalse(callback1.called) - self.assertTrue(m_WIFEXITED.called) - self.assertTrue(m_WEXITSTATUS.called) - self.assertFalse(m_WTERMSIG.called) + self.assertTrue(m.WIFEXITED.called) + self.assertTrue(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) - m_WIFSIGNALED.reset_mock() - m_WIFEXITED.reset_mock() - m_WEXITSTATUS.reset_mock() + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WEXITSTATUS.reset_mock() callback2.reset_mock() # ensure that the children are effectively reaped @@ -925,11 +930,11 @@ def test_sigchld_two_children(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, self.assertFalse(callback1.called) self.assertFalse(callback2.called) - self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m.WTERMSIG.called) - m_WIFSIGNALED.reset_mock() - m_WIFEXITED.reset_mock() - m_WEXITSTATUS.reset_mock() + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WEXITSTATUS.reset_mock() # sigchld called again self.zombies.clear() @@ -937,19 +942,13 @@ def test_sigchld_two_children(self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, self.assertFalse(callback1.called) self.assertFalse(callback2.called) - self.assertFalse(m_WIFEXITED.called) - self.assertFalse(m_WIFSIGNALED.called) - self.assertFalse(m_WEXITSTATUS.called) - self.assertFalse(m_WTERMSIG.called) - - @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) - @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) - @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) - @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) - @unittest.mock.patch('os.waitpid', wraps=waitpid) - def test_sigchld_two_children_terminating_together( - self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, m_WEXITSTATUS, - m_WTERMSIG): + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + @waitpid_mocks + def test_sigchld_two_children_terminating_together(self, m): callback1 = unittest.mock.Mock() callback2 = unittest.mock.Mock() @@ -960,10 +959,10 @@ def test_sigchld_two_children_terminating_together( self.assertFalse(callback1.called) self.assertFalse(callback2.called) - self.assertFalse(m_WIFEXITED.called) - self.assertFalse(m_WIFSIGNALED.called) - self.assertFalse(m_WEXITSTATUS.called) - self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) # register child 2 with self.watcher: @@ -971,20 +970,20 @@ def test_sigchld_two_children_terminating_together( self.assertFalse(callback1.called) self.assertFalse(callback2.called) - self.assertFalse(m_WIFEXITED.called) - self.assertFalse(m_WIFSIGNALED.called) - self.assertFalse(m_WEXITSTATUS.called) - self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) # childen are running self.watcher._sig_chld() self.assertFalse(callback1.called) self.assertFalse(callback2.called) - self.assertFalse(m_WIFEXITED.called) - self.assertFalse(m_WIFSIGNALED.called) - self.assertFalse(m_WEXITSTATUS.called) - self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) # child 1 terminates (code 78) # child 2 terminates (signal 5) @@ -995,15 +994,15 @@ def test_sigchld_two_children_terminating_together( callback1.assert_called_once_with(45, 78, 17, 8) callback2.assert_called_once_with(46, -5, 1147, 18) - self.assertTrue(m_WIFSIGNALED.called) - self.assertTrue(m_WIFEXITED.called) - self.assertTrue(m_WEXITSTATUS.called) - self.assertTrue(m_WTERMSIG.called) - - m_WIFSIGNALED.reset_mock() - m_WIFEXITED.reset_mock() - m_WTERMSIG.reset_mock() - m_WEXITSTATUS.reset_mock() + self.assertTrue(m.WIFSIGNALED.called) + self.assertTrue(m.WIFEXITED.called) + self.assertTrue(m.WEXITSTATUS.called) + self.assertTrue(m.WTERMSIG.called) + + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WTERMSIG.reset_mock() + m.WEXITSTATUS.reset_mock() callback1.reset_mock() callback2.reset_mock() @@ -1015,16 +1014,10 @@ def test_sigchld_two_children_terminating_together( self.assertFalse(callback1.called) self.assertFalse(callback2.called) - self.assertFalse(m_WTERMSIG.called) - - @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) - @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) - @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) - @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) - @unittest.mock.patch('os.waitpid', wraps=waitpid) - def test_sigchld_race_condition( - self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, m_WEXITSTATUS, - m_WTERMSIG): + self.assertFalse(m.WTERMSIG.called) + + @waitpid_mocks + def test_sigchld_race_condition(self, m): # register a child callback = unittest.mock.Mock() @@ -1045,14 +1038,8 @@ def test_sigchld_race_condition( self.assertFalse(callback.called) - @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) - @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) - @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) - @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) - @unittest.mock.patch('os.waitpid', wraps=waitpid) - def test_sigchld_replace_handler( - self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, m_WEXITSTATUS, - m_WTERMSIG): + @waitpid_mocks + def test_sigchld_replace_handler(self, m): callback1 = unittest.mock.Mock() callback2 = unittest.mock.Mock() @@ -1063,10 +1050,10 @@ def test_sigchld_replace_handler( self.assertFalse(callback1.called) self.assertFalse(callback2.called) - self.assertFalse(m_WIFEXITED.called) - self.assertFalse(m_WIFSIGNALED.called) - self.assertFalse(m_WEXITSTATUS.called) - self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) # register the same child again with self.watcher: @@ -1074,10 +1061,10 @@ def test_sigchld_replace_handler( self.assertFalse(callback1.called) self.assertFalse(callback2.called) - self.assertFalse(m_WIFEXITED.called) - self.assertFalse(m_WIFSIGNALED.called) - self.assertFalse(m_WEXITSTATUS.called) - self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) # child terminates (signal 8) self.running = False @@ -1086,13 +1073,13 @@ def test_sigchld_replace_handler( callback2.assert_called_once_with(51, -8, 21) self.assertFalse(callback1.called) - self.assertTrue(m_WIFSIGNALED.called) - self.assertFalse(m_WEXITSTATUS.called) - self.assertTrue(m_WTERMSIG.called) + self.assertTrue(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertTrue(m.WTERMSIG.called) - m_WIFSIGNALED.reset_mock() - m_WIFEXITED.reset_mock() - m_WTERMSIG.reset_mock() + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WTERMSIG.reset_mock() callback2.reset_mock() # ensure that the child is effectively reaped @@ -1102,15 +1089,10 @@ def test_sigchld_replace_handler( self.assertFalse(callback1.called) self.assertFalse(callback2.called) - self.assertFalse(m_WTERMSIG.called) - - @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) - @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) - @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) - @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) - @unittest.mock.patch('os.waitpid', wraps=waitpid) - def test_sigchld_remove_handler(self, m_waitpid, m_WIFEXITED, - m_WIFSIGNALED, m_WEXITSTATUS, m_WTERMSIG): + self.assertFalse(m.WTERMSIG.called) + + @waitpid_mocks + def test_sigchld_remove_handler(self, m): callback = unittest.mock.Mock() # register a child @@ -1119,19 +1101,19 @@ def test_sigchld_remove_handler(self, m_waitpid, m_WIFEXITED, self.watcher.add_child_handler(52, callback, 1984) self.assertFalse(callback.called) - self.assertFalse(m_WIFEXITED.called) - self.assertFalse(m_WIFSIGNALED.called) - self.assertFalse(m_WEXITSTATUS.called) - self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) # unregister the child self.watcher.remove_child_handler(52) self.assertFalse(callback.called) - self.assertFalse(m_WIFEXITED.called) - self.assertFalse(m_WIFSIGNALED.called) - self.assertFalse(m_WEXITSTATUS.called) - self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) # child terminates (code 99) self.running = False @@ -1141,13 +1123,8 @@ def test_sigchld_remove_handler(self, m_waitpid, m_WIFEXITED, self.assertFalse(callback.called) - @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) - @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) - @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) - @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) - @unittest.mock.patch('os.waitpid', wraps=waitpid) - def test_sigchld_unknown_status(self, m_waitpid, m_WIFEXITED, - m_WIFSIGNALED, m_WEXITSTATUS, m_WTERMSIG): + @waitpid_mocks + def test_sigchld_unknown_status(self, m): callback = unittest.mock.Mock() # register a child @@ -1156,10 +1133,10 @@ def test_sigchld_unknown_status(self, m_waitpid, m_WIFEXITED, self.watcher.add_child_handler(53, callback, -19) self.assertFalse(callback.called) - self.assertFalse(m_WIFEXITED.called) - self.assertFalse(m_WIFSIGNALED.called) - self.assertFalse(m_WEXITSTATUS.called) - self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) # terminate with unknown status self.zombies[53] = 1178 @@ -1167,14 +1144,14 @@ def test_sigchld_unknown_status(self, m_waitpid, m_WIFEXITED, self.watcher._sig_chld() callback.assert_called_once_with(53, 1178, -19) - self.assertTrue(m_WIFEXITED.called) - self.assertTrue(m_WIFSIGNALED.called) - self.assertFalse(m_WEXITSTATUS.called) - self.assertFalse(m_WTERMSIG.called) + self.assertTrue(m.WIFEXITED.called) + self.assertTrue(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) callback.reset_mock() - m_WIFEXITED.reset_mock() - m_WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WIFSIGNALED.reset_mock() # ensure that the child is effectively reaped self.add_zombie(53, 101) @@ -1183,13 +1160,8 @@ def test_sigchld_unknown_status(self, m_waitpid, m_WIFEXITED, self.assertFalse(callback.called) - @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) - @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) - @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) - @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) - @unittest.mock.patch('os.waitpid', wraps=waitpid) - def test_remove_child_handler(self, m_waitpid, m_WIFEXITED, - m_WIFSIGNALED, m_WEXITSTATUS, m_WTERMSIG): + @waitpid_mocks + def test_remove_child_handler(self, m): callback1 = unittest.mock.Mock() callback2 = unittest.mock.Mock() callback3 = unittest.mock.Mock() @@ -1221,8 +1193,8 @@ def test_remove_child_handler(self, m_waitpid, m_WIFEXITED, self.assertFalse(callback2.called) callback3.assert_called_once_with(56, 2, 3) - @unittest.mock.patch('os.waitpid', wraps=waitpid) - def test_sigchld_unhandled_exception(self, m_waitpid): + @waitpid_mocks + def test_sigchld_unhandled_exception(self, m): callback = unittest.mock.Mock() # register a child @@ -1231,7 +1203,7 @@ def test_sigchld_unhandled_exception(self, m_waitpid): self.watcher.add_child_handler(57, callback) # raise an exception - m_waitpid.side_effect = ValueError + m.waitpid.side_effect = ValueError with unittest.mock.patch.object(unix_events.logger, "exception") as m_exception: @@ -1239,15 +1211,8 @@ def test_sigchld_unhandled_exception(self, m_waitpid): self.assertEqual(self.watcher._sig_chld(), None) self.assertTrue(m_exception.called) - @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) - @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) - @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) - @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) - @unittest.mock.patch('os.waitpid', wraps=waitpid) - def test_sigchld_child_reaped_elsewhere( - self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, m_WEXITSTATUS, - m_WTERMSIG): - + @waitpid_mocks + def test_sigchld_child_reaped_elsewhere(self, m): # register a child callback = unittest.mock.Mock() @@ -1256,10 +1221,10 @@ def test_sigchld_child_reaped_elsewhere( self.watcher.add_child_handler(58, callback) self.assertFalse(callback.called) - self.assertFalse(m_WIFEXITED.called) - self.assertFalse(m_WIFSIGNALED.called) - self.assertFalse(m_WEXITSTATUS.called) - self.assertFalse(m_WTERMSIG.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) # child terminates self.running = False @@ -1268,13 +1233,13 @@ def test_sigchld_child_reaped_elsewhere( # waitpid is called elsewhere os.waitpid(58, os.WNOHANG) - m_waitpid.reset_mock() + m.waitpid.reset_mock() # sigchld with self.ignore_warnings: self.watcher._sig_chld() - callback.assert_called(m_waitpid) + callback.assert_called(m.waitpid) if isinstance(self.watcher, unix_events.FastChildWatcher): # here the FastChildWatche enters a deadlock # (there is no way to prevent it) @@ -1282,15 +1247,8 @@ def test_sigchld_child_reaped_elsewhere( else: callback.assert_called_once_with(58, 255) - @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) - @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) - @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) - @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) - @unittest.mock.patch('os.waitpid', wraps=waitpid) - def test_sigchld_unknown_pid_during_registration( - self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, m_WEXITSTATUS, - m_WTERMSIG): - + @waitpid_mocks + def test_sigchld_unknown_pid_during_registration(self, m): # register two children callback1 = unittest.mock.Mock() callback2 = unittest.mock.Mock() @@ -1310,15 +1268,8 @@ def test_sigchld_unknown_pid_during_registration( callback1.assert_called_once_with(591, 7) self.assertFalse(callback2.called) - @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) - @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) - @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) - @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) - @unittest.mock.patch('os.waitpid', wraps=waitpid) - def test_set_loop( - self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, m_WEXITSTATUS, - m_WTERMSIG): - + @waitpid_mocks + def test_set_loop(self, m): # register a child callback = unittest.mock.Mock() @@ -1351,15 +1302,8 @@ def test_set_loop( callback.assert_called_once_with(60, 9) - @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) - @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) - @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) - @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) - @unittest.mock.patch('os.waitpid', wraps=waitpid) - def test_set_loop_race_condition( - self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, m_WEXITSTATUS, - m_WTERMSIG): - + @waitpid_mocks + def test_set_loop_race_condition(self, m): # register 3 children callback1 = unittest.mock.Mock() callback2 = unittest.mock.Mock() @@ -1418,15 +1362,8 @@ def test_set_loop_race_condition( self.assertFalse(callback2.called) callback3.assert_called_once_with(622, 19) - @unittest.mock.patch('os.WTERMSIG', wraps=WTERMSIG) - @unittest.mock.patch('os.WEXITSTATUS', wraps=WEXITSTATUS) - @unittest.mock.patch('os.WIFSIGNALED', wraps=WIFSIGNALED) - @unittest.mock.patch('os.WIFEXITED', wraps=WIFEXITED) - @unittest.mock.patch('os.waitpid', wraps=waitpid) - def test_close( - self, m_waitpid, m_WIFEXITED, m_WIFSIGNALED, m_WEXITSTATUS, - m_WTERMSIG): - + @waitpid_mocks + def test_close(self, m): # register two children callback1 = unittest.mock.Mock() callback2 = unittest.mock.Mock() From 090cc43507031ce7f3005cba51de634f6d6321b8 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 15 Nov 2013 07:39:29 -0800 Subject: [PATCH 0784/1502] Increase timeout in test_popen() for buildbots. --- tests/test_windows_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_windows_utils.py b/tests/test_windows_utils.py index e013fbdd..fa9d66c0 100644 --- a/tests/test_windows_utils.py +++ b/tests/test_windows_utils.py @@ -119,7 +119,8 @@ def test_popen(self): overr.ReadFile(p.stderr.handle, 100) events = [ovin.event, ovout.event, overr.event] - res = _winapi.WaitForMultipleObjects(events, True, 2000) + # Super-long timeout for slow buildbots. + res = _winapi.WaitForMultipleObjects(events, True, 10000) self.assertEqual(res, _winapi.WAIT_OBJECT_0) self.assertFalse(ovout.pending) self.assertFalse(overr.pending) From 97aff9a0623262ab56f33657a7812a082a270f61 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 15 Nov 2013 16:39:13 -0800 Subject: [PATCH 0785/1502] Generalized error handling callback for DatastoreProtocol. --- asyncio/protocols.py | 12 +++++++----- asyncio/selector_events.py | 17 ++++++----------- examples/udp_echo.py | 8 ++++---- tests/test_base_events.py | 2 +- tests/test_events.py | 4 ++-- tests/test_selector_events.py | 33 +++++++++++++++++++++++---------- 6 files changed, 43 insertions(+), 33 deletions(-) diff --git a/asyncio/protocols.py b/asyncio/protocols.py index d3a86859..eb94fb6f 100644 --- a/asyncio/protocols.py +++ b/asyncio/protocols.py @@ -100,15 +100,18 @@ class DatagramProtocol(BaseProtocol): def datagram_received(self, data, addr): """Called when some datagram is received.""" - def connection_refused(self, exc): - """Connection is refused.""" + def error_received(self, exc): + """Called when a send or receive operation raises an OSError. + + (Other than BlockingIOError or InterruptedError.) + """ class SubprocessProtocol(BaseProtocol): """ABC representing a protocol for subprocess calls.""" def pipe_data_received(self, fd, data): - """Called when subprocess write a data into stdout/stderr pipes. + """Called when the subprocess writes data into stdout/stderr pipe. fd is int file dascriptor. data is bytes object. @@ -122,5 +125,4 @@ def pipe_connection_lost(self, fd, exc): """ def process_exited(self): - """Called when subprocess has exited. - """ + """Called when subprocess has exited.""" diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 3bad1980..3efa4d2a 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -771,6 +771,8 @@ def _read_ready(self): data, addr = self._sock.recvfrom(self.max_size) except (BlockingIOError, InterruptedError): pass + except OSError as exc: + self._protocol.error_received(exc) except Exception as exc: self._fatal_error(exc) else: @@ -800,9 +802,8 @@ def sendto(self, data, addr=None): return except (BlockingIOError, InterruptedError): self._loop.add_writer(self._sock_fd, self._sendto_ready) - except ConnectionRefusedError as exc: - if self._address: - self._fatal_error(exc) + except OSError as exc: + self._protocol.error_received(exc) return except Exception as exc: self._fatal_error(exc) @@ -822,9 +823,8 @@ def _sendto_ready(self): except (BlockingIOError, InterruptedError): self._buffer.appendleft((data, addr)) # Try again later. break - except ConnectionRefusedError as exc: - if self._address: - self._fatal_error(exc) + except OSError as exc: + self._protocol.error_received(exc) return except Exception as exc: self._fatal_error(exc) @@ -835,8 +835,3 @@ def _sendto_ready(self): self._loop.remove_writer(self._sock_fd) if self._closing: self._call_connection_lost(None) - - def _force_close(self, exc): - if self._address and isinstance(exc, ConnectionRefusedError): - self._protocol.connection_refused(exc) - super()._force_close(exc) diff --git a/examples/udp_echo.py b/examples/udp_echo.py index 8e95d292..e958f385 100755 --- a/examples/udp_echo.py +++ b/examples/udp_echo.py @@ -19,8 +19,8 @@ def datagram_received(self, data, addr): print('Data received:', data, addr) self.transport.sendto(data, addr) - def connection_refused(self, exc): - print('Connection refused:', exc) + def error_received(self, exc): + print('Error received:', exc) def connection_lost(self, exc): print('stop', exc) @@ -40,8 +40,8 @@ def datagram_received(self, data, addr): print('received "{}"'.format(data.decode())) self.transport.close() - def connection_refused(self, exc): - print('Connection refused:', exc) + def error_received(self, exc): + print('Error received:', exc) def connection_lost(self, exc): print('closing transport', exc) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 5178d543..ff537ab2 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -284,7 +284,7 @@ def datagram_received(self, data, addr): assert self.state == 'INITIALIZED', self.state self.nbytes += len(data) - def connection_refused(self, exc): + def error_received(self, exc): assert self.state == 'INITIALIZED', self.state def connection_lost(self, exc): diff --git a/tests/test_events.py b/tests/test_events.py index 2338546a..3a2dece0 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -78,7 +78,7 @@ def datagram_received(self, data, addr): assert self.state == 'INITIALIZED', self.state self.nbytes += len(data) - def connection_refused(self, exc): + def error_received(self, exc): assert self.state == 'INITIALIZED', self.state def connection_lost(self, exc): @@ -1557,7 +1557,7 @@ def test_empty(self): dp = protocols.DatagramProtocol() self.assertIsNone(dp.connection_made(f)) self.assertIsNone(dp.connection_lost(f)) - self.assertIsNone(dp.connection_refused(f)) + self.assertIsNone(dp.error_received(f)) self.assertIsNone(dp.datagram_received(f, f)) sp = protocols.SubprocessProtocol() diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 04a7d0c5..4aef2fde 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -1329,12 +1329,23 @@ def test_read_ready_err(self): transport = _SelectorDatagramTransport( self.loop, self.sock, self.protocol) - err = self.sock.recvfrom.side_effect = OSError() + err = self.sock.recvfrom.side_effect = RuntimeError() transport._fatal_error = unittest.mock.Mock() transport._read_ready() transport._fatal_error.assert_called_with(err) + def test_read_ready_oserr(self): + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + + err = self.sock.recvfrom.side_effect = OSError() + transport._fatal_error = unittest.mock.Mock() + transport._read_ready() + + self.assertFalse(transport._fatal_error.called) + self.protocol.error_received.assert_called_with(err) + def test_sendto(self): data = b'data' transport = _SelectorDatagramTransport( @@ -1380,7 +1391,7 @@ def test_sendto_tryagain(self): @unittest.mock.patch('asyncio.selector_events.logger') def test_sendto_exception(self, m_log): data = b'data' - err = self.sock.sendto.side_effect = OSError() + err = self.sock.sendto.side_effect = RuntimeError() transport = _SelectorDatagramTransport( self.loop, self.sock, self.protocol) @@ -1399,7 +1410,7 @@ def test_sendto_exception(self, m_log): transport.sendto(data) m_log.warning.assert_called_with('socket.send() raised exception.') - def test_sendto_connection_refused(self): + def test_sendto_error_received(self): data = b'data' self.sock.sendto.side_effect = ConnectionRefusedError @@ -1412,7 +1423,7 @@ def test_sendto_connection_refused(self): self.assertEqual(transport._conn_lost, 0) self.assertFalse(transport._fatal_error.called) - def test_sendto_connection_refused_connected(self): + def test_sendto_error_received_connected(self): data = b'data' self.sock.send.side_effect = ConnectionRefusedError @@ -1422,7 +1433,8 @@ def test_sendto_connection_refused_connected(self): transport._fatal_error = unittest.mock.Mock() transport.sendto(data) - self.assertTrue(transport._fatal_error.called) + self.assertFalse(transport._fatal_error.called) + self.assertTrue(self.protocol.error_received.called) def test_sendto_str(self): transport = _SelectorDatagramTransport( @@ -1495,7 +1507,7 @@ def test_sendto_ready_tryagain(self): list(transport._buffer)) def test_sendto_ready_exception(self): - err = self.sock.sendto.side_effect = OSError() + err = self.sock.sendto.side_effect = RuntimeError() transport = _SelectorDatagramTransport( self.loop, self.sock, self.protocol) @@ -1505,7 +1517,7 @@ def test_sendto_ready_exception(self): transport._fatal_error.assert_called_with(err) - def test_sendto_ready_connection_refused(self): + def test_sendto_ready_error_received(self): self.sock.sendto.side_effect = ConnectionRefusedError transport = _SelectorDatagramTransport( @@ -1516,7 +1528,7 @@ def test_sendto_ready_connection_refused(self): self.assertFalse(transport._fatal_error.called) - def test_sendto_ready_connection_refused_connection(self): + def test_sendto_ready_error_received_connection(self): self.sock.send.side_effect = ConnectionRefusedError transport = _SelectorDatagramTransport( @@ -1525,7 +1537,8 @@ def test_sendto_ready_connection_refused_connection(self): transport._buffer.append((b'data', ())) transport._sendto_ready() - self.assertTrue(transport._fatal_error.called) + self.assertFalse(transport._fatal_error.called) + self.assertTrue(self.protocol.error_received.called) @unittest.mock.patch('asyncio.log.logger.exception') def test_fatal_error_connected(self, m_exc): @@ -1533,7 +1546,7 @@ def test_fatal_error_connected(self, m_exc): self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) err = ConnectionRefusedError() transport._fatal_error(err) - self.protocol.connection_refused.assert_called_with(err) + self.assertFalse(self.protocol.error_received.called) m_exc.assert_called_with('Fatal error for %s', transport) From 7953750f3988f6e5afcc98352d501dfd62a1265f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 19 Nov 2013 11:32:45 -0800 Subject: [PATCH 0786/1502] Add streams.start_server(), by Gustavo Carneiro. --- asyncio/streams.py | 53 +++++++++++- examples/simple_tcp_server.py | 151 ++++++++++++++++++++++++++++++++++ tests/test_streams.py | 66 +++++++++++++++ 3 files changed, 268 insertions(+), 2 deletions(-) create mode 100644 examples/simple_tcp_server.py diff --git a/asyncio/streams.py b/asyncio/streams.py index e9953682..331d28d0 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -1,6 +1,8 @@ """Stream-related things.""" -__all__ = ['StreamReader', 'StreamReaderProtocol', 'open_connection'] +__all__ = ['StreamReader', 'StreamReaderProtocol', + 'open_connection', 'start_server', + ] import collections @@ -43,6 +45,42 @@ def open_connection(host=None, port=None, *, return reader, writer +@tasks.coroutine +def start_server(client_connected_cb, host=None, port=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """Start a socket server, call back for each client connected. + + The first parameter, `client_connected_cb`, takes two parameters: + client_reader, client_writer. client_reader is a StreamReader + object, while client_writer is a StreamWriter object. This + parameter can either be a plain callback function or a coroutine; + if it is a coroutine, it will be automatically converted into a + Task. + + The rest of the arguments are all the usual arguments to + loop.create_server() except protocol_factory; most common are + positional host and port, with various optional keyword arguments + following. The return value is the same as loop.create_server(). + + Additional optional keyword arguments are loop (to set the event loop + instance to use) and limit (to set the buffer limit passed to the + StreamReader). + + The return value is the same as loop.create_server(), i.e. a + Server object which can be used to stop the service. + """ + if loop is None: + loop = events.get_event_loop() + + def factory(): + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, client_connected_cb, + loop=loop) + return protocol + + return (yield from loop.create_server(factory, host, port, **kwds)) + + class StreamReaderProtocol(protocols.Protocol): """Trivial helper class to adapt between Protocol and StreamReader. @@ -52,13 +90,24 @@ class StreamReaderProtocol(protocols.Protocol): call inappropriate methods of the protocol.) """ - def __init__(self, stream_reader): + def __init__(self, stream_reader, client_connected_cb=None, loop=None): self._stream_reader = stream_reader + self._stream_writer = None self._drain_waiter = None self._paused = False + self._client_connected_cb = client_connected_cb + self._loop = loop # May be None; we may never need it. def connection_made(self, transport): self._stream_reader.set_transport(transport) + if self._client_connected_cb is not None: + self._stream_writer = StreamWriter(transport, self, + self._stream_reader, + self._loop) + res = self._client_connected_cb(self._stream_reader, + self._stream_writer) + if tasks.iscoroutine(res): + tasks.Task(res, loop=self._loop) def connection_lost(self, exc): if exc is None: diff --git a/examples/simple_tcp_server.py b/examples/simple_tcp_server.py new file mode 100644 index 00000000..c36710d2 --- /dev/null +++ b/examples/simple_tcp_server.py @@ -0,0 +1,151 @@ +""" +Example of a simple TCP server that is written in (mostly) coroutine +style and uses asyncio.streams.start_server() and +asyncio.streams.open_connection(). + +Note that running this example starts both the TCP server and client +in the same process. It listens on port 1234 on 127.0.0.1, so it will +fail if this port is currently in use. +""" + +import sys +import asyncio +import asyncio.streams + + +class MyServer: + """ + This is just an example of how a TCP server might be potentially + structured. This class has basically 3 methods: start the server, + handle a client, and stop the server. + + Note that you don't have to follow this structure, it is really + just an example or possible starting point. + """ + + def __init__(self): + self.server = None # encapsulates the server sockets + + # this keeps track of all the clients that connected to our + # server. It can be useful in some cases, for instance to + # kill client connections or to broadcast some data to all + # clients... + self.clients = {} # task -> (reader, writer) + + def _accept_client(self, client_reader, client_writer): + """ + This method accepts a new client connection and creates a Task + to handle this client. self.clients is updated to keep track + of the new client. + """ + + # start a new Task to handle this specific client connection + task = asyncio.Task(self._handle_client(client_reader, client_writer)) + self.clients[task] = (client_reader, client_writer) + + def client_done(task): + print("client task done:", task, file=sys.stderr) + del self.clients[task] + + task.add_done_callback(client_done) + + @asyncio.coroutine + def _handle_client(self, client_reader, client_writer): + """ + This method actually does the work to handle the requests for + a specific client. The protocol is line oriented, so there is + a main loop that reads a line with a request and then sends + out one or more lines back to the client with the result. + """ + while True: + data = (yield from client_reader.readline()).decode("utf-8") + if not data: # an empty string means the client disconnected + break + cmd, *args = data.rstrip().split(' ') + if cmd == 'add': + arg1 = float(args[0]) + arg2 = float(args[1]) + retval = arg1 + arg2 + client_writer.write("{!r}\n".format(retval).encode("utf-8")) + elif cmd == 'repeat': + times = int(args[0]) + msg = args[1] + client_writer.write("begin\n".encode("utf-8")) + for idx in range(times): + client_writer.write("{}. {}\n".format(idx+1, msg) + .encode("utf-8")) + client_writer.write("end\n".encode("utf-8")) + else: + print("Bad command {!r}".format(data), file=sys.stderr) + + # This enables us to have flow control in our connection. + yield from client_writer.drain() + + def start(self, loop): + """ + Starts the TCP server, so that it listens on port 1234. + + For each client that connects, the accept_client method gets + called. This method runs the loop until the server sockets + are ready to accept connections. + """ + self.server = loop.run_until_complete( + asyncio.streams.start_server(self._accept_client, + '127.0.0.1', 12345, + loop=loop)) + + def stop(self, loop): + """ + Stops the TCP server, i.e. closes the listening socket(s). + + This method runs the loop until the server sockets are closed. + """ + if self.server is not None: + self.server.close() + loop.run_until_complete(self.server.wait_closed()) + self.server = None + + +def main(): + loop = asyncio.get_event_loop() + + # creates a server and starts listening to TCP connections + server = MyServer() + server.start(loop) + + @asyncio.coroutine + def client(): + reader, writer = yield from asyncio.streams.open_connection( + '127.0.0.1', 12345, loop=loop) + + def send(msg): + print("> " + msg) + writer.write((msg + '\n').encode("utf-8")) + + def recv(): + msgback = (yield from reader.readline()).decode("utf-8").rstrip() + print("< " + msgback) + return msgback + + # send a line + send("add 1 2") + msg = yield from recv() + + send("repeat 5 hello") + msg = yield from recv() + assert msg == 'begin' + while True: + msg = yield from recv() + if msg == 'end': + break + + writer.close() + yield from asyncio.sleep(0.5) + + # creates a client and connects to our server + msg = loop.run_until_complete(client()) + server.stop(loop) + + +if __name__ == '__main__': + main() diff --git a/tests/test_streams.py b/tests/test_streams.py index 69e2246f..5516c158 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -359,6 +359,72 @@ def read_a_line(): test_utils.run_briefly(self.loop) self.assertIs(stream._waiter, None) + def test_start_server(self): + + class MyServer: + + def __init__(self, loop): + self.server = None + self.loop = loop + + @tasks.coroutine + def handle_client(self, client_reader, client_writer): + data = yield from client_reader.readline() + client_writer.write(data) + + def start(self): + self.server = self.loop.run_until_complete( + streams.start_server(self.handle_client, + '127.0.0.1', 12345, + loop=self.loop)) + + def handle_client_callback(self, client_reader, client_writer): + task = tasks.Task(client_reader.readline(), loop=self.loop) + + def done(task): + client_writer.write(task.result()) + + task.add_done_callback(done) + + def start_callback(self): + self.server = self.loop.run_until_complete( + streams.start_server(self.handle_client_callback, + '127.0.0.1', 12345, + loop=self.loop)) + + def stop(self): + if self.server is not None: + self.server.close() + self.loop.run_until_complete(self.server.wait_closed()) + self.server = None + + @tasks.coroutine + def client(): + reader, writer = yield from streams.open_connection( + '127.0.0.1', 12345, loop=self.loop) + # send a line + writer.write(b"hello world!\n") + # read it back + msgback = yield from reader.readline() + writer.close() + return msgback + + # test the server variant with a coroutine as client handler + server = MyServer(self.loop) + server.start() + msg = self.loop.run_until_complete(tasks.Task(client(), + loop=self.loop)) + server.stop() + self.assertEqual(msg, b"hello world!\n") + + # test the server variant with a callback as client handler + server = MyServer(self.loop) + server.start_callback() + msg = self.loop.run_until_complete(tasks.Task(client(), + loop=self.loop)) + server.stop() + self.assertEqual(msg, b"hello world!\n") + if __name__ == '__main__': unittest.main() From f94c1135f12b28308314e79a07d6168fb627ba32 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 21 Nov 2013 10:02:29 -0800 Subject: [PATCH 0787/1502] Allow and correctly implement Semaphore(0). --- asyncio/locks.py | 4 ++-- tests/test_locks.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/asyncio/locks.py b/asyncio/locks.py index ac851e5c..dd9f0f8b 100644 --- a/asyncio/locks.py +++ b/asyncio/locks.py @@ -348,12 +348,12 @@ class Semaphore: def __init__(self, value=1, bound=False, *, loop=None): if value < 0: - raise ValueError("Semaphore initial value must be > 0") + raise ValueError("Semaphore initial value must be >= 0") self._value = value self._bound = bound self._bound_value = value self._waiters = collections.deque() - self._locked = False + self._locked = (value == 0) if loop is not None: self._loop = loop else: diff --git a/tests/test_locks.py b/tests/test_locks.py index 19ef877a..539c5c38 100644 --- a/tests/test_locks.py +++ b/tests/test_locks.py @@ -684,6 +684,10 @@ def test_ctor_noloop(self): finally: events.set_event_loop(None) + def test_initial_value_zero(self): + sem = locks.Semaphore(0, loop=self.loop) + self.assertTrue(sem.locked()) + def test_repr(self): sem = locks.Semaphore(loop=self.loop) self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) From ab5706cb9779c85324e781a01e8b3937681b6448 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 22 Nov 2013 11:30:02 -0800 Subject: [PATCH 0788/1502] From CPython: relax a timeout. --- tests/test_windows_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index f5147de2..7ba33dac 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -109,7 +109,7 @@ def test_wait_for_handle(self): self.loop.run_until_complete(f) elapsed = self.loop.time() - start self.assertFalse(f.result()) - self.assertTrue(0.18 < elapsed < 0.22, elapsed) + self.assertTrue(0.18 < elapsed < 0.5, elapsed) _overlapped.SetEvent(event) From 8307a7f592cb9a039c300f60681822a0b09aa073 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 22 Nov 2013 11:30:23 -0800 Subject: [PATCH 0789/1502] From CPython: use a single return in selectors.py. --- asyncio/selectors.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/asyncio/selectors.py b/asyncio/selectors.py index 3971502e..261fac6c 100644 --- a/asyncio/selectors.py +++ b/asyncio/selectors.py @@ -140,14 +140,12 @@ def modify(self, fileobj, events, data=None): raise KeyError("{!r} is not registered".format(fileobj)) from None if events != key.events: self.unregister(fileobj) - return self.register(fileobj, events, data) + key = self.register(fileobj, events, data) elif data != key.data: # Use a shortcut to update the data. key = key._replace(data=data) self._fd_to_key[key.fd] = key - return key - else: - return key + return key @abstractmethod def select(self, timeout=None): From 67025996aa8981d8f27325c0f3b9bf6dc6c054c0 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 22 Nov 2013 11:41:34 -0800 Subject: [PATCH 0790/1502] Pass cancellation from wrapping Future to wrapped Future. Fixes issue 88. By Sa?l Ibarra Corretg? (mostly). --- asyncio/futures.py | 11 ++++++++--- tests/test_futures.py | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/asyncio/futures.py b/asyncio/futures.py index db278386..dd3e718d 100644 --- a/asyncio/futures.py +++ b/asyncio/futures.py @@ -301,6 +301,8 @@ def _copy_state(self, other): The other Future may be a concurrent.futures.Future. """ assert other.done() + if self.cancelled(): + return assert not self.done() if other.cancelled(): self.cancel() @@ -324,14 +326,17 @@ def wrap_future(fut, *, loop=None): """Wrap concurrent.futures.Future object.""" if isinstance(fut, Future): return fut - assert isinstance(fut, concurrent.futures.Future), \ 'concurrent.futures.Future is expected, got {!r}'.format(fut) - if loop is None: loop = events.get_event_loop() - new_future = Future(loop=loop) + + def _check_cancel_other(f): + if f.cancelled(): + fut.cancel() + + new_future.add_done_callback(_check_cancel_other) fut.add_done_callback( lambda future: loop.call_soon_threadsafe( new_future._copy_state, fut)) diff --git a/tests/test_futures.py b/tests/test_futures.py index ccea2ffd..e35fcf07 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -241,6 +241,24 @@ def run(arg): f2 = futures.wrap_future(f1) self.assertIs(m_events.get_event_loop.return_value, f2._loop) + def test_wrap_future_cancel(self): + f1 = concurrent.futures.Future() + f2 = futures.wrap_future(f1, loop=self.loop) + f2.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(f1.cancelled()) + self.assertTrue(f2.cancelled()) + + def test_wrap_future_cancel2(self): + f1 = concurrent.futures.Future() + f2 = futures.wrap_future(f1, loop=self.loop) + f1.set_result(42) + f2.cancel() + test_utils.run_briefly(self.loop) + self.assertFalse(f1.cancelled()) + self.assertEqual(f1.result(), 42) + self.assertTrue(f2.cancelled()) + class FutureDoneCallbackTests(unittest.TestCase): From 276df9b47f52a264ac23f2a04aa8083c5b45aa7c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 22 Nov 2013 20:32:19 -0800 Subject: [PATCH 0791/1502] Two new examples: print 'Hello World' every two seconds, using a callback and using a coroutine. Thanks to Terry Reedy who suggested this exercise. --- examples/hello_callback.py | 14 ++++++++++++++ examples/hello_coroutine.py | 15 +++++++++++++++ 2 files changed, 29 insertions(+) create mode 100644 examples/hello_callback.py create mode 100644 examples/hello_coroutine.py diff --git a/examples/hello_callback.py b/examples/hello_callback.py new file mode 100644 index 00000000..df889e55 --- /dev/null +++ b/examples/hello_callback.py @@ -0,0 +1,14 @@ +"""Print 'Hello World' every two seconds, using a callback.""" + +import asyncio + + +def print_and_repeat(loop): + print('Hello World') + loop.call_later(2, print_and_repeat, loop) + + +if __name__ == '__main__': + loop = asyncio.get_event_loop() + print_and_repeat(loop) + loop.run_forever() diff --git a/examples/hello_coroutine.py b/examples/hello_coroutine.py new file mode 100644 index 00000000..8ad682d2 --- /dev/null +++ b/examples/hello_coroutine.py @@ -0,0 +1,15 @@ +"""Print 'Hello World' every two seconds, using a coroutine.""" + +import asyncio + + +@asyncio.coroutine +def greet_every_two_seconds(): + while True: + print('Hello World') + yield from asyncio.sleep(2) + + +if __name__ == '__main__': + loop = asyncio.get_event_loop() + loop.run_until_complete(greet_every_two_seconds()) From 67e2c490669cfe0de93426782ecaae8f36b2c355 Mon Sep 17 00:00:00 2001 From: Sa?l Ibarra Corretg? Date: Sat, 23 Nov 2013 19:42:37 +0100 Subject: [PATCH 0792/1502] Use socketpair() from test_utils in tests --- tests/test_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_events.py b/tests/test_events.py index 3a2dece0..a9c6385c 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -896,7 +896,7 @@ def factory(): proto = MyWritePipeProto(loop=self.loop) return proto - rsock, wsock = self.loop._socketpair() + rsock, wsock = test_utils.socketpair() pipeobj = io.open(wsock.detach(), 'wb', 1024) @tasks.coroutine From ce82b539f8d0f5d528c509821bb9286751297f33 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 23 Nov 2013 11:48:31 -0800 Subject: [PATCH 0793/1502] Relax timing (from upstream CPython repo). --- tests/test_base_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index ff537ab2..96f29750 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -181,7 +181,7 @@ def test__run_once(self): self.loop._run_once() t = self.loop._selector.select.call_args[0][0] - self.assertTrue(9.9 < t < 10.1, t) + self.assertTrue(9.5 < t < 10.5, t) self.assertEqual([h2], self.loop._scheduled) self.assertTrue(self.loop._process_events.called) From f4f0267afc050b76c0b4dad9bdc1048d2fde2325 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 23 Nov 2013 11:48:54 -0800 Subject: [PATCH 0794/1502] Fix some docstrings (from upstream CPython repo). --- asyncio/transports.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/asyncio/transports.py b/asyncio/transports.py index 8c6b1896..98f92247 100644 --- a/asyncio/transports.py +++ b/asyncio/transports.py @@ -16,7 +16,7 @@ def get_extra_info(self, name, default=None): return self._extra.get(name, default) def close(self): - """Closes the transport. + """Close the transport. Buffered data will be flushed asynchronously. No more data will be received. After all buffered data is flushed, the @@ -92,7 +92,7 @@ def writelines(self, list_of_data): self.write(data) def write_eof(self): - """Closes the write end after flushing buffered data. + """Close the write end after flushing buffered data. (This is like typing ^D into a UNIX program reading from stdin.) @@ -101,11 +101,11 @@ def write_eof(self): raise NotImplementedError def can_write_eof(self): - """Return True if this protocol supports write_eof(), False if not.""" + """Return True if this transport supports write_eof(), False if not.""" raise NotImplementedError def abort(self): - """Closes the transport immediately. + """Close the transport immediately. Buffered data will be lost. No more data will be received. The protocol's connection_lost() method will (eventually) be @@ -150,7 +150,7 @@ def sendto(self, data, addr=None): raise NotImplementedError def abort(self): - """Closes the transport immediately. + """Close the transport immediately. Buffered data will be lost. No more data will be received. The protocol's connection_lost() method will (eventually) be From 94197b0352a5efe94e5ca7de06e3b94534723197 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 23 Nov 2013 15:05:58 -0800 Subject: [PATCH 0795/1502] Change bounded semaphore into a subclass, like threading.[Bounded]Semaphore. --- asyncio/locks.py | 36 +++++++++++++++++++----------------- tests/test_locks.py | 2 +- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/asyncio/locks.py b/asyncio/locks.py index dd9f0f8b..4b0ce50b 100644 --- a/asyncio/locks.py +++ b/asyncio/locks.py @@ -336,22 +336,15 @@ class Semaphore: Semaphores also support the context manager protocol. - The first optional argument gives the initial value for the internal + The optional argument gives the initial value for the internal counter; it defaults to 1. If the value given is less than 0, ValueError is raised. - - The second optional argument determines if the semaphore can be released - more than initial internal counter value; it defaults to False. If the - value given is True and number of release() is more than number of - successful acquire() calls ValueError is raised. """ - def __init__(self, value=1, bound=False, *, loop=None): + def __init__(self, value=1, *, loop=None): if value < 0: raise ValueError("Semaphore initial value must be >= 0") self._value = value - self._bound = bound - self._bound_value = value self._waiters = collections.deque() self._locked = (value == 0) if loop is not None: @@ -402,17 +395,9 @@ def release(self): """Release a semaphore, incrementing the internal counter by one. When it was zero on entry and another coroutine is waiting for it to become larger than zero again, wake up that coroutine. - - If Semaphore is created with "bound" parameter equals true, then - release() method checks to make sure its current value doesn't exceed - its initial value. If it does, ValueError is raised. """ - if self._bound and self._value >= self._bound_value: - raise ValueError('Semaphore released too many times') - self._value += 1 self._locked = False - for waiter in self._waiters: if not waiter.done(): waiter.set_result(True) @@ -429,3 +414,20 @@ def __exit__(self, *args): def __iter__(self): yield from self.acquire() return self + + +class BoundedSemaphore(Semaphore): + """A bounded semaphore implementation. + + This raises ValueError in release() if it would increase the value + above the initial value. + """ + + def __init__(self, value=1, *, loop=None): + self._bound_value = value + super().__init__(value, loop=loop) + + def release(self): + if self._value >= self._bound_value: + raise ValueError('BoundedSemaphore released too many times') + super().release() diff --git a/tests/test_locks.py b/tests/test_locks.py index 539c5c38..df106e0d 100644 --- a/tests/test_locks.py +++ b/tests/test_locks.py @@ -805,7 +805,7 @@ def test_acquire_cancel(self): self.assertFalse(sem._waiters) def test_release_not_acquired(self): - sem = locks.Semaphore(bound=True, loop=self.loop) + sem = locks.BoundedSemaphore(loop=self.loop) self.assertRaises(ValueError, sem.release) From 440f941223ef8d78ab806e952c0af57544b1eb85 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 23 Nov 2013 15:34:57 -0800 Subject: [PATCH 0796/1502] Add 'back' option to update script. --- update_stdlib.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/update_stdlib.sh b/update_stdlib.sh index 828de2c8..b025adfa 100755 --- a/update_stdlib.sh +++ b/update_stdlib.sh @@ -29,10 +29,11 @@ maybe_copy() fi echo ======== $SRC === $DST ======== diff -u $DST $SRC - echo -n "Copy $SRC? [y/N] " + echo -n "Copy $SRC? [y/N/back] " read X case $X in [yY]*) echo Copying $SRC; cp $SRC $DST;; + back) echo Copying TO $SRC; cp $DST $SRC;; *) echo Not copying $SRC;; esac } From 9a45889f5dc2ffa786d7d419faadd3191fa9e53f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 23 Nov 2013 15:38:18 -0800 Subject: [PATCH 0797/1502] Keep asyncio working with Python 3.3 too. --- asyncio/selector_events.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 3efa4d2a..0641459f 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -571,10 +571,15 @@ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, # context; in that case the sslcontext passed is None. # The default is the same as used by urllib with # cadefault=True. - sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - sslcontext.options |= ssl.OP_NO_SSLv2 - sslcontext.set_default_verify_paths() - sslcontext.verify_mode = ssl.CERT_REQUIRED + if hasattr(ssl, '_create_stdlib_context'): + sslcontext = ssl._create_stdlib_context( + cert_reqs=ssl.CERT_REQUIRED) + else: + # Fallback for Python 3.3. + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.set_default_verify_paths() + sslcontext.verify_mode = ssl.CERT_REQUIRED wrap_kwargs = { 'server_side': server_side, From cc786b7ceb69f68fe8172c2fcda8c3ec2c0b74e7 Mon Sep 17 00:00:00 2001 From: Richard Oudkerk Date: Sun, 24 Nov 2013 11:04:44 -0800 Subject: [PATCH 0798/1502] Use WaitForSingleObject() instead of trusting TimerOrWaitFired. --- asyncio/windows_events.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 64fe3861..b2ed2415 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -327,14 +327,21 @@ def wait_for_handle(self, handle, timeout=None): handle, self._iocp, ov.address, ms) f = _WaitHandleFuture(wh, loop=self._loop) - def finish(timed_out, _, ov): + def finish(trans, key, ov): if not f.cancelled(): try: _overlapped.UnregisterWait(wh) except OSError as e: if e.winerror != _overlapped.ERROR_IO_PENDING: raise - return not timed_out + # Note that this second wait means that we should only use + # this with handles types where a successful wait has no + # effect. So events or processes are all right, but locks + # or semaphores are not. Also note if the handle is + # signalled and then quickly reset, then we may return + # False even though we have not timed out. + return (_winapi.WaitForSingleObject(handle, 0) == + _winapi.WAIT_OBJECT_0) self._cache[ov.address] = (f, ov, None, finish) return f From 1d4bfc24fc5e4daad041978eb53455bf4e98ee33 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 24 Nov 2013 22:31:16 -0800 Subject: [PATCH 0799/1502] Add BoundedSemaphore to export list in locks.__all__. --- asyncio/locks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/locks.py b/asyncio/locks.py index 4b0ce50b..2d458a9b 100644 --- a/asyncio/locks.py +++ b/asyncio/locks.py @@ -1,6 +1,6 @@ """Synchronization primitives.""" -__all__ = ['Lock', 'Event', 'Condition', 'Semaphore'] +__all__ = ['Lock', 'Event', 'Condition', 'Semaphore', 'BoundedSemaphore'] import collections From 061ef82b9082f1d7dd1627b825e34959fe067267 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 24 Nov 2013 22:40:29 -0800 Subject: [PATCH 0800/1502] Fix docstring of get_nowait(). --- asyncio/queues.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/queues.py b/asyncio/queues.py index 536de1cb..3f5bf447 100644 --- a/asyncio/queues.py +++ b/asyncio/queues.py @@ -184,7 +184,7 @@ def get(self): def get_nowait(self): """Remove and return an item from the queue. - Return an item if one is immediately available, else raise Full. + Return an item if one is immediately available, else raise Empty. """ self._consume_done_putters() if self._putters: From 5c334cc9ccb7c2c9846d2395cf57381b7209fa0f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 25 Nov 2013 09:51:48 -0800 Subject: [PATCH 0801/1502] Change mock pipe to mock socket. Hope to fix CPython issue 19750. --- tests/test_unix_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index fdd90495..98cf4079 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -379,7 +379,7 @@ def setUp(self): fstat_patcher = unittest.mock.patch('os.fstat') m_fstat = fstat_patcher.start() st = unittest.mock.Mock() - st.st_mode = stat.S_IFIFO + st.st_mode = stat.S_IFSOCK m_fstat.return_value = st self.addCleanup(fstat_patcher.stop) From 0fb81e1ce0f2da62d15b192d4430e369b7e60486 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 25 Nov 2013 10:09:21 -0800 Subject: [PATCH 0802/1502] Hopeful fix for CPython issue 19765. --- tests/test_events.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_events.py b/tests/test_events.py index a9c6385c..6abc7243 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -560,6 +560,7 @@ def factory(): client.connect(('127.0.0.1', port)) client.sendall(b'xxx') test_utils.run_briefly(self.loop) + test_utils.run_until(self.loop, lambda: proto is not None, 10) self.assertIsInstance(proto, MyProto) self.assertEqual('INITIAL', proto.state) test_utils.run_briefly(self.loop) From 036bbfb768de845b3495b99d212fffbf98ba5571 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 25 Nov 2013 10:18:09 -0800 Subject: [PATCH 0803/1502] Set version to 0.2.1. Ready for PyPI. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index b27f7860..77db68fb 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ setup( name="asyncio", - version="0.1.1", + version="0.2.1", description="reference implementation of PEP 3156", long_description=open("README").read(), From 40ce7fb9dc2b9d060e2cc67ac85edaa3839b534a Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 25 Nov 2013 10:22:06 -0800 Subject: [PATCH 0804/1502] Added tag 0.2.1 for changeset 278be97ece66 From afb73f6a7748127f1348d3c08791b116e3f92deb Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 25 Nov 2013 15:04:47 -0800 Subject: [PATCH 0805/1502] Add StreamReaderProtocol to __all__. --- asyncio/streams.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/streams.py b/asyncio/streams.py index 331d28d0..50c4c5d1 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -1,6 +1,6 @@ """Stream-related things.""" -__all__ = ['StreamReader', 'StreamReaderProtocol', +__all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol', 'open_connection', 'start_server', ] From b56cf3cf9f01c0e70e2e53931a20120a976ef9fb Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 27 Nov 2013 10:21:04 -0800 Subject: [PATCH 0806/1502] Fix get_event_loop() to call set_event_loop() when setting the loop. By Anthony Baire. --- asyncio/events.py | 2 +- tests/test_events.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/asyncio/events.py b/asyncio/events.py index 36ae312b..d429686b 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -360,7 +360,7 @@ def get_event_loop(self): if (self._local._loop is None and not self._local._set_called and isinstance(threading.current_thread(), threading._MainThread)): - self._local._loop = self.new_event_loop() + self.set_event_loop(self.new_event_loop()) assert self._local._loop is not None, \ ('There is no current event loop in thread %r.' % threading.current_thread().name) diff --git a/tests/test_events.py b/tests/test_events.py index 6abc7243..76aa935e 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1599,6 +1599,22 @@ def test_get_event_loop(self): self.assertIs(loop, policy.get_event_loop()) loop.close() + def test_get_event_loop_calls_set_event_loop(self): + policy = self.create_policy() + + with unittest.mock.patch.object( + policy, "set_event_loop", + wraps=policy.set_event_loop) as m_set_event_loop: + + loop = policy.get_event_loop() + + # policy._local._loop must be set through .set_event_loop() + # (the unix DefaultEventLoopPolicy needs this call to attach + # the child watcher correctly) + m_set_event_loop.assert_called_with(loop) + + loop.close() + def test_get_event_loop_after_set_none(self): policy = self.create_policy() policy.set_event_loop(None) From 229aac570dc4a6b809f70e0706ae4c7948525da2 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 27 Nov 2013 10:38:36 -0800 Subject: [PATCH 0807/1502] Fix amount of indentation -- CPython's precommit script complained. --- tests/test_events.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_events.py b/tests/test_events.py index 76aa935e..18411ecc 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1606,12 +1606,12 @@ def test_get_event_loop_calls_set_event_loop(self): policy, "set_event_loop", wraps=policy.set_event_loop) as m_set_event_loop: - loop = policy.get_event_loop() + loop = policy.get_event_loop() - # policy._local._loop must be set through .set_event_loop() - # (the unix DefaultEventLoopPolicy needs this call to attach - # the child watcher correctly) - m_set_event_loop.assert_called_with(loop) + # policy._local._loop must be set through .set_event_loop() + # (the unix DefaultEventLoopPolicy needs this call to attach + # the child watcher correctly) + m_set_event_loop.assert_called_with(loop) loop.close() From 77f56d0c43cb694bbb943ce88c907025aa498fd4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 26 Nov 2013 15:50:31 -0800 Subject: [PATCH 0808/1502] Experimental bytearray buffer. --- asyncio/selector_events.py | 37 ++++++++------- tests/test_selector_events.py | 89 ++++++++++++++++++----------------- 2 files changed, 66 insertions(+), 60 deletions(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 0641459f..2b79621d 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -340,6 +340,8 @@ class _SelectorTransport(transports.Transport): max_size = 256 * 1024 # Buffer size passed to recv(). + _buffer_factory = bytearray # Constructs initial value for self._buffer. + def __init__(self, loop, sock, protocol, extra, server=None): super().__init__(extra) self._extra['socket'] = sock @@ -354,7 +356,7 @@ def __init__(self, loop, sock, protocol, extra, server=None): self._sock_fd = sock.fileno() self._protocol = protocol self._server = server - self._buffer = collections.deque() + self._buffer = self._buffer_factory() self._conn_lost = 0 # Set when call to connection_lost scheduled. self._closing = False # Set when close() called. self._protocol_paused = False @@ -438,7 +440,7 @@ def set_write_buffer_limits(self, high=None, low=None): self._low_water = low def get_write_buffer_size(self): - return sum(len(data) for data in self._buffer) + return len(self._buffer) class _SelectorSocketTransport(_SelectorTransport): @@ -516,25 +518,23 @@ def write(self, data): self._loop.add_writer(self._sock_fd, self._write_ready) # Add it to the buffer. - self._buffer.append(data) + self._buffer.extend(data) self._maybe_pause_protocol() def _write_ready(self): - data = b''.join(self._buffer) - assert data, 'Data should not be empty' + assert self._buffer, 'Data should not be empty' - self._buffer.clear() # Optimistically; may have to put it back later. try: - n = self._sock.send(data) + n = self._sock.send(self._buffer) except (BlockingIOError, InterruptedError): - self._buffer.append(data) # Still need to write this. + pass except Exception as exc: self._loop.remove_writer(self._sock_fd) + self._buffer.clear() self._fatal_error(exc) else: - data = data[n:] - if data: - self._buffer.append(data) # Still need to write this. + if n: + del self._buffer[:n] self._maybe_resume_protocol() # May append to buffer. if not self._buffer: self._loop.remove_writer(self._sock_fd) @@ -556,6 +556,8 @@ def can_write_eof(self): class _SelectorSslTransport(_SelectorTransport): + _buffer_factory = bytearray + def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, server_side=False, server_hostname=None, extra=None, server=None): @@ -712,10 +714,8 @@ def _write_ready(self): self._loop.add_reader(self._sock_fd, self._read_ready) if self._buffer: - data = b''.join(self._buffer) - self._buffer.clear() try: - n = self._sock.send(data) + n = self._sock.send(self._buffer) except (BlockingIOError, InterruptedError, ssl.SSLWantWriteError): n = 0 @@ -725,11 +725,12 @@ def _write_ready(self): self._write_wants_read = True except Exception as exc: self._loop.remove_writer(self._sock_fd) + self._buffer.clear() self._fatal_error(exc) return - if n < len(data): - self._buffer.append(data[n:]) + if n: + del self._buffer[:n] self._maybe_resume_protocol() # May append to buffer. @@ -753,7 +754,7 @@ def write(self, data): self._loop.add_writer(self._sock_fd, self._write_ready) # Add it to the buffer. - self._buffer.append(data) + self._buffer.extend(data) self._maybe_pause_protocol() def can_write_eof(self): @@ -762,6 +763,8 @@ def can_write_eof(self): class _SelectorDatagramTransport(_SelectorTransport): + _buffer_factory = collections.deque + def __init__(self, loop, sock, protocol, address=None, extra=None): super().__init__(loop, sock, protocol, extra) self._address = address diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 4aef2fde..e2ee277b 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -32,6 +32,10 @@ def _make_self_pipe(self): self._internal_fds += 1 +def list_to_buffer(l=()): + return bytearray().join(l) + + class BaseSelectorEventLoopTests(unittest.TestCase): def setUp(self): @@ -613,7 +617,7 @@ def test_close(self): def test_close_write_buffer(self): tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) - tr._buffer.append(b'data') + tr._buffer.extend(b'data') tr.close() self.assertFalse(self.loop.readers) @@ -622,13 +626,13 @@ def test_close_write_buffer(self): def test_force_close(self): tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) - tr._buffer.append(b'1') + tr._buffer.extend(b'1') self.loop.add_reader(7, unittest.mock.sentinel) self.loop.add_writer(7, unittest.mock.sentinel) tr._force_close(None) self.assertTrue(tr._closing) - self.assertEqual(tr._buffer, collections.deque()) + self.assertEqual(tr._buffer, list_to_buffer()) self.assertFalse(self.loop.readers) self.assertFalse(self.loop.writers) @@ -786,18 +790,18 @@ def test_write(self): def test_write_no_data(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) - transport._buffer.append(b'data') + transport._buffer.extend(b'data') transport.write(b'') self.assertFalse(self.sock.send.called) - self.assertEqual(collections.deque([b'data']), transport._buffer) + self.assertEqual(list_to_buffer([b'data']), transport._buffer) def test_write_buffer(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) - transport._buffer.append(b'data1') + transport._buffer.extend(b'data1') transport.write(b'data2') self.assertFalse(self.sock.send.called) - self.assertEqual(collections.deque([b'data1', b'data2']), + self.assertEqual(list_to_buffer([b'data1', b'data2']), transport._buffer) def test_write_partial(self): @@ -809,7 +813,7 @@ def test_write_partial(self): transport.write(data) self.loop.assert_writer(7, transport._write_ready) - self.assertEqual(collections.deque([b'ta']), transport._buffer) + self.assertEqual(list_to_buffer([b'ta']), transport._buffer) def test_write_partial_none(self): data = b'data' @@ -821,7 +825,7 @@ def test_write_partial_none(self): transport.write(data) self.loop.assert_writer(7, transport._write_ready) - self.assertEqual(collections.deque([b'data']), transport._buffer) + self.assertEqual(list_to_buffer([b'data']), transport._buffer) def test_write_tryagain(self): self.sock.send.side_effect = BlockingIOError @@ -832,7 +836,7 @@ def test_write_tryagain(self): transport.write(data) self.loop.assert_writer(7, transport._write_ready) - self.assertEqual(collections.deque([b'data']), transport._buffer) + self.assertEqual(list_to_buffer([b'data']), transport._buffer) @unittest.mock.patch('asyncio.selector_events.logger') def test_write_exception(self, m_log): @@ -875,11 +879,10 @@ def test_write_ready(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) - transport._buffer.append(data) + transport._buffer.extend(data) self.loop.add_writer(7, transport._write_ready) transport._write_ready() self.assertTrue(self.sock.send.called) - self.assertEqual(self.sock.send.call_args[0], (data,)) self.assertFalse(self.loop.writers) def test_write_ready_closing(self): @@ -889,10 +892,10 @@ def test_write_ready_closing(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) transport._closing = True - transport._buffer.append(data) + transport._buffer.extend(data) self.loop.add_writer(7, transport._write_ready) transport._write_ready() - self.sock.send.assert_called_with(data) + self.assertTrue(self.sock.send.called) self.assertFalse(self.loop.writers) self.sock.close.assert_called_with() self.protocol.connection_lost.assert_called_with(None) @@ -908,11 +911,11 @@ def test_write_ready_partial(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) - transport._buffer.append(data) + transport._buffer.extend(data) self.loop.add_writer(7, transport._write_ready) transport._write_ready() self.loop.assert_writer(7, transport._write_ready) - self.assertEqual(collections.deque([b'ta']), transport._buffer) + self.assertEqual(list_to_buffer([b'ta']), transport._buffer) def test_write_ready_partial_none(self): data = b'data' @@ -920,23 +923,23 @@ def test_write_ready_partial_none(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) - transport._buffer.append(data) + transport._buffer.extend(data) self.loop.add_writer(7, transport._write_ready) transport._write_ready() self.loop.assert_writer(7, transport._write_ready) - self.assertEqual(collections.deque([b'data']), transport._buffer) + self.assertEqual(list_to_buffer([b'data']), transport._buffer) def test_write_ready_tryagain(self): self.sock.send.side_effect = BlockingIOError transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) - transport._buffer = collections.deque([b'data1', b'data2']) + transport._buffer = list_to_buffer([b'data1', b'data2']) self.loop.add_writer(7, transport._write_ready) transport._write_ready() self.loop.assert_writer(7, transport._write_ready) - self.assertEqual(collections.deque([b'data1data2']), transport._buffer) + self.assertEqual(list_to_buffer([b'data1data2']), transport._buffer) def test_write_ready_exception(self): err = self.sock.send.side_effect = OSError() @@ -944,7 +947,7 @@ def test_write_ready_exception(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) transport._fatal_error = unittest.mock.Mock() - transport._buffer.append(b'data') + transport._buffer.extend(b'data') transport._write_ready() transport._fatal_error.assert_called_with(err) @@ -956,7 +959,7 @@ def test_write_ready_exception_and_close(self, m_log): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) transport.close() - transport._buffer.append(b'data') + transport._buffer.extend(b'data') transport._write_ready() remove_writer.assert_called_with(self.sock_fd) @@ -976,12 +979,12 @@ def test_write_eof_buffer(self): self.sock.send.side_effect = BlockingIOError tr.write(b'data') tr.write_eof() - self.assertEqual(tr._buffer, collections.deque([b'data'])) + self.assertEqual(tr._buffer, list_to_buffer([b'data'])) self.assertTrue(tr._eof) self.assertFalse(self.sock.shutdown.called) self.sock.send.side_effect = lambda _: 4 tr._write_ready() - self.sock.send.assert_called_with(b'data') + self.assertTrue(self.sock.send.called) self.sock.shutdown.assert_called_with(socket.SHUT_WR) tr.close() @@ -1067,9 +1070,9 @@ def test_pause_resume_reading(self): def test_write_no_data(self): transport = self._make_one() - transport._buffer.append(b'data') + transport._buffer.extend(b'data') transport.write(b'') - self.assertEqual(collections.deque([b'data']), transport._buffer) + self.assertEqual(list_to_buffer([b'data']), transport._buffer) def test_write_str(self): transport = self._make_one() @@ -1087,7 +1090,7 @@ def test_write_exception(self, m_log): transport = self._make_one() transport._conn_lost = 1 transport.write(b'data') - self.assertEqual(transport._buffer, collections.deque()) + self.assertEqual(transport._buffer, list_to_buffer()) transport.write(b'data') transport.write(b'data') transport.write(b'data') @@ -1107,7 +1110,7 @@ def test_read_ready_write_wants_read(self): transport = self._make_one() transport._write_wants_read = True transport._write_ready = unittest.mock.Mock() - transport._buffer.append(b'data') + transport._buffer.extend(b'data') transport._read_ready() self.assertFalse(transport._write_wants_read) @@ -1168,31 +1171,31 @@ def test_read_ready_recv_exc(self): def test_write_ready_send(self): self.sslsock.send.return_value = 4 transport = self._make_one() - transport._buffer = collections.deque([b'data']) + transport._buffer = list_to_buffer([b'data']) transport._write_ready() - self.assertEqual(collections.deque(), transport._buffer) + self.assertEqual(list_to_buffer(), transport._buffer) self.assertTrue(self.sslsock.send.called) def test_write_ready_send_none(self): self.sslsock.send.return_value = 0 transport = self._make_one() - transport._buffer = collections.deque([b'data1', b'data2']) + transport._buffer = list_to_buffer([b'data1', b'data2']) transport._write_ready() self.assertTrue(self.sslsock.send.called) - self.assertEqual(collections.deque([b'data1data2']), transport._buffer) + self.assertEqual(list_to_buffer([b'data1data2']), transport._buffer) def test_write_ready_send_partial(self): self.sslsock.send.return_value = 2 transport = self._make_one() - transport._buffer = collections.deque([b'data1', b'data2']) + transport._buffer = list_to_buffer([b'data1', b'data2']) transport._write_ready() self.assertTrue(self.sslsock.send.called) - self.assertEqual(collections.deque([b'ta1data2']), transport._buffer) + self.assertEqual(list_to_buffer([b'ta1data2']), transport._buffer) def test_write_ready_send_closing_partial(self): self.sslsock.send.return_value = 2 transport = self._make_one() - transport._buffer = collections.deque([b'data1', b'data2']) + transport._buffer = list_to_buffer([b'data1', b'data2']) transport._write_ready() self.assertTrue(self.sslsock.send.called) self.assertFalse(self.sslsock.close.called) @@ -1201,7 +1204,7 @@ def test_write_ready_send_closing(self): self.sslsock.send.return_value = 4 transport = self._make_one() transport.close() - transport._buffer = collections.deque([b'data']) + transport._buffer = list_to_buffer([b'data']) transport._write_ready() self.assertFalse(self.loop.writers) self.protocol.connection_lost.assert_called_with(None) @@ -1210,26 +1213,26 @@ def test_write_ready_send_closing_empty_buffer(self): self.sslsock.send.return_value = 4 transport = self._make_one() transport.close() - transport._buffer = collections.deque() + transport._buffer = list_to_buffer() transport._write_ready() self.assertFalse(self.loop.writers) self.protocol.connection_lost.assert_called_with(None) def test_write_ready_send_retry(self): transport = self._make_one() - transport._buffer = collections.deque([b'data']) + transport._buffer = list_to_buffer([b'data']) self.sslsock.send.side_effect = ssl.SSLWantWriteError transport._write_ready() - self.assertEqual(collections.deque([b'data']), transport._buffer) + self.assertEqual(list_to_buffer([b'data']), transport._buffer) self.sslsock.send.side_effect = BlockingIOError() transport._write_ready() - self.assertEqual(collections.deque([b'data']), transport._buffer) + self.assertEqual(list_to_buffer([b'data']), transport._buffer) def test_write_ready_send_read(self): transport = self._make_one() - transport._buffer = collections.deque([b'data']) + transport._buffer = list_to_buffer([b'data']) self.loop.remove_writer = unittest.mock.Mock() self.sslsock.send.side_effect = ssl.SSLWantReadError @@ -1242,11 +1245,11 @@ def test_write_ready_send_exc(self): err = self.sslsock.send.side_effect = OSError() transport = self._make_one() - transport._buffer = collections.deque([b'data']) + transport._buffer = list_to_buffer([b'data']) transport._fatal_error = unittest.mock.Mock() transport._write_ready() transport._fatal_error.assert_called_with(err) - self.assertEqual(collections.deque(), transport._buffer) + self.assertEqual(list_to_buffer(), transport._buffer) def test_write_ready_read_wants_write(self): self.loop.add_reader = unittest.mock.Mock() From b5a1ee9bfbb33a37ebea237734f78aff20a15320 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 27 Nov 2013 12:10:55 -0800 Subject: [PATCH 0809/1502] Variant of simple_tcp_server.py to measure timing. By Gustavo Carneiro. --- examples/timing_tcp_server.py | 163 ++++++++++++++++++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100644 examples/timing_tcp_server.py diff --git a/examples/timing_tcp_server.py b/examples/timing_tcp_server.py new file mode 100644 index 00000000..5290ed1c --- /dev/null +++ b/examples/timing_tcp_server.py @@ -0,0 +1,163 @@ +""" +A variant of simple_tcp_server.py that measures the time it takes to +send N messages for a range of N. (This was O(N**2) in a previous +version of Tulip.) + +Note that running this example starts both the TCP server and client +in the same process. It listens on port 1234 on 127.0.0.1, so it will +fail if this port is currently in use. +""" + +import sys +import time +import random + +import asyncio +import asyncio.streams + + +class MyServer: + """ + This is just an example of how a TCP server might be potentially + structured. This class has basically 3 methods: start the server, + handle a client, and stop the server. + + Note that you don't have to follow this structure, it is really + just an example or possible starting point. + """ + + def __init__(self): + self.server = None # encapsulates the server sockets + + # this keeps track of all the clients that connected to our + # server. It can be useful in some cases, for instance to + # kill client connections or to broadcast some data to all + # clients... + self.clients = {} # task -> (reader, writer) + + def _accept_client(self, client_reader, client_writer): + """ + This method accepts a new client connection and creates a Task + to handle this client. self.clients is updated to keep track + of the new client. + """ + + # start a new Task to handle this specific client connection + task = asyncio.Task(self._handle_client(client_reader, client_writer)) + self.clients[task] = (client_reader, client_writer) + + def client_done(task): + print("client task done:", task, file=sys.stderr) + del self.clients[task] + + task.add_done_callback(client_done) + + @asyncio.coroutine + def _handle_client(self, client_reader, client_writer): + """ + This method actually does the work to handle the requests for + a specific client. The protocol is line oriented, so there is + a main loop that reads a line with a request and then sends + out one or more lines back to the client with the result. + """ + while True: + data = (yield from client_reader.readline()).decode("utf-8") + if not data: # an empty string means the client disconnected + break + cmd, *args = data.rstrip().split(' ') + if cmd == 'add': + arg1 = float(args[0]) + arg2 = float(args[1]) + retval = arg1 + arg2 + client_writer.write("{!r}\n".format(retval).encode("utf-8")) + elif cmd == 'repeat': + times = int(args[0]) + msg = args[1] + client_writer.write("begin\n".encode("utf-8")) + for idx in range(times): + client_writer.write("{}. {}\n".format(idx+1, msg + 'x'*random.randint(10, 50)) + .encode("utf-8")) + client_writer.write("end\n".encode("utf-8")) + else: + print("Bad command {!r}".format(data), file=sys.stderr) + + # This enables us to have flow control in our connection. + yield from client_writer.drain() + + def start(self, loop): + """ + Starts the TCP server, so that it listens on port 1234. + + For each client that connects, the accept_client method gets + called. This method runs the loop until the server sockets + are ready to accept connections. + """ + self.server = loop.run_until_complete( + asyncio.streams.start_server(self._accept_client, + '127.0.0.1', 12345, + loop=loop)) + + def stop(self, loop): + """ + Stops the TCP server, i.e. closes the listening socket(s). + + This method runs the loop until the server sockets are closed. + """ + if self.server is not None: + self.server.close() + loop.run_until_complete(self.server.wait_closed()) + self.server = None + + +def main(): + loop = asyncio.get_event_loop() + + # creates a server and starts listening to TCP connections + server = MyServer() + server.start(loop) + + @asyncio.coroutine + def client(): + reader, writer = yield from asyncio.streams.open_connection( + '127.0.0.1', 12345, loop=loop) + + def send(msg): + print("> " + msg) + writer.write((msg + '\n').encode("utf-8")) + + def recv(): + msgback = (yield from reader.readline()).decode("utf-8").rstrip() + print("< " + msgback) + return msgback + + # send a line + send("add 1 2") + msg = yield from recv() + + Ns = list(range(100, 100000, 10000)) + times = [] + + for N in Ns: + t0 = time.time() + send("repeat {} hello world ".format(N)) + msg = yield from recv() + assert msg == 'begin' + while True: + msg = (yield from reader.readline()).decode("utf-8").rstrip() + if msg == 'end': + break + t1 = time.time() + dt = t1 - t0 + print("Time taken: {:.3f} seconds ({:.6f} per repetition)".format(dt, dt/N)) + times.append(dt) + + writer.close() + yield from asyncio.sleep(0.5) + + # creates a client and connects to our server + msg = loop.run_until_complete(client()) + server.stop(loop) + + +if __name__ == '__main__': + main() From b3be39e67da2eff808b16ed38199e823e91b14d9 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 27 Nov 2013 12:47:48 -0800 Subject: [PATCH 0810/1502] Support bytearray/memoryview arguments to write(), sendto(). Fixes issue 27. --- asyncio/selector_events.py | 9 +-- tests/test_selector_events.py | 105 ++++++++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+), 4 deletions(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 2b79621d..147ea877 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -490,7 +490,7 @@ def _read_ready(self): self.close() def write(self, data): - assert isinstance(data, bytes), repr(type(data)) + assert isinstance(data, (bytes, bytearray, memoryview)), repr(type(data)) assert not self._eof, 'Cannot call write() after write_eof()' if not data: return @@ -740,7 +740,7 @@ def _write_ready(self): self._call_connection_lost(None) def write(self, data): - assert isinstance(data, bytes), repr(type(data)) + assert isinstance(data, (bytes, bytearray, memoryview)), repr(type(data)) if not data: return @@ -787,7 +787,7 @@ def _read_ready(self): self._protocol.datagram_received(data, addr) def sendto(self, data, addr=None): - assert isinstance(data, bytes), repr(type(data)) + assert isinstance(data, (bytes, bytearray, memoryview)), repr(type(data)) if not data: return @@ -817,7 +817,8 @@ def sendto(self, data, addr=None): self._fatal_error(exc) return - self._buffer.append((data, addr)) + # Ensure that what we buffer is immutable. + self._buffer.append((bytes(data), addr)) self._maybe_pause_protocol() def _sendto_ready(self): diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index e2ee277b..3f3ae630 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -787,6 +787,25 @@ def test_write(self): transport.write(data) self.sock.send.assert_called_with(data) + def test_write_bytearray(self): + data = bytearray(b'data') + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.write(data) + self.sock.send.assert_called_with(data) + self.assertEqual(data, bytearray(b'data')) # Hasn't been mutated. + + def test_write_memoryview(self): + data = memoryview(b'data') + self.sock.send.return_value = len(data) + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.write(data) + self.sock.send.assert_called_with(data) + def test_write_no_data(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) @@ -815,6 +834,29 @@ def test_write_partial(self): self.loop.assert_writer(7, transport._write_ready) self.assertEqual(list_to_buffer([b'ta']), transport._buffer) + def test_write_partial_bytearray(self): + data = bytearray(b'data') + self.sock.send.return_value = 2 + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.write(data) + + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(list_to_buffer([b'ta']), transport._buffer) + self.assertEqual(data, bytearray(b'data')) # Hasn't been mutated. + + def test_write_partial_memoryview(self): + data = memoryview(b'data') + self.sock.send.return_value = 2 + + transport = _SelectorSocketTransport( + self.loop, self.sock, self.protocol) + transport.write(data) + + self.loop.assert_writer(7, transport._write_ready) + self.assertEqual(list_to_buffer([b'ta']), transport._buffer) + def test_write_partial_none(self): data = b'data' self.sock.send.return_value = 0 @@ -1068,6 +1110,25 @@ def test_pause_resume_reading(self): self.assertFalse(tr._paused) self.loop.assert_reader(1, tr._read_ready) + def test_write(self): + transport = self._make_one() + transport.write(b'data') + self.assertEqual(list_to_buffer([b'data']), transport._buffer) + + def test_write_bytearray(self): + transport = self._make_one() + data = bytearray(b'data') + transport.write(data) + self.assertEqual(list_to_buffer([b'data']), transport._buffer) + self.assertEqual(data, bytearray(b'data')) # Hasn't been mutated. + self.assertIsNot(data, transport._buffer) # Hasn't been incorporated. + + def test_write_memoryview(self): + transport = self._make_one() + data = memoryview(b'data') + transport.write(data) + self.assertEqual(list_to_buffer([b'data']), transport._buffer) + def test_write_no_data(self): transport = self._make_one() transport._buffer.extend(b'data') @@ -1358,6 +1419,24 @@ def test_sendto(self): self.assertEqual( self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234))) + def test_sendto_bytearray(self): + data = bytearray(b'data') + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport.sendto(data, ('0.0.0.0', 1234)) + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234))) + + def test_sendto_memoryview(self): + data = memoryview(b'data') + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport.sendto(data, ('0.0.0.0', 1234)) + self.assertTrue(self.sock.sendto.called) + self.assertEqual( + self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234))) + def test_sendto_no_data(self): transport = _SelectorDatagramTransport( self.loop, self.sock, self.protocol) @@ -1378,6 +1457,32 @@ def test_sendto_buffer(self): (b'data2', ('0.0.0.0', 12345))], list(transport._buffer)) + def test_sendto_buffer_bytearray(self): + data2 = bytearray(b'data2') + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append((b'data1', ('0.0.0.0', 12345))) + transport.sendto(data2, ('0.0.0.0', 12345)) + self.assertFalse(self.sock.sendto.called) + self.assertEqual( + [(b'data1', ('0.0.0.0', 12345)), + (b'data2', ('0.0.0.0', 12345))], + list(transport._buffer)) + self.assertIsInstance(transport._buffer[1][0], bytes) + + def test_sendto_buffer_memoryview(self): + data2 = memoryview(b'data2') + transport = _SelectorDatagramTransport( + self.loop, self.sock, self.protocol) + transport._buffer.append((b'data1', ('0.0.0.0', 12345))) + transport.sendto(data2, ('0.0.0.0', 12345)) + self.assertFalse(self.sock.sendto.called) + self.assertEqual( + [(b'data1', ('0.0.0.0', 12345)), + (b'data2', ('0.0.0.0', 12345))], + list(transport._buffer)) + self.assertIsInstance(transport._buffer[1][0], bytes) + def test_sendto_tryagain(self): data = b'data' From 56ba01954c6880b5f148add1e16c4bce26697fc8 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 27 Nov 2013 12:58:15 -0800 Subject: [PATCH 0811/1502] Fold long lines. --- examples/timing_tcp_server.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/timing_tcp_server.py b/examples/timing_tcp_server.py index 5290ed1c..6f0483eb 100644 --- a/examples/timing_tcp_server.py +++ b/examples/timing_tcp_server.py @@ -75,7 +75,8 @@ def _handle_client(self, client_reader, client_writer): msg = args[1] client_writer.write("begin\n".encode("utf-8")) for idx in range(times): - client_writer.write("{}. {}\n".format(idx+1, msg + 'x'*random.randint(10, 50)) + client_writer.write("{}. {}\n".format( + idx+1, msg + 'x'*random.randint(10, 50)) .encode("utf-8")) client_writer.write("end\n".encode("utf-8")) else: @@ -148,7 +149,8 @@ def recv(): break t1 = time.time() dt = t1 - t0 - print("Time taken: {:.3f} seconds ({:.6f} per repetition)".format(dt, dt/N)) + print("Time taken: {:.3f} seconds ({:.6f} per repetition)" + .format(dt, dt/N)) times.append(dt) writer.close() From 13abe0ecfd617f6a781d7b4fcabee196501633e3 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 27 Nov 2013 13:13:14 -0800 Subject: [PATCH 0812/1502] Replace some asserts with proper exceptions. --- asyncio/selector_events.py | 42 ++++++++++++++++++++++++----------- tests/test_selector_events.py | 9 ++++---- 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 147ea877..93efddc9 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -435,7 +435,9 @@ def set_write_buffer_limits(self, high=None, low=None): high = 4*low if low is None: low = high // 4 - assert 0 <= low <= high, repr((low, high)) + if not high >= low >= 0: + raise ValueError('high (%r) must be >= low (%r) must be >= 0' % + (high, low)) self._high_water = high self._low_water = low @@ -457,13 +459,16 @@ def __init__(self, loop, sock, protocol, waiter=None, self._loop.call_soon(waiter.set_result, None) def pause_reading(self): - assert not self._closing, 'Cannot pause_reading() when closing' - assert not self._paused, 'Already paused' + if self._closing: + raise RuntimeError('Cannot pause_reading() when closing') + if self._paused: + raise RuntimeError('Already paused') self._paused = True self._loop.remove_reader(self._sock_fd) def resume_reading(self): - assert self._paused, 'Not paused' + if not self._paused: + raise RuntimeError('Not paused') self._paused = False if self._closing: return @@ -490,8 +495,11 @@ def _read_ready(self): self.close() def write(self, data): - assert isinstance(data, (bytes, bytearray, memoryview)), repr(type(data)) - assert not self._eof, 'Cannot call write() after write_eof()' + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError('data argument must be byte-ish (%r)', + type(data)) + if self._eof: + raise RuntimeError('Cannot call write() after write_eof()') if not data: return @@ -663,13 +671,16 @@ def pause_reading(self): # accept more data for the buffer and eventually the app will # call resume_reading() again, and things will flow again. - assert not self._closing, 'Cannot pause_reading() when closing' - assert not self._paused, 'Already paused' + if self._closing: + raise RuntimeError('Cannot pause_reading() when closing') + if self._paused: + raise RuntimeError('Already paused') self._paused = True self._loop.remove_reader(self._sock_fd) def resume_reading(self): - assert self._paused, 'Not paused' + if not self._paused: + raise ('Not paused') self._paused = False if self._closing: return @@ -740,7 +751,9 @@ def _write_ready(self): self._call_connection_lost(None) def write(self, data): - assert isinstance(data, (bytes, bytearray, memoryview)), repr(type(data)) + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError('data argument must be byte-ish (%r)', + type(data)) if not data: return @@ -787,12 +800,15 @@ def _read_ready(self): self._protocol.datagram_received(data, addr) def sendto(self, data, addr=None): - assert isinstance(data, (bytes, bytearray, memoryview)), repr(type(data)) + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError('data argument must be byte-ish (%r)', + type(data)) if not data: return - if self._address: - assert addr in (None, self._address) + if self._address and addr not in (None, self._address): + raise ValueError('Invalid address: must be None or %s' % + (self._address,)) if self._conn_lost and self._address: if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 3f3ae630..38aa7669 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -905,7 +905,7 @@ def test_write_exception(self, m_log): def test_write_str(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) - self.assertRaises(AssertionError, transport.write, 'str') + self.assertRaises(TypeError, transport.write, 'str') def test_write_closing(self): transport = _SelectorSocketTransport( @@ -945,6 +945,7 @@ def test_write_ready_closing(self): def test_write_ready_no_data(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) + # This is an internal error. self.assertRaises(AssertionError, transport._write_ready) def test_write_ready_partial(self): @@ -1137,7 +1138,7 @@ def test_write_no_data(self): def test_write_str(self): transport = self._make_one() - self.assertRaises(AssertionError, transport.write, 'str') + self.assertRaises(TypeError, transport.write, 'str') def test_write_closing(self): transport = self._make_one() @@ -1547,13 +1548,13 @@ def test_sendto_error_received_connected(self): def test_sendto_str(self): transport = _SelectorDatagramTransport( self.loop, self.sock, self.protocol) - self.assertRaises(AssertionError, transport.sendto, 'str', ()) + self.assertRaises(TypeError, transport.sendto, 'str', ()) def test_sendto_connected_addr(self): transport = _SelectorDatagramTransport( self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) self.assertRaises( - AssertionError, transport.sendto, b'str', ('0.0.0.0', 2)) + ValueError, transport.sendto, b'str', ('0.0.0.0', 2)) def test_sendto_closing(self): transport = _SelectorDatagramTransport( From aad199b0bc1925f78718c693ef6343880329c99b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 29 Nov 2013 09:27:15 -0800 Subject: [PATCH 0813/1502] Add 'shield' to __all__. Fixes issue 91. --- asyncio/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 2a21a4b9..999e9629 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -3,7 +3,7 @@ __all__ = ['coroutine', 'Task', 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', 'wait', 'wait_for', 'as_completed', 'sleep', 'async', - 'gather', + 'gather', 'shield', ] import collections From ad3119f0a6ea4ecefa4722585ffa1b966a8f4ba9 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 29 Nov 2013 15:10:16 -0800 Subject: [PATCH 0814/1502] Remove outdated TODO and NOTES. I have stopeed updating these a long time ago. --- NOTES | 176 ---------------------------------------------------------- TODO | 163 ----------------------------------------------------- 2 files changed, 339 deletions(-) delete mode 100644 NOTES delete mode 100644 TODO diff --git a/NOTES b/NOTES deleted file mode 100644 index 3b94ba96..00000000 --- a/NOTES +++ /dev/null @@ -1,176 +0,0 @@ -Notes from PyCon 2013 sprints -============================= - -- Cancellation. If a task creates several subtasks, and then the - parent task fails, should the subtasks be cancelled? (How do we - even establish the parent/subtask relationship?) - -- Adam Sah suggests that there might be a need for scheduling - (especially when multiple frameworks share an event loop). He - points to lottery scheduling but also mentions that's just one of - the options. However, after posting on python-tulip, it appears - none of the other frameworks have scheduling, and nobody seems to - miss it. - -- Feedback from Bram Cohen (Bittorrent creator) about UDP. He doesn't - think connected UDP is worth supporting, it doesn't do anything - except tell the kernel about the default target address for - sendto(). Basically he says all UDP end points are servers. He - sent me his own UDP event loop so I might glean some tricks from it. - He says we should treat EINTR the same as EAGAIN and friends. (We - should use the exceptions dedicated to errno checking, BTW.) HE - said to make sure we use SO_REUSEADDR (I think we already do). He - said to set the max datagram sizes pretty large (anything larger - than the declared limit is dropped on the floor). He reminds us of - the importance of being able to pick a valid, unused port by binding - to port 0 and then using getsockname(). He has an idea where he's - like to be able to kill all registered callbacks (i.e. Handles) - belonging to a certain "context". I think this can be done at the - application level (you'd have to wrap everything that returns a - Handle and collect these handles in some set or other datastructure) - but if someone thinks it's interesting we could imagine having some - kind of notion of context part of the event loop state, - e.g. associated with a Task (see Cancellation point above). He - brought up uTP (Micro Transport Protocol), a reimplementation of TCP - over UDP with more refined congestion control. - -- Mumblings about UNIX domain sockets and IPv6 addresses being - 4-tuples. The former can be handled by passing in a socket. There - seem to be no real use cases for the latter that can't be dealt with - by passing in suitably esoteric strings for the hostname. - getaddrinfo() will produce the appropriate 4-tuple and connect() - will accept it. - -- Mumblings on the list about add vs. set. - - -Notes from the second Tulip/Twisted meet-up -=========================================== - -Rackspace, 12/11/2012 -Glyph, Brian Warner, David Reid, Duncan McGreggor, others - -Flow control ------------- - -- Pause/resume on transport manages data_received. - -- There's also an API to tell the transport whom to pause when the - write calls are overwhelming it: IConsumer.registerProducer(). - -- There's also something called pipes but it's built on top of the - old interface. - -- Twisted has variations on the basic flow control that I should - ignore. - -Half_close ----------- - -- This sends an EOF after writing some stuff. - -- Can't write any more. - -- Problem with TLS is known (the RFC sadly specifies this behavior). - -- It must be dynamimcally discoverable whether the transport supports - half_close, since the protocol may have to do something different to - make up for its missing (e.g. use chunked encoding). Twisted uses - an interface check for this and also hasattr(trans, 'halfClose') - but a flag (or flag method) is fine too. - -Constructing transport and protocol ------------------------------------ - -- There are good reasons for passing a function to the transport - construction helper that creates the protocol. (You need these - anyway for server-side protocols.) The sequence of events is - something like - - . open socket - . create transport (pass it a socket?) - . create protocol (pass it nothing) - . proto.make_connection(transport); this does: - . self.transport = transport - . self.connection_made(transport) - - But it seems okay to skip make_connection and setting .transport. - Note that make_connection() is a concrete method on the Protocol - implementation base class, while connection_made() is an abstract - method on IProtocol. - -Event Loop ----------- - -- We discussed the sequence of actions in the event loop. I think in the - end we're fine with what Tulip currently does. There are two choices: - - Tulip: - . run ready callbacks until there aren't any left - . poll, adding more callbacks to the ready list - . add now-ready delayed callbacks to the ready list - . go to top - - Tornado: - . run all currently ready callbacks (but not new ones added during this) - . (the rest is the same) - - The difference is that in the Tulip version, CPU bound callbacks - that keep adding more to the queue will starve I/O (and yielding to - other tasks won't actually cause I/O to happen unless you do - e.g. sleep(0.001)). OTOH this may be good because it means there's - less overhead if you frequently split operations in two. - -- I think Twisted does it Tornado style (in a convoluted way :-), but - it may not matter, and it's important to leave this vague so - implementations can do what's best for their platform. (E.g. if the - event loop is built into the OS there are different trade-offs.) - -System call cost ----------------- - -- System calls on MacOS are expensive, on Linux they are cheap. - -- Optimal buffer size ~16K. - -- Try joining small buffer pieces together, but expect to be tuning - this later. - -Futures -------- - -- Futures are the most robust API for async stuff, you can check - errors etc. So let's do this. - -- Just don't implement wait(). - -- For the basics, however, (recv/send, mostly), don't use Futures but use - basic callbacks, transport/protocol style. - -- make_connection() (by any name) can return a Future, it makes it - easier to check for errors. - -- This means revisiting the Tulip proactor branch (IOCP). - -- The semantics of add_done_callback() are fuzzy about in which thread - the callback will be called. (It may be the current thread or - another one.) We don't like that. But always inserting a - call_soon() indirection may be expensive? Glyph suggested changing - the add_done_callback() method name to something else to indicate - the changed promise. - -- Separately, I've been thinking about having two versions of - call_soon() -- a more heavy-weight one to be called from other - threads that also writes a byte to the self-pipe. - -Signals -------- - -- There was a side conversation about signals. A signal handler is - similar to another thread, so probably should use (the heavy-weight - version of) call_soon() to schedule the real callback and not do - anything else. - -- Glyph vaguely recalled some trickiness with the self-pipe. We - should be able to fix this afterwards if necessary, it shouldn't - affect the API design. diff --git a/TODO b/TODO deleted file mode 100644 index c6d4eead..00000000 --- a/TODO +++ /dev/null @@ -1,163 +0,0 @@ -# -*- Mode: text -*- - -TO DO LARGER TASKS - -- Need more examples. - -- Benchmarkable but more realistic HTTP server? - -- Example of using UDP. - -- Write up a tutorial for the scheduling API. - -- More systematic approach to logging. Logger objects? What about - heavy-duty logging, tracking essentially all task state changes? - -- Restructure directory, move demos and benchmarks to subdirectories. - - -TO DO LATER - -- When multiple tasks are accessing the same socket, they should - either get interleaved I/O or an immediate exception; it should not - compromise the integrity of the scheduler or the app or leave a task - hanging. - -- For epoll you probably want to check/(log?) EPOLLHUP and EPOLLERR errors. - -- Add the simplest API possible to run a generator with a timeout. - -- Ensure multiple tasks can do atomic writes to the same pipe (since - UNIX guarantees that short writes to pipes are atomic). - -- Ensure some easy way of distributing accepted connections across tasks. - -- Be wary of thread-local storage. There should be a standard API to - get the current Context (which holds current task, event loop, and - maybe more) and a standard meta-API to change how that standard API - works (i.e. without monkey-patching). - -- See how much of asyncore I've already replaced. - -- Could BufferedReader reuse the standard io module's readers??? - -- Support ZeroMQ "sockets" which are user objects. Though possibly - this can be supported by getting the underlying fd? See - http://mail.python.org/pipermail/python-ideas/2012-October/017532.html - OTOH see - https://github.com/zeromq/pyzmq/blob/master/zmq/eventloop/ioloop.py - -- Study goroutines (again). - -- Benchmarks: http://nichol.as/benchmark-of-python-web-servers - - -FROM OLDER LIST - -- Multiple readers/writers per socket? (At which level? pollster, - eventloop, or scheduler?) - -- Could poll() usefully be an iterator? - -- Do we need to support more epoll and/or kqueue modes/flags/options/etc.? - -- Optimize register/unregister calls away if they cancel each other out? - -- Add explicit wait queue to wait for Task's completion, instead of - callbacks? - -- Look at pyfdpdlib's ioloop.py: - http://code.google.com/p/pyftpdlib/source/browse/trunk/pyftpdlib/lib/ioloop.py - - -MISTAKES I MADE - -- Forgetting yield from. (E.g.: scheduler.sleep(1); listener.accept().) - -- Forgot to add bare yield at end of internal function, after block(). - -- Forgot to call add_done_callback(). - -- Forgot to pass an undoer to block(), bug only found when cancelled. - -- Subtle accounting mistake in a callback. - -- Used context.eventloop from a different thread, forgetting about TLS. - -- Nasty race: eventloop.ready may contain both an I/O callback and a - cancel callback. How to avoid? Keep the DelayedCall in ready. Is - that enough? - -- If a toplevel task raises an error it just stops and nothing is logged - unless you have debug logging on. This confused me. (Then again, - previously I logged whenever a task raised an error, and that was too - chatty...) - -- Forgot to set the connection socket returned by accept() in - nonblocking mode. - -- Nastiest so far (cost me about a day): A race condition in - call_in_thread() where the Future's done_callback (which was - task.unblock()) would run immediately at the time when - add_done_callback() was called, and this screwed over the task - state. Solution: wrap the callback in eventloop.call_later(). - Ironically, I had a comment stating there might be a race condition. - -- Another bug where I was calling unblock() for the current thread - immediately after calling block(), before yielding. - -- readexactly() wasn't checking for EOF, so could be looping. - (Worse, the first fix I attempted was wrong.) - -- Spent a day trying to understand why a tentative patch trying to - move the recv() implementation into the eventloop (or the pollster) - resulted in problems cancelling a recv() call. Ultimately the - problem is that the cancellation mechanism is part of the coroutine - scheduler, which simply throws an exception into a task when it next - runs, and there isn't anything to be interrupted in the eventloop; - but the eventloop still has a reader registered (which will never - fire because I suspended the server -- that's my test case :-). - Then, the eventloop keeps running until the last file descriptor is - unregistered. What contributed to this disaster? - * I didn't build the whole infrastructure, just played with recv() - * I don't have unittests - * I don't have good logging to see what is going - -- In sockets.py, in some SSL error handling code, used the wrong - variable (sock instead of sslsock). A linter would have found this. - -- In polling.py, in KqueuePollster.register_writer(), a copy/paste - error where I was testing for "if fd not in self.readers" instead of - writers. This only came out when I had both a reader and a writer - for the same fd. - -- Submitted some changes prematurely (forgot to pass the filename on - hg ci). - -- Forgot again that shutdown(SHUT_WR) on an ssl socket does not work - as I expected. I ran into this with the origininal sockets.py and - again in transport.py. - -- Having the same callback for both reading and writing has a problem: - it may be scheduled twice, and if the first call closes the socket, - the second runs into trouble. - - -MISTAKES I MADE IN TULIP V2 - -- Nice one: Future._schedule_callbacks() wasn't scheduling any callbacks. - Spot the bug in these four lines: - - def _schedule_callbacks(self): - callbacks = self._callbacks[:] - self._callbacks[:] = [] - for callback in self._callbacks: - self._event_loop.call_soon(callback, self) - - The good news is that I found it with a unittest (albeit not the - unittest intended to exercise this particular method :-( ). - -- In _make_self_pipe_or_sock(), called _pollster.register_reader() - instead of add_reader(), trying to optimize something but breaking - things instead (since the -- internal -- API of register_reader() - had changed). From cb72fe9527bb181617de8e25348ee9b9010e2203 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 30 Nov 2013 15:33:33 -0800 Subject: [PATCH 0815/1502] Use Interface instead of ABC. Fixes issue 19726. --- asyncio/events.py | 4 ++-- asyncio/protocols.py | 8 ++++---- asyncio/transports.py | 10 +++++----- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/asyncio/events.py b/asyncio/events.py index d429686b..c10faa75 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -234,7 +234,7 @@ def connect_read_pipe(self, protocol_factory, pipe): protocol_factory should instantiate object with Protocol interface. pipe is file-like object already switched to nonblocking. Return pair (transport, protocol), where transport support - ReadTransport ABC""" + ReadTransport interface.""" # The reason to accept file-like object instead of just file descriptor # is: we need to own pipe and close it at transport finishing # Can got complicated errors if pass f.fileno(), @@ -247,7 +247,7 @@ def connect_write_pipe(self, protocol_factory, pipe): protocol_factory should instantiate object with BaseProtocol interface. Pipe is file-like object already switched to nonblocking. Return pair (transport, protocol), where transport support - WriteTransport ABC""" + WriteTransport interface.""" # The reason to accept file-like object instead of just file descriptor # is: we need to own pipe and close it at transport finishing # Can got complicated errors if pass f.fileno(), diff --git a/asyncio/protocols.py b/asyncio/protocols.py index eb94fb6f..1b8870c6 100644 --- a/asyncio/protocols.py +++ b/asyncio/protocols.py @@ -4,7 +4,7 @@ class BaseProtocol: - """ABC for base protocol class. + """Common base class for protocol interfaces. Usually user implements protocols that derived from BaseProtocol like Protocol or ProcessProtocol. @@ -59,7 +59,7 @@ def resume_writing(self): class Protocol(BaseProtocol): - """ABC representing a protocol. + """Interface for stream protocol. The user should implement this interface. They can inherit from this class but don't need to. The implementations here do @@ -95,7 +95,7 @@ def eof_received(self): class DatagramProtocol(BaseProtocol): - """ABC representing a datagram protocol.""" + """Interface for datagram protocol.""" def datagram_received(self, data, addr): """Called when some datagram is received.""" @@ -108,7 +108,7 @@ def error_received(self, exc): class SubprocessProtocol(BaseProtocol): - """ABC representing a protocol for subprocess calls.""" + """Interface for protocol for subprocess calls.""" def pipe_data_received(self, fd, data): """Called when the subprocess writes data into stdout/stderr pipe. diff --git a/asyncio/transports.py b/asyncio/transports.py index 98f92247..86b850e9 100644 --- a/asyncio/transports.py +++ b/asyncio/transports.py @@ -4,7 +4,7 @@ class BaseTransport: - """Base ABC for transports.""" + """Base class for transports.""" def __init__(self, extra=None): if extra is None: @@ -27,7 +27,7 @@ def close(self): class ReadTransport(BaseTransport): - """ABC for read-only transports.""" + """Interface for read-only transports.""" def pause_reading(self): """Pause the receiving end. @@ -47,7 +47,7 @@ def resume_reading(self): class WriteTransport(BaseTransport): - """ABC for write-only transports.""" + """Interface for write-only transports.""" def set_write_buffer_limits(self, high=None, low=None): """Set the high- and low-water limits for write flow control. @@ -115,7 +115,7 @@ def abort(self): class Transport(ReadTransport, WriteTransport): - """ABC representing a bidirectional transport. + """Interface representing a bidirectional transport. There may be several implementations, but typically, the user does not implement new transports; rather, the platform provides some @@ -137,7 +137,7 @@ class Transport(ReadTransport, WriteTransport): class DatagramTransport(BaseTransport): - """ABC for datagram (UDP) transports.""" + """Interface for datagram (UDP) transports.""" def sendto(self, data, addr=None): """Send data to the transport. From 946c4a9f942cedb087ac93281401d06f81867189 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 2 Dec 2013 15:45:28 -0800 Subject: [PATCH 0816/1502] Make writelines() join the lines before calling write(). Fixes issue 92. Also added some missing NotImplementedError tests. --- asyncio/transports.py | 7 +++---- tests/test_transports.py | 10 ++++++++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/asyncio/transports.py b/asyncio/transports.py index 86b850e9..de2facb0 100644 --- a/asyncio/transports.py +++ b/asyncio/transports.py @@ -85,11 +85,10 @@ def write(self, data): def writelines(self, list_of_data): """Write a list (or any iterable) of data bytes to the transport. - The default implementation just calls write() for each item in - the list/iterable. + The default implementation concatenates the arguments and + calls write() on the result. """ - for data in list_of_data: - self.write(data) + self.write(b''.join(list_of_data)) def write_eof(self): """Close the write end after flushing buffered data. diff --git a/tests/test_transports.py b/tests/test_transports.py index f96445c1..29393b52 100644 --- a/tests/test_transports.py +++ b/tests/test_transports.py @@ -24,12 +24,18 @@ def test_writelines(self): transport = transports.Transport() transport.write = unittest.mock.Mock() - transport.writelines(['line1', 'line2', 'line3']) - self.assertEqual(3, transport.write.call_count) + transport.writelines([b'line1', + bytearray(b'line2'), + memoryview(b'line3')]) + self.assertEqual(1, transport.write.call_count) + transport.write.assert_called_with(b'line1line2line3') def test_not_implemented(self): transport = transports.Transport() + self.assertRaises(NotImplementedError, + transport.set_write_buffer_limits) + self.assertRaises(NotImplementedError, transport.get_write_buffer_size) self.assertRaises(NotImplementedError, transport.write, 'data') self.assertRaises(NotImplementedError, transport.write_eof) self.assertRaises(NotImplementedError, transport.can_write_eof) From f180ac149078622309d012c484d7f8dac7d3737f Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 2 Dec 2013 17:26:25 -0800 Subject: [PATCH 0817/1502] Incorporate selectors.py refactoring from CPython repo. --- asyncio/selectors.py | 112 ++++++++++++++++++++++++---------------- asyncio/test_utils.py | 14 +++++ tests/test_selectors.py | 2 +- 3 files changed, 83 insertions(+), 45 deletions(-) diff --git a/asyncio/selectors.py b/asyncio/selectors.py index 261fac6c..c533f13d 100644 --- a/asyncio/selectors.py +++ b/asyncio/selectors.py @@ -64,7 +64,7 @@ def __iter__(self): class BaseSelector(metaclass=ABCMeta): - """Base selector class. + """Selector abstract base class. A selector supports registering file objects to be monitored for specific I/O events. @@ -78,12 +78,7 @@ class BaseSelector(metaclass=ABCMeta): performant implementation on the current platform. """ - def __init__(self): - # this maps file descriptors to keys - self._fd_to_key = {} - # read-only mapping returned by get_map() - self._map = _SelectorMapping(self) - + @abstractmethod def register(self, fileobj, events, data=None): """Register a file object. @@ -95,18 +90,9 @@ def register(self, fileobj, events, data=None): Returns: SelectorKey instance """ - if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)): - raise ValueError("Invalid events: {!r}".format(events)) - - key = SelectorKey(fileobj, _fileobj_to_fd(fileobj), events, data) - - if key.fd in self._fd_to_key: - raise KeyError("{!r} (FD {}) is already " - "registered".format(fileobj, key.fd)) - - self._fd_to_key[key.fd] = key - return key + raise NotImplementedError + @abstractmethod def unregister(self, fileobj): """Unregister a file object. @@ -116,11 +102,7 @@ def unregister(self, fileobj): Returns: SelectorKey instance """ - try: - key = self._fd_to_key.pop(_fileobj_to_fd(fileobj)) - except KeyError: - raise KeyError("{!r} is not registered".format(fileobj)) from None - return key + raise NotImplementedError def modify(self, fileobj, events, data=None): """Change a registered file object monitored events or attached data. @@ -133,19 +115,8 @@ def modify(self, fileobj, events, data=None): Returns: SelectorKey instance """ - # TODO: Subclasses can probably optimize this even further. - try: - key = self._fd_to_key[_fileobj_to_fd(fileobj)] - except KeyError: - raise KeyError("{!r} is not registered".format(fileobj)) from None - if events != key.events: - self.unregister(fileobj) - key = self.register(fileobj, events, data) - elif data != key.data: - # Use a shortcut to update the data. - key = key._replace(data=data) - self._fd_to_key[key.fd] = key - return key + self.unregister(fileobj) + return self.register(fileobj, events, data) @abstractmethod def select(self, timeout=None): @@ -164,14 +135,14 @@ def select(self, timeout=None): list of (key, events) for ready file objects `events` is a bitwise mask of EVENT_READ|EVENT_WRITE """ - raise NotImplementedError() + raise NotImplementedError def close(self): """Close the selector. This must be called to make sure that any underlying resource is freed. """ - self._fd_to_key.clear() + pass def get_key(self, fileobj): """Return the key associated to a registered file object. @@ -179,14 +150,16 @@ def get_key(self, fileobj): Returns: SelectorKey for this file object """ + mapping = self.get_map() try: - return self._fd_to_key[_fileobj_to_fd(fileobj)] + return mapping[fileobj] except KeyError: raise KeyError("{!r} is not registered".format(fileobj)) from None + @abstractmethod def get_map(self): """Return a mapping of file objects to selector keys.""" - return self._map + raise NotImplementedError def __enter__(self): return self @@ -194,6 +167,57 @@ def __enter__(self): def __exit__(self, *args): self.close() + +class _BaseSelectorImpl(BaseSelector): + """Base selector implementation.""" + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # read-only mapping returned by get_map() + self._map = _SelectorMapping(self) + + def register(self, fileobj, events, data=None): + if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)): + raise ValueError("Invalid events: {!r}".format(events)) + + key = SelectorKey(fileobj, _fileobj_to_fd(fileobj), events, data) + + if key.fd in self._fd_to_key: + raise KeyError("{!r} (FD {}) is already " + "registered".format(fileobj, key.fd)) + + self._fd_to_key[key.fd] = key + return key + + def unregister(self, fileobj): + try: + key = self._fd_to_key.pop(_fileobj_to_fd(fileobj)) + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + return key + + def modify(self, fileobj, events, data=None): + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fd_to_key[_fileobj_to_fd(fileobj)] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + if events != key.events: + self.unregister(fileobj) + key = self.register(fileobj, events, data) + elif data != key.data: + # Use a shortcut to update the data. + key = key._replace(data=data) + self._fd_to_key[key.fd] = key + return key + + def close(self): + self._fd_to_key.clear() + + def get_map(self): + return self._map + def _key_from_fd(self, fd): """Return the key associated to a given file descriptor. @@ -209,7 +233,7 @@ def _key_from_fd(self, fd): return None -class SelectSelector(BaseSelector): +class SelectSelector(_BaseSelectorImpl): """Select-based selector.""" def __init__(self): @@ -262,7 +286,7 @@ def select(self, timeout=None): if hasattr(select, 'poll'): - class PollSelector(BaseSelector): + class PollSelector(_BaseSelectorImpl): """Poll-based selector.""" def __init__(self): @@ -306,7 +330,7 @@ def select(self, timeout=None): if hasattr(select, 'epoll'): - class EpollSelector(BaseSelector): + class EpollSelector(_BaseSelectorImpl): """Epoll-based selector.""" def __init__(self): @@ -358,7 +382,7 @@ def close(self): if hasattr(select, 'kqueue'): - class KqueueSelector(BaseSelector): + class KqueueSelector(_BaseSelectorImpl): """Kqueue-based selector.""" def __init__(self): diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index c278dd17..d7d84424 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -142,9 +142,23 @@ def make_test_protocol(base): class TestSelector(selectors.BaseSelector): + def __init__(self): + self.keys = {} + + def register(self, fileobj, events, data=None): + key = selectors.SelectorKey(fileobj, 0, events, data) + self.keys[fileobj] = key + return key + + def unregister(self, fileobj): + return self.keys.pop(fileobj) + def select(self, timeout): return [] + def get_map(self): + return self.keys + class TestLoop(base_events.BaseEventLoop): """Loop for unittests. diff --git a/tests/test_selectors.py b/tests/test_selectors.py index 78012897..0519d75a 100644 --- a/tests/test_selectors.py +++ b/tests/test_selectors.py @@ -6,7 +6,7 @@ from asyncio import selectors -class FakeSelector(selectors.BaseSelector): +class FakeSelector(selectors._BaseSelectorImpl): """Trivial non-abstract subclass of BaseSelector.""" def select(self, timeout=None): From 971bfd402dc028933bcee1d21c03b4891ef01880 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 2 Dec 2013 18:31:07 -0800 Subject: [PATCH 0818/1502] Make the new writelines() work for Python 3.3. --- asyncio/transports.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/asyncio/transports.py b/asyncio/transports.py index de2facb0..c2feb93d 100644 --- a/asyncio/transports.py +++ b/asyncio/transports.py @@ -1,5 +1,9 @@ """Abstract Transport class.""" +import sys + +PY34 = sys.version_info >= (3, 4) + __all__ = ['ReadTransport', 'WriteTransport', 'Transport'] @@ -88,6 +92,11 @@ def writelines(self, list_of_data): The default implementation concatenates the arguments and calls write() on the result. """ + if not PY34: + # In Python 3.3, bytes.join() doesn't handle memoryview. + list_of_data = ( + bytes(data) if isinstance(data, memoryview) else data + for data in list_of_data) self.write(b''.join(list_of_data)) def write_eof(self): From 4f73894bc28470c90a6fee47ca06e720078660e4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 2 Dec 2013 18:35:00 -0800 Subject: [PATCH 0819/1502] Upstream tweaks to locks docs. --- asyncio/locks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/asyncio/locks.py b/asyncio/locks.py index 2d458a9b..6cd6779e 100644 --- a/asyncio/locks.py +++ b/asyncio/locks.py @@ -79,7 +79,7 @@ def __repr__(self): return '<{} [{}]>'.format(res[1:-1], extra) def locked(self): - """Return true if lock is acquired.""" + """Return True if lock is acquired.""" return self._locked @tasks.coroutine @@ -138,7 +138,7 @@ def __iter__(self): class Event: - """An Event implementation, our equivalent to threading.Event. + """An Event implementation, asynchronous equivalent to threading.Event. Class implementing event objects. An event manages a flag that can be set to true with the set() method and reset to false with the clear() method. @@ -162,7 +162,7 @@ def __repr__(self): return '<{} [{}]>'.format(res[1:-1], extra) def is_set(self): - """Return true if and only if the internal flag is true.""" + """Return True if and only if the internal flag is true.""" return self._value def set(self): @@ -204,7 +204,7 @@ def wait(self): class Condition: - """A Condition implementation, our equivalent to threading.Condition. + """A Condition implementation, asynchronous equivalent to threading.Condition. This class implements condition variable objects. A condition variable allows one or more coroutines to wait until they are notified by another From a0a520dfcf8b20eea51bc19d10decc63d27503de Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 2 Dec 2013 11:04:49 -0800 Subject: [PATCH 0820/1502] Accept bytearray and memoryview, and replace asserts with raises. --- asyncio/proactor_events.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index ce226b9b..22e20dc9 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -24,7 +24,7 @@ def __init__(self, loop, sock, protocol, waiter=None, self._sock = sock self._protocol = protocol self._server = server - self._buffer = [] + self._buffer = [] # TODO: Use bytearray like selector_events.py. self._read_fut = None self._write_fut = None self._conn_lost = 0 @@ -95,12 +95,15 @@ def __init__(self, loop, sock, protocol, waiter=None, self._loop.call_soon(self._loop_reading) def pause_reading(self): - assert not self._closing, 'Cannot pause_reading() when closing' - assert not self._paused, 'Already paused' + if self._closing: + raise RuntimeError('Cannot pause_reading() when closing') + if self._paused: + raise RuntimeError('Already paused') self._paused = True def resume_reading(self): - assert self._paused, 'Not paused' + if not self._paused: + raise RuntimeError('Not paused') self._paused = False if self._closing: return @@ -155,9 +158,11 @@ class _ProactorWritePipeTransport(_ProactorBasePipeTransport, """Transport for write pipes.""" def write(self, data): - assert isinstance(data, bytes), repr(data) + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError('data argument must be byte-ish (%r)', + type(data)) if self._eof_written: - raise IOError('write_eof() already called') + raise RuntimeError('write_eof() already called') if not data: return @@ -167,7 +172,7 @@ def write(self, data): logger.warning('socket.send() raised exception.') self._conn_lost += 1 return - self._buffer.append(data) + self._buffer.append(bytes(data)) if self._write_fut is None: self._loop_writing() @@ -330,7 +335,8 @@ def _write_to_self(self): self._csock.send(b'x') def _start_serving(self, protocol_factory, sock, ssl=None, server=None): - assert not ssl, 'IocpEventLoop is incompatible with SSL.' + if ssl: + raise ValueError('IocpEventLoop is incompatible with SSL.') def loop(f=None): try: From 5180e423552c25381aef46116c983dd92eff98e4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 2 Dec 2013 16:04:02 -0800 Subject: [PATCH 0821/1502] Change proactor buffer management. --- asyncio/proactor_events.py | 31 +++++++++++++++++++++++-------- tests/test_proactor_events.py | 14 +++++++------- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 22e20dc9..43de0a9b 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -24,7 +24,7 @@ def __init__(self, loop, sock, protocol, waiter=None, self._sock = sock self._protocol = protocol self._server = server - self._buffer = [] # TODO: Use bytearray like selector_events.py. + self._buffer = None # None or bytearray. self._read_fut = None self._write_fut = None self._conn_lost = 0 @@ -63,7 +63,7 @@ def _force_close(self, exc): if self._read_fut: self._read_fut.cancel() self._write_fut = self._read_fut = None - self._buffer = [] + self._buffer = None self._loop.call_soon(self._call_connection_lost, exc) def _call_connection_lost(self, exc): @@ -172,18 +172,33 @@ def write(self, data): logger.warning('socket.send() raised exception.') self._conn_lost += 1 return - self._buffer.append(bytes(data)) - if self._write_fut is None: - self._loop_writing() - def _loop_writing(self, f=None): + # Observable states: + # 1. IDLE: _write_fut and _buffer both None + # 2. WRITING: _write_fut set; _buffer None + # 3. BACKED UP: _write_fut set; _buffer a bytearray + # We always copy the data, so the caller can't modify it + # while we're still waiting for the I/O to happen. + if self._write_fut is None: # IDLE -> WRITING + assert self._buffer is None + # Pass a copy, except if it's already immutable. + self._loop_writing(data=bytes(data)) + elif not self._buffer: # WRITING -> BACKED UP + # Make a mutable copy which we can extend. + self._buffer = bytearray(data) + else: # BACKED UP + # Append to buffer (also copies). + self._buffer.extend(data) + + def _loop_writing(self, f=None, data=None): try: assert f is self._write_fut self._write_fut = None if f: f.result() - data = b''.join(self._buffer) - self._buffer = [] + if data is None: + data = self._buffer + self._buffer = None if not data: if self._closing: self._loop.call_soon(self._call_connection_lost, None) diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index 5a2a51c4..9964f425 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -111,8 +111,8 @@ def test_write(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._loop_writing = unittest.mock.Mock() tr.write(b'data') - self.assertEqual(tr._buffer, [b'data']) - self.assertTrue(tr._loop_writing.called) + self.assertEqual(tr._buffer, None) + tr._loop_writing.assert_called_with(data=b'data') def test_write_no_data(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) @@ -124,12 +124,12 @@ def test_write_more(self): tr._write_fut = unittest.mock.Mock() tr._loop_writing = unittest.mock.Mock() tr.write(b'data') - self.assertEqual(tr._buffer, [b'data']) + self.assertEqual(tr._buffer, b'data') self.assertFalse(tr._loop_writing.called) def test_loop_writing(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) - tr._buffer = [b'da', b'ta'] + tr._buffer = bytearray(b'data') tr._loop_writing() self.loop._proactor.send.assert_called_with(self.sock, b'data') self.loop._proactor.send.return_value.add_done_callback.\ @@ -150,7 +150,7 @@ def test_loop_writing_err(self, m_log): tr.write(b'data') tr.write(b'data') tr.write(b'data') - self.assertEqual(tr._buffer, []) + self.assertEqual(tr._buffer, None) m_log.warning.assert_called_with('socket.send() raised exception.') def test_loop_writing_stop(self): @@ -226,7 +226,7 @@ def test_force_close(self): write_fut.cancel.assert_called_with() test_utils.run_briefly(self.loop) self.protocol.connection_lost.assert_called_with(None) - self.assertEqual([], tr._buffer) + self.assertEqual(None, tr._buffer) self.assertEqual(tr._conn_lost, 1) def test_force_close_idempotent(self): @@ -243,7 +243,7 @@ def test_fatal_error_2(self): test_utils.run_briefly(self.loop) self.protocol.connection_lost.assert_called_with(None) - self.assertEqual([], tr._buffer) + self.assertEqual(None, tr._buffer) def test_call_connection_lost(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) From 9c0e2f8dc0886084596d79be1a96160d03438dce Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 3 Dec 2013 11:33:56 -0800 Subject: [PATCH 0822/1502] Use try/finally to close loop in examples. Add --iocp to fetch3.py. --- examples/child_process.py | 6 ++++-- examples/fetch0.py | 5 ++++- examples/fetch1.py | 5 ++++- examples/fetch2.py | 5 ++++- examples/fetch3.py | 14 +++++++++++--- examples/sink.py | 6 ++++-- examples/source.py | 6 ++++-- examples/source1.py | 6 ++++-- 8 files changed, 39 insertions(+), 14 deletions(-) diff --git a/examples/child_process.py b/examples/child_process.py index 8a7ed0e6..4410414d 100644 --- a/examples/child_process.py +++ b/examples/child_process.py @@ -124,5 +124,7 @@ def writeall(fd, buf): asyncio.set_event_loop(loop) else: loop = asyncio.get_event_loop() - loop.run_until_complete(main(loop)) - loop.close() + try: + loop.run_until_complete(main(loop)) + finally: + loop.close() diff --git a/examples/fetch0.py b/examples/fetch0.py index ac4d5d95..180fcf26 100644 --- a/examples/fetch0.py +++ b/examples/fetch0.py @@ -24,7 +24,10 @@ def fetch(): def main(): loop = get_event_loop() - body = loop.run_until_complete(fetch()) + try: + body = loop.run_until_complete(fetch()) + finally: + loop.close() print(body.decode('latin-1'), end='') diff --git a/examples/fetch1.py b/examples/fetch1.py index 6d99262b..8dbb6e47 100644 --- a/examples/fetch1.py +++ b/examples/fetch1.py @@ -67,7 +67,10 @@ def fetch(url, verbose=True): def main(): loop = get_event_loop() - body = loop.run_until_complete(fetch(sys.argv[1], '-v' in sys.argv)) + try: + body = loop.run_until_complete(fetch(sys.argv[1], '-v' in sys.argv)) + finally: + loop.close() print(body.decode('latin-1'), end='') diff --git a/examples/fetch2.py b/examples/fetch2.py index 0899123f..7617b59b 100644 --- a/examples/fetch2.py +++ b/examples/fetch2.py @@ -130,7 +130,10 @@ def fetch(url, verbose=True): def main(): loop = get_event_loop() - body = loop.run_until_complete(fetch(sys.argv[1], '-v' in sys.argv)) + try: + body = loop.run_until_complete(fetch(sys.argv[1], '-v' in sys.argv)) + finally: + loop.close() sys.stdout.buffer.write(body) diff --git a/examples/fetch3.py b/examples/fetch3.py index fa9ebb01..780222ba 100644 --- a/examples/fetch3.py +++ b/examples/fetch3.py @@ -1,7 +1,7 @@ """Fetch one URL and write its content to stdout. This version adds a primitive connection pool, redirect following and -chunked transfer-encoding. +chunked transfer-encoding. It also supports a --iocp flag. """ import sys @@ -209,8 +209,16 @@ def fetch(url, verbose=True, max_redirect=10): def main(): - loop = get_event_loop() - body = loop.run_until_complete(fetch(sys.argv[1], '-v' in sys.argv)) + if '--iocp' in sys.argv: + from asyncio.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + set_event_loop(loop) + else: + loop = get_event_loop() + try: + body = loop.run_until_complete(fetch(sys.argv[1], '-v' in sys.argv)) + finally: + loop.close() sys.stdout.buffer.write(body) diff --git a/examples/sink.py b/examples/sink.py index 4b223fdd..d4866e2f 100644 --- a/examples/sink.py +++ b/examples/sink.py @@ -84,8 +84,10 @@ def main(): set_event_loop(loop) else: loop = get_event_loop() - loop.run_until_complete(start(loop, args.host, args.port)) - loop.close() + try: + loop.run_until_complete(start(loop, args.host, args.port)) + finally: + loop.close() if __name__ == '__main__': diff --git a/examples/source.py b/examples/source.py index 9aff8c1d..7fd11fb0 100644 --- a/examples/source.py +++ b/examples/source.py @@ -90,8 +90,10 @@ def main(): set_event_loop(loop) else: loop = get_event_loop() - loop.run_until_complete(start(loop, args.host, args.port)) - loop.close() + try: + loop.run_until_complete(start(loop, args.host, args.port)) + finally: + loop.close() if __name__ == '__main__': diff --git a/examples/source1.py b/examples/source1.py index b8f89790..6802e963 100644 --- a/examples/source1.py +++ b/examples/source1.py @@ -88,8 +88,10 @@ def main(): set_event_loop(loop) else: loop = get_event_loop() - loop.run_until_complete(start(loop, args)) - loop.close() + try: + loop.run_until_complete(start(loop, args)) + finally: + loop.close() if __name__ == '__main__': From 23caf0680f0075ea67bb13698f23feffa40abb3c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 3 Dec 2013 11:35:48 -0800 Subject: [PATCH 0823/1502] Support write flow control in proactor transport. --- asyncio/proactor_events.py | 67 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 3 deletions(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 43de0a9b..979bc25f 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -30,6 +30,8 @@ def __init__(self, loop, sock, protocol, waiter=None, self._conn_lost = 0 self._closing = False # Set when close() called. self._eof_written = False + self._protocol_paused = False + self.set_write_buffer_limits() if self._server is not None: self._server.attach(self) self._loop.call_soon(self._protocol.connection_made, self) @@ -82,6 +84,53 @@ def _call_connection_lost(self, exc): server.detach(self) self._server = None + # XXX The next four methods are nearly identical to corresponding + # ones in _SelectorTransport. Maybe refactor buffer management to + # share the implementations? (Also these are really only needed + # by _ProactorWritePipeTransport but since _buffer is defined on + # the base class I am putting it here for now.) + + def _maybe_pause_protocol(self): + size = self.get_write_buffer_size() + if size <= self._high_water: + return + if not self._protocol_paused: + self._protocol_paused = True + try: + self._protocol.pause_writing() + except Exception: + logger.exception('pause_writing() failed') + + def _maybe_resume_protocol(self): + if (self._protocol_paused and + self.get_write_buffer_size() <= self._low_water): + self._protocol_paused = False + try: + self._protocol.resume_writing() + except Exception: + logger.exception('resume_writing() failed') + + def set_write_buffer_limits(self, high=None, low=None): + if high is None: + if low is None: + high = 64*1024 + else: + high = 4*low + if low is None: + low = high // 4 + if not high >= low >= 0: + raise ValueError('high (%r) must be >= low (%r) must be >= 0' % + (high, low)) + self._high_water = high + self._low_water = low + + def get_write_buffer_size(self): + # NOTE: This doesn't take into account data already passed to + # send() even if send() hasn't finished yet. + if not self._buffer: + return 0 + return len(self._buffer) + class _ProactorReadPipeTransport(_ProactorBasePipeTransport, transports.ReadTransport): @@ -183,12 +232,18 @@ def write(self, data): assert self._buffer is None # Pass a copy, except if it's already immutable. self._loop_writing(data=bytes(data)) + # XXX Should we pause the protocol at this point + # if len(data) > self._high_water? (That would + # require keeping track of the number of bytes passed + # to a send() that hasn't finished yet.) elif not self._buffer: # WRITING -> BACKED UP # Make a mutable copy which we can extend. self._buffer = bytearray(data) + self._maybe_pause_protocol() else: # BACKED UP # Append to buffer (also copies). self._buffer.extend(data) + self._maybe_pause_protocol() def _loop_writing(self, f=None, data=None): try: @@ -204,9 +259,15 @@ def _loop_writing(self, f=None, data=None): self._loop.call_soon(self._call_connection_lost, None) if self._eof_written: self._sock.shutdown(socket.SHUT_WR) - return - self._write_fut = self._loop._proactor.send(self._sock, data) - self._write_fut.add_done_callback(self._loop_writing) + else: + self._write_fut = self._loop._proactor.send(self._sock, data) + self._write_fut.add_done_callback(self._loop_writing) + # Now that we've reduced the buffer size, tell the + # protocol to resume writing if it was paused. Note that + # we do this last since the callback is called immediately + # and it may add more data to the buffer (even causing the + # protocol to be paused again). + self._maybe_resume_protocol() except ConnectionResetError as exc: self._force_close(exc) except OSError as exc: From 482ebd7a955c9ade3ce9d9e05ac26c33d5eaef3a Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 4 Dec 2013 08:04:05 -0800 Subject: [PATCH 0824/1502] Fix docstring typo. --- asyncio/events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/events.py b/asyncio/events.py index c10faa75..62400195 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -99,7 +99,7 @@ def __ne__(self, other): class AbstractServer: - """Abstract server returned by create_service().""" + """Abstract server returned by create_server().""" def close(self): """Stop serving. This leaves existing connections open.""" From 1081e2ec9c79906d6d94617bfb372e3e38fcb813 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 5 Dec 2013 14:22:58 -0800 Subject: [PATCH 0825/1502] Set SA_RESTART to limit EINTR occurrences. (from CPython repo, by C.F. Natali.) --- asyncio/unix_events.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index b611efd1..eb3fb9f9 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -74,6 +74,8 @@ def add_signal_handler(self, sig, callback, *args): try: signal.signal(sig, self._handle_signal) + # Set SA_RESTART to limit EINTR occurrences. + signal.siginterrupt(sig, False) except OSError as exc: del self._signal_handlers[sig] if not self._signal_handlers: From 1fa03f4040a65780d5912b45c08c7e96b908f747 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 5 Dec 2013 15:34:27 -0800 Subject: [PATCH 0826/1502] Add Task.current_task([loop]). By Vladimir Kryachko. --- asyncio/tasks.py | 20 ++++++++++++++++++++ tests/test_tasks.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 999e9629..cd9718f5 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -122,6 +122,22 @@ class Task(futures.Future): # Weak set containing all tasks alive. _all_tasks = weakref.WeakSet() + # Dictionary containing tasks that are currently active in + # all running event loops. {EventLoop: Task} + _current_tasks = {} + + @classmethod + def current_task(cls, loop=None): + """Return the currently running task in an event loop or None. + + By default the current task for the current event loop is returned. + + None is returned when called not in the context of a Task. + """ + if loop is None: + loop = events.get_event_loop() + return cls._current_tasks.get(loop) + @classmethod def all_tasks(cls, loop=None): """Return a set of all tasks for an event loop. @@ -252,6 +268,8 @@ def _step(self, value=None, exc=None): self._must_cancel = False coro = self._coro self._fut_waiter = None + + self.__class__._current_tasks[self._loop] = self # Call either coro.throw(exc) or coro.send(value). try: if exc is not None: @@ -302,6 +320,8 @@ def _step(self, value=None, exc=None): self._step, None, RuntimeError( 'Task got bad yield: {!r}'.format(result))) + finally: + self.__class__._current_tasks.pop(self._loop) self = None def _wakeup(self, future): diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 8f0d0815..5470da15 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1113,6 +1113,42 @@ def coro(): self.assertEqual(res, 'test') self.assertIsNone(t2.result()) + def test_current_task(self): + self.assertIsNone(tasks.Task.current_task(loop=self.loop)) + @tasks.coroutine + def coro(loop): + self.assertTrue(tasks.Task.current_task(loop=loop) is task) + + task = tasks.Task(coro(self.loop), loop=self.loop) + self.loop.run_until_complete(task) + self.assertIsNone(tasks.Task.current_task(loop=self.loop)) + + def test_current_task_with_interleaving_tasks(self): + self.assertIsNone(tasks.Task.current_task(loop=self.loop)) + + fut1 = futures.Future(loop=self.loop) + fut2 = futures.Future(loop=self.loop) + + @tasks.coroutine + def coro1(loop): + self.assertTrue(tasks.Task.current_task(loop=loop) is task1) + yield from fut1 + self.assertTrue(tasks.Task.current_task(loop=loop) is task1) + fut2.set_result(True) + + @tasks.coroutine + def coro2(loop): + self.assertTrue(tasks.Task.current_task(loop=loop) is task2) + fut1.set_result(True) + yield from fut2 + self.assertTrue(tasks.Task.current_task(loop=loop) is task2) + + task1 = tasks.Task(coro1(self.loop), loop=self.loop) + task2 = tasks.Task(coro2(self.loop), loop=self.loop) + + self.loop.run_until_complete(tasks.wait((task1, task2), loop=self.loop)) + self.assertIsNone(tasks.Task.current_task(loop=self.loop)) + # Some thorough tests for cancellation propagation through # coroutines, tasks and wait(). From cfabd0335542a3bfc8e18b83d2e971c71c88ba97 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 6 Dec 2013 12:59:35 -0800 Subject: [PATCH 0827/1502] SSL hostname checking changes from CPython repo by Christian Heimes. --- asyncio/selector_events.py | 25 ++++--- asyncio/test_utils.py | 6 +- examples/sink.py | 4 +- tests/keycert3.pem | 73 +++++++++++++++++++ tests/pycacert.pem | 78 +++++++++++++++++++++ tests/ssl_cert.pem | 15 ++++ tests/ssl_key.pem | 16 +++++ tests/test_events.py | 140 ++++++++++++++++++++++++++++++++----- update_stdlib.sh | 2 +- 9 files changed, 325 insertions(+), 34 deletions(-) create mode 100644 tests/keycert3.pem create mode 100644 tests/pycacert.pem create mode 100644 tests/ssl_cert.pem create mode 100644 tests/ssl_key.pem diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 93efddc9..19caf79d 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -583,7 +583,8 @@ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, # cadefault=True. if hasattr(ssl, '_create_stdlib_context'): sslcontext = ssl._create_stdlib_context( - cert_reqs=ssl.CERT_REQUIRED) + cert_reqs=ssl.CERT_REQUIRED, + check_hostname=bool(server_hostname)) else: # Fallback for Python 3.3. sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) @@ -639,17 +640,19 @@ def _on_handshake(self): self._loop.remove_reader(self._sock_fd) self._loop.remove_writer(self._sock_fd) - # Verify hostname if requested. peercert = self._sock.getpeercert() - if (self._server_hostname and - self._sslcontext.verify_mode != ssl.CERT_NONE): - try: - ssl.match_hostname(peercert, self._server_hostname) - except Exception as exc: - self._sock.close() - if self._waiter is not None: - self._waiter.set_exception(exc) - return + if not hasattr(self._sslcontext, 'check_hostname'): + # Verify hostname if requested, Python 3.4+ uses check_hostname + # and checks the hostname in do_handshake() + if (self._server_hostname and + self._sslcontext.verify_mode != ssl.CERT_NONE): + try: + ssl.match_hostname(peercert, self._server_hostname) + except Exception as exc: + self._sock.close() + if self._waiter is not None: + self._waiter.set_exception(exc) + return # Add extra info that becomes available after handshake. self._extra.update(peercert=peercert, diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index d7d84424..131a5460 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -88,15 +88,15 @@ def handle_error(self, request, client_address): class SSLWSGIServer(SilentWSGIServer): def finish_request(self, request, client_address): # The relative location of our test directory (which - # contains the sample key and certificate files) differs + # contains the ssl key and certificate files) differs # between the stdlib and stand-alone Tulip/asyncio. # Prefer our own if we can find it. here = os.path.join(os.path.dirname(__file__), '..', 'tests') if not os.path.isdir(here): here = os.path.join(os.path.dirname(os.__file__), 'test', 'test_asyncio') - keyfile = os.path.join(here, 'sample.key') - certfile = os.path.join(here, 'sample.crt') + keyfile = os.path.join(here, 'ssl_key.pem') + certfile = os.path.join(here, 'ssl_cert.pem') ssock = ssl.wrap_socket(request, keyfile=keyfile, certfile=certfile, diff --git a/examples/sink.py b/examples/sink.py index d4866e2f..d362cbb2 100644 --- a/examples/sink.py +++ b/examples/sink.py @@ -66,8 +66,8 @@ def start(loop, host, port): sslctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslctx.options |= ssl.OP_NO_SSLv2 sslctx.load_cert_chain( - certfile=os.path.join(here, 'sample.crt'), - keyfile=os.path.join(here, 'sample.key')) + certfile=os.path.join(here, 'ssl_cert.pem'), + keyfile=os.path.join(here, 'ssl_key.pem')) server = yield from loop.create_server(Service, host, port, ssl=sslctx) dprint('serving TLS' if sslctx else 'serving', diff --git a/tests/keycert3.pem b/tests/keycert3.pem new file mode 100644 index 00000000..5bfa62c4 --- /dev/null +++ b/tests/keycert3.pem @@ -0,0 +1,73 @@ +-----BEGIN PRIVATE KEY----- +MIICdgIBADANBgkqhkiG9w0BAQEFAASCAmAwggJcAgEAAoGBAMLgD0kAKDb5cFyP +jbwNfR5CtewdXC+kMXAWD8DLxiTTvhMW7qVnlwOm36mZlszHKvsRf05lT4pegiFM +9z2j1OlaN+ci/X7NU22TNN6crYSiN77FjYJP464j876ndSxyD+rzys386T+1r1aZ +aggEdkj1TsSsv1zWIYKlPIjlvhuxAgMBAAECgYA0aH+T2Vf3WOPv8KdkcJg6gCRe +yJKXOWgWRcicx/CUzOEsTxmFIDPLxqAWA3k7v0B+3vjGw5Y9lycV/5XqXNoQI14j +y09iNsumds13u5AKkGdTJnZhQ7UKdoVHfuP44ZdOv/rJ5/VD6F4zWywpe90pcbK+ +AWDVtusgGQBSieEl1QJBAOyVrUG5l2yoUBtd2zr/kiGm/DYyXlIthQO/A3/LngDW +5/ydGxVsT7lAVOgCsoT+0L4efTh90PjzW8LPQrPBWVMCQQDS3h/FtYYd5lfz+FNL +9CEe1F1w9l8P749uNUD0g317zv1tatIqVCsQWHfVHNdVvfQ+vSFw38OORO00Xqs9 +1GJrAkBkoXXEkxCZoy4PteheO/8IWWLGGr6L7di6MzFl1lIqwT6D8L9oaV2vynFT +DnKop0pa09Unhjyw57KMNmSE2SUJAkEArloTEzpgRmCq4IK2/NpCeGdHS5uqRlbh +1VIa/xGps7EWQl5Mn8swQDel/YP3WGHTjfx7pgSegQfkyaRtGpZ9OQJAa9Vumj8m +JAAtI0Bnga8hgQx7BhTQY4CadDxyiRGOGYhwUzYVCqkb2sbVRH9HnwUaJT7cWBY3 +RnJdHOMXWem7/w== +-----END PRIVATE KEY----- +Certificate: + Data: + Version: 1 (0x0) + Serial Number: 12723342612721443281 (0xb09264b1f2da21d1) + Signature Algorithm: sha1WithRSAEncryption + Issuer: C=XY, O=Python Software Foundation CA, CN=our-ca-server + Validity + Not Before: Jan 4 19:47:07 2013 GMT + Not After : Nov 13 19:47:07 2022 GMT + Subject: C=XY, L=Castle Anthrax, O=Python Software Foundation, CN=localhost + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (1024 bit) + Modulus: + 00:c2:e0:0f:49:00:28:36:f9:70:5c:8f:8d:bc:0d: + 7d:1e:42:b5:ec:1d:5c:2f:a4:31:70:16:0f:c0:cb: + c6:24:d3:be:13:16:ee:a5:67:97:03:a6:df:a9:99: + 96:cc:c7:2a:fb:11:7f:4e:65:4f:8a:5e:82:21:4c: + f7:3d:a3:d4:e9:5a:37:e7:22:fd:7e:cd:53:6d:93: + 34:de:9c:ad:84:a2:37:be:c5:8d:82:4f:e3:ae:23: + f3:be:a7:75:2c:72:0f:ea:f3:ca:cd:fc:e9:3f:b5: + af:56:99:6a:08:04:76:48:f5:4e:c4:ac:bf:5c:d6: + 21:82:a5:3c:88:e5:be:1b:b1 + Exponent: 65537 (0x10001) + Signature Algorithm: sha1WithRSAEncryption + 2f:42:5f:a3:09:2c:fa:51:88:c7:37:7f:ea:0e:63:f0:a2:9a: + e5:5a:e2:c8:20:f0:3f:60:bc:c8:0f:b6:c6:76:ce:db:83:93: + f5:a3:33:67:01:8e:04:cd:00:9a:73:fd:f3:35:86:fa:d7:13: + e2:46:c6:9d:c0:29:53:d4:a9:90:b8:77:4b:e6:83:76:e4:92: + d6:9c:50:cf:43:d0:c6:01:77:61:9a:de:9b:70:f7:72:cd:59: + 00:31:69:d9:b4:ca:06:9c:6d:c3:c7:80:8c:68:e6:b5:a2:f8: + ef:1d:bb:16:9f:77:77:ef:87:62:22:9b:4d:69:a4:3a:1a:f1: + 21:5e:8c:32:ac:92:fd:15:6b:18:c2:7f:15:0d:98:30:ca:75: + 8f:1a:71:df:da:1d:b2:ef:9a:e8:2d:2e:02:fd:4a:3c:aa:96: + 0b:06:5d:35:b3:3d:24:87:4b:e0:b0:58:60:2f:45:ac:2e:48: + 8a:b0:99:10:65:27:ff:cc:b1:d8:fd:bd:26:6b:b9:0c:05:2a: + f4:45:63:35:51:07:ed:83:85:fe:6f:69:cb:bb:40:a8:ae:b6: + 3b:56:4a:2d:a4:ed:6d:11:2c:4d:ed:17:24:fd:47:bc:d3:41: + a2:d3:06:fe:0c:90:d8:d8:94:26:c4:ff:cc:a1:d8:42:77:eb: + fc:a9:94:71 +-----BEGIN CERTIFICATE----- +MIICpDCCAYwCCQCwkmSx8toh0TANBgkqhkiG9w0BAQUFADBNMQswCQYDVQQGEwJY +WTEmMCQGA1UECgwdUHl0aG9uIFNvZnR3YXJlIEZvdW5kYXRpb24gQ0ExFjAUBgNV +BAMMDW91ci1jYS1zZXJ2ZXIwHhcNMTMwMTA0MTk0NzA3WhcNMjIxMTEzMTk0NzA3 +WjBfMQswCQYDVQQGEwJYWTEXMBUGA1UEBxMOQ2FzdGxlIEFudGhyYXgxIzAhBgNV +BAoTGlB5dGhvbiBTb2Z0d2FyZSBGb3VuZGF0aW9uMRIwEAYDVQQDEwlsb2NhbGhv +c3QwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAMLgD0kAKDb5cFyPjbwNfR5C +tewdXC+kMXAWD8DLxiTTvhMW7qVnlwOm36mZlszHKvsRf05lT4pegiFM9z2j1Ola +N+ci/X7NU22TNN6crYSiN77FjYJP464j876ndSxyD+rzys386T+1r1aZaggEdkj1 +TsSsv1zWIYKlPIjlvhuxAgMBAAEwDQYJKoZIhvcNAQEFBQADggEBAC9CX6MJLPpR +iMc3f+oOY/CimuVa4sgg8D9gvMgPtsZ2ztuDk/WjM2cBjgTNAJpz/fM1hvrXE+JG +xp3AKVPUqZC4d0vmg3bkktacUM9D0MYBd2Ga3ptw93LNWQAxadm0ygacbcPHgIxo +5rWi+O8duxafd3fvh2Iim01ppDoa8SFejDKskv0VaxjCfxUNmDDKdY8acd/aHbLv +mugtLgL9SjyqlgsGXTWzPSSHS+CwWGAvRawuSIqwmRBlJ//Msdj9vSZruQwFKvRF +YzVRB+2Dhf5vacu7QKiutjtWSi2k7W0RLE3tFyT9R7zTQaLTBv4MkNjYlCbE/8yh +2EJ36/yplHE= +-----END CERTIFICATE----- diff --git a/tests/pycacert.pem b/tests/pycacert.pem new file mode 100644 index 00000000..09b1f3e0 --- /dev/null +++ b/tests/pycacert.pem @@ -0,0 +1,78 @@ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: 12723342612721443280 (0xb09264b1f2da21d0) + Signature Algorithm: sha1WithRSAEncryption + Issuer: C=XY, O=Python Software Foundation CA, CN=our-ca-server + Validity + Not Before: Jan 4 19:47:07 2013 GMT + Not After : Jan 2 19:47:07 2023 GMT + Subject: C=XY, O=Python Software Foundation CA, CN=our-ca-server + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (2048 bit) + Modulus: + 00:e7:de:e9:e3:0c:9f:00:b6:a1:fd:2b:5b:96:d2: + 6f:cc:e0:be:86:b9:20:5e:ec:03:7a:55:ab:ea:a4: + e9:f9:49:85:d2:66:d5:ed:c7:7a:ea:56:8e:2d:8f: + e7:42:e2:62:28:a9:9f:d6:1b:8e:eb:b5:b4:9c:9f: + 14:ab:df:e6:94:8b:76:1d:3e:6d:24:61:ed:0c:bf: + 00:8a:61:0c:df:5c:c8:36:73:16:00:cd:47:ba:6d: + a4:a4:74:88:83:23:0a:19:fc:09:a7:3c:4a:4b:d3: + e7:1d:2d:e4:ea:4c:54:21:f3:26:db:89:37:18:d4: + 02:bb:40:32:5f:a4:ff:2d:1c:f7:d4:bb:ec:8e:cf: + 5c:82:ac:e6:7c:08:6c:48:85:61:07:7f:25:e0:5c: + e0:bc:34:5f:e0:b9:04:47:75:c8:47:0b:8d:bc:d6: + c8:68:5f:33:83:62:d2:20:44:35:b1:ad:81:1a:8a: + cd:bc:35:b0:5c:8b:47:d6:18:e9:9c:18:97:cc:01: + 3c:29:cc:e8:1e:e4:e4:c1:b8:de:e7:c2:11:18:87: + 5a:93:34:d8:a6:25:f7:14:71:eb:e4:21:a2:d2:0f: + 2e:2e:d4:62:00:35:d3:d6:ef:5c:60:4b:4c:a9:14: + e2:dd:15:58:46:37:33:26:b7:e7:2e:5d:ed:42:e4: + c5:4d + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Subject Key Identifier: + BC:DD:62:D9:76:DA:1B:D2:54:6B:CF:E0:66:9B:1E:1E:7B:56:0C:0B + X509v3 Authority Key Identifier: + keyid:BC:DD:62:D9:76:DA:1B:D2:54:6B:CF:E0:66:9B:1E:1E:7B:56:0C:0B + + X509v3 Basic Constraints: + CA:TRUE + Signature Algorithm: sha1WithRSAEncryption + 7d:0a:f5:cb:8d:d3:5d:bd:99:8e:f8:2b:0f:ba:eb:c2:d9:a6: + 27:4f:2e:7b:2f:0e:64:d8:1c:35:50:4e:ee:fc:90:b9:8d:6d: + a8:c5:c6:06:b0:af:f3:2d:bf:3b:b8:42:07:dd:18:7d:6d:95: + 54:57:85:18:60:47:2f:eb:78:1b:f9:e8:17:fd:5a:0d:87:17: + 28:ac:4c:6a:e6:bc:29:f4:f4:55:70:29:42:de:85:ea:ab:6c: + 23:06:64:30:75:02:8e:53:bc:5e:01:33:37:cc:1e:cd:b8:a4: + fd:ca:e4:5f:65:3b:83:1c:86:f1:55:02:a0:3a:8f:db:91:b7: + 40:14:b4:e7:8d:d2:ee:73:ba:e3:e5:34:2d:bc:94:6f:4e:24: + 06:f7:5f:8b:0e:a7:8e:6b:de:5e:75:f4:32:9a:50:b1:44:33: + 9a:d0:05:e2:78:82:ff:db:da:8a:63:eb:a9:dd:d1:bf:a0:61: + ad:e3:9e:8a:24:5d:62:0e:e7:4c:91:7f:ef:df:34:36:3b:2f: + 5d:f5:84:b2:2f:c4:6d:93:96:1a:6f:30:28:f1:da:12:9a:64: + b4:40:33:1d:bd:de:2b:53:a8:ea:be:d6:bc:4e:96:f5:44:fb: + 32:18:ae:d5:1f:f6:69:af:b6:4e:7b:1d:58:ec:3b:a9:53:a3: + 5e:58:c8:9e +-----BEGIN CERTIFICATE----- +MIIDbTCCAlWgAwIBAgIJALCSZLHy2iHQMA0GCSqGSIb3DQEBBQUAME0xCzAJBgNV +BAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUgRm91bmRhdGlvbiBDQTEW +MBQGA1UEAwwNb3VyLWNhLXNlcnZlcjAeFw0xMzAxMDQxOTQ3MDdaFw0yMzAxMDIx +OTQ3MDdaME0xCzAJBgNVBAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUg +Rm91bmRhdGlvbiBDQTEWMBQGA1UEAwwNb3VyLWNhLXNlcnZlcjCCASIwDQYJKoZI +hvcNAQEBBQADggEPADCCAQoCggEBAOfe6eMMnwC2of0rW5bSb8zgvoa5IF7sA3pV +q+qk6flJhdJm1e3HeupWji2P50LiYiipn9Ybjuu1tJyfFKvf5pSLdh0+bSRh7Qy/ +AIphDN9cyDZzFgDNR7ptpKR0iIMjChn8Cac8SkvT5x0t5OpMVCHzJtuJNxjUArtA +Ml+k/y0c99S77I7PXIKs5nwIbEiFYQd/JeBc4Lw0X+C5BEd1yEcLjbzWyGhfM4Ni +0iBENbGtgRqKzbw1sFyLR9YY6ZwYl8wBPCnM6B7k5MG43ufCERiHWpM02KYl9xRx +6+QhotIPLi7UYgA109bvXGBLTKkU4t0VWEY3Mya35y5d7ULkxU0CAwEAAaNQME4w +HQYDVR0OBBYEFLzdYtl22hvSVGvP4GabHh57VgwLMB8GA1UdIwQYMBaAFLzdYtl2 +2hvSVGvP4GabHh57VgwLMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEB +AH0K9cuN0129mY74Kw+668LZpidPLnsvDmTYHDVQTu78kLmNbajFxgawr/Mtvzu4 +QgfdGH1tlVRXhRhgRy/reBv56Bf9Wg2HFyisTGrmvCn09FVwKULeheqrbCMGZDB1 +Ao5TvF4BMzfMHs24pP3K5F9lO4MchvFVAqA6j9uRt0AUtOeN0u5zuuPlNC28lG9O +JAb3X4sOp45r3l519DKaULFEM5rQBeJ4gv/b2opj66nd0b+gYa3jnookXWIO50yR +f+/fNDY7L131hLIvxG2TlhpvMCjx2hKaZLRAMx293itTqOq+1rxOlvVE+zIYrtUf +9mmvtk57HVjsO6lTo15YyJ4= +-----END CERTIFICATE----- diff --git a/tests/ssl_cert.pem b/tests/ssl_cert.pem new file mode 100644 index 00000000..47a7d7e3 --- /dev/null +++ b/tests/ssl_cert.pem @@ -0,0 +1,15 @@ +-----BEGIN CERTIFICATE----- +MIICVDCCAb2gAwIBAgIJANfHOBkZr8JOMA0GCSqGSIb3DQEBBQUAMF8xCzAJBgNV +BAYTAlhZMRcwFQYDVQQHEw5DYXN0bGUgQW50aHJheDEjMCEGA1UEChMaUHl0aG9u +IFNvZnR3YXJlIEZvdW5kYXRpb24xEjAQBgNVBAMTCWxvY2FsaG9zdDAeFw0xMDEw +MDgyMzAxNTZaFw0yMDEwMDUyMzAxNTZaMF8xCzAJBgNVBAYTAlhZMRcwFQYDVQQH +Ew5DYXN0bGUgQW50aHJheDEjMCEGA1UEChMaUHl0aG9uIFNvZnR3YXJlIEZvdW5k +YXRpb24xEjAQBgNVBAMTCWxvY2FsaG9zdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAw +gYkCgYEA21vT5isq7F68amYuuNpSFlKDPrMUCa4YWYqZRt2OZ+/3NKaZ2xAiSwr7 +6MrQF70t5nLbSPpqE5+5VrS58SY+g/sXLiFd6AplH1wJZwh78DofbFYXUggktFMt +pTyiX8jtP66bkcPkDADA089RI1TQR6Ca+n7HFa7c1fabVV6i3zkCAwEAAaMYMBYw +FAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBBQUAA4GBAHPctQBEQ4wd +BJ6+JcpIraopLn8BGhbjNWj40mmRqWB/NAWF6M5ne7KpGAu7tLeG4hb1zLaldK8G +lxy2GPSRF6LFS48dpEj2HbMv2nvv6xxalDMJ9+DicWgAKTQ6bcX2j3GUkCR0g/T1 +CRlNBAAlvhKzO7Clpf9l0YKBEfraJByX +-----END CERTIFICATE----- diff --git a/tests/ssl_key.pem b/tests/ssl_key.pem new file mode 100644 index 00000000..3fd3bbd5 --- /dev/null +++ b/tests/ssl_key.pem @@ -0,0 +1,16 @@ +-----BEGIN PRIVATE KEY----- +MIICdwIBADANBgkqhkiG9w0BAQEFAASCAmEwggJdAgEAAoGBANtb0+YrKuxevGpm +LrjaUhZSgz6zFAmuGFmKmUbdjmfv9zSmmdsQIksK++jK0Be9LeZy20j6ahOfuVa0 +ufEmPoP7Fy4hXegKZR9cCWcIe/A6H2xWF1IIJLRTLaU8ol/I7T+um5HD5AwAwNPP +USNU0Eegmvp+xxWu3NX2m1Veot85AgMBAAECgYA3ZdZ673X0oexFlq7AAmrutkHt +CL7LvwrpOiaBjhyTxTeSNWzvtQBkIU8DOI0bIazA4UreAFffwtvEuPmonDb3F+Iq +SMAu42XcGyVZEl+gHlTPU9XRX7nTOXVt+MlRRRxL6t9GkGfUAXI3XxJDXW3c0vBK +UL9xqD8cORXOfE06rQJBAP8mEX1ERkR64Ptsoe4281vjTlNfIbs7NMPkUnrn9N/Y +BLhjNIfQ3HFZG8BTMLfX7kCS9D593DW5tV4Z9BP/c6cCQQDcFzCcVArNh2JSywOQ +ZfTfRbJg/Z5Lt9Fkngv1meeGNPgIMLN8Sg679pAOOWmzdMO3V706rNPzSVMME7E5 +oPIfAkEA8pDddarP5tCvTTgUpmTFbakm0KoTZm2+FzHcnA4jRh+XNTjTOv98Y6Ik +eO5d1ZnKXseWvkZncQgxfdnMqqpj5wJAcNq/RVne1DbYlwWchT2Si65MYmmJ8t+F +0mcsULqjOnEMwf5e+ptq5LzwbyrHZYq5FNk7ocufPv/ZQrcSSC+cFwJBAKvOJByS +x56qyGeZLOQlWS2JS3KJo59XuLFGqcbgN9Om9xFa41Yb4N9NvplFivsvZdw3m1Q/ +SPIXQuT8RMPDVNQ= +-----END PRIVATE KEY----- diff --git a/tests/test_events.py b/tests/test_events.py index 18411ecc..d3f32d75 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -17,7 +17,7 @@ import errno import unittest import unittest.mock -from test.support import find_unused_port, IPV6_ENABLED +from test import support # find_unused_port, IPV6_ENABLED, TEST_HOME_DIR from asyncio import futures @@ -30,10 +30,27 @@ from asyncio import locks +def data_file(filename): + if hasattr(support, 'TEST_HOME_DIR'): + fullname = os.path.join(support.TEST_HOME_DIR, filename) + if os.path.isfile(fullname): + return fullname + fullname = os.path.join(os.path.dirname(__file__), filename) + if os.path.isfile(fullname): + return fullname + raise FileNotFoundError(filename) + +ONLYCERT = data_file('ssl_cert.pem') +ONLYKEY = data_file('ssl_key.pem') +SIGNED_CERTFILE = data_file('keycert3.pem') +SIGNING_CA = data_file('pycacert.pem') + + class MyProto(protocols.Protocol): done = None def __init__(self, loop=None): + self.transport = None self.state = 'INITIAL' self.nbytes = 0 if loop is not None: @@ -523,7 +540,7 @@ def test_create_ssl_connection(self): def test_create_connection_local_addr(self): with test_utils.run_test_server() as httpd: - port = find_unused_port() + port = support.find_unused_port() f = self.loop.create_connection( lambda: MyProto(loop=self.loop), *httpd.address, local_addr=(httpd.address[0], port)) @@ -587,6 +604,20 @@ def factory(): # close server server.close() + def _make_ssl_server(self, factory, certfile, keyfile=None): + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.load_cert_chain(certfile, keyfile) + + f = self.loop.create_server( + factory, '127.0.0.1', 0, ssl=sslcontext) + + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + host, port = sock.getsockname() + self.assertEqual(host, '127.0.0.1') + return server, host, port + @unittest.skipIf(ssl is None, 'No ssl module') def test_create_server_ssl(self): proto = None @@ -602,19 +633,7 @@ def factory(): proto = MyProto(loop=self.loop) return proto - here = os.path.dirname(__file__) - sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - sslcontext.load_cert_chain( - certfile=os.path.join(here, 'sample.crt'), - keyfile=os.path.join(here, 'sample.key')) - - f = self.loop.create_server( - factory, '127.0.0.1', 0, ssl=sslcontext) - - server = self.loop.run_until_complete(f) - sock = server.sockets[0] - host, port = sock.getsockname() - self.assertEqual(host, '127.0.0.1') + server, host, port = self._make_ssl_server(factory, ONLYCERT, ONLYKEY) f_c = self.loop.create_connection(ClientMyProto, host, port, ssl=test_utils.dummy_ssl_context()) @@ -646,6 +665,93 @@ def factory(): # stop serving server.close() + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_server_ssl_verify_failed(self): + proto = None + + def factory(): + nonlocal proto + proto = MyProto(loop=self.loop) + return proto + + server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True + + # no CA loaded + f_c = self.loop.create_connection(MyProto, host, port, + ssl=sslcontext_client) + with self.assertRaisesRegex(ssl.SSLError, + 'certificate verify failed '): + self.loop.run_until_complete(f_c) + + # close connection + self.assertIsNone(proto.transport) + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_server_ssl_match_failed(self): + proto = None + + def factory(): + nonlocal proto + proto = MyProto(loop=self.loop) + return proto + + server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + sslcontext_client.load_verify_locations( + cafile=SIGNING_CA) + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True + + # incorrect server_hostname + f_c = self.loop.create_connection(MyProto, host, port, + ssl=sslcontext_client) + with self.assertRaisesRegex(ssl.CertificateError, + "hostname '127.0.0.1' doesn't match 'localhost'"): + self.loop.run_until_complete(f_c) + + # close connection + proto.transport.close() + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_server_ssl_verified(self): + proto = None + + def factory(): + nonlocal proto + proto = MyProto(loop=self.loop) + return proto + + server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + sslcontext_client.load_verify_locations(cafile=SIGNING_CA) + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True + + # Connection succeeds with correct CA and server hostname. + f_c = self.loop.create_connection(MyProto, host, port, + ssl=sslcontext_client, + server_hostname='localhost') + client, pr = self.loop.run_until_complete(f_c) + + # close connection + proto.transport.close() + client.close() + server.close() + def test_create_server_sock(self): proto = futures.Future(loop=self.loop) @@ -688,7 +794,7 @@ def test_create_server_addr_in_use(self): server.close() - @unittest.skipUnless(IPV6_ENABLED, 'IPv6 not supported or enabled') + @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 not supported or enabled') def test_create_server_dual_stack(self): f_proto = futures.Future(loop=self.loop) @@ -700,7 +806,7 @@ def connection_made(self, transport): try_count = 0 while True: try: - port = find_unused_port() + port = support.find_unused_port() f = self.loop.create_server(TestMyProto, host=None, port=port) server = self.loop.run_until_complete(f) except OSError as ex: diff --git a/update_stdlib.sh b/update_stdlib.sh index b025adfa..9e054659 100755 --- a/update_stdlib.sh +++ b/update_stdlib.sh @@ -48,7 +48,7 @@ do fi done -for i in `(cd tests && ls *.py sample.???)` +for i in `(cd tests && ls *.py *.pem)` do if [ $i == test_selectors.py ] then From b5decab18a6b45f65c62c99080352c41910e749c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 6 Dec 2013 17:31:14 -0800 Subject: [PATCH 0828/1502] Skip SSL tests with IOCP event loop. --- tests/test_events.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/test_events.py b/tests/test_events.py index d3f32d75..1c2560c0 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1393,10 +1393,19 @@ def create_event_loop(self): return windows_events.ProactorEventLoop() def test_create_ssl_connection(self): - raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + raise unittest.SkipTest("IocpEventLoop incompatible with SSL") def test_create_server_ssl(self): - raise unittest.SkipTest("IocpEventLoop imcompatible with SSL") + raise unittest.SkipTest("IocpEventLoop incompatible with SSL") + + def test_create_server_ssl_verify_failed(self): + raise unittest.SkipTest("IocpEventLoop incompatible with SSL") + + def test_create_server_ssl_match_failed(self): + raise unittest.SkipTest("IocpEventLoop incompatible with SSL") + + def test_create_server_ssl_verified(self): + raise unittest.SkipTest("IocpEventLoop incompatible with SSL") def test_reader_callback(self): raise unittest.SkipTest("IocpEventLoop does not have add_reader()") From 33c76df54ae0ba0019eeba56f828c3615b0204f9 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 7 Dec 2013 15:58:30 -0800 Subject: [PATCH 0829/1502] Remove duplicate import. --- asyncio/test_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index 131a5460..4c658da3 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -3,7 +3,6 @@ import collections import contextlib import io -import unittest.mock import os import sys import threading From 732a8a25a5a04b33d74dad4af2adc0e63fbf39e1 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 7 Dec 2013 15:58:49 -0800 Subject: [PATCH 0830/1502] import upstream selectors.py changes. --- asyncio/selectors.py | 74 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 65 insertions(+), 9 deletions(-) diff --git a/asyncio/selectors.py b/asyncio/selectors.py index c533f13d..a44d5e96 100644 --- a/asyncio/selectors.py +++ b/asyncio/selectors.py @@ -25,6 +25,9 @@ def _fileobj_to_fd(fileobj): Returns: corresponding file descriptor + + Raises: + ValueError if the object is invalid """ if isinstance(fileobj, int): fd = fileobj @@ -55,7 +58,8 @@ def __len__(self): def __getitem__(self, fileobj): try: - return self._selector._fd_to_key[_fileobj_to_fd(fileobj)] + fd = self._selector._fileobj_lookup(fileobj) + return self._selector._fd_to_key[fd] except KeyError: raise KeyError("{!r} is not registered".format(fileobj)) from None @@ -89,6 +93,15 @@ def register(self, fileobj, events, data=None): Returns: SelectorKey instance + + Raises: + ValueError if events is invalid + KeyError if fileobj is already registered + OSError if fileobj is closed or otherwise is unacceptable to + the underlying system call (if a system call is made) + + Note: + OSError may or may not be raised """ raise NotImplementedError @@ -101,6 +114,13 @@ def unregister(self, fileobj): Returns: SelectorKey instance + + Raises: + KeyError if fileobj is not registered + + Note: + If fileobj is registered but has since been closed this does + *not* raise OSError (even if the wrapped syscall does) """ raise NotImplementedError @@ -114,6 +134,9 @@ def modify(self, fileobj, events, data=None): Returns: SelectorKey instance + + Raises: + Anything that unregister() or register() raises """ self.unregister(fileobj) return self.register(fileobj, events, data) @@ -177,22 +200,41 @@ def __init__(self): # read-only mapping returned by get_map() self._map = _SelectorMapping(self) + def _fileobj_lookup(self, fileobj): + """Return a file descriptor from a file object. + + This wraps _fileobj_to_fd() to do an exhaustive search in case + the object is invalid but we still have it in our map. This + is used by unregister() so we can unregister an object that + was previously registered even if it is closed. It is also + used by _SelectorMapping. + """ + try: + return _fileobj_to_fd(fileobj) + except ValueError: + # Do an exhaustive search. + for key in self._fd_to_key.values(): + if key.fileobj is fileobj: + return key.fd + # Raise ValueError after all. + raise + def register(self, fileobj, events, data=None): if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)): raise ValueError("Invalid events: {!r}".format(events)) - key = SelectorKey(fileobj, _fileobj_to_fd(fileobj), events, data) + key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data) if key.fd in self._fd_to_key: - raise KeyError("{!r} (FD {}) is already " - "registered".format(fileobj, key.fd)) + raise KeyError("{!r} (FD {}) is already registered" + .format(fileobj, key.fd)) self._fd_to_key[key.fd] = key return key def unregister(self, fileobj): try: - key = self._fd_to_key.pop(_fileobj_to_fd(fileobj)) + key = self._fd_to_key.pop(self._fileobj_lookup(fileobj)) except KeyError: raise KeyError("{!r} is not registered".format(fileobj)) from None return key @@ -200,7 +242,7 @@ def unregister(self, fileobj): def modify(self, fileobj, events, data=None): # TODO: Subclasses can probably optimize this even further. try: - key = self._fd_to_key[_fileobj_to_fd(fileobj)] + key = self._fd_to_key[self._fileobj_lookup(fileobj)] except KeyError: raise KeyError("{!r} is not registered".format(fileobj)) from None if events != key.events: @@ -352,7 +394,12 @@ def register(self, fileobj, events, data=None): def unregister(self, fileobj): key = super().unregister(fileobj) - self._epoll.unregister(key.fd) + try: + self._epoll.unregister(key.fd) + except OSError: + # This can happen if the FD was closed since it + # was registered. + pass return key def select(self, timeout=None): @@ -409,11 +456,20 @@ def unregister(self, fileobj): if key.events & EVENT_READ: kev = select.kevent(key.fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) - self._kqueue.control([kev], 0, 0) + try: + self._kqueue.control([kev], 0, 0) + except OSError: + # This can happen if the FD was closed since it + # was registered. + pass if key.events & EVENT_WRITE: kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, select.KQ_EV_DELETE) - self._kqueue.control([kev], 0, 0) + try: + self._kqueue.control([kev], 0, 0) + except OSError: + # See comment above. + pass return key def select(self, timeout=None): From cf1d8d4569a8d0f6ff8b15a46fb5975e7c6b238b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 18 Dec 2013 14:41:39 -0800 Subject: [PATCH 0831/1502] Relax some more timeouts. --- tests/test_windows_events.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index 7ba33dac..b32477f9 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -120,7 +120,7 @@ def test_wait_for_handle(self): self.loop.run_until_complete(f) elapsed = self.loop.time() - start self.assertTrue(f.result()) - self.assertTrue(0 <= elapsed < 0.02, elapsed) + self.assertTrue(0 <= elapsed < 0.1, elapsed) _overlapped.ResetEvent(event) @@ -132,7 +132,7 @@ def test_wait_for_handle(self): with self.assertRaises(futures.CancelledError): self.loop.run_until_complete(f) elapsed = self.loop.time() - start - self.assertTrue(0 <= elapsed < 0.02, elapsed) + self.assertTrue(0 <= elapsed < 0.1, elapsed) if __name__ == '__main__': From d2757b5d253657c2fe52f2d2cb053b4ed0df5811 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 18 Dec 2013 14:43:13 -0800 Subject: [PATCH 0832/1502] Skip some tests if SNI not supported. --- tests/test_events.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/test_events.py b/tests/test_events.py index 1c2560c0..c6db4d1b 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -10,6 +10,9 @@ import ssl except ImportError: ssl = None + HAS_SNI = False +else: + from ssl import HAS_SNI import subprocess import sys import threading @@ -666,6 +669,7 @@ def factory(): server.close() @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') def test_create_server_ssl_verify_failed(self): proto = None @@ -694,6 +698,7 @@ def factory(): server.close() @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') def test_create_server_ssl_match_failed(self): proto = None @@ -724,6 +729,7 @@ def factory(): server.close() @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') def test_create_server_ssl_verified(self): proto = None @@ -899,7 +905,7 @@ def datagram_received(self, data, addr): def test_internal_fds(self): loop = self.create_event_loop() if not isinstance(loop, selector_events.BaseSelectorEventLoop): - return + self.skipTest('loop is not a BaseSelectorEventLoop') self.assertEqual(1, loop._internal_fds) loop.close() From ea61917b9d3818f43a20f14671dc293eb0bfe63b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 18 Dec 2013 14:43:41 -0800 Subject: [PATCH 0833/1502] Drop Tulip reference. --- asyncio/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index 4c658da3..ccb44541 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -88,7 +88,7 @@ class SSLWSGIServer(SilentWSGIServer): def finish_request(self, request, client_address): # The relative location of our test directory (which # contains the ssl key and certificate files) differs - # between the stdlib and stand-alone Tulip/asyncio. + # between the stdlib and stand-alone asyncio. # Prefer our own if we can find it. here = os.path.join(os.path.dirname(__file__), '..', 'tests') if not os.path.isdir(here): From c4daf897cceabd7cb6ad9129f59b0bcc9066dd92 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 18 Dec 2013 14:43:57 -0800 Subject: [PATCH 0834/1502] Drop Tulip reference. --- asyncio/queues.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/queues.py b/asyncio/queues.py index 3f5bf447..e900278f 100644 --- a/asyncio/queues.py +++ b/asyncio/queues.py @@ -26,7 +26,7 @@ class Queue: queue reaches maxsize, until an item is removed by get(). Unlike the standard library Queue, you can reliably know this Queue's size - with qsize(), since your single-threaded Tulip application won't be + with qsize(), since your single-threaded asyncio application won't be interrupted between calling qsize() and doing an operation on the Queue. """ From ed810e64adab043a395caa2c76e946aee833f8d6 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 19 Dec 2013 12:45:38 -0800 Subject: [PATCH 0835/1502] Shorten lines. --- asyncio/locks.py | 4 ++-- tests/test_events.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/asyncio/locks.py b/asyncio/locks.py index 6cd6779e..9e852924 100644 --- a/asyncio/locks.py +++ b/asyncio/locks.py @@ -138,7 +138,7 @@ def __iter__(self): class Event: - """An Event implementation, asynchronous equivalent to threading.Event. + """Asynchronous equivalent to threading.Event. Class implementing event objects. An event manages a flag that can be set to true with the set() method and reset to false with the clear() method. @@ -204,7 +204,7 @@ def wait(self): class Condition: - """A Condition implementation, asynchronous equivalent to threading.Condition. + """Asynchronous equivalent to threading.Condition. This class implements condition variable objects. A condition variable allows one or more coroutines to wait until they are notified by another diff --git a/tests/test_events.py b/tests/test_events.py index c6db4d1b..9545dd13 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -720,7 +720,8 @@ def factory(): # incorrect server_hostname f_c = self.loop.create_connection(MyProto, host, port, ssl=sslcontext_client) - with self.assertRaisesRegex(ssl.CertificateError, + with self.assertRaisesRegex( + ssl.CertificateError, "hostname '127.0.0.1' doesn't match 'localhost'"): self.loop.run_until_complete(f_c) From 53f45b72a12ffcf523c2061dd7579af196749b59 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 19 Dec 2013 13:48:20 -0800 Subject: [PATCH 0836/1502] Victor Stinner (CPython issue 19967): Future can have __del__ thanks to PEP 442. --- asyncio/futures.py | 30 +++++++++++++++++++++++------- tests/test_tasks.py | 6 +++++- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/asyncio/futures.py b/asyncio/futures.py index dd3e718d..0188f52d 100644 --- a/asyncio/futures.py +++ b/asyncio/futures.py @@ -7,6 +7,7 @@ import concurrent.futures._base import logging +import sys import traceback from . import events @@ -17,6 +18,8 @@ _CANCELLED = 'CANCELLED' _FINISHED = 'FINISHED' +_PY34 = sys.version_info >= (3, 4) + # TODO: Do we really want to depend on concurrent.futures internals? Error = concurrent.futures._base.Error CancelledError = concurrent.futures.CancelledError @@ -128,7 +131,8 @@ class Future: _blocking = False # proper use of future (yield vs yield from) - _tb_logger = None + _traceback = None # Used for Python 3.4 and later + _tb_logger = None # Used for Python 3.3 only def __init__(self, *, loop=None): """Initialize the future. @@ -162,6 +166,12 @@ def __repr__(self): res += '<{}>'.format(self._state) return res + if _PY34: + def __del__(self): + if self._traceback is not None: + logger.error('Future/Task exception was never retrieved:\n%s', + ''.join(self._traceback)) + def cancel(self): """Cancel the future and schedule callbacks. @@ -214,9 +224,10 @@ def result(self): raise CancelledError if self._state != _FINISHED: raise InvalidStateError('Result is not ready.') + self._traceback = None if self._tb_logger is not None: self._tb_logger.clear() - self._tb_logger = None + self._tb_logger = None if self._exception is not None: raise self._exception return self._result @@ -233,9 +244,10 @@ def exception(self): raise CancelledError if self._state != _FINISHED: raise InvalidStateError('Exception is not set.') + self._traceback = None if self._tb_logger is not None: self._tb_logger.clear() - self._tb_logger = None + self._tb_logger = None return self._exception def add_done_callback(self, fn): @@ -286,12 +298,16 @@ def set_exception(self, exception): if self._state != _PENDING: raise InvalidStateError('{}: {!r}'.format(self._state, self)) self._exception = exception - self._tb_logger = _TracebackLogger(exception) self._state = _FINISHED self._schedule_callbacks() - # Arrange for the logger to be activated after all callbacks - # have had a chance to call result() or exception(). - self._loop.call_soon(self._tb_logger.activate) + if _PY34: + self._traceback = traceback.format_exception( + exception.__class__, exception, exception.__traceback__) + else: + self._tb_logger = _TracebackLogger(exception) + # Arrange for the logger to be activated after all callbacks + # have had a chance to call result() or exception(). + self._loop.call_soon(self._tb_logger.activate) # Truly internal methods. diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 5470da15..79a25d29 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1115,6 +1115,7 @@ def coro(): def test_current_task(self): self.assertIsNone(tasks.Task.current_task(loop=self.loop)) + @tasks.coroutine def coro(loop): self.assertTrue(tasks.Task.current_task(loop=loop) is task) @@ -1146,7 +1147,8 @@ def coro2(loop): task1 = tasks.Task(coro1(self.loop), loop=self.loop) task2 = tasks.Task(coro2(self.loop), loop=self.loop) - self.loop.run_until_complete(tasks.wait((task1, task2), loop=self.loop)) + self.loop.run_until_complete(tasks.wait((task1, task2), + loop=self.loop)) self.assertIsNone(tasks.Task.current_task(loop=self.loop)) # Some thorough tests for cancellation propagation through @@ -1352,6 +1354,7 @@ def test_one_exception(self): c.set_result(3) d.cancel() e.set_exception(RuntimeError()) + e.exception() def test_return_exceptions(self): a, b, c, d = [futures.Future(loop=self.one_loop) for i in range(4)] @@ -1431,6 +1434,7 @@ def test_one_cancellation(self): c.set_result(3) d.cancel() e.set_exception(RuntimeError()) + e.exception() def test_result_exception_one_cancellation(self): a, b, c, d, e, f = [futures.Future(loop=self.one_loop) From 7b6419ef5f57f24549b5bb62bc44280ccabe3ad8 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 20 Dec 2013 14:37:12 -0800 Subject: [PATCH 0837/1502] Export all abstract protocol and transport classes. CPython issue #20029. --- asyncio/protocols.py | 3 ++- asyncio/transports.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/asyncio/protocols.py b/asyncio/protocols.py index 1b8870c6..3c4f3f4a 100644 --- a/asyncio/protocols.py +++ b/asyncio/protocols.py @@ -1,6 +1,7 @@ """Abstract Protocol class.""" -__all__ = ['Protocol', 'DatagramProtocol'] +__all__ = ['BaseProtocol', 'Protocol', 'DatagramProtocol', + 'SubprocessProtocol'] class BaseProtocol: diff --git a/asyncio/transports.py b/asyncio/transports.py index c2feb93d..2d2469ee 100644 --- a/asyncio/transports.py +++ b/asyncio/transports.py @@ -4,7 +4,9 @@ PY34 = sys.version_info >= (3, 4) -__all__ = ['ReadTransport', 'WriteTransport', 'Transport'] +__all__ = ['BaseTransport', 'ReadTransport', 'WriteTransport', + 'Transport', 'DatagramTransport', 'SubprocessTransport', + ] class BaseTransport: From d05f1619347f9c9e43437e92635ce13abd974c25 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 20 Dec 2013 15:41:48 -0800 Subject: [PATCH 0838/1502] Performance improvement to CPython issue 19967 (Victor Stinner again). --- asyncio/futures.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/asyncio/futures.py b/asyncio/futures.py index 0188f52d..9ee13e3e 100644 --- a/asyncio/futures.py +++ b/asyncio/futures.py @@ -131,8 +131,8 @@ class Future: _blocking = False # proper use of future (yield vs yield from) - _traceback = None # Used for Python 3.4 and later - _tb_logger = None # Used for Python 3.3 only + _log_traceback = False # Used for Python 3.4 and later + _tb_logger = None # Used for Python 3.3 only def __init__(self, *, loop=None): """Initialize the future. @@ -168,9 +168,13 @@ def __repr__(self): if _PY34: def __del__(self): - if self._traceback is not None: - logger.error('Future/Task exception was never retrieved:\n%s', - ''.join(self._traceback)) + if not self._log_traceback: + # set_exception() was not called, or result() or exception() + # has consumed the exception + return + exc = self._exception + logger.error('Future/Task exception was never retrieved:', + exc_info=(exc.__class__, exc, exc.__traceback__)) def cancel(self): """Cancel the future and schedule callbacks. @@ -224,10 +228,10 @@ def result(self): raise CancelledError if self._state != _FINISHED: raise InvalidStateError('Result is not ready.') - self._traceback = None + self._log_traceback = False if self._tb_logger is not None: self._tb_logger.clear() - self._tb_logger = None + self._tb_logger = None if self._exception is not None: raise self._exception return self._result @@ -244,10 +248,10 @@ def exception(self): raise CancelledError if self._state != _FINISHED: raise InvalidStateError('Exception is not set.') - self._traceback = None + self._log_traceback = False if self._tb_logger is not None: self._tb_logger.clear() - self._tb_logger = None + self._tb_logger = None return self._exception def add_done_callback(self, fn): @@ -301,8 +305,7 @@ def set_exception(self, exception): self._state = _FINISHED self._schedule_callbacks() if _PY34: - self._traceback = traceback.format_exception( - exception.__class__, exception, exception.__traceback__) + self._log_traceback = True else: self._tb_logger = _TracebackLogger(exception) # Arrange for the logger to be activated after all callbacks From 4f5219fdf735f543105749235b5405cb30b9cf0a Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 20 Dec 2013 20:36:23 -0800 Subject: [PATCH 0839/1502] Fix space in log message about poll time. --- asyncio/base_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index f2d117bd..a8850656 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -613,7 +613,7 @@ def _run_once(self): t0 = self.time() event_list = self._selector.select(timeout) t1 = self.time() - argstr = '' if timeout is None else '{:.3f}'.format(timeout) + argstr = '' if timeout is None else ' {:.3f}'.format(timeout) if t1-t0 >= 1: level = logging.INFO else: From 3e44506333c460fdce06e201fa8fd3c1faf0f1f9 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 20 Dec 2013 20:50:31 -0800 Subject: [PATCH 0840/1502] A new, larger example: a cache server with a shareable, reconnecting client. --- examples/cacheclt.py | 211 ++++++++++++++++++++++++++++++++++++ examples/cachemux.py | 89 ++++++++++++++++ examples/cachesvr.py | 247 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 547 insertions(+) create mode 100644 examples/cacheclt.py create mode 100644 examples/cachemux.py create mode 100644 examples/cachesvr.py diff --git a/examples/cacheclt.py b/examples/cacheclt.py new file mode 100644 index 00000000..cd8da070 --- /dev/null +++ b/examples/cacheclt.py @@ -0,0 +1,211 @@ +"""Client for cache server. + +See cachesvr.py for protocol description. +""" + +import argparse +import asyncio +from asyncio import test_utils +import json +import logging +import sys + +ARGS = argparse.ArgumentParser(description='Cache client example.') +ARGS.add_argument( + '--tls', action='store_true', dest='tls', + default=False, help='Use TLS') +ARGS.add_argument( + '--iocp', action='store_true', dest='iocp', + default=False, help='Use IOCP event loop (Windows only)') +ARGS.add_argument( + '--host', action='store', dest='host', + default='localhost', help='Host name') +ARGS.add_argument( + '--port', action='store', dest='port', + default=54321, type=int, help='Port number') +ARGS.add_argument( + '--timeout', action='store', dest='timeout', + default=5, type=float, help='Timeout') +ARGS.add_argument( + '--max_backoff', action='store', dest='max_backoff', + default=5, type=float, help='Max backoff on reconnect') +ARGS.add_argument( + '--ntasks', action='store', dest='ntasks', + default=10, type=int, help='Number of tester tasks') +ARGS.add_argument( + '--ntries', action='store', dest='ntries', + default=5, type=int, help='Number of request tries before giving up') + + +args = ARGS.parse_args() + + +class CacheClient: + """Multiplexing cache client. + + This wraps a single connection to the cache client. The + connection is automatically re-opened when an error occurs. + + Multiple tasks may share this object; the requests will be + serialized. + + The public API is get(), set(), delete() (all are coroutines). + """ + + def __init__(self, host, port, sslctx=None, loop=None): + self.host = host + self.port = port + self.sslctx = sslctx + self.loop = loop + self.todo = set() + self.initialized = False + self.task = asyncio.Task(self.activity(), loop=self.loop) + + @asyncio.coroutine + def get(self, key): + resp = yield from self.request('get', key) + if resp is None: + return None + return resp.get('value') + + @asyncio.coroutine + def set(self, key, value): + resp = yield from self.request('set', key, value) + if resp is None: + return False + return resp.get('status') == 'ok' + + @asyncio.coroutine + def delete(self, key): + resp = yield from self.request('delete', key) + if resp is None: + return False + return resp.get('status') == 'ok' + + @asyncio.coroutine + def request(self, type, key, value=None): + assert not self.task.done() + data = {'type': type, 'key': key} + if value is not None: + data['value'] = value + payload = json.dumps(data).encode('utf8') + waiter = asyncio.Future(loop=self.loop) + if self.initialized: + try: + yield from self.send(payload, waiter) + except IOError: + self.todo.add((payload, waiter)) + else: + self.todo.add((payload, waiter)) + return (yield from waiter) + + @asyncio.coroutine + def activity(self): + backoff = 0 + while True: + try: + self.reader, self.writer = yield from asyncio.open_connection( + self.host, self.port, ssl=self.sslctx, loop=self.loop) + except Exception as exc: + backoff = min(args.max_backoff, backoff + (backoff//2) + 1) + logging.info('Error connecting: %r; sleep %s', exc, backoff) + yield from asyncio.sleep(backoff, loop=self.loop) + continue + backoff = 0 + self.next_id = 0 + self.pending = {} + self. initialized = True + try: + while self.todo: + payload, waiter = self.todo.pop() + if not waiter.done(): + yield from self.send(payload, waiter) + while True: + resp_id, resp = yield from self.process() + if resp_id in self.pending: + payload, waiter = self.pending.pop(resp_id) + if not waiter.done(): + waiter.set_result(resp) + except Exception as exc: + self.initialized = False + self.writer.close() + while self.pending: + req_id, pair = self.pending.popitem() + payload, waiter = pair + if not waiter.done(): + self.todo.add(pair) + logging.info('Error processing: %r', exc) + + @asyncio.coroutine + def send(self, payload, waiter): + self.next_id += 1 + req_id = self.next_id + frame = 'request %d %d\n' % (req_id, len(payload)) + self.writer.write(frame.encode('ascii')) + self.writer.write(payload) + self.pending[req_id] = payload, waiter + yield from self.writer.drain() + + @asyncio.coroutine + def process(self): + frame = yield from self.reader.readline() + if not frame: + raise EOFError() + head, tail = frame.split(None, 1) + if head == b'error': + raise IOError('OOB error: %r' % tail) + if head != b'response': + raise IOError('Bad frame: %r' % frame) + resp_id, resp_size = map(int, tail.split()) + data = yield from self.reader.readexactly(resp_size) + if len(data) != resp_size: + raise EOFError() + resp = json.loads(data.decode('utf8')) + return resp_id, resp + + +def main(): + asyncio.set_event_loop(None) + if args.iocp: + from asyncio.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + else: + loop = asyncio.new_event_loop() + sslctx = None + if args.tls: + sslctx = test_utils.dummy_ssl_context() + cache = CacheClient(args.host, args.port, sslctx=sslctx, loop=loop) + loop.run_until_complete( + asyncio.gather( + *[testing(i, cache, loop) for i in range(args.ntasks)], + loop=loop)) + + +@asyncio.coroutine +def testing(label, cache, loop): + + def w(g): + return asyncio.wait_for(g, args.timeout, loop=loop) + + key = 'foo-%s' % label + while True: + logging.info('%s %s', label, '-'*20) + try: + ret = yield from w(cache.set(key, 'hello-%s-world' % label)) + logging.info('%s set %s', label, ret) + ret = yield from w(cache.get(key)) + logging.info('%s get %s', label, ret) + ret = yield from w(cache.delete(key)) + logging.info('%s del %s', label, ret) + ret = yield from w(cache.get(key)) + logging.info('%s get2 %s', label, ret) + except asyncio.TimeoutError: + logging.warn('%s Timeout', label) + except Exception as exc: + logging.exception('%s Client exception: %r', label, exc) + break + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + main() diff --git a/examples/cachemux.py b/examples/cachemux.py new file mode 100644 index 00000000..aaea0bef --- /dev/null +++ b/examples/cachemux.py @@ -0,0 +1,89 @@ +import asyncio +import json +import logging + + +class Mux: + """Demultiplexer for requests and responses. + + This also handles retries. + """ + + def __init__(self, host, port, *, sslctx=None, loop=None): + self.host = host + self.port = port + self.sslctx = sslctx + self.loop = loop + self.todo = set() + self.initialized = False + asyncio.Task(self.activity()) + + @asyncio.coroutine + def request(self, type, key, value=None): + data = {'type': type, 'key': key} + if value is not None: + data['value'] = value + payload = json.dumps(data).encode('utf8') + waiter = asyncio.Future(loop=self.loop) + if self.initialized: + yield from self.send(payload, waiter) + else: + self.todo.add((payload, waiter)) + return (yield from waiter) + + @asyncio.coroutine + def activity(self): + while True: + try: + self.reader, self.writer = yield from asyncio.open_connection( + self.host, self.post, ssl=loop.sslctx, loop=self.loop) + except IOError as exc: + logging.info('I/O error connecting: %r', exc) + yield from asyncio.sleep(1, loop=self.loop) + continue + self.next_id = 0 + self.pending = {} + self. initialized = True + try: + while self.todo: + payload, waiter = self.todo.pop() + yield from self.send(payload, waiter) + while True: + resp_id, resp = self.process() + if resp_id in self.pending: + payload, waiter = self.pending.pop(resp_id) + if not waiter.done(): + waiter.set_result(resp) + except IOError as exc: + self.initialized = False + self.writer.close() + while self.pending: + req_id, (payload, waiter) = self.pending.popitem() + if not waiter.done(): + self.todo.add(pair) + logging.info('I/O error processing: %r', exc) + + @asyncio.coroutine + def send(self, payload, waiter): + self.next_id += 1 + req_id = self.next_id + frame = 'request %d %d\n' % (req_id, len(payload)) + self.writer.write(frame.encode('ascii')) + self.writer.write(payload) + self.pending[req_id] = payload, waiter + yield from self.writer.drain() + + @asyncio.coroutine + def process(self): + frame = yield from self.reader.readline() + if not frame: + raise IOError('EOF') + head, tail = frame.split(None, 1) + if head == b'error': + raise IOError('OOB error: %r' % tail) + if head != b'response': + raise IOError('Bad frame: %r' % frame) + resp_id, resp_size = map(int, tail.split()) + data = yield from self.reader.readexactly(resp_size) + resp = json.loads(data.decode('utf8')) + return resp_id, resp diff --git a/examples/cachesvr.py b/examples/cachesvr.py new file mode 100644 index 00000000..9c7bda91 --- /dev/null +++ b/examples/cachesvr.py @@ -0,0 +1,247 @@ +"""A simple memcache-like server. + +The basic data structure maintained is a single in-memory dictionary +mapping string keys to string values, with operations get, set and +delete. (Both keys and values may contain Unicode.) + +This is a TCP server listening on port 54321. There is no +authentication. + +Requests provide an operation and return a response. A connection may +be used for multiple requests. The connection is closed when a client +sends a bad request. + +If a client is idle for over 5 seconds (i.e., it does not send another +request, or fails to read the whole response, within this time), it is +disconnected. + +Framing of requests and responses within a connection uses a +line-based protocol. The first line of a request is the frame header +and contains three whitespace-delimited token followed by LF or CRLF: + +- the keyword 'request' +- a decimal request ID; the first request is '1', the second '2', etc. +- a decimal byte count giving the size of the rest of the request + +Note that the requests ID *must* be consecutive and start at '1' for +each connection. + +Response frames look the same except the keyword is 'response'. The +response ID matches the request ID. There should be exactly one +response to each request and responses should be seen in the same +order as the requests. + +After the frame, individual requests and responses are JSON encoded. + +If the frame header or the JSON request body cannot be parsed, an +unframed error message (always starting with 'error') is written back +and the connection is closed. + +JSON-encoded requests can be: + +- {"type": "get", "key": } +- {"type": "set", "key": , "value": } +- {"type": "delete", "key": } + +Responses are also JSON-encoded: + +- {"status": "ok", "value": } # Successful get request +- {"status": "ok"} # Successful set or delete request +- {"status": "notfound"} # Key not found for get or delete request + +If the request is valid JSON but cannot be handled (e.g., the type or +key field is absent or invalid), an error response of the following +form is returned, but the connection is not closed: + +- {"error": } +""" + +import argparse +import asyncio +import json +import logging +import os +import random +import sys + +ARGS = argparse.ArgumentParser(description='Cache server example.') +ARGS.add_argument( + '--tls', action='store_true', dest='tls', + default=False, help='Use TLS') +ARGS.add_argument( + '--iocp', action='store_true', dest='iocp', + default=False, help='Use IOCP event loop (Windows only)') +ARGS.add_argument( + '--host', action='store', dest='host', + default='localhost', help='Host name') +ARGS.add_argument( + '--port', action='store', dest='port', + default=54321, type=int, help='Port number') +ARGS.add_argument( + '--timeout', action='store', dest='timeout', + default=5, type=float, help='Timeout') +ARGS.add_argument( + '--random_failure_percent', action='store', dest='fail_percent', + default=0, type=float, help='Fail randomly N percent of the time') +ARGS.add_argument( + '--random_failure_sleep', action='store', dest='fail_sleep', + default=0, type=float, help='Sleep time when randomly failing') +ARGS.add_argument( + '--random_response_sleep', action='store', dest='resp_sleep', + default=0, type=float, help='Sleep time before responding') + +args = ARGS.parse_args() + + +class Cache: + + def __init__(self, loop): + self.loop = loop + self.table = {} + + @asyncio.coroutine + def handle_client(self, reader, writer): + # Wrapper to log stuff and close writer (i.e., transport). + peer = writer.get_extra_info('socket').getpeername() + logging.info('got a connection from %s', peer) + try: + yield from self.frame_parser(reader, writer) + except Exception as exc: + logging.error('error %r from %s', exc, peer) + else: + logging.info('end connection from %s', peer) + finally: + writer.close() + + @asyncio.coroutine + def frame_parser(self, reader, writer): + # This takes care of the framing. + last_request_id = 0 + while True: + # Read the frame header, parse it, read the data. + # NOTE: The readline() and readexactly() calls will hang + # if the client doesn't send enough data but doesn't + # disconnect either. We add a timeout to each. (But the + # timeout should really be implemented by StreamReader.) + framing_b = yield from asyncio.wait_for( + reader.readline(), + timeout=args.timeout, loop=self.loop) + if random.random()*100 < args.fail_percent: + logging.warn('Inserting random failure') + yield from asyncio.sleep(args.fail_sleep*random.random(), + loop=self.loop) + writer.write(b'error random failure\r\n') + break + logging.debug('framing_b = %r', framing_b) + if not framing_b: + break # Clean close. + try: + frame_keyword, request_id_b, byte_count_b = framing_b.split() + except ValueError: + writer.write(b'error unparseable frame\r\n') + break + if frame_keyword != b'request': + writer.write(b'error frame does not start with request\r\n') + break + try: + request_id, byte_count = int(request_id_b), int(byte_count_b) + except ValueError: + writer.write(b'error unparsable frame parameters\r\n') + break + if request_id != last_request_id + 1 or byte_count < 2: + writer.write(b'error invalid frame parameters\r\n') + break + last_request_id = request_id + request_b = yield from asyncio.wait_for( + reader.readexactly(byte_count), + timeout=args.timeout, loop=self.loop) + try: + request = json.loads(request_b.decode('utf8')) + except ValueError: + writer.write(b'error unparsable json\r\n') + break + response = self.handle_request(request) # Not a coroutine. + if response is None: + writer.write(b'error unhandlable request\r\n') + break + response_b = json.dumps(response).encode('utf8') + b'\r\n' + byte_count = len(response_b) + framing_s = 'response {} {}\r\n'.format(request_id, byte_count) + writer.write(framing_s.encode('ascii')) + yield from asyncio.sleep(args.resp_sleep*random.random(), + loop=self.loop) + writer.write(response_b) + + def handle_request(self, request): + # This parses one request and farms it out to a specific handler. + # Return None for all errors. + if not isinstance(request, dict): + return {'error': 'request is not a dict'} + request_type = request.get('type') + if request_type is None: + return {'error': 'no type in request'} + if request_type not in {'get', 'set', 'delete'}: + return {'error': 'unknown request type'} + key = request.get('key') + if not isinstance(key, str): + return {'error': 'key is not a string'} + if request_type == 'get': + return self.handle_get(key) + if request_type == 'set': + value = request.get('value') + if not isinstance(value, str): + return {'error': 'value is not a string'} + return self.handle_set(key, value) + if request_type == 'delete': + return self.handle_delete(key) + assert False, 'bad request type' # Should have been caught above. + + def handle_get(self, key): + value = self.table.get(key) + if value is None: + return {'status': 'notfound'} + else: + return {'status': 'ok', 'value': value} + + def handle_set(self, key, value): + self.table[key] = value + return {'status': 'ok'} + + def handle_delete(self, key): + if key not in self.table: + return {'status': 'notfound'} + else: + del self.table[key] + return {'status': 'ok'} + + +def main(): + asyncio.set_event_loop(None) + if args.iocp: + from asyncio.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + else: + loop = asyncio.new_event_loop() + sslctx = None + if args.tls: + import ssl + # TODO: take cert/key from args as well. + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + sslctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslctx.options |= ssl.OP_NO_SSLv2 + sslctx.load_cert_chain( + certfile=os.path.join(here, 'ssl_cert.pem'), + keyfile=os.path.join(here, 'ssl_key.pem')) + cache = Cache(loop) + task = asyncio.streams.start_server(cache.handle_client, + args.host, args.port, + ssl=sslctx, loop=loop) + svr = loop.run_until_complete(task) + for sock in svr.sockets: + logging.info('socket %s', sock.getsockname()) + loop.run_forever() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + main() From 11385c78e2c3896f4822f617efd81830d7cb0d01 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 20 Dec 2013 21:46:00 -0800 Subject: [PATCH 0841/1502] Oops. Added a file by accident. --- examples/cachemux.py | 89 -------------------------------------------- 1 file changed, 89 deletions(-) delete mode 100644 examples/cachemux.py diff --git a/examples/cachemux.py b/examples/cachemux.py deleted file mode 100644 index aaea0bef..00000000 --- a/examples/cachemux.py +++ /dev/null @@ -1,89 +0,0 @@ -import asyncio -import json -import logging - - -class Mux: - """Demultiplexer for requests and responses. - - This also handles retries. - """ - - def __init__(self, host, port, *, sslctx=None, loop=None): - self.host = host - self.port = port - self.sslctx = sslctx - self.loop = loop - self.todo = set() - self.initialized = False - asyncio.Task(self.activity()) - - @asyncio.coroutine - def request(self, type, key, value=None): - data = {'type': type, 'key': key} - if value is not None: - data['value'] = value - payload = json.dumps(data).encode('utf8') - waiter = asyncio.Future(loop=self.loop) - if self.initialized: - yield from self.send(payload, waiter) - else: - self.todo.add((payload, waiter)) - return (yield from waiter) - - @asyncio.coroutine - def activity(self): - while True: - try: - self.reader, self.writer = yield from asyncio.open_connection( - self.host, self.post, ssl=loop.sslctx, loop=self.loop) - except IOError as exc: - logging.info('I/O error connecting: %r', exc) - yield from asyncio.sleep(1, loop=self.loop) - continue - self.next_id = 0 - self.pending = {} - self. initialized = True - try: - while self.todo: - payload, waiter = self.todo.pop() - yield from self.send(payload, waiter) - while True: - resp_id, resp = self.process() - if resp_id in self.pending: - payload, waiter = self.pending.pop(resp_id) - if not waiter.done(): - waiter.set_result(resp) - except IOError as exc: - self.initialized = False - self.writer.close() - while self.pending: - req_id, (payload, waiter) = self.pending.popitem() - if not waiter.done(): - self.todo.add(pair) - logging.info('I/O error processing: %r', exc) - - @asyncio.coroutine - def send(self, payload, waiter): - self.next_id += 1 - req_id = self.next_id - frame = 'request %d %d\n' % (req_id, len(payload)) - self.writer.write(frame.encode('ascii')) - self.writer.write(payload) - self.pending[req_id] = payload, waiter - yield from self.writer.drain() - - @asyncio.coroutine - def process(self): - frame = yield from self.reader.readline() - if not frame: - raise IOError('EOF') - head, tail = frame.split(None, 1) - if head == b'error': - raise IOError('OOB error: %r' % tail) - if head != b'response': - raise IOError('Bad frame: %r' % frame) - resp_id, resp_size = map(int, tail.split()) - data = yield from self.reader.readexactly(resp_size) - resp = json.loads(data.decode('utf8')) - return resp_id, resp From f62f542e4779bd8630f7dec277f5caa416de63de Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 28 Dec 2013 08:05:10 -1000 Subject: [PATCH 0842/1502] Export iscoroutine[function]. --- asyncio/tasks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index cd9718f5..406bcb93 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -1,6 +1,7 @@ """Support for tasks, coroutines and the scheduler.""" __all__ = ['coroutine', 'Task', + 'iscoroutinefunction', 'iscoroutine', 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', 'wait', 'wait_for', 'as_completed', 'sleep', 'async', 'gather', 'shield', From e282bc0a16b474f6843d94fd4b1621cb9868aa56 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 3 Jan 2014 09:10:01 -1000 Subject: [PATCH 0843/1502] Make PY34 symbol private (rename it to _PY34). --- asyncio/transports.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/asyncio/transports.py b/asyncio/transports.py index 2d2469ee..67ae7fda 100644 --- a/asyncio/transports.py +++ b/asyncio/transports.py @@ -2,7 +2,7 @@ import sys -PY34 = sys.version_info >= (3, 4) +_PY34 = sys.version_info >= (3, 4) __all__ = ['BaseTransport', 'ReadTransport', 'WriteTransport', 'Transport', 'DatagramTransport', 'SubprocessTransport', @@ -94,7 +94,7 @@ def writelines(self, list_of_data): The default implementation concatenates the arguments and calls write() on the result. """ - if not PY34: + if not _PY34: # In Python 3.3, bytes.join() doesn't handle memoryview. list_of_data = ( bytes(data) if isinstance(data, memoryview) else data From 4646a1eb7ef028f36732d5857ffe639d8d81e57c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 6 Jan 2014 16:05:16 -0800 Subject: [PATCH 0844/1502] Avoid pause deadlock in readexactly(). Fixes issue 99. --- asyncio/streams.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/asyncio/streams.py b/asyncio/streams.py index 50c4c5d1..93a21d1a 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -220,6 +220,7 @@ def __init__(self, limit=_DEFAULT_LIMIT, loop=None): if loop is None: loop = events.get_event_loop() self._loop = loop + # TODO: Use a bytearray for a buffer, like the transport. self._buffer = collections.deque() # Deque of bytes objects. self._byte_count = 0 # Bytes in buffer. self._eof = False # Whether we're done. @@ -384,15 +385,23 @@ def readexactly(self, n): if self._exception is not None: raise self._exception - if n <= 0: - return b'' + # There used to be "optimized" code here. It created its own + # Future and waited until self._buffer had at least the n + # bytes, then called read(n). Unfortunately, this could pause + # the transport if the argument was larger than the pause + # limit (which is twice self._limit). So now we just read() + # into a local buffer. + + blocks = [] + while n > 0: + block = yield from self.read(n) + if not block: + break + blocks.append(block) + n -= len(block) - while self._byte_count < n and not self._eof: - assert not self._waiter - self._waiter = futures.Future(loop=self._loop) - try: - yield from self._waiter - finally: - self._waiter = None + # TODO: Raise EOFError if we break before n == 0? (That would + # be a change in specification, but I've always had to add an + # explicit size check to the caller.) - return (yield from self.read(n)) + return b''.join(blocks) From 6115749c7c485785573119fa7c1402643d5f56e9 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 7 Jan 2014 13:47:20 -0800 Subject: [PATCH 0845/1502] Don't special-case GeneratorExit in Condition.wait(). I can't remember why I added that code, there are no tests for it, and it causes a spurious "Exception ignored in " error message if code like this is interrupted during the wait(): @coroutine def gen(): with (yield from cond): while : yield from cond.wait() --- asyncio/locks.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/asyncio/locks.py b/asyncio/locks.py index 9e852924..9fdb9374 100644 --- a/asyncio/locks.py +++ b/asyncio/locks.py @@ -251,7 +251,6 @@ def wait(self): if not self.locked(): raise RuntimeError('cannot wait on un-acquired lock') - keep_lock = True self.release() try: fut = futures.Future(loop=self._loop) @@ -262,12 +261,8 @@ def wait(self): finally: self._waiters.remove(fut) - except GeneratorExit: - keep_lock = False # Prevent yield in finally clause. - raise finally: - if keep_lock: - yield from self.acquire() + yield from self.acquire() @tasks.coroutine def wait_for(self, predicate): From 995f8a9b5258fcb19eb784110e69a1962851ce2a Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 7 Jan 2014 16:35:59 -0800 Subject: [PATCH 0846/1502] Fix typo (_writer instead of _reader). --- asyncio/streams.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/streams.py b/asyncio/streams.py index 93a21d1a..13828f20 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -199,7 +199,7 @@ def drain(self): drained and the protocol is resumed. """ if self._reader._exception is not None: - raise self._writer._exception + raise self._exception._exception if self._transport._conn_lost: # Uses private variable. raise ConnectionResetError('Connection lost') if not self._protocol._paused: From 3618cb034e2e6620d39f782e19812932b48ba44c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 7 Jan 2014 17:02:15 -0800 Subject: [PATCH 0847/1502] Fix the fix I just committed. :-( --- asyncio/streams.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/streams.py b/asyncio/streams.py index 13828f20..7eda5f66 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -199,7 +199,7 @@ def drain(self): drained and the protocol is resumed. """ if self._reader._exception is not None: - raise self._exception._exception + raise self._reader._exception if self._transport._conn_lost: # Uses private variable. raise ConnectionResetError('Connection lost') if not self._protocol._paused: From 7b8f9ca8386968974ac8bdfc09b015ac2696e0c0 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 8 Jan 2014 11:06:43 -0800 Subject: [PATCH 0848/1502] Fix bug in chunked reader (must read final CRLF). --- examples/fetch3.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/fetch3.py b/examples/fetch3.py index 780222ba..bdca3aee 100644 --- a/examples/fetch3.py +++ b/examples/fetch3.py @@ -167,19 +167,19 @@ def read(self): if nbytes is None: if self.get_header('transfer-encoding', '').lower() == 'chunked': blocks = [] - while True: + size = -1 + while size: size_header = yield from self.reader.readline() if not size_header: break parts = size_header.split(b';') size = int(parts[0], 16) - if not size: - break - block = yield from self.reader.readexactly(size) - assert len(block) == size, (len(block), size) - blocks.append(block) + if size: + block = yield from self.reader.readexactly(size) + assert len(block) == size, (len(block), size) + blocks.append(block) crlf = yield from self.reader.readline() - assert crlf == b'\r\n' + assert crlf == b'\r\n', repr(crlf) body = b''.join(blocks) else: body = yield from self.reader.read() From 419b2d7f6e67f5ef78530b3bf25a98cb6240f684 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 8 Jan 2014 11:15:57 -0800 Subject: [PATCH 0849/1502] A new crawler example. --- examples/crawl.py | 668 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 668 insertions(+) create mode 100644 examples/crawl.py diff --git a/examples/crawl.py b/examples/crawl.py new file mode 100644 index 00000000..18fbe7be --- /dev/null +++ b/examples/crawl.py @@ -0,0 +1,668 @@ +#!/usr/bin/env python3.4 + +"""A simple web crawler.""" + +# TODO: +# - Less verbose logging. +# - Support gzip encoding. +# - Seems sometimes getaddrinfo() raises gaierror? + +import argparse +import asyncio +import asyncio.locks +import cgi +from http.client import BadStatusLine +import logging +import re +import signal +import sys +import time +import urllib.parse + + +ARGS = argparse.ArgumentParser(description="Web crawler") +ARGS.add_argument( + '--iocp', action='store_true', dest='iocp', + default=False, help='Use IOCP event loop (Windows only)') +ARGS.add_argument( + '--select', action='store_true', dest='select', + default=False, help='Use Select event loop instead of default') +ARGS.add_argument( + 'roots', nargs='*', + default=['http://python.org'], help='Root URL (may be repeated)') +ARGS.add_argument( + '--max_redirect', action='store', type=int, metavar='N', + default=10, help='Limit redirection (for 301, 302 etc.)') +ARGS.add_argument( + '--max_tries', action='store', type=int, metavar='N', + default=4, help='Limit retries on network errors') +ARGS.add_argument( + '--max_tasks', action='store', type=int, metavar='N', + default=5, help='Limit on concurrent connections') +ARGS.add_argument( + '--exclude', action='store', metavar='REGEX', + help='Exclude matching URLs') +ARGS.add_argument( + '-v', '--verbose', action='count', dest='verbose', + default=1, help='Verbose logging (repeat for more verbose)') +ARGS.add_argument( + '-q', '--quiet', action='store_const', const=0, dest='verbose', + help='Quiet logging (opposite of --verbose)') + + +ESCAPES = [('quot', '"'), + ('gt', '>'), + ('lt', '<'), + ('amp', '&') # Must be last. + ] + +def unescape(url): + """Turn & into &, and so on. + + This is the inverse of cgi.escape(). + """ + for name, char in ESCAPES: + url = url.replace('&' + name + ';', char) + return url + + +def fix_url(url): + """Prefix a schema-less URL with http://.""" + if '://' not in url: + url = 'http://' + url + return url + + +class VPrinter: + """Mix-in class defining vprint() which is like print() if verbose > 0. + + The output goes to stderr. Only positional arguments are + supported. There are also methods vvprint(), vvvprint() + etc. which print() only if verbose > larger values. + + The verbose instance variable is public. + + TODO: This should probably be a shared object rather than a mix-in class. + """ + + def __init__(self, verbose): + self.verbose = verbose + + def _nvprint(self, n, args): + if self.verbose >= n: + print(*args, file=sys.stderr, flush=True) + + def nvprint(self, n, *args): + self._nvprint(n, args) + + def vprint(self, *args): + self._nvprint(1, args) + + def vvprint(self, *args): + self._nvprint(2, args) + + def vvvprint(self, *args): + self._nvprint(3, args) + + +class ConnectionPool(VPrinter): + """A connection pool. + + To open a connection, use reserve(). To recycle it, use unreserve(). + + The pool is mostly just a mapping from (host, port, ssl) to + (reader, writer). The currently active connections are *not* in + the mapping; reserve() takes the connection out of the mapping, + and unreserve()' puts it back in. It is up to the caller to only + call unreserve() for reusable connections. (That logic is + implemented in the Request class.) + """ + + def __init__(self, verbose=0): + VPrinter.__init__(self, verbose) + self.loop = asyncio.get_event_loop() + self.connections = {} # {(host, port, ssl): (reader, writer)} + + def close(self): + """Close all connections available for reuse.""" + for _, writer in self.connections.values(): + writer.close() + + @asyncio.coroutine + def reserve(self, host, port, ssl): + """Create or reuse a connection.""" + port = port or (443 if ssl else 80) + try: + ipaddrs = yield from self.loop.getaddrinfo(host, port) + except Exception as exc: + self.vprint('Exception %r for (%r, %r)' % (exc, host, port)) + raise + self.vprint('* %s resolves to %s' % + (host, ', '.join(ip[4][0] for ip in ipaddrs))) + for _, _, _, _, (h, p, *_) in ipaddrs: + key = h, p, ssl + pair = self.connections.pop(key, None) + if pair: + self.vprint('* Reusing pooled connection', key) + reader, writer = pair + return key, reader, writer + reader, writer = yield from asyncio.open_connection(host, port, + ssl=ssl) + host, port, *_ = writer.get_extra_info('peername') + key = host, port, ssl + self.vprint('* New connection', key) + return key, reader, writer + + def unreserve(self, key, reader, writer): + """Make a connection available for reuse.""" + self.connections[key] = (reader, writer) + + +class Request(VPrinter): + """HTTP request. + + Use connect() to open a connection; send_request() to send the + request; get_response() to receive the response headers. + """ + + def __init__(self, url, pool, verbose=0): + VPrinter.__init__(self, verbose) + self.url = url + self.pool = pool + self.parts = urllib.parse.urlparse(self.url) + self.scheme = self.parts.scheme + assert self.scheme in ('http', 'https'), repr(url) + self.ssl = self.parts.scheme == 'https' + self.netloc = self.parts.netloc + self.hostname = self.parts.hostname + self.port = self.parts.port or (443 if self.ssl else 80) + self.path = (self.parts.path or '/') + self.query = self.parts.query + if self.query: + self.full_path = '%s?%s' % (self.path, self.query) + else: + self.full_path = self.path + self.http_version = 'HTTP/1.1' + self.method = 'GET' + self.headers = [] + self.reader = None + self.writer = None + + @asyncio.coroutine + def connect(self): + """Open a connection to the server.""" + self.vprint('* Connecting to %s:%s using %s for %s' % + (self.hostname, self.port, + 'ssl' if self.ssl else 'tcp', + self.url)) + self.key, self.reader, self.writer = \ + yield from self.pool.reserve(self.hostname, self.port, self.ssl) + self.vprint('* Connected to %s' % + (self.writer.get_extra_info('peername'),)) + + def recycle_connection(self): + """Recycle the connection to the pool. + + This should only be called when a properly formatted HTTP + response has been received. + """ + self.pool.unreserve(self.key, self.reader, self.writer) + self.key = self.reader = self.writer = None + + @asyncio.coroutine + def putline(self, line): + """Write a line to the connection. + + Used for the request line and headers. + """ + self.vvprint('>', line) + self.writer.write(line.encode('latin-1') + b'\r\n') + + @asyncio.coroutine + def send_request(self): + """Send the request.""" + request_line = '%s %s %s' % (self.method, self.full_path, + self.http_version) + yield from self.putline(request_line) + # TODO: What if a header is already set? + self.headers.append(('User-Agent', 'asyncio-example-crawl/0.0')) + self.headers.append(('Host', self.netloc)) + self.headers.append(('Accept', '*/*')) + ##self.headers.append(('Accept-Encoding', 'gzip')) + for key, value in self.headers: + line = '%s: %s' % (key, value) + yield from self.putline(line) + yield from self.putline('') + + @asyncio.coroutine + def get_response(self): + """Receive the response.""" + response = Response(self.reader, self.verbose) + yield from response.read_headers() + return response + + +class Response(VPrinter): + """HTTP response. + + Call read_headers() to receive the request headers. Then check + the status attribute and call get_header() to inspect the headers. + Finally call read() to receive the body. + """ + + def __init__(self, reader, verbose=0): + VPrinter.__init__(self, verbose) + self.reader = reader + self.http_version = None # 'HTTP/1.1' + self.status = None # 200 + self.reason = None # 'Ok' + self.headers = [] # [('Content-Type', 'text/html')] + + @asyncio.coroutine + def getline(self): + """Read one line from the connection.""" + line = (yield from self.reader.readline()).decode('latin-1').rstrip() + self.vvprint('<', line) + return line + + @asyncio.coroutine + def read_headers(self): + """Read the response status and the request headers.""" + status_line = yield from self.getline() + status_parts = status_line.split(None, 2) + if len(status_parts) != 3: + self.vprint('bad status_line', repr(status_line)) + raise BadStatusLine(status_line) + self.http_version, status, self.reason = status_parts + self.status = int(status) + while True: + header_line = yield from self.getline() + if not header_line: + break + # TODO: Continuation lines. + key, value = header_line.split(':', 1) + self.headers.append((key, value.strip())) + + def get_redirect_url(self, default=''): + """Inspect the status and return the redirect url if appropriate.""" + if self.status not in (300, 301, 302, 303, 307): + return default + return self.get_header('Location', default) + + def get_header(self, key, default=''): + """Get one header value, using a case insensitive header name.""" + key = key.lower() + for k, v in self.headers: + if k.lower() == key: + return v + return default + + @asyncio.coroutine + def readexactly(self, nbytes): + """Wrapper for readexactly() that raise EOFError if not enough data. + + This also logs (at the vvv level) while it is reading. + """ + blocks = [] + nread = 0 + while nread < nbytes: + self.vvvprint('reading block', len(blocks), + 'with', nbytes - nread, 'bytes remaining') + block = yield from self.reader.read(nbytes-nread) + self.vvvprint('read', len(block), 'bytes') + if not block: + raise EOFError('EOF with %d more bytes expected' % + (nbytes - nread)) + blocks.append(block) + nread += len(block) + return b''.join(blocks) + + @asyncio.coroutine + def read(self): + """Read the response body. + + This honors Content-Length and Transfer-Encoding: chunked. + """ + nbytes = None + for key, value in self.headers: + if key.lower() == 'content-length': + nbytes = int(value) + break + if nbytes is None: + if self.get_header('transfer-encoding').lower() == 'chunked': + self.vvprint('parsing chunked response') + blocks = [] + while True: + size_header = yield from self.reader.readline() + if not size_header: + self.vprint('premature end of chunked response') + break + self.vvvprint('size_header =', repr(size_header)) + parts = size_header.split(b';') + size = int(parts[0], 16) + if size: + self.vvvprint('reading chunk of', size, 'bytes') + block = yield from self.readexactly(size) + assert len(block) == size, (len(block), size) + blocks.append(block) + crlf = yield from self.reader.readline() + assert crlf == b'\r\n', repr(crlf) + if not size: + break + body = b''.join(blocks) + self.vprint('chunked response had',len(body), + 'bytes in', len(blocks), 'blocks') + else: + self.vvvprint('reading until EOF') + body = yield from self.reader.read() + # TODO: Should make sure not to recycle the connection + # in this case. + else: + body = yield from self.readexactly(nbytes) + return body + + +class Fetcher(VPrinter): + """Logic and state for one URL. + + When found in crawler.busy, this represents a URL to be fetched or + in the process of being fetched; when found in crawler.done, this + holds the results from fetching it. + + This is usually associated with a task. This references the + crawler for the connection pool and to add more URLs to its todo + list. + + Call fetch() to do the fetching, then report() to print the results. + """ + + def __init__(self, url, crawler, max_redirect=10, max_tries=4, verbose=0): + VPrinter.__init__(self, verbose) + self.url = url + self.crawler = crawler + # We don't loop resolving redirects here -- we just use this + # to decide whether to add the redirect URL to crawler.todo. + self.max_redirect = max_redirect + # But we do loop to retry on errors a few times. + self.max_tries = max_tries + # Everything we collect from the response goes here. + self.task = None + self.exceptions = [] + self.tries = 0 + self.request = None + self.response = None + self.body = None + self.next_url = None + self.ctype = None + self.pdict = None + self.encoding = None + self.urls = None + self.new_urls = None + + @asyncio.coroutine + def fetch(self): + """Attempt to fetch the contents of the URL. + + If successful, and the data is HTML, extract further links and + add them to the crawler. Redirects are also added back there. + """ + while self.tries < self.max_tries: + self.tries += 1 + try: + self.request = Request(self.url, self.crawler.pool, + self.verbose) + yield from self.request.connect() + yield from self.request.send_request() + self.response = yield from self.request.get_response() + self.body = yield from self.response.read() + h_conn = self.response.get_header('connection').lower() + h_t_enc = self.response.get_header('transfer-encoding').lower() + if h_conn != 'close': + self.request.recycle_connection() + if self.tries > 1: + self.vprint('try', self.tries, 'for', self.url, 'success') + break + except (BadStatusLine, OSError) as exc: + self.exceptions.append(exc) + self.vprint('try', self.tries, 'for', self.url, + 'raised', repr(exc)) + # Don't reuse the connection in this case. + else: + # We never broke out of the while loop, i.e. all tries failed. + self.vprint('no success for', self.url, + 'in', self.max_tries, 'tries') + return + next_url = self.response.get_redirect_url() + if next_url: + self.next_url = urllib.parse.urljoin(self.url, next_url) + if self.max_redirect > 0: + self.vprint('redirect to', self.next_url, 'from', self.url) + self.crawler.add_url(self.next_url, self.max_redirect-1) + else: + self.vprint('redirect limit reached for', self.next_url, + 'from', self.url) + else: + if self.response.status == 200: + self.ctype = self.response.get_header('content-type') + self.pdict = {} + if self.ctype: + self.ctype, self.pdict = cgi.parse_header(self.ctype) + self.encoding = self.pdict.get('charset', 'utf-8') + if self.ctype == 'text/html': + body = self.body.decode(self.encoding, 'replace') + self.urls = set(re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', + body)) + if self.urls: + self.vprint('got', len(self.urls), + 'new urls from', self.url) + self.new_urls = set() + for url in self.urls: + url = unescape(url) + url = urllib.parse.urljoin(self.url, url) + url, frag = urllib.parse.urldefrag(url) + if self.crawler.add_url(url): + self.new_urls.add(url) + + def report(self, file=None): + """Print a report on the state for this URL.""" + if self.task is not None: + if not self.task.done(): + print(self.url, 'pending', file=file) + return + elif self.task.cancelled(): + print(self.url, 'cancelled', file=file) + return + elif self.task.exception(): + print(self.url, self.task.exception(), file=file) + return + if len(self.exceptions) == self.tries: + print(self.url, 'error', self.exceptions[-1], file=file) + elif self.next_url: + print(self.url, self.response.status, 'redirect', self.next_url, + file=file) + elif self.ctype == 'text/html': + print(self.url, self.response.status, + self.ctype, self.encoding, + len(self.body or b''), + '%d/%d' % (len(self.new_urls or ()), len(self.urls or ())), + file=file) + else: + print(self.url, self.response.status, + self.ctype, self.encoding, + len(self.body or b''), self.ctype, + file=file) + + +class Crawler(VPrinter): + """Crawl a set of URLs. + + This manages three disjoint sets of URLs (todo, busy, done). The + data structures actually store dicts -- the values in todo give + the redirect limit, while the values in busy and done are Fetcher + instances. + """ + def __init__(self, roots, + max_redirect=10, max_tries=4, max_tasks=10, exclude=None, + verbose=0): + VPrinter.__init__(self, verbose) + self.roots = roots + self.max_redirect = max_redirect + self.max_tries = max_tries + self.max_tasks = max_tasks + self.exclude = exclude + self.todo = {} + self.busy = {} + self.done = {} + self.pool = ConnectionPool(self.verbose) + self.root_domains = set() + for root in roots: + parts = urllib.parse.urlparse(root) + host, port = urllib.parse.splitport(parts.netloc) + self.root_domains.add(host.lower()) + for root in roots: + self.add_url(root) + self.governor = asyncio.locks.Semaphore(max_tasks) + self.termination = asyncio.locks.Condition() + self.t0 = time.time() + self.t1 = None + + def close(self): + """Close resources (currently only the pool).""" + self.pool.close() + + def host_okay(self, host): + """Check if a host should be crawled. + + It must match one of the root URLs, but leading 'www.' on the + domain is ignored. However, other subdomains (e.g. 'mail.', + 'docs.') are not crawled unless mentioned in a root URL. + """ + host = host.lower() + if host in self.root_domains: + return True + if host.startswith('www.'): + if host[4:] in self.root_domains: + return True + else: + if 'www.' + host in self.root_domains: + return True + return False + + def add_url(self, url, max_redirect=None): + """Add a URL to the todo list if not seen before.""" + if self.exclude and re.search(self.exclude, url): + return False + parts = urllib.parse.urlparse(url) + if parts.scheme not in ('http', 'https'): + self.vvprint('skipping non-http scheme in', url) + return False + host, port = urllib.parse.splitport(parts.netloc) + if not self.host_okay(host): + self.vvprint('skipping non-root host in', url) + return False + if max_redirect is None: + max_redirect = self.max_redirect + if url in self.todo or url in self.busy or url in self.done: + return False + self.vprint('adding', url, max_redirect) + self.todo[url] = max_redirect + return True + + @asyncio.coroutine + def crawl(self): + """Run the crawler until all finished.""" + with (yield from self.termination): + while self.todo or self.busy: + if self.todo: + url, max_redirect = self.todo.popitem() + fetcher = Fetcher(url, + crawler=self, + max_redirect=max_redirect, + max_tries=self.max_tries, + verbose=self.verbose) + self.busy[url] = fetcher + fetcher.task = asyncio.Task(self.fetch(fetcher)) + else: + yield from self.termination.wait() + self.t1 = time.time() + + @asyncio.coroutine + def fetch(self, fetcher): + """Call the Fetcher's fetch(), with a limit on concurrency. + + Once this returns, move the fetcher from busy to done. + """ + url = fetcher.url + with (yield from self.governor): + try: + yield from fetcher.fetch() # Fetcher gonna fetch. + except Exception: + # Force GC of the task, so the error is logged. + fetcher.task = None + raise + with (yield from self.termination): + self.done[url] = fetcher + del self.busy[url] + self.termination.notify() + + def report(self, file=None): + """Print a report on all completed URLs.""" + if self.t1 is None: + self.t1 = time.time() + dt = self.t1 - self.t0 + if dt and self.max_tasks: + speed = len(self.done) / dt / self.max_tasks + else: + speed = 0 + if self.verbose > 0: + print('*** Report ***', file=file) + try: + for url, fetcher in sorted(self.done.items()): + fetcher.report(file=file) + except KeyboardInterrupt: + print('\nInterrupted', file=file) + print('Crawled', len(self.done), + 'urls in %.3f secs' % dt, + '(max_tasks=%d)' % self.max_tasks, + '(%.3f urls/sec/task)' % speed, + file=file) + + +def main(): + """Main program. + + Parse arguments, set up event loop, run crawler, print report. + """ + args = ARGS.parse_args() + + if args.iocp: + from asyncio.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + asyncio.set_event_loop(loop) + elif args.select: + loop = asyncio.SelectorEventLoop() + asyncio.set_event_loop(loop) + else: + loop = asyncio.get_event_loop() + + roots = {fix_url(root) for root in args.roots} + + crawler = Crawler(roots, + max_redirect=args.max_redirect, + max_tries=args.max_tries, + max_tasks=args.max_tasks, + exclude=args.exclude, + verbose=args.verbose) + try: + loop.run_until_complete(crawler.crawl()) # Crawler gonna crawl. + except KeyboardInterrupt: + sys.stderr.flush() + print('\nInterrupted\n') + finally: + crawler.report() + crawler.close() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + main() From 166835ac82fba06e0827dd3181a7790b39e10b8c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 8 Jan 2014 17:29:18 -0800 Subject: [PATCH 0850/1502] Ignore now-closed connections in pool. --- examples/fetch3.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/fetch3.py b/examples/fetch3.py index bdca3aee..9419afd2 100644 --- a/examples/fetch3.py +++ b/examples/fetch3.py @@ -34,6 +34,10 @@ def open_connection(self, host, port, ssl): key = h, p, ssl conn = self.connections.get(key) if conn: + reader, writer = conn + if reader._eof: + self.connections.pop(key) + continue if self.verbose: print('* Reusing pooled connection', key, file=sys.stderr) return conn From 9660a2d8c466a63796ae44133dec92ae5d7b8fed Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 8 Jan 2014 20:52:49 -0800 Subject: [PATCH 0851/1502] Fix serious leak in connection pool (still a minor one left). --- examples/crawl.py | 64 ++++++++++++++++++++++++++++++++++------------- 1 file changed, 47 insertions(+), 17 deletions(-) diff --git a/examples/crawl.py b/examples/crawl.py index 18fbe7be..86c41ead 100644 --- a/examples/crawl.py +++ b/examples/crawl.py @@ -5,7 +5,8 @@ # TODO: # - Less verbose logging. # - Support gzip encoding. -# - Seems sometimes getaddrinfo() raises gaierror? +# - Close connection if HTTP/1.0 response. +# - Expire connections in pool if too many. import argparse import asyncio @@ -110,23 +111,24 @@ class ConnectionPool(VPrinter): To open a connection, use reserve(). To recycle it, use unreserve(). - The pool is mostly just a mapping from (host, port, ssl) to - (reader, writer). The currently active connections are *not* in - the mapping; reserve() takes the connection out of the mapping, - and unreserve()' puts it back in. It is up to the caller to only - call unreserve() for reusable connections. (That logic is + The pool is mostly just a mapping from (host, port, ssl) tuples to + lists of (reader, writer) pairs. The currently active connections + are *not* in the data structure; reserve() takes the connection + out, and unreserve()' puts it back in. It is up to the caller to + only call unreserve() for reusable connections. (That logic is implemented in the Request class.) """ def __init__(self, verbose=0): VPrinter.__init__(self, verbose) self.loop = asyncio.get_event_loop() - self.connections = {} # {(host, port, ssl): (reader, writer)} + self.connections = {} # {(host, port, ssl): [(reader, writer)]} def close(self): """Close all connections available for reuse.""" - for _, writer in self.connections.values(): - writer.close() + for pairs in self.connections.values(): + for _, writer in pairs: + writer.close() @asyncio.coroutine def reserve(self, host, port, ssl): @@ -141,21 +143,37 @@ def reserve(self, host, port, ssl): (host, ', '.join(ip[4][0] for ip in ipaddrs))) for _, _, _, _, (h, p, *_) in ipaddrs: key = h, p, ssl - pair = self.connections.pop(key, None) - if pair: - self.vprint('* Reusing pooled connection', key) + pair = None + pairs = self.connections.get(key) + while pairs: + pair = pairs.pop(0) + if not pairs: + del self.connections[key] reader, writer = pair - return key, reader, writer + if reader._eof: + self.vprint('(cached connection closed for %s)' % repr(key)) + else: + self.vprint('* Reusing pooled connection', key, 'FD =', writer._transport._sock.fileno()) + return key, reader, writer reader, writer = yield from asyncio.open_connection(host, port, ssl=ssl) - host, port, *_ = writer.get_extra_info('peername') + peername = writer.get_extra_info('peername') + if peername: + host, port, *_ = peername + else: + self.vprint('NO PEERNAME???', host, port, ssl) key = host, port, ssl - self.vprint('* New connection', key) + self.vprint('* New connection', key, 'FD =', writer._transport._sock.fileno()) return key, reader, writer def unreserve(self, key, reader, writer): """Make a connection available for reuse.""" - self.connections[key] = (reader, writer) + if reader._eof: + return + pairs = self.connections.get(key) + if pairs is None: + self.connections[key] = pairs = [] + pairs.append((reader, writer)) class Request(VPrinter): @@ -185,6 +203,7 @@ def __init__(self, url, pool, verbose=0): self.http_version = 'HTTP/1.1' self.method = 'GET' self.headers = [] + self.key = None self.reader = None self.writer = None @@ -209,6 +228,11 @@ def recycle_connection(self): self.pool.unreserve(self.key, self.reader, self.writer) self.key = self.reader = self.writer = None + def close(self): + if self.writer is not None: + self.writer.close() + self.key = self.reader = self.writer = None + @asyncio.coroutine def putline(self, line): """Write a line to the connection. @@ -408,6 +432,7 @@ def fetch(self): """ while self.tries < self.max_tries: self.tries += 1 + self.request = None try: self.request = Request(self.url, self.crawler.pool, self.verbose) @@ -419,6 +444,7 @@ def fetch(self): h_t_enc = self.response.get_header('transfer-encoding').lower() if h_conn != 'close': self.request.recycle_connection() + self.request = None if self.tries > 1: self.vprint('try', self.tries, 'for', self.url, 'success') break @@ -426,7 +452,11 @@ def fetch(self): self.exceptions.append(exc) self.vprint('try', self.tries, 'for', self.url, 'raised', repr(exc)) + ##import pdb; pdb.set_trace() # Don't reuse the connection in this case. + finally: + if self.request is not None: + self.request.close() else: # We never broke out of the while loop, i.e. all tries failed. self.vprint('no success for', self.url, @@ -454,7 +484,7 @@ def fetch(self): body)) if self.urls: self.vprint('got', len(self.urls), - 'new urls from', self.url) + 'distinct urls from', self.url) self.new_urls = set() for url in self.urls: url = unescape(url) From a1dae73ff04ac0bd36279970925fa26bd0fc6fe4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 9 Jan 2014 10:58:35 -0800 Subject: [PATCH 0852/1502] Tiny tweaks to code/docs. --- asyncio/streams.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/asyncio/streams.py b/asyncio/streams.py index 7eda5f66..f01f8629 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -38,7 +38,7 @@ def open_connection(host=None, port=None, *, if loop is None: loop = events.get_event_loop() reader = StreamReader(limit=limit, loop=loop) - protocol = StreamReaderProtocol(reader) + protocol = StreamReaderProtocol(reader, loop=loop) transport, _ = yield from loop.create_connection( lambda: protocol, host, port, **kwds) writer = StreamWriter(transport, protocol, reader, loop) @@ -151,7 +151,7 @@ class StreamWriter: This exposes write(), writelines(), [can_]write_eof(), get_extra_info() and close(). It adds drain() which returns an optional Future on which you can wait for flow control. It also - adds a transport attribute which references the Transport + adds a transport property which references the Transport directly. """ From 0dd16df075e66e8c7751fc1230b6c59918711e11 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 9 Jan 2014 11:01:09 -0800 Subject: [PATCH 0853/1502] Connection pool limits; strict/lenient host matching. --- examples/crawl.py | 142 +++++++++++++++++++++++++++++++++++++++------- 1 file changed, 120 insertions(+), 22 deletions(-) diff --git a/examples/crawl.py b/examples/crawl.py index 86c41ead..db33a175 100644 --- a/examples/crawl.py +++ b/examples/crawl.py @@ -6,7 +6,10 @@ # - Less verbose logging. # - Support gzip encoding. # - Close connection if HTTP/1.0 response. -# - Expire connections in pool if too many. +# - Add timeouts. +# - Improve class structure (e.g. add a Connection class). +# - Skip reading large non-text/html files? +# - Use ETag and If-Modified-Since? import argparse import asyncio @@ -30,19 +33,29 @@ default=False, help='Use Select event loop instead of default') ARGS.add_argument( 'roots', nargs='*', - default=['http://python.org'], help='Root URL (may be repeated)') + default=[], help='Root URL (may be repeated)') ARGS.add_argument( '--max_redirect', action='store', type=int, metavar='N', - default=10, help='Limit redirection (for 301, 302 etc.)') + default=10, help='Limit redirection chains (for 301, 302 etc.)') ARGS.add_argument( '--max_tries', action='store', type=int, metavar='N', default=4, help='Limit retries on network errors') ARGS.add_argument( '--max_tasks', action='store', type=int, metavar='N', - default=5, help='Limit on concurrent connections') + default=5, help='Limit concurrent connections') +ARGS.add_argument( + '--max_pool', action='store', type=int, metavar='N', + default=10, help='Limit connection pool size') ARGS.add_argument( '--exclude', action='store', metavar='REGEX', help='Exclude matching URLs') +ARGS.add_argument( + '--strict', action='store_true', + default=True, help='Strict host matching (default)') +ARGS.add_argument( + '--lenient', action='store_false', dest='strict', + default=False, help='Lenient host matching') + ARGS.add_argument( '-v', '--verbose', action='count', dest='verbose', default=1, help='Verbose logging (repeat for more verbose)') @@ -117,18 +130,25 @@ class ConnectionPool(VPrinter): out, and unreserve()' puts it back in. It is up to the caller to only call unreserve() for reusable connections. (That logic is implemented in the Request class.) + + There are limits to both the overal pool and the per-key pool. """ - def __init__(self, verbose=0): + def __init__(self, max_pool=10, max_tasks=5, verbose=0): VPrinter.__init__(self, verbose) + self.max_pool = max_pool # Overall limit. + self.max_tasks = max_tasks # Per-key limit. self.loop = asyncio.get_event_loop() self.connections = {} # {(host, port, ssl): [(reader, writer)]} + self.queue = [] # [(key, pair)] def close(self): """Close all connections available for reuse.""" for pairs in self.connections.values(): for _, writer in pairs: writer.close() + self.connections.clear() + self.queue.clear() @asyncio.coroutine def reserve(self, host, port, ssl): @@ -141,20 +161,28 @@ def reserve(self, host, port, ssl): raise self.vprint('* %s resolves to %s' % (host, ', '.join(ip[4][0] for ip in ipaddrs))) + + # Look for a reusable connection. for _, _, _, _, (h, p, *_) in ipaddrs: key = h, p, ssl pair = None pairs = self.connections.get(key) while pairs: pair = pairs.pop(0) + self.queue.remove((key, pair)) if not pairs: del self.connections[key] reader, writer = pair if reader._eof: - self.vprint('(cached connection closed for %s)' % repr(key)) + self.vprint('(cached connection closed for %s)' % + repr(key)) + writer.close() # Just in case. else: - self.vprint('* Reusing pooled connection', key, 'FD =', writer._transport._sock.fileno()) + self.vprint('* Reusing pooled connection', key, + 'FD =', writer._transport._sock.fileno()) return key, reader, writer + + # Create a new connection. reader, writer = yield from asyncio.open_connection(host, port, ssl=ssl) peername = writer.get_extra_info('peername') @@ -163,17 +191,40 @@ def reserve(self, host, port, ssl): else: self.vprint('NO PEERNAME???', host, port, ssl) key = host, port, ssl - self.vprint('* New connection', key, 'FD =', writer._transport._sock.fileno()) + self.vprint('* New connection', key, + 'FD =', writer._transport._sock.fileno()) return key, reader, writer def unreserve(self, key, reader, writer): - """Make a connection available for reuse.""" + """Make a connection available for reuse. + + This also prunes the pool if it exceeds the size limits. + """ if reader._eof: + writer.close() return - pairs = self.connections.get(key) - if pairs is None: - self.connections[key] = pairs = [] - pairs.append((reader, writer)) + pair = reader, writer + pairs = self.connections.setdefault(key, []) + pairs.append(pair) + self.queue.append((key, pair)) + + # Close oldest connection(s) for this key if limit reached. + while len(pairs) > self.max_tasks: + pair = pairs.pop(0) + self.vprint('closing oldest connection for', key) + self.queue.remove((key, pair)) + reader, writer = pair + writer.close() + + # Close oldest overall connection(s) if limit reached. + while len(self.queue) > self.max_pool: + key, pair = self.queue.pop(0) + self.vprint('closing olderst connection', key) + pairs = self.connections.get(key) + p = pairs.pop(0) + assert pair == p, (key, pair, p, pairs) + reader, writer = pair + writer.close() class Request(VPrinter): @@ -531,24 +582,44 @@ class Crawler(VPrinter): the redirect limit, while the values in busy and done are Fetcher instances. """ - def __init__(self, roots, - max_redirect=10, max_tries=4, max_tasks=10, exclude=None, + def __init__(self, + roots, exclude=None, strict=True, # What to crawl. + max_redirect=10, max_tries=4, # Per-url limits. + max_tasks=10, max_pool=10, # Global limits. verbose=0): VPrinter.__init__(self, verbose) self.roots = roots + self.exclude = exclude + self.strict = strict self.max_redirect = max_redirect self.max_tries = max_tries self.max_tasks = max_tasks - self.exclude = exclude + self.max_pool = max_pool self.todo = {} self.busy = {} self.done = {} - self.pool = ConnectionPool(self.verbose) + self.pool = ConnectionPool(max_pool, max_tasks, self.verbose) self.root_domains = set() for root in roots: parts = urllib.parse.urlparse(root) host, port = urllib.parse.splitport(parts.netloc) - self.root_domains.add(host.lower()) + if not host: + continue + if re.match(r'\A[\d\.]*\Z', host): + self.root_domains.add(host) + else: + host = host.lower() + if self.strict: + self.root_domains.add(host) + if host.startswith('www.'): + self.root_domains.add(host[4:]) + else: + self.root_domains.add('www.' + host) + else: + parts = host.split('.') + if len(parts) > 2: + host = '.'.join(parts[-2:]) + self.root_domains.add(host) for root in roots: self.add_url(root) self.governor = asyncio.locks.Semaphore(max_tasks) @@ -563,13 +634,25 @@ def close(self): def host_okay(self, host): """Check if a host should be crawled. - It must match one of the root URLs, but leading 'www.' on the - domain is ignored. However, other subdomains (e.g. 'mail.', - 'docs.') are not crawled unless mentioned in a root URL. + A literal match (after lowercasing) is always good. For hosts + that don't look like IP addresses, some approximate matches + are okay depending on the strict flag. """ host = host.lower() if host in self.root_domains: return True + if re.match(r'\A[\d\.]*\Z', host): + return False + if self.strict: + return self._host_okay_strictish(host) + else: + return self._host_okay_lenient(host) + + def _host_okay_strictish(self, host): + """Check if a host should be crawled, strict-ish version. + + This checks for equality modulo an initial 'www.' component. + """ if host.startswith('www.'): if host[4:] in self.root_domains: return True @@ -578,6 +661,16 @@ def host_okay(self, host): return True return False + def _host_okay_lenient(self, host): + """Check if a host should be crawled, lenient version. + + This compares the last two components of the host. + """ + parts = host.split('.') + if len(parts) > 2: + host = '.'.join(parts[-2:]) + return host in self.root_domains + def add_url(self, url, max_redirect=None): """Add a URL to the todo list if not seen before.""" if self.exclude and re.search(self.exclude, url): @@ -664,6 +757,9 @@ def main(): Parse arguments, set up event loop, run crawler, print report. """ args = ARGS.parse_args() + if not args.roots: + print('Use --help for command line help') + return if args.iocp: from asyncio.windows_events import ProactorEventLoop @@ -678,10 +774,12 @@ def main(): roots = {fix_url(root) for root in args.roots} crawler = Crawler(roots, + exclude=args.exclude, + strict=args.strict, max_redirect=args.max_redirect, max_tries=args.max_tries, max_tasks=args.max_tasks, - exclude=args.exclude, + max_pool=args.max_pool, verbose=args.verbose) try: loop.run_until_complete(crawler.crawl()) # Crawler gonna crawl. From 8c3a4665602af4f9b7b621df37946e31340d84e7 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 9 Jan 2014 11:18:54 -0800 Subject: [PATCH 0854/1502] Change max tasks/pool default to 100. Add TODOs. PEP8 tweaks. --- examples/crawl.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/examples/crawl.py b/examples/crawl.py index db33a175..562aa2df 100644 --- a/examples/crawl.py +++ b/examples/crawl.py @@ -3,13 +3,17 @@ """A simple web crawler.""" # TODO: -# - Less verbose logging. +# - Make VPrinter a sub-object, not a base class. +# - More organized logging (with task ID?). Use logging module. +# - Nicer reporting, e.g. total bytes, total html bytes, +# success/redirect/error counts, time of day (local+UTC). # - Support gzip encoding. # - Close connection if HTTP/1.0 response. -# - Add timeouts. +# - Add timeouts. (E.g. when switching networks, all seems to hang.) # - Improve class structure (e.g. add a Connection class). # - Skip reading large non-text/html files? # - Use ETag and If-Modified-Since? +# - Handle out of file descriptors directly? (How?) import argparse import asyncio @@ -42,10 +46,10 @@ default=4, help='Limit retries on network errors') ARGS.add_argument( '--max_tasks', action='store', type=int, metavar='N', - default=5, help='Limit concurrent connections') + default=100, help='Limit concurrent connections') ARGS.add_argument( '--max_pool', action='store', type=int, metavar='N', - default=10, help='Limit connection pool size') + default=100, help='Limit connection pool size') ARGS.add_argument( '--exclude', action='store', metavar='REGEX', help='Exclude matching URLs') @@ -70,6 +74,7 @@ ('amp', '&') # Must be last. ] + def unescape(url): """Turn & into &, and so on. @@ -425,7 +430,7 @@ def read(self): if not size: break body = b''.join(blocks) - self.vprint('chunked response had',len(body), + self.vprint('chunked response had', len(body), 'bytes in', len(blocks), 'blocks') else: self.vvvprint('reading until EOF') From f9d39fe8dc10f5f8e15d4ee06f84da93a68b5ed9 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 9 Jan 2014 12:49:54 -0800 Subject: [PATCH 0855/1502] Better reporting. --- examples/crawl.py | 73 +++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 61 insertions(+), 12 deletions(-) diff --git a/examples/crawl.py b/examples/crawl.py index 562aa2df..044c7d31 100644 --- a/examples/crawl.py +++ b/examples/crawl.py @@ -5,12 +5,12 @@ # TODO: # - Make VPrinter a sub-object, not a base class. # - More organized logging (with task ID?). Use logging module. -# - Nicer reporting, e.g. total bytes, total html bytes, -# success/redirect/error counts, time of day (local+UTC). +# - KeyboardInterrupt in HTML parsing may hang or report unretrieved error. # - Support gzip encoding. # - Close connection if HTTP/1.0 response. # - Add timeouts. (E.g. when switching networks, all seems to hang.) # - Improve class structure (e.g. add a Connection class). +# - Add arguments to specify TLS settings (e.g. cert/key files). # - Skip reading large non-text/html files? # - Use ETag and If-Modified-Since? # - Handle out of file descriptors directly? (How?) @@ -224,7 +224,7 @@ def unreserve(self, key, reader, writer): # Close oldest overall connection(s) if limit reached. while len(self.queue) > self.max_pool: key, pair = self.queue.pop(0) - self.vprint('closing olderst connection', key) + self.vprint('closing oldest connection', key) pairs = self.connections.get(key) p = pairs.pop(0) assert pair == p, (key, pair, p, pairs) @@ -549,36 +549,73 @@ def fetch(self): if self.crawler.add_url(url): self.new_urls.add(url) - def report(self, file=None): - """Print a report on the state for this URL.""" + def report(self, stats, file=None): + """Print a report on the state for this URL. + + Also update the Stats instance. + """ if self.task is not None: if not self.task.done(): + stats.add('pending') print(self.url, 'pending', file=file) return elif self.task.cancelled(): + stats.add('cancelled') print(self.url, 'cancelled', file=file) return elif self.task.exception(): - print(self.url, self.task.exception(), file=file) + stats.add('exception') + exc = self.task.exception() + stats.add('exception_' + exc.__class__.__name__) + print(self.url, exc, file=file) return if len(self.exceptions) == self.tries: - print(self.url, 'error', self.exceptions[-1], file=file) + stats.add('fail') + exc = self.exceptions[-1] + stats.add('fail_' + str(exc.__class__.__name__)) + print(self.url, 'error', exc, file=file) elif self.next_url: + stats.add('redirect') print(self.url, self.response.status, 'redirect', self.next_url, file=file) elif self.ctype == 'text/html': + stats.add('html') + size = len(self.body or b'') + stats.add('html_bytes', size) print(self.url, self.response.status, self.ctype, self.encoding, - len(self.body or b''), + size, '%d/%d' % (len(self.new_urls or ()), len(self.urls or ())), file=file) else: + size = len(self.body or b'') + if self.response.status == 200: + stats.add('other') + stats.add('other_bytes', size) + else: + stats.add('error') + stats.add('error_bytes', size) + stats.add('status_%s' % self.response.status) print(self.url, self.response.status, self.ctype, self.encoding, - len(self.body or b''), self.ctype, + size, file=file) +class Stats: + """Record stats of various sorts.""" + + def __init__(self): + self.stats = {} + + def add(self, key, count=1): + self.stats[key] = self.stats.get(key, 0) + count + + def report(self, file=None): + for key, count in sorted(self.stats.items()): + print(' %-20s %10d' % (key, count), file=file) + + class Crawler(VPrinter): """Crawl a set of URLs. @@ -742,18 +779,30 @@ def report(self, file=None): speed = len(self.done) / dt / self.max_tasks else: speed = 0 + stats = Stats() if self.verbose > 0: print('*** Report ***', file=file) try: - for url, fetcher in sorted(self.done.items()): - fetcher.report(file=file) + show = [] + show.extend(self.done.items()) + show.extend(self.busy.items()) + show.sort() + for url, fetcher in show: + fetcher.report(stats, file=file) except KeyboardInterrupt: print('\nInterrupted', file=file) - print('Crawled', len(self.done), + print('Finished', len(self.done), 'urls in %.3f secs' % dt, '(max_tasks=%d)' % self.max_tasks, '(%.3f urls/sec/task)' % speed, file=file) + stats.report(file=file) + if self.todo: + print('Todo:', len(self.todo), file=file) + if self.busy: + print('Busy:', len(self.busy), file=file) + print('Done:', len(self.done), file=file) + print('Date:', time.ctime(), 'local time', file=file) def main(): From 1d4bc2373439c46809c7708e2d1b391fa598491e Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 9 Jan 2014 14:45:48 -0800 Subject: [PATCH 0856/1502] Refactor Logger: No more VPrinter base class. --- examples/crawl.py | 189 +++++++++++++++++++++------------------------- 1 file changed, 85 insertions(+), 104 deletions(-) diff --git a/examples/crawl.py b/examples/crawl.py index 044c7d31..6716815d 100644 --- a/examples/crawl.py +++ b/examples/crawl.py @@ -3,8 +3,8 @@ """A simple web crawler.""" # TODO: -# - Make VPrinter a sub-object, not a base class. -# - More organized logging (with task ID?). Use logging module. +# - More organized logging (with task ID or URL?). +# - Use logging module for Logger. # - KeyboardInterrupt in HTML parsing may hang or report unretrieved error. # - Support gzip encoding. # - Close connection if HTTP/1.0 response. @@ -59,13 +59,12 @@ ARGS.add_argument( '--lenient', action='store_false', dest='strict', default=False, help='Lenient host matching') - ARGS.add_argument( - '-v', '--verbose', action='count', dest='verbose', + '-v', '--verbose', action='count', dest='level', default=1, help='Verbose logging (repeat for more verbose)') ARGS.add_argument( - '-q', '--quiet', action='store_const', const=0, dest='verbose', - help='Quiet logging (opposite of --verbose)') + '-q', '--quiet', action='store_const', const=0, dest='level', + default=1, help='Quiet logging (opposite of --verbose)') ESCAPES = [('quot', '"'), @@ -92,39 +91,23 @@ def fix_url(url): return url -class VPrinter: - """Mix-in class defining vprint() which is like print() if verbose > 0. - - The output goes to stderr. Only positional arguments are - supported. There are also methods vvprint(), vvvprint() - etc. which print() only if verbose > larger values. - - The verbose instance variable is public. +class Logger: - TODO: This should probably be a shared object rather than a mix-in class. - """ - - def __init__(self, verbose): - self.verbose = verbose + def __init__(self, level): + self.level = level - def _nvprint(self, n, args): - if self.verbose >= n: + def _log(self, n, args): + if self.level >= n: print(*args, file=sys.stderr, flush=True) - def nvprint(self, n, *args): - self._nvprint(n, args) - - def vprint(self, *args): - self._nvprint(1, args) + def log(self, n, *args): + self._log(n, args) - def vvprint(self, *args): - self._nvprint(2, args) + def __call__(self, n, *args): + self._log(n, args) - def vvvprint(self, *args): - self._nvprint(3, args) - -class ConnectionPool(VPrinter): +class ConnectionPool: """A connection pool. To open a connection, use reserve(). To recycle it, use unreserve(). @@ -139,8 +122,8 @@ class ConnectionPool(VPrinter): There are limits to both the overal pool and the per-key pool. """ - def __init__(self, max_pool=10, max_tasks=5, verbose=0): - VPrinter.__init__(self, verbose) + def __init__(self, log, max_pool=10, max_tasks=5): + self.log = log self.max_pool = max_pool # Overall limit. self.max_tasks = max_tasks # Per-key limit. self.loop = asyncio.get_event_loop() @@ -162,9 +145,9 @@ def reserve(self, host, port, ssl): try: ipaddrs = yield from self.loop.getaddrinfo(host, port) except Exception as exc: - self.vprint('Exception %r for (%r, %r)' % (exc, host, port)) + self.log(0, 'Exception %r for (%r, %r)' % (exc, host, port)) raise - self.vprint('* %s resolves to %s' % + self.log(1, '* %s resolves to %s' % (host, ', '.join(ip[4][0] for ip in ipaddrs))) # Look for a reusable connection. @@ -179,11 +162,11 @@ def reserve(self, host, port, ssl): del self.connections[key] reader, writer = pair if reader._eof: - self.vprint('(cached connection closed for %s)' % + self.log(1, '(cached connection closed for %s)' % repr(key)) writer.close() # Just in case. else: - self.vprint('* Reusing pooled connection', key, + self.log(1, '* Reusing pooled connection', key, 'FD =', writer._transport._sock.fileno()) return key, reader, writer @@ -194,9 +177,9 @@ def reserve(self, host, port, ssl): if peername: host, port, *_ = peername else: - self.vprint('NO PEERNAME???', host, port, ssl) + self.log(1, 'NO PEERNAME???', host, port, ssl) key = host, port, ssl - self.vprint('* New connection', key, + self.log(1, '* New connection', key, 'FD =', writer._transport._sock.fileno()) return key, reader, writer @@ -216,7 +199,7 @@ def unreserve(self, key, reader, writer): # Close oldest connection(s) for this key if limit reached. while len(pairs) > self.max_tasks: pair = pairs.pop(0) - self.vprint('closing oldest connection for', key) + self.log(1, 'closing oldest connection for', key) self.queue.remove((key, pair)) reader, writer = pair writer.close() @@ -224,7 +207,7 @@ def unreserve(self, key, reader, writer): # Close oldest overall connection(s) if limit reached. while len(self.queue) > self.max_pool: key, pair = self.queue.pop(0) - self.vprint('closing oldest connection', key) + self.log(1, 'closing oldest connection', key) pairs = self.connections.get(key) p = pairs.pop(0) assert pair == p, (key, pair, p, pairs) @@ -232,15 +215,15 @@ def unreserve(self, key, reader, writer): writer.close() -class Request(VPrinter): +class Request: """HTTP request. Use connect() to open a connection; send_request() to send the request; get_response() to receive the response headers. """ - def __init__(self, url, pool, verbose=0): - VPrinter.__init__(self, verbose) + def __init__(self, log, url, pool): + self.log = log self.url = url self.pool = pool self.parts = urllib.parse.urlparse(self.url) @@ -266,13 +249,13 @@ def __init__(self, url, pool, verbose=0): @asyncio.coroutine def connect(self): """Open a connection to the server.""" - self.vprint('* Connecting to %s:%s using %s for %s' % + self.log(1, '* Connecting to %s:%s using %s for %s' % (self.hostname, self.port, 'ssl' if self.ssl else 'tcp', self.url)) self.key, self.reader, self.writer = \ yield from self.pool.reserve(self.hostname, self.port, self.ssl) - self.vprint('* Connected to %s' % + self.log(1, '* Connected to %s' % (self.writer.get_extra_info('peername'),)) def recycle_connection(self): @@ -295,7 +278,7 @@ def putline(self, line): Used for the request line and headers. """ - self.vvprint('>', line) + self.log(2, '>', line) self.writer.write(line.encode('latin-1') + b'\r\n') @asyncio.coroutine @@ -317,12 +300,12 @@ def send_request(self): @asyncio.coroutine def get_response(self): """Receive the response.""" - response = Response(self.reader, self.verbose) + response = Response(self.log, self.reader) yield from response.read_headers() return response -class Response(VPrinter): +class Response: """HTTP response. Call read_headers() to receive the request headers. Then check @@ -330,8 +313,8 @@ class Response(VPrinter): Finally call read() to receive the body. """ - def __init__(self, reader, verbose=0): - VPrinter.__init__(self, verbose) + def __init__(self, log, reader): + self.log = log self.reader = reader self.http_version = None # 'HTTP/1.1' self.status = None # 200 @@ -342,7 +325,7 @@ def __init__(self, reader, verbose=0): def getline(self): """Read one line from the connection.""" line = (yield from self.reader.readline()).decode('latin-1').rstrip() - self.vvprint('<', line) + self.log(2, '<', line) return line @asyncio.coroutine @@ -351,7 +334,7 @@ def read_headers(self): status_line = yield from self.getline() status_parts = status_line.split(None, 2) if len(status_parts) != 3: - self.vprint('bad status_line', repr(status_line)) + self.log(0, 'bad status_line', repr(status_line)) raise BadStatusLine(status_line) self.http_version, status, self.reason = status_parts self.status = int(status) @@ -386,10 +369,10 @@ def readexactly(self, nbytes): blocks = [] nread = 0 while nread < nbytes: - self.vvvprint('reading block', len(blocks), - 'with', nbytes - nread, 'bytes remaining') + self.log(3, 'reading block', len(blocks), + 'with', nbytes - nread, 'bytes remaining') block = yield from self.reader.read(nbytes-nread) - self.vvvprint('read', len(block), 'bytes') + self.log(3, 'read', len(block), 'bytes') if not block: raise EOFError('EOF with %d more bytes expected' % (nbytes - nread)) @@ -410,18 +393,18 @@ def read(self): break if nbytes is None: if self.get_header('transfer-encoding').lower() == 'chunked': - self.vvprint('parsing chunked response') + self.log(2, 'parsing chunked response') blocks = [] while True: size_header = yield from self.reader.readline() if not size_header: - self.vprint('premature end of chunked response') + self.log(0, 'premature end of chunked response') break - self.vvvprint('size_header =', repr(size_header)) + self.log(3, 'size_header =', repr(size_header)) parts = size_header.split(b';') size = int(parts[0], 16) if size: - self.vvvprint('reading chunk of', size, 'bytes') + self.log(3, 'reading chunk of', size, 'bytes') block = yield from self.readexactly(size) assert len(block) == size, (len(block), size) blocks.append(block) @@ -430,10 +413,10 @@ def read(self): if not size: break body = b''.join(blocks) - self.vprint('chunked response had', len(body), + self.log(1, 'chunked response had', len(body), 'bytes in', len(blocks), 'blocks') else: - self.vvvprint('reading until EOF') + self.log(3, 'reading until EOF') body = yield from self.reader.read() # TODO: Should make sure not to recycle the connection # in this case. @@ -442,7 +425,7 @@ def read(self): return body -class Fetcher(VPrinter): +class Fetcher: """Logic and state for one URL. When found in crawler.busy, this represents a URL to be fetched or @@ -456,8 +439,8 @@ class Fetcher(VPrinter): Call fetch() to do the fetching, then report() to print the results. """ - def __init__(self, url, crawler, max_redirect=10, max_tries=4, verbose=0): - VPrinter.__init__(self, verbose) + def __init__(self, log, url, crawler, max_redirect=10, max_tries=4): + self.log = log self.url = url self.crawler = crawler # We don't loop resolving redirects here -- we just use this @@ -490,8 +473,7 @@ def fetch(self): self.tries += 1 self.request = None try: - self.request = Request(self.url, self.crawler.pool, - self.verbose) + self.request = Request(self.log, self.url, self.crawler.pool) yield from self.request.connect() yield from self.request.send_request() self.response = yield from self.request.get_response() @@ -502,11 +484,11 @@ def fetch(self): self.request.recycle_connection() self.request = None if self.tries > 1: - self.vprint('try', self.tries, 'for', self.url, 'success') + self.log(1, 'try', self.tries, 'for', self.url, 'success') break except (BadStatusLine, OSError) as exc: self.exceptions.append(exc) - self.vprint('try', self.tries, 'for', self.url, + self.log(1, 'try', self.tries, 'for', self.url, 'raised', repr(exc)) ##import pdb; pdb.set_trace() # Don't reuse the connection in this case. @@ -515,17 +497,17 @@ def fetch(self): self.request.close() else: # We never broke out of the while loop, i.e. all tries failed. - self.vprint('no success for', self.url, + self.log(0, 'no success for', self.url, 'in', self.max_tries, 'tries') return next_url = self.response.get_redirect_url() if next_url: self.next_url = urllib.parse.urljoin(self.url, next_url) if self.max_redirect > 0: - self.vprint('redirect to', self.next_url, 'from', self.url) + self.log(1, 'redirect to', self.next_url, 'from', self.url) self.crawler.add_url(self.next_url, self.max_redirect-1) else: - self.vprint('redirect limit reached for', self.next_url, + self.log(0, 'redirect limit reached for', self.next_url, 'from', self.url) else: if self.response.status == 200: @@ -536,10 +518,11 @@ def fetch(self): self.encoding = self.pdict.get('charset', 'utf-8') if self.ctype == 'text/html': body = self.body.decode(self.encoding, 'replace') + # Replace href with (?:href|src) to follow image links. self.urls = set(re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', body)) if self.urls: - self.vprint('got', len(self.urls), + self.log(1, 'got', len(self.urls), 'distinct urls from', self.url) self.new_urls = set() for url in self.urls: @@ -616,7 +599,7 @@ def report(self, file=None): print(' %-20s %10d' % (key, count), file=file) -class Crawler(VPrinter): +class Crawler: """Crawl a set of URLs. This manages three disjoint sets of URLs (todo, busy, done). The @@ -624,12 +607,12 @@ class Crawler(VPrinter): the redirect limit, while the values in busy and done are Fetcher instances. """ - def __init__(self, + def __init__(self, log, roots, exclude=None, strict=True, # What to crawl. max_redirect=10, max_tries=4, # Per-url limits. max_tasks=10, max_pool=10, # Global limits. - verbose=0): - VPrinter.__init__(self, verbose) + ): + self.log = log self.roots = roots self.exclude = exclude self.strict = strict @@ -640,7 +623,7 @@ def __init__(self, self.todo = {} self.busy = {} self.done = {} - self.pool = ConnectionPool(max_pool, max_tasks, self.verbose) + self.pool = ConnectionPool(self.log, max_pool, max_tasks) self.root_domains = set() for root in roots: parts = urllib.parse.urlparse(root) @@ -719,17 +702,17 @@ def add_url(self, url, max_redirect=None): return False parts = urllib.parse.urlparse(url) if parts.scheme not in ('http', 'https'): - self.vvprint('skipping non-http scheme in', url) + self.log(2, 'skipping non-http scheme in', url) return False host, port = urllib.parse.splitport(parts.netloc) if not self.host_okay(host): - self.vvprint('skipping non-root host in', url) + self.log(2, 'skipping non-root host in', url) return False if max_redirect is None: max_redirect = self.max_redirect if url in self.todo or url in self.busy or url in self.done: return False - self.vprint('adding', url, max_redirect) + self.log(1, 'adding', url, max_redirect) self.todo[url] = max_redirect return True @@ -740,11 +723,11 @@ def crawl(self): while self.todo or self.busy: if self.todo: url, max_redirect = self.todo.popitem() - fetcher = Fetcher(url, + fetcher = Fetcher(self.log, url, crawler=self, max_redirect=max_redirect, max_tries=self.max_tries, - verbose=self.verbose) + ) self.busy[url] = fetcher fetcher.task = asyncio.Task(self.fetch(fetcher)) else: @@ -761,10 +744,9 @@ def fetch(self, fetcher): with (yield from self.governor): try: yield from fetcher.fetch() # Fetcher gonna fetch. - except Exception: + finally: # Force GC of the task, so the error is logged. fetcher.task = None - raise with (yield from self.termination): self.done[url] = fetcher del self.busy[url] @@ -780,27 +762,24 @@ def report(self, file=None): else: speed = 0 stats = Stats() - if self.verbose > 0: - print('*** Report ***', file=file) - try: - show = [] - show.extend(self.done.items()) - show.extend(self.busy.items()) - show.sort() - for url, fetcher in show: - fetcher.report(stats, file=file) - except KeyboardInterrupt: - print('\nInterrupted', file=file) + print('*** Report ***', file=file) + try: + show = [] + show.extend(self.done.items()) + show.extend(self.busy.items()) + show.sort() + for url, fetcher in show: + fetcher.report(stats, file=file) + except KeyboardInterrupt: + print('\nInterrupted', file=file) print('Finished', len(self.done), 'urls in %.3f secs' % dt, '(max_tasks=%d)' % self.max_tasks, '(%.3f urls/sec/task)' % speed, file=file) stats.report(file=file) - if self.todo: - print('Todo:', len(self.todo), file=file) - if self.busy: - print('Busy:', len(self.busy), file=file) + print('Todo:', len(self.todo), file=file) + print('Busy:', len(self.busy), file=file) print('Done:', len(self.done), file=file) print('Date:', time.ctime(), 'local time', file=file) @@ -815,6 +794,8 @@ def main(): print('Use --help for command line help') return + log = Logger(args.level) + if args.iocp: from asyncio.windows_events import ProactorEventLoop loop = ProactorEventLoop() @@ -827,14 +808,14 @@ def main(): roots = {fix_url(root) for root in args.roots} - crawler = Crawler(roots, - exclude=args.exclude, + crawler = Crawler(log, + roots, exclude=args.exclude, strict=args.strict, max_redirect=args.max_redirect, max_tries=args.max_tries, max_tasks=args.max_tasks, max_pool=args.max_pool, - verbose=args.verbose) + ) try: loop.run_until_complete(crawler.crawl()) # Crawler gonna crawl. except KeyboardInterrupt: From bea90c5ac64c423d95a912fe7a058f8b8d23f93b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 10 Jan 2014 07:40:49 -0800 Subject: [PATCH 0857/1502] Fix race in subprocess transport, by Victor Stinner. Fixes issue 103. --- asyncio/unix_events.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index eb3fb9f9..80a98f80 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -160,9 +160,10 @@ def _make_subprocess_transport(self, protocol, args, shell, transp = _UnixSubprocessTransport(self, protocol, args, shell, stdin, stdout, stderr, bufsize, extra=None, **kwargs) + yield from transp._post_init() watcher.add_child_handler(transp.get_pid(), self._child_watcher_callback, transp) - yield from transp._post_init() + return transp def _child_watcher_callback(self, pid, returncode, transp): From 3e6c6264b5ed42288827217ba4cf51eea187e2a5 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 10 Jan 2014 07:47:40 -0800 Subject: [PATCH 0858/1502] Minimal pty support, by Jonathan Slenders. --- asyncio/unix_events.py | 7 +++++-- tests/test_events.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 80a98f80..24da3274 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -190,7 +190,9 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): self._pipe = pipe self._fileno = pipe.fileno() mode = os.fstat(self._fileno).st_mode - if not (stat.S_ISFIFO(mode) or stat.S_ISSOCK(mode)): + if not (stat.S_ISFIFO(mode) or + stat.S_ISSOCK(mode) or + stat.S_ISCHR(mode)): raise ValueError("Pipe transport is for pipes/sockets only.") _set_nonblocking(self._fileno) self._protocol = protocol @@ -228,7 +230,8 @@ def close(self): def _fatal_error(self, exc): # should be called by exception handler only - logger.exception('Fatal error for %s', self) + if not (isinstance(exc, OSError) and exc.errno == errno.EIO): + logger.exception('Fatal error for %s', self) self._close(exc) def _close(self, exc): diff --git a/tests/test_events.py b/tests/test_events.py index 9545dd13..2e1dfebf 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -132,6 +132,8 @@ def eof_received(self): self.state.append('EOF') def connection_lost(self, exc): + if 'EOF' not in self.state: + self.state.append('EOF') # It is okay if EOF is missed. assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state self.state.append('CLOSED') if self.done: @@ -953,6 +955,46 @@ def connect(): # extra info is available self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_read_pty_output(self): + proto = None + + def factory(): + nonlocal proto + proto = MyReadPipeProto(loop=self.loop) + return proto + + master, slave = os.openpty() + master_read_obj = io.open(master, 'rb', 0) + + @tasks.coroutine + def connect(): + t, p = yield from self.loop.connect_read_pipe(factory, + master_read_obj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(0, proto.nbytes) + + self.loop.run_until_complete(connect()) + + os.write(slave, b'1') + test_utils.run_until(self.loop, lambda: proto.nbytes) + self.assertEqual(1, proto.nbytes) + + os.write(slave, b'2345') + test_utils.run_until(self.loop, lambda: proto.nbytes >= 5) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(5, proto.nbytes) + + os.close(slave) + self.loop.run_until_complete(proto.done) + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + @unittest.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") def test_write_pipe(self): From 658da737017c857444ba67f05f3455e2e30aa3e5 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 10 Jan 2014 11:23:42 -0800 Subject: [PATCH 0859/1502] Refactor: introduce Connection class. --- examples/crawl.py | 177 ++++++++++++++++++++++++++-------------------- 1 file changed, 99 insertions(+), 78 deletions(-) diff --git a/examples/crawl.py b/examples/crawl.py index 6716815d..3a97966a 100644 --- a/examples/crawl.py +++ b/examples/crawl.py @@ -9,7 +9,6 @@ # - Support gzip encoding. # - Close connection if HTTP/1.0 response. # - Add timeouts. (E.g. when switching networks, all seems to hang.) -# - Improve class structure (e.g. add a Connection class). # - Add arguments to specify TLS settings (e.g. cert/key files). # - Skip reading large non-text/html files? # - Use ETag and If-Modified-Since? @@ -113,11 +112,10 @@ class ConnectionPool: To open a connection, use reserve(). To recycle it, use unreserve(). The pool is mostly just a mapping from (host, port, ssl) tuples to - lists of (reader, writer) pairs. The currently active connections - are *not* in the data structure; reserve() takes the connection - out, and unreserve()' puts it back in. It is up to the caller to - only call unreserve() for reusable connections. (That logic is - implemented in the Request class.) + lists of Connections. The currently active connections are *not* + in the data structure; get_connection() takes the connection out, + and recycle_connection() puts it back in. To recycle a + connection, call conn.close(recycle=True). There are limits to both the overal pool and the per-key pool. """ @@ -127,19 +125,19 @@ def __init__(self, log, max_pool=10, max_tasks=5): self.max_pool = max_pool # Overall limit. self.max_tasks = max_tasks # Per-key limit. self.loop = asyncio.get_event_loop() - self.connections = {} # {(host, port, ssl): [(reader, writer)]} - self.queue = [] # [(key, pair)] + self.connections = {} # {(host, port, ssl): [Connection, ...], ...} + self.queue = [] # [Connection, ...] def close(self): """Close all connections available for reuse.""" - for pairs in self.connections.values(): - for _, writer in pairs: - writer.close() + for conns in self.connections.values(): + for conn in conns: + conn.close() self.connections.clear() self.queue.clear() @asyncio.coroutine - def reserve(self, host, port, ssl): + def get_connection(self, host, port, ssl): """Create or reuse a connection.""" port = port or (443 if ssl else 80) try: @@ -153,66 +151,101 @@ def reserve(self, host, port, ssl): # Look for a reusable connection. for _, _, _, _, (h, p, *_) in ipaddrs: key = h, p, ssl - pair = None - pairs = self.connections.get(key) - while pairs: - pair = pairs.pop(0) - self.queue.remove((key, pair)) - if not pairs: + conn = None + conns = self.connections.get(key) + while conns: + conn = conns.pop(0) + self.queue.remove(conn) + if not conns: del self.connections[key] - reader, writer = pair - if reader._eof: + if conn.stale(): self.log(1, '(cached connection closed for %s)' % repr(key)) - writer.close() # Just in case. + conn.close() # Just in case. else: self.log(1, '* Reusing pooled connection', key, - 'FD =', writer._transport._sock.fileno()) - return key, reader, writer + 'FD =', conn.fileno()) + return conn # Create a new connection. - reader, writer = yield from asyncio.open_connection(host, port, - ssl=ssl) - peername = writer.get_extra_info('peername') - if peername: - host, port, *_ = peername - else: - self.log(1, 'NO PEERNAME???', host, port, ssl) - key = host, port, ssl - self.log(1, '* New connection', key, - 'FD =', writer._transport._sock.fileno()) - return key, reader, writer + conn = Connection(self.log, self, host, port, ssl) + yield from conn.connect() + self.log(1, '* New connection', conn.key, 'FD =', conn.fileno()) + return conn - def unreserve(self, key, reader, writer): + def recycle_connection(self, conn): """Make a connection available for reuse. This also prunes the pool if it exceeds the size limits. """ - if reader._eof: - writer.close() + if conn.stale(): + conn.close() return - pair = reader, writer - pairs = self.connections.setdefault(key, []) - pairs.append(pair) - self.queue.append((key, pair)) + conns = self.connections.setdefault(conn.key, []) + conns.append(conn) + self.queue.append(conn) + + # TODO: Remove closed connections first. # Close oldest connection(s) for this key if limit reached. - while len(pairs) > self.max_tasks: - pair = pairs.pop(0) - self.log(1, 'closing oldest connection for', key) - self.queue.remove((key, pair)) - reader, writer = pair - writer.close() + while len(conns) > self.max_tasks: + conn = conns.pop(0) + self.log(1, 'closing oldest connection for', conn.key) + self.queue.remove(conn) + conn.close() # Close oldest overall connection(s) if limit reached. while len(self.queue) > self.max_pool: - key, pair = self.queue.pop(0) - self.log(1, 'closing oldest connection', key) - pairs = self.connections.get(key) - p = pairs.pop(0) - assert pair == p, (key, pair, p, pairs) - reader, writer = pair - writer.close() + conn = self.queue.pop(0) + self.log(1, 'closing oldest connection', conn.key) + conns = self.connections.get(conn.key) + c = conns.pop(0) + assert conn == c, (conn.key, conn, c, conns) + conn.close() + + +class Connection: + + def __init__(self, log, pool, host, port, ssl): + self.log = log + self.pool = pool + self.host = host + self.port = port + self.ssl = ssl + self.reader = None + self.writer = None + self.key = None + + def stale(self): + return self.reader is None or self.reader._eof + + def fileno(self): + writer = self.writer + if writer is not None: + transport = writer._transport + if transport is not None: + sock = transport._sock + if sock is not None: + return sock.fileno() + return None + + @asyncio.coroutine + def connect(self): + self.reader, self.writer = yield from asyncio.open_connection( + self.host, self.port, ssl=self.ssl) + peername = self.writer.get_extra_info('peername') + if peername: + self.host, self.port = peername[:2] + else: + self.log(1, 'NO PEERNAME???', self.host, self.port, self.ssl) + self.key = self.host, self.port, self.ssl + + def close(self, recycle=False): + if recycle and not self.stale(): + self.pool.recycle_connection(self) + else: + self.writer.close() + self.pool = self.reader = self.writer = None class Request: @@ -242,9 +275,7 @@ def __init__(self, log, url, pool): self.http_version = 'HTTP/1.1' self.method = 'GET' self.headers = [] - self.key = None - self.reader = None - self.writer = None + self.conn = None @asyncio.coroutine def connect(self): @@ -253,24 +284,14 @@ def connect(self): (self.hostname, self.port, 'ssl' if self.ssl else 'tcp', self.url)) - self.key, self.reader, self.writer = \ - yield from self.pool.reserve(self.hostname, self.port, self.ssl) - self.log(1, '* Connected to %s' % - (self.writer.get_extra_info('peername'),)) + self.conn = yield from self.pool.get_connection(self.hostname, + self.port, self.ssl) - def recycle_connection(self): - """Recycle the connection to the pool. - - This should only be called when a properly formatted HTTP - response has been received. - """ - self.pool.unreserve(self.key, self.reader, self.writer) - self.key = self.reader = self.writer = None - - def close(self): - if self.writer is not None: - self.writer.close() - self.key = self.reader = self.writer = None + def close(self, recycle=False): + """Close the connection, recycle if requested.""" + if self.conn is not None: + self.conn.close(recycle) + self.conn = None @asyncio.coroutine def putline(self, line): @@ -279,7 +300,7 @@ def putline(self, line): Used for the request line and headers. """ self.log(2, '>', line) - self.writer.write(line.encode('latin-1') + b'\r\n') + self.conn.writer.write(line.encode('latin-1') + b'\r\n') @asyncio.coroutine def send_request(self): @@ -300,7 +321,7 @@ def send_request(self): @asyncio.coroutine def get_response(self): """Receive the response.""" - response = Response(self.log, self.reader) + response = Response(self.log, self.conn.reader) yield from response.read_headers() return response @@ -481,7 +502,7 @@ def fetch(self): h_conn = self.response.get_header('connection').lower() h_t_enc = self.response.get_header('transfer-encoding').lower() if h_conn != 'close': - self.request.recycle_connection() + self.request.close(recycle=True) self.request = None if self.tries > 1: self.log(1, 'try', self.tries, 'for', self.url, 'success') @@ -596,7 +617,7 @@ def add(self, key, count=1): def report(self, file=None): for key, count in sorted(self.stats.items()): - print(' %-20s %10d' % (key, count), file=file) + print('%10d' % count, key, file=file) class Crawler: From cba7de9b4d7072b591eb09499a5efd16ace42cae Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 10 Jan 2014 23:19:38 +0100 Subject: [PATCH 0860/1502] Cleanup properly proactor event loop * store the "self reading" future when the "self pipe" is closed (when the event loop is closed) * store "accept" futures to cancel them when we stop serving * close the "accept socket" if the "accept future" is cancelled Fix many warnings which can be seen when unit tests are run in verbose mode. --- asyncio/proactor_events.py | 10 ++++++++++ asyncio/windows_events.py | 17 +++++++++++++---- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 979bc25f..ba5169e9 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -330,6 +330,8 @@ def __init__(self, proactor): logger.debug('Using proactor: %s', proactor.__class__.__name__) self._proactor = proactor self._selector = proactor # convenient alias + self._self_reading_future = None + self._accept_futures = {} # socket file descriptor => Future proactor.set_loop(self) self._make_self_pipe() @@ -365,6 +367,7 @@ def close(self): self._proactor = None self._selector = None super().close() + self._accept_futures.clear() def sock_recv(self, sock, n): return self._proactor.recv(sock, n) @@ -382,6 +385,9 @@ def _socketpair(self): raise NotImplementedError def _close_self_pipe(self): + if self._self_reading_future is not None: + self._self_reading_future.cancel() + self._self_reading_future = None self._ssock.close() self._ssock = None self._csock.close() @@ -405,6 +411,7 @@ def _loop_self_reading(self, f=None): self.close() raise else: + self._self_reading_future = f f.add_done_callback(self._loop_self_reading) def _write_to_self(self): @@ -430,6 +437,7 @@ def loop(f=None): except futures.CancelledError: sock.close() else: + self._accept_futures[sock.fileno()] = f f.add_done_callback(loop) self.call_soon(loop) @@ -438,5 +446,7 @@ def _process_events(self, event_list): pass # XXX hard work currently done in poll def _stop_serving(self, sock): + for future in self._accept_futures.values(): + future.cancel() self._proactor._stop_serving(sock) sock.close() diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index b2ed2415..2e9ec697 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -168,9 +168,6 @@ def loop(f=None): self.call_soon(loop) return [server] - def _stop_serving(self, server): - server.close() - @tasks.coroutine def _make_subprocess_transport(self, protocol, args, shell, stdin, stdout, stderr, bufsize, @@ -260,7 +257,19 @@ def finish_accept(trans, key, ov): conn.settimeout(listener.gettimeout()) return conn, conn.getpeername() - return self._register(ov, listener, finish_accept) + @tasks.coroutine + def accept_coro(future, conn): + # Coroutine closing the accept socket if the future is cancelled + try: + yield from future + except futures.CancelledError: + conn.close() + raise + + future = self._register(ov, listener, finish_accept) + coro = accept_coro(future, conn) + tasks.async(coro, loop=self._loop) + return future def connect(self, conn, address): self._register_with_iocp(conn) From 0172550fe357f264d0d755732d97d1aa4a919e0c Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 11 Jan 2014 00:15:13 +0100 Subject: [PATCH 0861/1502] Fix ResourceWarning in test_windows_events: close the write end of the socket pair --- tests/test_windows_events.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index b32477f9..17c204a7 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -49,6 +49,7 @@ def test_close(self): trans.close() self.loop.run_until_complete(f) self.assertEqual(f.result(), b'') + b.close() def test_double_bind(self): ADDRESS = r'\\.\pipe\test_double_bind-%s' % os.getpid() From f3202181dee1323e163b9c770509cd2e7f6d1683 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 10 Jan 2014 17:04:32 -0800 Subject: [PATCH 0862/1502] Prune stale collections before closing oldest ones. Improve cloose logging. --- examples/crawl.py | 42 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/examples/crawl.py b/examples/crawl.py index 3a97966a..f0c8e3d1 100644 --- a/examples/crawl.py +++ b/examples/crawl.py @@ -159,8 +159,7 @@ def get_connection(self, host, port, ssl): if not conns: del self.connections[key] if conn.stale(): - self.log(1, '(cached connection closed for %s)' % - repr(key)) + self.log(1, 'closing stale connection for', key) conn.close() # Just in case. else: self.log(1, '* Reusing pooled connection', key, @@ -181,26 +180,55 @@ def recycle_connection(self, conn): if conn.stale(): conn.close() return - conns = self.connections.setdefault(conn.key, []) + + key = conn.key + conns = self.connections.setdefault(key, []) conns.append(conn) self.queue.append(conn) - # TODO: Remove closed connections first. + if len(conns) <= self.max_tasks and len(self.queue) <= self.max_pool: + return + + # Prune the queue. + + # Close stale connections for this key first. + stale = [conn for conn in conns if conn.stale()] + if stale: + for conn in stale: + conns.remove(conn) + self.queue.remove(conn) + self.log(1, 'closing stale connection for', key) + conn.close() + if not conns: + del self.connections[key] # Close oldest connection(s) for this key if limit reached. while len(conns) > self.max_tasks: conn = conns.pop(0) - self.log(1, 'closing oldest connection for', conn.key) self.queue.remove(conn) + self.log(1, 'closing oldest connection for', key) conn.close() + if len(self.queue) <= self.max_pool: + return + + # Close overall stale connections. + stale = [conn for conn in self.queue if conn.stale()] + if stale: + for conn in stale: + conns = self.connections.get(conn.key) + conns.remove(conn) + self.queue.remove(conn) + self.log(1, 'closing stale connection for', key) + conn.close() + # Close oldest overall connection(s) if limit reached. while len(self.queue) > self.max_pool: conn = self.queue.pop(0) - self.log(1, 'closing oldest connection', conn.key) conns = self.connections.get(conn.key) c = conns.pop(0) assert conn == c, (conn.key, conn, c, conns) + self.log(1, 'closing overall oldest connection for', conn.key) conn.close() @@ -290,6 +318,8 @@ def connect(self): def close(self, recycle=False): """Close the connection, recycle if requested.""" if self.conn is not None: + if not recycle: + self.log(1, 'closing connection for', self.conn.key) self.conn.close(recycle) self.conn = None From 01a091969d5c74a983c3ce663aaa3a78802552f1 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 16 Jan 2014 01:36:55 +0100 Subject: [PATCH 0863/1502] Issue #104: Fix a typo in CoroWrapper __slot__ => __slots__ --- asyncio/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 406bcb93..36404687 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -34,7 +34,7 @@ class CoroWrapper: """Wrapper for coroutine in _DEBUG mode.""" - __slot__ = ['gen', 'func'] + __slots__ = ['gen', 'func'] def __init__(self, gen, func): assert inspect.isgenerator(gen), gen From 5bd28ed6c45ad3d3383b4ee90c16ae0d4e709367 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 16 Jan 2014 01:56:31 +0100 Subject: [PATCH 0864/1502] Fix CoroWrapper (fix my previous commit) Add __name__ and __doc__ to __slots__ --- asyncio/tasks.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 36404687..ec04d2f6 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -32,9 +32,7 @@ class CoroWrapper: - """Wrapper for coroutine in _DEBUG mode.""" - - __slots__ = ['gen', 'func'] + __slots__ = ['gen', 'func', '__name__', '__doc__'] def __init__(self, gen, func): assert inspect.isgenerator(gen), gen From c7f52cd60302bd261595f8c52934137ab8cf8c43 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 16 Jan 2014 11:04:07 -0800 Subject: [PATCH 0865/1502] Reincarnate CoroWrapper's docstring as a comment. --- asyncio/tasks.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index ec04d2f6..42413dc0 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -32,6 +32,8 @@ class CoroWrapper: + # Wrapper for coroutine in _DEBUG mode. + __slots__ = ['gen', 'func', '__name__', '__doc__'] def __init__(self, gen, func): From 2f1b2d72ffe34419a091fdd72fac70ec5be54752 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 21 Jan 2014 00:04:14 +0100 Subject: [PATCH 0866/1502] Optimize BaseEventLoop._run_once() Logger.log() is "slow", logger.isEnabledFor() is faster and the logger is disabled in most cases. A microbenchmark executing 100,000 dummy tasks is 22% faster with this change. See the CPython issue: http://bugs.python.org/issue20275 --- asyncio/base_events.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index a8850656..07d49c5e 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -610,15 +610,18 @@ def _run_once(self): timeout = min(timeout, deadline) # TODO: Instrumentation only in debug mode? - t0 = self.time() - event_list = self._selector.select(timeout) - t1 = self.time() - argstr = '' if timeout is None else ' {:.3f}'.format(timeout) - if t1-t0 >= 1: - level = logging.INFO + if logger.isEnabledFor(logging.INFO): + t0 = self.time() + event_list = self._selector.select(timeout) + t1 = self.time() + argstr = '' if timeout is None else ' {:.3f}'.format(timeout) + if t1-t0 >= 1: + level = logging.INFO + else: + level = logging.DEBUG + logger.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) else: - level = logging.DEBUG - logger.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + event_list = self._selector.select(timeout) self._process_events(event_list) # Handle 'later' callbacks that are ready. From 990d81c60faae7a47b1f61d09cc2e0c5ca030804 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 21 Jan 2014 02:23:33 +0100 Subject: [PATCH 0867/1502] Fix timeout rounding issues in selectors Round timeouts away from zero to wait *at least* 'timeout' seconds in PollSelector and EpollSelector. The change in EpollSelector works around a Python bug in select.epoll.poll(): http://bugs.python.org/issue20311 --- asyncio/__init__.py | 10 +++------- asyncio/selectors.py | 21 +++++++++++++++++++-- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/asyncio/__init__.py b/asyncio/__init__.py index 0d288d5a..9c14515d 100644 --- a/asyncio/__init__.py +++ b/asyncio/__init__.py @@ -2,13 +2,9 @@ import sys -# The selectors module is in the stdlib in Python 3.4 but not in 3.3. -# Do this first, so the other submodules can use "from . import selectors". -# Prefer asyncio/selectors.py over the stdlib one, as ours may be newer. -try: - from . import selectors -except ImportError: - import selectors # Will also be exported. +# Prefer asyncio/selectors.py over the stdlib one, as ours may be newer +# and contains a workaround for a timeout rounding issue. +from . import selectors if sys.platform == 'win32': # Similar thing for _overlapped. diff --git a/asyncio/selectors.py b/asyncio/selectors.py index a44d5e96..ff139154 100644 --- a/asyncio/selectors.py +++ b/asyncio/selectors.py @@ -8,6 +8,7 @@ from abc import ABCMeta, abstractmethod from collections import namedtuple, Mapping import functools +import math import select import sys @@ -351,7 +352,14 @@ def unregister(self, fileobj): return key def select(self, timeout=None): - timeout = None if timeout is None else max(int(1000 * timeout), 0) + if timeout is None: + timeout = None + elif timeout <= 0: + timeout = 0 + else: + # poll() has a resolution of 1 millisecond, round away from + # zero to wait *at least* timeout seconds. + timeout = int(math.ceil(timeout * 1e3)) ready = [] try: fd_event_list = self._poll.poll(timeout) @@ -403,7 +411,16 @@ def unregister(self, fileobj): return key def select(self, timeout=None): - timeout = -1 if timeout is None else max(timeout, 0) + if timeout is None: + timeout = -1 + elif timeout <= 0: + timeout = 0 + else: + # epoll_wait() has a resolution of 1 millisecond, round away + # from zero to wait *at least* timeout seconds. Workaround for + # the following Python bug: + # http://bugs.python.org/issue20311 + timeout = math.ceil(timeout * 1e3) * 1e-3 max_ev = len(self._fd_to_key) ready = [] try: From f9b660ef7a75179af5d07b904c5f007d365bfde5 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 21 Jan 2014 17:48:38 +0100 Subject: [PATCH 0868/1502] Restore asyncio/__init__.py to have the same file in Python 3.4 and Tulip --- asyncio/__init__.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/asyncio/__init__.py b/asyncio/__init__.py index 9c14515d..0d288d5a 100644 --- a/asyncio/__init__.py +++ b/asyncio/__init__.py @@ -2,9 +2,13 @@ import sys -# Prefer asyncio/selectors.py over the stdlib one, as ours may be newer -# and contains a workaround for a timeout rounding issue. -from . import selectors +# The selectors module is in the stdlib in Python 3.4 but not in 3.3. +# Do this first, so the other submodules can use "from . import selectors". +# Prefer asyncio/selectors.py over the stdlib one, as ours may be newer. +try: + from . import selectors +except ImportError: + import selectors # Will also be exported. if sys.platform == 'win32': # Similar thing for _overlapped. From f9c4e913be8fb0d4644dd52c517a13de492959f2 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 22 Jan 2014 12:21:45 +0100 Subject: [PATCH 0869/1502] Cleanup logging in BaseEventLoop._run_once() logger.log() is now responsible to format the timeout. It might be faster if the log is disabled for DEBUG level, but it's also more readable and fix an issue with Python 2.6 in the Trollius project. --- asyncio/base_events.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 07d49c5e..72201aa5 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -614,12 +614,15 @@ def _run_once(self): t0 = self.time() event_list = self._selector.select(timeout) t1 = self.time() - argstr = '' if timeout is None else ' {:.3f}'.format(timeout) if t1-t0 >= 1: level = logging.INFO else: level = logging.DEBUG - logger.log(level, 'poll%s took %.3f seconds', argstr, t1-t0) + if timeout is not None: + logger.log(level, 'poll %.3f took %.3f seconds', + timeout, t1-t0) + else: + logger.log(level, 'poll took %.3f seconds', t1-t0) else: event_list = self._selector.select(timeout) self._process_events(event_list) From 42084899506a0a088456dcfc6ee550108852f0ca Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 22 Jan 2014 12:25:36 +0100 Subject: [PATCH 0870/1502] selectors: don't mention the Python bug Copy selectors.py from CPython to have exactly the same file in Tulip. --- asyncio/selectors.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/asyncio/selectors.py b/asyncio/selectors.py index ff139154..1bdf972c 100644 --- a/asyncio/selectors.py +++ b/asyncio/selectors.py @@ -417,9 +417,7 @@ def select(self, timeout=None): timeout = 0 else: # epoll_wait() has a resolution of 1 millisecond, round away - # from zero to wait *at least* timeout seconds. Workaround for - # the following Python bug: - # http://bugs.python.org/issue20311 + # from zero to wait *at least* timeout seconds. timeout = math.ceil(timeout * 1e3) * 1e-3 max_ev = len(self._fd_to_key) ready = [] From bc503589a2f315caf05edc223de6ad36a877a297 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 22 Jan 2014 22:56:35 +0100 Subject: [PATCH 0871/1502] Create a list of authors and contributors based on the Mercurial history --- AUTHORS | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 AUTHORS diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 00000000..c6f0b8b4 --- /dev/null +++ b/AUTHORS @@ -0,0 +1,31 @@ +Authors: + +- Guido van Rossum + +Contributors: + +- A. Jesse Jiryu Davis +- Aaron Griffith +- Andrew Svetlov +- Anthony Baire +- Antoine Pitrou +- Arnaud Faure +- Aymeric Augustin +- Brett Cannon +- Charles-François Natali +- Christian Heimes +- Eli Bendersky +- Geert Jansen +- Giampaolo Rodola' +- Jeff Quast +- Jonathan Slenders +- Nikolay Kim +- Richard Oudkerk +- Saúl Ibarra Corretgé +- Serhiy Storchaka +- Sonald Stufft +- Vajrasky Kok +- Victor Stinner +- Vladimir Kryachko +- Yury Selivanov + From c5765c745249c146ad2b2ba2d56c2a9a30a65f88 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 22 Jan 2014 23:00:47 +0100 Subject: [PATCH 0872/1502] wait_for() now cancels the future on timeout. Patch written by Gustavo Carneiro. --- AUTHORS | 1 + asyncio/tasks.py | 6 ++++-- tests/test_tasks.py | 29 ++++++++++++++--------------- 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/AUTHORS b/AUTHORS index c6f0b8b4..263e8dc2 100644 --- a/AUTHORS +++ b/AUTHORS @@ -17,6 +17,7 @@ Contributors: - Eli Bendersky - Geert Jansen - Giampaolo Rodola' +- Gustavo Carneiro - Jeff Quast - Jonathan Slenders - Nikolay Kim diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 42413dc0..b52933fc 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -382,8 +382,9 @@ def wait_for(fut, timeout, *, loop=None): Coroutine will be wrapped in Task. - Returns result of the Future or coroutine. Raises TimeoutError when - timeout occurs. + Returns result of the Future or coroutine. When a timeout occurs, + it cancels the task and raises TimeoutError. To avoid the task + cancellation, wrap it in shield(). Usage: @@ -405,6 +406,7 @@ def wait_for(fut, timeout, *, loop=None): return fut.result() else: fut.remove_done_callback(cb) + fut.cancel() raise futures.TimeoutError() finally: timeout_handle.cancel() diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 79a25d29..3d08ad8b 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -355,30 +355,32 @@ def gen(): when = yield 0 self.assertAlmostEqual(0.1, when) when = yield 0.1 - self.assertAlmostEqual(0.4, when) - yield 0.1 loop = test_utils.TestLoop(gen) self.addCleanup(loop.close) + foo_running = None + @tasks.coroutine def foo(): - yield from tasks.sleep(0.2, loop=loop) + nonlocal foo_running + foo_running = True + try: + yield from tasks.sleep(0.2, loop=loop) + finally: + foo_running = False return 'done' fut = tasks.Task(foo(), loop=loop) with self.assertRaises(futures.TimeoutError): loop.run_until_complete(tasks.wait_for(fut, 0.1, loop=loop)) - - self.assertFalse(fut.done()) + self.assertTrue(fut.done()) + # it should have been cancelled due to the timeout + self.assertTrue(fut.cancelled()) self.assertAlmostEqual(0.1, loop.time()) + self.assertEqual(foo_running, False) - # wait for result - res = loop.run_until_complete( - tasks.wait_for(fut, 0.3, loop=loop)) - self.assertEqual(res, 'done') - self.assertAlmostEqual(0.2, loop.time()) def test_wait_for_with_global_loop(self): @@ -406,11 +408,8 @@ def foo(): events.set_event_loop(None) self.assertAlmostEqual(0.01, loop.time()) - self.assertFalse(fut.done()) - - # move forward to close generator - loop.advance_time(10) - loop.run_until_complete(fut) + self.assertTrue(fut.done()) + self.assertTrue(fut.cancelled()) def test_wait(self): From 785f302127d70bbfc9b9e96a04fc871382feefff Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 23 Jan 2014 10:21:32 +0100 Subject: [PATCH 0873/1502] Fix open_connection() docstring, writer is a StreamWriter --- asyncio/streams.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/streams.py b/asyncio/streams.py index f01f8629..cd0c4ffb 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -21,7 +21,7 @@ def open_connection(host=None, port=None, *, """A wrapper for create_connection() returning a (reader, writer) pair. The reader returned is a StreamReader instance; the writer is a - Transport. + StreamWriter instance. The arguments are all the usual arguments to create_connection() except protocol_factory; most common are positional host and port, From 469c4c8aa96ca8b4c6ef80122edd91b43920de3b Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 23 Jan 2014 17:06:37 +0100 Subject: [PATCH 0874/1502] Issue #110: StreamReader.read() and StreamReader.readline() now raise a RuntimeError, instead of using an assertion, if another coroutine is already waiting for incoming data --- asyncio/streams.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/asyncio/streams.py b/asyncio/streams.py index cd0c4ffb..b53080ef 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -284,6 +284,16 @@ def feed_data(self, data): else: self._paused = True + def _create_waiter(self, func_name): + # StreamReader uses a future to link the protocol feed_data() method + # to a read coroutine. Running two read coroutines at the same time + # would have an unexpected behaviour. It would not possible to know + # which coroutine would get the next data. + if self._waiter is not None: + raise RuntimeError('%s() called while another coroutine is ' + 'already waiting for incoming data' % func_name) + return futures.Future(loop=self._loop) + @tasks.coroutine def readline(self): if self._exception is not None: @@ -318,8 +328,7 @@ def readline(self): break if not_enough: - assert self._waiter is None - self._waiter = futures.Future(loop=self._loop) + self._waiter = self._create_waiter('readline') try: yield from self._waiter finally: @@ -341,16 +350,14 @@ def read(self, n=-1): if n < 0: while not self._eof: - assert not self._waiter - self._waiter = futures.Future(loop=self._loop) + self._waiter = self._create_waiter('read') try: yield from self._waiter finally: self._waiter = None else: if not self._byte_count and not self._eof: - assert not self._waiter - self._waiter = futures.Future(loop=self._loop) + self._waiter = self._create_waiter('read') try: yield from self._waiter finally: From 1c523492f42522f6929e443363576e5b95e4290c Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 23 Jan 2014 17:20:05 +0100 Subject: [PATCH 0875/1502] Skip test_events.test_read_pty_output() on Mac OS X older than 10.9 (Maverick) --- tests/test_events.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_events.py b/tests/test_events.py index 2e1dfebf..21036b5f 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -957,6 +957,9 @@ def connect(): @unittest.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") + # kqueue doesn't support character devices (PTY) on Mac OS X older + # than 10.9 (Maverick) + @support.requires_mac_ver(10, 9) def test_read_pty_output(self): proto = None From 4a56da54b8f8fe31df33a40d71003f7d490a85fb Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 24 Jan 2014 14:46:00 +0100 Subject: [PATCH 0876/1502] Unit tests use the main asyncio module instead of submodules like events --- tests/test_base_events.py | 95 +++--- tests/test_events.py | 198 ++++++----- tests/test_futures.py | 99 +++--- tests/test_locks.py | 243 +++++++------- tests/test_queues.py | 95 +++--- tests/test_selector_events.py | 64 ++-- tests/test_streams.py | 110 +++---- tests/test_tasks.py | 596 +++++++++++++++++----------------- tests/test_transports.py | 14 +- tests/test_unix_events.py | 51 ++- tests/test_windows_events.py | 25 +- 11 files changed, 778 insertions(+), 812 deletions(-) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 96f29750..9c2bda1f 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -8,12 +8,9 @@ import unittest.mock from test.support import find_unused_port, IPV6_ENABLED +import asyncio from asyncio import base_events from asyncio import constants -from asyncio import events -from asyncio import futures -from asyncio import protocols -from asyncio import tasks from asyncio import test_utils @@ -22,7 +19,7 @@ class BaseEventLoopTests(unittest.TestCase): def setUp(self): self.loop = base_events.BaseEventLoop() self.loop._selector = unittest.mock.Mock() - events.set_event_loop(None) + asyncio.set_event_loop(None) def test_not_implemented(self): m = unittest.mock.Mock() @@ -51,20 +48,20 @@ def test_not_implemented(self): self.assertRaises(NotImplementedError, next, iter(gen)) def test__add_callback_handle(self): - h = events.Handle(lambda: False, ()) + h = asyncio.Handle(lambda: False, ()) self.loop._add_callback(h) self.assertFalse(self.loop._scheduled) self.assertIn(h, self.loop._ready) def test__add_callback_timer(self): - h = events.TimerHandle(time.monotonic()+10, lambda: False, ()) + h = asyncio.TimerHandle(time.monotonic()+10, lambda: False, ()) self.loop._add_callback(h) self.assertIn(h, self.loop._scheduled) def test__add_callback_cancelled_handle(self): - h = events.Handle(lambda: False, ()) + h = asyncio.Handle(lambda: False, ()) h.cancel() self.loop._add_callback(h) @@ -90,7 +87,7 @@ def cb(): h = self.loop.call_soon(cb) self.assertEqual(h._callback, cb) - self.assertIsInstance(h, events.Handle) + self.assertIsInstance(h, asyncio.Handle) self.assertIn(h, self.loop._ready) def test_call_later(self): @@ -98,7 +95,7 @@ def cb(): pass h = self.loop.call_later(10.0, cb) - self.assertIsInstance(h, events.TimerHandle) + self.assertIsInstance(h, asyncio.TimerHandle) self.assertIn(h, self.loop._scheduled) self.assertNotIn(h, self.loop._ready) @@ -132,27 +129,27 @@ def cb(): self.assertRaises( AssertionError, self.loop.run_in_executor, - None, events.Handle(cb, ()), ('',)) + None, asyncio.Handle(cb, ()), ('',)) self.assertRaises( AssertionError, self.loop.run_in_executor, - None, events.TimerHandle(10, cb, ())) + None, asyncio.TimerHandle(10, cb, ())) def test_run_once_in_executor_cancelled(self): def cb(): pass - h = events.Handle(cb, ()) + h = asyncio.Handle(cb, ()) h.cancel() f = self.loop.run_in_executor(None, h) - self.assertIsInstance(f, futures.Future) + self.assertIsInstance(f, asyncio.Future) self.assertTrue(f.done()) self.assertIsNone(f.result()) def test_run_once_in_executor_plain(self): def cb(): pass - h = events.Handle(cb, ()) - f = futures.Future(loop=self.loop) + h = asyncio.Handle(cb, ()) + f = asyncio.Future(loop=self.loop) executor = unittest.mock.Mock() executor.submit.return_value = f @@ -170,8 +167,8 @@ def cb(): f.cancel() # Don't complain about abandoned Future. def test__run_once(self): - h1 = events.TimerHandle(time.monotonic() + 5.0, lambda: True, ()) - h2 = events.TimerHandle(time.monotonic() + 10.0, lambda: True, ()) + h1 = asyncio.TimerHandle(time.monotonic() + 5.0, lambda: True, ()) + h2 = asyncio.TimerHandle(time.monotonic() + 10.0, lambda: True, ()) h1.cancel() @@ -202,14 +199,14 @@ def monotonic(): m_logging.DEBUG = logging.DEBUG self.loop._scheduled.append( - events.TimerHandle(11.0, lambda: True, ())) + asyncio.TimerHandle(11.0, lambda: True, ())) self.loop._process_events = unittest.mock.Mock() self.loop._run_once() self.assertEqual(logging.INFO, m_logging.log.call_args[0][0]) idx = -1 data = [10.0, 10.0, 10.3, 13.0] - self.loop._scheduled = [events.TimerHandle(11.0, lambda:True, ())] + self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda:True, ())] self.loop._run_once() self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0]) @@ -222,7 +219,7 @@ def cb(loop): processed = True handle = loop.call_soon(lambda: True) - h = events.TimerHandle(time.monotonic() - 1, cb, (self.loop,)) + h = asyncio.TimerHandle(time.monotonic() - 1, cb, (self.loop,)) self.loop._process_events = unittest.mock.Mock() self.loop._scheduled.append(h) @@ -236,14 +233,14 @@ def test_run_until_complete_type_error(self): TypeError, self.loop.run_until_complete, 'blah') -class MyProto(protocols.Protocol): +class MyProto(asyncio.Protocol): done = None def __init__(self, create_future=False): self.state = 'INITIAL' self.nbytes = 0 if create_future: - self.done = futures.Future() + self.done = asyncio.Future() def connection_made(self, transport): self.transport = transport @@ -266,14 +263,14 @@ def connection_lost(self, exc): self.done.set_result(None) -class MyDatagramProto(protocols.DatagramProtocol): +class MyDatagramProto(asyncio.DatagramProtocol): done = None def __init__(self, create_future=False): self.state = 'INITIAL' self.nbytes = 0 if create_future: - self.done = futures.Future() + self.done = asyncio.Future() def connection_made(self, transport): self.transport = transport @@ -297,8 +294,8 @@ def connection_lost(self, exc): class BaseEventLoopWithSelectorTests(unittest.TestCase): def setUp(self): - self.loop = events.new_event_loop() - events.set_event_loop(None) + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(None) def tearDown(self): self.loop.close() @@ -306,17 +303,17 @@ def tearDown(self): @unittest.mock.patch('asyncio.base_events.socket') def test_create_connection_multiple_errors(self, m_socket): - class MyProto(protocols.Protocol): + class MyProto(asyncio.Protocol): pass - @tasks.coroutine + @asyncio.coroutine def getaddrinfo(*args, **kw): yield from [] return [(2, 1, 6, '', ('107.6.106.82', 80)), (2, 1, 6, '', ('107.6.106.82', 80))] def getaddrinfo_task(*args, **kwds): - return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) idx = -1 errors = ['err1', 'err2'] @@ -346,12 +343,12 @@ def test_create_connection_no_host_port_sock(self): self.assertRaises(ValueError, self.loop.run_until_complete, coro) def test_create_connection_no_getaddrinfo(self): - @tasks.coroutine + @asyncio.coroutine def getaddrinfo(*args, **kw): yield from [] def getaddrinfo_task(*args, **kwds): - return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) self.loop.getaddrinfo = getaddrinfo_task coro = self.loop.create_connection(MyProto, 'example.com', 80) @@ -359,13 +356,13 @@ def getaddrinfo_task(*args, **kwds): OSError, self.loop.run_until_complete, coro) def test_create_connection_connect_err(self): - @tasks.coroutine + @asyncio.coroutine def getaddrinfo(*args, **kw): yield from [] return [(2, 1, 6, '', ('107.6.106.82', 80))] def getaddrinfo_task(*args, **kwds): - return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) self.loop.getaddrinfo = getaddrinfo_task self.loop.sock_connect = unittest.mock.Mock() @@ -376,13 +373,13 @@ def getaddrinfo_task(*args, **kwds): OSError, self.loop.run_until_complete, coro) def test_create_connection_multiple(self): - @tasks.coroutine + @asyncio.coroutine def getaddrinfo(*args, **kw): return [(2, 1, 6, '', ('0.0.0.1', 80)), (2, 1, 6, '', ('0.0.0.2', 80))] def getaddrinfo_task(*args, **kwds): - return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) self.loop.getaddrinfo = getaddrinfo_task self.loop.sock_connect = unittest.mock.Mock() @@ -404,13 +401,13 @@ def bind(addr): m_socket.socket.return_value.bind = bind - @tasks.coroutine + @asyncio.coroutine def getaddrinfo(*args, **kw): return [(2, 1, 6, '', ('0.0.0.1', 80)), (2, 1, 6, '', ('0.0.0.2', 80))] def getaddrinfo_task(*args, **kwds): - return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) self.loop.getaddrinfo = getaddrinfo_task self.loop.sock_connect = unittest.mock.Mock() @@ -426,7 +423,7 @@ def getaddrinfo_task(*args, **kwds): self.assertTrue(m_socket.socket.return_value.close.called) def test_create_connection_no_local_addr(self): - @tasks.coroutine + @asyncio.coroutine def getaddrinfo(host, *args, **kw): if host == 'example.com': return [(2, 1, 6, '', ('107.6.106.82', 80)), @@ -435,7 +432,7 @@ def getaddrinfo(host, *args, **kw): return [] def getaddrinfo_task(*args, **kwds): - return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) self.loop.getaddrinfo = getaddrinfo_task coro = self.loop.create_connection( @@ -448,7 +445,7 @@ def test_create_connection_ssl_server_hostname_default(self): self.loop.getaddrinfo = unittest.mock.Mock() def mock_getaddrinfo(*args, **kwds): - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) f.set_result([(socket.AF_INET, socket.SOCK_STREAM, socket.SOL_TCP, '', ('1.2.3.4', 80))]) return f @@ -527,14 +524,14 @@ def test_create_server_empty_host(self): # if host is empty string use None instead host = object() - @tasks.coroutine + @asyncio.coroutine def getaddrinfo(*args, **kw): nonlocal host host = args[0] yield from [] def getaddrinfo_task(*args, **kwds): - return tasks.Task(getaddrinfo(*args, **kwds), loop=self.loop) + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) self.loop.getaddrinfo = getaddrinfo_task fut = self.loop.create_server(MyProto, '', 0) @@ -596,7 +593,7 @@ def test_create_datagram_endpoint_connect_err(self): self.loop.sock_connect.side_effect = OSError coro = self.loop.create_datagram_endpoint( - protocols.DatagramProtocol, remote_addr=('127.0.0.1', 0)) + asyncio.DatagramProtocol, remote_addr=('127.0.0.1', 0)) self.assertRaises( OSError, self.loop.run_until_complete, coro) @@ -606,19 +603,19 @@ def test_create_datagram_endpoint_socket_err(self, m_socket): m_socket.socket.side_effect = OSError coro = self.loop.create_datagram_endpoint( - protocols.DatagramProtocol, family=socket.AF_INET) + asyncio.DatagramProtocol, family=socket.AF_INET) self.assertRaises( OSError, self.loop.run_until_complete, coro) coro = self.loop.create_datagram_endpoint( - protocols.DatagramProtocol, local_addr=('127.0.0.1', 0)) + asyncio.DatagramProtocol, local_addr=('127.0.0.1', 0)) self.assertRaises( OSError, self.loop.run_until_complete, coro) @unittest.skipUnless(IPV6_ENABLED, 'IPv6 not supported or enabled') def test_create_datagram_endpoint_no_matching_family(self): coro = self.loop.create_datagram_endpoint( - protocols.DatagramProtocol, + asyncio.DatagramProtocol, remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) self.assertRaises( ValueError, self.loop.run_until_complete, coro) @@ -628,7 +625,7 @@ def test_create_datagram_endpoint_setblk_err(self, m_socket): m_socket.socket.return_value.setblocking.side_effect = OSError coro = self.loop.create_datagram_endpoint( - protocols.DatagramProtocol, family=socket.AF_INET) + asyncio.DatagramProtocol, family=socket.AF_INET) self.assertRaises( OSError, self.loop.run_until_complete, coro) self.assertTrue( @@ -636,7 +633,7 @@ def test_create_datagram_endpoint_setblk_err(self, m_socket): def test_create_datagram_endpoint_noaddr_nofamily(self): coro = self.loop.create_datagram_endpoint( - protocols.DatagramProtocol) + asyncio.DatagramProtocol) self.assertRaises(ValueError, self.loop.run_until_complete, coro) @unittest.mock.patch('asyncio.base_events.socket') diff --git a/tests/test_events.py b/tests/test_events.py index 21036b5f..e49c4be5 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -23,14 +23,10 @@ from test import support # find_unused_port, IPV6_ENABLED, TEST_HOME_DIR -from asyncio import futures +import asyncio from asyncio import events -from asyncio import transports -from asyncio import protocols from asyncio import selector_events -from asyncio import tasks from asyncio import test_utils -from asyncio import locks def data_file(filename): @@ -49,7 +45,7 @@ def data_file(filename): SIGNING_CA = data_file('pycacert.pem') -class MyProto(protocols.Protocol): +class MyProto(asyncio.Protocol): done = None def __init__(self, loop=None): @@ -57,7 +53,7 @@ def __init__(self, loop=None): self.state = 'INITIAL' self.nbytes = 0 if loop is not None: - self.done = futures.Future(loop=loop) + self.done = asyncio.Future(loop=loop) def connection_made(self, transport): self.transport = transport @@ -80,14 +76,14 @@ def connection_lost(self, exc): self.done.set_result(None) -class MyDatagramProto(protocols.DatagramProtocol): +class MyDatagramProto(asyncio.DatagramProtocol): done = None def __init__(self, loop=None): self.state = 'INITIAL' self.nbytes = 0 if loop is not None: - self.done = futures.Future(loop=loop) + self.done = asyncio.Future(loop=loop) def connection_made(self, transport): self.transport = transport @@ -108,7 +104,7 @@ def connection_lost(self, exc): self.done.set_result(None) -class MyReadPipeProto(protocols.Protocol): +class MyReadPipeProto(asyncio.Protocol): done = None def __init__(self, loop=None): @@ -116,7 +112,7 @@ def __init__(self, loop=None): self.nbytes = 0 self.transport = None if loop is not None: - self.done = futures.Future(loop=loop) + self.done = asyncio.Future(loop=loop) def connection_made(self, transport): self.transport = transport @@ -140,14 +136,14 @@ def connection_lost(self, exc): self.done.set_result(None) -class MyWritePipeProto(protocols.BaseProtocol): +class MyWritePipeProto(asyncio.BaseProtocol): done = None def __init__(self, loop=None): self.state = 'INITIAL' self.transport = None if loop is not None: - self.done = futures.Future(loop=loop) + self.done = asyncio.Future(loop=loop) def connection_made(self, transport): self.transport = transport @@ -161,18 +157,18 @@ def connection_lost(self, exc): self.done.set_result(None) -class MySubprocessProtocol(protocols.SubprocessProtocol): +class MySubprocessProtocol(asyncio.SubprocessProtocol): def __init__(self, loop): self.state = 'INITIAL' self.transport = None - self.connected = futures.Future(loop=loop) - self.completed = futures.Future(loop=loop) - self.disconnects = {fd: futures.Future(loop=loop) for fd in range(3)} + self.connected = asyncio.Future(loop=loop) + self.completed = asyncio.Future(loop=loop) + self.disconnects = {fd: asyncio.Future(loop=loop) for fd in range(3)} self.data = {1: b'', 2: b''} self.returncode = None - self.got_data = {1: locks.Event(loop=loop), - 2: locks.Event(loop=loop)} + self.got_data = {1: asyncio.Event(loop=loop), + 2: asyncio.Event(loop=loop)} def connection_made(self, transport): self.transport = transport @@ -207,7 +203,7 @@ class EventLoopTestsMixin: def setUp(self): super().setUp() self.loop = self.create_event_loop() - events.set_event_loop(None) + asyncio.set_event_loop(None) def tearDown(self): # just in case if we have transport close callbacks @@ -218,11 +214,11 @@ def tearDown(self): super().tearDown() def test_run_until_complete_nesting(self): - @tasks.coroutine + @asyncio.coroutine def coro1(): yield - @tasks.coroutine + @asyncio.coroutine def coro2(): self.assertTrue(self.loop.is_running()) self.loop.run_until_complete(coro1()) @@ -235,15 +231,15 @@ def coro2(): def test_run_until_complete(self): t0 = self.loop.time() - self.loop.run_until_complete(tasks.sleep(0.1, loop=self.loop)) + self.loop.run_until_complete(asyncio.sleep(0.1, loop=self.loop)) t1 = self.loop.time() self.assertTrue(0.08 <= t1-t0 <= 0.8, t1-t0) def test_run_until_complete_stopped(self): - @tasks.coroutine + @asyncio.coroutine def cb(): self.loop.stop() - yield from tasks.sleep(0.1, loop=self.loop) + yield from asyncio.sleep(0.1, loop=self.loop) task = cb() self.assertRaises(RuntimeError, self.loop.run_until_complete, task) @@ -494,8 +490,8 @@ def test_create_connection(self): f = self.loop.create_connection( lambda: MyProto(loop=self.loop), *httpd.address) tr, pr = self.loop.run_until_complete(f) - self.assertIsInstance(tr, transports.Transport) - self.assertIsInstance(pr, protocols.Protocol) + self.assertIsInstance(tr, asyncio.Transport) + self.assertIsInstance(pr, asyncio.Protocol) self.loop.run_until_complete(pr.done) self.assertGreater(pr.nbytes, 0) tr.close() @@ -522,8 +518,8 @@ def test_create_connection_sock(self): f = self.loop.create_connection( lambda: MyProto(loop=self.loop), sock=sock) tr, pr = self.loop.run_until_complete(f) - self.assertIsInstance(tr, transports.Transport) - self.assertIsInstance(pr, protocols.Protocol) + self.assertIsInstance(tr, asyncio.Transport) + self.assertIsInstance(pr, asyncio.Protocol) self.loop.run_until_complete(pr.done) self.assertGreater(pr.nbytes, 0) tr.close() @@ -535,8 +531,8 @@ def test_create_ssl_connection(self): lambda: MyProto(loop=self.loop), *httpd.address, ssl=test_utils.dummy_ssl_context()) tr, pr = self.loop.run_until_complete(f) - self.assertIsInstance(tr, transports.Transport) - self.assertIsInstance(pr, protocols.Protocol) + self.assertIsInstance(tr, asyncio.Transport) + self.assertIsInstance(pr, asyncio.Protocol) self.assertTrue('ssl' in tr.__class__.__name__.lower()) self.assertIsNotNone(tr.get_extra_info('sockname')) self.loop.run_until_complete(pr.done) @@ -762,7 +758,7 @@ def factory(): server.close() def test_create_server_sock(self): - proto = futures.Future(loop=self.loop) + proto = asyncio.Future(loop=self.loop) class TestMyProto(MyProto): def connection_made(self, transport): @@ -805,7 +801,7 @@ def test_create_server_addr_in_use(self): @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 not supported or enabled') def test_create_server_dual_stack(self): - f_proto = futures.Future(loop=self.loop) + f_proto = asyncio.Future(loop=self.loop) class TestMyProto(MyProto): def connection_made(self, transport): @@ -834,7 +830,7 @@ def connection_made(self, transport): proto.transport.close() client.close() - f_proto = futures.Future(loop=self.loop) + f_proto = asyncio.Future(loop=self.loop) client = socket.socket(socket.AF_INET6) client.connect(('::1', port)) client.send(b'xxx') @@ -929,7 +925,7 @@ def factory(): rpipe, wpipe = os.pipe() pipeobj = io.open(rpipe, 'rb', 1024) - @tasks.coroutine + @asyncio.coroutine def connect(): t, p = yield from self.loop.connect_read_pipe(factory, pipeobj) self.assertIs(p, proto) @@ -971,7 +967,7 @@ def factory(): master, slave = os.openpty() master_read_obj = io.open(master, 'rb', 0) - @tasks.coroutine + @asyncio.coroutine def connect(): t, p = yield from self.loop.connect_read_pipe(factory, master_read_obj) @@ -1012,7 +1008,7 @@ def factory(): rpipe, wpipe = os.pipe() pipeobj = io.open(wpipe, 'wb', 1024) - @tasks.coroutine + @asyncio.coroutine def connect(): nonlocal transport t, p = yield from self.loop.connect_write_pipe(factory, pipeobj) @@ -1058,7 +1054,7 @@ def factory(): rsock, wsock = test_utils.socketpair() pipeobj = io.open(wsock.detach(), 'wb', 1024) - @tasks.coroutine + @asyncio.coroutine def connect(): nonlocal transport t, p = yield from self.loop.connect_write_pipe(factory, @@ -1088,12 +1084,12 @@ def test_prompt_cancellation(self): if ov is not None: self.assertTrue(ov.pending) - @tasks.coroutine + @asyncio.coroutine def main(): try: self.loop.call_soon(f.cancel) yield from f - except futures.CancelledError: + except asyncio.CancelledError: res = 'cancelled' else: res = None @@ -1102,13 +1098,13 @@ def main(): return res start = time.monotonic() - t = tasks.Task(main(), loop=self.loop) + t = asyncio.Task(main(), loop=self.loop) self.loop.run_forever() elapsed = time.monotonic() - start self.assertLess(elapsed, 0.1) self.assertEqual(t.result(), 'cancelled') - self.assertRaises(futures.CancelledError, f.result) + self.assertRaises(asyncio.CancelledError, f.result) if ov is not None: self.assertFalse(ov.pending) self.loop._stop_serving(r) @@ -1139,7 +1135,7 @@ def test_subprocess_exec(self): prog = os.path.join(os.path.dirname(__file__), 'echo.py') - @tasks.coroutine + @asyncio.coroutine def connect(): nonlocal proto, transp transp, proto = yield from self.loop.subprocess_exec( @@ -1165,7 +1161,7 @@ def test_subprocess_interactive(self): prog = os.path.join(os.path.dirname(__file__), 'echo.py') - @tasks.coroutine + @asyncio.coroutine def connect(): nonlocal proto, transp transp, proto = yield from self.loop.subprocess_exec( @@ -1197,7 +1193,7 @@ def test_subprocess_shell(self): proto = None transp = None - @tasks.coroutine + @asyncio.coroutine def connect(): nonlocal proto, transp transp, proto = yield from self.loop.subprocess_shell( @@ -1218,7 +1214,7 @@ def connect(): def test_subprocess_exitcode(self): proto = None - @tasks.coroutine + @asyncio.coroutine def connect(): nonlocal proto transp, proto = yield from self.loop.subprocess_shell( @@ -1234,7 +1230,7 @@ def test_subprocess_close_after_finish(self): proto = None transp = None - @tasks.coroutine + @asyncio.coroutine def connect(): nonlocal proto, transp transp, proto = yield from self.loop.subprocess_shell( @@ -1256,7 +1252,7 @@ def test_subprocess_kill(self): prog = os.path.join(os.path.dirname(__file__), 'echo.py') - @tasks.coroutine + @asyncio.coroutine def connect(): nonlocal proto, transp transp, proto = yield from self.loop.subprocess_exec( @@ -1277,7 +1273,7 @@ def test_subprocess_terminate(self): prog = os.path.join(os.path.dirname(__file__), 'echo.py') - @tasks.coroutine + @asyncio.coroutine def connect(): nonlocal proto, transp transp, proto = yield from self.loop.subprocess_exec( @@ -1299,7 +1295,7 @@ def test_subprocess_send_signal(self): prog = os.path.join(os.path.dirname(__file__), 'echo.py') - @tasks.coroutine + @asyncio.coroutine def connect(): nonlocal proto, transp transp, proto = yield from self.loop.subprocess_exec( @@ -1320,7 +1316,7 @@ def test_subprocess_stderr(self): prog = os.path.join(os.path.dirname(__file__), 'echo2.py') - @tasks.coroutine + @asyncio.coroutine def connect(): nonlocal proto, transp transp, proto = yield from self.loop.subprocess_exec( @@ -1347,7 +1343,7 @@ def test_subprocess_stderr_redirect_to_stdout(self): prog = os.path.join(os.path.dirname(__file__), 'echo2.py') - @tasks.coroutine + @asyncio.coroutine def connect(): nonlocal proto, transp transp, proto = yield from self.loop.subprocess_exec( @@ -1377,7 +1373,7 @@ def test_subprocess_close_client_stream(self): prog = os.path.join(os.path.dirname(__file__), 'echo3.py') - @tasks.coroutine + @asyncio.coroutine def connect(): nonlocal proto, transp transp, proto = yield from self.loop.subprocess_exec( @@ -1414,7 +1410,7 @@ def test_subprocess_wait_no_same_group(self): proto = None transp = None - @tasks.coroutine + @asyncio.coroutine def connect(): nonlocal proto # start the new process in a new session @@ -1430,19 +1426,18 @@ def connect(): if sys.platform == 'win32': - from asyncio import windows_events class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): def create_event_loop(self): - return windows_events.SelectorEventLoop() + return asyncio.SelectorEventLoop() class ProactorEventLoopTests(EventLoopTestsMixin, SubprocessTestsMixin, unittest.TestCase): def create_event_loop(self): - return windows_events.ProactorEventLoop() + return asyncio.ProactorEventLoop() def test_create_ssl_connection(self): raise unittest.SkipTest("IocpEventLoop incompatible with SSL") @@ -1476,17 +1471,16 @@ def test_create_datagram_endpoint(self): "IocpEventLoop does not have create_datagram_endpoint()") else: from asyncio import selectors - from asyncio import unix_events class UnixEventLoopTestsMixin(EventLoopTestsMixin): def setUp(self): super().setUp() - watcher = unix_events.SafeChildWatcher() + watcher = asyncio.SafeChildWatcher() watcher.attach_loop(self.loop) - events.set_child_watcher(watcher) + asyncio.set_child_watcher(watcher) def tearDown(self): - events.set_child_watcher(None) + asyncio.set_child_watcher(None) super().tearDown() if hasattr(selectors, 'KqueueSelector'): @@ -1495,7 +1489,7 @@ class KqueueEventLoopTests(UnixEventLoopTestsMixin, unittest.TestCase): def create_event_loop(self): - return unix_events.SelectorEventLoop( + return asyncio.SelectorEventLoop( selectors.KqueueSelector()) if hasattr(selectors, 'EpollSelector'): @@ -1504,7 +1498,7 @@ class EPollEventLoopTests(UnixEventLoopTestsMixin, unittest.TestCase): def create_event_loop(self): - return unix_events.SelectorEventLoop(selectors.EpollSelector()) + return asyncio.SelectorEventLoop(selectors.EpollSelector()) if hasattr(selectors, 'PollSelector'): class PollEventLoopTests(UnixEventLoopTestsMixin, @@ -1512,7 +1506,7 @@ class PollEventLoopTests(UnixEventLoopTestsMixin, unittest.TestCase): def create_event_loop(self): - return unix_events.SelectorEventLoop(selectors.PollSelector()) + return asyncio.SelectorEventLoop(selectors.PollSelector()) # Should always exist. class SelectEventLoopTests(UnixEventLoopTestsMixin, @@ -1520,7 +1514,7 @@ class SelectEventLoopTests(UnixEventLoopTestsMixin, unittest.TestCase): def create_event_loop(self): - return unix_events.SelectorEventLoop(selectors.SelectSelector()) + return asyncio.SelectorEventLoop(selectors.SelectSelector()) class HandleTests(unittest.TestCase): @@ -1530,7 +1524,7 @@ def callback(*args): return args args = () - h = events.Handle(callback, args) + h = asyncio.Handle(callback, args) self.assertIs(h._callback, callback) self.assertIs(h._args, args) self.assertFalse(h._cancelled) @@ -1553,16 +1547,16 @@ def callback(*args): def test_make_handle(self): def callback(*args): return args - h1 = events.Handle(callback, ()) + h1 = asyncio.Handle(callback, ()) self.assertRaises( - AssertionError, events.make_handle, h1, ()) + AssertionError, asyncio.events.make_handle, h1, ()) @unittest.mock.patch('asyncio.events.logger') def test_callback_with_exception(self, log): def callback(): raise ValueError() - h = events.Handle(callback, ()) + h = asyncio.Handle(callback, ()) h._run() self.assertTrue(log.exception.called) @@ -1571,7 +1565,7 @@ class TimerTests(unittest.TestCase): def test_hash(self): when = time.monotonic() - h = events.TimerHandle(when, lambda: False, ()) + h = asyncio.TimerHandle(when, lambda: False, ()) self.assertEqual(hash(h), hash(when)) def test_timer(self): @@ -1580,7 +1574,7 @@ def callback(*args): args = () when = time.monotonic() - h = events.TimerHandle(when, callback, args) + h = asyncio.TimerHandle(when, callback, args) self.assertIs(h._callback, callback) self.assertIs(h._args, args) self.assertFalse(h._cancelled) @@ -1595,7 +1589,7 @@ def callback(*args): self.assertTrue(r.endswith('())'), r) self.assertRaises(AssertionError, - events.TimerHandle, None, callback, args) + asyncio.TimerHandle, None, callback, args) def test_timer_comparison(self): def callback(*args): @@ -1603,8 +1597,8 @@ def callback(*args): when = time.monotonic() - h1 = events.TimerHandle(when, callback, ()) - h2 = events.TimerHandle(when, callback, ()) + h1 = asyncio.TimerHandle(when, callback, ()) + h2 = asyncio.TimerHandle(when, callback, ()) # TODO: Use assertLess etc. self.assertFalse(h1 < h2) self.assertFalse(h2 < h1) @@ -1620,8 +1614,8 @@ def callback(*args): h2.cancel() self.assertFalse(h1 == h2) - h1 = events.TimerHandle(when, callback, ()) - h2 = events.TimerHandle(when + 10.0, callback, ()) + h1 = asyncio.TimerHandle(when, callback, ()) + h2 = asyncio.TimerHandle(when + 10.0, callback, ()) self.assertTrue(h1 < h2) self.assertFalse(h2 < h1) self.assertTrue(h1 <= h2) @@ -1633,7 +1627,7 @@ def callback(*args): self.assertFalse(h1 == h2) self.assertTrue(h1 != h2) - h3 = events.Handle(callback, ()) + h3 = asyncio.Handle(callback, ()) self.assertIs(NotImplemented, h1.__eq__(h3)) self.assertIs(NotImplemented, h1.__ne__(h3)) @@ -1642,7 +1636,7 @@ class AbstractEventLoopTests(unittest.TestCase): def test_not_implemented(self): f = unittest.mock.Mock() - loop = events.AbstractEventLoop() + loop = asyncio.AbstractEventLoop() self.assertRaises( NotImplementedError, loop.run_forever) self.assertRaises( @@ -1716,19 +1710,19 @@ class ProtocolsAbsTests(unittest.TestCase): def test_empty(self): f = unittest.mock.Mock() - p = protocols.Protocol() + p = asyncio.Protocol() self.assertIsNone(p.connection_made(f)) self.assertIsNone(p.connection_lost(f)) self.assertIsNone(p.data_received(f)) self.assertIsNone(p.eof_received()) - dp = protocols.DatagramProtocol() + dp = asyncio.DatagramProtocol() self.assertIsNone(dp.connection_made(f)) self.assertIsNone(dp.connection_lost(f)) self.assertIsNone(dp.error_received(f)) self.assertIsNone(dp.datagram_received(f, f)) - sp = protocols.SubprocessProtocol() + sp = asyncio.SubprocessProtocol() self.assertIsNone(sp.connection_made(f)) self.assertIsNone(sp.connection_lost(f)) self.assertIsNone(sp.pipe_data_received(1, f)) @@ -1738,16 +1732,8 @@ def test_empty(self): class PolicyTests(unittest.TestCase): - def create_policy(self): - if sys.platform == "win32": - from asyncio import windows_events - return windows_events.DefaultEventLoopPolicy() - else: - from asyncio import unix_events - return unix_events.DefaultEventLoopPolicy() - def test_event_loop_policy(self): - policy = events.AbstractEventLoopPolicy() + policy = asyncio.AbstractEventLoopPolicy() self.assertRaises(NotImplementedError, policy.get_event_loop) self.assertRaises(NotImplementedError, policy.set_event_loop, object()) self.assertRaises(NotImplementedError, policy.new_event_loop) @@ -1756,18 +1742,18 @@ def test_event_loop_policy(self): object()) def test_get_event_loop(self): - policy = self.create_policy() + policy = asyncio.DefaultEventLoopPolicy() self.assertIsNone(policy._local._loop) loop = policy.get_event_loop() - self.assertIsInstance(loop, events.AbstractEventLoop) + self.assertIsInstance(loop, asyncio.AbstractEventLoop) self.assertIs(policy._local._loop, loop) self.assertIs(loop, policy.get_event_loop()) loop.close() def test_get_event_loop_calls_set_event_loop(self): - policy = self.create_policy() + policy = asyncio.DefaultEventLoopPolicy() with unittest.mock.patch.object( policy, "set_event_loop", @@ -1783,7 +1769,7 @@ def test_get_event_loop_calls_set_event_loop(self): loop.close() def test_get_event_loop_after_set_none(self): - policy = self.create_policy() + policy = asyncio.DefaultEventLoopPolicy() policy.set_event_loop(None) self.assertRaises(AssertionError, policy.get_event_loop) @@ -1791,7 +1777,7 @@ def test_get_event_loop_after_set_none(self): def test_get_event_loop_thread(self, m_current_thread): def f(): - policy = self.create_policy() + policy = asyncio.DefaultEventLoopPolicy() self.assertRaises(AssertionError, policy.get_event_loop) th = threading.Thread(target=f) @@ -1799,14 +1785,14 @@ def f(): th.join() def test_new_event_loop(self): - policy = self.create_policy() + policy = asyncio.DefaultEventLoopPolicy() loop = policy.new_event_loop() - self.assertIsInstance(loop, events.AbstractEventLoop) + self.assertIsInstance(loop, asyncio.AbstractEventLoop) loop.close() def test_set_event_loop(self): - policy = self.create_policy() + policy = asyncio.DefaultEventLoopPolicy() old_loop = policy.get_event_loop() self.assertRaises(AssertionError, policy.set_event_loop, object()) @@ -1819,19 +1805,19 @@ def test_set_event_loop(self): old_loop.close() def test_get_event_loop_policy(self): - policy = events.get_event_loop_policy() - self.assertIsInstance(policy, events.AbstractEventLoopPolicy) - self.assertIs(policy, events.get_event_loop_policy()) + policy = asyncio.get_event_loop_policy() + self.assertIsInstance(policy, asyncio.AbstractEventLoopPolicy) + self.assertIs(policy, asyncio.get_event_loop_policy()) def test_set_event_loop_policy(self): self.assertRaises( - AssertionError, events.set_event_loop_policy, object()) + AssertionError, asyncio.set_event_loop_policy, object()) - old_policy = events.get_event_loop_policy() + old_policy = asyncio.get_event_loop_policy() - policy = self.create_policy() - events.set_event_loop_policy(policy) - self.assertIs(policy, events.get_event_loop_policy()) + policy = asyncio.DefaultEventLoopPolicy() + asyncio.set_event_loop_policy(policy) + self.assertIs(policy, asyncio.get_event_loop_policy()) self.assertIsNot(policy, old_policy) diff --git a/tests/test_futures.py b/tests/test_futures.py index e35fcf07..d3a74125 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -5,8 +5,7 @@ import unittest import unittest.mock -from asyncio import events -from asyncio import futures +import asyncio from asyncio import test_utils @@ -18,13 +17,13 @@ class FutureTests(unittest.TestCase): def setUp(self): self.loop = test_utils.TestLoop() - events.set_event_loop(None) + asyncio.set_event_loop(None) def tearDown(self): self.loop.close() def test_initial_state(self): - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) self.assertFalse(f.cancelled()) self.assertFalse(f.done()) f.cancel() @@ -32,56 +31,56 @@ def test_initial_state(self): def test_init_constructor_default_loop(self): try: - events.set_event_loop(self.loop) - f = futures.Future() + asyncio.set_event_loop(self.loop) + f = asyncio.Future() self.assertIs(f._loop, self.loop) finally: - events.set_event_loop(None) + asyncio.set_event_loop(None) def test_constructor_positional(self): # Make sure Future does't accept a positional argument - self.assertRaises(TypeError, futures.Future, 42) + self.assertRaises(TypeError, asyncio.Future, 42) def test_cancel(self): - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) self.assertTrue(f.cancel()) self.assertTrue(f.cancelled()) self.assertTrue(f.done()) - self.assertRaises(futures.CancelledError, f.result) - self.assertRaises(futures.CancelledError, f.exception) - self.assertRaises(futures.InvalidStateError, f.set_result, None) - self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertRaises(asyncio.CancelledError, f.result) + self.assertRaises(asyncio.CancelledError, f.exception) + self.assertRaises(asyncio.InvalidStateError, f.set_result, None) + self.assertRaises(asyncio.InvalidStateError, f.set_exception, None) self.assertFalse(f.cancel()) def test_result(self): - f = futures.Future(loop=self.loop) - self.assertRaises(futures.InvalidStateError, f.result) + f = asyncio.Future(loop=self.loop) + self.assertRaises(asyncio.InvalidStateError, f.result) f.set_result(42) self.assertFalse(f.cancelled()) self.assertTrue(f.done()) self.assertEqual(f.result(), 42) self.assertEqual(f.exception(), None) - self.assertRaises(futures.InvalidStateError, f.set_result, None) - self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertRaises(asyncio.InvalidStateError, f.set_result, None) + self.assertRaises(asyncio.InvalidStateError, f.set_exception, None) self.assertFalse(f.cancel()) def test_exception(self): exc = RuntimeError() - f = futures.Future(loop=self.loop) - self.assertRaises(futures.InvalidStateError, f.exception) + f = asyncio.Future(loop=self.loop) + self.assertRaises(asyncio.InvalidStateError, f.exception) f.set_exception(exc) self.assertFalse(f.cancelled()) self.assertTrue(f.done()) self.assertRaises(RuntimeError, f.result) self.assertEqual(f.exception(), exc) - self.assertRaises(futures.InvalidStateError, f.set_result, None) - self.assertRaises(futures.InvalidStateError, f.set_exception, None) + self.assertRaises(asyncio.InvalidStateError, f.set_result, None) + self.assertRaises(asyncio.InvalidStateError, f.set_exception, None) self.assertFalse(f.cancel()) def test_yield_from_twice(self): - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) def fixture(): yield 'A' @@ -99,32 +98,32 @@ def fixture(): self.assertEqual(next(g), ('C', 42)) # yield 'C', y. def test_repr(self): - f_pending = futures.Future(loop=self.loop) + f_pending = asyncio.Future(loop=self.loop) self.assertEqual(repr(f_pending), 'Future') f_pending.cancel() - f_cancelled = futures.Future(loop=self.loop) + f_cancelled = asyncio.Future(loop=self.loop) f_cancelled.cancel() self.assertEqual(repr(f_cancelled), 'Future') - f_result = futures.Future(loop=self.loop) + f_result = asyncio.Future(loop=self.loop) f_result.set_result(4) self.assertEqual(repr(f_result), 'Future') self.assertEqual(f_result.result(), 4) exc = RuntimeError() - f_exception = futures.Future(loop=self.loop) + f_exception = asyncio.Future(loop=self.loop) f_exception.set_exception(exc) self.assertEqual(repr(f_exception), 'Future') self.assertIs(f_exception.exception(), exc) - f_few_callbacks = futures.Future(loop=self.loop) + f_few_callbacks = asyncio.Future(loop=self.loop) f_few_callbacks.add_done_callback(_fakefunc) self.assertIn('Future')) self.assertTrue(RGX_REPR.match(repr(lock))) - @tasks.coroutine + @asyncio.coroutine def acquire_lock(): yield from lock @@ -59,9 +56,9 @@ def acquire_lock(): self.assertTrue(RGX_REPR.match(repr(lock))) def test_lock(self): - lock = locks.Lock(loop=self.loop) + lock = asyncio.Lock(loop=self.loop) - @tasks.coroutine + @asyncio.coroutine def acquire_lock(): return (yield from lock) @@ -74,31 +71,31 @@ def acquire_lock(): self.assertFalse(lock.locked()) def test_acquire(self): - lock = locks.Lock(loop=self.loop) + lock = asyncio.Lock(loop=self.loop) result = [] self.assertTrue(self.loop.run_until_complete(lock.acquire())) - @tasks.coroutine + @asyncio.coroutine def c1(result): if (yield from lock.acquire()): result.append(1) return True - @tasks.coroutine + @asyncio.coroutine def c2(result): if (yield from lock.acquire()): result.append(2) return True - @tasks.coroutine + @asyncio.coroutine def c3(result): if (yield from lock.acquire()): result.append(3) return True - t1 = tasks.Task(c1(result), loop=self.loop) - t2 = tasks.Task(c2(result), loop=self.loop) + t1 = asyncio.Task(c1(result), loop=self.loop) + t2 = asyncio.Task(c2(result), loop=self.loop) test_utils.run_briefly(self.loop) self.assertEqual([], result) @@ -110,7 +107,7 @@ def c3(result): test_utils.run_briefly(self.loop) self.assertEqual([1], result) - t3 = tasks.Task(c3(result), loop=self.loop) + t3 = asyncio.Task(c3(result), loop=self.loop) lock.release() test_utils.run_briefly(self.loop) @@ -128,13 +125,13 @@ def c3(result): self.assertTrue(t3.result()) def test_acquire_cancel(self): - lock = locks.Lock(loop=self.loop) + lock = asyncio.Lock(loop=self.loop) self.assertTrue(self.loop.run_until_complete(lock.acquire())) - task = tasks.Task(lock.acquire(), loop=self.loop) + task = asyncio.Task(lock.acquire(), loop=self.loop) self.loop.call_soon(task.cancel) self.assertRaises( - futures.CancelledError, + asyncio.CancelledError, self.loop.run_until_complete, task) self.assertFalse(lock._waiters) @@ -153,9 +150,9 @@ def test_cancel_race(self): # B's waiter; instead, it should move on to C's waiter. # Setup: A has the lock, b and c are waiting. - lock = locks.Lock(loop=self.loop) + lock = asyncio.Lock(loop=self.loop) - @tasks.coroutine + @asyncio.coroutine def lockit(name, blocker): yield from lock.acquire() try: @@ -164,14 +161,14 @@ def lockit(name, blocker): finally: lock.release() - fa = futures.Future(loop=self.loop) - ta = tasks.Task(lockit('A', fa), loop=self.loop) + fa = asyncio.Future(loop=self.loop) + ta = asyncio.Task(lockit('A', fa), loop=self.loop) test_utils.run_briefly(self.loop) self.assertTrue(lock.locked()) - tb = tasks.Task(lockit('B', None), loop=self.loop) + tb = asyncio.Task(lockit('B', None), loop=self.loop) test_utils.run_briefly(self.loop) self.assertEqual(len(lock._waiters), 1) - tc = tasks.Task(lockit('C', None), loop=self.loop) + tc = asyncio.Task(lockit('C', None), loop=self.loop) test_utils.run_briefly(self.loop) self.assertEqual(len(lock._waiters), 2) @@ -187,12 +184,12 @@ def lockit(name, blocker): self.assertTrue(tc.done()) def test_release_not_acquired(self): - lock = locks.Lock(loop=self.loop) + lock = asyncio.Lock(loop=self.loop) self.assertRaises(RuntimeError, lock.release) def test_release_no_waiters(self): - lock = locks.Lock(loop=self.loop) + lock = asyncio.Lock(loop=self.loop) self.loop.run_until_complete(lock.acquire()) self.assertTrue(lock.locked()) @@ -200,9 +197,9 @@ def test_release_no_waiters(self): self.assertFalse(lock.locked()) def test_context_manager(self): - lock = locks.Lock(loop=self.loop) + lock = asyncio.Lock(loop=self.loop) - @tasks.coroutine + @asyncio.coroutine def acquire_lock(): return (yield from lock) @@ -212,7 +209,7 @@ def acquire_lock(): self.assertFalse(lock.locked()) def test_context_manager_no_yield(self): - lock = locks.Lock(loop=self.loop) + lock = asyncio.Lock(loop=self.loop) try: with lock: @@ -227,29 +224,29 @@ class EventTests(unittest.TestCase): def setUp(self): self.loop = test_utils.TestLoop() - events.set_event_loop(None) + asyncio.set_event_loop(None) def tearDown(self): self.loop.close() def test_ctor_loop(self): loop = unittest.mock.Mock() - ev = locks.Event(loop=loop) + ev = asyncio.Event(loop=loop) self.assertIs(ev._loop, loop) - ev = locks.Event(loop=self.loop) + ev = asyncio.Event(loop=self.loop) self.assertIs(ev._loop, self.loop) def test_ctor_noloop(self): try: - events.set_event_loop(self.loop) - ev = locks.Event() + asyncio.set_event_loop(self.loop) + ev = asyncio.Event() self.assertIs(ev._loop, self.loop) finally: - events.set_event_loop(None) + asyncio.set_event_loop(None) def test_repr(self): - ev = locks.Event(loop=self.loop) + ev = asyncio.Event(loop=self.loop) self.assertTrue(repr(ev).endswith('[unset]>')) match = RGX_REPR.match(repr(ev)) self.assertEqual(match.group('extras'), 'unset') @@ -263,33 +260,33 @@ def test_repr(self): self.assertTrue(RGX_REPR.match(repr(ev))) def test_wait(self): - ev = locks.Event(loop=self.loop) + ev = asyncio.Event(loop=self.loop) self.assertFalse(ev.is_set()) result = [] - @tasks.coroutine + @asyncio.coroutine def c1(result): if (yield from ev.wait()): result.append(1) - @tasks.coroutine + @asyncio.coroutine def c2(result): if (yield from ev.wait()): result.append(2) - @tasks.coroutine + @asyncio.coroutine def c3(result): if (yield from ev.wait()): result.append(3) - t1 = tasks.Task(c1(result), loop=self.loop) - t2 = tasks.Task(c2(result), loop=self.loop) + t1 = asyncio.Task(c1(result), loop=self.loop) + t2 = asyncio.Task(c2(result), loop=self.loop) test_utils.run_briefly(self.loop) self.assertEqual([], result) - t3 = tasks.Task(c3(result), loop=self.loop) + t3 = asyncio.Task(c3(result), loop=self.loop) ev.set() test_utils.run_briefly(self.loop) @@ -303,24 +300,24 @@ def c3(result): self.assertIsNone(t3.result()) def test_wait_on_set(self): - ev = locks.Event(loop=self.loop) + ev = asyncio.Event(loop=self.loop) ev.set() res = self.loop.run_until_complete(ev.wait()) self.assertTrue(res) def test_wait_cancel(self): - ev = locks.Event(loop=self.loop) + ev = asyncio.Event(loop=self.loop) - wait = tasks.Task(ev.wait(), loop=self.loop) + wait = asyncio.Task(ev.wait(), loop=self.loop) self.loop.call_soon(wait.cancel) self.assertRaises( - futures.CancelledError, + asyncio.CancelledError, self.loop.run_until_complete, wait) self.assertFalse(ev._waiters) def test_clear(self): - ev = locks.Event(loop=self.loop) + ev = asyncio.Event(loop=self.loop) self.assertFalse(ev.is_set()) ev.set() @@ -330,16 +327,16 @@ def test_clear(self): self.assertFalse(ev.is_set()) def test_clear_with_waiters(self): - ev = locks.Event(loop=self.loop) + ev = asyncio.Event(loop=self.loop) result = [] - @tasks.coroutine + @asyncio.coroutine def c1(result): if (yield from ev.wait()): result.append(1) return True - t = tasks.Task(c1(result), loop=self.loop) + t = asyncio.Task(c1(result), loop=self.loop) test_utils.run_briefly(self.loop) self.assertEqual([], result) @@ -363,55 +360,55 @@ class ConditionTests(unittest.TestCase): def setUp(self): self.loop = test_utils.TestLoop() - events.set_event_loop(None) + asyncio.set_event_loop(None) def tearDown(self): self.loop.close() def test_ctor_loop(self): loop = unittest.mock.Mock() - cond = locks.Condition(loop=loop) + cond = asyncio.Condition(loop=loop) self.assertIs(cond._loop, loop) - cond = locks.Condition(loop=self.loop) + cond = asyncio.Condition(loop=self.loop) self.assertIs(cond._loop, self.loop) def test_ctor_noloop(self): try: - events.set_event_loop(self.loop) - cond = locks.Condition() + asyncio.set_event_loop(self.loop) + cond = asyncio.Condition() self.assertIs(cond._loop, self.loop) finally: - events.set_event_loop(None) + asyncio.set_event_loop(None) def test_wait(self): - cond = locks.Condition(loop=self.loop) + cond = asyncio.Condition(loop=self.loop) result = [] - @tasks.coroutine + @asyncio.coroutine def c1(result): yield from cond.acquire() if (yield from cond.wait()): result.append(1) return True - @tasks.coroutine + @asyncio.coroutine def c2(result): yield from cond.acquire() if (yield from cond.wait()): result.append(2) return True - @tasks.coroutine + @asyncio.coroutine def c3(result): yield from cond.acquire() if (yield from cond.wait()): result.append(3) return True - t1 = tasks.Task(c1(result), loop=self.loop) - t2 = tasks.Task(c2(result), loop=self.loop) - t3 = tasks.Task(c3(result), loop=self.loop) + t1 = asyncio.Task(c1(result), loop=self.loop) + t2 = asyncio.Task(c2(result), loop=self.loop) + t3 = asyncio.Task(c3(result), loop=self.loop) test_utils.run_briefly(self.loop) self.assertEqual([], result) @@ -451,25 +448,25 @@ def c3(result): self.assertTrue(t3.result()) def test_wait_cancel(self): - cond = locks.Condition(loop=self.loop) + cond = asyncio.Condition(loop=self.loop) self.loop.run_until_complete(cond.acquire()) - wait = tasks.Task(cond.wait(), loop=self.loop) + wait = asyncio.Task(cond.wait(), loop=self.loop) self.loop.call_soon(wait.cancel) self.assertRaises( - futures.CancelledError, + asyncio.CancelledError, self.loop.run_until_complete, wait) self.assertFalse(cond._waiters) self.assertTrue(cond.locked()) def test_wait_unacquired(self): - cond = locks.Condition(loop=self.loop) + cond = asyncio.Condition(loop=self.loop) self.assertRaises( RuntimeError, self.loop.run_until_complete, cond.wait()) def test_wait_for(self): - cond = locks.Condition(loop=self.loop) + cond = asyncio.Condition(loop=self.loop) presult = False def predicate(): @@ -477,7 +474,7 @@ def predicate(): result = [] - @tasks.coroutine + @asyncio.coroutine def c1(result): yield from cond.acquire() if (yield from cond.wait_for(predicate)): @@ -485,7 +482,7 @@ def c1(result): cond.release() return True - t = tasks.Task(c1(result), loop=self.loop) + t = asyncio.Task(c1(result), loop=self.loop) test_utils.run_briefly(self.loop) self.assertEqual([], result) @@ -507,7 +504,7 @@ def c1(result): self.assertTrue(t.result()) def test_wait_for_unacquired(self): - cond = locks.Condition(loop=self.loop) + cond = asyncio.Condition(loop=self.loop) # predicate can return true immediately res = self.loop.run_until_complete(cond.wait_for(lambda: [1, 2, 3])) @@ -519,10 +516,10 @@ def test_wait_for_unacquired(self): cond.wait_for(lambda: False)) def test_notify(self): - cond = locks.Condition(loop=self.loop) + cond = asyncio.Condition(loop=self.loop) result = [] - @tasks.coroutine + @asyncio.coroutine def c1(result): yield from cond.acquire() if (yield from cond.wait()): @@ -530,7 +527,7 @@ def c1(result): cond.release() return True - @tasks.coroutine + @asyncio.coroutine def c2(result): yield from cond.acquire() if (yield from cond.wait()): @@ -538,7 +535,7 @@ def c2(result): cond.release() return True - @tasks.coroutine + @asyncio.coroutine def c3(result): yield from cond.acquire() if (yield from cond.wait()): @@ -546,9 +543,9 @@ def c3(result): cond.release() return True - t1 = tasks.Task(c1(result), loop=self.loop) - t2 = tasks.Task(c2(result), loop=self.loop) - t3 = tasks.Task(c3(result), loop=self.loop) + t1 = asyncio.Task(c1(result), loop=self.loop) + t2 = asyncio.Task(c2(result), loop=self.loop) + t3 = asyncio.Task(c3(result), loop=self.loop) test_utils.run_briefly(self.loop) self.assertEqual([], result) @@ -574,11 +571,11 @@ def c3(result): self.assertTrue(t3.result()) def test_notify_all(self): - cond = locks.Condition(loop=self.loop) + cond = asyncio.Condition(loop=self.loop) result = [] - @tasks.coroutine + @asyncio.coroutine def c1(result): yield from cond.acquire() if (yield from cond.wait()): @@ -586,7 +583,7 @@ def c1(result): cond.release() return True - @tasks.coroutine + @asyncio.coroutine def c2(result): yield from cond.acquire() if (yield from cond.wait()): @@ -594,8 +591,8 @@ def c2(result): cond.release() return True - t1 = tasks.Task(c1(result), loop=self.loop) - t2 = tasks.Task(c2(result), loop=self.loop) + t1 = asyncio.Task(c1(result), loop=self.loop) + t2 = asyncio.Task(c2(result), loop=self.loop) test_utils.run_briefly(self.loop) self.assertEqual([], result) @@ -612,15 +609,15 @@ def c2(result): self.assertTrue(t2.result()) def test_notify_unacquired(self): - cond = locks.Condition(loop=self.loop) + cond = asyncio.Condition(loop=self.loop) self.assertRaises(RuntimeError, cond.notify) def test_notify_all_unacquired(self): - cond = locks.Condition(loop=self.loop) + cond = asyncio.Condition(loop=self.loop) self.assertRaises(RuntimeError, cond.notify_all) def test_repr(self): - cond = locks.Condition(loop=self.loop) + cond = asyncio.Condition(loop=self.loop) self.assertTrue('unlocked' in repr(cond)) self.assertTrue(RGX_REPR.match(repr(cond))) @@ -636,9 +633,9 @@ def test_repr(self): self.assertTrue(RGX_REPR.match(repr(cond))) def test_context_manager(self): - cond = locks.Condition(loop=self.loop) + cond = asyncio.Condition(loop=self.loop) - @tasks.coroutine + @asyncio.coroutine def acquire_cond(): return (yield from cond) @@ -648,7 +645,7 @@ def acquire_cond(): self.assertFalse(cond.locked()) def test_context_manager_no_yield(self): - cond = locks.Condition(loop=self.loop) + cond = asyncio.Condition(loop=self.loop) try: with cond: @@ -663,33 +660,33 @@ class SemaphoreTests(unittest.TestCase): def setUp(self): self.loop = test_utils.TestLoop() - events.set_event_loop(None) + asyncio.set_event_loop(None) def tearDown(self): self.loop.close() def test_ctor_loop(self): loop = unittest.mock.Mock() - sem = locks.Semaphore(loop=loop) + sem = asyncio.Semaphore(loop=loop) self.assertIs(sem._loop, loop) - sem = locks.Semaphore(loop=self.loop) + sem = asyncio.Semaphore(loop=self.loop) self.assertIs(sem._loop, self.loop) def test_ctor_noloop(self): try: - events.set_event_loop(self.loop) - sem = locks.Semaphore() + asyncio.set_event_loop(self.loop) + sem = asyncio.Semaphore() self.assertIs(sem._loop, self.loop) finally: - events.set_event_loop(None) + asyncio.set_event_loop(None) def test_initial_value_zero(self): - sem = locks.Semaphore(0, loop=self.loop) + sem = asyncio.Semaphore(0, loop=self.loop) self.assertTrue(sem.locked()) def test_repr(self): - sem = locks.Semaphore(loop=self.loop) + sem = asyncio.Semaphore(loop=self.loop) self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) self.assertTrue(RGX_REPR.match(repr(sem))) @@ -707,10 +704,10 @@ def test_repr(self): self.assertTrue(RGX_REPR.match(repr(sem))) def test_semaphore(self): - sem = locks.Semaphore(loop=self.loop) + sem = asyncio.Semaphore(loop=self.loop) self.assertEqual(1, sem._value) - @tasks.coroutine + @asyncio.coroutine def acquire_lock(): return (yield from sem) @@ -725,43 +722,43 @@ def acquire_lock(): self.assertEqual(1, sem._value) def test_semaphore_value(self): - self.assertRaises(ValueError, locks.Semaphore, -1) + self.assertRaises(ValueError, asyncio.Semaphore, -1) def test_acquire(self): - sem = locks.Semaphore(3, loop=self.loop) + sem = asyncio.Semaphore(3, loop=self.loop) result = [] self.assertTrue(self.loop.run_until_complete(sem.acquire())) self.assertTrue(self.loop.run_until_complete(sem.acquire())) self.assertFalse(sem.locked()) - @tasks.coroutine + @asyncio.coroutine def c1(result): yield from sem.acquire() result.append(1) return True - @tasks.coroutine + @asyncio.coroutine def c2(result): yield from sem.acquire() result.append(2) return True - @tasks.coroutine + @asyncio.coroutine def c3(result): yield from sem.acquire() result.append(3) return True - @tasks.coroutine + @asyncio.coroutine def c4(result): yield from sem.acquire() result.append(4) return True - t1 = tasks.Task(c1(result), loop=self.loop) - t2 = tasks.Task(c2(result), loop=self.loop) - t3 = tasks.Task(c3(result), loop=self.loop) + t1 = asyncio.Task(c1(result), loop=self.loop) + t2 = asyncio.Task(c2(result), loop=self.loop) + t3 = asyncio.Task(c3(result), loop=self.loop) test_utils.run_briefly(self.loop) self.assertEqual([1], result) @@ -769,7 +766,7 @@ def c4(result): self.assertEqual(2, len(sem._waiters)) self.assertEqual(0, sem._value) - t4 = tasks.Task(c4(result), loop=self.loop) + t4 = asyncio.Task(c4(result), loop=self.loop) sem.release() sem.release() @@ -794,23 +791,23 @@ def c4(result): sem.release() def test_acquire_cancel(self): - sem = locks.Semaphore(loop=self.loop) + sem = asyncio.Semaphore(loop=self.loop) self.loop.run_until_complete(sem.acquire()) - acquire = tasks.Task(sem.acquire(), loop=self.loop) + acquire = asyncio.Task(sem.acquire(), loop=self.loop) self.loop.call_soon(acquire.cancel) self.assertRaises( - futures.CancelledError, + asyncio.CancelledError, self.loop.run_until_complete, acquire) self.assertFalse(sem._waiters) def test_release_not_acquired(self): - sem = locks.BoundedSemaphore(loop=self.loop) + sem = asyncio.BoundedSemaphore(loop=self.loop) self.assertRaises(ValueError, sem.release) def test_release_no_waiters(self): - sem = locks.Semaphore(loop=self.loop) + sem = asyncio.Semaphore(loop=self.loop) self.loop.run_until_complete(sem.acquire()) self.assertTrue(sem.locked()) @@ -818,9 +815,9 @@ def test_release_no_waiters(self): self.assertFalse(sem.locked()) def test_context_manager(self): - sem = locks.Semaphore(2, loop=self.loop) + sem = asyncio.Semaphore(2, loop=self.loop) - @tasks.coroutine + @asyncio.coroutine def acquire_lock(): return (yield from sem) diff --git a/tests/test_queues.py b/tests/test_queues.py index 8af4ee7f..ccb89d72 100644 --- a/tests/test_queues.py +++ b/tests/test_queues.py @@ -3,11 +3,8 @@ import unittest import unittest.mock -from asyncio import events -from asyncio import futures -from asyncio import locks +import asyncio from asyncio import queues -from asyncio import tasks from asyncio import test_utils @@ -15,7 +12,7 @@ class _QueueTestBase(unittest.TestCase): def setUp(self): self.loop = test_utils.TestLoop() - events.set_event_loop(None) + asyncio.set_event_loop(None) def tearDown(self): self.loop.close() @@ -44,27 +41,27 @@ def gen(): id_is_present = hex(id(q)) in fn(q) self.assertEqual(expect_id, id_is_present) - @tasks.coroutine + @asyncio.coroutine def add_getter(): q = queues.Queue(loop=loop) # Start a task that waits to get. - tasks.Task(q.get(), loop=loop) + asyncio.Task(q.get(), loop=loop) # Let it start waiting. - yield from tasks.sleep(0.1, loop=loop) + yield from asyncio.sleep(0.1, loop=loop) self.assertTrue('_getters[1]' in fn(q)) # resume q.get coroutine to finish generator q.put_nowait(0) loop.run_until_complete(add_getter()) - @tasks.coroutine + @asyncio.coroutine def add_putter(): q = queues.Queue(maxsize=1, loop=loop) q.put_nowait(1) # Start a task that waits to put. - tasks.Task(q.put(2), loop=loop) + asyncio.Task(q.put(2), loop=loop) # Let it start waiting. - yield from tasks.sleep(0.1, loop=loop) + yield from asyncio.sleep(0.1, loop=loop) self.assertTrue('_putters[1]' in fn(q)) # resume q.put coroutine to finish generator q.get_nowait() @@ -85,11 +82,11 @@ def test_ctor_loop(self): def test_ctor_noloop(self): try: - events.set_event_loop(self.loop) + asyncio.set_event_loop(self.loop) q = queues.Queue() self.assertIs(q._loop, self.loop) finally: - events.set_event_loop(None) + asyncio.set_event_loop(None) def test_repr(self): self._test_repr_or_str(repr, True) @@ -137,24 +134,24 @@ def gen(): self.assertEqual(2, q.maxsize) have_been_put = [] - @tasks.coroutine + @asyncio.coroutine def putter(): for i in range(3): yield from q.put(i) have_been_put.append(i) return True - @tasks.coroutine + @asyncio.coroutine def test(): - t = tasks.Task(putter(), loop=loop) - yield from tasks.sleep(0.01, loop=loop) + t = asyncio.Task(putter(), loop=loop) + yield from asyncio.sleep(0.01, loop=loop) # The putter is blocked after putting two items. self.assertEqual([0, 1], have_been_put) self.assertEqual(0, q.get_nowait()) # Let the putter resume and put last item. - yield from tasks.sleep(0.01, loop=loop) + yield from asyncio.sleep(0.01, loop=loop) self.assertEqual([0, 1, 2], have_been_put) self.assertEqual(1, q.get_nowait()) self.assertEqual(2, q.get_nowait()) @@ -172,7 +169,7 @@ def test_blocking_get(self): q = queues.Queue(loop=self.loop) q.put_nowait(1) - @tasks.coroutine + @asyncio.coroutine def queue_get(): return (yield from q.get()) @@ -183,7 +180,7 @@ def test_get_with_putters(self): q = queues.Queue(1, loop=self.loop) q.put_nowait(1) - waiter = futures.Future(loop=self.loop) + waiter = asyncio.Future(loop=self.loop) q._putters.append((2, waiter)) res = self.loop.run_until_complete(q.get()) @@ -202,10 +199,10 @@ def gen(): self.addCleanup(loop.close) q = queues.Queue(loop=loop) - started = locks.Event(loop=loop) + started = asyncio.Event(loop=loop) finished = False - @tasks.coroutine + @asyncio.coroutine def queue_get(): nonlocal finished started.set() @@ -213,10 +210,10 @@ def queue_get(): finished = True return res - @tasks.coroutine + @asyncio.coroutine def queue_put(): loop.call_later(0.01, q.put_nowait, 1) - queue_get_task = tasks.Task(queue_get(), loop=loop) + queue_get_task = asyncio.Task(queue_get(), loop=loop) yield from started.wait() self.assertFalse(finished) res = yield from queue_get_task @@ -250,14 +247,14 @@ def gen(): q = queues.Queue(loop=loop) - @tasks.coroutine + @asyncio.coroutine def queue_get(): - return (yield from tasks.wait_for(q.get(), 0.051, loop=loop)) + return (yield from asyncio.wait_for(q.get(), 0.051, loop=loop)) - @tasks.coroutine + @asyncio.coroutine def test(): - get_task = tasks.Task(queue_get(), loop=loop) - yield from tasks.sleep(0.01, loop=loop) # let the task start + get_task = asyncio.Task(queue_get(), loop=loop) + yield from asyncio.sleep(0.01, loop=loop) # let the task start q.put_nowait(1) return (yield from get_task) @@ -267,8 +264,8 @@ def test(): def test_get_cancelled_race(self): q = queues.Queue(loop=self.loop) - t1 = tasks.Task(q.get(), loop=self.loop) - t2 = tasks.Task(q.get(), loop=self.loop) + t1 = asyncio.Task(q.get(), loop=self.loop) + t2 = asyncio.Task(q.get(), loop=self.loop) test_utils.run_briefly(self.loop) t1.cancel() @@ -280,8 +277,8 @@ def test_get_cancelled_race(self): def test_get_with_waiting_putters(self): q = queues.Queue(loop=self.loop, maxsize=1) - tasks.Task(q.put('a'), loop=self.loop) - tasks.Task(q.put('b'), loop=self.loop) + asyncio.Task(q.put('a'), loop=self.loop) + asyncio.Task(q.put('b'), loop=self.loop) test_utils.run_briefly(self.loop) self.assertEqual(self.loop.run_until_complete(q.get()), 'a') self.assertEqual(self.loop.run_until_complete(q.get()), 'b') @@ -292,7 +289,7 @@ class QueuePutTests(_QueueTestBase): def test_blocking_put(self): q = queues.Queue(loop=self.loop) - @tasks.coroutine + @asyncio.coroutine def queue_put(): # No maxsize, won't block. yield from q.put(1) @@ -310,10 +307,10 @@ def gen(): self.addCleanup(loop.close) q = queues.Queue(maxsize=1, loop=loop) - started = locks.Event(loop=loop) + started = asyncio.Event(loop=loop) finished = False - @tasks.coroutine + @asyncio.coroutine def queue_put(): nonlocal finished started.set() @@ -321,10 +318,10 @@ def queue_put(): yield from q.put(2) finished = True - @tasks.coroutine + @asyncio.coroutine def queue_get(): loop.call_later(0.01, q.get_nowait) - queue_put_task = tasks.Task(queue_put(), loop=loop) + queue_put_task = asyncio.Task(queue_put(), loop=loop) yield from started.wait() self.assertFalse(finished) yield from queue_put_task @@ -346,16 +343,16 @@ def test_nonblocking_put_exception(self): def test_put_cancelled(self): q = queues.Queue(loop=self.loop) - @tasks.coroutine + @asyncio.coroutine def queue_put(): yield from q.put(1) return True - @tasks.coroutine + @asyncio.coroutine def test(): return (yield from q.get()) - t = tasks.Task(queue_put(), loop=self.loop) + t = asyncio.Task(queue_put(), loop=self.loop) self.assertEqual(1, self.loop.run_until_complete(test())) self.assertTrue(t.done()) self.assertTrue(t.result()) @@ -363,9 +360,9 @@ def test(): def test_put_cancelled_race(self): q = queues.Queue(loop=self.loop, maxsize=1) - tasks.Task(q.put('a'), loop=self.loop) - tasks.Task(q.put('c'), loop=self.loop) - t = tasks.Task(q.put('b'), loop=self.loop) + asyncio.Task(q.put('a'), loop=self.loop) + asyncio.Task(q.put('c'), loop=self.loop) + t = asyncio.Task(q.put('b'), loop=self.loop) test_utils.run_briefly(self.loop) t.cancel() @@ -376,7 +373,7 @@ def test_put_cancelled_race(self): def test_put_with_waiting_getters(self): q = queues.Queue(loop=self.loop) - t = tasks.Task(q.get(), loop=self.loop) + t = asyncio.Task(q.get(), loop=self.loop) test_utils.run_briefly(self.loop) self.loop.run_until_complete(q.put('a')) self.assertEqual(self.loop.run_until_complete(t), 'a') @@ -421,7 +418,7 @@ def test_task_done(self): # Join the queue and assert all items have been processed. running = True - @tasks.coroutine + @asyncio.coroutine def worker(): nonlocal accumulator @@ -430,10 +427,10 @@ def worker(): accumulator += item q.task_done() - @tasks.coroutine + @asyncio.coroutine def test(): for _ in range(2): - tasks.Task(worker(), loop=self.loop) + asyncio.Task(worker(), loop=self.loop) yield from q.join() @@ -451,7 +448,7 @@ def test_join_empty_queue(self): # Test that a queue join()s successfully, and before anything else # (done twice for insurance). - @tasks.coroutine + @asyncio.coroutine def join(): yield from q.join() yield from q.join() diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 38aa7669..908ee5b3 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -13,7 +13,7 @@ except ImportError: ssl = None -from asyncio import futures +import asyncio from asyncio import selectors from asyncio import test_utils from asyncio.protocols import DatagramProtocol, Protocol @@ -125,13 +125,13 @@ def test_sock_recv(self): self.loop._sock_recv = unittest.mock.Mock() f = self.loop.sock_recv(sock, 1024) - self.assertIsInstance(f, futures.Future) + self.assertIsInstance(f, asyncio.Future) self.loop._sock_recv.assert_called_with(f, False, sock, 1024) def test__sock_recv_canceled_fut(self): sock = unittest.mock.Mock() - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) f.cancel() self.loop._sock_recv(f, False, sock, 1024) @@ -141,7 +141,7 @@ def test__sock_recv_unregister(self): sock = unittest.mock.Mock() sock.fileno.return_value = 10 - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) f.cancel() self.loop.remove_reader = unittest.mock.Mock() @@ -149,7 +149,7 @@ def test__sock_recv_unregister(self): self.assertEqual((10,), self.loop.remove_reader.call_args[0]) def test__sock_recv_tryagain(self): - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) sock = unittest.mock.Mock() sock.fileno.return_value = 10 sock.recv.side_effect = BlockingIOError @@ -160,7 +160,7 @@ def test__sock_recv_tryagain(self): self.loop.add_reader.call_args[0]) def test__sock_recv_exception(self): - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) sock = unittest.mock.Mock() sock.fileno.return_value = 10 err = sock.recv.side_effect = OSError() @@ -173,7 +173,7 @@ def test_sock_sendall(self): self.loop._sock_sendall = unittest.mock.Mock() f = self.loop.sock_sendall(sock, b'data') - self.assertIsInstance(f, futures.Future) + self.assertIsInstance(f, asyncio.Future) self.assertEqual( (f, False, sock, b'data'), self.loop._sock_sendall.call_args[0]) @@ -183,7 +183,7 @@ def test_sock_sendall_nodata(self): self.loop._sock_sendall = unittest.mock.Mock() f = self.loop.sock_sendall(sock, b'') - self.assertIsInstance(f, futures.Future) + self.assertIsInstance(f, asyncio.Future) self.assertTrue(f.done()) self.assertIsNone(f.result()) self.assertFalse(self.loop._sock_sendall.called) @@ -191,7 +191,7 @@ def test_sock_sendall_nodata(self): def test__sock_sendall_canceled_fut(self): sock = unittest.mock.Mock() - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) f.cancel() self.loop._sock_sendall(f, False, sock, b'data') @@ -201,7 +201,7 @@ def test__sock_sendall_unregister(self): sock = unittest.mock.Mock() sock.fileno.return_value = 10 - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) f.cancel() self.loop.remove_writer = unittest.mock.Mock() @@ -209,7 +209,7 @@ def test__sock_sendall_unregister(self): self.assertEqual((10,), self.loop.remove_writer.call_args[0]) def test__sock_sendall_tryagain(self): - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) sock = unittest.mock.Mock() sock.fileno.return_value = 10 sock.send.side_effect = BlockingIOError @@ -221,7 +221,7 @@ def test__sock_sendall_tryagain(self): self.loop.add_writer.call_args[0]) def test__sock_sendall_interrupted(self): - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) sock = unittest.mock.Mock() sock.fileno.return_value = 10 sock.send.side_effect = InterruptedError @@ -233,7 +233,7 @@ def test__sock_sendall_interrupted(self): self.loop.add_writer.call_args[0]) def test__sock_sendall_exception(self): - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) sock = unittest.mock.Mock() sock.fileno.return_value = 10 err = sock.send.side_effect = OSError() @@ -244,7 +244,7 @@ def test__sock_sendall_exception(self): def test__sock_sendall(self): sock = unittest.mock.Mock() - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) sock.fileno.return_value = 10 sock.send.return_value = 4 @@ -255,7 +255,7 @@ def test__sock_sendall(self): def test__sock_sendall_partial(self): sock = unittest.mock.Mock() - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) sock.fileno.return_value = 10 sock.send.return_value = 2 @@ -269,7 +269,7 @@ def test__sock_sendall_partial(self): def test__sock_sendall_none(self): sock = unittest.mock.Mock() - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) sock.fileno.return_value = 10 sock.send.return_value = 0 @@ -285,13 +285,13 @@ def test_sock_connect(self): self.loop._sock_connect = unittest.mock.Mock() f = self.loop.sock_connect(sock, ('127.0.0.1', 8080)) - self.assertIsInstance(f, futures.Future) + self.assertIsInstance(f, asyncio.Future) self.assertEqual( (f, False, sock, ('127.0.0.1', 8080)), self.loop._sock_connect.call_args[0]) def test__sock_connect(self): - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) sock = unittest.mock.Mock() sock.fileno.return_value = 10 @@ -304,7 +304,7 @@ def test__sock_connect(self): def test__sock_connect_canceled_fut(self): sock = unittest.mock.Mock() - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) f.cancel() self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) @@ -314,7 +314,7 @@ def test__sock_connect_unregister(self): sock = unittest.mock.Mock() sock.fileno.return_value = 10 - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) f.cancel() self.loop.remove_writer = unittest.mock.Mock() @@ -322,7 +322,7 @@ def test__sock_connect_unregister(self): self.assertEqual((10,), self.loop.remove_writer.call_args[0]) def test__sock_connect_tryagain(self): - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) sock = unittest.mock.Mock() sock.fileno.return_value = 10 sock.getsockopt.return_value = errno.EAGAIN @@ -337,7 +337,7 @@ def test__sock_connect_tryagain(self): self.loop.add_writer.call_args[0]) def test__sock_connect_exception(self): - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) sock = unittest.mock.Mock() sock.fileno.return_value = 10 sock.getsockopt.return_value = errno.ENOTCONN @@ -351,12 +351,12 @@ def test_sock_accept(self): self.loop._sock_accept = unittest.mock.Mock() f = self.loop.sock_accept(sock) - self.assertIsInstance(f, futures.Future) + self.assertIsInstance(f, asyncio.Future) self.assertEqual( (f, False, sock), self.loop._sock_accept.call_args[0]) def test__sock_accept(self): - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) conn = unittest.mock.Mock() @@ -372,7 +372,7 @@ def test__sock_accept(self): def test__sock_accept_canceled_fut(self): sock = unittest.mock.Mock() - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) f.cancel() self.loop._sock_accept(f, False, sock) @@ -382,7 +382,7 @@ def test__sock_accept_unregister(self): sock = unittest.mock.Mock() sock.fileno.return_value = 10 - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) f.cancel() self.loop.remove_reader = unittest.mock.Mock() @@ -390,7 +390,7 @@ def test__sock_accept_unregister(self): self.assertEqual((10,), self.loop.remove_reader.call_args[0]) def test__sock_accept_tryagain(self): - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) sock = unittest.mock.Mock() sock.fileno.return_value = 10 sock.accept.side_effect = BlockingIOError @@ -402,7 +402,7 @@ def test__sock_accept_tryagain(self): self.loop.add_reader.call_args[0]) def test__sock_accept_exception(self): - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) sock = unittest.mock.Mock() sock.fileno.return_value = 10 err = sock.accept.side_effect = OSError() @@ -684,7 +684,7 @@ def test_ctor(self): self.protocol.connection_made.assert_called_with(tr) def test_ctor_with_waiter(self): - fut = futures.Future(loop=self.loop) + fut = asyncio.Future(loop=self.loop) _SelectorSocketTransport( self.loop, self.sock, self.protocol, fut) @@ -1055,7 +1055,7 @@ def _make_one(self, create_waiter=None): return transport def test_on_handshake(self): - waiter = futures.Future(loop=self.loop) + waiter = asyncio.Future(loop=self.loop) tr = _SelectorSslTransport( self.loop, self.sock, self.protocol, self.sslcontext, waiter=waiter) @@ -1083,7 +1083,7 @@ def test_on_handshake_exc(self): self.sslsock.do_handshake.side_effect = exc transport = _SelectorSslTransport( self.loop, self.sock, self.protocol, self.sslcontext) - transport._waiter = futures.Future(loop=self.loop) + transport._waiter = asyncio.Future(loop=self.loop) transport._on_handshake() self.assertTrue(self.sslsock.close.called) self.assertTrue(transport._waiter.done()) @@ -1092,7 +1092,7 @@ def test_on_handshake_exc(self): def test_on_handshake_base_exc(self): transport = _SelectorSslTransport( self.loop, self.sock, self.protocol, self.sslcontext) - transport._waiter = futures.Future(loop=self.loop) + transport._waiter = asyncio.Future(loop=self.loop) exc = BaseException() self.sslsock.do_handshake.side_effect = exc self.assertRaises(BaseException, transport._on_handshake) diff --git a/tests/test_streams.py b/tests/test_streams.py index 5516c158..cd6dc1e4 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -8,9 +8,7 @@ except ImportError: ssl = None -from asyncio import events -from asyncio import streams -from asyncio import tasks +import asyncio from asyncio import test_utils @@ -19,8 +17,8 @@ class StreamReaderTests(unittest.TestCase): DATA = b'line1\nline2\nline3\n' def setUp(self): - self.loop = events.new_event_loop() - events.set_event_loop(None) + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(None) def tearDown(self): # just in case if we have transport close callbacks @@ -31,12 +29,12 @@ def tearDown(self): @unittest.mock.patch('asyncio.streams.events') def test_ctor_global_loop(self, m_events): - stream = streams.StreamReader() + stream = asyncio.StreamReader() self.assertIs(stream._loop, m_events.get_event_loop.return_value) def test_open_connection(self): with test_utils.run_test_server() as httpd: - f = streams.open_connection(*httpd.address, loop=self.loop) + f = asyncio.open_connection(*httpd.address, loop=self.loop) reader, writer = self.loop.run_until_complete(f) writer.write(b'GET / HTTP/1.0\r\n\r\n') f = reader.readline() @@ -52,12 +50,12 @@ def test_open_connection(self): def test_open_connection_no_loop_ssl(self): with test_utils.run_test_server(use_ssl=True) as httpd: try: - events.set_event_loop(self.loop) - f = streams.open_connection(*httpd.address, + asyncio.set_event_loop(self.loop) + f = asyncio.open_connection(*httpd.address, ssl=test_utils.dummy_ssl_context()) reader, writer = self.loop.run_until_complete(f) finally: - events.set_event_loop(None) + asyncio.set_event_loop(None) writer.write(b'GET / HTTP/1.0\r\n\r\n') f = reader.read() data = self.loop.run_until_complete(f) @@ -67,7 +65,7 @@ def test_open_connection_no_loop_ssl(self): def test_open_connection_error(self): with test_utils.run_test_server() as httpd: - f = streams.open_connection(*httpd.address, loop=self.loop) + f = asyncio.open_connection(*httpd.address, loop=self.loop) reader, writer = self.loop.run_until_complete(f) writer._protocol.connection_lost(ZeroDivisionError()) f = reader.read() @@ -78,20 +76,20 @@ def test_open_connection_error(self): test_utils.run_briefly(self.loop) def test_feed_empty_data(self): - stream = streams.StreamReader(loop=self.loop) + stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(b'') self.assertEqual(0, stream._byte_count) def test_feed_data_byte_count(self): - stream = streams.StreamReader(loop=self.loop) + stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(self.DATA) self.assertEqual(len(self.DATA), stream._byte_count) def test_read_zero(self): # Read zero bytes. - stream = streams.StreamReader(loop=self.loop) + stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(self.DATA) data = self.loop.run_until_complete(stream.read(0)) @@ -100,8 +98,8 @@ def test_read_zero(self): def test_read(self): # Read bytes. - stream = streams.StreamReader(loop=self.loop) - read_task = tasks.Task(stream.read(30), loop=self.loop) + stream = asyncio.StreamReader(loop=self.loop) + read_task = asyncio.Task(stream.read(30), loop=self.loop) def cb(): stream.feed_data(self.DATA) @@ -113,7 +111,7 @@ def cb(): def test_read_line_breaks(self): # Read bytes without line breaks. - stream = streams.StreamReader(loop=self.loop) + stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(b'line1') stream.feed_data(b'line2') @@ -124,8 +122,8 @@ def test_read_line_breaks(self): def test_read_eof(self): # Read bytes, stop at eof. - stream = streams.StreamReader(loop=self.loop) - read_task = tasks.Task(stream.read(1024), loop=self.loop) + stream = asyncio.StreamReader(loop=self.loop) + read_task = asyncio.Task(stream.read(1024), loop=self.loop) def cb(): stream.feed_eof() @@ -137,8 +135,8 @@ def cb(): def test_read_until_eof(self): # Read all bytes until eof. - stream = streams.StreamReader(loop=self.loop) - read_task = tasks.Task(stream.read(-1), loop=self.loop) + stream = asyncio.StreamReader(loop=self.loop) + read_task = asyncio.Task(stream.read(-1), loop=self.loop) def cb(): stream.feed_data(b'chunk1\n') @@ -152,7 +150,7 @@ def cb(): self.assertFalse(stream._byte_count) def test_read_exception(self): - stream = streams.StreamReader(loop=self.loop) + stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(b'line\n') data = self.loop.run_until_complete(stream.read(2)) @@ -164,9 +162,9 @@ def test_read_exception(self): def test_readline(self): # Read one line. - stream = streams.StreamReader(loop=self.loop) + stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(b'chunk1 ') - read_task = tasks.Task(stream.readline(), loop=self.loop) + read_task = asyncio.Task(stream.readline(), loop=self.loop) def cb(): stream.feed_data(b'chunk2 ') @@ -179,7 +177,7 @@ def cb(): self.assertEqual(len(b'\n chunk4')-1, stream._byte_count) def test_readline_limit_with_existing_data(self): - stream = streams.StreamReader(3, loop=self.loop) + stream = asyncio.StreamReader(3, loop=self.loop) stream.feed_data(b'li') stream.feed_data(b'ne1\nline2\n') @@ -187,7 +185,7 @@ def test_readline_limit_with_existing_data(self): ValueError, self.loop.run_until_complete, stream.readline()) self.assertEqual([b'line2\n'], list(stream._buffer)) - stream = streams.StreamReader(3, loop=self.loop) + stream = asyncio.StreamReader(3, loop=self.loop) stream.feed_data(b'li') stream.feed_data(b'ne1') stream.feed_data(b'li') @@ -198,7 +196,7 @@ def test_readline_limit_with_existing_data(self): self.assertEqual(2, stream._byte_count) def test_readline_limit(self): - stream = streams.StreamReader(7, loop=self.loop) + stream = asyncio.StreamReader(7, loop=self.loop) def cb(): stream.feed_data(b'chunk1') @@ -213,7 +211,7 @@ def cb(): self.assertEqual(7, stream._byte_count) def test_readline_line_byte_count(self): - stream = streams.StreamReader(loop=self.loop) + stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(self.DATA[:6]) stream.feed_data(self.DATA[6:]) @@ -223,7 +221,7 @@ def test_readline_line_byte_count(self): self.assertEqual(len(self.DATA) - len(b'line1\n'), stream._byte_count) def test_readline_eof(self): - stream = streams.StreamReader(loop=self.loop) + stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(b'some data') stream.feed_eof() @@ -231,14 +229,14 @@ def test_readline_eof(self): self.assertEqual(b'some data', line) def test_readline_empty_eof(self): - stream = streams.StreamReader(loop=self.loop) + stream = asyncio.StreamReader(loop=self.loop) stream.feed_eof() line = self.loop.run_until_complete(stream.readline()) self.assertEqual(b'', line) def test_readline_read_byte_count(self): - stream = streams.StreamReader(loop=self.loop) + stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(self.DATA) self.loop.run_until_complete(stream.readline()) @@ -251,7 +249,7 @@ def test_readline_read_byte_count(self): stream._byte_count) def test_readline_exception(self): - stream = streams.StreamReader(loop=self.loop) + stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(b'line\n') data = self.loop.run_until_complete(stream.readline()) @@ -263,7 +261,7 @@ def test_readline_exception(self): def test_readexactly_zero_or_less(self): # Read exact number of bytes (zero or less). - stream = streams.StreamReader(loop=self.loop) + stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(self.DATA) data = self.loop.run_until_complete(stream.readexactly(0)) @@ -276,10 +274,10 @@ def test_readexactly_zero_or_less(self): def test_readexactly(self): # Read exact number of bytes. - stream = streams.StreamReader(loop=self.loop) + stream = asyncio.StreamReader(loop=self.loop) n = 2 * len(self.DATA) - read_task = tasks.Task(stream.readexactly(n), loop=self.loop) + read_task = asyncio.Task(stream.readexactly(n), loop=self.loop) def cb(): stream.feed_data(self.DATA) @@ -293,9 +291,9 @@ def cb(): def test_readexactly_eof(self): # Read exact number of bytes (eof). - stream = streams.StreamReader(loop=self.loop) + stream = asyncio.StreamReader(loop=self.loop) n = 2 * len(self.DATA) - read_task = tasks.Task(stream.readexactly(n), loop=self.loop) + read_task = asyncio.Task(stream.readexactly(n), loop=self.loop) def cb(): stream.feed_data(self.DATA) @@ -307,7 +305,7 @@ def cb(): self.assertFalse(stream._byte_count) def test_readexactly_exception(self): - stream = streams.StreamReader(loop=self.loop) + stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(b'line\n') data = self.loop.run_until_complete(stream.readexactly(2)) @@ -318,7 +316,7 @@ def test_readexactly_exception(self): ValueError, self.loop.run_until_complete, stream.readexactly(2)) def test_exception(self): - stream = streams.StreamReader(loop=self.loop) + stream = asyncio.StreamReader(loop=self.loop) self.assertIsNone(stream.exception()) exc = ValueError() @@ -326,31 +324,31 @@ def test_exception(self): self.assertIs(stream.exception(), exc) def test_exception_waiter(self): - stream = streams.StreamReader(loop=self.loop) + stream = asyncio.StreamReader(loop=self.loop) - @tasks.coroutine + @asyncio.coroutine def set_err(): stream.set_exception(ValueError()) - @tasks.coroutine + @asyncio.coroutine def readline(): yield from stream.readline() - t1 = tasks.Task(stream.readline(), loop=self.loop) - t2 = tasks.Task(set_err(), loop=self.loop) + t1 = asyncio.Task(stream.readline(), loop=self.loop) + t2 = asyncio.Task(set_err(), loop=self.loop) - self.loop.run_until_complete(tasks.wait([t1, t2], loop=self.loop)) + self.loop.run_until_complete(asyncio.wait([t1, t2], loop=self.loop)) self.assertRaises(ValueError, t1.result) def test_exception_cancel(self): - stream = streams.StreamReader(loop=self.loop) + stream = asyncio.StreamReader(loop=self.loop) - @tasks.coroutine + @asyncio.coroutine def read_a_line(): yield from stream.readline() - t = tasks.Task(read_a_line(), loop=self.loop) + t = asyncio.Task(read_a_line(), loop=self.loop) test_utils.run_briefly(self.loop) t.cancel() test_utils.run_briefly(self.loop) @@ -367,19 +365,19 @@ def __init__(self, loop): self.server = None self.loop = loop - @tasks.coroutine + @asyncio.coroutine def handle_client(self, client_reader, client_writer): data = yield from client_reader.readline() client_writer.write(data) def start(self): self.server = self.loop.run_until_complete( - streams.start_server(self.handle_client, + asyncio.start_server(self.handle_client, '127.0.0.1', 12345, loop=self.loop)) def handle_client_callback(self, client_reader, client_writer): - task = tasks.Task(client_reader.readline(), loop=self.loop) + task = asyncio.Task(client_reader.readline(), loop=self.loop) def done(task): client_writer.write(task.result()) @@ -388,7 +386,7 @@ def done(task): def start_callback(self): self.server = self.loop.run_until_complete( - streams.start_server(self.handle_client_callback, + asyncio.start_server(self.handle_client_callback, '127.0.0.1', 12345, loop=self.loop)) @@ -398,9 +396,9 @@ def stop(self): self.loop.run_until_complete(self.server.wait_closed()) self.server = None - @tasks.coroutine + @asyncio.coroutine def client(): - reader, writer = yield from streams.open_connection( + reader, writer = yield from asyncio.open_connection( '127.0.0.1', 12345, loop=self.loop) # send a line writer.write(b"hello world!\n") @@ -412,7 +410,7 @@ def client(): # test the server variant with a coroutine as client handler server = MyServer(self.loop) server.start() - msg = self.loop.run_until_complete(tasks.Task(client(), + msg = self.loop.run_until_complete(asyncio.Task(client(), loop=self.loop)) server.stop() self.assertEqual(msg, b"hello world!\n") @@ -420,7 +418,7 @@ def client(): # test the server variant with a callback as client handler server = MyServer(self.loop) server.start_callback() - msg = self.loop.run_until_complete(tasks.Task(client(), + msg = self.loop.run_until_complete(asyncio.Task(client(), loop=self.loop)) server.stop() self.assertEqual(msg, b"hello world!\n") diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 3d08ad8b..dbf130c1 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -5,9 +5,7 @@ import unittest.mock from unittest.mock import Mock -from asyncio import events -from asyncio import futures -from asyncio import tasks +import asyncio from asyncio import test_utils @@ -24,115 +22,115 @@ class TaskTests(unittest.TestCase): def setUp(self): self.loop = test_utils.TestLoop() - events.set_event_loop(None) + asyncio.set_event_loop(None) def tearDown(self): self.loop.close() gc.collect() def test_task_class(self): - @tasks.coroutine + @asyncio.coroutine def notmuch(): return 'ok' - t = tasks.Task(notmuch(), loop=self.loop) + t = asyncio.Task(notmuch(), loop=self.loop) self.loop.run_until_complete(t) self.assertTrue(t.done()) self.assertEqual(t.result(), 'ok') self.assertIs(t._loop, self.loop) - loop = events.new_event_loop() - t = tasks.Task(notmuch(), loop=loop) + loop = asyncio.new_event_loop() + t = asyncio.Task(notmuch(), loop=loop) self.assertIs(t._loop, loop) loop.close() def test_async_coroutine(self): - @tasks.coroutine + @asyncio.coroutine def notmuch(): return 'ok' - t = tasks.async(notmuch(), loop=self.loop) + t = asyncio.async(notmuch(), loop=self.loop) self.loop.run_until_complete(t) self.assertTrue(t.done()) self.assertEqual(t.result(), 'ok') self.assertIs(t._loop, self.loop) - loop = events.new_event_loop() - t = tasks.async(notmuch(), loop=loop) + loop = asyncio.new_event_loop() + t = asyncio.async(notmuch(), loop=loop) self.assertIs(t._loop, loop) loop.close() def test_async_future(self): - f_orig = futures.Future(loop=self.loop) + f_orig = asyncio.Future(loop=self.loop) f_orig.set_result('ko') - f = tasks.async(f_orig) + f = asyncio.async(f_orig) self.loop.run_until_complete(f) self.assertTrue(f.done()) self.assertEqual(f.result(), 'ko') self.assertIs(f, f_orig) - loop = events.new_event_loop() + loop = asyncio.new_event_loop() with self.assertRaises(ValueError): - f = tasks.async(f_orig, loop=loop) + f = asyncio.async(f_orig, loop=loop) loop.close() - f = tasks.async(f_orig, loop=self.loop) + f = asyncio.async(f_orig, loop=self.loop) self.assertIs(f, f_orig) def test_async_task(self): - @tasks.coroutine + @asyncio.coroutine def notmuch(): return 'ok' - t_orig = tasks.Task(notmuch(), loop=self.loop) - t = tasks.async(t_orig) + t_orig = asyncio.Task(notmuch(), loop=self.loop) + t = asyncio.async(t_orig) self.loop.run_until_complete(t) self.assertTrue(t.done()) self.assertEqual(t.result(), 'ok') self.assertIs(t, t_orig) - loop = events.new_event_loop() + loop = asyncio.new_event_loop() with self.assertRaises(ValueError): - t = tasks.async(t_orig, loop=loop) + t = asyncio.async(t_orig, loop=loop) loop.close() - t = tasks.async(t_orig, loop=self.loop) + t = asyncio.async(t_orig, loop=self.loop) self.assertIs(t, t_orig) def test_async_neither(self): with self.assertRaises(TypeError): - tasks.async('ok') + asyncio.async('ok') def test_task_repr(self): - @tasks.coroutine + @asyncio.coroutine def notmuch(): yield from [] return 'abc' - t = tasks.Task(notmuch(), loop=self.loop) + t = asyncio.Task(notmuch(), loop=self.loop) t.add_done_callback(Dummy()) self.assertEqual(repr(t), 'Task()') t.cancel() # Does not take immediate effect! self.assertEqual(repr(t), 'Task()') - self.assertRaises(futures.CancelledError, + self.assertRaises(asyncio.CancelledError, self.loop.run_until_complete, t) self.assertEqual(repr(t), 'Task()') - t = tasks.Task(notmuch(), loop=self.loop) + t = asyncio.Task(notmuch(), loop=self.loop) self.loop.run_until_complete(t) self.assertEqual(repr(t), "Task()") def test_task_repr_custom(self): - @tasks.coroutine + @asyncio.coroutine def coro(): pass - class T(futures.Future): + class T(asyncio.Future): def __repr__(self): return 'T[]' - class MyTask(tasks.Task, T): + class MyTask(asyncio.Task, T): def __repr__(self): return super().__repr__() @@ -142,17 +140,17 @@ def __repr__(self): gen.close() def test_task_basics(self): - @tasks.coroutine + @asyncio.coroutine def outer(): a = yield from inner1() b = yield from inner2() return a+b - @tasks.coroutine + @asyncio.coroutine def inner1(): return 42 - @tasks.coroutine + @asyncio.coroutine def inner2(): return 1000 @@ -169,66 +167,66 @@ def gen(): loop = test_utils.TestLoop(gen) self.addCleanup(loop.close) - @tasks.coroutine + @asyncio.coroutine def task(): - yield from tasks.sleep(10.0, loop=loop) + yield from asyncio.sleep(10.0, loop=loop) return 12 - t = tasks.Task(task(), loop=loop) + t = asyncio.Task(task(), loop=loop) loop.call_soon(t.cancel) - with self.assertRaises(futures.CancelledError): + with self.assertRaises(asyncio.CancelledError): loop.run_until_complete(t) self.assertTrue(t.done()) self.assertTrue(t.cancelled()) self.assertFalse(t.cancel()) def test_cancel_yield(self): - @tasks.coroutine + @asyncio.coroutine def task(): yield yield return 12 - t = tasks.Task(task(), loop=self.loop) + t = asyncio.Task(task(), loop=self.loop) test_utils.run_briefly(self.loop) # start coro t.cancel() self.assertRaises( - futures.CancelledError, self.loop.run_until_complete, t) + asyncio.CancelledError, self.loop.run_until_complete, t) self.assertTrue(t.done()) self.assertTrue(t.cancelled()) self.assertFalse(t.cancel()) def test_cancel_inner_future(self): - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) - @tasks.coroutine + @asyncio.coroutine def task(): yield from f return 12 - t = tasks.Task(task(), loop=self.loop) + t = asyncio.Task(task(), loop=self.loop) test_utils.run_briefly(self.loop) # start task f.cancel() - with self.assertRaises(futures.CancelledError): + with self.assertRaises(asyncio.CancelledError): self.loop.run_until_complete(t) self.assertTrue(f.cancelled()) self.assertTrue(t.cancelled()) def test_cancel_both_task_and_inner_future(self): - f = futures.Future(loop=self.loop) + f = asyncio.Future(loop=self.loop) - @tasks.coroutine + @asyncio.coroutine def task(): yield from f return 12 - t = tasks.Task(task(), loop=self.loop) + t = asyncio.Task(task(), loop=self.loop) test_utils.run_briefly(self.loop) f.cancel() t.cancel() - with self.assertRaises(futures.CancelledError): + with self.assertRaises(asyncio.CancelledError): self.loop.run_until_complete(t) self.assertTrue(t.done()) @@ -236,18 +234,18 @@ def task(): self.assertTrue(t.cancelled()) def test_cancel_task_catching(self): - fut1 = futures.Future(loop=self.loop) - fut2 = futures.Future(loop=self.loop) + fut1 = asyncio.Future(loop=self.loop) + fut2 = asyncio.Future(loop=self.loop) - @tasks.coroutine + @asyncio.coroutine def task(): yield from fut1 try: yield from fut2 - except futures.CancelledError: + except asyncio.CancelledError: return 42 - t = tasks.Task(task(), loop=self.loop) + t = asyncio.Task(task(), loop=self.loop) test_utils.run_briefly(self.loop) self.assertIs(t._fut_waiter, fut1) # White-box test. fut1.set_result(None) @@ -260,21 +258,21 @@ def task(): self.assertFalse(t.cancelled()) def test_cancel_task_ignoring(self): - fut1 = futures.Future(loop=self.loop) - fut2 = futures.Future(loop=self.loop) - fut3 = futures.Future(loop=self.loop) + fut1 = asyncio.Future(loop=self.loop) + fut2 = asyncio.Future(loop=self.loop) + fut3 = asyncio.Future(loop=self.loop) - @tasks.coroutine + @asyncio.coroutine def task(): yield from fut1 try: yield from fut2 - except futures.CancelledError: + except asyncio.CancelledError: pass res = yield from fut3 return res - t = tasks.Task(task(), loop=self.loop) + t = asyncio.Task(task(), loop=self.loop) test_utils.run_briefly(self.loop) self.assertIs(t._fut_waiter, fut1) # White-box test. fut1.set_result(None) @@ -291,20 +289,20 @@ def task(): self.assertFalse(t.cancelled()) def test_cancel_current_task(self): - loop = events.new_event_loop() + loop = asyncio.new_event_loop() self.addCleanup(loop.close) - @tasks.coroutine + @asyncio.coroutine def task(): t.cancel() self.assertTrue(t._must_cancel) # White-box test. # The sleep should be cancelled immediately. - yield from tasks.sleep(100, loop=loop) + yield from asyncio.sleep(100, loop=loop) return 12 - t = tasks.Task(task(), loop=loop) + t = asyncio.Task(task(), loop=loop) self.assertRaises( - futures.CancelledError, loop.run_until_complete, t) + asyncio.CancelledError, loop.run_until_complete, t) self.assertTrue(t.done()) self.assertFalse(t._must_cancel) # White-box test. self.assertFalse(t.cancel()) @@ -326,17 +324,17 @@ def gen(): x = 0 waiters = [] - @tasks.coroutine + @asyncio.coroutine def task(): nonlocal x while x < 10: - waiters.append(tasks.sleep(0.1, loop=loop)) + waiters.append(asyncio.sleep(0.1, loop=loop)) yield from waiters[-1] x += 1 if x == 2: loop.stop() - t = tasks.Task(task(), loop=loop) + t = asyncio.Task(task(), loop=loop) self.assertRaises( RuntimeError, loop.run_until_complete, t) self.assertFalse(t.done()) @@ -361,20 +359,20 @@ def gen(): foo_running = None - @tasks.coroutine + @asyncio.coroutine def foo(): nonlocal foo_running foo_running = True try: - yield from tasks.sleep(0.2, loop=loop) + yield from asyncio.sleep(0.2, loop=loop) finally: foo_running = False return 'done' - fut = tasks.Task(foo(), loop=loop) + fut = asyncio.Task(foo(), loop=loop) - with self.assertRaises(futures.TimeoutError): - loop.run_until_complete(tasks.wait_for(fut, 0.1, loop=loop)) + with self.assertRaises(asyncio.TimeoutError): + loop.run_until_complete(asyncio.wait_for(fut, 0.1, loop=loop)) self.assertTrue(fut.done()) # it should have been cancelled due to the timeout self.assertTrue(fut.cancelled()) @@ -394,18 +392,18 @@ def gen(): loop = test_utils.TestLoop(gen) self.addCleanup(loop.close) - @tasks.coroutine + @asyncio.coroutine def foo(): - yield from tasks.sleep(0.2, loop=loop) + yield from asyncio.sleep(0.2, loop=loop) return 'done' - events.set_event_loop(loop) + asyncio.set_event_loop(loop) try: - fut = tasks.Task(foo(), loop=loop) - with self.assertRaises(futures.TimeoutError): - loop.run_until_complete(tasks.wait_for(fut, 0.01)) + fut = asyncio.Task(foo(), loop=loop) + with self.assertRaises(asyncio.TimeoutError): + loop.run_until_complete(asyncio.wait_for(fut, 0.01)) finally: - events.set_event_loop(None) + asyncio.set_event_loop(None) self.assertAlmostEqual(0.01, loop.time()) self.assertTrue(fut.done()) @@ -423,22 +421,22 @@ def gen(): loop = test_utils.TestLoop(gen) self.addCleanup(loop.close) - a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) - b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) + b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop) - @tasks.coroutine + @asyncio.coroutine def foo(): - done, pending = yield from tasks.wait([b, a], loop=loop) + done, pending = yield from asyncio.wait([b, a], loop=loop) self.assertEqual(done, set([a, b])) self.assertEqual(pending, set()) return 42 - res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) self.assertEqual(res, 42) self.assertAlmostEqual(0.15, loop.time()) # Doing it again should take no time and exercise a different path. - res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) self.assertAlmostEqual(0.15, loop.time()) self.assertEqual(res, 42) @@ -454,33 +452,33 @@ def gen(): loop = test_utils.TestLoop(gen) self.addCleanup(loop.close) - a = tasks.Task(tasks.sleep(0.01, loop=loop), loop=loop) - b = tasks.Task(tasks.sleep(0.015, loop=loop), loop=loop) + a = asyncio.Task(asyncio.sleep(0.01, loop=loop), loop=loop) + b = asyncio.Task(asyncio.sleep(0.015, loop=loop), loop=loop) - @tasks.coroutine + @asyncio.coroutine def foo(): - done, pending = yield from tasks.wait([b, a]) + done, pending = yield from asyncio.wait([b, a]) self.assertEqual(done, set([a, b])) self.assertEqual(pending, set()) return 42 - events.set_event_loop(loop) + asyncio.set_event_loop(loop) try: res = loop.run_until_complete( - tasks.Task(foo(), loop=loop)) + asyncio.Task(foo(), loop=loop)) finally: - events.set_event_loop(None) + asyncio.set_event_loop(None) self.assertEqual(res, 42) def test_wait_errors(self): self.assertRaises( ValueError, self.loop.run_until_complete, - tasks.wait(set(), loop=self.loop)) + asyncio.wait(set(), loop=self.loop)) self.assertRaises( ValueError, self.loop.run_until_complete, - tasks.wait([tasks.sleep(10.0, loop=self.loop)], + asyncio.wait([asyncio.sleep(10.0, loop=self.loop)], return_when=-1, loop=self.loop)) def test_wait_first_completed(self): @@ -495,10 +493,10 @@ def gen(): loop = test_utils.TestLoop(gen) self.addCleanup(loop.close) - a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) - b = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) - task = tasks.Task( - tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED, + a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop) + b = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) + task = asyncio.Task( + asyncio.wait([b, a], return_when=asyncio.FIRST_COMPLETED, loop=loop), loop=loop) @@ -512,25 +510,25 @@ def gen(): # move forward to close generator loop.advance_time(10) - loop.run_until_complete(tasks.wait([a, b], loop=loop)) + loop.run_until_complete(asyncio.wait([a, b], loop=loop)) def test_wait_really_done(self): # there is possibility that some tasks in the pending list # became done but their callbacks haven't all been called yet - @tasks.coroutine + @asyncio.coroutine def coro1(): yield - @tasks.coroutine + @asyncio.coroutine def coro2(): yield yield - a = tasks.Task(coro1(), loop=self.loop) - b = tasks.Task(coro2(), loop=self.loop) - task = tasks.Task( - tasks.wait([b, a], return_when=tasks.FIRST_COMPLETED, + a = asyncio.Task(coro1(), loop=self.loop) + b = asyncio.Task(coro2(), loop=self.loop) + task = asyncio.Task( + asyncio.wait([b, a], return_when=asyncio.FIRST_COMPLETED, loop=self.loop), loop=self.loop) @@ -552,15 +550,15 @@ def gen(): self.addCleanup(loop.close) # first_exception, task already has exception - a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop) - @tasks.coroutine + @asyncio.coroutine def exc(): raise ZeroDivisionError('err') - b = tasks.Task(exc(), loop=loop) - task = tasks.Task( - tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION, + b = asyncio.Task(exc(), loop=loop) + task = asyncio.Task( + asyncio.wait([b, a], return_when=asyncio.FIRST_EXCEPTION, loop=loop), loop=loop) @@ -571,7 +569,7 @@ def exc(): # move forward to close generator loop.advance_time(10) - loop.run_until_complete(tasks.wait([a, b], loop=loop)) + loop.run_until_complete(asyncio.wait([a, b], loop=loop)) def test_wait_first_exception_in_wait(self): @@ -586,15 +584,15 @@ def gen(): self.addCleanup(loop.close) # first_exception, exception during waiting - a = tasks.Task(tasks.sleep(10.0, loop=loop), loop=loop) + a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop) - @tasks.coroutine + @asyncio.coroutine def exc(): - yield from tasks.sleep(0.01, loop=loop) + yield from asyncio.sleep(0.01, loop=loop) raise ZeroDivisionError('err') - b = tasks.Task(exc(), loop=loop) - task = tasks.wait([b, a], return_when=tasks.FIRST_EXCEPTION, + b = asyncio.Task(exc(), loop=loop) + task = asyncio.wait([b, a], return_when=asyncio.FIRST_EXCEPTION, loop=loop) done, pending = loop.run_until_complete(task) @@ -604,7 +602,7 @@ def exc(): # move forward to close generator loop.advance_time(10) - loop.run_until_complete(tasks.wait([a, b], loop=loop)) + loop.run_until_complete(asyncio.wait([a, b], loop=loop)) def test_wait_with_exception(self): @@ -618,27 +616,27 @@ def gen(): loop = test_utils.TestLoop(gen) self.addCleanup(loop.close) - a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) + a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) - @tasks.coroutine + @asyncio.coroutine def sleeper(): - yield from tasks.sleep(0.15, loop=loop) + yield from asyncio.sleep(0.15, loop=loop) raise ZeroDivisionError('really') - b = tasks.Task(sleeper(), loop=loop) + b = asyncio.Task(sleeper(), loop=loop) - @tasks.coroutine + @asyncio.coroutine def foo(): - done, pending = yield from tasks.wait([b, a], loop=loop) + done, pending = yield from asyncio.wait([b, a], loop=loop) self.assertEqual(len(done), 2) self.assertEqual(pending, set()) errors = set(f for f in done if f.exception() is not None) self.assertEqual(len(errors), 1) - loop.run_until_complete(tasks.Task(foo(), loop=loop)) + loop.run_until_complete(asyncio.Task(foo(), loop=loop)) self.assertAlmostEqual(0.15, loop.time()) - loop.run_until_complete(tasks.Task(foo(), loop=loop)) + loop.run_until_complete(asyncio.Task(foo(), loop=loop)) self.assertAlmostEqual(0.15, loop.time()) def test_wait_with_timeout(self): @@ -655,22 +653,22 @@ def gen(): loop = test_utils.TestLoop(gen) self.addCleanup(loop.close) - a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) - b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) + b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop) - @tasks.coroutine + @asyncio.coroutine def foo(): - done, pending = yield from tasks.wait([b, a], timeout=0.11, + done, pending = yield from asyncio.wait([b, a], timeout=0.11, loop=loop) self.assertEqual(done, set([a])) self.assertEqual(pending, set([b])) - loop.run_until_complete(tasks.Task(foo(), loop=loop)) + loop.run_until_complete(asyncio.Task(foo(), loop=loop)) self.assertAlmostEqual(0.11, loop.time()) # move forward to close generator loop.advance_time(10) - loop.run_until_complete(tasks.wait([a, b], loop=loop)) + loop.run_until_complete(asyncio.wait([a, b], loop=loop)) def test_wait_concurrent_complete(self): @@ -686,11 +684,11 @@ def gen(): loop = test_utils.TestLoop(gen) self.addCleanup(loop.close) - a = tasks.Task(tasks.sleep(0.1, loop=loop), loop=loop) - b = tasks.Task(tasks.sleep(0.15, loop=loop), loop=loop) + a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) + b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop) done, pending = loop.run_until_complete( - tasks.wait([b, a], timeout=0.1, loop=loop)) + asyncio.wait([b, a], timeout=0.1, loop=loop)) self.assertEqual(done, set([a])) self.assertEqual(pending, set([b])) @@ -698,7 +696,7 @@ def gen(): # move forward to close generator loop.advance_time(10) - loop.run_until_complete(tasks.wait([a, b], loop=loop)) + loop.run_until_complete(asyncio.wait([a, b], loop=loop)) def test_as_completed(self): @@ -713,10 +711,10 @@ def gen(): completed = set() time_shifted = False - @tasks.coroutine + @asyncio.coroutine def sleeper(dt, x): nonlocal time_shifted - yield from tasks.sleep(dt, loop=loop) + yield from asyncio.sleep(dt, loop=loop) completed.add(x) if not time_shifted and 'a' in completed and 'b' in completed: time_shifted = True @@ -727,21 +725,21 @@ def sleeper(dt, x): b = sleeper(0.01, 'b') c = sleeper(0.15, 'c') - @tasks.coroutine + @asyncio.coroutine def foo(): values = [] - for f in tasks.as_completed([b, c, a], loop=loop): + for f in asyncio.as_completed([b, c, a], loop=loop): values.append((yield from f)) return values - res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) self.assertAlmostEqual(0.15, loop.time()) self.assertTrue('a' in res[:2]) self.assertTrue('b' in res[:2]) self.assertEqual(res[2], 'c') # Doing it again should take no time and exercise a different path. - res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) self.assertAlmostEqual(0.15, loop.time()) def test_as_completed_with_timeout(self): @@ -760,30 +758,30 @@ def gen(): loop = test_utils.TestLoop(gen) self.addCleanup(loop.close) - a = tasks.sleep(0.1, 'a', loop=loop) - b = tasks.sleep(0.15, 'b', loop=loop) + a = asyncio.sleep(0.1, 'a', loop=loop) + b = asyncio.sleep(0.15, 'b', loop=loop) - @tasks.coroutine + @asyncio.coroutine def foo(): values = [] - for f in tasks.as_completed([a, b], timeout=0.12, loop=loop): + for f in asyncio.as_completed([a, b], timeout=0.12, loop=loop): try: v = yield from f values.append((1, v)) - except futures.TimeoutError as exc: + except asyncio.TimeoutError as exc: values.append((2, exc)) return values - res = loop.run_until_complete(tasks.Task(foo(), loop=loop)) + res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) self.assertEqual(len(res), 2, res) self.assertEqual(res[0], (1, 'a')) self.assertEqual(res[1][0], 2) - self.assertIsInstance(res[1][1], futures.TimeoutError) + self.assertIsInstance(res[1][1], asyncio.TimeoutError) self.assertAlmostEqual(0.12, loop.time()) # move forward to close generator loop.advance_time(10) - loop.run_until_complete(tasks.wait([a, b], loop=loop)) + loop.run_until_complete(asyncio.wait([a, b], loop=loop)) def test_as_completed_reverse_wait(self): @@ -795,10 +793,10 @@ def gen(): loop = test_utils.TestLoop(gen) self.addCleanup(loop.close) - a = tasks.sleep(0.05, 'a', loop=loop) - b = tasks.sleep(0.10, 'b', loop=loop) + a = asyncio.sleep(0.05, 'a', loop=loop) + b = asyncio.sleep(0.10, 'b', loop=loop) fs = {a, b} - futs = list(tasks.as_completed(fs, loop=loop)) + futs = list(asyncio.as_completed(fs, loop=loop)) self.assertEqual(len(futs), 2) x = loop.run_until_complete(futs[1]) @@ -821,12 +819,12 @@ def gen(): loop = test_utils.TestLoop(gen) self.addCleanup(loop.close) - a = tasks.sleep(0.05, 'a', loop=loop) - b = tasks.sleep(0.05, 'b', loop=loop) + a = asyncio.sleep(0.05, 'a', loop=loop) + b = asyncio.sleep(0.05, 'b', loop=loop) fs = {a, b} - futs = list(tasks.as_completed(fs, loop=loop)) + futs = list(asyncio.as_completed(fs, loop=loop)) self.assertEqual(len(futs), 2) - waiter = tasks.wait(futs, loop=loop) + waiter = asyncio.wait(futs, loop=loop) done, pending = loop.run_until_complete(waiter) self.assertEqual(set(f.result() for f in done), {'a', 'b'}) @@ -842,13 +840,13 @@ def gen(): loop = test_utils.TestLoop(gen) self.addCleanup(loop.close) - @tasks.coroutine + @asyncio.coroutine def sleeper(dt, arg): - yield from tasks.sleep(dt/2, loop=loop) - res = yield from tasks.sleep(dt/2, arg, loop=loop) + yield from asyncio.sleep(dt/2, loop=loop) + res = yield from asyncio.sleep(dt/2, arg, loop=loop) return res - t = tasks.Task(sleeper(0.1, 'yeah'), loop=loop) + t = asyncio.Task(sleeper(0.1, 'yeah'), loop=loop) loop.run_until_complete(t) self.assertTrue(t.done()) self.assertEqual(t.result(), 'yeah') @@ -864,7 +862,7 @@ def gen(): loop = test_utils.TestLoop(gen) self.addCleanup(loop.close) - t = tasks.Task(tasks.sleep(10.0, 'yeah', loop=loop), + t = asyncio.Task(asyncio.sleep(10.0, 'yeah', loop=loop), loop=loop) handle = None @@ -898,19 +896,19 @@ def gen(): sleepfut = None - @tasks.coroutine + @asyncio.coroutine def sleep(dt): nonlocal sleepfut - sleepfut = tasks.sleep(dt, loop=loop) + sleepfut = asyncio.sleep(dt, loop=loop) yield from sleepfut - @tasks.coroutine + @asyncio.coroutine def doit(): - sleeper = tasks.Task(sleep(5000), loop=loop) + sleeper = asyncio.Task(sleep(5000), loop=loop) loop.call_later(0.1, sleeper.cancel) try: yield from sleeper - except futures.CancelledError: + except asyncio.CancelledError: return 'cancelled' else: return 'slept in' @@ -920,37 +918,37 @@ def doit(): self.assertAlmostEqual(0.1, loop.time()) def test_task_cancel_waiter_future(self): - fut = futures.Future(loop=self.loop) + fut = asyncio.Future(loop=self.loop) - @tasks.coroutine + @asyncio.coroutine def coro(): yield from fut - task = tasks.Task(coro(), loop=self.loop) + task = asyncio.Task(coro(), loop=self.loop) test_utils.run_briefly(self.loop) self.assertIs(task._fut_waiter, fut) task.cancel() test_utils.run_briefly(self.loop) self.assertRaises( - futures.CancelledError, self.loop.run_until_complete, task) + asyncio.CancelledError, self.loop.run_until_complete, task) self.assertIsNone(task._fut_waiter) self.assertTrue(fut.cancelled()) def test_step_in_completed_task(self): - @tasks.coroutine + @asyncio.coroutine def notmuch(): return 'ko' gen = notmuch() - task = tasks.Task(gen, loop=self.loop) + task = asyncio.Task(gen, loop=self.loop) task.set_result('ok') self.assertRaises(AssertionError, task._step) gen.close() def test_step_result(self): - @tasks.coroutine + @asyncio.coroutine def notmuch(): yield None yield 1 @@ -962,7 +960,7 @@ def notmuch(): def test_step_result_future(self): # If coroutine returns future, task waits on this future. - class Fut(futures.Future): + class Fut(asyncio.Future): def __init__(self, *args, **kwds): self.cb_added = False super().__init__(*args, **kwds) @@ -974,12 +972,12 @@ def add_done_callback(self, fn): fut = Fut(loop=self.loop) result = None - @tasks.coroutine + @asyncio.coroutine def wait_for_future(): nonlocal result result = yield from fut - t = tasks.Task(wait_for_future(), loop=self.loop) + t = asyncio.Task(wait_for_future(), loop=self.loop) test_utils.run_briefly(self.loop) self.assertTrue(fut.cb_added) @@ -991,11 +989,11 @@ def wait_for_future(): self.assertIsNone(t.result()) def test_step_with_baseexception(self): - @tasks.coroutine + @asyncio.coroutine def notmutch(): raise BaseException() - task = tasks.Task(notmutch(), loop=self.loop) + task = asyncio.Task(notmutch(), loop=self.loop) self.assertRaises(BaseException, task._step) self.assertTrue(task.done()) @@ -1011,20 +1009,20 @@ def gen(): loop = test_utils.TestLoop(gen) self.addCleanup(loop.close) - @tasks.coroutine + @asyncio.coroutine def sleeper(): - yield from tasks.sleep(10, loop=loop) + yield from asyncio.sleep(10, loop=loop) base_exc = BaseException() - @tasks.coroutine + @asyncio.coroutine def notmutch(): try: yield from sleeper() - except futures.CancelledError: + except asyncio.CancelledError: raise base_exc - task = tasks.Task(notmutch(), loop=loop) + task = asyncio.Task(notmutch(), loop=loop) test_utils.run_briefly(loop) task.cancel() @@ -1040,21 +1038,21 @@ def test_iscoroutinefunction(self): def fn(): pass - self.assertFalse(tasks.iscoroutinefunction(fn)) + self.assertFalse(asyncio.iscoroutinefunction(fn)) def fn1(): yield - self.assertFalse(tasks.iscoroutinefunction(fn1)) + self.assertFalse(asyncio.iscoroutinefunction(fn1)) - @tasks.coroutine + @asyncio.coroutine def fn2(): yield - self.assertTrue(tasks.iscoroutinefunction(fn2)) + self.assertTrue(asyncio.iscoroutinefunction(fn2)) def test_yield_vs_yield_from(self): - fut = futures.Future(loop=self.loop) + fut = asyncio.Future(loop=self.loop) - @tasks.coroutine + @asyncio.coroutine def wait_for_future(): yield fut @@ -1065,11 +1063,11 @@ def wait_for_future(): self.assertFalse(fut.done()) def test_yield_vs_yield_from_generator(self): - @tasks.coroutine + @asyncio.coroutine def coro(): yield - @tasks.coroutine + @asyncio.coroutine def wait_for_future(): gen = coro() try: @@ -1083,72 +1081,72 @@ def wait_for_future(): self.loop.run_until_complete, task) def test_coroutine_non_gen_function(self): - @tasks.coroutine + @asyncio.coroutine def func(): return 'test' - self.assertTrue(tasks.iscoroutinefunction(func)) + self.assertTrue(asyncio.iscoroutinefunction(func)) coro = func() - self.assertTrue(tasks.iscoroutine(coro)) + self.assertTrue(asyncio.iscoroutine(coro)) res = self.loop.run_until_complete(coro) self.assertEqual(res, 'test') def test_coroutine_non_gen_function_return_future(self): - fut = futures.Future(loop=self.loop) + fut = asyncio.Future(loop=self.loop) - @tasks.coroutine + @asyncio.coroutine def func(): return fut - @tasks.coroutine + @asyncio.coroutine def coro(): fut.set_result('test') - t1 = tasks.Task(func(), loop=self.loop) - t2 = tasks.Task(coro(), loop=self.loop) + t1 = asyncio.Task(func(), loop=self.loop) + t2 = asyncio.Task(coro(), loop=self.loop) res = self.loop.run_until_complete(t1) self.assertEqual(res, 'test') self.assertIsNone(t2.result()) def test_current_task(self): - self.assertIsNone(tasks.Task.current_task(loop=self.loop)) + self.assertIsNone(asyncio.Task.current_task(loop=self.loop)) - @tasks.coroutine + @asyncio.coroutine def coro(loop): - self.assertTrue(tasks.Task.current_task(loop=loop) is task) + self.assertTrue(asyncio.Task.current_task(loop=loop) is task) - task = tasks.Task(coro(self.loop), loop=self.loop) + task = asyncio.Task(coro(self.loop), loop=self.loop) self.loop.run_until_complete(task) - self.assertIsNone(tasks.Task.current_task(loop=self.loop)) + self.assertIsNone(asyncio.Task.current_task(loop=self.loop)) def test_current_task_with_interleaving_tasks(self): - self.assertIsNone(tasks.Task.current_task(loop=self.loop)) + self.assertIsNone(asyncio.Task.current_task(loop=self.loop)) - fut1 = futures.Future(loop=self.loop) - fut2 = futures.Future(loop=self.loop) + fut1 = asyncio.Future(loop=self.loop) + fut2 = asyncio.Future(loop=self.loop) - @tasks.coroutine + @asyncio.coroutine def coro1(loop): - self.assertTrue(tasks.Task.current_task(loop=loop) is task1) + self.assertTrue(asyncio.Task.current_task(loop=loop) is task1) yield from fut1 - self.assertTrue(tasks.Task.current_task(loop=loop) is task1) + self.assertTrue(asyncio.Task.current_task(loop=loop) is task1) fut2.set_result(True) - @tasks.coroutine + @asyncio.coroutine def coro2(loop): - self.assertTrue(tasks.Task.current_task(loop=loop) is task2) + self.assertTrue(asyncio.Task.current_task(loop=loop) is task2) fut1.set_result(True) yield from fut2 - self.assertTrue(tasks.Task.current_task(loop=loop) is task2) + self.assertTrue(asyncio.Task.current_task(loop=loop) is task2) - task1 = tasks.Task(coro1(self.loop), loop=self.loop) - task2 = tasks.Task(coro2(self.loop), loop=self.loop) + task1 = asyncio.Task(coro1(self.loop), loop=self.loop) + task2 = asyncio.Task(coro2(self.loop), loop=self.loop) - self.loop.run_until_complete(tasks.wait((task1, task2), + self.loop.run_until_complete(asyncio.wait((task1, task2), loop=self.loop)) - self.assertIsNone(tasks.Task.current_task(loop=self.loop)) + self.assertIsNone(asyncio.Task.current_task(loop=self.loop)) # Some thorough tests for cancellation propagation through # coroutines, tasks and wait(). @@ -1156,30 +1154,30 @@ def coro2(loop): def test_yield_future_passes_cancel(self): # Cancelling outer() cancels inner() cancels waiter. proof = 0 - waiter = futures.Future(loop=self.loop) + waiter = asyncio.Future(loop=self.loop) - @tasks.coroutine + @asyncio.coroutine def inner(): nonlocal proof try: yield from waiter - except futures.CancelledError: + except asyncio.CancelledError: proof += 1 raise else: self.fail('got past sleep() in inner()') - @tasks.coroutine + @asyncio.coroutine def outer(): nonlocal proof try: yield from inner() - except futures.CancelledError: + except asyncio.CancelledError: proof += 100 # Expect this path. else: proof += 10 - f = tasks.async(outer(), loop=self.loop) + f = asyncio.async(outer(), loop=self.loop) test_utils.run_briefly(self.loop) f.cancel() self.loop.run_until_complete(f) @@ -1190,39 +1188,39 @@ def test_yield_wait_does_not_shield_cancel(self): # Cancelling outer() makes wait() return early, leaves inner() # running. proof = 0 - waiter = futures.Future(loop=self.loop) + waiter = asyncio.Future(loop=self.loop) - @tasks.coroutine + @asyncio.coroutine def inner(): nonlocal proof yield from waiter proof += 1 - @tasks.coroutine + @asyncio.coroutine def outer(): nonlocal proof - d, p = yield from tasks.wait([inner()], loop=self.loop) + d, p = yield from asyncio.wait([inner()], loop=self.loop) proof += 100 - f = tasks.async(outer(), loop=self.loop) + f = asyncio.async(outer(), loop=self.loop) test_utils.run_briefly(self.loop) f.cancel() self.assertRaises( - futures.CancelledError, self.loop.run_until_complete, f) + asyncio.CancelledError, self.loop.run_until_complete, f) waiter.set_result(None) test_utils.run_briefly(self.loop) self.assertEqual(proof, 1) def test_shield_result(self): - inner = futures.Future(loop=self.loop) - outer = tasks.shield(inner) + inner = asyncio.Future(loop=self.loop) + outer = asyncio.shield(inner) inner.set_result(42) res = self.loop.run_until_complete(outer) self.assertEqual(res, 42) def test_shield_exception(self): - inner = futures.Future(loop=self.loop) - outer = tasks.shield(inner) + inner = asyncio.Future(loop=self.loop) + outer = asyncio.shield(inner) test_utils.run_briefly(self.loop) exc = RuntimeError('expected') inner.set_exception(exc) @@ -1230,50 +1228,50 @@ def test_shield_exception(self): self.assertIs(outer.exception(), exc) def test_shield_cancel(self): - inner = futures.Future(loop=self.loop) - outer = tasks.shield(inner) + inner = asyncio.Future(loop=self.loop) + outer = asyncio.shield(inner) test_utils.run_briefly(self.loop) inner.cancel() test_utils.run_briefly(self.loop) self.assertTrue(outer.cancelled()) def test_shield_shortcut(self): - fut = futures.Future(loop=self.loop) + fut = asyncio.Future(loop=self.loop) fut.set_result(42) - res = self.loop.run_until_complete(tasks.shield(fut)) + res = self.loop.run_until_complete(asyncio.shield(fut)) self.assertEqual(res, 42) def test_shield_effect(self): # Cancelling outer() does not affect inner(). proof = 0 - waiter = futures.Future(loop=self.loop) + waiter = asyncio.Future(loop=self.loop) - @tasks.coroutine + @asyncio.coroutine def inner(): nonlocal proof yield from waiter proof += 1 - @tasks.coroutine + @asyncio.coroutine def outer(): nonlocal proof - yield from tasks.shield(inner(), loop=self.loop) + yield from asyncio.shield(inner(), loop=self.loop) proof += 100 - f = tasks.async(outer(), loop=self.loop) + f = asyncio.async(outer(), loop=self.loop) test_utils.run_briefly(self.loop) f.cancel() - with self.assertRaises(futures.CancelledError): + with self.assertRaises(asyncio.CancelledError): self.loop.run_until_complete(f) waiter.set_result(None) test_utils.run_briefly(self.loop) self.assertEqual(proof, 1) def test_shield_gather(self): - child1 = futures.Future(loop=self.loop) - child2 = futures.Future(loop=self.loop) - parent = tasks.gather(child1, child2, loop=self.loop) - outer = tasks.shield(parent, loop=self.loop) + child1 = asyncio.Future(loop=self.loop) + child2 = asyncio.Future(loop=self.loop) + parent = asyncio.gather(child1, child2, loop=self.loop) + outer = asyncio.shield(parent, loop=self.loop) test_utils.run_briefly(self.loop) outer.cancel() test_utils.run_briefly(self.loop) @@ -1284,16 +1282,16 @@ def test_shield_gather(self): self.assertEqual(parent.result(), [1, 2]) def test_gather_shield(self): - child1 = futures.Future(loop=self.loop) - child2 = futures.Future(loop=self.loop) - inner1 = tasks.shield(child1, loop=self.loop) - inner2 = tasks.shield(child2, loop=self.loop) - parent = tasks.gather(inner1, inner2, loop=self.loop) + child1 = asyncio.Future(loop=self.loop) + child2 = asyncio.Future(loop=self.loop) + inner1 = asyncio.shield(child1, loop=self.loop) + inner2 = asyncio.shield(child2, loop=self.loop) + parent = asyncio.gather(inner1, inner2, loop=self.loop) test_utils.run_briefly(self.loop) parent.cancel() # This should cancel inner1 and inner2 but bot child1 and child2. test_utils.run_briefly(self.loop) - self.assertIsInstance(parent.exception(), futures.CancelledError) + self.assertIsInstance(parent.exception(), asyncio.CancelledError) self.assertTrue(inner1.cancelled()) self.assertTrue(inner2.cancelled()) child1.set_result(1) @@ -1316,8 +1314,8 @@ def _run_loop(self, loop): test_utils.run_briefly(loop) def _check_success(self, **kwargs): - a, b, c = [futures.Future(loop=self.one_loop) for i in range(3)] - fut = tasks.gather(*self.wrap_futures(a, b, c), **kwargs) + a, b, c = [asyncio.Future(loop=self.one_loop) for i in range(3)] + fut = asyncio.gather(*self.wrap_futures(a, b, c), **kwargs) cb = Mock() fut.add_done_callback(cb) b.set_result(1) @@ -1338,8 +1336,8 @@ def test_result_exception_success(self): self._check_success(return_exceptions=True) def test_one_exception(self): - a, b, c, d, e = [futures.Future(loop=self.one_loop) for i in range(5)] - fut = tasks.gather(*self.wrap_futures(a, b, c, d, e)) + a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)] + fut = asyncio.gather(*self.wrap_futures(a, b, c, d, e)) cb = Mock() fut.add_done_callback(cb) exc = ZeroDivisionError() @@ -1356,8 +1354,8 @@ def test_one_exception(self): e.exception() def test_return_exceptions(self): - a, b, c, d = [futures.Future(loop=self.one_loop) for i in range(4)] - fut = tasks.gather(*self.wrap_futures(a, b, c, d), + a, b, c, d = [asyncio.Future(loop=self.one_loop) for i in range(4)] + fut = asyncio.gather(*self.wrap_futures(a, b, c, d), return_exceptions=True) cb = Mock() fut.add_done_callback(cb) @@ -1381,15 +1379,15 @@ def wrap_futures(self, *futures): return futures def _check_empty_sequence(self, seq_or_iter): - events.set_event_loop(self.one_loop) - self.addCleanup(events.set_event_loop, None) - fut = tasks.gather(*seq_or_iter) - self.assertIsInstance(fut, futures.Future) + asyncio.set_event_loop(self.one_loop) + self.addCleanup(asyncio.set_event_loop, None) + fut = asyncio.gather(*seq_or_iter) + self.assertIsInstance(fut, asyncio.Future) self.assertIs(fut._loop, self.one_loop) self._run_loop(self.one_loop) self.assertTrue(fut.done()) self.assertEqual(fut.result(), []) - fut = tasks.gather(*seq_or_iter, loop=self.other_loop) + fut = asyncio.gather(*seq_or_iter, loop=self.other_loop) self.assertIs(fut._loop, self.other_loop) def test_constructor_empty_sequence(self): @@ -1399,27 +1397,27 @@ def test_constructor_empty_sequence(self): self._check_empty_sequence(iter("")) def test_constructor_heterogenous_futures(self): - fut1 = futures.Future(loop=self.one_loop) - fut2 = futures.Future(loop=self.other_loop) + fut1 = asyncio.Future(loop=self.one_loop) + fut2 = asyncio.Future(loop=self.other_loop) with self.assertRaises(ValueError): - tasks.gather(fut1, fut2) + asyncio.gather(fut1, fut2) with self.assertRaises(ValueError): - tasks.gather(fut1, loop=self.other_loop) + asyncio.gather(fut1, loop=self.other_loop) def test_constructor_homogenous_futures(self): - children = [futures.Future(loop=self.other_loop) for i in range(3)] - fut = tasks.gather(*children) + children = [asyncio.Future(loop=self.other_loop) for i in range(3)] + fut = asyncio.gather(*children) self.assertIs(fut._loop, self.other_loop) self._run_loop(self.other_loop) self.assertFalse(fut.done()) - fut = tasks.gather(*children, loop=self.other_loop) + fut = asyncio.gather(*children, loop=self.other_loop) self.assertIs(fut._loop, self.other_loop) self._run_loop(self.other_loop) self.assertFalse(fut.done()) def test_one_cancellation(self): - a, b, c, d, e = [futures.Future(loop=self.one_loop) for i in range(5)] - fut = tasks.gather(a, b, c, d, e) + a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)] + fut = asyncio.gather(a, b, c, d, e) cb = Mock() fut.add_done_callback(cb) a.set_result(1) @@ -1428,7 +1426,7 @@ def test_one_cancellation(self): self.assertTrue(fut.done()) cb.assert_called_once_with(fut) self.assertFalse(fut.cancelled()) - self.assertIsInstance(fut.exception(), futures.CancelledError) + self.assertIsInstance(fut.exception(), asyncio.CancelledError) # Does nothing c.set_result(3) d.cancel() @@ -1436,9 +1434,9 @@ def test_one_cancellation(self): e.exception() def test_result_exception_one_cancellation(self): - a, b, c, d, e, f = [futures.Future(loop=self.one_loop) + a, b, c, d, e, f = [asyncio.Future(loop=self.one_loop) for i in range(6)] - fut = tasks.gather(a, b, c, d, e, f, return_exceptions=True) + fut = asyncio.gather(a, b, c, d, e, f, return_exceptions=True) cb = Mock() fut.add_done_callback(cb) a.set_result(1) @@ -1452,8 +1450,8 @@ def test_result_exception_one_cancellation(self): rte = RuntimeError() f.set_exception(rte) res = self.one_loop.run_until_complete(fut) - self.assertIsInstance(res[2], futures.CancelledError) - self.assertIsInstance(res[4], futures.CancelledError) + self.assertIsInstance(res[2], asyncio.CancelledError) + self.assertIsInstance(res[4], asyncio.CancelledError) res[2] = res[4] = None self.assertEqual(res, [1, zde, None, 3, None, rte]) cb.assert_called_once_with(fut) @@ -1463,34 +1461,34 @@ class CoroutineGatherTests(GatherTestsBase, unittest.TestCase): def setUp(self): super().setUp() - events.set_event_loop(self.one_loop) + asyncio.set_event_loop(self.one_loop) def tearDown(self): - events.set_event_loop(None) + asyncio.set_event_loop(None) super().tearDown() def wrap_futures(self, *futures): coros = [] for fut in futures: - @tasks.coroutine + @asyncio.coroutine def coro(fut=fut): return (yield from fut) coros.append(coro()) return coros def test_constructor_loop_selection(self): - @tasks.coroutine + @asyncio.coroutine def coro(): return 'abc' gen1 = coro() gen2 = coro() - fut = tasks.gather(gen1, gen2) + fut = asyncio.gather(gen1, gen2) self.assertIs(fut._loop, self.one_loop) gen1.close() gen2.close() gen3 = coro() gen4 = coro() - fut = tasks.gather(gen3, gen4, loop=self.other_loop) + fut = asyncio.gather(gen3, gen4, loop=self.other_loop) self.assertIs(fut._loop, self.other_loop) gen3.close() gen4.close() @@ -1498,29 +1496,29 @@ def coro(): def test_cancellation_broadcast(self): # Cancelling outer() cancels all children. proof = 0 - waiter = futures.Future(loop=self.one_loop) + waiter = asyncio.Future(loop=self.one_loop) - @tasks.coroutine + @asyncio.coroutine def inner(): nonlocal proof yield from waiter proof += 1 - child1 = tasks.async(inner(), loop=self.one_loop) - child2 = tasks.async(inner(), loop=self.one_loop) + child1 = asyncio.async(inner(), loop=self.one_loop) + child2 = asyncio.async(inner(), loop=self.one_loop) gatherer = None - @tasks.coroutine + @asyncio.coroutine def outer(): nonlocal proof, gatherer - gatherer = tasks.gather(child1, child2, loop=self.one_loop) + gatherer = asyncio.gather(child1, child2, loop=self.one_loop) yield from gatherer proof += 100 - f = tasks.async(outer(), loop=self.one_loop) + f = asyncio.async(outer(), loop=self.one_loop) test_utils.run_briefly(self.one_loop) self.assertTrue(f.cancel()) - with self.assertRaises(futures.CancelledError): + with self.assertRaises(asyncio.CancelledError): self.one_loop.run_until_complete(f) self.assertFalse(gatherer.cancel()) self.assertTrue(waiter.cancelled()) @@ -1532,19 +1530,19 @@ def outer(): def test_exception_marking(self): # Test for the first line marked "Mark exception retrieved." - @tasks.coroutine + @asyncio.coroutine def inner(f): yield from f raise RuntimeError('should not be ignored') - a = futures.Future(loop=self.one_loop) - b = futures.Future(loop=self.one_loop) + a = asyncio.Future(loop=self.one_loop) + b = asyncio.Future(loop=self.one_loop) - @tasks.coroutine + @asyncio.coroutine def outer(): - yield from tasks.gather(inner(a), inner(b), loop=self.one_loop) + yield from asyncio.gather(inner(a), inner(b), loop=self.one_loop) - f = tasks.async(outer(), loop=self.one_loop) + f = asyncio.async(outer(), loop=self.one_loop) test_utils.run_briefly(self.one_loop) a.set_result(None) test_utils.run_briefly(self.one_loop) diff --git a/tests/test_transports.py b/tests/test_transports.py index 29393b52..d16db807 100644 --- a/tests/test_transports.py +++ b/tests/test_transports.py @@ -3,17 +3,17 @@ import unittest import unittest.mock -from asyncio import transports +import asyncio class TransportTests(unittest.TestCase): def test_ctor_extra_is_none(self): - transport = transports.Transport() + transport = asyncio.Transport() self.assertEqual(transport._extra, {}) def test_get_extra_info(self): - transport = transports.Transport({'extra': 'info'}) + transport = asyncio.Transport({'extra': 'info'}) self.assertEqual('info', transport.get_extra_info('extra')) self.assertIsNone(transport.get_extra_info('unknown')) @@ -21,7 +21,7 @@ def test_get_extra_info(self): self.assertIs(default, transport.get_extra_info('unknown', default)) def test_writelines(self): - transport = transports.Transport() + transport = asyncio.Transport() transport.write = unittest.mock.Mock() transport.writelines([b'line1', @@ -31,7 +31,7 @@ def test_writelines(self): transport.write.assert_called_with(b'line1line2line3') def test_not_implemented(self): - transport = transports.Transport() + transport = asyncio.Transport() self.assertRaises(NotImplementedError, transport.set_write_buffer_limits) @@ -45,13 +45,13 @@ def test_not_implemented(self): self.assertRaises(NotImplementedError, transport.abort) def test_dgram_not_implemented(self): - transport = transports.DatagramTransport() + transport = asyncio.DatagramTransport() self.assertRaises(NotImplementedError, transport.sendto, 'data') self.assertRaises(NotImplementedError, transport.abort) def test_subprocess_transport_not_implemented(self): - transport = transports.SubprocessTransport() + transport = asyncio.SubprocessTransport() self.assertRaises(NotImplementedError, transport.get_pid) self.assertRaises(NotImplementedError, transport.get_returncode) diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 98cf4079..9461ec8b 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -17,9 +17,8 @@ raise unittest.SkipTest('UNIX only') -from asyncio import events -from asyncio import futures -from asyncio import protocols +import asyncio +from asyncio import log from asyncio import test_utils from asyncio import unix_events @@ -28,8 +27,8 @@ class SelectorEventLoopTests(unittest.TestCase): def setUp(self): - self.loop = unix_events.SelectorEventLoop() - events.set_event_loop(None) + self.loop = asyncio.SelectorEventLoop() + asyncio.set_event_loop(None) def tearDown(self): self.loop.close() @@ -44,7 +43,7 @@ def test_handle_signal_no_handler(self): self.loop._handle_signal(signal.NSIG + 1, ()) def test_handle_signal_cancelled_handler(self): - h = events.Handle(unittest.mock.Mock(), ()) + h = asyncio.Handle(unittest.mock.Mock(), ()) h.cancel() self.loop._signal_handlers[signal.NSIG + 1] = h self.loop.remove_signal_handler = unittest.mock.Mock() @@ -68,7 +67,7 @@ def test_add_signal_handler(self, m_signal): cb = lambda: True self.loop.add_signal_handler(signal.SIGHUP, cb) h = self.loop._signal_handlers.get(signal.SIGHUP) - self.assertIsInstance(h, events.Handle) + self.assertIsInstance(h, asyncio.Handle) self.assertEqual(h._callback, cb) @unittest.mock.patch('asyncio.unix_events.signal') @@ -205,7 +204,7 @@ class UnixReadPipeTransportTests(unittest.TestCase): def setUp(self): self.loop = test_utils.TestLoop() - self.protocol = test_utils.make_test_protocol(protocols.Protocol) + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) self.pipe.fileno.return_value = 5 @@ -228,7 +227,7 @@ def test_ctor(self): self.protocol.connection_made.assert_called_with(tr) def test_ctor_with_waiter(self): - fut = futures.Future(loop=self.loop) + fut = asyncio.Future(loop=self.loop) unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol, fut) test_utils.run_briefly(self.loop) @@ -368,7 +367,7 @@ class UnixWritePipeTransportTests(unittest.TestCase): def setUp(self): self.loop = test_utils.TestLoop() - self.protocol = test_utils.make_test_protocol(protocols.BaseProtocol) + self.protocol = test_utils.make_test_protocol(asyncio.BaseProtocol) self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) self.pipe.fileno.return_value = 5 @@ -391,7 +390,7 @@ def test_ctor(self): self.protocol.connection_made.assert_called_with(tr) def test_ctor_with_waiter(self): - fut = futures.Future(loop=self.loop) + fut = asyncio.Future(loop=self.loop) tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol, fut) self.loop.assert_reader(5, tr._read_ready) @@ -682,7 +681,7 @@ class AbstractChildWatcherTests(unittest.TestCase): def test_not_implemented(self): f = unittest.mock.Mock() - watcher = unix_events.AbstractChildWatcher() + watcher = asyncio.AbstractChildWatcher() self.assertRaises( NotImplementedError, watcher.add_child_handler, f, f) self.assertRaises( @@ -717,7 +716,7 @@ def test_not_implemented(self): class ChildWatcherTestsMixin: - ignore_warnings = unittest.mock.patch.object(unix_events.logger, "warning") + ignore_warnings = unittest.mock.patch.object(log.logger, "warning") def setUp(self): self.loop = test_utils.TestLoop() @@ -730,7 +729,7 @@ def setUp(self): self.watcher.attach_loop(self.loop) def waitpid(self, pid, flags): - if isinstance(self.watcher, unix_events.SafeChildWatcher) or pid != -1: + if isinstance(self.watcher, asyncio.SafeChildWatcher) or pid != -1: self.assertGreater(pid, 0) try: if pid < 0: @@ -1205,7 +1204,7 @@ def test_sigchld_unhandled_exception(self, m): # raise an exception m.waitpid.side_effect = ValueError - with unittest.mock.patch.object(unix_events.logger, + with unittest.mock.patch.object(log.logger, "exception") as m_exception: self.assertEqual(self.watcher._sig_chld(), None) @@ -1240,7 +1239,7 @@ def test_sigchld_child_reaped_elsewhere(self, m): self.watcher._sig_chld() callback.assert_called(m.waitpid) - if isinstance(self.watcher, unix_events.FastChildWatcher): + if isinstance(self.watcher, asyncio.FastChildWatcher): # here the FastChildWatche enters a deadlock # (there is no way to prevent it) self.assertFalse(callback.called) @@ -1380,7 +1379,7 @@ def test_close(self, m): self.watcher.add_child_handler(64, callback1) self.assertEqual(len(self.watcher._callbacks), 1) - if isinstance(self.watcher, unix_events.FastChildWatcher): + if isinstance(self.watcher, asyncio.FastChildWatcher): self.assertEqual(len(self.watcher._zombies), 1) with unittest.mock.patch.object( @@ -1392,31 +1391,31 @@ def test_close(self, m): m_remove_signal_handler.assert_called_once_with( signal.SIGCHLD) self.assertFalse(self.watcher._callbacks) - if isinstance(self.watcher, unix_events.FastChildWatcher): + if isinstance(self.watcher, asyncio.FastChildWatcher): self.assertFalse(self.watcher._zombies) class SafeChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase): def create_watcher(self): - return unix_events.SafeChildWatcher() + return asyncio.SafeChildWatcher() class FastChildWatcherTests (ChildWatcherTestsMixin, unittest.TestCase): def create_watcher(self): - return unix_events.FastChildWatcher() + return asyncio.FastChildWatcher() class PolicyTests(unittest.TestCase): def create_policy(self): - return unix_events.DefaultEventLoopPolicy() + return asyncio.DefaultEventLoopPolicy() def test_get_child_watcher(self): policy = self.create_policy() self.assertIsNone(policy._watcher) watcher = policy.get_child_watcher() - self.assertIsInstance(watcher, unix_events.SafeChildWatcher) + self.assertIsInstance(watcher, asyncio.SafeChildWatcher) self.assertIs(policy._watcher, watcher) @@ -1425,7 +1424,7 @@ def test_get_child_watcher(self): def test_get_child_watcher_after_set(self): policy = self.create_policy() - watcher = unix_events.FastChildWatcher() + watcher = asyncio.FastChildWatcher() policy.set_child_watcher(watcher) self.assertIs(policy._watcher, watcher) @@ -1438,7 +1437,7 @@ def test_get_child_watcher_with_mainloop_existing(self): self.assertIsNone(policy._watcher) watcher = policy.get_child_watcher() - self.assertIsInstance(watcher, unix_events.SafeChildWatcher) + self.assertIsInstance(watcher, asyncio.SafeChildWatcher) self.assertIs(watcher._loop, loop) loop.close() @@ -1449,10 +1448,10 @@ def f(): policy.set_event_loop(policy.new_event_loop()) self.assertIsInstance(policy.get_event_loop(), - events.AbstractEventLoop) + asyncio.AbstractEventLoop) watcher = policy.get_child_watcher() - self.assertIsInstance(watcher, unix_events.SafeChildWatcher) + self.assertIsInstance(watcher, asyncio.SafeChildWatcher) self.assertIsNone(watcher._loop) policy.get_event_loop().close() diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index 17c204a7..3c271ebe 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -8,17 +8,12 @@ import _winapi import asyncio - -from asyncio import windows_events -from asyncio import futures -from asyncio import protocols -from asyncio import streams -from asyncio import transports from asyncio import test_utils from asyncio import _overlapped +from asyncio import windows_events -class UpperProto(protocols.Protocol): +class UpperProto(asyncio.Protocol): def __init__(self): self.buf = [] @@ -35,7 +30,7 @@ def data_received(self, data): class ProactorTests(unittest.TestCase): def setUp(self): - self.loop = windows_events.ProactorEventLoop() + self.loop = asyncio.ProactorEventLoop() asyncio.set_event_loop(None) def tearDown(self): @@ -44,7 +39,7 @@ def tearDown(self): def test_close(self): a, b = self.loop._socketpair() - trans = self.loop._make_socket_transport(a, protocols.Protocol()) + trans = self.loop._make_socket_transport(a, asyncio.Protocol()) f = asyncio.async(self.loop.sock_recv(b, 100)) trans.close() self.loop.run_until_complete(f) @@ -67,7 +62,7 @@ def _test_pipe(self): with self.assertRaises(FileNotFoundError): yield from self.loop.create_pipe_connection( - protocols.Protocol, ADDRESS) + asyncio.Protocol, ADDRESS) [server] = yield from self.loop.start_serving_pipe( UpperProto, ADDRESS) @@ -75,11 +70,11 @@ def _test_pipe(self): clients = [] for i in range(5): - stream_reader = streams.StreamReader(loop=self.loop) - protocol = streams.StreamReaderProtocol(stream_reader) + stream_reader = asyncio.StreamReader(loop=self.loop) + protocol = asyncio.StreamReaderProtocol(stream_reader) trans, proto = yield from self.loop.create_pipe_connection( lambda: protocol, ADDRESS) - self.assertIsInstance(trans, transports.Transport) + self.assertIsInstance(trans, asyncio.Transport) self.assertEqual(protocol, proto) clients.append((stream_reader, trans)) @@ -95,7 +90,7 @@ def _test_pipe(self): with self.assertRaises(FileNotFoundError): yield from self.loop.create_pipe_connection( - protocols.Protocol, ADDRESS) + asyncio.Protocol, ADDRESS) return 'done' @@ -130,7 +125,7 @@ def test_wait_for_handle(self): f = self.loop._proactor.wait_for_handle(event, 10) f.cancel() start = self.loop.time() - with self.assertRaises(futures.CancelledError): + with self.assertRaises(asyncio.CancelledError): self.loop.run_until_complete(f) elapsed = self.loop.time() - start self.assertTrue(0 <= elapsed < 0.1, elapsed) From 7fd0da8fa6c085e1dc613153016d591e1ecd67e9 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 24 Jan 2014 15:01:27 +0100 Subject: [PATCH 0877/1502] Unit tests: pick symbols from the asyncio module --- tests/test_selector_events.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 908ee5b3..4bc35952 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -16,7 +16,6 @@ import asyncio from asyncio import selectors from asyncio import test_utils -from asyncio.protocols import DatagramProtocol, Protocol from asyncio.selector_events import BaseSelectorEventLoop from asyncio.selector_events import _SelectorTransport from asyncio.selector_events import _SelectorSslTransport @@ -585,7 +584,7 @@ class SelectorTransportTests(unittest.TestCase): def setUp(self): self.loop = test_utils.TestLoop() - self.protocol = test_utils.make_test_protocol(Protocol) + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) self.sock = unittest.mock.Mock(socket.socket) self.sock.fileno.return_value = 7 @@ -672,7 +671,7 @@ class SelectorSocketTransportTests(unittest.TestCase): def setUp(self): self.loop = test_utils.TestLoop() - self.protocol = test_utils.make_test_protocol(Protocol) + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) self.sock = unittest.mock.Mock(socket.socket) self.sock_fd = self.sock.fileno.return_value = 7 @@ -1037,7 +1036,7 @@ class SelectorSslTransportTests(unittest.TestCase): def setUp(self): self.loop = test_utils.TestLoop() - self.protocol = test_utils.make_test_protocol(Protocol) + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) self.sock = unittest.mock.Mock(socket.socket) self.sock.fileno.return_value = 7 self.sslsock = unittest.mock.Mock() @@ -1366,7 +1365,7 @@ class SelectorDatagramTransportTests(unittest.TestCase): def setUp(self): self.loop = test_utils.TestLoop() - self.protocol = test_utils.make_test_protocol(DatagramProtocol) + self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol) self.sock = unittest.mock.Mock(spec_set=socket.socket) self.sock.fileno.return_value = 7 From 93b6ec9295a6e2b2731b410d8821faf5c99c681f Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 24 Jan 2014 17:00:20 +0100 Subject: [PATCH 0878/1502] _UnixWritePipeTransport now also supports character devices, as _UnixReadPipeTransport. Patch written by Jonathan Slenders. --- AUTHORS | 2 +- asyncio/unix_events.py | 8 ++++--- tests/test_events.py | 50 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 4 deletions(-) diff --git a/AUTHORS b/AUTHORS index 263e8dc2..d7afcb79 100644 --- a/AUTHORS +++ b/AUTHORS @@ -19,7 +19,7 @@ Contributors: - Giampaolo Rodola' - Gustavo Carneiro - Jeff Quast -- Jonathan Slenders +- Jonathan Slenders - Nikolay Kim - Richard Oudkerk - Saúl Ibarra Corretgé diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 24da3274..bb0d80d6 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -259,9 +259,11 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): self._fileno = pipe.fileno() mode = os.fstat(self._fileno).st_mode is_socket = stat.S_ISSOCK(mode) - is_pipe = stat.S_ISFIFO(mode) - if not (is_socket or is_pipe): - raise ValueError("Pipe transport is for pipes/sockets only.") + if not (is_socket or + stat.S_ISFIFO(mode) or + stat.S_ISCHR(mode)): + raise ValueError("Pipe transport is only for " + "pipes, sockets and character devices") _set_nonblocking(self._fileno) self._protocol = protocol self._buffer = [] diff --git a/tests/test_events.py b/tests/test_events.py index e49c4be5..ae0c372a 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1076,6 +1076,56 @@ def connect(): self.loop.run_until_complete(proto.done) self.assertEqual('CLOSED', proto.state) + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + # kqueue doesn't support character devices (PTY) on Mac OS X older + # than 10.9 (Maverick) + @support.requires_mac_ver(10, 9) + def test_write_pty(self): + proto = None + transport = None + + def factory(): + nonlocal proto + proto = MyWritePipeProto(loop=self.loop) + return proto + + master, slave = os.openpty() + slave_write_obj = io.open(slave, 'wb', 0) + + @asyncio.coroutine + def connect(): + nonlocal transport + t, p = yield from self.loop.connect_write_pipe(factory, + slave_write_obj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual('CONNECTED', proto.state) + transport = t + + self.loop.run_until_complete(connect()) + + transport.write(b'1') + test_utils.run_briefly(self.loop) + data = os.read(master, 1024) + self.assertEqual(b'1', data) + + transport.write(b'2345') + test_utils.run_briefly(self.loop) + data = os.read(master, 1024) + self.assertEqual(b'2345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(master) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + def test_prompt_cancellation(self): r, w = test_utils.socketpair() r.setblocking(False) From 84809364d1c7e49c2b4cb06002b45d68bc30d998 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 24 Jan 2014 17:15:12 +0100 Subject: [PATCH 0879/1502] Only skip PTY tests with the kqueue selector PTY tests pass with select and poll selectors on Mac OS 10.9. --- tests/test_events.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/test_events.py b/tests/test_events.py index ae0c372a..7bd411c1 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -953,9 +953,6 @@ def connect(): @unittest.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") - # kqueue doesn't support character devices (PTY) on Mac OS X older - # than 10.9 (Maverick) - @support.requires_mac_ver(10, 9) def test_read_pty_output(self): proto = None @@ -1078,9 +1075,6 @@ def connect(): @unittest.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") - # kqueue doesn't support character devices (PTY) on Mac OS X older - # than 10.9 (Maverick) - @support.requires_mac_ver(10, 9) def test_write_pty(self): proto = None transport = None @@ -1542,6 +1536,18 @@ def create_event_loop(self): return asyncio.SelectorEventLoop( selectors.KqueueSelector()) + # kqueue doesn't support character devices (PTY) on Mac OS X older + # than 10.9 (Maverick) + @support.requires_mac_ver(10, 9) + def test_read_pty_output(self): + super().test_read_pty_output() + + # kqueue doesn't support character devices (PTY) on Mac OS X older + # than 10.9 (Maverick) + @support.requires_mac_ver(10, 9) + def test_write_pty(self): + super().test_write_pty() + if hasattr(selectors, 'EpollSelector'): class EPollEventLoopTests(UnixEventLoopTestsMixin, SubprocessTestsMixin, From 5bde189ea17b93f5c82db4ad8fb6d0cbf24dfb1c Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 24 Jan 2014 17:47:16 +0100 Subject: [PATCH 0880/1502] Issue #111: StreamReader.readexactly() now raises an IncompleteReadError if the end of stream is reached before we received enough bytes, instead of returning less bytes than requested. --- asyncio/streams.py | 22 ++++++++++++++++------ tests/test_streams.py | 8 ++++++-- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/asyncio/streams.py b/asyncio/streams.py index b53080ef..10d3591f 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -1,7 +1,7 @@ """Stream-related things.""" __all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol', - 'open_connection', 'start_server', + 'open_connection', 'start_server', 'IncompleteReadError', ] import collections @@ -14,6 +14,19 @@ _DEFAULT_LIMIT = 2**16 +class IncompleteReadError(EOFError): + """ + Incomplete read error. Attributes: + + - partial: read bytes string before the end of stream was reached + - expected: total number of expected bytes + """ + def __init__(self, partial, expected): + EOFError.__init__(self, "%s bytes read on a total of %s expected bytes" + % (len(partial), expected)) + self.partial = partial + self.expected = expected + @tasks.coroutine def open_connection(host=None, port=None, *, @@ -403,12 +416,9 @@ def readexactly(self, n): while n > 0: block = yield from self.read(n) if not block: - break + partial = b''.join(blocks) + raise IncompleteReadError(partial, len(partial) + n) blocks.append(block) n -= len(block) - # TODO: Raise EOFError if we break before n == 0? (That would - # be a change in specification, but I've always had to add an - # explicit size check to the caller.) - return b''.join(blocks) diff --git a/tests/test_streams.py b/tests/test_streams.py index cd6dc1e4..2e4f99f8 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -300,8 +300,12 @@ def cb(): stream.feed_eof() self.loop.call_soon(cb) - data = self.loop.run_until_complete(read_task) - self.assertEqual(self.DATA, data) + with self.assertRaises(asyncio.IncompleteReadError) as cm: + self.loop.run_until_complete(read_task) + self.assertEqual(cm.exception.partial, self.DATA) + self.assertEqual(cm.exception.expected, n) + self.assertEqual(str(cm.exception), + '18 bytes read on a total of 36 expected bytes') self.assertFalse(stream._byte_count) def test_readexactly_exception(self): From 8677d4fc93ab5b3577bf832bb13500841e82e831 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 24 Jan 2014 17:52:17 +0100 Subject: [PATCH 0881/1502] AUTHORS: remove Authors/Contributors titles, only mention that Guido is the author of the project --- AUTHORS | 57 ++++++++++++++++++++++++++------------------------------- 1 file changed, 26 insertions(+), 31 deletions(-) diff --git a/AUTHORS b/AUTHORS index d7afcb79..a5892b3d 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,32 +1,27 @@ -Authors: - -- Guido van Rossum - -Contributors: - -- A. Jesse Jiryu Davis -- Aaron Griffith -- Andrew Svetlov -- Anthony Baire -- Antoine Pitrou -- Arnaud Faure -- Aymeric Augustin -- Brett Cannon -- Charles-François Natali -- Christian Heimes -- Eli Bendersky -- Geert Jansen -- Giampaolo Rodola' -- Gustavo Carneiro -- Jeff Quast -- Jonathan Slenders -- Nikolay Kim -- Richard Oudkerk -- Saúl Ibarra Corretgé -- Serhiy Storchaka -- Sonald Stufft -- Vajrasky Kok -- Victor Stinner -- Vladimir Kryachko -- Yury Selivanov +A. Jesse Jiryu Davis +Aaron Griffith +Andrew Svetlov +Anthony Baire +Antoine Pitrou +Arnaud Faure +Aymeric Augustin +Brett Cannon +Charles-François Natali +Christian Heimes +Eli Bendersky +Geert Jansen +Giampaolo Rodola' +Guido van Rossum : creator of the Tulip project and author of the PEP 3156 +Gustavo Carneiro +Jeff Quast +Jonathan Slenders +Nikolay Kim +Richard Oudkerk +Saúl Ibarra Corretgé +Serhiy Storchaka +Sonald Stufft +Vajrasky Kok +Victor Stinner +Vladimir Kryachko +Yury Selivanov From 3c84a8eb79372c3b656dbb5a022cfca94e0990cc Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 25 Jan 2014 14:45:44 +0100 Subject: [PATCH 0882/1502] Revert changes in selectors for PollSelector and EpollSelector: round again the timeout towards zero, instead of rounding away from zero --- asyncio/selectors.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/asyncio/selectors.py b/asyncio/selectors.py index 1bdf972c..cd8b29e4 100644 --- a/asyncio/selectors.py +++ b/asyncio/selectors.py @@ -8,7 +8,6 @@ from abc import ABCMeta, abstractmethod from collections import namedtuple, Mapping import functools -import math import select import sys @@ -357,9 +356,8 @@ def select(self, timeout=None): elif timeout <= 0: timeout = 0 else: - # poll() has a resolution of 1 millisecond, round away from - # zero to wait *at least* timeout seconds. - timeout = int(math.ceil(timeout * 1e3)) + # Round towards zero + timeout = int(timeout * 1000) ready = [] try: fd_event_list = self._poll.poll(timeout) @@ -415,10 +413,6 @@ def select(self, timeout=None): timeout = -1 elif timeout <= 0: timeout = 0 - else: - # epoll_wait() has a resolution of 1 millisecond, round away - # from zero to wait *at least* timeout seconds. - timeout = math.ceil(timeout * 1e3) * 1e-3 max_ev = len(self._fd_to_key) ready = [] try: From 3d3bebf7545d4d0b371c016b0b339617cb905b65 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 25 Jan 2014 15:06:20 +0100 Subject: [PATCH 0883/1502] selectors: add a resolution attribute to BaseSelector --- asyncio/selectors.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/asyncio/selectors.py b/asyncio/selectors.py index cd8b29e4..b1b530af 100644 --- a/asyncio/selectors.py +++ b/asyncio/selectors.py @@ -5,7 +5,7 @@ """ -from abc import ABCMeta, abstractmethod +from abc import ABCMeta, abstractmethod, abstractproperty from collections import namedtuple, Mapping import functools import select @@ -82,6 +82,11 @@ class BaseSelector(metaclass=ABCMeta): performant implementation on the current platform. """ + @abstractproperty + def resolution(self): + """Resolution of the selector in seconds""" + return None + @abstractmethod def register(self, fileobj, events, data=None): """Register a file object. @@ -283,6 +288,10 @@ def __init__(self): self._readers = set() self._writers = set() + @property + def resolution(self): + return 1e-6 + def register(self, fileobj, events, data=None): key = super().register(fileobj, events, data) if events & EVENT_READ: @@ -335,6 +344,10 @@ def __init__(self): super().__init__() self._poll = select.poll() + @property + def resolution(self): + return 1e-3 + def register(self, fileobj, events, data=None): key = super().register(fileobj, events, data) poll_events = 0 @@ -385,6 +398,10 @@ def __init__(self): super().__init__() self._epoll = select.epoll() + @property + def resolution(self): + return 1e-3 + def fileno(self): return self._epoll.fileno() @@ -445,6 +462,10 @@ def __init__(self): super().__init__() self._kqueue = select.kqueue() + @property + def resolution(self): + return 1e-9 + def fileno(self): return self._kqueue.fileno() From ce2f7a2aef906d1ebaf04215e5cb1694fc72a600 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 25 Jan 2014 15:09:38 +0100 Subject: [PATCH 0884/1502] Add a granularity attribute to BaseEventLoop: maximum between the resolution of the BaseEventLoop.time() method and the resolution of the selector. The granuarility is used in the scheduler to round time and deadline. --- asyncio/base_events.py | 6 ++++++ asyncio/selector_events.py | 1 + asyncio/test_utils.py | 4 ++++ tests/test_events.py | 23 +++++++++++++++++++++++ tests/test_selector_events.py | 4 +++- tests/test_selectors.py | 4 ++++ 6 files changed, 41 insertions(+), 1 deletion(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 72201aa5..d082bccf 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -18,6 +18,7 @@ import concurrent.futures import heapq import logging +import math import socket import subprocess import time @@ -96,6 +97,7 @@ def __init__(self): self._default_executor = None self._internal_fds = 0 self._running = False + self.granularity = time.get_clock_info('monotonic').resolution def _make_socket_transport(self, sock, protocol, waiter=None, *, extra=None, server=None): @@ -603,6 +605,8 @@ def _run_once(self): elif self._scheduled: # Compute the desired timeout. when = self._scheduled[0]._when + # round deadline aways from zero + when = math.ceil(when / self.granularity) * self.granularity deadline = max(0, when - self.time()) if timeout is None: timeout = deadline @@ -629,6 +633,8 @@ def _run_once(self): # Handle 'later' callbacks that are ready. now = self.time() + # round current time aways from zero + now = math.ceil(now / self.granularity) * self.granularity while self._scheduled: handle = self._scheduled[0] if handle._when > now: diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 19caf79d..d2b3ccc2 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -34,6 +34,7 @@ def __init__(self, selector=None): selector = selectors.DefaultSelector() logger.debug('Using selector: %s', selector.__class__.__name__) self._selector = selector + self.granularity = max(selector.resolution, self.granularity) self._make_self_pipe() def _make_socket_transport(self, sock, protocol, waiter=None, *, diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index ccb44541..42b9cd75 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -144,6 +144,10 @@ class TestSelector(selectors.BaseSelector): def __init__(self): self.keys = {} + @property + def resolution(self): + return 1e-3 + def register(self, fileobj, events, data=None): key = selectors.SelectorKey(fileobj, 0, events, data) self.keys[fileobj] = key diff --git a/tests/test_events.py b/tests/test_events.py index 7bd411c1..bded1a3b 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1156,6 +1156,29 @@ def main(): r.close() w.close() + def test_timeout_rounding(self): + def _run_once(): + self.loop._run_once_counter += 1 + orig_run_once() + + orig_run_once = self.loop._run_once + self.loop._run_once_counter = 0 + self.loop._run_once = _run_once + calls = [] + + @asyncio.coroutine + def wait(): + loop = self.loop + calls.append(loop._run_once_counter) + yield from asyncio.sleep(loop.granularity * 10, loop=loop) + calls.append(loop._run_once_counter) + yield from asyncio.sleep(loop.granularity / 10, loop=loop) + calls.append(loop._run_once_counter) + + self.loop.run_until_complete(wait()) + calls.append(self.loop._run_once_counter) + self.assertEqual(calls, [1, 3, 5, 6]) + class SubprocessTestsMixin: diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 4bc35952..4c81e75d 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -38,7 +38,9 @@ def list_to_buffer(l=()): class BaseSelectorEventLoopTests(unittest.TestCase): def setUp(self): - self.loop = TestBaseSelectorEventLoop(unittest.mock.Mock()) + selector = unittest.mock.Mock() + selector.resolution = 1e-3 + self.loop = TestBaseSelectorEventLoop(selector) def test_make_socket_transport(self): m = unittest.mock.Mock() diff --git a/tests/test_selectors.py b/tests/test_selectors.py index 0519d75a..19098dde 100644 --- a/tests/test_selectors.py +++ b/tests/test_selectors.py @@ -9,6 +9,10 @@ class FakeSelector(selectors._BaseSelectorImpl): """Trivial non-abstract subclass of BaseSelector.""" + @property + def resolution(self): + return 1e-3 + def select(self, timeout=None): raise NotImplementedError From 0b99e50e582f9c9a0d85850fc3a382716f214ab7 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 25 Jan 2014 15:30:37 +0100 Subject: [PATCH 0885/1502] Export more symbols: BaseEventLoop, BaseProactorEventLoop, BaseSelectorEventLoop, Queue and Queue sublasses, Empty, Full --- asyncio/__init__.py | 18 ++++++--- asyncio/proactor_events.py | 2 + asyncio/selector_events.py | 2 + tests/test_base_events.py | 3 +- tests/test_events.py | 3 +- tests/test_proactor_events.py | 11 +++--- tests/test_queues.py | 71 +++++++++++++++++------------------ tests/test_selector_events.py | 3 +- 8 files changed, 60 insertions(+), 53 deletions(-) diff --git a/asyncio/__init__.py b/asyncio/__init__.py index 0d288d5a..95235dcf 100644 --- a/asyncio/__init__.py +++ b/asyncio/__init__.py @@ -18,13 +18,17 @@ import _overlapped # Will also be exported. # This relies on each of the submodules having an __all__ variable. -from .futures import * +from .base_events import * from .events import * +from .futures import * from .locks import * -from .transports import * +from .proactor_events import * from .protocols import * +from .queues import * +from .selector_events import * from .streams import * from .tasks import * +from .transports import * if sys.platform == 'win32': # pragma: no cover from .windows_events import * @@ -32,10 +36,14 @@ from .unix_events import * # pragma: no cover -__all__ = (futures.__all__ + +__all__ = (base_events.__all__ + events.__all__ + + futures.__all__ + locks.__all__ + - transports.__all__ + + proactor_events.__all__ + protocols.__all__ + + queues.__all__ + + selector_events.__all__ + streams.__all__ + - tasks.__all__) + tasks.__all__ + + transports.__all__) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index ba5169e9..3b44f248 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -4,6 +4,8 @@ proactor is only implemented on Windows with IOCP. """ +__all__ = ['BaseProactorEventLoop'] + import socket from . import base_events diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index d2b3ccc2..900eec01 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -4,6 +4,8 @@ also includes support for signal handling, see the unix_events sub-module. """ +__all__ = ['BaseSelectorEventLoop'] + import collections import errno import socket diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 9c2bda1f..8d0796ab 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -9,7 +9,6 @@ from test.support import find_unused_port, IPV6_ENABLED import asyncio -from asyncio import base_events from asyncio import constants from asyncio import test_utils @@ -17,7 +16,7 @@ class BaseEventLoopTests(unittest.TestCase): def setUp(self): - self.loop = base_events.BaseEventLoop() + self.loop = asyncio.BaseEventLoop() self.loop._selector = unittest.mock.Mock() asyncio.set_event_loop(None) diff --git a/tests/test_events.py b/tests/test_events.py index bded1a3b..44879ffa 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -25,7 +25,6 @@ import asyncio from asyncio import events -from asyncio import selector_events from asyncio import test_utils @@ -903,7 +902,7 @@ def datagram_received(self, data, addr): def test_internal_fds(self): loop = self.create_event_loop() - if not isinstance(loop, selector_events.BaseSelectorEventLoop): + if not isinstance(loop, asyncio.BaseSelectorEventLoop): self.skipTest('loop is not a BaseSelectorEventLoop') self.assertEqual(1, loop._internal_fds) diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index 9964f425..1c628000 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -5,7 +5,6 @@ import unittest.mock import asyncio -from asyncio.proactor_events import BaseProactorEventLoop from asyncio.proactor_events import _ProactorSocketTransport from asyncio.proactor_events import _ProactorWritePipeTransport from asyncio.proactor_events import _ProactorDuplexPipeTransport @@ -345,18 +344,18 @@ def setUp(self): self.ssock, self.csock = unittest.mock.Mock(), unittest.mock.Mock() - class EventLoop(BaseProactorEventLoop): + class EventLoop(asyncio.BaseProactorEventLoop): def _socketpair(s): return (self.ssock, self.csock) self.loop = EventLoop(self.proactor) - @unittest.mock.patch.object(BaseProactorEventLoop, 'call_soon') - @unittest.mock.patch.object(BaseProactorEventLoop, '_socketpair') + @unittest.mock.patch.object(asyncio.BaseProactorEventLoop, 'call_soon') + @unittest.mock.patch.object(asyncio.BaseProactorEventLoop, '_socketpair') def test_ctor(self, socketpair, call_soon): ssock, csock = socketpair.return_value = ( unittest.mock.Mock(), unittest.mock.Mock()) - loop = BaseProactorEventLoop(self.proactor) + loop = asyncio.BaseProactorEventLoop(self.proactor) self.assertIs(loop._ssock, ssock) self.assertIs(loop._csock, csock) self.assertEqual(loop._internal_fds, 1) @@ -399,7 +398,7 @@ def test_sock_accept(self): def test_socketpair(self): self.assertRaises( - NotImplementedError, BaseProactorEventLoop, self.proactor) + NotImplementedError, asyncio.BaseProactorEventLoop, self.proactor) def test_make_socket_transport(self): tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock()) diff --git a/tests/test_queues.py b/tests/test_queues.py index ccb89d72..a06ed503 100644 --- a/tests/test_queues.py +++ b/tests/test_queues.py @@ -4,7 +4,6 @@ import unittest.mock import asyncio -from asyncio import queues from asyncio import test_utils @@ -36,14 +35,14 @@ def gen(): loop = test_utils.TestLoop(gen) self.addCleanup(loop.close) - q = queues.Queue(loop=loop) + q = asyncio.Queue(loop=loop) self.assertTrue(fn(q).startswith(' Date: Sat, 25 Jan 2014 15:37:15 +0100 Subject: [PATCH 0886/1502] Strip trailing space --- asyncio/unix_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index bb0d80d6..7a6546d1 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -260,7 +260,7 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): mode = os.fstat(self._fileno).st_mode is_socket = stat.S_ISSOCK(mode) if not (is_socket or - stat.S_ISFIFO(mode) or + stat.S_ISFIFO(mode) or stat.S_ISCHR(mode)): raise ValueError("Pipe transport is only for " "pipes, sockets and character devices") From 6afc0420583430c723429f9ef91b6b7f3034835d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 25 Jan 2014 22:02:14 +0100 Subject: [PATCH 0887/1502] Don't export BaseEventLoop, BaseSelectorEventLoop nor BaseProactorEventLoop Import them from submodules if you really need them. --- asyncio/__init__.py | 8 +------- tests/test_base_events.py | 3 ++- tests/test_events.py | 3 ++- tests/test_proactor_events.py | 11 ++++++----- tests/test_selector_events.py | 3 ++- 5 files changed, 13 insertions(+), 15 deletions(-) diff --git a/asyncio/__init__.py b/asyncio/__init__.py index 95235dcf..eb22c385 100644 --- a/asyncio/__init__.py +++ b/asyncio/__init__.py @@ -18,14 +18,11 @@ import _overlapped # Will also be exported. # This relies on each of the submodules having an __all__ variable. -from .base_events import * from .events import * from .futures import * from .locks import * -from .proactor_events import * from .protocols import * from .queues import * -from .selector_events import * from .streams import * from .tasks import * from .transports import * @@ -36,14 +33,11 @@ from .unix_events import * # pragma: no cover -__all__ = (base_events.__all__ + - events.__all__ + +__all__ = (events.__all__ + futures.__all__ + locks.__all__ + - proactor_events.__all__ + protocols.__all__ + queues.__all__ + - selector_events.__all__ + streams.__all__ + tasks.__all__ + transports.__all__) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 8d0796ab..9c2bda1f 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -9,6 +9,7 @@ from test.support import find_unused_port, IPV6_ENABLED import asyncio +from asyncio import base_events from asyncio import constants from asyncio import test_utils @@ -16,7 +17,7 @@ class BaseEventLoopTests(unittest.TestCase): def setUp(self): - self.loop = asyncio.BaseEventLoop() + self.loop = base_events.BaseEventLoop() self.loop._selector = unittest.mock.Mock() asyncio.set_event_loop(None) diff --git a/tests/test_events.py b/tests/test_events.py index 44879ffa..bded1a3b 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -25,6 +25,7 @@ import asyncio from asyncio import events +from asyncio import selector_events from asyncio import test_utils @@ -902,7 +903,7 @@ def datagram_received(self, data, addr): def test_internal_fds(self): loop = self.create_event_loop() - if not isinstance(loop, asyncio.BaseSelectorEventLoop): + if not isinstance(loop, selector_events.BaseSelectorEventLoop): self.skipTest('loop is not a BaseSelectorEventLoop') self.assertEqual(1, loop._internal_fds) diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index 1c628000..9964f425 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -5,6 +5,7 @@ import unittest.mock import asyncio +from asyncio.proactor_events import BaseProactorEventLoop from asyncio.proactor_events import _ProactorSocketTransport from asyncio.proactor_events import _ProactorWritePipeTransport from asyncio.proactor_events import _ProactorDuplexPipeTransport @@ -344,18 +345,18 @@ def setUp(self): self.ssock, self.csock = unittest.mock.Mock(), unittest.mock.Mock() - class EventLoop(asyncio.BaseProactorEventLoop): + class EventLoop(BaseProactorEventLoop): def _socketpair(s): return (self.ssock, self.csock) self.loop = EventLoop(self.proactor) - @unittest.mock.patch.object(asyncio.BaseProactorEventLoop, 'call_soon') - @unittest.mock.patch.object(asyncio.BaseProactorEventLoop, '_socketpair') + @unittest.mock.patch.object(BaseProactorEventLoop, 'call_soon') + @unittest.mock.patch.object(BaseProactorEventLoop, '_socketpair') def test_ctor(self, socketpair, call_soon): ssock, csock = socketpair.return_value = ( unittest.mock.Mock(), unittest.mock.Mock()) - loop = asyncio.BaseProactorEventLoop(self.proactor) + loop = BaseProactorEventLoop(self.proactor) self.assertIs(loop._ssock, ssock) self.assertIs(loop._csock, csock) self.assertEqual(loop._internal_fds, 1) @@ -398,7 +399,7 @@ def test_sock_accept(self): def test_socketpair(self): self.assertRaises( - NotImplementedError, asyncio.BaseProactorEventLoop, self.proactor) + NotImplementedError, BaseProactorEventLoop, self.proactor) def test_make_socket_transport(self): tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock()) diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 5a074aed..4c81e75d 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -16,13 +16,14 @@ import asyncio from asyncio import selectors from asyncio import test_utils +from asyncio.selector_events import BaseSelectorEventLoop from asyncio.selector_events import _SelectorTransport from asyncio.selector_events import _SelectorSslTransport from asyncio.selector_events import _SelectorSocketTransport from asyncio.selector_events import _SelectorDatagramTransport -class TestBaseSelectorEventLoop(asyncio.BaseSelectorEventLoop): +class TestBaseSelectorEventLoop(BaseSelectorEventLoop): def _make_self_pipe(self): self._ssock = unittest.mock.Mock() From ed2154902e163264ad83f8135420797070c605b2 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 25 Jan 2014 23:59:41 +0100 Subject: [PATCH 0888/1502] Simplify BaseEventLoop._run_once(): avoid math.ceil(), use simple arithmetic instead --- asyncio/base_events.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index d082bccf..6b5116c7 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -18,7 +18,6 @@ import concurrent.futures import heapq import logging -import math import socket import subprocess import time @@ -605,8 +604,6 @@ def _run_once(self): elif self._scheduled: # Compute the desired timeout. when = self._scheduled[0]._when - # round deadline aways from zero - when = math.ceil(when / self.granularity) * self.granularity deadline = max(0, when - self.time()) if timeout is None: timeout = deadline @@ -632,9 +629,7 @@ def _run_once(self): self._process_events(event_list) # Handle 'later' callbacks that are ready. - now = self.time() - # round current time aways from zero - now = math.ceil(now / self.granularity) * self.granularity + now = self.time() + self.granularity while self._scheduled: handle = self._scheduled[0] if handle._when > now: From c1a64b5e2582f7463dec47b4c410f4908dc0237a Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sun, 26 Jan 2014 00:00:59 +0100 Subject: [PATCH 0889/1502] Make the new granularity attribute private --- asyncio/base_events.py | 4 ++-- asyncio/selector_events.py | 2 +- tests/test_events.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 6b5116c7..5694f296 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -96,7 +96,7 @@ def __init__(self): self._default_executor = None self._internal_fds = 0 self._running = False - self.granularity = time.get_clock_info('monotonic').resolution + self._granularity = time.get_clock_info('monotonic').resolution def _make_socket_transport(self, sock, protocol, waiter=None, *, extra=None, server=None): @@ -629,7 +629,7 @@ def _run_once(self): self._process_events(event_list) # Handle 'later' callbacks that are ready. - now = self.time() + self.granularity + now = self.time() + self._granularity while self._scheduled: handle = self._scheduled[0] if handle._when > now: diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 900eec01..94408f82 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -36,7 +36,7 @@ def __init__(self, selector=None): selector = selectors.DefaultSelector() logger.debug('Using selector: %s', selector.__class__.__name__) self._selector = selector - self.granularity = max(selector.resolution, self.granularity) + self._granularity = max(selector.resolution, self._granularity) self._make_self_pipe() def _make_socket_transport(self, sock, protocol, waiter=None, *, diff --git a/tests/test_events.py b/tests/test_events.py index bded1a3b..fe5b2246 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1170,9 +1170,9 @@ def _run_once(): def wait(): loop = self.loop calls.append(loop._run_once_counter) - yield from asyncio.sleep(loop.granularity * 10, loop=loop) + yield from asyncio.sleep(loop._granularity * 10, loop=loop) calls.append(loop._run_once_counter) - yield from asyncio.sleep(loop.granularity / 10, loop=loop) + yield from asyncio.sleep(loop._granularity / 10, loop=loop) calls.append(loop._run_once_counter) self.loop.run_until_complete(wait()) From aa84fa66df20dd65c76b913a2571da752ad12336 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 25 Jan 2014 16:31:15 -0800 Subject: [PATCH 0890/1502] Fix race in FastChildWatcher (by its original author, Anthony Baire). --- asyncio/unix_events.py | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 7a6546d1..24186420 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -641,22 +641,16 @@ def __exit__(self, a, b, c): def add_child_handler(self, pid, callback, *args): assert self._forks, "Must use the context manager" + with self._lock: + try: + returncode = self._zombies.pop(pid) + except KeyError: + # The child is running. + self._callbacks[pid] = callback, args + return - self._callbacks[pid] = callback, args - - try: - # Ensure that the child is not already terminated. - # (raise KeyError if still alive) - returncode = self._zombies.pop(pid) - - # Child is dead, therefore we can fire the callback immediately. - # First we remove it from the dict. - # (raise KeyError if .remove_child_handler() was called in-between) - del self._callbacks[pid] - except KeyError: - pass - else: - callback(pid, returncode, *args) + # The child is dead already. We can fire the callback. + callback(pid, returncode, *args) def remove_child_handler(self, pid): try: @@ -681,16 +675,18 @@ def _do_waitpid_all(self): returncode = self._compute_returncode(status) - try: - callback, args = self._callbacks.pop(pid) - except KeyError: - # unknown child - with self._lock: + with self._lock: + try: + callback, args = self._callbacks.pop(pid) + except KeyError: + # unknown child if self._forks: # It may not be registered yet. self._zombies[pid] = returncode continue + callback = None + if callback is None: logger.warning( "Caught subprocess termination from unknown pid: " "%d -> %d", pid, returncode) From f82d93b460d3f32c39378dc4279a6e6c59799f94 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 25 Jan 2014 16:48:34 -0800 Subject: [PATCH 0891/1502] Locks refactor: use a separate context manager; remove Semaphore._locked. --- asyncio/locks.py | 82 +++++++++++++++++++++++++++++++++------------ tests/test_locks.py | 35 +++++++++++++++++++ 2 files changed, 95 insertions(+), 22 deletions(-) diff --git a/asyncio/locks.py b/asyncio/locks.py index 9fdb9374..99c71454 100644 --- a/asyncio/locks.py +++ b/asyncio/locks.py @@ -9,6 +9,36 @@ from . import tasks +class _ContextManager: + """Context manager. + + This enables the following idiom for acquiring and releasing a + lock around a block: + + with (yield from lock): + + + while failing loudly when accidentally using: + + with lock: + + """ + + def __init__(self, lock): + self._lock = lock + + def __enter__(self): + # We have no use for the "as ..." clause in the with + # statement for locks. + return None + + def __exit__(self, *args): + try: + self._lock.release() + finally: + self._lock = None # Crudely prevent reuse. + + class Lock: """Primitive lock objects. @@ -124,17 +154,29 @@ def release(self): raise RuntimeError('Lock is not acquired.') def __enter__(self): - if not self._locked: - raise RuntimeError( - '"yield from" should be used as context manager expression') - return True + raise RuntimeError( + '"yield from" should be used as context manager expression') def __exit__(self, *args): - self.release() + # This must exist because __enter__ exists, even though that + # always raises; that's how the with-statement works. + pass def __iter__(self): + # This is not a coroutine. It is meant to enable the idiom: + # + # with (yield from lock): + # + # + # as an alternative to: + # + # yield from lock.acquire() + # try: + # + # finally: + # lock.release() yield from self.acquire() - return self + return _ContextManager(self) class Event: @@ -311,14 +353,16 @@ def notify_all(self): self.notify(len(self._waiters)) def __enter__(self): - return self._lock.__enter__() + raise RuntimeError( + '"yield from" should be used as context manager expression') def __exit__(self, *args): - return self._lock.__exit__(*args) + pass def __iter__(self): + # See comment in Lock.__iter__(). yield from self.acquire() - return self + return _ContextManager(self) class Semaphore: @@ -341,7 +385,6 @@ def __init__(self, value=1, *, loop=None): raise ValueError("Semaphore initial value must be >= 0") self._value = value self._waiters = collections.deque() - self._locked = (value == 0) if loop is not None: self._loop = loop else: @@ -349,7 +392,7 @@ def __init__(self, value=1, *, loop=None): def __repr__(self): res = super().__repr__() - extra = 'locked' if self._locked else 'unlocked,value:{}'.format( + extra = 'locked' if self.locked() else 'unlocked,value:{}'.format( self._value) if self._waiters: extra = '{},waiters:{}'.format(extra, len(self._waiters)) @@ -357,7 +400,7 @@ def __repr__(self): def locked(self): """Returns True if semaphore can not be acquired immediately.""" - return self._locked + return self._value == 0 @tasks.coroutine def acquire(self): @@ -371,8 +414,6 @@ def acquire(self): """ if not self._waiters and self._value > 0: self._value -= 1 - if self._value == 0: - self._locked = True return True fut = futures.Future(loop=self._loop) @@ -380,8 +421,6 @@ def acquire(self): try: yield from fut self._value -= 1 - if self._value == 0: - self._locked = True return True finally: self._waiters.remove(fut) @@ -392,23 +431,22 @@ def release(self): become larger than zero again, wake up that coroutine. """ self._value += 1 - self._locked = False for waiter in self._waiters: if not waiter.done(): waiter.set_result(True) break def __enter__(self): - # TODO: This is questionable. How do we know the user actually - # wrote "with (yield from sema)" instead of "with sema"? - return True + raise RuntimeError( + '"yield from" should be used as context manager expression') def __exit__(self, *args): - self.release() + pass def __iter__(self): + # See comment in Lock.__iter__(). yield from self.acquire() - return self + return _ContextManager(self) class BoundedSemaphore(Semaphore): diff --git a/tests/test_locks.py b/tests/test_locks.py index 5d0e09e0..0975f497 100644 --- a/tests/test_locks.py +++ b/tests/test_locks.py @@ -208,6 +208,24 @@ def acquire_lock(): self.assertFalse(lock.locked()) + def test_context_manager_cant_reuse(self): + lock = asyncio.Lock(loop=self.loop) + + @asyncio.coroutine + def acquire_lock(): + return (yield from lock) + + # This spells "yield from lock" outside a generator. + cm = self.loop.run_until_complete(acquire_lock()) + with cm: + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + with self.assertRaises(AttributeError): + with cm: + pass + def test_context_manager_no_yield(self): lock = asyncio.Lock(loop=self.loop) @@ -219,6 +237,8 @@ def test_context_manager_no_yield(self): str(err), '"yield from" should be used as context manager expression') + self.assertFalse(lock.locked()) + class EventTests(unittest.TestCase): @@ -655,6 +675,8 @@ def test_context_manager_no_yield(self): str(err), '"yield from" should be used as context manager expression') + self.assertFalse(cond.locked()) + class SemaphoreTests(unittest.TestCase): @@ -830,6 +852,19 @@ def acquire_lock(): self.assertEqual(2, sem._value) + def test_context_manager_no_yield(self): + sem = asyncio.Semaphore(2, loop=self.loop) + + try: + with sem: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + self.assertEqual(2, sem._value) + if __name__ == '__main__': unittest.main() From 3d6d6a6688d498589df5b9fb018335a272adddef Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 25 Jan 2014 16:52:40 -0800 Subject: [PATCH 0892/1502] Fix whitespace. --- asyncio/locks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/locks.py b/asyncio/locks.py index 99c71454..29c4434a 100644 --- a/asyncio/locks.py +++ b/asyncio/locks.py @@ -37,7 +37,7 @@ def __exit__(self, *args): self._lock.release() finally: self._lock = None # Crudely prevent reuse. - + class Lock: """Primitive lock objects. From 0229874327d4b9335da58595802503f1015e2ed0 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 25 Jan 2014 17:23:38 -0800 Subject: [PATCH 0893/1502] Rename {Empty,Full} to {QueueEmpty,QueueFull} and no longer get them from queue.py. --- asyncio/queues.py | 28 +++++++++++++++++++--------- tests/test_queues.py | 4 ++-- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/asyncio/queues.py b/asyncio/queues.py index e900278f..bd62c606 100644 --- a/asyncio/queues.py +++ b/asyncio/queues.py @@ -1,11 +1,10 @@ """Queues""" __all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue', - 'Full', 'Empty'] + 'QueueFull', 'QueueEmpty'] import collections import heapq -import queue from . import events from . import futures @@ -13,9 +12,20 @@ from .tasks import coroutine -# Re-export queue.Full and .Empty exceptions. -Full = queue.Full -Empty = queue.Empty +class QueueEmpty(Exception): + 'Exception raised by Queue.get(block=0)/get_nowait().' + pass + + +class QueueFull(Exception): + 'Exception raised by Queue.put(block=0)/put_nowait().' + pass + + +# Un-exported aliases for temporary backward compatibility. +# Will disappear soon. +Full = QueueFull +Empty = QueueEmpty class Queue: @@ -134,7 +144,7 @@ def put(self, item): def put_nowait(self, item): """Put an item into the queue without blocking. - If no free slot is immediately available, raise Full. + If no free slot is immediately available, raise QueueFull. """ self._consume_done_getters() if self._getters: @@ -149,7 +159,7 @@ def put_nowait(self, item): getter.set_result(self._get()) elif self._maxsize > 0 and self._maxsize == self.qsize(): - raise Full + raise QueueFull else: self._put(item) @@ -184,7 +194,7 @@ def get(self): def get_nowait(self): """Remove and return an item from the queue. - Return an item if one is immediately available, else raise Empty. + Return an item if one is immediately available, else raise QueueEmpty. """ self._consume_done_putters() if self._putters: @@ -199,7 +209,7 @@ def get_nowait(self): elif self.qsize(): return self._get() else: - raise Empty + raise QueueEmpty class PriorityQueue(Queue): diff --git a/tests/test_queues.py b/tests/test_queues.py index a06ed503..fc2bf460 100644 --- a/tests/test_queues.py +++ b/tests/test_queues.py @@ -230,7 +230,7 @@ def test_nonblocking_get(self): def test_nonblocking_get_exception(self): q = asyncio.Queue(loop=self.loop) - self.assertRaises(asyncio.Empty, q.get_nowait) + self.assertRaises(asyncio.QueueEmpty, q.get_nowait) def test_get_cancelled(self): @@ -337,7 +337,7 @@ def test_nonblocking_put(self): def test_nonblocking_put_exception(self): q = asyncio.Queue(maxsize=1, loop=self.loop) q.put_nowait(1) - self.assertRaises(asyncio.Full, q.put_nowait, 2) + self.assertRaises(asyncio.QueueFull, q.put_nowait, 2) def test_put_cancelled(self): q = asyncio.Queue(loop=self.loop) From c6983c4b41ba5d156bf78d6375f425eaa113b19b Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sun, 26 Jan 2014 18:25:06 +0200 Subject: [PATCH 0894/1502] Code cleanup: remove unused function --- asyncio/base_subprocess.py | 1 - asyncio/unix_events.py | 3 --- asyncio/windows_events.py | 3 --- 3 files changed, 7 deletions(-) diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index d15fb159..c5efda79 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -114,7 +114,6 @@ def _process_exited(self, returncode): assert returncode is not None, returncode assert self._returncode is None, self._returncode self._returncode = returncode - self._loop._subprocess_closed(self) self._call(self._protocol.process_exited) self._try_finish() diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 24186420..219c88a0 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -169,9 +169,6 @@ def _make_subprocess_transport(self, protocol, args, shell, def _child_watcher_callback(self, pid, returncode, transp): self.call_soon_threadsafe(transp._process_exited, returncode) - def _subprocess_closed(self, transp): - pass - def _set_nonblocking(fd): flags = fcntl.fcntl(fd, fcntl.F_GETFL) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 2e9ec697..3c21e43b 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -178,9 +178,6 @@ def _make_subprocess_transport(self, protocol, args, shell, yield from transp._post_init() return transp - def _subprocess_closed(self, transport): - pass - class IocpProactor: """Proactor implementation using IOCP.""" From a5a0d8bd9d36b199d5deeb3eee9f87b1f259b2eb Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 26 Jan 2014 16:05:07 -0800 Subject: [PATCH 0895/1502] The standard readexactly() now raises on a short read, so kill the custom wrapper. --- examples/crawl.py | 24 ++---------------------- 1 file changed, 2 insertions(+), 22 deletions(-) diff --git a/examples/crawl.py b/examples/crawl.py index f0c8e3d1..b7333ce1 100644 --- a/examples/crawl.py +++ b/examples/crawl.py @@ -411,26 +411,6 @@ def get_header(self, key, default=''): return v return default - @asyncio.coroutine - def readexactly(self, nbytes): - """Wrapper for readexactly() that raise EOFError if not enough data. - - This also logs (at the vvv level) while it is reading. - """ - blocks = [] - nread = 0 - while nread < nbytes: - self.log(3, 'reading block', len(blocks), - 'with', nbytes - nread, 'bytes remaining') - block = yield from self.reader.read(nbytes-nread) - self.log(3, 'read', len(block), 'bytes') - if not block: - raise EOFError('EOF with %d more bytes expected' % - (nbytes - nread)) - blocks.append(block) - nread += len(block) - return b''.join(blocks) - @asyncio.coroutine def read(self): """Read the response body. @@ -456,7 +436,7 @@ def read(self): size = int(parts[0], 16) if size: self.log(3, 'reading chunk of', size, 'bytes') - block = yield from self.readexactly(size) + block = yield from self.reader.readexactly(size) assert len(block) == size, (len(block), size) blocks.append(block) crlf = yield from self.reader.readline() @@ -472,7 +452,7 @@ def read(self): # TODO: Should make sure not to recycle the connection # in this case. else: - body = yield from self.readexactly(nbytes) + body = yield from self.reader.readexactly(nbytes) return body From 844462dc41b1caf6aa9b002a7d9bcf3246ae1514 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 27 Jan 2014 15:43:28 -0800 Subject: [PATCH 0896/1502] Remove temporary aliases Full/Empty. --- asyncio/queues.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/asyncio/queues.py b/asyncio/queues.py index bd62c606..6283db32 100644 --- a/asyncio/queues.py +++ b/asyncio/queues.py @@ -22,12 +22,6 @@ class QueueFull(Exception): pass -# Un-exported aliases for temporary backward compatibility. -# Will disappear soon. -Full = QueueFull -Empty = QueueEmpty - - class Queue: """A queue, useful for coordinating producer and consumer coroutines. From f2353232f0c809326abf61abeedc607c846f672c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 28 Jan 2014 09:18:03 -0800 Subject: [PATCH 0897/1502] Close loop in crawl.py example (mostly for IOCP). --- examples/crawl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/crawl.py b/examples/crawl.py index b7333ce1..0e99c82e 100644 --- a/examples/crawl.py +++ b/examples/crawl.py @@ -855,6 +855,7 @@ def main(): finally: crawler.report() crawler.close() + loop.close() if __name__ == '__main__': From 250007fd5fa2d164a3744026281757aecb98b54e Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 28 Jan 2014 02:09:42 +0100 Subject: [PATCH 0898/1502] _fatal_error() of _UnixWritePipeTransport and _ProactorBasePipeTransport don't log BrokenPipeError nor ConnectionResetError Same behaviour than _SelectorTransport._fatal_error() --- asyncio/proactor_events.py | 3 ++- asyncio/unix_events.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 3b44f248..d2553eb7 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -54,7 +54,8 @@ def close(self): self._read_fut.cancel() def _fatal_error(self, exc): - logger.exception('Fatal error for %s', self) + if not isinstance(exc, (BrokenPipeError, ConnectionResetError)): + logger.exception('Fatal error for %s', self) self._force_close(exc) def _force_close(self, exc): diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 219c88a0..a1aff3f1 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -363,7 +363,8 @@ def abort(self): def _fatal_error(self, exc): # should be called by exception handler only - logger.exception('Fatal error for %s', self) + if not isinstance(exc, (BrokenPipeError, ConnectionResetError)): + logger.exception('Fatal error for %s', self) self._close(exc) def _close(self, exc=None): From 5bd0dfdadb676f47990e79089bf24ca756c301cb Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 27 Jan 2014 18:22:40 -0800 Subject: [PATCH 0899/1502] Refactoring: move write flow control to a subclass/mixin. --- asyncio/selector_events.py | 98 ++++++++++++++++++++++++-------------- 1 file changed, 61 insertions(+), 37 deletions(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 94408f82..36901452 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -339,7 +339,67 @@ def _stop_serving(self, sock): sock.close() -class _SelectorTransport(transports.Transport): +class _FlowControlMixin(transports.Transport): + """All the logic for (write) flow control in a mix-in base class. + + The subclass must implement get_write_buffer_size(). It must call + _maybe_pause_protocol() whenever the write buffer size increases, + and _maybe_resume_protocol() whenever it decreases. It may also + override set_write_buffer_limits() (e.g. to specify different + defaults). + + The subclass constructor must call super().__init__(extra). This + will call set_write_buffer_limits(). + + The user may call set_write_buffer_limits() and + get_write_buffer_size(), and their protocol's pause_writing() and + resume_writing() may be called. + """ + + def __init__(self, extra=None): + super().__init__(extra) + self._protocol_paused = False + self.set_write_buffer_limits() + + def _maybe_pause_protocol(self): + size = self.get_write_buffer_size() + if size <= self._high_water: + return + if not self._protocol_paused: + self._protocol_paused = True + try: + self._protocol.pause_writing() + except Exception: + logger.exception('pause_writing() failed') + + def _maybe_resume_protocol(self): + if (self._protocol_paused and + self.get_write_buffer_size() <= self._low_water): + self._protocol_paused = False + try: + self._protocol.resume_writing() + except Exception: + logger.exception('resume_writing() failed') + + def set_write_buffer_limits(self, high=None, low=None): + if high is None: + if low is None: + high = 64*1024 + else: + high = 4*low + if low is None: + low = high // 4 + if not high >= low >= 0: + raise ValueError('high (%r) must be >= low (%r) must be >= 0' % + (high, low)) + self._high_water = high + self._low_water = low + + def get_write_buffer_size(self): + raise NotImplementedError + + +class _SelectorTransport(_FlowControlMixin, transports.Transport): max_size = 256 * 1024 # Buffer size passed to recv(). @@ -362,8 +422,6 @@ def __init__(self, loop, sock, protocol, extra, server=None): self._buffer = self._buffer_factory() self._conn_lost = 0 # Set when call to connection_lost scheduled. self._closing = False # Set when close() called. - self._protocol_paused = False - self.set_write_buffer_limits() if self._server is not None: self._server.attach(self) @@ -410,40 +468,6 @@ def _call_connection_lost(self, exc): server.detach(self) self._server = None - def _maybe_pause_protocol(self): - size = self.get_write_buffer_size() - if size <= self._high_water: - return - if not self._protocol_paused: - self._protocol_paused = True - try: - self._protocol.pause_writing() - except Exception: - logger.exception('pause_writing() failed') - - def _maybe_resume_protocol(self): - if (self._protocol_paused and - self.get_write_buffer_size() <= self._low_water): - self._protocol_paused = False - try: - self._protocol.resume_writing() - except Exception: - logger.exception('resume_writing() failed') - - def set_write_buffer_limits(self, high=None, low=None): - if high is None: - if low is None: - high = 64*1024 - else: - high = 4*low - if low is None: - low = high // 4 - if not high >= low >= 0: - raise ValueError('high (%r) must be >= low (%r) must be >= 0' % - (high, low)) - self._high_water = high - self._low_water = low - def get_write_buffer_size(self): return len(self._buffer) From 0e313f439d8c42d283e94d2c7c3be1e51dc22375 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 28 Jan 2014 09:47:27 -0800 Subject: [PATCH 0900/1502] Add write flow control to unix pipes. --- asyncio/unix_events.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index a1aff3f1..05aa2721 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -246,7 +246,8 @@ def _call_connection_lost(self, exc): self._loop = None -class _UnixWritePipeTransport(transports.WriteTransport): +class _UnixWritePipeTransport(selector_events._FlowControlMixin, + transports.WriteTransport): def __init__(self, loop, pipe, protocol, waiter=None, extra=None): super().__init__(extra) @@ -277,12 +278,17 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): if waiter is not None: self._loop.call_soon(waiter.set_result, None) + def get_write_buffer_size(self): + return sum(len(data) for data in self._buffer) + def _read_ready(self): # Pipe was closed by peer. self._close() def write(self, data): - assert isinstance(data, bytes), repr(data) + assert isinstance(data, (bytes, bytearray, memoryview)), repr(data) + if isinstance(data, bytearray): + data = memoryview(data) if not data: return @@ -310,6 +316,7 @@ def write(self, data): self._loop.add_writer(self._fileno, self._write_ready) self._buffer.append(data) + self._maybe_pause_protocol() def _write_ready(self): data = b''.join(self._buffer) @@ -329,7 +336,8 @@ def _write_ready(self): else: if n == len(data): self._loop.remove_writer(self._fileno) - if self._closing: + self._maybe_resume_protocol() # May append to buffer. + if not self._buffer and self._closing: self._loop.remove_reader(self._fileno) self._call_connection_lost(None) return From 34acced42a8688d43c5a92047bbe54f4e21591ea Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 28 Jan 2014 08:37:36 -0800 Subject: [PATCH 0901/1502] Refactoring: get rid of _try_connected(). --- asyncio/base_subprocess.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index c5efda79..cc4f8cb7 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -75,19 +75,27 @@ def _post_init(self): proc = self._proc loop = self._loop if proc.stdin is not None: - transp, proto = yield from loop.connect_write_pipe( + _, pipe = yield from loop.connect_write_pipe( lambda: WriteSubprocessPipeProto(self, STDIN), proc.stdin) + self._pipes[STDIN] = pipe if proc.stdout is not None: - transp, proto = yield from loop.connect_read_pipe( + _, pipe = yield from loop.connect_read_pipe( lambda: ReadSubprocessPipeProto(self, STDOUT), proc.stdout) + self._pipes[STDOUT] = pipe if proc.stderr is not None: - transp, proto = yield from loop.connect_read_pipe( + _, pipe = yield from loop.connect_read_pipe( lambda: ReadSubprocessPipeProto(self, STDERR), proc.stderr) - if not self._pipes: - self._try_connected() + self._pipes[STDERR] = pipe + + assert self._pending_calls is not None + + self._loop.call_soon(self._protocol.connection_made, self) + for callback, data in self._pending_calls: + self._loop.call_soon(callback, *data) + self._pending_calls = None def _call(self, cb, *data): if self._pending_calls is not None: @@ -95,14 +103,6 @@ def _call(self, cb, *data): else: self._loop.call_soon(cb, *data) - def _try_connected(self): - assert self._pending_calls is not None - if all(p is not None and p.connected for p in self._pipes.values()): - self._loop.call_soon(self._protocol.connection_made, self) - for callback, data in self._pending_calls: - self._loop.call_soon(callback, *data) - self._pending_calls = None - def _pipe_connection_lost(self, fd, exc): self._call(self._protocol.pipe_connection_lost, fd, exc) self._try_finish() @@ -136,19 +136,15 @@ def _call_connection_lost(self, exc): class WriteSubprocessPipeProto(protocols.BaseProtocol): - pipe = None def __init__(self, proc, fd): self.proc = proc self.fd = fd - self.connected = False + self.pipe = None self.disconnected = False - proc._pipes[fd] = self def connection_made(self, transport): - self.connected = True self.pipe = transport - self.proc._try_connected() def connection_lost(self, exc): self.disconnected = True From 7220ae3b3af11e07f36e4e2462303219e23c8231 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 28 Jan 2014 09:54:06 -0800 Subject: [PATCH 0902/1502] Refactor drain logic in streams.py to be reusable. --- asyncio/streams.py | 97 +++++++++++++++++++++++++++++----------------- 1 file changed, 61 insertions(+), 36 deletions(-) diff --git a/asyncio/streams.py b/asyncio/streams.py index 10d3591f..bd77cabb 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -94,8 +94,63 @@ def factory(): return (yield from loop.create_server(factory, host, port, **kwds)) -class StreamReaderProtocol(protocols.Protocol): - """Trivial helper class to adapt between Protocol and StreamReader. +class FlowControlMixin(protocols.Protocol): + """Reusable flow control logic for StreamWriter.drain(). + + This implements the protocol methods pause_writing(), + resume_reading() and connection_lost(). If the subclass overrides + these it must call the super methods. + + StreamWriter.drain() must check for error conditions and then call + _make_drain_waiter(), which will return either () or a Future + depending on the paused state. + """ + + def __init__(self, loop=None): + self._loop = loop # May be None; we may never need it. + self._paused = False + self._drain_waiter = None + + def pause_writing(self): + assert not self._paused + self._paused = True + + def resume_writing(self): + assert self._paused + self._paused = False + waiter = self._drain_waiter + if waiter is not None: + self._drain_waiter = None + if not waiter.done(): + waiter.set_result(None) + + def connection_lost(self, exc): + # Wake up the writer if currently paused. + if not self._paused: + return + waiter = self._drain_waiter + if waiter is None: + return + self._drain_waiter = None + if waiter.done(): + return + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) + + def _make_drain_waiter(self): + if not self._paused: + return () + waiter = self._drain_waiter + assert waiter is None or waiter.cancelled() + waiter = futures.Future(loop=self._loop) + self._drain_waiter = waiter + return waiter + + +class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): + """Helper class to adapt between Protocol and StreamReader. (This is a helper class instead of making StreamReader itself a Protocol subclass, because the StreamReader has other potential @@ -104,12 +159,10 @@ class StreamReaderProtocol(protocols.Protocol): """ def __init__(self, stream_reader, client_connected_cb=None, loop=None): + super().__init__(loop=loop) self._stream_reader = stream_reader self._stream_writer = None - self._drain_waiter = None - self._paused = False self._client_connected_cb = client_connected_cb - self._loop = loop # May be None; we may never need it. def connection_made(self, transport): self._stream_reader.set_transport(transport) @@ -127,16 +180,7 @@ def connection_lost(self, exc): self._stream_reader.feed_eof() else: self._stream_reader.set_exception(exc) - # Also wake up the writing side. - if self._paused: - waiter = self._drain_waiter - if waiter is not None: - self._drain_waiter = None - if not waiter.done(): - if exc is None: - waiter.set_result(None) - else: - waiter.set_exception(exc) + super().connection_lost(exc) def data_received(self, data): self._stream_reader.feed_data(data) @@ -144,19 +188,6 @@ def data_received(self, data): def eof_received(self): self._stream_reader.feed_eof() - def pause_writing(self): - assert not self._paused - self._paused = True - - def resume_writing(self): - assert self._paused - self._paused = False - waiter = self._drain_waiter - if waiter is not None: - self._drain_waiter = None - if not waiter.done(): - waiter.set_result(None) - class StreamWriter: """Wraps a Transport. @@ -211,17 +242,11 @@ def drain(self): completed, which will happen when the buffer is (partially) drained and the protocol is resumed. """ - if self._reader._exception is not None: + if self._reader is not None and self._reader._exception is not None: raise self._reader._exception if self._transport._conn_lost: # Uses private variable. raise ConnectionResetError('Connection lost') - if not self._protocol._paused: - return () - waiter = self._protocol._drain_waiter - assert waiter is None or waiter.cancelled() - waiter = futures.Future(loop=self._loop) - self._protocol._drain_waiter = waiter - return waiter + return self._protocol._make_drain_waiter() class StreamReader: From 30986ed26d9ed038a763387c398c010aed070e6d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 28 Jan 2014 08:34:19 -0800 Subject: [PATCH 0903/1502] pass through pause/resume from subprocess pipe proto to subprocess proto. --- asyncio/base_subprocess.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index cc4f8cb7..b7cdbcef 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -150,8 +150,11 @@ def connection_lost(self, exc): self.disconnected = True self.proc._pipe_connection_lost(self.fd, exc) - def eof_received(self): - pass + def pause_writing(self): + self.proc._protocol.pause_writing() + + def resume_writing(self): + self.proc._protocol.resume_writing() class ReadSubprocessPipeProto(WriteSubprocessPipeProto, From eaa0eb342628d4a9ea7d5bd803f9000c8309d28d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 28 Jan 2014 23:40:51 +0100 Subject: [PATCH 0904/1502] examples: close the event loop --- examples/cacheclt.py | 11 +++++++---- examples/simple_tcp_server.py | 7 +++++-- examples/tcp_echo.py | 5 ++++- examples/timing_tcp_server.py | 7 +++++-- examples/udp_echo.py | 5 ++++- 5 files changed, 25 insertions(+), 10 deletions(-) diff --git a/examples/cacheclt.py b/examples/cacheclt.py index cd8da070..d1891889 100644 --- a/examples/cacheclt.py +++ b/examples/cacheclt.py @@ -175,10 +175,13 @@ def main(): if args.tls: sslctx = test_utils.dummy_ssl_context() cache = CacheClient(args.host, args.port, sslctx=sslctx, loop=loop) - loop.run_until_complete( - asyncio.gather( - *[testing(i, cache, loop) for i in range(args.ntasks)], - loop=loop)) + try: + loop.run_until_complete( + asyncio.gather( + *[testing(i, cache, loop) for i in range(args.ntasks)], + loop=loop)) + finally: + loop.close() @asyncio.coroutine diff --git a/examples/simple_tcp_server.py b/examples/simple_tcp_server.py index c36710d2..0e87d5b6 100644 --- a/examples/simple_tcp_server.py +++ b/examples/simple_tcp_server.py @@ -143,8 +143,11 @@ def recv(): yield from asyncio.sleep(0.5) # creates a client and connects to our server - msg = loop.run_until_complete(client()) - server.stop(loop) + try: + msg = loop.run_until_complete(client()) + server.stop(loop) + finally: + loop.close() if __name__ == '__main__': diff --git a/examples/tcp_echo.py b/examples/tcp_echo.py index 3d9b39c8..45ece157 100755 --- a/examples/tcp_echo.py +++ b/examples/tcp_echo.py @@ -110,4 +110,7 @@ def start_server(loop, host, port): else: start_client(loop, args.host, args.port) - loop.run_forever() + try: + loop.run_forever() + finally: + loop.close() diff --git a/examples/timing_tcp_server.py b/examples/timing_tcp_server.py index 6f0483eb..cb43a796 100644 --- a/examples/timing_tcp_server.py +++ b/examples/timing_tcp_server.py @@ -157,8 +157,11 @@ def recv(): yield from asyncio.sleep(0.5) # creates a client and connects to our server - msg = loop.run_until_complete(client()) - server.stop(loop) + try: + msg = loop.run_until_complete(client()) + server.stop(loop) + finally: + loop.close() if __name__ == '__main__': diff --git a/examples/udp_echo.py b/examples/udp_echo.py index e958f385..899728bd 100755 --- a/examples/udp_echo.py +++ b/examples/udp_echo.py @@ -95,4 +95,7 @@ def start_client(loop, addr): else: start_client(loop, (args.host, args.port)) - loop.run_forever() + try: + loop.run_forever() + finally: + loop.close() From bd1e249ab0c97598bf0b56dfec3afdfbe9b48073 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 28 Jan 2014 23:57:15 +0100 Subject: [PATCH 0905/1502] Fix ResourceWarning in tcp and udp echo examples --- examples/tcp_echo.py | 7 ++++--- examples/udp_echo.py | 7 +++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/tcp_echo.py b/examples/tcp_echo.py index 45ece157..3c08b15d 100755 --- a/examples/tcp_echo.py +++ b/examples/tcp_echo.py @@ -71,8 +71,7 @@ def start_client(loop, host, port): def start_server(loop, host, port): f = loop.create_server(EchoServer, host, port) - s = loop.run_until_complete(f) - print('serving on', s.sockets[0].getsockname()) + return loop.run_until_complete(f) ARGS = argparse.ArgumentParser(description="TCP Echo example.") @@ -106,11 +105,13 @@ def start_server(loop, host, port): loop.add_signal_handler(signal.SIGINT, loop.stop) if args.server: - start_server(loop, args.host, args.port) + server = start_server(loop, args.host, args.port) else: start_client(loop, args.host, args.port) try: loop.run_forever() finally: + if args.server: + server.close() loop.close() diff --git a/examples/udp_echo.py b/examples/udp_echo.py index 899728bd..93ac7e6b 100755 --- a/examples/udp_echo.py +++ b/examples/udp_echo.py @@ -52,7 +52,8 @@ def connection_lost(self, exc): def start_server(loop, addr): t = asyncio.Task(loop.create_datagram_endpoint( MyServerUdpEchoProtocol, local_addr=addr)) - loop.run_until_complete(t) + transport, server = loop.run_until_complete(t) + return transport def start_client(loop, addr): @@ -91,11 +92,13 @@ def start_client(loop, addr): loop.add_signal_handler(signal.SIGINT, loop.stop) if '--server' in sys.argv: - start_server(loop, (args.host, args.port)) + server = start_server(loop, (args.host, args.port)) else: start_client(loop, (args.host, args.port)) try: loop.run_forever() finally: + if '--server' in sys.argv: + server.close() loop.close() From 6681086fed6367672b4058dc0fe604ee28b113a3 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 28 Jan 2014 23:59:56 +0100 Subject: [PATCH 0906/1502] wait_for() now accepts None as timeout --- asyncio/tasks.py | 5 ++++- tests/test_tasks.py | 11 +++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index b52933fc..d04bdc79 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -394,11 +394,14 @@ def wait_for(fut, timeout, *, loop=None): if loop is None: loop = events.get_event_loop() + fut = async(fut, loop=loop) + if timeout is None: + return (yield from fut) + waiter = futures.Future(loop=loop) timeout_handle = loop.call_later(timeout, _release_waiter, waiter, False) cb = functools.partial(_release_waiter, waiter, True) - fut = async(fut, loop=loop) fut.add_done_callback(cb) try: diff --git a/tests/test_tasks.py b/tests/test_tasks.py index dbf130c1..778b6e0d 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -380,6 +380,17 @@ def foo(): self.assertEqual(foo_running, False) + def test_wait_for_blocking(self): + loop = test_utils.TestLoop() + self.addCleanup(loop.close) + + @asyncio.coroutine + def coro(): + return 'done' + + res = loop.run_until_complete(asyncio.wait_for(coro(), timeout=None, loop=loop)) + self.assertEqual(res, 'done') + def test_wait_for_with_global_loop(self): def gen(): From 8404ad07ddc67d4eda3a6e4fad39c4addd0d6078 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 28 Jan 2014 15:53:08 -0800 Subject: [PATCH 0907/1502] Move async() call back to its original position. Issue 117. --- asyncio/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index d04bdc79..38ffec16 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -394,7 +394,6 @@ def wait_for(fut, timeout, *, loop=None): if loop is None: loop = events.get_event_loop() - fut = async(fut, loop=loop) if timeout is None: return (yield from fut) @@ -402,6 +401,7 @@ def wait_for(fut, timeout, *, loop=None): timeout_handle = loop.call_later(timeout, _release_waiter, waiter, False) cb = functools.partial(_release_waiter, waiter, True) + fut = async(fut, loop=loop) fut.add_done_callback(cb) try: From 0860f3cf8dd5ae02cca7e23ddd0050c06bb629a6 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 29 Jan 2014 01:26:18 +0100 Subject: [PATCH 0908/1502] Fix _make_subprocess_transport(): pass extra value to the constructor --- asyncio/unix_events.py | 2 +- asyncio/windows_events.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 05aa2721..ac764f8a 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -159,7 +159,7 @@ def _make_subprocess_transport(self, protocol, args, shell, with events.get_child_watcher() as watcher: transp = _UnixSubprocessTransport(self, protocol, args, shell, stdin, stdout, stderr, bufsize, - extra=None, **kwargs) + extra=extra, **kwargs) yield from transp._post_init() watcher.add_child_handler(transp.get_pid(), self._child_watcher_callback, transp) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 3c21e43b..d01de2f7 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -174,7 +174,7 @@ def _make_subprocess_transport(self, protocol, args, shell, extra=None, **kwargs): transp = _WindowsSubprocessTransport(self, protocol, args, shell, stdin, stdout, stderr, bufsize, - extra=None, **kwargs) + extra=extra, **kwargs) yield from transp._post_init() return transp From 9176f5f1764f2d55be5709b7069fef0b483bfdca Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 29 Jan 2014 23:15:12 +0100 Subject: [PATCH 0909/1502] subprocess_shell() and subprocess_exec() methods of BaseEventLoop now raises a ValueError instead of raising an AssertionError. Moreover, bufsize different than 0 is now considered as an error. --- asyncio/base_events.py | 19 ++++++++++++++----- tests/test_events.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 5694f296..58c3520e 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -552,9 +552,14 @@ def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=False, shell=True, bufsize=0, **kwargs): - assert not universal_newlines, "universal_newlines must be False" - assert shell, "shell must be True" - assert isinstance(cmd, str), cmd + if not isinstance(cmd, str): + raise ValueError("cmd must be a string") + if universal_newlines: + raise ValueError("universal_newlines must be False") + if not shell: + raise ValueError("shell must be False") + if bufsize != 0: + raise ValueError("bufsize must be 0") protocol = protocol_factory() transport = yield from self._make_subprocess_transport( protocol, cmd, True, stdin, stdout, stderr, bufsize, **kwargs) @@ -565,8 +570,12 @@ def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=False, shell=False, bufsize=0, **kwargs): - assert not universal_newlines, "universal_newlines must be False" - assert not shell, "shell must be False" + if universal_newlines: + raise ValueError("universal_newlines must be False") + if shell: + raise ValueError("shell must be False") + if bufsize != 0: + raise ValueError("bufsize must be 0") protocol = protocol_factory() transport = yield from self._make_subprocess_transport( protocol, args, False, stdin, stdout, stderr, bufsize, **kwargs) diff --git a/tests/test_events.py b/tests/test_events.py index fe5b2246..24808cb1 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1491,6 +1491,38 @@ def connect(): self.loop.run_until_complete(proto.completed) self.assertEqual(7, proto.returncode) + def test_subprocess_exec_invalid_args(self): + @asyncio.coroutine + def connect(**kwds): + yield from self.loop.subprocess_exec( + asyncio.SubprocessProtocol, + 'pwd', **kwds) + + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(universal_newlines=True)) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(bufsize=4096)) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(shell=True)) + + def test_subprocess_shell_invalid_args(self): + @asyncio.coroutine + def connect(cmd=None, **kwds): + if not cmd: + cmd = 'pwd' + yield from self.loop.subprocess_shell( + asyncio.SubprocessProtocol, + cmd, **kwds) + + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(['ls', '-l'])) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(universal_newlines=True)) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(bufsize=4096)) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(shell=False)) + if sys.platform == 'win32': From 1c7da09323018c8ded395062f3157e86f5f79a2d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 30 Jan 2014 01:01:21 +0100 Subject: [PATCH 0910/1502] Fix _UnixWritePipeTransport: raise BrokenPipeError when the pipe is closed --- asyncio/unix_events.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index ac764f8a..98fdddee 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -283,7 +283,10 @@ def get_write_buffer_size(self): def _read_ready(self): # Pipe was closed by peer. - self._close() + if self._buffer: + self._close(BrokenPipeError()) + else: + self._close() def write(self, data): assert isinstance(data, (bytes, bytearray, memoryview)), repr(data) From a0036120e5d24d49773c093137bc06733ede0369 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 30 Jan 2014 15:32:54 +0100 Subject: [PATCH 0911/1502] Fix granularity of test_utils.TestLoop --- asyncio/test_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index 42b9cd75..fed28d7d 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -196,6 +196,7 @@ def gen(): next(self._gen) self._time = 0 self._timers = [] + self._granularity = 1e-9 self._selector = TestSelector() self.readers = {} From 4118e36ddec220cc9d9b0aac0d918f5ab1816517 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 30 Jan 2014 18:54:07 +0100 Subject: [PATCH 0912/1502] Future.set_exception(exc) now instanciate exc if it is a class For example, Future.set_exception(RuntimeError) is now allowed. --- asyncio/futures.py | 2 ++ tests/test_futures.py | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/asyncio/futures.py b/asyncio/futures.py index 9ee13e3e..d09f423c 100644 --- a/asyncio/futures.py +++ b/asyncio/futures.py @@ -301,6 +301,8 @@ def set_exception(self, exception): """ if self._state != _PENDING: raise InvalidStateError('{}: {!r}'.format(self._state, self)) + if isinstance(exception, type): + exception = exception() self._exception = exception self._state = _FINISHED self._schedule_callbacks() diff --git a/tests/test_futures.py b/tests/test_futures.py index d3a74125..8a6976b1 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -79,6 +79,11 @@ def test_exception(self): self.assertRaises(asyncio.InvalidStateError, f.set_exception, None) self.assertFalse(f.cancel()) + def test_exception_class(self): + f = asyncio.Future(loop=self.loop) + f.set_exception(RuntimeError) + self.assertIsInstance(f.exception(), RuntimeError) + def test_yield_from_twice(self): f = asyncio.Future(loop=self.loop) From ee66957ff72373f5f3881a3f3846b634ad047ea2 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 30 Jan 2014 15:11:54 +0100 Subject: [PATCH 0913/1502] overlapped.c: Fix usage of the union * read_buffer can only be used for TYPE_READ and TYPE_ACCEPT types * write_buffer can only be used for TYPE_WRITE type --- overlapped.c | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/overlapped.c b/overlapped.c index 625c76ef..6842efbb 100644 --- a/overlapped.c +++ b/overlapped.c @@ -45,9 +45,9 @@ typedef struct { /* Type of operation */ DWORD type; union { - /* Buffer used for reading (optional) */ + /* Buffer used for reading: TYPE_READ and TYPE_ACCEPT */ PyObject *read_buffer; - /* Buffer used for writing (optional) */ + /* Buffer used for writing: TYPE_WRITE */ Py_buffer write_buffer; }; } OverlappedObject; @@ -568,13 +568,15 @@ Overlapped_dealloc(OverlappedObject *self) if (self->overlapped.hEvent != NULL) CloseHandle(self->overlapped.hEvent); - if (self->write_buffer.obj) - PyBuffer_Release(&self->write_buffer); - switch (self->type) { - case TYPE_READ: - case TYPE_ACCEPT: - Py_CLEAR(self->read_buffer); + case TYPE_READ: + case TYPE_ACCEPT: + Py_CLEAR(self->read_buffer); + break; + case TYPE_WRITE: + if (self->write_buffer.obj) + PyBuffer_Release(&self->write_buffer); + break; } PyObject_Del(self); SetLastError(olderr); @@ -648,7 +650,7 @@ Overlapped_getresult(OverlappedObject *self, PyObject *args) case ERROR_MORE_DATA: break; case ERROR_BROKEN_PIPE: - if (self->read_buffer != NULL) + if ((self->type == TYPE_READ || self->type == TYPE_ACCEPT) && self->read_buffer != NULL) break; /* fall through */ default: From f2d2f7e482d51e143670d0f830cfdc057e0f952a Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 30 Jan 2014 09:55:28 -0800 Subject: [PATCH 0914/1502] Normalize whitespace (use "make pep8" to verify). --- asyncio/streams.py | 1 + tests/test_streams.py | 4 ++-- tests/test_tasks.py | 23 ++++++++++++----------- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/asyncio/streams.py b/asyncio/streams.py index bd77cabb..06f052a2 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -14,6 +14,7 @@ _DEFAULT_LIMIT = 2**16 + class IncompleteReadError(EOFError): """ Incomplete read error. Attributes: diff --git a/tests/test_streams.py b/tests/test_streams.py index 2e4f99f8..01d565cd 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -415,7 +415,7 @@ def client(): server = MyServer(self.loop) server.start() msg = self.loop.run_until_complete(asyncio.Task(client(), - loop=self.loop)) + loop=self.loop)) server.stop() self.assertEqual(msg, b"hello world!\n") @@ -423,7 +423,7 @@ def client(): server = MyServer(self.loop) server.start_callback() msg = self.loop.run_until_complete(asyncio.Task(client(), - loop=self.loop)) + loop=self.loop)) server.stop() self.assertEqual(msg, b"hello world!\n") diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 778b6e0d..f54a0a06 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -379,7 +379,6 @@ def foo(): self.assertAlmostEqual(0.1, loop.time()) self.assertEqual(foo_running, False) - def test_wait_for_blocking(self): loop = test_utils.TestLoop() self.addCleanup(loop.close) @@ -388,7 +387,9 @@ def test_wait_for_blocking(self): def coro(): return 'done' - res = loop.run_until_complete(asyncio.wait_for(coro(), timeout=None, loop=loop)) + res = loop.run_until_complete(asyncio.wait_for(coro(), + timeout=None, + loop=loop)) self.assertEqual(res, 'done') def test_wait_for_with_global_loop(self): @@ -490,7 +491,7 @@ def test_wait_errors(self): self.assertRaises( ValueError, self.loop.run_until_complete, asyncio.wait([asyncio.sleep(10.0, loop=self.loop)], - return_when=-1, loop=self.loop)) + return_when=-1, loop=self.loop)) def test_wait_first_completed(self): @@ -508,7 +509,7 @@ def gen(): b = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) task = asyncio.Task( asyncio.wait([b, a], return_when=asyncio.FIRST_COMPLETED, - loop=loop), + loop=loop), loop=loop) done, pending = loop.run_until_complete(task) @@ -540,7 +541,7 @@ def coro2(): b = asyncio.Task(coro2(), loop=self.loop) task = asyncio.Task( asyncio.wait([b, a], return_when=asyncio.FIRST_COMPLETED, - loop=self.loop), + loop=self.loop), loop=self.loop) done, pending = self.loop.run_until_complete(task) @@ -570,7 +571,7 @@ def exc(): b = asyncio.Task(exc(), loop=loop) task = asyncio.Task( asyncio.wait([b, a], return_when=asyncio.FIRST_EXCEPTION, - loop=loop), + loop=loop), loop=loop) done, pending = loop.run_until_complete(task) @@ -604,7 +605,7 @@ def exc(): b = asyncio.Task(exc(), loop=loop) task = asyncio.wait([b, a], return_when=asyncio.FIRST_EXCEPTION, - loop=loop) + loop=loop) done, pending = loop.run_until_complete(task) self.assertEqual({b}, done) @@ -670,7 +671,7 @@ def gen(): @asyncio.coroutine def foo(): done, pending = yield from asyncio.wait([b, a], timeout=0.11, - loop=loop) + loop=loop) self.assertEqual(done, set([a])) self.assertEqual(pending, set([b])) @@ -874,7 +875,7 @@ def gen(): self.addCleanup(loop.close) t = asyncio.Task(asyncio.sleep(10.0, 'yeah', loop=loop), - loop=loop) + loop=loop) handle = None orig_call_later = loop.call_later @@ -1156,7 +1157,7 @@ def coro2(loop): task2 = asyncio.Task(coro2(self.loop), loop=self.loop) self.loop.run_until_complete(asyncio.wait((task1, task2), - loop=self.loop)) + loop=self.loop)) self.assertIsNone(asyncio.Task.current_task(loop=self.loop)) # Some thorough tests for cancellation propagation through @@ -1367,7 +1368,7 @@ def test_one_exception(self): def test_return_exceptions(self): a, b, c, d = [asyncio.Future(loop=self.one_loop) for i in range(4)] fut = asyncio.gather(*self.wrap_futures(a, b, c, d), - return_exceptions=True) + return_exceptions=True) cb = Mock() fut.add_done_callback(cb) exc = ZeroDivisionError() From 67095d6679b67b9414fa53a33a571d0d3dc538c9 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 30 Jan 2014 13:07:06 -0800 Subject: [PATCH 0915/1502] Minor packaging tweak by ?ric Araujo. Fixes issue #116. --- README | 6 +++--- setup.cfg | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) create mode 100644 setup.cfg diff --git a/README b/README index 34812e08..1150bafa 100644 --- a/README +++ b/README @@ -20,10 +20,10 @@ To run coverage (coverage package is required): On Windows, things are a little more complicated. Assume 'P' is your Python binary (for example C:\Python33\python.exe). -You must first build the _overlapped.pyd extension and have it placed -in the asyncio directory, as follows: +You must first build the _overlapped.pyd extension (it will be placed +in the asyncio directory): - C> P setup.py build_ext --inplace + C> P setup.py build_ext Then you can run the tests as follows: diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..da2a775d --- /dev/null +++ b/setup.cfg @@ -0,0 +1,2 @@ +[build_ext] +inplace = 1 From e92a5ff1b028c8f81a61f5588edc7fbe71b34866 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 31 Jan 2014 12:29:43 +0100 Subject: [PATCH 0916/1502] asyncio: Fix error message in BaseEventLoop.subprocess_shell(). Patch written by Vajrasky Kok. --- asyncio/base_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 58c3520e..cafd10a0 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -557,7 +557,7 @@ def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, if universal_newlines: raise ValueError("universal_newlines must be False") if not shell: - raise ValueError("shell must be False") + raise ValueError("shell must be True") if bufsize != 0: raise ValueError("bufsize must be 0") protocol = protocol_factory() From 0307f1a9fabe86e029c130e2df0b624dc4145118 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 31 Jan 2014 13:06:25 +0100 Subject: [PATCH 0917/1502] selectors: round (again) timeout away from zero for poll and epoll --- asyncio/selectors.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/asyncio/selectors.py b/asyncio/selectors.py index b1b530af..52ee8dbd 100644 --- a/asyncio/selectors.py +++ b/asyncio/selectors.py @@ -8,6 +8,7 @@ from abc import ABCMeta, abstractmethod, abstractproperty from collections import namedtuple, Mapping import functools +import math import select import sys @@ -369,8 +370,9 @@ def select(self, timeout=None): elif timeout <= 0: timeout = 0 else: - # Round towards zero - timeout = int(timeout * 1000) + # poll() has a resolution of 1 millisecond, round away from + # zero to wait *at least* timeout seconds. + timeout = int(math.ceil(timeout * 1e3)) ready = [] try: fd_event_list = self._poll.poll(timeout) @@ -430,6 +432,10 @@ def select(self, timeout=None): timeout = -1 elif timeout <= 0: timeout = 0 + else: + # epoll_wait() has a resolution of 1 millisecond, round away + # from zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) * 1e-3 max_ev = len(self._fd_to_key) ready = [] try: From ecc2418ae8ab5d237f1d0b521ed9eacff3743a64 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 31 Jan 2014 13:19:24 -0800 Subject: [PATCH 0918/1502] Copy a bunch of fixes by Victor for the Proactor event loop from the CPython repo. --- asyncio/proactor_events.py | 39 ++++++++++++++++++++++++----------- asyncio/selectors.py | 2 +- asyncio/windows_events.py | 14 +++++++++---- tests/test_base_events.py | 9 ++++++-- tests/test_proactor_events.py | 2 ++ 5 files changed, 47 insertions(+), 19 deletions(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index d2553eb7..b6b3be2d 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -205,7 +205,7 @@ def _loop_reading(self, fut=None): self.close() -class _ProactorWritePipeTransport(_ProactorBasePipeTransport, +class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport, transports.WriteTransport): """Transport for write pipes.""" @@ -286,8 +286,27 @@ def abort(self): self._force_close(None) +class _ProactorWritePipeTransport(_ProactorBaseWritePipeTransport): + def __init__(self, *args, **kw): + super().__init__(*args, **kw) + self._read_fut = self._loop._proactor.recv(self._sock, 16) + self._read_fut.add_done_callback(self._pipe_closed) + + def _pipe_closed(self, fut): + if fut.cancelled(): + # the transport has been closed + return + assert fut is self._read_fut, (fut, self._read_fut) + self._read_fut = None + assert fut.result() == b'' + if self._write_fut is not None: + self._force_close(exc) + else: + self.close() + + class _ProactorDuplexPipeTransport(_ProactorReadPipeTransport, - _ProactorWritePipeTransport, + _ProactorBaseWritePipeTransport, transports.Transport): """Transport for duplex pipes.""" @@ -299,7 +318,7 @@ def write_eof(self): class _ProactorSocketTransport(_ProactorReadPipeTransport, - _ProactorWritePipeTransport, + _ProactorBaseWritePipeTransport, transports.Transport): """Transport for connected sockets.""" @@ -335,6 +354,7 @@ def __init__(self, proactor): self._selector = proactor # convenient alias self._self_reading_future = None self._accept_futures = {} # socket file descriptor => Future + self._granularity = max(proactor.resolution, self._granularity) proactor.set_loop(self) self._make_self_pipe() @@ -353,15 +373,10 @@ def _make_read_pipe_transport(self, sock, protocol, waiter=None, return _ProactorReadPipeTransport(self, sock, protocol, waiter, extra) def _make_write_pipe_transport(self, sock, protocol, waiter=None, - extra=None, check_for_hangup=True): - if check_for_hangup: - # We want connection_lost() to be called when other end closes - return _ProactorDuplexPipeTransport(self, - sock, protocol, waiter, extra) - else: - # If other end closes we may not notice for a long time - return _ProactorWritePipeTransport(self, sock, protocol, waiter, - extra) + extra=None): + # We want connection_lost() to be called when other end closes + return _ProactorWritePipeTransport(self, + sock, protocol, waiter, extra) def close(self): if self._proactor is not None: diff --git a/asyncio/selectors.py b/asyncio/selectors.py index 52ee8dbd..056e45c2 100644 --- a/asyncio/selectors.py +++ b/asyncio/selectors.py @@ -372,7 +372,7 @@ def select(self, timeout=None): else: # poll() has a resolution of 1 millisecond, round away from # zero to wait *at least* timeout seconds. - timeout = int(math.ceil(timeout * 1e3)) + timeout = math.ceil(timeout * 1e3) ready = [] try: fd_event_list = self._poll.poll(timeout) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index d01de2f7..b8574fa0 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -1,11 +1,12 @@ """Selector and proactor eventloops for Windows.""" +import _winapi import errno +import math import socket +import struct import subprocess import weakref -import struct -import _winapi from . import events from . import base_subprocess @@ -190,6 +191,7 @@ def __init__(self, concurrency=0xffffffff): self._cache = {} self._registered = weakref.WeakSet() self._stopped_serving = weakref.WeakSet() + self.resolution = 1e-3 def set_loop(self, loop): self._loop = loop @@ -325,7 +327,9 @@ def wait_for_handle(self, handle, timeout=None): if timeout is None: ms = _winapi.INFINITE else: - ms = int(timeout * 1000 + 0.5) + # RegisterWaitForSingleObject() has a resolution of 1 millisecond, + # round away from zero to wait *at least* timeout seconds. + ms = math.ceil(timeout * 1e3) # We only create ov so we can use ov.address as a key for the cache. ov = _overlapped.Overlapped(NULL) @@ -396,7 +400,9 @@ def _poll(self, timeout=None): elif timeout < 0: raise ValueError("negative timeout") else: - ms = int(timeout * 1000 + 0.5) + # GetQueuedCompletionStatus() has a resolution of 1 millisecond, + # round away from zero to wait *at least* timeout seconds. + ms = math.ceil(timeout * 1e3) if ms >= INFINITE: raise ValueError("timeout too big") while True: diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 9c2bda1f..72f5c8a0 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -120,8 +120,13 @@ def cb(): self.loop.call_at(when, cb) t0 = self.loop.time() self.loop.run_forever() - t1 = self.loop.time() - self.assertTrue(0.09 <= t1-t0 <= 0.9, t1-t0) + dt = self.loop.time() - t0 + self.assertTrue(0.09 <= dt <= 0.9, + # Issue #20452: add more info in case of failure, + # to try to investigate the bug + (dt, + self.loop._granularity, + time.get_clock_info('monotonic'))) def test_run_once_in_executor_handle(self): def cb(): diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index 9964f425..98abe696 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -17,6 +17,7 @@ class ProactorSocketTransportTests(unittest.TestCase): def setUp(self): self.loop = test_utils.TestLoop() self.proactor = unittest.mock.Mock() + self.proactor.resolution = 1e-3 self.loop._proactor = self.proactor self.protocol = test_utils.make_test_protocol(asyncio.Protocol) self.sock = unittest.mock.Mock(socket.socket) @@ -342,6 +343,7 @@ class BaseProactorEventLoopTests(unittest.TestCase): def setUp(self): self.sock = unittest.mock.Mock(socket.socket) self.proactor = unittest.mock.Mock() + self.proactor.resolution = 1e-3 self.ssock, self.csock = unittest.mock.Mock(), unittest.mock.Mock() From 1e4d71fc98687e47dfc6b031b6822d86506b87b2 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 1 Feb 2014 22:46:32 +0100 Subject: [PATCH 0919/1502] Merge (manually) the subprocess_stream into default * Add a new asyncio.subprocess module * Add new create_subprocess_exec() and create_subprocess_shell() functions * The new asyncio.subprocess.SubprocessStreamProtocol creates stream readers for stdout and stderr and a stream writer for stdin. * The new asyncio.subprocess.Process class offers an API close to the subprocess.Popen class: - pid, returncode, stdin, stdout and stderr attributes - communicate(), wait(), send_signal(), terminate() and kill() methods * Remove STDIN (0), STDOUT (1) and STDERR (2) constants from base_subprocess and unix_events, to not be confused with the symbols with the same name of subprocess and asyncio.subprocess modules * _ProactorBasePipeTransport.get_write_buffer_size() now counts also the size of the pending write * _ProactorBaseWritePipeTransport._loop_writing() may now pause the protocol if the write buffer size is greater than the high water mark (64 KB by default) * Add new subprocess examples: shell.py, subprocess_shell.py, * subprocess_attach_read_pipe.py and subprocess_attach_write_pipe.py --- asyncio/__init__.py | 2 + asyncio/base_subprocess.py | 23 ++- asyncio/proactor_events.py | 34 ++-- asyncio/subprocess.py | 198 +++++++++++++++++++++++ asyncio/unix_events.py | 7 +- examples/child_process.py | 4 +- examples/shell.py | 50 ++++++ examples/subprocess_attach_read_pipe.py | 33 ++++ examples/subprocess_attach_write_pipe.py | 33 ++++ examples/subprocess_shell.py | 85 ++++++++++ tests/test_subprocess.py | 196 ++++++++++++++++++++++ 11 files changed, 631 insertions(+), 34 deletions(-) create mode 100644 asyncio/subprocess.py create mode 100644 examples/shell.py create mode 100644 examples/subprocess_attach_read_pipe.py create mode 100644 examples/subprocess_attach_write_pipe.py create mode 100644 examples/subprocess_shell.py create mode 100644 tests/test_subprocess.py diff --git a/asyncio/__init__.py b/asyncio/__init__.py index eb22c385..3df2f803 100644 --- a/asyncio/__init__.py +++ b/asyncio/__init__.py @@ -24,6 +24,7 @@ from .protocols import * from .queues import * from .streams import * +from .subprocess import * from .tasks import * from .transports import * @@ -39,5 +40,6 @@ protocols.__all__ + queues.__all__ + streams.__all__ + + subprocess.__all__ + tasks.__all__ + transports.__all__) diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index b7cdbcef..b78f816d 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -6,11 +6,6 @@ from . import transports -STDIN = 0 -STDOUT = 1 -STDERR = 2 - - class BaseSubprocessTransport(transports.SubprocessTransport): def __init__(self, loop, protocol, args, shell, @@ -22,11 +17,11 @@ def __init__(self, loop, protocol, args, shell, self._pipes = {} if stdin == subprocess.PIPE: - self._pipes[STDIN] = None + self._pipes[0] = None if stdout == subprocess.PIPE: - self._pipes[STDOUT] = None + self._pipes[1] = None if stderr == subprocess.PIPE: - self._pipes[STDERR] = None + self._pipes[2] = None self._pending_calls = collections.deque() self._finished = False self._returncode = None @@ -76,19 +71,19 @@ def _post_init(self): loop = self._loop if proc.stdin is not None: _, pipe = yield from loop.connect_write_pipe( - lambda: WriteSubprocessPipeProto(self, STDIN), + lambda: WriteSubprocessPipeProto(self, 0), proc.stdin) - self._pipes[STDIN] = pipe + self._pipes[0] = pipe if proc.stdout is not None: _, pipe = yield from loop.connect_read_pipe( - lambda: ReadSubprocessPipeProto(self, STDOUT), + lambda: ReadSubprocessPipeProto(self, 1), proc.stdout) - self._pipes[STDOUT] = pipe + self._pipes[1] = pipe if proc.stderr is not None: _, pipe = yield from loop.connect_read_pipe( - lambda: ReadSubprocessPipeProto(self, STDERR), + lambda: ReadSubprocessPipeProto(self, 2), proc.stderr) - self._pipes[STDERR] = pipe + self._pipes[2] = pipe assert self._pending_calls is not None diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index b6b3be2d..fb671557 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -29,6 +29,7 @@ def __init__(self, loop, sock, protocol, waiter=None, self._buffer = None # None or bytearray. self._read_fut = None self._write_fut = None + self._pending_write = 0 self._conn_lost = 0 self._closing = False # Set when close() called. self._eof_written = False @@ -68,6 +69,7 @@ def _force_close(self, exc): if self._read_fut: self._read_fut.cancel() self._write_fut = self._read_fut = None + self._pending_write = 0 self._buffer = None self._loop.call_soon(self._call_connection_lost, exc) @@ -128,11 +130,10 @@ def set_write_buffer_limits(self, high=None, low=None): self._low_water = low def get_write_buffer_size(self): - # NOTE: This doesn't take into account data already passed to - # send() even if send() hasn't finished yet. - if not self._buffer: - return 0 - return len(self._buffer) + size = self._pending_write + if self._buffer is not None: + size += len(self._buffer) + return size class _ProactorReadPipeTransport(_ProactorBasePipeTransport, @@ -206,7 +207,7 @@ def _loop_reading(self, fut=None): class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport, - transports.WriteTransport): + transports.WriteTransport): """Transport for write pipes.""" def write(self, data): @@ -252,6 +253,7 @@ def _loop_writing(self, f=None, data=None): try: assert f is self._write_fut self._write_fut = None + self._pending_write = 0 if f: f.result() if data is None: @@ -262,15 +264,21 @@ def _loop_writing(self, f=None, data=None): self._loop.call_soon(self._call_connection_lost, None) if self._eof_written: self._sock.shutdown(socket.SHUT_WR) + # Now that we've reduced the buffer size, tell the + # protocol to resume writing if it was paused. Note that + # we do this last since the callback is called immediately + # and it may add more data to the buffer (even causing the + # protocol to be paused again). + self._maybe_resume_protocol() else: self._write_fut = self._loop._proactor.send(self._sock, data) - self._write_fut.add_done_callback(self._loop_writing) - # Now that we've reduced the buffer size, tell the - # protocol to resume writing if it was paused. Note that - # we do this last since the callback is called immediately - # and it may add more data to the buffer (even causing the - # protocol to be paused again). - self._maybe_resume_protocol() + if not self._write_fut.done(): + assert self._pending_write == 0 + self._pending_write = len(data) + self._write_fut.add_done_callback(self._loop_writing) + self._maybe_pause_protocol() + else: + self._write_fut.add_done_callback(self._loop_writing) except ConnectionResetError as exc: self._force_close(exc) except OSError as exc: diff --git a/asyncio/subprocess.py b/asyncio/subprocess.py new file mode 100644 index 00000000..6c4ded35 --- /dev/null +++ b/asyncio/subprocess.py @@ -0,0 +1,198 @@ +__all__ = ['create_subprocess_exec', 'create_subprocess_shell'] + +import collections +import subprocess + +from . import events +from . import futures +from . import protocols +from . import streams +from . import tasks + + +PIPE = subprocess.PIPE +STDOUT = subprocess.STDOUT +DEVNULL = subprocess.DEVNULL + + +class SubprocessStreamProtocol(streams.FlowControlMixin, + protocols.SubprocessProtocol): + """Like StreamReaderProtocol, but for a subprocess.""" + + def __init__(self, limit, loop): + super().__init__(loop=loop) + self._limit = limit + self.stdin = self.stdout = self.stderr = None + self.waiter = futures.Future(loop=loop) + self._waiters = collections.deque() + self._transport = None + + def connection_made(self, transport): + self._transport = transport + if transport.get_pipe_transport(1): + self.stdout = streams.StreamReader(limit=self._limit, + loop=self._loop) + if transport.get_pipe_transport(2): + self.stderr = streams.StreamReader(limit=self._limit, + loop=self._loop) + stdin = transport.get_pipe_transport(0) + if stdin is not None: + self.stdin = streams.StreamWriter(stdin, + protocol=self, + reader=None, + loop=self._loop) + self.waiter.set_result(None) + + def pipe_data_received(self, fd, data): + if fd == 1: + reader = self.stdout + elif fd == 2: + reader = self.stderr + else: + reader = None + if reader is not None: + reader.feed_data(data) + + def pipe_connection_lost(self, fd, exc): + if fd == 0: + pipe = self.stdin + if pipe is not None: + pipe.close() + self.connection_lost(exc) + return + if fd == 1: + reader = self.stdout + elif fd == 2: + reader = self.stderr + else: + reader = None + if reader != None: + if exc is None: + reader.feed_eof() + else: + reader.set_exception(exc) + + def process_exited(self): + # wake up futures waiting for wait() + returncode = self._transport.get_returncode() + while self._waiters: + waiter = self._waiters.popleft() + waiter.set_result(returncode) + + +class Process: + def __init__(self, transport, protocol, loop): + self._transport = transport + self._protocol = protocol + self._loop = loop + self.stdin = protocol.stdin + self.stdout = protocol.stdout + self.stderr = protocol.stderr + self.pid = transport.get_pid() + + @property + def returncode(self): + return self._transport.get_returncode() + + @tasks.coroutine + def wait(self): + """Wait until the process exit and return the process return code.""" + returncode = self._transport.get_returncode() + if returncode is not None: + return returncode + + waiter = futures.Future(loop=self._loop) + self._protocol._waiters.append(waiter) + yield from waiter + return waiter.result() + + def get_subprocess(self): + return self._transport.get_extra_info('subprocess') + + def _check_alive(self): + if self._transport.get_returncode() is not None: + raise ProcessLookupError() + + def send_signal(self, signal): + self._check_alive() + self._transport.send_signal(signal) + + def terminate(self): + self._check_alive() + self._transport.terminate() + + def kill(self): + self._check_alive() + self._transport.kill() + + @tasks.coroutine + def _feed_stdin(self, input): + self.stdin.write(input) + yield from self.stdin.drain() + self.stdin.close() + + @tasks.coroutine + def _noop(self): + return None + + @tasks.coroutine + def _read_stream(self, fd): + transport = self._transport.get_pipe_transport(fd) + if fd == 2: + stream = self.stderr + else: + assert fd == 1 + stream = self.stdout + output = yield from stream.read() + transport.close() + return output + + @tasks.coroutine + def communicate(self, input=None): + loop = self._transport._loop + if input: + stdin = self._feed_stdin(input) + else: + stdin = self._noop() + if self.stdout is not None: + stdout = self._read_stream(1) + else: + stdout = self._noop() + if self.stderr is not None: + stderr = self._read_stream(2) + else: + stderr = self._noop() + stdin, stdout, stderr = yield from tasks.gather(stdin, stdout, stderr, + loop=loop) + yield from self.wait() + return (stdout, stderr) + + +@tasks.coroutine +def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None, + loop=None, limit=streams._DEFAULT_LIMIT, **kwds): + if loop is None: + loop = events.get_event_loop() + protocol_factory = lambda: SubprocessStreamProtocol(limit=limit, + loop=loop) + transport, protocol = yield from loop.subprocess_shell( + protocol_factory, + cmd, stdin=stdin, stdout=stdout, + stderr=stderr, **kwds) + yield from protocol.waiter + return Process(transport, protocol, loop) + +@tasks.coroutine +def create_subprocess_exec(*args, stdin=None, stdout=None, stderr=None, + loop=None, limit=streams._DEFAULT_LIMIT, **kwds): + if loop is None: + loop = events.get_event_loop() + protocol_factory = lambda: SubprocessStreamProtocol(limit=limit, + loop=loop) + transport, protocol = yield from loop.subprocess_exec( + protocol_factory, + *args, stdin=stdin, stdout=stdout, + stderr=stderr, **kwds) + yield from protocol.waiter + return Process(transport, protocol, loop) + diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 98fdddee..3ce2db8d 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -21,16 +21,11 @@ from .log import logger -__all__ = ['SelectorEventLoop', 'STDIN', 'STDOUT', 'STDERR', +__all__ = ['SelectorEventLoop', 'AbstractChildWatcher', 'SafeChildWatcher', 'FastChildWatcher', 'DefaultEventLoopPolicy', ] -STDIN = 0 -STDOUT = 1 -STDERR = 2 - - if sys.platform == 'win32': # pragma: no cover raise ImportError('Signals are not really supported on Windows') diff --git a/examples/child_process.py b/examples/child_process.py index 4410414d..0c12cb95 100644 --- a/examples/child_process.py +++ b/examples/child_process.py @@ -1,7 +1,9 @@ """ Example of asynchronous interaction with a child python process. -Note that on Windows we must use the IOCP event loop. +This example shows how to attach an existing Popen object and use the low level +transport-protocol API. See shell.py and subprocess_shell.py for higher level +examples. """ import os diff --git a/examples/shell.py b/examples/shell.py new file mode 100644 index 00000000..e094b613 --- /dev/null +++ b/examples/shell.py @@ -0,0 +1,50 @@ +"""Examples using create_subprocess_exec() and create_subprocess_shell().""" +import logging; logging.basicConfig() + +import asyncio +import signal +from asyncio.subprocess import PIPE + +@asyncio.coroutine +def cat(loop): + proc = yield from asyncio.create_subprocess_shell("cat", + stdin=PIPE, + stdout=PIPE) + print("pid: %s" % proc.pid) + + message = "Hello World!" + print("cat write: %r" % message) + + stdout, stderr = yield from proc.communicate(message.encode('ascii')) + print("cat read: %r" % stdout.decode('ascii')) + + exitcode = yield from proc.wait() + print("(exit code %s)" % exitcode) + +@asyncio.coroutine +def ls(loop): + proc = yield from asyncio.create_subprocess_exec("ls", + stdout=PIPE) + while True: + line = yield from proc.stdout.readline() + if not line: + break + print("ls>>", line.decode('ascii').rstrip()) + try: + proc.send_signal(signal.SIGINT) + except ProcessLookupError: + pass + +@asyncio.coroutine +def test_call(*args, timeout=None): + try: + proc = yield from asyncio.create_subprocess_exec(*args) + exitcode = yield from asyncio.wait_for(proc.wait(), timeout) + print("%s: exit code %s" % (' '.join(args), exitcode)) + except asyncio.TimeoutError: + print("timeout! (%.1f sec)" % timeout) + +loop = asyncio.get_event_loop() +loop.run_until_complete(cat(loop)) +loop.run_until_complete(ls(loop)) +loop.run_until_complete(test_call("bash", "-c", "sleep 3", timeout=1.0)) diff --git a/examples/subprocess_attach_read_pipe.py b/examples/subprocess_attach_read_pipe.py new file mode 100644 index 00000000..8bec652f --- /dev/null +++ b/examples/subprocess_attach_read_pipe.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +"""Example showing how to attach a read pipe to a subprocess.""" +import asyncio +import os, sys +from asyncio import subprocess + +code = """ +import os, sys +fd = int(sys.argv[1]) +data = os.write(fd, b'data') +os.close(fd) +""" + +loop = asyncio.get_event_loop() + +@asyncio.coroutine +def task(): + rfd, wfd = os.pipe() + args = [sys.executable, '-c', code, str(wfd)] + + pipe = open(rfd, 'rb', 0) + reader = asyncio.StreamReader(loop=loop) + protocol = asyncio.StreamReaderProtocol(reader, loop=loop) + transport, _ = yield from loop.connect_read_pipe(lambda: protocol, pipe) + + proc = yield from asyncio.create_subprocess_exec(*args, pass_fds={wfd}) + yield from proc.wait() + + os.close(wfd) + data = yield from reader.read() + print("read = %r" % data.decode()) + +loop.run_until_complete(task()) diff --git a/examples/subprocess_attach_write_pipe.py b/examples/subprocess_attach_write_pipe.py new file mode 100644 index 00000000..017b827f --- /dev/null +++ b/examples/subprocess_attach_write_pipe.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +"""Example showing how to attach a write pipe to a subprocess.""" +import asyncio +import os, sys +from asyncio import subprocess + +code = """ +import os, sys +fd = int(sys.argv[1]) +data = os.read(fd, 1024) +sys.stdout.buffer.write(data) +""" + +loop = asyncio.get_event_loop() + +@asyncio.coroutine +def task(): + rfd, wfd = os.pipe() + args = [sys.executable, '-c', code, str(rfd)] + proc = yield from asyncio.create_subprocess_exec( + *args, + pass_fds={rfd}, + stdout=subprocess.PIPE) + + pipe = open(wfd, 'wb', 0) + transport, _ = yield from loop.connect_write_pipe(asyncio.Protocol, + pipe) + transport.write(b'data') + + stdout, stderr = yield from proc.communicate() + print("stdout = %r" % stdout.decode()) + +loop.run_until_complete(task()) diff --git a/examples/subprocess_shell.py b/examples/subprocess_shell.py new file mode 100644 index 00000000..d0e5d65a --- /dev/null +++ b/examples/subprocess_shell.py @@ -0,0 +1,85 @@ +"""Example writing to and reading from a subprocess at the same time using +tasks.""" + +import asyncio +import os +from asyncio.subprocess import PIPE + + +@asyncio.coroutine +def send_input(writer, input): + try: + for line in input: + print('sending', len(line), 'bytes') + writer.write(line) + d = writer.drain() + if d: + print('pause writing') + yield from d + print('resume writing') + writer.close() + except BrokenPipeError: + print('stdin: broken pipe error') + except ConnectionResetError: + print('stdin: connection reset error') + +@asyncio.coroutine +def log_errors(reader): + while True: + line = yield from reader.readline() + if not line: + break + print('ERROR', repr(line)) + +@asyncio.coroutine +def read_stdout(stdout): + while True: + line = yield from stdout.readline() + print('received', repr(line)) + if not line: + break + +@asyncio.coroutine +def start(cmd, input=None, **kwds): + kwds['stdout'] = PIPE + kwds['stderr'] = PIPE + if input is None and 'stdin' not in kwds: + kwds['stdin'] = None + else: + kwds['stdin'] = PIPE + proc = yield from asyncio.create_subprocess_shell(cmd, **kwds) + + tasks = [] + if input is not None: + tasks.append(send_input(proc.stdin, input)) + else: + print('No stdin') + if proc.stderr is not None: + tasks.append(log_errors(proc.stderr)) + else: + print('No stderr') + if proc.stdout is not None: + tasks.append(read_stdout(proc.stdout)) + else: + print('No stdout') + + if tasks: + # feed stdin while consuming stdout to avoid hang + # when stdin pipe is full + yield from asyncio.wait(tasks) + + exitcode = yield from proc.wait() + print("exit code: %s" % exitcode) + + +def main(): + if os.name == 'nt': + loop = asyncio.ProactorEventLoop() + asyncio.set_event_loop(loop) + else: + loop = asyncio.get_event_loop() + loop.run_until_complete(start('sleep 2; wc', input=[b'foo bar baz\n'*300 for i in range(100)])) + + +if __name__ == '__main__': + main() diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py new file mode 100644 index 00000000..4f38cc72 --- /dev/null +++ b/tests/test_subprocess.py @@ -0,0 +1,196 @@ +from asyncio import subprocess +import asyncio +import signal +import sys +import unittest +from test import support +if sys.platform != 'win32': + from asyncio import unix_events + +# Program exiting quickly +PROGRAM_EXIT_FAST = [sys.executable, '-c', 'pass'] + +# Program blocking +PROGRAM_BLOCKED = [sys.executable, '-c', 'import time; time.sleep(3600)'] + +# Program sleeping during 1 second +PROGRAM_SLEEP_1SEC = [sys.executable, '-c', 'import time; time.sleep(1)'] + +# Program copying input to output +PROGRAM_CAT = [ + sys.executable, '-c', + ';'.join(('import sys', + 'data = sys.stdin.buffer.read()', + 'sys.stdout.buffer.write(data)'))] + +class SubprocessMixin: + def test_stdin_stdout(self): + args = PROGRAM_CAT + + @asyncio.coroutine + def run(data): + proc = yield from asyncio.create_subprocess_exec( + *args, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + loop=self.loop) + + # feed data + proc.stdin.write(data) + yield from proc.stdin.drain() + proc.stdin.close() + + # get output and exitcode + data = yield from proc.stdout.read() + exitcode = yield from proc.wait() + return (exitcode, data) + + task = run(b'some data') + task = asyncio.wait_for(task, 10.0, loop=self.loop) + exitcode, stdout = self.loop.run_until_complete(task) + self.assertEqual(exitcode, 0) + self.assertEqual(stdout, b'some data') + + def test_communicate(self): + args = PROGRAM_CAT + + @asyncio.coroutine + def run(data): + proc = yield from asyncio.create_subprocess_exec( + *args, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + loop=self.loop) + stdout, stderr = yield from proc.communicate(data) + return proc.returncode, stdout + + task = run(b'some data') + task = asyncio.wait_for(task, 10.0, loop=self.loop) + exitcode, stdout = self.loop.run_until_complete(task) + self.assertEqual(exitcode, 0) + self.assertEqual(stdout, b'some data') + + def test_shell(self): + create = asyncio.create_subprocess_shell('exit 7', + loop=self.loop) + proc = self.loop.run_until_complete(create) + exitcode = self.loop.run_until_complete(proc.wait()) + self.assertEqual(exitcode, 7) + + def test_start_new_session(self): + # start the new process in a new session + create = asyncio.create_subprocess_shell('exit 8', + start_new_session=True, + loop=self.loop) + proc = self.loop.run_until_complete(create) + exitcode = self.loop.run_until_complete(proc.wait()) + self.assertEqual(exitcode, 8) + + def test_kill(self): + args = PROGRAM_BLOCKED + create = asyncio.create_subprocess_exec(*args, loop=self.loop) + proc = self.loop.run_until_complete(create) + proc.kill() + returncode = self.loop.run_until_complete(proc.wait()) + if sys.platform == 'win32': + self.assertIsInstance(returncode, int) + # expect 1 but sometimes get 0 + else: + self.assertEqual(-signal.SIGKILL, returncode) + + def test_terminate(self): + args = PROGRAM_BLOCKED + create = asyncio.create_subprocess_exec(*args, loop=self.loop) + proc = self.loop.run_until_complete(create) + proc.terminate() + returncode = self.loop.run_until_complete(proc.wait()) + if sys.platform == 'win32': + self.assertIsInstance(returncode, int) + # expect 1 but sometimes get 0 + else: + self.assertEqual(-signal.SIGTERM, returncode) + + @unittest.skipIf(sys.platform == 'win32', "Don't have SIGHUP") + def test_send_signal(self): + args = PROGRAM_BLOCKED + create = asyncio.create_subprocess_exec(*args, loop=self.loop) + proc = self.loop.run_until_complete(create) + proc.send_signal(signal.SIGHUP) + returncode = self.loop.run_until_complete(proc.wait()) + self.assertEqual(-signal.SIGHUP, returncode) + + def test_get_subprocess(self): + args = PROGRAM_EXIT_FAST + + @asyncio.coroutine + def run(): + proc = yield from asyncio.create_subprocess_exec(*args, + loop=self.loop) + yield from proc.wait() + + popen = proc.get_subprocess() + popen.wait() + return (proc, popen) + + proc, popen = self.loop.run_until_complete(run()) + self.assertEqual(popen.returncode, proc.returncode) + self.assertEqual(popen.pid, proc.pid) + + def test_broken_pipe(self): + large_data = b'x' * support.PIPE_MAX_SIZE + + create = asyncio.create_subprocess_exec( + *PROGRAM_SLEEP_1SEC, + stdin=subprocess.PIPE, + loop=self.loop) + proc = self.loop.run_until_complete(create) + with self.assertRaises(BrokenPipeError): + self.loop.run_until_complete(proc.communicate(large_data)) + self.loop.run_until_complete(proc.wait()) + + +if sys.platform != 'win32': + # Unix + class SubprocessWatcherMixin(SubprocessMixin): + Watcher = None + + def setUp(self): + policy = asyncio.get_event_loop_policy() + self.loop = policy.new_event_loop() + + # ensure that the event loop is passed explicitly in the code + policy.set_event_loop(None) + + watcher = self.Watcher() + watcher.attach_loop(self.loop) + policy.set_child_watcher(watcher) + + def tearDown(self): + policy = asyncio.get_event_loop_policy() + policy.set_child_watcher(None) + self.loop.close() + policy.set_event_loop(None) + + class SubprocessSafeWatcherTests(SubprocessWatcherMixin, unittest.TestCase): + Watcher = unix_events.SafeChildWatcher + + class SubprocessFastWatcherTests(SubprocessWatcherMixin, unittest.TestCase): + Watcher = unix_events.FastChildWatcher +else: + # Windows + class SubprocessProactorTests(SubprocessMixin, unittest.TestCase): + def setUp(self): + policy = asyncio.get_event_loop_policy() + self.loop = asyncio.ProactorEventLoop() + + # ensure that the event loop is passed explicitly in the code + policy.set_event_loop(None) + + def tearDown(self): + policy = asyncio.get_event_loop_policy() + self.loop.close() + policy.set_event_loop(None) + + +if __name__ == '__main__': + unittest.main() From e1cfc4ff90f78e3a71aebb499f8ccb34807390e4 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 3 Feb 2014 23:05:41 +0100 Subject: [PATCH 0920/1502] Replace Process.get_subprocess() method with a Process.subprocess read-only property --- asyncio/subprocess.py | 3 ++- tests/test_subprocess.py | 18 +++++++++--------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/asyncio/subprocess.py b/asyncio/subprocess.py index 6c4ded35..8b5baeef 100644 --- a/asyncio/subprocess.py +++ b/asyncio/subprocess.py @@ -106,7 +106,8 @@ def wait(self): yield from waiter return waiter.result() - def get_subprocess(self): + @property + def subprocess(self): return self._transport.get_extra_info('subprocess') def _check_alive(self): diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index 4f38cc72..785156c7 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -119,7 +119,7 @@ def test_send_signal(self): returncode = self.loop.run_until_complete(proc.wait()) self.assertEqual(-signal.SIGHUP, returncode) - def test_get_subprocess(self): + def test_subprocess(self): args = PROGRAM_EXIT_FAST @asyncio.coroutine @@ -127,14 +127,14 @@ def run(): proc = yield from asyncio.create_subprocess_exec(*args, loop=self.loop) yield from proc.wait() - - popen = proc.get_subprocess() - popen.wait() - return (proc, popen) - - proc, popen = self.loop.run_until_complete(run()) - self.assertEqual(popen.returncode, proc.returncode) - self.assertEqual(popen.pid, proc.pid) + # need to poll subprocess.Popen, otherwise the returncode + # attribute is not set + proc.subprocess.wait() + return proc + + proc = self.loop.run_until_complete(run()) + self.assertEqual(proc.subprocess.returncode, proc.returncode) + self.assertEqual(proc.subprocess.pid, proc.pid) def test_broken_pipe(self): large_data = b'x' * support.PIPE_MAX_SIZE From f6cc1b8f7f017bd6e9582b1006a82481b517de3d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 3 Feb 2014 23:08:32 +0100 Subject: [PATCH 0921/1502] Remove empty line at the end of subprocess.py --- asyncio/subprocess.py | 1 - 1 file changed, 1 deletion(-) diff --git a/asyncio/subprocess.py b/asyncio/subprocess.py index 8b5baeef..3047894b 100644 --- a/asyncio/subprocess.py +++ b/asyncio/subprocess.py @@ -196,4 +196,3 @@ def create_subprocess_exec(*args, stdin=None, stdout=None, stderr=None, stderr=stderr, **kwds) yield from protocol.waiter return Process(transport, protocol, loop) - From e862e7696b939d7be76f21f074abf2f9ab58cd57 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 3 Feb 2014 23:09:02 +0100 Subject: [PATCH 0922/1502] test_events: skip PTY tests on Mac OS X older than 10.6 --- tests/test_events.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_events.py b/tests/test_events.py index 24808cb1..5158430f 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -953,6 +953,9 @@ def connect(): @unittest.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") + # select, poll and kqueue don't support character devices (PTY) on Mac OS X + # older than 10.6 (Snow Leopard) + @support.requires_mac_ver(10, 6) def test_read_pty_output(self): proto = None @@ -1075,6 +1078,9 @@ def connect(): @unittest.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") + # select, poll and kqueue don't support character devices (PTY) on Mac OS X + # older than 10.6 (Snow Leopard) + @support.requires_mac_ver(10, 6) def test_write_pty(self): proto = None transport = None From fd35727ad66ab02d57dd827bab296a4ebcefa47a Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 4 Feb 2014 00:04:04 +0100 Subject: [PATCH 0923/1502] Adjust unit tests regarding timings: tolerate slow buildbots, add a test on the granularity --- tests/test_base_events.py | 15 ++++++++------- tests/test_events.py | 8 ++++++++ tests/test_windows_events.py | 2 +- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 72f5c8a0..1db77233 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -116,17 +116,18 @@ def cb(): self.loop.stop() self.loop._process_events = unittest.mock.Mock() - when = self.loop.time() + 0.1 + delay = 0.1 + + when = self.loop.time() + delay self.loop.call_at(when, cb) t0 = self.loop.time() self.loop.run_forever() dt = self.loop.time() - t0 - self.assertTrue(0.09 <= dt <= 0.9, - # Issue #20452: add more info in case of failure, - # to try to investigate the bug - (dt, - self.loop._granularity, - time.get_clock_info('monotonic'))) + + self.assertGreaterEqual(dt, delay - self.loop._granularity, dt) + # tolerate a difference of +800 ms because some Python buildbots + # are really slow + self.assertLessEqual(dt, 0.9, dt) def test_run_once_in_executor_handle(self): def cb(): diff --git a/tests/test_events.py b/tests/test_events.py index 5158430f..c11d20f9 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1185,6 +1185,14 @@ def wait(): calls.append(self.loop._run_once_counter) self.assertEqual(calls, [1, 3, 5, 6]) + def test_granularity(self): + granularity = self.loop._granularity + self.assertGreater(granularity, 0.0) + # Worst expected granularity: 1 ms on Linux (limited by poll/epoll + # resolution), 15.6 ms on Windows (limited by time.monotonic + # resolution) + self.assertLess(granularity, 0.050) + class SubprocessTestsMixin: diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index 3c271ebe..846049a2 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -105,7 +105,7 @@ def test_wait_for_handle(self): self.loop.run_until_complete(f) elapsed = self.loop.time() - start self.assertFalse(f.result()) - self.assertTrue(0.18 < elapsed < 0.5, elapsed) + self.assertTrue(0.18 < elapsed < 0.9, elapsed) _overlapped.SetEvent(event) From 35db74fd502ebad1dbda7b92c1ef7dd9a6d6565e Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 4 Feb 2014 08:54:49 +0100 Subject: [PATCH 0924/1502] Fix _ProactorWritePipeTransport._pipe_closed() Do nothing if the pipe is already closed. _loop_writing() may call _force_close() when it gets ConnectionResetError. --- asyncio/proactor_events.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index fb671557..6b5707c7 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -304,9 +304,12 @@ def _pipe_closed(self, fut): if fut.cancelled(): # the transport has been closed return + assert fut.result() == b'' + if self._closing: + assert self._read_fut is None + return assert fut is self._read_fut, (fut, self._read_fut) self._read_fut = None - assert fut.result() == b'' if self._write_fut is not None: self._force_close(exc) else: From 46c9f5a0b17b3039d66f517aa06f7fd60fd8dbdd Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 4 Feb 2014 14:24:32 -0800 Subject: [PATCH 0925/1502] Cosmetic improvement to test__run_once_logging() mock argument. --- tests/test_base_events.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 1db77233..0d90d3fd 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -190,7 +190,7 @@ def test__run_once(self): @unittest.mock.patch('asyncio.base_events.time') @unittest.mock.patch('asyncio.base_events.logger') - def test__run_once_logging(self, m_logging, m_time): + def test__run_once_logging(self, m_logger, m_time): # Log to INFO level if timeout > 1.0 sec. idx = -1 data = [10.0, 10.0, 12.0, 13.0] @@ -201,20 +201,18 @@ def monotonic(): return data[idx] m_time.monotonic = monotonic - m_logging.INFO = logging.INFO - m_logging.DEBUG = logging.DEBUG self.loop._scheduled.append( asyncio.TimerHandle(11.0, lambda: True, ())) self.loop._process_events = unittest.mock.Mock() self.loop._run_once() - self.assertEqual(logging.INFO, m_logging.log.call_args[0][0]) + self.assertEqual(logging.INFO, m_logger.log.call_args[0][0]) idx = -1 data = [10.0, 10.0, 10.3, 13.0] self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda:True, ())] self.loop._run_once() - self.assertEqual(logging.DEBUG, m_logging.log.call_args[0][0]) + self.assertEqual(logging.DEBUG, m_logger.log.call_args[0][0]) def test__run_once_schedule_handle(self): handle = None From 49ef421f882b676a53e6649b8093064c42f080fc Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Tue, 4 Feb 2014 16:33:33 -0500 Subject: [PATCH 0926/1502] streams.StreamReader: Use bytearray instead of deque of bytes for internal buffer --- asyncio/streams.py | 71 ++++++++++++-------------------------- tests/test_streams.py | 79 ++++++++++++++++++++++++++++--------------- 2 files changed, 74 insertions(+), 76 deletions(-) diff --git a/asyncio/streams.py b/asyncio/streams.py index 06f052a2..4df58461 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -4,8 +4,6 @@ 'open_connection', 'start_server', 'IncompleteReadError', ] -import collections - from . import events from . import futures from . import protocols @@ -259,9 +257,7 @@ def __init__(self, limit=_DEFAULT_LIMIT, loop=None): if loop is None: loop = events.get_event_loop() self._loop = loop - # TODO: Use a bytearray for a buffer, like the transport. - self._buffer = collections.deque() # Deque of bytes objects. - self._byte_count = 0 # Bytes in buffer. + self._buffer = bytearray() self._eof = False # Whether we're done. self._waiter = None # A future. self._exception = None @@ -285,7 +281,7 @@ def set_transport(self, transport): self._transport = transport def _maybe_resume_transport(self): - if self._paused and self._byte_count <= self._limit: + if self._paused and len(self._buffer) <= self._limit: self._paused = False self._transport.resume_reading() @@ -301,8 +297,7 @@ def feed_data(self, data): if not data: return - self._buffer.append(data) - self._byte_count += len(data) + self._buffer.extend(data) waiter = self._waiter if waiter is not None: @@ -312,7 +307,7 @@ def feed_data(self, data): if (self._transport is not None and not self._paused and - self._byte_count > 2*self._limit): + len(self._buffer) > 2*self._limit): try: self._transport.pause_reading() except NotImplementedError: @@ -338,28 +333,22 @@ def readline(self): if self._exception is not None: raise self._exception - parts = [] - parts_size = 0 + line = bytearray() not_enough = True while not_enough: while self._buffer and not_enough: - data = self._buffer.popleft() - ichar = data.find(b'\n') + ichar = self._buffer.find(b'\n') if ichar < 0: - parts.append(data) - parts_size += len(data) + line.extend(self._buffer) + self._buffer.clear() else: ichar += 1 - head, tail = data[:ichar], data[ichar:] - if tail: - self._buffer.appendleft(tail) + line.extend(self._buffer[:ichar]) + del self._buffer[:ichar] not_enough = False - parts.append(head) - parts_size += len(head) - if parts_size > self._limit: - self._byte_count -= parts_size + if len(line) > self._limit: self._maybe_resume_transport() raise ValueError('Line is too long') @@ -373,11 +362,8 @@ def readline(self): finally: self._waiter = None - line = b''.join(parts) - self._byte_count -= parts_size self._maybe_resume_transport() - - return line + return bytes(line) @tasks.coroutine def read(self, n=-1): @@ -395,36 +381,23 @@ def read(self, n=-1): finally: self._waiter = None else: - if not self._byte_count and not self._eof: + if not self._buffer and not self._eof: self._waiter = self._create_waiter('read') try: yield from self._waiter finally: self._waiter = None - if n < 0 or self._byte_count <= n: - data = b''.join(self._buffer) + if n < 0 or len(self._buffer) <= n: + data = bytes(self._buffer) self._buffer.clear() - self._byte_count = 0 - self._maybe_resume_transport() - return data - - parts = [] - parts_bytes = 0 - while self._buffer and parts_bytes < n: - data = self._buffer.popleft() - data_bytes = len(data) - if n < parts_bytes + data_bytes: - data_bytes = n - parts_bytes - data, rest = data[:data_bytes], data[data_bytes:] - self._buffer.appendleft(rest) - - parts.append(data) - parts_bytes += data_bytes - self._byte_count -= data_bytes - self._maybe_resume_transport() - - return b''.join(parts) + else: + # n > 0 and len(self._buffer) > n + data = bytes(self._buffer[:n]) + del self._buffer[:n] + + self._maybe_resume_transport() + return data @tasks.coroutine def readexactly(self, n): diff --git a/tests/test_streams.py b/tests/test_streams.py index 01d565cd..83474a87 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -79,13 +79,13 @@ def test_feed_empty_data(self): stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(b'') - self.assertEqual(0, stream._byte_count) + self.assertEqual(b'', stream._buffer) - def test_feed_data_byte_count(self): + def test_feed_nonempty_data(self): stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(self.DATA) - self.assertEqual(len(self.DATA), stream._byte_count) + self.assertEqual(self.DATA, stream._buffer) def test_read_zero(self): # Read zero bytes. @@ -94,7 +94,7 @@ def test_read_zero(self): data = self.loop.run_until_complete(stream.read(0)) self.assertEqual(b'', data) - self.assertEqual(len(self.DATA), stream._byte_count) + self.assertEqual(self.DATA, stream._buffer) def test_read(self): # Read bytes. @@ -107,7 +107,7 @@ def cb(): data = self.loop.run_until_complete(read_task) self.assertEqual(self.DATA, data) - self.assertFalse(stream._byte_count) + self.assertEqual(b'', stream._buffer) def test_read_line_breaks(self): # Read bytes without line breaks. @@ -118,7 +118,7 @@ def test_read_line_breaks(self): data = self.loop.run_until_complete(stream.read(5)) self.assertEqual(b'line1', data) - self.assertEqual(5, stream._byte_count) + self.assertEqual(b'line2', stream._buffer) def test_read_eof(self): # Read bytes, stop at eof. @@ -131,7 +131,7 @@ def cb(): data = self.loop.run_until_complete(read_task) self.assertEqual(b'', data) - self.assertFalse(stream._byte_count) + self.assertEqual(b'', stream._buffer) def test_read_until_eof(self): # Read all bytes until eof. @@ -147,7 +147,7 @@ def cb(): data = self.loop.run_until_complete(read_task) self.assertEqual(b'chunk1\nchunk2', data) - self.assertFalse(stream._byte_count) + self.assertEqual(b'', stream._buffer) def test_read_exception(self): stream = asyncio.StreamReader(loop=self.loop) @@ -161,7 +161,8 @@ def test_read_exception(self): ValueError, self.loop.run_until_complete, stream.read(2)) def test_readline(self): - # Read one line. + # Read one line. 'readline' will need to wait for the data + # to come from 'cb' stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(b'chunk1 ') read_task = asyncio.Task(stream.readline(), loop=self.loop) @@ -174,30 +175,40 @@ def cb(): line = self.loop.run_until_complete(read_task) self.assertEqual(b'chunk1 chunk2 chunk3 \n', line) - self.assertEqual(len(b'\n chunk4')-1, stream._byte_count) + self.assertEqual(b' chunk4', stream._buffer) def test_readline_limit_with_existing_data(self): - stream = asyncio.StreamReader(3, loop=self.loop) + # Read one line. The data is in StreamReader's buffer + # before the event loop is run. + + stream = asyncio.StreamReader(limit=3, loop=self.loop) stream.feed_data(b'li') stream.feed_data(b'ne1\nline2\n') self.assertRaises( ValueError, self.loop.run_until_complete, stream.readline()) - self.assertEqual([b'line2\n'], list(stream._buffer)) + # The buffer should contain the remaining data after exception + self.assertEqual(b'line2\n', stream._buffer) - stream = asyncio.StreamReader(3, loop=self.loop) + stream = asyncio.StreamReader(limit=3, loop=self.loop) stream.feed_data(b'li') stream.feed_data(b'ne1') stream.feed_data(b'li') self.assertRaises( ValueError, self.loop.run_until_complete, stream.readline()) - self.assertEqual([b'li'], list(stream._buffer)) - self.assertEqual(2, stream._byte_count) + # No b'\n' at the end. The 'limit' is set to 3. So before + # waiting for the new data in buffer, 'readline' will consume + # the entire buffer, and since the length of the consumed data + # is more than 3, it will raise a ValudError. The buffer is + # expected to be empty now. + self.assertEqual(b'', stream._buffer) def test_readline_limit(self): - stream = asyncio.StreamReader(7, loop=self.loop) + # Read one line. StreamReaders are fed with data after + # their 'readline' methods are called. + stream = asyncio.StreamReader(limit=7, loop=self.loop) def cb(): stream.feed_data(b'chunk1') stream.feed_data(b'chunk2') @@ -207,10 +218,25 @@ def cb(): self.assertRaises( ValueError, self.loop.run_until_complete, stream.readline()) - self.assertEqual([b'chunk3\n'], list(stream._buffer)) - self.assertEqual(7, stream._byte_count) + # The buffer had just one line of data, and after raising + # a ValueError it should be empty. + self.assertEqual(b'', stream._buffer) + + stream = asyncio.StreamReader(limit=7, loop=self.loop) + def cb(): + stream.feed_data(b'chunk1') + stream.feed_data(b'chunk2\n') + stream.feed_data(b'chunk3\n') + stream.feed_eof() + self.loop.call_soon(cb) + + self.assertRaises( + ValueError, self.loop.run_until_complete, stream.readline()) + self.assertEqual(b'chunk3\n', stream._buffer) - def test_readline_line_byte_count(self): + def test_readline_nolimit_nowait(self): + # All needed data for the first 'readline' call will be + # in the buffer. stream = asyncio.StreamReader(loop=self.loop) stream.feed_data(self.DATA[:6]) stream.feed_data(self.DATA[6:]) @@ -218,7 +244,7 @@ def test_readline_line_byte_count(self): line = self.loop.run_until_complete(stream.readline()) self.assertEqual(b'line1\n', line) - self.assertEqual(len(self.DATA) - len(b'line1\n'), stream._byte_count) + self.assertEqual(b'line2\nline3\n', stream._buffer) def test_readline_eof(self): stream = asyncio.StreamReader(loop=self.loop) @@ -244,9 +270,7 @@ def test_readline_read_byte_count(self): data = self.loop.run_until_complete(stream.read(7)) self.assertEqual(b'line2\nl', data) - self.assertEqual( - len(self.DATA) - len(b'line1\n') - len(b'line2\nl'), - stream._byte_count) + self.assertEqual(b'ine3\n', stream._buffer) def test_readline_exception(self): stream = asyncio.StreamReader(loop=self.loop) @@ -258,6 +282,7 @@ def test_readline_exception(self): stream.set_exception(ValueError()) self.assertRaises( ValueError, self.loop.run_until_complete, stream.readline()) + self.assertEqual(b'', stream._buffer) def test_readexactly_zero_or_less(self): # Read exact number of bytes (zero or less). @@ -266,11 +291,11 @@ def test_readexactly_zero_or_less(self): data = self.loop.run_until_complete(stream.readexactly(0)) self.assertEqual(b'', data) - self.assertEqual(len(self.DATA), stream._byte_count) + self.assertEqual(self.DATA, stream._buffer) data = self.loop.run_until_complete(stream.readexactly(-1)) self.assertEqual(b'', data) - self.assertEqual(len(self.DATA), stream._byte_count) + self.assertEqual(self.DATA, stream._buffer) def test_readexactly(self): # Read exact number of bytes. @@ -287,7 +312,7 @@ def cb(): data = self.loop.run_until_complete(read_task) self.assertEqual(self.DATA + self.DATA, data) - self.assertEqual(len(self.DATA), stream._byte_count) + self.assertEqual(self.DATA, stream._buffer) def test_readexactly_eof(self): # Read exact number of bytes (eof). @@ -306,7 +331,7 @@ def cb(): self.assertEqual(cm.exception.expected, n) self.assertEqual(str(cm.exception), '18 bytes read on a total of 36 expected bytes') - self.assertFalse(stream._byte_count) + self.assertEqual(b'', stream._buffer) def test_readexactly_exception(self): stream = asyncio.StreamReader(loop=self.loop) From f590413c47e2095b6adb8fc23249bdd47be6057d Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Wed, 5 Feb 2014 17:36:11 -0500 Subject: [PATCH 0927/1502] streams.StreamReader.feed_data: Add assertion that stream is not in EOF state --- asyncio/streams.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/asyncio/streams.py b/asyncio/streams.py index 4df58461..3da1d10f 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -294,6 +294,8 @@ def feed_eof(self): waiter.set_result(True) def feed_data(self, data): + assert not self._eof, 'feed_data after feed_eof' + if not data: return From 7eb2b22b2d4e673aa8c8ce3bb83c1982d4209a96 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Thu, 6 Feb 2014 00:11:28 -0500 Subject: [PATCH 0928/1502] streams.StreamReader: Add 'at_eof()' method --- asyncio/streams.py | 4 ++++ tests/test_streams.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/asyncio/streams.py b/asyncio/streams.py index 3da1d10f..8fc21474 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -293,6 +293,10 @@ def feed_eof(self): if not waiter.cancelled(): waiter.set_result(True) + def at_eof(self): + """Return True if the buffer is empty and 'feed_eof' was called.""" + return self._eof and not self._buffer + def feed_data(self, data): assert not self._eof, 'feed_data after feed_eof' diff --git a/tests/test_streams.py b/tests/test_streams.py index 83474a87..ee3fb450 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -204,6 +204,21 @@ def test_readline_limit_with_existing_data(self): # expected to be empty now. self.assertEqual(b'', stream._buffer) + def test_at_eof(self): + stream = asyncio.StreamReader(loop=self.loop) + self.assertFalse(stream.at_eof()) + + stream.feed_data(b'some data\n') + self.assertFalse(stream.at_eof()) + + self.loop.run_until_complete(stream.readline()) + self.assertFalse(stream.at_eof()) + + stream.feed_data(b'some data\n') + stream.feed_eof() + self.loop.run_until_complete(stream.readline()) + self.assertTrue(stream.at_eof()) + def test_readline_limit(self): # Read one line. StreamReaders are fed with data after # their 'readline' methods are called. From b9c2442cfb782480055f88139885fdfc54497673 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Thu, 6 Feb 2014 12:01:09 -0500 Subject: [PATCH 0929/1502] tasks.gather: Fix docstring --- asyncio/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 38ffec16..a5708b4c 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -555,7 +555,7 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False): All futures must share the same event loop. If all the tasks are done successfully, the returned future's result is the list of results (in the order of the original sequence, not necessarily - the order of results arrival). If *result_exception* is True, + the order of results arrival). If *return_exceptions* is True, exceptions in the tasks are treated the same as successful results, and gathered in the result list; otherwise, the first raised exception will be immediately propagated to the returned From 4b34c93accd01412a532e7aa8e2b835b662fe3be Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Thu, 6 Feb 2014 22:04:28 -0500 Subject: [PATCH 0930/1502] tasks: Fix as_completed, gather & wait to work with duplicate coroutines. Issue #114 --- asyncio/tasks.py | 7 +++--- tests/test_tasks.py | 55 ++++++++++++++++++++++++++++++++++++++------- 2 files changed, 51 insertions(+), 11 deletions(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index a5708b4c..5ad06520 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -364,7 +364,7 @@ def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): if loop is None: loop = events.get_event_loop() - fs = set(async(f, loop=loop) for f in fs) + fs = {async(f, loop=loop) for f in set(fs)} if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED): raise ValueError('Invalid return_when value: {}'.format(return_when)) @@ -476,7 +476,7 @@ def as_completed(fs, *, loop=None, timeout=None): """ loop = loop if loop is not None else events.get_event_loop() deadline = None if timeout is None else loop.time() + timeout - todo = set(async(f, loop=loop) for f in fs) + todo = {async(f, loop=loop) for f in set(fs)} completed = collections.deque() @coroutine @@ -568,7 +568,8 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False): prevent the cancellation of one child to cause other children to be cancelled.) """ - children = [async(fut, loop=loop) for fut in coros_or_futures] + arg_to_fut = {arg: async(arg, loop=loop) for arg in set(coros_or_futures)} + children = [arg_to_fut[arg] for arg in coros_or_futures] n = len(children) if n == 0: outer = futures.Future(loop=loop) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index f54a0a06..d4d4e639 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -483,6 +483,21 @@ def foo(): self.assertEqual(res, 42) + def test_wait_duplicate_coroutines(self): + @asyncio.coroutine + def coro(s): + return s + c = coro('test') + + task = asyncio.Task( + asyncio.wait([c, c, coro('spam')], loop=self.loop), + loop=self.loop) + + done, pending = self.loop.run_until_complete(task) + + self.assertFalse(pending) + self.assertEqual(set(f.result() for f in done), {'test', 'spam'}) + def test_wait_errors(self): self.assertRaises( ValueError, self.loop.run_until_complete, @@ -757,14 +772,10 @@ def foo(): def test_as_completed_with_timeout(self): def gen(): - when = yield - self.assertAlmostEqual(0.12, when) - when = yield 0 - self.assertAlmostEqual(0.1, when) - when = yield 0 - self.assertAlmostEqual(0.15, when) - when = yield 0.1 - self.assertAlmostEqual(0.12, when) + yield + yield 0 + yield 0 + yield 0.1 yield 0.02 loop = test_utils.TestLoop(gen) @@ -840,6 +851,25 @@ def gen(): done, pending = loop.run_until_complete(waiter) self.assertEqual(set(f.result() for f in done), {'a', 'b'}) + def test_as_completed_duplicate_coroutines(self): + @asyncio.coroutine + def coro(s): + return s + + @asyncio.coroutine + def runner(): + result = [] + c = coro('ham') + for f in asyncio.as_completed({c, c, coro('spam')}, loop=self.loop): + result.append((yield from f)) + return result + + fut = asyncio.Task(runner(), loop=self.loop) + self.loop.run_until_complete(fut) + result = fut.result() + self.assertEqual(set(result), {'ham', 'spam'}) + self.assertEqual(len(result), 2) + def test_sleep(self): def gen(): @@ -1505,6 +1535,15 @@ def coro(): gen3.close() gen4.close() + def test_duplicate_coroutines(self): + @asyncio.coroutine + def coro(s): + return s + c = coro('abc') + fut = asyncio.gather(c, c, coro('def'), c, loop=self.one_loop) + self._run_loop(self.one_loop) + self.assertEqual(fut.result(), ['abc', 'abc', 'def', 'abc']) + def test_cancellation_broadcast(self): # Cancelling outer() cancels all children. proof = 0 From 35e6b8a42923cf71df2561c90f1ea4e8c28ea39d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 7 Feb 2014 23:36:32 +0100 Subject: [PATCH 0931/1502] Remove resolution and _granularity from selectors and asyncio * Remove selectors.BaseSelector.resolution attribute * Remove asyncio.BaseEventLoop._granularity attribute --- asyncio/base_events.py | 3 +-- asyncio/proactor_events.py | 1 - asyncio/selector_events.py | 1 - asyncio/selectors.py | 21 --------------------- tests/test_base_events.py | 3 ++- tests/test_events.py | 23 +++++++---------------- 6 files changed, 10 insertions(+), 42 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index cafd10a0..db57ee86 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -96,7 +96,6 @@ def __init__(self): self._default_executor = None self._internal_fds = 0 self._running = False - self._granularity = time.get_clock_info('monotonic').resolution def _make_socket_transport(self, sock, protocol, waiter=None, *, extra=None, server=None): @@ -638,7 +637,7 @@ def _run_once(self): self._process_events(event_list) # Handle 'later' callbacks that are ready. - now = self.time() + self._granularity + now = self.time() while self._scheduled: handle = self._scheduled[0] if handle._when > now: diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 6b5707c7..74566b2e 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -365,7 +365,6 @@ def __init__(self, proactor): self._selector = proactor # convenient alias self._self_reading_future = None self._accept_futures = {} # socket file descriptor => Future - self._granularity = max(proactor.resolution, self._granularity) proactor.set_loop(self) self._make_self_pipe() diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 36901452..202c14b7 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -36,7 +36,6 @@ def __init__(self, selector=None): selector = selectors.DefaultSelector() logger.debug('Using selector: %s', selector.__class__.__name__) self._selector = selector - self._granularity = max(selector.resolution, self._granularity) self._make_self_pipe() def _make_socket_transport(self, sock, protocol, waiter=None, *, diff --git a/asyncio/selectors.py b/asyncio/selectors.py index 056e45c2..bb2a45a8 100644 --- a/asyncio/selectors.py +++ b/asyncio/selectors.py @@ -83,11 +83,6 @@ class BaseSelector(metaclass=ABCMeta): performant implementation on the current platform. """ - @abstractproperty - def resolution(self): - """Resolution of the selector in seconds""" - return None - @abstractmethod def register(self, fileobj, events, data=None): """Register a file object. @@ -289,10 +284,6 @@ def __init__(self): self._readers = set() self._writers = set() - @property - def resolution(self): - return 1e-6 - def register(self, fileobj, events, data=None): key = super().register(fileobj, events, data) if events & EVENT_READ: @@ -345,10 +336,6 @@ def __init__(self): super().__init__() self._poll = select.poll() - @property - def resolution(self): - return 1e-3 - def register(self, fileobj, events, data=None): key = super().register(fileobj, events, data) poll_events = 0 @@ -400,10 +387,6 @@ def __init__(self): super().__init__() self._epoll = select.epoll() - @property - def resolution(self): - return 1e-3 - def fileno(self): return self._epoll.fileno() @@ -468,10 +451,6 @@ def __init__(self): super().__init__() self._kqueue = select.kqueue() - @property - def resolution(self): - return 1e-9 - def fileno(self): return self._kqueue.fileno() diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 0d90d3fd..5b056847 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -124,7 +124,8 @@ def cb(): self.loop.run_forever() dt = self.loop.time() - t0 - self.assertGreaterEqual(dt, delay - self.loop._granularity, dt) + # 50 ms: maximum granularity of the event loop + self.assertGreaterEqual(dt, delay - 0.050, dt) # tolerate a difference of +800 ms because some Python buildbots # are really slow self.assertLessEqual(dt, 0.9, dt) diff --git a/tests/test_events.py b/tests/test_events.py index c11d20f9..c2988c0f 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1170,28 +1170,19 @@ def _run_once(): orig_run_once = self.loop._run_once self.loop._run_once_counter = 0 self.loop._run_once = _run_once - calls = [] @asyncio.coroutine def wait(): loop = self.loop - calls.append(loop._run_once_counter) - yield from asyncio.sleep(loop._granularity * 10, loop=loop) - calls.append(loop._run_once_counter) - yield from asyncio.sleep(loop._granularity / 10, loop=loop) - calls.append(loop._run_once_counter) + yield from asyncio.sleep(1e-2, loop=loop) + yield from asyncio.sleep(1e-4, loop=loop) self.loop.run_until_complete(wait()) - calls.append(self.loop._run_once_counter) - self.assertEqual(calls, [1, 3, 5, 6]) - - def test_granularity(self): - granularity = self.loop._granularity - self.assertGreater(granularity, 0.0) - # Worst expected granularity: 1 ms on Linux (limited by poll/epoll - # resolution), 15.6 ms on Windows (limited by time.monotonic - # resolution) - self.assertLess(granularity, 0.050) + # The ideal number of call is 6, but on some platforms, the selector + # may sleep at little bit less than timeout depending on the resolution + # of the clock used by the kernel. Tolerate 2 useless calls on these + # platforms. + self.assertLessEqual(self.loop._run_once_counter, 8) class SubprocessTestsMixin: From b8a38df5473baf309c0b8000840338c690477739 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sun, 9 Feb 2014 01:25:02 +0100 Subject: [PATCH 0932/1502] Remove scories of resolution/granularity --- tests/test_proactor_events.py | 2 -- tests/test_selector_events.py | 1 - tests/test_selectors.py | 4 ---- 3 files changed, 7 deletions(-) diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index 98abe696..9964f425 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -17,7 +17,6 @@ class ProactorSocketTransportTests(unittest.TestCase): def setUp(self): self.loop = test_utils.TestLoop() self.proactor = unittest.mock.Mock() - self.proactor.resolution = 1e-3 self.loop._proactor = self.proactor self.protocol = test_utils.make_test_protocol(asyncio.Protocol) self.sock = unittest.mock.Mock(socket.socket) @@ -343,7 +342,6 @@ class BaseProactorEventLoopTests(unittest.TestCase): def setUp(self): self.sock = unittest.mock.Mock(socket.socket) self.proactor = unittest.mock.Mock() - self.proactor.resolution = 1e-3 self.ssock, self.csock = unittest.mock.Mock(), unittest.mock.Mock() diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 4c81e75d..ad0b0be8 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -39,7 +39,6 @@ class BaseSelectorEventLoopTests(unittest.TestCase): def setUp(self): selector = unittest.mock.Mock() - selector.resolution = 1e-3 self.loop = TestBaseSelectorEventLoop(selector) def test_make_socket_transport(self): diff --git a/tests/test_selectors.py b/tests/test_selectors.py index 19098dde..0519d75a 100644 --- a/tests/test_selectors.py +++ b/tests/test_selectors.py @@ -9,10 +9,6 @@ class FakeSelector(selectors._BaseSelectorImpl): """Trivial non-abstract subclass of BaseSelector.""" - @property - def resolution(self): - return 1e-3 - def select(self, timeout=None): raise NotImplementedError From 8cafa02f2507231e623952d248a7a5e13623161e Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 8 Feb 2014 17:31:33 -0800 Subject: [PATCH 0933/1502] Fix test bug (should use list, not set). --- tests/test_tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index d4d4e639..9abdfa5b 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -860,7 +860,7 @@ def coro(s): def runner(): result = [] c = coro('ham') - for f in asyncio.as_completed({c, c, coro('spam')}, loop=self.loop): + for f in asyncio.as_completed([c, c, coro('spam')], loop=self.loop): result.append((yield from f)) return result From b40a4f48e3b3d7650bcafa1a8c25bd6f42450ca8 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sun, 9 Feb 2014 02:50:13 +0100 Subject: [PATCH 0934/1502] Remove Process.subprocess attribute; it's too easy to get inconsistent Process and Popen objects --- asyncio/subprocess.py | 4 ---- tests/test_subprocess.py | 20 -------------------- 2 files changed, 24 deletions(-) diff --git a/asyncio/subprocess.py b/asyncio/subprocess.py index 3047894b..848d64f9 100644 --- a/asyncio/subprocess.py +++ b/asyncio/subprocess.py @@ -106,10 +106,6 @@ def wait(self): yield from waiter return waiter.result() - @property - def subprocess(self): - return self._transport.get_extra_info('subprocess') - def _check_alive(self): if self._transport.get_returncode() is not None: raise ProcessLookupError() diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index 785156c7..1b2f05be 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -7,9 +7,6 @@ if sys.platform != 'win32': from asyncio import unix_events -# Program exiting quickly -PROGRAM_EXIT_FAST = [sys.executable, '-c', 'pass'] - # Program blocking PROGRAM_BLOCKED = [sys.executable, '-c', 'import time; time.sleep(3600)'] @@ -119,23 +116,6 @@ def test_send_signal(self): returncode = self.loop.run_until_complete(proc.wait()) self.assertEqual(-signal.SIGHUP, returncode) - def test_subprocess(self): - args = PROGRAM_EXIT_FAST - - @asyncio.coroutine - def run(): - proc = yield from asyncio.create_subprocess_exec(*args, - loop=self.loop) - yield from proc.wait() - # need to poll subprocess.Popen, otherwise the returncode - # attribute is not set - proc.subprocess.wait() - return proc - - proc = self.loop.run_until_complete(run()) - self.assertEqual(proc.subprocess.returncode, proc.returncode) - self.assertEqual(proc.subprocess.pid, proc.pid) - def test_broken_pipe(self): large_data = b'x' * support.PIPE_MAX_SIZE From fddfe52bdf89d29c8f49408bef9e45c9ef2fbaba Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 8 Feb 2014 19:44:54 -0800 Subject: [PATCH 0935/1502] Remove more relics of resolution/granularity. --- asyncio/test_utils.py | 5 ----- asyncio/windows_events.py | 1 - 2 files changed, 6 deletions(-) diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index fed28d7d..ccb44541 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -144,10 +144,6 @@ class TestSelector(selectors.BaseSelector): def __init__(self): self.keys = {} - @property - def resolution(self): - return 1e-3 - def register(self, fileobj, events, data=None): key = selectors.SelectorKey(fileobj, 0, events, data) self.keys[fileobj] = key @@ -196,7 +192,6 @@ def gen(): next(self._gen) self._time = 0 self._timers = [] - self._granularity = 1e-9 self._selector = TestSelector() self.readers = {} diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index b8574fa0..0a2d9810 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -191,7 +191,6 @@ def __init__(self, concurrency=0xffffffff): self._cache = {} self._registered = weakref.WeakSet() self._stopped_serving = weakref.WeakSet() - self.resolution = 1e-3 def set_loop(self, loop): self._loop = loop From 5f11d839b707b29bb8e8c5501bfdae2f63161f67 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 10 Feb 2014 00:42:00 +0100 Subject: [PATCH 0936/1502] Issue #112: Inline make_handle() into Handle constructor --- asyncio/base_events.py | 2 +- asyncio/events.py | 7 +------ asyncio/selector_events.py | 4 ++-- asyncio/test_utils.py | 4 ++-- asyncio/unix_events.py | 2 +- tests/test_events.py | 4 ++-- 6 files changed, 9 insertions(+), 14 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index db57ee86..558406c2 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -240,7 +240,7 @@ def call_soon(self, callback, *args): Any positional arguments after the callback will be passed to the callback when it is called. """ - handle = events.make_handle(callback, args) + handle = events.Handle(callback, args) self._ready.append(handle) return handle diff --git a/asyncio/events.py b/asyncio/events.py index 62400195..4c0cbb09 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -20,6 +20,7 @@ class Handle: """Object returned by callback registration methods.""" def __init__(self, callback, args): + assert not isinstance(callback, Handle), 'A Handle is not a callback' self._callback = callback self._args = args self._cancelled = False @@ -42,12 +43,6 @@ def _run(self): self = None # Needed to break cycles when an exception occurs. -def make_handle(callback, args): - # TODO: Inline this? Or make it a private EventLoop method? - assert not isinstance(callback, Handle), 'A Handle is not a callback' - return Handle(callback, args) - - class TimerHandle(Handle): """Object returned by timed callback registration methods.""" diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 202c14b7..14231c5f 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -132,7 +132,7 @@ def _accept_connection(self, protocol_factory, sock, def add_reader(self, fd, callback, *args): """Add a reader callback.""" - handle = events.make_handle(callback, args) + handle = events.Handle(callback, args) try: key = self._selector.get_key(fd) except KeyError: @@ -167,7 +167,7 @@ def remove_reader(self, fd): def add_writer(self, fd, callback, *args): """Add a writer callback..""" - handle = events.make_handle(callback, args) + handle = events.Handle(callback, args) try: key = self._selector.get_key(fd) except KeyError: diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index ccb44541..71d69cfa 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -216,7 +216,7 @@ def close(self): raise AssertionError("Time generator is not finished") def add_reader(self, fd, callback, *args): - self.readers[fd] = events.make_handle(callback, args) + self.readers[fd] = events.Handle(callback, args) def remove_reader(self, fd): self.remove_reader_count[fd] += 1 @@ -235,7 +235,7 @@ def assert_reader(self, fd, callback, *args): handle._args, args) def add_writer(self, fd, callback, *args): - self.writers[fd] = events.make_handle(callback, args) + self.writers[fd] = events.Handle(callback, args) def remove_writer(self, fd): self.remove_writer_count[fd] += 1 diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 3ce2db8d..ea79d33b 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -64,7 +64,7 @@ def add_signal_handler(self, sig, callback, *args): except ValueError as exc: raise RuntimeError(str(exc)) - handle = events.make_handle(callback, args) + handle = events.Handle(callback, args) self._signal_handlers[sig] = handle try: diff --git a/tests/test_events.py b/tests/test_events.py index c2988c0f..4fb4b254 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1660,12 +1660,12 @@ def callback(*args): '.callback')) self.assertTrue(r.endswith('())'), r) - def test_make_handle(self): + def test_handle(self): def callback(*args): return args h1 = asyncio.Handle(callback, ()) self.assertRaises( - AssertionError, asyncio.events.make_handle, h1, ()) + AssertionError, asyncio.Handle, h1, ()) @unittest.mock.patch('asyncio.events.logger') def test_callback_with_exception(self, log): From bac9e59f228446c32549c88aeec01b41317fbf49 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 10 Feb 2014 09:16:57 -0800 Subject: [PATCH 0937/1502] Added tag 0.3.1 for changeset b01fa490bc3d From 6aa1708e3da62bd26787420cee6fc4f89a122206 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 10 Feb 2014 09:19:59 -0800 Subject: [PATCH 0938/1502] Removed tag 0.3.1 From 5e7da72c065e7d9c72a47cc1b091e705a65e3190 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 10 Feb 2014 09:20:12 -0800 Subject: [PATCH 0939/1502] Bump to version 0.3.1. --- Makefile | 3 ++- setup.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 448796ef..87fa3ee5 100644 --- a/Makefile +++ b/Makefile @@ -41,6 +41,7 @@ clean: rm -f MANIFEST -# Make distributions for Python 3.3 +# Push a source distribution for Python 3.3 to PyPI. +# You must update the version in setup.py first. pypi: clean python3.3 setup.py sdist upload diff --git a/setup.py b/setup.py index 77db68fb..0d92df70 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ setup( name="asyncio", - version="0.2.1", + version="0.3.1", description="reference implementation of PEP 3156", long_description=open("README").read(), From 2c76306c5f158a8547634d7b1d45fc45ed4f55cd Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 10 Feb 2014 09:20:16 -0800 Subject: [PATCH 0940/1502] Added tag 0.3.1 for changeset 70a228927cab From 179735c1b7fbf216f9a3e521627a7a757eb68de4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 10 Feb 2014 09:49:56 -0800 Subject: [PATCH 0941/1502] Add hint for pypi release on Windows. --- Makefile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Makefile b/Makefile index 87fa3ee5..24a8c050 100644 --- a/Makefile +++ b/Makefile @@ -43,5 +43,7 @@ clean: # Push a source distribution for Python 3.3 to PyPI. # You must update the version in setup.py first. +# The corresponding action on Windows is pypi.bat. +# A PyPI user configuration in ~/.pypirc is required. pypi: clean python3.3 setup.py sdist upload From ee01f37fa4055e6f3b2dcbb144e24b801c68b567 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 10 Feb 2014 23:55:47 +0100 Subject: [PATCH 0942/1502] Python issue #20505: BaseEventLoop uses again the resolution of the clock to decide if scheduled tasks should be executed or not. --- asyncio/base_events.py | 11 +++++++++-- tests/test_events.py | 13 +++++++++---- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 558406c2..377ea216 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -96,6 +96,7 @@ def __init__(self): self._default_executor = None self._internal_fds = 0 self._running = False + self._clock_resolution = time.get_clock_info('monotonic').resolution def _make_socket_transport(self, sock, protocol, waiter=None, *, extra=None, server=None): @@ -633,14 +634,20 @@ def _run_once(self): else: logger.log(level, 'poll took %.3f seconds', t1-t0) else: + t0 = self.time() event_list = self._selector.select(timeout) + dt = self.time() - t0 + if not event_list and timeout and dt < timeout: + print("asyncio: selector.select(%.3f ms) took %.3f ms" + % (timeout*1e3, dt*1e3), + file=sys.__stderr__, flush=True) self._process_events(event_list) # Handle 'later' callbacks that are ready. - now = self.time() + end_time = self.time() + self._clock_resolution while self._scheduled: handle = self._scheduled[0] - if handle._when > now: + if handle._when >= end_time: break handle = heapq.heappop(self._scheduled) self._ready.append(handle) diff --git a/tests/test_events.py b/tests/test_events.py index 4fb4b254..3f99da4c 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1176,13 +1176,18 @@ def wait(): loop = self.loop yield from asyncio.sleep(1e-2, loop=loop) yield from asyncio.sleep(1e-4, loop=loop) + yield from asyncio.sleep(1e-6, loop=loop) + yield from asyncio.sleep(1e-8, loop=loop) + yield from asyncio.sleep(1e-10, loop=loop) self.loop.run_until_complete(wait()) - # The ideal number of call is 6, but on some platforms, the selector + # The ideal number of call is 12, but on some platforms, the selector # may sleep at little bit less than timeout depending on the resolution - # of the clock used by the kernel. Tolerate 2 useless calls on these - # platforms. - self.assertLessEqual(self.loop._run_once_counter, 8) + # of the clock used by the kernel. Tolerate a few useless calls on + # these platforms. + self.assertLessEqual(self.loop._run_once_counter, 20, + {'clock_resolution': self.loop._clock_resolution, + 'selector': self.loop._selector.__class__.__name__}) class SubprocessTestsMixin: From e5d2ffb5e283497b430bc96ac9438f52b4557624 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 11 Feb 2014 11:29:23 +0100 Subject: [PATCH 0943/1502] Fix TestLoop, set the clock resolution --- asyncio/test_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index 71d69cfa..7c8e1dcb 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -191,6 +191,7 @@ def gen(): self._gen = gen() next(self._gen) self._time = 0 + self._clock_resolution = 1e-9 self._timers = [] self._selector = TestSelector() From 0c83e695449904683f696f4fdf80a5c89d07c968 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 11 Feb 2014 11:30:53 +0100 Subject: [PATCH 0944/1502] Issue #126: call_soon(), call_soon_threadsafe(), call_later(), call_at() and run_in_executor() now raise a TypeError if the callback is a coroutine function. --- asyncio/base_events.py | 6 ++++++ asyncio/test_utils.py | 5 ++++- tests/test_base_events.py | 18 ++++++++++++++++++ tests/test_proactor_events.py | 2 +- tests/test_selector_events.py | 9 +++++---- tests/test_tasks.py | 12 +++++------- 6 files changed, 39 insertions(+), 13 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 377ea216..1f0c8133 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -227,6 +227,8 @@ def call_later(self, delay, callback, *args): def call_at(self, when, callback, *args): """Like call_later(), but uses an absolute time.""" + if tasks.iscoroutinefunction(callback): + raise TypeError("coroutines cannot be used with call_at()") timer = events.TimerHandle(when, callback, args) heapq.heappush(self._scheduled, timer) return timer @@ -241,6 +243,8 @@ def call_soon(self, callback, *args): Any positional arguments after the callback will be passed to the callback when it is called. """ + if tasks.iscoroutinefunction(callback): + raise TypeError("coroutines cannot be used with call_soon()") handle = events.Handle(callback, args) self._ready.append(handle) return handle @@ -252,6 +256,8 @@ def call_soon_threadsafe(self, callback, *args): return handle def run_in_executor(self, executor, callback, *args): + if tasks.iscoroutinefunction(callback): + raise TypeError("coroutines cannot be used with run_in_executor()") if isinstance(callback, events.Handle): assert not args assert not isinstance(callback, events.TimerHandle) diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index 7c8e1dcb..deab7c33 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -135,7 +135,7 @@ def make_test_protocol(base): if name.startswith('__') and name.endswith('__'): # skip magic names continue - dct[name] = unittest.mock.Mock(return_value=None) + dct[name] = MockCallback(return_value=None) return type('TestProtocol', (base,) + base.__bases__, dct)() @@ -274,3 +274,6 @@ def _process_events(self, event_list): def _write_to_self(self): pass + +def MockCallback(**kwargs): + return unittest.mock.Mock(spec=['__call__'], **kwargs) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 5b056847..c6950ab3 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -567,6 +567,7 @@ class Err(OSError): m_socket.getaddrinfo.return_value = [ (2, 1, 6, '', ('127.0.0.1', 10100))] + m_socket.getaddrinfo._is_coroutine = False m_sock = m_socket.socket.return_value = unittest.mock.Mock() m_sock.bind.side_effect = Err @@ -577,6 +578,7 @@ class Err(OSError): @unittest.mock.patch('asyncio.base_events.socket') def test_create_datagram_endpoint_no_addrinfo(self, m_socket): m_socket.getaddrinfo.return_value = [] + m_socket.getaddrinfo._is_coroutine = False coro = self.loop.create_datagram_endpoint( MyDatagramProto, local_addr=('localhost', 0)) @@ -681,6 +683,22 @@ def test_accept_connection_exception(self, m_log): unittest.mock.ANY, MyProto, sock, None, None) + def test_call_coroutine(self): + @asyncio.coroutine + def coroutine_function(): + pass + + with self.assertRaises(TypeError): + self.loop.call_soon(coroutine_function) + with self.assertRaises(TypeError): + self.loop.call_soon_threadsafe(coroutine_function) + with self.assertRaises(TypeError): + self.loop.call_later(60, coroutine_function) + with self.assertRaises(TypeError): + self.loop.call_at(self.loop.time() + 60, coroutine_function) + with self.assertRaises(TypeError): + self.loop.run_in_executor(None, coroutine_function) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index 9964f425..6bea1a33 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -402,7 +402,7 @@ def test_socketpair(self): NotImplementedError, BaseProactorEventLoop, self.proactor) def test_make_socket_transport(self): - tr = self.loop._make_socket_transport(self.sock, unittest.mock.Mock()) + tr = self.loop._make_socket_transport(self.sock, asyncio.Protocol()) self.assertIsInstance(tr, _ProactorSocketTransport) def test_loop_self_reading(self): diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index ad0b0be8..855a8954 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -44,8 +44,8 @@ def setUp(self): def test_make_socket_transport(self): m = unittest.mock.Mock() self.loop.add_reader = unittest.mock.Mock() - self.assertIsInstance( - self.loop._make_socket_transport(m, m), _SelectorSocketTransport) + transport = self.loop._make_socket_transport(m, asyncio.Protocol()) + self.assertIsInstance(transport, _SelectorSocketTransport) @unittest.skipIf(ssl is None, 'No ssl module') def test_make_ssl_transport(self): @@ -54,8 +54,9 @@ def test_make_ssl_transport(self): self.loop.add_writer = unittest.mock.Mock() self.loop.remove_reader = unittest.mock.Mock() self.loop.remove_writer = unittest.mock.Mock() - self.assertIsInstance( - self.loop._make_ssl_transport(m, m, m, m), _SelectorSslTransport) + waiter = asyncio.Future(loop=self.loop) + transport = self.loop._make_ssl_transport(m, asyncio.Protocol(), m, waiter) + self.assertIsInstance(transport, _SelectorSslTransport) @unittest.mock.patch('asyncio.selector_events.ssl', None) def test_make_ssl_transport_without_ssl_error(self): diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 9abdfa5b..29bdaf5b 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -2,8 +2,6 @@ import gc import unittest -import unittest.mock -from unittest.mock import Mock import asyncio from asyncio import test_utils @@ -1358,7 +1356,7 @@ def _run_loop(self, loop): def _check_success(self, **kwargs): a, b, c = [asyncio.Future(loop=self.one_loop) for i in range(3)] fut = asyncio.gather(*self.wrap_futures(a, b, c), **kwargs) - cb = Mock() + cb = test_utils.MockCallback() fut.add_done_callback(cb) b.set_result(1) a.set_result(2) @@ -1380,7 +1378,7 @@ def test_result_exception_success(self): def test_one_exception(self): a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)] fut = asyncio.gather(*self.wrap_futures(a, b, c, d, e)) - cb = Mock() + cb = test_utils.MockCallback() fut.add_done_callback(cb) exc = ZeroDivisionError() a.set_result(1) @@ -1399,7 +1397,7 @@ def test_return_exceptions(self): a, b, c, d = [asyncio.Future(loop=self.one_loop) for i in range(4)] fut = asyncio.gather(*self.wrap_futures(a, b, c, d), return_exceptions=True) - cb = Mock() + cb = test_utils.MockCallback() fut.add_done_callback(cb) exc = ZeroDivisionError() exc2 = RuntimeError() @@ -1460,7 +1458,7 @@ def test_constructor_homogenous_futures(self): def test_one_cancellation(self): a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)] fut = asyncio.gather(a, b, c, d, e) - cb = Mock() + cb = test_utils.MockCallback() fut.add_done_callback(cb) a.set_result(1) b.cancel() @@ -1479,7 +1477,7 @@ def test_result_exception_one_cancellation(self): a, b, c, d, e, f = [asyncio.Future(loop=self.one_loop) for i in range(6)] fut = asyncio.gather(a, b, c, d, e, f, return_exceptions=True) - cb = Mock() + cb = test_utils.MockCallback() fut.add_done_callback(cb) a.set_result(1) zde = ZeroDivisionError() From 08874a324ab15a1179ac1478883b1c3dad0a7c9d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 11 Feb 2014 11:33:45 +0100 Subject: [PATCH 0945/1502] Remove debug traces, there are only useful on Python buildbots --- asyncio/base_events.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 1f0c8133..9c5241f3 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -640,13 +640,7 @@ def _run_once(self): else: logger.log(level, 'poll took %.3f seconds', t1-t0) else: - t0 = self.time() event_list = self._selector.select(timeout) - dt = self.time() - t0 - if not event_list and timeout and dt < timeout: - print("asyncio: selector.select(%.3f ms) took %.3f ms" - % (timeout*1e3, dt*1e3), - file=sys.__stderr__, flush=True) self._process_events(event_list) # Handle 'later' callbacks that are ready. From 9d6147fe85b25c0c67c6806a586d2a8c95ed109c Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 11 Feb 2014 11:43:45 +0100 Subject: [PATCH 0946/1502] Issue #130: Add more checks on subprocess_exec/subprocess_shell parameters --- asyncio/base_events.py | 12 ++++++--- asyncio/subprocess.py | 5 ++-- tests/test_base_events.py | 54 +++++++++++++++++++++++++++++++++++++-- 3 files changed, 64 insertions(+), 7 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 9c5241f3..7d120424 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -558,7 +558,7 @@ def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=False, shell=True, bufsize=0, **kwargs): - if not isinstance(cmd, str): + if not isinstance(cmd, (bytes, str)): raise ValueError("cmd must be a string") if universal_newlines: raise ValueError("universal_newlines must be False") @@ -572,7 +572,7 @@ def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, return transport, protocol @tasks.coroutine - def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, + def subprocess_exec(self, protocol_factory, program, *args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=False, shell=False, bufsize=0, **kwargs): @@ -582,9 +582,15 @@ def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, raise ValueError("shell must be False") if bufsize != 0: raise ValueError("bufsize must be 0") + popen_args = (program,) + args + for arg in popen_args: + if not isinstance(arg, (str, bytes)): + raise TypeError("program arguments must be " + "a bytes or text string, not %s" + % type(arg).__name__) protocol = protocol_factory() transport = yield from self._make_subprocess_transport( - protocol, args, False, stdin, stdout, stderr, bufsize, **kwargs) + protocol, popen_args, False, stdin, stdout, stderr, bufsize, **kwargs) return transport, protocol def _add_callback(self, handle): diff --git a/asyncio/subprocess.py b/asyncio/subprocess.py index 848d64f9..8d1a4073 100644 --- a/asyncio/subprocess.py +++ b/asyncio/subprocess.py @@ -180,7 +180,7 @@ def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None, return Process(transport, protocol, loop) @tasks.coroutine -def create_subprocess_exec(*args, stdin=None, stdout=None, stderr=None, +def create_subprocess_exec(program, *args, stdin=None, stdout=None, stderr=None, loop=None, limit=streams._DEFAULT_LIMIT, **kwds): if loop is None: loop = events.get_event_loop() @@ -188,7 +188,8 @@ def create_subprocess_exec(*args, stdin=None, stdout=None, stderr=None, loop=loop) transport, protocol = yield from loop.subprocess_exec( protocol_factory, - *args, stdin=stdin, stdout=stdout, + program, *args, + stdin=stdin, stdout=stdout, stderr=stderr, **kwds) yield from protocol.waiter return Process(transport, protocol, loop) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index c6950ab3..94e2d59d 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -3,6 +3,7 @@ import errno import logging import socket +import sys import time import unittest import unittest.mock @@ -234,8 +235,57 @@ def cb(loop): self.assertEqual([handle], list(self.loop._ready)) def test_run_until_complete_type_error(self): - self.assertRaises( - TypeError, self.loop.run_until_complete, 'blah') + self.assertRaises(TypeError, + self.loop.run_until_complete, 'blah') + + def test_subprocess_exec_invalid_args(self): + args = [sys.executable, '-c', 'pass'] + + # missing program parameter (empty args) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol) + + # exepected multiple arguments, not a list + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, args) + + # program arguments must be strings, not int + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, sys.executable, 123) + + # universal_newlines, shell, bufsize must not be set + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, *args, universal_newlines=True) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, *args, shell=True) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, *args, bufsize=4096) + + def test_subprocess_shell_invalid_args(self): + # exepected a string, not an int or a list + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, 123) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, [sys.executable, '-c', 'pass']) + + # universal_newlines, shell, bufsize must not be set + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, 'exit 0', universal_newlines=True) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, 'exit 0', shell=True) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, 'exit 0', bufsize=4096) class MyProto(asyncio.Protocol): From aa40ec5a0f61b70839bce56ed3929c7ff423ef8f Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 11 Feb 2014 11:53:15 +0100 Subject: [PATCH 0947/1502] Issue #131: as_completed() and wait() now raises a TypeError if the list of futures is not a list but a Future, Task or coroutine object --- asyncio/tasks.py | 4 ++++ tests/test_tasks.py | 26 ++++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 5ad06520..81a125f4 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -358,6 +358,8 @@ def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): Note: This does not raise TimeoutError! Futures that aren't done when the timeout occurs are returned in the second set. """ + if isinstance(fs, futures.Future) or iscoroutine(fs): + raise TypeError("expect a list of futures, not %s" % type(fs).__name__) if not fs: raise ValueError('Set of coroutines/Futures is empty.') @@ -474,6 +476,8 @@ def as_completed(fs, *, loop=None, timeout=None): Note: The futures 'f' are not necessarily members of fs. """ + if isinstance(fs, futures.Future) or iscoroutine(fs): + raise TypeError("expect a list of futures, not %s" % type(fs).__name__) loop = loop if loop is not None else events.get_event_loop() deadline = None if timeout is None else loop.time() + timeout todo = {async(f, loop=loop) for f in set(fs)} diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 29bdaf5b..6847de04 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -7,6 +7,11 @@ from asyncio import test_utils +@asyncio.coroutine +def coroutine_function(): + pass + + class Dummy: def __repr__(self): @@ -1338,6 +1343,27 @@ def test_gather_shield(self): child2.set_result(2) test_utils.run_briefly(self.loop) + def test_as_completed_invalid_args(self): + fut = asyncio.Future(loop=self.loop) + + # as_completed() expects a list of futures, not a future instance + self.assertRaises(TypeError, self.loop.run_until_complete, + asyncio.as_completed(fut, loop=self.loop)) + self.assertRaises(TypeError, self.loop.run_until_complete, + asyncio.as_completed(coroutine_function(), loop=self.loop)) + + def test_wait_invalid_args(self): + fut = asyncio.Future(loop=self.loop) + + # wait() expects a list of futures, not a future instance + self.assertRaises(TypeError, self.loop.run_until_complete, + asyncio.wait(fut, loop=self.loop)) + self.assertRaises(TypeError, self.loop.run_until_complete, + asyncio.wait(coroutine_function(), loop=self.loop)) + + # wait() expects at least a future + self.assertRaises(ValueError, self.loop.run_until_complete, + asyncio.wait([], loop=self.loop)) class GatherTestsBase: From a85d24e8fdc311a4b46bbf310f2a5c997059de0f Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 11 Feb 2014 18:42:08 +0100 Subject: [PATCH 0948/1502] Python issue #20495: Skip test_read_pty_output() of test_asyncio on FreeBSD older than FreeBSD 8 --- tests/test_events.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_events.py b/tests/test_events.py index 3f99da4c..d5d667a3 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -956,6 +956,8 @@ def connect(): # select, poll and kqueue don't support character devices (PTY) on Mac OS X # older than 10.6 (Snow Leopard) @support.requires_mac_ver(10, 6) + # Issue #20495: The test hangs on FreeBSD 7.2 but pass on FreeBSD 9 + @support.requires_freebsd_version(8) def test_read_pty_output(self): proto = None From 448fd97a4dd91baee8d02a554ae9465d21e76cbf Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Wed, 12 Feb 2014 16:59:26 -0500 Subject: [PATCH 0949/1502] events: Use __slots__ in Handle and TimerHandle --- asyncio/events.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/asyncio/events.py b/asyncio/events.py index 4c0cbb09..dd9e3fb4 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -19,6 +19,8 @@ class Handle: """Object returned by callback registration methods.""" + __slots__ = ['_callback', '_args', '_cancelled'] + def __init__(self, callback, args): assert not isinstance(callback, Handle), 'A Handle is not a callback' self._callback = callback @@ -46,6 +48,8 @@ def _run(self): class TimerHandle(Handle): """Object returned by timed callback registration methods.""" + __slots__ = ['_when'] + def __init__(self, when, callback, args): assert when is not None super().__init__(callback, args) From 26f0a4e7e8e74ecae32418ec91b9a98de3bc02f0 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 12 Feb 2014 17:56:02 -0800 Subject: [PATCH 0950/1502] Change as_completed() to use a Queue, to avoid O(N**2) behavior. Fixes issue #127. --- asyncio/tasks.py | 53 ++++++++++++++++++++++++++++----------------- tests/test_tasks.py | 23 +++++++++++++++++++- 2 files changed, 55 insertions(+), 21 deletions(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 81a125f4..b7ee758d 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -463,7 +463,11 @@ def _on_completion(f): # This is *not* a @coroutine! It is just an iterator (yielding Futures). def as_completed(fs, *, loop=None, timeout=None): - """Return an iterator whose values, when waited for, are Futures. + """Return an iterator whose values are coroutines. + + When waiting for the yielded coroutines you'll get the results (or + exceptions!) of the original Futures (or coroutines), in the order + in which and as soon as they complete. This differs from PEP 3148; the proper way to use this is: @@ -471,8 +475,8 @@ def as_completed(fs, *, loop=None, timeout=None): result = yield from f # The 'yield from' may raise. # Use result. - Raises TimeoutError if the timeout occurs before all Futures are - done. + If a timeout is specified, the 'yield from' will raise + TimeoutError when the timeout occurs before all Futures are done. Note: The futures 'f' are not necessarily members of fs. """ @@ -481,27 +485,36 @@ def as_completed(fs, *, loop=None, timeout=None): loop = loop if loop is not None else events.get_event_loop() deadline = None if timeout is None else loop.time() + timeout todo = {async(f, loop=loop) for f in set(fs)} - completed = collections.deque() + from .queues import Queue # Import here to avoid circular import problem. + done = Queue(loop=loop) + timeout_handle = None + + def _on_timeout(): + for f in todo: + f.remove_done_callback(_on_completion) + done.put_nowait(None) # Queue a dummy value for _wait_for_one(). + todo.clear() # Can't do todo.remove(f) in the loop. + + def _on_completion(f): + if not todo: + return # _on_timeout() was here first. + todo.remove(f) + done.put_nowait(f) + if not todo and timeout_handle is not None: + timeout_handle.cancel() @coroutine def _wait_for_one(): - while not completed: - timeout = None - if deadline is not None: - timeout = deadline - loop.time() - if timeout < 0: - raise futures.TimeoutError() - done, pending = yield from _wait( - todo, timeout, FIRST_COMPLETED, loop) - # Multiple callers might be waiting for the same events - # and getting the same outcome. Dedupe by updating todo. - for f in done: - if f in todo: - todo.remove(f) - completed.append(f) - f = completed.popleft() - return f.result() # May raise. + f = yield from done.get() + if f is None: + # Dummy value from _on_timeout(). + raise futures.TimeoutError + return f.result() # May raise f.exception(). + for f in todo: + f.add_done_callback(_on_completion) + if todo and timeout is not None: + timeout_handle = loop.call_later(timeout, _on_timeout) for _ in range(len(todo)): yield _wait_for_one() diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 6847de04..024dd2ea 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -779,7 +779,6 @@ def gen(): yield 0 yield 0 yield 0.1 - yield 0.02 loop = test_utils.TestLoop(gen) self.addCleanup(loop.close) @@ -791,6 +790,8 @@ def gen(): def foo(): values = [] for f in asyncio.as_completed([a, b], timeout=0.12, loop=loop): + if values: + loop.advance_time(0.02) try: v = yield from f values.append((1, v)) @@ -809,6 +810,26 @@ def foo(): loop.advance_time(10) loop.run_until_complete(asyncio.wait([a, b], loop=loop)) + def test_as_completed_with_unused_timeout(self): + + def gen(): + yield + yield 0 + yield 0.01 + + loop = test_utils.TestLoop(gen) + self.addCleanup(loop.close) + + a = asyncio.sleep(0.01, 'a', loop=loop) + + @asyncio.coroutine + def foo(): + for f in asyncio.as_completed([a], timeout=1, loop=loop): + v = yield from f + self.assertEqual(v, 'a') + + res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + def test_as_completed_reverse_wait(self): def gen(): From befc2deda29e6433d551bcba7cea07b463d2501c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 12 Feb 2014 18:01:27 -0800 Subject: [PATCH 0951/1502] Fuzz tester for as_completed(), by Glenn Langford. (See issue #127.) --- examples/fuzz_as_completed.py | 69 +++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 examples/fuzz_as_completed.py diff --git a/examples/fuzz_as_completed.py b/examples/fuzz_as_completed.py new file mode 100644 index 00000000..123fbf1b --- /dev/null +++ b/examples/fuzz_as_completed.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 + +"""Fuzz tester for as_completed(), by Glenn Langford.""" + +import asyncio +import itertools +import random +import sys + +@asyncio.coroutine +def sleeper(time): + yield from asyncio.sleep(time) + return time + +@asyncio.coroutine +def watcher(tasks,delay=False): + res = [] + for t in asyncio.as_completed(tasks): + r = yield from t + res.append(r) + if delay: + # simulate processing delay + process_time = random.random() / 10 + yield from asyncio.sleep(process_time) + #print(res) + #assert(sorted(res) == res) + if sorted(res) != res: + print('FAIL', res) + print('------------') + else: + print('.', end='') + sys.stdout.flush() + +loop = asyncio.get_event_loop() + +print('Pass 1') +# All permutations of discrete task running times must be returned +# by as_completed in the correct order. +task_times = [0, 0.1, 0.2, 0.3, 0.4 ] # 120 permutations +for times in itertools.permutations(task_times): + tasks = [ asyncio.Task(sleeper(t)) for t in times ] + loop.run_until_complete(asyncio.Task(watcher(tasks))) + +print() +print('Pass 2') +# Longer task times, with randomized duplicates. 100 tasks each time. +longer_task_times = [x/10 for x in range(30)] +for i in range(20): + task_times = longer_task_times * 10 + random.shuffle(task_times) + #print('Times', task_times[:500]) + tasks = [ asyncio.Task(sleeper(t)) for t in task_times[:100] ] + loop.run_until_complete(asyncio.Task(watcher(tasks))) + +print() +print('Pass 3') +# Same as pass 2, but with a random processing delay (0 - 0.1s) after +# retrieving each future from as_completed and 200 tasks. This tests whether +# the order that callbacks are triggered is preserved through to the +# as_completed caller. +for i in range(20): + task_times = longer_task_times * 10 + random.shuffle(task_times) + #print('Times', task_times[:200]) + tasks = [ asyncio.Task(sleeper(t)) for t in task_times[:200] ] + loop.run_until_complete(asyncio.Task(watcher(tasks, delay=True))) + +print() +loop.close() From 66c6ed72c8e94d2a3f428267b8ee05a94cea8e48 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 13 Feb 2014 09:15:35 +0100 Subject: [PATCH 0952/1502] Issue #129: BaseEventLoop.sock_connect() now raises an error if the address is not resolved (hostname instead of an IP address) for AF_INET and AF_INET6 address families. --- asyncio/base_events.py | 25 +++++++++++++++++++++++++ asyncio/proactor_events.py | 9 ++++++++- asyncio/selector_events.py | 20 ++++++++------------ tests/test_events.py | 12 ++++++++++++ 4 files changed, 53 insertions(+), 13 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 7d120424..3bbf6b54 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -41,6 +41,31 @@ class _StopError(BaseException): """Raised to stop the event loop.""" +def _check_resolved_address(sock, address): + # Ensure that the address is already resolved to avoid the trap of hanging + # the entire event loop when the address requires doing a DNS lookup. + family = sock.family + if family not in (socket.AF_INET, socket.AF_INET6): + return + + host, port = address + type_mask = 0 + if hasattr(socket, 'SOCK_NONBLOCK'): + type_mask |= socket.SOCK_NONBLOCK + if hasattr(socket, 'SOCK_CLOEXEC'): + type_mask |= socket.SOCK_CLOEXEC + # Use getaddrinfo(AI_NUMERICHOST) to ensure that the address is + # already resolved. + try: + socket.getaddrinfo(host, port, + family=family, + type=(sock.type & ~type_mask), + proto=sock.proto, + flags=socket.AI_NUMERICHOST) + except socket.gaierror as err: + raise ValueError("address must be resolved (IP address), got %r: %s" + % (address, err)) + def _raise_stop_error(*args): raise _StopError diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 74566b2e..5de4d3d6 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -404,7 +404,14 @@ def sock_sendall(self, sock, data): return self._proactor.send(sock, data) def sock_connect(self, sock, address): - return self._proactor.connect(sock, address) + try: + base_events._check_resolved_address(sock, address) + except ValueError as err: + fut = futures.Future(loop=self) + fut.set_exception(err) + return fut + else: + return self._proactor.connect(sock, address) def sock_accept(self, sock): return self._proactor.accept(sock) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 14231c5f..10b02579 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -208,6 +208,8 @@ def sock_recv(self, sock, n): return fut def _sock_recv(self, fut, registered, sock, n): + # _sock_recv() can add itself as an I/O callback if the operation can't + # be done immediatly. Don't use it directly, call sock_recv(). fd = sock.fileno() if registered: # Remove the callback early. It should be rare that the @@ -260,22 +262,16 @@ def _sock_sendall(self, fut, registered, sock, data): def sock_connect(self, sock, address): """XXX""" - # That address better not require a lookup! We're not calling - # self.getaddrinfo() for you here. But verifying this is - # complicated; the socket module doesn't have a pattern for - # IPv6 addresses (there are too many forms, apparently). fut = futures.Future(loop=self) - self._sock_connect(fut, False, sock, address) + try: + base_events._check_resolved_address(sock, address) + except ValueError as err: + fut.set_exception(err) + else: + self._sock_connect(fut, False, sock, address) return fut def _sock_connect(self, fut, registered, sock, address): - # TODO: Use getaddrinfo() to look up the address, to avoid the - # trap of hanging the entire event loop when the address - # requires doing a DNS lookup. (OTOH, the caller should - # already have done this, so it would be nice if we could - # easily tell whether the address needs looking up or not. I - # know how to do this for IPv4, but IPv6 addresses have many - # syntaxes.) fd = sock.fileno() if registered: self.remove_writer(fd) diff --git a/tests/test_events.py b/tests/test_events.py index d5d667a3..4300ddd0 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1191,6 +1191,18 @@ def wait(): {'clock_resolution': self.loop._clock_resolution, 'selector': self.loop._selector.__class__.__name__}) + def test_sock_connect_address(self): + address = ('www.python.org', 80) + for family in (socket.AF_INET, socket.AF_INET6): + for sock_type in (socket.SOCK_STREAM, socket.SOCK_DGRAM): + sock = socket.socket(family, sock_type) + with sock: + connect = self.loop.sock_connect(sock, address) + with self.assertRaises(ValueError) as cm: + self.loop.run_until_complete(connect) + self.assertIn('address must be resolved', + str(cm.exception)) + class SubprocessTestsMixin: From 23f4712e1235d0bae0bbd5b4b9a7fbcfe1782301 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 13 Feb 2014 10:48:15 +0100 Subject: [PATCH 0953/1502] Fix test_events.py: skip IPv6 if IPv6 is disabled on the host --- tests/test_events.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_events.py b/tests/test_events.py index 4300ddd0..3bb8dd81 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1192,8 +1192,12 @@ def wait(): 'selector': self.loop._selector.__class__.__name__}) def test_sock_connect_address(self): + families = [socket.AF_INET] + if support.IPV6_ENABLED: + families.append(socket.AF_INET6) + address = ('www.python.org', 80) - for family in (socket.AF_INET, socket.AF_INET6): + for family in families: for sock_type in (socket.SOCK_STREAM, socket.SOCK_DGRAM): sock = socket.socket(family, sock_type) with sock: From e14ee367b8fd9c5095b1d00ba7f348d2006d96bc Mon Sep 17 00:00:00 2001 From: jesse Date: Tue, 18 Feb 2014 02:28:44 +0000 Subject: [PATCH 0954/1502] Update email address for A. Jesse Jiryu Davis. --- AUTHORS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AUTHORS b/AUTHORS index a5892b3d..79acc3d8 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,4 +1,4 @@ -A. Jesse Jiryu Davis +A. Jesse Jiryu Davis Aaron Griffith Andrew Svetlov Anthony Baire From 53da0be2eb3297331acdfe6cdbb9bc2c07ff9ec7 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 18 Feb 2014 10:02:52 +0100 Subject: [PATCH 0955/1502] Skip test_read_pty_output() on OpenBSD --- tests/test_events.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_events.py b/tests/test_events.py index 3bb8dd81..8c32a6e7 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1622,6 +1622,10 @@ def create_event_loop(self): # kqueue doesn't support character devices (PTY) on Mac OS X older # than 10.9 (Maverick) @support.requires_mac_ver(10, 9) + # Issue #20667: KqueueEventLoopTests.test_read_pty_output() + # hangs on OpenBSD 5.5 + @unittest.skipIf(sys.platform.startswith('openbsd'), + 'test hangs on OpenBSD') def test_read_pty_output(self): super().test_read_pty_output() From 5728fe01aa244af4a04d5bcd87a2350d69213177 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Tue, 18 Feb 2014 12:05:57 -0500 Subject: [PATCH 0956/1502] Add support for UNIX Domain Sockets. Closes issue #81. New APIs: - loop.create_unix_connection - loop.create_unix_server - streams.open_unix_connection - streams.start_unix_server --- asyncio/base_events.py | 7 + asyncio/events.py | 26 +++ asyncio/streams.py | 39 +++- asyncio/test_utils.py | 153 +++++++++++---- asyncio/unix_events.py | 75 +++++++- tests/test_base_events.py | 2 +- tests/test_events.py | 349 +++++++++++++++++++++++----------- tests/test_selector_events.py | 3 +- tests/test_streams.py | 195 +++++++++++++++---- tests/test_unix_events.py | 82 +++++++- 10 files changed, 738 insertions(+), 193 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 3bbf6b54..b74e9369 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -407,6 +407,13 @@ def create_connection(self, protocol_factory, host=None, port=None, *, sock.setblocking(False) + transport, protocol = yield from self._create_connection_transport( + sock, protocol_factory, ssl, server_hostname) + return transport, protocol + + @tasks.coroutine + def _create_connection_transport(self, sock, protocol_factory, ssl, + server_hostname): protocol = protocol_factory() waiter = futures.Future(loop=self) if ssl: diff --git a/asyncio/events.py b/asyncio/events.py index dd9e3fb4..7841ad9b 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -220,6 +220,32 @@ def create_server(self, protocol_factory, host=None, port=None, *, """ raise NotImplementedError + def create_unix_connection(self, protocol_factory, path, *, + ssl=None, sock=None, + server_hostname=None): + raise NotImplementedError + + def create_unix_server(self, protocol_factory, path, *, + sock=None, backlog=100, ssl=None): + """A coroutine which creates a UNIX Domain Socket server. + + The return valud is a Server object, which can be used to stop + the service. + + path is a str, representing a file systsem path to bind the + server socket to. + + sock can optionally be specified in order to use a preexisting + socket object. + + backlog is the maximum number of queued connections passed to + listen() (defaults to 100). + + ssl can be set to an SSLContext to enable SSL over the + accepted connections. + """ + raise NotImplementedError + def create_datagram_endpoint(self, protocol_factory, local_addr=None, remote_addr=None, *, family=0, proto=0, flags=0): diff --git a/asyncio/streams.py b/asyncio/streams.py index 8fc21474..698c5c6b 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -1,9 +1,13 @@ """Stream-related things.""" __all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol', - 'open_connection', 'start_server', 'IncompleteReadError', + 'open_connection', 'start_server', + 'open_unix_connection', 'start_unix_server', + 'IncompleteReadError', ] +import socket + from . import events from . import futures from . import protocols @@ -93,6 +97,39 @@ def factory(): return (yield from loop.create_server(factory, host, port, **kwds)) +if hasattr(socket, 'AF_UNIX'): + # UNIX Domain Sockets are supported on this platform + + @tasks.coroutine + def open_unix_connection(path=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """Similar to `open_connection` but works with UNIX Domain Sockets.""" + if loop is None: + loop = events.get_event_loop() + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, loop=loop) + transport, _ = yield from loop.create_unix_connection( + lambda: protocol, path, **kwds) + writer = StreamWriter(transport, protocol, reader, loop) + return reader, writer + + + @tasks.coroutine + def start_unix_server(client_connected_cb, path=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """Similar to `start_server` but works with UNIX Domain Sockets.""" + if loop is None: + loop = events.get_event_loop() + + def factory(): + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, client_connected_cb, + loop=loop) + return protocol + + return (yield from loop.create_unix_server(factory, path, **kwds)) + + class FlowControlMixin(protocols.Protocol): """Reusable flow control logic for StreamWriter.drain(). diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index deab7c33..de2916bf 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -4,12 +4,18 @@ import contextlib import io import os +import socket +import socketserver import sys +import tempfile import threading import time import unittest import unittest.mock + +from http.server import HTTPServer from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer + try: import ssl except ImportError: # pragma: no cover @@ -70,42 +76,51 @@ def run_once(loop): loop.run_forever() -@contextlib.contextmanager -def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): +class SilentWSGIRequestHandler(WSGIRequestHandler): - class SilentWSGIRequestHandler(WSGIRequestHandler): - def get_stderr(self): - return io.StringIO() + def get_stderr(self): + return io.StringIO() - def log_message(self, format, *args): - pass + def log_message(self, format, *args): + pass - class SilentWSGIServer(WSGIServer): - def handle_error(self, request, client_address): + +class SilentWSGIServer(WSGIServer): + + def handle_error(self, request, client_address): + pass + + +class SSLWSGIServerMixin: + + def finish_request(self, request, client_address): + # The relative location of our test directory (which + # contains the ssl key and certificate files) differs + # between the stdlib and stand-alone asyncio. + # Prefer our own if we can find it. + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + if not os.path.isdir(here): + here = os.path.join(os.path.dirname(os.__file__), + 'test', 'test_asyncio') + keyfile = os.path.join(here, 'ssl_key.pem') + certfile = os.path.join(here, 'ssl_cert.pem') + ssock = ssl.wrap_socket(request, + keyfile=keyfile, + certfile=certfile, + server_side=True) + try: + self.RequestHandlerClass(ssock, client_address, self) + ssock.close() + except OSError: + # maybe socket has been closed by peer pass - class SSLWSGIServer(SilentWSGIServer): - def finish_request(self, request, client_address): - # The relative location of our test directory (which - # contains the ssl key and certificate files) differs - # between the stdlib and stand-alone asyncio. - # Prefer our own if we can find it. - here = os.path.join(os.path.dirname(__file__), '..', 'tests') - if not os.path.isdir(here): - here = os.path.join(os.path.dirname(os.__file__), - 'test', 'test_asyncio') - keyfile = os.path.join(here, 'ssl_key.pem') - certfile = os.path.join(here, 'ssl_cert.pem') - ssock = ssl.wrap_socket(request, - keyfile=keyfile, - certfile=certfile, - server_side=True) - try: - self.RequestHandlerClass(ssock, client_address, self) - ssock.close() - except OSError: - # maybe socket has been closed by peer - pass + +class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer): + pass + + +def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls): def app(environ, start_response): status = '200 OK' @@ -115,9 +130,9 @@ def app(environ, start_response): # Run the test WSGI server in a separate thread in order not to # interfere with event handling in the main thread - server_class = SSLWSGIServer if use_ssl else SilentWSGIServer - httpd = make_server(host, port, app, - server_class, SilentWSGIRequestHandler) + server_class = server_ssl_cls if use_ssl else server_cls + httpd = server_class(address, SilentWSGIRequestHandler) + httpd.set_app(app) httpd.address = httpd.server_address server_thread = threading.Thread(target=httpd.serve_forever) server_thread.start() @@ -129,6 +144,75 @@ def app(environ, start_response): server_thread.join() +if hasattr(socket, 'AF_UNIX'): + + class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer): + + def server_bind(self): + socketserver.UnixStreamServer.server_bind(self) + self.server_name = '127.0.0.1' + self.server_port = 80 + + + class UnixWSGIServer(UnixHTTPServer, WSGIServer): + + def server_bind(self): + UnixHTTPServer.server_bind(self) + self.setup_environ() + + def get_request(self): + request, client_addr = super().get_request() + # Code in the stdlib expects that get_request + # will return a socket and a tuple (host, port). + # However, this isn't true for UNIX sockets, + # as the second return value will be a path; + # hence we return some fake data sufficient + # to get the tests going + return request, ('127.0.0.1', '') + + + class SilentUnixWSGIServer(UnixWSGIServer): + + def handle_error(self, request, client_address): + pass + + + class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer): + pass + + + def gen_unix_socket_path(): + with tempfile.NamedTemporaryFile() as file: + return file.name + + + @contextlib.contextmanager + def unix_socket_path(): + path = gen_unix_socket_path() + try: + yield path + finally: + try: + os.unlink(path) + except OSError: + pass + + + @contextlib.contextmanager + def run_test_unix_server(*, use_ssl=False): + with unix_socket_path() as path: + yield from _run_test_server(address=path, use_ssl=use_ssl, + server_cls=SilentUnixWSGIServer, + server_ssl_cls=UnixSSLWSGIServer) + + +@contextlib.contextmanager +def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): + yield from _run_test_server(address=(host, port), use_ssl=use_ssl, + server_cls=SilentWSGIServer, + server_ssl_cls=SSLWSGIServer) + + def make_test_protocol(base): dct = {} for name in dir(base): @@ -275,5 +359,6 @@ def _process_events(self, event_list): def _write_to_self(self): pass + def MockCallback(**kwargs): return unittest.mock.Mock(spec=['__call__'], **kwargs) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index ea79d33b..e0d75077 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -11,6 +11,7 @@ import threading +from . import base_events from . import base_subprocess from . import constants from . import events @@ -31,9 +32,9 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): - """Unix event loop + """Unix event loop. - Adds signal handling to SelectorEventLoop + Adds signal handling and UNIX Domain Socket support to SelectorEventLoop. """ def __init__(self, selector=None): @@ -164,6 +165,76 @@ def _make_subprocess_transport(self, protocol, args, shell, def _child_watcher_callback(self, pid, returncode, transp): self.call_soon_threadsafe(transp._process_exited, returncode) + @tasks.coroutine + def create_unix_connection(self, protocol_factory, path, *, + ssl=None, sock=None, + server_hostname=None): + assert server_hostname is None or isinstance(server_hostname, str) + if ssl: + if server_hostname is None: + raise ValueError( + 'you have to pass server_hostname when using ssl') + else: + if server_hostname is not None: + raise ValueError('server_hostname is only meaningful with ssl') + + if path is not None: + if sock is not None: + raise ValueError( + 'path and sock can not be specified at the same time') + + try: + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0) + sock.setblocking(False) + yield from self.sock_connect(sock, path) + except OSError: + if sock is not None: + sock.close() + raise + + else: + if sock is None: + raise ValueError('no path and sock were specified') + sock.setblocking(False) + + transport, protocol = yield from self._create_connection_transport( + sock, protocol_factory, ssl, server_hostname) + return transport, protocol + + @tasks.coroutine + def create_unix_server(self, protocol_factory, path=None, *, + sock=None, backlog=100, ssl=None): + if isinstance(ssl, bool): + raise TypeError('ssl argument must be an SSLContext or None') + + if path is not None: + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + + try: + sock.bind(path) + except OSError as exc: + if exc.errno == errno.EADDRINUSE: + # Let's improve the error message by adding + # with what exact address it occurs. + msg = 'Address {!r} is already in use'.format(path) + raise OSError(errno.EADDRINUSE, msg) from None + else: + raise + else: + if sock is None: + raise ValueError( + 'path was not specified, and no sock specified') + + if sock.family != socket.AF_UNIX: + raise ValueError( + 'A UNIX Domain Socket was expected, got {!r}'.format(sock)) + + server = base_events.Server(self, [sock]) + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock, ssl, server) + return server + def _set_nonblocking(fd): flags = fcntl.fcntl(fd, fcntl.F_GETFL) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 94e2d59d..9fa98415 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -212,7 +212,7 @@ def monotonic(): idx = -1 data = [10.0, 10.0, 10.3, 13.0] - self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda:True, ())] + self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda: True, ())] self.loop._run_once() self.assertEqual(logging.DEBUG, m_logger.log.call_args[0][0]) diff --git a/tests/test_events.py b/tests/test_events.py index 8c32a6e7..c9d04c04 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -39,13 +39,14 @@ def data_file(filename): return fullname raise FileNotFoundError(filename) + ONLYCERT = data_file('ssl_cert.pem') ONLYKEY = data_file('ssl_key.pem') SIGNED_CERTFILE = data_file('keycert3.pem') SIGNING_CA = data_file('pycacert.pem') -class MyProto(asyncio.Protocol): +class MyBaseProto(asyncio.Protocol): done = None def __init__(self, loop=None): @@ -59,7 +60,6 @@ def connection_made(self, transport): self.transport = transport assert self.state == 'INITIAL', self.state self.state = 'CONNECTED' - transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') def data_received(self, data): assert self.state == 'CONNECTED', self.state @@ -76,6 +76,12 @@ def connection_lost(self, exc): self.done.set_result(None) +class MyProto(MyBaseProto): + def connection_made(self, transport): + super().connection_made(transport) + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + class MyDatagramProto(asyncio.DatagramProtocol): done = None @@ -357,22 +363,30 @@ def remove_writer(): r.close() self.assertGreaterEqual(len(data), 200) + def _basetest_sock_client_ops(self, httpd, sock): + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, httpd.address)) + self.loop.run_until_complete( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + # consume data + self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + sock.close() + self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) + def test_sock_client_ops(self): with test_utils.run_test_server() as httpd: sock = socket.socket() - sock.setblocking(False) - self.loop.run_until_complete( - self.loop.sock_connect(sock, httpd.address)) - self.loop.run_until_complete( - self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) - data = self.loop.run_until_complete( - self.loop.sock_recv(sock, 1024)) - # consume data - self.loop.run_until_complete( - self.loop.sock_recv(sock, 1024)) - sock.close() + self._basetest_sock_client_ops(httpd, sock) - self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_unix_sock_client_ops(self): + with test_utils.run_test_unix_server() as httpd: + sock = socket.socket(socket.AF_UNIX) + self._basetest_sock_client_ops(httpd, sock) def test_sock_client_fail(self): # Make sure that we will get an unused port @@ -485,16 +499,26 @@ def my_handler(*args): self.loop.run_forever() self.assertEqual(caught, 1) + def _basetest_create_connection(self, connection_fut): + tr, pr = self.loop.run_until_complete(connection_fut) + self.assertIsInstance(tr, asyncio.Transport) + self.assertIsInstance(pr, asyncio.Protocol) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + def test_create_connection(self): with test_utils.run_test_server() as httpd: - f = self.loop.create_connection( + conn_fut = self.loop.create_connection( lambda: MyProto(loop=self.loop), *httpd.address) - tr, pr = self.loop.run_until_complete(f) - self.assertIsInstance(tr, asyncio.Transport) - self.assertIsInstance(pr, asyncio.Protocol) - self.loop.run_until_complete(pr.done) - self.assertGreater(pr.nbytes, 0) - tr.close() + self._basetest_create_connection(conn_fut) + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_connection(self): + with test_utils.run_test_unix_server() as httpd: + conn_fut = self.loop.create_unix_connection( + lambda: MyProto(loop=self.loop), httpd.address) + self._basetest_create_connection(conn_fut) def test_create_connection_sock(self): with test_utils.run_test_server() as httpd: @@ -524,20 +548,37 @@ def test_create_connection_sock(self): self.assertGreater(pr.nbytes, 0) tr.close() + def _basetest_create_ssl_connection(self, connection_fut): + tr, pr = self.loop.run_until_complete(connection_fut) + self.assertIsInstance(tr, asyncio.Transport) + self.assertIsInstance(pr, asyncio.Protocol) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + self.assertIsNotNone(tr.get_extra_info('sockname')) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + @unittest.skipIf(ssl is None, 'No ssl module') def test_create_ssl_connection(self): with test_utils.run_test_server(use_ssl=True) as httpd: - f = self.loop.create_connection( - lambda: MyProto(loop=self.loop), *httpd.address, + conn_fut = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, ssl=test_utils.dummy_ssl_context()) - tr, pr = self.loop.run_until_complete(f) - self.assertIsInstance(tr, asyncio.Transport) - self.assertIsInstance(pr, asyncio.Protocol) - self.assertTrue('ssl' in tr.__class__.__name__.lower()) - self.assertIsNotNone(tr.get_extra_info('sockname')) - self.loop.run_until_complete(pr.done) - self.assertGreater(pr.nbytes, 0) - tr.close() + + self._basetest_create_ssl_connection(conn_fut) + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_ssl_unix_connection(self): + with test_utils.run_test_unix_server(use_ssl=True) as httpd: + conn_fut = self.loop.create_unix_connection( + lambda: MyProto(loop=self.loop), + httpd.address, + ssl=test_utils.dummy_ssl_context(), + server_hostname='127.0.0.1') + + self._basetest_create_ssl_connection(conn_fut) def test_create_connection_local_addr(self): with test_utils.run_test_server() as httpd: @@ -561,14 +602,8 @@ def test_create_connection_local_addr_in_use(self): self.assertIn(str(httpd.address), cm.exception.strerror) def test_create_server(self): - proto = None - - def factory(): - nonlocal proto - proto = MyProto() - return proto - - f = self.loop.create_server(factory, '0.0.0.0', 0) + proto = MyProto() + f = self.loop.create_server(lambda: proto, '0.0.0.0', 0) server = self.loop.run_until_complete(f) self.assertEqual(len(server.sockets), 1) sock = server.sockets[0] @@ -605,38 +640,76 @@ def factory(): # close server server.close() - def _make_ssl_server(self, factory, certfile, keyfile=None): + def _make_unix_server(self, factory, **kwargs): + path = test_utils.gen_unix_socket_path() + self.addCleanup(lambda: os.path.exists(path) and os.unlink(path)) + + f = self.loop.create_unix_server(factory, path, **kwargs) + server = self.loop.run_until_complete(f) + + return server, path + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server(self): + proto = MyProto() + server, path = self._make_unix_server(lambda: proto) + self.assertEqual(len(server.sockets), 1) + + client = socket.socket(socket.AF_UNIX) + client.connect(path) + client.sendall(b'xxx') + test_utils.run_briefly(self.loop) + test_utils.run_until(self.loop, lambda: proto is not None, 10) + + self.assertIsInstance(proto, MyProto) + self.assertEqual('INITIAL', proto.state) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + test_utils.run_until(self.loop, lambda: proto.nbytes > 0, + timeout=10) + self.assertEqual(3, proto.nbytes) + + # close connection + proto.transport.close() + test_utils.run_briefly(self.loop) # windows iocp + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # close server + server.close() + + def _create_ssl_context(self, certfile, keyfile=None): sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext.options |= ssl.OP_NO_SSLv2 sslcontext.load_cert_chain(certfile, keyfile) + return sslcontext - f = self.loop.create_server( - factory, '127.0.0.1', 0, ssl=sslcontext) + def _make_ssl_server(self, factory, certfile, keyfile=None): + sslcontext = self._create_ssl_context(certfile, keyfile) + f = self.loop.create_server(factory, '127.0.0.1', 0, ssl=sslcontext) server = self.loop.run_until_complete(f) + sock = server.sockets[0] host, port = sock.getsockname() self.assertEqual(host, '127.0.0.1') return server, host, port + def _make_ssl_unix_server(self, factory, certfile, keyfile=None): + sslcontext = self._create_ssl_context(certfile, keyfile) + return self._make_unix_server(factory, ssl=sslcontext) + @unittest.skipIf(ssl is None, 'No ssl module') def test_create_server_ssl(self): - proto = None - - class ClientMyProto(MyProto): - def connection_made(self, transport): - self.transport = transport - assert self.state == 'INITIAL', self.state - self.state = 'CONNECTED' + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, ONLYCERT, ONLYKEY) - def factory(): - nonlocal proto - proto = MyProto(loop=self.loop) - return proto - - server, host, port = self._make_ssl_server(factory, ONLYCERT, ONLYKEY) - - f_c = self.loop.create_connection(ClientMyProto, host, port, + f_c = self.loop.create_connection(MyBaseProto, host, port, ssl=test_utils.dummy_ssl_context()) client, pr = self.loop.run_until_complete(f_c) @@ -667,16 +740,45 @@ def factory(): server.close() @unittest.skipIf(ssl is None, 'No ssl module') - @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') - def test_create_server_ssl_verify_failed(self): - proto = None + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_ssl(self): + proto = MyProto(loop=self.loop) + server, path = self._make_ssl_unix_server( + lambda: proto, ONLYCERT, ONLYKEY) - def factory(): - nonlocal proto - proto = MyProto(loop=self.loop) - return proto + f_c = self.loop.create_unix_connection( + MyBaseProto, path, ssl=test_utils.dummy_ssl_context(), + server_hostname='') + + client, pr = self.loop.run_until_complete(f_c) + + client.write(b'xxx') + test_utils.run_briefly(self.loop) + self.assertIsInstance(proto, MyProto) + test_utils.run_briefly(self.loop) + self.assertEqual('CONNECTED', proto.state) + test_utils.run_until(self.loop, lambda: proto.nbytes > 0, + timeout=10) + self.assertEqual(3, proto.nbytes) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() - server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE) + # stop serving + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') + def test_create_server_ssl_verify_failed(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, SIGNED_CERTFILE) sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext_client.options |= ssl.OP_NO_SSLv2 @@ -697,15 +799,36 @@ def factory(): @unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') - def test_create_server_ssl_match_failed(self): - proto = None + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_ssl_verify_failed(self): + proto = MyProto(loop=self.loop) + server, path = self._make_ssl_unix_server( + lambda: proto, SIGNED_CERTFILE) - def factory(): - nonlocal proto - proto = MyProto(loop=self.loop) - return proto + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True - server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE) + # no CA loaded + f_c = self.loop.create_unix_connection(MyProto, path, + ssl=sslcontext_client, + server_hostname='invalid') + with self.assertRaisesRegex(ssl.SSLError, + 'certificate verify failed '): + self.loop.run_until_complete(f_c) + + # close connection + self.assertIsNone(proto.transport) + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') + def test_create_server_ssl_match_failed(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, SIGNED_CERTFILE) sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext_client.options |= ssl.OP_NO_SSLv2 @@ -729,15 +852,36 @@ def factory(): @unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') - def test_create_server_ssl_verified(self): - proto = None + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_ssl_verified(self): + proto = MyProto(loop=self.loop) + server, path = self._make_ssl_unix_server( + lambda: proto, SIGNED_CERTFILE) - def factory(): - nonlocal proto - proto = MyProto(loop=self.loop) - return proto + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + sslcontext_client.load_verify_locations(cafile=SIGNING_CA) + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True - server, host, port = self._make_ssl_server(factory, SIGNED_CERTFILE) + # Connection succeeds with correct CA and server hostname. + f_c = self.loop.create_unix_connection(MyProto, path, + ssl=sslcontext_client, + server_hostname='localhost') + client, pr = self.loop.run_until_complete(f_c) + + # close connection + proto.transport.close() + client.close() + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') + def test_create_server_ssl_verified(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, SIGNED_CERTFILE) sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext_client.options |= ssl.OP_NO_SSLv2 @@ -915,19 +1059,15 @@ def test_internal_fds(self): @unittest.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") def test_read_pipe(self): - proto = None - - def factory(): - nonlocal proto - proto = MyReadPipeProto(loop=self.loop) - return proto + proto = MyReadPipeProto(loop=self.loop) rpipe, wpipe = os.pipe() pipeobj = io.open(rpipe, 'rb', 1024) @asyncio.coroutine def connect(): - t, p = yield from self.loop.connect_read_pipe(factory, pipeobj) + t, p = yield from self.loop.connect_read_pipe( + lambda: proto, pipeobj) self.assertIs(p, proto) self.assertIs(t, proto.transport) self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) @@ -959,19 +1099,14 @@ def connect(): # Issue #20495: The test hangs on FreeBSD 7.2 but pass on FreeBSD 9 @support.requires_freebsd_version(8) def test_read_pty_output(self): - proto = None - - def factory(): - nonlocal proto - proto = MyReadPipeProto(loop=self.loop) - return proto + proto = MyReadPipeProto(loop=self.loop) master, slave = os.openpty() master_read_obj = io.open(master, 'rb', 0) @asyncio.coroutine def connect(): - t, p = yield from self.loop.connect_read_pipe(factory, + t, p = yield from self.loop.connect_read_pipe(lambda: proto, master_read_obj) self.assertIs(p, proto) self.assertIs(t, proto.transport) @@ -999,21 +1134,17 @@ def connect(): @unittest.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") def test_write_pipe(self): - proto = None + proto = MyWritePipeProto(loop=self.loop) transport = None - def factory(): - nonlocal proto - proto = MyWritePipeProto(loop=self.loop) - return proto - rpipe, wpipe = os.pipe() pipeobj = io.open(wpipe, 'wb', 1024) @asyncio.coroutine def connect(): nonlocal transport - t, p = yield from self.loop.connect_write_pipe(factory, pipeobj) + t, p = yield from self.loop.connect_write_pipe( + lambda: proto, pipeobj) self.assertIs(p, proto) self.assertIs(t, proto.transport) self.assertEqual('CONNECTED', proto.state) @@ -1045,21 +1176,16 @@ def connect(): @unittest.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") def test_write_pipe_disconnect_on_close(self): - proto = None + proto = MyWritePipeProto(loop=self.loop) transport = None - def factory(): - nonlocal proto - proto = MyWritePipeProto(loop=self.loop) - return proto - rsock, wsock = test_utils.socketpair() pipeobj = io.open(wsock.detach(), 'wb', 1024) @asyncio.coroutine def connect(): nonlocal transport - t, p = yield from self.loop.connect_write_pipe(factory, + t, p = yield from self.loop.connect_write_pipe(lambda: proto, pipeobj) self.assertIs(p, proto) self.assertIs(t, proto.transport) @@ -1084,21 +1210,16 @@ def connect(): # older than 10.6 (Snow Leopard) @support.requires_mac_ver(10, 6) def test_write_pty(self): - proto = None + proto = MyWritePipeProto(loop=self.loop) transport = None - def factory(): - nonlocal proto - proto = MyWritePipeProto(loop=self.loop) - return proto - master, slave = os.openpty() slave_write_obj = io.open(slave, 'wb', 0) @asyncio.coroutine def connect(): nonlocal transport - t, p = yield from self.loop.connect_write_pipe(factory, + t, p = yield from self.loop.connect_write_pipe(lambda: proto, slave_write_obj) self.assertIs(p, proto) self.assertIs(t, proto.transport) diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 855a8954..7741e191 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -55,7 +55,8 @@ def test_make_ssl_transport(self): self.loop.remove_reader = unittest.mock.Mock() self.loop.remove_writer = unittest.mock.Mock() waiter = asyncio.Future(loop=self.loop) - transport = self.loop._make_ssl_transport(m, asyncio.Protocol(), m, waiter) + transport = self.loop._make_ssl_transport( + m, asyncio.Protocol(), m, waiter) self.assertIsInstance(transport, _SelectorSslTransport) @unittest.mock.patch('asyncio.selector_events.ssl', None) diff --git a/tests/test_streams.py b/tests/test_streams.py index ee3fb450..31e26a64 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -1,6 +1,8 @@ """Tests for streams.py.""" +import functools import gc +import socket import unittest import unittest.mock try: @@ -32,48 +34,85 @@ def test_ctor_global_loop(self, m_events): stream = asyncio.StreamReader() self.assertIs(stream._loop, m_events.get_event_loop.return_value) + def _basetest_open_connection(self, open_connection_fut): + reader, writer = self.loop.run_until_complete(open_connection_fut) + writer.write(b'GET / HTTP/1.0\r\n\r\n') + f = reader.readline() + data = self.loop.run_until_complete(f) + self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') + f = reader.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + writer.close() + def test_open_connection(self): with test_utils.run_test_server() as httpd: - f = asyncio.open_connection(*httpd.address, loop=self.loop) - reader, writer = self.loop.run_until_complete(f) - writer.write(b'GET / HTTP/1.0\r\n\r\n') - f = reader.readline() - data = self.loop.run_until_complete(f) - self.assertEqual(data, b'HTTP/1.0 200 OK\r\n') - f = reader.read() - data = self.loop.run_until_complete(f) - self.assertTrue(data.endswith(b'\r\n\r\nTest message')) - - writer.close() + conn_fut = asyncio.open_connection(*httpd.address, + loop=self.loop) + self._basetest_open_connection(conn_fut) + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_open_unix_connection(self): + with test_utils.run_test_unix_server() as httpd: + conn_fut = asyncio.open_unix_connection(httpd.address, + loop=self.loop) + self._basetest_open_connection(conn_fut) + + def _basetest_open_connection_no_loop_ssl(self, open_connection_fut): + try: + reader, writer = self.loop.run_until_complete(open_connection_fut) + finally: + asyncio.set_event_loop(None) + writer.write(b'GET / HTTP/1.0\r\n\r\n') + f = reader.read() + data = self.loop.run_until_complete(f) + self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + + writer.close() @unittest.skipIf(ssl is None, 'No ssl module') def test_open_connection_no_loop_ssl(self): with test_utils.run_test_server(use_ssl=True) as httpd: - try: - asyncio.set_event_loop(self.loop) - f = asyncio.open_connection(*httpd.address, - ssl=test_utils.dummy_ssl_context()) - reader, writer = self.loop.run_until_complete(f) - finally: - asyncio.set_event_loop(None) - writer.write(b'GET / HTTP/1.0\r\n\r\n') - f = reader.read() - data = self.loop.run_until_complete(f) - self.assertTrue(data.endswith(b'\r\n\r\nTest message')) + conn_fut = asyncio.open_connection( + *httpd.address, + ssl=test_utils.dummy_ssl_context(), + loop=self.loop) - writer.close() + self._basetest_open_connection_no_loop_ssl(conn_fut) + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_open_unix_connection_no_loop_ssl(self): + with test_utils.run_test_unix_server(use_ssl=True) as httpd: + conn_fut = asyncio.open_unix_connection( + httpd.address, + ssl=test_utils.dummy_ssl_context(), + server_hostname='', + loop=self.loop) + + self._basetest_open_connection_no_loop_ssl(conn_fut) + + def _basetest_open_connection_error(self, open_connection_fut): + reader, writer = self.loop.run_until_complete(open_connection_fut) + writer._protocol.connection_lost(ZeroDivisionError()) + f = reader.read() + with self.assertRaises(ZeroDivisionError): + self.loop.run_until_complete(f) + writer.close() + test_utils.run_briefly(self.loop) def test_open_connection_error(self): with test_utils.run_test_server() as httpd: - f = asyncio.open_connection(*httpd.address, loop=self.loop) - reader, writer = self.loop.run_until_complete(f) - writer._protocol.connection_lost(ZeroDivisionError()) - f = reader.read() - with self.assertRaises(ZeroDivisionError): - self.loop.run_until_complete(f) + conn_fut = asyncio.open_connection(*httpd.address, + loop=self.loop) + self._basetest_open_connection_error(conn_fut) - writer.close() - test_utils.run_briefly(self.loop) + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_open_unix_connection_error(self): + with test_utils.run_test_unix_server() as httpd: + conn_fut = asyncio.open_unix_connection(httpd.address, + loop=self.loop) + self._basetest_open_connection_error(conn_fut) def test_feed_empty_data(self): stream = asyncio.StreamReader(loop=self.loop) @@ -415,10 +454,13 @@ def handle_client(self, client_reader, client_writer): client_writer.write(data) def start(self): + sock = socket.socket() + sock.bind(('127.0.0.1', 0)) self.server = self.loop.run_until_complete( asyncio.start_server(self.handle_client, - '127.0.0.1', 12345, + sock=sock, loop=self.loop)) + return sock.getsockname() def handle_client_callback(self, client_reader, client_writer): task = asyncio.Task(client_reader.readline(), loop=self.loop) @@ -429,10 +471,15 @@ def done(task): task.add_done_callback(done) def start_callback(self): + sock = socket.socket() + sock.bind(('127.0.0.1', 0)) + addr = sock.getsockname() + sock.close() self.server = self.loop.run_until_complete( asyncio.start_server(self.handle_client_callback, - '127.0.0.1', 12345, + host=addr[0], port=addr[1], loop=self.loop)) + return addr def stop(self): if self.server is not None: @@ -441,9 +488,9 @@ def stop(self): self.server = None @asyncio.coroutine - def client(): + def client(addr): reader, writer = yield from asyncio.open_connection( - '127.0.0.1', 12345, loop=self.loop) + *addr, loop=self.loop) # send a line writer.write(b"hello world!\n") # read it back @@ -453,20 +500,90 @@ def client(): # test the server variant with a coroutine as client handler server = MyServer(self.loop) - server.start() - msg = self.loop.run_until_complete(asyncio.Task(client(), + addr = server.start() + msg = self.loop.run_until_complete(asyncio.Task(client(addr), loop=self.loop)) server.stop() self.assertEqual(msg, b"hello world!\n") # test the server variant with a callback as client handler server = MyServer(self.loop) - server.start_callback() - msg = self.loop.run_until_complete(asyncio.Task(client(), + addr = server.start_callback() + msg = self.loop.run_until_complete(asyncio.Task(client(addr), loop=self.loop)) server.stop() self.assertEqual(msg, b"hello world!\n") + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_start_unix_server(self): + + class MyServer: + + def __init__(self, loop, path): + self.server = None + self.loop = loop + self.path = path + + @asyncio.coroutine + def handle_client(self, client_reader, client_writer): + data = yield from client_reader.readline() + client_writer.write(data) + + def start(self): + self.server = self.loop.run_until_complete( + asyncio.start_unix_server(self.handle_client, + path=self.path, + loop=self.loop)) + + def handle_client_callback(self, client_reader, client_writer): + task = asyncio.Task(client_reader.readline(), loop=self.loop) + + def done(task): + client_writer.write(task.result()) + + task.add_done_callback(done) + + def start_callback(self): + self.server = self.loop.run_until_complete( + asyncio.start_unix_server(self.handle_client_callback, + path=self.path, + loop=self.loop)) + + def stop(self): + if self.server is not None: + self.server.close() + self.loop.run_until_complete(self.server.wait_closed()) + self.server = None + + @asyncio.coroutine + def client(path): + reader, writer = yield from asyncio.open_unix_connection( + path, loop=self.loop) + # send a line + writer.write(b"hello world!\n") + # read it back + msgback = yield from reader.readline() + writer.close() + return msgback + + # test the server variant with a coroutine as client handler + with test_utils.unix_socket_path() as path: + server = MyServer(self.loop, path) + server.start() + msg = self.loop.run_until_complete(asyncio.Task(client(path), + loop=self.loop)) + server.stop() + self.assertEqual(msg, b"hello world!\n") + + # test the server variant with a callback as client handler + with test_utils.unix_socket_path() as path: + server = MyServer(self.loop, path) + server.start_callback() + msg = self.loop.run_until_complete(asyncio.Task(client(path), + loop=self.loop)) + server.stop() + self.assertEqual(msg, b"hello world!\n") + if __name__ == '__main__': unittest.main() diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 9461ec8b..2fa1db45 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -7,8 +7,10 @@ import os import pprint import signal +import socket import stat import sys +import tempfile import threading import unittest import unittest.mock @@ -24,7 +26,7 @@ @unittest.skipUnless(signal, 'Signals are not supported') -class SelectorEventLoopTests(unittest.TestCase): +class SelectorEventLoopSignalTests(unittest.TestCase): def setUp(self): self.loop = asyncio.SelectorEventLoop() @@ -200,6 +202,84 @@ def test_close(self, m_signal): m_signal.set_wakeup_fd.assert_called_once_with(-1) +@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), + 'UNIX Sockets are not supported') +class SelectorEventLoopUnixSocketTests(unittest.TestCase): + + def setUp(self): + self.loop = asyncio.SelectorEventLoop() + asyncio.set_event_loop(None) + + def tearDown(self): + self.loop.close() + + def test_create_unix_server_existing_path_sock(self): + with test_utils.unix_socket_path() as path: + sock = socket.socket(socket.AF_UNIX) + sock.bind(path) + + coro = self.loop.create_unix_server(lambda: None, path) + with self.assertRaisesRegexp(OSError, + 'Address.*is already in use'): + self.loop.run_until_complete(coro) + + def test_create_unix_server_existing_path_nonsock(self): + with tempfile.NamedTemporaryFile() as file: + coro = self.loop.create_unix_server(lambda: None, file.name) + with self.assertRaisesRegexp(OSError, + 'Address.*is already in use'): + self.loop.run_until_complete(coro) + + def test_create_unix_server_ssl_bool(self): + coro = self.loop.create_unix_server(lambda: None, path='spam', + ssl=True) + with self.assertRaisesRegex(TypeError, + 'ssl argument must be an SSLContext'): + self.loop.run_until_complete(coro) + + def test_create_unix_server_nopath_nosock(self): + coro = self.loop.create_unix_server(lambda: None, path=None) + with self.assertRaisesRegex(ValueError, + 'path was not specified, and no sock'): + self.loop.run_until_complete(coro) + + def test_create_unix_server_path_inetsock(self): + coro = self.loop.create_unix_server(lambda: None, path=None, + sock=socket.socket()) + with self.assertRaisesRegex(ValueError, + 'A UNIX Domain Socket was expected'): + self.loop.run_until_complete(coro) + + def test_create_unix_connection_path_sock(self): + coro = self.loop.create_unix_connection( + lambda: None, '/dev/null', sock=object()) + with self.assertRaisesRegex(ValueError, 'path and sock can not be'): + self.loop.run_until_complete(coro) + + def test_create_unix_connection_nopath_nosock(self): + coro = self.loop.create_unix_connection( + lambda: None, None) + with self.assertRaisesRegex(ValueError, + 'no path and sock were specified'): + self.loop.run_until_complete(coro) + + def test_create_unix_connection_nossl_serverhost(self): + coro = self.loop.create_unix_connection( + lambda: None, '/dev/null', server_hostname='spam') + with self.assertRaisesRegex(ValueError, + 'server_hostname is only meaningful'): + self.loop.run_until_complete(coro) + + def test_create_unix_connection_ssl_noserverhost(self): + coro = self.loop.create_unix_connection( + lambda: None, '/dev/null', ssl=True) + + with self.assertRaisesRegexp( + ValueError, 'you have to pass server_hostname when using ssl'): + + self.loop.run_until_complete(coro) + + class UnixReadPipeTransportTests(unittest.TestCase): def setUp(self): From 57014e56b32efa69a0d59ad57ff30fffaffbece0 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 18 Feb 2014 10:21:09 -0800 Subject: [PATCH 0957/1502] Only add *_unix_* to __all__ if they are defined. --- asyncio/streams.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/asyncio/streams.py b/asyncio/streams.py index 698c5c6b..27d595f1 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -2,12 +2,14 @@ __all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol', 'open_connection', 'start_server', - 'open_unix_connection', 'start_unix_server', 'IncompleteReadError', ] import socket +if hasattr(socket, 'AF_UNIX'): + __all__.extend(['open_unix_connection', 'start_unix_server']) + from . import events from . import futures from . import protocols From 0c43056264e3b4d338d0b524063eff370f2824f0 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Tue, 18 Feb 2014 17:52:25 -0500 Subject: [PATCH 0958/1502] Add new event loop exception handling API (closes issue #80). New APIs: - loop.set_exception_handler() - loop.default_exception_handler() - loop.call_exception_handler() --- asyncio/base_events.py | 96 ++++++++++++++++- asyncio/events.py | 31 ++++-- asyncio/futures.py | 22 ++-- asyncio/proactor_events.py | 33 ++++-- asyncio/selector_events.py | 35 ++++-- asyncio/test_utils.py | 18 +++- asyncio/unix_events.py | 26 ++++- asyncio/windows_events.py | 8 +- tests/test_base_events.py | 194 +++++++++++++++++++++++++++++++--- tests/test_events.py | 44 +++++--- tests/test_futures.py | 12 +-- tests/test_proactor_events.py | 8 +- tests/test_selector_events.py | 20 +++- tests/test_unix_events.py | 39 ++++--- 14 files changed, 487 insertions(+), 99 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index b74e9369..cb2499d2 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -122,6 +122,7 @@ def __init__(self): self._internal_fds = 0 self._running = False self._clock_resolution = time.get_clock_info('monotonic').resolution + self._exception_handler = None def _make_socket_transport(self, sock, protocol, waiter=None, *, extra=None, server=None): @@ -254,7 +255,7 @@ def call_at(self, when, callback, *args): """Like call_later(), but uses an absolute time.""" if tasks.iscoroutinefunction(callback): raise TypeError("coroutines cannot be used with call_at()") - timer = events.TimerHandle(when, callback, args) + timer = events.TimerHandle(when, callback, args, self) heapq.heappush(self._scheduled, timer) return timer @@ -270,7 +271,7 @@ def call_soon(self, callback, *args): """ if tasks.iscoroutinefunction(callback): raise TypeError("coroutines cannot be used with call_soon()") - handle = events.Handle(callback, args) + handle = events.Handle(callback, args, self) self._ready.append(handle) return handle @@ -625,6 +626,97 @@ def subprocess_exec(self, protocol_factory, program, *args, stdin=subprocess.PIP protocol, popen_args, False, stdin, stdout, stderr, bufsize, **kwargs) return transport, protocol + def set_exception_handler(self, handler): + """Set handler as the new event loop exception handler. + + If handler is None, the default exception handler will + be set. + + If handler is a callable object, it should have a + matching signature to '(loop, context)', where 'loop' + will be a reference to the active event loop, 'context' + will be a dict object (see `call_exception_handler()` + documentation for details about context). + """ + if handler is not None and not callable(handler): + raise TypeError('A callable object or None is expected, ' + 'got {!r}'.format(handler)) + self._exception_handler = handler + + def default_exception_handler(self, context): + """Default exception handler. + + This is called when an exception occurs and no exception + handler is set, and can be called by a custom exception + handler that wants to defer to the default behavior. + + context parameter has the same meaning as in + `call_exception_handler()`. + """ + message = context.get('message') + if not message: + message = 'Unhandled exception in event loop' + + exception = context.get('exception') + if exception is not None: + exc_info = (type(exception), exception, exception.__traceback__) + else: + exc_info = False + + log_lines = [message] + for key in sorted(context): + if key in {'message', 'exception'}: + continue + log_lines.append('{}: {!r}'.format(key, context[key])) + + logger.error('\n'.join(log_lines), exc_info=exc_info) + + def call_exception_handler(self, context): + """Call the current event loop exception handler. + + context is a dict object containing the following keys + (new keys maybe introduced later): + - 'message': Error message; + - 'exception' (optional): Exception object; + - 'future' (optional): Future instance; + - 'handle' (optional): Handle instance; + - 'protocol' (optional): Protocol instance; + - 'transport' (optional): Transport instance; + - 'socket' (optional): Socket instance. + + Note: this method should not be overloaded in subclassed + event loops. For any custom exception handling, use + `set_exception_handler()` method. + """ + if self._exception_handler is None: + try: + self.default_exception_handler(context) + except Exception: + # Second protection layer for unexpected errors + # in the default implementation, as well as for subclassed + # event loops with overloaded "default_exception_handler". + logger.error('Exception in default exception handler', + exc_info=True) + else: + try: + self._exception_handler(self, context) + except Exception as exc: + # Exception in the user set custom exception handler. + try: + # Let's try default handler. + self.default_exception_handler({ + 'message': 'Unhandled error in exception handler', + 'exception': exc, + 'context': context, + }) + except Exception: + # Guard 'default_exception_handler' in case it's + # overloaded. + logger.error('Exception in default exception handler ' + 'while handling an unexpected error ' + 'in custom exception handler', + exc_info=True) + def _add_callback(self, handle): """Add a Handle to ready or scheduled.""" assert isinstance(handle, events.Handle), 'A Handle is required here' diff --git a/asyncio/events.py b/asyncio/events.py index 7841ad9b..f61c5b74 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -19,10 +19,11 @@ class Handle: """Object returned by callback registration methods.""" - __slots__ = ['_callback', '_args', '_cancelled'] + __slots__ = ['_callback', '_args', '_cancelled', '_loop'] - def __init__(self, callback, args): + def __init__(self, callback, args, loop): assert not isinstance(callback, Handle), 'A Handle is not a callback' + self._loop = loop self._callback = callback self._args = args self._cancelled = False @@ -39,9 +40,14 @@ def cancel(self): def _run(self): try: self._callback(*self._args) - except Exception: - logger.exception('Exception in callback %s %r', - self._callback, self._args) + except Exception as exc: + msg = 'Exception in callback {}{!r}'.format(self._callback, + self._args) + self._loop.call_exception_handler({ + 'message': msg, + 'exception': exc, + 'handle': self, + }) self = None # Needed to break cycles when an exception occurs. @@ -50,9 +56,9 @@ class TimerHandle(Handle): __slots__ = ['_when'] - def __init__(self, when, callback, args): + def __init__(self, when, callback, args, loop): assert when is not None - super().__init__(callback, args) + super().__init__(callback, args, loop) self._when = when @@ -328,6 +334,17 @@ def add_signal_handler(self, sig, callback, *args): def remove_signal_handler(self, sig): raise NotImplementedError + # Error handlers. + + def set_exception_handler(self, handler): + raise NotImplementedError + + def default_exception_handler(self, context): + raise NotImplementedError + + def call_exception_handler(self, context): + raise NotImplementedError + class AbstractEventLoopPolicy: """Abstract policy for accessing the event loop.""" diff --git a/asyncio/futures.py b/asyncio/futures.py index d09f423c..b9cd45c7 100644 --- a/asyncio/futures.py +++ b/asyncio/futures.py @@ -83,9 +83,10 @@ class itself, but instead to have a reference to a helper object in a discussion about closing files when they are collected. """ - __slots__ = ['exc', 'tb'] + __slots__ = ['exc', 'tb', 'loop'] - def __init__(self, exc): + def __init__(self, exc, loop): + self.loop = loop self.exc = exc self.tb = None @@ -102,8 +103,11 @@ def clear(self): def __del__(self): if self.tb: - logger.error('Future/Task exception was never retrieved:\n%s', - ''.join(self.tb)) + msg = 'Future/Task exception was never retrieved:\n{tb}' + context = { + 'message': msg.format(tb=''.join(self.tb)), + } + self.loop.call_exception_handler(context) class Future: @@ -173,8 +177,12 @@ def __del__(self): # has consumed the exception return exc = self._exception - logger.error('Future/Task exception was never retrieved:', - exc_info=(exc.__class__, exc, exc.__traceback__)) + context = { + 'message': 'Future/Task exception was never retrieved', + 'exception': exc, + 'future': self, + } + self._loop.call_exception_handler(context) def cancel(self): """Cancel the future and schedule callbacks. @@ -309,7 +317,7 @@ def set_exception(self, exception): if _PY34: self._log_traceback = True else: - self._tb_logger = _TracebackLogger(exception) + self._tb_logger = _TracebackLogger(exception, self._loop) # Arrange for the logger to be activated after all callbacks # have had a chance to call result() or exception(). self._loop.call_soon(self._tb_logger.activate) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 5de4d3d6..b2ac632f 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -56,7 +56,12 @@ def close(self): def _fatal_error(self, exc): if not isinstance(exc, (BrokenPipeError, ConnectionResetError)): - logger.exception('Fatal error for %s', self) + self._loop.call_exception_handler({ + 'message': 'Fatal transport error', + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) self._force_close(exc) def _force_close(self, exc): @@ -103,8 +108,13 @@ def _maybe_pause_protocol(self): self._protocol_paused = True try: self._protocol.pause_writing() - except Exception: - logger.exception('pause_writing() failed') + except Exception as exc: + self._loop.call_exception_handler({ + 'message': 'protocol.pause_writing() failed', + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) def _maybe_resume_protocol(self): if (self._protocol_paused and @@ -112,8 +122,13 @@ def _maybe_resume_protocol(self): self._protocol_paused = False try: self._protocol.resume_writing() - except Exception: - logger.exception('resume_writing() failed') + except Exception as exc: + self._loop.call_exception_handler({ + 'message': 'protocol.resume_writing() failed', + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) def set_write_buffer_limits(self, high=None, low=None): if high is None: @@ -465,9 +480,13 @@ def loop(f=None): conn, protocol, extra={'peername': addr}, server=server) f = self._proactor.accept(sock) - except OSError: + except OSError as exc: if sock.fileno() != -1: - logger.exception('Accept failed') + self.call_exception_handler({ + 'message': 'Accept failed', + 'exception': exc, + 'socket': sock, + }) sock.close() except futures.CancelledError: sock.close() diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 10b02579..fb86f824 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -112,7 +112,11 @@ def _accept_connection(self, protocol_factory, sock, # Some platforms (e.g. Linux keep reporting the FD as # ready, so we remove the read handler temporarily. # We'll try again in a while. - logger.exception('Accept out of system resource (%s)', exc) + self.call_exception_handler({ + 'message': 'socket.accept() out of system resource', + 'exception': exc, + 'socket': sock, + }) self.remove_reader(sock.fileno()) self.call_later(constants.ACCEPT_RETRY_DELAY, self._start_serving, @@ -132,7 +136,7 @@ def _accept_connection(self, protocol_factory, sock, def add_reader(self, fd, callback, *args): """Add a reader callback.""" - handle = events.Handle(callback, args) + handle = events.Handle(callback, args, self) try: key = self._selector.get_key(fd) except KeyError: @@ -167,7 +171,7 @@ def remove_reader(self, fd): def add_writer(self, fd, callback, *args): """Add a writer callback..""" - handle = events.Handle(callback, args) + handle = events.Handle(callback, args, self) try: key = self._selector.get_key(fd) except KeyError: @@ -364,8 +368,13 @@ def _maybe_pause_protocol(self): self._protocol_paused = True try: self._protocol.pause_writing() - except Exception: - logger.exception('pause_writing() failed') + except Exception as exc: + self._loop.call_exception_handler({ + 'message': 'protocol.pause_writing() failed', + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) def _maybe_resume_protocol(self): if (self._protocol_paused and @@ -373,8 +382,13 @@ def _maybe_resume_protocol(self): self._protocol_paused = False try: self._protocol.resume_writing() - except Exception: - logger.exception('resume_writing() failed') + except Exception as exc: + self._loop.call_exception_handler({ + 'message': 'protocol.resume_writing() failed', + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) def set_write_buffer_limits(self, high=None, low=None): if high is None: @@ -435,7 +449,12 @@ def close(self): def _fatal_error(self, exc): # Should be called from exception handler only. if not isinstance(exc, (BrokenPipeError, ConnectionResetError)): - logger.exception('Fatal error for %s', self) + self._loop.call_exception_handler({ + 'message': 'Fatal transport error', + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) self._force_close(exc) def _force_close(self, exc): diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index de2916bf..28e52430 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -4,6 +4,7 @@ import contextlib import io import os +import re import socket import socketserver import sys @@ -301,7 +302,7 @@ def close(self): raise AssertionError("Time generator is not finished") def add_reader(self, fd, callback, *args): - self.readers[fd] = events.Handle(callback, args) + self.readers[fd] = events.Handle(callback, args, self) def remove_reader(self, fd): self.remove_reader_count[fd] += 1 @@ -320,7 +321,7 @@ def assert_reader(self, fd, callback, *args): handle._args, args) def add_writer(self, fd, callback, *args): - self.writers[fd] = events.Handle(callback, args) + self.writers[fd] = events.Handle(callback, args, self) def remove_writer(self, fd): self.remove_writer_count[fd] += 1 @@ -362,3 +363,16 @@ def _write_to_self(self): def MockCallback(**kwargs): return unittest.mock.Mock(spec=['__call__'], **kwargs) + + +class MockPattern(str): + """A regex based str with a fuzzy __eq__. + + Use this helper with 'mock.assert_called_with', or anywhere + where a regexp comparison between strings is needed. + + For instance: + mock_call.assert_called_with(MockPattern('spam.*ham')) + """ + def __eq__(self, other): + return bool(re.search(str(self), other, re.S)) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index e0d75077..9a40c04d 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -65,7 +65,7 @@ def add_signal_handler(self, sig, callback, *args): except ValueError as exc: raise RuntimeError(str(exc)) - handle = events.Handle(callback, args) + handle = events.Handle(callback, args, self) self._signal_handlers[sig] = handle try: @@ -294,7 +294,12 @@ def close(self): def _fatal_error(self, exc): # should be called by exception handler only if not (isinstance(exc, OSError) and exc.errno == errno.EIO): - logger.exception('Fatal error for %s', self) + self._loop.call_exception_handler({ + 'message': 'Fatal transport error', + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) self._close(exc) def _close(self, exc): @@ -441,7 +446,12 @@ def abort(self): def _fatal_error(self, exc): # should be called by exception handler only if not isinstance(exc, (BrokenPipeError, ConnectionResetError)): - logger.exception('Fatal error for %s', self) + self._loop.call_exception_handler({ + 'message': 'Fatal transport error', + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) self._close(exc) def _close(self, exc=None): @@ -582,8 +592,14 @@ def attach_loop(self, loop): def _sig_chld(self): try: self._do_waitpid_all() - except Exception: - logger.exception('Unknown exception in SIGCHLD handler') + except Exception as exc: + # self._loop should always be available here + # as '_sig_chld' is added as a signal handler + # in 'attach_loop' + self._loop.call_exception_handler({ + 'message': 'Unknown exception in SIGCHLD handler', + 'exception': exc, + }) def _compute_returncode(self, status): if os.WIFSIGNALED(status): diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 0a2d9810..c667a1c3 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -156,9 +156,13 @@ def loop(f=None): if pipe is None: return f = self._proactor.accept_pipe(pipe) - except OSError: + except OSError as exc: if pipe and pipe.fileno() != -1: - logger.exception('Pipe accept failed') + self.call_exception_handler({ + 'message': 'Pipe accept failed', + 'exception': exc, + 'pipe': pipe, + }) pipe.close() except futures.CancelledError: if pipe: diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 9fa98415..f664cccf 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -15,6 +15,10 @@ from asyncio import test_utils +MOCK_ANY = unittest.mock.ANY +PY34 = sys.version_info >= (3, 4) + + class BaseEventLoopTests(unittest.TestCase): def setUp(self): @@ -49,20 +53,21 @@ def test_not_implemented(self): self.assertRaises(NotImplementedError, next, iter(gen)) def test__add_callback_handle(self): - h = asyncio.Handle(lambda: False, ()) + h = asyncio.Handle(lambda: False, (), self.loop) self.loop._add_callback(h) self.assertFalse(self.loop._scheduled) self.assertIn(h, self.loop._ready) def test__add_callback_timer(self): - h = asyncio.TimerHandle(time.monotonic()+10, lambda: False, ()) + h = asyncio.TimerHandle(time.monotonic()+10, lambda: False, (), + self.loop) self.loop._add_callback(h) self.assertIn(h, self.loop._scheduled) def test__add_callback_cancelled_handle(self): - h = asyncio.Handle(lambda: False, ()) + h = asyncio.Handle(lambda: False, (), self.loop) h.cancel() self.loop._add_callback(h) @@ -137,15 +142,15 @@ def cb(): self.assertRaises( AssertionError, self.loop.run_in_executor, - None, asyncio.Handle(cb, ()), ('',)) + None, asyncio.Handle(cb, (), self.loop), ('',)) self.assertRaises( AssertionError, self.loop.run_in_executor, - None, asyncio.TimerHandle(10, cb, ())) + None, asyncio.TimerHandle(10, cb, (), self.loop)) def test_run_once_in_executor_cancelled(self): def cb(): pass - h = asyncio.Handle(cb, ()) + h = asyncio.Handle(cb, (), self.loop) h.cancel() f = self.loop.run_in_executor(None, h) @@ -156,7 +161,7 @@ def cb(): def test_run_once_in_executor_plain(self): def cb(): pass - h = asyncio.Handle(cb, ()) + h = asyncio.Handle(cb, (), self.loop) f = asyncio.Future(loop=self.loop) executor = unittest.mock.Mock() executor.submit.return_value = f @@ -175,8 +180,10 @@ def cb(): f.cancel() # Don't complain about abandoned Future. def test__run_once(self): - h1 = asyncio.TimerHandle(time.monotonic() + 5.0, lambda: True, ()) - h2 = asyncio.TimerHandle(time.monotonic() + 10.0, lambda: True, ()) + h1 = asyncio.TimerHandle(time.monotonic() + 5.0, lambda: True, (), + self.loop) + h2 = asyncio.TimerHandle(time.monotonic() + 10.0, lambda: True, (), + self.loop) h1.cancel() @@ -205,14 +212,15 @@ def monotonic(): m_time.monotonic = monotonic self.loop._scheduled.append( - asyncio.TimerHandle(11.0, lambda: True, ())) + asyncio.TimerHandle(11.0, lambda: True, (), self.loop)) self.loop._process_events = unittest.mock.Mock() self.loop._run_once() self.assertEqual(logging.INFO, m_logger.log.call_args[0][0]) idx = -1 data = [10.0, 10.0, 10.3, 13.0] - self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda: True, ())] + self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda: True, (), + self.loop)] self.loop._run_once() self.assertEqual(logging.DEBUG, m_logger.log.call_args[0][0]) @@ -225,7 +233,8 @@ def cb(loop): processed = True handle = loop.call_soon(lambda: True) - h = asyncio.TimerHandle(time.monotonic() - 1, cb, (self.loop,)) + h = asyncio.TimerHandle(time.monotonic() - 1, cb, (self.loop,), + self.loop) self.loop._process_events = unittest.mock.Mock() self.loop._scheduled.append(h) @@ -287,6 +296,163 @@ def test_subprocess_shell_invalid_args(self): self.loop.run_until_complete, self.loop.subprocess_shell, asyncio.SubprocessProtocol, 'exit 0', bufsize=4096) + def test_default_exc_handler_callback(self): + self.loop._process_events = unittest.mock.Mock() + + def zero_error(fut): + fut.set_result(True) + 1/0 + + # Test call_soon (events.Handle) + with unittest.mock.patch('asyncio.base_events.logger') as log: + fut = asyncio.Future(loop=self.loop) + self.loop.call_soon(zero_error, fut) + fut.add_done_callback(lambda fut: self.loop.stop()) + self.loop.run_forever() + log.error.assert_called_with( + test_utils.MockPattern('Exception in callback.*zero'), + exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) + + # Test call_later (events.TimerHandle) + with unittest.mock.patch('asyncio.base_events.logger') as log: + fut = asyncio.Future(loop=self.loop) + self.loop.call_later(0.01, zero_error, fut) + fut.add_done_callback(lambda fut: self.loop.stop()) + self.loop.run_forever() + log.error.assert_called_with( + test_utils.MockPattern('Exception in callback.*zero'), + exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) + + def test_default_exc_handler_coro(self): + self.loop._process_events = unittest.mock.Mock() + + @asyncio.coroutine + def zero_error_coro(): + yield from asyncio.sleep(0.01, loop=self.loop) + 1/0 + + # Test Future.__del__ + with unittest.mock.patch('asyncio.base_events.logger') as log: + fut = asyncio.async(zero_error_coro(), loop=self.loop) + fut.add_done_callback(lambda *args: self.loop.stop()) + self.loop.run_forever() + fut = None # Trigger Future.__del__ or futures._TracebackLogger + if PY34: + # Future.__del__ in Python 3.4 logs error with + # an actual exception context + log.error.assert_called_with( + test_utils.MockPattern('.*exception was never retrieved'), + exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) + else: + # futures._TracebackLogger logs only textual traceback + log.error.assert_called_with( + test_utils.MockPattern( + '.*exception was never retrieved.*ZeroDiv'), + exc_info=False) + + def test_set_exc_handler_invalid(self): + with self.assertRaisesRegex(TypeError, 'A callable object or None'): + self.loop.set_exception_handler('spam') + + def test_set_exc_handler_custom(self): + def zero_error(): + 1/0 + + def run_loop(): + self.loop.call_soon(zero_error) + self.loop._run_once() + + self.loop._process_events = unittest.mock.Mock() + + mock_handler = unittest.mock.Mock() + self.loop.set_exception_handler(mock_handler) + run_loop() + mock_handler.assert_called_with(self.loop, { + 'exception': MOCK_ANY, + 'message': test_utils.MockPattern( + 'Exception in callback.*zero_error'), + 'handle': MOCK_ANY, + }) + mock_handler.reset_mock() + + self.loop.set_exception_handler(None) + with unittest.mock.patch('asyncio.base_events.logger') as log: + run_loop() + log.error.assert_called_with( + test_utils.MockPattern( + 'Exception in callback.*zero'), + exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) + + assert not mock_handler.called + + def test_set_exc_handler_broken(self): + def run_loop(): + def zero_error(): + 1/0 + self.loop.call_soon(zero_error) + self.loop._run_once() + + def handler(loop, context): + raise AttributeError('spam') + + self.loop._process_events = unittest.mock.Mock() + + self.loop.set_exception_handler(handler) + + with unittest.mock.patch('asyncio.base_events.logger') as log: + run_loop() + log.error.assert_called_with( + test_utils.MockPattern( + 'Unhandled error in exception handler'), + exc_info=(AttributeError, MOCK_ANY, MOCK_ANY)) + + def test_default_exc_handler_broken(self): + _context = None + + class Loop(base_events.BaseEventLoop): + + _selector = unittest.mock.Mock() + _process_events = unittest.mock.Mock() + + def default_exception_handler(self, context): + nonlocal _context + _context = context + # Simulates custom buggy "default_exception_handler" + raise ValueError('spam') + + loop = Loop() + asyncio.set_event_loop(loop) + + def run_loop(): + def zero_error(): + 1/0 + loop.call_soon(zero_error) + loop._run_once() + + with unittest.mock.patch('asyncio.base_events.logger') as log: + run_loop() + log.error.assert_called_with( + 'Exception in default exception handler', + exc_info=True) + + def custom_handler(loop, context): + raise ValueError('ham') + + _context = None + loop.set_exception_handler(custom_handler) + with unittest.mock.patch('asyncio.base_events.logger') as log: + run_loop() + log.error.assert_called_with( + test_utils.MockPattern('Exception in default exception.*' + 'while handling.*in custom'), + exc_info=True) + + # Check that original context was passed to default + # exception handler. + self.assertIn('context', _context) + self.assertIs(type(_context['context']['exception']), + ZeroDivisionError) + class MyProto(asyncio.Protocol): done = None @@ -716,7 +882,7 @@ def test_accept_connection_retry(self): self.loop._accept_connection(MyProto, sock) self.assertFalse(sock.close.called) - @unittest.mock.patch('asyncio.selector_events.logger') + @unittest.mock.patch('asyncio.base_events.logger') def test_accept_connection_exception(self, m_log): sock = unittest.mock.Mock() sock.fileno.return_value = 10 @@ -725,7 +891,7 @@ def test_accept_connection_exception(self, m_log): self.loop.call_later = unittest.mock.Mock() self.loop._accept_connection(MyProto, sock) - self.assertTrue(m_log.exception.called) + self.assertTrue(m_log.error.called) self.assertFalse(sock.close.called) self.loop.remove_reader.assert_called_with(10) self.loop.call_later.assert_called_with(constants.ACCEPT_RETRY_DELAY, diff --git a/tests/test_events.py b/tests/test_events.py index c9d04c04..a0a4d02c 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1788,7 +1788,7 @@ def callback(*args): return args args = () - h = asyncio.Handle(callback, args) + h = asyncio.Handle(callback, args, unittest.mock.Mock()) self.assertIs(h._callback, callback) self.assertIs(h._args, args) self.assertFalse(h._cancelled) @@ -1808,28 +1808,37 @@ def callback(*args): '.callback')) self.assertTrue(r.endswith('())'), r) - def test_handle(self): + def test_handle_from_handle(self): def callback(*args): return args - h1 = asyncio.Handle(callback, ()) + m_loop = object() + h1 = asyncio.Handle(callback, (), loop=m_loop) self.assertRaises( - AssertionError, asyncio.Handle, h1, ()) + AssertionError, asyncio.Handle, h1, (), m_loop) - @unittest.mock.patch('asyncio.events.logger') - def test_callback_with_exception(self, log): + def test_callback_with_exception(self): def callback(): raise ValueError() - h = asyncio.Handle(callback, ()) + m_loop = unittest.mock.Mock() + m_loop.call_exception_handler = unittest.mock.Mock() + + h = asyncio.Handle(callback, (), m_loop) h._run() - self.assertTrue(log.exception.called) + + m_loop.call_exception_handler.assert_called_with({ + 'message': test_utils.MockPattern('Exception in callback.*'), + 'exception': unittest.mock.ANY, + 'handle': h + }) class TimerTests(unittest.TestCase): def test_hash(self): when = time.monotonic() - h = asyncio.TimerHandle(when, lambda: False, ()) + h = asyncio.TimerHandle(when, lambda: False, (), + unittest.mock.Mock()) self.assertEqual(hash(h), hash(when)) def test_timer(self): @@ -1838,7 +1847,7 @@ def callback(*args): args = () when = time.monotonic() - h = asyncio.TimerHandle(when, callback, args) + h = asyncio.TimerHandle(when, callback, args, unittest.mock.Mock()) self.assertIs(h._callback, callback) self.assertIs(h._args, args) self.assertFalse(h._cancelled) @@ -1853,16 +1862,19 @@ def callback(*args): self.assertTrue(r.endswith('())'), r) self.assertRaises(AssertionError, - asyncio.TimerHandle, None, callback, args) + asyncio.TimerHandle, None, callback, args, + unittest.mock.Mock()) def test_timer_comparison(self): + loop = unittest.mock.Mock() + def callback(*args): return args when = time.monotonic() - h1 = asyncio.TimerHandle(when, callback, ()) - h2 = asyncio.TimerHandle(when, callback, ()) + h1 = asyncio.TimerHandle(when, callback, (), loop) + h2 = asyncio.TimerHandle(when, callback, (), loop) # TODO: Use assertLess etc. self.assertFalse(h1 < h2) self.assertFalse(h2 < h1) @@ -1878,8 +1890,8 @@ def callback(*args): h2.cancel() self.assertFalse(h1 == h2) - h1 = asyncio.TimerHandle(when, callback, ()) - h2 = asyncio.TimerHandle(when + 10.0, callback, ()) + h1 = asyncio.TimerHandle(when, callback, (), loop) + h2 = asyncio.TimerHandle(when + 10.0, callback, (), loop) self.assertTrue(h1 < h2) self.assertFalse(h2 < h1) self.assertTrue(h1 <= h2) @@ -1891,7 +1903,7 @@ def callback(*args): self.assertFalse(h1 == h2) self.assertTrue(h1 != h2) - h3 = asyncio.Handle(callback, ()) + h3 = asyncio.Handle(callback, (), loop) self.assertIs(NotImplemented, h1.__eq__(h3)) self.assertIs(NotImplemented, h1.__ne__(h3)) diff --git a/tests/test_futures.py b/tests/test_futures.py index 8a6976b1..2e4dbd4a 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -174,20 +174,20 @@ def test(): self.assertRaises(AssertionError, test) fut.cancel() - @unittest.mock.patch('asyncio.futures.logger') + @unittest.mock.patch('asyncio.base_events.logger') def test_tb_logger_abandoned(self, m_log): fut = asyncio.Future(loop=self.loop) del fut self.assertFalse(m_log.error.called) - @unittest.mock.patch('asyncio.futures.logger') + @unittest.mock.patch('asyncio.base_events.logger') def test_tb_logger_result_unretrieved(self, m_log): fut = asyncio.Future(loop=self.loop) fut.set_result(42) del fut self.assertFalse(m_log.error.called) - @unittest.mock.patch('asyncio.futures.logger') + @unittest.mock.patch('asyncio.base_events.logger') def test_tb_logger_result_retrieved(self, m_log): fut = asyncio.Future(loop=self.loop) fut.set_result(42) @@ -195,7 +195,7 @@ def test_tb_logger_result_retrieved(self, m_log): del fut self.assertFalse(m_log.error.called) - @unittest.mock.patch('asyncio.futures.logger') + @unittest.mock.patch('asyncio.base_events.logger') def test_tb_logger_exception_unretrieved(self, m_log): fut = asyncio.Future(loop=self.loop) fut.set_exception(RuntimeError('boom')) @@ -203,7 +203,7 @@ def test_tb_logger_exception_unretrieved(self, m_log): test_utils.run_briefly(self.loop) self.assertTrue(m_log.error.called) - @unittest.mock.patch('asyncio.futures.logger') + @unittest.mock.patch('asyncio.base_events.logger') def test_tb_logger_exception_retrieved(self, m_log): fut = asyncio.Future(loop=self.loop) fut.set_exception(RuntimeError('boom')) @@ -211,7 +211,7 @@ def test_tb_logger_exception_retrieved(self, m_log): del fut self.assertFalse(m_log.error.called) - @unittest.mock.patch('asyncio.futures.logger') + @unittest.mock.patch('asyncio.base_events.logger') def test_tb_logger_exception_result_retrieved(self, m_log): fut = asyncio.Future(loop=self.loop) fut.set_exception(RuntimeError('boom')) diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index 6bea1a33..816c9732 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -207,13 +207,13 @@ def test_close_buffer(self): test_utils.run_briefly(self.loop) self.assertFalse(self.protocol.connection_lost.called) - @unittest.mock.patch('asyncio.proactor_events.logger') + @unittest.mock.patch('asyncio.base_events.logger') def test_fatal_error(self, m_logging): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._force_close = unittest.mock.Mock() tr._fatal_error(None) self.assertTrue(tr._force_close.called) - self.assertTrue(m_logging.exception.called) + self.assertTrue(m_logging.error.called) def test_force_close(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) @@ -432,7 +432,7 @@ def test_write_to_self(self): def test_process_events(self): self.loop._process_events([]) - @unittest.mock.patch('asyncio.proactor_events.logger') + @unittest.mock.patch('asyncio.base_events.logger') def test_create_server(self, m_log): pf = unittest.mock.Mock() call_soon = self.loop.call_soon = unittest.mock.Mock() @@ -458,7 +458,7 @@ def test_create_server(self, m_log): fut.result.side_effect = OSError() loop(fut) self.assertTrue(self.sock.close.called) - self.assertTrue(m_log.exception.called) + self.assertTrue(m_log.error.called) def test_create_server_cancel(self): pf = unittest.mock.Mock() diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 7741e191..04b05780 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -23,6 +23,9 @@ from asyncio.selector_events import _SelectorDatagramTransport +MOCK_ANY = unittest.mock.ANY + + class TestBaseSelectorEventLoop(BaseSelectorEventLoop): def _make_self_pipe(self): @@ -643,14 +646,18 @@ def test_force_close(self): self.assertFalse(self.loop.readers) self.assertEqual(1, self.loop.remove_reader_count[7]) - @unittest.mock.patch('asyncio.log.logger.exception') + @unittest.mock.patch('asyncio.log.logger.error') def test_fatal_error(self, m_exc): exc = OSError() tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) tr._force_close = unittest.mock.Mock() tr._fatal_error(exc) - m_exc.assert_called_with('Fatal error for %s', tr) + m_exc.assert_called_with( + test_utils.MockPattern( + 'Fatal transport error\nprotocol:.*\ntransport:.*'), + exc_info=(OSError, MOCK_ANY, MOCK_ANY)) + tr._force_close.assert_called_with(exc) def test_connection_lost(self): @@ -996,7 +1003,7 @@ def test_write_ready_exception(self): transport._write_ready() transport._fatal_error.assert_called_with(err) - @unittest.mock.patch('asyncio.selector_events.logger') + @unittest.mock.patch('asyncio.base_events.logger') def test_write_ready_exception_and_close(self, m_log): self.sock.send.side_effect = OSError() remove_writer = self.loop.remove_writer = unittest.mock.Mock() @@ -1651,14 +1658,17 @@ def test_sendto_ready_error_received_connection(self): self.assertFalse(transport._fatal_error.called) self.assertTrue(self.protocol.error_received.called) - @unittest.mock.patch('asyncio.log.logger.exception') + @unittest.mock.patch('asyncio.base_events.logger.error') def test_fatal_error_connected(self, m_exc): transport = _SelectorDatagramTransport( self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) err = ConnectionRefusedError() transport._fatal_error(err) self.assertFalse(self.protocol.error_received.called) - m_exc.assert_called_with('Fatal error for %s', transport) + m_exc.assert_called_with( + test_utils.MockPattern( + 'Fatal transport error\nprotocol:.*\ntransport:.*'), + exc_info=(ConnectionRefusedError, MOCK_ANY, MOCK_ANY)) if __name__ == '__main__': diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 2fa1db45..e9330797 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -25,6 +25,9 @@ from asyncio import unix_events +MOCK_ANY = unittest.mock.ANY + + @unittest.skipUnless(signal, 'Signals are not supported') class SelectorEventLoopSignalTests(unittest.TestCase): @@ -45,7 +48,8 @@ def test_handle_signal_no_handler(self): self.loop._handle_signal(signal.NSIG + 1, ()) def test_handle_signal_cancelled_handler(self): - h = asyncio.Handle(unittest.mock.Mock(), ()) + h = asyncio.Handle(unittest.mock.Mock(), (), + loop=unittest.mock.Mock()) h.cancel() self.loop._signal_handlers[signal.NSIG + 1] = h self.loop.remove_signal_handler = unittest.mock.Mock() @@ -91,7 +95,7 @@ class Err(OSError): signal.SIGINT, lambda: True) @unittest.mock.patch('asyncio.unix_events.signal') - @unittest.mock.patch('asyncio.unix_events.logger') + @unittest.mock.patch('asyncio.base_events.logger') def test_add_signal_handler_install_error2(self, m_logging, m_signal): m_signal.NSIG = signal.NSIG @@ -108,7 +112,7 @@ class Err(OSError): self.assertEqual(1, m_signal.set_wakeup_fd.call_count) @unittest.mock.patch('asyncio.unix_events.signal') - @unittest.mock.patch('asyncio.unix_events.logger') + @unittest.mock.patch('asyncio.base_events.logger') def test_add_signal_handler_install_error3(self, m_logging, m_signal): class Err(OSError): errno = errno.EINVAL @@ -153,7 +157,7 @@ def test_remove_signal_handler_2(self, m_signal): m_signal.signal.call_args[0]) @unittest.mock.patch('asyncio.unix_events.signal') - @unittest.mock.patch('asyncio.unix_events.logger') + @unittest.mock.patch('asyncio.base_events.logger') def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): m_signal.NSIG = signal.NSIG self.loop.add_signal_handler(signal.SIGHUP, lambda: True) @@ -347,7 +351,7 @@ def test__read_ready_blocked(self, m_read): test_utils.run_briefly(self.loop) self.assertFalse(self.protocol.data_received.called) - @unittest.mock.patch('asyncio.log.logger.exception') + @unittest.mock.patch('asyncio.log.logger.error') @unittest.mock.patch('os.read') def test__read_ready_error(self, m_read, m_logexc): tr = unix_events._UnixReadPipeTransport( @@ -359,7 +363,10 @@ def test__read_ready_error(self, m_read, m_logexc): m_read.assert_called_with(5, tr.max_size) tr._close.assert_called_with(err) - m_logexc.assert_called_with('Fatal error for %s', tr) + m_logexc.assert_called_with( + test_utils.MockPattern( + 'Fatal transport error\nprotocol:.*\ntransport:.*'), + exc_info=(OSError, MOCK_ANY, MOCK_ANY)) @unittest.mock.patch('os.read') def test_pause_reading(self, m_read): @@ -423,7 +430,7 @@ def test__call_connection_lost(self): self.assertEqual(2, sys.getrefcount(self.protocol), pprint.pformat(gc.get_referrers(self.protocol))) self.assertIsNone(tr._loop) - self.assertEqual(2, sys.getrefcount(self.loop), + self.assertEqual(4, sys.getrefcount(self.loop), pprint.pformat(gc.get_referrers(self.loop))) def test__call_connection_lost_with_err(self): @@ -436,10 +443,11 @@ def test__call_connection_lost_with_err(self): self.pipe.close.assert_called_with() self.assertIsNone(tr._protocol) + self.assertEqual(2, sys.getrefcount(self.protocol), pprint.pformat(gc.get_referrers(self.protocol))) self.assertIsNone(tr._loop) - self.assertEqual(2, sys.getrefcount(self.loop), + self.assertEqual(4, sys.getrefcount(self.loop), pprint.pformat(gc.get_referrers(self.loop))) @@ -635,7 +643,7 @@ def test__write_ready_empty(self, m_write): self.loop.assert_writer(5, tr._write_ready) self.assertEqual([b'data'], tr._buffer) - @unittest.mock.patch('asyncio.log.logger.exception') + @unittest.mock.patch('asyncio.log.logger.error') @unittest.mock.patch('os.write') def test__write_ready_err(self, m_write, m_logexc): tr = unix_events._UnixWritePipeTransport( @@ -650,7 +658,10 @@ def test__write_ready_err(self, m_write, m_logexc): self.assertFalse(self.loop.readers) self.assertEqual([], tr._buffer) self.assertTrue(tr._closing) - m_logexc.assert_called_with('Fatal error for %s', tr) + m_logexc.assert_called_with( + test_utils.MockPattern( + 'Fatal transport error\nprotocol:.*\ntransport:.*'), + exc_info=(OSError, MOCK_ANY, MOCK_ANY)) self.assertEqual(1, tr._conn_lost) test_utils.run_briefly(self.loop) self.protocol.connection_lost.assert_called_with(err) @@ -702,7 +713,7 @@ def test__call_connection_lost(self): self.assertEqual(2, sys.getrefcount(self.protocol), pprint.pformat(gc.get_referrers(self.protocol))) self.assertIsNone(tr._loop) - self.assertEqual(2, sys.getrefcount(self.loop), + self.assertEqual(4, sys.getrefcount(self.loop), pprint.pformat(gc.get_referrers(self.loop))) def test__call_connection_lost_with_err(self): @@ -718,7 +729,7 @@ def test__call_connection_lost_with_err(self): self.assertEqual(2, sys.getrefcount(self.protocol), pprint.pformat(gc.get_referrers(self.protocol))) self.assertIsNone(tr._loop) - self.assertEqual(2, sys.getrefcount(self.loop), + self.assertEqual(4, sys.getrefcount(self.loop), pprint.pformat(gc.get_referrers(self.loop))) def test_close(self): @@ -1285,10 +1296,10 @@ def test_sigchld_unhandled_exception(self, m): m.waitpid.side_effect = ValueError with unittest.mock.patch.object(log.logger, - "exception") as m_exception: + 'error') as m_error: self.assertEqual(self.watcher._sig_chld(), None) - self.assertTrue(m_exception.called) + self.assertTrue(m_error.called) @waitpid_mocks def test_sigchld_child_reaped_elsewhere(self, m): From 83d1ab447b2ec595f0ca5e8ef7229b86211a39c0 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Tue, 18 Feb 2014 18:39:14 -0500 Subject: [PATCH 0959/1502] transports: Make _ProactorBasePipeTransport use _FlowControlMixin --- asyncio/proactor_events.py | 55 ++-------------------------- asyncio/selector_events.py | 73 ++------------------------------------ asyncio/transports.py | 70 ++++++++++++++++++++++++++++++++++++ asyncio/unix_events.py | 2 +- 4 files changed, 75 insertions(+), 125 deletions(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index b2ac632f..d72f9274 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -15,7 +15,8 @@ from .log import logger -class _ProactorBasePipeTransport(transports.BaseTransport): +class _ProactorBasePipeTransport(transports._FlowControlMixin, + transports.BaseTransport): """Base class for pipe and socket transports.""" def __init__(self, loop, sock, protocol, waiter=None, @@ -33,8 +34,6 @@ def __init__(self, loop, sock, protocol, waiter=None, self._conn_lost = 0 self._closing = False # Set when close() called. self._eof_written = False - self._protocol_paused = False - self.set_write_buffer_limits() if self._server is not None: self._server.attach(self) self._loop.call_soon(self._protocol.connection_made, self) @@ -94,56 +93,6 @@ def _call_connection_lost(self, exc): server.detach(self) self._server = None - # XXX The next four methods are nearly identical to corresponding - # ones in _SelectorTransport. Maybe refactor buffer management to - # share the implementations? (Also these are really only needed - # by _ProactorWritePipeTransport but since _buffer is defined on - # the base class I am putting it here for now.) - - def _maybe_pause_protocol(self): - size = self.get_write_buffer_size() - if size <= self._high_water: - return - if not self._protocol_paused: - self._protocol_paused = True - try: - self._protocol.pause_writing() - except Exception as exc: - self._loop.call_exception_handler({ - 'message': 'protocol.pause_writing() failed', - 'exception': exc, - 'transport': self, - 'protocol': self._protocol, - }) - - def _maybe_resume_protocol(self): - if (self._protocol_paused and - self.get_write_buffer_size() <= self._low_water): - self._protocol_paused = False - try: - self._protocol.resume_writing() - except Exception as exc: - self._loop.call_exception_handler({ - 'message': 'protocol.resume_writing() failed', - 'exception': exc, - 'transport': self, - 'protocol': self._protocol, - }) - - def set_write_buffer_limits(self, high=None, low=None): - if high is None: - if low is None: - high = 64*1024 - else: - high = 4*low - if low is None: - low = high // 4 - if not high >= low >= 0: - raise ValueError('high (%r) must be >= low (%r) must be >= 0' % - (high, low)) - self._high_water = high - self._low_water = low - def get_write_buffer_size(self): size = self._pending_write if self._buffer is not None: diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index fb86f824..869d66e0 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -338,77 +338,8 @@ def _stop_serving(self, sock): sock.close() -class _FlowControlMixin(transports.Transport): - """All the logic for (write) flow control in a mix-in base class. - - The subclass must implement get_write_buffer_size(). It must call - _maybe_pause_protocol() whenever the write buffer size increases, - and _maybe_resume_protocol() whenever it decreases. It may also - override set_write_buffer_limits() (e.g. to specify different - defaults). - - The subclass constructor must call super().__init__(extra). This - will call set_write_buffer_limits(). - - The user may call set_write_buffer_limits() and - get_write_buffer_size(), and their protocol's pause_writing() and - resume_writing() may be called. - """ - - def __init__(self, extra=None): - super().__init__(extra) - self._protocol_paused = False - self.set_write_buffer_limits() - - def _maybe_pause_protocol(self): - size = self.get_write_buffer_size() - if size <= self._high_water: - return - if not self._protocol_paused: - self._protocol_paused = True - try: - self._protocol.pause_writing() - except Exception as exc: - self._loop.call_exception_handler({ - 'message': 'protocol.pause_writing() failed', - 'exception': exc, - 'transport': self, - 'protocol': self._protocol, - }) - - def _maybe_resume_protocol(self): - if (self._protocol_paused and - self.get_write_buffer_size() <= self._low_water): - self._protocol_paused = False - try: - self._protocol.resume_writing() - except Exception as exc: - self._loop.call_exception_handler({ - 'message': 'protocol.resume_writing() failed', - 'exception': exc, - 'transport': self, - 'protocol': self._protocol, - }) - - def set_write_buffer_limits(self, high=None, low=None): - if high is None: - if low is None: - high = 64*1024 - else: - high = 4*low - if low is None: - low = high // 4 - if not high >= low >= 0: - raise ValueError('high (%r) must be >= low (%r) must be >= 0' % - (high, low)) - self._high_water = high - self._low_water = low - - def get_write_buffer_size(self): - raise NotImplementedError - - -class _SelectorTransport(_FlowControlMixin, transports.Transport): +class _SelectorTransport(transports._FlowControlMixin, + transports.Transport): max_size = 256 * 1024 # Buffer size passed to recv(). diff --git a/asyncio/transports.py b/asyncio/transports.py index 67ae7fda..5b975aa7 100644 --- a/asyncio/transports.py +++ b/asyncio/transports.py @@ -219,3 +219,73 @@ def kill(self): http://docs.python.org/3/library/subprocess#subprocess.Popen.kill """ raise NotImplementedError + + +class _FlowControlMixin(Transport): + """All the logic for (write) flow control in a mix-in base class. + + The subclass must implement get_write_buffer_size(). It must call + _maybe_pause_protocol() whenever the write buffer size increases, + and _maybe_resume_protocol() whenever it decreases. It may also + override set_write_buffer_limits() (e.g. to specify different + defaults). + + The subclass constructor must call super().__init__(extra). This + will call set_write_buffer_limits(). + + The user may call set_write_buffer_limits() and + get_write_buffer_size(), and their protocol's pause_writing() and + resume_writing() may be called. + """ + + def __init__(self, extra=None): + super().__init__(extra) + self._protocol_paused = False + self.set_write_buffer_limits() + + def _maybe_pause_protocol(self): + size = self.get_write_buffer_size() + if size <= self._high_water: + return + if not self._protocol_paused: + self._protocol_paused = True + try: + self._protocol.pause_writing() + except Exception as exc: + self._loop.call_exception_handler({ + 'message': 'protocol.pause_writing() failed', + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) + + def _maybe_resume_protocol(self): + if (self._protocol_paused and + self.get_write_buffer_size() <= self._low_water): + self._protocol_paused = False + try: + self._protocol.resume_writing() + except Exception as exc: + self._loop.call_exception_handler({ + 'message': 'protocol.resume_writing() failed', + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) + + def set_write_buffer_limits(self, high=None, low=None): + if high is None: + if low is None: + high = 64*1024 + else: + high = 4*low + if low is None: + low = high // 4 + if not high >= low >= 0: + raise ValueError('high (%r) must be >= low (%r) must be >= 0' % + (high, low)) + self._high_water = high + self._low_water = low + + def get_write_buffer_size(self): + raise NotImplementedError diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 9a40c04d..748452c5 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -317,7 +317,7 @@ def _call_connection_lost(self, exc): self._loop = None -class _UnixWritePipeTransport(selector_events._FlowControlMixin, +class _UnixWritePipeTransport(transports._FlowControlMixin, transports.WriteTransport): def __init__(self, loop, pipe, protocol, waiter=None, extra=None): From 346ffb37bd1c8b3280e48b737c43f2194f557273 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 19 Feb 2014 01:39:27 +0100 Subject: [PATCH 0960/1502] Issue #139: Improve error messages on "fatal errors" Mention if the error was caused by a read or a write, and be more specific on the object (ex: "pipe transport" instead of "transport"). --- asyncio/proactor_events.py | 10 +++++----- asyncio/selector_events.py | 22 +++++++++++---------- asyncio/unix_events.py | 14 +++++++------- tests/test_proactor_events.py | 12 +++++++++--- tests/test_selector_events.py | 36 +++++++++++++++++++++++++---------- tests/test_unix_events.py | 8 +++++--- 6 files changed, 64 insertions(+), 38 deletions(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index d72f9274..f45cd9c6 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -53,10 +53,10 @@ def close(self): if self._read_fut is not None: self._read_fut.cancel() - def _fatal_error(self, exc): + def _fatal_error(self, exc, message='Fatal error on pipe transport'): if not isinstance(exc, (BrokenPipeError, ConnectionResetError)): self._loop.call_exception_handler({ - 'message': 'Fatal transport error', + 'message': message, 'exception': exc, 'transport': self, 'protocol': self._protocol, @@ -151,11 +151,11 @@ def _loop_reading(self, fut=None): self._read_fut = self._loop._proactor.recv(self._sock, 4096) except ConnectionAbortedError as exc: if not self._closing: - self._fatal_error(exc) + self._fatal_error(exc, 'Fatal read error on pipe transport') except ConnectionResetError as exc: self._force_close(exc) except OSError as exc: - self._fatal_error(exc) + self._fatal_error(exc, 'Fatal read error on pipe transport') except futures.CancelledError: if not self._closing: raise @@ -246,7 +246,7 @@ def _loop_writing(self, f=None, data=None): except ConnectionResetError as exc: self._force_close(exc) except OSError as exc: - self._fatal_error(exc) + self._fatal_error(exc, 'Fatal write error on pipe transport') def can_write_eof(self): return True diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 869d66e0..c142356f 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -377,11 +377,11 @@ def close(self): self._conn_lost += 1 self._loop.call_soon(self._call_connection_lost, None) - def _fatal_error(self, exc): + def _fatal_error(self, exc, message='Fatal error on transport'): # Should be called from exception handler only. if not isinstance(exc, (BrokenPipeError, ConnectionResetError)): self._loop.call_exception_handler({ - 'message': 'Fatal transport error', + 'message': message, 'exception': exc, 'transport': self, 'protocol': self._protocol, @@ -452,7 +452,7 @@ def _read_ready(self): except (BlockingIOError, InterruptedError): pass except Exception as exc: - self._fatal_error(exc) + self._fatal_error(exc, 'Fatal read error on socket transport') else: if data: self._protocol.data_received(data) @@ -488,7 +488,7 @@ def write(self, data): except (BlockingIOError, InterruptedError): pass except Exception as exc: - self._fatal_error(exc) + self._fatal_error(exc, 'Fatal write error on socket transport') return else: data = data[n:] @@ -511,7 +511,7 @@ def _write_ready(self): except Exception as exc: self._loop.remove_writer(self._sock_fd) self._buffer.clear() - self._fatal_error(exc) + self._fatal_error(exc, 'Fatal write error on socket transport') else: if n: del self._buffer[:n] @@ -678,7 +678,7 @@ def _read_ready(self): self._loop.remove_reader(self._sock_fd) self._loop.add_writer(self._sock_fd, self._write_ready) except Exception as exc: - self._fatal_error(exc) + self._fatal_error(exc, 'Fatal read error on SSL transport') else: if data: self._protocol.data_received(data) @@ -712,7 +712,7 @@ def _write_ready(self): except Exception as exc: self._loop.remove_writer(self._sock_fd) self._buffer.clear() - self._fatal_error(exc) + self._fatal_error(exc, 'Fatal write error on SSL transport') return if n: @@ -770,7 +770,7 @@ def _read_ready(self): except OSError as exc: self._protocol.error_received(exc) except Exception as exc: - self._fatal_error(exc) + self._fatal_error(exc, 'Fatal read error on datagram transport') else: self._protocol.datagram_received(data, addr) @@ -805,7 +805,8 @@ def sendto(self, data, addr=None): self._protocol.error_received(exc) return except Exception as exc: - self._fatal_error(exc) + self._fatal_error(exc, + 'Fatal write error on datagram transport') return # Ensure that what we buffer is immutable. @@ -827,7 +828,8 @@ def _sendto_ready(self): self._protocol.error_received(exc) return except Exception as exc: - self._fatal_error(exc) + self._fatal_error(exc, + 'Fatal write error on datagram transport') return self._maybe_resume_protocol() # May append to buffer. diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 748452c5..3a2fd18b 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -271,7 +271,7 @@ def _read_ready(self): except (BlockingIOError, InterruptedError): pass except OSError as exc: - self._fatal_error(exc) + self._fatal_error(exc, 'Fatal read error on pipe transport') else: if data: self._protocol.data_received(data) @@ -291,11 +291,11 @@ def close(self): if not self._closing: self._close(None) - def _fatal_error(self, exc): + def _fatal_error(self, exc, message='Fatal error on pipe transport'): # should be called by exception handler only if not (isinstance(exc, OSError) and exc.errno == errno.EIO): self._loop.call_exception_handler({ - 'message': 'Fatal transport error', + 'message': message, 'exception': exc, 'transport': self, 'protocol': self._protocol, @@ -381,7 +381,7 @@ def write(self, data): n = 0 except Exception as exc: self._conn_lost += 1 - self._fatal_error(exc) + self._fatal_error(exc, 'Fatal write error on pipe transport') return if n == len(data): return @@ -406,7 +406,7 @@ def _write_ready(self): # Remove writer here, _fatal_error() doesn't it # because _buffer is empty. self._loop.remove_writer(self._fileno) - self._fatal_error(exc) + self._fatal_error(exc, 'Fatal write error on pipe transport') else: if n == len(data): self._loop.remove_writer(self._fileno) @@ -443,11 +443,11 @@ def close(self): def abort(self): self._close(None) - def _fatal_error(self, exc): + def _fatal_error(self, exc, message='Fatal error on pipe transport'): # should be called by exception handler only if not isinstance(exc, (BrokenPipeError, ConnectionResetError)): self._loop.call_exception_handler({ - 'message': 'Fatal transport error', + 'message': message, 'exception': exc, 'transport': self, 'protocol': self._protocol, diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index 816c9732..08920690 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -69,7 +69,9 @@ def test_loop_reading_aborted(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._fatal_error = unittest.mock.Mock() tr._loop_reading() - tr._fatal_error.assert_called_with(err) + tr._fatal_error.assert_called_with( + err, + 'Fatal read error on pipe transport') def test_loop_reading_aborted_closing(self): self.loop._proactor.recv.side_effect = ConnectionAbortedError() @@ -105,7 +107,9 @@ def test_loop_reading_exception(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._fatal_error = unittest.mock.Mock() tr._loop_reading() - tr._fatal_error.assert_called_with(err) + tr._fatal_error.assert_called_with( + err, + 'Fatal read error on pipe transport') def test_write(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) @@ -142,7 +146,9 @@ def test_loop_writing_err(self, m_log): tr._fatal_error = unittest.mock.Mock() tr._buffer = [b'da', b'ta'] tr._loop_writing() - tr._fatal_error.assert_called_with(err) + tr._fatal_error.assert_called_with( + err, + 'Fatal write error on pipe transport') tr._conn_lost = 1 tr.write(b'data') diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 04b05780..247df9e0 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -655,7 +655,7 @@ def test_fatal_error(self, m_exc): m_exc.assert_called_with( test_utils.MockPattern( - 'Fatal transport error\nprotocol:.*\ntransport:.*'), + 'Fatal error on transport\nprotocol:.*\ntransport:.*'), exc_info=(OSError, MOCK_ANY, MOCK_ANY)) tr._force_close.assert_called_with(exc) @@ -785,7 +785,9 @@ def test_read_ready_err(self, m_exc): transport._fatal_error = unittest.mock.Mock() transport._read_ready() - transport._fatal_error.assert_called_with(err) + transport._fatal_error.assert_called_with( + err, + 'Fatal read error on socket transport') def test_write(self): data = b'data' @@ -898,7 +900,9 @@ def test_write_exception(self, m_log): self.loop, self.sock, self.protocol) transport._fatal_error = unittest.mock.Mock() transport.write(data) - transport._fatal_error.assert_called_with(err) + transport._fatal_error.assert_called_with( + err, + 'Fatal write error on socket transport') transport._conn_lost = 1 self.sock.reset_mock() @@ -1001,7 +1005,9 @@ def test_write_ready_exception(self): transport._fatal_error = unittest.mock.Mock() transport._buffer.extend(b'data') transport._write_ready() - transport._fatal_error.assert_called_with(err) + transport._fatal_error.assert_called_with( + err, + 'Fatal write error on socket transport') @unittest.mock.patch('asyncio.base_events.logger') def test_write_ready_exception_and_close(self, m_log): @@ -1237,7 +1243,9 @@ def test_read_ready_recv_exc(self): transport = self._make_one() transport._fatal_error = unittest.mock.Mock() transport._read_ready() - transport._fatal_error.assert_called_with(err) + transport._fatal_error.assert_called_with( + err, + 'Fatal read error on SSL transport') def test_write_ready_send(self): self.sslsock.send.return_value = 4 @@ -1319,7 +1327,9 @@ def test_write_ready_send_exc(self): transport._buffer = list_to_buffer([b'data']) transport._fatal_error = unittest.mock.Mock() transport._write_ready() - transport._fatal_error.assert_called_with(err) + transport._fatal_error.assert_called_with( + err, + 'Fatal write error on SSL transport') self.assertEqual(list_to_buffer(), transport._buffer) def test_write_ready_read_wants_write(self): @@ -1407,7 +1417,9 @@ def test_read_ready_err(self): transport._fatal_error = unittest.mock.Mock() transport._read_ready() - transport._fatal_error.assert_called_with(err) + transport._fatal_error.assert_called_with( + err, + 'Fatal read error on datagram transport') def test_read_ready_oserr(self): transport = _SelectorDatagramTransport( @@ -1517,7 +1529,9 @@ def test_sendto_exception(self, m_log): transport.sendto(data, ()) self.assertTrue(transport._fatal_error.called) - transport._fatal_error.assert_called_with(err) + transport._fatal_error.assert_called_with( + err, + 'Fatal write error on datagram transport') transport._conn_lost = 1 transport._address = ('123',) @@ -1633,7 +1647,9 @@ def test_sendto_ready_exception(self): transport._buffer.append((b'data', ())) transport._sendto_ready() - transport._fatal_error.assert_called_with(err) + transport._fatal_error.assert_called_with( + err, + 'Fatal write error on datagram transport') def test_sendto_ready_error_received(self): self.sock.sendto.side_effect = ConnectionRefusedError @@ -1667,7 +1683,7 @@ def test_fatal_error_connected(self, m_exc): self.assertFalse(self.protocol.error_received.called) m_exc.assert_called_with( test_utils.MockPattern( - 'Fatal transport error\nprotocol:.*\ntransport:.*'), + 'Fatal error on transport\nprotocol:.*\ntransport:.*'), exc_info=(ConnectionRefusedError, MOCK_ANY, MOCK_ANY)) diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index e9330797..9866e33a 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -365,7 +365,7 @@ def test__read_ready_error(self, m_read, m_logexc): tr._close.assert_called_with(err) m_logexc.assert_called_with( test_utils.MockPattern( - 'Fatal transport error\nprotocol:.*\ntransport:.*'), + 'Fatal read error on pipe transport\nprotocol:.*\ntransport:.*'), exc_info=(OSError, MOCK_ANY, MOCK_ANY)) @unittest.mock.patch('os.read') @@ -558,7 +558,9 @@ def test_write_err(self, m_write, m_log): m_write.assert_called_with(5, b'data') self.assertFalse(self.loop.writers) self.assertEqual([], tr._buffer) - tr._fatal_error.assert_called_with(err) + tr._fatal_error.assert_called_with( + err, + 'Fatal write error on pipe transport') self.assertEqual(1, tr._conn_lost) tr.write(b'data') @@ -660,7 +662,7 @@ def test__write_ready_err(self, m_write, m_logexc): self.assertTrue(tr._closing) m_logexc.assert_called_with( test_utils.MockPattern( - 'Fatal transport error\nprotocol:.*\ntransport:.*'), + 'Fatal write error on pipe transport\nprotocol:.*\ntransport:.*'), exc_info=(OSError, MOCK_ANY, MOCK_ANY)) self.assertEqual(1, tr._conn_lost) test_utils.run_briefly(self.loop) From 81f1d9965ad1639f882ae73e7c5115667bdcf4a3 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 19 Feb 2014 01:44:23 +0100 Subject: [PATCH 0961/1502] Issue #143: UNIX domain methods, fix ResourceWarning and DeprecationWarning warnings. create_unix_server() closes the socket on any error, not only on OSError. --- asyncio/unix_events.py | 8 ++++---- tests/test_unix_events.py | 28 +++++++++++++++------------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 3a2fd18b..faf4c60d 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -183,13 +183,12 @@ def create_unix_connection(self, protocol_factory, path, *, raise ValueError( 'path and sock can not be specified at the same time') + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0) try: - sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0) sock.setblocking(False) yield from self.sock_connect(sock, path) - except OSError: - if sock is not None: - sock.close() + except: + sock.close() raise else: @@ -213,6 +212,7 @@ def create_unix_server(self, protocol_factory, path=None, *, try: sock.bind(path) except OSError as exc: + sock.close() if exc.errno == errno.EADDRINUSE: # Let's improve the error message by adding # with what exact address it occurs. diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 9866e33a..7b5196c8 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -221,17 +221,17 @@ def test_create_unix_server_existing_path_sock(self): with test_utils.unix_socket_path() as path: sock = socket.socket(socket.AF_UNIX) sock.bind(path) - - coro = self.loop.create_unix_server(lambda: None, path) - with self.assertRaisesRegexp(OSError, - 'Address.*is already in use'): - self.loop.run_until_complete(coro) + with sock: + coro = self.loop.create_unix_server(lambda: None, path) + with self.assertRaisesRegex(OSError, + 'Address.*is already in use'): + self.loop.run_until_complete(coro) def test_create_unix_server_existing_path_nonsock(self): with tempfile.NamedTemporaryFile() as file: coro = self.loop.create_unix_server(lambda: None, file.name) - with self.assertRaisesRegexp(OSError, - 'Address.*is already in use'): + with self.assertRaisesRegex(OSError, + 'Address.*is already in use'): self.loop.run_until_complete(coro) def test_create_unix_server_ssl_bool(self): @@ -248,11 +248,13 @@ def test_create_unix_server_nopath_nosock(self): self.loop.run_until_complete(coro) def test_create_unix_server_path_inetsock(self): - coro = self.loop.create_unix_server(lambda: None, path=None, - sock=socket.socket()) - with self.assertRaisesRegex(ValueError, - 'A UNIX Domain Socket was expected'): - self.loop.run_until_complete(coro) + sock = socket.socket() + with sock: + coro = self.loop.create_unix_server(lambda: None, path=None, + sock=sock) + with self.assertRaisesRegex(ValueError, + 'A UNIX Domain Socket was expected'): + self.loop.run_until_complete(coro) def test_create_unix_connection_path_sock(self): coro = self.loop.create_unix_connection( @@ -278,7 +280,7 @@ def test_create_unix_connection_ssl_noserverhost(self): coro = self.loop.create_unix_connection( lambda: None, '/dev/null', ssl=True) - with self.assertRaisesRegexp( + with self.assertRaisesRegex( ValueError, 'you have to pass server_hostname when using ssl'): self.loop.run_until_complete(coro) From 93132bc49771ed3a76e0d10a0b2625405a7f30e4 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Tue, 18 Feb 2014 22:25:37 -0500 Subject: [PATCH 0962/1502] Fix spelling & typos --- asyncio/events.py | 6 +++--- asyncio/protocols.py | 2 +- asyncio/selector_events.py | 2 +- asyncio/selectors.py | 2 +- asyncio/tasks.py | 2 +- asyncio/test_utils.py | 4 ++-- asyncio/unix_events.py | 4 ++-- asyncio/windows_events.py | 2 +- tests/test_base_events.py | 2 +- tests/test_futures.py | 2 +- tests/test_streams.py | 2 +- tests/test_unix_events.py | 6 +++--- 12 files changed, 18 insertions(+), 18 deletions(-) diff --git a/asyncio/events.py b/asyncio/events.py index f61c5b74..1030c045 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -235,7 +235,7 @@ def create_unix_server(self, protocol_factory, path, *, sock=None, backlog=100, ssl=None): """A coroutine which creates a UNIX Domain Socket server. - The return valud is a Server object, which can be used to stop + The return value is a Server object, which can be used to stop the service. path is a str, representing a file systsem path to bind the @@ -260,7 +260,7 @@ def create_datagram_endpoint(self, protocol_factory, # Pipes and subprocesses. def connect_read_pipe(self, protocol_factory, pipe): - """Register read pipe in eventloop. + """Register read pipe in event loop. protocol_factory should instantiate object with Protocol interface. pipe is file-like object already switched to nonblocking. @@ -273,7 +273,7 @@ def connect_read_pipe(self, protocol_factory, pipe): raise NotImplementedError def connect_write_pipe(self, protocol_factory, pipe): - """Register write pipe in eventloop. + """Register write pipe in event loop. protocol_factory should instantiate object with BaseProtocol interface. Pipe is file-like object already switched to nonblocking. diff --git a/asyncio/protocols.py b/asyncio/protocols.py index 3c4f3f4a..52fc25c2 100644 --- a/asyncio/protocols.py +++ b/asyncio/protocols.py @@ -114,7 +114,7 @@ class SubprocessProtocol(BaseProtocol): def pipe_data_received(self, fd, data): """Called when the subprocess writes data into stdout/stderr pipe. - fd is int file dascriptor. + fd is int file descriptor. data is bytes object. """ diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index c142356f..aa427459 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -213,7 +213,7 @@ def sock_recv(self, sock, n): def _sock_recv(self, fut, registered, sock, n): # _sock_recv() can add itself as an I/O callback if the operation can't - # be done immediatly. Don't use it directly, call sock_recv(). + # be done immediately. Don't use it directly, call sock_recv(). fd = sock.fileno() if registered: # Remove the callback early. It should be rare that the diff --git a/asyncio/selectors.py b/asyncio/selectors.py index bb2a45a8..a5465e24 100644 --- a/asyncio/selectors.py +++ b/asyncio/selectors.py @@ -80,7 +80,7 @@ class BaseSelector(metaclass=ABCMeta): A selector can use various implementations (select(), poll(), epoll()...) depending on the platform. The default `Selector` class uses the most - performant implementation on the current platform. + efficient implementation on the current platform. """ @abstractmethod diff --git a/asyncio/tasks.py b/asyncio/tasks.py index b7ee758d..a3e7cdf1 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -181,7 +181,7 @@ def get_stack(self, *, limit=None): The frames are always ordered from oldest to newest. - The optional limit gives the maximum nummber of frames to + The optional limit gives the maximum number of frames to return; by default all available frames are returned. Its meaning differs depending on whether a stack or a traceback is returned: the newest frames of a stack are returned, but the diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index 28e52430..2a8a241f 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -259,7 +259,7 @@ def gen(): when = yield ... ... = yield time_advance - Value retuned by yield is absolute time of next scheduled handler. + Value returned by yield is absolute time of next scheduled handler. Value passed to yield is time advance to move loop's time forward. """ @@ -369,7 +369,7 @@ class MockPattern(str): """A regex based str with a fuzzy __eq__. Use this helper with 'mock.assert_called_with', or anywhere - where a regexp comparison between strings is needed. + where a regex comparison between strings is needed. For instance: mock_call.assert_called_with(MockPattern('spam.*ham')) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index faf4c60d..ce45e5ff 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -1,4 +1,4 @@ -"""Selector eventloop for Unix with signal handling.""" +"""Selector event loop for Unix with signal handling.""" import errno import fcntl @@ -244,7 +244,7 @@ def _set_nonblocking(fd): class _UnixReadPipeTransport(transports.ReadTransport): - max_size = 256 * 1024 # max bytes we read in one eventloop iteration + max_size = 256 * 1024 # max bytes we read in one event loop iteration def __init__(self, loop, pipe, protocol, waiter=None, extra=None): super().__init__(extra) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index c667a1c3..e6be9d13 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -1,4 +1,4 @@ -"""Selector and proactor eventloops for Windows.""" +"""Selector and proactor event loops for Windows.""" import _winapi import errno diff --git a/tests/test_base_events.py b/tests/test_base_events.py index f664cccf..2eee3be3 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -277,7 +277,7 @@ def test_subprocess_exec_invalid_args(self): asyncio.SubprocessProtocol, *args, bufsize=4096) def test_subprocess_shell_invalid_args(self): - # exepected a string, not an int or a list + # expected a string, not an int or a list self.assertRaises(TypeError, self.loop.run_until_complete, self.loop.subprocess_shell, asyncio.SubprocessProtocol, 123) diff --git a/tests/test_futures.py b/tests/test_futures.py index 2e4dbd4a..f2b81ddd 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -38,7 +38,7 @@ def test_init_constructor_default_loop(self): asyncio.set_event_loop(None) def test_constructor_positional(self): - # Make sure Future does't accept a positional argument + # Make sure Future doesn't accept a positional argument self.assertRaises(TypeError, asyncio.Future, 42) def test_cancel(self): diff --git a/tests/test_streams.py b/tests/test_streams.py index 31e26a64..ca792f20 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -239,7 +239,7 @@ def test_readline_limit_with_existing_data(self): # No b'\n' at the end. The 'limit' is set to 3. So before # waiting for the new data in buffer, 'readline' will consume # the entire buffer, and since the length of the consumed data - # is more than 3, it will raise a ValudError. The buffer is + # is more than 3, it will raise a ValueError. The buffer is # expected to be empty now. self.assertEqual(b'', stream._buffer) diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 7b5196c8..c0f205e5 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -965,7 +965,7 @@ def test_sigchld_two_children(self, m): self.assertFalse(m.WEXITSTATUS.called) self.assertFalse(m.WTERMSIG.called) - # childen are running + # children are running self.watcher._sig_chld() self.assertFalse(callback1.called) @@ -1069,7 +1069,7 @@ def test_sigchld_two_children_terminating_together(self, m): self.assertFalse(m.WEXITSTATUS.called) self.assertFalse(m.WTERMSIG.called) - # childen are running + # children are running self.watcher._sig_chld() self.assertFalse(callback1.called) @@ -1425,7 +1425,7 @@ def test_set_loop_race_condition(self, m): self.add_zombie(61, 11) self.add_zombie(62, -5) - # SIGCHLD was not catched + # SIGCHLD was not caught self.assertFalse(callback1.called) self.assertFalse(callback2.called) self.assertFalse(callback3.called) From 48d3fe3655cce9903854685be20dccfc266171cf Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Tue, 18 Feb 2014 22:54:14 -0500 Subject: [PATCH 0963/1502] pep8-ify the code. --- asyncio/base_events.py | 11 ++++++----- asyncio/subprocess.py | 5 +++-- examples/subprocess_shell.py | 3 ++- tests/test_subprocess.py | 12 ++++++++++-- tests/test_tasks.py | 4 +++- tests/test_unix_events.py | 6 ++++-- 6 files changed, 28 insertions(+), 13 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index cb2499d2..b94ba079 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -605,10 +605,10 @@ def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, return transport, protocol @tasks.coroutine - def subprocess_exec(self, protocol_factory, program, *args, stdin=subprocess.PIPE, - stdout=subprocess.PIPE, stderr=subprocess.PIPE, - universal_newlines=False, shell=False, bufsize=0, - **kwargs): + def subprocess_exec(self, protocol_factory, program, *args, + stdin=subprocess.PIPE, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, universal_newlines=False, + shell=False, bufsize=0, **kwargs): if universal_newlines: raise ValueError("universal_newlines must be False") if shell: @@ -623,7 +623,8 @@ def subprocess_exec(self, protocol_factory, program, *args, stdin=subprocess.PIP % type(arg).__name__) protocol = protocol_factory() transport = yield from self._make_subprocess_transport( - protocol, popen_args, False, stdin, stdout, stderr, bufsize, **kwargs) + protocol, popen_args, False, stdin, stdout, stderr, + bufsize, **kwargs) return transport, protocol def set_exception_handler(self, handler): diff --git a/asyncio/subprocess.py b/asyncio/subprocess.py index 8d1a4073..c3b01755 100644 --- a/asyncio/subprocess.py +++ b/asyncio/subprocess.py @@ -180,8 +180,9 @@ def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None, return Process(transport, protocol, loop) @tasks.coroutine -def create_subprocess_exec(program, *args, stdin=None, stdout=None, stderr=None, - loop=None, limit=streams._DEFAULT_LIMIT, **kwds): +def create_subprocess_exec(program, *args, stdin=None, stdout=None, + stderr=None, loop=None, + limit=streams._DEFAULT_LIMIT, **kwds): if loop is None: loop = events.get_event_loop() protocol_factory = lambda: SubprocessStreamProtocol(limit=limit, diff --git a/examples/subprocess_shell.py b/examples/subprocess_shell.py index d0e5d65a..ca871540 100644 --- a/examples/subprocess_shell.py +++ b/examples/subprocess_shell.py @@ -78,7 +78,8 @@ def main(): asyncio.set_event_loop(loop) else: loop = asyncio.get_event_loop() - loop.run_until_complete(start('sleep 2; wc', input=[b'foo bar baz\n'*300 for i in range(100)])) + loop.run_until_complete(start( + 'sleep 2; wc', input=[b'foo bar baz\n'*300 for i in range(100)])) if __name__ == '__main__': diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index 1b2f05be..14fd17e6 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -21,6 +21,7 @@ 'sys.stdout.buffer.write(data)'))] class SubprocessMixin: + def test_stdin_stdout(self): args = PROGRAM_CAT @@ -132,6 +133,7 @@ def test_broken_pipe(self): if sys.platform != 'win32': # Unix class SubprocessWatcherMixin(SubprocessMixin): + Watcher = None def setUp(self): @@ -151,14 +153,20 @@ def tearDown(self): self.loop.close() policy.set_event_loop(None) - class SubprocessSafeWatcherTests(SubprocessWatcherMixin, unittest.TestCase): + class SubprocessSafeWatcherTests(SubprocessWatcherMixin, + unittest.TestCase): + Watcher = unix_events.SafeChildWatcher - class SubprocessFastWatcherTests(SubprocessWatcherMixin, unittest.TestCase): + class SubprocessFastWatcherTests(SubprocessWatcherMixin, + unittest.TestCase): + Watcher = unix_events.FastChildWatcher + else: # Windows class SubprocessProactorTests(SubprocessMixin, unittest.TestCase): + def setUp(self): policy = asyncio.get_event_loop_policy() self.loop = asyncio.ProactorEventLoop() diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 024dd2ea..f27b9522 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -876,6 +876,7 @@ def gen(): self.assertEqual(set(f.result() for f in done), {'a', 'b'}) def test_as_completed_duplicate_coroutines(self): + @asyncio.coroutine def coro(s): return s @@ -884,7 +885,8 @@ def coro(s): def runner(): result = [] c = coro('ham') - for f in asyncio.as_completed([c, c, coro('spam')], loop=self.loop): + for f in asyncio.as_completed([c, c, coro('spam')], + loop=self.loop): result.append((yield from f)) return result diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index c0f205e5..9e489c2d 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -367,7 +367,8 @@ def test__read_ready_error(self, m_read, m_logexc): tr._close.assert_called_with(err) m_logexc.assert_called_with( test_utils.MockPattern( - 'Fatal read error on pipe transport\nprotocol:.*\ntransport:.*'), + 'Fatal read error on pipe transport' + '\nprotocol:.*\ntransport:.*'), exc_info=(OSError, MOCK_ANY, MOCK_ANY)) @unittest.mock.patch('os.read') @@ -664,7 +665,8 @@ def test__write_ready_err(self, m_write, m_logexc): self.assertTrue(tr._closing) m_logexc.assert_called_with( test_utils.MockPattern( - 'Fatal write error on pipe transport\nprotocol:.*\ntransport:.*'), + 'Fatal write error on pipe transport' + '\nprotocol:.*\ntransport:.*'), exc_info=(OSError, MOCK_ANY, MOCK_ANY)) self.assertEqual(1, tr._conn_lost) test_utils.run_briefly(self.loop) From df5bb4d69002552e1ff49c44e84ff5038580fabe Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Wed, 19 Feb 2014 11:09:02 -0500 Subject: [PATCH 0964/1502] WriteTransport.set_write_buffer_size to call _maybe_pause_protocol (closes issue #140) --- asyncio/transports.py | 8 ++++++-- tests/test_transports.py | 23 +++++++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/asyncio/transports.py b/asyncio/transports.py index 5b975aa7..5f674f99 100644 --- a/asyncio/transports.py +++ b/asyncio/transports.py @@ -241,7 +241,7 @@ class _FlowControlMixin(Transport): def __init__(self, extra=None): super().__init__(extra) self._protocol_paused = False - self.set_write_buffer_limits() + self._set_write_buffer_limits() def _maybe_pause_protocol(self): size = self.get_write_buffer_size() @@ -273,7 +273,7 @@ def _maybe_resume_protocol(self): 'protocol': self._protocol, }) - def set_write_buffer_limits(self, high=None, low=None): + def _set_write_buffer_limits(self, high=None, low=None): if high is None: if low is None: high = 64*1024 @@ -287,5 +287,9 @@ def set_write_buffer_limits(self, high=None, low=None): self._high_water = high self._low_water = low + def set_write_buffer_limits(self, high=None, low=None): + self._set_write_buffer_limits(high=high, low=low) + self._maybe_pause_protocol() + def get_write_buffer_size(self): raise NotImplementedError diff --git a/tests/test_transports.py b/tests/test_transports.py index d16db807..4c645268 100644 --- a/tests/test_transports.py +++ b/tests/test_transports.py @@ -4,6 +4,7 @@ import unittest.mock import asyncio +from asyncio import transports class TransportTests(unittest.TestCase): @@ -60,6 +61,28 @@ def test_subprocess_transport_not_implemented(self): self.assertRaises(NotImplementedError, transport.terminate) self.assertRaises(NotImplementedError, transport.kill) + def test_flowcontrol_mixin_set_write_limits(self): + + class MyTransport(transports._FlowControlMixin, + transports.Transport): + + def get_write_buffer_size(self): + return 512 + + transport = MyTransport() + transport._protocol = unittest.mock.Mock() + + self.assertFalse(transport._protocol_paused) + + with self.assertRaisesRegex(ValueError, 'high.*must be >= low'): + transport.set_write_buffer_limits(high=0, low=1) + + transport.set_write_buffer_limits(high=1024, low=128) + self.assertFalse(transport._protocol_paused) + + transport.set_write_buffer_limits(high=256, low=128) + self.assertTrue(transport._protocol_paused) + if __name__ == '__main__': unittest.main() From 9b4bbf7641c04ffb9daf34aae299adc524962e12 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 19 Feb 2014 23:04:13 +0100 Subject: [PATCH 0965/1502] Fix tests on UNIX sockets on Mac OS X 10.4 (Tiger): don't test the sockname extra info, getsockname() has a bug on this old OS X kernel. --- tests/test_events.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/tests/test_events.py b/tests/test_events.py index a0a4d02c..3720cc75 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -4,6 +4,7 @@ import gc import io import os +import platform import signal import socket try: @@ -40,6 +41,15 @@ def data_file(filename): raise FileNotFoundError(filename) +def osx_tiger(): + """Return True if the platform is Mac OS 10.4 or older.""" + if sys.platform != 'darwin': + return False + version = platform.mac_ver()[0] + version = tuple(map(int, version.split('.'))) + return version < (10, 5) + + ONLYCERT = data_file('ssl_cert.pem') ONLYKEY = data_file('ssl_key.pem') SIGNED_CERTFILE = data_file('keycert3.pem') @@ -499,10 +509,12 @@ def my_handler(*args): self.loop.run_forever() self.assertEqual(caught, 1) - def _basetest_create_connection(self, connection_fut): + def _basetest_create_connection(self, connection_fut, check_sockname=True): tr, pr = self.loop.run_until_complete(connection_fut) self.assertIsInstance(tr, asyncio.Transport) self.assertIsInstance(pr, asyncio.Protocol) + if check_sockname: + self.assertIsNotNone(tr.get_extra_info('sockname')) self.loop.run_until_complete(pr.done) self.assertGreater(pr.nbytes, 0) tr.close() @@ -515,10 +527,14 @@ def test_create_connection(self): @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_create_unix_connection(self): + # Issue #20682: On Mac OS X Tiger, getsockname() returns a + # zero-length address for UNIX socket. + check_sockname = not osx_tiger() + with test_utils.run_test_unix_server() as httpd: conn_fut = self.loop.create_unix_connection( lambda: MyProto(loop=self.loop), httpd.address) - self._basetest_create_connection(conn_fut) + self._basetest_create_connection(conn_fut, check_sockname) def test_create_connection_sock(self): with test_utils.run_test_server() as httpd: @@ -548,12 +564,14 @@ def test_create_connection_sock(self): self.assertGreater(pr.nbytes, 0) tr.close() - def _basetest_create_ssl_connection(self, connection_fut): + def _basetest_create_ssl_connection(self, connection_fut, + check_sockname=True): tr, pr = self.loop.run_until_complete(connection_fut) self.assertIsInstance(tr, asyncio.Transport) self.assertIsInstance(pr, asyncio.Protocol) self.assertTrue('ssl' in tr.__class__.__name__.lower()) - self.assertIsNotNone(tr.get_extra_info('sockname')) + if check_sockname: + self.assertIsNotNone(tr.get_extra_info('sockname')) self.loop.run_until_complete(pr.done) self.assertGreater(pr.nbytes, 0) tr.close() @@ -571,6 +589,10 @@ def test_create_ssl_connection(self): @unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_create_ssl_unix_connection(self): + # Issue #20682: On Mac OS X Tiger, getsockname() returns a + # zero-length address for UNIX socket. + check_sockname = not osx_tiger() + with test_utils.run_test_unix_server(use_ssl=True) as httpd: conn_fut = self.loop.create_unix_connection( lambda: MyProto(loop=self.loop), @@ -578,7 +600,7 @@ def test_create_ssl_unix_connection(self): ssl=test_utils.dummy_ssl_context(), server_hostname='127.0.0.1') - self._basetest_create_ssl_connection(conn_fut) + self._basetest_create_ssl_connection(conn_fut, check_sockname) def test_create_connection_local_addr(self): with test_utils.run_test_server() as httpd: From 8228f011fafd8f3445a6c3a1974f1f5f2de0d99c Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 19 Feb 2014 23:07:08 +0100 Subject: [PATCH 0966/1502] Issue #136: Add get/set_debug() methods to BaseEventLoopTests. Add also a PYTHONASYNCIODEBUG environment variable to debug coroutines since Python startup, to be able to debug coroutines defined directly in the asyncio module. --- asyncio/base_events.py | 7 +++++++ asyncio/events.py | 8 ++++++++ asyncio/tasks.py | 5 ++++- tests/test_base_events.py | 6 ++++++ tests/test_tasks.py | 28 ++++++++++++++++++++++++++++ 5 files changed, 53 insertions(+), 1 deletion(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index b94ba079..69caa4d7 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -123,6 +123,7 @@ def __init__(self): self._running = False self._clock_resolution = time.get_clock_info('monotonic').resolution self._exception_handler = None + self._debug = False def _make_socket_transport(self, sock, protocol, waiter=None, *, extra=None, server=None): @@ -795,3 +796,9 @@ def _run_once(self): if not handle._cancelled: handle._run() handle = None # Needed to break cycles when an exception occurs. + + def get_debug(self): + return self._debug + + def set_debug(self, enabled): + self._debug = enabled diff --git a/asyncio/events.py b/asyncio/events.py index 1030c045..5362f056 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -345,6 +345,14 @@ def default_exception_handler(self, context): def call_exception_handler(self, context): raise NotImplementedError + # Debug flag management. + + def get_debug(self): + raise NotImplementedError + + def set_debug(self, enabled): + raise NotImplementedError + class AbstractEventLoopPolicy: """Abstract policy for accessing the event loop.""" diff --git a/asyncio/tasks.py b/asyncio/tasks.py index a3e7cdf1..cf7b5400 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -12,6 +12,8 @@ import functools import inspect import linecache +import os +import sys import traceback import weakref @@ -28,7 +30,8 @@ # before you define your coroutines. A downside of using this feature # is that tracebacks show entries for the CoroWrapper.__next__ method # when _DEBUG is true. -_DEBUG = False +_DEBUG = (not sys.flags.ignore_environment + and bool(os.environ.get('PYTHONASYNCIODEBUG'))) class CoroWrapper: diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 2eee3be3..784a39f8 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -197,6 +197,12 @@ def test__run_once(self): self.assertEqual([h2], self.loop._scheduled) self.assertTrue(self.loop._process_events.called) + def test_set_debug(self): + self.loop.set_debug(True) + self.assertTrue(self.loop.get_debug()) + self.loop.set_debug(False) + self.assertFalse(self.loop.get_debug()) + @unittest.mock.patch('asyncio.base_events.time') @unittest.mock.patch('asyncio.base_events.logger') def test__run_once_logging(self, m_logger, m_time): diff --git a/tests/test_tasks.py b/tests/test_tasks.py index f27b9522..6d03dc78 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1,7 +1,9 @@ """Tests for tasks.py.""" import gc +import os.path import unittest +from test.script_helper import assert_python_ok import asyncio from asyncio import test_utils @@ -1461,6 +1463,32 @@ def test_return_exceptions(self): cb.assert_called_once_with(fut) self.assertEqual(fut.result(), [3, 1, exc, exc2]) + def test_env_var_debug(self): + path = os.path.dirname(asyncio.__file__) + path = os.path.normpath(os.path.join(path, '..')) + code = '\n'.join(( + 'import sys', + 'sys.path.insert(0, %r)' % path, + 'import asyncio.tasks', + 'print(asyncio.tasks._DEBUG)')) + + # Test with -E to not fail if the unit test was run with + # PYTHONASYNCIODEBUG set to a non-empty string + sts, stdout, stderr = assert_python_ok('-E', '-c', code) + self.assertEqual(stdout.rstrip(), b'False') + + sts, stdout, stderr = assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='') + self.assertEqual(stdout.rstrip(), b'False') + + sts, stdout, stderr = assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='1') + self.assertEqual(stdout.rstrip(), b'True') + + sts, stdout, stderr = assert_python_ok('-E', '-c', code, + PYTHONASYNCIODEBUG='1') + self.assertEqual(stdout.rstrip(), b'False') + class FutureGatherTests(GatherTestsBase, unittest.TestCase): From 00c6c1bd65b6deb9d97cd815cd7dc93e0627d023 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 20 Feb 2014 10:05:40 +0100 Subject: [PATCH 0967/1502] asyncio.subprocess: Fix a race condition in communicate() Use self._loop instead of self._transport._loop, because transport._loop is set to None at process exit. --- asyncio/subprocess.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/asyncio/subprocess.py b/asyncio/subprocess.py index c3b01755..414e0238 100644 --- a/asyncio/subprocess.py +++ b/asyncio/subprocess.py @@ -146,7 +146,6 @@ def _read_stream(self, fd): @tasks.coroutine def communicate(self, input=None): - loop = self._transport._loop if input: stdin = self._feed_stdin(input) else: @@ -160,7 +159,7 @@ def communicate(self, input=None): else: stderr = self._noop() stdin, stdout, stderr = yield from tasks.gather(stdin, stdout, stderr, - loop=loop) + loop=self._loop) yield from self.wait() return (stdout, stderr) From 004c727f9d5d49760eeebf240a6734b1acf5c65d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 20 Feb 2014 10:32:40 +0100 Subject: [PATCH 0968/1502] asyncio: Fix _ProactorWritePipeTransport._pipe_closed() The "exc" variable was not defined, pass a BrokenPipeError exception instead. --- asyncio/proactor_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index f45cd9c6..d99e8ce7 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -275,7 +275,7 @@ def _pipe_closed(self, fut): assert fut is self._read_fut, (fut, self._read_fut) self._read_fut = None if self._write_fut is not None: - self._force_close(exc) + self._force_close(BrokenPipeError()) else: self.close() From 75714928f735c1dd686274d55e257b5d192f030b Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 20 Feb 2014 10:37:17 +0100 Subject: [PATCH 0969/1502] asyncio: remove unused imports and unused variables noticed by pyflakes --- asyncio/events.py | 3 --- asyncio/futures.py | 1 - asyncio/selectors.py | 3 +-- asyncio/tasks.py | 2 -- asyncio/test_utils.py | 2 +- asyncio/unix_events.py | 1 - asyncio/windows_events.py | 1 - 7 files changed, 2 insertions(+), 11 deletions(-) diff --git a/asyncio/events.py b/asyncio/events.py index 5362f056..57af68af 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -9,12 +9,9 @@ ] import subprocess -import sys import threading import socket -from .log import logger - class Handle: """Object returned by callback registration methods.""" diff --git a/asyncio/futures.py b/asyncio/futures.py index b9cd45c7..91ea1706 100644 --- a/asyncio/futures.py +++ b/asyncio/futures.py @@ -11,7 +11,6 @@ import traceback from . import events -from .log import logger # States for Future. _PENDING = 'PENDING' diff --git a/asyncio/selectors.py b/asyncio/selectors.py index a5465e24..9be92255 100644 --- a/asyncio/selectors.py +++ b/asyncio/selectors.py @@ -5,9 +5,8 @@ """ -from abc import ABCMeta, abstractmethod, abstractproperty +from abc import ABCMeta, abstractmethod from collections import namedtuple, Mapping -import functools import math import select import sys diff --git a/asyncio/tasks.py b/asyncio/tasks.py index cf7b5400..19fa654e 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -7,7 +7,6 @@ 'gather', 'shield', ] -import collections import concurrent.futures import functools import inspect @@ -486,7 +485,6 @@ def as_completed(fs, *, loop=None, timeout=None): if isinstance(fs, futures.Future) or iscoroutine(fs): raise TypeError("expect a list of futures, not %s" % type(fs).__name__) loop = loop if loop is not None else events.get_event_loop() - deadline = None if timeout is None else loop.time() + timeout todo = {async(f, loop=loop) for f in set(fs)} from .queues import Queue # Import here to avoid circular import problem. done = Queue(loop=loop) diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index 2a8a241f..dd87789f 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -15,7 +15,7 @@ import unittest.mock from http.server import HTTPServer -from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer +from wsgiref.simple_server import WSGIRequestHandler, WSGIServer try: import ssl diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index ce45e5ff..21255480 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -15,7 +15,6 @@ from . import base_subprocess from . import constants from . import events -from . import protocols from . import selector_events from . import tasks from . import transports diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index e6be9d13..60fb5896 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -5,7 +5,6 @@ import math import socket import struct -import subprocess import weakref from . import events From 2b8709661b2deedd52df9c6be55b431c14cd2fc2 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 20 Feb 2014 16:41:16 +0100 Subject: [PATCH 0970/1502] Fix _check_resolved_address() for IPv6 address --- asyncio/base_events.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 69caa4d7..1615ecbf 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -45,10 +45,13 @@ def _check_resolved_address(sock, address): # Ensure that the address is already resolved to avoid the trap of hanging # the entire event loop when the address requires doing a DNS lookup. family = sock.family - if family not in (socket.AF_INET, socket.AF_INET6): + if family == socket.AF_INET: + host, port = address + elif family == socket.AF_INET6: + host, port, flow_info, scope_id = address + else: return - host, port = address type_mask = 0 if hasattr(socket, 'SOCK_NONBLOCK'): type_mask |= socket.SOCK_NONBLOCK From 2898eb91ad8103f9190b9ccdd4d39bf4d23cd5c3 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 20 Feb 2014 16:51:44 +0100 Subject: [PATCH 0971/1502] Oops, and now fix also the unit test for IPv6 address: test_sock_connect_address() --- tests/test_events.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/test_events.py b/tests/test_events.py index 3720cc75..f8499dc1 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1335,12 +1335,11 @@ def wait(): 'selector': self.loop._selector.__class__.__name__}) def test_sock_connect_address(self): - families = [socket.AF_INET] + families = [(socket.AF_INET, ('www.python.org', 80))] if support.IPV6_ENABLED: - families.append(socket.AF_INET6) + families.append((socket.AF_INET6, ('www.python.org', 80, 0, 0))) - address = ('www.python.org', 80) - for family in families: + for family, address in families: for sock_type in (socket.SOCK_STREAM, socket.SOCK_DGRAM): sock = socket.socket(family, sock_type) with sock: From 228cf17b63b2b5e85d97c95e40cdeb0e16717ed0 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 20 Feb 2014 17:12:44 +0100 Subject: [PATCH 0972/1502] Remove debug code --- examples/shell.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/shell.py b/examples/shell.py index e094b613..8ae30ca9 100644 --- a/examples/shell.py +++ b/examples/shell.py @@ -1,5 +1,4 @@ """Examples using create_subprocess_exec() and create_subprocess_shell().""" -import logging; logging.basicConfig() import asyncio import signal From 557d51585c800ead73013f7426885c0b285589cb Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 20 Feb 2014 21:59:02 +0100 Subject: [PATCH 0973/1502] _check_resolved_address() must also accept IPv6 without flow_info and scope_id: (host, port). --- asyncio/base_events.py | 2 +- tests/test_events.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 1615ecbf..80df9271 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -48,7 +48,7 @@ def _check_resolved_address(sock, address): if family == socket.AF_INET: host, port = address elif family == socket.AF_INET6: - host, port, flow_info, scope_id = address + host, port = address[:2] else: return diff --git a/tests/test_events.py b/tests/test_events.py index f8499dc1..d00af23d 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1335,11 +1335,14 @@ def wait(): 'selector': self.loop._selector.__class__.__name__}) def test_sock_connect_address(self): - families = [(socket.AF_INET, ('www.python.org', 80))] + addresses = [(socket.AF_INET, ('www.python.org', 80))] if support.IPV6_ENABLED: - families.append((socket.AF_INET6, ('www.python.org', 80, 0, 0))) + addresses.extend(( + (socket.AF_INET6, ('www.python.org', 80)), + (socket.AF_INET6, ('www.python.org', 80, 0, 0)), + )) - for family, address in families: + for family, address in addresses: for sock_type in (socket.SOCK_STREAM, socket.SOCK_DGRAM): sock = socket.socket(family, sock_type) with sock: From 3783143135d1c33b8ca75149441441165b4f069b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 24 Feb 2014 10:18:42 -0800 Subject: [PATCH 0974/1502] Release asyncio version 0.4.1, identical to CPython 3.4.0rc2. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 0d92df70..6cce4a7b 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ setup( name="asyncio", - version="0.3.1", + version="0.4.1", description="reference implementation of PEP 3156", long_description=open("README").read(), From 4783d6b83583d1d83f1132df6a0291b0434f5047 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 24 Feb 2014 10:18:49 -0800 Subject: [PATCH 0975/1502] Added tag 0.4.1 for changeset 429bf62d2636 From fad69162191e75a846bd51061da75d8dcfc5ce7b Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 26 Feb 2014 00:27:46 +0100 Subject: [PATCH 0976/1502] Add COPYING: Apache License version 2.0 --- COPYING | 201 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 201 insertions(+) create mode 100644 COPYING diff --git a/COPYING b/COPYING new file mode 100644 index 00000000..11069edd --- /dev/null +++ b/COPYING @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. From 5f955fe8cb6a6610064f120f81aea876fdbef141 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 26 Feb 2014 10:22:59 +0100 Subject: [PATCH 0977/1502] Replace "unittest.mock" with "mock" in unit tests Use "from unittest import mock". It should simplify my work to merge new tests in Trollius, because Trollius uses "mock" backport for Python 2. --- asyncio/test_utils.py | 4 +- tests/test_base_events.py | 108 +++++++------- tests/test_events.py | 32 ++-- tests/test_futures.py | 16 +- tests/test_locks.py | 20 +-- tests/test_proactor_events.py | 78 +++++----- tests/test_queues.py | 4 +- tests/test_selector_events.py | 270 +++++++++++++++++----------------- tests/test_selectors.py | 36 ++--- tests/test_streams.py | 4 +- tests/test_transports.py | 6 +- tests/test_unix_events.py | 190 ++++++++++++------------ tests/test_windows_utils.py | 4 +- 13 files changed, 386 insertions(+), 386 deletions(-) diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index dd87789f..9a9a10b4 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -12,7 +12,7 @@ import threading import time import unittest -import unittest.mock +from unittest import mock from http.server import HTTPServer from wsgiref.simple_server import WSGIRequestHandler, WSGIServer @@ -362,7 +362,7 @@ def _write_to_self(self): def MockCallback(**kwargs): - return unittest.mock.Mock(spec=['__call__'], **kwargs) + return mock.Mock(spec=['__call__'], **kwargs) class MockPattern(str): diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 784a39f8..f7a4e3a0 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -6,7 +6,7 @@ import sys import time import unittest -import unittest.mock +from unittest import mock from test.support import find_unused_port, IPV6_ENABLED import asyncio @@ -15,7 +15,7 @@ from asyncio import test_utils -MOCK_ANY = unittest.mock.ANY +MOCK_ANY = mock.ANY PY34 = sys.version_info >= (3, 4) @@ -23,11 +23,11 @@ class BaseEventLoopTests(unittest.TestCase): def setUp(self): self.loop = base_events.BaseEventLoop() - self.loop._selector = unittest.mock.Mock() + self.loop._selector = mock.Mock() asyncio.set_event_loop(None) def test_not_implemented(self): - m = unittest.mock.Mock() + m = mock.Mock() self.assertRaises( NotImplementedError, self.loop._make_socket_transport, m, m) @@ -75,13 +75,13 @@ def test__add_callback_cancelled_handle(self): self.assertFalse(self.loop._ready) def test_set_default_executor(self): - executor = unittest.mock.Mock() + executor = mock.Mock() self.loop.set_default_executor(executor) self.assertIs(executor, self.loop._default_executor) def test_getnameinfo(self): - sockaddr = unittest.mock.Mock() - self.loop.run_in_executor = unittest.mock.Mock() + sockaddr = mock.Mock() + self.loop.run_in_executor = mock.Mock() self.loop.getnameinfo(sockaddr) self.assertEqual( (None, socket.getnameinfo, sockaddr, 0), @@ -111,7 +111,7 @@ def test_call_later_negative_delays(self): def cb(arg): calls.append(arg) - self.loop._process_events = unittest.mock.Mock() + self.loop._process_events = mock.Mock() self.loop.call_later(-1, cb, 'a') self.loop.call_later(-2, cb, 'b') test_utils.run_briefly(self.loop) @@ -121,7 +121,7 @@ def test_time_and_call_at(self): def cb(): self.loop.stop() - self.loop._process_events = unittest.mock.Mock() + self.loop._process_events = mock.Mock() delay = 0.1 when = self.loop.time() + delay @@ -163,7 +163,7 @@ def cb(): pass h = asyncio.Handle(cb, (), self.loop) f = asyncio.Future(loop=self.loop) - executor = unittest.mock.Mock() + executor = mock.Mock() executor.submit.return_value = f self.loop.set_default_executor(executor) @@ -171,7 +171,7 @@ def cb(): res = self.loop.run_in_executor(None, h) self.assertIs(f, res) - executor = unittest.mock.Mock() + executor = mock.Mock() executor.submit.return_value = f res = self.loop.run_in_executor(executor, h) self.assertIs(f, res) @@ -187,7 +187,7 @@ def test__run_once(self): h1.cancel() - self.loop._process_events = unittest.mock.Mock() + self.loop._process_events = mock.Mock() self.loop._scheduled.append(h1) self.loop._scheduled.append(h2) self.loop._run_once() @@ -203,8 +203,8 @@ def test_set_debug(self): self.loop.set_debug(False) self.assertFalse(self.loop.get_debug()) - @unittest.mock.patch('asyncio.base_events.time') - @unittest.mock.patch('asyncio.base_events.logger') + @mock.patch('asyncio.base_events.time') + @mock.patch('asyncio.base_events.logger') def test__run_once_logging(self, m_logger, m_time): # Log to INFO level if timeout > 1.0 sec. idx = -1 @@ -219,7 +219,7 @@ def monotonic(): self.loop._scheduled.append( asyncio.TimerHandle(11.0, lambda: True, (), self.loop)) - self.loop._process_events = unittest.mock.Mock() + self.loop._process_events = mock.Mock() self.loop._run_once() self.assertEqual(logging.INFO, m_logger.log.call_args[0][0]) @@ -242,7 +242,7 @@ def cb(loop): h = asyncio.TimerHandle(time.monotonic() - 1, cb, (self.loop,), self.loop) - self.loop._process_events = unittest.mock.Mock() + self.loop._process_events = mock.Mock() self.loop._scheduled.append(h) self.loop._run_once() @@ -303,14 +303,14 @@ def test_subprocess_shell_invalid_args(self): asyncio.SubprocessProtocol, 'exit 0', bufsize=4096) def test_default_exc_handler_callback(self): - self.loop._process_events = unittest.mock.Mock() + self.loop._process_events = mock.Mock() def zero_error(fut): fut.set_result(True) 1/0 # Test call_soon (events.Handle) - with unittest.mock.patch('asyncio.base_events.logger') as log: + with mock.patch('asyncio.base_events.logger') as log: fut = asyncio.Future(loop=self.loop) self.loop.call_soon(zero_error, fut) fut.add_done_callback(lambda fut: self.loop.stop()) @@ -320,7 +320,7 @@ def zero_error(fut): exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) # Test call_later (events.TimerHandle) - with unittest.mock.patch('asyncio.base_events.logger') as log: + with mock.patch('asyncio.base_events.logger') as log: fut = asyncio.Future(loop=self.loop) self.loop.call_later(0.01, zero_error, fut) fut.add_done_callback(lambda fut: self.loop.stop()) @@ -330,7 +330,7 @@ def zero_error(fut): exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) def test_default_exc_handler_coro(self): - self.loop._process_events = unittest.mock.Mock() + self.loop._process_events = mock.Mock() @asyncio.coroutine def zero_error_coro(): @@ -338,7 +338,7 @@ def zero_error_coro(): 1/0 # Test Future.__del__ - with unittest.mock.patch('asyncio.base_events.logger') as log: + with mock.patch('asyncio.base_events.logger') as log: fut = asyncio.async(zero_error_coro(), loop=self.loop) fut.add_done_callback(lambda *args: self.loop.stop()) self.loop.run_forever() @@ -368,9 +368,9 @@ def run_loop(): self.loop.call_soon(zero_error) self.loop._run_once() - self.loop._process_events = unittest.mock.Mock() + self.loop._process_events = mock.Mock() - mock_handler = unittest.mock.Mock() + mock_handler = mock.Mock() self.loop.set_exception_handler(mock_handler) run_loop() mock_handler.assert_called_with(self.loop, { @@ -382,7 +382,7 @@ def run_loop(): mock_handler.reset_mock() self.loop.set_exception_handler(None) - with unittest.mock.patch('asyncio.base_events.logger') as log: + with mock.patch('asyncio.base_events.logger') as log: run_loop() log.error.assert_called_with( test_utils.MockPattern( @@ -401,11 +401,11 @@ def zero_error(): def handler(loop, context): raise AttributeError('spam') - self.loop._process_events = unittest.mock.Mock() + self.loop._process_events = mock.Mock() self.loop.set_exception_handler(handler) - with unittest.mock.patch('asyncio.base_events.logger') as log: + with mock.patch('asyncio.base_events.logger') as log: run_loop() log.error.assert_called_with( test_utils.MockPattern( @@ -417,8 +417,8 @@ def test_default_exc_handler_broken(self): class Loop(base_events.BaseEventLoop): - _selector = unittest.mock.Mock() - _process_events = unittest.mock.Mock() + _selector = mock.Mock() + _process_events = mock.Mock() def default_exception_handler(self, context): nonlocal _context @@ -435,7 +435,7 @@ def zero_error(): loop.call_soon(zero_error) loop._run_once() - with unittest.mock.patch('asyncio.base_events.logger') as log: + with mock.patch('asyncio.base_events.logger') as log: run_loop() log.error.assert_called_with( 'Exception in default exception handler', @@ -446,7 +446,7 @@ def custom_handler(loop, context): _context = None loop.set_exception_handler(custom_handler) - with unittest.mock.patch('asyncio.base_events.logger') as log: + with mock.patch('asyncio.base_events.logger') as log: run_loop() log.error.assert_called_with( test_utils.MockPattern('Exception in default exception.*' @@ -527,7 +527,7 @@ def setUp(self): def tearDown(self): self.loop.close() - @unittest.mock.patch('asyncio.base_events.socket') + @mock.patch('asyncio.base_events.socket') def test_create_connection_multiple_errors(self, m_socket): class MyProto(asyncio.Protocol): @@ -592,7 +592,7 @@ def getaddrinfo_task(*args, **kwds): return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) self.loop.getaddrinfo = getaddrinfo_task - self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect = mock.Mock() self.loop.sock_connect.side_effect = OSError coro = self.loop.create_connection(MyProto, 'example.com', 80) @@ -609,7 +609,7 @@ def getaddrinfo_task(*args, **kwds): return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) self.loop.getaddrinfo = getaddrinfo_task - self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect = mock.Mock() self.loop.sock_connect.side_effect = OSError coro = self.loop.create_connection( @@ -617,7 +617,7 @@ def getaddrinfo_task(*args, **kwds): with self.assertRaises(OSError): self.loop.run_until_complete(coro) - @unittest.mock.patch('asyncio.base_events.socket') + @mock.patch('asyncio.base_events.socket') def test_create_connection_multiple_errors_local_addr(self, m_socket): def bind(addr): @@ -637,7 +637,7 @@ def getaddrinfo_task(*args, **kwds): return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) self.loop.getaddrinfo = getaddrinfo_task - self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect = mock.Mock() self.loop.sock_connect.side_effect = OSError('Err2') coro = self.loop.create_connection( @@ -669,7 +669,7 @@ def getaddrinfo_task(*args, **kwds): OSError, self.loop.run_until_complete, coro) def test_create_connection_ssl_server_hostname_default(self): - self.loop.getaddrinfo = unittest.mock.Mock() + self.loop.getaddrinfo = mock.Mock() def mock_getaddrinfo(*args, **kwds): f = asyncio.Future(loop=self.loop) @@ -678,9 +678,9 @@ def mock_getaddrinfo(*args, **kwds): return f self.loop.getaddrinfo.side_effect = mock_getaddrinfo - self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect = mock.Mock() self.loop.sock_connect.return_value = () - self.loop._make_ssl_transport = unittest.mock.Mock() + self.loop._make_ssl_transport = mock.Mock() class _SelectorTransportMock: _sock = None @@ -696,7 +696,7 @@ def mock_make_ssl_transport(sock, protocol, sslcontext, waiter, return transport self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport - ANY = unittest.mock.ANY + ANY = mock.ANY # First try the default server_hostname. self.loop._make_ssl_transport.reset_mock() coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True) @@ -775,13 +775,13 @@ def test_create_server_no_host_port_sock(self): self.assertRaises(ValueError, self.loop.run_until_complete, fut) def test_create_server_no_getaddrinfo(self): - getaddrinfo = self.loop.getaddrinfo = unittest.mock.Mock() + getaddrinfo = self.loop.getaddrinfo = mock.Mock() getaddrinfo.return_value = [] f = self.loop.create_server(MyProto, '0.0.0.0', 0) self.assertRaises(OSError, self.loop.run_until_complete, f) - @unittest.mock.patch('asyncio.base_events.socket') + @mock.patch('asyncio.base_events.socket') def test_create_server_cant_bind(self, m_socket): class Err(OSError): @@ -790,14 +790,14 @@ class Err(OSError): m_socket.getaddrinfo.return_value = [ (2, 1, 6, '', ('127.0.0.1', 10100))] m_socket.getaddrinfo._is_coroutine = False - m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock = m_socket.socket.return_value = mock.Mock() m_sock.bind.side_effect = Err fut = self.loop.create_server(MyProto, '0.0.0.0', 0) self.assertRaises(OSError, self.loop.run_until_complete, fut) self.assertTrue(m_sock.close.called) - @unittest.mock.patch('asyncio.base_events.socket') + @mock.patch('asyncio.base_events.socket') def test_create_datagram_endpoint_no_addrinfo(self, m_socket): m_socket.getaddrinfo.return_value = [] m_socket.getaddrinfo._is_coroutine = False @@ -818,7 +818,7 @@ def test_create_datagram_endpoint_addr_error(self): AssertionError, self.loop.run_until_complete, coro) def test_create_datagram_endpoint_connect_err(self): - self.loop.sock_connect = unittest.mock.Mock() + self.loop.sock_connect = mock.Mock() self.loop.sock_connect.side_effect = OSError coro = self.loop.create_datagram_endpoint( @@ -826,7 +826,7 @@ def test_create_datagram_endpoint_connect_err(self): self.assertRaises( OSError, self.loop.run_until_complete, coro) - @unittest.mock.patch('asyncio.base_events.socket') + @mock.patch('asyncio.base_events.socket') def test_create_datagram_endpoint_socket_err(self, m_socket): m_socket.getaddrinfo = socket.getaddrinfo m_socket.socket.side_effect = OSError @@ -849,7 +849,7 @@ def test_create_datagram_endpoint_no_matching_family(self): self.assertRaises( ValueError, self.loop.run_until_complete, coro) - @unittest.mock.patch('asyncio.base_events.socket') + @mock.patch('asyncio.base_events.socket') def test_create_datagram_endpoint_setblk_err(self, m_socket): m_socket.socket.return_value.setblocking.side_effect = OSError @@ -865,14 +865,14 @@ def test_create_datagram_endpoint_noaddr_nofamily(self): asyncio.DatagramProtocol) self.assertRaises(ValueError, self.loop.run_until_complete, coro) - @unittest.mock.patch('asyncio.base_events.socket') + @mock.patch('asyncio.base_events.socket') def test_create_datagram_endpoint_cant_bind(self, m_socket): class Err(OSError): pass m_socket.AF_INET6 = socket.AF_INET6 m_socket.getaddrinfo = socket.getaddrinfo - m_sock = m_socket.socket.return_value = unittest.mock.Mock() + m_sock = m_socket.socket.return_value = mock.Mock() m_sock.bind.side_effect = Err fut = self.loop.create_datagram_endpoint( @@ -882,19 +882,19 @@ class Err(OSError): self.assertTrue(m_sock.close.called) def test_accept_connection_retry(self): - sock = unittest.mock.Mock() + sock = mock.Mock() sock.accept.side_effect = BlockingIOError() self.loop._accept_connection(MyProto, sock) self.assertFalse(sock.close.called) - @unittest.mock.patch('asyncio.base_events.logger') + @mock.patch('asyncio.base_events.logger') def test_accept_connection_exception(self, m_log): - sock = unittest.mock.Mock() + sock = mock.Mock() sock.fileno.return_value = 10 sock.accept.side_effect = OSError(errno.EMFILE, 'Too many open files') - self.loop.remove_reader = unittest.mock.Mock() - self.loop.call_later = unittest.mock.Mock() + self.loop.remove_reader = mock.Mock() + self.loop.call_later = mock.Mock() self.loop._accept_connection(MyProto, sock) self.assertTrue(m_log.error.called) @@ -902,7 +902,7 @@ def test_accept_connection_exception(self, m_log): self.loop.remove_reader.assert_called_with(10) self.loop.call_later.assert_called_with(constants.ACCEPT_RETRY_DELAY, # self.loop._start_serving - unittest.mock.ANY, + mock.ANY, MyProto, sock, None, None) def test_call_coroutine(self): diff --git a/tests/test_events.py b/tests/test_events.py index d00af23d..055a2aaa 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -20,7 +20,7 @@ import time import errno import unittest -import unittest.mock +from unittest import mock from test import support # find_unused_port, IPV6_ENABLED, TEST_HOME_DIR @@ -1812,7 +1812,7 @@ def callback(*args): return args args = () - h = asyncio.Handle(callback, args, unittest.mock.Mock()) + h = asyncio.Handle(callback, args, mock.Mock()) self.assertIs(h._callback, callback) self.assertIs(h._args, args) self.assertFalse(h._cancelled) @@ -1844,15 +1844,15 @@ def test_callback_with_exception(self): def callback(): raise ValueError() - m_loop = unittest.mock.Mock() - m_loop.call_exception_handler = unittest.mock.Mock() + m_loop = mock.Mock() + m_loop.call_exception_handler = mock.Mock() h = asyncio.Handle(callback, (), m_loop) h._run() m_loop.call_exception_handler.assert_called_with({ 'message': test_utils.MockPattern('Exception in callback.*'), - 'exception': unittest.mock.ANY, + 'exception': mock.ANY, 'handle': h }) @@ -1862,7 +1862,7 @@ class TimerTests(unittest.TestCase): def test_hash(self): when = time.monotonic() h = asyncio.TimerHandle(when, lambda: False, (), - unittest.mock.Mock()) + mock.Mock()) self.assertEqual(hash(h), hash(when)) def test_timer(self): @@ -1871,7 +1871,7 @@ def callback(*args): args = () when = time.monotonic() - h = asyncio.TimerHandle(when, callback, args, unittest.mock.Mock()) + h = asyncio.TimerHandle(when, callback, args, mock.Mock()) self.assertIs(h._callback, callback) self.assertIs(h._args, args) self.assertFalse(h._cancelled) @@ -1887,10 +1887,10 @@ def callback(*args): self.assertRaises(AssertionError, asyncio.TimerHandle, None, callback, args, - unittest.mock.Mock()) + mock.Mock()) def test_timer_comparison(self): - loop = unittest.mock.Mock() + loop = mock.Mock() def callback(*args): return args @@ -1935,7 +1935,7 @@ def callback(*args): class AbstractEventLoopTests(unittest.TestCase): def test_not_implemented(self): - f = unittest.mock.Mock() + f = mock.Mock() loop = asyncio.AbstractEventLoop() self.assertRaises( NotImplementedError, loop.run_forever) @@ -1995,13 +1995,13 @@ def test_not_implemented(self): NotImplementedError, loop.remove_signal_handler, 1) self.assertRaises( NotImplementedError, loop.connect_read_pipe, f, - unittest.mock.sentinel.pipe) + mock.sentinel.pipe) self.assertRaises( NotImplementedError, loop.connect_write_pipe, f, - unittest.mock.sentinel.pipe) + mock.sentinel.pipe) self.assertRaises( NotImplementedError, loop.subprocess_shell, f, - unittest.mock.sentinel) + mock.sentinel) self.assertRaises( NotImplementedError, loop.subprocess_exec, f) @@ -2009,7 +2009,7 @@ def test_not_implemented(self): class ProtocolsAbsTests(unittest.TestCase): def test_empty(self): - f = unittest.mock.Mock() + f = mock.Mock() p = asyncio.Protocol() self.assertIsNone(p.connection_made(f)) self.assertIsNone(p.connection_lost(f)) @@ -2055,7 +2055,7 @@ def test_get_event_loop(self): def test_get_event_loop_calls_set_event_loop(self): policy = asyncio.DefaultEventLoopPolicy() - with unittest.mock.patch.object( + with mock.patch.object( policy, "set_event_loop", wraps=policy.set_event_loop) as m_set_event_loop: @@ -2073,7 +2073,7 @@ def test_get_event_loop_after_set_none(self): policy.set_event_loop(None) self.assertRaises(AssertionError, policy.get_event_loop) - @unittest.mock.patch('asyncio.events.threading.current_thread') + @mock.patch('asyncio.events.threading.current_thread') def test_get_event_loop_thread(self, m_current_thread): def f(): diff --git a/tests/test_futures.py b/tests/test_futures.py index f2b81ddd..399e8f43 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -3,7 +3,7 @@ import concurrent.futures import threading import unittest -import unittest.mock +from unittest import mock import asyncio from asyncio import test_utils @@ -174,20 +174,20 @@ def test(): self.assertRaises(AssertionError, test) fut.cancel() - @unittest.mock.patch('asyncio.base_events.logger') + @mock.patch('asyncio.base_events.logger') def test_tb_logger_abandoned(self, m_log): fut = asyncio.Future(loop=self.loop) del fut self.assertFalse(m_log.error.called) - @unittest.mock.patch('asyncio.base_events.logger') + @mock.patch('asyncio.base_events.logger') def test_tb_logger_result_unretrieved(self, m_log): fut = asyncio.Future(loop=self.loop) fut.set_result(42) del fut self.assertFalse(m_log.error.called) - @unittest.mock.patch('asyncio.base_events.logger') + @mock.patch('asyncio.base_events.logger') def test_tb_logger_result_retrieved(self, m_log): fut = asyncio.Future(loop=self.loop) fut.set_result(42) @@ -195,7 +195,7 @@ def test_tb_logger_result_retrieved(self, m_log): del fut self.assertFalse(m_log.error.called) - @unittest.mock.patch('asyncio.base_events.logger') + @mock.patch('asyncio.base_events.logger') def test_tb_logger_exception_unretrieved(self, m_log): fut = asyncio.Future(loop=self.loop) fut.set_exception(RuntimeError('boom')) @@ -203,7 +203,7 @@ def test_tb_logger_exception_unretrieved(self, m_log): test_utils.run_briefly(self.loop) self.assertTrue(m_log.error.called) - @unittest.mock.patch('asyncio.base_events.logger') + @mock.patch('asyncio.base_events.logger') def test_tb_logger_exception_retrieved(self, m_log): fut = asyncio.Future(loop=self.loop) fut.set_exception(RuntimeError('boom')) @@ -211,7 +211,7 @@ def test_tb_logger_exception_retrieved(self, m_log): del fut self.assertFalse(m_log.error.called) - @unittest.mock.patch('asyncio.base_events.logger') + @mock.patch('asyncio.base_events.logger') def test_tb_logger_exception_result_retrieved(self, m_log): fut = asyncio.Future(loop=self.loop) fut.set_exception(RuntimeError('boom')) @@ -236,7 +236,7 @@ def test_wrap_future_future(self): f2 = asyncio.wrap_future(f1) self.assertIs(f1, f2) - @unittest.mock.patch('asyncio.futures.events') + @mock.patch('asyncio.futures.events') def test_wrap_future_use_global_loop(self, m_events): def run(arg): return (arg, threading.get_ident()) diff --git a/tests/test_locks.py b/tests/test_locks.py index 0975f497..f542463a 100644 --- a/tests/test_locks.py +++ b/tests/test_locks.py @@ -1,7 +1,7 @@ """Tests for lock.py""" import unittest -import unittest.mock +from unittest import mock import re import asyncio @@ -27,7 +27,7 @@ def tearDown(self): self.loop.close() def test_ctor_loop(self): - loop = unittest.mock.Mock() + loop = mock.Mock() lock = asyncio.Lock(loop=loop) self.assertIs(lock._loop, loop) @@ -250,7 +250,7 @@ def tearDown(self): self.loop.close() def test_ctor_loop(self): - loop = unittest.mock.Mock() + loop = mock.Mock() ev = asyncio.Event(loop=loop) self.assertIs(ev._loop, loop) @@ -275,7 +275,7 @@ def test_repr(self): self.assertTrue(repr(ev).endswith('[set]>')) self.assertTrue(RGX_REPR.match(repr(ev))) - ev._waiters.append(unittest.mock.Mock()) + ev._waiters.append(mock.Mock()) self.assertTrue('waiters:1' in repr(ev)) self.assertTrue(RGX_REPR.match(repr(ev))) @@ -386,7 +386,7 @@ def tearDown(self): self.loop.close() def test_ctor_loop(self): - loop = unittest.mock.Mock() + loop = mock.Mock() cond = asyncio.Condition(loop=loop) self.assertIs(cond._loop, loop) @@ -644,11 +644,11 @@ def test_repr(self): self.loop.run_until_complete(cond.acquire()) self.assertTrue('locked' in repr(cond)) - cond._waiters.append(unittest.mock.Mock()) + cond._waiters.append(mock.Mock()) self.assertTrue('waiters:1' in repr(cond)) self.assertTrue(RGX_REPR.match(repr(cond))) - cond._waiters.append(unittest.mock.Mock()) + cond._waiters.append(mock.Mock()) self.assertTrue('waiters:2' in repr(cond)) self.assertTrue(RGX_REPR.match(repr(cond))) @@ -688,7 +688,7 @@ def tearDown(self): self.loop.close() def test_ctor_loop(self): - loop = unittest.mock.Mock() + loop = mock.Mock() sem = asyncio.Semaphore(loop=loop) self.assertIs(sem._loop, loop) @@ -717,11 +717,11 @@ def test_repr(self): self.assertTrue('waiters' not in repr(sem)) self.assertTrue(RGX_REPR.match(repr(sem))) - sem._waiters.append(unittest.mock.Mock()) + sem._waiters.append(mock.Mock()) self.assertTrue('waiters:1' in repr(sem)) self.assertTrue(RGX_REPR.match(repr(sem))) - sem._waiters.append(unittest.mock.Mock()) + sem._waiters.append(mock.Mock()) self.assertTrue('waiters:2' in repr(sem)) self.assertTrue(RGX_REPR.match(repr(sem))) diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index 08920690..5bf24a45 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -2,7 +2,7 @@ import socket import unittest -import unittest.mock +from unittest import mock import asyncio from asyncio.proactor_events import BaseProactorEventLoop @@ -16,10 +16,10 @@ class ProactorSocketTransportTests(unittest.TestCase): def setUp(self): self.loop = test_utils.TestLoop() - self.proactor = unittest.mock.Mock() + self.proactor = mock.Mock() self.loop._proactor = self.proactor self.protocol = test_utils.make_test_protocol(asyncio.Protocol) - self.sock = unittest.mock.Mock(socket.socket) + self.sock = mock.Mock(socket.socket) def test_ctor(self): fut = asyncio.Future(loop=self.loop) @@ -56,7 +56,7 @@ def test_loop_reading_no_data(self): self.assertRaises(AssertionError, tr._loop_reading, res) - tr.close = unittest.mock.Mock() + tr.close = mock.Mock() tr._read_fut = res tr._loop_reading(res) self.assertFalse(self.loop._proactor.recv.called) @@ -67,7 +67,7 @@ def test_loop_reading_aborted(self): err = self.loop._proactor.recv.side_effect = ConnectionAbortedError() tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) - tr._fatal_error = unittest.mock.Mock() + tr._fatal_error = mock.Mock() tr._loop_reading() tr._fatal_error.assert_called_with( err, @@ -78,7 +78,7 @@ def test_loop_reading_aborted_closing(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._closing = True - tr._fatal_error = unittest.mock.Mock() + tr._fatal_error = mock.Mock() tr._loop_reading() self.assertFalse(tr._fatal_error.called) @@ -86,7 +86,7 @@ def test_loop_reading_aborted_is_fatal(self): self.loop._proactor.recv.side_effect = ConnectionAbortedError() tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._closing = False - tr._fatal_error = unittest.mock.Mock() + tr._fatal_error = mock.Mock() tr._loop_reading() self.assertTrue(tr._fatal_error.called) @@ -95,8 +95,8 @@ def test_loop_reading_conn_reset_lost(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._closing = False - tr._fatal_error = unittest.mock.Mock() - tr._force_close = unittest.mock.Mock() + tr._fatal_error = mock.Mock() + tr._force_close = mock.Mock() tr._loop_reading() self.assertFalse(tr._fatal_error.called) tr._force_close.assert_called_with(err) @@ -105,7 +105,7 @@ def test_loop_reading_exception(self): err = self.loop._proactor.recv.side_effect = (OSError()) tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) - tr._fatal_error = unittest.mock.Mock() + tr._fatal_error = mock.Mock() tr._loop_reading() tr._fatal_error.assert_called_with( err, @@ -113,7 +113,7 @@ def test_loop_reading_exception(self): def test_write(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) - tr._loop_writing = unittest.mock.Mock() + tr._loop_writing = mock.Mock() tr.write(b'data') self.assertEqual(tr._buffer, None) tr._loop_writing.assert_called_with(data=b'data') @@ -125,8 +125,8 @@ def test_write_no_data(self): def test_write_more(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) - tr._write_fut = unittest.mock.Mock() - tr._loop_writing = unittest.mock.Mock() + tr._write_fut = mock.Mock() + tr._loop_writing = mock.Mock() tr.write(b'data') self.assertEqual(tr._buffer, b'data') self.assertFalse(tr._loop_writing.called) @@ -139,11 +139,11 @@ def test_loop_writing(self): self.loop._proactor.send.return_value.add_done_callback.\ assert_called_with(tr._loop_writing) - @unittest.mock.patch('asyncio.proactor_events.logger') + @mock.patch('asyncio.proactor_events.logger') def test_loop_writing_err(self, m_log): err = self.loop._proactor.send.side_effect = OSError() tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) - tr._fatal_error = unittest.mock.Mock() + tr._fatal_error = mock.Mock() tr._buffer = [b'da', b'ta'] tr._loop_writing() tr._fatal_error.assert_called_with( @@ -182,7 +182,7 @@ def test_loop_writing_closing(self): def test_abort(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) - tr._force_close = unittest.mock.Mock() + tr._force_close = mock.Mock() tr.abort() tr._force_close.assert_called_with(None) @@ -201,7 +201,7 @@ def test_close(self): def test_close_write_fut(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) - tr._write_fut = unittest.mock.Mock() + tr._write_fut = mock.Mock() tr.close() test_utils.run_briefly(self.loop) self.assertFalse(self.protocol.connection_lost.called) @@ -213,10 +213,10 @@ def test_close_buffer(self): test_utils.run_briefly(self.loop) self.assertFalse(self.protocol.connection_lost.called) - @unittest.mock.patch('asyncio.base_events.logger') + @mock.patch('asyncio.base_events.logger') def test_fatal_error(self, m_logging): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) - tr._force_close = unittest.mock.Mock() + tr._force_close = mock.Mock() tr._fatal_error(None) self.assertTrue(tr._force_close.called) self.assertTrue(m_logging.error.called) @@ -224,8 +224,8 @@ def test_fatal_error(self, m_logging): def test_force_close(self): tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) tr._buffer = [b'data'] - read_fut = tr._read_fut = unittest.mock.Mock() - write_fut = tr._write_fut = unittest.mock.Mock() + read_fut = tr._read_fut = mock.Mock() + write_fut = tr._write_fut = mock.Mock() tr._force_close(None) read_fut.cancel.assert_called_with() @@ -346,10 +346,10 @@ def test_pause_resume_reading(self): class BaseProactorEventLoopTests(unittest.TestCase): def setUp(self): - self.sock = unittest.mock.Mock(socket.socket) - self.proactor = unittest.mock.Mock() + self.sock = mock.Mock(socket.socket) + self.proactor = mock.Mock() - self.ssock, self.csock = unittest.mock.Mock(), unittest.mock.Mock() + self.ssock, self.csock = mock.Mock(), mock.Mock() class EventLoop(BaseProactorEventLoop): def _socketpair(s): @@ -357,11 +357,11 @@ def _socketpair(s): self.loop = EventLoop(self.proactor) - @unittest.mock.patch.object(BaseProactorEventLoop, 'call_soon') - @unittest.mock.patch.object(BaseProactorEventLoop, '_socketpair') + @mock.patch.object(BaseProactorEventLoop, 'call_soon') + @mock.patch.object(BaseProactorEventLoop, '_socketpair') def test_ctor(self, socketpair, call_soon): ssock, csock = socketpair.return_value = ( - unittest.mock.Mock(), unittest.mock.Mock()) + mock.Mock(), mock.Mock()) loop = BaseProactorEventLoop(self.proactor) self.assertIs(loop._ssock, ssock) self.assertIs(loop._csock, csock) @@ -377,7 +377,7 @@ def test_close_self_pipe(self): self.assertIsNone(self.loop._csock) def test_close(self): - self.loop._close_self_pipe = unittest.mock.Mock() + self.loop._close_self_pipe = mock.Mock() self.loop.close() self.assertTrue(self.loop._close_self_pipe.called) self.assertTrue(self.proactor.close.called) @@ -418,7 +418,7 @@ def test_loop_self_reading(self): self.loop._loop_self_reading) def test_loop_self_reading_fut(self): - fut = unittest.mock.Mock() + fut = mock.Mock() self.loop._loop_self_reading(fut) self.assertTrue(fut.result.called) self.proactor.recv.assert_called_with(self.ssock, 4096) @@ -426,7 +426,7 @@ def test_loop_self_reading_fut(self): self.loop._loop_self_reading) def test_loop_self_reading_exception(self): - self.loop.close = unittest.mock.Mock() + self.loop.close = mock.Mock() self.proactor.recv.side_effect = OSError() self.assertRaises(OSError, self.loop._loop_self_reading) self.assertTrue(self.loop.close.called) @@ -438,10 +438,10 @@ def test_write_to_self(self): def test_process_events(self): self.loop._process_events([]) - @unittest.mock.patch('asyncio.base_events.logger') + @mock.patch('asyncio.base_events.logger') def test_create_server(self, m_log): - pf = unittest.mock.Mock() - call_soon = self.loop.call_soon = unittest.mock.Mock() + pf = mock.Mock() + call_soon = self.loop.call_soon = mock.Mock() self.loop._start_serving(pf, self.sock) self.assertTrue(call_soon.called) @@ -452,10 +452,10 @@ def test_create_server(self, m_log): self.proactor.accept.assert_called_with(self.sock) # conn - fut = unittest.mock.Mock() - fut.result.return_value = (unittest.mock.Mock(), unittest.mock.Mock()) + fut = mock.Mock() + fut.result.return_value = (mock.Mock(), mock.Mock()) - make_tr = self.loop._make_socket_transport = unittest.mock.Mock() + make_tr = self.loop._make_socket_transport = mock.Mock() loop(fut) self.assertTrue(fut.result.called) self.assertTrue(make_tr.called) @@ -467,8 +467,8 @@ def test_create_server(self, m_log): self.assertTrue(m_log.error.called) def test_create_server_cancel(self): - pf = unittest.mock.Mock() - call_soon = self.loop.call_soon = unittest.mock.Mock() + pf = mock.Mock() + call_soon = self.loop.call_soon = mock.Mock() self.loop._start_serving(pf, self.sock) loop = call_soon.call_args[0][0] @@ -480,7 +480,7 @@ def test_create_server_cancel(self): self.assertTrue(self.sock.close.called) def test_stop_serving(self): - sock = unittest.mock.Mock() + sock = mock.Mock() self.loop._stop_serving(sock) self.assertTrue(sock.close.called) self.proactor._stop_serving.assert_called_with(sock) diff --git a/tests/test_queues.py b/tests/test_queues.py index fc2bf460..f79fee21 100644 --- a/tests/test_queues.py +++ b/tests/test_queues.py @@ -1,7 +1,7 @@ """Tests for queues.py""" import unittest -import unittest.mock +from unittest import mock import asyncio from asyncio import test_utils @@ -72,7 +72,7 @@ def add_putter(): self.assertTrue('_queue=[1]' in fn(q)) def test_ctor_loop(self): - loop = unittest.mock.Mock() + loop = mock.Mock() q = asyncio.Queue(loop=loop) self.assertIs(q._loop, loop) diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 247df9e0..369ec32b 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -7,7 +7,7 @@ import socket import sys import unittest -import unittest.mock +from unittest import mock try: import ssl except ImportError: @@ -23,14 +23,14 @@ from asyncio.selector_events import _SelectorDatagramTransport -MOCK_ANY = unittest.mock.ANY +MOCK_ANY = mock.ANY class TestBaseSelectorEventLoop(BaseSelectorEventLoop): def _make_self_pipe(self): - self._ssock = unittest.mock.Mock() - self._csock = unittest.mock.Mock() + self._ssock = mock.Mock() + self._csock = mock.Mock() self._internal_fds += 1 @@ -41,34 +41,34 @@ def list_to_buffer(l=()): class BaseSelectorEventLoopTests(unittest.TestCase): def setUp(self): - selector = unittest.mock.Mock() + selector = mock.Mock() self.loop = TestBaseSelectorEventLoop(selector) def test_make_socket_transport(self): - m = unittest.mock.Mock() - self.loop.add_reader = unittest.mock.Mock() + m = mock.Mock() + self.loop.add_reader = mock.Mock() transport = self.loop._make_socket_transport(m, asyncio.Protocol()) self.assertIsInstance(transport, _SelectorSocketTransport) @unittest.skipIf(ssl is None, 'No ssl module') def test_make_ssl_transport(self): - m = unittest.mock.Mock() - self.loop.add_reader = unittest.mock.Mock() - self.loop.add_writer = unittest.mock.Mock() - self.loop.remove_reader = unittest.mock.Mock() - self.loop.remove_writer = unittest.mock.Mock() + m = mock.Mock() + self.loop.add_reader = mock.Mock() + self.loop.add_writer = mock.Mock() + self.loop.remove_reader = mock.Mock() + self.loop.remove_writer = mock.Mock() waiter = asyncio.Future(loop=self.loop) transport = self.loop._make_ssl_transport( m, asyncio.Protocol(), m, waiter) self.assertIsInstance(transport, _SelectorSslTransport) - @unittest.mock.patch('asyncio.selector_events.ssl', None) + @mock.patch('asyncio.selector_events.ssl', None) def test_make_ssl_transport_without_ssl_error(self): - m = unittest.mock.Mock() - self.loop.add_reader = unittest.mock.Mock() - self.loop.add_writer = unittest.mock.Mock() - self.loop.remove_reader = unittest.mock.Mock() - self.loop.remove_writer = unittest.mock.Mock() + m = mock.Mock() + self.loop.add_reader = mock.Mock() + self.loop.add_writer = mock.Mock() + self.loop.remove_reader = mock.Mock() + self.loop.remove_writer = mock.Mock() with self.assertRaises(RuntimeError): self.loop._make_ssl_transport(m, m, m, m) @@ -77,10 +77,10 @@ def test_close(self): ssock.fileno.return_value = 7 csock = self.loop._csock csock.fileno.return_value = 1 - remove_reader = self.loop.remove_reader = unittest.mock.Mock() + remove_reader = self.loop.remove_reader = mock.Mock() self.loop._selector.close() - self.loop._selector = selector = unittest.mock.Mock() + self.loop._selector = selector = mock.Mock() self.loop.close() self.assertIsNone(self.loop._selector) self.assertIsNone(self.loop._csock) @@ -96,7 +96,7 @@ def test_close(self): def test_close_no_selector(self): ssock = self.loop._ssock csock = self.loop._csock - remove_reader = self.loop.remove_reader = unittest.mock.Mock() + remove_reader = self.loop.remove_reader = mock.Mock() self.loop._selector.close() self.loop._selector = None @@ -126,15 +126,15 @@ def test_write_to_self_exception(self): self.assertRaises(OSError, self.loop._write_to_self) def test_sock_recv(self): - sock = unittest.mock.Mock() - self.loop._sock_recv = unittest.mock.Mock() + sock = mock.Mock() + self.loop._sock_recv = mock.Mock() f = self.loop.sock_recv(sock, 1024) self.assertIsInstance(f, asyncio.Future) self.loop._sock_recv.assert_called_with(f, False, sock, 1024) def test__sock_recv_canceled_fut(self): - sock = unittest.mock.Mock() + sock = mock.Mock() f = asyncio.Future(loop=self.loop) f.cancel() @@ -143,30 +143,30 @@ def test__sock_recv_canceled_fut(self): self.assertFalse(sock.recv.called) def test__sock_recv_unregister(self): - sock = unittest.mock.Mock() + sock = mock.Mock() sock.fileno.return_value = 10 f = asyncio.Future(loop=self.loop) f.cancel() - self.loop.remove_reader = unittest.mock.Mock() + self.loop.remove_reader = mock.Mock() self.loop._sock_recv(f, True, sock, 1024) self.assertEqual((10,), self.loop.remove_reader.call_args[0]) def test__sock_recv_tryagain(self): f = asyncio.Future(loop=self.loop) - sock = unittest.mock.Mock() + sock = mock.Mock() sock.fileno.return_value = 10 sock.recv.side_effect = BlockingIOError - self.loop.add_reader = unittest.mock.Mock() + self.loop.add_reader = mock.Mock() self.loop._sock_recv(f, False, sock, 1024) self.assertEqual((10, self.loop._sock_recv, f, True, sock, 1024), self.loop.add_reader.call_args[0]) def test__sock_recv_exception(self): f = asyncio.Future(loop=self.loop) - sock = unittest.mock.Mock() + sock = mock.Mock() sock.fileno.return_value = 10 err = sock.recv.side_effect = OSError() @@ -174,8 +174,8 @@ def test__sock_recv_exception(self): self.assertIs(err, f.exception()) def test_sock_sendall(self): - sock = unittest.mock.Mock() - self.loop._sock_sendall = unittest.mock.Mock() + sock = mock.Mock() + self.loop._sock_sendall = mock.Mock() f = self.loop.sock_sendall(sock, b'data') self.assertIsInstance(f, asyncio.Future) @@ -184,8 +184,8 @@ def test_sock_sendall(self): self.loop._sock_sendall.call_args[0]) def test_sock_sendall_nodata(self): - sock = unittest.mock.Mock() - self.loop._sock_sendall = unittest.mock.Mock() + sock = mock.Mock() + self.loop._sock_sendall = mock.Mock() f = self.loop.sock_sendall(sock, b'') self.assertIsInstance(f, asyncio.Future) @@ -194,7 +194,7 @@ def test_sock_sendall_nodata(self): self.assertFalse(self.loop._sock_sendall.called) def test__sock_sendall_canceled_fut(self): - sock = unittest.mock.Mock() + sock = mock.Mock() f = asyncio.Future(loop=self.loop) f.cancel() @@ -203,23 +203,23 @@ def test__sock_sendall_canceled_fut(self): self.assertFalse(sock.send.called) def test__sock_sendall_unregister(self): - sock = unittest.mock.Mock() + sock = mock.Mock() sock.fileno.return_value = 10 f = asyncio.Future(loop=self.loop) f.cancel() - self.loop.remove_writer = unittest.mock.Mock() + self.loop.remove_writer = mock.Mock() self.loop._sock_sendall(f, True, sock, b'data') self.assertEqual((10,), self.loop.remove_writer.call_args[0]) def test__sock_sendall_tryagain(self): f = asyncio.Future(loop=self.loop) - sock = unittest.mock.Mock() + sock = mock.Mock() sock.fileno.return_value = 10 sock.send.side_effect = BlockingIOError - self.loop.add_writer = unittest.mock.Mock() + self.loop.add_writer = mock.Mock() self.loop._sock_sendall(f, False, sock, b'data') self.assertEqual( (10, self.loop._sock_sendall, f, True, sock, b'data'), @@ -227,11 +227,11 @@ def test__sock_sendall_tryagain(self): def test__sock_sendall_interrupted(self): f = asyncio.Future(loop=self.loop) - sock = unittest.mock.Mock() + sock = mock.Mock() sock.fileno.return_value = 10 sock.send.side_effect = InterruptedError - self.loop.add_writer = unittest.mock.Mock() + self.loop.add_writer = mock.Mock() self.loop._sock_sendall(f, False, sock, b'data') self.assertEqual( (10, self.loop._sock_sendall, f, True, sock, b'data'), @@ -239,7 +239,7 @@ def test__sock_sendall_interrupted(self): def test__sock_sendall_exception(self): f = asyncio.Future(loop=self.loop) - sock = unittest.mock.Mock() + sock = mock.Mock() sock.fileno.return_value = 10 err = sock.send.side_effect = OSError() @@ -247,7 +247,7 @@ def test__sock_sendall_exception(self): self.assertIs(f.exception(), err) def test__sock_sendall(self): - sock = unittest.mock.Mock() + sock = mock.Mock() f = asyncio.Future(loop=self.loop) sock.fileno.return_value = 10 @@ -258,13 +258,13 @@ def test__sock_sendall(self): self.assertIsNone(f.result()) def test__sock_sendall_partial(self): - sock = unittest.mock.Mock() + sock = mock.Mock() f = asyncio.Future(loop=self.loop) sock.fileno.return_value = 10 sock.send.return_value = 2 - self.loop.add_writer = unittest.mock.Mock() + self.loop.add_writer = mock.Mock() self.loop._sock_sendall(f, False, sock, b'data') self.assertFalse(f.done()) self.assertEqual( @@ -272,13 +272,13 @@ def test__sock_sendall_partial(self): self.loop.add_writer.call_args[0]) def test__sock_sendall_none(self): - sock = unittest.mock.Mock() + sock = mock.Mock() f = asyncio.Future(loop=self.loop) sock.fileno.return_value = 10 sock.send.return_value = 0 - self.loop.add_writer = unittest.mock.Mock() + self.loop.add_writer = mock.Mock() self.loop._sock_sendall(f, False, sock, b'data') self.assertFalse(f.done()) self.assertEqual( @@ -286,8 +286,8 @@ def test__sock_sendall_none(self): self.loop.add_writer.call_args[0]) def test_sock_connect(self): - sock = unittest.mock.Mock() - self.loop._sock_connect = unittest.mock.Mock() + sock = mock.Mock() + self.loop._sock_connect = mock.Mock() f = self.loop.sock_connect(sock, ('127.0.0.1', 8080)) self.assertIsInstance(f, asyncio.Future) @@ -298,7 +298,7 @@ def test_sock_connect(self): def test__sock_connect(self): f = asyncio.Future(loop=self.loop) - sock = unittest.mock.Mock() + sock = mock.Mock() sock.fileno.return_value = 10 self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) @@ -307,7 +307,7 @@ def test__sock_connect(self): self.assertTrue(sock.connect.called) def test__sock_connect_canceled_fut(self): - sock = unittest.mock.Mock() + sock = mock.Mock() f = asyncio.Future(loop=self.loop) f.cancel() @@ -316,24 +316,24 @@ def test__sock_connect_canceled_fut(self): self.assertFalse(sock.connect.called) def test__sock_connect_unregister(self): - sock = unittest.mock.Mock() + sock = mock.Mock() sock.fileno.return_value = 10 f = asyncio.Future(loop=self.loop) f.cancel() - self.loop.remove_writer = unittest.mock.Mock() + self.loop.remove_writer = mock.Mock() self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) self.assertEqual((10,), self.loop.remove_writer.call_args[0]) def test__sock_connect_tryagain(self): f = asyncio.Future(loop=self.loop) - sock = unittest.mock.Mock() + sock = mock.Mock() sock.fileno.return_value = 10 sock.getsockopt.return_value = errno.EAGAIN - self.loop.add_writer = unittest.mock.Mock() - self.loop.remove_writer = unittest.mock.Mock() + self.loop.add_writer = mock.Mock() + self.loop.remove_writer = mock.Mock() self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) self.assertEqual( @@ -343,17 +343,17 @@ def test__sock_connect_tryagain(self): def test__sock_connect_exception(self): f = asyncio.Future(loop=self.loop) - sock = unittest.mock.Mock() + sock = mock.Mock() sock.fileno.return_value = 10 sock.getsockopt.return_value = errno.ENOTCONN - self.loop.remove_writer = unittest.mock.Mock() + self.loop.remove_writer = mock.Mock() self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) self.assertIsInstance(f.exception(), OSError) def test_sock_accept(self): - sock = unittest.mock.Mock() - self.loop._sock_accept = unittest.mock.Mock() + sock = mock.Mock() + self.loop._sock_accept = mock.Mock() f = self.loop.sock_accept(sock) self.assertIsInstance(f, asyncio.Future) @@ -363,9 +363,9 @@ def test_sock_accept(self): def test__sock_accept(self): f = asyncio.Future(loop=self.loop) - conn = unittest.mock.Mock() + conn = mock.Mock() - sock = unittest.mock.Mock() + sock = mock.Mock() sock.fileno.return_value = 10 sock.accept.return_value = conn, ('127.0.0.1', 1000) @@ -375,7 +375,7 @@ def test__sock_accept(self): self.assertEqual((False,), conn.setblocking.call_args[0]) def test__sock_accept_canceled_fut(self): - sock = unittest.mock.Mock() + sock = mock.Mock() f = asyncio.Future(loop=self.loop) f.cancel() @@ -384,23 +384,23 @@ def test__sock_accept_canceled_fut(self): self.assertFalse(sock.accept.called) def test__sock_accept_unregister(self): - sock = unittest.mock.Mock() + sock = mock.Mock() sock.fileno.return_value = 10 f = asyncio.Future(loop=self.loop) f.cancel() - self.loop.remove_reader = unittest.mock.Mock() + self.loop.remove_reader = mock.Mock() self.loop._sock_accept(f, True, sock) self.assertEqual((10,), self.loop.remove_reader.call_args[0]) def test__sock_accept_tryagain(self): f = asyncio.Future(loop=self.loop) - sock = unittest.mock.Mock() + sock = mock.Mock() sock.fileno.return_value = 10 sock.accept.side_effect = BlockingIOError - self.loop.add_reader = unittest.mock.Mock() + self.loop.add_reader = mock.Mock() self.loop._sock_accept(f, False, sock) self.assertEqual( (10, self.loop._sock_accept, f, True, sock), @@ -408,7 +408,7 @@ def test__sock_accept_tryagain(self): def test__sock_accept_exception(self): f = asyncio.Future(loop=self.loop) - sock = unittest.mock.Mock() + sock = mock.Mock() sock.fileno.return_value = 10 err = sock.accept.side_effect = OSError() @@ -428,8 +428,8 @@ def test_add_reader(self): self.assertIsNone(w) def test_add_reader_existing(self): - reader = unittest.mock.Mock() - writer = unittest.mock.Mock() + reader = mock.Mock() + writer = mock.Mock() self.loop._selector.get_key.return_value = selectors.SelectorKey( 1, 1, selectors.EVENT_WRITE, (reader, writer)) cb = lambda: True @@ -445,7 +445,7 @@ def test_add_reader_existing(self): self.assertEqual(writer, w) def test_add_reader_existing_writer(self): - writer = unittest.mock.Mock() + writer = mock.Mock() self.loop._selector.get_key.return_value = selectors.SelectorKey( 1, 1, selectors.EVENT_WRITE, (None, writer)) cb = lambda: True @@ -467,8 +467,8 @@ def test_remove_reader(self): self.assertTrue(self.loop._selector.unregister.called) def test_remove_reader_read_write(self): - reader = unittest.mock.Mock() - writer = unittest.mock.Mock() + reader = mock.Mock() + writer = mock.Mock() self.loop._selector.get_key.return_value = selectors.SelectorKey( 1, 1, selectors.EVENT_READ | selectors.EVENT_WRITE, (reader, writer)) @@ -498,8 +498,8 @@ def test_add_writer(self): self.assertEqual(cb, w._callback) def test_add_writer_existing(self): - reader = unittest.mock.Mock() - writer = unittest.mock.Mock() + reader = mock.Mock() + writer = mock.Mock() self.loop._selector.get_key.return_value = selectors.SelectorKey( 1, 1, selectors.EVENT_READ, (reader, writer)) cb = lambda: True @@ -522,8 +522,8 @@ def test_remove_writer(self): self.assertTrue(self.loop._selector.unregister.called) def test_remove_writer_read_write(self): - reader = unittest.mock.Mock() - writer = unittest.mock.Mock() + reader = mock.Mock() + writer = mock.Mock() self.loop._selector.get_key.return_value = selectors.SelectorKey( 1, 1, selectors.EVENT_READ | selectors.EVENT_WRITE, (reader, writer)) @@ -541,10 +541,10 @@ def test_remove_writer_unknown(self): self.loop.remove_writer(1)) def test_process_events_read(self): - reader = unittest.mock.Mock() + reader = mock.Mock() reader._cancelled = False - self.loop._add_callback = unittest.mock.Mock() + self.loop._add_callback = mock.Mock() self.loop._process_events( [(selectors.SelectorKey( 1, 1, selectors.EVENT_READ, (reader, None)), @@ -553,10 +553,10 @@ def test_process_events_read(self): self.loop._add_callback.assert_called_with(reader) def test_process_events_read_cancelled(self): - reader = unittest.mock.Mock() + reader = mock.Mock() reader.cancelled = True - self.loop.remove_reader = unittest.mock.Mock() + self.loop.remove_reader = mock.Mock() self.loop._process_events( [(selectors.SelectorKey( 1, 1, selectors.EVENT_READ, (reader, None)), @@ -564,10 +564,10 @@ def test_process_events_read_cancelled(self): self.loop.remove_reader.assert_called_with(1) def test_process_events_write(self): - writer = unittest.mock.Mock() + writer = mock.Mock() writer._cancelled = False - self.loop._add_callback = unittest.mock.Mock() + self.loop._add_callback = mock.Mock() self.loop._process_events( [(selectors.SelectorKey(1, 1, selectors.EVENT_WRITE, (None, writer)), @@ -575,9 +575,9 @@ def test_process_events_write(self): self.loop._add_callback.assert_called_with(writer) def test_process_events_write_cancelled(self): - writer = unittest.mock.Mock() + writer = mock.Mock() writer.cancelled = True - self.loop.remove_writer = unittest.mock.Mock() + self.loop.remove_writer = mock.Mock() self.loop._process_events( [(selectors.SelectorKey(1, 1, selectors.EVENT_WRITE, @@ -591,7 +591,7 @@ class SelectorTransportTests(unittest.TestCase): def setUp(self): self.loop = test_utils.TestLoop() self.protocol = test_utils.make_test_protocol(asyncio.Protocol) - self.sock = unittest.mock.Mock(socket.socket) + self.sock = mock.Mock(socket.socket) self.sock.fileno.return_value = 7 def test_ctor(self): @@ -602,7 +602,7 @@ def test_ctor(self): def test_abort(self): tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) - tr._force_close = unittest.mock.Mock() + tr._force_close = mock.Mock() tr.abort() tr._force_close.assert_called_with(None) @@ -632,8 +632,8 @@ def test_close_write_buffer(self): def test_force_close(self): tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) tr._buffer.extend(b'1') - self.loop.add_reader(7, unittest.mock.sentinel) - self.loop.add_writer(7, unittest.mock.sentinel) + self.loop.add_reader(7, mock.sentinel) + self.loop.add_writer(7, mock.sentinel) tr._force_close(None) self.assertTrue(tr._closing) @@ -646,11 +646,11 @@ def test_force_close(self): self.assertFalse(self.loop.readers) self.assertEqual(1, self.loop.remove_reader_count[7]) - @unittest.mock.patch('asyncio.log.logger.error') + @mock.patch('asyncio.log.logger.error') def test_fatal_error(self, m_exc): exc = OSError() tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) - tr._force_close = unittest.mock.Mock() + tr._force_close = mock.Mock() tr._fatal_error(exc) m_exc.assert_called_with( @@ -682,7 +682,7 @@ class SelectorSocketTransportTests(unittest.TestCase): def setUp(self): self.loop = test_utils.TestLoop() self.protocol = test_utils.make_test_protocol(asyncio.Protocol) - self.sock = unittest.mock.Mock(socket.socket) + self.sock = mock.Mock(socket.socket) self.sock_fd = self.sock.fileno.return_value = 7 def test_ctor(self): @@ -724,7 +724,7 @@ def test_read_ready(self): def test_read_ready_eof(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) - transport.close = unittest.mock.Mock() + transport.close = mock.Mock() self.sock.recv.return_value = b'' transport._read_ready() @@ -735,7 +735,7 @@ def test_read_ready_eof(self): def test_read_ready_eof_keep_open(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) - transport.close = unittest.mock.Mock() + transport.close = mock.Mock() self.sock.recv.return_value = b'' self.protocol.eof_received.return_value = True @@ -744,45 +744,45 @@ def test_read_ready_eof_keep_open(self): self.protocol.eof_received.assert_called_with() self.assertFalse(transport.close.called) - @unittest.mock.patch('logging.exception') + @mock.patch('logging.exception') def test_read_ready_tryagain(self, m_exc): self.sock.recv.side_effect = BlockingIOError transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) - transport._fatal_error = unittest.mock.Mock() + transport._fatal_error = mock.Mock() transport._read_ready() self.assertFalse(transport._fatal_error.called) - @unittest.mock.patch('logging.exception') + @mock.patch('logging.exception') def test_read_ready_tryagain_interrupted(self, m_exc): self.sock.recv.side_effect = InterruptedError transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) - transport._fatal_error = unittest.mock.Mock() + transport._fatal_error = mock.Mock() transport._read_ready() self.assertFalse(transport._fatal_error.called) - @unittest.mock.patch('logging.exception') + @mock.patch('logging.exception') def test_read_ready_conn_reset(self, m_exc): err = self.sock.recv.side_effect = ConnectionResetError() transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) - transport._force_close = unittest.mock.Mock() + transport._force_close = mock.Mock() transport._read_ready() transport._force_close.assert_called_with(err) - @unittest.mock.patch('logging.exception') + @mock.patch('logging.exception') def test_read_ready_err(self, m_exc): err = self.sock.recv.side_effect = OSError() transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) - transport._fatal_error = unittest.mock.Mock() + transport._fatal_error = mock.Mock() transport._read_ready() transport._fatal_error.assert_called_with( @@ -891,14 +891,14 @@ def test_write_tryagain(self): self.loop.assert_writer(7, transport._write_ready) self.assertEqual(list_to_buffer([b'data']), transport._buffer) - @unittest.mock.patch('asyncio.selector_events.logger') + @mock.patch('asyncio.selector_events.logger') def test_write_exception(self, m_log): err = self.sock.send.side_effect = OSError() data = b'data' transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) - transport._fatal_error = unittest.mock.Mock() + transport._fatal_error = mock.Mock() transport.write(data) transport._fatal_error.assert_called_with( err, @@ -1002,17 +1002,17 @@ def test_write_ready_exception(self): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) - transport._fatal_error = unittest.mock.Mock() + transport._fatal_error = mock.Mock() transport._buffer.extend(b'data') transport._write_ready() transport._fatal_error.assert_called_with( err, 'Fatal write error on socket transport') - @unittest.mock.patch('asyncio.base_events.logger') + @mock.patch('asyncio.base_events.logger') def test_write_ready_exception_and_close(self, m_log): self.sock.send.side_effect = OSError() - remove_writer = self.loop.remove_writer = unittest.mock.Mock() + remove_writer = self.loop.remove_writer = mock.Mock() transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) @@ -1053,11 +1053,11 @@ class SelectorSslTransportTests(unittest.TestCase): def setUp(self): self.loop = test_utils.TestLoop() self.protocol = test_utils.make_test_protocol(asyncio.Protocol) - self.sock = unittest.mock.Mock(socket.socket) + self.sock = mock.Mock(socket.socket) self.sock.fileno.return_value = 7 - self.sslsock = unittest.mock.Mock() + self.sslsock = mock.Mock() self.sslsock.fileno.return_value = 1 - self.sslcontext = unittest.mock.Mock() + self.sslcontext = mock.Mock() self.sslcontext.wrap_socket.return_value = self.sslsock def _make_one(self, create_waiter=None): @@ -1162,7 +1162,7 @@ def test_write_closing(self): transport.write(b'data') self.assertEqual(transport._conn_lost, 2) - @unittest.mock.patch('asyncio.selector_events.logger') + @mock.patch('asyncio.selector_events.logger') def test_write_exception(self, m_log): transport = self._make_one() transport._conn_lost = 1 @@ -1182,11 +1182,11 @@ def test_read_ready_recv(self): self.assertEqual((b'data',), self.protocol.data_received.call_args[0]) def test_read_ready_write_wants_read(self): - self.loop.add_writer = unittest.mock.Mock() + self.loop.add_writer = mock.Mock() self.sslsock.recv.side_effect = BlockingIOError transport = self._make_one() transport._write_wants_read = True - transport._write_ready = unittest.mock.Mock() + transport._write_ready = mock.Mock() transport._buffer.extend(b'data') transport._read_ready() @@ -1198,7 +1198,7 @@ def test_read_ready_write_wants_read(self): def test_read_ready_recv_eof(self): self.sslsock.recv.return_value = b'' transport = self._make_one() - transport.close = unittest.mock.Mock() + transport.close = mock.Mock() transport._read_ready() transport.close.assert_called_with() self.protocol.eof_received.assert_called_with() @@ -1206,7 +1206,7 @@ def test_read_ready_recv_eof(self): def test_read_ready_recv_conn_reset(self): err = self.sslsock.recv.side_effect = ConnectionResetError() transport = self._make_one() - transport._force_close = unittest.mock.Mock() + transport._force_close = mock.Mock() transport._read_ready() transport._force_close.assert_called_with(err) @@ -1226,8 +1226,8 @@ def test_read_ready_recv_retry(self): self.assertFalse(self.protocol.data_received.called) def test_read_ready_recv_write(self): - self.loop.remove_reader = unittest.mock.Mock() - self.loop.add_writer = unittest.mock.Mock() + self.loop.remove_reader = mock.Mock() + self.loop.add_writer = mock.Mock() self.sslsock.recv.side_effect = ssl.SSLWantWriteError transport = self._make_one() transport._read_ready() @@ -1241,7 +1241,7 @@ def test_read_ready_recv_write(self): def test_read_ready_recv_exc(self): err = self.sslsock.recv.side_effect = OSError() transport = self._make_one() - transport._fatal_error = unittest.mock.Mock() + transport._fatal_error = mock.Mock() transport._read_ready() transport._fatal_error.assert_called_with( err, @@ -1313,7 +1313,7 @@ def test_write_ready_send_read(self): transport = self._make_one() transport._buffer = list_to_buffer([b'data']) - self.loop.remove_writer = unittest.mock.Mock() + self.loop.remove_writer = mock.Mock() self.sslsock.send.side_effect = ssl.SSLWantReadError transport._write_ready() self.assertFalse(self.protocol.data_received.called) @@ -1325,7 +1325,7 @@ def test_write_ready_send_exc(self): transport = self._make_one() transport._buffer = list_to_buffer([b'data']) - transport._fatal_error = unittest.mock.Mock() + transport._fatal_error = mock.Mock() transport._write_ready() transport._fatal_error.assert_called_with( err, @@ -1333,11 +1333,11 @@ def test_write_ready_send_exc(self): self.assertEqual(list_to_buffer(), transport._buffer) def test_write_ready_read_wants_write(self): - self.loop.add_reader = unittest.mock.Mock() + self.loop.add_reader = mock.Mock() self.sslsock.send.side_effect = BlockingIOError transport = self._make_one() transport._read_wants_write = True - transport._read_ready = unittest.mock.Mock() + transport._read_ready = mock.Mock() transport._write_ready() self.assertFalse(transport._read_wants_write) @@ -1374,9 +1374,9 @@ def test_server_hostname(self): class SelectorSslWithoutSslTransportTests(unittest.TestCase): - @unittest.mock.patch('asyncio.selector_events.ssl', None) + @mock.patch('asyncio.selector_events.ssl', None) def test_ssl_transport_requires_ssl_module(self): - Mock = unittest.mock.Mock + Mock = mock.Mock with self.assertRaises(RuntimeError): transport = _SelectorSslTransport(Mock(), Mock(), Mock(), Mock()) @@ -1386,7 +1386,7 @@ class SelectorDatagramTransportTests(unittest.TestCase): def setUp(self): self.loop = test_utils.TestLoop() self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol) - self.sock = unittest.mock.Mock(spec_set=socket.socket) + self.sock = mock.Mock(spec_set=socket.socket) self.sock.fileno.return_value = 7 def test_read_ready(self): @@ -1404,7 +1404,7 @@ def test_read_ready_tryagain(self): self.loop, self.sock, self.protocol) self.sock.recvfrom.side_effect = BlockingIOError - transport._fatal_error = unittest.mock.Mock() + transport._fatal_error = mock.Mock() transport._read_ready() self.assertFalse(transport._fatal_error.called) @@ -1414,7 +1414,7 @@ def test_read_ready_err(self): self.loop, self.sock, self.protocol) err = self.sock.recvfrom.side_effect = RuntimeError() - transport._fatal_error = unittest.mock.Mock() + transport._fatal_error = mock.Mock() transport._read_ready() transport._fatal_error.assert_called_with( @@ -1426,7 +1426,7 @@ def test_read_ready_oserr(self): self.loop, self.sock, self.protocol) err = self.sock.recvfrom.side_effect = OSError() - transport._fatal_error = unittest.mock.Mock() + transport._fatal_error = mock.Mock() transport._read_ready() self.assertFalse(transport._fatal_error.called) @@ -1518,14 +1518,14 @@ def test_sendto_tryagain(self): self.assertEqual( [(b'data', ('0.0.0.0', 12345))], list(transport._buffer)) - @unittest.mock.patch('asyncio.selector_events.logger') + @mock.patch('asyncio.selector_events.logger') def test_sendto_exception(self, m_log): data = b'data' err = self.sock.sendto.side_effect = RuntimeError() transport = _SelectorDatagramTransport( self.loop, self.sock, self.protocol) - transport._fatal_error = unittest.mock.Mock() + transport._fatal_error = mock.Mock() transport.sendto(data, ()) self.assertTrue(transport._fatal_error.called) @@ -1549,7 +1549,7 @@ def test_sendto_error_received(self): transport = _SelectorDatagramTransport( self.loop, self.sock, self.protocol) - transport._fatal_error = unittest.mock.Mock() + transport._fatal_error = mock.Mock() transport.sendto(data, ()) self.assertEqual(transport._conn_lost, 0) @@ -1562,7 +1562,7 @@ def test_sendto_error_received_connected(self): transport = _SelectorDatagramTransport( self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) - transport._fatal_error = unittest.mock.Mock() + transport._fatal_error = mock.Mock() transport.sendto(data) self.assertFalse(transport._fatal_error.called) @@ -1643,7 +1643,7 @@ def test_sendto_ready_exception(self): transport = _SelectorDatagramTransport( self.loop, self.sock, self.protocol) - transport._fatal_error = unittest.mock.Mock() + transport._fatal_error = mock.Mock() transport._buffer.append((b'data', ())) transport._sendto_ready() @@ -1656,7 +1656,7 @@ def test_sendto_ready_error_received(self): transport = _SelectorDatagramTransport( self.loop, self.sock, self.protocol) - transport._fatal_error = unittest.mock.Mock() + transport._fatal_error = mock.Mock() transport._buffer.append((b'data', ())) transport._sendto_ready() @@ -1667,14 +1667,14 @@ def test_sendto_ready_error_received_connection(self): transport = _SelectorDatagramTransport( self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) - transport._fatal_error = unittest.mock.Mock() + transport._fatal_error = mock.Mock() transport._buffer.append((b'data', ())) transport._sendto_ready() self.assertFalse(transport._fatal_error.called) self.assertTrue(self.protocol.error_received.called) - @unittest.mock.patch('asyncio.base_events.logger.error') + @mock.patch('asyncio.base_events.logger.error') def test_fatal_error_connected(self, m_exc): transport = _SelectorDatagramTransport( self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) diff --git a/tests/test_selectors.py b/tests/test_selectors.py index 0519d75a..93929626 100644 --- a/tests/test_selectors.py +++ b/tests/test_selectors.py @@ -1,7 +1,7 @@ """Tests for selectors.py.""" import unittest -import unittest.mock +from unittest import mock from asyncio import selectors @@ -20,7 +20,7 @@ def test_len(self): map = selectors._SelectorMapping(s) self.assertTrue(map.__len__() == 0) - f = unittest.mock.Mock() + f = mock.Mock() f.fileno.return_value = 10 s.register(f, selectors.EVENT_READ, None) self.assertTrue(len(map) == 1) @@ -28,7 +28,7 @@ def test_len(self): def test_getitem(self): s = FakeSelector() map = selectors._SelectorMapping(s) - f = unittest.mock.Mock() + f = mock.Mock() f.fileno.return_value = 10 s.register(f, selectors.EVENT_READ, None) attended = selectors.SelectorKey(f, 10, selectors.EVENT_READ, None) @@ -38,7 +38,7 @@ def test_getitem_key_error(self): s = FakeSelector() map = selectors._SelectorMapping(s) self.assertTrue(len(map) == 0) - f = unittest.mock.Mock() + f = mock.Mock() f.fileno.return_value = 10 s.register(f, selectors.EVENT_READ, None) self.assertRaises(KeyError, map.__getitem__, 5) @@ -47,7 +47,7 @@ def test_iter(self): s = FakeSelector() map = selectors._SelectorMapping(s) self.assertTrue(len(map) == 0) - f = unittest.mock.Mock() + f = mock.Mock() f.fileno.return_value = 5 s.register(f, selectors.EVENT_READ, None) counter = 0 @@ -64,7 +64,7 @@ class BaseSelectorTests(unittest.TestCase): def test_fileobj_to_fd(self): self.assertEqual(10, selectors._fileobj_to_fd(10)) - f = unittest.mock.Mock() + f = mock.Mock() f.fileno.return_value = 10 self.assertEqual(10, selectors._fileobj_to_fd(f)) @@ -80,7 +80,7 @@ def test_selector_key_repr(self): "SelectorKey(fileobj=10, fd=10, events=1, data=None)", repr(key)) def test_register(self): - fobj = unittest.mock.Mock() + fobj = mock.Mock() fobj.fileno.return_value = 10 s = FakeSelector() @@ -91,10 +91,10 @@ def test_register(self): def test_register_unknown_event(self): s = FakeSelector() - self.assertRaises(ValueError, s.register, unittest.mock.Mock(), 999999) + self.assertRaises(ValueError, s.register, mock.Mock(), 999999) def test_register_already_registered(self): - fobj = unittest.mock.Mock() + fobj = mock.Mock() fobj.fileno.return_value = 10 s = FakeSelector() @@ -102,7 +102,7 @@ def test_register_already_registered(self): self.assertRaises(KeyError, s.register, fobj, selectors.EVENT_READ) def test_unregister(self): - fobj = unittest.mock.Mock() + fobj = mock.Mock() fobj.fileno.return_value = 10 s = FakeSelector() @@ -111,21 +111,21 @@ def test_unregister(self): self.assertFalse(s._fd_to_key) def test_unregister_unknown(self): - fobj = unittest.mock.Mock() + fobj = mock.Mock() fobj.fileno.return_value = 10 s = FakeSelector() self.assertRaises(KeyError, s.unregister, fobj) def test_modify_unknown(self): - fobj = unittest.mock.Mock() + fobj = mock.Mock() fobj.fileno.return_value = 10 s = FakeSelector() self.assertRaises(KeyError, s.modify, fobj, 1) def test_modify(self): - fobj = unittest.mock.Mock() + fobj = mock.Mock() fobj.fileno.return_value = 10 s = FakeSelector() @@ -137,7 +137,7 @@ def test_modify(self): s.get_key(fobj)) def test_modify_data(self): - fobj = unittest.mock.Mock() + fobj = mock.Mock() fobj.fileno.return_value = 10 d1 = object() @@ -153,7 +153,7 @@ def test_modify_data(self): s.get_key(fobj)) def test_modify_data_use_a_shortcut(self): - fobj = unittest.mock.Mock() + fobj = mock.Mock() fobj.fileno.return_value = 10 d1 = object() @@ -162,14 +162,14 @@ def test_modify_data_use_a_shortcut(self): s = FakeSelector() key = s.register(fobj, selectors.EVENT_READ, d1) - s.unregister = unittest.mock.Mock() - s.register = unittest.mock.Mock() + s.unregister = mock.Mock() + s.register = mock.Mock() key2 = s.modify(fobj, selectors.EVENT_READ, d2) self.assertFalse(s.unregister.called) self.assertFalse(s.register.called) def test_modify_same(self): - fobj = unittest.mock.Mock() + fobj = mock.Mock() fobj.fileno.return_value = 10 data = object() diff --git a/tests/test_streams.py b/tests/test_streams.py index ca792f20..e921dfe4 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -4,7 +4,7 @@ import gc import socket import unittest -import unittest.mock +from unittest import mock try: import ssl except ImportError: @@ -29,7 +29,7 @@ def tearDown(self): self.loop.close() gc.collect() - @unittest.mock.patch('asyncio.streams.events') + @mock.patch('asyncio.streams.events') def test_ctor_global_loop(self, m_events): stream = asyncio.StreamReader() self.assertIs(stream._loop, m_events.get_event_loop.return_value) diff --git a/tests/test_transports.py b/tests/test_transports.py index 4c645268..cfbdf3e9 100644 --- a/tests/test_transports.py +++ b/tests/test_transports.py @@ -1,7 +1,7 @@ """Tests for transports.py.""" import unittest -import unittest.mock +from unittest import mock import asyncio from asyncio import transports @@ -23,7 +23,7 @@ def test_get_extra_info(self): def test_writelines(self): transport = asyncio.Transport() - transport.write = unittest.mock.Mock() + transport.write = mock.Mock() transport.writelines([b'line1', bytearray(b'line2'), @@ -70,7 +70,7 @@ def get_write_buffer_size(self): return 512 transport = MyTransport() - transport._protocol = unittest.mock.Mock() + transport._protocol = mock.Mock() self.assertFalse(transport._protocol_paused) diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 9e489c2d..3b187de9 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -13,7 +13,7 @@ import tempfile import threading import unittest -import unittest.mock +from unittest import mock if sys.platform == 'win32': raise unittest.SkipTest('UNIX only') @@ -25,7 +25,7 @@ from asyncio import unix_events -MOCK_ANY = unittest.mock.ANY +MOCK_ANY = mock.ANY @unittest.skipUnless(signal, 'Signals are not supported') @@ -48,15 +48,15 @@ def test_handle_signal_no_handler(self): self.loop._handle_signal(signal.NSIG + 1, ()) def test_handle_signal_cancelled_handler(self): - h = asyncio.Handle(unittest.mock.Mock(), (), - loop=unittest.mock.Mock()) + h = asyncio.Handle(mock.Mock(), (), + loop=mock.Mock()) h.cancel() self.loop._signal_handlers[signal.NSIG + 1] = h - self.loop.remove_signal_handler = unittest.mock.Mock() + self.loop.remove_signal_handler = mock.Mock() self.loop._handle_signal(signal.NSIG + 1, ()) self.loop.remove_signal_handler.assert_called_with(signal.NSIG + 1) - @unittest.mock.patch('asyncio.unix_events.signal') + @mock.patch('asyncio.unix_events.signal') def test_add_signal_handler_setup_error(self, m_signal): m_signal.NSIG = signal.NSIG m_signal.set_wakeup_fd.side_effect = ValueError @@ -66,7 +66,7 @@ def test_add_signal_handler_setup_error(self, m_signal): self.loop.add_signal_handler, signal.SIGINT, lambda: True) - @unittest.mock.patch('asyncio.unix_events.signal') + @mock.patch('asyncio.unix_events.signal') def test_add_signal_handler(self, m_signal): m_signal.NSIG = signal.NSIG @@ -76,7 +76,7 @@ def test_add_signal_handler(self, m_signal): self.assertIsInstance(h, asyncio.Handle) self.assertEqual(h._callback, cb) - @unittest.mock.patch('asyncio.unix_events.signal') + @mock.patch('asyncio.unix_events.signal') def test_add_signal_handler_install_error(self, m_signal): m_signal.NSIG = signal.NSIG @@ -94,8 +94,8 @@ class Err(OSError): self.loop.add_signal_handler, signal.SIGINT, lambda: True) - @unittest.mock.patch('asyncio.unix_events.signal') - @unittest.mock.patch('asyncio.base_events.logger') + @mock.patch('asyncio.unix_events.signal') + @mock.patch('asyncio.base_events.logger') def test_add_signal_handler_install_error2(self, m_logging, m_signal): m_signal.NSIG = signal.NSIG @@ -111,8 +111,8 @@ class Err(OSError): self.assertFalse(m_logging.info.called) self.assertEqual(1, m_signal.set_wakeup_fd.call_count) - @unittest.mock.patch('asyncio.unix_events.signal') - @unittest.mock.patch('asyncio.base_events.logger') + @mock.patch('asyncio.unix_events.signal') + @mock.patch('asyncio.base_events.logger') def test_add_signal_handler_install_error3(self, m_logging, m_signal): class Err(OSError): errno = errno.EINVAL @@ -126,7 +126,7 @@ class Err(OSError): self.assertFalse(m_logging.info.called) self.assertEqual(2, m_signal.set_wakeup_fd.call_count) - @unittest.mock.patch('asyncio.unix_events.signal') + @mock.patch('asyncio.unix_events.signal') def test_remove_signal_handler(self, m_signal): m_signal.NSIG = signal.NSIG @@ -139,7 +139,7 @@ def test_remove_signal_handler(self, m_signal): self.assertEqual( (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0]) - @unittest.mock.patch('asyncio.unix_events.signal') + @mock.patch('asyncio.unix_events.signal') def test_remove_signal_handler_2(self, m_signal): m_signal.NSIG = signal.NSIG m_signal.SIGINT = signal.SIGINT @@ -156,8 +156,8 @@ def test_remove_signal_handler_2(self, m_signal): (signal.SIGINT, m_signal.default_int_handler), m_signal.signal.call_args[0]) - @unittest.mock.patch('asyncio.unix_events.signal') - @unittest.mock.patch('asyncio.base_events.logger') + @mock.patch('asyncio.unix_events.signal') + @mock.patch('asyncio.base_events.logger') def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): m_signal.NSIG = signal.NSIG self.loop.add_signal_handler(signal.SIGHUP, lambda: True) @@ -167,7 +167,7 @@ def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): self.loop.remove_signal_handler(signal.SIGHUP) self.assertTrue(m_logging.info) - @unittest.mock.patch('asyncio.unix_events.signal') + @mock.patch('asyncio.unix_events.signal') def test_remove_signal_handler_error(self, m_signal): m_signal.NSIG = signal.NSIG self.loop.add_signal_handler(signal.SIGHUP, lambda: True) @@ -177,7 +177,7 @@ def test_remove_signal_handler_error(self, m_signal): self.assertRaises( OSError, self.loop.remove_signal_handler, signal.SIGHUP) - @unittest.mock.patch('asyncio.unix_events.signal') + @mock.patch('asyncio.unix_events.signal') def test_remove_signal_handler_error2(self, m_signal): m_signal.NSIG = signal.NSIG self.loop.add_signal_handler(signal.SIGHUP, lambda: True) @@ -189,7 +189,7 @@ class Err(OSError): self.assertRaises( RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP) - @unittest.mock.patch('asyncio.unix_events.signal') + @mock.patch('asyncio.unix_events.signal') def test_close(self, m_signal): m_signal.NSIG = signal.NSIG @@ -291,16 +291,16 @@ class UnixReadPipeTransportTests(unittest.TestCase): def setUp(self): self.loop = test_utils.TestLoop() self.protocol = test_utils.make_test_protocol(asyncio.Protocol) - self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe = mock.Mock(spec_set=io.RawIOBase) self.pipe.fileno.return_value = 5 - fcntl_patcher = unittest.mock.patch('fcntl.fcntl') + fcntl_patcher = mock.patch('fcntl.fcntl') fcntl_patcher.start() self.addCleanup(fcntl_patcher.stop) - fstat_patcher = unittest.mock.patch('os.fstat') + fstat_patcher = mock.patch('os.fstat') m_fstat = fstat_patcher.start() - st = unittest.mock.Mock() + st = mock.Mock() st.st_mode = stat.S_IFIFO m_fstat.return_value = st self.addCleanup(fstat_patcher.stop) @@ -319,7 +319,7 @@ def test_ctor_with_waiter(self): test_utils.run_briefly(self.loop) self.assertIsNone(fut.result()) - @unittest.mock.patch('os.read') + @mock.patch('os.read') def test__read_ready(self, m_read): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) @@ -329,7 +329,7 @@ def test__read_ready(self, m_read): m_read.assert_called_with(5, tr.max_size) self.protocol.data_received.assert_called_with(b'data') - @unittest.mock.patch('os.read') + @mock.patch('os.read') def test__read_ready_eof(self, m_read): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) @@ -342,7 +342,7 @@ def test__read_ready_eof(self, m_read): self.protocol.eof_received.assert_called_with() self.protocol.connection_lost.assert_called_with(None) - @unittest.mock.patch('os.read') + @mock.patch('os.read') def test__read_ready_blocked(self, m_read): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) @@ -353,14 +353,14 @@ def test__read_ready_blocked(self, m_read): test_utils.run_briefly(self.loop) self.assertFalse(self.protocol.data_received.called) - @unittest.mock.patch('asyncio.log.logger.error') - @unittest.mock.patch('os.read') + @mock.patch('asyncio.log.logger.error') + @mock.patch('os.read') def test__read_ready_error(self, m_read, m_logexc): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) err = OSError() m_read.side_effect = err - tr._close = unittest.mock.Mock() + tr._close = mock.Mock() tr._read_ready() m_read.assert_called_with(5, tr.max_size) @@ -371,17 +371,17 @@ def test__read_ready_error(self, m_read, m_logexc): '\nprotocol:.*\ntransport:.*'), exc_info=(OSError, MOCK_ANY, MOCK_ANY)) - @unittest.mock.patch('os.read') + @mock.patch('os.read') def test_pause_reading(self, m_read): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) - m = unittest.mock.Mock() + m = mock.Mock() self.loop.add_reader(5, m) tr.pause_reading() self.assertFalse(self.loop.readers) - @unittest.mock.patch('os.read') + @mock.patch('os.read') def test_resume_reading(self, m_read): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) @@ -389,26 +389,26 @@ def test_resume_reading(self, m_read): tr.resume_reading() self.loop.assert_reader(5, tr._read_ready) - @unittest.mock.patch('os.read') + @mock.patch('os.read') def test_close(self, m_read): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) - tr._close = unittest.mock.Mock() + tr._close = mock.Mock() tr.close() tr._close.assert_called_with(None) - @unittest.mock.patch('os.read') + @mock.patch('os.read') def test_close_already_closing(self, m_read): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) tr._closing = True - tr._close = unittest.mock.Mock() + tr._close = mock.Mock() tr.close() self.assertFalse(tr._close.called) - @unittest.mock.patch('os.read') + @mock.patch('os.read') def test__close(self, m_read): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) @@ -459,16 +459,16 @@ class UnixWritePipeTransportTests(unittest.TestCase): def setUp(self): self.loop = test_utils.TestLoop() self.protocol = test_utils.make_test_protocol(asyncio.BaseProtocol) - self.pipe = unittest.mock.Mock(spec_set=io.RawIOBase) + self.pipe = mock.Mock(spec_set=io.RawIOBase) self.pipe.fileno.return_value = 5 - fcntl_patcher = unittest.mock.patch('fcntl.fcntl') + fcntl_patcher = mock.patch('fcntl.fcntl') fcntl_patcher.start() self.addCleanup(fcntl_patcher.stop) - fstat_patcher = unittest.mock.patch('os.fstat') + fstat_patcher = mock.patch('os.fstat') m_fstat = fstat_patcher.start() - st = unittest.mock.Mock() + st = mock.Mock() st.st_mode = stat.S_IFSOCK m_fstat.return_value = st self.addCleanup(fstat_patcher.stop) @@ -493,7 +493,7 @@ def test_can_write_eof(self): self.loop, self.pipe, self.protocol) self.assertTrue(tr.can_write_eof()) - @unittest.mock.patch('os.write') + @mock.patch('os.write') def test_write(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -504,7 +504,7 @@ def test_write(self, m_write): self.assertFalse(self.loop.writers) self.assertEqual([], tr._buffer) - @unittest.mock.patch('os.write') + @mock.patch('os.write') def test_write_no_data(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -514,7 +514,7 @@ def test_write_no_data(self, m_write): self.assertFalse(self.loop.writers) self.assertEqual([], tr._buffer) - @unittest.mock.patch('os.write') + @mock.patch('os.write') def test_write_partial(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -525,7 +525,7 @@ def test_write_partial(self, m_write): self.loop.assert_writer(5, tr._write_ready) self.assertEqual([b'ta'], tr._buffer) - @unittest.mock.patch('os.write') + @mock.patch('os.write') def test_write_buffer(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -537,7 +537,7 @@ def test_write_buffer(self, m_write): self.loop.assert_writer(5, tr._write_ready) self.assertEqual([b'previous', b'data'], tr._buffer) - @unittest.mock.patch('os.write') + @mock.patch('os.write') def test_write_again(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -548,15 +548,15 @@ def test_write_again(self, m_write): self.loop.assert_writer(5, tr._write_ready) self.assertEqual([b'data'], tr._buffer) - @unittest.mock.patch('asyncio.unix_events.logger') - @unittest.mock.patch('os.write') + @mock.patch('asyncio.unix_events.logger') + @mock.patch('os.write') def test_write_err(self, m_write, m_log): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) err = OSError() m_write.side_effect = err - tr._fatal_error = unittest.mock.Mock() + tr._fatal_error = mock.Mock() tr.write(b'data') m_write.assert_called_with(5, b'data') self.assertFalse(self.loop.writers) @@ -576,7 +576,7 @@ def test_write_err(self, m_write, m_log): m_log.warning.assert_called_with( 'pipe closed by peer or os.write(pipe, data) raised exception.') - @unittest.mock.patch('os.write') + @mock.patch('os.write') def test_write_close(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -597,7 +597,7 @@ def test__read_ready(self): test_utils.run_briefly(self.loop) self.protocol.connection_lost.assert_called_with(None) - @unittest.mock.patch('os.write') + @mock.patch('os.write') def test__write_ready(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -609,7 +609,7 @@ def test__write_ready(self, m_write): self.assertFalse(self.loop.writers) self.assertEqual([], tr._buffer) - @unittest.mock.patch('os.write') + @mock.patch('os.write') def test__write_ready_partial(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -622,7 +622,7 @@ def test__write_ready_partial(self, m_write): self.loop.assert_writer(5, tr._write_ready) self.assertEqual([b'a'], tr._buffer) - @unittest.mock.patch('os.write') + @mock.patch('os.write') def test__write_ready_again(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -635,7 +635,7 @@ def test__write_ready_again(self, m_write): self.loop.assert_writer(5, tr._write_ready) self.assertEqual([b'data'], tr._buffer) - @unittest.mock.patch('os.write') + @mock.patch('os.write') def test__write_ready_empty(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -648,8 +648,8 @@ def test__write_ready_empty(self, m_write): self.loop.assert_writer(5, tr._write_ready) self.assertEqual([b'data'], tr._buffer) - @unittest.mock.patch('asyncio.log.logger.error') - @unittest.mock.patch('os.write') + @mock.patch('asyncio.log.logger.error') + @mock.patch('os.write') def test__write_ready_err(self, m_write, m_logexc): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -672,7 +672,7 @@ def test__write_ready_err(self, m_write, m_logexc): test_utils.run_briefly(self.loop) self.protocol.connection_lost.assert_called_with(err) - @unittest.mock.patch('os.write') + @mock.patch('os.write') def test__write_ready_closing(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -689,7 +689,7 @@ def test__write_ready_closing(self, m_write): self.protocol.connection_lost.assert_called_with(None) self.pipe.close.assert_called_with() - @unittest.mock.patch('os.write') + @mock.patch('os.write') def test_abort(self, m_write): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) @@ -742,7 +742,7 @@ def test_close(self): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) - tr.write_eof = unittest.mock.Mock() + tr.write_eof = mock.Mock() tr.close() tr.write_eof.assert_called_with() @@ -750,7 +750,7 @@ def test_close_closing(self): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) - tr.write_eof = unittest.mock.Mock() + tr.write_eof = mock.Mock() tr._closing = True tr.close() self.assertFalse(tr.write_eof.called) @@ -777,7 +777,7 @@ def test_write_eof_pending(self): class AbstractChildWatcherTests(unittest.TestCase): def test_not_implemented(self): - f = unittest.mock.Mock() + f = mock.Mock() watcher = asyncio.AbstractChildWatcher() self.assertRaises( NotImplementedError, watcher.add_child_handler, f, f) @@ -796,7 +796,7 @@ def test_not_implemented(self): class BaseChildWatcherTests(unittest.TestCase): def test_not_implemented(self): - f = unittest.mock.Mock() + f = mock.Mock() watcher = unix_events.BaseChildWatcher() self.assertRaises( NotImplementedError, watcher._do_waitpid, f) @@ -813,14 +813,14 @@ def test_not_implemented(self): class ChildWatcherTestsMixin: - ignore_warnings = unittest.mock.patch.object(log.logger, "warning") + ignore_warnings = mock.patch.object(log.logger, "warning") def setUp(self): self.loop = test_utils.TestLoop() self.running = False self.zombies = {} - with unittest.mock.patch.object( + with mock.patch.object( self.loop, "add_signal_handler") as self.m_add_signal_handler: self.watcher = self.create_watcher() self.watcher.attach_loop(self.loop) @@ -864,8 +864,8 @@ def test_create_watcher(self): def waitpid_mocks(func): def wrapped_func(self): def patch(target, wrapper): - return unittest.mock.patch(target, wraps=wrapper, - new_callable=unittest.mock.Mock) + return mock.patch(target, wraps=wrapper, + new_callable=mock.Mock) with patch('os.WTERMSIG', self.WTERMSIG) as m_WTERMSIG, \ patch('os.WEXITSTATUS', self.WEXITSTATUS) as m_WEXITSTATUS, \ @@ -881,7 +881,7 @@ def patch(target, wrapper): @waitpid_mocks def test_sigchld(self, m): # register a child - callback = unittest.mock.Mock() + callback = mock.Mock() with self.watcher: self.running = True @@ -941,8 +941,8 @@ def test_sigchld(self, m): @waitpid_mocks def test_sigchld_two_children(self, m): - callback1 = unittest.mock.Mock() - callback2 = unittest.mock.Mock() + callback1 = mock.Mock() + callback2 = mock.Mock() # register child 1 with self.watcher: @@ -1045,8 +1045,8 @@ def test_sigchld_two_children(self, m): @waitpid_mocks def test_sigchld_two_children_terminating_together(self, m): - callback1 = unittest.mock.Mock() - callback2 = unittest.mock.Mock() + callback1 = mock.Mock() + callback2 = mock.Mock() # register child 1 with self.watcher: @@ -1115,7 +1115,7 @@ def test_sigchld_two_children_terminating_together(self, m): @waitpid_mocks def test_sigchld_race_condition(self, m): # register a child - callback = unittest.mock.Mock() + callback = mock.Mock() with self.watcher: # child terminates before being registered @@ -1136,8 +1136,8 @@ def test_sigchld_race_condition(self, m): @waitpid_mocks def test_sigchld_replace_handler(self, m): - callback1 = unittest.mock.Mock() - callback2 = unittest.mock.Mock() + callback1 = mock.Mock() + callback2 = mock.Mock() # register a child with self.watcher: @@ -1189,7 +1189,7 @@ def test_sigchld_replace_handler(self, m): @waitpid_mocks def test_sigchld_remove_handler(self, m): - callback = unittest.mock.Mock() + callback = mock.Mock() # register a child with self.watcher: @@ -1221,7 +1221,7 @@ def test_sigchld_remove_handler(self, m): @waitpid_mocks def test_sigchld_unknown_status(self, m): - callback = unittest.mock.Mock() + callback = mock.Mock() # register a child with self.watcher: @@ -1258,9 +1258,9 @@ def test_sigchld_unknown_status(self, m): @waitpid_mocks def test_remove_child_handler(self, m): - callback1 = unittest.mock.Mock() - callback2 = unittest.mock.Mock() - callback3 = unittest.mock.Mock() + callback1 = mock.Mock() + callback2 = mock.Mock() + callback3 = mock.Mock() # register children with self.watcher: @@ -1291,7 +1291,7 @@ def test_remove_child_handler(self, m): @waitpid_mocks def test_sigchld_unhandled_exception(self, m): - callback = unittest.mock.Mock() + callback = mock.Mock() # register a child with self.watcher: @@ -1301,7 +1301,7 @@ def test_sigchld_unhandled_exception(self, m): # raise an exception m.waitpid.side_effect = ValueError - with unittest.mock.patch.object(log.logger, + with mock.patch.object(log.logger, 'error') as m_error: self.assertEqual(self.watcher._sig_chld(), None) @@ -1310,7 +1310,7 @@ def test_sigchld_unhandled_exception(self, m): @waitpid_mocks def test_sigchld_child_reaped_elsewhere(self, m): # register a child - callback = unittest.mock.Mock() + callback = mock.Mock() with self.watcher: self.running = True @@ -1346,8 +1346,8 @@ def test_sigchld_child_reaped_elsewhere(self, m): @waitpid_mocks def test_sigchld_unknown_pid_during_registration(self, m): # register two children - callback1 = unittest.mock.Mock() - callback2 = unittest.mock.Mock() + callback1 = mock.Mock() + callback2 = mock.Mock() with self.ignore_warnings, self.watcher: self.running = True @@ -1367,7 +1367,7 @@ def test_sigchld_unknown_pid_during_registration(self, m): @waitpid_mocks def test_set_loop(self, m): # register a child - callback = unittest.mock.Mock() + callback = mock.Mock() with self.watcher: self.running = True @@ -1377,10 +1377,10 @@ def test_set_loop(self, m): old_loop = self.loop self.loop = test_utils.TestLoop() - with unittest.mock.patch.object( + with mock.patch.object( old_loop, "remove_signal_handler") as m_old_remove_signal_handler, \ - unittest.mock.patch.object( + mock.patch.object( self.loop, "add_signal_handler") as m_new_add_signal_handler: @@ -1401,9 +1401,9 @@ def test_set_loop(self, m): @waitpid_mocks def test_set_loop_race_condition(self, m): # register 3 children - callback1 = unittest.mock.Mock() - callback2 = unittest.mock.Mock() - callback3 = unittest.mock.Mock() + callback1 = mock.Mock() + callback2 = mock.Mock() + callback3 = mock.Mock() with self.watcher: self.running = True @@ -1415,7 +1415,7 @@ def test_set_loop_race_condition(self, m): old_loop = self.loop self.loop = None - with unittest.mock.patch.object( + with mock.patch.object( old_loop, "remove_signal_handler") as m_remove_signal_handler: self.watcher.attach_loop(None) @@ -1435,7 +1435,7 @@ def test_set_loop_race_condition(self, m): # attach a new loop self.loop = test_utils.TestLoop() - with unittest.mock.patch.object( + with mock.patch.object( self.loop, "add_signal_handler") as m_add_signal_handler: self.watcher.attach_loop(self.loop) @@ -1461,8 +1461,8 @@ def test_set_loop_race_condition(self, m): @waitpid_mocks def test_close(self, m): # register two children - callback1 = unittest.mock.Mock() - callback2 = unittest.mock.Mock() + callback1 = mock.Mock() + callback2 = mock.Mock() with self.watcher: self.running = True @@ -1479,7 +1479,7 @@ def test_close(self, m): if isinstance(self.watcher, asyncio.FastChildWatcher): self.assertEqual(len(self.watcher._zombies), 1) - with unittest.mock.patch.object( + with mock.patch.object( self.loop, "remove_signal_handler") as m_remove_signal_handler: diff --git a/tests/test_windows_utils.py b/tests/test_windows_utils.py index fa9d66c0..7616c73e 100644 --- a/tests/test_windows_utils.py +++ b/tests/test_windows_utils.py @@ -3,7 +3,7 @@ import sys import test.support import unittest -import unittest.mock +from unittest import mock if sys.platform != 'win32': raise unittest.SkipTest('Windows only') @@ -25,7 +25,7 @@ def test_winsocketpair(self): csock.close() ssock.close() - @unittest.mock.patch('asyncio.windows_utils.socket') + @mock.patch('asyncio.windows_utils.socket') def test_winsocketpair_exc(self, m_socket): m_socket.socket.return_value.getsockname.return_value = ('', 12345) m_socket.socket.return_value.accept.return_value = object(), object() From 4329622f7d6cdde2c7dc29b3b709e343fbf0349c Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 26 Feb 2014 10:55:46 +0100 Subject: [PATCH 0978/1502] Cleanup test_unix_events.py (indentation) --- tests/test_unix_events.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 3b187de9..792e702c 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -1302,7 +1302,7 @@ def test_sigchld_unhandled_exception(self, m): m.waitpid.side_effect = ValueError with mock.patch.object(log.logger, - 'error') as m_error: + 'error') as m_error: self.assertEqual(self.watcher._sig_chld(), None) self.assertTrue(m_error.called) @@ -1376,19 +1376,16 @@ def test_set_loop(self, m): # attach a new loop old_loop = self.loop self.loop = test_utils.TestLoop() + patch = mock.patch.object - with mock.patch.object( - old_loop, - "remove_signal_handler") as m_old_remove_signal_handler, \ - mock.patch.object( - self.loop, - "add_signal_handler") as m_new_add_signal_handler: + with patch(old_loop, "remove_signal_handler") as m_old_remove, \ + patch(self.loop, "add_signal_handler") as m_new_add: self.watcher.attach_loop(self.loop) - m_old_remove_signal_handler.assert_called_once_with( + m_old_remove.assert_called_once_with( signal.SIGCHLD) - m_new_add_signal_handler.assert_called_once_with( + m_new_add.assert_called_once_with( signal.SIGCHLD, self.watcher._sig_chld) # child terminates From 31a56009bbd16186085522a949932624ea0b54d6 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 26 Feb 2014 11:05:56 +0100 Subject: [PATCH 0979/1502] Fix pyflakes warnings: remove unused variables and imports --- asyncio/test_utils.py | 1 - examples/cacheclt.py | 1 - examples/cachesvr.py | 1 - examples/child_process.py | 8 ++------ examples/crawl.py | 2 -- examples/simple_tcp_server.py | 2 +- examples/subprocess_attach_read_pipe.py | 1 - examples/timing_tcp_server.py | 2 +- tests/test_base_events.py | 2 +- tests/test_events.py | 4 +--- tests/test_selector_events.py | 3 +-- tests/test_streams.py | 1 - tests/test_tasks.py | 8 ++------ tests/test_unix_events.py | 1 - tests/test_windows_events.py | 3 +-- 15 files changed, 10 insertions(+), 30 deletions(-) diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index 9a9a10b4..71d309b0 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -11,7 +11,6 @@ import tempfile import threading import time -import unittest from unittest import mock from http.server import HTTPServer diff --git a/examples/cacheclt.py b/examples/cacheclt.py index d1891889..b11a4d1a 100644 --- a/examples/cacheclt.py +++ b/examples/cacheclt.py @@ -8,7 +8,6 @@ from asyncio import test_utils import json import logging -import sys ARGS = argparse.ArgumentParser(description='Cache client example.') ARGS.add_argument( diff --git a/examples/cachesvr.py b/examples/cachesvr.py index 9c7bda91..ddb79b6f 100644 --- a/examples/cachesvr.py +++ b/examples/cachesvr.py @@ -62,7 +62,6 @@ import logging import os import random -import sys ARGS = argparse.ArgumentParser(description='Cache server example.') ARGS.add_argument( diff --git a/examples/child_process.py b/examples/child_process.py index 0c12cb95..3fac175e 100644 --- a/examples/child_process.py +++ b/examples/child_process.py @@ -16,9 +16,6 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '..')) import asyncio -from asyncio import streams -from asyncio import protocols - if sys.platform == 'win32': from asyncio.windows_utils import Popen, PIPE from asyncio.windows_events import ProactorEventLoop @@ -32,7 +29,6 @@ @asyncio.coroutine def connect_write_pipe(file): loop = asyncio.get_event_loop() - protocol = protocols.Protocol() transport, _ = yield from loop.connect_write_pipe(asyncio.Protocol, file) return transport @@ -43,9 +39,9 @@ def connect_write_pipe(file): @asyncio.coroutine def connect_read_pipe(file): loop = asyncio.get_event_loop() - stream_reader = streams.StreamReader(loop=loop) + stream_reader = asyncio.StreamReader(loop=loop) def factory(): - return streams.StreamReaderProtocol(stream_reader) + return asyncio.StreamReaderProtocol(stream_reader) transport, _ = yield from loop.connect_read_pipe(factory, file) return stream_reader, transport diff --git a/examples/crawl.py b/examples/crawl.py index 0e99c82e..637fb8e0 100644 --- a/examples/crawl.py +++ b/examples/crawl.py @@ -21,7 +21,6 @@ from http.client import BadStatusLine import logging import re -import signal import sys import time import urllib.parse @@ -510,7 +509,6 @@ def fetch(self): self.response = yield from self.request.get_response() self.body = yield from self.response.read() h_conn = self.response.get_header('connection').lower() - h_t_enc = self.response.get_header('transfer-encoding').lower() if h_conn != 'close': self.request.close(recycle=True) self.request = None diff --git a/examples/simple_tcp_server.py b/examples/simple_tcp_server.py index 0e87d5b6..b796d9b6 100644 --- a/examples/simple_tcp_server.py +++ b/examples/simple_tcp_server.py @@ -144,7 +144,7 @@ def recv(): # creates a client and connects to our server try: - msg = loop.run_until_complete(client()) + loop.run_until_complete(client()) server.stop(loop) finally: loop.close() diff --git a/examples/subprocess_attach_read_pipe.py b/examples/subprocess_attach_read_pipe.py index 8bec652f..57a1342b 100644 --- a/examples/subprocess_attach_read_pipe.py +++ b/examples/subprocess_attach_read_pipe.py @@ -2,7 +2,6 @@ """Example showing how to attach a read pipe to a subprocess.""" import asyncio import os, sys -from asyncio import subprocess code = """ import os, sys diff --git a/examples/timing_tcp_server.py b/examples/timing_tcp_server.py index cb43a796..883ce6d3 100644 --- a/examples/timing_tcp_server.py +++ b/examples/timing_tcp_server.py @@ -158,7 +158,7 @@ def recv(): # creates a client and connects to our server try: - msg = loop.run_until_complete(client()) + loop.run_until_complete(client()) server.stop(loop) finally: loop.close() diff --git a/tests/test_base_events.py b/tests/test_base_events.py index f7a4e3a0..340ca67d 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -7,7 +7,7 @@ import time import unittest from unittest import mock -from test.support import find_unused_port, IPV6_ENABLED +from test.support import IPV6_ENABLED import asyncio from asyncio import base_events diff --git a/tests/test_events.py b/tests/test_events.py index 055a2aaa..ab58cb53 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -25,7 +25,6 @@ import asyncio -from asyncio import events from asyncio import selector_events from asyncio import test_utils @@ -1648,13 +1647,12 @@ def connect(): def test_subprocess_wait_no_same_group(self): proto = None - transp = None @asyncio.coroutine def connect(): nonlocal proto # start the new process in a new session - transp, proto = yield from self.loop.subprocess_shell( + _, proto = yield from self.loop.subprocess_shell( functools.partial(MySubprocessProtocol, self.loop), 'exit 7', stdin=None, stdout=None, stderr=None, start_new_session=True) diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 369ec32b..964b2e8e 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -1,6 +1,5 @@ """Tests for selector_events.py""" -import collections import errno import gc import pprint @@ -1378,7 +1377,7 @@ class SelectorSslWithoutSslTransportTests(unittest.TestCase): def test_ssl_transport_requires_ssl_module(self): Mock = mock.Mock with self.assertRaises(RuntimeError): - transport = _SelectorSslTransport(Mock(), Mock(), Mock(), Mock()) + _SelectorSslTransport(Mock(), Mock(), Mock(), Mock()) class SelectorDatagramTransportTests(unittest.TestCase): diff --git a/tests/test_streams.py b/tests/test_streams.py index e921dfe4..031499e8 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -1,6 +1,5 @@ """Tests for streams.py.""" -import functools import gc import socket import unittest diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 6d03dc78..ced34312 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -830,7 +830,7 @@ def foo(): v = yield from f self.assertEqual(v, 'a') - res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + loop.run_until_complete(asyncio.Task(foo(), loop=loop)) def test_as_completed_reverse_wait(self): @@ -964,13 +964,9 @@ def gen(): loop = test_utils.TestLoop(gen) self.addCleanup(loop.close) - sleepfut = None - @asyncio.coroutine def sleep(dt): - nonlocal sleepfut - sleepfut = asyncio.sleep(dt, loop=loop) - yield from sleepfut + yield from asyncio.sleep(dt, loop=loop) @asyncio.coroutine def doit(): diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 792e702c..cc743839 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -1459,7 +1459,6 @@ def test_set_loop_race_condition(self, m): def test_close(self, m): # register two children callback1 = mock.Mock() - callback2 = mock.Mock() with self.watcher: self.running = True diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index 846049a2..f6522586 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -8,7 +8,6 @@ import _winapi import asyncio -from asyncio import test_utils from asyncio import _overlapped from asyncio import windows_events @@ -50,7 +49,7 @@ def test_double_bind(self): ADDRESS = r'\\.\pipe\test_double_bind-%s' % os.getpid() server1 = windows_events.PipeServer(ADDRESS) with self.assertRaises(PermissionError): - server2 = windows_events.PipeServer(ADDRESS) + windows_events.PipeServer(ADDRESS) server1.close() def test_pipe(self): From 4257e2d85fe305b65740fc1f9db7aafcbc951eb8 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 26 Feb 2014 11:30:06 +0100 Subject: [PATCH 0980/1502] Simplify test_events.py, don't use non local variables and don't call assert methods in coroutines. It also simplify merges from Tulip to Trollius (Python 2 does not support non local variables). --- tests/test_events.py | 265 +++++++++++++------------------------------ 1 file changed, 79 insertions(+), 186 deletions(-) diff --git a/tests/test_events.py b/tests/test_events.py index ab58cb53..f01d1f38 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1155,23 +1155,15 @@ def connect(): @unittest.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") def test_write_pipe(self): - proto = MyWritePipeProto(loop=self.loop) - transport = None - rpipe, wpipe = os.pipe() pipeobj = io.open(wpipe, 'wb', 1024) - @asyncio.coroutine - def connect(): - nonlocal transport - t, p = yield from self.loop.connect_write_pipe( - lambda: proto, pipeobj) - self.assertIs(p, proto) - self.assertIs(t, proto.transport) - self.assertEqual('CONNECTED', proto.state) - transport = t - - self.loop.run_until_complete(connect()) + proto = MyWritePipeProto(loop=self.loop) + connect = self.loop.connect_write_pipe(lambda: proto, pipeobj) + transport, p = self.loop.run_until_complete(connect) + self.assertIs(p, proto) + self.assertIs(transport, proto.transport) + self.assertEqual('CONNECTED', proto.state) transport.write(b'1') test_utils.run_briefly(self.loop) @@ -1197,23 +1189,14 @@ def connect(): @unittest.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") def test_write_pipe_disconnect_on_close(self): - proto = MyWritePipeProto(loop=self.loop) - transport = None - rsock, wsock = test_utils.socketpair() pipeobj = io.open(wsock.detach(), 'wb', 1024) - @asyncio.coroutine - def connect(): - nonlocal transport - t, p = yield from self.loop.connect_write_pipe(lambda: proto, - pipeobj) - self.assertIs(p, proto) - self.assertIs(t, proto.transport) - self.assertEqual('CONNECTED', proto.state) - transport = t - - self.loop.run_until_complete(connect()) + proto = MyWritePipeProto(loop=self.loop) + connect = self.loop.connect_write_pipe(lambda: proto, pipeobj) + transport, p = self.loop.run_until_complete(connect) + self.assertIs(p, proto) + self.assertIs(transport, proto.transport) self.assertEqual('CONNECTED', proto.state) transport.write(b'1') @@ -1231,23 +1214,15 @@ def connect(): # older than 10.6 (Snow Leopard) @support.requires_mac_ver(10, 6) def test_write_pty(self): - proto = MyWritePipeProto(loop=self.loop) - transport = None - master, slave = os.openpty() slave_write_obj = io.open(slave, 'wb', 0) - @asyncio.coroutine - def connect(): - nonlocal transport - t, p = yield from self.loop.connect_write_pipe(lambda: proto, - slave_write_obj) - self.assertIs(p, proto) - self.assertIs(t, proto.transport) - self.assertEqual('CONNECTED', proto.state) - transport = t - - self.loop.run_until_complete(connect()) + proto = MyWritePipeProto(loop=self.loop) + connect = self.loop.connect_write_pipe(lambda: proto, slave_write_obj) + transport, p = self.loop.run_until_complete(connect) + self.assertIs(p, proto) + self.assertIs(transport, proto.transport) + self.assertEqual('CONNECTED', proto.state) transport.write(b'1') test_utils.run_briefly(self.loop) @@ -1369,20 +1344,13 @@ def check_killed(self, returncode): self.assertEqual(-signal.SIGKILL, returncode) def test_subprocess_exec(self): - proto = None - transp = None - prog = os.path.join(os.path.dirname(__file__), 'echo.py') - @asyncio.coroutine - def connect(): - nonlocal proto, transp - transp, proto = yield from self.loop.subprocess_exec( - functools.partial(MySubprocessProtocol, self.loop), - sys.executable, prog) - self.assertIsInstance(proto, MySubprocessProtocol) - - self.loop.run_until_complete(connect()) + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) self.loop.run_until_complete(proto.connected) self.assertEqual('CONNECTED', proto.state) @@ -1395,20 +1363,13 @@ def connect(): self.assertEqual(b'Python The Winner', proto.data[1]) def test_subprocess_interactive(self): - proto = None - transp = None - prog = os.path.join(os.path.dirname(__file__), 'echo.py') - @asyncio.coroutine - def connect(): - nonlocal proto, transp - transp, proto = yield from self.loop.subprocess_exec( - functools.partial(MySubprocessProtocol, self.loop), - sys.executable, prog) - self.assertIsInstance(proto, MySubprocessProtocol) - - self.loop.run_until_complete(connect()) + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) self.loop.run_until_complete(proto.connected) self.assertEqual('CONNECTED', proto.state) @@ -1429,18 +1390,11 @@ def connect(): self.check_terminated(proto.returncode) def test_subprocess_shell(self): - proto = None - transp = None - - @asyncio.coroutine - def connect(): - nonlocal proto, transp - transp, proto = yield from self.loop.subprocess_shell( - functools.partial(MySubprocessProtocol, self.loop), - 'echo Python') - self.assertIsInstance(proto, MySubprocessProtocol) - - self.loop.run_until_complete(connect()) + connect = self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'echo Python') + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) self.loop.run_until_complete(proto.connected) transp.get_pipe_transport(0).close() @@ -1451,33 +1405,20 @@ def connect(): self.assertEqual(proto.data[2], b'') def test_subprocess_exitcode(self): - proto = None - - @asyncio.coroutine - def connect(): - nonlocal proto - transp, proto = yield from self.loop.subprocess_shell( - functools.partial(MySubprocessProtocol, self.loop), - 'exit 7', stdin=None, stdout=None, stderr=None) - self.assertIsInstance(proto, MySubprocessProtocol) - - self.loop.run_until_complete(connect()) + connect = self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) self.loop.run_until_complete(proto.completed) self.assertEqual(7, proto.returncode) def test_subprocess_close_after_finish(self): - proto = None - transp = None - - @asyncio.coroutine - def connect(): - nonlocal proto, transp - transp, proto = yield from self.loop.subprocess_shell( - functools.partial(MySubprocessProtocol, self.loop), - 'exit 7', stdin=None, stdout=None, stderr=None) - self.assertIsInstance(proto, MySubprocessProtocol) - - self.loop.run_until_complete(connect()) + connect = self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) self.assertIsNone(transp.get_pipe_transport(0)) self.assertIsNone(transp.get_pipe_transport(1)) self.assertIsNone(transp.get_pipe_transport(2)) @@ -1486,20 +1427,13 @@ def connect(): self.assertIsNone(transp.close()) def test_subprocess_kill(self): - proto = None - transp = None - prog = os.path.join(os.path.dirname(__file__), 'echo.py') - @asyncio.coroutine - def connect(): - nonlocal proto, transp - transp, proto = yield from self.loop.subprocess_exec( - functools.partial(MySubprocessProtocol, self.loop), - sys.executable, prog) - self.assertIsInstance(proto, MySubprocessProtocol) - - self.loop.run_until_complete(connect()) + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) self.loop.run_until_complete(proto.connected) transp.kill() @@ -1507,20 +1441,13 @@ def connect(): self.check_killed(proto.returncode) def test_subprocess_terminate(self): - proto = None - transp = None - prog = os.path.join(os.path.dirname(__file__), 'echo.py') - @asyncio.coroutine - def connect(): - nonlocal proto, transp - transp, proto = yield from self.loop.subprocess_exec( - functools.partial(MySubprocessProtocol, self.loop), - sys.executable, prog) - self.assertIsInstance(proto, MySubprocessProtocol) - - self.loop.run_until_complete(connect()) + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) self.loop.run_until_complete(proto.connected) transp.terminate() @@ -1529,20 +1456,13 @@ def connect(): @unittest.skipIf(sys.platform == 'win32', "Don't have SIGHUP") def test_subprocess_send_signal(self): - proto = None - transp = None - prog = os.path.join(os.path.dirname(__file__), 'echo.py') - @asyncio.coroutine - def connect(): - nonlocal proto, transp - transp, proto = yield from self.loop.subprocess_exec( - functools.partial(MySubprocessProtocol, self.loop), - sys.executable, prog) - self.assertIsInstance(proto, MySubprocessProtocol) - - self.loop.run_until_complete(connect()) + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) self.loop.run_until_complete(proto.connected) transp.send_signal(signal.SIGHUP) @@ -1550,20 +1470,13 @@ def connect(): self.assertEqual(-signal.SIGHUP, proto.returncode) def test_subprocess_stderr(self): - proto = None - transp = None - prog = os.path.join(os.path.dirname(__file__), 'echo2.py') - @asyncio.coroutine - def connect(): - nonlocal proto, transp - transp, proto = yield from self.loop.subprocess_exec( - functools.partial(MySubprocessProtocol, self.loop), - sys.executable, prog) - self.assertIsInstance(proto, MySubprocessProtocol) - - self.loop.run_until_complete(connect()) + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) self.loop.run_until_complete(proto.connected) stdin = transp.get_pipe_transport(0) @@ -1577,20 +1490,13 @@ def connect(): self.assertEqual(0, proto.returncode) def test_subprocess_stderr_redirect_to_stdout(self): - proto = None - transp = None - prog = os.path.join(os.path.dirname(__file__), 'echo2.py') - @asyncio.coroutine - def connect(): - nonlocal proto, transp - transp, proto = yield from self.loop.subprocess_exec( - functools.partial(MySubprocessProtocol, self.loop), - sys.executable, prog, stderr=subprocess.STDOUT) - self.assertIsInstance(proto, MySubprocessProtocol) - - self.loop.run_until_complete(connect()) + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog, stderr=subprocess.STDOUT) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) self.loop.run_until_complete(proto.connected) stdin = transp.get_pipe_transport(0) @@ -1607,20 +1513,13 @@ def connect(): self.assertEqual(0, proto.returncode) def test_subprocess_close_client_stream(self): - proto = None - transp = None - prog = os.path.join(os.path.dirname(__file__), 'echo3.py') - @asyncio.coroutine - def connect(): - nonlocal proto, transp - transp, proto = yield from self.loop.subprocess_exec( - functools.partial(MySubprocessProtocol, self.loop), - sys.executable, prog) - self.assertIsInstance(proto, MySubprocessProtocol) - - self.loop.run_until_complete(connect()) + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) self.loop.run_until_complete(proto.connected) stdin = transp.get_pipe_transport(0) @@ -1646,19 +1545,13 @@ def connect(): self.check_terminated(proto.returncode) def test_subprocess_wait_no_same_group(self): - proto = None - - @asyncio.coroutine - def connect(): - nonlocal proto - # start the new process in a new session - _, proto = yield from self.loop.subprocess_shell( - functools.partial(MySubprocessProtocol, self.loop), - 'exit 7', stdin=None, stdout=None, stderr=None, - start_new_session=True) - self.assertIsInstance(proto, MySubprocessProtocol) - - self.loop.run_until_complete(connect()) + # start the new process in a new session + connect = self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None, + start_new_session=True) + _, proto = yield self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) self.loop.run_until_complete(proto.completed) self.assertEqual(7, proto.returncode) From 281dcd5e9bfaffd05d2cc493a7f5d2e20416def7 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 26 Feb 2014 17:23:58 +0100 Subject: [PATCH 0981/1502] cleanup: write the long line as a single line --- asyncio/selector_events.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index aa427459..70d8a958 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -702,8 +702,7 @@ def _write_ready(self): if self._buffer: try: n = self._sock.send(self._buffer) - except (BlockingIOError, InterruptedError, - ssl.SSLWantWriteError): + except (BlockingIOError, InterruptedError, ssl.SSLWantWriteError): n = 0 except ssl.SSLWantReadError: n = 0 From b5678ef69b5a029e64e7dae9168b1c4075c44a4b Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 26 Feb 2014 17:34:07 +0100 Subject: [PATCH 0982/1502] windows_events.py: use more revelant names to overlapped callbacks For example: "finish_recv", not just "finish". --- asyncio/windows_events.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 60fb5896..19f25882 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -213,7 +213,7 @@ def recv(self, conn, nbytes, flags=0): else: ov.ReadFile(conn.fileno(), nbytes) - def finish(trans, key, ov): + def finish_recv(trans, key, ov): try: return ov.getresult() except OSError as exc: @@ -222,7 +222,7 @@ def finish(trans, key, ov): else: raise - return self._register(ov, conn, finish) + return self._register(ov, conn, finish_recv) def send(self, conn, buf, flags=0): self._register_with_iocp(conn) @@ -232,7 +232,7 @@ def send(self, conn, buf, flags=0): else: ov.WriteFile(conn.fileno(), buf) - def finish(trans, key, ov): + def finish_send(trans, key, ov): try: return ov.getresult() except OSError as exc: @@ -241,7 +241,7 @@ def finish(trans, key, ov): else: raise - return self._register(ov, conn, finish) + return self._register(ov, conn, finish_send) def accept(self, listener): self._register_with_iocp(listener) @@ -300,17 +300,17 @@ def accept_pipe(self, pipe): ov = _overlapped.Overlapped(NULL) ov.ConnectNamedPipe(pipe.fileno()) - def finish(trans, key, ov): + def finish_accept_pipe(trans, key, ov): ov.getresult() return pipe - return self._register(ov, pipe, finish) + return self._register(ov, pipe, finish_accept_pipe) def connect_pipe(self, address): ov = _overlapped.Overlapped(NULL) ov.WaitNamedPipeAndConnect(address, self._iocp, ov.address) - def finish(err, handle, ov): + def finish_connect_pipe(err, handle, ov): # err, handle were arguments passed to PostQueuedCompletionStatus() # in a function run in a thread pool. if err == _overlapped.ERROR_SEM_TIMEOUT: @@ -323,7 +323,7 @@ def finish(err, handle, ov): else: return windows_utils.PipeHandle(handle) - return self._register(ov, None, finish, wait_for_post=True) + return self._register(ov, None, finish_connect_pipe, wait_for_post=True) def wait_for_handle(self, handle, timeout=None): if timeout is None: @@ -339,7 +339,7 @@ def wait_for_handle(self, handle, timeout=None): handle, self._iocp, ov.address, ms) f = _WaitHandleFuture(wh, loop=self._loop) - def finish(trans, key, ov): + def finish_wait_for_handle(trans, key, ov): if not f.cancelled(): try: _overlapped.UnregisterWait(wh) @@ -355,7 +355,7 @@ def finish(trans, key, ov): return (_winapi.WaitForSingleObject(handle, 0) == _winapi.WAIT_OBJECT_0) - self._cache[ov.address] = (f, ov, None, finish) + self._cache[ov.address] = (f, ov, None, finish_wait_for_handle) return f def _register_with_iocp(self, obj): From c3f1c3d7e268577e561915d177fc19a069bc0665 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 26 Feb 2014 18:00:18 +0100 Subject: [PATCH 0983/1502] tcp_echo.py: add --iocp command line option to use IOCP event loop on Windows Don't setup a signal handler for SIGINT on Windows. --- examples/tcp_echo.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/examples/tcp_echo.py b/examples/tcp_echo.py index 3c08b15d..d743242a 100755 --- a/examples/tcp_echo.py +++ b/examples/tcp_echo.py @@ -2,6 +2,7 @@ """TCP echo server example.""" import argparse import asyncio +import sys try: import signal except ImportError: @@ -87,6 +88,9 @@ def start_server(loop, host, port): ARGS.add_argument( '--port', action="store", dest='port', default=9999, type=int, help='Port number') +ARGS.add_argument( + '--iocp', action="store_true", dest='iocp', + default=False, help='Use IOCP event loop') if __name__ == '__main__': @@ -100,8 +104,15 @@ def start_server(loop, host, port): print('Please specify --server or --client\n') ARGS.print_help() else: - loop = asyncio.get_event_loop() - if signal is not None: + if args.iocp: + from asyncio import windows_events + loop = windows_events.ProactorEventLoop() + asyncio.set_event_loop(loop) + else: + loop = asyncio.get_event_loop() + print ('Using backend: {0}'.format(loop.__class__.__name__)) + + if signal is not None and sys.platform != 'win32': loop.add_signal_handler(signal.SIGINT, loop.stop) if args.server: From d25b8bdee9c909afcf0930ef4218c4ab62e0931c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 3 Mar 2014 14:28:08 -0800 Subject: [PATCH 0984/1502] Use public interfaces to get the fileno(). Fix a typo. --- examples/crawl.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/crawl.py b/examples/crawl.py index 637fb8e0..da654cb9 100644 --- a/examples/crawl.py +++ b/examples/crawl.py @@ -116,7 +116,7 @@ class ConnectionPool: and recycle_connection() puts it back in. To recycle a connection, call conn.close(recycle=True). - There are limits to both the overal pool and the per-key pool. + There are limits to both the overall pool and the per-key pool. """ def __init__(self, log, max_pool=10, max_tasks=5): @@ -244,14 +244,14 @@ def __init__(self, log, pool, host, port, ssl): self.key = None def stale(self): - return self.reader is None or self.reader._eof + return self.reader is None or self.reader.at_eof() def fileno(self): writer = self.writer if writer is not None: - transport = writer._transport + transport = writer.transport if transport is not None: - sock = transport._sock + sock = transport.get_extra_info('socket') if sock is not None: return sock.fileno() return None From 24d19affc9c8e9a7772bf1afb7fd5025a98b904c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 3 Mar 2014 14:44:19 -0800 Subject: [PATCH 0985/1502] Another tiny crawl.py cleanup. --- examples/crawl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/crawl.py b/examples/crawl.py index da654cb9..4bb0b4ea 100644 --- a/examples/crawl.py +++ b/examples/crawl.py @@ -599,6 +599,8 @@ def report(self, stats, file=None): size, '%d/%d' % (len(self.new_urls or ()), len(self.urls or ())), file=file) + elif self.response is None: + print(self.url, 'no response object') else: size = len(self.body or b'') if self.response.status == 200: From 874ae261771829ef6ea5aa393eb161ccbc52d7c1 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 4 Mar 2014 23:05:26 +0100 Subject: [PATCH 0986/1502] Issue #158: Task._step() now also sets self to None if an exception is raised. self is set to None to break a reference cycle. --- asyncio/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 19fa654e..0967e7e6 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -325,7 +325,7 @@ def _step(self, value=None, exc=None): 'Task got bad yield: {!r}'.format(result))) finally: self.__class__._current_tasks.pop(self._loop) - self = None + self = None # Needed to break cycles when an exception occurs. def _wakeup(self, future): try: From b934e3a7ca00a77b6393d4ff387c1a09620bd07c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 4 Mar 2014 14:38:16 -0800 Subject: [PATCH 0987/1502] Reject add/remove reader/writer when event loop is closed. --- asyncio/selector_events.py | 8 ++++++++ tests/test_events.py | 20 ++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 70d8a958..367c5fbe 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -136,6 +136,8 @@ def _accept_connection(self, protocol_factory, sock, def add_reader(self, fd, callback, *args): """Add a reader callback.""" + if self._selector is None: + raise RuntimeError('Event loop is closed') handle = events.Handle(callback, args, self) try: key = self._selector.get_key(fd) @@ -151,6 +153,8 @@ def add_reader(self, fd, callback, *args): def remove_reader(self, fd): """Remove a reader callback.""" + if self._selector is None: + return False try: key = self._selector.get_key(fd) except KeyError: @@ -171,6 +175,8 @@ def remove_reader(self, fd): def add_writer(self, fd, callback, *args): """Add a writer callback..""" + if self._selector is None: + raise RuntimeError('Event loop is closed') handle = events.Handle(callback, args, self) try: key = self._selector.get_key(fd) @@ -186,6 +192,8 @@ def add_writer(self, fd, callback, *args): def remove_writer(self, fd): """Remove a writer callback.""" + if self._selector is None: + return False try: key = self._selector.get_key(fd) except KeyError: diff --git a/tests/test_events.py b/tests/test_events.py index f01d1f38..12be8a14 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1326,6 +1326,26 @@ def test_sock_connect_address(self): self.assertIn('address must be resolved', str(cm.exception)) + def test_remove_fds_after_closing(self): + loop = self.create_event_loop() + callback = lambda: None + r, w = test_utils.socketpair() + loop.add_reader(r, callback) + loop.add_writer(w, callback) + loop.close() + self.assertFalse(loop.remove_reader(r)) + self.assertFalse(loop.remove_writer(w)) + + def test_add_fds_after_closing(self): + loop = self.create_event_loop() + callback = lambda: None + r, w = test_utils.socketpair() + loop.close() + with self.assertRaises(RuntimeError): + loop.add_reader(r, callback) + with self.assertRaises(RuntimeError): + loop.add_writer(w, callback) + class SubprocessTestsMixin: From b7a15eff4c8958bf384d5ddfea20ec6727de01ed Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 5 Mar 2014 10:37:39 -0800 Subject: [PATCH 0988/1502] Remove egg info in "make clean". --- Makefile | 1 + 1 file changed, 1 insertion(+) diff --git a/Makefile b/Makefile index 24a8c050..952c93ff 100644 --- a/Makefile +++ b/Makefile @@ -38,6 +38,7 @@ clean: rm -f .coverage rm -rf htmlcov rm -rf build + rm -rf asyncio.egg-info rm -f MANIFEST From 21697f0ef6e89871549802725dc2e2fcc7becfb5 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 5 Mar 2014 22:52:18 +0100 Subject: [PATCH 0989/1502] Issue #159: Fix windows_utils.socketpair() * Use "127.0.0.1" (IPv4) or "::1" (IPv6) host instead of "localhost", because "localhost" may be a different IP address * Reject also invalid arguments: only AF_INET/AF_INET6 with SOCK_STREAM (and proto=0) are supported --- asyncio/windows_utils.py | 17 +++++++++++++++-- tests/test_windows_utils.py | 27 +++++++++++++++++++++++---- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/asyncio/windows_utils.py b/asyncio/windows_utils.py index aa1c0648..2a196cc7 100644 --- a/asyncio/windows_utils.py +++ b/asyncio/windows_utils.py @@ -36,12 +36,25 @@ def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. """ + if family == socket.AF_INET: + host = '127.0.0.1' + elif family == socket.AF_INET6: + host = '::1' + else: + raise ValueError("Ony AF_INET and AF_INET6 socket address families " + "are supported") + if type != socket.SOCK_STREAM: + raise ValueError("Only SOCK_STREAM socket type is supported") + if proto != 0: + raise ValueError("Only protocol zero is supported") + # We create a connected TCP socket. Note the trick with setblocking(0) # that prevents us from having to create a thread. lsock = socket.socket(family, type, proto) - lsock.bind(('localhost', 0)) + lsock.bind((host, 0)) lsock.listen(1) - addr, port = lsock.getsockname() + # On IPv6, ignore flow_info and scope_id + addr, port = lsock.getsockname()[:2] csock = socket.socket(family, type, proto) csock.setblocking(False) try: diff --git a/tests/test_windows_utils.py b/tests/test_windows_utils.py index 7616c73e..9daf4340 100644 --- a/tests/test_windows_utils.py +++ b/tests/test_windows_utils.py @@ -1,8 +1,10 @@ """Tests for window_utils""" +import socket import sys import test.support import unittest +from test.support import IPV6_ENABLED from unittest import mock if sys.platform != 'win32': @@ -16,23 +18,40 @@ class WinsocketpairTests(unittest.TestCase): - def test_winsocketpair(self): - ssock, csock = windows_utils.socketpair() - + def check_winsocketpair(self, ssock, csock): csock.send(b'xxx') self.assertEqual(b'xxx', ssock.recv(1024)) - csock.close() ssock.close() + def test_winsocketpair(self): + ssock, csock = windows_utils.socketpair() + self.check_winsocketpair(ssock, csock) + + @unittest.skipUnless(IPV6_ENABLED, 'IPv6 not supported or enabled') + def test_winsocketpair_ipv6(self): + ssock, csock = windows_utils.socketpair(family=socket.AF_INET6) + self.check_winsocketpair(ssock, csock) + @mock.patch('asyncio.windows_utils.socket') def test_winsocketpair_exc(self, m_socket): + m_socket.AF_INET = socket.AF_INET + m_socket.SOCK_STREAM = socket.SOCK_STREAM m_socket.socket.return_value.getsockname.return_value = ('', 12345) m_socket.socket.return_value.accept.return_value = object(), object() m_socket.socket.return_value.connect.side_effect = OSError() self.assertRaises(OSError, windows_utils.socketpair) + def test_winsocketpair_invalid_args(self): + self.assertRaises(ValueError, + windows_utils.socketpair, family=socket.AF_UNSPEC) + self.assertRaises(ValueError, + windows_utils.socketpair, type=socket.SOCK_DGRAM) + self.assertRaises(ValueError, + windows_utils.socketpair, proto=1) + + class PipeTests(unittest.TestCase): From fda58633d48219c661d1af09644e77815d79e2aa Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 6 Mar 2014 00:19:30 +0100 Subject: [PATCH 0990/1502] Fix ResourceWarning warnings --- tests/test_events.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_events.py b/tests/test_events.py index 12be8a14..fec5e6ea 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1330,6 +1330,8 @@ def test_remove_fds_after_closing(self): loop = self.create_event_loop() callback = lambda: None r, w = test_utils.socketpair() + self.addCleanup(r.close) + self.addCleanup(w.close) loop.add_reader(r, callback) loop.add_writer(w, callback) loop.close() @@ -1340,6 +1342,8 @@ def test_add_fds_after_closing(self): loop = self.create_event_loop() callback = lambda: None r, w = test_utils.socketpair() + self.addCleanup(r.close) + self.addCleanup(w.close) loop.close() with self.assertRaises(RuntimeError): loop.add_reader(r, callback) From aadd3b9e4db1e4f494683a294228eaa10e0d616c Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 6 Mar 2014 00:20:26 +0100 Subject: [PATCH 0991/1502] Skip test_remove_fds_after_closing() for IocpEventLoop --- tests/test_events.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_events.py b/tests/test_events.py index fec5e6ea..fd7022ff 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1656,6 +1656,9 @@ def test_writer_callback_cancel(self): def test_create_datagram_endpoint(self): raise unittest.SkipTest( "IocpEventLoop does not have create_datagram_endpoint()") + + def test_remove_fds_after_closing(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") else: from asyncio import selectors From 63de3dd6e326daa47e94c577370cd67c573cb13d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 6 Mar 2014 01:00:03 +0100 Subject: [PATCH 0992/1502] Issue #157: Improve test_events.py, avoid run_briefly() which is not reliable --- asyncio/test_utils.py | 15 ++--- tests/test_events.py | 129 +++++++++++++++++++++--------------------- 2 files changed, 71 insertions(+), 73 deletions(-) diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index 71d309b0..9c3656ac 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -21,10 +21,11 @@ except ImportError: # pragma: no cover ssl = None -from . import tasks from . import base_events from . import events +from . import futures from . import selectors +from . import tasks if sys.platform == 'win32': # pragma: no cover @@ -52,18 +53,14 @@ def once(): gen.close() -def run_until(loop, pred, timeout=None): - if timeout is not None: - deadline = time.time() + timeout +def run_until(loop, pred, timeout=30): + deadline = time.time() + timeout while not pred(): if timeout is not None: timeout = deadline - time.time() if timeout <= 0: - return False - loop.run_until_complete(tasks.sleep(timeout, loop=loop)) - else: - run_briefly(loop) - return True + raise futures.TimeoutError() + loop.run_until_complete(tasks.sleep(0.001, loop=loop)) def run_once(loop): diff --git a/tests/test_events.py b/tests/test_events.py index fd7022ff..bafa8756 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -56,6 +56,7 @@ def osx_tiger(): class MyBaseProto(asyncio.Protocol): + connected = None done = None def __init__(self, loop=None): @@ -63,12 +64,15 @@ def __init__(self, loop=None): self.state = 'INITIAL' self.nbytes = 0 if loop is not None: + self.connected = asyncio.Future(loop=loop) self.done = asyncio.Future(loop=loop) def connection_made(self, transport): self.transport = transport assert self.state == 'INITIAL', self.state self.state = 'CONNECTED' + if self.connected: + self.connected.set_result(None) def data_received(self, data): assert self.state == 'CONNECTED', self.state @@ -330,7 +334,8 @@ def run(arg): def test_reader_callback(self): r, w = test_utils.socketpair() - bytes_read = [] + r.setblocking(False) + bytes_read = bytearray() def reader(): try: @@ -340,37 +345,40 @@ def reader(): # at least on Linux -- see man select. return if data: - bytes_read.append(data) + bytes_read.extend(data) else: self.assertTrue(self.loop.remove_reader(r.fileno())) r.close() self.loop.add_reader(r.fileno(), reader) self.loop.call_soon(w.send, b'abc') - test_utils.run_briefly(self.loop) + test_utils.run_until(self.loop, lambda: len(bytes_read) >= 3) self.loop.call_soon(w.send, b'def') - test_utils.run_briefly(self.loop) + test_utils.run_until(self.loop, lambda: len(bytes_read) >= 6) self.loop.call_soon(w.close) self.loop.call_soon(self.loop.stop) self.loop.run_forever() - self.assertEqual(b''.join(bytes_read), b'abcdef') + self.assertEqual(bytes_read, b'abcdef') def test_writer_callback(self): r, w = test_utils.socketpair() w.setblocking(False) - self.loop.add_writer(w.fileno(), w.send, b'x'*(256*1024)) - test_utils.run_briefly(self.loop) - def remove_writer(): - self.assertTrue(self.loop.remove_writer(w.fileno())) + def writer(data): + w.send(data) + self.loop.stop() - self.loop.call_soon(remove_writer) - self.loop.call_soon(self.loop.stop) + data = b'x' * 1024 + self.loop.add_writer(w.fileno(), writer, data) self.loop.run_forever() + + self.assertTrue(self.loop.remove_writer(w.fileno())) + self.assertFalse(self.loop.remove_writer(w.fileno())) + w.close() - data = r.recv(256*1024) + read = r.recv(len(data) * 2) r.close() - self.assertGreaterEqual(len(data), 200) + self.assertEqual(read, data) def _basetest_sock_client_ops(self, httpd, sock): sock.setblocking(False) @@ -464,10 +472,10 @@ def my_handler(): self.assertFalse(self.loop.remove_signal_handler(signal.SIGKILL)) # Now set a handler and handle it. self.loop.add_signal_handler(signal.SIGINT, my_handler) - test_utils.run_briefly(self.loop) + os.kill(os.getpid(), signal.SIGINT) - test_utils.run_briefly(self.loop) - self.assertEqual(caught, 1) + test_utils.run_until(self.loop, lambda: caught) + # Removing it should restore the default handler. self.assertTrue(self.loop.remove_signal_handler(signal.SIGINT)) self.assertEqual(signal.getsignal(signal.SIGINT), @@ -623,7 +631,7 @@ def test_create_connection_local_addr_in_use(self): self.assertIn(str(httpd.address), cm.exception.strerror) def test_create_server(self): - proto = MyProto() + proto = MyProto(self.loop) f = self.loop.create_server(lambda: proto, '0.0.0.0', 0) server = self.loop.run_until_complete(f) self.assertEqual(len(server.sockets), 1) @@ -633,14 +641,11 @@ def test_create_server(self): client = socket.socket() client.connect(('127.0.0.1', port)) client.sendall(b'xxx') - test_utils.run_briefly(self.loop) - test_utils.run_until(self.loop, lambda: proto is not None, 10) - self.assertIsInstance(proto, MyProto) - self.assertEqual('INITIAL', proto.state) - test_utils.run_briefly(self.loop) + + self.loop.run_until_complete(proto.connected) self.assertEqual('CONNECTED', proto.state) - test_utils.run_until(self.loop, lambda: proto.nbytes > 0, - timeout=10) + + test_utils.run_until(self.loop, lambda: proto.nbytes > 0) self.assertEqual(3, proto.nbytes) # extra info is available @@ -650,7 +655,7 @@ def test_create_server(self): # close connection proto.transport.close() - test_utils.run_briefly(self.loop) # windows iocp + self.loop.run_until_complete(proto.done) self.assertEqual('CLOSED', proto.state) @@ -672,27 +677,22 @@ def _make_unix_server(self, factory, **kwargs): @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_create_unix_server(self): - proto = MyProto() + proto = MyProto(loop=self.loop) server, path = self._make_unix_server(lambda: proto) self.assertEqual(len(server.sockets), 1) client = socket.socket(socket.AF_UNIX) client.connect(path) client.sendall(b'xxx') - test_utils.run_briefly(self.loop) - test_utils.run_until(self.loop, lambda: proto is not None, 10) - self.assertIsInstance(proto, MyProto) - self.assertEqual('INITIAL', proto.state) - test_utils.run_briefly(self.loop) + self.loop.run_until_complete(proto.connected) self.assertEqual('CONNECTED', proto.state) - test_utils.run_until(self.loop, lambda: proto.nbytes > 0, - timeout=10) + test_utils.run_until(self.loop, lambda: proto.nbytes > 0) self.assertEqual(3, proto.nbytes) # close connection proto.transport.close() - test_utils.run_briefly(self.loop) # windows iocp + self.loop.run_until_complete(proto.done) self.assertEqual('CLOSED', proto.state) @@ -735,12 +735,10 @@ def test_create_server_ssl(self): client, pr = self.loop.run_until_complete(f_c) client.write(b'xxx') - test_utils.run_briefly(self.loop) - self.assertIsInstance(proto, MyProto) - test_utils.run_briefly(self.loop) + self.loop.run_until_complete(proto.connected) self.assertEqual('CONNECTED', proto.state) - test_utils.run_until(self.loop, lambda: proto.nbytes > 0, - timeout=10) + + test_utils.run_until(self.loop, lambda: proto.nbytes > 0) self.assertEqual(3, proto.nbytes) # extra info is available @@ -774,12 +772,9 @@ def test_create_unix_server_ssl(self): client, pr = self.loop.run_until_complete(f_c) client.write(b'xxx') - test_utils.run_briefly(self.loop) - self.assertIsInstance(proto, MyProto) - test_utils.run_briefly(self.loop) + self.loop.run_until_complete(proto.connected) self.assertEqual('CONNECTED', proto.state) - test_utils.run_until(self.loop, lambda: proto.nbytes > 0, - timeout=10) + test_utils.run_until(self.loop, lambda: proto.nbytes > 0) self.assertEqual(3, proto.nbytes) # close connection @@ -1044,15 +1039,9 @@ def datagram_received(self, data, addr): self.assertEqual('INITIALIZED', client.state) transport.sendto(b'xxx') - for _ in range(1000): - if server.nbytes: - break - test_utils.run_briefly(self.loop) + test_utils.run_until(self.loop, lambda: server.nbytes) self.assertEqual(3, server.nbytes) - for _ in range(1000): - if client.nbytes: - break - test_utils.run_briefly(self.loop) + test_utils.run_until(self.loop, lambda: client.nbytes) # received self.assertEqual(8, client.nbytes) @@ -1097,11 +1086,11 @@ def connect(): self.loop.run_until_complete(connect()) os.write(wpipe, b'1') - test_utils.run_briefly(self.loop) + test_utils.run_until(self.loop, lambda: proto.nbytes >= 1) self.assertEqual(1, proto.nbytes) os.write(wpipe, b'2345') - test_utils.run_briefly(self.loop) + test_utils.run_until(self.loop, lambda: proto.nbytes >= 5) self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) self.assertEqual(5, proto.nbytes) @@ -1166,14 +1155,19 @@ def test_write_pipe(self): self.assertEqual('CONNECTED', proto.state) transport.write(b'1') - test_utils.run_briefly(self.loop) - data = os.read(rpipe, 1024) + + data = bytearray() + def reader(data): + chunk = os.read(rpipe, 1024) + data += chunk + return len(data) + + test_utils.run_until(self.loop, lambda: reader(data) >= 1) self.assertEqual(b'1', data) transport.write(b'2345') - test_utils.run_briefly(self.loop) - data = os.read(rpipe, 1024) - self.assertEqual(b'2345', data) + test_utils.run_until(self.loop, lambda: reader(data) >= 5) + self.assertEqual(b'12345', data) self.assertEqual('CONNECTED', proto.state) os.close(rpipe) @@ -1225,14 +1219,21 @@ def test_write_pty(self): self.assertEqual('CONNECTED', proto.state) transport.write(b'1') - test_utils.run_briefly(self.loop) - data = os.read(master, 1024) + + data = bytearray() + def reader(data): + chunk = os.read(master, 1024) + data += chunk + return len(data) + + test_utils.run_until(self.loop, lambda: reader(data) >= 1, + timeout=10) self.assertEqual(b'1', data) transport.write(b'2345') - test_utils.run_briefly(self.loop) - data = os.read(master, 1024) - self.assertEqual(b'2345', data) + test_utils.run_until(self.loop, lambda: reader(data) >= 5, + timeout=10) + self.assertEqual(b'12345', data) self.assertEqual('CONNECTED', proto.state) os.close(master) From 124ceb94102b7df3e580aad3b0f76234a5b876bb Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 14 Mar 2014 21:23:12 -0700 Subject: [PATCH 0993/1502] Back out inplace default for build_ext. See issue #116. --- README | 6 +++--- setup.cfg | 2 -- 2 files changed, 3 insertions(+), 5 deletions(-) delete mode 100644 setup.cfg diff --git a/README b/README index 1150bafa..34812e08 100644 --- a/README +++ b/README @@ -20,10 +20,10 @@ To run coverage (coverage package is required): On Windows, things are a little more complicated. Assume 'P' is your Python binary (for example C:\Python33\python.exe). -You must first build the _overlapped.pyd extension (it will be placed -in the asyncio directory): +You must first build the _overlapped.pyd extension and have it placed +in the asyncio directory, as follows: - C> P setup.py build_ext + C> P setup.py build_ext --inplace Then you can run the tests as follows: diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index da2a775d..00000000 --- a/setup.cfg +++ /dev/null @@ -1,2 +0,0 @@ -[build_ext] -inplace = 1 From 1791bd64f15900adae257ae315367befaa1e365c Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 21 Mar 2014 09:59:47 +0100 Subject: [PATCH 0994/1502] Ensure call_soon(), call_later() and call_at() are invoked on current loop in debug mode. Raise a RuntimeError if the event loop of the current thread is different. The check should help to debug thread-safetly issue. Patch written by David Foster. --- asyncio/base_events.py | 23 ++++++++++++++++++++++- tests/test_base_events.py | 23 +++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 80df9271..d2bdc07d 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -259,6 +259,8 @@ def call_at(self, when, callback, *args): """Like call_later(), but uses an absolute time.""" if tasks.iscoroutinefunction(callback): raise TypeError("coroutines cannot be used with call_at()") + if self._debug: + self._assert_is_current_event_loop() timer = events.TimerHandle(when, callback, args, self) heapq.heappush(self._scheduled, timer) return timer @@ -273,15 +275,34 @@ def call_soon(self, callback, *args): Any positional arguments after the callback will be passed to the callback when it is called. """ + return self._call_soon(callback, args, check_loop=True) + + def _call_soon(self, callback, args, check_loop): if tasks.iscoroutinefunction(callback): raise TypeError("coroutines cannot be used with call_soon()") + if self._debug and check_loop: + self._assert_is_current_event_loop() handle = events.Handle(callback, args, self) self._ready.append(handle) return handle + def _assert_is_current_event_loop(self): + """Asserts that this event loop is the current event loop. + + Non-threadsafe methods of this class make this assumption and will + likely behave incorrectly when the assumption is violated. + + Should only be called when (self._debug == True). The caller is + responsible for checking this condition for performance reasons. + """ + if events.get_event_loop() is not self: + raise RuntimeError( + "non-threadsafe operation invoked on an event loop other " + "than the current one") + def call_soon_threadsafe(self, callback, *args): """XXX""" - handle = self.call_soon(callback, *args) + handle = self._call_soon(callback, args, check_loop=False) self._write_to_self() return handle diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 340ca67d..544bd3dd 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -136,6 +136,29 @@ def cb(): # are really slow self.assertLessEqual(dt, 0.9, dt) + def test_assert_is_current_event_loop(self): + def cb(): + pass + + other_loop = base_events.BaseEventLoop() + other_loop._selector = unittest.mock.Mock() + asyncio.set_event_loop(other_loop) + + # raise RuntimeError if the event loop is different in debug mode + self.loop.set_debug(True) + with self.assertRaises(RuntimeError): + self.loop.call_soon(cb) + with self.assertRaises(RuntimeError): + self.loop.call_later(60, cb) + with self.assertRaises(RuntimeError): + self.loop.call_at(self.loop.time() + 60, cb) + + # check disabled if debug mode is disabled + self.loop.set_debug(False) + self.loop.call_soon(cb) + self.loop.call_later(60, cb) + self.loop.call_at(self.loop.time() + 60, cb) + def test_run_once_in_executor_handle(self): def cb(): pass From a2eb06d9df253e2cf2b3ae2de6f5024e8a456893 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 24 Mar 2014 10:16:33 -0700 Subject: [PATCH 0995/1502] Pull in Solaris devpoll support by Giampaolo Rodola'. --- asyncio/selectors.py | 62 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/asyncio/selectors.py b/asyncio/selectors.py index 9be92255..4e9ae6ec 100644 --- a/asyncio/selectors.py +++ b/asyncio/selectors.py @@ -441,6 +441,64 @@ def close(self): super().close() +if hasattr(select, 'devpoll'): + + class DevpollSelector(_BaseSelectorImpl): + """Solaris /dev/poll selector.""" + + def __init__(self): + super().__init__() + self._devpoll = select.devpoll() + + def fileno(self): + return self._devpoll.fileno() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= select.POLLIN + if events & EVENT_WRITE: + poll_events |= select.POLLOUT + self._devpoll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._devpoll.unregister(key.fd) + return key + + def select(self, timeout=None): + if timeout is None: + timeout = None + elif timeout <= 0: + timeout = 0 + else: + # devpoll() has a resolution of 1 millisecond, round away from + # zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) + ready = [] + try: + fd_event_list = self._devpoll.poll(timeout) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & ~select.POLLIN: + events |= EVENT_WRITE + if event & ~select.POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + self._devpoll.close() + super().close() + + if hasattr(select, 'kqueue'): class KqueueSelector(_BaseSelectorImpl): @@ -513,12 +571,14 @@ def close(self): super().close() -# Choose the best implementation: roughly, epoll|kqueue > poll > select. +# Choose the best implementation: roughly, epoll|kqueue|devpoll > poll > select. # select() also can't accept a FD > FD_SETSIZE (usually around 1024) if 'KqueueSelector' in globals(): DefaultSelector = KqueueSelector elif 'EpollSelector' in globals(): DefaultSelector = EpollSelector +elif 'DevpollSelector' in globals(): + DefaultSelector = DevpollSelector elif 'PollSelector' in globals(): DefaultSelector = PollSelector else: From 5cdc4cad58484574f025617c95302a831ac0f0fd Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 31 Mar 2014 10:29:24 -0700 Subject: [PATCH 0996/1502] Document Task.cancel() properly. --- asyncio/tasks.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 0967e7e6..a84ad261 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -250,6 +250,25 @@ def print_stack(self, *, limit=None, file=None): print(line, file=file, end='') def cancel(self): + """Request that a task to cancel itself. + + This arranges for a CancellationError to be thrown into the + wrapped coroutine on the next cycle through the event loop. + The coroutine then has a chance to clean up or even deny + the request using try/except/finally. + + Contrary to Future.cancel(), this does not guarantee that the + task will cancelled: the exception might be caught and acted + upon, delaying cancellation of the task or preventing it + completely. The task may also return a value or raise a + different exception. + + Immediately after this method is called, Task.cancelled() will + not return True (unless the task was already cancelled). A + task will be marked as cancelled when the wrapped coroutine + terminates with a CancelledError exception (even if cancel() + was not called). + """ if self.done(): return False if self._fut_waiter is not None: From f1938b57640d0f49a2ede991dad0cb24d37c52df Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 31 Mar 2014 11:31:16 -0700 Subject: [PATCH 0997/1502] Fix bad grammar. --- asyncio/tasks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index a84ad261..153f731a 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -258,8 +258,8 @@ def cancel(self): the request using try/except/finally. Contrary to Future.cancel(), this does not guarantee that the - task will cancelled: the exception might be caught and acted - upon, delaying cancellation of the task or preventing it + task will be cancelled: the exception might be caught and + acted upon, delaying cancellation of the task or preventing it completely. The task may also return a value or raise a different exception. From eabfc95a66ecd19b7512b653a6aea01b2d712bdb Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 5 Apr 2014 23:21:07 +0200 Subject: [PATCH 0998/1502] EventLoop.create_unix_server() now raises a ValueError if path and sock are specified at the same time --- asyncio/unix_events.py | 4 ++++ tests/test_events.py | 11 +++++++++++ 2 files changed, 15 insertions(+) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 21255480..1fbdd313 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -206,6 +206,10 @@ def create_unix_server(self, protocol_factory, path=None, *, raise TypeError('ssl argument must be an SSLContext or None') if path is not None: + if sock is not None: + raise ValueError( + 'path and sock can not be specified at the same time') + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) try: diff --git a/tests/test_events.py b/tests/test_events.py index bafa8756..1e64dd07 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -703,6 +703,17 @@ def test_create_unix_server(self): # close server server.close() + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_path_socket_error(self): + proto = MyProto(loop=self.loop) + sock = socket.socket() + with sock: + f = self.loop.create_unix_server(lambda: proto, '/test', sock=sock) + with self.assertRaisesRegex(ValueError, + 'path and sock can not be specified ' + 'at the same time'): + server = self.loop.run_until_complete(f) + def _create_ssl_context(self, certfile, keyfile=None): sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext.options |= ssl.OP_NO_SSLv2 From 09bf06b66b12937770e4837f5750ff1951c623ea Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Mon, 14 Apr 2014 22:14:36 -0400 Subject: [PATCH 0999/1502] tasks: Fix CoroWrapper to workaround yield-from bug in CPython --- asyncio/tasks.py | 5 ++++- tests/test_tasks.py | 25 +++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 153f731a..0366da35 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -49,7 +49,10 @@ def __iter__(self): def __next__(self): return next(self.gen) - def send(self, value): + def send(self, *value): + # We use `*value` because of a bug in CPythons prior + # to 3.4.1. See issue #21209 and test_yield_from_corowrapper + # for details. This workaround should be removed in 3.5.0. return self.gen.send(value) def throw(self, exc): diff --git a/tests/test_tasks.py b/tests/test_tasks.py index ced34312..45de8acc 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1386,6 +1386,31 @@ def test_wait_invalid_args(self): self.assertRaises(ValueError, self.loop.run_until_complete, asyncio.wait([], loop=self.loop)) + def test_yield_from_corowrapper(self): + old_debug = asyncio.tasks._DEBUG + asyncio.tasks._DEBUG = True + try: + @asyncio.coroutine + def t1(): + return (yield from t2()) + + @asyncio.coroutine + def t2(): + f = asyncio.Future(loop=self.loop) + asyncio.Task(t3(f), loop=self.loop) + return (yield from f) + + @asyncio.coroutine + def t3(f): + f.set_result((1, 2, 3)) + + task = asyncio.Task(t1(), loop=self.loop) + val = self.loop.run_until_complete(task) + self.assertEqual(val, (1, 2, 3)) + finally: + asyncio.tasks._DEBUG = old_debug + + class GatherTestsBase: def setUp(self): From 205ad123cef16b822b293966ee9659273fa10b9a Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Tue, 15 Apr 2014 11:58:49 -0400 Subject: [PATCH 1000/1502] tasks: Make sure CoroWrapper.send proxies one argument correctly --- asyncio/tasks.py | 2 ++ tests/test_tasks.py | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 0366da35..0785e107 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -53,6 +53,8 @@ def send(self, *value): # We use `*value` because of a bug in CPythons prior # to 3.4.1. See issue #21209 and test_yield_from_corowrapper # for details. This workaround should be removed in 3.5.0. + if len(value) == 1: + value = value[0] return self.gen.send(value) def throw(self, exc): diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 45de8acc..2b90a108 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1410,6 +1410,24 @@ def t3(f): finally: asyncio.tasks._DEBUG = old_debug + def test_yield_from_corowrapper_send(self): + def foo(): + a = yield + return a + + def call(arg): + cw = asyncio.tasks.CoroWrapper(foo(), foo) + cw.send(None) + try: + cw.send(arg) + except StopIteration as ex: + return ex.args[0] + else: + raise AssertionError('StopIteration was expected') + + self.assertEqual(call((1, 2)), (1, 2)) + self.assertEqual(call('spam'), 'spam') + class GatherTestsBase: From 493dead42a0428b70f781a86a5afdf252ee0d047 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 15 Apr 2014 11:59:07 -0700 Subject: [PATCH 1001/1502] Add gi_{frame,running,code} properties to CoroWrapper. Fixes issue #163. --- asyncio/tasks.py | 12 ++++++++++++ tests/test_tasks.py | 47 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 0785e107..c6c22dd2 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -63,6 +63,18 @@ def throw(self, exc): def close(self): return self.gen.close() + @property + def gi_frame(self): + return self.gen.gi_frame + + @property + def gi_running(self): + return self.gen.gi_running + + @property + def gi_code(self): + return self.gen.gi_code + def __del__(self): frame = self.gen.gi_frame if frame is not None and frame.f_lasti == -1: diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 2b90a108..80571b41 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -2,6 +2,7 @@ import gc import os.path +import types import unittest from test.script_helper import assert_python_ok @@ -1386,6 +1387,52 @@ def test_wait_invalid_args(self): self.assertRaises(ValueError, self.loop.run_until_complete, asyncio.wait([], loop=self.loop)) + def test_corowrapper_mocks_generator(self): + + def check(): + # A function that asserts various things. + # Called twice, with different debug flag values. + + @asyncio.coroutine + def coro(): + # The actual coroutine. + self.assertTrue(gen.gi_running) + yield from fut + + # A completed Future used to run the coroutine. + fut = asyncio.Future(loop=self.loop) + fut.set_result(None) + + # Call the coroutine. + gen = coro() + + # Check some properties. + self.assertTrue(asyncio.iscoroutine(gen)) + self.assertIsInstance(gen.gi_frame, types.FrameType) + self.assertFalse(gen.gi_running) + self.assertIsInstance(gen.gi_code, types.CodeType) + + # Run it. + self.loop.run_until_complete(gen) + + # The frame should have changed. + self.assertIsNone(gen.gi_frame) + + # Save debug flag. + old_debug = asyncio.tasks._DEBUG + try: + # Test with debug flag cleared. + asyncio.tasks._DEBUG = False + check() + + # Test with debug flag set. + asyncio.tasks._DEBUG = True + check() + + finally: + # Restore original debug flag. + asyncio.tasks._DEBUG = old_debug + def test_yield_from_corowrapper(self): old_debug = asyncio.tasks._DEBUG asyncio.tasks._DEBUG = True From 088f44976339db033e777893859f1758075baf39 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 18 Apr 2014 09:51:35 -0700 Subject: [PATCH 1002/1502] Remove superfluous and useless line. (According to Benjamin in CPython repo rev 30a7e37b8441.) --- tests/test_unix_events.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index cc743839..744c3195 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -1335,7 +1335,6 @@ def test_sigchld_child_reaped_elsewhere(self, m): with self.ignore_warnings: self.watcher._sig_chld() - callback.assert_called(m.waitpid) if isinstance(self.watcher, asyncio.FastChildWatcher): # here the FastChildWatche enters a deadlock # (there is no way to prevent it) From 76ec278badb505e81077b0cd2f9cf4dca3516e04 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 27 Apr 2014 10:29:46 -0700 Subject: [PATCH 1003/1502] Be careful accessing instance variables in __del__ (CPython issue 21340). --- asyncio/tasks.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index c6c22dd2..e8ee9475 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -76,7 +76,9 @@ def gi_code(self): return self.gen.gi_code def __del__(self): - frame = self.gen.gi_frame + # Be careful accessing self.gen.frame -- self.gen might not exist. + gen = getattr(self, 'gen', None) + frame = getattr(gen, 'gi_frame', None) if frame is not None and frame.f_lasti == -1: func = self.func code = func.__code__ From e5a919eb9143ba0ee0011162dca38512b47cefa7 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 27 Apr 2014 10:43:00 -0700 Subject: [PATCH 1004/1502] Add __weakref__ slots to Handle and CoroWrapper. Fixes issue #166. --- asyncio/events.py | 2 +- asyncio/tasks.py | 2 +- tests/test_events.py | 6 ++++++ tests/test_tasks.py | 8 ++++++++ 4 files changed, 16 insertions(+), 2 deletions(-) diff --git a/asyncio/events.py b/asyncio/events.py index 57af68af..31592d10 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -16,7 +16,7 @@ class Handle: """Object returned by callback registration methods.""" - __slots__ = ['_callback', '_args', '_cancelled', '_loop'] + __slots__ = ['_callback', '_args', '_cancelled', '_loop', '__weakref__'] def __init__(self, callback, args, loop): assert not isinstance(callback, Handle), 'A Handle is not a callback' diff --git a/asyncio/tasks.py b/asyncio/tasks.py index e8ee9475..45a6342e 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -36,7 +36,7 @@ class CoroWrapper: # Wrapper for coroutine in _DEBUG mode. - __slots__ = ['gen', 'func', '__name__', '__doc__'] + __slots__ = ['gen', 'func', '__name__', '__doc__', '__weakref__'] def __init__(self, gen, func): assert inspect.isgenerator(gen), gen diff --git a/tests/test_events.py b/tests/test_events.py index 1e64dd07..03c41491 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -21,6 +21,7 @@ import errno import unittest from unittest import mock +import weakref from test import support # find_unused_port, IPV6_ENABLED, TEST_HOME_DIR @@ -1786,6 +1787,11 @@ def callback(): 'handle': h }) + def test_handle_weakref(self): + wd = weakref.WeakValueDictionary() + h = asyncio.Handle(lambda: None, (), object()) + wd['h'] = h # Would fail without __weakref__ slot. + class TimerTests(unittest.TestCase): diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 80571b41..45a0dc1d 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -4,6 +4,7 @@ import os.path import types import unittest +import weakref from test.script_helper import assert_python_ok import asyncio @@ -1475,6 +1476,13 @@ def call(arg): self.assertEqual(call((1, 2)), (1, 2)) self.assertEqual(call('spam'), 'spam') + def test_corowrapper_weakref(self): + wd = weakref.WeakValueDictionary() + def foo(): yield from [] + cw = asyncio.tasks.CoroWrapper(foo(), foo) + wd['cw'] = cw # Would fail without __weakref__ slot. + cw.gen = None # Suppress warning from __del__. + class GatherTestsBase: From 1fe19c25d7b031cb99c8af039d21e9b88b11e212 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sun, 4 May 2014 10:11:09 -0700 Subject: [PATCH 1005/1502] Simple echo client/server example (for Twitter thread). --- examples/echo_client_tulip.py | 19 +++++++++++++++++++ examples/echo_server_tulip.py | 17 +++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 examples/echo_client_tulip.py create mode 100644 examples/echo_server_tulip.py diff --git a/examples/echo_client_tulip.py b/examples/echo_client_tulip.py new file mode 100644 index 00000000..9b5d29b8 --- /dev/null +++ b/examples/echo_client_tulip.py @@ -0,0 +1,19 @@ +import asyncio + +END = b'Bye-bye!\n' + +@asyncio.coroutine +def echo_client(): + reader, writer = yield from asyncio.open_connection('localhost', 8000) + writer.write(b'Hello, world\n') + writer.write(b'What a fine day it is.\n') + writer.write(END) + while True: + line = yield from reader.readline() + print('received:', line) + if line == END or not line: + break + writer.close() + +loop = asyncio.get_event_loop() +loop.run_until_complete(echo_client()) diff --git a/examples/echo_server_tulip.py b/examples/echo_server_tulip.py new file mode 100644 index 00000000..c1ccb9df --- /dev/null +++ b/examples/echo_server_tulip.py @@ -0,0 +1,17 @@ +import asyncio + +@asyncio.coroutine +def echo_server(): + yield from asyncio.start_server(handle_connection, 'localhost', 8000) + +@asyncio.coroutine +def handle_connection(reader, writer): + while True: + data = yield from reader.read(8192) + if not data: + break + writer.write(data) + +loop = asyncio.get_event_loop() +loop.run_until_complete(echo_server()) +loop.run_forever() From b044961aff40da5ec259b1ff3a5195a1e47293e8 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 6 May 2014 14:47:09 -0700 Subject: [PATCH 1006/1502] Fix the second half of bugs.python.org/issue21447: race in _write_to_self(). --- asyncio/selector_events.py | 15 +++++++++++---- tests/test_selector_events.py | 5 +++-- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 367c5fbe..c7df8d8d 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -87,10 +87,17 @@ def _read_from_self(self): pass def _write_to_self(self): - try: - self._csock.send(b'x') - except (BlockingIOError, InterruptedError): - pass + # This may be called from a different thread, possibly after + # _close_self_pipe() has been called or even while it is + # running. Guard for self._csock being None or closed. When + # a socket is closed, send() raises OSError (with errno set to + # EBADF, but let's not rely on the exact error code). + csock = self._csock + if csock is not None: + try: + csock.send(b'x') + except OSError: + pass def _start_serving(self, protocol_factory, sock, sslcontext=None, server=None): diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 964b2e8e..0735237c 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -121,8 +121,9 @@ def test_write_to_self_tryagain(self): self.assertIsNone(self.loop._write_to_self()) def test_write_to_self_exception(self): - self.loop._csock.send.side_effect = OSError() - self.assertRaises(OSError, self.loop._write_to_self) + # _write_to_self() swallows OSError + self.loop._csock.send.side_effect = RuntimeError() + self.assertRaises(RuntimeError, self.loop._write_to_self) def test_sock_recv(self): sock = mock.Mock() From 2fc6da3188065feb4b5ebd4392ea09ffe927a37a Mon Sep 17 00:00:00 2001 From: schlamar Date: Sat, 10 May 2014 09:47:55 +0200 Subject: [PATCH 1007/1502] Removed dead code path in _run_once. --- asyncio/base_events.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index d2bdc07d..3d4a87af 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -775,11 +775,7 @@ def _run_once(self): elif self._scheduled: # Compute the desired timeout. when = self._scheduled[0]._when - deadline = max(0, when - self.time()) - if timeout is None: - timeout = deadline - else: - timeout = min(timeout, deadline) + timeout = max(0, when - self.time()) # TODO: Instrumentation only in debug mode? if logger.isEnabledFor(logging.INFO): From c80a6578102c5e3dee3f5cad8e694f31bc6141c4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 12 May 2014 09:55:06 -0700 Subject: [PATCH 1008/1502] Fix issue 168: StreamReader.read(-1) from pipe may hang if data exceeds buffer limit. --- asyncio/streams.py | 17 +++++++----- examples/subprocess_attach_read_pipe.py | 2 +- tests/test_streams.py | 36 +++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 7 deletions(-) diff --git a/asyncio/streams.py b/asyncio/streams.py index 27d595f1..e239248d 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -419,12 +419,17 @@ def read(self, n=-1): return b'' if n < 0: - while not self._eof: - self._waiter = self._create_waiter('read') - try: - yield from self._waiter - finally: - self._waiter = None + # This used to just loop creating a new waiter hoping to + # collect everything in self._buffer, but that would + # deadlock if the subprocess sends more than self.limit + # bytes. So just call self.read(self._limit) until EOF. + blocks = [] + while True: + block = yield from self.read(self._limit) + if not block: + break + blocks.append(block) + return b''.join(blocks) else: if not self._buffer and not self._eof: self._waiter = self._create_waiter('read') diff --git a/examples/subprocess_attach_read_pipe.py b/examples/subprocess_attach_read_pipe.py index 57a1342b..a692781c 100644 --- a/examples/subprocess_attach_read_pipe.py +++ b/examples/subprocess_attach_read_pipe.py @@ -6,7 +6,7 @@ code = """ import os, sys fd = int(sys.argv[1]) -data = os.write(fd, b'data') +os.write(fd, b'data') os.close(fd) """ diff --git a/tests/test_streams.py b/tests/test_streams.py index 031499e8..23012b72 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -1,7 +1,9 @@ """Tests for streams.py.""" import gc +import os import socket +import sys import unittest from unittest import mock try: @@ -583,6 +585,40 @@ def client(path): server.stop() self.assertEqual(msg, b"hello world!\n") + @unittest.skipIf(sys.platform == 'win32', "Don't have pipes") + def test_read_all_from_pipe_reader(self): + # See Tulip issue 168. This test is derived from the example + # subprocess_attach_read_pipe.py, but we configure the + # StreamReader's limit so that twice it is less than the size + # of the data writter. Also we must explicitly attach a child + # watcher to the event loop. + + watcher = asyncio.get_child_watcher() + watcher.attach_loop(self.loop) + + code = """\ +import os, sys +fd = int(sys.argv[1]) +os.write(fd, b'data') +os.close(fd) +""" + rfd, wfd = os.pipe() + args = [sys.executable, '-c', code, str(wfd)] + + pipe = open(rfd, 'rb', 0) + reader = asyncio.StreamReader(loop=self.loop, limit=1) + protocol = asyncio.StreamReaderProtocol(reader, loop=self.loop) + transport, _ = self.loop.run_until_complete( + self.loop.connect_read_pipe(lambda: protocol, pipe)) + + proc = self.loop.run_until_complete( + asyncio.create_subprocess_exec(*args, pass_fds={wfd}, loop=self.loop)) + self.loop.run_until_complete(proc.wait()) + + os.close(wfd) + data = self.loop.run_until_complete(reader.read(-1)) + self.assertEqual(data, b'data') + if __name__ == '__main__': unittest.main() From f8091e786e7017d1e8dd271a96b003e2b8049f3d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 13 May 2014 08:41:50 -0700 Subject: [PATCH 1009/1502] Fix test failures by not cleaning up watcher in test for issue #168. --- tests/test_streams.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/test_streams.py b/tests/test_streams.py index 23012b72..1ecc8eb1 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -593,9 +593,6 @@ def test_read_all_from_pipe_reader(self): # of the data writter. Also we must explicitly attach a child # watcher to the event loop. - watcher = asyncio.get_child_watcher() - watcher.attach_loop(self.loop) - code = """\ import os, sys fd = int(sys.argv[1]) @@ -611,9 +608,15 @@ def test_read_all_from_pipe_reader(self): transport, _ = self.loop.run_until_complete( self.loop.connect_read_pipe(lambda: protocol, pipe)) - proc = self.loop.run_until_complete( - asyncio.create_subprocess_exec(*args, pass_fds={wfd}, loop=self.loop)) - self.loop.run_until_complete(proc.wait()) + watcher = asyncio.SafeChildWatcher() + watcher.attach_loop(self.loop) + try: + asyncio.set_child_watcher(watcher) + proc = self.loop.run_until_complete( + asyncio.create_subprocess_exec(*args, pass_fds={wfd}, loop=self.loop)) + self.loop.run_until_complete(proc.wait()) + finally: + asyncio.set_child_watcher(None) os.close(wfd) data = self.loop.run_until_complete(reader.read(-1)) From 2a0e55d58add45998ae07fb1b9eb58c34ef8b1f9 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 19 May 2014 08:14:50 -0700 Subject: [PATCH 1010/1502] Add option to randomize test order. --- runtests.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/runtests.py b/runtests.py index 469ac86a..4d06f695 100644 --- a/runtests.py +++ b/runtests.py @@ -23,6 +23,7 @@ import gc import logging import os +import random import re import sys import unittest @@ -55,6 +56,10 @@ ARGS.add_argument( '--findleaks', action='store_true', dest='findleaks', help='detect tests that leak memory') +ARGS.add_argument('-r', '--randomize', action='store_true', + help='randomize test execution order.') +ARGS.add_argument('--seed', type=int, + help='random seed to reproduce a previous random run') ARGS.add_argument( '-q', action="store_true", dest='quiet', help='quiet') ARGS.add_argument( @@ -112,6 +117,14 @@ def list_dir(prefix, dir): return mods +def randomize_tests(tests, seed): + if seed is None: + seed = random.randrange(10000000) + random.seed(seed) + print("Using random seed", seed) + random.shuffle(tests._tests) + + class TestsFinder: def __init__(self, testsdir, includes=(), excludes=()): @@ -253,12 +266,16 @@ def runtests(): if args.forever: while True: tests = finder.load_tests() + if args.randomize: + randomize_tests(tests, args.seed) result = runner_factory(verbosity=v, failfast=failfast).run(tests) if not result.wasSuccessful(): sys.exit(1) else: tests = finder.load_tests() + if args.randomize: + randomize_tests(tests, args.seed) result = runner_factory(verbosity=v, failfast=failfast).run(tests) sys.exit(not result.wasSuccessful()) From 48ae84b17c30d0d4a4c20659ac37243f59946687 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 19 May 2014 08:14:56 -0700 Subject: [PATCH 1011/1502] Added tag 3.4.1 for changeset e6084a6ff3bb From 93dcad1500ec151f51521412cbc1ac86804150e4 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 19 May 2014 08:15:55 -0700 Subject: [PATCH 1012/1502] Bump version in setup.py. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 6cce4a7b..14f96f26 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ setup( name="asyncio", - version="0.4.1", + version="3.4.1", description="reference implementation of PEP 3156", long_description=open("README").read(), From a7400f7e6c0ae49497303a8c73a43f8c4581e39d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 19 May 2014 08:16:03 -0700 Subject: [PATCH 1013/1502] Added tag 3.4.1 for changeset 7c85dd9f8f6e From ff2097bd21242eaa8965cb0da91c52b302e85bf1 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 19 May 2014 11:22:12 -0700 Subject: [PATCH 1014/1502] Update instructions for Windows. --- Makefile | 14 ++++++++++++-- README | 10 +++++++++- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 952c93ff..eda02f2d 100644 --- a/Makefile +++ b/Makefile @@ -42,9 +42,19 @@ clean: rm -f MANIFEST +# For distribution builders only! # Push a source distribution for Python 3.3 to PyPI. # You must update the version in setup.py first. -# The corresponding action on Windows is pypi.bat. -# A PyPI user configuration in ~/.pypirc is required. +# A PyPI user configuration in ~/.pypirc is required; +# you can create a suitable confifuration using +# python setup.py register pypi: clean python3.3 setup.py sdist upload + +# The corresponding action on Windows is pypi.bat. For that to work, +# you need to install wheel and setuptools. The easiest way is to get +# pip using the get-pip.py script found here: +# https://pip.pypa.io/en/latest/installing.html#install-pip +# That will install setuptools and pip; then you can just do +# \Python33\python.exe -m pip install wheel +# after which the pypi.bat script should work. diff --git a/README b/README index 34812e08..2f3150a2 100644 --- a/README +++ b/README @@ -25,7 +25,15 @@ in the asyncio directory, as follows: C> P setup.py build_ext --inplace -Then you can run the tests as follows: +If this complains about vcvars.bat, you probably don't have the +required version of Visual Studio installed. Compiling extensions for +Python 3.3 requires Microsoft Visual C++ 2010 (MSVC 10.0) of any +edition; you can download Visual Studio Express 2010 for free from +http://www.visualstudio.com/downloads (scroll down to Visual C++ 2010 +Express). + +Once you have built the _overlapped.pyd extension successfully you can +run the tests as follows: C> P runtests.py From c7bc730e63a6cbf282de6f1b04af0932fab3093b Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 20 May 2014 15:56:33 +0200 Subject: [PATCH 1015/1502] test_base_events: use mock.Mock instead of unittest.mock.Mock to simplify the synchronization with Trollius --- tests/test_base_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 544bd3dd..4ba95565 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -141,7 +141,7 @@ def cb(): pass other_loop = base_events.BaseEventLoop() - other_loop._selector = unittest.mock.Mock() + other_loop._selector = mock.Mock() asyncio.set_event_loop(other_loop) # raise RuntimeError if the event loop is different in debug mode From a0b74cab842585207c2a8f5f509497fdabe44937 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Tue, 27 May 2014 21:12:32 +0300 Subject: [PATCH 1016/1502] Fix for raising exception not derived from BaseException in _SelectorSslTransport.resume_reading --- asyncio/selector_events.py | 2 +- tests/test_selector_events.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index c7df8d8d..86a8d23c 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -670,7 +670,7 @@ def pause_reading(self): def resume_reading(self): if not self._paused: - raise ('Not paused') + raise RuntimeError('Not paused') self._paused = False if self._closing: return diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 0735237c..d7fafab6 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -711,6 +711,8 @@ def test_pause_resume_reading(self): tr.resume_reading() self.assertFalse(tr._paused) self.loop.assert_reader(7, tr._read_ready) + with self.assertRaises(RuntimeError): + tr.resume_reading() def test_read_ready(self): transport = _SelectorSocketTransport( @@ -1125,6 +1127,8 @@ def test_pause_resume_reading(self): tr.resume_reading() self.assertFalse(tr._paused) self.loop.assert_reader(1, tr._read_ready) + with self.assertRaises(RuntimeError): + tr.resume_reading() def test_write(self): transport = self._make_one() From 4c9025f166e183f292080c73a81c9013b1fe9da1 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 28 May 2014 16:15:18 -0700 Subject: [PATCH 1017/1502] Fix docstring typo: CancellationError should be CancelledError. --- asyncio/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 45a6342e..5599c18b 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -271,7 +271,7 @@ def print_stack(self, *, limit=None, file=None): def cancel(self): """Request that a task to cancel itself. - This arranges for a CancellationError to be thrown into the + This arranges for a CancelledError to be thrown into the wrapped coroutine on the next cycle through the event loop. The coroutine then has a chance to clean up or even deny the request using try/except/finally. From 5175932c13ff889435a7c6d01440776e5491e6dd Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 29 May 2014 10:59:43 -0700 Subject: [PATCH 1018/1502] Make 'python3 setup.py test' work. --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 14f96f26..fcd3b6aa 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ ], packages=["asyncio"], + test_suite="runtests.runtests", ext_modules=extensions, ) From 5fa18a526f8538db583d5499262fa2cadc744be7 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 2 Jun 2014 23:05:42 +0200 Subject: [PATCH 1019/1502] Rephrase Task.cancel docstring --- asyncio/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 5599c18b..2aa568bc 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -269,7 +269,7 @@ def print_stack(self, *, limit=None, file=None): print(line, file=file, end='') def cancel(self): - """Request that a task to cancel itself. + """Request this task to cancel itself. This arranges for a CancelledError to be thrown into the wrapped coroutine on the next cycle through the event loop. From eed62171abb74af554808233d393d4d3fe6e84ed Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 4 Jun 2014 00:06:09 +0200 Subject: [PATCH 1020/1502] Close sockets on errors Fix ResourceWarning: BaseEventLoop.create_connection(), BaseEventLoop.create_datagram_endpoint() and _UnixSelectorEventLoop.create_unix_server() now close the newly created socket on error. --- asyncio/base_events.py | 8 ++++++++ asyncio/unix_events.py | 3 +++ tests/test_base_events.py | 21 +++++++++++++++++++++ tests/test_unix_events.py | 18 ++++++++++++++++++ 4 files changed, 50 insertions(+) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 3d4a87af..1c7073c3 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -412,6 +412,10 @@ def create_connection(self, protocol_factory, host=None, port=None, *, if sock is not None: sock.close() exceptions.append(exc) + except: + if sock is not None: + sock.close() + raise else: break else: @@ -512,6 +516,10 @@ def create_datagram_endpoint(self, protocol_factory, if sock is not None: sock.close() exceptions.append(exc) + except: + if sock is not None: + sock.close() + raise else: break else: diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 1fbdd313..230fbc38 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -223,6 +223,9 @@ def create_unix_server(self, protocol_factory, path=None, *, raise OSError(errno.EADDRINUSE, msg) from None else: raise + except: + sock.close() + raise else: if sock is None: raise ValueError( diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 4ba95565..dbcd590b 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -583,6 +583,27 @@ def _socket(*args, **kw): self.assertEqual(str(cm.exception), 'Multiple exceptions: err1, err2') + @mock.patch('asyncio.base_events.socket') + def test_create_connection_timeout(self, m_socket): + # Ensure that the socket is closed on timeout + sock = mock.Mock() + m_socket.socket.return_value = sock + + def getaddrinfo(*args, **kw): + fut = asyncio.Future(loop=self.loop) + addr = (socket.AF_INET, socket.SOCK_STREAM, 0, '', + ('127.0.0.1', 80)) + fut.set_result([addr]) + return fut + self.loop.getaddrinfo = getaddrinfo + + with mock.patch.object(self.loop, 'sock_connect', + side_effect=asyncio.TimeoutError): + coro = self.loop.create_connection(MyProto, '127.0.0.1', 80) + with self.assertRaises(asyncio.TimeoutError) as cm: + self.loop.run_until_complete(coro) + self.assertTrue(sock.close.called) + def test_create_connection_host_port_sock(self): coro = self.loop.create_connection( MyProto, 'example.com', 80, sock=object()) diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 744c3195..cec7a110 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -256,6 +256,24 @@ def test_create_unix_server_path_inetsock(self): 'A UNIX Domain Socket was expected'): self.loop.run_until_complete(coro) + @mock.patch('asyncio.unix_events.socket') + def test_create_unix_server_bind_error(self, m_socket): + # Ensure that the socket is closed on any bind error + sock = mock.Mock() + m_socket.socket.return_value = sock + + sock.bind.side_effect = OSError + coro = self.loop.create_unix_server(lambda: None, path="/test") + with self.assertRaises(OSError): + self.loop.run_until_complete(coro) + self.assertTrue(sock.close.called) + + sock.bind.side_effect = MemoryError + coro = self.loop.create_unix_server(lambda: None, path="/test") + with self.assertRaises(MemoryError): + self.loop.run_until_complete(coro) + self.assertTrue(sock.close.called) + def test_create_unix_connection_path_sock(self): coro = self.loop.create_unix_connection( lambda: None, '/dev/null', sock=object()) From e000cf5b40a6916a704bccb11c61301f66139dd6 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 4 Jun 2014 00:08:02 +0200 Subject: [PATCH 1021/1502] Make sure that socketpair() close sockets on error Close the listening socket if sock.bind() raises an exception. --- asyncio/windows_utils.py | 32 +++++++++++++++++--------------- tests/test_windows_utils.py | 9 +++++++++ 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/asyncio/windows_utils.py b/asyncio/windows_utils.py index 2a196cc7..f7f2f358 100644 --- a/asyncio/windows_utils.py +++ b/asyncio/windows_utils.py @@ -51,23 +51,25 @@ def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): # We create a connected TCP socket. Note the trick with setblocking(0) # that prevents us from having to create a thread. lsock = socket.socket(family, type, proto) - lsock.bind((host, 0)) - lsock.listen(1) - # On IPv6, ignore flow_info and scope_id - addr, port = lsock.getsockname()[:2] - csock = socket.socket(family, type, proto) - csock.setblocking(False) try: - csock.connect((addr, port)) - except (BlockingIOError, InterruptedError): - pass - except Exception: + lsock.bind((host, 0)) + lsock.listen(1) + # On IPv6, ignore flow_info and scope_id + addr, port = lsock.getsockname()[:2] + csock = socket.socket(family, type, proto) + try: + csock.setblocking(False) + try: + csock.connect((addr, port)) + except (BlockingIOError, InterruptedError): + pass + ssock, _ = lsock.accept() + csock.setblocking(True) + except: + csock.close() + raise + finally: lsock.close() - csock.close() - raise - ssock, _ = lsock.accept() - csock.setblocking(True) - lsock.close() return (ssock, csock) diff --git a/tests/test_windows_utils.py b/tests/test_windows_utils.py index 9daf4340..b1f81da8 100644 --- a/tests/test_windows_utils.py +++ b/tests/test_windows_utils.py @@ -51,6 +51,15 @@ def test_winsocketpair_invalid_args(self): self.assertRaises(ValueError, windows_utils.socketpair, proto=1) + @mock.patch('asyncio.windows_utils.socket') + def test_winsocketpair_close(self, m_socket): + m_socket.AF_INET = socket.AF_INET + m_socket.SOCK_STREAM = socket.SOCK_STREAM + sock = mock.Mock() + m_socket.socket.return_value = sock + sock.bind.side_effect = OSError + self.assertRaises(OSError, windows_utils.socketpair) + self.assertTrue(sock.close.called) class PipeTests(unittest.TestCase): From 069bf040c34ac7ab8607df93045b63d0e17761db Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 4 Jun 2014 00:10:37 +0200 Subject: [PATCH 1022/1502] Python issue #21454: Fix asyncio.BaseEventLoop.connect_read_pipe doc The function sets the the pipe to non-blocking mode. --- asyncio/events.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/asyncio/events.py b/asyncio/events.py index 31592d10..f0ad5680 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -257,11 +257,11 @@ def create_datagram_endpoint(self, protocol_factory, # Pipes and subprocesses. def connect_read_pipe(self, protocol_factory, pipe): - """Register read pipe in event loop. + """Register read pipe in event loop. Set the pipe to non-blocking mode. protocol_factory should instantiate object with Protocol interface. - pipe is file-like object already switched to nonblocking. - Return pair (transport, protocol), where transport support + pipe is a file-like object. + Return pair (transport, protocol), where transport supports the ReadTransport interface.""" # The reason to accept file-like object instead of just file descriptor # is: we need to own pipe and close it at transport finishing From 564e4f2989577993b28b63ec38d738edc404c89d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 4 Jun 2014 00:18:19 +0200 Subject: [PATCH 1023/1502] cleanup test_base_events: cm variable was unused --- tests/test_base_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index dbcd590b..e28c3272 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -600,7 +600,7 @@ def getaddrinfo(*args, **kw): with mock.patch.object(self.loop, 'sock_connect', side_effect=asyncio.TimeoutError): coro = self.loop.create_connection(MyProto, '127.0.0.1', 80) - with self.assertRaises(asyncio.TimeoutError) as cm: + with self.assertRaises(asyncio.TimeoutError): self.loop.run_until_complete(coro) self.assertTrue(sock.close.called) From c02affb40c1e79a5c0766e508c7a4203a584e816 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 4 Jun 2014 00:22:33 +0200 Subject: [PATCH 1024/1502] Python issue #21651: Fix ResourceWarning when running asyncio tests on Windows. Patch written by Claudiu Popa. --- tests/test_events.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_events.py b/tests/test_events.py index 03c41491..e19d991f 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1070,6 +1070,7 @@ def datagram_received(self, data, addr): def test_internal_fds(self): loop = self.create_event_loop() if not isinstance(loop, selector_events.BaseSelectorEventLoop): + loop.close() self.skipTest('loop is not a BaseSelectorEventLoop') self.assertEqual(1, loop._internal_fds) From 2b3ad8eac190664a18630381e5d070e4c644c408 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 4 Jun 2014 00:41:42 +0200 Subject: [PATCH 1025/1502] Fix tests on Windows: wait for the subprocess exit Before, regrtest failed to remove the temporary test directory because the process was still running in this directory. --- tests/test_windows_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_windows_utils.py b/tests/test_windows_utils.py index b1f81da8..7ea3a6d3 100644 --- a/tests/test_windows_utils.py +++ b/tests/test_windows_utils.py @@ -164,6 +164,8 @@ def test_popen(self): self.assertTrue(msg.upper().rstrip().startswith(out)) self.assertTrue(b"stderr".startswith(err)) + p.wait() + if __name__ == '__main__': unittest.main() From c4756add71b613f79cb2402838a938f3584e6e1c Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 4 Jun 2014 01:04:02 +0200 Subject: [PATCH 1026/1502] Tulip issue #169, Python issue #21326: Add BaseEventLoop.is_closed() method Add BaseEventLoop._closed attribute and use it to check if the event loop was closed or not, instead of checking different attributes in each subclass of BaseEventLoop. run_forever() and run_until_complete() now raises a RuntimeError('Event loop is closed') exception if the event loop was closed. BaseProactorEventLoop.close() now also cancels "accept futures". --- asyncio/base_events.py | 19 +++++++++++++++++++ asyncio/proactor_events.py | 23 +++++++++++++++-------- asyncio/selector_events.py | 16 ++++++++-------- tests/test_base_events.py | 14 ++++++++++++++ tests/test_selector_events.py | 17 ++++++++++++++--- 5 files changed, 70 insertions(+), 19 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 1c7073c3..5ee21d1c 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -119,6 +119,7 @@ def wait_closed(self): class BaseEventLoop(events.AbstractEventLoop): def __init__(self): + self._closed = False self._ready = collections.deque() self._scheduled = [] self._default_executor = None @@ -128,6 +129,11 @@ def __init__(self): self._exception_handler = None self._debug = False + def __repr__(self): + return ('<%s running=%s closed=%s debug=%s>' + % (self.__class__.__name__, self.is_running(), + self.is_closed(), self.get_debug())) + def _make_socket_transport(self, sock, protocol, waiter=None, *, extra=None, server=None): """Create socket transport.""" @@ -173,8 +179,13 @@ def _process_events(self, event_list): """Process selector events.""" raise NotImplementedError + def _check_closed(self): + if self._closed: + raise RuntimeError('Event loop is closed') + def run_forever(self): """Run until stop() is called.""" + self._check_closed() if self._running: raise RuntimeError('Event loop is running.') self._running = True @@ -198,6 +209,7 @@ def run_until_complete(self, future): Return the Future's result, or raise its exception. """ + self._check_closed() future = tasks.async(future, loop=self) future.add_done_callback(_raise_stop_error) self.run_forever() @@ -222,6 +234,9 @@ def close(self): This clears the queues and shuts down the executor, but does not wait for the executor to finish. """ + if self._closed: + return + self._closed = True self._ready.clear() self._scheduled.clear() executor = self._default_executor @@ -229,6 +244,10 @@ def close(self): self._default_executor = None executor.shutdown(wait=False) + def is_closed(self): + """Returns True if the event loop was closed.""" + return self._closed + def is_running(self): """Returns running status of event loop.""" return self._running diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index d99e8ce7..757a22e8 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -353,13 +353,14 @@ def _make_write_pipe_transport(self, sock, protocol, waiter=None, sock, protocol, waiter, extra) def close(self): - if self._proactor is not None: - self._close_self_pipe() - self._proactor.close() - self._proactor = None - self._selector = None - super().close() - self._accept_futures.clear() + if self.is_closed(): + return + self._stop_accept_futures() + self._close_self_pipe() + self._proactor.close() + self._proactor = None + self._selector = None + super().close() def sock_recv(self, sock, n): return self._proactor.recv(sock, n) @@ -428,6 +429,8 @@ def loop(f=None): self._make_socket_transport( conn, protocol, extra={'peername': addr}, server=server) + if self.is_closed(): + return f = self._proactor.accept(sock) except OSError as exc: if sock.fileno() != -1: @@ -448,8 +451,12 @@ def loop(f=None): def _process_events(self, event_list): pass # XXX hard work currently done in poll - def _stop_serving(self, sock): + def _stop_accept_futures(self): for future in self._accept_futures.values(): future.cancel() + self._accept_futures.clear() + + def _stop_serving(self, sock): + self._stop_accept_futures() self._proactor._stop_serving(sock) sock.close() diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 86a8d23c..1f8e5c8b 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -55,11 +55,13 @@ def _make_datagram_transport(self, sock, protocol, return _SelectorDatagramTransport(self, sock, protocol, address, extra) def close(self): + if self.is_closed(): + return + self._close_self_pipe() if self._selector is not None: - self._close_self_pipe() self._selector.close() self._selector = None - super().close() + super().close() def _socketpair(self): raise NotImplementedError @@ -143,8 +145,7 @@ def _accept_connection(self, protocol_factory, sock, def add_reader(self, fd, callback, *args): """Add a reader callback.""" - if self._selector is None: - raise RuntimeError('Event loop is closed') + self._check_closed() handle = events.Handle(callback, args, self) try: key = self._selector.get_key(fd) @@ -160,7 +161,7 @@ def add_reader(self, fd, callback, *args): def remove_reader(self, fd): """Remove a reader callback.""" - if self._selector is None: + if self.is_closed(): return False try: key = self._selector.get_key(fd) @@ -182,8 +183,7 @@ def remove_reader(self, fd): def add_writer(self, fd, callback, *args): """Add a writer callback..""" - if self._selector is None: - raise RuntimeError('Event loop is closed') + self._check_closed() handle = events.Handle(callback, args, self) try: key = self._selector.get_key(fd) @@ -199,7 +199,7 @@ def add_writer(self, fd, callback, *args): def remove_writer(self, fd): """Remove a writer callback.""" - if self._selector is None: + if self.is_closed(): return False try: key = self._selector.get_key(fd) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index e28c3272..1611a114 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -52,6 +52,20 @@ def test_not_implemented(self): gen = self.loop._make_subprocess_transport(m, m, m, m, m, m, m) self.assertRaises(NotImplementedError, next, iter(gen)) + def test_close(self): + self.assertFalse(self.loop.is_closed()) + self.loop.close() + self.assertTrue(self.loop.is_closed()) + + # it should be possible to call close() more than once + self.loop.close() + self.loop.close() + + # operation blocked when the loop is closed + f = asyncio.Future(loop=self.loop) + self.assertRaises(RuntimeError, self.loop.run_forever) + self.assertRaises(RuntimeError, self.loop.run_until_complete, f) + def test__add_callback_handle(self): h = asyncio.Handle(lambda: False, (), self.loop) diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index d7fafab6..36f65085 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -80,7 +80,10 @@ def test_close(self): self.loop._selector.close() self.loop._selector = selector = mock.Mock() + self.assertFalse(self.loop.is_closed()) + self.loop.close() + self.assertTrue(self.loop.is_closed()) self.assertIsNone(self.loop._selector) self.assertIsNone(self.loop._csock) self.assertIsNone(self.loop._ssock) @@ -89,9 +92,20 @@ def test_close(self): csock.close.assert_called_with() remove_reader.assert_called_with(7) + # it should be possible to call close() more than once self.loop.close() self.loop.close() + # operation blocked when the loop is closed + f = asyncio.Future(loop=self.loop) + self.assertRaises(RuntimeError, self.loop.run_forever) + self.assertRaises(RuntimeError, self.loop.run_until_complete, f) + fd = 0 + def callback(): + pass + self.assertRaises(RuntimeError, self.loop.add_reader, fd, callback) + self.assertRaises(RuntimeError, self.loop.add_writer, fd, callback) + def test_close_no_selector(self): ssock = self.loop._ssock csock = self.loop._csock @@ -101,9 +115,6 @@ def test_close_no_selector(self): self.loop._selector = None self.loop.close() self.assertIsNone(self.loop._selector) - self.assertFalse(ssock.close.called) - self.assertFalse(csock.close.called) - self.assertFalse(remove_reader.called) def test_socketpair(self): self.assertRaises(NotImplementedError, self.loop._socketpair) From 29ceebc4cb1ee5ce45265f9788134fecf59dccd8 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 5 Jun 2014 12:05:27 +0200 Subject: [PATCH 1027/1502] Tulip issue #83, Python issue 21252: Fill some XXX docstrings --- asyncio/events.py | 35 +++++++++++++++++++++++------------ asyncio/unix_events.py | 4 ++-- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/asyncio/events.py b/asyncio/events.py index f0ad5680..4a9a9a38 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -355,25 +355,33 @@ class AbstractEventLoopPolicy: """Abstract policy for accessing the event loop.""" def get_event_loop(self): - """XXX""" + """Get the event loop for the current context. + + Returns an event loop object implementing the BaseEventLoop interface, + or raises an exception in case no event loop has been set for the + current context and the current policy does not specify to create one. + + It should never return None.""" raise NotImplementedError def set_event_loop(self, loop): - """XXX""" + """Set the event loop for the current context to loop.""" raise NotImplementedError def new_event_loop(self): - """XXX""" + """Create and return a new event loop object according to this + policy's rules. If there's need to set this loop as the event loop for + the current context, set_event_loop must be called explicitly.""" raise NotImplementedError # Child processes handling (Unix only). def get_child_watcher(self): - """XXX""" + "Get the watcher for child processes." raise NotImplementedError def set_child_watcher(self, watcher): - """XXX""" + """Set the watcher for child processes.""" raise NotImplementedError @@ -447,39 +455,42 @@ def _init_event_loop_policy(): def get_event_loop_policy(): - """XXX""" + """Get the current event loop policy.""" if _event_loop_policy is None: _init_event_loop_policy() return _event_loop_policy def set_event_loop_policy(policy): - """XXX""" + """Set the current event loop policy. + + If policy is None, the default policy is restored.""" global _event_loop_policy assert policy is None or isinstance(policy, AbstractEventLoopPolicy) _event_loop_policy = policy def get_event_loop(): - """XXX""" + """Equivalent to calling get_event_loop_policy().get_event_loop().""" return get_event_loop_policy().get_event_loop() def set_event_loop(loop): - """XXX""" + """Equivalent to calling get_event_loop_policy().set_event_loop(loop).""" get_event_loop_policy().set_event_loop(loop) def new_event_loop(): - """XXX""" + """Equivalent to calling get_event_loop_policy().new_event_loop().""" return get_event_loop_policy().new_event_loop() def get_child_watcher(): - """XXX""" + """Equivalent to calling get_event_loop_policy().get_child_watcher().""" return get_event_loop_policy().get_child_watcher() def set_child_watcher(watcher): - """XXX""" + """Equivalent to calling + get_event_loop_policy().set_child_watcher(watcher).""" return get_event_loop_policy().set_child_watcher(watcher) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 230fbc38..acb327d9 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -822,7 +822,7 @@ def set_event_loop(self, loop): self._watcher.attach_loop(loop) def get_child_watcher(self): - """Get the child watcher + """Get the watcher for child processes. If not yet set, a SafeChildWatcher object is automatically created. """ @@ -832,7 +832,7 @@ def get_child_watcher(self): return self._watcher def set_child_watcher(self, watcher): - """Set the child watcher""" + """Set the watcher for child processes.""" assert watcher is None or isinstance(watcher, AbstractChildWatcher) From 01fef52482b8f8e24a8cff29beef2769533a83ea Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 10 Jun 2014 11:15:03 +0200 Subject: [PATCH 1028/1502] wait(): mention that the future sequence must not be empty --- asyncio/tasks.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 2aa568bc..8b8fb82e 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -387,6 +387,8 @@ def _wakeup(self, future): def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): """Wait for the Futures and coroutines given by fs to complete. + The sequence futures must not be empty. + Coroutines will be wrapped in Tasks. Returns two sets of Future: (done, pending). From 2e0955c14eccaf28f37791159efb24cec266a06d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 12 Jun 2014 18:35:14 +0200 Subject: [PATCH 1029/1502] Issue #173: Enhance repr(Handle) and repr(Task) repr(Handle) is shorter for function: "foo" instead of "". It now also includes the source of the callback, filename and line number where it was defined, if available. repr(Task) now also includes the current position in the code, filename and line number, if available. If the coroutine (generator) is done, the line number is omitted and "done" is added. --- asyncio/events.py | 30 ++++++++++++++++- asyncio/tasks.py | 10 +++++- asyncio/test_utils.py | 7 ++++ tests/test_events.py | 78 +++++++++++++++++++++++++++++++------------ tests/test_tasks.py | 29 +++++++++++----- 5 files changed, 123 insertions(+), 31 deletions(-) diff --git a/asyncio/events.py b/asyncio/events.py index 4a9a9a38..de161df6 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -8,9 +8,29 @@ 'get_child_watcher', 'set_child_watcher', ] +import functools +import inspect import subprocess import threading import socket +import sys + + +_PY34 = sys.version_info >= (3, 4) + +def _get_function_source(func): + if _PY34: + func = inspect.unwrap(func) + elif hasattr(func, '__wrapped__'): + func = func.__wrapped__ + if inspect.isfunction(func): + code = func.__code__ + return (code.co_filename, code.co_firstlineno) + if isinstance(func, functools.partial): + return _get_function_source(func.func) + if _PY34 and isinstance(func, functools.partialmethod): + return _get_function_source(func.func) + return None class Handle: @@ -26,7 +46,15 @@ def __init__(self, callback, args, loop): self._cancelled = False def __repr__(self): - res = 'Handle({}, {})'.format(self._callback, self._args) + cb_repr = getattr(self._callback, '__qualname__', None) + if not cb_repr: + cb_repr = str(self._callback) + + source = _get_function_source(self._callback) + if source: + cb_repr += ' at %s:%s' % source + + res = 'Handle({}, {})'.format(cb_repr, self._args) if self._cancelled: res += '' return res diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 8b8fb82e..e6fd3d38 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -188,7 +188,15 @@ def __repr__(self): i = res.find('<') if i < 0: i = len(res) - res = res[:i] + '(<{}>)'.format(self._coro.__name__) + res[i:] + text = self._coro.__name__ + coro = self._coro + if inspect.isgenerator(coro): + filename = coro.gi_code.co_filename + if coro.gi_frame is not None: + text += ' at %s:%s' % (filename, coro.gi_frame.f_lineno) + else: + text += ' done at %s' % filename + res = res[:i] + '(<{}>)'.format(text) + res[i:] return res def get_stack(self, *, limit=None): diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index 9c3656ac..1062bae1 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -372,3 +372,10 @@ class MockPattern(str): """ def __eq__(self, other): return bool(re.search(str(self), other, re.S)) + + +def get_function_source(func): + source = events._get_function_source(func) + if source is None: + raise ValueError("unable to get the source of %r" % (func,)) + return source diff --git a/tests/test_events.py b/tests/test_events.py index e19d991f..2262a752 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -5,6 +5,7 @@ import io import os import platform +import re import signal import socket try: @@ -1737,52 +1738,46 @@ def create_event_loop(self): return asyncio.SelectorEventLoop(selectors.SelectSelector()) +def noop(): + pass + + class HandleTests(unittest.TestCase): + def setUp(self): + self.loop = None + def test_handle(self): def callback(*args): return args args = () - h = asyncio.Handle(callback, args, mock.Mock()) + h = asyncio.Handle(callback, args, self.loop) self.assertIs(h._callback, callback) self.assertIs(h._args, args) self.assertFalse(h._cancelled) - r = repr(h) - self.assertTrue(r.startswith( - 'Handle(' - '.callback')) - self.assertTrue(r.endswith('())')) - h.cancel() self.assertTrue(h._cancelled) - r = repr(h) - self.assertTrue(r.startswith( - 'Handle(' - '.callback')) - self.assertTrue(r.endswith('())'), r) - def test_handle_from_handle(self): def callback(*args): return args - m_loop = object() - h1 = asyncio.Handle(callback, (), loop=m_loop) + h1 = asyncio.Handle(callback, (), loop=self.loop) self.assertRaises( - AssertionError, asyncio.Handle, h1, (), m_loop) + AssertionError, asyncio.Handle, h1, (), self.loop) def test_callback_with_exception(self): def callback(): raise ValueError() - m_loop = mock.Mock() - m_loop.call_exception_handler = mock.Mock() + self.loop = mock.Mock() + self.loop.call_exception_handler = mock.Mock() - h = asyncio.Handle(callback, (), m_loop) + h = asyncio.Handle(callback, (), self.loop) h._run() - m_loop.call_exception_handler.assert_called_with({ + self.loop.call_exception_handler.assert_called_with({ 'message': test_utils.MockPattern('Exception in callback.*'), 'exception': mock.ANY, 'handle': h @@ -1790,9 +1785,50 @@ def callback(): def test_handle_weakref(self): wd = weakref.WeakValueDictionary() - h = asyncio.Handle(lambda: None, (), object()) + h = asyncio.Handle(lambda: None, (), self.loop) wd['h'] = h # Would fail without __weakref__ slot. + def test_repr(self): + # simple function + h = asyncio.Handle(noop, (), self.loop) + src = test_utils.get_function_source(noop) + self.assertEqual(repr(h), + 'Handle(noop at %s:%s, ())' % src) + + # cancelled handle + h.cancel() + self.assertEqual(repr(h), + 'Handle(noop at %s:%s, ())' % src) + + # decorated function + cb = asyncio.coroutine(noop) + h = asyncio.Handle(cb, (), self.loop) + self.assertEqual(repr(h), + 'Handle(noop at %s:%s, ())' % src) + + # partial function + cb = functools.partial(noop) + h = asyncio.Handle(cb, (), self.loop) + filename, lineno = src + regex = (r'^Handle\(functools.partial\(' + r'\) at %s:%s, ' + r'\(\)\)$' % (re.escape(filename), lineno)) + self.assertRegex(repr(h), regex) + + # partial method + if sys.version_info >= (3, 4): + method = HandleTests.test_repr + cb = functools.partialmethod(method) + src = test_utils.get_function_source(method) + h = asyncio.Handle(cb, (), self.loop) + + filename, lineno = src + regex = (r'^Handle\(functools.partialmethod\(' + r', , \) at %s:%s, ' + r'\(\)\)$' % (re.escape(filename), lineno)) + self.assertRegex(repr(h), regex) + + class TimerTests(unittest.TestCase): diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 45a0dc1d..92eb9dae 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -116,21 +116,30 @@ def notmuch(): yield from [] return 'abc' + filename, lineno = test_utils.get_function_source(notmuch) + src = "%s:%s" % (filename, lineno) + t = asyncio.Task(notmuch(), loop=self.loop) t.add_done_callback(Dummy()) - self.assertEqual(repr(t), 'Task()') + self.assertEqual(repr(t), + 'Task()' % src) + t.cancel() # Does not take immediate effect! - self.assertEqual(repr(t), 'Task()') + self.assertEqual(repr(t), + 'Task()' % src) self.assertRaises(asyncio.CancelledError, self.loop.run_until_complete, t) - self.assertEqual(repr(t), 'Task()') + self.assertEqual(repr(t), + 'Task()' % filename) + t = asyncio.Task(notmuch(), loop=self.loop) self.loop.run_until_complete(t) - self.assertEqual(repr(t), "Task()") + self.assertEqual(repr(t), + "Task()" % filename) def test_task_repr_custom(self): @asyncio.coroutine - def coro(): + def notmuch(): pass class T(asyncio.Future): @@ -141,10 +150,14 @@ class MyTask(asyncio.Task, T): def __repr__(self): return super().__repr__() - gen = coro() + gen = notmuch() t = MyTask(gen, loop=self.loop) - self.assertEqual(repr(t), 'T[]()') - gen.close() + filename = gen.gi_code.co_filename + lineno = gen.gi_frame.f_lineno + # FIXME: check for the name "coro" instead of "notmuch" because + # @asyncio.coroutine drops the name of the wrapped function: + # http://bugs.python.org/issue21205 + self.assertEqual(repr(t), 'T[]()' % (filename, lineno)) def test_task_basics(self): @asyncio.coroutine From d23369a5d4698dbe14f56b909bb5f397138778e6 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 16 Jun 2014 07:50:13 -0700 Subject: [PATCH 1030/1502] Do not offer to copy selectors.py to/from 3.4 CPython branch in update script. --- update_stdlib.sh | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/update_stdlib.sh b/update_stdlib.sh index 9e054659..bb6251a0 100755 --- a/update_stdlib.sh +++ b/update_stdlib.sh @@ -42,7 +42,12 @@ for i in `(cd asyncio && ls *.py)` do if [ $i == selectors.py ] then - maybe_copy asyncio/$i Lib/$i + if [ "`(cd $CPYTHON; hg branch)`" == "3.4" ] + then + echo "Destination is 3.4 branch -- ignoring selectors.py" + else + maybe_copy asyncio/$i Lib/$i + fi else maybe_copy asyncio/$i Lib/asyncio/$i fi From 1b7b4134e6ce5a46f1e424ca69d1d2502a681d13 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 16 Jun 2014 17:07:20 +0200 Subject: [PATCH 1031/1502] Fix test_tasks for Python 3.5 On Python 3.5, generator now gets their name from the function, no more from the code. So we get the expected "notmuch" name instead of the generic "coro" name. --- tests/test_tasks.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 92eb9dae..4e239ecb 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -2,6 +2,7 @@ import gc import os.path +import sys import types import unittest import weakref @@ -154,10 +155,13 @@ def __repr__(self): t = MyTask(gen, loop=self.loop) filename = gen.gi_code.co_filename lineno = gen.gi_frame.f_lineno - # FIXME: check for the name "coro" instead of "notmuch" because - # @asyncio.coroutine drops the name of the wrapped function: - # http://bugs.python.org/issue21205 - self.assertEqual(repr(t), 'T[]()' % (filename, lineno)) + if sys.version_info >= (3, 5): + name = 'notmuch' + else: + # On Python < 3.5, generators inherit the name of the code, not of + # the function. See: http://bugs.python.org/issue21205 + name = 'coro' + self.assertEqual(repr(t), 'T[](<%s at %s:%s>)' % (name, filename, lineno)) def test_task_basics(self): @asyncio.coroutine From 34d9d475f2c73af685faff3fcfa21ec4c598832d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 17 Jun 2014 00:24:45 +0200 Subject: [PATCH 1032/1502] Task.__repr__() now also handles CoroWrapper --- asyncio/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index e6fd3d38..281bf608 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -190,7 +190,7 @@ def __repr__(self): i = len(res) text = self._coro.__name__ coro = self._coro - if inspect.isgenerator(coro): + if iscoroutine(coro): filename = coro.gi_code.co_filename if coro.gi_frame is not None: text += ' at %s:%s' % (filename, coro.gi_frame.f_lineno) From a9faa0eb690bebbac9be194259a96454d7fc7fb1 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 18 Jun 2014 00:05:54 +0200 Subject: [PATCH 1033/1502] Python issue 21723: asyncio.Queue: support any type of number (ex: float) for the maximum size. Patch written by Vajrasky Kok. --- asyncio/queues.py | 6 +++--- tests/test_queues.py | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/asyncio/queues.py b/asyncio/queues.py index 6283db32..57afb053 100644 --- a/asyncio/queues.py +++ b/asyncio/queues.py @@ -105,7 +105,7 @@ def full(self): if self._maxsize <= 0: return False else: - return self.qsize() == self._maxsize + return self.qsize() >= self._maxsize @coroutine def put(self, item): @@ -126,7 +126,7 @@ def put(self, item): self._put(item) getter.set_result(self._get()) - elif self._maxsize > 0 and self._maxsize == self.qsize(): + elif self._maxsize > 0 and self._maxsize <= self.qsize(): waiter = futures.Future(loop=self._loop) self._putters.append((item, waiter)) @@ -152,7 +152,7 @@ def put_nowait(self, item): self._put(item) getter.set_result(self._get()) - elif self._maxsize > 0 and self._maxsize == self.qsize(): + elif self._maxsize > 0 and self._maxsize <= self.qsize(): raise QueueFull else: self._put(item) diff --git a/tests/test_queues.py b/tests/test_queues.py index f79fee21..820234df 100644 --- a/tests/test_queues.py +++ b/tests/test_queues.py @@ -339,6 +339,21 @@ def test_nonblocking_put_exception(self): q.put_nowait(1) self.assertRaises(asyncio.QueueFull, q.put_nowait, 2) + def test_float_maxsize(self): + q = asyncio.Queue(maxsize=1.3, loop=self.loop) + q.put_nowait(1) + q.put_nowait(2) + self.assertTrue(q.full()) + self.assertRaises(asyncio.QueueFull, q.put_nowait, 3) + + q = asyncio.Queue(maxsize=1.3, loop=self.loop) + @asyncio.coroutine + def queue_put(): + yield from q.put(1) + yield from q.put(2) + self.assertTrue(q.full()) + self.loop.run_until_complete(queue_put()) + def test_put_cancelled(self): q = asyncio.Queue(loop=self.loop) From ff281d119a51fa74bf39925579616138e268eec8 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 18 Jun 2014 01:09:09 +0200 Subject: [PATCH 1034/1502] Set __qualname__ attribute of CoroWrapper in @coroutine decorator - Drop __slots__ optimization of CoroWrapper to be able to set the __qualname__ attribute. - Add tests on __name__, __qualname__ and __module__ of a coroutine function and coroutine object. - Fix test_tasks when run in debug mode (PYTHONASYNCIODEBUG env var set) on Python 3.3 or 3.4 --- asyncio/tasks.py | 10 ++++++---- tests/test_tasks.py | 48 +++++++++++++++++++++++++++++++++++++-------- 2 files changed, 46 insertions(+), 12 deletions(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 281bf608..eaf93f88 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -32,12 +32,12 @@ _DEBUG = (not sys.flags.ignore_environment and bool(os.environ.get('PYTHONASYNCIODEBUG'))) +_PY35 = (sys.version_info >= (3, 5)) + class CoroWrapper: # Wrapper for coroutine in _DEBUG mode. - __slots__ = ['gen', 'func', '__name__', '__doc__', '__weakref__'] - def __init__(self, gen, func): assert inspect.isgenerator(gen), gen self.gen = gen @@ -111,8 +111,10 @@ def coro(*args, **kw): @functools.wraps(func) def wrapper(*args, **kwds): w = CoroWrapper(coro(*args, **kwds), func) - w.__name__ = coro.__name__ - w.__doc__ = coro.__doc__ + w.__name__ = func.__name__ + if _PY35: + w.__qualname__ = func.__qualname__ + w.__doc__ = func.__doc__ return w wrapper._is_coroutine = True # For iscoroutinefunction(). diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 4e239ecb..dcc81234 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -9,9 +9,13 @@ from test.script_helper import assert_python_ok import asyncio +from asyncio import tasks from asyncio import test_utils +PY35 = (sys.version_info >= (3, 5)) + + @asyncio.coroutine def coroutine_function(): pass @@ -117,10 +121,22 @@ def notmuch(): yield from [] return 'abc' + self.assertEqual(notmuch.__name__, 'notmuch') + if PY35: + self.assertEqual(notmuch.__qualname__, + 'TaskTests.test_task_repr..notmuch') + self.assertEqual(notmuch.__module__, __name__) + filename, lineno = test_utils.get_function_source(notmuch) src = "%s:%s" % (filename, lineno) - t = asyncio.Task(notmuch(), loop=self.loop) + gen = notmuch() + self.assertEqual(gen.__name__, 'notmuch') + if PY35: + self.assertEqual(gen.__qualname__, + 'TaskTests.test_task_repr..notmuch') + + t = asyncio.Task(gen, loop=self.loop) t.add_done_callback(Dummy()) self.assertEqual(repr(t), 'Task()' % src) @@ -143,6 +159,12 @@ def test_task_repr_custom(self): def notmuch(): pass + self.assertEqual(notmuch.__name__, 'notmuch') + self.assertEqual(notmuch.__module__, __name__) + if PY35: + self.assertEqual(notmuch.__qualname__, + 'TaskTests.test_task_repr_custom..notmuch') + class T(asyncio.Future): def __repr__(self): return 'T[]' @@ -152,16 +174,26 @@ def __repr__(self): return super().__repr__() gen = notmuch() - t = MyTask(gen, loop=self.loop) - filename = gen.gi_code.co_filename - lineno = gen.gi_frame.f_lineno - if sys.version_info >= (3, 5): - name = 'notmuch' + if PY35 or tasks._DEBUG: + # On Python >= 3.5, generators now inherit the name of the + # function, as expected, and have a qualified name (__qualname__ + # attribute). In debug mode, @coroutine decorator uses CoroWrapper + # which gets its name (__name__ attribute) from the wrapped + # coroutine function. + coro_name = 'notmuch' else: # On Python < 3.5, generators inherit the name of the code, not of # the function. See: http://bugs.python.org/issue21205 - name = 'coro' - self.assertEqual(repr(t), 'T[](<%s at %s:%s>)' % (name, filename, lineno)) + coro_name = 'coro' + self.assertEqual(gen.__name__, coro_name) + if PY35: + self.assertEqual(gen.__qualname__, + 'TaskTests.test_task_repr_custom..notmuch') + + t = MyTask(gen, loop=self.loop) + filename = gen.gi_code.co_filename + lineno = gen.gi_frame.f_lineno + self.assertEqual(repr(t), 'T[](<%s at %s:%s>)' % (coro_name, filename, lineno)) def test_task_basics(self): @asyncio.coroutine From 976fe112a6ad26a92278c87bc45811becfd1834b Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 18 Jun 2014 01:19:47 +0200 Subject: [PATCH 1035/1502] Refactor test__run_once_logging() to not rely on the exact number of calls to time.monotonic(). Use a "fast select" and a "slow select" instead. --- tests/test_base_events.py | 29 +++++++++++------------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 1611a114..fb28b87e 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -240,30 +240,23 @@ def test_set_debug(self): self.loop.set_debug(False) self.assertFalse(self.loop.get_debug()) - @mock.patch('asyncio.base_events.time') @mock.patch('asyncio.base_events.logger') - def test__run_once_logging(self, m_logger, m_time): - # Log to INFO level if timeout > 1.0 sec. - idx = -1 - data = [10.0, 10.0, 12.0, 13.0] - - def monotonic(): - nonlocal data, idx - idx += 1 - return data[idx] + def test__run_once_logging(self, m_logger): + def slow_select(timeout): + time.sleep(1.0) + return [] - m_time.monotonic = monotonic - - self.loop._scheduled.append( - asyncio.TimerHandle(11.0, lambda: True, (), self.loop)) + # Log to INFO level if timeout > 1.0 sec. + self.loop._selector.select = slow_select self.loop._process_events = mock.Mock() self.loop._run_once() self.assertEqual(logging.INFO, m_logger.log.call_args[0][0]) - idx = -1 - data = [10.0, 10.0, 10.3, 13.0] - self.loop._scheduled = [asyncio.TimerHandle(11.0, lambda: True, (), - self.loop)] + def fast_select(timeout): + time.sleep(0.001) + return [] + + self.loop._selector.select = fast_select self.loop._run_once() self.assertEqual(logging.DEBUG, m_logger.log.call_args[0][0]) From b193771ae76b024216b595e505fff870609445f5 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 18 Jun 2014 01:32:10 +0200 Subject: [PATCH 1036/1502] Refactor tests: add a base TestCase class --- asyncio/test_utils.py | 18 ++++++ tests/test_base_events.py | 11 ++-- tests/test_events.py | 14 ++--- tests/test_futures.py | 25 +++----- tests/test_locks.py | 68 +++++++--------------- tests/test_proactor_events.py | 7 ++- tests/test_queues.py | 32 ++++------- tests/test_selector_events.py | 21 +++---- tests/test_streams.py | 5 +- tests/test_subprocess.py | 10 ++-- tests/test_tasks.py | 105 ++++++++++++---------------------- tests/test_unix_events.py | 40 ++++++------- tests/test_windows_events.py | 8 +-- 13 files changed, 145 insertions(+), 219 deletions(-) diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index 1062bae1..d9c7ae2d 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -11,6 +11,7 @@ import tempfile import threading import time +import unittest from unittest import mock from http.server import HTTPServer @@ -379,3 +380,20 @@ def get_function_source(func): if source is None: raise ValueError("unable to get the source of %r" % (func,)) return source + + +class TestCase(unittest.TestCase): + def set_event_loop(self, loop, *, cleanup=True): + assert loop is not None + # ensure that the event loop is passed explicitly in asyncio + events.set_event_loop(None) + if cleanup: + self.addCleanup(loop.close) + + def new_test_loop(self, gen=None): + loop = TestLoop(gen) + self.set_event_loop(loop) + return loop + + def tearDown(self): + events.set_event_loop(None) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index fb28b87e..059b41c3 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -19,12 +19,12 @@ PY34 = sys.version_info >= (3, 4) -class BaseEventLoopTests(unittest.TestCase): +class BaseEventLoopTests(test_utils.TestCase): def setUp(self): self.loop = base_events.BaseEventLoop() self.loop._selector = mock.Mock() - asyncio.set_event_loop(None) + self.set_event_loop(self.loop) def test_not_implemented(self): m = mock.Mock() @@ -548,14 +548,11 @@ def connection_lost(self, exc): self.done.set_result(None) -class BaseEventLoopWithSelectorTests(unittest.TestCase): +class BaseEventLoopWithSelectorTests(test_utils.TestCase): def setUp(self): self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(None) - - def tearDown(self): - self.loop.close() + self.set_event_loop(self.loop) @mock.patch('asyncio.base_events.socket') def test_create_connection_multiple_errors(self, m_socket): diff --git a/tests/test_events.py b/tests/test_events.py index 2262a752..37e45e1d 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -224,7 +224,7 @@ class EventLoopTestsMixin: def setUp(self): super().setUp() self.loop = self.create_event_loop() - asyncio.set_event_loop(None) + self.set_event_loop(self.loop) def tearDown(self): # just in case if we have transport close callbacks @@ -1629,14 +1629,14 @@ def connect(cmd=None, **kwds): if sys.platform == 'win32': - class SelectEventLoopTests(EventLoopTestsMixin, unittest.TestCase): + class SelectEventLoopTests(EventLoopTestsMixin, test_utils.TestCase): def create_event_loop(self): return asyncio.SelectorEventLoop() class ProactorEventLoopTests(EventLoopTestsMixin, SubprocessTestsMixin, - unittest.TestCase): + test_utils.TestCase): def create_event_loop(self): return asyncio.ProactorEventLoop() @@ -1691,7 +1691,7 @@ def tearDown(self): if hasattr(selectors, 'KqueueSelector'): class KqueueEventLoopTests(UnixEventLoopTestsMixin, SubprocessTestsMixin, - unittest.TestCase): + test_utils.TestCase): def create_event_loop(self): return asyncio.SelectorEventLoop( @@ -1716,7 +1716,7 @@ def test_write_pty(self): if hasattr(selectors, 'EpollSelector'): class EPollEventLoopTests(UnixEventLoopTestsMixin, SubprocessTestsMixin, - unittest.TestCase): + test_utils.TestCase): def create_event_loop(self): return asyncio.SelectorEventLoop(selectors.EpollSelector()) @@ -1724,7 +1724,7 @@ def create_event_loop(self): if hasattr(selectors, 'PollSelector'): class PollEventLoopTests(UnixEventLoopTestsMixin, SubprocessTestsMixin, - unittest.TestCase): + test_utils.TestCase): def create_event_loop(self): return asyncio.SelectorEventLoop(selectors.PollSelector()) @@ -1732,7 +1732,7 @@ def create_event_loop(self): # Should always exist. class SelectEventLoopTests(UnixEventLoopTestsMixin, SubprocessTestsMixin, - unittest.TestCase): + test_utils.TestCase): def create_event_loop(self): return asyncio.SelectorEventLoop(selectors.SelectSelector()) diff --git a/tests/test_futures.py b/tests/test_futures.py index 399e8f43..a230d614 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -13,14 +13,10 @@ def _fakefunc(f): return f -class FutureTests(unittest.TestCase): +class FutureTests(test_utils.TestCase): def setUp(self): - self.loop = test_utils.TestLoop() - asyncio.set_event_loop(None) - - def tearDown(self): - self.loop.close() + self.loop = self.new_test_loop() def test_initial_state(self): f = asyncio.Future(loop=self.loop) @@ -30,12 +26,9 @@ def test_initial_state(self): self.assertTrue(f.cancelled()) def test_init_constructor_default_loop(self): - try: - asyncio.set_event_loop(self.loop) - f = asyncio.Future() - self.assertIs(f._loop, self.loop) - finally: - asyncio.set_event_loop(None) + asyncio.set_event_loop(self.loop) + f = asyncio.Future() + self.assertIs(f._loop, self.loop) def test_constructor_positional(self): # Make sure Future doesn't accept a positional argument @@ -264,14 +257,10 @@ def test_wrap_future_cancel2(self): self.assertTrue(f2.cancelled()) -class FutureDoneCallbackTests(unittest.TestCase): +class FutureDoneCallbackTests(test_utils.TestCase): def setUp(self): - self.loop = test_utils.TestLoop() - asyncio.set_event_loop(None) - - def tearDown(self): - self.loop.close() + self.loop = self.new_test_loop() def run_briefly(self): test_utils.run_briefly(self.loop) diff --git a/tests/test_locks.py b/tests/test_locks.py index f542463a..9d50a71f 100644 --- a/tests/test_locks.py +++ b/tests/test_locks.py @@ -17,14 +17,10 @@ RGX_REPR = re.compile(STR_RGX_REPR) -class LockTests(unittest.TestCase): +class LockTests(test_utils.TestCase): def setUp(self): - self.loop = test_utils.TestLoop() - asyncio.set_event_loop(None) - - def tearDown(self): - self.loop.close() + self.loop = self.new_test_loop() def test_ctor_loop(self): loop = mock.Mock() @@ -35,12 +31,9 @@ def test_ctor_loop(self): self.assertIs(lock._loop, self.loop) def test_ctor_noloop(self): - try: - asyncio.set_event_loop(self.loop) - lock = asyncio.Lock() - self.assertIs(lock._loop, self.loop) - finally: - asyncio.set_event_loop(None) + asyncio.set_event_loop(self.loop) + lock = asyncio.Lock() + self.assertIs(lock._loop, self.loop) def test_repr(self): lock = asyncio.Lock(loop=self.loop) @@ -240,14 +233,10 @@ def test_context_manager_no_yield(self): self.assertFalse(lock.locked()) -class EventTests(unittest.TestCase): +class EventTests(test_utils.TestCase): def setUp(self): - self.loop = test_utils.TestLoop() - asyncio.set_event_loop(None) - - def tearDown(self): - self.loop.close() + self.loop = self.new_test_loop() def test_ctor_loop(self): loop = mock.Mock() @@ -258,12 +247,9 @@ def test_ctor_loop(self): self.assertIs(ev._loop, self.loop) def test_ctor_noloop(self): - try: - asyncio.set_event_loop(self.loop) - ev = asyncio.Event() - self.assertIs(ev._loop, self.loop) - finally: - asyncio.set_event_loop(None) + asyncio.set_event_loop(self.loop) + ev = asyncio.Event() + self.assertIs(ev._loop, self.loop) def test_repr(self): ev = asyncio.Event(loop=self.loop) @@ -376,14 +362,10 @@ def c1(result): self.assertTrue(t.result()) -class ConditionTests(unittest.TestCase): +class ConditionTests(test_utils.TestCase): def setUp(self): - self.loop = test_utils.TestLoop() - asyncio.set_event_loop(None) - - def tearDown(self): - self.loop.close() + self.loop = self.new_test_loop() def test_ctor_loop(self): loop = mock.Mock() @@ -394,12 +376,9 @@ def test_ctor_loop(self): self.assertIs(cond._loop, self.loop) def test_ctor_noloop(self): - try: - asyncio.set_event_loop(self.loop) - cond = asyncio.Condition() - self.assertIs(cond._loop, self.loop) - finally: - asyncio.set_event_loop(None) + asyncio.set_event_loop(self.loop) + cond = asyncio.Condition() + self.assertIs(cond._loop, self.loop) def test_wait(self): cond = asyncio.Condition(loop=self.loop) @@ -678,14 +657,10 @@ def test_context_manager_no_yield(self): self.assertFalse(cond.locked()) -class SemaphoreTests(unittest.TestCase): +class SemaphoreTests(test_utils.TestCase): def setUp(self): - self.loop = test_utils.TestLoop() - asyncio.set_event_loop(None) - - def tearDown(self): - self.loop.close() + self.loop = self.new_test_loop() def test_ctor_loop(self): loop = mock.Mock() @@ -696,12 +671,9 @@ def test_ctor_loop(self): self.assertIs(sem._loop, self.loop) def test_ctor_noloop(self): - try: - asyncio.set_event_loop(self.loop) - sem = asyncio.Semaphore() - self.assertIs(sem._loop, self.loop) - finally: - asyncio.set_event_loop(None) + asyncio.set_event_loop(self.loop) + sem = asyncio.Semaphore() + self.assertIs(sem._loop, self.loop) def test_initial_value_zero(self): sem = asyncio.Semaphore(0, loop=self.loop) diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index 5bf24a45..ddfceae1 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -12,10 +12,10 @@ from asyncio import test_utils -class ProactorSocketTransportTests(unittest.TestCase): +class ProactorSocketTransportTests(test_utils.TestCase): def setUp(self): - self.loop = test_utils.TestLoop() + self.loop = self.new_test_loop() self.proactor = mock.Mock() self.loop._proactor = self.proactor self.protocol = test_utils.make_test_protocol(asyncio.Protocol) @@ -343,7 +343,7 @@ def test_pause_resume_reading(self): tr.close() -class BaseProactorEventLoopTests(unittest.TestCase): +class BaseProactorEventLoopTests(test_utils.TestCase): def setUp(self): self.sock = mock.Mock(socket.socket) @@ -356,6 +356,7 @@ def _socketpair(s): return (self.ssock, self.csock) self.loop = EventLoop(self.proactor) + self.set_event_loop(self.loop, cleanup=False) @mock.patch.object(BaseProactorEventLoop, 'call_soon') @mock.patch.object(BaseProactorEventLoop, '_socketpair') diff --git a/tests/test_queues.py b/tests/test_queues.py index 820234df..32c90f47 100644 --- a/tests/test_queues.py +++ b/tests/test_queues.py @@ -7,14 +7,10 @@ from asyncio import test_utils -class _QueueTestBase(unittest.TestCase): +class _QueueTestBase(test_utils.TestCase): def setUp(self): - self.loop = test_utils.TestLoop() - asyncio.set_event_loop(None) - - def tearDown(self): - self.loop.close() + self.loop = self.new_test_loop() class QueueBasicTests(_QueueTestBase): @@ -32,8 +28,7 @@ def gen(): self.assertAlmostEqual(0.2, when) yield 0.1 - loop = test_utils.TestLoop(gen) - self.addCleanup(loop.close) + loop = self.new_test_loop(gen) q = asyncio.Queue(loop=loop) self.assertTrue(fn(q).startswith(' Date: Wed, 18 Jun 2014 03:24:50 +0200 Subject: [PATCH 1037/1502] Fix pyflakes errors - Add a missing import - Remove an unused import - Remove unused variables --- tests/test_selector_events.py | 5 +---- tests/test_tasks.py | 1 - tests/test_windows_events.py | 1 + 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index b1148d2e..7c84f03f 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -108,10 +108,7 @@ def callback(): self.assertRaises(RuntimeError, self.loop.add_writer, fd, callback) def test_close_no_selector(self): - ssock = self.loop._ssock - csock = self.loop._csock - remove_reader = self.loop.remove_reader = mock.Mock() - + self.loop.remove_reader = mock.Mock() self.loop._selector.close() self.loop._selector = None self.loop.close() diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 0ed2f941..e95c7dcb 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1,6 +1,5 @@ """Tests for tasks.py.""" -import gc import os.path import sys import types diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index ca79c437..4ab56e6c 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -9,6 +9,7 @@ import asyncio from asyncio import _overlapped +from asyncio import test_utils from asyncio import windows_events From 88388887ddddf62bd5530232ce7700405685b7bc Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 19 Jun 2014 12:55:55 +0200 Subject: [PATCH 1038/1502] Python issue 21595: BaseSelectorEventLoop._read_from_self() reads all available bytes from the "self pipe", not only a single byte. This change reduces the risk of having the pipe full and so getting the "BlockingIOError: [Errno 11] Resource temporarily unavailable" message. --- asyncio/selector_events.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 1f8e5c8b..854e8151 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -83,10 +83,15 @@ def _make_self_pipe(self): self.add_reader(self._ssock.fileno(), self._read_from_self) def _read_from_self(self): - try: - self._ssock.recv(1) - except (BlockingIOError, InterruptedError): - pass + while True: + try: + data = self._ssock.recv(4096) + if not data: + break + except InterruptedError: + continue + except BlockingIOError: + break def _write_to_self(self): # This may be called from a different thread, possibly after From 7d76e9d6f64ef6f2aecf1d429288750fa0818a61 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 19 Jun 2014 17:12:24 +0200 Subject: [PATCH 1039/1502] Tulip issue #83: document more functions in docstrings --- asyncio/base_events.py | 21 ++++++++++++++++--- asyncio/selector_events.py | 41 ++++++++++++++++++++++++++++++++++---- 2 files changed, 55 insertions(+), 7 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 5ee21d1c..93509892 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -320,7 +320,7 @@ def _assert_is_current_event_loop(self): "than the current one") def call_soon_threadsafe(self, callback, *args): - """XXX""" + """Like call_soon(), but thread safe.""" handle = self._call_soon(callback, args, check_loop=False) self._write_to_self() return handle @@ -358,7 +358,17 @@ def getnameinfo(self, sockaddr, flags=0): def create_connection(self, protocol_factory, host=None, port=None, *, ssl=None, family=0, proto=0, flags=0, sock=None, local_addr=None, server_hostname=None): - """XXX""" + """Connect to a TCP server. + + Create a streaming transport connection to a given Internet host and + port: socket family AF_INET or socket.AF_INET6 depending on host (or + family if specified), socket type SOCK_STREAM. protocol_factory must be + a callable returning a protocol instance. + + This method is a coroutine which will try to establish the connection + in the background. When successful, the coroutine returns a + (transport, protocol) pair. + """ if server_hostname is not None and not ssl: raise ValueError('server_hostname is only meaningful with ssl') @@ -557,7 +567,12 @@ def create_server(self, protocol_factory, host=None, port=None, backlog=100, ssl=None, reuse_address=None): - """XXX""" + """Create a TCP server bound to host and port. + + Return an AbstractServer object which can be used to stop the service. + + This method is a coroutine. + """ if isinstance(ssl, bool): raise TypeError('ssl argument must be an SSLContext or None') if host is not None or port is not None: diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 854e8151..a62a8e58 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -226,7 +226,14 @@ def remove_writer(self, fd): return False def sock_recv(self, sock, n): - """XXX""" + """Receive data from the socket. + + The return value is a bytes object representing the data received. + The maximum amount of data to be received at once is specified by + nbytes. + + This method is a coroutine. + """ fut = futures.Future(loop=self) self._sock_recv(fut, False, sock, n) return fut @@ -253,7 +260,16 @@ def _sock_recv(self, fut, registered, sock, n): fut.set_result(data) def sock_sendall(self, sock, data): - """XXX""" + """Send data to the socket. + + The socket must be connected to a remote socket. This method continues + to send data from data until either all data has been sent or an + error occurs. None is returned on success. On error, an exception is + raised, and there is no way to determine how much data, if any, was + successfully processed by the receiving end of the connection. + + This method is a coroutine. + """ fut = futures.Future(loop=self) if data: self._sock_sendall(fut, False, sock, data) @@ -285,7 +301,16 @@ def _sock_sendall(self, fut, registered, sock, data): self.add_writer(fd, self._sock_sendall, fut, True, sock, data) def sock_connect(self, sock, address): - """XXX""" + """Connect to a remote socket at address. + + The address must be already resolved to avoid the trap of hanging the + entire event loop when the address requires doing a DNS lookup. For + example, it must be an IP address, not an hostname, for AF_INET and + AF_INET6 address families. Use getaddrinfo() to resolve the hostname + asynchronously. + + This method is a coroutine. + """ fut = futures.Future(loop=self) try: base_events._check_resolved_address(sock, address) @@ -318,7 +343,15 @@ def _sock_connect(self, fut, registered, sock, address): fut.set_result(None) def sock_accept(self, sock): - """XXX""" + """Accept a connection. + + The socket must be bound to an address and listening for connections. + The return value is a pair (conn, address) where conn is a new socket + object usable to send and receive data on the connection, and address + is the address bound to the socket on the other end of the connection. + + This method is a coroutine. + """ fut = futures.Future(loop=self) self._sock_accept(fut, False, sock) return fut From c2cdef149096b74459c24b9c259341e2142eee78 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 20 Jun 2014 17:32:03 +0200 Subject: [PATCH 1040/1502] Tulip issue #105: in debug mode, log callbacks taking more than 100 ms to be executed. --- asyncio/base_events.py | 32 +++++++++++++++++++++++++++----- tests/test_base_events.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 5 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 93509892..9f9067ed 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -17,6 +17,7 @@ import collections import concurrent.futures import heapq +import inspect import logging import socket import subprocess @@ -37,6 +38,15 @@ _MAX_WORKERS = 5 +def _format_handle(handle): + cb = handle._callback + if inspect.ismethod(cb) and isinstance(cb.__self__, tasks.Task): + # format the task + return repr(cb.__self__) + else: + return str(handle) + + class _StopError(BaseException): """Raised to stop the event loop.""" @@ -128,6 +138,9 @@ def __init__(self): self._clock_resolution = time.get_clock_info('monotonic').resolution self._exception_handler = None self._debug = False + # In debug mode, if the execution of a callback or a step of a task + # exceed this duration in seconds, the slow callback/task is logged. + self.slow_callback_duration = 0.1 def __repr__(self): return ('<%s running=%s closed=%s debug=%s>' @@ -823,16 +836,16 @@ def _run_once(self): if logger.isEnabledFor(logging.INFO): t0 = self.time() event_list = self._selector.select(timeout) - t1 = self.time() - if t1-t0 >= 1: + dt = self.time() - t0 + if dt >= 1: level = logging.INFO else: level = logging.DEBUG if timeout is not None: logger.log(level, 'poll %.3f took %.3f seconds', - timeout, t1-t0) + timeout, dt) else: - logger.log(level, 'poll took %.3f seconds', t1-t0) + logger.log(level, 'poll took %.3f seconds', dt) else: event_list = self._selector.select(timeout) self._process_events(event_list) @@ -855,7 +868,16 @@ def _run_once(self): ntodo = len(self._ready) for i in range(ntodo): handle = self._ready.popleft() - if not handle._cancelled: + if handle._cancelled: + continue + if self._debug: + t0 = self.time() + handle._run() + dt = self.time() - t0 + if dt >= self.slow_callback_duration: + logger.warning('Executing %s took %.3f seconds', + _format_handle(handle), dt) + else: handle._run() handle = None # Needed to break cycles when an exception occurs. diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 059b41c3..352af488 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -969,6 +969,34 @@ def coroutine_function(): with self.assertRaises(TypeError): self.loop.run_in_executor(None, coroutine_function) + @mock.patch('asyncio.base_events.logger') + def test_log_slow_callbacks(self, m_logger): + def stop_loop_cb(loop): + loop.stop() + + @asyncio.coroutine + def stop_loop_coro(loop): + yield from () + loop.stop() + + asyncio.set_event_loop(self.loop) + self.loop.set_debug(True) + self.loop.slow_callback_duration = 0.0 + + # slow callback + self.loop.call_soon(stop_loop_cb, self.loop) + self.loop.run_forever() + fmt, *args = m_logger.warning.call_args[0] + self.assertRegex(fmt % tuple(args), + "^Executing Handle.*stop_loop_cb.* took .* seconds$") + + # slow task + asyncio.async(stop_loop_coro(self.loop), loop=self.loop) + self.loop.run_forever() + fmt, *args = m_logger.warning.call_args[0] + self.assertRegex(fmt % tuple(args), + "^Executing Task.*stop_loop_coro.* took .* seconds$") + if __name__ == '__main__': unittest.main() From 367335eeaf7a1137da74f2497945a7a02499e90a Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 23 Jun 2014 00:02:51 +0200 Subject: [PATCH 1041/1502] BaseEventLoop._assert_is_current_event_loop() now only raises an exception if the current loop is not None. Guido van Rossum wrote: "The behavior that you can set the loop to None (and keep track of it explicitly) is part of the spec, and this should still be supported even in debug mode. The behavior that we raise an error if you are caught having multiple active loops per thread is just a debugging heuristic, and it shouldn't break code that follows the spec." --- asyncio/base_events.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 9f9067ed..2227a26e 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -327,7 +327,8 @@ def _assert_is_current_event_loop(self): Should only be called when (self._debug == True). The caller is responsible for checking this condition for performance reasons. """ - if events.get_event_loop() is not self: + current = events.get_event_loop() + if current is not None and current is not self: raise RuntimeError( "non-threadsafe operation invoked on an event loop other " "than the current one") From 7ae661e4732ec8d470588f4dd5221338230ca6d1 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 23 Jun 2014 00:11:39 +0200 Subject: [PATCH 1042/1502] Enable the debug mode of event loops when the PYTHONASYNCIODEBUG environment variable is set --- asyncio/base_events.py | 3 ++- tests/test_selector_events.py | 2 -- tests/test_subprocess.py | 4 ++-- tests/test_tasks.py | 2 ++ tests/test_unix_events.py | 8 -------- 5 files changed, 6 insertions(+), 13 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 2227a26e..0975bcb6 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -137,7 +137,8 @@ def __init__(self): self._running = False self._clock_resolution = time.get_clock_info('monotonic').resolution self._exception_handler = None - self._debug = False + self._debug = (not sys.flags.ignore_environment + and bool(os.environ.get('PYTHONASYNCIODEBUG'))) # In debug mode, if the execution of a callback or a step of a task # exceed this duration in seconds, the slow callback/task is logged. self.slow_callback_duration = 0.1 diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 7c84f03f..35efab97 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -682,8 +682,6 @@ def test_connection_lost(self): self.assertEqual(2, sys.getrefcount(self.protocol), pprint.pformat(gc.get_referrers(self.protocol))) self.assertIsNone(tr._loop) - self.assertEqual(3, sys.getrefcount(self.loop), - pprint.pformat(gc.get_referrers(self.loop))) class SelectorSocketTransportTests(test_utils.TestCase): diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index 3b962bf9..3204d42e 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -141,7 +141,7 @@ def setUp(self): policy = asyncio.get_event_loop_policy() self.loop = policy.new_event_loop() - # ensure that the event loop is passed explicitly in the code + # ensure that the event loop is passed explicitly in asyncio policy.set_event_loop(None) watcher = self.Watcher() @@ -172,7 +172,7 @@ def setUp(self): policy = asyncio.get_event_loop_policy() self.loop = asyncio.ProactorEventLoop() - # ensure that the event loop is passed explicitly in the code + # ensure that the event loop is passed explicitly in asyncio policy.set_event_loop(None) def tearDown(self): diff --git a/tests/test_tasks.py b/tests/test_tasks.py index e95c7dcb..3c358a21 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1712,6 +1712,8 @@ def coro(): self.assertIs(fut._loop, self.one_loop) gen1.close() gen2.close() + + self.set_event_loop(self.other_loop, cleanup=False) gen3 = coro() gen4 = coro() fut = asyncio.gather(gen3, gen4, loop=self.other_loop) diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 89a4c103..0ade7f21 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -445,8 +445,6 @@ def test__call_connection_lost(self): self.assertEqual(2, sys.getrefcount(self.protocol), pprint.pformat(gc.get_referrers(self.protocol))) self.assertIsNone(tr._loop) - self.assertEqual(5, sys.getrefcount(self.loop), - pprint.pformat(gc.get_referrers(self.loop))) def test__call_connection_lost_with_err(self): tr = unix_events._UnixReadPipeTransport( @@ -462,8 +460,6 @@ def test__call_connection_lost_with_err(self): self.assertEqual(2, sys.getrefcount(self.protocol), pprint.pformat(gc.get_referrers(self.protocol))) self.assertIsNone(tr._loop) - self.assertEqual(5, sys.getrefcount(self.loop), - pprint.pformat(gc.get_referrers(self.loop))) class UnixWritePipeTransportTests(test_utils.TestCase): @@ -731,8 +727,6 @@ def test__call_connection_lost(self): self.assertEqual(2, sys.getrefcount(self.protocol), pprint.pformat(gc.get_referrers(self.protocol))) self.assertIsNone(tr._loop) - self.assertEqual(5, sys.getrefcount(self.loop), - pprint.pformat(gc.get_referrers(self.loop))) def test__call_connection_lost_with_err(self): tr = unix_events._UnixWritePipeTransport( @@ -747,8 +741,6 @@ def test__call_connection_lost_with_err(self): self.assertEqual(2, sys.getrefcount(self.protocol), pprint.pformat(gc.get_referrers(self.protocol))) self.assertIsNone(tr._loop) - self.assertEqual(5, sys.getrefcount(self.loop), - pprint.pformat(gc.get_referrers(self.loop))) def test_close(self): tr = unix_events._UnixWritePipeTransport( From b527b7a854f2c3d4a5611ef07d70016da8a6f2ff Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 23 Jun 2014 00:19:08 +0200 Subject: [PATCH 1043/1502] Add an unit test to check that setting the PYTHONASYNCIODEBUG env var enables debug mode of the event loop. --- tests/test_base_events.py | 24 ++++++++++++++++++++++++ tests/test_tasks.py | 4 ---- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 352af488..b238f4ea 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -7,6 +7,7 @@ import time import unittest from unittest import mock +from test.script_helper import assert_python_ok from test.support import IPV6_ENABLED import asyncio @@ -489,6 +490,29 @@ def custom_handler(loop, context): self.assertIs(type(_context['context']['exception']), ZeroDivisionError) + def test_env_var_debug(self): + code = '\n'.join(( + 'import asyncio', + 'loop = asyncio.get_event_loop()', + 'print(loop.get_debug())')) + + # Test with -E to not fail if the unit test was run with + # PYTHONASYNCIODEBUG set to a non-empty string + sts, stdout, stderr = assert_python_ok('-E', '-c', code) + self.assertEqual(stdout.rstrip(), b'False') + + sts, stdout, stderr = assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='') + self.assertEqual(stdout.rstrip(), b'False') + + sts, stdout, stderr = assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='1') + self.assertEqual(stdout.rstrip(), b'True') + + sts, stdout, stderr = assert_python_ok('-E', '-c', code, + PYTHONASYNCIODEBUG='1') + self.assertEqual(stdout.rstrip(), b'False') + class MyProto(asyncio.Protocol): done = None diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 3c358a21..4b55a8af 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1577,11 +1577,7 @@ def test_return_exceptions(self): self.assertEqual(fut.result(), [3, 1, exc, exc2]) def test_env_var_debug(self): - path = os.path.dirname(asyncio.__file__) - path = os.path.normpath(os.path.join(path, '..')) code = '\n'.join(( - 'import sys', - 'sys.path.insert(0, %r)' % path, 'import asyncio.tasks', 'print(asyncio.tasks._DEBUG)')) From fa740d29b3280b7548240525fc48080564b13aaf Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 23 Jun 2014 00:30:32 +0200 Subject: [PATCH 1044/1502] Tulip issue #172: only log selector timing in debug mode --- asyncio/base_events.py | 3 +-- tests/test_base_events.py | 3 +++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 0975bcb6..2f7f1979 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -834,8 +834,7 @@ def _run_once(self): when = self._scheduled[0]._when timeout = max(0, when - self.time()) - # TODO: Instrumentation only in debug mode? - if logger.isEnabledFor(logging.INFO): + if self._debug: t0 = self.time() event_list = self._selector.select(timeout) dt = self.time() - t0 diff --git a/tests/test_base_events.py b/tests/test_base_events.py index b238f4ea..9fa3e6d2 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -247,6 +247,9 @@ def slow_select(timeout): time.sleep(1.0) return [] + # logging needs debug flag + self.loop.set_debug(True) + # Log to INFO level if timeout > 1.0 sec. self.loop._selector.select = slow_select self.loop._process_events = mock.Mock() From 3af3a872f9dd163d3a231c1ce92a0a8b09649462 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 23 Jun 2014 00:59:37 +0200 Subject: [PATCH 1045/1502] Tulip issue #171: BaseEventLoop.close() now raises an exception if the event loop is running. You must first stop the event loop and then wait until it stopped, before closing it. --- asyncio/base_events.py | 4 ++++ asyncio/proactor_events.py | 2 +- asyncio/selector_events.py | 2 +- asyncio/unix_events.py | 2 +- tests/test_events.py | 9 +++++++++ 5 files changed, 16 insertions(+), 3 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 2f7f1979..42d8b0b4 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -247,7 +247,11 @@ def close(self): This clears the queues and shuts down the executor, but does not wait for the executor to finish. + + The event loop must not be running. """ + if self._running: + raise RuntimeError("cannot close a running event loop") if self._closed: return self._closed = True diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 757a22e8..b76f69ee 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -355,12 +355,12 @@ def _make_write_pipe_transport(self, sock, protocol, waiter=None, def close(self): if self.is_closed(): return + super().close() self._stop_accept_futures() self._close_self_pipe() self._proactor.close() self._proactor = None self._selector = None - super().close() def sock_recv(self, sock, n): return self._proactor.recv(sock, n) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index a62a8e58..df64aece 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -57,11 +57,11 @@ def _make_datagram_transport(self, sock, protocol, def close(self): if self.is_closed(): return + super().close() self._close_self_pipe() if self._selector is not None: self._selector.close() self._selector = None - super().close() def _socketpair(self): raise NotImplementedError diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index acb327d9..ad4c2294 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -44,9 +44,9 @@ def _socketpair(self): return socket.socketpair() def close(self): + super().close() for sig in list(self._signal_handlers): self.remove_signal_handler(sig) - super().close() def add_signal_handler(self, sig, callback, *args): """Add a handler for a signal. UNIX only. diff --git a/tests/test_events.py b/tests/test_events.py index 37e45e1d..020d1230 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1365,6 +1365,15 @@ def test_add_fds_after_closing(self): with self.assertRaises(RuntimeError): loop.add_writer(w, callback) + def test_close_running_event_loop(self): + @asyncio.coroutine + def close_loop(loop): + self.loop.close() + + coro = close_loop(self.loop) + with self.assertRaises(RuntimeError): + self.loop.run_until_complete(coro) + class SubprocessTestsMixin: From 31ca5f457cbbcf237e144e94edd7356a20dca733 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 23 Jun 2014 15:13:35 +0200 Subject: [PATCH 1046/1502] Fix BaseEventLoop._assert_is_current_event_loop(): get_event_loop() raises an exception if there is no current loop --- asyncio/base_events.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 42d8b0b4..b1271429 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -332,8 +332,11 @@ def _assert_is_current_event_loop(self): Should only be called when (self._debug == True). The caller is responsible for checking this condition for performance reasons. """ - current = events.get_event_loop() - if current is not None and current is not self: + try: + current = events.get_event_loop() + except AssertionError: + return + if current is not self: raise RuntimeError( "non-threadsafe operation invoked on an event loop other " "than the current one") From d9fe07a21f9d01095ce9859dea4f9e6e93de9abf Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 23 Jun 2014 22:21:58 +0200 Subject: [PATCH 1047/1502] update MANIFEST.in --- MANIFEST.in | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/MANIFEST.in b/MANIFEST.in index 317dcc3d..b647f6ac 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,7 +1,11 @@ +include AUTHORS COPYING include Makefile -include *.c *.py +include overlapped.c +include check.py runtests.py +include update_stdlib.sh recursive-include examples *.py recursive-include tests *.crt recursive-include tests *.key +recursive-include tests *.pem recursive-include tests *.py From 946a40e4ca5bd2d7eaa32d9f1fb42819a4a52b58 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 23 Jun 2014 13:50:32 -0700 Subject: [PATCH 1048/1502] Make slow_select() test pass on Windows. --- tests/test_base_events.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 9fa3e6d2..773a2848 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -244,7 +244,8 @@ def test_set_debug(self): @mock.patch('asyncio.base_events.logger') def test__run_once_logging(self, m_logger): def slow_select(timeout): - time.sleep(1.0) + # Sleep a bit longer than a second to avoid timer resolution issues. + time.sleep(1.1) return [] # logging needs debug flag From 00fec7848de1061e96d44dc3557735de4473b620 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 24 Jun 2014 22:28:54 +0200 Subject: [PATCH 1049/1502] Log an error if a Task is destroyed while it is still pending, but only on Python 3.4 and newer. --- asyncio/futures.py | 3 +++ asyncio/tasks.py | 13 +++++++++++++ tests/test_tasks.py | 45 ++++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 58 insertions(+), 3 deletions(-) diff --git a/asyncio/futures.py b/asyncio/futures.py index 91ea1706..4edd2e50 100644 --- a/asyncio/futures.py +++ b/asyncio/futures.py @@ -169,6 +169,9 @@ def __repr__(self): res += '<{}>'.format(self._state) return res + # On Python 3.3 or older, objects with a destructor part of a reference + # cycle are never destroyed. It's not more the case on Python 3.4 thanks to + # the PEP 442. if _PY34: def __del__(self): if not self._log_traceback: diff --git a/asyncio/tasks.py b/asyncio/tasks.py index eaf93f88..f5c10c86 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -32,6 +32,7 @@ _DEBUG = (not sys.flags.ignore_environment and bool(os.environ.get('PYTHONASYNCIODEBUG'))) +_PY34 = (sys.version_info >= (3, 4)) _PY35 = (sys.version_info >= (3, 5)) @@ -181,6 +182,18 @@ def __init__(self, coro, *, loop=None): self._loop.call_soon(self._step) self.__class__._all_tasks.add(self) + # On Python 3.3 or older, objects with a destructor part of a reference + # cycle are never destroyed. It's not more the case on Python 3.4 thanks to + # the PEP 442. + if _PY34: + def __del__(self): + if self._state == futures._PENDING: + self._loop.call_exception_handler({ + 'task': self, + 'message': 'Task was destroyed but it is pending!', + }) + futures.Future.__del__(self) + def __repr__(self): res = super().__repr__() if (self._must_cancel and diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 4b55a8af..d770a910 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -5,13 +5,16 @@ import types import unittest import weakref +from test import support from test.script_helper import assert_python_ok +from unittest import mock import asyncio from asyncio import tasks from asyncio import test_utils +PY34 = (sys.version_info >= (3, 4)) PY35 = (sys.version_info >= (3, 5)) @@ -1501,9 +1504,45 @@ def call(arg): def test_corowrapper_weakref(self): wd = weakref.WeakValueDictionary() def foo(): yield from [] - cw = asyncio.tasks.CoroWrapper(foo(), foo) - wd['cw'] = cw # Would fail without __weakref__ slot. - cw.gen = None # Suppress warning from __del__. + + @unittest.skipUnless(PY34, + 'need python 3.4 or later') + def test_log_destroyed_pending_task(self): + @asyncio.coroutine + def kill_me(loop): + future = asyncio.Future(loop=loop) + yield from future + # at this point, the only reference to kill_me() task is + # the Task._wakeup() method in future._callbacks + raise Exception("code never reached") + + mock_handler = mock.Mock() + self.loop.set_exception_handler(mock_handler) + + # schedule the task + coro = kill_me(self.loop) + task = asyncio.async(coro, loop=self.loop) + self.assertEqual(asyncio.Task.all_tasks(loop=self.loop), {task}) + + # execute the task so it waits for future + self.loop._run_once() + self.assertEqual(len(self.loop._ready), 0) + + # remove the future used in kill_me(), and references to the task + del coro.gi_frame.f_locals['future'] + coro = None + task = None + + # no more reference to kill_me() task: the task is destroyed by the GC + support.gc_collect() + + self.assertEqual(asyncio.Task.all_tasks(loop=self.loop), set()) + + mock_handler.assert_called_with(self.loop, { + 'message': 'Task was destroyed but it is pending!', + 'task': mock.ANY, + }) + mock_handler.reset_mock() class GatherTestsBase: From 5e53e04be2398960e8a0b2955557556cd8935c2f Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 24 Jun 2014 22:56:14 +0200 Subject: [PATCH 1050/1502] repr(Task) now also contains the line number even if the coroutine is done: use the first line number of the code object instead of the current line number of the generator frame. The name of the coroutine is not enough because many coroutines may have the same name. It's a common case in asyncio tests for example. --- asyncio/tasks.py | 6 ++++-- tests/test_tasks.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index f5c10c86..3b41a21c 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -208,9 +208,11 @@ def __repr__(self): if iscoroutine(coro): filename = coro.gi_code.co_filename if coro.gi_frame is not None: - text += ' at %s:%s' % (filename, coro.gi_frame.f_lineno) + lineno = coro.gi_frame.f_lineno + text += ' at %s:%s' % (filename, lineno) else: - text += ' done at %s' % filename + lineno = coro.gi_code.co_firstlineno + text += ' done at %s:%s' % (filename, lineno) res = res[:i] + '(<{}>)'.format(text) + res[i:] return res diff --git a/tests/test_tasks.py b/tests/test_tasks.py index d770a910..3037f60d 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -148,12 +148,14 @@ def notmuch(): self.assertRaises(asyncio.CancelledError, self.loop.run_until_complete, t) self.assertEqual(repr(t), - 'Task()' % filename) + 'Task()' + % (filename, lineno)) t = asyncio.Task(notmuch(), loop=self.loop) self.loop.run_until_complete(t) self.assertEqual(repr(t), - "Task()" % filename) + "Task()" + % (filename, lineno)) def test_task_repr_custom(self): @asyncio.coroutine From daeabaa249ae476233d6aaeeb5bbaa6049556fe0 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 25 Jun 2014 21:40:57 +0200 Subject: [PATCH 1051/1502] Tulip issue #177: Rewite repr() of Future, Task, Handle and TimerHandle - Uniformize repr() output to format "" - On Python 3.5+, repr(Task) uses the qualified name instead of the short name of the coroutine --- asyncio/events.py | 56 +++++++++++++++++++--------- asyncio/futures.py | 48 ++++++++++++++++-------- asyncio/tasks.py | 51 +++++++++++++++---------- tests/test_base_events.py | 4 +- tests/test_events.py | 66 +++++++++++++++++++-------------- tests/test_futures.py | 56 +++++++++++++++++++++------- tests/test_tasks.py | 78 +++++++++++++++++++++++---------------- 7 files changed, 231 insertions(+), 128 deletions(-) diff --git a/asyncio/events.py b/asyncio/events.py index de161df6..40544822 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -18,6 +18,7 @@ _PY34 = sys.version_info >= (3, 4) + def _get_function_source(func): if _PY34: func = inspect.unwrap(func) @@ -33,6 +34,35 @@ def _get_function_source(func): return None +def _format_args(args): + # function formatting ('hello',) as ('hello') + args_repr = repr(args) + if len(args) == 1 and args_repr.endswith(',)'): + args_repr = args_repr[:-2] + ')' + return args_repr + + +def _format_callback(func, args, suffix=''): + if isinstance(func, functools.partial): + if args is not None: + suffix = _format_args(args) + suffix + return _format_callback(func.func, func.args, suffix) + + func_repr = getattr(func, '__qualname__', None) + if not func_repr: + func_repr = repr(func) + + if args is not None: + func_repr += _format_args(args) + if suffix: + func_repr += suffix + + source = _get_function_source(func) + if source: + func_repr += ' at %s:%s' % source + return func_repr + + class Handle: """Object returned by callback registration methods.""" @@ -46,18 +76,11 @@ def __init__(self, callback, args, loop): self._cancelled = False def __repr__(self): - cb_repr = getattr(self._callback, '__qualname__', None) - if not cb_repr: - cb_repr = str(self._callback) - - source = _get_function_source(self._callback) - if source: - cb_repr += ' at %s:%s' % source - - res = 'Handle({}, {})'.format(cb_repr, self._args) + info = [] if self._cancelled: - res += '' - return res + info.append('cancelled') + info.append(_format_callback(self._callback, self._args)) + return '<%s %s>' % (self.__class__.__name__, ' '.join(info)) def cancel(self): self._cancelled = True @@ -88,13 +111,12 @@ def __init__(self, when, callback, args, loop): self._when = when def __repr__(self): - res = 'TimerHandle({}, {}, {})'.format(self._when, - self._callback, - self._args) + info = [] if self._cancelled: - res += '' - - return res + info.append('cancelled') + info.append('when=%s' % self._when) + info.append(_format_callback(self._callback, self._args)) + return '<%s %s>' % (self.__class__.__name__, ' '.join(info)) def __hash__(self): return hash(self._when) diff --git a/asyncio/futures.py b/asyncio/futures.py index 4edd2e50..3103fe11 100644 --- a/asyncio/futures.py +++ b/asyncio/futures.py @@ -150,24 +150,40 @@ def __init__(self, *, loop=None): self._loop = loop self._callbacks = [] + def _format_callbacks(self): + cb = self._callbacks + size = len(cb) + if not size: + cb = '' + + def format_cb(callback): + return events._format_callback(callback, ()) + + if size == 1: + cb = format_cb(cb[0]) + elif size == 2: + cb = '{}, {}'.format(format_cb(cb[0]), format_cb(cb[1])) + elif size > 2: + cb = '{}, <{} more>, {}'.format(format_cb(cb[0]), + size-2, + format_cb(cb[-1])) + return 'cb=[%s]' % cb + + def _format_result(self): + if self._state != _FINISHED: + return None + elif self._exception is not None: + return 'exception={!r}'.format(self._exception) + else: + return 'result={!r}'.format(self._result) + def __repr__(self): - res = self.__class__.__name__ + info = [self._state.lower()] if self._state == _FINISHED: - if self._exception is not None: - res += ''.format(self._exception) - else: - res += ''.format(self._result) - elif self._callbacks: - size = len(self._callbacks) - if size > 2: - res += '<{}, [{}, <{} more>, {}]>'.format( - self._state, self._callbacks[0], - size-2, self._callbacks[-1]) - else: - res += '<{}, {}>'.format(self._state, self._callbacks) - else: - res += '<{}>'.format(self._state) - return res + info.append(self._format_result()) + if self._callbacks: + info.append(self._format_callbacks()) + return '<%s %s>' % (self.__class__.__name__, ' '.join(info)) # On Python 3.3 or older, objects with a destructor part of a reference # cycle are never destroyed. It's not more the case on Python 3.4 thanks to diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 3b41a21c..52ca33a8 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -132,6 +132,22 @@ def iscoroutine(obj): return isinstance(obj, CoroWrapper) or inspect.isgenerator(obj) +def _format_coroutine(coro): + assert iscoroutine(coro) + if _PY35: + coro_name = coro.__qualname__ + else: + coro_name = coro.__name__ + + filename = coro.gi_code.co_filename + if coro.gi_frame is not None: + lineno = coro.gi_frame.f_lineno + return '%s() at %s:%s' % (coro_name, filename, lineno) + else: + lineno = coro.gi_code.co_firstlineno + return '%s() done at %s:%s' % (coro_name, filename, lineno) + + class Task(futures.Future): """A coroutine wrapped in a Future.""" @@ -195,26 +211,21 @@ def __del__(self): futures.Future.__del__(self) def __repr__(self): - res = super().__repr__() - if (self._must_cancel and - self._state == futures._PENDING and - ')'.format(text) + res[i:] - return res + info = [] + if self._must_cancel: + info.append('cancelling') + else: + info.append(self._state.lower()) + + info.append(_format_coroutine(self._coro)) + + if self._state == futures._FINISHED: + info.append(self._format_result()) + + if self._callbacks: + info.append(self._format_callbacks()) + + return '<%s %s>' % (self.__class__.__name__, ' '.join(info)) def get_stack(self, *, limit=None): """Return the list of stack frames for this task's coroutine. diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 773a2848..0aa7a8d1 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -1016,14 +1016,14 @@ def stop_loop_coro(loop): self.loop.run_forever() fmt, *args = m_logger.warning.call_args[0] self.assertRegex(fmt % tuple(args), - "^Executing Handle.*stop_loop_cb.* took .* seconds$") + "^Executing took .* seconds$") # slow task asyncio.async(stop_loop_coro(self.loop), loop=self.loop) self.loop.run_forever() fmt, *args = m_logger.warning.call_args[0] self.assertRegex(fmt % tuple(args), - "^Executing Task.*stop_loop_coro.* took .* seconds$") + "^Executing took .* seconds$") if __name__ == '__main__': diff --git a/tests/test_events.py b/tests/test_events.py index 020d1230..d3dbd3a6 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1747,7 +1747,7 @@ def create_event_loop(self): return asyncio.SelectorEventLoop(selectors.SelectSelector()) -def noop(): +def noop(*args): pass @@ -1797,50 +1797,52 @@ def test_handle_weakref(self): h = asyncio.Handle(lambda: None, (), self.loop) wd['h'] = h # Would fail without __weakref__ slot. - def test_repr(self): + def test_handle_repr(self): # simple function h = asyncio.Handle(noop, (), self.loop) src = test_utils.get_function_source(noop) self.assertEqual(repr(h), - 'Handle(noop at %s:%s, ())' % src) + '' % src) # cancelled handle h.cancel() self.assertEqual(repr(h), - 'Handle(noop at %s:%s, ())' % src) + '' % src) # decorated function cb = asyncio.coroutine(noop) h = asyncio.Handle(cb, (), self.loop) self.assertEqual(repr(h), - 'Handle(noop at %s:%s, ())' % src) + '' % src) # partial function - cb = functools.partial(noop) - h = asyncio.Handle(cb, (), self.loop) + cb = functools.partial(noop, 1, 2) + h = asyncio.Handle(cb, (3,), self.loop) filename, lineno = src - regex = (r'^Handle\(functools.partial\(' - r'\) at %s:%s, ' - r'\(\)\)$' % (re.escape(filename), lineno)) + regex = (r'^$' + % (re.escape(filename), lineno)) self.assertRegex(repr(h), regex) # partial method if sys.version_info >= (3, 4): - method = HandleTests.test_repr + method = HandleTests.test_handle_repr cb = functools.partialmethod(method) src = test_utils.get_function_source(method) h = asyncio.Handle(cb, (), self.loop) filename, lineno = src - regex = (r'^Handle\(functools.partialmethod\(' - r', , \) at %s:%s, ' - r'\(\)\)$' % (re.escape(filename), lineno)) + cb_regex = r'' + cb_regex = (r'functools.partialmethod\(%s, , \)\(\)' % cb_regex) + regex = (r'^$' + % (cb_regex, re.escape(filename), lineno)) self.assertRegex(repr(h), regex) - class TimerTests(unittest.TestCase): + def setUp(self): + self.loop = mock.Mock() + def test_hash(self): when = time.monotonic() h = asyncio.TimerHandle(when, lambda: False, (), @@ -1858,29 +1860,37 @@ def callback(*args): self.assertIs(h._args, args) self.assertFalse(h._cancelled) - r = repr(h) - self.assertTrue(r.endswith('())')) - + # cancel h.cancel() self.assertTrue(h._cancelled) - r = repr(h) - self.assertTrue(r.endswith('())'), r) + # when cannot be None self.assertRaises(AssertionError, asyncio.TimerHandle, None, callback, args, - mock.Mock()) + self.loop) - def test_timer_comparison(self): - loop = mock.Mock() + def test_timer_repr(self): + # simple function + h = asyncio.TimerHandle(123, noop, (), self.loop) + src = test_utils.get_function_source(noop) + self.assertEqual(repr(h), + '' % src) + # cancelled handle + h.cancel() + self.assertEqual(repr(h), + '' + % src) + + def test_timer_comparison(self): def callback(*args): return args when = time.monotonic() - h1 = asyncio.TimerHandle(when, callback, (), loop) - h2 = asyncio.TimerHandle(when, callback, (), loop) + h1 = asyncio.TimerHandle(when, callback, (), self.loop) + h2 = asyncio.TimerHandle(when, callback, (), self.loop) # TODO: Use assertLess etc. self.assertFalse(h1 < h2) self.assertFalse(h2 < h1) @@ -1896,8 +1906,8 @@ def callback(*args): h2.cancel() self.assertFalse(h1 == h2) - h1 = asyncio.TimerHandle(when, callback, (), loop) - h2 = asyncio.TimerHandle(when + 10.0, callback, (), loop) + h1 = asyncio.TimerHandle(when, callback, (), self.loop) + h2 = asyncio.TimerHandle(when + 10.0, callback, (), self.loop) self.assertTrue(h1 < h2) self.assertFalse(h2 < h1) self.assertTrue(h1 <= h2) @@ -1909,7 +1919,7 @@ def callback(*args): self.assertFalse(h1 == h2) self.assertTrue(h1 != h2) - h3 = asyncio.Handle(callback, (), loop) + h3 = asyncio.Handle(callback, (), self.loop) self.assertIs(NotImplemented, h1.__eq__(h3)) self.assertIs(NotImplemented, h1.__ne__(h3)) diff --git a/tests/test_futures.py b/tests/test_futures.py index a230d614..8485a5e2 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -1,6 +1,7 @@ """Tests for futures.py.""" import concurrent.futures +import re import threading import unittest from unittest import mock @@ -12,6 +13,12 @@ def _fakefunc(f): return f +def first_cb(): + pass + +def last_cb(): + pass + class FutureTests(test_utils.TestCase): @@ -95,39 +102,60 @@ def fixture(): # The second "yield from f" does not yield f. self.assertEqual(next(g), ('C', 42)) # yield 'C', y. - def test_repr(self): + def test_future_repr(self): f_pending = asyncio.Future(loop=self.loop) - self.assertEqual(repr(f_pending), 'Future') + self.assertEqual(repr(f_pending), '') f_pending.cancel() f_cancelled = asyncio.Future(loop=self.loop) f_cancelled.cancel() - self.assertEqual(repr(f_cancelled), 'Future') + self.assertEqual(repr(f_cancelled), '') f_result = asyncio.Future(loop=self.loop) f_result.set_result(4) - self.assertEqual(repr(f_result), 'Future') + self.assertEqual(repr(f_result), '') self.assertEqual(f_result.result(), 4) exc = RuntimeError() f_exception = asyncio.Future(loop=self.loop) f_exception.set_exception(exc) - self.assertEqual(repr(f_exception), 'Future') + self.assertEqual(repr(f_exception), '') self.assertIs(f_exception.exception(), exc) - f_few_callbacks = asyncio.Future(loop=self.loop) - f_few_callbacks.add_done_callback(_fakefunc) - self.assertIn('Future' % fake_repr) + f_one_callbacks.cancel() + self.assertEqual(repr(f_one_callbacks), + '') + + f_two_callbacks = asyncio.Future(loop=self.loop) + f_two_callbacks.add_done_callback(first_cb) + f_two_callbacks.add_done_callback(last_cb) + first_repr = func_repr(first_cb) + last_repr = func_repr(last_cb) + self.assertRegex(repr(f_two_callbacks), + r'' + % (first_repr, last_repr)) f_many_callbacks = asyncio.Future(loop=self.loop) - for i in range(20): + f_many_callbacks.add_done_callback(first_cb) + for i in range(8): f_many_callbacks.add_done_callback(_fakefunc) - r = repr(f_many_callbacks) - self.assertIn('Future', r) + f_many_callbacks.add_done_callback(last_cb) + cb_regex = r'%s, <8 more>, %s' % (first_repr, last_repr) + self.assertRegex(repr(f_many_callbacks), + r'' % cb_regex) f_many_callbacks.cancel() + self.assertEqual(repr(f_many_callbacks), + '') def test_copy_state(self): # Test the internal _copy_state method since it's being directly diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 3037f60d..78517450 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -26,7 +26,7 @@ def coroutine_function(): class Dummy: def __repr__(self): - return 'Dummy()' + return '' def __call__(self, *args): pass @@ -122,6 +122,7 @@ def notmuch(): yield from [] return 'abc' + # test coroutine function self.assertEqual(notmuch.__name__, 'notmuch') if PY35: self.assertEqual(notmuch.__qualname__, @@ -131,72 +132,87 @@ def notmuch(): filename, lineno = test_utils.get_function_source(notmuch) src = "%s:%s" % (filename, lineno) + # test coroutine object gen = notmuch() + if PY35: + coro_qualname = 'TaskTests.test_task_repr..notmuch' + else: + coro_qualname = 'notmuch' self.assertEqual(gen.__name__, 'notmuch') if PY35: self.assertEqual(gen.__qualname__, - 'TaskTests.test_task_repr..notmuch') + coro_qualname) + # test pending Task t = asyncio.Task(gen, loop=self.loop) t.add_done_callback(Dummy()) + coro = '%s() at %s' % (coro_qualname, src) self.assertEqual(repr(t), - 'Task()' % src) + '()]>' % coro) + # test cancelling Task t.cancel() # Does not take immediate effect! self.assertEqual(repr(t), - 'Task()' % src) + '()]>' % coro) + + # test cancelled Task self.assertRaises(asyncio.CancelledError, self.loop.run_until_complete, t) + coro = '%s() done at %s' % (coro_qualname, src) self.assertEqual(repr(t), - 'Task()' - % (filename, lineno)) + '' % coro) + # test finished Task t = asyncio.Task(notmuch(), loop=self.loop) self.loop.run_until_complete(t) self.assertEqual(repr(t), - "Task()" - % (filename, lineno)) + "" % coro) - def test_task_repr_custom(self): + def test_task_repr_coro_decorator(self): @asyncio.coroutine def notmuch(): - pass + # notmuch() function doesn't use yield from: it will be wrapped by + # @coroutine decorator + return 123 + # test coroutine function self.assertEqual(notmuch.__name__, 'notmuch') - self.assertEqual(notmuch.__module__, __name__) if PY35: self.assertEqual(notmuch.__qualname__, - 'TaskTests.test_task_repr_custom..notmuch') - - class T(asyncio.Future): - def __repr__(self): - return 'T[]' - - class MyTask(asyncio.Task, T): - def __repr__(self): - return super().__repr__() + 'TaskTests.test_task_repr_coro_decorator..notmuch') + self.assertEqual(notmuch.__module__, __name__) + # test coroutine object gen = notmuch() - if PY35 or tasks._DEBUG: + if PY35: # On Python >= 3.5, generators now inherit the name of the # function, as expected, and have a qualified name (__qualname__ - # attribute). In debug mode, @coroutine decorator uses CoroWrapper - # which gets its name (__name__ attribute) from the wrapped - # coroutine function. + # attribute). coro_name = 'notmuch' + coro_qualname = 'TaskTests.test_task_repr_coro_decorator..notmuch' + elif tasks._DEBUG: + # In debug mode, @coroutine decorator uses CoroWrapper which gets + # its name (__name__ attribute) from the wrapped coroutine + # function. + coro_name = coro_qualname = 'notmuch' else: # On Python < 3.5, generators inherit the name of the code, not of # the function. See: http://bugs.python.org/issue21205 - coro_name = 'coro' + coro_name = coro_qualname = 'coro' self.assertEqual(gen.__name__, coro_name) if PY35: - self.assertEqual(gen.__qualname__, - 'TaskTests.test_task_repr_custom..notmuch') + self.assertEqual(gen.__qualname__, coro_qualname) + + # format the coroutine object + code = gen.gi_code + coro = ('%s() at %s:%s' + % (coro_qualname, code.co_filename, code.co_firstlineno)) - t = MyTask(gen, loop=self.loop) - filename = gen.gi_code.co_filename - lineno = gen.gi_frame.f_lineno - self.assertEqual(repr(t), 'T[](<%s at %s:%s>)' % (coro_name, filename, lineno)) + # test pending Task + t = asyncio.Task(gen, loop=self.loop) + t.add_done_callback(Dummy()) + self.assertEqual(repr(t), + '()]>' % coro) def test_task_basics(self): @asyncio.coroutine From ee8b5e4abdaf02701b59cfbd327b2f763752692d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 25 Jun 2014 23:10:27 +0200 Subject: [PATCH 1052/1502] Python issue 21163: Fix some "Task was destroyed but it is pending!" logs in tests --- tests/test_locks.py | 1 + tests/test_queues.py | 27 +++++++++++++++++---------- tests/test_tasks.py | 10 ++++------ 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/tests/test_locks.py b/tests/test_locks.py index 9d50a71f..8ad14863 100644 --- a/tests/test_locks.py +++ b/tests/test_locks.py @@ -783,6 +783,7 @@ def c4(result): # cleanup locked semaphore sem.release() + self.loop.run_until_complete(t4) def test_acquire_cancel(self): sem = asyncio.Semaphore(loop=self.loop) diff --git a/tests/test_queues.py b/tests/test_queues.py index 32c90f47..3d4ac51d 100644 --- a/tests/test_queues.py +++ b/tests/test_queues.py @@ -362,16 +362,21 @@ def test(): def test_put_cancelled_race(self): q = asyncio.Queue(loop=self.loop, maxsize=1) - asyncio.Task(q.put('a'), loop=self.loop) - asyncio.Task(q.put('c'), loop=self.loop) - t = asyncio.Task(q.put('b'), loop=self.loop) + put_a = asyncio.Task(q.put('a'), loop=self.loop) + put_b = asyncio.Task(q.put('b'), loop=self.loop) + put_c = asyncio.Task(q.put('X'), loop=self.loop) test_utils.run_briefly(self.loop) - t.cancel() + self.assertTrue(put_a.done()) + self.assertFalse(put_b.done()) + + put_c.cancel() test_utils.run_briefly(self.loop) - self.assertTrue(t.done()) + self.assertTrue(put_c.done()) self.assertEqual(q.get_nowait(), 'a') - self.assertEqual(q.get_nowait(), 'c') + self.assertEqual(q.get_nowait(), 'b') + + self.loop.run_until_complete(put_b) def test_put_with_waiting_getters(self): q = asyncio.Queue(loop=self.loop) @@ -431,18 +436,20 @@ def worker(): @asyncio.coroutine def test(): - for _ in range(2): - asyncio.Task(worker(), loop=self.loop) + tasks = [asyncio.Task(worker(), loop=self.loop) + for index in range(2)] yield from q.join() + return tasks - self.loop.run_until_complete(test()) + tasks = self.loop.run_until_complete(test()) self.assertEqual(sum(range(100)), accumulator) # close running generators running = False - for i in range(2): + for i in range(len(tasks)): q.put_nowait(0) + self.loop.run_until_complete(asyncio.wait(tasks, loop=self.loop)) def test_join_empty_queue(self): q = asyncio.JoinableQueue(loop=self.loop) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 78517450..3a23d721 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1763,16 +1763,14 @@ def coro(): gen2 = coro() fut = asyncio.gather(gen1, gen2) self.assertIs(fut._loop, self.one_loop) - gen1.close() - gen2.close() + self.one_loop.run_until_complete(fut) self.set_event_loop(self.other_loop, cleanup=False) gen3 = coro() gen4 = coro() - fut = asyncio.gather(gen3, gen4, loop=self.other_loop) - self.assertIs(fut._loop, self.other_loop) - gen3.close() - gen4.close() + fut2 = asyncio.gather(gen3, gen4, loop=self.other_loop) + self.assertIs(fut2._loop, self.other_loop) + self.other_loop.run_until_complete(fut2) def test_duplicate_coroutines(self): @asyncio.coroutine From 0da4591cdd554075f3bc99db05e605e55f66a96d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 25 Jun 2014 23:29:49 +0200 Subject: [PATCH 1053/1502] Add test to check that run_until_complete() checks the loop of the future --- tests/test_base_events.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 0aa7a8d1..6ad08043 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -288,6 +288,12 @@ def test_run_until_complete_type_error(self): self.assertRaises(TypeError, self.loop.run_until_complete, 'blah') + def test_run_until_complete_loop(self): + task = asyncio.Future(loop=self.loop) + other_loop = self.new_test_loop() + self.assertRaises(ValueError, + other_loop.run_until_complete, task) + def test_subprocess_exec_invalid_args(self): args = [sys.executable, '-c', 'pass'] From cdb1cdaea7a5a508e60087d680a270ab5aa831a2 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 25 Jun 2014 23:31:20 +0200 Subject: [PATCH 1054/1502] Python issue 21163: Fix more "Task was destroyed but it is pending!" logs in tests --- tests/test_tasks.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 3a23d721..45089879 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -51,6 +51,7 @@ def notmuch(): self.set_event_loop(loop) t = asyncio.Task(notmuch(), loop=loop) self.assertIs(t._loop, loop) + loop.run_until_complete(t) loop.close() def test_async_coroutine(self): @@ -67,6 +68,7 @@ def notmuch(): self.set_event_loop(loop) t = asyncio.async(notmuch(), loop=loop) self.assertIs(t._loop, loop) + loop.run_until_complete(t) loop.close() def test_async_future(self): @@ -213,6 +215,7 @@ def notmuch(): t.add_done_callback(Dummy()) self.assertEqual(repr(t), '()]>' % coro) + self.loop.run_until_complete(t) def test_task_basics(self): @asyncio.coroutine From 3646b42c18d6c4aa8999b4a865c88071a31e0fbb Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 25 Jun 2014 23:47:38 +0200 Subject: [PATCH 1055/1502] Python issue 21163: Fix one more "Task was destroyed but it is pending!" log in tests --- tests/test_tasks.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 45089879..b19d7ccc 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -411,8 +411,10 @@ def task(): loop.stop() t = asyncio.Task(task(), loop=loop) - self.assertRaises( - RuntimeError, loop.run_until_complete, t) + with self.assertRaises(RuntimeError) as cm: + loop.run_until_complete(t) + self.assertEqual(str(cm.exception), + 'Event loop stopped before Future completed.') self.assertFalse(t.done()) self.assertEqual(x, 2) self.assertAlmostEqual(0.3, loop.time()) @@ -420,6 +422,8 @@ def task(): # close generators for w in waiters: w.close() + t.cancel() + self.assertRaises(asyncio.CancelledError, loop.run_until_complete, t) def test_wait_for(self): From 55110b453faf301437779a27eba13d1f3ea40144 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 26 Jun 2014 01:35:01 +0200 Subject: [PATCH 1056/1502] Handle error handler: enhance formatting of the callback --- asyncio/events.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/asyncio/events.py b/asyncio/events.py index 40544822..58c6bd53 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -89,8 +89,8 @@ def _run(self): try: self._callback(*self._args) except Exception as exc: - msg = 'Exception in callback {}{!r}'.format(self._callback, - self._args) + cb = _format_callback(self._callback, self._args) + msg = 'Exception in callback {}'.format(cb) self._loop.call_exception_handler({ 'message': msg, 'exception': exc, From b2b052b94c08d630b73b2d5c0e3fdecd24da0267 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 27 Jun 2014 12:20:53 +0200 Subject: [PATCH 1057/1502] MANIFEST.in: add pypi.bat --- MANIFEST.in | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MANIFEST.in b/MANIFEST.in index b647f6ac..314325c8 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,6 +1,6 @@ include AUTHORS COPYING include Makefile -include overlapped.c +include overlapped.c pypi.bat include check.py runtests.py include update_stdlib.sh From a670245d509d0e8b38e364df7381d32213e93884 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 27 Jun 2014 12:22:00 +0200 Subject: [PATCH 1058/1502] Oops, restore a removed test --- tests/test_tasks.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index b19d7ccc..8fd3e28f 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1529,6 +1529,9 @@ def call(arg): def test_corowrapper_weakref(self): wd = weakref.WeakValueDictionary() def foo(): yield from [] + cw = asyncio.tasks.CoroWrapper(foo(), foo) + wd['cw'] = cw # Would fail without __weakref__ slot. + cw.gen = None # Suppress warning from __del__. @unittest.skipUnless(PY34, 'need python 3.4 or later') From 3430747ee4ff428c119a7db5ef8615cceb03fae5 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 27 Jun 2014 12:28:15 +0200 Subject: [PATCH 1059/1502] Tulip issue #137: In debug mode, add the traceback where the coroutine object was created to the "coroutine ... was never yield from" log --- asyncio/tasks.py | 17 ++++++++++------- tests/test_tasks.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 52ca33a8..89ec3a4c 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -43,6 +43,7 @@ def __init__(self, gen, func): assert inspect.isgenerator(gen), gen self.gen = gen self.func = func + self._source_traceback = traceback.extract_stack(sys._getframe(1)) def __iter__(self): return self @@ -81,13 +82,13 @@ def __del__(self): gen = getattr(self, 'gen', None) frame = getattr(gen, 'gi_frame', None) if frame is not None and frame.f_lasti == -1: - func = self.func - code = func.__code__ - filename = code.co_filename - lineno = code.co_firstlineno - logger.error( - 'Coroutine %r defined at %s:%s was never yielded from', - func.__name__, filename, lineno) + func = events._format_callback(self.func, ()) + tb = ''.join(traceback.format_list(self._source_traceback)) + message = ('Coroutine %s was never yielded from\n' + 'Coroutine object created at (most recent call last):\n' + '%s' + % (func, tb.rstrip())) + logger.error(message) def coroutine(func): @@ -112,6 +113,8 @@ def coro(*args, **kw): @functools.wraps(func) def wrapper(*args, **kwds): w = CoroWrapper(coro(*args, **kwds), func) + if w._source_traceback: + del w._source_traceback[-1] w.__name__ = func.__name__ if _PY35: w.__qualname__ = func.__qualname__ diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 8fd3e28f..c5eb92b8 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1,6 +1,7 @@ """Tests for tasks.py.""" import os.path +import re import sys import types import unittest @@ -1572,6 +1573,37 @@ def kill_me(loop): }) mock_handler.reset_mock() + @mock.patch('asyncio.tasks.logger') + def test_coroutine_never_yielded(self, m_log): + debug = asyncio.tasks._DEBUG + try: + asyncio.tasks._DEBUG = True + @asyncio.coroutine + def coro_noop(): + pass + finally: + asyncio.tasks._DEBUG = debug + + tb_filename = __file__ + tb_lineno = sys._getframe().f_lineno + 1 + coro = coro_noop() + coro = None + support.gc_collect() + + self.assertTrue(m_log.error.called) + message = m_log.error.call_args[0][0] + func_filename, func_lineno = test_utils.get_function_source(coro_noop) + regex = (r'^Coroutine %s\(\) at %s:%s was never yielded from\n' + r'Coroutine object created at \(most recent call last\):\n' + r'.*\n' + r' File "%s", line %s, in test_coroutine_never_yielded\n' + r' coro = coro_noop\(\)$' + % (re.escape(coro_noop.__qualname__), + func_filename, func_lineno, + tb_filename, tb_lineno)) + + self.assertRegex(message, re.compile(regex, re.DOTALL)) + class GatherTestsBase: From d7fa1017f58997f72ff92191828f71c6378f3555 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 27 Jun 2014 13:35:15 +0200 Subject: [PATCH 1060/1502] Tulip issue #137: In debug mode, save traceback where Future, Task and Handle objects are created. Pass the traceback to call_exception_handler() in the 'source_traceback' key. The traceback is truncated to hide internal calls in asyncio, show only the traceback from user code. Add tests for the new source_traceback, and a test for the 'Future/Task exception was never retrieved' log. --- asyncio/base_events.py | 26 +++++++++++++++-- asyncio/events.py | 18 +++++++++--- asyncio/futures.py | 29 ++++++++++++------- asyncio/tasks.py | 14 ++++++++-- tests/test_base_events.py | 9 ++++-- tests/test_events.py | 37 ++++++++++++++++++++++-- tests/test_futures.py | 59 +++++++++++++++++++++++++++++++++++++++ tests/test_tasks.py | 14 ++++++++++ 8 files changed, 180 insertions(+), 26 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index b1271429..90115e50 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -21,6 +21,7 @@ import logging import socket import subprocess +import traceback import time import os import sys @@ -290,7 +291,10 @@ def call_later(self, delay, callback, *args): Any positional arguments after the callback will be passed to the callback when it is called. """ - return self.call_at(self.time() + delay, callback, *args) + timer = self.call_at(self.time() + delay, callback, *args) + if timer._source_traceback: + del timer._source_traceback[-1] + return timer def call_at(self, when, callback, *args): """Like call_later(), but uses an absolute time.""" @@ -299,6 +303,8 @@ def call_at(self, when, callback, *args): if self._debug: self._assert_is_current_event_loop() timer = events.TimerHandle(when, callback, args, self) + if timer._source_traceback: + del timer._source_traceback[-1] heapq.heappush(self._scheduled, timer) return timer @@ -312,7 +318,10 @@ def call_soon(self, callback, *args): Any positional arguments after the callback will be passed to the callback when it is called. """ - return self._call_soon(callback, args, check_loop=True) + handle = self._call_soon(callback, args, check_loop=True) + if handle._source_traceback: + del handle._source_traceback[-1] + return handle def _call_soon(self, callback, args, check_loop): if tasks.iscoroutinefunction(callback): @@ -320,6 +329,8 @@ def _call_soon(self, callback, args, check_loop): if self._debug and check_loop: self._assert_is_current_event_loop() handle = events.Handle(callback, args, self) + if handle._source_traceback: + del handle._source_traceback[-1] self._ready.append(handle) return handle @@ -344,6 +355,8 @@ def _assert_is_current_event_loop(self): def call_soon_threadsafe(self, callback, *args): """Like call_soon(), but thread safe.""" handle = self._call_soon(callback, args, check_loop=False) + if handle._source_traceback: + del handle._source_traceback[-1] self._write_to_self() return handle @@ -757,7 +770,14 @@ def default_exception_handler(self, context): for key in sorted(context): if key in {'message', 'exception'}: continue - log_lines.append('{}: {!r}'.format(key, context[key])) + value = context[key] + if key == 'source_traceback': + tb = ''.join(traceback.format_list(value)) + value = 'Object created at (most recent call last):\n' + value += tb.rstrip() + else: + value = repr(value) + log_lines.append('{}: {}'.format(key, value)) logger.error('\n'.join(log_lines), exc_info=exc_info) diff --git a/asyncio/events.py b/asyncio/events.py index 58c6bd53..b389cfb0 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -11,6 +11,7 @@ import functools import inspect import subprocess +import traceback import threading import socket import sys @@ -66,7 +67,8 @@ def _format_callback(func, args, suffix=''): class Handle: """Object returned by callback registration methods.""" - __slots__ = ['_callback', '_args', '_cancelled', '_loop', '__weakref__'] + __slots__ = ('_callback', '_args', '_cancelled', '_loop', + '_source_traceback', '__weakref__') def __init__(self, callback, args, loop): assert not isinstance(callback, Handle), 'A Handle is not a callback' @@ -74,6 +76,10 @@ def __init__(self, callback, args, loop): self._callback = callback self._args = args self._cancelled = False + if self._loop.get_debug(): + self._source_traceback = traceback.extract_stack(sys._getframe(1)) + else: + self._source_traceback = None def __repr__(self): info = [] @@ -91,11 +97,14 @@ def _run(self): except Exception as exc: cb = _format_callback(self._callback, self._args) msg = 'Exception in callback {}'.format(cb) - self._loop.call_exception_handler({ + context = { 'message': msg, 'exception': exc, 'handle': self, - }) + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) self = None # Needed to break cycles when an exception occurs. @@ -107,7 +116,8 @@ class TimerHandle(Handle): def __init__(self, when, callback, args, loop): assert when is not None super().__init__(callback, args, loop) - + if self._source_traceback: + del self._source_traceback[-1] self._when = when def __repr__(self): diff --git a/asyncio/futures.py b/asyncio/futures.py index 3103fe11..fcc90d13 100644 --- a/asyncio/futures.py +++ b/asyncio/futures.py @@ -82,10 +82,11 @@ class itself, but instead to have a reference to a helper object in a discussion about closing files when they are collected. """ - __slots__ = ['exc', 'tb', 'loop'] + __slots__ = ('loop', 'source_traceback', 'exc', 'tb') - def __init__(self, exc, loop): - self.loop = loop + def __init__(self, future, exc): + self.loop = future._loop + self.source_traceback = future._source_traceback self.exc = exc self.tb = None @@ -102,11 +103,12 @@ def clear(self): def __del__(self): if self.tb: - msg = 'Future/Task exception was never retrieved:\n{tb}' - context = { - 'message': msg.format(tb=''.join(self.tb)), - } - self.loop.call_exception_handler(context) + msg = 'Future/Task exception was never retrieved' + if self.source_traceback: + msg += '\nFuture/Task created at (most recent call last):\n' + msg += ''.join(traceback.format_list(self.source_traceback)) + msg += ''.join(self.tb).rstrip() + self.loop.call_exception_handler({'message': msg}) class Future: @@ -149,6 +151,10 @@ def __init__(self, *, loop=None): else: self._loop = loop self._callbacks = [] + if self._loop.get_debug(): + self._source_traceback = traceback.extract_stack(sys._getframe(1)) + else: + self._source_traceback = None def _format_callbacks(self): cb = self._callbacks @@ -196,10 +202,13 @@ def __del__(self): return exc = self._exception context = { - 'message': 'Future/Task exception was never retrieved', + 'message': ('%s exception was never retrieved' + % self.__class__.__name__), 'exception': exc, 'future': self, } + if self._source_traceback: + context['source_traceback'] = self._source_traceback self._loop.call_exception_handler(context) def cancel(self): @@ -335,7 +344,7 @@ def set_exception(self, exception): if _PY34: self._log_traceback = True else: - self._tb_logger = _TracebackLogger(exception, self._loop) + self._tb_logger = _TracebackLogger(self, exception) # Arrange for the logger to be activated after all callbacks # have had a chance to call result() or exception(). self._loop.call_soon(self._tb_logger.activate) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 89ec3a4c..db0bbf3a 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -195,6 +195,8 @@ def all_tasks(cls, loop=None): def __init__(self, coro, *, loop=None): assert iscoroutine(coro), repr(coro) # Not a coroutine function! super().__init__(loop=loop) + if self._source_traceback: + del self._source_traceback[-1] self._coro = iter(coro) # Use the iterator just in case. self._fut_waiter = None self._must_cancel = False @@ -207,10 +209,13 @@ def __init__(self, coro, *, loop=None): if _PY34: def __del__(self): if self._state == futures._PENDING: - self._loop.call_exception_handler({ + context = { 'task': self, 'message': 'Task was destroyed but it is pending!', - }) + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) futures.Future.__del__(self) def __repr__(self): @@ -620,7 +625,10 @@ def async(coro_or_future, *, loop=None): raise ValueError('loop argument must agree with Future') return coro_or_future elif iscoroutine(coro_or_future): - return Task(coro_or_future, loop=loop) + task = Task(coro_or_future, loop=loop) + if task._source_traceback: + del task._source_traceback[-1] + return task else: raise TypeError('A Future or coroutine is required') diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 6ad08043..adba082b 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -406,19 +406,22 @@ def zero_error(): 1/0 def run_loop(): - self.loop.call_soon(zero_error) + handle = self.loop.call_soon(zero_error) self.loop._run_once() + return handle + self.loop.set_debug(True) self.loop._process_events = mock.Mock() mock_handler = mock.Mock() self.loop.set_exception_handler(mock_handler) - run_loop() + handle = run_loop() mock_handler.assert_called_with(self.loop, { 'exception': MOCK_ANY, 'message': test_utils.MockPattern( 'Exception in callback.*zero_error'), - 'handle': MOCK_ANY, + 'handle': handle, + 'source_traceback': handle._source_traceback, }) mock_handler.reset_mock() diff --git a/tests/test_events.py b/tests/test_events.py index d3dbd3a6..beb6cecf 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1751,10 +1751,11 @@ def noop(*args): pass -class HandleTests(unittest.TestCase): +class HandleTests(test_utils.TestCase): def setUp(self): - self.loop = None + self.loop = mock.Mock() + self.loop.get_debug.return_value = True def test_handle(self): def callback(*args): @@ -1789,7 +1790,8 @@ def callback(): self.loop.call_exception_handler.assert_called_with({ 'message': test_utils.MockPattern('Exception in callback.*'), 'exception': mock.ANY, - 'handle': h + 'handle': h, + 'source_traceback': h._source_traceback, }) def test_handle_weakref(self): @@ -1837,6 +1839,35 @@ def test_handle_repr(self): % (cb_regex, re.escape(filename), lineno)) self.assertRegex(repr(h), regex) + def test_handle_source_traceback(self): + loop = asyncio.get_event_loop_policy().new_event_loop() + loop.set_debug(True) + self.set_event_loop(loop) + + def check_source_traceback(h): + lineno = sys._getframe(1).f_lineno - 1 + self.assertIsInstance(h._source_traceback, list) + self.assertEqual(h._source_traceback[-1][:3], + (__file__, + lineno, + 'test_handle_source_traceback')) + + # call_soon + h = loop.call_soon(noop) + check_source_traceback(h) + + # call_soon_threadsafe + h = loop.call_soon_threadsafe(noop) + check_source_traceback(h) + + # call_later + h = loop.call_later(0, noop) + check_source_traceback(h) + + # call_at + h = loop.call_later(0, noop) + check_source_traceback(h) + class TimerTests(unittest.TestCase): diff --git a/tests/test_futures.py b/tests/test_futures.py index 8485a5e2..ee872615 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -2,8 +2,10 @@ import concurrent.futures import re +import sys import threading import unittest +from test import support from unittest import mock import asyncio @@ -284,6 +286,63 @@ def test_wrap_future_cancel2(self): self.assertEqual(f1.result(), 42) self.assertTrue(f2.cancelled()) + def test_future_source_traceback(self): + self.loop.set_debug(True) + + future = asyncio.Future(loop=self.loop) + lineno = sys._getframe().f_lineno - 1 + self.assertIsInstance(future._source_traceback, list) + self.assertEqual(future._source_traceback[-1][:3], + (__file__, + lineno, + 'test_future_source_traceback')) + + @mock.patch('asyncio.base_events.logger') + def test_future_exception_never_retrieved(self, m_log): + self.loop.set_debug(True) + + def memroy_error(): + try: + raise MemoryError() + except BaseException as exc: + return exc + exc = memroy_error() + + future = asyncio.Future(loop=self.loop) + source_traceback = future._source_traceback + future.set_exception(exc) + future = None + test_utils.run_briefly(self.loop) + support.gc_collect() + + if sys.version_info >= (3, 4): + frame = source_traceback[-1] + regex = (r'^Future exception was never retrieved\n' + r'future: \n' + r'source_traceback: Object created at \(most recent call last\):\n' + r' File' + r'.*\n' + r' File "%s", line %s, in test_future_exception_never_retrieved\n' + r' future = asyncio\.Future\(loop=self\.loop\)$' + % (frame[0], frame[1])) + exc_info = (type(exc), exc, exc.__traceback__) + m_log.error.assert_called_once_with(mock.ANY, exc_info=exc_info) + else: + frame = source_traceback[-1] + regex = (r'^Future/Task exception was never retrieved\n' + r'Future/Task created at \(most recent call last\):\n' + r' File' + r'.*\n' + r' File "%s", line %s, in test_future_exception_never_retrieved\n' + r' future = asyncio\.Future\(loop=self\.loop\)\n' + r'Traceback \(most recent call last\):\n' + r'.*\n' + r'MemoryError$' + % (frame[0], frame[1])) + m_log.error.assert_called_once_with(mock.ANY, exc_info=False) + message = m_log.error.call_args[0][0] + self.assertRegex(message, re.compile(regex, re.DOTALL)) + class FutureDoneCallbackTests(test_utils.TestCase): diff --git a/tests/test_tasks.py b/tests/test_tasks.py index c5eb92b8..54b29ba9 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1546,6 +1546,7 @@ def kill_me(loop): raise Exception("code never reached") mock_handler = mock.Mock() + self.loop.set_debug(True) self.loop.set_exception_handler(mock_handler) # schedule the task @@ -1560,6 +1561,7 @@ def kill_me(loop): # remove the future used in kill_me(), and references to the task del coro.gi_frame.f_locals['future'] coro = None + source_traceback = task._source_traceback task = None # no more reference to kill_me() task: the task is destroyed by the GC @@ -1570,6 +1572,7 @@ def kill_me(loop): mock_handler.assert_called_with(self.loop, { 'message': 'Task was destroyed but it is pending!', 'task': mock.ANY, + 'source_traceback': source_traceback, }) mock_handler.reset_mock() @@ -1604,6 +1607,17 @@ def coro_noop(): self.assertRegex(message, re.compile(regex, re.DOTALL)) + def test_task_source_traceback(self): + self.loop.set_debug(True) + + task = asyncio.Task(coroutine_function(), loop=self.loop) + lineno = sys._getframe().f_lineno - 1 + self.assertIsInstance(task._source_traceback, list) + self.assertEqual(task._source_traceback[-1][:3], + (__file__, + lineno, + 'test_task_source_traceback')) + class GatherTestsBase: From 58337795b1e3bb382cddd3b1ea1d5227944f3bcc Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 28 Jun 2014 00:11:24 +0200 Subject: [PATCH 1061/1502] Fix unit tests on Windows, escape filenames in regex --- tests/test_futures.py | 8 ++++---- tests/test_tasks.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_futures.py b/tests/test_futures.py index ee872615..96b41d69 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -322,9 +322,9 @@ def memroy_error(): r'source_traceback: Object created at \(most recent call last\):\n' r' File' r'.*\n' - r' File "%s", line %s, in test_future_exception_never_retrieved\n' + r' File "{filename}", line {lineno}, in test_future_exception_never_retrieved\n' r' future = asyncio\.Future\(loop=self\.loop\)$' - % (frame[0], frame[1])) + ).format(filename=re.escape(frame[0]), lineno=frame[1]) exc_info = (type(exc), exc, exc.__traceback__) m_log.error.assert_called_once_with(mock.ANY, exc_info=exc_info) else: @@ -333,12 +333,12 @@ def memroy_error(): r'Future/Task created at \(most recent call last\):\n' r' File' r'.*\n' - r' File "%s", line %s, in test_future_exception_never_retrieved\n' + r' File "{filename}", line {lineno}, in test_future_exception_never_retrieved\n' r' future = asyncio\.Future\(loop=self\.loop\)\n' r'Traceback \(most recent call last\):\n' r'.*\n' r'MemoryError$' - % (frame[0], frame[1])) + ).format(filename=re.escape(frame[0]), lineno=frame[1]) m_log.error.assert_called_once_with(mock.ANY, exc_info=False) message = m_log.error.call_args[0][0] self.assertRegex(message, re.compile(regex, re.DOTALL)) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 54b29ba9..dee14b2e 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1602,8 +1602,8 @@ def coro_noop(): r' File "%s", line %s, in test_coroutine_never_yielded\n' r' coro = coro_noop\(\)$' % (re.escape(coro_noop.__qualname__), - func_filename, func_lineno, - tb_filename, tb_lineno)) + re.escape(func_filename), func_lineno, + re.escape(tb_filename), tb_lineno)) self.assertRegex(message, re.compile(regex, re.DOTALL)) From 5c42e4571edd787ca49c199bce8ef72e60b8685e Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 28 Jun 2014 01:18:31 +0200 Subject: [PATCH 1062/1502] Fix two "Coroutine xxx was never yielded from" messages in tests --- tests/test_tasks.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index dee14b2e..b4a3092e 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1423,8 +1423,10 @@ def test_as_completed_invalid_args(self): # as_completed() expects a list of futures, not a future instance self.assertRaises(TypeError, self.loop.run_until_complete, asyncio.as_completed(fut, loop=self.loop)) + coro = coroutine_function() self.assertRaises(TypeError, self.loop.run_until_complete, - asyncio.as_completed(coroutine_function(), loop=self.loop)) + asyncio.as_completed(coro, loop=self.loop)) + coro.close() def test_wait_invalid_args(self): fut = asyncio.Future(loop=self.loop) @@ -1432,8 +1434,10 @@ def test_wait_invalid_args(self): # wait() expects a list of futures, not a future instance self.assertRaises(TypeError, self.loop.run_until_complete, asyncio.wait(fut, loop=self.loop)) + coro = coroutine_function() self.assertRaises(TypeError, self.loop.run_until_complete, - asyncio.wait(coroutine_function(), loop=self.loop)) + asyncio.wait(coro, loop=self.loop)) + coro.close() # wait() expects at least a future self.assertRaises(ValueError, self.loop.run_until_complete, From a4f227270e4674ff0514d93127182f5956d95c01 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sun, 29 Jun 2014 00:12:21 +0200 Subject: [PATCH 1063/1502] Move coroutine code in the new module asyncio.coroutines --- asyncio/__init__.py | 4 +- asyncio/base_events.py | 28 ++++---- asyncio/base_subprocess.py | 4 +- asyncio/coroutines.py | 140 ++++++++++++++++++++++++++++++++++++ asyncio/locks.py | 12 ++-- asyncio/streams.py | 18 ++--- asyncio/subprocess.py | 15 ++-- asyncio/tasks.py | 143 +++---------------------------------- asyncio/test_utils.py | 3 +- asyncio/unix_events.py | 8 +-- asyncio/windows_events.py | 11 +-- tests/test_tasks.py | 34 ++++----- 12 files changed, 221 insertions(+), 199 deletions(-) create mode 100644 asyncio/coroutines.py diff --git a/asyncio/__init__.py b/asyncio/__init__.py index 3df2f803..789424e4 100644 --- a/asyncio/__init__.py +++ b/asyncio/__init__.py @@ -18,6 +18,7 @@ import _overlapped # Will also be exported. # This relies on each of the submodules having an __all__ variable. +from .coroutines import * from .events import * from .futures import * from .locks import * @@ -34,7 +35,8 @@ from .unix_events import * # pragma: no cover -__all__ = (events.__all__ + +__all__ = (coroutines.__all__ + + events.__all__ + futures.__all__ + locks.__all__ + protocols.__all__ + diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 90115e50..c42e7f98 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -26,9 +26,11 @@ import os import sys +from . import coroutines from . import events from . import futures from . import tasks +from .coroutines import coroutine from .log import logger @@ -118,7 +120,7 @@ def _wakeup(self): if not waiter.done(): waiter.set_result(waiter) - @tasks.coroutine + @coroutine def wait_closed(self): if self.sockets is None or self.waiters is None: return @@ -175,7 +177,7 @@ def _make_write_pipe_transport(self, pipe, protocol, waiter=None, """Create write pipe transport.""" raise NotImplementedError - @tasks.coroutine + @coroutine def _make_subprocess_transport(self, protocol, args, shell, stdin, stdout, stderr, bufsize, extra=None, **kwargs): @@ -298,7 +300,7 @@ def call_later(self, delay, callback, *args): def call_at(self, when, callback, *args): """Like call_later(), but uses an absolute time.""" - if tasks.iscoroutinefunction(callback): + if coroutines.iscoroutinefunction(callback): raise TypeError("coroutines cannot be used with call_at()") if self._debug: self._assert_is_current_event_loop() @@ -324,7 +326,7 @@ def call_soon(self, callback, *args): return handle def _call_soon(self, callback, args, check_loop): - if tasks.iscoroutinefunction(callback): + if coroutines.iscoroutinefunction(callback): raise TypeError("coroutines cannot be used with call_soon()") if self._debug and check_loop: self._assert_is_current_event_loop() @@ -361,7 +363,7 @@ def call_soon_threadsafe(self, callback, *args): return handle def run_in_executor(self, executor, callback, *args): - if tasks.iscoroutinefunction(callback): + if coroutines.iscoroutinefunction(callback): raise TypeError("coroutines cannot be used with run_in_executor()") if isinstance(callback, events.Handle): assert not args @@ -389,7 +391,7 @@ def getaddrinfo(self, host, port, *, def getnameinfo(self, sockaddr, flags=0): return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) - @tasks.coroutine + @coroutine def create_connection(self, protocol_factory, host=None, port=None, *, ssl=None, family=0, proto=0, flags=0, sock=None, local_addr=None, server_hostname=None): @@ -505,7 +507,7 @@ def create_connection(self, protocol_factory, host=None, port=None, *, sock, protocol_factory, ssl, server_hostname) return transport, protocol - @tasks.coroutine + @coroutine def _create_connection_transport(self, sock, protocol_factory, ssl, server_hostname): protocol = protocol_factory() @@ -521,7 +523,7 @@ def _create_connection_transport(self, sock, protocol_factory, ssl, yield from waiter return transport, protocol - @tasks.coroutine + @coroutine def create_datagram_endpoint(self, protocol_factory, local_addr=None, remote_addr=None, *, family=0, proto=0, flags=0): @@ -593,7 +595,7 @@ def create_datagram_endpoint(self, protocol_factory, transport = self._make_datagram_transport(sock, protocol, r_addr) return transport, protocol - @tasks.coroutine + @coroutine def create_server(self, protocol_factory, host=None, port=None, *, family=socket.AF_UNSPEC, @@ -672,7 +674,7 @@ def create_server(self, protocol_factory, host=None, port=None, self._start_serving(protocol_factory, sock, ssl, server) return server - @tasks.coroutine + @coroutine def connect_read_pipe(self, protocol_factory, pipe): protocol = protocol_factory() waiter = futures.Future(loop=self) @@ -680,7 +682,7 @@ def connect_read_pipe(self, protocol_factory, pipe): yield from waiter return transport, protocol - @tasks.coroutine + @coroutine def connect_write_pipe(self, protocol_factory, pipe): protocol = protocol_factory() waiter = futures.Future(loop=self) @@ -688,7 +690,7 @@ def connect_write_pipe(self, protocol_factory, pipe): yield from waiter return transport, protocol - @tasks.coroutine + @coroutine def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=False, shell=True, bufsize=0, @@ -706,7 +708,7 @@ def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, protocol, cmd, True, stdin, stdout, stderr, bufsize, **kwargs) return transport, protocol - @tasks.coroutine + @coroutine def subprocess_exec(self, protocol_factory, program, *args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=False, diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index b78f816d..2f933c54 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -2,8 +2,8 @@ import subprocess from . import protocols -from . import tasks from . import transports +from .coroutines import coroutine class BaseSubprocessTransport(transports.SubprocessTransport): @@ -65,7 +65,7 @@ def terminate(self): def kill(self): self._proc.kill() - @tasks.coroutine + @coroutine def _post_init(self): proc = self._proc loop = self._loop diff --git a/asyncio/coroutines.py b/asyncio/coroutines.py new file mode 100644 index 00000000..5b6d93f0 --- /dev/null +++ b/asyncio/coroutines.py @@ -0,0 +1,140 @@ +__all__ = ['coroutine', + 'iscoroutinefunction', 'iscoroutine'] + +import functools +import inspect +import os +import sys +import traceback + +from . import events +from . import futures +from .log import logger + +# If you set _DEBUG to true, @coroutine will wrap the resulting +# generator objects in a CoroWrapper instance (defined below). That +# instance will log a message when the generator is never iterated +# over, which may happen when you forget to use "yield from" with a +# coroutine call. Note that the value of the _DEBUG flag is taken +# when the decorator is used, so to be of any use it must be set +# before you define your coroutines. A downside of using this feature +# is that tracebacks show entries for the CoroWrapper.__next__ method +# when _DEBUG is true. +_DEBUG = (not sys.flags.ignore_environment + and bool(os.environ.get('PYTHONASYNCIODEBUG'))) + +_PY35 = (sys.version_info >= (3, 5)) + +class CoroWrapper: + # Wrapper for coroutine in _DEBUG mode. + + def __init__(self, gen, func): + assert inspect.isgenerator(gen), gen + self.gen = gen + self.func = func + self._source_traceback = traceback.extract_stack(sys._getframe(1)) + + def __iter__(self): + return self + + def __next__(self): + return next(self.gen) + + def send(self, *value): + # We use `*value` because of a bug in CPythons prior + # to 3.4.1. See issue #21209 and test_yield_from_corowrapper + # for details. This workaround should be removed in 3.5.0. + if len(value) == 1: + value = value[0] + return self.gen.send(value) + + def throw(self, exc): + return self.gen.throw(exc) + + def close(self): + return self.gen.close() + + @property + def gi_frame(self): + return self.gen.gi_frame + + @property + def gi_running(self): + return self.gen.gi_running + + @property + def gi_code(self): + return self.gen.gi_code + + def __del__(self): + # Be careful accessing self.gen.frame -- self.gen might not exist. + gen = getattr(self, 'gen', None) + frame = getattr(gen, 'gi_frame', None) + if frame is not None and frame.f_lasti == -1: + func = events._format_callback(self.func, ()) + tb = ''.join(traceback.format_list(self._source_traceback)) + message = ('Coroutine %s was never yielded from\n' + 'Coroutine object created at (most recent call last):\n' + '%s' + % (func, tb.rstrip())) + logger.error(message) + + +def coroutine(func): + """Decorator to mark coroutines. + + If the coroutine is not yielded from before it is destroyed, + an error message is logged. + """ + if inspect.isgeneratorfunction(func): + coro = func + else: + @functools.wraps(func) + def coro(*args, **kw): + res = func(*args, **kw) + if isinstance(res, futures.Future) or inspect.isgenerator(res): + res = yield from res + return res + + if not _DEBUG: + wrapper = coro + else: + @functools.wraps(func) + def wrapper(*args, **kwds): + w = CoroWrapper(coro(*args, **kwds), func) + if w._source_traceback: + del w._source_traceback[-1] + w.__name__ = func.__name__ + if _PY35: + w.__qualname__ = func.__qualname__ + w.__doc__ = func.__doc__ + return w + + wrapper._is_coroutine = True # For iscoroutinefunction(). + return wrapper + + +def iscoroutinefunction(func): + """Return True if func is a decorated coroutine function.""" + return getattr(func, '_is_coroutine', False) + + +def iscoroutine(obj): + """Return True if obj is a coroutine object.""" + return isinstance(obj, CoroWrapper) or inspect.isgenerator(obj) + + +def _format_coroutine(coro): + assert iscoroutine(coro) + if _PY35: + coro_name = coro.__qualname__ + else: + coro_name = coro.__name__ + + filename = coro.gi_code.co_filename + if coro.gi_frame is not None: + lineno = coro.gi_frame.f_lineno + return '%s() at %s:%s' % (coro_name, filename, lineno) + else: + lineno = coro.gi_code.co_firstlineno + return '%s() done at %s:%s' % (coro_name, filename, lineno) diff --git a/asyncio/locks.py b/asyncio/locks.py index 29c4434a..8d9e3b4d 100644 --- a/asyncio/locks.py +++ b/asyncio/locks.py @@ -6,7 +6,7 @@ from . import events from . import futures -from . import tasks +from .coroutines import coroutine class _ContextManager: @@ -112,7 +112,7 @@ def locked(self): """Return True if lock is acquired.""" return self._locked - @tasks.coroutine + @coroutine def acquire(self): """Acquire a lock. @@ -225,7 +225,7 @@ def clear(self): to true again.""" self._value = False - @tasks.coroutine + @coroutine def wait(self): """Block until the internal flag is true. @@ -278,7 +278,7 @@ def __repr__(self): extra = '{},waiters:{}'.format(extra, len(self._waiters)) return '<{} [{}]>'.format(res[1:-1], extra) - @tasks.coroutine + @coroutine def wait(self): """Wait until notified. @@ -306,7 +306,7 @@ def wait(self): finally: yield from self.acquire() - @tasks.coroutine + @coroutine def wait_for(self, predicate): """Wait until a predicate becomes true. @@ -402,7 +402,7 @@ def locked(self): """Returns True if semaphore can not be acquired immediately.""" return self._value == 0 - @tasks.coroutine + @coroutine def acquire(self): """Acquire a semaphore. diff --git a/asyncio/streams.py b/asyncio/streams.py index e239248d..a10b969c 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -10,10 +10,12 @@ if hasattr(socket, 'AF_UNIX'): __all__.extend(['open_unix_connection', 'start_unix_server']) +from . import coroutines from . import events from . import futures from . import protocols from . import tasks +from .coroutines import coroutine _DEFAULT_LIMIT = 2**16 @@ -33,7 +35,7 @@ def __init__(self, partial, expected): self.expected = expected -@tasks.coroutine +@coroutine def open_connection(host=None, port=None, *, loop=None, limit=_DEFAULT_LIMIT, **kwds): """A wrapper for create_connection() returning a (reader, writer) pair. @@ -63,7 +65,7 @@ def open_connection(host=None, port=None, *, return reader, writer -@tasks.coroutine +@coroutine def start_server(client_connected_cb, host=None, port=None, *, loop=None, limit=_DEFAULT_LIMIT, **kwds): """Start a socket server, call back for each client connected. @@ -102,7 +104,7 @@ def factory(): if hasattr(socket, 'AF_UNIX'): # UNIX Domain Sockets are supported on this platform - @tasks.coroutine + @coroutine def open_unix_connection(path=None, *, loop=None, limit=_DEFAULT_LIMIT, **kwds): """Similar to `open_connection` but works with UNIX Domain Sockets.""" @@ -116,7 +118,7 @@ def open_unix_connection(path=None, *, return reader, writer - @tasks.coroutine + @coroutine def start_unix_server(client_connected_cb, path=None, *, loop=None, limit=_DEFAULT_LIMIT, **kwds): """Similar to `start_server` but works with UNIX Domain Sockets.""" @@ -210,7 +212,7 @@ def connection_made(self, transport): self._loop) res = self._client_connected_cb(self._stream_reader, self._stream_writer) - if tasks.iscoroutine(res): + if coroutines.iscoroutine(res): tasks.Task(res, loop=self._loop) def connection_lost(self, exc): @@ -373,7 +375,7 @@ def _create_waiter(self, func_name): 'already waiting for incoming data' % func_name) return futures.Future(loop=self._loop) - @tasks.coroutine + @coroutine def readline(self): if self._exception is not None: raise self._exception @@ -410,7 +412,7 @@ def readline(self): self._maybe_resume_transport() return bytes(line) - @tasks.coroutine + @coroutine def read(self, n=-1): if self._exception is not None: raise self._exception @@ -449,7 +451,7 @@ def read(self, n=-1): self._maybe_resume_transport() return data - @tasks.coroutine + @coroutine def readexactly(self, n): if self._exception is not None: raise self._exception diff --git a/asyncio/subprocess.py b/asyncio/subprocess.py index 414e0238..2cd6de6d 100644 --- a/asyncio/subprocess.py +++ b/asyncio/subprocess.py @@ -8,6 +8,7 @@ from . import protocols from . import streams from . import tasks +from .coroutines import coroutine PIPE = subprocess.PIPE @@ -94,7 +95,7 @@ def __init__(self, transport, protocol, loop): def returncode(self): return self._transport.get_returncode() - @tasks.coroutine + @coroutine def wait(self): """Wait until the process exit and return the process return code.""" returncode = self._transport.get_returncode() @@ -122,17 +123,17 @@ def kill(self): self._check_alive() self._transport.kill() - @tasks.coroutine + @coroutine def _feed_stdin(self, input): self.stdin.write(input) yield from self.stdin.drain() self.stdin.close() - @tasks.coroutine + @coroutine def _noop(self): return None - @tasks.coroutine + @coroutine def _read_stream(self, fd): transport = self._transport.get_pipe_transport(fd) if fd == 2: @@ -144,7 +145,7 @@ def _read_stream(self, fd): transport.close() return output - @tasks.coroutine + @coroutine def communicate(self, input=None): if input: stdin = self._feed_stdin(input) @@ -164,7 +165,7 @@ def communicate(self, input=None): return (stdout, stderr) -@tasks.coroutine +@coroutine def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None, loop=None, limit=streams._DEFAULT_LIMIT, **kwds): if loop is None: @@ -178,7 +179,7 @@ def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None, yield from protocol.waiter return Process(transport, protocol, loop) -@tasks.coroutine +@coroutine def create_subprocess_exec(program, *args, stdin=None, stdout=None, stderr=None, loop=None, limit=streams._DEFAULT_LIMIT, **kwds): diff --git a/asyncio/tasks.py b/asyncio/tasks.py index db0bbf3a..5b8f3eb4 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -1,7 +1,6 @@ """Support for tasks, coroutines and the scheduler.""" -__all__ = ['coroutine', 'Task', - 'iscoroutinefunction', 'iscoroutine', +__all__ = ['Task', 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', 'wait', 'wait_for', 'as_completed', 'sleep', 'async', 'gather', 'shield', @@ -11,146 +10,20 @@ import functools import inspect import linecache -import os import sys import traceback import weakref +from . import coroutines from . import events from . import futures +from .coroutines import coroutine from .log import logger -# If you set _DEBUG to true, @coroutine will wrap the resulting -# generator objects in a CoroWrapper instance (defined below). That -# instance will log a message when the generator is never iterated -# over, which may happen when you forget to use "yield from" with a -# coroutine call. Note that the value of the _DEBUG flag is taken -# when the decorator is used, so to be of any use it must be set -# before you define your coroutines. A downside of using this feature -# is that tracebacks show entries for the CoroWrapper.__next__ method -# when _DEBUG is true. -_DEBUG = (not sys.flags.ignore_environment - and bool(os.environ.get('PYTHONASYNCIODEBUG'))) - _PY34 = (sys.version_info >= (3, 4)) _PY35 = (sys.version_info >= (3, 5)) -class CoroWrapper: - # Wrapper for coroutine in _DEBUG mode. - - def __init__(self, gen, func): - assert inspect.isgenerator(gen), gen - self.gen = gen - self.func = func - self._source_traceback = traceback.extract_stack(sys._getframe(1)) - - def __iter__(self): - return self - - def __next__(self): - return next(self.gen) - - def send(self, *value): - # We use `*value` because of a bug in CPythons prior - # to 3.4.1. See issue #21209 and test_yield_from_corowrapper - # for details. This workaround should be removed in 3.5.0. - if len(value) == 1: - value = value[0] - return self.gen.send(value) - - def throw(self, exc): - return self.gen.throw(exc) - - def close(self): - return self.gen.close() - - @property - def gi_frame(self): - return self.gen.gi_frame - - @property - def gi_running(self): - return self.gen.gi_running - - @property - def gi_code(self): - return self.gen.gi_code - - def __del__(self): - # Be careful accessing self.gen.frame -- self.gen might not exist. - gen = getattr(self, 'gen', None) - frame = getattr(gen, 'gi_frame', None) - if frame is not None and frame.f_lasti == -1: - func = events._format_callback(self.func, ()) - tb = ''.join(traceback.format_list(self._source_traceback)) - message = ('Coroutine %s was never yielded from\n' - 'Coroutine object created at (most recent call last):\n' - '%s' - % (func, tb.rstrip())) - logger.error(message) - - -def coroutine(func): - """Decorator to mark coroutines. - - If the coroutine is not yielded from before it is destroyed, - an error message is logged. - """ - if inspect.isgeneratorfunction(func): - coro = func - else: - @functools.wraps(func) - def coro(*args, **kw): - res = func(*args, **kw) - if isinstance(res, futures.Future) or inspect.isgenerator(res): - res = yield from res - return res - - if not _DEBUG: - wrapper = coro - else: - @functools.wraps(func) - def wrapper(*args, **kwds): - w = CoroWrapper(coro(*args, **kwds), func) - if w._source_traceback: - del w._source_traceback[-1] - w.__name__ = func.__name__ - if _PY35: - w.__qualname__ = func.__qualname__ - w.__doc__ = func.__doc__ - return w - - wrapper._is_coroutine = True # For iscoroutinefunction(). - return wrapper - - -def iscoroutinefunction(func): - """Return True if func is a decorated coroutine function.""" - return getattr(func, '_is_coroutine', False) - - -def iscoroutine(obj): - """Return True if obj is a coroutine object.""" - return isinstance(obj, CoroWrapper) or inspect.isgenerator(obj) - - -def _format_coroutine(coro): - assert iscoroutine(coro) - if _PY35: - coro_name = coro.__qualname__ - else: - coro_name = coro.__name__ - - filename = coro.gi_code.co_filename - if coro.gi_frame is not None: - lineno = coro.gi_frame.f_lineno - return '%s() at %s:%s' % (coro_name, filename, lineno) - else: - lineno = coro.gi_code.co_firstlineno - return '%s() done at %s:%s' % (coro_name, filename, lineno) - - class Task(futures.Future): """A coroutine wrapped in a Future.""" @@ -193,7 +66,7 @@ def all_tasks(cls, loop=None): return {t for t in cls._all_tasks if t._loop is loop} def __init__(self, coro, *, loop=None): - assert iscoroutine(coro), repr(coro) # Not a coroutine function! + assert coroutines.iscoroutine(coro), repr(coro) # Not a coroutine function! super().__init__(loop=loop) if self._source_traceback: del self._source_traceback[-1] @@ -225,7 +98,7 @@ def __repr__(self): else: info.append(self._state.lower()) - info.append(_format_coroutine(self._coro)) + info.append(coroutines._format_coroutine(self._coro)) if self._state == futures._FINISHED: info.append(self._format_result()) @@ -444,7 +317,7 @@ def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): Note: This does not raise TimeoutError! Futures that aren't done when the timeout occurs are returned in the second set. """ - if isinstance(fs, futures.Future) or iscoroutine(fs): + if isinstance(fs, futures.Future) or coroutines.iscoroutine(fs): raise TypeError("expect a list of futures, not %s" % type(fs).__name__) if not fs: raise ValueError('Set of coroutines/Futures is empty.') @@ -566,7 +439,7 @@ def as_completed(fs, *, loop=None, timeout=None): Note: The futures 'f' are not necessarily members of fs. """ - if isinstance(fs, futures.Future) or iscoroutine(fs): + if isinstance(fs, futures.Future) or coroutines.iscoroutine(fs): raise TypeError("expect a list of futures, not %s" % type(fs).__name__) loop = loop if loop is not None else events.get_event_loop() todo = {async(f, loop=loop) for f in set(fs)} @@ -624,7 +497,7 @@ def async(coro_or_future, *, loop=None): if loop is not None and loop is not coro_or_future._loop: raise ValueError('loop argument must agree with Future') return coro_or_future - elif iscoroutine(coro_or_future): + elif coroutines.iscoroutine(coro_or_future): task = Task(coro_or_future, loop=loop) if task._source_traceback: del task._source_traceback[-1] diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index d9c7ae2d..94054e70 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -27,6 +27,7 @@ from . import futures from . import selectors from . import tasks +from .coroutines import coroutine if sys.platform == 'win32': # pragma: no cover @@ -43,7 +44,7 @@ def dummy_ssl_context(): def run_briefly(loop): - @tasks.coroutine + @coroutine def once(): pass gen = once() diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index ad4c2294..1cb70ffa 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -16,8 +16,8 @@ from . import constants from . import events from . import selector_events -from . import tasks from . import transports +from .coroutines import coroutine from .log import logger @@ -147,7 +147,7 @@ def _make_write_pipe_transport(self, pipe, protocol, waiter=None, extra=None): return _UnixWritePipeTransport(self, pipe, protocol, waiter, extra) - @tasks.coroutine + @coroutine def _make_subprocess_transport(self, protocol, args, shell, stdin, stdout, stderr, bufsize, extra=None, **kwargs): @@ -164,7 +164,7 @@ def _make_subprocess_transport(self, protocol, args, shell, def _child_watcher_callback(self, pid, returncode, transp): self.call_soon_threadsafe(transp._process_exited, returncode) - @tasks.coroutine + @coroutine def create_unix_connection(self, protocol_factory, path, *, ssl=None, sock=None, server_hostname=None): @@ -199,7 +199,7 @@ def create_unix_connection(self, protocol_factory, path, *, sock, protocol_factory, ssl, server_hostname) return transport, protocol - @tasks.coroutine + @coroutine def create_unix_server(self, protocol_factory, path=None, *, sock=None, backlog=100, ssl=None): if isinstance(ssl, bool): diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 19f25882..93b71b2a 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -14,8 +14,9 @@ from . import selector_events from . import tasks from . import windows_utils -from .log import logger from . import _overlapped +from .coroutines import coroutine +from .log import logger __all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor', @@ -129,7 +130,7 @@ def __init__(self, proactor=None): def _socketpair(self): return windows_utils.socketpair() - @tasks.coroutine + @coroutine def create_pipe_connection(self, protocol_factory, address): f = self._proactor.connect_pipe(address) pipe = yield from f @@ -138,7 +139,7 @@ def create_pipe_connection(self, protocol_factory, address): extra={'addr': address}) return trans, protocol - @tasks.coroutine + @coroutine def start_serving_pipe(self, protocol_factory, address): server = PipeServer(address) @@ -172,7 +173,7 @@ def loop(f=None): self.call_soon(loop) return [server] - @tasks.coroutine + @coroutine def _make_subprocess_transport(self, protocol, args, shell, stdin, stdout, stderr, bufsize, extra=None, **kwargs): @@ -258,7 +259,7 @@ def finish_accept(trans, key, ov): conn.settimeout(listener.gettimeout()) return conn, conn.getpeername() - @tasks.coroutine + @coroutine def accept_coro(future, conn): # Coroutine closing the accept socket if the future is cancelled try: diff --git a/tests/test_tasks.py b/tests/test_tasks.py index b4a3092e..d509768b 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -11,7 +11,7 @@ from unittest import mock import asyncio -from asyncio import tasks +from asyncio import coroutines from asyncio import test_utils @@ -193,7 +193,7 @@ def notmuch(): # attribute). coro_name = 'notmuch' coro_qualname = 'TaskTests.test_task_repr_coro_decorator..notmuch' - elif tasks._DEBUG: + elif coroutines._DEBUG: # In debug mode, @coroutine decorator uses CoroWrapper which gets # its name (__name__ attribute) from the wrapped coroutine # function. @@ -1475,23 +1475,23 @@ def coro(): self.assertIsNone(gen.gi_frame) # Save debug flag. - old_debug = asyncio.tasks._DEBUG + old_debug = asyncio.coroutines._DEBUG try: # Test with debug flag cleared. - asyncio.tasks._DEBUG = False + asyncio.coroutines._DEBUG = False check() # Test with debug flag set. - asyncio.tasks._DEBUG = True + asyncio.coroutines._DEBUG = True check() finally: # Restore original debug flag. - asyncio.tasks._DEBUG = old_debug + asyncio.coroutines._DEBUG = old_debug def test_yield_from_corowrapper(self): - old_debug = asyncio.tasks._DEBUG - asyncio.tasks._DEBUG = True + old_debug = asyncio.coroutines._DEBUG + asyncio.coroutines._DEBUG = True try: @asyncio.coroutine def t1(): @@ -1511,7 +1511,7 @@ def t3(f): val = self.loop.run_until_complete(task) self.assertEqual(val, (1, 2, 3)) finally: - asyncio.tasks._DEBUG = old_debug + asyncio.coroutines._DEBUG = old_debug def test_yield_from_corowrapper_send(self): def foo(): @@ -1519,7 +1519,7 @@ def foo(): return a def call(arg): - cw = asyncio.tasks.CoroWrapper(foo(), foo) + cw = asyncio.coroutines.CoroWrapper(foo(), foo) cw.send(None) try: cw.send(arg) @@ -1534,7 +1534,7 @@ def call(arg): def test_corowrapper_weakref(self): wd = weakref.WeakValueDictionary() def foo(): yield from [] - cw = asyncio.tasks.CoroWrapper(foo(), foo) + cw = asyncio.coroutines.CoroWrapper(foo(), foo) wd['cw'] = cw # Would fail without __weakref__ slot. cw.gen = None # Suppress warning from __del__. @@ -1580,16 +1580,16 @@ def kill_me(loop): }) mock_handler.reset_mock() - @mock.patch('asyncio.tasks.logger') + @mock.patch('asyncio.coroutines.logger') def test_coroutine_never_yielded(self, m_log): - debug = asyncio.tasks._DEBUG + debug = asyncio.coroutines._DEBUG try: - asyncio.tasks._DEBUG = True + asyncio.coroutines._DEBUG = True @asyncio.coroutine def coro_noop(): pass finally: - asyncio.tasks._DEBUG = debug + asyncio.coroutines._DEBUG = debug tb_filename = __file__ tb_lineno = sys._getframe().f_lineno + 1 @@ -1695,8 +1695,8 @@ def test_return_exceptions(self): def test_env_var_debug(self): code = '\n'.join(( - 'import asyncio.tasks', - 'print(asyncio.tasks._DEBUG)')) + 'import asyncio.coroutines', + 'print(asyncio.coroutines._DEBUG)')) # Test with -E to not fail if the unit test was run with # PYTHONASYNCIODEBUG set to a non-empty string From f66bbd4bebb3a71bd951250fb5e5fd4b8abeb2a7 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 30 Jun 2014 13:43:13 +0200 Subject: [PATCH 1064/1502] Sort imports --- asyncio/base_events.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index c42e7f98..b3d6e034 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -19,11 +19,11 @@ import heapq import inspect import logging +import os import socket import subprocess -import traceback import time -import os +import traceback import sys from . import coroutines From bd81f3ca80822571a26d679d33489fe49ac55935 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 30 Jun 2014 13:44:02 +0200 Subject: [PATCH 1065/1502] Simplify/optimize iscoroutine() Inline inspect.isgenerator(obj): replace it with isinstance(obj, types.GeneratorType). --- asyncio/coroutines.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/asyncio/coroutines.py b/asyncio/coroutines.py index 5b6d93f0..e9e78336 100644 --- a/asyncio/coroutines.py +++ b/asyncio/coroutines.py @@ -6,6 +6,7 @@ import os import sys import traceback +import types from . import events from . import futures @@ -119,9 +120,11 @@ def iscoroutinefunction(func): return getattr(func, '_is_coroutine', False) +_COROUTINE_TYPES = (CoroWrapper, types.GeneratorType) + def iscoroutine(obj): """Return True if obj is a coroutine object.""" - return isinstance(obj, CoroWrapper) or inspect.isgenerator(obj) + return isinstance(obj, _COROUTINE_TYPES) def _format_coroutine(coro): From 813f2409043a8de78a26468b3e20c65184e4c464 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 30 Jun 2014 13:55:09 +0200 Subject: [PATCH 1066/1502] CoroWrapper: check at runtime if Python has the yield-from bug #21209 If Python has the bug, check if CoroWrapper.send() was called by yield-from to decide if parameters must be unpacked or not. --- asyncio/coroutines.py | 52 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 45 insertions(+), 7 deletions(-) diff --git a/asyncio/coroutines.py b/asyncio/coroutines.py index e9e78336..cdb1ea8d 100644 --- a/asyncio/coroutines.py +++ b/asyncio/coroutines.py @@ -3,6 +3,7 @@ import functools import inspect +import opcode import os import sys import traceback @@ -12,6 +13,10 @@ from . import futures from .log import logger + +# Opcode of "yield from" instruction +_YIELD_FROM = opcode.opmap['YIELD_FROM'] + # If you set _DEBUG to true, @coroutine will wrap the resulting # generator objects in a CoroWrapper instance (defined below). That # instance will log a message when the generator is never iterated @@ -26,6 +31,31 @@ _PY35 = (sys.version_info >= (3, 5)) + +# Check for CPython issue #21209 +def has_yield_from_bug(): + class MyGen: + def __init__(self): + self.send_args = None + def __iter__(self): + return self + def __next__(self): + return 42 + def send(self, *what): + self.send_args = what + return None + def yield_from_gen(gen): + yield from gen + value = (1, 2, 3) + gen = MyGen() + coro = yield_from_gen(gen) + next(coro) + coro.send(value) + return gen.send_args != (value,) +_YIELD_FROM_BUG = has_yield_from_bug() +del has_yield_from_bug + + class CoroWrapper: # Wrapper for coroutine in _DEBUG mode. @@ -41,13 +71,21 @@ def __iter__(self): def __next__(self): return next(self.gen) - def send(self, *value): - # We use `*value` because of a bug in CPythons prior - # to 3.4.1. See issue #21209 and test_yield_from_corowrapper - # for details. This workaround should be removed in 3.5.0. - if len(value) == 1: - value = value[0] - return self.gen.send(value) + if _YIELD_FROM_BUG: + # For for CPython issue #21209: using "yield from" and a custom + # generator, generator.send(tuple) unpacks the tuple instead of passing + # the tuple unchanged. Check if the caller is a generator using "yield + # from" to decide if the parameter should be unpacked or not. + def send(self, *value): + frame = sys._getframe() + caller = frame.f_back + assert caller.f_lasti >= 0 + if caller.f_code.co_code[caller.f_lasti] != _YIELD_FROM: + value = value[0] + return self.gen.send(value) + else: + def send(self, value): + return self.gen.send(value) def throw(self, exc): return self.gen.throw(exc) From 78bfe476819a6922cdeadf6f72b740b85fdea62d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 30 Jun 2014 14:02:39 +0200 Subject: [PATCH 1067/1502] Fix "Task was destroyed but it is pending!" warning in test_task_source_traceback() --- tests/test_tasks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index d509768b..a5706ae5 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1621,6 +1621,7 @@ def test_task_source_traceback(self): (__file__, lineno, 'test_task_source_traceback')) + self.loop.run_until_complete(task) class GatherTestsBase: From 06fabb49da8d0069b4067f24ade6dd05bfe09d00 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 30 Jun 2014 14:48:53 +0200 Subject: [PATCH 1068/1502] Python issue #21163: BaseEventLoop.run_until_complete() and test_utils.run_briefly() don't log the "destroy pending task" message anymore. The log is redundant for run_until_complete() and useless in run_briefly(). --- asyncio/base_events.py | 7 +++++++ asyncio/tasks.py | 5 ++++- asyncio/test_utils.py | 3 +++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index b3d6e034..2230dc2c 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -227,7 +227,14 @@ def run_until_complete(self, future): Return the Future's result, or raise its exception. """ self._check_closed() + + new_task = not isinstance(future, futures.Future) future = tasks.async(future, loop=self) + if new_task: + # An exception is raised if the future didn't complete, so there + # is no need to log the "destroy pending task" message + future._log_destroy_pending = False + future.add_done_callback(_raise_stop_error) self.run_forever() future.remove_done_callback(_raise_stop_error) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 5b8f3eb4..e9adf1df 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -75,13 +75,16 @@ def __init__(self, coro, *, loop=None): self._must_cancel = False self._loop.call_soon(self._step) self.__class__._all_tasks.add(self) + # If False, don't log a message if the task is destroyed whereas its + # status is still pending + self._log_destroy_pending = True # On Python 3.3 or older, objects with a destructor part of a reference # cycle are never destroyed. It's not more the case on Python 3.4 thanks to # the PEP 442. if _PY34: def __del__(self): - if self._state == futures._PENDING: + if self._state == futures._PENDING and self._log_destroy_pending: context = { 'task': self, 'message': 'Task was destroyed but it is pending!', diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index 94054e70..ef3be236 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -49,6 +49,9 @@ def once(): pass gen = once() t = tasks.Task(gen, loop=loop) + # Don't log a warning if the task is not done after run_until_complete(). + # It occurs if the loop is stopped or if a task raises a BaseException. + t._log_destroy_pending = False try: loop.run_until_complete(t) finally: From a891ac21ea4b64f2d8542668403f34d39c623b7a Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 1 Jul 2014 12:38:21 +0200 Subject: [PATCH 1069/1502] Fix test_sleep_cancel(): call_later() mock has no self parameter --- tests/test_tasks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index a5706ae5..c64e1ef5 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -993,9 +993,9 @@ def gen(): handle = None orig_call_later = loop.call_later - def call_later(self, delay, callback, *args): + def call_later(delay, callback, *args): nonlocal handle - handle = orig_call_later(self, delay, callback, *args) + handle = orig_call_later(delay, callback, *args) return handle loop.call_later = call_later From ea6948b37c7e62344c8228b33fdc1b219720eb65 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 1 Jul 2014 20:23:39 +0200 Subject: [PATCH 1070/1502] repr(Task): include also the future the task is waiting for --- asyncio/tasks.py | 3 +++ tests/test_tasks.py | 11 +++++++++++ 2 files changed, 14 insertions(+) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index e9adf1df..dd191e77 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -109,6 +109,9 @@ def __repr__(self): if self._callbacks: info.append(self._format_callbacks()) + if self._fut_waiter is not None: + info.append('wait_for=%r' % self._fut_waiter) + return '<%s %s>' % (self.__class__.__name__, ' '.join(info)) def get_stack(self, *, limit=None): diff --git a/tests/test_tasks.py b/tests/test_tasks.py index c64e1ef5..83b7e61f 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -218,6 +218,17 @@ def notmuch(): '()]>' % coro) self.loop.run_until_complete(t) + def test_task_repr_wait_for(self): + @asyncio.coroutine + def wait_for(fut): + return (yield from fut) + + fut = asyncio.Future(loop=self.loop) + task = asyncio.Task(wait_for(fut), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertRegex(repr(task), + '' % re.escape(repr(fut))) + def test_task_basics(self): @asyncio.coroutine def outer(): From 6498249322fd73e07a2369eeea30cf610d8aea2e Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 3 Jul 2014 00:50:19 +0200 Subject: [PATCH 1071/1502] More reliable CoroWrapper.__del__ If the constructor is interrupted by KeyboardInterrupt or the coroutine objet is destroyed lately, some the _source_traceback attribute doesn't exist anymore. --- asyncio/coroutines.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/asyncio/coroutines.py b/asyncio/coroutines.py index cdb1ea8d..71a1ec4d 100644 --- a/asyncio/coroutines.py +++ b/asyncio/coroutines.py @@ -111,12 +111,14 @@ def __del__(self): frame = getattr(gen, 'gi_frame', None) if frame is not None and frame.f_lasti == -1: func = events._format_callback(self.func, ()) - tb = ''.join(traceback.format_list(self._source_traceback)) - message = ('Coroutine %s was never yielded from\n' - 'Coroutine object created at (most recent call last):\n' - '%s' - % (func, tb.rstrip())) - logger.error(message) + msg = 'Coroutine %s was never yielded from' % func + tb = getattr(self, '_source_traceback', ()) + if tb: + tb = ''.join(traceback.format_list(tb)) + msg += ('\nCoroutine object created at ' + '(most recent call last):\n') + msg += tb.rstrip() + logger.error(msg) def coroutine(func): From 7cf3851dc3b0940bb8631fadafbce1cfed6572e2 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 2 Jul 2014 01:26:08 +0200 Subject: [PATCH 1072/1502] Examples: close the event loop at exit --- examples/cachesvr.py | 5 ++++- examples/echo_client_tulip.py | 1 + examples/echo_server_tulip.py | 5 ++++- examples/hello_callback.py | 5 ++++- examples/hello_coroutine.py | 5 ++++- examples/shell.py | 1 + examples/stacks.py | 1 + examples/subprocess_attach_read_pipe.py | 1 + examples/subprocess_attach_write_pipe.py | 2 ++ examples/subprocess_shell.py | 1 + 10 files changed, 23 insertions(+), 4 deletions(-) diff --git a/examples/cachesvr.py b/examples/cachesvr.py index ddb79b6f..053f9c21 100644 --- a/examples/cachesvr.py +++ b/examples/cachesvr.py @@ -238,7 +238,10 @@ def main(): svr = loop.run_until_complete(task) for sock in svr.sockets: logging.info('socket %s', sock.getsockname()) - loop.run_forever() + try: + loop.run_forever() + finally: + loop.close() if __name__ == '__main__': diff --git a/examples/echo_client_tulip.py b/examples/echo_client_tulip.py index 9b5d29b8..88124efe 100644 --- a/examples/echo_client_tulip.py +++ b/examples/echo_client_tulip.py @@ -17,3 +17,4 @@ def echo_client(): loop = asyncio.get_event_loop() loop.run_until_complete(echo_client()) +loop.close() diff --git a/examples/echo_server_tulip.py b/examples/echo_server_tulip.py index c1ccb9df..8167e540 100644 --- a/examples/echo_server_tulip.py +++ b/examples/echo_server_tulip.py @@ -14,4 +14,7 @@ def handle_connection(reader, writer): loop = asyncio.get_event_loop() loop.run_until_complete(echo_server()) -loop.run_forever() +try: + loop.run_forever() +finally: + loop.close() diff --git a/examples/hello_callback.py b/examples/hello_callback.py index df889e55..7ccbea1e 100644 --- a/examples/hello_callback.py +++ b/examples/hello_callback.py @@ -11,4 +11,7 @@ def print_and_repeat(loop): if __name__ == '__main__': loop = asyncio.get_event_loop() print_and_repeat(loop) - loop.run_forever() + try: + loop.run_forever() + finally: + loop.close() diff --git a/examples/hello_coroutine.py b/examples/hello_coroutine.py index 8ad682d2..b9347aa8 100644 --- a/examples/hello_coroutine.py +++ b/examples/hello_coroutine.py @@ -12,4 +12,7 @@ def greet_every_two_seconds(): if __name__ == '__main__': loop = asyncio.get_event_loop() - loop.run_until_complete(greet_every_two_seconds()) + try: + loop.run_until_complete(greet_every_two_seconds()) + finally: + loop.close() diff --git a/examples/shell.py b/examples/shell.py index 8ae30ca9..7dc7caf3 100644 --- a/examples/shell.py +++ b/examples/shell.py @@ -47,3 +47,4 @@ def test_call(*args, timeout=None): loop.run_until_complete(cat(loop)) loop.run_until_complete(ls(loop)) loop.run_until_complete(test_call("bash", "-c", "sleep 3", timeout=1.0)) +loop.close() diff --git a/examples/stacks.py b/examples/stacks.py index 371d31f2..0b7e0b2c 100644 --- a/examples/stacks.py +++ b/examples/stacks.py @@ -33,6 +33,7 @@ def doit(): finally: for t in Task.all_tasks(): t.print_stack() + l.close() def main(): diff --git a/examples/subprocess_attach_read_pipe.py b/examples/subprocess_attach_read_pipe.py index a692781c..d8a62420 100644 --- a/examples/subprocess_attach_read_pipe.py +++ b/examples/subprocess_attach_read_pipe.py @@ -30,3 +30,4 @@ def task(): print("read = %r" % data.decode()) loop.run_until_complete(task()) +loop.close() diff --git a/examples/subprocess_attach_write_pipe.py b/examples/subprocess_attach_write_pipe.py index 017b827f..86148774 100644 --- a/examples/subprocess_attach_write_pipe.py +++ b/examples/subprocess_attach_write_pipe.py @@ -29,5 +29,7 @@ def task(): stdout, stderr = yield from proc.communicate() print("stdout = %r" % stdout.decode()) + pipe.close() loop.run_until_complete(task()) +loop.close() diff --git a/examples/subprocess_shell.py b/examples/subprocess_shell.py index ca871540..745cb646 100644 --- a/examples/subprocess_shell.py +++ b/examples/subprocess_shell.py @@ -80,6 +80,7 @@ def main(): loop = asyncio.get_event_loop() loop.run_until_complete(start( 'sleep 2; wc', input=[b'foo bar baz\n'*300 for i in range(100)])) + loop.close() if __name__ == '__main__': From 7ec3e020ce711ff426f2ff66f6c3695f9ddb5ece Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 1 Jul 2014 21:52:29 +0200 Subject: [PATCH 1073/1502] _UnixSubprocessTransport: fix file mode of stdin Open stdin in write mode, not in read mode --- asyncio/unix_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 1cb70ffa..5f728b57 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -494,7 +494,7 @@ def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): universal_newlines=False, bufsize=bufsize, **kwargs) if stdin_w is not None: stdin.close() - self._proc.stdin = open(stdin_w.detach(), 'rb', buffering=bufsize) + self._proc.stdin = open(stdin_w.detach(), 'wb', buffering=bufsize) class AbstractChildWatcher: From cfd625f332bbdf45e4427fe8530a34f05f299d4d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 3 Jul 2014 13:42:55 +0200 Subject: [PATCH 1074/1502] Add repr(CoroWrapper) --- asyncio/coroutines.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/asyncio/coroutines.py b/asyncio/coroutines.py index 71a1ec4d..524fa71c 100644 --- a/asyncio/coroutines.py +++ b/asyncio/coroutines.py @@ -64,6 +64,12 @@ def __init__(self, gen, func): self.gen = gen self.func = func self._source_traceback = traceback.extract_stack(sys._getframe(1)) + # __name__, __qualname__, __doc__ attributes are set by the coroutine() + # decorator + + def __repr__(self): + return ('<%s %s>' + % (self.__class__.__name__, _format_coroutine(self.gen))) def __iter__(self): return self From bdf785ead2a4cdb65a4fd345d40522b04d2a3f25 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 3 Jul 2014 13:50:53 +0200 Subject: [PATCH 1075/1502] Better repr(CoroWrapper); add unit test for repr(CoroWrapper): ensure that the qualified name is used --- asyncio/coroutines.py | 2 +- tests/test_tasks.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/asyncio/coroutines.py b/asyncio/coroutines.py index 524fa71c..7654a0b9 100644 --- a/asyncio/coroutines.py +++ b/asyncio/coroutines.py @@ -69,7 +69,7 @@ def __init__(self, gen, func): def __repr__(self): return ('<%s %s>' - % (self.__class__.__name__, _format_coroutine(self.gen))) + % (self.__class__.__name__, _format_coroutine(self))) def __iter__(self): return self diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 83b7e61f..eaef05b5 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -211,6 +211,10 @@ def notmuch(): coro = ('%s() at %s:%s' % (coro_qualname, code.co_filename, code.co_firstlineno)) + # test repr(CoroWrapper) + if coroutines._DEBUG: + self.assertEqual(repr(gen), '' % coro) + # test pending Task t = asyncio.Task(gen, loop=self.loop) t.add_done_callback(Dummy()) From db4d37bc0e5a8c4dca7e347374bfb5bafb7fdc02 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 3 Jul 2014 17:23:04 +0200 Subject: [PATCH 1076/1502] Add asyncio.tasks.task_factory variable In the greenio project, Task._step() should not create Task objects but GreenTask to control how tasks are executed. Luca Sbardella already asked this feature for its Pulsar project to support coroutines using yield instead of yield-from. --- asyncio/streams.py | 2 +- asyncio/tasks.py | 5 ++++- asyncio/test_utils.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/asyncio/streams.py b/asyncio/streams.py index a10b969c..90b0e751 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -213,7 +213,7 @@ def connection_made(self, transport): res = self._client_connected_cb(self._stream_reader, self._stream_writer) if coroutines.iscoroutine(res): - tasks.Task(res, loop=self._loop) + tasks.task_factory(res, loop=self._loop) def connection_lost(self, exc): if exc is None: diff --git a/asyncio/tasks.py b/asyncio/tasks.py index dd191e77..3e26d4f2 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -299,6 +299,9 @@ def _wakeup(self, future): self = None # Needed to break cycles when an exception occurs. +task_factory = Task + + # wait() and as_completed() similar to those in PEP 3148. FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED @@ -504,7 +507,7 @@ def async(coro_or_future, *, loop=None): raise ValueError('loop argument must agree with Future') return coro_or_future elif coroutines.iscoroutine(coro_or_future): - task = Task(coro_or_future, loop=loop) + task = task_factory(coro_or_future, loop=loop) if task._source_traceback: del task._source_traceback[-1] return task diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index ef3be236..644413a2 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -48,7 +48,7 @@ def run_briefly(loop): def once(): pass gen = once() - t = tasks.Task(gen, loop=loop) + t = tasks.task_factory(gen, loop=loop) # Don't log a warning if the task is not done after run_until_complete(). # It occurs if the loop is stopped or if a task raises a BaseException. t._log_destroy_pending = False From e81724d15d78ac3aa9ba067fe8639963c0b9fadb Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 3 Jul 2014 17:24:21 +0200 Subject: [PATCH 1077/1502] Backed out changeset b288da71fb40 Oops, I wanted to send this patch for review before --- asyncio/streams.py | 2 +- asyncio/tasks.py | 5 +---- asyncio/test_utils.py | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/asyncio/streams.py b/asyncio/streams.py index 90b0e751..a10b969c 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -213,7 +213,7 @@ def connection_made(self, transport): res = self._client_connected_cb(self._stream_reader, self._stream_writer) if coroutines.iscoroutine(res): - tasks.task_factory(res, loop=self._loop) + tasks.Task(res, loop=self._loop) def connection_lost(self, exc): if exc is None: diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 3e26d4f2..dd191e77 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -299,9 +299,6 @@ def _wakeup(self, future): self = None # Needed to break cycles when an exception occurs. -task_factory = Task - - # wait() and as_completed() similar to those in PEP 3148. FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED @@ -507,7 +504,7 @@ def async(coro_or_future, *, loop=None): raise ValueError('loop argument must agree with Future') return coro_or_future elif coroutines.iscoroutine(coro_or_future): - task = task_factory(coro_or_future, loop=loop) + task = Task(coro_or_future, loop=loop) if task._source_traceback: del task._source_traceback[-1] return task diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index 644413a2..ef3be236 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -48,7 +48,7 @@ def run_briefly(loop): def once(): pass gen = once() - t = tasks.task_factory(gen, loop=loop) + t = tasks.Task(gen, loop=loop) # Don't log a warning if the task is not done after run_until_complete(). # It occurs if the loop is stopped or if a task raises a BaseException. t._log_destroy_pending = False From 052b194da87f037df01ba32860e07d54cbe81fed Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 5 Jul 2014 15:27:34 +0200 Subject: [PATCH 1078/1502] Python issue 21447, 21886: Fix a race condition when setting the result of a Future with call_soon(). Add an helper, an private method, to set the result only if the future was not cancelled. --- asyncio/futures.py | 6 ++++++ asyncio/proactor_events.py | 2 +- asyncio/queues.py | 2 +- asyncio/selector_events.py | 5 +++-- asyncio/tasks.py | 3 ++- asyncio/unix_events.py | 4 ++-- tests/test_futures.py | 6 ++++++ 7 files changed, 21 insertions(+), 7 deletions(-) diff --git a/asyncio/futures.py b/asyncio/futures.py index fcc90d13..022fef76 100644 --- a/asyncio/futures.py +++ b/asyncio/futures.py @@ -316,6 +316,12 @@ def remove_done_callback(self, fn): # So-called internal methods (note: no set_running_or_notify_cancel()). + def _set_result_unless_cancelled(self, result): + """Helper setting the result only if the future was not cancelled.""" + if self.cancelled(): + return + self.set_result(result) + def set_result(self, result): """Mark the future done and set its result. diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index b76f69ee..a80876f3 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -38,7 +38,7 @@ def __init__(self, loop, sock, protocol, waiter=None, self._server.attach(self) self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - self._loop.call_soon(waiter.set_result, None) + self._loop.call_soon(waiter._set_result_unless_cancelled, None) def _set_extra(self, sock): self._extra['pipe'] = sock diff --git a/asyncio/queues.py b/asyncio/queues.py index 57afb053..41551a90 100644 --- a/asyncio/queues.py +++ b/asyncio/queues.py @@ -173,7 +173,7 @@ def get(self): # run, we need to defer the put for a tick to ensure that # getters and putters alternate perfectly. See # ChannelTest.test_wait. - self._loop.call_soon(putter.set_result, None) + self._loop.call_soon(putter._set_result_unless_cancelled, None) return self._get() diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index df64aece..2a170340 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -481,7 +481,7 @@ def __init__(self, loop, sock, protocol, waiter=None, self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - self._loop.call_soon(waiter.set_result, None) + self._loop.call_soon(waiter._set_result_unless_cancelled, None) def pause_reading(self): if self._closing: @@ -690,7 +690,8 @@ def _on_handshake(self): self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) if self._waiter is not None: - self._loop.call_soon(self._waiter.set_result, None) + self._loop.call_soon(self._waiter._set_result_unless_cancelled, + None) def pause_reading(self): # XXX This is a bit icky, given the comment at the top of diff --git a/asyncio/tasks.py b/asyncio/tasks.py index dd191e77..8c7217b7 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -487,7 +487,8 @@ def _wait_for_one(): def sleep(delay, result=None, *, loop=None): """Coroutine that completes after a given time (in seconds).""" future = futures.Future(loop=loop) - h = future._loop.call_later(delay, future.set_result, result) + h = future._loop.call_later(delay, + future._set_result_unless_cancelled, result) try: return (yield from future) finally: diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 5f728b57..535ea220 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -269,7 +269,7 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): self._loop.add_reader(self._fileno, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - self._loop.call_soon(waiter.set_result, None) + self._loop.call_soon(waiter._set_result_unless_cancelled, None) def _read_ready(self): try: @@ -353,7 +353,7 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - self._loop.call_soon(waiter.set_result, None) + self._loop.call_soon(waiter._set_result_unless_cancelled, None) def get_write_buffer_size(self): return sum(len(data) for data in self._buffer) diff --git a/tests/test_futures.py b/tests/test_futures.py index 96b41d69..a6071ea7 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -343,6 +343,12 @@ def memroy_error(): message = m_log.error.call_args[0][0] self.assertRegex(message, re.compile(regex, re.DOTALL)) + def test_set_result_unless_cancelled(self): + fut = asyncio.Future(loop=self.loop) + fut.cancel() + fut._set_result_unless_cancelled(2) + self.assertTrue(fut.cancelled()) + class FutureDoneCallbackTests(test_utils.TestCase): From b0edb252b950a76cf7d6b1127a601baac393b0ac Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 7 Jul 2014 01:14:30 +0200 Subject: [PATCH 1079/1502] cleanup iscoroutine() --- asyncio/coroutines.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/asyncio/coroutines.py b/asyncio/coroutines.py index 7654a0b9..48730c22 100644 --- a/asyncio/coroutines.py +++ b/asyncio/coroutines.py @@ -166,11 +166,11 @@ def iscoroutinefunction(func): return getattr(func, '_is_coroutine', False) -_COROUTINE_TYPES = (CoroWrapper, types.GeneratorType) +_COROUTINE_TYPES = (types.GeneratorType, CoroWrapper) def iscoroutine(obj): """Return True if obj is a coroutine object.""" - return isinstance(obj, _COROUTINE_TYPES) + return isinstance(obj, _COROUTINE_TYPES) def _format_coroutine(coro): From bde811de3187b02d8e6e77d4b5c624cd3a0eb462 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 7 Jul 2014 17:22:34 +0200 Subject: [PATCH 1080/1502] Tulip issue #181: Faster create_connection() Call directly waiter.set_result() in the constructor of _ProactorBasePipeTransport and _SelectorSocketTransport, instead of using of delaying the call with call_soon(). --- asyncio/proactor_events.py | 2 +- asyncio/selector_events.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index a80876f3..23545c9e 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -38,7 +38,7 @@ def __init__(self, loop, sock, protocol, waiter=None, self._server.attach(self) self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - self._loop.call_soon(waiter._set_result_unless_cancelled, None) + waiter.set_result(None) def _set_extra(self, sock): self._extra['pipe'] = sock diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 2a170340..628efb75 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -481,7 +481,7 @@ def __init__(self, loop, sock, protocol, waiter=None, self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - self._loop.call_soon(waiter._set_result_unless_cancelled, None) + waiter.set_result(None) def pause_reading(self): if self._closing: From 7383c9a0eea006cb5397868f9c9f9f7186c9e7ac Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 7 Jul 2014 17:55:46 +0200 Subject: [PATCH 1081/1502] Backed out changeset 9b16831a863a --- asyncio/proactor_events.py | 2 +- asyncio/selector_events.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 23545c9e..a80876f3 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -38,7 +38,7 @@ def __init__(self, loop, sock, protocol, waiter=None, self._server.attach(self) self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - waiter.set_result(None) + self._loop.call_soon(waiter._set_result_unless_cancelled, None) def _set_extra(self, sock): self._extra['pipe'] = sock diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 628efb75..2a170340 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -481,7 +481,7 @@ def __init__(self, loop, sock, protocol, waiter=None, self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - waiter.set_result(None) + self._loop.call_soon(waiter._set_result_unless_cancelled, None) def pause_reading(self): if self._closing: From 79352bda1833967bf50cdd8feff01da0d942d3b8 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 7 Jul 2014 23:22:52 +0200 Subject: [PATCH 1082/1502] Update AbstractEventLoop: add new event loop methods; update also the unit test --- asyncio/events.py | 4 ++++ tests/test_events.py | 12 ++++++++++++ 2 files changed, 16 insertions(+) diff --git a/asyncio/events.py b/asyncio/events.py index b389cfb0..503102d1 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -200,6 +200,10 @@ def is_running(self): """Return whether the event loop is currently running.""" raise NotImplementedError + def is_closed(self): + """Returns True if the event loop was closed.""" + raise NotImplementedError + def close(self): """Close the loop. diff --git a/tests/test_events.py b/tests/test_events.py index beb6cecf..5957b4f1 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1968,6 +1968,8 @@ def test_not_implemented(self): NotImplementedError, loop.stop) self.assertRaises( NotImplementedError, loop.is_running) + self.assertRaises( + NotImplementedError, loop.is_closed) self.assertRaises( NotImplementedError, loop.close) self.assertRaises( @@ -2027,6 +2029,16 @@ def test_not_implemented(self): mock.sentinel) self.assertRaises( NotImplementedError, loop.subprocess_exec, f) + self.assertRaises( + NotImplementedError, loop.set_exception_handler, f) + self.assertRaises( + NotImplementedError, loop.default_exception_handler, f) + self.assertRaises( + NotImplementedError, loop.call_exception_handler, f) + self.assertRaises( + NotImplementedError, loop.get_debug) + self.assertRaises( + NotImplementedError, loop.set_debug, f) class ProtocolsAbsTests(unittest.TestCase): From 823878c3d680db92a08c464200281ba13df2d317 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 7 Jul 2014 23:28:02 +0200 Subject: [PATCH 1083/1502] fix typo in the name of a test function --- tests/test_futures.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_futures.py b/tests/test_futures.py index a6071ea7..157adb7f 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -301,12 +301,12 @@ def test_future_source_traceback(self): def test_future_exception_never_retrieved(self, m_log): self.loop.set_debug(True) - def memroy_error(): + def memory_error(): try: raise MemoryError() except BaseException as exc: return exc - exc = memroy_error() + exc = memory_error() future = asyncio.Future(loop=self.loop) source_traceback = future._source_traceback From 61254737202939a84eb7a14225e63b12c173bf5b Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 8 Jul 2014 11:19:35 +0200 Subject: [PATCH 1084/1502] Tulip issue #185: Add a create_task() method to event loops The create_task() method can be overriden in custom event loop to implement their own task class. For example, greenio and Pulsar projects use their own task class. The create_task() method is now preferred over creating directly task using the Task class. --- asyncio/base_events.py | 6 ++++++ asyncio/events.py | 5 +++++ asyncio/streams.py | 2 +- asyncio/tasks.py | 4 +++- asyncio/test_utils.py | 2 +- tests/test_base_events.py | 24 ++++++++++++++++++++++++ tests/test_events.py | 2 ++ 7 files changed, 42 insertions(+), 3 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 2230dc2c..52c5517b 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -151,6 +151,12 @@ def __repr__(self): % (self.__class__.__name__, self.is_running(), self.is_closed(), self.get_debug())) + def create_task(self, coro): + """Schedule a coroutine object. + + Return a task object.""" + return tasks.Task(coro, loop=self) + def _make_socket_transport(self, sock, protocol, waiter=None, *, extra=None, server=None): """Create socket transport.""" diff --git a/asyncio/events.py b/asyncio/events.py index 503102d1..1f5e5824 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -229,6 +229,11 @@ def call_at(self, when, callback, *args): def time(self): raise NotImplementedError + # Method scheduling a coroutine object: create a task. + + def create_task(self, coro): + raise NotImplementedError + # Methods for interacting with threads. def call_soon_threadsafe(self, callback, *args): diff --git a/asyncio/streams.py b/asyncio/streams.py index a10b969c..9bde218b 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -213,7 +213,7 @@ def connection_made(self, transport): res = self._client_connected_cb(self._stream_reader, self._stream_writer) if coroutines.iscoroutine(res): - tasks.Task(res, loop=self._loop) + self._loop.create_task(res) def connection_lost(self, exc): if exc is None: diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 8c7217b7..befc2967 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -505,7 +505,9 @@ def async(coro_or_future, *, loop=None): raise ValueError('loop argument must agree with Future') return coro_or_future elif coroutines.iscoroutine(coro_or_future): - task = Task(coro_or_future, loop=loop) + if loop is None: + loop = events.get_event_loop() + task = loop.create_task(coro_or_future) if task._source_traceback: del task._source_traceback[-1] return task diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index ef3be236..6abcaf1d 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -48,7 +48,7 @@ def run_briefly(loop): def once(): pass gen = once() - t = tasks.Task(gen, loop=loop) + t = loop.create_task(gen) # Don't log a warning if the task is not done after run_until_complete(). # It occurs if the loop is stopped or if a task raises a BaseException. t._log_destroy_pending = False diff --git a/tests/test_base_events.py b/tests/test_base_events.py index adba082b..f6da7c37 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -12,6 +12,7 @@ import asyncio from asyncio import base_events +from asyncio import events from asyncio import constants from asyncio import test_utils @@ -526,6 +527,29 @@ def test_env_var_debug(self): PYTHONASYNCIODEBUG='1') self.assertEqual(stdout.rstrip(), b'False') + def test_create_task(self): + class MyTask(asyncio.Task): + pass + + @asyncio.coroutine + def test(): + pass + + class EventLoop(base_events.BaseEventLoop): + def create_task(self, coro): + return MyTask(coro, loop=loop) + + loop = EventLoop() + self.set_event_loop(loop) + + coro = test() + task = asyncio.async(coro, loop=loop) + self.assertIsInstance(task, MyTask) + + # make warnings quiet + task._log_destroy_pending = False + coro.close() + class MyProto(asyncio.Protocol): done = None diff --git a/tests/test_events.py b/tests/test_events.py index 5957b4f1..b89416fb 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1972,6 +1972,8 @@ def test_not_implemented(self): NotImplementedError, loop.is_closed) self.assertRaises( NotImplementedError, loop.close) + self.assertRaises( + NotImplementedError, loop.create_task, None) self.assertRaises( NotImplementedError, loop.call_later, None, None) self.assertRaises( From d7b6ba1b4703cc3b30c3c4f37f22575792cc8acf Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 8 Jul 2014 11:27:59 +0200 Subject: [PATCH 1085/1502] tests: fix a warning --- tests/test_tasks.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index eaef05b5..afadc7c1 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -233,6 +233,9 @@ def wait_for(fut): self.assertRegex(repr(task), '' % re.escape(repr(fut))) + fut.set_result(None) + self.loop.run_until_complete(task) + def test_task_basics(self): @asyncio.coroutine def outer(): From f1070bb5b6d7f4255dc24524645f7be7e173db59 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 8 Jul 2014 23:55:59 +0200 Subject: [PATCH 1086/1502] Tulip issue #181: BaseEventLoop.create_datagram_endpoint() now waits until protocol.connection_made() has been called. Document also why transport constructors use a waiter. --- asyncio/base_events.py | 7 +++++-- asyncio/proactor_events.py | 1 + asyncio/selector_events.py | 13 ++++++++++--- asyncio/unix_events.py | 2 ++ tests/test_events.py | 10 ++++++++++ 5 files changed, 28 insertions(+), 5 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 52c5517b..833f81d4 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -169,7 +169,7 @@ def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, raise NotImplementedError def _make_datagram_transport(self, sock, protocol, - address=None, extra=None): + address=None, waiter=None, extra=None): """Create datagram transport.""" raise NotImplementedError @@ -605,7 +605,10 @@ def create_datagram_endpoint(self, protocol_factory, raise exceptions[0] protocol = protocol_factory() - transport = self._make_datagram_transport(sock, protocol, r_addr) + waiter = futures.Future(loop=self) + transport = self._make_datagram_transport(sock, protocol, r_addr, + waiter) + yield from waiter return transport, protocol @coroutine diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index a80876f3..fa247959 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -38,6 +38,7 @@ def __init__(self, loop, sock, protocol, waiter=None, self._server.attach(self) self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: + # wait until protocol.connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) def _set_extra(self, sock): diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 2a170340..7b364ad3 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -51,8 +51,9 @@ def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, server_side, server_hostname, extra, server) def _make_datagram_transport(self, sock, protocol, - address=None, extra=None): - return _SelectorDatagramTransport(self, sock, protocol, address, extra) + address=None, waiter=None, extra=None): + return _SelectorDatagramTransport(self, sock, protocol, + address, waiter, extra) def close(self): if self.is_closed(): @@ -481,6 +482,7 @@ def __init__(self, loop, sock, protocol, waiter=None, self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: + # wait until protocol.connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) def pause_reading(self): @@ -690,6 +692,7 @@ def _on_handshake(self): self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) if self._waiter is not None: + # wait until protocol.connection_made() has been called self._loop.call_soon(self._waiter._set_result_unless_cancelled, None) @@ -806,11 +809,15 @@ class _SelectorDatagramTransport(_SelectorTransport): _buffer_factory = collections.deque - def __init__(self, loop, sock, protocol, address=None, extra=None): + def __init__(self, loop, sock, protocol, address=None, + waiter=None, extra=None): super().__init__(loop, sock, protocol, extra) self._address = address self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + # wait until protocol.connection_made() has been called + self._loop.call_soon(waiter._set_result_unless_cancelled, None) def get_write_buffer_size(self): return sum(len(data) for data, _ in self._buffer) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 535ea220..764e719d 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -269,6 +269,7 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): self._loop.add_reader(self._fileno, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: + # wait until protocol.connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) def _read_ready(self): @@ -353,6 +354,7 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: + # wait until protocol.connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) def get_write_buffer_size(self): diff --git a/tests/test_events.py b/tests/test_events.py index b89416fb..e5c5729f 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -522,6 +522,7 @@ def _basetest_create_connection(self, connection_fut, check_sockname=True): tr, pr = self.loop.run_until_complete(connection_fut) self.assertIsInstance(tr, asyncio.Transport) self.assertIsInstance(pr, asyncio.Protocol) + self.assertIs(pr.transport, tr) if check_sockname: self.assertIsNotNone(tr.get_extra_info('sockname')) self.loop.run_until_complete(pr.done) @@ -1045,12 +1046,21 @@ def datagram_received(self, data, addr): s_transport, server = self.loop.run_until_complete(coro) host, port = s_transport.get_extra_info('sockname') + self.assertIsInstance(s_transport, asyncio.Transport) + self.assertIsInstance(server, TestMyDatagramProto) + self.assertEqual('INITIALIZED', server.state) + self.assertIs(server.transport, s_transport) + coro = self.loop.create_datagram_endpoint( lambda: MyDatagramProto(loop=self.loop), remote_addr=(host, port)) transport, client = self.loop.run_until_complete(coro) + self.assertIsInstance(transport, asyncio.Transport) + self.assertIsInstance(client, MyDatagramProto) self.assertEqual('INITIALIZED', client.state) + self.assertIs(client.transport, transport) + transport.sendto(b'xxx') test_utils.run_until(self.loop, lambda: server.nbytes) self.assertEqual(3, server.nbytes) From 35944e0f10e6a5f12fed7460786a900d12d9eedc Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 9 Jul 2014 10:34:02 -0700 Subject: [PATCH 1087/1502] Correct author name. --- AUTHORS | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/AUTHORS b/AUTHORS index 79acc3d8..d25b4465 100644 --- a/AUTHORS +++ b/AUTHORS @@ -8,6 +8,7 @@ Aymeric Augustin Brett Cannon Charles-François Natali Christian Heimes +Donald Stufft Eli Bendersky Geert Jansen Giampaolo Rodola' @@ -19,9 +20,7 @@ Nikolay Kim Richard Oudkerk Saúl Ibarra Corretgé Serhiy Storchaka -Sonald Stufft Vajrasky Kok Victor Stinner Vladimir Kryachko Yury Selivanov - From b0dc34ef43158115fe62ea0a16874160d4d65f47 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 10 Jul 2014 21:29:48 +0200 Subject: [PATCH 1088/1502] Handle.cancel() now clears references to callback and args In debug mode, repr(Handle) now also contains the location where the Handle was created. --- asyncio/events.py | 18 +++++++++--- tests/test_events.py | 66 ++++++++++++++++++++++++++++++++++++-------- 2 files changed, 68 insertions(+), 16 deletions(-) diff --git a/asyncio/events.py b/asyncio/events.py index 1f5e5824..bddd7e36 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -82,14 +82,20 @@ def __init__(self, callback, args, loop): self._source_traceback = None def __repr__(self): - info = [] + info = [self.__class__.__name__] if self._cancelled: info.append('cancelled') - info.append(_format_callback(self._callback, self._args)) - return '<%s %s>' % (self.__class__.__name__, ' '.join(info)) + if self._callback is not None: + info.append(_format_callback(self._callback, self._args)) + if self._source_traceback: + frame = self._source_traceback[-1] + info.append('created at %s:%s' % (frame[0], frame[1])) + return '<%s>' % ' '.join(info) def cancel(self): self._cancelled = True + self._callback = None + self._args = None def _run(self): try: @@ -125,7 +131,11 @@ def __repr__(self): if self._cancelled: info.append('cancelled') info.append('when=%s' % self._when) - info.append(_format_callback(self._callback, self._args)) + if self._callback is not None: + info.append(_format_callback(self._callback, self._args)) + if self._source_traceback: + frame = self._source_traceback[-1] + info.append('created at %s:%s' % (frame[0], frame[1])) return '<%s %s>' % (self.__class__.__name__, ' '.join(info)) def __hash__(self): diff --git a/tests/test_events.py b/tests/test_events.py index e5c5729f..e04c2876 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1810,27 +1810,30 @@ def test_handle_weakref(self): wd['h'] = h # Would fail without __weakref__ slot. def test_handle_repr(self): + self.loop.get_debug.return_value = False + # simple function - h = asyncio.Handle(noop, (), self.loop) - src = test_utils.get_function_source(noop) + h = asyncio.Handle(noop, (1, 2), self.loop) + filename, lineno = test_utils.get_function_source(noop) self.assertEqual(repr(h), - '' % src) + '' + % (filename, lineno)) # cancelled handle h.cancel() self.assertEqual(repr(h), - '' % src) + '') # decorated function cb = asyncio.coroutine(noop) h = asyncio.Handle(cb, (), self.loop) self.assertEqual(repr(h), - '' % src) + '' + % (filename, lineno)) # partial function cb = functools.partial(noop, 1, 2) h = asyncio.Handle(cb, (3,), self.loop) - filename, lineno = src regex = (r'^$' % (re.escape(filename), lineno)) self.assertRegex(repr(h), regex) @@ -1839,16 +1842,33 @@ def test_handle_repr(self): if sys.version_info >= (3, 4): method = HandleTests.test_handle_repr cb = functools.partialmethod(method) - src = test_utils.get_function_source(method) + filename, lineno = test_utils.get_function_source(method) h = asyncio.Handle(cb, (), self.loop) - filename, lineno = src cb_regex = r'' cb_regex = (r'functools.partialmethod\(%s, , \)\(\)' % cb_regex) regex = (r'^$' % (cb_regex, re.escape(filename), lineno)) self.assertRegex(repr(h), regex) + def test_handle_repr_debug(self): + self.loop.get_debug.return_value = True + + # simple function + create_filename = __file__ + create_lineno = sys._getframe().f_lineno + 1 + h = asyncio.Handle(noop, (1, 2), self.loop) + filename, lineno = test_utils.get_function_source(noop) + self.assertEqual(repr(h), + '' + % (filename, lineno, create_filename, create_lineno)) + + # cancelled handle + h.cancel() + self.assertEqual(repr(h), + '' + % (create_filename, create_lineno)) + def test_handle_source_traceback(self): loop = asyncio.get_event_loop_policy().new_event_loop() loop.set_debug(True) @@ -1894,7 +1914,7 @@ def test_timer(self): def callback(*args): return args - args = () + args = (1, 2, 3) when = time.monotonic() h = asyncio.TimerHandle(when, callback, args, mock.Mock()) self.assertIs(h._callback, callback) @@ -1904,7 +1924,8 @@ def callback(*args): # cancel h.cancel() self.assertTrue(h._cancelled) - + self.assertIsNone(h._callback) + self.assertIsNone(h._args) # when cannot be None self.assertRaises(AssertionError, @@ -1912,6 +1933,8 @@ def callback(*args): self.loop) def test_timer_repr(self): + self.loop.get_debug.return_value = False + # simple function h = asyncio.TimerHandle(123, noop, (), self.loop) src = test_utils.get_function_source(noop) @@ -1921,8 +1944,27 @@ def test_timer_repr(self): # cancelled handle h.cancel() self.assertEqual(repr(h), - '' - % src) + '') + + def test_timer_repr_debug(self): + self.loop.get_debug.return_value = True + + # simple function + create_filename = __file__ + create_lineno = sys._getframe().f_lineno + 1 + h = asyncio.TimerHandle(123, noop, (), self.loop) + filename, lineno = test_utils.get_function_source(noop) + self.assertEqual(repr(h), + '' + % (filename, lineno, create_filename, create_lineno)) + + # cancelled handle + h.cancel() + self.assertEqual(repr(h), + '' + % (create_filename, create_lineno)) + def test_timer_comparison(self): def callback(*args): From 1417aa6877cc37c045ed7cafc5aade94c63c7e9f Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 10 Jul 2014 21:49:54 +0200 Subject: [PATCH 1089/1502] Python issues 21936, 21163: Fix sporadic failures of test_future_exception_never_retrieved() --- tests/test_futures.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_futures.py b/tests/test_futures.py index 157adb7f..50e9414a 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -299,6 +299,12 @@ def test_future_source_traceback(self): @mock.patch('asyncio.base_events.logger') def test_future_exception_never_retrieved(self, m_log): + # FIXME: Python issue #21163, other tests may "leak" pending task which + # emit a warning when they are destroyed by the GC + support.gc_collect() + m_log.error.reset_mock() + # --- + self.loop.set_debug(True) def memory_error(): From ade566127fa96e258767f779ca3067f425b31f3d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 10 Jul 2014 23:43:13 +0200 Subject: [PATCH 1090/1502] Fix create_task(): truncate the traceback to hide the call to create_task() --- asyncio/base_events.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 833f81d4..f6d7a58f 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -155,7 +155,10 @@ def create_task(self, coro): """Schedule a coroutine object. Return a task object.""" - return tasks.Task(coro, loop=self) + task = tasks.Task(coro, loop=self) + if task._source_traceback: + del task._source_traceback[-1] + return task def _make_socket_transport(self, sock, protocol, waiter=None, *, extra=None, server=None): From 629a33e0003c25e8207b2af74e16580f366601d1 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 10 Jul 2014 23:43:41 +0200 Subject: [PATCH 1091/1502] repr(Task) and repr(CoroWrapper) now also includes where these objects were created. If the coroutine is not a generator (don't use "yield from"), use the location of the function, not the location of the coro() wrapper. --- asyncio/coroutines.py | 24 +++++++++++++++------ asyncio/tasks.py | 7 ++++++- tests/test_tasks.py | 49 ++++++++++++++++++++++++++++++++++++------- 3 files changed, 66 insertions(+), 14 deletions(-) diff --git a/asyncio/coroutines.py b/asyncio/coroutines.py index 48730c22..4cbfa854 100644 --- a/asyncio/coroutines.py +++ b/asyncio/coroutines.py @@ -57,7 +57,7 @@ def yield_from_gen(gen): class CoroWrapper: - # Wrapper for coroutine in _DEBUG mode. + # Wrapper for coroutine object in _DEBUG mode. def __init__(self, gen, func): assert inspect.isgenerator(gen), gen @@ -68,8 +68,11 @@ def __init__(self, gen, func): # decorator def __repr__(self): - return ('<%s %s>' - % (self.__class__.__name__, _format_coroutine(self))) + coro_repr = _format_coroutine(self) + if self._source_traceback: + frame = self._source_traceback[-1] + coro_repr += ', created at %s:%s' % (frame[0], frame[1]) + return '<%s %s>' % (self.__class__.__name__, coro_repr) def __iter__(self): return self @@ -181,9 +184,18 @@ def _format_coroutine(coro): coro_name = coro.__name__ filename = coro.gi_code.co_filename - if coro.gi_frame is not None: + if (isinstance(coro, CoroWrapper) + and not inspect.isgeneratorfunction(coro.func)): + filename, lineno = events._get_function_source(coro.func) + if coro.gi_frame is None: + coro_repr = '%s() done, defined at %s:%s' % (coro_name, filename, lineno) + else: + coro_repr = '%s() running, defined at %s:%s' % (coro_name, filename, lineno) + elif coro.gi_frame is not None: lineno = coro.gi_frame.f_lineno - return '%s() at %s:%s' % (coro_name, filename, lineno) + coro_repr = '%s() running at %s:%s' % (coro_name, filename, lineno) else: lineno = coro.gi_code.co_firstlineno - return '%s() done at %s:%s' % (coro_name, filename, lineno) + coro_repr = '%s() done, defined at %s:%s' % (coro_name, filename, lineno) + + return coro_repr diff --git a/asyncio/tasks.py b/asyncio/tasks.py index befc2967..61f48223 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -101,7 +101,12 @@ def __repr__(self): else: info.append(self._state.lower()) - info.append(coroutines._format_coroutine(self._coro)) + coro = coroutines._format_coroutine(self._coro) + info.append('coro=<%s>' % coro) + + if self._source_traceback: + frame = self._source_traceback[-1] + info.append('created at %s:%s' % (frame[0], frame[1])) if self._state == futures._FINISHED: info.append(self._format_result()) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index afadc7c1..b13818f7 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -24,6 +24,19 @@ def coroutine_function(): pass +def format_coroutine(qualname, state, src, source_traceback, generator=False): + if generator: + state = '%s' % state + else: + state = '%s, defined' % state + if source_traceback is not None: + frame = source_traceback[-1] + return ('coro=<%s() %s at %s> created at %s:%s' + % (qualname, state, src, frame[0], frame[1])) + else: + return 'coro=<%s() %s at %s>' % (qualname, state, src) + + class Dummy: def __repr__(self): @@ -149,7 +162,9 @@ def notmuch(): # test pending Task t = asyncio.Task(gen, loop=self.loop) t.add_done_callback(Dummy()) - coro = '%s() at %s' % (coro_qualname, src) + + coro = format_coroutine(coro_qualname, 'running', src, + t._source_traceback, generator=True) self.assertEqual(repr(t), '()]>' % coro) @@ -161,13 +176,16 @@ def notmuch(): # test cancelled Task self.assertRaises(asyncio.CancelledError, self.loop.run_until_complete, t) - coro = '%s() done at %s' % (coro_qualname, src) + coro = format_coroutine(coro_qualname, 'done', src, + t._source_traceback) self.assertEqual(repr(t), '' % coro) # test finished Task t = asyncio.Task(notmuch(), loop=self.loop) self.loop.run_until_complete(t) + coro = format_coroutine(coro_qualname, 'done', src, + t._source_traceback) self.assertEqual(repr(t), "" % coro) @@ -206,18 +224,35 @@ def notmuch(): if PY35: self.assertEqual(gen.__qualname__, coro_qualname) - # format the coroutine object - code = gen.gi_code - coro = ('%s() at %s:%s' - % (coro_qualname, code.co_filename, code.co_firstlineno)) - # test repr(CoroWrapper) if coroutines._DEBUG: + # format the coroutine object + if coroutines._DEBUG: + filename, lineno = test_utils.get_function_source(notmuch) + frame = gen._source_traceback[-1] + coro = ('%s() running, defined at %s:%s, created at %s:%s' + % (coro_qualname, filename, lineno, + frame[0], frame[1])) + else: + code = gen.gi_code + coro = ('%s() running at %s:%s' + % (coro_qualname, code.co_filename, code.co_firstlineno)) + self.assertEqual(repr(gen), '' % coro) # test pending Task t = asyncio.Task(gen, loop=self.loop) t.add_done_callback(Dummy()) + + # format the coroutine object + if coroutines._DEBUG: + src = '%s:%s' % test_utils.get_function_source(notmuch) + else: + code = gen.gi_code + src = '%s:%s' % (code.co_filename, code.co_firstlineno) + coro = format_coroutine(coro_qualname, 'running', src, + t._source_traceback, + generator=not coroutines._DEBUG) self.assertEqual(repr(t), '()]>' % coro) self.loop.run_until_complete(t) From 3aac8d7b51495651d831ac1e008a552091e18488 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 11 Jul 2014 01:03:25 +0200 Subject: [PATCH 1092/1502] CoroWrapper.__del__() now reuses repr(CoroWrapper) to log the "... was never yielded from" warning --- asyncio/coroutines.py | 3 +-- tests/test_tasks.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/asyncio/coroutines.py b/asyncio/coroutines.py index 4cbfa854..3d3a19f4 100644 --- a/asyncio/coroutines.py +++ b/asyncio/coroutines.py @@ -119,8 +119,7 @@ def __del__(self): gen = getattr(self, 'gen', None) frame = getattr(gen, 'gi_frame', None) if frame is not None and frame.f_lasti == -1: - func = events._format_callback(self.func, ()) - msg = 'Coroutine %s was never yielded from' % func + msg = '%r was never yielded from' % self tb = getattr(self, '_source_traceback', ()) if tb: tb = ''.join(traceback.format_list(tb)) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index b13818f7..5029cfa3 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1653,7 +1653,7 @@ def coro_noop(): self.assertTrue(m_log.error.called) message = m_log.error.call_args[0][0] func_filename, func_lineno = test_utils.get_function_source(coro_noop) - regex = (r'^Coroutine %s\(\) at %s:%s was never yielded from\n' + regex = (r'^ was never yielded from\n' r'Coroutine object created at \(most recent call last\):\n' r'.*\n' r' File "%s", line %s, in test_coroutine_never_yielded\n' From f786a201a612a537a063db12f225031b0b0d05ab Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 11 Jul 2014 01:12:58 +0200 Subject: [PATCH 1093/1502] Improve CoroWrapper: copy also the qualified name on Python 3.4, not only on Python 3.5+ --- asyncio/coroutines.py | 9 ++------- asyncio/tasks.py | 1 - tests/test_tasks.py | 9 ++------- 3 files changed, 4 insertions(+), 15 deletions(-) diff --git a/asyncio/coroutines.py b/asyncio/coroutines.py index 3d3a19f4..c28de95a 100644 --- a/asyncio/coroutines.py +++ b/asyncio/coroutines.py @@ -29,8 +29,6 @@ _DEBUG = (not sys.flags.ignore_environment and bool(os.environ.get('PYTHONASYNCIODEBUG'))) -_PY35 = (sys.version_info >= (3, 5)) - # Check for CPython issue #21209 def has_yield_from_bug(): @@ -154,7 +152,7 @@ def wrapper(*args, **kwds): if w._source_traceback: del w._source_traceback[-1] w.__name__ = func.__name__ - if _PY35: + if hasattr(func, '__qualname__'): w.__qualname__ = func.__qualname__ w.__doc__ = func.__doc__ return w @@ -177,10 +175,7 @@ def iscoroutine(obj): def _format_coroutine(coro): assert iscoroutine(coro) - if _PY35: - coro_name = coro.__qualname__ - else: - coro_name = coro.__name__ + coro_name = getattr(coro, '__qualname__', coro.__name__) filename = coro.gi_code.co_filename if (isinstance(coro, CoroWrapper) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 61f48223..3d7e5a43 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -21,7 +21,6 @@ from .log import logger _PY34 = (sys.version_info >= (3, 4)) -_PY35 = (sys.version_info >= (3, 5)) class Task(futures.Future): diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 5029cfa3..85648dc4 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -150,7 +150,7 @@ def notmuch(): # test coroutine object gen = notmuch() - if PY35: + if coroutines._DEBUG or PY35: coro_qualname = 'TaskTests.test_task_repr..notmuch' else: coro_qualname = 'notmuch' @@ -205,17 +205,12 @@ def notmuch(): # test coroutine object gen = notmuch() - if PY35: + if coroutines._DEBUG or PY35: # On Python >= 3.5, generators now inherit the name of the # function, as expected, and have a qualified name (__qualname__ # attribute). coro_name = 'notmuch' coro_qualname = 'TaskTests.test_task_repr_coro_decorator..notmuch' - elif coroutines._DEBUG: - # In debug mode, @coroutine decorator uses CoroWrapper which gets - # its name (__name__ attribute) from the wrapped coroutine - # function. - coro_name = coro_qualname = 'notmuch' else: # On Python < 3.5, generators inherit the name of the code, not of # the function. See: http://bugs.python.org/issue21205 From 9703e629cdf169184acd47fa419b499bde73a6a3 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 11 Jul 2014 11:49:46 +0200 Subject: [PATCH 1094/1502] Fix some pyflakes warnings: remove unused imports --- asyncio/streams.py | 1 - asyncio/tasks.py | 1 - tests/test_base_events.py | 1 - tests/test_events.py | 2 +- tests/test_tasks.py | 9 ++++----- 5 files changed, 5 insertions(+), 9 deletions(-) diff --git a/asyncio/streams.py b/asyncio/streams.py index 9bde218b..9b654cdb 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -14,7 +14,6 @@ from . import events from . import futures from . import protocols -from . import tasks from .coroutines import coroutine diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 3d7e5a43..78b4c4dc 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -18,7 +18,6 @@ from . import events from . import futures from .coroutines import coroutine -from .log import logger _PY34 = (sys.version_info >= (3, 4)) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index f6da7c37..8155beb6 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -12,7 +12,6 @@ import asyncio from asyncio import base_events -from asyncio import events from asyncio import constants from asyncio import test_utils diff --git a/tests/test_events.py b/tests/test_events.py index e04c2876..06552f87 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -715,7 +715,7 @@ def test_create_unix_server_path_socket_error(self): with self.assertRaisesRegex(ValueError, 'path and sock can not be specified ' 'at the same time'): - server = self.loop.run_until_complete(f) + self.loop.run_until_complete(f) def _create_ssl_context(self, certfile, keyfile=None): sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 85648dc4..ca770f90 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1,6 +1,5 @@ """Tests for tasks.py.""" -import os.path import re import sys import types @@ -1640,9 +1639,9 @@ def coro_noop(): asyncio.coroutines._DEBUG = debug tb_filename = __file__ - tb_lineno = sys._getframe().f_lineno + 1 - coro = coro_noop() - coro = None + tb_lineno = sys._getframe().f_lineno + 2 + # create a coroutine object but don't use it + coro_noop() support.gc_collect() self.assertTrue(m_log.error.called) @@ -1652,7 +1651,7 @@ def coro_noop(): r'Coroutine object created at \(most recent call last\):\n' r'.*\n' r' File "%s", line %s, in test_coroutine_never_yielded\n' - r' coro = coro_noop\(\)$' + r' coro_noop\(\)$' % (re.escape(coro_noop.__qualname__), re.escape(func_filename), func_lineno, re.escape(tb_filename), tb_lineno)) From 39a5c7cb41d05bc21acf4bee5dc6410dab996925 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 11 Jul 2014 11:54:05 +0200 Subject: [PATCH 1095/1502] Tulip issue #182: Improve logs of BaseEventLoop._run_once() - Don't log non-blocking poll - Only log polling with a timeout if it gets events or if it timed out after more than 1 second. --- asyncio/base_events.py | 21 ++++++++++++++------- tests/test_base_events.py | 1 + 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index f6d7a58f..3951fb75 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -882,19 +882,26 @@ def _run_once(self): when = self._scheduled[0]._when timeout = max(0, when - self.time()) - if self._debug: + if self._debug and timeout != 0: t0 = self.time() event_list = self._selector.select(timeout) dt = self.time() - t0 - if dt >= 1: + if dt >= 1.0: level = logging.INFO else: level = logging.DEBUG - if timeout is not None: - logger.log(level, 'poll %.3f took %.3f seconds', - timeout, dt) - else: - logger.log(level, 'poll took %.3f seconds', dt) + nevent = len(event_list) + if timeout is None: + logger.log(level, 'poll took %.3f ms: %s events', + dt * 1e3, nevent) + elif nevent: + logger.log(level, + 'poll %.3f ms took %.3f ms: %s events', + timeout * 1e3, dt * 1e3, nevent) + elif dt >= 1.0: + logger.log(level, + 'poll %.3f ms took %.3f ms: timeout', + timeout * 1e3, dt * 1e3) else: event_list = self._selector.select(timeout) self._process_events(event_list) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 8155beb6..27610f0d 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -25,6 +25,7 @@ class BaseEventLoopTests(test_utils.TestCase): def setUp(self): self.loop = base_events.BaseEventLoop() self.loop._selector = mock.Mock() + self.loop._selector.select.return_value = () self.set_event_loop(self.loop) def test_not_implemented(self): From 5a8250229ea43078dda647053ddd4baf19556b99 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 11 Jul 2014 22:51:49 +0200 Subject: [PATCH 1096/1502] Tulip issue #180: Make Server attributes and methods private - loop, waiters and active_count attributes are now private - attach(), detach() and wakeup() methods are now private The sockets attribute remains public. --- asyncio/base_events.py | 41 +++++++++++++++++++------------------- asyncio/proactor_events.py | 4 ++-- asyncio/selector_events.py | 4 ++-- 3 files changed, 25 insertions(+), 24 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 3951fb75..10996d23 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -89,43 +89,44 @@ def _raise_stop_error(*args): class Server(events.AbstractServer): def __init__(self, loop, sockets): - self.loop = loop + self._loop = loop self.sockets = sockets - self.active_count = 0 - self.waiters = [] + self._active_count = 0 + self._waiters = [] - def attach(self, transport): + def _attach(self): assert self.sockets is not None - self.active_count += 1 + self._active_count += 1 - def detach(self, transport): - assert self.active_count > 0 - self.active_count -= 1 - if self.active_count == 0 and self.sockets is None: + def _detach(self): + assert self._active_count > 0 + self._active_count -= 1 + if self._active_count == 0 and self.sockets is None: self._wakeup() def close(self): sockets = self.sockets - if sockets is not None: - self.sockets = None - for sock in sockets: - self.loop._stop_serving(sock) - if self.active_count == 0: - self._wakeup() + if sockets is None: + return + self.sockets = None + for sock in sockets: + self._loop._stop_serving(sock) + if self._active_count == 0: + self._wakeup() def _wakeup(self): - waiters = self.waiters - self.waiters = None + waiters = self._waiters + self._waiters = None for waiter in waiters: if not waiter.done(): waiter.set_result(waiter) @coroutine def wait_closed(self): - if self.sockets is None or self.waiters is None: + if self.sockets is None or self._waiters is None: return - waiter = futures.Future(loop=self.loop) - self.waiters.append(waiter) + waiter = futures.Future(loop=self._loop) + self._waiters.append(waiter) yield from waiter diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index fa247959..d0b601d7 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -35,7 +35,7 @@ def __init__(self, loop, sock, protocol, waiter=None, self._closing = False # Set when close() called. self._eof_written = False if self._server is not None: - self._server.attach(self) + self._server._attach() self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: # wait until protocol.connection_made() has been called @@ -91,7 +91,7 @@ def _call_connection_lost(self, exc): self._sock.close() server = self._server if server is not None: - server.detach(self) + server._detach() self._server = None def get_write_buffer_size(self): diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 7b364ad3..b9650468 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -417,7 +417,7 @@ def __init__(self, loop, sock, protocol, extra, server=None): self._conn_lost = 0 # Set when call to connection_lost scheduled. self._closing = False # Set when close() called. if self._server is not None: - self._server.attach(self) + self._server._attach() def abort(self): self._force_close(None) @@ -464,7 +464,7 @@ def _call_connection_lost(self, exc): self._loop = None server = self._server if server is not None: - server.detach(self) + server._detach() self._server = None def get_write_buffer_size(self): From af6c1a35b31da3eda0f674c872eacb2bc45a578b Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 12 Jul 2014 00:15:28 +0200 Subject: [PATCH 1097/1502] BaseEventLoop.create_server() returns a Server object --- asyncio/base_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 10996d23..2f7c124a 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -626,7 +626,7 @@ def create_server(self, protocol_factory, host=None, port=None, reuse_address=None): """Create a TCP server bound to host and port. - Return an AbstractServer object which can be used to stop the service. + Return an Server object which can be used to stop the service. This method is a coroutine. """ From feb383aea1e555fa642d576ac06b92cbc975f69d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 12 Jul 2014 03:07:14 +0200 Subject: [PATCH 1098/1502] Fix ProactorEventLoop() in debug mode ProactorEventLoop._make_self_pipe() doesn't call call_soon() directly because it checks for the current loop which fails, because the method is called to build the event loop. --- asyncio/proactor_events.py | 4 +++- tests/test_proactor_events.py | 7 ++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index d0b601d7..8e718838 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -401,7 +401,9 @@ def _make_self_pipe(self): self._ssock.setblocking(False) self._csock.setblocking(False) self._internal_fds += 1 - self.call_soon(self._loop_self_reading) + # don't check the current loop because _make_self_pipe() is called + # from the event loop constructor + self._call_soon(self._loop_self_reading, (), check_loop=False) def _loop_self_reading(self, f=None): try: diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index ddfceae1..4bb4f0b3 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -358,16 +358,17 @@ def _socketpair(s): self.loop = EventLoop(self.proactor) self.set_event_loop(self.loop, cleanup=False) - @mock.patch.object(BaseProactorEventLoop, 'call_soon') + @mock.patch.object(BaseProactorEventLoop, '_call_soon') @mock.patch.object(BaseProactorEventLoop, '_socketpair') - def test_ctor(self, socketpair, call_soon): + def test_ctor(self, socketpair, _call_soon): ssock, csock = socketpair.return_value = ( mock.Mock(), mock.Mock()) loop = BaseProactorEventLoop(self.proactor) self.assertIs(loop._ssock, ssock) self.assertIs(loop._csock, csock) self.assertEqual(loop._internal_fds, 1) - call_soon.assert_called_with(loop._loop_self_reading) + _call_soon.assert_called_with(loop._loop_self_reading, (), + check_loop=False) def test_close_self_pipe(self): self.loop._close_self_pipe() From 58f42f896bc77aad41385daf2c2c28558b7b1438 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 12 Jul 2014 02:36:36 +0200 Subject: [PATCH 1099/1502] Cleanup _ProactorReadPipeTransport constructor Not need to set again _read_fut attribute to None, it is already done in the base class. --- asyncio/proactor_events.py | 1 - 1 file changed, 1 deletion(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 8e718838..5009b0d3 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -108,7 +108,6 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport, def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None): super().__init__(loop, sock, protocol, waiter, extra, server) - self._read_fut = None self._paused = False self._loop.call_soon(self._loop_reading) From b809f1bb8e8c94cfc3d9fd0c61977614053d9c27 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 12 Jul 2014 03:07:24 +0200 Subject: [PATCH 1100/1502] Tulip issue #183: log socket events in debug mode - Log most important socket events: socket connected, new client, connection reset or closed by peer (EOF), etc. - Log time elapsed in DNS resolution (getaddrinfo) - Log pause/resume reading - Log time of SSL handshake - Log SSL handshake errors - Add a __repr__() method to many classes --- asyncio/base_events.py | 54 ++++++++++++++++++++- asyncio/proactor_events.py | 31 +++++++++++- asyncio/selector_events.py | 89 ++++++++++++++++++++++++++++++----- asyncio/unix_events.py | 36 ++++++++++++++ asyncio/windows_events.py | 12 +++++ tests/test_selector_events.py | 12 ++--- 6 files changed, 212 insertions(+), 22 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 2f7c124a..e5683fd1 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -94,6 +94,9 @@ def __init__(self, loop, sockets): self._active_count = 0 self._waiters = [] + def __repr__(self): + return '<%s sockets=%r>' % (self.__class__.__name__, self.sockets) + def _attach(self): assert self.sockets is not None self._active_count += 1 @@ -274,6 +277,8 @@ def close(self): raise RuntimeError("cannot close a running event loop") if self._closed: return + if self._debug: + logger.debug("Close %r", self) self._closed = True self._ready.clear() self._scheduled.clear() @@ -400,10 +405,39 @@ def run_in_executor(self, executor, callback, *args): def set_default_executor(self, executor): self._default_executor = executor + def _getaddrinfo_debug(self, host, port, family, type, proto, flags): + msg = ["%s:%r" % (host, port)] + if family: + msg.append('family=%r' % family) + if type: + msg.append('type=%r' % type) + if proto: + msg.append('proto=%r' % proto) + if flags: + msg.append('flags=%r' % flags) + msg = ', '.join(msg) + logger.debug('Get addresss info %s', msg) + + t0 = self.time() + addrinfo = socket.getaddrinfo(host, port, family, type, proto, flags) + dt = self.time() - t0 + + msg = ('Getting addresss info %s took %.3f ms: %r' + % (msg, dt * 1e3, addrinfo)) + if dt >= self.slow_callback_duration: + logger.info(msg) + else: + logger.debug(msg) + return addrinfo + def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): - return self.run_in_executor(None, socket.getaddrinfo, - host, port, family, type, proto, flags) + if self._debug: + return self.run_in_executor(None, self._getaddrinfo_debug, + host, port, family, type, proto, flags) + else: + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) def getnameinfo(self, sockaddr, flags=0): return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) @@ -490,6 +524,8 @@ def create_connection(self, protocol_factory, host=None, port=None, *, sock.close() sock = None continue + if self._debug: + logger.debug("connect %r to %r", sock, address) yield from self.sock_connect(sock, address) except OSError as exc: if sock is not None: @@ -522,6 +558,9 @@ def create_connection(self, protocol_factory, host=None, port=None, *, transport, protocol = yield from self._create_connection_transport( sock, protocol_factory, ssl, server_hostname) + if self._debug: + logger.debug("connected to %s:%r: (%r, %r)", + host, port, transport, protocol) return transport, protocol @coroutine @@ -612,6 +651,15 @@ def create_datagram_endpoint(self, protocol_factory, waiter = futures.Future(loop=self) transport = self._make_datagram_transport(sock, protocol, r_addr, waiter) + if self._debug: + if local_addr: + logger.info("Datagram endpoint local_addr=%r remote_addr=%r " + "created: (%r, %r)", + local_addr, remote_addr, transport, protocol) + else: + logger.debug("Datagram endpoint remote_addr=%r created: " + "(%r, %r)", + remote_addr, transport, protocol) yield from waiter return transport, protocol @@ -692,6 +740,8 @@ def create_server(self, protocol_factory, host=None, port=None, sock.listen(backlog) sock.setblocking(False) self._start_serving(protocol_factory, sock, ssl, server) + if self._debug: + logger.info("%r is serving", server) return server @coroutine diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 5009b0d3..d09e9faa 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -41,6 +41,23 @@ def __init__(self, loop, sock, protocol, waiter=None, # wait until protocol.connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) + def __repr__(self): + info = [self.__class__.__name__, 'fd=%s' % self._sock.fileno()] + if self._read_fut is not None: + ov = "pending" if self._read_fut.ov.pending else "completed" + info.append('read=%s' % ov) + if self._write_fut is not None: + if self._write_fut.ov.pending: + info.append("write=pending=%s" % self._pending_write) + else: + info.append("write=completed") + if self._buffer: + bufsize = len(self._buffer) + info.append('write_bufsize=%s' % bufsize) + if self._eof_written: + info.append('EOF written') + return '<%s>' % ' '.join(info) + def _set_extra(self, sock): self._extra['pipe'] = sock @@ -55,7 +72,10 @@ def close(self): self._read_fut.cancel() def _fatal_error(self, exc, message='Fatal error on pipe transport'): - if not isinstance(exc, (BrokenPipeError, ConnectionResetError)): + if isinstance(exc, (BrokenPipeError, ConnectionResetError)): + if self._loop.get_debug(): + logger.debug("%r: %s", self, message, exc_info=True) + else: self._loop.call_exception_handler({ 'message': message, 'exception': exc, @@ -117,6 +137,8 @@ def pause_reading(self): if self._paused: raise RuntimeError('Already paused') self._paused = True + if self._loop.get_debug(): + logger.debug("%r pauses reading", self) def resume_reading(self): if not self._paused: @@ -125,6 +147,8 @@ def resume_reading(self): if self._closing: return self._loop.call_soon(self._loop_reading, self._read_fut) + if self._loop.get_debug(): + logger.debug("%r resumes reading", self) def _loop_reading(self, fut=None): if self._paused: @@ -165,6 +189,8 @@ def _loop_reading(self, fut=None): if data: self._protocol.data_received(data) elif data is not None: + if self._loop.get_debug(): + logger.debug("%r received EOF", self) keep_open = self._protocol.eof_received() if not keep_open: self.close() @@ -427,6 +453,9 @@ def loop(f=None): try: if f is not None: conn, addr = f.result() + if self._debug: + logger.debug("%r got a new connection from %r: %r", + server, addr, conn) protocol = protocol_factory() self._make_socket_transport( conn, protocol, diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index b9650468..d79c0801 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -23,6 +23,17 @@ from .log import logger +def _test_selector_event(selector, fd, event): + # Test if the selector is monitoring 'event' events + # for the file descriptor 'fd'. + try: + key = selector.get_key(fd) + except KeyError: + return False + else: + return bool(key.events & event) + + class BaseSelectorEventLoop(base_events.BaseEventLoop): """Selector event loop. @@ -116,6 +127,9 @@ def _accept_connection(self, protocol_factory, sock, sslcontext=None, server=None): try: conn, addr = sock.accept() + if self._debug: + logger.debug("%r got a new connection from %r: %r", + server, addr, conn) conn.setblocking(False) except (BlockingIOError, InterruptedError, ConnectionAbortedError): pass # False alarm. @@ -419,6 +433,26 @@ def __init__(self, loop, sock, protocol, extra, server=None): if self._server is not None: self._server._attach() + def __repr__(self): + info = [self.__class__.__name__, 'fd=%s' % self._sock_fd] + polling = _test_selector_event(self._loop._selector, + self._sock_fd, selectors.EVENT_READ) + if polling: + info.append('read=polling') + else: + info.append('read=idle') + + polling = _test_selector_event(self._loop._selector, + self._sock_fd, selectors.EVENT_WRITE) + if polling: + state = 'polling' + else: + state = 'idle' + + bufsize = self.get_write_buffer_size() + info.append('write=<%s, bufsize=%s>' % (state, bufsize)) + return '<%s>' % ' '.join(info) + def abort(self): self._force_close(None) @@ -433,7 +467,10 @@ def close(self): def _fatal_error(self, exc, message='Fatal error on transport'): # Should be called from exception handler only. - if not isinstance(exc, (BrokenPipeError, ConnectionResetError)): + if isinstance(exc, (BrokenPipeError, ConnectionResetError)): + if self._loop.get_debug(): + logger.debug("%r: %s", self, message, exc_info=True) + else: self._loop.call_exception_handler({ 'message': message, 'exception': exc, @@ -492,6 +529,8 @@ def pause_reading(self): raise RuntimeError('Already paused') self._paused = True self._loop.remove_reader(self._sock_fd) + if self._loop.get_debug(): + logger.debug("%r pauses reading", self) def resume_reading(self): if not self._paused: @@ -500,6 +539,8 @@ def resume_reading(self): if self._closing: return self._loop.add_reader(self._sock_fd, self._read_ready) + if self._loop.get_debug(): + logger.debug("%r resumes reading", self) def _read_ready(self): try: @@ -512,6 +553,8 @@ def _read_ready(self): if data: self._protocol.data_received(data) else: + if self._loop.get_debug(): + logger.debug("%r received EOF", self) keep_open = self._protocol.eof_received() if keep_open: # We're keeping the connection open so the @@ -638,31 +681,37 @@ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, # SSL-specific extra info. (peercert is set later) self._extra.update(sslcontext=sslcontext) - self._on_handshake() + if self._loop.get_debug(): + logger.debug("%r starts SSL handshake", self) + start_time = self._loop.time() + else: + start_time = None + self._on_handshake(start_time) - def _on_handshake(self): + def _on_handshake(self, start_time): try: self._sock.do_handshake() except ssl.SSLWantReadError: - self._loop.add_reader(self._sock_fd, self._on_handshake) + self._loop.add_reader(self._sock_fd, + self._on_handshake, start_time) return except ssl.SSLWantWriteError: - self._loop.add_writer(self._sock_fd, self._on_handshake) - return - except Exception as exc: - self._loop.remove_reader(self._sock_fd) - self._loop.remove_writer(self._sock_fd) - self._sock.close() - if self._waiter is not None: - self._waiter.set_exception(exc) + self._loop.add_writer(self._sock_fd, + self._on_handshake, start_time) return except BaseException as exc: + if self._loop.get_debug(): + logger.warning("%r: SSL handshake failed", + self, exc_info=True) self._loop.remove_reader(self._sock_fd) self._loop.remove_writer(self._sock_fd) self._sock.close() if self._waiter is not None: self._waiter.set_exception(exc) - raise + if isinstance(exc, Exception): + return + else: + raise self._loop.remove_reader(self._sock_fd) self._loop.remove_writer(self._sock_fd) @@ -676,6 +725,10 @@ def _on_handshake(self): try: ssl.match_hostname(peercert, self._server_hostname) except Exception as exc: + if self._loop.get_debug(): + logger.warning("%r: SSL handshake failed " + "on matching the hostname", + self, exc_info=True) self._sock.close() if self._waiter is not None: self._waiter.set_exception(exc) @@ -696,6 +749,10 @@ def _on_handshake(self): self._loop.call_soon(self._waiter._set_result_unless_cancelled, None) + if self._loop.get_debug(): + dt = self._loop.time() - start_time + logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3) + def pause_reading(self): # XXX This is a bit icky, given the comment at the top of # _read_ready(). Is it possible to evoke a deadlock? I don't @@ -709,6 +766,8 @@ def pause_reading(self): raise RuntimeError('Already paused') self._paused = True self._loop.remove_reader(self._sock_fd) + if self._loop.get_debug(): + logger.debug("%r pauses reading", self) def resume_reading(self): if not self._paused: @@ -717,6 +776,8 @@ def resume_reading(self): if self._closing: return self._loop.add_reader(self._sock_fd, self._read_ready) + if self._loop.get_debug(): + logger.debug("%r resumes reading", self) def _read_ready(self): if self._write_wants_read: @@ -741,6 +802,8 @@ def _read_ready(self): self._protocol.data_received(data) else: try: + if self._loop.get_debug(): + logger.debug("%r received EOF", self) keep_open = self._protocol.eof_received() if keep_open: logger.warning('returning true from eof_received() ' diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 764e719d..09b875ce 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -16,6 +16,7 @@ from . import constants from . import events from . import selector_events +from . import selectors from . import transports from .coroutines import coroutine from .log import logger @@ -272,6 +273,20 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): # wait until protocol.connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) + def __repr__(self): + info = [self.__class__.__name__, 'fd=%s' % self._fileno] + if self._pipe is not None: + polling = selector_events._test_selector_event( + self._loop._selector, + self._fileno, selectors.EVENT_READ) + if polling: + info.append('polling') + else: + info.append('idle') + else: + info.append('closed') + return '<%s>' % ' '.join(info) + def _read_ready(self): try: data = os.read(self._fileno, self.max_size) @@ -283,6 +298,8 @@ def _read_ready(self): if data: self._protocol.data_received(data) else: + if self._loop.get_debug(): + logger.info("%r was closed by peer", self) self._closing = True self._loop.remove_reader(self._fileno) self._loop.call_soon(self._protocol.eof_received) @@ -357,11 +374,30 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): # wait until protocol.connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) + def __repr__(self): + info = [self.__class__.__name__, 'fd=%s' % self._fileno] + if self._pipe is not None: + polling = selector_events._test_selector_event( + self._loop._selector, + self._fileno, selectors.EVENT_WRITE) + if polling: + info.append('polling') + else: + info.append('idle') + + bufsize = self.get_write_buffer_size() + info.append('bufsize=%s' % bufsize) + else: + info.append('closed') + return '<%s>' % ' '.join(info) + def get_write_buffer_size(self): return sum(len(data) for data in self._buffer) def _read_ready(self): # Pipe was closed by peer. + if self._loop.get_debug(): + logger.info("%r was closed by peer", self) if self._buffer: self._close(BrokenPipeError()) else: diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 93b71b2a..9d86c96b 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -40,6 +40,18 @@ def __init__(self, ov, *, loop=None): super().__init__(loop=loop) self.ov = ov + def __repr__(self): + info = [self._state.lower()] + if self.ov.pending: + info.append('overlapped=pending') + else: + info.append('overlapped=completed') + if self._state == futures._FINISHED: + info.append(self._format_result()) + if self._callbacks: + info.append(self._format_callbacks()) + return '<%s %s>' % (self.__class__.__name__, ' '.join(info)) + def cancel(self): try: self.ov.cancel() diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 35efab97..51869316 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -1092,15 +1092,15 @@ def test_on_handshake_reader_retry(self): self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError transport = _SelectorSslTransport( self.loop, self.sock, self.protocol, self.sslcontext) - transport._on_handshake() - self.loop.assert_reader(1, transport._on_handshake) + transport._on_handshake(None) + self.loop.assert_reader(1, transport._on_handshake, None) def test_on_handshake_writer_retry(self): self.sslsock.do_handshake.side_effect = ssl.SSLWantWriteError transport = _SelectorSslTransport( self.loop, self.sock, self.protocol, self.sslcontext) - transport._on_handshake() - self.loop.assert_writer(1, transport._on_handshake) + transport._on_handshake(None) + self.loop.assert_writer(1, transport._on_handshake, None) def test_on_handshake_exc(self): exc = ValueError() @@ -1108,7 +1108,7 @@ def test_on_handshake_exc(self): transport = _SelectorSslTransport( self.loop, self.sock, self.protocol, self.sslcontext) transport._waiter = asyncio.Future(loop=self.loop) - transport._on_handshake() + transport._on_handshake(None) self.assertTrue(self.sslsock.close.called) self.assertTrue(transport._waiter.done()) self.assertIs(exc, transport._waiter.exception()) @@ -1119,7 +1119,7 @@ def test_on_handshake_base_exc(self): transport._waiter = asyncio.Future(loop=self.loop) exc = BaseException() self.sslsock.do_handshake.side_effect = exc - self.assertRaises(BaseException, transport._on_handshake) + self.assertRaises(BaseException, transport._on_handshake, None) self.assertTrue(self.sslsock.close.called) self.assertTrue(transport._waiter.done()) self.assertIs(exc, transport._waiter.exception()) From c4919ebf20156a6a9da223546b015e6566af65f9 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 12 Jul 2014 14:38:33 -0700 Subject: [PATCH 1101/1502] Clean up some docstrings and comments. Remove unused unimplemented _read_from_self(). --- asyncio/base_events.py | 85 ++++++++++++++++++++++----------------- asyncio/unix_events.py | 2 +- tests/test_base_events.py | 2 - 3 files changed, 49 insertions(+), 40 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index e5683fd1..05559e00 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -1,7 +1,7 @@ """Base implementation of event loop. The event loop can be broken up into a multiplexer (the part -responsible for notifying us of IO events) and the event loop proper, +responsible for notifying us of I/O events) and the event loop proper, which wraps a multiplexer with functionality for scheduling callbacks, immediately or at a given time in the future. @@ -70,7 +70,7 @@ def _check_resolved_address(sock, address): type_mask |= socket.SOCK_NONBLOCK if hasattr(socket, 'SOCK_CLOEXEC'): type_mask |= socket.SOCK_CLOEXEC - # Use getaddrinfo(AI_NUMERICHOST) to ensure that the address is + # Use getaddrinfo(flags=AI_NUMERICHOST) to ensure that the address is # already resolved. try: socket.getaddrinfo(host, port, @@ -158,7 +158,8 @@ def __repr__(self): def create_task(self, coro): """Schedule a coroutine object. - Return a task object.""" + Return a task object. + """ task = tasks.Task(coro, loop=self) if task._source_traceback: del task._source_traceback[-1] @@ -197,12 +198,13 @@ def _make_subprocess_transport(self, protocol, args, shell, """Create subprocess transport.""" raise NotImplementedError - def _read_from_self(self): - """XXX""" - raise NotImplementedError - def _write_to_self(self): - """XXX""" + """Write a byte to self-pipe, to wake up the event loop. + + This may be called from a different thread. + + The subclass is responsible for implementing the self-pipe. + """ raise NotImplementedError def _process_events(self, event_list): @@ -233,7 +235,7 @@ def run_until_complete(self, future): If the argument is a coroutine, it is wrapped in a Task. - XXX TBD: It would be disastrous to call run_until_complete() + WARNING: It would be disastrous to call run_until_complete() with the same coroutine twice -- it would wrap it in two different Tasks and that can't be good. @@ -261,7 +263,7 @@ def stop(self): Every callback scheduled before stop() is called will run. Callback scheduled after stop() is called won't. However, - those callbacks will run if run() is called again later. + those callbacks will run if run_*() is called again later. """ self.call_soon(_raise_stop_error) @@ -274,7 +276,7 @@ def close(self): The event loop must not be running. """ if self._running: - raise RuntimeError("cannot close a running event loop") + raise RuntimeError("Cannot close a running event loop") if self._closed: return if self._debug: @@ -292,11 +294,16 @@ def is_closed(self): return self._closed def is_running(self): - """Returns running status of event loop.""" + """Returns True if the event loop is running.""" return self._running def time(self): - """Return the time according to the event loop's clock.""" + """Return the time according to the event loop's clock. + + This is a float expressed in seconds since an epoch, but the + epoch, precision, accuracy and drift are unspecified and may + differ per event loop. + """ return time.monotonic() def call_later(self, delay, callback, *args): @@ -306,7 +313,7 @@ def call_later(self, delay, callback, *args): can be used to cancel the call. The delay can be an int or float, expressed in seconds. It is - always a relative time. + always relative to the current time. Each callback will be called exactly once. If two callbacks are scheduled for exactly the same time, it undefined which @@ -321,7 +328,10 @@ def call_later(self, delay, callback, *args): return timer def call_at(self, when, callback, *args): - """Like call_later(), but uses an absolute time.""" + """Like call_later(), but uses an absolute time. + + Absolute time corresponds to the event loop's time() method. + """ if coroutines.iscoroutinefunction(callback): raise TypeError("coroutines cannot be used with call_at()") if self._debug: @@ -335,7 +345,7 @@ def call_at(self, when, callback, *args): def call_soon(self, callback, *args): """Arrange for a callback to be called as soon as possible. - This operates as a FIFO queue, callbacks are called in the + This operates as a FIFO queue: callbacks are called in the order in which they are registered. Each callback will be called exactly once. @@ -361,10 +371,10 @@ def _call_soon(self, callback, args, check_loop): def _assert_is_current_event_loop(self): """Asserts that this event loop is the current event loop. - Non-threadsafe methods of this class make this assumption and will + Non-thread-safe methods of this class make this assumption and will likely behave incorrectly when the assumption is violated. - Should only be called when (self._debug == True). The caller is + Should only be called when (self._debug == True). The caller is responsible for checking this condition for performance reasons. """ try: @@ -373,11 +383,11 @@ def _assert_is_current_event_loop(self): return if current is not self: raise RuntimeError( - "non-threadsafe operation invoked on an event loop other " + "Non-thread-safe operation invoked on an event loop other " "than the current one") def call_soon_threadsafe(self, callback, *args): - """Like call_soon(), but thread safe.""" + """Like call_soon(), but thread-safe.""" handle = self._call_soon(callback, args, check_loop=False) if handle._source_traceback: del handle._source_traceback[-1] @@ -386,7 +396,7 @@ def call_soon_threadsafe(self, callback, *args): def run_in_executor(self, executor, callback, *args): if coroutines.iscoroutinefunction(callback): - raise TypeError("coroutines cannot be used with run_in_executor()") + raise TypeError("Coroutines cannot be used with run_in_executor()") if isinstance(callback, events.Handle): assert not args assert not isinstance(callback, events.TimerHandle) @@ -416,13 +426,13 @@ def _getaddrinfo_debug(self, host, port, family, type, proto, flags): if flags: msg.append('flags=%r' % flags) msg = ', '.join(msg) - logger.debug('Get addresss info %s', msg) + logger.debug('Get address info %s', msg) t0 = self.time() addrinfo = socket.getaddrinfo(host, port, family, type, proto, flags) dt = self.time() - t0 - msg = ('Getting addresss info %s took %.3f ms: %r' + msg = ('Getting address info %s took %.3f ms: %r' % (msg, dt * 1e3, addrinfo)) if dt >= self.slow_callback_duration: logger.info(msg) @@ -589,7 +599,7 @@ def create_datagram_endpoint(self, protocol_factory, raise ValueError('unexpected address family') addr_pairs_info = (((family, proto), (None, None)),) else: - # join addresss by (family, protocol) + # join address by (family, protocol) addr_infos = collections.OrderedDict() for idx, addr in ((0, local_addr), (1, remote_addr)): if addr is not None: @@ -674,7 +684,7 @@ def create_server(self, protocol_factory, host=None, port=None, reuse_address=None): """Create a TCP server bound to host and port. - Return an Server object which can be used to stop the service. + Return a Server object which can be used to stop the service. This method is a coroutine. """ @@ -731,8 +741,7 @@ def create_server(self, protocol_factory, host=None, port=None, sock.close() else: if sock is None: - raise ValueError( - 'host and port was not specified and no sock specified') + raise ValueError('Neither host/port nor sock were specified') sockets = [sock] server = Server(self, sockets) @@ -808,7 +817,7 @@ def set_exception_handler(self, handler): be set. If handler is a callable object, it should have a - matching signature to '(loop, context)', where 'loop' + signature matching '(loop, context)', where 'loop' will be a reference to the active event loop, 'context' will be a dict object (see `call_exception_handler()` documentation for details about context). @@ -825,7 +834,7 @@ def default_exception_handler(self, context): handler is set, and can be called by a custom exception handler that wants to defer to the default behavior. - context parameter has the same meaning as in + The context parameter has the same meaning as in `call_exception_handler()`. """ message = context.get('message') @@ -854,10 +863,10 @@ def default_exception_handler(self, context): logger.error('\n'.join(log_lines), exc_info=exc_info) def call_exception_handler(self, context): - """Call the current event loop exception handler. + """Call the current event loop's exception handler. + + The context argument is a dict containing the following keys: - context is a dict object containing the following keys - (new keys maybe introduced later): - 'message': Error message; - 'exception' (optional): Exception object; - 'future' (optional): Future instance; @@ -866,8 +875,10 @@ def call_exception_handler(self, context): - 'transport' (optional): Transport instance; - 'socket' (optional): Socket instance. - Note: this method should not be overloaded in subclassed - event loops. For any custom exception handling, use + New keys maybe introduced in the future. + + Note: do not overload this method in an event loop subclass. + For custom exception handling, use the `set_exception_handler()` method. """ if self._exception_handler is None: @@ -892,7 +903,7 @@ def call_exception_handler(self, context): 'context': context, }) except Exception: - # Guard 'default_exception_handler' in case it's + # Guard 'default_exception_handler' in case it is # overloaded. logger.error('Exception in default exception handler ' 'while handling an unexpected error ' @@ -900,7 +911,7 @@ def call_exception_handler(self, context): exc_info=True) def _add_callback(self, handle): - """Add a Handle to ready or scheduled.""" + """Add a Handle to _scheduled (TimerHandle) or _ready.""" assert isinstance(handle, events.Handle), 'A Handle is required here' if handle._cancelled: return @@ -971,7 +982,7 @@ def _run_once(self): # Note: We run all currently scheduled callbacks, but not any # callbacks scheduled by callbacks run this time around -- # they will be run the next time (after another I/O poll). - # Use an idiom that is threadsafe without using locks. + # Use an idiom that is thread-safe without using locks. ntodo = len(self._ready) for i in range(ntodo): handle = self._ready.popleft() diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 09b875ce..a27e5291 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -565,7 +565,7 @@ def add_child_handler(self, pid, callback, *args): process 'pid' terminates. Specifying another callback for the same process replaces the previous handler. - Note: callback() must be thread-safe + Note: callback() must be thread-safe. """ raise NotImplementedError() diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 27610f0d..7bf07ed6 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -43,8 +43,6 @@ def test_not_implemented(self): NotImplementedError, self.loop._process_events, []) self.assertRaises( NotImplementedError, self.loop._write_to_self) - self.assertRaises( - NotImplementedError, self.loop._read_from_self) self.assertRaises( NotImplementedError, self.loop._make_read_pipe_transport, m, m) From 4e6c784c02192cf817ad2420dd2ab8ce60659d3e Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 14 Jul 2014 16:27:55 +0200 Subject: [PATCH 1102/1502] create_connection(): add the socket in the "connected to" debug log --- asyncio/base_events.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 05559e00..5e067b8e 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -569,8 +569,8 @@ def create_connection(self, protocol_factory, host=None, port=None, *, transport, protocol = yield from self._create_connection_transport( sock, protocol_factory, ssl, server_hostname) if self._debug: - logger.debug("connected to %s:%r: (%r, %r)", - host, port, transport, protocol) + logger.debug("%r connected to %s:%r: (%r, %r)", + sock, host, port, transport, protocol) return transport, protocol @coroutine From 21cd24a1b9116555a236a77edca1dd07c3aae413 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 14 Jul 2014 17:02:34 +0200 Subject: [PATCH 1103/1502] Add BaseSubprocessTransport._pid attribute Store the pid so it is still accessible after the process exited. It's more convinient for debug. --- asyncio/base_subprocess.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index 2f933c54..8c9c7e23 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -14,6 +14,7 @@ def __init__(self, loop, protocol, args, shell, super().__init__(extra) self._protocol = protocol self._loop = loop + self._pid = None self._pipes = {} if stdin == subprocess.PIPE: @@ -27,6 +28,7 @@ def __init__(self, loop, protocol, args, shell, self._returncode = None self._start(args=args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, bufsize=bufsize, **kwargs) + self._pid = self._proc.pid self._extra['subprocess'] = self._proc def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): @@ -45,7 +47,7 @@ def close(self): self.terminate() def get_pid(self): - return self._proc.pid + return self._pid def get_returncode(self): return self._returncode From cf3a91c1248ea4b7842b0b50217b47409577c6ba Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 14 Jul 2014 17:23:28 +0200 Subject: [PATCH 1104/1502] Tulip issue #184: Log subprocess events in debug mode - Log stdin, stdout and stderr transports and protocols - Log process identifier (pid) - Log connection of pipes - Log process exit - Log Process.communicate() tasks: feed stdin, read stdout and stderr - Add __repr__() method to many classes related to subprocesses --- asyncio/base_events.py | 42 ++++++++++++++++++++++++++++++++++++++ asyncio/base_subprocess.py | 36 ++++++++++++++++++++++++++++++++ asyncio/streams.py | 12 +++++++++++ asyncio/subprocess.py | 26 +++++++++++++++++++++++ asyncio/unix_events.py | 11 ++++++++++ 5 files changed, 127 insertions(+) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 5e067b8e..0aeaae42 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -50,6 +50,15 @@ def _format_handle(handle): return str(handle) +def _format_pipe(fd): + if fd == subprocess.PIPE: + return '' + elif fd == subprocess.STDOUT: + return '' + else: + return repr(fd) + + class _StopError(BaseException): """Raised to stop the event loop.""" @@ -759,6 +768,9 @@ def connect_read_pipe(self, protocol_factory, pipe): waiter = futures.Future(loop=self) transport = self._make_read_pipe_transport(pipe, protocol, waiter) yield from waiter + if self._debug: + logger.debug('Read pipe %r connected: (%r, %r)', + pipe.fileno(), transport, protocol) return transport, protocol @coroutine @@ -767,8 +779,24 @@ def connect_write_pipe(self, protocol_factory, pipe): waiter = futures.Future(loop=self) transport = self._make_write_pipe_transport(pipe, protocol, waiter) yield from waiter + if self._debug: + logger.debug('Write pipe %r connected: (%r, %r)', + pipe.fileno(), transport, protocol) return transport, protocol + def _log_subprocess(self, msg, stdin, stdout, stderr): + info = [msg] + if stdin is not None: + info.append('stdin=%s' % _format_pipe(stdin)) + if stdout is not None and stderr == subprocess.STDOUT: + info.append('stdout=stderr=%s' % _format_pipe(stdout)) + else: + if stdout is not None: + info.append('stdout=%s' % _format_pipe(stdout)) + if stderr is not None: + info.append('stderr=%s' % _format_pipe(stderr)) + logger.debug(' '.join(info)) + @coroutine def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, @@ -783,8 +811,15 @@ def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, if bufsize != 0: raise ValueError("bufsize must be 0") protocol = protocol_factory() + if self._debug: + # don't log parameters: they may contain sensitive information + # (password) and may be too long + debug_log = 'run shell command %r' % cmd + self._log_subprocess(debug_log, stdin, stdout, stderr) transport = yield from self._make_subprocess_transport( protocol, cmd, True, stdin, stdout, stderr, bufsize, **kwargs) + if self._debug: + logger.info('%s: %r' % (debug_log, transport)) return transport, protocol @coroutine @@ -805,9 +840,16 @@ def subprocess_exec(self, protocol_factory, program, *args, "a bytes or text string, not %s" % type(arg).__name__) protocol = protocol_factory() + if self._debug: + # don't log parameters: they may contain sensitive information + # (password) and may be too long + debug_log = 'execute program %r' % program + self._log_subprocess(debug_log, stdin, stdout, stderr) transport = yield from self._make_subprocess_transport( protocol, popen_args, False, stdin, stdout, stderr, bufsize, **kwargs) + if self._debug: + logger.info('%s: %r' % (debug_log, transport)) return transport, protocol def set_exception_handler(self, handler): diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index 8c9c7e23..d0087793 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -4,6 +4,7 @@ from . import protocols from . import transports from .coroutines import coroutine +from .log import logger class BaseSubprocessTransport(transports.SubprocessTransport): @@ -30,6 +31,34 @@ def __init__(self, loop, protocol, args, shell, stderr=stderr, bufsize=bufsize, **kwargs) self._pid = self._proc.pid self._extra['subprocess'] = self._proc + if self._loop.get_debug(): + if isinstance(args, (bytes, str)): + program = args + else: + program = args[0] + logger.debug('process %r created: pid %s', + program, self._pid) + + def __repr__(self): + info = [self.__class__.__name__, 'pid=%s' % self._pid] + if self._returncode is not None: + info.append('returncode=%s' % self._returncode) + + stdin = self._pipes.get(0) + if stdin is not None: + info.append('stdin=%s' % stdin.pipe) + + stdout = self._pipes.get(1) + stderr = self._pipes.get(2) + if stdout is not None and stderr is stdout: + info.append('stdout=stderr=%s' % stdout.pipe) + else: + if stdout is not None: + info.append('stdout=%s' % stdout.pipe) + if stderr is not None: + info.append('stderr=%s' % stderr.pipe) + + return '<%s>' % ' '.join(info) def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): raise NotImplementedError @@ -110,6 +139,9 @@ def _pipe_data_received(self, fd, data): def _process_exited(self, returncode): assert returncode is not None, returncode assert self._returncode is None, self._returncode + if self._loop.get_debug(): + logger.info('%r exited with return code %r', + self, returncode) self._returncode = returncode self._call(self._protocol.process_exited) self._try_finish() @@ -143,6 +175,10 @@ def __init__(self, proc, fd): def connection_made(self, transport): self.pipe = transport + def __repr__(self): + return ('<%s fd=%s pipe=%r>' + % (self.__class__.__name__, self.fd, self.pipe)) + def connection_lost(self, exc): self.disconnected = True self.proc._pipe_connection_lost(self.fd, exc) diff --git a/asyncio/streams.py b/asyncio/streams.py index 9b654cdb..d18db77b 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -15,6 +15,7 @@ from . import futures from . import protocols from .coroutines import coroutine +from .log import logger _DEFAULT_LIMIT = 2**16 @@ -153,10 +154,15 @@ def __init__(self, loop=None): def pause_writing(self): assert not self._paused self._paused = True + if self._loop.get_debug(): + logger.debug("%r pauses writing", self) def resume_writing(self): assert self._paused self._paused = False + if self._loop.get_debug(): + logger.debug("%r resumes writing", self) + waiter = self._drain_waiter if waiter is not None: self._drain_waiter = None @@ -244,6 +250,12 @@ def __init__(self, transport, protocol, reader, loop): self._reader = reader self._loop = loop + def __repr__(self): + info = [self.__class__.__name__, 'transport=%r' % self._transport] + if self._reader is not None: + info.append('reader=%r' % self._reader) + return '<%s>' % ' '.join(info) + @property def transport(self): return self._transport diff --git a/asyncio/subprocess.py b/asyncio/subprocess.py index 2cd6de6d..12902f1b 100644 --- a/asyncio/subprocess.py +++ b/asyncio/subprocess.py @@ -9,6 +9,7 @@ from . import streams from . import tasks from .coroutines import coroutine +from .log import logger PIPE = subprocess.PIPE @@ -28,6 +29,16 @@ def __init__(self, limit, loop): self._waiters = collections.deque() self._transport = None + def __repr__(self): + info = [self.__class__.__name__] + if self.stdin is not None: + info.append('stdin=%r' % self.stdin) + if self.stdout is not None: + info.append('stdout=%r' % self.stdout) + if self.stderr is not None: + info.append('stderr=%r' % self.stderr) + return '<%s>' % ' '.join(info) + def connection_made(self, transport): self._transport = transport if transport.get_pipe_transport(1): @@ -91,6 +102,9 @@ def __init__(self, transport, protocol, loop): self.stderr = protocol.stderr self.pid = transport.get_pid() + def __repr__(self): + return '<%s %s>' % (self.__class__.__name__, self.pid) + @property def returncode(self): return self._transport.get_returncode() @@ -126,7 +140,13 @@ def kill(self): @coroutine def _feed_stdin(self, input): self.stdin.write(input) + if self._loop.get_debug(): + logger.debug('%r communicate: feed stdin (%s bytes)', + self, len(input)) yield from self.stdin.drain() + + if self._loop.get_debug(): + logger.debug('%r communicate: close stdin', self) self.stdin.close() @coroutine @@ -141,7 +161,13 @@ def _read_stream(self, fd): else: assert fd == 1 stream = self.stdout + if self._loop.get_debug(): + name = 'stdout' if fd == 1 else 'stderr' + logger.debug('%r communicate: read %s', self, name) output = yield from stream.read() + if self._loop.get_debug(): + name = 'stdout' if fd == 1 else 'stderr' + logger.debug('%r communicate: close %s', self, name) transport.close() return output diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index a27e5291..4ba4f49c 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -721,6 +721,9 @@ def _do_waitpid(self, expected_pid): return returncode = self._compute_returncode(status) + if self._loop.get_debug(): + logger.debug('process %s exited with returncode %s', + expected_pid, returncode) try: callback, args = self._callbacks.pop(pid) @@ -818,8 +821,16 @@ def _do_waitpid_all(self): if self._forks: # It may not be registered yet. self._zombies[pid] = returncode + if self._loop.get_debug(): + logger.debug('unknown process %s exited ' + 'with returncode %s', + pid, returncode) continue callback = None + else: + if self._loop.get_debug(): + logger.debug('process %s exited with returncode %s', + pid, returncode) if callback is None: logger.warning( From 93067f9873e9a7d9c352cc38893b0aeb54b2ff7f Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 14 Jul 2014 22:25:10 +0200 Subject: [PATCH 1105/1502] tests: make quiet the logs of SSL handshake failures when running tests in debug mode --- asyncio/test_utils.py | 16 ++++++++++++++++ tests/test_events.py | 23 +++++++++++++---------- tests/test_selector_events.py | 15 ++++++++------- 3 files changed, 37 insertions(+), 17 deletions(-) diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index 6abcaf1d..840bbf94 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -3,6 +3,7 @@ import collections import contextlib import io +import logging import os import re import socket @@ -28,6 +29,7 @@ from . import selectors from . import tasks from .coroutines import coroutine +from .log import logger if sys.platform == 'win32': # pragma: no cover @@ -401,3 +403,17 @@ def new_test_loop(self, gen=None): def tearDown(self): events.set_event_loop(None) + + +@contextlib.contextmanager +def disable_logger(): + """Context manager to disable asyncio logger. + + For example, it can be used to ignore warnings in debug mode. + """ + old_level = logger.level + try: + logger.setLevel(logging.CRITICAL+1) + yield + finally: + logger.setLevel(old_level) diff --git a/tests/test_events.py b/tests/test_events.py index 06552f87..b0657495 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -819,9 +819,10 @@ def test_create_server_ssl_verify_failed(self): # no CA loaded f_c = self.loop.create_connection(MyProto, host, port, ssl=sslcontext_client) - with self.assertRaisesRegex(ssl.SSLError, - 'certificate verify failed '): - self.loop.run_until_complete(f_c) + with test_utils.disable_logger(): + with self.assertRaisesRegex(ssl.SSLError, + 'certificate verify failed '): + self.loop.run_until_complete(f_c) # close connection self.assertIsNone(proto.transport) @@ -845,9 +846,10 @@ def test_create_unix_server_ssl_verify_failed(self): f_c = self.loop.create_unix_connection(MyProto, path, ssl=sslcontext_client, server_hostname='invalid') - with self.assertRaisesRegex(ssl.SSLError, - 'certificate verify failed '): - self.loop.run_until_complete(f_c) + with test_utils.disable_logger(): + with self.assertRaisesRegex(ssl.SSLError, + 'certificate verify failed '): + self.loop.run_until_complete(f_c) # close connection self.assertIsNone(proto.transport) @@ -871,10 +873,11 @@ def test_create_server_ssl_match_failed(self): # incorrect server_hostname f_c = self.loop.create_connection(MyProto, host, port, ssl=sslcontext_client) - with self.assertRaisesRegex( - ssl.CertificateError, - "hostname '127.0.0.1' doesn't match 'localhost'"): - self.loop.run_until_complete(f_c) + with test_utils.disable_logger(): + with self.assertRaisesRegex( + ssl.CertificateError, + "hostname '127.0.0.1' doesn't match 'localhost'"): + self.loop.run_until_complete(f_c) # close connection proto.transport.close() diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 51869316..198b14fd 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -1105,13 +1105,13 @@ def test_on_handshake_writer_retry(self): def test_on_handshake_exc(self): exc = ValueError() self.sslsock.do_handshake.side_effect = exc - transport = _SelectorSslTransport( - self.loop, self.sock, self.protocol, self.sslcontext) - transport._waiter = asyncio.Future(loop=self.loop) - transport._on_handshake(None) + with test_utils.disable_logger(): + waiter = asyncio.Future(loop=self.loop) + transport = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext, waiter) + self.assertTrue(waiter.done()) + self.assertIs(exc, waiter.exception()) self.assertTrue(self.sslsock.close.called) - self.assertTrue(transport._waiter.done()) - self.assertIs(exc, transport._waiter.exception()) def test_on_handshake_base_exc(self): transport = _SelectorSslTransport( @@ -1119,7 +1119,8 @@ def test_on_handshake_base_exc(self): transport._waiter = asyncio.Future(loop=self.loop) exc = BaseException() self.sslsock.do_handshake.side_effect = exc - self.assertRaises(BaseException, transport._on_handshake, None) + with test_utils.disable_logger(): + self.assertRaises(BaseException, transport._on_handshake, None) self.assertTrue(self.sslsock.close.called) self.assertTrue(transport._waiter.done()) self.assertIs(exc, transport._waiter.exception()) From d13d4f202773102c04ef08f517f2f1f58a957f4a Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 15 Jul 2014 23:42:02 +0200 Subject: [PATCH 1106/1502] test_selector_events: remove duplicate call to _on_handshake() method The _SelectorSslTransport constructor already calls it. --- tests/test_selector_events.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 198b14fd..d483edcb 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -1092,15 +1092,13 @@ def test_on_handshake_reader_retry(self): self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError transport = _SelectorSslTransport( self.loop, self.sock, self.protocol, self.sslcontext) - transport._on_handshake(None) - self.loop.assert_reader(1, transport._on_handshake, None) + self.loop.assert_reader(1, transport._on_handshake, 0) def test_on_handshake_writer_retry(self): self.sslsock.do_handshake.side_effect = ssl.SSLWantWriteError transport = _SelectorSslTransport( self.loop, self.sock, self.protocol, self.sslcontext) - transport._on_handshake(None) - self.loop.assert_writer(1, transport._on_handshake, None) + self.loop.assert_writer(1, transport._on_handshake, 0) def test_on_handshake_exc(self): exc = ValueError() @@ -1120,7 +1118,7 @@ def test_on_handshake_base_exc(self): exc = BaseException() self.sslsock.do_handshake.side_effect = exc with test_utils.disable_logger(): - self.assertRaises(BaseException, transport._on_handshake, None) + self.assertRaises(BaseException, transport._on_handshake, 0) self.assertTrue(self.sslsock.close.called) self.assertTrue(transport._waiter.done()) self.assertIs(exc, transport._waiter.exception()) From fe8cfa426d1cd8fddd663f5888fab25f4ebcc429 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 16 Jul 2014 17:25:08 +0200 Subject: [PATCH 1107/1502] Fix _on_handshake() tests --- tests/test_selector_events.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index d483edcb..c0f388d6 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -1089,16 +1089,18 @@ def test_on_handshake(self): self.assertIsNone(waiter.result()) def test_on_handshake_reader_retry(self): + self.loop.set_debug(False) self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError transport = _SelectorSslTransport( self.loop, self.sock, self.protocol, self.sslcontext) - self.loop.assert_reader(1, transport._on_handshake, 0) + self.loop.assert_reader(1, transport._on_handshake, None) def test_on_handshake_writer_retry(self): + self.loop.set_debug(False) self.sslsock.do_handshake.side_effect = ssl.SSLWantWriteError transport = _SelectorSslTransport( self.loop, self.sock, self.protocol, self.sslcontext) - self.loop.assert_writer(1, transport._on_handshake, 0) + self.loop.assert_writer(1, transport._on_handshake, None) def test_on_handshake_exc(self): exc = ValueError() From 9ff66eab953d882ca4347ee2992b2e14544c6b70 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 16 Jul 2014 18:35:33 +0200 Subject: [PATCH 1108/1502] Python issue 21163: Ignore "destroy pending task" warnings for private tasks in asyncio.gather(). --- asyncio/tasks.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 78b4c4dc..a741bd33 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -558,21 +558,33 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False): prevent the cancellation of one child to cause other children to be cancelled.) """ - arg_to_fut = {arg: async(arg, loop=loop) for arg in set(coros_or_futures)} - children = [arg_to_fut[arg] for arg in coros_or_futures] - n = len(children) - if n == 0: + if not coros_or_futures: outer = futures.Future(loop=loop) outer.set_result([]) return outer - if loop is None: - loop = children[0]._loop - for fut in children: - if fut._loop is not loop: - raise ValueError("futures are tied to different event loops") + + arg_to_fut = {} + for arg in set(coros_or_futures): + if not isinstance(arg, futures.Future): + fut = async(arg, loop=loop) + if loop is None: + loop = fut._loop + # The caller cannot control this future, the "destroy pending task" + # warning should not be emitted. + fut._log_destroy_pending = False + else: + fut = arg + if loop is None: + loop = fut._loop + elif fut._loop is not loop: + raise ValueError("futures are tied to different event loops") + arg_to_fut[arg] = fut + + children = [arg_to_fut[arg] for arg in coros_or_futures] + nchildren = len(children) outer = _GatheringFuture(children, loop=loop) nfinished = 0 - results = [None] * n + results = [None] * nchildren def _done_callback(i, fut): nonlocal nfinished @@ -595,7 +607,7 @@ def _done_callback(i, fut): res = fut._result results[i] = res nfinished += 1 - if nfinished == n: + if nfinished == nchildren: outer.set_result(results) for i, fut in enumerate(children): From c5f3191c2276d29e041bdcca0308c680d6831e9e Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 16 Jul 2014 18:50:01 +0200 Subject: [PATCH 1109/1502] Python issue 21163: Fix "destroy pending task" warning in test_wait_errors() --- asyncio/tasks.py | 4 ++-- tests/test_tasks.py | 11 +++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index a741bd33..07952c9a 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -330,14 +330,14 @@ def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): raise TypeError("expect a list of futures, not %s" % type(fs).__name__) if not fs: raise ValueError('Set of coroutines/Futures is empty.') + if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED): + raise ValueError('Invalid return_when value: {}'.format(return_when)) if loop is None: loop = events.get_event_loop() fs = {async(f, loop=loop) for f in set(fs)} - if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED): - raise ValueError('Invalid return_when value: {}'.format(return_when)) return (yield from _wait(fs, timeout, return_when, loop)) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index ca770f90..e199c5a4 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -623,10 +623,13 @@ def test_wait_errors(self): ValueError, self.loop.run_until_complete, asyncio.wait(set(), loop=self.loop)) - self.assertRaises( - ValueError, self.loop.run_until_complete, - asyncio.wait([asyncio.sleep(10.0, loop=self.loop)], - return_when=-1, loop=self.loop)) + # -1 is an invalid return_when value + sleep_coro = asyncio.sleep(10.0, loop=self.loop) + wait_coro = asyncio.wait([sleep_coro], return_when=-1, loop=self.loop) + self.assertRaises(ValueError, + self.loop.run_until_complete, wait_coro) + + sleep_coro.close() def test_wait_first_completed(self): From e6573535c9215bbb4d7a78720f9fd4053902536c Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 16 Jul 2014 18:53:50 +0200 Subject: [PATCH 1110/1502] test_as_completed(): disable "slow callback" warning --- tests/test_tasks.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index e199c5a4..7b93a0e2 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -851,6 +851,8 @@ def gen(): yield 0 loop = self.new_test_loop(gen) + # disable "slow callback" warning + loop.slow_callback_duration = 1.0 completed = set() time_shifted = False From 9e4fa8b4167ff05dd9502abbad32922919f55560 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 17 Jul 2014 12:43:13 +0200 Subject: [PATCH 1111/1502] asyncio, tulip issue 190: Process.communicate() must ignore BrokenPipeError If you want to handle the BrokenPipeError, you can easily reimplement communicate(). Add also a unit test to ensure that stdin.write() + stdin.drain() raises BrokenPipeError. --- asyncio/subprocess.py | 6 +++++- tests/test_subprocess.py | 27 ++++++++++++++++++++------- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/asyncio/subprocess.py b/asyncio/subprocess.py index 12902f1b..23d6b4da 100644 --- a/asyncio/subprocess.py +++ b/asyncio/subprocess.py @@ -143,7 +143,11 @@ def _feed_stdin(self, input): if self._loop.get_debug(): logger.debug('%r communicate: feed stdin (%s bytes)', self, len(input)) - yield from self.stdin.drain() + try: + yield from self.stdin.drain() + except BrokenPipeError: + # ignore BrokenPipeError + pass if self._loop.get_debug(): logger.debug('%r communicate: close stdin', self) diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index 3204d42e..e41cabef 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -11,9 +11,6 @@ # Program blocking PROGRAM_BLOCKED = [sys.executable, '-c', 'import time; time.sleep(3600)'] -# Program sleeping during 1 second -PROGRAM_SLEEP_1SEC = [sys.executable, '-c', 'import time; time.sleep(1)'] - # Program copying input to output PROGRAM_CAT = [ sys.executable, '-c', @@ -118,16 +115,32 @@ def test_send_signal(self): returncode = self.loop.run_until_complete(proc.wait()) self.assertEqual(-signal.SIGHUP, returncode) - def test_broken_pipe(self): + def prepare_broken_pipe_test(self): + # buffer large enough to feed the whole pipe buffer large_data = b'x' * support.PIPE_MAX_SIZE + # the program ends before the stdin can be feeded create = asyncio.create_subprocess_exec( - *PROGRAM_SLEEP_1SEC, + sys.executable, '-c', 'pass', stdin=subprocess.PIPE, loop=self.loop) proc = self.loop.run_until_complete(create) - with self.assertRaises(BrokenPipeError): - self.loop.run_until_complete(proc.communicate(large_data)) + return (proc, large_data) + + def test_stdin_broken_pipe(self): + proc, large_data = self.prepare_broken_pipe_test() + + # drain() must raise BrokenPipeError + proc.stdin.write(large_data) + self.assertRaises(BrokenPipeError, + self.loop.run_until_complete, proc.stdin.drain()) + self.loop.run_until_complete(proc.wait()) + + def test_communicate_ignore_broken_pipe(self): + proc, large_data = self.prepare_broken_pipe_test() + + # communicate() must ignore BrokenPipeError when feeding stdin + self.loop.run_until_complete(proc.communicate(large_data)) self.loop.run_until_complete(proc.wait()) From c9de93be08f9b7475b033c62d30120f6d9082cdf Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 17 Jul 2014 13:10:43 +0200 Subject: [PATCH 1112/1502] tulip issue 190: Process.communicate() now ignores ConnectionResetError too --- asyncio/subprocess.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/asyncio/subprocess.py b/asyncio/subprocess.py index 23d6b4da..e4c14995 100644 --- a/asyncio/subprocess.py +++ b/asyncio/subprocess.py @@ -139,17 +139,19 @@ def kill(self): @coroutine def _feed_stdin(self, input): + debug = self._loop.get_debug() self.stdin.write(input) - if self._loop.get_debug(): + if debug: logger.debug('%r communicate: feed stdin (%s bytes)', self, len(input)) try: yield from self.stdin.drain() - except BrokenPipeError: - # ignore BrokenPipeError - pass + except (BrokenPipeError, ConnectionResetError) as exc: + # communicate() ignores BrokenPipeError and ConnectionResetError + if debug: + logger.debug('%r communicate: stdin got %r', self, exc) - if self._loop.get_debug(): + if debug: logger.debug('%r communicate: close stdin', self) self.stdin.close() From ed7395c2a4bcd65d464670db44331dba009e4fb4 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 17 Jul 2014 14:00:41 +0200 Subject: [PATCH 1113/1502] Fix test_stdin_broken_pipe(): drain() can also raise ConnectionResetError --- tests/test_subprocess.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index e41cabef..d050458e 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -130,9 +130,9 @@ def prepare_broken_pipe_test(self): def test_stdin_broken_pipe(self): proc, large_data = self.prepare_broken_pipe_test() - # drain() must raise BrokenPipeError + # drain() must raise BrokenPipeError or ConnectionResetError proc.stdin.write(large_data) - self.assertRaises(BrokenPipeError, + self.assertRaises((BrokenPipeError, ConnectionResetError), self.loop.run_until_complete, proc.stdin.drain()) self.loop.run_until_complete(proc.wait()) From fc1e262e308ffd26c03c6b224bb643d5d4fec45c Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 17 Jul 2014 22:18:31 +0200 Subject: [PATCH 1114/1502] Tulip issue 192, Python issue 21645: Rewrite signal handling Since Python 3.3, the C signal handler writes the signal number into the wakeup file descriptor and then schedules the Python call using Py_AddPendingCall(). asyncio uses the wakeup file descriptor to wake up the event loop, and relies on Py_AddPendingCall() to schedule the final callback with call_soon(). If the C signal handler is called in a thread different than the thread of the event loop, the loop is awaken but Py_AddPendingCall() was not called yet. In this case, the event loop has nothing to do and go to sleep again. Py_AddPendingCall() is called while the event loop is sleeping again and so the final callback is not scheduled immediatly. This patch changes how asyncio handles signals. Instead of relying on Py_AddPendingCall() and the wakeup file descriptor, asyncio now only relies on the wakeup file descriptor. asyncio reads signal numbers from the wakeup file descriptor to call its signal handler. --- asyncio/proactor_events.py | 2 +- asyncio/selector_events.py | 6 +++++- asyncio/unix_events.py | 20 ++++++++++++++++++-- tests/test_proactor_events.py | 2 +- tests/test_unix_events.py | 4 ++-- 5 files changed, 27 insertions(+), 7 deletions(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index d09e9faa..c530687d 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -443,7 +443,7 @@ def _loop_self_reading(self, f=None): f.add_done_callback(self._loop_self_reading) def _write_to_self(self): - self._csock.send(b'x') + self._csock.send(b'\0') def _start_serving(self, protocol_factory, sock, ssl=None, server=None): if ssl: diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index d79c0801..cd1a75aa 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -94,12 +94,16 @@ def _make_self_pipe(self): self._internal_fds += 1 self.add_reader(self._ssock.fileno(), self._read_from_self) + def _process_self_data(self, data): + pass + def _read_from_self(self): while True: try: data = self._ssock.recv(4096) if not data: break + self._process_self_data(data) except InterruptedError: continue except BlockingIOError: @@ -114,7 +118,7 @@ def _write_to_self(self): csock = self._csock if csock is not None: try: - csock.send(b'x') + csock.send(b'\0') except OSError: pass diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 4ba4f49c..73a85c11 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -31,6 +31,11 @@ raise ImportError('Signals are not really supported on Windows') +def _sighandler_noop(signum, frame): + """Dummy signal handler.""" + pass + + class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): """Unix event loop. @@ -49,6 +54,13 @@ def close(self): for sig in list(self._signal_handlers): self.remove_signal_handler(sig) + def _process_self_data(self, data): + for signum in data: + if not signum: + # ignore null bytes written by _write_to_self() + continue + self._handle_signal(signum) + def add_signal_handler(self, sig, callback, *args): """Add a handler for a signal. UNIX only. @@ -69,7 +81,11 @@ def add_signal_handler(self, sig, callback, *args): self._signal_handlers[sig] = handle try: - signal.signal(sig, self._handle_signal) + # Register a dummy signal handler to ask Python to write the signal + # number in the wakup file descriptor. _process_self_data() will + # read signal numbers from this file descriptor to handle signals. + signal.signal(sig, _sighandler_noop) + # Set SA_RESTART to limit EINTR occurrences. signal.siginterrupt(sig, False) except OSError as exc: @@ -85,7 +101,7 @@ def add_signal_handler(self, sig, callback, *args): else: raise - def _handle_signal(self, sig, arg): + def _handle_signal(self, sig): """Internal helper that is the actual signal handler.""" handle = self._signal_handlers.get(sig) if handle is None: diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index 4bb4f0b3..0c536986 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -435,7 +435,7 @@ def test_loop_self_reading_exception(self): def test_write_to_self(self): self.loop._write_to_self() - self.csock.send.assert_called_with(b'x') + self.csock.send.assert_called_with(b'\0') def test_process_events(self): self.loop._process_events([]) diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 0ade7f21..d355defb 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -42,7 +42,7 @@ def test_check_signal(self): ValueError, self.loop._check_signal, signal.NSIG + 1) def test_handle_signal_no_handler(self): - self.loop._handle_signal(signal.NSIG + 1, ()) + self.loop._handle_signal(signal.NSIG + 1) def test_handle_signal_cancelled_handler(self): h = asyncio.Handle(mock.Mock(), (), @@ -50,7 +50,7 @@ def test_handle_signal_cancelled_handler(self): h.cancel() self.loop._signal_handlers[signal.NSIG + 1] = h self.loop.remove_signal_handler = mock.Mock() - self.loop._handle_signal(signal.NSIG + 1, ()) + self.loop._handle_signal(signal.NSIG + 1) self.loop.remove_signal_handler.assert_called_with(signal.NSIG + 1) @mock.patch('asyncio.unix_events.signal') From 748b427fe5802121d90cb4f9d03dfad94c78535b Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 17 Jul 2014 23:48:06 +0200 Subject: [PATCH 1115/1502] Python issue 21247: Fix a race condition in test_send_signal() of asyncio Add a basic synchronization mechanism to wait until the child process is ready before sending it a signal. --- tests/test_subprocess.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index d050458e..5425d9bf 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -108,11 +108,22 @@ def test_terminate(self): @unittest.skipIf(sys.platform == 'win32', "Don't have SIGHUP") def test_send_signal(self): - args = PROGRAM_BLOCKED - create = asyncio.create_subprocess_exec(*args, loop=self.loop) + code = 'import time; print("sleeping", flush=True); time.sleep(3600)' + args = [sys.executable, '-c', code] + create = asyncio.create_subprocess_exec(*args, loop=self.loop, stdout=subprocess.PIPE) proc = self.loop.run_until_complete(create) - proc.send_signal(signal.SIGHUP) - returncode = self.loop.run_until_complete(proc.wait()) + + @asyncio.coroutine + def send_signal(proc): + # basic synchronization to wait until the program is sleeping + line = yield from proc.stdout.readline() + self.assertEqual(line, b'sleeping\n') + + proc.send_signal(signal.SIGHUP) + returncode = (yield from proc.wait()) + return returncode + + returncode = self.loop.run_until_complete(send_signal(proc)) self.assertEqual(-signal.SIGHUP, returncode) def prepare_broken_pipe_test(self): From 4f754ad3177eb0fcdc10ccf7804349a9453e9ff0 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 18 Jul 2014 12:21:06 +0200 Subject: [PATCH 1116/1502] Fix asyncio.__all__: export also unix_events and windows_events symbols For example, on Windows, it was not possible to get ProactorEventLoop or DefaultEventLoopPolicy using "from asyncio import *". --- asyncio/__init__.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/asyncio/__init__.py b/asyncio/__init__.py index 789424e4..3911fb40 100644 --- a/asyncio/__init__.py +++ b/asyncio/__init__.py @@ -29,12 +29,6 @@ from .tasks import * from .transports import * -if sys.platform == 'win32': # pragma: no cover - from .windows_events import * -else: - from .unix_events import * # pragma: no cover - - __all__ = (coroutines.__all__ + events.__all__ + futures.__all__ + @@ -45,3 +39,10 @@ subprocess.__all__ + tasks.__all__ + transports.__all__) + +if sys.platform == 'win32': # pragma: no cover + from .windows_events import * + __all__ += windows_events.__all__ +else: + from .unix_events import * # pragma: no cover + __all__ += unix_events.__all__ From b10a08fdcf6a0df382eb7abfd09fab34319e3bd2 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 21 Jul 2014 16:20:41 +0200 Subject: [PATCH 1117/1502] Fix test_stdin_broken_pipe(): drain() is not a coroutine --- tests/test_subprocess.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index 5425d9bf..a4e9df2f 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -141,10 +141,15 @@ def prepare_broken_pipe_test(self): def test_stdin_broken_pipe(self): proc, large_data = self.prepare_broken_pipe_test() + @asyncio.coroutine + def write_stdin(proc, data): + proc.stdin.write(data) + yield from proc.stdin.drain() + + coro = write_stdin(proc, large_data) # drain() must raise BrokenPipeError or ConnectionResetError - proc.stdin.write(large_data) self.assertRaises((BrokenPipeError, ConnectionResetError), - self.loop.run_until_complete, proc.stdin.drain()) + self.loop.run_until_complete, coro) self.loop.run_until_complete(proc.wait()) def test_communicate_ignore_broken_pipe(self): From 95ce36a9dcbad64a66e28ac41a7e7ee1ee98aa1c Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 22 Jul 2014 12:02:58 +0200 Subject: [PATCH 1118/1502] Tulip issue #193: Convert StreamWriter.drain() to a classic coroutine Replace also _make_drain_waiter() function with a classic _drain_helper() coroutine. --- asyncio/streams.py | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/asyncio/streams.py b/asyncio/streams.py index d18db77b..c77eb606 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -141,15 +141,14 @@ class FlowControlMixin(protocols.Protocol): resume_reading() and connection_lost(). If the subclass overrides these it must call the super methods. - StreamWriter.drain() must check for error conditions and then call - _make_drain_waiter(), which will return either () or a Future - depending on the paused state. + StreamWriter.drain() must wait for _drain_helper() coroutine. """ def __init__(self, loop=None): self._loop = loop # May be None; we may never need it. self._paused = False self._drain_waiter = None + self._connection_lost = False def pause_writing(self): assert not self._paused @@ -170,6 +169,7 @@ def resume_writing(self): waiter.set_result(None) def connection_lost(self, exc): + self._connection_lost = True # Wake up the writer if currently paused. if not self._paused: return @@ -184,14 +184,17 @@ def connection_lost(self, exc): else: waiter.set_exception(exc) - def _make_drain_waiter(self): + @coroutine + def _drain_helper(self): + if self._connection_lost: + raise ConnectionResetError('Connection lost') if not self._paused: - return () + return waiter = self._drain_waiter assert waiter is None or waiter.cancelled() waiter = futures.Future(loop=self._loop) self._drain_waiter = waiter - return waiter + yield from waiter class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): @@ -247,6 +250,8 @@ class StreamWriter: def __init__(self, transport, protocol, reader, loop): self._transport = transport self._protocol = protocol + # drain() expects that the reader has a exception() method + assert reader is None or isinstance(reader, StreamReader) self._reader = reader self._loop = loop @@ -278,26 +283,20 @@ def close(self): def get_extra_info(self, name, default=None): return self._transport.get_extra_info(name, default) + @coroutine def drain(self): - """This method has an unusual return value. + """Flush the write buffer. The intended use is to write w.write(data) yield from w.drain() - - When there's nothing to wait for, drain() returns (), and the - yield-from continues immediately. When the transport buffer - is full (the protocol is paused), drain() creates and returns - a Future and the yield-from will block until that Future is - completed, which will happen when the buffer is (partially) - drained and the protocol is resumed. """ - if self._reader is not None and self._reader._exception is not None: - raise self._reader._exception - if self._transport._conn_lost: # Uses private variable. - raise ConnectionResetError('Connection lost') - return self._protocol._make_drain_waiter() + if self._reader is not None: + exc = self._reader.exception() + if exc is not None: + raise exc + yield from self._protocol._drain_helper() class StreamReader: From 921091cb89d94cdb80272f55a0d71a00a4ae5522 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 22 Jul 2014 17:11:58 +0200 Subject: [PATCH 1119/1502] signal.set_wakeup_fd() can now raise an OSError on Python 3.5 --- asyncio/unix_events.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 73a85c11..5020cc5d 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -74,7 +74,7 @@ def add_signal_handler(self, sig, callback, *args): # event loop running in another thread cannot add a signal # handler. signal.set_wakeup_fd(self._csock.fileno()) - except ValueError as exc: + except (ValueError, OSError) as exc: raise RuntimeError(str(exc)) handle = events.Handle(callback, args, self) @@ -93,7 +93,7 @@ def add_signal_handler(self, sig, callback, *args): if not self._signal_handlers: try: signal.set_wakeup_fd(-1) - except ValueError as nexc: + except (ValueError, OSError) as nexc: logger.info('set_wakeup_fd(-1) failed: %s', nexc) if exc.errno == errno.EINVAL: @@ -138,7 +138,7 @@ def remove_signal_handler(self, sig): if not self._signal_handlers: try: signal.set_wakeup_fd(-1) - except ValueError as exc: + except (ValueError, OSError) as exc: logger.info('set_wakeup_fd(-1) failed: %s', exc) return True From dc6e790f78f95af49d3a56d9cf916194cd236696 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 23 Jul 2014 18:16:37 +0200 Subject: [PATCH 1120/1502] Tulip issue #194: Don't use sys.getrefcount() in unit tests --- tests/test_selector_events.py | 4 ++-- tests/test_unix_events.py | 17 ++++++++--------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index c0f388d6..bd6c2f26 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -672,6 +672,8 @@ def test_fatal_error(self, m_exc): def test_connection_lost(self): exc = OSError() tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + self.assertIsNotNone(tr._protocol) + self.assertIsNotNone(tr._loop) tr._call_connection_lost(exc) self.protocol.connection_lost.assert_called_with(exc) @@ -679,8 +681,6 @@ def test_connection_lost(self): self.assertIsNone(tr._sock) self.assertIsNone(tr._protocol) - self.assertEqual(2, sys.getrefcount(self.protocol), - pprint.pformat(gc.get_referrers(self.protocol))) self.assertIsNone(tr._loop) diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index d355defb..099d4d51 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -435,6 +435,8 @@ def test__close(self, m_read): def test__call_connection_lost(self): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) + self.assertIsNotNone(tr._protocol) + self.assertIsNotNone(tr._loop) err = None tr._call_connection_lost(err) @@ -442,13 +444,13 @@ def test__call_connection_lost(self): self.pipe.close.assert_called_with() self.assertIsNone(tr._protocol) - self.assertEqual(2, sys.getrefcount(self.protocol), - pprint.pformat(gc.get_referrers(self.protocol))) self.assertIsNone(tr._loop) def test__call_connection_lost_with_err(self): tr = unix_events._UnixReadPipeTransport( self.loop, self.pipe, self.protocol) + self.assertIsNotNone(tr._protocol) + self.assertIsNotNone(tr._loop) err = OSError() tr._call_connection_lost(err) @@ -456,9 +458,6 @@ def test__call_connection_lost_with_err(self): self.pipe.close.assert_called_with() self.assertIsNone(tr._protocol) - - self.assertEqual(2, sys.getrefcount(self.protocol), - pprint.pformat(gc.get_referrers(self.protocol))) self.assertIsNone(tr._loop) @@ -717,6 +716,8 @@ def test_abort(self, m_write): def test__call_connection_lost(self): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) + self.assertIsNotNone(tr._protocol) + self.assertIsNotNone(tr._loop) err = None tr._call_connection_lost(err) @@ -724,13 +725,13 @@ def test__call_connection_lost(self): self.pipe.close.assert_called_with() self.assertIsNone(tr._protocol) - self.assertEqual(2, sys.getrefcount(self.protocol), - pprint.pformat(gc.get_referrers(self.protocol))) self.assertIsNone(tr._loop) def test__call_connection_lost_with_err(self): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) + self.assertIsNotNone(tr._protocol) + self.assertIsNotNone(tr._loop) err = OSError() tr._call_connection_lost(err) @@ -738,8 +739,6 @@ def test__call_connection_lost_with_err(self): self.pipe.close.assert_called_with() self.assertIsNone(tr._protocol) - self.assertEqual(2, sys.getrefcount(self.protocol), - pprint.pformat(gc.get_referrers(self.protocol))) self.assertIsNone(tr._loop) def test_close(self): From e52b685fb9f535f3e923a80442bf241f3515a4a0 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 24 Jul 2014 11:32:59 +0200 Subject: [PATCH 1121/1502] Python issue 20055: Fix BaseEventLoop.stop() docstring, incomplete sentence. Patch written by Saimadhav Heblikar. --- asyncio/base_events.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 0aeaae42..d0a337bd 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -270,9 +270,9 @@ def run_until_complete(self, future): def stop(self): """Stop running the event loop. - Every callback scheduled before stop() is called will run. - Callback scheduled after stop() is called won't. However, - those callbacks will run if run_*() is called again later. + Every callback scheduled before stop() is called will run. Callbacks + scheduled after stop() is called will not run. However, those callbacks + will run if run_forever is called again later. """ self.call_soon(_raise_stop_error) From 5be3bb8d22886b67d24555c54996662f60f3d63b Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 24 Jul 2014 12:03:42 +0200 Subject: [PATCH 1122/1502] tests: relax timings for slow buildbots --- tests/test_windows_events.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index 4ab56e6c..689deb47 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -94,14 +94,14 @@ def test_wait_for_handle(self): event = _overlapped.CreateEvent(None, True, False, None) self.addCleanup(_winapi.CloseHandle, event) - # Wait for unset event with 0.2s timeout; + # Wait for unset event with 0.5s timeout; # result should be False at timeout - f = self.loop._proactor.wait_for_handle(event, 0.2) + f = self.loop._proactor.wait_for_handle(event, 0.5) start = self.loop.time() self.loop.run_until_complete(f) elapsed = self.loop.time() - start self.assertFalse(f.result()) - self.assertTrue(0.18 < elapsed < 0.9, elapsed) + self.assertTrue(0.48 < elapsed < 0.9, elapsed) _overlapped.SetEvent(event) @@ -112,7 +112,7 @@ def test_wait_for_handle(self): self.loop.run_until_complete(f) elapsed = self.loop.time() - start self.assertTrue(f.result()) - self.assertTrue(0 <= elapsed < 0.1, elapsed) + self.assertTrue(0 <= elapsed < 0.3, elapsed) _overlapped.ResetEvent(event) From 60b02fb0b1f9f5eeb1d5e921a5b8d8639ba76914 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 25 Jul 2014 00:37:32 +0200 Subject: [PATCH 1123/1502] _OverlappedFuture truncates the source traceback to hide the call to the parent constructor (useless in debug). --- asyncio/windows_events.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 9d86c96b..fe2a9d60 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -38,6 +38,8 @@ class _OverlappedFuture(futures.Future): def __init__(self, ov, *, loop=None): super().__init__(loop=loop) + if self._source_traceback: + del self._source_traceback[-1] self.ov = ov def __repr__(self): From d9ea1117f4c0d2f530ad5f0c1576fe60455699d8 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 25 Jul 2014 00:38:06 +0200 Subject: [PATCH 1124/1502] Add the address of the overlapped object in repr(_OverlappedFuture) --- asyncio/windows_events.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index fe2a9d60..2afc94d8 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -44,10 +44,8 @@ def __init__(self, ov, *, loop=None): def __repr__(self): info = [self._state.lower()] - if self.ov.pending: - info.append('overlapped=pending') - else: - info.append('overlapped=completed') + state = 'pending' if self.ov.pending else 'completed' + info.append('overlapped=<%s, %#x>' % (state, self.ov.address)) if self._state == futures._FINISHED: info.append(self._format_result()) if self._callbacks: From 0611a11badb6baf65b789279c25ece02ce0aaf3b Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 25 Jul 2014 00:40:29 +0200 Subject: [PATCH 1125/1502] _OverlappedFuture.cancel() doesn't cancel the overlapped anymore if it is done: if it is already cancelled or completed. Log also an error if the cancellation failed. --- asyncio/windows_events.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 2afc94d8..42c330ad 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -53,10 +53,18 @@ def __repr__(self): return '<%s %s>' % (self.__class__.__name__, ' '.join(info)) def cancel(self): - try: - self.ov.cancel() - except OSError: - pass + if not self.done(): + try: + self.ov.cancel() + except OSError as exc: + context = { + 'message': 'Cancelling an overlapped future failed', + 'exception': exc, + 'future': self, + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) return super().cancel() From 19e60dfa6809673e5ac41f236ccf5ebeb4d7f018 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 25 Jul 2014 00:42:55 +0200 Subject: [PATCH 1126/1502] Add a destructor to IocpProactor which closes it --- asyncio/windows_events.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 42c330ad..e5e8a751 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -484,6 +484,9 @@ def close(self): _winapi.CloseHandle(self._iocp) self._iocp = None + def __del__(self): + self.close() + class _WindowsSubprocessTransport(base_subprocess.BaseSubprocessTransport): From f236ebf00a192f266a9a94c8c56570d966def004 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 25 Jul 2014 00:43:19 +0200 Subject: [PATCH 1127/1502] Add a __repr__() method to IocpProactor --- asyncio/windows_events.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index e5e8a751..e23d2aaa 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -216,6 +216,11 @@ def __init__(self, concurrency=0xffffffff): self._registered = weakref.WeakSet() self._stopped_serving = weakref.WeakSet() + def __repr__(self): + return ('<%s overlapped#=%s result#=%s>' + % (self.__class__.__name__, len(self._cache), + len(self._results))) + def set_loop(self, loop): self._loop = loop From acbfeed912cbfb7f3e5fcb8fb76309d176735417 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 25 Jul 2014 00:46:29 +0200 Subject: [PATCH 1128/1502] Tulip issue #195: Don't call UnregisterWait() twice if a _WaitHandleFuture is cancelled twice to fix a crash. --- asyncio/windows_events.py | 18 ++++++++++-------- tests/test_windows_events.py | 12 +++++++++++- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index e23d2aaa..2a6e44a7 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -75,13 +75,20 @@ def __init__(self, wait_handle, *, loop=None): super().__init__(loop=loop) self._wait_handle = wait_handle - def cancel(self): - super().cancel() + def _unregister(self): + if self._wait_handle is None: + return try: _overlapped.UnregisterWait(self._wait_handle) except OSError as e: if e.winerror != _overlapped.ERROR_IO_PENDING: raise + # ERROR_IO_PENDING is not an error, the wait was unregistered + self._wait_handle = None + + def cancel(self): + self._unregister() + super().cancel() class PipeServer(object): @@ -366,12 +373,7 @@ def wait_for_handle(self, handle, timeout=None): f = _WaitHandleFuture(wh, loop=self._loop) def finish_wait_for_handle(trans, key, ov): - if not f.cancelled(): - try: - _overlapped.UnregisterWait(wh) - except OSError as e: - if e.winerror != _overlapped.ERROR_IO_PENDING: - raise + f._unregister() # Note that this second wait means that we should only use # this with handles types where a successful wait has no # effect. So events or processes are all right, but locks diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index 689deb47..c35c1c29 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -114,7 +114,12 @@ def test_wait_for_handle(self): self.assertTrue(f.result()) self.assertTrue(0 <= elapsed < 0.3, elapsed) - _overlapped.ResetEvent(event) + # Tulip issue #195: cancelling a done _WaitHandleFuture must not crash + f.cancel() + + def test_wait_for_handle_cancel(self): + event = _overlapped.CreateEvent(None, True, False, None) + self.addCleanup(_winapi.CloseHandle, event) # Wait for unset event with a cancelled future; # CancelledError should be raised immediately @@ -126,6 +131,11 @@ def test_wait_for_handle(self): elapsed = self.loop.time() - start self.assertTrue(0 <= elapsed < 0.1, elapsed) + # Tulip issue #195: cancelling a _WaitHandleFuture twice must not crash + f = self.loop._proactor.wait_for_handle(event) + f.cancel() + f.cancel() + if __name__ == '__main__': unittest.main() From b0ac76ff11b714c8aa22d32b0cb41fac097bb6eb Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 25 Jul 2014 00:47:24 +0200 Subject: [PATCH 1129/1502] tests: rename "f" to "fut" --- tests/test_windows_events.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index c35c1c29..85d9669b 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -96,26 +96,26 @@ def test_wait_for_handle(self): # Wait for unset event with 0.5s timeout; # result should be False at timeout - f = self.loop._proactor.wait_for_handle(event, 0.5) + fut = self.loop._proactor.wait_for_handle(event, 0.5) start = self.loop.time() - self.loop.run_until_complete(f) + self.loop.run_until_complete(fut) elapsed = self.loop.time() - start - self.assertFalse(f.result()) + self.assertFalse(fut.result()) self.assertTrue(0.48 < elapsed < 0.9, elapsed) _overlapped.SetEvent(event) # Wait for for set event; # result should be True immediately - f = self.loop._proactor.wait_for_handle(event, 10) + fut = self.loop._proactor.wait_for_handle(event, 10) start = self.loop.time() - self.loop.run_until_complete(f) + self.loop.run_until_complete(fut) elapsed = self.loop.time() - start - self.assertTrue(f.result()) + self.assertTrue(fut.result()) self.assertTrue(0 <= elapsed < 0.3, elapsed) # Tulip issue #195: cancelling a done _WaitHandleFuture must not crash - f.cancel() + fut.cancel() def test_wait_for_handle_cancel(self): event = _overlapped.CreateEvent(None, True, False, None) @@ -123,18 +123,18 @@ def test_wait_for_handle_cancel(self): # Wait for unset event with a cancelled future; # CancelledError should be raised immediately - f = self.loop._proactor.wait_for_handle(event, 10) - f.cancel() + fut = self.loop._proactor.wait_for_handle(event, 10) + fut.cancel() start = self.loop.time() with self.assertRaises(asyncio.CancelledError): - self.loop.run_until_complete(f) + self.loop.run_until_complete(fut) elapsed = self.loop.time() - start self.assertTrue(0 <= elapsed < 0.1, elapsed) # Tulip issue #195: cancelling a _WaitHandleFuture twice must not crash - f = self.loop._proactor.wait_for_handle(event) - f.cancel() - f.cancel() + fut = self.loop._proactor.wait_for_handle(event) + fut.cancel() + fut.cancel() if __name__ == '__main__': From 7ee9e4a1c19034bd06fdbbe6ae39c874fd5c26dc Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 25 Jul 2014 00:50:31 +0200 Subject: [PATCH 1130/1502] IocpProactor.close(): cancel futures to cancel overlapped operations, instead of cancelling directly overlapped operations. Future objects may not call ov.cancel() if the future was cancelled or if the overlapped was already cancelled. The cancel() method of the future may also catch exceptions. Log also errors on cancellation. --- asyncio/windows_events.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 2a6e44a7..af290b7e 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -470,7 +470,7 @@ def _stop_serving(self, obj): def close(self): # Cancel remaining registered operations. - for address, (f, ov, obj, callback) in list(self._cache.items()): + for address, (fut, ov, obj, callback) in list(self._cache.items()): if obj is None: # The operation was started with connect_pipe() which # queues a task to Windows' thread pool. This cannot @@ -478,9 +478,17 @@ def close(self): del self._cache[address] else: try: - ov.cancel() - except OSError: - pass + fut.cancel() + except OSError as exc: + if self._loop is not None: + context = { + 'message': 'Cancelling a future failed', + 'exception': exc, + 'future': fut, + } + if fut._source_traceback: + context['source_traceback'] = fut._source_traceback + self._loop.call_exception_handler(context) while self._cache: if not self._poll(1): From 3346a197360371838167b3999100d004ace1244e Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 25 Jul 2014 10:28:24 +0200 Subject: [PATCH 1131/1502] Fix _WaitHandleFuture.cancel(): return the result of the parent cancel() method --- asyncio/windows_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index af290b7e..9a94e4f5 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -88,7 +88,7 @@ def _unregister(self): def cancel(self): self._unregister() - super().cancel() + return super().cancel() class PipeServer(object): From e67f2fd94d06ab6480cb970414a475a183f7b0c4 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 25 Jul 2014 12:46:29 +0200 Subject: [PATCH 1132/1502] _OverlappedFuture.cancel() now clears its reference to the overlapped object Make also the _OverlappedFuture.ov attribute private. --- asyncio/proactor_events.py | 8 ++------ asyncio/windows_events.py | 36 +++++++++++++++++++++--------------- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index c530687d..ab566b32 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -44,13 +44,9 @@ def __init__(self, loop, sock, protocol, waiter=None, def __repr__(self): info = [self.__class__.__name__, 'fd=%s' % self._sock.fileno()] if self._read_fut is not None: - ov = "pending" if self._read_fut.ov.pending else "completed" - info.append('read=%s' % ov) + info.append('read=%s' % self._read_fut) if self._write_fut is not None: - if self._write_fut.ov.pending: - info.append("write=pending=%s" % self._pending_write) - else: - info.append("write=completed") + info.append("write=%r" % self._write_fut) if self._buffer: bufsize = len(self._buffer) info.append('write_bufsize=%s' % bufsize) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 9a94e4f5..b6fc3252 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -40,31 +40,37 @@ def __init__(self, ov, *, loop=None): super().__init__(loop=loop) if self._source_traceback: del self._source_traceback[-1] - self.ov = ov + self._ov = ov def __repr__(self): info = [self._state.lower()] - state = 'pending' if self.ov.pending else 'completed' - info.append('overlapped=<%s, %#x>' % (state, self.ov.address)) + if self._ov is not None: + state = 'pending' if self._ov.pending else 'completed' + info.append('overlapped=<%s, %#x>' % (state, self._ov.address)) if self._state == futures._FINISHED: info.append(self._format_result()) if self._callbacks: info.append(self._format_callbacks()) return '<%s %s>' % (self.__class__.__name__, ' '.join(info)) + def _cancel_overlapped(self): + if self._ov is None: + return + try: + self._ov.cancel() + except OSError as exc: + context = { + 'message': 'Cancelling an overlapped future failed', + 'exception': exc, + 'future': self, + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + self._ov = None + def cancel(self): - if not self.done(): - try: - self.ov.cancel() - except OSError as exc: - context = { - 'message': 'Cancelling an overlapped future failed', - 'exception': exc, - 'future': self, - } - if self._source_traceback: - context['source_traceback'] = self._source_traceback - self._loop.call_exception_handler(context) + self._cancel_overlapped() return super().cancel() From cdef6ae9ab75490d5efca2a54b69528ec2ca0243 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 25 Jul 2014 12:25:21 +0200 Subject: [PATCH 1133/1502] Check if _WaitHandleFuture completed before unregistering it in the callback. Add also _WaitHandleFuture._poll() and repr(_WaitHandleFuture). --- asyncio/windows_events.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index b6fc3252..23932a7e 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -77,10 +77,28 @@ def cancel(self): class _WaitHandleFuture(futures.Future): """Subclass of Future which represents a wait handle.""" - def __init__(self, wait_handle, *, loop=None): + def __init__(self, handle, wait_handle, *, loop=None): super().__init__(loop=loop) + self._handle = handle self._wait_handle = wait_handle + def _poll(self): + # non-blocking wait: use a timeout of 0 millisecond + return (_winapi.WaitForSingleObject(self._handle, 0) == + _winapi.WAIT_OBJECT_0) + + def __repr__(self): + info = [self._state.lower()] + if self._wait_handle: + state = 'pending' if self._poll() else 'completed' + info.append('wait_handle=<%s, %#x>' % (state, self._wait_handle)) + info.append('handle=<%#x>' % self._handle) + if self._state == futures._FINISHED: + info.append(self._format_result()) + if self._callbacks: + info.append(self._format_callbacks()) + return '<%s %s>' % (self.__class__.__name__, ' '.join(info)) + def _unregister(self): if self._wait_handle is None: return @@ -376,18 +394,18 @@ def wait_for_handle(self, handle, timeout=None): ov = _overlapped.Overlapped(NULL) wh = _overlapped.RegisterWaitWithQueue( handle, self._iocp, ov.address, ms) - f = _WaitHandleFuture(wh, loop=self._loop) + f = _WaitHandleFuture(handle, wh, loop=self._loop) def finish_wait_for_handle(trans, key, ov): - f._unregister() # Note that this second wait means that we should only use # this with handles types where a successful wait has no # effect. So events or processes are all right, but locks # or semaphores are not. Also note if the handle is # signalled and then quickly reset, then we may return # False even though we have not timed out. - return (_winapi.WaitForSingleObject(handle, 0) == - _winapi.WAIT_OBJECT_0) + done = f._poll() + f._unregister() + return done self._cache[ov.address] = (f, ov, None, finish_wait_for_handle) return f From 1f8d0d50980cff135ba0b579ba453a3b487e59cb Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 25 Jul 2014 12:31:35 +0200 Subject: [PATCH 1134/1502] _WaitHandleFuture now unregisters its wait handler if WaitForSingleObject() raises an exception. --- asyncio/windows_events.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 23932a7e..0842d8ad 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -403,9 +403,10 @@ def finish_wait_for_handle(trans, key, ov): # or semaphores are not. Also note if the handle is # signalled and then quickly reset, then we may return # False even though we have not timed out. - done = f._poll() - f._unregister() - return done + try: + return f._poll() + finally: + f._unregister() self._cache[ov.address] = (f, ov, None, finish_wait_for_handle) return f From 1bc118f39f16b6099bc22e24acf93b9fa5fd9a65 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 25 Jul 2014 12:33:53 +0200 Subject: [PATCH 1135/1502] _OverlappedFuture.set_exception() now cancels the overlapped operation. --- asyncio/windows_events.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 0842d8ad..375003c4 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -73,6 +73,10 @@ def cancel(self): self._cancel_overlapped() return super().cancel() + def set_exception(self, exception): + super().set_exception(exception) + self._cancel_overlapped() + class _WaitHandleFuture(futures.Future): """Subclass of Future which represents a wait handle.""" From cd0c4b3c6bb6c6070b65f3dc330680db528fd0e2 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 25 Jul 2014 14:04:35 +0200 Subject: [PATCH 1136/1502] test_subprocess: relax timings for slow builbots --- tests/test_subprocess.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index a4e9df2f..b5b1012e 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -42,7 +42,7 @@ def run(data): return (exitcode, data) task = run(b'some data') - task = asyncio.wait_for(task, 10.0, loop=self.loop) + task = asyncio.wait_for(task, 60.0, loop=self.loop) exitcode, stdout = self.loop.run_until_complete(task) self.assertEqual(exitcode, 0) self.assertEqual(stdout, b'some data') @@ -61,7 +61,7 @@ def run(data): return proc.returncode, stdout task = run(b'some data') - task = asyncio.wait_for(task, 10.0, loop=self.loop) + task = asyncio.wait_for(task, 60.0, loop=self.loop) exitcode, stdout = self.loop.run_until_complete(task) self.assertEqual(exitcode, 0) self.assertEqual(stdout, b'some data') From 24a1f6152862b2ab4c671b7ba6325e3e99c35b1d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 25 Jul 2014 18:33:33 +0200 Subject: [PATCH 1137/1502] Tulip issue #196: IocpProactor._poll() clears the reference to the overlapped operation when the operation is done. It would be better to clear the reference in a new _OverlappedFuture.set_result() method, but it cannot be done yet because of a weird bug. --- asyncio/windows_events.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 375003c4..9bca77bb 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -446,6 +446,11 @@ def _register(self, ov, obj, callback, wait_for_post=False): f.set_exception(e) else: f.set_result(value) + # FIXME, tulip issue #196: add _OverlappedFuture.set_result() + # method to clear the refrence, don't do it here (f may + # by a _WaitHandleFuture). Problem: clearing the reference + # in _register() if ov.pedding is False leads to weird bugs. + f._ov = None return f def _get_accept_socket(self, family): From b517ceed9c3fdf9c08ec94a641223ec1d5c9e4f6 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 25 Jul 2014 18:35:21 +0200 Subject: [PATCH 1138/1502] Oops, fix previous commit: I wanted to do exactly the reverse: only clear the reference in _poll(), not in _register(). --- asyncio/windows_events.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 9bca77bb..65ecf340 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -446,11 +446,6 @@ def _register(self, ov, obj, callback, wait_for_post=False): f.set_exception(e) else: f.set_result(value) - # FIXME, tulip issue #196: add _OverlappedFuture.set_result() - # method to clear the refrence, don't do it here (f may - # by a _WaitHandleFuture). Problem: clearing the reference - # in _register() if ov.pedding is False leads to weird bugs. - f._ov = None return f def _get_accept_socket(self, family): @@ -494,6 +489,11 @@ def _poll(self, timeout=None): else: f.set_result(value) self._results.append(f) + # FIXME, tulip issue #196: add _OverlappedFuture.set_result() + # method to clear the refrence, don't do it here (f may + # by a _WaitHandleFuture). Problem: clearing the reference + # in _register() if ov.pedding is False leads to weird bugs. + f._ov = None ms = 0 def _stop_serving(self, obj): From c64ec6620c3528a8e9ff924bfe5f4e79b462d32a Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 25 Jul 2014 01:13:55 +0200 Subject: [PATCH 1139/1502] Fix runtest.py to be able to log at level DEBUG --- runtests.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/runtests.py b/runtests.py index 4d06f695..f23aaa90 100644 --- a/runtests.py +++ b/runtests.py @@ -248,18 +248,20 @@ def runtests(): ) cov.start() - finder = TestsFinder(args.testsdir, includes, excludes) logger = logging.getLogger() if v == 0: - logger.setLevel(logging.CRITICAL) + level = logging.CRITICAL elif v == 1: - logger.setLevel(logging.ERROR) + level = logging.ERROR elif v == 2: - logger.setLevel(logging.WARNING) + level = logging.WARNING elif v == 3: - logger.setLevel(logging.INFO) + level = logging.INFO elif v >= 4: - logger.setLevel(logging.DEBUG) + level = logging.DEBUG + logging.basicConfig(level=level) + + finder = TestsFinder(args.testsdir, includes, excludes) if catchbreak: installHandler() try: From 555a6a8c3558a4a814bf2b44acea001af3458581 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 25 Jul 2014 01:14:26 +0200 Subject: [PATCH 1140/1502] BaseSelectorEventLoop._write_to_self() now logs errors in debug mode --- asyncio/selector_events.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index cd1a75aa..eca48b8e 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -120,7 +120,10 @@ def _write_to_self(self): try: csock.send(b'\0') except OSError: - pass + if self._debug: + logger.debug("Fail to write a null byte into the " + "self-pipe socket", + exc_info=True) def _start_serving(self, protocol_factory, sock, sslcontext=None, server=None): From 7537ea596c258152086060314848b056dca7937e Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 26 Jul 2014 00:31:52 +0200 Subject: [PATCH 1141/1502] Tulip issue #196: _OverlappedFuture.set_result() now clears its reference to the overlapped object. IocpProactor._poll() now also ignores false alarms: GetQueuedCompletionStatus() returns the overlapped but it is still pending. --- asyncio/windows_events.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 65ecf340..3aa142c4 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -77,6 +77,10 @@ def set_exception(self, exception): super().set_exception(exception) self._cancel_overlapped() + def set_result(self, result): + super().set_result(result) + self._ov = None + class _WaitHandleFuture(futures.Future): """Subclass of Future which represents a wait handle.""" @@ -478,6 +482,13 @@ def _poll(self, timeout=None): _winapi.CloseHandle(key) ms = 0 continue + + if ov.pending: + # False alarm: the overlapped operation is not completed. + # FIXME: why do we get false alarms? + self._cache[address] = (f, ov, obj, callback) + continue + if obj in self._stopped_serving: f.cancel() elif not f.cancelled(): @@ -489,11 +500,6 @@ def _poll(self, timeout=None): else: f.set_result(value) self._results.append(f) - # FIXME, tulip issue #196: add _OverlappedFuture.set_result() - # method to clear the refrence, don't do it here (f may - # by a _WaitHandleFuture). Problem: clearing the reference - # in _register() if ov.pedding is False leads to weird bugs. - f._ov = None ms = 0 def _stop_serving(self, obj): From d634f5b6a538713558878c2866f9556302df7205 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 26 Jul 2014 17:51:20 +0300 Subject: [PATCH 1142/1502] Accept optional lock object in Condition ctor (#198) --- asyncio/locks.py | 9 ++++++--- tests/test_locks.py | 12 ++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/asyncio/locks.py b/asyncio/locks.py index 8d9e3b4d..574e3618 100644 --- a/asyncio/locks.py +++ b/asyncio/locks.py @@ -255,14 +255,17 @@ class Condition: A new Lock object is created and used as the underlying lock. """ - def __init__(self, *, loop=None): + def __init__(self, lock=None, *, loop=None): if loop is not None: self._loop = loop else: self._loop = events.get_event_loop() - # Lock as an attribute as in threading.Condition. - lock = Lock(loop=self._loop) + if lock is None: + lock = Lock(loop=self._loop) + elif lock._loop is not self._loop: + raise ValueError("loop argument must agree with lock") + self._lock = lock # Export the lock's locked(), acquire() and release() methods. self.locked = lock.locked diff --git a/tests/test_locks.py b/tests/test_locks.py index 8ad14863..c4e74e33 100644 --- a/tests/test_locks.py +++ b/tests/test_locks.py @@ -656,6 +656,18 @@ def test_context_manager_no_yield(self): self.assertFalse(cond.locked()) + def test_explicit_lock(self): + lock = asyncio.Lock(loop=self.loop) + cond = asyncio.Condition(lock, loop=self.loop) + + self.assertIs(lock._loop, cond._loop) + + def test_ambiguous_loops(self): + loop = self.new_test_loop() + lock = asyncio.Lock(loop=self.loop) + with self.assertRaises(ValueError): + asyncio.Condition(lock, loop=loop) + class SemaphoreTests(test_utils.TestCase): From fe1cffda9c574fd177cd7af5b93bb39e5f7ca3a3 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 26 Jul 2014 19:48:23 +0200 Subject: [PATCH 1143/1502] Tulip issue 196: ProactorIocp._register() now registers the overlapped in the _cache dictionary, even if we already got the result. We need to keep a reference to the overlapped object, otherwise the memory may be reused and GetQueuedCompletionStatus() may use random bytes and behaves badly. There is still a hack for ConnectNamedPipe(): the overlapped object is not register into _cache if the overlapped object completed directly. Log also an error in debug mode in ProactorIocp._loop() if we get an unexpected event. Add a protection in ProactorIocp.close() to avoid blocking, even if it should not happen. I still don't understand exactly why some the completion of some overlapped objects are not notified. --- asyncio/windows_events.py | 53 ++++++++++++++++++++++++++++----------- 1 file changed, 38 insertions(+), 15 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 3aa142c4..41be8da2 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -369,7 +369,10 @@ def finish_accept_pipe(trans, key, ov): ov.getresult() return pipe - return self._register(ov, pipe, finish_accept_pipe) + # FIXME: Tulip issue 196: why to we neeed register=False? + # See also the comment in the _register() method + return self._register(ov, pipe, finish_accept_pipe, + register=False) def connect_pipe(self, address): ov = _overlapped.Overlapped(NULL) @@ -429,17 +432,13 @@ def _register_with_iocp(self, obj): # to avoid sending notifications to completion port of ops # that succeed immediately. - def _register(self, ov, obj, callback, wait_for_post=False): + def _register(self, ov, obj, callback, + wait_for_post=False, register=True): # Return a future which will be set with the result of the # operation when it completes. The future's value is actually # the value returned by callback(). f = _OverlappedFuture(ov, loop=self._loop) - if ov.pending or wait_for_post: - # Register the overlapped operation for later. Note that - # we only store obj to prevent it from being garbage - # collected too early. - self._cache[ov.address] = (f, ov, obj, callback) - else: + if not ov.pending and not wait_for_post: # The operation has completed, so no need to postpone the # work. We cannot take this short cut if we need the # NumberOfBytes, CompletionKey values returned by @@ -450,6 +449,23 @@ def _register(self, ov, obj, callback, wait_for_post=False): f.set_exception(e) else: f.set_result(value) + # Even if GetOverlappedResult() was called, we have to wait for the + # notification of the completion in GetQueuedCompletionStatus(). + # Register the overlapped operation to keep a reference to the + # OVERLAPPED object, otherwise the memory is freed and Windows may + # read uninitialized memory. + # + # For an unknown reason, ConnectNamedPipe() behaves differently: + # the completion is not notified by GetOverlappedResult() if we + # already called GetOverlappedResult(). For this specific case, we + # don't expect notification (register is set to False). + else: + register = True + if register: + # Register the overlapped operation for later. Note that + # we only store obj to prevent it from being garbage + # collected too early. + self._cache[ov.address] = (f, ov, obj, callback) return f def _get_accept_socket(self, family): @@ -476,6 +492,14 @@ def _poll(self, timeout=None): try: f, ov, obj, callback = self._cache.pop(address) except KeyError: + if self._loop.get_debug(): + self._loop.call_exception_handler({ + 'message': ('GetQueuedCompletionStatus() returned an ' + 'unexpected event'), + 'status': ('err=%s transferred=%s key=%#x address=%#x' + % (err, transferred, key, address)), + }) + # key is either zero, or it is used to return a pipe # handle which should be closed to avoid a leak. if key not in (0, _overlapped.INVALID_HANDLE_VALUE): @@ -483,15 +507,11 @@ def _poll(self, timeout=None): ms = 0 continue - if ov.pending: - # False alarm: the overlapped operation is not completed. - # FIXME: why do we get false alarms? - self._cache[address] = (f, ov, obj, callback) - continue - if obj in self._stopped_serving: f.cancel() - elif not f.cancelled(): + # Don't call the callback if _register() already read the result or + # if the overlapped has been cancelled + elif not f.done(): try: value = callback(transferred, key, ov) except OSError as e: @@ -516,6 +536,9 @@ def close(self): # queues a task to Windows' thread pool. This cannot # be cancelled, so just forget it. del self._cache[address] + # FIXME: Tulip issue 196: remove this case, it should not happen + elif fut.done() and not fut.cancelled(): + del self._cache[address] else: try: fut.cancel() From 89775e3d6ff5bc47e743958eac95c97180f57a9a Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 29 Jul 2014 10:07:03 +0200 Subject: [PATCH 1144/1502] Cleanup ProactorIocp._poll(): set the timeout to 0 after the first call to GetQueuedCompletionStatus() --- asyncio/windows_events.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 41be8da2..6d9feab2 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -484,10 +484,13 @@ def _poll(self, timeout=None): ms = math.ceil(timeout * 1e3) if ms >= INFINITE: raise ValueError("timeout too big") + while True: status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms) if status is None: return + ms = 0 + err, transferred, key, address = status try: f, ov, obj, callback = self._cache.pop(address) @@ -504,7 +507,6 @@ def _poll(self, timeout=None): # handle which should be closed to avoid a leak. if key not in (0, _overlapped.INVALID_HANDLE_VALUE): _winapi.CloseHandle(key) - ms = 0 continue if obj in self._stopped_serving: @@ -520,7 +522,6 @@ def _poll(self, timeout=None): else: f.set_result(value) self._results.append(f) - ms = 0 def _stop_serving(self, obj): # obj is a socket or pipe handle. It will be closed in From b11c0a47dd0d1152a59f7d4c65d9fa62bc0c7326 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 29 Jul 2014 10:14:06 +0200 Subject: [PATCH 1145/1502] test_locks: close the temporary event loop and check the condition lock --- tests/test_locks.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_locks.py b/tests/test_locks.py index c4e74e33..dda4577a 100644 --- a/tests/test_locks.py +++ b/tests/test_locks.py @@ -660,10 +660,13 @@ def test_explicit_lock(self): lock = asyncio.Lock(loop=self.loop) cond = asyncio.Condition(lock, loop=self.loop) - self.assertIs(lock._loop, cond._loop) + self.assertIs(cond._lock, lock) + self.assertIs(cond._loop, lock._loop) def test_ambiguous_loops(self): loop = self.new_test_loop() + self.addCleanup(loop.close) + lock = asyncio.Lock(loop=self.loop) with self.assertRaises(ValueError): asyncio.Condition(lock, loop=loop) From f778f1d295874d1b7b3044fdab37309aefa54b51 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 29 Jul 2014 12:49:33 +0200 Subject: [PATCH 1146/1502] Enhance representation of Future and Future subclasses * Add "created at filename:lineno" in the representation * Add Future._repr_info() method which can be more easily overriden than Future.__repr__(). It should now be more easy to enhance Future representation without having to modify each subclass. For example, _OverlappedFuture and _WaitHandleFuture get the new "created at" information. * Use reprlib to format Future result, and function arguments when formatting a callback, to limit the length of the representation. --- asyncio/events.py | 15 ++++++++++----- asyncio/futures.py | 26 ++++++++++++++++---------- asyncio/tasks.py | 27 ++++++++------------------- asyncio/windows_events.py | 27 ++++++++++----------------- tests/test_futures.py | 11 ++++++++++- tests/test_tasks.py | 6 ++++++ 6 files changed, 60 insertions(+), 52 deletions(-) diff --git a/asyncio/events.py b/asyncio/events.py index bddd7e36..3c7a36d0 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -10,11 +10,12 @@ import functools import inspect -import subprocess -import traceback -import threading +import reprlib import socket +import subprocess import sys +import threading +import traceback _PY34 = sys.version_info >= (3, 4) @@ -36,8 +37,12 @@ def _get_function_source(func): def _format_args(args): - # function formatting ('hello',) as ('hello') - args_repr = repr(args) + """Format function arguments. + + Special case for a single parameter: ('hello',) is formatted as ('hello'). + """ + # use reprlib to limit the length of the output + args_repr = reprlib.repr(args) if len(args) == 1 and args_repr.endswith(',)'): args_repr = args_repr[:-2] + ')' return args_repr diff --git a/asyncio/futures.py b/asyncio/futures.py index 022fef76..7998fbbc 100644 --- a/asyncio/futures.py +++ b/asyncio/futures.py @@ -7,6 +7,7 @@ import concurrent.futures._base import logging +import reprlib import sys import traceback @@ -175,20 +176,25 @@ def format_cb(callback): format_cb(cb[-1])) return 'cb=[%s]' % cb - def _format_result(self): - if self._state != _FINISHED: - return None - elif self._exception is not None: - return 'exception={!r}'.format(self._exception) - else: - return 'result={!r}'.format(self._result) - - def __repr__(self): + def _repr_info(self): info = [self._state.lower()] if self._state == _FINISHED: - info.append(self._format_result()) + if self._exception is not None: + info.append('exception={!r}'.format(self._exception)) + else: + # use reprlib to limit the length of the output, especially + # for very long strings + result = reprlib.repr(self._result) + info.append('result={}'.format(result)) if self._callbacks: info.append(self._format_callbacks()) + if self._source_traceback: + frame = self._source_traceback[-1] + info.append('created at %s:%s' % (frame[0], frame[1])) + return info + + def __repr__(self): + info = self._repr_info() return '<%s %s>' % (self.__class__.__name__, ' '.join(info)) # On Python 3.3 or older, objects with a destructor part of a reference diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 07952c9a..92070162 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -92,30 +92,19 @@ def __del__(self): self._loop.call_exception_handler(context) futures.Future.__del__(self) - def __repr__(self): - info = [] + def _repr_info(self): + info = super()._repr_info() + if self._must_cancel: - info.append('cancelling') - else: - info.append(self._state.lower()) + # replace status + info[0] = 'cancelling' coro = coroutines._format_coroutine(self._coro) - info.append('coro=<%s>' % coro) - - if self._source_traceback: - frame = self._source_traceback[-1] - info.append('created at %s:%s' % (frame[0], frame[1])) - - if self._state == futures._FINISHED: - info.append(self._format_result()) - - if self._callbacks: - info.append(self._format_callbacks()) + info.insert(1, 'coro=<%s>' % coro) if self._fut_waiter is not None: - info.append('wait_for=%r' % self._fut_waiter) - - return '<%s %s>' % (self.__class__.__name__, ' '.join(info)) + info.insert(2, 'wait_for=%r' % self._fut_waiter) + return info def get_stack(self, *, limit=None): """Return the list of stack frames for this task's coroutine. diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 6d9feab2..1db255ee 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -42,16 +42,12 @@ def __init__(self, ov, *, loop=None): del self._source_traceback[-1] self._ov = ov - def __repr__(self): - info = [self._state.lower()] + def _repr_info(self): + info = super()._repr_info() if self._ov is not None: state = 'pending' if self._ov.pending else 'completed' - info.append('overlapped=<%s, %#x>' % (state, self._ov.address)) - if self._state == futures._FINISHED: - info.append(self._format_result()) - if self._callbacks: - info.append(self._format_callbacks()) - return '<%s %s>' % (self.__class__.__name__, ' '.join(info)) + info.insert(1, 'overlapped=<%s, %#x>' % (state, self._ov.address)) + return info def _cancel_overlapped(self): if self._ov is None: @@ -95,17 +91,14 @@ def _poll(self): return (_winapi.WaitForSingleObject(self._handle, 0) == _winapi.WAIT_OBJECT_0) - def __repr__(self): - info = [self._state.lower()] + def _repr_info(self): + info = super()._repr_info() + info.insert(1, 'handle=%#x' % self._handle) if self._wait_handle: state = 'pending' if self._poll() else 'completed' - info.append('wait_handle=<%s, %#x>' % (state, self._wait_handle)) - info.append('handle=<%#x>' % self._handle) - if self._state == futures._FINISHED: - info.append(self._format_result()) - if self._callbacks: - info.append(self._format_callbacks()) - return '<%s %s>' % (self.__class__.__name__, ' '.join(info)) + info.insert(1, 'wait_handle=<%s, %#x>' + % (state, self._wait_handle)) + return info def _unregister(self): if self._wait_handle is None: diff --git a/tests/test_futures.py b/tests/test_futures.py index 50e9414a..3029a9c3 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -105,6 +105,15 @@ def fixture(): self.assertEqual(next(g), ('C', 42)) # yield 'C', y. def test_future_repr(self): + self.loop.set_debug(True) + f_pending_debug = asyncio.Future(loop=self.loop) + frame = f_pending_debug._source_traceback[-1] + self.assertEqual(repr(f_pending_debug), + '' + % (frame[0], frame[1])) + f_pending_debug.cancel() + + self.loop.set_debug(False) f_pending = asyncio.Future(loop=self.loop) self.assertEqual(repr(f_pending), '') f_pending.cancel() @@ -324,7 +333,7 @@ def memory_error(): if sys.version_info >= (3, 4): frame = source_traceback[-1] regex = (r'^Future exception was never retrieved\n' - r'future: \n' + r'future: \n' r'source_traceback: Object created at \(most recent call last\):\n' r' File' r'.*\n' diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 7b93a0e2..95cba542 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -132,6 +132,8 @@ def test_async_neither(self): asyncio.async('ok') def test_task_repr(self): + self.loop.set_debug(False) + @asyncio.coroutine def notmuch(): yield from [] @@ -189,6 +191,8 @@ def notmuch(): "" % coro) def test_task_repr_coro_decorator(self): + self.loop.set_debug(False) + @asyncio.coroutine def notmuch(): # notmuch() function doesn't use yield from: it will be wrapped by @@ -252,6 +256,8 @@ def notmuch(): self.loop.run_until_complete(t) def test_task_repr_wait_for(self): + self.loop.set_debug(False) + @asyncio.coroutine def wait_for(fut): return (yield from fut) From f439c555d42849cd390cc84c17c95be46ec75b0c Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 29 Jul 2014 11:36:00 +0200 Subject: [PATCH 1147/1502] _WaitHandleFuture and _OverlappedFuture: hide frames of internal calls in the source traceback. --- asyncio/windows_events.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 1db255ee..a5d9b219 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -83,6 +83,8 @@ class _WaitHandleFuture(futures.Future): def __init__(self, handle, wait_handle, *, loop=None): super().__init__(loop=loop) + if self._source_traceback: + del self._source_traceback[-1] self._handle = handle self._wait_handle = wait_handle @@ -399,6 +401,8 @@ def wait_for_handle(self, handle, timeout=None): wh = _overlapped.RegisterWaitWithQueue( handle, self._iocp, ov.address, ms) f = _WaitHandleFuture(handle, wh, loop=self._loop) + if f._source_traceback: + del f._source_traceback[-1] def finish_wait_for_handle(trans, key, ov): # Note that this second wait means that we should only use @@ -431,6 +435,8 @@ def _register(self, ov, obj, callback, # operation when it completes. The future's value is actually # the value returned by callback(). f = _OverlappedFuture(ov, loop=self._loop) + if f._source_traceback: + del f._source_traceback[-1] if not ov.pending and not wait_for_post: # The operation has completed, so no need to postpone the # work. We cannot take this short cut if we need the From 65f26cbf04bed90ba07bd01a6b0dcc11726558fa Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 29 Jul 2014 11:37:24 +0200 Subject: [PATCH 1148/1502] Fix repr(_WaitHandleFuture) --- asyncio/windows_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index a5d9b219..03146ca9 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -97,7 +97,7 @@ def _repr_info(self): info = super()._repr_info() info.insert(1, 'handle=%#x' % self._handle) if self._wait_handle: - state = 'pending' if self._poll() else 'completed' + state = 'signaled' if self._poll() else 'waiting' info.insert(1, 'wait_handle=<%s, %#x>' % (state, self._wait_handle)) return info From 994506cbbd8baf99e396c43bd8271aa5a0a08805 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 29 Jul 2014 11:39:35 +0200 Subject: [PATCH 1149/1502] Optimize IocpProactor.wait_for_handle() gets the result if the wait is signaled immediatly. --- asyncio/windows_events.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 03146ca9..40991ff0 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -102,7 +102,7 @@ def _repr_info(self): % (state, self._wait_handle)) return info - def _unregister(self): + def _unregister_wait(self): if self._wait_handle is None: return try: @@ -114,8 +114,17 @@ def _unregister(self): self._wait_handle = None def cancel(self): - self._unregister() - return super().cancel() + result = super().cancel() + self._unregister_wait() + return result + + def set_exception(self, exception): + super().set_exception(exception) + self._unregister_wait() + + def set_result(self, result): + super().set_result(result) + self._unregister_wait() class PipeServer(object): @@ -411,10 +420,15 @@ def finish_wait_for_handle(trans, key, ov): # or semaphores are not. Also note if the handle is # signalled and then quickly reset, then we may return # False even though we have not timed out. + return f._poll() + + if f._poll(): try: - return f._poll() - finally: - f._unregister() + result = f._poll() + except OSError as exc: + f.set_exception(exc) + else: + f.set_result(result) self._cache[ov.address] = (f, ov, None, finish_wait_for_handle) return f From 2d43b85112fb3ee7cb8863669981792016817825 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 29 Jul 2014 12:50:37 +0200 Subject: [PATCH 1150/1502] _WaitHandleFuture.cancel() now notify IocpProactor through the overlapped object that the wait was cancelled. --- asyncio/windows_events.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 40991ff0..ec427d5c 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -81,10 +81,14 @@ def set_result(self, result): class _WaitHandleFuture(futures.Future): """Subclass of Future which represents a wait handle.""" - def __init__(self, handle, wait_handle, *, loop=None): + def __init__(self, iocp, ov, handle, wait_handle, *, loop=None): super().__init__(loop=loop) if self._source_traceback: del self._source_traceback[-1] + # iocp and ov are only used by cancel() to notify IocpProactor + # that the wait was cancelled + self._iocp = iocp + self._ov = ov self._handle = handle self._wait_handle = wait_handle @@ -112,9 +116,15 @@ def _unregister_wait(self): raise # ERROR_IO_PENDING is not an error, the wait was unregistered self._wait_handle = None + self._iocp = None + self._ov = None def cancel(self): result = super().cancel() + if self._ov is not None: + # signal the cancellation to the overlapped object + _overlapped.PostQueuedCompletionStatus(self._iocp, True, + 0, self._ov.address) self._unregister_wait() return result @@ -409,7 +419,7 @@ def wait_for_handle(self, handle, timeout=None): ov = _overlapped.Overlapped(NULL) wh = _overlapped.RegisterWaitWithQueue( handle, self._iocp, ov.address, ms) - f = _WaitHandleFuture(handle, wh, loop=self._loop) + f = _WaitHandleFuture(self._iocp, ov, handle, wh, loop=self._loop) if f._source_traceback: del f._source_traceback[-1] @@ -430,7 +440,7 @@ def finish_wait_for_handle(trans, key, ov): else: f.set_result(result) - self._cache[ov.address] = (f, ov, None, finish_wait_for_handle) + self._cache[ov.address] = (f, ov, 0, finish_wait_for_handle) return f def _register_with_iocp(self, obj): From f7fcb293a9406cb7fe2ae60c72de797a00b0b2dc Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 29 Jul 2014 12:51:05 +0200 Subject: [PATCH 1151/1502] Remove workaround in test_futures, no more needed --- tests/test_futures.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/test_futures.py b/tests/test_futures.py index 3029a9c3..e5002bc8 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -308,12 +308,6 @@ def test_future_source_traceback(self): @mock.patch('asyncio.base_events.logger') def test_future_exception_never_retrieved(self, m_log): - # FIXME: Python issue #21163, other tests may "leak" pending task which - # emit a warning when they are destroyed by the GC - support.gc_collect() - m_log.error.reset_mock() - # --- - self.loop.set_debug(True) def memory_error(): From 39e06a2279c0da1d642944a6f72f19b4a464c9de Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 29 Jul 2014 22:45:43 +0200 Subject: [PATCH 1152/1502] Use the new os.set_blocking() function of Python 3.5 if available --- asyncio/unix_events.py | 12 ++++++++---- tests/test_unix_events.py | 12 ++++++------ 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 5020cc5d..8d3e25eb 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -259,10 +259,14 @@ def create_unix_server(self, protocol_factory, path=None, *, return server -def _set_nonblocking(fd): - flags = fcntl.fcntl(fd, fcntl.F_GETFL) - flags = flags | os.O_NONBLOCK - fcntl.fcntl(fd, fcntl.F_SETFL, flags) +if hasattr(os, 'set_blocking'): + def _set_nonblocking(fd): + os.set_blocking(fd, False) +else: + def _set_nonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) class _UnixReadPipeTransport(transports.ReadTransport): diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 099d4d51..e3975982 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -306,9 +306,9 @@ def setUp(self): self.pipe = mock.Mock(spec_set=io.RawIOBase) self.pipe.fileno.return_value = 5 - fcntl_patcher = mock.patch('fcntl.fcntl') - fcntl_patcher.start() - self.addCleanup(fcntl_patcher.stop) + blocking_patcher = mock.patch('asyncio.unix_events._set_nonblocking') + blocking_patcher.start() + self.addCleanup(blocking_patcher.stop) fstat_patcher = mock.patch('os.fstat') m_fstat = fstat_patcher.start() @@ -469,9 +469,9 @@ def setUp(self): self.pipe = mock.Mock(spec_set=io.RawIOBase) self.pipe.fileno.return_value = 5 - fcntl_patcher = mock.patch('fcntl.fcntl') - fcntl_patcher.start() - self.addCleanup(fcntl_patcher.stop) + blocking_patcher = mock.patch('asyncio.unix_events._set_nonblocking') + blocking_patcher.start() + self.addCleanup(blocking_patcher.stop) fstat_patcher = mock.patch('os.fstat') m_fstat = fstat_patcher.start() From 8d6d6c799af8f903540c740bb40f28e422ccb1de Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 29 Jul 2014 23:04:28 +0200 Subject: [PATCH 1153/1502] Python issue 22063: socket operations (socket,recv, sock_sendall, sock_connect, sock_accept) now raise an exception in debug mode if sockets are in blocking mode. --- asyncio/proactor_events.py | 8 ++++++++ asyncio/selector_events.py | 8 ++++++++ tests/test_events.py | 18 ++++++++++++++++++ 3 files changed, 34 insertions(+) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index ab566b32..751155bf 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -385,12 +385,18 @@ def close(self): self._selector = None def sock_recv(self, sock, n): + if self.get_debug() and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") return self._proactor.recv(sock, n) def sock_sendall(self, sock, data): + if self.get_debug() and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") return self._proactor.send(sock, data) def sock_connect(self, sock, address): + if self.get_debug() and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") try: base_events._check_resolved_address(sock, address) except ValueError as err: @@ -401,6 +407,8 @@ def sock_connect(self, sock, address): return self._proactor.connect(sock, address) def sock_accept(self, sock): + if self.get_debug() and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") return self._proactor.accept(sock) def _socketpair(self): diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index eca48b8e..6b7bdf01 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -256,6 +256,8 @@ def sock_recv(self, sock, n): This method is a coroutine. """ + if self.get_debug() and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") fut = futures.Future(loop=self) self._sock_recv(fut, False, sock, n) return fut @@ -292,6 +294,8 @@ def sock_sendall(self, sock, data): This method is a coroutine. """ + if self.get_debug() and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") fut = futures.Future(loop=self) if data: self._sock_sendall(fut, False, sock, data) @@ -333,6 +337,8 @@ def sock_connect(self, sock, address): This method is a coroutine. """ + if self.get_debug() and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") fut = futures.Future(loop=self) try: base_events._check_resolved_address(sock, address) @@ -374,6 +380,8 @@ def sock_accept(self, sock): This method is a coroutine. """ + if self.get_debug() and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") fut = futures.Future(loop=self) self._sock_accept(fut, False, sock) return fut diff --git a/tests/test_events.py b/tests/test_events.py index b0657495..0cff00ae 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -383,6 +383,24 @@ def writer(data): self.assertEqual(read, data) def _basetest_sock_client_ops(self, httpd, sock): + # in debug mode, socket operations must fail + # if the socket is not in blocking mode + self.loop.set_debug(True) + sock.setblocking(True) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_connect(sock, httpd.address)) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_accept(sock)) + + # test in non-blocking mode sock.setblocking(False) self.loop.run_until_complete( self.loop.sock_connect(sock, httpd.address)) From afd9f08bb8d5db644edde0326b07ea34f0ec2de1 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 30 Jul 2014 14:15:49 +0200 Subject: [PATCH 1154/1502] PipeServer.close() now cancels the "accept pipe" future which cancels the overlapped operation. --- asyncio/windows_events.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index ec427d5c..66aeca85 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -145,6 +145,11 @@ class PipeServer(object): def __init__(self, address): self._address = address self._free_instances = weakref.WeakSet() + # initialize the pipe attribute before calling _server_pipe_handle() + # because this function can raise an exception and the destructor calls + # the close() method + self._pipe = None + self._accept_pipe_future = None self._pipe = self._server_pipe_handle(True) def _get_unconnected_pipe(self): @@ -174,6 +179,9 @@ def _server_pipe_handle(self, first): return pipe def close(self): + if self._accept_pipe_future is not None: + self._accept_pipe_future.cancel() + self._accept_pipe_future = None # Close all instances which have not been connected to by a client. if self._address is not None: for pipe in self._free_instances: @@ -216,7 +224,7 @@ def create_pipe_connection(self, protocol_factory, address): def start_serving_pipe(self, protocol_factory, address): server = PipeServer(address) - def loop(f=None): + def loop_accept_pipe(f=None): pipe = None try: if f: @@ -241,9 +249,10 @@ def loop(f=None): if pipe: pipe.close() else: - f.add_done_callback(loop) + server._accept_pipe_future = f + f.add_done_callback(loop_accept_pipe) - self.call_soon(loop) + self.call_soon(loop_accept_pipe) return [server] @coroutine From ca0f8cc293abd1e32c52472ba54a219f4f8a5570 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 30 Jul 2014 14:32:14 +0200 Subject: [PATCH 1155/1502] Fix _SelectorTransport.__repr__() if the transport was closed --- asyncio/selector_events.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 6b7bdf01..d6e3364c 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -450,22 +450,24 @@ def __init__(self, loop, sock, protocol, extra, server=None): def __repr__(self): info = [self.__class__.__name__, 'fd=%s' % self._sock_fd] - polling = _test_selector_event(self._loop._selector, - self._sock_fd, selectors.EVENT_READ) - if polling: - info.append('read=polling') - else: - info.append('read=idle') + # test if the transport was closed + if self._loop is not None: + polling = _test_selector_event(self._loop._selector, + self._sock_fd, selectors.EVENT_READ) + if polling: + info.append('read=polling') + else: + info.append('read=idle') - polling = _test_selector_event(self._loop._selector, - self._sock_fd, selectors.EVENT_WRITE) - if polling: - state = 'polling' - else: - state = 'idle' + polling = _test_selector_event(self._loop._selector, + self._sock_fd, selectors.EVENT_WRITE) + if polling: + state = 'polling' + else: + state = 'idle' - bufsize = self.get_write_buffer_size() - info.append('write=<%s, bufsize=%s>' % (state, bufsize)) + bufsize = self.get_write_buffer_size() + info.append('write=<%s, bufsize=%s>' % (state, bufsize)) return '<%s>' % ' '.join(info) def abort(self): From e50222808f14d21968444158f27e5aabe7c4b3dd Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 30 Jul 2014 14:57:41 +0200 Subject: [PATCH 1156/1502] Fix debug log in BaseEventLoop.create_connection(): get the socket object from the transport because SSL transport closes the old socket and creates a new SSL socket object. Remove also the _SelectorSslTransport._rawsock attribute: it contained the closed socket (not very useful) and it was not used. --- asyncio/base_events.py | 3 +++ asyncio/selector_events.py | 1 - tests/test_base_events.py | 3 +++ 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index d0a337bd..f88bd01b 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -578,6 +578,9 @@ def create_connection(self, protocol_factory, host=None, port=None, *, transport, protocol = yield from self._create_connection_transport( sock, protocol_factory, ssl, server_hostname) if self._debug: + # Get the socket from the transport because SSL transport closes + # the old socket and creates a new SSL socket + sock = transport.get_extra_info('socket') logger.debug("%r connected to %s:%r: (%r, %r)", sock, host, port, transport, protocol) return transport, protocol diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index d6e3364c..0434a701 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -691,7 +691,6 @@ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, self._server_hostname = server_hostname self._waiter = waiter - self._rawsock = rawsock self._sslcontext = sslcontext self._paused = False diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 7bf07ed6..ca12101b 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -792,6 +792,9 @@ def mock_getaddrinfo(*args, **kwds): class _SelectorTransportMock: _sock = None + def get_extra_info(self, key): + return mock.Mock() + def close(self): self._sock.close() From dfcf393bd957a60106339d02cd9a2786192f7993 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 30 Jul 2014 15:37:32 +0200 Subject: [PATCH 1157/1502] Python issue 22063: socket operations (sock_recv, sock_sendall, sock_connect, sock_accept) of the proactor event loop don't raise an exception in debug mode if the socket are in blocking mode. Overlapped operations also work on blocking sockets. --- asyncio/proactor_events.py | 8 -------- tests/test_events.py | 34 ++++++++++++++++++---------------- 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 751155bf..ab566b32 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -385,18 +385,12 @@ def close(self): self._selector = None def sock_recv(self, sock, n): - if self.get_debug() and sock.gettimeout() != 0: - raise ValueError("the socket must be non-blocking") return self._proactor.recv(sock, n) def sock_sendall(self, sock, data): - if self.get_debug() and sock.gettimeout() != 0: - raise ValueError("the socket must be non-blocking") return self._proactor.send(sock, data) def sock_connect(self, sock, address): - if self.get_debug() and sock.gettimeout() != 0: - raise ValueError("the socket must be non-blocking") try: base_events._check_resolved_address(sock, address) except ValueError as err: @@ -407,8 +401,6 @@ def sock_connect(self, sock, address): return self._proactor.connect(sock, address) def sock_accept(self, sock): - if self.get_debug() and sock.gettimeout() != 0: - raise ValueError("the socket must be non-blocking") return self._proactor.accept(sock) def _socketpair(self): diff --git a/tests/test_events.py b/tests/test_events.py index 0cff00ae..70ba3ad3 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -27,6 +27,7 @@ import asyncio +from asyncio import proactor_events from asyncio import selector_events from asyncio import test_utils @@ -383,22 +384,23 @@ def writer(data): self.assertEqual(read, data) def _basetest_sock_client_ops(self, httpd, sock): - # in debug mode, socket operations must fail - # if the socket is not in blocking mode - self.loop.set_debug(True) - sock.setblocking(True) - with self.assertRaises(ValueError): - self.loop.run_until_complete( - self.loop.sock_connect(sock, httpd.address)) - with self.assertRaises(ValueError): - self.loop.run_until_complete( - self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) - with self.assertRaises(ValueError): - self.loop.run_until_complete( - self.loop.sock_recv(sock, 1024)) - with self.assertRaises(ValueError): - self.loop.run_until_complete( - self.loop.sock_accept(sock)) + if not isinstance(self.loop, proactor_events.BaseProactorEventLoop): + # in debug mode, socket operations must fail + # if the socket is not in blocking mode + self.loop.set_debug(True) + sock.setblocking(True) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_connect(sock, httpd.address)) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_accept(sock)) # test in non-blocking mode sock.setblocking(False) From 65e86f2eba816d72b67083e21a79ff1d6a6e4925 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 30 Jul 2014 15:31:53 +0200 Subject: [PATCH 1158/1502] Fix unit tests in debug mode: mock a non-blocking socket for socket operations which now raise an exception if the socket is blocking. --- asyncio/test_utils.py | 6 ++++++ tests/test_events.py | 2 ++ tests/test_selector_events.py | 10 +++++----- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index 840bbf94..ac7680de 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -417,3 +417,9 @@ def disable_logger(): yield finally: logger.setLevel(old_level) + +def mock_nonblocking_socket(): + """Create a mock of a non-blocking socket.""" + sock = mock.Mock(socket.socket) + sock.gettimeout.return_value = 0.0 + return sock diff --git a/tests/test_events.py b/tests/test_events.py index 70ba3ad3..0cfc028d 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1231,6 +1231,7 @@ def reader(data): "Don't support pipes for Windows") def test_write_pipe_disconnect_on_close(self): rsock, wsock = test_utils.socketpair() + rsock.setblocking(False) pipeobj = io.open(wsock.detach(), 'wb', 1024) proto = MyWritePipeProto(loop=self.loop) @@ -1368,6 +1369,7 @@ def test_sock_connect_address(self): for sock_type in (socket.SOCK_STREAM, socket.SOCK_DGRAM): sock = socket.socket(family, sock_type) with sock: + sock.setblocking(False) connect = self.loop.sock_connect(sock, address) with self.assertRaises(ValueError) as cm: self.loop.run_until_complete(connect) diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index bd6c2f26..5fee411e 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -135,7 +135,7 @@ def test_write_to_self_exception(self): self.assertRaises(RuntimeError, self.loop._write_to_self) def test_sock_recv(self): - sock = mock.Mock() + sock = test_utils.mock_nonblocking_socket() self.loop._sock_recv = mock.Mock() f = self.loop.sock_recv(sock, 1024) @@ -183,7 +183,7 @@ def test__sock_recv_exception(self): self.assertIs(err, f.exception()) def test_sock_sendall(self): - sock = mock.Mock() + sock = test_utils.mock_nonblocking_socket() self.loop._sock_sendall = mock.Mock() f = self.loop.sock_sendall(sock, b'data') @@ -193,7 +193,7 @@ def test_sock_sendall(self): self.loop._sock_sendall.call_args[0]) def test_sock_sendall_nodata(self): - sock = mock.Mock() + sock = test_utils.mock_nonblocking_socket() self.loop._sock_sendall = mock.Mock() f = self.loop.sock_sendall(sock, b'') @@ -295,7 +295,7 @@ def test__sock_sendall_none(self): self.loop.add_writer.call_args[0]) def test_sock_connect(self): - sock = mock.Mock() + sock = test_utils.mock_nonblocking_socket() self.loop._sock_connect = mock.Mock() f = self.loop.sock_connect(sock, ('127.0.0.1', 8080)) @@ -361,7 +361,7 @@ def test__sock_connect_exception(self): self.assertIsInstance(f.exception(), OSError) def test_sock_accept(self): - sock = mock.Mock() + sock = test_utils.mock_nonblocking_socket() self.loop._sock_accept = mock.Mock() f = self.loop.sock_accept(sock) From 7332c7a402bc4dc20ba562752cbeaa760fa347d1 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 30 Jul 2014 15:51:45 +0200 Subject: [PATCH 1159/1502] _fatal_error() method of _UnixReadPipeTransport and _UnixWritePipeTransport now log all exceptions in debug mode --- asyncio/unix_events.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 8d3e25eb..656ad59e 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -337,7 +337,10 @@ def close(self): def _fatal_error(self, exc, message='Fatal error on pipe transport'): # should be called by exception handler only - if not (isinstance(exc, OSError) and exc.errno == errno.EIO): + if (isinstance(exc, OSError) and exc.errno == errno.EIO): + if self._loop.get_debug(): + logger.debug("%r: %s", self, message, exc_info=True) + else: self._loop.call_exception_handler({ 'message': message, 'exception': exc, @@ -509,7 +512,10 @@ def abort(self): def _fatal_error(self, exc, message='Fatal error on pipe transport'): # should be called by exception handler only - if not isinstance(exc, (BrokenPipeError, ConnectionResetError)): + if isinstance(exc, (BrokenPipeError, ConnectionResetError)): + if self._loop.get_debug(): + logger.debug("%r: %s", self, message, exc_info=True) + else: self._loop.call_exception_handler({ 'message': message, 'exception': exc, From cd7a6f4f24d5e4894682f2f0ce8a2bbe56d43323 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 30 Jul 2014 15:52:28 +0200 Subject: [PATCH 1160/1502] Don't log expected errors in unit tests --- tests/test_selector_events.py | 14 +++++++++----- tests/test_subprocess.py | 8 +++++--- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 5fee411e..df6e9916 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -58,8 +58,9 @@ def test_make_ssl_transport(self): self.loop.remove_reader = mock.Mock() self.loop.remove_writer = mock.Mock() waiter = asyncio.Future(loop=self.loop) - transport = self.loop._make_ssl_transport( - m, asyncio.Protocol(), m, waiter) + with test_utils.disable_logger(): + transport = self.loop._make_ssl_transport( + m, asyncio.Protocol(), m, waiter) self.assertIsInstance(transport, _SelectorSslTransport) @mock.patch('asyncio.selector_events.ssl', None) @@ -127,7 +128,8 @@ def test_read_from_self_exception(self): def test_write_to_self_tryagain(self): self.loop._csock.send.side_effect = BlockingIOError - self.assertIsNone(self.loop._write_to_self()) + with test_utils.disable_logger(): + self.assertIsNone(self.loop._write_to_self()) def test_write_to_self_exception(self): # _write_to_self() swallows OSError @@ -782,7 +784,8 @@ def test_read_ready_conn_reset(self, m_exc): transport = _SelectorSocketTransport( self.loop, self.sock, self.protocol) transport._force_close = mock.Mock() - transport._read_ready() + with test_utils.disable_logger(): + transport._read_ready() transport._force_close.assert_called_with(err) @mock.patch('logging.exception') @@ -1219,7 +1222,8 @@ def test_read_ready_recv_conn_reset(self): err = self.sslsock.recv.side_effect = ConnectionResetError() transport = self._make_one() transport._force_close = mock.Mock() - transport._read_ready() + with test_utils.disable_logger(): + transport._read_ready() transport._force_close.assert_called_with(err) def test_read_ready_recv_retry(self): diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index b5b1012e..0e9e1ce5 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -148,15 +148,17 @@ def write_stdin(proc, data): coro = write_stdin(proc, large_data) # drain() must raise BrokenPipeError or ConnectionResetError - self.assertRaises((BrokenPipeError, ConnectionResetError), - self.loop.run_until_complete, coro) + with test_utils.disable_logger(): + self.assertRaises((BrokenPipeError, ConnectionResetError), + self.loop.run_until_complete, coro) self.loop.run_until_complete(proc.wait()) def test_communicate_ignore_broken_pipe(self): proc, large_data = self.prepare_broken_pipe_test() # communicate() must ignore BrokenPipeError when feeding stdin - self.loop.run_until_complete(proc.communicate(large_data)) + with test_utils.disable_logger(): + self.loop.run_until_complete(proc.communicate(large_data)) self.loop.run_until_complete(proc.wait()) From afea3f26dc842bc7a66b4cd89a5bb8d48d7b6984 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 25 Aug 2014 23:08:51 +0200 Subject: [PATCH 1161/1502] Tulip issue #200: _WaitHandleFuture._unregister_wait() now catchs and logs exceptions. --- asyncio/windows_events.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 66aeca85..6881789b 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -111,10 +111,17 @@ def _unregister_wait(self): return try: _overlapped.UnregisterWait(self._wait_handle) - except OSError as e: - if e.winerror != _overlapped.ERROR_IO_PENDING: - raise + except OSError as exc: # ERROR_IO_PENDING is not an error, the wait was unregistered + if exc.winerror != _overlapped.ERROR_IO_PENDING: + context = { + 'message': 'Failed to unregister the wait handle', + 'exception': exc, + 'future': self, + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) self._wait_handle = None self._iocp = None self._ov = None From bc50b26945606682ce129863c1fb6fd2808147c1 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 25 Aug 2014 23:09:15 +0200 Subject: [PATCH 1162/1502] Tulip issue #200: Log errors in debug mode instead of simply ignoring them. --- asyncio/base_events.py | 4 ++++ asyncio/proactor_events.py | 16 +++++++++++++--- asyncio/unix_events.py | 4 +++- asyncio/windows_events.py | 3 +++ 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index f88bd01b..db132505 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -728,6 +728,10 @@ def create_server(self, protocol_factory, host=None, port=None, sock = socket.socket(af, socktype, proto) except socket.error: # Assume it's a bad family/type/protocol combination. + if self._debug: + logger.warning('create_server() failed to create ' + 'socket.socket(%r, %r, %r)', + af, socktype, proto, exc_info=True) continue sockets.append(sock) if reuse_address: diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index ab566b32..0ad06564 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -172,6 +172,9 @@ def _loop_reading(self, fut=None): except ConnectionAbortedError as exc: if not self._closing: self._fatal_error(exc, 'Fatal read error on pipe transport') + elif self._loop.get_debug(): + logger.debug("Read error on pipe transport while closing", + exc_info=True) except ConnectionResetError as exc: self._force_close(exc) except OSError as exc: @@ -324,12 +327,16 @@ def _set_extra(self, sock): try: self._extra['sockname'] = sock.getsockname() except (socket.error, AttributeError): - pass + if self._loop.get_debug(): + logger.warning("getsockname() failed on %r", + sock, exc_info=True) if 'peername' not in self._extra: try: self._extra['peername'] = sock.getpeername() except (socket.error, AttributeError): - pass + if self._loop.get_debug(): + logger.warning("getpeername() failed on %r", + sock, exc_info=True) def can_write_eof(self): return True @@ -462,11 +469,14 @@ def loop(f=None): except OSError as exc: if sock.fileno() != -1: self.call_exception_handler({ - 'message': 'Accept failed', + 'message': 'Accept failed on a socket', 'exception': exc, 'socket': sock, }) sock.close() + elif self._debug: + logger.debug("Accept failed on socket %r", + sock, exc_info=True) except futures.CancelledError: sock.close() else: diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 656ad59e..37310cfd 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -756,7 +756,9 @@ def _do_waitpid(self, expected_pid): except KeyError: # pragma: no cover # May happen if .remove_child_handler() is called # after os.waitpid() returns. - pass + if self._loop.get_debug(): + logger.warning("Child watcher got an unexpected pid: %r", + pid, exc_info=True) else: callback(pid, returncode, *args) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 6881789b..6763f0b7 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -252,6 +252,9 @@ def loop_accept_pipe(f=None): 'pipe': pipe, }) pipe.close() + elif self._debug: + logger.warning("Accept pipe failed on pipe %r", + pipe, exc_info=True) except futures.CancelledError: if pipe: pipe.close() From efd90f2b66beb1eefdd821b0077401246e4e3ca9 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 26 Aug 2014 00:10:48 +0200 Subject: [PATCH 1163/1502] Tulip issue #203: Add _FlowControlMixin.get_write_buffer_limits() method --- asyncio/transports.py | 3 +++ tests/test_transports.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/asyncio/transports.py b/asyncio/transports.py index 5f674f99..3caf853f 100644 --- a/asyncio/transports.py +++ b/asyncio/transports.py @@ -273,6 +273,9 @@ def _maybe_resume_protocol(self): 'protocol': self._protocol, }) + def get_write_buffer_limits(self): + return (self._low_water, self._high_water) + def _set_write_buffer_limits(self, high=None, low=None): if high is None: if low is None: diff --git a/tests/test_transports.py b/tests/test_transports.py index cfbdf3e9..5be1b7bf 100644 --- a/tests/test_transports.py +++ b/tests/test_transports.py @@ -79,9 +79,11 @@ def get_write_buffer_size(self): transport.set_write_buffer_limits(high=1024, low=128) self.assertFalse(transport._protocol_paused) + self.assertEqual(transport.get_write_buffer_limits(), (128, 1024)) transport.set_write_buffer_limits(high=256, low=128) self.assertTrue(transport._protocol_paused) + self.assertEqual(transport.get_write_buffer_limits(), (128, 256)) if __name__ == '__main__': From e637bf51253ae860c0a343433d8244636fa70da3 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 28 Aug 2014 11:14:14 +0200 Subject: [PATCH 1164/1502] Tulip issue #201: Fix a race condition in wait_for() Don't raise a TimeoutError if we reached the timeout and the future completed in the same iteration of the event loop. A side effect of the bug is that Queue.get() looses items. --- asyncio/tasks.py | 15 +++++++++------ tests/test_tasks.py | 15 +++++++++++++++ 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 92070162..c556e448 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -330,9 +330,9 @@ def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): return (yield from _wait(fs, timeout, return_when, loop)) -def _release_waiter(waiter, value=True, *args): +def _release_waiter(waiter, *args): if not waiter.done(): - waiter.set_result(value) + waiter.set_result(None) @coroutine @@ -357,14 +357,17 @@ def wait_for(fut, timeout, *, loop=None): return (yield from fut) waiter = futures.Future(loop=loop) - timeout_handle = loop.call_later(timeout, _release_waiter, waiter, False) - cb = functools.partial(_release_waiter, waiter, True) + timeout_handle = loop.call_later(timeout, _release_waiter, waiter) + cb = functools.partial(_release_waiter, waiter) fut = async(fut, loop=loop) fut.add_done_callback(cb) try: - if (yield from waiter): + # wait until the future completes or the timeout + yield from waiter + + if fut.done(): return fut.result() else: fut.remove_done_callback(cb) @@ -397,7 +400,7 @@ def _on_completion(f): if timeout_handle is not None: timeout_handle.cancel() if not waiter.done(): - waiter.set_result(False) + waiter.set_result(None) for f in fs: f.add_done_callback(_on_completion) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 95cba542..e25aa4d7 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -552,6 +552,21 @@ def foo(): self.assertTrue(fut.done()) self.assertTrue(fut.cancelled()) + def test_wait_for_race_condition(self): + + def gen(): + yield 0.1 + yield 0.1 + yield 0.1 + + loop = self.new_test_loop(gen) + + fut = asyncio.Future(loop=loop) + task = asyncio.wait_for(fut, timeout=0.2, loop=loop) + loop.call_later(0.1, fut.set_result, "ok") + res = loop.run_until_complete(task) + self.assertEqual(res, "ok") + def test_wait(self): def gen(): From 1ade959858caaf21f804092b008145d4265b0184 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 28 Aug 2014 11:21:37 +0200 Subject: [PATCH 1165/1502] runtests.py: display a message to mention if tests are run in debug or release mode --- runtests.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/runtests.py b/runtests.py index f23aaa90..d4a7aad6 100644 --- a/runtests.py +++ b/runtests.py @@ -264,6 +264,11 @@ def runtests(): finder = TestsFinder(args.testsdir, includes, excludes) if catchbreak: installHandler() + import asyncio.coroutines + if asyncio.coroutines._DEBUG: + print("Run tests in debug mode") + else: + print("Run tests in release mode") try: if args.forever: while True: From 25a84086c03469fc6bd711d662d914b83b811ff1 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sun, 31 Aug 2014 15:06:28 +0200 Subject: [PATCH 1166/1502] Tulip issue #205: Fix a race condition in BaseSelectorEventLoop.sock_connect() There is a race condition in create_connection() used with wait_for() to have a timeout. sock_connect() registers the file descriptor of the socket to be notified of write event (if connect() raises BlockingIOError). When create_connection() is cancelled with a TimeoutError, sock_connect() coroutine gets the exception, but it doesn't unregister the file descriptor for write event. create_connection() gets the TimeoutError and closes the socket. If you call again create_connection(), the new socket will likely gets the same file descriptor, which is still registered in the selector. When sock_connect() calls add_writer(), it tries to modify the entry instead of creating a new one. This issue was originally reported in the Trollius project, but the bug comes from Tulip in fact (Trollius is based on Tulip): https://bitbucket.org/enovance/trollius/issue/15/after-timeouterror-on-wait_for This change fixes the race condition. It also makes sock_connect() more reliable (and portable) is sock.connect() raises an InterruptedError. --- asyncio/selector_events.py | 44 +++++++++++++++------ tests/test_selector_events.py | 74 ++++++++++++++++++++++++----------- 2 files changed, 83 insertions(+), 35 deletions(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 0434a701..33de92e1 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -8,6 +8,7 @@ import collections import errno +import functools import socket try: import ssl @@ -345,26 +346,43 @@ def sock_connect(self, sock, address): except ValueError as err: fut.set_exception(err) else: - self._sock_connect(fut, False, sock, address) + self._sock_connect(fut, sock, address) return fut - def _sock_connect(self, fut, registered, sock, address): + def _sock_connect(self, fut, sock, address): fd = sock.fileno() - if registered: - self.remove_writer(fd) + try: + while True: + try: + sock.connect(address) + except InterruptedError: + continue + else: + break + except BlockingIOError: + fut.add_done_callback(functools.partial(self._sock_connect_done, + sock)) + self.add_writer(fd, self._sock_connect_cb, fut, sock, address) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(None) + + def _sock_connect_done(self, sock, fut): + self.remove_writer(sock.fileno()) + + def _sock_connect_cb(self, fut, sock, address): if fut.cancelled(): return + try: - if not registered: - # First time around. - sock.connect(address) - else: - err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) - if err != 0: - # Jump to the except clause below. - raise OSError(err, 'Connect call failed %s' % (address,)) + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to any except clause below. + raise OSError(err, 'Connect call failed %s' % (address,)) except (BlockingIOError, InterruptedError): - self.add_writer(fd, self._sock_connect, fut, True, sock, address) + # socket is still registered, the callback will be retried later + pass except Exception as exc: fut.set_exception(exc) else: diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index df6e9916..528da39d 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -40,8 +40,9 @@ def list_to_buffer(l=()): class BaseSelectorEventLoopTests(test_utils.TestCase): def setUp(self): - selector = mock.Mock() - self.loop = TestBaseSelectorEventLoop(selector) + self.selector = mock.Mock() + self.selector.select.return_value = [] + self.loop = TestBaseSelectorEventLoop(self.selector) self.set_event_loop(self.loop, cleanup=False) def test_make_socket_transport(self): @@ -303,63 +304,92 @@ def test_sock_connect(self): f = self.loop.sock_connect(sock, ('127.0.0.1', 8080)) self.assertIsInstance(f, asyncio.Future) self.assertEqual( - (f, False, sock, ('127.0.0.1', 8080)), + (f, sock, ('127.0.0.1', 8080)), self.loop._sock_connect.call_args[0]) + def test_sock_connect_timeout(self): + # Tulip issue #205: sock_connect() must unregister the socket on + # timeout error + + # prepare mocks + self.loop.add_writer = mock.Mock() + self.loop.remove_writer = mock.Mock() + sock = test_utils.mock_nonblocking_socket() + sock.connect.side_effect = BlockingIOError + + # first call to sock_connect() registers the socket + fut = self.loop.sock_connect(sock, ('127.0.0.1', 80)) + self.assertTrue(sock.connect.called) + self.assertTrue(self.loop.add_writer.called) + self.assertEqual(len(fut._callbacks), 1) + + # on timeout, the socket must be unregistered + sock.connect.reset_mock() + fut.set_exception(asyncio.TimeoutError) + with self.assertRaises(asyncio.TimeoutError): + self.loop.run_until_complete(fut) + self.assertTrue(self.loop.remove_writer.called) + def test__sock_connect(self): f = asyncio.Future(loop=self.loop) sock = mock.Mock() sock.fileno.return_value = 10 - self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) + self.loop._sock_connect(f, sock, ('127.0.0.1', 8080)) self.assertTrue(f.done()) self.assertIsNone(f.result()) self.assertTrue(sock.connect.called) - def test__sock_connect_canceled_fut(self): + def test__sock_connect_cb_cancelled_fut(self): sock = mock.Mock() + self.loop.remove_writer = mock.Mock() f = asyncio.Future(loop=self.loop) f.cancel() - self.loop._sock_connect(f, False, sock, ('127.0.0.1', 8080)) - self.assertFalse(sock.connect.called) + self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080)) + self.assertFalse(sock.getsockopt.called) + + def test__sock_connect_writer(self): + # check that the fd is registered and then unregistered + self.loop._process_events = mock.Mock() + self.loop.add_writer = mock.Mock() + self.loop.remove_writer = mock.Mock() - def test__sock_connect_unregister(self): sock = mock.Mock() sock.fileno.return_value = 10 + sock.connect.side_effect = BlockingIOError + sock.getsockopt.return_value = 0 + address = ('127.0.0.1', 8080) f = asyncio.Future(loop=self.loop) - f.cancel() + self.loop._sock_connect(f, sock, address) + self.assertTrue(self.loop.add_writer.called) + self.assertEqual(10, self.loop.add_writer.call_args[0][0]) - self.loop.remove_writer = mock.Mock() - self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.loop._sock_connect_cb(f, sock, address) + # need to run the event loop to execute _sock_connect_done() callback + self.loop.run_until_complete(f) self.assertEqual((10,), self.loop.remove_writer.call_args[0]) - def test__sock_connect_tryagain(self): + def test__sock_connect_cb_tryagain(self): f = asyncio.Future(loop=self.loop) sock = mock.Mock() sock.fileno.return_value = 10 sock.getsockopt.return_value = errno.EAGAIN - self.loop.add_writer = mock.Mock() - self.loop.remove_writer = mock.Mock() - - self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) - self.assertEqual( - (10, self.loop._sock_connect, f, - True, sock, ('127.0.0.1', 8080)), - self.loop.add_writer.call_args[0]) + # check that the exception is handled + self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080)) - def test__sock_connect_exception(self): + def test__sock_connect_cb_exception(self): f = asyncio.Future(loop=self.loop) sock = mock.Mock() sock.fileno.return_value = 10 sock.getsockopt.return_value = errno.ENOTCONN self.loop.remove_writer = mock.Mock() - self.loop._sock_connect(f, True, sock, ('127.0.0.1', 8080)) + self.loop._sock_connect_cb(f, sock, ('127.0.0.1', 8080)) self.assertIsInstance(f.exception(), OSError) def test_sock_accept(self): From eeeef365cb67138c2335f03ef48d43374c26417b Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 17 Sep 2014 23:14:36 +0200 Subject: [PATCH 1167/1502] Fix Handle and TimerHandle repr in debug mode Tulip issue #206: In debug mode, keep the callback in the representation of Handle and TimerHandle after cancel(). --- asyncio/events.py | 32 +++++++++++++++++++------------- tests/test_events.py | 9 +++++---- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/asyncio/events.py b/asyncio/events.py index 3c7a36d0..b7cc3512 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -73,7 +73,7 @@ class Handle: """Object returned by callback registration methods.""" __slots__ = ('_callback', '_args', '_cancelled', '_loop', - '_source_traceback', '__weakref__') + '_source_traceback', '_repr', '__weakref__') def __init__(self, callback, args, loop): assert not isinstance(callback, Handle), 'A Handle is not a callback' @@ -81,12 +81,13 @@ def __init__(self, callback, args, loop): self._callback = callback self._args = args self._cancelled = False + self._repr = None if self._loop.get_debug(): self._source_traceback = traceback.extract_stack(sys._getframe(1)) else: self._source_traceback = None - def __repr__(self): + def _repr_info(self): info = [self.__class__.__name__] if self._cancelled: info.append('cancelled') @@ -95,10 +96,21 @@ def __repr__(self): if self._source_traceback: frame = self._source_traceback[-1] info.append('created at %s:%s' % (frame[0], frame[1])) + return info + + def __repr__(self): + if self._repr is not None: + return self._repr + info = self._repr_info() return '<%s>' % ' '.join(info) def cancel(self): self._cancelled = True + if self._loop.get_debug(): + # Keep a representation in debug mode to keep callback and + # parameters. For example, to log the warning "Executing took 2.5 second" + self._repr = repr(self) self._callback = None self._args = None @@ -131,17 +143,11 @@ def __init__(self, when, callback, args, loop): del self._source_traceback[-1] self._when = when - def __repr__(self): - info = [] - if self._cancelled: - info.append('cancelled') - info.append('when=%s' % self._when) - if self._callback is not None: - info.append(_format_callback(self._callback, self._args)) - if self._source_traceback: - frame = self._source_traceback[-1] - info.append('created at %s:%s' % (frame[0], frame[1])) - return '<%s %s>' % (self.__class__.__name__, ' '.join(info)) + def _repr_info(self): + info = super()._repr_info() + pos = 2 if self._cancelled else 1 + info.insert(pos, 'when=%s' % self._when) + return info def __hash__(self): return hash(self._when) diff --git a/tests/test_events.py b/tests/test_events.py index 0cfc028d..7ac845a8 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1891,8 +1891,8 @@ def test_handle_repr_debug(self): # cancelled handle h.cancel() self.assertEqual(repr(h), - '' - % (create_filename, create_lineno)) + '' + % (filename, lineno, create_filename, create_lineno)) def test_handle_source_traceback(self): loop = asyncio.get_event_loop_policy().new_event_loop() @@ -1987,8 +1987,9 @@ def test_timer_repr_debug(self): # cancelled handle h.cancel() self.assertEqual(repr(h), - '' - % (create_filename, create_lineno)) + '' + % (filename, lineno, create_filename, create_lineno)) def test_timer_comparison(self): From 446daa0bef8fc0d371701e3bf53c7a7cbff2dfbb Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 17 Sep 2014 23:22:47 +0200 Subject: [PATCH 1168/1502] Python issue #22369: Change "context manager protocol" to "context management protocol". Patch written by Serhiy Storchaka . --- asyncio/locks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/asyncio/locks.py b/asyncio/locks.py index 574e3618..b943e9dd 100644 --- a/asyncio/locks.py +++ b/asyncio/locks.py @@ -63,7 +63,7 @@ class Lock: acquire() is a coroutine and should be called with 'yield from'. - Locks also support the context manager protocol. '(yield from lock)' + Locks also support the context management protocol. '(yield from lock)' should be used as context manager expression. Usage: @@ -376,7 +376,7 @@ class Semaphore: can never go below zero; when acquire() finds that it is zero, it blocks, waiting until some other thread calls release(). - Semaphores also support the context manager protocol. + Semaphores also support the context management protocol. The optional argument gives the initial value for the internal counter; it defaults to 1. If the value given is less than 0, From 8469bb606d512650f89d07dbe2b91272f30ae907 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Wed, 24 Sep 2014 23:16:48 -0400 Subject: [PATCH 1169/1502] tasks.py: Sync comments updates from cpython tree --- asyncio/tasks.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index c556e448..e0738021 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -77,9 +77,9 @@ def __init__(self, coro, *, loop=None): # status is still pending self._log_destroy_pending = True - # On Python 3.3 or older, objects with a destructor part of a reference - # cycle are never destroyed. It's not more the case on Python 3.4 thanks to - # the PEP 442. + # On Python 3.3 or older, objects with a destructor that are part of a + # reference cycle are never destroyed. That's not the case any more on + # Python 3.4 thanks to the PEP 442. if _PY34: def __del__(self): if self._state == futures._PENDING and self._log_destroy_pending: @@ -155,7 +155,8 @@ def print_stack(self, *, limit=None, file=None): This produces output similar to that of the traceback module, for the frames retrieved by get_stack(). The limit argument is passed to get_stack(). The file argument is an I/O stream - to which the output goes; by default it goes to sys.stderr. + to which the output is written; by default output is written + to sys.stderr. """ extracted_list = [] checked = set() @@ -184,18 +185,18 @@ def print_stack(self, *, limit=None, file=None): print(line, file=file, end='') def cancel(self): - """Request this task to cancel itself. + """Request that this task cancel itself. This arranges for a CancelledError to be thrown into the wrapped coroutine on the next cycle through the event loop. The coroutine then has a chance to clean up or even deny the request using try/except/finally. - Contrary to Future.cancel(), this does not guarantee that the + Unlike Future.cancel, this does not guarantee that the task will be cancelled: the exception might be caught and - acted upon, delaying cancellation of the task or preventing it - completely. The task may also return a value or raise a - different exception. + acted upon, delaying cancellation of the task or preventing + cancellation completely. The task may also return a value or + raise a different exception. Immediately after this method is called, Task.cancelled() will not return True (unless the task was already cancelled). A From 43e57b4cc53fc70a15cf1b56cd907f15ffd392c1 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Wed, 24 Sep 2014 23:23:19 -0400 Subject: [PATCH 1170/1502] unix_events: Move import statement to sync code with cpython --- asyncio/unix_events.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 37310cfd..93c8c1c8 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -1,7 +1,6 @@ """Selector event loop for Unix with signal handling.""" import errno -import fcntl import os import signal import socket @@ -263,6 +262,8 @@ def create_unix_server(self, protocol_factory, path=None, *, def _set_nonblocking(fd): os.set_blocking(fd, False) else: + import fcntl + def _set_nonblocking(fd): flags = fcntl.fcntl(fd, fcntl.F_GETFL) flags = flags | os.O_NONBLOCK From 0dfb7bf842b3f730722c341bd1b0851637c0a313 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Thu, 25 Sep 2014 12:02:25 -0400 Subject: [PATCH 1171/1502] Improve canceled timer callback handles cleanup (CPython issue #22448) Patch by Joshua Moore-Oliva. --- asyncio/base_events.py | 44 ++++++++++++++++---- asyncio/events.py | 29 +++++++++----- tests/test_base_events.py | 84 +++++++++++++++++++++++++++++++++++---- tests/test_events.py | 14 +++++-- 4 files changed, 145 insertions(+), 26 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index db132505..5aaf58f9 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -40,6 +40,13 @@ # Argument for default thread pool executor creation. _MAX_WORKERS = 5 +# Minimum number of _scheduled timer handles before cleanup of +# cancelled handles is performed. +_MIN_SCHEDULED_TIMER_HANDLES = 100 + +# Minimum fraction of _scheduled timer handles that are cancelled +# before cleanup of cancelled handles is performed. +_MIN_CANCELLED_TIMER_HANDLES_FRACTION = 0.5 def _format_handle(handle): cb = handle._callback @@ -145,6 +152,7 @@ def wait_closed(self): class BaseEventLoop(events.AbstractEventLoop): def __init__(self): + self._timer_cancelled_count = 0 self._closed = False self._ready = collections.deque() self._scheduled = [] @@ -349,6 +357,7 @@ def call_at(self, when, callback, *args): if timer._source_traceback: del timer._source_traceback[-1] heapq.heappush(self._scheduled, timer) + timer._scheduled = True return timer def call_soon(self, callback, *args): @@ -964,16 +973,19 @@ def _add_callback(self, handle): assert isinstance(handle, events.Handle), 'A Handle is required here' if handle._cancelled: return - if isinstance(handle, events.TimerHandle): - heapq.heappush(self._scheduled, handle) - else: - self._ready.append(handle) + assert not isinstance(handle, events.TimerHandle) + self._ready.append(handle) def _add_callback_signalsafe(self, handle): """Like _add_callback() but called from a signal handler.""" self._add_callback(handle) self._write_to_self() + def _timer_handle_cancelled(self, handle): + """Notification that a TimerHandle has been cancelled.""" + if handle._scheduled: + self._timer_cancelled_count += 1 + def _run_once(self): """Run one full iteration of the event loop. @@ -981,9 +993,26 @@ def _run_once(self): schedules the resulting callbacks, and finally schedules 'call_later' callbacks. """ - # Remove delayed calls that were cancelled from head of queue. - while self._scheduled and self._scheduled[0]._cancelled: - heapq.heappop(self._scheduled) + + # Remove delayed calls that were cancelled if their number is too high + sched_count = len(self._scheduled) + if (sched_count > _MIN_SCHEDULED_TIMER_HANDLES and + self._timer_cancelled_count / sched_count > + _MIN_CANCELLED_TIMER_HANDLES_FRACTION): + for handle in self._scheduled: + if handle._cancelled: + handle._scheduled = False + + self._scheduled = [x for x in self._scheduled if not x._cancelled] + self._timer_cancelled_count = 0 + + heapq.heapify(self._scheduled) + else: + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0]._cancelled: + self._timer_cancelled_count -= 1 + handle = heapq.heappop(self._scheduled) + handle._scheduled = False timeout = None if self._ready: @@ -1024,6 +1053,7 @@ def _run_once(self): if handle._when >= end_time: break handle = heapq.heappop(self._scheduled) + handle._scheduled = False self._ready.append(handle) # This is the only place where callbacks are actually *called*. diff --git a/asyncio/events.py b/asyncio/events.py index b7cc3512..806218f6 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -105,14 +105,15 @@ def __repr__(self): return '<%s>' % ' '.join(info) def cancel(self): - self._cancelled = True - if self._loop.get_debug(): - # Keep a representation in debug mode to keep callback and - # parameters. For example, to log the warning "Executing took 2.5 second" - self._repr = repr(self) - self._callback = None - self._args = None + if not self._cancelled: + self._cancelled = True + if self._loop.get_debug(): + # Keep a representation in debug mode to keep callback and + # parameters. For example, to log the warning + # "Executing took 2.5 second" + self._repr = repr(self) + self._callback = None + self._args = None def _run(self): try: @@ -134,7 +135,7 @@ def _run(self): class TimerHandle(Handle): """Object returned by timed callback registration methods.""" - __slots__ = ['_when'] + __slots__ = ['_scheduled', '_when'] def __init__(self, when, callback, args, loop): assert when is not None @@ -142,6 +143,7 @@ def __init__(self, when, callback, args, loop): if self._source_traceback: del self._source_traceback[-1] self._when = when + self._scheduled = False def _repr_info(self): info = super()._repr_info() @@ -180,6 +182,11 @@ def __ne__(self, other): equal = self.__eq__(other) return NotImplemented if equal is NotImplemented else not equal + def cancel(self): + if not self._cancelled: + self._loop._timer_handle_cancelled(self) + super().cancel() + class AbstractServer: """Abstract server returned by create_server().""" @@ -238,6 +245,10 @@ def close(self): # Methods scheduling callbacks. All these return Handles. + def _timer_handle_cancelled(self, handle): + """Notification that a TimerHandle has been cancelled.""" + raise NotImplementedError + def call_soon(self, callback, *args): return self.call_later(0, callback, *args) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index ca12101b..294872a9 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -2,6 +2,7 @@ import errno import logging +import math import socket import sys import time @@ -73,13 +74,6 @@ def test__add_callback_handle(self): self.assertFalse(self.loop._scheduled) self.assertIn(h, self.loop._ready) - def test__add_callback_timer(self): - h = asyncio.TimerHandle(time.monotonic()+10, lambda: False, (), - self.loop) - - self.loop._add_callback(h) - self.assertIn(h, self.loop._scheduled) - def test__add_callback_cancelled_handle(self): h = asyncio.Handle(lambda: False, (), self.loop) h.cancel() @@ -283,6 +277,82 @@ def cb(loop): self.assertTrue(processed) self.assertEqual([handle], list(self.loop._ready)) + def test__run_once_cancelled_event_cleanup(self): + self.loop._process_events = mock.Mock() + + self.assertTrue( + 0 < base_events._MIN_CANCELLED_TIMER_HANDLES_FRACTION < 1.0) + + def cb(): + pass + + # Set up one "blocking" event that will not be cancelled to + # ensure later cancelled events do not make it to the head + # of the queue and get cleaned. + not_cancelled_count = 1 + self.loop.call_later(3000, cb) + + # Add less than threshold (base_events._MIN_SCHEDULED_TIMER_HANDLES) + # cancelled handles, ensure they aren't removed + + cancelled_count = 2 + for x in range(2): + h = self.loop.call_later(3600, cb) + h.cancel() + + # Add some cancelled events that will be at head and removed + cancelled_count += 2 + for x in range(2): + h = self.loop.call_later(100, cb) + h.cancel() + + # This test is invalid if _MIN_SCHEDULED_TIMER_HANDLES is too low + self.assertLessEqual(cancelled_count + not_cancelled_count, + base_events._MIN_SCHEDULED_TIMER_HANDLES) + + self.assertEqual(self.loop._timer_cancelled_count, cancelled_count) + + self.loop._run_once() + + cancelled_count -= 2 + + self.assertEqual(self.loop._timer_cancelled_count, cancelled_count) + + self.assertEqual(len(self.loop._scheduled), + cancelled_count + not_cancelled_count) + + # Need enough events to pass _MIN_CANCELLED_TIMER_HANDLES_FRACTION + # so that deletion of cancelled events will occur on next _run_once + add_cancel_count = int(math.ceil( + base_events._MIN_SCHEDULED_TIMER_HANDLES * + base_events._MIN_CANCELLED_TIMER_HANDLES_FRACTION)) + 1 + + add_not_cancel_count = max(base_events._MIN_SCHEDULED_TIMER_HANDLES - + add_cancel_count, 0) + + # Add some events that will not be cancelled + not_cancelled_count += add_not_cancel_count + for x in range(add_not_cancel_count): + self.loop.call_later(3600, cb) + + # Add enough cancelled events + cancelled_count += add_cancel_count + for x in range(add_cancel_count): + h = self.loop.call_later(3600, cb) + h.cancel() + + # Ensure all handles are still scheduled + self.assertEqual(len(self.loop._scheduled), + cancelled_count + not_cancelled_count) + + self.loop._run_once() + + # Ensure cancelled events were removed + self.assertEqual(len(self.loop._scheduled), not_cancelled_count) + + # Ensure only uncancelled events remain scheduled + self.assertTrue(all([not x._cancelled for x in self.loop._scheduled])) + def test_run_until_complete_type_error(self): self.assertRaises(TypeError, self.loop.run_until_complete, 'blah') diff --git a/tests/test_events.py b/tests/test_events.py index 7ac845a8..a305e66d 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1890,9 +1890,17 @@ def test_handle_repr_debug(self): # cancelled handle h.cancel() - self.assertEqual(repr(h), - '' - % (filename, lineno, create_filename, create_lineno)) + self.assertEqual( + repr(h), + '' + % (filename, lineno, create_filename, create_lineno)) + + # double cancellation won't overwrite _repr + h.cancel() + self.assertEqual( + repr(h), + '' + % (filename, lineno, create_filename, create_lineno)) def test_handle_source_traceback(self): loop = asyncio.get_event_loop_policy().new_event_loop() From f3e29d20c153e4e5a468968aea204b1b1c15770a Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Thu, 25 Sep 2014 19:09:54 -0400 Subject: [PATCH 1172/1502] test_tasks: Fix test_env_var_debug to use correct asyncio module (issue #207) --- tests/test_tasks.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index e25aa4d7..770f2181 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1,5 +1,6 @@ """Tests for tasks.py.""" +import os import re import sys import types @@ -1768,25 +1769,31 @@ def test_return_exceptions(self): self.assertEqual(fut.result(), [3, 1, exc, exc2]) def test_env_var_debug(self): + aio_path = os.path.dirname(os.path.dirname(asyncio.__file__)) + code = '\n'.join(( 'import asyncio.coroutines', 'print(asyncio.coroutines._DEBUG)')) # Test with -E to not fail if the unit test was run with # PYTHONASYNCIODEBUG set to a non-empty string - sts, stdout, stderr = assert_python_ok('-E', '-c', code) + sts, stdout, stderr = assert_python_ok('-E', '-c', code, + PYTHONPATH=aio_path) self.assertEqual(stdout.rstrip(), b'False') sts, stdout, stderr = assert_python_ok('-c', code, - PYTHONASYNCIODEBUG='') + PYTHONASYNCIODEBUG='', + PYTHONPATH=aio_path) self.assertEqual(stdout.rstrip(), b'False') sts, stdout, stderr = assert_python_ok('-c', code, - PYTHONASYNCIODEBUG='1') + PYTHONASYNCIODEBUG='1', + PYTHONPATH=aio_path) self.assertEqual(stdout.rstrip(), b'True') sts, stdout, stderr = assert_python_ok('-E', '-c', code, - PYTHONASYNCIODEBUG='1') + PYTHONASYNCIODEBUG='1', + PYTHONPATH=aio_path) self.assertEqual(stdout.rstrip(), b'False') From f9401d73b8e9d0973f86e1d7903b2b16b78b39c2 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 30 Sep 2014 18:02:52 +0200 Subject: [PATCH 1173/1502] Python issue 22448: cleanup _run_once(), only iterate once to remove delayed calls that were cancelled. --- asyncio/base_events.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 5aaf58f9..3cff72ab 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -994,19 +994,22 @@ def _run_once(self): 'call_later' callbacks. """ - # Remove delayed calls that were cancelled if their number is too high sched_count = len(self._scheduled) if (sched_count > _MIN_SCHEDULED_TIMER_HANDLES and self._timer_cancelled_count / sched_count > _MIN_CANCELLED_TIMER_HANDLES_FRACTION): + # Remove delayed calls that were cancelled if their number + # is too high + new_scheduled = [] for handle in self._scheduled: if handle._cancelled: handle._scheduled = False + else: + new_scheduled.append(handle) - self._scheduled = [x for x in self._scheduled if not x._cancelled] + heapq.heapify(new_scheduled) + self._scheduled = new_scheduled self._timer_cancelled_count = 0 - - heapq.heapify(self._scheduled) else: # Remove delayed calls that were cancelled from head of queue. while self._scheduled and self._scheduled[0]._cancelled: From b4fdeface598199da38b59ea30021243fcc4897f Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 11 Oct 2014 14:28:27 +0200 Subject: [PATCH 1174/1502] run_forever() now consumes BaseException of the temporary task If the coroutine raised a BaseException, consume the exception to not log a warning. The caller doesn't have access to the local task. --- asyncio/base_events.py | 10 +++++++++- tests/test_base_events.py | 22 +++++++++++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 3cff72ab..b6b71239 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -268,7 +268,15 @@ def run_until_complete(self, future): future._log_destroy_pending = False future.add_done_callback(_raise_stop_error) - self.run_forever() + try: + self.run_forever() + except: + if new_task and future.done() and not future.cancelled(): + # The coroutine raised a BaseException. Consume the exception + # to not log a warning, the caller doesn't have access to the + # local task. + future.exception() + raise future.remove_done_callback(_raise_stop_error) if not future.done(): raise RuntimeError('Event loop stopped before Future completed.') diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 294872a9..afc448c0 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -9,7 +9,7 @@ import unittest from unittest import mock from test.script_helper import assert_python_ok -from test.support import IPV6_ENABLED +from test.support import IPV6_ENABLED, gc_collect import asyncio from asyncio import base_events @@ -618,6 +618,26 @@ def create_task(self, coro): task._log_destroy_pending = False coro.close() + def test_run_forever_keyboard_interrupt(self): + # Python issue #22601: ensure that the temporary task created by + # run_forever() consumes the KeyboardInterrupt and so don't log + # a warning + @asyncio.coroutine + def raise_keyboard_interrupt(): + raise KeyboardInterrupt + + self.loop._process_events = mock.Mock() + self.loop.call_exception_handler = mock.Mock() + + try: + self.loop.run_until_complete(raise_keyboard_interrupt()) + except KeyboardInterrupt: + pass + self.loop.close() + gc_collect() + + self.assertFalse(self.loop.call_exception_handler.called) + class MyProto(asyncio.Protocol): done = None From 4f468f436eff8e599814d06667f0d23c373ccf39 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 14 Oct 2014 22:55:37 +0200 Subject: [PATCH 1175/1502] Enhance protocol representation Add "closed" or "closing" to repr() of selector and proactor transports --- asyncio/proactor_events.py | 8 +++++++- asyncio/selector_events.py | 7 ++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 0ad06564..7132300a 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -42,7 +42,13 @@ def __init__(self, loop, sock, protocol, waiter=None, self._loop.call_soon(waiter._set_result_unless_cancelled, None) def __repr__(self): - info = [self.__class__.__name__, 'fd=%s' % self._sock.fileno()] + info = [self.__class__.__name__] + fd = self._sock.fileno() + if fd < 0: + info.append('closed') + elif self._closing: + info.append('closing') + info.append('fd=%s' % fd) if self._read_fut is not None: info.append('read=%s' % self._read_fut) if self._write_fut is not None: diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 33de92e1..a55eff78 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -467,7 +467,12 @@ def __init__(self, loop, sock, protocol, extra, server=None): self._server._attach() def __repr__(self): - info = [self.__class__.__name__, 'fd=%s' % self._sock_fd] + info = [self.__class__.__name__] + if self._sock is None: + info.append('closed') + elif self._closing: + info.append('closing') + info.append('fd=%s' % self._sock_fd) # test if the transport was closed if self._loop is not None: polling = _test_selector_event(self._loop._selector, From fdefca3a38b1f5b9cbb1cd6a8c3f5fcd79021080 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 14 Oct 2014 22:53:35 +0200 Subject: [PATCH 1176/1502] Reuse socket.socketpair() on Windows if available Since Python 3.5, socket.socketpair() is now also available on Windows. Make csock blocking before calling the accept() method, and fix also a typo in an error message. --- asyncio/windows_utils.py | 84 +++++++++++++++++++------------------ tests/test_windows_utils.py | 4 ++ 2 files changed, 47 insertions(+), 41 deletions(-) diff --git a/asyncio/windows_utils.py b/asyncio/windows_utils.py index f7f2f358..1155a77f 100644 --- a/asyncio/windows_utils.py +++ b/asyncio/windows_utils.py @@ -28,49 +28,51 @@ _mmap_counter = itertools.count() -# Replacement for socket.socketpair() - - -def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): - """A socket pair usable as a self-pipe, for Windows. - - Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. - """ - if family == socket.AF_INET: - host = '127.0.0.1' - elif family == socket.AF_INET6: - host = '::1' - else: - raise ValueError("Ony AF_INET and AF_INET6 socket address families " - "are supported") - if type != socket.SOCK_STREAM: - raise ValueError("Only SOCK_STREAM socket type is supported") - if proto != 0: - raise ValueError("Only protocol zero is supported") - - # We create a connected TCP socket. Note the trick with setblocking(0) - # that prevents us from having to create a thread. - lsock = socket.socket(family, type, proto) - try: - lsock.bind((host, 0)) - lsock.listen(1) - # On IPv6, ignore flow_info and scope_id - addr, port = lsock.getsockname()[:2] - csock = socket.socket(family, type, proto) +if hasattr(socket, 'socketpair'): + # Since Python 3.5, socket.socketpair() is now also available on Windows + socketpair = socket.socketpair +else: + # Replacement for socket.socketpair() + def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """A socket pair usable as a self-pipe, for Windows. + + Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. + """ + if family == socket.AF_INET: + host = '127.0.0.1' + elif family == socket.AF_INET6: + host = '::1' + else: + raise ValueError("Only AF_INET and AF_INET6 socket address families " + "are supported") + if type != socket.SOCK_STREAM: + raise ValueError("Only SOCK_STREAM socket type is supported") + if proto != 0: + raise ValueError("Only protocol zero is supported") + + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) try: - csock.setblocking(False) + lsock.bind((host, 0)) + lsock.listen(1) + # On IPv6, ignore flow_info and scope_id + addr, port = lsock.getsockname()[:2] + csock = socket.socket(family, type, proto) try: - csock.connect((addr, port)) - except (BlockingIOError, InterruptedError): - pass - ssock, _ = lsock.accept() - csock.setblocking(True) - except: - csock.close() - raise - finally: - lsock.close() - return (ssock, csock) + csock.setblocking(False) + try: + csock.connect((addr, port)) + except (BlockingIOError, InterruptedError): + pass + csock.setblocking(True) + ssock, _ = lsock.accept() + except: + csock.close() + raise + finally: + lsock.close() + return (ssock, csock) # Replacement for os.pipe() using handles instead of fds diff --git a/tests/test_windows_utils.py b/tests/test_windows_utils.py index 7ea3a6d3..3e7a211e 100644 --- a/tests/test_windows_utils.py +++ b/tests/test_windows_utils.py @@ -33,6 +33,8 @@ def test_winsocketpair_ipv6(self): ssock, csock = windows_utils.socketpair(family=socket.AF_INET6) self.check_winsocketpair(ssock, csock) + @unittest.skipIf(hasattr(socket, 'socketpair'), + 'socket.socketpair is available') @mock.patch('asyncio.windows_utils.socket') def test_winsocketpair_exc(self, m_socket): m_socket.AF_INET = socket.AF_INET @@ -51,6 +53,8 @@ def test_winsocketpair_invalid_args(self): self.assertRaises(ValueError, windows_utils.socketpair, proto=1) + @unittest.skipIf(hasattr(socket, 'socketpair'), + 'socket.socketpair is available') @mock.patch('asyncio.windows_utils.socket') def test_winsocketpair_close(self, m_socket): m_socket.AF_INET = socket.AF_INET From c89935b04ddb12f1063dfc7dd84ded1f0595b67e Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 8 Oct 2014 11:46:30 -0700 Subject: [PATCH 1177/1502] Added tag 3.4.2 for changeset 5f2a130f7a8c From ce9a485a831375a759c38e2fea4508406cce7f0a Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 5 Nov 2014 15:25:26 +0100 Subject: [PATCH 1178/1502] Move loop attribute to _FlowControlMixin Move the _loop attribute from the constructor of _SelectorTransport, _ProactorBasePipeTransport and _UnixWritePipeTransport classes to the constructor of the _FlowControlMixin class. Add also an assertion to explicit that the parent class must ensure that the loop is defined (not None) --- asyncio/proactor_events.py | 3 +-- asyncio/selector_events.py | 3 +-- asyncio/transports.py | 4 +++- asyncio/unix_events.py | 3 +-- tests/test_transports.py | 3 ++- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 7132300a..a1e2fef6 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -21,9 +21,8 @@ class _ProactorBasePipeTransport(transports._FlowControlMixin, def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None): - super().__init__(extra) + super().__init__(extra, loop) self._set_extra(sock) - self._loop = loop self._sock = sock self._protocol = protocol self._server = server diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index a55eff78..da64a60c 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -447,7 +447,7 @@ class _SelectorTransport(transports._FlowControlMixin, _buffer_factory = bytearray # Constructs initial value for self._buffer. def __init__(self, loop, sock, protocol, extra, server=None): - super().__init__(extra) + super().__init__(extra, loop) self._extra['socket'] = sock self._extra['sockname'] = sock.getsockname() if 'peername' not in self._extra: @@ -455,7 +455,6 @@ def __init__(self, loop, sock, protocol, extra, server=None): self._extra['peername'] = sock.getpeername() except socket.error: self._extra['peername'] = None - self._loop = loop self._sock = sock self._sock_fd = sock.fileno() self._protocol = protocol diff --git a/asyncio/transports.py b/asyncio/transports.py index 3caf853f..22df3c7a 100644 --- a/asyncio/transports.py +++ b/asyncio/transports.py @@ -238,8 +238,10 @@ class _FlowControlMixin(Transport): resume_writing() may be called. """ - def __init__(self, extra=None): + def __init__(self, extra=None, loop=None): super().__init__(extra) + assert loop is not None + self._loop = loop self._protocol_paused = False self._set_write_buffer_limits() diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 93c8c1c8..b16f946a 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -369,9 +369,8 @@ class _UnixWritePipeTransport(transports._FlowControlMixin, transports.WriteTransport): def __init__(self, loop, pipe, protocol, waiter=None, extra=None): - super().__init__(extra) + super().__init__(extra, loop) self._extra['pipe'] = pipe - self._loop = loop self._pipe = pipe self._fileno = pipe.fileno() mode = os.fstat(self._fileno).st_mode diff --git a/tests/test_transports.py b/tests/test_transports.py index 5be1b7bf..3b6e3d67 100644 --- a/tests/test_transports.py +++ b/tests/test_transports.py @@ -69,7 +69,8 @@ class MyTransport(transports._FlowControlMixin, def get_write_buffer_size(self): return 512 - transport = MyTransport() + loop = mock.Mock() + transport = MyTransport(loop=loop) transport._protocol = mock.Mock() self.assertFalse(transport._protocol_paused) From bf40c729a05acd9a8c3e97bcd82295e3ecea7396 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 5 Nov 2014 15:30:16 +0100 Subject: [PATCH 1179/1502] Python issue #22641: On Python 3.4 and newer, the default SSL context for client connections is now created using ssl.create_default_context(), for stronger security. Patch written by Antoine Pitrou. --- asyncio/selector_events.py | 13 ++++++----- asyncio/test_utils.py | 13 ++++++++++- tests/test_base_events.py | 2 +- tests/test_events.py | 48 ++++++++++++++++++++++++++++++-------- 4 files changed, 58 insertions(+), 18 deletions(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index da64a60c..116d3801 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -688,16 +688,17 @@ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, if not sslcontext: # Client side may pass ssl=True to use a default # context; in that case the sslcontext passed is None. - # The default is the same as used by urllib with - # cadefault=True. - if hasattr(ssl, '_create_stdlib_context'): - sslcontext = ssl._create_stdlib_context( - cert_reqs=ssl.CERT_REQUIRED, - check_hostname=bool(server_hostname)) + # The default is secure for client connections. + if hasattr(ssl, 'create_default_context'): + # Python 3.4+: use up-to-date strong settings. + sslcontext = ssl.create_default_context() + if not server_hostname: + sslcontext.check_hostname = False else: # Fallback for Python 3.3. sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.options |= ssl.OP_NO_SSLv3 sslcontext.set_default_verify_paths() sslcontext.verify_mode = ssl.CERT_REQUIRED diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index ac7680de..3e5eee54 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -91,6 +91,13 @@ def log_message(self, format, *args): class SilentWSGIServer(WSGIServer): + request_timeout = 2 + + def get_request(self): + request, client_addr = super().get_request() + request.settimeout(self.request_timeout) + return request, client_addr + def handle_error(self, request, client_address): pass @@ -138,7 +145,8 @@ def app(environ, start_response): httpd = server_class(address, SilentWSGIRequestHandler) httpd.set_app(app) httpd.address = httpd.server_address - server_thread = threading.Thread(target=httpd.serve_forever) + server_thread = threading.Thread( + target=lambda: httpd.serve_forever(poll_interval=0.05)) server_thread.start() try: yield httpd @@ -160,12 +168,15 @@ def server_bind(self): class UnixWSGIServer(UnixHTTPServer, WSGIServer): + request_timeout = 2 + def server_bind(self): UnixHTTPServer.server_bind(self) self.setup_environ() def get_request(self): request, client_addr = super().get_request() + request.settimeout(self.request_timeout) # Code in the stdlib expects that get_request # will return a socket and a tuple (host, port). # However, this isn't true for UNIX sockets, diff --git a/tests/test_base_events.py b/tests/test_base_events.py index afc448c0..d61a64c9 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -371,7 +371,7 @@ def test_subprocess_exec_invalid_args(self): self.loop.run_until_complete, self.loop.subprocess_exec, asyncio.SubprocessProtocol) - # exepected multiple arguments, not a list + # expected multiple arguments, not a list self.assertRaises(TypeError, self.loop.run_until_complete, self.loop.subprocess_exec, asyncio.SubprocessProtocol, args) diff --git a/tests/test_events.py b/tests/test_events.py index a305e66d..fe1e3add 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -606,15 +606,43 @@ def _basetest_create_ssl_connection(self, connection_fut, self.assertGreater(pr.nbytes, 0) tr.close() + def _dummy_ssl_create_context(self, purpose=ssl.Purpose.SERVER_AUTH, *, + cafile=None, capath=None, cadata=None): + """ + A ssl.create_default_context() replacement that doesn't enable + cert validation. + """ + self.assertEqual(purpose, ssl.Purpose.SERVER_AUTH) + return test_utils.dummy_ssl_context() + + def _test_create_ssl_connection(self, httpd, create_connection, + check_sockname=True): + conn_fut = create_connection(ssl=test_utils.dummy_ssl_context()) + self._basetest_create_ssl_connection(conn_fut, check_sockname) + + # With ssl=True, ssl.create_default_context() should be called + with mock.patch('ssl.create_default_context', + side_effect=self._dummy_ssl_create_context) as m: + conn_fut = create_connection(ssl=True) + self._basetest_create_ssl_connection(conn_fut, check_sockname) + self.assertEqual(m.call_count, 1) + + # With the real ssl.create_default_context(), certificate + # validation will fail + with self.assertRaises(ssl.SSLError) as cm: + conn_fut = create_connection(ssl=True) + self._basetest_create_ssl_connection(conn_fut, check_sockname) + + self.assertEqual(cm.exception.reason, 'CERTIFICATE_VERIFY_FAILED') + @unittest.skipIf(ssl is None, 'No ssl module') def test_create_ssl_connection(self): with test_utils.run_test_server(use_ssl=True) as httpd: - conn_fut = self.loop.create_connection( + create_connection = functools.partial( + self.loop.create_connection, lambda: MyProto(loop=self.loop), - *httpd.address, - ssl=test_utils.dummy_ssl_context()) - - self._basetest_create_ssl_connection(conn_fut) + *httpd.address) + self._test_create_ssl_connection(httpd, create_connection) @unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') @@ -624,13 +652,13 @@ def test_create_ssl_unix_connection(self): check_sockname = not osx_tiger() with test_utils.run_test_unix_server(use_ssl=True) as httpd: - conn_fut = self.loop.create_unix_connection( - lambda: MyProto(loop=self.loop), - httpd.address, - ssl=test_utils.dummy_ssl_context(), + create_connection = functools.partial( + self.loop.create_unix_connection, + lambda: MyProto(loop=self.loop), httpd.address, server_hostname='127.0.0.1') - self._basetest_create_ssl_connection(conn_fut, check_sockname) + self._test_create_ssl_connection(httpd, create_connection, + check_sockname) def test_create_connection_local_addr(self): with test_utils.run_test_server() as httpd: From ad7c76c2b04315c9d7205a24789449eb7900916b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 14 Nov 2014 11:41:13 -0800 Subject: [PATCH 1180/1502] - Issue #22841: Reject coroutines in asyncio add_signal_handler(). Patch by Ludovic.Gasc. --- asyncio/unix_events.py | 3 +++ tests/test_unix_events.py | 12 ++++++++++++ 2 files changed, 15 insertions(+) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index b16f946a..e49212e5 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -13,6 +13,7 @@ from . import base_events from . import base_subprocess from . import constants +from . import coroutines from . import events from . import selector_events from . import selectors @@ -66,6 +67,8 @@ def add_signal_handler(self, sig, callback, *args): Raise ValueError if the signal number is invalid or uncatchable. Raise RuntimeError if there is a problem setting up the handler. """ + if coroutines.iscoroutinefunction(callback): + raise TypeError("coroutines cannot be used with call_soon()") self._check_signal(sig) try: # set_wakeup_fd() raises ValueError if this is not the diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index e3975982..2f3fa185 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -63,6 +63,18 @@ def test_add_signal_handler_setup_error(self, m_signal): self.loop.add_signal_handler, signal.SIGINT, lambda: True) + @mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler_coroutine_error(self, m_signal): + + @asyncio.coroutine + def simple_coroutine(): + yield from [] + + self.assertRaises( + TypeError, + self.loop.add_signal_handler, + signal.SIGINT, simple_coroutine) + @mock.patch('asyncio.unix_events.signal') def test_add_signal_handler(self, m_signal): m_signal.NSIG = signal.NSIG From f47ca38eee3778aacb5f410f773f9ad7d5500b17 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Fri, 14 Nov 2014 11:54:29 -0800 Subject: [PATCH 1181/1502] CPython issue #22784: fix test_asyncio when the ssl module isn't available. (Antoine Pitrou) --- tests/test_events.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/test_events.py b/tests/test_events.py index fe1e3add..4fe4b4c4 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -606,14 +606,15 @@ def _basetest_create_ssl_connection(self, connection_fut, self.assertGreater(pr.nbytes, 0) tr.close() - def _dummy_ssl_create_context(self, purpose=ssl.Purpose.SERVER_AUTH, *, - cafile=None, capath=None, cadata=None): - """ - A ssl.create_default_context() replacement that doesn't enable - cert validation. - """ - self.assertEqual(purpose, ssl.Purpose.SERVER_AUTH) - return test_utils.dummy_ssl_context() + if ssl: + def _dummy_ssl_create_context(self, purpose=ssl.Purpose.SERVER_AUTH, *, + cafile=None, capath=None, cadata=None): + """ + A ssl.create_default_context() replacement that doesn't enable + cert validation. + """ + self.assertEqual(purpose, ssl.Purpose.SERVER_AUTH) + return test_utils.dummy_ssl_context() def _test_create_ssl_connection(self, httpd, create_connection, check_sockname=True): From 7395edde19df3d8d1518436fc260d55e2026797b Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 20 Nov 2014 14:11:34 +0100 Subject: [PATCH 1182/1502] runtests.py: only catch SkipTest exception when loading a module, don't catch all exceptions --- runtests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtests.py b/runtests.py index d4a7aad6..e9bbdd8e 100644 --- a/runtests.py +++ b/runtests.py @@ -111,7 +111,7 @@ def list_dir(prefix, dir): mods.append((loader.load_module(), sourcefile)) except SyntaxError: raise - except Exception as err: + except unittest.SkipTest as err: print("Skipping '{}': {}".format(modname, err), file=sys.stderr) return mods From 03f6284674d729d9e733c7521a8c90bed9f096c9 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 20 Nov 2014 14:12:19 +0100 Subject: [PATCH 1183/1502] Fix test_events.py on Python 3.3: ssl.Purpose was introduced in Python 3.4 --- tests/test_events.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/tests/test_events.py b/tests/test_events.py index 4fe4b4c4..b05cb7ca 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -606,27 +606,29 @@ def _basetest_create_ssl_connection(self, connection_fut, self.assertGreater(pr.nbytes, 0) tr.close() - if ssl: - def _dummy_ssl_create_context(self, purpose=ssl.Purpose.SERVER_AUTH, *, - cafile=None, capath=None, cadata=None): - """ - A ssl.create_default_context() replacement that doesn't enable - cert validation. - """ - self.assertEqual(purpose, ssl.Purpose.SERVER_AUTH) - return test_utils.dummy_ssl_context() - def _test_create_ssl_connection(self, httpd, create_connection, check_sockname=True): conn_fut = create_connection(ssl=test_utils.dummy_ssl_context()) self._basetest_create_ssl_connection(conn_fut, check_sockname) - # With ssl=True, ssl.create_default_context() should be called - with mock.patch('ssl.create_default_context', - side_effect=self._dummy_ssl_create_context) as m: - conn_fut = create_connection(ssl=True) - self._basetest_create_ssl_connection(conn_fut, check_sockname) - self.assertEqual(m.call_count, 1) + # ssl.Purpose was introduced in Python 3.4 + if hasattr(ssl, 'Purpose'): + def _dummy_ssl_create_context(purpose=ssl.Purpose.SERVER_AUTH, *, + cafile=None, capath=None, + cadata=None): + """ + A ssl.create_default_context() replacement that doesn't enable + cert validation. + """ + self.assertEqual(purpose, ssl.Purpose.SERVER_AUTH) + return test_utils.dummy_ssl_context() + + # With ssl=True, ssl.create_default_context() should be called + with mock.patch('ssl.create_default_context', + side_effect=_dummy_ssl_create_context) as m: + conn_fut = create_connection(ssl=True) + self._basetest_create_ssl_connection(conn_fut, check_sockname) + self.assertEqual(m.call_count, 1) # With the real ssl.create_default_context(), certificate # validation will fail From a67be4f5a033d23e46837f378c53357d5c96bb8f Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 20 Nov 2014 14:13:22 +0100 Subject: [PATCH 1184/1502] Fix formatting of the "Future exception was never retrieved" Add an unit test to check for non regression --- asyncio/futures.py | 7 +++-- tests/test_futures.py | 65 ++++++++++++++++++++++++++++--------------- 2 files changed, 46 insertions(+), 26 deletions(-) diff --git a/asyncio/futures.py b/asyncio/futures.py index 7998fbbc..40662a32 100644 --- a/asyncio/futures.py +++ b/asyncio/futures.py @@ -104,10 +104,11 @@ def clear(self): def __del__(self): if self.tb: - msg = 'Future/Task exception was never retrieved' + msg = 'Future/Task exception was never retrieved\n' if self.source_traceback: - msg += '\nFuture/Task created at (most recent call last):\n' - msg += ''.join(traceback.format_list(self.source_traceback)) + src = ''.join(traceback.format_list(self.source_traceback)) + msg += 'Future/Task created at (most recent call last):\n' + msg += '%s\n' % src.rstrip() msg += ''.join(self.tb).rstrip() self.loop.call_exception_handler({'message': msg}) diff --git a/tests/test_futures.py b/tests/test_futures.py index e5002bc8..371d3518 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -307,8 +307,8 @@ def test_future_source_traceback(self): 'test_future_source_traceback')) @mock.patch('asyncio.base_events.logger') - def test_future_exception_never_retrieved(self, m_log): - self.loop.set_debug(True) + def check_future_exception_never_retrieved(self, debug, m_log): + self.loop.set_debug(debug) def memory_error(): try: @@ -318,40 +318,59 @@ def memory_error(): exc = memory_error() future = asyncio.Future(loop=self.loop) - source_traceback = future._source_traceback + if debug: + source_traceback = future._source_traceback future.set_exception(exc) future = None test_utils.run_briefly(self.loop) support.gc_collect() if sys.version_info >= (3, 4): - frame = source_traceback[-1] - regex = (r'^Future exception was never retrieved\n' - r'future: \n' - r'source_traceback: Object created at \(most recent call last\):\n' - r' File' - r'.*\n' - r' File "{filename}", line {lineno}, in test_future_exception_never_retrieved\n' - r' future = asyncio\.Future\(loop=self\.loop\)$' - ).format(filename=re.escape(frame[0]), lineno=frame[1]) + if debug: + frame = source_traceback[-1] + regex = (r'^Future exception was never retrieved\n' + r'future: \n' + r'source_traceback: Object created at \(most recent call last\):\n' + r' File' + r'.*\n' + r' File "{filename}", line {lineno}, in check_future_exception_never_retrieved\n' + r' future = asyncio\.Future\(loop=self\.loop\)$' + ).format(filename=re.escape(frame[0]), lineno=frame[1]) + else: + regex = (r'^Future exception was never retrieved\n' + r'future: $' + ) exc_info = (type(exc), exc, exc.__traceback__) m_log.error.assert_called_once_with(mock.ANY, exc_info=exc_info) else: - frame = source_traceback[-1] - regex = (r'^Future/Task exception was never retrieved\n' - r'Future/Task created at \(most recent call last\):\n' - r' File' - r'.*\n' - r' File "{filename}", line {lineno}, in test_future_exception_never_retrieved\n' - r' future = asyncio\.Future\(loop=self\.loop\)\n' - r'Traceback \(most recent call last\):\n' - r'.*\n' - r'MemoryError$' - ).format(filename=re.escape(frame[0]), lineno=frame[1]) + if debug: + frame = source_traceback[-1] + regex = (r'^Future/Task exception was never retrieved\n' + r'Future/Task created at \(most recent call last\):\n' + r' File' + r'.*\n' + r' File "{filename}", line {lineno}, in check_future_exception_never_retrieved\n' + r' future = asyncio\.Future\(loop=self\.loop\)\n' + r'Traceback \(most recent call last\):\n' + r'.*\n' + r'MemoryError$' + ).format(filename=re.escape(frame[0]), lineno=frame[1]) + else: + regex = (r'^Future/Task exception was never retrieved\n' + r'Traceback \(most recent call last\):\n' + r'.*\n' + r'MemoryError$' + ) m_log.error.assert_called_once_with(mock.ANY, exc_info=False) message = m_log.error.call_args[0][0] self.assertRegex(message, re.compile(regex, re.DOTALL)) + def test_future_exception_never_retrieved(self): + self.check_future_exception_never_retrieved(False) + + def test_future_exception_never_retrieved_debug(self): + self.check_future_exception_never_retrieved(True) + def test_set_result_unless_cancelled(self): fut = asyncio.Future(loop=self.loop) fut.cancel() From d3ed2f61596af89d6acbd006bb63c32e16709b7a Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 20 Nov 2014 14:18:53 +0100 Subject: [PATCH 1185/1502] test_events: Ignore the "SSL handshake failed" log in debug mode --- tests/test_events.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_events.py b/tests/test_events.py index b05cb7ca..fab3259f 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -634,7 +634,9 @@ def _dummy_ssl_create_context(purpose=ssl.Purpose.SERVER_AUTH, *, # validation will fail with self.assertRaises(ssl.SSLError) as cm: conn_fut = create_connection(ssl=True) - self._basetest_create_ssl_connection(conn_fut, check_sockname) + # Ignore the "SSL handshake failed" log in debug mode + with test_utils.disable_logger(): + self._basetest_create_ssl_connection(conn_fut, check_sockname) self.assertEqual(cm.exception.reason, 'CERTIFICATE_VERIFY_FAILED') From f9f28622b5925596392ac183f17b1d235a15ae78 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 20 Nov 2014 15:03:16 +0100 Subject: [PATCH 1186/1502] Coroutine objects are now rejected with a TypeError by the following functions: * add_signal_handler() * call_at() * call_later() * call_soon() * call_soon_threadsafe() * run_in_executor() Fix also the error message of add_signal_handler() (fix the name of the function). --- asyncio/base_events.py | 11 +++++++---- asyncio/unix_events.py | 5 +++-- tests/test_base_events.py | 26 +++++++++++++++----------- tests/test_unix_events.py | 14 ++++++++++---- 4 files changed, 35 insertions(+), 21 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index b6b71239..40dd6682 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -357,7 +357,8 @@ def call_at(self, when, callback, *args): Absolute time corresponds to the event loop's time() method. """ - if coroutines.iscoroutinefunction(callback): + if (coroutines.iscoroutine(callback) + or coroutines.iscoroutinefunction(callback)): raise TypeError("coroutines cannot be used with call_at()") if self._debug: self._assert_is_current_event_loop() @@ -384,7 +385,8 @@ def call_soon(self, callback, *args): return handle def _call_soon(self, callback, args, check_loop): - if coroutines.iscoroutinefunction(callback): + if (coroutines.iscoroutine(callback) + or coroutines.iscoroutinefunction(callback)): raise TypeError("coroutines cannot be used with call_soon()") if self._debug and check_loop: self._assert_is_current_event_loop() @@ -421,8 +423,9 @@ def call_soon_threadsafe(self, callback, *args): return handle def run_in_executor(self, executor, callback, *args): - if coroutines.iscoroutinefunction(callback): - raise TypeError("Coroutines cannot be used with run_in_executor()") + if (coroutines.iscoroutine(callback) + or coroutines.iscoroutinefunction(callback)): + raise TypeError("coroutines cannot be used with run_in_executor()") if isinstance(callback, events.Handle): assert not args assert not isinstance(callback, events.TimerHandle) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index e49212e5..efe06d4a 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -67,8 +67,9 @@ def add_signal_handler(self, sig, callback, *args): Raise ValueError if the signal number is invalid or uncatchable. Raise RuntimeError if there is a problem setting up the handler. """ - if coroutines.iscoroutinefunction(callback): - raise TypeError("coroutines cannot be used with call_soon()") + if (coroutines.iscoroutine(callback) + or coroutines.iscoroutinefunction(callback)): + raise TypeError("coroutines cannot be used with add_signal_handler()") self._check_signal(sig) try: # set_wakeup_fd() raises ValueError if this is not the diff --git a/tests/test_base_events.py b/tests/test_base_events.py index d61a64c9..0aa01174 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -1107,19 +1107,23 @@ def test_accept_connection_exception(self, m_log): def test_call_coroutine(self): @asyncio.coroutine - def coroutine_function(): + def simple_coroutine(): pass - with self.assertRaises(TypeError): - self.loop.call_soon(coroutine_function) - with self.assertRaises(TypeError): - self.loop.call_soon_threadsafe(coroutine_function) - with self.assertRaises(TypeError): - self.loop.call_later(60, coroutine_function) - with self.assertRaises(TypeError): - self.loop.call_at(self.loop.time() + 60, coroutine_function) - with self.assertRaises(TypeError): - self.loop.run_in_executor(None, coroutine_function) + coro_func = simple_coroutine + coro_obj = coro_func() + self.addCleanup(coro_obj.close) + for func in (coro_func, coro_obj): + with self.assertRaises(TypeError): + self.loop.call_soon(func) + with self.assertRaises(TypeError): + self.loop.call_soon_threadsafe(func) + with self.assertRaises(TypeError): + self.loop.call_later(60, func) + with self.assertRaises(TypeError): + self.loop.call_at(self.loop.time() + 60, func) + with self.assertRaises(TypeError): + self.loop.run_in_executor(None, func) @mock.patch('asyncio.base_events.logger') def test_log_slow_callbacks(self, m_logger): diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 2f3fa185..b6ad0189 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -65,15 +65,21 @@ def test_add_signal_handler_setup_error(self, m_signal): @mock.patch('asyncio.unix_events.signal') def test_add_signal_handler_coroutine_error(self, m_signal): + m_signal.NSIG = signal.NSIG @asyncio.coroutine def simple_coroutine(): yield from [] - self.assertRaises( - TypeError, - self.loop.add_signal_handler, - signal.SIGINT, simple_coroutine) + # callback must not be a coroutine function + coro_func = simple_coroutine + coro_obj = coro_func() + self.addCleanup(coro_obj.close) + for func in (coro_func, coro_obj): + self.assertRaisesRegex( + TypeError, 'coroutines cannot be used with add_signal_handler', + self.loop.add_signal_handler, + signal.SIGINT, func) @mock.patch('asyncio.unix_events.signal') def test_add_signal_handler(self, m_signal): From b7c708f81c4573ef59450b7ca098fa310c114e47 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 21 Nov 2014 00:13:50 +0100 Subject: [PATCH 1187/1502] BaseSelectorEventLoop.close() now closes the self-pipe before calling the parent close() method. If the event loop is already closed, the self-pipe is not unregistered from the selector. --- asyncio/selector_events.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 116d3801..f0c94c45 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -68,10 +68,12 @@ def _make_datagram_transport(self, sock, protocol, address, waiter, extra) def close(self): + if self._running: + raise RuntimeError("Cannot close a running event loop") if self.is_closed(): return - super().close() self._close_self_pipe() + super().close() if self._selector is not None: self._selector.close() self._selector = None From 95113218f7d2056b348d8ebe60f5442e93cdc29d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 25 Nov 2014 17:17:13 +0100 Subject: [PATCH 1188/1502] Python issue #22685: Set the transport of stdout and stderr StreamReader objects in the SubprocessStreamProtocol. It allows to pause the transport to not buffer too much stdout or stderr data. --- asyncio/subprocess.py | 17 ++++++++++++----- tests/test_subprocess.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/asyncio/subprocess.py b/asyncio/subprocess.py index e4c14995..f6d6a141 100644 --- a/asyncio/subprocess.py +++ b/asyncio/subprocess.py @@ -41,15 +41,22 @@ def __repr__(self): def connection_made(self, transport): self._transport = transport - if transport.get_pipe_transport(1): + + stdout_transport = transport.get_pipe_transport(1) + if stdout_transport is not None: self.stdout = streams.StreamReader(limit=self._limit, loop=self._loop) - if transport.get_pipe_transport(2): + self.stdout.set_transport(stdout_transport) + + stderr_transport = transport.get_pipe_transport(2) + if stderr_transport is not None: self.stderr = streams.StreamReader(limit=self._limit, loop=self._loop) - stdin = transport.get_pipe_transport(0) - if stdin is not None: - self.stdin = streams.StreamWriter(stdin, + self.stderr.set_transport(stderr_transport) + + stdin_transport = transport.get_pipe_transport(0) + if stdin_transport is not None: + self.stdin = streams.StreamWriter(stdin_transport, protocol=self, reader=None, loop=self._loop) diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index 0e9e1ce5..d0ab2308 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -4,6 +4,7 @@ import signal import sys import unittest +from unittest import mock from test import support if sys.platform != 'win32': from asyncio import unix_events @@ -161,6 +162,37 @@ def test_communicate_ignore_broken_pipe(self): self.loop.run_until_complete(proc.communicate(large_data)) self.loop.run_until_complete(proc.wait()) + def test_pause_reading(self): + @asyncio.coroutine + def test_pause_reading(): + limit = 100 + + code = '\n'.join(( + 'import sys', + 'sys.stdout.write("x" * %s)' % (limit * 2 + 1), + 'sys.stdout.flush()', + )) + proc = yield from asyncio.create_subprocess_exec( + sys.executable, '-c', code, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + limit=limit, + loop=self.loop) + stdout_transport = proc._transport.get_pipe_transport(1) + stdout_transport.pause_reading = mock.Mock() + + yield from proc.wait() + + # The child process produced more than limit bytes of output, + # the stream reader transport should pause the protocol to not + # allocate too much memory. + return stdout_transport.pause_reading.called + + # Issue #22685: Ensure that the stream reader pauses the protocol + # when the child process produces too much data + called = self.loop.run_until_complete(test_pause_reading()) + self.assertTrue(called) + if sys.platform != 'win32': # Unix From 1fd7c6d9600ff024a6068149c3a63528b5120b67 Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Tue, 25 Nov 2014 17:23:22 +0100 Subject: [PATCH 1189/1502] Python issue #22921: Don't require OpenSSL SNI to pass hostname to ssl functions. Patch by Donald Stufft. --- asyncio/selector_events.py | 2 +- tests/test_events.py | 8 -------- tests/test_selector_events.py | 2 +- 3 files changed, 2 insertions(+), 10 deletions(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index f0c94c45..7df8b866 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -708,7 +708,7 @@ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, 'server_side': server_side, 'do_handshake_on_connect': False, } - if server_hostname and not server_side and ssl.HAS_SNI: + if server_hostname and not server_side: wrap_kwargs['server_hostname'] = server_hostname sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs) diff --git a/tests/test_events.py b/tests/test_events.py index fab3259f..ea657fd4 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -12,9 +12,6 @@ import ssl except ImportError: ssl = None - HAS_SNI = False -else: - from ssl import HAS_SNI import subprocess import sys import threading @@ -857,7 +854,6 @@ def test_create_unix_server_ssl(self): server.close() @unittest.skipIf(ssl is None, 'No ssl module') - @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') def test_create_server_ssl_verify_failed(self): proto = MyProto(loop=self.loop) server, host, port = self._make_ssl_server( @@ -882,7 +878,6 @@ def test_create_server_ssl_verify_failed(self): server.close() @unittest.skipIf(ssl is None, 'No ssl module') - @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_create_unix_server_ssl_verify_failed(self): proto = MyProto(loop=self.loop) @@ -909,7 +904,6 @@ def test_create_unix_server_ssl_verify_failed(self): server.close() @unittest.skipIf(ssl is None, 'No ssl module') - @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') def test_create_server_ssl_match_failed(self): proto = MyProto(loop=self.loop) server, host, port = self._make_ssl_server( @@ -937,7 +931,6 @@ def test_create_server_ssl_match_failed(self): server.close() @unittest.skipIf(ssl is None, 'No ssl module') - @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_create_unix_server_ssl_verified(self): proto = MyProto(loop=self.loop) @@ -963,7 +956,6 @@ def test_create_unix_server_ssl_verified(self): server.close() @unittest.skipIf(ssl is None, 'No ssl module') - @unittest.skipUnless(HAS_SNI, 'No SNI support in ssl module') def test_create_server_ssl_verified(self): proto = MyProto(loop=self.loop) server, host, port = self._make_ssl_server( diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 528da39d..8eba56c4 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -1408,7 +1408,7 @@ def test_close(self): self.assertEqual(tr._conn_lost, 1) self.assertEqual(1, self.loop.remove_reader_count[1]) - @unittest.skipIf(ssl is None or not ssl.HAS_SNI, 'No SNI support') + @unittest.skipIf(ssl is None, 'No SSL support') def test_server_hostname(self): _SelectorSslTransport( self.loop, self.sock, self.protocol, self.sslcontext, From 8da949409b84ffc160dd949f3122b8acb79a9474 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 4 Dec 2014 22:24:08 +0100 Subject: [PATCH 1190/1502] Initialize more Future and Task attributes in the class definition to avoid attribute errors in destructors. --- asyncio/futures.py | 3 +-- asyncio/tasks.py | 7 ++++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/asyncio/futures.py b/asyncio/futures.py index 40662a32..03a4bf07 100644 --- a/asyncio/futures.py +++ b/asyncio/futures.py @@ -135,6 +135,7 @@ class Future: _result = None _exception = None _loop = None + _source_traceback = None _blocking = False # proper use of future (yield vs yield from) @@ -155,8 +156,6 @@ def __init__(self, *, loop=None): self._callbacks = [] if self._loop.get_debug(): self._source_traceback = traceback.extract_stack(sys._getframe(1)) - else: - self._source_traceback = None def _format_callbacks(self): cb = self._callbacks diff --git a/asyncio/tasks.py b/asyncio/tasks.py index e0738021..698ec6a3 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -41,6 +41,10 @@ class Task(futures.Future): # all running event loops. {EventLoop: Task} _current_tasks = {} + # If False, don't log a message if the task is destroyed whereas its + # status is still pending + _log_destroy_pending = True + @classmethod def current_task(cls, loop=None): """Return the currently running task in an event loop or None. @@ -73,9 +77,6 @@ def __init__(self, coro, *, loop=None): self._must_cancel = False self._loop.call_soon(self._step) self.__class__._all_tasks.add(self) - # If False, don't log a message if the task is destroyed whereas its - # status is still pending - self._log_destroy_pending = True # On Python 3.3 or older, objects with a destructor that are part of a # reference cycle are never destroyed. That's not the case any more on From f1399b67f698c37c6379c82ba98e3eb0d5422f0e Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 4 Dec 2014 22:45:05 +0100 Subject: [PATCH 1191/1502] Python issue #22922: More EventLoop methods fail if the loop is closed. Initial patch written by Torsten Landschoff. create_task(), call_at(), call_soon(), call_soon_threadsafe() and run_in_executor() now raise an error if the event loop is closed. --- asyncio/base_events.py | 4 ++++ asyncio/unix_events.py | 1 + tests/test_events.py | 35 ++++++++++++++++++++++++++++++++++- 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 40dd6682..7c38b093 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -177,6 +177,7 @@ def create_task(self, coro): Return a task object. """ + self._check_closed() task = tasks.Task(coro, loop=self) if task._source_traceback: del task._source_traceback[-1] @@ -360,6 +361,7 @@ def call_at(self, when, callback, *args): if (coroutines.iscoroutine(callback) or coroutines.iscoroutinefunction(callback)): raise TypeError("coroutines cannot be used with call_at()") + self._check_closed() if self._debug: self._assert_is_current_event_loop() timer = events.TimerHandle(when, callback, args, self) @@ -390,6 +392,7 @@ def _call_soon(self, callback, args, check_loop): raise TypeError("coroutines cannot be used with call_soon()") if self._debug and check_loop: self._assert_is_current_event_loop() + self._check_closed() handle = events.Handle(callback, args, self) if handle._source_traceback: del handle._source_traceback[-1] @@ -426,6 +429,7 @@ def run_in_executor(self, executor, callback, *args): if (coroutines.iscoroutine(callback) or coroutines.iscoroutinefunction(callback)): raise TypeError("coroutines cannot be used with run_in_executor()") + self._check_closed() if isinstance(callback, events.Handle): assert not args assert not isinstance(callback, events.TimerHandle) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index efe06d4a..d5db4d55 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -71,6 +71,7 @@ def add_signal_handler(self, sig, callback, *args): or coroutines.iscoroutinefunction(callback)): raise TypeError("coroutines cannot be used with add_signal_handler()") self._check_signal(sig) + self._check_closed() try: # set_wakeup_fd() raises ValueError if this is not the # main thread. By calling it early we ensure that an diff --git a/tests/test_events.py b/tests/test_events.py index ea657fd4..6644fbea 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -226,7 +226,8 @@ def setUp(self): def tearDown(self): # just in case if we have transport close callbacks - test_utils.run_briefly(self.loop) + if not self.loop.is_closed(): + test_utils.run_briefly(self.loop) self.loop.close() gc.collect() @@ -1434,6 +1435,38 @@ def close_loop(loop): with self.assertRaises(RuntimeError): self.loop.run_until_complete(coro) + def test_close(self): + self.loop.close() + + @asyncio.coroutine + def test(): + pass + + func = lambda: False + coro = test() + self.addCleanup(coro.close) + + # operation blocked when the loop is closed + with self.assertRaises(RuntimeError): + self.loop.run_forever() + with self.assertRaises(RuntimeError): + fut = asyncio.Future(loop=self.loop) + self.loop.run_until_complete(fut) + with self.assertRaises(RuntimeError): + self.loop.call_soon(func) + with self.assertRaises(RuntimeError): + self.loop.call_soon_threadsafe(func) + with self.assertRaises(RuntimeError): + self.loop.call_later(1.0, func) + with self.assertRaises(RuntimeError): + self.loop.call_at(self.loop.time() + .0, func) + with self.assertRaises(RuntimeError): + self.loop.run_in_executor(None, func) + with self.assertRaises(RuntimeError): + self.loop.create_task(coro) + with self.assertRaises(RuntimeError): + self.loop.add_signal_handler(signal.SIGTERM, func) + class SubprocessTestsMixin: From 6988474ded05805805be4ea35012ee68b2ccdcfd Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 4 Dec 2014 22:56:40 +0100 Subject: [PATCH 1192/1502] Python issue #22475: fix Task.get_stack() doc --- asyncio/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 698ec6a3..9aebffda 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -110,7 +110,7 @@ def _repr_info(self): def get_stack(self, *, limit=None): """Return the list of stack frames for this task's coroutine. - If the coroutine is active, this returns the stack where it is + If the coroutine is not done, this returns the stack where it is suspended. If the coroutine has completed successfully or was cancelled, this returns an empty list. If the coroutine was terminated by an exception, this returns the list of traceback From 4316872d9f0f4b6a92567e70d43d397a947cba1d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 4 Dec 2014 22:59:12 +0100 Subject: [PATCH 1193/1502] Removed duplicated words in in comments and docs. Patch written by Serhiy Storchaka. --- asyncio/futures.py | 2 +- tests/test_windows_events.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/asyncio/futures.py b/asyncio/futures.py index 03a4bf07..f46d008f 100644 --- a/asyncio/futures.py +++ b/asyncio/futures.py @@ -61,7 +61,7 @@ class itself, but instead to have a reference to a helper object the Future is collected, and the helper is present, the helper object is also collected, and its __del__() method will log the traceback. When the Future's result() or exception() method is - called (and a helper object is present), it removes the the helper + called (and a helper object is present), it removes the helper object, after calling its clear() method to prevent it from logging. diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index 85d9669b..b4d9398f 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -105,7 +105,7 @@ def test_wait_for_handle(self): _overlapped.SetEvent(event) - # Wait for for set event; + # Wait for set event; # result should be True immediately fut = self.loop._proactor.wait_for_handle(event, 10) start = self.loop.time() From 6c89fea1745309c97d56d4d733e7b473e06bfdb8 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 4 Dec 2014 23:04:26 +0100 Subject: [PATCH 1194/1502] Python issue #22685: Fix test_pause_reading() of test_subprocess * mock also resume_reading() * ensure that resume_reading() is called --- tests/test_subprocess.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index d0ab2308..9060b9d3 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -163,13 +163,14 @@ def test_communicate_ignore_broken_pipe(self): self.loop.run_until_complete(proc.wait()) def test_pause_reading(self): + limit = 10 + size = (limit * 2 + 1) + @asyncio.coroutine def test_pause_reading(): - limit = 100 - code = '\n'.join(( 'import sys', - 'sys.stdout.write("x" * %s)' % (limit * 2 + 1), + 'sys.stdout.write("x" * %s)' % size, 'sys.stdout.flush()', )) proc = yield from asyncio.create_subprocess_exec( @@ -180,18 +181,22 @@ def test_pause_reading(): loop=self.loop) stdout_transport = proc._transport.get_pipe_transport(1) stdout_transport.pause_reading = mock.Mock() + stdout_transport.resume_reading = mock.Mock() - yield from proc.wait() + stdout, stderr = yield from proc.communicate() # The child process produced more than limit bytes of output, # the stream reader transport should pause the protocol to not # allocate too much memory. - return stdout_transport.pause_reading.called + return (stdout, stdout_transport) # Issue #22685: Ensure that the stream reader pauses the protocol # when the child process produces too much data - called = self.loop.run_until_complete(test_pause_reading()) - self.assertTrue(called) + stdout, transport = self.loop.run_until_complete(test_pause_reading()) + + self.assertEqual(stdout, b'x' * size) + self.assertTrue(transport.pause_reading.called) + self.assertTrue(transport.resume_reading.called) if sys.platform != 'win32': From dffbf3a5435eb1b6f4c680d334661f48329132cf Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 5 Dec 2014 01:40:59 +0100 Subject: [PATCH 1195/1502] Python issue #22429: Fix EventLoop.run_until_complete(), don't stop the event loop if a BaseException is raised, because the event loop is already stopped. --- asyncio/base_events.py | 14 ++++++++++++-- tests/test_base_events.py | 25 +++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 7c38b093..0c7316ea 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -102,6 +102,16 @@ def _raise_stop_error(*args): raise _StopError +def _run_until_complete_cb(fut): + exc = fut._exception + if (isinstance(exc, BaseException) + and not isinstance(exc, Exception)): + # Issue #22429: run_forever() already finished, no need to + # stop it. + return + _raise_stop_error() + + class Server(events.AbstractServer): def __init__(self, loop, sockets): @@ -268,7 +278,7 @@ def run_until_complete(self, future): # is no need to log the "destroy pending task" message future._log_destroy_pending = False - future.add_done_callback(_raise_stop_error) + future.add_done_callback(_run_until_complete_cb) try: self.run_forever() except: @@ -278,7 +288,7 @@ def run_until_complete(self, future): # local task. future.exception() raise - future.remove_done_callback(_raise_stop_error) + future.remove_done_callback(_run_until_complete_cb) if not future.done(): raise RuntimeError('Event loop stopped before Future completed.') diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 0aa01174..db9d732c 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -638,6 +638,31 @@ def raise_keyboard_interrupt(): self.assertFalse(self.loop.call_exception_handler.called) + def test_run_until_complete_baseexception(self): + # Python issue #22429: run_until_complete() must not schedule a pending + # call to stop() if the future raised a BaseException + @asyncio.coroutine + def raise_keyboard_interrupt(): + raise KeyboardInterrupt + + self.loop._process_events = mock.Mock() + + try: + self.loop.run_until_complete(raise_keyboard_interrupt()) + except KeyboardInterrupt: + pass + + def func(): + self.loop.stop() + func.called = True + func.called = False + try: + self.loop.call_soon(func) + self.loop.run_forever() + except KeyboardInterrupt: + pass + self.assertTrue(func.called) + class MyProto(asyncio.Protocol): done = None From 361a1cd6d8c8c30a2bf5f694c9b04c0344731b02 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 5 Dec 2014 01:43:04 +0100 Subject: [PATCH 1196/1502] Python issue #22922: Fix ProactorEventLoop.close() Call _stop_accept_futures() before sestting the _closed attribute, otherwise call_soon() raises an error. --- asyncio/proactor_events.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index a1e2fef6..4c527aa2 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -387,11 +387,13 @@ def _make_write_pipe_transport(self, sock, protocol, waiter=None, sock, protocol, waiter, extra) def close(self): + if self._running: + raise RuntimeError("Cannot close a running event loop") if self.is_closed(): return - super().close() self._stop_accept_futures() self._close_self_pipe() + super().close() self._proactor.close() self._proactor = None self._selector = None From abbbb433e5cf9430ff4a7b9bb5beed72e235d03e Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Mon, 8 Dec 2014 12:27:57 -0500 Subject: [PATCH 1197/1502] selectors: Make sure EpollSelecrtor.select() works when no FD is registered. Closes http://bugs.python.org/issue23009 --- asyncio/selectors.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/asyncio/selectors.py b/asyncio/selectors.py index 4e9ae6ec..ad7afe2a 100644 --- a/asyncio/selectors.py +++ b/asyncio/selectors.py @@ -418,7 +418,12 @@ def select(self, timeout=None): # epoll_wait() has a resolution of 1 millisecond, round away # from zero to wait *at least* timeout seconds. timeout = math.ceil(timeout * 1e3) * 1e-3 - max_ev = len(self._fd_to_key) + + # epoll_wait() expectcs `maxevents` to be greater than zero; + # we want to make sure that `select()` can be called when no + # FD is registered. + max_ev = max(len(self._fd_to_key), 1) + ready = [] try: fd_event_list = self._epoll.poll(timeout, max_ev) From 28f4d5327bafe61a58f8024a33264d79e737fcf4 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Mon, 8 Dec 2014 12:31:03 -0500 Subject: [PATCH 1198/1502] selectors: Fix typo. --- asyncio/selectors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/selectors.py b/asyncio/selectors.py index ad7afe2a..faa2d3da 100644 --- a/asyncio/selectors.py +++ b/asyncio/selectors.py @@ -419,7 +419,7 @@ def select(self, timeout=None): # from zero to wait *at least* timeout seconds. timeout = math.ceil(timeout * 1e3) * 1e-3 - # epoll_wait() expectcs `maxevents` to be greater than zero; + # epoll_wait() expects `maxevents` to be greater than zero; # we want to make sure that `select()` can be called when no # FD is registered. max_ev = max(len(self._fd_to_key), 1) From 65de4b6dd38fe73cfff7af90edeff299c95364e5 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 11 Dec 2014 22:19:47 +0100 Subject: [PATCH 1199/1502] Tulip issue #202: Add unit test of pause/resume writing for proactor socket transport --- asyncio/proactor_events.py | 4 -- tests/test_proactor_events.py | 82 +++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 4 deletions(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 4c527aa2..e67cf65a 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -230,10 +230,6 @@ def write(self, data): assert self._buffer is None # Pass a copy, except if it's already immutable. self._loop_writing(data=bytes(data)) - # XXX Should we pause the protocol at this point - # if len(data) > self._high_water? (That would - # require keeping track of the number of bytes passed - # to a send() that hasn't finished yet.) elif not self._buffer: # WRITING -> BACKED UP # Make a mutable copy which we can extend. self._buffer = bytearray(data) diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index 0c536986..9e9b41a4 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -343,6 +343,88 @@ def test_pause_resume_reading(self): tr.close() + def pause_writing_transport(self, high): + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol) + self.addCleanup(tr.close) + + tr.set_write_buffer_limits(high=high) + + self.assertEqual(tr.get_write_buffer_size(), 0) + self.assertFalse(self.protocol.pause_writing.called) + self.assertFalse(self.protocol.resume_writing.called) + return tr + + def test_pause_resume_writing(self): + tr = self.pause_writing_transport(high=4) + + # write a large chunk, must pause writing + fut = asyncio.Future(loop=self.loop) + self.loop._proactor.send.return_value = fut + tr.write(b'large data') + self.loop._run_once() + self.assertTrue(self.protocol.pause_writing.called) + + # flush the buffer + fut.set_result(None) + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 0) + self.assertTrue(self.protocol.resume_writing.called) + + def test_pause_writing_2write(self): + tr = self.pause_writing_transport(high=4) + + # first short write, the buffer is not full (3 <= 4) + fut1 = asyncio.Future(loop=self.loop) + self.loop._proactor.send.return_value = fut1 + tr.write(b'123') + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 3) + self.assertFalse(self.protocol.pause_writing.called) + + # fill the buffer, must pause writing (6 > 4) + tr.write(b'abc') + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 6) + self.assertTrue(self.protocol.pause_writing.called) + + def test_pause_writing_3write(self): + tr = self.pause_writing_transport(high=4) + + # first short write, the buffer is not full (1 <= 4) + fut = asyncio.Future(loop=self.loop) + self.loop._proactor.send.return_value = fut + tr.write(b'1') + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 1) + self.assertFalse(self.protocol.pause_writing.called) + + # second short write, the buffer is not full (3 <= 4) + tr.write(b'23') + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 3) + self.assertFalse(self.protocol.pause_writing.called) + + # fill the buffer, must pause writing (6 > 4) + tr.write(b'abc') + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 6) + self.assertTrue(self.protocol.pause_writing.called) + + def test_dont_pause_writing(self): + tr = self.pause_writing_transport(high=4) + + # write a large chunk which completes immedialty, + # it should not pause writing + fut = asyncio.Future(loop=self.loop) + fut.set_result(None) + self.loop._proactor.send.return_value = fut + tr.write(b'very large data') + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 0) + self.assertFalse(self.protocol.pause_writing.called) + + class BaseProactorEventLoopTests(test_utils.TestCase): def setUp(self): From cd00b2cb2f1614283fd318145431f19bf10d78a0 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 11 Dec 2014 23:28:34 +0100 Subject: [PATCH 1200/1502] Fix subprocess for close_fds=False on Python 3.3 Mark the write end of the stdin pipe as non-inheritable. --- asyncio/unix_events.py | 22 ++++++++++++++++++++++ tests/test_subprocess.py | 21 +++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index d5db4d55..d1461fd0 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -547,6 +547,22 @@ def _call_connection_lost(self, exc): self._loop = None +if hasattr(os, 'set_inheritable'): + # Python 3.4 and newer + _set_inheritable = os.set_inheritable +else: + import fcntl + + def _set_inheritable(fd, inheritable): + cloexec_flag = getattr(fcntl, 'FD_CLOEXEC', 1) + + old = fcntl.fcntl(fd, fcntl.F_GETFD) + if not inheritable: + fcntl.fcntl(fd, fcntl.F_SETFD, old | cloexec_flag) + else: + fcntl.fcntl(fd, fcntl.F_SETFD, old & ~cloexec_flag) + + class _UnixSubprocessTransport(base_subprocess.BaseSubprocessTransport): def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): @@ -558,6 +574,12 @@ def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): # other end). Notably this is needed on AIX, and works # just fine on other platforms. stdin, stdin_w = self._loop._socketpair() + + # Mark the write end of the stdin pipe as non-inheritable, + # needed by close_fds=False on Python 3.3 and older + # (Python 3.4 implements the PEP 446, socketpair returns + # non-inheritable sockets) + _set_inheritable(stdin_w.fileno(), False) self._proc = subprocess.Popen( args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, universal_newlines=False, bufsize=bufsize, **kwargs) diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index 9060b9d3..5c0a2c85 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -198,6 +198,27 @@ def test_pause_reading(): self.assertTrue(transport.pause_reading.called) self.assertTrue(transport.resume_reading.called) + def test_stdin_not_inheritable(self): + # Tulip issue #209: stdin must not be inheritable, otherwise + # the Process.communicate() hangs + @asyncio.coroutine + def len_message(message): + code = 'import sys; data = sys.stdin.read(); print(len(data))' + proc = yield from asyncio.create_subprocess_exec( + sys.executable, '-c', code, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + close_fds=False, + loop=self.loop) + stdout, stderr = yield from proc.communicate(message) + exitcode = yield from proc.wait() + return (stdout, exitcode) + + output, exitcode = self.loop.run_until_complete(len_message(b'abc')) + self.assertEqual(output.rstrip(), b'3') + self.assertEqual(exitcode, 0) + if sys.platform != 'win32': # Unix From 557bcba3e5b165e493c8bddda50866f8fb789c09 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 12 Dec 2014 18:08:50 +0100 Subject: [PATCH 1201/1502] Add run_aiotest.py --- .hgeol | 4 + .hgignore | 14 + AUTHORS | 26 + COPYING | 201 ++ MANIFEST.in | 11 + Makefile | 60 + README | 44 + asyncio/__init__.py | 48 + asyncio/base_events.py | 1113 +++++++++++ asyncio/base_subprocess.py | 197 ++ asyncio/constants.py | 7 + asyncio/coroutines.py | 195 ++ asyncio/events.py | 597 ++++++ asyncio/futures.py | 411 ++++ asyncio/locks.py | 469 +++++ asyncio/log.py | 7 + asyncio/proactor_events.py | 502 +++++ asyncio/protocols.py | 129 ++ asyncio/queues.py | 288 +++ asyncio/selector_events.py | 1003 ++++++++++ asyncio/selectors.py | 590 ++++++ asyncio/streams.py | 485 +++++ asyncio/subprocess.py | 235 +++ asyncio/tasks.py | 660 +++++++ asyncio/test_utils.py | 436 ++++ asyncio/transports.py | 300 +++ asyncio/unix_events.py | 949 +++++++++ asyncio/windows_events.py | 634 ++++++ asyncio/windows_utils.py | 209 ++ check.py | 45 + examples/cacheclt.py | 213 ++ examples/cachesvr.py | 249 +++ examples/child_process.py | 128 ++ examples/crawl.py | 863 ++++++++ examples/echo_client_tulip.py | 20 + examples/echo_server_tulip.py | 20 + examples/fetch0.py | 35 + examples/fetch1.py | 78 + examples/fetch2.py | 141 ++ examples/fetch3.py | 230 +++ examples/fuzz_as_completed.py | 69 + examples/hello_callback.py | 17 + examples/hello_coroutine.py | 18 + examples/shell.py | 50 + examples/simple_tcp_server.py | 154 ++ examples/sink.py | 94 + examples/source.py | 100 + examples/source1.py | 98 + examples/stacks.py | 44 + examples/subprocess_attach_read_pipe.py | 33 + examples/subprocess_attach_write_pipe.py | 35 + examples/subprocess_shell.py | 87 + examples/tcp_echo.py | 128 ++ examples/timing_tcp_server.py | 168 ++ examples/udp_echo.py | 104 + overlapped.c | 1380 +++++++++++++ pypi.bat | 1 + run_aiotest.py | 14 + runtests.py | 302 +++ setup.py | 34 + tests/echo.py | 8 + tests/echo2.py | 6 + tests/echo3.py | 11 + tests/keycert3.pem | 73 + tests/pycacert.pem | 78 + tests/sample.crt | 14 + tests/sample.key | 15 + tests/ssl_cert.pem | 15 + tests/ssl_key.pem | 16 + tests/test_base_events.py | 1183 +++++++++++ tests/test_events.py | 2306 ++++++++++++++++++++++ tests/test_futures.py | 461 +++++ tests/test_locks.py | 858 ++++++++ tests/test_proactor_events.py | 574 ++++++ tests/test_queues.py | 476 +++++ tests/test_selector_events.py | 1737 ++++++++++++++++ tests/test_selectors.py | 214 ++ tests/test_streams.py | 628 ++++++ tests/test_subprocess.py | 275 +++ tests/test_tasks.py | 1984 +++++++++++++++++++ tests/test_transports.py | 91 + tests/test_unix_events.py | 1600 +++++++++++++++ tests/test_windows_events.py | 141 ++ tests/test_windows_utils.py | 175 ++ update_stdlib.sh | 65 + 85 files changed, 27780 insertions(+) create mode 100644 .hgeol create mode 100644 .hgignore create mode 100644 AUTHORS create mode 100644 COPYING create mode 100644 MANIFEST.in create mode 100644 Makefile create mode 100644 README create mode 100644 asyncio/__init__.py create mode 100644 asyncio/base_events.py create mode 100644 asyncio/base_subprocess.py create mode 100644 asyncio/constants.py create mode 100644 asyncio/coroutines.py create mode 100644 asyncio/events.py create mode 100644 asyncio/futures.py create mode 100644 asyncio/locks.py create mode 100644 asyncio/log.py create mode 100644 asyncio/proactor_events.py create mode 100644 asyncio/protocols.py create mode 100644 asyncio/queues.py create mode 100644 asyncio/selector_events.py create mode 100644 asyncio/selectors.py create mode 100644 asyncio/streams.py create mode 100644 asyncio/subprocess.py create mode 100644 asyncio/tasks.py create mode 100644 asyncio/test_utils.py create mode 100644 asyncio/transports.py create mode 100644 asyncio/unix_events.py create mode 100644 asyncio/windows_events.py create mode 100644 asyncio/windows_utils.py create mode 100644 check.py create mode 100644 examples/cacheclt.py create mode 100644 examples/cachesvr.py create mode 100644 examples/child_process.py create mode 100644 examples/crawl.py create mode 100644 examples/echo_client_tulip.py create mode 100644 examples/echo_server_tulip.py create mode 100644 examples/fetch0.py create mode 100644 examples/fetch1.py create mode 100644 examples/fetch2.py create mode 100644 examples/fetch3.py create mode 100644 examples/fuzz_as_completed.py create mode 100644 examples/hello_callback.py create mode 100644 examples/hello_coroutine.py create mode 100644 examples/shell.py create mode 100644 examples/simple_tcp_server.py create mode 100644 examples/sink.py create mode 100644 examples/source.py create mode 100644 examples/source1.py create mode 100644 examples/stacks.py create mode 100644 examples/subprocess_attach_read_pipe.py create mode 100644 examples/subprocess_attach_write_pipe.py create mode 100644 examples/subprocess_shell.py create mode 100755 examples/tcp_echo.py create mode 100644 examples/timing_tcp_server.py create mode 100755 examples/udp_echo.py create mode 100644 overlapped.c create mode 100644 pypi.bat create mode 100644 run_aiotest.py create mode 100644 runtests.py create mode 100644 setup.py create mode 100644 tests/echo.py create mode 100644 tests/echo2.py create mode 100644 tests/echo3.py create mode 100644 tests/keycert3.pem create mode 100644 tests/pycacert.pem create mode 100644 tests/sample.crt create mode 100644 tests/sample.key create mode 100644 tests/ssl_cert.pem create mode 100644 tests/ssl_key.pem create mode 100644 tests/test_base_events.py create mode 100644 tests/test_events.py create mode 100644 tests/test_futures.py create mode 100644 tests/test_locks.py create mode 100644 tests/test_proactor_events.py create mode 100644 tests/test_queues.py create mode 100644 tests/test_selector_events.py create mode 100644 tests/test_selectors.py create mode 100644 tests/test_streams.py create mode 100644 tests/test_subprocess.py create mode 100644 tests/test_tasks.py create mode 100644 tests/test_transports.py create mode 100644 tests/test_unix_events.py create mode 100644 tests/test_windows_events.py create mode 100644 tests/test_windows_utils.py create mode 100755 update_stdlib.sh diff --git a/.hgeol b/.hgeol new file mode 100644 index 00000000..8233b6dd --- /dev/null +++ b/.hgeol @@ -0,0 +1,4 @@ +[patterns] +** = native +.hgignore = native +.hgeol = native diff --git a/.hgignore b/.hgignore new file mode 100644 index 00000000..6d1136f2 --- /dev/null +++ b/.hgignore @@ -0,0 +1,14 @@ +.*\.py[co]$ +.*~$ +.*\.orig$ +.*\#.*$ +.*@.*$ +\.coverage$ +htmlcov$ +\.DS_Store$ +venv$ +distribute_setup.py$ +distribute-\d+.\d+.\d+.tar.gz$ +build$ +dist$ +.*\.egg-info$ diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 00000000..d25b4465 --- /dev/null +++ b/AUTHORS @@ -0,0 +1,26 @@ +A. Jesse Jiryu Davis +Aaron Griffith +Andrew Svetlov +Anthony Baire +Antoine Pitrou +Arnaud Faure +Aymeric Augustin +Brett Cannon +Charles-François Natali +Christian Heimes +Donald Stufft +Eli Bendersky +Geert Jansen +Giampaolo Rodola' +Guido van Rossum : creator of the Tulip project and author of the PEP 3156 +Gustavo Carneiro +Jeff Quast +Jonathan Slenders +Nikolay Kim +Richard Oudkerk +Saúl Ibarra Corretgé +Serhiy Storchaka +Vajrasky Kok +Victor Stinner +Vladimir Kryachko +Yury Selivanov diff --git a/COPYING b/COPYING new file mode 100644 index 00000000..11069edd --- /dev/null +++ b/COPYING @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..314325c8 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,11 @@ +include AUTHORS COPYING +include Makefile +include overlapped.c pypi.bat +include check.py runtests.py +include update_stdlib.sh + +recursive-include examples *.py +recursive-include tests *.crt +recursive-include tests *.key +recursive-include tests *.pem +recursive-include tests *.py diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..eda02f2d --- /dev/null +++ b/Makefile @@ -0,0 +1,60 @@ +# Some simple testing tasks (sorry, UNIX only). + +PYTHON=python3 +VERBOSE=$(V) +V= 0 +FLAGS= + +test: + $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS) + +vtest: + $(PYTHON) runtests.py -v 1 $(FLAGS) + +testloop: + while sleep 1; do $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS); done + +# See runtests.py for coverage installation instructions. +cov coverage: + $(PYTHON) runtests.py --coverage -v $(VERBOSE) $(FLAGS) + +check: + $(PYTHON) check.py + +# Requires "pip install pep8". +pep8: check + pep8 --ignore E125,E127,E226 tests asyncio + +clean: + rm -rf `find . -name __pycache__` + rm -f `find . -type f -name '*.py[co]' ` + rm -f `find . -type f -name '*~' ` + rm -f `find . -type f -name '.*~' ` + rm -f `find . -type f -name '@*' ` + rm -f `find . -type f -name '#*#' ` + rm -f `find . -type f -name '*.orig' ` + rm -f `find . -type f -name '*.rej' ` + rm -rf dist + rm -f .coverage + rm -rf htmlcov + rm -rf build + rm -rf asyncio.egg-info + rm -f MANIFEST + + +# For distribution builders only! +# Push a source distribution for Python 3.3 to PyPI. +# You must update the version in setup.py first. +# A PyPI user configuration in ~/.pypirc is required; +# you can create a suitable confifuration using +# python setup.py register +pypi: clean + python3.3 setup.py sdist upload + +# The corresponding action on Windows is pypi.bat. For that to work, +# you need to install wheel and setuptools. The easiest way is to get +# pip using the get-pip.py script found here: +# https://pip.pypa.io/en/latest/installing.html#install-pip +# That will install setuptools and pip; then you can just do +# \Python33\python.exe -m pip install wheel +# after which the pypi.bat script should work. diff --git a/README b/README new file mode 100644 index 00000000..2f3150a2 --- /dev/null +++ b/README @@ -0,0 +1,44 @@ +Tulip is the codename for my reference implementation of PEP 3156. + +PEP 3156: http://www.python.org/dev/peps/pep-3156/ + +*** This requires Python 3.3 or later! *** + +Copyright/license: Open source, Apache 2.0. Enjoy. + +Master Mercurial repo: http://code.google.com/p/tulip/ + +The actual code lives in the 'asyncio' subdirectory. +Tests are in the 'tests' subdirectory. + +To run tests: + - make test + +To run coverage (coverage package is required): + - make coverage + +On Windows, things are a little more complicated. Assume 'P' is your +Python binary (for example C:\Python33\python.exe). + +You must first build the _overlapped.pyd extension and have it placed +in the asyncio directory, as follows: + + C> P setup.py build_ext --inplace + +If this complains about vcvars.bat, you probably don't have the +required version of Visual Studio installed. Compiling extensions for +Python 3.3 requires Microsoft Visual C++ 2010 (MSVC 10.0) of any +edition; you can download Visual Studio Express 2010 for free from +http://www.visualstudio.com/downloads (scroll down to Visual C++ 2010 +Express). + +Once you have built the _overlapped.pyd extension successfully you can +run the tests as follows: + + C> P runtests.py + +And coverage as follows: + + C> P runtests.py --coverage + +--Guido van Rossum diff --git a/asyncio/__init__.py b/asyncio/__init__.py new file mode 100644 index 00000000..3911fb40 --- /dev/null +++ b/asyncio/__init__.py @@ -0,0 +1,48 @@ +"""The asyncio package, tracking PEP 3156.""" + +import sys + +# The selectors module is in the stdlib in Python 3.4 but not in 3.3. +# Do this first, so the other submodules can use "from . import selectors". +# Prefer asyncio/selectors.py over the stdlib one, as ours may be newer. +try: + from . import selectors +except ImportError: + import selectors # Will also be exported. + +if sys.platform == 'win32': + # Similar thing for _overlapped. + try: + from . import _overlapped + except ImportError: + import _overlapped # Will also be exported. + +# This relies on each of the submodules having an __all__ variable. +from .coroutines import * +from .events import * +from .futures import * +from .locks import * +from .protocols import * +from .queues import * +from .streams import * +from .subprocess import * +from .tasks import * +from .transports import * + +__all__ = (coroutines.__all__ + + events.__all__ + + futures.__all__ + + locks.__all__ + + protocols.__all__ + + queues.__all__ + + streams.__all__ + + subprocess.__all__ + + tasks.__all__ + + transports.__all__) + +if sys.platform == 'win32': # pragma: no cover + from .windows_events import * + __all__ += windows_events.__all__ +else: + from .unix_events import * # pragma: no cover + __all__ += unix_events.__all__ diff --git a/asyncio/base_events.py b/asyncio/base_events.py new file mode 100644 index 00000000..0c7316ea --- /dev/null +++ b/asyncio/base_events.py @@ -0,0 +1,1113 @@ +"""Base implementation of event loop. + +The event loop can be broken up into a multiplexer (the part +responsible for notifying us of I/O events) and the event loop proper, +which wraps a multiplexer with functionality for scheduling callbacks, +immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. +""" + + +import collections +import concurrent.futures +import heapq +import inspect +import logging +import os +import socket +import subprocess +import time +import traceback +import sys + +from . import coroutines +from . import events +from . import futures +from . import tasks +from .coroutines import coroutine +from .log import logger + + +__all__ = ['BaseEventLoop', 'Server'] + + +# Argument for default thread pool executor creation. +_MAX_WORKERS = 5 + +# Minimum number of _scheduled timer handles before cleanup of +# cancelled handles is performed. +_MIN_SCHEDULED_TIMER_HANDLES = 100 + +# Minimum fraction of _scheduled timer handles that are cancelled +# before cleanup of cancelled handles is performed. +_MIN_CANCELLED_TIMER_HANDLES_FRACTION = 0.5 + +def _format_handle(handle): + cb = handle._callback + if inspect.ismethod(cb) and isinstance(cb.__self__, tasks.Task): + # format the task + return repr(cb.__self__) + else: + return str(handle) + + +def _format_pipe(fd): + if fd == subprocess.PIPE: + return '' + elif fd == subprocess.STDOUT: + return '' + else: + return repr(fd) + + +class _StopError(BaseException): + """Raised to stop the event loop.""" + + +def _check_resolved_address(sock, address): + # Ensure that the address is already resolved to avoid the trap of hanging + # the entire event loop when the address requires doing a DNS lookup. + family = sock.family + if family == socket.AF_INET: + host, port = address + elif family == socket.AF_INET6: + host, port = address[:2] + else: + return + + type_mask = 0 + if hasattr(socket, 'SOCK_NONBLOCK'): + type_mask |= socket.SOCK_NONBLOCK + if hasattr(socket, 'SOCK_CLOEXEC'): + type_mask |= socket.SOCK_CLOEXEC + # Use getaddrinfo(flags=AI_NUMERICHOST) to ensure that the address is + # already resolved. + try: + socket.getaddrinfo(host, port, + family=family, + type=(sock.type & ~type_mask), + proto=sock.proto, + flags=socket.AI_NUMERICHOST) + except socket.gaierror as err: + raise ValueError("address must be resolved (IP address), got %r: %s" + % (address, err)) + +def _raise_stop_error(*args): + raise _StopError + + +def _run_until_complete_cb(fut): + exc = fut._exception + if (isinstance(exc, BaseException) + and not isinstance(exc, Exception)): + # Issue #22429: run_forever() already finished, no need to + # stop it. + return + _raise_stop_error() + + +class Server(events.AbstractServer): + + def __init__(self, loop, sockets): + self._loop = loop + self.sockets = sockets + self._active_count = 0 + self._waiters = [] + + def __repr__(self): + return '<%s sockets=%r>' % (self.__class__.__name__, self.sockets) + + def _attach(self): + assert self.sockets is not None + self._active_count += 1 + + def _detach(self): + assert self._active_count > 0 + self._active_count -= 1 + if self._active_count == 0 and self.sockets is None: + self._wakeup() + + def close(self): + sockets = self.sockets + if sockets is None: + return + self.sockets = None + for sock in sockets: + self._loop._stop_serving(sock) + if self._active_count == 0: + self._wakeup() + + def _wakeup(self): + waiters = self._waiters + self._waiters = None + for waiter in waiters: + if not waiter.done(): + waiter.set_result(waiter) + + @coroutine + def wait_closed(self): + if self.sockets is None or self._waiters is None: + return + waiter = futures.Future(loop=self._loop) + self._waiters.append(waiter) + yield from waiter + + +class BaseEventLoop(events.AbstractEventLoop): + + def __init__(self): + self._timer_cancelled_count = 0 + self._closed = False + self._ready = collections.deque() + self._scheduled = [] + self._default_executor = None + self._internal_fds = 0 + self._running = False + self._clock_resolution = time.get_clock_info('monotonic').resolution + self._exception_handler = None + self._debug = (not sys.flags.ignore_environment + and bool(os.environ.get('PYTHONASYNCIODEBUG'))) + # In debug mode, if the execution of a callback or a step of a task + # exceed this duration in seconds, the slow callback/task is logged. + self.slow_callback_duration = 0.1 + + def __repr__(self): + return ('<%s running=%s closed=%s debug=%s>' + % (self.__class__.__name__, self.is_running(), + self.is_closed(), self.get_debug())) + + def create_task(self, coro): + """Schedule a coroutine object. + + Return a task object. + """ + self._check_closed() + task = tasks.Task(coro, loop=self) + if task._source_traceback: + del task._source_traceback[-1] + return task + + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None, server=None): + """Create socket transport.""" + raise NotImplementedError + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + server_side=False, server_hostname=None, + extra=None, server=None): + """Create SSL transport.""" + raise NotImplementedError + + def _make_datagram_transport(self, sock, protocol, + address=None, waiter=None, extra=None): + """Create datagram transport.""" + raise NotImplementedError + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create read pipe transport.""" + raise NotImplementedError + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create write pipe transport.""" + raise NotImplementedError + + @coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + """Create subprocess transport.""" + raise NotImplementedError + + def _write_to_self(self): + """Write a byte to self-pipe, to wake up the event loop. + + This may be called from a different thread. + + The subclass is responsible for implementing the self-pipe. + """ + raise NotImplementedError + + def _process_events(self, event_list): + """Process selector events.""" + raise NotImplementedError + + def _check_closed(self): + if self._closed: + raise RuntimeError('Event loop is closed') + + def run_forever(self): + """Run until stop() is called.""" + self._check_closed() + if self._running: + raise RuntimeError('Event loop is running.') + self._running = True + try: + while True: + try: + self._run_once() + except _StopError: + break + finally: + self._running = False + + def run_until_complete(self, future): + """Run until the Future is done. + + If the argument is a coroutine, it is wrapped in a Task. + + WARNING: It would be disastrous to call run_until_complete() + with the same coroutine twice -- it would wrap it in two + different Tasks and that can't be good. + + Return the Future's result, or raise its exception. + """ + self._check_closed() + + new_task = not isinstance(future, futures.Future) + future = tasks.async(future, loop=self) + if new_task: + # An exception is raised if the future didn't complete, so there + # is no need to log the "destroy pending task" message + future._log_destroy_pending = False + + future.add_done_callback(_run_until_complete_cb) + try: + self.run_forever() + except: + if new_task and future.done() and not future.cancelled(): + # The coroutine raised a BaseException. Consume the exception + # to not log a warning, the caller doesn't have access to the + # local task. + future.exception() + raise + future.remove_done_callback(_run_until_complete_cb) + if not future.done(): + raise RuntimeError('Event loop stopped before Future completed.') + + return future.result() + + def stop(self): + """Stop running the event loop. + + Every callback scheduled before stop() is called will run. Callbacks + scheduled after stop() is called will not run. However, those callbacks + will run if run_forever is called again later. + """ + self.call_soon(_raise_stop_error) + + def close(self): + """Close the event loop. + + This clears the queues and shuts down the executor, + but does not wait for the executor to finish. + + The event loop must not be running. + """ + if self._running: + raise RuntimeError("Cannot close a running event loop") + if self._closed: + return + if self._debug: + logger.debug("Close %r", self) + self._closed = True + self._ready.clear() + self._scheduled.clear() + executor = self._default_executor + if executor is not None: + self._default_executor = None + executor.shutdown(wait=False) + + def is_closed(self): + """Returns True if the event loop was closed.""" + return self._closed + + def is_running(self): + """Returns True if the event loop is running.""" + return self._running + + def time(self): + """Return the time according to the event loop's clock. + + This is a float expressed in seconds since an epoch, but the + epoch, precision, accuracy and drift are unspecified and may + differ per event loop. + """ + return time.monotonic() + + def call_later(self, delay, callback, *args): + """Arrange for a callback to be called at a given time. + + Return a Handle: an opaque object with a cancel() method that + can be used to cancel the call. + + The delay can be an int or float, expressed in seconds. It is + always relative to the current time. + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + timer = self.call_at(self.time() + delay, callback, *args) + if timer._source_traceback: + del timer._source_traceback[-1] + return timer + + def call_at(self, when, callback, *args): + """Like call_later(), but uses an absolute time. + + Absolute time corresponds to the event loop's time() method. + """ + if (coroutines.iscoroutine(callback) + or coroutines.iscoroutinefunction(callback)): + raise TypeError("coroutines cannot be used with call_at()") + self._check_closed() + if self._debug: + self._assert_is_current_event_loop() + timer = events.TimerHandle(when, callback, args, self) + if timer._source_traceback: + del timer._source_traceback[-1] + heapq.heappush(self._scheduled, timer) + timer._scheduled = True + return timer + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue: callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + handle = self._call_soon(callback, args, check_loop=True) + if handle._source_traceback: + del handle._source_traceback[-1] + return handle + + def _call_soon(self, callback, args, check_loop): + if (coroutines.iscoroutine(callback) + or coroutines.iscoroutinefunction(callback)): + raise TypeError("coroutines cannot be used with call_soon()") + if self._debug and check_loop: + self._assert_is_current_event_loop() + self._check_closed() + handle = events.Handle(callback, args, self) + if handle._source_traceback: + del handle._source_traceback[-1] + self._ready.append(handle) + return handle + + def _assert_is_current_event_loop(self): + """Asserts that this event loop is the current event loop. + + Non-thread-safe methods of this class make this assumption and will + likely behave incorrectly when the assumption is violated. + + Should only be called when (self._debug == True). The caller is + responsible for checking this condition for performance reasons. + """ + try: + current = events.get_event_loop() + except AssertionError: + return + if current is not self: + raise RuntimeError( + "Non-thread-safe operation invoked on an event loop other " + "than the current one") + + def call_soon_threadsafe(self, callback, *args): + """Like call_soon(), but thread-safe.""" + handle = self._call_soon(callback, args, check_loop=False) + if handle._source_traceback: + del handle._source_traceback[-1] + self._write_to_self() + return handle + + def run_in_executor(self, executor, callback, *args): + if (coroutines.iscoroutine(callback) + or coroutines.iscoroutinefunction(callback)): + raise TypeError("coroutines cannot be used with run_in_executor()") + self._check_closed() + if isinstance(callback, events.Handle): + assert not args + assert not isinstance(callback, events.TimerHandle) + if callback._cancelled: + f = futures.Future(loop=self) + f.set_result(None) + return f + callback, args = callback._callback, callback._args + if executor is None: + executor = self._default_executor + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + self._default_executor = executor + return futures.wrap_future(executor.submit(callback, *args), loop=self) + + def set_default_executor(self, executor): + self._default_executor = executor + + def _getaddrinfo_debug(self, host, port, family, type, proto, flags): + msg = ["%s:%r" % (host, port)] + if family: + msg.append('family=%r' % family) + if type: + msg.append('type=%r' % type) + if proto: + msg.append('proto=%r' % proto) + if flags: + msg.append('flags=%r' % flags) + msg = ', '.join(msg) + logger.debug('Get address info %s', msg) + + t0 = self.time() + addrinfo = socket.getaddrinfo(host, port, family, type, proto, flags) + dt = self.time() - t0 + + msg = ('Getting address info %s took %.3f ms: %r' + % (msg, dt * 1e3, addrinfo)) + if dt >= self.slow_callback_duration: + logger.info(msg) + else: + logger.debug(msg) + return addrinfo + + def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): + if self._debug: + return self.run_in_executor(None, self._getaddrinfo_debug, + host, port, family, type, proto, flags) + else: + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) + + def getnameinfo(self, sockaddr, flags=0): + return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + + @coroutine + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=None, family=0, proto=0, flags=0, sock=None, + local_addr=None, server_hostname=None): + """Connect to a TCP server. + + Create a streaming transport connection to a given Internet host and + port: socket family AF_INET or socket.AF_INET6 depending on host (or + family if specified), socket type SOCK_STREAM. protocol_factory must be + a callable returning a protocol instance. + + This method is a coroutine which will try to establish the connection + in the background. When successful, the coroutine returns a + (transport, protocol) pair. + """ + if server_hostname is not None and not ssl: + raise ValueError('server_hostname is only meaningful with ssl') + + if server_hostname is None and ssl: + # Use host as default for server_hostname. It is an error + # if host is empty or not set, e.g. when an + # already-connected socket was passed or when only a port + # is given. To avoid this error, you can pass + # server_hostname='' -- this will bypass the hostname + # check. (This also means that if host is a numeric + # IP/IPv6 address, we will attempt to verify that exact + # address; this will probably fail, but it is possible to + # create a certificate for a specific IP address, so we + # don't judge it here.) + if not host: + raise ValueError('You must set server_hostname ' + 'when using ssl without a host') + server_hostname = host + + if host is not None or port is not None: + if sock is not None: + raise ValueError( + 'host/port and sock can not be specified at the same time') + + f1 = self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + fs = [f1] + if local_addr is not None: + f2 = self.getaddrinfo( + *local_addr, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + fs.append(f2) + else: + f2 = None + + yield from tasks.wait(fs, loop=self) + + infos = f1.result() + if not infos: + raise OSError('getaddrinfo() returned empty list') + if f2 is not None: + laddr_infos = f2.result() + if not laddr_infos: + raise OSError('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + if f2 is not None: + for _, _, _, _, laddr in laddr_infos: + try: + sock.bind(laddr) + break + except OSError as exc: + exc = OSError( + exc.errno, 'error while ' + 'attempting to bind on address ' + '{!r}: {}'.format( + laddr, exc.strerror.lower())) + exceptions.append(exc) + else: + sock.close() + sock = None + continue + if self._debug: + logger.debug("connect %r to %r", sock, address) + yield from self.sock_connect(sock, address) + except OSError as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + except: + if sock is not None: + sock.close() + raise + else: + break + else: + if len(exceptions) == 1: + raise exceptions[0] + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise OSError('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + + elif sock is None: + raise ValueError( + 'host and port was not specified and no sock specified') + + sock.setblocking(False) + + transport, protocol = yield from self._create_connection_transport( + sock, protocol_factory, ssl, server_hostname) + if self._debug: + # Get the socket from the transport because SSL transport closes + # the old socket and creates a new SSL socket + sock = transport.get_extra_info('socket') + logger.debug("%r connected to %s:%r: (%r, %r)", + sock, host, port, transport, protocol) + return transport, protocol + + @coroutine + def _create_connection_transport(self, sock, protocol_factory, ssl, + server_hostname): + protocol = protocol_factory() + waiter = futures.Future(loop=self) + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + transport = self._make_ssl_transport( + sock, protocol, sslcontext, waiter, + server_side=False, server_hostname=server_hostname) + else: + transport = self._make_socket_transport(sock, protocol, waiter) + + yield from waiter + return transport, protocol + + @coroutine + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + """Create datagram connection.""" + if not (local_addr or remote_addr): + if family == 0: + raise ValueError('unexpected address family') + addr_pairs_info = (((family, proto), (None, None)),) + else: + # join address by (family, protocol) + addr_infos = collections.OrderedDict() + for idx, addr in ((0, local_addr), (1, remote_addr)): + if addr is not None: + assert isinstance(addr, tuple) and len(addr) == 2, ( + '2-tuple is expected') + + infos = yield from self.getaddrinfo( + *addr, family=family, type=socket.SOCK_DGRAM, + proto=proto, flags=flags) + if not infos: + raise OSError('getaddrinfo() returned empty list') + + for fam, _, pro, _, address in infos: + key = (fam, pro) + if key not in addr_infos: + addr_infos[key] = [None, None] + addr_infos[key][idx] = address + + # each addr has to have info for each (family, proto) pair + addr_pairs_info = [ + (key, addr_pair) for key, addr_pair in addr_infos.items() + if not ((local_addr and addr_pair[0] is None) or + (remote_addr and addr_pair[1] is None))] + + if not addr_pairs_info: + raise ValueError('can not get address information') + + exceptions = [] + + for ((family, proto), + (local_address, remote_address)) in addr_pairs_info: + sock = None + r_addr = None + try: + sock = socket.socket( + family=family, type=socket.SOCK_DGRAM, proto=proto) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(False) + + if local_addr: + sock.bind(local_address) + if remote_addr: + yield from self.sock_connect(sock, remote_address) + r_addr = remote_address + except OSError as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + except: + if sock is not None: + sock.close() + raise + else: + break + else: + raise exceptions[0] + + protocol = protocol_factory() + waiter = futures.Future(loop=self) + transport = self._make_datagram_transport(sock, protocol, r_addr, + waiter) + if self._debug: + if local_addr: + logger.info("Datagram endpoint local_addr=%r remote_addr=%r " + "created: (%r, %r)", + local_addr, remote_addr, transport, protocol) + else: + logger.debug("Datagram endpoint remote_addr=%r created: " + "(%r, %r)", + remote_addr, transport, protocol) + yield from waiter + return transport, protocol + + @coroutine + def create_server(self, protocol_factory, host=None, port=None, + *, + family=socket.AF_UNSPEC, + flags=socket.AI_PASSIVE, + sock=None, + backlog=100, + ssl=None, + reuse_address=None): + """Create a TCP server bound to host and port. + + Return a Server object which can be used to stop the service. + + This method is a coroutine. + """ + if isinstance(ssl, bool): + raise TypeError('ssl argument must be an SSLContext or None') + if host is not None or port is not None: + if sock is not None: + raise ValueError( + 'host/port and sock can not be specified at the same time') + + AF_INET6 = getattr(socket, 'AF_INET6', 0) + if reuse_address is None: + reuse_address = os.name == 'posix' and sys.platform != 'cygwin' + sockets = [] + if host == '': + host = None + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=0, flags=flags) + if not infos: + raise OSError('getaddrinfo() returned empty list') + + completed = False + try: + for res in infos: + af, socktype, proto, canonname, sa = res + try: + sock = socket.socket(af, socktype, proto) + except socket.error: + # Assume it's a bad family/type/protocol combination. + if self._debug: + logger.warning('create_server() failed to create ' + 'socket.socket(%r, %r, %r)', + af, socktype, proto, exc_info=True) + continue + sockets.append(sock) + if reuse_address: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, + True) + # Disable IPv4/IPv6 dual stack support (enabled by + # default on Linux) which makes a single socket + # listen on both address families. + if af == AF_INET6 and hasattr(socket, 'IPPROTO_IPV6'): + sock.setsockopt(socket.IPPROTO_IPV6, + socket.IPV6_V6ONLY, + True) + try: + sock.bind(sa) + except OSError as err: + raise OSError(err.errno, 'error while attempting ' + 'to bind on address %r: %s' + % (sa, err.strerror.lower())) + completed = True + finally: + if not completed: + for sock in sockets: + sock.close() + else: + if sock is None: + raise ValueError('Neither host/port nor sock were specified') + sockets = [sock] + + server = Server(self, sockets) + for sock in sockets: + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock, ssl, server) + if self._debug: + logger.info("%r is serving", server) + return server + + @coroutine + def connect_read_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future(loop=self) + transport = self._make_read_pipe_transport(pipe, protocol, waiter) + yield from waiter + if self._debug: + logger.debug('Read pipe %r connected: (%r, %r)', + pipe.fileno(), transport, protocol) + return transport, protocol + + @coroutine + def connect_write_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future(loop=self) + transport = self._make_write_pipe_transport(pipe, protocol, waiter) + yield from waiter + if self._debug: + logger.debug('Write pipe %r connected: (%r, %r)', + pipe.fileno(), transport, protocol) + return transport, protocol + + def _log_subprocess(self, msg, stdin, stdout, stderr): + info = [msg] + if stdin is not None: + info.append('stdin=%s' % _format_pipe(stdin)) + if stdout is not None and stderr == subprocess.STDOUT: + info.append('stdout=stderr=%s' % _format_pipe(stdout)) + else: + if stdout is not None: + info.append('stdout=%s' % _format_pipe(stdout)) + if stderr is not None: + info.append('stderr=%s' % _format_pipe(stderr)) + logger.debug(' '.join(info)) + + @coroutine + def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + universal_newlines=False, shell=True, bufsize=0, + **kwargs): + if not isinstance(cmd, (bytes, str)): + raise ValueError("cmd must be a string") + if universal_newlines: + raise ValueError("universal_newlines must be False") + if not shell: + raise ValueError("shell must be True") + if bufsize != 0: + raise ValueError("bufsize must be 0") + protocol = protocol_factory() + if self._debug: + # don't log parameters: they may contain sensitive information + # (password) and may be too long + debug_log = 'run shell command %r' % cmd + self._log_subprocess(debug_log, stdin, stdout, stderr) + transport = yield from self._make_subprocess_transport( + protocol, cmd, True, stdin, stdout, stderr, bufsize, **kwargs) + if self._debug: + logger.info('%s: %r' % (debug_log, transport)) + return transport, protocol + + @coroutine + def subprocess_exec(self, protocol_factory, program, *args, + stdin=subprocess.PIPE, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, universal_newlines=False, + shell=False, bufsize=0, **kwargs): + if universal_newlines: + raise ValueError("universal_newlines must be False") + if shell: + raise ValueError("shell must be False") + if bufsize != 0: + raise ValueError("bufsize must be 0") + popen_args = (program,) + args + for arg in popen_args: + if not isinstance(arg, (str, bytes)): + raise TypeError("program arguments must be " + "a bytes or text string, not %s" + % type(arg).__name__) + protocol = protocol_factory() + if self._debug: + # don't log parameters: they may contain sensitive information + # (password) and may be too long + debug_log = 'execute program %r' % program + self._log_subprocess(debug_log, stdin, stdout, stderr) + transport = yield from self._make_subprocess_transport( + protocol, popen_args, False, stdin, stdout, stderr, + bufsize, **kwargs) + if self._debug: + logger.info('%s: %r' % (debug_log, transport)) + return transport, protocol + + def set_exception_handler(self, handler): + """Set handler as the new event loop exception handler. + + If handler is None, the default exception handler will + be set. + + If handler is a callable object, it should have a + signature matching '(loop, context)', where 'loop' + will be a reference to the active event loop, 'context' + will be a dict object (see `call_exception_handler()` + documentation for details about context). + """ + if handler is not None and not callable(handler): + raise TypeError('A callable object or None is expected, ' + 'got {!r}'.format(handler)) + self._exception_handler = handler + + def default_exception_handler(self, context): + """Default exception handler. + + This is called when an exception occurs and no exception + handler is set, and can be called by a custom exception + handler that wants to defer to the default behavior. + + The context parameter has the same meaning as in + `call_exception_handler()`. + """ + message = context.get('message') + if not message: + message = 'Unhandled exception in event loop' + + exception = context.get('exception') + if exception is not None: + exc_info = (type(exception), exception, exception.__traceback__) + else: + exc_info = False + + log_lines = [message] + for key in sorted(context): + if key in {'message', 'exception'}: + continue + value = context[key] + if key == 'source_traceback': + tb = ''.join(traceback.format_list(value)) + value = 'Object created at (most recent call last):\n' + value += tb.rstrip() + else: + value = repr(value) + log_lines.append('{}: {}'.format(key, value)) + + logger.error('\n'.join(log_lines), exc_info=exc_info) + + def call_exception_handler(self, context): + """Call the current event loop's exception handler. + + The context argument is a dict containing the following keys: + + - 'message': Error message; + - 'exception' (optional): Exception object; + - 'future' (optional): Future instance; + - 'handle' (optional): Handle instance; + - 'protocol' (optional): Protocol instance; + - 'transport' (optional): Transport instance; + - 'socket' (optional): Socket instance. + + New keys maybe introduced in the future. + + Note: do not overload this method in an event loop subclass. + For custom exception handling, use the + `set_exception_handler()` method. + """ + if self._exception_handler is None: + try: + self.default_exception_handler(context) + except Exception: + # Second protection layer for unexpected errors + # in the default implementation, as well as for subclassed + # event loops with overloaded "default_exception_handler". + logger.error('Exception in default exception handler', + exc_info=True) + else: + try: + self._exception_handler(self, context) + except Exception as exc: + # Exception in the user set custom exception handler. + try: + # Let's try default handler. + self.default_exception_handler({ + 'message': 'Unhandled error in exception handler', + 'exception': exc, + 'context': context, + }) + except Exception: + # Guard 'default_exception_handler' in case it is + # overloaded. + logger.error('Exception in default exception handler ' + 'while handling an unexpected error ' + 'in custom exception handler', + exc_info=True) + + def _add_callback(self, handle): + """Add a Handle to _scheduled (TimerHandle) or _ready.""" + assert isinstance(handle, events.Handle), 'A Handle is required here' + if handle._cancelled: + return + assert not isinstance(handle, events.TimerHandle) + self._ready.append(handle) + + def _add_callback_signalsafe(self, handle): + """Like _add_callback() but called from a signal handler.""" + self._add_callback(handle) + self._write_to_self() + + def _timer_handle_cancelled(self, handle): + """Notification that a TimerHandle has been cancelled.""" + if handle._scheduled: + self._timer_cancelled_count += 1 + + def _run_once(self): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + + sched_count = len(self._scheduled) + if (sched_count > _MIN_SCHEDULED_TIMER_HANDLES and + self._timer_cancelled_count / sched_count > + _MIN_CANCELLED_TIMER_HANDLES_FRACTION): + # Remove delayed calls that were cancelled if their number + # is too high + new_scheduled = [] + for handle in self._scheduled: + if handle._cancelled: + handle._scheduled = False + else: + new_scheduled.append(handle) + + heapq.heapify(new_scheduled) + self._scheduled = new_scheduled + self._timer_cancelled_count = 0 + else: + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0]._cancelled: + self._timer_cancelled_count -= 1 + handle = heapq.heappop(self._scheduled) + handle._scheduled = False + + timeout = None + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0]._when + timeout = max(0, when - self.time()) + + if self._debug and timeout != 0: + t0 = self.time() + event_list = self._selector.select(timeout) + dt = self.time() - t0 + if dt >= 1.0: + level = logging.INFO + else: + level = logging.DEBUG + nevent = len(event_list) + if timeout is None: + logger.log(level, 'poll took %.3f ms: %s events', + dt * 1e3, nevent) + elif nevent: + logger.log(level, + 'poll %.3f ms took %.3f ms: %s events', + timeout * 1e3, dt * 1e3, nevent) + elif dt >= 1.0: + logger.log(level, + 'poll %.3f ms took %.3f ms: timeout', + timeout * 1e3, dt * 1e3) + else: + event_list = self._selector.select(timeout) + self._process_events(event_list) + + # Handle 'later' callbacks that are ready. + end_time = self.time() + self._clock_resolution + while self._scheduled: + handle = self._scheduled[0] + if handle._when >= end_time: + break + handle = heapq.heappop(self._scheduled) + handle._scheduled = False + self._ready.append(handle) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is thread-safe without using locks. + ntodo = len(self._ready) + for i in range(ntodo): + handle = self._ready.popleft() + if handle._cancelled: + continue + if self._debug: + t0 = self.time() + handle._run() + dt = self.time() - t0 + if dt >= self.slow_callback_duration: + logger.warning('Executing %s took %.3f seconds', + _format_handle(handle), dt) + else: + handle._run() + handle = None # Needed to break cycles when an exception occurs. + + def get_debug(self): + return self._debug + + def set_debug(self, enabled): + self._debug = enabled diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py new file mode 100644 index 00000000..d0087793 --- /dev/null +++ b/asyncio/base_subprocess.py @@ -0,0 +1,197 @@ +import collections +import subprocess + +from . import protocols +from . import transports +from .coroutines import coroutine +from .log import logger + + +class BaseSubprocessTransport(transports.SubprocessTransport): + + def __init__(self, loop, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + super().__init__(extra) + self._protocol = protocol + self._loop = loop + self._pid = None + + self._pipes = {} + if stdin == subprocess.PIPE: + self._pipes[0] = None + if stdout == subprocess.PIPE: + self._pipes[1] = None + if stderr == subprocess.PIPE: + self._pipes[2] = None + self._pending_calls = collections.deque() + self._finished = False + self._returncode = None + self._start(args=args, shell=shell, stdin=stdin, stdout=stdout, + stderr=stderr, bufsize=bufsize, **kwargs) + self._pid = self._proc.pid + self._extra['subprocess'] = self._proc + if self._loop.get_debug(): + if isinstance(args, (bytes, str)): + program = args + else: + program = args[0] + logger.debug('process %r created: pid %s', + program, self._pid) + + def __repr__(self): + info = [self.__class__.__name__, 'pid=%s' % self._pid] + if self._returncode is not None: + info.append('returncode=%s' % self._returncode) + + stdin = self._pipes.get(0) + if stdin is not None: + info.append('stdin=%s' % stdin.pipe) + + stdout = self._pipes.get(1) + stderr = self._pipes.get(2) + if stdout is not None and stderr is stdout: + info.append('stdout=stderr=%s' % stdout.pipe) + else: + if stdout is not None: + info.append('stdout=%s' % stdout.pipe) + if stderr is not None: + info.append('stderr=%s' % stderr.pipe) + + return '<%s>' % ' '.join(info) + + def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): + raise NotImplementedError + + def _make_write_subprocess_pipe_proto(self, fd): + raise NotImplementedError + + def _make_read_subprocess_pipe_proto(self, fd): + raise NotImplementedError + + def close(self): + for proto in self._pipes.values(): + proto.pipe.close() + if self._returncode is None: + self.terminate() + + def get_pid(self): + return self._pid + + def get_returncode(self): + return self._returncode + + def get_pipe_transport(self, fd): + if fd in self._pipes: + return self._pipes[fd].pipe + else: + return None + + def send_signal(self, signal): + self._proc.send_signal(signal) + + def terminate(self): + self._proc.terminate() + + def kill(self): + self._proc.kill() + + @coroutine + def _post_init(self): + proc = self._proc + loop = self._loop + if proc.stdin is not None: + _, pipe = yield from loop.connect_write_pipe( + lambda: WriteSubprocessPipeProto(self, 0), + proc.stdin) + self._pipes[0] = pipe + if proc.stdout is not None: + _, pipe = yield from loop.connect_read_pipe( + lambda: ReadSubprocessPipeProto(self, 1), + proc.stdout) + self._pipes[1] = pipe + if proc.stderr is not None: + _, pipe = yield from loop.connect_read_pipe( + lambda: ReadSubprocessPipeProto(self, 2), + proc.stderr) + self._pipes[2] = pipe + + assert self._pending_calls is not None + + self._loop.call_soon(self._protocol.connection_made, self) + for callback, data in self._pending_calls: + self._loop.call_soon(callback, *data) + self._pending_calls = None + + def _call(self, cb, *data): + if self._pending_calls is not None: + self._pending_calls.append((cb, data)) + else: + self._loop.call_soon(cb, *data) + + def _pipe_connection_lost(self, fd, exc): + self._call(self._protocol.pipe_connection_lost, fd, exc) + self._try_finish() + + def _pipe_data_received(self, fd, data): + self._call(self._protocol.pipe_data_received, fd, data) + + def _process_exited(self, returncode): + assert returncode is not None, returncode + assert self._returncode is None, self._returncode + if self._loop.get_debug(): + logger.info('%r exited with return code %r', + self, returncode) + self._returncode = returncode + self._call(self._protocol.process_exited) + self._try_finish() + + def _try_finish(self): + assert not self._finished + if self._returncode is None: + return + if all(p is not None and p.disconnected + for p in self._pipes.values()): + self._finished = True + self._loop.call_soon(self._call_connection_lost, None) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._proc = None + self._protocol = None + self._loop = None + + +class WriteSubprocessPipeProto(protocols.BaseProtocol): + + def __init__(self, proc, fd): + self.proc = proc + self.fd = fd + self.pipe = None + self.disconnected = False + + def connection_made(self, transport): + self.pipe = transport + + def __repr__(self): + return ('<%s fd=%s pipe=%r>' + % (self.__class__.__name__, self.fd, self.pipe)) + + def connection_lost(self, exc): + self.disconnected = True + self.proc._pipe_connection_lost(self.fd, exc) + + def pause_writing(self): + self.proc._protocol.pause_writing() + + def resume_writing(self): + self.proc._protocol.resume_writing() + + +class ReadSubprocessPipeProto(WriteSubprocessPipeProto, + protocols.Protocol): + + def data_received(self, data): + self.proc._pipe_data_received(self.fd, data) diff --git a/asyncio/constants.py b/asyncio/constants.py new file mode 100644 index 00000000..f9e12328 --- /dev/null +++ b/asyncio/constants.py @@ -0,0 +1,7 @@ +"""Constants.""" + +# After the connection is lost, log warnings after this many write()s. +LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5 + +# Seconds to wait before retrying accept(). +ACCEPT_RETRY_DELAY = 1 diff --git a/asyncio/coroutines.py b/asyncio/coroutines.py new file mode 100644 index 00000000..c28de95a --- /dev/null +++ b/asyncio/coroutines.py @@ -0,0 +1,195 @@ +__all__ = ['coroutine', + 'iscoroutinefunction', 'iscoroutine'] + +import functools +import inspect +import opcode +import os +import sys +import traceback +import types + +from . import events +from . import futures +from .log import logger + + +# Opcode of "yield from" instruction +_YIELD_FROM = opcode.opmap['YIELD_FROM'] + +# If you set _DEBUG to true, @coroutine will wrap the resulting +# generator objects in a CoroWrapper instance (defined below). That +# instance will log a message when the generator is never iterated +# over, which may happen when you forget to use "yield from" with a +# coroutine call. Note that the value of the _DEBUG flag is taken +# when the decorator is used, so to be of any use it must be set +# before you define your coroutines. A downside of using this feature +# is that tracebacks show entries for the CoroWrapper.__next__ method +# when _DEBUG is true. +_DEBUG = (not sys.flags.ignore_environment + and bool(os.environ.get('PYTHONASYNCIODEBUG'))) + + +# Check for CPython issue #21209 +def has_yield_from_bug(): + class MyGen: + def __init__(self): + self.send_args = None + def __iter__(self): + return self + def __next__(self): + return 42 + def send(self, *what): + self.send_args = what + return None + def yield_from_gen(gen): + yield from gen + value = (1, 2, 3) + gen = MyGen() + coro = yield_from_gen(gen) + next(coro) + coro.send(value) + return gen.send_args != (value,) +_YIELD_FROM_BUG = has_yield_from_bug() +del has_yield_from_bug + + +class CoroWrapper: + # Wrapper for coroutine object in _DEBUG mode. + + def __init__(self, gen, func): + assert inspect.isgenerator(gen), gen + self.gen = gen + self.func = func + self._source_traceback = traceback.extract_stack(sys._getframe(1)) + # __name__, __qualname__, __doc__ attributes are set by the coroutine() + # decorator + + def __repr__(self): + coro_repr = _format_coroutine(self) + if self._source_traceback: + frame = self._source_traceback[-1] + coro_repr += ', created at %s:%s' % (frame[0], frame[1]) + return '<%s %s>' % (self.__class__.__name__, coro_repr) + + def __iter__(self): + return self + + def __next__(self): + return next(self.gen) + + if _YIELD_FROM_BUG: + # For for CPython issue #21209: using "yield from" and a custom + # generator, generator.send(tuple) unpacks the tuple instead of passing + # the tuple unchanged. Check if the caller is a generator using "yield + # from" to decide if the parameter should be unpacked or not. + def send(self, *value): + frame = sys._getframe() + caller = frame.f_back + assert caller.f_lasti >= 0 + if caller.f_code.co_code[caller.f_lasti] != _YIELD_FROM: + value = value[0] + return self.gen.send(value) + else: + def send(self, value): + return self.gen.send(value) + + def throw(self, exc): + return self.gen.throw(exc) + + def close(self): + return self.gen.close() + + @property + def gi_frame(self): + return self.gen.gi_frame + + @property + def gi_running(self): + return self.gen.gi_running + + @property + def gi_code(self): + return self.gen.gi_code + + def __del__(self): + # Be careful accessing self.gen.frame -- self.gen might not exist. + gen = getattr(self, 'gen', None) + frame = getattr(gen, 'gi_frame', None) + if frame is not None and frame.f_lasti == -1: + msg = '%r was never yielded from' % self + tb = getattr(self, '_source_traceback', ()) + if tb: + tb = ''.join(traceback.format_list(tb)) + msg += ('\nCoroutine object created at ' + '(most recent call last):\n') + msg += tb.rstrip() + logger.error(msg) + + +def coroutine(func): + """Decorator to mark coroutines. + + If the coroutine is not yielded from before it is destroyed, + an error message is logged. + """ + if inspect.isgeneratorfunction(func): + coro = func + else: + @functools.wraps(func) + def coro(*args, **kw): + res = func(*args, **kw) + if isinstance(res, futures.Future) or inspect.isgenerator(res): + res = yield from res + return res + + if not _DEBUG: + wrapper = coro + else: + @functools.wraps(func) + def wrapper(*args, **kwds): + w = CoroWrapper(coro(*args, **kwds), func) + if w._source_traceback: + del w._source_traceback[-1] + w.__name__ = func.__name__ + if hasattr(func, '__qualname__'): + w.__qualname__ = func.__qualname__ + w.__doc__ = func.__doc__ + return w + + wrapper._is_coroutine = True # For iscoroutinefunction(). + return wrapper + + +def iscoroutinefunction(func): + """Return True if func is a decorated coroutine function.""" + return getattr(func, '_is_coroutine', False) + + +_COROUTINE_TYPES = (types.GeneratorType, CoroWrapper) + +def iscoroutine(obj): + """Return True if obj is a coroutine object.""" + return isinstance(obj, _COROUTINE_TYPES) + + +def _format_coroutine(coro): + assert iscoroutine(coro) + coro_name = getattr(coro, '__qualname__', coro.__name__) + + filename = coro.gi_code.co_filename + if (isinstance(coro, CoroWrapper) + and not inspect.isgeneratorfunction(coro.func)): + filename, lineno = events._get_function_source(coro.func) + if coro.gi_frame is None: + coro_repr = '%s() done, defined at %s:%s' % (coro_name, filename, lineno) + else: + coro_repr = '%s() running, defined at %s:%s' % (coro_name, filename, lineno) + elif coro.gi_frame is not None: + lineno = coro.gi_frame.f_lineno + coro_repr = '%s() running at %s:%s' % (coro_name, filename, lineno) + else: + lineno = coro.gi_code.co_firstlineno + coro_repr = '%s() done, defined at %s:%s' % (coro_name, filename, lineno) + + return coro_repr diff --git a/asyncio/events.py b/asyncio/events.py new file mode 100644 index 00000000..806218f6 --- /dev/null +++ b/asyncio/events.py @@ -0,0 +1,597 @@ +"""Event loop and event loop policy.""" + +__all__ = ['AbstractEventLoopPolicy', + 'AbstractEventLoop', 'AbstractServer', + 'Handle', 'TimerHandle', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', 'new_event_loop', + 'get_child_watcher', 'set_child_watcher', + ] + +import functools +import inspect +import reprlib +import socket +import subprocess +import sys +import threading +import traceback + + +_PY34 = sys.version_info >= (3, 4) + + +def _get_function_source(func): + if _PY34: + func = inspect.unwrap(func) + elif hasattr(func, '__wrapped__'): + func = func.__wrapped__ + if inspect.isfunction(func): + code = func.__code__ + return (code.co_filename, code.co_firstlineno) + if isinstance(func, functools.partial): + return _get_function_source(func.func) + if _PY34 and isinstance(func, functools.partialmethod): + return _get_function_source(func.func) + return None + + +def _format_args(args): + """Format function arguments. + + Special case for a single parameter: ('hello',) is formatted as ('hello'). + """ + # use reprlib to limit the length of the output + args_repr = reprlib.repr(args) + if len(args) == 1 and args_repr.endswith(',)'): + args_repr = args_repr[:-2] + ')' + return args_repr + + +def _format_callback(func, args, suffix=''): + if isinstance(func, functools.partial): + if args is not None: + suffix = _format_args(args) + suffix + return _format_callback(func.func, func.args, suffix) + + func_repr = getattr(func, '__qualname__', None) + if not func_repr: + func_repr = repr(func) + + if args is not None: + func_repr += _format_args(args) + if suffix: + func_repr += suffix + + source = _get_function_source(func) + if source: + func_repr += ' at %s:%s' % source + return func_repr + + +class Handle: + """Object returned by callback registration methods.""" + + __slots__ = ('_callback', '_args', '_cancelled', '_loop', + '_source_traceback', '_repr', '__weakref__') + + def __init__(self, callback, args, loop): + assert not isinstance(callback, Handle), 'A Handle is not a callback' + self._loop = loop + self._callback = callback + self._args = args + self._cancelled = False + self._repr = None + if self._loop.get_debug(): + self._source_traceback = traceback.extract_stack(sys._getframe(1)) + else: + self._source_traceback = None + + def _repr_info(self): + info = [self.__class__.__name__] + if self._cancelled: + info.append('cancelled') + if self._callback is not None: + info.append(_format_callback(self._callback, self._args)) + if self._source_traceback: + frame = self._source_traceback[-1] + info.append('created at %s:%s' % (frame[0], frame[1])) + return info + + def __repr__(self): + if self._repr is not None: + return self._repr + info = self._repr_info() + return '<%s>' % ' '.join(info) + + def cancel(self): + if not self._cancelled: + self._cancelled = True + if self._loop.get_debug(): + # Keep a representation in debug mode to keep callback and + # parameters. For example, to log the warning + # "Executing took 2.5 second" + self._repr = repr(self) + self._callback = None + self._args = None + + def _run(self): + try: + self._callback(*self._args) + except Exception as exc: + cb = _format_callback(self._callback, self._args) + msg = 'Exception in callback {}'.format(cb) + context = { + 'message': msg, + 'exception': exc, + 'handle': self, + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + self = None # Needed to break cycles when an exception occurs. + + +class TimerHandle(Handle): + """Object returned by timed callback registration methods.""" + + __slots__ = ['_scheduled', '_when'] + + def __init__(self, when, callback, args, loop): + assert when is not None + super().__init__(callback, args, loop) + if self._source_traceback: + del self._source_traceback[-1] + self._when = when + self._scheduled = False + + def _repr_info(self): + info = super()._repr_info() + pos = 2 if self._cancelled else 1 + info.insert(pos, 'when=%s' % self._when) + return info + + def __hash__(self): + return hash(self._when) + + def __lt__(self, other): + return self._when < other._when + + def __le__(self, other): + if self._when < other._when: + return True + return self.__eq__(other) + + def __gt__(self, other): + return self._when > other._when + + def __ge__(self, other): + if self._when > other._when: + return True + return self.__eq__(other) + + def __eq__(self, other): + if isinstance(other, TimerHandle): + return (self._when == other._when and + self._callback == other._callback and + self._args == other._args and + self._cancelled == other._cancelled) + return NotImplemented + + def __ne__(self, other): + equal = self.__eq__(other) + return NotImplemented if equal is NotImplemented else not equal + + def cancel(self): + if not self._cancelled: + self._loop._timer_handle_cancelled(self) + super().cancel() + + +class AbstractServer: + """Abstract server returned by create_server().""" + + def close(self): + """Stop serving. This leaves existing connections open.""" + return NotImplemented + + def wait_closed(self): + """Coroutine to wait until service is closed.""" + return NotImplemented + + +class AbstractEventLoop: + """Abstract event loop.""" + + # Running and stopping the event loop. + + def run_forever(self): + """Run the event loop until stop() is called.""" + raise NotImplementedError + + def run_until_complete(self, future): + """Run the event loop until a Future is done. + + Return the Future's result, or raise its exception. + """ + raise NotImplementedError + + def stop(self): + """Stop the event loop as soon as reasonable. + + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + raise NotImplementedError + + def is_running(self): + """Return whether the event loop is currently running.""" + raise NotImplementedError + + def is_closed(self): + """Returns True if the event loop was closed.""" + raise NotImplementedError + + def close(self): + """Close the loop. + + The loop should not be running. + + This is idempotent and irreversible. + + No other methods should be called after this one. + """ + raise NotImplementedError + + # Methods scheduling callbacks. All these return Handles. + + def _timer_handle_cancelled(self, handle): + """Notification that a TimerHandle has been cancelled.""" + raise NotImplementedError + + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) + + def call_later(self, delay, callback, *args): + raise NotImplementedError + + def call_at(self, when, callback, *args): + raise NotImplementedError + + def time(self): + raise NotImplementedError + + # Method scheduling a coroutine object: create a task. + + def create_task(self, coro): + raise NotImplementedError + + # Methods for interacting with threads. + + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError + + def run_in_executor(self, executor, callback, *args): + raise NotImplementedError + + def set_default_executor(self, executor): + raise NotImplementedError + + # Network I/O methods returning Futures. + + def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError + + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=None, family=0, proto=0, flags=0, sock=None, + local_addr=None, server_hostname=None): + raise NotImplementedError + + def create_server(self, protocol_factory, host=None, port=None, *, + family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, + sock=None, backlog=100, ssl=None, reuse_address=None): + """A coroutine which creates a TCP server bound to host and port. + + The return value is a Server object which can be used to stop + the service. + + If host is an empty string or None all interfaces are assumed + and a list of multiple sockets will be returned (most likely + one for IPv4 and another one for IPv6). + + family can be set to either AF_INET or AF_INET6 to force the + socket to use IPv4 or IPv6. If not set it will be determined + from host (defaults to AF_UNSPEC). + + flags is a bitmask for getaddrinfo(). + + sock can optionally be specified in order to use a preexisting + socket object. + + backlog is the maximum number of queued connections passed to + listen() (defaults to 100). + + ssl can be set to an SSLContext to enable SSL over the + accepted connections. + + reuse_address tells the kernel to reuse a local socket in + TIME_WAIT state, without waiting for its natural timeout to + expire. If not specified will automatically be set to True on + UNIX. + """ + raise NotImplementedError + + def create_unix_connection(self, protocol_factory, path, *, + ssl=None, sock=None, + server_hostname=None): + raise NotImplementedError + + def create_unix_server(self, protocol_factory, path, *, + sock=None, backlog=100, ssl=None): + """A coroutine which creates a UNIX Domain Socket server. + + The return value is a Server object, which can be used to stop + the service. + + path is a str, representing a file systsem path to bind the + server socket to. + + sock can optionally be specified in order to use a preexisting + socket object. + + backlog is the maximum number of queued connections passed to + listen() (defaults to 100). + + ssl can be set to an SSLContext to enable SSL over the + accepted connections. + """ + raise NotImplementedError + + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + # Pipes and subprocesses. + + def connect_read_pipe(self, protocol_factory, pipe): + """Register read pipe in event loop. Set the pipe to non-blocking mode. + + protocol_factory should instantiate object with Protocol interface. + pipe is a file-like object. + Return pair (transport, protocol), where transport supports the + ReadTransport interface.""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def connect_write_pipe(self, protocol_factory, pipe): + """Register write pipe in event loop. + + protocol_factory should instantiate object with BaseProtocol interface. + Pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + WriteTransport interface.""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + **kwargs): + raise NotImplementedError + + def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + **kwargs): + raise NotImplementedError + + # Ready-based callback registration methods. + # The add_*() methods return None. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. + + def add_reader(self, fd, callback, *args): + raise NotImplementedError + + def remove_reader(self, fd): + raise NotImplementedError + + def add_writer(self, fd, callback, *args): + raise NotImplementedError + + def remove_writer(self, fd): + raise NotImplementedError + + # Completion based I/O methods returning Futures. + + def sock_recv(self, sock, nbytes): + raise NotImplementedError + + def sock_sendall(self, sock, data): + raise NotImplementedError + + def sock_connect(self, sock, address): + raise NotImplementedError + + def sock_accept(self, sock): + raise NotImplementedError + + # Signal handling. + + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError + + def remove_signal_handler(self, sig): + raise NotImplementedError + + # Error handlers. + + def set_exception_handler(self, handler): + raise NotImplementedError + + def default_exception_handler(self, context): + raise NotImplementedError + + def call_exception_handler(self, context): + raise NotImplementedError + + # Debug flag management. + + def get_debug(self): + raise NotImplementedError + + def set_debug(self, enabled): + raise NotImplementedError + + +class AbstractEventLoopPolicy: + """Abstract policy for accessing the event loop.""" + + def get_event_loop(self): + """Get the event loop for the current context. + + Returns an event loop object implementing the BaseEventLoop interface, + or raises an exception in case no event loop has been set for the + current context and the current policy does not specify to create one. + + It should never return None.""" + raise NotImplementedError + + def set_event_loop(self, loop): + """Set the event loop for the current context to loop.""" + raise NotImplementedError + + def new_event_loop(self): + """Create and return a new event loop object according to this + policy's rules. If there's need to set this loop as the event loop for + the current context, set_event_loop must be called explicitly.""" + raise NotImplementedError + + # Child processes handling (Unix only). + + def get_child_watcher(self): + "Get the watcher for child processes." + raise NotImplementedError + + def set_child_watcher(self, watcher): + """Set the watcher for child processes.""" + raise NotImplementedError + + +class BaseDefaultEventLoopPolicy(AbstractEventLoopPolicy): + """Default policy implementation for accessing the event loop. + + In this policy, each thread has its own event loop. However, we + only automatically create an event loop by default for the main + thread; other threads by default have no event loop. + + Other policies may have different rules (e.g. a single global + event loop, or automatically creating an event loop per thread, or + using some other notion of context to which an event loop is + associated). + """ + + _loop_factory = None + + class _Local(threading.local): + _loop = None + _set_called = False + + def __init__(self): + self._local = self._Local() + + def get_event_loop(self): + """Get the event loop. + + This may be None or an instance of EventLoop. + """ + if (self._local._loop is None and + not self._local._set_called and + isinstance(threading.current_thread(), threading._MainThread)): + self.set_event_loop(self.new_event_loop()) + assert self._local._loop is not None, \ + ('There is no current event loop in thread %r.' % + threading.current_thread().name) + return self._local._loop + + def set_event_loop(self, loop): + """Set the event loop.""" + self._local._set_called = True + assert loop is None or isinstance(loop, AbstractEventLoop) + self._local._loop = loop + + def new_event_loop(self): + """Create a new event loop. + + You must call set_event_loop() to make this the current event + loop. + """ + return self._loop_factory() + + +# Event loop policy. The policy itself is always global, even if the +# policy's rules say that there is an event loop per thread (or other +# notion of context). The default policy is installed by the first +# call to get_event_loop_policy(). +_event_loop_policy = None + +# Lock for protecting the on-the-fly creation of the event loop policy. +_lock = threading.Lock() + + +def _init_event_loop_policy(): + global _event_loop_policy + with _lock: + if _event_loop_policy is None: # pragma: no branch + from . import DefaultEventLoopPolicy + _event_loop_policy = DefaultEventLoopPolicy() + + +def get_event_loop_policy(): + """Get the current event loop policy.""" + if _event_loop_policy is None: + _init_event_loop_policy() + return _event_loop_policy + + +def set_event_loop_policy(policy): + """Set the current event loop policy. + + If policy is None, the default policy is restored.""" + global _event_loop_policy + assert policy is None or isinstance(policy, AbstractEventLoopPolicy) + _event_loop_policy = policy + + +def get_event_loop(): + """Equivalent to calling get_event_loop_policy().get_event_loop().""" + return get_event_loop_policy().get_event_loop() + + +def set_event_loop(loop): + """Equivalent to calling get_event_loop_policy().set_event_loop(loop).""" + get_event_loop_policy().set_event_loop(loop) + + +def new_event_loop(): + """Equivalent to calling get_event_loop_policy().new_event_loop().""" + return get_event_loop_policy().new_event_loop() + + +def get_child_watcher(): + """Equivalent to calling get_event_loop_policy().get_child_watcher().""" + return get_event_loop_policy().get_child_watcher() + + +def set_child_watcher(watcher): + """Equivalent to calling + get_event_loop_policy().set_child_watcher(watcher).""" + return get_event_loop_policy().set_child_watcher(watcher) diff --git a/asyncio/futures.py b/asyncio/futures.py new file mode 100644 index 00000000..f46d008f --- /dev/null +++ b/asyncio/futures.py @@ -0,0 +1,411 @@ +"""A Future class similar to the one in PEP 3148.""" + +__all__ = ['CancelledError', 'TimeoutError', + 'InvalidStateError', + 'Future', 'wrap_future', + ] + +import concurrent.futures._base +import logging +import reprlib +import sys +import traceback + +from . import events + +# States for Future. +_PENDING = 'PENDING' +_CANCELLED = 'CANCELLED' +_FINISHED = 'FINISHED' + +_PY34 = sys.version_info >= (3, 4) + +# TODO: Do we really want to depend on concurrent.futures internals? +Error = concurrent.futures._base.Error +CancelledError = concurrent.futures.CancelledError +TimeoutError = concurrent.futures.TimeoutError + +STACK_DEBUG = logging.DEBUG - 1 # heavy-duty debugging + + +class InvalidStateError(Error): + """The operation is not allowed in this state.""" + # TODO: Show the future, its state, the method, and the required state. + + +class _TracebackLogger: + """Helper to log a traceback upon destruction if not cleared. + + This solves a nasty problem with Futures and Tasks that have an + exception set: if nobody asks for the exception, the exception is + never logged. This violates the Zen of Python: 'Errors should + never pass silently. Unless explicitly silenced.' + + However, we don't want to log the exception as soon as + set_exception() is called: if the calling code is written + properly, it will get the exception and handle it properly. But + we *do* want to log it if result() or exception() was never called + -- otherwise developers waste a lot of time wondering why their + buggy code fails silently. + + An earlier attempt added a __del__() method to the Future class + itself, but this backfired because the presence of __del__() + prevents garbage collection from breaking cycles. A way out of + this catch-22 is to avoid having a __del__() method on the Future + class itself, but instead to have a reference to a helper object + with a __del__() method that logs the traceback, where we ensure + that the helper object doesn't participate in cycles, and only the + Future has a reference to it. + + The helper object is added when set_exception() is called. When + the Future is collected, and the helper is present, the helper + object is also collected, and its __del__() method will log the + traceback. When the Future's result() or exception() method is + called (and a helper object is present), it removes the helper + object, after calling its clear() method to prevent it from + logging. + + One downside is that we do a fair amount of work to extract the + traceback from the exception, even when it is never logged. It + would seem cheaper to just store the exception object, but that + references the traceback, which references stack frames, which may + reference the Future, which references the _TracebackLogger, and + then the _TracebackLogger would be included in a cycle, which is + what we're trying to avoid! As an optimization, we don't + immediately format the exception; we only do the work when + activate() is called, which call is delayed until after all the + Future's callbacks have run. Since usually a Future has at least + one callback (typically set by 'yield from') and usually that + callback extracts the callback, thereby removing the need to + format the exception. + + PS. I don't claim credit for this solution. I first heard of it + in a discussion about closing files when they are collected. + """ + + __slots__ = ('loop', 'source_traceback', 'exc', 'tb') + + def __init__(self, future, exc): + self.loop = future._loop + self.source_traceback = future._source_traceback + self.exc = exc + self.tb = None + + def activate(self): + exc = self.exc + if exc is not None: + self.exc = None + self.tb = traceback.format_exception(exc.__class__, exc, + exc.__traceback__) + + def clear(self): + self.exc = None + self.tb = None + + def __del__(self): + if self.tb: + msg = 'Future/Task exception was never retrieved\n' + if self.source_traceback: + src = ''.join(traceback.format_list(self.source_traceback)) + msg += 'Future/Task created at (most recent call last):\n' + msg += '%s\n' % src.rstrip() + msg += ''.join(self.tb).rstrip() + self.loop.call_exception_handler({'message': msg}) + + +class Future: + """This class is *almost* compatible with concurrent.futures.Future. + + Differences: + + - result() and exception() do not take a timeout argument and + raise an exception when the future isn't done yet. + + - Callbacks registered with add_done_callback() are always called + via the event loop's call_soon_threadsafe(). + + - This class is not compatible with the wait() and as_completed() + methods in the concurrent.futures package. + + (In Python 3.4 or later we may be able to unify the implementations.) + """ + + # Class variables serving as defaults for instance variables. + _state = _PENDING + _result = None + _exception = None + _loop = None + _source_traceback = None + + _blocking = False # proper use of future (yield vs yield from) + + _log_traceback = False # Used for Python 3.4 and later + _tb_logger = None # Used for Python 3.3 only + + def __init__(self, *, loop=None): + """Initialize the future. + + The optional event_loop argument allows to explicitly set the event + loop object used by the future. If it's not provided, the future uses + the default event loop. + """ + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._callbacks = [] + if self._loop.get_debug(): + self._source_traceback = traceback.extract_stack(sys._getframe(1)) + + def _format_callbacks(self): + cb = self._callbacks + size = len(cb) + if not size: + cb = '' + + def format_cb(callback): + return events._format_callback(callback, ()) + + if size == 1: + cb = format_cb(cb[0]) + elif size == 2: + cb = '{}, {}'.format(format_cb(cb[0]), format_cb(cb[1])) + elif size > 2: + cb = '{}, <{} more>, {}'.format(format_cb(cb[0]), + size-2, + format_cb(cb[-1])) + return 'cb=[%s]' % cb + + def _repr_info(self): + info = [self._state.lower()] + if self._state == _FINISHED: + if self._exception is not None: + info.append('exception={!r}'.format(self._exception)) + else: + # use reprlib to limit the length of the output, especially + # for very long strings + result = reprlib.repr(self._result) + info.append('result={}'.format(result)) + if self._callbacks: + info.append(self._format_callbacks()) + if self._source_traceback: + frame = self._source_traceback[-1] + info.append('created at %s:%s' % (frame[0], frame[1])) + return info + + def __repr__(self): + info = self._repr_info() + return '<%s %s>' % (self.__class__.__name__, ' '.join(info)) + + # On Python 3.3 or older, objects with a destructor part of a reference + # cycle are never destroyed. It's not more the case on Python 3.4 thanks to + # the PEP 442. + if _PY34: + def __del__(self): + if not self._log_traceback: + # set_exception() was not called, or result() or exception() + # has consumed the exception + return + exc = self._exception + context = { + 'message': ('%s exception was never retrieved' + % self.__class__.__name__), + 'exception': exc, + 'future': self, + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + + def cancel(self): + """Cancel the future and schedule callbacks. + + If the future is already done or cancelled, return False. Otherwise, + change the future's state to cancelled, schedule the callbacks and + return True. + """ + if self._state != _PENDING: + return False + self._state = _CANCELLED + self._schedule_callbacks() + return True + + def _schedule_callbacks(self): + """Internal: Ask the event loop to call all callbacks. + + The callbacks are scheduled to be called as soon as possible. Also + clears the callback list. + """ + callbacks = self._callbacks[:] + if not callbacks: + return + + self._callbacks[:] = [] + for callback in callbacks: + self._loop.call_soon(callback, self) + + def cancelled(self): + """Return True if the future was cancelled.""" + return self._state == _CANCELLED + + # Don't implement running(); see http://bugs.python.org/issue18699 + + def done(self): + """Return True if the future is done. + + Done means either that a result / exception are available, or that the + future was cancelled. + """ + return self._state != _PENDING + + def result(self): + """Return the result this future represents. + + If the future has been cancelled, raises CancelledError. If the + future's result isn't yet available, raises InvalidStateError. If + the future is done and has an exception set, this exception is raised. + """ + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError('Result is not ready.') + self._log_traceback = False + if self._tb_logger is not None: + self._tb_logger.clear() + self._tb_logger = None + if self._exception is not None: + raise self._exception + return self._result + + def exception(self): + """Return the exception that was set on this future. + + The exception (or None if no exception was set) is returned only if + the future is done. If the future has been cancelled, raises + CancelledError. If the future isn't done yet, raises + InvalidStateError. + """ + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError('Exception is not set.') + self._log_traceback = False + if self._tb_logger is not None: + self._tb_logger.clear() + self._tb_logger = None + return self._exception + + def add_done_callback(self, fn): + """Add a callback to be run when the future becomes done. + + The callback is called with a single argument - the future object. If + the future is already done when this is called, the callback is + scheduled with call_soon. + """ + if self._state != _PENDING: + self._loop.call_soon(fn, self) + else: + self._callbacks.append(fn) + + # New method not in PEP 3148. + + def remove_done_callback(self, fn): + """Remove all instances of a callback from the "call when done" list. + + Returns the number of callbacks removed. + """ + filtered_callbacks = [f for f in self._callbacks if f != fn] + removed_count = len(self._callbacks) - len(filtered_callbacks) + if removed_count: + self._callbacks[:] = filtered_callbacks + return removed_count + + # So-called internal methods (note: no set_running_or_notify_cancel()). + + def _set_result_unless_cancelled(self, result): + """Helper setting the result only if the future was not cancelled.""" + if self.cancelled(): + return + self.set_result(result) + + def set_result(self, result): + """Mark the future done and set its result. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError('{}: {!r}'.format(self._state, self)) + self._result = result + self._state = _FINISHED + self._schedule_callbacks() + + def set_exception(self, exception): + """Mark the future done and set an exception. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError('{}: {!r}'.format(self._state, self)) + if isinstance(exception, type): + exception = exception() + self._exception = exception + self._state = _FINISHED + self._schedule_callbacks() + if _PY34: + self._log_traceback = True + else: + self._tb_logger = _TracebackLogger(self, exception) + # Arrange for the logger to be activated after all callbacks + # have had a chance to call result() or exception(). + self._loop.call_soon(self._tb_logger.activate) + + # Truly internal methods. + + def _copy_state(self, other): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert other.done() + if self.cancelled(): + return + assert not self.done() + if other.cancelled(): + self.cancel() + else: + exception = other.exception() + if exception is not None: + self.set_exception(exception) + else: + result = other.result() + self.set_result(result) + + def __iter__(self): + if not self.done(): + self._blocking = True + yield self # This tells Task to wait for completion. + assert self.done(), "yield from wasn't used with future" + return self.result() # May raise too. + + +def wrap_future(fut, *, loop=None): + """Wrap concurrent.futures.Future object.""" + if isinstance(fut, Future): + return fut + assert isinstance(fut, concurrent.futures.Future), \ + 'concurrent.futures.Future is expected, got {!r}'.format(fut) + if loop is None: + loop = events.get_event_loop() + new_future = Future(loop=loop) + + def _check_cancel_other(f): + if f.cancelled(): + fut.cancel() + + new_future.add_done_callback(_check_cancel_other) + fut.add_done_callback( + lambda future: loop.call_soon_threadsafe( + new_future._copy_state, fut)) + return new_future diff --git a/asyncio/locks.py b/asyncio/locks.py new file mode 100644 index 00000000..b943e9dd --- /dev/null +++ b/asyncio/locks.py @@ -0,0 +1,469 @@ +"""Synchronization primitives.""" + +__all__ = ['Lock', 'Event', 'Condition', 'Semaphore', 'BoundedSemaphore'] + +import collections + +from . import events +from . import futures +from .coroutines import coroutine + + +class _ContextManager: + """Context manager. + + This enables the following idiom for acquiring and releasing a + lock around a block: + + with (yield from lock): + + + while failing loudly when accidentally using: + + with lock: + + """ + + def __init__(self, lock): + self._lock = lock + + def __enter__(self): + # We have no use for the "as ..." clause in the with + # statement for locks. + return None + + def __exit__(self, *args): + try: + self._lock.release() + finally: + self._lock = None # Crudely prevent reuse. + + +class Lock: + """Primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned + by a particular coroutine when locked. A primitive lock is in one + of two states, 'locked' or 'unlocked'. + + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() + changes the state to locked and returns immediately. When the + state is locked, acquire() blocks until a call to release() in + another coroutine changes it to unlocked, then the acquire() call + resets it to locked and returns. The release() method should only + be called in the locked state; it changes the state to unlocked + and returns immediately. If an attempt is made to release an + unlocked lock, a RuntimeError will be raised. + + When more than one coroutine is blocked in acquire() waiting for + the state to turn to unlocked, only one coroutine proceeds when a + release() call resets the state to unlocked; first coroutine which + is blocked in acquire() is being processed. + + acquire() is a coroutine and should be called with 'yield from'. + + Locks also support the context management protocol. '(yield from lock)' + should be used as context manager expression. + + Usage: + + lock = Lock() + ... + yield from lock + try: + ... + finally: + lock.release() + + Context manager usage: + + lock = Lock() + ... + with (yield from lock): + ... + + Lock objects can be tested for locking state: + + if not lock.locked(): + yield from lock + else: + # lock is acquired + ... + + """ + + def __init__(self, *, loop=None): + self._waiters = collections.deque() + self._locked = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + extra = 'locked' if self._locked else 'unlocked' + if self._waiters: + extra = '{},waiters:{}'.format(extra, len(self._waiters)) + return '<{} [{}]>'.format(res[1:-1], extra) + + def locked(self): + """Return True if lock is acquired.""" + return self._locked + + @coroutine + def acquire(self): + """Acquire a lock. + + This method blocks until the lock is unlocked, then sets it to + locked and returns True. + """ + if not self._waiters and not self._locked: + self._locked = True + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + self._locked = True + return True + finally: + self._waiters.remove(fut) + + def release(self): + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + If any other coroutines are blocked waiting for the lock to become + unlocked, allow exactly one of them to proceed. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + # Wake up the first waiter who isn't cancelled. + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + break + else: + raise RuntimeError('Lock is not acquired.') + + def __enter__(self): + raise RuntimeError( + '"yield from" should be used as context manager expression') + + def __exit__(self, *args): + # This must exist because __enter__ exists, even though that + # always raises; that's how the with-statement works. + pass + + def __iter__(self): + # This is not a coroutine. It is meant to enable the idiom: + # + # with (yield from lock): + # + # + # as an alternative to: + # + # yield from lock.acquire() + # try: + # + # finally: + # lock.release() + yield from self.acquire() + return _ContextManager(self) + + +class Event: + """Asynchronous equivalent to threading.Event. + + Class implementing event objects. An event manages a flag that can be set + to true with the set() method and reset to false with the clear() method. + The wait() method blocks until the flag is true. The flag is initially + false. + """ + + def __init__(self, *, loop=None): + self._waiters = collections.deque() + self._value = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + extra = 'set' if self._value else 'unset' + if self._waiters: + extra = '{},waiters:{}'.format(extra, len(self._waiters)) + return '<{} [{}]>'.format(res[1:-1], extra) + + def is_set(self): + """Return True if and only if the internal flag is true.""" + return self._value + + def set(self): + """Set the internal flag to true. All coroutines waiting for it to + become true are awakened. Coroutine that call wait() once the flag is + true will not block at all. + """ + if not self._value: + self._value = True + + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + + def clear(self): + """Reset the internal flag to false. Subsequently, coroutines calling + wait() will block until set() is called to set the internal flag + to true again.""" + self._value = False + + @coroutine + def wait(self): + """Block until the internal flag is true. + + If the internal flag is true on entry, return True + immediately. Otherwise, block until another coroutine calls + set() to set the flag to true, then return True. + """ + if self._value: + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + return True + finally: + self._waiters.remove(fut) + + +class Condition: + """Asynchronous equivalent to threading.Condition. + + This class implements condition variable objects. A condition variable + allows one or more coroutines to wait until they are notified by another + coroutine. + + A new Lock object is created and used as the underlying lock. + """ + + def __init__(self, lock=None, *, loop=None): + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + if lock is None: + lock = Lock(loop=self._loop) + elif lock._loop is not self._loop: + raise ValueError("loop argument must agree with lock") + + self._lock = lock + # Export the lock's locked(), acquire() and release() methods. + self.locked = lock.locked + self.acquire = lock.acquire + self.release = lock.release + + self._waiters = collections.deque() + + def __repr__(self): + res = super().__repr__() + extra = 'locked' if self.locked() else 'unlocked' + if self._waiters: + extra = '{},waiters:{}'.format(extra, len(self._waiters)) + return '<{} [{}]>'.format(res[1:-1], extra) + + @coroutine + def wait(self): + """Wait until notified. + + If the calling coroutine has not acquired the lock when this + method is called, a RuntimeError is raised. + + This method releases the underlying lock, and then blocks + until it is awakened by a notify() or notify_all() call for + the same condition variable in another coroutine. Once + awakened, it re-acquires the lock and returns True. + """ + if not self.locked(): + raise RuntimeError('cannot wait on un-acquired lock') + + self.release() + try: + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + return True + finally: + self._waiters.remove(fut) + + finally: + yield from self.acquire() + + @coroutine + def wait_for(self, predicate): + """Wait until a predicate becomes true. + + The predicate should be a callable which result will be + interpreted as a boolean value. The final predicate value is + the return value. + """ + result = predicate() + while not result: + yield from self.wait() + result = predicate() + return result + + def notify(self, n=1): + """By default, wake up one coroutine waiting on this condition, if any. + If the calling coroutine has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up at most n of the coroutines waiting for the + condition variable; it is a no-op if no coroutines are waiting. + + Note: an awakened coroutine does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self.locked(): + raise RuntimeError('cannot notify on un-acquired lock') + + idx = 0 + for fut in self._waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self): + """Wake up all threads waiting on this condition. This method acts + like notify(), but wakes up all waiting threads instead of one. If the + calling thread has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._waiters)) + + def __enter__(self): + raise RuntimeError( + '"yield from" should be used as context manager expression') + + def __exit__(self, *args): + pass + + def __iter__(self): + # See comment in Lock.__iter__(). + yield from self.acquire() + return _ContextManager(self) + + +class Semaphore: + """A Semaphore implementation. + + A semaphore manages an internal counter which is decremented by each + acquire() call and incremented by each release() call. The counter + can never go below zero; when acquire() finds that it is zero, it blocks, + waiting until some other thread calls release(). + + Semaphores also support the context management protocol. + + The optional argument gives the initial value for the internal + counter; it defaults to 1. If the value given is less than 0, + ValueError is raised. + """ + + def __init__(self, value=1, *, loop=None): + if value < 0: + raise ValueError("Semaphore initial value must be >= 0") + self._value = value + self._waiters = collections.deque() + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + extra = 'locked' if self.locked() else 'unlocked,value:{}'.format( + self._value) + if self._waiters: + extra = '{},waiters:{}'.format(extra, len(self._waiters)) + return '<{} [{}]>'.format(res[1:-1], extra) + + def locked(self): + """Returns True if semaphore can not be acquired immediately.""" + return self._value == 0 + + @coroutine + def acquire(self): + """Acquire a semaphore. + + If the internal counter is larger than zero on entry, + decrement it by one and return True immediately. If it is + zero on entry, block, waiting until some other coroutine has + called release() to make it larger than 0, and then return + True. + """ + if not self._waiters and self._value > 0: + self._value -= 1 + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + self._value -= 1 + return True + finally: + self._waiters.remove(fut) + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + When it was zero on entry and another coroutine is waiting for it to + become larger than zero again, wake up that coroutine. + """ + self._value += 1 + for waiter in self._waiters: + if not waiter.done(): + waiter.set_result(True) + break + + def __enter__(self): + raise RuntimeError( + '"yield from" should be used as context manager expression') + + def __exit__(self, *args): + pass + + def __iter__(self): + # See comment in Lock.__iter__(). + yield from self.acquire() + return _ContextManager(self) + + +class BoundedSemaphore(Semaphore): + """A bounded semaphore implementation. + + This raises ValueError in release() if it would increase the value + above the initial value. + """ + + def __init__(self, value=1, *, loop=None): + self._bound_value = value + super().__init__(value, loop=loop) + + def release(self): + if self._value >= self._bound_value: + raise ValueError('BoundedSemaphore released too many times') + super().release() diff --git a/asyncio/log.py b/asyncio/log.py new file mode 100644 index 00000000..23a7074a --- /dev/null +++ b/asyncio/log.py @@ -0,0 +1,7 @@ +"""Logging configuration.""" + +import logging + + +# Name the logger after the package. +logger = logging.getLogger(__package__) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py new file mode 100644 index 00000000..e67cf65a --- /dev/null +++ b/asyncio/proactor_events.py @@ -0,0 +1,502 @@ +"""Event loop using a proactor and related classes. + +A proactor is a "notify-on-completion" multiplexer. Currently a +proactor is only implemented on Windows with IOCP. +""" + +__all__ = ['BaseProactorEventLoop'] + +import socket + +from . import base_events +from . import constants +from . import futures +from . import transports +from .log import logger + + +class _ProactorBasePipeTransport(transports._FlowControlMixin, + transports.BaseTransport): + """Base class for pipe and socket transports.""" + + def __init__(self, loop, sock, protocol, waiter=None, + extra=None, server=None): + super().__init__(extra, loop) + self._set_extra(sock) + self._sock = sock + self._protocol = protocol + self._server = server + self._buffer = None # None or bytearray. + self._read_fut = None + self._write_fut = None + self._pending_write = 0 + self._conn_lost = 0 + self._closing = False # Set when close() called. + self._eof_written = False + if self._server is not None: + self._server._attach() + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + # wait until protocol.connection_made() has been called + self._loop.call_soon(waiter._set_result_unless_cancelled, None) + + def __repr__(self): + info = [self.__class__.__name__] + fd = self._sock.fileno() + if fd < 0: + info.append('closed') + elif self._closing: + info.append('closing') + info.append('fd=%s' % fd) + if self._read_fut is not None: + info.append('read=%s' % self._read_fut) + if self._write_fut is not None: + info.append("write=%r" % self._write_fut) + if self._buffer: + bufsize = len(self._buffer) + info.append('write_bufsize=%s' % bufsize) + if self._eof_written: + info.append('EOF written') + return '<%s>' % ' '.join(info) + + def _set_extra(self, sock): + self._extra['pipe'] = sock + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + if not self._buffer and self._write_fut is None: + self._loop.call_soon(self._call_connection_lost, None) + if self._read_fut is not None: + self._read_fut.cancel() + + def _fatal_error(self, exc, message='Fatal error on pipe transport'): + if isinstance(exc, (BrokenPipeError, ConnectionResetError)): + if self._loop.get_debug(): + logger.debug("%r: %s", self, message, exc_info=True) + else: + self._loop.call_exception_handler({ + 'message': message, + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) + self._force_close(exc) + + def _force_close(self, exc): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + if self._write_fut: + self._write_fut.cancel() + if self._read_fut: + self._read_fut.cancel() + self._write_fut = self._read_fut = None + self._pending_write = 0 + self._buffer = None + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + # XXX If there is a pending overlapped read on the other + # end then it may fail with ERROR_NETNAME_DELETED if we + # just close our end. First calling shutdown() seems to + # cure it, but maybe using DisconnectEx() would be better. + if hasattr(self._sock, 'shutdown'): + self._sock.shutdown(socket.SHUT_RDWR) + self._sock.close() + server = self._server + if server is not None: + server._detach() + self._server = None + + def get_write_buffer_size(self): + size = self._pending_write + if self._buffer is not None: + size += len(self._buffer) + return size + + +class _ProactorReadPipeTransport(_ProactorBasePipeTransport, + transports.ReadTransport): + """Transport for read pipes.""" + + def __init__(self, loop, sock, protocol, waiter=None, + extra=None, server=None): + super().__init__(loop, sock, protocol, waiter, extra, server) + self._paused = False + self._loop.call_soon(self._loop_reading) + + def pause_reading(self): + if self._closing: + raise RuntimeError('Cannot pause_reading() when closing') + if self._paused: + raise RuntimeError('Already paused') + self._paused = True + if self._loop.get_debug(): + logger.debug("%r pauses reading", self) + + def resume_reading(self): + if not self._paused: + raise RuntimeError('Not paused') + self._paused = False + if self._closing: + return + self._loop.call_soon(self._loop_reading, self._read_fut) + if self._loop.get_debug(): + logger.debug("%r resumes reading", self) + + def _loop_reading(self, fut=None): + if self._paused: + return + data = None + + try: + if fut is not None: + assert self._read_fut is fut or (self._read_fut is None and + self._closing) + self._read_fut = None + data = fut.result() # deliver data later in "finally" clause + + if self._closing: + # since close() has been called we ignore any read data + data = None + return + + if data == b'': + # we got end-of-file so no need to reschedule a new read + return + + # reschedule a new read + self._read_fut = self._loop._proactor.recv(self._sock, 4096) + except ConnectionAbortedError as exc: + if not self._closing: + self._fatal_error(exc, 'Fatal read error on pipe transport') + elif self._loop.get_debug(): + logger.debug("Read error on pipe transport while closing", + exc_info=True) + except ConnectionResetError as exc: + self._force_close(exc) + except OSError as exc: + self._fatal_error(exc, 'Fatal read error on pipe transport') + except futures.CancelledError: + if not self._closing: + raise + else: + self._read_fut.add_done_callback(self._loop_reading) + finally: + if data: + self._protocol.data_received(data) + elif data is not None: + if self._loop.get_debug(): + logger.debug("%r received EOF", self) + keep_open = self._protocol.eof_received() + if not keep_open: + self.close() + + +class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport, + transports.WriteTransport): + """Transport for write pipes.""" + + def write(self, data): + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError('data argument must be byte-ish (%r)', + type(data)) + if self._eof_written: + raise RuntimeError('write_eof() already called') + + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + # Observable states: + # 1. IDLE: _write_fut and _buffer both None + # 2. WRITING: _write_fut set; _buffer None + # 3. BACKED UP: _write_fut set; _buffer a bytearray + # We always copy the data, so the caller can't modify it + # while we're still waiting for the I/O to happen. + if self._write_fut is None: # IDLE -> WRITING + assert self._buffer is None + # Pass a copy, except if it's already immutable. + self._loop_writing(data=bytes(data)) + elif not self._buffer: # WRITING -> BACKED UP + # Make a mutable copy which we can extend. + self._buffer = bytearray(data) + self._maybe_pause_protocol() + else: # BACKED UP + # Append to buffer (also copies). + self._buffer.extend(data) + self._maybe_pause_protocol() + + def _loop_writing(self, f=None, data=None): + try: + assert f is self._write_fut + self._write_fut = None + self._pending_write = 0 + if f: + f.result() + if data is None: + data = self._buffer + self._buffer = None + if not data: + if self._closing: + self._loop.call_soon(self._call_connection_lost, None) + if self._eof_written: + self._sock.shutdown(socket.SHUT_WR) + # Now that we've reduced the buffer size, tell the + # protocol to resume writing if it was paused. Note that + # we do this last since the callback is called immediately + # and it may add more data to the buffer (even causing the + # protocol to be paused again). + self._maybe_resume_protocol() + else: + self._write_fut = self._loop._proactor.send(self._sock, data) + if not self._write_fut.done(): + assert self._pending_write == 0 + self._pending_write = len(data) + self._write_fut.add_done_callback(self._loop_writing) + self._maybe_pause_protocol() + else: + self._write_fut.add_done_callback(self._loop_writing) + except ConnectionResetError as exc: + self._force_close(exc) + except OSError as exc: + self._fatal_error(exc, 'Fatal write error on pipe transport') + + def can_write_eof(self): + return True + + def write_eof(self): + self.close() + + def abort(self): + self._force_close(None) + + +class _ProactorWritePipeTransport(_ProactorBaseWritePipeTransport): + def __init__(self, *args, **kw): + super().__init__(*args, **kw) + self._read_fut = self._loop._proactor.recv(self._sock, 16) + self._read_fut.add_done_callback(self._pipe_closed) + + def _pipe_closed(self, fut): + if fut.cancelled(): + # the transport has been closed + return + assert fut.result() == b'' + if self._closing: + assert self._read_fut is None + return + assert fut is self._read_fut, (fut, self._read_fut) + self._read_fut = None + if self._write_fut is not None: + self._force_close(BrokenPipeError()) + else: + self.close() + + +class _ProactorDuplexPipeTransport(_ProactorReadPipeTransport, + _ProactorBaseWritePipeTransport, + transports.Transport): + """Transport for duplex pipes.""" + + def can_write_eof(self): + return False + + def write_eof(self): + raise NotImplementedError + + +class _ProactorSocketTransport(_ProactorReadPipeTransport, + _ProactorBaseWritePipeTransport, + transports.Transport): + """Transport for connected sockets.""" + + def _set_extra(self, sock): + self._extra['socket'] = sock + try: + self._extra['sockname'] = sock.getsockname() + except (socket.error, AttributeError): + if self._loop.get_debug(): + logger.warning("getsockname() failed on %r", + sock, exc_info=True) + if 'peername' not in self._extra: + try: + self._extra['peername'] = sock.getpeername() + except (socket.error, AttributeError): + if self._loop.get_debug(): + logger.warning("getpeername() failed on %r", + sock, exc_info=True) + + def can_write_eof(self): + return True + + def write_eof(self): + if self._closing or self._eof_written: + return + self._eof_written = True + if self._write_fut is None: + self._sock.shutdown(socket.SHUT_WR) + + +class BaseProactorEventLoop(base_events.BaseEventLoop): + + def __init__(self, proactor): + super().__init__() + logger.debug('Using proactor: %s', proactor.__class__.__name__) + self._proactor = proactor + self._selector = proactor # convenient alias + self._self_reading_future = None + self._accept_futures = {} # socket file descriptor => Future + proactor.set_loop(self) + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, + extra=None, server=None): + return _ProactorSocketTransport(self, sock, protocol, waiter, + extra, server) + + def _make_duplex_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorDuplexPipeTransport(self, + sock, protocol, waiter, extra) + + def _make_read_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorReadPipeTransport(self, sock, protocol, waiter, extra) + + def _make_write_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + # We want connection_lost() to be called when other end closes + return _ProactorWritePipeTransport(self, + sock, protocol, waiter, extra) + + def close(self): + if self._running: + raise RuntimeError("Cannot close a running event loop") + if self.is_closed(): + return + self._stop_accept_futures() + self._close_self_pipe() + super().close() + self._proactor.close() + self._proactor = None + self._selector = None + + def sock_recv(self, sock, n): + return self._proactor.recv(sock, n) + + def sock_sendall(self, sock, data): + return self._proactor.send(sock, data) + + def sock_connect(self, sock, address): + try: + base_events._check_resolved_address(sock, address) + except ValueError as err: + fut = futures.Future(loop=self) + fut.set_exception(err) + return fut + else: + return self._proactor.connect(sock, address) + + def sock_accept(self, sock): + return self._proactor.accept(sock) + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + if self._self_reading_future is not None: + self._self_reading_future.cancel() + self._self_reading_future = None + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + # don't check the current loop because _make_self_pipe() is called + # from the event loop constructor + self._call_soon(self._loop_self_reading, (), check_loop=False) + + def _loop_self_reading(self, f=None): + try: + if f is not None: + f.result() # may raise + f = self._proactor.recv(self._ssock, 4096) + except: + self.close() + raise + else: + self._self_reading_future = f + f.add_done_callback(self._loop_self_reading) + + def _write_to_self(self): + self._csock.send(b'\0') + + def _start_serving(self, protocol_factory, sock, ssl=None, server=None): + if ssl: + raise ValueError('IocpEventLoop is incompatible with SSL.') + + def loop(f=None): + try: + if f is not None: + conn, addr = f.result() + if self._debug: + logger.debug("%r got a new connection from %r: %r", + server, addr, conn) + protocol = protocol_factory() + self._make_socket_transport( + conn, protocol, + extra={'peername': addr}, server=server) + if self.is_closed(): + return + f = self._proactor.accept(sock) + except OSError as exc: + if sock.fileno() != -1: + self.call_exception_handler({ + 'message': 'Accept failed on a socket', + 'exception': exc, + 'socket': sock, + }) + sock.close() + elif self._debug: + logger.debug("Accept failed on socket %r", + sock, exc_info=True) + except futures.CancelledError: + sock.close() + else: + self._accept_futures[sock.fileno()] = f + f.add_done_callback(loop) + + self.call_soon(loop) + + def _process_events(self, event_list): + pass # XXX hard work currently done in poll + + def _stop_accept_futures(self): + for future in self._accept_futures.values(): + future.cancel() + self._accept_futures.clear() + + def _stop_serving(self, sock): + self._stop_accept_futures() + self._proactor._stop_serving(sock) + sock.close() diff --git a/asyncio/protocols.py b/asyncio/protocols.py new file mode 100644 index 00000000..52fc25c2 --- /dev/null +++ b/asyncio/protocols.py @@ -0,0 +1,129 @@ +"""Abstract Protocol class.""" + +__all__ = ['BaseProtocol', 'Protocol', 'DatagramProtocol', + 'SubprocessProtocol'] + + +class BaseProtocol: + """Common base class for protocol interfaces. + + Usually user implements protocols that derived from BaseProtocol + like Protocol or ProcessProtocol. + + The only case when BaseProtocol should be implemented directly is + write-only transport like write pipe + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the pipe connection. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + def pause_writing(self): + """Called when the transport's buffer goes over the high-water mark. + + Pause and resume calls are paired -- pause_writing() is called + once when the buffer goes strictly over the high-water mark + (even if subsequent writes increases the buffer size even + more), and eventually resume_writing() is called once when the + buffer size reaches the low-water mark. + + Note that if the buffer size equals the high-water mark, + pause_writing() is not called -- it must go strictly over. + Conversely, resume_writing() is called when the buffer size is + equal or lower than the low-water mark. These end conditions + are important to ensure that things go as expected when either + mark is zero. + + NOTE: This is the only Protocol callback that is not called + through EventLoop.call_soon() -- if it were, it would have no + effect when it's most needed (when the app keeps writing + without yielding until pause_writing() is called). + """ + + def resume_writing(self): + """Called when the transport's buffer drains below the low-water mark. + + See pause_writing() for details. + """ + + +class Protocol(BaseProtocol): + """Interface for stream protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing (they don't raise exceptions). + + When the user wants to requests a transport, they pass a protocol + factory to a utility function (e.g., EventLoop.create_connection()). + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_lost() will be called exactly once + with either an exception object or None as an argument. + + State machine of calls: + + start -> CM [-> DR*] [-> ER?] -> CL -> end + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + If this returns a false value (including None), the transport + will close itself. If it returns a true value, closing the + transport is up to the protocol. + """ + + +class DatagramProtocol(BaseProtocol): + """Interface for datagram protocol.""" + + def datagram_received(self, data, addr): + """Called when some datagram is received.""" + + def error_received(self, exc): + """Called when a send or receive operation raises an OSError. + + (Other than BlockingIOError or InterruptedError.) + """ + + +class SubprocessProtocol(BaseProtocol): + """Interface for protocol for subprocess calls.""" + + def pipe_data_received(self, fd, data): + """Called when the subprocess writes data into stdout/stderr pipe. + + fd is int file descriptor. + data is bytes object. + """ + + def pipe_connection_lost(self, fd, exc): + """Called when a file descriptor associated with the child process is + closed. + + fd is the int file descriptor that was closed. + """ + + def process_exited(self): + """Called when subprocess has exited.""" diff --git a/asyncio/queues.py b/asyncio/queues.py new file mode 100644 index 00000000..41551a90 --- /dev/null +++ b/asyncio/queues.py @@ -0,0 +1,288 @@ +"""Queues""" + +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue', + 'QueueFull', 'QueueEmpty'] + +import collections +import heapq + +from . import events +from . import futures +from . import locks +from .tasks import coroutine + + +class QueueEmpty(Exception): + 'Exception raised by Queue.get(block=0)/get_nowait().' + pass + + +class QueueFull(Exception): + 'Exception raised by Queue.put(block=0)/put_nowait().' + pass + + +class Queue: + """A queue, useful for coordinating producer and consumer coroutines. + + If maxsize is less than or equal to zero, the queue size is infinite. If it + is an integer greater than 0, then "yield from put()" will block when the + queue reaches maxsize, until an item is removed by get(). + + Unlike the standard library Queue, you can reliably know this Queue's size + with qsize(), since your single-threaded asyncio application won't be + interrupted between calling qsize() and doing an operation on the Queue. + """ + + def __init__(self, maxsize=0, *, loop=None): + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._maxsize = maxsize + + # Futures. + self._getters = collections.deque() + # Pairs of (item, Future). + self._putters = collections.deque() + self._init(maxsize) + + def _init(self, maxsize): + self._queue = collections.deque() + + def _get(self): + return self._queue.popleft() + + def _put(self, item): + self._queue.append(item) + + def __repr__(self): + return '<{} at {:#x} {}>'.format( + type(self).__name__, id(self), self._format()) + + def __str__(self): + return '<{} {}>'.format(type(self).__name__, self._format()) + + def _format(self): + result = 'maxsize={!r}'.format(self._maxsize) + if getattr(self, '_queue', None): + result += ' _queue={!r}'.format(list(self._queue)) + if self._getters: + result += ' _getters[{}]'.format(len(self._getters)) + if self._putters: + result += ' _putters[{}]'.format(len(self._putters)) + return result + + def _consume_done_getters(self): + # Delete waiters at the head of the get() queue who've timed out. + while self._getters and self._getters[0].done(): + self._getters.popleft() + + def _consume_done_putters(self): + # Delete waiters at the head of the put() queue who've timed out. + while self._putters and self._putters[0][1].done(): + self._putters.popleft() + + def qsize(self): + """Number of items in the queue.""" + return len(self._queue) + + @property + def maxsize(self): + """Number of items allowed in the queue.""" + return self._maxsize + + def empty(self): + """Return True if the queue is empty, False otherwise.""" + return not self._queue + + def full(self): + """Return True if there are maxsize items in the queue. + + Note: if the Queue was initialized with maxsize=0 (the default), + then full() is never True. + """ + if self._maxsize <= 0: + return False + else: + return self.qsize() >= self._maxsize + + @coroutine + def put(self, item): + """Put an item into the queue. + + If you yield from put(), wait until a free slot is available + before adding item. + """ + self._consume_done_getters() + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize <= self.qsize(): + waiter = futures.Future(loop=self._loop) + + self._putters.append((item, waiter)) + yield from waiter + + else: + self._put(item) + + def put_nowait(self, item): + """Put an item into the queue without blocking. + + If no free slot is immediately available, raise QueueFull. + """ + self._consume_done_getters() + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize <= self.qsize(): + raise QueueFull + else: + self._put(item) + + @coroutine + def get(self): + """Remove and return an item from the queue. + + If you yield from get(), wait until a item is available. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + + # When a getter runs and frees up a slot so this putter can + # run, we need to defer the put for a tick to ensure that + # getters and putters alternate perfectly. See + # ChannelTest.test_wait. + self._loop.call_soon(putter._set_result_unless_cancelled, None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + waiter = futures.Future(loop=self._loop) + + self._getters.append(waiter) + return (yield from waiter) + + def get_nowait(self): + """Remove and return an item from the queue. + + Return an item if one is immediately available, else raise QueueEmpty. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + # Wake putter on next tick. + putter.set_result(None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + raise QueueEmpty + + +class PriorityQueue(Queue): + """A subclass of Queue; retrieves entries in priority order (lowest first). + + Entries are typically tuples of the form: (priority number, data). + """ + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item, heappush=heapq.heappush): + heappush(self._queue, item) + + def _get(self, heappop=heapq.heappop): + return heappop(self._queue) + + +class LifoQueue(Queue): + """A subclass of Queue that retrieves most recently added entries first.""" + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item): + self._queue.append(item) + + def _get(self): + return self._queue.pop() + + +class JoinableQueue(Queue): + """A subclass of Queue with task_done() and join() methods.""" + + def __init__(self, maxsize=0, *, loop=None): + super().__init__(maxsize=maxsize, loop=loop) + self._unfinished_tasks = 0 + self._finished = locks.Event(loop=self._loop) + self._finished.set() + + def _format(self): + result = Queue._format(self) + if self._unfinished_tasks: + result += ' tasks={}'.format(self._unfinished_tasks) + return result + + def _put(self, item): + super()._put(item) + self._unfinished_tasks += 1 + self._finished.clear() + + def task_done(self): + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each get() used to fetch a task, + a subsequent call to task_done() tells the queue that the processing + on the task is complete. + + If a join() is currently blocking, it will resume when all items have + been processed (meaning that a task_done() call was received for every + item that had been put() into the queue). + + Raises ValueError if called more times than there were items placed in + the queue. + """ + if self._unfinished_tasks <= 0: + raise ValueError('task_done() called too many times') + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + @coroutine + def join(self): + """Block until all items in the queue have been gotten and processed. + + The count of unfinished tasks goes up whenever an item is added to the + queue. The count goes down whenever a consumer thread calls task_done() + to indicate that the item was retrieved and all work on it is complete. + When the count of unfinished tasks drops to zero, join() unblocks. + """ + if self._unfinished_tasks > 0: + yield from self._finished.wait() diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py new file mode 100644 index 00000000..7df8b866 --- /dev/null +++ b/asyncio/selector_events.py @@ -0,0 +1,1003 @@ +"""Event loop using a selector and related classes. + +A selector is a "notify-when-ready" multiplexer. For a subclass which +also includes support for signal handling, see the unix_events sub-module. +""" + +__all__ = ['BaseSelectorEventLoop'] + +import collections +import errno +import functools +import socket +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import base_events +from . import constants +from . import events +from . import futures +from . import selectors +from . import transports +from .log import logger + + +def _test_selector_event(selector, fd, event): + # Test if the selector is monitoring 'event' events + # for the file descriptor 'fd'. + try: + key = selector.get_key(fd) + except KeyError: + return False + else: + return bool(key.events & event) + + +class BaseSelectorEventLoop(base_events.BaseEventLoop): + """Selector event loop. + + See events.EventLoop for API specification. + """ + + def __init__(self, selector=None): + super().__init__() + + if selector is None: + selector = selectors.DefaultSelector() + logger.debug('Using selector: %s', selector.__class__.__name__) + self._selector = selector + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None, server=None): + return _SelectorSocketTransport(self, sock, protocol, waiter, + extra, server) + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + server_side=False, server_hostname=None, + extra=None, server=None): + return _SelectorSslTransport( + self, rawsock, protocol, sslcontext, waiter, + server_side, server_hostname, extra, server) + + def _make_datagram_transport(self, sock, protocol, + address=None, waiter=None, extra=None): + return _SelectorDatagramTransport(self, sock, protocol, + address, waiter, extra) + + def close(self): + if self._running: + raise RuntimeError("Cannot close a running event loop") + if self.is_closed(): + return + self._close_self_pipe() + super().close() + if self._selector is not None: + self._selector.close() + self._selector = None + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self.remove_reader(self._ssock.fileno()) + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.add_reader(self._ssock.fileno(), self._read_from_self) + + def _process_self_data(self, data): + pass + + def _read_from_self(self): + while True: + try: + data = self._ssock.recv(4096) + if not data: + break + self._process_self_data(data) + except InterruptedError: + continue + except BlockingIOError: + break + + def _write_to_self(self): + # This may be called from a different thread, possibly after + # _close_self_pipe() has been called or even while it is + # running. Guard for self._csock being None or closed. When + # a socket is closed, send() raises OSError (with errno set to + # EBADF, but let's not rely on the exact error code). + csock = self._csock + if csock is not None: + try: + csock.send(b'\0') + except OSError: + if self._debug: + logger.debug("Fail to write a null byte into the " + "self-pipe socket", + exc_info=True) + + def _start_serving(self, protocol_factory, sock, + sslcontext=None, server=None): + self.add_reader(sock.fileno(), self._accept_connection, + protocol_factory, sock, sslcontext, server) + + def _accept_connection(self, protocol_factory, sock, + sslcontext=None, server=None): + try: + conn, addr = sock.accept() + if self._debug: + logger.debug("%r got a new connection from %r: %r", + server, addr, conn) + conn.setblocking(False) + except (BlockingIOError, InterruptedError, ConnectionAbortedError): + pass # False alarm. + except OSError as exc: + # There's nowhere to send the error, so just log it. + # TODO: Someone will want an error handler for this. + if exc.errno in (errno.EMFILE, errno.ENFILE, + errno.ENOBUFS, errno.ENOMEM): + # Some platforms (e.g. Linux keep reporting the FD as + # ready, so we remove the read handler temporarily. + # We'll try again in a while. + self.call_exception_handler({ + 'message': 'socket.accept() out of system resource', + 'exception': exc, + 'socket': sock, + }) + self.remove_reader(sock.fileno()) + self.call_later(constants.ACCEPT_RETRY_DELAY, + self._start_serving, + protocol_factory, sock, sslcontext, server) + else: + raise # The event loop will catch, log and ignore it. + else: + if sslcontext: + self._make_ssl_transport( + conn, protocol_factory(), sslcontext, None, + server_side=True, extra={'peername': addr}, server=server) + else: + self._make_socket_transport( + conn, protocol_factory(), extra={'peername': addr}, + server=server) + # It's now up to the protocol to handle the connection. + + def add_reader(self, fd, callback, *args): + """Add a reader callback.""" + self._check_closed() + handle = events.Handle(callback, args, self) + try: + key = self._selector.get_key(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_READ, + (handle, None)) + else: + mask, (reader, writer) = key.events, key.data + self._selector.modify(fd, mask | selectors.EVENT_READ, + (handle, writer)) + if reader is not None: + reader.cancel() + + def remove_reader(self, fd): + """Remove a reader callback.""" + if self.is_closed(): + return False + try: + key = self._selector.get_key(fd) + except KeyError: + return False + else: + mask, (reader, writer) = key.events, key.data + mask &= ~selectors.EVENT_READ + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (None, writer)) + + if reader is not None: + reader.cancel() + return True + else: + return False + + def add_writer(self, fd, callback, *args): + """Add a writer callback..""" + self._check_closed() + handle = events.Handle(callback, args, self) + try: + key = self._selector.get_key(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_WRITE, + (None, handle)) + else: + mask, (reader, writer) = key.events, key.data + self._selector.modify(fd, mask | selectors.EVENT_WRITE, + (reader, handle)) + if writer is not None: + writer.cancel() + + def remove_writer(self, fd): + """Remove a writer callback.""" + if self.is_closed(): + return False + try: + key = self._selector.get_key(fd) + except KeyError: + return False + else: + mask, (reader, writer) = key.events, key.data + # Remove both writer and connector. + mask &= ~selectors.EVENT_WRITE + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None)) + + if writer is not None: + writer.cancel() + return True + else: + return False + + def sock_recv(self, sock, n): + """Receive data from the socket. + + The return value is a bytes object representing the data received. + The maximum amount of data to be received at once is specified by + nbytes. + + This method is a coroutine. + """ + if self.get_debug() and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") + fut = futures.Future(loop=self) + self._sock_recv(fut, False, sock, n) + return fut + + def _sock_recv(self, fut, registered, sock, n): + # _sock_recv() can add itself as an I/O callback if the operation can't + # be done immediately. Don't use it directly, call sock_recv(). + fd = sock.fileno() + if registered: + # Remove the callback early. It should be rare that the + # selector says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_reader(fd) + if fut.cancelled(): + return + try: + data = sock.recv(n) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_recv, fut, True, sock, n) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(data) + + def sock_sendall(self, sock, data): + """Send data to the socket. + + The socket must be connected to a remote socket. This method continues + to send data from data until either all data has been sent or an + error occurs. None is returned on success. On error, an exception is + raised, and there is no way to determine how much data, if any, was + successfully processed by the receiving end of the connection. + + This method is a coroutine. + """ + if self.get_debug() and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") + fut = futures.Future(loop=self) + if data: + self._sock_sendall(fut, False, sock, data) + else: + fut.set_result(None) + return fut + + def _sock_sendall(self, fut, registered, sock, data): + fd = sock.fileno() + + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + + try: + n = sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + fut.set_exception(exc) + return + + if n == len(data): + fut.set_result(None) + else: + if n: + data = data[n:] + self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + + def sock_connect(self, sock, address): + """Connect to a remote socket at address. + + The address must be already resolved to avoid the trap of hanging the + entire event loop when the address requires doing a DNS lookup. For + example, it must be an IP address, not an hostname, for AF_INET and + AF_INET6 address families. Use getaddrinfo() to resolve the hostname + asynchronously. + + This method is a coroutine. + """ + if self.get_debug() and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") + fut = futures.Future(loop=self) + try: + base_events._check_resolved_address(sock, address) + except ValueError as err: + fut.set_exception(err) + else: + self._sock_connect(fut, sock, address) + return fut + + def _sock_connect(self, fut, sock, address): + fd = sock.fileno() + try: + while True: + try: + sock.connect(address) + except InterruptedError: + continue + else: + break + except BlockingIOError: + fut.add_done_callback(functools.partial(self._sock_connect_done, + sock)) + self.add_writer(fd, self._sock_connect_cb, fut, sock, address) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(None) + + def _sock_connect_done(self, sock, fut): + self.remove_writer(sock.fileno()) + + def _sock_connect_cb(self, fut, sock, address): + if fut.cancelled(): + return + + try: + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to any except clause below. + raise OSError(err, 'Connect call failed %s' % (address,)) + except (BlockingIOError, InterruptedError): + # socket is still registered, the callback will be retried later + pass + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(None) + + def sock_accept(self, sock): + """Accept a connection. + + The socket must be bound to an address and listening for connections. + The return value is a pair (conn, address) where conn is a new socket + object usable to send and receive data on the connection, and address + is the address bound to the socket on the other end of the connection. + + This method is a coroutine. + """ + if self.get_debug() and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") + fut = futures.Future(loop=self) + self._sock_accept(fut, False, sock) + return fut + + def _sock_accept(self, fut, registered, sock): + fd = sock.fileno() + if registered: + self.remove_reader(fd) + if fut.cancelled(): + return + try: + conn, address = sock.accept() + conn.setblocking(False) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_accept, fut, True, sock) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result((conn, address)) + + def _process_events(self, event_list): + for key, mask in event_list: + fileobj, (reader, writer) = key.fileobj, key.data + if mask & selectors.EVENT_READ and reader is not None: + if reader._cancelled: + self.remove_reader(fileobj) + else: + self._add_callback(reader) + if mask & selectors.EVENT_WRITE and writer is not None: + if writer._cancelled: + self.remove_writer(fileobj) + else: + self._add_callback(writer) + + def _stop_serving(self, sock): + self.remove_reader(sock.fileno()) + sock.close() + + +class _SelectorTransport(transports._FlowControlMixin, + transports.Transport): + + max_size = 256 * 1024 # Buffer size passed to recv(). + + _buffer_factory = bytearray # Constructs initial value for self._buffer. + + def __init__(self, loop, sock, protocol, extra, server=None): + super().__init__(extra, loop) + self._extra['socket'] = sock + self._extra['sockname'] = sock.getsockname() + if 'peername' not in self._extra: + try: + self._extra['peername'] = sock.getpeername() + except socket.error: + self._extra['peername'] = None + self._sock = sock + self._sock_fd = sock.fileno() + self._protocol = protocol + self._server = server + self._buffer = self._buffer_factory() + self._conn_lost = 0 # Set when call to connection_lost scheduled. + self._closing = False # Set when close() called. + if self._server is not None: + self._server._attach() + + def __repr__(self): + info = [self.__class__.__name__] + if self._sock is None: + info.append('closed') + elif self._closing: + info.append('closing') + info.append('fd=%s' % self._sock_fd) + # test if the transport was closed + if self._loop is not None: + polling = _test_selector_event(self._loop._selector, + self._sock_fd, selectors.EVENT_READ) + if polling: + info.append('read=polling') + else: + info.append('read=idle') + + polling = _test_selector_event(self._loop._selector, + self._sock_fd, selectors.EVENT_WRITE) + if polling: + state = 'polling' + else: + state = 'idle' + + bufsize = self.get_write_buffer_size() + info.append('write=<%s, bufsize=%s>' % (state, bufsize)) + return '<%s>' % ' '.join(info) + + def abort(self): + self._force_close(None) + + def close(self): + if self._closing: + return + self._closing = True + self._loop.remove_reader(self._sock_fd) + if not self._buffer: + self._conn_lost += 1 + self._loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc, message='Fatal error on transport'): + # Should be called from exception handler only. + if isinstance(exc, (BrokenPipeError, ConnectionResetError)): + if self._loop.get_debug(): + logger.debug("%r: %s", self, message, exc_info=True) + else: + self._loop.call_exception_handler({ + 'message': message, + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) + self._force_close(exc) + + def _force_close(self, exc): + if self._conn_lost: + return + if self._buffer: + self._buffer.clear() + self._loop.remove_writer(self._sock_fd) + if not self._closing: + self._closing = True + self._loop.remove_reader(self._sock_fd) + self._conn_lost += 1 + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + self._sock = None + self._protocol = None + self._loop = None + server = self._server + if server is not None: + server._detach() + self._server = None + + def get_write_buffer_size(self): + return len(self._buffer) + + +class _SelectorSocketTransport(_SelectorTransport): + + def __init__(self, loop, sock, protocol, waiter=None, + extra=None, server=None): + super().__init__(loop, sock, protocol, extra, server) + self._eof = False + self._paused = False + + self._loop.add_reader(self._sock_fd, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + # wait until protocol.connection_made() has been called + self._loop.call_soon(waiter._set_result_unless_cancelled, None) + + def pause_reading(self): + if self._closing: + raise RuntimeError('Cannot pause_reading() when closing') + if self._paused: + raise RuntimeError('Already paused') + self._paused = True + self._loop.remove_reader(self._sock_fd) + if self._loop.get_debug(): + logger.debug("%r pauses reading", self) + + def resume_reading(self): + if not self._paused: + raise RuntimeError('Not paused') + self._paused = False + if self._closing: + return + self._loop.add_reader(self._sock_fd, self._read_ready) + if self._loop.get_debug(): + logger.debug("%r resumes reading", self) + + def _read_ready(self): + try: + data = self._sock.recv(self.max_size) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc, 'Fatal read error on socket transport') + else: + if data: + self._protocol.data_received(data) + else: + if self._loop.get_debug(): + logger.debug("%r received EOF", self) + keep_open = self._protocol.eof_received() + if keep_open: + # We're keeping the connection open so the + # protocol can write more, but we still can't + # receive more, so remove the reader callback. + self._loop.remove_reader(self._sock_fd) + else: + self.close() + + def write(self, data): + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError('data argument must be byte-ish (%r)', + type(data)) + if self._eof: + raise RuntimeError('Cannot call write() after write_eof()') + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Optimization: try to send now. + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc, 'Fatal write error on socket transport') + return + else: + data = data[n:] + if not data: + return + # Not all was written; register write handler. + self._loop.add_writer(self._sock_fd, self._write_ready) + + # Add it to the buffer. + self._buffer.extend(data) + self._maybe_pause_protocol() + + def _write_ready(self): + assert self._buffer, 'Data should not be empty' + + try: + n = self._sock.send(self._buffer) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._loop.remove_writer(self._sock_fd) + self._buffer.clear() + self._fatal_error(exc, 'Fatal write error on socket transport') + else: + if n: + del self._buffer[:n] + self._maybe_resume_protocol() # May append to buffer. + if not self._buffer: + self._loop.remove_writer(self._sock_fd) + if self._closing: + self._call_connection_lost(None) + elif self._eof: + self._sock.shutdown(socket.SHUT_WR) + + def write_eof(self): + if self._eof: + return + self._eof = True + if not self._buffer: + self._sock.shutdown(socket.SHUT_WR) + + def can_write_eof(self): + return True + + +class _SelectorSslTransport(_SelectorTransport): + + _buffer_factory = bytearray + + def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, + server_side=False, server_hostname=None, + extra=None, server=None): + if ssl is None: + raise RuntimeError('stdlib ssl module not available') + + if server_side: + if not sslcontext: + raise ValueError('Server side ssl needs a valid SSLContext') + else: + if not sslcontext: + # Client side may pass ssl=True to use a default + # context; in that case the sslcontext passed is None. + # The default is secure for client connections. + if hasattr(ssl, 'create_default_context'): + # Python 3.4+: use up-to-date strong settings. + sslcontext = ssl.create_default_context() + if not server_hostname: + sslcontext.check_hostname = False + else: + # Fallback for Python 3.3. + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.options |= ssl.OP_NO_SSLv3 + sslcontext.set_default_verify_paths() + sslcontext.verify_mode = ssl.CERT_REQUIRED + + wrap_kwargs = { + 'server_side': server_side, + 'do_handshake_on_connect': False, + } + if server_hostname and not server_side: + wrap_kwargs['server_hostname'] = server_hostname + sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs) + + super().__init__(loop, sslsock, protocol, extra, server) + + self._server_hostname = server_hostname + self._waiter = waiter + self._sslcontext = sslcontext + self._paused = False + + # SSL-specific extra info. (peercert is set later) + self._extra.update(sslcontext=sslcontext) + + if self._loop.get_debug(): + logger.debug("%r starts SSL handshake", self) + start_time = self._loop.time() + else: + start_time = None + self._on_handshake(start_time) + + def _on_handshake(self, start_time): + try: + self._sock.do_handshake() + except ssl.SSLWantReadError: + self._loop.add_reader(self._sock_fd, + self._on_handshake, start_time) + return + except ssl.SSLWantWriteError: + self._loop.add_writer(self._sock_fd, + self._on_handshake, start_time) + return + except BaseException as exc: + if self._loop.get_debug(): + logger.warning("%r: SSL handshake failed", + self, exc_info=True) + self._loop.remove_reader(self._sock_fd) + self._loop.remove_writer(self._sock_fd) + self._sock.close() + if self._waiter is not None: + self._waiter.set_exception(exc) + if isinstance(exc, Exception): + return + else: + raise + + self._loop.remove_reader(self._sock_fd) + self._loop.remove_writer(self._sock_fd) + + peercert = self._sock.getpeercert() + if not hasattr(self._sslcontext, 'check_hostname'): + # Verify hostname if requested, Python 3.4+ uses check_hostname + # and checks the hostname in do_handshake() + if (self._server_hostname and + self._sslcontext.verify_mode != ssl.CERT_NONE): + try: + ssl.match_hostname(peercert, self._server_hostname) + except Exception as exc: + if self._loop.get_debug(): + logger.warning("%r: SSL handshake failed " + "on matching the hostname", + self, exc_info=True) + self._sock.close() + if self._waiter is not None: + self._waiter.set_exception(exc) + return + + # Add extra info that becomes available after handshake. + self._extra.update(peercert=peercert, + cipher=self._sock.cipher(), + compression=self._sock.compression(), + ) + + self._read_wants_write = False + self._write_wants_read = False + self._loop.add_reader(self._sock_fd, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if self._waiter is not None: + # wait until protocol.connection_made() has been called + self._loop.call_soon(self._waiter._set_result_unless_cancelled, + None) + + if self._loop.get_debug(): + dt = self._loop.time() - start_time + logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3) + + def pause_reading(self): + # XXX This is a bit icky, given the comment at the top of + # _read_ready(). Is it possible to evoke a deadlock? I don't + # know, although it doesn't look like it; write() will still + # accept more data for the buffer and eventually the app will + # call resume_reading() again, and things will flow again. + + if self._closing: + raise RuntimeError('Cannot pause_reading() when closing') + if self._paused: + raise RuntimeError('Already paused') + self._paused = True + self._loop.remove_reader(self._sock_fd) + if self._loop.get_debug(): + logger.debug("%r pauses reading", self) + + def resume_reading(self): + if not self._paused: + raise RuntimeError('Not paused') + self._paused = False + if self._closing: + return + self._loop.add_reader(self._sock_fd, self._read_ready) + if self._loop.get_debug(): + logger.debug("%r resumes reading", self) + + def _read_ready(self): + if self._write_wants_read: + self._write_wants_read = False + self._write_ready() + + if self._buffer: + self._loop.add_writer(self._sock_fd, self._write_ready) + + try: + data = self._sock.recv(self.max_size) + except (BlockingIOError, InterruptedError, ssl.SSLWantReadError): + pass + except ssl.SSLWantWriteError: + self._read_wants_write = True + self._loop.remove_reader(self._sock_fd) + self._loop.add_writer(self._sock_fd, self._write_ready) + except Exception as exc: + self._fatal_error(exc, 'Fatal read error on SSL transport') + else: + if data: + self._protocol.data_received(data) + else: + try: + if self._loop.get_debug(): + logger.debug("%r received EOF", self) + keep_open = self._protocol.eof_received() + if keep_open: + logger.warning('returning true from eof_received() ' + 'has no effect when using ssl') + finally: + self.close() + + def _write_ready(self): + if self._read_wants_write: + self._read_wants_write = False + self._read_ready() + + if not (self._paused or self._closing): + self._loop.add_reader(self._sock_fd, self._read_ready) + + if self._buffer: + try: + n = self._sock.send(self._buffer) + except (BlockingIOError, InterruptedError, ssl.SSLWantWriteError): + n = 0 + except ssl.SSLWantReadError: + n = 0 + self._loop.remove_writer(self._sock_fd) + self._write_wants_read = True + except Exception as exc: + self._loop.remove_writer(self._sock_fd) + self._buffer.clear() + self._fatal_error(exc, 'Fatal write error on SSL transport') + return + + if n: + del self._buffer[:n] + + self._maybe_resume_protocol() # May append to buffer. + + if not self._buffer: + self._loop.remove_writer(self._sock_fd) + if self._closing: + self._call_connection_lost(None) + + def write(self, data): + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError('data argument must be byte-ish (%r)', + type(data)) + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + self._loop.add_writer(self._sock_fd, self._write_ready) + + # Add it to the buffer. + self._buffer.extend(data) + self._maybe_pause_protocol() + + def can_write_eof(self): + return False + + +class _SelectorDatagramTransport(_SelectorTransport): + + _buffer_factory = collections.deque + + def __init__(self, loop, sock, protocol, address=None, + waiter=None, extra=None): + super().__init__(loop, sock, protocol, extra) + self._address = address + self._loop.add_reader(self._sock_fd, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + # wait until protocol.connection_made() has been called + self._loop.call_soon(waiter._set_result_unless_cancelled, None) + + def get_write_buffer_size(self): + return sum(len(data) for data, _ in self._buffer) + + def _read_ready(self): + try: + data, addr = self._sock.recvfrom(self.max_size) + except (BlockingIOError, InterruptedError): + pass + except OSError as exc: + self._protocol.error_received(exc) + except Exception as exc: + self._fatal_error(exc, 'Fatal read error on datagram transport') + else: + self._protocol.datagram_received(data, addr) + + def sendto(self, data, addr=None): + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError('data argument must be byte-ish (%r)', + type(data)) + if not data: + return + + if self._address and addr not in (None, self._address): + raise ValueError('Invalid address: must be None or %s' % + (self._address,)) + + if self._conn_lost and self._address: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Attempt to send it right away first. + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + return + except (BlockingIOError, InterruptedError): + self._loop.add_writer(self._sock_fd, self._sendto_ready) + except OSError as exc: + self._protocol.error_received(exc) + return + except Exception as exc: + self._fatal_error(exc, + 'Fatal write error on datagram transport') + return + + # Ensure that what we buffer is immutable. + self._buffer.append((bytes(data), addr)) + self._maybe_pause_protocol() + + def _sendto_ready(self): + while self._buffer: + data, addr = self._buffer.popleft() + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + except (BlockingIOError, InterruptedError): + self._buffer.appendleft((data, addr)) # Try again later. + break + except OSError as exc: + self._protocol.error_received(exc) + return + except Exception as exc: + self._fatal_error(exc, + 'Fatal write error on datagram transport') + return + + self._maybe_resume_protocol() # May append to buffer. + if not self._buffer: + self._loop.remove_writer(self._sock_fd) + if self._closing: + self._call_connection_lost(None) diff --git a/asyncio/selectors.py b/asyncio/selectors.py new file mode 100644 index 00000000..faa2d3da --- /dev/null +++ b/asyncio/selectors.py @@ -0,0 +1,590 @@ +"""Selectors module. + +This module allows high-level and efficient I/O multiplexing, built upon the +`select` module primitives. +""" + + +from abc import ABCMeta, abstractmethod +from collections import namedtuple, Mapping +import math +import select +import sys + + +# generic events, that must be mapped to implementation-specific ones +EVENT_READ = (1 << 0) +EVENT_WRITE = (1 << 1) + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file object or file descriptor + + Returns: + corresponding file descriptor + + Raises: + ValueError if the object is invalid + """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (AttributeError, TypeError, ValueError): + raise ValueError("Invalid file object: " + "{!r}".format(fileobj)) from None + if fd < 0: + raise ValueError("Invalid file descriptor: {}".format(fd)) + return fd + + +SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data']) +"""Object used to associate a file object to its backing file descriptor, +selected event mask and attached data.""" + + +class _SelectorMapping(Mapping): + """Mapping of file objects to selector keys.""" + + def __init__(self, selector): + self._selector = selector + + def __len__(self): + return len(self._selector._fd_to_key) + + def __getitem__(self, fileobj): + try: + fd = self._selector._fileobj_lookup(fileobj) + return self._selector._fd_to_key[fd] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + + def __iter__(self): + return iter(self._selector._fd_to_key) + + +class BaseSelector(metaclass=ABCMeta): + """Selector abstract base class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + efficient implementation on the current platform. + """ + + @abstractmethod + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object or file descriptor + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + + Raises: + ValueError if events is invalid + KeyError if fileobj is already registered + OSError if fileobj is closed or otherwise is unacceptable to + the underlying system call (if a system call is made) + + Note: + OSError may or may not be raised + """ + raise NotImplementedError + + @abstractmethod + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object or file descriptor + + Returns: + SelectorKey instance + + Raises: + KeyError if fileobj is not registered + + Note: + If fileobj is registered but has since been closed this does + *not* raise OSError (even if the wrapped syscall does) + """ + raise NotImplementedError + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object or file descriptor + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + + Raises: + Anything that unregister() or register() raises + """ + self.unregister(fileobj) + return self.register(fileobj, events, data) + + @abstractmethod + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout <= 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (key, events) for ready file objects + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE + """ + raise NotImplementedError + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + pass + + def get_key(self, fileobj): + """Return the key associated to a registered file object. + + Returns: + SelectorKey for this file object + """ + mapping = self.get_map() + try: + return mapping[fileobj] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + + @abstractmethod + def get_map(self): + """Return a mapping of file objects to selector keys.""" + raise NotImplementedError + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + +class _BaseSelectorImpl(BaseSelector): + """Base selector implementation.""" + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # read-only mapping returned by get_map() + self._map = _SelectorMapping(self) + + def _fileobj_lookup(self, fileobj): + """Return a file descriptor from a file object. + + This wraps _fileobj_to_fd() to do an exhaustive search in case + the object is invalid but we still have it in our map. This + is used by unregister() so we can unregister an object that + was previously registered even if it is closed. It is also + used by _SelectorMapping. + """ + try: + return _fileobj_to_fd(fileobj) + except ValueError: + # Do an exhaustive search. + for key in self._fd_to_key.values(): + if key.fileobj is fileobj: + return key.fd + # Raise ValueError after all. + raise + + def register(self, fileobj, events, data=None): + if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)): + raise ValueError("Invalid events: {!r}".format(events)) + + key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data) + + if key.fd in self._fd_to_key: + raise KeyError("{!r} (FD {}) is already registered" + .format(fileobj, key.fd)) + + self._fd_to_key[key.fd] = key + return key + + def unregister(self, fileobj): + try: + key = self._fd_to_key.pop(self._fileobj_lookup(fileobj)) + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + return key + + def modify(self, fileobj, events, data=None): + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fd_to_key[self._fileobj_lookup(fileobj)] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + if events != key.events: + self.unregister(fileobj) + key = self.register(fileobj, events, data) + elif data != key.data: + # Use a shortcut to update the data. + key = key._replace(data=data) + self._fd_to_key[key.fd] = key + return key + + def close(self): + self._fd_to_key.clear() + + def get_map(self): + return self._map + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key, or None if not found + """ + try: + return self._fd_to_key[fd] + except KeyError: + return None + + +class SelectSelector(_BaseSelectorImpl): + """Select-based selector.""" + + def __init__(self): + super().__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & EVENT_WRITE: + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + if sys.platform == 'win32': + def _select(self, r, w, _, timeout=None): + r, w, x = select.select(r, w, w, timeout) + return r, w + x, [] + else: + _select = select.select + + def select(self, timeout=None): + timeout = None if timeout is None else max(timeout, 0) + ready = [] + try: + r, w, _ = self._select(self._readers, self._writers, [], timeout) + except InterruptedError: + return ready + r = set(r) + w = set(w) + for fd in r | w: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + +if hasattr(select, 'poll'): + + class PollSelector(_BaseSelectorImpl): + """Poll-based selector.""" + + def __init__(self): + super().__init__() + self._poll = select.poll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= select.POLLIN + if events & EVENT_WRITE: + poll_events |= select.POLLOUT + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + if timeout is None: + timeout = None + elif timeout <= 0: + timeout = 0 + else: + # poll() has a resolution of 1 millisecond, round away from + # zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) + ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & ~select.POLLIN: + events |= EVENT_WRITE + if event & ~select.POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + +if hasattr(select, 'epoll'): + + class EpollSelector(_BaseSelectorImpl): + """Epoll-based selector.""" + + def __init__(self): + super().__init__() + self._epoll = select.epoll() + + def fileno(self): + return self._epoll.fileno() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + epoll_events = 0 + if events & EVENT_READ: + epoll_events |= select.EPOLLIN + if events & EVENT_WRITE: + epoll_events |= select.EPOLLOUT + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + try: + self._epoll.unregister(key.fd) + except OSError: + # This can happen if the FD was closed since it + # was registered. + pass + return key + + def select(self, timeout=None): + if timeout is None: + timeout = -1 + elif timeout <= 0: + timeout = 0 + else: + # epoll_wait() has a resolution of 1 millisecond, round away + # from zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) * 1e-3 + + # epoll_wait() expects `maxevents` to be greater than zero; + # we want to make sure that `select()` can be called when no + # FD is registered. + max_ev = max(len(self._fd_to_key), 1) + + ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & ~select.EPOLLIN: + events |= EVENT_WRITE + if event & ~select.EPOLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + self._epoll.close() + super().close() + + +if hasattr(select, 'devpoll'): + + class DevpollSelector(_BaseSelectorImpl): + """Solaris /dev/poll selector.""" + + def __init__(self): + super().__init__() + self._devpoll = select.devpoll() + + def fileno(self): + return self._devpoll.fileno() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= select.POLLIN + if events & EVENT_WRITE: + poll_events |= select.POLLOUT + self._devpoll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._devpoll.unregister(key.fd) + return key + + def select(self, timeout=None): + if timeout is None: + timeout = None + elif timeout <= 0: + timeout = 0 + else: + # devpoll() has a resolution of 1 millisecond, round away from + # zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) + ready = [] + try: + fd_event_list = self._devpoll.poll(timeout) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & ~select.POLLIN: + events |= EVENT_WRITE + if event & ~select.POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + self._devpoll.close() + super().close() + + +if hasattr(select, 'kqueue'): + + class KqueueSelector(_BaseSelectorImpl): + """Kqueue-based selector.""" + + def __init__(self): + super().__init__() + self._kqueue = select.kqueue() + + def fileno(self): + return self._kqueue.fileno() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + if key.events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_DELETE) + try: + self._kqueue.control([kev], 0, 0) + except OSError: + # This can happen if the FD was closed since it + # was registered. + pass + if key.events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_DELETE) + try: + self._kqueue.control([kev], 0, 0) + except OSError: + # See comment above. + pass + return key + + def select(self, timeout=None): + timeout = None if timeout is None else max(timeout, 0) + max_ev = len(self._fd_to_key) + ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except InterruptedError: + return ready + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == select.KQ_FILTER_READ: + events |= EVENT_READ + if flag == select.KQ_FILTER_WRITE: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + self._kqueue.close() + super().close() + + +# Choose the best implementation: roughly, epoll|kqueue|devpoll > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + DefaultSelector = KqueueSelector +elif 'EpollSelector' in globals(): + DefaultSelector = EpollSelector +elif 'DevpollSelector' in globals(): + DefaultSelector = DevpollSelector +elif 'PollSelector' in globals(): + DefaultSelector = PollSelector +else: + DefaultSelector = SelectSelector diff --git a/asyncio/streams.py b/asyncio/streams.py new file mode 100644 index 00000000..c77eb606 --- /dev/null +++ b/asyncio/streams.py @@ -0,0 +1,485 @@ +"""Stream-related things.""" + +__all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol', + 'open_connection', 'start_server', + 'IncompleteReadError', + ] + +import socket + +if hasattr(socket, 'AF_UNIX'): + __all__.extend(['open_unix_connection', 'start_unix_server']) + +from . import coroutines +from . import events +from . import futures +from . import protocols +from .coroutines import coroutine +from .log import logger + + +_DEFAULT_LIMIT = 2**16 + + +class IncompleteReadError(EOFError): + """ + Incomplete read error. Attributes: + + - partial: read bytes string before the end of stream was reached + - expected: total number of expected bytes + """ + def __init__(self, partial, expected): + EOFError.__init__(self, "%s bytes read on a total of %s expected bytes" + % (len(partial), expected)) + self.partial = partial + self.expected = expected + + +@coroutine +def open_connection(host=None, port=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """A wrapper for create_connection() returning a (reader, writer) pair. + + The reader returned is a StreamReader instance; the writer is a + StreamWriter instance. + + The arguments are all the usual arguments to create_connection() + except protocol_factory; most common are positional host and port, + with various optional keyword arguments following. + + Additional optional keyword arguments are loop (to set the event loop + instance to use) and limit (to set the buffer limit passed to the + StreamReader). + + (If you want to customize the StreamReader and/or + StreamReaderProtocol classes, just copy the code -- there's + really nothing special here except some convenience.) + """ + if loop is None: + loop = events.get_event_loop() + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, loop=loop) + transport, _ = yield from loop.create_connection( + lambda: protocol, host, port, **kwds) + writer = StreamWriter(transport, protocol, reader, loop) + return reader, writer + + +@coroutine +def start_server(client_connected_cb, host=None, port=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """Start a socket server, call back for each client connected. + + The first parameter, `client_connected_cb`, takes two parameters: + client_reader, client_writer. client_reader is a StreamReader + object, while client_writer is a StreamWriter object. This + parameter can either be a plain callback function or a coroutine; + if it is a coroutine, it will be automatically converted into a + Task. + + The rest of the arguments are all the usual arguments to + loop.create_server() except protocol_factory; most common are + positional host and port, with various optional keyword arguments + following. The return value is the same as loop.create_server(). + + Additional optional keyword arguments are loop (to set the event loop + instance to use) and limit (to set the buffer limit passed to the + StreamReader). + + The return value is the same as loop.create_server(), i.e. a + Server object which can be used to stop the service. + """ + if loop is None: + loop = events.get_event_loop() + + def factory(): + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, client_connected_cb, + loop=loop) + return protocol + + return (yield from loop.create_server(factory, host, port, **kwds)) + + +if hasattr(socket, 'AF_UNIX'): + # UNIX Domain Sockets are supported on this platform + + @coroutine + def open_unix_connection(path=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """Similar to `open_connection` but works with UNIX Domain Sockets.""" + if loop is None: + loop = events.get_event_loop() + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, loop=loop) + transport, _ = yield from loop.create_unix_connection( + lambda: protocol, path, **kwds) + writer = StreamWriter(transport, protocol, reader, loop) + return reader, writer + + + @coroutine + def start_unix_server(client_connected_cb, path=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """Similar to `start_server` but works with UNIX Domain Sockets.""" + if loop is None: + loop = events.get_event_loop() + + def factory(): + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, client_connected_cb, + loop=loop) + return protocol + + return (yield from loop.create_unix_server(factory, path, **kwds)) + + +class FlowControlMixin(protocols.Protocol): + """Reusable flow control logic for StreamWriter.drain(). + + This implements the protocol methods pause_writing(), + resume_reading() and connection_lost(). If the subclass overrides + these it must call the super methods. + + StreamWriter.drain() must wait for _drain_helper() coroutine. + """ + + def __init__(self, loop=None): + self._loop = loop # May be None; we may never need it. + self._paused = False + self._drain_waiter = None + self._connection_lost = False + + def pause_writing(self): + assert not self._paused + self._paused = True + if self._loop.get_debug(): + logger.debug("%r pauses writing", self) + + def resume_writing(self): + assert self._paused + self._paused = False + if self._loop.get_debug(): + logger.debug("%r resumes writing", self) + + waiter = self._drain_waiter + if waiter is not None: + self._drain_waiter = None + if not waiter.done(): + waiter.set_result(None) + + def connection_lost(self, exc): + self._connection_lost = True + # Wake up the writer if currently paused. + if not self._paused: + return + waiter = self._drain_waiter + if waiter is None: + return + self._drain_waiter = None + if waiter.done(): + return + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) + + @coroutine + def _drain_helper(self): + if self._connection_lost: + raise ConnectionResetError('Connection lost') + if not self._paused: + return + waiter = self._drain_waiter + assert waiter is None or waiter.cancelled() + waiter = futures.Future(loop=self._loop) + self._drain_waiter = waiter + yield from waiter + + +class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): + """Helper class to adapt between Protocol and StreamReader. + + (This is a helper class instead of making StreamReader itself a + Protocol subclass, because the StreamReader has other potential + uses, and to prevent the user of the StreamReader to accidentally + call inappropriate methods of the protocol.) + """ + + def __init__(self, stream_reader, client_connected_cb=None, loop=None): + super().__init__(loop=loop) + self._stream_reader = stream_reader + self._stream_writer = None + self._client_connected_cb = client_connected_cb + + def connection_made(self, transport): + self._stream_reader.set_transport(transport) + if self._client_connected_cb is not None: + self._stream_writer = StreamWriter(transport, self, + self._stream_reader, + self._loop) + res = self._client_connected_cb(self._stream_reader, + self._stream_writer) + if coroutines.iscoroutine(res): + self._loop.create_task(res) + + def connection_lost(self, exc): + if exc is None: + self._stream_reader.feed_eof() + else: + self._stream_reader.set_exception(exc) + super().connection_lost(exc) + + def data_received(self, data): + self._stream_reader.feed_data(data) + + def eof_received(self): + self._stream_reader.feed_eof() + + +class StreamWriter: + """Wraps a Transport. + + This exposes write(), writelines(), [can_]write_eof(), + get_extra_info() and close(). It adds drain() which returns an + optional Future on which you can wait for flow control. It also + adds a transport property which references the Transport + directly. + """ + + def __init__(self, transport, protocol, reader, loop): + self._transport = transport + self._protocol = protocol + # drain() expects that the reader has a exception() method + assert reader is None or isinstance(reader, StreamReader) + self._reader = reader + self._loop = loop + + def __repr__(self): + info = [self.__class__.__name__, 'transport=%r' % self._transport] + if self._reader is not None: + info.append('reader=%r' % self._reader) + return '<%s>' % ' '.join(info) + + @property + def transport(self): + return self._transport + + def write(self, data): + self._transport.write(data) + + def writelines(self, data): + self._transport.writelines(data) + + def write_eof(self): + return self._transport.write_eof() + + def can_write_eof(self): + return self._transport.can_write_eof() + + def close(self): + return self._transport.close() + + def get_extra_info(self, name, default=None): + return self._transport.get_extra_info(name, default) + + @coroutine + def drain(self): + """Flush the write buffer. + + The intended use is to write + + w.write(data) + yield from w.drain() + """ + if self._reader is not None: + exc = self._reader.exception() + if exc is not None: + raise exc + yield from self._protocol._drain_helper() + + +class StreamReader: + + def __init__(self, limit=_DEFAULT_LIMIT, loop=None): + # The line length limit is a security feature; + # it also doubles as half the buffer limit. + self._limit = limit + if loop is None: + loop = events.get_event_loop() + self._loop = loop + self._buffer = bytearray() + self._eof = False # Whether we're done. + self._waiter = None # A future. + self._exception = None + self._transport = None + self._paused = False + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + waiter = self._waiter + if waiter is not None: + self._waiter = None + if not waiter.cancelled(): + waiter.set_exception(exc) + + def set_transport(self, transport): + assert self._transport is None, 'Transport already set' + self._transport = transport + + def _maybe_resume_transport(self): + if self._paused and len(self._buffer) <= self._limit: + self._paused = False + self._transport.resume_reading() + + def feed_eof(self): + self._eof = True + waiter = self._waiter + if waiter is not None: + self._waiter = None + if not waiter.cancelled(): + waiter.set_result(True) + + def at_eof(self): + """Return True if the buffer is empty and 'feed_eof' was called.""" + return self._eof and not self._buffer + + def feed_data(self, data): + assert not self._eof, 'feed_data after feed_eof' + + if not data: + return + + self._buffer.extend(data) + + waiter = self._waiter + if waiter is not None: + self._waiter = None + if not waiter.cancelled(): + waiter.set_result(False) + + if (self._transport is not None and + not self._paused and + len(self._buffer) > 2*self._limit): + try: + self._transport.pause_reading() + except NotImplementedError: + # The transport can't be paused. + # We'll just have to buffer all data. + # Forget the transport so we don't keep trying. + self._transport = None + else: + self._paused = True + + def _create_waiter(self, func_name): + # StreamReader uses a future to link the protocol feed_data() method + # to a read coroutine. Running two read coroutines at the same time + # would have an unexpected behaviour. It would not possible to know + # which coroutine would get the next data. + if self._waiter is not None: + raise RuntimeError('%s() called while another coroutine is ' + 'already waiting for incoming data' % func_name) + return futures.Future(loop=self._loop) + + @coroutine + def readline(self): + if self._exception is not None: + raise self._exception + + line = bytearray() + not_enough = True + + while not_enough: + while self._buffer and not_enough: + ichar = self._buffer.find(b'\n') + if ichar < 0: + line.extend(self._buffer) + self._buffer.clear() + else: + ichar += 1 + line.extend(self._buffer[:ichar]) + del self._buffer[:ichar] + not_enough = False + + if len(line) > self._limit: + self._maybe_resume_transport() + raise ValueError('Line is too long') + + if self._eof: + break + + if not_enough: + self._waiter = self._create_waiter('readline') + try: + yield from self._waiter + finally: + self._waiter = None + + self._maybe_resume_transport() + return bytes(line) + + @coroutine + def read(self, n=-1): + if self._exception is not None: + raise self._exception + + if not n: + return b'' + + if n < 0: + # This used to just loop creating a new waiter hoping to + # collect everything in self._buffer, but that would + # deadlock if the subprocess sends more than self.limit + # bytes. So just call self.read(self._limit) until EOF. + blocks = [] + while True: + block = yield from self.read(self._limit) + if not block: + break + blocks.append(block) + return b''.join(blocks) + else: + if not self._buffer and not self._eof: + self._waiter = self._create_waiter('read') + try: + yield from self._waiter + finally: + self._waiter = None + + if n < 0 or len(self._buffer) <= n: + data = bytes(self._buffer) + self._buffer.clear() + else: + # n > 0 and len(self._buffer) > n + data = bytes(self._buffer[:n]) + del self._buffer[:n] + + self._maybe_resume_transport() + return data + + @coroutine + def readexactly(self, n): + if self._exception is not None: + raise self._exception + + # There used to be "optimized" code here. It created its own + # Future and waited until self._buffer had at least the n + # bytes, then called read(n). Unfortunately, this could pause + # the transport if the argument was larger than the pause + # limit (which is twice self._limit). So now we just read() + # into a local buffer. + + blocks = [] + while n > 0: + block = yield from self.read(n) + if not block: + partial = b''.join(blocks) + raise IncompleteReadError(partial, len(partial) + n) + blocks.append(block) + n -= len(block) + + return b''.join(blocks) diff --git a/asyncio/subprocess.py b/asyncio/subprocess.py new file mode 100644 index 00000000..f6d6a141 --- /dev/null +++ b/asyncio/subprocess.py @@ -0,0 +1,235 @@ +__all__ = ['create_subprocess_exec', 'create_subprocess_shell'] + +import collections +import subprocess + +from . import events +from . import futures +from . import protocols +from . import streams +from . import tasks +from .coroutines import coroutine +from .log import logger + + +PIPE = subprocess.PIPE +STDOUT = subprocess.STDOUT +DEVNULL = subprocess.DEVNULL + + +class SubprocessStreamProtocol(streams.FlowControlMixin, + protocols.SubprocessProtocol): + """Like StreamReaderProtocol, but for a subprocess.""" + + def __init__(self, limit, loop): + super().__init__(loop=loop) + self._limit = limit + self.stdin = self.stdout = self.stderr = None + self.waiter = futures.Future(loop=loop) + self._waiters = collections.deque() + self._transport = None + + def __repr__(self): + info = [self.__class__.__name__] + if self.stdin is not None: + info.append('stdin=%r' % self.stdin) + if self.stdout is not None: + info.append('stdout=%r' % self.stdout) + if self.stderr is not None: + info.append('stderr=%r' % self.stderr) + return '<%s>' % ' '.join(info) + + def connection_made(self, transport): + self._transport = transport + + stdout_transport = transport.get_pipe_transport(1) + if stdout_transport is not None: + self.stdout = streams.StreamReader(limit=self._limit, + loop=self._loop) + self.stdout.set_transport(stdout_transport) + + stderr_transport = transport.get_pipe_transport(2) + if stderr_transport is not None: + self.stderr = streams.StreamReader(limit=self._limit, + loop=self._loop) + self.stderr.set_transport(stderr_transport) + + stdin_transport = transport.get_pipe_transport(0) + if stdin_transport is not None: + self.stdin = streams.StreamWriter(stdin_transport, + protocol=self, + reader=None, + loop=self._loop) + self.waiter.set_result(None) + + def pipe_data_received(self, fd, data): + if fd == 1: + reader = self.stdout + elif fd == 2: + reader = self.stderr + else: + reader = None + if reader is not None: + reader.feed_data(data) + + def pipe_connection_lost(self, fd, exc): + if fd == 0: + pipe = self.stdin + if pipe is not None: + pipe.close() + self.connection_lost(exc) + return + if fd == 1: + reader = self.stdout + elif fd == 2: + reader = self.stderr + else: + reader = None + if reader != None: + if exc is None: + reader.feed_eof() + else: + reader.set_exception(exc) + + def process_exited(self): + # wake up futures waiting for wait() + returncode = self._transport.get_returncode() + while self._waiters: + waiter = self._waiters.popleft() + waiter.set_result(returncode) + + +class Process: + def __init__(self, transport, protocol, loop): + self._transport = transport + self._protocol = protocol + self._loop = loop + self.stdin = protocol.stdin + self.stdout = protocol.stdout + self.stderr = protocol.stderr + self.pid = transport.get_pid() + + def __repr__(self): + return '<%s %s>' % (self.__class__.__name__, self.pid) + + @property + def returncode(self): + return self._transport.get_returncode() + + @coroutine + def wait(self): + """Wait until the process exit and return the process return code.""" + returncode = self._transport.get_returncode() + if returncode is not None: + return returncode + + waiter = futures.Future(loop=self._loop) + self._protocol._waiters.append(waiter) + yield from waiter + return waiter.result() + + def _check_alive(self): + if self._transport.get_returncode() is not None: + raise ProcessLookupError() + + def send_signal(self, signal): + self._check_alive() + self._transport.send_signal(signal) + + def terminate(self): + self._check_alive() + self._transport.terminate() + + def kill(self): + self._check_alive() + self._transport.kill() + + @coroutine + def _feed_stdin(self, input): + debug = self._loop.get_debug() + self.stdin.write(input) + if debug: + logger.debug('%r communicate: feed stdin (%s bytes)', + self, len(input)) + try: + yield from self.stdin.drain() + except (BrokenPipeError, ConnectionResetError) as exc: + # communicate() ignores BrokenPipeError and ConnectionResetError + if debug: + logger.debug('%r communicate: stdin got %r', self, exc) + + if debug: + logger.debug('%r communicate: close stdin', self) + self.stdin.close() + + @coroutine + def _noop(self): + return None + + @coroutine + def _read_stream(self, fd): + transport = self._transport.get_pipe_transport(fd) + if fd == 2: + stream = self.stderr + else: + assert fd == 1 + stream = self.stdout + if self._loop.get_debug(): + name = 'stdout' if fd == 1 else 'stderr' + logger.debug('%r communicate: read %s', self, name) + output = yield from stream.read() + if self._loop.get_debug(): + name = 'stdout' if fd == 1 else 'stderr' + logger.debug('%r communicate: close %s', self, name) + transport.close() + return output + + @coroutine + def communicate(self, input=None): + if input: + stdin = self._feed_stdin(input) + else: + stdin = self._noop() + if self.stdout is not None: + stdout = self._read_stream(1) + else: + stdout = self._noop() + if self.stderr is not None: + stderr = self._read_stream(2) + else: + stderr = self._noop() + stdin, stdout, stderr = yield from tasks.gather(stdin, stdout, stderr, + loop=self._loop) + yield from self.wait() + return (stdout, stderr) + + +@coroutine +def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None, + loop=None, limit=streams._DEFAULT_LIMIT, **kwds): + if loop is None: + loop = events.get_event_loop() + protocol_factory = lambda: SubprocessStreamProtocol(limit=limit, + loop=loop) + transport, protocol = yield from loop.subprocess_shell( + protocol_factory, + cmd, stdin=stdin, stdout=stdout, + stderr=stderr, **kwds) + yield from protocol.waiter + return Process(transport, protocol, loop) + +@coroutine +def create_subprocess_exec(program, *args, stdin=None, stdout=None, + stderr=None, loop=None, + limit=streams._DEFAULT_LIMIT, **kwds): + if loop is None: + loop = events.get_event_loop() + protocol_factory = lambda: SubprocessStreamProtocol(limit=limit, + loop=loop) + transport, protocol = yield from loop.subprocess_exec( + protocol_factory, + program, *args, + stdin=stdin, stdout=stdout, + stderr=stderr, **kwds) + yield from protocol.waiter + return Process(transport, protocol, loop) diff --git a/asyncio/tasks.py b/asyncio/tasks.py new file mode 100644 index 00000000..9aebffda --- /dev/null +++ b/asyncio/tasks.py @@ -0,0 +1,660 @@ +"""Support for tasks, coroutines and the scheduler.""" + +__all__ = ['Task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'wait_for', 'as_completed', 'sleep', 'async', + 'gather', 'shield', + ] + +import concurrent.futures +import functools +import inspect +import linecache +import sys +import traceback +import weakref + +from . import coroutines +from . import events +from . import futures +from .coroutines import coroutine + +_PY34 = (sys.version_info >= (3, 4)) + + +class Task(futures.Future): + """A coroutine wrapped in a Future.""" + + # An important invariant maintained while a Task not done: + # + # - Either _fut_waiter is None, and _step() is scheduled; + # - or _fut_waiter is some Future, and _step() is *not* scheduled. + # + # The only transition from the latter to the former is through + # _wakeup(). When _fut_waiter is not None, one of its callbacks + # must be _wakeup(). + + # Weak set containing all tasks alive. + _all_tasks = weakref.WeakSet() + + # Dictionary containing tasks that are currently active in + # all running event loops. {EventLoop: Task} + _current_tasks = {} + + # If False, don't log a message if the task is destroyed whereas its + # status is still pending + _log_destroy_pending = True + + @classmethod + def current_task(cls, loop=None): + """Return the currently running task in an event loop or None. + + By default the current task for the current event loop is returned. + + None is returned when called not in the context of a Task. + """ + if loop is None: + loop = events.get_event_loop() + return cls._current_tasks.get(loop) + + @classmethod + def all_tasks(cls, loop=None): + """Return a set of all tasks for an event loop. + + By default all tasks for the current event loop are returned. + """ + if loop is None: + loop = events.get_event_loop() + return {t for t in cls._all_tasks if t._loop is loop} + + def __init__(self, coro, *, loop=None): + assert coroutines.iscoroutine(coro), repr(coro) # Not a coroutine function! + super().__init__(loop=loop) + if self._source_traceback: + del self._source_traceback[-1] + self._coro = iter(coro) # Use the iterator just in case. + self._fut_waiter = None + self._must_cancel = False + self._loop.call_soon(self._step) + self.__class__._all_tasks.add(self) + + # On Python 3.3 or older, objects with a destructor that are part of a + # reference cycle are never destroyed. That's not the case any more on + # Python 3.4 thanks to the PEP 442. + if _PY34: + def __del__(self): + if self._state == futures._PENDING and self._log_destroy_pending: + context = { + 'task': self, + 'message': 'Task was destroyed but it is pending!', + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + futures.Future.__del__(self) + + def _repr_info(self): + info = super()._repr_info() + + if self._must_cancel: + # replace status + info[0] = 'cancelling' + + coro = coroutines._format_coroutine(self._coro) + info.insert(1, 'coro=<%s>' % coro) + + if self._fut_waiter is not None: + info.insert(2, 'wait_for=%r' % self._fut_waiter) + return info + + def get_stack(self, *, limit=None): + """Return the list of stack frames for this task's coroutine. + + If the coroutine is not done, this returns the stack where it is + suspended. If the coroutine has completed successfully or was + cancelled, this returns an empty list. If the coroutine was + terminated by an exception, this returns the list of traceback + frames. + + The frames are always ordered from oldest to newest. + + The optional limit gives the maximum number of frames to + return; by default all available frames are returned. Its + meaning differs depending on whether a stack or a traceback is + returned: the newest frames of a stack are returned, but the + oldest frames of a traceback are returned. (This matches the + behavior of the traceback module.) + + For reasons beyond our control, only one stack frame is + returned for a suspended coroutine. + """ + frames = [] + f = self._coro.gi_frame + if f is not None: + while f is not None: + if limit is not None: + if limit <= 0: + break + limit -= 1 + frames.append(f) + f = f.f_back + frames.reverse() + elif self._exception is not None: + tb = self._exception.__traceback__ + while tb is not None: + if limit is not None: + if limit <= 0: + break + limit -= 1 + frames.append(tb.tb_frame) + tb = tb.tb_next + return frames + + def print_stack(self, *, limit=None, file=None): + """Print the stack or traceback for this task's coroutine. + + This produces output similar to that of the traceback module, + for the frames retrieved by get_stack(). The limit argument + is passed to get_stack(). The file argument is an I/O stream + to which the output is written; by default output is written + to sys.stderr. + """ + extracted_list = [] + checked = set() + for f in self.get_stack(limit=limit): + lineno = f.f_lineno + co = f.f_code + filename = co.co_filename + name = co.co_name + if filename not in checked: + checked.add(filename) + linecache.checkcache(filename) + line = linecache.getline(filename, lineno, f.f_globals) + extracted_list.append((filename, lineno, name, line)) + exc = self._exception + if not extracted_list: + print('No stack for %r' % self, file=file) + elif exc is not None: + print('Traceback for %r (most recent call last):' % self, + file=file) + else: + print('Stack for %r (most recent call last):' % self, + file=file) + traceback.print_list(extracted_list, file=file) + if exc is not None: + for line in traceback.format_exception_only(exc.__class__, exc): + print(line, file=file, end='') + + def cancel(self): + """Request that this task cancel itself. + + This arranges for a CancelledError to be thrown into the + wrapped coroutine on the next cycle through the event loop. + The coroutine then has a chance to clean up or even deny + the request using try/except/finally. + + Unlike Future.cancel, this does not guarantee that the + task will be cancelled: the exception might be caught and + acted upon, delaying cancellation of the task or preventing + cancellation completely. The task may also return a value or + raise a different exception. + + Immediately after this method is called, Task.cancelled() will + not return True (unless the task was already cancelled). A + task will be marked as cancelled when the wrapped coroutine + terminates with a CancelledError exception (even if cancel() + was not called). + """ + if self.done(): + return False + if self._fut_waiter is not None: + if self._fut_waiter.cancel(): + # Leave self._fut_waiter; it may be a Task that + # catches and ignores the cancellation so we may have + # to cancel it again later. + return True + # It must be the case that self._step is already scheduled. + self._must_cancel = True + return True + + def _step(self, value=None, exc=None): + assert not self.done(), \ + '_step(): already done: {!r}, {!r}, {!r}'.format(self, value, exc) + if self._must_cancel: + if not isinstance(exc, futures.CancelledError): + exc = futures.CancelledError() + self._must_cancel = False + coro = self._coro + self._fut_waiter = None + + self.__class__._current_tasks[self._loop] = self + # Call either coro.throw(exc) or coro.send(value). + try: + if exc is not None: + result = coro.throw(exc) + elif value is not None: + result = coro.send(value) + else: + result = next(coro) + except StopIteration as exc: + self.set_result(exc.value) + except futures.CancelledError as exc: + super().cancel() # I.e., Future.cancel(self). + except Exception as exc: + self.set_exception(exc) + except BaseException as exc: + self.set_exception(exc) + raise + else: + if isinstance(result, futures.Future): + # Yielded Future must come from Future.__iter__(). + if result._blocking: + result._blocking = False + result.add_done_callback(self._wakeup) + self._fut_waiter = result + if self._must_cancel: + if self._fut_waiter.cancel(): + self._must_cancel = False + else: + self._loop.call_soon( + self._step, None, + RuntimeError( + 'yield was used instead of yield from ' + 'in task {!r} with {!r}'.format(self, result))) + elif result is None: + # Bare yield relinquishes control for one event loop iteration. + self._loop.call_soon(self._step) + elif inspect.isgenerator(result): + # Yielding a generator is just wrong. + self._loop.call_soon( + self._step, None, + RuntimeError( + 'yield was used instead of yield from for ' + 'generator in task {!r} with {}'.format( + self, result))) + else: + # Yielding something else is an error. + self._loop.call_soon( + self._step, None, + RuntimeError( + 'Task got bad yield: {!r}'.format(result))) + finally: + self.__class__._current_tasks.pop(self._loop) + self = None # Needed to break cycles when an exception occurs. + + def _wakeup(self, future): + try: + value = future.result() + except Exception as exc: + # This may also be a cancellation. + self._step(None, exc) + else: + self._step(value, None) + self = None # Needed to break cycles when an exception occurs. + + +# wait() and as_completed() similar to those in PEP 3148. + +FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED +FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION +ALL_COMPLETED = concurrent.futures.ALL_COMPLETED + + +@coroutine +def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): + """Wait for the Futures and coroutines given by fs to complete. + + The sequence futures must not be empty. + + Coroutines will be wrapped in Tasks. + + Returns two sets of Future: (done, pending). + + Usage: + + done, pending = yield from asyncio.wait(fs) + + Note: This does not raise TimeoutError! Futures that aren't done + when the timeout occurs are returned in the second set. + """ + if isinstance(fs, futures.Future) or coroutines.iscoroutine(fs): + raise TypeError("expect a list of futures, not %s" % type(fs).__name__) + if not fs: + raise ValueError('Set of coroutines/Futures is empty.') + if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED): + raise ValueError('Invalid return_when value: {}'.format(return_when)) + + if loop is None: + loop = events.get_event_loop() + + fs = {async(f, loop=loop) for f in set(fs)} + + return (yield from _wait(fs, timeout, return_when, loop)) + + +def _release_waiter(waiter, *args): + if not waiter.done(): + waiter.set_result(None) + + +@coroutine +def wait_for(fut, timeout, *, loop=None): + """Wait for the single Future or coroutine to complete, with timeout. + + Coroutine will be wrapped in Task. + + Returns result of the Future or coroutine. When a timeout occurs, + it cancels the task and raises TimeoutError. To avoid the task + cancellation, wrap it in shield(). + + Usage: + + result = yield from asyncio.wait_for(fut, 10.0) + + """ + if loop is None: + loop = events.get_event_loop() + + if timeout is None: + return (yield from fut) + + waiter = futures.Future(loop=loop) + timeout_handle = loop.call_later(timeout, _release_waiter, waiter) + cb = functools.partial(_release_waiter, waiter) + + fut = async(fut, loop=loop) + fut.add_done_callback(cb) + + try: + # wait until the future completes or the timeout + yield from waiter + + if fut.done(): + return fut.result() + else: + fut.remove_done_callback(cb) + fut.cancel() + raise futures.TimeoutError() + finally: + timeout_handle.cancel() + + +@coroutine +def _wait(fs, timeout, return_when, loop): + """Internal helper for wait() and _wait_for(). + + The fs argument must be a collection of Futures. + """ + assert fs, 'Set of Futures is empty.' + waiter = futures.Future(loop=loop) + timeout_handle = None + if timeout is not None: + timeout_handle = loop.call_later(timeout, _release_waiter, waiter) + counter = len(fs) + + def _on_completion(f): + nonlocal counter + counter -= 1 + if (counter <= 0 or + return_when == FIRST_COMPLETED or + return_when == FIRST_EXCEPTION and (not f.cancelled() and + f.exception() is not None)): + if timeout_handle is not None: + timeout_handle.cancel() + if not waiter.done(): + waiter.set_result(None) + + for f in fs: + f.add_done_callback(_on_completion) + + try: + yield from waiter + finally: + if timeout_handle is not None: + timeout_handle.cancel() + + done, pending = set(), set() + for f in fs: + f.remove_done_callback(_on_completion) + if f.done(): + done.add(f) + else: + pending.add(f) + return done, pending + + +# This is *not* a @coroutine! It is just an iterator (yielding Futures). +def as_completed(fs, *, loop=None, timeout=None): + """Return an iterator whose values are coroutines. + + When waiting for the yielded coroutines you'll get the results (or + exceptions!) of the original Futures (or coroutines), in the order + in which and as soon as they complete. + + This differs from PEP 3148; the proper way to use this is: + + for f in as_completed(fs): + result = yield from f # The 'yield from' may raise. + # Use result. + + If a timeout is specified, the 'yield from' will raise + TimeoutError when the timeout occurs before all Futures are done. + + Note: The futures 'f' are not necessarily members of fs. + """ + if isinstance(fs, futures.Future) or coroutines.iscoroutine(fs): + raise TypeError("expect a list of futures, not %s" % type(fs).__name__) + loop = loop if loop is not None else events.get_event_loop() + todo = {async(f, loop=loop) for f in set(fs)} + from .queues import Queue # Import here to avoid circular import problem. + done = Queue(loop=loop) + timeout_handle = None + + def _on_timeout(): + for f in todo: + f.remove_done_callback(_on_completion) + done.put_nowait(None) # Queue a dummy value for _wait_for_one(). + todo.clear() # Can't do todo.remove(f) in the loop. + + def _on_completion(f): + if not todo: + return # _on_timeout() was here first. + todo.remove(f) + done.put_nowait(f) + if not todo and timeout_handle is not None: + timeout_handle.cancel() + + @coroutine + def _wait_for_one(): + f = yield from done.get() + if f is None: + # Dummy value from _on_timeout(). + raise futures.TimeoutError + return f.result() # May raise f.exception(). + + for f in todo: + f.add_done_callback(_on_completion) + if todo and timeout is not None: + timeout_handle = loop.call_later(timeout, _on_timeout) + for _ in range(len(todo)): + yield _wait_for_one() + + +@coroutine +def sleep(delay, result=None, *, loop=None): + """Coroutine that completes after a given time (in seconds).""" + future = futures.Future(loop=loop) + h = future._loop.call_later(delay, + future._set_result_unless_cancelled, result) + try: + return (yield from future) + finally: + h.cancel() + + +def async(coro_or_future, *, loop=None): + """Wrap a coroutine in a future. + + If the argument is a Future, it is returned directly. + """ + if isinstance(coro_or_future, futures.Future): + if loop is not None and loop is not coro_or_future._loop: + raise ValueError('loop argument must agree with Future') + return coro_or_future + elif coroutines.iscoroutine(coro_or_future): + if loop is None: + loop = events.get_event_loop() + task = loop.create_task(coro_or_future) + if task._source_traceback: + del task._source_traceback[-1] + return task + else: + raise TypeError('A Future or coroutine is required') + + +class _GatheringFuture(futures.Future): + """Helper for gather(). + + This overrides cancel() to cancel all the children and act more + like Task.cancel(), which doesn't immediately mark itself as + cancelled. + """ + + def __init__(self, children, *, loop=None): + super().__init__(loop=loop) + self._children = children + + def cancel(self): + if self.done(): + return False + for child in self._children: + child.cancel() + return True + + +def gather(*coros_or_futures, loop=None, return_exceptions=False): + """Return a future aggregating results from the given coroutines + or futures. + + All futures must share the same event loop. If all the tasks are + done successfully, the returned future's result is the list of + results (in the order of the original sequence, not necessarily + the order of results arrival). If *return_exceptions* is True, + exceptions in the tasks are treated the same as successful + results, and gathered in the result list; otherwise, the first + raised exception will be immediately propagated to the returned + future. + + Cancellation: if the outer Future is cancelled, all children (that + have not completed yet) are also cancelled. If any child is + cancelled, this is treated as if it raised CancelledError -- + the outer Future is *not* cancelled in this case. (This is to + prevent the cancellation of one child to cause other children to + be cancelled.) + """ + if not coros_or_futures: + outer = futures.Future(loop=loop) + outer.set_result([]) + return outer + + arg_to_fut = {} + for arg in set(coros_or_futures): + if not isinstance(arg, futures.Future): + fut = async(arg, loop=loop) + if loop is None: + loop = fut._loop + # The caller cannot control this future, the "destroy pending task" + # warning should not be emitted. + fut._log_destroy_pending = False + else: + fut = arg + if loop is None: + loop = fut._loop + elif fut._loop is not loop: + raise ValueError("futures are tied to different event loops") + arg_to_fut[arg] = fut + + children = [arg_to_fut[arg] for arg in coros_or_futures] + nchildren = len(children) + outer = _GatheringFuture(children, loop=loop) + nfinished = 0 + results = [None] * nchildren + + def _done_callback(i, fut): + nonlocal nfinished + if outer._state != futures._PENDING: + if fut._exception is not None: + # Mark exception retrieved. + fut.exception() + return + if fut._state == futures._CANCELLED: + res = futures.CancelledError() + if not return_exceptions: + outer.set_exception(res) + return + elif fut._exception is not None: + res = fut.exception() # Mark exception retrieved. + if not return_exceptions: + outer.set_exception(res) + return + else: + res = fut._result + results[i] = res + nfinished += 1 + if nfinished == nchildren: + outer.set_result(results) + + for i, fut in enumerate(children): + fut.add_done_callback(functools.partial(_done_callback, i)) + return outer + + +def shield(arg, *, loop=None): + """Wait for a future, shielding it from cancellation. + + The statement + + res = yield from shield(something()) + + is exactly equivalent to the statement + + res = yield from something() + + *except* that if the coroutine containing it is cancelled, the + task running in something() is not cancelled. From the POV of + something(), the cancellation did not happen. But its caller is + still cancelled, so the yield-from expression still raises + CancelledError. Note: If something() is cancelled by other means + this will still cancel shield(). + + If you want to completely ignore cancellation (not recommended) + you can combine shield() with a try/except clause, as follows: + + try: + res = yield from shield(something()) + except CancelledError: + res = None + """ + inner = async(arg, loop=loop) + if inner.done(): + # Shortcut. + return inner + loop = inner._loop + outer = futures.Future(loop=loop) + + def _done_callback(inner): + if outer.cancelled(): + # Mark inner's result as retrieved. + inner.cancelled() or inner.exception() + return + if inner.cancelled(): + outer.cancel() + else: + exc = inner.exception() + if exc is not None: + outer.set_exception(exc) + else: + outer.set_result(inner.result()) + + inner.add_done_callback(_done_callback) + return outer diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py new file mode 100644 index 00000000..3e5eee54 --- /dev/null +++ b/asyncio/test_utils.py @@ -0,0 +1,436 @@ +"""Utilities shared by tests.""" + +import collections +import contextlib +import io +import logging +import os +import re +import socket +import socketserver +import sys +import tempfile +import threading +import time +import unittest +from unittest import mock + +from http.server import HTTPServer +from wsgiref.simple_server import WSGIRequestHandler, WSGIServer + +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import base_events +from . import events +from . import futures +from . import selectors +from . import tasks +from .coroutines import coroutine +from .log import logger + + +if sys.platform == 'win32': # pragma: no cover + from .windows_utils import socketpair +else: + from socket import socketpair # pragma: no cover + + +def dummy_ssl_context(): + if ssl is None: + return None + else: + return ssl.SSLContext(ssl.PROTOCOL_SSLv23) + + +def run_briefly(loop): + @coroutine + def once(): + pass + gen = once() + t = loop.create_task(gen) + # Don't log a warning if the task is not done after run_until_complete(). + # It occurs if the loop is stopped or if a task raises a BaseException. + t._log_destroy_pending = False + try: + loop.run_until_complete(t) + finally: + gen.close() + + +def run_until(loop, pred, timeout=30): + deadline = time.time() + timeout + while not pred(): + if timeout is not None: + timeout = deadline - time.time() + if timeout <= 0: + raise futures.TimeoutError() + loop.run_until_complete(tasks.sleep(0.001, loop=loop)) + + +def run_once(loop): + """loop.stop() schedules _raise_stop_error() + and run_forever() runs until _raise_stop_error() callback. + this wont work if test waits for some IO events, because + _raise_stop_error() runs before any of io events callbacks. + """ + loop.stop() + loop.run_forever() + + +class SilentWSGIRequestHandler(WSGIRequestHandler): + + def get_stderr(self): + return io.StringIO() + + def log_message(self, format, *args): + pass + + +class SilentWSGIServer(WSGIServer): + + request_timeout = 2 + + def get_request(self): + request, client_addr = super().get_request() + request.settimeout(self.request_timeout) + return request, client_addr + + def handle_error(self, request, client_address): + pass + + +class SSLWSGIServerMixin: + + def finish_request(self, request, client_address): + # The relative location of our test directory (which + # contains the ssl key and certificate files) differs + # between the stdlib and stand-alone asyncio. + # Prefer our own if we can find it. + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + if not os.path.isdir(here): + here = os.path.join(os.path.dirname(os.__file__), + 'test', 'test_asyncio') + keyfile = os.path.join(here, 'ssl_key.pem') + certfile = os.path.join(here, 'ssl_cert.pem') + ssock = ssl.wrap_socket(request, + keyfile=keyfile, + certfile=certfile, + server_side=True) + try: + self.RequestHandlerClass(ssock, client_address, self) + ssock.close() + except OSError: + # maybe socket has been closed by peer + pass + + +class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer): + pass + + +def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls): + + def app(environ, start_response): + status = '200 OK' + headers = [('Content-type', 'text/plain')] + start_response(status, headers) + return [b'Test message'] + + # Run the test WSGI server in a separate thread in order not to + # interfere with event handling in the main thread + server_class = server_ssl_cls if use_ssl else server_cls + httpd = server_class(address, SilentWSGIRequestHandler) + httpd.set_app(app) + httpd.address = httpd.server_address + server_thread = threading.Thread( + target=lambda: httpd.serve_forever(poll_interval=0.05)) + server_thread.start() + try: + yield httpd + finally: + httpd.shutdown() + httpd.server_close() + server_thread.join() + + +if hasattr(socket, 'AF_UNIX'): + + class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer): + + def server_bind(self): + socketserver.UnixStreamServer.server_bind(self) + self.server_name = '127.0.0.1' + self.server_port = 80 + + + class UnixWSGIServer(UnixHTTPServer, WSGIServer): + + request_timeout = 2 + + def server_bind(self): + UnixHTTPServer.server_bind(self) + self.setup_environ() + + def get_request(self): + request, client_addr = super().get_request() + request.settimeout(self.request_timeout) + # Code in the stdlib expects that get_request + # will return a socket and a tuple (host, port). + # However, this isn't true for UNIX sockets, + # as the second return value will be a path; + # hence we return some fake data sufficient + # to get the tests going + return request, ('127.0.0.1', '') + + + class SilentUnixWSGIServer(UnixWSGIServer): + + def handle_error(self, request, client_address): + pass + + + class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer): + pass + + + def gen_unix_socket_path(): + with tempfile.NamedTemporaryFile() as file: + return file.name + + + @contextlib.contextmanager + def unix_socket_path(): + path = gen_unix_socket_path() + try: + yield path + finally: + try: + os.unlink(path) + except OSError: + pass + + + @contextlib.contextmanager + def run_test_unix_server(*, use_ssl=False): + with unix_socket_path() as path: + yield from _run_test_server(address=path, use_ssl=use_ssl, + server_cls=SilentUnixWSGIServer, + server_ssl_cls=UnixSSLWSGIServer) + + +@contextlib.contextmanager +def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): + yield from _run_test_server(address=(host, port), use_ssl=use_ssl, + server_cls=SilentWSGIServer, + server_ssl_cls=SSLWSGIServer) + + +def make_test_protocol(base): + dct = {} + for name in dir(base): + if name.startswith('__') and name.endswith('__'): + # skip magic names + continue + dct[name] = MockCallback(return_value=None) + return type('TestProtocol', (base,) + base.__bases__, dct)() + + +class TestSelector(selectors.BaseSelector): + + def __init__(self): + self.keys = {} + + def register(self, fileobj, events, data=None): + key = selectors.SelectorKey(fileobj, 0, events, data) + self.keys[fileobj] = key + return key + + def unregister(self, fileobj): + return self.keys.pop(fileobj) + + def select(self, timeout): + return [] + + def get_map(self): + return self.keys + + +class TestLoop(base_events.BaseEventLoop): + """Loop for unittests. + + It manages self time directly. + If something scheduled to be executed later then + on next loop iteration after all ready handlers done + generator passed to __init__ is calling. + + Generator should be like this: + + def gen(): + ... + when = yield ... + ... = yield time_advance + + Value returned by yield is absolute time of next scheduled handler. + Value passed to yield is time advance to move loop's time forward. + """ + + def __init__(self, gen=None): + super().__init__() + + if gen is None: + def gen(): + yield + self._check_on_close = False + else: + self._check_on_close = True + + self._gen = gen() + next(self._gen) + self._time = 0 + self._clock_resolution = 1e-9 + self._timers = [] + self._selector = TestSelector() + + self.readers = {} + self.writers = {} + self.reset_counters() + + def time(self): + return self._time + + def advance_time(self, advance): + """Move test time forward.""" + if advance: + self._time += advance + + def close(self): + if self._check_on_close: + try: + self._gen.send(0) + except StopIteration: + pass + else: # pragma: no cover + raise AssertionError("Time generator is not finished") + + def add_reader(self, fd, callback, *args): + self.readers[fd] = events.Handle(callback, args, self) + + def remove_reader(self, fd): + self.remove_reader_count[fd] += 1 + if fd in self.readers: + del self.readers[fd] + return True + else: + return False + + def assert_reader(self, fd, callback, *args): + assert fd in self.readers, 'fd {} is not registered'.format(fd) + handle = self.readers[fd] + assert handle._callback == callback, '{!r} != {!r}'.format( + handle._callback, callback) + assert handle._args == args, '{!r} != {!r}'.format( + handle._args, args) + + def add_writer(self, fd, callback, *args): + self.writers[fd] = events.Handle(callback, args, self) + + def remove_writer(self, fd): + self.remove_writer_count[fd] += 1 + if fd in self.writers: + del self.writers[fd] + return True + else: + return False + + def assert_writer(self, fd, callback, *args): + assert fd in self.writers, 'fd {} is not registered'.format(fd) + handle = self.writers[fd] + assert handle._callback == callback, '{!r} != {!r}'.format( + handle._callback, callback) + assert handle._args == args, '{!r} != {!r}'.format( + handle._args, args) + + def reset_counters(self): + self.remove_reader_count = collections.defaultdict(int) + self.remove_writer_count = collections.defaultdict(int) + + def _run_once(self): + super()._run_once() + for when in self._timers: + advance = self._gen.send(when) + self.advance_time(advance) + self._timers = [] + + def call_at(self, when, callback, *args): + self._timers.append(when) + return super().call_at(when, callback, *args) + + def _process_events(self, event_list): + return + + def _write_to_self(self): + pass + + +def MockCallback(**kwargs): + return mock.Mock(spec=['__call__'], **kwargs) + + +class MockPattern(str): + """A regex based str with a fuzzy __eq__. + + Use this helper with 'mock.assert_called_with', or anywhere + where a regex comparison between strings is needed. + + For instance: + mock_call.assert_called_with(MockPattern('spam.*ham')) + """ + def __eq__(self, other): + return bool(re.search(str(self), other, re.S)) + + +def get_function_source(func): + source = events._get_function_source(func) + if source is None: + raise ValueError("unable to get the source of %r" % (func,)) + return source + + +class TestCase(unittest.TestCase): + def set_event_loop(self, loop, *, cleanup=True): + assert loop is not None + # ensure that the event loop is passed explicitly in asyncio + events.set_event_loop(None) + if cleanup: + self.addCleanup(loop.close) + + def new_test_loop(self, gen=None): + loop = TestLoop(gen) + self.set_event_loop(loop) + return loop + + def tearDown(self): + events.set_event_loop(None) + + +@contextlib.contextmanager +def disable_logger(): + """Context manager to disable asyncio logger. + + For example, it can be used to ignore warnings in debug mode. + """ + old_level = logger.level + try: + logger.setLevel(logging.CRITICAL+1) + yield + finally: + logger.setLevel(old_level) + +def mock_nonblocking_socket(): + """Create a mock of a non-blocking socket.""" + sock = mock.Mock(socket.socket) + sock.gettimeout.return_value = 0.0 + return sock diff --git a/asyncio/transports.py b/asyncio/transports.py new file mode 100644 index 00000000..22df3c7a --- /dev/null +++ b/asyncio/transports.py @@ -0,0 +1,300 @@ +"""Abstract Transport class.""" + +import sys + +_PY34 = sys.version_info >= (3, 4) + +__all__ = ['BaseTransport', 'ReadTransport', 'WriteTransport', + 'Transport', 'DatagramTransport', 'SubprocessTransport', + ] + + +class BaseTransport: + """Base class for transports.""" + + def __init__(self, extra=None): + if extra is None: + extra = {} + self._extra = extra + + def get_extra_info(self, name, default=None): + """Get optional transport information.""" + return self._extra.get(name, default) + + def close(self): + """Close the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + raise NotImplementedError + + +class ReadTransport(BaseTransport): + """Interface for read-only transports.""" + + def pause_reading(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume_reading() is called. + """ + raise NotImplementedError + + def resume_reading(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + raise NotImplementedError + + +class WriteTransport(BaseTransport): + """Interface for write-only transports.""" + + def set_write_buffer_limits(self, high=None, low=None): + """Set the high- and low-water limits for write flow control. + + These two values control when to call the protocol's + pause_writing() and resume_writing() methods. If specified, + the low-water limit must be less than or equal to the + high-water limit. Neither value can be negative. + + The defaults are implementation-specific. If only the + high-water limit is given, the low-water limit defaults to a + implementation-specific value less than or equal to the + high-water limit. Setting high to zero forces low to zero as + well, and causes pause_writing() to be called whenever the + buffer becomes non-empty. Setting low to zero causes + resume_writing() to be called only once the buffer is empty. + Use of zero for either limit is generally sub-optimal as it + reduces opportunities for doing I/O and computation + concurrently. + """ + raise NotImplementedError + + def get_write_buffer_size(self): + """Return the current size of the write buffer.""" + raise NotImplementedError + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation concatenates the arguments and + calls write() on the result. + """ + if not _PY34: + # In Python 3.3, bytes.join() doesn't handle memoryview. + list_of_data = ( + bytes(data) if isinstance(data, memoryview) else data + for data in list_of_data) + self.write(b''.join(list_of_data)) + + def write_eof(self): + """Close the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + raise NotImplementedError + + def can_write_eof(self): + """Return True if this transport supports write_eof(), False if not.""" + raise NotImplementedError + + def abort(self): + """Close the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class Transport(ReadTransport, WriteTransport): + """Interface representing a bidirectional transport. + + There may be several implementations, but typically, the user does + not implement new transports; rather, the platform provides some + useful transports that are implemented using the platform's best + practices. + + The user never instantiates a transport directly; they call a + utility function, passing it a protocol factory and other + information necessary to create the transport and protocol. (E.g. + EventLoop.create_connection() or EventLoop.create_server().) + + The utility function will asynchronously create a transport and a + protocol and hook them up by calling the protocol's + connection_made() method, passing it the transport. + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + +class DatagramTransport(BaseTransport): + """Interface for datagram (UDP) transports.""" + + def sendto(self, data, addr=None): + """Send data to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + addr is target socket address. + If addr is None use target address pointed on transport creation. + """ + raise NotImplementedError + + def abort(self): + """Close the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class SubprocessTransport(BaseTransport): + + def get_pid(self): + """Get subprocess id.""" + raise NotImplementedError + + def get_returncode(self): + """Get subprocess returncode. + + See also + http://docs.python.org/3/library/subprocess#subprocess.Popen.returncode + """ + raise NotImplementedError + + def get_pipe_transport(self, fd): + """Get transport for pipe with number fd.""" + raise NotImplementedError + + def send_signal(self, signal): + """Send signal to subprocess. + + See also: + docs.python.org/3/library/subprocess#subprocess.Popen.send_signal + """ + raise NotImplementedError + + def terminate(self): + """Stop the subprocess. + + Alias for close() method. + + On Posix OSs the method sends SIGTERM to the subprocess. + On Windows the Win32 API function TerminateProcess() + is called to stop the subprocess. + + See also: + http://docs.python.org/3/library/subprocess#subprocess.Popen.terminate + """ + raise NotImplementedError + + def kill(self): + """Kill the subprocess. + + On Posix OSs the function sends SIGKILL to the subprocess. + On Windows kill() is an alias for terminate(). + + See also: + http://docs.python.org/3/library/subprocess#subprocess.Popen.kill + """ + raise NotImplementedError + + +class _FlowControlMixin(Transport): + """All the logic for (write) flow control in a mix-in base class. + + The subclass must implement get_write_buffer_size(). It must call + _maybe_pause_protocol() whenever the write buffer size increases, + and _maybe_resume_protocol() whenever it decreases. It may also + override set_write_buffer_limits() (e.g. to specify different + defaults). + + The subclass constructor must call super().__init__(extra). This + will call set_write_buffer_limits(). + + The user may call set_write_buffer_limits() and + get_write_buffer_size(), and their protocol's pause_writing() and + resume_writing() may be called. + """ + + def __init__(self, extra=None, loop=None): + super().__init__(extra) + assert loop is not None + self._loop = loop + self._protocol_paused = False + self._set_write_buffer_limits() + + def _maybe_pause_protocol(self): + size = self.get_write_buffer_size() + if size <= self._high_water: + return + if not self._protocol_paused: + self._protocol_paused = True + try: + self._protocol.pause_writing() + except Exception as exc: + self._loop.call_exception_handler({ + 'message': 'protocol.pause_writing() failed', + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) + + def _maybe_resume_protocol(self): + if (self._protocol_paused and + self.get_write_buffer_size() <= self._low_water): + self._protocol_paused = False + try: + self._protocol.resume_writing() + except Exception as exc: + self._loop.call_exception_handler({ + 'message': 'protocol.resume_writing() failed', + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) + + def get_write_buffer_limits(self): + return (self._low_water, self._high_water) + + def _set_write_buffer_limits(self, high=None, low=None): + if high is None: + if low is None: + high = 64*1024 + else: + high = 4*low + if low is None: + low = high // 4 + if not high >= low >= 0: + raise ValueError('high (%r) must be >= low (%r) must be >= 0' % + (high, low)) + self._high_water = high + self._low_water = low + + def set_write_buffer_limits(self, high=None, low=None): + self._set_write_buffer_limits(high=high, low=low) + self._maybe_pause_protocol() + + def get_write_buffer_size(self): + raise NotImplementedError diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py new file mode 100644 index 00000000..d1461fd0 --- /dev/null +++ b/asyncio/unix_events.py @@ -0,0 +1,949 @@ +"""Selector event loop for Unix with signal handling.""" + +import errno +import os +import signal +import socket +import stat +import subprocess +import sys +import threading + + +from . import base_events +from . import base_subprocess +from . import constants +from . import coroutines +from . import events +from . import selector_events +from . import selectors +from . import transports +from .coroutines import coroutine +from .log import logger + + +__all__ = ['SelectorEventLoop', + 'AbstractChildWatcher', 'SafeChildWatcher', + 'FastChildWatcher', 'DefaultEventLoopPolicy', + ] + +if sys.platform == 'win32': # pragma: no cover + raise ImportError('Signals are not really supported on Windows') + + +def _sighandler_noop(signum, frame): + """Dummy signal handler.""" + pass + + +class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Unix event loop. + + Adds signal handling and UNIX Domain Socket support to SelectorEventLoop. + """ + + def __init__(self, selector=None): + super().__init__(selector) + self._signal_handlers = {} + + def _socketpair(self): + return socket.socketpair() + + def close(self): + super().close() + for sig in list(self._signal_handlers): + self.remove_signal_handler(sig) + + def _process_self_data(self, data): + for signum in data: + if not signum: + # ignore null bytes written by _write_to_self() + continue + self._handle_signal(signum) + + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if (coroutines.iscoroutine(callback) + or coroutines.iscoroutinefunction(callback)): + raise TypeError("coroutines cannot be used with add_signal_handler()") + self._check_signal(sig) + self._check_closed() + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except (ValueError, OSError) as exc: + raise RuntimeError(str(exc)) + + handle = events.Handle(callback, args, self) + self._signal_handlers[sig] = handle + + try: + # Register a dummy signal handler to ask Python to write the signal + # number in the wakup file descriptor. _process_self_data() will + # read signal numbers from this file descriptor to handle signals. + signal.signal(sig, _sighandler_noop) + + # Set SA_RESTART to limit EINTR occurrences. + signal.siginterrupt(sig, False) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except (ValueError, OSError) as nexc: + logger.info('set_wakeup_fd(-1) failed: %s', nexc) + + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + def _handle_signal(self, sig): + """Internal helper that is the actual signal handler.""" + handle = self._signal_handlers.get(sig) + if handle is None: + return # Assume it's some race condition. + if handle._cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self._add_callback_signalsafe(handle) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not. + """ + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except (ValueError, OSError) as exc: + logger.info('set_wakeup_fd(-1) failed: %s', exc) + + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + + if not (1 <= sig < signal.NSIG): + raise ValueError( + 'sig {} out of range(1, {})'.format(sig, signal.NSIG)) + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixReadPipeTransport(self, pipe, protocol, waiter, extra) + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixWritePipeTransport(self, pipe, protocol, waiter, extra) + + @coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + with events.get_child_watcher() as watcher: + transp = _UnixSubprocessTransport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=extra, **kwargs) + yield from transp._post_init() + watcher.add_child_handler(transp.get_pid(), + self._child_watcher_callback, transp) + + return transp + + def _child_watcher_callback(self, pid, returncode, transp): + self.call_soon_threadsafe(transp._process_exited, returncode) + + @coroutine + def create_unix_connection(self, protocol_factory, path, *, + ssl=None, sock=None, + server_hostname=None): + assert server_hostname is None or isinstance(server_hostname, str) + if ssl: + if server_hostname is None: + raise ValueError( + 'you have to pass server_hostname when using ssl') + else: + if server_hostname is not None: + raise ValueError('server_hostname is only meaningful with ssl') + + if path is not None: + if sock is not None: + raise ValueError( + 'path and sock can not be specified at the same time') + + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0) + try: + sock.setblocking(False) + yield from self.sock_connect(sock, path) + except: + sock.close() + raise + + else: + if sock is None: + raise ValueError('no path and sock were specified') + sock.setblocking(False) + + transport, protocol = yield from self._create_connection_transport( + sock, protocol_factory, ssl, server_hostname) + return transport, protocol + + @coroutine + def create_unix_server(self, protocol_factory, path=None, *, + sock=None, backlog=100, ssl=None): + if isinstance(ssl, bool): + raise TypeError('ssl argument must be an SSLContext or None') + + if path is not None: + if sock is not None: + raise ValueError( + 'path and sock can not be specified at the same time') + + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + + try: + sock.bind(path) + except OSError as exc: + sock.close() + if exc.errno == errno.EADDRINUSE: + # Let's improve the error message by adding + # with what exact address it occurs. + msg = 'Address {!r} is already in use'.format(path) + raise OSError(errno.EADDRINUSE, msg) from None + else: + raise + except: + sock.close() + raise + else: + if sock is None: + raise ValueError( + 'path was not specified, and no sock specified') + + if sock.family != socket.AF_UNIX: + raise ValueError( + 'A UNIX Domain Socket was expected, got {!r}'.format(sock)) + + server = base_events.Server(self, [sock]) + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock, ssl, server) + return server + + +if hasattr(os, 'set_blocking'): + def _set_nonblocking(fd): + os.set_blocking(fd, False) +else: + import fcntl + + def _set_nonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + + +class _UnixReadPipeTransport(transports.ReadTransport): + + max_size = 256 * 1024 # max bytes we read in one event loop iteration + + def __init__(self, loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._loop = loop + self._pipe = pipe + self._fileno = pipe.fileno() + mode = os.fstat(self._fileno).st_mode + if not (stat.S_ISFIFO(mode) or + stat.S_ISSOCK(mode) or + stat.S_ISCHR(mode)): + raise ValueError("Pipe transport is for pipes/sockets only.") + _set_nonblocking(self._fileno) + self._protocol = protocol + self._closing = False + self._loop.add_reader(self._fileno, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + # wait until protocol.connection_made() has been called + self._loop.call_soon(waiter._set_result_unless_cancelled, None) + + def __repr__(self): + info = [self.__class__.__name__, 'fd=%s' % self._fileno] + if self._pipe is not None: + polling = selector_events._test_selector_event( + self._loop._selector, + self._fileno, selectors.EVENT_READ) + if polling: + info.append('polling') + else: + info.append('idle') + else: + info.append('closed') + return '<%s>' % ' '.join(info) + + def _read_ready(self): + try: + data = os.read(self._fileno, self.max_size) + except (BlockingIOError, InterruptedError): + pass + except OSError as exc: + self._fatal_error(exc, 'Fatal read error on pipe transport') + else: + if data: + self._protocol.data_received(data) + else: + if self._loop.get_debug(): + logger.info("%r was closed by peer", self) + self._closing = True + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._protocol.eof_received) + self._loop.call_soon(self._call_connection_lost, None) + + def pause_reading(self): + self._loop.remove_reader(self._fileno) + + def resume_reading(self): + self._loop.add_reader(self._fileno, self._read_ready) + + def close(self): + if not self._closing: + self._close(None) + + def _fatal_error(self, exc, message='Fatal error on pipe transport'): + # should be called by exception handler only + if (isinstance(exc, OSError) and exc.errno == errno.EIO): + if self._loop.get_debug(): + logger.debug("%r: %s", self, message, exc_info=True) + else: + self._loop.call_exception_handler({ + 'message': message, + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) + self._close(exc) + + def _close(self, exc): + self._closing = True + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + self._pipe = None + self._protocol = None + self._loop = None + + +class _UnixWritePipeTransport(transports._FlowControlMixin, + transports.WriteTransport): + + def __init__(self, loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra, loop) + self._extra['pipe'] = pipe + self._pipe = pipe + self._fileno = pipe.fileno() + mode = os.fstat(self._fileno).st_mode + is_socket = stat.S_ISSOCK(mode) + if not (is_socket or + stat.S_ISFIFO(mode) or + stat.S_ISCHR(mode)): + raise ValueError("Pipe transport is only for " + "pipes, sockets and character devices") + _set_nonblocking(self._fileno) + self._protocol = protocol + self._buffer = [] + self._conn_lost = 0 + self._closing = False # Set when close() or write_eof() called. + + # On AIX, the reader trick only works for sockets. + # On other platforms it works for pipes and sockets. + # (Exception: OS X 10.4? Issue #19294.) + if is_socket or not sys.platform.startswith("aix"): + self._loop.add_reader(self._fileno, self._read_ready) + + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + # wait until protocol.connection_made() has been called + self._loop.call_soon(waiter._set_result_unless_cancelled, None) + + def __repr__(self): + info = [self.__class__.__name__, 'fd=%s' % self._fileno] + if self._pipe is not None: + polling = selector_events._test_selector_event( + self._loop._selector, + self._fileno, selectors.EVENT_WRITE) + if polling: + info.append('polling') + else: + info.append('idle') + + bufsize = self.get_write_buffer_size() + info.append('bufsize=%s' % bufsize) + else: + info.append('closed') + return '<%s>' % ' '.join(info) + + def get_write_buffer_size(self): + return sum(len(data) for data in self._buffer) + + def _read_ready(self): + # Pipe was closed by peer. + if self._loop.get_debug(): + logger.info("%r was closed by peer", self) + if self._buffer: + self._close(BrokenPipeError()) + else: + self._close() + + def write(self, data): + assert isinstance(data, (bytes, bytearray, memoryview)), repr(data) + if isinstance(data, bytearray): + data = memoryview(data) + if not data: + return + + if self._conn_lost or self._closing: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('pipe closed by peer or ' + 'os.write(pipe, data) raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Attempt to send it right away first. + try: + n = os.write(self._fileno, data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + self._conn_lost += 1 + self._fatal_error(exc, 'Fatal write error on pipe transport') + return + if n == len(data): + return + elif n > 0: + data = data[n:] + self._loop.add_writer(self._fileno, self._write_ready) + + self._buffer.append(data) + self._maybe_pause_protocol() + + def _write_ready(self): + data = b''.join(self._buffer) + assert data, 'Data should not be empty' + + self._buffer.clear() + try: + n = os.write(self._fileno, data) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except Exception as exc: + self._conn_lost += 1 + # Remove writer here, _fatal_error() doesn't it + # because _buffer is empty. + self._loop.remove_writer(self._fileno) + self._fatal_error(exc, 'Fatal write error on pipe transport') + else: + if n == len(data): + self._loop.remove_writer(self._fileno) + self._maybe_resume_protocol() # May append to buffer. + if not self._buffer and self._closing: + self._loop.remove_reader(self._fileno) + self._call_connection_lost(None) + return + elif n > 0: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def can_write_eof(self): + return True + + # TODO: Make the relationships between write_eof(), close(), + # abort(), _fatal_error() and _close() more straightforward. + + def write_eof(self): + if self._closing: + return + assert self._pipe + self._closing = True + if not self._buffer: + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, None) + + def close(self): + if not self._closing: + # write_eof is all what we needed to close the write pipe + self.write_eof() + + def abort(self): + self._close(None) + + def _fatal_error(self, exc, message='Fatal error on pipe transport'): + # should be called by exception handler only + if isinstance(exc, (BrokenPipeError, ConnectionResetError)): + if self._loop.get_debug(): + logger.debug("%r: %s", self, message, exc_info=True) + else: + self._loop.call_exception_handler({ + 'message': message, + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) + self._close(exc) + + def _close(self, exc=None): + self._closing = True + if self._buffer: + self._loop.remove_writer(self._fileno) + self._buffer.clear() + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + self._pipe = None + self._protocol = None + self._loop = None + + +if hasattr(os, 'set_inheritable'): + # Python 3.4 and newer + _set_inheritable = os.set_inheritable +else: + import fcntl + + def _set_inheritable(fd, inheritable): + cloexec_flag = getattr(fcntl, 'FD_CLOEXEC', 1) + + old = fcntl.fcntl(fd, fcntl.F_GETFD) + if not inheritable: + fcntl.fcntl(fd, fcntl.F_SETFD, old | cloexec_flag) + else: + fcntl.fcntl(fd, fcntl.F_SETFD, old & ~cloexec_flag) + + +class _UnixSubprocessTransport(base_subprocess.BaseSubprocessTransport): + + def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): + stdin_w = None + if stdin == subprocess.PIPE: + # Use a socket pair for stdin, since not all platforms + # support selecting read events on the write end of a + # socket (which we use in order to detect closing of the + # other end). Notably this is needed on AIX, and works + # just fine on other platforms. + stdin, stdin_w = self._loop._socketpair() + + # Mark the write end of the stdin pipe as non-inheritable, + # needed by close_fds=False on Python 3.3 and older + # (Python 3.4 implements the PEP 446, socketpair returns + # non-inheritable sockets) + _set_inheritable(stdin_w.fileno(), False) + self._proc = subprocess.Popen( + args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, + universal_newlines=False, bufsize=bufsize, **kwargs) + if stdin_w is not None: + stdin.close() + self._proc.stdin = open(stdin_w.detach(), 'wb', buffering=bufsize) + + +class AbstractChildWatcher: + """Abstract base class for monitoring child processes. + + Objects derived from this class monitor a collection of subprocesses and + report their termination or interruption by a signal. + + New callbacks are registered with .add_child_handler(). Starting a new + process must be done within a 'with' block to allow the watcher to suspend + its activity until the new process if fully registered (this is needed to + prevent a race condition in some implementations). + + Example: + with watcher: + proc = subprocess.Popen("sleep 1") + watcher.add_child_handler(proc.pid, callback) + + Notes: + Implementations of this class must be thread-safe. + + Since child watcher objects may catch the SIGCHLD signal and call + waitpid(-1), there should be only one active object per process. + """ + + def add_child_handler(self, pid, callback, *args): + """Register a new child handler. + + Arrange for callback(pid, returncode, *args) to be called when + process 'pid' terminates. Specifying another callback for the same + process replaces the previous handler. + + Note: callback() must be thread-safe. + """ + raise NotImplementedError() + + def remove_child_handler(self, pid): + """Removes the handler for process 'pid'. + + The function returns True if the handler was successfully removed, + False if there was nothing to remove.""" + + raise NotImplementedError() + + def attach_loop(self, loop): + """Attach the watcher to an event loop. + + If the watcher was previously attached to an event loop, then it is + first detached before attaching to the new loop. + + Note: loop may be None. + """ + raise NotImplementedError() + + def close(self): + """Close the watcher. + + This must be called to make sure that any underlying resource is freed. + """ + raise NotImplementedError() + + def __enter__(self): + """Enter the watcher's context and allow starting new processes + + This function must return self""" + raise NotImplementedError() + + def __exit__(self, a, b, c): + """Exit the watcher's context""" + raise NotImplementedError() + + +class BaseChildWatcher(AbstractChildWatcher): + + def __init__(self): + self._loop = None + + def close(self): + self.attach_loop(None) + + def _do_waitpid(self, expected_pid): + raise NotImplementedError() + + def _do_waitpid_all(self): + raise NotImplementedError() + + def attach_loop(self, loop): + assert loop is None or isinstance(loop, events.AbstractEventLoop) + + if self._loop is not None: + self._loop.remove_signal_handler(signal.SIGCHLD) + + self._loop = loop + if loop is not None: + loop.add_signal_handler(signal.SIGCHLD, self._sig_chld) + + # Prevent a race condition in case a child terminated + # during the switch. + self._do_waitpid_all() + + def _sig_chld(self): + try: + self._do_waitpid_all() + except Exception as exc: + # self._loop should always be available here + # as '_sig_chld' is added as a signal handler + # in 'attach_loop' + self._loop.call_exception_handler({ + 'message': 'Unknown exception in SIGCHLD handler', + 'exception': exc, + }) + + def _compute_returncode(self, status): + if os.WIFSIGNALED(status): + # The child process died because of a signal. + return -os.WTERMSIG(status) + elif os.WIFEXITED(status): + # The child process exited (e.g sys.exit()). + return os.WEXITSTATUS(status) + else: + # The child exited, but we don't understand its status. + # This shouldn't happen, but if it does, let's just + # return that status; perhaps that helps debug it. + return status + + +class SafeChildWatcher(BaseChildWatcher): + """'Safe' child watcher implementation. + + This implementation avoids disrupting other code spawning processes by + polling explicitly each process in the SIGCHLD handler instead of calling + os.waitpid(-1). + + This is a safe solution but it has a significant overhead when handling a + big number of children (O(n) each time SIGCHLD is raised) + """ + + def __init__(self): + super().__init__() + self._callbacks = {} + + def close(self): + self._callbacks.clear() + super().close() + + def __enter__(self): + return self + + def __exit__(self, a, b, c): + pass + + def add_child_handler(self, pid, callback, *args): + self._callbacks[pid] = callback, args + + # Prevent a race condition in case the child is already terminated. + self._do_waitpid(pid) + + def remove_child_handler(self, pid): + try: + del self._callbacks[pid] + return True + except KeyError: + return False + + def _do_waitpid_all(self): + + for pid in list(self._callbacks): + self._do_waitpid(pid) + + def _do_waitpid(self, expected_pid): + assert expected_pid > 0 + + try: + pid, status = os.waitpid(expected_pid, os.WNOHANG) + except ChildProcessError: + # The child process is already reaped + # (may happen if waitpid() is called elsewhere). + pid = expected_pid + returncode = 255 + logger.warning( + "Unknown child process pid %d, will report returncode 255", + pid) + else: + if pid == 0: + # The child process is still alive. + return + + returncode = self._compute_returncode(status) + if self._loop.get_debug(): + logger.debug('process %s exited with returncode %s', + expected_pid, returncode) + + try: + callback, args = self._callbacks.pop(pid) + except KeyError: # pragma: no cover + # May happen if .remove_child_handler() is called + # after os.waitpid() returns. + if self._loop.get_debug(): + logger.warning("Child watcher got an unexpected pid: %r", + pid, exc_info=True) + else: + callback(pid, returncode, *args) + + +class FastChildWatcher(BaseChildWatcher): + """'Fast' child watcher implementation. + + This implementation reaps every terminated processes by calling + os.waitpid(-1) directly, possibly breaking other code spawning processes + and waiting for their termination. + + There is no noticeable overhead when handling a big number of children + (O(1) each time a child terminates). + """ + def __init__(self): + super().__init__() + self._callbacks = {} + self._lock = threading.Lock() + self._zombies = {} + self._forks = 0 + + def close(self): + self._callbacks.clear() + self._zombies.clear() + super().close() + + def __enter__(self): + with self._lock: + self._forks += 1 + + return self + + def __exit__(self, a, b, c): + with self._lock: + self._forks -= 1 + + if self._forks or not self._zombies: + return + + collateral_victims = str(self._zombies) + self._zombies.clear() + + logger.warning( + "Caught subprocesses termination from unknown pids: %s", + collateral_victims) + + def add_child_handler(self, pid, callback, *args): + assert self._forks, "Must use the context manager" + with self._lock: + try: + returncode = self._zombies.pop(pid) + except KeyError: + # The child is running. + self._callbacks[pid] = callback, args + return + + # The child is dead already. We can fire the callback. + callback(pid, returncode, *args) + + def remove_child_handler(self, pid): + try: + del self._callbacks[pid] + return True + except KeyError: + return False + + def _do_waitpid_all(self): + # Because of signal coalescing, we must keep calling waitpid() as + # long as we're able to reap a child. + while True: + try: + pid, status = os.waitpid(-1, os.WNOHANG) + except ChildProcessError: + # No more child processes exist. + return + else: + if pid == 0: + # A child process is still alive. + return + + returncode = self._compute_returncode(status) + + with self._lock: + try: + callback, args = self._callbacks.pop(pid) + except KeyError: + # unknown child + if self._forks: + # It may not be registered yet. + self._zombies[pid] = returncode + if self._loop.get_debug(): + logger.debug('unknown process %s exited ' + 'with returncode %s', + pid, returncode) + continue + callback = None + else: + if self._loop.get_debug(): + logger.debug('process %s exited with returncode %s', + pid, returncode) + + if callback is None: + logger.warning( + "Caught subprocess termination from unknown pid: " + "%d -> %d", pid, returncode) + else: + callback(pid, returncode, *args) + + +class _UnixDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy): + """XXX""" + _loop_factory = _UnixSelectorEventLoop + + def __init__(self): + super().__init__() + self._watcher = None + + def _init_watcher(self): + with events._lock: + if self._watcher is None: # pragma: no branch + self._watcher = SafeChildWatcher() + if isinstance(threading.current_thread(), + threading._MainThread): + self._watcher.attach_loop(self._local._loop) + + def set_event_loop(self, loop): + """Set the event loop. + + As a side effect, if a child watcher was set before, then calling + .set_event_loop() from the main thread will call .attach_loop(loop) on + the child watcher. + """ + + super().set_event_loop(loop) + + if self._watcher is not None and \ + isinstance(threading.current_thread(), threading._MainThread): + self._watcher.attach_loop(loop) + + def get_child_watcher(self): + """Get the watcher for child processes. + + If not yet set, a SafeChildWatcher object is automatically created. + """ + if self._watcher is None: + self._init_watcher() + + return self._watcher + + def set_child_watcher(self, watcher): + """Set the watcher for child processes.""" + + assert watcher is None or isinstance(watcher, AbstractChildWatcher) + + if self._watcher is not None: + self._watcher.close() + + self._watcher = watcher + +SelectorEventLoop = _UnixSelectorEventLoop +DefaultEventLoopPolicy = _UnixDefaultEventLoopPolicy diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py new file mode 100644 index 00000000..6763f0b7 --- /dev/null +++ b/asyncio/windows_events.py @@ -0,0 +1,634 @@ +"""Selector and proactor event loops for Windows.""" + +import _winapi +import errno +import math +import socket +import struct +import weakref + +from . import events +from . import base_subprocess +from . import futures +from . import proactor_events +from . import selector_events +from . import tasks +from . import windows_utils +from . import _overlapped +from .coroutines import coroutine +from .log import logger + + +__all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor', + 'DefaultEventLoopPolicy', + ] + + +NULL = 0 +INFINITE = 0xffffffff +ERROR_CONNECTION_REFUSED = 1225 +ERROR_CONNECTION_ABORTED = 1236 + + +class _OverlappedFuture(futures.Future): + """Subclass of Future which represents an overlapped operation. + + Cancelling it will immediately cancel the overlapped operation. + """ + + def __init__(self, ov, *, loop=None): + super().__init__(loop=loop) + if self._source_traceback: + del self._source_traceback[-1] + self._ov = ov + + def _repr_info(self): + info = super()._repr_info() + if self._ov is not None: + state = 'pending' if self._ov.pending else 'completed' + info.insert(1, 'overlapped=<%s, %#x>' % (state, self._ov.address)) + return info + + def _cancel_overlapped(self): + if self._ov is None: + return + try: + self._ov.cancel() + except OSError as exc: + context = { + 'message': 'Cancelling an overlapped future failed', + 'exception': exc, + 'future': self, + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + self._ov = None + + def cancel(self): + self._cancel_overlapped() + return super().cancel() + + def set_exception(self, exception): + super().set_exception(exception) + self._cancel_overlapped() + + def set_result(self, result): + super().set_result(result) + self._ov = None + + +class _WaitHandleFuture(futures.Future): + """Subclass of Future which represents a wait handle.""" + + def __init__(self, iocp, ov, handle, wait_handle, *, loop=None): + super().__init__(loop=loop) + if self._source_traceback: + del self._source_traceback[-1] + # iocp and ov are only used by cancel() to notify IocpProactor + # that the wait was cancelled + self._iocp = iocp + self._ov = ov + self._handle = handle + self._wait_handle = wait_handle + + def _poll(self): + # non-blocking wait: use a timeout of 0 millisecond + return (_winapi.WaitForSingleObject(self._handle, 0) == + _winapi.WAIT_OBJECT_0) + + def _repr_info(self): + info = super()._repr_info() + info.insert(1, 'handle=%#x' % self._handle) + if self._wait_handle: + state = 'signaled' if self._poll() else 'waiting' + info.insert(1, 'wait_handle=<%s, %#x>' + % (state, self._wait_handle)) + return info + + def _unregister_wait(self): + if self._wait_handle is None: + return + try: + _overlapped.UnregisterWait(self._wait_handle) + except OSError as exc: + # ERROR_IO_PENDING is not an error, the wait was unregistered + if exc.winerror != _overlapped.ERROR_IO_PENDING: + context = { + 'message': 'Failed to unregister the wait handle', + 'exception': exc, + 'future': self, + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + self._wait_handle = None + self._iocp = None + self._ov = None + + def cancel(self): + result = super().cancel() + if self._ov is not None: + # signal the cancellation to the overlapped object + _overlapped.PostQueuedCompletionStatus(self._iocp, True, + 0, self._ov.address) + self._unregister_wait() + return result + + def set_exception(self, exception): + super().set_exception(exception) + self._unregister_wait() + + def set_result(self, result): + super().set_result(result) + self._unregister_wait() + + +class PipeServer(object): + """Class representing a pipe server. + + This is much like a bound, listening socket. + """ + def __init__(self, address): + self._address = address + self._free_instances = weakref.WeakSet() + # initialize the pipe attribute before calling _server_pipe_handle() + # because this function can raise an exception and the destructor calls + # the close() method + self._pipe = None + self._accept_pipe_future = None + self._pipe = self._server_pipe_handle(True) + + def _get_unconnected_pipe(self): + # Create new instance and return previous one. This ensures + # that (until the server is closed) there is always at least + # one pipe handle for address. Therefore if a client attempt + # to connect it will not fail with FileNotFoundError. + tmp, self._pipe = self._pipe, self._server_pipe_handle(False) + return tmp + + def _server_pipe_handle(self, first): + # Return a wrapper for a new pipe handle. + if self._address is None: + return None + flags = _winapi.PIPE_ACCESS_DUPLEX | _winapi.FILE_FLAG_OVERLAPPED + if first: + flags |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE + h = _winapi.CreateNamedPipe( + self._address, flags, + _winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_READMODE_MESSAGE | + _winapi.PIPE_WAIT, + _winapi.PIPE_UNLIMITED_INSTANCES, + windows_utils.BUFSIZE, windows_utils.BUFSIZE, + _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL) + pipe = windows_utils.PipeHandle(h) + self._free_instances.add(pipe) + return pipe + + def close(self): + if self._accept_pipe_future is not None: + self._accept_pipe_future.cancel() + self._accept_pipe_future = None + # Close all instances which have not been connected to by a client. + if self._address is not None: + for pipe in self._free_instances: + pipe.close() + self._pipe = None + self._address = None + self._free_instances.clear() + + __del__ = close + + +class _WindowsSelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Windows version of selector event loop.""" + + def _socketpair(self): + return windows_utils.socketpair() + + +class ProactorEventLoop(proactor_events.BaseProactorEventLoop): + """Windows version of proactor event loop using IOCP.""" + + def __init__(self, proactor=None): + if proactor is None: + proactor = IocpProactor() + super().__init__(proactor) + + def _socketpair(self): + return windows_utils.socketpair() + + @coroutine + def create_pipe_connection(self, protocol_factory, address): + f = self._proactor.connect_pipe(address) + pipe = yield from f + protocol = protocol_factory() + trans = self._make_duplex_pipe_transport(pipe, protocol, + extra={'addr': address}) + return trans, protocol + + @coroutine + def start_serving_pipe(self, protocol_factory, address): + server = PipeServer(address) + + def loop_accept_pipe(f=None): + pipe = None + try: + if f: + pipe = f.result() + server._free_instances.discard(pipe) + protocol = protocol_factory() + self._make_duplex_pipe_transport( + pipe, protocol, extra={'addr': address}) + pipe = server._get_unconnected_pipe() + if pipe is None: + return + f = self._proactor.accept_pipe(pipe) + except OSError as exc: + if pipe and pipe.fileno() != -1: + self.call_exception_handler({ + 'message': 'Pipe accept failed', + 'exception': exc, + 'pipe': pipe, + }) + pipe.close() + elif self._debug: + logger.warning("Accept pipe failed on pipe %r", + pipe, exc_info=True) + except futures.CancelledError: + if pipe: + pipe.close() + else: + server._accept_pipe_future = f + f.add_done_callback(loop_accept_pipe) + + self.call_soon(loop_accept_pipe) + return [server] + + @coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + transp = _WindowsSubprocessTransport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=extra, **kwargs) + yield from transp._post_init() + return transp + + +class IocpProactor: + """Proactor implementation using IOCP.""" + + def __init__(self, concurrency=0xffffffff): + self._loop = None + self._results = [] + self._iocp = _overlapped.CreateIoCompletionPort( + _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency) + self._cache = {} + self._registered = weakref.WeakSet() + self._stopped_serving = weakref.WeakSet() + + def __repr__(self): + return ('<%s overlapped#=%s result#=%s>' + % (self.__class__.__name__, len(self._cache), + len(self._results))) + + def set_loop(self, loop): + self._loop = loop + + def select(self, timeout=None): + if not self._results: + self._poll(timeout) + tmp = self._results + self._results = [] + return tmp + + def recv(self, conn, nbytes, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + if isinstance(conn, socket.socket): + ov.WSARecv(conn.fileno(), nbytes, flags) + else: + ov.ReadFile(conn.fileno(), nbytes) + + def finish_recv(trans, key, ov): + try: + return ov.getresult() + except OSError as exc: + if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: + raise ConnectionResetError(*exc.args) + else: + raise + + return self._register(ov, conn, finish_recv) + + def send(self, conn, buf, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + if isinstance(conn, socket.socket): + ov.WSASend(conn.fileno(), buf, flags) + else: + ov.WriteFile(conn.fileno(), buf) + + def finish_send(trans, key, ov): + try: + return ov.getresult() + except OSError as exc: + if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: + raise ConnectionResetError(*exc.args) + else: + raise + + return self._register(ov, conn, finish_send) + + def accept(self, listener): + self._register_with_iocp(listener) + conn = self._get_accept_socket(listener.family) + ov = _overlapped.Overlapped(NULL) + ov.AcceptEx(listener.fileno(), conn.fileno()) + + def finish_accept(trans, key, ov): + ov.getresult() + # Use SO_UPDATE_ACCEPT_CONTEXT so getsockname() etc work. + buf = struct.pack('@P', listener.fileno()) + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_ACCEPT_CONTEXT, buf) + conn.settimeout(listener.gettimeout()) + return conn, conn.getpeername() + + @coroutine + def accept_coro(future, conn): + # Coroutine closing the accept socket if the future is cancelled + try: + yield from future + except futures.CancelledError: + conn.close() + raise + + future = self._register(ov, listener, finish_accept) + coro = accept_coro(future, conn) + tasks.async(coro, loop=self._loop) + return future + + def connect(self, conn, address): + self._register_with_iocp(conn) + # The socket needs to be locally bound before we call ConnectEx(). + try: + _overlapped.BindLocal(conn.fileno(), conn.family) + except OSError as e: + if e.winerror != errno.WSAEINVAL: + raise + # Probably already locally bound; check using getsockname(). + if conn.getsockname()[1] == 0: + raise + ov = _overlapped.Overlapped(NULL) + ov.ConnectEx(conn.fileno(), address) + + def finish_connect(trans, key, ov): + ov.getresult() + # Use SO_UPDATE_CONNECT_CONTEXT so getsockname() etc work. + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_CONNECT_CONTEXT, 0) + return conn + + return self._register(ov, conn, finish_connect) + + def accept_pipe(self, pipe): + self._register_with_iocp(pipe) + ov = _overlapped.Overlapped(NULL) + ov.ConnectNamedPipe(pipe.fileno()) + + def finish_accept_pipe(trans, key, ov): + ov.getresult() + return pipe + + # FIXME: Tulip issue 196: why to we neeed register=False? + # See also the comment in the _register() method + return self._register(ov, pipe, finish_accept_pipe, + register=False) + + def connect_pipe(self, address): + ov = _overlapped.Overlapped(NULL) + ov.WaitNamedPipeAndConnect(address, self._iocp, ov.address) + + def finish_connect_pipe(err, handle, ov): + # err, handle were arguments passed to PostQueuedCompletionStatus() + # in a function run in a thread pool. + if err == _overlapped.ERROR_SEM_TIMEOUT: + # Connection did not succeed within time limit. + msg = _overlapped.FormatMessage(err) + raise ConnectionRefusedError(0, msg, None, err) + elif err != 0: + msg = _overlapped.FormatMessage(err) + raise OSError(0, msg, None, err) + else: + return windows_utils.PipeHandle(handle) + + return self._register(ov, None, finish_connect_pipe, wait_for_post=True) + + def wait_for_handle(self, handle, timeout=None): + if timeout is None: + ms = _winapi.INFINITE + else: + # RegisterWaitForSingleObject() has a resolution of 1 millisecond, + # round away from zero to wait *at least* timeout seconds. + ms = math.ceil(timeout * 1e3) + + # We only create ov so we can use ov.address as a key for the cache. + ov = _overlapped.Overlapped(NULL) + wh = _overlapped.RegisterWaitWithQueue( + handle, self._iocp, ov.address, ms) + f = _WaitHandleFuture(self._iocp, ov, handle, wh, loop=self._loop) + if f._source_traceback: + del f._source_traceback[-1] + + def finish_wait_for_handle(trans, key, ov): + # Note that this second wait means that we should only use + # this with handles types where a successful wait has no + # effect. So events or processes are all right, but locks + # or semaphores are not. Also note if the handle is + # signalled and then quickly reset, then we may return + # False even though we have not timed out. + return f._poll() + + if f._poll(): + try: + result = f._poll() + except OSError as exc: + f.set_exception(exc) + else: + f.set_result(result) + + self._cache[ov.address] = (f, ov, 0, finish_wait_for_handle) + return f + + def _register_with_iocp(self, obj): + # To get notifications of finished ops on this objects sent to the + # completion port, were must register the handle. + if obj not in self._registered: + self._registered.add(obj) + _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) + # XXX We could also use SetFileCompletionNotificationModes() + # to avoid sending notifications to completion port of ops + # that succeed immediately. + + def _register(self, ov, obj, callback, + wait_for_post=False, register=True): + # Return a future which will be set with the result of the + # operation when it completes. The future's value is actually + # the value returned by callback(). + f = _OverlappedFuture(ov, loop=self._loop) + if f._source_traceback: + del f._source_traceback[-1] + if not ov.pending and not wait_for_post: + # The operation has completed, so no need to postpone the + # work. We cannot take this short cut if we need the + # NumberOfBytes, CompletionKey values returned by + # PostQueuedCompletionStatus(). + try: + value = callback(None, None, ov) + except OSError as e: + f.set_exception(e) + else: + f.set_result(value) + # Even if GetOverlappedResult() was called, we have to wait for the + # notification of the completion in GetQueuedCompletionStatus(). + # Register the overlapped operation to keep a reference to the + # OVERLAPPED object, otherwise the memory is freed and Windows may + # read uninitialized memory. + # + # For an unknown reason, ConnectNamedPipe() behaves differently: + # the completion is not notified by GetOverlappedResult() if we + # already called GetOverlappedResult(). For this specific case, we + # don't expect notification (register is set to False). + else: + register = True + if register: + # Register the overlapped operation for later. Note that + # we only store obj to prevent it from being garbage + # collected too early. + self._cache[ov.address] = (f, ov, obj, callback) + return f + + def _get_accept_socket(self, family): + s = socket.socket(family) + s.settimeout(0) + return s + + def _poll(self, timeout=None): + if timeout is None: + ms = INFINITE + elif timeout < 0: + raise ValueError("negative timeout") + else: + # GetQueuedCompletionStatus() has a resolution of 1 millisecond, + # round away from zero to wait *at least* timeout seconds. + ms = math.ceil(timeout * 1e3) + if ms >= INFINITE: + raise ValueError("timeout too big") + + while True: + status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms) + if status is None: + return + ms = 0 + + err, transferred, key, address = status + try: + f, ov, obj, callback = self._cache.pop(address) + except KeyError: + if self._loop.get_debug(): + self._loop.call_exception_handler({ + 'message': ('GetQueuedCompletionStatus() returned an ' + 'unexpected event'), + 'status': ('err=%s transferred=%s key=%#x address=%#x' + % (err, transferred, key, address)), + }) + + # key is either zero, or it is used to return a pipe + # handle which should be closed to avoid a leak. + if key not in (0, _overlapped.INVALID_HANDLE_VALUE): + _winapi.CloseHandle(key) + continue + + if obj in self._stopped_serving: + f.cancel() + # Don't call the callback if _register() already read the result or + # if the overlapped has been cancelled + elif not f.done(): + try: + value = callback(transferred, key, ov) + except OSError as e: + f.set_exception(e) + self._results.append(f) + else: + f.set_result(value) + self._results.append(f) + + def _stop_serving(self, obj): + # obj is a socket or pipe handle. It will be closed in + # BaseProactorEventLoop._stop_serving() which will make any + # pending operations fail quickly. + self._stopped_serving.add(obj) + + def close(self): + # Cancel remaining registered operations. + for address, (fut, ov, obj, callback) in list(self._cache.items()): + if obj is None: + # The operation was started with connect_pipe() which + # queues a task to Windows' thread pool. This cannot + # be cancelled, so just forget it. + del self._cache[address] + # FIXME: Tulip issue 196: remove this case, it should not happen + elif fut.done() and not fut.cancelled(): + del self._cache[address] + else: + try: + fut.cancel() + except OSError as exc: + if self._loop is not None: + context = { + 'message': 'Cancelling a future failed', + 'exception': exc, + 'future': fut, + } + if fut._source_traceback: + context['source_traceback'] = fut._source_traceback + self._loop.call_exception_handler(context) + + while self._cache: + if not self._poll(1): + logger.debug('taking long time to close proactor') + + self._results = [] + if self._iocp is not None: + _winapi.CloseHandle(self._iocp) + self._iocp = None + + def __del__(self): + self.close() + + +class _WindowsSubprocessTransport(base_subprocess.BaseSubprocessTransport): + + def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): + self._proc = windows_utils.Popen( + args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, + bufsize=bufsize, **kwargs) + + def callback(f): + returncode = self._proc.poll() + self._process_exited(returncode) + + f = self._loop._proactor.wait_for_handle(int(self._proc._handle)) + f.add_done_callback(callback) + + +SelectorEventLoop = _WindowsSelectorEventLoop + + +class _WindowsDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy): + _loop_factory = SelectorEventLoop + + +DefaultEventLoopPolicy = _WindowsDefaultEventLoopPolicy diff --git a/asyncio/windows_utils.py b/asyncio/windows_utils.py new file mode 100644 index 00000000..1155a77f --- /dev/null +++ b/asyncio/windows_utils.py @@ -0,0 +1,209 @@ +""" +Various Windows specific bits and pieces +""" + +import sys + +if sys.platform != 'win32': # pragma: no cover + raise ImportError('win32 only') + +import socket +import itertools +import msvcrt +import os +import subprocess +import tempfile +import _winapi + + +__all__ = ['socketpair', 'pipe', 'Popen', 'PIPE', 'PipeHandle'] + + +# Constants/globals + + +BUFSIZE = 8192 +PIPE = subprocess.PIPE +STDOUT = subprocess.STDOUT +_mmap_counter = itertools.count() + + +if hasattr(socket, 'socketpair'): + # Since Python 3.5, socket.socketpair() is now also available on Windows + socketpair = socket.socketpair +else: + # Replacement for socket.socketpair() + def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """A socket pair usable as a self-pipe, for Windows. + + Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. + """ + if family == socket.AF_INET: + host = '127.0.0.1' + elif family == socket.AF_INET6: + host = '::1' + else: + raise ValueError("Only AF_INET and AF_INET6 socket address families " + "are supported") + if type != socket.SOCK_STREAM: + raise ValueError("Only SOCK_STREAM socket type is supported") + if proto != 0: + raise ValueError("Only protocol zero is supported") + + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + try: + lsock.bind((host, 0)) + lsock.listen(1) + # On IPv6, ignore flow_info and scope_id + addr, port = lsock.getsockname()[:2] + csock = socket.socket(family, type, proto) + try: + csock.setblocking(False) + try: + csock.connect((addr, port)) + except (BlockingIOError, InterruptedError): + pass + csock.setblocking(True) + ssock, _ = lsock.accept() + except: + csock.close() + raise + finally: + lsock.close() + return (ssock, csock) + + +# Replacement for os.pipe() using handles instead of fds + + +def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): + """Like os.pipe() but with overlapped support and using handles not fds.""" + address = tempfile.mktemp(prefix=r'\\.\pipe\python-pipe-%d-%d-' % + (os.getpid(), next(_mmap_counter))) + + if duplex: + openmode = _winapi.PIPE_ACCESS_DUPLEX + access = _winapi.GENERIC_READ | _winapi.GENERIC_WRITE + obsize, ibsize = bufsize, bufsize + else: + openmode = _winapi.PIPE_ACCESS_INBOUND + access = _winapi.GENERIC_WRITE + obsize, ibsize = 0, bufsize + + openmode |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE + + if overlapped[0]: + openmode |= _winapi.FILE_FLAG_OVERLAPPED + + if overlapped[1]: + flags_and_attribs = _winapi.FILE_FLAG_OVERLAPPED + else: + flags_and_attribs = 0 + + h1 = h2 = None + try: + h1 = _winapi.CreateNamedPipe( + address, openmode, _winapi.PIPE_WAIT, + 1, obsize, ibsize, _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL) + + h2 = _winapi.CreateFile( + address, access, 0, _winapi.NULL, _winapi.OPEN_EXISTING, + flags_and_attribs, _winapi.NULL) + + ov = _winapi.ConnectNamedPipe(h1, overlapped=True) + ov.GetOverlappedResult(True) + return h1, h2 + except: + if h1 is not None: + _winapi.CloseHandle(h1) + if h2 is not None: + _winapi.CloseHandle(h2) + raise + + +# Wrapper for a pipe handle + + +class PipeHandle: + """Wrapper for an overlapped pipe handle which is vaguely file-object like. + + The IOCP event loop can use these instead of socket objects. + """ + def __init__(self, handle): + self._handle = handle + + @property + def handle(self): + return self._handle + + def fileno(self): + return self._handle + + def close(self, *, CloseHandle=_winapi.CloseHandle): + if self._handle != -1: + CloseHandle(self._handle) + self._handle = -1 + + __del__ = close + + def __enter__(self): + return self + + def __exit__(self, t, v, tb): + self.close() + + +# Replacement for subprocess.Popen using overlapped pipe handles + + +class Popen(subprocess.Popen): + """Replacement for subprocess.Popen using overlapped pipe handles. + + The stdin, stdout, stderr are None or instances of PipeHandle. + """ + def __init__(self, args, stdin=None, stdout=None, stderr=None, **kwds): + assert not kwds.get('universal_newlines') + assert kwds.get('bufsize', 0) == 0 + stdin_rfd = stdout_wfd = stderr_wfd = None + stdin_wh = stdout_rh = stderr_rh = None + if stdin == PIPE: + stdin_rh, stdin_wh = pipe(overlapped=(False, True), duplex=True) + stdin_rfd = msvcrt.open_osfhandle(stdin_rh, os.O_RDONLY) + else: + stdin_rfd = stdin + if stdout == PIPE: + stdout_rh, stdout_wh = pipe(overlapped=(True, False)) + stdout_wfd = msvcrt.open_osfhandle(stdout_wh, 0) + else: + stdout_wfd = stdout + if stderr == PIPE: + stderr_rh, stderr_wh = pipe(overlapped=(True, False)) + stderr_wfd = msvcrt.open_osfhandle(stderr_wh, 0) + elif stderr == STDOUT: + stderr_wfd = stdout_wfd + else: + stderr_wfd = stderr + try: + super().__init__(args, stdin=stdin_rfd, stdout=stdout_wfd, + stderr=stderr_wfd, **kwds) + except: + for h in (stdin_wh, stdout_rh, stderr_rh): + if h is not None: + _winapi.CloseHandle(h) + raise + else: + if stdin_wh is not None: + self.stdin = PipeHandle(stdin_wh) + if stdout_rh is not None: + self.stdout = PipeHandle(stdout_rh) + if stderr_rh is not None: + self.stderr = PipeHandle(stderr_rh) + finally: + if stdin == PIPE: + os.close(stdin_rfd) + if stdout == PIPE: + os.close(stdout_wfd) + if stderr == PIPE: + os.close(stderr_wfd) diff --git a/check.py b/check.py new file mode 100644 index 00000000..6db82d64 --- /dev/null +++ b/check.py @@ -0,0 +1,45 @@ +"""Search for lines >= 80 chars or with trailing whitespace.""" + +import os +import sys + + +def main(): + args = sys.argv[1:] or os.curdir + for arg in args: + if os.path.isdir(arg): + for dn, dirs, files in os.walk(arg): + for fn in sorted(files): + if fn.endswith('.py'): + process(os.path.join(dn, fn)) + dirs[:] = [d for d in dirs if d[0] != '.'] + dirs.sort() + else: + process(arg) + + +def isascii(x): + try: + x.encode('ascii') + return True + except UnicodeError: + return False + + +def process(fn): + try: + f = open(fn) + except IOError as err: + print(err) + return + try: + for i, line in enumerate(f): + line = line.rstrip('\n') + sline = line.rstrip() + if len(line) >= 80 or line != sline or not isascii(line): + print('{}:{:d}:{}{}'.format( + fn, i+1, sline, '_' * (len(line) - len(sline)))) + finally: + f.close() + +main() diff --git a/examples/cacheclt.py b/examples/cacheclt.py new file mode 100644 index 00000000..b11a4d1a --- /dev/null +++ b/examples/cacheclt.py @@ -0,0 +1,213 @@ +"""Client for cache server. + +See cachesvr.py for protocol description. +""" + +import argparse +import asyncio +from asyncio import test_utils +import json +import logging + +ARGS = argparse.ArgumentParser(description='Cache client example.') +ARGS.add_argument( + '--tls', action='store_true', dest='tls', + default=False, help='Use TLS') +ARGS.add_argument( + '--iocp', action='store_true', dest='iocp', + default=False, help='Use IOCP event loop (Windows only)') +ARGS.add_argument( + '--host', action='store', dest='host', + default='localhost', help='Host name') +ARGS.add_argument( + '--port', action='store', dest='port', + default=54321, type=int, help='Port number') +ARGS.add_argument( + '--timeout', action='store', dest='timeout', + default=5, type=float, help='Timeout') +ARGS.add_argument( + '--max_backoff', action='store', dest='max_backoff', + default=5, type=float, help='Max backoff on reconnect') +ARGS.add_argument( + '--ntasks', action='store', dest='ntasks', + default=10, type=int, help='Number of tester tasks') +ARGS.add_argument( + '--ntries', action='store', dest='ntries', + default=5, type=int, help='Number of request tries before giving up') + + +args = ARGS.parse_args() + + +class CacheClient: + """Multiplexing cache client. + + This wraps a single connection to the cache client. The + connection is automatically re-opened when an error occurs. + + Multiple tasks may share this object; the requests will be + serialized. + + The public API is get(), set(), delete() (all are coroutines). + """ + + def __init__(self, host, port, sslctx=None, loop=None): + self.host = host + self.port = port + self.sslctx = sslctx + self.loop = loop + self.todo = set() + self.initialized = False + self.task = asyncio.Task(self.activity(), loop=self.loop) + + @asyncio.coroutine + def get(self, key): + resp = yield from self.request('get', key) + if resp is None: + return None + return resp.get('value') + + @asyncio.coroutine + def set(self, key, value): + resp = yield from self.request('set', key, value) + if resp is None: + return False + return resp.get('status') == 'ok' + + @asyncio.coroutine + def delete(self, key): + resp = yield from self.request('delete', key) + if resp is None: + return False + return resp.get('status') == 'ok' + + @asyncio.coroutine + def request(self, type, key, value=None): + assert not self.task.done() + data = {'type': type, 'key': key} + if value is not None: + data['value'] = value + payload = json.dumps(data).encode('utf8') + waiter = asyncio.Future(loop=self.loop) + if self.initialized: + try: + yield from self.send(payload, waiter) + except IOError: + self.todo.add((payload, waiter)) + else: + self.todo.add((payload, waiter)) + return (yield from waiter) + + @asyncio.coroutine + def activity(self): + backoff = 0 + while True: + try: + self.reader, self.writer = yield from asyncio.open_connection( + self.host, self.port, ssl=self.sslctx, loop=self.loop) + except Exception as exc: + backoff = min(args.max_backoff, backoff + (backoff//2) + 1) + logging.info('Error connecting: %r; sleep %s', exc, backoff) + yield from asyncio.sleep(backoff, loop=self.loop) + continue + backoff = 0 + self.next_id = 0 + self.pending = {} + self. initialized = True + try: + while self.todo: + payload, waiter = self.todo.pop() + if not waiter.done(): + yield from self.send(payload, waiter) + while True: + resp_id, resp = yield from self.process() + if resp_id in self.pending: + payload, waiter = self.pending.pop(resp_id) + if not waiter.done(): + waiter.set_result(resp) + except Exception as exc: + self.initialized = False + self.writer.close() + while self.pending: + req_id, pair = self.pending.popitem() + payload, waiter = pair + if not waiter.done(): + self.todo.add(pair) + logging.info('Error processing: %r', exc) + + @asyncio.coroutine + def send(self, payload, waiter): + self.next_id += 1 + req_id = self.next_id + frame = 'request %d %d\n' % (req_id, len(payload)) + self.writer.write(frame.encode('ascii')) + self.writer.write(payload) + self.pending[req_id] = payload, waiter + yield from self.writer.drain() + + @asyncio.coroutine + def process(self): + frame = yield from self.reader.readline() + if not frame: + raise EOFError() + head, tail = frame.split(None, 1) + if head == b'error': + raise IOError('OOB error: %r' % tail) + if head != b'response': + raise IOError('Bad frame: %r' % frame) + resp_id, resp_size = map(int, tail.split()) + data = yield from self.reader.readexactly(resp_size) + if len(data) != resp_size: + raise EOFError() + resp = json.loads(data.decode('utf8')) + return resp_id, resp + + +def main(): + asyncio.set_event_loop(None) + if args.iocp: + from asyncio.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + else: + loop = asyncio.new_event_loop() + sslctx = None + if args.tls: + sslctx = test_utils.dummy_ssl_context() + cache = CacheClient(args.host, args.port, sslctx=sslctx, loop=loop) + try: + loop.run_until_complete( + asyncio.gather( + *[testing(i, cache, loop) for i in range(args.ntasks)], + loop=loop)) + finally: + loop.close() + + +@asyncio.coroutine +def testing(label, cache, loop): + + def w(g): + return asyncio.wait_for(g, args.timeout, loop=loop) + + key = 'foo-%s' % label + while True: + logging.info('%s %s', label, '-'*20) + try: + ret = yield from w(cache.set(key, 'hello-%s-world' % label)) + logging.info('%s set %s', label, ret) + ret = yield from w(cache.get(key)) + logging.info('%s get %s', label, ret) + ret = yield from w(cache.delete(key)) + logging.info('%s del %s', label, ret) + ret = yield from w(cache.get(key)) + logging.info('%s get2 %s', label, ret) + except asyncio.TimeoutError: + logging.warn('%s Timeout', label) + except Exception as exc: + logging.exception('%s Client exception: %r', label, exc) + break + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + main() diff --git a/examples/cachesvr.py b/examples/cachesvr.py new file mode 100644 index 00000000..053f9c21 --- /dev/null +++ b/examples/cachesvr.py @@ -0,0 +1,249 @@ +"""A simple memcache-like server. + +The basic data structure maintained is a single in-memory dictionary +mapping string keys to string values, with operations get, set and +delete. (Both keys and values may contain Unicode.) + +This is a TCP server listening on port 54321. There is no +authentication. + +Requests provide an operation and return a response. A connection may +be used for multiple requests. The connection is closed when a client +sends a bad request. + +If a client is idle for over 5 seconds (i.e., it does not send another +request, or fails to read the whole response, within this time), it is +disconnected. + +Framing of requests and responses within a connection uses a +line-based protocol. The first line of a request is the frame header +and contains three whitespace-delimited token followed by LF or CRLF: + +- the keyword 'request' +- a decimal request ID; the first request is '1', the second '2', etc. +- a decimal byte count giving the size of the rest of the request + +Note that the requests ID *must* be consecutive and start at '1' for +each connection. + +Response frames look the same except the keyword is 'response'. The +response ID matches the request ID. There should be exactly one +response to each request and responses should be seen in the same +order as the requests. + +After the frame, individual requests and responses are JSON encoded. + +If the frame header or the JSON request body cannot be parsed, an +unframed error message (always starting with 'error') is written back +and the connection is closed. + +JSON-encoded requests can be: + +- {"type": "get", "key": } +- {"type": "set", "key": , "value": } +- {"type": "delete", "key": } + +Responses are also JSON-encoded: + +- {"status": "ok", "value": } # Successful get request +- {"status": "ok"} # Successful set or delete request +- {"status": "notfound"} # Key not found for get or delete request + +If the request is valid JSON but cannot be handled (e.g., the type or +key field is absent or invalid), an error response of the following +form is returned, but the connection is not closed: + +- {"error": } +""" + +import argparse +import asyncio +import json +import logging +import os +import random + +ARGS = argparse.ArgumentParser(description='Cache server example.') +ARGS.add_argument( + '--tls', action='store_true', dest='tls', + default=False, help='Use TLS') +ARGS.add_argument( + '--iocp', action='store_true', dest='iocp', + default=False, help='Use IOCP event loop (Windows only)') +ARGS.add_argument( + '--host', action='store', dest='host', + default='localhost', help='Host name') +ARGS.add_argument( + '--port', action='store', dest='port', + default=54321, type=int, help='Port number') +ARGS.add_argument( + '--timeout', action='store', dest='timeout', + default=5, type=float, help='Timeout') +ARGS.add_argument( + '--random_failure_percent', action='store', dest='fail_percent', + default=0, type=float, help='Fail randomly N percent of the time') +ARGS.add_argument( + '--random_failure_sleep', action='store', dest='fail_sleep', + default=0, type=float, help='Sleep time when randomly failing') +ARGS.add_argument( + '--random_response_sleep', action='store', dest='resp_sleep', + default=0, type=float, help='Sleep time before responding') + +args = ARGS.parse_args() + + +class Cache: + + def __init__(self, loop): + self.loop = loop + self.table = {} + + @asyncio.coroutine + def handle_client(self, reader, writer): + # Wrapper to log stuff and close writer (i.e., transport). + peer = writer.get_extra_info('socket').getpeername() + logging.info('got a connection from %s', peer) + try: + yield from self.frame_parser(reader, writer) + except Exception as exc: + logging.error('error %r from %s', exc, peer) + else: + logging.info('end connection from %s', peer) + finally: + writer.close() + + @asyncio.coroutine + def frame_parser(self, reader, writer): + # This takes care of the framing. + last_request_id = 0 + while True: + # Read the frame header, parse it, read the data. + # NOTE: The readline() and readexactly() calls will hang + # if the client doesn't send enough data but doesn't + # disconnect either. We add a timeout to each. (But the + # timeout should really be implemented by StreamReader.) + framing_b = yield from asyncio.wait_for( + reader.readline(), + timeout=args.timeout, loop=self.loop) + if random.random()*100 < args.fail_percent: + logging.warn('Inserting random failure') + yield from asyncio.sleep(args.fail_sleep*random.random(), + loop=self.loop) + writer.write(b'error random failure\r\n') + break + logging.debug('framing_b = %r', framing_b) + if not framing_b: + break # Clean close. + try: + frame_keyword, request_id_b, byte_count_b = framing_b.split() + except ValueError: + writer.write(b'error unparseable frame\r\n') + break + if frame_keyword != b'request': + writer.write(b'error frame does not start with request\r\n') + break + try: + request_id, byte_count = int(request_id_b), int(byte_count_b) + except ValueError: + writer.write(b'error unparsable frame parameters\r\n') + break + if request_id != last_request_id + 1 or byte_count < 2: + writer.write(b'error invalid frame parameters\r\n') + break + last_request_id = request_id + request_b = yield from asyncio.wait_for( + reader.readexactly(byte_count), + timeout=args.timeout, loop=self.loop) + try: + request = json.loads(request_b.decode('utf8')) + except ValueError: + writer.write(b'error unparsable json\r\n') + break + response = self.handle_request(request) # Not a coroutine. + if response is None: + writer.write(b'error unhandlable request\r\n') + break + response_b = json.dumps(response).encode('utf8') + b'\r\n' + byte_count = len(response_b) + framing_s = 'response {} {}\r\n'.format(request_id, byte_count) + writer.write(framing_s.encode('ascii')) + yield from asyncio.sleep(args.resp_sleep*random.random(), + loop=self.loop) + writer.write(response_b) + + def handle_request(self, request): + # This parses one request and farms it out to a specific handler. + # Return None for all errors. + if not isinstance(request, dict): + return {'error': 'request is not a dict'} + request_type = request.get('type') + if request_type is None: + return {'error': 'no type in request'} + if request_type not in {'get', 'set', 'delete'}: + return {'error': 'unknown request type'} + key = request.get('key') + if not isinstance(key, str): + return {'error': 'key is not a string'} + if request_type == 'get': + return self.handle_get(key) + if request_type == 'set': + value = request.get('value') + if not isinstance(value, str): + return {'error': 'value is not a string'} + return self.handle_set(key, value) + if request_type == 'delete': + return self.handle_delete(key) + assert False, 'bad request type' # Should have been caught above. + + def handle_get(self, key): + value = self.table.get(key) + if value is None: + return {'status': 'notfound'} + else: + return {'status': 'ok', 'value': value} + + def handle_set(self, key, value): + self.table[key] = value + return {'status': 'ok'} + + def handle_delete(self, key): + if key not in self.table: + return {'status': 'notfound'} + else: + del self.table[key] + return {'status': 'ok'} + + +def main(): + asyncio.set_event_loop(None) + if args.iocp: + from asyncio.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + else: + loop = asyncio.new_event_loop() + sslctx = None + if args.tls: + import ssl + # TODO: take cert/key from args as well. + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + sslctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslctx.options |= ssl.OP_NO_SSLv2 + sslctx.load_cert_chain( + certfile=os.path.join(here, 'ssl_cert.pem'), + keyfile=os.path.join(here, 'ssl_key.pem')) + cache = Cache(loop) + task = asyncio.streams.start_server(cache.handle_client, + args.host, args.port, + ssl=sslctx, loop=loop) + svr = loop.run_until_complete(task) + for sock in svr.sockets: + logging.info('socket %s', sock.getsockname()) + try: + loop.run_forever() + finally: + loop.close() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + main() diff --git a/examples/child_process.py b/examples/child_process.py new file mode 100644 index 00000000..3fac175e --- /dev/null +++ b/examples/child_process.py @@ -0,0 +1,128 @@ +""" +Example of asynchronous interaction with a child python process. + +This example shows how to attach an existing Popen object and use the low level +transport-protocol API. See shell.py and subprocess_shell.py for higher level +examples. +""" + +import os +import sys + +try: + import asyncio +except ImportError: + # asyncio is not installed + sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + import asyncio + +if sys.platform == 'win32': + from asyncio.windows_utils import Popen, PIPE + from asyncio.windows_events import ProactorEventLoop +else: + from subprocess import Popen, PIPE + +# +# Return a write-only transport wrapping a writable pipe +# + +@asyncio.coroutine +def connect_write_pipe(file): + loop = asyncio.get_event_loop() + transport, _ = yield from loop.connect_write_pipe(asyncio.Protocol, file) + return transport + +# +# Wrap a readable pipe in a stream +# + +@asyncio.coroutine +def connect_read_pipe(file): + loop = asyncio.get_event_loop() + stream_reader = asyncio.StreamReader(loop=loop) + def factory(): + return asyncio.StreamReaderProtocol(stream_reader) + transport, _ = yield from loop.connect_read_pipe(factory, file) + return stream_reader, transport + + +# +# Example +# + +@asyncio.coroutine +def main(loop): + # program which prints evaluation of each expression from stdin + code = r'''if 1: + import os + def writeall(fd, buf): + while buf: + n = os.write(fd, buf) + buf = buf[n:] + while True: + s = os.read(0, 1024) + if not s: + break + s = s.decode('ascii') + s = repr(eval(s)) + '\n' + s = s.encode('ascii') + writeall(1, s) + ''' + + # commands to send to input + commands = iter([b"1+1\n", + b"2**16\n", + b"1/3\n", + b"'x'*50", + b"1/0\n"]) + + # start subprocess and wrap stdin, stdout, stderr + p = Popen([sys.executable, '-c', code], + stdin=PIPE, stdout=PIPE, stderr=PIPE) + + stdin = yield from connect_write_pipe(p.stdin) + stdout, stdout_transport = yield from connect_read_pipe(p.stdout) + stderr, stderr_transport = yield from connect_read_pipe(p.stderr) + + # interact with subprocess + name = {stdout:'OUT', stderr:'ERR'} + registered = {asyncio.Task(stderr.readline()): stderr, + asyncio.Task(stdout.readline()): stdout} + while registered: + # write command + cmd = next(commands, None) + if cmd is None: + stdin.close() + else: + print('>>>', cmd.decode('ascii').rstrip()) + stdin.write(cmd) + + # get and print lines from stdout, stderr + timeout = None + while registered: + done, pending = yield from asyncio.wait( + registered, timeout=timeout, + return_when=asyncio.FIRST_COMPLETED) + if not done: + break + for f in done: + stream = registered.pop(f) + res = f.result() + print(name[stream], res.decode('ascii').rstrip()) + if res != b'': + registered[asyncio.Task(stream.readline())] = stream + timeout = 0.0 + + stdout_transport.close() + stderr_transport.close() + +if __name__ == '__main__': + if sys.platform == 'win32': + loop = ProactorEventLoop() + asyncio.set_event_loop(loop) + else: + loop = asyncio.get_event_loop() + try: + loop.run_until_complete(main(loop)) + finally: + loop.close() diff --git a/examples/crawl.py b/examples/crawl.py new file mode 100644 index 00000000..4bb0b4ea --- /dev/null +++ b/examples/crawl.py @@ -0,0 +1,863 @@ +#!/usr/bin/env python3.4 + +"""A simple web crawler.""" + +# TODO: +# - More organized logging (with task ID or URL?). +# - Use logging module for Logger. +# - KeyboardInterrupt in HTML parsing may hang or report unretrieved error. +# - Support gzip encoding. +# - Close connection if HTTP/1.0 response. +# - Add timeouts. (E.g. when switching networks, all seems to hang.) +# - Add arguments to specify TLS settings (e.g. cert/key files). +# - Skip reading large non-text/html files? +# - Use ETag and If-Modified-Since? +# - Handle out of file descriptors directly? (How?) + +import argparse +import asyncio +import asyncio.locks +import cgi +from http.client import BadStatusLine +import logging +import re +import sys +import time +import urllib.parse + + +ARGS = argparse.ArgumentParser(description="Web crawler") +ARGS.add_argument( + '--iocp', action='store_true', dest='iocp', + default=False, help='Use IOCP event loop (Windows only)') +ARGS.add_argument( + '--select', action='store_true', dest='select', + default=False, help='Use Select event loop instead of default') +ARGS.add_argument( + 'roots', nargs='*', + default=[], help='Root URL (may be repeated)') +ARGS.add_argument( + '--max_redirect', action='store', type=int, metavar='N', + default=10, help='Limit redirection chains (for 301, 302 etc.)') +ARGS.add_argument( + '--max_tries', action='store', type=int, metavar='N', + default=4, help='Limit retries on network errors') +ARGS.add_argument( + '--max_tasks', action='store', type=int, metavar='N', + default=100, help='Limit concurrent connections') +ARGS.add_argument( + '--max_pool', action='store', type=int, metavar='N', + default=100, help='Limit connection pool size') +ARGS.add_argument( + '--exclude', action='store', metavar='REGEX', + help='Exclude matching URLs') +ARGS.add_argument( + '--strict', action='store_true', + default=True, help='Strict host matching (default)') +ARGS.add_argument( + '--lenient', action='store_false', dest='strict', + default=False, help='Lenient host matching') +ARGS.add_argument( + '-v', '--verbose', action='count', dest='level', + default=1, help='Verbose logging (repeat for more verbose)') +ARGS.add_argument( + '-q', '--quiet', action='store_const', const=0, dest='level', + default=1, help='Quiet logging (opposite of --verbose)') + + +ESCAPES = [('quot', '"'), + ('gt', '>'), + ('lt', '<'), + ('amp', '&') # Must be last. + ] + + +def unescape(url): + """Turn & into &, and so on. + + This is the inverse of cgi.escape(). + """ + for name, char in ESCAPES: + url = url.replace('&' + name + ';', char) + return url + + +def fix_url(url): + """Prefix a schema-less URL with http://.""" + if '://' not in url: + url = 'http://' + url + return url + + +class Logger: + + def __init__(self, level): + self.level = level + + def _log(self, n, args): + if self.level >= n: + print(*args, file=sys.stderr, flush=True) + + def log(self, n, *args): + self._log(n, args) + + def __call__(self, n, *args): + self._log(n, args) + + +class ConnectionPool: + """A connection pool. + + To open a connection, use reserve(). To recycle it, use unreserve(). + + The pool is mostly just a mapping from (host, port, ssl) tuples to + lists of Connections. The currently active connections are *not* + in the data structure; get_connection() takes the connection out, + and recycle_connection() puts it back in. To recycle a + connection, call conn.close(recycle=True). + + There are limits to both the overall pool and the per-key pool. + """ + + def __init__(self, log, max_pool=10, max_tasks=5): + self.log = log + self.max_pool = max_pool # Overall limit. + self.max_tasks = max_tasks # Per-key limit. + self.loop = asyncio.get_event_loop() + self.connections = {} # {(host, port, ssl): [Connection, ...], ...} + self.queue = [] # [Connection, ...] + + def close(self): + """Close all connections available for reuse.""" + for conns in self.connections.values(): + for conn in conns: + conn.close() + self.connections.clear() + self.queue.clear() + + @asyncio.coroutine + def get_connection(self, host, port, ssl): + """Create or reuse a connection.""" + port = port or (443 if ssl else 80) + try: + ipaddrs = yield from self.loop.getaddrinfo(host, port) + except Exception as exc: + self.log(0, 'Exception %r for (%r, %r)' % (exc, host, port)) + raise + self.log(1, '* %s resolves to %s' % + (host, ', '.join(ip[4][0] for ip in ipaddrs))) + + # Look for a reusable connection. + for _, _, _, _, (h, p, *_) in ipaddrs: + key = h, p, ssl + conn = None + conns = self.connections.get(key) + while conns: + conn = conns.pop(0) + self.queue.remove(conn) + if not conns: + del self.connections[key] + if conn.stale(): + self.log(1, 'closing stale connection for', key) + conn.close() # Just in case. + else: + self.log(1, '* Reusing pooled connection', key, + 'FD =', conn.fileno()) + return conn + + # Create a new connection. + conn = Connection(self.log, self, host, port, ssl) + yield from conn.connect() + self.log(1, '* New connection', conn.key, 'FD =', conn.fileno()) + return conn + + def recycle_connection(self, conn): + """Make a connection available for reuse. + + This also prunes the pool if it exceeds the size limits. + """ + if conn.stale(): + conn.close() + return + + key = conn.key + conns = self.connections.setdefault(key, []) + conns.append(conn) + self.queue.append(conn) + + if len(conns) <= self.max_tasks and len(self.queue) <= self.max_pool: + return + + # Prune the queue. + + # Close stale connections for this key first. + stale = [conn for conn in conns if conn.stale()] + if stale: + for conn in stale: + conns.remove(conn) + self.queue.remove(conn) + self.log(1, 'closing stale connection for', key) + conn.close() + if not conns: + del self.connections[key] + + # Close oldest connection(s) for this key if limit reached. + while len(conns) > self.max_tasks: + conn = conns.pop(0) + self.queue.remove(conn) + self.log(1, 'closing oldest connection for', key) + conn.close() + + if len(self.queue) <= self.max_pool: + return + + # Close overall stale connections. + stale = [conn for conn in self.queue if conn.stale()] + if stale: + for conn in stale: + conns = self.connections.get(conn.key) + conns.remove(conn) + self.queue.remove(conn) + self.log(1, 'closing stale connection for', key) + conn.close() + + # Close oldest overall connection(s) if limit reached. + while len(self.queue) > self.max_pool: + conn = self.queue.pop(0) + conns = self.connections.get(conn.key) + c = conns.pop(0) + assert conn == c, (conn.key, conn, c, conns) + self.log(1, 'closing overall oldest connection for', conn.key) + conn.close() + + +class Connection: + + def __init__(self, log, pool, host, port, ssl): + self.log = log + self.pool = pool + self.host = host + self.port = port + self.ssl = ssl + self.reader = None + self.writer = None + self.key = None + + def stale(self): + return self.reader is None or self.reader.at_eof() + + def fileno(self): + writer = self.writer + if writer is not None: + transport = writer.transport + if transport is not None: + sock = transport.get_extra_info('socket') + if sock is not None: + return sock.fileno() + return None + + @asyncio.coroutine + def connect(self): + self.reader, self.writer = yield from asyncio.open_connection( + self.host, self.port, ssl=self.ssl) + peername = self.writer.get_extra_info('peername') + if peername: + self.host, self.port = peername[:2] + else: + self.log(1, 'NO PEERNAME???', self.host, self.port, self.ssl) + self.key = self.host, self.port, self.ssl + + def close(self, recycle=False): + if recycle and not self.stale(): + self.pool.recycle_connection(self) + else: + self.writer.close() + self.pool = self.reader = self.writer = None + + +class Request: + """HTTP request. + + Use connect() to open a connection; send_request() to send the + request; get_response() to receive the response headers. + """ + + def __init__(self, log, url, pool): + self.log = log + self.url = url + self.pool = pool + self.parts = urllib.parse.urlparse(self.url) + self.scheme = self.parts.scheme + assert self.scheme in ('http', 'https'), repr(url) + self.ssl = self.parts.scheme == 'https' + self.netloc = self.parts.netloc + self.hostname = self.parts.hostname + self.port = self.parts.port or (443 if self.ssl else 80) + self.path = (self.parts.path or '/') + self.query = self.parts.query + if self.query: + self.full_path = '%s?%s' % (self.path, self.query) + else: + self.full_path = self.path + self.http_version = 'HTTP/1.1' + self.method = 'GET' + self.headers = [] + self.conn = None + + @asyncio.coroutine + def connect(self): + """Open a connection to the server.""" + self.log(1, '* Connecting to %s:%s using %s for %s' % + (self.hostname, self.port, + 'ssl' if self.ssl else 'tcp', + self.url)) + self.conn = yield from self.pool.get_connection(self.hostname, + self.port, self.ssl) + + def close(self, recycle=False): + """Close the connection, recycle if requested.""" + if self.conn is not None: + if not recycle: + self.log(1, 'closing connection for', self.conn.key) + self.conn.close(recycle) + self.conn = None + + @asyncio.coroutine + def putline(self, line): + """Write a line to the connection. + + Used for the request line and headers. + """ + self.log(2, '>', line) + self.conn.writer.write(line.encode('latin-1') + b'\r\n') + + @asyncio.coroutine + def send_request(self): + """Send the request.""" + request_line = '%s %s %s' % (self.method, self.full_path, + self.http_version) + yield from self.putline(request_line) + # TODO: What if a header is already set? + self.headers.append(('User-Agent', 'asyncio-example-crawl/0.0')) + self.headers.append(('Host', self.netloc)) + self.headers.append(('Accept', '*/*')) + ##self.headers.append(('Accept-Encoding', 'gzip')) + for key, value in self.headers: + line = '%s: %s' % (key, value) + yield from self.putline(line) + yield from self.putline('') + + @asyncio.coroutine + def get_response(self): + """Receive the response.""" + response = Response(self.log, self.conn.reader) + yield from response.read_headers() + return response + + +class Response: + """HTTP response. + + Call read_headers() to receive the request headers. Then check + the status attribute and call get_header() to inspect the headers. + Finally call read() to receive the body. + """ + + def __init__(self, log, reader): + self.log = log + self.reader = reader + self.http_version = None # 'HTTP/1.1' + self.status = None # 200 + self.reason = None # 'Ok' + self.headers = [] # [('Content-Type', 'text/html')] + + @asyncio.coroutine + def getline(self): + """Read one line from the connection.""" + line = (yield from self.reader.readline()).decode('latin-1').rstrip() + self.log(2, '<', line) + return line + + @asyncio.coroutine + def read_headers(self): + """Read the response status and the request headers.""" + status_line = yield from self.getline() + status_parts = status_line.split(None, 2) + if len(status_parts) != 3: + self.log(0, 'bad status_line', repr(status_line)) + raise BadStatusLine(status_line) + self.http_version, status, self.reason = status_parts + self.status = int(status) + while True: + header_line = yield from self.getline() + if not header_line: + break + # TODO: Continuation lines. + key, value = header_line.split(':', 1) + self.headers.append((key, value.strip())) + + def get_redirect_url(self, default=''): + """Inspect the status and return the redirect url if appropriate.""" + if self.status not in (300, 301, 302, 303, 307): + return default + return self.get_header('Location', default) + + def get_header(self, key, default=''): + """Get one header value, using a case insensitive header name.""" + key = key.lower() + for k, v in self.headers: + if k.lower() == key: + return v + return default + + @asyncio.coroutine + def read(self): + """Read the response body. + + This honors Content-Length and Transfer-Encoding: chunked. + """ + nbytes = None + for key, value in self.headers: + if key.lower() == 'content-length': + nbytes = int(value) + break + if nbytes is None: + if self.get_header('transfer-encoding').lower() == 'chunked': + self.log(2, 'parsing chunked response') + blocks = [] + while True: + size_header = yield from self.reader.readline() + if not size_header: + self.log(0, 'premature end of chunked response') + break + self.log(3, 'size_header =', repr(size_header)) + parts = size_header.split(b';') + size = int(parts[0], 16) + if size: + self.log(3, 'reading chunk of', size, 'bytes') + block = yield from self.reader.readexactly(size) + assert len(block) == size, (len(block), size) + blocks.append(block) + crlf = yield from self.reader.readline() + assert crlf == b'\r\n', repr(crlf) + if not size: + break + body = b''.join(blocks) + self.log(1, 'chunked response had', len(body), + 'bytes in', len(blocks), 'blocks') + else: + self.log(3, 'reading until EOF') + body = yield from self.reader.read() + # TODO: Should make sure not to recycle the connection + # in this case. + else: + body = yield from self.reader.readexactly(nbytes) + return body + + +class Fetcher: + """Logic and state for one URL. + + When found in crawler.busy, this represents a URL to be fetched or + in the process of being fetched; when found in crawler.done, this + holds the results from fetching it. + + This is usually associated with a task. This references the + crawler for the connection pool and to add more URLs to its todo + list. + + Call fetch() to do the fetching, then report() to print the results. + """ + + def __init__(self, log, url, crawler, max_redirect=10, max_tries=4): + self.log = log + self.url = url + self.crawler = crawler + # We don't loop resolving redirects here -- we just use this + # to decide whether to add the redirect URL to crawler.todo. + self.max_redirect = max_redirect + # But we do loop to retry on errors a few times. + self.max_tries = max_tries + # Everything we collect from the response goes here. + self.task = None + self.exceptions = [] + self.tries = 0 + self.request = None + self.response = None + self.body = None + self.next_url = None + self.ctype = None + self.pdict = None + self.encoding = None + self.urls = None + self.new_urls = None + + @asyncio.coroutine + def fetch(self): + """Attempt to fetch the contents of the URL. + + If successful, and the data is HTML, extract further links and + add them to the crawler. Redirects are also added back there. + """ + while self.tries < self.max_tries: + self.tries += 1 + self.request = None + try: + self.request = Request(self.log, self.url, self.crawler.pool) + yield from self.request.connect() + yield from self.request.send_request() + self.response = yield from self.request.get_response() + self.body = yield from self.response.read() + h_conn = self.response.get_header('connection').lower() + if h_conn != 'close': + self.request.close(recycle=True) + self.request = None + if self.tries > 1: + self.log(1, 'try', self.tries, 'for', self.url, 'success') + break + except (BadStatusLine, OSError) as exc: + self.exceptions.append(exc) + self.log(1, 'try', self.tries, 'for', self.url, + 'raised', repr(exc)) + ##import pdb; pdb.set_trace() + # Don't reuse the connection in this case. + finally: + if self.request is not None: + self.request.close() + else: + # We never broke out of the while loop, i.e. all tries failed. + self.log(0, 'no success for', self.url, + 'in', self.max_tries, 'tries') + return + next_url = self.response.get_redirect_url() + if next_url: + self.next_url = urllib.parse.urljoin(self.url, next_url) + if self.max_redirect > 0: + self.log(1, 'redirect to', self.next_url, 'from', self.url) + self.crawler.add_url(self.next_url, self.max_redirect-1) + else: + self.log(0, 'redirect limit reached for', self.next_url, + 'from', self.url) + else: + if self.response.status == 200: + self.ctype = self.response.get_header('content-type') + self.pdict = {} + if self.ctype: + self.ctype, self.pdict = cgi.parse_header(self.ctype) + self.encoding = self.pdict.get('charset', 'utf-8') + if self.ctype == 'text/html': + body = self.body.decode(self.encoding, 'replace') + # Replace href with (?:href|src) to follow image links. + self.urls = set(re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', + body)) + if self.urls: + self.log(1, 'got', len(self.urls), + 'distinct urls from', self.url) + self.new_urls = set() + for url in self.urls: + url = unescape(url) + url = urllib.parse.urljoin(self.url, url) + url, frag = urllib.parse.urldefrag(url) + if self.crawler.add_url(url): + self.new_urls.add(url) + + def report(self, stats, file=None): + """Print a report on the state for this URL. + + Also update the Stats instance. + """ + if self.task is not None: + if not self.task.done(): + stats.add('pending') + print(self.url, 'pending', file=file) + return + elif self.task.cancelled(): + stats.add('cancelled') + print(self.url, 'cancelled', file=file) + return + elif self.task.exception(): + stats.add('exception') + exc = self.task.exception() + stats.add('exception_' + exc.__class__.__name__) + print(self.url, exc, file=file) + return + if len(self.exceptions) == self.tries: + stats.add('fail') + exc = self.exceptions[-1] + stats.add('fail_' + str(exc.__class__.__name__)) + print(self.url, 'error', exc, file=file) + elif self.next_url: + stats.add('redirect') + print(self.url, self.response.status, 'redirect', self.next_url, + file=file) + elif self.ctype == 'text/html': + stats.add('html') + size = len(self.body or b'') + stats.add('html_bytes', size) + print(self.url, self.response.status, + self.ctype, self.encoding, + size, + '%d/%d' % (len(self.new_urls or ()), len(self.urls or ())), + file=file) + elif self.response is None: + print(self.url, 'no response object') + else: + size = len(self.body or b'') + if self.response.status == 200: + stats.add('other') + stats.add('other_bytes', size) + else: + stats.add('error') + stats.add('error_bytes', size) + stats.add('status_%s' % self.response.status) + print(self.url, self.response.status, + self.ctype, self.encoding, + size, + file=file) + + +class Stats: + """Record stats of various sorts.""" + + def __init__(self): + self.stats = {} + + def add(self, key, count=1): + self.stats[key] = self.stats.get(key, 0) + count + + def report(self, file=None): + for key, count in sorted(self.stats.items()): + print('%10d' % count, key, file=file) + + +class Crawler: + """Crawl a set of URLs. + + This manages three disjoint sets of URLs (todo, busy, done). The + data structures actually store dicts -- the values in todo give + the redirect limit, while the values in busy and done are Fetcher + instances. + """ + def __init__(self, log, + roots, exclude=None, strict=True, # What to crawl. + max_redirect=10, max_tries=4, # Per-url limits. + max_tasks=10, max_pool=10, # Global limits. + ): + self.log = log + self.roots = roots + self.exclude = exclude + self.strict = strict + self.max_redirect = max_redirect + self.max_tries = max_tries + self.max_tasks = max_tasks + self.max_pool = max_pool + self.todo = {} + self.busy = {} + self.done = {} + self.pool = ConnectionPool(self.log, max_pool, max_tasks) + self.root_domains = set() + for root in roots: + parts = urllib.parse.urlparse(root) + host, port = urllib.parse.splitport(parts.netloc) + if not host: + continue + if re.match(r'\A[\d\.]*\Z', host): + self.root_domains.add(host) + else: + host = host.lower() + if self.strict: + self.root_domains.add(host) + if host.startswith('www.'): + self.root_domains.add(host[4:]) + else: + self.root_domains.add('www.' + host) + else: + parts = host.split('.') + if len(parts) > 2: + host = '.'.join(parts[-2:]) + self.root_domains.add(host) + for root in roots: + self.add_url(root) + self.governor = asyncio.locks.Semaphore(max_tasks) + self.termination = asyncio.locks.Condition() + self.t0 = time.time() + self.t1 = None + + def close(self): + """Close resources (currently only the pool).""" + self.pool.close() + + def host_okay(self, host): + """Check if a host should be crawled. + + A literal match (after lowercasing) is always good. For hosts + that don't look like IP addresses, some approximate matches + are okay depending on the strict flag. + """ + host = host.lower() + if host in self.root_domains: + return True + if re.match(r'\A[\d\.]*\Z', host): + return False + if self.strict: + return self._host_okay_strictish(host) + else: + return self._host_okay_lenient(host) + + def _host_okay_strictish(self, host): + """Check if a host should be crawled, strict-ish version. + + This checks for equality modulo an initial 'www.' component. + """ + if host.startswith('www.'): + if host[4:] in self.root_domains: + return True + else: + if 'www.' + host in self.root_domains: + return True + return False + + def _host_okay_lenient(self, host): + """Check if a host should be crawled, lenient version. + + This compares the last two components of the host. + """ + parts = host.split('.') + if len(parts) > 2: + host = '.'.join(parts[-2:]) + return host in self.root_domains + + def add_url(self, url, max_redirect=None): + """Add a URL to the todo list if not seen before.""" + if self.exclude and re.search(self.exclude, url): + return False + parts = urllib.parse.urlparse(url) + if parts.scheme not in ('http', 'https'): + self.log(2, 'skipping non-http scheme in', url) + return False + host, port = urllib.parse.splitport(parts.netloc) + if not self.host_okay(host): + self.log(2, 'skipping non-root host in', url) + return False + if max_redirect is None: + max_redirect = self.max_redirect + if url in self.todo or url in self.busy or url in self.done: + return False + self.log(1, 'adding', url, max_redirect) + self.todo[url] = max_redirect + return True + + @asyncio.coroutine + def crawl(self): + """Run the crawler until all finished.""" + with (yield from self.termination): + while self.todo or self.busy: + if self.todo: + url, max_redirect = self.todo.popitem() + fetcher = Fetcher(self.log, url, + crawler=self, + max_redirect=max_redirect, + max_tries=self.max_tries, + ) + self.busy[url] = fetcher + fetcher.task = asyncio.Task(self.fetch(fetcher)) + else: + yield from self.termination.wait() + self.t1 = time.time() + + @asyncio.coroutine + def fetch(self, fetcher): + """Call the Fetcher's fetch(), with a limit on concurrency. + + Once this returns, move the fetcher from busy to done. + """ + url = fetcher.url + with (yield from self.governor): + try: + yield from fetcher.fetch() # Fetcher gonna fetch. + finally: + # Force GC of the task, so the error is logged. + fetcher.task = None + with (yield from self.termination): + self.done[url] = fetcher + del self.busy[url] + self.termination.notify() + + def report(self, file=None): + """Print a report on all completed URLs.""" + if self.t1 is None: + self.t1 = time.time() + dt = self.t1 - self.t0 + if dt and self.max_tasks: + speed = len(self.done) / dt / self.max_tasks + else: + speed = 0 + stats = Stats() + print('*** Report ***', file=file) + try: + show = [] + show.extend(self.done.items()) + show.extend(self.busy.items()) + show.sort() + for url, fetcher in show: + fetcher.report(stats, file=file) + except KeyboardInterrupt: + print('\nInterrupted', file=file) + print('Finished', len(self.done), + 'urls in %.3f secs' % dt, + '(max_tasks=%d)' % self.max_tasks, + '(%.3f urls/sec/task)' % speed, + file=file) + stats.report(file=file) + print('Todo:', len(self.todo), file=file) + print('Busy:', len(self.busy), file=file) + print('Done:', len(self.done), file=file) + print('Date:', time.ctime(), 'local time', file=file) + + +def main(): + """Main program. + + Parse arguments, set up event loop, run crawler, print report. + """ + args = ARGS.parse_args() + if not args.roots: + print('Use --help for command line help') + return + + log = Logger(args.level) + + if args.iocp: + from asyncio.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + asyncio.set_event_loop(loop) + elif args.select: + loop = asyncio.SelectorEventLoop() + asyncio.set_event_loop(loop) + else: + loop = asyncio.get_event_loop() + + roots = {fix_url(root) for root in args.roots} + + crawler = Crawler(log, + roots, exclude=args.exclude, + strict=args.strict, + max_redirect=args.max_redirect, + max_tries=args.max_tries, + max_tasks=args.max_tasks, + max_pool=args.max_pool, + ) + try: + loop.run_until_complete(crawler.crawl()) # Crawler gonna crawl. + except KeyboardInterrupt: + sys.stderr.flush() + print('\nInterrupted\n') + finally: + crawler.report() + crawler.close() + loop.close() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + main() diff --git a/examples/echo_client_tulip.py b/examples/echo_client_tulip.py new file mode 100644 index 00000000..88124efe --- /dev/null +++ b/examples/echo_client_tulip.py @@ -0,0 +1,20 @@ +import asyncio + +END = b'Bye-bye!\n' + +@asyncio.coroutine +def echo_client(): + reader, writer = yield from asyncio.open_connection('localhost', 8000) + writer.write(b'Hello, world\n') + writer.write(b'What a fine day it is.\n') + writer.write(END) + while True: + line = yield from reader.readline() + print('received:', line) + if line == END or not line: + break + writer.close() + +loop = asyncio.get_event_loop() +loop.run_until_complete(echo_client()) +loop.close() diff --git a/examples/echo_server_tulip.py b/examples/echo_server_tulip.py new file mode 100644 index 00000000..8167e540 --- /dev/null +++ b/examples/echo_server_tulip.py @@ -0,0 +1,20 @@ +import asyncio + +@asyncio.coroutine +def echo_server(): + yield from asyncio.start_server(handle_connection, 'localhost', 8000) + +@asyncio.coroutine +def handle_connection(reader, writer): + while True: + data = yield from reader.read(8192) + if not data: + break + writer.write(data) + +loop = asyncio.get_event_loop() +loop.run_until_complete(echo_server()) +try: + loop.run_forever() +finally: + loop.close() diff --git a/examples/fetch0.py b/examples/fetch0.py new file mode 100644 index 00000000..180fcf26 --- /dev/null +++ b/examples/fetch0.py @@ -0,0 +1,35 @@ +"""Simplest possible HTTP client.""" + +import sys + +from asyncio import * + + +@coroutine +def fetch(): + r, w = yield from open_connection('python.org', 80) + request = 'GET / HTTP/1.0\r\n\r\n' + print('>', request, file=sys.stderr) + w.write(request.encode('latin-1')) + while True: + line = yield from r.readline() + line = line.decode('latin-1').rstrip() + if not line: + break + print('<', line, file=sys.stderr) + print(file=sys.stderr) + body = yield from r.read() + return body + + +def main(): + loop = get_event_loop() + try: + body = loop.run_until_complete(fetch()) + finally: + loop.close() + print(body.decode('latin-1'), end='') + + +if __name__ == '__main__': + main() diff --git a/examples/fetch1.py b/examples/fetch1.py new file mode 100644 index 00000000..8dbb6e47 --- /dev/null +++ b/examples/fetch1.py @@ -0,0 +1,78 @@ +"""Fetch one URL and write its content to stdout. + +This version adds URL parsing (including SSL) and a Response object. +""" + +import sys +import urllib.parse + +from asyncio import * + + +class Response: + + def __init__(self, verbose=True): + self.verbose = verbose + self.http_version = None # 'HTTP/1.1' + self.status = None # 200 + self.reason = None # 'Ok' + self.headers = [] # [('Content-Type', 'text/html')] + + @coroutine + def read(self, reader): + @coroutine + def getline(): + return (yield from reader.readline()).decode('latin-1').rstrip() + status_line = yield from getline() + if self.verbose: print('<', status_line, file=sys.stderr) + self.http_version, status, self.reason = status_line.split(None, 2) + self.status = int(status) + while True: + header_line = yield from getline() + if not header_line: + break + if self.verbose: print('<', header_line, file=sys.stderr) + # TODO: Continuation lines. + key, value = header_line.split(':', 1) + self.headers.append((key, value.strip())) + if self.verbose: print(file=sys.stderr) + + +@coroutine +def fetch(url, verbose=True): + parts = urllib.parse.urlparse(url) + if parts.scheme == 'http': + ssl = False + elif parts.scheme == 'https': + ssl = True + else: + print('URL must use http or https.') + sys.exit(1) + port = parts.port + if port is None: + port = 443 if ssl else 80 + path = parts.path or '/' + if parts.query: + path += '?' + parts.query + request = 'GET %s HTTP/1.0\r\n\r\n' % path + if verbose: + print('>', request, file=sys.stderr, end='') + r, w = yield from open_connection(parts.hostname, port, ssl=ssl) + w.write(request.encode('latin-1')) + response = Response(verbose) + yield from response.read(r) + body = yield from r.read() + return body + + +def main(): + loop = get_event_loop() + try: + body = loop.run_until_complete(fetch(sys.argv[1], '-v' in sys.argv)) + finally: + loop.close() + print(body.decode('latin-1'), end='') + + +if __name__ == '__main__': + main() diff --git a/examples/fetch2.py b/examples/fetch2.py new file mode 100644 index 00000000..7617b59b --- /dev/null +++ b/examples/fetch2.py @@ -0,0 +1,141 @@ +"""Fetch one URL and write its content to stdout. + +This version adds a Request object. +""" + +import sys +import urllib.parse +from http.client import BadStatusLine + +from asyncio import * + + +class Request: + + def __init__(self, url, verbose=True): + self.url = url + self.verbose = verbose + self.parts = urllib.parse.urlparse(self.url) + self.scheme = self.parts.scheme + assert self.scheme in ('http', 'https'), repr(url) + self.ssl = self.parts.scheme == 'https' + self.netloc = self.parts.netloc + self.hostname = self.parts.hostname + self.port = self.parts.port or (443 if self.ssl else 80) + self.path = (self.parts.path or '/') + self.query = self.parts.query + if self.query: + self.full_path = '%s?%s' % (self.path, self.query) + else: + self.full_path = self.path + self.http_version = 'HTTP/1.1' + self.method = 'GET' + self.headers = [] + self.reader = None + self.writer = None + + @coroutine + def connect(self): + if self.verbose: + print('* Connecting to %s:%s using %s' % + (self.hostname, self.port, 'ssl' if self.ssl else 'tcp'), + file=sys.stderr) + self.reader, self.writer = yield from open_connection(self.hostname, + self.port, + ssl=self.ssl) + if self.verbose: + print('* Connected to %s' % + (self.writer.get_extra_info('peername'),), + file=sys.stderr) + + def putline(self, line): + self.writer.write(line.encode('latin-1') + b'\r\n') + + @coroutine + def send_request(self): + request = '%s %s %s' % (self.method, self.full_path, self.http_version) + if self.verbose: print('>', request, file=sys.stderr) + self.putline(request) + if 'host' not in {key.lower() for key, _ in self.headers}: + self.headers.insert(0, ('Host', self.netloc)) + for key, value in self.headers: + line = '%s: %s' % (key, value) + if self.verbose: print('>', line, file=sys.stderr) + self.putline(line) + self.putline('') + + @coroutine + def get_response(self): + response = Response(self.reader, self.verbose) + yield from response.read_headers() + return response + + +class Response: + + def __init__(self, reader, verbose=True): + self.reader = reader + self.verbose = verbose + self.http_version = None # 'HTTP/1.1' + self.status = None # 200 + self.reason = None # 'Ok' + self.headers = [] # [('Content-Type', 'text/html')] + + @coroutine + def getline(self): + return (yield from self.reader.readline()).decode('latin-1').rstrip() + + @coroutine + def read_headers(self): + status_line = yield from self.getline() + if self.verbose: print('<', status_line, file=sys.stderr) + status_parts = status_line.split(None, 2) + if len(status_parts) != 3: + raise BadStatusLine(status_line) + self.http_version, status, self.reason = status_parts + self.status = int(status) + while True: + header_line = yield from self.getline() + if not header_line: + break + if self.verbose: print('<', header_line, file=sys.stderr) + # TODO: Continuation lines. + key, value = header_line.split(':', 1) + self.headers.append((key, value.strip())) + if self.verbose: print(file=sys.stderr) + + @coroutine + def read(self): + nbytes = None + for key, value in self.headers: + if key.lower() == 'content-length': + nbytes = int(value) + break + if nbytes is None: + body = yield from self.reader.read() + else: + body = yield from self.reader.readexactly(nbytes) + return body + + +@coroutine +def fetch(url, verbose=True): + request = Request(url, verbose) + yield from request.connect() + yield from request.send_request() + response = yield from request.get_response() + body = yield from response.read() + return body + + +def main(): + loop = get_event_loop() + try: + body = loop.run_until_complete(fetch(sys.argv[1], '-v' in sys.argv)) + finally: + loop.close() + sys.stdout.buffer.write(body) + + +if __name__ == '__main__': + main() diff --git a/examples/fetch3.py b/examples/fetch3.py new file mode 100644 index 00000000..9419afd2 --- /dev/null +++ b/examples/fetch3.py @@ -0,0 +1,230 @@ +"""Fetch one URL and write its content to stdout. + +This version adds a primitive connection pool, redirect following and +chunked transfer-encoding. It also supports a --iocp flag. +""" + +import sys +import urllib.parse +from http.client import BadStatusLine + +from asyncio import * + + +class ConnectionPool: + # TODO: Locking? Close idle connections? + + def __init__(self, verbose=False): + self.verbose = verbose + self.connections = {} # {(host, port, ssl): (reader, writer)} + + def close(self): + for _, writer in self.connections.values(): + writer.close() + + @coroutine + def open_connection(self, host, port, ssl): + port = port or (443 if ssl else 80) + ipaddrs = yield from get_event_loop().getaddrinfo(host, port) + if self.verbose: + print('* %s resolves to %s' % + (host, ', '.join(ip[4][0] for ip in ipaddrs)), + file=sys.stderr) + for _, _, _, _, (h, p, *_) in ipaddrs: + key = h, p, ssl + conn = self.connections.get(key) + if conn: + reader, writer = conn + if reader._eof: + self.connections.pop(key) + continue + if self.verbose: + print('* Reusing pooled connection', key, file=sys.stderr) + return conn + reader, writer = yield from open_connection(host, port, ssl=ssl) + host, port, *_ = writer.get_extra_info('peername') + key = host, port, ssl + self.connections[key] = reader, writer + if self.verbose: + print('* New connection', key, file=sys.stderr) + return reader, writer + + +class Request: + + def __init__(self, url, verbose=True): + self.url = url + self.verbose = verbose + self.parts = urllib.parse.urlparse(self.url) + self.scheme = self.parts.scheme + assert self.scheme in ('http', 'https'), repr(url) + self.ssl = self.parts.scheme == 'https' + self.netloc = self.parts.netloc + self.hostname = self.parts.hostname + self.port = self.parts.port or (443 if self.ssl else 80) + self.path = (self.parts.path or '/') + self.query = self.parts.query + if self.query: + self.full_path = '%s?%s' % (self.path, self.query) + else: + self.full_path = self.path + self.http_version = 'HTTP/1.1' + self.method = 'GET' + self.headers = [] + self.reader = None + self.writer = None + + def vprint(self, *args): + if self.verbose: + print(*args, file=sys.stderr) + + @coroutine + def connect(self, pool): + self.vprint('* Connecting to %s:%s using %s' % + (self.hostname, self.port, 'ssl' if self.ssl else 'tcp')) + self.reader, self.writer = \ + yield from pool.open_connection(self.hostname, + self.port, + ssl=self.ssl) + self.vprint('* Connected to %s' % + (self.writer.get_extra_info('peername'),)) + + @coroutine + def putline(self, line): + self.vprint('>', line) + self.writer.write(line.encode('latin-1') + b'\r\n') + ##yield from self.writer.drain() + + @coroutine + def send_request(self): + request = '%s %s %s' % (self.method, self.full_path, self.http_version) + yield from self.putline(request) + if 'host' not in {key.lower() for key, _ in self.headers}: + self.headers.insert(0, ('Host', self.netloc)) + for key, value in self.headers: + line = '%s: %s' % (key, value) + yield from self.putline(line) + yield from self.putline('') + + @coroutine + def get_response(self): + response = Response(self.reader, self.verbose) + yield from response.read_headers() + return response + + +class Response: + + def __init__(self, reader, verbose=True): + self.reader = reader + self.verbose = verbose + self.http_version = None # 'HTTP/1.1' + self.status = None # 200 + self.reason = None # 'Ok' + self.headers = [] # [('Content-Type', 'text/html')] + + def vprint(self, *args): + if self.verbose: + print(*args, file=sys.stderr) + + @coroutine + def getline(self): + line = (yield from self.reader.readline()).decode('latin-1').rstrip() + self.vprint('<', line) + return line + + @coroutine + def read_headers(self): + status_line = yield from self.getline() + status_parts = status_line.split(None, 2) + if len(status_parts) != 3: + raise BadStatusLine(status_line) + self.http_version, status, self.reason = status_parts + self.status = int(status) + while True: + header_line = yield from self.getline() + if not header_line: + break + # TODO: Continuation lines. + key, value = header_line.split(':', 1) + self.headers.append((key, value.strip())) + + def get_redirect_url(self, default=None): + if self.status not in (300, 301, 302, 303, 307): + return default + return self.get_header('Location', default) + + def get_header(self, key, default=None): + key = key.lower() + for k, v in self.headers: + if k.lower() == key: + return v + return default + + @coroutine + def read(self): + nbytes = None + for key, value in self.headers: + if key.lower() == 'content-length': + nbytes = int(value) + break + if nbytes is None: + if self.get_header('transfer-encoding', '').lower() == 'chunked': + blocks = [] + size = -1 + while size: + size_header = yield from self.reader.readline() + if not size_header: + break + parts = size_header.split(b';') + size = int(parts[0], 16) + if size: + block = yield from self.reader.readexactly(size) + assert len(block) == size, (len(block), size) + blocks.append(block) + crlf = yield from self.reader.readline() + assert crlf == b'\r\n', repr(crlf) + body = b''.join(blocks) + else: + body = yield from self.reader.read() + else: + body = yield from self.reader.readexactly(nbytes) + return body + + +@coroutine +def fetch(url, verbose=True, max_redirect=10): + pool = ConnectionPool(verbose) + try: + for _ in range(max_redirect): + request = Request(url, verbose) + yield from request.connect(pool) + yield from request.send_request() + response = yield from request.get_response() + body = yield from response.read() + next_url = response.get_redirect_url() + if not next_url: + break + url = urllib.parse.urljoin(url, next_url) + print('redirect to', url, file=sys.stderr) + return body + finally: + pool.close() + + +def main(): + if '--iocp' in sys.argv: + from asyncio.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + set_event_loop(loop) + else: + loop = get_event_loop() + try: + body = loop.run_until_complete(fetch(sys.argv[1], '-v' in sys.argv)) + finally: + loop.close() + sys.stdout.buffer.write(body) + + +if __name__ == '__main__': + main() diff --git a/examples/fuzz_as_completed.py b/examples/fuzz_as_completed.py new file mode 100644 index 00000000..123fbf1b --- /dev/null +++ b/examples/fuzz_as_completed.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 + +"""Fuzz tester for as_completed(), by Glenn Langford.""" + +import asyncio +import itertools +import random +import sys + +@asyncio.coroutine +def sleeper(time): + yield from asyncio.sleep(time) + return time + +@asyncio.coroutine +def watcher(tasks,delay=False): + res = [] + for t in asyncio.as_completed(tasks): + r = yield from t + res.append(r) + if delay: + # simulate processing delay + process_time = random.random() / 10 + yield from asyncio.sleep(process_time) + #print(res) + #assert(sorted(res) == res) + if sorted(res) != res: + print('FAIL', res) + print('------------') + else: + print('.', end='') + sys.stdout.flush() + +loop = asyncio.get_event_loop() + +print('Pass 1') +# All permutations of discrete task running times must be returned +# by as_completed in the correct order. +task_times = [0, 0.1, 0.2, 0.3, 0.4 ] # 120 permutations +for times in itertools.permutations(task_times): + tasks = [ asyncio.Task(sleeper(t)) for t in times ] + loop.run_until_complete(asyncio.Task(watcher(tasks))) + +print() +print('Pass 2') +# Longer task times, with randomized duplicates. 100 tasks each time. +longer_task_times = [x/10 for x in range(30)] +for i in range(20): + task_times = longer_task_times * 10 + random.shuffle(task_times) + #print('Times', task_times[:500]) + tasks = [ asyncio.Task(sleeper(t)) for t in task_times[:100] ] + loop.run_until_complete(asyncio.Task(watcher(tasks))) + +print() +print('Pass 3') +# Same as pass 2, but with a random processing delay (0 - 0.1s) after +# retrieving each future from as_completed and 200 tasks. This tests whether +# the order that callbacks are triggered is preserved through to the +# as_completed caller. +for i in range(20): + task_times = longer_task_times * 10 + random.shuffle(task_times) + #print('Times', task_times[:200]) + tasks = [ asyncio.Task(sleeper(t)) for t in task_times[:200] ] + loop.run_until_complete(asyncio.Task(watcher(tasks, delay=True))) + +print() +loop.close() diff --git a/examples/hello_callback.py b/examples/hello_callback.py new file mode 100644 index 00000000..7ccbea1e --- /dev/null +++ b/examples/hello_callback.py @@ -0,0 +1,17 @@ +"""Print 'Hello World' every two seconds, using a callback.""" + +import asyncio + + +def print_and_repeat(loop): + print('Hello World') + loop.call_later(2, print_and_repeat, loop) + + +if __name__ == '__main__': + loop = asyncio.get_event_loop() + print_and_repeat(loop) + try: + loop.run_forever() + finally: + loop.close() diff --git a/examples/hello_coroutine.py b/examples/hello_coroutine.py new file mode 100644 index 00000000..b9347aa8 --- /dev/null +++ b/examples/hello_coroutine.py @@ -0,0 +1,18 @@ +"""Print 'Hello World' every two seconds, using a coroutine.""" + +import asyncio + + +@asyncio.coroutine +def greet_every_two_seconds(): + while True: + print('Hello World') + yield from asyncio.sleep(2) + + +if __name__ == '__main__': + loop = asyncio.get_event_loop() + try: + loop.run_until_complete(greet_every_two_seconds()) + finally: + loop.close() diff --git a/examples/shell.py b/examples/shell.py new file mode 100644 index 00000000..7dc7caf3 --- /dev/null +++ b/examples/shell.py @@ -0,0 +1,50 @@ +"""Examples using create_subprocess_exec() and create_subprocess_shell().""" + +import asyncio +import signal +from asyncio.subprocess import PIPE + +@asyncio.coroutine +def cat(loop): + proc = yield from asyncio.create_subprocess_shell("cat", + stdin=PIPE, + stdout=PIPE) + print("pid: %s" % proc.pid) + + message = "Hello World!" + print("cat write: %r" % message) + + stdout, stderr = yield from proc.communicate(message.encode('ascii')) + print("cat read: %r" % stdout.decode('ascii')) + + exitcode = yield from proc.wait() + print("(exit code %s)" % exitcode) + +@asyncio.coroutine +def ls(loop): + proc = yield from asyncio.create_subprocess_exec("ls", + stdout=PIPE) + while True: + line = yield from proc.stdout.readline() + if not line: + break + print("ls>>", line.decode('ascii').rstrip()) + try: + proc.send_signal(signal.SIGINT) + except ProcessLookupError: + pass + +@asyncio.coroutine +def test_call(*args, timeout=None): + try: + proc = yield from asyncio.create_subprocess_exec(*args) + exitcode = yield from asyncio.wait_for(proc.wait(), timeout) + print("%s: exit code %s" % (' '.join(args), exitcode)) + except asyncio.TimeoutError: + print("timeout! (%.1f sec)" % timeout) + +loop = asyncio.get_event_loop() +loop.run_until_complete(cat(loop)) +loop.run_until_complete(ls(loop)) +loop.run_until_complete(test_call("bash", "-c", "sleep 3", timeout=1.0)) +loop.close() diff --git a/examples/simple_tcp_server.py b/examples/simple_tcp_server.py new file mode 100644 index 00000000..b796d9b6 --- /dev/null +++ b/examples/simple_tcp_server.py @@ -0,0 +1,154 @@ +""" +Example of a simple TCP server that is written in (mostly) coroutine +style and uses asyncio.streams.start_server() and +asyncio.streams.open_connection(). + +Note that running this example starts both the TCP server and client +in the same process. It listens on port 1234 on 127.0.0.1, so it will +fail if this port is currently in use. +""" + +import sys +import asyncio +import asyncio.streams + + +class MyServer: + """ + This is just an example of how a TCP server might be potentially + structured. This class has basically 3 methods: start the server, + handle a client, and stop the server. + + Note that you don't have to follow this structure, it is really + just an example or possible starting point. + """ + + def __init__(self): + self.server = None # encapsulates the server sockets + + # this keeps track of all the clients that connected to our + # server. It can be useful in some cases, for instance to + # kill client connections or to broadcast some data to all + # clients... + self.clients = {} # task -> (reader, writer) + + def _accept_client(self, client_reader, client_writer): + """ + This method accepts a new client connection and creates a Task + to handle this client. self.clients is updated to keep track + of the new client. + """ + + # start a new Task to handle this specific client connection + task = asyncio.Task(self._handle_client(client_reader, client_writer)) + self.clients[task] = (client_reader, client_writer) + + def client_done(task): + print("client task done:", task, file=sys.stderr) + del self.clients[task] + + task.add_done_callback(client_done) + + @asyncio.coroutine + def _handle_client(self, client_reader, client_writer): + """ + This method actually does the work to handle the requests for + a specific client. The protocol is line oriented, so there is + a main loop that reads a line with a request and then sends + out one or more lines back to the client with the result. + """ + while True: + data = (yield from client_reader.readline()).decode("utf-8") + if not data: # an empty string means the client disconnected + break + cmd, *args = data.rstrip().split(' ') + if cmd == 'add': + arg1 = float(args[0]) + arg2 = float(args[1]) + retval = arg1 + arg2 + client_writer.write("{!r}\n".format(retval).encode("utf-8")) + elif cmd == 'repeat': + times = int(args[0]) + msg = args[1] + client_writer.write("begin\n".encode("utf-8")) + for idx in range(times): + client_writer.write("{}. {}\n".format(idx+1, msg) + .encode("utf-8")) + client_writer.write("end\n".encode("utf-8")) + else: + print("Bad command {!r}".format(data), file=sys.stderr) + + # This enables us to have flow control in our connection. + yield from client_writer.drain() + + def start(self, loop): + """ + Starts the TCP server, so that it listens on port 1234. + + For each client that connects, the accept_client method gets + called. This method runs the loop until the server sockets + are ready to accept connections. + """ + self.server = loop.run_until_complete( + asyncio.streams.start_server(self._accept_client, + '127.0.0.1', 12345, + loop=loop)) + + def stop(self, loop): + """ + Stops the TCP server, i.e. closes the listening socket(s). + + This method runs the loop until the server sockets are closed. + """ + if self.server is not None: + self.server.close() + loop.run_until_complete(self.server.wait_closed()) + self.server = None + + +def main(): + loop = asyncio.get_event_loop() + + # creates a server and starts listening to TCP connections + server = MyServer() + server.start(loop) + + @asyncio.coroutine + def client(): + reader, writer = yield from asyncio.streams.open_connection( + '127.0.0.1', 12345, loop=loop) + + def send(msg): + print("> " + msg) + writer.write((msg + '\n').encode("utf-8")) + + def recv(): + msgback = (yield from reader.readline()).decode("utf-8").rstrip() + print("< " + msgback) + return msgback + + # send a line + send("add 1 2") + msg = yield from recv() + + send("repeat 5 hello") + msg = yield from recv() + assert msg == 'begin' + while True: + msg = yield from recv() + if msg == 'end': + break + + writer.close() + yield from asyncio.sleep(0.5) + + # creates a client and connects to our server + try: + loop.run_until_complete(client()) + server.stop(loop) + finally: + loop.close() + + +if __name__ == '__main__': + main() diff --git a/examples/sink.py b/examples/sink.py new file mode 100644 index 00000000..d362cbb2 --- /dev/null +++ b/examples/sink.py @@ -0,0 +1,94 @@ +"""Test service that accepts connections and reads all data off them.""" + +import argparse +import os +import sys + +from asyncio import * + +ARGS = argparse.ArgumentParser(description="TCP data sink example.") +ARGS.add_argument( + '--tls', action='store_true', dest='tls', + default=False, help='Use TLS with a self-signed cert') +ARGS.add_argument( + '--iocp', action='store_true', dest='iocp', + default=False, help='Use IOCP event loop (Windows only)') +ARGS.add_argument( + '--host', action='store', dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action='store', dest='port', + default=1111, type=int, help='Port number') +ARGS.add_argument( + '--maxsize', action='store', dest='maxsize', + default=16*1024*1024, type=int, help='Max total data size') + +server = None +args = None + + +def dprint(*args): + print('sink:', *args, file=sys.stderr) + + +class Service(Protocol): + + def connection_made(self, tr): + dprint('connection from', tr.get_extra_info('peername')) + dprint('my socket is', tr.get_extra_info('sockname')) + self.tr = tr + self.total = 0 + + def data_received(self, data): + if data == b'stop': + dprint('stopping server') + server.close() + self.tr.close() + return + self.total += len(data) + dprint('received', len(data), 'bytes; total', self.total) + if self.total > args.maxsize: + dprint('closing due to too much data') + self.tr.close() + + def connection_lost(self, how): + dprint('closed', repr(how)) + + +@coroutine +def start(loop, host, port): + global server + sslctx = None + if args.tls: + import ssl + # TODO: take cert/key from args as well. + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + sslctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslctx.options |= ssl.OP_NO_SSLv2 + sslctx.load_cert_chain( + certfile=os.path.join(here, 'ssl_cert.pem'), + keyfile=os.path.join(here, 'ssl_key.pem')) + + server = yield from loop.create_server(Service, host, port, ssl=sslctx) + dprint('serving TLS' if sslctx else 'serving', + [s.getsockname() for s in server.sockets]) + yield from server.wait_closed() + + +def main(): + global args + args = ARGS.parse_args() + if args.iocp: + from asyncio.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + set_event_loop(loop) + else: + loop = get_event_loop() + try: + loop.run_until_complete(start(loop, args.host, args.port)) + finally: + loop.close() + + +if __name__ == '__main__': + main() diff --git a/examples/source.py b/examples/source.py new file mode 100644 index 00000000..7fd11fb0 --- /dev/null +++ b/examples/source.py @@ -0,0 +1,100 @@ +"""Test client that connects and sends infinite data.""" + +import argparse +import sys + +from asyncio import * +from asyncio import test_utils + + +ARGS = argparse.ArgumentParser(description="TCP data sink example.") +ARGS.add_argument( + '--tls', action='store_true', dest='tls', + default=False, help='Use TLS') +ARGS.add_argument( + '--iocp', action='store_true', dest='iocp', + default=False, help='Use IOCP event loop (Windows only)') +ARGS.add_argument( + '--stop', action='store_true', dest='stop', + default=False, help='Stop the server by sending it b"stop" as data') +ARGS.add_argument( + '--host', action='store', dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action='store', dest='port', + default=1111, type=int, help='Port number') +ARGS.add_argument( + '--size', action='store', dest='size', + default=16*1024, type=int, help='Data size') + +args = None + + +def dprint(*args): + print('source:', *args, file=sys.stderr) + + +class Client(Protocol): + + total = 0 + + def connection_made(self, tr): + dprint('connecting to', tr.get_extra_info('peername')) + dprint('my socket is', tr.get_extra_info('sockname')) + self.tr = tr + self.lost = False + self.loop = get_event_loop() + self.waiter = Future() + if args.stop: + self.tr.write(b'stop') + self.tr.close() + else: + self.data = b'x'*args.size + self.write_some_data() + + def write_some_data(self): + if self.lost: + dprint('lost already') + return + data = self.data + size = len(data) + self.total += size + dprint('writing', size, 'bytes; total', self.total) + self.tr.write(data) + self.loop.call_soon(self.write_some_data) + + def connection_lost(self, exc): + dprint('lost connection', repr(exc)) + self.lost = True + self.waiter.set_result(None) + + +@coroutine +def start(loop, host, port): + sslctx = None + if args.tls: + sslctx = test_utils.dummy_ssl_context() + tr, pr = yield from loop.create_connection(Client, host, port, + ssl=sslctx) + dprint('tr =', tr) + dprint('pr =', pr) + yield from pr.waiter + + +def main(): + global args + args = ARGS.parse_args() + if args.iocp: + from asyncio.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + set_event_loop(loop) + else: + loop = get_event_loop() + try: + loop.run_until_complete(start(loop, args.host, args.port)) + finally: + loop.close() + + +if __name__ == '__main__': + main() diff --git a/examples/source1.py b/examples/source1.py new file mode 100644 index 00000000..6802e963 --- /dev/null +++ b/examples/source1.py @@ -0,0 +1,98 @@ +"""Like source.py, but uses streams.""" + +import argparse +import sys + +from asyncio import * +from asyncio import test_utils + +ARGS = argparse.ArgumentParser(description="TCP data sink example.") +ARGS.add_argument( + '--tls', action='store_true', dest='tls', + default=False, help='Use TLS') +ARGS.add_argument( + '--iocp', action='store_true', dest='iocp', + default=False, help='Use IOCP event loop (Windows only)') +ARGS.add_argument( + '--stop', action='store_true', dest='stop', + default=False, help='Stop the server by sending it b"stop" as data') +ARGS.add_argument( + '--host', action='store', dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action='store', dest='port', + default=1111, type=int, help='Port number') +ARGS.add_argument( + '--size', action='store', dest='size', + default=16*1024, type=int, help='Data size') + + +class Debug: + """A clever little class that suppresses repetitive messages.""" + + overwriting = False + label = 'stream1:' + + def print(self, *args): + if self.overwriting: + print(file=sys.stderr) + self.overwriting = 0 + print(self.label, *args, file=sys.stderr) + + def oprint(self, *args): + self.overwriting += 1 + end = '\n' + if self.overwriting >= 3: + if self.overwriting == 3: + print(self.label, '[...]', file=sys.stderr) + end = '\r' + print(self.label, *args, file=sys.stderr, end=end, flush=True) + + +@coroutine +def start(loop, args): + d = Debug() + total = 0 + sslctx = None + if args.tls: + d.print('using dummy SSLContext') + sslctx = test_utils.dummy_ssl_context() + r, w = yield from open_connection(args.host, args.port, ssl=sslctx) + d.print('r =', r) + d.print('w =', w) + if args.stop: + w.write(b'stop') + w.close() + else: + size = args.size + data = b'x'*size + try: + while True: + total += size + d.oprint('writing', size, 'bytes; total', total) + w.write(data) + f = w.drain() + if f: + d.print('pausing') + yield from f + except (ConnectionResetError, BrokenPipeError) as exc: + d.print('caught', repr(exc)) + + +def main(): + global args + args = ARGS.parse_args() + if args.iocp: + from asyncio.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + set_event_loop(loop) + else: + loop = get_event_loop() + try: + loop.run_until_complete(start(loop, args)) + finally: + loop.close() + + +if __name__ == '__main__': + main() diff --git a/examples/stacks.py b/examples/stacks.py new file mode 100644 index 00000000..0b7e0b2c --- /dev/null +++ b/examples/stacks.py @@ -0,0 +1,44 @@ +"""Crude demo for print_stack().""" + + +from asyncio import * + + +@coroutine +def helper(r): + print('--- helper ---') + for t in Task.all_tasks(): + t.print_stack() + print('--- end helper ---') + line = yield from r.readline() + 1/0 + return line + +def doit(): + l = get_event_loop() + lr = l.run_until_complete + r, w = lr(open_connection('python.org', 80)) + t1 = async(helper(r)) + for t in Task.all_tasks(): t.print_stack() + print('---') + l._run_once() + for t in Task.all_tasks(): t.print_stack() + print('---') + w.write(b'GET /\r\n') + w.write_eof() + try: + lr(t1) + except Exception as e: + print('catching', e) + finally: + for t in Task.all_tasks(): + t.print_stack() + l.close() + + +def main(): + doit() + + +if __name__ == '__main__': + main() diff --git a/examples/subprocess_attach_read_pipe.py b/examples/subprocess_attach_read_pipe.py new file mode 100644 index 00000000..d8a62420 --- /dev/null +++ b/examples/subprocess_attach_read_pipe.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +"""Example showing how to attach a read pipe to a subprocess.""" +import asyncio +import os, sys + +code = """ +import os, sys +fd = int(sys.argv[1]) +os.write(fd, b'data') +os.close(fd) +""" + +loop = asyncio.get_event_loop() + +@asyncio.coroutine +def task(): + rfd, wfd = os.pipe() + args = [sys.executable, '-c', code, str(wfd)] + + pipe = open(rfd, 'rb', 0) + reader = asyncio.StreamReader(loop=loop) + protocol = asyncio.StreamReaderProtocol(reader, loop=loop) + transport, _ = yield from loop.connect_read_pipe(lambda: protocol, pipe) + + proc = yield from asyncio.create_subprocess_exec(*args, pass_fds={wfd}) + yield from proc.wait() + + os.close(wfd) + data = yield from reader.read() + print("read = %r" % data.decode()) + +loop.run_until_complete(task()) +loop.close() diff --git a/examples/subprocess_attach_write_pipe.py b/examples/subprocess_attach_write_pipe.py new file mode 100644 index 00000000..86148774 --- /dev/null +++ b/examples/subprocess_attach_write_pipe.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 +"""Example showing how to attach a write pipe to a subprocess.""" +import asyncio +import os, sys +from asyncio import subprocess + +code = """ +import os, sys +fd = int(sys.argv[1]) +data = os.read(fd, 1024) +sys.stdout.buffer.write(data) +""" + +loop = asyncio.get_event_loop() + +@asyncio.coroutine +def task(): + rfd, wfd = os.pipe() + args = [sys.executable, '-c', code, str(rfd)] + proc = yield from asyncio.create_subprocess_exec( + *args, + pass_fds={rfd}, + stdout=subprocess.PIPE) + + pipe = open(wfd, 'wb', 0) + transport, _ = yield from loop.connect_write_pipe(asyncio.Protocol, + pipe) + transport.write(b'data') + + stdout, stderr = yield from proc.communicate() + print("stdout = %r" % stdout.decode()) + pipe.close() + +loop.run_until_complete(task()) +loop.close() diff --git a/examples/subprocess_shell.py b/examples/subprocess_shell.py new file mode 100644 index 00000000..745cb646 --- /dev/null +++ b/examples/subprocess_shell.py @@ -0,0 +1,87 @@ +"""Example writing to and reading from a subprocess at the same time using +tasks.""" + +import asyncio +import os +from asyncio.subprocess import PIPE + + +@asyncio.coroutine +def send_input(writer, input): + try: + for line in input: + print('sending', len(line), 'bytes') + writer.write(line) + d = writer.drain() + if d: + print('pause writing') + yield from d + print('resume writing') + writer.close() + except BrokenPipeError: + print('stdin: broken pipe error') + except ConnectionResetError: + print('stdin: connection reset error') + +@asyncio.coroutine +def log_errors(reader): + while True: + line = yield from reader.readline() + if not line: + break + print('ERROR', repr(line)) + +@asyncio.coroutine +def read_stdout(stdout): + while True: + line = yield from stdout.readline() + print('received', repr(line)) + if not line: + break + +@asyncio.coroutine +def start(cmd, input=None, **kwds): + kwds['stdout'] = PIPE + kwds['stderr'] = PIPE + if input is None and 'stdin' not in kwds: + kwds['stdin'] = None + else: + kwds['stdin'] = PIPE + proc = yield from asyncio.create_subprocess_shell(cmd, **kwds) + + tasks = [] + if input is not None: + tasks.append(send_input(proc.stdin, input)) + else: + print('No stdin') + if proc.stderr is not None: + tasks.append(log_errors(proc.stderr)) + else: + print('No stderr') + if proc.stdout is not None: + tasks.append(read_stdout(proc.stdout)) + else: + print('No stdout') + + if tasks: + # feed stdin while consuming stdout to avoid hang + # when stdin pipe is full + yield from asyncio.wait(tasks) + + exitcode = yield from proc.wait() + print("exit code: %s" % exitcode) + + +def main(): + if os.name == 'nt': + loop = asyncio.ProactorEventLoop() + asyncio.set_event_loop(loop) + else: + loop = asyncio.get_event_loop() + loop.run_until_complete(start( + 'sleep 2; wc', input=[b'foo bar baz\n'*300 for i in range(100)])) + loop.close() + + +if __name__ == '__main__': + main() diff --git a/examples/tcp_echo.py b/examples/tcp_echo.py new file mode 100755 index 00000000..d743242a --- /dev/null +++ b/examples/tcp_echo.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +"""TCP echo server example.""" +import argparse +import asyncio +import sys +try: + import signal +except ImportError: + signal = None + + +class EchoServer(asyncio.Protocol): + + TIMEOUT = 5.0 + + def timeout(self): + print('connection timeout, closing.') + self.transport.close() + + def connection_made(self, transport): + print('connection made') + self.transport = transport + + # start 5 seconds timeout timer + self.h_timeout = asyncio.get_event_loop().call_later( + self.TIMEOUT, self.timeout) + + def data_received(self, data): + print('data received: ', data.decode()) + self.transport.write(b'Re: ' + data) + + # restart timeout timer + self.h_timeout.cancel() + self.h_timeout = asyncio.get_event_loop().call_later( + self.TIMEOUT, self.timeout) + + def eof_received(self): + pass + + def connection_lost(self, exc): + print('connection lost:', exc) + self.h_timeout.cancel() + + +class EchoClient(asyncio.Protocol): + + message = 'This is the message. It will be echoed.' + + def connection_made(self, transport): + self.transport = transport + self.transport.write(self.message.encode()) + print('data sent:', self.message) + + def data_received(self, data): + print('data received:', data) + + # disconnect after 10 seconds + asyncio.get_event_loop().call_later(10.0, self.transport.close) + + def eof_received(self): + pass + + def connection_lost(self, exc): + print('connection lost:', exc) + asyncio.get_event_loop().stop() + + +def start_client(loop, host, port): + t = asyncio.Task(loop.create_connection(EchoClient, host, port)) + loop.run_until_complete(t) + + +def start_server(loop, host, port): + f = loop.create_server(EchoServer, host, port) + return loop.run_until_complete(f) + + +ARGS = argparse.ArgumentParser(description="TCP Echo example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run tcp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run tcp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') +ARGS.add_argument( + '--iocp', action="store_true", dest='iocp', + default=False, help='Use IOCP event loop') + + +if __name__ == '__main__': + args = ARGS.parse_args() + + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + if args.iocp: + from asyncio import windows_events + loop = windows_events.ProactorEventLoop() + asyncio.set_event_loop(loop) + else: + loop = asyncio.get_event_loop() + print ('Using backend: {0}'.format(loop.__class__.__name__)) + + if signal is not None and sys.platform != 'win32': + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if args.server: + server = start_server(loop, args.host, args.port) + else: + start_client(loop, args.host, args.port) + + try: + loop.run_forever() + finally: + if args.server: + server.close() + loop.close() diff --git a/examples/timing_tcp_server.py b/examples/timing_tcp_server.py new file mode 100644 index 00000000..883ce6d3 --- /dev/null +++ b/examples/timing_tcp_server.py @@ -0,0 +1,168 @@ +""" +A variant of simple_tcp_server.py that measures the time it takes to +send N messages for a range of N. (This was O(N**2) in a previous +version of Tulip.) + +Note that running this example starts both the TCP server and client +in the same process. It listens on port 1234 on 127.0.0.1, so it will +fail if this port is currently in use. +""" + +import sys +import time +import random + +import asyncio +import asyncio.streams + + +class MyServer: + """ + This is just an example of how a TCP server might be potentially + structured. This class has basically 3 methods: start the server, + handle a client, and stop the server. + + Note that you don't have to follow this structure, it is really + just an example or possible starting point. + """ + + def __init__(self): + self.server = None # encapsulates the server sockets + + # this keeps track of all the clients that connected to our + # server. It can be useful in some cases, for instance to + # kill client connections or to broadcast some data to all + # clients... + self.clients = {} # task -> (reader, writer) + + def _accept_client(self, client_reader, client_writer): + """ + This method accepts a new client connection and creates a Task + to handle this client. self.clients is updated to keep track + of the new client. + """ + + # start a new Task to handle this specific client connection + task = asyncio.Task(self._handle_client(client_reader, client_writer)) + self.clients[task] = (client_reader, client_writer) + + def client_done(task): + print("client task done:", task, file=sys.stderr) + del self.clients[task] + + task.add_done_callback(client_done) + + @asyncio.coroutine + def _handle_client(self, client_reader, client_writer): + """ + This method actually does the work to handle the requests for + a specific client. The protocol is line oriented, so there is + a main loop that reads a line with a request and then sends + out one or more lines back to the client with the result. + """ + while True: + data = (yield from client_reader.readline()).decode("utf-8") + if not data: # an empty string means the client disconnected + break + cmd, *args = data.rstrip().split(' ') + if cmd == 'add': + arg1 = float(args[0]) + arg2 = float(args[1]) + retval = arg1 + arg2 + client_writer.write("{!r}\n".format(retval).encode("utf-8")) + elif cmd == 'repeat': + times = int(args[0]) + msg = args[1] + client_writer.write("begin\n".encode("utf-8")) + for idx in range(times): + client_writer.write("{}. {}\n".format( + idx+1, msg + 'x'*random.randint(10, 50)) + .encode("utf-8")) + client_writer.write("end\n".encode("utf-8")) + else: + print("Bad command {!r}".format(data), file=sys.stderr) + + # This enables us to have flow control in our connection. + yield from client_writer.drain() + + def start(self, loop): + """ + Starts the TCP server, so that it listens on port 1234. + + For each client that connects, the accept_client method gets + called. This method runs the loop until the server sockets + are ready to accept connections. + """ + self.server = loop.run_until_complete( + asyncio.streams.start_server(self._accept_client, + '127.0.0.1', 12345, + loop=loop)) + + def stop(self, loop): + """ + Stops the TCP server, i.e. closes the listening socket(s). + + This method runs the loop until the server sockets are closed. + """ + if self.server is not None: + self.server.close() + loop.run_until_complete(self.server.wait_closed()) + self.server = None + + +def main(): + loop = asyncio.get_event_loop() + + # creates a server and starts listening to TCP connections + server = MyServer() + server.start(loop) + + @asyncio.coroutine + def client(): + reader, writer = yield from asyncio.streams.open_connection( + '127.0.0.1', 12345, loop=loop) + + def send(msg): + print("> " + msg) + writer.write((msg + '\n').encode("utf-8")) + + def recv(): + msgback = (yield from reader.readline()).decode("utf-8").rstrip() + print("< " + msgback) + return msgback + + # send a line + send("add 1 2") + msg = yield from recv() + + Ns = list(range(100, 100000, 10000)) + times = [] + + for N in Ns: + t0 = time.time() + send("repeat {} hello world ".format(N)) + msg = yield from recv() + assert msg == 'begin' + while True: + msg = (yield from reader.readline()).decode("utf-8").rstrip() + if msg == 'end': + break + t1 = time.time() + dt = t1 - t0 + print("Time taken: {:.3f} seconds ({:.6f} per repetition)" + .format(dt, dt/N)) + times.append(dt) + + writer.close() + yield from asyncio.sleep(0.5) + + # creates a client and connects to our server + try: + loop.run_until_complete(client()) + server.stop(loop) + finally: + loop.close() + + +if __name__ == '__main__': + main() diff --git a/examples/udp_echo.py b/examples/udp_echo.py new file mode 100755 index 00000000..93ac7e6b --- /dev/null +++ b/examples/udp_echo.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +"""UDP echo example.""" +import argparse +import sys +import asyncio +try: + import signal +except ImportError: + signal = None + + +class MyServerUdpEchoProtocol: + + def connection_made(self, transport): + print('start', transport) + self.transport = transport + + def datagram_received(self, data, addr): + print('Data received:', data, addr) + self.transport.sendto(data, addr) + + def error_received(self, exc): + print('Error received:', exc) + + def connection_lost(self, exc): + print('stop', exc) + + +class MyClientUdpEchoProtocol: + + message = 'This is the message. It will be echoed.' + + def connection_made(self, transport): + self.transport = transport + print('sending "{}"'.format(self.message)) + self.transport.sendto(self.message.encode()) + print('waiting to receive') + + def datagram_received(self, data, addr): + print('received "{}"'.format(data.decode())) + self.transport.close() + + def error_received(self, exc): + print('Error received:', exc) + + def connection_lost(self, exc): + print('closing transport', exc) + loop = asyncio.get_event_loop() + loop.stop() + + +def start_server(loop, addr): + t = asyncio.Task(loop.create_datagram_endpoint( + MyServerUdpEchoProtocol, local_addr=addr)) + transport, server = loop.run_until_complete(t) + return transport + + +def start_client(loop, addr): + t = asyncio.Task(loop.create_datagram_endpoint( + MyClientUdpEchoProtocol, remote_addr=addr)) + loop.run_until_complete(t) + + +ARGS = argparse.ArgumentParser(description="UDP Echo example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run udp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run udp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') + + +if __name__ == '__main__': + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + loop = asyncio.get_event_loop() + if signal is not None: + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if '--server' in sys.argv: + server = start_server(loop, (args.host, args.port)) + else: + start_client(loop, (args.host, args.port)) + + try: + loop.run_forever() + finally: + if '--server' in sys.argv: + server.close() + loop.close() diff --git a/overlapped.c b/overlapped.c new file mode 100644 index 00000000..6842efbb --- /dev/null +++ b/overlapped.c @@ -0,0 +1,1380 @@ +/* + * Support for overlapped IO + * + * Some code borrowed from Modules/_winapi.c of CPython + */ + +/* XXX check overflow and DWORD <-> Py_ssize_t conversions + Check itemsize */ + +#include "Python.h" +#include "structmember.h" + +#define WINDOWS_LEAN_AND_MEAN +#include +#include +#include + +#if defined(MS_WIN32) && !defined(MS_WIN64) +# define F_POINTER "k" +# define T_POINTER T_ULONG +#else +# define F_POINTER "K" +# define T_POINTER T_ULONGLONG +#endif + +#define F_HANDLE F_POINTER +#define F_ULONG_PTR F_POINTER +#define F_DWORD "k" +#define F_BOOL "i" +#define F_UINT "I" + +#define T_HANDLE T_POINTER + +enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_WRITE, TYPE_ACCEPT, + TYPE_CONNECT, TYPE_DISCONNECT, TYPE_CONNECT_NAMED_PIPE, + TYPE_WAIT_NAMED_PIPE_AND_CONNECT}; + +typedef struct { + PyObject_HEAD + OVERLAPPED overlapped; + /* For convenience, we store the file handle too */ + HANDLE handle; + /* Error returned by last method call */ + DWORD error; + /* Type of operation */ + DWORD type; + union { + /* Buffer used for reading: TYPE_READ and TYPE_ACCEPT */ + PyObject *read_buffer; + /* Buffer used for writing: TYPE_WRITE */ + Py_buffer write_buffer; + }; +} OverlappedObject; + +typedef struct { + OVERLAPPED *Overlapped; + HANDLE IocpHandle; + char Address[1]; +} WaitNamedPipeAndConnectContext; + +/* + * Map Windows error codes to subclasses of OSError + */ + +static PyObject * +SetFromWindowsErr(DWORD err) +{ + PyObject *exception_type; + + if (err == 0) + err = GetLastError(); + switch (err) { + case ERROR_CONNECTION_REFUSED: + exception_type = PyExc_ConnectionRefusedError; + break; + case ERROR_CONNECTION_ABORTED: + exception_type = PyExc_ConnectionAbortedError; + break; + default: + exception_type = PyExc_OSError; + } + return PyErr_SetExcFromWindowsErr(exception_type, err); +} + +/* + * Some functions should be loaded at runtime + */ + +static LPFN_ACCEPTEX Py_AcceptEx = NULL; +static LPFN_CONNECTEX Py_ConnectEx = NULL; +static LPFN_DISCONNECTEX Py_DisconnectEx = NULL; +static BOOL (CALLBACK *Py_CancelIoEx)(HANDLE, LPOVERLAPPED) = NULL; + +#define GET_WSA_POINTER(s, x) \ + (SOCKET_ERROR != WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, \ + &Guid##x, sizeof(Guid##x), &Py_##x, \ + sizeof(Py_##x), &dwBytes, NULL, NULL)) + +static int +initialize_function_pointers(void) +{ + GUID GuidAcceptEx = WSAID_ACCEPTEX; + GUID GuidConnectEx = WSAID_CONNECTEX; + GUID GuidDisconnectEx = WSAID_DISCONNECTEX; + HINSTANCE hKernel32; + SOCKET s; + DWORD dwBytes; + + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (s == INVALID_SOCKET) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + if (!GET_WSA_POINTER(s, AcceptEx) || + !GET_WSA_POINTER(s, ConnectEx) || + !GET_WSA_POINTER(s, DisconnectEx)) + { + closesocket(s); + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + closesocket(s); + + /* On WinXP we will have Py_CancelIoEx == NULL */ + hKernel32 = GetModuleHandle("KERNEL32"); + *(FARPROC *)&Py_CancelIoEx = GetProcAddress(hKernel32, "CancelIoEx"); + return 0; +} + +/* + * Completion port stuff + */ + +PyDoc_STRVAR( + CreateIoCompletionPort_doc, + "CreateIoCompletionPort(handle, port, key, concurrency) -> port\n\n" + "Create a completion port or register a handle with a port."); + +static PyObject * +overlapped_CreateIoCompletionPort(PyObject *self, PyObject *args) +{ + HANDLE FileHandle; + HANDLE ExistingCompletionPort; + ULONG_PTR CompletionKey; + DWORD NumberOfConcurrentThreads; + HANDLE ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE F_ULONG_PTR F_DWORD, + &FileHandle, &ExistingCompletionPort, &CompletionKey, + &NumberOfConcurrentThreads)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = CreateIoCompletionPort(FileHandle, ExistingCompletionPort, + CompletionKey, NumberOfConcurrentThreads); + Py_END_ALLOW_THREADS + + if (ret == NULL) + return SetFromWindowsErr(0); + return Py_BuildValue(F_HANDLE, ret); +} + +PyDoc_STRVAR( + GetQueuedCompletionStatus_doc, + "GetQueuedCompletionStatus(port, msecs) -> (err, bytes, key, address)\n\n" + "Get a message from completion port. Wait for up to msecs milliseconds."); + +static PyObject * +overlapped_GetQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort = NULL; + DWORD NumberOfBytes = 0; + ULONG_PTR CompletionKey = 0; + OVERLAPPED *Overlapped = NULL; + DWORD Milliseconds; + DWORD err; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, + &CompletionPort, &Milliseconds)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = GetQueuedCompletionStatus(CompletionPort, &NumberOfBytes, + &CompletionKey, &Overlapped, Milliseconds); + Py_END_ALLOW_THREADS + + err = ret ? ERROR_SUCCESS : GetLastError(); + if (Overlapped == NULL) { + if (err == WAIT_TIMEOUT) + Py_RETURN_NONE; + else + return SetFromWindowsErr(err); + } + return Py_BuildValue(F_DWORD F_DWORD F_ULONG_PTR F_POINTER, + err, NumberOfBytes, CompletionKey, Overlapped); +} + +PyDoc_STRVAR( + PostQueuedCompletionStatus_doc, + "PostQueuedCompletionStatus(port, bytes, key, address) -> None\n\n" + "Post a message to completion port."); + +static PyObject * +overlapped_PostQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort; + DWORD NumberOfBytes; + ULONG_PTR CompletionKey; + OVERLAPPED *Overlapped; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD F_ULONG_PTR F_POINTER, + &CompletionPort, &NumberOfBytes, &CompletionKey, + &Overlapped)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = PostQueuedCompletionStatus(CompletionPort, NumberOfBytes, + CompletionKey, Overlapped); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +/* + * Wait for a handle + */ + +struct PostCallbackData { + HANDLE CompletionPort; + LPOVERLAPPED Overlapped; +}; + +static VOID CALLBACK +PostToQueueCallback(PVOID lpParameter, BOOL TimerOrWaitFired) +{ + struct PostCallbackData *p = (struct PostCallbackData*) lpParameter; + + PostQueuedCompletionStatus(p->CompletionPort, TimerOrWaitFired, + 0, p->Overlapped); + /* ignore possible error! */ + PyMem_Free(p); +} + +PyDoc_STRVAR( + RegisterWaitWithQueue_doc, + "RegisterWaitWithQueue(Object, CompletionPort, Overlapped, Timeout)\n" + " -> WaitHandle\n\n" + "Register wait for Object; when complete CompletionPort is notified.\n"); + +static PyObject * +overlapped_RegisterWaitWithQueue(PyObject *self, PyObject *args) +{ + HANDLE NewWaitObject; + HANDLE Object; + ULONG Milliseconds; + struct PostCallbackData data, *pdata; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE F_POINTER F_DWORD, + &Object, + &data.CompletionPort, + &data.Overlapped, + &Milliseconds)) + return NULL; + + pdata = PyMem_Malloc(sizeof(struct PostCallbackData)); + if (pdata == NULL) + return SetFromWindowsErr(0); + + *pdata = data; + + if (!RegisterWaitForSingleObject( + &NewWaitObject, Object, (WAITORTIMERCALLBACK)PostToQueueCallback, + pdata, Milliseconds, + WT_EXECUTEINWAITTHREAD | WT_EXECUTEONLYONCE)) + { + PyMem_Free(pdata); + return SetFromWindowsErr(0); + } + + return Py_BuildValue(F_HANDLE, NewWaitObject); +} + +PyDoc_STRVAR( + UnregisterWait_doc, + "UnregisterWait(WaitHandle) -> None\n\n" + "Unregister wait handle.\n"); + +static PyObject * +overlapped_UnregisterWait(PyObject *self, PyObject *args) +{ + HANDLE WaitHandle; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE, &WaitHandle)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = UnregisterWait(WaitHandle); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +/* + * Event functions -- currently only used by tests + */ + +PyDoc_STRVAR( + CreateEvent_doc, + "CreateEvent(EventAttributes, ManualReset, InitialState, Name)" + " -> Handle\n\n" + "Create an event. EventAttributes must be None.\n"); + +static PyObject * +overlapped_CreateEvent(PyObject *self, PyObject *args) +{ + PyObject *EventAttributes; + BOOL ManualReset; + BOOL InitialState; + Py_UNICODE *Name; + HANDLE Event; + + if (!PyArg_ParseTuple(args, "O" F_BOOL F_BOOL "Z", + &EventAttributes, &ManualReset, + &InitialState, &Name)) + return NULL; + + if (EventAttributes != Py_None) { + PyErr_SetString(PyExc_ValueError, "EventAttributes must be None"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + Event = CreateEventW(NULL, ManualReset, InitialState, Name); + Py_END_ALLOW_THREADS + + if (Event == NULL) + return SetFromWindowsErr(0); + return Py_BuildValue(F_HANDLE, Event); +} + +PyDoc_STRVAR( + SetEvent_doc, + "SetEvent(Handle) -> None\n\n" + "Set event.\n"); + +static PyObject * +overlapped_SetEvent(PyObject *self, PyObject *args) +{ + HANDLE Handle; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE, &Handle)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = SetEvent(Handle); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +PyDoc_STRVAR( + ResetEvent_doc, + "ResetEvent(Handle) -> None\n\n" + "Reset event.\n"); + +static PyObject * +overlapped_ResetEvent(PyObject *self, PyObject *args) +{ + HANDLE Handle; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE, &Handle)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = ResetEvent(Handle); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +/* + * Bind socket handle to local port without doing slow getaddrinfo() + */ + +PyDoc_STRVAR( + BindLocal_doc, + "BindLocal(handle, family) -> None\n\n" + "Bind a socket handle to an arbitrary local port.\n" + "family should AF_INET or AF_INET6.\n"); + +static PyObject * +overlapped_BindLocal(PyObject *self, PyObject *args) +{ + SOCKET Socket; + int Family; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE "i", &Socket, &Family)) + return NULL; + + if (Family == AF_INET) { + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = 0; + addr.sin_addr.S_un.S_addr = INADDR_ANY; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else if (Family == AF_INET6) { + struct sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_port = 0; + addr.sin6_addr = in6addr_any; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else { + PyErr_SetString(PyExc_ValueError, "expected tuple of length 2 or 4"); + return NULL; + } + + if (!ret) + return SetFromWindowsErr(WSAGetLastError()); + Py_RETURN_NONE; +} + +/* + * Windows equivalent of os.strerror() -- compare _ctypes/callproc.c + */ + +PyDoc_STRVAR( + FormatMessage_doc, + "FormatMessage(error_code) -> error_message\n\n" + "Return error message for an error code."); + +static PyObject * +overlapped_FormatMessage(PyObject *ignore, PyObject *args) +{ + DWORD code, n; + WCHAR *lpMsgBuf; + PyObject *res; + + if (!PyArg_ParseTuple(args, F_DWORD, &code)) + return NULL; + + n = FormatMessageW(FORMAT_MESSAGE_ALLOCATE_BUFFER | + FORMAT_MESSAGE_FROM_SYSTEM, + NULL, + code, + MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + (LPWSTR) &lpMsgBuf, + 0, + NULL); + if (n) { + while (iswspace(lpMsgBuf[n-1])) + --n; + lpMsgBuf[n] = L'\0'; + res = Py_BuildValue("u", lpMsgBuf); + } else { + res = PyUnicode_FromFormat("unknown error code %u", code); + } + LocalFree(lpMsgBuf); + return res; +} + + +/* + * Mark operation as completed - used when reading produces ERROR_BROKEN_PIPE + */ + +static void +mark_as_completed(OVERLAPPED *ov) +{ + ov->Internal = 0; + if (ov->hEvent != NULL) + SetEvent(ov->hEvent); +} + +/* + * A Python object wrapping an OVERLAPPED structure and other useful data + * for overlapped I/O + */ + +PyDoc_STRVAR( + Overlapped_doc, + "Overlapped object"); + +static PyObject * +Overlapped_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + OverlappedObject *self; + HANDLE event = INVALID_HANDLE_VALUE; + static char *kwlist[] = {"event", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|" F_HANDLE, kwlist, &event)) + return NULL; + + if (event == INVALID_HANDLE_VALUE) { + event = CreateEvent(NULL, TRUE, FALSE, NULL); + if (event == NULL) + return SetFromWindowsErr(0); + } + + self = PyObject_New(OverlappedObject, type); + if (self == NULL) { + if (event != NULL) + CloseHandle(event); + return NULL; + } + + self->handle = NULL; + self->error = 0; + self->type = TYPE_NONE; + self->read_buffer = NULL; + memset(&self->overlapped, 0, sizeof(OVERLAPPED)); + memset(&self->write_buffer, 0, sizeof(Py_buffer)); + if (event) + self->overlapped.hEvent = event; + return (PyObject *)self; +} + +static void +Overlapped_dealloc(OverlappedObject *self) +{ + DWORD bytes; + DWORD olderr = GetLastError(); + BOOL wait = FALSE; + BOOL ret; + + if (!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED) + { + if (Py_CancelIoEx && Py_CancelIoEx(self->handle, &self->overlapped)) + wait = TRUE; + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, + &bytes, wait); + Py_END_ALLOW_THREADS + + switch (ret ? ERROR_SUCCESS : GetLastError()) { + case ERROR_SUCCESS: + case ERROR_NOT_FOUND: + case ERROR_OPERATION_ABORTED: + break; + default: + PyErr_Format( + PyExc_RuntimeError, + "%R still has pending operation at " + "deallocation, the process may crash", self); + PyErr_WriteUnraisable(NULL); + } + } + + if (self->overlapped.hEvent != NULL) + CloseHandle(self->overlapped.hEvent); + + switch (self->type) { + case TYPE_READ: + case TYPE_ACCEPT: + Py_CLEAR(self->read_buffer); + break; + case TYPE_WRITE: + if (self->write_buffer.obj) + PyBuffer_Release(&self->write_buffer); + break; + } + PyObject_Del(self); + SetLastError(olderr); +} + +PyDoc_STRVAR( + Overlapped_cancel_doc, + "cancel() -> None\n\n" + "Cancel overlapped operation"); + +static PyObject * +Overlapped_cancel(OverlappedObject *self) +{ + BOOL ret = TRUE; + + if (self->type == TYPE_NOT_STARTED + || self->type == TYPE_WAIT_NAMED_PIPE_AND_CONNECT) + Py_RETURN_NONE; + + if (!HasOverlappedIoCompleted(&self->overlapped)) { + Py_BEGIN_ALLOW_THREADS + if (Py_CancelIoEx) + ret = Py_CancelIoEx(self->handle, &self->overlapped); + else + ret = CancelIo(self->handle); + Py_END_ALLOW_THREADS + } + + /* CancelIoEx returns ERROR_NOT_FOUND if the I/O completed in-between */ + if (!ret && GetLastError() != ERROR_NOT_FOUND) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +PyDoc_STRVAR( + Overlapped_getresult_doc, + "getresult(wait=False) -> result\n\n" + "Retrieve result of operation. If wait is true then it blocks\n" + "until the operation is finished. If wait is false and the\n" + "operation is still pending then an error is raised."); + +static PyObject * +Overlapped_getresult(OverlappedObject *self, PyObject *args) +{ + BOOL wait = FALSE; + DWORD transferred = 0; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, "|" F_BOOL, &wait)) + return NULL; + + if (self->type == TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation not yet attempted"); + return NULL; + } + + if (self->type == TYPE_NOT_STARTED) { + PyErr_SetString(PyExc_ValueError, "operation failed to start"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, &transferred, + wait); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + break; + case ERROR_BROKEN_PIPE: + if ((self->type == TYPE_READ || self->type == TYPE_ACCEPT) && self->read_buffer != NULL) + break; + /* fall through */ + default: + return SetFromWindowsErr(err); + } + + switch (self->type) { + case TYPE_READ: + assert(PyBytes_CheckExact(self->read_buffer)); + if (transferred != PyBytes_GET_SIZE(self->read_buffer) && + _PyBytes_Resize(&self->read_buffer, transferred)) + return NULL; + Py_INCREF(self->read_buffer); + return self->read_buffer; + default: + return PyLong_FromUnsignedLong((unsigned long) transferred); + } +} + +PyDoc_STRVAR( + Overlapped_ReadFile_doc, + "ReadFile(handle, size) -> Overlapped[message]\n\n" + "Start overlapped read"); + +static PyObject * +Overlapped_ReadFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD nread; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &handle, &size)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = ReadFile(handle, PyBytes_AS_STRING(buf), size, &nread, + &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_BROKEN_PIPE: + mark_as_completed(&self->overlapped); + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSARecv_doc, + "RecvFile(handle, size, flags) -> Overlapped[message]\n\n" + "Start overlapped receive"); + +static PyObject * +Overlapped_WSARecv(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD flags = 0; + DWORD nread; + PyObject *buf; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD "|" F_DWORD, + &handle, &size, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + wsabuf.len = size; + wsabuf.buf = PyBytes_AS_STRING(buf); + + Py_BEGIN_ALLOW_THREADS + ret = WSARecv((SOCKET)handle, &wsabuf, 1, &nread, &flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_BROKEN_PIPE: + mark_as_completed(&self->overlapped); + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WriteFile_doc, + "WriteFile(handle, buf) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped write"); + +static PyObject * +Overlapped_WriteFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD written; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &handle, &bufobj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + + Py_BEGIN_ALLOW_THREADS + ret = WriteFile(handle, self->write_buffer.buf, + (DWORD)self->write_buffer.len, + &written, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSASend_doc, + "WSASend(handle, buf, flags) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped send"); + +static PyObject * +Overlapped_WSASend(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD flags; + DWORD written; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O" F_DWORD, + &handle, &bufobj, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + wsabuf.len = (DWORD)self->write_buffer.len; + wsabuf.buf = self->write_buffer.buf; + + Py_BEGIN_ALLOW_THREADS + ret = WSASend((SOCKET)handle, &wsabuf, 1, &written, flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_AcceptEx_doc, + "AcceptEx(listen_handle, accept_handle) -> Overlapped[address_as_bytes]\n\n" + "Start overlapped wait for client to connect"); + +static PyObject * +Overlapped_AcceptEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ListenSocket; + SOCKET AcceptSocket; + DWORD BytesReceived; + DWORD size; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE, + &ListenSocket, &AcceptSocket)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + size = sizeof(struct sockaddr_in6) + 16; + buf = PyBytes_FromStringAndSize(NULL, size*2); + if (!buf) + return NULL; + + self->type = TYPE_ACCEPT; + self->handle = (HANDLE)ListenSocket; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = Py_AcceptEx(ListenSocket, AcceptSocket, PyBytes_AS_STRING(buf), + 0, size, size, &BytesReceived, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + + +static int +parse_address(PyObject *obj, SOCKADDR *Address, int Length) +{ + char *Host; + unsigned short Port; + unsigned long FlowInfo; + unsigned long ScopeId; + + memset(Address, 0, Length); + + if (PyArg_ParseTuple(obj, "sH", &Host, &Port)) + { + Address->sa_family = AF_INET; + if (WSAStringToAddressA(Host, AF_INET, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN*)Address)->sin_port = htons(Port); + return Length; + } + else if (PyArg_ParseTuple(obj, "sHkk", &Host, &Port, &FlowInfo, &ScopeId)) + { + PyErr_Clear(); + Address->sa_family = AF_INET6; + if (WSAStringToAddressA(Host, AF_INET6, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN6*)Address)->sin6_port = htons(Port); + ((SOCKADDR_IN6*)Address)->sin6_flowinfo = FlowInfo; + ((SOCKADDR_IN6*)Address)->sin6_scope_id = ScopeId; + return Length; + } + + return -1; +} + + +PyDoc_STRVAR( + Overlapped_ConnectEx_doc, + "ConnectEx(client_handle, address_as_bytes) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_ConnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ConnectSocket; + PyObject *AddressObj; + char AddressBuf[sizeof(struct sockaddr_in6)]; + SOCKADDR *Address = (SOCKADDR*)AddressBuf; + int Length; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &ConnectSocket, &AddressObj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + Length = sizeof(AddressBuf); + Length = parse_address(AddressObj, Address, Length); + if (Length < 0) + return NULL; + + self->type = TYPE_CONNECT; + self->handle = (HANDLE)ConnectSocket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_ConnectEx(ConnectSocket, Address, Length, + NULL, 0, NULL, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_DisconnectEx_doc, + "DisconnectEx(handle, flags) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_DisconnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET Socket; + DWORD flags; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &Socket, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + self->type = TYPE_DISCONNECT; + self->handle = (HANDLE)Socket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_DisconnectEx(Socket, &self->overlapped, flags, 0); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_ConnectNamedPipe_doc, + "ConnectNamedPipe(handle) -> Overlapped[None]\n\n" + "Start overlapped wait for a client to connect."); + +static PyObject * +Overlapped_ConnectNamedPipe(OverlappedObject *self, PyObject *args) +{ + HANDLE Pipe; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE, &Pipe)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + self->type = TYPE_CONNECT_NAMED_PIPE; + self->handle = Pipe; + + Py_BEGIN_ALLOW_THREADS + ret = ConnectNamedPipe(Pipe, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_PIPE_CONNECTED: + mark_as_completed(&self->overlapped); + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +/* Unfortunately there is no way to do an overlapped connect to a + pipe. We instead use WaitNamedPipe() and CreateFile() in a thread + pool thread. If a connection succeeds within a time limit (10 + seconds) then PostQueuedCompletionStatus() is used to return the + pipe handle to the completion port. */ + +static DWORD WINAPI +WaitNamedPipeAndConnectInThread(WaitNamedPipeAndConnectContext *ctx) +{ + HANDLE PipeHandle = INVALID_HANDLE_VALUE; + DWORD Start = GetTickCount(); + DWORD Deadline = Start + 10*1000; + DWORD Error = 0; + DWORD Timeout; + BOOL Success; + + for ( ; ; ) { + Timeout = Deadline - GetTickCount(); + if ((int)Timeout < 0) + break; + Success = WaitNamedPipe(ctx->Address, Timeout); + Error = Success ? ERROR_SUCCESS : GetLastError(); + switch (Error) { + case ERROR_SUCCESS: + PipeHandle = CreateFile(ctx->Address, + GENERIC_READ | GENERIC_WRITE, + 0, NULL, OPEN_EXISTING, + FILE_FLAG_OVERLAPPED, NULL); + if (PipeHandle == INVALID_HANDLE_VALUE) + continue; + break; + case ERROR_SEM_TIMEOUT: + continue; + } + break; + } + if (!PostQueuedCompletionStatus(ctx->IocpHandle, Error, + (ULONG_PTR)PipeHandle, ctx->Overlapped)) + CloseHandle(PipeHandle); + free(ctx); + return 0; +} + +PyDoc_STRVAR( + Overlapped_WaitNamedPipeAndConnect_doc, + "WaitNamedPipeAndConnect(addr, iocp_handle) -> Overlapped[pipe_handle]\n\n" + "Start overlapped connection to address, notifying iocp_handle when\n" + "finished"); + +static PyObject * +Overlapped_WaitNamedPipeAndConnect(OverlappedObject *self, PyObject *args) +{ + char *Address; + Py_ssize_t AddressLength; + HANDLE IocpHandle; + OVERLAPPED Overlapped; + BOOL ret; + DWORD err; + WaitNamedPipeAndConnectContext *ctx; + Py_ssize_t ContextLength; + + if (!PyArg_ParseTuple(args, "s#" F_HANDLE F_POINTER, + &Address, &AddressLength, &IocpHandle, &Overlapped)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + ContextLength = (AddressLength + + offsetof(WaitNamedPipeAndConnectContext, Address)); + ctx = calloc(1, ContextLength + 1); + if (ctx == NULL) + return PyErr_NoMemory(); + memcpy(ctx->Address, Address, AddressLength + 1); + ctx->Overlapped = &self->overlapped; + ctx->IocpHandle = IocpHandle; + + self->type = TYPE_WAIT_NAMED_PIPE_AND_CONNECT; + self->handle = NULL; + + Py_BEGIN_ALLOW_THREADS + ret = QueueUserWorkItem(WaitNamedPipeAndConnectInThread, ctx, + WT_EXECUTELONGFUNCTION); + Py_END_ALLOW_THREADS + + mark_as_completed(&self->overlapped); + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + if (!ret) + return SetFromWindowsErr(err); + Py_RETURN_NONE; +} + +static PyObject* +Overlapped_getaddress(OverlappedObject *self) +{ + return PyLong_FromVoidPtr(&self->overlapped); +} + +static PyObject* +Overlapped_getpending(OverlappedObject *self) +{ + return PyBool_FromLong(!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED); +} + +static PyMethodDef Overlapped_methods[] = { + {"getresult", (PyCFunction) Overlapped_getresult, + METH_VARARGS, Overlapped_getresult_doc}, + {"cancel", (PyCFunction) Overlapped_cancel, + METH_NOARGS, Overlapped_cancel_doc}, + {"ReadFile", (PyCFunction) Overlapped_ReadFile, + METH_VARARGS, Overlapped_ReadFile_doc}, + {"WSARecv", (PyCFunction) Overlapped_WSARecv, + METH_VARARGS, Overlapped_WSARecv_doc}, + {"WriteFile", (PyCFunction) Overlapped_WriteFile, + METH_VARARGS, Overlapped_WriteFile_doc}, + {"WSASend", (PyCFunction) Overlapped_WSASend, + METH_VARARGS, Overlapped_WSASend_doc}, + {"AcceptEx", (PyCFunction) Overlapped_AcceptEx, + METH_VARARGS, Overlapped_AcceptEx_doc}, + {"ConnectEx", (PyCFunction) Overlapped_ConnectEx, + METH_VARARGS, Overlapped_ConnectEx_doc}, + {"DisconnectEx", (PyCFunction) Overlapped_DisconnectEx, + METH_VARARGS, Overlapped_DisconnectEx_doc}, + {"ConnectNamedPipe", (PyCFunction) Overlapped_ConnectNamedPipe, + METH_VARARGS, Overlapped_ConnectNamedPipe_doc}, + {"WaitNamedPipeAndConnect", + (PyCFunction) Overlapped_WaitNamedPipeAndConnect, + METH_VARARGS, Overlapped_WaitNamedPipeAndConnect_doc}, + {NULL} +}; + +static PyMemberDef Overlapped_members[] = { + {"error", T_ULONG, + offsetof(OverlappedObject, error), + READONLY, "Error from last operation"}, + {"event", T_HANDLE, + offsetof(OverlappedObject, overlapped) + offsetof(OVERLAPPED, hEvent), + READONLY, "Overlapped event handle"}, + {NULL} +}; + +static PyGetSetDef Overlapped_getsets[] = { + {"address", (getter)Overlapped_getaddress, NULL, + "Address of overlapped structure"}, + {"pending", (getter)Overlapped_getpending, NULL, + "Whether the operation is pending"}, + {NULL}, +}; + +PyTypeObject OverlappedType = { + PyVarObject_HEAD_INIT(NULL, 0) + /* tp_name */ "_overlapped.Overlapped", + /* tp_basicsize */ sizeof(OverlappedObject), + /* tp_itemsize */ 0, + /* tp_dealloc */ (destructor) Overlapped_dealloc, + /* tp_print */ 0, + /* tp_getattr */ 0, + /* tp_setattr */ 0, + /* tp_reserved */ 0, + /* tp_repr */ 0, + /* tp_as_number */ 0, + /* tp_as_sequence */ 0, + /* tp_as_mapping */ 0, + /* tp_hash */ 0, + /* tp_call */ 0, + /* tp_str */ 0, + /* tp_getattro */ 0, + /* tp_setattro */ 0, + /* tp_as_buffer */ 0, + /* tp_flags */ Py_TPFLAGS_DEFAULT, + /* tp_doc */ "OVERLAPPED structure wrapper", + /* tp_traverse */ 0, + /* tp_clear */ 0, + /* tp_richcompare */ 0, + /* tp_weaklistoffset */ 0, + /* tp_iter */ 0, + /* tp_iternext */ 0, + /* tp_methods */ Overlapped_methods, + /* tp_members */ Overlapped_members, + /* tp_getset */ Overlapped_getsets, + /* tp_base */ 0, + /* tp_dict */ 0, + /* tp_descr_get */ 0, + /* tp_descr_set */ 0, + /* tp_dictoffset */ 0, + /* tp_init */ 0, + /* tp_alloc */ 0, + /* tp_new */ Overlapped_new, +}; + +static PyMethodDef overlapped_functions[] = { + {"CreateIoCompletionPort", overlapped_CreateIoCompletionPort, + METH_VARARGS, CreateIoCompletionPort_doc}, + {"GetQueuedCompletionStatus", overlapped_GetQueuedCompletionStatus, + METH_VARARGS, GetQueuedCompletionStatus_doc}, + {"PostQueuedCompletionStatus", overlapped_PostQueuedCompletionStatus, + METH_VARARGS, PostQueuedCompletionStatus_doc}, + {"FormatMessage", overlapped_FormatMessage, + METH_VARARGS, FormatMessage_doc}, + {"BindLocal", overlapped_BindLocal, + METH_VARARGS, BindLocal_doc}, + {"RegisterWaitWithQueue", overlapped_RegisterWaitWithQueue, + METH_VARARGS, RegisterWaitWithQueue_doc}, + {"UnregisterWait", overlapped_UnregisterWait, + METH_VARARGS, UnregisterWait_doc}, + {"CreateEvent", overlapped_CreateEvent, + METH_VARARGS, CreateEvent_doc}, + {"SetEvent", overlapped_SetEvent, + METH_VARARGS, SetEvent_doc}, + {"ResetEvent", overlapped_ResetEvent, + METH_VARARGS, ResetEvent_doc}, + {NULL} +}; + +static struct PyModuleDef overlapped_module = { + PyModuleDef_HEAD_INIT, + "_overlapped", + NULL, + -1, + overlapped_functions, + NULL, + NULL, + NULL, + NULL +}; + +#define WINAPI_CONSTANT(fmt, con) \ + PyDict_SetItemString(d, #con, Py_BuildValue(fmt, con)) + +PyMODINIT_FUNC +PyInit__overlapped(void) +{ + PyObject *m, *d; + + /* Ensure WSAStartup() called before initializing function pointers */ + m = PyImport_ImportModule("_socket"); + if (!m) + return NULL; + Py_DECREF(m); + + if (initialize_function_pointers() < 0) + return NULL; + + if (PyType_Ready(&OverlappedType) < 0) + return NULL; + + m = PyModule_Create(&overlapped_module); + if (PyModule_AddObject(m, "Overlapped", (PyObject *)&OverlappedType) < 0) + return NULL; + + d = PyModule_GetDict(m); + + WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING); + WINAPI_CONSTANT(F_DWORD, ERROR_NETNAME_DELETED); + WINAPI_CONSTANT(F_DWORD, ERROR_SEM_TIMEOUT); + WINAPI_CONSTANT(F_DWORD, INFINITE); + WINAPI_CONSTANT(F_HANDLE, INVALID_HANDLE_VALUE); + WINAPI_CONSTANT(F_HANDLE, NULL); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_ACCEPT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_CONNECT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, TF_REUSE_SOCKET); + + return m; +} diff --git a/pypi.bat b/pypi.bat new file mode 100644 index 00000000..5218ace3 --- /dev/null +++ b/pypi.bat @@ -0,0 +1 @@ +c:\Python33\python.exe setup.py bdist_wheel upload diff --git a/run_aiotest.py b/run_aiotest.py new file mode 100644 index 00000000..8d6fa293 --- /dev/null +++ b/run_aiotest.py @@ -0,0 +1,14 @@ +import aiotest.run +import asyncio +import sys +if sys.platform == 'win32': + from asyncio.windows_utils import socketpair +else: + from socket import socketpair + +config = aiotest.TestConfig() +config.asyncio = asyncio +config.socketpair = socketpair +config.new_event_pool_policy = asyncio.DefaultEventLoopPolicy +config.call_soon_check_closed = True +aiotest.run.main(config) diff --git a/runtests.py b/runtests.py new file mode 100644 index 00000000..e9bbdd8e --- /dev/null +++ b/runtests.py @@ -0,0 +1,302 @@ +"""Run Tulip unittests. + +Usage: + python3 runtests.py [flags] [pattern] ... + +Patterns are matched against the fully qualified name of the test, +including package, module, class and method, +e.g. 'tests.test_events.PolicyTests.testPolicy'. + +For full help, try --help. + +runtests.py --coverage is equivalent of: + + $(COVERAGE) run --branch runtests.py -v + $(COVERAGE) html $(list of files) + $(COVERAGE) report -m $(list of files) + +""" + +# Originally written by Beech Horn (for NDB). + +import argparse +import gc +import logging +import os +import random +import re +import sys +import unittest +import textwrap +import importlib.machinery +try: + import coverage +except ImportError: + coverage = None + +from unittest.signals import installHandler + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +ARGS = argparse.ArgumentParser(description="Run all unittests.") +ARGS.add_argument( + '-v', action="store", dest='verbose', + nargs='?', const=1, type=int, default=0, help='verbose') +ARGS.add_argument( + '-x', action="store_true", dest='exclude', help='exclude tests') +ARGS.add_argument( + '-f', '--failfast', action="store_true", default=False, + dest='failfast', help='Stop on first fail or error') +ARGS.add_argument( + '-c', '--catch', action="store_true", default=False, + dest='catchbreak', help='Catch control-C and display results') +ARGS.add_argument( + '--forever', action="store_true", dest='forever', default=False, + help='run tests forever to catch sporadic errors') +ARGS.add_argument( + '--findleaks', action='store_true', dest='findleaks', + help='detect tests that leak memory') +ARGS.add_argument('-r', '--randomize', action='store_true', + help='randomize test execution order.') +ARGS.add_argument('--seed', type=int, + help='random seed to reproduce a previous random run') +ARGS.add_argument( + '-q', action="store_true", dest='quiet', help='quiet') +ARGS.add_argument( + '--tests', action="store", dest='testsdir', default='tests', + help='tests directory') +ARGS.add_argument( + '--coverage', action="store_true", dest='coverage', + help='enable html coverage report') +ARGS.add_argument( + 'pattern', action="store", nargs="*", + help='optional regex patterns to match test ids (default all tests)') + +COV_ARGS = argparse.ArgumentParser(description="Run all unittests.") +COV_ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') + + +def load_modules(basedir, suffix='.py'): + def list_dir(prefix, dir): + files = [] + + modpath = os.path.join(dir, '__init__.py') + if os.path.isfile(modpath): + mod = os.path.split(dir)[-1] + files.append(('{}{}'.format(prefix, mod), modpath)) + + prefix = '{}{}.'.format(prefix, mod) + + for name in os.listdir(dir): + path = os.path.join(dir, name) + + if os.path.isdir(path): + files.extend(list_dir('{}{}.'.format(prefix, name), path)) + else: + if (name != '__init__.py' and + name.endswith(suffix) and + not name.startswith(('.', '_'))): + files.append(('{}{}'.format(prefix, name[:-3]), path)) + + return files + + mods = [] + for modname, sourcefile in list_dir('', basedir): + if modname == 'runtests': + continue + try: + loader = importlib.machinery.SourceFileLoader(modname, sourcefile) + mods.append((loader.load_module(), sourcefile)) + except SyntaxError: + raise + except unittest.SkipTest as err: + print("Skipping '{}': {}".format(modname, err), file=sys.stderr) + + return mods + + +def randomize_tests(tests, seed): + if seed is None: + seed = random.randrange(10000000) + random.seed(seed) + print("Using random seed", seed) + random.shuffle(tests._tests) + + +class TestsFinder: + + def __init__(self, testsdir, includes=(), excludes=()): + self._testsdir = testsdir + self._includes = includes + self._excludes = excludes + self.find_available_tests() + + def find_available_tests(self): + """ + Find available test classes without instantiating them. + """ + self._test_factories = [] + mods = [mod for mod, _ in load_modules(self._testsdir)] + for mod in mods: + for name in set(dir(mod)): + if name.endswith('Tests'): + self._test_factories.append(getattr(mod, name)) + + def load_tests(self): + """ + Load test cases from the available test classes and apply + optional include / exclude filters. + """ + loader = unittest.TestLoader() + suite = unittest.TestSuite() + for test_factory in self._test_factories: + tests = loader.loadTestsFromTestCase(test_factory) + if self._includes: + tests = [test + for test in tests + if any(re.search(pat, test.id()) + for pat in self._includes)] + if self._excludes: + tests = [test + for test in tests + if not any(re.search(pat, test.id()) + for pat in self._excludes)] + suite.addTests(tests) + return suite + + +class TestResult(unittest.TextTestResult): + + def __init__(self, stream, descriptions, verbosity): + super().__init__(stream, descriptions, verbosity) + self.leaks = [] + + def startTest(self, test): + super().startTest(test) + gc.collect() + + def addSuccess(self, test): + super().addSuccess(test) + gc.collect() + if gc.garbage: + if self.showAll: + self.stream.writeln( + " Warning: test created {} uncollectable " + "object(s).".format(len(gc.garbage))) + # move the uncollectable objects somewhere so we don't see + # them again + self.leaks.append((self.getDescription(test), gc.garbage[:])) + del gc.garbage[:] + + +class TestRunner(unittest.TextTestRunner): + resultclass = TestResult + + def run(self, test): + result = super().run(test) + if result.leaks: + self.stream.writeln("{} tests leaks:".format(len(result.leaks))) + for name, leaks in result.leaks: + self.stream.writeln(' '*4 + name + ':') + for leak in leaks: + self.stream.writeln(' '*8 + repr(leak)) + return result + + +def runtests(): + args = ARGS.parse_args() + + if args.coverage and coverage is None: + URL = "bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py" + print(textwrap.dedent(""" + coverage package is not installed. + + To install coverage3 for Python 3, you need: + - Setuptools (https://pypi.python.org/pypi/setuptools) + + What worked for me: + - download {0} + * curl -O https://{0} + - python3 ez_setup.py + - python3 -m easy_install coverage + """.format(URL)).strip()) + sys.exit(1) + + testsdir = os.path.abspath(args.testsdir) + if not os.path.isdir(testsdir): + print("Tests directory is not found: {}\n".format(testsdir)) + ARGS.print_help() + return + + excludes = includes = [] + if args.exclude: + excludes = args.pattern + else: + includes = args.pattern + + v = 0 if args.quiet else args.verbose + 1 + failfast = args.failfast + catchbreak = args.catchbreak + findleaks = args.findleaks + runner_factory = TestRunner if findleaks else unittest.TextTestRunner + + if args.coverage: + cov = coverage.coverage(branch=True, + source=['asyncio'], + ) + cov.start() + + logger = logging.getLogger() + if v == 0: + level = logging.CRITICAL + elif v == 1: + level = logging.ERROR + elif v == 2: + level = logging.WARNING + elif v == 3: + level = logging.INFO + elif v >= 4: + level = logging.DEBUG + logging.basicConfig(level=level) + + finder = TestsFinder(args.testsdir, includes, excludes) + if catchbreak: + installHandler() + import asyncio.coroutines + if asyncio.coroutines._DEBUG: + print("Run tests in debug mode") + else: + print("Run tests in release mode") + try: + if args.forever: + while True: + tests = finder.load_tests() + if args.randomize: + randomize_tests(tests, args.seed) + result = runner_factory(verbosity=v, + failfast=failfast).run(tests) + if not result.wasSuccessful(): + sys.exit(1) + else: + tests = finder.load_tests() + if args.randomize: + randomize_tests(tests, args.seed) + result = runner_factory(verbosity=v, + failfast=failfast).run(tests) + sys.exit(not result.wasSuccessful()) + finally: + if args.coverage: + cov.stop() + cov.save() + cov.html_report(directory='htmlcov') + print("\nCoverage report:") + cov.report(show_missing=False) + here = os.path.dirname(os.path.abspath(__file__)) + print("\nFor html report:") + print("open file://{}/htmlcov/index.html".format(here)) + + +if __name__ == '__main__': + runtests() diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..fcd3b6aa --- /dev/null +++ b/setup.py @@ -0,0 +1,34 @@ +import os +try: + from setuptools import setup, Extension +except ImportError: + # Use distutils.core as a fallback. + # We won't be able to build the Wheel file on Windows. + from distutils.core import setup, Extension + +extensions = [] +if os.name == 'nt': + ext = Extension( + 'asyncio._overlapped', ['overlapped.c'], libraries=['ws2_32'], + ) + extensions.append(ext) + +setup( + name="asyncio", + version="3.4.1", + + description="reference implementation of PEP 3156", + long_description=open("README").read(), + url="http://www.python.org/dev/peps/pep-3156/", + + classifiers=[ + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.3", + ], + + packages=["asyncio"], + test_suite="runtests.runtests", + + ext_modules=extensions, +) diff --git a/tests/echo.py b/tests/echo.py new file mode 100644 index 00000000..006364bb --- /dev/null +++ b/tests/echo.py @@ -0,0 +1,8 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + if not buf: + break + os.write(1, buf) diff --git a/tests/echo2.py b/tests/echo2.py new file mode 100644 index 00000000..e83ca09f --- /dev/null +++ b/tests/echo2.py @@ -0,0 +1,6 @@ +import os + +if __name__ == '__main__': + buf = os.read(0, 1024) + os.write(1, b'OUT:'+buf) + os.write(2, b'ERR:'+buf) diff --git a/tests/echo3.py b/tests/echo3.py new file mode 100644 index 00000000..06449673 --- /dev/null +++ b/tests/echo3.py @@ -0,0 +1,11 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + if not buf: + break + try: + os.write(1, b'OUT:'+buf) + except OSError as ex: + os.write(2, b'ERR:' + ex.__class__.__name__.encode('ascii')) diff --git a/tests/keycert3.pem b/tests/keycert3.pem new file mode 100644 index 00000000..5bfa62c4 --- /dev/null +++ b/tests/keycert3.pem @@ -0,0 +1,73 @@ +-----BEGIN PRIVATE KEY----- +MIICdgIBADANBgkqhkiG9w0BAQEFAASCAmAwggJcAgEAAoGBAMLgD0kAKDb5cFyP +jbwNfR5CtewdXC+kMXAWD8DLxiTTvhMW7qVnlwOm36mZlszHKvsRf05lT4pegiFM +9z2j1OlaN+ci/X7NU22TNN6crYSiN77FjYJP464j876ndSxyD+rzys386T+1r1aZ +aggEdkj1TsSsv1zWIYKlPIjlvhuxAgMBAAECgYA0aH+T2Vf3WOPv8KdkcJg6gCRe +yJKXOWgWRcicx/CUzOEsTxmFIDPLxqAWA3k7v0B+3vjGw5Y9lycV/5XqXNoQI14j +y09iNsumds13u5AKkGdTJnZhQ7UKdoVHfuP44ZdOv/rJ5/VD6F4zWywpe90pcbK+ +AWDVtusgGQBSieEl1QJBAOyVrUG5l2yoUBtd2zr/kiGm/DYyXlIthQO/A3/LngDW +5/ydGxVsT7lAVOgCsoT+0L4efTh90PjzW8LPQrPBWVMCQQDS3h/FtYYd5lfz+FNL +9CEe1F1w9l8P749uNUD0g317zv1tatIqVCsQWHfVHNdVvfQ+vSFw38OORO00Xqs9 +1GJrAkBkoXXEkxCZoy4PteheO/8IWWLGGr6L7di6MzFl1lIqwT6D8L9oaV2vynFT +DnKop0pa09Unhjyw57KMNmSE2SUJAkEArloTEzpgRmCq4IK2/NpCeGdHS5uqRlbh +1VIa/xGps7EWQl5Mn8swQDel/YP3WGHTjfx7pgSegQfkyaRtGpZ9OQJAa9Vumj8m +JAAtI0Bnga8hgQx7BhTQY4CadDxyiRGOGYhwUzYVCqkb2sbVRH9HnwUaJT7cWBY3 +RnJdHOMXWem7/w== +-----END PRIVATE KEY----- +Certificate: + Data: + Version: 1 (0x0) + Serial Number: 12723342612721443281 (0xb09264b1f2da21d1) + Signature Algorithm: sha1WithRSAEncryption + Issuer: C=XY, O=Python Software Foundation CA, CN=our-ca-server + Validity + Not Before: Jan 4 19:47:07 2013 GMT + Not After : Nov 13 19:47:07 2022 GMT + Subject: C=XY, L=Castle Anthrax, O=Python Software Foundation, CN=localhost + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (1024 bit) + Modulus: + 00:c2:e0:0f:49:00:28:36:f9:70:5c:8f:8d:bc:0d: + 7d:1e:42:b5:ec:1d:5c:2f:a4:31:70:16:0f:c0:cb: + c6:24:d3:be:13:16:ee:a5:67:97:03:a6:df:a9:99: + 96:cc:c7:2a:fb:11:7f:4e:65:4f:8a:5e:82:21:4c: + f7:3d:a3:d4:e9:5a:37:e7:22:fd:7e:cd:53:6d:93: + 34:de:9c:ad:84:a2:37:be:c5:8d:82:4f:e3:ae:23: + f3:be:a7:75:2c:72:0f:ea:f3:ca:cd:fc:e9:3f:b5: + af:56:99:6a:08:04:76:48:f5:4e:c4:ac:bf:5c:d6: + 21:82:a5:3c:88:e5:be:1b:b1 + Exponent: 65537 (0x10001) + Signature Algorithm: sha1WithRSAEncryption + 2f:42:5f:a3:09:2c:fa:51:88:c7:37:7f:ea:0e:63:f0:a2:9a: + e5:5a:e2:c8:20:f0:3f:60:bc:c8:0f:b6:c6:76:ce:db:83:93: + f5:a3:33:67:01:8e:04:cd:00:9a:73:fd:f3:35:86:fa:d7:13: + e2:46:c6:9d:c0:29:53:d4:a9:90:b8:77:4b:e6:83:76:e4:92: + d6:9c:50:cf:43:d0:c6:01:77:61:9a:de:9b:70:f7:72:cd:59: + 00:31:69:d9:b4:ca:06:9c:6d:c3:c7:80:8c:68:e6:b5:a2:f8: + ef:1d:bb:16:9f:77:77:ef:87:62:22:9b:4d:69:a4:3a:1a:f1: + 21:5e:8c:32:ac:92:fd:15:6b:18:c2:7f:15:0d:98:30:ca:75: + 8f:1a:71:df:da:1d:b2:ef:9a:e8:2d:2e:02:fd:4a:3c:aa:96: + 0b:06:5d:35:b3:3d:24:87:4b:e0:b0:58:60:2f:45:ac:2e:48: + 8a:b0:99:10:65:27:ff:cc:b1:d8:fd:bd:26:6b:b9:0c:05:2a: + f4:45:63:35:51:07:ed:83:85:fe:6f:69:cb:bb:40:a8:ae:b6: + 3b:56:4a:2d:a4:ed:6d:11:2c:4d:ed:17:24:fd:47:bc:d3:41: + a2:d3:06:fe:0c:90:d8:d8:94:26:c4:ff:cc:a1:d8:42:77:eb: + fc:a9:94:71 +-----BEGIN CERTIFICATE----- +MIICpDCCAYwCCQCwkmSx8toh0TANBgkqhkiG9w0BAQUFADBNMQswCQYDVQQGEwJY +WTEmMCQGA1UECgwdUHl0aG9uIFNvZnR3YXJlIEZvdW5kYXRpb24gQ0ExFjAUBgNV +BAMMDW91ci1jYS1zZXJ2ZXIwHhcNMTMwMTA0MTk0NzA3WhcNMjIxMTEzMTk0NzA3 +WjBfMQswCQYDVQQGEwJYWTEXMBUGA1UEBxMOQ2FzdGxlIEFudGhyYXgxIzAhBgNV +BAoTGlB5dGhvbiBTb2Z0d2FyZSBGb3VuZGF0aW9uMRIwEAYDVQQDEwlsb2NhbGhv +c3QwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAMLgD0kAKDb5cFyPjbwNfR5C +tewdXC+kMXAWD8DLxiTTvhMW7qVnlwOm36mZlszHKvsRf05lT4pegiFM9z2j1Ola +N+ci/X7NU22TNN6crYSiN77FjYJP464j876ndSxyD+rzys386T+1r1aZaggEdkj1 +TsSsv1zWIYKlPIjlvhuxAgMBAAEwDQYJKoZIhvcNAQEFBQADggEBAC9CX6MJLPpR +iMc3f+oOY/CimuVa4sgg8D9gvMgPtsZ2ztuDk/WjM2cBjgTNAJpz/fM1hvrXE+JG +xp3AKVPUqZC4d0vmg3bkktacUM9D0MYBd2Ga3ptw93LNWQAxadm0ygacbcPHgIxo +5rWi+O8duxafd3fvh2Iim01ppDoa8SFejDKskv0VaxjCfxUNmDDKdY8acd/aHbLv +mugtLgL9SjyqlgsGXTWzPSSHS+CwWGAvRawuSIqwmRBlJ//Msdj9vSZruQwFKvRF +YzVRB+2Dhf5vacu7QKiutjtWSi2k7W0RLE3tFyT9R7zTQaLTBv4MkNjYlCbE/8yh +2EJ36/yplHE= +-----END CERTIFICATE----- diff --git a/tests/pycacert.pem b/tests/pycacert.pem new file mode 100644 index 00000000..09b1f3e0 --- /dev/null +++ b/tests/pycacert.pem @@ -0,0 +1,78 @@ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: 12723342612721443280 (0xb09264b1f2da21d0) + Signature Algorithm: sha1WithRSAEncryption + Issuer: C=XY, O=Python Software Foundation CA, CN=our-ca-server + Validity + Not Before: Jan 4 19:47:07 2013 GMT + Not After : Jan 2 19:47:07 2023 GMT + Subject: C=XY, O=Python Software Foundation CA, CN=our-ca-server + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (2048 bit) + Modulus: + 00:e7:de:e9:e3:0c:9f:00:b6:a1:fd:2b:5b:96:d2: + 6f:cc:e0:be:86:b9:20:5e:ec:03:7a:55:ab:ea:a4: + e9:f9:49:85:d2:66:d5:ed:c7:7a:ea:56:8e:2d:8f: + e7:42:e2:62:28:a9:9f:d6:1b:8e:eb:b5:b4:9c:9f: + 14:ab:df:e6:94:8b:76:1d:3e:6d:24:61:ed:0c:bf: + 00:8a:61:0c:df:5c:c8:36:73:16:00:cd:47:ba:6d: + a4:a4:74:88:83:23:0a:19:fc:09:a7:3c:4a:4b:d3: + e7:1d:2d:e4:ea:4c:54:21:f3:26:db:89:37:18:d4: + 02:bb:40:32:5f:a4:ff:2d:1c:f7:d4:bb:ec:8e:cf: + 5c:82:ac:e6:7c:08:6c:48:85:61:07:7f:25:e0:5c: + e0:bc:34:5f:e0:b9:04:47:75:c8:47:0b:8d:bc:d6: + c8:68:5f:33:83:62:d2:20:44:35:b1:ad:81:1a:8a: + cd:bc:35:b0:5c:8b:47:d6:18:e9:9c:18:97:cc:01: + 3c:29:cc:e8:1e:e4:e4:c1:b8:de:e7:c2:11:18:87: + 5a:93:34:d8:a6:25:f7:14:71:eb:e4:21:a2:d2:0f: + 2e:2e:d4:62:00:35:d3:d6:ef:5c:60:4b:4c:a9:14: + e2:dd:15:58:46:37:33:26:b7:e7:2e:5d:ed:42:e4: + c5:4d + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Subject Key Identifier: + BC:DD:62:D9:76:DA:1B:D2:54:6B:CF:E0:66:9B:1E:1E:7B:56:0C:0B + X509v3 Authority Key Identifier: + keyid:BC:DD:62:D9:76:DA:1B:D2:54:6B:CF:E0:66:9B:1E:1E:7B:56:0C:0B + + X509v3 Basic Constraints: + CA:TRUE + Signature Algorithm: sha1WithRSAEncryption + 7d:0a:f5:cb:8d:d3:5d:bd:99:8e:f8:2b:0f:ba:eb:c2:d9:a6: + 27:4f:2e:7b:2f:0e:64:d8:1c:35:50:4e:ee:fc:90:b9:8d:6d: + a8:c5:c6:06:b0:af:f3:2d:bf:3b:b8:42:07:dd:18:7d:6d:95: + 54:57:85:18:60:47:2f:eb:78:1b:f9:e8:17:fd:5a:0d:87:17: + 28:ac:4c:6a:e6:bc:29:f4:f4:55:70:29:42:de:85:ea:ab:6c: + 23:06:64:30:75:02:8e:53:bc:5e:01:33:37:cc:1e:cd:b8:a4: + fd:ca:e4:5f:65:3b:83:1c:86:f1:55:02:a0:3a:8f:db:91:b7: + 40:14:b4:e7:8d:d2:ee:73:ba:e3:e5:34:2d:bc:94:6f:4e:24: + 06:f7:5f:8b:0e:a7:8e:6b:de:5e:75:f4:32:9a:50:b1:44:33: + 9a:d0:05:e2:78:82:ff:db:da:8a:63:eb:a9:dd:d1:bf:a0:61: + ad:e3:9e:8a:24:5d:62:0e:e7:4c:91:7f:ef:df:34:36:3b:2f: + 5d:f5:84:b2:2f:c4:6d:93:96:1a:6f:30:28:f1:da:12:9a:64: + b4:40:33:1d:bd:de:2b:53:a8:ea:be:d6:bc:4e:96:f5:44:fb: + 32:18:ae:d5:1f:f6:69:af:b6:4e:7b:1d:58:ec:3b:a9:53:a3: + 5e:58:c8:9e +-----BEGIN CERTIFICATE----- +MIIDbTCCAlWgAwIBAgIJALCSZLHy2iHQMA0GCSqGSIb3DQEBBQUAME0xCzAJBgNV +BAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUgRm91bmRhdGlvbiBDQTEW +MBQGA1UEAwwNb3VyLWNhLXNlcnZlcjAeFw0xMzAxMDQxOTQ3MDdaFw0yMzAxMDIx +OTQ3MDdaME0xCzAJBgNVBAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUg +Rm91bmRhdGlvbiBDQTEWMBQGA1UEAwwNb3VyLWNhLXNlcnZlcjCCASIwDQYJKoZI +hvcNAQEBBQADggEPADCCAQoCggEBAOfe6eMMnwC2of0rW5bSb8zgvoa5IF7sA3pV +q+qk6flJhdJm1e3HeupWji2P50LiYiipn9Ybjuu1tJyfFKvf5pSLdh0+bSRh7Qy/ +AIphDN9cyDZzFgDNR7ptpKR0iIMjChn8Cac8SkvT5x0t5OpMVCHzJtuJNxjUArtA +Ml+k/y0c99S77I7PXIKs5nwIbEiFYQd/JeBc4Lw0X+C5BEd1yEcLjbzWyGhfM4Ni +0iBENbGtgRqKzbw1sFyLR9YY6ZwYl8wBPCnM6B7k5MG43ufCERiHWpM02KYl9xRx +6+QhotIPLi7UYgA109bvXGBLTKkU4t0VWEY3Mya35y5d7ULkxU0CAwEAAaNQME4w +HQYDVR0OBBYEFLzdYtl22hvSVGvP4GabHh57VgwLMB8GA1UdIwQYMBaAFLzdYtl2 +2hvSVGvP4GabHh57VgwLMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEB +AH0K9cuN0129mY74Kw+668LZpidPLnsvDmTYHDVQTu78kLmNbajFxgawr/Mtvzu4 +QgfdGH1tlVRXhRhgRy/reBv56Bf9Wg2HFyisTGrmvCn09FVwKULeheqrbCMGZDB1 +Ao5TvF4BMzfMHs24pP3K5F9lO4MchvFVAqA6j9uRt0AUtOeN0u5zuuPlNC28lG9O +JAb3X4sOp45r3l519DKaULFEM5rQBeJ4gv/b2opj66nd0b+gYa3jnookXWIO50yR +f+/fNDY7L131hLIvxG2TlhpvMCjx2hKaZLRAMx293itTqOq+1rxOlvVE+zIYrtUf +9mmvtk57HVjsO6lTo15YyJ4= +-----END CERTIFICATE----- diff --git a/tests/sample.crt b/tests/sample.crt new file mode 100644 index 00000000..6a1e3f3c --- /dev/null +++ b/tests/sample.crt @@ -0,0 +1,14 @@ +-----BEGIN CERTIFICATE----- +MIICMzCCAZwCCQDFl4ys0fU7iTANBgkqhkiG9w0BAQUFADBeMQswCQYDVQQGEwJV +UzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwNU2FuLUZyYW5jaXNjbzEi +MCAGA1UECgwZUHl0aG9uIFNvZnR3YXJlIEZvbmRhdGlvbjAeFw0xMzAzMTgyMDA3 +MjhaFw0yMzAzMTYyMDA3MjhaMF4xCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxp +Zm9ybmlhMRYwFAYDVQQHDA1TYW4tRnJhbmNpc2NvMSIwIAYDVQQKDBlQeXRob24g +U29mdHdhcmUgRm9uZGF0aW9uMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCn +t3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx/47Vc5TZSaO11uO7 +gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIiqusnLfpqR8cIAavg +Z06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQABMA0GCSqGSIb3DQEB +BQUAA4GBAE9PknG6pv72+5z/gsDGYy8sK5UNkbWSNr4i4e5lxVsF03+/M71H+3AB +MxVX4+A+Vlk2fmU+BrdHIIUE0r1dDcO3josQ9hc9OJpp5VLSQFP8VeuJCmzYPp9I +I8WbW93cnXnChTrYQVdgVoFdv7GE9YgU7NYkrGIM0nZl1/f/bHPB +-----END CERTIFICATE----- diff --git a/tests/sample.key b/tests/sample.key new file mode 100644 index 00000000..edfea8dc --- /dev/null +++ b/tests/sample.key @@ -0,0 +1,15 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXQIBAAKBgQCnt3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx +/47Vc5TZSaO11uO7gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIi +qusnLfpqR8cIAavgZ06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQAB +AoGABfm8k19Yue3W68BecKEGS0VBV57GRTPT+MiBGvVGNIQ15gk6w3sGfMZsdD1y +bsUkQgcDb2d/4i5poBTpl/+Cd41V+c20IC/sSl5X1IEreHMKSLhy/uyjyiyfXlP1 +iXhToFCgLWwENWc8LzfUV8vuAV5WG6oL9bnudWzZxeqx8V0CQQDR7xwVj6LN70Eb +DUhSKLkusmFw5Gk9NJ/7wZ4eHg4B8c9KNVvSlLCLhcsVTQXuqYeFpOqytI45SneP +lr0vrvsDAkEAzITYiXu6ox5huDCG7imX2W9CAYuX638urLxBqBXMS7GqBzojD6RL +21Q8oPwJWJquERa3HDScq1deiQbM9uKIkwJBAIa1PLslGN216Xv3UPHPScyKD/aF +ynXIv+OnANPoiyp6RH4ksQ/18zcEGiVH8EeNpvV9tlAHhb+DZibQHgNr74sCQQC0 +zhToplu/bVKSlUQUNO0rqrI9z30FErDewKeCw5KSsIRSU1E/uM3fHr9iyq4wiL6u +GNjUtKZ0y46lsT9uW6LFAkB5eqeEQnshAdr3X5GykWHJ8DDGBXPPn6Rce1NX4RSq +V9khG2z1bFyfo+hMqpYnF2k32hVq3E54RS8YYnwBsVof +-----END RSA PRIVATE KEY----- diff --git a/tests/ssl_cert.pem b/tests/ssl_cert.pem new file mode 100644 index 00000000..47a7d7e3 --- /dev/null +++ b/tests/ssl_cert.pem @@ -0,0 +1,15 @@ +-----BEGIN CERTIFICATE----- +MIICVDCCAb2gAwIBAgIJANfHOBkZr8JOMA0GCSqGSIb3DQEBBQUAMF8xCzAJBgNV +BAYTAlhZMRcwFQYDVQQHEw5DYXN0bGUgQW50aHJheDEjMCEGA1UEChMaUHl0aG9u +IFNvZnR3YXJlIEZvdW5kYXRpb24xEjAQBgNVBAMTCWxvY2FsaG9zdDAeFw0xMDEw +MDgyMzAxNTZaFw0yMDEwMDUyMzAxNTZaMF8xCzAJBgNVBAYTAlhZMRcwFQYDVQQH +Ew5DYXN0bGUgQW50aHJheDEjMCEGA1UEChMaUHl0aG9uIFNvZnR3YXJlIEZvdW5k +YXRpb24xEjAQBgNVBAMTCWxvY2FsaG9zdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAw +gYkCgYEA21vT5isq7F68amYuuNpSFlKDPrMUCa4YWYqZRt2OZ+/3NKaZ2xAiSwr7 +6MrQF70t5nLbSPpqE5+5VrS58SY+g/sXLiFd6AplH1wJZwh78DofbFYXUggktFMt +pTyiX8jtP66bkcPkDADA089RI1TQR6Ca+n7HFa7c1fabVV6i3zkCAwEAAaMYMBYw +FAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBBQUAA4GBAHPctQBEQ4wd +BJ6+JcpIraopLn8BGhbjNWj40mmRqWB/NAWF6M5ne7KpGAu7tLeG4hb1zLaldK8G +lxy2GPSRF6LFS48dpEj2HbMv2nvv6xxalDMJ9+DicWgAKTQ6bcX2j3GUkCR0g/T1 +CRlNBAAlvhKzO7Clpf9l0YKBEfraJByX +-----END CERTIFICATE----- diff --git a/tests/ssl_key.pem b/tests/ssl_key.pem new file mode 100644 index 00000000..3fd3bbd5 --- /dev/null +++ b/tests/ssl_key.pem @@ -0,0 +1,16 @@ +-----BEGIN PRIVATE KEY----- +MIICdwIBADANBgkqhkiG9w0BAQEFAASCAmEwggJdAgEAAoGBANtb0+YrKuxevGpm +LrjaUhZSgz6zFAmuGFmKmUbdjmfv9zSmmdsQIksK++jK0Be9LeZy20j6ahOfuVa0 +ufEmPoP7Fy4hXegKZR9cCWcIe/A6H2xWF1IIJLRTLaU8ol/I7T+um5HD5AwAwNPP +USNU0Eegmvp+xxWu3NX2m1Veot85AgMBAAECgYA3ZdZ673X0oexFlq7AAmrutkHt +CL7LvwrpOiaBjhyTxTeSNWzvtQBkIU8DOI0bIazA4UreAFffwtvEuPmonDb3F+Iq +SMAu42XcGyVZEl+gHlTPU9XRX7nTOXVt+MlRRRxL6t9GkGfUAXI3XxJDXW3c0vBK +UL9xqD8cORXOfE06rQJBAP8mEX1ERkR64Ptsoe4281vjTlNfIbs7NMPkUnrn9N/Y +BLhjNIfQ3HFZG8BTMLfX7kCS9D593DW5tV4Z9BP/c6cCQQDcFzCcVArNh2JSywOQ +ZfTfRbJg/Z5Lt9Fkngv1meeGNPgIMLN8Sg679pAOOWmzdMO3V706rNPzSVMME7E5 +oPIfAkEA8pDddarP5tCvTTgUpmTFbakm0KoTZm2+FzHcnA4jRh+XNTjTOv98Y6Ik +eO5d1ZnKXseWvkZncQgxfdnMqqpj5wJAcNq/RVne1DbYlwWchT2Si65MYmmJ8t+F +0mcsULqjOnEMwf5e+ptq5LzwbyrHZYq5FNk7ocufPv/ZQrcSSC+cFwJBAKvOJByS +x56qyGeZLOQlWS2JS3KJo59XuLFGqcbgN9Om9xFa41Yb4N9NvplFivsvZdw3m1Q/ +SPIXQuT8RMPDVNQ= +-----END PRIVATE KEY----- diff --git a/tests/test_base_events.py b/tests/test_base_events.py new file mode 100644 index 00000000..db9d732c --- /dev/null +++ b/tests/test_base_events.py @@ -0,0 +1,1183 @@ +"""Tests for base_events.py""" + +import errno +import logging +import math +import socket +import sys +import time +import unittest +from unittest import mock +from test.script_helper import assert_python_ok +from test.support import IPV6_ENABLED, gc_collect + +import asyncio +from asyncio import base_events +from asyncio import constants +from asyncio import test_utils + + +MOCK_ANY = mock.ANY +PY34 = sys.version_info >= (3, 4) + + +class BaseEventLoopTests(test_utils.TestCase): + + def setUp(self): + self.loop = base_events.BaseEventLoop() + self.loop._selector = mock.Mock() + self.loop._selector.select.return_value = () + self.set_event_loop(self.loop) + + def test_not_implemented(self): + m = mock.Mock() + self.assertRaises( + NotImplementedError, + self.loop._make_socket_transport, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_ssl_transport, m, m, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_datagram_transport, m, m) + self.assertRaises( + NotImplementedError, self.loop._process_events, []) + self.assertRaises( + NotImplementedError, self.loop._write_to_self) + self.assertRaises( + NotImplementedError, + self.loop._make_read_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_write_pipe_transport, m, m) + gen = self.loop._make_subprocess_transport(m, m, m, m, m, m, m) + self.assertRaises(NotImplementedError, next, iter(gen)) + + def test_close(self): + self.assertFalse(self.loop.is_closed()) + self.loop.close() + self.assertTrue(self.loop.is_closed()) + + # it should be possible to call close() more than once + self.loop.close() + self.loop.close() + + # operation blocked when the loop is closed + f = asyncio.Future(loop=self.loop) + self.assertRaises(RuntimeError, self.loop.run_forever) + self.assertRaises(RuntimeError, self.loop.run_until_complete, f) + + def test__add_callback_handle(self): + h = asyncio.Handle(lambda: False, (), self.loop) + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertIn(h, self.loop._ready) + + def test__add_callback_cancelled_handle(self): + h = asyncio.Handle(lambda: False, (), self.loop) + h.cancel() + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertFalse(self.loop._ready) + + def test_set_default_executor(self): + executor = mock.Mock() + self.loop.set_default_executor(executor) + self.assertIs(executor, self.loop._default_executor) + + def test_getnameinfo(self): + sockaddr = mock.Mock() + self.loop.run_in_executor = mock.Mock() + self.loop.getnameinfo(sockaddr) + self.assertEqual( + (None, socket.getnameinfo, sockaddr, 0), + self.loop.run_in_executor.call_args[0]) + + def test_call_soon(self): + def cb(): + pass + + h = self.loop.call_soon(cb) + self.assertEqual(h._callback, cb) + self.assertIsInstance(h, asyncio.Handle) + self.assertIn(h, self.loop._ready) + + def test_call_later(self): + def cb(): + pass + + h = self.loop.call_later(10.0, cb) + self.assertIsInstance(h, asyncio.TimerHandle) + self.assertIn(h, self.loop._scheduled) + self.assertNotIn(h, self.loop._ready) + + def test_call_later_negative_delays(self): + calls = [] + + def cb(arg): + calls.append(arg) + + self.loop._process_events = mock.Mock() + self.loop.call_later(-1, cb, 'a') + self.loop.call_later(-2, cb, 'b') + test_utils.run_briefly(self.loop) + self.assertEqual(calls, ['b', 'a']) + + def test_time_and_call_at(self): + def cb(): + self.loop.stop() + + self.loop._process_events = mock.Mock() + delay = 0.1 + + when = self.loop.time() + delay + self.loop.call_at(when, cb) + t0 = self.loop.time() + self.loop.run_forever() + dt = self.loop.time() - t0 + + # 50 ms: maximum granularity of the event loop + self.assertGreaterEqual(dt, delay - 0.050, dt) + # tolerate a difference of +800 ms because some Python buildbots + # are really slow + self.assertLessEqual(dt, 0.9, dt) + + def test_assert_is_current_event_loop(self): + def cb(): + pass + + other_loop = base_events.BaseEventLoop() + other_loop._selector = mock.Mock() + asyncio.set_event_loop(other_loop) + + # raise RuntimeError if the event loop is different in debug mode + self.loop.set_debug(True) + with self.assertRaises(RuntimeError): + self.loop.call_soon(cb) + with self.assertRaises(RuntimeError): + self.loop.call_later(60, cb) + with self.assertRaises(RuntimeError): + self.loop.call_at(self.loop.time() + 60, cb) + + # check disabled if debug mode is disabled + self.loop.set_debug(False) + self.loop.call_soon(cb) + self.loop.call_later(60, cb) + self.loop.call_at(self.loop.time() + 60, cb) + + def test_run_once_in_executor_handle(self): + def cb(): + pass + + self.assertRaises( + AssertionError, self.loop.run_in_executor, + None, asyncio.Handle(cb, (), self.loop), ('',)) + self.assertRaises( + AssertionError, self.loop.run_in_executor, + None, asyncio.TimerHandle(10, cb, (), self.loop)) + + def test_run_once_in_executor_cancelled(self): + def cb(): + pass + h = asyncio.Handle(cb, (), self.loop) + h.cancel() + + f = self.loop.run_in_executor(None, h) + self.assertIsInstance(f, asyncio.Future) + self.assertTrue(f.done()) + self.assertIsNone(f.result()) + + def test_run_once_in_executor_plain(self): + def cb(): + pass + h = asyncio.Handle(cb, (), self.loop) + f = asyncio.Future(loop=self.loop) + executor = mock.Mock() + executor.submit.return_value = f + + self.loop.set_default_executor(executor) + + res = self.loop.run_in_executor(None, h) + self.assertIs(f, res) + + executor = mock.Mock() + executor.submit.return_value = f + res = self.loop.run_in_executor(executor, h) + self.assertIs(f, res) + self.assertTrue(executor.submit.called) + + f.cancel() # Don't complain about abandoned Future. + + def test__run_once(self): + h1 = asyncio.TimerHandle(time.monotonic() + 5.0, lambda: True, (), + self.loop) + h2 = asyncio.TimerHandle(time.monotonic() + 10.0, lambda: True, (), + self.loop) + + h1.cancel() + + self.loop._process_events = mock.Mock() + self.loop._scheduled.append(h1) + self.loop._scheduled.append(h2) + self.loop._run_once() + + t = self.loop._selector.select.call_args[0][0] + self.assertTrue(9.5 < t < 10.5, t) + self.assertEqual([h2], self.loop._scheduled) + self.assertTrue(self.loop._process_events.called) + + def test_set_debug(self): + self.loop.set_debug(True) + self.assertTrue(self.loop.get_debug()) + self.loop.set_debug(False) + self.assertFalse(self.loop.get_debug()) + + @mock.patch('asyncio.base_events.logger') + def test__run_once_logging(self, m_logger): + def slow_select(timeout): + # Sleep a bit longer than a second to avoid timer resolution issues. + time.sleep(1.1) + return [] + + # logging needs debug flag + self.loop.set_debug(True) + + # Log to INFO level if timeout > 1.0 sec. + self.loop._selector.select = slow_select + self.loop._process_events = mock.Mock() + self.loop._run_once() + self.assertEqual(logging.INFO, m_logger.log.call_args[0][0]) + + def fast_select(timeout): + time.sleep(0.001) + return [] + + self.loop._selector.select = fast_select + self.loop._run_once() + self.assertEqual(logging.DEBUG, m_logger.log.call_args[0][0]) + + def test__run_once_schedule_handle(self): + handle = None + processed = False + + def cb(loop): + nonlocal processed, handle + processed = True + handle = loop.call_soon(lambda: True) + + h = asyncio.TimerHandle(time.monotonic() - 1, cb, (self.loop,), + self.loop) + + self.loop._process_events = mock.Mock() + self.loop._scheduled.append(h) + self.loop._run_once() + + self.assertTrue(processed) + self.assertEqual([handle], list(self.loop._ready)) + + def test__run_once_cancelled_event_cleanup(self): + self.loop._process_events = mock.Mock() + + self.assertTrue( + 0 < base_events._MIN_CANCELLED_TIMER_HANDLES_FRACTION < 1.0) + + def cb(): + pass + + # Set up one "blocking" event that will not be cancelled to + # ensure later cancelled events do not make it to the head + # of the queue and get cleaned. + not_cancelled_count = 1 + self.loop.call_later(3000, cb) + + # Add less than threshold (base_events._MIN_SCHEDULED_TIMER_HANDLES) + # cancelled handles, ensure they aren't removed + + cancelled_count = 2 + for x in range(2): + h = self.loop.call_later(3600, cb) + h.cancel() + + # Add some cancelled events that will be at head and removed + cancelled_count += 2 + for x in range(2): + h = self.loop.call_later(100, cb) + h.cancel() + + # This test is invalid if _MIN_SCHEDULED_TIMER_HANDLES is too low + self.assertLessEqual(cancelled_count + not_cancelled_count, + base_events._MIN_SCHEDULED_TIMER_HANDLES) + + self.assertEqual(self.loop._timer_cancelled_count, cancelled_count) + + self.loop._run_once() + + cancelled_count -= 2 + + self.assertEqual(self.loop._timer_cancelled_count, cancelled_count) + + self.assertEqual(len(self.loop._scheduled), + cancelled_count + not_cancelled_count) + + # Need enough events to pass _MIN_CANCELLED_TIMER_HANDLES_FRACTION + # so that deletion of cancelled events will occur on next _run_once + add_cancel_count = int(math.ceil( + base_events._MIN_SCHEDULED_TIMER_HANDLES * + base_events._MIN_CANCELLED_TIMER_HANDLES_FRACTION)) + 1 + + add_not_cancel_count = max(base_events._MIN_SCHEDULED_TIMER_HANDLES - + add_cancel_count, 0) + + # Add some events that will not be cancelled + not_cancelled_count += add_not_cancel_count + for x in range(add_not_cancel_count): + self.loop.call_later(3600, cb) + + # Add enough cancelled events + cancelled_count += add_cancel_count + for x in range(add_cancel_count): + h = self.loop.call_later(3600, cb) + h.cancel() + + # Ensure all handles are still scheduled + self.assertEqual(len(self.loop._scheduled), + cancelled_count + not_cancelled_count) + + self.loop._run_once() + + # Ensure cancelled events were removed + self.assertEqual(len(self.loop._scheduled), not_cancelled_count) + + # Ensure only uncancelled events remain scheduled + self.assertTrue(all([not x._cancelled for x in self.loop._scheduled])) + + def test_run_until_complete_type_error(self): + self.assertRaises(TypeError, + self.loop.run_until_complete, 'blah') + + def test_run_until_complete_loop(self): + task = asyncio.Future(loop=self.loop) + other_loop = self.new_test_loop() + self.assertRaises(ValueError, + other_loop.run_until_complete, task) + + def test_subprocess_exec_invalid_args(self): + args = [sys.executable, '-c', 'pass'] + + # missing program parameter (empty args) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol) + + # expected multiple arguments, not a list + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, args) + + # program arguments must be strings, not int + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, sys.executable, 123) + + # universal_newlines, shell, bufsize must not be set + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, *args, universal_newlines=True) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, *args, shell=True) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, *args, bufsize=4096) + + def test_subprocess_shell_invalid_args(self): + # expected a string, not an int or a list + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, 123) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, [sys.executable, '-c', 'pass']) + + # universal_newlines, shell, bufsize must not be set + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, 'exit 0', universal_newlines=True) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, 'exit 0', shell=True) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, 'exit 0', bufsize=4096) + + def test_default_exc_handler_callback(self): + self.loop._process_events = mock.Mock() + + def zero_error(fut): + fut.set_result(True) + 1/0 + + # Test call_soon (events.Handle) + with mock.patch('asyncio.base_events.logger') as log: + fut = asyncio.Future(loop=self.loop) + self.loop.call_soon(zero_error, fut) + fut.add_done_callback(lambda fut: self.loop.stop()) + self.loop.run_forever() + log.error.assert_called_with( + test_utils.MockPattern('Exception in callback.*zero'), + exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) + + # Test call_later (events.TimerHandle) + with mock.patch('asyncio.base_events.logger') as log: + fut = asyncio.Future(loop=self.loop) + self.loop.call_later(0.01, zero_error, fut) + fut.add_done_callback(lambda fut: self.loop.stop()) + self.loop.run_forever() + log.error.assert_called_with( + test_utils.MockPattern('Exception in callback.*zero'), + exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) + + def test_default_exc_handler_coro(self): + self.loop._process_events = mock.Mock() + + @asyncio.coroutine + def zero_error_coro(): + yield from asyncio.sleep(0.01, loop=self.loop) + 1/0 + + # Test Future.__del__ + with mock.patch('asyncio.base_events.logger') as log: + fut = asyncio.async(zero_error_coro(), loop=self.loop) + fut.add_done_callback(lambda *args: self.loop.stop()) + self.loop.run_forever() + fut = None # Trigger Future.__del__ or futures._TracebackLogger + if PY34: + # Future.__del__ in Python 3.4 logs error with + # an actual exception context + log.error.assert_called_with( + test_utils.MockPattern('.*exception was never retrieved'), + exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) + else: + # futures._TracebackLogger logs only textual traceback + log.error.assert_called_with( + test_utils.MockPattern( + '.*exception was never retrieved.*ZeroDiv'), + exc_info=False) + + def test_set_exc_handler_invalid(self): + with self.assertRaisesRegex(TypeError, 'A callable object or None'): + self.loop.set_exception_handler('spam') + + def test_set_exc_handler_custom(self): + def zero_error(): + 1/0 + + def run_loop(): + handle = self.loop.call_soon(zero_error) + self.loop._run_once() + return handle + + self.loop.set_debug(True) + self.loop._process_events = mock.Mock() + + mock_handler = mock.Mock() + self.loop.set_exception_handler(mock_handler) + handle = run_loop() + mock_handler.assert_called_with(self.loop, { + 'exception': MOCK_ANY, + 'message': test_utils.MockPattern( + 'Exception in callback.*zero_error'), + 'handle': handle, + 'source_traceback': handle._source_traceback, + }) + mock_handler.reset_mock() + + self.loop.set_exception_handler(None) + with mock.patch('asyncio.base_events.logger') as log: + run_loop() + log.error.assert_called_with( + test_utils.MockPattern( + 'Exception in callback.*zero'), + exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) + + assert not mock_handler.called + + def test_set_exc_handler_broken(self): + def run_loop(): + def zero_error(): + 1/0 + self.loop.call_soon(zero_error) + self.loop._run_once() + + def handler(loop, context): + raise AttributeError('spam') + + self.loop._process_events = mock.Mock() + + self.loop.set_exception_handler(handler) + + with mock.patch('asyncio.base_events.logger') as log: + run_loop() + log.error.assert_called_with( + test_utils.MockPattern( + 'Unhandled error in exception handler'), + exc_info=(AttributeError, MOCK_ANY, MOCK_ANY)) + + def test_default_exc_handler_broken(self): + _context = None + + class Loop(base_events.BaseEventLoop): + + _selector = mock.Mock() + _process_events = mock.Mock() + + def default_exception_handler(self, context): + nonlocal _context + _context = context + # Simulates custom buggy "default_exception_handler" + raise ValueError('spam') + + loop = Loop() + asyncio.set_event_loop(loop) + + def run_loop(): + def zero_error(): + 1/0 + loop.call_soon(zero_error) + loop._run_once() + + with mock.patch('asyncio.base_events.logger') as log: + run_loop() + log.error.assert_called_with( + 'Exception in default exception handler', + exc_info=True) + + def custom_handler(loop, context): + raise ValueError('ham') + + _context = None + loop.set_exception_handler(custom_handler) + with mock.patch('asyncio.base_events.logger') as log: + run_loop() + log.error.assert_called_with( + test_utils.MockPattern('Exception in default exception.*' + 'while handling.*in custom'), + exc_info=True) + + # Check that original context was passed to default + # exception handler. + self.assertIn('context', _context) + self.assertIs(type(_context['context']['exception']), + ZeroDivisionError) + + def test_env_var_debug(self): + code = '\n'.join(( + 'import asyncio', + 'loop = asyncio.get_event_loop()', + 'print(loop.get_debug())')) + + # Test with -E to not fail if the unit test was run with + # PYTHONASYNCIODEBUG set to a non-empty string + sts, stdout, stderr = assert_python_ok('-E', '-c', code) + self.assertEqual(stdout.rstrip(), b'False') + + sts, stdout, stderr = assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='') + self.assertEqual(stdout.rstrip(), b'False') + + sts, stdout, stderr = assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='1') + self.assertEqual(stdout.rstrip(), b'True') + + sts, stdout, stderr = assert_python_ok('-E', '-c', code, + PYTHONASYNCIODEBUG='1') + self.assertEqual(stdout.rstrip(), b'False') + + def test_create_task(self): + class MyTask(asyncio.Task): + pass + + @asyncio.coroutine + def test(): + pass + + class EventLoop(base_events.BaseEventLoop): + def create_task(self, coro): + return MyTask(coro, loop=loop) + + loop = EventLoop() + self.set_event_loop(loop) + + coro = test() + task = asyncio.async(coro, loop=loop) + self.assertIsInstance(task, MyTask) + + # make warnings quiet + task._log_destroy_pending = False + coro.close() + + def test_run_forever_keyboard_interrupt(self): + # Python issue #22601: ensure that the temporary task created by + # run_forever() consumes the KeyboardInterrupt and so don't log + # a warning + @asyncio.coroutine + def raise_keyboard_interrupt(): + raise KeyboardInterrupt + + self.loop._process_events = mock.Mock() + self.loop.call_exception_handler = mock.Mock() + + try: + self.loop.run_until_complete(raise_keyboard_interrupt()) + except KeyboardInterrupt: + pass + self.loop.close() + gc_collect() + + self.assertFalse(self.loop.call_exception_handler.called) + + def test_run_until_complete_baseexception(self): + # Python issue #22429: run_until_complete() must not schedule a pending + # call to stop() if the future raised a BaseException + @asyncio.coroutine + def raise_keyboard_interrupt(): + raise KeyboardInterrupt + + self.loop._process_events = mock.Mock() + + try: + self.loop.run_until_complete(raise_keyboard_interrupt()) + except KeyboardInterrupt: + pass + + def func(): + self.loop.stop() + func.called = True + func.called = False + try: + self.loop.call_soon(func) + self.loop.run_forever() + except KeyboardInterrupt: + pass + self.assertTrue(func.called) + + +class MyProto(asyncio.Protocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = asyncio.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyDatagramProto(asyncio.DatagramProtocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = asyncio.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def error_received(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class BaseEventLoopWithSelectorTests(test_utils.TestCase): + + def setUp(self): + self.loop = asyncio.new_event_loop() + self.set_event_loop(self.loop) + + @mock.patch('asyncio.base_events.socket') + def test_create_connection_multiple_errors(self, m_socket): + + class MyProto(asyncio.Protocol): + pass + + @asyncio.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + idx = -1 + errors = ['err1', 'err2'] + + def _socket(*args, **kw): + nonlocal idx, errors + idx += 1 + raise OSError(errors[idx]) + + m_socket.socket = _socket + + self.loop.getaddrinfo = getaddrinfo_task + + coro = self.loop.create_connection(MyProto, 'example.com', 80) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(coro) + + self.assertEqual(str(cm.exception), 'Multiple exceptions: err1, err2') + + @mock.patch('asyncio.base_events.socket') + def test_create_connection_timeout(self, m_socket): + # Ensure that the socket is closed on timeout + sock = mock.Mock() + m_socket.socket.return_value = sock + + def getaddrinfo(*args, **kw): + fut = asyncio.Future(loop=self.loop) + addr = (socket.AF_INET, socket.SOCK_STREAM, 0, '', + ('127.0.0.1', 80)) + fut.set_result([addr]) + return fut + self.loop.getaddrinfo = getaddrinfo + + with mock.patch.object(self.loop, 'sock_connect', + side_effect=asyncio.TimeoutError): + coro = self.loop.create_connection(MyProto, '127.0.0.1', 80) + with self.assertRaises(asyncio.TimeoutError): + self.loop.run_until_complete(coro) + self.assertTrue(sock.close.called) + + def test_create_connection_host_port_sock(self): + coro = self.loop.create_connection( + MyProto, 'example.com', 80, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_host_port_sock(self): + coro = self.loop.create_connection(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_getaddrinfo(self): + @asyncio.coroutine + def getaddrinfo(*args, **kw): + yield from [] + + def getaddrinfo_task(*args, **kwds): + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_connect_err(self): + @asyncio.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_multiple(self): + @asyncio.coroutine + def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + + def getaddrinfo_task(*args, **kwds): + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET) + with self.assertRaises(OSError): + self.loop.run_until_complete(coro) + + @mock.patch('asyncio.base_events.socket') + def test_create_connection_multiple_errors_local_addr(self, m_socket): + + def bind(addr): + if addr[0] == '0.0.0.1': + err = OSError('Err') + err.strerror = 'Err' + raise err + + m_socket.socket.return_value.bind = bind + + @asyncio.coroutine + def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + + def getaddrinfo_task(*args, **kwds): + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = mock.Mock() + self.loop.sock_connect.side_effect = OSError('Err2') + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(coro) + + self.assertTrue(str(cm.exception).startswith('Multiple exceptions: ')) + self.assertTrue(m_socket.socket.return_value.close.called) + + def test_create_connection_no_local_addr(self): + @asyncio.coroutine + def getaddrinfo(host, *args, **kw): + if host == 'example.com': + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + else: + return [] + + def getaddrinfo_task(*args, **kwds): + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) + self.loop.getaddrinfo = getaddrinfo_task + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_ssl_server_hostname_default(self): + self.loop.getaddrinfo = mock.Mock() + + def mock_getaddrinfo(*args, **kwds): + f = asyncio.Future(loop=self.loop) + f.set_result([(socket.AF_INET, socket.SOCK_STREAM, + socket.SOL_TCP, '', ('1.2.3.4', 80))]) + return f + + self.loop.getaddrinfo.side_effect = mock_getaddrinfo + self.loop.sock_connect = mock.Mock() + self.loop.sock_connect.return_value = () + self.loop._make_ssl_transport = mock.Mock() + + class _SelectorTransportMock: + _sock = None + + def get_extra_info(self, key): + return mock.Mock() + + def close(self): + self._sock.close() + + def mock_make_ssl_transport(sock, protocol, sslcontext, waiter, + **kwds): + waiter.set_result(None) + transport = _SelectorTransportMock() + transport._sock = sock + return transport + + self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport + ANY = mock.ANY + # First try the default server_hostname. + self.loop._make_ssl_transport.reset_mock() + coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True) + transport, _ = self.loop.run_until_complete(coro) + transport.close() + self.loop._make_ssl_transport.assert_called_with( + ANY, ANY, ANY, ANY, + server_side=False, + server_hostname='python.org') + # Next try an explicit server_hostname. + self.loop._make_ssl_transport.reset_mock() + coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True, + server_hostname='perl.com') + transport, _ = self.loop.run_until_complete(coro) + transport.close() + self.loop._make_ssl_transport.assert_called_with( + ANY, ANY, ANY, ANY, + server_side=False, + server_hostname='perl.com') + # Finally try an explicit empty server_hostname. + self.loop._make_ssl_transport.reset_mock() + coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True, + server_hostname='') + transport, _ = self.loop.run_until_complete(coro) + transport.close() + self.loop._make_ssl_transport.assert_called_with(ANY, ANY, ANY, ANY, + server_side=False, + server_hostname='') + + def test_create_connection_no_ssl_server_hostname_errors(self): + # When not using ssl, server_hostname must be None. + coro = self.loop.create_connection(MyProto, 'python.org', 80, + server_hostname='') + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + coro = self.loop.create_connection(MyProto, 'python.org', 80, + server_hostname='python.org') + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_ssl_server_hostname_errors(self): + # When using ssl, server_hostname may be None if host is non-empty. + coro = self.loop.create_connection(MyProto, '', 80, ssl=True) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + coro = self.loop.create_connection(MyProto, None, 80, ssl=True) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + sock = socket.socket() + coro = self.loop.create_connection(MyProto, None, None, + ssl=True, sock=sock) + self.addCleanup(sock.close) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_server_empty_host(self): + # if host is empty string use None instead + host = object() + + @asyncio.coroutine + def getaddrinfo(*args, **kw): + nonlocal host + host = args[0] + yield from [] + + def getaddrinfo_task(*args, **kwds): + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + fut = self.loop.create_server(MyProto, '', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertIsNone(host) + + def test_create_server_host_port_sock(self): + fut = self.loop.create_server( + MyProto, '0.0.0.0', 0, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_create_server_no_host_port_sock(self): + fut = self.loop.create_server(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_create_server_no_getaddrinfo(self): + getaddrinfo = self.loop.getaddrinfo = mock.Mock() + getaddrinfo.return_value = [] + + f = self.loop.create_server(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, f) + + @mock.patch('asyncio.base_events.socket') + def test_create_server_cant_bind(self, m_socket): + + class Err(OSError): + strerror = 'error' + + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] + m_socket.getaddrinfo._is_coroutine = False + m_sock = m_socket.socket.return_value = mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.create_server(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + @mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_no_addrinfo(self, m_socket): + m_socket.getaddrinfo.return_value = [] + m_socket.getaddrinfo._is_coroutine = False + + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_addr_error(self): + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr='localhost') + self.assertRaises( + AssertionError, self.loop.run_until_complete, coro) + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 1, 2, 3)) + self.assertRaises( + AssertionError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_connect_err(self): + self.loop.sock_connect = mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol, remote_addr=('127.0.0.1', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + @mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_socket_err(self, m_socket): + m_socket.getaddrinfo = socket.getaddrinfo + m_socket.socket.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + coro = self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol, local_addr=('127.0.0.1', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + @unittest.skipUnless(IPV6_ENABLED, 'IPv6 not supported or enabled') + def test_create_datagram_endpoint_no_matching_family(self): + coro = self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol, + remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) + self.assertRaises( + ValueError, self.loop.run_until_complete, coro) + + @mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_setblk_err(self, m_socket): + m_socket.socket.return_value.setblocking.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + self.assertTrue( + m_socket.socket.return_value.close.called) + + def test_create_datagram_endpoint_noaddr_nofamily(self): + coro = self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + @mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_cant_bind(self, m_socket): + class Err(OSError): + pass + + m_socket.AF_INET6 = socket.AF_INET6 + m_socket.getaddrinfo = socket.getaddrinfo + m_sock = m_socket.socket.return_value = mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, + local_addr=('127.0.0.1', 0), family=socket.AF_INET) + self.assertRaises(Err, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_accept_connection_retry(self): + sock = mock.Mock() + sock.accept.side_effect = BlockingIOError() + + self.loop._accept_connection(MyProto, sock) + self.assertFalse(sock.close.called) + + @mock.patch('asyncio.base_events.logger') + def test_accept_connection_exception(self, m_log): + sock = mock.Mock() + sock.fileno.return_value = 10 + sock.accept.side_effect = OSError(errno.EMFILE, 'Too many open files') + self.loop.remove_reader = mock.Mock() + self.loop.call_later = mock.Mock() + + self.loop._accept_connection(MyProto, sock) + self.assertTrue(m_log.error.called) + self.assertFalse(sock.close.called) + self.loop.remove_reader.assert_called_with(10) + self.loop.call_later.assert_called_with(constants.ACCEPT_RETRY_DELAY, + # self.loop._start_serving + mock.ANY, + MyProto, sock, None, None) + + def test_call_coroutine(self): + @asyncio.coroutine + def simple_coroutine(): + pass + + coro_func = simple_coroutine + coro_obj = coro_func() + self.addCleanup(coro_obj.close) + for func in (coro_func, coro_obj): + with self.assertRaises(TypeError): + self.loop.call_soon(func) + with self.assertRaises(TypeError): + self.loop.call_soon_threadsafe(func) + with self.assertRaises(TypeError): + self.loop.call_later(60, func) + with self.assertRaises(TypeError): + self.loop.call_at(self.loop.time() + 60, func) + with self.assertRaises(TypeError): + self.loop.run_in_executor(None, func) + + @mock.patch('asyncio.base_events.logger') + def test_log_slow_callbacks(self, m_logger): + def stop_loop_cb(loop): + loop.stop() + + @asyncio.coroutine + def stop_loop_coro(loop): + yield from () + loop.stop() + + asyncio.set_event_loop(self.loop) + self.loop.set_debug(True) + self.loop.slow_callback_duration = 0.0 + + # slow callback + self.loop.call_soon(stop_loop_cb, self.loop) + self.loop.run_forever() + fmt, *args = m_logger.warning.call_args[0] + self.assertRegex(fmt % tuple(args), + "^Executing took .* seconds$") + + # slow task + asyncio.async(stop_loop_coro(self.loop), loop=self.loop) + self.loop.run_forever() + fmt, *args = m_logger.warning.call_args[0] + self.assertRegex(fmt % tuple(args), + "^Executing took .* seconds$") + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_events.py b/tests/test_events.py new file mode 100644 index 00000000..6644fbea --- /dev/null +++ b/tests/test_events.py @@ -0,0 +1,2306 @@ +"""Tests for events.py.""" + +import functools +import gc +import io +import os +import platform +import re +import signal +import socket +try: + import ssl +except ImportError: + ssl = None +import subprocess +import sys +import threading +import time +import errno +import unittest +from unittest import mock +import weakref +from test import support # find_unused_port, IPV6_ENABLED, TEST_HOME_DIR + + +import asyncio +from asyncio import proactor_events +from asyncio import selector_events +from asyncio import test_utils + + +def data_file(filename): + if hasattr(support, 'TEST_HOME_DIR'): + fullname = os.path.join(support.TEST_HOME_DIR, filename) + if os.path.isfile(fullname): + return fullname + fullname = os.path.join(os.path.dirname(__file__), filename) + if os.path.isfile(fullname): + return fullname + raise FileNotFoundError(filename) + + +def osx_tiger(): + """Return True if the platform is Mac OS 10.4 or older.""" + if sys.platform != 'darwin': + return False + version = platform.mac_ver()[0] + version = tuple(map(int, version.split('.'))) + return version < (10, 5) + + +ONLYCERT = data_file('ssl_cert.pem') +ONLYKEY = data_file('ssl_key.pem') +SIGNED_CERTFILE = data_file('keycert3.pem') +SIGNING_CA = data_file('pycacert.pem') + + +class MyBaseProto(asyncio.Protocol): + connected = None + done = None + + def __init__(self, loop=None): + self.transport = None + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.connected = asyncio.Future(loop=loop) + self.done = asyncio.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + if self.connected: + self.connected.set_result(None) + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyProto(MyBaseProto): + def connection_made(self, transport): + super().connection_made(transport) + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + +class MyDatagramProto(asyncio.DatagramProtocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.done = asyncio.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def error_received(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyReadPipeProto(asyncio.Protocol): + done = None + + def __init__(self, loop=None): + self.state = ['INITIAL'] + self.nbytes = 0 + self.transport = None + if loop is not None: + self.done = asyncio.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == ['INITIAL'], self.state + self.state.append('CONNECTED') + + def data_received(self, data): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.state.append('EOF') + + def connection_lost(self, exc): + if 'EOF' not in self.state: + self.state.append('EOF') # It is okay if EOF is missed. + assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state + self.state.append('CLOSED') + if self.done: + self.done.set_result(None) + + +class MyWritePipeProto(asyncio.BaseProtocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.transport = None + if loop is not None: + self.done = asyncio.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MySubprocessProtocol(asyncio.SubprocessProtocol): + + def __init__(self, loop): + self.state = 'INITIAL' + self.transport = None + self.connected = asyncio.Future(loop=loop) + self.completed = asyncio.Future(loop=loop) + self.disconnects = {fd: asyncio.Future(loop=loop) for fd in range(3)} + self.data = {1: b'', 2: b''} + self.returncode = None + self.got_data = {1: asyncio.Event(loop=loop), + 2: asyncio.Event(loop=loop)} + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + self.connected.set_result(None) + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + self.completed.set_result(None) + + def pipe_data_received(self, fd, data): + assert self.state == 'CONNECTED', self.state + self.data[fd] += data + self.got_data[fd].set() + + def pipe_connection_lost(self, fd, exc): + assert self.state == 'CONNECTED', self.state + if exc: + self.disconnects[fd].set_exception(exc) + else: + self.disconnects[fd].set_result(exc) + + def process_exited(self): + assert self.state == 'CONNECTED', self.state + self.returncode = self.transport.get_returncode() + + +class EventLoopTestsMixin: + + def setUp(self): + super().setUp() + self.loop = self.create_event_loop() + self.set_event_loop(self.loop) + + def tearDown(self): + # just in case if we have transport close callbacks + if not self.loop.is_closed(): + test_utils.run_briefly(self.loop) + + self.loop.close() + gc.collect() + super().tearDown() + + def test_run_until_complete_nesting(self): + @asyncio.coroutine + def coro1(): + yield + + @asyncio.coroutine + def coro2(): + self.assertTrue(self.loop.is_running()) + self.loop.run_until_complete(coro1()) + + self.assertRaises( + RuntimeError, self.loop.run_until_complete, coro2()) + + # Note: because of the default Windows timing granularity of + # 15.6 msec, we use fairly long sleep times here (~100 msec). + + def test_run_until_complete(self): + t0 = self.loop.time() + self.loop.run_until_complete(asyncio.sleep(0.1, loop=self.loop)) + t1 = self.loop.time() + self.assertTrue(0.08 <= t1-t0 <= 0.8, t1-t0) + + def test_run_until_complete_stopped(self): + @asyncio.coroutine + def cb(): + self.loop.stop() + yield from asyncio.sleep(0.1, loop=self.loop) + task = cb() + self.assertRaises(RuntimeError, + self.loop.run_until_complete, task) + + def test_call_later(self): + results = [] + + def callback(arg): + results.append(arg) + self.loop.stop() + + self.loop.call_later(0.1, callback, 'hello world') + t0 = time.monotonic() + self.loop.run_forever() + t1 = time.monotonic() + self.assertEqual(results, ['hello world']) + self.assertTrue(0.08 <= t1-t0 <= 0.8, t1-t0) + + def test_call_soon(self): + results = [] + + def callback(arg1, arg2): + results.append((arg1, arg2)) + self.loop.stop() + + self.loop.call_soon(callback, 'hello', 'world') + self.loop.run_forever() + self.assertEqual(results, [('hello', 'world')]) + + def test_call_soon_threadsafe(self): + results = [] + lock = threading.Lock() + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.loop.stop() + + def run_in_thread(): + self.loop.call_soon_threadsafe(callback, 'hello') + lock.release() + + lock.acquire() + t = threading.Thread(target=run_in_thread) + t.start() + + with lock: + self.loop.call_soon(callback, 'world') + self.loop.run_forever() + t.join() + self.assertEqual(results, ['hello', 'world']) + + def test_call_soon_threadsafe_same_thread(self): + results = [] + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.loop.stop() + + self.loop.call_soon_threadsafe(callback, 'hello') + self.loop.call_soon(callback, 'world') + self.loop.run_forever() + self.assertEqual(results, ['hello', 'world']) + + def test_run_in_executor(self): + def run(arg): + return (arg, threading.get_ident()) + f2 = self.loop.run_in_executor(None, run, 'yo') + res, thread_id = self.loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + self.assertNotEqual(thread_id, threading.get_ident()) + + def test_reader_callback(self): + r, w = test_utils.socketpair() + r.setblocking(False) + bytes_read = bytearray() + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.extend(data) + else: + self.assertTrue(self.loop.remove_reader(r.fileno())) + r.close() + + self.loop.add_reader(r.fileno(), reader) + self.loop.call_soon(w.send, b'abc') + test_utils.run_until(self.loop, lambda: len(bytes_read) >= 3) + self.loop.call_soon(w.send, b'def') + test_utils.run_until(self.loop, lambda: len(bytes_read) >= 6) + self.loop.call_soon(w.close) + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + self.assertEqual(bytes_read, b'abcdef') + + def test_writer_callback(self): + r, w = test_utils.socketpair() + w.setblocking(False) + + def writer(data): + w.send(data) + self.loop.stop() + + data = b'x' * 1024 + self.loop.add_writer(w.fileno(), writer, data) + self.loop.run_forever() + + self.assertTrue(self.loop.remove_writer(w.fileno())) + self.assertFalse(self.loop.remove_writer(w.fileno())) + + w.close() + read = r.recv(len(data) * 2) + r.close() + self.assertEqual(read, data) + + def _basetest_sock_client_ops(self, httpd, sock): + if not isinstance(self.loop, proactor_events.BaseProactorEventLoop): + # in debug mode, socket operations must fail + # if the socket is not in blocking mode + self.loop.set_debug(True) + sock.setblocking(True) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_connect(sock, httpd.address)) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_accept(sock)) + + # test in non-blocking mode + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, httpd.address)) + self.loop.run_until_complete( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + # consume data + self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + sock.close() + self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) + + def test_sock_client_ops(self): + with test_utils.run_test_server() as httpd: + sock = socket.socket() + self._basetest_sock_client_ops(httpd, sock) + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_unix_sock_client_ops(self): + with test_utils.run_test_unix_server() as httpd: + sock = socket.socket(socket.AF_UNIX) + self._basetest_sock_client_ops(httpd, sock) + + def test_sock_client_fail(self): + # Make sure that we will get an unused port + address = None + try: + s = socket.socket() + s.bind(('127.0.0.1', 0)) + address = s.getsockname() + finally: + s.close() + + sock = socket.socket() + sock.setblocking(False) + with self.assertRaises(ConnectionRefusedError): + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) + sock.close() + + def test_sock_accept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + + f = self.loop.sock_accept(listener) + conn, addr = self.loop.run_until_complete(f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + def test_add_signal_handler(self): + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + # Check error behavior first. + self.assertRaises( + TypeError, self.loop.add_signal_handler, 'boom', my_handler) + self.assertRaises( + TypeError, self.loop.remove_signal_handler, 'boom') + self.assertRaises( + ValueError, self.loop.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, signal.NSIG+1) + self.assertRaises( + ValueError, self.loop.add_signal_handler, 0, my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, 0) + self.assertRaises( + ValueError, self.loop.add_signal_handler, -1, my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, -1) + self.assertRaises( + RuntimeError, self.loop.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(self.loop.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + self.loop.add_signal_handler(signal.SIGINT, my_handler) + + os.kill(os.getpid(), signal.SIGINT) + test_utils.run_until(self.loop, lambda: caught) + + # Removing it should restore the default handler. + self.assertTrue(self.loop.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(self.loop.remove_signal_handler(signal.SIGINT)) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_while_selecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + self.loop.stop() + + self.loop.add_signal_handler(signal.SIGALRM, my_handler) + + signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once. + self.loop.run_forever() + self.assertEqual(caught, 1) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_args(self): + some_args = (42,) + caught = 0 + + def my_handler(*args): + nonlocal caught + caught += 1 + self.assertEqual(args, some_args) + + self.loop.add_signal_handler(signal.SIGALRM, my_handler, *some_args) + + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + self.loop.call_later(0.5, self.loop.stop) + self.loop.run_forever() + self.assertEqual(caught, 1) + + def _basetest_create_connection(self, connection_fut, check_sockname=True): + tr, pr = self.loop.run_until_complete(connection_fut) + self.assertIsInstance(tr, asyncio.Transport) + self.assertIsInstance(pr, asyncio.Protocol) + self.assertIs(pr.transport, tr) + if check_sockname: + self.assertIsNotNone(tr.get_extra_info('sockname')) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def test_create_connection(self): + with test_utils.run_test_server() as httpd: + conn_fut = self.loop.create_connection( + lambda: MyProto(loop=self.loop), *httpd.address) + self._basetest_create_connection(conn_fut) + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_connection(self): + # Issue #20682: On Mac OS X Tiger, getsockname() returns a + # zero-length address for UNIX socket. + check_sockname = not osx_tiger() + + with test_utils.run_test_unix_server() as httpd: + conn_fut = self.loop.create_unix_connection( + lambda: MyProto(loop=self.loop), httpd.address) + self._basetest_create_connection(conn_fut, check_sockname) + + def test_create_connection_sock(self): + with test_utils.run_test_server() as httpd: + sock = None + infos = self.loop.run_until_complete( + self.loop.getaddrinfo( + *httpd.address, type=socket.SOCK_STREAM)) + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) + except: + pass + else: + break + else: + assert False, 'Can not create socket.' + + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), sock=sock) + tr, pr = self.loop.run_until_complete(f) + self.assertIsInstance(tr, asyncio.Transport) + self.assertIsInstance(pr, asyncio.Protocol) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def _basetest_create_ssl_connection(self, connection_fut, + check_sockname=True): + tr, pr = self.loop.run_until_complete(connection_fut) + self.assertIsInstance(tr, asyncio.Transport) + self.assertIsInstance(pr, asyncio.Protocol) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + if check_sockname: + self.assertIsNotNone(tr.get_extra_info('sockname')) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def _test_create_ssl_connection(self, httpd, create_connection, + check_sockname=True): + conn_fut = create_connection(ssl=test_utils.dummy_ssl_context()) + self._basetest_create_ssl_connection(conn_fut, check_sockname) + + # ssl.Purpose was introduced in Python 3.4 + if hasattr(ssl, 'Purpose'): + def _dummy_ssl_create_context(purpose=ssl.Purpose.SERVER_AUTH, *, + cafile=None, capath=None, + cadata=None): + """ + A ssl.create_default_context() replacement that doesn't enable + cert validation. + """ + self.assertEqual(purpose, ssl.Purpose.SERVER_AUTH) + return test_utils.dummy_ssl_context() + + # With ssl=True, ssl.create_default_context() should be called + with mock.patch('ssl.create_default_context', + side_effect=_dummy_ssl_create_context) as m: + conn_fut = create_connection(ssl=True) + self._basetest_create_ssl_connection(conn_fut, check_sockname) + self.assertEqual(m.call_count, 1) + + # With the real ssl.create_default_context(), certificate + # validation will fail + with self.assertRaises(ssl.SSLError) as cm: + conn_fut = create_connection(ssl=True) + # Ignore the "SSL handshake failed" log in debug mode + with test_utils.disable_logger(): + self._basetest_create_ssl_connection(conn_fut, check_sockname) + + self.assertEqual(cm.exception.reason, 'CERTIFICATE_VERIFY_FAILED') + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_ssl_connection(self): + with test_utils.run_test_server(use_ssl=True) as httpd: + create_connection = functools.partial( + self.loop.create_connection, + lambda: MyProto(loop=self.loop), + *httpd.address) + self._test_create_ssl_connection(httpd, create_connection) + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_ssl_unix_connection(self): + # Issue #20682: On Mac OS X Tiger, getsockname() returns a + # zero-length address for UNIX socket. + check_sockname = not osx_tiger() + + with test_utils.run_test_unix_server(use_ssl=True) as httpd: + create_connection = functools.partial( + self.loop.create_unix_connection, + lambda: MyProto(loop=self.loop), httpd.address, + server_hostname='127.0.0.1') + + self._test_create_ssl_connection(httpd, create_connection, + check_sockname) + + def test_create_connection_local_addr(self): + with test_utils.run_test_server() as httpd: + port = support.find_unused_port() + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, local_addr=(httpd.address[0], port)) + tr, pr = self.loop.run_until_complete(f) + expected = pr.transport.get_extra_info('sockname')[1] + self.assertEqual(port, expected) + tr.close() + + def test_create_connection_local_addr_in_use(self): + with test_utils.run_test_server() as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, local_addr=httpd.address) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + self.assertIn(str(httpd.address), cm.exception.strerror) + + def test_create_server(self): + proto = MyProto(self.loop) + f = self.loop.create_server(lambda: proto, '0.0.0.0', 0) + server = self.loop.run_until_complete(f) + self.assertEqual(len(server.sockets), 1) + sock = server.sockets[0] + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.sendall(b'xxx') + + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + test_utils.run_until(self.loop, lambda: proto.nbytes > 0) + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('sockname')) + self.assertEqual('127.0.0.1', + proto.transport.get_extra_info('peername')[0]) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # close server + server.close() + + def _make_unix_server(self, factory, **kwargs): + path = test_utils.gen_unix_socket_path() + self.addCleanup(lambda: os.path.exists(path) and os.unlink(path)) + + f = self.loop.create_unix_server(factory, path, **kwargs) + server = self.loop.run_until_complete(f) + + return server, path + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server(self): + proto = MyProto(loop=self.loop) + server, path = self._make_unix_server(lambda: proto) + self.assertEqual(len(server.sockets), 1) + + client = socket.socket(socket.AF_UNIX) + client.connect(path) + client.sendall(b'xxx') + + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + test_utils.run_until(self.loop, lambda: proto.nbytes > 0) + self.assertEqual(3, proto.nbytes) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # close server + server.close() + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_path_socket_error(self): + proto = MyProto(loop=self.loop) + sock = socket.socket() + with sock: + f = self.loop.create_unix_server(lambda: proto, '/test', sock=sock) + with self.assertRaisesRegex(ValueError, + 'path and sock can not be specified ' + 'at the same time'): + self.loop.run_until_complete(f) + + def _create_ssl_context(self, certfile, keyfile=None): + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.load_cert_chain(certfile, keyfile) + return sslcontext + + def _make_ssl_server(self, factory, certfile, keyfile=None): + sslcontext = self._create_ssl_context(certfile, keyfile) + + f = self.loop.create_server(factory, '127.0.0.1', 0, ssl=sslcontext) + server = self.loop.run_until_complete(f) + + sock = server.sockets[0] + host, port = sock.getsockname() + self.assertEqual(host, '127.0.0.1') + return server, host, port + + def _make_ssl_unix_server(self, factory, certfile, keyfile=None): + sslcontext = self._create_ssl_context(certfile, keyfile) + return self._make_unix_server(factory, ssl=sslcontext) + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_server_ssl(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, ONLYCERT, ONLYKEY) + + f_c = self.loop.create_connection(MyBaseProto, host, port, + ssl=test_utils.dummy_ssl_context()) + client, pr = self.loop.run_until_complete(f_c) + + client.write(b'xxx') + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + test_utils.run_until(self.loop, lambda: proto.nbytes > 0) + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('sockname')) + self.assertEqual('127.0.0.1', + proto.transport.get_extra_info('peername')[0]) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # stop serving + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_ssl(self): + proto = MyProto(loop=self.loop) + server, path = self._make_ssl_unix_server( + lambda: proto, ONLYCERT, ONLYKEY) + + f_c = self.loop.create_unix_connection( + MyBaseProto, path, ssl=test_utils.dummy_ssl_context(), + server_hostname='') + + client, pr = self.loop.run_until_complete(f_c) + + client.write(b'xxx') + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + test_utils.run_until(self.loop, lambda: proto.nbytes > 0) + self.assertEqual(3, proto.nbytes) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # stop serving + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_server_ssl_verify_failed(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True + + # no CA loaded + f_c = self.loop.create_connection(MyProto, host, port, + ssl=sslcontext_client) + with test_utils.disable_logger(): + with self.assertRaisesRegex(ssl.SSLError, + 'certificate verify failed '): + self.loop.run_until_complete(f_c) + + # close connection + self.assertIsNone(proto.transport) + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_ssl_verify_failed(self): + proto = MyProto(loop=self.loop) + server, path = self._make_ssl_unix_server( + lambda: proto, SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True + + # no CA loaded + f_c = self.loop.create_unix_connection(MyProto, path, + ssl=sslcontext_client, + server_hostname='invalid') + with test_utils.disable_logger(): + with self.assertRaisesRegex(ssl.SSLError, + 'certificate verify failed '): + self.loop.run_until_complete(f_c) + + # close connection + self.assertIsNone(proto.transport) + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_server_ssl_match_failed(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + sslcontext_client.load_verify_locations( + cafile=SIGNING_CA) + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True + + # incorrect server_hostname + f_c = self.loop.create_connection(MyProto, host, port, + ssl=sslcontext_client) + with test_utils.disable_logger(): + with self.assertRaisesRegex( + ssl.CertificateError, + "hostname '127.0.0.1' doesn't match 'localhost'"): + self.loop.run_until_complete(f_c) + + # close connection + proto.transport.close() + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_ssl_verified(self): + proto = MyProto(loop=self.loop) + server, path = self._make_ssl_unix_server( + lambda: proto, SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + sslcontext_client.load_verify_locations(cafile=SIGNING_CA) + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True + + # Connection succeeds with correct CA and server hostname. + f_c = self.loop.create_unix_connection(MyProto, path, + ssl=sslcontext_client, + server_hostname='localhost') + client, pr = self.loop.run_until_complete(f_c) + + # close connection + proto.transport.close() + client.close() + server.close() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_server_ssl_verified(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + sslcontext_client.load_verify_locations(cafile=SIGNING_CA) + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True + + # Connection succeeds with correct CA and server hostname. + f_c = self.loop.create_connection(MyProto, host, port, + ssl=sslcontext_client, + server_hostname='localhost') + client, pr = self.loop.run_until_complete(f_c) + + # close connection + proto.transport.close() + client.close() + server.close() + + def test_create_server_sock(self): + proto = asyncio.Future(loop=self.loop) + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + proto.set_result(self) + + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.loop.create_server(TestMyProto, sock=sock_ob) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + self.assertIs(sock, sock_ob) + + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + server.close() + + def test_create_server_addr_in_use(self): + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.loop.create_server(MyProto, sock=sock_ob) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + host, port = sock.getsockname() + + f = self.loop.create_server(MyProto, host=host, port=port) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + + server.close() + + @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 not supported or enabled') + def test_create_server_dual_stack(self): + f_proto = asyncio.Future(loop=self.loop) + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + f_proto.set_result(self) + + try_count = 0 + while True: + try: + port = support.find_unused_port() + f = self.loop.create_server(TestMyProto, host=None, port=port) + server = self.loop.run_until_complete(f) + except OSError as ex: + if ex.errno == errno.EADDRINUSE: + try_count += 1 + self.assertGreaterEqual(5, try_count) + continue + else: + raise + else: + break + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + proto = self.loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + f_proto = asyncio.Future(loop=self.loop) + client = socket.socket(socket.AF_INET6) + client.connect(('::1', port)) + client.send(b'xxx') + proto = self.loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + server.close() + + def test_server_close(self): + f = self.loop.create_server(MyProto, '0.0.0.0', 0) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + host, port = sock.getsockname() + + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + + server.close() + + client = socket.socket() + self.assertRaises( + ConnectionRefusedError, client.connect, ('127.0.0.1', port)) + client.close() + + def test_create_datagram_endpoint(self): + class TestMyDatagramProto(MyDatagramProto): + def __init__(inner_self): + super().__init__(loop=self.loop) + + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + coro = self.loop.create_datagram_endpoint( + TestMyDatagramProto, local_addr=('127.0.0.1', 0)) + s_transport, server = self.loop.run_until_complete(coro) + host, port = s_transport.get_extra_info('sockname') + + self.assertIsInstance(s_transport, asyncio.Transport) + self.assertIsInstance(server, TestMyDatagramProto) + self.assertEqual('INITIALIZED', server.state) + self.assertIs(server.transport, s_transport) + + coro = self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(loop=self.loop), + remote_addr=(host, port)) + transport, client = self.loop.run_until_complete(coro) + + self.assertIsInstance(transport, asyncio.Transport) + self.assertIsInstance(client, MyDatagramProto) + self.assertEqual('INITIALIZED', client.state) + self.assertIs(client.transport, transport) + + transport.sendto(b'xxx') + test_utils.run_until(self.loop, lambda: server.nbytes) + self.assertEqual(3, server.nbytes) + test_utils.run_until(self.loop, lambda: client.nbytes) + + # received + self.assertEqual(8, client.nbytes) + + # extra info is available + self.assertIsNotNone(transport.get_extra_info('sockname')) + + # close connection + transport.close() + self.loop.run_until_complete(client.done) + self.assertEqual('CLOSED', client.state) + server.transport.close() + + def test_internal_fds(self): + loop = self.create_event_loop() + if not isinstance(loop, selector_events.BaseSelectorEventLoop): + loop.close() + self.skipTest('loop is not a BaseSelectorEventLoop') + + self.assertEqual(1, loop._internal_fds) + loop.close() + self.assertEqual(0, loop._internal_fds) + self.assertIsNone(loop._csock) + self.assertIsNone(loop._ssock) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_read_pipe(self): + proto = MyReadPipeProto(loop=self.loop) + + rpipe, wpipe = os.pipe() + pipeobj = io.open(rpipe, 'rb', 1024) + + @asyncio.coroutine + def connect(): + t, p = yield from self.loop.connect_read_pipe( + lambda: proto, pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(0, proto.nbytes) + + self.loop.run_until_complete(connect()) + + os.write(wpipe, b'1') + test_utils.run_until(self.loop, lambda: proto.nbytes >= 1) + self.assertEqual(1, proto.nbytes) + + os.write(wpipe, b'2345') + test_utils.run_until(self.loop, lambda: proto.nbytes >= 5) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(5, proto.nbytes) + + os.close(wpipe) + self.loop.run_until_complete(proto.done) + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + # select, poll and kqueue don't support character devices (PTY) on Mac OS X + # older than 10.6 (Snow Leopard) + @support.requires_mac_ver(10, 6) + # Issue #20495: The test hangs on FreeBSD 7.2 but pass on FreeBSD 9 + @support.requires_freebsd_version(8) + def test_read_pty_output(self): + proto = MyReadPipeProto(loop=self.loop) + + master, slave = os.openpty() + master_read_obj = io.open(master, 'rb', 0) + + @asyncio.coroutine + def connect(): + t, p = yield from self.loop.connect_read_pipe(lambda: proto, + master_read_obj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(0, proto.nbytes) + + self.loop.run_until_complete(connect()) + + os.write(slave, b'1') + test_utils.run_until(self.loop, lambda: proto.nbytes) + self.assertEqual(1, proto.nbytes) + + os.write(slave, b'2345') + test_utils.run_until(self.loop, lambda: proto.nbytes >= 5) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(5, proto.nbytes) + + os.close(slave) + self.loop.run_until_complete(proto.done) + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe(self): + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + proto = MyWritePipeProto(loop=self.loop) + connect = self.loop.connect_write_pipe(lambda: proto, pipeobj) + transport, p = self.loop.run_until_complete(connect) + self.assertIs(p, proto) + self.assertIs(transport, proto.transport) + self.assertEqual('CONNECTED', proto.state) + + transport.write(b'1') + + data = bytearray() + def reader(data): + chunk = os.read(rpipe, 1024) + data += chunk + return len(data) + + test_utils.run_until(self.loop, lambda: reader(data) >= 1) + self.assertEqual(b'1', data) + + transport.write(b'2345') + test_utils.run_until(self.loop, lambda: reader(data) >= 5) + self.assertEqual(b'12345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(rpipe) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe_disconnect_on_close(self): + rsock, wsock = test_utils.socketpair() + rsock.setblocking(False) + pipeobj = io.open(wsock.detach(), 'wb', 1024) + + proto = MyWritePipeProto(loop=self.loop) + connect = self.loop.connect_write_pipe(lambda: proto, pipeobj) + transport, p = self.loop.run_until_complete(connect) + self.assertIs(p, proto) + self.assertIs(transport, proto.transport) + self.assertEqual('CONNECTED', proto.state) + + transport.write(b'1') + data = self.loop.run_until_complete(self.loop.sock_recv(rsock, 1024)) + self.assertEqual(b'1', data) + + rsock.close() + + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + # select, poll and kqueue don't support character devices (PTY) on Mac OS X + # older than 10.6 (Snow Leopard) + @support.requires_mac_ver(10, 6) + def test_write_pty(self): + master, slave = os.openpty() + slave_write_obj = io.open(slave, 'wb', 0) + + proto = MyWritePipeProto(loop=self.loop) + connect = self.loop.connect_write_pipe(lambda: proto, slave_write_obj) + transport, p = self.loop.run_until_complete(connect) + self.assertIs(p, proto) + self.assertIs(transport, proto.transport) + self.assertEqual('CONNECTED', proto.state) + + transport.write(b'1') + + data = bytearray() + def reader(data): + chunk = os.read(master, 1024) + data += chunk + return len(data) + + test_utils.run_until(self.loop, lambda: reader(data) >= 1, + timeout=10) + self.assertEqual(b'1', data) + + transport.write(b'2345') + test_utils.run_until(self.loop, lambda: reader(data) >= 5, + timeout=10) + self.assertEqual(b'12345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(master) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + def test_prompt_cancellation(self): + r, w = test_utils.socketpair() + r.setblocking(False) + f = self.loop.sock_recv(r, 1) + ov = getattr(f, 'ov', None) + if ov is not None: + self.assertTrue(ov.pending) + + @asyncio.coroutine + def main(): + try: + self.loop.call_soon(f.cancel) + yield from f + except asyncio.CancelledError: + res = 'cancelled' + else: + res = None + finally: + self.loop.stop() + return res + + start = time.monotonic() + t = asyncio.Task(main(), loop=self.loop) + self.loop.run_forever() + elapsed = time.monotonic() - start + + self.assertLess(elapsed, 0.1) + self.assertEqual(t.result(), 'cancelled') + self.assertRaises(asyncio.CancelledError, f.result) + if ov is not None: + self.assertFalse(ov.pending) + self.loop._stop_serving(r) + + r.close() + w.close() + + def test_timeout_rounding(self): + def _run_once(): + self.loop._run_once_counter += 1 + orig_run_once() + + orig_run_once = self.loop._run_once + self.loop._run_once_counter = 0 + self.loop._run_once = _run_once + + @asyncio.coroutine + def wait(): + loop = self.loop + yield from asyncio.sleep(1e-2, loop=loop) + yield from asyncio.sleep(1e-4, loop=loop) + yield from asyncio.sleep(1e-6, loop=loop) + yield from asyncio.sleep(1e-8, loop=loop) + yield from asyncio.sleep(1e-10, loop=loop) + + self.loop.run_until_complete(wait()) + # The ideal number of call is 12, but on some platforms, the selector + # may sleep at little bit less than timeout depending on the resolution + # of the clock used by the kernel. Tolerate a few useless calls on + # these platforms. + self.assertLessEqual(self.loop._run_once_counter, 20, + {'clock_resolution': self.loop._clock_resolution, + 'selector': self.loop._selector.__class__.__name__}) + + def test_sock_connect_address(self): + addresses = [(socket.AF_INET, ('www.python.org', 80))] + if support.IPV6_ENABLED: + addresses.extend(( + (socket.AF_INET6, ('www.python.org', 80)), + (socket.AF_INET6, ('www.python.org', 80, 0, 0)), + )) + + for family, address in addresses: + for sock_type in (socket.SOCK_STREAM, socket.SOCK_DGRAM): + sock = socket.socket(family, sock_type) + with sock: + sock.setblocking(False) + connect = self.loop.sock_connect(sock, address) + with self.assertRaises(ValueError) as cm: + self.loop.run_until_complete(connect) + self.assertIn('address must be resolved', + str(cm.exception)) + + def test_remove_fds_after_closing(self): + loop = self.create_event_loop() + callback = lambda: None + r, w = test_utils.socketpair() + self.addCleanup(r.close) + self.addCleanup(w.close) + loop.add_reader(r, callback) + loop.add_writer(w, callback) + loop.close() + self.assertFalse(loop.remove_reader(r)) + self.assertFalse(loop.remove_writer(w)) + + def test_add_fds_after_closing(self): + loop = self.create_event_loop() + callback = lambda: None + r, w = test_utils.socketpair() + self.addCleanup(r.close) + self.addCleanup(w.close) + loop.close() + with self.assertRaises(RuntimeError): + loop.add_reader(r, callback) + with self.assertRaises(RuntimeError): + loop.add_writer(w, callback) + + def test_close_running_event_loop(self): + @asyncio.coroutine + def close_loop(loop): + self.loop.close() + + coro = close_loop(self.loop) + with self.assertRaises(RuntimeError): + self.loop.run_until_complete(coro) + + def test_close(self): + self.loop.close() + + @asyncio.coroutine + def test(): + pass + + func = lambda: False + coro = test() + self.addCleanup(coro.close) + + # operation blocked when the loop is closed + with self.assertRaises(RuntimeError): + self.loop.run_forever() + with self.assertRaises(RuntimeError): + fut = asyncio.Future(loop=self.loop) + self.loop.run_until_complete(fut) + with self.assertRaises(RuntimeError): + self.loop.call_soon(func) + with self.assertRaises(RuntimeError): + self.loop.call_soon_threadsafe(func) + with self.assertRaises(RuntimeError): + self.loop.call_later(1.0, func) + with self.assertRaises(RuntimeError): + self.loop.call_at(self.loop.time() + .0, func) + with self.assertRaises(RuntimeError): + self.loop.run_in_executor(None, func) + with self.assertRaises(RuntimeError): + self.loop.create_task(coro) + with self.assertRaises(RuntimeError): + self.loop.add_signal_handler(signal.SIGTERM, func) + + +class SubprocessTestsMixin: + + def check_terminated(self, returncode): + if sys.platform == 'win32': + self.assertIsInstance(returncode, int) + # expect 1 but sometimes get 0 + else: + self.assertEqual(-signal.SIGTERM, returncode) + + def check_killed(self, returncode): + if sys.platform == 'win32': + self.assertIsInstance(returncode, int) + # expect 1 but sometimes get 0 + else: + self.assertEqual(-signal.SIGKILL, returncode) + + def test_subprocess_exec(self): + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python The Winner') + self.loop.run_until_complete(proto.got_data[1].wait()) + transp.close() + self.loop.run_until_complete(proto.completed) + self.check_terminated(proto.returncode) + self.assertEqual(b'Python The Winner', proto.data[1]) + + def test_subprocess_interactive(self): + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + try: + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python ') + self.loop.run_until_complete(proto.got_data[1].wait()) + proto.got_data[1].clear() + self.assertEqual(b'Python ', proto.data[1]) + + stdin.write(b'The Winner') + self.loop.run_until_complete(proto.got_data[1].wait()) + self.assertEqual(b'Python The Winner', proto.data[1]) + finally: + transp.close() + + self.loop.run_until_complete(proto.completed) + self.check_terminated(proto.returncode) + + def test_subprocess_shell(self): + connect = self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'echo Python') + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + transp.get_pipe_transport(0).close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(0, proto.returncode) + self.assertTrue(all(f.done() for f in proto.disconnects.values())) + self.assertEqual(proto.data[1].rstrip(b'\r\n'), b'Python') + self.assertEqual(proto.data[2], b'') + + def test_subprocess_exitcode(self): + connect = self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + + def test_subprocess_close_after_finish(self): + connect = self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.assertIsNone(transp.get_pipe_transport(0)) + self.assertIsNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + self.assertIsNone(transp.close()) + + def test_subprocess_kill(self): + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + transp.kill() + self.loop.run_until_complete(proto.completed) + self.check_killed(proto.returncode) + + def test_subprocess_terminate(self): + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + transp.terminate() + self.loop.run_until_complete(proto.completed) + self.check_terminated(proto.returncode) + + @unittest.skipIf(sys.platform == 'win32', "Don't have SIGHUP") + def test_subprocess_send_signal(self): + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + transp.send_signal(signal.SIGHUP) + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGHUP, proto.returncode) + + def test_subprocess_stderr(self): + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'test') + + self.loop.run_until_complete(proto.completed) + + transp.close() + self.assertEqual(b'OUT:test', proto.data[1]) + self.assertTrue(proto.data[2].startswith(b'ERR:test'), proto.data[2]) + self.assertEqual(0, proto.returncode) + + def test_subprocess_stderr_redirect_to_stdout(self): + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog, stderr=subprocess.STDOUT) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + self.assertIsNotNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + + stdin.write(b'test') + self.loop.run_until_complete(proto.completed) + self.assertTrue(proto.data[1].startswith(b'OUT:testERR:test'), + proto.data[1]) + self.assertEqual(b'', proto.data[2]) + + transp.close() + self.assertEqual(0, proto.returncode) + + def test_subprocess_close_client_stream(self): + prog = os.path.join(os.path.dirname(__file__), 'echo3.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdout = transp.get_pipe_transport(1) + stdin.write(b'test') + self.loop.run_until_complete(proto.got_data[1].wait()) + self.assertEqual(b'OUT:test', proto.data[1]) + + stdout.close() + self.loop.run_until_complete(proto.disconnects[1]) + stdin.write(b'xxx') + self.loop.run_until_complete(proto.got_data[2].wait()) + if sys.platform != 'win32': + self.assertEqual(b'ERR:BrokenPipeError', proto.data[2]) + else: + # After closing the read-end of a pipe, writing to the + # write-end using os.write() fails with errno==EINVAL and + # GetLastError()==ERROR_INVALID_NAME on Windows!?! (Using + # WriteFile() we get ERROR_BROKEN_PIPE as expected.) + self.assertEqual(b'ERR:OSError', proto.data[2]) + transp.close() + self.loop.run_until_complete(proto.completed) + self.check_terminated(proto.returncode) + + def test_subprocess_wait_no_same_group(self): + # start the new process in a new session + connect = self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None, + start_new_session=True) + _, proto = yield self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + + def test_subprocess_exec_invalid_args(self): + @asyncio.coroutine + def connect(**kwds): + yield from self.loop.subprocess_exec( + asyncio.SubprocessProtocol, + 'pwd', **kwds) + + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(universal_newlines=True)) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(bufsize=4096)) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(shell=True)) + + def test_subprocess_shell_invalid_args(self): + @asyncio.coroutine + def connect(cmd=None, **kwds): + if not cmd: + cmd = 'pwd' + yield from self.loop.subprocess_shell( + asyncio.SubprocessProtocol, + cmd, **kwds) + + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(['ls', '-l'])) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(universal_newlines=True)) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(bufsize=4096)) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(shell=False)) + + +if sys.platform == 'win32': + + class SelectEventLoopTests(EventLoopTestsMixin, test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop() + + class ProactorEventLoopTests(EventLoopTestsMixin, + SubprocessTestsMixin, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.ProactorEventLoop() + + def test_create_ssl_connection(self): + raise unittest.SkipTest("IocpEventLoop incompatible with SSL") + + def test_create_server_ssl(self): + raise unittest.SkipTest("IocpEventLoop incompatible with SSL") + + def test_create_server_ssl_verify_failed(self): + raise unittest.SkipTest("IocpEventLoop incompatible with SSL") + + def test_create_server_ssl_match_failed(self): + raise unittest.SkipTest("IocpEventLoop incompatible with SSL") + + def test_create_server_ssl_verified(self): + raise unittest.SkipTest("IocpEventLoop incompatible with SSL") + + def test_reader_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_reader_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_writer_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + + def test_writer_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + + def test_create_datagram_endpoint(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + + def test_remove_fds_after_closing(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") +else: + from asyncio import selectors + + class UnixEventLoopTestsMixin(EventLoopTestsMixin): + def setUp(self): + super().setUp() + watcher = asyncio.SafeChildWatcher() + watcher.attach_loop(self.loop) + asyncio.set_child_watcher(watcher) + + def tearDown(self): + asyncio.set_child_watcher(None) + super().tearDown() + + if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(UnixEventLoopTestsMixin, + SubprocessTestsMixin, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop( + selectors.KqueueSelector()) + + # kqueue doesn't support character devices (PTY) on Mac OS X older + # than 10.9 (Maverick) + @support.requires_mac_ver(10, 9) + # Issue #20667: KqueueEventLoopTests.test_read_pty_output() + # hangs on OpenBSD 5.5 + @unittest.skipIf(sys.platform.startswith('openbsd'), + 'test hangs on OpenBSD') + def test_read_pty_output(self): + super().test_read_pty_output() + + # kqueue doesn't support character devices (PTY) on Mac OS X older + # than 10.9 (Maverick) + @support.requires_mac_ver(10, 9) + def test_write_pty(self): + super().test_write_pty() + + if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(UnixEventLoopTestsMixin, + SubprocessTestsMixin, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop(selectors.EpollSelector()) + + if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(UnixEventLoopTestsMixin, + SubprocessTestsMixin, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop(selectors.PollSelector()) + + # Should always exist. + class SelectEventLoopTests(UnixEventLoopTestsMixin, + SubprocessTestsMixin, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop(selectors.SelectSelector()) + + +def noop(*args): + pass + + +class HandleTests(test_utils.TestCase): + + def setUp(self): + self.loop = mock.Mock() + self.loop.get_debug.return_value = True + + def test_handle(self): + def callback(*args): + return args + + args = () + h = asyncio.Handle(callback, args, self.loop) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h._cancelled) + + h.cancel() + self.assertTrue(h._cancelled) + + def test_handle_from_handle(self): + def callback(*args): + return args + h1 = asyncio.Handle(callback, (), loop=self.loop) + self.assertRaises( + AssertionError, asyncio.Handle, h1, (), self.loop) + + def test_callback_with_exception(self): + def callback(): + raise ValueError() + + self.loop = mock.Mock() + self.loop.call_exception_handler = mock.Mock() + + h = asyncio.Handle(callback, (), self.loop) + h._run() + + self.loop.call_exception_handler.assert_called_with({ + 'message': test_utils.MockPattern('Exception in callback.*'), + 'exception': mock.ANY, + 'handle': h, + 'source_traceback': h._source_traceback, + }) + + def test_handle_weakref(self): + wd = weakref.WeakValueDictionary() + h = asyncio.Handle(lambda: None, (), self.loop) + wd['h'] = h # Would fail without __weakref__ slot. + + def test_handle_repr(self): + self.loop.get_debug.return_value = False + + # simple function + h = asyncio.Handle(noop, (1, 2), self.loop) + filename, lineno = test_utils.get_function_source(noop) + self.assertEqual(repr(h), + '' + % (filename, lineno)) + + # cancelled handle + h.cancel() + self.assertEqual(repr(h), + '') + + # decorated function + cb = asyncio.coroutine(noop) + h = asyncio.Handle(cb, (), self.loop) + self.assertEqual(repr(h), + '' + % (filename, lineno)) + + # partial function + cb = functools.partial(noop, 1, 2) + h = asyncio.Handle(cb, (3,), self.loop) + regex = (r'^$' + % (re.escape(filename), lineno)) + self.assertRegex(repr(h), regex) + + # partial method + if sys.version_info >= (3, 4): + method = HandleTests.test_handle_repr + cb = functools.partialmethod(method) + filename, lineno = test_utils.get_function_source(method) + h = asyncio.Handle(cb, (), self.loop) + + cb_regex = r'' + cb_regex = (r'functools.partialmethod\(%s, , \)\(\)' % cb_regex) + regex = (r'^$' + % (cb_regex, re.escape(filename), lineno)) + self.assertRegex(repr(h), regex) + + def test_handle_repr_debug(self): + self.loop.get_debug.return_value = True + + # simple function + create_filename = __file__ + create_lineno = sys._getframe().f_lineno + 1 + h = asyncio.Handle(noop, (1, 2), self.loop) + filename, lineno = test_utils.get_function_source(noop) + self.assertEqual(repr(h), + '' + % (filename, lineno, create_filename, create_lineno)) + + # cancelled handle + h.cancel() + self.assertEqual( + repr(h), + '' + % (filename, lineno, create_filename, create_lineno)) + + # double cancellation won't overwrite _repr + h.cancel() + self.assertEqual( + repr(h), + '' + % (filename, lineno, create_filename, create_lineno)) + + def test_handle_source_traceback(self): + loop = asyncio.get_event_loop_policy().new_event_loop() + loop.set_debug(True) + self.set_event_loop(loop) + + def check_source_traceback(h): + lineno = sys._getframe(1).f_lineno - 1 + self.assertIsInstance(h._source_traceback, list) + self.assertEqual(h._source_traceback[-1][:3], + (__file__, + lineno, + 'test_handle_source_traceback')) + + # call_soon + h = loop.call_soon(noop) + check_source_traceback(h) + + # call_soon_threadsafe + h = loop.call_soon_threadsafe(noop) + check_source_traceback(h) + + # call_later + h = loop.call_later(0, noop) + check_source_traceback(h) + + # call_at + h = loop.call_later(0, noop) + check_source_traceback(h) + + +class TimerTests(unittest.TestCase): + + def setUp(self): + self.loop = mock.Mock() + + def test_hash(self): + when = time.monotonic() + h = asyncio.TimerHandle(when, lambda: False, (), + mock.Mock()) + self.assertEqual(hash(h), hash(when)) + + def test_timer(self): + def callback(*args): + return args + + args = (1, 2, 3) + when = time.monotonic() + h = asyncio.TimerHandle(when, callback, args, mock.Mock()) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h._cancelled) + + # cancel + h.cancel() + self.assertTrue(h._cancelled) + self.assertIsNone(h._callback) + self.assertIsNone(h._args) + + # when cannot be None + self.assertRaises(AssertionError, + asyncio.TimerHandle, None, callback, args, + self.loop) + + def test_timer_repr(self): + self.loop.get_debug.return_value = False + + # simple function + h = asyncio.TimerHandle(123, noop, (), self.loop) + src = test_utils.get_function_source(noop) + self.assertEqual(repr(h), + '' % src) + + # cancelled handle + h.cancel() + self.assertEqual(repr(h), + '') + + def test_timer_repr_debug(self): + self.loop.get_debug.return_value = True + + # simple function + create_filename = __file__ + create_lineno = sys._getframe().f_lineno + 1 + h = asyncio.TimerHandle(123, noop, (), self.loop) + filename, lineno = test_utils.get_function_source(noop) + self.assertEqual(repr(h), + '' + % (filename, lineno, create_filename, create_lineno)) + + # cancelled handle + h.cancel() + self.assertEqual(repr(h), + '' + % (filename, lineno, create_filename, create_lineno)) + + + def test_timer_comparison(self): + def callback(*args): + return args + + when = time.monotonic() + + h1 = asyncio.TimerHandle(when, callback, (), self.loop) + h2 = asyncio.TimerHandle(when, callback, (), self.loop) + # TODO: Use assertLess etc. + self.assertFalse(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertTrue(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertFalse(h2 > h1) + self.assertTrue(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertTrue(h1 == h2) + self.assertFalse(h1 != h2) + + h2.cancel() + self.assertFalse(h1 == h2) + + h1 = asyncio.TimerHandle(when, callback, (), self.loop) + h2 = asyncio.TimerHandle(when + 10.0, callback, (), self.loop) + self.assertTrue(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertFalse(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertTrue(h2 > h1) + self.assertFalse(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h3 = asyncio.Handle(callback, (), self.loop) + self.assertIs(NotImplemented, h1.__eq__(h3)) + self.assertIs(NotImplemented, h1.__ne__(h3)) + + +class AbstractEventLoopTests(unittest.TestCase): + + def test_not_implemented(self): + f = mock.Mock() + loop = asyncio.AbstractEventLoop() + self.assertRaises( + NotImplementedError, loop.run_forever) + self.assertRaises( + NotImplementedError, loop.run_until_complete, None) + self.assertRaises( + NotImplementedError, loop.stop) + self.assertRaises( + NotImplementedError, loop.is_running) + self.assertRaises( + NotImplementedError, loop.is_closed) + self.assertRaises( + NotImplementedError, loop.close) + self.assertRaises( + NotImplementedError, loop.create_task, None) + self.assertRaises( + NotImplementedError, loop.call_later, None, None) + self.assertRaises( + NotImplementedError, loop.call_at, f, f) + self.assertRaises( + NotImplementedError, loop.call_soon, None) + self.assertRaises( + NotImplementedError, loop.time) + self.assertRaises( + NotImplementedError, loop.call_soon_threadsafe, None) + self.assertRaises( + NotImplementedError, loop.run_in_executor, f, f) + self.assertRaises( + NotImplementedError, loop.set_default_executor, f) + self.assertRaises( + NotImplementedError, loop.getaddrinfo, 'localhost', 8080) + self.assertRaises( + NotImplementedError, loop.getnameinfo, ('localhost', 8080)) + self.assertRaises( + NotImplementedError, loop.create_connection, f) + self.assertRaises( + NotImplementedError, loop.create_server, f) + self.assertRaises( + NotImplementedError, loop.create_datagram_endpoint, f) + self.assertRaises( + NotImplementedError, loop.add_reader, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_reader, 1) + self.assertRaises( + NotImplementedError, loop.add_writer, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_writer, 1) + self.assertRaises( + NotImplementedError, loop.sock_recv, f, 10) + self.assertRaises( + NotImplementedError, loop.sock_sendall, f, 10) + self.assertRaises( + NotImplementedError, loop.sock_connect, f, f) + self.assertRaises( + NotImplementedError, loop.sock_accept, f) + self.assertRaises( + NotImplementedError, loop.add_signal_handler, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, loop.connect_read_pipe, f, + mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, loop.connect_write_pipe, f, + mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, loop.subprocess_shell, f, + mock.sentinel) + self.assertRaises( + NotImplementedError, loop.subprocess_exec, f) + self.assertRaises( + NotImplementedError, loop.set_exception_handler, f) + self.assertRaises( + NotImplementedError, loop.default_exception_handler, f) + self.assertRaises( + NotImplementedError, loop.call_exception_handler, f) + self.assertRaises( + NotImplementedError, loop.get_debug) + self.assertRaises( + NotImplementedError, loop.set_debug, f) + + +class ProtocolsAbsTests(unittest.TestCase): + + def test_empty(self): + f = mock.Mock() + p = asyncio.Protocol() + self.assertIsNone(p.connection_made(f)) + self.assertIsNone(p.connection_lost(f)) + self.assertIsNone(p.data_received(f)) + self.assertIsNone(p.eof_received()) + + dp = asyncio.DatagramProtocol() + self.assertIsNone(dp.connection_made(f)) + self.assertIsNone(dp.connection_lost(f)) + self.assertIsNone(dp.error_received(f)) + self.assertIsNone(dp.datagram_received(f, f)) + + sp = asyncio.SubprocessProtocol() + self.assertIsNone(sp.connection_made(f)) + self.assertIsNone(sp.connection_lost(f)) + self.assertIsNone(sp.pipe_data_received(1, f)) + self.assertIsNone(sp.pipe_connection_lost(1, f)) + self.assertIsNone(sp.process_exited()) + + +class PolicyTests(unittest.TestCase): + + def test_event_loop_policy(self): + policy = asyncio.AbstractEventLoopPolicy() + self.assertRaises(NotImplementedError, policy.get_event_loop) + self.assertRaises(NotImplementedError, policy.set_event_loop, object()) + self.assertRaises(NotImplementedError, policy.new_event_loop) + self.assertRaises(NotImplementedError, policy.get_child_watcher) + self.assertRaises(NotImplementedError, policy.set_child_watcher, + object()) + + def test_get_event_loop(self): + policy = asyncio.DefaultEventLoopPolicy() + self.assertIsNone(policy._local._loop) + + loop = policy.get_event_loop() + self.assertIsInstance(loop, asyncio.AbstractEventLoop) + + self.assertIs(policy._local._loop, loop) + self.assertIs(loop, policy.get_event_loop()) + loop.close() + + def test_get_event_loop_calls_set_event_loop(self): + policy = asyncio.DefaultEventLoopPolicy() + + with mock.patch.object( + policy, "set_event_loop", + wraps=policy.set_event_loop) as m_set_event_loop: + + loop = policy.get_event_loop() + + # policy._local._loop must be set through .set_event_loop() + # (the unix DefaultEventLoopPolicy needs this call to attach + # the child watcher correctly) + m_set_event_loop.assert_called_with(loop) + + loop.close() + + def test_get_event_loop_after_set_none(self): + policy = asyncio.DefaultEventLoopPolicy() + policy.set_event_loop(None) + self.assertRaises(AssertionError, policy.get_event_loop) + + @mock.patch('asyncio.events.threading.current_thread') + def test_get_event_loop_thread(self, m_current_thread): + + def f(): + policy = asyncio.DefaultEventLoopPolicy() + self.assertRaises(AssertionError, policy.get_event_loop) + + th = threading.Thread(target=f) + th.start() + th.join() + + def test_new_event_loop(self): + policy = asyncio.DefaultEventLoopPolicy() + + loop = policy.new_event_loop() + self.assertIsInstance(loop, asyncio.AbstractEventLoop) + loop.close() + + def test_set_event_loop(self): + policy = asyncio.DefaultEventLoopPolicy() + old_loop = policy.get_event_loop() + + self.assertRaises(AssertionError, policy.set_event_loop, object()) + + loop = policy.new_event_loop() + policy.set_event_loop(loop) + self.assertIs(loop, policy.get_event_loop()) + self.assertIsNot(old_loop, policy.get_event_loop()) + loop.close() + old_loop.close() + + def test_get_event_loop_policy(self): + policy = asyncio.get_event_loop_policy() + self.assertIsInstance(policy, asyncio.AbstractEventLoopPolicy) + self.assertIs(policy, asyncio.get_event_loop_policy()) + + def test_set_event_loop_policy(self): + self.assertRaises( + AssertionError, asyncio.set_event_loop_policy, object()) + + old_policy = asyncio.get_event_loop_policy() + + policy = asyncio.DefaultEventLoopPolicy() + asyncio.set_event_loop_policy(policy) + self.assertIs(policy, asyncio.get_event_loop_policy()) + self.assertIsNot(policy, old_policy) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_futures.py b/tests/test_futures.py new file mode 100644 index 00000000..371d3518 --- /dev/null +++ b/tests/test_futures.py @@ -0,0 +1,461 @@ +"""Tests for futures.py.""" + +import concurrent.futures +import re +import sys +import threading +import unittest +from test import support +from unittest import mock + +import asyncio +from asyncio import test_utils + + +def _fakefunc(f): + return f + +def first_cb(): + pass + +def last_cb(): + pass + + +class FutureTests(test_utils.TestCase): + + def setUp(self): + self.loop = self.new_test_loop() + + def test_initial_state(self): + f = asyncio.Future(loop=self.loop) + self.assertFalse(f.cancelled()) + self.assertFalse(f.done()) + f.cancel() + self.assertTrue(f.cancelled()) + + def test_init_constructor_default_loop(self): + asyncio.set_event_loop(self.loop) + f = asyncio.Future() + self.assertIs(f._loop, self.loop) + + def test_constructor_positional(self): + # Make sure Future doesn't accept a positional argument + self.assertRaises(TypeError, asyncio.Future, 42) + + def test_cancel(self): + f = asyncio.Future(loop=self.loop) + self.assertTrue(f.cancel()) + self.assertTrue(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(asyncio.CancelledError, f.result) + self.assertRaises(asyncio.CancelledError, f.exception) + self.assertRaises(asyncio.InvalidStateError, f.set_result, None) + self.assertRaises(asyncio.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_result(self): + f = asyncio.Future(loop=self.loop) + self.assertRaises(asyncio.InvalidStateError, f.result) + + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + self.assertRaises(asyncio.InvalidStateError, f.set_result, None) + self.assertRaises(asyncio.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_exception(self): + exc = RuntimeError() + f = asyncio.Future(loop=self.loop) + self.assertRaises(asyncio.InvalidStateError, f.exception) + + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + self.assertRaises(asyncio.InvalidStateError, f.set_result, None) + self.assertRaises(asyncio.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_exception_class(self): + f = asyncio.Future(loop=self.loop) + f.set_exception(RuntimeError) + self.assertIsInstance(f.exception(), RuntimeError) + + def test_yield_from_twice(self): + f = asyncio.Future(loop=self.loop) + + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + + def test_future_repr(self): + self.loop.set_debug(True) + f_pending_debug = asyncio.Future(loop=self.loop) + frame = f_pending_debug._source_traceback[-1] + self.assertEqual(repr(f_pending_debug), + '' + % (frame[0], frame[1])) + f_pending_debug.cancel() + + self.loop.set_debug(False) + f_pending = asyncio.Future(loop=self.loop) + self.assertEqual(repr(f_pending), '') + f_pending.cancel() + + f_cancelled = asyncio.Future(loop=self.loop) + f_cancelled.cancel() + self.assertEqual(repr(f_cancelled), '') + + f_result = asyncio.Future(loop=self.loop) + f_result.set_result(4) + self.assertEqual(repr(f_result), '') + self.assertEqual(f_result.result(), 4) + + exc = RuntimeError() + f_exception = asyncio.Future(loop=self.loop) + f_exception.set_exception(exc) + self.assertEqual(repr(f_exception), '') + self.assertIs(f_exception.exception(), exc) + + def func_repr(func): + filename, lineno = test_utils.get_function_source(func) + text = '%s() at %s:%s' % (func.__qualname__, filename, lineno) + return re.escape(text) + + f_one_callbacks = asyncio.Future(loop=self.loop) + f_one_callbacks.add_done_callback(_fakefunc) + fake_repr = func_repr(_fakefunc) + self.assertRegex(repr(f_one_callbacks), + r'' % fake_repr) + f_one_callbacks.cancel() + self.assertEqual(repr(f_one_callbacks), + '') + + f_two_callbacks = asyncio.Future(loop=self.loop) + f_two_callbacks.add_done_callback(first_cb) + f_two_callbacks.add_done_callback(last_cb) + first_repr = func_repr(first_cb) + last_repr = func_repr(last_cb) + self.assertRegex(repr(f_two_callbacks), + r'' + % (first_repr, last_repr)) + + f_many_callbacks = asyncio.Future(loop=self.loop) + f_many_callbacks.add_done_callback(first_cb) + for i in range(8): + f_many_callbacks.add_done_callback(_fakefunc) + f_many_callbacks.add_done_callback(last_cb) + cb_regex = r'%s, <8 more>, %s' % (first_repr, last_repr) + self.assertRegex(repr(f_many_callbacks), + r'' % cb_regex) + f_many_callbacks.cancel() + self.assertEqual(repr(f_many_callbacks), + '') + + def test_copy_state(self): + # Test the internal _copy_state method since it's being directly + # invoked in other modules. + f = asyncio.Future(loop=self.loop) + f.set_result(10) + + newf = asyncio.Future(loop=self.loop) + newf._copy_state(f) + self.assertTrue(newf.done()) + self.assertEqual(newf.result(), 10) + + f_exception = asyncio.Future(loop=self.loop) + f_exception.set_exception(RuntimeError()) + + newf_exception = asyncio.Future(loop=self.loop) + newf_exception._copy_state(f_exception) + self.assertTrue(newf_exception.done()) + self.assertRaises(RuntimeError, newf_exception.result) + + f_cancelled = asyncio.Future(loop=self.loop) + f_cancelled.cancel() + + newf_cancelled = asyncio.Future(loop=self.loop) + newf_cancelled._copy_state(f_cancelled) + self.assertTrue(newf_cancelled.cancelled()) + + def test_iter(self): + fut = asyncio.Future(loop=self.loop) + + def coro(): + yield from fut + + def test(): + arg1, arg2 = coro() + + self.assertRaises(AssertionError, test) + fut.cancel() + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_abandoned(self, m_log): + fut = asyncio.Future(loop=self.loop) + del fut + self.assertFalse(m_log.error.called) + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_result_unretrieved(self, m_log): + fut = asyncio.Future(loop=self.loop) + fut.set_result(42) + del fut + self.assertFalse(m_log.error.called) + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_result_retrieved(self, m_log): + fut = asyncio.Future(loop=self.loop) + fut.set_result(42) + fut.result() + del fut + self.assertFalse(m_log.error.called) + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_exception_unretrieved(self, m_log): + fut = asyncio.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + del fut + test_utils.run_briefly(self.loop) + self.assertTrue(m_log.error.called) + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_exception_retrieved(self, m_log): + fut = asyncio.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + fut.exception() + del fut + self.assertFalse(m_log.error.called) + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_exception_result_retrieved(self, m_log): + fut = asyncio.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + self.assertRaises(RuntimeError, fut.result) + del fut + self.assertFalse(m_log.error.called) + + def test_wrap_future(self): + + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = asyncio.wrap_future(f1, loop=self.loop) + res, ident = self.loop.run_until_complete(f2) + self.assertIsInstance(f2, asyncio.Future) + self.assertEqual(res, 'oi') + self.assertNotEqual(ident, threading.get_ident()) + + def test_wrap_future_future(self): + f1 = asyncio.Future(loop=self.loop) + f2 = asyncio.wrap_future(f1) + self.assertIs(f1, f2) + + @mock.patch('asyncio.futures.events') + def test_wrap_future_use_global_loop(self, m_events): + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = asyncio.wrap_future(f1) + self.assertIs(m_events.get_event_loop.return_value, f2._loop) + + def test_wrap_future_cancel(self): + f1 = concurrent.futures.Future() + f2 = asyncio.wrap_future(f1, loop=self.loop) + f2.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(f1.cancelled()) + self.assertTrue(f2.cancelled()) + + def test_wrap_future_cancel2(self): + f1 = concurrent.futures.Future() + f2 = asyncio.wrap_future(f1, loop=self.loop) + f1.set_result(42) + f2.cancel() + test_utils.run_briefly(self.loop) + self.assertFalse(f1.cancelled()) + self.assertEqual(f1.result(), 42) + self.assertTrue(f2.cancelled()) + + def test_future_source_traceback(self): + self.loop.set_debug(True) + + future = asyncio.Future(loop=self.loop) + lineno = sys._getframe().f_lineno - 1 + self.assertIsInstance(future._source_traceback, list) + self.assertEqual(future._source_traceback[-1][:3], + (__file__, + lineno, + 'test_future_source_traceback')) + + @mock.patch('asyncio.base_events.logger') + def check_future_exception_never_retrieved(self, debug, m_log): + self.loop.set_debug(debug) + + def memory_error(): + try: + raise MemoryError() + except BaseException as exc: + return exc + exc = memory_error() + + future = asyncio.Future(loop=self.loop) + if debug: + source_traceback = future._source_traceback + future.set_exception(exc) + future = None + test_utils.run_briefly(self.loop) + support.gc_collect() + + if sys.version_info >= (3, 4): + if debug: + frame = source_traceback[-1] + regex = (r'^Future exception was never retrieved\n' + r'future: \n' + r'source_traceback: Object created at \(most recent call last\):\n' + r' File' + r'.*\n' + r' File "{filename}", line {lineno}, in check_future_exception_never_retrieved\n' + r' future = asyncio\.Future\(loop=self\.loop\)$' + ).format(filename=re.escape(frame[0]), lineno=frame[1]) + else: + regex = (r'^Future exception was never retrieved\n' + r'future: $' + ) + exc_info = (type(exc), exc, exc.__traceback__) + m_log.error.assert_called_once_with(mock.ANY, exc_info=exc_info) + else: + if debug: + frame = source_traceback[-1] + regex = (r'^Future/Task exception was never retrieved\n' + r'Future/Task created at \(most recent call last\):\n' + r' File' + r'.*\n' + r' File "{filename}", line {lineno}, in check_future_exception_never_retrieved\n' + r' future = asyncio\.Future\(loop=self\.loop\)\n' + r'Traceback \(most recent call last\):\n' + r'.*\n' + r'MemoryError$' + ).format(filename=re.escape(frame[0]), lineno=frame[1]) + else: + regex = (r'^Future/Task exception was never retrieved\n' + r'Traceback \(most recent call last\):\n' + r'.*\n' + r'MemoryError$' + ) + m_log.error.assert_called_once_with(mock.ANY, exc_info=False) + message = m_log.error.call_args[0][0] + self.assertRegex(message, re.compile(regex, re.DOTALL)) + + def test_future_exception_never_retrieved(self): + self.check_future_exception_never_retrieved(False) + + def test_future_exception_never_retrieved_debug(self): + self.check_future_exception_never_retrieved(True) + + def test_set_result_unless_cancelled(self): + fut = asyncio.Future(loop=self.loop) + fut.cancel() + fut._set_result_unless_cancelled(2) + self.assertTrue(fut.cancelled()) + + +class FutureDoneCallbackTests(test_utils.TestCase): + + def setUp(self): + self.loop = self.new_test_loop() + + def run_briefly(self): + test_utils.run_briefly(self.loop) + + def _make_callback(self, bag, thing): + # Create a callback function that appends thing to bag. + def bag_appender(future): + bag.append(thing) + return bag_appender + + def _new_future(self): + return asyncio.Future(loop=self.loop) + + def test_callbacks_invoked_on_set_result(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 42)) + f.add_done_callback(self._make_callback(bag, 17)) + + self.assertEqual(bag, []) + f.set_result('foo') + + self.run_briefly() + + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def test_callbacks_invoked_on_set_exception(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 100)) + + self.assertEqual(bag, []) + exc = RuntimeError() + f.set_exception(exc) + + self.run_briefly() + + self.assertEqual(bag, [100]) + self.assertEqual(f.exception(), exc) + + def test_remove_done_callback(self): + bag = [] + f = self._new_future() + cb1 = self._make_callback(bag, 1) + cb2 = self._make_callback(bag, 2) + cb3 = self._make_callback(bag, 3) + + # Add one cb1 and one cb2. + f.add_done_callback(cb1) + f.add_done_callback(cb2) + + # One instance of cb2 removed. Now there's only one cb1. + self.assertEqual(f.remove_done_callback(cb2), 1) + + # Never had any cb3 in there. + self.assertEqual(f.remove_done_callback(cb3), 0) + + # After this there will be 6 instances of cb1 and one of cb2. + f.add_done_callback(cb2) + for i in range(5): + f.add_done_callback(cb1) + + # Remove all instances of cb1. One cb2 remains. + self.assertEqual(f.remove_done_callback(cb1), 6) + + self.assertEqual(bag, []) + f.set_result('foo') + + self.run_briefly() + + self.assertEqual(bag, [2]) + self.assertEqual(f.result(), 'foo') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_locks.py b/tests/test_locks.py new file mode 100644 index 00000000..dda4577a --- /dev/null +++ b/tests/test_locks.py @@ -0,0 +1,858 @@ +"""Tests for lock.py""" + +import unittest +from unittest import mock +import re + +import asyncio +from asyncio import test_utils + + +STR_RGX_REPR = ( + r'^<(?P.*?) object at (?P
.*?)' + r'\[(?P' + r'(set|unset|locked|unlocked)(,value:\d)?(,waiters:\d+)?' + r')\]>\Z' +) +RGX_REPR = re.compile(STR_RGX_REPR) + + +class LockTests(test_utils.TestCase): + + def setUp(self): + self.loop = self.new_test_loop() + + def test_ctor_loop(self): + loop = mock.Mock() + lock = asyncio.Lock(loop=loop) + self.assertIs(lock._loop, loop) + + lock = asyncio.Lock(loop=self.loop) + self.assertIs(lock._loop, self.loop) + + def test_ctor_noloop(self): + asyncio.set_event_loop(self.loop) + lock = asyncio.Lock() + self.assertIs(lock._loop, self.loop) + + def test_repr(self): + lock = asyncio.Lock(loop=self.loop) + self.assertTrue(repr(lock).endswith('[unlocked]>')) + self.assertTrue(RGX_REPR.match(repr(lock))) + + @asyncio.coroutine + def acquire_lock(): + yield from lock + + self.loop.run_until_complete(acquire_lock()) + self.assertTrue(repr(lock).endswith('[locked]>')) + self.assertTrue(RGX_REPR.match(repr(lock))) + + def test_lock(self): + lock = asyncio.Lock(loop=self.loop) + + @asyncio.coroutine + def acquire_lock(): + return (yield from lock) + + res = self.loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_acquire(self): + lock = asyncio.Lock(loop=self.loop) + result = [] + + self.assertTrue(self.loop.run_until_complete(lock.acquire())) + + @asyncio.coroutine + def c1(result): + if (yield from lock.acquire()): + result.append(1) + return True + + @asyncio.coroutine + def c2(result): + if (yield from lock.acquire()): + result.append(2) + return True + + @asyncio.coroutine + def c3(result): + if (yield from lock.acquire()): + result.append(3) + return True + + t1 = asyncio.Task(c1(result), loop=self.loop) + t2 = asyncio.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + t3 = asyncio.Task(c3(result), loop=self.loop) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_acquire_cancel(self): + lock = asyncio.Lock(loop=self.loop) + self.assertTrue(self.loop.run_until_complete(lock.acquire())) + + task = asyncio.Task(lock.acquire(), loop=self.loop) + self.loop.call_soon(task.cancel) + self.assertRaises( + asyncio.CancelledError, + self.loop.run_until_complete, task) + self.assertFalse(lock._waiters) + + def test_cancel_race(self): + # Several tasks: + # - A acquires the lock + # - B is blocked in aqcuire() + # - C is blocked in aqcuire() + # + # Now, concurrently: + # - B is cancelled + # - A releases the lock + # + # If B's waiter is marked cancelled but not yet removed from + # _waiters, A's release() call will crash when trying to set + # B's waiter; instead, it should move on to C's waiter. + + # Setup: A has the lock, b and c are waiting. + lock = asyncio.Lock(loop=self.loop) + + @asyncio.coroutine + def lockit(name, blocker): + yield from lock.acquire() + try: + if blocker is not None: + yield from blocker + finally: + lock.release() + + fa = asyncio.Future(loop=self.loop) + ta = asyncio.Task(lockit('A', fa), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertTrue(lock.locked()) + tb = asyncio.Task(lockit('B', None), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(len(lock._waiters), 1) + tc = asyncio.Task(lockit('C', None), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(len(lock._waiters), 2) + + # Create the race and check. + # Without the fix this failed at the last assert. + fa.set_result(None) + tb.cancel() + self.assertTrue(lock._waiters[0].cancelled()) + test_utils.run_briefly(self.loop) + self.assertFalse(lock.locked()) + self.assertTrue(ta.done()) + self.assertTrue(tb.cancelled()) + self.assertTrue(tc.done()) + + def test_release_not_acquired(self): + lock = asyncio.Lock(loop=self.loop) + + self.assertRaises(RuntimeError, lock.release) + + def test_release_no_waiters(self): + lock = asyncio.Lock(loop=self.loop) + self.loop.run_until_complete(lock.acquire()) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_context_manager(self): + lock = asyncio.Lock(loop=self.loop) + + @asyncio.coroutine + def acquire_lock(): + return (yield from lock) + + with self.loop.run_until_complete(acquire_lock()): + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + def test_context_manager_cant_reuse(self): + lock = asyncio.Lock(loop=self.loop) + + @asyncio.coroutine + def acquire_lock(): + return (yield from lock) + + # This spells "yield from lock" outside a generator. + cm = self.loop.run_until_complete(acquire_lock()) + with cm: + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + with self.assertRaises(AttributeError): + with cm: + pass + + def test_context_manager_no_yield(self): + lock = asyncio.Lock(loop=self.loop) + + try: + with lock: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + self.assertFalse(lock.locked()) + + +class EventTests(test_utils.TestCase): + + def setUp(self): + self.loop = self.new_test_loop() + + def test_ctor_loop(self): + loop = mock.Mock() + ev = asyncio.Event(loop=loop) + self.assertIs(ev._loop, loop) + + ev = asyncio.Event(loop=self.loop) + self.assertIs(ev._loop, self.loop) + + def test_ctor_noloop(self): + asyncio.set_event_loop(self.loop) + ev = asyncio.Event() + self.assertIs(ev._loop, self.loop) + + def test_repr(self): + ev = asyncio.Event(loop=self.loop) + self.assertTrue(repr(ev).endswith('[unset]>')) + match = RGX_REPR.match(repr(ev)) + self.assertEqual(match.group('extras'), 'unset') + + ev.set() + self.assertTrue(repr(ev).endswith('[set]>')) + self.assertTrue(RGX_REPR.match(repr(ev))) + + ev._waiters.append(mock.Mock()) + self.assertTrue('waiters:1' in repr(ev)) + self.assertTrue(RGX_REPR.match(repr(ev))) + + def test_wait(self): + ev = asyncio.Event(loop=self.loop) + self.assertFalse(ev.is_set()) + + result = [] + + @asyncio.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + @asyncio.coroutine + def c2(result): + if (yield from ev.wait()): + result.append(2) + + @asyncio.coroutine + def c3(result): + if (yield from ev.wait()): + result.append(3) + + t1 = asyncio.Task(c1(result), loop=self.loop) + t2 = asyncio.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + t3 = asyncio.Task(c3(result), loop=self.loop) + + ev.set() + test_utils.run_briefly(self.loop) + self.assertEqual([3, 1, 2], result) + + self.assertTrue(t1.done()) + self.assertIsNone(t1.result()) + self.assertTrue(t2.done()) + self.assertIsNone(t2.result()) + self.assertTrue(t3.done()) + self.assertIsNone(t3.result()) + + def test_wait_on_set(self): + ev = asyncio.Event(loop=self.loop) + ev.set() + + res = self.loop.run_until_complete(ev.wait()) + self.assertTrue(res) + + def test_wait_cancel(self): + ev = asyncio.Event(loop=self.loop) + + wait = asyncio.Task(ev.wait(), loop=self.loop) + self.loop.call_soon(wait.cancel) + self.assertRaises( + asyncio.CancelledError, + self.loop.run_until_complete, wait) + self.assertFalse(ev._waiters) + + def test_clear(self): + ev = asyncio.Event(loop=self.loop) + self.assertFalse(ev.is_set()) + + ev.set() + self.assertTrue(ev.is_set()) + + ev.clear() + self.assertFalse(ev.is_set()) + + def test_clear_with_waiters(self): + ev = asyncio.Event(loop=self.loop) + result = [] + + @asyncio.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + return True + + t = asyncio.Task(c1(result), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + ev.set() + ev.clear() + self.assertFalse(ev.is_set()) + + ev.set() + ev.set() + self.assertEqual(1, len(ev._waiters)) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertEqual(0, len(ev._waiters)) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + +class ConditionTests(test_utils.TestCase): + + def setUp(self): + self.loop = self.new_test_loop() + + def test_ctor_loop(self): + loop = mock.Mock() + cond = asyncio.Condition(loop=loop) + self.assertIs(cond._loop, loop) + + cond = asyncio.Condition(loop=self.loop) + self.assertIs(cond._loop, self.loop) + + def test_ctor_noloop(self): + asyncio.set_event_loop(self.loop) + cond = asyncio.Condition() + self.assertIs(cond._loop, self.loop) + + def test_wait(self): + cond = asyncio.Condition(loop=self.loop) + result = [] + + @asyncio.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + return True + + @asyncio.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + return True + + @asyncio.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + return True + + t1 = asyncio.Task(c1(result), loop=self.loop) + t2 = asyncio.Task(c2(result), loop=self.loop) + t3 = asyncio.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue(self.loop.run_until_complete(cond.acquire())) + cond.notify() + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_wait_cancel(self): + cond = asyncio.Condition(loop=self.loop) + self.loop.run_until_complete(cond.acquire()) + + wait = asyncio.Task(cond.wait(), loop=self.loop) + self.loop.call_soon(wait.cancel) + self.assertRaises( + asyncio.CancelledError, + self.loop.run_until_complete, wait) + self.assertFalse(cond._waiters) + self.assertTrue(cond.locked()) + + def test_wait_unacquired(self): + cond = asyncio.Condition(loop=self.loop) + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, cond.wait()) + + def test_wait_for(self): + cond = asyncio.Condition(loop=self.loop) + presult = False + + def predicate(): + return presult + + result = [] + + @asyncio.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate)): + result.append(1) + cond.release() + return True + + t = asyncio.Task(c1(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + presult = True + self.loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + def test_wait_for_unacquired(self): + cond = asyncio.Condition(loop=self.loop) + + # predicate can return true immediately + res = self.loop.run_until_complete(cond.wait_for(lambda: [1, 2, 3])) + self.assertEqual([1, 2, 3], res) + + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, + cond.wait_for(lambda: False)) + + def test_notify(self): + cond = asyncio.Condition(loop=self.loop) + result = [] + + @asyncio.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + return True + + @asyncio.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + return True + + @asyncio.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + cond.release() + return True + + t1 = asyncio.Task(c1(result), loop=self.loop) + t2 = asyncio.Task(c2(result), loop=self.loop) + t3 = asyncio.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.notify(2048) + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_notify_all(self): + cond = asyncio.Condition(loop=self.loop) + + result = [] + + @asyncio.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + return True + + @asyncio.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + return True + + t1 = asyncio.Task(c1(result), loop=self.loop) + t2 = asyncio.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify_all() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + + def test_notify_unacquired(self): + cond = asyncio.Condition(loop=self.loop) + self.assertRaises(RuntimeError, cond.notify) + + def test_notify_all_unacquired(self): + cond = asyncio.Condition(loop=self.loop) + self.assertRaises(RuntimeError, cond.notify_all) + + def test_repr(self): + cond = asyncio.Condition(loop=self.loop) + self.assertTrue('unlocked' in repr(cond)) + self.assertTrue(RGX_REPR.match(repr(cond))) + + self.loop.run_until_complete(cond.acquire()) + self.assertTrue('locked' in repr(cond)) + + cond._waiters.append(mock.Mock()) + self.assertTrue('waiters:1' in repr(cond)) + self.assertTrue(RGX_REPR.match(repr(cond))) + + cond._waiters.append(mock.Mock()) + self.assertTrue('waiters:2' in repr(cond)) + self.assertTrue(RGX_REPR.match(repr(cond))) + + def test_context_manager(self): + cond = asyncio.Condition(loop=self.loop) + + @asyncio.coroutine + def acquire_cond(): + return (yield from cond) + + with self.loop.run_until_complete(acquire_cond()): + self.assertTrue(cond.locked()) + + self.assertFalse(cond.locked()) + + def test_context_manager_no_yield(self): + cond = asyncio.Condition(loop=self.loop) + + try: + with cond: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + self.assertFalse(cond.locked()) + + def test_explicit_lock(self): + lock = asyncio.Lock(loop=self.loop) + cond = asyncio.Condition(lock, loop=self.loop) + + self.assertIs(cond._lock, lock) + self.assertIs(cond._loop, lock._loop) + + def test_ambiguous_loops(self): + loop = self.new_test_loop() + self.addCleanup(loop.close) + + lock = asyncio.Lock(loop=self.loop) + with self.assertRaises(ValueError): + asyncio.Condition(lock, loop=loop) + + +class SemaphoreTests(test_utils.TestCase): + + def setUp(self): + self.loop = self.new_test_loop() + + def test_ctor_loop(self): + loop = mock.Mock() + sem = asyncio.Semaphore(loop=loop) + self.assertIs(sem._loop, loop) + + sem = asyncio.Semaphore(loop=self.loop) + self.assertIs(sem._loop, self.loop) + + def test_ctor_noloop(self): + asyncio.set_event_loop(self.loop) + sem = asyncio.Semaphore() + self.assertIs(sem._loop, self.loop) + + def test_initial_value_zero(self): + sem = asyncio.Semaphore(0, loop=self.loop) + self.assertTrue(sem.locked()) + + def test_repr(self): + sem = asyncio.Semaphore(loop=self.loop) + self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) + self.assertTrue(RGX_REPR.match(repr(sem))) + + self.loop.run_until_complete(sem.acquire()) + self.assertTrue(repr(sem).endswith('[locked]>')) + self.assertTrue('waiters' not in repr(sem)) + self.assertTrue(RGX_REPR.match(repr(sem))) + + sem._waiters.append(mock.Mock()) + self.assertTrue('waiters:1' in repr(sem)) + self.assertTrue(RGX_REPR.match(repr(sem))) + + sem._waiters.append(mock.Mock()) + self.assertTrue('waiters:2' in repr(sem)) + self.assertTrue(RGX_REPR.match(repr(sem))) + + def test_semaphore(self): + sem = asyncio.Semaphore(loop=self.loop) + self.assertEqual(1, sem._value) + + @asyncio.coroutine + def acquire_lock(): + return (yield from sem) + + res = self.loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(sem.locked()) + self.assertEqual(0, sem._value) + + sem.release() + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + def test_semaphore_value(self): + self.assertRaises(ValueError, asyncio.Semaphore, -1) + + def test_acquire(self): + sem = asyncio.Semaphore(3, loop=self.loop) + result = [] + + self.assertTrue(self.loop.run_until_complete(sem.acquire())) + self.assertTrue(self.loop.run_until_complete(sem.acquire())) + self.assertFalse(sem.locked()) + + @asyncio.coroutine + def c1(result): + yield from sem.acquire() + result.append(1) + return True + + @asyncio.coroutine + def c2(result): + yield from sem.acquire() + result.append(2) + return True + + @asyncio.coroutine + def c3(result): + yield from sem.acquire() + result.append(3) + return True + + @asyncio.coroutine + def c4(result): + yield from sem.acquire() + result.append(4) + return True + + t1 = asyncio.Task(c1(result), loop=self.loop) + t2 = asyncio.Task(c2(result), loop=self.loop) + t3 = asyncio.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(sem.locked()) + self.assertEqual(2, len(sem._waiters)) + self.assertEqual(0, sem._value) + + t4 = asyncio.Task(c4(result), loop=self.loop) + + sem.release() + sem.release() + self.assertEqual(2, sem._value) + + test_utils.run_briefly(self.loop) + self.assertEqual(0, sem._value) + self.assertEqual([1, 2, 3], result) + self.assertTrue(sem.locked()) + self.assertEqual(1, len(sem._waiters)) + self.assertEqual(0, sem._value) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + self.assertFalse(t4.done()) + + # cleanup locked semaphore + sem.release() + self.loop.run_until_complete(t4) + + def test_acquire_cancel(self): + sem = asyncio.Semaphore(loop=self.loop) + self.loop.run_until_complete(sem.acquire()) + + acquire = asyncio.Task(sem.acquire(), loop=self.loop) + self.loop.call_soon(acquire.cancel) + self.assertRaises( + asyncio.CancelledError, + self.loop.run_until_complete, acquire) + self.assertFalse(sem._waiters) + + def test_release_not_acquired(self): + sem = asyncio.BoundedSemaphore(loop=self.loop) + + self.assertRaises(ValueError, sem.release) + + def test_release_no_waiters(self): + sem = asyncio.Semaphore(loop=self.loop) + self.loop.run_until_complete(sem.acquire()) + self.assertTrue(sem.locked()) + + sem.release() + self.assertFalse(sem.locked()) + + def test_context_manager(self): + sem = asyncio.Semaphore(2, loop=self.loop) + + @asyncio.coroutine + def acquire_lock(): + return (yield from sem) + + with self.loop.run_until_complete(acquire_lock()): + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + with self.loop.run_until_complete(acquire_lock()): + self.assertTrue(sem.locked()) + + self.assertEqual(2, sem._value) + + def test_context_manager_no_yield(self): + sem = asyncio.Semaphore(2, loop=self.loop) + + try: + with sem: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + self.assertEqual(2, sem._value) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py new file mode 100644 index 00000000..9e9b41a4 --- /dev/null +++ b/tests/test_proactor_events.py @@ -0,0 +1,574 @@ +"""Tests for proactor_events.py""" + +import socket +import unittest +from unittest import mock + +import asyncio +from asyncio.proactor_events import BaseProactorEventLoop +from asyncio.proactor_events import _ProactorSocketTransport +from asyncio.proactor_events import _ProactorWritePipeTransport +from asyncio.proactor_events import _ProactorDuplexPipeTransport +from asyncio import test_utils + + +class ProactorSocketTransportTests(test_utils.TestCase): + + def setUp(self): + self.loop = self.new_test_loop() + self.proactor = mock.Mock() + self.loop._proactor = self.proactor + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) + self.sock = mock.Mock(socket.socket) + + def test_ctor(self): + fut = asyncio.Future(loop=self.loop) + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol, fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + self.protocol.connection_made(tr) + self.proactor.recv.assert_called_with(self.sock, 4096) + + def test_loop_reading(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._loop_reading() + self.loop._proactor.recv.assert_called_with(self.sock, 4096) + self.assertFalse(self.protocol.data_received.called) + self.assertFalse(self.protocol.eof_received.called) + + def test_loop_reading_data(self): + res = asyncio.Future(loop=self.loop) + res.set_result(b'data') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + + tr._read_fut = res + tr._loop_reading(res) + self.loop._proactor.recv.assert_called_with(self.sock, 4096) + self.protocol.data_received.assert_called_with(b'data') + + def test_loop_reading_no_data(self): + res = asyncio.Future(loop=self.loop) + res.set_result(b'') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + + self.assertRaises(AssertionError, tr._loop_reading, res) + + tr.close = mock.Mock() + tr._read_fut = res + tr._loop_reading(res) + self.assertFalse(self.loop._proactor.recv.called) + self.assertTrue(self.protocol.eof_received.called) + self.assertTrue(tr.close.called) + + def test_loop_reading_aborted(self): + err = self.loop._proactor.recv.side_effect = ConnectionAbortedError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with( + err, + 'Fatal read error on pipe transport') + + def test_loop_reading_aborted_closing(self): + self.loop._proactor.recv.side_effect = ConnectionAbortedError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = True + tr._fatal_error = mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + + def test_loop_reading_aborted_is_fatal(self): + self.loop._proactor.recv.side_effect = ConnectionAbortedError() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = False + tr._fatal_error = mock.Mock() + tr._loop_reading() + self.assertTrue(tr._fatal_error.called) + + def test_loop_reading_conn_reset_lost(self): + err = self.loop._proactor.recv.side_effect = ConnectionResetError() + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = False + tr._fatal_error = mock.Mock() + tr._force_close = mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + tr._force_close.assert_called_with(err) + + def test_loop_reading_exception(self): + err = self.loop._proactor.recv.side_effect = (OSError()) + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with( + err, + 'Fatal read error on pipe transport') + + def test_write(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._loop_writing = mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, None) + tr._loop_writing.assert_called_with(data=b'data') + + def test_write_no_data(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr.write(b'') + self.assertFalse(tr._buffer) + + def test_write_more(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = mock.Mock() + tr._loop_writing = mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, b'data') + self.assertFalse(tr._loop_writing.called) + + def test_loop_writing(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = bytearray(b'data') + tr._loop_writing() + self.loop._proactor.send.assert_called_with(self.sock, b'data') + self.loop._proactor.send.return_value.add_done_callback.\ + assert_called_with(tr._loop_writing) + + @mock.patch('asyncio.proactor_events.logger') + def test_loop_writing_err(self, m_log): + err = self.loop._proactor.send.side_effect = OSError() + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._fatal_error = mock.Mock() + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + tr._fatal_error.assert_called_with( + err, + 'Fatal write error on pipe transport') + tr._conn_lost = 1 + + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + self.assertEqual(tr._buffer, None) + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_loop_writing_stop(self): + fut = asyncio.Future(loop=self.loop) + fut.set_result(b'data') + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = fut + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + + def test_loop_writing_closing(self): + fut = asyncio.Future(loop=self.loop) + fut.set_result(1) + + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = fut + tr.close() + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test_abort(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._force_close = mock.Mock() + tr.abort() + tr._force_close.assert_called_with(None) + + def test_close(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr.close() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertTrue(tr._closing) + self.assertEqual(tr._conn_lost, 1) + + self.protocol.connection_lost.reset_mock() + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_write_fut(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._write_fut = mock.Mock() + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_buffer(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + @mock.patch('asyncio.base_events.logger') + def test_fatal_error(self, m_logging): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._force_close = mock.Mock() + tr._fatal_error(None) + self.assertTrue(tr._force_close.called) + self.assertTrue(m_logging.error.called) + + def test_force_close(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + read_fut = tr._read_fut = mock.Mock() + write_fut = tr._write_fut = mock.Mock() + tr._force_close(None) + + read_fut.cancel.assert_called_with() + write_fut.cancel.assert_called_with() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertEqual(None, tr._buffer) + self.assertEqual(tr._conn_lost, 1) + + def test_force_close_idempotent(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._closing = True + tr._force_close(None) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_fatal_error_2(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._buffer = [b'data'] + tr._force_close(None) + + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertEqual(None, tr._buffer) + + def test_call_connection_lost(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr._call_connection_lost(None) + self.assertTrue(self.protocol.connection_lost.called) + self.assertTrue(self.sock.close.called) + + def test_write_eof(self): + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol) + self.assertTrue(tr.can_write_eof()) + tr.write_eof() + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.write_eof() + self.assertEqual(self.sock.shutdown.call_count, 1) + tr.close() + + def test_write_eof_buffer(self): + tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + f = asyncio.Future(loop=self.loop) + tr._loop._proactor.send.return_value = f + tr.write(b'data') + tr.write_eof() + self.assertTrue(tr._eof_written) + self.assertFalse(self.sock.shutdown.called) + tr._loop._proactor.send.assert_called_with(self.sock, b'data') + f.set_result(4) + self.loop._run_once() + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.close() + + def test_write_eof_write_pipe(self): + tr = _ProactorWritePipeTransport( + self.loop, self.sock, self.protocol) + self.assertTrue(tr.can_write_eof()) + tr.write_eof() + self.assertTrue(tr._closing) + self.loop._run_once() + self.assertTrue(self.sock.close.called) + tr.close() + + def test_write_eof_buffer_write_pipe(self): + tr = _ProactorWritePipeTransport(self.loop, self.sock, self.protocol) + f = asyncio.Future(loop=self.loop) + tr._loop._proactor.send.return_value = f + tr.write(b'data') + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.sock.shutdown.called) + tr._loop._proactor.send.assert_called_with(self.sock, b'data') + f.set_result(4) + self.loop._run_once() + self.loop._run_once() + self.assertTrue(self.sock.close.called) + tr.close() + + def test_write_eof_duplex_pipe(self): + tr = _ProactorDuplexPipeTransport( + self.loop, self.sock, self.protocol) + self.assertFalse(tr.can_write_eof()) + with self.assertRaises(NotImplementedError): + tr.write_eof() + tr.close() + + def test_pause_resume_reading(self): + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol) + futures = [] + for msg in [b'data1', b'data2', b'data3', b'data4', b'']: + f = asyncio.Future(loop=self.loop) + f.set_result(msg) + futures.append(f) + self.loop._proactor.recv.side_effect = futures + self.loop._run_once() + self.assertFalse(tr._paused) + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data1') + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data2') + tr.pause_reading() + self.assertTrue(tr._paused) + for i in range(10): + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data2') + tr.resume_reading() + self.assertFalse(tr._paused) + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data3') + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data4') + tr.close() + + + def pause_writing_transport(self, high): + tr = _ProactorSocketTransport( + self.loop, self.sock, self.protocol) + self.addCleanup(tr.close) + + tr.set_write_buffer_limits(high=high) + + self.assertEqual(tr.get_write_buffer_size(), 0) + self.assertFalse(self.protocol.pause_writing.called) + self.assertFalse(self.protocol.resume_writing.called) + return tr + + def test_pause_resume_writing(self): + tr = self.pause_writing_transport(high=4) + + # write a large chunk, must pause writing + fut = asyncio.Future(loop=self.loop) + self.loop._proactor.send.return_value = fut + tr.write(b'large data') + self.loop._run_once() + self.assertTrue(self.protocol.pause_writing.called) + + # flush the buffer + fut.set_result(None) + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 0) + self.assertTrue(self.protocol.resume_writing.called) + + def test_pause_writing_2write(self): + tr = self.pause_writing_transport(high=4) + + # first short write, the buffer is not full (3 <= 4) + fut1 = asyncio.Future(loop=self.loop) + self.loop._proactor.send.return_value = fut1 + tr.write(b'123') + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 3) + self.assertFalse(self.protocol.pause_writing.called) + + # fill the buffer, must pause writing (6 > 4) + tr.write(b'abc') + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 6) + self.assertTrue(self.protocol.pause_writing.called) + + def test_pause_writing_3write(self): + tr = self.pause_writing_transport(high=4) + + # first short write, the buffer is not full (1 <= 4) + fut = asyncio.Future(loop=self.loop) + self.loop._proactor.send.return_value = fut + tr.write(b'1') + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 1) + self.assertFalse(self.protocol.pause_writing.called) + + # second short write, the buffer is not full (3 <= 4) + tr.write(b'23') + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 3) + self.assertFalse(self.protocol.pause_writing.called) + + # fill the buffer, must pause writing (6 > 4) + tr.write(b'abc') + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 6) + self.assertTrue(self.protocol.pause_writing.called) + + def test_dont_pause_writing(self): + tr = self.pause_writing_transport(high=4) + + # write a large chunk which completes immedialty, + # it should not pause writing + fut = asyncio.Future(loop=self.loop) + fut.set_result(None) + self.loop._proactor.send.return_value = fut + tr.write(b'very large data') + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 0) + self.assertFalse(self.protocol.pause_writing.called) + + +class BaseProactorEventLoopTests(test_utils.TestCase): + + def setUp(self): + self.sock = mock.Mock(socket.socket) + self.proactor = mock.Mock() + + self.ssock, self.csock = mock.Mock(), mock.Mock() + + class EventLoop(BaseProactorEventLoop): + def _socketpair(s): + return (self.ssock, self.csock) + + self.loop = EventLoop(self.proactor) + self.set_event_loop(self.loop, cleanup=False) + + @mock.patch.object(BaseProactorEventLoop, '_call_soon') + @mock.patch.object(BaseProactorEventLoop, '_socketpair') + def test_ctor(self, socketpair, _call_soon): + ssock, csock = socketpair.return_value = ( + mock.Mock(), mock.Mock()) + loop = BaseProactorEventLoop(self.proactor) + self.assertIs(loop._ssock, ssock) + self.assertIs(loop._csock, csock) + self.assertEqual(loop._internal_fds, 1) + _call_soon.assert_called_with(loop._loop_self_reading, (), + check_loop=False) + + def test_close_self_pipe(self): + self.loop._close_self_pipe() + self.assertEqual(self.loop._internal_fds, 0) + self.assertTrue(self.ssock.close.called) + self.assertTrue(self.csock.close.called) + self.assertIsNone(self.loop._ssock) + self.assertIsNone(self.loop._csock) + + def test_close(self): + self.loop._close_self_pipe = mock.Mock() + self.loop.close() + self.assertTrue(self.loop._close_self_pipe.called) + self.assertTrue(self.proactor.close.called) + self.assertIsNone(self.loop._proactor) + + self.loop._close_self_pipe.reset_mock() + self.loop.close() + self.assertFalse(self.loop._close_self_pipe.called) + + def test_sock_recv(self): + self.loop.sock_recv(self.sock, 1024) + self.proactor.recv.assert_called_with(self.sock, 1024) + + def test_sock_sendall(self): + self.loop.sock_sendall(self.sock, b'data') + self.proactor.send.assert_called_with(self.sock, b'data') + + def test_sock_connect(self): + self.loop.sock_connect(self.sock, 123) + self.proactor.connect.assert_called_with(self.sock, 123) + + def test_sock_accept(self): + self.loop.sock_accept(self.sock) + self.proactor.accept.assert_called_with(self.sock) + + def test_socketpair(self): + self.assertRaises( + NotImplementedError, BaseProactorEventLoop, self.proactor) + + def test_make_socket_transport(self): + tr = self.loop._make_socket_transport(self.sock, asyncio.Protocol()) + self.assertIsInstance(tr, _ProactorSocketTransport) + + def test_loop_self_reading(self): + self.loop._loop_self_reading() + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_fut(self): + fut = mock.Mock() + self.loop._loop_self_reading(fut) + self.assertTrue(fut.result.called) + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_exception(self): + self.loop.close = mock.Mock() + self.proactor.recv.side_effect = OSError() + self.assertRaises(OSError, self.loop._loop_self_reading) + self.assertTrue(self.loop.close.called) + + def test_write_to_self(self): + self.loop._write_to_self() + self.csock.send.assert_called_with(b'\0') + + def test_process_events(self): + self.loop._process_events([]) + + @mock.patch('asyncio.base_events.logger') + def test_create_server(self, m_log): + pf = mock.Mock() + call_soon = self.loop.call_soon = mock.Mock() + + self.loop._start_serving(pf, self.sock) + self.assertTrue(call_soon.called) + + # callback + loop = call_soon.call_args[0][0] + loop() + self.proactor.accept.assert_called_with(self.sock) + + # conn + fut = mock.Mock() + fut.result.return_value = (mock.Mock(), mock.Mock()) + + make_tr = self.loop._make_socket_transport = mock.Mock() + loop(fut) + self.assertTrue(fut.result.called) + self.assertTrue(make_tr.called) + + # exception + fut.result.side_effect = OSError() + loop(fut) + self.assertTrue(self.sock.close.called) + self.assertTrue(m_log.error.called) + + def test_create_server_cancel(self): + pf = mock.Mock() + call_soon = self.loop.call_soon = mock.Mock() + + self.loop._start_serving(pf, self.sock) + loop = call_soon.call_args[0][0] + + # cancelled + fut = asyncio.Future(loop=self.loop) + fut.cancel() + loop(fut) + self.assertTrue(self.sock.close.called) + + def test_stop_serving(self): + sock = mock.Mock() + self.loop._stop_serving(sock) + self.assertTrue(sock.close.called) + self.proactor._stop_serving.assert_called_with(sock) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_queues.py b/tests/test_queues.py new file mode 100644 index 00000000..3d4ac51d --- /dev/null +++ b/tests/test_queues.py @@ -0,0 +1,476 @@ +"""Tests for queues.py""" + +import unittest +from unittest import mock + +import asyncio +from asyncio import test_utils + + +class _QueueTestBase(test_utils.TestCase): + + def setUp(self): + self.loop = self.new_test_loop() + + +class QueueBasicTests(_QueueTestBase): + + def _test_repr_or_str(self, fn, expect_id): + """Test Queue's repr or str. + + fn is repr or str. expect_id is True if we expect the Queue's id to + appear in fn(Queue()). + """ + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.2, when) + yield 0.1 + + loop = self.new_test_loop(gen) + + q = asyncio.Queue(loop=loop) + self.assertTrue(fn(q).startswith('= (3, 4)) +PY35 = (sys.version_info >= (3, 5)) + + +@asyncio.coroutine +def coroutine_function(): + pass + + +def format_coroutine(qualname, state, src, source_traceback, generator=False): + if generator: + state = '%s' % state + else: + state = '%s, defined' % state + if source_traceback is not None: + frame = source_traceback[-1] + return ('coro=<%s() %s at %s> created at %s:%s' + % (qualname, state, src, frame[0], frame[1])) + else: + return 'coro=<%s() %s at %s>' % (qualname, state, src) + + +class Dummy: + + def __repr__(self): + return '' + + def __call__(self, *args): + pass + + +class TaskTests(test_utils.TestCase): + + def setUp(self): + self.loop = self.new_test_loop() + + def test_task_class(self): + @asyncio.coroutine + def notmuch(): + return 'ok' + t = asyncio.Task(notmuch(), loop=self.loop) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t._loop, self.loop) + + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + t = asyncio.Task(notmuch(), loop=loop) + self.assertIs(t._loop, loop) + loop.run_until_complete(t) + loop.close() + + def test_async_coroutine(self): + @asyncio.coroutine + def notmuch(): + return 'ok' + t = asyncio.async(notmuch(), loop=self.loop) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t._loop, self.loop) + + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + t = asyncio.async(notmuch(), loop=loop) + self.assertIs(t._loop, loop) + loop.run_until_complete(t) + loop.close() + + def test_async_future(self): + f_orig = asyncio.Future(loop=self.loop) + f_orig.set_result('ko') + + f = asyncio.async(f_orig) + self.loop.run_until_complete(f) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 'ko') + self.assertIs(f, f_orig) + + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + + with self.assertRaises(ValueError): + f = asyncio.async(f_orig, loop=loop) + + loop.close() + + f = asyncio.async(f_orig, loop=self.loop) + self.assertIs(f, f_orig) + + def test_async_task(self): + @asyncio.coroutine + def notmuch(): + return 'ok' + t_orig = asyncio.Task(notmuch(), loop=self.loop) + t = asyncio.async(t_orig) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t, t_orig) + + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + + with self.assertRaises(ValueError): + t = asyncio.async(t_orig, loop=loop) + + loop.close() + + t = asyncio.async(t_orig, loop=self.loop) + self.assertIs(t, t_orig) + + def test_async_neither(self): + with self.assertRaises(TypeError): + asyncio.async('ok') + + def test_task_repr(self): + self.loop.set_debug(False) + + @asyncio.coroutine + def notmuch(): + yield from [] + return 'abc' + + # test coroutine function + self.assertEqual(notmuch.__name__, 'notmuch') + if PY35: + self.assertEqual(notmuch.__qualname__, + 'TaskTests.test_task_repr..notmuch') + self.assertEqual(notmuch.__module__, __name__) + + filename, lineno = test_utils.get_function_source(notmuch) + src = "%s:%s" % (filename, lineno) + + # test coroutine object + gen = notmuch() + if coroutines._DEBUG or PY35: + coro_qualname = 'TaskTests.test_task_repr..notmuch' + else: + coro_qualname = 'notmuch' + self.assertEqual(gen.__name__, 'notmuch') + if PY35: + self.assertEqual(gen.__qualname__, + coro_qualname) + + # test pending Task + t = asyncio.Task(gen, loop=self.loop) + t.add_done_callback(Dummy()) + + coro = format_coroutine(coro_qualname, 'running', src, + t._source_traceback, generator=True) + self.assertEqual(repr(t), + '()]>' % coro) + + # test cancelling Task + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), + '()]>' % coro) + + # test cancelled Task + self.assertRaises(asyncio.CancelledError, + self.loop.run_until_complete, t) + coro = format_coroutine(coro_qualname, 'done', src, + t._source_traceback) + self.assertEqual(repr(t), + '' % coro) + + # test finished Task + t = asyncio.Task(notmuch(), loop=self.loop) + self.loop.run_until_complete(t) + coro = format_coroutine(coro_qualname, 'done', src, + t._source_traceback) + self.assertEqual(repr(t), + "" % coro) + + def test_task_repr_coro_decorator(self): + self.loop.set_debug(False) + + @asyncio.coroutine + def notmuch(): + # notmuch() function doesn't use yield from: it will be wrapped by + # @coroutine decorator + return 123 + + # test coroutine function + self.assertEqual(notmuch.__name__, 'notmuch') + if PY35: + self.assertEqual(notmuch.__qualname__, + 'TaskTests.test_task_repr_coro_decorator..notmuch') + self.assertEqual(notmuch.__module__, __name__) + + # test coroutine object + gen = notmuch() + if coroutines._DEBUG or PY35: + # On Python >= 3.5, generators now inherit the name of the + # function, as expected, and have a qualified name (__qualname__ + # attribute). + coro_name = 'notmuch' + coro_qualname = 'TaskTests.test_task_repr_coro_decorator..notmuch' + else: + # On Python < 3.5, generators inherit the name of the code, not of + # the function. See: http://bugs.python.org/issue21205 + coro_name = coro_qualname = 'coro' + self.assertEqual(gen.__name__, coro_name) + if PY35: + self.assertEqual(gen.__qualname__, coro_qualname) + + # test repr(CoroWrapper) + if coroutines._DEBUG: + # format the coroutine object + if coroutines._DEBUG: + filename, lineno = test_utils.get_function_source(notmuch) + frame = gen._source_traceback[-1] + coro = ('%s() running, defined at %s:%s, created at %s:%s' + % (coro_qualname, filename, lineno, + frame[0], frame[1])) + else: + code = gen.gi_code + coro = ('%s() running at %s:%s' + % (coro_qualname, code.co_filename, code.co_firstlineno)) + + self.assertEqual(repr(gen), '' % coro) + + # test pending Task + t = asyncio.Task(gen, loop=self.loop) + t.add_done_callback(Dummy()) + + # format the coroutine object + if coroutines._DEBUG: + src = '%s:%s' % test_utils.get_function_source(notmuch) + else: + code = gen.gi_code + src = '%s:%s' % (code.co_filename, code.co_firstlineno) + coro = format_coroutine(coro_qualname, 'running', src, + t._source_traceback, + generator=not coroutines._DEBUG) + self.assertEqual(repr(t), + '()]>' % coro) + self.loop.run_until_complete(t) + + def test_task_repr_wait_for(self): + self.loop.set_debug(False) + + @asyncio.coroutine + def wait_for(fut): + return (yield from fut) + + fut = asyncio.Future(loop=self.loop) + task = asyncio.Task(wait_for(fut), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertRegex(repr(task), + '' % re.escape(repr(fut))) + + fut.set_result(None) + self.loop.run_until_complete(task) + + def test_task_basics(self): + @asyncio.coroutine + def outer(): + a = yield from inner1() + b = yield from inner2() + return a+b + + @asyncio.coroutine + def inner1(): + return 42 + + @asyncio.coroutine + def inner2(): + return 1000 + + t = outer() + self.assertEqual(self.loop.run_until_complete(t), 1042) + + def test_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = self.new_test_loop(gen) + + @asyncio.coroutine + def task(): + yield from asyncio.sleep(10.0, loop=loop) + return 12 + + t = asyncio.Task(task(), loop=loop) + loop.call_soon(t.cancel) + with self.assertRaises(asyncio.CancelledError): + loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) + self.assertFalse(t.cancel()) + + def test_cancel_yield(self): + @asyncio.coroutine + def task(): + yield + yield + return 12 + + t = asyncio.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) # start coro + t.cancel() + self.assertRaises( + asyncio.CancelledError, self.loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) + self.assertFalse(t.cancel()) + + def test_cancel_inner_future(self): + f = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def task(): + yield from f + return 12 + + t = asyncio.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) # start task + f.cancel() + with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(t) + self.assertTrue(f.cancelled()) + self.assertTrue(t.cancelled()) + + def test_cancel_both_task_and_inner_future(self): + f = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def task(): + yield from f + return 12 + + t = asyncio.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + + f.cancel() + t.cancel() + + with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(t) + + self.assertTrue(t.done()) + self.assertTrue(f.cancelled()) + self.assertTrue(t.cancelled()) + + def test_cancel_task_catching(self): + fut1 = asyncio.Future(loop=self.loop) + fut2 = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def task(): + yield from fut1 + try: + yield from fut2 + except asyncio.CancelledError: + return 42 + + t = asyncio.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut1) # White-box test. + fut1.set_result(None) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut2) # White-box test. + t.cancel() + self.assertTrue(fut2.cancelled()) + res = self.loop.run_until_complete(t) + self.assertEqual(res, 42) + self.assertFalse(t.cancelled()) + + def test_cancel_task_ignoring(self): + fut1 = asyncio.Future(loop=self.loop) + fut2 = asyncio.Future(loop=self.loop) + fut3 = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def task(): + yield from fut1 + try: + yield from fut2 + except asyncio.CancelledError: + pass + res = yield from fut3 + return res + + t = asyncio.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut1) # White-box test. + fut1.set_result(None) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut2) # White-box test. + t.cancel() + self.assertTrue(fut2.cancelled()) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut3) # White-box test. + fut3.set_result(42) + res = self.loop.run_until_complete(t) + self.assertEqual(res, 42) + self.assertFalse(fut3.cancelled()) + self.assertFalse(t.cancelled()) + + def test_cancel_current_task(self): + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + + @asyncio.coroutine + def task(): + t.cancel() + self.assertTrue(t._must_cancel) # White-box test. + # The sleep should be cancelled immediately. + yield from asyncio.sleep(100, loop=loop) + return 12 + + t = asyncio.Task(task(), loop=loop) + self.assertRaises( + asyncio.CancelledError, loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t._must_cancel) # White-box test. + self.assertFalse(t.cancel()) + + def test_stop_while_run_in_complete(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.2, when) + when = yield 0.1 + self.assertAlmostEqual(0.3, when) + yield 0.1 + + loop = self.new_test_loop(gen) + + x = 0 + waiters = [] + + @asyncio.coroutine + def task(): + nonlocal x + while x < 10: + waiters.append(asyncio.sleep(0.1, loop=loop)) + yield from waiters[-1] + x += 1 + if x == 2: + loop.stop() + + t = asyncio.Task(task(), loop=loop) + with self.assertRaises(RuntimeError) as cm: + loop.run_until_complete(t) + self.assertEqual(str(cm.exception), + 'Event loop stopped before Future completed.') + self.assertFalse(t.done()) + self.assertEqual(x, 2) + self.assertAlmostEqual(0.3, loop.time()) + + # close generators + for w in waiters: + w.close() + t.cancel() + self.assertRaises(asyncio.CancelledError, loop.run_until_complete, t) + + def test_wait_for(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.2, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + + loop = self.new_test_loop(gen) + + foo_running = None + + @asyncio.coroutine + def foo(): + nonlocal foo_running + foo_running = True + try: + yield from asyncio.sleep(0.2, loop=loop) + finally: + foo_running = False + return 'done' + + fut = asyncio.Task(foo(), loop=loop) + + with self.assertRaises(asyncio.TimeoutError): + loop.run_until_complete(asyncio.wait_for(fut, 0.1, loop=loop)) + self.assertTrue(fut.done()) + # it should have been cancelled due to the timeout + self.assertTrue(fut.cancelled()) + self.assertAlmostEqual(0.1, loop.time()) + self.assertEqual(foo_running, False) + + def test_wait_for_blocking(self): + loop = self.new_test_loop() + + @asyncio.coroutine + def coro(): + return 'done' + + res = loop.run_until_complete(asyncio.wait_for(coro(), + timeout=None, + loop=loop)) + self.assertEqual(res, 'done') + + def test_wait_for_with_global_loop(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.2, when) + when = yield 0 + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = self.new_test_loop(gen) + + @asyncio.coroutine + def foo(): + yield from asyncio.sleep(0.2, loop=loop) + return 'done' + + asyncio.set_event_loop(loop) + try: + fut = asyncio.Task(foo(), loop=loop) + with self.assertRaises(asyncio.TimeoutError): + loop.run_until_complete(asyncio.wait_for(fut, 0.01)) + finally: + asyncio.set_event_loop(None) + + self.assertAlmostEqual(0.01, loop.time()) + self.assertTrue(fut.done()) + self.assertTrue(fut.cancelled()) + + def test_wait_for_race_condition(self): + + def gen(): + yield 0.1 + yield 0.1 + yield 0.1 + + loop = self.new_test_loop(gen) + + fut = asyncio.Future(loop=loop) + task = asyncio.wait_for(fut, timeout=0.2, loop=loop) + loop.call_later(0.1, fut.set_result, "ok") + res = loop.run_until_complete(task) + self.assertEqual(res, "ok") + + def test_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = self.new_test_loop(gen) + + a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) + b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop) + + @asyncio.coroutine + def foo(): + done, pending = yield from asyncio.wait([b, a], loop=loop) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + self.assertEqual(res, 42) + self.assertAlmostEqual(0.15, loop.time()) + + # Doing it again should take no time and exercise a different path. + res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + self.assertEqual(res, 42) + + def test_wait_with_global_loop(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + when = yield 0 + self.assertAlmostEqual(0.015, when) + yield 0.015 + + loop = self.new_test_loop(gen) + + a = asyncio.Task(asyncio.sleep(0.01, loop=loop), loop=loop) + b = asyncio.Task(asyncio.sleep(0.015, loop=loop), loop=loop) + + @asyncio.coroutine + def foo(): + done, pending = yield from asyncio.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + asyncio.set_event_loop(loop) + res = loop.run_until_complete( + asyncio.Task(foo(), loop=loop)) + + self.assertEqual(res, 42) + + def test_wait_duplicate_coroutines(self): + @asyncio.coroutine + def coro(s): + return s + c = coro('test') + + task = asyncio.Task( + asyncio.wait([c, c, coro('spam')], loop=self.loop), + loop=self.loop) + + done, pending = self.loop.run_until_complete(task) + + self.assertFalse(pending) + self.assertEqual(set(f.result() for f in done), {'test', 'spam'}) + + def test_wait_errors(self): + self.assertRaises( + ValueError, self.loop.run_until_complete, + asyncio.wait(set(), loop=self.loop)) + + # -1 is an invalid return_when value + sleep_coro = asyncio.sleep(10.0, loop=self.loop) + wait_coro = asyncio.wait([sleep_coro], return_when=-1, loop=self.loop) + self.assertRaises(ValueError, + self.loop.run_until_complete, wait_coro) + + sleep_coro.close() + + def test_wait_first_completed(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = self.new_test_loop(gen) + + a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop) + b = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) + task = asyncio.Task( + asyncio.wait([b, a], return_when=asyncio.FIRST_COMPLETED, + loop=loop), + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertFalse(a.done()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + self.assertAlmostEqual(0.1, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(asyncio.wait([a, b], loop=loop)) + + def test_wait_really_done(self): + # there is possibility that some tasks in the pending list + # became done but their callbacks haven't all been called yet + + @asyncio.coroutine + def coro1(): + yield + + @asyncio.coroutine + def coro2(): + yield + yield + + a = asyncio.Task(coro1(), loop=self.loop) + b = asyncio.Task(coro2(), loop=self.loop) + task = asyncio.Task( + asyncio.wait([b, a], return_when=asyncio.FIRST_COMPLETED, + loop=self.loop), + loop=self.loop) + + done, pending = self.loop.run_until_complete(task) + self.assertEqual({a, b}, done) + self.assertTrue(a.done()) + self.assertIsNone(a.result()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + + def test_wait_first_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = self.new_test_loop(gen) + + # first_exception, task already has exception + a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop) + + @asyncio.coroutine + def exc(): + raise ZeroDivisionError('err') + + b = asyncio.Task(exc(), loop=loop) + task = asyncio.Task( + asyncio.wait([b, a], return_when=asyncio.FIRST_EXCEPTION, + loop=loop), + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertAlmostEqual(0, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(asyncio.wait([a, b], loop=loop)) + + def test_wait_first_exception_in_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = self.new_test_loop(gen) + + # first_exception, exception during waiting + a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop) + + @asyncio.coroutine + def exc(): + yield from asyncio.sleep(0.01, loop=loop) + raise ZeroDivisionError('err') + + b = asyncio.Task(exc(), loop=loop) + task = asyncio.wait([b, a], return_when=asyncio.FIRST_EXCEPTION, + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertAlmostEqual(0.01, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(asyncio.wait([a, b], loop=loop)) + + def test_wait_with_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = self.new_test_loop(gen) + + a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) + + @asyncio.coroutine + def sleeper(): + yield from asyncio.sleep(0.15, loop=loop) + raise ZeroDivisionError('really') + + b = asyncio.Task(sleeper(), loop=loop) + + @asyncio.coroutine + def foo(): + done, pending = yield from asyncio.wait([b, a], loop=loop) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + + loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + def test_wait_with_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.11, when) + yield 0.11 + + loop = self.new_test_loop(gen) + + a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) + b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop) + + @asyncio.coroutine + def foo(): + done, pending = yield from asyncio.wait([b, a], timeout=0.11, + loop=loop) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + + loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.11, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(asyncio.wait([a, b], loop=loop)) + + def test_wait_concurrent_complete(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = self.new_test_loop(gen) + + a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) + b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop) + + done, pending = loop.run_until_complete( + asyncio.wait([b, a], timeout=0.1, loop=loop)) + + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + self.assertAlmostEqual(0.1, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(asyncio.wait([a, b], loop=loop)) + + def test_as_completed(self): + + def gen(): + yield 0 + yield 0 + yield 0.01 + yield 0 + + loop = self.new_test_loop(gen) + # disable "slow callback" warning + loop.slow_callback_duration = 1.0 + completed = set() + time_shifted = False + + @asyncio.coroutine + def sleeper(dt, x): + nonlocal time_shifted + yield from asyncio.sleep(dt, loop=loop) + completed.add(x) + if not time_shifted and 'a' in completed and 'b' in completed: + time_shifted = True + loop.advance_time(0.14) + return x + + a = sleeper(0.01, 'a') + b = sleeper(0.01, 'b') + c = sleeper(0.15, 'c') + + @asyncio.coroutine + def foo(): + values = [] + for f in asyncio.as_completed([b, c, a], loop=loop): + values.append((yield from f)) + return values + + res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + + # Doing it again should take no time and exercise a different path. + res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + def test_as_completed_with_timeout(self): + + def gen(): + yield + yield 0 + yield 0 + yield 0.1 + + loop = self.new_test_loop(gen) + + a = asyncio.sleep(0.1, 'a', loop=loop) + b = asyncio.sleep(0.15, 'b', loop=loop) + + @asyncio.coroutine + def foo(): + values = [] + for f in asyncio.as_completed([a, b], timeout=0.12, loop=loop): + if values: + loop.advance_time(0.02) + try: + v = yield from f + values.append((1, v)) + except asyncio.TimeoutError as exc: + values.append((2, exc)) + return values + + res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertIsInstance(res[1][1], asyncio.TimeoutError) + self.assertAlmostEqual(0.12, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(asyncio.wait([a, b], loop=loop)) + + def test_as_completed_with_unused_timeout(self): + + def gen(): + yield + yield 0 + yield 0.01 + + loop = self.new_test_loop(gen) + + a = asyncio.sleep(0.01, 'a', loop=loop) + + @asyncio.coroutine + def foo(): + for f in asyncio.as_completed([a], timeout=1, loop=loop): + v = yield from f + self.assertEqual(v, 'a') + + loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + + def test_as_completed_reverse_wait(self): + + def gen(): + yield 0 + yield 0.05 + yield 0 + + loop = self.new_test_loop(gen) + + a = asyncio.sleep(0.05, 'a', loop=loop) + b = asyncio.sleep(0.10, 'b', loop=loop) + fs = {a, b} + futs = list(asyncio.as_completed(fs, loop=loop)) + self.assertEqual(len(futs), 2) + + x = loop.run_until_complete(futs[1]) + self.assertEqual(x, 'a') + self.assertAlmostEqual(0.05, loop.time()) + loop.advance_time(0.05) + y = loop.run_until_complete(futs[0]) + self.assertEqual(y, 'b') + self.assertAlmostEqual(0.10, loop.time()) + + def test_as_completed_concurrent(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0 + self.assertAlmostEqual(0.05, when) + yield 0.05 + + loop = self.new_test_loop(gen) + + a = asyncio.sleep(0.05, 'a', loop=loop) + b = asyncio.sleep(0.05, 'b', loop=loop) + fs = {a, b} + futs = list(asyncio.as_completed(fs, loop=loop)) + self.assertEqual(len(futs), 2) + waiter = asyncio.wait(futs, loop=loop) + done, pending = loop.run_until_complete(waiter) + self.assertEqual(set(f.result() for f in done), {'a', 'b'}) + + def test_as_completed_duplicate_coroutines(self): + + @asyncio.coroutine + def coro(s): + return s + + @asyncio.coroutine + def runner(): + result = [] + c = coro('ham') + for f in asyncio.as_completed([c, c, coro('spam')], + loop=self.loop): + result.append((yield from f)) + return result + + fut = asyncio.Task(runner(), loop=self.loop) + self.loop.run_until_complete(fut) + result = fut.result() + self.assertEqual(set(result), {'ham', 'spam'}) + self.assertEqual(len(result), 2) + + def test_sleep(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0.05 + self.assertAlmostEqual(0.1, when) + yield 0.05 + + loop = self.new_test_loop(gen) + + @asyncio.coroutine + def sleeper(dt, arg): + yield from asyncio.sleep(dt/2, loop=loop) + res = yield from asyncio.sleep(dt/2, arg, loop=loop) + return res + + t = asyncio.Task(sleeper(0.1, 'yeah'), loop=loop) + loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + self.assertAlmostEqual(0.1, loop.time()) + + def test_sleep_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = self.new_test_loop(gen) + + t = asyncio.Task(asyncio.sleep(10.0, 'yeah', loop=loop), + loop=loop) + + handle = None + orig_call_later = loop.call_later + + def call_later(delay, callback, *args): + nonlocal handle + handle = orig_call_later(delay, callback, *args) + return handle + + loop.call_later = call_later + test_utils.run_briefly(loop) + + self.assertFalse(handle._cancelled) + + t.cancel() + test_utils.run_briefly(loop) + self.assertTrue(handle._cancelled) + + def test_task_cancel_sleeping_task(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(5000, when) + yield 0.1 + + loop = self.new_test_loop(gen) + + @asyncio.coroutine + def sleep(dt): + yield from asyncio.sleep(dt, loop=loop) + + @asyncio.coroutine + def doit(): + sleeper = asyncio.Task(sleep(5000), loop=loop) + loop.call_later(0.1, sleeper.cancel) + try: + yield from sleeper + except asyncio.CancelledError: + return 'cancelled' + else: + return 'slept in' + + doer = doit() + self.assertEqual(loop.run_until_complete(doer), 'cancelled') + self.assertAlmostEqual(0.1, loop.time()) + + def test_task_cancel_waiter_future(self): + fut = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def coro(): + yield from fut + + task = asyncio.Task(coro(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(task._fut_waiter, fut) + + task.cancel() + test_utils.run_briefly(self.loop) + self.assertRaises( + asyncio.CancelledError, self.loop.run_until_complete, task) + self.assertIsNone(task._fut_waiter) + self.assertTrue(fut.cancelled()) + + def test_step_in_completed_task(self): + @asyncio.coroutine + def notmuch(): + return 'ko' + + gen = notmuch() + task = asyncio.Task(gen, loop=self.loop) + task.set_result('ok') + + self.assertRaises(AssertionError, task._step) + gen.close() + + def test_step_result(self): + @asyncio.coroutine + def notmuch(): + yield None + yield 1 + return 'ko' + + self.assertRaises( + RuntimeError, self.loop.run_until_complete, notmuch()) + + def test_step_result_future(self): + # If coroutine returns future, task waits on this future. + + class Fut(asyncio.Future): + def __init__(self, *args, **kwds): + self.cb_added = False + super().__init__(*args, **kwds) + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + fut = Fut(loop=self.loop) + result = None + + @asyncio.coroutine + def wait_for_future(): + nonlocal result + result = yield from fut + + t = asyncio.Task(wait_for_future(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertTrue(fut.cb_added) + + res = object() + fut.set_result(res) + test_utils.run_briefly(self.loop) + self.assertIs(res, result) + self.assertTrue(t.done()) + self.assertIsNone(t.result()) + + def test_step_with_baseexception(self): + @asyncio.coroutine + def notmutch(): + raise BaseException() + + task = asyncio.Task(notmutch(), loop=self.loop) + self.assertRaises(BaseException, task._step) + + self.assertTrue(task.done()) + self.assertIsInstance(task.exception(), BaseException) + + def test_baseexception_during_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = self.new_test_loop(gen) + + @asyncio.coroutine + def sleeper(): + yield from asyncio.sleep(10, loop=loop) + + base_exc = BaseException() + + @asyncio.coroutine + def notmutch(): + try: + yield from sleeper() + except asyncio.CancelledError: + raise base_exc + + task = asyncio.Task(notmutch(), loop=loop) + test_utils.run_briefly(loop) + + task.cancel() + self.assertFalse(task.done()) + + self.assertRaises(BaseException, test_utils.run_briefly, loop) + + self.assertTrue(task.done()) + self.assertFalse(task.cancelled()) + self.assertIs(task.exception(), base_exc) + + def test_iscoroutinefunction(self): + def fn(): + pass + + self.assertFalse(asyncio.iscoroutinefunction(fn)) + + def fn1(): + yield + self.assertFalse(asyncio.iscoroutinefunction(fn1)) + + @asyncio.coroutine + def fn2(): + yield + self.assertTrue(asyncio.iscoroutinefunction(fn2)) + + def test_yield_vs_yield_from(self): + fut = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def wait_for_future(): + yield fut + + task = wait_for_future() + with self.assertRaises(RuntimeError): + self.loop.run_until_complete(task) + + self.assertFalse(fut.done()) + + def test_yield_vs_yield_from_generator(self): + @asyncio.coroutine + def coro(): + yield + + @asyncio.coroutine + def wait_for_future(): + gen = coro() + try: + yield gen + finally: + gen.close() + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, task) + + def test_coroutine_non_gen_function(self): + @asyncio.coroutine + def func(): + return 'test' + + self.assertTrue(asyncio.iscoroutinefunction(func)) + + coro = func() + self.assertTrue(asyncio.iscoroutine(coro)) + + res = self.loop.run_until_complete(coro) + self.assertEqual(res, 'test') + + def test_coroutine_non_gen_function_return_future(self): + fut = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def func(): + return fut + + @asyncio.coroutine + def coro(): + fut.set_result('test') + + t1 = asyncio.Task(func(), loop=self.loop) + t2 = asyncio.Task(coro(), loop=self.loop) + res = self.loop.run_until_complete(t1) + self.assertEqual(res, 'test') + self.assertIsNone(t2.result()) + + def test_current_task(self): + self.assertIsNone(asyncio.Task.current_task(loop=self.loop)) + + @asyncio.coroutine + def coro(loop): + self.assertTrue(asyncio.Task.current_task(loop=loop) is task) + + task = asyncio.Task(coro(self.loop), loop=self.loop) + self.loop.run_until_complete(task) + self.assertIsNone(asyncio.Task.current_task(loop=self.loop)) + + def test_current_task_with_interleaving_tasks(self): + self.assertIsNone(asyncio.Task.current_task(loop=self.loop)) + + fut1 = asyncio.Future(loop=self.loop) + fut2 = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def coro1(loop): + self.assertTrue(asyncio.Task.current_task(loop=loop) is task1) + yield from fut1 + self.assertTrue(asyncio.Task.current_task(loop=loop) is task1) + fut2.set_result(True) + + @asyncio.coroutine + def coro2(loop): + self.assertTrue(asyncio.Task.current_task(loop=loop) is task2) + fut1.set_result(True) + yield from fut2 + self.assertTrue(asyncio.Task.current_task(loop=loop) is task2) + + task1 = asyncio.Task(coro1(self.loop), loop=self.loop) + task2 = asyncio.Task(coro2(self.loop), loop=self.loop) + + self.loop.run_until_complete(asyncio.wait((task1, task2), + loop=self.loop)) + self.assertIsNone(asyncio.Task.current_task(loop=self.loop)) + + # Some thorough tests for cancellation propagation through + # coroutines, tasks and wait(). + + def test_yield_future_passes_cancel(self): + # Cancelling outer() cancels inner() cancels waiter. + proof = 0 + waiter = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def inner(): + nonlocal proof + try: + yield from waiter + except asyncio.CancelledError: + proof += 1 + raise + else: + self.fail('got past sleep() in inner()') + + @asyncio.coroutine + def outer(): + nonlocal proof + try: + yield from inner() + except asyncio.CancelledError: + proof += 100 # Expect this path. + else: + proof += 10 + + f = asyncio.async(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + self.loop.run_until_complete(f) + self.assertEqual(proof, 101) + self.assertTrue(waiter.cancelled()) + + def test_yield_wait_does_not_shield_cancel(self): + # Cancelling outer() makes wait() return early, leaves inner() + # running. + proof = 0 + waiter = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def inner(): + nonlocal proof + yield from waiter + proof += 1 + + @asyncio.coroutine + def outer(): + nonlocal proof + d, p = yield from asyncio.wait([inner()], loop=self.loop) + proof += 100 + + f = asyncio.async(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + self.assertRaises( + asyncio.CancelledError, self.loop.run_until_complete, f) + waiter.set_result(None) + test_utils.run_briefly(self.loop) + self.assertEqual(proof, 1) + + def test_shield_result(self): + inner = asyncio.Future(loop=self.loop) + outer = asyncio.shield(inner) + inner.set_result(42) + res = self.loop.run_until_complete(outer) + self.assertEqual(res, 42) + + def test_shield_exception(self): + inner = asyncio.Future(loop=self.loop) + outer = asyncio.shield(inner) + test_utils.run_briefly(self.loop) + exc = RuntimeError('expected') + inner.set_exception(exc) + test_utils.run_briefly(self.loop) + self.assertIs(outer.exception(), exc) + + def test_shield_cancel(self): + inner = asyncio.Future(loop=self.loop) + outer = asyncio.shield(inner) + test_utils.run_briefly(self.loop) + inner.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(outer.cancelled()) + + def test_shield_shortcut(self): + fut = asyncio.Future(loop=self.loop) + fut.set_result(42) + res = self.loop.run_until_complete(asyncio.shield(fut)) + self.assertEqual(res, 42) + + def test_shield_effect(self): + # Cancelling outer() does not affect inner(). + proof = 0 + waiter = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def inner(): + nonlocal proof + yield from waiter + proof += 1 + + @asyncio.coroutine + def outer(): + nonlocal proof + yield from asyncio.shield(inner(), loop=self.loop) + proof += 100 + + f = asyncio.async(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(f) + waiter.set_result(None) + test_utils.run_briefly(self.loop) + self.assertEqual(proof, 1) + + def test_shield_gather(self): + child1 = asyncio.Future(loop=self.loop) + child2 = asyncio.Future(loop=self.loop) + parent = asyncio.gather(child1, child2, loop=self.loop) + outer = asyncio.shield(parent, loop=self.loop) + test_utils.run_briefly(self.loop) + outer.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(outer.cancelled()) + child1.set_result(1) + child2.set_result(2) + test_utils.run_briefly(self.loop) + self.assertEqual(parent.result(), [1, 2]) + + def test_gather_shield(self): + child1 = asyncio.Future(loop=self.loop) + child2 = asyncio.Future(loop=self.loop) + inner1 = asyncio.shield(child1, loop=self.loop) + inner2 = asyncio.shield(child2, loop=self.loop) + parent = asyncio.gather(inner1, inner2, loop=self.loop) + test_utils.run_briefly(self.loop) + parent.cancel() + # This should cancel inner1 and inner2 but bot child1 and child2. + test_utils.run_briefly(self.loop) + self.assertIsInstance(parent.exception(), asyncio.CancelledError) + self.assertTrue(inner1.cancelled()) + self.assertTrue(inner2.cancelled()) + child1.set_result(1) + child2.set_result(2) + test_utils.run_briefly(self.loop) + + def test_as_completed_invalid_args(self): + fut = asyncio.Future(loop=self.loop) + + # as_completed() expects a list of futures, not a future instance + self.assertRaises(TypeError, self.loop.run_until_complete, + asyncio.as_completed(fut, loop=self.loop)) + coro = coroutine_function() + self.assertRaises(TypeError, self.loop.run_until_complete, + asyncio.as_completed(coro, loop=self.loop)) + coro.close() + + def test_wait_invalid_args(self): + fut = asyncio.Future(loop=self.loop) + + # wait() expects a list of futures, not a future instance + self.assertRaises(TypeError, self.loop.run_until_complete, + asyncio.wait(fut, loop=self.loop)) + coro = coroutine_function() + self.assertRaises(TypeError, self.loop.run_until_complete, + asyncio.wait(coro, loop=self.loop)) + coro.close() + + # wait() expects at least a future + self.assertRaises(ValueError, self.loop.run_until_complete, + asyncio.wait([], loop=self.loop)) + + def test_corowrapper_mocks_generator(self): + + def check(): + # A function that asserts various things. + # Called twice, with different debug flag values. + + @asyncio.coroutine + def coro(): + # The actual coroutine. + self.assertTrue(gen.gi_running) + yield from fut + + # A completed Future used to run the coroutine. + fut = asyncio.Future(loop=self.loop) + fut.set_result(None) + + # Call the coroutine. + gen = coro() + + # Check some properties. + self.assertTrue(asyncio.iscoroutine(gen)) + self.assertIsInstance(gen.gi_frame, types.FrameType) + self.assertFalse(gen.gi_running) + self.assertIsInstance(gen.gi_code, types.CodeType) + + # Run it. + self.loop.run_until_complete(gen) + + # The frame should have changed. + self.assertIsNone(gen.gi_frame) + + # Save debug flag. + old_debug = asyncio.coroutines._DEBUG + try: + # Test with debug flag cleared. + asyncio.coroutines._DEBUG = False + check() + + # Test with debug flag set. + asyncio.coroutines._DEBUG = True + check() + + finally: + # Restore original debug flag. + asyncio.coroutines._DEBUG = old_debug + + def test_yield_from_corowrapper(self): + old_debug = asyncio.coroutines._DEBUG + asyncio.coroutines._DEBUG = True + try: + @asyncio.coroutine + def t1(): + return (yield from t2()) + + @asyncio.coroutine + def t2(): + f = asyncio.Future(loop=self.loop) + asyncio.Task(t3(f), loop=self.loop) + return (yield from f) + + @asyncio.coroutine + def t3(f): + f.set_result((1, 2, 3)) + + task = asyncio.Task(t1(), loop=self.loop) + val = self.loop.run_until_complete(task) + self.assertEqual(val, (1, 2, 3)) + finally: + asyncio.coroutines._DEBUG = old_debug + + def test_yield_from_corowrapper_send(self): + def foo(): + a = yield + return a + + def call(arg): + cw = asyncio.coroutines.CoroWrapper(foo(), foo) + cw.send(None) + try: + cw.send(arg) + except StopIteration as ex: + return ex.args[0] + else: + raise AssertionError('StopIteration was expected') + + self.assertEqual(call((1, 2)), (1, 2)) + self.assertEqual(call('spam'), 'spam') + + def test_corowrapper_weakref(self): + wd = weakref.WeakValueDictionary() + def foo(): yield from [] + cw = asyncio.coroutines.CoroWrapper(foo(), foo) + wd['cw'] = cw # Would fail without __weakref__ slot. + cw.gen = None # Suppress warning from __del__. + + @unittest.skipUnless(PY34, + 'need python 3.4 or later') + def test_log_destroyed_pending_task(self): + @asyncio.coroutine + def kill_me(loop): + future = asyncio.Future(loop=loop) + yield from future + # at this point, the only reference to kill_me() task is + # the Task._wakeup() method in future._callbacks + raise Exception("code never reached") + + mock_handler = mock.Mock() + self.loop.set_debug(True) + self.loop.set_exception_handler(mock_handler) + + # schedule the task + coro = kill_me(self.loop) + task = asyncio.async(coro, loop=self.loop) + self.assertEqual(asyncio.Task.all_tasks(loop=self.loop), {task}) + + # execute the task so it waits for future + self.loop._run_once() + self.assertEqual(len(self.loop._ready), 0) + + # remove the future used in kill_me(), and references to the task + del coro.gi_frame.f_locals['future'] + coro = None + source_traceback = task._source_traceback + task = None + + # no more reference to kill_me() task: the task is destroyed by the GC + support.gc_collect() + + self.assertEqual(asyncio.Task.all_tasks(loop=self.loop), set()) + + mock_handler.assert_called_with(self.loop, { + 'message': 'Task was destroyed but it is pending!', + 'task': mock.ANY, + 'source_traceback': source_traceback, + }) + mock_handler.reset_mock() + + @mock.patch('asyncio.coroutines.logger') + def test_coroutine_never_yielded(self, m_log): + debug = asyncio.coroutines._DEBUG + try: + asyncio.coroutines._DEBUG = True + @asyncio.coroutine + def coro_noop(): + pass + finally: + asyncio.coroutines._DEBUG = debug + + tb_filename = __file__ + tb_lineno = sys._getframe().f_lineno + 2 + # create a coroutine object but don't use it + coro_noop() + support.gc_collect() + + self.assertTrue(m_log.error.called) + message = m_log.error.call_args[0][0] + func_filename, func_lineno = test_utils.get_function_source(coro_noop) + regex = (r'^ was never yielded from\n' + r'Coroutine object created at \(most recent call last\):\n' + r'.*\n' + r' File "%s", line %s, in test_coroutine_never_yielded\n' + r' coro_noop\(\)$' + % (re.escape(coro_noop.__qualname__), + re.escape(func_filename), func_lineno, + re.escape(tb_filename), tb_lineno)) + + self.assertRegex(message, re.compile(regex, re.DOTALL)) + + def test_task_source_traceback(self): + self.loop.set_debug(True) + + task = asyncio.Task(coroutine_function(), loop=self.loop) + lineno = sys._getframe().f_lineno - 1 + self.assertIsInstance(task._source_traceback, list) + self.assertEqual(task._source_traceback[-1][:3], + (__file__, + lineno, + 'test_task_source_traceback')) + self.loop.run_until_complete(task) + + +class GatherTestsBase: + + def setUp(self): + self.one_loop = self.new_test_loop() + self.other_loop = self.new_test_loop() + self.set_event_loop(self.one_loop, cleanup=False) + + def _run_loop(self, loop): + while loop._ready: + test_utils.run_briefly(loop) + + def _check_success(self, **kwargs): + a, b, c = [asyncio.Future(loop=self.one_loop) for i in range(3)] + fut = asyncio.gather(*self.wrap_futures(a, b, c), **kwargs) + cb = test_utils.MockCallback() + fut.add_done_callback(cb) + b.set_result(1) + a.set_result(2) + self._run_loop(self.one_loop) + self.assertEqual(cb.called, False) + self.assertFalse(fut.done()) + c.set_result(3) + self._run_loop(self.one_loop) + cb.assert_called_once_with(fut) + self.assertEqual(fut.result(), [2, 1, 3]) + + def test_success(self): + self._check_success() + self._check_success(return_exceptions=False) + + def test_result_exception_success(self): + self._check_success(return_exceptions=True) + + def test_one_exception(self): + a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)] + fut = asyncio.gather(*self.wrap_futures(a, b, c, d, e)) + cb = test_utils.MockCallback() + fut.add_done_callback(cb) + exc = ZeroDivisionError() + a.set_result(1) + b.set_exception(exc) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertIs(fut.exception(), exc) + # Does nothing + c.set_result(3) + d.cancel() + e.set_exception(RuntimeError()) + e.exception() + + def test_return_exceptions(self): + a, b, c, d = [asyncio.Future(loop=self.one_loop) for i in range(4)] + fut = asyncio.gather(*self.wrap_futures(a, b, c, d), + return_exceptions=True) + cb = test_utils.MockCallback() + fut.add_done_callback(cb) + exc = ZeroDivisionError() + exc2 = RuntimeError() + b.set_result(1) + c.set_exception(exc) + a.set_result(3) + self._run_loop(self.one_loop) + self.assertFalse(fut.done()) + d.set_exception(exc2) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertEqual(fut.result(), [3, 1, exc, exc2]) + + def test_env_var_debug(self): + aio_path = os.path.dirname(os.path.dirname(asyncio.__file__)) + + code = '\n'.join(( + 'import asyncio.coroutines', + 'print(asyncio.coroutines._DEBUG)')) + + # Test with -E to not fail if the unit test was run with + # PYTHONASYNCIODEBUG set to a non-empty string + sts, stdout, stderr = assert_python_ok('-E', '-c', code, + PYTHONPATH=aio_path) + self.assertEqual(stdout.rstrip(), b'False') + + sts, stdout, stderr = assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='', + PYTHONPATH=aio_path) + self.assertEqual(stdout.rstrip(), b'False') + + sts, stdout, stderr = assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='1', + PYTHONPATH=aio_path) + self.assertEqual(stdout.rstrip(), b'True') + + sts, stdout, stderr = assert_python_ok('-E', '-c', code, + PYTHONASYNCIODEBUG='1', + PYTHONPATH=aio_path) + self.assertEqual(stdout.rstrip(), b'False') + + +class FutureGatherTests(GatherTestsBase, test_utils.TestCase): + + def wrap_futures(self, *futures): + return futures + + def _check_empty_sequence(self, seq_or_iter): + asyncio.set_event_loop(self.one_loop) + self.addCleanup(asyncio.set_event_loop, None) + fut = asyncio.gather(*seq_or_iter) + self.assertIsInstance(fut, asyncio.Future) + self.assertIs(fut._loop, self.one_loop) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + self.assertEqual(fut.result(), []) + fut = asyncio.gather(*seq_or_iter, loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + + def test_constructor_empty_sequence(self): + self._check_empty_sequence([]) + self._check_empty_sequence(()) + self._check_empty_sequence(set()) + self._check_empty_sequence(iter("")) + + def test_constructor_heterogenous_futures(self): + fut1 = asyncio.Future(loop=self.one_loop) + fut2 = asyncio.Future(loop=self.other_loop) + with self.assertRaises(ValueError): + asyncio.gather(fut1, fut2) + with self.assertRaises(ValueError): + asyncio.gather(fut1, loop=self.other_loop) + + def test_constructor_homogenous_futures(self): + children = [asyncio.Future(loop=self.other_loop) for i in range(3)] + fut = asyncio.gather(*children) + self.assertIs(fut._loop, self.other_loop) + self._run_loop(self.other_loop) + self.assertFalse(fut.done()) + fut = asyncio.gather(*children, loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + self._run_loop(self.other_loop) + self.assertFalse(fut.done()) + + def test_one_cancellation(self): + a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)] + fut = asyncio.gather(a, b, c, d, e) + cb = test_utils.MockCallback() + fut.add_done_callback(cb) + a.set_result(1) + b.cancel() + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertFalse(fut.cancelled()) + self.assertIsInstance(fut.exception(), asyncio.CancelledError) + # Does nothing + c.set_result(3) + d.cancel() + e.set_exception(RuntimeError()) + e.exception() + + def test_result_exception_one_cancellation(self): + a, b, c, d, e, f = [asyncio.Future(loop=self.one_loop) + for i in range(6)] + fut = asyncio.gather(a, b, c, d, e, f, return_exceptions=True) + cb = test_utils.MockCallback() + fut.add_done_callback(cb) + a.set_result(1) + zde = ZeroDivisionError() + b.set_exception(zde) + c.cancel() + self._run_loop(self.one_loop) + self.assertFalse(fut.done()) + d.set_result(3) + e.cancel() + rte = RuntimeError() + f.set_exception(rte) + res = self.one_loop.run_until_complete(fut) + self.assertIsInstance(res[2], asyncio.CancelledError) + self.assertIsInstance(res[4], asyncio.CancelledError) + res[2] = res[4] = None + self.assertEqual(res, [1, zde, None, 3, None, rte]) + cb.assert_called_once_with(fut) + + +class CoroutineGatherTests(GatherTestsBase, test_utils.TestCase): + + def setUp(self): + super().setUp() + asyncio.set_event_loop(self.one_loop) + + def wrap_futures(self, *futures): + coros = [] + for fut in futures: + @asyncio.coroutine + def coro(fut=fut): + return (yield from fut) + coros.append(coro()) + return coros + + def test_constructor_loop_selection(self): + @asyncio.coroutine + def coro(): + return 'abc' + gen1 = coro() + gen2 = coro() + fut = asyncio.gather(gen1, gen2) + self.assertIs(fut._loop, self.one_loop) + self.one_loop.run_until_complete(fut) + + self.set_event_loop(self.other_loop, cleanup=False) + gen3 = coro() + gen4 = coro() + fut2 = asyncio.gather(gen3, gen4, loop=self.other_loop) + self.assertIs(fut2._loop, self.other_loop) + self.other_loop.run_until_complete(fut2) + + def test_duplicate_coroutines(self): + @asyncio.coroutine + def coro(s): + return s + c = coro('abc') + fut = asyncio.gather(c, c, coro('def'), c, loop=self.one_loop) + self._run_loop(self.one_loop) + self.assertEqual(fut.result(), ['abc', 'abc', 'def', 'abc']) + + def test_cancellation_broadcast(self): + # Cancelling outer() cancels all children. + proof = 0 + waiter = asyncio.Future(loop=self.one_loop) + + @asyncio.coroutine + def inner(): + nonlocal proof + yield from waiter + proof += 1 + + child1 = asyncio.async(inner(), loop=self.one_loop) + child2 = asyncio.async(inner(), loop=self.one_loop) + gatherer = None + + @asyncio.coroutine + def outer(): + nonlocal proof, gatherer + gatherer = asyncio.gather(child1, child2, loop=self.one_loop) + yield from gatherer + proof += 100 + + f = asyncio.async(outer(), loop=self.one_loop) + test_utils.run_briefly(self.one_loop) + self.assertTrue(f.cancel()) + with self.assertRaises(asyncio.CancelledError): + self.one_loop.run_until_complete(f) + self.assertFalse(gatherer.cancel()) + self.assertTrue(waiter.cancelled()) + self.assertTrue(child1.cancelled()) + self.assertTrue(child2.cancelled()) + test_utils.run_briefly(self.one_loop) + self.assertEqual(proof, 0) + + def test_exception_marking(self): + # Test for the first line marked "Mark exception retrieved." + + @asyncio.coroutine + def inner(f): + yield from f + raise RuntimeError('should not be ignored') + + a = asyncio.Future(loop=self.one_loop) + b = asyncio.Future(loop=self.one_loop) + + @asyncio.coroutine + def outer(): + yield from asyncio.gather(inner(a), inner(b), loop=self.one_loop) + + f = asyncio.async(outer(), loop=self.one_loop) + test_utils.run_briefly(self.one_loop) + a.set_result(None) + test_utils.run_briefly(self.one_loop) + b.set_result(None) + test_utils.run_briefly(self.one_loop) + self.assertIsInstance(f.exception(), RuntimeError) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_transports.py b/tests/test_transports.py new file mode 100644 index 00000000..3b6e3d67 --- /dev/null +++ b/tests/test_transports.py @@ -0,0 +1,91 @@ +"""Tests for transports.py.""" + +import unittest +from unittest import mock + +import asyncio +from asyncio import transports + + +class TransportTests(unittest.TestCase): + + def test_ctor_extra_is_none(self): + transport = asyncio.Transport() + self.assertEqual(transport._extra, {}) + + def test_get_extra_info(self): + transport = asyncio.Transport({'extra': 'info'}) + self.assertEqual('info', transport.get_extra_info('extra')) + self.assertIsNone(transport.get_extra_info('unknown')) + + default = object() + self.assertIs(default, transport.get_extra_info('unknown', default)) + + def test_writelines(self): + transport = asyncio.Transport() + transport.write = mock.Mock() + + transport.writelines([b'line1', + bytearray(b'line2'), + memoryview(b'line3')]) + self.assertEqual(1, transport.write.call_count) + transport.write.assert_called_with(b'line1line2line3') + + def test_not_implemented(self): + transport = asyncio.Transport() + + self.assertRaises(NotImplementedError, + transport.set_write_buffer_limits) + self.assertRaises(NotImplementedError, transport.get_write_buffer_size) + self.assertRaises(NotImplementedError, transport.write, 'data') + self.assertRaises(NotImplementedError, transport.write_eof) + self.assertRaises(NotImplementedError, transport.can_write_eof) + self.assertRaises(NotImplementedError, transport.pause_reading) + self.assertRaises(NotImplementedError, transport.resume_reading) + self.assertRaises(NotImplementedError, transport.close) + self.assertRaises(NotImplementedError, transport.abort) + + def test_dgram_not_implemented(self): + transport = asyncio.DatagramTransport() + + self.assertRaises(NotImplementedError, transport.sendto, 'data') + self.assertRaises(NotImplementedError, transport.abort) + + def test_subprocess_transport_not_implemented(self): + transport = asyncio.SubprocessTransport() + + self.assertRaises(NotImplementedError, transport.get_pid) + self.assertRaises(NotImplementedError, transport.get_returncode) + self.assertRaises(NotImplementedError, transport.get_pipe_transport, 1) + self.assertRaises(NotImplementedError, transport.send_signal, 1) + self.assertRaises(NotImplementedError, transport.terminate) + self.assertRaises(NotImplementedError, transport.kill) + + def test_flowcontrol_mixin_set_write_limits(self): + + class MyTransport(transports._FlowControlMixin, + transports.Transport): + + def get_write_buffer_size(self): + return 512 + + loop = mock.Mock() + transport = MyTransport(loop=loop) + transport._protocol = mock.Mock() + + self.assertFalse(transport._protocol_paused) + + with self.assertRaisesRegex(ValueError, 'high.*must be >= low'): + transport.set_write_buffer_limits(high=0, low=1) + + transport.set_write_buffer_limits(high=1024, low=128) + self.assertFalse(transport._protocol_paused) + self.assertEqual(transport.get_write_buffer_limits(), (128, 1024)) + + transport.set_write_buffer_limits(high=256, low=128) + self.assertTrue(transport._protocol_paused) + self.assertEqual(transport.get_write_buffer_limits(), (128, 256)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py new file mode 100644 index 00000000..b6ad0189 --- /dev/null +++ b/tests/test_unix_events.py @@ -0,0 +1,1600 @@ +"""Tests for unix_events.py.""" + +import collections +import gc +import errno +import io +import os +import pprint +import signal +import socket +import stat +import sys +import tempfile +import threading +import unittest +from unittest import mock + +if sys.platform == 'win32': + raise unittest.SkipTest('UNIX only') + + +import asyncio +from asyncio import log +from asyncio import test_utils +from asyncio import unix_events + + +MOCK_ANY = mock.ANY + + +@unittest.skipUnless(signal, 'Signals are not supported') +class SelectorEventLoopSignalTests(test_utils.TestCase): + + def setUp(self): + self.loop = asyncio.SelectorEventLoop() + self.set_event_loop(self.loop) + + def test_check_signal(self): + self.assertRaises( + TypeError, self.loop._check_signal, '1') + self.assertRaises( + ValueError, self.loop._check_signal, signal.NSIG + 1) + + def test_handle_signal_no_handler(self): + self.loop._handle_signal(signal.NSIG + 1) + + def test_handle_signal_cancelled_handler(self): + h = asyncio.Handle(mock.Mock(), (), + loop=mock.Mock()) + h.cancel() + self.loop._signal_handlers[signal.NSIG + 1] = h + self.loop.remove_signal_handler = mock.Mock() + self.loop._handle_signal(signal.NSIG + 1) + self.loop.remove_signal_handler.assert_called_with(signal.NSIG + 1) + + @mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler_setup_error(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.set_wakeup_fd.side_effect = ValueError + + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler_coroutine_error(self, m_signal): + m_signal.NSIG = signal.NSIG + + @asyncio.coroutine + def simple_coroutine(): + yield from [] + + # callback must not be a coroutine function + coro_func = simple_coroutine + coro_obj = coro_func() + self.addCleanup(coro_obj.close) + for func in (coro_func, coro_obj): + self.assertRaisesRegex( + TypeError, 'coroutines cannot be used with add_signal_handler', + self.loop.add_signal_handler, + signal.SIGINT, func) + + @mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + cb = lambda: True + self.loop.add_signal_handler(signal.SIGHUP, cb) + h = self.loop._signal_handlers.get(signal.SIGHUP) + self.assertIsInstance(h, asyncio.Handle) + self.assertEqual(h._callback, cb) + + @mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler_install_error(self, m_signal): + m_signal.NSIG = signal.NSIG + + def set_wakeup_fd(fd): + if fd == -1: + raise ValueError() + m_signal.set_wakeup_fd = set_wakeup_fd + + class Err(OSError): + errno = errno.EFAULT + m_signal.signal.side_effect = Err + + self.assertRaises( + Err, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @mock.patch('asyncio.unix_events.signal') + @mock.patch('asyncio.base_events.logger') + def test_add_signal_handler_install_error2(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.loop._signal_handlers[signal.SIGHUP] = lambda: True + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(1, m_signal.set_wakeup_fd.call_count) + + @mock.patch('asyncio.unix_events.signal') + @mock.patch('asyncio.base_events.logger') + def test_add_signal_handler_install_error3(self, m_logging, m_signal): + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + m_signal.NSIG = signal.NSIG + + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(2, m_signal.set_wakeup_fd.call_count) + + @mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + self.assertTrue( + self.loop.remove_signal_handler(signal.SIGHUP)) + self.assertTrue(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0]) + + @mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler_2(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.SIGINT = signal.SIGINT + + self.loop.add_signal_handler(signal.SIGINT, lambda: True) + self.loop._signal_handlers[signal.SIGHUP] = object() + m_signal.set_wakeup_fd.reset_mock() + + self.assertTrue( + self.loop.remove_signal_handler(signal.SIGINT)) + self.assertFalse(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGINT, m_signal.default_int_handler), + m_signal.signal.call_args[0]) + + @mock.patch('asyncio.unix_events.signal') + @mock.patch('asyncio.base_events.logger') + def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.set_wakeup_fd.side_effect = ValueError + + self.loop.remove_signal_handler(signal.SIGHUP) + self.assertTrue(m_logging.info) + + @mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler_error(self, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.signal.side_effect = OSError + + self.assertRaises( + OSError, self.loop.remove_signal_handler, signal.SIGHUP) + + @mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler_error2(self, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.assertRaises( + RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP) + + @mock.patch('asyncio.unix_events.signal') + def test_close(self, m_signal): + m_signal.NSIG = signal.NSIG + + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + self.loop.add_signal_handler(signal.SIGCHLD, lambda: True) + + self.assertEqual(len(self.loop._signal_handlers), 2) + + m_signal.set_wakeup_fd.reset_mock() + + self.loop.close() + + self.assertEqual(len(self.loop._signal_handlers), 0) + m_signal.set_wakeup_fd.assert_called_once_with(-1) + + +@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), + 'UNIX Sockets are not supported') +class SelectorEventLoopUnixSocketTests(test_utils.TestCase): + + def setUp(self): + self.loop = asyncio.SelectorEventLoop() + self.set_event_loop(self.loop) + + def test_create_unix_server_existing_path_sock(self): + with test_utils.unix_socket_path() as path: + sock = socket.socket(socket.AF_UNIX) + sock.bind(path) + with sock: + coro = self.loop.create_unix_server(lambda: None, path) + with self.assertRaisesRegex(OSError, + 'Address.*is already in use'): + self.loop.run_until_complete(coro) + + def test_create_unix_server_existing_path_nonsock(self): + with tempfile.NamedTemporaryFile() as file: + coro = self.loop.create_unix_server(lambda: None, file.name) + with self.assertRaisesRegex(OSError, + 'Address.*is already in use'): + self.loop.run_until_complete(coro) + + def test_create_unix_server_ssl_bool(self): + coro = self.loop.create_unix_server(lambda: None, path='spam', + ssl=True) + with self.assertRaisesRegex(TypeError, + 'ssl argument must be an SSLContext'): + self.loop.run_until_complete(coro) + + def test_create_unix_server_nopath_nosock(self): + coro = self.loop.create_unix_server(lambda: None, path=None) + with self.assertRaisesRegex(ValueError, + 'path was not specified, and no sock'): + self.loop.run_until_complete(coro) + + def test_create_unix_server_path_inetsock(self): + sock = socket.socket() + with sock: + coro = self.loop.create_unix_server(lambda: None, path=None, + sock=sock) + with self.assertRaisesRegex(ValueError, + 'A UNIX Domain Socket was expected'): + self.loop.run_until_complete(coro) + + @mock.patch('asyncio.unix_events.socket') + def test_create_unix_server_bind_error(self, m_socket): + # Ensure that the socket is closed on any bind error + sock = mock.Mock() + m_socket.socket.return_value = sock + + sock.bind.side_effect = OSError + coro = self.loop.create_unix_server(lambda: None, path="/test") + with self.assertRaises(OSError): + self.loop.run_until_complete(coro) + self.assertTrue(sock.close.called) + + sock.bind.side_effect = MemoryError + coro = self.loop.create_unix_server(lambda: None, path="/test") + with self.assertRaises(MemoryError): + self.loop.run_until_complete(coro) + self.assertTrue(sock.close.called) + + def test_create_unix_connection_path_sock(self): + coro = self.loop.create_unix_connection( + lambda: None, '/dev/null', sock=object()) + with self.assertRaisesRegex(ValueError, 'path and sock can not be'): + self.loop.run_until_complete(coro) + + def test_create_unix_connection_nopath_nosock(self): + coro = self.loop.create_unix_connection( + lambda: None, None) + with self.assertRaisesRegex(ValueError, + 'no path and sock were specified'): + self.loop.run_until_complete(coro) + + def test_create_unix_connection_nossl_serverhost(self): + coro = self.loop.create_unix_connection( + lambda: None, '/dev/null', server_hostname='spam') + with self.assertRaisesRegex(ValueError, + 'server_hostname is only meaningful'): + self.loop.run_until_complete(coro) + + def test_create_unix_connection_ssl_noserverhost(self): + coro = self.loop.create_unix_connection( + lambda: None, '/dev/null', ssl=True) + + with self.assertRaisesRegex( + ValueError, 'you have to pass server_hostname when using ssl'): + + self.loop.run_until_complete(coro) + + +class UnixReadPipeTransportTests(test_utils.TestCase): + + def setUp(self): + self.loop = self.new_test_loop() + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) + self.pipe = mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + + blocking_patcher = mock.patch('asyncio.unix_events._set_nonblocking') + blocking_patcher.start() + self.addCleanup(blocking_patcher.stop) + + fstat_patcher = mock.patch('os.fstat') + m_fstat = fstat_patcher.start() + st = mock.Mock() + st.st_mode = stat.S_IFIFO + m_fstat.return_value = st + self.addCleanup(fstat_patcher.stop) + + def test_ctor(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = asyncio.Future(loop=self.loop) + unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol, fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + + @mock.patch('os.read') + def test__read_ready(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.return_value = b'data' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.data_received.assert_called_with(b'data') + + @mock.patch('os.read') + def test__read_ready_eof(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.return_value = b'' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.eof_received.assert_called_with() + self.protocol.connection_lost.assert_called_with(None) + + @mock.patch('os.read') + def test__read_ready_blocked(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + m_read.side_effect = BlockingIOError + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.data_received.called) + + @mock.patch('asyncio.log.logger.error') + @mock.patch('os.read') + def test__read_ready_error(self, m_read, m_logexc): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + err = OSError() + m_read.side_effect = err + tr._close = mock.Mock() + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + tr._close.assert_called_with(err) + m_logexc.assert_called_with( + test_utils.MockPattern( + 'Fatal read error on pipe transport' + '\nprotocol:.*\ntransport:.*'), + exc_info=(OSError, MOCK_ANY, MOCK_ANY)) + + @mock.patch('os.read') + def test_pause_reading(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + m = mock.Mock() + self.loop.add_reader(5, m) + tr.pause_reading() + self.assertFalse(self.loop.readers) + + @mock.patch('os.read') + def test_resume_reading(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr.resume_reading() + self.loop.assert_reader(5, tr._read_ready) + + @mock.patch('os.read') + def test_close(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr._close = mock.Mock() + tr.close() + tr._close.assert_called_with(None) + + @mock.patch('os.read') + def test_close_already_closing(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + tr._closing = True + tr._close = mock.Mock() + tr.close() + self.assertFalse(tr._close.called) + + @mock.patch('os.read') + def test__close(self, m_read): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + + err = object() + tr._close(err) + self.assertTrue(tr._closing) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + + def test__call_connection_lost(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + self.assertIsNotNone(tr._protocol) + self.assertIsNotNone(tr._loop) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertIsNone(tr._loop) + + def test__call_connection_lost_with_err(self): + tr = unix_events._UnixReadPipeTransport( + self.loop, self.pipe, self.protocol) + self.assertIsNotNone(tr._protocol) + self.assertIsNotNone(tr._loop) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertIsNone(tr._loop) + + +class UnixWritePipeTransportTests(test_utils.TestCase): + + def setUp(self): + self.loop = self.new_test_loop() + self.protocol = test_utils.make_test_protocol(asyncio.BaseProtocol) + self.pipe = mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + + blocking_patcher = mock.patch('asyncio.unix_events._set_nonblocking') + blocking_patcher.start() + self.addCleanup(blocking_patcher.stop) + + fstat_patcher = mock.patch('os.fstat') + m_fstat = fstat_patcher.start() + st = mock.Mock() + st.st_mode = stat.S_IFSOCK + m_fstat.return_value = st + self.addCleanup(fstat_patcher.stop) + + def test_ctor(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = asyncio.Future(loop=self.loop) + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol, fut) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.assertEqual(None, fut.result()) + + def test_can_write_eof(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.assertTrue(tr.can_write_eof()) + + @mock.patch('os.write') + def test_write(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.return_value = 4 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @mock.patch('os.write') + def test_write_no_data(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write(b'') + self.assertFalse(m_write.called) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @mock.patch('os.write') + def test_write_partial(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.return_value = 2 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'ta'], tr._buffer) + + @mock.patch('os.write') + def test_write_buffer(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'previous'] + tr.write(b'data') + self.assertFalse(m_write.called) + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'previous', b'data'], tr._buffer) + + @mock.patch('os.write') + def test_write_again(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + m_write.side_effect = BlockingIOError() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @mock.patch('asyncio.unix_events.logger') + @mock.patch('os.write') + def test_write_err(self, m_write, m_log): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + err = OSError() + m_write.side_effect = err + tr._fatal_error = mock.Mock() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + tr._fatal_error.assert_called_with( + err, + 'Fatal write error on pipe transport') + self.assertEqual(1, tr._conn_lost) + + tr.write(b'data') + self.assertEqual(2, tr._conn_lost) + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + # This is a bit overspecified. :-( + m_log.warning.assert_called_with( + 'pipe closed by peer or os.write(pipe, data) raised exception.') + + @mock.patch('os.write') + def test_write_close(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr._read_ready() # pipe was closed by peer + + tr.write(b'data') + self.assertEqual(tr._conn_lost, 1) + tr.write(b'data') + self.assertEqual(tr._conn_lost, 2) + + def test__read_ready(self): + tr = unix_events._UnixWritePipeTransport(self.loop, self.pipe, + self.protocol) + tr._read_ready() + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + self.assertTrue(tr._closing) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + @mock.patch('os.write') + def test__write_ready(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @mock.patch('os.write') + def test__write_ready_partial(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 3 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'a'], tr._buffer) + + @mock.patch('os.write') + def test__write_ready_again(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.side_effect = BlockingIOError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @mock.patch('os.write') + def test__write_ready_empty(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 0 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @mock.patch('asyncio.log.logger.error') + @mock.patch('os.write') + def test__write_ready_err(self, m_write, m_logexc): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.side_effect = err = OSError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + m_logexc.assert_called_with( + test_utils.MockPattern( + 'Fatal write error on pipe transport' + '\nprotocol:.*\ntransport:.*'), + exc_info=(OSError, MOCK_ANY, MOCK_ANY)) + self.assertEqual(1, tr._conn_lost) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + + @mock.patch('os.write') + def test__write_ready_closing(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + tr._closing = True + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) + self.assertEqual([], tr._buffer) + self.protocol.connection_lost.assert_called_with(None) + self.pipe.close.assert_called_with() + + @mock.patch('os.write') + def test_abort(self, m_write): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + self.loop.add_writer(5, tr._write_ready) + self.loop.add_reader(5, tr._read_ready) + tr._buffer = [b'da', b'ta'] + tr.abort() + self.assertFalse(m_write.called) + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test__call_connection_lost(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.assertIsNotNone(tr._protocol) + self.assertIsNotNone(tr._loop) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertIsNone(tr._loop) + + def test__call_connection_lost_with_err(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + self.assertIsNotNone(tr._protocol) + self.assertIsNotNone(tr._loop) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertIsNone(tr._loop) + + def test_close(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof = mock.Mock() + tr.close() + tr.write_eof.assert_called_with() + + def test_close_closing(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof = mock.Mock() + tr._closing = True + tr.close() + self.assertFalse(tr.write_eof.called) + + def test_write_eof(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test_write_eof_pending(self): + tr = unix_events._UnixWritePipeTransport( + self.loop, self.pipe, self.protocol) + tr._buffer = [b'data'] + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.protocol.connection_lost.called) + + +class AbstractChildWatcherTests(unittest.TestCase): + + def test_not_implemented(self): + f = mock.Mock() + watcher = asyncio.AbstractChildWatcher() + self.assertRaises( + NotImplementedError, watcher.add_child_handler, f, f) + self.assertRaises( + NotImplementedError, watcher.remove_child_handler, f) + self.assertRaises( + NotImplementedError, watcher.attach_loop, f) + self.assertRaises( + NotImplementedError, watcher.close) + self.assertRaises( + NotImplementedError, watcher.__enter__) + self.assertRaises( + NotImplementedError, watcher.__exit__, f, f, f) + + +class BaseChildWatcherTests(unittest.TestCase): + + def test_not_implemented(self): + f = mock.Mock() + watcher = unix_events.BaseChildWatcher() + self.assertRaises( + NotImplementedError, watcher._do_waitpid, f) + + +WaitPidMocks = collections.namedtuple("WaitPidMocks", + ("waitpid", + "WIFEXITED", + "WIFSIGNALED", + "WEXITSTATUS", + "WTERMSIG", + )) + + +class ChildWatcherTestsMixin: + + ignore_warnings = mock.patch.object(log.logger, "warning") + + def setUp(self): + self.loop = self.new_test_loop() + self.running = False + self.zombies = {} + + with mock.patch.object( + self.loop, "add_signal_handler") as self.m_add_signal_handler: + self.watcher = self.create_watcher() + self.watcher.attach_loop(self.loop) + + def waitpid(self, pid, flags): + if isinstance(self.watcher, asyncio.SafeChildWatcher) or pid != -1: + self.assertGreater(pid, 0) + try: + if pid < 0: + return self.zombies.popitem() + else: + return pid, self.zombies.pop(pid) + except KeyError: + pass + if self.running: + return 0, 0 + else: + raise ChildProcessError() + + def add_zombie(self, pid, returncode): + self.zombies[pid] = returncode + 32768 + + def WIFEXITED(self, status): + return status >= 32768 + + def WIFSIGNALED(self, status): + return 32700 < status < 32768 + + def WEXITSTATUS(self, status): + self.assertTrue(self.WIFEXITED(status)) + return status - 32768 + + def WTERMSIG(self, status): + self.assertTrue(self.WIFSIGNALED(status)) + return 32768 - status + + def test_create_watcher(self): + self.m_add_signal_handler.assert_called_once_with( + signal.SIGCHLD, self.watcher._sig_chld) + + def waitpid_mocks(func): + def wrapped_func(self): + def patch(target, wrapper): + return mock.patch(target, wraps=wrapper, + new_callable=mock.Mock) + + with patch('os.WTERMSIG', self.WTERMSIG) as m_WTERMSIG, \ + patch('os.WEXITSTATUS', self.WEXITSTATUS) as m_WEXITSTATUS, \ + patch('os.WIFSIGNALED', self.WIFSIGNALED) as m_WIFSIGNALED, \ + patch('os.WIFEXITED', self.WIFEXITED) as m_WIFEXITED, \ + patch('os.waitpid', self.waitpid) as m_waitpid: + func(self, WaitPidMocks(m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG, + )) + return wrapped_func + + @waitpid_mocks + def test_sigchld(self, m): + # register a child + callback = mock.Mock() + + with self.watcher: + self.running = True + self.watcher.add_child_handler(42, callback, 9, 10, 14) + + self.assertFalse(callback.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # child is running + self.watcher._sig_chld() + + self.assertFalse(callback.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # child terminates (returncode 12) + self.running = False + self.add_zombie(42, 12) + self.watcher._sig_chld() + + self.assertTrue(m.WIFEXITED.called) + self.assertTrue(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + callback.assert_called_once_with(42, 12, 9, 10, 14) + + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WEXITSTATUS.reset_mock() + callback.reset_mock() + + # ensure that the child is effectively reaped + self.add_zombie(42, 13) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback.called) + self.assertFalse(m.WTERMSIG.called) + + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WEXITSTATUS.reset_mock() + + # sigchld called again + self.zombies.clear() + self.watcher._sig_chld() + + self.assertFalse(callback.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + @waitpid_mocks + def test_sigchld_two_children(self, m): + callback1 = mock.Mock() + callback2 = mock.Mock() + + # register child 1 + with self.watcher: + self.running = True + self.watcher.add_child_handler(43, callback1, 7, 8) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # register child 2 + with self.watcher: + self.watcher.add_child_handler(44, callback2, 147, 18) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # children are running + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # child 1 terminates (signal 3) + self.add_zombie(43, -3) + self.watcher._sig_chld() + + callback1.assert_called_once_with(43, -3, 7, 8) + self.assertFalse(callback2.called) + self.assertTrue(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertTrue(m.WTERMSIG.called) + + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WTERMSIG.reset_mock() + callback1.reset_mock() + + # child 2 still running + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # child 2 terminates (code 108) + self.add_zombie(44, 108) + self.running = False + self.watcher._sig_chld() + + callback2.assert_called_once_with(44, 108, 147, 18) + self.assertFalse(callback1.called) + self.assertTrue(m.WIFEXITED.called) + self.assertTrue(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WEXITSTATUS.reset_mock() + callback2.reset_mock() + + # ensure that the children are effectively reaped + self.add_zombie(43, 14) + self.add_zombie(44, 15) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WTERMSIG.called) + + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WEXITSTATUS.reset_mock() + + # sigchld called again + self.zombies.clear() + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + @waitpid_mocks + def test_sigchld_two_children_terminating_together(self, m): + callback1 = mock.Mock() + callback2 = mock.Mock() + + # register child 1 + with self.watcher: + self.running = True + self.watcher.add_child_handler(45, callback1, 17, 8) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # register child 2 + with self.watcher: + self.watcher.add_child_handler(46, callback2, 1147, 18) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # children are running + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # child 1 terminates (code 78) + # child 2 terminates (signal 5) + self.add_zombie(45, 78) + self.add_zombie(46, -5) + self.running = False + self.watcher._sig_chld() + + callback1.assert_called_once_with(45, 78, 17, 8) + callback2.assert_called_once_with(46, -5, 1147, 18) + self.assertTrue(m.WIFSIGNALED.called) + self.assertTrue(m.WIFEXITED.called) + self.assertTrue(m.WEXITSTATUS.called) + self.assertTrue(m.WTERMSIG.called) + + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WTERMSIG.reset_mock() + m.WEXITSTATUS.reset_mock() + callback1.reset_mock() + callback2.reset_mock() + + # ensure that the children are effectively reaped + self.add_zombie(45, 14) + self.add_zombie(46, 15) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WTERMSIG.called) + + @waitpid_mocks + def test_sigchld_race_condition(self, m): + # register a child + callback = mock.Mock() + + with self.watcher: + # child terminates before being registered + self.add_zombie(50, 4) + self.watcher._sig_chld() + + self.watcher.add_child_handler(50, callback, 1, 12) + + callback.assert_called_once_with(50, 4, 1, 12) + callback.reset_mock() + + # ensure that the child is effectively reaped + self.add_zombie(50, -1) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback.called) + + @waitpid_mocks + def test_sigchld_replace_handler(self, m): + callback1 = mock.Mock() + callback2 = mock.Mock() + + # register a child + with self.watcher: + self.running = True + self.watcher.add_child_handler(51, callback1, 19) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # register the same child again + with self.watcher: + self.watcher.add_child_handler(51, callback2, 21) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # child terminates (signal 8) + self.running = False + self.add_zombie(51, -8) + self.watcher._sig_chld() + + callback2.assert_called_once_with(51, -8, 21) + self.assertFalse(callback1.called) + self.assertTrue(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertTrue(m.WTERMSIG.called) + + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WTERMSIG.reset_mock() + callback2.reset_mock() + + # ensure that the child is effectively reaped + self.add_zombie(51, 13) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WTERMSIG.called) + + @waitpid_mocks + def test_sigchld_remove_handler(self, m): + callback = mock.Mock() + + # register a child + with self.watcher: + self.running = True + self.watcher.add_child_handler(52, callback, 1984) + + self.assertFalse(callback.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # unregister the child + self.watcher.remove_child_handler(52) + + self.assertFalse(callback.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # child terminates (code 99) + self.running = False + self.add_zombie(52, 99) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback.called) + + @waitpid_mocks + def test_sigchld_unknown_status(self, m): + callback = mock.Mock() + + # register a child + with self.watcher: + self.running = True + self.watcher.add_child_handler(53, callback, -19) + + self.assertFalse(callback.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # terminate with unknown status + self.zombies[53] = 1178 + self.running = False + self.watcher._sig_chld() + + callback.assert_called_once_with(53, 1178, -19) + self.assertTrue(m.WIFEXITED.called) + self.assertTrue(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + callback.reset_mock() + m.WIFEXITED.reset_mock() + m.WIFSIGNALED.reset_mock() + + # ensure that the child is effectively reaped + self.add_zombie(53, 101) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback.called) + + @waitpid_mocks + def test_remove_child_handler(self, m): + callback1 = mock.Mock() + callback2 = mock.Mock() + callback3 = mock.Mock() + + # register children + with self.watcher: + self.running = True + self.watcher.add_child_handler(54, callback1, 1) + self.watcher.add_child_handler(55, callback2, 2) + self.watcher.add_child_handler(56, callback3, 3) + + # remove child handler 1 + self.assertTrue(self.watcher.remove_child_handler(54)) + + # remove child handler 2 multiple times + self.assertTrue(self.watcher.remove_child_handler(55)) + self.assertFalse(self.watcher.remove_child_handler(55)) + self.assertFalse(self.watcher.remove_child_handler(55)) + + # all children terminate + self.add_zombie(54, 0) + self.add_zombie(55, 1) + self.add_zombie(56, 2) + self.running = False + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + callback3.assert_called_once_with(56, 2, 3) + + @waitpid_mocks + def test_sigchld_unhandled_exception(self, m): + callback = mock.Mock() + + # register a child + with self.watcher: + self.running = True + self.watcher.add_child_handler(57, callback) + + # raise an exception + m.waitpid.side_effect = ValueError + + with mock.patch.object(log.logger, + 'error') as m_error: + + self.assertEqual(self.watcher._sig_chld(), None) + self.assertTrue(m_error.called) + + @waitpid_mocks + def test_sigchld_child_reaped_elsewhere(self, m): + # register a child + callback = mock.Mock() + + with self.watcher: + self.running = True + self.watcher.add_child_handler(58, callback) + + self.assertFalse(callback.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # child terminates + self.running = False + self.add_zombie(58, 4) + + # waitpid is called elsewhere + os.waitpid(58, os.WNOHANG) + + m.waitpid.reset_mock() + + # sigchld + with self.ignore_warnings: + self.watcher._sig_chld() + + if isinstance(self.watcher, asyncio.FastChildWatcher): + # here the FastChildWatche enters a deadlock + # (there is no way to prevent it) + self.assertFalse(callback.called) + else: + callback.assert_called_once_with(58, 255) + + @waitpid_mocks + def test_sigchld_unknown_pid_during_registration(self, m): + # register two children + callback1 = mock.Mock() + callback2 = mock.Mock() + + with self.ignore_warnings, self.watcher: + self.running = True + # child 1 terminates + self.add_zombie(591, 7) + # an unknown child terminates + self.add_zombie(593, 17) + + self.watcher._sig_chld() + + self.watcher.add_child_handler(591, callback1) + self.watcher.add_child_handler(592, callback2) + + callback1.assert_called_once_with(591, 7) + self.assertFalse(callback2.called) + + @waitpid_mocks + def test_set_loop(self, m): + # register a child + callback = mock.Mock() + + with self.watcher: + self.running = True + self.watcher.add_child_handler(60, callback) + + # attach a new loop + old_loop = self.loop + self.loop = self.new_test_loop() + patch = mock.patch.object + + with patch(old_loop, "remove_signal_handler") as m_old_remove, \ + patch(self.loop, "add_signal_handler") as m_new_add: + + self.watcher.attach_loop(self.loop) + + m_old_remove.assert_called_once_with( + signal.SIGCHLD) + m_new_add.assert_called_once_with( + signal.SIGCHLD, self.watcher._sig_chld) + + # child terminates + self.running = False + self.add_zombie(60, 9) + self.watcher._sig_chld() + + callback.assert_called_once_with(60, 9) + + @waitpid_mocks + def test_set_loop_race_condition(self, m): + # register 3 children + callback1 = mock.Mock() + callback2 = mock.Mock() + callback3 = mock.Mock() + + with self.watcher: + self.running = True + self.watcher.add_child_handler(61, callback1) + self.watcher.add_child_handler(62, callback2) + self.watcher.add_child_handler(622, callback3) + + # detach the loop + old_loop = self.loop + self.loop = None + + with mock.patch.object( + old_loop, "remove_signal_handler") as m_remove_signal_handler: + + self.watcher.attach_loop(None) + + m_remove_signal_handler.assert_called_once_with( + signal.SIGCHLD) + + # child 1 & 2 terminate + self.add_zombie(61, 11) + self.add_zombie(62, -5) + + # SIGCHLD was not caught + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(callback3.called) + + # attach a new loop + self.loop = self.new_test_loop() + + with mock.patch.object( + self.loop, "add_signal_handler") as m_add_signal_handler: + + self.watcher.attach_loop(self.loop) + + m_add_signal_handler.assert_called_once_with( + signal.SIGCHLD, self.watcher._sig_chld) + callback1.assert_called_once_with(61, 11) # race condition! + callback2.assert_called_once_with(62, -5) # race condition! + self.assertFalse(callback3.called) + + callback1.reset_mock() + callback2.reset_mock() + + # child 3 terminates + self.running = False + self.add_zombie(622, 19) + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + callback3.assert_called_once_with(622, 19) + + @waitpid_mocks + def test_close(self, m): + # register two children + callback1 = mock.Mock() + + with self.watcher: + self.running = True + # child 1 terminates + self.add_zombie(63, 9) + # other child terminates + self.add_zombie(65, 18) + self.watcher._sig_chld() + + self.watcher.add_child_handler(63, callback1) + self.watcher.add_child_handler(64, callback1) + + self.assertEqual(len(self.watcher._callbacks), 1) + if isinstance(self.watcher, asyncio.FastChildWatcher): + self.assertEqual(len(self.watcher._zombies), 1) + + with mock.patch.object( + self.loop, + "remove_signal_handler") as m_remove_signal_handler: + + self.watcher.close() + + m_remove_signal_handler.assert_called_once_with( + signal.SIGCHLD) + self.assertFalse(self.watcher._callbacks) + if isinstance(self.watcher, asyncio.FastChildWatcher): + self.assertFalse(self.watcher._zombies) + + +class SafeChildWatcherTests (ChildWatcherTestsMixin, test_utils.TestCase): + def create_watcher(self): + return asyncio.SafeChildWatcher() + + +class FastChildWatcherTests (ChildWatcherTestsMixin, test_utils.TestCase): + def create_watcher(self): + return asyncio.FastChildWatcher() + + +class PolicyTests(unittest.TestCase): + + def create_policy(self): + return asyncio.DefaultEventLoopPolicy() + + def test_get_child_watcher(self): + policy = self.create_policy() + self.assertIsNone(policy._watcher) + + watcher = policy.get_child_watcher() + self.assertIsInstance(watcher, asyncio.SafeChildWatcher) + + self.assertIs(policy._watcher, watcher) + + self.assertIs(watcher, policy.get_child_watcher()) + self.assertIsNone(watcher._loop) + + def test_get_child_watcher_after_set(self): + policy = self.create_policy() + watcher = asyncio.FastChildWatcher() + + policy.set_child_watcher(watcher) + self.assertIs(policy._watcher, watcher) + self.assertIs(watcher, policy.get_child_watcher()) + + def test_get_child_watcher_with_mainloop_existing(self): + policy = self.create_policy() + loop = policy.get_event_loop() + + self.assertIsNone(policy._watcher) + watcher = policy.get_child_watcher() + + self.assertIsInstance(watcher, asyncio.SafeChildWatcher) + self.assertIs(watcher._loop, loop) + + loop.close() + + def test_get_child_watcher_thread(self): + + def f(): + policy.set_event_loop(policy.new_event_loop()) + + self.assertIsInstance(policy.get_event_loop(), + asyncio.AbstractEventLoop) + watcher = policy.get_child_watcher() + + self.assertIsInstance(watcher, asyncio.SafeChildWatcher) + self.assertIsNone(watcher._loop) + + policy.get_event_loop().close() + + policy = self.create_policy() + + th = threading.Thread(target=f) + th.start() + th.join() + + def test_child_watcher_replace_mainloop_existing(self): + policy = self.create_policy() + loop = policy.get_event_loop() + + watcher = policy.get_child_watcher() + + self.assertIs(watcher._loop, loop) + + new_loop = policy.new_event_loop() + policy.set_event_loop(new_loop) + + self.assertIs(watcher._loop, new_loop) + + policy.set_event_loop(None) + + self.assertIs(watcher._loop, None) + + loop.close() + new_loop.close() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py new file mode 100644 index 00000000..b4d9398f --- /dev/null +++ b/tests/test_windows_events.py @@ -0,0 +1,141 @@ +import os +import sys +import unittest + +if sys.platform != 'win32': + raise unittest.SkipTest('Windows only') + +import _winapi + +import asyncio +from asyncio import _overlapped +from asyncio import test_utils +from asyncio import windows_events + + +class UpperProto(asyncio.Protocol): + def __init__(self): + self.buf = [] + + def connection_made(self, trans): + self.trans = trans + + def data_received(self, data): + self.buf.append(data) + if b'\n' in data: + self.trans.write(b''.join(self.buf).upper()) + self.trans.close() + + +class ProactorTests(test_utils.TestCase): + + def setUp(self): + self.loop = asyncio.ProactorEventLoop() + self.set_event_loop(self.loop) + + def test_close(self): + a, b = self.loop._socketpair() + trans = self.loop._make_socket_transport(a, asyncio.Protocol()) + f = asyncio.async(self.loop.sock_recv(b, 100)) + trans.close() + self.loop.run_until_complete(f) + self.assertEqual(f.result(), b'') + b.close() + + def test_double_bind(self): + ADDRESS = r'\\.\pipe\test_double_bind-%s' % os.getpid() + server1 = windows_events.PipeServer(ADDRESS) + with self.assertRaises(PermissionError): + windows_events.PipeServer(ADDRESS) + server1.close() + + def test_pipe(self): + res = self.loop.run_until_complete(self._test_pipe()) + self.assertEqual(res, 'done') + + def _test_pipe(self): + ADDRESS = r'\\.\pipe\_test_pipe-%s' % os.getpid() + + with self.assertRaises(FileNotFoundError): + yield from self.loop.create_pipe_connection( + asyncio.Protocol, ADDRESS) + + [server] = yield from self.loop.start_serving_pipe( + UpperProto, ADDRESS) + self.assertIsInstance(server, windows_events.PipeServer) + + clients = [] + for i in range(5): + stream_reader = asyncio.StreamReader(loop=self.loop) + protocol = asyncio.StreamReaderProtocol(stream_reader) + trans, proto = yield from self.loop.create_pipe_connection( + lambda: protocol, ADDRESS) + self.assertIsInstance(trans, asyncio.Transport) + self.assertEqual(protocol, proto) + clients.append((stream_reader, trans)) + + for i, (r, w) in enumerate(clients): + w.write('lower-{}\n'.format(i).encode()) + + for i, (r, w) in enumerate(clients): + response = yield from r.readline() + self.assertEqual(response, 'LOWER-{}\n'.format(i).encode()) + w.close() + + server.close() + + with self.assertRaises(FileNotFoundError): + yield from self.loop.create_pipe_connection( + asyncio.Protocol, ADDRESS) + + return 'done' + + def test_wait_for_handle(self): + event = _overlapped.CreateEvent(None, True, False, None) + self.addCleanup(_winapi.CloseHandle, event) + + # Wait for unset event with 0.5s timeout; + # result should be False at timeout + fut = self.loop._proactor.wait_for_handle(event, 0.5) + start = self.loop.time() + self.loop.run_until_complete(fut) + elapsed = self.loop.time() - start + self.assertFalse(fut.result()) + self.assertTrue(0.48 < elapsed < 0.9, elapsed) + + _overlapped.SetEvent(event) + + # Wait for set event; + # result should be True immediately + fut = self.loop._proactor.wait_for_handle(event, 10) + start = self.loop.time() + self.loop.run_until_complete(fut) + elapsed = self.loop.time() - start + self.assertTrue(fut.result()) + self.assertTrue(0 <= elapsed < 0.3, elapsed) + + # Tulip issue #195: cancelling a done _WaitHandleFuture must not crash + fut.cancel() + + def test_wait_for_handle_cancel(self): + event = _overlapped.CreateEvent(None, True, False, None) + self.addCleanup(_winapi.CloseHandle, event) + + # Wait for unset event with a cancelled future; + # CancelledError should be raised immediately + fut = self.loop._proactor.wait_for_handle(event, 10) + fut.cancel() + start = self.loop.time() + with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(fut) + elapsed = self.loop.time() - start + self.assertTrue(0 <= elapsed < 0.1, elapsed) + + # Tulip issue #195: cancelling a _WaitHandleFuture twice must not crash + fut = self.loop._proactor.wait_for_handle(event) + fut.cancel() + fut.cancel() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_windows_utils.py b/tests/test_windows_utils.py new file mode 100644 index 00000000..3e7a211e --- /dev/null +++ b/tests/test_windows_utils.py @@ -0,0 +1,175 @@ +"""Tests for window_utils""" + +import socket +import sys +import test.support +import unittest +from test.support import IPV6_ENABLED +from unittest import mock + +if sys.platform != 'win32': + raise unittest.SkipTest('Windows only') + +import _winapi + +from asyncio import windows_utils +from asyncio import _overlapped + + +class WinsocketpairTests(unittest.TestCase): + + def check_winsocketpair(self, ssock, csock): + csock.send(b'xxx') + self.assertEqual(b'xxx', ssock.recv(1024)) + csock.close() + ssock.close() + + def test_winsocketpair(self): + ssock, csock = windows_utils.socketpair() + self.check_winsocketpair(ssock, csock) + + @unittest.skipUnless(IPV6_ENABLED, 'IPv6 not supported or enabled') + def test_winsocketpair_ipv6(self): + ssock, csock = windows_utils.socketpair(family=socket.AF_INET6) + self.check_winsocketpair(ssock, csock) + + @unittest.skipIf(hasattr(socket, 'socketpair'), + 'socket.socketpair is available') + @mock.patch('asyncio.windows_utils.socket') + def test_winsocketpair_exc(self, m_socket): + m_socket.AF_INET = socket.AF_INET + m_socket.SOCK_STREAM = socket.SOCK_STREAM + m_socket.socket.return_value.getsockname.return_value = ('', 12345) + m_socket.socket.return_value.accept.return_value = object(), object() + m_socket.socket.return_value.connect.side_effect = OSError() + + self.assertRaises(OSError, windows_utils.socketpair) + + def test_winsocketpair_invalid_args(self): + self.assertRaises(ValueError, + windows_utils.socketpair, family=socket.AF_UNSPEC) + self.assertRaises(ValueError, + windows_utils.socketpair, type=socket.SOCK_DGRAM) + self.assertRaises(ValueError, + windows_utils.socketpair, proto=1) + + @unittest.skipIf(hasattr(socket, 'socketpair'), + 'socket.socketpair is available') + @mock.patch('asyncio.windows_utils.socket') + def test_winsocketpair_close(self, m_socket): + m_socket.AF_INET = socket.AF_INET + m_socket.SOCK_STREAM = socket.SOCK_STREAM + sock = mock.Mock() + m_socket.socket.return_value = sock + sock.bind.side_effect = OSError + self.assertRaises(OSError, windows_utils.socketpair) + self.assertTrue(sock.close.called) + + +class PipeTests(unittest.TestCase): + + def test_pipe_overlapped(self): + h1, h2 = windows_utils.pipe(overlapped=(True, True)) + try: + ov1 = _overlapped.Overlapped() + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, 0) + + ov1.ReadFile(h1, 100) + self.assertTrue(ov1.pending) + self.assertEqual(ov1.error, _winapi.ERROR_IO_PENDING) + ERROR_IO_INCOMPLETE = 996 + try: + ov1.getresult() + except OSError as e: + self.assertEqual(e.winerror, ERROR_IO_INCOMPLETE) + else: + raise RuntimeError('expected ERROR_IO_INCOMPLETE') + + ov2 = _overlapped.Overlapped() + self.assertFalse(ov2.pending) + self.assertEqual(ov2.error, 0) + + ov2.WriteFile(h2, b"hello") + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + + res = _winapi.WaitForMultipleObjects([ov2.event], False, 100) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, ERROR_IO_INCOMPLETE) + self.assertFalse(ov2.pending) + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + self.assertEqual(ov1.getresult(), b"hello") + finally: + _winapi.CloseHandle(h1) + _winapi.CloseHandle(h2) + + def test_pipe_handle(self): + h, _ = windows_utils.pipe(overlapped=(True, True)) + _winapi.CloseHandle(_) + p = windows_utils.PipeHandle(h) + self.assertEqual(p.fileno(), h) + self.assertEqual(p.handle, h) + + # check garbage collection of p closes handle + del p + test.support.gc_collect() + try: + _winapi.CloseHandle(h) + except OSError as e: + self.assertEqual(e.winerror, 6) # ERROR_INVALID_HANDLE + else: + raise RuntimeError('expected ERROR_INVALID_HANDLE') + + +class PopenTests(unittest.TestCase): + + def test_popen(self): + command = r"""if 1: + import sys + s = sys.stdin.readline() + sys.stdout.write(s.upper()) + sys.stderr.write('stderr') + """ + msg = b"blah\n" + + p = windows_utils.Popen([sys.executable, '-c', command], + stdin=windows_utils.PIPE, + stdout=windows_utils.PIPE, + stderr=windows_utils.PIPE) + + for f in [p.stdin, p.stdout, p.stderr]: + self.assertIsInstance(f, windows_utils.PipeHandle) + + ovin = _overlapped.Overlapped() + ovout = _overlapped.Overlapped() + overr = _overlapped.Overlapped() + + ovin.WriteFile(p.stdin.handle, msg) + ovout.ReadFile(p.stdout.handle, 100) + overr.ReadFile(p.stderr.handle, 100) + + events = [ovin.event, ovout.event, overr.event] + # Super-long timeout for slow buildbots. + res = _winapi.WaitForMultipleObjects(events, True, 10000) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + self.assertFalse(ovout.pending) + self.assertFalse(overr.pending) + self.assertFalse(ovin.pending) + + self.assertEqual(ovin.getresult(), len(msg)) + out = ovout.getresult().rstrip() + err = overr.getresult().rstrip() + + self.assertGreater(len(out), 0) + self.assertGreater(len(err), 0) + # allow for partial reads... + self.assertTrue(msg.upper().rstrip().startswith(out)) + self.assertTrue(b"stderr".startswith(err)) + + p.wait() + + +if __name__ == '__main__': + unittest.main() diff --git a/update_stdlib.sh b/update_stdlib.sh new file mode 100755 index 00000000..bb6251a0 --- /dev/null +++ b/update_stdlib.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +# Script to copy asyncio files to the standard library tree. +# Optional argument is the root of the Python 3.4 tree. +# Assumes you have already created Lib/asyncio and +# Lib/test/test_asyncio in the destination tree. + +CPYTHON=${1-$HOME/cpython} + +if [ ! -d $CPYTHON ] +then + echo Bad destination $CPYTHON + exit 1 +fi + +if [ ! -f asyncio/__init__.py ] +then + echo Bad current directory + exit 1 +fi + +maybe_copy() +{ + SRC=$1 + DST=$CPYTHON/$2 + if cmp $DST $SRC + then + return + fi + echo ======== $SRC === $DST ======== + diff -u $DST $SRC + echo -n "Copy $SRC? [y/N/back] " + read X + case $X in + [yY]*) echo Copying $SRC; cp $SRC $DST;; + back) echo Copying TO $SRC; cp $DST $SRC;; + *) echo Not copying $SRC;; + esac +} + +for i in `(cd asyncio && ls *.py)` +do + if [ $i == selectors.py ] + then + if [ "`(cd $CPYTHON; hg branch)`" == "3.4" ] + then + echo "Destination is 3.4 branch -- ignoring selectors.py" + else + maybe_copy asyncio/$i Lib/$i + fi + else + maybe_copy asyncio/$i Lib/asyncio/$i + fi +done + +for i in `(cd tests && ls *.py *.pem)` +do + if [ $i == test_selectors.py ] + then + continue + fi + maybe_copy tests/$i Lib/test/test_asyncio/$i +done + +maybe_copy overlapped.c Modules/overlapped.c From 1ccf8b456ff705ab47a8ad5c78dd672028f66d06 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 15 Dec 2014 17:07:12 +0100 Subject: [PATCH 1202/1502] hgignore: ignore .tox/ directory --- .hgignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.hgignore b/.hgignore index 6d1136f2..7931fad2 100644 --- a/.hgignore +++ b/.hgignore @@ -12,3 +12,4 @@ distribute-\d+.\d+.\d+.tar.gz$ build$ dist$ .*\.egg-info$ +.tox$ From 4295379b38e1f8c355671fdf1f182cd9cf41e86e Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 15 Dec 2014 17:07:47 +0100 Subject: [PATCH 1203/1502] Backed out changeset 2af65c9de2a8 --- .hgignore | 1 - 1 file changed, 1 deletion(-) diff --git a/.hgignore b/.hgignore index 7931fad2..6d1136f2 100644 --- a/.hgignore +++ b/.hgignore @@ -12,4 +12,3 @@ distribute-\d+.\d+.\d+.tar.gz$ build$ dist$ .*\.egg-info$ -.tox$ From 2efe29894ee347794e3076f6b9c5938acb30aa9b Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 18 Dec 2014 01:15:03 +0100 Subject: [PATCH 1204/1502] Python issue #23074: get_event_loop() now raises an exception if the thread has no event loop even if assertions are disabled. --- asyncio/base_events.py | 2 +- asyncio/events.py | 6 +++--- tests/test_events.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 0c7316ea..b1a5422b 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -420,7 +420,7 @@ def _assert_is_current_event_loop(self): """ try: current = events.get_event_loop() - except AssertionError: + except RuntimeError: return if current is not self: raise RuntimeError( diff --git a/asyncio/events.py b/asyncio/events.py index 806218f6..8a7bb814 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -517,9 +517,9 @@ def get_event_loop(self): not self._local._set_called and isinstance(threading.current_thread(), threading._MainThread)): self.set_event_loop(self.new_event_loop()) - assert self._local._loop is not None, \ - ('There is no current event loop in thread %r.' % - threading.current_thread().name) + if self._local._loop is None: + raise RuntimeError('There is no current event loop in thread %r.' + % threading.current_thread().name) return self._local._loop def set_event_loop(self, loop): diff --git a/tests/test_events.py b/tests/test_events.py index 6644fbea..d7e2f348 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -2252,14 +2252,14 @@ def test_get_event_loop_calls_set_event_loop(self): def test_get_event_loop_after_set_none(self): policy = asyncio.DefaultEventLoopPolicy() policy.set_event_loop(None) - self.assertRaises(AssertionError, policy.get_event_loop) + self.assertRaises(RuntimeError, policy.get_event_loop) @mock.patch('asyncio.events.threading.current_thread') def test_get_event_loop_thread(self, m_current_thread): def f(): policy = asyncio.DefaultEventLoopPolicy() - self.assertRaises(AssertionError, policy.get_event_loop) + self.assertRaises(RuntimeError, policy.get_event_loop) th = threading.Thread(target=f) th.start() From 32de9e93e2d38fe79076e1825e2fbd2f667d2424 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 18 Dec 2014 11:27:37 +0100 Subject: [PATCH 1205/1502] Add release.py --- release.py | 227 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 227 insertions(+) create mode 100644 release.py diff --git a/release.py b/release.py new file mode 100644 index 00000000..49d8ee7d --- /dev/null +++ b/release.py @@ -0,0 +1,227 @@ +import contextlib +import os +import re +import shutil +import subprocess +import sys +import tempfile +import textwrap + +PY3 = (sys.version_info >= (3,)) +HG = 'hg' +_PYTHON_VERSIONS = [(3, 3)] +PYTHON_VERSIONS = [] +for pyver in _PYTHON_VERSIONS: + PYTHON_VERSIONS.append((pyver, 32)) + PYTHON_VERSIONS.append((pyver, 64)) +SDK_ROOT = r"C:\Program Files\Microsoft SDKs\Windows" +BATCH_FAIL_ON_ERROR = "@IF %errorlevel% neq 0 exit /b %errorlevel%" + +class Release(object): + def __init__(self): + root = os.path.dirname(__file__) + self.root = os.path.realpath(root) + + @contextlib.contextmanager + def _popen(self, args, **kw): + env2 = kw.pop('env', {}) + env = dict(os.environ) + # Force the POSIX locale + env['LC_ALL'] = 'C' + env.update(env2) + print('+ ' + ' '.join(args)) + if PY3: + kw['universal_newlines'] = True + proc = subprocess.Popen(args, env=env, **kw) + with proc: + yield proc + + def get_output(self, *args, **kw): + with self._popen(args, stdout=subprocess.PIPE, **kw) as proc: + stdout, stderr = proc.communicate() + return stdout + + def run_command(self, *args, **kw): + with self._popen(args, **kw) as proc: + exitcode = proc.wait() + if exitcode: + sys.exit(exitcode) + + def get_local_changes(self): + status = self.get_output(HG, 'status') + return [line for line in status.splitlines() + if not line.startswith("?")] + + def remove_directory(self, name): + path = os.path.join(self.root, name) + if os.path.exists(path): + print("Remove directory: %s" % name) + shutil.rmtree(path) + + def remove_file(self, name): + path = os.path.join(self.root, name) + if os.path.exists(path): + print("Remove file: %s" % name) + os.unlink(path) + + def windows_sdk_setenv(self, pyver, bits): + if pyver >= (3, 3): + sdkver = "v7.1" + else: + sdkver = "v7.0" + setenv = os.path.join(SDK_ROOT, sdkver, 'Bin', 'SetEnv.cmd') + if not os.path.exists(setenv): + print("Unable to find Windows SDK %s for Python %s.%s" + % (sdkver, pyver[0], pyver[1])) + print("Please download and install it") + print("%s does not exists" % setenv) + sys.exit(1) + if bits == 64: + arch = '/x64' + else: + arch = '/x86' + return ["CALL", setenv, "/release", arch] + + def get_python(self, version, bits): + if bits == 32: + python = 'c:\\Python%s%s_32bit\\python.exe' % version + else: + python = 'c:\\Python%s%s\\python.exe' % version + if not os.path.exists(python): + print("Unable to find python%s.%s" % version) + print("%s does not exists" % python) + sys.exit(1) + code = ( + 'import platform, sys; ' + 'print("{ver.major}.{ver.minor} {bits}".format(' + 'ver=sys.version_info, ' + 'bits=platform.architecture()[0]))' + ) + stdout = self.get_output(python, '-c', code) + stdout = stdout.rstrip() + expected = "%s.%s %sbit" % (version[0], version[1], bits) + if stdout != expected: + print("Python version or architecture doesn't match") + print("got %r, expected %r" % (stdout, expected)) + print(python) + sys.exit(1) + return python + + def quote(self, arg): + if not re.search("[ '\"]", arg): + return arg + # FIXME: should we escape "? + return '"%s"' % arg + + def quote_args(self, args): + return ' '.join(self.quote(arg) for arg in args) + + def cleanup(self): + self.remove_directory('build') + self.remove_directory('dist') + self.remove_file('_overlapped.pyd') + self.remove_file(os.path.join('asyncio', '_overlapped.pyd')) + + def sdist_upload(self): + self.cleanup() + self.run_command(sys.executable, 'setup.py', 'sdist', 'upload') + + def runtests(self, pyver, bits): + pythonstr = "%s.%s (%s bits)" % (pyver[0], pyver[1], bits) + python = self.get_python(pyver, bits) + args = python, 'runtests.py', '-r' + + print("Run tests in release mode with %s" % pythonstr) + self.run_command(*args) + + print("Run tests in debug mode with %s" % pythonstr) + self.run_command(*args, env={'PYTHONASYNCIODEBUG': 1}) + + def wheel_command(self, pyver, bits, *cmds): + self.cleanup() + + setenv = self.windows_sdk_setenv(pyver, bits) + + python = self.get_python(pyver, bits) + + cmd = [python, 'setup.py'] + list(cmds) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".bat", delete=False) as temp: + print("CD %s" % self.quote(self.root), file=temp) + print(self.quote_args(setenv), file=temp) + print(BATCH_FAIL_ON_ERROR, file=temp) + print("", file=temp) + print("SET DISTUTILS_USE_SDK=1", file=temp) + print("SET MSSDK=1", file=temp) + print(self.quote_args(cmd), file=temp) + print(BATCH_FAIL_ON_ERROR, file=temp) + + try: + self.run_command(temp.name) + finally: + os.unlink(temp.name) + + def test_wheel(self, pyver, bits): + self.wheel_command(pyver, bits, 'bdist_wheel') + + def publish_wheel(self, pyver, bits): + self.wheel_command(pyver, bits, 'bdist_wheel', 'upload') + + def main(self): + try: + pos = sys.argv[1:].index('--ignore') + except ValueError: + ignore = False + else: + ignore = True + del sys.argv[1+pos] + if len(sys.argv) != 2: + print("usage: %s hg_tag" % sys.argv[0]) + sys.exit(1) + + print("Directory: %s" % self.root) + os.chdir(self.root) + + if not ignore: + lines = self.get_local_changes() + else: + lines = () + if lines: + print("ERROR: Found local changes") + for line in lines: + print(line) + print("") + print("Revert local changes") + print("or use the --ignore command line option") + sys.exit(1) + + hg_tag = sys.argv[1] + self.run_command(HG, 'up', hg_tag) + + # FIXME: enable running tests + # On Windows, installing Python with the MSI doesn't install the test module, + # so asyncio tests cannot run because test.script_helper is not found. + #for pyver in PYTHON_VERSIONS: + # self.runtests(pyver, 32) + # self.runtests(pyver, 64) + + for pyver, bits in PYTHON_VERSIONS: + self.test_wheel(pyver, bits) + + self.run_command(sys.executable, 'setup.py', 'register') + + self.sdist_upload() + + for pyver, bits in PYTHON_VERSIONS: + self.publish_wheel(pyver, bits) + + print("") + print("Publish version %s" % hg_tag) + print("Uploaded:") + print("- sdist") + for pyver, bits in PYTHON_VERSIONS: + print("- Windows wheel %s bits package for Python %s.%s" + % (bits, pyver[0], pyver[1])) + +if __name__ == "__main__": + Release().main() From be0ddd550f4791d917ce37e64796e4710211c43b Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 18 Dec 2014 11:51:31 +0100 Subject: [PATCH 1206/1502] setup.py: set version to 3.4.3; write the release procedure --- setup.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index fcd3b6aa..a5461dc8 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,15 @@ +# Release procedure: +# - run unit tests with python 3.3 in debug mode +# - run unit tests with python dev (3.5) in debug mode +# - maybe test examples +# - update version in setup.py +# - hg ci +# - hg tag VERSION +# - hg push +# - python setup.py register sdist bdist_wheel upload +# - increment version in setup.py +# - hg ci && hg push + import os try: from setuptools import setup, Extension @@ -15,7 +27,7 @@ setup( name="asyncio", - version="3.4.1", + version="3.4.3", description="reference implementation of PEP 3156", long_description=open("README").read(), From 871c0c1ac652737cd4dc0fe3def89cbc4fccaba8 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 18 Dec 2014 11:54:57 +0100 Subject: [PATCH 1207/1502] Add release.py --- release.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/release.py b/release.py index 49d8ee7d..99f8f32e 100644 --- a/release.py +++ b/release.py @@ -1,3 +1,11 @@ +""" +Script to upload 32 bits and 64 bits wheel packages for Python 3.3 on Windows. + +Usage: "python release.py HG_TAG" where HG_TAG is a Mercurial tag, usually +a version number like "3.4.2". + +It requires the Windows SDK 7.1 on Windows 64 bits. +""" import contextlib import os import re @@ -21,6 +29,9 @@ class Release(object): def __init__(self): root = os.path.dirname(__file__) self.root = os.path.realpath(root) + # Set these attributes to True to run also register sdist upload + self.register = False + self.sdist = False @contextlib.contextmanager def _popen(self, args, **kw): @@ -208,17 +219,21 @@ def main(self): for pyver, bits in PYTHON_VERSIONS: self.test_wheel(pyver, bits) - self.run_command(sys.executable, 'setup.py', 'register') + if self.register: + self.run_command(sys.executable, 'setup.py', 'register') - self.sdist_upload() + if self.sdist: + self.sdist_upload() for pyver, bits in PYTHON_VERSIONS: self.publish_wheel(pyver, bits) print("") - print("Publish version %s" % hg_tag) + if self.register: + print("Publish version %s" % hg_tag) print("Uploaded:") - print("- sdist") + if self.sdist: + print("- sdist") for pyver, bits in PYTHON_VERSIONS: print("- Windows wheel %s bits package for Python %s.%s" % (bits, pyver[0], pyver[1])) From 6f9e5fb7db35ebc8e4e5e80431d517d55d3320f8 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 18 Dec 2014 11:59:16 +0100 Subject: [PATCH 1208/1502] update release procedure --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a5461dc8..3a3a4db3 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,8 @@ # - hg ci # - hg tag VERSION # - hg push -# - python setup.py register sdist bdist_wheel upload +# - run on Linux: python setup.py register sdist upload +# - run on Windows: python release.py VERSION # - increment version in setup.py # - hg ci && hg push From 1d83d4acc472135551675fa1a33322afaf8a4093 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 18 Dec 2014 12:28:08 +0100 Subject: [PATCH 1209/1502] Copy a subset of test.support from CPython 3.5 to no more depend on the test module to run the asyncio test suite. The test module is rarely installed. --- asyncio/test_support.py | 290 ++++++++++++++++++++++++++++++++++++ tests/test_base_events.py | 12 +- tests/test_events.py | 5 +- tests/test_futures.py | 5 +- tests/test_subprocess.py | 12 +- tests/test_tasks.py | 8 +- tests/test_windows_utils.py | 11 +- update_stdlib.sh | 5 + 8 files changed, 332 insertions(+), 16 deletions(-) create mode 100644 asyncio/test_support.py diff --git a/asyncio/test_support.py b/asyncio/test_support.py new file mode 100644 index 00000000..59b3e8cf --- /dev/null +++ b/asyncio/test_support.py @@ -0,0 +1,290 @@ +# Subset of test.support from CPython 3.5, just what we need to run asyncio +# test suite. The cde is copied from CPython 3.5 to not depend on the test +# module because it is rarely installed. + +# Ignore symbol TEST_HOME_DIR: test_events works without it + +import functools +import gc +import os +import platform +import re +import socket +import subprocess +import sys + +# A constant likely larger than the underlying OS pipe buffer size, to +# make writes blocking. +# Windows limit seems to be around 512 B, and many Unix kernels have a +# 64 KiB pipe buffer size or 16 * PAGE_SIZE: take a few megs to be sure. +# (see issue #17835 for a discussion of this number). +PIPE_MAX_SIZE = 4 * 1024 * 1024 + 1 + +def strip_python_stderr(stderr): + """Strip the stderr of a Python process from potential debug output + emitted by the interpreter. + + This will typically be run on the result of the communicate() method + of a subprocess.Popen object. + """ + stderr = re.sub(br"\[\d+ refs, \d+ blocks\]\r?\n?", b"", stderr).strip() + return stderr + + +# Executing the interpreter in a subprocess +def _assert_python(expected_success, *args, **env_vars): + if '__isolated' in env_vars: + isolated = env_vars.pop('__isolated') + else: + isolated = not env_vars + cmd_line = [sys.executable, '-X', 'faulthandler'] + if isolated: + # isolated mode: ignore Python environment variables, ignore user + # site-packages, and don't add the current directory to sys.path + cmd_line.append('-I') + elif not env_vars: + # ignore Python environment variables + cmd_line.append('-E') + # Need to preserve the original environment, for in-place testing of + # shared library builds. + env = os.environ.copy() + # But a special flag that can be set to override -- in this case, the + # caller is responsible to pass the full environment. + if env_vars.pop('__cleanenv', None): + env = {} + env.update(env_vars) + cmd_line.extend(args) + p = subprocess.Popen(cmd_line, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + env=env) + try: + out, err = p.communicate() + finally: + subprocess._cleanup() + p.stdout.close() + p.stderr.close() + rc = p.returncode + err = strip_python_stderr(err) + if (rc and expected_success) or (not rc and not expected_success): + raise AssertionError( + "Process return code is %d, " + "stderr follows:\n%s" % (rc, err.decode('ascii', 'ignore'))) + return rc, out, err + + +def assert_python_ok(*args, **env_vars): + """ + Assert that running the interpreter with `args` and optional environment + variables `env_vars` succeeds (rc == 0) and return a (return code, stdout, + stderr) tuple. + + If the __cleanenv keyword is set, env_vars is used a fresh environment. + + Python is started in isolated mode (command line option -I), + except if the __isolated keyword is set to False. + """ + return _assert_python(True, *args, **env_vars) + + +is_jython = sys.platform.startswith('java') + +def gc_collect(): + """Force as many objects as possible to be collected. + + In non-CPython implementations of Python, this is needed because timely + deallocation is not guaranteed by the garbage collector. (Even in CPython + this can be the case in case of reference cycles.) This means that __del__ + methods may be called later than expected and weakrefs may remain alive for + longer than expected. This function tries its best to force all garbage + objects to disappear. + """ + gc.collect() + if is_jython: + time.sleep(0.1) + gc.collect() + gc.collect() + + +HOST = "127.0.0.1" +HOSTv6 = "::1" + + +def _is_ipv6_enabled(): + """Check whether IPv6 is enabled on this host.""" + if socket.has_ipv6: + sock = None + try: + sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + sock.bind((HOSTv6, 0)) + return True + except OSError: + pass + finally: + if sock: + sock.close() + return False + +IPV6_ENABLED = _is_ipv6_enabled() + + +def find_unused_port(family=socket.AF_INET, socktype=socket.SOCK_STREAM): + """Returns an unused port that should be suitable for binding. This is + achieved by creating a temporary socket with the same family and type as + the 'sock' parameter (default is AF_INET, SOCK_STREAM), and binding it to + the specified host address (defaults to 0.0.0.0) with the port set to 0, + eliciting an unused ephemeral port from the OS. The temporary socket is + then closed and deleted, and the ephemeral port is returned. + + Either this method or bind_port() should be used for any tests where a + server socket needs to be bound to a particular port for the duration of + the test. Which one to use depends on whether the calling code is creating + a python socket, or if an unused port needs to be provided in a constructor + or passed to an external program (i.e. the -accept argument to openssl's + s_server mode). Always prefer bind_port() over find_unused_port() where + possible. Hard coded ports should *NEVER* be used. As soon as a server + socket is bound to a hard coded port, the ability to run multiple instances + of the test simultaneously on the same host is compromised, which makes the + test a ticking time bomb in a buildbot environment. On Unix buildbots, this + may simply manifest as a failed test, which can be recovered from without + intervention in most cases, but on Windows, the entire python process can + completely and utterly wedge, requiring someone to log in to the buildbot + and manually kill the affected process. + + (This is easy to reproduce on Windows, unfortunately, and can be traced to + the SO_REUSEADDR socket option having different semantics on Windows versus + Unix/Linux. On Unix, you can't have two AF_INET SOCK_STREAM sockets bind, + listen and then accept connections on identical host/ports. An EADDRINUSE + OSError will be raised at some point (depending on the platform and + the order bind and listen were called on each socket). + + However, on Windows, if SO_REUSEADDR is set on the sockets, no EADDRINUSE + will ever be raised when attempting to bind two identical host/ports. When + accept() is called on each socket, the second caller's process will steal + the port from the first caller, leaving them both in an awkwardly wedged + state where they'll no longer respond to any signals or graceful kills, and + must be forcibly killed via OpenProcess()/TerminateProcess(). + + The solution on Windows is to use the SO_EXCLUSIVEADDRUSE socket option + instead of SO_REUSEADDR, which effectively affords the same semantics as + SO_REUSEADDR on Unix. Given the propensity of Unix developers in the Open + Source world compared to Windows ones, this is a common mistake. A quick + look over OpenSSL's 0.9.8g source shows that they use SO_REUSEADDR when + openssl.exe is called with the 's_server' option, for example. See + http://bugs.python.org/issue2550 for more info. The following site also + has a very thorough description about the implications of both REUSEADDR + and EXCLUSIVEADDRUSE on Windows: + http://msdn2.microsoft.com/en-us/library/ms740621(VS.85).aspx) + + XXX: although this approach is a vast improvement on previous attempts to + elicit unused ports, it rests heavily on the assumption that the ephemeral + port returned to us by the OS won't immediately be dished back out to some + other process when we close and delete our temporary socket but before our + calling code has a chance to bind the returned port. We can deal with this + issue if/when we come across it. + """ + + tempsock = socket.socket(family, socktype) + port = bind_port(tempsock) + tempsock.close() + del tempsock + return port + +def bind_port(sock, host=HOST): + """Bind the socket to a free port and return the port number. Relies on + ephemeral ports in order to ensure we are using an unbound port. This is + important as many tests may be running simultaneously, especially in a + buildbot environment. This method raises an exception if the sock.family + is AF_INET and sock.type is SOCK_STREAM, *and* the socket has SO_REUSEADDR + or SO_REUSEPORT set on it. Tests should *never* set these socket options + for TCP/IP sockets. The only case for setting these options is testing + multicasting via multiple UDP sockets. + + Additionally, if the SO_EXCLUSIVEADDRUSE socket option is available (i.e. + on Windows), it will be set on the socket. This will prevent anyone else + from bind()'ing to our host/port for the duration of the test. + """ + + if sock.family == socket.AF_INET and sock.type == socket.SOCK_STREAM: + if hasattr(socket, 'SO_REUSEADDR'): + if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 1: + raise TestFailed("tests should never set the SO_REUSEADDR " \ + "socket option on TCP/IP sockets!") + if hasattr(socket, 'SO_REUSEPORT'): + try: + if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 1: + raise TestFailed("tests should never set the SO_REUSEPORT " \ + "socket option on TCP/IP sockets!") + except OSError: + # Python's socket module was compiled using modern headers + # thus defining SO_REUSEPORT but this process is running + # under an older kernel that does not support SO_REUSEPORT. + pass + if hasattr(socket, 'SO_EXCLUSIVEADDRUSE'): + sock.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) + + sock.bind((host, 0)) + port = sock.getsockname()[1] + return port + +def requires_mac_ver(*min_version): + """Decorator raising SkipTest if the OS is Mac OS X and the OS X + version if less than min_version. + + For example, @requires_mac_ver(10, 5) raises SkipTest if the OS X version + is lesser than 10.5. + """ + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kw): + if sys.platform == 'darwin': + version_txt = platform.mac_ver()[0] + try: + version = tuple(map(int, version_txt.split('.'))) + except ValueError: + pass + else: + if version < min_version: + min_version_txt = '.'.join(map(str, min_version)) + raise unittest.SkipTest( + "Mac OS X %s or higher required, not %s" + % (min_version_txt, version_txt)) + return func(*args, **kw) + wrapper.min_version = min_version + return wrapper + return decorator + +def _requires_unix_version(sysname, min_version): + """Decorator raising SkipTest if the OS is `sysname` and the version is less + than `min_version`. + + For example, @_requires_unix_version('FreeBSD', (7, 2)) raises SkipTest if + the FreeBSD version is less than 7.2. + """ + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kw): + if platform.system() == sysname: + version_txt = platform.release().split('-', 1)[0] + try: + version = tuple(map(int, version_txt.split('.'))) + except ValueError: + pass + else: + if version < min_version: + min_version_txt = '.'.join(map(str, min_version)) + raise unittest.SkipTest( + "%s version %s or higher required, not %s" + % (sysname, min_version_txt, version_txt)) + return func(*args, **kw) + wrapper.min_version = min_version + return wrapper + return decorator + +def requires_freebsd_version(*min_version): + """Decorator raising SkipTest if the OS is FreeBSD and the FreeBSD version is + less than `min_version`. + + For example, @requires_freebsd_version(7, 2) raises SkipTest if the FreeBSD + version is less than 7.2. + """ + return _requires_unix_version('FreeBSD', min_version) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index db9d732c..4e5b6ca9 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -8,13 +8,17 @@ import time import unittest from unittest import mock -from test.script_helper import assert_python_ok -from test.support import IPV6_ENABLED, gc_collect import asyncio from asyncio import base_events from asyncio import constants from asyncio import test_utils +try: + from test.script_helper import assert_python_ok + from test import support +except ImportError: + from asyncio import test_support as support + from asyncio.test_support import assert_python_ok MOCK_ANY = mock.ANY @@ -634,7 +638,7 @@ def raise_keyboard_interrupt(): except KeyboardInterrupt: pass self.loop.close() - gc_collect() + support.gc_collect() self.assertFalse(self.loop.call_exception_handler.called) @@ -1066,7 +1070,7 @@ def test_create_datagram_endpoint_socket_err(self, m_socket): self.assertRaises( OSError, self.loop.run_until_complete, coro) - @unittest.skipUnless(IPV6_ENABLED, 'IPv6 not supported or enabled') + @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 not supported or enabled') def test_create_datagram_endpoint_no_matching_family(self): coro = self.loop.create_datagram_endpoint( asyncio.DatagramProtocol, diff --git a/tests/test_events.py b/tests/test_events.py index d7e2f348..06302920 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -20,13 +20,16 @@ import unittest from unittest import mock import weakref -from test import support # find_unused_port, IPV6_ENABLED, TEST_HOME_DIR import asyncio from asyncio import proactor_events from asyncio import selector_events from asyncio import test_utils +try: + from test import support # find_unused_port, IPV6_ENABLED, TEST_HOME_DIR +except ImportError: + from asyncio import test_support as support def data_file(filename): diff --git a/tests/test_futures.py b/tests/test_futures.py index 371d3518..f9c3ad20 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -5,11 +5,14 @@ import sys import threading import unittest -from test import support from unittest import mock import asyncio from asyncio import test_utils +try: + from test import support # gc_collect +except ImportError: + from asyncio import test_support as support def _fakefunc(f): diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index 5c0a2c85..08c8ac24 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -1,13 +1,17 @@ -from asyncio import subprocess -from asyncio import test_utils -import asyncio import signal import sys import unittest from unittest import mock -from test import support + +import asyncio +from asyncio import subprocess +from asyncio import test_utils if sys.platform != 'win32': from asyncio import unix_events +try: + from test import support # PIPE_MAX_SIZE +except ImportError: + from asyncio import test_support as support # Program blocking PROGRAM_BLOCKED = [sys.executable, '-c', 'import time; time.sleep(3600)'] diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 770f2181..25b21dc5 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -6,9 +6,13 @@ import types import unittest import weakref -from test import support -from test.script_helper import assert_python_ok from unittest import mock +try: + from test import support # gc_collect + from test.script_helper import assert_python_ok +except ImportError: + from asyncio import test_support as support + from asyncio.test_support import assert_python_ok import asyncio from asyncio import coroutines diff --git a/tests/test_windows_utils.py b/tests/test_windows_utils.py index 3e7a211e..b9579491 100644 --- a/tests/test_windows_utils.py +++ b/tests/test_windows_utils.py @@ -2,11 +2,14 @@ import socket import sys -import test.support import unittest -from test.support import IPV6_ENABLED from unittest import mock +try: + from test import support # gc_collect, IPV6_ENABLED +except ImportError: + from asyncio import test_support as support + if sys.platform != 'win32': raise unittest.SkipTest('Windows only') @@ -28,7 +31,7 @@ def test_winsocketpair(self): ssock, csock = windows_utils.socketpair() self.check_winsocketpair(ssock, csock) - @unittest.skipUnless(IPV6_ENABLED, 'IPv6 not supported or enabled') + @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 not supported or enabled') def test_winsocketpair_ipv6(self): ssock, csock = windows_utils.socketpair(family=socket.AF_INET6) self.check_winsocketpair(ssock, csock) @@ -114,7 +117,7 @@ def test_pipe_handle(self): # check garbage collection of p closes handle del p - test.support.gc_collect() + support.gc_collect() try: _winapi.CloseHandle(h) except OSError as e: diff --git a/update_stdlib.sh b/update_stdlib.sh index bb6251a0..0cdbb1bd 100755 --- a/update_stdlib.sh +++ b/update_stdlib.sh @@ -40,6 +40,11 @@ maybe_copy() for i in `(cd asyncio && ls *.py)` do + if [ $i == test_support.py ] + then + continue + fi + if [ $i == selectors.py ] then if [ "`(cd $CPYTHON; hg branch)`" == "3.4" ] From c5d97b0dec38129b5717b78b10c5211048550177 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 18 Dec 2014 12:36:40 +0100 Subject: [PATCH 1210/1502] Add tox.ini --- .hgignore | 1 + setup.py | 3 +-- tox.ini | 11 +++++++++++ 3 files changed, 13 insertions(+), 2 deletions(-) create mode 100644 tox.ini diff --git a/.hgignore b/.hgignore index 6d1136f2..736c7fdf 100644 --- a/.hgignore +++ b/.hgignore @@ -12,3 +12,4 @@ distribute-\d+.\d+.\d+.tar.gz$ build$ dist$ .*\.egg-info$ +\.tox$ diff --git a/setup.py b/setup.py index 3a3a4db3..87a629dc 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,5 @@ # Release procedure: -# - run unit tests with python 3.3 in debug mode -# - run unit tests with python dev (3.5) in debug mode +# - run tox (to run runtests.py and run_aiotest.py) # - maybe test examples # - update version in setup.py # - hg ci diff --git a/tox.ini b/tox.ini new file mode 100644 index 00000000..040b25ab --- /dev/null +++ b/tox.ini @@ -0,0 +1,11 @@ +[tox] +envlist = py33,py34 + +[testenv] +deps= + aiotest +setenv = + PYTHONASYNCIODEBUG = 1 +commands= + python runtests.py -r {posargs} + python run_aiotest.py -r {posargs} From 558d6cd7774f4066124244ae21810bde1a46fe0d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 18 Dec 2014 12:51:57 +0100 Subject: [PATCH 1211/1502] release.py now also run tests, runtests.py and run_aiotest.py Modify also release.py to run in dry run, don't upload anything, to avoid mistakes. --- release.py | 50 ++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 36 insertions(+), 14 deletions(-) diff --git a/release.py b/release.py index 99f8f32e..8e0bec43 100644 --- a/release.py +++ b/release.py @@ -4,7 +4,9 @@ Usage: "python release.py HG_TAG" where HG_TAG is a Mercurial tag, usually a version number like "3.4.2". -It requires the Windows SDK 7.1 on Windows 64 bits. +Modify manually the dry_run attribute to upload files. + +It requires the Windows SDK 7.1 on Windows 64 bits and the aiotest module. """ import contextlib import os @@ -32,6 +34,8 @@ def __init__(self): # Set these attributes to True to run also register sdist upload self.register = False self.sdist = False + self.dry_run = True + self.aiotest = True @contextlib.contextmanager def _popen(self, args, **kw): @@ -140,15 +144,34 @@ def sdist_upload(self): def runtests(self, pyver, bits): pythonstr = "%s.%s (%s bits)" % (pyver[0], pyver[1], bits) python = self.get_python(pyver, bits) - args = python, 'runtests.py', '-r' + dbg_env = {'PYTHONASYNCIODEBUG': '1'} - print("Run tests in release mode with %s" % pythonstr) + self.build(pyver, bits, 'build') + if bits == 64: + arch = 'win-amd64' + else: + arch = 'win32' + build_dir = 'lib.%s-%s.%s' % (arch, pyver[0], pyver[1]) + src = os.path.join(self.root, 'build', build_dir, 'asyncio', '_overlapped.pyd') + dst = os.path.join(self.root, 'asyncio', '_overlapped.pyd') + shutil.copyfile(src, dst) + + args = (python, 'runtests.py', '-r') + print("Run runtests.py in release mode with %s" % pythonstr) self.run_command(*args) - print("Run tests in debug mode with %s" % pythonstr) - self.run_command(*args, env={'PYTHONASYNCIODEBUG': 1}) + print("Run runtests.py in debug mode with %s" % pythonstr) + self.run_command(*args, env=dbg_env) + + if self.aiotest: + args = (python, 'run_aiotest.py') + print("Run aiotest in release mode with %s" % pythonstr) + self.run_command(*args) - def wheel_command(self, pyver, bits, *cmds): + print("Run aiotest in debug mode with %s" % pythonstr) + self.run_command(*args, env=dbg_env) + + def build(self, pyver, bits, *cmds): self.cleanup() setenv = self.windows_sdk_setenv(pyver, bits) @@ -173,10 +196,10 @@ def wheel_command(self, pyver, bits, *cmds): os.unlink(temp.name) def test_wheel(self, pyver, bits): - self.wheel_command(pyver, bits, 'bdist_wheel') + self.build(pyver, bits, 'bdist_wheel') def publish_wheel(self, pyver, bits): - self.wheel_command(pyver, bits, 'bdist_wheel', 'upload') + self.build(pyver, bits, 'bdist_wheel', 'upload') def main(self): try: @@ -209,16 +232,15 @@ def main(self): hg_tag = sys.argv[1] self.run_command(HG, 'up', hg_tag) - # FIXME: enable running tests - # On Windows, installing Python with the MSI doesn't install the test module, - # so asyncio tests cannot run because test.script_helper is not found. - #for pyver in PYTHON_VERSIONS: - # self.runtests(pyver, 32) - # self.runtests(pyver, 64) + for pyver, bits in PYTHON_VERSIONS: + self.runtests(pyver, bits) for pyver, bits in PYTHON_VERSIONS: self.test_wheel(pyver, bits) + if self.dry_run: + sys.exit(0) + if self.register: self.run_command(sys.executable, 'setup.py', 'register') From 11f201a0933e49c17bc39d1aa4fbdb7441f197cc Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 18 Dec 2014 12:58:47 +0100 Subject: [PATCH 1212/1502] release.py: add an option to skip tests --- release.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/release.py b/release.py index 8e0bec43..f39d1e13 100644 --- a/release.py +++ b/release.py @@ -35,6 +35,7 @@ def __init__(self): self.register = False self.sdist = False self.dry_run = True + self.test = True self.aiotest = True @contextlib.contextmanager @@ -232,8 +233,9 @@ def main(self): hg_tag = sys.argv[1] self.run_command(HG, 'up', hg_tag) - for pyver, bits in PYTHON_VERSIONS: - self.runtests(pyver, bits) + if self.test: + for pyver, bits in PYTHON_VERSIONS: + self.runtests(pyver, bits) for pyver, bits in PYTHON_VERSIONS: self.test_wheel(pyver, bits) From 3299b7f0a84a1b09249a965619bb438ffdfb79c6 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 18 Dec 2014 14:52:45 +0100 Subject: [PATCH 1213/1502] Fix typo --- asyncio/windows_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 6763f0b7..0773d061 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -402,7 +402,7 @@ def finish_accept_pipe(trans, key, ov): ov.getresult() return pipe - # FIXME: Tulip issue 196: why to we neeed register=False? + # FIXME: Tulip issue 196: why do we need register=False? # See also the comment in the _register() method return self._register(ov, pipe, finish_accept_pipe, register=False) From 994256da4e0d40f442ce4e3d39c75b48f6b35c76 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 18 Dec 2014 14:57:49 +0100 Subject: [PATCH 1214/1502] Start to write Tulip 3.4.2 changelog --- ChangeLog | 157 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 ChangeLog diff --git a/ChangeLog b/ChangeLog new file mode 100644 index 00000000..25017e73 --- /dev/null +++ b/ChangeLog @@ -0,0 +1,157 @@ +Tulip 3.4.2 +=========== + +New shiny methods like create_task(), better documentation, much better debug +mode, better tests. + +asyncio API +----------- + +* Add BaseEventLoop.create_task() method: schedule a coroutine object. + It allows other asyncio implementations to use their own Task class to + change its behaviour. + +* New BaseEventLoop methods: + + - create_task(): schedule a coroutine + - get_debug() + - is_closed() + - set_debug() + +* Add _FlowControlMixin.get_write_buffer_limits() method + +* sock_recv(), sock_sendall(), sock_connect(), sock_accept() methods of + SelectorEventLoop now raise an exception if the socket is blocking mode + +* Include unix_events/windows_events symbols in asyncio.__all__. + Examples: SelectorEventLoop, ProactorEventLoop, DefaultEventLoopPolicy. + +* attach(), detach(), loop, active_count and waiters attributes of the Server + class are now private + +* BaseEventLoop: run_forever(), run_until_complete() now raises an exception if + the event loop was closed + +* close() now raises an exception if the event loop is running, because pending + callbacks would be lost + +* Queue now accepts a float for the maximum size. + +* Process.communicate() now ignores BrokenPipeError and ConnectionResetError + exceptions, as Popen.communicate() of the subprocess module + + +Performances +------------ + +* Optimize handling of cancelled timers + + +Debug +----- + +* Future (and Task), CoroWrapper and Handle now remembers where they were + created (new _source_traceback object), traceback displayed when errors are + logged. + +* On Python 3.4 and newer, Task destrutor now logs a warning if the task was + destroyed while it was still pending. It occurs if the last reference + to the task was removed, while the coroutine didn't finish yet. + +* Much more useful events are logged: + + - Event loop closed + - Network connection + - Creation of a subprocess + - Pipe lost + - Log many errors previously silently ignored + - SSL handshake failure + - etc. + +* BaseEventLoop._debug is now True if the envrionement variable + PYTHONASYNCIODEBUG is set + +* Log the duration of DNS resolution and SSL handshake + +* Log a warning if a callback blocks the event loop longer than 100 ms + (configurable duration) + +* repr(CoroWrapper) and repr(Task) now contains the current status of the + coroutine (running, done), current filename and line number, and filename and + line number where the object was created + +* Enhance representation (repr) of transports: add the file descriptor, status + (idle, polling, writing, etc.), size of the write buffer, ... + +* Add repr(BaseEventLoop) + +* run_until_complete() doesn't log a warning anymore when called with a + coroutine object which raises an exception. + + +Bugfixes +-------- + +* windows_utils.socketpair() now ensures that sockets are closed in case + of error. + +* Rewrite bricks of the IocpProactor() to make it more reliable + +* IocpProactor destructor now closes it. + +* _OverlappedFuture.set_exception() now cancels the overlapped operation. + +* Rewrite _WaitHandleFuture: + + - cancel() is now able to signal the cancellation to the overlapped object + - _unregister_wait() now catchs and logs exceptions + +* PipeServer.close() (class used on Windows) now cancels the accept pipe + future. + +* Rewrite signal handling in the UNIX implementation of SelectorEventLoop: + use the self-pipe to store pending signals instead of registering a + signal handler calling directly _handle_signal(). The change fixes a + race condition. + +* create_unix_server(): close the socket on error. + +* Fix wait_for() + +* Rewrite gather() + +* drain() is now a classic coroutine, no more special return value (empty + tuple) + +* Rewrite SelectorEventLoop.sock_connect() to handle correctly timeout + +* Process data of the self-pipe faster to accept more pending events, + especially signals written by signal handlers: the callback reads all pending + data, not only a single byte + +* Don't try to set the result of a Future anymore if it was cancelled + (explicitly or by a timeout) + +* CoroWrapper now works around CPython issue #21209: yield from & custom + generator classes don't work together, issue with the send() method. It only + affected asyncio in debug mode on Python older than 3.4.2 + + +Misc changes +------------ + +* windows_utils.socketpair() now supports IPv6. + +* Better documentation (online & docstrings): fill remaining XXX, more examples + +* new asyncio.coroutines submodule, to ease maintenance with the trollius + project: @coroutine, _DEBUG, iscoroutine() and iscoroutinefunction() have + been moved from asyncio.tasks to asyncio.coroutines + +* Cleanup code, ex: remove unused attribute (ex: _rawsock) + +* Reuse os.set_blocking() of Python 3.5. + +* Close explicitly the event loop in Tulip examples. + +* runtests.py now mention if tests are running in release or debug mode. From 8363f1d5e72af29f20b046049659848ccd3ff1cc Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 18 Dec 2014 15:02:08 +0100 Subject: [PATCH 1215/1502] Changelog: add also releases --- ChangeLog | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/ChangeLog b/ChangeLog index 25017e73..e483fb2d 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,5 +1,5 @@ -Tulip 3.4.2 -=========== +2014-09-30: Tulip 3.4.2 +======================= New shiny methods like create_task(), better documentation, much better debug mode, better tests. @@ -155,3 +155,31 @@ Misc changes * Close explicitly the event loop in Tulip examples. * runtests.py now mention if tests are running in release or debug mode. + +2014-05-19: Tulip 3.4.1 +======================= + +2014-02-24: Tulip 0.4.1 +======================= + +2014-02-10: Tulip 0.3.1 +======================= + +* Add asyncio.subprocess submodule and the Process class. + +2013-11-25: Tulip 0.2.1 +======================= + +* Add support of subprocesses using transports and protocols. + +2013-10-22: Tulip 0.1.1 +======================= + +* First release. + +Creation of the project +======================= + +* 2013-10-14: The tulip package was renamed to asyncio. +* 2012-10-16: Creation of the Tulip project, started as mail threads on the + python-ideas mailing list. From ec9769dda6c9ccf2c8f9f45db162a7b06caa4d5e Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 18 Dec 2014 17:00:36 +0100 Subject: [PATCH 1216/1502] test_support: add missing import --- asyncio/test_support.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/asyncio/test_support.py b/asyncio/test_support.py index 59b3e8cf..c9df5d07 100644 --- a/asyncio/test_support.py +++ b/asyncio/test_support.py @@ -12,6 +12,8 @@ import socket import subprocess import sys +import time + # A constant likely larger than the underlying OS pipe buffer size, to # make writes blocking. From f41e191083ecbfe00fdbe915b360c555072b11bf Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 18 Dec 2014 17:01:59 +0100 Subject: [PATCH 1217/1502] Fix release.py for SDK 7.0 * Enable delayed expansion, needed by SDK 7.0 * Make the project name configurable * Restore also console colors --- release.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/release.py b/release.py index f39d1e13..fe56e2c1 100644 --- a/release.py +++ b/release.py @@ -17,6 +17,7 @@ import tempfile import textwrap +PROJECT = 'asyncio' PY3 = (sys.version_info >= (3,)) HG = 'hg' _PYTHON_VERSIONS = [(3, 3)] @@ -136,7 +137,7 @@ def cleanup(self): self.remove_directory('build') self.remove_directory('dist') self.remove_file('_overlapped.pyd') - self.remove_file(os.path.join('asyncio', '_overlapped.pyd')) + self.remove_file(os.path.join(PROJECT, '_overlapped.pyd')) def sdist_upload(self): self.cleanup() @@ -153,8 +154,8 @@ def runtests(self, pyver, bits): else: arch = 'win32' build_dir = 'lib.%s-%s.%s' % (arch, pyver[0], pyver[1]) - src = os.path.join(self.root, 'build', build_dir, 'asyncio', '_overlapped.pyd') - dst = os.path.join(self.root, 'asyncio', '_overlapped.pyd') + src = os.path.join(self.root, 'build', build_dir, PROJECT, '_overlapped.pyd') + dst = os.path.join(self.root, PROJECT, '_overlapped.pyd') shutil.copyfile(src, dst) args = (python, 'runtests.py', '-r') @@ -182,12 +183,15 @@ def build(self, pyver, bits, *cmds): cmd = [python, 'setup.py'] + list(cmds) with tempfile.NamedTemporaryFile(mode="w", suffix=".bat", delete=False) as temp: - print("CD %s" % self.quote(self.root), file=temp) + print("SETLOCAL EnableDelayedExpansion", file=temp) print(self.quote_args(setenv), file=temp) print(BATCH_FAIL_ON_ERROR, file=temp) + # Restore console colors: lightgrey on black + print("COLOR 07", file=temp) print("", file=temp) print("SET DISTUTILS_USE_SDK=1", file=temp) print("SET MSSDK=1", file=temp) + print("CD %s" % self.quote(self.root), file=temp) print(self.quote_args(cmd), file=temp) print(BATCH_FAIL_ON_ERROR, file=temp) From 849d8b1855df65a3fcd7ab59c111e075d46e1328 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 18 Dec 2014 17:29:06 +0100 Subject: [PATCH 1218/1502] Add repr(PipeHandle) --- asyncio/windows_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/asyncio/windows_utils.py b/asyncio/windows_utils.py index 1155a77f..c6e4bc9e 100644 --- a/asyncio/windows_utils.py +++ b/asyncio/windows_utils.py @@ -134,6 +134,13 @@ class PipeHandle: def __init__(self, handle): self._handle = handle + def __repr__(self): + if self._handle != -1: + handle = 'handle=%r' % self._handle + else: + handle = 'closed' + return '<%s %s>' % (self.__class__.__name__, handle) + @property def handle(self): return self._handle From 754d49a2624a33bc316e65b255fecf4bb8210560 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 18 Dec 2014 18:12:00 +0100 Subject: [PATCH 1219/1502] release.py: make debug env var configurable --- release.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/release.py b/release.py index fe56e2c1..8b5c3bc2 100644 --- a/release.py +++ b/release.py @@ -18,13 +18,13 @@ import textwrap PROJECT = 'asyncio' +DEBUG_ENV_VAR = 'PYTHONASYNCIODEBUG' +PYTHON_VERSIONS = ( + ((3, 3), 32), + ((3, 3), 64), +) PY3 = (sys.version_info >= (3,)) HG = 'hg' -_PYTHON_VERSIONS = [(3, 3)] -PYTHON_VERSIONS = [] -for pyver in _PYTHON_VERSIONS: - PYTHON_VERSIONS.append((pyver, 32)) - PYTHON_VERSIONS.append((pyver, 64)) SDK_ROOT = r"C:\Program Files\Microsoft SDKs\Windows" BATCH_FAIL_ON_ERROR = "@IF %errorlevel% neq 0 exit /b %errorlevel%" @@ -41,15 +41,10 @@ def __init__(self): @contextlib.contextmanager def _popen(self, args, **kw): - env2 = kw.pop('env', {}) - env = dict(os.environ) - # Force the POSIX locale - env['LC_ALL'] = 'C' - env.update(env2) print('+ ' + ' '.join(args)) if PY3: kw['universal_newlines'] = True - proc = subprocess.Popen(args, env=env, **kw) + proc = subprocess.Popen(args, **kw) with proc: yield proc @@ -146,7 +141,6 @@ def sdist_upload(self): def runtests(self, pyver, bits): pythonstr = "%s.%s (%s bits)" % (pyver[0], pyver[1], bits) python = self.get_python(pyver, bits) - dbg_env = {'PYTHONASYNCIODEBUG': '1'} self.build(pyver, bits, 'build') if bits == 64: @@ -158,9 +152,15 @@ def runtests(self, pyver, bits): dst = os.path.join(self.root, PROJECT, '_overlapped.pyd') shutil.copyfile(src, dst) + release_env = dict(os.environ) + release_env.pop(DEBUG_ENV_VAR, None) + + dbg_env = dict(os.environ) + dbg_env = {DEBUG_ENV_VAR: '1'} + args = (python, 'runtests.py', '-r') print("Run runtests.py in release mode with %s" % pythonstr) - self.run_command(*args) + self.run_command(*args, env=release_env) print("Run runtests.py in debug mode with %s" % pythonstr) self.run_command(*args, env=dbg_env) @@ -168,7 +168,7 @@ def runtests(self, pyver, bits): if self.aiotest: args = (python, 'run_aiotest.py') print("Run aiotest in release mode with %s" % pythonstr) - self.run_command(*args) + self.run_command(*args, env=release_env) print("Run aiotest in debug mode with %s" % pythonstr) self.run_command(*args, env=dbg_env) From 5921775c6443948d8bb74bee1800c18ee48623ef Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 18 Dec 2014 23:41:46 +0100 Subject: [PATCH 1220/1502] Fix a race condition in BaseSubprocessTransport._try_finish() If the process exited before the _post_init() method was called, scheduling the call to _call_connection_lost() with call_soon() is wrong: connection_made() must be called before connection_lost(). Reuse the BaseSubprocessTransport._call() method to schedule the call to _call_connection_lost() to ensure that connection_made() and connection_lost() are called in the correct order. --- asyncio/base_subprocess.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index d0087793..81698b09 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -153,7 +153,7 @@ def _try_finish(self): if all(p is not None and p.disconnected for p in self._pipes.values()): self._finished = True - self._loop.call_soon(self._call_connection_lost, None) + self._call(self._call_connection_lost, None) def _call_connection_lost(self, exc): try: From 70043a647222c9e7b5fe1bb11ae3e39e3bdd0584 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 18 Dec 2014 23:44:57 +0100 Subject: [PATCH 1221/1502] support: fix assert_python_ok() on python 3.3 Python 3.3 doesn't support the -I command line option (isolated mode). --- asyncio/test_support.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/test_support.py b/asyncio/test_support.py index c9df5d07..fab1f807 100644 --- a/asyncio/test_support.py +++ b/asyncio/test_support.py @@ -40,7 +40,7 @@ def _assert_python(expected_success, *args, **env_vars): else: isolated = not env_vars cmd_line = [sys.executable, '-X', 'faulthandler'] - if isolated: + if isolated and sys.version_info >= (3, 4): # isolated mode: ignore Python environment variables, ignore user # site-packages, and don't add the current directory to sys.path cmd_line.append('-I') From e32105eb32b152c6d3c03620dea83e98e5d256bf Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 19 Dec 2014 00:06:44 +0100 Subject: [PATCH 1222/1502] release.py: Fix code to run tests --- release.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/release.py b/release.py index 8b5c3bc2..8c00c284 100644 --- a/release.py +++ b/release.py @@ -156,7 +156,7 @@ def runtests(self, pyver, bits): release_env.pop(DEBUG_ENV_VAR, None) dbg_env = dict(os.environ) - dbg_env = {DEBUG_ENV_VAR: '1'} + dbg_env[DEBUG_ENV_VAR] = '1' args = (python, 'runtests.py', '-r') print("Run runtests.py in release mode with %s" % pythonstr) From 7cdcdfa11b1cc084f9cc9ab0822da5a8a2014da9 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 19 Dec 2014 16:58:10 +0100 Subject: [PATCH 1223/1502] IocpProactor.wait_for_handle() test now also checks the result of the future --- asyncio/windows_events.py | 5 +++++ tests/test_windows_events.py | 8 ++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 0773d061..d7feb1ae 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -427,6 +427,11 @@ def finish_connect_pipe(err, handle, ov): return self._register(ov, None, finish_connect_pipe, wait_for_post=True) def wait_for_handle(self, handle, timeout=None): + """Wait for a handle. + + Return a Future object. The result of the future is True if the wait + completed, or False if the wait did not complete (on timeout). + """ if timeout is None: ms = _winapi.INFINITE else: diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index b4d9398f..9b264a64 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -98,8 +98,10 @@ def test_wait_for_handle(self): # result should be False at timeout fut = self.loop._proactor.wait_for_handle(event, 0.5) start = self.loop.time() - self.loop.run_until_complete(fut) + done = self.loop.run_until_complete(fut) elapsed = self.loop.time() - start + + self.assertEqual(done, False) self.assertFalse(fut.result()) self.assertTrue(0.48 < elapsed < 0.9, elapsed) @@ -109,8 +111,10 @@ def test_wait_for_handle(self): # result should be True immediately fut = self.loop._proactor.wait_for_handle(event, 10) start = self.loop.time() - self.loop.run_until_complete(fut) + done = self.loop.run_until_complete(fut) elapsed = self.loop.time() - start + + self.assertEqual(done, True) self.assertTrue(fut.result()) self.assertTrue(0 <= elapsed < 0.3, elapsed) From 023fdeb8d3854de0691cd7b2a6692f3f8fe67fe2 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 19 Dec 2014 17:05:28 +0100 Subject: [PATCH 1224/1502] Cleanup runtests.py --- runtests.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/runtests.py b/runtests.py index e9bbdd8e..8cb56fe0 100644 --- a/runtests.py +++ b/runtests.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 """Run Tulip unittests. Usage: @@ -205,6 +206,17 @@ def run(self, test): return result +def _runtests(args, tests): + v = 0 if args.quiet else args.verbose + 1 + runner_factory = TestRunner if args.findleaks else unittest.TextTestRunner + if args.randomize: + randomize_tests(tests, args.seed) + runner = runner_factory(verbosity=v, failfast=args.failfast) + sys.stdout.flush() + sys.stderr.flush() + return runner.run(tests) + + def runtests(): args = ARGS.parse_args() @@ -238,9 +250,6 @@ def runtests(): v = 0 if args.quiet else args.verbose + 1 failfast = args.failfast - catchbreak = args.catchbreak - findleaks = args.findleaks - runner_factory = TestRunner if findleaks else unittest.TextTestRunner if args.coverage: cov = coverage.coverage(branch=True, @@ -262,7 +271,7 @@ def runtests(): logging.basicConfig(level=level) finder = TestsFinder(args.testsdir, includes, excludes) - if catchbreak: + if args.catchbreak: installHandler() import asyncio.coroutines if asyncio.coroutines._DEBUG: @@ -270,21 +279,14 @@ def runtests(): else: print("Run tests in release mode") try: + tests = finder.load_tests() if args.forever: while True: - tests = finder.load_tests() - if args.randomize: - randomize_tests(tests, args.seed) - result = runner_factory(verbosity=v, - failfast=failfast).run(tests) + result = _runtests(args, tests) if not result.wasSuccessful(): sys.exit(1) else: - tests = finder.load_tests() - if args.randomize: - randomize_tests(tests, args.seed) - result = runner_factory(verbosity=v, - failfast=failfast).run(tests) + result = _runtests(args, tests) sys.exit(not result.wasSuccessful()) finally: if args.coverage: From 692066b2c20cb4d08d45ee30631530b0cbb5b824 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 19 Dec 2014 17:07:56 +0100 Subject: [PATCH 1225/1502] Enhance release.py * Only get the python executable once per python version * Much lesser verbose output: by default, hide all logs when building trollius --- release.py | 175 +++++++++++++++++++++++++++++++++-------------------- 1 file changed, 110 insertions(+), 65 deletions(-) diff --git a/release.py b/release.py index 8c00c284..034f2a01 100644 --- a/release.py +++ b/release.py @@ -28,6 +28,47 @@ SDK_ROOT = r"C:\Program Files\Microsoft SDKs\Windows" BATCH_FAIL_ON_ERROR = "@IF %errorlevel% neq 0 exit /b %errorlevel%" + +class PythonVersion: + def __init__(self, major, minor, bits): + self.major = major + self.minor = minor + self.bits = bits + self._executable = None + + def get_executable(self, app): + if self._executable: + return self._executable + + if self.bits == 32: + python = 'c:\\Python%s%s_32bit\\python.exe' % (self.major, self.minor) + else: + python = 'c:\\Python%s%s\\python.exe' % (self.major, self.minor) + if not os.path.exists(python): + print("Unable to find python %s" % self) + print("%s does not exists" % python) + sys.exit(1) + code = ( + 'import platform, sys; ' + 'print("{ver.major}.{ver.minor} {bits}".format(' + 'ver=sys.version_info, ' + 'bits=platform.architecture()[0]))' + ) + exitcode, stdout = app.get_output(python, '-c', code) + stdout = stdout.rstrip() + expected = "%s.%s %sbit" % (self.major, self.minor, self.bits) + if stdout != expected: + print("Python version or architecture doesn't match") + print("got %r, expected %r" % (stdout, expected)) + print(python) + sys.exit(1) + self._executable = python + return python + + def __str__(self): + return 'Python %s.%s (%s bits)' % (self.major, self.minor, self.bits) + + class Release(object): def __init__(self): root = os.path.dirname(__file__) @@ -37,11 +78,17 @@ def __init__(self): self.sdist = False self.dry_run = True self.test = True - self.aiotest = True + self.aiotest = False + self.verbose = False + self.python_versions = [ + PythonVersion(pyver[0], pyver[1], bits) + for pyver, bits in PYTHON_VERSIONS] @contextlib.contextmanager def _popen(self, args, **kw): - print('+ ' + ' '.join(args)) + verbose = kw.pop('verbose', True) + if self.verbose and verbose: + print('+ ' + ' '.join(args)) if PY3: kw['universal_newlines'] = True proc = subprocess.Popen(args, **kw) @@ -49,9 +96,11 @@ def _popen(self, args, **kw): yield proc def get_output(self, *args, **kw): - with self._popen(args, stdout=subprocess.PIPE, **kw) as proc: + kw['stdout'] = subprocess.PIPE + kw['stderr'] = subprocess.STDOUT + with self._popen(args, **kw) as proc: stdout, stderr = proc.communicate() - return stdout + return proc.returncode, stdout def run_command(self, *args, **kw): with self._popen(args, **kw) as proc: @@ -60,65 +109,42 @@ def run_command(self, *args, **kw): sys.exit(exitcode) def get_local_changes(self): - status = self.get_output(HG, 'status') + exitcode, status = self.get_output(HG, 'status') return [line for line in status.splitlines() if not line.startswith("?")] def remove_directory(self, name): path = os.path.join(self.root, name) if os.path.exists(path): - print("Remove directory: %s" % name) + if self.verbose: + print("Remove directory: %s" % name) shutil.rmtree(path) def remove_file(self, name): path = os.path.join(self.root, name) if os.path.exists(path): - print("Remove file: %s" % name) + if self.verbose: + print("Remove file: %s" % name) os.unlink(path) - def windows_sdk_setenv(self, pyver, bits): - if pyver >= (3, 3): + def windows_sdk_setenv(self, pyver): + if (pyver.major, pyver.minor) >= (3, 3): sdkver = "v7.1" else: sdkver = "v7.0" setenv = os.path.join(SDK_ROOT, sdkver, 'Bin', 'SetEnv.cmd') if not os.path.exists(setenv): - print("Unable to find Windows SDK %s for Python %s.%s" - % (sdkver, pyver[0], pyver[1])) + print("Unable to find Windows SDK %s for %s" + % (sdkver, pyver)) print("Please download and install it") print("%s does not exists" % setenv) sys.exit(1) - if bits == 64: + if pyver.bits == 64: arch = '/x64' else: arch = '/x86' return ["CALL", setenv, "/release", arch] - def get_python(self, version, bits): - if bits == 32: - python = 'c:\\Python%s%s_32bit\\python.exe' % version - else: - python = 'c:\\Python%s%s\\python.exe' % version - if not os.path.exists(python): - print("Unable to find python%s.%s" % version) - print("%s does not exists" % python) - sys.exit(1) - code = ( - 'import platform, sys; ' - 'print("{ver.major}.{ver.minor} {bits}".format(' - 'ver=sys.version_info, ' - 'bits=platform.architecture()[0]))' - ) - stdout = self.get_output(python, '-c', code) - stdout = stdout.rstrip() - expected = "%s.%s %sbit" % (version[0], version[1], bits) - if stdout != expected: - print("Python version or architecture doesn't match") - print("got %r, expected %r" % (stdout, expected)) - print(python) - sys.exit(1) - return python - def quote(self, arg): if not re.search("[ '\"]", arg): return arg @@ -129,6 +155,8 @@ def quote_args(self, args): return ' '.join(self.quote(arg) for arg in args) def cleanup(self): + if self.verbose: + print("Cleanup") self.remove_directory('build') self.remove_directory('dist') self.remove_file('_overlapped.pyd') @@ -138,16 +166,18 @@ def sdist_upload(self): self.cleanup() self.run_command(sys.executable, 'setup.py', 'sdist', 'upload') - def runtests(self, pyver, bits): - pythonstr = "%s.%s (%s bits)" % (pyver[0], pyver[1], bits) - python = self.get_python(pyver, bits) + def runtests(self, pyver): + print("Run tests on %s" % pyver) + + python = pyver.get_executable(self) - self.build(pyver, bits, 'build') - if bits == 64: + print("Build _overlapped.pyd for %s" % pyver) + self.build(pyver, 'build') + if pyver.bits == 64: arch = 'win-amd64' else: arch = 'win32' - build_dir = 'lib.%s-%s.%s' % (arch, pyver[0], pyver[1]) + build_dir = 'lib.%s-%s.%s' % (arch, pyver.major, pyver.minor) src = os.path.join(self.root, 'build', build_dir, PROJECT, '_overlapped.pyd') dst = os.path.join(self.root, PROJECT, '_overlapped.pyd') shutil.copyfile(src, dst) @@ -159,26 +189,27 @@ def runtests(self, pyver, bits): dbg_env[DEBUG_ENV_VAR] = '1' args = (python, 'runtests.py', '-r') - print("Run runtests.py in release mode with %s" % pythonstr) + print("Run runtests.py in release mode on %s" % pyver) self.run_command(*args, env=release_env) - print("Run runtests.py in debug mode with %s" % pythonstr) + print("Run runtests.py in debug mode on %s" % pyver) self.run_command(*args, env=dbg_env) if self.aiotest: args = (python, 'run_aiotest.py') - print("Run aiotest in release mode with %s" % pythonstr) + print("Run aiotest in release mode on %s" % pyver) self.run_command(*args, env=release_env) - print("Run aiotest in debug mode with %s" % pythonstr) + print("Run aiotest in debug mode on %s" % pyver) self.run_command(*args, env=dbg_env) + print("") - def build(self, pyver, bits, *cmds): + def build(self, pyver, *cmds): self.cleanup() - setenv = self.windows_sdk_setenv(pyver, bits) + setenv = self.windows_sdk_setenv(pyver) - python = self.get_python(pyver, bits) + python = pyver.get_executable(self) cmd = [python, 'setup.py'] + list(cmds) @@ -196,15 +227,25 @@ def build(self, pyver, bits, *cmds): print(BATCH_FAIL_ON_ERROR, file=temp) try: - self.run_command(temp.name) + if self.verbose: + print("Setup Windows SDK") + print("+ " + ' '.join(cmd)) + if self.verbose: + self.run_command(temp.name, verbose=False) + else: + exitcode, stdout = self.get_output(temp.name, verbose=False) + if exitcode: + sys.stdout.write(stdout) + sys.stdout.flush() finally: os.unlink(temp.name) - def test_wheel(self, pyver, bits): - self.build(pyver, bits, 'bdist_wheel') + def test_wheel(self, pyver): + print("Test building wheel package for %s" % pyver) + self.build(pyver, 'bdist_wheel') - def publish_wheel(self, pyver, bits): - self.build(pyver, bits, 'bdist_wheel', 'upload') + def publish_wheel(self, pyver): + self.build(pyver, 'bdist_wheel', 'upload') def main(self): try: @@ -235,14 +276,19 @@ def main(self): sys.exit(1) hg_tag = sys.argv[1] - self.run_command(HG, 'up', hg_tag) + print("Update repository to revision %s" % hg_tag) + exitcode, output = self.get_output(HG, 'update', hg_tag) + if exitcode: + sys.stdout.write(output) + sys.stdout.flush() + sys.exit(exitcode) if self.test: - for pyver, bits in PYTHON_VERSIONS: - self.runtests(pyver, bits) + for pyver in self.python_versions: + self.runtests(pyver) - for pyver, bits in PYTHON_VERSIONS: - self.test_wheel(pyver, bits) + for pyver in self.python_versions: + self.test_wheel(pyver) if self.dry_run: sys.exit(0) @@ -253,8 +299,8 @@ def main(self): if self.sdist: self.sdist_upload() - for pyver, bits in PYTHON_VERSIONS: - self.publish_wheel(pyver, bits) + for pyver in self.python_versions: + self.publish_wheel(pyver) print("") if self.register: @@ -262,9 +308,8 @@ def main(self): print("Uploaded:") if self.sdist: print("- sdist") - for pyver, bits in PYTHON_VERSIONS: - print("- Windows wheel %s bits package for Python %s.%s" - % (bits, pyver[0], pyver[1])) + for pyver in self.python_versions: + print("- Windows wheel package for %s" % pyver) if __name__ == "__main__": Release().main() From ee53e225562219bf4b29739eb7faefc58db9f33f Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 19 Dec 2014 17:45:34 +0100 Subject: [PATCH 1226/1502] Fix release.py with SDK 7.1 --- release.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/release.py b/release.py index 034f2a01..f4879706 100644 --- a/release.py +++ b/release.py @@ -129,13 +129,15 @@ def remove_file(self, name): def windows_sdk_setenv(self, pyver): if (pyver.major, pyver.minor) >= (3, 3): - sdkver = "v7.1" + path = "v7.1" + sdkver = (7, 1) else: - sdkver = "v7.0" - setenv = os.path.join(SDK_ROOT, sdkver, 'Bin', 'SetEnv.cmd') + path = "v7.0" + sdkver = (7, 0) + setenv = os.path.join(SDK_ROOT, path, 'Bin', 'SetEnv.cmd') if not os.path.exists(setenv): - print("Unable to find Windows SDK %s for %s" - % (sdkver, pyver)) + print("Unable to find Windows SDK %s.%s for %s" + % (sdkver[0], sdkver[1], pyver)) print("Please download and install it") print("%s does not exists" % setenv) sys.exit(1) @@ -143,7 +145,8 @@ def windows_sdk_setenv(self, pyver): arch = '/x64' else: arch = '/x86' - return ["CALL", setenv, "/release", arch] + cmd = ["CALL", setenv, "/release", arch] + return (cmd, sdkver) def quote(self, arg): if not re.search("[ '\"]", arg): @@ -207,7 +210,7 @@ def runtests(self, pyver): def build(self, pyver, *cmds): self.cleanup() - setenv = self.windows_sdk_setenv(pyver) + setenv, sdkver = self.windows_sdk_setenv(pyver) python = pyver.get_executable(self) @@ -230,7 +233,9 @@ def build(self, pyver, *cmds): if self.verbose: print("Setup Windows SDK") print("+ " + ' '.join(cmd)) - if self.verbose: + # SDK 7.1 uses the COLOR command which makes SetEnv.cmd failing + # if the stdout is not a TTY (if we redirect stdout into a file) + if self.verbose or sdkver >= (7, 1): self.run_command(temp.name, verbose=False) else: exitcode, stdout = self.get_output(temp.name, verbose=False) From cd4007814d635e1da805bd353afbff30942e0892 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 19 Dec 2014 18:20:37 +0100 Subject: [PATCH 1227/1502] release.py: better command line interface Commands: build, test, test_wheel, release, clean --- release.py | 231 ++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 183 insertions(+), 48 deletions(-) diff --git a/release.py b/release.py index f4879706..8a1f6a1d 100644 --- a/release.py +++ b/release.py @@ -4,12 +4,16 @@ Usage: "python release.py HG_TAG" where HG_TAG is a Mercurial tag, usually a version number like "3.4.2". -Modify manually the dry_run attribute to upload files. +Requirements: -It requires the Windows SDK 7.1 on Windows 64 bits and the aiotest module. +- Python 3.3 and newer requires the Windows SDK 7.1 to build wheel packages +- Python 2.7 requires the Windows SDK 7.0 +- the aiotest module is required to run aiotest tests """ import contextlib +import optparse import os +import platform import re import shutil import subprocess @@ -36,6 +40,16 @@ def __init__(self, major, minor, bits): self.bits = bits self._executable = None + @staticmethod + def running(): + arch = platform.architecture()[0] + bits = int(arch[:2]) + pyver = PythonVersion(sys.version_info.major, + sys.version_info.minor, + bits) + pyver._executable = sys.executable + return pyver + def get_executable(self, app): if self._executable: return self._executable @@ -74,12 +88,15 @@ def __init__(self): root = os.path.dirname(__file__) self.root = os.path.realpath(root) # Set these attributes to True to run also register sdist upload + self.wheel = False + self.test = False self.register = False self.sdist = False - self.dry_run = True - self.test = True self.aiotest = False self.verbose = False + self.upload = False + # Release mode: enable more tests + self.release = False self.python_versions = [ PythonVersion(pyver[0], pyver[1], bits) for pyver, bits in PYTHON_VERSIONS] @@ -102,6 +119,14 @@ def get_output(self, *args, **kw): stdout, stderr = proc.communicate() return proc.returncode, stdout + def check_output(self, *args, **kw): + exitcode, output = self.get_output(*args, **kw) + if exitcode: + sys.stdout.write(output) + sys.stdout.flush() + sys.exit(1) + return output + def run_command(self, *args, **kw): with self._popen(args, **kw) as proc: exitcode = proc.wait() @@ -109,7 +134,7 @@ def run_command(self, *args, **kw): sys.exit(exitcode) def get_local_changes(self): - exitcode, status = self.get_output(HG, 'status') + status = self.check_output(HG, 'status') return [line for line in status.splitlines() if not line.startswith("?")] @@ -169,11 +194,7 @@ def sdist_upload(self): self.cleanup() self.run_command(sys.executable, 'setup.py', 'sdist', 'upload') - def runtests(self, pyver): - print("Run tests on %s" % pyver) - - python = pyver.get_executable(self) - + def build_inplace(self, pyver): print("Build _overlapped.pyd for %s" % pyver) self.build(pyver, 'build') if pyver.bits == 64: @@ -185,23 +206,34 @@ def runtests(self, pyver): dst = os.path.join(self.root, PROJECT, '_overlapped.pyd') shutil.copyfile(src, dst) + def runtests(self, pyver): + print("Run tests on %s" % pyver) + + if not self.options.no_compile: + self.build_inplace(pyver) + release_env = dict(os.environ) release_env.pop(DEBUG_ENV_VAR, None) dbg_env = dict(os.environ) dbg_env[DEBUG_ENV_VAR] = '1' + python = pyver.get_executable(self) args = (python, 'runtests.py', '-r') - print("Run runtests.py in release mode on %s" % pyver) - self.run_command(*args, env=release_env) + + if self.release: + print("Run runtests.py in release mode on %s" % pyver) + self.run_command(*args, env=release_env) print("Run runtests.py in debug mode on %s" % pyver) self.run_command(*args, env=dbg_env) if self.aiotest: args = (python, 'run_aiotest.py') - print("Run aiotest in release mode on %s" % pyver) - self.run_command(*args, env=release_env) + + if self.release: + print("Run aiotest in release mode on %s" % pyver) + self.run_command(*args, env=release_env) print("Run aiotest in debug mode on %s" % pyver) self.run_command(*args, env=dbg_env) @@ -231,17 +263,14 @@ def build(self, pyver, *cmds): try: if self.verbose: - print("Setup Windows SDK") + print("Setup Windows SDK %s.%s" % sdkver) print("+ " + ' '.join(cmd)) # SDK 7.1 uses the COLOR command which makes SetEnv.cmd failing # if the stdout is not a TTY (if we redirect stdout into a file) if self.verbose or sdkver >= (7, 1): self.run_command(temp.name, verbose=False) else: - exitcode, stdout = self.get_output(temp.name, verbose=False) - if exitcode: - sys.stdout.write(stdout) - sys.stdout.flush() + self.check_output(temp.name, verbose=False) finally: os.unlink(temp.name) @@ -252,22 +281,114 @@ def test_wheel(self, pyver): def publish_wheel(self, pyver): self.build(pyver, 'bdist_wheel', 'upload') - def main(self): - try: - pos = sys.argv[1:].index('--ignore') - except ValueError: - ignore = False + def parse_options(self): + parser = optparse.OptionParser( + description="Run all unittests.", + usage="%prog [options] command") + parser.add_option( + '-v', '--verbose', action="store_true", dest='verbose', + default=0, help='verbose') + parser.add_option( + '-t', '--tag', type="str", + help='Mercurial tag or revision, required to release') + parser.add_option( + '-p', '--python', type="str", + help='Only build/test one specific Python version, ex: "2.7:32"') + parser.add_option( + '-C', "--no-compile", action="store_true", + help="Don't compile the module, this options implies --running", + default=False) + parser.add_option( + '-r', "--running", action="store_true", + help='Only use the running Python version', + default=False) + parser.add_option( + '--ignore', action="store_true", + help='Ignore local changes', + default=False) + self.options, args = parser.parse_args() + if len(args) == 1: + command = args[0] + else: + command = None + + if self.options.no_compile: + self.options.running = True + + if command == 'clean': + self.options.verbose = True + elif command == 'build': + self.options.running = True + elif command == 'test_wheel': + self.wheel = True + elif command == 'test': + self.test = True + elif command == 'release': + if not self.options.tag: + print("The release command requires the --tag option") + sys.exit(1) + + self.release = True + self.wheel = True + self.test = True + #self.upload = True else: - ignore = True - del sys.argv[1+pos] - if len(sys.argv) != 2: - print("usage: %s hg_tag" % sys.argv[0]) + if command: + print("Invalid command: %s" % command) + else: + parser.print_usage() + + print("Available commands:") + print("- build: build asyncio in place, imply --running") + print("- test: run tests") + print("- test_wheel: test building wheel packages") + print("- release: run tests and publish wheel packages,") + print(" require the --tag option") + print("- clean: cleanup the project") sys.exit(1) + if self.options.python and self.options.running: + print("--python and --running options are exclusive") + sys.exit(1) + + python = self.options.python + if python: + match = re.match("^([23])\.([0-9])/(32|64)$", python) + if not match: + print("Invalid Python version: %s" % python) + print('Format of a Python version: "x.y/bits"') + print("Example: 2.7/32") + sys.exit(1) + major = int(match.group(1)) + minor = int(match.group(2)) + bits = int(match.group(3)) + self.python_versions = [PythonVersion(major, minor, bits)] + + if self.options.running: + self.python_versions = [PythonVersion.running()] + + self.verbose = self.options.verbose + self.command = command + + def main(self): + self.parse_options() + print("Directory: %s" % self.root) os.chdir(self.root) - if not ignore: + if self.command == "clean": + self.cleanup() + sys.exit(1) + + if self.command == "build": + if len(self.python_versions) != 1: + print("build command requires one specific Python version") + print("Use the --python command line option") + sys.exit(1) + pyver = self.python_versions[0] + self.build_inplace(pyver) + + if (self.register or self.upload) and (not self.options.ignore): lines = self.get_local_changes() else: lines = () @@ -280,41 +401,55 @@ def main(self): print("or use the --ignore command line option") sys.exit(1) - hg_tag = sys.argv[1] - print("Update repository to revision %s" % hg_tag) - exitcode, output = self.get_output(HG, 'update', hg_tag) - if exitcode: - sys.stdout.write(output) - sys.stdout.flush() - sys.exit(exitcode) + hg_tag = self.options.tag + if hg_tag: + print("Update repository to revision %s" % hg_tag) + self.check_output(HG, 'update', hg_tag) + + hg_rev = self.check_output(HG, 'id').rstrip() + + if self.wheel: + for pyver in self.python_versions: + self.test_wheel(pyver) if self.test: for pyver in self.python_versions: self.runtests(pyver) - for pyver in self.python_versions: - self.test_wheel(pyver) - - if self.dry_run: - sys.exit(0) - if self.register: self.run_command(sys.executable, 'setup.py', 'register') if self.sdist: self.sdist_upload() - for pyver in self.python_versions: - self.publish_wheel(pyver) + if self.upload: + for pyver in self.python_versions: + self.publish_wheel(pyver) + + hg_rev2 = self.check_output(HG, 'id').rstrip() + if hg_rev != hg_rev2: + print("ERROR: The Mercurial revision changed") + print("Before: %s" % hg_rev) + print("After: %s" % hg_rev2) + sys.exit(1) print("") + print("Mercurial revision: %s" % hg_rev) + if self.command == 'build': + print("Inplace compilation done") + if self.wheel: + print("Compilation of wheel packages succeeded") + if self.test: + print("Tests succeeded") if self.register: - print("Publish version %s" % hg_tag) - print("Uploaded:") + print("Project registered on the Python cheeseshop (PyPI)") if self.sdist: - print("- sdist") + print("Project source code uploaded to the Python cheeseshop (PyPI)") + if self.upload: + print("Wheel packages uploaded to the Python cheeseshop (PyPI)") for pyver in self.python_versions: - print("- Windows wheel package for %s" % pyver) + print("- %s" % pyver) + if __name__ == "__main__": Release().main() From 2b06827fbbd4300c914f098eb5b0172e6c8d78ca Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 19 Dec 2014 22:18:09 +0100 Subject: [PATCH 1228/1502] asyncio.test_support now uses test.support and test.script_helper if available --- asyncio/test_support.py | 12 ++++++++++++ tests/test_base_events.py | 21 ++++++++------------- tests/test_events.py | 5 +---- tests/test_futures.py | 5 +---- tests/test_subprocess.py | 5 +---- tests/test_tasks.py | 29 ++++++++++++----------------- tests/test_windows_utils.py | 8 ++------ 7 files changed, 37 insertions(+), 48 deletions(-) diff --git a/asyncio/test_support.py b/asyncio/test_support.py index fab1f807..336f3acf 100644 --- a/asyncio/test_support.py +++ b/asyncio/test_support.py @@ -290,3 +290,15 @@ def requires_freebsd_version(*min_version): version is less than 7.2. """ return _requires_unix_version('FreeBSD', min_version) + +# Use test.support if available +try: + from test.support import * +except ImportError: + pass + +# Use test.script_helper if available +try: + from test.script_helper import assert_python_ok +except ImportError: + pass diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 4e5b6ca9..afb65b29 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -12,13 +12,8 @@ import asyncio from asyncio import base_events from asyncio import constants +from asyncio import test_support as support from asyncio import test_utils -try: - from test.script_helper import assert_python_ok - from test import support -except ImportError: - from asyncio import test_support as support - from asyncio.test_support import assert_python_ok MOCK_ANY = mock.ANY @@ -584,19 +579,19 @@ def test_env_var_debug(self): # Test with -E to not fail if the unit test was run with # PYTHONASYNCIODEBUG set to a non-empty string - sts, stdout, stderr = assert_python_ok('-E', '-c', code) + sts, stdout, stderr = support.assert_python_ok('-E', '-c', code) self.assertEqual(stdout.rstrip(), b'False') - sts, stdout, stderr = assert_python_ok('-c', code, - PYTHONASYNCIODEBUG='') + sts, stdout, stderr = support.assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='') self.assertEqual(stdout.rstrip(), b'False') - sts, stdout, stderr = assert_python_ok('-c', code, - PYTHONASYNCIODEBUG='1') + sts, stdout, stderr = support.assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='1') self.assertEqual(stdout.rstrip(), b'True') - sts, stdout, stderr = assert_python_ok('-E', '-c', code, - PYTHONASYNCIODEBUG='1') + sts, stdout, stderr = support.assert_python_ok('-E', '-c', code, + PYTHONASYNCIODEBUG='1') self.assertEqual(stdout.rstrip(), b'False') def test_create_task(self): diff --git a/tests/test_events.py b/tests/test_events.py index 06302920..04dc8805 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -25,11 +25,8 @@ import asyncio from asyncio import proactor_events from asyncio import selector_events +from asyncio import test_support as support from asyncio import test_utils -try: - from test import support # find_unused_port, IPV6_ENABLED, TEST_HOME_DIR -except ImportError: - from asyncio import test_support as support def data_file(filename): diff --git a/tests/test_futures.py b/tests/test_futures.py index f9c3ad20..7c564624 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -8,11 +8,8 @@ from unittest import mock import asyncio +from asyncio import test_support as support from asyncio import test_utils -try: - from test import support # gc_collect -except ImportError: - from asyncio import test_support as support def _fakefunc(f): diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index 08c8ac24..5cb0f030 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -5,13 +5,10 @@ import asyncio from asyncio import subprocess +from asyncio import test_support as support from asyncio import test_utils if sys.platform != 'win32': from asyncio import unix_events -try: - from test import support # PIPE_MAX_SIZE -except ImportError: - from asyncio import test_support as support # Program blocking PROGRAM_BLOCKED = [sys.executable, '-c', 'import time; time.sleep(3600)'] diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 25b21dc5..c0053668 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -7,15 +7,10 @@ import unittest import weakref from unittest import mock -try: - from test import support # gc_collect - from test.script_helper import assert_python_ok -except ImportError: - from asyncio import test_support as support - from asyncio.test_support import assert_python_ok import asyncio from asyncio import coroutines +from asyncio import test_support as support from asyncio import test_utils @@ -1781,23 +1776,23 @@ def test_env_var_debug(self): # Test with -E to not fail if the unit test was run with # PYTHONASYNCIODEBUG set to a non-empty string - sts, stdout, stderr = assert_python_ok('-E', '-c', code, - PYTHONPATH=aio_path) + sts, stdout, stderr = support.assert_python_ok('-E', '-c', code, + PYTHONPATH=aio_path) self.assertEqual(stdout.rstrip(), b'False') - sts, stdout, stderr = assert_python_ok('-c', code, - PYTHONASYNCIODEBUG='', - PYTHONPATH=aio_path) + sts, stdout, stderr = support.assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='', + PYTHONPATH=aio_path) self.assertEqual(stdout.rstrip(), b'False') - sts, stdout, stderr = assert_python_ok('-c', code, - PYTHONASYNCIODEBUG='1', - PYTHONPATH=aio_path) + sts, stdout, stderr = support.assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='1', + PYTHONPATH=aio_path) self.assertEqual(stdout.rstrip(), b'True') - sts, stdout, stderr = assert_python_ok('-E', '-c', code, - PYTHONASYNCIODEBUG='1', - PYTHONPATH=aio_path) + sts, stdout, stderr = support.assert_python_ok('-E', '-c', code, + PYTHONASYNCIODEBUG='1', + PYTHONPATH=aio_path) self.assertEqual(stdout.rstrip(), b'False') diff --git a/tests/test_windows_utils.py b/tests/test_windows_utils.py index b9579491..92db24e6 100644 --- a/tests/test_windows_utils.py +++ b/tests/test_windows_utils.py @@ -5,18 +5,14 @@ import unittest from unittest import mock -try: - from test import support # gc_collect, IPV6_ENABLED -except ImportError: - from asyncio import test_support as support - if sys.platform != 'win32': raise unittest.SkipTest('Windows only') import _winapi -from asyncio import windows_utils from asyncio import _overlapped +from asyncio import test_support as support +from asyncio import windows_utils class WinsocketpairTests(unittest.TestCase): From c5a51e6ce648bf1fd54b2a8d6d7249df37804627 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 19 Dec 2014 22:20:44 +0100 Subject: [PATCH 1229/1502] Fix pyflakes warnings: remove unused imports and variables --- tests/test_selector_events.py | 3 --- tests/test_selectors.py | 4 ++-- tests/test_unix_events.py | 2 -- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 8eba56c4..ff114f82 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -1,10 +1,7 @@ """Tests for selector_events.py""" import errno -import gc -import pprint import socket -import sys import unittest from unittest import mock try: diff --git a/tests/test_selectors.py b/tests/test_selectors.py index 93929626..d91c78b1 100644 --- a/tests/test_selectors.py +++ b/tests/test_selectors.py @@ -160,11 +160,11 @@ def test_modify_data_use_a_shortcut(self): d2 = object() s = FakeSelector() - key = s.register(fobj, selectors.EVENT_READ, d1) + s.register(fobj, selectors.EVENT_READ, d1) s.unregister = mock.Mock() s.register = mock.Mock() - key2 = s.modify(fobj, selectors.EVENT_READ, d2) + s.modify(fobj, selectors.EVENT_READ, d2) self.assertFalse(s.unregister.called) self.assertFalse(s.register.called) diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index b6ad0189..4b825dc8 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -1,11 +1,9 @@ """Tests for unix_events.py.""" import collections -import gc import errno import io import os -import pprint import signal import socket import stat From b5f24d82b7b59f9eaa653bbfb59e0ce700f467d3 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 19 Dec 2014 22:45:32 +0100 Subject: [PATCH 1230/1502] setup.py: fix ResourceWarning --- setup.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 87a629dc..2581bfda 100644 --- a/setup.py +++ b/setup.py @@ -25,12 +25,15 @@ ) extensions.append(ext) +with open("README") as fp: + long_description = fp.read() + setup( name="asyncio", version="3.4.3", description="reference implementation of PEP 3156", - long_description=open("README").read(), + long_description=long_description, url="http://www.python.org/dev/peps/pep-3156/", classifiers=[ From c2de37e68791a2d9dc8871f32b378ef45deaf95f Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 19 Dec 2014 22:46:37 +0100 Subject: [PATCH 1231/1502] Port release.py to Python 2 and to UNIX --- release.py | 171 +++++++++++++++++++++++++++++++++++------------------ 1 file changed, 114 insertions(+), 57 deletions(-) diff --git a/release.py b/release.py index 8a1f6a1d..c43822b0 100644 --- a/release.py +++ b/release.py @@ -24,13 +24,18 @@ PROJECT = 'asyncio' DEBUG_ENV_VAR = 'PYTHONASYNCIODEBUG' PYTHON_VERSIONS = ( - ((3, 3), 32), - ((3, 3), 64), + (3, 3), ) PY3 = (sys.version_info >= (3,)) HG = 'hg' SDK_ROOT = r"C:\Program Files\Microsoft SDKs\Windows" BATCH_FAIL_ON_ERROR = "@IF %errorlevel% neq 0 exit /b %errorlevel%" +WINDOWS = (sys.platform == 'win32') + + +def get_archiecture_bits(): + arch = platform.architecture()[0] + return int(arch[:2]) class PythonVersion: @@ -42,42 +47,62 @@ def __init__(self, major, minor, bits): @staticmethod def running(): - arch = platform.architecture()[0] - bits = int(arch[:2]) + bits = get_archiecture_bits() pyver = PythonVersion(sys.version_info.major, sys.version_info.minor, bits) pyver._executable = sys.executable return pyver + def _get_executable_windows(self, app): + if self.bits == 32: + executable = 'c:\\Python%s%s_32bit\\python.exe' + else: + executable = 'c:\\Python%s%s\\python.exe' + executable = executable % (self.major, self.minor) + if not os.path.exists(executable): + print("Unable to find python %s" % self) + print("%s does not exists" % executable) + sys.exit(1) + return executable + + def _get_executable_unix(self, app): + return 'python%s.%s' % (self.major, self.minor) + def get_executable(self, app): if self._executable: return self._executable - if self.bits == 32: - python = 'c:\\Python%s%s_32bit\\python.exe' % (self.major, self.minor) + if WINDOWS: + executable = self._get_executable_windows(app) else: - python = 'c:\\Python%s%s\\python.exe' % (self.major, self.minor) - if not os.path.exists(python): - print("Unable to find python %s" % self) - print("%s does not exists" % python) - sys.exit(1) + executable = self._get_executable_unix(app) + code = ( 'import platform, sys; ' 'print("{ver.major}.{ver.minor} {bits}".format(' 'ver=sys.version_info, ' 'bits=platform.architecture()[0]))' ) - exitcode, stdout = app.get_output(python, '-c', code) - stdout = stdout.rstrip() - expected = "%s.%s %sbit" % (self.major, self.minor, self.bits) - if stdout != expected: - print("Python version or architecture doesn't match") - print("got %r, expected %r" % (stdout, expected)) - print(python) + try: + exitcode, stdout = app.get_output(executable, '-c', code, + ignore_stderr=True) + except OSError as exc: + print("Error while checking %s:" % self) + print(str(exc)) + print("Executable: %s" % executable) sys.exit(1) - self._executable = python - return python + else: + stdout = stdout.rstrip() + expected = "%s.%s %sbit" % (self.major, self.minor, self.bits) + if stdout != expected: + print("Python version or architecture doesn't match") + print("got %r, expected %r" % (stdout, expected)) + print("Executable: %s" % executable) + sys.exit(1) + + self._executable = executable + return executable def __str__(self): return 'Python %s.%s (%s bits)' % (self.major, self.minor, self.bits) @@ -97,9 +122,16 @@ def __init__(self): self.upload = False # Release mode: enable more tests self.release = False - self.python_versions = [ - PythonVersion(pyver[0], pyver[1], bits) - for pyver, bits in PYTHON_VERSIONS] + self.python_versions = [] + if WINDOWS: + supported_archs = (32, 64) + else: + bits = get_archiecture_bits() + supported_archs = (bits,) + for major, minor in PYTHON_VERSIONS: + for bits in supported_archs: + pyver = PythonVersion(major, minor, bits) + self.python_versions.append(pyver) @contextlib.contextmanager def _popen(self, args, **kw): @@ -109,15 +141,28 @@ def _popen(self, args, **kw): if PY3: kw['universal_newlines'] = True proc = subprocess.Popen(args, **kw) - with proc: + try: yield proc + except: + proc.kill() + proc.wait() + raise def get_output(self, *args, **kw): kw['stdout'] = subprocess.PIPE - kw['stderr'] = subprocess.STDOUT - with self._popen(args, **kw) as proc: - stdout, stderr = proc.communicate() - return proc.returncode, stdout + ignore_stderr = kw.pop('ignore_stderr', False) + if ignore_stderr: + devnull = open(os.path.devnull, 'wb') + kw['stderr'] = devnull + else: + kw['stderr'] = subprocess.STDOUT + try: + with self._popen(args, **kw) as proc: + stdout, stderr = proc.communicate() + return proc.returncode, stdout + finally: + if ignore_stderr: + devnull.close() def check_output(self, *args, **kw): exitcode, output = self.get_output(*args, **kw) @@ -195,21 +240,23 @@ def sdist_upload(self): self.run_command(sys.executable, 'setup.py', 'sdist', 'upload') def build_inplace(self, pyver): - print("Build _overlapped.pyd for %s" % pyver) + print("Build for %s" % pyver) self.build(pyver, 'build') - if pyver.bits == 64: - arch = 'win-amd64' - else: - arch = 'win32' - build_dir = 'lib.%s-%s.%s' % (arch, pyver.major, pyver.minor) - src = os.path.join(self.root, 'build', build_dir, PROJECT, '_overlapped.pyd') - dst = os.path.join(self.root, PROJECT, '_overlapped.pyd') - shutil.copyfile(src, dst) + + if WINDOWS: + if pyver.bits == 64: + arch = 'win-amd64' + else: + arch = 'win32' + build_dir = 'lib.%s-%s.%s' % (arch, pyver.major, pyver.minor) + src = os.path.join(self.root, 'build', build_dir, PROJECT, '_overlapped.pyd') + dst = os.path.join(self.root, PROJECT, '_overlapped.pyd') + shutil.copyfile(src, dst) def runtests(self, pyver): print("Run tests on %s" % pyver) - if not self.options.no_compile: + if WINDOWS and not self.options.no_compile: self.build_inplace(pyver) release_env = dict(os.environ) @@ -239,27 +286,21 @@ def runtests(self, pyver): self.run_command(*args, env=dbg_env) print("") - def build(self, pyver, *cmds): - self.cleanup() - + def _build_windows(self, pyver, cmd): setenv, sdkver = self.windows_sdk_setenv(pyver) - python = pyver.get_executable(self) - - cmd = [python, 'setup.py'] + list(cmds) - with tempfile.NamedTemporaryFile(mode="w", suffix=".bat", delete=False) as temp: - print("SETLOCAL EnableDelayedExpansion", file=temp) - print(self.quote_args(setenv), file=temp) - print(BATCH_FAIL_ON_ERROR, file=temp) + temp.write("SETLOCAL EnableDelayedExpansion\n") + temp.write(self.quote_args(setenv) + "\n") + temp.write(BATCH_FAIL_ON_ERROR + "\n") # Restore console colors: lightgrey on black - print("COLOR 07", file=temp) - print("", file=temp) - print("SET DISTUTILS_USE_SDK=1", file=temp) - print("SET MSSDK=1", file=temp) - print("CD %s" % self.quote(self.root), file=temp) - print(self.quote_args(cmd), file=temp) - print(BATCH_FAIL_ON_ERROR, file=temp) + temp.write("COLOR 07\n") + temp.write("\n") + temp.write("SET DISTUTILS_USE_SDK=1\n") + temp.write("SET MSSDK=1\n") + temp.write("CD %s\n" % self.quote(self.root)) + temp.write(self.quote_args(cmd) + "\n") + temp.write(BATCH_FAIL_ON_ERROR + "\n") try: if self.verbose: @@ -274,12 +315,28 @@ def build(self, pyver, *cmds): finally: os.unlink(temp.name) + def _build_unix(self, pyver, cmd): + self.check_output(*cmd) + + def build(self, pyver, *cmds): + self.cleanup() + + python = pyver.get_executable(self) + cmd = [python, 'setup.py'] + list(cmds) + + if WINDOWS: + self._build_windows(pyver, cmd) + else: + self._build_unix(pyver, cmd) + def test_wheel(self, pyver): print("Test building wheel package for %s" % pyver) self.build(pyver, 'bdist_wheel') def publish_wheel(self, pyver): - self.build(pyver, 'bdist_wheel', 'upload') + # FIXME: really upload + #self.build(pyver, 'bdist_wheel', 'upload') + self.build(pyver, 'bdist_wheel') def parse_options(self): parser = optparse.OptionParser( @@ -331,7 +388,7 @@ def parse_options(self): self.release = True self.wheel = True self.test = True - #self.upload = True + self.upload = True else: if command: print("Invalid command: %s" % command) From 187d2a5d39de6895f533e3c0ba3ae368f9e3f477 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 19 Dec 2014 23:24:26 +0100 Subject: [PATCH 1232/1502] MANIFEST.in: add release.py --- MANIFEST.in | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MANIFEST.in b/MANIFEST.in index 314325c8..cb0cb08e 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,7 +1,7 @@ include AUTHORS COPYING include Makefile include overlapped.c pypi.bat -include check.py runtests.py +include check.py runtests.py release.py include update_stdlib.sh recursive-include examples *.py From effa09e0796d93e105459e2b1a1716cbb2e13894 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 19 Dec 2014 23:25:06 +0100 Subject: [PATCH 1233/1502] release.py register now upload for real publish_wheel() is now also more verbose --- release.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/release.py b/release.py index c43822b0..ec6267f2 100644 --- a/release.py +++ b/release.py @@ -334,9 +334,8 @@ def test_wheel(self, pyver): self.build(pyver, 'bdist_wheel') def publish_wheel(self, pyver): - # FIXME: really upload - #self.build(pyver, 'bdist_wheel', 'upload') - self.build(pyver, 'bdist_wheel') + print("Build wheel package for %s" % pyver) + self.build(pyver, 'bdist_wheel', 'upload') def parse_options(self): parser = optparse.OptionParser( From 1d5832923d8312b3c4b62ac6fb3b7e0ee35d6122 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 20 Dec 2014 14:17:06 +0100 Subject: [PATCH 1234/1502] release.py: fix the message of publish_wheel() --- release.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/release.py b/release.py index ec6267f2..848e3169 100644 --- a/release.py +++ b/release.py @@ -334,7 +334,7 @@ def test_wheel(self, pyver): self.build(pyver, 'bdist_wheel') def publish_wheel(self, pyver): - print("Build wheel package for %s" % pyver) + print("Build and publish wheel package for %s" % pyver) self.build(pyver, 'bdist_wheel', 'upload') def parse_options(self): From 44fd65413693d4865cc9f7c9c3dbdb887d9b4b76 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 26 Dec 2014 21:03:44 +0100 Subject: [PATCH 1235/1502] Fix doc of get and put methods of Queue --- asyncio/queues.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/asyncio/queues.py b/asyncio/queues.py index 41551a90..8f6c2577 100644 --- a/asyncio/queues.py +++ b/asyncio/queues.py @@ -111,8 +111,10 @@ def full(self): def put(self, item): """Put an item into the queue. - If you yield from put(), wait until a free slot is available - before adding item. + Put an item into the queue. If the queue is full, wait until a free + slot is available before adding item. + + This method is a coroutine. """ self._consume_done_getters() if self._getters: @@ -161,7 +163,9 @@ def put_nowait(self, item): def get(self): """Remove and return an item from the queue. - If you yield from get(), wait until a item is available. + If queue is empty, wait until an item is available. + + This method is a coroutine. """ self._consume_done_putters() if self._putters: From 5d9276ffb0ac4eba272080821790edf975cbc1f1 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 26 Dec 2014 21:00:05 +0100 Subject: [PATCH 1236/1502] Python issue #22926: In debug mode, call_soon(), call_at() and call_later() methods of BaseEventLoop now use the identifier of the current thread to ensure that they are called from the thread running the event loop. Before, the get_event_loop() method was used to check the thread, and no exception was raised when the thread had no event loop. Now the methods always raise an exception in debug mode when called from the wrong thread. It should help to notice misusage of the API. --- asyncio/base_events.py | 38 +++++++++--------- asyncio/proactor_events.py | 6 +-- asyncio/selector_events.py | 2 +- tests/test_base_events.py | 76 +++++++++++++++++++++++++++-------- tests/test_proactor_events.py | 7 ++-- tests/test_subprocess.py | 22 ++-------- 6 files changed, 89 insertions(+), 62 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index b1a5422b..684c9ecd 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -22,6 +22,7 @@ import os import socket import subprocess +import threading import time import traceback import sys @@ -168,7 +169,9 @@ def __init__(self): self._scheduled = [] self._default_executor = None self._internal_fds = 0 - self._running = False + # Identifier of the thread running the event loop, or None if the + # event loop is not running + self._owner = None self._clock_resolution = time.get_clock_info('monotonic').resolution self._exception_handler = None self._debug = (not sys.flags.ignore_environment @@ -246,9 +249,9 @@ def _check_closed(self): def run_forever(self): """Run until stop() is called.""" self._check_closed() - if self._running: + if self.is_running(): raise RuntimeError('Event loop is running.') - self._running = True + self._owner = threading.get_ident() try: while True: try: @@ -256,7 +259,7 @@ def run_forever(self): except _StopError: break finally: - self._running = False + self._owner = None def run_until_complete(self, future): """Run until the Future is done. @@ -311,7 +314,7 @@ def close(self): The event loop must not be running. """ - if self._running: + if self.is_running(): raise RuntimeError("Cannot close a running event loop") if self._closed: return @@ -331,7 +334,7 @@ def is_closed(self): def is_running(self): """Returns True if the event loop is running.""" - return self._running + return (self._owner is not None) def time(self): """Return the time according to the event loop's clock. @@ -373,7 +376,7 @@ def call_at(self, when, callback, *args): raise TypeError("coroutines cannot be used with call_at()") self._check_closed() if self._debug: - self._assert_is_current_event_loop() + self._check_thread() timer = events.TimerHandle(when, callback, args, self) if timer._source_traceback: del timer._source_traceback[-1] @@ -391,17 +394,17 @@ def call_soon(self, callback, *args): Any positional arguments after the callback will be passed to the callback when it is called. """ - handle = self._call_soon(callback, args, check_loop=True) + if self._debug: + self._check_thread() + handle = self._call_soon(callback, args) if handle._source_traceback: del handle._source_traceback[-1] return handle - def _call_soon(self, callback, args, check_loop): + def _call_soon(self, callback, args): if (coroutines.iscoroutine(callback) or coroutines.iscoroutinefunction(callback)): raise TypeError("coroutines cannot be used with call_soon()") - if self._debug and check_loop: - self._assert_is_current_event_loop() self._check_closed() handle = events.Handle(callback, args, self) if handle._source_traceback: @@ -409,8 +412,8 @@ def _call_soon(self, callback, args, check_loop): self._ready.append(handle) return handle - def _assert_is_current_event_loop(self): - """Asserts that this event loop is the current event loop. + def _check_thread(self): + """Check that the current thread is the thread running the event loop. Non-thread-safe methods of this class make this assumption and will likely behave incorrectly when the assumption is violated. @@ -418,18 +421,17 @@ def _assert_is_current_event_loop(self): Should only be called when (self._debug == True). The caller is responsible for checking this condition for performance reasons. """ - try: - current = events.get_event_loop() - except RuntimeError: + if self._owner is None: return - if current is not self: + thread_id = threading.get_ident() + if thread_id != self._owner: raise RuntimeError( "Non-thread-safe operation invoked on an event loop other " "than the current one") def call_soon_threadsafe(self, callback, *args): """Like call_soon(), but thread-safe.""" - handle = self._call_soon(callback, args, check_loop=False) + handle = self._call_soon(callback, args) if handle._source_traceback: del handle._source_traceback[-1] self._write_to_self() diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index e67cf65a..44a81975 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -383,7 +383,7 @@ def _make_write_pipe_transport(self, sock, protocol, waiter=None, sock, protocol, waiter, extra) def close(self): - if self._running: + if self.is_running(): raise RuntimeError("Cannot close a running event loop") if self.is_closed(): return @@ -432,9 +432,7 @@ def _make_self_pipe(self): self._ssock.setblocking(False) self._csock.setblocking(False) self._internal_fds += 1 - # don't check the current loop because _make_self_pipe() is called - # from the event loop constructor - self._call_soon(self._loop_self_reading, (), check_loop=False) + self.call_soon(self._loop_self_reading) def _loop_self_reading(self, f=None): try: diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 7df8b866..a97709d8 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -68,7 +68,7 @@ def _make_datagram_transport(self, sock, protocol, address, waiter, extra) def close(self): - if self._running: + if self.is_running(): raise RuntimeError("Cannot close a running event loop") if self.is_closed(): return diff --git a/tests/test_base_events.py b/tests/test_base_events.py index afb65b29..5906fb73 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -5,6 +5,7 @@ import math import socket import sys +import threading import time import unittest from unittest import mock @@ -143,28 +144,71 @@ def cb(): # are really slow self.assertLessEqual(dt, 0.9, dt) - def test_assert_is_current_event_loop(self): + def check_thread(self, loop, debug): def cb(): pass - other_loop = base_events.BaseEventLoop() - other_loop._selector = mock.Mock() - asyncio.set_event_loop(other_loop) + loop.set_debug(debug) + if debug: + msg = ("Non-thread-safe operation invoked on an event loop other " + "than the current one") + with self.assertRaisesRegex(RuntimeError, msg): + loop.call_soon(cb) + with self.assertRaisesRegex(RuntimeError, msg): + loop.call_later(60, cb) + with self.assertRaisesRegex(RuntimeError, msg): + loop.call_at(loop.time() + 60, cb) + else: + loop.call_soon(cb) + loop.call_later(60, cb) + loop.call_at(loop.time() + 60, cb) + + def test_check_thread(self): + def check_in_thread(loop, event, debug, create_loop, fut): + # wait until the event loop is running + event.wait() + + try: + if create_loop: + loop2 = base_events.BaseEventLoop() + try: + asyncio.set_event_loop(loop2) + self.check_thread(loop, debug) + finally: + asyncio.set_event_loop(None) + loop2.close() + else: + self.check_thread(loop, debug) + except Exception as exc: + loop.call_soon_threadsafe(fut.set_exception, exc) + else: + loop.call_soon_threadsafe(fut.set_result, None) + + def test_thread(loop, debug, create_loop=False): + event = threading.Event() + fut = asyncio.Future(loop=loop) + loop.call_soon(event.set) + args = (loop, event, debug, create_loop, fut) + thread = threading.Thread(target=check_in_thread, args=args) + thread.start() + loop.run_until_complete(fut) + thread.join() - # raise RuntimeError if the event loop is different in debug mode - self.loop.set_debug(True) - with self.assertRaises(RuntimeError): - self.loop.call_soon(cb) - with self.assertRaises(RuntimeError): - self.loop.call_later(60, cb) - with self.assertRaises(RuntimeError): - self.loop.call_at(self.loop.time() + 60, cb) + self.loop._process_events = mock.Mock() + self.loop._write_to_self = mock.Mock() + + # raise RuntimeError if the thread has no event loop + test_thread(self.loop, True) # check disabled if debug mode is disabled - self.loop.set_debug(False) - self.loop.call_soon(cb) - self.loop.call_later(60, cb) - self.loop.call_at(self.loop.time() + 60, cb) + test_thread(self.loop, False) + + # raise RuntimeError if the event loop of the thread is not the called + # event loop + test_thread(self.loop, True, create_loop=True) + + # check disabled if debug mode is disabled + test_thread(self.loop, False, create_loop=True) def test_run_once_in_executor_handle(self): def cb(): diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index 9e9b41a4..82582383 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -440,17 +440,16 @@ def _socketpair(s): self.loop = EventLoop(self.proactor) self.set_event_loop(self.loop, cleanup=False) - @mock.patch.object(BaseProactorEventLoop, '_call_soon') + @mock.patch.object(BaseProactorEventLoop, 'call_soon') @mock.patch.object(BaseProactorEventLoop, '_socketpair') - def test_ctor(self, socketpair, _call_soon): + def test_ctor(self, socketpair, call_soon): ssock, csock = socketpair.return_value = ( mock.Mock(), mock.Mock()) loop = BaseProactorEventLoop(self.proactor) self.assertIs(loop._ssock, ssock) self.assertIs(loop._csock, csock) self.assertEqual(loop._internal_fds, 1) - _call_soon.assert_called_with(loop._loop_self_reading, (), - check_loop=False) + call_soon.assert_called_with(loop._loop_self_reading) def test_close_self_pipe(self): self.loop._close_self_pipe() diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index 5cb0f030..b284e6b7 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -230,19 +230,12 @@ class SubprocessWatcherMixin(SubprocessMixin): def setUp(self): policy = asyncio.get_event_loop_policy() self.loop = policy.new_event_loop() - - # ensure that the event loop is passed explicitly in asyncio - policy.set_event_loop(None) + self.set_event_loop(self.loop) watcher = self.Watcher() watcher.attach_loop(self.loop) policy.set_child_watcher(watcher) - - def tearDown(self): - policy = asyncio.get_event_loop_policy() - policy.set_child_watcher(None) - self.loop.close() - super().tearDown() + self.addCleanup(policy.set_child_watcher, None) class SubprocessSafeWatcherTests(SubprocessWatcherMixin, test_utils.TestCase): @@ -259,17 +252,8 @@ class SubprocessFastWatcherTests(SubprocessWatcherMixin, class SubprocessProactorTests(SubprocessMixin, test_utils.TestCase): def setUp(self): - policy = asyncio.get_event_loop_policy() self.loop = asyncio.ProactorEventLoop() - - # ensure that the event loop is passed explicitly in asyncio - policy.set_event_loop(None) - - def tearDown(self): - policy = asyncio.get_event_loop_policy() - self.loop.close() - policy.set_event_loop(None) - super().tearDown() + self.set_event_loop(self.loop) if __name__ == '__main__': From 1385ab70894cd9bafa480d7b7eed82240cd02749 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 26 Dec 2014 21:16:17 +0100 Subject: [PATCH 1237/1502] CPython doesn't have asyncio.test_support --- tests/test_base_events.py | 21 +++++++++++++-------- tests/test_events.py | 5 ++++- tests/test_futures.py | 5 ++++- tests/test_subprocess.py | 5 ++++- tests/test_tasks.py | 29 +++++++++++++++++------------ tests/test_windows_utils.py | 5 ++++- 6 files changed, 46 insertions(+), 24 deletions(-) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 5906fb73..6599e4ea 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -13,8 +13,13 @@ import asyncio from asyncio import base_events from asyncio import constants -from asyncio import test_support as support from asyncio import test_utils +try: + from test import support + from test.script_helper import assert_python_ok +except ImportError: + from asyncio import test_support as support + from asyncio.test_support import assert_python_ok MOCK_ANY = mock.ANY @@ -623,19 +628,19 @@ def test_env_var_debug(self): # Test with -E to not fail if the unit test was run with # PYTHONASYNCIODEBUG set to a non-empty string - sts, stdout, stderr = support.assert_python_ok('-E', '-c', code) + sts, stdout, stderr = assert_python_ok('-E', '-c', code) self.assertEqual(stdout.rstrip(), b'False') - sts, stdout, stderr = support.assert_python_ok('-c', code, - PYTHONASYNCIODEBUG='') + sts, stdout, stderr = assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='') self.assertEqual(stdout.rstrip(), b'False') - sts, stdout, stderr = support.assert_python_ok('-c', code, - PYTHONASYNCIODEBUG='1') + sts, stdout, stderr = assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='1') self.assertEqual(stdout.rstrip(), b'True') - sts, stdout, stderr = support.assert_python_ok('-E', '-c', code, - PYTHONASYNCIODEBUG='1') + sts, stdout, stderr = assert_python_ok('-E', '-c', code, + PYTHONASYNCIODEBUG='1') self.assertEqual(stdout.rstrip(), b'False') def test_create_task(self): diff --git a/tests/test_events.py b/tests/test_events.py index 04dc8805..af2da1fe 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -25,8 +25,11 @@ import asyncio from asyncio import proactor_events from asyncio import selector_events -from asyncio import test_support as support from asyncio import test_utils +try: + from test import support +except ImportError: + from asyncio import test_support as support def data_file(filename): diff --git a/tests/test_futures.py b/tests/test_futures.py index 7c564624..28637091 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -8,8 +8,11 @@ from unittest import mock import asyncio -from asyncio import test_support as support from asyncio import test_utils +try: + from test import support +except ImportError: + from asyncio import test_support as support def _fakefunc(f): diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index b284e6b7..d82cbbf0 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -5,8 +5,11 @@ import asyncio from asyncio import subprocess -from asyncio import test_support as support from asyncio import test_utils +try: + from test import support +except ImportError: + from asyncio import test_support as support if sys.platform != 'win32': from asyncio import unix_events diff --git a/tests/test_tasks.py b/tests/test_tasks.py index c0053668..1520fb4c 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -10,8 +10,13 @@ import asyncio from asyncio import coroutines -from asyncio import test_support as support from asyncio import test_utils +try: + from test import support + from test.script_helper import assert_python_ok +except ImportError: + from asyncio import test_support as support + from asyncio.test_support import assert_python_ok PY34 = (sys.version_info >= (3, 4)) @@ -1776,23 +1781,23 @@ def test_env_var_debug(self): # Test with -E to not fail if the unit test was run with # PYTHONASYNCIODEBUG set to a non-empty string - sts, stdout, stderr = support.assert_python_ok('-E', '-c', code, - PYTHONPATH=aio_path) + sts, stdout, stderr = assert_python_ok('-E', '-c', code, + PYTHONPATH=aio_path) self.assertEqual(stdout.rstrip(), b'False') - sts, stdout, stderr = support.assert_python_ok('-c', code, - PYTHONASYNCIODEBUG='', - PYTHONPATH=aio_path) + sts, stdout, stderr = assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='', + PYTHONPATH=aio_path) self.assertEqual(stdout.rstrip(), b'False') - sts, stdout, stderr = support.assert_python_ok('-c', code, - PYTHONASYNCIODEBUG='1', - PYTHONPATH=aio_path) + sts, stdout, stderr = assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='1', + PYTHONPATH=aio_path) self.assertEqual(stdout.rstrip(), b'True') - sts, stdout, stderr = support.assert_python_ok('-E', '-c', code, - PYTHONASYNCIODEBUG='1', - PYTHONPATH=aio_path) + sts, stdout, stderr = assert_python_ok('-E', '-c', code, + PYTHONASYNCIODEBUG='1', + PYTHONPATH=aio_path) self.assertEqual(stdout.rstrip(), b'False') diff --git a/tests/test_windows_utils.py b/tests/test_windows_utils.py index 92db24e6..af5c453b 100644 --- a/tests/test_windows_utils.py +++ b/tests/test_windows_utils.py @@ -11,8 +11,11 @@ import _winapi from asyncio import _overlapped -from asyncio import test_support as support from asyncio import windows_utils +try: + from test import support +except ImportError: + from asyncio import test_support as support class WinsocketpairTests(unittest.TestCase): From 08f06230ac3bc41a81950051aee0fdcb210d3e3a Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 26 Dec 2014 23:52:55 +0100 Subject: [PATCH 1238/1502] MANIFEST.in: add run_aiotest.py --- MANIFEST.in | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/MANIFEST.in b/MANIFEST.in index cb0cb08e..d0dbde14 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,7 +1,7 @@ include AUTHORS COPYING include Makefile include overlapped.c pypi.bat -include check.py runtests.py release.py +include check.py runtests.py run_aiotest.py release.py include update_stdlib.sh recursive-include examples *.py From ff7a0d636f0dadef6561b633704492d2b45d8613 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 26 Dec 2014 23:53:03 +0100 Subject: [PATCH 1239/1502] fix typo --- asyncio/test_support.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/test_support.py b/asyncio/test_support.py index 336f3acf..7a58cc04 100644 --- a/asyncio/test_support.py +++ b/asyncio/test_support.py @@ -1,5 +1,5 @@ # Subset of test.support from CPython 3.5, just what we need to run asyncio -# test suite. The cde is copied from CPython 3.5 to not depend on the test +# test suite. The code is copied from CPython 3.5 to not depend on the test # module because it is rarely installed. # Ignore symbol TEST_HOME_DIR: test_events works without it From 1bb49303ad926e4ce324a789681e23d89a98014f Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 6 Jan 2015 01:03:18 +0100 Subject: [PATCH 1240/1502] Python issue #23046: Expose the BaseEventLoop class in the asyncio namespace --- asyncio/__init__.py | 4 +++- asyncio/base_events.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/asyncio/__init__.py b/asyncio/__init__.py index 3911fb40..011466b3 100644 --- a/asyncio/__init__.py +++ b/asyncio/__init__.py @@ -18,6 +18,7 @@ import _overlapped # Will also be exported. # This relies on each of the submodules having an __all__ variable. +from .base_events import * from .coroutines import * from .events import * from .futures import * @@ -29,7 +30,8 @@ from .tasks import * from .transports import * -__all__ = (coroutines.__all__ + +__all__ = (base_events.__all__ + + coroutines.__all__ + events.__all__ + futures.__all__ + locks.__all__ + diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 684c9ecd..59f31364 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -35,7 +35,7 @@ from .log import logger -__all__ = ['BaseEventLoop', 'Server'] +__all__ = ['BaseEventLoop'] # Argument for default thread pool executor creation. From 9b36966cb84f883d190e16330326ee5049a595c2 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 6 Jan 2015 01:13:16 +0100 Subject: [PATCH 1241/1502] Issue #23140: Fix cancellation of Process.wait(). Check the state of the waiter future before setting its result. --- asyncio/subprocess.py | 3 ++- tests/test_subprocess.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/asyncio/subprocess.py b/asyncio/subprocess.py index f6d6a141..a8ad03c2 100644 --- a/asyncio/subprocess.py +++ b/asyncio/subprocess.py @@ -96,7 +96,8 @@ def process_exited(self): returncode = self._transport.get_returncode() while self._waiters: waiter = self._waiters.popleft() - waiter.set_result(returncode) + if not waiter.cancelled(): + waiter.set_result(returncode) class Process: diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index d82cbbf0..dfe23be2 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -223,6 +223,34 @@ def len_message(message): self.assertEqual(output.rstrip(), b'3') self.assertEqual(exitcode, 0) + def test_cancel_process_wait(self): + # Issue #23140: cancel Process.wait() + + @asyncio.coroutine + def wait_proc(proc, event): + event.set() + yield from proc.wait() + + @asyncio.coroutine + def cancel_wait(): + proc = yield from asyncio.create_subprocess_exec( + *PROGRAM_BLOCKED, + loop=self.loop) + + # Create an internal future waiting on the process exit + event = asyncio.Event(loop=self.loop) + task = self.loop.create_task(wait_proc(proc, event)) + yield from event.wait() + + # Cancel the future + task.cancel() + + # Kill the process and wait until it is done + proc.kill() + yield from proc.wait() + + self.loop.run_until_complete(cancel_wait()) + if sys.platform != 'win32': # Unix From 9f1e74d269679452f2a7d6278fe433c92cea17cc Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 6 Jan 2015 01:21:57 +0100 Subject: [PATCH 1242/1502] Python issue #23140: Simplify the unit test --- tests/test_subprocess.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index dfe23be2..1fe9095d 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -226,11 +226,6 @@ def len_message(message): def test_cancel_process_wait(self): # Issue #23140: cancel Process.wait() - @asyncio.coroutine - def wait_proc(proc, event): - event.set() - yield from proc.wait() - @asyncio.coroutine def cancel_wait(): proc = yield from asyncio.create_subprocess_exec( @@ -238,9 +233,12 @@ def cancel_wait(): loop=self.loop) # Create an internal future waiting on the process exit - event = asyncio.Event(loop=self.loop) - task = self.loop.create_task(wait_proc(proc, event)) - yield from event.wait() + task = self.loop.create_task(proc.wait()) + self.loop.call_soon(task.cancel) + try: + yield from task + except asyncio.CancelledError: + pass # Cancel the future task.cancel() From f139962e14cf85658af3b563adbeb7fbde07c864 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 8 Jan 2015 12:05:43 +0100 Subject: [PATCH 1243/1502] _make_ssl_transport: make the waiter parameter optional --- asyncio/base_events.py | 2 +- asyncio/selector_events.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 59f31364..ac885a8e 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -201,7 +201,7 @@ def _make_socket_transport(self, sock, protocol, waiter=None, *, """Create socket transport.""" raise NotImplementedError - def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, *, server_side=False, server_hostname=None, extra=None, server=None): """Create SSL transport.""" diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index a97709d8..2e7364b8 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -55,7 +55,7 @@ def _make_socket_transport(self, sock, protocol, waiter=None, *, return _SelectorSocketTransport(self, sock, protocol, waiter, extra, server) - def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter, *, + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, *, server_side=False, server_hostname=None, extra=None, server=None): return _SelectorSslTransport( @@ -165,7 +165,7 @@ def _accept_connection(self, protocol_factory, sock, else: if sslcontext: self._make_ssl_transport( - conn, protocol_factory(), sslcontext, None, + conn, protocol_factory(), sslcontext, server_side=True, extra={'peername': addr}, server=server) else: self._make_socket_transport( From 9d4323861f0e3a69018f0b7ad00a99b0e6663835 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 9 Jan 2015 00:02:52 +0100 Subject: [PATCH 1244/1502] Truncate to 80 columns --- asyncio/base_events.py | 4 ++-- asyncio/coroutines.py | 12 ++++++++---- asyncio/selector_events.py | 7 ++++--- asyncio/tasks.py | 2 +- asyncio/test_support.py | 15 ++++++++------- asyncio/unix_events.py | 3 ++- asyncio/windows_events.py | 3 ++- asyncio/windows_utils.py | 7 ++++--- release.py | 10 +++++++--- tests/test_base_events.py | 9 ++++++--- tests/test_futures.py | 24 ++++++++++++++++-------- tests/test_streams.py | 6 ++++-- tests/test_subprocess.py | 4 +++- tests/test_tasks.py | 12 ++++++++---- 14 files changed, 75 insertions(+), 43 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index ac885a8e..35c8d742 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -201,8 +201,8 @@ def _make_socket_transport(self, sock, protocol, waiter=None, *, """Create socket transport.""" raise NotImplementedError - def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, *, - server_side=False, server_hostname=None, + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, + *, server_side=False, server_hostname=None, extra=None, server=None): """Create SSL transport.""" raise NotImplementedError diff --git a/asyncio/coroutines.py b/asyncio/coroutines.py index c28de95a..a1b28751 100644 --- a/asyncio/coroutines.py +++ b/asyncio/coroutines.py @@ -182,14 +182,18 @@ def _format_coroutine(coro): and not inspect.isgeneratorfunction(coro.func)): filename, lineno = events._get_function_source(coro.func) if coro.gi_frame is None: - coro_repr = '%s() done, defined at %s:%s' % (coro_name, filename, lineno) + coro_repr = ('%s() done, defined at %s:%s' + % (coro_name, filename, lineno)) else: - coro_repr = '%s() running, defined at %s:%s' % (coro_name, filename, lineno) + coro_repr = ('%s() running, defined at %s:%s' + % (coro_name, filename, lineno)) elif coro.gi_frame is not None: lineno = coro.gi_frame.f_lineno - coro_repr = '%s() running at %s:%s' % (coro_name, filename, lineno) + coro_repr = ('%s() running at %s:%s' + % (coro_name, filename, lineno)) else: lineno = coro.gi_code.co_firstlineno - coro_repr = '%s() done, defined at %s:%s' % (coro_name, filename, lineno) + coro_repr = ('%s() done, defined at %s:%s' + % (coro_name, filename, lineno)) return coro_repr diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 2e7364b8..69b649ce 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -55,8 +55,8 @@ def _make_socket_transport(self, sock, protocol, waiter=None, *, return _SelectorSocketTransport(self, sock, protocol, waiter, extra, server) - def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, *, - server_side=False, server_hostname=None, + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, + *, server_side=False, server_hostname=None, extra=None, server=None): return _SelectorSslTransport( self, rawsock, protocol, sslcontext, waiter, @@ -484,7 +484,8 @@ def __repr__(self): info.append('read=idle') polling = _test_selector_event(self._loop._selector, - self._sock_fd, selectors.EVENT_WRITE) + self._sock_fd, + selectors.EVENT_WRITE) if polling: state = 'polling' else: diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 9aebffda..8fc5beac 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -68,7 +68,7 @@ def all_tasks(cls, loop=None): return {t for t in cls._all_tasks if t._loop is loop} def __init__(self, coro, *, loop=None): - assert coroutines.iscoroutine(coro), repr(coro) # Not a coroutine function! + assert coroutines.iscoroutine(coro), repr(coro) super().__init__(loop=loop) if self._source_traceback: del self._source_traceback[-1] diff --git a/asyncio/test_support.py b/asyncio/test_support.py index 7a58cc04..3da47558 100644 --- a/asyncio/test_support.py +++ b/asyncio/test_support.py @@ -209,12 +209,13 @@ def bind_port(sock, host=HOST): if sock.family == socket.AF_INET and sock.type == socket.SOCK_STREAM: if hasattr(socket, 'SO_REUSEADDR'): if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 1: - raise TestFailed("tests should never set the SO_REUSEADDR " \ + raise TestFailed("tests should never set the SO_REUSEADDR " "socket option on TCP/IP sockets!") if hasattr(socket, 'SO_REUSEPORT'): try: - if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 1: - raise TestFailed("tests should never set the SO_REUSEPORT " \ + reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) + if reuse == 1: + raise TestFailed("tests should never set the SO_REUSEPORT " "socket option on TCP/IP sockets!") except OSError: # Python's socket module was compiled using modern headers @@ -256,8 +257,8 @@ def wrapper(*args, **kw): return decorator def _requires_unix_version(sysname, min_version): - """Decorator raising SkipTest if the OS is `sysname` and the version is less - than `min_version`. + """Decorator raising SkipTest if the OS is `sysname` and the version is + less than `min_version`. For example, @_requires_unix_version('FreeBSD', (7, 2)) raises SkipTest if the FreeBSD version is less than 7.2. @@ -283,8 +284,8 @@ def wrapper(*args, **kw): return decorator def requires_freebsd_version(*min_version): - """Decorator raising SkipTest if the OS is FreeBSD and the FreeBSD version is - less than `min_version`. + """Decorator raising SkipTest if the OS is FreeBSD and the FreeBSD version + is less than `min_version`. For example, @requires_freebsd_version(7, 2) raises SkipTest if the FreeBSD version is less than 7.2. diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index d1461fd0..91e43cfc 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -69,7 +69,8 @@ def add_signal_handler(self, sig, callback, *args): """ if (coroutines.iscoroutine(callback) or coroutines.iscoroutinefunction(callback)): - raise TypeError("coroutines cannot be used with add_signal_handler()") + raise TypeError("coroutines cannot be used " + "with add_signal_handler()") self._check_signal(sig) self._check_closed() try: diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index d7feb1ae..9d496f2f 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -424,7 +424,8 @@ def finish_connect_pipe(err, handle, ov): else: return windows_utils.PipeHandle(handle) - return self._register(ov, None, finish_connect_pipe, wait_for_post=True) + return self._register(ov, None, finish_connect_pipe, + wait_for_post=True) def wait_for_handle(self, handle, timeout=None): """Wait for a handle. diff --git a/asyncio/windows_utils.py b/asyncio/windows_utils.py index c6e4bc9e..b4758123 100644 --- a/asyncio/windows_utils.py +++ b/asyncio/windows_utils.py @@ -36,15 +36,16 @@ def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): """A socket pair usable as a self-pipe, for Windows. - Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. + Origin: https://gist.github.com/4325783, by Geert Jansen. + Public domain. """ if family == socket.AF_INET: host = '127.0.0.1' elif family == socket.AF_INET6: host = '::1' else: - raise ValueError("Only AF_INET and AF_INET6 socket address families " - "are supported") + raise ValueError("Only AF_INET and AF_INET6 socket address " + "families are supported") if type != socket.SOCK_STREAM: raise ValueError("Only SOCK_STREAM socket type is supported") if proto != 0: diff --git a/release.py b/release.py index 848e3169..c7829ec6 100644 --- a/release.py +++ b/release.py @@ -249,7 +249,8 @@ def build_inplace(self, pyver): else: arch = 'win32' build_dir = 'lib.%s-%s.%s' % (arch, pyver.major, pyver.minor) - src = os.path.join(self.root, 'build', build_dir, PROJECT, '_overlapped.pyd') + src = os.path.join(self.root, 'build', build_dir, + PROJECT, '_overlapped.pyd') dst = os.path.join(self.root, PROJECT, '_overlapped.pyd') shutil.copyfile(src, dst) @@ -289,7 +290,9 @@ def runtests(self, pyver): def _build_windows(self, pyver, cmd): setenv, sdkver = self.windows_sdk_setenv(pyver) - with tempfile.NamedTemporaryFile(mode="w", suffix=".bat", delete=False) as temp: + temp = tempfile.NamedTemporaryFile(mode="w", suffix=".bat", + delete=False) + with temp: temp.write("SETLOCAL EnableDelayedExpansion\n") temp.write(self.quote_args(setenv) + "\n") temp.write(BATCH_FAIL_ON_ERROR + "\n") @@ -500,7 +503,8 @@ def main(self): if self.register: print("Project registered on the Python cheeseshop (PyPI)") if self.sdist: - print("Project source code uploaded to the Python cheeseshop (PyPI)") + print("Project source code uploaded to the Python " + "cheeseshop (PyPI)") if self.upload: print("Wheel packages uploaded to the Python cheeseshop (PyPI)") for pyver in self.python_versions: diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 6599e4ea..6bf7e796 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -285,7 +285,8 @@ def test_set_debug(self): @mock.patch('asyncio.base_events.logger') def test__run_once_logging(self, m_logger): def slow_select(timeout): - # Sleep a bit longer than a second to avoid timer resolution issues. + # Sleep a bit longer than a second to avoid timer resolution + # issues. time.sleep(1.1) return [] @@ -1217,14 +1218,16 @@ def stop_loop_coro(loop): self.loop.run_forever() fmt, *args = m_logger.warning.call_args[0] self.assertRegex(fmt % tuple(args), - "^Executing took .* seconds$") + "^Executing " + "took .* seconds$") # slow task asyncio.async(stop_loop_coro(self.loop), loop=self.loop) self.loop.run_forever() fmt, *args = m_logger.warning.call_args[0] self.assertRegex(fmt % tuple(args), - "^Executing took .* seconds$") + "^Executing " + "took .* seconds$") if __name__ == '__main__': diff --git a/tests/test_futures.py b/tests/test_futures.py index 28637091..dac1e897 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -133,7 +133,8 @@ def test_future_repr(self): exc = RuntimeError() f_exception = asyncio.Future(loop=self.loop) f_exception.set_exception(exc) - self.assertEqual(repr(f_exception), '') + self.assertEqual(repr(f_exception), + '') self.assertIs(f_exception.exception(), exc) def func_repr(func): @@ -332,16 +333,21 @@ def memory_error(): if debug: frame = source_traceback[-1] regex = (r'^Future exception was never retrieved\n' - r'future: \n' - r'source_traceback: Object created at \(most recent call last\):\n' + r'future: \n' + r'source_traceback: Object ' + r'created at \(most recent call last\):\n' r' File' r'.*\n' - r' File "{filename}", line {lineno}, in check_future_exception_never_retrieved\n' + r' File "{filename}", line {lineno}, ' + r'in check_future_exception_never_retrieved\n' r' future = asyncio\.Future\(loop=self\.loop\)$' - ).format(filename=re.escape(frame[0]), lineno=frame[1]) + ).format(filename=re.escape(frame[0]), + lineno=frame[1]) else: regex = (r'^Future exception was never retrieved\n' - r'future: $' + r'future: ' + r'$' ) exc_info = (type(exc), exc, exc.__traceback__) m_log.error.assert_called_once_with(mock.ANY, exc_info=exc_info) @@ -352,12 +358,14 @@ def memory_error(): r'Future/Task created at \(most recent call last\):\n' r' File' r'.*\n' - r' File "{filename}", line {lineno}, in check_future_exception_never_retrieved\n' + r' File "{filename}", line {lineno}, ' + r'in check_future_exception_never_retrieved\n' r' future = asyncio\.Future\(loop=self\.loop\)\n' r'Traceback \(most recent call last\):\n' r'.*\n' r'MemoryError$' - ).format(filename=re.escape(frame[0]), lineno=frame[1]) + ).format(filename=re.escape(frame[0]), + lineno=frame[1]) else: regex = (r'^Future/Task exception was never retrieved\n' r'Traceback \(most recent call last\):\n' diff --git a/tests/test_streams.py b/tests/test_streams.py index 73a375ab..05963cf1 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -613,8 +613,10 @@ def test_read_all_from_pipe_reader(self): watcher.attach_loop(self.loop) try: asyncio.set_child_watcher(watcher) - proc = self.loop.run_until_complete( - asyncio.create_subprocess_exec(*args, pass_fds={wfd}, loop=self.loop)) + create = asyncio.create_subprocess_exec(*args, + pass_fds={wfd}, + loop=self.loop) + proc = self.loop.run_until_complete(create) self.loop.run_until_complete(proc.wait()) finally: asyncio.set_child_watcher(None) diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index 1fe9095d..5fc1dc0a 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -115,7 +115,9 @@ def test_terminate(self): def test_send_signal(self): code = 'import time; print("sleeping", flush=True); time.sleep(3600)' args = [sys.executable, '-c', code] - create = asyncio.create_subprocess_exec(*args, loop=self.loop, stdout=subprocess.PIPE) + create = asyncio.create_subprocess_exec(*args, + stdout=subprocess.PIPE, + loop=self.loop) proc = self.loop.run_until_complete(create) @asyncio.coroutine diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 1520fb4c..7807dc04 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -208,7 +208,8 @@ def notmuch(): self.assertEqual(notmuch.__name__, 'notmuch') if PY35: self.assertEqual(notmuch.__qualname__, - 'TaskTests.test_task_repr_coro_decorator..notmuch') + 'TaskTests.test_task_repr_coro_decorator' + '..notmuch') self.assertEqual(notmuch.__module__, __name__) # test coroutine object @@ -218,7 +219,8 @@ def notmuch(): # function, as expected, and have a qualified name (__qualname__ # attribute). coro_name = 'notmuch' - coro_qualname = 'TaskTests.test_task_repr_coro_decorator..notmuch' + coro_qualname = ('TaskTests.test_task_repr_coro_decorator' + '..notmuch') else: # On Python < 3.5, generators inherit the name of the code, not of # the function. See: http://bugs.python.org/issue21205 @@ -239,7 +241,8 @@ def notmuch(): else: code = gen.gi_code coro = ('%s() running at %s:%s' - % (coro_qualname, code.co_filename, code.co_firstlineno)) + % (coro_qualname, code.co_filename, + code.co_firstlineno)) self.assertEqual(repr(gen), '' % coro) @@ -1678,7 +1681,8 @@ def coro_noop(): self.assertTrue(m_log.error.called) message = m_log.error.call_args[0][0] func_filename, func_lineno = test_utils.get_function_source(coro_noop) - regex = (r'^ was never yielded from\n' + regex = (r'^ ' + r'was never yielded from\n' r'Coroutine object created at \(most recent call last\):\n' r'.*\n' r' File "%s", line %s, in test_coroutine_never_yielded\n' From 9dc23dffa3f1060d386d27516ec8faedbf75ba7d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 9 Jan 2015 00:13:35 +0100 Subject: [PATCH 1245/1502] selectors: truncate to 80 characters --- asyncio/selectors.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/asyncio/selectors.py b/asyncio/selectors.py index faa2d3da..598845d4 100644 --- a/asyncio/selectors.py +++ b/asyncio/selectors.py @@ -576,7 +576,8 @@ def close(self): super().close() -# Choose the best implementation: roughly, epoll|kqueue|devpoll > poll > select. +# Choose the best implementation, roughly: +# epoll|kqueue|devpoll > poll > select. # select() also can't accept a FD > FD_SETSIZE (usually around 1024) if 'KqueueSelector' in globals(): DefaultSelector = KqueueSelector From c4c59e65d200abed893db1670d04f9cae1de8af9 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 9 Jan 2015 00:51:43 +0100 Subject: [PATCH 1246/1502] sock_connect(): pass directly the fd to _sock_connect_done instead of the socket --- asyncio/selector_events.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 69b649ce..58b61f1c 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -363,15 +363,15 @@ def _sock_connect(self, fut, sock, address): break except BlockingIOError: fut.add_done_callback(functools.partial(self._sock_connect_done, - sock)) + fd)) self.add_writer(fd, self._sock_connect_cb, fut, sock, address) except Exception as exc: fut.set_exception(exc) else: fut.set_result(None) - def _sock_connect_done(self, sock, fut): - self.remove_writer(sock.fileno()) + def _sock_connect_done(self, fd, fut): + self.remove_writer(fd) def _sock_connect_cb(self, fut, sock, address): if fut.cancelled(): From d78cd0a4a14a60b6bfd4c767b06f3875e4a1878e Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 9 Jan 2015 00:58:42 +0100 Subject: [PATCH 1247/1502] Cleanup gather() Use public methods instead of hacks to consume the exception of a future. --- asyncio/tasks.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 8fc5beac..7959a55a 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -582,11 +582,12 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False): def _done_callback(i, fut): nonlocal nfinished - if outer._state != futures._PENDING: - if fut._exception is not None: + if outer.done(): + if not fut.cancelled(): # Mark exception retrieved. fut.exception() return + if fut._state == futures._CANCELLED: res = futures.CancelledError() if not return_exceptions: @@ -644,9 +645,11 @@ def shield(arg, *, loop=None): def _done_callback(inner): if outer.cancelled(): - # Mark inner's result as retrieved. - inner.cancelled() or inner.exception() + if not inner.cancelled(): + # Mark inner's result as retrieved. + inner.exception() return + if inner.cancelled(): outer.cancel() else: From e0fd3977396da4c8bfcae026eda2a7113e252322 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 9 Jan 2015 00:59:11 +0100 Subject: [PATCH 1248/1502] Document why set_result() calls are safe --- asyncio/queues.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/asyncio/queues.py b/asyncio/queues.py index 8f6c2577..dce0d53c 100644 --- a/asyncio/queues.py +++ b/asyncio/queues.py @@ -126,6 +126,8 @@ def put(self, item): # Use _put and _get instead of passing item straight to getter, in # case a subclass has logic that must run (e.g. JoinableQueue). self._put(item) + + # getter cannot be cancelled, we just removed done getters getter.set_result(self._get()) elif self._maxsize > 0 and self._maxsize <= self.qsize(): @@ -152,6 +154,8 @@ def put_nowait(self, item): # Use _put and _get instead of passing item straight to getter, in # case a subclass has logic that must run (e.g. JoinableQueue). self._put(item) + + # getter cannot be cancelled, we just removed done getters getter.set_result(self._get()) elif self._maxsize > 0 and self._maxsize <= self.qsize(): @@ -200,6 +204,8 @@ def get_nowait(self): item, putter = self._putters.popleft() self._put(item) # Wake putter on next tick. + + # getter cannot be cancelled, we just removed done putters putter.set_result(None) return self._get() From 84858d6e17b05660f0b3e45dc7d77b2bce184982 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 9 Jan 2015 15:16:08 +0100 Subject: [PATCH 1249/1502] Replace test_selectors.py with the file of Python 3.5 adapted for asyncio and Python 3.3 * Use time.time if time.monotonic is not available * Get socketpair from asyncio.test_utils * Get selectors from asyncio.selectors --- tests/test_selectors.py | 556 ++++++++++++++++++++++++++++------------ 1 file changed, 395 insertions(+), 161 deletions(-) diff --git a/tests/test_selectors.py b/tests/test_selectors.py index d91c78b1..3d5ef918 100644 --- a/tests/test_selectors.py +++ b/tests/test_selectors.py @@ -1,214 +1,448 @@ -"""Tests for selectors.py.""" - +import errno +import os +import random +import signal +import sys +from test import support +from time import sleep import unittest -from unittest import mock - +import unittest.mock +try: + from time import monotonic as time +except ImportError: + from time import time as time +try: + import resource +except ImportError: + resource = None from asyncio import selectors +from asyncio.test_utils import socketpair -class FakeSelector(selectors._BaseSelectorImpl): - """Trivial non-abstract subclass of BaseSelector.""" +def find_ready_matching(ready, flag): + match = [] + for key, events in ready: + if events & flag: + match.append(key.fileobj) + return match - def select(self, timeout=None): - raise NotImplementedError +class BaseSelectorTestCase(unittest.TestCase): -class _SelectorMappingTests(unittest.TestCase): + def make_socketpair(self): + rd, wr = socketpair() + self.addCleanup(rd.close) + self.addCleanup(wr.close) + return rd, wr - def test_len(self): - s = FakeSelector() - map = selectors._SelectorMapping(s) - self.assertTrue(map.__len__() == 0) - - f = mock.Mock() - f.fileno.return_value = 10 - s.register(f, selectors.EVENT_READ, None) - self.assertTrue(len(map) == 1) + def test_register(self): + s = self.SELECTOR() + self.addCleanup(s.close) - def test_getitem(self): - s = FakeSelector() - map = selectors._SelectorMapping(s) - f = mock.Mock() - f.fileno.return_value = 10 - s.register(f, selectors.EVENT_READ, None) - attended = selectors.SelectorKey(f, 10, selectors.EVENT_READ, None) - self.assertEqual(attended, map.__getitem__(f)) + rd, wr = self.make_socketpair() - def test_getitem_key_error(self): - s = FakeSelector() - map = selectors._SelectorMapping(s) - self.assertTrue(len(map) == 0) - f = mock.Mock() - f.fileno.return_value = 10 - s.register(f, selectors.EVENT_READ, None) - self.assertRaises(KeyError, map.__getitem__, 5) + key = s.register(rd, selectors.EVENT_READ, "data") + self.assertIsInstance(key, selectors.SelectorKey) + self.assertEqual(key.fileobj, rd) + self.assertEqual(key.fd, rd.fileno()) + self.assertEqual(key.events, selectors.EVENT_READ) + self.assertEqual(key.data, "data") - def test_iter(self): - s = FakeSelector() - map = selectors._SelectorMapping(s) - self.assertTrue(len(map) == 0) - f = mock.Mock() - f.fileno.return_value = 5 - s.register(f, selectors.EVENT_READ, None) - counter = 0 - for fileno in map.__iter__(): - self.assertEqual(5, fileno) - counter += 1 + # register an unknown event + self.assertRaises(ValueError, s.register, 0, 999999) - for idx in map: - self.assertEqual(f, map[idx].fileobj) - self.assertEqual(1, counter) + # register an invalid FD + self.assertRaises(ValueError, s.register, -10, selectors.EVENT_READ) + # register twice + self.assertRaises(KeyError, s.register, rd, selectors.EVENT_READ) -class BaseSelectorTests(unittest.TestCase): - def test_fileobj_to_fd(self): - self.assertEqual(10, selectors._fileobj_to_fd(10)) + # register the same FD, but with a different object + self.assertRaises(KeyError, s.register, rd.fileno(), + selectors.EVENT_READ) - f = mock.Mock() - f.fileno.return_value = 10 - self.assertEqual(10, selectors._fileobj_to_fd(f)) + def test_unregister(self): + s = self.SELECTOR() + self.addCleanup(s.close) + + rd, wr = self.make_socketpair() + + s.register(rd, selectors.EVENT_READ) + s.unregister(rd) + + # unregister an unknown file obj + self.assertRaises(KeyError, s.unregister, 999999) + + # unregister twice + self.assertRaises(KeyError, s.unregister, rd) + + def test_unregister_after_fd_close(self): + s = self.SELECTOR() + self.addCleanup(s.close) + rd, wr = self.make_socketpair() + r, w = rd.fileno(), wr.fileno() + s.register(r, selectors.EVENT_READ) + s.register(w, selectors.EVENT_WRITE) + rd.close() + wr.close() + s.unregister(r) + s.unregister(w) + + @unittest.skipUnless(os.name == 'posix', "requires posix") + def test_unregister_after_fd_close_and_reuse(self): + s = self.SELECTOR() + self.addCleanup(s.close) + rd, wr = self.make_socketpair() + r, w = rd.fileno(), wr.fileno() + s.register(r, selectors.EVENT_READ) + s.register(w, selectors.EVENT_WRITE) + rd2, wr2 = self.make_socketpair() + rd.close() + wr.close() + os.dup2(rd2.fileno(), r) + os.dup2(wr2.fileno(), w) + self.addCleanup(os.close, r) + self.addCleanup(os.close, w) + s.unregister(r) + s.unregister(w) + + def test_unregister_after_socket_close(self): + s = self.SELECTOR() + self.addCleanup(s.close) + rd, wr = self.make_socketpair() + s.register(rd, selectors.EVENT_READ) + s.register(wr, selectors.EVENT_WRITE) + rd.close() + wr.close() + s.unregister(rd) + s.unregister(wr) - f.fileno.side_effect = AttributeError - self.assertRaises(ValueError, selectors._fileobj_to_fd, f) + def test_modify(self): + s = self.SELECTOR() + self.addCleanup(s.close) - f.fileno.return_value = -1 - self.assertRaises(ValueError, selectors._fileobj_to_fd, f) + rd, wr = self.make_socketpair() - def test_selector_key_repr(self): - key = selectors.SelectorKey(10, 10, selectors.EVENT_READ, None) - self.assertEqual( - "SelectorKey(fileobj=10, fd=10, events=1, data=None)", repr(key)) + key = s.register(rd, selectors.EVENT_READ) - def test_register(self): - fobj = mock.Mock() - fobj.fileno.return_value = 10 + # modify events + key2 = s.modify(rd, selectors.EVENT_WRITE) + self.assertNotEqual(key.events, key2.events) + self.assertEqual(key2, s.get_key(rd)) - s = FakeSelector() - key = s.register(fobj, selectors.EVENT_READ) - self.assertIsInstance(key, selectors.SelectorKey) - self.assertEqual(key.fd, 10) - self.assertIs(key, s._fd_to_key[10]) + s.unregister(rd) - def test_register_unknown_event(self): - s = FakeSelector() - self.assertRaises(ValueError, s.register, mock.Mock(), 999999) + # modify data + d1 = object() + d2 = object() - def test_register_already_registered(self): - fobj = mock.Mock() - fobj.fileno.return_value = 10 + key = s.register(rd, selectors.EVENT_READ, d1) + key2 = s.modify(rd, selectors.EVENT_READ, d2) + self.assertEqual(key.events, key2.events) + self.assertNotEqual(key.data, key2.data) + self.assertEqual(key2, s.get_key(rd)) + self.assertEqual(key2.data, d2) - s = FakeSelector() - s.register(fobj, selectors.EVENT_READ) - self.assertRaises(KeyError, s.register, fobj, selectors.EVENT_READ) + # modify unknown file obj + self.assertRaises(KeyError, s.modify, 999999, selectors.EVENT_READ) - def test_unregister(self): - fobj = mock.Mock() - fobj.fileno.return_value = 10 + # modify use a shortcut + d3 = object() + s.register = unittest.mock.Mock() + s.unregister = unittest.mock.Mock() - s = FakeSelector() - s.register(fobj, selectors.EVENT_READ) - s.unregister(fobj) - self.assertFalse(s._fd_to_key) + s.modify(rd, selectors.EVENT_READ, d3) + self.assertFalse(s.register.called) + self.assertFalse(s.unregister.called) - def test_unregister_unknown(self): - fobj = mock.Mock() - fobj.fileno.return_value = 10 + def test_close(self): + s = self.SELECTOR() + self.addCleanup(s.close) - s = FakeSelector() - self.assertRaises(KeyError, s.unregister, fobj) + rd, wr = self.make_socketpair() - def test_modify_unknown(self): - fobj = mock.Mock() - fobj.fileno.return_value = 10 + s.register(rd, selectors.EVENT_READ) + s.register(wr, selectors.EVENT_WRITE) - s = FakeSelector() - self.assertRaises(KeyError, s.modify, fobj, 1) + s.close() + self.assertRaises(KeyError, s.get_key, rd) + self.assertRaises(KeyError, s.get_key, wr) - def test_modify(self): - fobj = mock.Mock() - fobj.fileno.return_value = 10 + def test_get_key(self): + s = self.SELECTOR() + self.addCleanup(s.close) - s = FakeSelector() - key = s.register(fobj, selectors.EVENT_READ) - key2 = s.modify(fobj, selectors.EVENT_WRITE) - self.assertNotEqual(key.events, key2.events) - self.assertEqual( - selectors.SelectorKey(fobj, 10, selectors.EVENT_WRITE, None), - s.get_key(fobj)) + rd, wr = self.make_socketpair() - def test_modify_data(self): - fobj = mock.Mock() - fobj.fileno.return_value = 10 + key = s.register(rd, selectors.EVENT_READ, "data") + self.assertEqual(key, s.get_key(rd)) - d1 = object() - d2 = object() + # unknown file obj + self.assertRaises(KeyError, s.get_key, 999999) - s = FakeSelector() - key = s.register(fobj, selectors.EVENT_READ, d1) - key2 = s.modify(fobj, selectors.EVENT_READ, d2) - self.assertEqual(key.events, key2.events) - self.assertNotEqual(key.data, key2.data) - self.assertEqual( - selectors.SelectorKey(fobj, 10, selectors.EVENT_READ, d2), - s.get_key(fobj)) + def test_get_map(self): + s = self.SELECTOR() + self.addCleanup(s.close) - def test_modify_data_use_a_shortcut(self): - fobj = mock.Mock() - fobj.fileno.return_value = 10 + rd, wr = self.make_socketpair() - d1 = object() - d2 = object() + keys = s.get_map() + self.assertFalse(keys) + self.assertEqual(len(keys), 0) + self.assertEqual(list(keys), []) + key = s.register(rd, selectors.EVENT_READ, "data") + self.assertIn(rd, keys) + self.assertEqual(key, keys[rd]) + self.assertEqual(len(keys), 1) + self.assertEqual(list(keys), [rd.fileno()]) + self.assertEqual(list(keys.values()), [key]) - s = FakeSelector() - s.register(fobj, selectors.EVENT_READ, d1) + # unknown file obj + with self.assertRaises(KeyError): + keys[999999] - s.unregister = mock.Mock() - s.register = mock.Mock() - s.modify(fobj, selectors.EVENT_READ, d2) - self.assertFalse(s.unregister.called) - self.assertFalse(s.register.called) + # Read-only mapping + with self.assertRaises(TypeError): + del keys[rd] - def test_modify_same(self): - fobj = mock.Mock() - fobj.fileno.return_value = 10 + def test_select(self): + s = self.SELECTOR() + self.addCleanup(s.close) - data = object() + rd, wr = self.make_socketpair() - s = FakeSelector() - key = s.register(fobj, selectors.EVENT_READ, data) - key2 = s.modify(fobj, selectors.EVENT_READ, data) - self.assertIs(key, key2) + s.register(rd, selectors.EVENT_READ) + wr_key = s.register(wr, selectors.EVENT_WRITE) - def test_select(self): - s = FakeSelector() - self.assertRaises(NotImplementedError, s.select) + result = s.select() + for key, events in result: + self.assertTrue(isinstance(key, selectors.SelectorKey)) + self.assertTrue(events) + self.assertFalse(events & ~(selectors.EVENT_READ | + selectors.EVENT_WRITE)) - def test_close(self): - s = FakeSelector() - s.register(1, selectors.EVENT_READ) - - s.close() - self.assertFalse(s._fd_to_key) + self.assertEqual([(wr_key, selectors.EVENT_WRITE)], result) def test_context_manager(self): - s = FakeSelector() + s = self.SELECTOR() + self.addCleanup(s.close) + + rd, wr = self.make_socketpair() with s as sel: - sel.register(1, selectors.EVENT_READ) + sel.register(rd, selectors.EVENT_READ) + sel.register(wr, selectors.EVENT_WRITE) + + self.assertRaises(KeyError, s.get_key, rd) + self.assertRaises(KeyError, s.get_key, wr) + + def test_fileno(self): + s = self.SELECTOR() + self.addCleanup(s.close) + + if hasattr(s, 'fileno'): + fd = s.fileno() + self.assertTrue(isinstance(fd, int)) + self.assertGreaterEqual(fd, 0) + + def test_selector(self): + s = self.SELECTOR() + self.addCleanup(s.close) + + NUM_SOCKETS = 12 + MSG = b" This is a test." + MSG_LEN = len(MSG) + readers = [] + writers = [] + r2w = {} + w2r = {} + + for i in range(NUM_SOCKETS): + rd, wr = self.make_socketpair() + s.register(rd, selectors.EVENT_READ) + s.register(wr, selectors.EVENT_WRITE) + readers.append(rd) + writers.append(wr) + r2w[rd] = wr + w2r[wr] = rd + + bufs = [] + + while writers: + ready = s.select() + ready_writers = find_ready_matching(ready, selectors.EVENT_WRITE) + if not ready_writers: + self.fail("no sockets ready for writing") + wr = random.choice(ready_writers) + wr.send(MSG) + + for i in range(10): + ready = s.select() + ready_readers = find_ready_matching(ready, + selectors.EVENT_READ) + if ready_readers: + break + # there might be a delay between the write to the write end and + # the read end is reported ready + sleep(0.1) + else: + self.fail("no sockets ready for reading") + self.assertEqual([w2r[wr]], ready_readers) + rd = ready_readers[0] + buf = rd.recv(MSG_LEN) + self.assertEqual(len(buf), MSG_LEN) + bufs.append(buf) + s.unregister(r2w[rd]) + s.unregister(rd) + writers.remove(r2w[rd]) + + self.assertEqual(bufs, [MSG] * NUM_SOCKETS) + + @unittest.skipIf(sys.platform == 'win32', + 'select.select() cannot be used with empty fd sets') + def test_empty_select(self): + s = self.SELECTOR() + self.addCleanup(s.close) + self.assertEqual(s.select(timeout=0), []) + + def test_timeout(self): + s = self.SELECTOR() + self.addCleanup(s.close) + + rd, wr = self.make_socketpair() + + s.register(wr, selectors.EVENT_WRITE) + t = time() + self.assertEqual(1, len(s.select(0))) + self.assertEqual(1, len(s.select(-1))) + self.assertLess(time() - t, 0.5) + + s.unregister(wr) + s.register(rd, selectors.EVENT_READ) + t = time() + self.assertFalse(s.select(0)) + self.assertFalse(s.select(-1)) + self.assertLess(time() - t, 0.5) + + t0 = time() + self.assertFalse(s.select(1)) + t1 = time() + dt = t1 - t0 + # Tolerate 2.0 seconds for very slow buildbots + self.assertTrue(0.8 <= dt <= 2.0, dt) + + @unittest.skipUnless(hasattr(signal, "alarm"), + "signal.alarm() required for this test") + def test_select_interrupt(self): + s = self.SELECTOR() + self.addCleanup(s.close) + + rd, wr = self.make_socketpair() + + orig_alrm_handler = signal.signal(signal.SIGALRM, lambda *args: None) + self.addCleanup(signal.signal, signal.SIGALRM, orig_alrm_handler) + self.addCleanup(signal.alarm, 0) + + signal.alarm(1) + + s.register(rd, selectors.EVENT_READ) + t = time() + self.assertFalse(s.select(2)) + self.assertLess(time() - t, 2.5) + + +class ScalableSelectorMixIn: + + # see issue #18963 for why it's skipped on older OS X versions + @support.requires_mac_ver(10, 5) + @unittest.skipUnless(resource, "Test needs resource module") + def test_above_fd_setsize(self): + # A scalable implementation should have no problem with more than + # FD_SETSIZE file descriptors. Since we don't know the value, we just + # try to set the soft RLIMIT_NOFILE to the hard RLIMIT_NOFILE ceiling. + soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE) + try: + resource.setrlimit(resource.RLIMIT_NOFILE, (hard, hard)) + self.addCleanup(resource.setrlimit, resource.RLIMIT_NOFILE, + (soft, hard)) + NUM_FDS = min(hard, 2**16) + except (OSError, ValueError): + NUM_FDS = soft + + # guard for already allocated FDs (stdin, stdout...) + NUM_FDS -= 32 + + s = self.SELECTOR() + self.addCleanup(s.close) + + for i in range(NUM_FDS // 2): + try: + rd, wr = self.make_socketpair() + except OSError: + # too many FDs, skip - note that we should only catch EMFILE + # here, but apparently *BSD and Solaris can fail upon connect() + # or bind() with EADDRNOTAVAIL, so let's be safe + self.skipTest("FD limit reached") + + try: + s.register(rd, selectors.EVENT_READ) + s.register(wr, selectors.EVENT_WRITE) + except OSError as e: + if e.errno == errno.ENOSPC: + # this can be raised by epoll if we go over + # fs.epoll.max_user_watches sysctl + self.skipTest("FD limit reached") + raise + + self.assertEqual(NUM_FDS // 2, len(s.select())) + + +class DefaultSelectorTestCase(BaseSelectorTestCase): + + SELECTOR = selectors.DefaultSelector + + +class SelectSelectorTestCase(BaseSelectorTestCase): + + SELECTOR = selectors.SelectSelector + + +@unittest.skipUnless(hasattr(selectors, 'PollSelector'), + "Test needs selectors.PollSelector") +class PollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn): + + SELECTOR = getattr(selectors, 'PollSelector', None) + + +@unittest.skipUnless(hasattr(selectors, 'EpollSelector'), + "Test needs selectors.EpollSelector") +class EpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn): + + SELECTOR = getattr(selectors, 'EpollSelector', None) + + +@unittest.skipUnless(hasattr(selectors, 'KqueueSelector'), + "Test needs selectors.KqueueSelector)") +class KqueueSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn): + + SELECTOR = getattr(selectors, 'KqueueSelector', None) + + +@unittest.skipUnless(hasattr(selectors, 'DevpollSelector'), + "Test needs selectors.DevpollSelector") +class DevpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn): - self.assertFalse(s._fd_to_key) + SELECTOR = getattr(selectors, 'DevpollSelector', None) - def test_key_from_fd(self): - s = FakeSelector() - key = s.register(1, selectors.EVENT_READ) - self.assertIs(key, s._key_from_fd(1)) - self.assertIsNone(s._key_from_fd(10)) - if hasattr(selectors.DefaultSelector, 'fileno'): - def test_fileno(self): - self.assertIsInstance(selectors.DefaultSelector().fileno(), int) +def test_main(): + tests = [DefaultSelectorTestCase, SelectSelectorTestCase, + PollSelectorTestCase, EpollSelectorTestCase, + KqueueSelectorTestCase, DevpollSelectorTestCase] + support.run_unittest(*tests) + support.reap_children() -if __name__ == '__main__': - unittest.main() +if __name__ == "__main__": + test_main() From d3b7c6a9cee52266f24f2edbf4e3796c07ea2dbf Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 9 Jan 2015 15:33:52 +0100 Subject: [PATCH 1250/1502] tox.ini: add py35 env --- tox.ini | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 040b25ab..37a1fe4a 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py33,py34 +envlist = py33,py34,py35 [testenv] deps= @@ -9,3 +9,6 @@ setenv = commands= python runtests.py -r {posargs} python run_aiotest.py -r {posargs} + +[testenv:py35] +basepython = python3.5 From aef31b1a43f6632ef37689c41d96a6edf9e62724 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 9 Jan 2015 15:38:48 +0100 Subject: [PATCH 1251/1502] Remove outdated TODO/XXX * Yes, futures errors (Error, CancelledError, TimeoutError, ...) are aliases of concurrent.futures errors * InvalidStateError: the state is already logged in the message when the exception is raised * call_exception_handler() now makes possible to decide how to handle exceptions * Add a docstring to _UnixDefaultEventLoopPolicy --- asyncio/futures.py | 2 -- asyncio/proactor_events.py | 3 ++- asyncio/selector_events.py | 1 - asyncio/unix_events.py | 5 +---- 4 files changed, 3 insertions(+), 8 deletions(-) diff --git a/asyncio/futures.py b/asyncio/futures.py index f46d008f..e0e12f05 100644 --- a/asyncio/futures.py +++ b/asyncio/futures.py @@ -20,7 +20,6 @@ _PY34 = sys.version_info >= (3, 4) -# TODO: Do we really want to depend on concurrent.futures internals? Error = concurrent.futures._base.Error CancelledError = concurrent.futures.CancelledError TimeoutError = concurrent.futures.TimeoutError @@ -30,7 +29,6 @@ class InvalidStateError(Error): """The operation is not allowed in this state.""" - # TODO: Show the future, its state, the method, and the required state. class _TracebackLogger: diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 44a81975..0a4d0685 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -487,7 +487,8 @@ def loop(f=None): self.call_soon(loop) def _process_events(self, event_list): - pass # XXX hard work currently done in poll + # Events are processed in the IocpProactor._poll() method + pass def _stop_accept_futures(self): for future in self._accept_futures.values(): diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 58b61f1c..307a9add 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -145,7 +145,6 @@ def _accept_connection(self, protocol_factory, sock, pass # False alarm. except OSError as exc: # There's nowhere to send the error, so just log it. - # TODO: Someone will want an error handler for this. if exc.errno in (errno.EMFILE, errno.ENFILE, errno.ENOBUFS, errno.ENOMEM): # Some platforms (e.g. Linux keep reporting the FD as diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 91e43cfc..1a4d4183 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -496,9 +496,6 @@ def _write_ready(self): def can_write_eof(self): return True - # TODO: Make the relationships between write_eof(), close(), - # abort(), _fatal_error() and _close() more straightforward. - def write_eof(self): if self._closing: return @@ -897,7 +894,7 @@ def _do_waitpid_all(self): class _UnixDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy): - """XXX""" + """UNIX event loop policy with a watcher for child processes.""" _loop_factory = _UnixSelectorEventLoop def __init__(self): From 0fefcb7261413bc4656676378030f6161592e4d1 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 9 Jan 2015 15:41:50 +0100 Subject: [PATCH 1252/1502] Tulip issue #184: FlowControlMixin constructor now get the event loop if the loop parameter is not set Add unit tests to ensure that constructor of StreamReader and StreamReaderProtocol classes get the event loop. --- asyncio/streams.py | 10 +++++++--- tests/test_streams.py | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/asyncio/streams.py b/asyncio/streams.py index c77eb606..5a96b241 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -145,7 +145,10 @@ class FlowControlMixin(protocols.Protocol): """ def __init__(self, loop=None): - self._loop = loop # May be None; we may never need it. + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop self._paused = False self._drain_waiter = None self._connection_lost = False @@ -306,8 +309,9 @@ def __init__(self, limit=_DEFAULT_LIMIT, loop=None): # it also doubles as half the buffer limit. self._limit = limit if loop is None: - loop = events.get_event_loop() - self._loop = loop + self._loop = events.get_event_loop() + else: + self._loop = loop self._buffer = bytearray() self._eof = False # Whether we're done. self._waiter = None # A future. diff --git a/tests/test_streams.py b/tests/test_streams.py index 05963cf1..a18603af 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -625,6 +625,25 @@ def test_read_all_from_pipe_reader(self): data = self.loop.run_until_complete(reader.read(-1)) self.assertEqual(data, b'data') + def test_streamreader_constructor(self): + self.addCleanup(asyncio.set_event_loop, None) + asyncio.set_event_loop(self.loop) + + # Tulip issue #184: Ensure that StreamReaderProtocol constructor + # retrieves the current loop if the loop parameter is not set + reader = asyncio.StreamReader() + self.assertIs(reader._loop, self.loop) + + def test_streamreaderprotocol_constructor(self): + self.addCleanup(asyncio.set_event_loop, None) + asyncio.set_event_loop(self.loop) + + # Tulip issue #184: Ensure that StreamReaderProtocol constructor + # retrieves the current loop if the loop parameter is not set + reader = mock.Mock() + protocol = asyncio.StreamReaderProtocol(reader) + self.assertIs(protocol._loop, self.loop) + if __name__ == '__main__': unittest.main() From d8051e36dd4b5e8d13708a8ce07ca7ced902493c Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 9 Jan 2015 16:04:19 +0100 Subject: [PATCH 1253/1502] tox.ini: add py3_release env to run tests in release mode --- tox.ini | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 37a1fe4a..192a0b9f 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py33,py34,py35 +envlist = py33,py34,py3_release [testenv] deps= @@ -10,5 +10,11 @@ commands= python runtests.py -r {posargs} python run_aiotest.py -r {posargs} +[testenv:py3_release] +# Run tests in debug mode +setenv = + PYTHONASYNCIODEBUG = +basepython = python3 + [testenv:py35] basepython = python3.5 From ed483878f14caeb22986da63eab05cc4ab59cad3 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 9 Jan 2015 16:10:48 +0100 Subject: [PATCH 1254/1502] Make the release.py program executable on UNIX --- release.py | 1 + 1 file changed, 1 insertion(+) mode change 100644 => 100755 release.py diff --git a/release.py b/release.py old mode 100644 new mode 100755 index c7829ec6..3fea4a94 --- a/release.py +++ b/release.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 """ Script to upload 32 bits and 64 bits wheel packages for Python 3.3 on Windows. From f98e47c04b861042b9c88be5a82006237a9b96f4 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 9 Jan 2015 21:29:29 +0100 Subject: [PATCH 1255/1502] Python issue #23209: Break some reference cycles in asyncio. Patch written by Martin Richard. --- asyncio/base_subprocess.py | 1 + asyncio/futures.py | 2 +- asyncio/selectors.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index 81698b09..afc434de 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -182,6 +182,7 @@ def __repr__(self): def connection_lost(self, exc): self.disconnected = True self.proc._pipe_connection_lost(self.fd, exc) + self.proc = None def pause_writing(self): self.proc._protocol.pause_writing() diff --git a/asyncio/futures.py b/asyncio/futures.py index e0e12f05..19212a94 100644 --- a/asyncio/futures.py +++ b/asyncio/futures.py @@ -405,5 +405,5 @@ def _check_cancel_other(f): new_future.add_done_callback(_check_cancel_other) fut.add_done_callback( lambda future: loop.call_soon_threadsafe( - new_future._copy_state, fut)) + new_future._copy_state, future)) return new_future diff --git a/asyncio/selectors.py b/asyncio/selectors.py index 598845d4..5850dead 100644 --- a/asyncio/selectors.py +++ b/asyncio/selectors.py @@ -256,6 +256,7 @@ def modify(self, fileobj, events, data=None): def close(self): self._fd_to_key.clear() + self._map = None def get_map(self): return self._map From c3fed610874521d88680df3ca7902797baa40040 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 9 Jan 2015 21:58:29 +0100 Subject: [PATCH 1256/1502] Python issue #23209: Revert change on selectors, test_selectors failed. --- asyncio/selectors.py | 1 - 1 file changed, 1 deletion(-) diff --git a/asyncio/selectors.py b/asyncio/selectors.py index 5850dead..598845d4 100644 --- a/asyncio/selectors.py +++ b/asyncio/selectors.py @@ -256,7 +256,6 @@ def modify(self, fileobj, events, data=None): def close(self): self._fd_to_key.clear() - self._map = None def get_map(self): return self._map From d9072d48f9204c9554abe1413ecd3046b6d6faf0 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 13 Jan 2015 10:04:24 +0100 Subject: [PATCH 1257/1502] Python issue #23209, #23225: selectors.BaseSelector.get_key() now raises a RuntimeError if the selector is closed. And selectors.BaseSelector.close() now clears its internal reference to the selector mapping to break a reference cycle. Initial patch written by Martin Richard. --- asyncio/selectors.py | 3 +++ tests/test_selectors.py | 11 +++++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/asyncio/selectors.py b/asyncio/selectors.py index 598845d4..6d569c30 100644 --- a/asyncio/selectors.py +++ b/asyncio/selectors.py @@ -174,6 +174,8 @@ def get_key(self, fileobj): SelectorKey for this file object """ mapping = self.get_map() + if mapping is None: + raise RuntimeError('Selector is closed') try: return mapping[fileobj] except KeyError: @@ -256,6 +258,7 @@ def modify(self, fileobj, events, data=None): def close(self): self._fd_to_key.clear() + self._map = None def get_map(self): return self._map diff --git a/tests/test_selectors.py b/tests/test_selectors.py index 3d5ef918..49b5b8d0 100644 --- a/tests/test_selectors.py +++ b/tests/test_selectors.py @@ -159,14 +159,17 @@ def test_close(self): s = self.SELECTOR() self.addCleanup(s.close) + mapping = s.get_map() rd, wr = self.make_socketpair() s.register(rd, selectors.EVENT_READ) s.register(wr, selectors.EVENT_WRITE) s.close() - self.assertRaises(KeyError, s.get_key, rd) - self.assertRaises(KeyError, s.get_key, wr) + self.assertRaises(RuntimeError, s.get_key, rd) + self.assertRaises(RuntimeError, s.get_key, wr) + self.assertRaises(KeyError, mapping.__getitem__, rd) + self.assertRaises(KeyError, mapping.__getitem__, wr) def test_get_key(self): s = self.SELECTOR() @@ -233,8 +236,8 @@ def test_context_manager(self): sel.register(rd, selectors.EVENT_READ) sel.register(wr, selectors.EVENT_WRITE) - self.assertRaises(KeyError, s.get_key, rd) - self.assertRaises(KeyError, s.get_key, wr) + self.assertRaises(RuntimeError, s.get_key, rd) + self.assertRaises(RuntimeError, s.get_key, wr) def test_fileno(self): s = self.SELECTOR() From d2cafa0a613a054c0696151c88950be4dad5933d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 13 Jan 2015 16:14:44 +0100 Subject: [PATCH 1258/1502] Python issue #22922: Fix ProactorEventLoop.close() Close the IocpProactor before closing the event loop. IocpProactor.close() can call loop.call_soon(), which is forbidden when the event loop is closed. --- asyncio/proactor_events.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 0a4d0685..5986e37f 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -387,13 +387,19 @@ def close(self): raise RuntimeError("Cannot close a running event loop") if self.is_closed(): return + + # Call these methods before closing the event loop (before calling + # BaseEventLoop.close), because they can schedule callbacks with + # call_soon(), which is forbidden when the event loop is closed. self._stop_accept_futures() self._close_self_pipe() - super().close() self._proactor.close() self._proactor = None self._selector = None + # Close the event loop + super().close() + def sock_recv(self, sock, n): return self._proactor.recv(sock, n) From 9d5f74089b11fb2bdc7cbb4abbdafcd931bab029 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 13 Jan 2015 16:14:53 +0100 Subject: [PATCH 1259/1502] Tulip issue 184: Fix test_pipe() on Windows Pass explicitly the event loop to StreamReaderProtocol. --- tests/test_windows_events.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index 9b264a64..f9b3dd15 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -67,7 +67,8 @@ def _test_pipe(self): clients = [] for i in range(5): stream_reader = asyncio.StreamReader(loop=self.loop) - protocol = asyncio.StreamReaderProtocol(stream_reader) + protocol = asyncio.StreamReaderProtocol(stream_reader, + loop=self.loop) trans, proto = yield from self.loop.create_pipe_connection( lambda: protocol, ADDRESS) self.assertIsInstance(trans, asyncio.Transport) From 0e0b0fe989adbf08c2cd507b964df060b27a93ba Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 14 Jan 2015 00:15:15 +0100 Subject: [PATCH 1260/1502] Python issue #22560: New SSL implementation based on ssl.MemoryBIO The new SSL implementation is based on the new ssl.MemoryBIO which is only available on Python 3.5. On Python 3.4 and older, the legacy SSL implementation (using SSL_write, SSL_read, etc.) is used. The proactor event loop only supports the new implementation. The new asyncio.sslproto module adds _SSLPipe, SSLProtocol and _SSLProtocolTransport classes. _SSLPipe allows to "wrap" or "unwrap" a socket (switch between cleartext and SSL/TLS). Patch written by Antoine Pitrou. sslproto.py is based on gruvi/ssl.py of the gruvi project written by Geert Jansen. This change adds SSL support to ProactorEventLoop on Python 3.5 and newer! It becomes also possible to implement STARTTTLS: switch a cleartext socket to SSL. --- asyncio/proactor_events.py | 31 +- asyncio/selector_events.py | 45 +-- asyncio/sslproto.py | 640 ++++++++++++++++++++++++++++++++++ asyncio/test_utils.py | 5 + tests/test_events.py | 58 ++- tests/test_selector_events.py | 6 +- 6 files changed, 747 insertions(+), 38 deletions(-) create mode 100644 asyncio/sslproto.py diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 5986e37f..4716bb59 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -11,6 +11,7 @@ from . import base_events from . import constants from . import futures +from . import sslproto from . import transports from .log import logger @@ -367,6 +368,20 @@ def _make_socket_transport(self, sock, protocol, waiter=None, return _ProactorSocketTransport(self, sock, protocol, waiter, extra, server) + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, + *, server_side=False, server_hostname=None, + extra=None, server=None): + if not sslproto._is_sslproto_available(): + raise NotImplementedError("Proactor event loop requires Python 3.5" + " or newer (ssl.MemoryBIO) to support " + "SSL") + + ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter, + server_side, server_hostname) + _ProactorSocketTransport(self, rawsock, ssl_protocol, + extra=extra, server=server) + return ssl_protocol._app_transport + def _make_duplex_pipe_transport(self, sock, protocol, waiter=None, extra=None): return _ProactorDuplexPipeTransport(self, @@ -455,9 +470,8 @@ def _loop_self_reading(self, f=None): def _write_to_self(self): self._csock.send(b'\0') - def _start_serving(self, protocol_factory, sock, ssl=None, server=None): - if ssl: - raise ValueError('IocpEventLoop is incompatible with SSL.') + def _start_serving(self, protocol_factory, sock, + sslcontext=None, server=None): def loop(f=None): try: @@ -467,9 +481,14 @@ def loop(f=None): logger.debug("%r got a new connection from %r: %r", server, addr, conn) protocol = protocol_factory() - self._make_socket_transport( - conn, protocol, - extra={'peername': addr}, server=server) + if sslcontext is not None: + self._make_ssl_transport( + conn, protocol, sslcontext, server_side=True, + extra={'peername': addr}, server=server) + else: + self._make_socket_transport( + conn, protocol, + extra={'peername': addr}, server=server) if self.is_closed(): return f = self._proactor.accept(sock) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 307a9add..b2f29c70 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -10,6 +10,7 @@ import errno import functools import socket +import sys try: import ssl except ImportError: # pragma: no cover @@ -21,6 +22,7 @@ from . import futures from . import selectors from . import transports +from . import sslproto from .log import logger @@ -58,6 +60,24 @@ def _make_socket_transport(self, sock, protocol, waiter=None, *, def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, *, server_side=False, server_hostname=None, extra=None, server=None): + if not sslproto._is_sslproto_available(): + return self._make_legacy_ssl_transport( + rawsock, protocol, sslcontext, waiter, + server_side=server_side, server_hostname=server_hostname, + extra=extra, server=server) + + ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter, + server_side, server_hostname) + _SelectorSocketTransport(self, rawsock, ssl_protocol, + extra=extra, server=server) + return ssl_protocol._app_transport + + def _make_legacy_ssl_transport(self, rawsock, protocol, sslcontext, + waiter, *, + server_side=False, server_hostname=None, + extra=None, server=None): + # Use the legacy API: SSL_write, SSL_read, etc. The legacy API is used + # on Python 3.4 and older, when ssl.MemoryBIO is not available. return _SelectorSslTransport( self, rawsock, protocol, sslcontext, waiter, server_side, server_hostname, extra, server) @@ -508,7 +528,8 @@ def close(self): def _fatal_error(self, exc, message='Fatal error on transport'): # Should be called from exception handler only. - if isinstance(exc, (BrokenPipeError, ConnectionResetError)): + if isinstance(exc, (BrokenPipeError, + ConnectionResetError, ConnectionAbortedError)): if self._loop.get_debug(): logger.debug("%r: %s", self, message, exc_info=True) else: @@ -683,26 +704,8 @@ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, if ssl is None: raise RuntimeError('stdlib ssl module not available') - if server_side: - if not sslcontext: - raise ValueError('Server side ssl needs a valid SSLContext') - else: - if not sslcontext: - # Client side may pass ssl=True to use a default - # context; in that case the sslcontext passed is None. - # The default is secure for client connections. - if hasattr(ssl, 'create_default_context'): - # Python 3.4+: use up-to-date strong settings. - sslcontext = ssl.create_default_context() - if not server_hostname: - sslcontext.check_hostname = False - else: - # Fallback for Python 3.3. - sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - sslcontext.options |= ssl.OP_NO_SSLv2 - sslcontext.options |= ssl.OP_NO_SSLv3 - sslcontext.set_default_verify_paths() - sslcontext.verify_mode = ssl.CERT_REQUIRED + if not sslcontext: + sslcontext = sslproto._create_transport_context(server_side, server_hostname) wrap_kwargs = { 'server_side': server_side, diff --git a/asyncio/sslproto.py b/asyncio/sslproto.py new file mode 100644 index 00000000..987c158e --- /dev/null +++ b/asyncio/sslproto.py @@ -0,0 +1,640 @@ +import collections +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import protocols +from . import transports +from .log import logger + + +def _create_transport_context(server_side, server_hostname): + if server_side: + raise ValueError('Server side SSL needs a valid SSLContext') + + # Client side may pass ssl=True to use a default + # context; in that case the sslcontext passed is None. + # The default is secure for client connections. + if hasattr(ssl, 'create_default_context'): + # Python 3.4+: use up-to-date strong settings. + sslcontext = ssl.create_default_context() + if not server_hostname: + sslcontext.check_hostname = False + else: + # Fallback for Python 3.3. + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.options |= ssl.OP_NO_SSLv3 + sslcontext.set_default_verify_paths() + sslcontext.verify_mode = ssl.CERT_REQUIRED + return sslcontext + + +def _is_sslproto_available(): + return hasattr(ssl, "MemoryBIO") + + +# States of an _SSLPipe. +_UNWRAPPED = "UNWRAPPED" +_DO_HANDSHAKE = "DO_HANDSHAKE" +_WRAPPED = "WRAPPED" +_SHUTDOWN = "SHUTDOWN" + + +class _SSLPipe(object): + """An SSL "Pipe". + + An SSL pipe allows you to communicate with an SSL/TLS protocol instance + through memory buffers. It can be used to implement a security layer for an + existing connection where you don't have access to the connection's file + descriptor, or for some reason you don't want to use it. + + An SSL pipe can be in "wrapped" and "unwrapped" mode. In unwrapped mode, + data is passed through untransformed. In wrapped mode, application level + data is encrypted to SSL record level data and vice versa. The SSL record + level is the lowest level in the SSL protocol suite and is what travels + as-is over the wire. + + An SslPipe initially is in "unwrapped" mode. To start SSL, call + do_handshake(). To shutdown SSL again, call unwrap(). + """ + + max_size = 256 * 1024 # Buffer size passed to read() + + def __init__(self, context, server_side, server_hostname=None): + """ + The *context* argument specifies the ssl.SSLContext to use. + + The *server_side* argument indicates whether this is a server side or + client side transport. + + The optional *server_hostname* argument can be used to specify the + hostname you are connecting to. You may only specify this parameter if + the _ssl module supports Server Name Indication (SNI). + """ + self._context = context + self._server_side = server_side + self._server_hostname = server_hostname + self._state = _UNWRAPPED + self._incoming = ssl.MemoryBIO() + self._outgoing = ssl.MemoryBIO() + self._sslobj = None + self._need_ssldata = False + self._handshake_cb = None + self._shutdown_cb = None + + @property + def context(self): + """The SSL context passed to the constructor.""" + return self._context + + @property + def ssl_object(self): + """The internal ssl.SSLObject instance. + + Return None if the pipe is not wrapped. + """ + return self._sslobj + + @property + def need_ssldata(self): + """Whether more record level data is needed to complete a handshake + that is currently in progress.""" + return self._need_ssldata + + @property + def wrapped(self): + """ + Whether a security layer is currently in effect. + + Return False during handshake. + """ + return self._state == _WRAPPED + + def do_handshake(self, callback=None): + """Start the SSL handshake. + + Return a list of ssldata. A ssldata element is a list of buffers + + The optional *callback* argument can be used to install a callback that + will be called when the handshake is complete. The callback will be + called with None if successful, else an exception instance. + """ + if self._state != _UNWRAPPED: + raise RuntimeError('handshake in progress or completed') + self._sslobj = self._context.wrap_bio( + self._incoming, self._outgoing, + server_side=self._server_side, + server_hostname=self._server_hostname) + self._state = _DO_HANDSHAKE + self._handshake_cb = callback + ssldata, appdata = self.feed_ssldata(b'', only_handshake=True) + assert len(appdata) == 0 + return ssldata + + def shutdown(self, callback=None): + """Start the SSL shutdown sequence. + + Return a list of ssldata. A ssldata element is a list of buffers + + The optional *callback* argument can be used to install a callback that + will be called when the shutdown is complete. The callback will be + called without arguments. + """ + if self._state == _UNWRAPPED: + raise RuntimeError('no security layer present') + if self._state == _SHUTDOWN: + raise RuntimeError('shutdown in progress') + assert self._state in (_WRAPPED, _DO_HANDSHAKE) + self._state = _SHUTDOWN + self._shutdown_cb = callback + ssldata, appdata = self.feed_ssldata(b'') + assert appdata == [] or appdata == [b''] + return ssldata + + def feed_eof(self): + """Send a potentially "ragged" EOF. + + This method will raise an SSL_ERROR_EOF exception if the EOF is + unexpected. + """ + self._incoming.write_eof() + ssldata, appdata = self.feed_ssldata(b'') + assert appdata == [] or appdata == [b''] + + def feed_ssldata(self, data, only_handshake=False): + """Feed SSL record level data into the pipe. + + The data must be a bytes instance. It is OK to send an empty bytes + instance. This can be used to get ssldata for a handshake initiated by + this endpoint. + + Return a (ssldata, appdata) tuple. The ssldata element is a list of + buffers containing SSL data that needs to be sent to the remote SSL. + + The appdata element is a list of buffers containing plaintext data that + needs to be forwarded to the application. The appdata list may contain + an empty buffer indicating an SSL "close_notify" alert. This alert must + be acknowledged by calling shutdown(). + """ + if self._state == _UNWRAPPED: + # If unwrapped, pass plaintext data straight through. + if data: + appdata = [data] + else: + appdata = [] + return ([], appdata) + + self._need_ssldata = False + if data: + self._incoming.write(data) + + ssldata = [] + appdata = [] + try: + if self._state == _DO_HANDSHAKE: + # Call do_handshake() until it doesn't raise anymore. + self._sslobj.do_handshake() + self._state = _WRAPPED + if self._handshake_cb: + self._handshake_cb(None) + if only_handshake: + return (ssldata, appdata) + # Handshake done: execute the wrapped block + + if self._state == _WRAPPED: + # Main state: read data from SSL until close_notify + while True: + chunk = self._sslobj.read(self.max_size) + appdata.append(chunk) + if not chunk: # close_notify + break + + elif self._state == _SHUTDOWN: + # Call shutdown() until it doesn't raise anymore. + self._sslobj.unwrap() + self._sslobj = None + self._state = _UNWRAPPED + if self._shutdown_cb: + self._shutdown_cb() + + elif self._state == _UNWRAPPED: + # Drain possible plaintext data after close_notify. + appdata.append(self._incoming.read()) + except (ssl.SSLError, ssl.CertificateError) as exc: + if getattr(exc, 'errno', None) not in ( + ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, + ssl.SSL_ERROR_SYSCALL): + if self._state == _DO_HANDSHAKE and self._handshake_cb: + self._handshake_cb(exc) + raise + self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ) + + # Check for record level data that needs to be sent back. + # Happens for the initial handshake and renegotiations. + if self._outgoing.pending: + ssldata.append(self._outgoing.read()) + return (ssldata, appdata) + + def feed_appdata(self, data, offset=0): + """Feed plaintext data into the pipe. + + Return an (ssldata, offset) tuple. The ssldata element is a list of + buffers containing record level data that needs to be sent to the + remote SSL instance. The offset is the number of plaintext bytes that + were processed, which may be less than the length of data. + + NOTE: In case of short writes, this call MUST be retried with the SAME + buffer passed into the *data* argument (i.e. the id() must be the + same). This is an OpenSSL requirement. A further particularity is that + a short write will always have offset == 0, because the _ssl module + does not enable partial writes. And even though the offset is zero, + there will still be encrypted data in ssldata. + """ + assert 0 <= offset <= len(data) + if self._state == _UNWRAPPED: + # pass through data in unwrapped mode + if offset < len(data): + ssldata = [data[offset:]] + else: + ssldata = [] + return (ssldata, len(data)) + + ssldata = [] + view = memoryview(data) + while True: + self._need_ssldata = False + try: + if offset < len(view): + offset += self._sslobj.write(view[offset:]) + except ssl.SSLError as exc: + # It is not allowed to call write() after unwrap() until the + # close_notify is acknowledged. We return the condition to the + # caller as a short write. + if exc.reason == 'PROTOCOL_IS_SHUTDOWN': + exc.errno = ssl.SSL_ERROR_WANT_READ + if exc.errno not in (ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE, + ssl.SSL_ERROR_SYSCALL): + raise + self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ) + + # See if there's any record level data back for us. + if self._outgoing.pending: + ssldata.append(self._outgoing.read()) + if offset == len(view) or self._need_ssldata: + break + return (ssldata, offset) + + +class _SSLProtocolTransport(transports._FlowControlMixin, + transports.Transport): + + def __init__(self, loop, ssl_protocol, app_protocol): + self._loop = loop + self._ssl_protocol = ssl_protocol + self._app_protocol = app_protocol + + def get_extra_info(self, name, default=None): + """Get optional transport information.""" + return self._ssl_protocol._get_extra_info(name, default) + + def close(self): + """Close the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + self._ssl_protocol._start_shutdown() + + def pause_reading(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume_reading() is called. + """ + self._ssl_protocol._transport.pause_reading() + + def resume_reading(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + self._ssl_protocol._transport.resume_reading() + + def set_write_buffer_limits(self, high=None, low=None): + """Set the high- and low-water limits for write flow control. + + These two values control when to call the protocol's + pause_writing() and resume_writing() methods. If specified, + the low-water limit must be less than or equal to the + high-water limit. Neither value can be negative. + + The defaults are implementation-specific. If only the + high-water limit is given, the low-water limit defaults to a + implementation-specific value less than or equal to the + high-water limit. Setting high to zero forces low to zero as + well, and causes pause_writing() to be called whenever the + buffer becomes non-empty. Setting low to zero causes + resume_writing() to be called only once the buffer is empty. + Use of zero for either limit is generally sub-optimal as it + reduces opportunities for doing I/O and computation + concurrently. + """ + self._ssl_protocol._transport.set_write_buffer_limits(high, low) + + def get_write_buffer_size(self): + """Return the current size of the write buffer.""" + return self._ssl_protocol._transport.get_write_buffer_size() + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError("data: expecting a bytes-like instance, got {!r}" + .format(type(data).__name__)) + if not data: + return + self._ssl_protocol._write_appdata(data) + + def can_write_eof(self): + """Return True if this transport supports write_eof(), False if not.""" + return False + + def abort(self): + """Close the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + self._ssl_protocol._abort() + + +class SSLProtocol(protocols.Protocol): + """SSL protocol. + + Implementation of SSL on top of a socket using incoming and outgoing + buffers which are ssl.MemoryBIO objects. + """ + + def __init__(self, loop, app_protocol, sslcontext, waiter, + server_side=False, server_hostname=None): + if ssl is None: + raise RuntimeError('stdlib ssl module not available') + + if not sslcontext: + sslcontext = _create_transport_context(server_side, server_hostname) + + self._server_side = server_side + if server_hostname and not server_side: + self._server_hostname = server_hostname + else: + self._server_hostname = None + self._sslcontext = sslcontext + # SSL-specific extra info. More info are set when the handshake + # completes. + self._extra = dict(sslcontext=sslcontext) + + # App data write buffering + self._write_backlog = collections.deque() + self._write_buffer_size = 0 + + self._waiter = waiter + self._closing = False + self._loop = loop + self._app_protocol = app_protocol + self._app_transport = _SSLProtocolTransport(self._loop, + self, self._app_protocol) + self._sslpipe = None + self._session_established = False + self._in_handshake = False + self._in_shutdown = False + + def connection_made(self, transport): + """Called when the low-level connection is made. + + Start the SSL handshake. + """ + self._transport = transport + self._sslpipe = _SSLPipe(self._sslcontext, + self._server_side, + self._server_hostname) + self._start_handshake() + + def connection_lost(self, exc): + """Called when the low-level connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + if self._session_established: + self._session_established = False + self._loop.call_soon(self._app_protocol.connection_lost, exc) + self._transport = None + self._app_transport = None + + def pause_writing(self): + """Called when the low-level transport's buffer goes over + the high-water mark. + """ + self._app_protocol.pause_writing() + + def resume_writing(self): + """Called when the low-level transport's buffer drains below + the low-water mark. + """ + self._app_protocol.resume_writing() + + def data_received(self, data): + """Called when some SSL data is received. + + The argument is a bytes object. + """ + try: + ssldata, appdata = self._sslpipe.feed_ssldata(data) + except ssl.SSLError as e: + if self._loop.get_debug(): + logger.warning('%r: SSL error %s (reason %s)', + self, e.errno, e.reason) + self._abort() + return + + for chunk in ssldata: + self._transport.write(chunk) + + for chunk in appdata: + if chunk: + self._app_protocol.data_received(chunk) + else: + self._start_shutdown() + break + + def eof_received(self): + """Called when the other end of the low-level stream + is half-closed. + + If this returns a false value (including None), the transport + will close itself. If it returns a true value, closing the + transport is up to the protocol. + """ + try: + if self._loop.get_debug(): + logger.debug("%r received EOF", self) + if not self._in_handshake: + keep_open = self._app_protocol.eof_received() + if keep_open: + logger.warning('returning true from eof_received() ' + 'has no effect when using ssl') + finally: + self._transport.close() + + def _get_extra_info(self, name, default=None): + if name in self._extra: + return self._extra[name] + else: + return self._transport.get_extra_info(name, default) + + def _start_shutdown(self): + if self._in_shutdown: + return + self._in_shutdown = True + self._write_appdata(b'') + + def _write_appdata(self, data): + self._write_backlog.append((data, 0)) + self._write_buffer_size += len(data) + self._process_write_backlog() + + def _start_handshake(self): + if self._loop.get_debug(): + logger.debug("%r starts SSL handshake", self) + self._handshake_start_time = self._loop.time() + else: + self._handshake_start_time = None + self._in_handshake = True + # (b'', 1) is a special value in _process_write_backlog() to do + # the SSL handshake + self._write_backlog.append((b'', 1)) + self._loop.call_soon(self._process_write_backlog) + + def _on_handshake_complete(self, handshake_exc): + self._in_handshake = False + + sslobj = self._sslpipe.ssl_object + peercert = None if handshake_exc else sslobj.getpeercert() + try: + if handshake_exc is not None: + raise handshake_exc + if not hasattr(self._sslcontext, 'check_hostname'): + # Verify hostname if requested, Python 3.4+ uses check_hostname + # and checks the hostname in do_handshake() + if (self._server_hostname + and self._sslcontext.verify_mode != ssl.CERT_NONE): + ssl.match_hostname(peercert, self._server_hostname) + except BaseException as exc: + if self._loop.get_debug(): + if isinstance(exc, ssl.CertificateError): + logger.warning("%r: SSL handshake failed " + "on verifying the certificate", + self, exc_info=True) + else: + logger.warning("%r: SSL handshake failed", + self, exc_info=True) + self._transport.close() + if isinstance(exc, Exception): + if self._waiter is not None: + self._waiter.set_exception(exc) + return + else: + raise + + if self._loop.get_debug(): + dt = self._loop.time() - self._handshake_start_time + logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3) + + # Add extra info that becomes available after handshake. + self._extra.update(peercert=peercert, + cipher=sslobj.cipher(), + compression=sslobj.compression(), + ) + self._app_protocol.connection_made(self._app_transport) + if self._waiter is not None: + # wait until protocol.connection_made() has been called + self._waiter._set_result_unless_cancelled(None) + self._session_established = True + # In case transport.write() was already called + self._process_write_backlog() + + def _process_write_backlog(self): + # Try to make progress on the write backlog. + if self._transport is None: + return + + try: + for i in range(len(self._write_backlog)): + data, offset = self._write_backlog[0] + if data: + ssldata, offset = self._sslpipe.feed_appdata(data, offset) + elif offset: + ssldata = self._sslpipe.do_handshake(self._on_handshake_complete) + offset = 1 + else: + ssldata = self._sslpipe.shutdown(self._finalize) + offset = 1 + + for chunk in ssldata: + self._transport.write(chunk) + + if offset < len(data): + self._write_backlog[0] = (data, offset) + # A short write means that a write is blocked on a read + # We need to enable reading if it is paused! + assert self._sslpipe.need_ssldata + if self._transport._paused: + self._transport.resume_reading() + break + + # An entire chunk from the backlog was processed. We can + # delete it and reduce the outstanding buffer size. + del self._write_backlog[0] + self._write_buffer_size -= len(data) + except BaseException as exc: + if self._in_handshake: + self._on_handshake_complete(exc) + else: + self._fatal_error(exc, 'Fatal error on SSL transport') + + def _fatal_error(self, exc, message='Fatal error on transport'): + # Should be called from exception handler only. + if isinstance(exc, (BrokenPipeError, ConnectionResetError)): + if self._loop.get_debug(): + logger.debug("%r: %s", self, message, exc_info=True) + else: + self._loop.call_exception_handler({ + 'message': message, + 'exception': exc, + 'transport': self._transport, + 'protocol': self, + }) + if self._transport: + self._transport._force_close(exc) + + def _finalize(self): + if self._transport is not None: + self._transport.close() + + def _abort(self): + if self._transport is not None: + try: + self._transport.abort() + finally: + self._finalize() diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index 3e5eee54..180bafa1 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -434,3 +434,8 @@ def mock_nonblocking_socket(): sock = mock.Mock(socket.socket) sock.gettimeout.return_value = 0.0 return sock + + +def force_legacy_ssl_support(): + return mock.patch('asyncio.sslproto._is_sslproto_available', + return_value=False) diff --git a/tests/test_events.py b/tests/test_events.py index af2da1fe..a2c6dc94 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -650,6 +650,10 @@ def test_create_ssl_connection(self): *httpd.address) self._test_create_ssl_connection(httpd, create_connection) + def test_legacy_create_ssl_connection(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_ssl_connection() + @unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_create_ssl_unix_connection(self): @@ -666,6 +670,10 @@ def test_create_ssl_unix_connection(self): self._test_create_ssl_connection(httpd, create_connection, check_sockname) + def test_legacy_create_ssl_unix_connection(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_ssl_unix_connection() + def test_create_connection_local_addr(self): with test_utils.run_test_server() as httpd: port = support.find_unused_port() @@ -826,6 +834,10 @@ def test_create_server_ssl(self): # stop serving server.close() + def test_legacy_create_server_ssl(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_server_ssl() + @unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_create_unix_server_ssl(self): @@ -857,6 +869,10 @@ def test_create_unix_server_ssl(self): # stop serving server.close() + def test_legacy_create_unix_server_ssl(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_unix_server_ssl() + @unittest.skipIf(ssl is None, 'No ssl module') def test_create_server_ssl_verify_failed(self): proto = MyProto(loop=self.loop) @@ -881,6 +897,10 @@ def test_create_server_ssl_verify_failed(self): self.assertIsNone(proto.transport) server.close() + def test_legacy_create_server_ssl_verify_failed(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_server_ssl_verify_failed() + @unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_create_unix_server_ssl_verify_failed(self): @@ -907,6 +927,10 @@ def test_create_unix_server_ssl_verify_failed(self): self.assertIsNone(proto.transport) server.close() + def test_legacy_create_unix_server_ssl_verify_failed(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_unix_server_ssl_verify_failed() + @unittest.skipIf(ssl is None, 'No ssl module') def test_create_server_ssl_match_failed(self): proto = MyProto(loop=self.loop) @@ -934,6 +958,10 @@ def test_create_server_ssl_match_failed(self): proto.transport.close() server.close() + def test_legacy_create_server_ssl_match_failed(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_server_ssl_match_failed() + @unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_create_unix_server_ssl_verified(self): @@ -958,6 +986,11 @@ def test_create_unix_server_ssl_verified(self): proto.transport.close() client.close() server.close() + self.loop.run_until_complete(proto.done) + + def test_legacy_create_unix_server_ssl_verified(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_unix_server_ssl_verified() @unittest.skipIf(ssl is None, 'No ssl module') def test_create_server_ssl_verified(self): @@ -982,6 +1015,11 @@ def test_create_server_ssl_verified(self): proto.transport.close() client.close() server.close() + self.loop.run_until_complete(proto.done) + + def test_legacy_create_server_ssl_verified(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_server_ssl_verified() def test_create_server_sock(self): proto = asyncio.Future(loop=self.loop) @@ -1746,20 +1784,20 @@ class ProactorEventLoopTests(EventLoopTestsMixin, def create_event_loop(self): return asyncio.ProactorEventLoop() - def test_create_ssl_connection(self): - raise unittest.SkipTest("IocpEventLoop incompatible with SSL") + def test_legacy_create_ssl_connection(self): + raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL") - def test_create_server_ssl(self): - raise unittest.SkipTest("IocpEventLoop incompatible with SSL") + def test_legacy_create_server_ssl(self): + raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL") - def test_create_server_ssl_verify_failed(self): - raise unittest.SkipTest("IocpEventLoop incompatible with SSL") + def test_legacy_create_server_ssl_verify_failed(self): + raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL") - def test_create_server_ssl_match_failed(self): - raise unittest.SkipTest("IocpEventLoop incompatible with SSL") + def test_legacy_create_server_ssl_match_failed(self): + raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL") - def test_create_server_ssl_verified(self): - raise unittest.SkipTest("IocpEventLoop incompatible with SSL") + def test_legacy_create_server_ssl_verified(self): + raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL") def test_reader_callback(self): raise unittest.SkipTest("IocpEventLoop does not have add_reader()") diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index ff114f82..360327aa 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -59,9 +59,13 @@ def test_make_ssl_transport(self): with test_utils.disable_logger(): transport = self.loop._make_ssl_transport( m, asyncio.Protocol(), m, waiter) - self.assertIsInstance(transport, _SelectorSslTransport) + # Sanity check + class_name = transport.__class__.__name__ + self.assertIn("ssl", class_name.lower()) + self.assertIn("transport", class_name.lower()) @mock.patch('asyncio.selector_events.ssl', None) + @mock.patch('asyncio.sslproto.ssl', None) def test_make_ssl_transport_without_ssl_error(self): m = mock.Mock() self.loop.add_reader = mock.Mock() From 51193dd1b20c0d26841c48836a8f9078568d9d92 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 14 Jan 2015 00:52:39 +0100 Subject: [PATCH 1261/1502] Python issue #23198: Reactor StreamReader - Add a new _wakeup_waiter() method - Replace _create_waiter() method with a _wait_for_data() coroutine function - Use the value None instead of True or False to wake up the waiter --- asyncio/streams.py | 47 ++++++++++++++++++++++------------------------ 1 file changed, 22 insertions(+), 25 deletions(-) diff --git a/asyncio/streams.py b/asyncio/streams.py index 5a96b241..7ff16a48 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -313,8 +313,8 @@ def __init__(self, limit=_DEFAULT_LIMIT, loop=None): else: self._loop = loop self._buffer = bytearray() - self._eof = False # Whether we're done. - self._waiter = None # A future. + self._eof = False # Whether we're done. + self._waiter = None # A future used by _wait_for_data() self._exception = None self._transport = None self._paused = False @@ -331,6 +331,14 @@ def set_exception(self, exc): if not waiter.cancelled(): waiter.set_exception(exc) + def _wakeup_waiter(self): + """Wakeup read() or readline() function waiting for data or EOF.""" + waiter = self._waiter + if waiter is not None: + self._waiter = None + if not waiter.cancelled(): + waiter.set_result(None) + def set_transport(self, transport): assert self._transport is None, 'Transport already set' self._transport = transport @@ -342,11 +350,7 @@ def _maybe_resume_transport(self): def feed_eof(self): self._eof = True - waiter = self._waiter - if waiter is not None: - self._waiter = None - if not waiter.cancelled(): - waiter.set_result(True) + self._wakeup_waiter() def at_eof(self): """Return True if the buffer is empty and 'feed_eof' was called.""" @@ -359,12 +363,7 @@ def feed_data(self, data): return self._buffer.extend(data) - - waiter = self._waiter - if waiter is not None: - self._waiter = None - if not waiter.cancelled(): - waiter.set_result(False) + self._wakeup_waiter() if (self._transport is not None and not self._paused and @@ -379,7 +378,8 @@ def feed_data(self, data): else: self._paused = True - def _create_waiter(self, func_name): + def _wait_for_data(self, func_name): + """Wait until feed_data() or feed_eof() is called.""" # StreamReader uses a future to link the protocol feed_data() method # to a read coroutine. Running two read coroutines at the same time # would have an unexpected behaviour. It would not possible to know @@ -387,7 +387,12 @@ def _create_waiter(self, func_name): if self._waiter is not None: raise RuntimeError('%s() called while another coroutine is ' 'already waiting for incoming data' % func_name) - return futures.Future(loop=self._loop) + + self._waiter = futures.Future(loop=self._loop) + try: + yield from self._waiter + finally: + self._waiter = None @coroutine def readline(self): @@ -417,11 +422,7 @@ def readline(self): break if not_enough: - self._waiter = self._create_waiter('readline') - try: - yield from self._waiter - finally: - self._waiter = None + yield from self._wait_for_data('readline') self._maybe_resume_transport() return bytes(line) @@ -448,11 +449,7 @@ def read(self, n=-1): return b''.join(blocks) else: if not self._buffer and not self._eof: - self._waiter = self._create_waiter('read') - try: - yield from self._waiter - finally: - self._waiter = None + yield from self._wait_for_data('read') if n < 0 or len(self._buffer) <= n: data = bytes(self._buffer) From da030f978e38e95c07e97adf24d9d4f171a04384 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 14 Jan 2015 02:06:07 +0100 Subject: [PATCH 1262/1502] Python issue #23173: If an exception is raised during the creation of a subprocess, kill the subprocess (close pipes, kill and read the return status). Log an error in such case. --- asyncio/base_subprocess.py | 77 ++++++++++++++++++++++++++------------ asyncio/subprocess.py | 12 +++++- 2 files changed, 63 insertions(+), 26 deletions(-) diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index afc434de..0787ad70 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -96,32 +96,61 @@ def terminate(self): def kill(self): self._proc.kill() + def _kill_wait(self): + """Close pipes, kill the subprocess and read its return status. + + Function called when an exception is raised during the creation + of a subprocess. + """ + if self._loop.get_debug(): + logger.warning('Exception during subprocess creation, ' + 'kill the subprocess %r', + self, + exc_info=True) + + proc = self._proc + if proc.stdout: + proc.stdout.close() + if proc.stderr: + proc.stderr.close() + if proc.stdin: + proc.stdin.close() + try: + proc.kill() + except ProcessLookupError: + pass + proc.wait() + @coroutine def _post_init(self): - proc = self._proc - loop = self._loop - if proc.stdin is not None: - _, pipe = yield from loop.connect_write_pipe( - lambda: WriteSubprocessPipeProto(self, 0), - proc.stdin) - self._pipes[0] = pipe - if proc.stdout is not None: - _, pipe = yield from loop.connect_read_pipe( - lambda: ReadSubprocessPipeProto(self, 1), - proc.stdout) - self._pipes[1] = pipe - if proc.stderr is not None: - _, pipe = yield from loop.connect_read_pipe( - lambda: ReadSubprocessPipeProto(self, 2), - proc.stderr) - self._pipes[2] = pipe - - assert self._pending_calls is not None - - self._loop.call_soon(self._protocol.connection_made, self) - for callback, data in self._pending_calls: - self._loop.call_soon(callback, *data) - self._pending_calls = None + try: + proc = self._proc + loop = self._loop + if proc.stdin is not None: + _, pipe = yield from loop.connect_write_pipe( + lambda: WriteSubprocessPipeProto(self, 0), + proc.stdin) + self._pipes[0] = pipe + if proc.stdout is not None: + _, pipe = yield from loop.connect_read_pipe( + lambda: ReadSubprocessPipeProto(self, 1), + proc.stdout) + self._pipes[1] = pipe + if proc.stderr is not None: + _, pipe = yield from loop.connect_read_pipe( + lambda: ReadSubprocessPipeProto(self, 2), + proc.stderr) + self._pipes[2] = pipe + + assert self._pending_calls is not None + + self._loop.call_soon(self._protocol.connection_made, self) + for callback, data in self._pending_calls: + self._loop.call_soon(callback, *data) + self._pending_calls = None + except: + self._kill_wait() + raise def _call(self, cb, *data): if self._pending_calls is not None: diff --git a/asyncio/subprocess.py b/asyncio/subprocess.py index a8ad03c2..d83442ef 100644 --- a/asyncio/subprocess.py +++ b/asyncio/subprocess.py @@ -216,7 +216,11 @@ def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None, protocol_factory, cmd, stdin=stdin, stdout=stdout, stderr=stderr, **kwds) - yield from protocol.waiter + try: + yield from protocol.waiter + except: + transport._kill_wait() + raise return Process(transport, protocol, loop) @coroutine @@ -232,5 +236,9 @@ def create_subprocess_exec(program, *args, stdin=None, stdout=None, program, *args, stdin=stdin, stdout=stdout, stderr=stderr, **kwds) - yield from protocol.waiter + try: + yield from protocol.waiter + except: + transport._kill_wait() + raise return Process(transport, protocol, loop) From a1e0e99652fe4b368a945ecc92800c28541f1182 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 14 Jan 2015 02:09:23 +0100 Subject: [PATCH 1263/1502] Python issue #23173: Fix SubprocessStreamProtocol.connection_made() to handle cancelled waiter. Add unit test cancelling subprocess methods. --- asyncio/subprocess.py | 4 +++- tests/test_subprocess.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/asyncio/subprocess.py b/asyncio/subprocess.py index d83442ef..a028339c 100644 --- a/asyncio/subprocess.py +++ b/asyncio/subprocess.py @@ -60,7 +60,9 @@ def connection_made(self, transport): protocol=self, reader=None, loop=self._loop) - self.waiter.set_result(None) + + if not self.waiter.cancelled(): + self.waiter.set_result(None) def pipe_data_received(self, fd, data): if fd == 1: diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index 5fc1dc0a..b2f1b953 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -251,6 +251,42 @@ def cancel_wait(): self.loop.run_until_complete(cancel_wait()) + def test_cancel_make_subprocess_transport_exec(self): + @asyncio.coroutine + def cancel_make_transport(): + coro = asyncio.create_subprocess_exec(*PROGRAM_BLOCKED, + loop=self.loop) + task = self.loop.create_task(coro) + + self.loop.call_soon(task.cancel) + try: + yield from task + except asyncio.CancelledError: + pass + + # ignore the log: + # "Exception during subprocess creation, kill the subprocess" + with test_utils.disable_logger(): + self.loop.run_until_complete(cancel_make_transport()) + + def test_cancel_post_init(self): + @asyncio.coroutine + def cancel_make_transport(): + coro = self.loop.subprocess_exec(asyncio.SubprocessProtocol, + *PROGRAM_BLOCKED) + task = self.loop.create_task(coro) + + self.loop.call_soon(task.cancel) + try: + yield from task + except asyncio.CancelledError: + pass + + # ignore the log: + # "Exception during subprocess creation, kill the subprocess" + with test_utils.disable_logger(): + self.loop.run_until_complete(cancel_make_transport()) + if sys.platform != 'win32': # Unix From 9d44c821817898cf8d0452080374f5d0d9c35f9d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 14 Jan 2015 16:50:41 +0100 Subject: [PATCH 1264/1502] Cleanup sslproto.py --- asyncio/sslproto.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/asyncio/sslproto.py b/asyncio/sslproto.py index 987c158e..dc03cf59 100644 --- a/asyncio/sslproto.py +++ b/asyncio/sslproto.py @@ -530,10 +530,11 @@ def _on_handshake_complete(self, handshake_exc): self._in_handshake = False sslobj = self._sslpipe.ssl_object - peercert = None if handshake_exc else sslobj.getpeercert() try: if handshake_exc is not None: raise handshake_exc + + peercert = sslobj.getpeercert() if not hasattr(self._sslcontext, 'check_hostname'): # Verify hostname if requested, Python 3.4+ uses check_hostname # and checks the hostname in do_handshake() From 398fdbe6590522e803cc4c4eac353bf5fc344552 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 14 Jan 2015 16:51:38 +0100 Subject: [PATCH 1265/1502] Python issue #23197: On SSL handshake failure, check if the waiter is cancelled before setting its exception. Add unit tests for this case. --- asyncio/selector_events.py | 2 +- asyncio/sslproto.py | 2 +- tests/test_selector_events.py | 20 ++++++++++++---- tests/test_sslproto.py | 45 +++++++++++++++++++++++++++++++++++ 4 files changed, 63 insertions(+), 6 deletions(-) create mode 100644 tests/test_sslproto.py diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index b2f29c70..ca862648 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -750,7 +750,7 @@ def _on_handshake(self, start_time): self._loop.remove_reader(self._sock_fd) self._loop.remove_writer(self._sock_fd) self._sock.close() - if self._waiter is not None: + if self._waiter is not None and not self._waiter.cancelled(): self._waiter.set_exception(exc) if isinstance(exc, Exception): return diff --git a/asyncio/sslproto.py b/asyncio/sslproto.py index dc03cf59..541e2527 100644 --- a/asyncio/sslproto.py +++ b/asyncio/sslproto.py @@ -552,7 +552,7 @@ def _on_handshake_complete(self, handshake_exc): self, exc_info=True) self._transport.close() if isinstance(exc, Exception): - if self._waiter is not None: + if self._waiter is not None and not self._waiter.cancelled(): self._waiter.set_exception(exc) return else: diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 360327aa..64c2e650 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -1148,16 +1148,28 @@ def test_on_handshake_exc(self): self.assertTrue(self.sslsock.close.called) def test_on_handshake_base_exc(self): + waiter = asyncio.Future(loop=self.loop) transport = _SelectorSslTransport( - self.loop, self.sock, self.protocol, self.sslcontext) - transport._waiter = asyncio.Future(loop=self.loop) + self.loop, self.sock, self.protocol, self.sslcontext, waiter) exc = BaseException() self.sslsock.do_handshake.side_effect = exc with test_utils.disable_logger(): self.assertRaises(BaseException, transport._on_handshake, 0) self.assertTrue(self.sslsock.close.called) - self.assertTrue(transport._waiter.done()) - self.assertIs(exc, transport._waiter.exception()) + self.assertTrue(waiter.done()) + self.assertIs(exc, waiter.exception()) + + def test_cancel_handshake(self): + # Python issue #23197: cancelling an handshake must not raise an + # exception or log an error, even if the handshake failed + waiter = asyncio.Future(loop=self.loop) + transport = _SelectorSslTransport( + self.loop, self.sock, self.protocol, self.sslcontext, waiter) + waiter.cancel() + exc = ValueError() + self.sslsock.do_handshake.side_effect = exc + with test_utils.disable_logger(): + transport._on_handshake(0) def test_pause_resume_reading(self): tr = self._make_one() diff --git a/tests/test_sslproto.py b/tests/test_sslproto.py new file mode 100644 index 00000000..053fefe7 --- /dev/null +++ b/tests/test_sslproto.py @@ -0,0 +1,45 @@ +"""Tests for asyncio/sslproto.py.""" + +import unittest +from unittest import mock + +import asyncio +from asyncio import sslproto +from asyncio import test_utils + + +class SslProtoHandshakeTests(test_utils.TestCase): + + def setUp(self): + self.loop = asyncio.new_event_loop() + self.set_event_loop(self.loop) + + def test_cancel_handshake(self): + # Python issue #23197: cancelling an handshake must not raise an + # exception or log an error, even if the handshake failed + sslcontext = test_utils.dummy_ssl_context() + app_proto = asyncio.Protocol() + waiter = asyncio.Future(loop=self.loop) + ssl_proto = sslproto.SSLProtocol(self.loop, app_proto, sslcontext, + waiter) + handshake_fut = asyncio.Future(loop=self.loop) + + def do_handshake(callback): + exc = Exception() + callback(exc) + handshake_fut.set_result(None) + return [] + + waiter.cancel() + transport = mock.Mock() + sslpipe = mock.Mock() + sslpipe.do_handshake.side_effect = do_handshake + with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe): + ssl_proto.connection_made(transport) + + with test_utils.disable_logger(): + self.loop.run_until_complete(handshake_fut) + + +if __name__ == '__main__': + unittest.main() From 0acde31423c4a9704a2f38dbc6eab4a8c681b3dd Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 14 Jan 2015 17:12:59 +0100 Subject: [PATCH 1266/1502] Python issue #23197: On SSL handshake failure on matching hostname, check if the waiter is cancelled before setting its exception. --- asyncio/selector_events.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index ca862648..074a8df0 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -774,7 +774,8 @@ def _on_handshake(self, start_time): "on matching the hostname", self, exc_info=True) self._sock.close() - if self._waiter is not None: + if (self._waiter is not None + and not self._waiter.cancelled()): self._waiter.set_exception(exc) return From 3fd2599228979af641ae452fa8b6d69e403d115d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 14 Jan 2015 22:55:43 +0100 Subject: [PATCH 1267/1502] UNIX pipe transports: add closed/closing in repr() Add "closed" or "closing" state in the __repr__() method of _UnixReadPipeTransport and _UnixWritePipeTransport classes. --- asyncio/unix_events.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 1a4d4183..14b48438 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -301,7 +301,12 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): self._loop.call_soon(waiter._set_result_unless_cancelled, None) def __repr__(self): - info = [self.__class__.__name__, 'fd=%s' % self._fileno] + info = [self.__class__.__name__] + if self._pipe is None: + info.append('closed') + elif self._closing: + info.append('closing') + info.append('fd=%s' % self._fileno) if self._pipe is not None: polling = selector_events._test_selector_event( self._loop._selector, @@ -404,7 +409,12 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): self._loop.call_soon(waiter._set_result_unless_cancelled, None) def __repr__(self): - info = [self.__class__.__name__, 'fd=%s' % self._fileno] + info = [self.__class__.__name__] + if self._pipe is None: + info.append('closed') + elif self._closing: + info.append('closing') + info.append('fd=%s' % self._fileno) if self._pipe is not None: polling = selector_events._test_selector_event( self._loop._selector, From 4bb6d524dbf571800cf0f2b33da340757ea29f1d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 14 Jan 2015 22:58:14 +0100 Subject: [PATCH 1268/1502] Tests: explicitly close event loops and transports --- tests/test_base_events.py | 1 + tests/test_events.py | 5 +++++ tests/test_futures.py | 1 + tests/test_selector_events.py | 1 + tests/test_unix_events.py | 1 + 5 files changed, 9 insertions(+) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 6bf7e796..bd6c0d80 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -409,6 +409,7 @@ def test_run_until_complete_type_error(self): def test_run_until_complete_loop(self): task = asyncio.Future(loop=self.loop) other_loop = self.new_test_loop() + self.addCleanup(other_loop.close) self.assertRaises(ValueError, other_loop.run_until_complete, task) diff --git a/tests/test_events.py b/tests/test_events.py index a2c6dc94..d40c8e46 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1585,6 +1585,7 @@ def test_subprocess_shell(self): self.assertTrue(all(f.done() for f in proto.disconnects.values())) self.assertEqual(proto.data[1].rstrip(b'\r\n'), b'Python') self.assertEqual(proto.data[2], b'') + transp.close() def test_subprocess_exitcode(self): connect = self.loop.subprocess_shell( @@ -1594,6 +1595,7 @@ def test_subprocess_exitcode(self): self.assertIsInstance(proto, MySubprocessProtocol) self.loop.run_until_complete(proto.completed) self.assertEqual(7, proto.returncode) + transp.close() def test_subprocess_close_after_finish(self): connect = self.loop.subprocess_shell( @@ -1621,6 +1623,7 @@ def test_subprocess_kill(self): transp.kill() self.loop.run_until_complete(proto.completed) self.check_killed(proto.returncode) + transp.close() def test_subprocess_terminate(self): prog = os.path.join(os.path.dirname(__file__), 'echo.py') @@ -1635,6 +1638,7 @@ def test_subprocess_terminate(self): transp.terminate() self.loop.run_until_complete(proto.completed) self.check_terminated(proto.returncode) + transp.close() @unittest.skipIf(sys.platform == 'win32', "Don't have SIGHUP") def test_subprocess_send_signal(self): @@ -1650,6 +1654,7 @@ def test_subprocess_send_signal(self): transp.send_signal(signal.SIGHUP) self.loop.run_until_complete(proto.completed) self.assertEqual(-signal.SIGHUP, proto.returncode) + transp.close() def test_subprocess_stderr(self): prog = os.path.join(os.path.dirname(__file__), 'echo2.py') diff --git a/tests/test_futures.py b/tests/test_futures.py index dac1e897..c8b6829f 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -29,6 +29,7 @@ class FutureTests(test_utils.TestCase): def setUp(self): self.loop = self.new_test_loop() + self.addCleanup(self.loop.close) def test_initial_state(self): f = asyncio.Future(loop=self.loop) diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 64c2e650..3e8392e7 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -1744,6 +1744,7 @@ def test_fatal_error_connected(self, m_exc): test_utils.MockPattern( 'Fatal error on transport\nprotocol:.*\ntransport:.*'), exc_info=(ConnectionRefusedError, MOCK_ANY, MOCK_ANY)) + transport.close() if __name__ == '__main__': diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 4b825dc8..5f4b0244 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -598,6 +598,7 @@ def test_write_err(self, m_write, m_log): # This is a bit overspecified. :-( m_log.warning.assert_called_with( 'pipe closed by peer or os.write(pipe, data) raised exception.') + tr.close() @mock.patch('os.write') def test_write_close(self, m_write): From de9990e17df33f63e5ba85a8beeafbfe4cc99c19 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 14 Jan 2015 22:45:48 +0100 Subject: [PATCH 1269/1502] Fix BaseSubprocessTransport._kill_wait() Set the _returncode attribute, so close() doesn't try to terminate the process. --- asyncio/base_subprocess.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index 0787ad70..d607e8d3 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -119,7 +119,7 @@ def _kill_wait(self): proc.kill() except ProcessLookupError: pass - proc.wait() + self._returncode = proc.wait() @coroutine def _post_init(self): From 233af8cdf021f16a2bba31dafff70f0616f48fec Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 14 Jan 2015 22:48:59 +0100 Subject: [PATCH 1270/1502] cleanup BaseSelectorEventLoop Create the protocol on a separated line for readability and ease debugging. --- asyncio/selector_events.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 074a8df0..4d3e5d9e 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -182,13 +182,14 @@ def _accept_connection(self, protocol_factory, sock, else: raise # The event loop will catch, log and ignore it. else: + protocol = protocol_factory() if sslcontext: self._make_ssl_transport( - conn, protocol_factory(), sslcontext, + conn, protocol, sslcontext, server_side=True, extra={'peername': addr}, server=server) else: self._make_socket_transport( - conn, protocol_factory(), extra={'peername': addr}, + conn, protocol , extra={'peername': addr}, server=server) # It's now up to the protocol to handle the connection. From ed21baf4a83e259a57a98940e8c3deb02defd2dd Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 14 Jan 2015 22:53:15 +0100 Subject: [PATCH 1271/1502] TestLoop.close() now calls the close() method of the parent class (BaseEventLoop) --- asyncio/test_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index 180bafa1..6eedc583 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -307,6 +307,7 @@ def advance_time(self, advance): self._time += advance def close(self): + super().close() if self._check_on_close: try: self._gen.send(0) From 158f6213056c55d5378f8630216ef2ba4e4e1312 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 14 Jan 2015 22:46:58 +0100 Subject: [PATCH 1272/1502] Fix BaseSubprocessTransport.close() Ignore pipes for which the protocol is not set yet (still equal to None). --- asyncio/base_subprocess.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index d607e8d3..f5e7dfec 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -71,6 +71,8 @@ def _make_read_subprocess_pipe_proto(self, fd): def close(self): for proto in self._pipes.values(): + if proto is None: + continue proto.pipe.close() if self._returncode is None: self.terminate() From d5cc0307ee96db861b40d4ec6f4a217f1a40fcd3 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 14 Jan 2015 23:21:36 +0100 Subject: [PATCH 1273/1502] _ProactorBasePipeTransport now sets _sock to None when the transport is closed --- asyncio/proactor_events.py | 1 + 1 file changed, 1 insertion(+) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 4716bb59..0ecb44eb 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -111,6 +111,7 @@ def _call_connection_lost(self, exc): if hasattr(self._sock, 'shutdown'): self._sock.shutdown(socket.SHUT_RDWR) self._sock.close() + self._sock = None server = self._server if server is not None: server._detach() From f2e08a0325a28cdd3af1223312c9982525cd12bb Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 14 Jan 2015 22:44:41 +0100 Subject: [PATCH 1274/1502] Fix BaseEventLoop._create_connection_transport() Close the transport if the creation of the transport (if the waiter) gets an exception. --- asyncio/base_events.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 35c8d742..5df5b83b 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -634,7 +634,12 @@ def _create_connection_transport(self, sock, protocol_factory, ssl, else: transport = self._make_socket_transport(sock, protocol, waiter) - yield from waiter + try: + yield from waiter + except Exception as exc: + transport.close() + raise + return transport, protocol @coroutine From 58d27b7c3070e90436acb19e17da97f1bb3999f2 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 14 Jan 2015 23:43:03 +0100 Subject: [PATCH 1275/1502] Fix test_events on Python older than 3.5 Skip SSL tests on the ProactorEventLoop if ssl.MemoryIO is missing --- tests/test_events.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/test_events.py b/tests/test_events.py index d40c8e46..a38c90eb 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -25,6 +25,7 @@ import asyncio from asyncio import proactor_events from asyncio import selector_events +from asyncio import sslproto from asyncio import test_utils try: from test import support @@ -1789,6 +1790,22 @@ class ProactorEventLoopTests(EventLoopTestsMixin, def create_event_loop(self): return asyncio.ProactorEventLoop() + if not sslproto._is_sslproto_available(): + def test_create_ssl_connection(self): + raise unittest.SkipTest("need python 3.5 (ssl.MemoryBIO)") + + def test_create_server_ssl(self): + raise unittest.SkipTest("need python 3.5 (ssl.MemoryBIO)") + + def test_create_server_ssl_verify_failed(self): + raise unittest.SkipTest("need python 3.5 (ssl.MemoryBIO)") + + def test_create_server_ssl_match_failed(self): + raise unittest.SkipTest("need python 3.5 (ssl.MemoryBIO)") + + def test_create_server_ssl_verified(self): + raise unittest.SkipTest("need python 3.5 (ssl.MemoryBIO)") + def test_legacy_create_ssl_connection(self): raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL") From 20a103cc90135494162e819f98d0edfc1f1fba6b Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 14 Jan 2015 22:57:27 +0100 Subject: [PATCH 1276/1502] PipeHandle now uses None instead of -1 for a closed handle Sort also imports in windows_utils. --- asyncio/windows_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/asyncio/windows_utils.py b/asyncio/windows_utils.py index b4758123..e6642960 100644 --- a/asyncio/windows_utils.py +++ b/asyncio/windows_utils.py @@ -7,13 +7,13 @@ if sys.platform != 'win32': # pragma: no cover raise ImportError('win32 only') -import socket +import _winapi import itertools import msvcrt import os +import socket import subprocess import tempfile -import _winapi __all__ = ['socketpair', 'pipe', 'Popen', 'PIPE', 'PipeHandle'] @@ -136,7 +136,7 @@ def __init__(self, handle): self._handle = handle def __repr__(self): - if self._handle != -1: + if self._handle is not None: handle = 'handle=%r' % self._handle else: handle = 'closed' @@ -150,9 +150,9 @@ def fileno(self): return self._handle def close(self, *, CloseHandle=_winapi.CloseHandle): - if self._handle != -1: + if self._handle is not None: CloseHandle(self._handle) - self._handle = -1 + self._handle = None __del__ = close From 47152d63f3ffa9182afb80456d552012a02bc38c Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 15 Jan 2015 09:32:18 +0100 Subject: [PATCH 1277/1502] StreamWriter: close() now clears the reference to the transport StreamWriter now raises an exception if it is closed: write(), writelines(), write_eof(), can_write_eof(), get_extra_info(), drain(). --- asyncio/streams.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/asyncio/streams.py b/asyncio/streams.py index 7ff16a48..12ab1c52 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -258,8 +258,22 @@ def __init__(self, transport, protocol, reader, loop): self._reader = reader self._loop = loop + def close(self): + if self._transport is None: + return + self._transport.close() + self._transport = None + + def _check_closed(self): + if self._transport is None: + raise RuntimeError('StreamWriter is closed') + def __repr__(self): - info = [self.__class__.__name__, 'transport=%r' % self._transport] + info = [self.__class__.__name__] + if self._transport is not None: + info.append('transport=%r' % self._transport) + else: + info.append('closed') if self._reader is not None: info.append('reader=%r' % self._reader) return '<%s>' % ' '.join(info) @@ -269,21 +283,23 @@ def transport(self): return self._transport def write(self, data): + self._check_closed() self._transport.write(data) def writelines(self, data): + self._check_closed() self._transport.writelines(data) def write_eof(self): + self._check_closed() return self._transport.write_eof() def can_write_eof(self): + self._check_closed() return self._transport.can_write_eof() - def close(self): - return self._transport.close() - def get_extra_info(self, name, default=None): + self._check_closed() return self._transport.get_extra_info(name, default) @coroutine @@ -295,6 +311,7 @@ def drain(self): w.write(data) yield from w.drain() """ + self._check_closed() if self._reader is not None: exc = self._reader.exception() if exc is not None: From 1da6e1d82c7a466d7d14b928b0f2b1fa568f6417 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 15 Jan 2015 09:40:16 +0100 Subject: [PATCH 1278/1502] Python issue #22560: Fix SSLProtocol._on_handshake_complete() Don't call immediatly self._process_write_backlog() but schedule the call using call_soon(). _on_handshake_complete() can be called indirectly from _process_write_backlog(), and _process_write_backlog() is not reentrant. --- asyncio/sslproto.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/asyncio/sslproto.py b/asyncio/sslproto.py index 541e2527..c7fb4e7c 100644 --- a/asyncio/sslproto.py +++ b/asyncio/sslproto.py @@ -572,8 +572,12 @@ def _on_handshake_complete(self, handshake_exc): # wait until protocol.connection_made() has been called self._waiter._set_result_unless_cancelled(None) self._session_established = True - # In case transport.write() was already called - self._process_write_backlog() + # In case transport.write() was already called. Don't call + # immediatly _process_write_backlog(), but schedule it: + # _on_handshake_complete() can be called indirectly from + # _process_write_backlog(), and _process_write_backlog() is not + # reentrant. + self._loop.call_soon(self._process_write_backlog) def _process_write_backlog(self): # Try to make progress on the write backlog. From 3ce21be63a0a481c56958749b4aa46665b212387 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 15 Jan 2015 12:52:05 +0100 Subject: [PATCH 1279/1502] Python issue #23242: SubprocessStreamProtocol now closes the subprocess transport at subprocess exit. Clear also its reference to the transport. --- asyncio/subprocess.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/asyncio/subprocess.py b/asyncio/subprocess.py index a028339c..c848a21a 100644 --- a/asyncio/subprocess.py +++ b/asyncio/subprocess.py @@ -94,8 +94,11 @@ def pipe_connection_lost(self, fd, exc): reader.set_exception(exc) def process_exited(self): - # wake up futures waiting for wait() returncode = self._transport.get_returncode() + self._transport.close() + self._transport = None + + # wake up futures waiting for wait() while self._waiters: waiter = self._waiters.popleft() if not waiter.cancelled(): From 8addd76f4b56b1d8239f624fc97d1e6a9674b4e7 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 15 Jan 2015 12:50:35 +0100 Subject: [PATCH 1280/1502] SSLProtocol: set the _transport attribute in the constructor --- asyncio/sslproto.py | 1 + 1 file changed, 1 insertion(+) diff --git a/asyncio/sslproto.py b/asyncio/sslproto.py index c7fb4e7c..117dc565 100644 --- a/asyncio/sslproto.py +++ b/asyncio/sslproto.py @@ -417,6 +417,7 @@ def __init__(self, loop, app_protocol, sslcontext, waiter, self._session_established = False self._in_handshake = False self._in_shutdown = False + self._transport = None def connection_made(self, transport): """Called when the low-level connection is made. From 693fa4b2242d0d64c3343797396ec61fd5774bc0 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 15 Jan 2015 12:54:49 +0100 Subject: [PATCH 1281/1502] Python issue #23243: Fix _UnixWritePipeTransport.close() Do nothing if the transport is already closed. Before it was not possible to close the transport twice. --- asyncio/unix_events.py | 2 +- tests/test_unix_events.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 14b48438..9f4005cb 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -516,7 +516,7 @@ def write_eof(self): self._loop.call_soon(self._call_connection_lost, None) def close(self): - if not self._closing: + if self._pipe is not None and not self._closing: # write_eof is all what we needed to close the write pipe self.write_eof() diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 5f4b0244..4a68ce36 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -766,6 +766,9 @@ def test_close(self): tr.close() tr.write_eof.assert_called_with() + # closing the transport twice must not fail + tr.close() + def test_close_closing(self): tr = unix_events._UnixWritePipeTransport( self.loop, self.pipe, self.protocol) From ae6f24cb14b8036f7ea1fd808183e8e3450275e4 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 15 Jan 2015 13:10:31 +0100 Subject: [PATCH 1282/1502] Python issue #23243: Close explicitly event loops in tests --- tests/test_base_events.py | 1 + tests/test_proactor_events.py | 4 ++++ tests/test_selector_events.py | 16 +++++++++++++++- 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index bd6c0d80..9e7c50cc 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -590,6 +590,7 @@ def default_exception_handler(self, context): raise ValueError('spam') loop = Loop() + self.addCleanup(loop.close) asyncio.set_event_loop(loop) def run_loop(): diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index 82582383..08c622a8 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -16,6 +16,7 @@ class ProactorSocketTransportTests(test_utils.TestCase): def setUp(self): self.loop = self.new_test_loop() + self.addCleanup(self.loop.close) self.proactor = mock.Mock() self.loop._proactor = self.proactor self.protocol = test_utils.make_test_protocol(asyncio.Protocol) @@ -459,6 +460,9 @@ def test_close_self_pipe(self): self.assertIsNone(self.loop._ssock) self.assertIsNone(self.loop._csock) + # Don't call close(): _close_self_pipe() cannot be called twice + self.loop._closed = True + def test_close(self): self.loop._close_self_pipe = mock.Mock() self.loop.close() diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 3e8392e7..f5194193 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -24,6 +24,11 @@ class TestBaseSelectorEventLoop(BaseSelectorEventLoop): + def close(self): + # Don't call the close() method of the parent class, because the + # selector is mocked + self._closed = True + def _make_self_pipe(self): self._ssock = mock.Mock() self._csock = mock.Mock() @@ -40,7 +45,7 @@ def setUp(self): self.selector = mock.Mock() self.selector.select.return_value = [] self.loop = TestBaseSelectorEventLoop(self.selector) - self.set_event_loop(self.loop, cleanup=False) + self.set_event_loop(self.loop) def test_make_socket_transport(self): m = mock.Mock() @@ -76,6 +81,15 @@ def test_make_ssl_transport_without_ssl_error(self): self.loop._make_ssl_transport(m, m, m, m) def test_close(self): + class EventLoop(BaseSelectorEventLoop): + def _make_self_pipe(self): + self._ssock = mock.Mock() + self._csock = mock.Mock() + self._internal_fds += 1 + + self.loop = EventLoop(self.selector) + self.set_event_loop(self.loop) + ssock = self.loop._ssock ssock.fileno.return_value = 7 csock = self.loop._csock From eba6a462551421fd0be31899436dcfdbead086be Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 15 Jan 2015 13:08:17 +0100 Subject: [PATCH 1283/1502] Python issue #23243: Close explicitly transports in tests --- tests/test_proactor_events.py | 87 ++++++----- tests/test_selector_events.py | 262 +++++++++++++++------------------- tests/test_unix_events.py | 154 ++++++++------------ 3 files changed, 226 insertions(+), 277 deletions(-) diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index 08c622a8..dee147e5 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -12,6 +12,15 @@ from asyncio import test_utils +def close_transport(transport): + # Don't call transport.close() because the event loop and the IOCP proactor + # are mocked + if transport._sock is None: + return + transport._sock.close() + transport._sock = None + + class ProactorSocketTransportTests(test_utils.TestCase): def setUp(self): @@ -22,17 +31,22 @@ def setUp(self): self.protocol = test_utils.make_test_protocol(asyncio.Protocol) self.sock = mock.Mock(socket.socket) + def socket_transport(self, waiter=None): + transport = _ProactorSocketTransport(self.loop, self.sock, + self.protocol, waiter=waiter) + self.addCleanup(close_transport, transport) + return transport + def test_ctor(self): fut = asyncio.Future(loop=self.loop) - tr = _ProactorSocketTransport( - self.loop, self.sock, self.protocol, fut) + tr = self.socket_transport(waiter=fut) test_utils.run_briefly(self.loop) self.assertIsNone(fut.result()) self.protocol.connection_made(tr) self.proactor.recv.assert_called_with(self.sock, 4096) def test_loop_reading(self): - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr = self.socket_transport() tr._loop_reading() self.loop._proactor.recv.assert_called_with(self.sock, 4096) self.assertFalse(self.protocol.data_received.called) @@ -42,8 +56,7 @@ def test_loop_reading_data(self): res = asyncio.Future(loop=self.loop) res.set_result(b'data') - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) - + tr = self.socket_transport() tr._read_fut = res tr._loop_reading(res) self.loop._proactor.recv.assert_called_with(self.sock, 4096) @@ -53,8 +66,7 @@ def test_loop_reading_no_data(self): res = asyncio.Future(loop=self.loop) res.set_result(b'') - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) - + tr = self.socket_transport() self.assertRaises(AssertionError, tr._loop_reading, res) tr.close = mock.Mock() @@ -67,7 +79,7 @@ def test_loop_reading_no_data(self): def test_loop_reading_aborted(self): err = self.loop._proactor.recv.side_effect = ConnectionAbortedError() - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr = self.socket_transport() tr._fatal_error = mock.Mock() tr._loop_reading() tr._fatal_error.assert_called_with( @@ -77,7 +89,7 @@ def test_loop_reading_aborted(self): def test_loop_reading_aborted_closing(self): self.loop._proactor.recv.side_effect = ConnectionAbortedError() - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr = self.socket_transport() tr._closing = True tr._fatal_error = mock.Mock() tr._loop_reading() @@ -85,7 +97,7 @@ def test_loop_reading_aborted_closing(self): def test_loop_reading_aborted_is_fatal(self): self.loop._proactor.recv.side_effect = ConnectionAbortedError() - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr = self.socket_transport() tr._closing = False tr._fatal_error = mock.Mock() tr._loop_reading() @@ -94,7 +106,7 @@ def test_loop_reading_aborted_is_fatal(self): def test_loop_reading_conn_reset_lost(self): err = self.loop._proactor.recv.side_effect = ConnectionResetError() - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr = self.socket_transport() tr._closing = False tr._fatal_error = mock.Mock() tr._force_close = mock.Mock() @@ -105,7 +117,7 @@ def test_loop_reading_conn_reset_lost(self): def test_loop_reading_exception(self): err = self.loop._proactor.recv.side_effect = (OSError()) - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr = self.socket_transport() tr._fatal_error = mock.Mock() tr._loop_reading() tr._fatal_error.assert_called_with( @@ -113,19 +125,19 @@ def test_loop_reading_exception(self): 'Fatal read error on pipe transport') def test_write(self): - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr = self.socket_transport() tr._loop_writing = mock.Mock() tr.write(b'data') self.assertEqual(tr._buffer, None) tr._loop_writing.assert_called_with(data=b'data') def test_write_no_data(self): - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr = self.socket_transport() tr.write(b'') self.assertFalse(tr._buffer) def test_write_more(self): - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr = self.socket_transport() tr._write_fut = mock.Mock() tr._loop_writing = mock.Mock() tr.write(b'data') @@ -133,7 +145,7 @@ def test_write_more(self): self.assertFalse(tr._loop_writing.called) def test_loop_writing(self): - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr = self.socket_transport() tr._buffer = bytearray(b'data') tr._loop_writing() self.loop._proactor.send.assert_called_with(self.sock, b'data') @@ -143,7 +155,7 @@ def test_loop_writing(self): @mock.patch('asyncio.proactor_events.logger') def test_loop_writing_err(self, m_log): err = self.loop._proactor.send.side_effect = OSError() - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr = self.socket_transport() tr._fatal_error = mock.Mock() tr._buffer = [b'da', b'ta'] tr._loop_writing() @@ -164,7 +176,7 @@ def test_loop_writing_stop(self): fut = asyncio.Future(loop=self.loop) fut.set_result(b'data') - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr = self.socket_transport() tr._write_fut = fut tr._loop_writing(fut) self.assertIsNone(tr._write_fut) @@ -173,7 +185,7 @@ def test_loop_writing_closing(self): fut = asyncio.Future(loop=self.loop) fut.set_result(1) - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr = self.socket_transport() tr._write_fut = fut tr.close() tr._loop_writing(fut) @@ -182,13 +194,13 @@ def test_loop_writing_closing(self): self.protocol.connection_lost.assert_called_with(None) def test_abort(self): - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr = self.socket_transport() tr._force_close = mock.Mock() tr.abort() tr._force_close.assert_called_with(None) def test_close(self): - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr = self.socket_transport() tr.close() test_utils.run_briefly(self.loop) self.protocol.connection_lost.assert_called_with(None) @@ -201,14 +213,14 @@ def test_close(self): self.assertFalse(self.protocol.connection_lost.called) def test_close_write_fut(self): - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr = self.socket_transport() tr._write_fut = mock.Mock() tr.close() test_utils.run_briefly(self.loop) self.assertFalse(self.protocol.connection_lost.called) def test_close_buffer(self): - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr = self.socket_transport() tr._buffer = [b'data'] tr.close() test_utils.run_briefly(self.loop) @@ -216,14 +228,14 @@ def test_close_buffer(self): @mock.patch('asyncio.base_events.logger') def test_fatal_error(self, m_logging): - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr = self.socket_transport() tr._force_close = mock.Mock() tr._fatal_error(None) self.assertTrue(tr._force_close.called) self.assertTrue(m_logging.error.called) def test_force_close(self): - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr = self.socket_transport() tr._buffer = [b'data'] read_fut = tr._read_fut = mock.Mock() write_fut = tr._write_fut = mock.Mock() @@ -237,14 +249,14 @@ def test_force_close(self): self.assertEqual(tr._conn_lost, 1) def test_force_close_idempotent(self): - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr = self.socket_transport() tr._closing = True tr._force_close(None) test_utils.run_briefly(self.loop) self.assertFalse(self.protocol.connection_lost.called) def test_fatal_error_2(self): - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr = self.socket_transport() tr._buffer = [b'data'] tr._force_close(None) @@ -253,14 +265,13 @@ def test_fatal_error_2(self): self.assertEqual(None, tr._buffer) def test_call_connection_lost(self): - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr = self.socket_transport() tr._call_connection_lost(None) self.assertTrue(self.protocol.connection_lost.called) self.assertTrue(self.sock.close.called) def test_write_eof(self): - tr = _ProactorSocketTransport( - self.loop, self.sock, self.protocol) + tr = self.socket_transport() self.assertTrue(tr.can_write_eof()) tr.write_eof() self.sock.shutdown.assert_called_with(socket.SHUT_WR) @@ -269,7 +280,7 @@ def test_write_eof(self): tr.close() def test_write_eof_buffer(self): - tr = _ProactorSocketTransport(self.loop, self.sock, self.protocol) + tr = self.socket_transport() f = asyncio.Future(loop=self.loop) tr._loop._proactor.send.return_value = f tr.write(b'data') @@ -313,11 +324,10 @@ def test_write_eof_duplex_pipe(self): self.assertFalse(tr.can_write_eof()) with self.assertRaises(NotImplementedError): tr.write_eof() - tr.close() + close_transport(tr) def test_pause_resume_reading(self): - tr = _ProactorSocketTransport( - self.loop, self.sock, self.protocol) + tr = self.socket_transport() futures = [] for msg in [b'data1', b'data2', b'data3', b'data4', b'']: f = asyncio.Future(loop=self.loop) @@ -345,10 +355,7 @@ def test_pause_resume_reading(self): def pause_writing_transport(self, high): - tr = _ProactorSocketTransport( - self.loop, self.sock, self.protocol) - self.addCleanup(tr.close) - + tr = self.socket_transport() tr.set_write_buffer_limits(high=high) self.assertEqual(tr.get_write_buffer_size(), 0) @@ -439,7 +446,7 @@ def _socketpair(s): return (self.ssock, self.csock) self.loop = EventLoop(self.proactor) - self.set_event_loop(self.loop, cleanup=False) + self.set_event_loop(self.loop) @mock.patch.object(BaseProactorEventLoop, 'call_soon') @mock.patch.object(BaseProactorEventLoop, '_socketpair') @@ -451,6 +458,7 @@ def test_ctor(self, socketpair, call_soon): self.assertIs(loop._csock, csock) self.assertEqual(loop._internal_fds, 1) call_soon.assert_called_with(loop._loop_self_reading) + loop.close() def test_close_self_pipe(self): self.loop._close_self_pipe() @@ -497,6 +505,7 @@ def test_socketpair(self): def test_make_socket_transport(self): tr = self.loop._make_socket_transport(self.sock, asyncio.Protocol()) self.assertIsInstance(tr, _ProactorSocketTransport) + close_transport(tr) def test_loop_self_reading(self): self.loop._loop_self_reading() diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index f5194193..f99d04d4 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -39,6 +39,15 @@ def list_to_buffer(l=()): return bytearray().join(l) +def close_transport(transport): + # Don't call transport.close() because the event loop and the selector + # are mocked + if transport._sock is None: + return + transport._sock.close() + transport._sock = None + + class BaseSelectorEventLoopTests(test_utils.TestCase): def setUp(self): @@ -52,6 +61,7 @@ def test_make_socket_transport(self): self.loop.add_reader = mock.Mock() transport = self.loop._make_socket_transport(m, asyncio.Protocol()) self.assertIsInstance(transport, _SelectorSocketTransport) + close_transport(transport) @unittest.skipIf(ssl is None, 'No ssl module') def test_make_ssl_transport(self): @@ -64,11 +74,19 @@ def test_make_ssl_transport(self): with test_utils.disable_logger(): transport = self.loop._make_ssl_transport( m, asyncio.Protocol(), m, waiter) + # execute the handshake while the logger is disabled + # to ignore SSL handshake failure + test_utils.run_briefly(self.loop) + # Sanity check class_name = transport.__class__.__name__ self.assertIn("ssl", class_name.lower()) self.assertIn("transport", class_name.lower()) + transport.close() + # execute pending callbacks to close the socket transport + test_utils.run_briefly(self.loop) + @mock.patch('asyncio.selector_events.ssl', None) @mock.patch('asyncio.sslproto.ssl', None) def test_make_ssl_transport_without_ssl_error(self): @@ -650,21 +668,27 @@ def setUp(self): self.sock = mock.Mock(socket.socket) self.sock.fileno.return_value = 7 + def create_transport(self): + transport = _SelectorTransport(self.loop, self.sock, self.protocol, + None) + self.addCleanup(close_transport, transport) + return transport + def test_ctor(self): - tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr = self.create_transport() self.assertIs(tr._loop, self.loop) self.assertIs(tr._sock, self.sock) self.assertIs(tr._sock_fd, 7) def test_abort(self): - tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr = self.create_transport() tr._force_close = mock.Mock() tr.abort() tr._force_close.assert_called_with(None) def test_close(self): - tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr = self.create_transport() tr.close() self.assertTrue(tr._closing) @@ -677,7 +701,7 @@ def test_close(self): self.assertEqual(1, self.loop.remove_reader_count[7]) def test_close_write_buffer(self): - tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr = self.create_transport() tr._buffer.extend(b'data') tr.close() @@ -686,7 +710,7 @@ def test_close_write_buffer(self): self.assertFalse(self.protocol.connection_lost.called) def test_force_close(self): - tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr = self.create_transport() tr._buffer.extend(b'1') self.loop.add_reader(7, mock.sentinel) self.loop.add_writer(7, mock.sentinel) @@ -705,7 +729,7 @@ def test_force_close(self): @mock.patch('asyncio.log.logger.error') def test_fatal_error(self, m_exc): exc = OSError() - tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr = self.create_transport() tr._force_close = mock.Mock() tr._fatal_error(exc) @@ -718,7 +742,7 @@ def test_fatal_error(self, m_exc): def test_connection_lost(self): exc = OSError() - tr = _SelectorTransport(self.loop, self.sock, self.protocol, None) + tr = self.create_transport() self.assertIsNotNone(tr._protocol) self.assertIsNotNone(tr._loop) tr._call_connection_lost(exc) @@ -739,9 +763,14 @@ def setUp(self): self.sock = mock.Mock(socket.socket) self.sock_fd = self.sock.fileno.return_value = 7 + def socket_transport(self, waiter=None): + transport = _SelectorSocketTransport(self.loop, self.sock, + self.protocol, waiter=waiter) + self.addCleanup(close_transport, transport) + return transport + def test_ctor(self): - tr = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + tr = self.socket_transport() self.loop.assert_reader(7, tr._read_ready) test_utils.run_briefly(self.loop) self.protocol.connection_made.assert_called_with(tr) @@ -749,14 +778,12 @@ def test_ctor(self): def test_ctor_with_waiter(self): fut = asyncio.Future(loop=self.loop) - _SelectorSocketTransport( - self.loop, self.sock, self.protocol, fut) + self.socket_transport(waiter=fut) test_utils.run_briefly(self.loop) self.assertIsNone(fut.result()) def test_pause_resume_reading(self): - tr = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + tr = self.socket_transport() self.assertFalse(tr._paused) self.loop.assert_reader(7, tr._read_ready) tr.pause_reading() @@ -769,8 +796,7 @@ def test_pause_resume_reading(self): tr.resume_reading() def test_read_ready(self): - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() self.sock.recv.return_value = b'data' transport._read_ready() @@ -778,8 +804,7 @@ def test_read_ready(self): self.protocol.data_received.assert_called_with(b'data') def test_read_ready_eof(self): - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport.close = mock.Mock() self.sock.recv.return_value = b'' @@ -789,8 +814,7 @@ def test_read_ready_eof(self): transport.close.assert_called_with() def test_read_ready_eof_keep_open(self): - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport.close = mock.Mock() self.sock.recv.return_value = b'' @@ -804,8 +828,7 @@ def test_read_ready_eof_keep_open(self): def test_read_ready_tryagain(self, m_exc): self.sock.recv.side_effect = BlockingIOError - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport._fatal_error = mock.Mock() transport._read_ready() @@ -815,8 +838,7 @@ def test_read_ready_tryagain(self, m_exc): def test_read_ready_tryagain_interrupted(self, m_exc): self.sock.recv.side_effect = InterruptedError - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport._fatal_error = mock.Mock() transport._read_ready() @@ -826,8 +848,7 @@ def test_read_ready_tryagain_interrupted(self, m_exc): def test_read_ready_conn_reset(self, m_exc): err = self.sock.recv.side_effect = ConnectionResetError() - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport._force_close = mock.Mock() with test_utils.disable_logger(): transport._read_ready() @@ -837,8 +858,7 @@ def test_read_ready_conn_reset(self, m_exc): def test_read_ready_err(self, m_exc): err = self.sock.recv.side_effect = OSError() - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport._fatal_error = mock.Mock() transport._read_ready() @@ -850,8 +870,7 @@ def test_write(self): data = b'data' self.sock.send.return_value = len(data) - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport.write(data) self.sock.send.assert_called_with(data) @@ -859,8 +878,7 @@ def test_write_bytearray(self): data = bytearray(b'data') self.sock.send.return_value = len(data) - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport.write(data) self.sock.send.assert_called_with(data) self.assertEqual(data, bytearray(b'data')) # Hasn't been mutated. @@ -869,22 +887,19 @@ def test_write_memoryview(self): data = memoryview(b'data') self.sock.send.return_value = len(data) - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport.write(data) self.sock.send.assert_called_with(data) def test_write_no_data(self): - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport._buffer.extend(b'data') transport.write(b'') self.assertFalse(self.sock.send.called) self.assertEqual(list_to_buffer([b'data']), transport._buffer) def test_write_buffer(self): - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport._buffer.extend(b'data1') transport.write(b'data2') self.assertFalse(self.sock.send.called) @@ -895,8 +910,7 @@ def test_write_partial(self): data = b'data' self.sock.send.return_value = 2 - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport.write(data) self.loop.assert_writer(7, transport._write_ready) @@ -906,8 +920,7 @@ def test_write_partial_bytearray(self): data = bytearray(b'data') self.sock.send.return_value = 2 - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport.write(data) self.loop.assert_writer(7, transport._write_ready) @@ -918,8 +931,7 @@ def test_write_partial_memoryview(self): data = memoryview(b'data') self.sock.send.return_value = 2 - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport.write(data) self.loop.assert_writer(7, transport._write_ready) @@ -930,8 +942,7 @@ def test_write_partial_none(self): self.sock.send.return_value = 0 self.sock.fileno.return_value = 7 - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport.write(data) self.loop.assert_writer(7, transport._write_ready) @@ -941,8 +952,7 @@ def test_write_tryagain(self): self.sock.send.side_effect = BlockingIOError data = b'data' - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport.write(data) self.loop.assert_writer(7, transport._write_ready) @@ -953,8 +963,7 @@ def test_write_exception(self, m_log): err = self.sock.send.side_effect = OSError() data = b'data' - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport._fatal_error = mock.Mock() transport.write(data) transport._fatal_error.assert_called_with( @@ -973,13 +982,11 @@ def test_write_exception(self, m_log): m_log.warning.assert_called_with('socket.send() raised exception.') def test_write_str(self): - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() self.assertRaises(TypeError, transport.write, 'str') def test_write_closing(self): - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport.close() self.assertEqual(transport._conn_lost, 1) transport.write(b'data') @@ -989,8 +996,7 @@ def test_write_ready(self): data = b'data' self.sock.send.return_value = len(data) - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport._buffer.extend(data) self.loop.add_writer(7, transport._write_ready) transport._write_ready() @@ -1001,8 +1007,7 @@ def test_write_ready_closing(self): data = b'data' self.sock.send.return_value = len(data) - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport._closing = True transport._buffer.extend(data) self.loop.add_writer(7, transport._write_ready) @@ -1013,8 +1018,7 @@ def test_write_ready_closing(self): self.protocol.connection_lost.assert_called_with(None) def test_write_ready_no_data(self): - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() # This is an internal error. self.assertRaises(AssertionError, transport._write_ready) @@ -1022,8 +1026,7 @@ def test_write_ready_partial(self): data = b'data' self.sock.send.return_value = 2 - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport._buffer.extend(data) self.loop.add_writer(7, transport._write_ready) transport._write_ready() @@ -1034,8 +1037,7 @@ def test_write_ready_partial_none(self): data = b'data' self.sock.send.return_value = 0 - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport._buffer.extend(data) self.loop.add_writer(7, transport._write_ready) transport._write_ready() @@ -1045,8 +1047,7 @@ def test_write_ready_partial_none(self): def test_write_ready_tryagain(self): self.sock.send.side_effect = BlockingIOError - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport._buffer = list_to_buffer([b'data1', b'data2']) self.loop.add_writer(7, transport._write_ready) transport._write_ready() @@ -1057,8 +1058,7 @@ def test_write_ready_tryagain(self): def test_write_ready_exception(self): err = self.sock.send.side_effect = OSError() - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport._fatal_error = mock.Mock() transport._buffer.extend(b'data') transport._write_ready() @@ -1071,16 +1071,14 @@ def test_write_ready_exception_and_close(self, m_log): self.sock.send.side_effect = OSError() remove_writer = self.loop.remove_writer = mock.Mock() - transport = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + transport = self.socket_transport() transport.close() transport._buffer.extend(b'data') transport._write_ready() remove_writer.assert_called_with(self.sock_fd) def test_write_eof(self): - tr = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + tr = self.socket_transport() self.assertTrue(tr.can_write_eof()) tr.write_eof() self.sock.shutdown.assert_called_with(socket.SHUT_WR) @@ -1089,8 +1087,7 @@ def test_write_eof(self): tr.close() def test_write_eof_buffer(self): - tr = _SelectorSocketTransport( - self.loop, self.sock, self.protocol) + tr = self.socket_transport() self.sock.send.side_effect = BlockingIOError tr.write(b'data') tr.write_eof() @@ -1117,9 +1114,15 @@ def setUp(self): self.sslcontext = mock.Mock() self.sslcontext.wrap_socket.return_value = self.sslsock + def ssl_transport(self, waiter=None, server_hostname=None): + transport = _SelectorSslTransport(self.loop, self.sock, self.protocol, + self.sslcontext, waiter=waiter, + server_hostname=server_hostname) + self.addCleanup(close_transport, transport) + return transport + def _make_one(self, create_waiter=None): - transport = _SelectorSslTransport( - self.loop, self.sock, self.protocol, self.sslcontext) + transport = self.ssl_transport() self.sock.reset_mock() self.sslsock.reset_mock() self.sslcontext.reset_mock() @@ -1128,9 +1131,7 @@ def _make_one(self, create_waiter=None): def test_on_handshake(self): waiter = asyncio.Future(loop=self.loop) - tr = _SelectorSslTransport( - self.loop, self.sock, self.protocol, self.sslcontext, - waiter=waiter) + tr = self.ssl_transport(waiter=waiter) self.assertTrue(self.sslsock.do_handshake.called) self.loop.assert_reader(1, tr._read_ready) test_utils.run_briefly(self.loop) @@ -1139,15 +1140,13 @@ def test_on_handshake(self): def test_on_handshake_reader_retry(self): self.loop.set_debug(False) self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError - transport = _SelectorSslTransport( - self.loop, self.sock, self.protocol, self.sslcontext) + transport = self.ssl_transport() self.loop.assert_reader(1, transport._on_handshake, None) def test_on_handshake_writer_retry(self): self.loop.set_debug(False) self.sslsock.do_handshake.side_effect = ssl.SSLWantWriteError - transport = _SelectorSslTransport( - self.loop, self.sock, self.protocol, self.sslcontext) + transport = self.ssl_transport() self.loop.assert_writer(1, transport._on_handshake, None) def test_on_handshake_exc(self): @@ -1155,16 +1154,14 @@ def test_on_handshake_exc(self): self.sslsock.do_handshake.side_effect = exc with test_utils.disable_logger(): waiter = asyncio.Future(loop=self.loop) - transport = _SelectorSslTransport( - self.loop, self.sock, self.protocol, self.sslcontext, waiter) + transport = self.ssl_transport(waiter=waiter) self.assertTrue(waiter.done()) self.assertIs(exc, waiter.exception()) self.assertTrue(self.sslsock.close.called) def test_on_handshake_base_exc(self): waiter = asyncio.Future(loop=self.loop) - transport = _SelectorSslTransport( - self.loop, self.sock, self.protocol, self.sslcontext, waiter) + transport = self.ssl_transport(waiter=waiter) exc = BaseException() self.sslsock.do_handshake.side_effect = exc with test_utils.disable_logger(): @@ -1177,8 +1174,7 @@ def test_cancel_handshake(self): # Python issue #23197: cancelling an handshake must not raise an # exception or log an error, even if the handshake failed waiter = asyncio.Future(loop=self.loop) - transport = _SelectorSslTransport( - self.loop, self.sock, self.protocol, self.sslcontext, waiter) + transport = self.ssl_transport(waiter=waiter) waiter.cancel() exc = ValueError() self.sslsock.do_handshake.side_effect = exc @@ -1437,9 +1433,7 @@ def test_close(self): @unittest.skipIf(ssl is None, 'No SSL support') def test_server_hostname(self): - _SelectorSslTransport( - self.loop, self.sock, self.protocol, self.sslcontext, - server_hostname='localhost') + self.ssl_transport(server_hostname='localhost') self.sslcontext.wrap_socket.assert_called_with( self.sock, do_handshake_on_connect=False, server_side=False, server_hostname='localhost') @@ -1462,9 +1456,15 @@ def setUp(self): self.sock = mock.Mock(spec_set=socket.socket) self.sock.fileno.return_value = 7 + def datagram_transport(self, address=None): + transport = _SelectorDatagramTransport(self.loop, self.sock, + self.protocol, + address=address) + self.addCleanup(close_transport, transport) + return transport + def test_read_ready(self): - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) + transport = self.datagram_transport() self.sock.recvfrom.return_value = (b'data', ('0.0.0.0', 1234)) transport._read_ready() @@ -1473,8 +1473,7 @@ def test_read_ready(self): b'data', ('0.0.0.0', 1234)) def test_read_ready_tryagain(self): - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) + transport = self.datagram_transport() self.sock.recvfrom.side_effect = BlockingIOError transport._fatal_error = mock.Mock() @@ -1483,8 +1482,7 @@ def test_read_ready_tryagain(self): self.assertFalse(transport._fatal_error.called) def test_read_ready_err(self): - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) + transport = self.datagram_transport() err = self.sock.recvfrom.side_effect = RuntimeError() transport._fatal_error = mock.Mock() @@ -1495,8 +1493,7 @@ def test_read_ready_err(self): 'Fatal read error on datagram transport') def test_read_ready_oserr(self): - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) + transport = self.datagram_transport() err = self.sock.recvfrom.side_effect = OSError() transport._fatal_error = mock.Mock() @@ -1507,8 +1504,7 @@ def test_read_ready_oserr(self): def test_sendto(self): data = b'data' - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) + transport = self.datagram_transport() transport.sendto(data, ('0.0.0.0', 1234)) self.assertTrue(self.sock.sendto.called) self.assertEqual( @@ -1516,8 +1512,7 @@ def test_sendto(self): def test_sendto_bytearray(self): data = bytearray(b'data') - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) + transport = self.datagram_transport() transport.sendto(data, ('0.0.0.0', 1234)) self.assertTrue(self.sock.sendto.called) self.assertEqual( @@ -1525,16 +1520,14 @@ def test_sendto_bytearray(self): def test_sendto_memoryview(self): data = memoryview(b'data') - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) + transport = self.datagram_transport() transport.sendto(data, ('0.0.0.0', 1234)) self.assertTrue(self.sock.sendto.called) self.assertEqual( self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234))) def test_sendto_no_data(self): - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) + transport = self.datagram_transport() transport._buffer.append((b'data', ('0.0.0.0', 12345))) transport.sendto(b'', ()) self.assertFalse(self.sock.sendto.called) @@ -1542,8 +1535,7 @@ def test_sendto_no_data(self): [(b'data', ('0.0.0.0', 12345))], list(transport._buffer)) def test_sendto_buffer(self): - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) + transport = self.datagram_transport() transport._buffer.append((b'data1', ('0.0.0.0', 12345))) transport.sendto(b'data2', ('0.0.0.0', 12345)) self.assertFalse(self.sock.sendto.called) @@ -1554,8 +1546,7 @@ def test_sendto_buffer(self): def test_sendto_buffer_bytearray(self): data2 = bytearray(b'data2') - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) + transport = self.datagram_transport() transport._buffer.append((b'data1', ('0.0.0.0', 12345))) transport.sendto(data2, ('0.0.0.0', 12345)) self.assertFalse(self.sock.sendto.called) @@ -1567,8 +1558,7 @@ def test_sendto_buffer_bytearray(self): def test_sendto_buffer_memoryview(self): data2 = memoryview(b'data2') - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) + transport = self.datagram_transport() transport._buffer.append((b'data1', ('0.0.0.0', 12345))) transport.sendto(data2, ('0.0.0.0', 12345)) self.assertFalse(self.sock.sendto.called) @@ -1583,8 +1573,7 @@ def test_sendto_tryagain(self): self.sock.sendto.side_effect = BlockingIOError - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) + transport = self.datagram_transport() transport.sendto(data, ('0.0.0.0', 12345)) self.loop.assert_writer(7, transport._sendto_ready) @@ -1596,8 +1585,7 @@ def test_sendto_exception(self, m_log): data = b'data' err = self.sock.sendto.side_effect = RuntimeError() - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) + transport = self.datagram_transport() transport._fatal_error = mock.Mock() transport.sendto(data, ()) @@ -1620,8 +1608,7 @@ def test_sendto_error_received(self): self.sock.sendto.side_effect = ConnectionRefusedError - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) + transport = self.datagram_transport() transport._fatal_error = mock.Mock() transport.sendto(data, ()) @@ -1633,8 +1620,7 @@ def test_sendto_error_received_connected(self): self.sock.send.side_effect = ConnectionRefusedError - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) + transport = self.datagram_transport(address=('0.0.0.0', 1)) transport._fatal_error = mock.Mock() transport.sendto(data) @@ -1642,19 +1628,16 @@ def test_sendto_error_received_connected(self): self.assertTrue(self.protocol.error_received.called) def test_sendto_str(self): - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) + transport = self.datagram_transport() self.assertRaises(TypeError, transport.sendto, 'str', ()) def test_sendto_connected_addr(self): - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) + transport = self.datagram_transport(address=('0.0.0.0', 1)) self.assertRaises( ValueError, transport.sendto, b'str', ('0.0.0.0', 2)) def test_sendto_closing(self): - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol, address=(1,)) + transport = self.datagram_transport(address=(1,)) transport.close() self.assertEqual(transport._conn_lost, 1) transport.sendto(b'data', (1,)) @@ -1664,8 +1647,7 @@ def test_sendto_ready(self): data = b'data' self.sock.sendto.return_value = len(data) - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) + transport = self.datagram_transport() transport._buffer.append((data, ('0.0.0.0', 12345))) self.loop.add_writer(7, transport._sendto_ready) transport._sendto_ready() @@ -1678,8 +1660,7 @@ def test_sendto_ready_closing(self): data = b'data' self.sock.send.return_value = len(data) - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) + transport = self.datagram_transport() transport._closing = True transport._buffer.append((data, ())) self.loop.add_writer(7, transport._sendto_ready) @@ -1690,8 +1671,7 @@ def test_sendto_ready_closing(self): self.protocol.connection_lost.assert_called_with(None) def test_sendto_ready_no_data(self): - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) + transport = self.datagram_transport() self.loop.add_writer(7, transport._sendto_ready) transport._sendto_ready() self.assertFalse(self.sock.sendto.called) @@ -1700,8 +1680,7 @@ def test_sendto_ready_no_data(self): def test_sendto_ready_tryagain(self): self.sock.sendto.side_effect = BlockingIOError - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) + transport = self.datagram_transport() transport._buffer.extend([(b'data1', ()), (b'data2', ())]) self.loop.add_writer(7, transport._sendto_ready) transport._sendto_ready() @@ -1714,8 +1693,7 @@ def test_sendto_ready_tryagain(self): def test_sendto_ready_exception(self): err = self.sock.sendto.side_effect = RuntimeError() - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) + transport = self.datagram_transport() transport._fatal_error = mock.Mock() transport._buffer.append((b'data', ())) transport._sendto_ready() @@ -1727,8 +1705,7 @@ def test_sendto_ready_exception(self): def test_sendto_ready_error_received(self): self.sock.sendto.side_effect = ConnectionRefusedError - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol) + transport = self.datagram_transport() transport._fatal_error = mock.Mock() transport._buffer.append((b'data', ())) transport._sendto_ready() @@ -1738,8 +1715,7 @@ def test_sendto_ready_error_received(self): def test_sendto_ready_error_received_connection(self): self.sock.send.side_effect = ConnectionRefusedError - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) + transport = self.datagram_transport(address=('0.0.0.0', 1)) transport._fatal_error = mock.Mock() transport._buffer.append((b'data', ())) transport._sendto_ready() @@ -1749,8 +1725,7 @@ def test_sendto_ready_error_received_connection(self): @mock.patch('asyncio.base_events.logger.error') def test_fatal_error_connected(self, m_exc): - transport = _SelectorDatagramTransport( - self.loop, self.sock, self.protocol, ('0.0.0.0', 1)) + transport = self.datagram_transport(address=('0.0.0.0', 1)) err = ConnectionRefusedError() transport._fatal_error(err) self.assertFalse(self.protocol.error_received.called) @@ -1758,7 +1733,6 @@ def test_fatal_error_connected(self, m_exc): test_utils.MockPattern( 'Fatal error on transport\nprotocol:.*\ntransport:.*'), exc_info=(ConnectionRefusedError, MOCK_ANY, MOCK_ANY)) - transport.close() if __name__ == '__main__': diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 4a68ce36..126196da 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -26,6 +26,15 @@ MOCK_ANY = mock.ANY +def close_pipe_transport(transport): + # Don't call transport.close() because the event loop and the selector + # are mocked + if transport._pipe is None: + return + transport._pipe.close() + transport._pipe = None + + @unittest.skipUnless(signal, 'Signals are not supported') class SelectorEventLoopSignalTests(test_utils.TestCase): @@ -333,24 +342,28 @@ def setUp(self): m_fstat.return_value = st self.addCleanup(fstat_patcher.stop) + def read_pipe_transport(self, waiter=None): + transport = unix_events._UnixReadPipeTransport(self.loop, self.pipe, + self.protocol, + waiter=waiter) + self.addCleanup(close_pipe_transport, transport) + return transport + def test_ctor(self): - tr = unix_events._UnixReadPipeTransport( - self.loop, self.pipe, self.protocol) + tr = self.read_pipe_transport() self.loop.assert_reader(5, tr._read_ready) test_utils.run_briefly(self.loop) self.protocol.connection_made.assert_called_with(tr) def test_ctor_with_waiter(self): fut = asyncio.Future(loop=self.loop) - unix_events._UnixReadPipeTransport( - self.loop, self.pipe, self.protocol, fut) + tr = self.read_pipe_transport(waiter=fut) test_utils.run_briefly(self.loop) self.assertIsNone(fut.result()) @mock.patch('os.read') def test__read_ready(self, m_read): - tr = unix_events._UnixReadPipeTransport( - self.loop, self.pipe, self.protocol) + tr = self.read_pipe_transport() m_read.return_value = b'data' tr._read_ready() @@ -359,8 +372,7 @@ def test__read_ready(self, m_read): @mock.patch('os.read') def test__read_ready_eof(self, m_read): - tr = unix_events._UnixReadPipeTransport( - self.loop, self.pipe, self.protocol) + tr = self.read_pipe_transport() m_read.return_value = b'' tr._read_ready() @@ -372,8 +384,7 @@ def test__read_ready_eof(self, m_read): @mock.patch('os.read') def test__read_ready_blocked(self, m_read): - tr = unix_events._UnixReadPipeTransport( - self.loop, self.pipe, self.protocol) + tr = self.read_pipe_transport() m_read.side_effect = BlockingIOError tr._read_ready() @@ -384,8 +395,7 @@ def test__read_ready_blocked(self, m_read): @mock.patch('asyncio.log.logger.error') @mock.patch('os.read') def test__read_ready_error(self, m_read, m_logexc): - tr = unix_events._UnixReadPipeTransport( - self.loop, self.pipe, self.protocol) + tr = self.read_pipe_transport() err = OSError() m_read.side_effect = err tr._close = mock.Mock() @@ -401,9 +411,7 @@ def test__read_ready_error(self, m_read, m_logexc): @mock.patch('os.read') def test_pause_reading(self, m_read): - tr = unix_events._UnixReadPipeTransport( - self.loop, self.pipe, self.protocol) - + tr = self.read_pipe_transport() m = mock.Mock() self.loop.add_reader(5, m) tr.pause_reading() @@ -411,26 +419,20 @@ def test_pause_reading(self, m_read): @mock.patch('os.read') def test_resume_reading(self, m_read): - tr = unix_events._UnixReadPipeTransport( - self.loop, self.pipe, self.protocol) - + tr = self.read_pipe_transport() tr.resume_reading() self.loop.assert_reader(5, tr._read_ready) @mock.patch('os.read') def test_close(self, m_read): - tr = unix_events._UnixReadPipeTransport( - self.loop, self.pipe, self.protocol) - + tr = self.read_pipe_transport() tr._close = mock.Mock() tr.close() tr._close.assert_called_with(None) @mock.patch('os.read') def test_close_already_closing(self, m_read): - tr = unix_events._UnixReadPipeTransport( - self.loop, self.pipe, self.protocol) - + tr = self.read_pipe_transport() tr._closing = True tr._close = mock.Mock() tr.close() @@ -438,9 +440,7 @@ def test_close_already_closing(self, m_read): @mock.patch('os.read') def test__close(self, m_read): - tr = unix_events._UnixReadPipeTransport( - self.loop, self.pipe, self.protocol) - + tr = self.read_pipe_transport() err = object() tr._close(err) self.assertTrue(tr._closing) @@ -449,8 +449,7 @@ def test__close(self, m_read): self.protocol.connection_lost.assert_called_with(err) def test__call_connection_lost(self): - tr = unix_events._UnixReadPipeTransport( - self.loop, self.pipe, self.protocol) + tr = self.read_pipe_transport() self.assertIsNotNone(tr._protocol) self.assertIsNotNone(tr._loop) @@ -463,8 +462,7 @@ def test__call_connection_lost(self): self.assertIsNone(tr._loop) def test__call_connection_lost_with_err(self): - tr = unix_events._UnixReadPipeTransport( - self.loop, self.pipe, self.protocol) + tr = self.read_pipe_transport() self.assertIsNotNone(tr._protocol) self.assertIsNotNone(tr._loop) @@ -496,31 +494,33 @@ def setUp(self): m_fstat.return_value = st self.addCleanup(fstat_patcher.stop) + def write_pipe_transport(self, waiter=None): + transport = unix_events._UnixWritePipeTransport(self.loop, self.pipe, + self.protocol, + waiter=waiter) + self.addCleanup(close_pipe_transport, transport) + return transport + def test_ctor(self): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) + tr = self.write_pipe_transport() self.loop.assert_reader(5, tr._read_ready) test_utils.run_briefly(self.loop) self.protocol.connection_made.assert_called_with(tr) def test_ctor_with_waiter(self): fut = asyncio.Future(loop=self.loop) - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol, fut) + tr = self.write_pipe_transport(waiter=fut) self.loop.assert_reader(5, tr._read_ready) test_utils.run_briefly(self.loop) self.assertEqual(None, fut.result()) def test_can_write_eof(self): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) + tr = self.write_pipe_transport() self.assertTrue(tr.can_write_eof()) @mock.patch('os.write') def test_write(self, m_write): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) - + tr = self.write_pipe_transport() m_write.return_value = 4 tr.write(b'data') m_write.assert_called_with(5, b'data') @@ -529,9 +529,7 @@ def test_write(self, m_write): @mock.patch('os.write') def test_write_no_data(self, m_write): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) - + tr = self.write_pipe_transport() tr.write(b'') self.assertFalse(m_write.called) self.assertFalse(self.loop.writers) @@ -539,9 +537,7 @@ def test_write_no_data(self, m_write): @mock.patch('os.write') def test_write_partial(self, m_write): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) - + tr = self.write_pipe_transport() m_write.return_value = 2 tr.write(b'data') m_write.assert_called_with(5, b'data') @@ -550,9 +546,7 @@ def test_write_partial(self, m_write): @mock.patch('os.write') def test_write_buffer(self, m_write): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) - + tr = self.write_pipe_transport() self.loop.add_writer(5, tr._write_ready) tr._buffer = [b'previous'] tr.write(b'data') @@ -562,9 +556,7 @@ def test_write_buffer(self, m_write): @mock.patch('os.write') def test_write_again(self, m_write): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) - + tr = self.write_pipe_transport() m_write.side_effect = BlockingIOError() tr.write(b'data') m_write.assert_called_with(5, b'data') @@ -574,9 +566,7 @@ def test_write_again(self, m_write): @mock.patch('asyncio.unix_events.logger') @mock.patch('os.write') def test_write_err(self, m_write, m_log): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) - + tr = self.write_pipe_transport() err = OSError() m_write.side_effect = err tr._fatal_error = mock.Mock() @@ -602,8 +592,7 @@ def test_write_err(self, m_write, m_log): @mock.patch('os.write') def test_write_close(self, m_write): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) + tr = self.write_pipe_transport() tr._read_ready() # pipe was closed by peer tr.write(b'data') @@ -612,8 +601,7 @@ def test_write_close(self, m_write): self.assertEqual(tr._conn_lost, 2) def test__read_ready(self): - tr = unix_events._UnixWritePipeTransport(self.loop, self.pipe, - self.protocol) + tr = self.write_pipe_transport() tr._read_ready() self.assertFalse(self.loop.readers) self.assertFalse(self.loop.writers) @@ -623,8 +611,7 @@ def test__read_ready(self): @mock.patch('os.write') def test__write_ready(self, m_write): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) + tr = self.write_pipe_transport() self.loop.add_writer(5, tr._write_ready) tr._buffer = [b'da', b'ta'] m_write.return_value = 4 @@ -635,9 +622,7 @@ def test__write_ready(self, m_write): @mock.patch('os.write') def test__write_ready_partial(self, m_write): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) - + tr = self.write_pipe_transport() self.loop.add_writer(5, tr._write_ready) tr._buffer = [b'da', b'ta'] m_write.return_value = 3 @@ -648,9 +633,7 @@ def test__write_ready_partial(self, m_write): @mock.patch('os.write') def test__write_ready_again(self, m_write): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) - + tr = self.write_pipe_transport() self.loop.add_writer(5, tr._write_ready) tr._buffer = [b'da', b'ta'] m_write.side_effect = BlockingIOError() @@ -661,9 +644,7 @@ def test__write_ready_again(self, m_write): @mock.patch('os.write') def test__write_ready_empty(self, m_write): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) - + tr = self.write_pipe_transport() self.loop.add_writer(5, tr._write_ready) tr._buffer = [b'da', b'ta'] m_write.return_value = 0 @@ -675,9 +656,7 @@ def test__write_ready_empty(self, m_write): @mock.patch('asyncio.log.logger.error') @mock.patch('os.write') def test__write_ready_err(self, m_write, m_logexc): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) - + tr = self.write_pipe_transport() self.loop.add_writer(5, tr._write_ready) tr._buffer = [b'da', b'ta'] m_write.side_effect = err = OSError() @@ -698,9 +677,7 @@ def test__write_ready_err(self, m_write, m_logexc): @mock.patch('os.write') def test__write_ready_closing(self, m_write): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) - + tr = self.write_pipe_transport() self.loop.add_writer(5, tr._write_ready) tr._closing = True tr._buffer = [b'da', b'ta'] @@ -715,9 +692,7 @@ def test__write_ready_closing(self, m_write): @mock.patch('os.write') def test_abort(self, m_write): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) - + tr = self.write_pipe_transport() self.loop.add_writer(5, tr._write_ready) self.loop.add_reader(5, tr._read_ready) tr._buffer = [b'da', b'ta'] @@ -731,8 +706,7 @@ def test_abort(self, m_write): self.protocol.connection_lost.assert_called_with(None) def test__call_connection_lost(self): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) + tr = self.write_pipe_transport() self.assertIsNotNone(tr._protocol) self.assertIsNotNone(tr._loop) @@ -745,8 +719,7 @@ def test__call_connection_lost(self): self.assertIsNone(tr._loop) def test__call_connection_lost_with_err(self): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) + tr = self.write_pipe_transport() self.assertIsNotNone(tr._protocol) self.assertIsNotNone(tr._loop) @@ -759,9 +732,7 @@ def test__call_connection_lost_with_err(self): self.assertIsNone(tr._loop) def test_close(self): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) - + tr = self.write_pipe_transport() tr.write_eof = mock.Mock() tr.close() tr.write_eof.assert_called_with() @@ -770,18 +741,14 @@ def test_close(self): tr.close() def test_close_closing(self): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) - + tr = self.write_pipe_transport() tr.write_eof = mock.Mock() tr._closing = True tr.close() self.assertFalse(tr.write_eof.called) def test_write_eof(self): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) - + tr = self.write_pipe_transport() tr.write_eof() self.assertTrue(tr._closing) self.assertFalse(self.loop.readers) @@ -789,8 +756,7 @@ def test_write_eof(self): self.protocol.connection_lost.assert_called_with(None) def test_write_eof_pending(self): - tr = unix_events._UnixWritePipeTransport( - self.loop, self.pipe, self.protocol) + tr = self.write_pipe_transport() tr._buffer = [b'data'] tr.write_eof() self.assertTrue(tr._closing) From d9de03aa87da736e43eb074de52afadc6716d2fd Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 15 Jan 2015 13:32:10 +0100 Subject: [PATCH 1284/1502] Fix _ProactorBasePipeTransport.__repr__() Check if the _sock attribute is None to check if the transport is closed. --- asyncio/proactor_events.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 0ecb44eb..a177d32a 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -43,12 +43,12 @@ def __init__(self, loop, sock, protocol, waiter=None, def __repr__(self): info = [self.__class__.__name__] - fd = self._sock.fileno() - if fd < 0: + if self._sock is None: info.append('closed') elif self._closing: info.append('closing') - info.append('fd=%s' % fd) + if self._sock is not None: + info.append('fd=%s' % self._sock.fileno()) if self._read_fut is not None: info.append('read=%s' % self._read_fut) if self._write_fut is not None: From 873bc97904998dc8a9876401ae5974ae98d892b3 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 15 Jan 2015 13:40:03 +0100 Subject: [PATCH 1285/1502] Fix _ProactorBasePipeTransport.close() Set the _read_fut attribute to None after cancelling it. This change should fix a race condition with _ProactorWritePipeTransport._pipe_closed(). --- asyncio/proactor_events.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index a177d32a..6d8641fe 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -72,6 +72,7 @@ def close(self): self._loop.call_soon(self._call_connection_lost, None) if self._read_fut is not None: self._read_fut.cancel() + self._read_fut = None def _fatal_error(self, exc, message='Fatal error on pipe transport'): if isinstance(exc, (BrokenPipeError, ConnectionResetError)): @@ -93,9 +94,10 @@ def _force_close(self, exc): self._conn_lost += 1 if self._write_fut: self._write_fut.cancel() + self._write_fut = None if self._read_fut: self._read_fut.cancel() - self._write_fut = self._read_fut = None + self._read_fut = None self._pending_write = 0 self._buffer = None self._loop.call_soon(self._call_connection_lost, exc) From f9c91e5e4423aba4eba50901e9a4f459e53530c8 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 15 Jan 2015 14:18:33 +0100 Subject: [PATCH 1286/1502] Close the transport on subprocess creation failure --- asyncio/unix_events.py | 6 +++++- asyncio/windows_events.py | 7 ++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 9f4005cb..97f9addd 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -177,7 +177,11 @@ def _make_subprocess_transport(self, protocol, args, shell, transp = _UnixSubprocessTransport(self, protocol, args, shell, stdin, stdout, stderr, bufsize, extra=extra, **kwargs) - yield from transp._post_init() + try: + yield from transp._post_init() + except: + transp.close() + raise watcher.add_child_handler(transp.get_pid(), self._child_watcher_callback, transp) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 9d496f2f..82d09663 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -272,7 +272,12 @@ def _make_subprocess_transport(self, protocol, args, shell, transp = _WindowsSubprocessTransport(self, protocol, args, shell, stdin, stdout, stderr, bufsize, extra=extra, **kwargs) - yield from transp._post_init() + try: + yield from transp._post_init() + except: + transp.close() + raise + return transp From cf01d4281847ea88d541e6b289e26c1ab9a2d2dd Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 15 Jan 2015 14:23:32 +0100 Subject: [PATCH 1287/1502] Close transports in tests * Use test_utils.run_briefly() to execute pending calls to really close transports * sslproto: mock also _SSLPipe.shutdown(), it's need to close the transport * pipe test: the test doesn't close explicitly the PipeHandle, so ignore the warning instead * test_popen: use the context manager ("with p:") to explicitly close pipes --- tests/test_selector_events.py | 2 ++ tests/test_sslproto.py | 4 ++++ tests/test_subprocess.py | 1 + tests/test_windows_utils.py | 11 ++++++++--- 4 files changed, 15 insertions(+), 3 deletions(-) diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index f99d04d4..ad86ada3 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -1180,6 +1180,8 @@ def test_cancel_handshake(self): self.sslsock.do_handshake.side_effect = exc with test_utils.disable_logger(): transport._on_handshake(0) + transport.close() + test_utils.run_briefly(self.loop) def test_pause_resume_reading(self): tr = self._make_one() diff --git a/tests/test_sslproto.py b/tests/test_sslproto.py index 053fefe7..812dedbe 100644 --- a/tests/test_sslproto.py +++ b/tests/test_sslproto.py @@ -33,6 +33,7 @@ def do_handshake(callback): waiter.cancel() transport = mock.Mock() sslpipe = mock.Mock() + sslpipe.shutdown.return_value = b'' sslpipe.do_handshake.side_effect = do_handshake with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe): ssl_proto.connection_made(transport) @@ -40,6 +41,9 @@ def do_handshake(callback): with test_utils.disable_logger(): self.loop.run_until_complete(handshake_fut) + # Close the transport + ssl_proto._app_transport.close() + if __name__ == '__main__': unittest.main() diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index b2f1b953..a4c1f698 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -286,6 +286,7 @@ def cancel_make_transport(): # "Exception during subprocess creation, kill the subprocess" with test_utils.disable_logger(): self.loop.run_until_complete(cancel_make_transport()) + test_utils.run_briefly(self.loop) if sys.platform != 'win32': diff --git a/tests/test_windows_utils.py b/tests/test_windows_utils.py index af5c453b..d48b8bcb 100644 --- a/tests/test_windows_utils.py +++ b/tests/test_windows_utils.py @@ -3,6 +3,7 @@ import socket import sys import unittest +import warnings from unittest import mock if sys.platform != 'win32': @@ -115,8 +116,10 @@ def test_pipe_handle(self): self.assertEqual(p.handle, h) # check garbage collection of p closes handle - del p - support.gc_collect() + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "", ResourceWarning) + del p + support.gc_collect() try: _winapi.CloseHandle(h) except OSError as e: @@ -170,7 +173,9 @@ def test_popen(self): self.assertTrue(msg.upper().rstrip().startswith(out)) self.assertTrue(b"stderr".startswith(err)) - p.wait() + # The context manager calls wait() and closes resources + with p: + pass if __name__ == '__main__': From 408fbab530fd4abe49249a636a10f10f44d07a21 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 15 Jan 2015 16:25:54 +0100 Subject: [PATCH 1288/1502] Python issue #23219: cancelling wait_for() now cancels the task --- asyncio/tasks.py | 12 ++++++++---- tests/test_tasks.py | 27 +++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 7959a55a..63412a97 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -347,10 +347,9 @@ def wait_for(fut, timeout, *, loop=None): it cancels the task and raises TimeoutError. To avoid the task cancellation, wrap it in shield(). - Usage: - - result = yield from asyncio.wait_for(fut, 10.0) + If the wait is cancelled, the task is also cancelled. + This function is a coroutine. """ if loop is None: loop = events.get_event_loop() @@ -367,7 +366,12 @@ def wait_for(fut, timeout, *, loop=None): try: # wait until the future completes or the timeout - yield from waiter + try: + yield from waiter + except futures.CancelledError: + fut.remove_done_callback(cb) + fut.cancel() + raise if fut.done(): return fut.result() diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 7807dc04..06447d77 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1705,6 +1705,33 @@ def test_task_source_traceback(self): 'test_task_source_traceback')) self.loop.run_until_complete(task) + def _test_cancel_wait_for(self, timeout): + loop = asyncio.new_event_loop() + self.addCleanup(loop.close) + + @asyncio.coroutine + def blocking_coroutine(): + fut = asyncio.Future(loop=loop) + # Block: fut result is never set + yield from fut + + task = loop.create_task(blocking_coroutine()) + + wait = loop.create_task(asyncio.wait_for(task, timeout, loop=loop)) + loop.call_soon(wait.cancel) + + self.assertRaises(asyncio.CancelledError, + loop.run_until_complete, wait) + + # Python issue #23219: cancelling the wait must also cancel the task + self.assertTrue(task.cancelled()) + + def test_cancel_blocking_wait_for(self): + self._test_cancel_wait_for(None) + + def test_cancel_wait_for(self): + self._test_cancel_wait_for(60.0) + class GatherTestsBase: From 416f59d4dd534da6cbd68036b1c3d2e7ef249921 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 15 Jan 2015 21:53:27 +0100 Subject: [PATCH 1289/1502] Backout changeset 3a06020af8cf StreamWriter: close() now clears the reference to the transport StreamWriter now raises an exception if it is closed: write(), writelines(), write_eof(), can_write_eof(), get_extra_info(), drain(). --- asyncio/streams.py | 25 ++++--------------------- 1 file changed, 4 insertions(+), 21 deletions(-) diff --git a/asyncio/streams.py b/asyncio/streams.py index 12ab1c52..7ff16a48 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -258,22 +258,8 @@ def __init__(self, transport, protocol, reader, loop): self._reader = reader self._loop = loop - def close(self): - if self._transport is None: - return - self._transport.close() - self._transport = None - - def _check_closed(self): - if self._transport is None: - raise RuntimeError('StreamWriter is closed') - def __repr__(self): - info = [self.__class__.__name__] - if self._transport is not None: - info.append('transport=%r' % self._transport) - else: - info.append('closed') + info = [self.__class__.__name__, 'transport=%r' % self._transport] if self._reader is not None: info.append('reader=%r' % self._reader) return '<%s>' % ' '.join(info) @@ -283,23 +269,21 @@ def transport(self): return self._transport def write(self, data): - self._check_closed() self._transport.write(data) def writelines(self, data): - self._check_closed() self._transport.writelines(data) def write_eof(self): - self._check_closed() return self._transport.write_eof() def can_write_eof(self): - self._check_closed() return self._transport.can_write_eof() + def close(self): + return self._transport.close() + def get_extra_info(self, name, default=None): - self._check_closed() return self._transport.get_extra_info(name, default) @coroutine @@ -311,7 +295,6 @@ def drain(self): w.write(data) yield from w.drain() """ - self._check_closed() if self._reader is not None: exc = self._reader.exception() if exc is not None: From e87e26f866e407c1a82639f0d6c1424efafc346f Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 15 Jan 2015 22:59:13 +0100 Subject: [PATCH 1290/1502] Issue #22685: Fix test_pause_reading() of test_subprocess Override the connect_read_pipe() method of the loop to mock immediatly pause_reading() and resume_reading() methods. The test failed randomly on FreeBSD 9 buildbot and on Windows using trollius. --- tests/test_subprocess.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index a4c1f698..ecc2c9d8 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -179,6 +179,18 @@ def test_pause_reading(): 'sys.stdout.write("x" * %s)' % size, 'sys.stdout.flush()', )) + + connect_read_pipe = self.loop.connect_read_pipe + + @asyncio.coroutine + def connect_read_pipe_mock(*args, **kw): + transport, protocol = yield from connect_read_pipe(*args, **kw) + transport.pause_reading = mock.Mock() + transport.resume_reading = mock.Mock() + return (transport, protocol) + + self.loop.connect_read_pipe = connect_read_pipe_mock + proc = yield from asyncio.create_subprocess_exec( sys.executable, '-c', code, stdin=asyncio.subprocess.PIPE, @@ -186,8 +198,6 @@ def test_pause_reading(): limit=limit, loop=self.loop) stdout_transport = proc._transport.get_pipe_transport(1) - stdout_transport.pause_reading = mock.Mock() - stdout_transport.resume_reading = mock.Mock() stdout, stderr = yield from proc.communicate() From 21192c42f2b540a560b29346e0de549920c01ad1 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 16 Jan 2015 17:35:59 +0100 Subject: [PATCH 1291/1502] Cleanup BaseEventLoop._create_connection_transport() Remove the exc variable, it's not used. --- asyncio/base_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 5df5b83b..739296b7 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -636,7 +636,7 @@ def _create_connection_transport(self, sock, protocol_factory, ssl, try: yield from waiter - except Exception as exc: + except Exception: transport.close() raise From 777da23a5ca3f0364d30b33bd34043225b5bbc73 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 16 Jan 2015 17:36:09 +0100 Subject: [PATCH 1292/1502] tox.ini: fix comment --- tox.ini | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 192a0b9f..6209ff4e 100644 --- a/tox.ini +++ b/tox.ini @@ -4,6 +4,7 @@ envlist = py33,py34,py3_release [testenv] deps= aiotest +# Run tests in debug mode setenv = PYTHONASYNCIODEBUG = 1 commands= @@ -11,7 +12,7 @@ commands= python run_aiotest.py -r {posargs} [testenv:py3_release] -# Run tests in debug mode +# Run tests in release mode setenv = PYTHONASYNCIODEBUG = basepython = python3 From 346737d6c57b5806870346ef54ad71a355dc200d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 16 Jan 2015 17:36:27 +0100 Subject: [PATCH 1293/1502] runtests.py: rephrase the message mentionning randomization of tests --- runtests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtests.py b/runtests.py index 8cb56fe0..c38b0c18 100644 --- a/runtests.py +++ b/runtests.py @@ -122,7 +122,7 @@ def randomize_tests(tests, seed): if seed is None: seed = random.randrange(10000000) random.seed(seed) - print("Using random seed", seed) + print("Randomize test execution order (seed: %s)" % seed) random.shuffle(tests._tests) From ad1724a21157a7c35fbec316bdc47cf03c135a89 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 16 Jan 2015 17:36:35 +0100 Subject: [PATCH 1294/1502] release.py: fix typo --- release.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/release.py b/release.py index 3fea4a94..3d95a98f 100755 --- a/release.py +++ b/release.py @@ -34,7 +34,7 @@ WINDOWS = (sys.platform == 'win32') -def get_archiecture_bits(): +def get_architecture_bits(): arch = platform.architecture()[0] return int(arch[:2]) @@ -48,7 +48,7 @@ def __init__(self, major, minor, bits): @staticmethod def running(): - bits = get_archiecture_bits() + bits = get_architecture_bits() pyver = PythonVersion(sys.version_info.major, sys.version_info.minor, bits) @@ -127,7 +127,7 @@ def __init__(self): if WINDOWS: supported_archs = (32, 64) else: - bits = get_archiecture_bits() + bits = get_architecture_bits() supported_archs = (bits,) for major, minor in PYTHON_VERSIONS: for bits in supported_archs: From 2b5becd0cc5ac065ed06847fdec062c813176b58 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 16 Jan 2015 17:36:47 +0100 Subject: [PATCH 1295/1502] tests: Remove unused function; inline another function --- tests/test_streams.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tests/test_streams.py b/tests/test_streams.py index a18603af..2273049b 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -415,10 +415,6 @@ def test_exception_waiter(self): def set_err(): stream.set_exception(ValueError()) - @asyncio.coroutine - def readline(): - yield from stream.readline() - t1 = asyncio.Task(stream.readline(), loop=self.loop) t2 = asyncio.Task(set_err(), loop=self.loop) @@ -429,11 +425,7 @@ def readline(): def test_exception_cancel(self): stream = asyncio.StreamReader(loop=self.loop) - @asyncio.coroutine - def read_a_line(): - yield from stream.readline() - - t = asyncio.Task(read_a_line(), loop=self.loop) + t = asyncio.Task(stream.readline(), loop=self.loop) test_utils.run_briefly(self.loop) t.cancel() test_utils.run_briefly(self.loop) From 3c0eabfa3fe1fedb95fa35157f76c0a7edbb6d66 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 16 Jan 2015 17:45:51 +0100 Subject: [PATCH 1296/1502] Enhance BaseProactorEventLoop._loop_self_reading() * Handle correctly CancelledError: just exit * On error, log the exception and exit Don't try to close the event loop, it is probably running and so it cannot be closed. --- asyncio/proactor_events.py | 12 +++++++++--- tests/test_proactor_events.py | 5 +++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 6d8641fe..ed170622 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -463,9 +463,15 @@ def _loop_self_reading(self, f=None): if f is not None: f.result() # may raise f = self._proactor.recv(self._ssock, 4096) - except: - self.close() - raise + except futures.CancelledError: + # _close_self_pipe() has been called, stop waiting for data + return + except Exception as exc: + self.call_exception_handler({ + 'message': 'Error on reading from the event loop self pipe', + 'exception': exc, + 'loop': self, + }) else: self._self_reading_future = f f.add_done_callback(self._loop_self_reading) diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index dee147e5..33a8a671 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -523,9 +523,10 @@ def test_loop_self_reading_fut(self): def test_loop_self_reading_exception(self): self.loop.close = mock.Mock() + self.loop.call_exception_handler = mock.Mock() self.proactor.recv.side_effect = OSError() - self.assertRaises(OSError, self.loop._loop_self_reading) - self.assertTrue(self.loop.close.called) + self.loop._loop_self_reading() + self.assertTrue(self.loop.call_exception_handler.called) def test_write_to_self(self): self.loop._write_to_self() From c3a18fdbc5f71e225f8b6dd59e8e39a9fc432ea3 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 16 Jan 2015 17:50:58 +0100 Subject: [PATCH 1297/1502] pyflakes: remove unused import --- asyncio/selector_events.py | 1 - 1 file changed, 1 deletion(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 4d3e5d9e..24f84615 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -10,7 +10,6 @@ import errno import functools import socket -import sys try: import ssl except ImportError: # pragma: no cover From f1fbd1e697cacdc1926231253d382e4779f39393 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 21 Jan 2015 00:21:25 +0100 Subject: [PATCH 1298/1502] test_selectors: use asyncio.test_support if test.support is missing --- tests/test_selectors.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_selectors.py b/tests/test_selectors.py index 49b5b8d0..a33f0fa4 100644 --- a/tests/test_selectors.py +++ b/tests/test_selectors.py @@ -3,10 +3,13 @@ import random import signal import sys -from test import support from time import sleep import unittest import unittest.mock +try: + from test import support +except ImportError: + from asyncio import test_support as support try: from time import monotonic as time except ImportError: From 550000013c96fbfa295c443c1c6bdbfdd59bea1d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 21 Jan 2015 22:40:49 +0100 Subject: [PATCH 1299/1502] Python issue #23095: Rewrite _WaitHandleFuture.cancel() This change fixes a race conditon related to _WaitHandleFuture.cancel() leading to Python crash or "GetQueuedCompletionStatus() returned an unexpected event" logs. Before, the overlapped object was destroyed too early, it was possible that the wait completed whereas the overlapped object was already destroyed. Sometimes, a different overlapped was allocated at the same address, leading to unexpected completition. _WaitHandleFuture.cancel() now waits until the wait is cancelled to clear its reference to the overlapped object. To wait until the cancellation is done, UnregisterWaitEx() is used with an event instead of UnregisterWait(). To wait for this event, a new _WaitCancelFuture class was added. It's a simplified version of _WaitCancelFuture. For example, its cancel() method calls UnregisterWait(), not UnregisterWaitEx(). _WaitCancelFuture should not be cancelled. The overlapped object is kept alive in _WaitHandleFuture until the wait is unregistered. Other changes: * Add _overlapped.UnregisterWaitEx() * Remove fast-path in IocpProactor.wait_for_handle() to immediatly set the result if the wait already completed. I'm not sure that it's safe to call immediatly UnregisterWaitEx() before the completion was signaled. * Add IocpProactor._unregistered() to forget an overlapped which may never be signaled, but may be signaled for the next loop iteration. It avoids to block forever IocpProactor.close() if a wait was cancelled, and it may also avoid some "... unexpected event ..." warnings. --- asyncio/windows_events.py | 168 ++++++++++++++++++++++++++++++-------- overlapped.c | 25 ++++++ 2 files changed, 159 insertions(+), 34 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 82d09663..5105426f 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -78,20 +78,23 @@ def set_result(self, result): self._ov = None -class _WaitHandleFuture(futures.Future): +class _BaseWaitHandleFuture(futures.Future): """Subclass of Future which represents a wait handle.""" - def __init__(self, iocp, ov, handle, wait_handle, *, loop=None): + def __init__(self, ov, handle, wait_handle, *, loop=None): super().__init__(loop=loop) if self._source_traceback: del self._source_traceback[-1] - # iocp and ov are only used by cancel() to notify IocpProactor - # that the wait was cancelled - self._iocp = iocp + # Keep a reference to the Overlapped object to keep it alive until the + # wait is unregistered self._ov = ov self._handle = handle self._wait_handle = wait_handle + # Should we call UnregisterWaitEx() if the wait completes + # or is cancelled? + self._registered = True + def _poll(self): # non-blocking wait: use a timeout of 0 millisecond return (_winapi.WaitForSingleObject(self._handle, 0) == @@ -99,21 +102,32 @@ def _poll(self): def _repr_info(self): info = super()._repr_info() - info.insert(1, 'handle=%#x' % self._handle) - if self._wait_handle: + info.append('handle=%#x' % self._handle) + if self._handle is not None: state = 'signaled' if self._poll() else 'waiting' - info.insert(1, 'wait_handle=<%s, %#x>' - % (state, self._wait_handle)) + info.append(state) + if self._wait_handle is not None: + info.append('wait_handle=%#x' % self._wait_handle) return info + def _unregister_wait_cb(self, fut): + # The wait was unregistered: it's not safe to destroy the Overlapped + # object + self._ov = None + def _unregister_wait(self): - if self._wait_handle is None: + if not self._registered: return + self._registered = False + try: _overlapped.UnregisterWait(self._wait_handle) except OSError as exc: - # ERROR_IO_PENDING is not an error, the wait was unregistered - if exc.winerror != _overlapped.ERROR_IO_PENDING: + self._wait_handle = None + if exc.winerror == _overlapped.ERROR_IO_PENDING: + # ERROR_IO_PENDING is not an error, the wait was unregistered + self._unregister_wait_cb(None) + elif exc.winerror != _overlapped.ERROR_IO_PENDING: context = { 'message': 'Failed to unregister the wait handle', 'exception': exc, @@ -122,26 +136,91 @@ def _unregister_wait(self): if self._source_traceback: context['source_traceback'] = self._source_traceback self._loop.call_exception_handler(context) - self._wait_handle = None - self._iocp = None - self._ov = None + else: + self._wait_handle = None + self._unregister_wait_cb(None) def cancel(self): - result = super().cancel() - if self._ov is not None: - # signal the cancellation to the overlapped object - _overlapped.PostQueuedCompletionStatus(self._iocp, True, - 0, self._ov.address) self._unregister_wait() - return result + return super().cancel() def set_exception(self, exception): - super().set_exception(exception) self._unregister_wait() + super().set_exception(exception) def set_result(self, result): - super().set_result(result) self._unregister_wait() + super().set_result(result) + + +class _WaitCancelFuture(_BaseWaitHandleFuture): + """Subclass of Future which represents a wait for the cancellation of a + _WaitHandleFuture using an event. + """ + + def __init__(self, ov, event, wait_handle, *, loop=None): + super().__init__(ov, event, wait_handle, loop=loop) + + self._done_callback = None + + def _schedule_callbacks(self): + super(_WaitCancelFuture, self)._schedule_callbacks() + if self._done_callback is not None: + self._done_callback(self) + + +class _WaitHandleFuture(_BaseWaitHandleFuture): + def __init__(self, ov, handle, wait_handle, proactor, *, loop=None): + super().__init__(ov, handle, wait_handle, loop=loop) + self._proactor = proactor + self._unregister_proactor = True + self._event = _overlapped.CreateEvent(None, True, False, None) + self._event_fut = None + + def _unregister_wait_cb(self, fut): + if self._event is not None: + _winapi.CloseHandle(self._event) + self._event = None + self._event_fut = None + + # If the wait was cancelled, the wait may never be signalled, so + # it's required to unregister it. Otherwise, IocpProactor.close() will + # wait forever for an event which will never come. + # + # If the IocpProactor already received the event, it's safe to call + # _unregister() because we kept a reference to the Overlapped object + # which is used as an unique key. + self._proactor._unregister(self._ov) + self._proactor = None + + super()._unregister_wait_cb(fut) + + def _unregister_wait(self): + if not self._registered: + return + self._registered = False + + try: + _overlapped.UnregisterWaitEx(self._wait_handle, self._event) + except OSError as exc: + self._wait_handle = None + if exc.winerror == _overlapped.ERROR_IO_PENDING: + # ERROR_IO_PENDING is not an error, the wait was unregistered + self._unregister_wait_cb(None) + elif exc.winerror != _overlapped.ERROR_IO_PENDING: + context = { + 'message': 'Failed to unregister the wait handle', + 'exception': exc, + 'future': self, + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + else: + self._wait_handle = None + self._event_fut = self._proactor._wait_cancel( + self._event, + self._unregister_wait_cb) class PipeServer(object): @@ -291,6 +370,7 @@ def __init__(self, concurrency=0xffffffff): _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency) self._cache = {} self._registered = weakref.WeakSet() + self._unregistered = [] self._stopped_serving = weakref.WeakSet() def __repr__(self): @@ -438,6 +518,16 @@ def wait_for_handle(self, handle, timeout=None): Return a Future object. The result of the future is True if the wait completed, or False if the wait did not complete (on timeout). """ + return self._wait_for_handle(handle, timeout, False) + + def _wait_cancel(self, event, done_callback): + fut = self._wait_for_handle(event, None, True) + # add_done_callback() cannot be used because the wait may only complete + # in IocpProactor.close(), while the event loop is not running. + fut._done_callback = done_callback + return fut + + def _wait_for_handle(self, handle, timeout, _is_cancel): if timeout is None: ms = _winapi.INFINITE else: @@ -447,9 +537,13 @@ def wait_for_handle(self, handle, timeout=None): # We only create ov so we can use ov.address as a key for the cache. ov = _overlapped.Overlapped(NULL) - wh = _overlapped.RegisterWaitWithQueue( + wait_handle = _overlapped.RegisterWaitWithQueue( handle, self._iocp, ov.address, ms) - f = _WaitHandleFuture(self._iocp, ov, handle, wh, loop=self._loop) + if _is_cancel: + f = _WaitCancelFuture(ov, handle, wait_handle, loop=self._loop) + else: + f = _WaitHandleFuture(ov, handle, wait_handle, self, + loop=self._loop) if f._source_traceback: del f._source_traceback[-1] @@ -462,14 +556,6 @@ def finish_wait_for_handle(trans, key, ov): # False even though we have not timed out. return f._poll() - if f._poll(): - try: - result = f._poll() - except OSError as exc: - f.set_exception(exc) - else: - f.set_result(result) - self._cache[ov.address] = (f, ov, 0, finish_wait_for_handle) return f @@ -521,6 +607,15 @@ def _register(self, ov, obj, callback, self._cache[ov.address] = (f, ov, obj, callback) return f + def _unregister(self, ov): + """Unregister an overlapped object. + + Call this method when its future has been cancelled. The event can + already be signalled (pending in the proactor event queue). It is also + safe if the event is never signalled (because it was cancelled). + """ + self._unregistered.append(ov) + def _get_accept_socket(self, family): s = socket.socket(family) s.settimeout(0) @@ -541,7 +636,7 @@ def _poll(self, timeout=None): while True: status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms) if status is None: - return + break ms = 0 err, transferred, key, address = status @@ -576,6 +671,11 @@ def _poll(self, timeout=None): f.set_result(value) self._results.append(f) + # Remove unregisted futures + for ov in self._unregistered: + self._cache.pop(ov.address, None) + self._unregistered.clear() + def _stop_serving(self, obj): # obj is a socket or pipe handle. It will be closed in # BaseProactorEventLoop._stop_serving() which will make any diff --git a/overlapped.c b/overlapped.c index 6842efbb..d22c626e 100644 --- a/overlapped.c +++ b/overlapped.c @@ -309,6 +309,29 @@ overlapped_UnregisterWait(PyObject *self, PyObject *args) Py_RETURN_NONE; } +PyDoc_STRVAR( + UnregisterWaitEx_doc, + "UnregisterWaitEx(WaitHandle, Event) -> None\n\n" + "Unregister wait handle.\n"); + +static PyObject * +overlapped_UnregisterWaitEx(PyObject *self, PyObject *args) +{ + HANDLE WaitHandle, Event; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE, &WaitHandle, &Event)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = UnregisterWaitEx(WaitHandle, Event); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + /* * Event functions -- currently only used by tests */ @@ -1319,6 +1342,8 @@ static PyMethodDef overlapped_functions[] = { METH_VARARGS, RegisterWaitWithQueue_doc}, {"UnregisterWait", overlapped_UnregisterWait, METH_VARARGS, UnregisterWait_doc}, + {"UnregisterWaitEx", overlapped_UnregisterWaitEx, + METH_VARARGS, UnregisterWaitEx_doc}, {"CreateEvent", overlapped_CreateEvent, METH_VARARGS, CreateEvent_doc}, {"SetEvent", overlapped_SetEvent, From 0bb6858b7c201843f7823ef272c9dfcf7eb9226b Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 22 Jan 2015 00:15:23 +0100 Subject: [PATCH 1300/1502] BaseEventLoop._create_connection_transport() catchs any exception, not only Exception --- asyncio/base_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 739296b7..1ceeb2d2 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -636,7 +636,7 @@ def _create_connection_transport(self, sock, protocol_factory, ssl, try: yield from waiter - except Exception: + except: transport.close() raise From ae97759d878519f510d752865340c5a18f06ecc8 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 22 Jan 2015 00:15:53 +0100 Subject: [PATCH 1301/1502] Python issue #23095: IocpProactor.close() must not cancel pending _WaitCancelFuture futures --- asyncio/windows_events.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 5105426f..3cb5690f 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -163,6 +163,9 @@ def __init__(self, ov, event, wait_handle, *, loop=None): self._done_callback = None + def cancel(self): + raise RuntimeError("_WaitCancelFuture must not be cancelled") + def _schedule_callbacks(self): super(_WaitCancelFuture, self)._schedule_callbacks() if self._done_callback is not None: @@ -693,6 +696,9 @@ def close(self): # FIXME: Tulip issue 196: remove this case, it should not happen elif fut.done() and not fut.cancelled(): del self._cache[address] + elif isinstance(fut, _WaitCancelFuture): + # _WaitCancelFuture must not be cancelled + pass else: try: fut.cancel() From 2fe68a363d4ebcb01826cb21338d78b691bcf244 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 22 Jan 2015 09:24:03 +0100 Subject: [PATCH 1302/1502] IocpProactor.close(): don't cancel futures which are already cancelled --- asyncio/windows_events.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 3cb5690f..315455aa 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -693,12 +693,16 @@ def close(self): # queues a task to Windows' thread pool. This cannot # be cancelled, so just forget it. del self._cache[address] - # FIXME: Tulip issue 196: remove this case, it should not happen - elif fut.done() and not fut.cancelled(): - del self._cache[address] + elif fut.cancelled(): + # Nothing to do with cancelled futures + pass elif isinstance(fut, _WaitCancelFuture): # _WaitCancelFuture must not be cancelled pass + elif fut.done(): + # FIXME: Tulip issue 196: remove this case, it should not + # happen + del self._cache[address] else: try: fut.cancel() From bb31eae65a094c5fa8fc6e53d89404815216ccb8 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 22 Jan 2015 10:00:03 +0100 Subject: [PATCH 1303/1502] release.py: Fix help --- release.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/release.py b/release.py index 3d95a98f..a5acbc88 100755 --- a/release.py +++ b/release.py @@ -396,7 +396,8 @@ def parse_options(self): if command: print("Invalid command: %s" % command) else: - parser.print_usage() + parser.print_help() + print("") print("Available commands:") print("- build: build asyncio in place, imply --running") From 8f386cd5fd5410b076845665ec8211541b827638 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 22 Jan 2015 09:44:29 +0100 Subject: [PATCH 1304/1502] Python issue #23293: Rewrite IocpProactor.connect_pipe() Add _overlapped.ConnectPipe() which tries to connect to the pipe for asynchronous I/O (overlapped): call CreateFile() in a loop until it doesn't fail with ERROR_PIPE_BUSY. Use an increasing delay between 1 ms and 100 ms. Remove Overlapped.WaitNamedPipeAndConnect() which is no more used. --- asyncio/windows_events.py | 43 ++++++++------ overlapped.c | 115 ++++++++------------------------------ 2 files changed, 48 insertions(+), 110 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 315455aa..7d0dbe9d 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -29,6 +29,12 @@ ERROR_CONNECTION_REFUSED = 1225 ERROR_CONNECTION_ABORTED = 1236 +# Initial delay in seconds for connect_pipe() before retrying to connect +CONNECT_PIPE_INIT_DELAY = 0.001 + +# Maximum delay in seconds for connect_pipe() before retrying to connect +CONNECT_PIPE_MAX_DELAY = 0.100 + class _OverlappedFuture(futures.Future): """Subclass of Future which represents an overlapped operation. @@ -495,25 +501,28 @@ def finish_accept_pipe(trans, key, ov): return self._register(ov, pipe, finish_accept_pipe, register=False) - def connect_pipe(self, address): - ov = _overlapped.Overlapped(NULL) - ov.WaitNamedPipeAndConnect(address, self._iocp, ov.address) - - def finish_connect_pipe(err, handle, ov): - # err, handle were arguments passed to PostQueuedCompletionStatus() - # in a function run in a thread pool. - if err == _overlapped.ERROR_SEM_TIMEOUT: - # Connection did not succeed within time limit. - msg = _overlapped.FormatMessage(err) - raise ConnectionRefusedError(0, msg, None, err) - elif err != 0: - msg = _overlapped.FormatMessage(err) - raise OSError(0, msg, None, err) + def _connect_pipe(self, fut, address, delay): + # Unfortunately there is no way to do an overlapped connect to a pipe. + # Call CreateFile() in a loop until it doesn't fail with + # ERROR_PIPE_BUSY + try: + handle = _overlapped.ConnectPipe(address) + except OSError as exc: + if exc.winerror == _overlapped.ERROR_PIPE_BUSY: + # Polling: retry later + delay = min(delay * 2, CONNECT_PIPE_MAX_DELAY) + self._loop.call_later(delay, + self._connect_pipe, fut, address, delay) else: - return windows_utils.PipeHandle(handle) + fut.set_exception(exc) + else: + pipe = windows_utils.PipeHandle(handle) + fut.set_result(pipe) - return self._register(ov, None, finish_connect_pipe, - wait_for_post=True) + def connect_pipe(self, address): + fut = futures.Future(loop=self._loop) + self._connect_pipe(fut, address, CONNECT_PIPE_INIT_DELAY) + return fut def wait_for_handle(self, handle, timeout=None): """Wait for a handle. diff --git a/overlapped.c b/overlapped.c index d22c626e..8fe2e247 100644 --- a/overlapped.c +++ b/overlapped.c @@ -52,12 +52,6 @@ typedef struct { }; } OverlappedObject; -typedef struct { - OVERLAPPED *Overlapped; - HANDLE IocpHandle; - char Address[1]; -} WaitNamedPipeAndConnectContext; - /* * Map Windows error codes to subclasses of OSError */ @@ -1133,99 +1127,33 @@ Overlapped_ConnectNamedPipe(OverlappedObject *self, PyObject *args) } } -/* Unfortunately there is no way to do an overlapped connect to a - pipe. We instead use WaitNamedPipe() and CreateFile() in a thread - pool thread. If a connection succeeds within a time limit (10 - seconds) then PostQueuedCompletionStatus() is used to return the - pipe handle to the completion port. */ - -static DWORD WINAPI -WaitNamedPipeAndConnectInThread(WaitNamedPipeAndConnectContext *ctx) -{ - HANDLE PipeHandle = INVALID_HANDLE_VALUE; - DWORD Start = GetTickCount(); - DWORD Deadline = Start + 10*1000; - DWORD Error = 0; - DWORD Timeout; - BOOL Success; - - for ( ; ; ) { - Timeout = Deadline - GetTickCount(); - if ((int)Timeout < 0) - break; - Success = WaitNamedPipe(ctx->Address, Timeout); - Error = Success ? ERROR_SUCCESS : GetLastError(); - switch (Error) { - case ERROR_SUCCESS: - PipeHandle = CreateFile(ctx->Address, - GENERIC_READ | GENERIC_WRITE, - 0, NULL, OPEN_EXISTING, - FILE_FLAG_OVERLAPPED, NULL); - if (PipeHandle == INVALID_HANDLE_VALUE) - continue; - break; - case ERROR_SEM_TIMEOUT: - continue; - } - break; - } - if (!PostQueuedCompletionStatus(ctx->IocpHandle, Error, - (ULONG_PTR)PipeHandle, ctx->Overlapped)) - CloseHandle(PipeHandle); - free(ctx); - return 0; -} - PyDoc_STRVAR( - Overlapped_WaitNamedPipeAndConnect_doc, - "WaitNamedPipeAndConnect(addr, iocp_handle) -> Overlapped[pipe_handle]\n\n" - "Start overlapped connection to address, notifying iocp_handle when\n" - "finished"); + ConnectPipe_doc, + "ConnectPipe(addr) -> pipe_handle\n\n" + "Connect to the pipe for asynchronous I/O (overlapped)."); static PyObject * -Overlapped_WaitNamedPipeAndConnect(OverlappedObject *self, PyObject *args) +ConnectPipe(OverlappedObject *self, PyObject *args) { - char *Address; - Py_ssize_t AddressLength; - HANDLE IocpHandle; - OVERLAPPED Overlapped; - BOOL ret; - DWORD err; - WaitNamedPipeAndConnectContext *ctx; - Py_ssize_t ContextLength; + PyObject *AddressObj; + wchar_t *Address; + HANDLE PipeHandle; - if (!PyArg_ParseTuple(args, "s#" F_HANDLE F_POINTER, - &Address, &AddressLength, &IocpHandle, &Overlapped)) + if (!PyArg_ParseTuple(args, "U", &AddressObj)) return NULL; - if (self->type != TYPE_NONE) { - PyErr_SetString(PyExc_ValueError, "operation already attempted"); + Address = PyUnicode_AsWideCharString(AddressObj, NULL); + if (Address == NULL) return NULL; - } - ContextLength = (AddressLength + - offsetof(WaitNamedPipeAndConnectContext, Address)); - ctx = calloc(1, ContextLength + 1); - if (ctx == NULL) - return PyErr_NoMemory(); - memcpy(ctx->Address, Address, AddressLength + 1); - ctx->Overlapped = &self->overlapped; - ctx->IocpHandle = IocpHandle; - - self->type = TYPE_WAIT_NAMED_PIPE_AND_CONNECT; - self->handle = NULL; - - Py_BEGIN_ALLOW_THREADS - ret = QueueUserWorkItem(WaitNamedPipeAndConnectInThread, ctx, - WT_EXECUTELONGFUNCTION); - Py_END_ALLOW_THREADS - - mark_as_completed(&self->overlapped); - - self->error = err = ret ? ERROR_SUCCESS : GetLastError(); - if (!ret) - return SetFromWindowsErr(err); - Py_RETURN_NONE; + PipeHandle = CreateFileW(Address, + GENERIC_READ | GENERIC_WRITE, + 0, NULL, OPEN_EXISTING, + FILE_FLAG_OVERLAPPED, NULL); + PyMem_Free(Address); + if (PipeHandle == INVALID_HANDLE_VALUE) + return SetFromWindowsErr(0); + return Py_BuildValue(F_HANDLE, PipeHandle); } static PyObject* @@ -1262,9 +1190,6 @@ static PyMethodDef Overlapped_methods[] = { METH_VARARGS, Overlapped_DisconnectEx_doc}, {"ConnectNamedPipe", (PyCFunction) Overlapped_ConnectNamedPipe, METH_VARARGS, Overlapped_ConnectNamedPipe_doc}, - {"WaitNamedPipeAndConnect", - (PyCFunction) Overlapped_WaitNamedPipeAndConnect, - METH_VARARGS, Overlapped_WaitNamedPipeAndConnect_doc}, {NULL} }; @@ -1350,6 +1275,9 @@ static PyMethodDef overlapped_functions[] = { METH_VARARGS, SetEvent_doc}, {"ResetEvent", overlapped_ResetEvent, METH_VARARGS, ResetEvent_doc}, + {"ConnectPipe", + (PyCFunction) ConnectPipe, + METH_VARARGS, ConnectPipe_doc}, {NULL} }; @@ -1394,6 +1322,7 @@ PyInit__overlapped(void) WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING); WINAPI_CONSTANT(F_DWORD, ERROR_NETNAME_DELETED); WINAPI_CONSTANT(F_DWORD, ERROR_SEM_TIMEOUT); + WINAPI_CONSTANT(F_DWORD, ERROR_PIPE_BUSY); WINAPI_CONSTANT(F_DWORD, INFINITE); WINAPI_CONSTANT(F_HANDLE, INVALID_HANDLE_VALUE); WINAPI_CONSTANT(F_HANDLE, NULL); From 8b4b4f1fd83c33c645d58660756a4ed6bac5b174 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 22 Jan 2015 23:35:14 +0100 Subject: [PATCH 1305/1502] Tulip issue #204: Fix IocpProactor.accept_pipe() Overlapped.ConnectNamedPipe() now returns a boolean: True if the pipe is connected (if ConnectNamedPipe() failed with ERROR_PIPE_CONNECTED), False if the connection is in progress. This change removes multiple hacks in IocpProactor. --- asyncio/windows_events.py | 41 ++++++++++++++++----------------------- overlapped.c | 4 ++-- 2 files changed, 19 insertions(+), 26 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 7d0dbe9d..42c5f6e1 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -490,16 +490,21 @@ def finish_connect(trans, key, ov): def accept_pipe(self, pipe): self._register_with_iocp(pipe) ov = _overlapped.Overlapped(NULL) - ov.ConnectNamedPipe(pipe.fileno()) + connected = ov.ConnectNamedPipe(pipe.fileno()) + + if connected: + # ConnectNamePipe() failed with ERROR_PIPE_CONNECTED which means + # that the pipe is connected. There is no need to wait for the + # completion of the connection. + f = futures.Future(loop=self._loop) + f.set_result(pipe) + return f def finish_accept_pipe(trans, key, ov): ov.getresult() return pipe - # FIXME: Tulip issue 196: why do we need register=False? - # See also the comment in the _register() method - return self._register(ov, pipe, finish_accept_pipe, - register=False) + return self._register(ov, pipe, finish_accept_pipe) def _connect_pipe(self, fut, address, delay): # Unfortunately there is no way to do an overlapped connect to a pipe. @@ -581,15 +586,14 @@ def _register_with_iocp(self, obj): # to avoid sending notifications to completion port of ops # that succeed immediately. - def _register(self, ov, obj, callback, - wait_for_post=False, register=True): + def _register(self, ov, obj, callback): # Return a future which will be set with the result of the # operation when it completes. The future's value is actually # the value returned by callback(). f = _OverlappedFuture(ov, loop=self._loop) if f._source_traceback: del f._source_traceback[-1] - if not ov.pending and not wait_for_post: + if not ov.pending: # The operation has completed, so no need to postpone the # work. We cannot take this short cut if we need the # NumberOfBytes, CompletionKey values returned by @@ -605,18 +609,11 @@ def _register(self, ov, obj, callback, # Register the overlapped operation to keep a reference to the # OVERLAPPED object, otherwise the memory is freed and Windows may # read uninitialized memory. - # - # For an unknown reason, ConnectNamedPipe() behaves differently: - # the completion is not notified by GetOverlappedResult() if we - # already called GetOverlappedResult(). For this specific case, we - # don't expect notification (register is set to False). - else: - register = True - if register: - # Register the overlapped operation for later. Note that - # we only store obj to prevent it from being garbage - # collected too early. - self._cache[ov.address] = (f, ov, obj, callback) + + # Register the overlapped operation for later. Note that + # we only store obj to prevent it from being garbage + # collected too early. + self._cache[ov.address] = (f, ov, obj, callback) return f def _unregister(self, ov): @@ -708,10 +705,6 @@ def close(self): elif isinstance(fut, _WaitCancelFuture): # _WaitCancelFuture must not be cancelled pass - elif fut.done(): - # FIXME: Tulip issue 196: remove this case, it should not - # happen - del self._cache[address] else: try: fut.cancel() diff --git a/overlapped.c b/overlapped.c index 8fe2e247..4661152d 100644 --- a/overlapped.c +++ b/overlapped.c @@ -1117,10 +1117,10 @@ Overlapped_ConnectNamedPipe(OverlappedObject *self, PyObject *args) switch (err) { case ERROR_PIPE_CONNECTED: mark_as_completed(&self->overlapped); - Py_RETURN_NONE; + Py_RETURN_TRUE; case ERROR_SUCCESS: case ERROR_IO_PENDING: - Py_RETURN_NONE; + Py_RETURN_FALSE; default: self->type = TYPE_NOT_STARTED; return SetFromWindowsErr(err); From 995e2c1ac445c018dc0f8279e7f5b691cbc97c3a Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 23 Jan 2015 00:52:06 +0100 Subject: [PATCH 1306/1502] Close transports on error Fix create_datagram_endpoint(), connect_read_pipe() and connect_write_pipe(): close the transport if the task is cancelled or on error. --- asyncio/base_events.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 1ceeb2d2..e43441ea 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -723,7 +723,13 @@ def create_datagram_endpoint(self, protocol_factory, logger.debug("Datagram endpoint remote_addr=%r created: " "(%r, %r)", remote_addr, transport, protocol) - yield from waiter + + try: + yield from waiter + except: + transport.close() + raise + return transport, protocol @coroutine @@ -815,7 +821,13 @@ def connect_read_pipe(self, protocol_factory, pipe): protocol = protocol_factory() waiter = futures.Future(loop=self) transport = self._make_read_pipe_transport(pipe, protocol, waiter) - yield from waiter + + try: + yield from waiter + except: + transport.close() + raise + if self._debug: logger.debug('Read pipe %r connected: (%r, %r)', pipe.fileno(), transport, protocol) @@ -826,7 +838,13 @@ def connect_write_pipe(self, protocol_factory, pipe): protocol = protocol_factory() waiter = futures.Future(loop=self) transport = self._make_write_pipe_transport(pipe, protocol, waiter) - yield from waiter + + try: + yield from waiter + except: + transport.close() + raise + if self._debug: logger.debug('Write pipe %r connected: (%r, %r)', pipe.fileno(), transport, protocol) From adc9c3e0995cb0ea68faeb1cbc39e68ca9d2d2c2 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 23 Jan 2015 01:14:44 +0100 Subject: [PATCH 1307/1502] Python issue #23293: Cleanup IocpProactor.close() The special case for connect_pipe() is not more needed. connect_pipe() doesn't use overlapped operations anymore. --- asyncio/windows_events.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 42c5f6e1..6c7e0580 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -694,12 +694,7 @@ def _stop_serving(self, obj): def close(self): # Cancel remaining registered operations. for address, (fut, ov, obj, callback) in list(self._cache.items()): - if obj is None: - # The operation was started with connect_pipe() which - # queues a task to Windows' thread pool. This cannot - # be cancelled, so just forget it. - del self._cache[address] - elif fut.cancelled(): + if fut.cancelled(): # Nothing to do with cancelled futures pass elif isinstance(fut, _WaitCancelFuture): From 86bbe9e5844677c734994239acdbd19499bc67bd Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 23 Jan 2015 02:59:20 +0100 Subject: [PATCH 1308/1502] Tulip issue #219: Fix comments in simple_tcp_server.py example Fixing some comments. The server listens on 12345 and client connects on 12345, but the comments state 1234. Patch written by bryan.neff. --- examples/simple_tcp_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/simple_tcp_server.py b/examples/simple_tcp_server.py index b796d9b6..5f874ffc 100644 --- a/examples/simple_tcp_server.py +++ b/examples/simple_tcp_server.py @@ -4,7 +4,7 @@ asyncio.streams.open_connection(). Note that running this example starts both the TCP server and client -in the same process. It listens on port 1234 on 127.0.0.1, so it will +in the same process. It listens on port 12345 on 127.0.0.1, so it will fail if this port is currently in use. """ @@ -83,7 +83,7 @@ def _handle_client(self, client_reader, client_writer): def start(self, loop): """ - Starts the TCP server, so that it listens on port 1234. + Starts the TCP server, so that it listens on port 12345. For each client that connects, the accept_client method gets called. This method runs the loop until the server sockets From a4aa7c1e86fe288b7786c2afe8c64e1775cb2382 Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Fri, 23 Jan 2015 23:30:31 -0500 Subject: [PATCH 1309/1502] Tulip issue #220: Merge JoinableQueue with Queue. To more closely match the standard Queue, asyncio.Queue has "join" and "task_done". JoinableQueue is deleted. --- asyncio/queues.py | 95 +++++++++++++++++--------------------------- tests/test_queues.py | 10 ++--- 2 files changed, 42 insertions(+), 63 deletions(-) diff --git a/asyncio/queues.py b/asyncio/queues.py index dce0d53c..37b0c41c 100644 --- a/asyncio/queues.py +++ b/asyncio/queues.py @@ -1,7 +1,6 @@ """Queues""" -__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue', - 'QueueFull', 'QueueEmpty'] +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty'] import collections import heapq @@ -45,6 +44,9 @@ def __init__(self, maxsize=0, *, loop=None): self._getters = collections.deque() # Pairs of (item, Future). self._putters = collections.deque() + self._unfinished_tasks = 0 + self._finished = locks.Event(loop=self._loop) + self._finished.set() self._init(maxsize) def _init(self, maxsize): @@ -55,6 +57,8 @@ def _get(self): def _put(self, item): self._queue.append(item) + self._unfinished_tasks += 1 + self._finished.clear() def __repr__(self): return '<{} at {:#x} {}>'.format( @@ -71,6 +75,8 @@ def _format(self): result += ' _getters[{}]'.format(len(self._getters)) if self._putters: result += ' _putters[{}]'.format(len(self._putters)) + if self._unfinished_tasks: + result += ' tasks={}'.format(self._unfinished_tasks) return result def _consume_done_getters(self): @@ -122,9 +128,6 @@ def put(self, item): 'queue non-empty, why are getters waiting?') getter = self._getters.popleft() - - # Use _put and _get instead of passing item straight to getter, in - # case a subclass has logic that must run (e.g. JoinableQueue). self._put(item) # getter cannot be cancelled, we just removed done getters @@ -150,9 +153,6 @@ def put_nowait(self, item): 'queue non-empty, why are getters waiting?') getter = self._getters.popleft() - - # Use _put and _get instead of passing item straight to getter, in - # case a subclass has logic that must run (e.g. JoinableQueue). self._put(item) # getter cannot be cancelled, we just removed done getters @@ -215,56 +215,6 @@ def get_nowait(self): else: raise QueueEmpty - -class PriorityQueue(Queue): - """A subclass of Queue; retrieves entries in priority order (lowest first). - - Entries are typically tuples of the form: (priority number, data). - """ - - def _init(self, maxsize): - self._queue = [] - - def _put(self, item, heappush=heapq.heappush): - heappush(self._queue, item) - - def _get(self, heappop=heapq.heappop): - return heappop(self._queue) - - -class LifoQueue(Queue): - """A subclass of Queue that retrieves most recently added entries first.""" - - def _init(self, maxsize): - self._queue = [] - - def _put(self, item): - self._queue.append(item) - - def _get(self): - return self._queue.pop() - - -class JoinableQueue(Queue): - """A subclass of Queue with task_done() and join() methods.""" - - def __init__(self, maxsize=0, *, loop=None): - super().__init__(maxsize=maxsize, loop=loop) - self._unfinished_tasks = 0 - self._finished = locks.Event(loop=self._loop) - self._finished.set() - - def _format(self): - result = Queue._format(self) - if self._unfinished_tasks: - result += ' tasks={}'.format(self._unfinished_tasks) - return result - - def _put(self, item): - super()._put(item) - self._unfinished_tasks += 1 - self._finished.clear() - def task_done(self): """Indicate that a formerly enqueued task is complete. @@ -296,3 +246,32 @@ def join(self): """ if self._unfinished_tasks > 0: yield from self._finished.wait() + + +class PriorityQueue(Queue): + """A subclass of Queue; retrieves entries in priority order (lowest first). + + Entries are typically tuples of the form: (priority number, data). + """ + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item, heappush=heapq.heappush): + heappush(self._queue, item) + + def _get(self, heappop=heapq.heappop): + return heappop(self._queue) + + +class LifoQueue(Queue): + """A subclass of Queue that retrieves most recently added entries first.""" + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item): + self._queue.append(item) + + def _get(self): + return self._queue.pop() diff --git a/tests/test_queues.py b/tests/test_queues.py index 3d4ac51d..a73539d1 100644 --- a/tests/test_queues.py +++ b/tests/test_queues.py @@ -408,14 +408,14 @@ def test_order(self): self.assertEqual([1, 2, 3], items) -class JoinableQueueTests(_QueueTestBase): +class QueueJoinTests(_QueueTestBase): def test_task_done_underflow(self): - q = asyncio.JoinableQueue(loop=self.loop) + q = asyncio.Queue(loop=self.loop) self.assertRaises(ValueError, q.task_done) def test_task_done(self): - q = asyncio.JoinableQueue(loop=self.loop) + q = asyncio.Queue(loop=self.loop) for i in range(100): q.put_nowait(i) @@ -452,7 +452,7 @@ def test(): self.loop.run_until_complete(asyncio.wait(tasks, loop=self.loop)) def test_join_empty_queue(self): - q = asyncio.JoinableQueue(loop=self.loop) + q = asyncio.Queue(loop=self.loop) # Test that a queue join()s successfully, and before anything else # (done twice for insurance). @@ -465,7 +465,7 @@ def join(): self.loop.run_until_complete(join()) def test_format(self): - q = asyncio.JoinableQueue(loop=self.loop) + q = asyncio.Queue(loop=self.loop) self.assertEqual(q._format(), 'maxsize=0') q._unfinished_tasks = 2 From dc19c32eb65881e18d15e7e357bda881700542d6 Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Fri, 23 Jan 2015 23:31:27 -0500 Subject: [PATCH 1310/1502] Docstring for Queue.join shouldn't mention threads. --- asyncio/queues.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/asyncio/queues.py b/asyncio/queues.py index 37b0c41c..b0fb3873 100644 --- a/asyncio/queues.py +++ b/asyncio/queues.py @@ -240,8 +240,8 @@ def join(self): """Block until all items in the queue have been gotten and processed. The count of unfinished tasks goes up whenever an item is added to the - queue. The count goes down whenever a consumer thread calls task_done() - to indicate that the item was retrieved and all work on it is complete. + queue. The count goes down whenever a consumer calls task_done() to + indicate that the item was retrieved and all work on it is complete. When the count of unfinished tasks drops to zero, join() unblocks. """ if self._unfinished_tasks > 0: From 684f3be00011d3c6cc4f81f5cb61c157e5eb5205 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 26 Jan 2015 10:52:45 +0100 Subject: [PATCH 1311/1502] Python issue #23208: Add BaseEventLoop._current_handle In debug mode, BaseEventLoop._run_once() now sets the BaseEventLoop._current_handle attribute to the handle currently executed. In release mode or when no handle is executed, the attribute is None. BaseEventLoop.default_exception_handler() displays the traceback of the current handle if available. --- .hgeol | 4 + .hgignore | 15 + AUTHORS | 26 + COPYING | 201 ++ ChangeLog | 185 ++ MANIFEST.in | 11 + Makefile | 60 + README | 44 + asyncio/__init__.py | 50 + asyncio/base_events.py | 1151 +++++++++++ asyncio/base_subprocess.py | 229 +++ asyncio/constants.py | 7 + asyncio/coroutines.py | 199 ++ asyncio/events.py | 597 ++++++ asyncio/futures.py | 409 ++++ asyncio/locks.py | 469 +++++ asyncio/log.py | 7 + asyncio/proactor_events.py | 535 +++++ asyncio/protocols.py | 129 ++ asyncio/queues.py | 298 +++ asyncio/selector_events.py | 1007 +++++++++ asyncio/selectors.py | 594 ++++++ asyncio/sslproto.py | 646 ++++++ asyncio/streams.py | 486 +++++ asyncio/subprocess.py | 249 +++ asyncio/tasks.py | 667 ++++++ asyncio/test_support.py | 305 +++ asyncio/test_utils.py | 442 ++++ asyncio/transports.py | 300 +++ asyncio/unix_events.py | 961 +++++++++ asyncio/windows_events.py | 752 +++++++ asyncio/windows_utils.py | 217 ++ check.py | 45 + examples/cacheclt.py | 213 ++ examples/cachesvr.py | 249 +++ examples/child_process.py | 128 ++ examples/crawl.py | 863 ++++++++ examples/echo_client_tulip.py | 20 + examples/echo_server_tulip.py | 20 + examples/fetch0.py | 35 + examples/fetch1.py | 78 + examples/fetch2.py | 141 ++ examples/fetch3.py | 230 +++ examples/fuzz_as_completed.py | 69 + examples/hello_callback.py | 17 + examples/hello_coroutine.py | 18 + examples/shell.py | 50 + examples/simple_tcp_server.py | 154 ++ examples/sink.py | 94 + examples/source.py | 100 + examples/source1.py | 98 + examples/stacks.py | 44 + examples/subprocess_attach_read_pipe.py | 33 + examples/subprocess_attach_write_pipe.py | 35 + examples/subprocess_shell.py | 87 + examples/tcp_echo.py | 128 ++ examples/timing_tcp_server.py | 168 ++ examples/udp_echo.py | 104 + overlapped.c | 1334 ++++++++++++ pypi.bat | 1 + release.py | 517 +++++ run_aiotest.py | 14 + runtests.py | 304 +++ setup.py | 49 + tests/echo.py | 8 + tests/echo2.py | 6 + tests/echo3.py | 11 + tests/keycert3.pem | 73 + tests/pycacert.pem | 78 + tests/sample.crt | 14 + tests/sample.key | 15 + tests/ssl_cert.pem | 15 + tests/ssl_key.pem | 16 + tests/test_base_events.py | 1236 +++++++++++ tests/test_events.py | 2369 ++++++++++++++++++++++ tests/test_futures.py | 473 +++++ tests/test_locks.py | 858 ++++++++ tests/test_proactor_events.py | 587 ++++++ tests/test_queues.py | 476 +++++ tests/test_selector_events.py | 1741 ++++++++++++++++ tests/test_selectors.py | 454 +++++ tests/test_sslproto.py | 49 + tests/test_streams.py | 641 ++++++ tests/test_subprocess.py | 338 +++ tests/test_tasks.py | 2019 ++++++++++++++++++ tests/test_transports.py | 91 + tests/test_unix_events.py | 1568 ++++++++++++++ tests/test_windows_events.py | 146 ++ tests/test_windows_utils.py | 182 ++ tox.ini | 21 + update_stdlib.sh | 70 + 91 files changed, 30247 insertions(+) create mode 100644 .hgeol create mode 100644 .hgignore create mode 100644 AUTHORS create mode 100644 COPYING create mode 100644 ChangeLog create mode 100644 MANIFEST.in create mode 100644 Makefile create mode 100644 README create mode 100644 asyncio/__init__.py create mode 100644 asyncio/base_events.py create mode 100644 asyncio/base_subprocess.py create mode 100644 asyncio/constants.py create mode 100644 asyncio/coroutines.py create mode 100644 asyncio/events.py create mode 100644 asyncio/futures.py create mode 100644 asyncio/locks.py create mode 100644 asyncio/log.py create mode 100644 asyncio/proactor_events.py create mode 100644 asyncio/protocols.py create mode 100644 asyncio/queues.py create mode 100644 asyncio/selector_events.py create mode 100644 asyncio/selectors.py create mode 100644 asyncio/sslproto.py create mode 100644 asyncio/streams.py create mode 100644 asyncio/subprocess.py create mode 100644 asyncio/tasks.py create mode 100644 asyncio/test_support.py create mode 100644 asyncio/test_utils.py create mode 100644 asyncio/transports.py create mode 100644 asyncio/unix_events.py create mode 100644 asyncio/windows_events.py create mode 100644 asyncio/windows_utils.py create mode 100644 check.py create mode 100644 examples/cacheclt.py create mode 100644 examples/cachesvr.py create mode 100644 examples/child_process.py create mode 100644 examples/crawl.py create mode 100644 examples/echo_client_tulip.py create mode 100644 examples/echo_server_tulip.py create mode 100644 examples/fetch0.py create mode 100644 examples/fetch1.py create mode 100644 examples/fetch2.py create mode 100644 examples/fetch3.py create mode 100644 examples/fuzz_as_completed.py create mode 100644 examples/hello_callback.py create mode 100644 examples/hello_coroutine.py create mode 100644 examples/shell.py create mode 100644 examples/simple_tcp_server.py create mode 100644 examples/sink.py create mode 100644 examples/source.py create mode 100644 examples/source1.py create mode 100644 examples/stacks.py create mode 100644 examples/subprocess_attach_read_pipe.py create mode 100644 examples/subprocess_attach_write_pipe.py create mode 100644 examples/subprocess_shell.py create mode 100755 examples/tcp_echo.py create mode 100644 examples/timing_tcp_server.py create mode 100755 examples/udp_echo.py create mode 100644 overlapped.c create mode 100644 pypi.bat create mode 100755 release.py create mode 100644 run_aiotest.py create mode 100644 runtests.py create mode 100644 setup.py create mode 100644 tests/echo.py create mode 100644 tests/echo2.py create mode 100644 tests/echo3.py create mode 100644 tests/keycert3.pem create mode 100644 tests/pycacert.pem create mode 100644 tests/sample.crt create mode 100644 tests/sample.key create mode 100644 tests/ssl_cert.pem create mode 100644 tests/ssl_key.pem create mode 100644 tests/test_base_events.py create mode 100644 tests/test_events.py create mode 100644 tests/test_futures.py create mode 100644 tests/test_locks.py create mode 100644 tests/test_proactor_events.py create mode 100644 tests/test_queues.py create mode 100644 tests/test_selector_events.py create mode 100644 tests/test_selectors.py create mode 100644 tests/test_sslproto.py create mode 100644 tests/test_streams.py create mode 100644 tests/test_subprocess.py create mode 100644 tests/test_tasks.py create mode 100644 tests/test_transports.py create mode 100644 tests/test_unix_events.py create mode 100644 tests/test_windows_events.py create mode 100644 tests/test_windows_utils.py create mode 100644 tox.ini create mode 100755 update_stdlib.sh diff --git a/.hgeol b/.hgeol new file mode 100644 index 00000000..8233b6dd --- /dev/null +++ b/.hgeol @@ -0,0 +1,4 @@ +[patterns] +** = native +.hgignore = native +.hgeol = native diff --git a/.hgignore b/.hgignore new file mode 100644 index 00000000..736c7fdf --- /dev/null +++ b/.hgignore @@ -0,0 +1,15 @@ +.*\.py[co]$ +.*~$ +.*\.orig$ +.*\#.*$ +.*@.*$ +\.coverage$ +htmlcov$ +\.DS_Store$ +venv$ +distribute_setup.py$ +distribute-\d+.\d+.\d+.tar.gz$ +build$ +dist$ +.*\.egg-info$ +\.tox$ diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 00000000..d25b4465 --- /dev/null +++ b/AUTHORS @@ -0,0 +1,26 @@ +A. Jesse Jiryu Davis +Aaron Griffith +Andrew Svetlov +Anthony Baire +Antoine Pitrou +Arnaud Faure +Aymeric Augustin +Brett Cannon +Charles-François Natali +Christian Heimes +Donald Stufft +Eli Bendersky +Geert Jansen +Giampaolo Rodola' +Guido van Rossum : creator of the Tulip project and author of the PEP 3156 +Gustavo Carneiro +Jeff Quast +Jonathan Slenders +Nikolay Kim +Richard Oudkerk +Saúl Ibarra Corretgé +Serhiy Storchaka +Vajrasky Kok +Victor Stinner +Vladimir Kryachko +Yury Selivanov diff --git a/COPYING b/COPYING new file mode 100644 index 00000000..11069edd --- /dev/null +++ b/COPYING @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/ChangeLog b/ChangeLog new file mode 100644 index 00000000..e483fb2d --- /dev/null +++ b/ChangeLog @@ -0,0 +1,185 @@ +2014-09-30: Tulip 3.4.2 +======================= + +New shiny methods like create_task(), better documentation, much better debug +mode, better tests. + +asyncio API +----------- + +* Add BaseEventLoop.create_task() method: schedule a coroutine object. + It allows other asyncio implementations to use their own Task class to + change its behaviour. + +* New BaseEventLoop methods: + + - create_task(): schedule a coroutine + - get_debug() + - is_closed() + - set_debug() + +* Add _FlowControlMixin.get_write_buffer_limits() method + +* sock_recv(), sock_sendall(), sock_connect(), sock_accept() methods of + SelectorEventLoop now raise an exception if the socket is blocking mode + +* Include unix_events/windows_events symbols in asyncio.__all__. + Examples: SelectorEventLoop, ProactorEventLoop, DefaultEventLoopPolicy. + +* attach(), detach(), loop, active_count and waiters attributes of the Server + class are now private + +* BaseEventLoop: run_forever(), run_until_complete() now raises an exception if + the event loop was closed + +* close() now raises an exception if the event loop is running, because pending + callbacks would be lost + +* Queue now accepts a float for the maximum size. + +* Process.communicate() now ignores BrokenPipeError and ConnectionResetError + exceptions, as Popen.communicate() of the subprocess module + + +Performances +------------ + +* Optimize handling of cancelled timers + + +Debug +----- + +* Future (and Task), CoroWrapper and Handle now remembers where they were + created (new _source_traceback object), traceback displayed when errors are + logged. + +* On Python 3.4 and newer, Task destrutor now logs a warning if the task was + destroyed while it was still pending. It occurs if the last reference + to the task was removed, while the coroutine didn't finish yet. + +* Much more useful events are logged: + + - Event loop closed + - Network connection + - Creation of a subprocess + - Pipe lost + - Log many errors previously silently ignored + - SSL handshake failure + - etc. + +* BaseEventLoop._debug is now True if the envrionement variable + PYTHONASYNCIODEBUG is set + +* Log the duration of DNS resolution and SSL handshake + +* Log a warning if a callback blocks the event loop longer than 100 ms + (configurable duration) + +* repr(CoroWrapper) and repr(Task) now contains the current status of the + coroutine (running, done), current filename and line number, and filename and + line number where the object was created + +* Enhance representation (repr) of transports: add the file descriptor, status + (idle, polling, writing, etc.), size of the write buffer, ... + +* Add repr(BaseEventLoop) + +* run_until_complete() doesn't log a warning anymore when called with a + coroutine object which raises an exception. + + +Bugfixes +-------- + +* windows_utils.socketpair() now ensures that sockets are closed in case + of error. + +* Rewrite bricks of the IocpProactor() to make it more reliable + +* IocpProactor destructor now closes it. + +* _OverlappedFuture.set_exception() now cancels the overlapped operation. + +* Rewrite _WaitHandleFuture: + + - cancel() is now able to signal the cancellation to the overlapped object + - _unregister_wait() now catchs and logs exceptions + +* PipeServer.close() (class used on Windows) now cancels the accept pipe + future. + +* Rewrite signal handling in the UNIX implementation of SelectorEventLoop: + use the self-pipe to store pending signals instead of registering a + signal handler calling directly _handle_signal(). The change fixes a + race condition. + +* create_unix_server(): close the socket on error. + +* Fix wait_for() + +* Rewrite gather() + +* drain() is now a classic coroutine, no more special return value (empty + tuple) + +* Rewrite SelectorEventLoop.sock_connect() to handle correctly timeout + +* Process data of the self-pipe faster to accept more pending events, + especially signals written by signal handlers: the callback reads all pending + data, not only a single byte + +* Don't try to set the result of a Future anymore if it was cancelled + (explicitly or by a timeout) + +* CoroWrapper now works around CPython issue #21209: yield from & custom + generator classes don't work together, issue with the send() method. It only + affected asyncio in debug mode on Python older than 3.4.2 + + +Misc changes +------------ + +* windows_utils.socketpair() now supports IPv6. + +* Better documentation (online & docstrings): fill remaining XXX, more examples + +* new asyncio.coroutines submodule, to ease maintenance with the trollius + project: @coroutine, _DEBUG, iscoroutine() and iscoroutinefunction() have + been moved from asyncio.tasks to asyncio.coroutines + +* Cleanup code, ex: remove unused attribute (ex: _rawsock) + +* Reuse os.set_blocking() of Python 3.5. + +* Close explicitly the event loop in Tulip examples. + +* runtests.py now mention if tests are running in release or debug mode. + +2014-05-19: Tulip 3.4.1 +======================= + +2014-02-24: Tulip 0.4.1 +======================= + +2014-02-10: Tulip 0.3.1 +======================= + +* Add asyncio.subprocess submodule and the Process class. + +2013-11-25: Tulip 0.2.1 +======================= + +* Add support of subprocesses using transports and protocols. + +2013-10-22: Tulip 0.1.1 +======================= + +* First release. + +Creation of the project +======================= + +* 2013-10-14: The tulip package was renamed to asyncio. +* 2012-10-16: Creation of the Tulip project, started as mail threads on the + python-ideas mailing list. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..d0dbde14 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,11 @@ +include AUTHORS COPYING +include Makefile +include overlapped.c pypi.bat +include check.py runtests.py run_aiotest.py release.py +include update_stdlib.sh + +recursive-include examples *.py +recursive-include tests *.crt +recursive-include tests *.key +recursive-include tests *.pem +recursive-include tests *.py diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..eda02f2d --- /dev/null +++ b/Makefile @@ -0,0 +1,60 @@ +# Some simple testing tasks (sorry, UNIX only). + +PYTHON=python3 +VERBOSE=$(V) +V= 0 +FLAGS= + +test: + $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS) + +vtest: + $(PYTHON) runtests.py -v 1 $(FLAGS) + +testloop: + while sleep 1; do $(PYTHON) runtests.py -v $(VERBOSE) $(FLAGS); done + +# See runtests.py for coverage installation instructions. +cov coverage: + $(PYTHON) runtests.py --coverage -v $(VERBOSE) $(FLAGS) + +check: + $(PYTHON) check.py + +# Requires "pip install pep8". +pep8: check + pep8 --ignore E125,E127,E226 tests asyncio + +clean: + rm -rf `find . -name __pycache__` + rm -f `find . -type f -name '*.py[co]' ` + rm -f `find . -type f -name '*~' ` + rm -f `find . -type f -name '.*~' ` + rm -f `find . -type f -name '@*' ` + rm -f `find . -type f -name '#*#' ` + rm -f `find . -type f -name '*.orig' ` + rm -f `find . -type f -name '*.rej' ` + rm -rf dist + rm -f .coverage + rm -rf htmlcov + rm -rf build + rm -rf asyncio.egg-info + rm -f MANIFEST + + +# For distribution builders only! +# Push a source distribution for Python 3.3 to PyPI. +# You must update the version in setup.py first. +# A PyPI user configuration in ~/.pypirc is required; +# you can create a suitable confifuration using +# python setup.py register +pypi: clean + python3.3 setup.py sdist upload + +# The corresponding action on Windows is pypi.bat. For that to work, +# you need to install wheel and setuptools. The easiest way is to get +# pip using the get-pip.py script found here: +# https://pip.pypa.io/en/latest/installing.html#install-pip +# That will install setuptools and pip; then you can just do +# \Python33\python.exe -m pip install wheel +# after which the pypi.bat script should work. diff --git a/README b/README new file mode 100644 index 00000000..2f3150a2 --- /dev/null +++ b/README @@ -0,0 +1,44 @@ +Tulip is the codename for my reference implementation of PEP 3156. + +PEP 3156: http://www.python.org/dev/peps/pep-3156/ + +*** This requires Python 3.3 or later! *** + +Copyright/license: Open source, Apache 2.0. Enjoy. + +Master Mercurial repo: http://code.google.com/p/tulip/ + +The actual code lives in the 'asyncio' subdirectory. +Tests are in the 'tests' subdirectory. + +To run tests: + - make test + +To run coverage (coverage package is required): + - make coverage + +On Windows, things are a little more complicated. Assume 'P' is your +Python binary (for example C:\Python33\python.exe). + +You must first build the _overlapped.pyd extension and have it placed +in the asyncio directory, as follows: + + C> P setup.py build_ext --inplace + +If this complains about vcvars.bat, you probably don't have the +required version of Visual Studio installed. Compiling extensions for +Python 3.3 requires Microsoft Visual C++ 2010 (MSVC 10.0) of any +edition; you can download Visual Studio Express 2010 for free from +http://www.visualstudio.com/downloads (scroll down to Visual C++ 2010 +Express). + +Once you have built the _overlapped.pyd extension successfully you can +run the tests as follows: + + C> P runtests.py + +And coverage as follows: + + C> P runtests.py --coverage + +--Guido van Rossum diff --git a/asyncio/__init__.py b/asyncio/__init__.py new file mode 100644 index 00000000..011466b3 --- /dev/null +++ b/asyncio/__init__.py @@ -0,0 +1,50 @@ +"""The asyncio package, tracking PEP 3156.""" + +import sys + +# The selectors module is in the stdlib in Python 3.4 but not in 3.3. +# Do this first, so the other submodules can use "from . import selectors". +# Prefer asyncio/selectors.py over the stdlib one, as ours may be newer. +try: + from . import selectors +except ImportError: + import selectors # Will also be exported. + +if sys.platform == 'win32': + # Similar thing for _overlapped. + try: + from . import _overlapped + except ImportError: + import _overlapped # Will also be exported. + +# This relies on each of the submodules having an __all__ variable. +from .base_events import * +from .coroutines import * +from .events import * +from .futures import * +from .locks import * +from .protocols import * +from .queues import * +from .streams import * +from .subprocess import * +from .tasks import * +from .transports import * + +__all__ = (base_events.__all__ + + coroutines.__all__ + + events.__all__ + + futures.__all__ + + locks.__all__ + + protocols.__all__ + + queues.__all__ + + streams.__all__ + + subprocess.__all__ + + tasks.__all__ + + transports.__all__) + +if sys.platform == 'win32': # pragma: no cover + from .windows_events import * + __all__ += windows_events.__all__ +else: + from .unix_events import * # pragma: no cover + __all__ += unix_events.__all__ diff --git a/asyncio/base_events.py b/asyncio/base_events.py new file mode 100644 index 00000000..1c51a7cf --- /dev/null +++ b/asyncio/base_events.py @@ -0,0 +1,1151 @@ +"""Base implementation of event loop. + +The event loop can be broken up into a multiplexer (the part +responsible for notifying us of I/O events) and the event loop proper, +which wraps a multiplexer with functionality for scheduling callbacks, +immediately or at a given time in the future. + +Whenever a public API takes a callback, subsequent positional +arguments will be passed to the callback if/when it is called. This +avoids the proliferation of trivial lambdas implementing closures. +Keyword arguments for the callback are not supported; this is a +conscious design decision, leaving the door open for keyword arguments +to modify the meaning of the API call itself. +""" + + +import collections +import concurrent.futures +import heapq +import inspect +import logging +import os +import socket +import subprocess +import threading +import time +import traceback +import sys + +from . import coroutines +from . import events +from . import futures +from . import tasks +from .coroutines import coroutine +from .log import logger + + +__all__ = ['BaseEventLoop'] + + +# Argument for default thread pool executor creation. +_MAX_WORKERS = 5 + +# Minimum number of _scheduled timer handles before cleanup of +# cancelled handles is performed. +_MIN_SCHEDULED_TIMER_HANDLES = 100 + +# Minimum fraction of _scheduled timer handles that are cancelled +# before cleanup of cancelled handles is performed. +_MIN_CANCELLED_TIMER_HANDLES_FRACTION = 0.5 + +def _format_handle(handle): + cb = handle._callback + if inspect.ismethod(cb) and isinstance(cb.__self__, tasks.Task): + # format the task + return repr(cb.__self__) + else: + return str(handle) + + +def _format_pipe(fd): + if fd == subprocess.PIPE: + return '' + elif fd == subprocess.STDOUT: + return '' + else: + return repr(fd) + + +class _StopError(BaseException): + """Raised to stop the event loop.""" + + +def _check_resolved_address(sock, address): + # Ensure that the address is already resolved to avoid the trap of hanging + # the entire event loop when the address requires doing a DNS lookup. + family = sock.family + if family == socket.AF_INET: + host, port = address + elif family == socket.AF_INET6: + host, port = address[:2] + else: + return + + type_mask = 0 + if hasattr(socket, 'SOCK_NONBLOCK'): + type_mask |= socket.SOCK_NONBLOCK + if hasattr(socket, 'SOCK_CLOEXEC'): + type_mask |= socket.SOCK_CLOEXEC + # Use getaddrinfo(flags=AI_NUMERICHOST) to ensure that the address is + # already resolved. + try: + socket.getaddrinfo(host, port, + family=family, + type=(sock.type & ~type_mask), + proto=sock.proto, + flags=socket.AI_NUMERICHOST) + except socket.gaierror as err: + raise ValueError("address must be resolved (IP address), got %r: %s" + % (address, err)) + +def _raise_stop_error(*args): + raise _StopError + + +def _run_until_complete_cb(fut): + exc = fut._exception + if (isinstance(exc, BaseException) + and not isinstance(exc, Exception)): + # Issue #22429: run_forever() already finished, no need to + # stop it. + return + _raise_stop_error() + + +class Server(events.AbstractServer): + + def __init__(self, loop, sockets): + self._loop = loop + self.sockets = sockets + self._active_count = 0 + self._waiters = [] + + def __repr__(self): + return '<%s sockets=%r>' % (self.__class__.__name__, self.sockets) + + def _attach(self): + assert self.sockets is not None + self._active_count += 1 + + def _detach(self): + assert self._active_count > 0 + self._active_count -= 1 + if self._active_count == 0 and self.sockets is None: + self._wakeup() + + def close(self): + sockets = self.sockets + if sockets is None: + return + self.sockets = None + for sock in sockets: + self._loop._stop_serving(sock) + if self._active_count == 0: + self._wakeup() + + def _wakeup(self): + waiters = self._waiters + self._waiters = None + for waiter in waiters: + if not waiter.done(): + waiter.set_result(waiter) + + @coroutine + def wait_closed(self): + if self.sockets is None or self._waiters is None: + return + waiter = futures.Future(loop=self._loop) + self._waiters.append(waiter) + yield from waiter + + +class BaseEventLoop(events.AbstractEventLoop): + + def __init__(self): + self._timer_cancelled_count = 0 + self._closed = False + self._ready = collections.deque() + self._scheduled = [] + self._default_executor = None + self._internal_fds = 0 + # Identifier of the thread running the event loop, or None if the + # event loop is not running + self._owner = None + self._clock_resolution = time.get_clock_info('monotonic').resolution + self._exception_handler = None + self._debug = (not sys.flags.ignore_environment + and bool(os.environ.get('PYTHONASYNCIODEBUG'))) + # In debug mode, if the execution of a callback or a step of a task + # exceed this duration in seconds, the slow callback/task is logged. + self.slow_callback_duration = 0.1 + self._current_handle = None + + def __repr__(self): + return ('<%s running=%s closed=%s debug=%s>' + % (self.__class__.__name__, self.is_running(), + self.is_closed(), self.get_debug())) + + def create_task(self, coro): + """Schedule a coroutine object. + + Return a task object. + """ + self._check_closed() + task = tasks.Task(coro, loop=self) + if task._source_traceback: + del task._source_traceback[-1] + return task + + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None, server=None): + """Create socket transport.""" + raise NotImplementedError + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, + *, server_side=False, server_hostname=None, + extra=None, server=None): + """Create SSL transport.""" + raise NotImplementedError + + def _make_datagram_transport(self, sock, protocol, + address=None, waiter=None, extra=None): + """Create datagram transport.""" + raise NotImplementedError + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create read pipe transport.""" + raise NotImplementedError + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + """Create write pipe transport.""" + raise NotImplementedError + + @coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + """Create subprocess transport.""" + raise NotImplementedError + + def _write_to_self(self): + """Write a byte to self-pipe, to wake up the event loop. + + This may be called from a different thread. + + The subclass is responsible for implementing the self-pipe. + """ + raise NotImplementedError + + def _process_events(self, event_list): + """Process selector events.""" + raise NotImplementedError + + def _check_closed(self): + if self._closed: + raise RuntimeError('Event loop is closed') + + def run_forever(self): + """Run until stop() is called.""" + self._check_closed() + if self.is_running(): + raise RuntimeError('Event loop is running.') + self._owner = threading.get_ident() + try: + while True: + try: + self._run_once() + except _StopError: + break + finally: + self._owner = None + + def run_until_complete(self, future): + """Run until the Future is done. + + If the argument is a coroutine, it is wrapped in a Task. + + WARNING: It would be disastrous to call run_until_complete() + with the same coroutine twice -- it would wrap it in two + different Tasks and that can't be good. + + Return the Future's result, or raise its exception. + """ + self._check_closed() + + new_task = not isinstance(future, futures.Future) + future = tasks.async(future, loop=self) + if new_task: + # An exception is raised if the future didn't complete, so there + # is no need to log the "destroy pending task" message + future._log_destroy_pending = False + + future.add_done_callback(_run_until_complete_cb) + try: + self.run_forever() + except: + if new_task and future.done() and not future.cancelled(): + # The coroutine raised a BaseException. Consume the exception + # to not log a warning, the caller doesn't have access to the + # local task. + future.exception() + raise + future.remove_done_callback(_run_until_complete_cb) + if not future.done(): + raise RuntimeError('Event loop stopped before Future completed.') + + return future.result() + + def stop(self): + """Stop running the event loop. + + Every callback scheduled before stop() is called will run. Callbacks + scheduled after stop() is called will not run. However, those callbacks + will run if run_forever is called again later. + """ + self.call_soon(_raise_stop_error) + + def close(self): + """Close the event loop. + + This clears the queues and shuts down the executor, + but does not wait for the executor to finish. + + The event loop must not be running. + """ + if self.is_running(): + raise RuntimeError("Cannot close a running event loop") + if self._closed: + return + if self._debug: + logger.debug("Close %r", self) + self._closed = True + self._ready.clear() + self._scheduled.clear() + executor = self._default_executor + if executor is not None: + self._default_executor = None + executor.shutdown(wait=False) + + def is_closed(self): + """Returns True if the event loop was closed.""" + return self._closed + + def is_running(self): + """Returns True if the event loop is running.""" + return (self._owner is not None) + + def time(self): + """Return the time according to the event loop's clock. + + This is a float expressed in seconds since an epoch, but the + epoch, precision, accuracy and drift are unspecified and may + differ per event loop. + """ + return time.monotonic() + + def call_later(self, delay, callback, *args): + """Arrange for a callback to be called at a given time. + + Return a Handle: an opaque object with a cancel() method that + can be used to cancel the call. + + The delay can be an int or float, expressed in seconds. It is + always relative to the current time. + + Each callback will be called exactly once. If two callbacks + are scheduled for exactly the same time, it undefined which + will be called first. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + timer = self.call_at(self.time() + delay, callback, *args) + if timer._source_traceback: + del timer._source_traceback[-1] + return timer + + def call_at(self, when, callback, *args): + """Like call_later(), but uses an absolute time. + + Absolute time corresponds to the event loop's time() method. + """ + if (coroutines.iscoroutine(callback) + or coroutines.iscoroutinefunction(callback)): + raise TypeError("coroutines cannot be used with call_at()") + self._check_closed() + if self._debug: + self._check_thread() + timer = events.TimerHandle(when, callback, args, self) + if timer._source_traceback: + del timer._source_traceback[-1] + heapq.heappush(self._scheduled, timer) + timer._scheduled = True + return timer + + def call_soon(self, callback, *args): + """Arrange for a callback to be called as soon as possible. + + This operates as a FIFO queue: callbacks are called in the + order in which they are registered. Each callback will be + called exactly once. + + Any positional arguments after the callback will be passed to + the callback when it is called. + """ + if self._debug: + self._check_thread() + handle = self._call_soon(callback, args) + if handle._source_traceback: + del handle._source_traceback[-1] + return handle + + def _call_soon(self, callback, args): + if (coroutines.iscoroutine(callback) + or coroutines.iscoroutinefunction(callback)): + raise TypeError("coroutines cannot be used with call_soon()") + self._check_closed() + handle = events.Handle(callback, args, self) + if handle._source_traceback: + del handle._source_traceback[-1] + self._ready.append(handle) + return handle + + def _check_thread(self): + """Check that the current thread is the thread running the event loop. + + Non-thread-safe methods of this class make this assumption and will + likely behave incorrectly when the assumption is violated. + + Should only be called when (self._debug == True). The caller is + responsible for checking this condition for performance reasons. + """ + if self._owner is None: + return + thread_id = threading.get_ident() + if thread_id != self._owner: + raise RuntimeError( + "Non-thread-safe operation invoked on an event loop other " + "than the current one") + + def call_soon_threadsafe(self, callback, *args): + """Like call_soon(), but thread-safe.""" + handle = self._call_soon(callback, args) + if handle._source_traceback: + del handle._source_traceback[-1] + self._write_to_self() + return handle + + def run_in_executor(self, executor, callback, *args): + if (coroutines.iscoroutine(callback) + or coroutines.iscoroutinefunction(callback)): + raise TypeError("coroutines cannot be used with run_in_executor()") + self._check_closed() + if isinstance(callback, events.Handle): + assert not args + assert not isinstance(callback, events.TimerHandle) + if callback._cancelled: + f = futures.Future(loop=self) + f.set_result(None) + return f + callback, args = callback._callback, callback._args + if executor is None: + executor = self._default_executor + if executor is None: + executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + self._default_executor = executor + return futures.wrap_future(executor.submit(callback, *args), loop=self) + + def set_default_executor(self, executor): + self._default_executor = executor + + def _getaddrinfo_debug(self, host, port, family, type, proto, flags): + msg = ["%s:%r" % (host, port)] + if family: + msg.append('family=%r' % family) + if type: + msg.append('type=%r' % type) + if proto: + msg.append('proto=%r' % proto) + if flags: + msg.append('flags=%r' % flags) + msg = ', '.join(msg) + logger.debug('Get address info %s', msg) + + t0 = self.time() + addrinfo = socket.getaddrinfo(host, port, family, type, proto, flags) + dt = self.time() - t0 + + msg = ('Getting address info %s took %.3f ms: %r' + % (msg, dt * 1e3, addrinfo)) + if dt >= self.slow_callback_duration: + logger.info(msg) + else: + logger.debug(msg) + return addrinfo + + def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): + if self._debug: + return self.run_in_executor(None, self._getaddrinfo_debug, + host, port, family, type, proto, flags) + else: + return self.run_in_executor(None, socket.getaddrinfo, + host, port, family, type, proto, flags) + + def getnameinfo(self, sockaddr, flags=0): + return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + + @coroutine + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=None, family=0, proto=0, flags=0, sock=None, + local_addr=None, server_hostname=None): + """Connect to a TCP server. + + Create a streaming transport connection to a given Internet host and + port: socket family AF_INET or socket.AF_INET6 depending on host (or + family if specified), socket type SOCK_STREAM. protocol_factory must be + a callable returning a protocol instance. + + This method is a coroutine which will try to establish the connection + in the background. When successful, the coroutine returns a + (transport, protocol) pair. + """ + if server_hostname is not None and not ssl: + raise ValueError('server_hostname is only meaningful with ssl') + + if server_hostname is None and ssl: + # Use host as default for server_hostname. It is an error + # if host is empty or not set, e.g. when an + # already-connected socket was passed or when only a port + # is given. To avoid this error, you can pass + # server_hostname='' -- this will bypass the hostname + # check. (This also means that if host is a numeric + # IP/IPv6 address, we will attempt to verify that exact + # address; this will probably fail, but it is possible to + # create a certificate for a specific IP address, so we + # don't judge it here.) + if not host: + raise ValueError('You must set server_hostname ' + 'when using ssl without a host') + server_hostname = host + + if host is not None or port is not None: + if sock is not None: + raise ValueError( + 'host/port and sock can not be specified at the same time') + + f1 = self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + fs = [f1] + if local_addr is not None: + f2 = self.getaddrinfo( + *local_addr, family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags) + fs.append(f2) + else: + f2 = None + + yield from tasks.wait(fs, loop=self) + + infos = f1.result() + if not infos: + raise OSError('getaddrinfo() returned empty list') + if f2 is not None: + laddr_infos = f2.result() + if not laddr_infos: + raise OSError('getaddrinfo() returned empty list') + + exceptions = [] + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + if f2 is not None: + for _, _, _, _, laddr in laddr_infos: + try: + sock.bind(laddr) + break + except OSError as exc: + exc = OSError( + exc.errno, 'error while ' + 'attempting to bind on address ' + '{!r}: {}'.format( + laddr, exc.strerror.lower())) + exceptions.append(exc) + else: + sock.close() + sock = None + continue + if self._debug: + logger.debug("connect %r to %r", sock, address) + yield from self.sock_connect(sock, address) + except OSError as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + except: + if sock is not None: + sock.close() + raise + else: + break + else: + if len(exceptions) == 1: + raise exceptions[0] + else: + # If they all have the same str(), raise one. + model = str(exceptions[0]) + if all(str(exc) == model for exc in exceptions): + raise exceptions[0] + # Raise a combined exception so the user can see all + # the various error messages. + raise OSError('Multiple exceptions: {}'.format( + ', '.join(str(exc) for exc in exceptions))) + + elif sock is None: + raise ValueError( + 'host and port was not specified and no sock specified') + + sock.setblocking(False) + + transport, protocol = yield from self._create_connection_transport( + sock, protocol_factory, ssl, server_hostname) + if self._debug: + # Get the socket from the transport because SSL transport closes + # the old socket and creates a new SSL socket + sock = transport.get_extra_info('socket') + logger.debug("%r connected to %s:%r: (%r, %r)", + sock, host, port, transport, protocol) + return transport, protocol + + @coroutine + def _create_connection_transport(self, sock, protocol_factory, ssl, + server_hostname): + protocol = protocol_factory() + waiter = futures.Future(loop=self) + if ssl: + sslcontext = None if isinstance(ssl, bool) else ssl + transport = self._make_ssl_transport( + sock, protocol, sslcontext, waiter, + server_side=False, server_hostname=server_hostname) + else: + transport = self._make_socket_transport(sock, protocol, waiter) + + try: + yield from waiter + except: + transport.close() + raise + + return transport, protocol + + @coroutine + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + """Create datagram connection.""" + if not (local_addr or remote_addr): + if family == 0: + raise ValueError('unexpected address family') + addr_pairs_info = (((family, proto), (None, None)),) + else: + # join address by (family, protocol) + addr_infos = collections.OrderedDict() + for idx, addr in ((0, local_addr), (1, remote_addr)): + if addr is not None: + assert isinstance(addr, tuple) and len(addr) == 2, ( + '2-tuple is expected') + + infos = yield from self.getaddrinfo( + *addr, family=family, type=socket.SOCK_DGRAM, + proto=proto, flags=flags) + if not infos: + raise OSError('getaddrinfo() returned empty list') + + for fam, _, pro, _, address in infos: + key = (fam, pro) + if key not in addr_infos: + addr_infos[key] = [None, None] + addr_infos[key][idx] = address + + # each addr has to have info for each (family, proto) pair + addr_pairs_info = [ + (key, addr_pair) for key, addr_pair in addr_infos.items() + if not ((local_addr and addr_pair[0] is None) or + (remote_addr and addr_pair[1] is None))] + + if not addr_pairs_info: + raise ValueError('can not get address information') + + exceptions = [] + + for ((family, proto), + (local_address, remote_address)) in addr_pairs_info: + sock = None + r_addr = None + try: + sock = socket.socket( + family=family, type=socket.SOCK_DGRAM, proto=proto) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setblocking(False) + + if local_addr: + sock.bind(local_address) + if remote_addr: + yield from self.sock_connect(sock, remote_address) + r_addr = remote_address + except OSError as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + except: + if sock is not None: + sock.close() + raise + else: + break + else: + raise exceptions[0] + + protocol = protocol_factory() + waiter = futures.Future(loop=self) + transport = self._make_datagram_transport(sock, protocol, r_addr, + waiter) + if self._debug: + if local_addr: + logger.info("Datagram endpoint local_addr=%r remote_addr=%r " + "created: (%r, %r)", + local_addr, remote_addr, transport, protocol) + else: + logger.debug("Datagram endpoint remote_addr=%r created: " + "(%r, %r)", + remote_addr, transport, protocol) + + try: + yield from waiter + except: + transport.close() + raise + + return transport, protocol + + @coroutine + def create_server(self, protocol_factory, host=None, port=None, + *, + family=socket.AF_UNSPEC, + flags=socket.AI_PASSIVE, + sock=None, + backlog=100, + ssl=None, + reuse_address=None): + """Create a TCP server bound to host and port. + + Return a Server object which can be used to stop the service. + + This method is a coroutine. + """ + if isinstance(ssl, bool): + raise TypeError('ssl argument must be an SSLContext or None') + if host is not None or port is not None: + if sock is not None: + raise ValueError( + 'host/port and sock can not be specified at the same time') + + AF_INET6 = getattr(socket, 'AF_INET6', 0) + if reuse_address is None: + reuse_address = os.name == 'posix' and sys.platform != 'cygwin' + sockets = [] + if host == '': + host = None + + infos = yield from self.getaddrinfo( + host, port, family=family, + type=socket.SOCK_STREAM, proto=0, flags=flags) + if not infos: + raise OSError('getaddrinfo() returned empty list') + + completed = False + try: + for res in infos: + af, socktype, proto, canonname, sa = res + try: + sock = socket.socket(af, socktype, proto) + except socket.error: + # Assume it's a bad family/type/protocol combination. + if self._debug: + logger.warning('create_server() failed to create ' + 'socket.socket(%r, %r, %r)', + af, socktype, proto, exc_info=True) + continue + sockets.append(sock) + if reuse_address: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, + True) + # Disable IPv4/IPv6 dual stack support (enabled by + # default on Linux) which makes a single socket + # listen on both address families. + if af == AF_INET6 and hasattr(socket, 'IPPROTO_IPV6'): + sock.setsockopt(socket.IPPROTO_IPV6, + socket.IPV6_V6ONLY, + True) + try: + sock.bind(sa) + except OSError as err: + raise OSError(err.errno, 'error while attempting ' + 'to bind on address %r: %s' + % (sa, err.strerror.lower())) + completed = True + finally: + if not completed: + for sock in sockets: + sock.close() + else: + if sock is None: + raise ValueError('Neither host/port nor sock were specified') + sockets = [sock] + + server = Server(self, sockets) + for sock in sockets: + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock, ssl, server) + if self._debug: + logger.info("%r is serving", server) + return server + + @coroutine + def connect_read_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future(loop=self) + transport = self._make_read_pipe_transport(pipe, protocol, waiter) + + try: + yield from waiter + except: + transport.close() + raise + + if self._debug: + logger.debug('Read pipe %r connected: (%r, %r)', + pipe.fileno(), transport, protocol) + return transport, protocol + + @coroutine + def connect_write_pipe(self, protocol_factory, pipe): + protocol = protocol_factory() + waiter = futures.Future(loop=self) + transport = self._make_write_pipe_transport(pipe, protocol, waiter) + + try: + yield from waiter + except: + transport.close() + raise + + if self._debug: + logger.debug('Write pipe %r connected: (%r, %r)', + pipe.fileno(), transport, protocol) + return transport, protocol + + def _log_subprocess(self, msg, stdin, stdout, stderr): + info = [msg] + if stdin is not None: + info.append('stdin=%s' % _format_pipe(stdin)) + if stdout is not None and stderr == subprocess.STDOUT: + info.append('stdout=stderr=%s' % _format_pipe(stdout)) + else: + if stdout is not None: + info.append('stdout=%s' % _format_pipe(stdout)) + if stderr is not None: + info.append('stderr=%s' % _format_pipe(stderr)) + logger.debug(' '.join(info)) + + @coroutine + def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + universal_newlines=False, shell=True, bufsize=0, + **kwargs): + if not isinstance(cmd, (bytes, str)): + raise ValueError("cmd must be a string") + if universal_newlines: + raise ValueError("universal_newlines must be False") + if not shell: + raise ValueError("shell must be True") + if bufsize != 0: + raise ValueError("bufsize must be 0") + protocol = protocol_factory() + if self._debug: + # don't log parameters: they may contain sensitive information + # (password) and may be too long + debug_log = 'run shell command %r' % cmd + self._log_subprocess(debug_log, stdin, stdout, stderr) + transport = yield from self._make_subprocess_transport( + protocol, cmd, True, stdin, stdout, stderr, bufsize, **kwargs) + if self._debug: + logger.info('%s: %r' % (debug_log, transport)) + return transport, protocol + + @coroutine + def subprocess_exec(self, protocol_factory, program, *args, + stdin=subprocess.PIPE, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, universal_newlines=False, + shell=False, bufsize=0, **kwargs): + if universal_newlines: + raise ValueError("universal_newlines must be False") + if shell: + raise ValueError("shell must be False") + if bufsize != 0: + raise ValueError("bufsize must be 0") + popen_args = (program,) + args + for arg in popen_args: + if not isinstance(arg, (str, bytes)): + raise TypeError("program arguments must be " + "a bytes or text string, not %s" + % type(arg).__name__) + protocol = protocol_factory() + if self._debug: + # don't log parameters: they may contain sensitive information + # (password) and may be too long + debug_log = 'execute program %r' % program + self._log_subprocess(debug_log, stdin, stdout, stderr) + transport = yield from self._make_subprocess_transport( + protocol, popen_args, False, stdin, stdout, stderr, + bufsize, **kwargs) + if self._debug: + logger.info('%s: %r' % (debug_log, transport)) + return transport, protocol + + def set_exception_handler(self, handler): + """Set handler as the new event loop exception handler. + + If handler is None, the default exception handler will + be set. + + If handler is a callable object, it should have a + signature matching '(loop, context)', where 'loop' + will be a reference to the active event loop, 'context' + will be a dict object (see `call_exception_handler()` + documentation for details about context). + """ + if handler is not None and not callable(handler): + raise TypeError('A callable object or None is expected, ' + 'got {!r}'.format(handler)) + self._exception_handler = handler + + def default_exception_handler(self, context): + """Default exception handler. + + This is called when an exception occurs and no exception + handler is set, and can be called by a custom exception + handler that wants to defer to the default behavior. + + The context parameter has the same meaning as in + `call_exception_handler()`. + """ + message = context.get('message') + if not message: + message = 'Unhandled exception in event loop' + + exception = context.get('exception') + if exception is not None: + exc_info = (type(exception), exception, exception.__traceback__) + else: + exc_info = False + + if (self._current_handle is not None + and self._current_handle._source_traceback): + context['handle_traceback'] = self._current_handle._source_traceback + + log_lines = [message] + for key in sorted(context): + if key in {'message', 'exception'}: + continue + value = context[key] + if key == 'source_traceback': + tb = ''.join(traceback.format_list(value)) + value = 'Object created at (most recent call last):\n' + value += tb.rstrip() + elif key == 'handle_traceback': + tb = ''.join(traceback.format_list(value)) + value = 'Handle created at (most recent call last):\n' + value += tb.rstrip() + else: + value = repr(value) + log_lines.append('{}: {}'.format(key, value)) + + logger.error('\n'.join(log_lines), exc_info=exc_info) + + def call_exception_handler(self, context): + """Call the current event loop's exception handler. + + The context argument is a dict containing the following keys: + + - 'message': Error message; + - 'exception' (optional): Exception object; + - 'future' (optional): Future instance; + - 'handle' (optional): Handle instance; + - 'protocol' (optional): Protocol instance; + - 'transport' (optional): Transport instance; + - 'socket' (optional): Socket instance. + + New keys maybe introduced in the future. + + Note: do not overload this method in an event loop subclass. + For custom exception handling, use the + `set_exception_handler()` method. + """ + if self._exception_handler is None: + try: + self.default_exception_handler(context) + except Exception: + # Second protection layer for unexpected errors + # in the default implementation, as well as for subclassed + # event loops with overloaded "default_exception_handler". + logger.error('Exception in default exception handler', + exc_info=True) + else: + try: + self._exception_handler(self, context) + except Exception as exc: + # Exception in the user set custom exception handler. + try: + # Let's try default handler. + self.default_exception_handler({ + 'message': 'Unhandled error in exception handler', + 'exception': exc, + 'context': context, + }) + except Exception: + # Guard 'default_exception_handler' in case it is + # overloaded. + logger.error('Exception in default exception handler ' + 'while handling an unexpected error ' + 'in custom exception handler', + exc_info=True) + + def _add_callback(self, handle): + """Add a Handle to _scheduled (TimerHandle) or _ready.""" + assert isinstance(handle, events.Handle), 'A Handle is required here' + if handle._cancelled: + return + assert not isinstance(handle, events.TimerHandle) + self._ready.append(handle) + + def _add_callback_signalsafe(self, handle): + """Like _add_callback() but called from a signal handler.""" + self._add_callback(handle) + self._write_to_self() + + def _timer_handle_cancelled(self, handle): + """Notification that a TimerHandle has been cancelled.""" + if handle._scheduled: + self._timer_cancelled_count += 1 + + def _run_once(self): + """Run one full iteration of the event loop. + + This calls all currently ready callbacks, polls for I/O, + schedules the resulting callbacks, and finally schedules + 'call_later' callbacks. + """ + + sched_count = len(self._scheduled) + if (sched_count > _MIN_SCHEDULED_TIMER_HANDLES and + self._timer_cancelled_count / sched_count > + _MIN_CANCELLED_TIMER_HANDLES_FRACTION): + # Remove delayed calls that were cancelled if their number + # is too high + new_scheduled = [] + for handle in self._scheduled: + if handle._cancelled: + handle._scheduled = False + else: + new_scheduled.append(handle) + + heapq.heapify(new_scheduled) + self._scheduled = new_scheduled + self._timer_cancelled_count = 0 + else: + # Remove delayed calls that were cancelled from head of queue. + while self._scheduled and self._scheduled[0]._cancelled: + self._timer_cancelled_count -= 1 + handle = heapq.heappop(self._scheduled) + handle._scheduled = False + + timeout = None + if self._ready: + timeout = 0 + elif self._scheduled: + # Compute the desired timeout. + when = self._scheduled[0]._when + timeout = max(0, when - self.time()) + + if self._debug and timeout != 0: + t0 = self.time() + event_list = self._selector.select(timeout) + dt = self.time() - t0 + if dt >= 1.0: + level = logging.INFO + else: + level = logging.DEBUG + nevent = len(event_list) + if timeout is None: + logger.log(level, 'poll took %.3f ms: %s events', + dt * 1e3, nevent) + elif nevent: + logger.log(level, + 'poll %.3f ms took %.3f ms: %s events', + timeout * 1e3, dt * 1e3, nevent) + elif dt >= 1.0: + logger.log(level, + 'poll %.3f ms took %.3f ms: timeout', + timeout * 1e3, dt * 1e3) + else: + event_list = self._selector.select(timeout) + self._process_events(event_list) + + # Handle 'later' callbacks that are ready. + end_time = self.time() + self._clock_resolution + while self._scheduled: + handle = self._scheduled[0] + if handle._when >= end_time: + break + handle = heapq.heappop(self._scheduled) + handle._scheduled = False + self._ready.append(handle) + + # This is the only place where callbacks are actually *called*. + # All other places just add them to ready. + # Note: We run all currently scheduled callbacks, but not any + # callbacks scheduled by callbacks run this time around -- + # they will be run the next time (after another I/O poll). + # Use an idiom that is thread-safe without using locks. + ntodo = len(self._ready) + for i in range(ntodo): + handle = self._ready.popleft() + if handle._cancelled: + continue + if self._debug: + try: + self._current_handle = handle + t0 = self.time() + handle._run() + dt = self.time() - t0 + if dt >= self.slow_callback_duration: + logger.warning('Executing %s took %.3f seconds', + _format_handle(handle), dt) + finally: + self._current_handle = None + else: + handle._run() + handle = None # Needed to break cycles when an exception occurs. + + def get_debug(self): + return self._debug + + def set_debug(self, enabled): + self._debug = enabled diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py new file mode 100644 index 00000000..f5e7dfec --- /dev/null +++ b/asyncio/base_subprocess.py @@ -0,0 +1,229 @@ +import collections +import subprocess + +from . import protocols +from . import transports +from .coroutines import coroutine +from .log import logger + + +class BaseSubprocessTransport(transports.SubprocessTransport): + + def __init__(self, loop, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + super().__init__(extra) + self._protocol = protocol + self._loop = loop + self._pid = None + + self._pipes = {} + if stdin == subprocess.PIPE: + self._pipes[0] = None + if stdout == subprocess.PIPE: + self._pipes[1] = None + if stderr == subprocess.PIPE: + self._pipes[2] = None + self._pending_calls = collections.deque() + self._finished = False + self._returncode = None + self._start(args=args, shell=shell, stdin=stdin, stdout=stdout, + stderr=stderr, bufsize=bufsize, **kwargs) + self._pid = self._proc.pid + self._extra['subprocess'] = self._proc + if self._loop.get_debug(): + if isinstance(args, (bytes, str)): + program = args + else: + program = args[0] + logger.debug('process %r created: pid %s', + program, self._pid) + + def __repr__(self): + info = [self.__class__.__name__, 'pid=%s' % self._pid] + if self._returncode is not None: + info.append('returncode=%s' % self._returncode) + + stdin = self._pipes.get(0) + if stdin is not None: + info.append('stdin=%s' % stdin.pipe) + + stdout = self._pipes.get(1) + stderr = self._pipes.get(2) + if stdout is not None and stderr is stdout: + info.append('stdout=stderr=%s' % stdout.pipe) + else: + if stdout is not None: + info.append('stdout=%s' % stdout.pipe) + if stderr is not None: + info.append('stderr=%s' % stderr.pipe) + + return '<%s>' % ' '.join(info) + + def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): + raise NotImplementedError + + def _make_write_subprocess_pipe_proto(self, fd): + raise NotImplementedError + + def _make_read_subprocess_pipe_proto(self, fd): + raise NotImplementedError + + def close(self): + for proto in self._pipes.values(): + if proto is None: + continue + proto.pipe.close() + if self._returncode is None: + self.terminate() + + def get_pid(self): + return self._pid + + def get_returncode(self): + return self._returncode + + def get_pipe_transport(self, fd): + if fd in self._pipes: + return self._pipes[fd].pipe + else: + return None + + def send_signal(self, signal): + self._proc.send_signal(signal) + + def terminate(self): + self._proc.terminate() + + def kill(self): + self._proc.kill() + + def _kill_wait(self): + """Close pipes, kill the subprocess and read its return status. + + Function called when an exception is raised during the creation + of a subprocess. + """ + if self._loop.get_debug(): + logger.warning('Exception during subprocess creation, ' + 'kill the subprocess %r', + self, + exc_info=True) + + proc = self._proc + if proc.stdout: + proc.stdout.close() + if proc.stderr: + proc.stderr.close() + if proc.stdin: + proc.stdin.close() + try: + proc.kill() + except ProcessLookupError: + pass + self._returncode = proc.wait() + + @coroutine + def _post_init(self): + try: + proc = self._proc + loop = self._loop + if proc.stdin is not None: + _, pipe = yield from loop.connect_write_pipe( + lambda: WriteSubprocessPipeProto(self, 0), + proc.stdin) + self._pipes[0] = pipe + if proc.stdout is not None: + _, pipe = yield from loop.connect_read_pipe( + lambda: ReadSubprocessPipeProto(self, 1), + proc.stdout) + self._pipes[1] = pipe + if proc.stderr is not None: + _, pipe = yield from loop.connect_read_pipe( + lambda: ReadSubprocessPipeProto(self, 2), + proc.stderr) + self._pipes[2] = pipe + + assert self._pending_calls is not None + + self._loop.call_soon(self._protocol.connection_made, self) + for callback, data in self._pending_calls: + self._loop.call_soon(callback, *data) + self._pending_calls = None + except: + self._kill_wait() + raise + + def _call(self, cb, *data): + if self._pending_calls is not None: + self._pending_calls.append((cb, data)) + else: + self._loop.call_soon(cb, *data) + + def _pipe_connection_lost(self, fd, exc): + self._call(self._protocol.pipe_connection_lost, fd, exc) + self._try_finish() + + def _pipe_data_received(self, fd, data): + self._call(self._protocol.pipe_data_received, fd, data) + + def _process_exited(self, returncode): + assert returncode is not None, returncode + assert self._returncode is None, self._returncode + if self._loop.get_debug(): + logger.info('%r exited with return code %r', + self, returncode) + self._returncode = returncode + self._call(self._protocol.process_exited) + self._try_finish() + + def _try_finish(self): + assert not self._finished + if self._returncode is None: + return + if all(p is not None and p.disconnected + for p in self._pipes.values()): + self._finished = True + self._call(self._call_connection_lost, None) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._proc = None + self._protocol = None + self._loop = None + + +class WriteSubprocessPipeProto(protocols.BaseProtocol): + + def __init__(self, proc, fd): + self.proc = proc + self.fd = fd + self.pipe = None + self.disconnected = False + + def connection_made(self, transport): + self.pipe = transport + + def __repr__(self): + return ('<%s fd=%s pipe=%r>' + % (self.__class__.__name__, self.fd, self.pipe)) + + def connection_lost(self, exc): + self.disconnected = True + self.proc._pipe_connection_lost(self.fd, exc) + self.proc = None + + def pause_writing(self): + self.proc._protocol.pause_writing() + + def resume_writing(self): + self.proc._protocol.resume_writing() + + +class ReadSubprocessPipeProto(WriteSubprocessPipeProto, + protocols.Protocol): + + def data_received(self, data): + self.proc._pipe_data_received(self.fd, data) diff --git a/asyncio/constants.py b/asyncio/constants.py new file mode 100644 index 00000000..f9e12328 --- /dev/null +++ b/asyncio/constants.py @@ -0,0 +1,7 @@ +"""Constants.""" + +# After the connection is lost, log warnings after this many write()s. +LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5 + +# Seconds to wait before retrying accept(). +ACCEPT_RETRY_DELAY = 1 diff --git a/asyncio/coroutines.py b/asyncio/coroutines.py new file mode 100644 index 00000000..a1b28751 --- /dev/null +++ b/asyncio/coroutines.py @@ -0,0 +1,199 @@ +__all__ = ['coroutine', + 'iscoroutinefunction', 'iscoroutine'] + +import functools +import inspect +import opcode +import os +import sys +import traceback +import types + +from . import events +from . import futures +from .log import logger + + +# Opcode of "yield from" instruction +_YIELD_FROM = opcode.opmap['YIELD_FROM'] + +# If you set _DEBUG to true, @coroutine will wrap the resulting +# generator objects in a CoroWrapper instance (defined below). That +# instance will log a message when the generator is never iterated +# over, which may happen when you forget to use "yield from" with a +# coroutine call. Note that the value of the _DEBUG flag is taken +# when the decorator is used, so to be of any use it must be set +# before you define your coroutines. A downside of using this feature +# is that tracebacks show entries for the CoroWrapper.__next__ method +# when _DEBUG is true. +_DEBUG = (not sys.flags.ignore_environment + and bool(os.environ.get('PYTHONASYNCIODEBUG'))) + + +# Check for CPython issue #21209 +def has_yield_from_bug(): + class MyGen: + def __init__(self): + self.send_args = None + def __iter__(self): + return self + def __next__(self): + return 42 + def send(self, *what): + self.send_args = what + return None + def yield_from_gen(gen): + yield from gen + value = (1, 2, 3) + gen = MyGen() + coro = yield_from_gen(gen) + next(coro) + coro.send(value) + return gen.send_args != (value,) +_YIELD_FROM_BUG = has_yield_from_bug() +del has_yield_from_bug + + +class CoroWrapper: + # Wrapper for coroutine object in _DEBUG mode. + + def __init__(self, gen, func): + assert inspect.isgenerator(gen), gen + self.gen = gen + self.func = func + self._source_traceback = traceback.extract_stack(sys._getframe(1)) + # __name__, __qualname__, __doc__ attributes are set by the coroutine() + # decorator + + def __repr__(self): + coro_repr = _format_coroutine(self) + if self._source_traceback: + frame = self._source_traceback[-1] + coro_repr += ', created at %s:%s' % (frame[0], frame[1]) + return '<%s %s>' % (self.__class__.__name__, coro_repr) + + def __iter__(self): + return self + + def __next__(self): + return next(self.gen) + + if _YIELD_FROM_BUG: + # For for CPython issue #21209: using "yield from" and a custom + # generator, generator.send(tuple) unpacks the tuple instead of passing + # the tuple unchanged. Check if the caller is a generator using "yield + # from" to decide if the parameter should be unpacked or not. + def send(self, *value): + frame = sys._getframe() + caller = frame.f_back + assert caller.f_lasti >= 0 + if caller.f_code.co_code[caller.f_lasti] != _YIELD_FROM: + value = value[0] + return self.gen.send(value) + else: + def send(self, value): + return self.gen.send(value) + + def throw(self, exc): + return self.gen.throw(exc) + + def close(self): + return self.gen.close() + + @property + def gi_frame(self): + return self.gen.gi_frame + + @property + def gi_running(self): + return self.gen.gi_running + + @property + def gi_code(self): + return self.gen.gi_code + + def __del__(self): + # Be careful accessing self.gen.frame -- self.gen might not exist. + gen = getattr(self, 'gen', None) + frame = getattr(gen, 'gi_frame', None) + if frame is not None and frame.f_lasti == -1: + msg = '%r was never yielded from' % self + tb = getattr(self, '_source_traceback', ()) + if tb: + tb = ''.join(traceback.format_list(tb)) + msg += ('\nCoroutine object created at ' + '(most recent call last):\n') + msg += tb.rstrip() + logger.error(msg) + + +def coroutine(func): + """Decorator to mark coroutines. + + If the coroutine is not yielded from before it is destroyed, + an error message is logged. + """ + if inspect.isgeneratorfunction(func): + coro = func + else: + @functools.wraps(func) + def coro(*args, **kw): + res = func(*args, **kw) + if isinstance(res, futures.Future) or inspect.isgenerator(res): + res = yield from res + return res + + if not _DEBUG: + wrapper = coro + else: + @functools.wraps(func) + def wrapper(*args, **kwds): + w = CoroWrapper(coro(*args, **kwds), func) + if w._source_traceback: + del w._source_traceback[-1] + w.__name__ = func.__name__ + if hasattr(func, '__qualname__'): + w.__qualname__ = func.__qualname__ + w.__doc__ = func.__doc__ + return w + + wrapper._is_coroutine = True # For iscoroutinefunction(). + return wrapper + + +def iscoroutinefunction(func): + """Return True if func is a decorated coroutine function.""" + return getattr(func, '_is_coroutine', False) + + +_COROUTINE_TYPES = (types.GeneratorType, CoroWrapper) + +def iscoroutine(obj): + """Return True if obj is a coroutine object.""" + return isinstance(obj, _COROUTINE_TYPES) + + +def _format_coroutine(coro): + assert iscoroutine(coro) + coro_name = getattr(coro, '__qualname__', coro.__name__) + + filename = coro.gi_code.co_filename + if (isinstance(coro, CoroWrapper) + and not inspect.isgeneratorfunction(coro.func)): + filename, lineno = events._get_function_source(coro.func) + if coro.gi_frame is None: + coro_repr = ('%s() done, defined at %s:%s' + % (coro_name, filename, lineno)) + else: + coro_repr = ('%s() running, defined at %s:%s' + % (coro_name, filename, lineno)) + elif coro.gi_frame is not None: + lineno = coro.gi_frame.f_lineno + coro_repr = ('%s() running at %s:%s' + % (coro_name, filename, lineno)) + else: + lineno = coro.gi_code.co_firstlineno + coro_repr = ('%s() done, defined at %s:%s' + % (coro_name, filename, lineno)) + + return coro_repr diff --git a/asyncio/events.py b/asyncio/events.py new file mode 100644 index 00000000..8a7bb814 --- /dev/null +++ b/asyncio/events.py @@ -0,0 +1,597 @@ +"""Event loop and event loop policy.""" + +__all__ = ['AbstractEventLoopPolicy', + 'AbstractEventLoop', 'AbstractServer', + 'Handle', 'TimerHandle', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', 'new_event_loop', + 'get_child_watcher', 'set_child_watcher', + ] + +import functools +import inspect +import reprlib +import socket +import subprocess +import sys +import threading +import traceback + + +_PY34 = sys.version_info >= (3, 4) + + +def _get_function_source(func): + if _PY34: + func = inspect.unwrap(func) + elif hasattr(func, '__wrapped__'): + func = func.__wrapped__ + if inspect.isfunction(func): + code = func.__code__ + return (code.co_filename, code.co_firstlineno) + if isinstance(func, functools.partial): + return _get_function_source(func.func) + if _PY34 and isinstance(func, functools.partialmethod): + return _get_function_source(func.func) + return None + + +def _format_args(args): + """Format function arguments. + + Special case for a single parameter: ('hello',) is formatted as ('hello'). + """ + # use reprlib to limit the length of the output + args_repr = reprlib.repr(args) + if len(args) == 1 and args_repr.endswith(',)'): + args_repr = args_repr[:-2] + ')' + return args_repr + + +def _format_callback(func, args, suffix=''): + if isinstance(func, functools.partial): + if args is not None: + suffix = _format_args(args) + suffix + return _format_callback(func.func, func.args, suffix) + + func_repr = getattr(func, '__qualname__', None) + if not func_repr: + func_repr = repr(func) + + if args is not None: + func_repr += _format_args(args) + if suffix: + func_repr += suffix + + source = _get_function_source(func) + if source: + func_repr += ' at %s:%s' % source + return func_repr + + +class Handle: + """Object returned by callback registration methods.""" + + __slots__ = ('_callback', '_args', '_cancelled', '_loop', + '_source_traceback', '_repr', '__weakref__') + + def __init__(self, callback, args, loop): + assert not isinstance(callback, Handle), 'A Handle is not a callback' + self._loop = loop + self._callback = callback + self._args = args + self._cancelled = False + self._repr = None + if self._loop.get_debug(): + self._source_traceback = traceback.extract_stack(sys._getframe(1)) + else: + self._source_traceback = None + + def _repr_info(self): + info = [self.__class__.__name__] + if self._cancelled: + info.append('cancelled') + if self._callback is not None: + info.append(_format_callback(self._callback, self._args)) + if self._source_traceback: + frame = self._source_traceback[-1] + info.append('created at %s:%s' % (frame[0], frame[1])) + return info + + def __repr__(self): + if self._repr is not None: + return self._repr + info = self._repr_info() + return '<%s>' % ' '.join(info) + + def cancel(self): + if not self._cancelled: + self._cancelled = True + if self._loop.get_debug(): + # Keep a representation in debug mode to keep callback and + # parameters. For example, to log the warning + # "Executing took 2.5 second" + self._repr = repr(self) + self._callback = None + self._args = None + + def _run(self): + try: + self._callback(*self._args) + except Exception as exc: + cb = _format_callback(self._callback, self._args) + msg = 'Exception in callback {}'.format(cb) + context = { + 'message': msg, + 'exception': exc, + 'handle': self, + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + self = None # Needed to break cycles when an exception occurs. + + +class TimerHandle(Handle): + """Object returned by timed callback registration methods.""" + + __slots__ = ['_scheduled', '_when'] + + def __init__(self, when, callback, args, loop): + assert when is not None + super().__init__(callback, args, loop) + if self._source_traceback: + del self._source_traceback[-1] + self._when = when + self._scheduled = False + + def _repr_info(self): + info = super()._repr_info() + pos = 2 if self._cancelled else 1 + info.insert(pos, 'when=%s' % self._when) + return info + + def __hash__(self): + return hash(self._when) + + def __lt__(self, other): + return self._when < other._when + + def __le__(self, other): + if self._when < other._when: + return True + return self.__eq__(other) + + def __gt__(self, other): + return self._when > other._when + + def __ge__(self, other): + if self._when > other._when: + return True + return self.__eq__(other) + + def __eq__(self, other): + if isinstance(other, TimerHandle): + return (self._when == other._when and + self._callback == other._callback and + self._args == other._args and + self._cancelled == other._cancelled) + return NotImplemented + + def __ne__(self, other): + equal = self.__eq__(other) + return NotImplemented if equal is NotImplemented else not equal + + def cancel(self): + if not self._cancelled: + self._loop._timer_handle_cancelled(self) + super().cancel() + + +class AbstractServer: + """Abstract server returned by create_server().""" + + def close(self): + """Stop serving. This leaves existing connections open.""" + return NotImplemented + + def wait_closed(self): + """Coroutine to wait until service is closed.""" + return NotImplemented + + +class AbstractEventLoop: + """Abstract event loop.""" + + # Running and stopping the event loop. + + def run_forever(self): + """Run the event loop until stop() is called.""" + raise NotImplementedError + + def run_until_complete(self, future): + """Run the event loop until a Future is done. + + Return the Future's result, or raise its exception. + """ + raise NotImplementedError + + def stop(self): + """Stop the event loop as soon as reasonable. + + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + raise NotImplementedError + + def is_running(self): + """Return whether the event loop is currently running.""" + raise NotImplementedError + + def is_closed(self): + """Returns True if the event loop was closed.""" + raise NotImplementedError + + def close(self): + """Close the loop. + + The loop should not be running. + + This is idempotent and irreversible. + + No other methods should be called after this one. + """ + raise NotImplementedError + + # Methods scheduling callbacks. All these return Handles. + + def _timer_handle_cancelled(self, handle): + """Notification that a TimerHandle has been cancelled.""" + raise NotImplementedError + + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) + + def call_later(self, delay, callback, *args): + raise NotImplementedError + + def call_at(self, when, callback, *args): + raise NotImplementedError + + def time(self): + raise NotImplementedError + + # Method scheduling a coroutine object: create a task. + + def create_task(self, coro): + raise NotImplementedError + + # Methods for interacting with threads. + + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError + + def run_in_executor(self, executor, callback, *args): + raise NotImplementedError + + def set_default_executor(self, executor): + raise NotImplementedError + + # Network I/O methods returning Futures. + + def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + raise NotImplementedError + + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError + + def create_connection(self, protocol_factory, host=None, port=None, *, + ssl=None, family=0, proto=0, flags=0, sock=None, + local_addr=None, server_hostname=None): + raise NotImplementedError + + def create_server(self, protocol_factory, host=None, port=None, *, + family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, + sock=None, backlog=100, ssl=None, reuse_address=None): + """A coroutine which creates a TCP server bound to host and port. + + The return value is a Server object which can be used to stop + the service. + + If host is an empty string or None all interfaces are assumed + and a list of multiple sockets will be returned (most likely + one for IPv4 and another one for IPv6). + + family can be set to either AF_INET or AF_INET6 to force the + socket to use IPv4 or IPv6. If not set it will be determined + from host (defaults to AF_UNSPEC). + + flags is a bitmask for getaddrinfo(). + + sock can optionally be specified in order to use a preexisting + socket object. + + backlog is the maximum number of queued connections passed to + listen() (defaults to 100). + + ssl can be set to an SSLContext to enable SSL over the + accepted connections. + + reuse_address tells the kernel to reuse a local socket in + TIME_WAIT state, without waiting for its natural timeout to + expire. If not specified will automatically be set to True on + UNIX. + """ + raise NotImplementedError + + def create_unix_connection(self, protocol_factory, path, *, + ssl=None, sock=None, + server_hostname=None): + raise NotImplementedError + + def create_unix_server(self, protocol_factory, path, *, + sock=None, backlog=100, ssl=None): + """A coroutine which creates a UNIX Domain Socket server. + + The return value is a Server object, which can be used to stop + the service. + + path is a str, representing a file systsem path to bind the + server socket to. + + sock can optionally be specified in order to use a preexisting + socket object. + + backlog is the maximum number of queued connections passed to + listen() (defaults to 100). + + ssl can be set to an SSLContext to enable SSL over the + accepted connections. + """ + raise NotImplementedError + + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0): + raise NotImplementedError + + # Pipes and subprocesses. + + def connect_read_pipe(self, protocol_factory, pipe): + """Register read pipe in event loop. Set the pipe to non-blocking mode. + + protocol_factory should instantiate object with Protocol interface. + pipe is a file-like object. + Return pair (transport, protocol), where transport supports the + ReadTransport interface.""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def connect_write_pipe(self, protocol_factory, pipe): + """Register write pipe in event loop. + + protocol_factory should instantiate object with BaseProtocol interface. + Pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + WriteTransport interface.""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError + + def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + **kwargs): + raise NotImplementedError + + def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + **kwargs): + raise NotImplementedError + + # Ready-based callback registration methods. + # The add_*() methods return None. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. + + def add_reader(self, fd, callback, *args): + raise NotImplementedError + + def remove_reader(self, fd): + raise NotImplementedError + + def add_writer(self, fd, callback, *args): + raise NotImplementedError + + def remove_writer(self, fd): + raise NotImplementedError + + # Completion based I/O methods returning Futures. + + def sock_recv(self, sock, nbytes): + raise NotImplementedError + + def sock_sendall(self, sock, data): + raise NotImplementedError + + def sock_connect(self, sock, address): + raise NotImplementedError + + def sock_accept(self, sock): + raise NotImplementedError + + # Signal handling. + + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError + + def remove_signal_handler(self, sig): + raise NotImplementedError + + # Error handlers. + + def set_exception_handler(self, handler): + raise NotImplementedError + + def default_exception_handler(self, context): + raise NotImplementedError + + def call_exception_handler(self, context): + raise NotImplementedError + + # Debug flag management. + + def get_debug(self): + raise NotImplementedError + + def set_debug(self, enabled): + raise NotImplementedError + + +class AbstractEventLoopPolicy: + """Abstract policy for accessing the event loop.""" + + def get_event_loop(self): + """Get the event loop for the current context. + + Returns an event loop object implementing the BaseEventLoop interface, + or raises an exception in case no event loop has been set for the + current context and the current policy does not specify to create one. + + It should never return None.""" + raise NotImplementedError + + def set_event_loop(self, loop): + """Set the event loop for the current context to loop.""" + raise NotImplementedError + + def new_event_loop(self): + """Create and return a new event loop object according to this + policy's rules. If there's need to set this loop as the event loop for + the current context, set_event_loop must be called explicitly.""" + raise NotImplementedError + + # Child processes handling (Unix only). + + def get_child_watcher(self): + "Get the watcher for child processes." + raise NotImplementedError + + def set_child_watcher(self, watcher): + """Set the watcher for child processes.""" + raise NotImplementedError + + +class BaseDefaultEventLoopPolicy(AbstractEventLoopPolicy): + """Default policy implementation for accessing the event loop. + + In this policy, each thread has its own event loop. However, we + only automatically create an event loop by default for the main + thread; other threads by default have no event loop. + + Other policies may have different rules (e.g. a single global + event loop, or automatically creating an event loop per thread, or + using some other notion of context to which an event loop is + associated). + """ + + _loop_factory = None + + class _Local(threading.local): + _loop = None + _set_called = False + + def __init__(self): + self._local = self._Local() + + def get_event_loop(self): + """Get the event loop. + + This may be None or an instance of EventLoop. + """ + if (self._local._loop is None and + not self._local._set_called and + isinstance(threading.current_thread(), threading._MainThread)): + self.set_event_loop(self.new_event_loop()) + if self._local._loop is None: + raise RuntimeError('There is no current event loop in thread %r.' + % threading.current_thread().name) + return self._local._loop + + def set_event_loop(self, loop): + """Set the event loop.""" + self._local._set_called = True + assert loop is None or isinstance(loop, AbstractEventLoop) + self._local._loop = loop + + def new_event_loop(self): + """Create a new event loop. + + You must call set_event_loop() to make this the current event + loop. + """ + return self._loop_factory() + + +# Event loop policy. The policy itself is always global, even if the +# policy's rules say that there is an event loop per thread (or other +# notion of context). The default policy is installed by the first +# call to get_event_loop_policy(). +_event_loop_policy = None + +# Lock for protecting the on-the-fly creation of the event loop policy. +_lock = threading.Lock() + + +def _init_event_loop_policy(): + global _event_loop_policy + with _lock: + if _event_loop_policy is None: # pragma: no branch + from . import DefaultEventLoopPolicy + _event_loop_policy = DefaultEventLoopPolicy() + + +def get_event_loop_policy(): + """Get the current event loop policy.""" + if _event_loop_policy is None: + _init_event_loop_policy() + return _event_loop_policy + + +def set_event_loop_policy(policy): + """Set the current event loop policy. + + If policy is None, the default policy is restored.""" + global _event_loop_policy + assert policy is None or isinstance(policy, AbstractEventLoopPolicy) + _event_loop_policy = policy + + +def get_event_loop(): + """Equivalent to calling get_event_loop_policy().get_event_loop().""" + return get_event_loop_policy().get_event_loop() + + +def set_event_loop(loop): + """Equivalent to calling get_event_loop_policy().set_event_loop(loop).""" + get_event_loop_policy().set_event_loop(loop) + + +def new_event_loop(): + """Equivalent to calling get_event_loop_policy().new_event_loop().""" + return get_event_loop_policy().new_event_loop() + + +def get_child_watcher(): + """Equivalent to calling get_event_loop_policy().get_child_watcher().""" + return get_event_loop_policy().get_child_watcher() + + +def set_child_watcher(watcher): + """Equivalent to calling + get_event_loop_policy().set_child_watcher(watcher).""" + return get_event_loop_policy().set_child_watcher(watcher) diff --git a/asyncio/futures.py b/asyncio/futures.py new file mode 100644 index 00000000..19212a94 --- /dev/null +++ b/asyncio/futures.py @@ -0,0 +1,409 @@ +"""A Future class similar to the one in PEP 3148.""" + +__all__ = ['CancelledError', 'TimeoutError', + 'InvalidStateError', + 'Future', 'wrap_future', + ] + +import concurrent.futures._base +import logging +import reprlib +import sys +import traceback + +from . import events + +# States for Future. +_PENDING = 'PENDING' +_CANCELLED = 'CANCELLED' +_FINISHED = 'FINISHED' + +_PY34 = sys.version_info >= (3, 4) + +Error = concurrent.futures._base.Error +CancelledError = concurrent.futures.CancelledError +TimeoutError = concurrent.futures.TimeoutError + +STACK_DEBUG = logging.DEBUG - 1 # heavy-duty debugging + + +class InvalidStateError(Error): + """The operation is not allowed in this state.""" + + +class _TracebackLogger: + """Helper to log a traceback upon destruction if not cleared. + + This solves a nasty problem with Futures and Tasks that have an + exception set: if nobody asks for the exception, the exception is + never logged. This violates the Zen of Python: 'Errors should + never pass silently. Unless explicitly silenced.' + + However, we don't want to log the exception as soon as + set_exception() is called: if the calling code is written + properly, it will get the exception and handle it properly. But + we *do* want to log it if result() or exception() was never called + -- otherwise developers waste a lot of time wondering why their + buggy code fails silently. + + An earlier attempt added a __del__() method to the Future class + itself, but this backfired because the presence of __del__() + prevents garbage collection from breaking cycles. A way out of + this catch-22 is to avoid having a __del__() method on the Future + class itself, but instead to have a reference to a helper object + with a __del__() method that logs the traceback, where we ensure + that the helper object doesn't participate in cycles, and only the + Future has a reference to it. + + The helper object is added when set_exception() is called. When + the Future is collected, and the helper is present, the helper + object is also collected, and its __del__() method will log the + traceback. When the Future's result() or exception() method is + called (and a helper object is present), it removes the helper + object, after calling its clear() method to prevent it from + logging. + + One downside is that we do a fair amount of work to extract the + traceback from the exception, even when it is never logged. It + would seem cheaper to just store the exception object, but that + references the traceback, which references stack frames, which may + reference the Future, which references the _TracebackLogger, and + then the _TracebackLogger would be included in a cycle, which is + what we're trying to avoid! As an optimization, we don't + immediately format the exception; we only do the work when + activate() is called, which call is delayed until after all the + Future's callbacks have run. Since usually a Future has at least + one callback (typically set by 'yield from') and usually that + callback extracts the callback, thereby removing the need to + format the exception. + + PS. I don't claim credit for this solution. I first heard of it + in a discussion about closing files when they are collected. + """ + + __slots__ = ('loop', 'source_traceback', 'exc', 'tb') + + def __init__(self, future, exc): + self.loop = future._loop + self.source_traceback = future._source_traceback + self.exc = exc + self.tb = None + + def activate(self): + exc = self.exc + if exc is not None: + self.exc = None + self.tb = traceback.format_exception(exc.__class__, exc, + exc.__traceback__) + + def clear(self): + self.exc = None + self.tb = None + + def __del__(self): + if self.tb: + msg = 'Future/Task exception was never retrieved\n' + if self.source_traceback: + src = ''.join(traceback.format_list(self.source_traceback)) + msg += 'Future/Task created at (most recent call last):\n' + msg += '%s\n' % src.rstrip() + msg += ''.join(self.tb).rstrip() + self.loop.call_exception_handler({'message': msg}) + + +class Future: + """This class is *almost* compatible with concurrent.futures.Future. + + Differences: + + - result() and exception() do not take a timeout argument and + raise an exception when the future isn't done yet. + + - Callbacks registered with add_done_callback() are always called + via the event loop's call_soon_threadsafe(). + + - This class is not compatible with the wait() and as_completed() + methods in the concurrent.futures package. + + (In Python 3.4 or later we may be able to unify the implementations.) + """ + + # Class variables serving as defaults for instance variables. + _state = _PENDING + _result = None + _exception = None + _loop = None + _source_traceback = None + + _blocking = False # proper use of future (yield vs yield from) + + _log_traceback = False # Used for Python 3.4 and later + _tb_logger = None # Used for Python 3.3 only + + def __init__(self, *, loop=None): + """Initialize the future. + + The optional event_loop argument allows to explicitly set the event + loop object used by the future. If it's not provided, the future uses + the default event loop. + """ + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._callbacks = [] + if self._loop.get_debug(): + self._source_traceback = traceback.extract_stack(sys._getframe(1)) + + def _format_callbacks(self): + cb = self._callbacks + size = len(cb) + if not size: + cb = '' + + def format_cb(callback): + return events._format_callback(callback, ()) + + if size == 1: + cb = format_cb(cb[0]) + elif size == 2: + cb = '{}, {}'.format(format_cb(cb[0]), format_cb(cb[1])) + elif size > 2: + cb = '{}, <{} more>, {}'.format(format_cb(cb[0]), + size-2, + format_cb(cb[-1])) + return 'cb=[%s]' % cb + + def _repr_info(self): + info = [self._state.lower()] + if self._state == _FINISHED: + if self._exception is not None: + info.append('exception={!r}'.format(self._exception)) + else: + # use reprlib to limit the length of the output, especially + # for very long strings + result = reprlib.repr(self._result) + info.append('result={}'.format(result)) + if self._callbacks: + info.append(self._format_callbacks()) + if self._source_traceback: + frame = self._source_traceback[-1] + info.append('created at %s:%s' % (frame[0], frame[1])) + return info + + def __repr__(self): + info = self._repr_info() + return '<%s %s>' % (self.__class__.__name__, ' '.join(info)) + + # On Python 3.3 or older, objects with a destructor part of a reference + # cycle are never destroyed. It's not more the case on Python 3.4 thanks to + # the PEP 442. + if _PY34: + def __del__(self): + if not self._log_traceback: + # set_exception() was not called, or result() or exception() + # has consumed the exception + return + exc = self._exception + context = { + 'message': ('%s exception was never retrieved' + % self.__class__.__name__), + 'exception': exc, + 'future': self, + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + + def cancel(self): + """Cancel the future and schedule callbacks. + + If the future is already done or cancelled, return False. Otherwise, + change the future's state to cancelled, schedule the callbacks and + return True. + """ + if self._state != _PENDING: + return False + self._state = _CANCELLED + self._schedule_callbacks() + return True + + def _schedule_callbacks(self): + """Internal: Ask the event loop to call all callbacks. + + The callbacks are scheduled to be called as soon as possible. Also + clears the callback list. + """ + callbacks = self._callbacks[:] + if not callbacks: + return + + self._callbacks[:] = [] + for callback in callbacks: + self._loop.call_soon(callback, self) + + def cancelled(self): + """Return True if the future was cancelled.""" + return self._state == _CANCELLED + + # Don't implement running(); see http://bugs.python.org/issue18699 + + def done(self): + """Return True if the future is done. + + Done means either that a result / exception are available, or that the + future was cancelled. + """ + return self._state != _PENDING + + def result(self): + """Return the result this future represents. + + If the future has been cancelled, raises CancelledError. If the + future's result isn't yet available, raises InvalidStateError. If + the future is done and has an exception set, this exception is raised. + """ + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError('Result is not ready.') + self._log_traceback = False + if self._tb_logger is not None: + self._tb_logger.clear() + self._tb_logger = None + if self._exception is not None: + raise self._exception + return self._result + + def exception(self): + """Return the exception that was set on this future. + + The exception (or None if no exception was set) is returned only if + the future is done. If the future has been cancelled, raises + CancelledError. If the future isn't done yet, raises + InvalidStateError. + """ + if self._state == _CANCELLED: + raise CancelledError + if self._state != _FINISHED: + raise InvalidStateError('Exception is not set.') + self._log_traceback = False + if self._tb_logger is not None: + self._tb_logger.clear() + self._tb_logger = None + return self._exception + + def add_done_callback(self, fn): + """Add a callback to be run when the future becomes done. + + The callback is called with a single argument - the future object. If + the future is already done when this is called, the callback is + scheduled with call_soon. + """ + if self._state != _PENDING: + self._loop.call_soon(fn, self) + else: + self._callbacks.append(fn) + + # New method not in PEP 3148. + + def remove_done_callback(self, fn): + """Remove all instances of a callback from the "call when done" list. + + Returns the number of callbacks removed. + """ + filtered_callbacks = [f for f in self._callbacks if f != fn] + removed_count = len(self._callbacks) - len(filtered_callbacks) + if removed_count: + self._callbacks[:] = filtered_callbacks + return removed_count + + # So-called internal methods (note: no set_running_or_notify_cancel()). + + def _set_result_unless_cancelled(self, result): + """Helper setting the result only if the future was not cancelled.""" + if self.cancelled(): + return + self.set_result(result) + + def set_result(self, result): + """Mark the future done and set its result. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError('{}: {!r}'.format(self._state, self)) + self._result = result + self._state = _FINISHED + self._schedule_callbacks() + + def set_exception(self, exception): + """Mark the future done and set an exception. + + If the future is already done when this method is called, raises + InvalidStateError. + """ + if self._state != _PENDING: + raise InvalidStateError('{}: {!r}'.format(self._state, self)) + if isinstance(exception, type): + exception = exception() + self._exception = exception + self._state = _FINISHED + self._schedule_callbacks() + if _PY34: + self._log_traceback = True + else: + self._tb_logger = _TracebackLogger(self, exception) + # Arrange for the logger to be activated after all callbacks + # have had a chance to call result() or exception(). + self._loop.call_soon(self._tb_logger.activate) + + # Truly internal methods. + + def _copy_state(self, other): + """Internal helper to copy state from another Future. + + The other Future may be a concurrent.futures.Future. + """ + assert other.done() + if self.cancelled(): + return + assert not self.done() + if other.cancelled(): + self.cancel() + else: + exception = other.exception() + if exception is not None: + self.set_exception(exception) + else: + result = other.result() + self.set_result(result) + + def __iter__(self): + if not self.done(): + self._blocking = True + yield self # This tells Task to wait for completion. + assert self.done(), "yield from wasn't used with future" + return self.result() # May raise too. + + +def wrap_future(fut, *, loop=None): + """Wrap concurrent.futures.Future object.""" + if isinstance(fut, Future): + return fut + assert isinstance(fut, concurrent.futures.Future), \ + 'concurrent.futures.Future is expected, got {!r}'.format(fut) + if loop is None: + loop = events.get_event_loop() + new_future = Future(loop=loop) + + def _check_cancel_other(f): + if f.cancelled(): + fut.cancel() + + new_future.add_done_callback(_check_cancel_other) + fut.add_done_callback( + lambda future: loop.call_soon_threadsafe( + new_future._copy_state, future)) + return new_future diff --git a/asyncio/locks.py b/asyncio/locks.py new file mode 100644 index 00000000..b943e9dd --- /dev/null +++ b/asyncio/locks.py @@ -0,0 +1,469 @@ +"""Synchronization primitives.""" + +__all__ = ['Lock', 'Event', 'Condition', 'Semaphore', 'BoundedSemaphore'] + +import collections + +from . import events +from . import futures +from .coroutines import coroutine + + +class _ContextManager: + """Context manager. + + This enables the following idiom for acquiring and releasing a + lock around a block: + + with (yield from lock): + + + while failing loudly when accidentally using: + + with lock: + + """ + + def __init__(self, lock): + self._lock = lock + + def __enter__(self): + # We have no use for the "as ..." clause in the with + # statement for locks. + return None + + def __exit__(self, *args): + try: + self._lock.release() + finally: + self._lock = None # Crudely prevent reuse. + + +class Lock: + """Primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned + by a particular coroutine when locked. A primitive lock is in one + of two states, 'locked' or 'unlocked'. + + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() + changes the state to locked and returns immediately. When the + state is locked, acquire() blocks until a call to release() in + another coroutine changes it to unlocked, then the acquire() call + resets it to locked and returns. The release() method should only + be called in the locked state; it changes the state to unlocked + and returns immediately. If an attempt is made to release an + unlocked lock, a RuntimeError will be raised. + + When more than one coroutine is blocked in acquire() waiting for + the state to turn to unlocked, only one coroutine proceeds when a + release() call resets the state to unlocked; first coroutine which + is blocked in acquire() is being processed. + + acquire() is a coroutine and should be called with 'yield from'. + + Locks also support the context management protocol. '(yield from lock)' + should be used as context manager expression. + + Usage: + + lock = Lock() + ... + yield from lock + try: + ... + finally: + lock.release() + + Context manager usage: + + lock = Lock() + ... + with (yield from lock): + ... + + Lock objects can be tested for locking state: + + if not lock.locked(): + yield from lock + else: + # lock is acquired + ... + + """ + + def __init__(self, *, loop=None): + self._waiters = collections.deque() + self._locked = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + extra = 'locked' if self._locked else 'unlocked' + if self._waiters: + extra = '{},waiters:{}'.format(extra, len(self._waiters)) + return '<{} [{}]>'.format(res[1:-1], extra) + + def locked(self): + """Return True if lock is acquired.""" + return self._locked + + @coroutine + def acquire(self): + """Acquire a lock. + + This method blocks until the lock is unlocked, then sets it to + locked and returns True. + """ + if not self._waiters and not self._locked: + self._locked = True + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + self._locked = True + return True + finally: + self._waiters.remove(fut) + + def release(self): + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + If any other coroutines are blocked waiting for the lock to become + unlocked, allow exactly one of them to proceed. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + # Wake up the first waiter who isn't cancelled. + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + break + else: + raise RuntimeError('Lock is not acquired.') + + def __enter__(self): + raise RuntimeError( + '"yield from" should be used as context manager expression') + + def __exit__(self, *args): + # This must exist because __enter__ exists, even though that + # always raises; that's how the with-statement works. + pass + + def __iter__(self): + # This is not a coroutine. It is meant to enable the idiom: + # + # with (yield from lock): + # + # + # as an alternative to: + # + # yield from lock.acquire() + # try: + # + # finally: + # lock.release() + yield from self.acquire() + return _ContextManager(self) + + +class Event: + """Asynchronous equivalent to threading.Event. + + Class implementing event objects. An event manages a flag that can be set + to true with the set() method and reset to false with the clear() method. + The wait() method blocks until the flag is true. The flag is initially + false. + """ + + def __init__(self, *, loop=None): + self._waiters = collections.deque() + self._value = False + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + extra = 'set' if self._value else 'unset' + if self._waiters: + extra = '{},waiters:{}'.format(extra, len(self._waiters)) + return '<{} [{}]>'.format(res[1:-1], extra) + + def is_set(self): + """Return True if and only if the internal flag is true.""" + return self._value + + def set(self): + """Set the internal flag to true. All coroutines waiting for it to + become true are awakened. Coroutine that call wait() once the flag is + true will not block at all. + """ + if not self._value: + self._value = True + + for fut in self._waiters: + if not fut.done(): + fut.set_result(True) + + def clear(self): + """Reset the internal flag to false. Subsequently, coroutines calling + wait() will block until set() is called to set the internal flag + to true again.""" + self._value = False + + @coroutine + def wait(self): + """Block until the internal flag is true. + + If the internal flag is true on entry, return True + immediately. Otherwise, block until another coroutine calls + set() to set the flag to true, then return True. + """ + if self._value: + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + return True + finally: + self._waiters.remove(fut) + + +class Condition: + """Asynchronous equivalent to threading.Condition. + + This class implements condition variable objects. A condition variable + allows one or more coroutines to wait until they are notified by another + coroutine. + + A new Lock object is created and used as the underlying lock. + """ + + def __init__(self, lock=None, *, loop=None): + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + if lock is None: + lock = Lock(loop=self._loop) + elif lock._loop is not self._loop: + raise ValueError("loop argument must agree with lock") + + self._lock = lock + # Export the lock's locked(), acquire() and release() methods. + self.locked = lock.locked + self.acquire = lock.acquire + self.release = lock.release + + self._waiters = collections.deque() + + def __repr__(self): + res = super().__repr__() + extra = 'locked' if self.locked() else 'unlocked' + if self._waiters: + extra = '{},waiters:{}'.format(extra, len(self._waiters)) + return '<{} [{}]>'.format(res[1:-1], extra) + + @coroutine + def wait(self): + """Wait until notified. + + If the calling coroutine has not acquired the lock when this + method is called, a RuntimeError is raised. + + This method releases the underlying lock, and then blocks + until it is awakened by a notify() or notify_all() call for + the same condition variable in another coroutine. Once + awakened, it re-acquires the lock and returns True. + """ + if not self.locked(): + raise RuntimeError('cannot wait on un-acquired lock') + + self.release() + try: + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + return True + finally: + self._waiters.remove(fut) + + finally: + yield from self.acquire() + + @coroutine + def wait_for(self, predicate): + """Wait until a predicate becomes true. + + The predicate should be a callable which result will be + interpreted as a boolean value. The final predicate value is + the return value. + """ + result = predicate() + while not result: + yield from self.wait() + result = predicate() + return result + + def notify(self, n=1): + """By default, wake up one coroutine waiting on this condition, if any. + If the calling coroutine has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up at most n of the coroutines waiting for the + condition variable; it is a no-op if no coroutines are waiting. + + Note: an awakened coroutine does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self.locked(): + raise RuntimeError('cannot notify on un-acquired lock') + + idx = 0 + for fut in self._waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self): + """Wake up all threads waiting on this condition. This method acts + like notify(), but wakes up all waiting threads instead of one. If the + calling thread has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._waiters)) + + def __enter__(self): + raise RuntimeError( + '"yield from" should be used as context manager expression') + + def __exit__(self, *args): + pass + + def __iter__(self): + # See comment in Lock.__iter__(). + yield from self.acquire() + return _ContextManager(self) + + +class Semaphore: + """A Semaphore implementation. + + A semaphore manages an internal counter which is decremented by each + acquire() call and incremented by each release() call. The counter + can never go below zero; when acquire() finds that it is zero, it blocks, + waiting until some other thread calls release(). + + Semaphores also support the context management protocol. + + The optional argument gives the initial value for the internal + counter; it defaults to 1. If the value given is less than 0, + ValueError is raised. + """ + + def __init__(self, value=1, *, loop=None): + if value < 0: + raise ValueError("Semaphore initial value must be >= 0") + self._value = value + self._waiters = collections.deque() + if loop is not None: + self._loop = loop + else: + self._loop = events.get_event_loop() + + def __repr__(self): + res = super().__repr__() + extra = 'locked' if self.locked() else 'unlocked,value:{}'.format( + self._value) + if self._waiters: + extra = '{},waiters:{}'.format(extra, len(self._waiters)) + return '<{} [{}]>'.format(res[1:-1], extra) + + def locked(self): + """Returns True if semaphore can not be acquired immediately.""" + return self._value == 0 + + @coroutine + def acquire(self): + """Acquire a semaphore. + + If the internal counter is larger than zero on entry, + decrement it by one and return True immediately. If it is + zero on entry, block, waiting until some other coroutine has + called release() to make it larger than 0, and then return + True. + """ + if not self._waiters and self._value > 0: + self._value -= 1 + return True + + fut = futures.Future(loop=self._loop) + self._waiters.append(fut) + try: + yield from fut + self._value -= 1 + return True + finally: + self._waiters.remove(fut) + + def release(self): + """Release a semaphore, incrementing the internal counter by one. + When it was zero on entry and another coroutine is waiting for it to + become larger than zero again, wake up that coroutine. + """ + self._value += 1 + for waiter in self._waiters: + if not waiter.done(): + waiter.set_result(True) + break + + def __enter__(self): + raise RuntimeError( + '"yield from" should be used as context manager expression') + + def __exit__(self, *args): + pass + + def __iter__(self): + # See comment in Lock.__iter__(). + yield from self.acquire() + return _ContextManager(self) + + +class BoundedSemaphore(Semaphore): + """A bounded semaphore implementation. + + This raises ValueError in release() if it would increase the value + above the initial value. + """ + + def __init__(self, value=1, *, loop=None): + self._bound_value = value + super().__init__(value, loop=loop) + + def release(self): + if self._value >= self._bound_value: + raise ValueError('BoundedSemaphore released too many times') + super().release() diff --git a/asyncio/log.py b/asyncio/log.py new file mode 100644 index 00000000..23a7074a --- /dev/null +++ b/asyncio/log.py @@ -0,0 +1,7 @@ +"""Logging configuration.""" + +import logging + + +# Name the logger after the package. +logger = logging.getLogger(__package__) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py new file mode 100644 index 00000000..ed170622 --- /dev/null +++ b/asyncio/proactor_events.py @@ -0,0 +1,535 @@ +"""Event loop using a proactor and related classes. + +A proactor is a "notify-on-completion" multiplexer. Currently a +proactor is only implemented on Windows with IOCP. +""" + +__all__ = ['BaseProactorEventLoop'] + +import socket + +from . import base_events +from . import constants +from . import futures +from . import sslproto +from . import transports +from .log import logger + + +class _ProactorBasePipeTransport(transports._FlowControlMixin, + transports.BaseTransport): + """Base class for pipe and socket transports.""" + + def __init__(self, loop, sock, protocol, waiter=None, + extra=None, server=None): + super().__init__(extra, loop) + self._set_extra(sock) + self._sock = sock + self._protocol = protocol + self._server = server + self._buffer = None # None or bytearray. + self._read_fut = None + self._write_fut = None + self._pending_write = 0 + self._conn_lost = 0 + self._closing = False # Set when close() called. + self._eof_written = False + if self._server is not None: + self._server._attach() + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + # wait until protocol.connection_made() has been called + self._loop.call_soon(waiter._set_result_unless_cancelled, None) + + def __repr__(self): + info = [self.__class__.__name__] + if self._sock is None: + info.append('closed') + elif self._closing: + info.append('closing') + if self._sock is not None: + info.append('fd=%s' % self._sock.fileno()) + if self._read_fut is not None: + info.append('read=%s' % self._read_fut) + if self._write_fut is not None: + info.append("write=%r" % self._write_fut) + if self._buffer: + bufsize = len(self._buffer) + info.append('write_bufsize=%s' % bufsize) + if self._eof_written: + info.append('EOF written') + return '<%s>' % ' '.join(info) + + def _set_extra(self, sock): + self._extra['pipe'] = sock + + def close(self): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + if not self._buffer and self._write_fut is None: + self._loop.call_soon(self._call_connection_lost, None) + if self._read_fut is not None: + self._read_fut.cancel() + self._read_fut = None + + def _fatal_error(self, exc, message='Fatal error on pipe transport'): + if isinstance(exc, (BrokenPipeError, ConnectionResetError)): + if self._loop.get_debug(): + logger.debug("%r: %s", self, message, exc_info=True) + else: + self._loop.call_exception_handler({ + 'message': message, + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) + self._force_close(exc) + + def _force_close(self, exc): + if self._closing: + return + self._closing = True + self._conn_lost += 1 + if self._write_fut: + self._write_fut.cancel() + self._write_fut = None + if self._read_fut: + self._read_fut.cancel() + self._read_fut = None + self._pending_write = 0 + self._buffer = None + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + # XXX If there is a pending overlapped read on the other + # end then it may fail with ERROR_NETNAME_DELETED if we + # just close our end. First calling shutdown() seems to + # cure it, but maybe using DisconnectEx() would be better. + if hasattr(self._sock, 'shutdown'): + self._sock.shutdown(socket.SHUT_RDWR) + self._sock.close() + self._sock = None + server = self._server + if server is not None: + server._detach() + self._server = None + + def get_write_buffer_size(self): + size = self._pending_write + if self._buffer is not None: + size += len(self._buffer) + return size + + +class _ProactorReadPipeTransport(_ProactorBasePipeTransport, + transports.ReadTransport): + """Transport for read pipes.""" + + def __init__(self, loop, sock, protocol, waiter=None, + extra=None, server=None): + super().__init__(loop, sock, protocol, waiter, extra, server) + self._paused = False + self._loop.call_soon(self._loop_reading) + + def pause_reading(self): + if self._closing: + raise RuntimeError('Cannot pause_reading() when closing') + if self._paused: + raise RuntimeError('Already paused') + self._paused = True + if self._loop.get_debug(): + logger.debug("%r pauses reading", self) + + def resume_reading(self): + if not self._paused: + raise RuntimeError('Not paused') + self._paused = False + if self._closing: + return + self._loop.call_soon(self._loop_reading, self._read_fut) + if self._loop.get_debug(): + logger.debug("%r resumes reading", self) + + def _loop_reading(self, fut=None): + if self._paused: + return + data = None + + try: + if fut is not None: + assert self._read_fut is fut or (self._read_fut is None and + self._closing) + self._read_fut = None + data = fut.result() # deliver data later in "finally" clause + + if self._closing: + # since close() has been called we ignore any read data + data = None + return + + if data == b'': + # we got end-of-file so no need to reschedule a new read + return + + # reschedule a new read + self._read_fut = self._loop._proactor.recv(self._sock, 4096) + except ConnectionAbortedError as exc: + if not self._closing: + self._fatal_error(exc, 'Fatal read error on pipe transport') + elif self._loop.get_debug(): + logger.debug("Read error on pipe transport while closing", + exc_info=True) + except ConnectionResetError as exc: + self._force_close(exc) + except OSError as exc: + self._fatal_error(exc, 'Fatal read error on pipe transport') + except futures.CancelledError: + if not self._closing: + raise + else: + self._read_fut.add_done_callback(self._loop_reading) + finally: + if data: + self._protocol.data_received(data) + elif data is not None: + if self._loop.get_debug(): + logger.debug("%r received EOF", self) + keep_open = self._protocol.eof_received() + if not keep_open: + self.close() + + +class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport, + transports.WriteTransport): + """Transport for write pipes.""" + + def write(self, data): + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError('data argument must be byte-ish (%r)', + type(data)) + if self._eof_written: + raise RuntimeError('write_eof() already called') + + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + # Observable states: + # 1. IDLE: _write_fut and _buffer both None + # 2. WRITING: _write_fut set; _buffer None + # 3. BACKED UP: _write_fut set; _buffer a bytearray + # We always copy the data, so the caller can't modify it + # while we're still waiting for the I/O to happen. + if self._write_fut is None: # IDLE -> WRITING + assert self._buffer is None + # Pass a copy, except if it's already immutable. + self._loop_writing(data=bytes(data)) + elif not self._buffer: # WRITING -> BACKED UP + # Make a mutable copy which we can extend. + self._buffer = bytearray(data) + self._maybe_pause_protocol() + else: # BACKED UP + # Append to buffer (also copies). + self._buffer.extend(data) + self._maybe_pause_protocol() + + def _loop_writing(self, f=None, data=None): + try: + assert f is self._write_fut + self._write_fut = None + self._pending_write = 0 + if f: + f.result() + if data is None: + data = self._buffer + self._buffer = None + if not data: + if self._closing: + self._loop.call_soon(self._call_connection_lost, None) + if self._eof_written: + self._sock.shutdown(socket.SHUT_WR) + # Now that we've reduced the buffer size, tell the + # protocol to resume writing if it was paused. Note that + # we do this last since the callback is called immediately + # and it may add more data to the buffer (even causing the + # protocol to be paused again). + self._maybe_resume_protocol() + else: + self._write_fut = self._loop._proactor.send(self._sock, data) + if not self._write_fut.done(): + assert self._pending_write == 0 + self._pending_write = len(data) + self._write_fut.add_done_callback(self._loop_writing) + self._maybe_pause_protocol() + else: + self._write_fut.add_done_callback(self._loop_writing) + except ConnectionResetError as exc: + self._force_close(exc) + except OSError as exc: + self._fatal_error(exc, 'Fatal write error on pipe transport') + + def can_write_eof(self): + return True + + def write_eof(self): + self.close() + + def abort(self): + self._force_close(None) + + +class _ProactorWritePipeTransport(_ProactorBaseWritePipeTransport): + def __init__(self, *args, **kw): + super().__init__(*args, **kw) + self._read_fut = self._loop._proactor.recv(self._sock, 16) + self._read_fut.add_done_callback(self._pipe_closed) + + def _pipe_closed(self, fut): + if fut.cancelled(): + # the transport has been closed + return + assert fut.result() == b'' + if self._closing: + assert self._read_fut is None + return + assert fut is self._read_fut, (fut, self._read_fut) + self._read_fut = None + if self._write_fut is not None: + self._force_close(BrokenPipeError()) + else: + self.close() + + +class _ProactorDuplexPipeTransport(_ProactorReadPipeTransport, + _ProactorBaseWritePipeTransport, + transports.Transport): + """Transport for duplex pipes.""" + + def can_write_eof(self): + return False + + def write_eof(self): + raise NotImplementedError + + +class _ProactorSocketTransport(_ProactorReadPipeTransport, + _ProactorBaseWritePipeTransport, + transports.Transport): + """Transport for connected sockets.""" + + def _set_extra(self, sock): + self._extra['socket'] = sock + try: + self._extra['sockname'] = sock.getsockname() + except (socket.error, AttributeError): + if self._loop.get_debug(): + logger.warning("getsockname() failed on %r", + sock, exc_info=True) + if 'peername' not in self._extra: + try: + self._extra['peername'] = sock.getpeername() + except (socket.error, AttributeError): + if self._loop.get_debug(): + logger.warning("getpeername() failed on %r", + sock, exc_info=True) + + def can_write_eof(self): + return True + + def write_eof(self): + if self._closing or self._eof_written: + return + self._eof_written = True + if self._write_fut is None: + self._sock.shutdown(socket.SHUT_WR) + + +class BaseProactorEventLoop(base_events.BaseEventLoop): + + def __init__(self, proactor): + super().__init__() + logger.debug('Using proactor: %s', proactor.__class__.__name__) + self._proactor = proactor + self._selector = proactor # convenient alias + self._self_reading_future = None + self._accept_futures = {} # socket file descriptor => Future + proactor.set_loop(self) + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, + extra=None, server=None): + return _ProactorSocketTransport(self, sock, protocol, waiter, + extra, server) + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, + *, server_side=False, server_hostname=None, + extra=None, server=None): + if not sslproto._is_sslproto_available(): + raise NotImplementedError("Proactor event loop requires Python 3.5" + " or newer (ssl.MemoryBIO) to support " + "SSL") + + ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter, + server_side, server_hostname) + _ProactorSocketTransport(self, rawsock, ssl_protocol, + extra=extra, server=server) + return ssl_protocol._app_transport + + def _make_duplex_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorDuplexPipeTransport(self, + sock, protocol, waiter, extra) + + def _make_read_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + return _ProactorReadPipeTransport(self, sock, protocol, waiter, extra) + + def _make_write_pipe_transport(self, sock, protocol, waiter=None, + extra=None): + # We want connection_lost() to be called when other end closes + return _ProactorWritePipeTransport(self, + sock, protocol, waiter, extra) + + def close(self): + if self.is_running(): + raise RuntimeError("Cannot close a running event loop") + if self.is_closed(): + return + + # Call these methods before closing the event loop (before calling + # BaseEventLoop.close), because they can schedule callbacks with + # call_soon(), which is forbidden when the event loop is closed. + self._stop_accept_futures() + self._close_self_pipe() + self._proactor.close() + self._proactor = None + self._selector = None + + # Close the event loop + super().close() + + def sock_recv(self, sock, n): + return self._proactor.recv(sock, n) + + def sock_sendall(self, sock, data): + return self._proactor.send(sock, data) + + def sock_connect(self, sock, address): + try: + base_events._check_resolved_address(sock, address) + except ValueError as err: + fut = futures.Future(loop=self) + fut.set_exception(err) + return fut + else: + return self._proactor.connect(sock, address) + + def sock_accept(self, sock): + return self._proactor.accept(sock) + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + if self._self_reading_future is not None: + self._self_reading_future.cancel() + self._self_reading_future = None + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.call_soon(self._loop_self_reading) + + def _loop_self_reading(self, f=None): + try: + if f is not None: + f.result() # may raise + f = self._proactor.recv(self._ssock, 4096) + except futures.CancelledError: + # _close_self_pipe() has been called, stop waiting for data + return + except Exception as exc: + self.call_exception_handler({ + 'message': 'Error on reading from the event loop self pipe', + 'exception': exc, + 'loop': self, + }) + else: + self._self_reading_future = f + f.add_done_callback(self._loop_self_reading) + + def _write_to_self(self): + self._csock.send(b'\0') + + def _start_serving(self, protocol_factory, sock, + sslcontext=None, server=None): + + def loop(f=None): + try: + if f is not None: + conn, addr = f.result() + if self._debug: + logger.debug("%r got a new connection from %r: %r", + server, addr, conn) + protocol = protocol_factory() + if sslcontext is not None: + self._make_ssl_transport( + conn, protocol, sslcontext, server_side=True, + extra={'peername': addr}, server=server) + else: + self._make_socket_transport( + conn, protocol, + extra={'peername': addr}, server=server) + if self.is_closed(): + return + f = self._proactor.accept(sock) + except OSError as exc: + if sock.fileno() != -1: + self.call_exception_handler({ + 'message': 'Accept failed on a socket', + 'exception': exc, + 'socket': sock, + }) + sock.close() + elif self._debug: + logger.debug("Accept failed on socket %r", + sock, exc_info=True) + except futures.CancelledError: + sock.close() + else: + self._accept_futures[sock.fileno()] = f + f.add_done_callback(loop) + + self.call_soon(loop) + + def _process_events(self, event_list): + # Events are processed in the IocpProactor._poll() method + pass + + def _stop_accept_futures(self): + for future in self._accept_futures.values(): + future.cancel() + self._accept_futures.clear() + + def _stop_serving(self, sock): + self._stop_accept_futures() + self._proactor._stop_serving(sock) + sock.close() diff --git a/asyncio/protocols.py b/asyncio/protocols.py new file mode 100644 index 00000000..52fc25c2 --- /dev/null +++ b/asyncio/protocols.py @@ -0,0 +1,129 @@ +"""Abstract Protocol class.""" + +__all__ = ['BaseProtocol', 'Protocol', 'DatagramProtocol', + 'SubprocessProtocol'] + + +class BaseProtocol: + """Common base class for protocol interfaces. + + Usually user implements protocols that derived from BaseProtocol + like Protocol or ProcessProtocol. + + The only case when BaseProtocol should be implemented directly is + write-only transport like write pipe + """ + + def connection_made(self, transport): + """Called when a connection is made. + + The argument is the transport representing the pipe connection. + To receive data, wait for data_received() calls. + When the connection is closed, connection_lost() is called. + """ + + def connection_lost(self, exc): + """Called when the connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + + def pause_writing(self): + """Called when the transport's buffer goes over the high-water mark. + + Pause and resume calls are paired -- pause_writing() is called + once when the buffer goes strictly over the high-water mark + (even if subsequent writes increases the buffer size even + more), and eventually resume_writing() is called once when the + buffer size reaches the low-water mark. + + Note that if the buffer size equals the high-water mark, + pause_writing() is not called -- it must go strictly over. + Conversely, resume_writing() is called when the buffer size is + equal or lower than the low-water mark. These end conditions + are important to ensure that things go as expected when either + mark is zero. + + NOTE: This is the only Protocol callback that is not called + through EventLoop.call_soon() -- if it were, it would have no + effect when it's most needed (when the app keeps writing + without yielding until pause_writing() is called). + """ + + def resume_writing(self): + """Called when the transport's buffer drains below the low-water mark. + + See pause_writing() for details. + """ + + +class Protocol(BaseProtocol): + """Interface for stream protocol. + + The user should implement this interface. They can inherit from + this class but don't need to. The implementations here do + nothing (they don't raise exceptions). + + When the user wants to requests a transport, they pass a protocol + factory to a utility function (e.g., EventLoop.create_connection()). + + When the connection is made successfully, connection_made() is + called with a suitable transport object. Then data_received() + will be called 0 or more times with data (bytes) received from the + transport; finally, connection_lost() will be called exactly once + with either an exception object or None as an argument. + + State machine of calls: + + start -> CM [-> DR*] [-> ER?] -> CL -> end + """ + + def data_received(self, data): + """Called when some data is received. + + The argument is a bytes object. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + If this returns a false value (including None), the transport + will close itself. If it returns a true value, closing the + transport is up to the protocol. + """ + + +class DatagramProtocol(BaseProtocol): + """Interface for datagram protocol.""" + + def datagram_received(self, data, addr): + """Called when some datagram is received.""" + + def error_received(self, exc): + """Called when a send or receive operation raises an OSError. + + (Other than BlockingIOError or InterruptedError.) + """ + + +class SubprocessProtocol(BaseProtocol): + """Interface for protocol for subprocess calls.""" + + def pipe_data_received(self, fd, data): + """Called when the subprocess writes data into stdout/stderr pipe. + + fd is int file descriptor. + data is bytes object. + """ + + def pipe_connection_lost(self, fd, exc): + """Called when a file descriptor associated with the child process is + closed. + + fd is the int file descriptor that was closed. + """ + + def process_exited(self): + """Called when subprocess has exited.""" diff --git a/asyncio/queues.py b/asyncio/queues.py new file mode 100644 index 00000000..dce0d53c --- /dev/null +++ b/asyncio/queues.py @@ -0,0 +1,298 @@ +"""Queues""" + +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'JoinableQueue', + 'QueueFull', 'QueueEmpty'] + +import collections +import heapq + +from . import events +from . import futures +from . import locks +from .tasks import coroutine + + +class QueueEmpty(Exception): + 'Exception raised by Queue.get(block=0)/get_nowait().' + pass + + +class QueueFull(Exception): + 'Exception raised by Queue.put(block=0)/put_nowait().' + pass + + +class Queue: + """A queue, useful for coordinating producer and consumer coroutines. + + If maxsize is less than or equal to zero, the queue size is infinite. If it + is an integer greater than 0, then "yield from put()" will block when the + queue reaches maxsize, until an item is removed by get(). + + Unlike the standard library Queue, you can reliably know this Queue's size + with qsize(), since your single-threaded asyncio application won't be + interrupted between calling qsize() and doing an operation on the Queue. + """ + + def __init__(self, maxsize=0, *, loop=None): + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._maxsize = maxsize + + # Futures. + self._getters = collections.deque() + # Pairs of (item, Future). + self._putters = collections.deque() + self._init(maxsize) + + def _init(self, maxsize): + self._queue = collections.deque() + + def _get(self): + return self._queue.popleft() + + def _put(self, item): + self._queue.append(item) + + def __repr__(self): + return '<{} at {:#x} {}>'.format( + type(self).__name__, id(self), self._format()) + + def __str__(self): + return '<{} {}>'.format(type(self).__name__, self._format()) + + def _format(self): + result = 'maxsize={!r}'.format(self._maxsize) + if getattr(self, '_queue', None): + result += ' _queue={!r}'.format(list(self._queue)) + if self._getters: + result += ' _getters[{}]'.format(len(self._getters)) + if self._putters: + result += ' _putters[{}]'.format(len(self._putters)) + return result + + def _consume_done_getters(self): + # Delete waiters at the head of the get() queue who've timed out. + while self._getters and self._getters[0].done(): + self._getters.popleft() + + def _consume_done_putters(self): + # Delete waiters at the head of the put() queue who've timed out. + while self._putters and self._putters[0][1].done(): + self._putters.popleft() + + def qsize(self): + """Number of items in the queue.""" + return len(self._queue) + + @property + def maxsize(self): + """Number of items allowed in the queue.""" + return self._maxsize + + def empty(self): + """Return True if the queue is empty, False otherwise.""" + return not self._queue + + def full(self): + """Return True if there are maxsize items in the queue. + + Note: if the Queue was initialized with maxsize=0 (the default), + then full() is never True. + """ + if self._maxsize <= 0: + return False + else: + return self.qsize() >= self._maxsize + + @coroutine + def put(self, item): + """Put an item into the queue. + + Put an item into the queue. If the queue is full, wait until a free + slot is available before adding item. + + This method is a coroutine. + """ + self._consume_done_getters() + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + + # getter cannot be cancelled, we just removed done getters + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize <= self.qsize(): + waiter = futures.Future(loop=self._loop) + + self._putters.append((item, waiter)) + yield from waiter + + else: + self._put(item) + + def put_nowait(self, item): + """Put an item into the queue without blocking. + + If no free slot is immediately available, raise QueueFull. + """ + self._consume_done_getters() + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + + # Use _put and _get instead of passing item straight to getter, in + # case a subclass has logic that must run (e.g. JoinableQueue). + self._put(item) + + # getter cannot be cancelled, we just removed done getters + getter.set_result(self._get()) + + elif self._maxsize > 0 and self._maxsize <= self.qsize(): + raise QueueFull + else: + self._put(item) + + @coroutine + def get(self): + """Remove and return an item from the queue. + + If queue is empty, wait until an item is available. + + This method is a coroutine. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + + # When a getter runs and frees up a slot so this putter can + # run, we need to defer the put for a tick to ensure that + # getters and putters alternate perfectly. See + # ChannelTest.test_wait. + self._loop.call_soon(putter._set_result_unless_cancelled, None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + waiter = futures.Future(loop=self._loop) + + self._getters.append(waiter) + return (yield from waiter) + + def get_nowait(self): + """Remove and return an item from the queue. + + Return an item if one is immediately available, else raise QueueEmpty. + """ + self._consume_done_putters() + if self._putters: + assert self.full(), 'queue not full, why are putters waiting?' + item, putter = self._putters.popleft() + self._put(item) + # Wake putter on next tick. + + # getter cannot be cancelled, we just removed done putters + putter.set_result(None) + + return self._get() + + elif self.qsize(): + return self._get() + else: + raise QueueEmpty + + +class PriorityQueue(Queue): + """A subclass of Queue; retrieves entries in priority order (lowest first). + + Entries are typically tuples of the form: (priority number, data). + """ + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item, heappush=heapq.heappush): + heappush(self._queue, item) + + def _get(self, heappop=heapq.heappop): + return heappop(self._queue) + + +class LifoQueue(Queue): + """A subclass of Queue that retrieves most recently added entries first.""" + + def _init(self, maxsize): + self._queue = [] + + def _put(self, item): + self._queue.append(item) + + def _get(self): + return self._queue.pop() + + +class JoinableQueue(Queue): + """A subclass of Queue with task_done() and join() methods.""" + + def __init__(self, maxsize=0, *, loop=None): + super().__init__(maxsize=maxsize, loop=loop) + self._unfinished_tasks = 0 + self._finished = locks.Event(loop=self._loop) + self._finished.set() + + def _format(self): + result = Queue._format(self) + if self._unfinished_tasks: + result += ' tasks={}'.format(self._unfinished_tasks) + return result + + def _put(self, item): + super()._put(item) + self._unfinished_tasks += 1 + self._finished.clear() + + def task_done(self): + """Indicate that a formerly enqueued task is complete. + + Used by queue consumers. For each get() used to fetch a task, + a subsequent call to task_done() tells the queue that the processing + on the task is complete. + + If a join() is currently blocking, it will resume when all items have + been processed (meaning that a task_done() call was received for every + item that had been put() into the queue). + + Raises ValueError if called more times than there were items placed in + the queue. + """ + if self._unfinished_tasks <= 0: + raise ValueError('task_done() called too many times') + self._unfinished_tasks -= 1 + if self._unfinished_tasks == 0: + self._finished.set() + + @coroutine + def join(self): + """Block until all items in the queue have been gotten and processed. + + The count of unfinished tasks goes up whenever an item is added to the + queue. The count goes down whenever a consumer thread calls task_done() + to indicate that the item was retrieved and all work on it is complete. + When the count of unfinished tasks drops to zero, join() unblocks. + """ + if self._unfinished_tasks > 0: + yield from self._finished.wait() diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py new file mode 100644 index 00000000..24f84615 --- /dev/null +++ b/asyncio/selector_events.py @@ -0,0 +1,1007 @@ +"""Event loop using a selector and related classes. + +A selector is a "notify-when-ready" multiplexer. For a subclass which +also includes support for signal handling, see the unix_events sub-module. +""" + +__all__ = ['BaseSelectorEventLoop'] + +import collections +import errno +import functools +import socket +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import base_events +from . import constants +from . import events +from . import futures +from . import selectors +from . import transports +from . import sslproto +from .log import logger + + +def _test_selector_event(selector, fd, event): + # Test if the selector is monitoring 'event' events + # for the file descriptor 'fd'. + try: + key = selector.get_key(fd) + except KeyError: + return False + else: + return bool(key.events & event) + + +class BaseSelectorEventLoop(base_events.BaseEventLoop): + """Selector event loop. + + See events.EventLoop for API specification. + """ + + def __init__(self, selector=None): + super().__init__() + + if selector is None: + selector = selectors.DefaultSelector() + logger.debug('Using selector: %s', selector.__class__.__name__) + self._selector = selector + self._make_self_pipe() + + def _make_socket_transport(self, sock, protocol, waiter=None, *, + extra=None, server=None): + return _SelectorSocketTransport(self, sock, protocol, waiter, + extra, server) + + def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, + *, server_side=False, server_hostname=None, + extra=None, server=None): + if not sslproto._is_sslproto_available(): + return self._make_legacy_ssl_transport( + rawsock, protocol, sslcontext, waiter, + server_side=server_side, server_hostname=server_hostname, + extra=extra, server=server) + + ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter, + server_side, server_hostname) + _SelectorSocketTransport(self, rawsock, ssl_protocol, + extra=extra, server=server) + return ssl_protocol._app_transport + + def _make_legacy_ssl_transport(self, rawsock, protocol, sslcontext, + waiter, *, + server_side=False, server_hostname=None, + extra=None, server=None): + # Use the legacy API: SSL_write, SSL_read, etc. The legacy API is used + # on Python 3.4 and older, when ssl.MemoryBIO is not available. + return _SelectorSslTransport( + self, rawsock, protocol, sslcontext, waiter, + server_side, server_hostname, extra, server) + + def _make_datagram_transport(self, sock, protocol, + address=None, waiter=None, extra=None): + return _SelectorDatagramTransport(self, sock, protocol, + address, waiter, extra) + + def close(self): + if self.is_running(): + raise RuntimeError("Cannot close a running event loop") + if self.is_closed(): + return + self._close_self_pipe() + super().close() + if self._selector is not None: + self._selector.close() + self._selector = None + + def _socketpair(self): + raise NotImplementedError + + def _close_self_pipe(self): + self.remove_reader(self._ssock.fileno()) + self._ssock.close() + self._ssock = None + self._csock.close() + self._csock = None + self._internal_fds -= 1 + + def _make_self_pipe(self): + # A self-socket, really. :-) + self._ssock, self._csock = self._socketpair() + self._ssock.setblocking(False) + self._csock.setblocking(False) + self._internal_fds += 1 + self.add_reader(self._ssock.fileno(), self._read_from_self) + + def _process_self_data(self, data): + pass + + def _read_from_self(self): + while True: + try: + data = self._ssock.recv(4096) + if not data: + break + self._process_self_data(data) + except InterruptedError: + continue + except BlockingIOError: + break + + def _write_to_self(self): + # This may be called from a different thread, possibly after + # _close_self_pipe() has been called or even while it is + # running. Guard for self._csock being None or closed. When + # a socket is closed, send() raises OSError (with errno set to + # EBADF, but let's not rely on the exact error code). + csock = self._csock + if csock is not None: + try: + csock.send(b'\0') + except OSError: + if self._debug: + logger.debug("Fail to write a null byte into the " + "self-pipe socket", + exc_info=True) + + def _start_serving(self, protocol_factory, sock, + sslcontext=None, server=None): + self.add_reader(sock.fileno(), self._accept_connection, + protocol_factory, sock, sslcontext, server) + + def _accept_connection(self, protocol_factory, sock, + sslcontext=None, server=None): + try: + conn, addr = sock.accept() + if self._debug: + logger.debug("%r got a new connection from %r: %r", + server, addr, conn) + conn.setblocking(False) + except (BlockingIOError, InterruptedError, ConnectionAbortedError): + pass # False alarm. + except OSError as exc: + # There's nowhere to send the error, so just log it. + if exc.errno in (errno.EMFILE, errno.ENFILE, + errno.ENOBUFS, errno.ENOMEM): + # Some platforms (e.g. Linux keep reporting the FD as + # ready, so we remove the read handler temporarily. + # We'll try again in a while. + self.call_exception_handler({ + 'message': 'socket.accept() out of system resource', + 'exception': exc, + 'socket': sock, + }) + self.remove_reader(sock.fileno()) + self.call_later(constants.ACCEPT_RETRY_DELAY, + self._start_serving, + protocol_factory, sock, sslcontext, server) + else: + raise # The event loop will catch, log and ignore it. + else: + protocol = protocol_factory() + if sslcontext: + self._make_ssl_transport( + conn, protocol, sslcontext, + server_side=True, extra={'peername': addr}, server=server) + else: + self._make_socket_transport( + conn, protocol , extra={'peername': addr}, + server=server) + # It's now up to the protocol to handle the connection. + + def add_reader(self, fd, callback, *args): + """Add a reader callback.""" + self._check_closed() + handle = events.Handle(callback, args, self) + try: + key = self._selector.get_key(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_READ, + (handle, None)) + else: + mask, (reader, writer) = key.events, key.data + self._selector.modify(fd, mask | selectors.EVENT_READ, + (handle, writer)) + if reader is not None: + reader.cancel() + + def remove_reader(self, fd): + """Remove a reader callback.""" + if self.is_closed(): + return False + try: + key = self._selector.get_key(fd) + except KeyError: + return False + else: + mask, (reader, writer) = key.events, key.data + mask &= ~selectors.EVENT_READ + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (None, writer)) + + if reader is not None: + reader.cancel() + return True + else: + return False + + def add_writer(self, fd, callback, *args): + """Add a writer callback..""" + self._check_closed() + handle = events.Handle(callback, args, self) + try: + key = self._selector.get_key(fd) + except KeyError: + self._selector.register(fd, selectors.EVENT_WRITE, + (None, handle)) + else: + mask, (reader, writer) = key.events, key.data + self._selector.modify(fd, mask | selectors.EVENT_WRITE, + (reader, handle)) + if writer is not None: + writer.cancel() + + def remove_writer(self, fd): + """Remove a writer callback.""" + if self.is_closed(): + return False + try: + key = self._selector.get_key(fd) + except KeyError: + return False + else: + mask, (reader, writer) = key.events, key.data + # Remove both writer and connector. + mask &= ~selectors.EVENT_WRITE + if not mask: + self._selector.unregister(fd) + else: + self._selector.modify(fd, mask, (reader, None)) + + if writer is not None: + writer.cancel() + return True + else: + return False + + def sock_recv(self, sock, n): + """Receive data from the socket. + + The return value is a bytes object representing the data received. + The maximum amount of data to be received at once is specified by + nbytes. + + This method is a coroutine. + """ + if self.get_debug() and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") + fut = futures.Future(loop=self) + self._sock_recv(fut, False, sock, n) + return fut + + def _sock_recv(self, fut, registered, sock, n): + # _sock_recv() can add itself as an I/O callback if the operation can't + # be done immediately. Don't use it directly, call sock_recv(). + fd = sock.fileno() + if registered: + # Remove the callback early. It should be rare that the + # selector says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_reader(fd) + if fut.cancelled(): + return + try: + data = sock.recv(n) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_recv, fut, True, sock, n) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(data) + + def sock_sendall(self, sock, data): + """Send data to the socket. + + The socket must be connected to a remote socket. This method continues + to send data from data until either all data has been sent or an + error occurs. None is returned on success. On error, an exception is + raised, and there is no way to determine how much data, if any, was + successfully processed by the receiving end of the connection. + + This method is a coroutine. + """ + if self.get_debug() and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") + fut = futures.Future(loop=self) + if data: + self._sock_sendall(fut, False, sock, data) + else: + fut.set_result(None) + return fut + + def _sock_sendall(self, fut, registered, sock, data): + fd = sock.fileno() + + if registered: + self.remove_writer(fd) + if fut.cancelled(): + return + + try: + n = sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + fut.set_exception(exc) + return + + if n == len(data): + fut.set_result(None) + else: + if n: + data = data[n:] + self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + + def sock_connect(self, sock, address): + """Connect to a remote socket at address. + + The address must be already resolved to avoid the trap of hanging the + entire event loop when the address requires doing a DNS lookup. For + example, it must be an IP address, not an hostname, for AF_INET and + AF_INET6 address families. Use getaddrinfo() to resolve the hostname + asynchronously. + + This method is a coroutine. + """ + if self.get_debug() and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") + fut = futures.Future(loop=self) + try: + base_events._check_resolved_address(sock, address) + except ValueError as err: + fut.set_exception(err) + else: + self._sock_connect(fut, sock, address) + return fut + + def _sock_connect(self, fut, sock, address): + fd = sock.fileno() + try: + while True: + try: + sock.connect(address) + except InterruptedError: + continue + else: + break + except BlockingIOError: + fut.add_done_callback(functools.partial(self._sock_connect_done, + fd)) + self.add_writer(fd, self._sock_connect_cb, fut, sock, address) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(None) + + def _sock_connect_done(self, fd, fut): + self.remove_writer(fd) + + def _sock_connect_cb(self, fut, sock, address): + if fut.cancelled(): + return + + try: + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to any except clause below. + raise OSError(err, 'Connect call failed %s' % (address,)) + except (BlockingIOError, InterruptedError): + # socket is still registered, the callback will be retried later + pass + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result(None) + + def sock_accept(self, sock): + """Accept a connection. + + The socket must be bound to an address and listening for connections. + The return value is a pair (conn, address) where conn is a new socket + object usable to send and receive data on the connection, and address + is the address bound to the socket on the other end of the connection. + + This method is a coroutine. + """ + if self.get_debug() and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") + fut = futures.Future(loop=self) + self._sock_accept(fut, False, sock) + return fut + + def _sock_accept(self, fut, registered, sock): + fd = sock.fileno() + if registered: + self.remove_reader(fd) + if fut.cancelled(): + return + try: + conn, address = sock.accept() + conn.setblocking(False) + except (BlockingIOError, InterruptedError): + self.add_reader(fd, self._sock_accept, fut, True, sock) + except Exception as exc: + fut.set_exception(exc) + else: + fut.set_result((conn, address)) + + def _process_events(self, event_list): + for key, mask in event_list: + fileobj, (reader, writer) = key.fileobj, key.data + if mask & selectors.EVENT_READ and reader is not None: + if reader._cancelled: + self.remove_reader(fileobj) + else: + self._add_callback(reader) + if mask & selectors.EVENT_WRITE and writer is not None: + if writer._cancelled: + self.remove_writer(fileobj) + else: + self._add_callback(writer) + + def _stop_serving(self, sock): + self.remove_reader(sock.fileno()) + sock.close() + + +class _SelectorTransport(transports._FlowControlMixin, + transports.Transport): + + max_size = 256 * 1024 # Buffer size passed to recv(). + + _buffer_factory = bytearray # Constructs initial value for self._buffer. + + def __init__(self, loop, sock, protocol, extra, server=None): + super().__init__(extra, loop) + self._extra['socket'] = sock + self._extra['sockname'] = sock.getsockname() + if 'peername' not in self._extra: + try: + self._extra['peername'] = sock.getpeername() + except socket.error: + self._extra['peername'] = None + self._sock = sock + self._sock_fd = sock.fileno() + self._protocol = protocol + self._server = server + self._buffer = self._buffer_factory() + self._conn_lost = 0 # Set when call to connection_lost scheduled. + self._closing = False # Set when close() called. + if self._server is not None: + self._server._attach() + + def __repr__(self): + info = [self.__class__.__name__] + if self._sock is None: + info.append('closed') + elif self._closing: + info.append('closing') + info.append('fd=%s' % self._sock_fd) + # test if the transport was closed + if self._loop is not None: + polling = _test_selector_event(self._loop._selector, + self._sock_fd, selectors.EVENT_READ) + if polling: + info.append('read=polling') + else: + info.append('read=idle') + + polling = _test_selector_event(self._loop._selector, + self._sock_fd, + selectors.EVENT_WRITE) + if polling: + state = 'polling' + else: + state = 'idle' + + bufsize = self.get_write_buffer_size() + info.append('write=<%s, bufsize=%s>' % (state, bufsize)) + return '<%s>' % ' '.join(info) + + def abort(self): + self._force_close(None) + + def close(self): + if self._closing: + return + self._closing = True + self._loop.remove_reader(self._sock_fd) + if not self._buffer: + self._conn_lost += 1 + self._loop.call_soon(self._call_connection_lost, None) + + def _fatal_error(self, exc, message='Fatal error on transport'): + # Should be called from exception handler only. + if isinstance(exc, (BrokenPipeError, + ConnectionResetError, ConnectionAbortedError)): + if self._loop.get_debug(): + logger.debug("%r: %s", self, message, exc_info=True) + else: + self._loop.call_exception_handler({ + 'message': message, + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) + self._force_close(exc) + + def _force_close(self, exc): + if self._conn_lost: + return + if self._buffer: + self._buffer.clear() + self._loop.remove_writer(self._sock_fd) + if not self._closing: + self._closing = True + self._loop.remove_reader(self._sock_fd) + self._conn_lost += 1 + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._sock.close() + self._sock = None + self._protocol = None + self._loop = None + server = self._server + if server is not None: + server._detach() + self._server = None + + def get_write_buffer_size(self): + return len(self._buffer) + + +class _SelectorSocketTransport(_SelectorTransport): + + def __init__(self, loop, sock, protocol, waiter=None, + extra=None, server=None): + super().__init__(loop, sock, protocol, extra, server) + self._eof = False + self._paused = False + + self._loop.add_reader(self._sock_fd, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + # wait until protocol.connection_made() has been called + self._loop.call_soon(waiter._set_result_unless_cancelled, None) + + def pause_reading(self): + if self._closing: + raise RuntimeError('Cannot pause_reading() when closing') + if self._paused: + raise RuntimeError('Already paused') + self._paused = True + self._loop.remove_reader(self._sock_fd) + if self._loop.get_debug(): + logger.debug("%r pauses reading", self) + + def resume_reading(self): + if not self._paused: + raise RuntimeError('Not paused') + self._paused = False + if self._closing: + return + self._loop.add_reader(self._sock_fd, self._read_ready) + if self._loop.get_debug(): + logger.debug("%r resumes reading", self) + + def _read_ready(self): + try: + data = self._sock.recv(self.max_size) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc, 'Fatal read error on socket transport') + else: + if data: + self._protocol.data_received(data) + else: + if self._loop.get_debug(): + logger.debug("%r received EOF", self) + keep_open = self._protocol.eof_received() + if keep_open: + # We're keeping the connection open so the + # protocol can write more, but we still can't + # receive more, so remove the reader callback. + self._loop.remove_reader(self._sock_fd) + else: + self.close() + + def write(self, data): + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError('data argument must be byte-ish (%r)', + type(data)) + if self._eof: + raise RuntimeError('Cannot call write() after write_eof()') + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Optimization: try to send now. + try: + n = self._sock.send(data) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._fatal_error(exc, 'Fatal write error on socket transport') + return + else: + data = data[n:] + if not data: + return + # Not all was written; register write handler. + self._loop.add_writer(self._sock_fd, self._write_ready) + + # Add it to the buffer. + self._buffer.extend(data) + self._maybe_pause_protocol() + + def _write_ready(self): + assert self._buffer, 'Data should not be empty' + + try: + n = self._sock.send(self._buffer) + except (BlockingIOError, InterruptedError): + pass + except Exception as exc: + self._loop.remove_writer(self._sock_fd) + self._buffer.clear() + self._fatal_error(exc, 'Fatal write error on socket transport') + else: + if n: + del self._buffer[:n] + self._maybe_resume_protocol() # May append to buffer. + if not self._buffer: + self._loop.remove_writer(self._sock_fd) + if self._closing: + self._call_connection_lost(None) + elif self._eof: + self._sock.shutdown(socket.SHUT_WR) + + def write_eof(self): + if self._eof: + return + self._eof = True + if not self._buffer: + self._sock.shutdown(socket.SHUT_WR) + + def can_write_eof(self): + return True + + +class _SelectorSslTransport(_SelectorTransport): + + _buffer_factory = bytearray + + def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, + server_side=False, server_hostname=None, + extra=None, server=None): + if ssl is None: + raise RuntimeError('stdlib ssl module not available') + + if not sslcontext: + sslcontext = sslproto._create_transport_context(server_side, server_hostname) + + wrap_kwargs = { + 'server_side': server_side, + 'do_handshake_on_connect': False, + } + if server_hostname and not server_side: + wrap_kwargs['server_hostname'] = server_hostname + sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs) + + super().__init__(loop, sslsock, protocol, extra, server) + + self._server_hostname = server_hostname + self._waiter = waiter + self._sslcontext = sslcontext + self._paused = False + + # SSL-specific extra info. (peercert is set later) + self._extra.update(sslcontext=sslcontext) + + if self._loop.get_debug(): + logger.debug("%r starts SSL handshake", self) + start_time = self._loop.time() + else: + start_time = None + self._on_handshake(start_time) + + def _on_handshake(self, start_time): + try: + self._sock.do_handshake() + except ssl.SSLWantReadError: + self._loop.add_reader(self._sock_fd, + self._on_handshake, start_time) + return + except ssl.SSLWantWriteError: + self._loop.add_writer(self._sock_fd, + self._on_handshake, start_time) + return + except BaseException as exc: + if self._loop.get_debug(): + logger.warning("%r: SSL handshake failed", + self, exc_info=True) + self._loop.remove_reader(self._sock_fd) + self._loop.remove_writer(self._sock_fd) + self._sock.close() + if self._waiter is not None and not self._waiter.cancelled(): + self._waiter.set_exception(exc) + if isinstance(exc, Exception): + return + else: + raise + + self._loop.remove_reader(self._sock_fd) + self._loop.remove_writer(self._sock_fd) + + peercert = self._sock.getpeercert() + if not hasattr(self._sslcontext, 'check_hostname'): + # Verify hostname if requested, Python 3.4+ uses check_hostname + # and checks the hostname in do_handshake() + if (self._server_hostname and + self._sslcontext.verify_mode != ssl.CERT_NONE): + try: + ssl.match_hostname(peercert, self._server_hostname) + except Exception as exc: + if self._loop.get_debug(): + logger.warning("%r: SSL handshake failed " + "on matching the hostname", + self, exc_info=True) + self._sock.close() + if (self._waiter is not None + and not self._waiter.cancelled()): + self._waiter.set_exception(exc) + return + + # Add extra info that becomes available after handshake. + self._extra.update(peercert=peercert, + cipher=self._sock.cipher(), + compression=self._sock.compression(), + ) + + self._read_wants_write = False + self._write_wants_read = False + self._loop.add_reader(self._sock_fd, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if self._waiter is not None: + # wait until protocol.connection_made() has been called + self._loop.call_soon(self._waiter._set_result_unless_cancelled, + None) + + if self._loop.get_debug(): + dt = self._loop.time() - start_time + logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3) + + def pause_reading(self): + # XXX This is a bit icky, given the comment at the top of + # _read_ready(). Is it possible to evoke a deadlock? I don't + # know, although it doesn't look like it; write() will still + # accept more data for the buffer and eventually the app will + # call resume_reading() again, and things will flow again. + + if self._closing: + raise RuntimeError('Cannot pause_reading() when closing') + if self._paused: + raise RuntimeError('Already paused') + self._paused = True + self._loop.remove_reader(self._sock_fd) + if self._loop.get_debug(): + logger.debug("%r pauses reading", self) + + def resume_reading(self): + if not self._paused: + raise RuntimeError('Not paused') + self._paused = False + if self._closing: + return + self._loop.add_reader(self._sock_fd, self._read_ready) + if self._loop.get_debug(): + logger.debug("%r resumes reading", self) + + def _read_ready(self): + if self._write_wants_read: + self._write_wants_read = False + self._write_ready() + + if self._buffer: + self._loop.add_writer(self._sock_fd, self._write_ready) + + try: + data = self._sock.recv(self.max_size) + except (BlockingIOError, InterruptedError, ssl.SSLWantReadError): + pass + except ssl.SSLWantWriteError: + self._read_wants_write = True + self._loop.remove_reader(self._sock_fd) + self._loop.add_writer(self._sock_fd, self._write_ready) + except Exception as exc: + self._fatal_error(exc, 'Fatal read error on SSL transport') + else: + if data: + self._protocol.data_received(data) + else: + try: + if self._loop.get_debug(): + logger.debug("%r received EOF", self) + keep_open = self._protocol.eof_received() + if keep_open: + logger.warning('returning true from eof_received() ' + 'has no effect when using ssl') + finally: + self.close() + + def _write_ready(self): + if self._read_wants_write: + self._read_wants_write = False + self._read_ready() + + if not (self._paused or self._closing): + self._loop.add_reader(self._sock_fd, self._read_ready) + + if self._buffer: + try: + n = self._sock.send(self._buffer) + except (BlockingIOError, InterruptedError, ssl.SSLWantWriteError): + n = 0 + except ssl.SSLWantReadError: + n = 0 + self._loop.remove_writer(self._sock_fd) + self._write_wants_read = True + except Exception as exc: + self._loop.remove_writer(self._sock_fd) + self._buffer.clear() + self._fatal_error(exc, 'Fatal write error on SSL transport') + return + + if n: + del self._buffer[:n] + + self._maybe_resume_protocol() # May append to buffer. + + if not self._buffer: + self._loop.remove_writer(self._sock_fd) + if self._closing: + self._call_connection_lost(None) + + def write(self, data): + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError('data argument must be byte-ish (%r)', + type(data)) + if not data: + return + + if self._conn_lost: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + self._loop.add_writer(self._sock_fd, self._write_ready) + + # Add it to the buffer. + self._buffer.extend(data) + self._maybe_pause_protocol() + + def can_write_eof(self): + return False + + +class _SelectorDatagramTransport(_SelectorTransport): + + _buffer_factory = collections.deque + + def __init__(self, loop, sock, protocol, address=None, + waiter=None, extra=None): + super().__init__(loop, sock, protocol, extra) + self._address = address + self._loop.add_reader(self._sock_fd, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + # wait until protocol.connection_made() has been called + self._loop.call_soon(waiter._set_result_unless_cancelled, None) + + def get_write_buffer_size(self): + return sum(len(data) for data, _ in self._buffer) + + def _read_ready(self): + try: + data, addr = self._sock.recvfrom(self.max_size) + except (BlockingIOError, InterruptedError): + pass + except OSError as exc: + self._protocol.error_received(exc) + except Exception as exc: + self._fatal_error(exc, 'Fatal read error on datagram transport') + else: + self._protocol.datagram_received(data, addr) + + def sendto(self, data, addr=None): + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError('data argument must be byte-ish (%r)', + type(data)) + if not data: + return + + if self._address and addr not in (None, self._address): + raise ValueError('Invalid address: must be None or %s' % + (self._address,)) + + if self._conn_lost and self._address: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('socket.send() raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Attempt to send it right away first. + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + return + except (BlockingIOError, InterruptedError): + self._loop.add_writer(self._sock_fd, self._sendto_ready) + except OSError as exc: + self._protocol.error_received(exc) + return + except Exception as exc: + self._fatal_error(exc, + 'Fatal write error on datagram transport') + return + + # Ensure that what we buffer is immutable. + self._buffer.append((bytes(data), addr)) + self._maybe_pause_protocol() + + def _sendto_ready(self): + while self._buffer: + data, addr = self._buffer.popleft() + try: + if self._address: + self._sock.send(data) + else: + self._sock.sendto(data, addr) + except (BlockingIOError, InterruptedError): + self._buffer.appendleft((data, addr)) # Try again later. + break + except OSError as exc: + self._protocol.error_received(exc) + return + except Exception as exc: + self._fatal_error(exc, + 'Fatal write error on datagram transport') + return + + self._maybe_resume_protocol() # May append to buffer. + if not self._buffer: + self._loop.remove_writer(self._sock_fd) + if self._closing: + self._call_connection_lost(None) diff --git a/asyncio/selectors.py b/asyncio/selectors.py new file mode 100644 index 00000000..6d569c30 --- /dev/null +++ b/asyncio/selectors.py @@ -0,0 +1,594 @@ +"""Selectors module. + +This module allows high-level and efficient I/O multiplexing, built upon the +`select` module primitives. +""" + + +from abc import ABCMeta, abstractmethod +from collections import namedtuple, Mapping +import math +import select +import sys + + +# generic events, that must be mapped to implementation-specific ones +EVENT_READ = (1 << 0) +EVENT_WRITE = (1 << 1) + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file object or file descriptor + + Returns: + corresponding file descriptor + + Raises: + ValueError if the object is invalid + """ + if isinstance(fileobj, int): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (AttributeError, TypeError, ValueError): + raise ValueError("Invalid file object: " + "{!r}".format(fileobj)) from None + if fd < 0: + raise ValueError("Invalid file descriptor: {}".format(fd)) + return fd + + +SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data']) +"""Object used to associate a file object to its backing file descriptor, +selected event mask and attached data.""" + + +class _SelectorMapping(Mapping): + """Mapping of file objects to selector keys.""" + + def __init__(self, selector): + self._selector = selector + + def __len__(self): + return len(self._selector._fd_to_key) + + def __getitem__(self, fileobj): + try: + fd = self._selector._fileobj_lookup(fileobj) + return self._selector._fd_to_key[fd] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + + def __iter__(self): + return iter(self._selector._fd_to_key) + + +class BaseSelector(metaclass=ABCMeta): + """Selector abstract base class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + efficient implementation on the current platform. + """ + + @abstractmethod + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object or file descriptor + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + + Raises: + ValueError if events is invalid + KeyError if fileobj is already registered + OSError if fileobj is closed or otherwise is unacceptable to + the underlying system call (if a system call is made) + + Note: + OSError may or may not be raised + """ + raise NotImplementedError + + @abstractmethod + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object or file descriptor + + Returns: + SelectorKey instance + + Raises: + KeyError if fileobj is not registered + + Note: + If fileobj is registered but has since been closed this does + *not* raise OSError (even if the wrapped syscall does) + """ + raise NotImplementedError + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object or file descriptor + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + + Raises: + Anything that unregister() or register() raises + """ + self.unregister(fileobj) + return self.register(fileobj, events, data) + + @abstractmethod + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout <= 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (key, events) for ready file objects + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE + """ + raise NotImplementedError + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + pass + + def get_key(self, fileobj): + """Return the key associated to a registered file object. + + Returns: + SelectorKey for this file object + """ + mapping = self.get_map() + if mapping is None: + raise RuntimeError('Selector is closed') + try: + return mapping[fileobj] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + + @abstractmethod + def get_map(self): + """Return a mapping of file objects to selector keys.""" + raise NotImplementedError + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + +class _BaseSelectorImpl(BaseSelector): + """Base selector implementation.""" + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # read-only mapping returned by get_map() + self._map = _SelectorMapping(self) + + def _fileobj_lookup(self, fileobj): + """Return a file descriptor from a file object. + + This wraps _fileobj_to_fd() to do an exhaustive search in case + the object is invalid but we still have it in our map. This + is used by unregister() so we can unregister an object that + was previously registered even if it is closed. It is also + used by _SelectorMapping. + """ + try: + return _fileobj_to_fd(fileobj) + except ValueError: + # Do an exhaustive search. + for key in self._fd_to_key.values(): + if key.fileobj is fileobj: + return key.fd + # Raise ValueError after all. + raise + + def register(self, fileobj, events, data=None): + if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)): + raise ValueError("Invalid events: {!r}".format(events)) + + key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data) + + if key.fd in self._fd_to_key: + raise KeyError("{!r} (FD {}) is already registered" + .format(fileobj, key.fd)) + + self._fd_to_key[key.fd] = key + return key + + def unregister(self, fileobj): + try: + key = self._fd_to_key.pop(self._fileobj_lookup(fileobj)) + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + return key + + def modify(self, fileobj, events, data=None): + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fd_to_key[self._fileobj_lookup(fileobj)] + except KeyError: + raise KeyError("{!r} is not registered".format(fileobj)) from None + if events != key.events: + self.unregister(fileobj) + key = self.register(fileobj, events, data) + elif data != key.data: + # Use a shortcut to update the data. + key = key._replace(data=data) + self._fd_to_key[key.fd] = key + return key + + def close(self): + self._fd_to_key.clear() + self._map = None + + def get_map(self): + return self._map + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key, or None if not found + """ + try: + return self._fd_to_key[fd] + except KeyError: + return None + + +class SelectSelector(_BaseSelectorImpl): + """Select-based selector.""" + + def __init__(self): + super().__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & EVENT_WRITE: + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + if sys.platform == 'win32': + def _select(self, r, w, _, timeout=None): + r, w, x = select.select(r, w, w, timeout) + return r, w + x, [] + else: + _select = select.select + + def select(self, timeout=None): + timeout = None if timeout is None else max(timeout, 0) + ready = [] + try: + r, w, _ = self._select(self._readers, self._writers, [], timeout) + except InterruptedError: + return ready + r = set(r) + w = set(w) + for fd in r | w: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + +if hasattr(select, 'poll'): + + class PollSelector(_BaseSelectorImpl): + """Poll-based selector.""" + + def __init__(self): + super().__init__() + self._poll = select.poll() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= select.POLLIN + if events & EVENT_WRITE: + poll_events |= select.POLLOUT + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + if timeout is None: + timeout = None + elif timeout <= 0: + timeout = 0 + else: + # poll() has a resolution of 1 millisecond, round away from + # zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) + ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & ~select.POLLIN: + events |= EVENT_WRITE + if event & ~select.POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + +if hasattr(select, 'epoll'): + + class EpollSelector(_BaseSelectorImpl): + """Epoll-based selector.""" + + def __init__(self): + super().__init__() + self._epoll = select.epoll() + + def fileno(self): + return self._epoll.fileno() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + epoll_events = 0 + if events & EVENT_READ: + epoll_events |= select.EPOLLIN + if events & EVENT_WRITE: + epoll_events |= select.EPOLLOUT + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + try: + self._epoll.unregister(key.fd) + except OSError: + # This can happen if the FD was closed since it + # was registered. + pass + return key + + def select(self, timeout=None): + if timeout is None: + timeout = -1 + elif timeout <= 0: + timeout = 0 + else: + # epoll_wait() has a resolution of 1 millisecond, round away + # from zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) * 1e-3 + + # epoll_wait() expects `maxevents` to be greater than zero; + # we want to make sure that `select()` can be called when no + # FD is registered. + max_ev = max(len(self._fd_to_key), 1) + + ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & ~select.EPOLLIN: + events |= EVENT_WRITE + if event & ~select.EPOLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + self._epoll.close() + super().close() + + +if hasattr(select, 'devpoll'): + + class DevpollSelector(_BaseSelectorImpl): + """Solaris /dev/poll selector.""" + + def __init__(self): + super().__init__() + self._devpoll = select.devpoll() + + def fileno(self): + return self._devpoll.fileno() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= select.POLLIN + if events & EVENT_WRITE: + poll_events |= select.POLLOUT + self._devpoll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + self._devpoll.unregister(key.fd) + return key + + def select(self, timeout=None): + if timeout is None: + timeout = None + elif timeout <= 0: + timeout = 0 + else: + # devpoll() has a resolution of 1 millisecond, round away from + # zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) + ready = [] + try: + fd_event_list = self._devpoll.poll(timeout) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & ~select.POLLIN: + events |= EVENT_WRITE + if event & ~select.POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + self._devpoll.close() + super().close() + + +if hasattr(select, 'kqueue'): + + class KqueueSelector(_BaseSelectorImpl): + """Kqueue-based selector.""" + + def __init__(self): + super().__init__() + self._kqueue = select.kqueue() + + def fileno(self): + return self._kqueue.fileno() + + def register(self, fileobj, events, data=None): + key = super().register(fileobj, events, data) + if events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def unregister(self, fileobj): + key = super().unregister(fileobj) + if key.events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_DELETE) + try: + self._kqueue.control([kev], 0, 0) + except OSError: + # This can happen if the FD was closed since it + # was registered. + pass + if key.events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_DELETE) + try: + self._kqueue.control([kev], 0, 0) + except OSError: + # See comment above. + pass + return key + + def select(self, timeout=None): + timeout = None if timeout is None else max(timeout, 0) + max_ev = len(self._fd_to_key) + ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except InterruptedError: + return ready + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == select.KQ_FILTER_READ: + events |= EVENT_READ + if flag == select.KQ_FILTER_WRITE: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + self._kqueue.close() + super().close() + + +# Choose the best implementation, roughly: +# epoll|kqueue|devpoll > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + DefaultSelector = KqueueSelector +elif 'EpollSelector' in globals(): + DefaultSelector = EpollSelector +elif 'DevpollSelector' in globals(): + DefaultSelector = DevpollSelector +elif 'PollSelector' in globals(): + DefaultSelector = PollSelector +else: + DefaultSelector = SelectSelector diff --git a/asyncio/sslproto.py b/asyncio/sslproto.py new file mode 100644 index 00000000..117dc565 --- /dev/null +++ b/asyncio/sslproto.py @@ -0,0 +1,646 @@ +import collections +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import protocols +from . import transports +from .log import logger + + +def _create_transport_context(server_side, server_hostname): + if server_side: + raise ValueError('Server side SSL needs a valid SSLContext') + + # Client side may pass ssl=True to use a default + # context; in that case the sslcontext passed is None. + # The default is secure for client connections. + if hasattr(ssl, 'create_default_context'): + # Python 3.4+: use up-to-date strong settings. + sslcontext = ssl.create_default_context() + if not server_hostname: + sslcontext.check_hostname = False + else: + # Fallback for Python 3.3. + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.options |= ssl.OP_NO_SSLv3 + sslcontext.set_default_verify_paths() + sslcontext.verify_mode = ssl.CERT_REQUIRED + return sslcontext + + +def _is_sslproto_available(): + return hasattr(ssl, "MemoryBIO") + + +# States of an _SSLPipe. +_UNWRAPPED = "UNWRAPPED" +_DO_HANDSHAKE = "DO_HANDSHAKE" +_WRAPPED = "WRAPPED" +_SHUTDOWN = "SHUTDOWN" + + +class _SSLPipe(object): + """An SSL "Pipe". + + An SSL pipe allows you to communicate with an SSL/TLS protocol instance + through memory buffers. It can be used to implement a security layer for an + existing connection where you don't have access to the connection's file + descriptor, or for some reason you don't want to use it. + + An SSL pipe can be in "wrapped" and "unwrapped" mode. In unwrapped mode, + data is passed through untransformed. In wrapped mode, application level + data is encrypted to SSL record level data and vice versa. The SSL record + level is the lowest level in the SSL protocol suite and is what travels + as-is over the wire. + + An SslPipe initially is in "unwrapped" mode. To start SSL, call + do_handshake(). To shutdown SSL again, call unwrap(). + """ + + max_size = 256 * 1024 # Buffer size passed to read() + + def __init__(self, context, server_side, server_hostname=None): + """ + The *context* argument specifies the ssl.SSLContext to use. + + The *server_side* argument indicates whether this is a server side or + client side transport. + + The optional *server_hostname* argument can be used to specify the + hostname you are connecting to. You may only specify this parameter if + the _ssl module supports Server Name Indication (SNI). + """ + self._context = context + self._server_side = server_side + self._server_hostname = server_hostname + self._state = _UNWRAPPED + self._incoming = ssl.MemoryBIO() + self._outgoing = ssl.MemoryBIO() + self._sslobj = None + self._need_ssldata = False + self._handshake_cb = None + self._shutdown_cb = None + + @property + def context(self): + """The SSL context passed to the constructor.""" + return self._context + + @property + def ssl_object(self): + """The internal ssl.SSLObject instance. + + Return None if the pipe is not wrapped. + """ + return self._sslobj + + @property + def need_ssldata(self): + """Whether more record level data is needed to complete a handshake + that is currently in progress.""" + return self._need_ssldata + + @property + def wrapped(self): + """ + Whether a security layer is currently in effect. + + Return False during handshake. + """ + return self._state == _WRAPPED + + def do_handshake(self, callback=None): + """Start the SSL handshake. + + Return a list of ssldata. A ssldata element is a list of buffers + + The optional *callback* argument can be used to install a callback that + will be called when the handshake is complete. The callback will be + called with None if successful, else an exception instance. + """ + if self._state != _UNWRAPPED: + raise RuntimeError('handshake in progress or completed') + self._sslobj = self._context.wrap_bio( + self._incoming, self._outgoing, + server_side=self._server_side, + server_hostname=self._server_hostname) + self._state = _DO_HANDSHAKE + self._handshake_cb = callback + ssldata, appdata = self.feed_ssldata(b'', only_handshake=True) + assert len(appdata) == 0 + return ssldata + + def shutdown(self, callback=None): + """Start the SSL shutdown sequence. + + Return a list of ssldata. A ssldata element is a list of buffers + + The optional *callback* argument can be used to install a callback that + will be called when the shutdown is complete. The callback will be + called without arguments. + """ + if self._state == _UNWRAPPED: + raise RuntimeError('no security layer present') + if self._state == _SHUTDOWN: + raise RuntimeError('shutdown in progress') + assert self._state in (_WRAPPED, _DO_HANDSHAKE) + self._state = _SHUTDOWN + self._shutdown_cb = callback + ssldata, appdata = self.feed_ssldata(b'') + assert appdata == [] or appdata == [b''] + return ssldata + + def feed_eof(self): + """Send a potentially "ragged" EOF. + + This method will raise an SSL_ERROR_EOF exception if the EOF is + unexpected. + """ + self._incoming.write_eof() + ssldata, appdata = self.feed_ssldata(b'') + assert appdata == [] or appdata == [b''] + + def feed_ssldata(self, data, only_handshake=False): + """Feed SSL record level data into the pipe. + + The data must be a bytes instance. It is OK to send an empty bytes + instance. This can be used to get ssldata for a handshake initiated by + this endpoint. + + Return a (ssldata, appdata) tuple. The ssldata element is a list of + buffers containing SSL data that needs to be sent to the remote SSL. + + The appdata element is a list of buffers containing plaintext data that + needs to be forwarded to the application. The appdata list may contain + an empty buffer indicating an SSL "close_notify" alert. This alert must + be acknowledged by calling shutdown(). + """ + if self._state == _UNWRAPPED: + # If unwrapped, pass plaintext data straight through. + if data: + appdata = [data] + else: + appdata = [] + return ([], appdata) + + self._need_ssldata = False + if data: + self._incoming.write(data) + + ssldata = [] + appdata = [] + try: + if self._state == _DO_HANDSHAKE: + # Call do_handshake() until it doesn't raise anymore. + self._sslobj.do_handshake() + self._state = _WRAPPED + if self._handshake_cb: + self._handshake_cb(None) + if only_handshake: + return (ssldata, appdata) + # Handshake done: execute the wrapped block + + if self._state == _WRAPPED: + # Main state: read data from SSL until close_notify + while True: + chunk = self._sslobj.read(self.max_size) + appdata.append(chunk) + if not chunk: # close_notify + break + + elif self._state == _SHUTDOWN: + # Call shutdown() until it doesn't raise anymore. + self._sslobj.unwrap() + self._sslobj = None + self._state = _UNWRAPPED + if self._shutdown_cb: + self._shutdown_cb() + + elif self._state == _UNWRAPPED: + # Drain possible plaintext data after close_notify. + appdata.append(self._incoming.read()) + except (ssl.SSLError, ssl.CertificateError) as exc: + if getattr(exc, 'errno', None) not in ( + ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, + ssl.SSL_ERROR_SYSCALL): + if self._state == _DO_HANDSHAKE and self._handshake_cb: + self._handshake_cb(exc) + raise + self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ) + + # Check for record level data that needs to be sent back. + # Happens for the initial handshake and renegotiations. + if self._outgoing.pending: + ssldata.append(self._outgoing.read()) + return (ssldata, appdata) + + def feed_appdata(self, data, offset=0): + """Feed plaintext data into the pipe. + + Return an (ssldata, offset) tuple. The ssldata element is a list of + buffers containing record level data that needs to be sent to the + remote SSL instance. The offset is the number of plaintext bytes that + were processed, which may be less than the length of data. + + NOTE: In case of short writes, this call MUST be retried with the SAME + buffer passed into the *data* argument (i.e. the id() must be the + same). This is an OpenSSL requirement. A further particularity is that + a short write will always have offset == 0, because the _ssl module + does not enable partial writes. And even though the offset is zero, + there will still be encrypted data in ssldata. + """ + assert 0 <= offset <= len(data) + if self._state == _UNWRAPPED: + # pass through data in unwrapped mode + if offset < len(data): + ssldata = [data[offset:]] + else: + ssldata = [] + return (ssldata, len(data)) + + ssldata = [] + view = memoryview(data) + while True: + self._need_ssldata = False + try: + if offset < len(view): + offset += self._sslobj.write(view[offset:]) + except ssl.SSLError as exc: + # It is not allowed to call write() after unwrap() until the + # close_notify is acknowledged. We return the condition to the + # caller as a short write. + if exc.reason == 'PROTOCOL_IS_SHUTDOWN': + exc.errno = ssl.SSL_ERROR_WANT_READ + if exc.errno not in (ssl.SSL_ERROR_WANT_READ, + ssl.SSL_ERROR_WANT_WRITE, + ssl.SSL_ERROR_SYSCALL): + raise + self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ) + + # See if there's any record level data back for us. + if self._outgoing.pending: + ssldata.append(self._outgoing.read()) + if offset == len(view) or self._need_ssldata: + break + return (ssldata, offset) + + +class _SSLProtocolTransport(transports._FlowControlMixin, + transports.Transport): + + def __init__(self, loop, ssl_protocol, app_protocol): + self._loop = loop + self._ssl_protocol = ssl_protocol + self._app_protocol = app_protocol + + def get_extra_info(self, name, default=None): + """Get optional transport information.""" + return self._ssl_protocol._get_extra_info(name, default) + + def close(self): + """Close the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + self._ssl_protocol._start_shutdown() + + def pause_reading(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume_reading() is called. + """ + self._ssl_protocol._transport.pause_reading() + + def resume_reading(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + self._ssl_protocol._transport.resume_reading() + + def set_write_buffer_limits(self, high=None, low=None): + """Set the high- and low-water limits for write flow control. + + These two values control when to call the protocol's + pause_writing() and resume_writing() methods. If specified, + the low-water limit must be less than or equal to the + high-water limit. Neither value can be negative. + + The defaults are implementation-specific. If only the + high-water limit is given, the low-water limit defaults to a + implementation-specific value less than or equal to the + high-water limit. Setting high to zero forces low to zero as + well, and causes pause_writing() to be called whenever the + buffer becomes non-empty. Setting low to zero causes + resume_writing() to be called only once the buffer is empty. + Use of zero for either limit is generally sub-optimal as it + reduces opportunities for doing I/O and computation + concurrently. + """ + self._ssl_protocol._transport.set_write_buffer_limits(high, low) + + def get_write_buffer_size(self): + """Return the current size of the write buffer.""" + return self._ssl_protocol._transport.get_write_buffer_size() + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError("data: expecting a bytes-like instance, got {!r}" + .format(type(data).__name__)) + if not data: + return + self._ssl_protocol._write_appdata(data) + + def can_write_eof(self): + """Return True if this transport supports write_eof(), False if not.""" + return False + + def abort(self): + """Close the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + self._ssl_protocol._abort() + + +class SSLProtocol(protocols.Protocol): + """SSL protocol. + + Implementation of SSL on top of a socket using incoming and outgoing + buffers which are ssl.MemoryBIO objects. + """ + + def __init__(self, loop, app_protocol, sslcontext, waiter, + server_side=False, server_hostname=None): + if ssl is None: + raise RuntimeError('stdlib ssl module not available') + + if not sslcontext: + sslcontext = _create_transport_context(server_side, server_hostname) + + self._server_side = server_side + if server_hostname and not server_side: + self._server_hostname = server_hostname + else: + self._server_hostname = None + self._sslcontext = sslcontext + # SSL-specific extra info. More info are set when the handshake + # completes. + self._extra = dict(sslcontext=sslcontext) + + # App data write buffering + self._write_backlog = collections.deque() + self._write_buffer_size = 0 + + self._waiter = waiter + self._closing = False + self._loop = loop + self._app_protocol = app_protocol + self._app_transport = _SSLProtocolTransport(self._loop, + self, self._app_protocol) + self._sslpipe = None + self._session_established = False + self._in_handshake = False + self._in_shutdown = False + self._transport = None + + def connection_made(self, transport): + """Called when the low-level connection is made. + + Start the SSL handshake. + """ + self._transport = transport + self._sslpipe = _SSLPipe(self._sslcontext, + self._server_side, + self._server_hostname) + self._start_handshake() + + def connection_lost(self, exc): + """Called when the low-level connection is lost or closed. + + The argument is an exception object or None (the latter + meaning a regular EOF is received or the connection was + aborted or closed). + """ + if self._session_established: + self._session_established = False + self._loop.call_soon(self._app_protocol.connection_lost, exc) + self._transport = None + self._app_transport = None + + def pause_writing(self): + """Called when the low-level transport's buffer goes over + the high-water mark. + """ + self._app_protocol.pause_writing() + + def resume_writing(self): + """Called when the low-level transport's buffer drains below + the low-water mark. + """ + self._app_protocol.resume_writing() + + def data_received(self, data): + """Called when some SSL data is received. + + The argument is a bytes object. + """ + try: + ssldata, appdata = self._sslpipe.feed_ssldata(data) + except ssl.SSLError as e: + if self._loop.get_debug(): + logger.warning('%r: SSL error %s (reason %s)', + self, e.errno, e.reason) + self._abort() + return + + for chunk in ssldata: + self._transport.write(chunk) + + for chunk in appdata: + if chunk: + self._app_protocol.data_received(chunk) + else: + self._start_shutdown() + break + + def eof_received(self): + """Called when the other end of the low-level stream + is half-closed. + + If this returns a false value (including None), the transport + will close itself. If it returns a true value, closing the + transport is up to the protocol. + """ + try: + if self._loop.get_debug(): + logger.debug("%r received EOF", self) + if not self._in_handshake: + keep_open = self._app_protocol.eof_received() + if keep_open: + logger.warning('returning true from eof_received() ' + 'has no effect when using ssl') + finally: + self._transport.close() + + def _get_extra_info(self, name, default=None): + if name in self._extra: + return self._extra[name] + else: + return self._transport.get_extra_info(name, default) + + def _start_shutdown(self): + if self._in_shutdown: + return + self._in_shutdown = True + self._write_appdata(b'') + + def _write_appdata(self, data): + self._write_backlog.append((data, 0)) + self._write_buffer_size += len(data) + self._process_write_backlog() + + def _start_handshake(self): + if self._loop.get_debug(): + logger.debug("%r starts SSL handshake", self) + self._handshake_start_time = self._loop.time() + else: + self._handshake_start_time = None + self._in_handshake = True + # (b'', 1) is a special value in _process_write_backlog() to do + # the SSL handshake + self._write_backlog.append((b'', 1)) + self._loop.call_soon(self._process_write_backlog) + + def _on_handshake_complete(self, handshake_exc): + self._in_handshake = False + + sslobj = self._sslpipe.ssl_object + try: + if handshake_exc is not None: + raise handshake_exc + + peercert = sslobj.getpeercert() + if not hasattr(self._sslcontext, 'check_hostname'): + # Verify hostname if requested, Python 3.4+ uses check_hostname + # and checks the hostname in do_handshake() + if (self._server_hostname + and self._sslcontext.verify_mode != ssl.CERT_NONE): + ssl.match_hostname(peercert, self._server_hostname) + except BaseException as exc: + if self._loop.get_debug(): + if isinstance(exc, ssl.CertificateError): + logger.warning("%r: SSL handshake failed " + "on verifying the certificate", + self, exc_info=True) + else: + logger.warning("%r: SSL handshake failed", + self, exc_info=True) + self._transport.close() + if isinstance(exc, Exception): + if self._waiter is not None and not self._waiter.cancelled(): + self._waiter.set_exception(exc) + return + else: + raise + + if self._loop.get_debug(): + dt = self._loop.time() - self._handshake_start_time + logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3) + + # Add extra info that becomes available after handshake. + self._extra.update(peercert=peercert, + cipher=sslobj.cipher(), + compression=sslobj.compression(), + ) + self._app_protocol.connection_made(self._app_transport) + if self._waiter is not None: + # wait until protocol.connection_made() has been called + self._waiter._set_result_unless_cancelled(None) + self._session_established = True + # In case transport.write() was already called. Don't call + # immediatly _process_write_backlog(), but schedule it: + # _on_handshake_complete() can be called indirectly from + # _process_write_backlog(), and _process_write_backlog() is not + # reentrant. + self._loop.call_soon(self._process_write_backlog) + + def _process_write_backlog(self): + # Try to make progress on the write backlog. + if self._transport is None: + return + + try: + for i in range(len(self._write_backlog)): + data, offset = self._write_backlog[0] + if data: + ssldata, offset = self._sslpipe.feed_appdata(data, offset) + elif offset: + ssldata = self._sslpipe.do_handshake(self._on_handshake_complete) + offset = 1 + else: + ssldata = self._sslpipe.shutdown(self._finalize) + offset = 1 + + for chunk in ssldata: + self._transport.write(chunk) + + if offset < len(data): + self._write_backlog[0] = (data, offset) + # A short write means that a write is blocked on a read + # We need to enable reading if it is paused! + assert self._sslpipe.need_ssldata + if self._transport._paused: + self._transport.resume_reading() + break + + # An entire chunk from the backlog was processed. We can + # delete it and reduce the outstanding buffer size. + del self._write_backlog[0] + self._write_buffer_size -= len(data) + except BaseException as exc: + if self._in_handshake: + self._on_handshake_complete(exc) + else: + self._fatal_error(exc, 'Fatal error on SSL transport') + + def _fatal_error(self, exc, message='Fatal error on transport'): + # Should be called from exception handler only. + if isinstance(exc, (BrokenPipeError, ConnectionResetError)): + if self._loop.get_debug(): + logger.debug("%r: %s", self, message, exc_info=True) + else: + self._loop.call_exception_handler({ + 'message': message, + 'exception': exc, + 'transport': self._transport, + 'protocol': self, + }) + if self._transport: + self._transport._force_close(exc) + + def _finalize(self): + if self._transport is not None: + self._transport.close() + + def _abort(self): + if self._transport is not None: + try: + self._transport.abort() + finally: + self._finalize() diff --git a/asyncio/streams.py b/asyncio/streams.py new file mode 100644 index 00000000..7ff16a48 --- /dev/null +++ b/asyncio/streams.py @@ -0,0 +1,486 @@ +"""Stream-related things.""" + +__all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol', + 'open_connection', 'start_server', + 'IncompleteReadError', + ] + +import socket + +if hasattr(socket, 'AF_UNIX'): + __all__.extend(['open_unix_connection', 'start_unix_server']) + +from . import coroutines +from . import events +from . import futures +from . import protocols +from .coroutines import coroutine +from .log import logger + + +_DEFAULT_LIMIT = 2**16 + + +class IncompleteReadError(EOFError): + """ + Incomplete read error. Attributes: + + - partial: read bytes string before the end of stream was reached + - expected: total number of expected bytes + """ + def __init__(self, partial, expected): + EOFError.__init__(self, "%s bytes read on a total of %s expected bytes" + % (len(partial), expected)) + self.partial = partial + self.expected = expected + + +@coroutine +def open_connection(host=None, port=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """A wrapper for create_connection() returning a (reader, writer) pair. + + The reader returned is a StreamReader instance; the writer is a + StreamWriter instance. + + The arguments are all the usual arguments to create_connection() + except protocol_factory; most common are positional host and port, + with various optional keyword arguments following. + + Additional optional keyword arguments are loop (to set the event loop + instance to use) and limit (to set the buffer limit passed to the + StreamReader). + + (If you want to customize the StreamReader and/or + StreamReaderProtocol classes, just copy the code -- there's + really nothing special here except some convenience.) + """ + if loop is None: + loop = events.get_event_loop() + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, loop=loop) + transport, _ = yield from loop.create_connection( + lambda: protocol, host, port, **kwds) + writer = StreamWriter(transport, protocol, reader, loop) + return reader, writer + + +@coroutine +def start_server(client_connected_cb, host=None, port=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """Start a socket server, call back for each client connected. + + The first parameter, `client_connected_cb`, takes two parameters: + client_reader, client_writer. client_reader is a StreamReader + object, while client_writer is a StreamWriter object. This + parameter can either be a plain callback function or a coroutine; + if it is a coroutine, it will be automatically converted into a + Task. + + The rest of the arguments are all the usual arguments to + loop.create_server() except protocol_factory; most common are + positional host and port, with various optional keyword arguments + following. The return value is the same as loop.create_server(). + + Additional optional keyword arguments are loop (to set the event loop + instance to use) and limit (to set the buffer limit passed to the + StreamReader). + + The return value is the same as loop.create_server(), i.e. a + Server object which can be used to stop the service. + """ + if loop is None: + loop = events.get_event_loop() + + def factory(): + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, client_connected_cb, + loop=loop) + return protocol + + return (yield from loop.create_server(factory, host, port, **kwds)) + + +if hasattr(socket, 'AF_UNIX'): + # UNIX Domain Sockets are supported on this platform + + @coroutine + def open_unix_connection(path=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """Similar to `open_connection` but works with UNIX Domain Sockets.""" + if loop is None: + loop = events.get_event_loop() + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, loop=loop) + transport, _ = yield from loop.create_unix_connection( + lambda: protocol, path, **kwds) + writer = StreamWriter(transport, protocol, reader, loop) + return reader, writer + + + @coroutine + def start_unix_server(client_connected_cb, path=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): + """Similar to `start_server` but works with UNIX Domain Sockets.""" + if loop is None: + loop = events.get_event_loop() + + def factory(): + reader = StreamReader(limit=limit, loop=loop) + protocol = StreamReaderProtocol(reader, client_connected_cb, + loop=loop) + return protocol + + return (yield from loop.create_unix_server(factory, path, **kwds)) + + +class FlowControlMixin(protocols.Protocol): + """Reusable flow control logic for StreamWriter.drain(). + + This implements the protocol methods pause_writing(), + resume_reading() and connection_lost(). If the subclass overrides + these it must call the super methods. + + StreamWriter.drain() must wait for _drain_helper() coroutine. + """ + + def __init__(self, loop=None): + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._paused = False + self._drain_waiter = None + self._connection_lost = False + + def pause_writing(self): + assert not self._paused + self._paused = True + if self._loop.get_debug(): + logger.debug("%r pauses writing", self) + + def resume_writing(self): + assert self._paused + self._paused = False + if self._loop.get_debug(): + logger.debug("%r resumes writing", self) + + waiter = self._drain_waiter + if waiter is not None: + self._drain_waiter = None + if not waiter.done(): + waiter.set_result(None) + + def connection_lost(self, exc): + self._connection_lost = True + # Wake up the writer if currently paused. + if not self._paused: + return + waiter = self._drain_waiter + if waiter is None: + return + self._drain_waiter = None + if waiter.done(): + return + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) + + @coroutine + def _drain_helper(self): + if self._connection_lost: + raise ConnectionResetError('Connection lost') + if not self._paused: + return + waiter = self._drain_waiter + assert waiter is None or waiter.cancelled() + waiter = futures.Future(loop=self._loop) + self._drain_waiter = waiter + yield from waiter + + +class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): + """Helper class to adapt between Protocol and StreamReader. + + (This is a helper class instead of making StreamReader itself a + Protocol subclass, because the StreamReader has other potential + uses, and to prevent the user of the StreamReader to accidentally + call inappropriate methods of the protocol.) + """ + + def __init__(self, stream_reader, client_connected_cb=None, loop=None): + super().__init__(loop=loop) + self._stream_reader = stream_reader + self._stream_writer = None + self._client_connected_cb = client_connected_cb + + def connection_made(self, transport): + self._stream_reader.set_transport(transport) + if self._client_connected_cb is not None: + self._stream_writer = StreamWriter(transport, self, + self._stream_reader, + self._loop) + res = self._client_connected_cb(self._stream_reader, + self._stream_writer) + if coroutines.iscoroutine(res): + self._loop.create_task(res) + + def connection_lost(self, exc): + if exc is None: + self._stream_reader.feed_eof() + else: + self._stream_reader.set_exception(exc) + super().connection_lost(exc) + + def data_received(self, data): + self._stream_reader.feed_data(data) + + def eof_received(self): + self._stream_reader.feed_eof() + + +class StreamWriter: + """Wraps a Transport. + + This exposes write(), writelines(), [can_]write_eof(), + get_extra_info() and close(). It adds drain() which returns an + optional Future on which you can wait for flow control. It also + adds a transport property which references the Transport + directly. + """ + + def __init__(self, transport, protocol, reader, loop): + self._transport = transport + self._protocol = protocol + # drain() expects that the reader has a exception() method + assert reader is None or isinstance(reader, StreamReader) + self._reader = reader + self._loop = loop + + def __repr__(self): + info = [self.__class__.__name__, 'transport=%r' % self._transport] + if self._reader is not None: + info.append('reader=%r' % self._reader) + return '<%s>' % ' '.join(info) + + @property + def transport(self): + return self._transport + + def write(self, data): + self._transport.write(data) + + def writelines(self, data): + self._transport.writelines(data) + + def write_eof(self): + return self._transport.write_eof() + + def can_write_eof(self): + return self._transport.can_write_eof() + + def close(self): + return self._transport.close() + + def get_extra_info(self, name, default=None): + return self._transport.get_extra_info(name, default) + + @coroutine + def drain(self): + """Flush the write buffer. + + The intended use is to write + + w.write(data) + yield from w.drain() + """ + if self._reader is not None: + exc = self._reader.exception() + if exc is not None: + raise exc + yield from self._protocol._drain_helper() + + +class StreamReader: + + def __init__(self, limit=_DEFAULT_LIMIT, loop=None): + # The line length limit is a security feature; + # it also doubles as half the buffer limit. + self._limit = limit + if loop is None: + self._loop = events.get_event_loop() + else: + self._loop = loop + self._buffer = bytearray() + self._eof = False # Whether we're done. + self._waiter = None # A future used by _wait_for_data() + self._exception = None + self._transport = None + self._paused = False + + def exception(self): + return self._exception + + def set_exception(self, exc): + self._exception = exc + + waiter = self._waiter + if waiter is not None: + self._waiter = None + if not waiter.cancelled(): + waiter.set_exception(exc) + + def _wakeup_waiter(self): + """Wakeup read() or readline() function waiting for data or EOF.""" + waiter = self._waiter + if waiter is not None: + self._waiter = None + if not waiter.cancelled(): + waiter.set_result(None) + + def set_transport(self, transport): + assert self._transport is None, 'Transport already set' + self._transport = transport + + def _maybe_resume_transport(self): + if self._paused and len(self._buffer) <= self._limit: + self._paused = False + self._transport.resume_reading() + + def feed_eof(self): + self._eof = True + self._wakeup_waiter() + + def at_eof(self): + """Return True if the buffer is empty and 'feed_eof' was called.""" + return self._eof and not self._buffer + + def feed_data(self, data): + assert not self._eof, 'feed_data after feed_eof' + + if not data: + return + + self._buffer.extend(data) + self._wakeup_waiter() + + if (self._transport is not None and + not self._paused and + len(self._buffer) > 2*self._limit): + try: + self._transport.pause_reading() + except NotImplementedError: + # The transport can't be paused. + # We'll just have to buffer all data. + # Forget the transport so we don't keep trying. + self._transport = None + else: + self._paused = True + + def _wait_for_data(self, func_name): + """Wait until feed_data() or feed_eof() is called.""" + # StreamReader uses a future to link the protocol feed_data() method + # to a read coroutine. Running two read coroutines at the same time + # would have an unexpected behaviour. It would not possible to know + # which coroutine would get the next data. + if self._waiter is not None: + raise RuntimeError('%s() called while another coroutine is ' + 'already waiting for incoming data' % func_name) + + self._waiter = futures.Future(loop=self._loop) + try: + yield from self._waiter + finally: + self._waiter = None + + @coroutine + def readline(self): + if self._exception is not None: + raise self._exception + + line = bytearray() + not_enough = True + + while not_enough: + while self._buffer and not_enough: + ichar = self._buffer.find(b'\n') + if ichar < 0: + line.extend(self._buffer) + self._buffer.clear() + else: + ichar += 1 + line.extend(self._buffer[:ichar]) + del self._buffer[:ichar] + not_enough = False + + if len(line) > self._limit: + self._maybe_resume_transport() + raise ValueError('Line is too long') + + if self._eof: + break + + if not_enough: + yield from self._wait_for_data('readline') + + self._maybe_resume_transport() + return bytes(line) + + @coroutine + def read(self, n=-1): + if self._exception is not None: + raise self._exception + + if not n: + return b'' + + if n < 0: + # This used to just loop creating a new waiter hoping to + # collect everything in self._buffer, but that would + # deadlock if the subprocess sends more than self.limit + # bytes. So just call self.read(self._limit) until EOF. + blocks = [] + while True: + block = yield from self.read(self._limit) + if not block: + break + blocks.append(block) + return b''.join(blocks) + else: + if not self._buffer and not self._eof: + yield from self._wait_for_data('read') + + if n < 0 or len(self._buffer) <= n: + data = bytes(self._buffer) + self._buffer.clear() + else: + # n > 0 and len(self._buffer) > n + data = bytes(self._buffer[:n]) + del self._buffer[:n] + + self._maybe_resume_transport() + return data + + @coroutine + def readexactly(self, n): + if self._exception is not None: + raise self._exception + + # There used to be "optimized" code here. It created its own + # Future and waited until self._buffer had at least the n + # bytes, then called read(n). Unfortunately, this could pause + # the transport if the argument was larger than the pause + # limit (which is twice self._limit). So now we just read() + # into a local buffer. + + blocks = [] + while n > 0: + block = yield from self.read(n) + if not block: + partial = b''.join(blocks) + raise IncompleteReadError(partial, len(partial) + n) + blocks.append(block) + n -= len(block) + + return b''.join(blocks) diff --git a/asyncio/subprocess.py b/asyncio/subprocess.py new file mode 100644 index 00000000..c848a21a --- /dev/null +++ b/asyncio/subprocess.py @@ -0,0 +1,249 @@ +__all__ = ['create_subprocess_exec', 'create_subprocess_shell'] + +import collections +import subprocess + +from . import events +from . import futures +from . import protocols +from . import streams +from . import tasks +from .coroutines import coroutine +from .log import logger + + +PIPE = subprocess.PIPE +STDOUT = subprocess.STDOUT +DEVNULL = subprocess.DEVNULL + + +class SubprocessStreamProtocol(streams.FlowControlMixin, + protocols.SubprocessProtocol): + """Like StreamReaderProtocol, but for a subprocess.""" + + def __init__(self, limit, loop): + super().__init__(loop=loop) + self._limit = limit + self.stdin = self.stdout = self.stderr = None + self.waiter = futures.Future(loop=loop) + self._waiters = collections.deque() + self._transport = None + + def __repr__(self): + info = [self.__class__.__name__] + if self.stdin is not None: + info.append('stdin=%r' % self.stdin) + if self.stdout is not None: + info.append('stdout=%r' % self.stdout) + if self.stderr is not None: + info.append('stderr=%r' % self.stderr) + return '<%s>' % ' '.join(info) + + def connection_made(self, transport): + self._transport = transport + + stdout_transport = transport.get_pipe_transport(1) + if stdout_transport is not None: + self.stdout = streams.StreamReader(limit=self._limit, + loop=self._loop) + self.stdout.set_transport(stdout_transport) + + stderr_transport = transport.get_pipe_transport(2) + if stderr_transport is not None: + self.stderr = streams.StreamReader(limit=self._limit, + loop=self._loop) + self.stderr.set_transport(stderr_transport) + + stdin_transport = transport.get_pipe_transport(0) + if stdin_transport is not None: + self.stdin = streams.StreamWriter(stdin_transport, + protocol=self, + reader=None, + loop=self._loop) + + if not self.waiter.cancelled(): + self.waiter.set_result(None) + + def pipe_data_received(self, fd, data): + if fd == 1: + reader = self.stdout + elif fd == 2: + reader = self.stderr + else: + reader = None + if reader is not None: + reader.feed_data(data) + + def pipe_connection_lost(self, fd, exc): + if fd == 0: + pipe = self.stdin + if pipe is not None: + pipe.close() + self.connection_lost(exc) + return + if fd == 1: + reader = self.stdout + elif fd == 2: + reader = self.stderr + else: + reader = None + if reader != None: + if exc is None: + reader.feed_eof() + else: + reader.set_exception(exc) + + def process_exited(self): + returncode = self._transport.get_returncode() + self._transport.close() + self._transport = None + + # wake up futures waiting for wait() + while self._waiters: + waiter = self._waiters.popleft() + if not waiter.cancelled(): + waiter.set_result(returncode) + + +class Process: + def __init__(self, transport, protocol, loop): + self._transport = transport + self._protocol = protocol + self._loop = loop + self.stdin = protocol.stdin + self.stdout = protocol.stdout + self.stderr = protocol.stderr + self.pid = transport.get_pid() + + def __repr__(self): + return '<%s %s>' % (self.__class__.__name__, self.pid) + + @property + def returncode(self): + return self._transport.get_returncode() + + @coroutine + def wait(self): + """Wait until the process exit and return the process return code.""" + returncode = self._transport.get_returncode() + if returncode is not None: + return returncode + + waiter = futures.Future(loop=self._loop) + self._protocol._waiters.append(waiter) + yield from waiter + return waiter.result() + + def _check_alive(self): + if self._transport.get_returncode() is not None: + raise ProcessLookupError() + + def send_signal(self, signal): + self._check_alive() + self._transport.send_signal(signal) + + def terminate(self): + self._check_alive() + self._transport.terminate() + + def kill(self): + self._check_alive() + self._transport.kill() + + @coroutine + def _feed_stdin(self, input): + debug = self._loop.get_debug() + self.stdin.write(input) + if debug: + logger.debug('%r communicate: feed stdin (%s bytes)', + self, len(input)) + try: + yield from self.stdin.drain() + except (BrokenPipeError, ConnectionResetError) as exc: + # communicate() ignores BrokenPipeError and ConnectionResetError + if debug: + logger.debug('%r communicate: stdin got %r', self, exc) + + if debug: + logger.debug('%r communicate: close stdin', self) + self.stdin.close() + + @coroutine + def _noop(self): + return None + + @coroutine + def _read_stream(self, fd): + transport = self._transport.get_pipe_transport(fd) + if fd == 2: + stream = self.stderr + else: + assert fd == 1 + stream = self.stdout + if self._loop.get_debug(): + name = 'stdout' if fd == 1 else 'stderr' + logger.debug('%r communicate: read %s', self, name) + output = yield from stream.read() + if self._loop.get_debug(): + name = 'stdout' if fd == 1 else 'stderr' + logger.debug('%r communicate: close %s', self, name) + transport.close() + return output + + @coroutine + def communicate(self, input=None): + if input: + stdin = self._feed_stdin(input) + else: + stdin = self._noop() + if self.stdout is not None: + stdout = self._read_stream(1) + else: + stdout = self._noop() + if self.stderr is not None: + stderr = self._read_stream(2) + else: + stderr = self._noop() + stdin, stdout, stderr = yield from tasks.gather(stdin, stdout, stderr, + loop=self._loop) + yield from self.wait() + return (stdout, stderr) + + +@coroutine +def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None, + loop=None, limit=streams._DEFAULT_LIMIT, **kwds): + if loop is None: + loop = events.get_event_loop() + protocol_factory = lambda: SubprocessStreamProtocol(limit=limit, + loop=loop) + transport, protocol = yield from loop.subprocess_shell( + protocol_factory, + cmd, stdin=stdin, stdout=stdout, + stderr=stderr, **kwds) + try: + yield from protocol.waiter + except: + transport._kill_wait() + raise + return Process(transport, protocol, loop) + +@coroutine +def create_subprocess_exec(program, *args, stdin=None, stdout=None, + stderr=None, loop=None, + limit=streams._DEFAULT_LIMIT, **kwds): + if loop is None: + loop = events.get_event_loop() + protocol_factory = lambda: SubprocessStreamProtocol(limit=limit, + loop=loop) + transport, protocol = yield from loop.subprocess_exec( + protocol_factory, + program, *args, + stdin=stdin, stdout=stdout, + stderr=stderr, **kwds) + try: + yield from protocol.waiter + except: + transport._kill_wait() + raise + return Process(transport, protocol, loop) diff --git a/asyncio/tasks.py b/asyncio/tasks.py new file mode 100644 index 00000000..63412a97 --- /dev/null +++ b/asyncio/tasks.py @@ -0,0 +1,667 @@ +"""Support for tasks, coroutines and the scheduler.""" + +__all__ = ['Task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'wait_for', 'as_completed', 'sleep', 'async', + 'gather', 'shield', + ] + +import concurrent.futures +import functools +import inspect +import linecache +import sys +import traceback +import weakref + +from . import coroutines +from . import events +from . import futures +from .coroutines import coroutine + +_PY34 = (sys.version_info >= (3, 4)) + + +class Task(futures.Future): + """A coroutine wrapped in a Future.""" + + # An important invariant maintained while a Task not done: + # + # - Either _fut_waiter is None, and _step() is scheduled; + # - or _fut_waiter is some Future, and _step() is *not* scheduled. + # + # The only transition from the latter to the former is through + # _wakeup(). When _fut_waiter is not None, one of its callbacks + # must be _wakeup(). + + # Weak set containing all tasks alive. + _all_tasks = weakref.WeakSet() + + # Dictionary containing tasks that are currently active in + # all running event loops. {EventLoop: Task} + _current_tasks = {} + + # If False, don't log a message if the task is destroyed whereas its + # status is still pending + _log_destroy_pending = True + + @classmethod + def current_task(cls, loop=None): + """Return the currently running task in an event loop or None. + + By default the current task for the current event loop is returned. + + None is returned when called not in the context of a Task. + """ + if loop is None: + loop = events.get_event_loop() + return cls._current_tasks.get(loop) + + @classmethod + def all_tasks(cls, loop=None): + """Return a set of all tasks for an event loop. + + By default all tasks for the current event loop are returned. + """ + if loop is None: + loop = events.get_event_loop() + return {t for t in cls._all_tasks if t._loop is loop} + + def __init__(self, coro, *, loop=None): + assert coroutines.iscoroutine(coro), repr(coro) + super().__init__(loop=loop) + if self._source_traceback: + del self._source_traceback[-1] + self._coro = iter(coro) # Use the iterator just in case. + self._fut_waiter = None + self._must_cancel = False + self._loop.call_soon(self._step) + self.__class__._all_tasks.add(self) + + # On Python 3.3 or older, objects with a destructor that are part of a + # reference cycle are never destroyed. That's not the case any more on + # Python 3.4 thanks to the PEP 442. + if _PY34: + def __del__(self): + if self._state == futures._PENDING and self._log_destroy_pending: + context = { + 'task': self, + 'message': 'Task was destroyed but it is pending!', + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + futures.Future.__del__(self) + + def _repr_info(self): + info = super()._repr_info() + + if self._must_cancel: + # replace status + info[0] = 'cancelling' + + coro = coroutines._format_coroutine(self._coro) + info.insert(1, 'coro=<%s>' % coro) + + if self._fut_waiter is not None: + info.insert(2, 'wait_for=%r' % self._fut_waiter) + return info + + def get_stack(self, *, limit=None): + """Return the list of stack frames for this task's coroutine. + + If the coroutine is not done, this returns the stack where it is + suspended. If the coroutine has completed successfully or was + cancelled, this returns an empty list. If the coroutine was + terminated by an exception, this returns the list of traceback + frames. + + The frames are always ordered from oldest to newest. + + The optional limit gives the maximum number of frames to + return; by default all available frames are returned. Its + meaning differs depending on whether a stack or a traceback is + returned: the newest frames of a stack are returned, but the + oldest frames of a traceback are returned. (This matches the + behavior of the traceback module.) + + For reasons beyond our control, only one stack frame is + returned for a suspended coroutine. + """ + frames = [] + f = self._coro.gi_frame + if f is not None: + while f is not None: + if limit is not None: + if limit <= 0: + break + limit -= 1 + frames.append(f) + f = f.f_back + frames.reverse() + elif self._exception is not None: + tb = self._exception.__traceback__ + while tb is not None: + if limit is not None: + if limit <= 0: + break + limit -= 1 + frames.append(tb.tb_frame) + tb = tb.tb_next + return frames + + def print_stack(self, *, limit=None, file=None): + """Print the stack or traceback for this task's coroutine. + + This produces output similar to that of the traceback module, + for the frames retrieved by get_stack(). The limit argument + is passed to get_stack(). The file argument is an I/O stream + to which the output is written; by default output is written + to sys.stderr. + """ + extracted_list = [] + checked = set() + for f in self.get_stack(limit=limit): + lineno = f.f_lineno + co = f.f_code + filename = co.co_filename + name = co.co_name + if filename not in checked: + checked.add(filename) + linecache.checkcache(filename) + line = linecache.getline(filename, lineno, f.f_globals) + extracted_list.append((filename, lineno, name, line)) + exc = self._exception + if not extracted_list: + print('No stack for %r' % self, file=file) + elif exc is not None: + print('Traceback for %r (most recent call last):' % self, + file=file) + else: + print('Stack for %r (most recent call last):' % self, + file=file) + traceback.print_list(extracted_list, file=file) + if exc is not None: + for line in traceback.format_exception_only(exc.__class__, exc): + print(line, file=file, end='') + + def cancel(self): + """Request that this task cancel itself. + + This arranges for a CancelledError to be thrown into the + wrapped coroutine on the next cycle through the event loop. + The coroutine then has a chance to clean up or even deny + the request using try/except/finally. + + Unlike Future.cancel, this does not guarantee that the + task will be cancelled: the exception might be caught and + acted upon, delaying cancellation of the task or preventing + cancellation completely. The task may also return a value or + raise a different exception. + + Immediately after this method is called, Task.cancelled() will + not return True (unless the task was already cancelled). A + task will be marked as cancelled when the wrapped coroutine + terminates with a CancelledError exception (even if cancel() + was not called). + """ + if self.done(): + return False + if self._fut_waiter is not None: + if self._fut_waiter.cancel(): + # Leave self._fut_waiter; it may be a Task that + # catches and ignores the cancellation so we may have + # to cancel it again later. + return True + # It must be the case that self._step is already scheduled. + self._must_cancel = True + return True + + def _step(self, value=None, exc=None): + assert not self.done(), \ + '_step(): already done: {!r}, {!r}, {!r}'.format(self, value, exc) + if self._must_cancel: + if not isinstance(exc, futures.CancelledError): + exc = futures.CancelledError() + self._must_cancel = False + coro = self._coro + self._fut_waiter = None + + self.__class__._current_tasks[self._loop] = self + # Call either coro.throw(exc) or coro.send(value). + try: + if exc is not None: + result = coro.throw(exc) + elif value is not None: + result = coro.send(value) + else: + result = next(coro) + except StopIteration as exc: + self.set_result(exc.value) + except futures.CancelledError as exc: + super().cancel() # I.e., Future.cancel(self). + except Exception as exc: + self.set_exception(exc) + except BaseException as exc: + self.set_exception(exc) + raise + else: + if isinstance(result, futures.Future): + # Yielded Future must come from Future.__iter__(). + if result._blocking: + result._blocking = False + result.add_done_callback(self._wakeup) + self._fut_waiter = result + if self._must_cancel: + if self._fut_waiter.cancel(): + self._must_cancel = False + else: + self._loop.call_soon( + self._step, None, + RuntimeError( + 'yield was used instead of yield from ' + 'in task {!r} with {!r}'.format(self, result))) + elif result is None: + # Bare yield relinquishes control for one event loop iteration. + self._loop.call_soon(self._step) + elif inspect.isgenerator(result): + # Yielding a generator is just wrong. + self._loop.call_soon( + self._step, None, + RuntimeError( + 'yield was used instead of yield from for ' + 'generator in task {!r} with {}'.format( + self, result))) + else: + # Yielding something else is an error. + self._loop.call_soon( + self._step, None, + RuntimeError( + 'Task got bad yield: {!r}'.format(result))) + finally: + self.__class__._current_tasks.pop(self._loop) + self = None # Needed to break cycles when an exception occurs. + + def _wakeup(self, future): + try: + value = future.result() + except Exception as exc: + # This may also be a cancellation. + self._step(None, exc) + else: + self._step(value, None) + self = None # Needed to break cycles when an exception occurs. + + +# wait() and as_completed() similar to those in PEP 3148. + +FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED +FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION +ALL_COMPLETED = concurrent.futures.ALL_COMPLETED + + +@coroutine +def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): + """Wait for the Futures and coroutines given by fs to complete. + + The sequence futures must not be empty. + + Coroutines will be wrapped in Tasks. + + Returns two sets of Future: (done, pending). + + Usage: + + done, pending = yield from asyncio.wait(fs) + + Note: This does not raise TimeoutError! Futures that aren't done + when the timeout occurs are returned in the second set. + """ + if isinstance(fs, futures.Future) or coroutines.iscoroutine(fs): + raise TypeError("expect a list of futures, not %s" % type(fs).__name__) + if not fs: + raise ValueError('Set of coroutines/Futures is empty.') + if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED): + raise ValueError('Invalid return_when value: {}'.format(return_when)) + + if loop is None: + loop = events.get_event_loop() + + fs = {async(f, loop=loop) for f in set(fs)} + + return (yield from _wait(fs, timeout, return_when, loop)) + + +def _release_waiter(waiter, *args): + if not waiter.done(): + waiter.set_result(None) + + +@coroutine +def wait_for(fut, timeout, *, loop=None): + """Wait for the single Future or coroutine to complete, with timeout. + + Coroutine will be wrapped in Task. + + Returns result of the Future or coroutine. When a timeout occurs, + it cancels the task and raises TimeoutError. To avoid the task + cancellation, wrap it in shield(). + + If the wait is cancelled, the task is also cancelled. + + This function is a coroutine. + """ + if loop is None: + loop = events.get_event_loop() + + if timeout is None: + return (yield from fut) + + waiter = futures.Future(loop=loop) + timeout_handle = loop.call_later(timeout, _release_waiter, waiter) + cb = functools.partial(_release_waiter, waiter) + + fut = async(fut, loop=loop) + fut.add_done_callback(cb) + + try: + # wait until the future completes or the timeout + try: + yield from waiter + except futures.CancelledError: + fut.remove_done_callback(cb) + fut.cancel() + raise + + if fut.done(): + return fut.result() + else: + fut.remove_done_callback(cb) + fut.cancel() + raise futures.TimeoutError() + finally: + timeout_handle.cancel() + + +@coroutine +def _wait(fs, timeout, return_when, loop): + """Internal helper for wait() and _wait_for(). + + The fs argument must be a collection of Futures. + """ + assert fs, 'Set of Futures is empty.' + waiter = futures.Future(loop=loop) + timeout_handle = None + if timeout is not None: + timeout_handle = loop.call_later(timeout, _release_waiter, waiter) + counter = len(fs) + + def _on_completion(f): + nonlocal counter + counter -= 1 + if (counter <= 0 or + return_when == FIRST_COMPLETED or + return_when == FIRST_EXCEPTION and (not f.cancelled() and + f.exception() is not None)): + if timeout_handle is not None: + timeout_handle.cancel() + if not waiter.done(): + waiter.set_result(None) + + for f in fs: + f.add_done_callback(_on_completion) + + try: + yield from waiter + finally: + if timeout_handle is not None: + timeout_handle.cancel() + + done, pending = set(), set() + for f in fs: + f.remove_done_callback(_on_completion) + if f.done(): + done.add(f) + else: + pending.add(f) + return done, pending + + +# This is *not* a @coroutine! It is just an iterator (yielding Futures). +def as_completed(fs, *, loop=None, timeout=None): + """Return an iterator whose values are coroutines. + + When waiting for the yielded coroutines you'll get the results (or + exceptions!) of the original Futures (or coroutines), in the order + in which and as soon as they complete. + + This differs from PEP 3148; the proper way to use this is: + + for f in as_completed(fs): + result = yield from f # The 'yield from' may raise. + # Use result. + + If a timeout is specified, the 'yield from' will raise + TimeoutError when the timeout occurs before all Futures are done. + + Note: The futures 'f' are not necessarily members of fs. + """ + if isinstance(fs, futures.Future) or coroutines.iscoroutine(fs): + raise TypeError("expect a list of futures, not %s" % type(fs).__name__) + loop = loop if loop is not None else events.get_event_loop() + todo = {async(f, loop=loop) for f in set(fs)} + from .queues import Queue # Import here to avoid circular import problem. + done = Queue(loop=loop) + timeout_handle = None + + def _on_timeout(): + for f in todo: + f.remove_done_callback(_on_completion) + done.put_nowait(None) # Queue a dummy value for _wait_for_one(). + todo.clear() # Can't do todo.remove(f) in the loop. + + def _on_completion(f): + if not todo: + return # _on_timeout() was here first. + todo.remove(f) + done.put_nowait(f) + if not todo and timeout_handle is not None: + timeout_handle.cancel() + + @coroutine + def _wait_for_one(): + f = yield from done.get() + if f is None: + # Dummy value from _on_timeout(). + raise futures.TimeoutError + return f.result() # May raise f.exception(). + + for f in todo: + f.add_done_callback(_on_completion) + if todo and timeout is not None: + timeout_handle = loop.call_later(timeout, _on_timeout) + for _ in range(len(todo)): + yield _wait_for_one() + + +@coroutine +def sleep(delay, result=None, *, loop=None): + """Coroutine that completes after a given time (in seconds).""" + future = futures.Future(loop=loop) + h = future._loop.call_later(delay, + future._set_result_unless_cancelled, result) + try: + return (yield from future) + finally: + h.cancel() + + +def async(coro_or_future, *, loop=None): + """Wrap a coroutine in a future. + + If the argument is a Future, it is returned directly. + """ + if isinstance(coro_or_future, futures.Future): + if loop is not None and loop is not coro_or_future._loop: + raise ValueError('loop argument must agree with Future') + return coro_or_future + elif coroutines.iscoroutine(coro_or_future): + if loop is None: + loop = events.get_event_loop() + task = loop.create_task(coro_or_future) + if task._source_traceback: + del task._source_traceback[-1] + return task + else: + raise TypeError('A Future or coroutine is required') + + +class _GatheringFuture(futures.Future): + """Helper for gather(). + + This overrides cancel() to cancel all the children and act more + like Task.cancel(), which doesn't immediately mark itself as + cancelled. + """ + + def __init__(self, children, *, loop=None): + super().__init__(loop=loop) + self._children = children + + def cancel(self): + if self.done(): + return False + for child in self._children: + child.cancel() + return True + + +def gather(*coros_or_futures, loop=None, return_exceptions=False): + """Return a future aggregating results from the given coroutines + or futures. + + All futures must share the same event loop. If all the tasks are + done successfully, the returned future's result is the list of + results (in the order of the original sequence, not necessarily + the order of results arrival). If *return_exceptions* is True, + exceptions in the tasks are treated the same as successful + results, and gathered in the result list; otherwise, the first + raised exception will be immediately propagated to the returned + future. + + Cancellation: if the outer Future is cancelled, all children (that + have not completed yet) are also cancelled. If any child is + cancelled, this is treated as if it raised CancelledError -- + the outer Future is *not* cancelled in this case. (This is to + prevent the cancellation of one child to cause other children to + be cancelled.) + """ + if not coros_or_futures: + outer = futures.Future(loop=loop) + outer.set_result([]) + return outer + + arg_to_fut = {} + for arg in set(coros_or_futures): + if not isinstance(arg, futures.Future): + fut = async(arg, loop=loop) + if loop is None: + loop = fut._loop + # The caller cannot control this future, the "destroy pending task" + # warning should not be emitted. + fut._log_destroy_pending = False + else: + fut = arg + if loop is None: + loop = fut._loop + elif fut._loop is not loop: + raise ValueError("futures are tied to different event loops") + arg_to_fut[arg] = fut + + children = [arg_to_fut[arg] for arg in coros_or_futures] + nchildren = len(children) + outer = _GatheringFuture(children, loop=loop) + nfinished = 0 + results = [None] * nchildren + + def _done_callback(i, fut): + nonlocal nfinished + if outer.done(): + if not fut.cancelled(): + # Mark exception retrieved. + fut.exception() + return + + if fut._state == futures._CANCELLED: + res = futures.CancelledError() + if not return_exceptions: + outer.set_exception(res) + return + elif fut._exception is not None: + res = fut.exception() # Mark exception retrieved. + if not return_exceptions: + outer.set_exception(res) + return + else: + res = fut._result + results[i] = res + nfinished += 1 + if nfinished == nchildren: + outer.set_result(results) + + for i, fut in enumerate(children): + fut.add_done_callback(functools.partial(_done_callback, i)) + return outer + + +def shield(arg, *, loop=None): + """Wait for a future, shielding it from cancellation. + + The statement + + res = yield from shield(something()) + + is exactly equivalent to the statement + + res = yield from something() + + *except* that if the coroutine containing it is cancelled, the + task running in something() is not cancelled. From the POV of + something(), the cancellation did not happen. But its caller is + still cancelled, so the yield-from expression still raises + CancelledError. Note: If something() is cancelled by other means + this will still cancel shield(). + + If you want to completely ignore cancellation (not recommended) + you can combine shield() with a try/except clause, as follows: + + try: + res = yield from shield(something()) + except CancelledError: + res = None + """ + inner = async(arg, loop=loop) + if inner.done(): + # Shortcut. + return inner + loop = inner._loop + outer = futures.Future(loop=loop) + + def _done_callback(inner): + if outer.cancelled(): + if not inner.cancelled(): + # Mark inner's result as retrieved. + inner.exception() + return + + if inner.cancelled(): + outer.cancel() + else: + exc = inner.exception() + if exc is not None: + outer.set_exception(exc) + else: + outer.set_result(inner.result()) + + inner.add_done_callback(_done_callback) + return outer diff --git a/asyncio/test_support.py b/asyncio/test_support.py new file mode 100644 index 00000000..3da47558 --- /dev/null +++ b/asyncio/test_support.py @@ -0,0 +1,305 @@ +# Subset of test.support from CPython 3.5, just what we need to run asyncio +# test suite. The code is copied from CPython 3.5 to not depend on the test +# module because it is rarely installed. + +# Ignore symbol TEST_HOME_DIR: test_events works without it + +import functools +import gc +import os +import platform +import re +import socket +import subprocess +import sys +import time + + +# A constant likely larger than the underlying OS pipe buffer size, to +# make writes blocking. +# Windows limit seems to be around 512 B, and many Unix kernels have a +# 64 KiB pipe buffer size or 16 * PAGE_SIZE: take a few megs to be sure. +# (see issue #17835 for a discussion of this number). +PIPE_MAX_SIZE = 4 * 1024 * 1024 + 1 + +def strip_python_stderr(stderr): + """Strip the stderr of a Python process from potential debug output + emitted by the interpreter. + + This will typically be run on the result of the communicate() method + of a subprocess.Popen object. + """ + stderr = re.sub(br"\[\d+ refs, \d+ blocks\]\r?\n?", b"", stderr).strip() + return stderr + + +# Executing the interpreter in a subprocess +def _assert_python(expected_success, *args, **env_vars): + if '__isolated' in env_vars: + isolated = env_vars.pop('__isolated') + else: + isolated = not env_vars + cmd_line = [sys.executable, '-X', 'faulthandler'] + if isolated and sys.version_info >= (3, 4): + # isolated mode: ignore Python environment variables, ignore user + # site-packages, and don't add the current directory to sys.path + cmd_line.append('-I') + elif not env_vars: + # ignore Python environment variables + cmd_line.append('-E') + # Need to preserve the original environment, for in-place testing of + # shared library builds. + env = os.environ.copy() + # But a special flag that can be set to override -- in this case, the + # caller is responsible to pass the full environment. + if env_vars.pop('__cleanenv', None): + env = {} + env.update(env_vars) + cmd_line.extend(args) + p = subprocess.Popen(cmd_line, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + env=env) + try: + out, err = p.communicate() + finally: + subprocess._cleanup() + p.stdout.close() + p.stderr.close() + rc = p.returncode + err = strip_python_stderr(err) + if (rc and expected_success) or (not rc and not expected_success): + raise AssertionError( + "Process return code is %d, " + "stderr follows:\n%s" % (rc, err.decode('ascii', 'ignore'))) + return rc, out, err + + +def assert_python_ok(*args, **env_vars): + """ + Assert that running the interpreter with `args` and optional environment + variables `env_vars` succeeds (rc == 0) and return a (return code, stdout, + stderr) tuple. + + If the __cleanenv keyword is set, env_vars is used a fresh environment. + + Python is started in isolated mode (command line option -I), + except if the __isolated keyword is set to False. + """ + return _assert_python(True, *args, **env_vars) + + +is_jython = sys.platform.startswith('java') + +def gc_collect(): + """Force as many objects as possible to be collected. + + In non-CPython implementations of Python, this is needed because timely + deallocation is not guaranteed by the garbage collector. (Even in CPython + this can be the case in case of reference cycles.) This means that __del__ + methods may be called later than expected and weakrefs may remain alive for + longer than expected. This function tries its best to force all garbage + objects to disappear. + """ + gc.collect() + if is_jython: + time.sleep(0.1) + gc.collect() + gc.collect() + + +HOST = "127.0.0.1" +HOSTv6 = "::1" + + +def _is_ipv6_enabled(): + """Check whether IPv6 is enabled on this host.""" + if socket.has_ipv6: + sock = None + try: + sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + sock.bind((HOSTv6, 0)) + return True + except OSError: + pass + finally: + if sock: + sock.close() + return False + +IPV6_ENABLED = _is_ipv6_enabled() + + +def find_unused_port(family=socket.AF_INET, socktype=socket.SOCK_STREAM): + """Returns an unused port that should be suitable for binding. This is + achieved by creating a temporary socket with the same family and type as + the 'sock' parameter (default is AF_INET, SOCK_STREAM), and binding it to + the specified host address (defaults to 0.0.0.0) with the port set to 0, + eliciting an unused ephemeral port from the OS. The temporary socket is + then closed and deleted, and the ephemeral port is returned. + + Either this method or bind_port() should be used for any tests where a + server socket needs to be bound to a particular port for the duration of + the test. Which one to use depends on whether the calling code is creating + a python socket, or if an unused port needs to be provided in a constructor + or passed to an external program (i.e. the -accept argument to openssl's + s_server mode). Always prefer bind_port() over find_unused_port() where + possible. Hard coded ports should *NEVER* be used. As soon as a server + socket is bound to a hard coded port, the ability to run multiple instances + of the test simultaneously on the same host is compromised, which makes the + test a ticking time bomb in a buildbot environment. On Unix buildbots, this + may simply manifest as a failed test, which can be recovered from without + intervention in most cases, but on Windows, the entire python process can + completely and utterly wedge, requiring someone to log in to the buildbot + and manually kill the affected process. + + (This is easy to reproduce on Windows, unfortunately, and can be traced to + the SO_REUSEADDR socket option having different semantics on Windows versus + Unix/Linux. On Unix, you can't have two AF_INET SOCK_STREAM sockets bind, + listen and then accept connections on identical host/ports. An EADDRINUSE + OSError will be raised at some point (depending on the platform and + the order bind and listen were called on each socket). + + However, on Windows, if SO_REUSEADDR is set on the sockets, no EADDRINUSE + will ever be raised when attempting to bind two identical host/ports. When + accept() is called on each socket, the second caller's process will steal + the port from the first caller, leaving them both in an awkwardly wedged + state where they'll no longer respond to any signals or graceful kills, and + must be forcibly killed via OpenProcess()/TerminateProcess(). + + The solution on Windows is to use the SO_EXCLUSIVEADDRUSE socket option + instead of SO_REUSEADDR, which effectively affords the same semantics as + SO_REUSEADDR on Unix. Given the propensity of Unix developers in the Open + Source world compared to Windows ones, this is a common mistake. A quick + look over OpenSSL's 0.9.8g source shows that they use SO_REUSEADDR when + openssl.exe is called with the 's_server' option, for example. See + http://bugs.python.org/issue2550 for more info. The following site also + has a very thorough description about the implications of both REUSEADDR + and EXCLUSIVEADDRUSE on Windows: + http://msdn2.microsoft.com/en-us/library/ms740621(VS.85).aspx) + + XXX: although this approach is a vast improvement on previous attempts to + elicit unused ports, it rests heavily on the assumption that the ephemeral + port returned to us by the OS won't immediately be dished back out to some + other process when we close and delete our temporary socket but before our + calling code has a chance to bind the returned port. We can deal with this + issue if/when we come across it. + """ + + tempsock = socket.socket(family, socktype) + port = bind_port(tempsock) + tempsock.close() + del tempsock + return port + +def bind_port(sock, host=HOST): + """Bind the socket to a free port and return the port number. Relies on + ephemeral ports in order to ensure we are using an unbound port. This is + important as many tests may be running simultaneously, especially in a + buildbot environment. This method raises an exception if the sock.family + is AF_INET and sock.type is SOCK_STREAM, *and* the socket has SO_REUSEADDR + or SO_REUSEPORT set on it. Tests should *never* set these socket options + for TCP/IP sockets. The only case for setting these options is testing + multicasting via multiple UDP sockets. + + Additionally, if the SO_EXCLUSIVEADDRUSE socket option is available (i.e. + on Windows), it will be set on the socket. This will prevent anyone else + from bind()'ing to our host/port for the duration of the test. + """ + + if sock.family == socket.AF_INET and sock.type == socket.SOCK_STREAM: + if hasattr(socket, 'SO_REUSEADDR'): + if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 1: + raise TestFailed("tests should never set the SO_REUSEADDR " + "socket option on TCP/IP sockets!") + if hasattr(socket, 'SO_REUSEPORT'): + try: + reuse = sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) + if reuse == 1: + raise TestFailed("tests should never set the SO_REUSEPORT " + "socket option on TCP/IP sockets!") + except OSError: + # Python's socket module was compiled using modern headers + # thus defining SO_REUSEPORT but this process is running + # under an older kernel that does not support SO_REUSEPORT. + pass + if hasattr(socket, 'SO_EXCLUSIVEADDRUSE'): + sock.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) + + sock.bind((host, 0)) + port = sock.getsockname()[1] + return port + +def requires_mac_ver(*min_version): + """Decorator raising SkipTest if the OS is Mac OS X and the OS X + version if less than min_version. + + For example, @requires_mac_ver(10, 5) raises SkipTest if the OS X version + is lesser than 10.5. + """ + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kw): + if sys.platform == 'darwin': + version_txt = platform.mac_ver()[0] + try: + version = tuple(map(int, version_txt.split('.'))) + except ValueError: + pass + else: + if version < min_version: + min_version_txt = '.'.join(map(str, min_version)) + raise unittest.SkipTest( + "Mac OS X %s or higher required, not %s" + % (min_version_txt, version_txt)) + return func(*args, **kw) + wrapper.min_version = min_version + return wrapper + return decorator + +def _requires_unix_version(sysname, min_version): + """Decorator raising SkipTest if the OS is `sysname` and the version is + less than `min_version`. + + For example, @_requires_unix_version('FreeBSD', (7, 2)) raises SkipTest if + the FreeBSD version is less than 7.2. + """ + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kw): + if platform.system() == sysname: + version_txt = platform.release().split('-', 1)[0] + try: + version = tuple(map(int, version_txt.split('.'))) + except ValueError: + pass + else: + if version < min_version: + min_version_txt = '.'.join(map(str, min_version)) + raise unittest.SkipTest( + "%s version %s or higher required, not %s" + % (sysname, min_version_txt, version_txt)) + return func(*args, **kw) + wrapper.min_version = min_version + return wrapper + return decorator + +def requires_freebsd_version(*min_version): + """Decorator raising SkipTest if the OS is FreeBSD and the FreeBSD version + is less than `min_version`. + + For example, @requires_freebsd_version(7, 2) raises SkipTest if the FreeBSD + version is less than 7.2. + """ + return _requires_unix_version('FreeBSD', min_version) + +# Use test.support if available +try: + from test.support import * +except ImportError: + pass + +# Use test.script_helper if available +try: + from test.script_helper import assert_python_ok +except ImportError: + pass diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py new file mode 100644 index 00000000..6eedc583 --- /dev/null +++ b/asyncio/test_utils.py @@ -0,0 +1,442 @@ +"""Utilities shared by tests.""" + +import collections +import contextlib +import io +import logging +import os +import re +import socket +import socketserver +import sys +import tempfile +import threading +import time +import unittest +from unittest import mock + +from http.server import HTTPServer +from wsgiref.simple_server import WSGIRequestHandler, WSGIServer + +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import base_events +from . import events +from . import futures +from . import selectors +from . import tasks +from .coroutines import coroutine +from .log import logger + + +if sys.platform == 'win32': # pragma: no cover + from .windows_utils import socketpair +else: + from socket import socketpair # pragma: no cover + + +def dummy_ssl_context(): + if ssl is None: + return None + else: + return ssl.SSLContext(ssl.PROTOCOL_SSLv23) + + +def run_briefly(loop): + @coroutine + def once(): + pass + gen = once() + t = loop.create_task(gen) + # Don't log a warning if the task is not done after run_until_complete(). + # It occurs if the loop is stopped or if a task raises a BaseException. + t._log_destroy_pending = False + try: + loop.run_until_complete(t) + finally: + gen.close() + + +def run_until(loop, pred, timeout=30): + deadline = time.time() + timeout + while not pred(): + if timeout is not None: + timeout = deadline - time.time() + if timeout <= 0: + raise futures.TimeoutError() + loop.run_until_complete(tasks.sleep(0.001, loop=loop)) + + +def run_once(loop): + """loop.stop() schedules _raise_stop_error() + and run_forever() runs until _raise_stop_error() callback. + this wont work if test waits for some IO events, because + _raise_stop_error() runs before any of io events callbacks. + """ + loop.stop() + loop.run_forever() + + +class SilentWSGIRequestHandler(WSGIRequestHandler): + + def get_stderr(self): + return io.StringIO() + + def log_message(self, format, *args): + pass + + +class SilentWSGIServer(WSGIServer): + + request_timeout = 2 + + def get_request(self): + request, client_addr = super().get_request() + request.settimeout(self.request_timeout) + return request, client_addr + + def handle_error(self, request, client_address): + pass + + +class SSLWSGIServerMixin: + + def finish_request(self, request, client_address): + # The relative location of our test directory (which + # contains the ssl key and certificate files) differs + # between the stdlib and stand-alone asyncio. + # Prefer our own if we can find it. + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + if not os.path.isdir(here): + here = os.path.join(os.path.dirname(os.__file__), + 'test', 'test_asyncio') + keyfile = os.path.join(here, 'ssl_key.pem') + certfile = os.path.join(here, 'ssl_cert.pem') + ssock = ssl.wrap_socket(request, + keyfile=keyfile, + certfile=certfile, + server_side=True) + try: + self.RequestHandlerClass(ssock, client_address, self) + ssock.close() + except OSError: + # maybe socket has been closed by peer + pass + + +class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer): + pass + + +def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls): + + def app(environ, start_response): + status = '200 OK' + headers = [('Content-type', 'text/plain')] + start_response(status, headers) + return [b'Test message'] + + # Run the test WSGI server in a separate thread in order not to + # interfere with event handling in the main thread + server_class = server_ssl_cls if use_ssl else server_cls + httpd = server_class(address, SilentWSGIRequestHandler) + httpd.set_app(app) + httpd.address = httpd.server_address + server_thread = threading.Thread( + target=lambda: httpd.serve_forever(poll_interval=0.05)) + server_thread.start() + try: + yield httpd + finally: + httpd.shutdown() + httpd.server_close() + server_thread.join() + + +if hasattr(socket, 'AF_UNIX'): + + class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer): + + def server_bind(self): + socketserver.UnixStreamServer.server_bind(self) + self.server_name = '127.0.0.1' + self.server_port = 80 + + + class UnixWSGIServer(UnixHTTPServer, WSGIServer): + + request_timeout = 2 + + def server_bind(self): + UnixHTTPServer.server_bind(self) + self.setup_environ() + + def get_request(self): + request, client_addr = super().get_request() + request.settimeout(self.request_timeout) + # Code in the stdlib expects that get_request + # will return a socket and a tuple (host, port). + # However, this isn't true for UNIX sockets, + # as the second return value will be a path; + # hence we return some fake data sufficient + # to get the tests going + return request, ('127.0.0.1', '') + + + class SilentUnixWSGIServer(UnixWSGIServer): + + def handle_error(self, request, client_address): + pass + + + class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer): + pass + + + def gen_unix_socket_path(): + with tempfile.NamedTemporaryFile() as file: + return file.name + + + @contextlib.contextmanager + def unix_socket_path(): + path = gen_unix_socket_path() + try: + yield path + finally: + try: + os.unlink(path) + except OSError: + pass + + + @contextlib.contextmanager + def run_test_unix_server(*, use_ssl=False): + with unix_socket_path() as path: + yield from _run_test_server(address=path, use_ssl=use_ssl, + server_cls=SilentUnixWSGIServer, + server_ssl_cls=UnixSSLWSGIServer) + + +@contextlib.contextmanager +def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): + yield from _run_test_server(address=(host, port), use_ssl=use_ssl, + server_cls=SilentWSGIServer, + server_ssl_cls=SSLWSGIServer) + + +def make_test_protocol(base): + dct = {} + for name in dir(base): + if name.startswith('__') and name.endswith('__'): + # skip magic names + continue + dct[name] = MockCallback(return_value=None) + return type('TestProtocol', (base,) + base.__bases__, dct)() + + +class TestSelector(selectors.BaseSelector): + + def __init__(self): + self.keys = {} + + def register(self, fileobj, events, data=None): + key = selectors.SelectorKey(fileobj, 0, events, data) + self.keys[fileobj] = key + return key + + def unregister(self, fileobj): + return self.keys.pop(fileobj) + + def select(self, timeout): + return [] + + def get_map(self): + return self.keys + + +class TestLoop(base_events.BaseEventLoop): + """Loop for unittests. + + It manages self time directly. + If something scheduled to be executed later then + on next loop iteration after all ready handlers done + generator passed to __init__ is calling. + + Generator should be like this: + + def gen(): + ... + when = yield ... + ... = yield time_advance + + Value returned by yield is absolute time of next scheduled handler. + Value passed to yield is time advance to move loop's time forward. + """ + + def __init__(self, gen=None): + super().__init__() + + if gen is None: + def gen(): + yield + self._check_on_close = False + else: + self._check_on_close = True + + self._gen = gen() + next(self._gen) + self._time = 0 + self._clock_resolution = 1e-9 + self._timers = [] + self._selector = TestSelector() + + self.readers = {} + self.writers = {} + self.reset_counters() + + def time(self): + return self._time + + def advance_time(self, advance): + """Move test time forward.""" + if advance: + self._time += advance + + def close(self): + super().close() + if self._check_on_close: + try: + self._gen.send(0) + except StopIteration: + pass + else: # pragma: no cover + raise AssertionError("Time generator is not finished") + + def add_reader(self, fd, callback, *args): + self.readers[fd] = events.Handle(callback, args, self) + + def remove_reader(self, fd): + self.remove_reader_count[fd] += 1 + if fd in self.readers: + del self.readers[fd] + return True + else: + return False + + def assert_reader(self, fd, callback, *args): + assert fd in self.readers, 'fd {} is not registered'.format(fd) + handle = self.readers[fd] + assert handle._callback == callback, '{!r} != {!r}'.format( + handle._callback, callback) + assert handle._args == args, '{!r} != {!r}'.format( + handle._args, args) + + def add_writer(self, fd, callback, *args): + self.writers[fd] = events.Handle(callback, args, self) + + def remove_writer(self, fd): + self.remove_writer_count[fd] += 1 + if fd in self.writers: + del self.writers[fd] + return True + else: + return False + + def assert_writer(self, fd, callback, *args): + assert fd in self.writers, 'fd {} is not registered'.format(fd) + handle = self.writers[fd] + assert handle._callback == callback, '{!r} != {!r}'.format( + handle._callback, callback) + assert handle._args == args, '{!r} != {!r}'.format( + handle._args, args) + + def reset_counters(self): + self.remove_reader_count = collections.defaultdict(int) + self.remove_writer_count = collections.defaultdict(int) + + def _run_once(self): + super()._run_once() + for when in self._timers: + advance = self._gen.send(when) + self.advance_time(advance) + self._timers = [] + + def call_at(self, when, callback, *args): + self._timers.append(when) + return super().call_at(when, callback, *args) + + def _process_events(self, event_list): + return + + def _write_to_self(self): + pass + + +def MockCallback(**kwargs): + return mock.Mock(spec=['__call__'], **kwargs) + + +class MockPattern(str): + """A regex based str with a fuzzy __eq__. + + Use this helper with 'mock.assert_called_with', or anywhere + where a regex comparison between strings is needed. + + For instance: + mock_call.assert_called_with(MockPattern('spam.*ham')) + """ + def __eq__(self, other): + return bool(re.search(str(self), other, re.S)) + + +def get_function_source(func): + source = events._get_function_source(func) + if source is None: + raise ValueError("unable to get the source of %r" % (func,)) + return source + + +class TestCase(unittest.TestCase): + def set_event_loop(self, loop, *, cleanup=True): + assert loop is not None + # ensure that the event loop is passed explicitly in asyncio + events.set_event_loop(None) + if cleanup: + self.addCleanup(loop.close) + + def new_test_loop(self, gen=None): + loop = TestLoop(gen) + self.set_event_loop(loop) + return loop + + def tearDown(self): + events.set_event_loop(None) + + +@contextlib.contextmanager +def disable_logger(): + """Context manager to disable asyncio logger. + + For example, it can be used to ignore warnings in debug mode. + """ + old_level = logger.level + try: + logger.setLevel(logging.CRITICAL+1) + yield + finally: + logger.setLevel(old_level) + +def mock_nonblocking_socket(): + """Create a mock of a non-blocking socket.""" + sock = mock.Mock(socket.socket) + sock.gettimeout.return_value = 0.0 + return sock + + +def force_legacy_ssl_support(): + return mock.patch('asyncio.sslproto._is_sslproto_available', + return_value=False) diff --git a/asyncio/transports.py b/asyncio/transports.py new file mode 100644 index 00000000..22df3c7a --- /dev/null +++ b/asyncio/transports.py @@ -0,0 +1,300 @@ +"""Abstract Transport class.""" + +import sys + +_PY34 = sys.version_info >= (3, 4) + +__all__ = ['BaseTransport', 'ReadTransport', 'WriteTransport', + 'Transport', 'DatagramTransport', 'SubprocessTransport', + ] + + +class BaseTransport: + """Base class for transports.""" + + def __init__(self, extra=None): + if extra is None: + extra = {} + self._extra = extra + + def get_extra_info(self, name, default=None): + """Get optional transport information.""" + return self._extra.get(name, default) + + def close(self): + """Close the transport. + + Buffered data will be flushed asynchronously. No more data + will be received. After all buffered data is flushed, the + protocol's connection_lost() method will (eventually) called + with None as its argument. + """ + raise NotImplementedError + + +class ReadTransport(BaseTransport): + """Interface for read-only transports.""" + + def pause_reading(self): + """Pause the receiving end. + + No data will be passed to the protocol's data_received() + method until resume_reading() is called. + """ + raise NotImplementedError + + def resume_reading(self): + """Resume the receiving end. + + Data received will once again be passed to the protocol's + data_received() method. + """ + raise NotImplementedError + + +class WriteTransport(BaseTransport): + """Interface for write-only transports.""" + + def set_write_buffer_limits(self, high=None, low=None): + """Set the high- and low-water limits for write flow control. + + These two values control when to call the protocol's + pause_writing() and resume_writing() methods. If specified, + the low-water limit must be less than or equal to the + high-water limit. Neither value can be negative. + + The defaults are implementation-specific. If only the + high-water limit is given, the low-water limit defaults to a + implementation-specific value less than or equal to the + high-water limit. Setting high to zero forces low to zero as + well, and causes pause_writing() to be called whenever the + buffer becomes non-empty. Setting low to zero causes + resume_writing() to be called only once the buffer is empty. + Use of zero for either limit is generally sub-optimal as it + reduces opportunities for doing I/O and computation + concurrently. + """ + raise NotImplementedError + + def get_write_buffer_size(self): + """Return the current size of the write buffer.""" + raise NotImplementedError + + def write(self, data): + """Write some data bytes to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + """ + raise NotImplementedError + + def writelines(self, list_of_data): + """Write a list (or any iterable) of data bytes to the transport. + + The default implementation concatenates the arguments and + calls write() on the result. + """ + if not _PY34: + # In Python 3.3, bytes.join() doesn't handle memoryview. + list_of_data = ( + bytes(data) if isinstance(data, memoryview) else data + for data in list_of_data) + self.write(b''.join(list_of_data)) + + def write_eof(self): + """Close the write end after flushing buffered data. + + (This is like typing ^D into a UNIX program reading from stdin.) + + Data may still be received. + """ + raise NotImplementedError + + def can_write_eof(self): + """Return True if this transport supports write_eof(), False if not.""" + raise NotImplementedError + + def abort(self): + """Close the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class Transport(ReadTransport, WriteTransport): + """Interface representing a bidirectional transport. + + There may be several implementations, but typically, the user does + not implement new transports; rather, the platform provides some + useful transports that are implemented using the platform's best + practices. + + The user never instantiates a transport directly; they call a + utility function, passing it a protocol factory and other + information necessary to create the transport and protocol. (E.g. + EventLoop.create_connection() or EventLoop.create_server().) + + The utility function will asynchronously create a transport and a + protocol and hook them up by calling the protocol's + connection_made() method, passing it the transport. + + The implementation here raises NotImplemented for every method + except writelines(), which calls write() in a loop. + """ + + +class DatagramTransport(BaseTransport): + """Interface for datagram (UDP) transports.""" + + def sendto(self, data, addr=None): + """Send data to the transport. + + This does not block; it buffers the data and arranges for it + to be sent out asynchronously. + addr is target socket address. + If addr is None use target address pointed on transport creation. + """ + raise NotImplementedError + + def abort(self): + """Close the transport immediately. + + Buffered data will be lost. No more data will be received. + The protocol's connection_lost() method will (eventually) be + called with None as its argument. + """ + raise NotImplementedError + + +class SubprocessTransport(BaseTransport): + + def get_pid(self): + """Get subprocess id.""" + raise NotImplementedError + + def get_returncode(self): + """Get subprocess returncode. + + See also + http://docs.python.org/3/library/subprocess#subprocess.Popen.returncode + """ + raise NotImplementedError + + def get_pipe_transport(self, fd): + """Get transport for pipe with number fd.""" + raise NotImplementedError + + def send_signal(self, signal): + """Send signal to subprocess. + + See also: + docs.python.org/3/library/subprocess#subprocess.Popen.send_signal + """ + raise NotImplementedError + + def terminate(self): + """Stop the subprocess. + + Alias for close() method. + + On Posix OSs the method sends SIGTERM to the subprocess. + On Windows the Win32 API function TerminateProcess() + is called to stop the subprocess. + + See also: + http://docs.python.org/3/library/subprocess#subprocess.Popen.terminate + """ + raise NotImplementedError + + def kill(self): + """Kill the subprocess. + + On Posix OSs the function sends SIGKILL to the subprocess. + On Windows kill() is an alias for terminate(). + + See also: + http://docs.python.org/3/library/subprocess#subprocess.Popen.kill + """ + raise NotImplementedError + + +class _FlowControlMixin(Transport): + """All the logic for (write) flow control in a mix-in base class. + + The subclass must implement get_write_buffer_size(). It must call + _maybe_pause_protocol() whenever the write buffer size increases, + and _maybe_resume_protocol() whenever it decreases. It may also + override set_write_buffer_limits() (e.g. to specify different + defaults). + + The subclass constructor must call super().__init__(extra). This + will call set_write_buffer_limits(). + + The user may call set_write_buffer_limits() and + get_write_buffer_size(), and their protocol's pause_writing() and + resume_writing() may be called. + """ + + def __init__(self, extra=None, loop=None): + super().__init__(extra) + assert loop is not None + self._loop = loop + self._protocol_paused = False + self._set_write_buffer_limits() + + def _maybe_pause_protocol(self): + size = self.get_write_buffer_size() + if size <= self._high_water: + return + if not self._protocol_paused: + self._protocol_paused = True + try: + self._protocol.pause_writing() + except Exception as exc: + self._loop.call_exception_handler({ + 'message': 'protocol.pause_writing() failed', + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) + + def _maybe_resume_protocol(self): + if (self._protocol_paused and + self.get_write_buffer_size() <= self._low_water): + self._protocol_paused = False + try: + self._protocol.resume_writing() + except Exception as exc: + self._loop.call_exception_handler({ + 'message': 'protocol.resume_writing() failed', + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) + + def get_write_buffer_limits(self): + return (self._low_water, self._high_water) + + def _set_write_buffer_limits(self, high=None, low=None): + if high is None: + if low is None: + high = 64*1024 + else: + high = 4*low + if low is None: + low = high // 4 + if not high >= low >= 0: + raise ValueError('high (%r) must be >= low (%r) must be >= 0' % + (high, low)) + self._high_water = high + self._low_water = low + + def set_write_buffer_limits(self, high=None, low=None): + self._set_write_buffer_limits(high=high, low=low) + self._maybe_pause_protocol() + + def get_write_buffer_size(self): + raise NotImplementedError diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py new file mode 100644 index 00000000..97f9addd --- /dev/null +++ b/asyncio/unix_events.py @@ -0,0 +1,961 @@ +"""Selector event loop for Unix with signal handling.""" + +import errno +import os +import signal +import socket +import stat +import subprocess +import sys +import threading + + +from . import base_events +from . import base_subprocess +from . import constants +from . import coroutines +from . import events +from . import selector_events +from . import selectors +from . import transports +from .coroutines import coroutine +from .log import logger + + +__all__ = ['SelectorEventLoop', + 'AbstractChildWatcher', 'SafeChildWatcher', + 'FastChildWatcher', 'DefaultEventLoopPolicy', + ] + +if sys.platform == 'win32': # pragma: no cover + raise ImportError('Signals are not really supported on Windows') + + +def _sighandler_noop(signum, frame): + """Dummy signal handler.""" + pass + + +class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Unix event loop. + + Adds signal handling and UNIX Domain Socket support to SelectorEventLoop. + """ + + def __init__(self, selector=None): + super().__init__(selector) + self._signal_handlers = {} + + def _socketpair(self): + return socket.socketpair() + + def close(self): + super().close() + for sig in list(self._signal_handlers): + self.remove_signal_handler(sig) + + def _process_self_data(self, data): + for signum in data: + if not signum: + # ignore null bytes written by _write_to_self() + continue + self._handle_signal(signum) + + def add_signal_handler(self, sig, callback, *args): + """Add a handler for a signal. UNIX only. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if (coroutines.iscoroutine(callback) + or coroutines.iscoroutinefunction(callback)): + raise TypeError("coroutines cannot be used " + "with add_signal_handler()") + self._check_signal(sig) + self._check_closed() + try: + # set_wakeup_fd() raises ValueError if this is not the + # main thread. By calling it early we ensure that an + # event loop running in another thread cannot add a signal + # handler. + signal.set_wakeup_fd(self._csock.fileno()) + except (ValueError, OSError) as exc: + raise RuntimeError(str(exc)) + + handle = events.Handle(callback, args, self) + self._signal_handlers[sig] = handle + + try: + # Register a dummy signal handler to ask Python to write the signal + # number in the wakup file descriptor. _process_self_data() will + # read signal numbers from this file descriptor to handle signals. + signal.signal(sig, _sighandler_noop) + + # Set SA_RESTART to limit EINTR occurrences. + signal.siginterrupt(sig, False) + except OSError as exc: + del self._signal_handlers[sig] + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except (ValueError, OSError) as nexc: + logger.info('set_wakeup_fd(-1) failed: %s', nexc) + + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + def _handle_signal(self, sig): + """Internal helper that is the actual signal handler.""" + handle = self._signal_handlers.get(sig) + if handle is None: + return # Assume it's some race condition. + if handle._cancelled: + self.remove_signal_handler(sig) # Remove it properly. + else: + self._add_callback_signalsafe(handle) + + def remove_signal_handler(self, sig): + """Remove a handler for a signal. UNIX only. + + Return True if a signal handler was removed, False if not. + """ + self._check_signal(sig) + try: + del self._signal_handlers[sig] + except KeyError: + return False + + if sig == signal.SIGINT: + handler = signal.default_int_handler + else: + handler = signal.SIG_DFL + + try: + signal.signal(sig, handler) + except OSError as exc: + if exc.errno == errno.EINVAL: + raise RuntimeError('sig {} cannot be caught'.format(sig)) + else: + raise + + if not self._signal_handlers: + try: + signal.set_wakeup_fd(-1) + except (ValueError, OSError) as exc: + logger.info('set_wakeup_fd(-1) failed: %s', exc) + + return True + + def _check_signal(self, sig): + """Internal helper to validate a signal. + + Raise ValueError if the signal number is invalid or uncatchable. + Raise RuntimeError if there is a problem setting up the handler. + """ + if not isinstance(sig, int): + raise TypeError('sig must be an int, not {!r}'.format(sig)) + + if not (1 <= sig < signal.NSIG): + raise ValueError( + 'sig {} out of range(1, {})'.format(sig, signal.NSIG)) + + def _make_read_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixReadPipeTransport(self, pipe, protocol, waiter, extra) + + def _make_write_pipe_transport(self, pipe, protocol, waiter=None, + extra=None): + return _UnixWritePipeTransport(self, pipe, protocol, waiter, extra) + + @coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + with events.get_child_watcher() as watcher: + transp = _UnixSubprocessTransport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=extra, **kwargs) + try: + yield from transp._post_init() + except: + transp.close() + raise + watcher.add_child_handler(transp.get_pid(), + self._child_watcher_callback, transp) + + return transp + + def _child_watcher_callback(self, pid, returncode, transp): + self.call_soon_threadsafe(transp._process_exited, returncode) + + @coroutine + def create_unix_connection(self, protocol_factory, path, *, + ssl=None, sock=None, + server_hostname=None): + assert server_hostname is None or isinstance(server_hostname, str) + if ssl: + if server_hostname is None: + raise ValueError( + 'you have to pass server_hostname when using ssl') + else: + if server_hostname is not None: + raise ValueError('server_hostname is only meaningful with ssl') + + if path is not None: + if sock is not None: + raise ValueError( + 'path and sock can not be specified at the same time') + + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0) + try: + sock.setblocking(False) + yield from self.sock_connect(sock, path) + except: + sock.close() + raise + + else: + if sock is None: + raise ValueError('no path and sock were specified') + sock.setblocking(False) + + transport, protocol = yield from self._create_connection_transport( + sock, protocol_factory, ssl, server_hostname) + return transport, protocol + + @coroutine + def create_unix_server(self, protocol_factory, path=None, *, + sock=None, backlog=100, ssl=None): + if isinstance(ssl, bool): + raise TypeError('ssl argument must be an SSLContext or None') + + if path is not None: + if sock is not None: + raise ValueError( + 'path and sock can not be specified at the same time') + + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + + try: + sock.bind(path) + except OSError as exc: + sock.close() + if exc.errno == errno.EADDRINUSE: + # Let's improve the error message by adding + # with what exact address it occurs. + msg = 'Address {!r} is already in use'.format(path) + raise OSError(errno.EADDRINUSE, msg) from None + else: + raise + except: + sock.close() + raise + else: + if sock is None: + raise ValueError( + 'path was not specified, and no sock specified') + + if sock.family != socket.AF_UNIX: + raise ValueError( + 'A UNIX Domain Socket was expected, got {!r}'.format(sock)) + + server = base_events.Server(self, [sock]) + sock.listen(backlog) + sock.setblocking(False) + self._start_serving(protocol_factory, sock, ssl, server) + return server + + +if hasattr(os, 'set_blocking'): + def _set_nonblocking(fd): + os.set_blocking(fd, False) +else: + import fcntl + + def _set_nonblocking(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + flags = flags | os.O_NONBLOCK + fcntl.fcntl(fd, fcntl.F_SETFL, flags) + + +class _UnixReadPipeTransport(transports.ReadTransport): + + max_size = 256 * 1024 # max bytes we read in one event loop iteration + + def __init__(self, loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra) + self._extra['pipe'] = pipe + self._loop = loop + self._pipe = pipe + self._fileno = pipe.fileno() + mode = os.fstat(self._fileno).st_mode + if not (stat.S_ISFIFO(mode) or + stat.S_ISSOCK(mode) or + stat.S_ISCHR(mode)): + raise ValueError("Pipe transport is for pipes/sockets only.") + _set_nonblocking(self._fileno) + self._protocol = protocol + self._closing = False + self._loop.add_reader(self._fileno, self._read_ready) + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + # wait until protocol.connection_made() has been called + self._loop.call_soon(waiter._set_result_unless_cancelled, None) + + def __repr__(self): + info = [self.__class__.__name__] + if self._pipe is None: + info.append('closed') + elif self._closing: + info.append('closing') + info.append('fd=%s' % self._fileno) + if self._pipe is not None: + polling = selector_events._test_selector_event( + self._loop._selector, + self._fileno, selectors.EVENT_READ) + if polling: + info.append('polling') + else: + info.append('idle') + else: + info.append('closed') + return '<%s>' % ' '.join(info) + + def _read_ready(self): + try: + data = os.read(self._fileno, self.max_size) + except (BlockingIOError, InterruptedError): + pass + except OSError as exc: + self._fatal_error(exc, 'Fatal read error on pipe transport') + else: + if data: + self._protocol.data_received(data) + else: + if self._loop.get_debug(): + logger.info("%r was closed by peer", self) + self._closing = True + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._protocol.eof_received) + self._loop.call_soon(self._call_connection_lost, None) + + def pause_reading(self): + self._loop.remove_reader(self._fileno) + + def resume_reading(self): + self._loop.add_reader(self._fileno, self._read_ready) + + def close(self): + if not self._closing: + self._close(None) + + def _fatal_error(self, exc, message='Fatal error on pipe transport'): + # should be called by exception handler only + if (isinstance(exc, OSError) and exc.errno == errno.EIO): + if self._loop.get_debug(): + logger.debug("%r: %s", self, message, exc_info=True) + else: + self._loop.call_exception_handler({ + 'message': message, + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) + self._close(exc) + + def _close(self, exc): + self._closing = True + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + self._pipe = None + self._protocol = None + self._loop = None + + +class _UnixWritePipeTransport(transports._FlowControlMixin, + transports.WriteTransport): + + def __init__(self, loop, pipe, protocol, waiter=None, extra=None): + super().__init__(extra, loop) + self._extra['pipe'] = pipe + self._pipe = pipe + self._fileno = pipe.fileno() + mode = os.fstat(self._fileno).st_mode + is_socket = stat.S_ISSOCK(mode) + if not (is_socket or + stat.S_ISFIFO(mode) or + stat.S_ISCHR(mode)): + raise ValueError("Pipe transport is only for " + "pipes, sockets and character devices") + _set_nonblocking(self._fileno) + self._protocol = protocol + self._buffer = [] + self._conn_lost = 0 + self._closing = False # Set when close() or write_eof() called. + + # On AIX, the reader trick only works for sockets. + # On other platforms it works for pipes and sockets. + # (Exception: OS X 10.4? Issue #19294.) + if is_socket or not sys.platform.startswith("aix"): + self._loop.add_reader(self._fileno, self._read_ready) + + self._loop.call_soon(self._protocol.connection_made, self) + if waiter is not None: + # wait until protocol.connection_made() has been called + self._loop.call_soon(waiter._set_result_unless_cancelled, None) + + def __repr__(self): + info = [self.__class__.__name__] + if self._pipe is None: + info.append('closed') + elif self._closing: + info.append('closing') + info.append('fd=%s' % self._fileno) + if self._pipe is not None: + polling = selector_events._test_selector_event( + self._loop._selector, + self._fileno, selectors.EVENT_WRITE) + if polling: + info.append('polling') + else: + info.append('idle') + + bufsize = self.get_write_buffer_size() + info.append('bufsize=%s' % bufsize) + else: + info.append('closed') + return '<%s>' % ' '.join(info) + + def get_write_buffer_size(self): + return sum(len(data) for data in self._buffer) + + def _read_ready(self): + # Pipe was closed by peer. + if self._loop.get_debug(): + logger.info("%r was closed by peer", self) + if self._buffer: + self._close(BrokenPipeError()) + else: + self._close() + + def write(self, data): + assert isinstance(data, (bytes, bytearray, memoryview)), repr(data) + if isinstance(data, bytearray): + data = memoryview(data) + if not data: + return + + if self._conn_lost or self._closing: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('pipe closed by peer or ' + 'os.write(pipe, data) raised exception.') + self._conn_lost += 1 + return + + if not self._buffer: + # Attempt to send it right away first. + try: + n = os.write(self._fileno, data) + except (BlockingIOError, InterruptedError): + n = 0 + except Exception as exc: + self._conn_lost += 1 + self._fatal_error(exc, 'Fatal write error on pipe transport') + return + if n == len(data): + return + elif n > 0: + data = data[n:] + self._loop.add_writer(self._fileno, self._write_ready) + + self._buffer.append(data) + self._maybe_pause_protocol() + + def _write_ready(self): + data = b''.join(self._buffer) + assert data, 'Data should not be empty' + + self._buffer.clear() + try: + n = os.write(self._fileno, data) + except (BlockingIOError, InterruptedError): + self._buffer.append(data) + except Exception as exc: + self._conn_lost += 1 + # Remove writer here, _fatal_error() doesn't it + # because _buffer is empty. + self._loop.remove_writer(self._fileno) + self._fatal_error(exc, 'Fatal write error on pipe transport') + else: + if n == len(data): + self._loop.remove_writer(self._fileno) + self._maybe_resume_protocol() # May append to buffer. + if not self._buffer and self._closing: + self._loop.remove_reader(self._fileno) + self._call_connection_lost(None) + return + elif n > 0: + data = data[n:] + + self._buffer.append(data) # Try again later. + + def can_write_eof(self): + return True + + def write_eof(self): + if self._closing: + return + assert self._pipe + self._closing = True + if not self._buffer: + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, None) + + def close(self): + if self._pipe is not None and not self._closing: + # write_eof is all what we needed to close the write pipe + self.write_eof() + + def abort(self): + self._close(None) + + def _fatal_error(self, exc, message='Fatal error on pipe transport'): + # should be called by exception handler only + if isinstance(exc, (BrokenPipeError, ConnectionResetError)): + if self._loop.get_debug(): + logger.debug("%r: %s", self, message, exc_info=True) + else: + self._loop.call_exception_handler({ + 'message': message, + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) + self._close(exc) + + def _close(self, exc=None): + self._closing = True + if self._buffer: + self._loop.remove_writer(self._fileno) + self._buffer.clear() + self._loop.remove_reader(self._fileno) + self._loop.call_soon(self._call_connection_lost, exc) + + def _call_connection_lost(self, exc): + try: + self._protocol.connection_lost(exc) + finally: + self._pipe.close() + self._pipe = None + self._protocol = None + self._loop = None + + +if hasattr(os, 'set_inheritable'): + # Python 3.4 and newer + _set_inheritable = os.set_inheritable +else: + import fcntl + + def _set_inheritable(fd, inheritable): + cloexec_flag = getattr(fcntl, 'FD_CLOEXEC', 1) + + old = fcntl.fcntl(fd, fcntl.F_GETFD) + if not inheritable: + fcntl.fcntl(fd, fcntl.F_SETFD, old | cloexec_flag) + else: + fcntl.fcntl(fd, fcntl.F_SETFD, old & ~cloexec_flag) + + +class _UnixSubprocessTransport(base_subprocess.BaseSubprocessTransport): + + def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): + stdin_w = None + if stdin == subprocess.PIPE: + # Use a socket pair for stdin, since not all platforms + # support selecting read events on the write end of a + # socket (which we use in order to detect closing of the + # other end). Notably this is needed on AIX, and works + # just fine on other platforms. + stdin, stdin_w = self._loop._socketpair() + + # Mark the write end of the stdin pipe as non-inheritable, + # needed by close_fds=False on Python 3.3 and older + # (Python 3.4 implements the PEP 446, socketpair returns + # non-inheritable sockets) + _set_inheritable(stdin_w.fileno(), False) + self._proc = subprocess.Popen( + args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, + universal_newlines=False, bufsize=bufsize, **kwargs) + if stdin_w is not None: + stdin.close() + self._proc.stdin = open(stdin_w.detach(), 'wb', buffering=bufsize) + + +class AbstractChildWatcher: + """Abstract base class for monitoring child processes. + + Objects derived from this class monitor a collection of subprocesses and + report their termination or interruption by a signal. + + New callbacks are registered with .add_child_handler(). Starting a new + process must be done within a 'with' block to allow the watcher to suspend + its activity until the new process if fully registered (this is needed to + prevent a race condition in some implementations). + + Example: + with watcher: + proc = subprocess.Popen("sleep 1") + watcher.add_child_handler(proc.pid, callback) + + Notes: + Implementations of this class must be thread-safe. + + Since child watcher objects may catch the SIGCHLD signal and call + waitpid(-1), there should be only one active object per process. + """ + + def add_child_handler(self, pid, callback, *args): + """Register a new child handler. + + Arrange for callback(pid, returncode, *args) to be called when + process 'pid' terminates. Specifying another callback for the same + process replaces the previous handler. + + Note: callback() must be thread-safe. + """ + raise NotImplementedError() + + def remove_child_handler(self, pid): + """Removes the handler for process 'pid'. + + The function returns True if the handler was successfully removed, + False if there was nothing to remove.""" + + raise NotImplementedError() + + def attach_loop(self, loop): + """Attach the watcher to an event loop. + + If the watcher was previously attached to an event loop, then it is + first detached before attaching to the new loop. + + Note: loop may be None. + """ + raise NotImplementedError() + + def close(self): + """Close the watcher. + + This must be called to make sure that any underlying resource is freed. + """ + raise NotImplementedError() + + def __enter__(self): + """Enter the watcher's context and allow starting new processes + + This function must return self""" + raise NotImplementedError() + + def __exit__(self, a, b, c): + """Exit the watcher's context""" + raise NotImplementedError() + + +class BaseChildWatcher(AbstractChildWatcher): + + def __init__(self): + self._loop = None + + def close(self): + self.attach_loop(None) + + def _do_waitpid(self, expected_pid): + raise NotImplementedError() + + def _do_waitpid_all(self): + raise NotImplementedError() + + def attach_loop(self, loop): + assert loop is None or isinstance(loop, events.AbstractEventLoop) + + if self._loop is not None: + self._loop.remove_signal_handler(signal.SIGCHLD) + + self._loop = loop + if loop is not None: + loop.add_signal_handler(signal.SIGCHLD, self._sig_chld) + + # Prevent a race condition in case a child terminated + # during the switch. + self._do_waitpid_all() + + def _sig_chld(self): + try: + self._do_waitpid_all() + except Exception as exc: + # self._loop should always be available here + # as '_sig_chld' is added as a signal handler + # in 'attach_loop' + self._loop.call_exception_handler({ + 'message': 'Unknown exception in SIGCHLD handler', + 'exception': exc, + }) + + def _compute_returncode(self, status): + if os.WIFSIGNALED(status): + # The child process died because of a signal. + return -os.WTERMSIG(status) + elif os.WIFEXITED(status): + # The child process exited (e.g sys.exit()). + return os.WEXITSTATUS(status) + else: + # The child exited, but we don't understand its status. + # This shouldn't happen, but if it does, let's just + # return that status; perhaps that helps debug it. + return status + + +class SafeChildWatcher(BaseChildWatcher): + """'Safe' child watcher implementation. + + This implementation avoids disrupting other code spawning processes by + polling explicitly each process in the SIGCHLD handler instead of calling + os.waitpid(-1). + + This is a safe solution but it has a significant overhead when handling a + big number of children (O(n) each time SIGCHLD is raised) + """ + + def __init__(self): + super().__init__() + self._callbacks = {} + + def close(self): + self._callbacks.clear() + super().close() + + def __enter__(self): + return self + + def __exit__(self, a, b, c): + pass + + def add_child_handler(self, pid, callback, *args): + self._callbacks[pid] = callback, args + + # Prevent a race condition in case the child is already terminated. + self._do_waitpid(pid) + + def remove_child_handler(self, pid): + try: + del self._callbacks[pid] + return True + except KeyError: + return False + + def _do_waitpid_all(self): + + for pid in list(self._callbacks): + self._do_waitpid(pid) + + def _do_waitpid(self, expected_pid): + assert expected_pid > 0 + + try: + pid, status = os.waitpid(expected_pid, os.WNOHANG) + except ChildProcessError: + # The child process is already reaped + # (may happen if waitpid() is called elsewhere). + pid = expected_pid + returncode = 255 + logger.warning( + "Unknown child process pid %d, will report returncode 255", + pid) + else: + if pid == 0: + # The child process is still alive. + return + + returncode = self._compute_returncode(status) + if self._loop.get_debug(): + logger.debug('process %s exited with returncode %s', + expected_pid, returncode) + + try: + callback, args = self._callbacks.pop(pid) + except KeyError: # pragma: no cover + # May happen if .remove_child_handler() is called + # after os.waitpid() returns. + if self._loop.get_debug(): + logger.warning("Child watcher got an unexpected pid: %r", + pid, exc_info=True) + else: + callback(pid, returncode, *args) + + +class FastChildWatcher(BaseChildWatcher): + """'Fast' child watcher implementation. + + This implementation reaps every terminated processes by calling + os.waitpid(-1) directly, possibly breaking other code spawning processes + and waiting for their termination. + + There is no noticeable overhead when handling a big number of children + (O(1) each time a child terminates). + """ + def __init__(self): + super().__init__() + self._callbacks = {} + self._lock = threading.Lock() + self._zombies = {} + self._forks = 0 + + def close(self): + self._callbacks.clear() + self._zombies.clear() + super().close() + + def __enter__(self): + with self._lock: + self._forks += 1 + + return self + + def __exit__(self, a, b, c): + with self._lock: + self._forks -= 1 + + if self._forks or not self._zombies: + return + + collateral_victims = str(self._zombies) + self._zombies.clear() + + logger.warning( + "Caught subprocesses termination from unknown pids: %s", + collateral_victims) + + def add_child_handler(self, pid, callback, *args): + assert self._forks, "Must use the context manager" + with self._lock: + try: + returncode = self._zombies.pop(pid) + except KeyError: + # The child is running. + self._callbacks[pid] = callback, args + return + + # The child is dead already. We can fire the callback. + callback(pid, returncode, *args) + + def remove_child_handler(self, pid): + try: + del self._callbacks[pid] + return True + except KeyError: + return False + + def _do_waitpid_all(self): + # Because of signal coalescing, we must keep calling waitpid() as + # long as we're able to reap a child. + while True: + try: + pid, status = os.waitpid(-1, os.WNOHANG) + except ChildProcessError: + # No more child processes exist. + return + else: + if pid == 0: + # A child process is still alive. + return + + returncode = self._compute_returncode(status) + + with self._lock: + try: + callback, args = self._callbacks.pop(pid) + except KeyError: + # unknown child + if self._forks: + # It may not be registered yet. + self._zombies[pid] = returncode + if self._loop.get_debug(): + logger.debug('unknown process %s exited ' + 'with returncode %s', + pid, returncode) + continue + callback = None + else: + if self._loop.get_debug(): + logger.debug('process %s exited with returncode %s', + pid, returncode) + + if callback is None: + logger.warning( + "Caught subprocess termination from unknown pid: " + "%d -> %d", pid, returncode) + else: + callback(pid, returncode, *args) + + +class _UnixDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy): + """UNIX event loop policy with a watcher for child processes.""" + _loop_factory = _UnixSelectorEventLoop + + def __init__(self): + super().__init__() + self._watcher = None + + def _init_watcher(self): + with events._lock: + if self._watcher is None: # pragma: no branch + self._watcher = SafeChildWatcher() + if isinstance(threading.current_thread(), + threading._MainThread): + self._watcher.attach_loop(self._local._loop) + + def set_event_loop(self, loop): + """Set the event loop. + + As a side effect, if a child watcher was set before, then calling + .set_event_loop() from the main thread will call .attach_loop(loop) on + the child watcher. + """ + + super().set_event_loop(loop) + + if self._watcher is not None and \ + isinstance(threading.current_thread(), threading._MainThread): + self._watcher.attach_loop(loop) + + def get_child_watcher(self): + """Get the watcher for child processes. + + If not yet set, a SafeChildWatcher object is automatically created. + """ + if self._watcher is None: + self._init_watcher() + + return self._watcher + + def set_child_watcher(self, watcher): + """Set the watcher for child processes.""" + + assert watcher is None or isinstance(watcher, AbstractChildWatcher) + + if self._watcher is not None: + self._watcher.close() + + self._watcher = watcher + +SelectorEventLoop = _UnixSelectorEventLoop +DefaultEventLoopPolicy = _UnixDefaultEventLoopPolicy diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py new file mode 100644 index 00000000..6c7e0580 --- /dev/null +++ b/asyncio/windows_events.py @@ -0,0 +1,752 @@ +"""Selector and proactor event loops for Windows.""" + +import _winapi +import errno +import math +import socket +import struct +import weakref + +from . import events +from . import base_subprocess +from . import futures +from . import proactor_events +from . import selector_events +from . import tasks +from . import windows_utils +from . import _overlapped +from .coroutines import coroutine +from .log import logger + + +__all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor', + 'DefaultEventLoopPolicy', + ] + + +NULL = 0 +INFINITE = 0xffffffff +ERROR_CONNECTION_REFUSED = 1225 +ERROR_CONNECTION_ABORTED = 1236 + +# Initial delay in seconds for connect_pipe() before retrying to connect +CONNECT_PIPE_INIT_DELAY = 0.001 + +# Maximum delay in seconds for connect_pipe() before retrying to connect +CONNECT_PIPE_MAX_DELAY = 0.100 + + +class _OverlappedFuture(futures.Future): + """Subclass of Future which represents an overlapped operation. + + Cancelling it will immediately cancel the overlapped operation. + """ + + def __init__(self, ov, *, loop=None): + super().__init__(loop=loop) + if self._source_traceback: + del self._source_traceback[-1] + self._ov = ov + + def _repr_info(self): + info = super()._repr_info() + if self._ov is not None: + state = 'pending' if self._ov.pending else 'completed' + info.insert(1, 'overlapped=<%s, %#x>' % (state, self._ov.address)) + return info + + def _cancel_overlapped(self): + if self._ov is None: + return + try: + self._ov.cancel() + except OSError as exc: + context = { + 'message': 'Cancelling an overlapped future failed', + 'exception': exc, + 'future': self, + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + self._ov = None + + def cancel(self): + self._cancel_overlapped() + return super().cancel() + + def set_exception(self, exception): + super().set_exception(exception) + self._cancel_overlapped() + + def set_result(self, result): + super().set_result(result) + self._ov = None + + +class _BaseWaitHandleFuture(futures.Future): + """Subclass of Future which represents a wait handle.""" + + def __init__(self, ov, handle, wait_handle, *, loop=None): + super().__init__(loop=loop) + if self._source_traceback: + del self._source_traceback[-1] + # Keep a reference to the Overlapped object to keep it alive until the + # wait is unregistered + self._ov = ov + self._handle = handle + self._wait_handle = wait_handle + + # Should we call UnregisterWaitEx() if the wait completes + # or is cancelled? + self._registered = True + + def _poll(self): + # non-blocking wait: use a timeout of 0 millisecond + return (_winapi.WaitForSingleObject(self._handle, 0) == + _winapi.WAIT_OBJECT_0) + + def _repr_info(self): + info = super()._repr_info() + info.append('handle=%#x' % self._handle) + if self._handle is not None: + state = 'signaled' if self._poll() else 'waiting' + info.append(state) + if self._wait_handle is not None: + info.append('wait_handle=%#x' % self._wait_handle) + return info + + def _unregister_wait_cb(self, fut): + # The wait was unregistered: it's not safe to destroy the Overlapped + # object + self._ov = None + + def _unregister_wait(self): + if not self._registered: + return + self._registered = False + + try: + _overlapped.UnregisterWait(self._wait_handle) + except OSError as exc: + self._wait_handle = None + if exc.winerror == _overlapped.ERROR_IO_PENDING: + # ERROR_IO_PENDING is not an error, the wait was unregistered + self._unregister_wait_cb(None) + elif exc.winerror != _overlapped.ERROR_IO_PENDING: + context = { + 'message': 'Failed to unregister the wait handle', + 'exception': exc, + 'future': self, + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + else: + self._wait_handle = None + self._unregister_wait_cb(None) + + def cancel(self): + self._unregister_wait() + return super().cancel() + + def set_exception(self, exception): + self._unregister_wait() + super().set_exception(exception) + + def set_result(self, result): + self._unregister_wait() + super().set_result(result) + + +class _WaitCancelFuture(_BaseWaitHandleFuture): + """Subclass of Future which represents a wait for the cancellation of a + _WaitHandleFuture using an event. + """ + + def __init__(self, ov, event, wait_handle, *, loop=None): + super().__init__(ov, event, wait_handle, loop=loop) + + self._done_callback = None + + def cancel(self): + raise RuntimeError("_WaitCancelFuture must not be cancelled") + + def _schedule_callbacks(self): + super(_WaitCancelFuture, self)._schedule_callbacks() + if self._done_callback is not None: + self._done_callback(self) + + +class _WaitHandleFuture(_BaseWaitHandleFuture): + def __init__(self, ov, handle, wait_handle, proactor, *, loop=None): + super().__init__(ov, handle, wait_handle, loop=loop) + self._proactor = proactor + self._unregister_proactor = True + self._event = _overlapped.CreateEvent(None, True, False, None) + self._event_fut = None + + def _unregister_wait_cb(self, fut): + if self._event is not None: + _winapi.CloseHandle(self._event) + self._event = None + self._event_fut = None + + # If the wait was cancelled, the wait may never be signalled, so + # it's required to unregister it. Otherwise, IocpProactor.close() will + # wait forever for an event which will never come. + # + # If the IocpProactor already received the event, it's safe to call + # _unregister() because we kept a reference to the Overlapped object + # which is used as an unique key. + self._proactor._unregister(self._ov) + self._proactor = None + + super()._unregister_wait_cb(fut) + + def _unregister_wait(self): + if not self._registered: + return + self._registered = False + + try: + _overlapped.UnregisterWaitEx(self._wait_handle, self._event) + except OSError as exc: + self._wait_handle = None + if exc.winerror == _overlapped.ERROR_IO_PENDING: + # ERROR_IO_PENDING is not an error, the wait was unregistered + self._unregister_wait_cb(None) + elif exc.winerror != _overlapped.ERROR_IO_PENDING: + context = { + 'message': 'Failed to unregister the wait handle', + 'exception': exc, + 'future': self, + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + else: + self._wait_handle = None + self._event_fut = self._proactor._wait_cancel( + self._event, + self._unregister_wait_cb) + + +class PipeServer(object): + """Class representing a pipe server. + + This is much like a bound, listening socket. + """ + def __init__(self, address): + self._address = address + self._free_instances = weakref.WeakSet() + # initialize the pipe attribute before calling _server_pipe_handle() + # because this function can raise an exception and the destructor calls + # the close() method + self._pipe = None + self._accept_pipe_future = None + self._pipe = self._server_pipe_handle(True) + + def _get_unconnected_pipe(self): + # Create new instance and return previous one. This ensures + # that (until the server is closed) there is always at least + # one pipe handle for address. Therefore if a client attempt + # to connect it will not fail with FileNotFoundError. + tmp, self._pipe = self._pipe, self._server_pipe_handle(False) + return tmp + + def _server_pipe_handle(self, first): + # Return a wrapper for a new pipe handle. + if self._address is None: + return None + flags = _winapi.PIPE_ACCESS_DUPLEX | _winapi.FILE_FLAG_OVERLAPPED + if first: + flags |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE + h = _winapi.CreateNamedPipe( + self._address, flags, + _winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_READMODE_MESSAGE | + _winapi.PIPE_WAIT, + _winapi.PIPE_UNLIMITED_INSTANCES, + windows_utils.BUFSIZE, windows_utils.BUFSIZE, + _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL) + pipe = windows_utils.PipeHandle(h) + self._free_instances.add(pipe) + return pipe + + def close(self): + if self._accept_pipe_future is not None: + self._accept_pipe_future.cancel() + self._accept_pipe_future = None + # Close all instances which have not been connected to by a client. + if self._address is not None: + for pipe in self._free_instances: + pipe.close() + self._pipe = None + self._address = None + self._free_instances.clear() + + __del__ = close + + +class _WindowsSelectorEventLoop(selector_events.BaseSelectorEventLoop): + """Windows version of selector event loop.""" + + def _socketpair(self): + return windows_utils.socketpair() + + +class ProactorEventLoop(proactor_events.BaseProactorEventLoop): + """Windows version of proactor event loop using IOCP.""" + + def __init__(self, proactor=None): + if proactor is None: + proactor = IocpProactor() + super().__init__(proactor) + + def _socketpair(self): + return windows_utils.socketpair() + + @coroutine + def create_pipe_connection(self, protocol_factory, address): + f = self._proactor.connect_pipe(address) + pipe = yield from f + protocol = protocol_factory() + trans = self._make_duplex_pipe_transport(pipe, protocol, + extra={'addr': address}) + return trans, protocol + + @coroutine + def start_serving_pipe(self, protocol_factory, address): + server = PipeServer(address) + + def loop_accept_pipe(f=None): + pipe = None + try: + if f: + pipe = f.result() + server._free_instances.discard(pipe) + protocol = protocol_factory() + self._make_duplex_pipe_transport( + pipe, protocol, extra={'addr': address}) + pipe = server._get_unconnected_pipe() + if pipe is None: + return + f = self._proactor.accept_pipe(pipe) + except OSError as exc: + if pipe and pipe.fileno() != -1: + self.call_exception_handler({ + 'message': 'Pipe accept failed', + 'exception': exc, + 'pipe': pipe, + }) + pipe.close() + elif self._debug: + logger.warning("Accept pipe failed on pipe %r", + pipe, exc_info=True) + except futures.CancelledError: + if pipe: + pipe.close() + else: + server._accept_pipe_future = f + f.add_done_callback(loop_accept_pipe) + + self.call_soon(loop_accept_pipe) + return [server] + + @coroutine + def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): + transp = _WindowsSubprocessTransport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=extra, **kwargs) + try: + yield from transp._post_init() + except: + transp.close() + raise + + return transp + + +class IocpProactor: + """Proactor implementation using IOCP.""" + + def __init__(self, concurrency=0xffffffff): + self._loop = None + self._results = [] + self._iocp = _overlapped.CreateIoCompletionPort( + _overlapped.INVALID_HANDLE_VALUE, NULL, 0, concurrency) + self._cache = {} + self._registered = weakref.WeakSet() + self._unregistered = [] + self._stopped_serving = weakref.WeakSet() + + def __repr__(self): + return ('<%s overlapped#=%s result#=%s>' + % (self.__class__.__name__, len(self._cache), + len(self._results))) + + def set_loop(self, loop): + self._loop = loop + + def select(self, timeout=None): + if not self._results: + self._poll(timeout) + tmp = self._results + self._results = [] + return tmp + + def recv(self, conn, nbytes, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + if isinstance(conn, socket.socket): + ov.WSARecv(conn.fileno(), nbytes, flags) + else: + ov.ReadFile(conn.fileno(), nbytes) + + def finish_recv(trans, key, ov): + try: + return ov.getresult() + except OSError as exc: + if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: + raise ConnectionResetError(*exc.args) + else: + raise + + return self._register(ov, conn, finish_recv) + + def send(self, conn, buf, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + if isinstance(conn, socket.socket): + ov.WSASend(conn.fileno(), buf, flags) + else: + ov.WriteFile(conn.fileno(), buf) + + def finish_send(trans, key, ov): + try: + return ov.getresult() + except OSError as exc: + if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: + raise ConnectionResetError(*exc.args) + else: + raise + + return self._register(ov, conn, finish_send) + + def accept(self, listener): + self._register_with_iocp(listener) + conn = self._get_accept_socket(listener.family) + ov = _overlapped.Overlapped(NULL) + ov.AcceptEx(listener.fileno(), conn.fileno()) + + def finish_accept(trans, key, ov): + ov.getresult() + # Use SO_UPDATE_ACCEPT_CONTEXT so getsockname() etc work. + buf = struct.pack('@P', listener.fileno()) + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_ACCEPT_CONTEXT, buf) + conn.settimeout(listener.gettimeout()) + return conn, conn.getpeername() + + @coroutine + def accept_coro(future, conn): + # Coroutine closing the accept socket if the future is cancelled + try: + yield from future + except futures.CancelledError: + conn.close() + raise + + future = self._register(ov, listener, finish_accept) + coro = accept_coro(future, conn) + tasks.async(coro, loop=self._loop) + return future + + def connect(self, conn, address): + self._register_with_iocp(conn) + # The socket needs to be locally bound before we call ConnectEx(). + try: + _overlapped.BindLocal(conn.fileno(), conn.family) + except OSError as e: + if e.winerror != errno.WSAEINVAL: + raise + # Probably already locally bound; check using getsockname(). + if conn.getsockname()[1] == 0: + raise + ov = _overlapped.Overlapped(NULL) + ov.ConnectEx(conn.fileno(), address) + + def finish_connect(trans, key, ov): + ov.getresult() + # Use SO_UPDATE_CONNECT_CONTEXT so getsockname() etc work. + conn.setsockopt(socket.SOL_SOCKET, + _overlapped.SO_UPDATE_CONNECT_CONTEXT, 0) + return conn + + return self._register(ov, conn, finish_connect) + + def accept_pipe(self, pipe): + self._register_with_iocp(pipe) + ov = _overlapped.Overlapped(NULL) + connected = ov.ConnectNamedPipe(pipe.fileno()) + + if connected: + # ConnectNamePipe() failed with ERROR_PIPE_CONNECTED which means + # that the pipe is connected. There is no need to wait for the + # completion of the connection. + f = futures.Future(loop=self._loop) + f.set_result(pipe) + return f + + def finish_accept_pipe(trans, key, ov): + ov.getresult() + return pipe + + return self._register(ov, pipe, finish_accept_pipe) + + def _connect_pipe(self, fut, address, delay): + # Unfortunately there is no way to do an overlapped connect to a pipe. + # Call CreateFile() in a loop until it doesn't fail with + # ERROR_PIPE_BUSY + try: + handle = _overlapped.ConnectPipe(address) + except OSError as exc: + if exc.winerror == _overlapped.ERROR_PIPE_BUSY: + # Polling: retry later + delay = min(delay * 2, CONNECT_PIPE_MAX_DELAY) + self._loop.call_later(delay, + self._connect_pipe, fut, address, delay) + else: + fut.set_exception(exc) + else: + pipe = windows_utils.PipeHandle(handle) + fut.set_result(pipe) + + def connect_pipe(self, address): + fut = futures.Future(loop=self._loop) + self._connect_pipe(fut, address, CONNECT_PIPE_INIT_DELAY) + return fut + + def wait_for_handle(self, handle, timeout=None): + """Wait for a handle. + + Return a Future object. The result of the future is True if the wait + completed, or False if the wait did not complete (on timeout). + """ + return self._wait_for_handle(handle, timeout, False) + + def _wait_cancel(self, event, done_callback): + fut = self._wait_for_handle(event, None, True) + # add_done_callback() cannot be used because the wait may only complete + # in IocpProactor.close(), while the event loop is not running. + fut._done_callback = done_callback + return fut + + def _wait_for_handle(self, handle, timeout, _is_cancel): + if timeout is None: + ms = _winapi.INFINITE + else: + # RegisterWaitForSingleObject() has a resolution of 1 millisecond, + # round away from zero to wait *at least* timeout seconds. + ms = math.ceil(timeout * 1e3) + + # We only create ov so we can use ov.address as a key for the cache. + ov = _overlapped.Overlapped(NULL) + wait_handle = _overlapped.RegisterWaitWithQueue( + handle, self._iocp, ov.address, ms) + if _is_cancel: + f = _WaitCancelFuture(ov, handle, wait_handle, loop=self._loop) + else: + f = _WaitHandleFuture(ov, handle, wait_handle, self, + loop=self._loop) + if f._source_traceback: + del f._source_traceback[-1] + + def finish_wait_for_handle(trans, key, ov): + # Note that this second wait means that we should only use + # this with handles types where a successful wait has no + # effect. So events or processes are all right, but locks + # or semaphores are not. Also note if the handle is + # signalled and then quickly reset, then we may return + # False even though we have not timed out. + return f._poll() + + self._cache[ov.address] = (f, ov, 0, finish_wait_for_handle) + return f + + def _register_with_iocp(self, obj): + # To get notifications of finished ops on this objects sent to the + # completion port, were must register the handle. + if obj not in self._registered: + self._registered.add(obj) + _overlapped.CreateIoCompletionPort(obj.fileno(), self._iocp, 0, 0) + # XXX We could also use SetFileCompletionNotificationModes() + # to avoid sending notifications to completion port of ops + # that succeed immediately. + + def _register(self, ov, obj, callback): + # Return a future which will be set with the result of the + # operation when it completes. The future's value is actually + # the value returned by callback(). + f = _OverlappedFuture(ov, loop=self._loop) + if f._source_traceback: + del f._source_traceback[-1] + if not ov.pending: + # The operation has completed, so no need to postpone the + # work. We cannot take this short cut if we need the + # NumberOfBytes, CompletionKey values returned by + # PostQueuedCompletionStatus(). + try: + value = callback(None, None, ov) + except OSError as e: + f.set_exception(e) + else: + f.set_result(value) + # Even if GetOverlappedResult() was called, we have to wait for the + # notification of the completion in GetQueuedCompletionStatus(). + # Register the overlapped operation to keep a reference to the + # OVERLAPPED object, otherwise the memory is freed and Windows may + # read uninitialized memory. + + # Register the overlapped operation for later. Note that + # we only store obj to prevent it from being garbage + # collected too early. + self._cache[ov.address] = (f, ov, obj, callback) + return f + + def _unregister(self, ov): + """Unregister an overlapped object. + + Call this method when its future has been cancelled. The event can + already be signalled (pending in the proactor event queue). It is also + safe if the event is never signalled (because it was cancelled). + """ + self._unregistered.append(ov) + + def _get_accept_socket(self, family): + s = socket.socket(family) + s.settimeout(0) + return s + + def _poll(self, timeout=None): + if timeout is None: + ms = INFINITE + elif timeout < 0: + raise ValueError("negative timeout") + else: + # GetQueuedCompletionStatus() has a resolution of 1 millisecond, + # round away from zero to wait *at least* timeout seconds. + ms = math.ceil(timeout * 1e3) + if ms >= INFINITE: + raise ValueError("timeout too big") + + while True: + status = _overlapped.GetQueuedCompletionStatus(self._iocp, ms) + if status is None: + break + ms = 0 + + err, transferred, key, address = status + try: + f, ov, obj, callback = self._cache.pop(address) + except KeyError: + if self._loop.get_debug(): + self._loop.call_exception_handler({ + 'message': ('GetQueuedCompletionStatus() returned an ' + 'unexpected event'), + 'status': ('err=%s transferred=%s key=%#x address=%#x' + % (err, transferred, key, address)), + }) + + # key is either zero, or it is used to return a pipe + # handle which should be closed to avoid a leak. + if key not in (0, _overlapped.INVALID_HANDLE_VALUE): + _winapi.CloseHandle(key) + continue + + if obj in self._stopped_serving: + f.cancel() + # Don't call the callback if _register() already read the result or + # if the overlapped has been cancelled + elif not f.done(): + try: + value = callback(transferred, key, ov) + except OSError as e: + f.set_exception(e) + self._results.append(f) + else: + f.set_result(value) + self._results.append(f) + + # Remove unregisted futures + for ov in self._unregistered: + self._cache.pop(ov.address, None) + self._unregistered.clear() + + def _stop_serving(self, obj): + # obj is a socket or pipe handle. It will be closed in + # BaseProactorEventLoop._stop_serving() which will make any + # pending operations fail quickly. + self._stopped_serving.add(obj) + + def close(self): + # Cancel remaining registered operations. + for address, (fut, ov, obj, callback) in list(self._cache.items()): + if fut.cancelled(): + # Nothing to do with cancelled futures + pass + elif isinstance(fut, _WaitCancelFuture): + # _WaitCancelFuture must not be cancelled + pass + else: + try: + fut.cancel() + except OSError as exc: + if self._loop is not None: + context = { + 'message': 'Cancelling a future failed', + 'exception': exc, + 'future': fut, + } + if fut._source_traceback: + context['source_traceback'] = fut._source_traceback + self._loop.call_exception_handler(context) + + while self._cache: + if not self._poll(1): + logger.debug('taking long time to close proactor') + + self._results = [] + if self._iocp is not None: + _winapi.CloseHandle(self._iocp) + self._iocp = None + + def __del__(self): + self.close() + + +class _WindowsSubprocessTransport(base_subprocess.BaseSubprocessTransport): + + def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): + self._proc = windows_utils.Popen( + args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, + bufsize=bufsize, **kwargs) + + def callback(f): + returncode = self._proc.poll() + self._process_exited(returncode) + + f = self._loop._proactor.wait_for_handle(int(self._proc._handle)) + f.add_done_callback(callback) + + +SelectorEventLoop = _WindowsSelectorEventLoop + + +class _WindowsDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy): + _loop_factory = SelectorEventLoop + + +DefaultEventLoopPolicy = _WindowsDefaultEventLoopPolicy diff --git a/asyncio/windows_utils.py b/asyncio/windows_utils.py new file mode 100644 index 00000000..e6642960 --- /dev/null +++ b/asyncio/windows_utils.py @@ -0,0 +1,217 @@ +""" +Various Windows specific bits and pieces +""" + +import sys + +if sys.platform != 'win32': # pragma: no cover + raise ImportError('win32 only') + +import _winapi +import itertools +import msvcrt +import os +import socket +import subprocess +import tempfile + + +__all__ = ['socketpair', 'pipe', 'Popen', 'PIPE', 'PipeHandle'] + + +# Constants/globals + + +BUFSIZE = 8192 +PIPE = subprocess.PIPE +STDOUT = subprocess.STDOUT +_mmap_counter = itertools.count() + + +if hasattr(socket, 'socketpair'): + # Since Python 3.5, socket.socketpair() is now also available on Windows + socketpair = socket.socketpair +else: + # Replacement for socket.socketpair() + def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + """A socket pair usable as a self-pipe, for Windows. + + Origin: https://gist.github.com/4325783, by Geert Jansen. + Public domain. + """ + if family == socket.AF_INET: + host = '127.0.0.1' + elif family == socket.AF_INET6: + host = '::1' + else: + raise ValueError("Only AF_INET and AF_INET6 socket address " + "families are supported") + if type != socket.SOCK_STREAM: + raise ValueError("Only SOCK_STREAM socket type is supported") + if proto != 0: + raise ValueError("Only protocol zero is supported") + + # We create a connected TCP socket. Note the trick with setblocking(0) + # that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + try: + lsock.bind((host, 0)) + lsock.listen(1) + # On IPv6, ignore flow_info and scope_id + addr, port = lsock.getsockname()[:2] + csock = socket.socket(family, type, proto) + try: + csock.setblocking(False) + try: + csock.connect((addr, port)) + except (BlockingIOError, InterruptedError): + pass + csock.setblocking(True) + ssock, _ = lsock.accept() + except: + csock.close() + raise + finally: + lsock.close() + return (ssock, csock) + + +# Replacement for os.pipe() using handles instead of fds + + +def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): + """Like os.pipe() but with overlapped support and using handles not fds.""" + address = tempfile.mktemp(prefix=r'\\.\pipe\python-pipe-%d-%d-' % + (os.getpid(), next(_mmap_counter))) + + if duplex: + openmode = _winapi.PIPE_ACCESS_DUPLEX + access = _winapi.GENERIC_READ | _winapi.GENERIC_WRITE + obsize, ibsize = bufsize, bufsize + else: + openmode = _winapi.PIPE_ACCESS_INBOUND + access = _winapi.GENERIC_WRITE + obsize, ibsize = 0, bufsize + + openmode |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE + + if overlapped[0]: + openmode |= _winapi.FILE_FLAG_OVERLAPPED + + if overlapped[1]: + flags_and_attribs = _winapi.FILE_FLAG_OVERLAPPED + else: + flags_and_attribs = 0 + + h1 = h2 = None + try: + h1 = _winapi.CreateNamedPipe( + address, openmode, _winapi.PIPE_WAIT, + 1, obsize, ibsize, _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL) + + h2 = _winapi.CreateFile( + address, access, 0, _winapi.NULL, _winapi.OPEN_EXISTING, + flags_and_attribs, _winapi.NULL) + + ov = _winapi.ConnectNamedPipe(h1, overlapped=True) + ov.GetOverlappedResult(True) + return h1, h2 + except: + if h1 is not None: + _winapi.CloseHandle(h1) + if h2 is not None: + _winapi.CloseHandle(h2) + raise + + +# Wrapper for a pipe handle + + +class PipeHandle: + """Wrapper for an overlapped pipe handle which is vaguely file-object like. + + The IOCP event loop can use these instead of socket objects. + """ + def __init__(self, handle): + self._handle = handle + + def __repr__(self): + if self._handle is not None: + handle = 'handle=%r' % self._handle + else: + handle = 'closed' + return '<%s %s>' % (self.__class__.__name__, handle) + + @property + def handle(self): + return self._handle + + def fileno(self): + return self._handle + + def close(self, *, CloseHandle=_winapi.CloseHandle): + if self._handle is not None: + CloseHandle(self._handle) + self._handle = None + + __del__ = close + + def __enter__(self): + return self + + def __exit__(self, t, v, tb): + self.close() + + +# Replacement for subprocess.Popen using overlapped pipe handles + + +class Popen(subprocess.Popen): + """Replacement for subprocess.Popen using overlapped pipe handles. + + The stdin, stdout, stderr are None or instances of PipeHandle. + """ + def __init__(self, args, stdin=None, stdout=None, stderr=None, **kwds): + assert not kwds.get('universal_newlines') + assert kwds.get('bufsize', 0) == 0 + stdin_rfd = stdout_wfd = stderr_wfd = None + stdin_wh = stdout_rh = stderr_rh = None + if stdin == PIPE: + stdin_rh, stdin_wh = pipe(overlapped=(False, True), duplex=True) + stdin_rfd = msvcrt.open_osfhandle(stdin_rh, os.O_RDONLY) + else: + stdin_rfd = stdin + if stdout == PIPE: + stdout_rh, stdout_wh = pipe(overlapped=(True, False)) + stdout_wfd = msvcrt.open_osfhandle(stdout_wh, 0) + else: + stdout_wfd = stdout + if stderr == PIPE: + stderr_rh, stderr_wh = pipe(overlapped=(True, False)) + stderr_wfd = msvcrt.open_osfhandle(stderr_wh, 0) + elif stderr == STDOUT: + stderr_wfd = stdout_wfd + else: + stderr_wfd = stderr + try: + super().__init__(args, stdin=stdin_rfd, stdout=stdout_wfd, + stderr=stderr_wfd, **kwds) + except: + for h in (stdin_wh, stdout_rh, stderr_rh): + if h is not None: + _winapi.CloseHandle(h) + raise + else: + if stdin_wh is not None: + self.stdin = PipeHandle(stdin_wh) + if stdout_rh is not None: + self.stdout = PipeHandle(stdout_rh) + if stderr_rh is not None: + self.stderr = PipeHandle(stderr_rh) + finally: + if stdin == PIPE: + os.close(stdin_rfd) + if stdout == PIPE: + os.close(stdout_wfd) + if stderr == PIPE: + os.close(stderr_wfd) diff --git a/check.py b/check.py new file mode 100644 index 00000000..6db82d64 --- /dev/null +++ b/check.py @@ -0,0 +1,45 @@ +"""Search for lines >= 80 chars or with trailing whitespace.""" + +import os +import sys + + +def main(): + args = sys.argv[1:] or os.curdir + for arg in args: + if os.path.isdir(arg): + for dn, dirs, files in os.walk(arg): + for fn in sorted(files): + if fn.endswith('.py'): + process(os.path.join(dn, fn)) + dirs[:] = [d for d in dirs if d[0] != '.'] + dirs.sort() + else: + process(arg) + + +def isascii(x): + try: + x.encode('ascii') + return True + except UnicodeError: + return False + + +def process(fn): + try: + f = open(fn) + except IOError as err: + print(err) + return + try: + for i, line in enumerate(f): + line = line.rstrip('\n') + sline = line.rstrip() + if len(line) >= 80 or line != sline or not isascii(line): + print('{}:{:d}:{}{}'.format( + fn, i+1, sline, '_' * (len(line) - len(sline)))) + finally: + f.close() + +main() diff --git a/examples/cacheclt.py b/examples/cacheclt.py new file mode 100644 index 00000000..b11a4d1a --- /dev/null +++ b/examples/cacheclt.py @@ -0,0 +1,213 @@ +"""Client for cache server. + +See cachesvr.py for protocol description. +""" + +import argparse +import asyncio +from asyncio import test_utils +import json +import logging + +ARGS = argparse.ArgumentParser(description='Cache client example.') +ARGS.add_argument( + '--tls', action='store_true', dest='tls', + default=False, help='Use TLS') +ARGS.add_argument( + '--iocp', action='store_true', dest='iocp', + default=False, help='Use IOCP event loop (Windows only)') +ARGS.add_argument( + '--host', action='store', dest='host', + default='localhost', help='Host name') +ARGS.add_argument( + '--port', action='store', dest='port', + default=54321, type=int, help='Port number') +ARGS.add_argument( + '--timeout', action='store', dest='timeout', + default=5, type=float, help='Timeout') +ARGS.add_argument( + '--max_backoff', action='store', dest='max_backoff', + default=5, type=float, help='Max backoff on reconnect') +ARGS.add_argument( + '--ntasks', action='store', dest='ntasks', + default=10, type=int, help='Number of tester tasks') +ARGS.add_argument( + '--ntries', action='store', dest='ntries', + default=5, type=int, help='Number of request tries before giving up') + + +args = ARGS.parse_args() + + +class CacheClient: + """Multiplexing cache client. + + This wraps a single connection to the cache client. The + connection is automatically re-opened when an error occurs. + + Multiple tasks may share this object; the requests will be + serialized. + + The public API is get(), set(), delete() (all are coroutines). + """ + + def __init__(self, host, port, sslctx=None, loop=None): + self.host = host + self.port = port + self.sslctx = sslctx + self.loop = loop + self.todo = set() + self.initialized = False + self.task = asyncio.Task(self.activity(), loop=self.loop) + + @asyncio.coroutine + def get(self, key): + resp = yield from self.request('get', key) + if resp is None: + return None + return resp.get('value') + + @asyncio.coroutine + def set(self, key, value): + resp = yield from self.request('set', key, value) + if resp is None: + return False + return resp.get('status') == 'ok' + + @asyncio.coroutine + def delete(self, key): + resp = yield from self.request('delete', key) + if resp is None: + return False + return resp.get('status') == 'ok' + + @asyncio.coroutine + def request(self, type, key, value=None): + assert not self.task.done() + data = {'type': type, 'key': key} + if value is not None: + data['value'] = value + payload = json.dumps(data).encode('utf8') + waiter = asyncio.Future(loop=self.loop) + if self.initialized: + try: + yield from self.send(payload, waiter) + except IOError: + self.todo.add((payload, waiter)) + else: + self.todo.add((payload, waiter)) + return (yield from waiter) + + @asyncio.coroutine + def activity(self): + backoff = 0 + while True: + try: + self.reader, self.writer = yield from asyncio.open_connection( + self.host, self.port, ssl=self.sslctx, loop=self.loop) + except Exception as exc: + backoff = min(args.max_backoff, backoff + (backoff//2) + 1) + logging.info('Error connecting: %r; sleep %s', exc, backoff) + yield from asyncio.sleep(backoff, loop=self.loop) + continue + backoff = 0 + self.next_id = 0 + self.pending = {} + self. initialized = True + try: + while self.todo: + payload, waiter = self.todo.pop() + if not waiter.done(): + yield from self.send(payload, waiter) + while True: + resp_id, resp = yield from self.process() + if resp_id in self.pending: + payload, waiter = self.pending.pop(resp_id) + if not waiter.done(): + waiter.set_result(resp) + except Exception as exc: + self.initialized = False + self.writer.close() + while self.pending: + req_id, pair = self.pending.popitem() + payload, waiter = pair + if not waiter.done(): + self.todo.add(pair) + logging.info('Error processing: %r', exc) + + @asyncio.coroutine + def send(self, payload, waiter): + self.next_id += 1 + req_id = self.next_id + frame = 'request %d %d\n' % (req_id, len(payload)) + self.writer.write(frame.encode('ascii')) + self.writer.write(payload) + self.pending[req_id] = payload, waiter + yield from self.writer.drain() + + @asyncio.coroutine + def process(self): + frame = yield from self.reader.readline() + if not frame: + raise EOFError() + head, tail = frame.split(None, 1) + if head == b'error': + raise IOError('OOB error: %r' % tail) + if head != b'response': + raise IOError('Bad frame: %r' % frame) + resp_id, resp_size = map(int, tail.split()) + data = yield from self.reader.readexactly(resp_size) + if len(data) != resp_size: + raise EOFError() + resp = json.loads(data.decode('utf8')) + return resp_id, resp + + +def main(): + asyncio.set_event_loop(None) + if args.iocp: + from asyncio.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + else: + loop = asyncio.new_event_loop() + sslctx = None + if args.tls: + sslctx = test_utils.dummy_ssl_context() + cache = CacheClient(args.host, args.port, sslctx=sslctx, loop=loop) + try: + loop.run_until_complete( + asyncio.gather( + *[testing(i, cache, loop) for i in range(args.ntasks)], + loop=loop)) + finally: + loop.close() + + +@asyncio.coroutine +def testing(label, cache, loop): + + def w(g): + return asyncio.wait_for(g, args.timeout, loop=loop) + + key = 'foo-%s' % label + while True: + logging.info('%s %s', label, '-'*20) + try: + ret = yield from w(cache.set(key, 'hello-%s-world' % label)) + logging.info('%s set %s', label, ret) + ret = yield from w(cache.get(key)) + logging.info('%s get %s', label, ret) + ret = yield from w(cache.delete(key)) + logging.info('%s del %s', label, ret) + ret = yield from w(cache.get(key)) + logging.info('%s get2 %s', label, ret) + except asyncio.TimeoutError: + logging.warn('%s Timeout', label) + except Exception as exc: + logging.exception('%s Client exception: %r', label, exc) + break + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + main() diff --git a/examples/cachesvr.py b/examples/cachesvr.py new file mode 100644 index 00000000..053f9c21 --- /dev/null +++ b/examples/cachesvr.py @@ -0,0 +1,249 @@ +"""A simple memcache-like server. + +The basic data structure maintained is a single in-memory dictionary +mapping string keys to string values, with operations get, set and +delete. (Both keys and values may contain Unicode.) + +This is a TCP server listening on port 54321. There is no +authentication. + +Requests provide an operation and return a response. A connection may +be used for multiple requests. The connection is closed when a client +sends a bad request. + +If a client is idle for over 5 seconds (i.e., it does not send another +request, or fails to read the whole response, within this time), it is +disconnected. + +Framing of requests and responses within a connection uses a +line-based protocol. The first line of a request is the frame header +and contains three whitespace-delimited token followed by LF or CRLF: + +- the keyword 'request' +- a decimal request ID; the first request is '1', the second '2', etc. +- a decimal byte count giving the size of the rest of the request + +Note that the requests ID *must* be consecutive and start at '1' for +each connection. + +Response frames look the same except the keyword is 'response'. The +response ID matches the request ID. There should be exactly one +response to each request and responses should be seen in the same +order as the requests. + +After the frame, individual requests and responses are JSON encoded. + +If the frame header or the JSON request body cannot be parsed, an +unframed error message (always starting with 'error') is written back +and the connection is closed. + +JSON-encoded requests can be: + +- {"type": "get", "key": } +- {"type": "set", "key": , "value": } +- {"type": "delete", "key": } + +Responses are also JSON-encoded: + +- {"status": "ok", "value": } # Successful get request +- {"status": "ok"} # Successful set or delete request +- {"status": "notfound"} # Key not found for get or delete request + +If the request is valid JSON but cannot be handled (e.g., the type or +key field is absent or invalid), an error response of the following +form is returned, but the connection is not closed: + +- {"error": } +""" + +import argparse +import asyncio +import json +import logging +import os +import random + +ARGS = argparse.ArgumentParser(description='Cache server example.') +ARGS.add_argument( + '--tls', action='store_true', dest='tls', + default=False, help='Use TLS') +ARGS.add_argument( + '--iocp', action='store_true', dest='iocp', + default=False, help='Use IOCP event loop (Windows only)') +ARGS.add_argument( + '--host', action='store', dest='host', + default='localhost', help='Host name') +ARGS.add_argument( + '--port', action='store', dest='port', + default=54321, type=int, help='Port number') +ARGS.add_argument( + '--timeout', action='store', dest='timeout', + default=5, type=float, help='Timeout') +ARGS.add_argument( + '--random_failure_percent', action='store', dest='fail_percent', + default=0, type=float, help='Fail randomly N percent of the time') +ARGS.add_argument( + '--random_failure_sleep', action='store', dest='fail_sleep', + default=0, type=float, help='Sleep time when randomly failing') +ARGS.add_argument( + '--random_response_sleep', action='store', dest='resp_sleep', + default=0, type=float, help='Sleep time before responding') + +args = ARGS.parse_args() + + +class Cache: + + def __init__(self, loop): + self.loop = loop + self.table = {} + + @asyncio.coroutine + def handle_client(self, reader, writer): + # Wrapper to log stuff and close writer (i.e., transport). + peer = writer.get_extra_info('socket').getpeername() + logging.info('got a connection from %s', peer) + try: + yield from self.frame_parser(reader, writer) + except Exception as exc: + logging.error('error %r from %s', exc, peer) + else: + logging.info('end connection from %s', peer) + finally: + writer.close() + + @asyncio.coroutine + def frame_parser(self, reader, writer): + # This takes care of the framing. + last_request_id = 0 + while True: + # Read the frame header, parse it, read the data. + # NOTE: The readline() and readexactly() calls will hang + # if the client doesn't send enough data but doesn't + # disconnect either. We add a timeout to each. (But the + # timeout should really be implemented by StreamReader.) + framing_b = yield from asyncio.wait_for( + reader.readline(), + timeout=args.timeout, loop=self.loop) + if random.random()*100 < args.fail_percent: + logging.warn('Inserting random failure') + yield from asyncio.sleep(args.fail_sleep*random.random(), + loop=self.loop) + writer.write(b'error random failure\r\n') + break + logging.debug('framing_b = %r', framing_b) + if not framing_b: + break # Clean close. + try: + frame_keyword, request_id_b, byte_count_b = framing_b.split() + except ValueError: + writer.write(b'error unparseable frame\r\n') + break + if frame_keyword != b'request': + writer.write(b'error frame does not start with request\r\n') + break + try: + request_id, byte_count = int(request_id_b), int(byte_count_b) + except ValueError: + writer.write(b'error unparsable frame parameters\r\n') + break + if request_id != last_request_id + 1 or byte_count < 2: + writer.write(b'error invalid frame parameters\r\n') + break + last_request_id = request_id + request_b = yield from asyncio.wait_for( + reader.readexactly(byte_count), + timeout=args.timeout, loop=self.loop) + try: + request = json.loads(request_b.decode('utf8')) + except ValueError: + writer.write(b'error unparsable json\r\n') + break + response = self.handle_request(request) # Not a coroutine. + if response is None: + writer.write(b'error unhandlable request\r\n') + break + response_b = json.dumps(response).encode('utf8') + b'\r\n' + byte_count = len(response_b) + framing_s = 'response {} {}\r\n'.format(request_id, byte_count) + writer.write(framing_s.encode('ascii')) + yield from asyncio.sleep(args.resp_sleep*random.random(), + loop=self.loop) + writer.write(response_b) + + def handle_request(self, request): + # This parses one request and farms it out to a specific handler. + # Return None for all errors. + if not isinstance(request, dict): + return {'error': 'request is not a dict'} + request_type = request.get('type') + if request_type is None: + return {'error': 'no type in request'} + if request_type not in {'get', 'set', 'delete'}: + return {'error': 'unknown request type'} + key = request.get('key') + if not isinstance(key, str): + return {'error': 'key is not a string'} + if request_type == 'get': + return self.handle_get(key) + if request_type == 'set': + value = request.get('value') + if not isinstance(value, str): + return {'error': 'value is not a string'} + return self.handle_set(key, value) + if request_type == 'delete': + return self.handle_delete(key) + assert False, 'bad request type' # Should have been caught above. + + def handle_get(self, key): + value = self.table.get(key) + if value is None: + return {'status': 'notfound'} + else: + return {'status': 'ok', 'value': value} + + def handle_set(self, key, value): + self.table[key] = value + return {'status': 'ok'} + + def handle_delete(self, key): + if key not in self.table: + return {'status': 'notfound'} + else: + del self.table[key] + return {'status': 'ok'} + + +def main(): + asyncio.set_event_loop(None) + if args.iocp: + from asyncio.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + else: + loop = asyncio.new_event_loop() + sslctx = None + if args.tls: + import ssl + # TODO: take cert/key from args as well. + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + sslctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslctx.options |= ssl.OP_NO_SSLv2 + sslctx.load_cert_chain( + certfile=os.path.join(here, 'ssl_cert.pem'), + keyfile=os.path.join(here, 'ssl_key.pem')) + cache = Cache(loop) + task = asyncio.streams.start_server(cache.handle_client, + args.host, args.port, + ssl=sslctx, loop=loop) + svr = loop.run_until_complete(task) + for sock in svr.sockets: + logging.info('socket %s', sock.getsockname()) + try: + loop.run_forever() + finally: + loop.close() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + main() diff --git a/examples/child_process.py b/examples/child_process.py new file mode 100644 index 00000000..3fac175e --- /dev/null +++ b/examples/child_process.py @@ -0,0 +1,128 @@ +""" +Example of asynchronous interaction with a child python process. + +This example shows how to attach an existing Popen object and use the low level +transport-protocol API. See shell.py and subprocess_shell.py for higher level +examples. +""" + +import os +import sys + +try: + import asyncio +except ImportError: + # asyncio is not installed + sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + import asyncio + +if sys.platform == 'win32': + from asyncio.windows_utils import Popen, PIPE + from asyncio.windows_events import ProactorEventLoop +else: + from subprocess import Popen, PIPE + +# +# Return a write-only transport wrapping a writable pipe +# + +@asyncio.coroutine +def connect_write_pipe(file): + loop = asyncio.get_event_loop() + transport, _ = yield from loop.connect_write_pipe(asyncio.Protocol, file) + return transport + +# +# Wrap a readable pipe in a stream +# + +@asyncio.coroutine +def connect_read_pipe(file): + loop = asyncio.get_event_loop() + stream_reader = asyncio.StreamReader(loop=loop) + def factory(): + return asyncio.StreamReaderProtocol(stream_reader) + transport, _ = yield from loop.connect_read_pipe(factory, file) + return stream_reader, transport + + +# +# Example +# + +@asyncio.coroutine +def main(loop): + # program which prints evaluation of each expression from stdin + code = r'''if 1: + import os + def writeall(fd, buf): + while buf: + n = os.write(fd, buf) + buf = buf[n:] + while True: + s = os.read(0, 1024) + if not s: + break + s = s.decode('ascii') + s = repr(eval(s)) + '\n' + s = s.encode('ascii') + writeall(1, s) + ''' + + # commands to send to input + commands = iter([b"1+1\n", + b"2**16\n", + b"1/3\n", + b"'x'*50", + b"1/0\n"]) + + # start subprocess and wrap stdin, stdout, stderr + p = Popen([sys.executable, '-c', code], + stdin=PIPE, stdout=PIPE, stderr=PIPE) + + stdin = yield from connect_write_pipe(p.stdin) + stdout, stdout_transport = yield from connect_read_pipe(p.stdout) + stderr, stderr_transport = yield from connect_read_pipe(p.stderr) + + # interact with subprocess + name = {stdout:'OUT', stderr:'ERR'} + registered = {asyncio.Task(stderr.readline()): stderr, + asyncio.Task(stdout.readline()): stdout} + while registered: + # write command + cmd = next(commands, None) + if cmd is None: + stdin.close() + else: + print('>>>', cmd.decode('ascii').rstrip()) + stdin.write(cmd) + + # get and print lines from stdout, stderr + timeout = None + while registered: + done, pending = yield from asyncio.wait( + registered, timeout=timeout, + return_when=asyncio.FIRST_COMPLETED) + if not done: + break + for f in done: + stream = registered.pop(f) + res = f.result() + print(name[stream], res.decode('ascii').rstrip()) + if res != b'': + registered[asyncio.Task(stream.readline())] = stream + timeout = 0.0 + + stdout_transport.close() + stderr_transport.close() + +if __name__ == '__main__': + if sys.platform == 'win32': + loop = ProactorEventLoop() + asyncio.set_event_loop(loop) + else: + loop = asyncio.get_event_loop() + try: + loop.run_until_complete(main(loop)) + finally: + loop.close() diff --git a/examples/crawl.py b/examples/crawl.py new file mode 100644 index 00000000..4bb0b4ea --- /dev/null +++ b/examples/crawl.py @@ -0,0 +1,863 @@ +#!/usr/bin/env python3.4 + +"""A simple web crawler.""" + +# TODO: +# - More organized logging (with task ID or URL?). +# - Use logging module for Logger. +# - KeyboardInterrupt in HTML parsing may hang or report unretrieved error. +# - Support gzip encoding. +# - Close connection if HTTP/1.0 response. +# - Add timeouts. (E.g. when switching networks, all seems to hang.) +# - Add arguments to specify TLS settings (e.g. cert/key files). +# - Skip reading large non-text/html files? +# - Use ETag and If-Modified-Since? +# - Handle out of file descriptors directly? (How?) + +import argparse +import asyncio +import asyncio.locks +import cgi +from http.client import BadStatusLine +import logging +import re +import sys +import time +import urllib.parse + + +ARGS = argparse.ArgumentParser(description="Web crawler") +ARGS.add_argument( + '--iocp', action='store_true', dest='iocp', + default=False, help='Use IOCP event loop (Windows only)') +ARGS.add_argument( + '--select', action='store_true', dest='select', + default=False, help='Use Select event loop instead of default') +ARGS.add_argument( + 'roots', nargs='*', + default=[], help='Root URL (may be repeated)') +ARGS.add_argument( + '--max_redirect', action='store', type=int, metavar='N', + default=10, help='Limit redirection chains (for 301, 302 etc.)') +ARGS.add_argument( + '--max_tries', action='store', type=int, metavar='N', + default=4, help='Limit retries on network errors') +ARGS.add_argument( + '--max_tasks', action='store', type=int, metavar='N', + default=100, help='Limit concurrent connections') +ARGS.add_argument( + '--max_pool', action='store', type=int, metavar='N', + default=100, help='Limit connection pool size') +ARGS.add_argument( + '--exclude', action='store', metavar='REGEX', + help='Exclude matching URLs') +ARGS.add_argument( + '--strict', action='store_true', + default=True, help='Strict host matching (default)') +ARGS.add_argument( + '--lenient', action='store_false', dest='strict', + default=False, help='Lenient host matching') +ARGS.add_argument( + '-v', '--verbose', action='count', dest='level', + default=1, help='Verbose logging (repeat for more verbose)') +ARGS.add_argument( + '-q', '--quiet', action='store_const', const=0, dest='level', + default=1, help='Quiet logging (opposite of --verbose)') + + +ESCAPES = [('quot', '"'), + ('gt', '>'), + ('lt', '<'), + ('amp', '&') # Must be last. + ] + + +def unescape(url): + """Turn & into &, and so on. + + This is the inverse of cgi.escape(). + """ + for name, char in ESCAPES: + url = url.replace('&' + name + ';', char) + return url + + +def fix_url(url): + """Prefix a schema-less URL with http://.""" + if '://' not in url: + url = 'http://' + url + return url + + +class Logger: + + def __init__(self, level): + self.level = level + + def _log(self, n, args): + if self.level >= n: + print(*args, file=sys.stderr, flush=True) + + def log(self, n, *args): + self._log(n, args) + + def __call__(self, n, *args): + self._log(n, args) + + +class ConnectionPool: + """A connection pool. + + To open a connection, use reserve(). To recycle it, use unreserve(). + + The pool is mostly just a mapping from (host, port, ssl) tuples to + lists of Connections. The currently active connections are *not* + in the data structure; get_connection() takes the connection out, + and recycle_connection() puts it back in. To recycle a + connection, call conn.close(recycle=True). + + There are limits to both the overall pool and the per-key pool. + """ + + def __init__(self, log, max_pool=10, max_tasks=5): + self.log = log + self.max_pool = max_pool # Overall limit. + self.max_tasks = max_tasks # Per-key limit. + self.loop = asyncio.get_event_loop() + self.connections = {} # {(host, port, ssl): [Connection, ...], ...} + self.queue = [] # [Connection, ...] + + def close(self): + """Close all connections available for reuse.""" + for conns in self.connections.values(): + for conn in conns: + conn.close() + self.connections.clear() + self.queue.clear() + + @asyncio.coroutine + def get_connection(self, host, port, ssl): + """Create or reuse a connection.""" + port = port or (443 if ssl else 80) + try: + ipaddrs = yield from self.loop.getaddrinfo(host, port) + except Exception as exc: + self.log(0, 'Exception %r for (%r, %r)' % (exc, host, port)) + raise + self.log(1, '* %s resolves to %s' % + (host, ', '.join(ip[4][0] for ip in ipaddrs))) + + # Look for a reusable connection. + for _, _, _, _, (h, p, *_) in ipaddrs: + key = h, p, ssl + conn = None + conns = self.connections.get(key) + while conns: + conn = conns.pop(0) + self.queue.remove(conn) + if not conns: + del self.connections[key] + if conn.stale(): + self.log(1, 'closing stale connection for', key) + conn.close() # Just in case. + else: + self.log(1, '* Reusing pooled connection', key, + 'FD =', conn.fileno()) + return conn + + # Create a new connection. + conn = Connection(self.log, self, host, port, ssl) + yield from conn.connect() + self.log(1, '* New connection', conn.key, 'FD =', conn.fileno()) + return conn + + def recycle_connection(self, conn): + """Make a connection available for reuse. + + This also prunes the pool if it exceeds the size limits. + """ + if conn.stale(): + conn.close() + return + + key = conn.key + conns = self.connections.setdefault(key, []) + conns.append(conn) + self.queue.append(conn) + + if len(conns) <= self.max_tasks and len(self.queue) <= self.max_pool: + return + + # Prune the queue. + + # Close stale connections for this key first. + stale = [conn for conn in conns if conn.stale()] + if stale: + for conn in stale: + conns.remove(conn) + self.queue.remove(conn) + self.log(1, 'closing stale connection for', key) + conn.close() + if not conns: + del self.connections[key] + + # Close oldest connection(s) for this key if limit reached. + while len(conns) > self.max_tasks: + conn = conns.pop(0) + self.queue.remove(conn) + self.log(1, 'closing oldest connection for', key) + conn.close() + + if len(self.queue) <= self.max_pool: + return + + # Close overall stale connections. + stale = [conn for conn in self.queue if conn.stale()] + if stale: + for conn in stale: + conns = self.connections.get(conn.key) + conns.remove(conn) + self.queue.remove(conn) + self.log(1, 'closing stale connection for', key) + conn.close() + + # Close oldest overall connection(s) if limit reached. + while len(self.queue) > self.max_pool: + conn = self.queue.pop(0) + conns = self.connections.get(conn.key) + c = conns.pop(0) + assert conn == c, (conn.key, conn, c, conns) + self.log(1, 'closing overall oldest connection for', conn.key) + conn.close() + + +class Connection: + + def __init__(self, log, pool, host, port, ssl): + self.log = log + self.pool = pool + self.host = host + self.port = port + self.ssl = ssl + self.reader = None + self.writer = None + self.key = None + + def stale(self): + return self.reader is None or self.reader.at_eof() + + def fileno(self): + writer = self.writer + if writer is not None: + transport = writer.transport + if transport is not None: + sock = transport.get_extra_info('socket') + if sock is not None: + return sock.fileno() + return None + + @asyncio.coroutine + def connect(self): + self.reader, self.writer = yield from asyncio.open_connection( + self.host, self.port, ssl=self.ssl) + peername = self.writer.get_extra_info('peername') + if peername: + self.host, self.port = peername[:2] + else: + self.log(1, 'NO PEERNAME???', self.host, self.port, self.ssl) + self.key = self.host, self.port, self.ssl + + def close(self, recycle=False): + if recycle and not self.stale(): + self.pool.recycle_connection(self) + else: + self.writer.close() + self.pool = self.reader = self.writer = None + + +class Request: + """HTTP request. + + Use connect() to open a connection; send_request() to send the + request; get_response() to receive the response headers. + """ + + def __init__(self, log, url, pool): + self.log = log + self.url = url + self.pool = pool + self.parts = urllib.parse.urlparse(self.url) + self.scheme = self.parts.scheme + assert self.scheme in ('http', 'https'), repr(url) + self.ssl = self.parts.scheme == 'https' + self.netloc = self.parts.netloc + self.hostname = self.parts.hostname + self.port = self.parts.port or (443 if self.ssl else 80) + self.path = (self.parts.path or '/') + self.query = self.parts.query + if self.query: + self.full_path = '%s?%s' % (self.path, self.query) + else: + self.full_path = self.path + self.http_version = 'HTTP/1.1' + self.method = 'GET' + self.headers = [] + self.conn = None + + @asyncio.coroutine + def connect(self): + """Open a connection to the server.""" + self.log(1, '* Connecting to %s:%s using %s for %s' % + (self.hostname, self.port, + 'ssl' if self.ssl else 'tcp', + self.url)) + self.conn = yield from self.pool.get_connection(self.hostname, + self.port, self.ssl) + + def close(self, recycle=False): + """Close the connection, recycle if requested.""" + if self.conn is not None: + if not recycle: + self.log(1, 'closing connection for', self.conn.key) + self.conn.close(recycle) + self.conn = None + + @asyncio.coroutine + def putline(self, line): + """Write a line to the connection. + + Used for the request line and headers. + """ + self.log(2, '>', line) + self.conn.writer.write(line.encode('latin-1') + b'\r\n') + + @asyncio.coroutine + def send_request(self): + """Send the request.""" + request_line = '%s %s %s' % (self.method, self.full_path, + self.http_version) + yield from self.putline(request_line) + # TODO: What if a header is already set? + self.headers.append(('User-Agent', 'asyncio-example-crawl/0.0')) + self.headers.append(('Host', self.netloc)) + self.headers.append(('Accept', '*/*')) + ##self.headers.append(('Accept-Encoding', 'gzip')) + for key, value in self.headers: + line = '%s: %s' % (key, value) + yield from self.putline(line) + yield from self.putline('') + + @asyncio.coroutine + def get_response(self): + """Receive the response.""" + response = Response(self.log, self.conn.reader) + yield from response.read_headers() + return response + + +class Response: + """HTTP response. + + Call read_headers() to receive the request headers. Then check + the status attribute and call get_header() to inspect the headers. + Finally call read() to receive the body. + """ + + def __init__(self, log, reader): + self.log = log + self.reader = reader + self.http_version = None # 'HTTP/1.1' + self.status = None # 200 + self.reason = None # 'Ok' + self.headers = [] # [('Content-Type', 'text/html')] + + @asyncio.coroutine + def getline(self): + """Read one line from the connection.""" + line = (yield from self.reader.readline()).decode('latin-1').rstrip() + self.log(2, '<', line) + return line + + @asyncio.coroutine + def read_headers(self): + """Read the response status and the request headers.""" + status_line = yield from self.getline() + status_parts = status_line.split(None, 2) + if len(status_parts) != 3: + self.log(0, 'bad status_line', repr(status_line)) + raise BadStatusLine(status_line) + self.http_version, status, self.reason = status_parts + self.status = int(status) + while True: + header_line = yield from self.getline() + if not header_line: + break + # TODO: Continuation lines. + key, value = header_line.split(':', 1) + self.headers.append((key, value.strip())) + + def get_redirect_url(self, default=''): + """Inspect the status and return the redirect url if appropriate.""" + if self.status not in (300, 301, 302, 303, 307): + return default + return self.get_header('Location', default) + + def get_header(self, key, default=''): + """Get one header value, using a case insensitive header name.""" + key = key.lower() + for k, v in self.headers: + if k.lower() == key: + return v + return default + + @asyncio.coroutine + def read(self): + """Read the response body. + + This honors Content-Length and Transfer-Encoding: chunked. + """ + nbytes = None + for key, value in self.headers: + if key.lower() == 'content-length': + nbytes = int(value) + break + if nbytes is None: + if self.get_header('transfer-encoding').lower() == 'chunked': + self.log(2, 'parsing chunked response') + blocks = [] + while True: + size_header = yield from self.reader.readline() + if not size_header: + self.log(0, 'premature end of chunked response') + break + self.log(3, 'size_header =', repr(size_header)) + parts = size_header.split(b';') + size = int(parts[0], 16) + if size: + self.log(3, 'reading chunk of', size, 'bytes') + block = yield from self.reader.readexactly(size) + assert len(block) == size, (len(block), size) + blocks.append(block) + crlf = yield from self.reader.readline() + assert crlf == b'\r\n', repr(crlf) + if not size: + break + body = b''.join(blocks) + self.log(1, 'chunked response had', len(body), + 'bytes in', len(blocks), 'blocks') + else: + self.log(3, 'reading until EOF') + body = yield from self.reader.read() + # TODO: Should make sure not to recycle the connection + # in this case. + else: + body = yield from self.reader.readexactly(nbytes) + return body + + +class Fetcher: + """Logic and state for one URL. + + When found in crawler.busy, this represents a URL to be fetched or + in the process of being fetched; when found in crawler.done, this + holds the results from fetching it. + + This is usually associated with a task. This references the + crawler for the connection pool and to add more URLs to its todo + list. + + Call fetch() to do the fetching, then report() to print the results. + """ + + def __init__(self, log, url, crawler, max_redirect=10, max_tries=4): + self.log = log + self.url = url + self.crawler = crawler + # We don't loop resolving redirects here -- we just use this + # to decide whether to add the redirect URL to crawler.todo. + self.max_redirect = max_redirect + # But we do loop to retry on errors a few times. + self.max_tries = max_tries + # Everything we collect from the response goes here. + self.task = None + self.exceptions = [] + self.tries = 0 + self.request = None + self.response = None + self.body = None + self.next_url = None + self.ctype = None + self.pdict = None + self.encoding = None + self.urls = None + self.new_urls = None + + @asyncio.coroutine + def fetch(self): + """Attempt to fetch the contents of the URL. + + If successful, and the data is HTML, extract further links and + add them to the crawler. Redirects are also added back there. + """ + while self.tries < self.max_tries: + self.tries += 1 + self.request = None + try: + self.request = Request(self.log, self.url, self.crawler.pool) + yield from self.request.connect() + yield from self.request.send_request() + self.response = yield from self.request.get_response() + self.body = yield from self.response.read() + h_conn = self.response.get_header('connection').lower() + if h_conn != 'close': + self.request.close(recycle=True) + self.request = None + if self.tries > 1: + self.log(1, 'try', self.tries, 'for', self.url, 'success') + break + except (BadStatusLine, OSError) as exc: + self.exceptions.append(exc) + self.log(1, 'try', self.tries, 'for', self.url, + 'raised', repr(exc)) + ##import pdb; pdb.set_trace() + # Don't reuse the connection in this case. + finally: + if self.request is not None: + self.request.close() + else: + # We never broke out of the while loop, i.e. all tries failed. + self.log(0, 'no success for', self.url, + 'in', self.max_tries, 'tries') + return + next_url = self.response.get_redirect_url() + if next_url: + self.next_url = urllib.parse.urljoin(self.url, next_url) + if self.max_redirect > 0: + self.log(1, 'redirect to', self.next_url, 'from', self.url) + self.crawler.add_url(self.next_url, self.max_redirect-1) + else: + self.log(0, 'redirect limit reached for', self.next_url, + 'from', self.url) + else: + if self.response.status == 200: + self.ctype = self.response.get_header('content-type') + self.pdict = {} + if self.ctype: + self.ctype, self.pdict = cgi.parse_header(self.ctype) + self.encoding = self.pdict.get('charset', 'utf-8') + if self.ctype == 'text/html': + body = self.body.decode(self.encoding, 'replace') + # Replace href with (?:href|src) to follow image links. + self.urls = set(re.findall(r'(?i)href=["\']?([^\s"\'<>]+)', + body)) + if self.urls: + self.log(1, 'got', len(self.urls), + 'distinct urls from', self.url) + self.new_urls = set() + for url in self.urls: + url = unescape(url) + url = urllib.parse.urljoin(self.url, url) + url, frag = urllib.parse.urldefrag(url) + if self.crawler.add_url(url): + self.new_urls.add(url) + + def report(self, stats, file=None): + """Print a report on the state for this URL. + + Also update the Stats instance. + """ + if self.task is not None: + if not self.task.done(): + stats.add('pending') + print(self.url, 'pending', file=file) + return + elif self.task.cancelled(): + stats.add('cancelled') + print(self.url, 'cancelled', file=file) + return + elif self.task.exception(): + stats.add('exception') + exc = self.task.exception() + stats.add('exception_' + exc.__class__.__name__) + print(self.url, exc, file=file) + return + if len(self.exceptions) == self.tries: + stats.add('fail') + exc = self.exceptions[-1] + stats.add('fail_' + str(exc.__class__.__name__)) + print(self.url, 'error', exc, file=file) + elif self.next_url: + stats.add('redirect') + print(self.url, self.response.status, 'redirect', self.next_url, + file=file) + elif self.ctype == 'text/html': + stats.add('html') + size = len(self.body or b'') + stats.add('html_bytes', size) + print(self.url, self.response.status, + self.ctype, self.encoding, + size, + '%d/%d' % (len(self.new_urls or ()), len(self.urls or ())), + file=file) + elif self.response is None: + print(self.url, 'no response object') + else: + size = len(self.body or b'') + if self.response.status == 200: + stats.add('other') + stats.add('other_bytes', size) + else: + stats.add('error') + stats.add('error_bytes', size) + stats.add('status_%s' % self.response.status) + print(self.url, self.response.status, + self.ctype, self.encoding, + size, + file=file) + + +class Stats: + """Record stats of various sorts.""" + + def __init__(self): + self.stats = {} + + def add(self, key, count=1): + self.stats[key] = self.stats.get(key, 0) + count + + def report(self, file=None): + for key, count in sorted(self.stats.items()): + print('%10d' % count, key, file=file) + + +class Crawler: + """Crawl a set of URLs. + + This manages three disjoint sets of URLs (todo, busy, done). The + data structures actually store dicts -- the values in todo give + the redirect limit, while the values in busy and done are Fetcher + instances. + """ + def __init__(self, log, + roots, exclude=None, strict=True, # What to crawl. + max_redirect=10, max_tries=4, # Per-url limits. + max_tasks=10, max_pool=10, # Global limits. + ): + self.log = log + self.roots = roots + self.exclude = exclude + self.strict = strict + self.max_redirect = max_redirect + self.max_tries = max_tries + self.max_tasks = max_tasks + self.max_pool = max_pool + self.todo = {} + self.busy = {} + self.done = {} + self.pool = ConnectionPool(self.log, max_pool, max_tasks) + self.root_domains = set() + for root in roots: + parts = urllib.parse.urlparse(root) + host, port = urllib.parse.splitport(parts.netloc) + if not host: + continue + if re.match(r'\A[\d\.]*\Z', host): + self.root_domains.add(host) + else: + host = host.lower() + if self.strict: + self.root_domains.add(host) + if host.startswith('www.'): + self.root_domains.add(host[4:]) + else: + self.root_domains.add('www.' + host) + else: + parts = host.split('.') + if len(parts) > 2: + host = '.'.join(parts[-2:]) + self.root_domains.add(host) + for root in roots: + self.add_url(root) + self.governor = asyncio.locks.Semaphore(max_tasks) + self.termination = asyncio.locks.Condition() + self.t0 = time.time() + self.t1 = None + + def close(self): + """Close resources (currently only the pool).""" + self.pool.close() + + def host_okay(self, host): + """Check if a host should be crawled. + + A literal match (after lowercasing) is always good. For hosts + that don't look like IP addresses, some approximate matches + are okay depending on the strict flag. + """ + host = host.lower() + if host in self.root_domains: + return True + if re.match(r'\A[\d\.]*\Z', host): + return False + if self.strict: + return self._host_okay_strictish(host) + else: + return self._host_okay_lenient(host) + + def _host_okay_strictish(self, host): + """Check if a host should be crawled, strict-ish version. + + This checks for equality modulo an initial 'www.' component. + """ + if host.startswith('www.'): + if host[4:] in self.root_domains: + return True + else: + if 'www.' + host in self.root_domains: + return True + return False + + def _host_okay_lenient(self, host): + """Check if a host should be crawled, lenient version. + + This compares the last two components of the host. + """ + parts = host.split('.') + if len(parts) > 2: + host = '.'.join(parts[-2:]) + return host in self.root_domains + + def add_url(self, url, max_redirect=None): + """Add a URL to the todo list if not seen before.""" + if self.exclude and re.search(self.exclude, url): + return False + parts = urllib.parse.urlparse(url) + if parts.scheme not in ('http', 'https'): + self.log(2, 'skipping non-http scheme in', url) + return False + host, port = urllib.parse.splitport(parts.netloc) + if not self.host_okay(host): + self.log(2, 'skipping non-root host in', url) + return False + if max_redirect is None: + max_redirect = self.max_redirect + if url in self.todo or url in self.busy or url in self.done: + return False + self.log(1, 'adding', url, max_redirect) + self.todo[url] = max_redirect + return True + + @asyncio.coroutine + def crawl(self): + """Run the crawler until all finished.""" + with (yield from self.termination): + while self.todo or self.busy: + if self.todo: + url, max_redirect = self.todo.popitem() + fetcher = Fetcher(self.log, url, + crawler=self, + max_redirect=max_redirect, + max_tries=self.max_tries, + ) + self.busy[url] = fetcher + fetcher.task = asyncio.Task(self.fetch(fetcher)) + else: + yield from self.termination.wait() + self.t1 = time.time() + + @asyncio.coroutine + def fetch(self, fetcher): + """Call the Fetcher's fetch(), with a limit on concurrency. + + Once this returns, move the fetcher from busy to done. + """ + url = fetcher.url + with (yield from self.governor): + try: + yield from fetcher.fetch() # Fetcher gonna fetch. + finally: + # Force GC of the task, so the error is logged. + fetcher.task = None + with (yield from self.termination): + self.done[url] = fetcher + del self.busy[url] + self.termination.notify() + + def report(self, file=None): + """Print a report on all completed URLs.""" + if self.t1 is None: + self.t1 = time.time() + dt = self.t1 - self.t0 + if dt and self.max_tasks: + speed = len(self.done) / dt / self.max_tasks + else: + speed = 0 + stats = Stats() + print('*** Report ***', file=file) + try: + show = [] + show.extend(self.done.items()) + show.extend(self.busy.items()) + show.sort() + for url, fetcher in show: + fetcher.report(stats, file=file) + except KeyboardInterrupt: + print('\nInterrupted', file=file) + print('Finished', len(self.done), + 'urls in %.3f secs' % dt, + '(max_tasks=%d)' % self.max_tasks, + '(%.3f urls/sec/task)' % speed, + file=file) + stats.report(file=file) + print('Todo:', len(self.todo), file=file) + print('Busy:', len(self.busy), file=file) + print('Done:', len(self.done), file=file) + print('Date:', time.ctime(), 'local time', file=file) + + +def main(): + """Main program. + + Parse arguments, set up event loop, run crawler, print report. + """ + args = ARGS.parse_args() + if not args.roots: + print('Use --help for command line help') + return + + log = Logger(args.level) + + if args.iocp: + from asyncio.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + asyncio.set_event_loop(loop) + elif args.select: + loop = asyncio.SelectorEventLoop() + asyncio.set_event_loop(loop) + else: + loop = asyncio.get_event_loop() + + roots = {fix_url(root) for root in args.roots} + + crawler = Crawler(log, + roots, exclude=args.exclude, + strict=args.strict, + max_redirect=args.max_redirect, + max_tries=args.max_tries, + max_tasks=args.max_tasks, + max_pool=args.max_pool, + ) + try: + loop.run_until_complete(crawler.crawl()) # Crawler gonna crawl. + except KeyboardInterrupt: + sys.stderr.flush() + print('\nInterrupted\n') + finally: + crawler.report() + crawler.close() + loop.close() + + +if __name__ == '__main__': + logging.basicConfig(level=logging.INFO) + main() diff --git a/examples/echo_client_tulip.py b/examples/echo_client_tulip.py new file mode 100644 index 00000000..88124efe --- /dev/null +++ b/examples/echo_client_tulip.py @@ -0,0 +1,20 @@ +import asyncio + +END = b'Bye-bye!\n' + +@asyncio.coroutine +def echo_client(): + reader, writer = yield from asyncio.open_connection('localhost', 8000) + writer.write(b'Hello, world\n') + writer.write(b'What a fine day it is.\n') + writer.write(END) + while True: + line = yield from reader.readline() + print('received:', line) + if line == END or not line: + break + writer.close() + +loop = asyncio.get_event_loop() +loop.run_until_complete(echo_client()) +loop.close() diff --git a/examples/echo_server_tulip.py b/examples/echo_server_tulip.py new file mode 100644 index 00000000..8167e540 --- /dev/null +++ b/examples/echo_server_tulip.py @@ -0,0 +1,20 @@ +import asyncio + +@asyncio.coroutine +def echo_server(): + yield from asyncio.start_server(handle_connection, 'localhost', 8000) + +@asyncio.coroutine +def handle_connection(reader, writer): + while True: + data = yield from reader.read(8192) + if not data: + break + writer.write(data) + +loop = asyncio.get_event_loop() +loop.run_until_complete(echo_server()) +try: + loop.run_forever() +finally: + loop.close() diff --git a/examples/fetch0.py b/examples/fetch0.py new file mode 100644 index 00000000..180fcf26 --- /dev/null +++ b/examples/fetch0.py @@ -0,0 +1,35 @@ +"""Simplest possible HTTP client.""" + +import sys + +from asyncio import * + + +@coroutine +def fetch(): + r, w = yield from open_connection('python.org', 80) + request = 'GET / HTTP/1.0\r\n\r\n' + print('>', request, file=sys.stderr) + w.write(request.encode('latin-1')) + while True: + line = yield from r.readline() + line = line.decode('latin-1').rstrip() + if not line: + break + print('<', line, file=sys.stderr) + print(file=sys.stderr) + body = yield from r.read() + return body + + +def main(): + loop = get_event_loop() + try: + body = loop.run_until_complete(fetch()) + finally: + loop.close() + print(body.decode('latin-1'), end='') + + +if __name__ == '__main__': + main() diff --git a/examples/fetch1.py b/examples/fetch1.py new file mode 100644 index 00000000..8dbb6e47 --- /dev/null +++ b/examples/fetch1.py @@ -0,0 +1,78 @@ +"""Fetch one URL and write its content to stdout. + +This version adds URL parsing (including SSL) and a Response object. +""" + +import sys +import urllib.parse + +from asyncio import * + + +class Response: + + def __init__(self, verbose=True): + self.verbose = verbose + self.http_version = None # 'HTTP/1.1' + self.status = None # 200 + self.reason = None # 'Ok' + self.headers = [] # [('Content-Type', 'text/html')] + + @coroutine + def read(self, reader): + @coroutine + def getline(): + return (yield from reader.readline()).decode('latin-1').rstrip() + status_line = yield from getline() + if self.verbose: print('<', status_line, file=sys.stderr) + self.http_version, status, self.reason = status_line.split(None, 2) + self.status = int(status) + while True: + header_line = yield from getline() + if not header_line: + break + if self.verbose: print('<', header_line, file=sys.stderr) + # TODO: Continuation lines. + key, value = header_line.split(':', 1) + self.headers.append((key, value.strip())) + if self.verbose: print(file=sys.stderr) + + +@coroutine +def fetch(url, verbose=True): + parts = urllib.parse.urlparse(url) + if parts.scheme == 'http': + ssl = False + elif parts.scheme == 'https': + ssl = True + else: + print('URL must use http or https.') + sys.exit(1) + port = parts.port + if port is None: + port = 443 if ssl else 80 + path = parts.path or '/' + if parts.query: + path += '?' + parts.query + request = 'GET %s HTTP/1.0\r\n\r\n' % path + if verbose: + print('>', request, file=sys.stderr, end='') + r, w = yield from open_connection(parts.hostname, port, ssl=ssl) + w.write(request.encode('latin-1')) + response = Response(verbose) + yield from response.read(r) + body = yield from r.read() + return body + + +def main(): + loop = get_event_loop() + try: + body = loop.run_until_complete(fetch(sys.argv[1], '-v' in sys.argv)) + finally: + loop.close() + print(body.decode('latin-1'), end='') + + +if __name__ == '__main__': + main() diff --git a/examples/fetch2.py b/examples/fetch2.py new file mode 100644 index 00000000..7617b59b --- /dev/null +++ b/examples/fetch2.py @@ -0,0 +1,141 @@ +"""Fetch one URL and write its content to stdout. + +This version adds a Request object. +""" + +import sys +import urllib.parse +from http.client import BadStatusLine + +from asyncio import * + + +class Request: + + def __init__(self, url, verbose=True): + self.url = url + self.verbose = verbose + self.parts = urllib.parse.urlparse(self.url) + self.scheme = self.parts.scheme + assert self.scheme in ('http', 'https'), repr(url) + self.ssl = self.parts.scheme == 'https' + self.netloc = self.parts.netloc + self.hostname = self.parts.hostname + self.port = self.parts.port or (443 if self.ssl else 80) + self.path = (self.parts.path or '/') + self.query = self.parts.query + if self.query: + self.full_path = '%s?%s' % (self.path, self.query) + else: + self.full_path = self.path + self.http_version = 'HTTP/1.1' + self.method = 'GET' + self.headers = [] + self.reader = None + self.writer = None + + @coroutine + def connect(self): + if self.verbose: + print('* Connecting to %s:%s using %s' % + (self.hostname, self.port, 'ssl' if self.ssl else 'tcp'), + file=sys.stderr) + self.reader, self.writer = yield from open_connection(self.hostname, + self.port, + ssl=self.ssl) + if self.verbose: + print('* Connected to %s' % + (self.writer.get_extra_info('peername'),), + file=sys.stderr) + + def putline(self, line): + self.writer.write(line.encode('latin-1') + b'\r\n') + + @coroutine + def send_request(self): + request = '%s %s %s' % (self.method, self.full_path, self.http_version) + if self.verbose: print('>', request, file=sys.stderr) + self.putline(request) + if 'host' not in {key.lower() for key, _ in self.headers}: + self.headers.insert(0, ('Host', self.netloc)) + for key, value in self.headers: + line = '%s: %s' % (key, value) + if self.verbose: print('>', line, file=sys.stderr) + self.putline(line) + self.putline('') + + @coroutine + def get_response(self): + response = Response(self.reader, self.verbose) + yield from response.read_headers() + return response + + +class Response: + + def __init__(self, reader, verbose=True): + self.reader = reader + self.verbose = verbose + self.http_version = None # 'HTTP/1.1' + self.status = None # 200 + self.reason = None # 'Ok' + self.headers = [] # [('Content-Type', 'text/html')] + + @coroutine + def getline(self): + return (yield from self.reader.readline()).decode('latin-1').rstrip() + + @coroutine + def read_headers(self): + status_line = yield from self.getline() + if self.verbose: print('<', status_line, file=sys.stderr) + status_parts = status_line.split(None, 2) + if len(status_parts) != 3: + raise BadStatusLine(status_line) + self.http_version, status, self.reason = status_parts + self.status = int(status) + while True: + header_line = yield from self.getline() + if not header_line: + break + if self.verbose: print('<', header_line, file=sys.stderr) + # TODO: Continuation lines. + key, value = header_line.split(':', 1) + self.headers.append((key, value.strip())) + if self.verbose: print(file=sys.stderr) + + @coroutine + def read(self): + nbytes = None + for key, value in self.headers: + if key.lower() == 'content-length': + nbytes = int(value) + break + if nbytes is None: + body = yield from self.reader.read() + else: + body = yield from self.reader.readexactly(nbytes) + return body + + +@coroutine +def fetch(url, verbose=True): + request = Request(url, verbose) + yield from request.connect() + yield from request.send_request() + response = yield from request.get_response() + body = yield from response.read() + return body + + +def main(): + loop = get_event_loop() + try: + body = loop.run_until_complete(fetch(sys.argv[1], '-v' in sys.argv)) + finally: + loop.close() + sys.stdout.buffer.write(body) + + +if __name__ == '__main__': + main() diff --git a/examples/fetch3.py b/examples/fetch3.py new file mode 100644 index 00000000..9419afd2 --- /dev/null +++ b/examples/fetch3.py @@ -0,0 +1,230 @@ +"""Fetch one URL and write its content to stdout. + +This version adds a primitive connection pool, redirect following and +chunked transfer-encoding. It also supports a --iocp flag. +""" + +import sys +import urllib.parse +from http.client import BadStatusLine + +from asyncio import * + + +class ConnectionPool: + # TODO: Locking? Close idle connections? + + def __init__(self, verbose=False): + self.verbose = verbose + self.connections = {} # {(host, port, ssl): (reader, writer)} + + def close(self): + for _, writer in self.connections.values(): + writer.close() + + @coroutine + def open_connection(self, host, port, ssl): + port = port or (443 if ssl else 80) + ipaddrs = yield from get_event_loop().getaddrinfo(host, port) + if self.verbose: + print('* %s resolves to %s' % + (host, ', '.join(ip[4][0] for ip in ipaddrs)), + file=sys.stderr) + for _, _, _, _, (h, p, *_) in ipaddrs: + key = h, p, ssl + conn = self.connections.get(key) + if conn: + reader, writer = conn + if reader._eof: + self.connections.pop(key) + continue + if self.verbose: + print('* Reusing pooled connection', key, file=sys.stderr) + return conn + reader, writer = yield from open_connection(host, port, ssl=ssl) + host, port, *_ = writer.get_extra_info('peername') + key = host, port, ssl + self.connections[key] = reader, writer + if self.verbose: + print('* New connection', key, file=sys.stderr) + return reader, writer + + +class Request: + + def __init__(self, url, verbose=True): + self.url = url + self.verbose = verbose + self.parts = urllib.parse.urlparse(self.url) + self.scheme = self.parts.scheme + assert self.scheme in ('http', 'https'), repr(url) + self.ssl = self.parts.scheme == 'https' + self.netloc = self.parts.netloc + self.hostname = self.parts.hostname + self.port = self.parts.port or (443 if self.ssl else 80) + self.path = (self.parts.path or '/') + self.query = self.parts.query + if self.query: + self.full_path = '%s?%s' % (self.path, self.query) + else: + self.full_path = self.path + self.http_version = 'HTTP/1.1' + self.method = 'GET' + self.headers = [] + self.reader = None + self.writer = None + + def vprint(self, *args): + if self.verbose: + print(*args, file=sys.stderr) + + @coroutine + def connect(self, pool): + self.vprint('* Connecting to %s:%s using %s' % + (self.hostname, self.port, 'ssl' if self.ssl else 'tcp')) + self.reader, self.writer = \ + yield from pool.open_connection(self.hostname, + self.port, + ssl=self.ssl) + self.vprint('* Connected to %s' % + (self.writer.get_extra_info('peername'),)) + + @coroutine + def putline(self, line): + self.vprint('>', line) + self.writer.write(line.encode('latin-1') + b'\r\n') + ##yield from self.writer.drain() + + @coroutine + def send_request(self): + request = '%s %s %s' % (self.method, self.full_path, self.http_version) + yield from self.putline(request) + if 'host' not in {key.lower() for key, _ in self.headers}: + self.headers.insert(0, ('Host', self.netloc)) + for key, value in self.headers: + line = '%s: %s' % (key, value) + yield from self.putline(line) + yield from self.putline('') + + @coroutine + def get_response(self): + response = Response(self.reader, self.verbose) + yield from response.read_headers() + return response + + +class Response: + + def __init__(self, reader, verbose=True): + self.reader = reader + self.verbose = verbose + self.http_version = None # 'HTTP/1.1' + self.status = None # 200 + self.reason = None # 'Ok' + self.headers = [] # [('Content-Type', 'text/html')] + + def vprint(self, *args): + if self.verbose: + print(*args, file=sys.stderr) + + @coroutine + def getline(self): + line = (yield from self.reader.readline()).decode('latin-1').rstrip() + self.vprint('<', line) + return line + + @coroutine + def read_headers(self): + status_line = yield from self.getline() + status_parts = status_line.split(None, 2) + if len(status_parts) != 3: + raise BadStatusLine(status_line) + self.http_version, status, self.reason = status_parts + self.status = int(status) + while True: + header_line = yield from self.getline() + if not header_line: + break + # TODO: Continuation lines. + key, value = header_line.split(':', 1) + self.headers.append((key, value.strip())) + + def get_redirect_url(self, default=None): + if self.status not in (300, 301, 302, 303, 307): + return default + return self.get_header('Location', default) + + def get_header(self, key, default=None): + key = key.lower() + for k, v in self.headers: + if k.lower() == key: + return v + return default + + @coroutine + def read(self): + nbytes = None + for key, value in self.headers: + if key.lower() == 'content-length': + nbytes = int(value) + break + if nbytes is None: + if self.get_header('transfer-encoding', '').lower() == 'chunked': + blocks = [] + size = -1 + while size: + size_header = yield from self.reader.readline() + if not size_header: + break + parts = size_header.split(b';') + size = int(parts[0], 16) + if size: + block = yield from self.reader.readexactly(size) + assert len(block) == size, (len(block), size) + blocks.append(block) + crlf = yield from self.reader.readline() + assert crlf == b'\r\n', repr(crlf) + body = b''.join(blocks) + else: + body = yield from self.reader.read() + else: + body = yield from self.reader.readexactly(nbytes) + return body + + +@coroutine +def fetch(url, verbose=True, max_redirect=10): + pool = ConnectionPool(verbose) + try: + for _ in range(max_redirect): + request = Request(url, verbose) + yield from request.connect(pool) + yield from request.send_request() + response = yield from request.get_response() + body = yield from response.read() + next_url = response.get_redirect_url() + if not next_url: + break + url = urllib.parse.urljoin(url, next_url) + print('redirect to', url, file=sys.stderr) + return body + finally: + pool.close() + + +def main(): + if '--iocp' in sys.argv: + from asyncio.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + set_event_loop(loop) + else: + loop = get_event_loop() + try: + body = loop.run_until_complete(fetch(sys.argv[1], '-v' in sys.argv)) + finally: + loop.close() + sys.stdout.buffer.write(body) + + +if __name__ == '__main__': + main() diff --git a/examples/fuzz_as_completed.py b/examples/fuzz_as_completed.py new file mode 100644 index 00000000..123fbf1b --- /dev/null +++ b/examples/fuzz_as_completed.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 + +"""Fuzz tester for as_completed(), by Glenn Langford.""" + +import asyncio +import itertools +import random +import sys + +@asyncio.coroutine +def sleeper(time): + yield from asyncio.sleep(time) + return time + +@asyncio.coroutine +def watcher(tasks,delay=False): + res = [] + for t in asyncio.as_completed(tasks): + r = yield from t + res.append(r) + if delay: + # simulate processing delay + process_time = random.random() / 10 + yield from asyncio.sleep(process_time) + #print(res) + #assert(sorted(res) == res) + if sorted(res) != res: + print('FAIL', res) + print('------------') + else: + print('.', end='') + sys.stdout.flush() + +loop = asyncio.get_event_loop() + +print('Pass 1') +# All permutations of discrete task running times must be returned +# by as_completed in the correct order. +task_times = [0, 0.1, 0.2, 0.3, 0.4 ] # 120 permutations +for times in itertools.permutations(task_times): + tasks = [ asyncio.Task(sleeper(t)) for t in times ] + loop.run_until_complete(asyncio.Task(watcher(tasks))) + +print() +print('Pass 2') +# Longer task times, with randomized duplicates. 100 tasks each time. +longer_task_times = [x/10 for x in range(30)] +for i in range(20): + task_times = longer_task_times * 10 + random.shuffle(task_times) + #print('Times', task_times[:500]) + tasks = [ asyncio.Task(sleeper(t)) for t in task_times[:100] ] + loop.run_until_complete(asyncio.Task(watcher(tasks))) + +print() +print('Pass 3') +# Same as pass 2, but with a random processing delay (0 - 0.1s) after +# retrieving each future from as_completed and 200 tasks. This tests whether +# the order that callbacks are triggered is preserved through to the +# as_completed caller. +for i in range(20): + task_times = longer_task_times * 10 + random.shuffle(task_times) + #print('Times', task_times[:200]) + tasks = [ asyncio.Task(sleeper(t)) for t in task_times[:200] ] + loop.run_until_complete(asyncio.Task(watcher(tasks, delay=True))) + +print() +loop.close() diff --git a/examples/hello_callback.py b/examples/hello_callback.py new file mode 100644 index 00000000..7ccbea1e --- /dev/null +++ b/examples/hello_callback.py @@ -0,0 +1,17 @@ +"""Print 'Hello World' every two seconds, using a callback.""" + +import asyncio + + +def print_and_repeat(loop): + print('Hello World') + loop.call_later(2, print_and_repeat, loop) + + +if __name__ == '__main__': + loop = asyncio.get_event_loop() + print_and_repeat(loop) + try: + loop.run_forever() + finally: + loop.close() diff --git a/examples/hello_coroutine.py b/examples/hello_coroutine.py new file mode 100644 index 00000000..b9347aa8 --- /dev/null +++ b/examples/hello_coroutine.py @@ -0,0 +1,18 @@ +"""Print 'Hello World' every two seconds, using a coroutine.""" + +import asyncio + + +@asyncio.coroutine +def greet_every_two_seconds(): + while True: + print('Hello World') + yield from asyncio.sleep(2) + + +if __name__ == '__main__': + loop = asyncio.get_event_loop() + try: + loop.run_until_complete(greet_every_two_seconds()) + finally: + loop.close() diff --git a/examples/shell.py b/examples/shell.py new file mode 100644 index 00000000..7dc7caf3 --- /dev/null +++ b/examples/shell.py @@ -0,0 +1,50 @@ +"""Examples using create_subprocess_exec() and create_subprocess_shell().""" + +import asyncio +import signal +from asyncio.subprocess import PIPE + +@asyncio.coroutine +def cat(loop): + proc = yield from asyncio.create_subprocess_shell("cat", + stdin=PIPE, + stdout=PIPE) + print("pid: %s" % proc.pid) + + message = "Hello World!" + print("cat write: %r" % message) + + stdout, stderr = yield from proc.communicate(message.encode('ascii')) + print("cat read: %r" % stdout.decode('ascii')) + + exitcode = yield from proc.wait() + print("(exit code %s)" % exitcode) + +@asyncio.coroutine +def ls(loop): + proc = yield from asyncio.create_subprocess_exec("ls", + stdout=PIPE) + while True: + line = yield from proc.stdout.readline() + if not line: + break + print("ls>>", line.decode('ascii').rstrip()) + try: + proc.send_signal(signal.SIGINT) + except ProcessLookupError: + pass + +@asyncio.coroutine +def test_call(*args, timeout=None): + try: + proc = yield from asyncio.create_subprocess_exec(*args) + exitcode = yield from asyncio.wait_for(proc.wait(), timeout) + print("%s: exit code %s" % (' '.join(args), exitcode)) + except asyncio.TimeoutError: + print("timeout! (%.1f sec)" % timeout) + +loop = asyncio.get_event_loop() +loop.run_until_complete(cat(loop)) +loop.run_until_complete(ls(loop)) +loop.run_until_complete(test_call("bash", "-c", "sleep 3", timeout=1.0)) +loop.close() diff --git a/examples/simple_tcp_server.py b/examples/simple_tcp_server.py new file mode 100644 index 00000000..5f874ffc --- /dev/null +++ b/examples/simple_tcp_server.py @@ -0,0 +1,154 @@ +""" +Example of a simple TCP server that is written in (mostly) coroutine +style and uses asyncio.streams.start_server() and +asyncio.streams.open_connection(). + +Note that running this example starts both the TCP server and client +in the same process. It listens on port 12345 on 127.0.0.1, so it will +fail if this port is currently in use. +""" + +import sys +import asyncio +import asyncio.streams + + +class MyServer: + """ + This is just an example of how a TCP server might be potentially + structured. This class has basically 3 methods: start the server, + handle a client, and stop the server. + + Note that you don't have to follow this structure, it is really + just an example or possible starting point. + """ + + def __init__(self): + self.server = None # encapsulates the server sockets + + # this keeps track of all the clients that connected to our + # server. It can be useful in some cases, for instance to + # kill client connections or to broadcast some data to all + # clients... + self.clients = {} # task -> (reader, writer) + + def _accept_client(self, client_reader, client_writer): + """ + This method accepts a new client connection and creates a Task + to handle this client. self.clients is updated to keep track + of the new client. + """ + + # start a new Task to handle this specific client connection + task = asyncio.Task(self._handle_client(client_reader, client_writer)) + self.clients[task] = (client_reader, client_writer) + + def client_done(task): + print("client task done:", task, file=sys.stderr) + del self.clients[task] + + task.add_done_callback(client_done) + + @asyncio.coroutine + def _handle_client(self, client_reader, client_writer): + """ + This method actually does the work to handle the requests for + a specific client. The protocol is line oriented, so there is + a main loop that reads a line with a request and then sends + out one or more lines back to the client with the result. + """ + while True: + data = (yield from client_reader.readline()).decode("utf-8") + if not data: # an empty string means the client disconnected + break + cmd, *args = data.rstrip().split(' ') + if cmd == 'add': + arg1 = float(args[0]) + arg2 = float(args[1]) + retval = arg1 + arg2 + client_writer.write("{!r}\n".format(retval).encode("utf-8")) + elif cmd == 'repeat': + times = int(args[0]) + msg = args[1] + client_writer.write("begin\n".encode("utf-8")) + for idx in range(times): + client_writer.write("{}. {}\n".format(idx+1, msg) + .encode("utf-8")) + client_writer.write("end\n".encode("utf-8")) + else: + print("Bad command {!r}".format(data), file=sys.stderr) + + # This enables us to have flow control in our connection. + yield from client_writer.drain() + + def start(self, loop): + """ + Starts the TCP server, so that it listens on port 12345. + + For each client that connects, the accept_client method gets + called. This method runs the loop until the server sockets + are ready to accept connections. + """ + self.server = loop.run_until_complete( + asyncio.streams.start_server(self._accept_client, + '127.0.0.1', 12345, + loop=loop)) + + def stop(self, loop): + """ + Stops the TCP server, i.e. closes the listening socket(s). + + This method runs the loop until the server sockets are closed. + """ + if self.server is not None: + self.server.close() + loop.run_until_complete(self.server.wait_closed()) + self.server = None + + +def main(): + loop = asyncio.get_event_loop() + + # creates a server and starts listening to TCP connections + server = MyServer() + server.start(loop) + + @asyncio.coroutine + def client(): + reader, writer = yield from asyncio.streams.open_connection( + '127.0.0.1', 12345, loop=loop) + + def send(msg): + print("> " + msg) + writer.write((msg + '\n').encode("utf-8")) + + def recv(): + msgback = (yield from reader.readline()).decode("utf-8").rstrip() + print("< " + msgback) + return msgback + + # send a line + send("add 1 2") + msg = yield from recv() + + send("repeat 5 hello") + msg = yield from recv() + assert msg == 'begin' + while True: + msg = yield from recv() + if msg == 'end': + break + + writer.close() + yield from asyncio.sleep(0.5) + + # creates a client and connects to our server + try: + loop.run_until_complete(client()) + server.stop(loop) + finally: + loop.close() + + +if __name__ == '__main__': + main() diff --git a/examples/sink.py b/examples/sink.py new file mode 100644 index 00000000..d362cbb2 --- /dev/null +++ b/examples/sink.py @@ -0,0 +1,94 @@ +"""Test service that accepts connections and reads all data off them.""" + +import argparse +import os +import sys + +from asyncio import * + +ARGS = argparse.ArgumentParser(description="TCP data sink example.") +ARGS.add_argument( + '--tls', action='store_true', dest='tls', + default=False, help='Use TLS with a self-signed cert') +ARGS.add_argument( + '--iocp', action='store_true', dest='iocp', + default=False, help='Use IOCP event loop (Windows only)') +ARGS.add_argument( + '--host', action='store', dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action='store', dest='port', + default=1111, type=int, help='Port number') +ARGS.add_argument( + '--maxsize', action='store', dest='maxsize', + default=16*1024*1024, type=int, help='Max total data size') + +server = None +args = None + + +def dprint(*args): + print('sink:', *args, file=sys.stderr) + + +class Service(Protocol): + + def connection_made(self, tr): + dprint('connection from', tr.get_extra_info('peername')) + dprint('my socket is', tr.get_extra_info('sockname')) + self.tr = tr + self.total = 0 + + def data_received(self, data): + if data == b'stop': + dprint('stopping server') + server.close() + self.tr.close() + return + self.total += len(data) + dprint('received', len(data), 'bytes; total', self.total) + if self.total > args.maxsize: + dprint('closing due to too much data') + self.tr.close() + + def connection_lost(self, how): + dprint('closed', repr(how)) + + +@coroutine +def start(loop, host, port): + global server + sslctx = None + if args.tls: + import ssl + # TODO: take cert/key from args as well. + here = os.path.join(os.path.dirname(__file__), '..', 'tests') + sslctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslctx.options |= ssl.OP_NO_SSLv2 + sslctx.load_cert_chain( + certfile=os.path.join(here, 'ssl_cert.pem'), + keyfile=os.path.join(here, 'ssl_key.pem')) + + server = yield from loop.create_server(Service, host, port, ssl=sslctx) + dprint('serving TLS' if sslctx else 'serving', + [s.getsockname() for s in server.sockets]) + yield from server.wait_closed() + + +def main(): + global args + args = ARGS.parse_args() + if args.iocp: + from asyncio.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + set_event_loop(loop) + else: + loop = get_event_loop() + try: + loop.run_until_complete(start(loop, args.host, args.port)) + finally: + loop.close() + + +if __name__ == '__main__': + main() diff --git a/examples/source.py b/examples/source.py new file mode 100644 index 00000000..7fd11fb0 --- /dev/null +++ b/examples/source.py @@ -0,0 +1,100 @@ +"""Test client that connects and sends infinite data.""" + +import argparse +import sys + +from asyncio import * +from asyncio import test_utils + + +ARGS = argparse.ArgumentParser(description="TCP data sink example.") +ARGS.add_argument( + '--tls', action='store_true', dest='tls', + default=False, help='Use TLS') +ARGS.add_argument( + '--iocp', action='store_true', dest='iocp', + default=False, help='Use IOCP event loop (Windows only)') +ARGS.add_argument( + '--stop', action='store_true', dest='stop', + default=False, help='Stop the server by sending it b"stop" as data') +ARGS.add_argument( + '--host', action='store', dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action='store', dest='port', + default=1111, type=int, help='Port number') +ARGS.add_argument( + '--size', action='store', dest='size', + default=16*1024, type=int, help='Data size') + +args = None + + +def dprint(*args): + print('source:', *args, file=sys.stderr) + + +class Client(Protocol): + + total = 0 + + def connection_made(self, tr): + dprint('connecting to', tr.get_extra_info('peername')) + dprint('my socket is', tr.get_extra_info('sockname')) + self.tr = tr + self.lost = False + self.loop = get_event_loop() + self.waiter = Future() + if args.stop: + self.tr.write(b'stop') + self.tr.close() + else: + self.data = b'x'*args.size + self.write_some_data() + + def write_some_data(self): + if self.lost: + dprint('lost already') + return + data = self.data + size = len(data) + self.total += size + dprint('writing', size, 'bytes; total', self.total) + self.tr.write(data) + self.loop.call_soon(self.write_some_data) + + def connection_lost(self, exc): + dprint('lost connection', repr(exc)) + self.lost = True + self.waiter.set_result(None) + + +@coroutine +def start(loop, host, port): + sslctx = None + if args.tls: + sslctx = test_utils.dummy_ssl_context() + tr, pr = yield from loop.create_connection(Client, host, port, + ssl=sslctx) + dprint('tr =', tr) + dprint('pr =', pr) + yield from pr.waiter + + +def main(): + global args + args = ARGS.parse_args() + if args.iocp: + from asyncio.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + set_event_loop(loop) + else: + loop = get_event_loop() + try: + loop.run_until_complete(start(loop, args.host, args.port)) + finally: + loop.close() + + +if __name__ == '__main__': + main() diff --git a/examples/source1.py b/examples/source1.py new file mode 100644 index 00000000..6802e963 --- /dev/null +++ b/examples/source1.py @@ -0,0 +1,98 @@ +"""Like source.py, but uses streams.""" + +import argparse +import sys + +from asyncio import * +from asyncio import test_utils + +ARGS = argparse.ArgumentParser(description="TCP data sink example.") +ARGS.add_argument( + '--tls', action='store_true', dest='tls', + default=False, help='Use TLS') +ARGS.add_argument( + '--iocp', action='store_true', dest='iocp', + default=False, help='Use IOCP event loop (Windows only)') +ARGS.add_argument( + '--stop', action='store_true', dest='stop', + default=False, help='Stop the server by sending it b"stop" as data') +ARGS.add_argument( + '--host', action='store', dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action='store', dest='port', + default=1111, type=int, help='Port number') +ARGS.add_argument( + '--size', action='store', dest='size', + default=16*1024, type=int, help='Data size') + + +class Debug: + """A clever little class that suppresses repetitive messages.""" + + overwriting = False + label = 'stream1:' + + def print(self, *args): + if self.overwriting: + print(file=sys.stderr) + self.overwriting = 0 + print(self.label, *args, file=sys.stderr) + + def oprint(self, *args): + self.overwriting += 1 + end = '\n' + if self.overwriting >= 3: + if self.overwriting == 3: + print(self.label, '[...]', file=sys.stderr) + end = '\r' + print(self.label, *args, file=sys.stderr, end=end, flush=True) + + +@coroutine +def start(loop, args): + d = Debug() + total = 0 + sslctx = None + if args.tls: + d.print('using dummy SSLContext') + sslctx = test_utils.dummy_ssl_context() + r, w = yield from open_connection(args.host, args.port, ssl=sslctx) + d.print('r =', r) + d.print('w =', w) + if args.stop: + w.write(b'stop') + w.close() + else: + size = args.size + data = b'x'*size + try: + while True: + total += size + d.oprint('writing', size, 'bytes; total', total) + w.write(data) + f = w.drain() + if f: + d.print('pausing') + yield from f + except (ConnectionResetError, BrokenPipeError) as exc: + d.print('caught', repr(exc)) + + +def main(): + global args + args = ARGS.parse_args() + if args.iocp: + from asyncio.windows_events import ProactorEventLoop + loop = ProactorEventLoop() + set_event_loop(loop) + else: + loop = get_event_loop() + try: + loop.run_until_complete(start(loop, args)) + finally: + loop.close() + + +if __name__ == '__main__': + main() diff --git a/examples/stacks.py b/examples/stacks.py new file mode 100644 index 00000000..0b7e0b2c --- /dev/null +++ b/examples/stacks.py @@ -0,0 +1,44 @@ +"""Crude demo for print_stack().""" + + +from asyncio import * + + +@coroutine +def helper(r): + print('--- helper ---') + for t in Task.all_tasks(): + t.print_stack() + print('--- end helper ---') + line = yield from r.readline() + 1/0 + return line + +def doit(): + l = get_event_loop() + lr = l.run_until_complete + r, w = lr(open_connection('python.org', 80)) + t1 = async(helper(r)) + for t in Task.all_tasks(): t.print_stack() + print('---') + l._run_once() + for t in Task.all_tasks(): t.print_stack() + print('---') + w.write(b'GET /\r\n') + w.write_eof() + try: + lr(t1) + except Exception as e: + print('catching', e) + finally: + for t in Task.all_tasks(): + t.print_stack() + l.close() + + +def main(): + doit() + + +if __name__ == '__main__': + main() diff --git a/examples/subprocess_attach_read_pipe.py b/examples/subprocess_attach_read_pipe.py new file mode 100644 index 00000000..d8a62420 --- /dev/null +++ b/examples/subprocess_attach_read_pipe.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +"""Example showing how to attach a read pipe to a subprocess.""" +import asyncio +import os, sys + +code = """ +import os, sys +fd = int(sys.argv[1]) +os.write(fd, b'data') +os.close(fd) +""" + +loop = asyncio.get_event_loop() + +@asyncio.coroutine +def task(): + rfd, wfd = os.pipe() + args = [sys.executable, '-c', code, str(wfd)] + + pipe = open(rfd, 'rb', 0) + reader = asyncio.StreamReader(loop=loop) + protocol = asyncio.StreamReaderProtocol(reader, loop=loop) + transport, _ = yield from loop.connect_read_pipe(lambda: protocol, pipe) + + proc = yield from asyncio.create_subprocess_exec(*args, pass_fds={wfd}) + yield from proc.wait() + + os.close(wfd) + data = yield from reader.read() + print("read = %r" % data.decode()) + +loop.run_until_complete(task()) +loop.close() diff --git a/examples/subprocess_attach_write_pipe.py b/examples/subprocess_attach_write_pipe.py new file mode 100644 index 00000000..86148774 --- /dev/null +++ b/examples/subprocess_attach_write_pipe.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 +"""Example showing how to attach a write pipe to a subprocess.""" +import asyncio +import os, sys +from asyncio import subprocess + +code = """ +import os, sys +fd = int(sys.argv[1]) +data = os.read(fd, 1024) +sys.stdout.buffer.write(data) +""" + +loop = asyncio.get_event_loop() + +@asyncio.coroutine +def task(): + rfd, wfd = os.pipe() + args = [sys.executable, '-c', code, str(rfd)] + proc = yield from asyncio.create_subprocess_exec( + *args, + pass_fds={rfd}, + stdout=subprocess.PIPE) + + pipe = open(wfd, 'wb', 0) + transport, _ = yield from loop.connect_write_pipe(asyncio.Protocol, + pipe) + transport.write(b'data') + + stdout, stderr = yield from proc.communicate() + print("stdout = %r" % stdout.decode()) + pipe.close() + +loop.run_until_complete(task()) +loop.close() diff --git a/examples/subprocess_shell.py b/examples/subprocess_shell.py new file mode 100644 index 00000000..745cb646 --- /dev/null +++ b/examples/subprocess_shell.py @@ -0,0 +1,87 @@ +"""Example writing to and reading from a subprocess at the same time using +tasks.""" + +import asyncio +import os +from asyncio.subprocess import PIPE + + +@asyncio.coroutine +def send_input(writer, input): + try: + for line in input: + print('sending', len(line), 'bytes') + writer.write(line) + d = writer.drain() + if d: + print('pause writing') + yield from d + print('resume writing') + writer.close() + except BrokenPipeError: + print('stdin: broken pipe error') + except ConnectionResetError: + print('stdin: connection reset error') + +@asyncio.coroutine +def log_errors(reader): + while True: + line = yield from reader.readline() + if not line: + break + print('ERROR', repr(line)) + +@asyncio.coroutine +def read_stdout(stdout): + while True: + line = yield from stdout.readline() + print('received', repr(line)) + if not line: + break + +@asyncio.coroutine +def start(cmd, input=None, **kwds): + kwds['stdout'] = PIPE + kwds['stderr'] = PIPE + if input is None and 'stdin' not in kwds: + kwds['stdin'] = None + else: + kwds['stdin'] = PIPE + proc = yield from asyncio.create_subprocess_shell(cmd, **kwds) + + tasks = [] + if input is not None: + tasks.append(send_input(proc.stdin, input)) + else: + print('No stdin') + if proc.stderr is not None: + tasks.append(log_errors(proc.stderr)) + else: + print('No stderr') + if proc.stdout is not None: + tasks.append(read_stdout(proc.stdout)) + else: + print('No stdout') + + if tasks: + # feed stdin while consuming stdout to avoid hang + # when stdin pipe is full + yield from asyncio.wait(tasks) + + exitcode = yield from proc.wait() + print("exit code: %s" % exitcode) + + +def main(): + if os.name == 'nt': + loop = asyncio.ProactorEventLoop() + asyncio.set_event_loop(loop) + else: + loop = asyncio.get_event_loop() + loop.run_until_complete(start( + 'sleep 2; wc', input=[b'foo bar baz\n'*300 for i in range(100)])) + loop.close() + + +if __name__ == '__main__': + main() diff --git a/examples/tcp_echo.py b/examples/tcp_echo.py new file mode 100755 index 00000000..d743242a --- /dev/null +++ b/examples/tcp_echo.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +"""TCP echo server example.""" +import argparse +import asyncio +import sys +try: + import signal +except ImportError: + signal = None + + +class EchoServer(asyncio.Protocol): + + TIMEOUT = 5.0 + + def timeout(self): + print('connection timeout, closing.') + self.transport.close() + + def connection_made(self, transport): + print('connection made') + self.transport = transport + + # start 5 seconds timeout timer + self.h_timeout = asyncio.get_event_loop().call_later( + self.TIMEOUT, self.timeout) + + def data_received(self, data): + print('data received: ', data.decode()) + self.transport.write(b'Re: ' + data) + + # restart timeout timer + self.h_timeout.cancel() + self.h_timeout = asyncio.get_event_loop().call_later( + self.TIMEOUT, self.timeout) + + def eof_received(self): + pass + + def connection_lost(self, exc): + print('connection lost:', exc) + self.h_timeout.cancel() + + +class EchoClient(asyncio.Protocol): + + message = 'This is the message. It will be echoed.' + + def connection_made(self, transport): + self.transport = transport + self.transport.write(self.message.encode()) + print('data sent:', self.message) + + def data_received(self, data): + print('data received:', data) + + # disconnect after 10 seconds + asyncio.get_event_loop().call_later(10.0, self.transport.close) + + def eof_received(self): + pass + + def connection_lost(self, exc): + print('connection lost:', exc) + asyncio.get_event_loop().stop() + + +def start_client(loop, host, port): + t = asyncio.Task(loop.create_connection(EchoClient, host, port)) + loop.run_until_complete(t) + + +def start_server(loop, host, port): + f = loop.create_server(EchoServer, host, port) + return loop.run_until_complete(f) + + +ARGS = argparse.ArgumentParser(description="TCP Echo example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run tcp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run tcp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') +ARGS.add_argument( + '--iocp', action="store_true", dest='iocp', + default=False, help='Use IOCP event loop') + + +if __name__ == '__main__': + args = ARGS.parse_args() + + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + if args.iocp: + from asyncio import windows_events + loop = windows_events.ProactorEventLoop() + asyncio.set_event_loop(loop) + else: + loop = asyncio.get_event_loop() + print ('Using backend: {0}'.format(loop.__class__.__name__)) + + if signal is not None and sys.platform != 'win32': + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if args.server: + server = start_server(loop, args.host, args.port) + else: + start_client(loop, args.host, args.port) + + try: + loop.run_forever() + finally: + if args.server: + server.close() + loop.close() diff --git a/examples/timing_tcp_server.py b/examples/timing_tcp_server.py new file mode 100644 index 00000000..883ce6d3 --- /dev/null +++ b/examples/timing_tcp_server.py @@ -0,0 +1,168 @@ +""" +A variant of simple_tcp_server.py that measures the time it takes to +send N messages for a range of N. (This was O(N**2) in a previous +version of Tulip.) + +Note that running this example starts both the TCP server and client +in the same process. It listens on port 1234 on 127.0.0.1, so it will +fail if this port is currently in use. +""" + +import sys +import time +import random + +import asyncio +import asyncio.streams + + +class MyServer: + """ + This is just an example of how a TCP server might be potentially + structured. This class has basically 3 methods: start the server, + handle a client, and stop the server. + + Note that you don't have to follow this structure, it is really + just an example or possible starting point. + """ + + def __init__(self): + self.server = None # encapsulates the server sockets + + # this keeps track of all the clients that connected to our + # server. It can be useful in some cases, for instance to + # kill client connections or to broadcast some data to all + # clients... + self.clients = {} # task -> (reader, writer) + + def _accept_client(self, client_reader, client_writer): + """ + This method accepts a new client connection and creates a Task + to handle this client. self.clients is updated to keep track + of the new client. + """ + + # start a new Task to handle this specific client connection + task = asyncio.Task(self._handle_client(client_reader, client_writer)) + self.clients[task] = (client_reader, client_writer) + + def client_done(task): + print("client task done:", task, file=sys.stderr) + del self.clients[task] + + task.add_done_callback(client_done) + + @asyncio.coroutine + def _handle_client(self, client_reader, client_writer): + """ + This method actually does the work to handle the requests for + a specific client. The protocol is line oriented, so there is + a main loop that reads a line with a request and then sends + out one or more lines back to the client with the result. + """ + while True: + data = (yield from client_reader.readline()).decode("utf-8") + if not data: # an empty string means the client disconnected + break + cmd, *args = data.rstrip().split(' ') + if cmd == 'add': + arg1 = float(args[0]) + arg2 = float(args[1]) + retval = arg1 + arg2 + client_writer.write("{!r}\n".format(retval).encode("utf-8")) + elif cmd == 'repeat': + times = int(args[0]) + msg = args[1] + client_writer.write("begin\n".encode("utf-8")) + for idx in range(times): + client_writer.write("{}. {}\n".format( + idx+1, msg + 'x'*random.randint(10, 50)) + .encode("utf-8")) + client_writer.write("end\n".encode("utf-8")) + else: + print("Bad command {!r}".format(data), file=sys.stderr) + + # This enables us to have flow control in our connection. + yield from client_writer.drain() + + def start(self, loop): + """ + Starts the TCP server, so that it listens on port 1234. + + For each client that connects, the accept_client method gets + called. This method runs the loop until the server sockets + are ready to accept connections. + """ + self.server = loop.run_until_complete( + asyncio.streams.start_server(self._accept_client, + '127.0.0.1', 12345, + loop=loop)) + + def stop(self, loop): + """ + Stops the TCP server, i.e. closes the listening socket(s). + + This method runs the loop until the server sockets are closed. + """ + if self.server is not None: + self.server.close() + loop.run_until_complete(self.server.wait_closed()) + self.server = None + + +def main(): + loop = asyncio.get_event_loop() + + # creates a server and starts listening to TCP connections + server = MyServer() + server.start(loop) + + @asyncio.coroutine + def client(): + reader, writer = yield from asyncio.streams.open_connection( + '127.0.0.1', 12345, loop=loop) + + def send(msg): + print("> " + msg) + writer.write((msg + '\n').encode("utf-8")) + + def recv(): + msgback = (yield from reader.readline()).decode("utf-8").rstrip() + print("< " + msgback) + return msgback + + # send a line + send("add 1 2") + msg = yield from recv() + + Ns = list(range(100, 100000, 10000)) + times = [] + + for N in Ns: + t0 = time.time() + send("repeat {} hello world ".format(N)) + msg = yield from recv() + assert msg == 'begin' + while True: + msg = (yield from reader.readline()).decode("utf-8").rstrip() + if msg == 'end': + break + t1 = time.time() + dt = t1 - t0 + print("Time taken: {:.3f} seconds ({:.6f} per repetition)" + .format(dt, dt/N)) + times.append(dt) + + writer.close() + yield from asyncio.sleep(0.5) + + # creates a client and connects to our server + try: + loop.run_until_complete(client()) + server.stop(loop) + finally: + loop.close() + + +if __name__ == '__main__': + main() diff --git a/examples/udp_echo.py b/examples/udp_echo.py new file mode 100755 index 00000000..93ac7e6b --- /dev/null +++ b/examples/udp_echo.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +"""UDP echo example.""" +import argparse +import sys +import asyncio +try: + import signal +except ImportError: + signal = None + + +class MyServerUdpEchoProtocol: + + def connection_made(self, transport): + print('start', transport) + self.transport = transport + + def datagram_received(self, data, addr): + print('Data received:', data, addr) + self.transport.sendto(data, addr) + + def error_received(self, exc): + print('Error received:', exc) + + def connection_lost(self, exc): + print('stop', exc) + + +class MyClientUdpEchoProtocol: + + message = 'This is the message. It will be echoed.' + + def connection_made(self, transport): + self.transport = transport + print('sending "{}"'.format(self.message)) + self.transport.sendto(self.message.encode()) + print('waiting to receive') + + def datagram_received(self, data, addr): + print('received "{}"'.format(data.decode())) + self.transport.close() + + def error_received(self, exc): + print('Error received:', exc) + + def connection_lost(self, exc): + print('closing transport', exc) + loop = asyncio.get_event_loop() + loop.stop() + + +def start_server(loop, addr): + t = asyncio.Task(loop.create_datagram_endpoint( + MyServerUdpEchoProtocol, local_addr=addr)) + transport, server = loop.run_until_complete(t) + return transport + + +def start_client(loop, addr): + t = asyncio.Task(loop.create_datagram_endpoint( + MyClientUdpEchoProtocol, remote_addr=addr)) + loop.run_until_complete(t) + + +ARGS = argparse.ArgumentParser(description="UDP Echo example.") +ARGS.add_argument( + '--server', action="store_true", dest='server', + default=False, help='Run udp server') +ARGS.add_argument( + '--client', action="store_true", dest='client', + default=False, help='Run udp client') +ARGS.add_argument( + '--host', action="store", dest='host', + default='127.0.0.1', help='Host name') +ARGS.add_argument( + '--port', action="store", dest='port', + default=9999, type=int, help='Port number') + + +if __name__ == '__main__': + args = ARGS.parse_args() + if ':' in args.host: + args.host, port = args.host.split(':', 1) + args.port = int(port) + + if (not (args.server or args.client)) or (args.server and args.client): + print('Please specify --server or --client\n') + ARGS.print_help() + else: + loop = asyncio.get_event_loop() + if signal is not None: + loop.add_signal_handler(signal.SIGINT, loop.stop) + + if '--server' in sys.argv: + server = start_server(loop, (args.host, args.port)) + else: + start_client(loop, (args.host, args.port)) + + try: + loop.run_forever() + finally: + if '--server' in sys.argv: + server.close() + loop.close() diff --git a/overlapped.c b/overlapped.c new file mode 100644 index 00000000..4661152d --- /dev/null +++ b/overlapped.c @@ -0,0 +1,1334 @@ +/* + * Support for overlapped IO + * + * Some code borrowed from Modules/_winapi.c of CPython + */ + +/* XXX check overflow and DWORD <-> Py_ssize_t conversions + Check itemsize */ + +#include "Python.h" +#include "structmember.h" + +#define WINDOWS_LEAN_AND_MEAN +#include +#include +#include + +#if defined(MS_WIN32) && !defined(MS_WIN64) +# define F_POINTER "k" +# define T_POINTER T_ULONG +#else +# define F_POINTER "K" +# define T_POINTER T_ULONGLONG +#endif + +#define F_HANDLE F_POINTER +#define F_ULONG_PTR F_POINTER +#define F_DWORD "k" +#define F_BOOL "i" +#define F_UINT "I" + +#define T_HANDLE T_POINTER + +enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_WRITE, TYPE_ACCEPT, + TYPE_CONNECT, TYPE_DISCONNECT, TYPE_CONNECT_NAMED_PIPE, + TYPE_WAIT_NAMED_PIPE_AND_CONNECT}; + +typedef struct { + PyObject_HEAD + OVERLAPPED overlapped; + /* For convenience, we store the file handle too */ + HANDLE handle; + /* Error returned by last method call */ + DWORD error; + /* Type of operation */ + DWORD type; + union { + /* Buffer used for reading: TYPE_READ and TYPE_ACCEPT */ + PyObject *read_buffer; + /* Buffer used for writing: TYPE_WRITE */ + Py_buffer write_buffer; + }; +} OverlappedObject; + +/* + * Map Windows error codes to subclasses of OSError + */ + +static PyObject * +SetFromWindowsErr(DWORD err) +{ + PyObject *exception_type; + + if (err == 0) + err = GetLastError(); + switch (err) { + case ERROR_CONNECTION_REFUSED: + exception_type = PyExc_ConnectionRefusedError; + break; + case ERROR_CONNECTION_ABORTED: + exception_type = PyExc_ConnectionAbortedError; + break; + default: + exception_type = PyExc_OSError; + } + return PyErr_SetExcFromWindowsErr(exception_type, err); +} + +/* + * Some functions should be loaded at runtime + */ + +static LPFN_ACCEPTEX Py_AcceptEx = NULL; +static LPFN_CONNECTEX Py_ConnectEx = NULL; +static LPFN_DISCONNECTEX Py_DisconnectEx = NULL; +static BOOL (CALLBACK *Py_CancelIoEx)(HANDLE, LPOVERLAPPED) = NULL; + +#define GET_WSA_POINTER(s, x) \ + (SOCKET_ERROR != WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, \ + &Guid##x, sizeof(Guid##x), &Py_##x, \ + sizeof(Py_##x), &dwBytes, NULL, NULL)) + +static int +initialize_function_pointers(void) +{ + GUID GuidAcceptEx = WSAID_ACCEPTEX; + GUID GuidConnectEx = WSAID_CONNECTEX; + GUID GuidDisconnectEx = WSAID_DISCONNECTEX; + HINSTANCE hKernel32; + SOCKET s; + DWORD dwBytes; + + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (s == INVALID_SOCKET) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + if (!GET_WSA_POINTER(s, AcceptEx) || + !GET_WSA_POINTER(s, ConnectEx) || + !GET_WSA_POINTER(s, DisconnectEx)) + { + closesocket(s); + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + + closesocket(s); + + /* On WinXP we will have Py_CancelIoEx == NULL */ + hKernel32 = GetModuleHandle("KERNEL32"); + *(FARPROC *)&Py_CancelIoEx = GetProcAddress(hKernel32, "CancelIoEx"); + return 0; +} + +/* + * Completion port stuff + */ + +PyDoc_STRVAR( + CreateIoCompletionPort_doc, + "CreateIoCompletionPort(handle, port, key, concurrency) -> port\n\n" + "Create a completion port or register a handle with a port."); + +static PyObject * +overlapped_CreateIoCompletionPort(PyObject *self, PyObject *args) +{ + HANDLE FileHandle; + HANDLE ExistingCompletionPort; + ULONG_PTR CompletionKey; + DWORD NumberOfConcurrentThreads; + HANDLE ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE F_ULONG_PTR F_DWORD, + &FileHandle, &ExistingCompletionPort, &CompletionKey, + &NumberOfConcurrentThreads)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = CreateIoCompletionPort(FileHandle, ExistingCompletionPort, + CompletionKey, NumberOfConcurrentThreads); + Py_END_ALLOW_THREADS + + if (ret == NULL) + return SetFromWindowsErr(0); + return Py_BuildValue(F_HANDLE, ret); +} + +PyDoc_STRVAR( + GetQueuedCompletionStatus_doc, + "GetQueuedCompletionStatus(port, msecs) -> (err, bytes, key, address)\n\n" + "Get a message from completion port. Wait for up to msecs milliseconds."); + +static PyObject * +overlapped_GetQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort = NULL; + DWORD NumberOfBytes = 0; + ULONG_PTR CompletionKey = 0; + OVERLAPPED *Overlapped = NULL; + DWORD Milliseconds; + DWORD err; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, + &CompletionPort, &Milliseconds)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = GetQueuedCompletionStatus(CompletionPort, &NumberOfBytes, + &CompletionKey, &Overlapped, Milliseconds); + Py_END_ALLOW_THREADS + + err = ret ? ERROR_SUCCESS : GetLastError(); + if (Overlapped == NULL) { + if (err == WAIT_TIMEOUT) + Py_RETURN_NONE; + else + return SetFromWindowsErr(err); + } + return Py_BuildValue(F_DWORD F_DWORD F_ULONG_PTR F_POINTER, + err, NumberOfBytes, CompletionKey, Overlapped); +} + +PyDoc_STRVAR( + PostQueuedCompletionStatus_doc, + "PostQueuedCompletionStatus(port, bytes, key, address) -> None\n\n" + "Post a message to completion port."); + +static PyObject * +overlapped_PostQueuedCompletionStatus(PyObject *self, PyObject *args) +{ + HANDLE CompletionPort; + DWORD NumberOfBytes; + ULONG_PTR CompletionKey; + OVERLAPPED *Overlapped; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD F_ULONG_PTR F_POINTER, + &CompletionPort, &NumberOfBytes, &CompletionKey, + &Overlapped)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = PostQueuedCompletionStatus(CompletionPort, NumberOfBytes, + CompletionKey, Overlapped); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +/* + * Wait for a handle + */ + +struct PostCallbackData { + HANDLE CompletionPort; + LPOVERLAPPED Overlapped; +}; + +static VOID CALLBACK +PostToQueueCallback(PVOID lpParameter, BOOL TimerOrWaitFired) +{ + struct PostCallbackData *p = (struct PostCallbackData*) lpParameter; + + PostQueuedCompletionStatus(p->CompletionPort, TimerOrWaitFired, + 0, p->Overlapped); + /* ignore possible error! */ + PyMem_Free(p); +} + +PyDoc_STRVAR( + RegisterWaitWithQueue_doc, + "RegisterWaitWithQueue(Object, CompletionPort, Overlapped, Timeout)\n" + " -> WaitHandle\n\n" + "Register wait for Object; when complete CompletionPort is notified.\n"); + +static PyObject * +overlapped_RegisterWaitWithQueue(PyObject *self, PyObject *args) +{ + HANDLE NewWaitObject; + HANDLE Object; + ULONG Milliseconds; + struct PostCallbackData data, *pdata; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE F_POINTER F_DWORD, + &Object, + &data.CompletionPort, + &data.Overlapped, + &Milliseconds)) + return NULL; + + pdata = PyMem_Malloc(sizeof(struct PostCallbackData)); + if (pdata == NULL) + return SetFromWindowsErr(0); + + *pdata = data; + + if (!RegisterWaitForSingleObject( + &NewWaitObject, Object, (WAITORTIMERCALLBACK)PostToQueueCallback, + pdata, Milliseconds, + WT_EXECUTEINWAITTHREAD | WT_EXECUTEONLYONCE)) + { + PyMem_Free(pdata); + return SetFromWindowsErr(0); + } + + return Py_BuildValue(F_HANDLE, NewWaitObject); +} + +PyDoc_STRVAR( + UnregisterWait_doc, + "UnregisterWait(WaitHandle) -> None\n\n" + "Unregister wait handle.\n"); + +static PyObject * +overlapped_UnregisterWait(PyObject *self, PyObject *args) +{ + HANDLE WaitHandle; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE, &WaitHandle)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = UnregisterWait(WaitHandle); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +PyDoc_STRVAR( + UnregisterWaitEx_doc, + "UnregisterWaitEx(WaitHandle, Event) -> None\n\n" + "Unregister wait handle.\n"); + +static PyObject * +overlapped_UnregisterWaitEx(PyObject *self, PyObject *args) +{ + HANDLE WaitHandle, Event; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE, &WaitHandle, &Event)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = UnregisterWaitEx(WaitHandle, Event); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +/* + * Event functions -- currently only used by tests + */ + +PyDoc_STRVAR( + CreateEvent_doc, + "CreateEvent(EventAttributes, ManualReset, InitialState, Name)" + " -> Handle\n\n" + "Create an event. EventAttributes must be None.\n"); + +static PyObject * +overlapped_CreateEvent(PyObject *self, PyObject *args) +{ + PyObject *EventAttributes; + BOOL ManualReset; + BOOL InitialState; + Py_UNICODE *Name; + HANDLE Event; + + if (!PyArg_ParseTuple(args, "O" F_BOOL F_BOOL "Z", + &EventAttributes, &ManualReset, + &InitialState, &Name)) + return NULL; + + if (EventAttributes != Py_None) { + PyErr_SetString(PyExc_ValueError, "EventAttributes must be None"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + Event = CreateEventW(NULL, ManualReset, InitialState, Name); + Py_END_ALLOW_THREADS + + if (Event == NULL) + return SetFromWindowsErr(0); + return Py_BuildValue(F_HANDLE, Event); +} + +PyDoc_STRVAR( + SetEvent_doc, + "SetEvent(Handle) -> None\n\n" + "Set event.\n"); + +static PyObject * +overlapped_SetEvent(PyObject *self, PyObject *args) +{ + HANDLE Handle; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE, &Handle)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = SetEvent(Handle); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +PyDoc_STRVAR( + ResetEvent_doc, + "ResetEvent(Handle) -> None\n\n" + "Reset event.\n"); + +static PyObject * +overlapped_ResetEvent(PyObject *self, PyObject *args) +{ + HANDLE Handle; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE, &Handle)) + return NULL; + + Py_BEGIN_ALLOW_THREADS + ret = ResetEvent(Handle); + Py_END_ALLOW_THREADS + + if (!ret) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +/* + * Bind socket handle to local port without doing slow getaddrinfo() + */ + +PyDoc_STRVAR( + BindLocal_doc, + "BindLocal(handle, family) -> None\n\n" + "Bind a socket handle to an arbitrary local port.\n" + "family should AF_INET or AF_INET6.\n"); + +static PyObject * +overlapped_BindLocal(PyObject *self, PyObject *args) +{ + SOCKET Socket; + int Family; + BOOL ret; + + if (!PyArg_ParseTuple(args, F_HANDLE "i", &Socket, &Family)) + return NULL; + + if (Family == AF_INET) { + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = 0; + addr.sin_addr.S_un.S_addr = INADDR_ANY; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else if (Family == AF_INET6) { + struct sockaddr_in6 addr; + memset(&addr, 0, sizeof(addr)); + addr.sin6_family = AF_INET6; + addr.sin6_port = 0; + addr.sin6_addr = in6addr_any; + ret = bind(Socket, (SOCKADDR*)&addr, sizeof(addr)) != SOCKET_ERROR; + } else { + PyErr_SetString(PyExc_ValueError, "expected tuple of length 2 or 4"); + return NULL; + } + + if (!ret) + return SetFromWindowsErr(WSAGetLastError()); + Py_RETURN_NONE; +} + +/* + * Windows equivalent of os.strerror() -- compare _ctypes/callproc.c + */ + +PyDoc_STRVAR( + FormatMessage_doc, + "FormatMessage(error_code) -> error_message\n\n" + "Return error message for an error code."); + +static PyObject * +overlapped_FormatMessage(PyObject *ignore, PyObject *args) +{ + DWORD code, n; + WCHAR *lpMsgBuf; + PyObject *res; + + if (!PyArg_ParseTuple(args, F_DWORD, &code)) + return NULL; + + n = FormatMessageW(FORMAT_MESSAGE_ALLOCATE_BUFFER | + FORMAT_MESSAGE_FROM_SYSTEM, + NULL, + code, + MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + (LPWSTR) &lpMsgBuf, + 0, + NULL); + if (n) { + while (iswspace(lpMsgBuf[n-1])) + --n; + lpMsgBuf[n] = L'\0'; + res = Py_BuildValue("u", lpMsgBuf); + } else { + res = PyUnicode_FromFormat("unknown error code %u", code); + } + LocalFree(lpMsgBuf); + return res; +} + + +/* + * Mark operation as completed - used when reading produces ERROR_BROKEN_PIPE + */ + +static void +mark_as_completed(OVERLAPPED *ov) +{ + ov->Internal = 0; + if (ov->hEvent != NULL) + SetEvent(ov->hEvent); +} + +/* + * A Python object wrapping an OVERLAPPED structure and other useful data + * for overlapped I/O + */ + +PyDoc_STRVAR( + Overlapped_doc, + "Overlapped object"); + +static PyObject * +Overlapped_new(PyTypeObject *type, PyObject *args, PyObject *kwds) +{ + OverlappedObject *self; + HANDLE event = INVALID_HANDLE_VALUE; + static char *kwlist[] = {"event", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|" F_HANDLE, kwlist, &event)) + return NULL; + + if (event == INVALID_HANDLE_VALUE) { + event = CreateEvent(NULL, TRUE, FALSE, NULL); + if (event == NULL) + return SetFromWindowsErr(0); + } + + self = PyObject_New(OverlappedObject, type); + if (self == NULL) { + if (event != NULL) + CloseHandle(event); + return NULL; + } + + self->handle = NULL; + self->error = 0; + self->type = TYPE_NONE; + self->read_buffer = NULL; + memset(&self->overlapped, 0, sizeof(OVERLAPPED)); + memset(&self->write_buffer, 0, sizeof(Py_buffer)); + if (event) + self->overlapped.hEvent = event; + return (PyObject *)self; +} + +static void +Overlapped_dealloc(OverlappedObject *self) +{ + DWORD bytes; + DWORD olderr = GetLastError(); + BOOL wait = FALSE; + BOOL ret; + + if (!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED) + { + if (Py_CancelIoEx && Py_CancelIoEx(self->handle, &self->overlapped)) + wait = TRUE; + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, + &bytes, wait); + Py_END_ALLOW_THREADS + + switch (ret ? ERROR_SUCCESS : GetLastError()) { + case ERROR_SUCCESS: + case ERROR_NOT_FOUND: + case ERROR_OPERATION_ABORTED: + break; + default: + PyErr_Format( + PyExc_RuntimeError, + "%R still has pending operation at " + "deallocation, the process may crash", self); + PyErr_WriteUnraisable(NULL); + } + } + + if (self->overlapped.hEvent != NULL) + CloseHandle(self->overlapped.hEvent); + + switch (self->type) { + case TYPE_READ: + case TYPE_ACCEPT: + Py_CLEAR(self->read_buffer); + break; + case TYPE_WRITE: + if (self->write_buffer.obj) + PyBuffer_Release(&self->write_buffer); + break; + } + PyObject_Del(self); + SetLastError(olderr); +} + +PyDoc_STRVAR( + Overlapped_cancel_doc, + "cancel() -> None\n\n" + "Cancel overlapped operation"); + +static PyObject * +Overlapped_cancel(OverlappedObject *self) +{ + BOOL ret = TRUE; + + if (self->type == TYPE_NOT_STARTED + || self->type == TYPE_WAIT_NAMED_PIPE_AND_CONNECT) + Py_RETURN_NONE; + + if (!HasOverlappedIoCompleted(&self->overlapped)) { + Py_BEGIN_ALLOW_THREADS + if (Py_CancelIoEx) + ret = Py_CancelIoEx(self->handle, &self->overlapped); + else + ret = CancelIo(self->handle); + Py_END_ALLOW_THREADS + } + + /* CancelIoEx returns ERROR_NOT_FOUND if the I/O completed in-between */ + if (!ret && GetLastError() != ERROR_NOT_FOUND) + return SetFromWindowsErr(0); + Py_RETURN_NONE; +} + +PyDoc_STRVAR( + Overlapped_getresult_doc, + "getresult(wait=False) -> result\n\n" + "Retrieve result of operation. If wait is true then it blocks\n" + "until the operation is finished. If wait is false and the\n" + "operation is still pending then an error is raised."); + +static PyObject * +Overlapped_getresult(OverlappedObject *self, PyObject *args) +{ + BOOL wait = FALSE; + DWORD transferred = 0; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, "|" F_BOOL, &wait)) + return NULL; + + if (self->type == TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation not yet attempted"); + return NULL; + } + + if (self->type == TYPE_NOT_STARTED) { + PyErr_SetString(PyExc_ValueError, "operation failed to start"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + ret = GetOverlappedResult(self->handle, &self->overlapped, &transferred, + wait); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + break; + case ERROR_BROKEN_PIPE: + if ((self->type == TYPE_READ || self->type == TYPE_ACCEPT) && self->read_buffer != NULL) + break; + /* fall through */ + default: + return SetFromWindowsErr(err); + } + + switch (self->type) { + case TYPE_READ: + assert(PyBytes_CheckExact(self->read_buffer)); + if (transferred != PyBytes_GET_SIZE(self->read_buffer) && + _PyBytes_Resize(&self->read_buffer, transferred)) + return NULL; + Py_INCREF(self->read_buffer); + return self->read_buffer; + default: + return PyLong_FromUnsignedLong((unsigned long) transferred); + } +} + +PyDoc_STRVAR( + Overlapped_ReadFile_doc, + "ReadFile(handle, size) -> Overlapped[message]\n\n" + "Start overlapped read"); + +static PyObject * +Overlapped_ReadFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD nread; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &handle, &size)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = ReadFile(handle, PyBytes_AS_STRING(buf), size, &nread, + &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_BROKEN_PIPE: + mark_as_completed(&self->overlapped); + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSARecv_doc, + "RecvFile(handle, size, flags) -> Overlapped[message]\n\n" + "Start overlapped receive"); + +static PyObject * +Overlapped_WSARecv(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + DWORD size; + DWORD flags = 0; + DWORD nread; + PyObject *buf; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD "|" F_DWORD, + &handle, &size, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + +#if SIZEOF_SIZE_T <= SIZEOF_LONG + size = Py_MIN(size, (DWORD)PY_SSIZE_T_MAX); +#endif + buf = PyBytes_FromStringAndSize(NULL, Py_MAX(size, 1)); + if (buf == NULL) + return NULL; + + self->type = TYPE_READ; + self->handle = handle; + self->read_buffer = buf; + wsabuf.len = size; + wsabuf.buf = PyBytes_AS_STRING(buf); + + Py_BEGIN_ALLOW_THREADS + ret = WSARecv((SOCKET)handle, &wsabuf, 1, &nread, &flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_BROKEN_PIPE: + mark_as_completed(&self->overlapped); + Py_RETURN_NONE; + case ERROR_SUCCESS: + case ERROR_MORE_DATA: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WriteFile_doc, + "WriteFile(handle, buf) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped write"); + +static PyObject * +Overlapped_WriteFile(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD written; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &handle, &bufobj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + + Py_BEGIN_ALLOW_THREADS + ret = WriteFile(handle, self->write_buffer.buf, + (DWORD)self->write_buffer.len, + &written, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_WSASend_doc, + "WSASend(handle, buf, flags) -> Overlapped[bytes_transferred]\n\n" + "Start overlapped send"); + +static PyObject * +Overlapped_WSASend(OverlappedObject *self, PyObject *args) +{ + HANDLE handle; + PyObject *bufobj; + DWORD flags; + DWORD written; + WSABUF wsabuf; + int ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O" F_DWORD, + &handle, &bufobj, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) + return NULL; + +#if SIZEOF_SIZE_T > SIZEOF_LONG + if (self->write_buffer.len > (Py_ssize_t)ULONG_MAX) { + PyBuffer_Release(&self->write_buffer); + PyErr_SetString(PyExc_ValueError, "buffer to large"); + return NULL; + } +#endif + + self->type = TYPE_WRITE; + self->handle = handle; + wsabuf.len = (DWORD)self->write_buffer.len; + wsabuf.buf = self->write_buffer.buf; + + Py_BEGIN_ALLOW_THREADS + ret = WSASend((SOCKET)handle, &wsabuf, 1, &written, flags, + &self->overlapped, NULL); + Py_END_ALLOW_THREADS + + self->error = err = (ret < 0 ? WSAGetLastError() : ERROR_SUCCESS); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_AcceptEx_doc, + "AcceptEx(listen_handle, accept_handle) -> Overlapped[address_as_bytes]\n\n" + "Start overlapped wait for client to connect"); + +static PyObject * +Overlapped_AcceptEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ListenSocket; + SOCKET AcceptSocket; + DWORD BytesReceived; + DWORD size; + PyObject *buf; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_HANDLE, + &ListenSocket, &AcceptSocket)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + size = sizeof(struct sockaddr_in6) + 16; + buf = PyBytes_FromStringAndSize(NULL, size*2); + if (!buf) + return NULL; + + self->type = TYPE_ACCEPT; + self->handle = (HANDLE)ListenSocket; + self->read_buffer = buf; + + Py_BEGIN_ALLOW_THREADS + ret = Py_AcceptEx(ListenSocket, AcceptSocket, PyBytes_AS_STRING(buf), + 0, size, size, &BytesReceived, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + + +static int +parse_address(PyObject *obj, SOCKADDR *Address, int Length) +{ + char *Host; + unsigned short Port; + unsigned long FlowInfo; + unsigned long ScopeId; + + memset(Address, 0, Length); + + if (PyArg_ParseTuple(obj, "sH", &Host, &Port)) + { + Address->sa_family = AF_INET; + if (WSAStringToAddressA(Host, AF_INET, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN*)Address)->sin_port = htons(Port); + return Length; + } + else if (PyArg_ParseTuple(obj, "sHkk", &Host, &Port, &FlowInfo, &ScopeId)) + { + PyErr_Clear(); + Address->sa_family = AF_INET6; + if (WSAStringToAddressA(Host, AF_INET6, NULL, Address, &Length) < 0) { + SetFromWindowsErr(WSAGetLastError()); + return -1; + } + ((SOCKADDR_IN6*)Address)->sin6_port = htons(Port); + ((SOCKADDR_IN6*)Address)->sin6_flowinfo = FlowInfo; + ((SOCKADDR_IN6*)Address)->sin6_scope_id = ScopeId; + return Length; + } + + return -1; +} + + +PyDoc_STRVAR( + Overlapped_ConnectEx_doc, + "ConnectEx(client_handle, address_as_bytes) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_ConnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET ConnectSocket; + PyObject *AddressObj; + char AddressBuf[sizeof(struct sockaddr_in6)]; + SOCKADDR *Address = (SOCKADDR*)AddressBuf; + int Length; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE "O", &ConnectSocket, &AddressObj)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + Length = sizeof(AddressBuf); + Length = parse_address(AddressObj, Address, Length); + if (Length < 0) + return NULL; + + self->type = TYPE_CONNECT; + self->handle = (HANDLE)ConnectSocket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_ConnectEx(ConnectSocket, Address, Length, + NULL, 0, NULL, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_DisconnectEx_doc, + "DisconnectEx(handle, flags) -> Overlapped[None]\n\n" + "Start overlapped connect. client_handle should be unbound."); + +static PyObject * +Overlapped_DisconnectEx(OverlappedObject *self, PyObject *args) +{ + SOCKET Socket; + DWORD flags; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE F_DWORD, &Socket, &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + self->type = TYPE_DISCONNECT; + self->handle = (HANDLE)Socket; + + Py_BEGIN_ALLOW_THREADS + ret = Py_DisconnectEx(Socket, &self->overlapped, flags, 0); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + Overlapped_ConnectNamedPipe_doc, + "ConnectNamedPipe(handle) -> Overlapped[None]\n\n" + "Start overlapped wait for a client to connect."); + +static PyObject * +Overlapped_ConnectNamedPipe(OverlappedObject *self, PyObject *args) +{ + HANDLE Pipe; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, F_HANDLE, &Pipe)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + self->type = TYPE_CONNECT_NAMED_PIPE; + self->handle = Pipe; + + Py_BEGIN_ALLOW_THREADS + ret = ConnectNamedPipe(Pipe, &self->overlapped); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : GetLastError(); + switch (err) { + case ERROR_PIPE_CONNECTED: + mark_as_completed(&self->overlapped); + Py_RETURN_TRUE; + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_FALSE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + +PyDoc_STRVAR( + ConnectPipe_doc, + "ConnectPipe(addr) -> pipe_handle\n\n" + "Connect to the pipe for asynchronous I/O (overlapped)."); + +static PyObject * +ConnectPipe(OverlappedObject *self, PyObject *args) +{ + PyObject *AddressObj; + wchar_t *Address; + HANDLE PipeHandle; + + if (!PyArg_ParseTuple(args, "U", &AddressObj)) + return NULL; + + Address = PyUnicode_AsWideCharString(AddressObj, NULL); + if (Address == NULL) + return NULL; + + PipeHandle = CreateFileW(Address, + GENERIC_READ | GENERIC_WRITE, + 0, NULL, OPEN_EXISTING, + FILE_FLAG_OVERLAPPED, NULL); + PyMem_Free(Address); + if (PipeHandle == INVALID_HANDLE_VALUE) + return SetFromWindowsErr(0); + return Py_BuildValue(F_HANDLE, PipeHandle); +} + +static PyObject* +Overlapped_getaddress(OverlappedObject *self) +{ + return PyLong_FromVoidPtr(&self->overlapped); +} + +static PyObject* +Overlapped_getpending(OverlappedObject *self) +{ + return PyBool_FromLong(!HasOverlappedIoCompleted(&self->overlapped) && + self->type != TYPE_NOT_STARTED); +} + +static PyMethodDef Overlapped_methods[] = { + {"getresult", (PyCFunction) Overlapped_getresult, + METH_VARARGS, Overlapped_getresult_doc}, + {"cancel", (PyCFunction) Overlapped_cancel, + METH_NOARGS, Overlapped_cancel_doc}, + {"ReadFile", (PyCFunction) Overlapped_ReadFile, + METH_VARARGS, Overlapped_ReadFile_doc}, + {"WSARecv", (PyCFunction) Overlapped_WSARecv, + METH_VARARGS, Overlapped_WSARecv_doc}, + {"WriteFile", (PyCFunction) Overlapped_WriteFile, + METH_VARARGS, Overlapped_WriteFile_doc}, + {"WSASend", (PyCFunction) Overlapped_WSASend, + METH_VARARGS, Overlapped_WSASend_doc}, + {"AcceptEx", (PyCFunction) Overlapped_AcceptEx, + METH_VARARGS, Overlapped_AcceptEx_doc}, + {"ConnectEx", (PyCFunction) Overlapped_ConnectEx, + METH_VARARGS, Overlapped_ConnectEx_doc}, + {"DisconnectEx", (PyCFunction) Overlapped_DisconnectEx, + METH_VARARGS, Overlapped_DisconnectEx_doc}, + {"ConnectNamedPipe", (PyCFunction) Overlapped_ConnectNamedPipe, + METH_VARARGS, Overlapped_ConnectNamedPipe_doc}, + {NULL} +}; + +static PyMemberDef Overlapped_members[] = { + {"error", T_ULONG, + offsetof(OverlappedObject, error), + READONLY, "Error from last operation"}, + {"event", T_HANDLE, + offsetof(OverlappedObject, overlapped) + offsetof(OVERLAPPED, hEvent), + READONLY, "Overlapped event handle"}, + {NULL} +}; + +static PyGetSetDef Overlapped_getsets[] = { + {"address", (getter)Overlapped_getaddress, NULL, + "Address of overlapped structure"}, + {"pending", (getter)Overlapped_getpending, NULL, + "Whether the operation is pending"}, + {NULL}, +}; + +PyTypeObject OverlappedType = { + PyVarObject_HEAD_INIT(NULL, 0) + /* tp_name */ "_overlapped.Overlapped", + /* tp_basicsize */ sizeof(OverlappedObject), + /* tp_itemsize */ 0, + /* tp_dealloc */ (destructor) Overlapped_dealloc, + /* tp_print */ 0, + /* tp_getattr */ 0, + /* tp_setattr */ 0, + /* tp_reserved */ 0, + /* tp_repr */ 0, + /* tp_as_number */ 0, + /* tp_as_sequence */ 0, + /* tp_as_mapping */ 0, + /* tp_hash */ 0, + /* tp_call */ 0, + /* tp_str */ 0, + /* tp_getattro */ 0, + /* tp_setattro */ 0, + /* tp_as_buffer */ 0, + /* tp_flags */ Py_TPFLAGS_DEFAULT, + /* tp_doc */ "OVERLAPPED structure wrapper", + /* tp_traverse */ 0, + /* tp_clear */ 0, + /* tp_richcompare */ 0, + /* tp_weaklistoffset */ 0, + /* tp_iter */ 0, + /* tp_iternext */ 0, + /* tp_methods */ Overlapped_methods, + /* tp_members */ Overlapped_members, + /* tp_getset */ Overlapped_getsets, + /* tp_base */ 0, + /* tp_dict */ 0, + /* tp_descr_get */ 0, + /* tp_descr_set */ 0, + /* tp_dictoffset */ 0, + /* tp_init */ 0, + /* tp_alloc */ 0, + /* tp_new */ Overlapped_new, +}; + +static PyMethodDef overlapped_functions[] = { + {"CreateIoCompletionPort", overlapped_CreateIoCompletionPort, + METH_VARARGS, CreateIoCompletionPort_doc}, + {"GetQueuedCompletionStatus", overlapped_GetQueuedCompletionStatus, + METH_VARARGS, GetQueuedCompletionStatus_doc}, + {"PostQueuedCompletionStatus", overlapped_PostQueuedCompletionStatus, + METH_VARARGS, PostQueuedCompletionStatus_doc}, + {"FormatMessage", overlapped_FormatMessage, + METH_VARARGS, FormatMessage_doc}, + {"BindLocal", overlapped_BindLocal, + METH_VARARGS, BindLocal_doc}, + {"RegisterWaitWithQueue", overlapped_RegisterWaitWithQueue, + METH_VARARGS, RegisterWaitWithQueue_doc}, + {"UnregisterWait", overlapped_UnregisterWait, + METH_VARARGS, UnregisterWait_doc}, + {"UnregisterWaitEx", overlapped_UnregisterWaitEx, + METH_VARARGS, UnregisterWaitEx_doc}, + {"CreateEvent", overlapped_CreateEvent, + METH_VARARGS, CreateEvent_doc}, + {"SetEvent", overlapped_SetEvent, + METH_VARARGS, SetEvent_doc}, + {"ResetEvent", overlapped_ResetEvent, + METH_VARARGS, ResetEvent_doc}, + {"ConnectPipe", + (PyCFunction) ConnectPipe, + METH_VARARGS, ConnectPipe_doc}, + {NULL} +}; + +static struct PyModuleDef overlapped_module = { + PyModuleDef_HEAD_INIT, + "_overlapped", + NULL, + -1, + overlapped_functions, + NULL, + NULL, + NULL, + NULL +}; + +#define WINAPI_CONSTANT(fmt, con) \ + PyDict_SetItemString(d, #con, Py_BuildValue(fmt, con)) + +PyMODINIT_FUNC +PyInit__overlapped(void) +{ + PyObject *m, *d; + + /* Ensure WSAStartup() called before initializing function pointers */ + m = PyImport_ImportModule("_socket"); + if (!m) + return NULL; + Py_DECREF(m); + + if (initialize_function_pointers() < 0) + return NULL; + + if (PyType_Ready(&OverlappedType) < 0) + return NULL; + + m = PyModule_Create(&overlapped_module); + if (PyModule_AddObject(m, "Overlapped", (PyObject *)&OverlappedType) < 0) + return NULL; + + d = PyModule_GetDict(m); + + WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING); + WINAPI_CONSTANT(F_DWORD, ERROR_NETNAME_DELETED); + WINAPI_CONSTANT(F_DWORD, ERROR_SEM_TIMEOUT); + WINAPI_CONSTANT(F_DWORD, ERROR_PIPE_BUSY); + WINAPI_CONSTANT(F_DWORD, INFINITE); + WINAPI_CONSTANT(F_HANDLE, INVALID_HANDLE_VALUE); + WINAPI_CONSTANT(F_HANDLE, NULL); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_ACCEPT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, SO_UPDATE_CONNECT_CONTEXT); + WINAPI_CONSTANT(F_DWORD, TF_REUSE_SOCKET); + + return m; +} diff --git a/pypi.bat b/pypi.bat new file mode 100644 index 00000000..5218ace3 --- /dev/null +++ b/pypi.bat @@ -0,0 +1 @@ +c:\Python33\python.exe setup.py bdist_wheel upload diff --git a/release.py b/release.py new file mode 100755 index 00000000..a5acbc88 --- /dev/null +++ b/release.py @@ -0,0 +1,517 @@ +#!/usr/bin/env python3 +""" +Script to upload 32 bits and 64 bits wheel packages for Python 3.3 on Windows. + +Usage: "python release.py HG_TAG" where HG_TAG is a Mercurial tag, usually +a version number like "3.4.2". + +Requirements: + +- Python 3.3 and newer requires the Windows SDK 7.1 to build wheel packages +- Python 2.7 requires the Windows SDK 7.0 +- the aiotest module is required to run aiotest tests +""" +import contextlib +import optparse +import os +import platform +import re +import shutil +import subprocess +import sys +import tempfile +import textwrap + +PROJECT = 'asyncio' +DEBUG_ENV_VAR = 'PYTHONASYNCIODEBUG' +PYTHON_VERSIONS = ( + (3, 3), +) +PY3 = (sys.version_info >= (3,)) +HG = 'hg' +SDK_ROOT = r"C:\Program Files\Microsoft SDKs\Windows" +BATCH_FAIL_ON_ERROR = "@IF %errorlevel% neq 0 exit /b %errorlevel%" +WINDOWS = (sys.platform == 'win32') + + +def get_architecture_bits(): + arch = platform.architecture()[0] + return int(arch[:2]) + + +class PythonVersion: + def __init__(self, major, minor, bits): + self.major = major + self.minor = minor + self.bits = bits + self._executable = None + + @staticmethod + def running(): + bits = get_architecture_bits() + pyver = PythonVersion(sys.version_info.major, + sys.version_info.minor, + bits) + pyver._executable = sys.executable + return pyver + + def _get_executable_windows(self, app): + if self.bits == 32: + executable = 'c:\\Python%s%s_32bit\\python.exe' + else: + executable = 'c:\\Python%s%s\\python.exe' + executable = executable % (self.major, self.minor) + if not os.path.exists(executable): + print("Unable to find python %s" % self) + print("%s does not exists" % executable) + sys.exit(1) + return executable + + def _get_executable_unix(self, app): + return 'python%s.%s' % (self.major, self.minor) + + def get_executable(self, app): + if self._executable: + return self._executable + + if WINDOWS: + executable = self._get_executable_windows(app) + else: + executable = self._get_executable_unix(app) + + code = ( + 'import platform, sys; ' + 'print("{ver.major}.{ver.minor} {bits}".format(' + 'ver=sys.version_info, ' + 'bits=platform.architecture()[0]))' + ) + try: + exitcode, stdout = app.get_output(executable, '-c', code, + ignore_stderr=True) + except OSError as exc: + print("Error while checking %s:" % self) + print(str(exc)) + print("Executable: %s" % executable) + sys.exit(1) + else: + stdout = stdout.rstrip() + expected = "%s.%s %sbit" % (self.major, self.minor, self.bits) + if stdout != expected: + print("Python version or architecture doesn't match") + print("got %r, expected %r" % (stdout, expected)) + print("Executable: %s" % executable) + sys.exit(1) + + self._executable = executable + return executable + + def __str__(self): + return 'Python %s.%s (%s bits)' % (self.major, self.minor, self.bits) + + +class Release(object): + def __init__(self): + root = os.path.dirname(__file__) + self.root = os.path.realpath(root) + # Set these attributes to True to run also register sdist upload + self.wheel = False + self.test = False + self.register = False + self.sdist = False + self.aiotest = False + self.verbose = False + self.upload = False + # Release mode: enable more tests + self.release = False + self.python_versions = [] + if WINDOWS: + supported_archs = (32, 64) + else: + bits = get_architecture_bits() + supported_archs = (bits,) + for major, minor in PYTHON_VERSIONS: + for bits in supported_archs: + pyver = PythonVersion(major, minor, bits) + self.python_versions.append(pyver) + + @contextlib.contextmanager + def _popen(self, args, **kw): + verbose = kw.pop('verbose', True) + if self.verbose and verbose: + print('+ ' + ' '.join(args)) + if PY3: + kw['universal_newlines'] = True + proc = subprocess.Popen(args, **kw) + try: + yield proc + except: + proc.kill() + proc.wait() + raise + + def get_output(self, *args, **kw): + kw['stdout'] = subprocess.PIPE + ignore_stderr = kw.pop('ignore_stderr', False) + if ignore_stderr: + devnull = open(os.path.devnull, 'wb') + kw['stderr'] = devnull + else: + kw['stderr'] = subprocess.STDOUT + try: + with self._popen(args, **kw) as proc: + stdout, stderr = proc.communicate() + return proc.returncode, stdout + finally: + if ignore_stderr: + devnull.close() + + def check_output(self, *args, **kw): + exitcode, output = self.get_output(*args, **kw) + if exitcode: + sys.stdout.write(output) + sys.stdout.flush() + sys.exit(1) + return output + + def run_command(self, *args, **kw): + with self._popen(args, **kw) as proc: + exitcode = proc.wait() + if exitcode: + sys.exit(exitcode) + + def get_local_changes(self): + status = self.check_output(HG, 'status') + return [line for line in status.splitlines() + if not line.startswith("?")] + + def remove_directory(self, name): + path = os.path.join(self.root, name) + if os.path.exists(path): + if self.verbose: + print("Remove directory: %s" % name) + shutil.rmtree(path) + + def remove_file(self, name): + path = os.path.join(self.root, name) + if os.path.exists(path): + if self.verbose: + print("Remove file: %s" % name) + os.unlink(path) + + def windows_sdk_setenv(self, pyver): + if (pyver.major, pyver.minor) >= (3, 3): + path = "v7.1" + sdkver = (7, 1) + else: + path = "v7.0" + sdkver = (7, 0) + setenv = os.path.join(SDK_ROOT, path, 'Bin', 'SetEnv.cmd') + if not os.path.exists(setenv): + print("Unable to find Windows SDK %s.%s for %s" + % (sdkver[0], sdkver[1], pyver)) + print("Please download and install it") + print("%s does not exists" % setenv) + sys.exit(1) + if pyver.bits == 64: + arch = '/x64' + else: + arch = '/x86' + cmd = ["CALL", setenv, "/release", arch] + return (cmd, sdkver) + + def quote(self, arg): + if not re.search("[ '\"]", arg): + return arg + # FIXME: should we escape "? + return '"%s"' % arg + + def quote_args(self, args): + return ' '.join(self.quote(arg) for arg in args) + + def cleanup(self): + if self.verbose: + print("Cleanup") + self.remove_directory('build') + self.remove_directory('dist') + self.remove_file('_overlapped.pyd') + self.remove_file(os.path.join(PROJECT, '_overlapped.pyd')) + + def sdist_upload(self): + self.cleanup() + self.run_command(sys.executable, 'setup.py', 'sdist', 'upload') + + def build_inplace(self, pyver): + print("Build for %s" % pyver) + self.build(pyver, 'build') + + if WINDOWS: + if pyver.bits == 64: + arch = 'win-amd64' + else: + arch = 'win32' + build_dir = 'lib.%s-%s.%s' % (arch, pyver.major, pyver.minor) + src = os.path.join(self.root, 'build', build_dir, + PROJECT, '_overlapped.pyd') + dst = os.path.join(self.root, PROJECT, '_overlapped.pyd') + shutil.copyfile(src, dst) + + def runtests(self, pyver): + print("Run tests on %s" % pyver) + + if WINDOWS and not self.options.no_compile: + self.build_inplace(pyver) + + release_env = dict(os.environ) + release_env.pop(DEBUG_ENV_VAR, None) + + dbg_env = dict(os.environ) + dbg_env[DEBUG_ENV_VAR] = '1' + + python = pyver.get_executable(self) + args = (python, 'runtests.py', '-r') + + if self.release: + print("Run runtests.py in release mode on %s" % pyver) + self.run_command(*args, env=release_env) + + print("Run runtests.py in debug mode on %s" % pyver) + self.run_command(*args, env=dbg_env) + + if self.aiotest: + args = (python, 'run_aiotest.py') + + if self.release: + print("Run aiotest in release mode on %s" % pyver) + self.run_command(*args, env=release_env) + + print("Run aiotest in debug mode on %s" % pyver) + self.run_command(*args, env=dbg_env) + print("") + + def _build_windows(self, pyver, cmd): + setenv, sdkver = self.windows_sdk_setenv(pyver) + + temp = tempfile.NamedTemporaryFile(mode="w", suffix=".bat", + delete=False) + with temp: + temp.write("SETLOCAL EnableDelayedExpansion\n") + temp.write(self.quote_args(setenv) + "\n") + temp.write(BATCH_FAIL_ON_ERROR + "\n") + # Restore console colors: lightgrey on black + temp.write("COLOR 07\n") + temp.write("\n") + temp.write("SET DISTUTILS_USE_SDK=1\n") + temp.write("SET MSSDK=1\n") + temp.write("CD %s\n" % self.quote(self.root)) + temp.write(self.quote_args(cmd) + "\n") + temp.write(BATCH_FAIL_ON_ERROR + "\n") + + try: + if self.verbose: + print("Setup Windows SDK %s.%s" % sdkver) + print("+ " + ' '.join(cmd)) + # SDK 7.1 uses the COLOR command which makes SetEnv.cmd failing + # if the stdout is not a TTY (if we redirect stdout into a file) + if self.verbose or sdkver >= (7, 1): + self.run_command(temp.name, verbose=False) + else: + self.check_output(temp.name, verbose=False) + finally: + os.unlink(temp.name) + + def _build_unix(self, pyver, cmd): + self.check_output(*cmd) + + def build(self, pyver, *cmds): + self.cleanup() + + python = pyver.get_executable(self) + cmd = [python, 'setup.py'] + list(cmds) + + if WINDOWS: + self._build_windows(pyver, cmd) + else: + self._build_unix(pyver, cmd) + + def test_wheel(self, pyver): + print("Test building wheel package for %s" % pyver) + self.build(pyver, 'bdist_wheel') + + def publish_wheel(self, pyver): + print("Build and publish wheel package for %s" % pyver) + self.build(pyver, 'bdist_wheel', 'upload') + + def parse_options(self): + parser = optparse.OptionParser( + description="Run all unittests.", + usage="%prog [options] command") + parser.add_option( + '-v', '--verbose', action="store_true", dest='verbose', + default=0, help='verbose') + parser.add_option( + '-t', '--tag', type="str", + help='Mercurial tag or revision, required to release') + parser.add_option( + '-p', '--python', type="str", + help='Only build/test one specific Python version, ex: "2.7:32"') + parser.add_option( + '-C', "--no-compile", action="store_true", + help="Don't compile the module, this options implies --running", + default=False) + parser.add_option( + '-r', "--running", action="store_true", + help='Only use the running Python version', + default=False) + parser.add_option( + '--ignore', action="store_true", + help='Ignore local changes', + default=False) + self.options, args = parser.parse_args() + if len(args) == 1: + command = args[0] + else: + command = None + + if self.options.no_compile: + self.options.running = True + + if command == 'clean': + self.options.verbose = True + elif command == 'build': + self.options.running = True + elif command == 'test_wheel': + self.wheel = True + elif command == 'test': + self.test = True + elif command == 'release': + if not self.options.tag: + print("The release command requires the --tag option") + sys.exit(1) + + self.release = True + self.wheel = True + self.test = True + self.upload = True + else: + if command: + print("Invalid command: %s" % command) + else: + parser.print_help() + print("") + + print("Available commands:") + print("- build: build asyncio in place, imply --running") + print("- test: run tests") + print("- test_wheel: test building wheel packages") + print("- release: run tests and publish wheel packages,") + print(" require the --tag option") + print("- clean: cleanup the project") + sys.exit(1) + + if self.options.python and self.options.running: + print("--python and --running options are exclusive") + sys.exit(1) + + python = self.options.python + if python: + match = re.match("^([23])\.([0-9])/(32|64)$", python) + if not match: + print("Invalid Python version: %s" % python) + print('Format of a Python version: "x.y/bits"') + print("Example: 2.7/32") + sys.exit(1) + major = int(match.group(1)) + minor = int(match.group(2)) + bits = int(match.group(3)) + self.python_versions = [PythonVersion(major, minor, bits)] + + if self.options.running: + self.python_versions = [PythonVersion.running()] + + self.verbose = self.options.verbose + self.command = command + + def main(self): + self.parse_options() + + print("Directory: %s" % self.root) + os.chdir(self.root) + + if self.command == "clean": + self.cleanup() + sys.exit(1) + + if self.command == "build": + if len(self.python_versions) != 1: + print("build command requires one specific Python version") + print("Use the --python command line option") + sys.exit(1) + pyver = self.python_versions[0] + self.build_inplace(pyver) + + if (self.register or self.upload) and (not self.options.ignore): + lines = self.get_local_changes() + else: + lines = () + if lines: + print("ERROR: Found local changes") + for line in lines: + print(line) + print("") + print("Revert local changes") + print("or use the --ignore command line option") + sys.exit(1) + + hg_tag = self.options.tag + if hg_tag: + print("Update repository to revision %s" % hg_tag) + self.check_output(HG, 'update', hg_tag) + + hg_rev = self.check_output(HG, 'id').rstrip() + + if self.wheel: + for pyver in self.python_versions: + self.test_wheel(pyver) + + if self.test: + for pyver in self.python_versions: + self.runtests(pyver) + + if self.register: + self.run_command(sys.executable, 'setup.py', 'register') + + if self.sdist: + self.sdist_upload() + + if self.upload: + for pyver in self.python_versions: + self.publish_wheel(pyver) + + hg_rev2 = self.check_output(HG, 'id').rstrip() + if hg_rev != hg_rev2: + print("ERROR: The Mercurial revision changed") + print("Before: %s" % hg_rev) + print("After: %s" % hg_rev2) + sys.exit(1) + + print("") + print("Mercurial revision: %s" % hg_rev) + if self.command == 'build': + print("Inplace compilation done") + if self.wheel: + print("Compilation of wheel packages succeeded") + if self.test: + print("Tests succeeded") + if self.register: + print("Project registered on the Python cheeseshop (PyPI)") + if self.sdist: + print("Project source code uploaded to the Python " + "cheeseshop (PyPI)") + if self.upload: + print("Wheel packages uploaded to the Python cheeseshop (PyPI)") + for pyver in self.python_versions: + print("- %s" % pyver) + + +if __name__ == "__main__": + Release().main() diff --git a/run_aiotest.py b/run_aiotest.py new file mode 100644 index 00000000..8d6fa293 --- /dev/null +++ b/run_aiotest.py @@ -0,0 +1,14 @@ +import aiotest.run +import asyncio +import sys +if sys.platform == 'win32': + from asyncio.windows_utils import socketpair +else: + from socket import socketpair + +config = aiotest.TestConfig() +config.asyncio = asyncio +config.socketpair = socketpair +config.new_event_pool_policy = asyncio.DefaultEventLoopPolicy +config.call_soon_check_closed = True +aiotest.run.main(config) diff --git a/runtests.py b/runtests.py new file mode 100644 index 00000000..c38b0c18 --- /dev/null +++ b/runtests.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python3 +"""Run Tulip unittests. + +Usage: + python3 runtests.py [flags] [pattern] ... + +Patterns are matched against the fully qualified name of the test, +including package, module, class and method, +e.g. 'tests.test_events.PolicyTests.testPolicy'. + +For full help, try --help. + +runtests.py --coverage is equivalent of: + + $(COVERAGE) run --branch runtests.py -v + $(COVERAGE) html $(list of files) + $(COVERAGE) report -m $(list of files) + +""" + +# Originally written by Beech Horn (for NDB). + +import argparse +import gc +import logging +import os +import random +import re +import sys +import unittest +import textwrap +import importlib.machinery +try: + import coverage +except ImportError: + coverage = None + +from unittest.signals import installHandler + +assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' + +ARGS = argparse.ArgumentParser(description="Run all unittests.") +ARGS.add_argument( + '-v', action="store", dest='verbose', + nargs='?', const=1, type=int, default=0, help='verbose') +ARGS.add_argument( + '-x', action="store_true", dest='exclude', help='exclude tests') +ARGS.add_argument( + '-f', '--failfast', action="store_true", default=False, + dest='failfast', help='Stop on first fail or error') +ARGS.add_argument( + '-c', '--catch', action="store_true", default=False, + dest='catchbreak', help='Catch control-C and display results') +ARGS.add_argument( + '--forever', action="store_true", dest='forever', default=False, + help='run tests forever to catch sporadic errors') +ARGS.add_argument( + '--findleaks', action='store_true', dest='findleaks', + help='detect tests that leak memory') +ARGS.add_argument('-r', '--randomize', action='store_true', + help='randomize test execution order.') +ARGS.add_argument('--seed', type=int, + help='random seed to reproduce a previous random run') +ARGS.add_argument( + '-q', action="store_true", dest='quiet', help='quiet') +ARGS.add_argument( + '--tests', action="store", dest='testsdir', default='tests', + help='tests directory') +ARGS.add_argument( + '--coverage', action="store_true", dest='coverage', + help='enable html coverage report') +ARGS.add_argument( + 'pattern', action="store", nargs="*", + help='optional regex patterns to match test ids (default all tests)') + +COV_ARGS = argparse.ArgumentParser(description="Run all unittests.") +COV_ARGS.add_argument( + '--coverage', action="store", dest='coverage', nargs='?', const='', + help='enable coverage report and provide python files directory') + + +def load_modules(basedir, suffix='.py'): + def list_dir(prefix, dir): + files = [] + + modpath = os.path.join(dir, '__init__.py') + if os.path.isfile(modpath): + mod = os.path.split(dir)[-1] + files.append(('{}{}'.format(prefix, mod), modpath)) + + prefix = '{}{}.'.format(prefix, mod) + + for name in os.listdir(dir): + path = os.path.join(dir, name) + + if os.path.isdir(path): + files.extend(list_dir('{}{}.'.format(prefix, name), path)) + else: + if (name != '__init__.py' and + name.endswith(suffix) and + not name.startswith(('.', '_'))): + files.append(('{}{}'.format(prefix, name[:-3]), path)) + + return files + + mods = [] + for modname, sourcefile in list_dir('', basedir): + if modname == 'runtests': + continue + try: + loader = importlib.machinery.SourceFileLoader(modname, sourcefile) + mods.append((loader.load_module(), sourcefile)) + except SyntaxError: + raise + except unittest.SkipTest as err: + print("Skipping '{}': {}".format(modname, err), file=sys.stderr) + + return mods + + +def randomize_tests(tests, seed): + if seed is None: + seed = random.randrange(10000000) + random.seed(seed) + print("Randomize test execution order (seed: %s)" % seed) + random.shuffle(tests._tests) + + +class TestsFinder: + + def __init__(self, testsdir, includes=(), excludes=()): + self._testsdir = testsdir + self._includes = includes + self._excludes = excludes + self.find_available_tests() + + def find_available_tests(self): + """ + Find available test classes without instantiating them. + """ + self._test_factories = [] + mods = [mod for mod, _ in load_modules(self._testsdir)] + for mod in mods: + for name in set(dir(mod)): + if name.endswith('Tests'): + self._test_factories.append(getattr(mod, name)) + + def load_tests(self): + """ + Load test cases from the available test classes and apply + optional include / exclude filters. + """ + loader = unittest.TestLoader() + suite = unittest.TestSuite() + for test_factory in self._test_factories: + tests = loader.loadTestsFromTestCase(test_factory) + if self._includes: + tests = [test + for test in tests + if any(re.search(pat, test.id()) + for pat in self._includes)] + if self._excludes: + tests = [test + for test in tests + if not any(re.search(pat, test.id()) + for pat in self._excludes)] + suite.addTests(tests) + return suite + + +class TestResult(unittest.TextTestResult): + + def __init__(self, stream, descriptions, verbosity): + super().__init__(stream, descriptions, verbosity) + self.leaks = [] + + def startTest(self, test): + super().startTest(test) + gc.collect() + + def addSuccess(self, test): + super().addSuccess(test) + gc.collect() + if gc.garbage: + if self.showAll: + self.stream.writeln( + " Warning: test created {} uncollectable " + "object(s).".format(len(gc.garbage))) + # move the uncollectable objects somewhere so we don't see + # them again + self.leaks.append((self.getDescription(test), gc.garbage[:])) + del gc.garbage[:] + + +class TestRunner(unittest.TextTestRunner): + resultclass = TestResult + + def run(self, test): + result = super().run(test) + if result.leaks: + self.stream.writeln("{} tests leaks:".format(len(result.leaks))) + for name, leaks in result.leaks: + self.stream.writeln(' '*4 + name + ':') + for leak in leaks: + self.stream.writeln(' '*8 + repr(leak)) + return result + + +def _runtests(args, tests): + v = 0 if args.quiet else args.verbose + 1 + runner_factory = TestRunner if args.findleaks else unittest.TextTestRunner + if args.randomize: + randomize_tests(tests, args.seed) + runner = runner_factory(verbosity=v, failfast=args.failfast) + sys.stdout.flush() + sys.stderr.flush() + return runner.run(tests) + + +def runtests(): + args = ARGS.parse_args() + + if args.coverage and coverage is None: + URL = "bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py" + print(textwrap.dedent(""" + coverage package is not installed. + + To install coverage3 for Python 3, you need: + - Setuptools (https://pypi.python.org/pypi/setuptools) + + What worked for me: + - download {0} + * curl -O https://{0} + - python3 ez_setup.py + - python3 -m easy_install coverage + """.format(URL)).strip()) + sys.exit(1) + + testsdir = os.path.abspath(args.testsdir) + if not os.path.isdir(testsdir): + print("Tests directory is not found: {}\n".format(testsdir)) + ARGS.print_help() + return + + excludes = includes = [] + if args.exclude: + excludes = args.pattern + else: + includes = args.pattern + + v = 0 if args.quiet else args.verbose + 1 + failfast = args.failfast + + if args.coverage: + cov = coverage.coverage(branch=True, + source=['asyncio'], + ) + cov.start() + + logger = logging.getLogger() + if v == 0: + level = logging.CRITICAL + elif v == 1: + level = logging.ERROR + elif v == 2: + level = logging.WARNING + elif v == 3: + level = logging.INFO + elif v >= 4: + level = logging.DEBUG + logging.basicConfig(level=level) + + finder = TestsFinder(args.testsdir, includes, excludes) + if args.catchbreak: + installHandler() + import asyncio.coroutines + if asyncio.coroutines._DEBUG: + print("Run tests in debug mode") + else: + print("Run tests in release mode") + try: + tests = finder.load_tests() + if args.forever: + while True: + result = _runtests(args, tests) + if not result.wasSuccessful(): + sys.exit(1) + else: + result = _runtests(args, tests) + sys.exit(not result.wasSuccessful()) + finally: + if args.coverage: + cov.stop() + cov.save() + cov.html_report(directory='htmlcov') + print("\nCoverage report:") + cov.report(show_missing=False) + here = os.path.dirname(os.path.abspath(__file__)) + print("\nFor html report:") + print("open file://{}/htmlcov/index.html".format(here)) + + +if __name__ == '__main__': + runtests() diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..2581bfda --- /dev/null +++ b/setup.py @@ -0,0 +1,49 @@ +# Release procedure: +# - run tox (to run runtests.py and run_aiotest.py) +# - maybe test examples +# - update version in setup.py +# - hg ci +# - hg tag VERSION +# - hg push +# - run on Linux: python setup.py register sdist upload +# - run on Windows: python release.py VERSION +# - increment version in setup.py +# - hg ci && hg push + +import os +try: + from setuptools import setup, Extension +except ImportError: + # Use distutils.core as a fallback. + # We won't be able to build the Wheel file on Windows. + from distutils.core import setup, Extension + +extensions = [] +if os.name == 'nt': + ext = Extension( + 'asyncio._overlapped', ['overlapped.c'], libraries=['ws2_32'], + ) + extensions.append(ext) + +with open("README") as fp: + long_description = fp.read() + +setup( + name="asyncio", + version="3.4.3", + + description="reference implementation of PEP 3156", + long_description=long_description, + url="http://www.python.org/dev/peps/pep-3156/", + + classifiers=[ + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.3", + ], + + packages=["asyncio"], + test_suite="runtests.runtests", + + ext_modules=extensions, +) diff --git a/tests/echo.py b/tests/echo.py new file mode 100644 index 00000000..006364bb --- /dev/null +++ b/tests/echo.py @@ -0,0 +1,8 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + if not buf: + break + os.write(1, buf) diff --git a/tests/echo2.py b/tests/echo2.py new file mode 100644 index 00000000..e83ca09f --- /dev/null +++ b/tests/echo2.py @@ -0,0 +1,6 @@ +import os + +if __name__ == '__main__': + buf = os.read(0, 1024) + os.write(1, b'OUT:'+buf) + os.write(2, b'ERR:'+buf) diff --git a/tests/echo3.py b/tests/echo3.py new file mode 100644 index 00000000..06449673 --- /dev/null +++ b/tests/echo3.py @@ -0,0 +1,11 @@ +import os + +if __name__ == '__main__': + while True: + buf = os.read(0, 1024) + if not buf: + break + try: + os.write(1, b'OUT:'+buf) + except OSError as ex: + os.write(2, b'ERR:' + ex.__class__.__name__.encode('ascii')) diff --git a/tests/keycert3.pem b/tests/keycert3.pem new file mode 100644 index 00000000..5bfa62c4 --- /dev/null +++ b/tests/keycert3.pem @@ -0,0 +1,73 @@ +-----BEGIN PRIVATE KEY----- +MIICdgIBADANBgkqhkiG9w0BAQEFAASCAmAwggJcAgEAAoGBAMLgD0kAKDb5cFyP +jbwNfR5CtewdXC+kMXAWD8DLxiTTvhMW7qVnlwOm36mZlszHKvsRf05lT4pegiFM +9z2j1OlaN+ci/X7NU22TNN6crYSiN77FjYJP464j876ndSxyD+rzys386T+1r1aZ +aggEdkj1TsSsv1zWIYKlPIjlvhuxAgMBAAECgYA0aH+T2Vf3WOPv8KdkcJg6gCRe +yJKXOWgWRcicx/CUzOEsTxmFIDPLxqAWA3k7v0B+3vjGw5Y9lycV/5XqXNoQI14j +y09iNsumds13u5AKkGdTJnZhQ7UKdoVHfuP44ZdOv/rJ5/VD6F4zWywpe90pcbK+ +AWDVtusgGQBSieEl1QJBAOyVrUG5l2yoUBtd2zr/kiGm/DYyXlIthQO/A3/LngDW +5/ydGxVsT7lAVOgCsoT+0L4efTh90PjzW8LPQrPBWVMCQQDS3h/FtYYd5lfz+FNL +9CEe1F1w9l8P749uNUD0g317zv1tatIqVCsQWHfVHNdVvfQ+vSFw38OORO00Xqs9 +1GJrAkBkoXXEkxCZoy4PteheO/8IWWLGGr6L7di6MzFl1lIqwT6D8L9oaV2vynFT +DnKop0pa09Unhjyw57KMNmSE2SUJAkEArloTEzpgRmCq4IK2/NpCeGdHS5uqRlbh +1VIa/xGps7EWQl5Mn8swQDel/YP3WGHTjfx7pgSegQfkyaRtGpZ9OQJAa9Vumj8m +JAAtI0Bnga8hgQx7BhTQY4CadDxyiRGOGYhwUzYVCqkb2sbVRH9HnwUaJT7cWBY3 +RnJdHOMXWem7/w== +-----END PRIVATE KEY----- +Certificate: + Data: + Version: 1 (0x0) + Serial Number: 12723342612721443281 (0xb09264b1f2da21d1) + Signature Algorithm: sha1WithRSAEncryption + Issuer: C=XY, O=Python Software Foundation CA, CN=our-ca-server + Validity + Not Before: Jan 4 19:47:07 2013 GMT + Not After : Nov 13 19:47:07 2022 GMT + Subject: C=XY, L=Castle Anthrax, O=Python Software Foundation, CN=localhost + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (1024 bit) + Modulus: + 00:c2:e0:0f:49:00:28:36:f9:70:5c:8f:8d:bc:0d: + 7d:1e:42:b5:ec:1d:5c:2f:a4:31:70:16:0f:c0:cb: + c6:24:d3:be:13:16:ee:a5:67:97:03:a6:df:a9:99: + 96:cc:c7:2a:fb:11:7f:4e:65:4f:8a:5e:82:21:4c: + f7:3d:a3:d4:e9:5a:37:e7:22:fd:7e:cd:53:6d:93: + 34:de:9c:ad:84:a2:37:be:c5:8d:82:4f:e3:ae:23: + f3:be:a7:75:2c:72:0f:ea:f3:ca:cd:fc:e9:3f:b5: + af:56:99:6a:08:04:76:48:f5:4e:c4:ac:bf:5c:d6: + 21:82:a5:3c:88:e5:be:1b:b1 + Exponent: 65537 (0x10001) + Signature Algorithm: sha1WithRSAEncryption + 2f:42:5f:a3:09:2c:fa:51:88:c7:37:7f:ea:0e:63:f0:a2:9a: + e5:5a:e2:c8:20:f0:3f:60:bc:c8:0f:b6:c6:76:ce:db:83:93: + f5:a3:33:67:01:8e:04:cd:00:9a:73:fd:f3:35:86:fa:d7:13: + e2:46:c6:9d:c0:29:53:d4:a9:90:b8:77:4b:e6:83:76:e4:92: + d6:9c:50:cf:43:d0:c6:01:77:61:9a:de:9b:70:f7:72:cd:59: + 00:31:69:d9:b4:ca:06:9c:6d:c3:c7:80:8c:68:e6:b5:a2:f8: + ef:1d:bb:16:9f:77:77:ef:87:62:22:9b:4d:69:a4:3a:1a:f1: + 21:5e:8c:32:ac:92:fd:15:6b:18:c2:7f:15:0d:98:30:ca:75: + 8f:1a:71:df:da:1d:b2:ef:9a:e8:2d:2e:02:fd:4a:3c:aa:96: + 0b:06:5d:35:b3:3d:24:87:4b:e0:b0:58:60:2f:45:ac:2e:48: + 8a:b0:99:10:65:27:ff:cc:b1:d8:fd:bd:26:6b:b9:0c:05:2a: + f4:45:63:35:51:07:ed:83:85:fe:6f:69:cb:bb:40:a8:ae:b6: + 3b:56:4a:2d:a4:ed:6d:11:2c:4d:ed:17:24:fd:47:bc:d3:41: + a2:d3:06:fe:0c:90:d8:d8:94:26:c4:ff:cc:a1:d8:42:77:eb: + fc:a9:94:71 +-----BEGIN CERTIFICATE----- +MIICpDCCAYwCCQCwkmSx8toh0TANBgkqhkiG9w0BAQUFADBNMQswCQYDVQQGEwJY +WTEmMCQGA1UECgwdUHl0aG9uIFNvZnR3YXJlIEZvdW5kYXRpb24gQ0ExFjAUBgNV +BAMMDW91ci1jYS1zZXJ2ZXIwHhcNMTMwMTA0MTk0NzA3WhcNMjIxMTEzMTk0NzA3 +WjBfMQswCQYDVQQGEwJYWTEXMBUGA1UEBxMOQ2FzdGxlIEFudGhyYXgxIzAhBgNV +BAoTGlB5dGhvbiBTb2Z0d2FyZSBGb3VuZGF0aW9uMRIwEAYDVQQDEwlsb2NhbGhv +c3QwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAMLgD0kAKDb5cFyPjbwNfR5C +tewdXC+kMXAWD8DLxiTTvhMW7qVnlwOm36mZlszHKvsRf05lT4pegiFM9z2j1Ola +N+ci/X7NU22TNN6crYSiN77FjYJP464j876ndSxyD+rzys386T+1r1aZaggEdkj1 +TsSsv1zWIYKlPIjlvhuxAgMBAAEwDQYJKoZIhvcNAQEFBQADggEBAC9CX6MJLPpR +iMc3f+oOY/CimuVa4sgg8D9gvMgPtsZ2ztuDk/WjM2cBjgTNAJpz/fM1hvrXE+JG +xp3AKVPUqZC4d0vmg3bkktacUM9D0MYBd2Ga3ptw93LNWQAxadm0ygacbcPHgIxo +5rWi+O8duxafd3fvh2Iim01ppDoa8SFejDKskv0VaxjCfxUNmDDKdY8acd/aHbLv +mugtLgL9SjyqlgsGXTWzPSSHS+CwWGAvRawuSIqwmRBlJ//Msdj9vSZruQwFKvRF +YzVRB+2Dhf5vacu7QKiutjtWSi2k7W0RLE3tFyT9R7zTQaLTBv4MkNjYlCbE/8yh +2EJ36/yplHE= +-----END CERTIFICATE----- diff --git a/tests/pycacert.pem b/tests/pycacert.pem new file mode 100644 index 00000000..09b1f3e0 --- /dev/null +++ b/tests/pycacert.pem @@ -0,0 +1,78 @@ +Certificate: + Data: + Version: 3 (0x2) + Serial Number: 12723342612721443280 (0xb09264b1f2da21d0) + Signature Algorithm: sha1WithRSAEncryption + Issuer: C=XY, O=Python Software Foundation CA, CN=our-ca-server + Validity + Not Before: Jan 4 19:47:07 2013 GMT + Not After : Jan 2 19:47:07 2023 GMT + Subject: C=XY, O=Python Software Foundation CA, CN=our-ca-server + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + Public-Key: (2048 bit) + Modulus: + 00:e7:de:e9:e3:0c:9f:00:b6:a1:fd:2b:5b:96:d2: + 6f:cc:e0:be:86:b9:20:5e:ec:03:7a:55:ab:ea:a4: + e9:f9:49:85:d2:66:d5:ed:c7:7a:ea:56:8e:2d:8f: + e7:42:e2:62:28:a9:9f:d6:1b:8e:eb:b5:b4:9c:9f: + 14:ab:df:e6:94:8b:76:1d:3e:6d:24:61:ed:0c:bf: + 00:8a:61:0c:df:5c:c8:36:73:16:00:cd:47:ba:6d: + a4:a4:74:88:83:23:0a:19:fc:09:a7:3c:4a:4b:d3: + e7:1d:2d:e4:ea:4c:54:21:f3:26:db:89:37:18:d4: + 02:bb:40:32:5f:a4:ff:2d:1c:f7:d4:bb:ec:8e:cf: + 5c:82:ac:e6:7c:08:6c:48:85:61:07:7f:25:e0:5c: + e0:bc:34:5f:e0:b9:04:47:75:c8:47:0b:8d:bc:d6: + c8:68:5f:33:83:62:d2:20:44:35:b1:ad:81:1a:8a: + cd:bc:35:b0:5c:8b:47:d6:18:e9:9c:18:97:cc:01: + 3c:29:cc:e8:1e:e4:e4:c1:b8:de:e7:c2:11:18:87: + 5a:93:34:d8:a6:25:f7:14:71:eb:e4:21:a2:d2:0f: + 2e:2e:d4:62:00:35:d3:d6:ef:5c:60:4b:4c:a9:14: + e2:dd:15:58:46:37:33:26:b7:e7:2e:5d:ed:42:e4: + c5:4d + Exponent: 65537 (0x10001) + X509v3 extensions: + X509v3 Subject Key Identifier: + BC:DD:62:D9:76:DA:1B:D2:54:6B:CF:E0:66:9B:1E:1E:7B:56:0C:0B + X509v3 Authority Key Identifier: + keyid:BC:DD:62:D9:76:DA:1B:D2:54:6B:CF:E0:66:9B:1E:1E:7B:56:0C:0B + + X509v3 Basic Constraints: + CA:TRUE + Signature Algorithm: sha1WithRSAEncryption + 7d:0a:f5:cb:8d:d3:5d:bd:99:8e:f8:2b:0f:ba:eb:c2:d9:a6: + 27:4f:2e:7b:2f:0e:64:d8:1c:35:50:4e:ee:fc:90:b9:8d:6d: + a8:c5:c6:06:b0:af:f3:2d:bf:3b:b8:42:07:dd:18:7d:6d:95: + 54:57:85:18:60:47:2f:eb:78:1b:f9:e8:17:fd:5a:0d:87:17: + 28:ac:4c:6a:e6:bc:29:f4:f4:55:70:29:42:de:85:ea:ab:6c: + 23:06:64:30:75:02:8e:53:bc:5e:01:33:37:cc:1e:cd:b8:a4: + fd:ca:e4:5f:65:3b:83:1c:86:f1:55:02:a0:3a:8f:db:91:b7: + 40:14:b4:e7:8d:d2:ee:73:ba:e3:e5:34:2d:bc:94:6f:4e:24: + 06:f7:5f:8b:0e:a7:8e:6b:de:5e:75:f4:32:9a:50:b1:44:33: + 9a:d0:05:e2:78:82:ff:db:da:8a:63:eb:a9:dd:d1:bf:a0:61: + ad:e3:9e:8a:24:5d:62:0e:e7:4c:91:7f:ef:df:34:36:3b:2f: + 5d:f5:84:b2:2f:c4:6d:93:96:1a:6f:30:28:f1:da:12:9a:64: + b4:40:33:1d:bd:de:2b:53:a8:ea:be:d6:bc:4e:96:f5:44:fb: + 32:18:ae:d5:1f:f6:69:af:b6:4e:7b:1d:58:ec:3b:a9:53:a3: + 5e:58:c8:9e +-----BEGIN CERTIFICATE----- +MIIDbTCCAlWgAwIBAgIJALCSZLHy2iHQMA0GCSqGSIb3DQEBBQUAME0xCzAJBgNV +BAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUgRm91bmRhdGlvbiBDQTEW +MBQGA1UEAwwNb3VyLWNhLXNlcnZlcjAeFw0xMzAxMDQxOTQ3MDdaFw0yMzAxMDIx +OTQ3MDdaME0xCzAJBgNVBAYTAlhZMSYwJAYDVQQKDB1QeXRob24gU29mdHdhcmUg +Rm91bmRhdGlvbiBDQTEWMBQGA1UEAwwNb3VyLWNhLXNlcnZlcjCCASIwDQYJKoZI +hvcNAQEBBQADggEPADCCAQoCggEBAOfe6eMMnwC2of0rW5bSb8zgvoa5IF7sA3pV +q+qk6flJhdJm1e3HeupWji2P50LiYiipn9Ybjuu1tJyfFKvf5pSLdh0+bSRh7Qy/ +AIphDN9cyDZzFgDNR7ptpKR0iIMjChn8Cac8SkvT5x0t5OpMVCHzJtuJNxjUArtA +Ml+k/y0c99S77I7PXIKs5nwIbEiFYQd/JeBc4Lw0X+C5BEd1yEcLjbzWyGhfM4Ni +0iBENbGtgRqKzbw1sFyLR9YY6ZwYl8wBPCnM6B7k5MG43ufCERiHWpM02KYl9xRx +6+QhotIPLi7UYgA109bvXGBLTKkU4t0VWEY3Mya35y5d7ULkxU0CAwEAAaNQME4w +HQYDVR0OBBYEFLzdYtl22hvSVGvP4GabHh57VgwLMB8GA1UdIwQYMBaAFLzdYtl2 +2hvSVGvP4GabHh57VgwLMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEFBQADggEB +AH0K9cuN0129mY74Kw+668LZpidPLnsvDmTYHDVQTu78kLmNbajFxgawr/Mtvzu4 +QgfdGH1tlVRXhRhgRy/reBv56Bf9Wg2HFyisTGrmvCn09FVwKULeheqrbCMGZDB1 +Ao5TvF4BMzfMHs24pP3K5F9lO4MchvFVAqA6j9uRt0AUtOeN0u5zuuPlNC28lG9O +JAb3X4sOp45r3l519DKaULFEM5rQBeJ4gv/b2opj66nd0b+gYa3jnookXWIO50yR +f+/fNDY7L131hLIvxG2TlhpvMCjx2hKaZLRAMx293itTqOq+1rxOlvVE+zIYrtUf +9mmvtk57HVjsO6lTo15YyJ4= +-----END CERTIFICATE----- diff --git a/tests/sample.crt b/tests/sample.crt new file mode 100644 index 00000000..6a1e3f3c --- /dev/null +++ b/tests/sample.crt @@ -0,0 +1,14 @@ +-----BEGIN CERTIFICATE----- +MIICMzCCAZwCCQDFl4ys0fU7iTANBgkqhkiG9w0BAQUFADBeMQswCQYDVQQGEwJV +UzETMBEGA1UECAwKQ2FsaWZvcm5pYTEWMBQGA1UEBwwNU2FuLUZyYW5jaXNjbzEi +MCAGA1UECgwZUHl0aG9uIFNvZnR3YXJlIEZvbmRhdGlvbjAeFw0xMzAzMTgyMDA3 +MjhaFw0yMzAzMTYyMDA3MjhaMF4xCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApDYWxp +Zm9ybmlhMRYwFAYDVQQHDA1TYW4tRnJhbmNpc2NvMSIwIAYDVQQKDBlQeXRob24g +U29mdHdhcmUgRm9uZGF0aW9uMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQCn +t3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx/47Vc5TZSaO11uO7 +gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIiqusnLfpqR8cIAavg +Z06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQABMA0GCSqGSIb3DQEB +BQUAA4GBAE9PknG6pv72+5z/gsDGYy8sK5UNkbWSNr4i4e5lxVsF03+/M71H+3AB +MxVX4+A+Vlk2fmU+BrdHIIUE0r1dDcO3josQ9hc9OJpp5VLSQFP8VeuJCmzYPp9I +I8WbW93cnXnChTrYQVdgVoFdv7GE9YgU7NYkrGIM0nZl1/f/bHPB +-----END CERTIFICATE----- diff --git a/tests/sample.key b/tests/sample.key new file mode 100644 index 00000000..edfea8dc --- /dev/null +++ b/tests/sample.key @@ -0,0 +1,15 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXQIBAAKBgQCnt3s+J7L0xP/YdAQOacpPi9phlrzKZhcXL3XMu2LCUg2fNJpx +/47Vc5TZSaO11uO7gdwVz3Z7Q2epAgwo59JLffLt5fia8+a/SlPweI/j4+wcIIIi +qusnLfpqR8cIAavgZ06cLYCDvb9wMlheIvSJY12skc1nnphWS2YJ0Xm6uQIDAQAB +AoGABfm8k19Yue3W68BecKEGS0VBV57GRTPT+MiBGvVGNIQ15gk6w3sGfMZsdD1y +bsUkQgcDb2d/4i5poBTpl/+Cd41V+c20IC/sSl5X1IEreHMKSLhy/uyjyiyfXlP1 +iXhToFCgLWwENWc8LzfUV8vuAV5WG6oL9bnudWzZxeqx8V0CQQDR7xwVj6LN70Eb +DUhSKLkusmFw5Gk9NJ/7wZ4eHg4B8c9KNVvSlLCLhcsVTQXuqYeFpOqytI45SneP +lr0vrvsDAkEAzITYiXu6ox5huDCG7imX2W9CAYuX638urLxBqBXMS7GqBzojD6RL +21Q8oPwJWJquERa3HDScq1deiQbM9uKIkwJBAIa1PLslGN216Xv3UPHPScyKD/aF +ynXIv+OnANPoiyp6RH4ksQ/18zcEGiVH8EeNpvV9tlAHhb+DZibQHgNr74sCQQC0 +zhToplu/bVKSlUQUNO0rqrI9z30FErDewKeCw5KSsIRSU1E/uM3fHr9iyq4wiL6u +GNjUtKZ0y46lsT9uW6LFAkB5eqeEQnshAdr3X5GykWHJ8DDGBXPPn6Rce1NX4RSq +V9khG2z1bFyfo+hMqpYnF2k32hVq3E54RS8YYnwBsVof +-----END RSA PRIVATE KEY----- diff --git a/tests/ssl_cert.pem b/tests/ssl_cert.pem new file mode 100644 index 00000000..47a7d7e3 --- /dev/null +++ b/tests/ssl_cert.pem @@ -0,0 +1,15 @@ +-----BEGIN CERTIFICATE----- +MIICVDCCAb2gAwIBAgIJANfHOBkZr8JOMA0GCSqGSIb3DQEBBQUAMF8xCzAJBgNV +BAYTAlhZMRcwFQYDVQQHEw5DYXN0bGUgQW50aHJheDEjMCEGA1UEChMaUHl0aG9u +IFNvZnR3YXJlIEZvdW5kYXRpb24xEjAQBgNVBAMTCWxvY2FsaG9zdDAeFw0xMDEw +MDgyMzAxNTZaFw0yMDEwMDUyMzAxNTZaMF8xCzAJBgNVBAYTAlhZMRcwFQYDVQQH +Ew5DYXN0bGUgQW50aHJheDEjMCEGA1UEChMaUHl0aG9uIFNvZnR3YXJlIEZvdW5k +YXRpb24xEjAQBgNVBAMTCWxvY2FsaG9zdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAw +gYkCgYEA21vT5isq7F68amYuuNpSFlKDPrMUCa4YWYqZRt2OZ+/3NKaZ2xAiSwr7 +6MrQF70t5nLbSPpqE5+5VrS58SY+g/sXLiFd6AplH1wJZwh78DofbFYXUggktFMt +pTyiX8jtP66bkcPkDADA089RI1TQR6Ca+n7HFa7c1fabVV6i3zkCAwEAAaMYMBYw +FAYDVR0RBA0wC4IJbG9jYWxob3N0MA0GCSqGSIb3DQEBBQUAA4GBAHPctQBEQ4wd +BJ6+JcpIraopLn8BGhbjNWj40mmRqWB/NAWF6M5ne7KpGAu7tLeG4hb1zLaldK8G +lxy2GPSRF6LFS48dpEj2HbMv2nvv6xxalDMJ9+DicWgAKTQ6bcX2j3GUkCR0g/T1 +CRlNBAAlvhKzO7Clpf9l0YKBEfraJByX +-----END CERTIFICATE----- diff --git a/tests/ssl_key.pem b/tests/ssl_key.pem new file mode 100644 index 00000000..3fd3bbd5 --- /dev/null +++ b/tests/ssl_key.pem @@ -0,0 +1,16 @@ +-----BEGIN PRIVATE KEY----- +MIICdwIBADANBgkqhkiG9w0BAQEFAASCAmEwggJdAgEAAoGBANtb0+YrKuxevGpm +LrjaUhZSgz6zFAmuGFmKmUbdjmfv9zSmmdsQIksK++jK0Be9LeZy20j6ahOfuVa0 +ufEmPoP7Fy4hXegKZR9cCWcIe/A6H2xWF1IIJLRTLaU8ol/I7T+um5HD5AwAwNPP +USNU0Eegmvp+xxWu3NX2m1Veot85AgMBAAECgYA3ZdZ673X0oexFlq7AAmrutkHt +CL7LvwrpOiaBjhyTxTeSNWzvtQBkIU8DOI0bIazA4UreAFffwtvEuPmonDb3F+Iq +SMAu42XcGyVZEl+gHlTPU9XRX7nTOXVt+MlRRRxL6t9GkGfUAXI3XxJDXW3c0vBK +UL9xqD8cORXOfE06rQJBAP8mEX1ERkR64Ptsoe4281vjTlNfIbs7NMPkUnrn9N/Y +BLhjNIfQ3HFZG8BTMLfX7kCS9D593DW5tV4Z9BP/c6cCQQDcFzCcVArNh2JSywOQ +ZfTfRbJg/Z5Lt9Fkngv1meeGNPgIMLN8Sg679pAOOWmzdMO3V706rNPzSVMME7E5 +oPIfAkEA8pDddarP5tCvTTgUpmTFbakm0KoTZm2+FzHcnA4jRh+XNTjTOv98Y6Ik +eO5d1ZnKXseWvkZncQgxfdnMqqpj5wJAcNq/RVne1DbYlwWchT2Si65MYmmJ8t+F +0mcsULqjOnEMwf5e+ptq5LzwbyrHZYq5FNk7ocufPv/ZQrcSSC+cFwJBAKvOJByS +x56qyGeZLOQlWS2JS3KJo59XuLFGqcbgN9Om9xFa41Yb4N9NvplFivsvZdw3m1Q/ +SPIXQuT8RMPDVNQ= +-----END PRIVATE KEY----- diff --git a/tests/test_base_events.py b/tests/test_base_events.py new file mode 100644 index 00000000..9e7c50cc --- /dev/null +++ b/tests/test_base_events.py @@ -0,0 +1,1236 @@ +"""Tests for base_events.py""" + +import errno +import logging +import math +import socket +import sys +import threading +import time +import unittest +from unittest import mock + +import asyncio +from asyncio import base_events +from asyncio import constants +from asyncio import test_utils +try: + from test import support + from test.script_helper import assert_python_ok +except ImportError: + from asyncio import test_support as support + from asyncio.test_support import assert_python_ok + + +MOCK_ANY = mock.ANY +PY34 = sys.version_info >= (3, 4) + + +class BaseEventLoopTests(test_utils.TestCase): + + def setUp(self): + self.loop = base_events.BaseEventLoop() + self.loop._selector = mock.Mock() + self.loop._selector.select.return_value = () + self.set_event_loop(self.loop) + + def test_not_implemented(self): + m = mock.Mock() + self.assertRaises( + NotImplementedError, + self.loop._make_socket_transport, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_ssl_transport, m, m, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_datagram_transport, m, m) + self.assertRaises( + NotImplementedError, self.loop._process_events, []) + self.assertRaises( + NotImplementedError, self.loop._write_to_self) + self.assertRaises( + NotImplementedError, + self.loop._make_read_pipe_transport, m, m) + self.assertRaises( + NotImplementedError, + self.loop._make_write_pipe_transport, m, m) + gen = self.loop._make_subprocess_transport(m, m, m, m, m, m, m) + self.assertRaises(NotImplementedError, next, iter(gen)) + + def test_close(self): + self.assertFalse(self.loop.is_closed()) + self.loop.close() + self.assertTrue(self.loop.is_closed()) + + # it should be possible to call close() more than once + self.loop.close() + self.loop.close() + + # operation blocked when the loop is closed + f = asyncio.Future(loop=self.loop) + self.assertRaises(RuntimeError, self.loop.run_forever) + self.assertRaises(RuntimeError, self.loop.run_until_complete, f) + + def test__add_callback_handle(self): + h = asyncio.Handle(lambda: False, (), self.loop) + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertIn(h, self.loop._ready) + + def test__add_callback_cancelled_handle(self): + h = asyncio.Handle(lambda: False, (), self.loop) + h.cancel() + + self.loop._add_callback(h) + self.assertFalse(self.loop._scheduled) + self.assertFalse(self.loop._ready) + + def test_set_default_executor(self): + executor = mock.Mock() + self.loop.set_default_executor(executor) + self.assertIs(executor, self.loop._default_executor) + + def test_getnameinfo(self): + sockaddr = mock.Mock() + self.loop.run_in_executor = mock.Mock() + self.loop.getnameinfo(sockaddr) + self.assertEqual( + (None, socket.getnameinfo, sockaddr, 0), + self.loop.run_in_executor.call_args[0]) + + def test_call_soon(self): + def cb(): + pass + + h = self.loop.call_soon(cb) + self.assertEqual(h._callback, cb) + self.assertIsInstance(h, asyncio.Handle) + self.assertIn(h, self.loop._ready) + + def test_call_later(self): + def cb(): + pass + + h = self.loop.call_later(10.0, cb) + self.assertIsInstance(h, asyncio.TimerHandle) + self.assertIn(h, self.loop._scheduled) + self.assertNotIn(h, self.loop._ready) + + def test_call_later_negative_delays(self): + calls = [] + + def cb(arg): + calls.append(arg) + + self.loop._process_events = mock.Mock() + self.loop.call_later(-1, cb, 'a') + self.loop.call_later(-2, cb, 'b') + test_utils.run_briefly(self.loop) + self.assertEqual(calls, ['b', 'a']) + + def test_time_and_call_at(self): + def cb(): + self.loop.stop() + + self.loop._process_events = mock.Mock() + delay = 0.1 + + when = self.loop.time() + delay + self.loop.call_at(when, cb) + t0 = self.loop.time() + self.loop.run_forever() + dt = self.loop.time() - t0 + + # 50 ms: maximum granularity of the event loop + self.assertGreaterEqual(dt, delay - 0.050, dt) + # tolerate a difference of +800 ms because some Python buildbots + # are really slow + self.assertLessEqual(dt, 0.9, dt) + + def check_thread(self, loop, debug): + def cb(): + pass + + loop.set_debug(debug) + if debug: + msg = ("Non-thread-safe operation invoked on an event loop other " + "than the current one") + with self.assertRaisesRegex(RuntimeError, msg): + loop.call_soon(cb) + with self.assertRaisesRegex(RuntimeError, msg): + loop.call_later(60, cb) + with self.assertRaisesRegex(RuntimeError, msg): + loop.call_at(loop.time() + 60, cb) + else: + loop.call_soon(cb) + loop.call_later(60, cb) + loop.call_at(loop.time() + 60, cb) + + def test_check_thread(self): + def check_in_thread(loop, event, debug, create_loop, fut): + # wait until the event loop is running + event.wait() + + try: + if create_loop: + loop2 = base_events.BaseEventLoop() + try: + asyncio.set_event_loop(loop2) + self.check_thread(loop, debug) + finally: + asyncio.set_event_loop(None) + loop2.close() + else: + self.check_thread(loop, debug) + except Exception as exc: + loop.call_soon_threadsafe(fut.set_exception, exc) + else: + loop.call_soon_threadsafe(fut.set_result, None) + + def test_thread(loop, debug, create_loop=False): + event = threading.Event() + fut = asyncio.Future(loop=loop) + loop.call_soon(event.set) + args = (loop, event, debug, create_loop, fut) + thread = threading.Thread(target=check_in_thread, args=args) + thread.start() + loop.run_until_complete(fut) + thread.join() + + self.loop._process_events = mock.Mock() + self.loop._write_to_self = mock.Mock() + + # raise RuntimeError if the thread has no event loop + test_thread(self.loop, True) + + # check disabled if debug mode is disabled + test_thread(self.loop, False) + + # raise RuntimeError if the event loop of the thread is not the called + # event loop + test_thread(self.loop, True, create_loop=True) + + # check disabled if debug mode is disabled + test_thread(self.loop, False, create_loop=True) + + def test_run_once_in_executor_handle(self): + def cb(): + pass + + self.assertRaises( + AssertionError, self.loop.run_in_executor, + None, asyncio.Handle(cb, (), self.loop), ('',)) + self.assertRaises( + AssertionError, self.loop.run_in_executor, + None, asyncio.TimerHandle(10, cb, (), self.loop)) + + def test_run_once_in_executor_cancelled(self): + def cb(): + pass + h = asyncio.Handle(cb, (), self.loop) + h.cancel() + + f = self.loop.run_in_executor(None, h) + self.assertIsInstance(f, asyncio.Future) + self.assertTrue(f.done()) + self.assertIsNone(f.result()) + + def test_run_once_in_executor_plain(self): + def cb(): + pass + h = asyncio.Handle(cb, (), self.loop) + f = asyncio.Future(loop=self.loop) + executor = mock.Mock() + executor.submit.return_value = f + + self.loop.set_default_executor(executor) + + res = self.loop.run_in_executor(None, h) + self.assertIs(f, res) + + executor = mock.Mock() + executor.submit.return_value = f + res = self.loop.run_in_executor(executor, h) + self.assertIs(f, res) + self.assertTrue(executor.submit.called) + + f.cancel() # Don't complain about abandoned Future. + + def test__run_once(self): + h1 = asyncio.TimerHandle(time.monotonic() + 5.0, lambda: True, (), + self.loop) + h2 = asyncio.TimerHandle(time.monotonic() + 10.0, lambda: True, (), + self.loop) + + h1.cancel() + + self.loop._process_events = mock.Mock() + self.loop._scheduled.append(h1) + self.loop._scheduled.append(h2) + self.loop._run_once() + + t = self.loop._selector.select.call_args[0][0] + self.assertTrue(9.5 < t < 10.5, t) + self.assertEqual([h2], self.loop._scheduled) + self.assertTrue(self.loop._process_events.called) + + def test_set_debug(self): + self.loop.set_debug(True) + self.assertTrue(self.loop.get_debug()) + self.loop.set_debug(False) + self.assertFalse(self.loop.get_debug()) + + @mock.patch('asyncio.base_events.logger') + def test__run_once_logging(self, m_logger): + def slow_select(timeout): + # Sleep a bit longer than a second to avoid timer resolution + # issues. + time.sleep(1.1) + return [] + + # logging needs debug flag + self.loop.set_debug(True) + + # Log to INFO level if timeout > 1.0 sec. + self.loop._selector.select = slow_select + self.loop._process_events = mock.Mock() + self.loop._run_once() + self.assertEqual(logging.INFO, m_logger.log.call_args[0][0]) + + def fast_select(timeout): + time.sleep(0.001) + return [] + + self.loop._selector.select = fast_select + self.loop._run_once() + self.assertEqual(logging.DEBUG, m_logger.log.call_args[0][0]) + + def test__run_once_schedule_handle(self): + handle = None + processed = False + + def cb(loop): + nonlocal processed, handle + processed = True + handle = loop.call_soon(lambda: True) + + h = asyncio.TimerHandle(time.monotonic() - 1, cb, (self.loop,), + self.loop) + + self.loop._process_events = mock.Mock() + self.loop._scheduled.append(h) + self.loop._run_once() + + self.assertTrue(processed) + self.assertEqual([handle], list(self.loop._ready)) + + def test__run_once_cancelled_event_cleanup(self): + self.loop._process_events = mock.Mock() + + self.assertTrue( + 0 < base_events._MIN_CANCELLED_TIMER_HANDLES_FRACTION < 1.0) + + def cb(): + pass + + # Set up one "blocking" event that will not be cancelled to + # ensure later cancelled events do not make it to the head + # of the queue and get cleaned. + not_cancelled_count = 1 + self.loop.call_later(3000, cb) + + # Add less than threshold (base_events._MIN_SCHEDULED_TIMER_HANDLES) + # cancelled handles, ensure they aren't removed + + cancelled_count = 2 + for x in range(2): + h = self.loop.call_later(3600, cb) + h.cancel() + + # Add some cancelled events that will be at head and removed + cancelled_count += 2 + for x in range(2): + h = self.loop.call_later(100, cb) + h.cancel() + + # This test is invalid if _MIN_SCHEDULED_TIMER_HANDLES is too low + self.assertLessEqual(cancelled_count + not_cancelled_count, + base_events._MIN_SCHEDULED_TIMER_HANDLES) + + self.assertEqual(self.loop._timer_cancelled_count, cancelled_count) + + self.loop._run_once() + + cancelled_count -= 2 + + self.assertEqual(self.loop._timer_cancelled_count, cancelled_count) + + self.assertEqual(len(self.loop._scheduled), + cancelled_count + not_cancelled_count) + + # Need enough events to pass _MIN_CANCELLED_TIMER_HANDLES_FRACTION + # so that deletion of cancelled events will occur on next _run_once + add_cancel_count = int(math.ceil( + base_events._MIN_SCHEDULED_TIMER_HANDLES * + base_events._MIN_CANCELLED_TIMER_HANDLES_FRACTION)) + 1 + + add_not_cancel_count = max(base_events._MIN_SCHEDULED_TIMER_HANDLES - + add_cancel_count, 0) + + # Add some events that will not be cancelled + not_cancelled_count += add_not_cancel_count + for x in range(add_not_cancel_count): + self.loop.call_later(3600, cb) + + # Add enough cancelled events + cancelled_count += add_cancel_count + for x in range(add_cancel_count): + h = self.loop.call_later(3600, cb) + h.cancel() + + # Ensure all handles are still scheduled + self.assertEqual(len(self.loop._scheduled), + cancelled_count + not_cancelled_count) + + self.loop._run_once() + + # Ensure cancelled events were removed + self.assertEqual(len(self.loop._scheduled), not_cancelled_count) + + # Ensure only uncancelled events remain scheduled + self.assertTrue(all([not x._cancelled for x in self.loop._scheduled])) + + def test_run_until_complete_type_error(self): + self.assertRaises(TypeError, + self.loop.run_until_complete, 'blah') + + def test_run_until_complete_loop(self): + task = asyncio.Future(loop=self.loop) + other_loop = self.new_test_loop() + self.addCleanup(other_loop.close) + self.assertRaises(ValueError, + other_loop.run_until_complete, task) + + def test_subprocess_exec_invalid_args(self): + args = [sys.executable, '-c', 'pass'] + + # missing program parameter (empty args) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol) + + # expected multiple arguments, not a list + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, args) + + # program arguments must be strings, not int + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, sys.executable, 123) + + # universal_newlines, shell, bufsize must not be set + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, *args, universal_newlines=True) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, *args, shell=True) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_exec, + asyncio.SubprocessProtocol, *args, bufsize=4096) + + def test_subprocess_shell_invalid_args(self): + # expected a string, not an int or a list + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, 123) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, [sys.executable, '-c', 'pass']) + + # universal_newlines, shell, bufsize must not be set + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, 'exit 0', universal_newlines=True) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, 'exit 0', shell=True) + self.assertRaises(TypeError, + self.loop.run_until_complete, self.loop.subprocess_shell, + asyncio.SubprocessProtocol, 'exit 0', bufsize=4096) + + def test_default_exc_handler_callback(self): + self.loop._process_events = mock.Mock() + + def zero_error(fut): + fut.set_result(True) + 1/0 + + # Test call_soon (events.Handle) + with mock.patch('asyncio.base_events.logger') as log: + fut = asyncio.Future(loop=self.loop) + self.loop.call_soon(zero_error, fut) + fut.add_done_callback(lambda fut: self.loop.stop()) + self.loop.run_forever() + log.error.assert_called_with( + test_utils.MockPattern('Exception in callback.*zero'), + exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) + + # Test call_later (events.TimerHandle) + with mock.patch('asyncio.base_events.logger') as log: + fut = asyncio.Future(loop=self.loop) + self.loop.call_later(0.01, zero_error, fut) + fut.add_done_callback(lambda fut: self.loop.stop()) + self.loop.run_forever() + log.error.assert_called_with( + test_utils.MockPattern('Exception in callback.*zero'), + exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) + + def test_default_exc_handler_coro(self): + self.loop._process_events = mock.Mock() + + @asyncio.coroutine + def zero_error_coro(): + yield from asyncio.sleep(0.01, loop=self.loop) + 1/0 + + # Test Future.__del__ + with mock.patch('asyncio.base_events.logger') as log: + fut = asyncio.async(zero_error_coro(), loop=self.loop) + fut.add_done_callback(lambda *args: self.loop.stop()) + self.loop.run_forever() + fut = None # Trigger Future.__del__ or futures._TracebackLogger + if PY34: + # Future.__del__ in Python 3.4 logs error with + # an actual exception context + log.error.assert_called_with( + test_utils.MockPattern('.*exception was never retrieved'), + exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) + else: + # futures._TracebackLogger logs only textual traceback + log.error.assert_called_with( + test_utils.MockPattern( + '.*exception was never retrieved.*ZeroDiv'), + exc_info=False) + + def test_set_exc_handler_invalid(self): + with self.assertRaisesRegex(TypeError, 'A callable object or None'): + self.loop.set_exception_handler('spam') + + def test_set_exc_handler_custom(self): + def zero_error(): + 1/0 + + def run_loop(): + handle = self.loop.call_soon(zero_error) + self.loop._run_once() + return handle + + self.loop.set_debug(True) + self.loop._process_events = mock.Mock() + + mock_handler = mock.Mock() + self.loop.set_exception_handler(mock_handler) + handle = run_loop() + mock_handler.assert_called_with(self.loop, { + 'exception': MOCK_ANY, + 'message': test_utils.MockPattern( + 'Exception in callback.*zero_error'), + 'handle': handle, + 'source_traceback': handle._source_traceback, + }) + mock_handler.reset_mock() + + self.loop.set_exception_handler(None) + with mock.patch('asyncio.base_events.logger') as log: + run_loop() + log.error.assert_called_with( + test_utils.MockPattern( + 'Exception in callback.*zero'), + exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) + + assert not mock_handler.called + + def test_set_exc_handler_broken(self): + def run_loop(): + def zero_error(): + 1/0 + self.loop.call_soon(zero_error) + self.loop._run_once() + + def handler(loop, context): + raise AttributeError('spam') + + self.loop._process_events = mock.Mock() + + self.loop.set_exception_handler(handler) + + with mock.patch('asyncio.base_events.logger') as log: + run_loop() + log.error.assert_called_with( + test_utils.MockPattern( + 'Unhandled error in exception handler'), + exc_info=(AttributeError, MOCK_ANY, MOCK_ANY)) + + def test_default_exc_handler_broken(self): + _context = None + + class Loop(base_events.BaseEventLoop): + + _selector = mock.Mock() + _process_events = mock.Mock() + + def default_exception_handler(self, context): + nonlocal _context + _context = context + # Simulates custom buggy "default_exception_handler" + raise ValueError('spam') + + loop = Loop() + self.addCleanup(loop.close) + asyncio.set_event_loop(loop) + + def run_loop(): + def zero_error(): + 1/0 + loop.call_soon(zero_error) + loop._run_once() + + with mock.patch('asyncio.base_events.logger') as log: + run_loop() + log.error.assert_called_with( + 'Exception in default exception handler', + exc_info=True) + + def custom_handler(loop, context): + raise ValueError('ham') + + _context = None + loop.set_exception_handler(custom_handler) + with mock.patch('asyncio.base_events.logger') as log: + run_loop() + log.error.assert_called_with( + test_utils.MockPattern('Exception in default exception.*' + 'while handling.*in custom'), + exc_info=True) + + # Check that original context was passed to default + # exception handler. + self.assertIn('context', _context) + self.assertIs(type(_context['context']['exception']), + ZeroDivisionError) + + def test_env_var_debug(self): + code = '\n'.join(( + 'import asyncio', + 'loop = asyncio.get_event_loop()', + 'print(loop.get_debug())')) + + # Test with -E to not fail if the unit test was run with + # PYTHONASYNCIODEBUG set to a non-empty string + sts, stdout, stderr = assert_python_ok('-E', '-c', code) + self.assertEqual(stdout.rstrip(), b'False') + + sts, stdout, stderr = assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='') + self.assertEqual(stdout.rstrip(), b'False') + + sts, stdout, stderr = assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='1') + self.assertEqual(stdout.rstrip(), b'True') + + sts, stdout, stderr = assert_python_ok('-E', '-c', code, + PYTHONASYNCIODEBUG='1') + self.assertEqual(stdout.rstrip(), b'False') + + def test_create_task(self): + class MyTask(asyncio.Task): + pass + + @asyncio.coroutine + def test(): + pass + + class EventLoop(base_events.BaseEventLoop): + def create_task(self, coro): + return MyTask(coro, loop=loop) + + loop = EventLoop() + self.set_event_loop(loop) + + coro = test() + task = asyncio.async(coro, loop=loop) + self.assertIsInstance(task, MyTask) + + # make warnings quiet + task._log_destroy_pending = False + coro.close() + + def test_run_forever_keyboard_interrupt(self): + # Python issue #22601: ensure that the temporary task created by + # run_forever() consumes the KeyboardInterrupt and so don't log + # a warning + @asyncio.coroutine + def raise_keyboard_interrupt(): + raise KeyboardInterrupt + + self.loop._process_events = mock.Mock() + self.loop.call_exception_handler = mock.Mock() + + try: + self.loop.run_until_complete(raise_keyboard_interrupt()) + except KeyboardInterrupt: + pass + self.loop.close() + support.gc_collect() + + self.assertFalse(self.loop.call_exception_handler.called) + + def test_run_until_complete_baseexception(self): + # Python issue #22429: run_until_complete() must not schedule a pending + # call to stop() if the future raised a BaseException + @asyncio.coroutine + def raise_keyboard_interrupt(): + raise KeyboardInterrupt + + self.loop._process_events = mock.Mock() + + try: + self.loop.run_until_complete(raise_keyboard_interrupt()) + except KeyboardInterrupt: + pass + + def func(): + self.loop.stop() + func.called = True + func.called = False + try: + self.loop.call_soon(func) + self.loop.run_forever() + except KeyboardInterrupt: + pass + self.assertTrue(func.called) + + +class MyProto(asyncio.Protocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = asyncio.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyDatagramProto(asyncio.DatagramProtocol): + done = None + + def __init__(self, create_future=False): + self.state = 'INITIAL' + self.nbytes = 0 + if create_future: + self.done = asyncio.Future() + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def error_received(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class BaseEventLoopWithSelectorTests(test_utils.TestCase): + + def setUp(self): + self.loop = asyncio.new_event_loop() + self.set_event_loop(self.loop) + + @mock.patch('asyncio.base_events.socket') + def test_create_connection_multiple_errors(self, m_socket): + + class MyProto(asyncio.Protocol): + pass + + @asyncio.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + idx = -1 + errors = ['err1', 'err2'] + + def _socket(*args, **kw): + nonlocal idx, errors + idx += 1 + raise OSError(errors[idx]) + + m_socket.socket = _socket + + self.loop.getaddrinfo = getaddrinfo_task + + coro = self.loop.create_connection(MyProto, 'example.com', 80) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(coro) + + self.assertEqual(str(cm.exception), 'Multiple exceptions: err1, err2') + + @mock.patch('asyncio.base_events.socket') + def test_create_connection_timeout(self, m_socket): + # Ensure that the socket is closed on timeout + sock = mock.Mock() + m_socket.socket.return_value = sock + + def getaddrinfo(*args, **kw): + fut = asyncio.Future(loop=self.loop) + addr = (socket.AF_INET, socket.SOCK_STREAM, 0, '', + ('127.0.0.1', 80)) + fut.set_result([addr]) + return fut + self.loop.getaddrinfo = getaddrinfo + + with mock.patch.object(self.loop, 'sock_connect', + side_effect=asyncio.TimeoutError): + coro = self.loop.create_connection(MyProto, '127.0.0.1', 80) + with self.assertRaises(asyncio.TimeoutError): + self.loop.run_until_complete(coro) + self.assertTrue(sock.close.called) + + def test_create_connection_host_port_sock(self): + coro = self.loop.create_connection( + MyProto, 'example.com', 80, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_host_port_sock(self): + coro = self.loop.create_connection(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_no_getaddrinfo(self): + @asyncio.coroutine + def getaddrinfo(*args, **kw): + yield from [] + + def getaddrinfo_task(*args, **kwds): + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_connect_err(self): + @asyncio.coroutine + def getaddrinfo(*args, **kw): + yield from [] + return [(2, 1, 6, '', ('107.6.106.82', 80))] + + def getaddrinfo_task(*args, **kwds): + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_connection(MyProto, 'example.com', 80) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_multiple(self): + @asyncio.coroutine + def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + + def getaddrinfo_task(*args, **kwds): + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET) + with self.assertRaises(OSError): + self.loop.run_until_complete(coro) + + @mock.patch('asyncio.base_events.socket') + def test_create_connection_multiple_errors_local_addr(self, m_socket): + + def bind(addr): + if addr[0] == '0.0.0.1': + err = OSError('Err') + err.strerror = 'Err' + raise err + + m_socket.socket.return_value.bind = bind + + @asyncio.coroutine + def getaddrinfo(*args, **kw): + return [(2, 1, 6, '', ('0.0.0.1', 80)), + (2, 1, 6, '', ('0.0.0.2', 80))] + + def getaddrinfo_task(*args, **kwds): + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + self.loop.sock_connect = mock.Mock() + self.loop.sock_connect.side_effect = OSError('Err2') + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(coro) + + self.assertTrue(str(cm.exception).startswith('Multiple exceptions: ')) + self.assertTrue(m_socket.socket.return_value.close.called) + + def test_create_connection_no_local_addr(self): + @asyncio.coroutine + def getaddrinfo(host, *args, **kw): + if host == 'example.com': + return [(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))] + else: + return [] + + def getaddrinfo_task(*args, **kwds): + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) + self.loop.getaddrinfo = getaddrinfo_task + + coro = self.loop.create_connection( + MyProto, 'example.com', 80, family=socket.AF_INET, + local_addr=(None, 8080)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_connection_ssl_server_hostname_default(self): + self.loop.getaddrinfo = mock.Mock() + + def mock_getaddrinfo(*args, **kwds): + f = asyncio.Future(loop=self.loop) + f.set_result([(socket.AF_INET, socket.SOCK_STREAM, + socket.SOL_TCP, '', ('1.2.3.4', 80))]) + return f + + self.loop.getaddrinfo.side_effect = mock_getaddrinfo + self.loop.sock_connect = mock.Mock() + self.loop.sock_connect.return_value = () + self.loop._make_ssl_transport = mock.Mock() + + class _SelectorTransportMock: + _sock = None + + def get_extra_info(self, key): + return mock.Mock() + + def close(self): + self._sock.close() + + def mock_make_ssl_transport(sock, protocol, sslcontext, waiter, + **kwds): + waiter.set_result(None) + transport = _SelectorTransportMock() + transport._sock = sock + return transport + + self.loop._make_ssl_transport.side_effect = mock_make_ssl_transport + ANY = mock.ANY + # First try the default server_hostname. + self.loop._make_ssl_transport.reset_mock() + coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True) + transport, _ = self.loop.run_until_complete(coro) + transport.close() + self.loop._make_ssl_transport.assert_called_with( + ANY, ANY, ANY, ANY, + server_side=False, + server_hostname='python.org') + # Next try an explicit server_hostname. + self.loop._make_ssl_transport.reset_mock() + coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True, + server_hostname='perl.com') + transport, _ = self.loop.run_until_complete(coro) + transport.close() + self.loop._make_ssl_transport.assert_called_with( + ANY, ANY, ANY, ANY, + server_side=False, + server_hostname='perl.com') + # Finally try an explicit empty server_hostname. + self.loop._make_ssl_transport.reset_mock() + coro = self.loop.create_connection(MyProto, 'python.org', 80, ssl=True, + server_hostname='') + transport, _ = self.loop.run_until_complete(coro) + transport.close() + self.loop._make_ssl_transport.assert_called_with(ANY, ANY, ANY, ANY, + server_side=False, + server_hostname='') + + def test_create_connection_no_ssl_server_hostname_errors(self): + # When not using ssl, server_hostname must be None. + coro = self.loop.create_connection(MyProto, 'python.org', 80, + server_hostname='') + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + coro = self.loop.create_connection(MyProto, 'python.org', 80, + server_hostname='python.org') + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_connection_ssl_server_hostname_errors(self): + # When using ssl, server_hostname may be None if host is non-empty. + coro = self.loop.create_connection(MyProto, '', 80, ssl=True) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + coro = self.loop.create_connection(MyProto, None, 80, ssl=True) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + sock = socket.socket() + coro = self.loop.create_connection(MyProto, None, None, + ssl=True, sock=sock) + self.addCleanup(sock.close) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + def test_create_server_empty_host(self): + # if host is empty string use None instead + host = object() + + @asyncio.coroutine + def getaddrinfo(*args, **kw): + nonlocal host + host = args[0] + yield from [] + + def getaddrinfo_task(*args, **kwds): + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task + fut = self.loop.create_server(MyProto, '', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertIsNone(host) + + def test_create_server_host_port_sock(self): + fut = self.loop.create_server( + MyProto, '0.0.0.0', 0, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_create_server_no_host_port_sock(self): + fut = self.loop.create_server(MyProto) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_create_server_no_getaddrinfo(self): + getaddrinfo = self.loop.getaddrinfo = mock.Mock() + getaddrinfo.return_value = [] + + f = self.loop.create_server(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, f) + + @mock.patch('asyncio.base_events.socket') + def test_create_server_cant_bind(self, m_socket): + + class Err(OSError): + strerror = 'error' + + m_socket.getaddrinfo.return_value = [ + (2, 1, 6, '', ('127.0.0.1', 10100))] + m_socket.getaddrinfo._is_coroutine = False + m_sock = m_socket.socket.return_value = mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.create_server(MyProto, '0.0.0.0', 0) + self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + @mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_no_addrinfo(self, m_socket): + m_socket.getaddrinfo.return_value = [] + m_socket.getaddrinfo._is_coroutine = False + + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_addr_error(self): + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr='localhost') + self.assertRaises( + AssertionError, self.loop.run_until_complete, coro) + coro = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('localhost', 1, 2, 3)) + self.assertRaises( + AssertionError, self.loop.run_until_complete, coro) + + def test_create_datagram_endpoint_connect_err(self): + self.loop.sock_connect = mock.Mock() + self.loop.sock_connect.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol, remote_addr=('127.0.0.1', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + @mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_socket_err(self, m_socket): + m_socket.getaddrinfo = socket.getaddrinfo + m_socket.socket.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + coro = self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol, local_addr=('127.0.0.1', 0)) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + + @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 not supported or enabled') + def test_create_datagram_endpoint_no_matching_family(self): + coro = self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol, + remote_addr=('127.0.0.1', 0), local_addr=('::1', 0)) + self.assertRaises( + ValueError, self.loop.run_until_complete, coro) + + @mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_setblk_err(self, m_socket): + m_socket.socket.return_value.setblocking.side_effect = OSError + + coro = self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol, family=socket.AF_INET) + self.assertRaises( + OSError, self.loop.run_until_complete, coro) + self.assertTrue( + m_socket.socket.return_value.close.called) + + def test_create_datagram_endpoint_noaddr_nofamily(self): + coro = self.loop.create_datagram_endpoint( + asyncio.DatagramProtocol) + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + + @mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_cant_bind(self, m_socket): + class Err(OSError): + pass + + m_socket.AF_INET6 = socket.AF_INET6 + m_socket.getaddrinfo = socket.getaddrinfo + m_sock = m_socket.socket.return_value = mock.Mock() + m_sock.bind.side_effect = Err + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, + local_addr=('127.0.0.1', 0), family=socket.AF_INET) + self.assertRaises(Err, self.loop.run_until_complete, fut) + self.assertTrue(m_sock.close.called) + + def test_accept_connection_retry(self): + sock = mock.Mock() + sock.accept.side_effect = BlockingIOError() + + self.loop._accept_connection(MyProto, sock) + self.assertFalse(sock.close.called) + + @mock.patch('asyncio.base_events.logger') + def test_accept_connection_exception(self, m_log): + sock = mock.Mock() + sock.fileno.return_value = 10 + sock.accept.side_effect = OSError(errno.EMFILE, 'Too many open files') + self.loop.remove_reader = mock.Mock() + self.loop.call_later = mock.Mock() + + self.loop._accept_connection(MyProto, sock) + self.assertTrue(m_log.error.called) + self.assertFalse(sock.close.called) + self.loop.remove_reader.assert_called_with(10) + self.loop.call_later.assert_called_with(constants.ACCEPT_RETRY_DELAY, + # self.loop._start_serving + mock.ANY, + MyProto, sock, None, None) + + def test_call_coroutine(self): + @asyncio.coroutine + def simple_coroutine(): + pass + + coro_func = simple_coroutine + coro_obj = coro_func() + self.addCleanup(coro_obj.close) + for func in (coro_func, coro_obj): + with self.assertRaises(TypeError): + self.loop.call_soon(func) + with self.assertRaises(TypeError): + self.loop.call_soon_threadsafe(func) + with self.assertRaises(TypeError): + self.loop.call_later(60, func) + with self.assertRaises(TypeError): + self.loop.call_at(self.loop.time() + 60, func) + with self.assertRaises(TypeError): + self.loop.run_in_executor(None, func) + + @mock.patch('asyncio.base_events.logger') + def test_log_slow_callbacks(self, m_logger): + def stop_loop_cb(loop): + loop.stop() + + @asyncio.coroutine + def stop_loop_coro(loop): + yield from () + loop.stop() + + asyncio.set_event_loop(self.loop) + self.loop.set_debug(True) + self.loop.slow_callback_duration = 0.0 + + # slow callback + self.loop.call_soon(stop_loop_cb, self.loop) + self.loop.run_forever() + fmt, *args = m_logger.warning.call_args[0] + self.assertRegex(fmt % tuple(args), + "^Executing " + "took .* seconds$") + + # slow task + asyncio.async(stop_loop_coro(self.loop), loop=self.loop) + self.loop.run_forever() + fmt, *args = m_logger.warning.call_args[0] + self.assertRegex(fmt % tuple(args), + "^Executing " + "took .* seconds$") + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_events.py b/tests/test_events.py new file mode 100644 index 00000000..a38c90eb --- /dev/null +++ b/tests/test_events.py @@ -0,0 +1,2369 @@ +"""Tests for events.py.""" + +import functools +import gc +import io +import os +import platform +import re +import signal +import socket +try: + import ssl +except ImportError: + ssl = None +import subprocess +import sys +import threading +import time +import errno +import unittest +from unittest import mock +import weakref + + +import asyncio +from asyncio import proactor_events +from asyncio import selector_events +from asyncio import sslproto +from asyncio import test_utils +try: + from test import support +except ImportError: + from asyncio import test_support as support + + +def data_file(filename): + if hasattr(support, 'TEST_HOME_DIR'): + fullname = os.path.join(support.TEST_HOME_DIR, filename) + if os.path.isfile(fullname): + return fullname + fullname = os.path.join(os.path.dirname(__file__), filename) + if os.path.isfile(fullname): + return fullname + raise FileNotFoundError(filename) + + +def osx_tiger(): + """Return True if the platform is Mac OS 10.4 or older.""" + if sys.platform != 'darwin': + return False + version = platform.mac_ver()[0] + version = tuple(map(int, version.split('.'))) + return version < (10, 5) + + +ONLYCERT = data_file('ssl_cert.pem') +ONLYKEY = data_file('ssl_key.pem') +SIGNED_CERTFILE = data_file('keycert3.pem') +SIGNING_CA = data_file('pycacert.pem') + + +class MyBaseProto(asyncio.Protocol): + connected = None + done = None + + def __init__(self, loop=None): + self.transport = None + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.connected = asyncio.Future(loop=loop) + self.done = asyncio.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + if self.connected: + self.connected.set_result(None) + + def data_received(self, data): + assert self.state == 'CONNECTED', self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == 'CONNECTED', self.state + self.state = 'EOF' + + def connection_lost(self, exc): + assert self.state in ('CONNECTED', 'EOF'), self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyProto(MyBaseProto): + def connection_made(self, transport): + super().connection_made(transport) + transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') + + +class MyDatagramProto(asyncio.DatagramProtocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.nbytes = 0 + if loop is not None: + self.done = asyncio.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'INITIALIZED' + + def datagram_received(self, data, addr): + assert self.state == 'INITIALIZED', self.state + self.nbytes += len(data) + + def error_received(self, exc): + assert self.state == 'INITIALIZED', self.state + + def connection_lost(self, exc): + assert self.state == 'INITIALIZED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MyReadPipeProto(asyncio.Protocol): + done = None + + def __init__(self, loop=None): + self.state = ['INITIAL'] + self.nbytes = 0 + self.transport = None + if loop is not None: + self.done = asyncio.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == ['INITIAL'], self.state + self.state.append('CONNECTED') + + def data_received(self, data): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.nbytes += len(data) + + def eof_received(self): + assert self.state == ['INITIAL', 'CONNECTED'], self.state + self.state.append('EOF') + + def connection_lost(self, exc): + if 'EOF' not in self.state: + self.state.append('EOF') # It is okay if EOF is missed. + assert self.state == ['INITIAL', 'CONNECTED', 'EOF'], self.state + self.state.append('CLOSED') + if self.done: + self.done.set_result(None) + + +class MyWritePipeProto(asyncio.BaseProtocol): + done = None + + def __init__(self, loop=None): + self.state = 'INITIAL' + self.transport = None + if loop is not None: + self.done = asyncio.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + if self.done: + self.done.set_result(None) + + +class MySubprocessProtocol(asyncio.SubprocessProtocol): + + def __init__(self, loop): + self.state = 'INITIAL' + self.transport = None + self.connected = asyncio.Future(loop=loop) + self.completed = asyncio.Future(loop=loop) + self.disconnects = {fd: asyncio.Future(loop=loop) for fd in range(3)} + self.data = {1: b'', 2: b''} + self.returncode = None + self.got_data = {1: asyncio.Event(loop=loop), + 2: asyncio.Event(loop=loop)} + + def connection_made(self, transport): + self.transport = transport + assert self.state == 'INITIAL', self.state + self.state = 'CONNECTED' + self.connected.set_result(None) + + def connection_lost(self, exc): + assert self.state == 'CONNECTED', self.state + self.state = 'CLOSED' + self.completed.set_result(None) + + def pipe_data_received(self, fd, data): + assert self.state == 'CONNECTED', self.state + self.data[fd] += data + self.got_data[fd].set() + + def pipe_connection_lost(self, fd, exc): + assert self.state == 'CONNECTED', self.state + if exc: + self.disconnects[fd].set_exception(exc) + else: + self.disconnects[fd].set_result(exc) + + def process_exited(self): + assert self.state == 'CONNECTED', self.state + self.returncode = self.transport.get_returncode() + + +class EventLoopTestsMixin: + + def setUp(self): + super().setUp() + self.loop = self.create_event_loop() + self.set_event_loop(self.loop) + + def tearDown(self): + # just in case if we have transport close callbacks + if not self.loop.is_closed(): + test_utils.run_briefly(self.loop) + + self.loop.close() + gc.collect() + super().tearDown() + + def test_run_until_complete_nesting(self): + @asyncio.coroutine + def coro1(): + yield + + @asyncio.coroutine + def coro2(): + self.assertTrue(self.loop.is_running()) + self.loop.run_until_complete(coro1()) + + self.assertRaises( + RuntimeError, self.loop.run_until_complete, coro2()) + + # Note: because of the default Windows timing granularity of + # 15.6 msec, we use fairly long sleep times here (~100 msec). + + def test_run_until_complete(self): + t0 = self.loop.time() + self.loop.run_until_complete(asyncio.sleep(0.1, loop=self.loop)) + t1 = self.loop.time() + self.assertTrue(0.08 <= t1-t0 <= 0.8, t1-t0) + + def test_run_until_complete_stopped(self): + @asyncio.coroutine + def cb(): + self.loop.stop() + yield from asyncio.sleep(0.1, loop=self.loop) + task = cb() + self.assertRaises(RuntimeError, + self.loop.run_until_complete, task) + + def test_call_later(self): + results = [] + + def callback(arg): + results.append(arg) + self.loop.stop() + + self.loop.call_later(0.1, callback, 'hello world') + t0 = time.monotonic() + self.loop.run_forever() + t1 = time.monotonic() + self.assertEqual(results, ['hello world']) + self.assertTrue(0.08 <= t1-t0 <= 0.8, t1-t0) + + def test_call_soon(self): + results = [] + + def callback(arg1, arg2): + results.append((arg1, arg2)) + self.loop.stop() + + self.loop.call_soon(callback, 'hello', 'world') + self.loop.run_forever() + self.assertEqual(results, [('hello', 'world')]) + + def test_call_soon_threadsafe(self): + results = [] + lock = threading.Lock() + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.loop.stop() + + def run_in_thread(): + self.loop.call_soon_threadsafe(callback, 'hello') + lock.release() + + lock.acquire() + t = threading.Thread(target=run_in_thread) + t.start() + + with lock: + self.loop.call_soon(callback, 'world') + self.loop.run_forever() + t.join() + self.assertEqual(results, ['hello', 'world']) + + def test_call_soon_threadsafe_same_thread(self): + results = [] + + def callback(arg): + results.append(arg) + if len(results) >= 2: + self.loop.stop() + + self.loop.call_soon_threadsafe(callback, 'hello') + self.loop.call_soon(callback, 'world') + self.loop.run_forever() + self.assertEqual(results, ['hello', 'world']) + + def test_run_in_executor(self): + def run(arg): + return (arg, threading.get_ident()) + f2 = self.loop.run_in_executor(None, run, 'yo') + res, thread_id = self.loop.run_until_complete(f2) + self.assertEqual(res, 'yo') + self.assertNotEqual(thread_id, threading.get_ident()) + + def test_reader_callback(self): + r, w = test_utils.socketpair() + r.setblocking(False) + bytes_read = bytearray() + + def reader(): + try: + data = r.recv(1024) + except BlockingIOError: + # Spurious readiness notifications are possible + # at least on Linux -- see man select. + return + if data: + bytes_read.extend(data) + else: + self.assertTrue(self.loop.remove_reader(r.fileno())) + r.close() + + self.loop.add_reader(r.fileno(), reader) + self.loop.call_soon(w.send, b'abc') + test_utils.run_until(self.loop, lambda: len(bytes_read) >= 3) + self.loop.call_soon(w.send, b'def') + test_utils.run_until(self.loop, lambda: len(bytes_read) >= 6) + self.loop.call_soon(w.close) + self.loop.call_soon(self.loop.stop) + self.loop.run_forever() + self.assertEqual(bytes_read, b'abcdef') + + def test_writer_callback(self): + r, w = test_utils.socketpair() + w.setblocking(False) + + def writer(data): + w.send(data) + self.loop.stop() + + data = b'x' * 1024 + self.loop.add_writer(w.fileno(), writer, data) + self.loop.run_forever() + + self.assertTrue(self.loop.remove_writer(w.fileno())) + self.assertFalse(self.loop.remove_writer(w.fileno())) + + w.close() + read = r.recv(len(data) * 2) + r.close() + self.assertEqual(read, data) + + def _basetest_sock_client_ops(self, httpd, sock): + if not isinstance(self.loop, proactor_events.BaseProactorEventLoop): + # in debug mode, socket operations must fail + # if the socket is not in blocking mode + self.loop.set_debug(True) + sock.setblocking(True) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_connect(sock, httpd.address)) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + with self.assertRaises(ValueError): + self.loop.run_until_complete( + self.loop.sock_accept(sock)) + + # test in non-blocking mode + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, httpd.address)) + self.loop.run_until_complete( + self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n')) + data = self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + # consume data + self.loop.run_until_complete( + self.loop.sock_recv(sock, 1024)) + sock.close() + self.assertTrue(data.startswith(b'HTTP/1.0 200 OK')) + + def test_sock_client_ops(self): + with test_utils.run_test_server() as httpd: + sock = socket.socket() + self._basetest_sock_client_ops(httpd, sock) + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_unix_sock_client_ops(self): + with test_utils.run_test_unix_server() as httpd: + sock = socket.socket(socket.AF_UNIX) + self._basetest_sock_client_ops(httpd, sock) + + def test_sock_client_fail(self): + # Make sure that we will get an unused port + address = None + try: + s = socket.socket() + s.bind(('127.0.0.1', 0)) + address = s.getsockname() + finally: + s.close() + + sock = socket.socket() + sock.setblocking(False) + with self.assertRaises(ConnectionRefusedError): + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) + sock.close() + + def test_sock_accept(self): + listener = socket.socket() + listener.setblocking(False) + listener.bind(('127.0.0.1', 0)) + listener.listen(1) + client = socket.socket() + client.connect(listener.getsockname()) + + f = self.loop.sock_accept(listener) + conn, addr = self.loop.run_until_complete(f) + self.assertEqual(conn.gettimeout(), 0) + self.assertEqual(addr, client.getsockname()) + self.assertEqual(client.getpeername(), listener.getsockname()) + client.close() + conn.close() + listener.close() + + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + def test_add_signal_handler(self): + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + + # Check error behavior first. + self.assertRaises( + TypeError, self.loop.add_signal_handler, 'boom', my_handler) + self.assertRaises( + TypeError, self.loop.remove_signal_handler, 'boom') + self.assertRaises( + ValueError, self.loop.add_signal_handler, signal.NSIG+1, + my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, signal.NSIG+1) + self.assertRaises( + ValueError, self.loop.add_signal_handler, 0, my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, 0) + self.assertRaises( + ValueError, self.loop.add_signal_handler, -1, my_handler) + self.assertRaises( + ValueError, self.loop.remove_signal_handler, -1) + self.assertRaises( + RuntimeError, self.loop.add_signal_handler, signal.SIGKILL, + my_handler) + # Removing SIGKILL doesn't raise, since we don't call signal(). + self.assertFalse(self.loop.remove_signal_handler(signal.SIGKILL)) + # Now set a handler and handle it. + self.loop.add_signal_handler(signal.SIGINT, my_handler) + + os.kill(os.getpid(), signal.SIGINT) + test_utils.run_until(self.loop, lambda: caught) + + # Removing it should restore the default handler. + self.assertTrue(self.loop.remove_signal_handler(signal.SIGINT)) + self.assertEqual(signal.getsignal(signal.SIGINT), + signal.default_int_handler) + # Removing again returns False. + self.assertFalse(self.loop.remove_signal_handler(signal.SIGINT)) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_while_selecting(self): + # Test with a signal actually arriving during a select() call. + caught = 0 + + def my_handler(): + nonlocal caught + caught += 1 + self.loop.stop() + + self.loop.add_signal_handler(signal.SIGALRM, my_handler) + + signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once. + self.loop.run_forever() + self.assertEqual(caught, 1) + + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + def test_signal_handling_args(self): + some_args = (42,) + caught = 0 + + def my_handler(*args): + nonlocal caught + caught += 1 + self.assertEqual(args, some_args) + + self.loop.add_signal_handler(signal.SIGALRM, my_handler, *some_args) + + signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. + self.loop.call_later(0.5, self.loop.stop) + self.loop.run_forever() + self.assertEqual(caught, 1) + + def _basetest_create_connection(self, connection_fut, check_sockname=True): + tr, pr = self.loop.run_until_complete(connection_fut) + self.assertIsInstance(tr, asyncio.Transport) + self.assertIsInstance(pr, asyncio.Protocol) + self.assertIs(pr.transport, tr) + if check_sockname: + self.assertIsNotNone(tr.get_extra_info('sockname')) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def test_create_connection(self): + with test_utils.run_test_server() as httpd: + conn_fut = self.loop.create_connection( + lambda: MyProto(loop=self.loop), *httpd.address) + self._basetest_create_connection(conn_fut) + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_connection(self): + # Issue #20682: On Mac OS X Tiger, getsockname() returns a + # zero-length address for UNIX socket. + check_sockname = not osx_tiger() + + with test_utils.run_test_unix_server() as httpd: + conn_fut = self.loop.create_unix_connection( + lambda: MyProto(loop=self.loop), httpd.address) + self._basetest_create_connection(conn_fut, check_sockname) + + def test_create_connection_sock(self): + with test_utils.run_test_server() as httpd: + sock = None + infos = self.loop.run_until_complete( + self.loop.getaddrinfo( + *httpd.address, type=socket.SOCK_STREAM)) + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + self.loop.run_until_complete( + self.loop.sock_connect(sock, address)) + except: + pass + else: + break + else: + assert False, 'Can not create socket.' + + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), sock=sock) + tr, pr = self.loop.run_until_complete(f) + self.assertIsInstance(tr, asyncio.Transport) + self.assertIsInstance(pr, asyncio.Protocol) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def _basetest_create_ssl_connection(self, connection_fut, + check_sockname=True): + tr, pr = self.loop.run_until_complete(connection_fut) + self.assertIsInstance(tr, asyncio.Transport) + self.assertIsInstance(pr, asyncio.Protocol) + self.assertTrue('ssl' in tr.__class__.__name__.lower()) + if check_sockname: + self.assertIsNotNone(tr.get_extra_info('sockname')) + self.loop.run_until_complete(pr.done) + self.assertGreater(pr.nbytes, 0) + tr.close() + + def _test_create_ssl_connection(self, httpd, create_connection, + check_sockname=True): + conn_fut = create_connection(ssl=test_utils.dummy_ssl_context()) + self._basetest_create_ssl_connection(conn_fut, check_sockname) + + # ssl.Purpose was introduced in Python 3.4 + if hasattr(ssl, 'Purpose'): + def _dummy_ssl_create_context(purpose=ssl.Purpose.SERVER_AUTH, *, + cafile=None, capath=None, + cadata=None): + """ + A ssl.create_default_context() replacement that doesn't enable + cert validation. + """ + self.assertEqual(purpose, ssl.Purpose.SERVER_AUTH) + return test_utils.dummy_ssl_context() + + # With ssl=True, ssl.create_default_context() should be called + with mock.patch('ssl.create_default_context', + side_effect=_dummy_ssl_create_context) as m: + conn_fut = create_connection(ssl=True) + self._basetest_create_ssl_connection(conn_fut, check_sockname) + self.assertEqual(m.call_count, 1) + + # With the real ssl.create_default_context(), certificate + # validation will fail + with self.assertRaises(ssl.SSLError) as cm: + conn_fut = create_connection(ssl=True) + # Ignore the "SSL handshake failed" log in debug mode + with test_utils.disable_logger(): + self._basetest_create_ssl_connection(conn_fut, check_sockname) + + self.assertEqual(cm.exception.reason, 'CERTIFICATE_VERIFY_FAILED') + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_ssl_connection(self): + with test_utils.run_test_server(use_ssl=True) as httpd: + create_connection = functools.partial( + self.loop.create_connection, + lambda: MyProto(loop=self.loop), + *httpd.address) + self._test_create_ssl_connection(httpd, create_connection) + + def test_legacy_create_ssl_connection(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_ssl_connection() + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_ssl_unix_connection(self): + # Issue #20682: On Mac OS X Tiger, getsockname() returns a + # zero-length address for UNIX socket. + check_sockname = not osx_tiger() + + with test_utils.run_test_unix_server(use_ssl=True) as httpd: + create_connection = functools.partial( + self.loop.create_unix_connection, + lambda: MyProto(loop=self.loop), httpd.address, + server_hostname='127.0.0.1') + + self._test_create_ssl_connection(httpd, create_connection, + check_sockname) + + def test_legacy_create_ssl_unix_connection(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_ssl_unix_connection() + + def test_create_connection_local_addr(self): + with test_utils.run_test_server() as httpd: + port = support.find_unused_port() + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, local_addr=(httpd.address[0], port)) + tr, pr = self.loop.run_until_complete(f) + expected = pr.transport.get_extra_info('sockname')[1] + self.assertEqual(port, expected) + tr.close() + + def test_create_connection_local_addr_in_use(self): + with test_utils.run_test_server() as httpd: + f = self.loop.create_connection( + lambda: MyProto(loop=self.loop), + *httpd.address, local_addr=httpd.address) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + self.assertIn(str(httpd.address), cm.exception.strerror) + + def test_create_server(self): + proto = MyProto(self.loop) + f = self.loop.create_server(lambda: proto, '0.0.0.0', 0) + server = self.loop.run_until_complete(f) + self.assertEqual(len(server.sockets), 1) + sock = server.sockets[0] + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.sendall(b'xxx') + + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + test_utils.run_until(self.loop, lambda: proto.nbytes > 0) + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('sockname')) + self.assertEqual('127.0.0.1', + proto.transport.get_extra_info('peername')[0]) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # close server + server.close() + + def _make_unix_server(self, factory, **kwargs): + path = test_utils.gen_unix_socket_path() + self.addCleanup(lambda: os.path.exists(path) and os.unlink(path)) + + f = self.loop.create_unix_server(factory, path, **kwargs) + server = self.loop.run_until_complete(f) + + return server, path + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server(self): + proto = MyProto(loop=self.loop) + server, path = self._make_unix_server(lambda: proto) + self.assertEqual(len(server.sockets), 1) + + client = socket.socket(socket.AF_UNIX) + client.connect(path) + client.sendall(b'xxx') + + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + test_utils.run_until(self.loop, lambda: proto.nbytes > 0) + self.assertEqual(3, proto.nbytes) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # close server + server.close() + + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_path_socket_error(self): + proto = MyProto(loop=self.loop) + sock = socket.socket() + with sock: + f = self.loop.create_unix_server(lambda: proto, '/test', sock=sock) + with self.assertRaisesRegex(ValueError, + 'path and sock can not be specified ' + 'at the same time'): + self.loop.run_until_complete(f) + + def _create_ssl_context(self, certfile, keyfile=None): + sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.load_cert_chain(certfile, keyfile) + return sslcontext + + def _make_ssl_server(self, factory, certfile, keyfile=None): + sslcontext = self._create_ssl_context(certfile, keyfile) + + f = self.loop.create_server(factory, '127.0.0.1', 0, ssl=sslcontext) + server = self.loop.run_until_complete(f) + + sock = server.sockets[0] + host, port = sock.getsockname() + self.assertEqual(host, '127.0.0.1') + return server, host, port + + def _make_ssl_unix_server(self, factory, certfile, keyfile=None): + sslcontext = self._create_ssl_context(certfile, keyfile) + return self._make_unix_server(factory, ssl=sslcontext) + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_server_ssl(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, ONLYCERT, ONLYKEY) + + f_c = self.loop.create_connection(MyBaseProto, host, port, + ssl=test_utils.dummy_ssl_context()) + client, pr = self.loop.run_until_complete(f_c) + + client.write(b'xxx') + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + test_utils.run_until(self.loop, lambda: proto.nbytes > 0) + self.assertEqual(3, proto.nbytes) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('sockname')) + self.assertEqual('127.0.0.1', + proto.transport.get_extra_info('peername')[0]) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # stop serving + server.close() + + def test_legacy_create_server_ssl(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_server_ssl() + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_ssl(self): + proto = MyProto(loop=self.loop) + server, path = self._make_ssl_unix_server( + lambda: proto, ONLYCERT, ONLYKEY) + + f_c = self.loop.create_unix_connection( + MyBaseProto, path, ssl=test_utils.dummy_ssl_context(), + server_hostname='') + + client, pr = self.loop.run_until_complete(f_c) + + client.write(b'xxx') + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + test_utils.run_until(self.loop, lambda: proto.nbytes > 0) + self.assertEqual(3, proto.nbytes) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + # the client socket must be closed after to avoid ECONNRESET upon + # recv()/send() on the serving socket + client.close() + + # stop serving + server.close() + + def test_legacy_create_unix_server_ssl(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_unix_server_ssl() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_server_ssl_verify_failed(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True + + # no CA loaded + f_c = self.loop.create_connection(MyProto, host, port, + ssl=sslcontext_client) + with test_utils.disable_logger(): + with self.assertRaisesRegex(ssl.SSLError, + 'certificate verify failed '): + self.loop.run_until_complete(f_c) + + # close connection + self.assertIsNone(proto.transport) + server.close() + + def test_legacy_create_server_ssl_verify_failed(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_server_ssl_verify_failed() + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_ssl_verify_failed(self): + proto = MyProto(loop=self.loop) + server, path = self._make_ssl_unix_server( + lambda: proto, SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True + + # no CA loaded + f_c = self.loop.create_unix_connection(MyProto, path, + ssl=sslcontext_client, + server_hostname='invalid') + with test_utils.disable_logger(): + with self.assertRaisesRegex(ssl.SSLError, + 'certificate verify failed '): + self.loop.run_until_complete(f_c) + + # close connection + self.assertIsNone(proto.transport) + server.close() + + def test_legacy_create_unix_server_ssl_verify_failed(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_unix_server_ssl_verify_failed() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_server_ssl_match_failed(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + sslcontext_client.load_verify_locations( + cafile=SIGNING_CA) + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True + + # incorrect server_hostname + f_c = self.loop.create_connection(MyProto, host, port, + ssl=sslcontext_client) + with test_utils.disable_logger(): + with self.assertRaisesRegex( + ssl.CertificateError, + "hostname '127.0.0.1' doesn't match 'localhost'"): + self.loop.run_until_complete(f_c) + + # close connection + proto.transport.close() + server.close() + + def test_legacy_create_server_ssl_match_failed(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_server_ssl_match_failed() + + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + def test_create_unix_server_ssl_verified(self): + proto = MyProto(loop=self.loop) + server, path = self._make_ssl_unix_server( + lambda: proto, SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + sslcontext_client.load_verify_locations(cafile=SIGNING_CA) + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True + + # Connection succeeds with correct CA and server hostname. + f_c = self.loop.create_unix_connection(MyProto, path, + ssl=sslcontext_client, + server_hostname='localhost') + client, pr = self.loop.run_until_complete(f_c) + + # close connection + proto.transport.close() + client.close() + server.close() + self.loop.run_until_complete(proto.done) + + def test_legacy_create_unix_server_ssl_verified(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_unix_server_ssl_verified() + + @unittest.skipIf(ssl is None, 'No ssl module') + def test_create_server_ssl_verified(self): + proto = MyProto(loop=self.loop) + server, host, port = self._make_ssl_server( + lambda: proto, SIGNED_CERTFILE) + + sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + sslcontext_client.load_verify_locations(cafile=SIGNING_CA) + if hasattr(sslcontext_client, 'check_hostname'): + sslcontext_client.check_hostname = True + + # Connection succeeds with correct CA and server hostname. + f_c = self.loop.create_connection(MyProto, host, port, + ssl=sslcontext_client, + server_hostname='localhost') + client, pr = self.loop.run_until_complete(f_c) + + # close connection + proto.transport.close() + client.close() + server.close() + self.loop.run_until_complete(proto.done) + + def test_legacy_create_server_ssl_verified(self): + with test_utils.force_legacy_ssl_support(): + self.test_create_server_ssl_verified() + + def test_create_server_sock(self): + proto = asyncio.Future(loop=self.loop) + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + proto.set_result(self) + + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.loop.create_server(TestMyProto, sock=sock_ob) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + self.assertIs(sock, sock_ob) + + host, port = sock.getsockname() + self.assertEqual(host, '0.0.0.0') + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + server.close() + + def test_create_server_addr_in_use(self): + sock_ob = socket.socket(type=socket.SOCK_STREAM) + sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock_ob.bind(('0.0.0.0', 0)) + + f = self.loop.create_server(MyProto, sock=sock_ob) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + host, port = sock.getsockname() + + f = self.loop.create_server(MyProto, host=host, port=port) + with self.assertRaises(OSError) as cm: + self.loop.run_until_complete(f) + self.assertEqual(cm.exception.errno, errno.EADDRINUSE) + + server.close() + + @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 not supported or enabled') + def test_create_server_dual_stack(self): + f_proto = asyncio.Future(loop=self.loop) + + class TestMyProto(MyProto): + def connection_made(self, transport): + super().connection_made(transport) + f_proto.set_result(self) + + try_count = 0 + while True: + try: + port = support.find_unused_port() + f = self.loop.create_server(TestMyProto, host=None, port=port) + server = self.loop.run_until_complete(f) + except OSError as ex: + if ex.errno == errno.EADDRINUSE: + try_count += 1 + self.assertGreaterEqual(5, try_count) + continue + else: + raise + else: + break + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + proto = self.loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + f_proto = asyncio.Future(loop=self.loop) + client = socket.socket(socket.AF_INET6) + client.connect(('::1', port)) + client.send(b'xxx') + proto = self.loop.run_until_complete(f_proto) + proto.transport.close() + client.close() + + server.close() + + def test_server_close(self): + f = self.loop.create_server(MyProto, '0.0.0.0', 0) + server = self.loop.run_until_complete(f) + sock = server.sockets[0] + host, port = sock.getsockname() + + client = socket.socket() + client.connect(('127.0.0.1', port)) + client.send(b'xxx') + client.close() + + server.close() + + client = socket.socket() + self.assertRaises( + ConnectionRefusedError, client.connect, ('127.0.0.1', port)) + client.close() + + def test_create_datagram_endpoint(self): + class TestMyDatagramProto(MyDatagramProto): + def __init__(inner_self): + super().__init__(loop=self.loop) + + def datagram_received(self, data, addr): + super().datagram_received(data, addr) + self.transport.sendto(b'resp:'+data, addr) + + coro = self.loop.create_datagram_endpoint( + TestMyDatagramProto, local_addr=('127.0.0.1', 0)) + s_transport, server = self.loop.run_until_complete(coro) + host, port = s_transport.get_extra_info('sockname') + + self.assertIsInstance(s_transport, asyncio.Transport) + self.assertIsInstance(server, TestMyDatagramProto) + self.assertEqual('INITIALIZED', server.state) + self.assertIs(server.transport, s_transport) + + coro = self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(loop=self.loop), + remote_addr=(host, port)) + transport, client = self.loop.run_until_complete(coro) + + self.assertIsInstance(transport, asyncio.Transport) + self.assertIsInstance(client, MyDatagramProto) + self.assertEqual('INITIALIZED', client.state) + self.assertIs(client.transport, transport) + + transport.sendto(b'xxx') + test_utils.run_until(self.loop, lambda: server.nbytes) + self.assertEqual(3, server.nbytes) + test_utils.run_until(self.loop, lambda: client.nbytes) + + # received + self.assertEqual(8, client.nbytes) + + # extra info is available + self.assertIsNotNone(transport.get_extra_info('sockname')) + + # close connection + transport.close() + self.loop.run_until_complete(client.done) + self.assertEqual('CLOSED', client.state) + server.transport.close() + + def test_internal_fds(self): + loop = self.create_event_loop() + if not isinstance(loop, selector_events.BaseSelectorEventLoop): + loop.close() + self.skipTest('loop is not a BaseSelectorEventLoop') + + self.assertEqual(1, loop._internal_fds) + loop.close() + self.assertEqual(0, loop._internal_fds) + self.assertIsNone(loop._csock) + self.assertIsNone(loop._ssock) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_read_pipe(self): + proto = MyReadPipeProto(loop=self.loop) + + rpipe, wpipe = os.pipe() + pipeobj = io.open(rpipe, 'rb', 1024) + + @asyncio.coroutine + def connect(): + t, p = yield from self.loop.connect_read_pipe( + lambda: proto, pipeobj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(0, proto.nbytes) + + self.loop.run_until_complete(connect()) + + os.write(wpipe, b'1') + test_utils.run_until(self.loop, lambda: proto.nbytes >= 1) + self.assertEqual(1, proto.nbytes) + + os.write(wpipe, b'2345') + test_utils.run_until(self.loop, lambda: proto.nbytes >= 5) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(5, proto.nbytes) + + os.close(wpipe) + self.loop.run_until_complete(proto.done) + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + # select, poll and kqueue don't support character devices (PTY) on Mac OS X + # older than 10.6 (Snow Leopard) + @support.requires_mac_ver(10, 6) + # Issue #20495: The test hangs on FreeBSD 7.2 but pass on FreeBSD 9 + @support.requires_freebsd_version(8) + def test_read_pty_output(self): + proto = MyReadPipeProto(loop=self.loop) + + master, slave = os.openpty() + master_read_obj = io.open(master, 'rb', 0) + + @asyncio.coroutine + def connect(): + t, p = yield from self.loop.connect_read_pipe(lambda: proto, + master_read_obj) + self.assertIs(p, proto) + self.assertIs(t, proto.transport) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(0, proto.nbytes) + + self.loop.run_until_complete(connect()) + + os.write(slave, b'1') + test_utils.run_until(self.loop, lambda: proto.nbytes) + self.assertEqual(1, proto.nbytes) + + os.write(slave, b'2345') + test_utils.run_until(self.loop, lambda: proto.nbytes >= 5) + self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) + self.assertEqual(5, proto.nbytes) + + os.close(slave) + self.loop.run_until_complete(proto.done) + self.assertEqual( + ['INITIAL', 'CONNECTED', 'EOF', 'CLOSED'], proto.state) + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe(self): + rpipe, wpipe = os.pipe() + pipeobj = io.open(wpipe, 'wb', 1024) + + proto = MyWritePipeProto(loop=self.loop) + connect = self.loop.connect_write_pipe(lambda: proto, pipeobj) + transport, p = self.loop.run_until_complete(connect) + self.assertIs(p, proto) + self.assertIs(transport, proto.transport) + self.assertEqual('CONNECTED', proto.state) + + transport.write(b'1') + + data = bytearray() + def reader(data): + chunk = os.read(rpipe, 1024) + data += chunk + return len(data) + + test_utils.run_until(self.loop, lambda: reader(data) >= 1) + self.assertEqual(b'1', data) + + transport.write(b'2345') + test_utils.run_until(self.loop, lambda: reader(data) >= 5) + self.assertEqual(b'12345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(rpipe) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + def test_write_pipe_disconnect_on_close(self): + rsock, wsock = test_utils.socketpair() + rsock.setblocking(False) + pipeobj = io.open(wsock.detach(), 'wb', 1024) + + proto = MyWritePipeProto(loop=self.loop) + connect = self.loop.connect_write_pipe(lambda: proto, pipeobj) + transport, p = self.loop.run_until_complete(connect) + self.assertIs(p, proto) + self.assertIs(transport, proto.transport) + self.assertEqual('CONNECTED', proto.state) + + transport.write(b'1') + data = self.loop.run_until_complete(self.loop.sock_recv(rsock, 1024)) + self.assertEqual(b'1', data) + + rsock.close() + + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") + # select, poll and kqueue don't support character devices (PTY) on Mac OS X + # older than 10.6 (Snow Leopard) + @support.requires_mac_ver(10, 6) + def test_write_pty(self): + master, slave = os.openpty() + slave_write_obj = io.open(slave, 'wb', 0) + + proto = MyWritePipeProto(loop=self.loop) + connect = self.loop.connect_write_pipe(lambda: proto, slave_write_obj) + transport, p = self.loop.run_until_complete(connect) + self.assertIs(p, proto) + self.assertIs(transport, proto.transport) + self.assertEqual('CONNECTED', proto.state) + + transport.write(b'1') + + data = bytearray() + def reader(data): + chunk = os.read(master, 1024) + data += chunk + return len(data) + + test_utils.run_until(self.loop, lambda: reader(data) >= 1, + timeout=10) + self.assertEqual(b'1', data) + + transport.write(b'2345') + test_utils.run_until(self.loop, lambda: reader(data) >= 5, + timeout=10) + self.assertEqual(b'12345', data) + self.assertEqual('CONNECTED', proto.state) + + os.close(master) + + # extra info is available + self.assertIsNotNone(proto.transport.get_extra_info('pipe')) + + # close connection + proto.transport.close() + self.loop.run_until_complete(proto.done) + self.assertEqual('CLOSED', proto.state) + + def test_prompt_cancellation(self): + r, w = test_utils.socketpair() + r.setblocking(False) + f = self.loop.sock_recv(r, 1) + ov = getattr(f, 'ov', None) + if ov is not None: + self.assertTrue(ov.pending) + + @asyncio.coroutine + def main(): + try: + self.loop.call_soon(f.cancel) + yield from f + except asyncio.CancelledError: + res = 'cancelled' + else: + res = None + finally: + self.loop.stop() + return res + + start = time.monotonic() + t = asyncio.Task(main(), loop=self.loop) + self.loop.run_forever() + elapsed = time.monotonic() - start + + self.assertLess(elapsed, 0.1) + self.assertEqual(t.result(), 'cancelled') + self.assertRaises(asyncio.CancelledError, f.result) + if ov is not None: + self.assertFalse(ov.pending) + self.loop._stop_serving(r) + + r.close() + w.close() + + def test_timeout_rounding(self): + def _run_once(): + self.loop._run_once_counter += 1 + orig_run_once() + + orig_run_once = self.loop._run_once + self.loop._run_once_counter = 0 + self.loop._run_once = _run_once + + @asyncio.coroutine + def wait(): + loop = self.loop + yield from asyncio.sleep(1e-2, loop=loop) + yield from asyncio.sleep(1e-4, loop=loop) + yield from asyncio.sleep(1e-6, loop=loop) + yield from asyncio.sleep(1e-8, loop=loop) + yield from asyncio.sleep(1e-10, loop=loop) + + self.loop.run_until_complete(wait()) + # The ideal number of call is 12, but on some platforms, the selector + # may sleep at little bit less than timeout depending on the resolution + # of the clock used by the kernel. Tolerate a few useless calls on + # these platforms. + self.assertLessEqual(self.loop._run_once_counter, 20, + {'clock_resolution': self.loop._clock_resolution, + 'selector': self.loop._selector.__class__.__name__}) + + def test_sock_connect_address(self): + addresses = [(socket.AF_INET, ('www.python.org', 80))] + if support.IPV6_ENABLED: + addresses.extend(( + (socket.AF_INET6, ('www.python.org', 80)), + (socket.AF_INET6, ('www.python.org', 80, 0, 0)), + )) + + for family, address in addresses: + for sock_type in (socket.SOCK_STREAM, socket.SOCK_DGRAM): + sock = socket.socket(family, sock_type) + with sock: + sock.setblocking(False) + connect = self.loop.sock_connect(sock, address) + with self.assertRaises(ValueError) as cm: + self.loop.run_until_complete(connect) + self.assertIn('address must be resolved', + str(cm.exception)) + + def test_remove_fds_after_closing(self): + loop = self.create_event_loop() + callback = lambda: None + r, w = test_utils.socketpair() + self.addCleanup(r.close) + self.addCleanup(w.close) + loop.add_reader(r, callback) + loop.add_writer(w, callback) + loop.close() + self.assertFalse(loop.remove_reader(r)) + self.assertFalse(loop.remove_writer(w)) + + def test_add_fds_after_closing(self): + loop = self.create_event_loop() + callback = lambda: None + r, w = test_utils.socketpair() + self.addCleanup(r.close) + self.addCleanup(w.close) + loop.close() + with self.assertRaises(RuntimeError): + loop.add_reader(r, callback) + with self.assertRaises(RuntimeError): + loop.add_writer(w, callback) + + def test_close_running_event_loop(self): + @asyncio.coroutine + def close_loop(loop): + self.loop.close() + + coro = close_loop(self.loop) + with self.assertRaises(RuntimeError): + self.loop.run_until_complete(coro) + + def test_close(self): + self.loop.close() + + @asyncio.coroutine + def test(): + pass + + func = lambda: False + coro = test() + self.addCleanup(coro.close) + + # operation blocked when the loop is closed + with self.assertRaises(RuntimeError): + self.loop.run_forever() + with self.assertRaises(RuntimeError): + fut = asyncio.Future(loop=self.loop) + self.loop.run_until_complete(fut) + with self.assertRaises(RuntimeError): + self.loop.call_soon(func) + with self.assertRaises(RuntimeError): + self.loop.call_soon_threadsafe(func) + with self.assertRaises(RuntimeError): + self.loop.call_later(1.0, func) + with self.assertRaises(RuntimeError): + self.loop.call_at(self.loop.time() + .0, func) + with self.assertRaises(RuntimeError): + self.loop.run_in_executor(None, func) + with self.assertRaises(RuntimeError): + self.loop.create_task(coro) + with self.assertRaises(RuntimeError): + self.loop.add_signal_handler(signal.SIGTERM, func) + + +class SubprocessTestsMixin: + + def check_terminated(self, returncode): + if sys.platform == 'win32': + self.assertIsInstance(returncode, int) + # expect 1 but sometimes get 0 + else: + self.assertEqual(-signal.SIGTERM, returncode) + + def check_killed(self, returncode): + if sys.platform == 'win32': + self.assertIsInstance(returncode, int) + # expect 1 but sometimes get 0 + else: + self.assertEqual(-signal.SIGKILL, returncode) + + def test_subprocess_exec(self): + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python The Winner') + self.loop.run_until_complete(proto.got_data[1].wait()) + transp.close() + self.loop.run_until_complete(proto.completed) + self.check_terminated(proto.returncode) + self.assertEqual(b'Python The Winner', proto.data[1]) + + def test_subprocess_interactive(self): + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + self.assertEqual('CONNECTED', proto.state) + + try: + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python ') + self.loop.run_until_complete(proto.got_data[1].wait()) + proto.got_data[1].clear() + self.assertEqual(b'Python ', proto.data[1]) + + stdin.write(b'The Winner') + self.loop.run_until_complete(proto.got_data[1].wait()) + self.assertEqual(b'Python The Winner', proto.data[1]) + finally: + transp.close() + + self.loop.run_until_complete(proto.completed) + self.check_terminated(proto.returncode) + + def test_subprocess_shell(self): + connect = self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'echo Python') + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + transp.get_pipe_transport(0).close() + self.loop.run_until_complete(proto.completed) + self.assertEqual(0, proto.returncode) + self.assertTrue(all(f.done() for f in proto.disconnects.values())) + self.assertEqual(proto.data[1].rstrip(b'\r\n'), b'Python') + self.assertEqual(proto.data[2], b'') + transp.close() + + def test_subprocess_exitcode(self): + connect = self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + transp.close() + + def test_subprocess_close_after_finish(self): + connect = self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.assertIsNone(transp.get_pipe_transport(0)) + self.assertIsNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + self.assertIsNone(transp.close()) + + def test_subprocess_kill(self): + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + transp.kill() + self.loop.run_until_complete(proto.completed) + self.check_killed(proto.returncode) + transp.close() + + def test_subprocess_terminate(self): + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + transp.terminate() + self.loop.run_until_complete(proto.completed) + self.check_terminated(proto.returncode) + transp.close() + + @unittest.skipIf(sys.platform == 'win32', "Don't have SIGHUP") + def test_subprocess_send_signal(self): + prog = os.path.join(os.path.dirname(__file__), 'echo.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + transp.send_signal(signal.SIGHUP) + self.loop.run_until_complete(proto.completed) + self.assertEqual(-signal.SIGHUP, proto.returncode) + transp.close() + + def test_subprocess_stderr(self): + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdin.write(b'test') + + self.loop.run_until_complete(proto.completed) + + transp.close() + self.assertEqual(b'OUT:test', proto.data[1]) + self.assertTrue(proto.data[2].startswith(b'ERR:test'), proto.data[2]) + self.assertEqual(0, proto.returncode) + + def test_subprocess_stderr_redirect_to_stdout(self): + prog = os.path.join(os.path.dirname(__file__), 'echo2.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog, stderr=subprocess.STDOUT) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + self.assertIsNotNone(transp.get_pipe_transport(1)) + self.assertIsNone(transp.get_pipe_transport(2)) + + stdin.write(b'test') + self.loop.run_until_complete(proto.completed) + self.assertTrue(proto.data[1].startswith(b'OUT:testERR:test'), + proto.data[1]) + self.assertEqual(b'', proto.data[2]) + + transp.close() + self.assertEqual(0, proto.returncode) + + def test_subprocess_close_client_stream(self): + prog = os.path.join(os.path.dirname(__file__), 'echo3.py') + + connect = self.loop.subprocess_exec( + functools.partial(MySubprocessProtocol, self.loop), + sys.executable, prog) + transp, proto = self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.connected) + + stdin = transp.get_pipe_transport(0) + stdout = transp.get_pipe_transport(1) + stdin.write(b'test') + self.loop.run_until_complete(proto.got_data[1].wait()) + self.assertEqual(b'OUT:test', proto.data[1]) + + stdout.close() + self.loop.run_until_complete(proto.disconnects[1]) + stdin.write(b'xxx') + self.loop.run_until_complete(proto.got_data[2].wait()) + if sys.platform != 'win32': + self.assertEqual(b'ERR:BrokenPipeError', proto.data[2]) + else: + # After closing the read-end of a pipe, writing to the + # write-end using os.write() fails with errno==EINVAL and + # GetLastError()==ERROR_INVALID_NAME on Windows!?! (Using + # WriteFile() we get ERROR_BROKEN_PIPE as expected.) + self.assertEqual(b'ERR:OSError', proto.data[2]) + transp.close() + self.loop.run_until_complete(proto.completed) + self.check_terminated(proto.returncode) + + def test_subprocess_wait_no_same_group(self): + # start the new process in a new session + connect = self.loop.subprocess_shell( + functools.partial(MySubprocessProtocol, self.loop), + 'exit 7', stdin=None, stdout=None, stderr=None, + start_new_session=True) + _, proto = yield self.loop.run_until_complete(connect) + self.assertIsInstance(proto, MySubprocessProtocol) + self.loop.run_until_complete(proto.completed) + self.assertEqual(7, proto.returncode) + + def test_subprocess_exec_invalid_args(self): + @asyncio.coroutine + def connect(**kwds): + yield from self.loop.subprocess_exec( + asyncio.SubprocessProtocol, + 'pwd', **kwds) + + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(universal_newlines=True)) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(bufsize=4096)) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(shell=True)) + + def test_subprocess_shell_invalid_args(self): + @asyncio.coroutine + def connect(cmd=None, **kwds): + if not cmd: + cmd = 'pwd' + yield from self.loop.subprocess_shell( + asyncio.SubprocessProtocol, + cmd, **kwds) + + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(['ls', '-l'])) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(universal_newlines=True)) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(bufsize=4096)) + with self.assertRaises(ValueError): + self.loop.run_until_complete(connect(shell=False)) + + +if sys.platform == 'win32': + + class SelectEventLoopTests(EventLoopTestsMixin, test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop() + + class ProactorEventLoopTests(EventLoopTestsMixin, + SubprocessTestsMixin, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.ProactorEventLoop() + + if not sslproto._is_sslproto_available(): + def test_create_ssl_connection(self): + raise unittest.SkipTest("need python 3.5 (ssl.MemoryBIO)") + + def test_create_server_ssl(self): + raise unittest.SkipTest("need python 3.5 (ssl.MemoryBIO)") + + def test_create_server_ssl_verify_failed(self): + raise unittest.SkipTest("need python 3.5 (ssl.MemoryBIO)") + + def test_create_server_ssl_match_failed(self): + raise unittest.SkipTest("need python 3.5 (ssl.MemoryBIO)") + + def test_create_server_ssl_verified(self): + raise unittest.SkipTest("need python 3.5 (ssl.MemoryBIO)") + + def test_legacy_create_ssl_connection(self): + raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL") + + def test_legacy_create_server_ssl(self): + raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL") + + def test_legacy_create_server_ssl_verify_failed(self): + raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL") + + def test_legacy_create_server_ssl_match_failed(self): + raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL") + + def test_legacy_create_server_ssl_verified(self): + raise unittest.SkipTest("IocpEventLoop incompatible with legacy SSL") + + def test_reader_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_reader_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") + + def test_writer_callback(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + + def test_writer_callback_cancel(self): + raise unittest.SkipTest("IocpEventLoop does not have add_writer()") + + def test_create_datagram_endpoint(self): + raise unittest.SkipTest( + "IocpEventLoop does not have create_datagram_endpoint()") + + def test_remove_fds_after_closing(self): + raise unittest.SkipTest("IocpEventLoop does not have add_reader()") +else: + from asyncio import selectors + + class UnixEventLoopTestsMixin(EventLoopTestsMixin): + def setUp(self): + super().setUp() + watcher = asyncio.SafeChildWatcher() + watcher.attach_loop(self.loop) + asyncio.set_child_watcher(watcher) + + def tearDown(self): + asyncio.set_child_watcher(None) + super().tearDown() + + if hasattr(selectors, 'KqueueSelector'): + class KqueueEventLoopTests(UnixEventLoopTestsMixin, + SubprocessTestsMixin, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop( + selectors.KqueueSelector()) + + # kqueue doesn't support character devices (PTY) on Mac OS X older + # than 10.9 (Maverick) + @support.requires_mac_ver(10, 9) + # Issue #20667: KqueueEventLoopTests.test_read_pty_output() + # hangs on OpenBSD 5.5 + @unittest.skipIf(sys.platform.startswith('openbsd'), + 'test hangs on OpenBSD') + def test_read_pty_output(self): + super().test_read_pty_output() + + # kqueue doesn't support character devices (PTY) on Mac OS X older + # than 10.9 (Maverick) + @support.requires_mac_ver(10, 9) + def test_write_pty(self): + super().test_write_pty() + + if hasattr(selectors, 'EpollSelector'): + class EPollEventLoopTests(UnixEventLoopTestsMixin, + SubprocessTestsMixin, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop(selectors.EpollSelector()) + + if hasattr(selectors, 'PollSelector'): + class PollEventLoopTests(UnixEventLoopTestsMixin, + SubprocessTestsMixin, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop(selectors.PollSelector()) + + # Should always exist. + class SelectEventLoopTests(UnixEventLoopTestsMixin, + SubprocessTestsMixin, + test_utils.TestCase): + + def create_event_loop(self): + return asyncio.SelectorEventLoop(selectors.SelectSelector()) + + +def noop(*args): + pass + + +class HandleTests(test_utils.TestCase): + + def setUp(self): + self.loop = mock.Mock() + self.loop.get_debug.return_value = True + + def test_handle(self): + def callback(*args): + return args + + args = () + h = asyncio.Handle(callback, args, self.loop) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h._cancelled) + + h.cancel() + self.assertTrue(h._cancelled) + + def test_handle_from_handle(self): + def callback(*args): + return args + h1 = asyncio.Handle(callback, (), loop=self.loop) + self.assertRaises( + AssertionError, asyncio.Handle, h1, (), self.loop) + + def test_callback_with_exception(self): + def callback(): + raise ValueError() + + self.loop = mock.Mock() + self.loop.call_exception_handler = mock.Mock() + + h = asyncio.Handle(callback, (), self.loop) + h._run() + + self.loop.call_exception_handler.assert_called_with({ + 'message': test_utils.MockPattern('Exception in callback.*'), + 'exception': mock.ANY, + 'handle': h, + 'source_traceback': h._source_traceback, + }) + + def test_handle_weakref(self): + wd = weakref.WeakValueDictionary() + h = asyncio.Handle(lambda: None, (), self.loop) + wd['h'] = h # Would fail without __weakref__ slot. + + def test_handle_repr(self): + self.loop.get_debug.return_value = False + + # simple function + h = asyncio.Handle(noop, (1, 2), self.loop) + filename, lineno = test_utils.get_function_source(noop) + self.assertEqual(repr(h), + '' + % (filename, lineno)) + + # cancelled handle + h.cancel() + self.assertEqual(repr(h), + '') + + # decorated function + cb = asyncio.coroutine(noop) + h = asyncio.Handle(cb, (), self.loop) + self.assertEqual(repr(h), + '' + % (filename, lineno)) + + # partial function + cb = functools.partial(noop, 1, 2) + h = asyncio.Handle(cb, (3,), self.loop) + regex = (r'^$' + % (re.escape(filename), lineno)) + self.assertRegex(repr(h), regex) + + # partial method + if sys.version_info >= (3, 4): + method = HandleTests.test_handle_repr + cb = functools.partialmethod(method) + filename, lineno = test_utils.get_function_source(method) + h = asyncio.Handle(cb, (), self.loop) + + cb_regex = r'' + cb_regex = (r'functools.partialmethod\(%s, , \)\(\)' % cb_regex) + regex = (r'^$' + % (cb_regex, re.escape(filename), lineno)) + self.assertRegex(repr(h), regex) + + def test_handle_repr_debug(self): + self.loop.get_debug.return_value = True + + # simple function + create_filename = __file__ + create_lineno = sys._getframe().f_lineno + 1 + h = asyncio.Handle(noop, (1, 2), self.loop) + filename, lineno = test_utils.get_function_source(noop) + self.assertEqual(repr(h), + '' + % (filename, lineno, create_filename, create_lineno)) + + # cancelled handle + h.cancel() + self.assertEqual( + repr(h), + '' + % (filename, lineno, create_filename, create_lineno)) + + # double cancellation won't overwrite _repr + h.cancel() + self.assertEqual( + repr(h), + '' + % (filename, lineno, create_filename, create_lineno)) + + def test_handle_source_traceback(self): + loop = asyncio.get_event_loop_policy().new_event_loop() + loop.set_debug(True) + self.set_event_loop(loop) + + def check_source_traceback(h): + lineno = sys._getframe(1).f_lineno - 1 + self.assertIsInstance(h._source_traceback, list) + self.assertEqual(h._source_traceback[-1][:3], + (__file__, + lineno, + 'test_handle_source_traceback')) + + # call_soon + h = loop.call_soon(noop) + check_source_traceback(h) + + # call_soon_threadsafe + h = loop.call_soon_threadsafe(noop) + check_source_traceback(h) + + # call_later + h = loop.call_later(0, noop) + check_source_traceback(h) + + # call_at + h = loop.call_later(0, noop) + check_source_traceback(h) + + +class TimerTests(unittest.TestCase): + + def setUp(self): + self.loop = mock.Mock() + + def test_hash(self): + when = time.monotonic() + h = asyncio.TimerHandle(when, lambda: False, (), + mock.Mock()) + self.assertEqual(hash(h), hash(when)) + + def test_timer(self): + def callback(*args): + return args + + args = (1, 2, 3) + when = time.monotonic() + h = asyncio.TimerHandle(when, callback, args, mock.Mock()) + self.assertIs(h._callback, callback) + self.assertIs(h._args, args) + self.assertFalse(h._cancelled) + + # cancel + h.cancel() + self.assertTrue(h._cancelled) + self.assertIsNone(h._callback) + self.assertIsNone(h._args) + + # when cannot be None + self.assertRaises(AssertionError, + asyncio.TimerHandle, None, callback, args, + self.loop) + + def test_timer_repr(self): + self.loop.get_debug.return_value = False + + # simple function + h = asyncio.TimerHandle(123, noop, (), self.loop) + src = test_utils.get_function_source(noop) + self.assertEqual(repr(h), + '' % src) + + # cancelled handle + h.cancel() + self.assertEqual(repr(h), + '') + + def test_timer_repr_debug(self): + self.loop.get_debug.return_value = True + + # simple function + create_filename = __file__ + create_lineno = sys._getframe().f_lineno + 1 + h = asyncio.TimerHandle(123, noop, (), self.loop) + filename, lineno = test_utils.get_function_source(noop) + self.assertEqual(repr(h), + '' + % (filename, lineno, create_filename, create_lineno)) + + # cancelled handle + h.cancel() + self.assertEqual(repr(h), + '' + % (filename, lineno, create_filename, create_lineno)) + + + def test_timer_comparison(self): + def callback(*args): + return args + + when = time.monotonic() + + h1 = asyncio.TimerHandle(when, callback, (), self.loop) + h2 = asyncio.TimerHandle(when, callback, (), self.loop) + # TODO: Use assertLess etc. + self.assertFalse(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertTrue(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertFalse(h2 > h1) + self.assertTrue(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertTrue(h1 == h2) + self.assertFalse(h1 != h2) + + h2.cancel() + self.assertFalse(h1 == h2) + + h1 = asyncio.TimerHandle(when, callback, (), self.loop) + h2 = asyncio.TimerHandle(when + 10.0, callback, (), self.loop) + self.assertTrue(h1 < h2) + self.assertFalse(h2 < h1) + self.assertTrue(h1 <= h2) + self.assertFalse(h2 <= h1) + self.assertFalse(h1 > h2) + self.assertTrue(h2 > h1) + self.assertFalse(h1 >= h2) + self.assertTrue(h2 >= h1) + self.assertFalse(h1 == h2) + self.assertTrue(h1 != h2) + + h3 = asyncio.Handle(callback, (), self.loop) + self.assertIs(NotImplemented, h1.__eq__(h3)) + self.assertIs(NotImplemented, h1.__ne__(h3)) + + +class AbstractEventLoopTests(unittest.TestCase): + + def test_not_implemented(self): + f = mock.Mock() + loop = asyncio.AbstractEventLoop() + self.assertRaises( + NotImplementedError, loop.run_forever) + self.assertRaises( + NotImplementedError, loop.run_until_complete, None) + self.assertRaises( + NotImplementedError, loop.stop) + self.assertRaises( + NotImplementedError, loop.is_running) + self.assertRaises( + NotImplementedError, loop.is_closed) + self.assertRaises( + NotImplementedError, loop.close) + self.assertRaises( + NotImplementedError, loop.create_task, None) + self.assertRaises( + NotImplementedError, loop.call_later, None, None) + self.assertRaises( + NotImplementedError, loop.call_at, f, f) + self.assertRaises( + NotImplementedError, loop.call_soon, None) + self.assertRaises( + NotImplementedError, loop.time) + self.assertRaises( + NotImplementedError, loop.call_soon_threadsafe, None) + self.assertRaises( + NotImplementedError, loop.run_in_executor, f, f) + self.assertRaises( + NotImplementedError, loop.set_default_executor, f) + self.assertRaises( + NotImplementedError, loop.getaddrinfo, 'localhost', 8080) + self.assertRaises( + NotImplementedError, loop.getnameinfo, ('localhost', 8080)) + self.assertRaises( + NotImplementedError, loop.create_connection, f) + self.assertRaises( + NotImplementedError, loop.create_server, f) + self.assertRaises( + NotImplementedError, loop.create_datagram_endpoint, f) + self.assertRaises( + NotImplementedError, loop.add_reader, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_reader, 1) + self.assertRaises( + NotImplementedError, loop.add_writer, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_writer, 1) + self.assertRaises( + NotImplementedError, loop.sock_recv, f, 10) + self.assertRaises( + NotImplementedError, loop.sock_sendall, f, 10) + self.assertRaises( + NotImplementedError, loop.sock_connect, f, f) + self.assertRaises( + NotImplementedError, loop.sock_accept, f) + self.assertRaises( + NotImplementedError, loop.add_signal_handler, 1, f) + self.assertRaises( + NotImplementedError, loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, loop.remove_signal_handler, 1) + self.assertRaises( + NotImplementedError, loop.connect_read_pipe, f, + mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, loop.connect_write_pipe, f, + mock.sentinel.pipe) + self.assertRaises( + NotImplementedError, loop.subprocess_shell, f, + mock.sentinel) + self.assertRaises( + NotImplementedError, loop.subprocess_exec, f) + self.assertRaises( + NotImplementedError, loop.set_exception_handler, f) + self.assertRaises( + NotImplementedError, loop.default_exception_handler, f) + self.assertRaises( + NotImplementedError, loop.call_exception_handler, f) + self.assertRaises( + NotImplementedError, loop.get_debug) + self.assertRaises( + NotImplementedError, loop.set_debug, f) + + +class ProtocolsAbsTests(unittest.TestCase): + + def test_empty(self): + f = mock.Mock() + p = asyncio.Protocol() + self.assertIsNone(p.connection_made(f)) + self.assertIsNone(p.connection_lost(f)) + self.assertIsNone(p.data_received(f)) + self.assertIsNone(p.eof_received()) + + dp = asyncio.DatagramProtocol() + self.assertIsNone(dp.connection_made(f)) + self.assertIsNone(dp.connection_lost(f)) + self.assertIsNone(dp.error_received(f)) + self.assertIsNone(dp.datagram_received(f, f)) + + sp = asyncio.SubprocessProtocol() + self.assertIsNone(sp.connection_made(f)) + self.assertIsNone(sp.connection_lost(f)) + self.assertIsNone(sp.pipe_data_received(1, f)) + self.assertIsNone(sp.pipe_connection_lost(1, f)) + self.assertIsNone(sp.process_exited()) + + +class PolicyTests(unittest.TestCase): + + def test_event_loop_policy(self): + policy = asyncio.AbstractEventLoopPolicy() + self.assertRaises(NotImplementedError, policy.get_event_loop) + self.assertRaises(NotImplementedError, policy.set_event_loop, object()) + self.assertRaises(NotImplementedError, policy.new_event_loop) + self.assertRaises(NotImplementedError, policy.get_child_watcher) + self.assertRaises(NotImplementedError, policy.set_child_watcher, + object()) + + def test_get_event_loop(self): + policy = asyncio.DefaultEventLoopPolicy() + self.assertIsNone(policy._local._loop) + + loop = policy.get_event_loop() + self.assertIsInstance(loop, asyncio.AbstractEventLoop) + + self.assertIs(policy._local._loop, loop) + self.assertIs(loop, policy.get_event_loop()) + loop.close() + + def test_get_event_loop_calls_set_event_loop(self): + policy = asyncio.DefaultEventLoopPolicy() + + with mock.patch.object( + policy, "set_event_loop", + wraps=policy.set_event_loop) as m_set_event_loop: + + loop = policy.get_event_loop() + + # policy._local._loop must be set through .set_event_loop() + # (the unix DefaultEventLoopPolicy needs this call to attach + # the child watcher correctly) + m_set_event_loop.assert_called_with(loop) + + loop.close() + + def test_get_event_loop_after_set_none(self): + policy = asyncio.DefaultEventLoopPolicy() + policy.set_event_loop(None) + self.assertRaises(RuntimeError, policy.get_event_loop) + + @mock.patch('asyncio.events.threading.current_thread') + def test_get_event_loop_thread(self, m_current_thread): + + def f(): + policy = asyncio.DefaultEventLoopPolicy() + self.assertRaises(RuntimeError, policy.get_event_loop) + + th = threading.Thread(target=f) + th.start() + th.join() + + def test_new_event_loop(self): + policy = asyncio.DefaultEventLoopPolicy() + + loop = policy.new_event_loop() + self.assertIsInstance(loop, asyncio.AbstractEventLoop) + loop.close() + + def test_set_event_loop(self): + policy = asyncio.DefaultEventLoopPolicy() + old_loop = policy.get_event_loop() + + self.assertRaises(AssertionError, policy.set_event_loop, object()) + + loop = policy.new_event_loop() + policy.set_event_loop(loop) + self.assertIs(loop, policy.get_event_loop()) + self.assertIsNot(old_loop, policy.get_event_loop()) + loop.close() + old_loop.close() + + def test_get_event_loop_policy(self): + policy = asyncio.get_event_loop_policy() + self.assertIsInstance(policy, asyncio.AbstractEventLoopPolicy) + self.assertIs(policy, asyncio.get_event_loop_policy()) + + def test_set_event_loop_policy(self): + self.assertRaises( + AssertionError, asyncio.set_event_loop_policy, object()) + + old_policy = asyncio.get_event_loop_policy() + + policy = asyncio.DefaultEventLoopPolicy() + asyncio.set_event_loop_policy(policy) + self.assertIs(policy, asyncio.get_event_loop_policy()) + self.assertIsNot(policy, old_policy) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_futures.py b/tests/test_futures.py new file mode 100644 index 00000000..c8b6829f --- /dev/null +++ b/tests/test_futures.py @@ -0,0 +1,473 @@ +"""Tests for futures.py.""" + +import concurrent.futures +import re +import sys +import threading +import unittest +from unittest import mock + +import asyncio +from asyncio import test_utils +try: + from test import support +except ImportError: + from asyncio import test_support as support + + +def _fakefunc(f): + return f + +def first_cb(): + pass + +def last_cb(): + pass + + +class FutureTests(test_utils.TestCase): + + def setUp(self): + self.loop = self.new_test_loop() + self.addCleanup(self.loop.close) + + def test_initial_state(self): + f = asyncio.Future(loop=self.loop) + self.assertFalse(f.cancelled()) + self.assertFalse(f.done()) + f.cancel() + self.assertTrue(f.cancelled()) + + def test_init_constructor_default_loop(self): + asyncio.set_event_loop(self.loop) + f = asyncio.Future() + self.assertIs(f._loop, self.loop) + + def test_constructor_positional(self): + # Make sure Future doesn't accept a positional argument + self.assertRaises(TypeError, asyncio.Future, 42) + + def test_cancel(self): + f = asyncio.Future(loop=self.loop) + self.assertTrue(f.cancel()) + self.assertTrue(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(asyncio.CancelledError, f.result) + self.assertRaises(asyncio.CancelledError, f.exception) + self.assertRaises(asyncio.InvalidStateError, f.set_result, None) + self.assertRaises(asyncio.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_result(self): + f = asyncio.Future(loop=self.loop) + self.assertRaises(asyncio.InvalidStateError, f.result) + + f.set_result(42) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 42) + self.assertEqual(f.exception(), None) + self.assertRaises(asyncio.InvalidStateError, f.set_result, None) + self.assertRaises(asyncio.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_exception(self): + exc = RuntimeError() + f = asyncio.Future(loop=self.loop) + self.assertRaises(asyncio.InvalidStateError, f.exception) + + f.set_exception(exc) + self.assertFalse(f.cancelled()) + self.assertTrue(f.done()) + self.assertRaises(RuntimeError, f.result) + self.assertEqual(f.exception(), exc) + self.assertRaises(asyncio.InvalidStateError, f.set_result, None) + self.assertRaises(asyncio.InvalidStateError, f.set_exception, None) + self.assertFalse(f.cancel()) + + def test_exception_class(self): + f = asyncio.Future(loop=self.loop) + f.set_exception(RuntimeError) + self.assertIsInstance(f.exception(), RuntimeError) + + def test_yield_from_twice(self): + f = asyncio.Future(loop=self.loop) + + def fixture(): + yield 'A' + x = yield from f + yield 'B', x + y = yield from f + yield 'C', y + + g = fixture() + self.assertEqual(next(g), 'A') # yield 'A'. + self.assertEqual(next(g), f) # First yield from f. + f.set_result(42) + self.assertEqual(next(g), ('B', 42)) # yield 'B', x. + # The second "yield from f" does not yield f. + self.assertEqual(next(g), ('C', 42)) # yield 'C', y. + + def test_future_repr(self): + self.loop.set_debug(True) + f_pending_debug = asyncio.Future(loop=self.loop) + frame = f_pending_debug._source_traceback[-1] + self.assertEqual(repr(f_pending_debug), + '' + % (frame[0], frame[1])) + f_pending_debug.cancel() + + self.loop.set_debug(False) + f_pending = asyncio.Future(loop=self.loop) + self.assertEqual(repr(f_pending), '') + f_pending.cancel() + + f_cancelled = asyncio.Future(loop=self.loop) + f_cancelled.cancel() + self.assertEqual(repr(f_cancelled), '') + + f_result = asyncio.Future(loop=self.loop) + f_result.set_result(4) + self.assertEqual(repr(f_result), '') + self.assertEqual(f_result.result(), 4) + + exc = RuntimeError() + f_exception = asyncio.Future(loop=self.loop) + f_exception.set_exception(exc) + self.assertEqual(repr(f_exception), + '') + self.assertIs(f_exception.exception(), exc) + + def func_repr(func): + filename, lineno = test_utils.get_function_source(func) + text = '%s() at %s:%s' % (func.__qualname__, filename, lineno) + return re.escape(text) + + f_one_callbacks = asyncio.Future(loop=self.loop) + f_one_callbacks.add_done_callback(_fakefunc) + fake_repr = func_repr(_fakefunc) + self.assertRegex(repr(f_one_callbacks), + r'' % fake_repr) + f_one_callbacks.cancel() + self.assertEqual(repr(f_one_callbacks), + '') + + f_two_callbacks = asyncio.Future(loop=self.loop) + f_two_callbacks.add_done_callback(first_cb) + f_two_callbacks.add_done_callback(last_cb) + first_repr = func_repr(first_cb) + last_repr = func_repr(last_cb) + self.assertRegex(repr(f_two_callbacks), + r'' + % (first_repr, last_repr)) + + f_many_callbacks = asyncio.Future(loop=self.loop) + f_many_callbacks.add_done_callback(first_cb) + for i in range(8): + f_many_callbacks.add_done_callback(_fakefunc) + f_many_callbacks.add_done_callback(last_cb) + cb_regex = r'%s, <8 more>, %s' % (first_repr, last_repr) + self.assertRegex(repr(f_many_callbacks), + r'' % cb_regex) + f_many_callbacks.cancel() + self.assertEqual(repr(f_many_callbacks), + '') + + def test_copy_state(self): + # Test the internal _copy_state method since it's being directly + # invoked in other modules. + f = asyncio.Future(loop=self.loop) + f.set_result(10) + + newf = asyncio.Future(loop=self.loop) + newf._copy_state(f) + self.assertTrue(newf.done()) + self.assertEqual(newf.result(), 10) + + f_exception = asyncio.Future(loop=self.loop) + f_exception.set_exception(RuntimeError()) + + newf_exception = asyncio.Future(loop=self.loop) + newf_exception._copy_state(f_exception) + self.assertTrue(newf_exception.done()) + self.assertRaises(RuntimeError, newf_exception.result) + + f_cancelled = asyncio.Future(loop=self.loop) + f_cancelled.cancel() + + newf_cancelled = asyncio.Future(loop=self.loop) + newf_cancelled._copy_state(f_cancelled) + self.assertTrue(newf_cancelled.cancelled()) + + def test_iter(self): + fut = asyncio.Future(loop=self.loop) + + def coro(): + yield from fut + + def test(): + arg1, arg2 = coro() + + self.assertRaises(AssertionError, test) + fut.cancel() + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_abandoned(self, m_log): + fut = asyncio.Future(loop=self.loop) + del fut + self.assertFalse(m_log.error.called) + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_result_unretrieved(self, m_log): + fut = asyncio.Future(loop=self.loop) + fut.set_result(42) + del fut + self.assertFalse(m_log.error.called) + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_result_retrieved(self, m_log): + fut = asyncio.Future(loop=self.loop) + fut.set_result(42) + fut.result() + del fut + self.assertFalse(m_log.error.called) + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_exception_unretrieved(self, m_log): + fut = asyncio.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + del fut + test_utils.run_briefly(self.loop) + self.assertTrue(m_log.error.called) + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_exception_retrieved(self, m_log): + fut = asyncio.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + fut.exception() + del fut + self.assertFalse(m_log.error.called) + + @mock.patch('asyncio.base_events.logger') + def test_tb_logger_exception_result_retrieved(self, m_log): + fut = asyncio.Future(loop=self.loop) + fut.set_exception(RuntimeError('boom')) + self.assertRaises(RuntimeError, fut.result) + del fut + self.assertFalse(m_log.error.called) + + def test_wrap_future(self): + + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = asyncio.wrap_future(f1, loop=self.loop) + res, ident = self.loop.run_until_complete(f2) + self.assertIsInstance(f2, asyncio.Future) + self.assertEqual(res, 'oi') + self.assertNotEqual(ident, threading.get_ident()) + + def test_wrap_future_future(self): + f1 = asyncio.Future(loop=self.loop) + f2 = asyncio.wrap_future(f1) + self.assertIs(f1, f2) + + @mock.patch('asyncio.futures.events') + def test_wrap_future_use_global_loop(self, m_events): + def run(arg): + return (arg, threading.get_ident()) + ex = concurrent.futures.ThreadPoolExecutor(1) + f1 = ex.submit(run, 'oi') + f2 = asyncio.wrap_future(f1) + self.assertIs(m_events.get_event_loop.return_value, f2._loop) + + def test_wrap_future_cancel(self): + f1 = concurrent.futures.Future() + f2 = asyncio.wrap_future(f1, loop=self.loop) + f2.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(f1.cancelled()) + self.assertTrue(f2.cancelled()) + + def test_wrap_future_cancel2(self): + f1 = concurrent.futures.Future() + f2 = asyncio.wrap_future(f1, loop=self.loop) + f1.set_result(42) + f2.cancel() + test_utils.run_briefly(self.loop) + self.assertFalse(f1.cancelled()) + self.assertEqual(f1.result(), 42) + self.assertTrue(f2.cancelled()) + + def test_future_source_traceback(self): + self.loop.set_debug(True) + + future = asyncio.Future(loop=self.loop) + lineno = sys._getframe().f_lineno - 1 + self.assertIsInstance(future._source_traceback, list) + self.assertEqual(future._source_traceback[-1][:3], + (__file__, + lineno, + 'test_future_source_traceback')) + + @mock.patch('asyncio.base_events.logger') + def check_future_exception_never_retrieved(self, debug, m_log): + self.loop.set_debug(debug) + + def memory_error(): + try: + raise MemoryError() + except BaseException as exc: + return exc + exc = memory_error() + + future = asyncio.Future(loop=self.loop) + if debug: + source_traceback = future._source_traceback + future.set_exception(exc) + future = None + test_utils.run_briefly(self.loop) + support.gc_collect() + + if sys.version_info >= (3, 4): + if debug: + frame = source_traceback[-1] + regex = (r'^Future exception was never retrieved\n' + r'future: \n' + r'source_traceback: Object ' + r'created at \(most recent call last\):\n' + r' File' + r'.*\n' + r' File "{filename}", line {lineno}, ' + r'in check_future_exception_never_retrieved\n' + r' future = asyncio\.Future\(loop=self\.loop\)$' + ).format(filename=re.escape(frame[0]), + lineno=frame[1]) + else: + regex = (r'^Future exception was never retrieved\n' + r'future: ' + r'$' + ) + exc_info = (type(exc), exc, exc.__traceback__) + m_log.error.assert_called_once_with(mock.ANY, exc_info=exc_info) + else: + if debug: + frame = source_traceback[-1] + regex = (r'^Future/Task exception was never retrieved\n' + r'Future/Task created at \(most recent call last\):\n' + r' File' + r'.*\n' + r' File "{filename}", line {lineno}, ' + r'in check_future_exception_never_retrieved\n' + r' future = asyncio\.Future\(loop=self\.loop\)\n' + r'Traceback \(most recent call last\):\n' + r'.*\n' + r'MemoryError$' + ).format(filename=re.escape(frame[0]), + lineno=frame[1]) + else: + regex = (r'^Future/Task exception was never retrieved\n' + r'Traceback \(most recent call last\):\n' + r'.*\n' + r'MemoryError$' + ) + m_log.error.assert_called_once_with(mock.ANY, exc_info=False) + message = m_log.error.call_args[0][0] + self.assertRegex(message, re.compile(regex, re.DOTALL)) + + def test_future_exception_never_retrieved(self): + self.check_future_exception_never_retrieved(False) + + def test_future_exception_never_retrieved_debug(self): + self.check_future_exception_never_retrieved(True) + + def test_set_result_unless_cancelled(self): + fut = asyncio.Future(loop=self.loop) + fut.cancel() + fut._set_result_unless_cancelled(2) + self.assertTrue(fut.cancelled()) + + +class FutureDoneCallbackTests(test_utils.TestCase): + + def setUp(self): + self.loop = self.new_test_loop() + + def run_briefly(self): + test_utils.run_briefly(self.loop) + + def _make_callback(self, bag, thing): + # Create a callback function that appends thing to bag. + def bag_appender(future): + bag.append(thing) + return bag_appender + + def _new_future(self): + return asyncio.Future(loop=self.loop) + + def test_callbacks_invoked_on_set_result(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 42)) + f.add_done_callback(self._make_callback(bag, 17)) + + self.assertEqual(bag, []) + f.set_result('foo') + + self.run_briefly() + + self.assertEqual(bag, [42, 17]) + self.assertEqual(f.result(), 'foo') + + def test_callbacks_invoked_on_set_exception(self): + bag = [] + f = self._new_future() + f.add_done_callback(self._make_callback(bag, 100)) + + self.assertEqual(bag, []) + exc = RuntimeError() + f.set_exception(exc) + + self.run_briefly() + + self.assertEqual(bag, [100]) + self.assertEqual(f.exception(), exc) + + def test_remove_done_callback(self): + bag = [] + f = self._new_future() + cb1 = self._make_callback(bag, 1) + cb2 = self._make_callback(bag, 2) + cb3 = self._make_callback(bag, 3) + + # Add one cb1 and one cb2. + f.add_done_callback(cb1) + f.add_done_callback(cb2) + + # One instance of cb2 removed. Now there's only one cb1. + self.assertEqual(f.remove_done_callback(cb2), 1) + + # Never had any cb3 in there. + self.assertEqual(f.remove_done_callback(cb3), 0) + + # After this there will be 6 instances of cb1 and one of cb2. + f.add_done_callback(cb2) + for i in range(5): + f.add_done_callback(cb1) + + # Remove all instances of cb1. One cb2 remains. + self.assertEqual(f.remove_done_callback(cb1), 6) + + self.assertEqual(bag, []) + f.set_result('foo') + + self.run_briefly() + + self.assertEqual(bag, [2]) + self.assertEqual(f.result(), 'foo') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_locks.py b/tests/test_locks.py new file mode 100644 index 00000000..dda4577a --- /dev/null +++ b/tests/test_locks.py @@ -0,0 +1,858 @@ +"""Tests for lock.py""" + +import unittest +from unittest import mock +import re + +import asyncio +from asyncio import test_utils + + +STR_RGX_REPR = ( + r'^<(?P.*?) object at (?P
.*?)' + r'\[(?P' + r'(set|unset|locked|unlocked)(,value:\d)?(,waiters:\d+)?' + r')\]>\Z' +) +RGX_REPR = re.compile(STR_RGX_REPR) + + +class LockTests(test_utils.TestCase): + + def setUp(self): + self.loop = self.new_test_loop() + + def test_ctor_loop(self): + loop = mock.Mock() + lock = asyncio.Lock(loop=loop) + self.assertIs(lock._loop, loop) + + lock = asyncio.Lock(loop=self.loop) + self.assertIs(lock._loop, self.loop) + + def test_ctor_noloop(self): + asyncio.set_event_loop(self.loop) + lock = asyncio.Lock() + self.assertIs(lock._loop, self.loop) + + def test_repr(self): + lock = asyncio.Lock(loop=self.loop) + self.assertTrue(repr(lock).endswith('[unlocked]>')) + self.assertTrue(RGX_REPR.match(repr(lock))) + + @asyncio.coroutine + def acquire_lock(): + yield from lock + + self.loop.run_until_complete(acquire_lock()) + self.assertTrue(repr(lock).endswith('[locked]>')) + self.assertTrue(RGX_REPR.match(repr(lock))) + + def test_lock(self): + lock = asyncio.Lock(loop=self.loop) + + @asyncio.coroutine + def acquire_lock(): + return (yield from lock) + + res = self.loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_acquire(self): + lock = asyncio.Lock(loop=self.loop) + result = [] + + self.assertTrue(self.loop.run_until_complete(lock.acquire())) + + @asyncio.coroutine + def c1(result): + if (yield from lock.acquire()): + result.append(1) + return True + + @asyncio.coroutine + def c2(result): + if (yield from lock.acquire()): + result.append(2) + return True + + @asyncio.coroutine + def c3(result): + if (yield from lock.acquire()): + result.append(3) + return True + + t1 = asyncio.Task(c1(result), loop=self.loop) + t2 = asyncio.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + t3 = asyncio.Task(c3(result), loop=self.loop) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + + lock.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_acquire_cancel(self): + lock = asyncio.Lock(loop=self.loop) + self.assertTrue(self.loop.run_until_complete(lock.acquire())) + + task = asyncio.Task(lock.acquire(), loop=self.loop) + self.loop.call_soon(task.cancel) + self.assertRaises( + asyncio.CancelledError, + self.loop.run_until_complete, task) + self.assertFalse(lock._waiters) + + def test_cancel_race(self): + # Several tasks: + # - A acquires the lock + # - B is blocked in aqcuire() + # - C is blocked in aqcuire() + # + # Now, concurrently: + # - B is cancelled + # - A releases the lock + # + # If B's waiter is marked cancelled but not yet removed from + # _waiters, A's release() call will crash when trying to set + # B's waiter; instead, it should move on to C's waiter. + + # Setup: A has the lock, b and c are waiting. + lock = asyncio.Lock(loop=self.loop) + + @asyncio.coroutine + def lockit(name, blocker): + yield from lock.acquire() + try: + if blocker is not None: + yield from blocker + finally: + lock.release() + + fa = asyncio.Future(loop=self.loop) + ta = asyncio.Task(lockit('A', fa), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertTrue(lock.locked()) + tb = asyncio.Task(lockit('B', None), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(len(lock._waiters), 1) + tc = asyncio.Task(lockit('C', None), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual(len(lock._waiters), 2) + + # Create the race and check. + # Without the fix this failed at the last assert. + fa.set_result(None) + tb.cancel() + self.assertTrue(lock._waiters[0].cancelled()) + test_utils.run_briefly(self.loop) + self.assertFalse(lock.locked()) + self.assertTrue(ta.done()) + self.assertTrue(tb.cancelled()) + self.assertTrue(tc.done()) + + def test_release_not_acquired(self): + lock = asyncio.Lock(loop=self.loop) + + self.assertRaises(RuntimeError, lock.release) + + def test_release_no_waiters(self): + lock = asyncio.Lock(loop=self.loop) + self.loop.run_until_complete(lock.acquire()) + self.assertTrue(lock.locked()) + + lock.release() + self.assertFalse(lock.locked()) + + def test_context_manager(self): + lock = asyncio.Lock(loop=self.loop) + + @asyncio.coroutine + def acquire_lock(): + return (yield from lock) + + with self.loop.run_until_complete(acquire_lock()): + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + def test_context_manager_cant_reuse(self): + lock = asyncio.Lock(loop=self.loop) + + @asyncio.coroutine + def acquire_lock(): + return (yield from lock) + + # This spells "yield from lock" outside a generator. + cm = self.loop.run_until_complete(acquire_lock()) + with cm: + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + with self.assertRaises(AttributeError): + with cm: + pass + + def test_context_manager_no_yield(self): + lock = asyncio.Lock(loop=self.loop) + + try: + with lock: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + self.assertFalse(lock.locked()) + + +class EventTests(test_utils.TestCase): + + def setUp(self): + self.loop = self.new_test_loop() + + def test_ctor_loop(self): + loop = mock.Mock() + ev = asyncio.Event(loop=loop) + self.assertIs(ev._loop, loop) + + ev = asyncio.Event(loop=self.loop) + self.assertIs(ev._loop, self.loop) + + def test_ctor_noloop(self): + asyncio.set_event_loop(self.loop) + ev = asyncio.Event() + self.assertIs(ev._loop, self.loop) + + def test_repr(self): + ev = asyncio.Event(loop=self.loop) + self.assertTrue(repr(ev).endswith('[unset]>')) + match = RGX_REPR.match(repr(ev)) + self.assertEqual(match.group('extras'), 'unset') + + ev.set() + self.assertTrue(repr(ev).endswith('[set]>')) + self.assertTrue(RGX_REPR.match(repr(ev))) + + ev._waiters.append(mock.Mock()) + self.assertTrue('waiters:1' in repr(ev)) + self.assertTrue(RGX_REPR.match(repr(ev))) + + def test_wait(self): + ev = asyncio.Event(loop=self.loop) + self.assertFalse(ev.is_set()) + + result = [] + + @asyncio.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + + @asyncio.coroutine + def c2(result): + if (yield from ev.wait()): + result.append(2) + + @asyncio.coroutine + def c3(result): + if (yield from ev.wait()): + result.append(3) + + t1 = asyncio.Task(c1(result), loop=self.loop) + t2 = asyncio.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + t3 = asyncio.Task(c3(result), loop=self.loop) + + ev.set() + test_utils.run_briefly(self.loop) + self.assertEqual([3, 1, 2], result) + + self.assertTrue(t1.done()) + self.assertIsNone(t1.result()) + self.assertTrue(t2.done()) + self.assertIsNone(t2.result()) + self.assertTrue(t3.done()) + self.assertIsNone(t3.result()) + + def test_wait_on_set(self): + ev = asyncio.Event(loop=self.loop) + ev.set() + + res = self.loop.run_until_complete(ev.wait()) + self.assertTrue(res) + + def test_wait_cancel(self): + ev = asyncio.Event(loop=self.loop) + + wait = asyncio.Task(ev.wait(), loop=self.loop) + self.loop.call_soon(wait.cancel) + self.assertRaises( + asyncio.CancelledError, + self.loop.run_until_complete, wait) + self.assertFalse(ev._waiters) + + def test_clear(self): + ev = asyncio.Event(loop=self.loop) + self.assertFalse(ev.is_set()) + + ev.set() + self.assertTrue(ev.is_set()) + + ev.clear() + self.assertFalse(ev.is_set()) + + def test_clear_with_waiters(self): + ev = asyncio.Event(loop=self.loop) + result = [] + + @asyncio.coroutine + def c1(result): + if (yield from ev.wait()): + result.append(1) + return True + + t = asyncio.Task(c1(result), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + ev.set() + ev.clear() + self.assertFalse(ev.is_set()) + + ev.set() + ev.set() + self.assertEqual(1, len(ev._waiters)) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertEqual(0, len(ev._waiters)) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + +class ConditionTests(test_utils.TestCase): + + def setUp(self): + self.loop = self.new_test_loop() + + def test_ctor_loop(self): + loop = mock.Mock() + cond = asyncio.Condition(loop=loop) + self.assertIs(cond._loop, loop) + + cond = asyncio.Condition(loop=self.loop) + self.assertIs(cond._loop, self.loop) + + def test_ctor_noloop(self): + asyncio.set_event_loop(self.loop) + cond = asyncio.Condition() + self.assertIs(cond._loop, self.loop) + + def test_wait(self): + cond = asyncio.Condition(loop=self.loop) + result = [] + + @asyncio.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + return True + + @asyncio.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + return True + + @asyncio.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + return True + + t1 = asyncio.Task(c1(result), loop=self.loop) + t2 = asyncio.Task(c2(result), loop=self.loop) + t3 = asyncio.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue(self.loop.run_until_complete(cond.acquire())) + cond.notify() + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_wait_cancel(self): + cond = asyncio.Condition(loop=self.loop) + self.loop.run_until_complete(cond.acquire()) + + wait = asyncio.Task(cond.wait(), loop=self.loop) + self.loop.call_soon(wait.cancel) + self.assertRaises( + asyncio.CancelledError, + self.loop.run_until_complete, wait) + self.assertFalse(cond._waiters) + self.assertTrue(cond.locked()) + + def test_wait_unacquired(self): + cond = asyncio.Condition(loop=self.loop) + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, cond.wait()) + + def test_wait_for(self): + cond = asyncio.Condition(loop=self.loop) + presult = False + + def predicate(): + return presult + + result = [] + + @asyncio.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait_for(predicate)): + result.append(1) + cond.release() + return True + + t = asyncio.Task(c1(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + presult = True + self.loop.run_until_complete(cond.acquire()) + cond.notify() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + def test_wait_for_unacquired(self): + cond = asyncio.Condition(loop=self.loop) + + # predicate can return true immediately + res = self.loop.run_until_complete(cond.wait_for(lambda: [1, 2, 3])) + self.assertEqual([1, 2, 3], res) + + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, + cond.wait_for(lambda: False)) + + def test_notify(self): + cond = asyncio.Condition(loop=self.loop) + result = [] + + @asyncio.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + return True + + @asyncio.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + return True + + @asyncio.coroutine + def c3(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(3) + cond.release() + return True + + t1 = asyncio.Task(c1(result), loop=self.loop) + t2 = asyncio.Task(c2(result), loop=self.loop) + t3 = asyncio.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify(1) + cond.notify(2048) + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + def test_notify_all(self): + cond = asyncio.Condition(loop=self.loop) + + result = [] + + @asyncio.coroutine + def c1(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(1) + cond.release() + return True + + @asyncio.coroutine + def c2(result): + yield from cond.acquire() + if (yield from cond.wait()): + result.append(2) + cond.release() + return True + + t1 = asyncio.Task(c1(result), loop=self.loop) + t2 = asyncio.Task(c2(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([], result) + + self.loop.run_until_complete(cond.acquire()) + cond.notify_all() + cond.release() + test_utils.run_briefly(self.loop) + self.assertEqual([1, 2], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + + def test_notify_unacquired(self): + cond = asyncio.Condition(loop=self.loop) + self.assertRaises(RuntimeError, cond.notify) + + def test_notify_all_unacquired(self): + cond = asyncio.Condition(loop=self.loop) + self.assertRaises(RuntimeError, cond.notify_all) + + def test_repr(self): + cond = asyncio.Condition(loop=self.loop) + self.assertTrue('unlocked' in repr(cond)) + self.assertTrue(RGX_REPR.match(repr(cond))) + + self.loop.run_until_complete(cond.acquire()) + self.assertTrue('locked' in repr(cond)) + + cond._waiters.append(mock.Mock()) + self.assertTrue('waiters:1' in repr(cond)) + self.assertTrue(RGX_REPR.match(repr(cond))) + + cond._waiters.append(mock.Mock()) + self.assertTrue('waiters:2' in repr(cond)) + self.assertTrue(RGX_REPR.match(repr(cond))) + + def test_context_manager(self): + cond = asyncio.Condition(loop=self.loop) + + @asyncio.coroutine + def acquire_cond(): + return (yield from cond) + + with self.loop.run_until_complete(acquire_cond()): + self.assertTrue(cond.locked()) + + self.assertFalse(cond.locked()) + + def test_context_manager_no_yield(self): + cond = asyncio.Condition(loop=self.loop) + + try: + with cond: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + self.assertFalse(cond.locked()) + + def test_explicit_lock(self): + lock = asyncio.Lock(loop=self.loop) + cond = asyncio.Condition(lock, loop=self.loop) + + self.assertIs(cond._lock, lock) + self.assertIs(cond._loop, lock._loop) + + def test_ambiguous_loops(self): + loop = self.new_test_loop() + self.addCleanup(loop.close) + + lock = asyncio.Lock(loop=self.loop) + with self.assertRaises(ValueError): + asyncio.Condition(lock, loop=loop) + + +class SemaphoreTests(test_utils.TestCase): + + def setUp(self): + self.loop = self.new_test_loop() + + def test_ctor_loop(self): + loop = mock.Mock() + sem = asyncio.Semaphore(loop=loop) + self.assertIs(sem._loop, loop) + + sem = asyncio.Semaphore(loop=self.loop) + self.assertIs(sem._loop, self.loop) + + def test_ctor_noloop(self): + asyncio.set_event_loop(self.loop) + sem = asyncio.Semaphore() + self.assertIs(sem._loop, self.loop) + + def test_initial_value_zero(self): + sem = asyncio.Semaphore(0, loop=self.loop) + self.assertTrue(sem.locked()) + + def test_repr(self): + sem = asyncio.Semaphore(loop=self.loop) + self.assertTrue(repr(sem).endswith('[unlocked,value:1]>')) + self.assertTrue(RGX_REPR.match(repr(sem))) + + self.loop.run_until_complete(sem.acquire()) + self.assertTrue(repr(sem).endswith('[locked]>')) + self.assertTrue('waiters' not in repr(sem)) + self.assertTrue(RGX_REPR.match(repr(sem))) + + sem._waiters.append(mock.Mock()) + self.assertTrue('waiters:1' in repr(sem)) + self.assertTrue(RGX_REPR.match(repr(sem))) + + sem._waiters.append(mock.Mock()) + self.assertTrue('waiters:2' in repr(sem)) + self.assertTrue(RGX_REPR.match(repr(sem))) + + def test_semaphore(self): + sem = asyncio.Semaphore(loop=self.loop) + self.assertEqual(1, sem._value) + + @asyncio.coroutine + def acquire_lock(): + return (yield from sem) + + res = self.loop.run_until_complete(acquire_lock()) + + self.assertTrue(res) + self.assertTrue(sem.locked()) + self.assertEqual(0, sem._value) + + sem.release() + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + def test_semaphore_value(self): + self.assertRaises(ValueError, asyncio.Semaphore, -1) + + def test_acquire(self): + sem = asyncio.Semaphore(3, loop=self.loop) + result = [] + + self.assertTrue(self.loop.run_until_complete(sem.acquire())) + self.assertTrue(self.loop.run_until_complete(sem.acquire())) + self.assertFalse(sem.locked()) + + @asyncio.coroutine + def c1(result): + yield from sem.acquire() + result.append(1) + return True + + @asyncio.coroutine + def c2(result): + yield from sem.acquire() + result.append(2) + return True + + @asyncio.coroutine + def c3(result): + yield from sem.acquire() + result.append(3) + return True + + @asyncio.coroutine + def c4(result): + yield from sem.acquire() + result.append(4) + return True + + t1 = asyncio.Task(c1(result), loop=self.loop) + t2 = asyncio.Task(c2(result), loop=self.loop) + t3 = asyncio.Task(c3(result), loop=self.loop) + + test_utils.run_briefly(self.loop) + self.assertEqual([1], result) + self.assertTrue(sem.locked()) + self.assertEqual(2, len(sem._waiters)) + self.assertEqual(0, sem._value) + + t4 = asyncio.Task(c4(result), loop=self.loop) + + sem.release() + sem.release() + self.assertEqual(2, sem._value) + + test_utils.run_briefly(self.loop) + self.assertEqual(0, sem._value) + self.assertEqual([1, 2, 3], result) + self.assertTrue(sem.locked()) + self.assertEqual(1, len(sem._waiters)) + self.assertEqual(0, sem._value) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + self.assertFalse(t4.done()) + + # cleanup locked semaphore + sem.release() + self.loop.run_until_complete(t4) + + def test_acquire_cancel(self): + sem = asyncio.Semaphore(loop=self.loop) + self.loop.run_until_complete(sem.acquire()) + + acquire = asyncio.Task(sem.acquire(), loop=self.loop) + self.loop.call_soon(acquire.cancel) + self.assertRaises( + asyncio.CancelledError, + self.loop.run_until_complete, acquire) + self.assertFalse(sem._waiters) + + def test_release_not_acquired(self): + sem = asyncio.BoundedSemaphore(loop=self.loop) + + self.assertRaises(ValueError, sem.release) + + def test_release_no_waiters(self): + sem = asyncio.Semaphore(loop=self.loop) + self.loop.run_until_complete(sem.acquire()) + self.assertTrue(sem.locked()) + + sem.release() + self.assertFalse(sem.locked()) + + def test_context_manager(self): + sem = asyncio.Semaphore(2, loop=self.loop) + + @asyncio.coroutine + def acquire_lock(): + return (yield from sem) + + with self.loop.run_until_complete(acquire_lock()): + self.assertFalse(sem.locked()) + self.assertEqual(1, sem._value) + + with self.loop.run_until_complete(acquire_lock()): + self.assertTrue(sem.locked()) + + self.assertEqual(2, sem._value) + + def test_context_manager_no_yield(self): + sem = asyncio.Semaphore(2, loop=self.loop) + + try: + with sem: + self.fail('RuntimeError is not raised in with expression') + except RuntimeError as err: + self.assertEqual( + str(err), + '"yield from" should be used as context manager expression') + + self.assertEqual(2, sem._value) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py new file mode 100644 index 00000000..33a8a671 --- /dev/null +++ b/tests/test_proactor_events.py @@ -0,0 +1,587 @@ +"""Tests for proactor_events.py""" + +import socket +import unittest +from unittest import mock + +import asyncio +from asyncio.proactor_events import BaseProactorEventLoop +from asyncio.proactor_events import _ProactorSocketTransport +from asyncio.proactor_events import _ProactorWritePipeTransport +from asyncio.proactor_events import _ProactorDuplexPipeTransport +from asyncio import test_utils + + +def close_transport(transport): + # Don't call transport.close() because the event loop and the IOCP proactor + # are mocked + if transport._sock is None: + return + transport._sock.close() + transport._sock = None + + +class ProactorSocketTransportTests(test_utils.TestCase): + + def setUp(self): + self.loop = self.new_test_loop() + self.addCleanup(self.loop.close) + self.proactor = mock.Mock() + self.loop._proactor = self.proactor + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) + self.sock = mock.Mock(socket.socket) + + def socket_transport(self, waiter=None): + transport = _ProactorSocketTransport(self.loop, self.sock, + self.protocol, waiter=waiter) + self.addCleanup(close_transport, transport) + return transport + + def test_ctor(self): + fut = asyncio.Future(loop=self.loop) + tr = self.socket_transport(waiter=fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + self.protocol.connection_made(tr) + self.proactor.recv.assert_called_with(self.sock, 4096) + + def test_loop_reading(self): + tr = self.socket_transport() + tr._loop_reading() + self.loop._proactor.recv.assert_called_with(self.sock, 4096) + self.assertFalse(self.protocol.data_received.called) + self.assertFalse(self.protocol.eof_received.called) + + def test_loop_reading_data(self): + res = asyncio.Future(loop=self.loop) + res.set_result(b'data') + + tr = self.socket_transport() + tr._read_fut = res + tr._loop_reading(res) + self.loop._proactor.recv.assert_called_with(self.sock, 4096) + self.protocol.data_received.assert_called_with(b'data') + + def test_loop_reading_no_data(self): + res = asyncio.Future(loop=self.loop) + res.set_result(b'') + + tr = self.socket_transport() + self.assertRaises(AssertionError, tr._loop_reading, res) + + tr.close = mock.Mock() + tr._read_fut = res + tr._loop_reading(res) + self.assertFalse(self.loop._proactor.recv.called) + self.assertTrue(self.protocol.eof_received.called) + self.assertTrue(tr.close.called) + + def test_loop_reading_aborted(self): + err = self.loop._proactor.recv.side_effect = ConnectionAbortedError() + + tr = self.socket_transport() + tr._fatal_error = mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with( + err, + 'Fatal read error on pipe transport') + + def test_loop_reading_aborted_closing(self): + self.loop._proactor.recv.side_effect = ConnectionAbortedError() + + tr = self.socket_transport() + tr._closing = True + tr._fatal_error = mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + + def test_loop_reading_aborted_is_fatal(self): + self.loop._proactor.recv.side_effect = ConnectionAbortedError() + tr = self.socket_transport() + tr._closing = False + tr._fatal_error = mock.Mock() + tr._loop_reading() + self.assertTrue(tr._fatal_error.called) + + def test_loop_reading_conn_reset_lost(self): + err = self.loop._proactor.recv.side_effect = ConnectionResetError() + + tr = self.socket_transport() + tr._closing = False + tr._fatal_error = mock.Mock() + tr._force_close = mock.Mock() + tr._loop_reading() + self.assertFalse(tr._fatal_error.called) + tr._force_close.assert_called_with(err) + + def test_loop_reading_exception(self): + err = self.loop._proactor.recv.side_effect = (OSError()) + + tr = self.socket_transport() + tr._fatal_error = mock.Mock() + tr._loop_reading() + tr._fatal_error.assert_called_with( + err, + 'Fatal read error on pipe transport') + + def test_write(self): + tr = self.socket_transport() + tr._loop_writing = mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, None) + tr._loop_writing.assert_called_with(data=b'data') + + def test_write_no_data(self): + tr = self.socket_transport() + tr.write(b'') + self.assertFalse(tr._buffer) + + def test_write_more(self): + tr = self.socket_transport() + tr._write_fut = mock.Mock() + tr._loop_writing = mock.Mock() + tr.write(b'data') + self.assertEqual(tr._buffer, b'data') + self.assertFalse(tr._loop_writing.called) + + def test_loop_writing(self): + tr = self.socket_transport() + tr._buffer = bytearray(b'data') + tr._loop_writing() + self.loop._proactor.send.assert_called_with(self.sock, b'data') + self.loop._proactor.send.return_value.add_done_callback.\ + assert_called_with(tr._loop_writing) + + @mock.patch('asyncio.proactor_events.logger') + def test_loop_writing_err(self, m_log): + err = self.loop._proactor.send.side_effect = OSError() + tr = self.socket_transport() + tr._fatal_error = mock.Mock() + tr._buffer = [b'da', b'ta'] + tr._loop_writing() + tr._fatal_error.assert_called_with( + err, + 'Fatal write error on pipe transport') + tr._conn_lost = 1 + + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + self.assertEqual(tr._buffer, None) + m_log.warning.assert_called_with('socket.send() raised exception.') + + def test_loop_writing_stop(self): + fut = asyncio.Future(loop=self.loop) + fut.set_result(b'data') + + tr = self.socket_transport() + tr._write_fut = fut + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + + def test_loop_writing_closing(self): + fut = asyncio.Future(loop=self.loop) + fut.set_result(1) + + tr = self.socket_transport() + tr._write_fut = fut + tr.close() + tr._loop_writing(fut) + self.assertIsNone(tr._write_fut) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test_abort(self): + tr = self.socket_transport() + tr._force_close = mock.Mock() + tr.abort() + tr._force_close.assert_called_with(None) + + def test_close(self): + tr = self.socket_transport() + tr.close() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertTrue(tr._closing) + self.assertEqual(tr._conn_lost, 1) + + self.protocol.connection_lost.reset_mock() + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_write_fut(self): + tr = self.socket_transport() + tr._write_fut = mock.Mock() + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_close_buffer(self): + tr = self.socket_transport() + tr._buffer = [b'data'] + tr.close() + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + @mock.patch('asyncio.base_events.logger') + def test_fatal_error(self, m_logging): + tr = self.socket_transport() + tr._force_close = mock.Mock() + tr._fatal_error(None) + self.assertTrue(tr._force_close.called) + self.assertTrue(m_logging.error.called) + + def test_force_close(self): + tr = self.socket_transport() + tr._buffer = [b'data'] + read_fut = tr._read_fut = mock.Mock() + write_fut = tr._write_fut = mock.Mock() + tr._force_close(None) + + read_fut.cancel.assert_called_with() + write_fut.cancel.assert_called_with() + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertEqual(None, tr._buffer) + self.assertEqual(tr._conn_lost, 1) + + def test_force_close_idempotent(self): + tr = self.socket_transport() + tr._closing = True + tr._force_close(None) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.connection_lost.called) + + def test_fatal_error_2(self): + tr = self.socket_transport() + tr._buffer = [b'data'] + tr._force_close(None) + + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + self.assertEqual(None, tr._buffer) + + def test_call_connection_lost(self): + tr = self.socket_transport() + tr._call_connection_lost(None) + self.assertTrue(self.protocol.connection_lost.called) + self.assertTrue(self.sock.close.called) + + def test_write_eof(self): + tr = self.socket_transport() + self.assertTrue(tr.can_write_eof()) + tr.write_eof() + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.write_eof() + self.assertEqual(self.sock.shutdown.call_count, 1) + tr.close() + + def test_write_eof_buffer(self): + tr = self.socket_transport() + f = asyncio.Future(loop=self.loop) + tr._loop._proactor.send.return_value = f + tr.write(b'data') + tr.write_eof() + self.assertTrue(tr._eof_written) + self.assertFalse(self.sock.shutdown.called) + tr._loop._proactor.send.assert_called_with(self.sock, b'data') + f.set_result(4) + self.loop._run_once() + self.sock.shutdown.assert_called_with(socket.SHUT_WR) + tr.close() + + def test_write_eof_write_pipe(self): + tr = _ProactorWritePipeTransport( + self.loop, self.sock, self.protocol) + self.assertTrue(tr.can_write_eof()) + tr.write_eof() + self.assertTrue(tr._closing) + self.loop._run_once() + self.assertTrue(self.sock.close.called) + tr.close() + + def test_write_eof_buffer_write_pipe(self): + tr = _ProactorWritePipeTransport(self.loop, self.sock, self.protocol) + f = asyncio.Future(loop=self.loop) + tr._loop._proactor.send.return_value = f + tr.write(b'data') + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.sock.shutdown.called) + tr._loop._proactor.send.assert_called_with(self.sock, b'data') + f.set_result(4) + self.loop._run_once() + self.loop._run_once() + self.assertTrue(self.sock.close.called) + tr.close() + + def test_write_eof_duplex_pipe(self): + tr = _ProactorDuplexPipeTransport( + self.loop, self.sock, self.protocol) + self.assertFalse(tr.can_write_eof()) + with self.assertRaises(NotImplementedError): + tr.write_eof() + close_transport(tr) + + def test_pause_resume_reading(self): + tr = self.socket_transport() + futures = [] + for msg in [b'data1', b'data2', b'data3', b'data4', b'']: + f = asyncio.Future(loop=self.loop) + f.set_result(msg) + futures.append(f) + self.loop._proactor.recv.side_effect = futures + self.loop._run_once() + self.assertFalse(tr._paused) + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data1') + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data2') + tr.pause_reading() + self.assertTrue(tr._paused) + for i in range(10): + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data2') + tr.resume_reading() + self.assertFalse(tr._paused) + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data3') + self.loop._run_once() + self.protocol.data_received.assert_called_with(b'data4') + tr.close() + + + def pause_writing_transport(self, high): + tr = self.socket_transport() + tr.set_write_buffer_limits(high=high) + + self.assertEqual(tr.get_write_buffer_size(), 0) + self.assertFalse(self.protocol.pause_writing.called) + self.assertFalse(self.protocol.resume_writing.called) + return tr + + def test_pause_resume_writing(self): + tr = self.pause_writing_transport(high=4) + + # write a large chunk, must pause writing + fut = asyncio.Future(loop=self.loop) + self.loop._proactor.send.return_value = fut + tr.write(b'large data') + self.loop._run_once() + self.assertTrue(self.protocol.pause_writing.called) + + # flush the buffer + fut.set_result(None) + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 0) + self.assertTrue(self.protocol.resume_writing.called) + + def test_pause_writing_2write(self): + tr = self.pause_writing_transport(high=4) + + # first short write, the buffer is not full (3 <= 4) + fut1 = asyncio.Future(loop=self.loop) + self.loop._proactor.send.return_value = fut1 + tr.write(b'123') + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 3) + self.assertFalse(self.protocol.pause_writing.called) + + # fill the buffer, must pause writing (6 > 4) + tr.write(b'abc') + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 6) + self.assertTrue(self.protocol.pause_writing.called) + + def test_pause_writing_3write(self): + tr = self.pause_writing_transport(high=4) + + # first short write, the buffer is not full (1 <= 4) + fut = asyncio.Future(loop=self.loop) + self.loop._proactor.send.return_value = fut + tr.write(b'1') + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 1) + self.assertFalse(self.protocol.pause_writing.called) + + # second short write, the buffer is not full (3 <= 4) + tr.write(b'23') + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 3) + self.assertFalse(self.protocol.pause_writing.called) + + # fill the buffer, must pause writing (6 > 4) + tr.write(b'abc') + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 6) + self.assertTrue(self.protocol.pause_writing.called) + + def test_dont_pause_writing(self): + tr = self.pause_writing_transport(high=4) + + # write a large chunk which completes immedialty, + # it should not pause writing + fut = asyncio.Future(loop=self.loop) + fut.set_result(None) + self.loop._proactor.send.return_value = fut + tr.write(b'very large data') + self.loop._run_once() + self.assertEqual(tr.get_write_buffer_size(), 0) + self.assertFalse(self.protocol.pause_writing.called) + + +class BaseProactorEventLoopTests(test_utils.TestCase): + + def setUp(self): + self.sock = mock.Mock(socket.socket) + self.proactor = mock.Mock() + + self.ssock, self.csock = mock.Mock(), mock.Mock() + + class EventLoop(BaseProactorEventLoop): + def _socketpair(s): + return (self.ssock, self.csock) + + self.loop = EventLoop(self.proactor) + self.set_event_loop(self.loop) + + @mock.patch.object(BaseProactorEventLoop, 'call_soon') + @mock.patch.object(BaseProactorEventLoop, '_socketpair') + def test_ctor(self, socketpair, call_soon): + ssock, csock = socketpair.return_value = ( + mock.Mock(), mock.Mock()) + loop = BaseProactorEventLoop(self.proactor) + self.assertIs(loop._ssock, ssock) + self.assertIs(loop._csock, csock) + self.assertEqual(loop._internal_fds, 1) + call_soon.assert_called_with(loop._loop_self_reading) + loop.close() + + def test_close_self_pipe(self): + self.loop._close_self_pipe() + self.assertEqual(self.loop._internal_fds, 0) + self.assertTrue(self.ssock.close.called) + self.assertTrue(self.csock.close.called) + self.assertIsNone(self.loop._ssock) + self.assertIsNone(self.loop._csock) + + # Don't call close(): _close_self_pipe() cannot be called twice + self.loop._closed = True + + def test_close(self): + self.loop._close_self_pipe = mock.Mock() + self.loop.close() + self.assertTrue(self.loop._close_self_pipe.called) + self.assertTrue(self.proactor.close.called) + self.assertIsNone(self.loop._proactor) + + self.loop._close_self_pipe.reset_mock() + self.loop.close() + self.assertFalse(self.loop._close_self_pipe.called) + + def test_sock_recv(self): + self.loop.sock_recv(self.sock, 1024) + self.proactor.recv.assert_called_with(self.sock, 1024) + + def test_sock_sendall(self): + self.loop.sock_sendall(self.sock, b'data') + self.proactor.send.assert_called_with(self.sock, b'data') + + def test_sock_connect(self): + self.loop.sock_connect(self.sock, 123) + self.proactor.connect.assert_called_with(self.sock, 123) + + def test_sock_accept(self): + self.loop.sock_accept(self.sock) + self.proactor.accept.assert_called_with(self.sock) + + def test_socketpair(self): + self.assertRaises( + NotImplementedError, BaseProactorEventLoop, self.proactor) + + def test_make_socket_transport(self): + tr = self.loop._make_socket_transport(self.sock, asyncio.Protocol()) + self.assertIsInstance(tr, _ProactorSocketTransport) + close_transport(tr) + + def test_loop_self_reading(self): + self.loop._loop_self_reading() + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_fut(self): + fut = mock.Mock() + self.loop._loop_self_reading(fut) + self.assertTrue(fut.result.called) + self.proactor.recv.assert_called_with(self.ssock, 4096) + self.proactor.recv.return_value.add_done_callback.assert_called_with( + self.loop._loop_self_reading) + + def test_loop_self_reading_exception(self): + self.loop.close = mock.Mock() + self.loop.call_exception_handler = mock.Mock() + self.proactor.recv.side_effect = OSError() + self.loop._loop_self_reading() + self.assertTrue(self.loop.call_exception_handler.called) + + def test_write_to_self(self): + self.loop._write_to_self() + self.csock.send.assert_called_with(b'\0') + + def test_process_events(self): + self.loop._process_events([]) + + @mock.patch('asyncio.base_events.logger') + def test_create_server(self, m_log): + pf = mock.Mock() + call_soon = self.loop.call_soon = mock.Mock() + + self.loop._start_serving(pf, self.sock) + self.assertTrue(call_soon.called) + + # callback + loop = call_soon.call_args[0][0] + loop() + self.proactor.accept.assert_called_with(self.sock) + + # conn + fut = mock.Mock() + fut.result.return_value = (mock.Mock(), mock.Mock()) + + make_tr = self.loop._make_socket_transport = mock.Mock() + loop(fut) + self.assertTrue(fut.result.called) + self.assertTrue(make_tr.called) + + # exception + fut.result.side_effect = OSError() + loop(fut) + self.assertTrue(self.sock.close.called) + self.assertTrue(m_log.error.called) + + def test_create_server_cancel(self): + pf = mock.Mock() + call_soon = self.loop.call_soon = mock.Mock() + + self.loop._start_serving(pf, self.sock) + loop = call_soon.call_args[0][0] + + # cancelled + fut = asyncio.Future(loop=self.loop) + fut.cancel() + loop(fut) + self.assertTrue(self.sock.close.called) + + def test_stop_serving(self): + sock = mock.Mock() + self.loop._stop_serving(sock) + self.assertTrue(sock.close.called) + self.proactor._stop_serving.assert_called_with(sock) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_queues.py b/tests/test_queues.py new file mode 100644 index 00000000..3d4ac51d --- /dev/null +++ b/tests/test_queues.py @@ -0,0 +1,476 @@ +"""Tests for queues.py""" + +import unittest +from unittest import mock + +import asyncio +from asyncio import test_utils + + +class _QueueTestBase(test_utils.TestCase): + + def setUp(self): + self.loop = self.new_test_loop() + + +class QueueBasicTests(_QueueTestBase): + + def _test_repr_or_str(self, fn, expect_id): + """Test Queue's repr or str. + + fn is repr or str. expect_id is True if we expect the Queue's id to + appear in fn(Queue()). + """ + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.2, when) + yield 0.1 + + loop = self.new_test_loop(gen) + + q = asyncio.Queue(loop=loop) + self.assertTrue(fn(q).startswith('= (3, 4)) +PY35 = (sys.version_info >= (3, 5)) + + +@asyncio.coroutine +def coroutine_function(): + pass + + +def format_coroutine(qualname, state, src, source_traceback, generator=False): + if generator: + state = '%s' % state + else: + state = '%s, defined' % state + if source_traceback is not None: + frame = source_traceback[-1] + return ('coro=<%s() %s at %s> created at %s:%s' + % (qualname, state, src, frame[0], frame[1])) + else: + return 'coro=<%s() %s at %s>' % (qualname, state, src) + + +class Dummy: + + def __repr__(self): + return '' + + def __call__(self, *args): + pass + + +class TaskTests(test_utils.TestCase): + + def setUp(self): + self.loop = self.new_test_loop() + + def test_task_class(self): + @asyncio.coroutine + def notmuch(): + return 'ok' + t = asyncio.Task(notmuch(), loop=self.loop) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t._loop, self.loop) + + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + t = asyncio.Task(notmuch(), loop=loop) + self.assertIs(t._loop, loop) + loop.run_until_complete(t) + loop.close() + + def test_async_coroutine(self): + @asyncio.coroutine + def notmuch(): + return 'ok' + t = asyncio.async(notmuch(), loop=self.loop) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t._loop, self.loop) + + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + t = asyncio.async(notmuch(), loop=loop) + self.assertIs(t._loop, loop) + loop.run_until_complete(t) + loop.close() + + def test_async_future(self): + f_orig = asyncio.Future(loop=self.loop) + f_orig.set_result('ko') + + f = asyncio.async(f_orig) + self.loop.run_until_complete(f) + self.assertTrue(f.done()) + self.assertEqual(f.result(), 'ko') + self.assertIs(f, f_orig) + + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + + with self.assertRaises(ValueError): + f = asyncio.async(f_orig, loop=loop) + + loop.close() + + f = asyncio.async(f_orig, loop=self.loop) + self.assertIs(f, f_orig) + + def test_async_task(self): + @asyncio.coroutine + def notmuch(): + return 'ok' + t_orig = asyncio.Task(notmuch(), loop=self.loop) + t = asyncio.async(t_orig) + self.loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'ok') + self.assertIs(t, t_orig) + + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + + with self.assertRaises(ValueError): + t = asyncio.async(t_orig, loop=loop) + + loop.close() + + t = asyncio.async(t_orig, loop=self.loop) + self.assertIs(t, t_orig) + + def test_async_neither(self): + with self.assertRaises(TypeError): + asyncio.async('ok') + + def test_task_repr(self): + self.loop.set_debug(False) + + @asyncio.coroutine + def notmuch(): + yield from [] + return 'abc' + + # test coroutine function + self.assertEqual(notmuch.__name__, 'notmuch') + if PY35: + self.assertEqual(notmuch.__qualname__, + 'TaskTests.test_task_repr..notmuch') + self.assertEqual(notmuch.__module__, __name__) + + filename, lineno = test_utils.get_function_source(notmuch) + src = "%s:%s" % (filename, lineno) + + # test coroutine object + gen = notmuch() + if coroutines._DEBUG or PY35: + coro_qualname = 'TaskTests.test_task_repr..notmuch' + else: + coro_qualname = 'notmuch' + self.assertEqual(gen.__name__, 'notmuch') + if PY35: + self.assertEqual(gen.__qualname__, + coro_qualname) + + # test pending Task + t = asyncio.Task(gen, loop=self.loop) + t.add_done_callback(Dummy()) + + coro = format_coroutine(coro_qualname, 'running', src, + t._source_traceback, generator=True) + self.assertEqual(repr(t), + '()]>' % coro) + + # test cancelling Task + t.cancel() # Does not take immediate effect! + self.assertEqual(repr(t), + '()]>' % coro) + + # test cancelled Task + self.assertRaises(asyncio.CancelledError, + self.loop.run_until_complete, t) + coro = format_coroutine(coro_qualname, 'done', src, + t._source_traceback) + self.assertEqual(repr(t), + '' % coro) + + # test finished Task + t = asyncio.Task(notmuch(), loop=self.loop) + self.loop.run_until_complete(t) + coro = format_coroutine(coro_qualname, 'done', src, + t._source_traceback) + self.assertEqual(repr(t), + "" % coro) + + def test_task_repr_coro_decorator(self): + self.loop.set_debug(False) + + @asyncio.coroutine + def notmuch(): + # notmuch() function doesn't use yield from: it will be wrapped by + # @coroutine decorator + return 123 + + # test coroutine function + self.assertEqual(notmuch.__name__, 'notmuch') + if PY35: + self.assertEqual(notmuch.__qualname__, + 'TaskTests.test_task_repr_coro_decorator' + '..notmuch') + self.assertEqual(notmuch.__module__, __name__) + + # test coroutine object + gen = notmuch() + if coroutines._DEBUG or PY35: + # On Python >= 3.5, generators now inherit the name of the + # function, as expected, and have a qualified name (__qualname__ + # attribute). + coro_name = 'notmuch' + coro_qualname = ('TaskTests.test_task_repr_coro_decorator' + '..notmuch') + else: + # On Python < 3.5, generators inherit the name of the code, not of + # the function. See: http://bugs.python.org/issue21205 + coro_name = coro_qualname = 'coro' + self.assertEqual(gen.__name__, coro_name) + if PY35: + self.assertEqual(gen.__qualname__, coro_qualname) + + # test repr(CoroWrapper) + if coroutines._DEBUG: + # format the coroutine object + if coroutines._DEBUG: + filename, lineno = test_utils.get_function_source(notmuch) + frame = gen._source_traceback[-1] + coro = ('%s() running, defined at %s:%s, created at %s:%s' + % (coro_qualname, filename, lineno, + frame[0], frame[1])) + else: + code = gen.gi_code + coro = ('%s() running at %s:%s' + % (coro_qualname, code.co_filename, + code.co_firstlineno)) + + self.assertEqual(repr(gen), '' % coro) + + # test pending Task + t = asyncio.Task(gen, loop=self.loop) + t.add_done_callback(Dummy()) + + # format the coroutine object + if coroutines._DEBUG: + src = '%s:%s' % test_utils.get_function_source(notmuch) + else: + code = gen.gi_code + src = '%s:%s' % (code.co_filename, code.co_firstlineno) + coro = format_coroutine(coro_qualname, 'running', src, + t._source_traceback, + generator=not coroutines._DEBUG) + self.assertEqual(repr(t), + '()]>' % coro) + self.loop.run_until_complete(t) + + def test_task_repr_wait_for(self): + self.loop.set_debug(False) + + @asyncio.coroutine + def wait_for(fut): + return (yield from fut) + + fut = asyncio.Future(loop=self.loop) + task = asyncio.Task(wait_for(fut), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertRegex(repr(task), + '' % re.escape(repr(fut))) + + fut.set_result(None) + self.loop.run_until_complete(task) + + def test_task_basics(self): + @asyncio.coroutine + def outer(): + a = yield from inner1() + b = yield from inner2() + return a+b + + @asyncio.coroutine + def inner1(): + return 42 + + @asyncio.coroutine + def inner2(): + return 1000 + + t = outer() + self.assertEqual(self.loop.run_until_complete(t), 1042) + + def test_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = self.new_test_loop(gen) + + @asyncio.coroutine + def task(): + yield from asyncio.sleep(10.0, loop=loop) + return 12 + + t = asyncio.Task(task(), loop=loop) + loop.call_soon(t.cancel) + with self.assertRaises(asyncio.CancelledError): + loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) + self.assertFalse(t.cancel()) + + def test_cancel_yield(self): + @asyncio.coroutine + def task(): + yield + yield + return 12 + + t = asyncio.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) # start coro + t.cancel() + self.assertRaises( + asyncio.CancelledError, self.loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertTrue(t.cancelled()) + self.assertFalse(t.cancel()) + + def test_cancel_inner_future(self): + f = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def task(): + yield from f + return 12 + + t = asyncio.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) # start task + f.cancel() + with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(t) + self.assertTrue(f.cancelled()) + self.assertTrue(t.cancelled()) + + def test_cancel_both_task_and_inner_future(self): + f = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def task(): + yield from f + return 12 + + t = asyncio.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + + f.cancel() + t.cancel() + + with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(t) + + self.assertTrue(t.done()) + self.assertTrue(f.cancelled()) + self.assertTrue(t.cancelled()) + + def test_cancel_task_catching(self): + fut1 = asyncio.Future(loop=self.loop) + fut2 = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def task(): + yield from fut1 + try: + yield from fut2 + except asyncio.CancelledError: + return 42 + + t = asyncio.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut1) # White-box test. + fut1.set_result(None) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut2) # White-box test. + t.cancel() + self.assertTrue(fut2.cancelled()) + res = self.loop.run_until_complete(t) + self.assertEqual(res, 42) + self.assertFalse(t.cancelled()) + + def test_cancel_task_ignoring(self): + fut1 = asyncio.Future(loop=self.loop) + fut2 = asyncio.Future(loop=self.loop) + fut3 = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def task(): + yield from fut1 + try: + yield from fut2 + except asyncio.CancelledError: + pass + res = yield from fut3 + return res + + t = asyncio.Task(task(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut1) # White-box test. + fut1.set_result(None) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut2) # White-box test. + t.cancel() + self.assertTrue(fut2.cancelled()) + test_utils.run_briefly(self.loop) + self.assertIs(t._fut_waiter, fut3) # White-box test. + fut3.set_result(42) + res = self.loop.run_until_complete(t) + self.assertEqual(res, 42) + self.assertFalse(fut3.cancelled()) + self.assertFalse(t.cancelled()) + + def test_cancel_current_task(self): + loop = asyncio.new_event_loop() + self.set_event_loop(loop) + + @asyncio.coroutine + def task(): + t.cancel() + self.assertTrue(t._must_cancel) # White-box test. + # The sleep should be cancelled immediately. + yield from asyncio.sleep(100, loop=loop) + return 12 + + t = asyncio.Task(task(), loop=loop) + self.assertRaises( + asyncio.CancelledError, loop.run_until_complete, t) + self.assertTrue(t.done()) + self.assertFalse(t._must_cancel) # White-box test. + self.assertFalse(t.cancel()) + + def test_stop_while_run_in_complete(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + self.assertAlmostEqual(0.2, when) + when = yield 0.1 + self.assertAlmostEqual(0.3, when) + yield 0.1 + + loop = self.new_test_loop(gen) + + x = 0 + waiters = [] + + @asyncio.coroutine + def task(): + nonlocal x + while x < 10: + waiters.append(asyncio.sleep(0.1, loop=loop)) + yield from waiters[-1] + x += 1 + if x == 2: + loop.stop() + + t = asyncio.Task(task(), loop=loop) + with self.assertRaises(RuntimeError) as cm: + loop.run_until_complete(t) + self.assertEqual(str(cm.exception), + 'Event loop stopped before Future completed.') + self.assertFalse(t.done()) + self.assertEqual(x, 2) + self.assertAlmostEqual(0.3, loop.time()) + + # close generators + for w in waiters: + w.close() + t.cancel() + self.assertRaises(asyncio.CancelledError, loop.run_until_complete, t) + + def test_wait_for(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.2, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + when = yield 0.1 + + loop = self.new_test_loop(gen) + + foo_running = None + + @asyncio.coroutine + def foo(): + nonlocal foo_running + foo_running = True + try: + yield from asyncio.sleep(0.2, loop=loop) + finally: + foo_running = False + return 'done' + + fut = asyncio.Task(foo(), loop=loop) + + with self.assertRaises(asyncio.TimeoutError): + loop.run_until_complete(asyncio.wait_for(fut, 0.1, loop=loop)) + self.assertTrue(fut.done()) + # it should have been cancelled due to the timeout + self.assertTrue(fut.cancelled()) + self.assertAlmostEqual(0.1, loop.time()) + self.assertEqual(foo_running, False) + + def test_wait_for_blocking(self): + loop = self.new_test_loop() + + @asyncio.coroutine + def coro(): + return 'done' + + res = loop.run_until_complete(asyncio.wait_for(coro(), + timeout=None, + loop=loop)) + self.assertEqual(res, 'done') + + def test_wait_for_with_global_loop(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.2, when) + when = yield 0 + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = self.new_test_loop(gen) + + @asyncio.coroutine + def foo(): + yield from asyncio.sleep(0.2, loop=loop) + return 'done' + + asyncio.set_event_loop(loop) + try: + fut = asyncio.Task(foo(), loop=loop) + with self.assertRaises(asyncio.TimeoutError): + loop.run_until_complete(asyncio.wait_for(fut, 0.01)) + finally: + asyncio.set_event_loop(None) + + self.assertAlmostEqual(0.01, loop.time()) + self.assertTrue(fut.done()) + self.assertTrue(fut.cancelled()) + + def test_wait_for_race_condition(self): + + def gen(): + yield 0.1 + yield 0.1 + yield 0.1 + + loop = self.new_test_loop(gen) + + fut = asyncio.Future(loop=loop) + task = asyncio.wait_for(fut, timeout=0.2, loop=loop) + loop.call_later(0.1, fut.set_result, "ok") + res = loop.run_until_complete(task) + self.assertEqual(res, "ok") + + def test_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = self.new_test_loop(gen) + + a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) + b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop) + + @asyncio.coroutine + def foo(): + done, pending = yield from asyncio.wait([b, a], loop=loop) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + self.assertEqual(res, 42) + self.assertAlmostEqual(0.15, loop.time()) + + # Doing it again should take no time and exercise a different path. + res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + self.assertEqual(res, 42) + + def test_wait_with_global_loop(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.01, when) + when = yield 0 + self.assertAlmostEqual(0.015, when) + yield 0.015 + + loop = self.new_test_loop(gen) + + a = asyncio.Task(asyncio.sleep(0.01, loop=loop), loop=loop) + b = asyncio.Task(asyncio.sleep(0.015, loop=loop), loop=loop) + + @asyncio.coroutine + def foo(): + done, pending = yield from asyncio.wait([b, a]) + self.assertEqual(done, set([a, b])) + self.assertEqual(pending, set()) + return 42 + + asyncio.set_event_loop(loop) + res = loop.run_until_complete( + asyncio.Task(foo(), loop=loop)) + + self.assertEqual(res, 42) + + def test_wait_duplicate_coroutines(self): + @asyncio.coroutine + def coro(s): + return s + c = coro('test') + + task = asyncio.Task( + asyncio.wait([c, c, coro('spam')], loop=self.loop), + loop=self.loop) + + done, pending = self.loop.run_until_complete(task) + + self.assertFalse(pending) + self.assertEqual(set(f.result() for f in done), {'test', 'spam'}) + + def test_wait_errors(self): + self.assertRaises( + ValueError, self.loop.run_until_complete, + asyncio.wait(set(), loop=self.loop)) + + # -1 is an invalid return_when value + sleep_coro = asyncio.sleep(10.0, loop=self.loop) + wait_coro = asyncio.wait([sleep_coro], return_when=-1, loop=self.loop) + self.assertRaises(ValueError, + self.loop.run_until_complete, wait_coro) + + sleep_coro.close() + + def test_wait_first_completed(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = self.new_test_loop(gen) + + a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop) + b = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) + task = asyncio.Task( + asyncio.wait([b, a], return_when=asyncio.FIRST_COMPLETED, + loop=loop), + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertFalse(a.done()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + self.assertAlmostEqual(0.1, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(asyncio.wait([a, b], loop=loop)) + + def test_wait_really_done(self): + # there is possibility that some tasks in the pending list + # became done but their callbacks haven't all been called yet + + @asyncio.coroutine + def coro1(): + yield + + @asyncio.coroutine + def coro2(): + yield + yield + + a = asyncio.Task(coro1(), loop=self.loop) + b = asyncio.Task(coro2(), loop=self.loop) + task = asyncio.Task( + asyncio.wait([b, a], return_when=asyncio.FIRST_COMPLETED, + loop=self.loop), + loop=self.loop) + + done, pending = self.loop.run_until_complete(task) + self.assertEqual({a, b}, done) + self.assertTrue(a.done()) + self.assertIsNone(a.result()) + self.assertTrue(b.done()) + self.assertIsNone(b.result()) + + def test_wait_first_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = self.new_test_loop(gen) + + # first_exception, task already has exception + a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop) + + @asyncio.coroutine + def exc(): + raise ZeroDivisionError('err') + + b = asyncio.Task(exc(), loop=loop) + task = asyncio.Task( + asyncio.wait([b, a], return_when=asyncio.FIRST_EXCEPTION, + loop=loop), + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertAlmostEqual(0, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(asyncio.wait([a, b], loop=loop)) + + def test_wait_first_exception_in_wait(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + when = yield 0 + self.assertAlmostEqual(0.01, when) + yield 0.01 + + loop = self.new_test_loop(gen) + + # first_exception, exception during waiting + a = asyncio.Task(asyncio.sleep(10.0, loop=loop), loop=loop) + + @asyncio.coroutine + def exc(): + yield from asyncio.sleep(0.01, loop=loop) + raise ZeroDivisionError('err') + + b = asyncio.Task(exc(), loop=loop) + task = asyncio.wait([b, a], return_when=asyncio.FIRST_EXCEPTION, + loop=loop) + + done, pending = loop.run_until_complete(task) + self.assertEqual({b}, done) + self.assertEqual({a}, pending) + self.assertAlmostEqual(0.01, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(asyncio.wait([a, b], loop=loop)) + + def test_wait_with_exception(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + yield 0.15 + + loop = self.new_test_loop(gen) + + a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) + + @asyncio.coroutine + def sleeper(): + yield from asyncio.sleep(0.15, loop=loop) + raise ZeroDivisionError('really') + + b = asyncio.Task(sleeper(), loop=loop) + + @asyncio.coroutine + def foo(): + done, pending = yield from asyncio.wait([b, a], loop=loop) + self.assertEqual(len(done), 2) + self.assertEqual(pending, set()) + errors = set(f for f in done if f.exception() is not None) + self.assertEqual(len(errors), 1) + + loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + def test_wait_with_timeout(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.11, when) + yield 0.11 + + loop = self.new_test_loop(gen) + + a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) + b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop) + + @asyncio.coroutine + def foo(): + done, pending = yield from asyncio.wait([b, a], timeout=0.11, + loop=loop) + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + + loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.11, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(asyncio.wait([a, b], loop=loop)) + + def test_wait_concurrent_complete(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(0.15, when) + when = yield 0 + self.assertAlmostEqual(0.1, when) + yield 0.1 + + loop = self.new_test_loop(gen) + + a = asyncio.Task(asyncio.sleep(0.1, loop=loop), loop=loop) + b = asyncio.Task(asyncio.sleep(0.15, loop=loop), loop=loop) + + done, pending = loop.run_until_complete( + asyncio.wait([b, a], timeout=0.1, loop=loop)) + + self.assertEqual(done, set([a])) + self.assertEqual(pending, set([b])) + self.assertAlmostEqual(0.1, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(asyncio.wait([a, b], loop=loop)) + + def test_as_completed(self): + + def gen(): + yield 0 + yield 0 + yield 0.01 + yield 0 + + loop = self.new_test_loop(gen) + # disable "slow callback" warning + loop.slow_callback_duration = 1.0 + completed = set() + time_shifted = False + + @asyncio.coroutine + def sleeper(dt, x): + nonlocal time_shifted + yield from asyncio.sleep(dt, loop=loop) + completed.add(x) + if not time_shifted and 'a' in completed and 'b' in completed: + time_shifted = True + loop.advance_time(0.14) + return x + + a = sleeper(0.01, 'a') + b = sleeper(0.01, 'b') + c = sleeper(0.15, 'c') + + @asyncio.coroutine + def foo(): + values = [] + for f in asyncio.as_completed([b, c, a], loop=loop): + values.append((yield from f)) + return values + + res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + self.assertTrue('a' in res[:2]) + self.assertTrue('b' in res[:2]) + self.assertEqual(res[2], 'c') + + # Doing it again should take no time and exercise a different path. + res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + self.assertAlmostEqual(0.15, loop.time()) + + def test_as_completed_with_timeout(self): + + def gen(): + yield + yield 0 + yield 0 + yield 0.1 + + loop = self.new_test_loop(gen) + + a = asyncio.sleep(0.1, 'a', loop=loop) + b = asyncio.sleep(0.15, 'b', loop=loop) + + @asyncio.coroutine + def foo(): + values = [] + for f in asyncio.as_completed([a, b], timeout=0.12, loop=loop): + if values: + loop.advance_time(0.02) + try: + v = yield from f + values.append((1, v)) + except asyncio.TimeoutError as exc: + values.append((2, exc)) + return values + + res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + self.assertEqual(len(res), 2, res) + self.assertEqual(res[0], (1, 'a')) + self.assertEqual(res[1][0], 2) + self.assertIsInstance(res[1][1], asyncio.TimeoutError) + self.assertAlmostEqual(0.12, loop.time()) + + # move forward to close generator + loop.advance_time(10) + loop.run_until_complete(asyncio.wait([a, b], loop=loop)) + + def test_as_completed_with_unused_timeout(self): + + def gen(): + yield + yield 0 + yield 0.01 + + loop = self.new_test_loop(gen) + + a = asyncio.sleep(0.01, 'a', loop=loop) + + @asyncio.coroutine + def foo(): + for f in asyncio.as_completed([a], timeout=1, loop=loop): + v = yield from f + self.assertEqual(v, 'a') + + loop.run_until_complete(asyncio.Task(foo(), loop=loop)) + + def test_as_completed_reverse_wait(self): + + def gen(): + yield 0 + yield 0.05 + yield 0 + + loop = self.new_test_loop(gen) + + a = asyncio.sleep(0.05, 'a', loop=loop) + b = asyncio.sleep(0.10, 'b', loop=loop) + fs = {a, b} + futs = list(asyncio.as_completed(fs, loop=loop)) + self.assertEqual(len(futs), 2) + + x = loop.run_until_complete(futs[1]) + self.assertEqual(x, 'a') + self.assertAlmostEqual(0.05, loop.time()) + loop.advance_time(0.05) + y = loop.run_until_complete(futs[0]) + self.assertEqual(y, 'b') + self.assertAlmostEqual(0.10, loop.time()) + + def test_as_completed_concurrent(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0 + self.assertAlmostEqual(0.05, when) + yield 0.05 + + loop = self.new_test_loop(gen) + + a = asyncio.sleep(0.05, 'a', loop=loop) + b = asyncio.sleep(0.05, 'b', loop=loop) + fs = {a, b} + futs = list(asyncio.as_completed(fs, loop=loop)) + self.assertEqual(len(futs), 2) + waiter = asyncio.wait(futs, loop=loop) + done, pending = loop.run_until_complete(waiter) + self.assertEqual(set(f.result() for f in done), {'a', 'b'}) + + def test_as_completed_duplicate_coroutines(self): + + @asyncio.coroutine + def coro(s): + return s + + @asyncio.coroutine + def runner(): + result = [] + c = coro('ham') + for f in asyncio.as_completed([c, c, coro('spam')], + loop=self.loop): + result.append((yield from f)) + return result + + fut = asyncio.Task(runner(), loop=self.loop) + self.loop.run_until_complete(fut) + result = fut.result() + self.assertEqual(set(result), {'ham', 'spam'}) + self.assertEqual(len(result), 2) + + def test_sleep(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.05, when) + when = yield 0.05 + self.assertAlmostEqual(0.1, when) + yield 0.05 + + loop = self.new_test_loop(gen) + + @asyncio.coroutine + def sleeper(dt, arg): + yield from asyncio.sleep(dt/2, loop=loop) + res = yield from asyncio.sleep(dt/2, arg, loop=loop) + return res + + t = asyncio.Task(sleeper(0.1, 'yeah'), loop=loop) + loop.run_until_complete(t) + self.assertTrue(t.done()) + self.assertEqual(t.result(), 'yeah') + self.assertAlmostEqual(0.1, loop.time()) + + def test_sleep_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = self.new_test_loop(gen) + + t = asyncio.Task(asyncio.sleep(10.0, 'yeah', loop=loop), + loop=loop) + + handle = None + orig_call_later = loop.call_later + + def call_later(delay, callback, *args): + nonlocal handle + handle = orig_call_later(delay, callback, *args) + return handle + + loop.call_later = call_later + test_utils.run_briefly(loop) + + self.assertFalse(handle._cancelled) + + t.cancel() + test_utils.run_briefly(loop) + self.assertTrue(handle._cancelled) + + def test_task_cancel_sleeping_task(self): + + def gen(): + when = yield + self.assertAlmostEqual(0.1, when) + when = yield 0 + self.assertAlmostEqual(5000, when) + yield 0.1 + + loop = self.new_test_loop(gen) + + @asyncio.coroutine + def sleep(dt): + yield from asyncio.sleep(dt, loop=loop) + + @asyncio.coroutine + def doit(): + sleeper = asyncio.Task(sleep(5000), loop=loop) + loop.call_later(0.1, sleeper.cancel) + try: + yield from sleeper + except asyncio.CancelledError: + return 'cancelled' + else: + return 'slept in' + + doer = doit() + self.assertEqual(loop.run_until_complete(doer), 'cancelled') + self.assertAlmostEqual(0.1, loop.time()) + + def test_task_cancel_waiter_future(self): + fut = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def coro(): + yield from fut + + task = asyncio.Task(coro(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertIs(task._fut_waiter, fut) + + task.cancel() + test_utils.run_briefly(self.loop) + self.assertRaises( + asyncio.CancelledError, self.loop.run_until_complete, task) + self.assertIsNone(task._fut_waiter) + self.assertTrue(fut.cancelled()) + + def test_step_in_completed_task(self): + @asyncio.coroutine + def notmuch(): + return 'ko' + + gen = notmuch() + task = asyncio.Task(gen, loop=self.loop) + task.set_result('ok') + + self.assertRaises(AssertionError, task._step) + gen.close() + + def test_step_result(self): + @asyncio.coroutine + def notmuch(): + yield None + yield 1 + return 'ko' + + self.assertRaises( + RuntimeError, self.loop.run_until_complete, notmuch()) + + def test_step_result_future(self): + # If coroutine returns future, task waits on this future. + + class Fut(asyncio.Future): + def __init__(self, *args, **kwds): + self.cb_added = False + super().__init__(*args, **kwds) + + def add_done_callback(self, fn): + self.cb_added = True + super().add_done_callback(fn) + + fut = Fut(loop=self.loop) + result = None + + @asyncio.coroutine + def wait_for_future(): + nonlocal result + result = yield from fut + + t = asyncio.Task(wait_for_future(), loop=self.loop) + test_utils.run_briefly(self.loop) + self.assertTrue(fut.cb_added) + + res = object() + fut.set_result(res) + test_utils.run_briefly(self.loop) + self.assertIs(res, result) + self.assertTrue(t.done()) + self.assertIsNone(t.result()) + + def test_step_with_baseexception(self): + @asyncio.coroutine + def notmutch(): + raise BaseException() + + task = asyncio.Task(notmutch(), loop=self.loop) + self.assertRaises(BaseException, task._step) + + self.assertTrue(task.done()) + self.assertIsInstance(task.exception(), BaseException) + + def test_baseexception_during_cancel(self): + + def gen(): + when = yield + self.assertAlmostEqual(10.0, when) + yield 0 + + loop = self.new_test_loop(gen) + + @asyncio.coroutine + def sleeper(): + yield from asyncio.sleep(10, loop=loop) + + base_exc = BaseException() + + @asyncio.coroutine + def notmutch(): + try: + yield from sleeper() + except asyncio.CancelledError: + raise base_exc + + task = asyncio.Task(notmutch(), loop=loop) + test_utils.run_briefly(loop) + + task.cancel() + self.assertFalse(task.done()) + + self.assertRaises(BaseException, test_utils.run_briefly, loop) + + self.assertTrue(task.done()) + self.assertFalse(task.cancelled()) + self.assertIs(task.exception(), base_exc) + + def test_iscoroutinefunction(self): + def fn(): + pass + + self.assertFalse(asyncio.iscoroutinefunction(fn)) + + def fn1(): + yield + self.assertFalse(asyncio.iscoroutinefunction(fn1)) + + @asyncio.coroutine + def fn2(): + yield + self.assertTrue(asyncio.iscoroutinefunction(fn2)) + + def test_yield_vs_yield_from(self): + fut = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def wait_for_future(): + yield fut + + task = wait_for_future() + with self.assertRaises(RuntimeError): + self.loop.run_until_complete(task) + + self.assertFalse(fut.done()) + + def test_yield_vs_yield_from_generator(self): + @asyncio.coroutine + def coro(): + yield + + @asyncio.coroutine + def wait_for_future(): + gen = coro() + try: + yield gen + finally: + gen.close() + + task = wait_for_future() + self.assertRaises( + RuntimeError, + self.loop.run_until_complete, task) + + def test_coroutine_non_gen_function(self): + @asyncio.coroutine + def func(): + return 'test' + + self.assertTrue(asyncio.iscoroutinefunction(func)) + + coro = func() + self.assertTrue(asyncio.iscoroutine(coro)) + + res = self.loop.run_until_complete(coro) + self.assertEqual(res, 'test') + + def test_coroutine_non_gen_function_return_future(self): + fut = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def func(): + return fut + + @asyncio.coroutine + def coro(): + fut.set_result('test') + + t1 = asyncio.Task(func(), loop=self.loop) + t2 = asyncio.Task(coro(), loop=self.loop) + res = self.loop.run_until_complete(t1) + self.assertEqual(res, 'test') + self.assertIsNone(t2.result()) + + def test_current_task(self): + self.assertIsNone(asyncio.Task.current_task(loop=self.loop)) + + @asyncio.coroutine + def coro(loop): + self.assertTrue(asyncio.Task.current_task(loop=loop) is task) + + task = asyncio.Task(coro(self.loop), loop=self.loop) + self.loop.run_until_complete(task) + self.assertIsNone(asyncio.Task.current_task(loop=self.loop)) + + def test_current_task_with_interleaving_tasks(self): + self.assertIsNone(asyncio.Task.current_task(loop=self.loop)) + + fut1 = asyncio.Future(loop=self.loop) + fut2 = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def coro1(loop): + self.assertTrue(asyncio.Task.current_task(loop=loop) is task1) + yield from fut1 + self.assertTrue(asyncio.Task.current_task(loop=loop) is task1) + fut2.set_result(True) + + @asyncio.coroutine + def coro2(loop): + self.assertTrue(asyncio.Task.current_task(loop=loop) is task2) + fut1.set_result(True) + yield from fut2 + self.assertTrue(asyncio.Task.current_task(loop=loop) is task2) + + task1 = asyncio.Task(coro1(self.loop), loop=self.loop) + task2 = asyncio.Task(coro2(self.loop), loop=self.loop) + + self.loop.run_until_complete(asyncio.wait((task1, task2), + loop=self.loop)) + self.assertIsNone(asyncio.Task.current_task(loop=self.loop)) + + # Some thorough tests for cancellation propagation through + # coroutines, tasks and wait(). + + def test_yield_future_passes_cancel(self): + # Cancelling outer() cancels inner() cancels waiter. + proof = 0 + waiter = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def inner(): + nonlocal proof + try: + yield from waiter + except asyncio.CancelledError: + proof += 1 + raise + else: + self.fail('got past sleep() in inner()') + + @asyncio.coroutine + def outer(): + nonlocal proof + try: + yield from inner() + except asyncio.CancelledError: + proof += 100 # Expect this path. + else: + proof += 10 + + f = asyncio.async(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + self.loop.run_until_complete(f) + self.assertEqual(proof, 101) + self.assertTrue(waiter.cancelled()) + + def test_yield_wait_does_not_shield_cancel(self): + # Cancelling outer() makes wait() return early, leaves inner() + # running. + proof = 0 + waiter = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def inner(): + nonlocal proof + yield from waiter + proof += 1 + + @asyncio.coroutine + def outer(): + nonlocal proof + d, p = yield from asyncio.wait([inner()], loop=self.loop) + proof += 100 + + f = asyncio.async(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + self.assertRaises( + asyncio.CancelledError, self.loop.run_until_complete, f) + waiter.set_result(None) + test_utils.run_briefly(self.loop) + self.assertEqual(proof, 1) + + def test_shield_result(self): + inner = asyncio.Future(loop=self.loop) + outer = asyncio.shield(inner) + inner.set_result(42) + res = self.loop.run_until_complete(outer) + self.assertEqual(res, 42) + + def test_shield_exception(self): + inner = asyncio.Future(loop=self.loop) + outer = asyncio.shield(inner) + test_utils.run_briefly(self.loop) + exc = RuntimeError('expected') + inner.set_exception(exc) + test_utils.run_briefly(self.loop) + self.assertIs(outer.exception(), exc) + + def test_shield_cancel(self): + inner = asyncio.Future(loop=self.loop) + outer = asyncio.shield(inner) + test_utils.run_briefly(self.loop) + inner.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(outer.cancelled()) + + def test_shield_shortcut(self): + fut = asyncio.Future(loop=self.loop) + fut.set_result(42) + res = self.loop.run_until_complete(asyncio.shield(fut)) + self.assertEqual(res, 42) + + def test_shield_effect(self): + # Cancelling outer() does not affect inner(). + proof = 0 + waiter = asyncio.Future(loop=self.loop) + + @asyncio.coroutine + def inner(): + nonlocal proof + yield from waiter + proof += 1 + + @asyncio.coroutine + def outer(): + nonlocal proof + yield from asyncio.shield(inner(), loop=self.loop) + proof += 100 + + f = asyncio.async(outer(), loop=self.loop) + test_utils.run_briefly(self.loop) + f.cancel() + with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(f) + waiter.set_result(None) + test_utils.run_briefly(self.loop) + self.assertEqual(proof, 1) + + def test_shield_gather(self): + child1 = asyncio.Future(loop=self.loop) + child2 = asyncio.Future(loop=self.loop) + parent = asyncio.gather(child1, child2, loop=self.loop) + outer = asyncio.shield(parent, loop=self.loop) + test_utils.run_briefly(self.loop) + outer.cancel() + test_utils.run_briefly(self.loop) + self.assertTrue(outer.cancelled()) + child1.set_result(1) + child2.set_result(2) + test_utils.run_briefly(self.loop) + self.assertEqual(parent.result(), [1, 2]) + + def test_gather_shield(self): + child1 = asyncio.Future(loop=self.loop) + child2 = asyncio.Future(loop=self.loop) + inner1 = asyncio.shield(child1, loop=self.loop) + inner2 = asyncio.shield(child2, loop=self.loop) + parent = asyncio.gather(inner1, inner2, loop=self.loop) + test_utils.run_briefly(self.loop) + parent.cancel() + # This should cancel inner1 and inner2 but bot child1 and child2. + test_utils.run_briefly(self.loop) + self.assertIsInstance(parent.exception(), asyncio.CancelledError) + self.assertTrue(inner1.cancelled()) + self.assertTrue(inner2.cancelled()) + child1.set_result(1) + child2.set_result(2) + test_utils.run_briefly(self.loop) + + def test_as_completed_invalid_args(self): + fut = asyncio.Future(loop=self.loop) + + # as_completed() expects a list of futures, not a future instance + self.assertRaises(TypeError, self.loop.run_until_complete, + asyncio.as_completed(fut, loop=self.loop)) + coro = coroutine_function() + self.assertRaises(TypeError, self.loop.run_until_complete, + asyncio.as_completed(coro, loop=self.loop)) + coro.close() + + def test_wait_invalid_args(self): + fut = asyncio.Future(loop=self.loop) + + # wait() expects a list of futures, not a future instance + self.assertRaises(TypeError, self.loop.run_until_complete, + asyncio.wait(fut, loop=self.loop)) + coro = coroutine_function() + self.assertRaises(TypeError, self.loop.run_until_complete, + asyncio.wait(coro, loop=self.loop)) + coro.close() + + # wait() expects at least a future + self.assertRaises(ValueError, self.loop.run_until_complete, + asyncio.wait([], loop=self.loop)) + + def test_corowrapper_mocks_generator(self): + + def check(): + # A function that asserts various things. + # Called twice, with different debug flag values. + + @asyncio.coroutine + def coro(): + # The actual coroutine. + self.assertTrue(gen.gi_running) + yield from fut + + # A completed Future used to run the coroutine. + fut = asyncio.Future(loop=self.loop) + fut.set_result(None) + + # Call the coroutine. + gen = coro() + + # Check some properties. + self.assertTrue(asyncio.iscoroutine(gen)) + self.assertIsInstance(gen.gi_frame, types.FrameType) + self.assertFalse(gen.gi_running) + self.assertIsInstance(gen.gi_code, types.CodeType) + + # Run it. + self.loop.run_until_complete(gen) + + # The frame should have changed. + self.assertIsNone(gen.gi_frame) + + # Save debug flag. + old_debug = asyncio.coroutines._DEBUG + try: + # Test with debug flag cleared. + asyncio.coroutines._DEBUG = False + check() + + # Test with debug flag set. + asyncio.coroutines._DEBUG = True + check() + + finally: + # Restore original debug flag. + asyncio.coroutines._DEBUG = old_debug + + def test_yield_from_corowrapper(self): + old_debug = asyncio.coroutines._DEBUG + asyncio.coroutines._DEBUG = True + try: + @asyncio.coroutine + def t1(): + return (yield from t2()) + + @asyncio.coroutine + def t2(): + f = asyncio.Future(loop=self.loop) + asyncio.Task(t3(f), loop=self.loop) + return (yield from f) + + @asyncio.coroutine + def t3(f): + f.set_result((1, 2, 3)) + + task = asyncio.Task(t1(), loop=self.loop) + val = self.loop.run_until_complete(task) + self.assertEqual(val, (1, 2, 3)) + finally: + asyncio.coroutines._DEBUG = old_debug + + def test_yield_from_corowrapper_send(self): + def foo(): + a = yield + return a + + def call(arg): + cw = asyncio.coroutines.CoroWrapper(foo(), foo) + cw.send(None) + try: + cw.send(arg) + except StopIteration as ex: + return ex.args[0] + else: + raise AssertionError('StopIteration was expected') + + self.assertEqual(call((1, 2)), (1, 2)) + self.assertEqual(call('spam'), 'spam') + + def test_corowrapper_weakref(self): + wd = weakref.WeakValueDictionary() + def foo(): yield from [] + cw = asyncio.coroutines.CoroWrapper(foo(), foo) + wd['cw'] = cw # Would fail without __weakref__ slot. + cw.gen = None # Suppress warning from __del__. + + @unittest.skipUnless(PY34, + 'need python 3.4 or later') + def test_log_destroyed_pending_task(self): + @asyncio.coroutine + def kill_me(loop): + future = asyncio.Future(loop=loop) + yield from future + # at this point, the only reference to kill_me() task is + # the Task._wakeup() method in future._callbacks + raise Exception("code never reached") + + mock_handler = mock.Mock() + self.loop.set_debug(True) + self.loop.set_exception_handler(mock_handler) + + # schedule the task + coro = kill_me(self.loop) + task = asyncio.async(coro, loop=self.loop) + self.assertEqual(asyncio.Task.all_tasks(loop=self.loop), {task}) + + # execute the task so it waits for future + self.loop._run_once() + self.assertEqual(len(self.loop._ready), 0) + + # remove the future used in kill_me(), and references to the task + del coro.gi_frame.f_locals['future'] + coro = None + source_traceback = task._source_traceback + task = None + + # no more reference to kill_me() task: the task is destroyed by the GC + support.gc_collect() + + self.assertEqual(asyncio.Task.all_tasks(loop=self.loop), set()) + + mock_handler.assert_called_with(self.loop, { + 'message': 'Task was destroyed but it is pending!', + 'task': mock.ANY, + 'source_traceback': source_traceback, + }) + mock_handler.reset_mock() + + @mock.patch('asyncio.coroutines.logger') + def test_coroutine_never_yielded(self, m_log): + debug = asyncio.coroutines._DEBUG + try: + asyncio.coroutines._DEBUG = True + @asyncio.coroutine + def coro_noop(): + pass + finally: + asyncio.coroutines._DEBUG = debug + + tb_filename = __file__ + tb_lineno = sys._getframe().f_lineno + 2 + # create a coroutine object but don't use it + coro_noop() + support.gc_collect() + + self.assertTrue(m_log.error.called) + message = m_log.error.call_args[0][0] + func_filename, func_lineno = test_utils.get_function_source(coro_noop) + regex = (r'^ ' + r'was never yielded from\n' + r'Coroutine object created at \(most recent call last\):\n' + r'.*\n' + r' File "%s", line %s, in test_coroutine_never_yielded\n' + r' coro_noop\(\)$' + % (re.escape(coro_noop.__qualname__), + re.escape(func_filename), func_lineno, + re.escape(tb_filename), tb_lineno)) + + self.assertRegex(message, re.compile(regex, re.DOTALL)) + + def test_task_source_traceback(self): + self.loop.set_debug(True) + + task = asyncio.Task(coroutine_function(), loop=self.loop) + lineno = sys._getframe().f_lineno - 1 + self.assertIsInstance(task._source_traceback, list) + self.assertEqual(task._source_traceback[-1][:3], + (__file__, + lineno, + 'test_task_source_traceback')) + self.loop.run_until_complete(task) + + def _test_cancel_wait_for(self, timeout): + loop = asyncio.new_event_loop() + self.addCleanup(loop.close) + + @asyncio.coroutine + def blocking_coroutine(): + fut = asyncio.Future(loop=loop) + # Block: fut result is never set + yield from fut + + task = loop.create_task(blocking_coroutine()) + + wait = loop.create_task(asyncio.wait_for(task, timeout, loop=loop)) + loop.call_soon(wait.cancel) + + self.assertRaises(asyncio.CancelledError, + loop.run_until_complete, wait) + + # Python issue #23219: cancelling the wait must also cancel the task + self.assertTrue(task.cancelled()) + + def test_cancel_blocking_wait_for(self): + self._test_cancel_wait_for(None) + + def test_cancel_wait_for(self): + self._test_cancel_wait_for(60.0) + + +class GatherTestsBase: + + def setUp(self): + self.one_loop = self.new_test_loop() + self.other_loop = self.new_test_loop() + self.set_event_loop(self.one_loop, cleanup=False) + + def _run_loop(self, loop): + while loop._ready: + test_utils.run_briefly(loop) + + def _check_success(self, **kwargs): + a, b, c = [asyncio.Future(loop=self.one_loop) for i in range(3)] + fut = asyncio.gather(*self.wrap_futures(a, b, c), **kwargs) + cb = test_utils.MockCallback() + fut.add_done_callback(cb) + b.set_result(1) + a.set_result(2) + self._run_loop(self.one_loop) + self.assertEqual(cb.called, False) + self.assertFalse(fut.done()) + c.set_result(3) + self._run_loop(self.one_loop) + cb.assert_called_once_with(fut) + self.assertEqual(fut.result(), [2, 1, 3]) + + def test_success(self): + self._check_success() + self._check_success(return_exceptions=False) + + def test_result_exception_success(self): + self._check_success(return_exceptions=True) + + def test_one_exception(self): + a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)] + fut = asyncio.gather(*self.wrap_futures(a, b, c, d, e)) + cb = test_utils.MockCallback() + fut.add_done_callback(cb) + exc = ZeroDivisionError() + a.set_result(1) + b.set_exception(exc) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertIs(fut.exception(), exc) + # Does nothing + c.set_result(3) + d.cancel() + e.set_exception(RuntimeError()) + e.exception() + + def test_return_exceptions(self): + a, b, c, d = [asyncio.Future(loop=self.one_loop) for i in range(4)] + fut = asyncio.gather(*self.wrap_futures(a, b, c, d), + return_exceptions=True) + cb = test_utils.MockCallback() + fut.add_done_callback(cb) + exc = ZeroDivisionError() + exc2 = RuntimeError() + b.set_result(1) + c.set_exception(exc) + a.set_result(3) + self._run_loop(self.one_loop) + self.assertFalse(fut.done()) + d.set_exception(exc2) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertEqual(fut.result(), [3, 1, exc, exc2]) + + def test_env_var_debug(self): + aio_path = os.path.dirname(os.path.dirname(asyncio.__file__)) + + code = '\n'.join(( + 'import asyncio.coroutines', + 'print(asyncio.coroutines._DEBUG)')) + + # Test with -E to not fail if the unit test was run with + # PYTHONASYNCIODEBUG set to a non-empty string + sts, stdout, stderr = assert_python_ok('-E', '-c', code, + PYTHONPATH=aio_path) + self.assertEqual(stdout.rstrip(), b'False') + + sts, stdout, stderr = assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='', + PYTHONPATH=aio_path) + self.assertEqual(stdout.rstrip(), b'False') + + sts, stdout, stderr = assert_python_ok('-c', code, + PYTHONASYNCIODEBUG='1', + PYTHONPATH=aio_path) + self.assertEqual(stdout.rstrip(), b'True') + + sts, stdout, stderr = assert_python_ok('-E', '-c', code, + PYTHONASYNCIODEBUG='1', + PYTHONPATH=aio_path) + self.assertEqual(stdout.rstrip(), b'False') + + +class FutureGatherTests(GatherTestsBase, test_utils.TestCase): + + def wrap_futures(self, *futures): + return futures + + def _check_empty_sequence(self, seq_or_iter): + asyncio.set_event_loop(self.one_loop) + self.addCleanup(asyncio.set_event_loop, None) + fut = asyncio.gather(*seq_or_iter) + self.assertIsInstance(fut, asyncio.Future) + self.assertIs(fut._loop, self.one_loop) + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + self.assertEqual(fut.result(), []) + fut = asyncio.gather(*seq_or_iter, loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + + def test_constructor_empty_sequence(self): + self._check_empty_sequence([]) + self._check_empty_sequence(()) + self._check_empty_sequence(set()) + self._check_empty_sequence(iter("")) + + def test_constructor_heterogenous_futures(self): + fut1 = asyncio.Future(loop=self.one_loop) + fut2 = asyncio.Future(loop=self.other_loop) + with self.assertRaises(ValueError): + asyncio.gather(fut1, fut2) + with self.assertRaises(ValueError): + asyncio.gather(fut1, loop=self.other_loop) + + def test_constructor_homogenous_futures(self): + children = [asyncio.Future(loop=self.other_loop) for i in range(3)] + fut = asyncio.gather(*children) + self.assertIs(fut._loop, self.other_loop) + self._run_loop(self.other_loop) + self.assertFalse(fut.done()) + fut = asyncio.gather(*children, loop=self.other_loop) + self.assertIs(fut._loop, self.other_loop) + self._run_loop(self.other_loop) + self.assertFalse(fut.done()) + + def test_one_cancellation(self): + a, b, c, d, e = [asyncio.Future(loop=self.one_loop) for i in range(5)] + fut = asyncio.gather(a, b, c, d, e) + cb = test_utils.MockCallback() + fut.add_done_callback(cb) + a.set_result(1) + b.cancel() + self._run_loop(self.one_loop) + self.assertTrue(fut.done()) + cb.assert_called_once_with(fut) + self.assertFalse(fut.cancelled()) + self.assertIsInstance(fut.exception(), asyncio.CancelledError) + # Does nothing + c.set_result(3) + d.cancel() + e.set_exception(RuntimeError()) + e.exception() + + def test_result_exception_one_cancellation(self): + a, b, c, d, e, f = [asyncio.Future(loop=self.one_loop) + for i in range(6)] + fut = asyncio.gather(a, b, c, d, e, f, return_exceptions=True) + cb = test_utils.MockCallback() + fut.add_done_callback(cb) + a.set_result(1) + zde = ZeroDivisionError() + b.set_exception(zde) + c.cancel() + self._run_loop(self.one_loop) + self.assertFalse(fut.done()) + d.set_result(3) + e.cancel() + rte = RuntimeError() + f.set_exception(rte) + res = self.one_loop.run_until_complete(fut) + self.assertIsInstance(res[2], asyncio.CancelledError) + self.assertIsInstance(res[4], asyncio.CancelledError) + res[2] = res[4] = None + self.assertEqual(res, [1, zde, None, 3, None, rte]) + cb.assert_called_once_with(fut) + + +class CoroutineGatherTests(GatherTestsBase, test_utils.TestCase): + + def setUp(self): + super().setUp() + asyncio.set_event_loop(self.one_loop) + + def wrap_futures(self, *futures): + coros = [] + for fut in futures: + @asyncio.coroutine + def coro(fut=fut): + return (yield from fut) + coros.append(coro()) + return coros + + def test_constructor_loop_selection(self): + @asyncio.coroutine + def coro(): + return 'abc' + gen1 = coro() + gen2 = coro() + fut = asyncio.gather(gen1, gen2) + self.assertIs(fut._loop, self.one_loop) + self.one_loop.run_until_complete(fut) + + self.set_event_loop(self.other_loop, cleanup=False) + gen3 = coro() + gen4 = coro() + fut2 = asyncio.gather(gen3, gen4, loop=self.other_loop) + self.assertIs(fut2._loop, self.other_loop) + self.other_loop.run_until_complete(fut2) + + def test_duplicate_coroutines(self): + @asyncio.coroutine + def coro(s): + return s + c = coro('abc') + fut = asyncio.gather(c, c, coro('def'), c, loop=self.one_loop) + self._run_loop(self.one_loop) + self.assertEqual(fut.result(), ['abc', 'abc', 'def', 'abc']) + + def test_cancellation_broadcast(self): + # Cancelling outer() cancels all children. + proof = 0 + waiter = asyncio.Future(loop=self.one_loop) + + @asyncio.coroutine + def inner(): + nonlocal proof + yield from waiter + proof += 1 + + child1 = asyncio.async(inner(), loop=self.one_loop) + child2 = asyncio.async(inner(), loop=self.one_loop) + gatherer = None + + @asyncio.coroutine + def outer(): + nonlocal proof, gatherer + gatherer = asyncio.gather(child1, child2, loop=self.one_loop) + yield from gatherer + proof += 100 + + f = asyncio.async(outer(), loop=self.one_loop) + test_utils.run_briefly(self.one_loop) + self.assertTrue(f.cancel()) + with self.assertRaises(asyncio.CancelledError): + self.one_loop.run_until_complete(f) + self.assertFalse(gatherer.cancel()) + self.assertTrue(waiter.cancelled()) + self.assertTrue(child1.cancelled()) + self.assertTrue(child2.cancelled()) + test_utils.run_briefly(self.one_loop) + self.assertEqual(proof, 0) + + def test_exception_marking(self): + # Test for the first line marked "Mark exception retrieved." + + @asyncio.coroutine + def inner(f): + yield from f + raise RuntimeError('should not be ignored') + + a = asyncio.Future(loop=self.one_loop) + b = asyncio.Future(loop=self.one_loop) + + @asyncio.coroutine + def outer(): + yield from asyncio.gather(inner(a), inner(b), loop=self.one_loop) + + f = asyncio.async(outer(), loop=self.one_loop) + test_utils.run_briefly(self.one_loop) + a.set_result(None) + test_utils.run_briefly(self.one_loop) + b.set_result(None) + test_utils.run_briefly(self.one_loop) + self.assertIsInstance(f.exception(), RuntimeError) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_transports.py b/tests/test_transports.py new file mode 100644 index 00000000..3b6e3d67 --- /dev/null +++ b/tests/test_transports.py @@ -0,0 +1,91 @@ +"""Tests for transports.py.""" + +import unittest +from unittest import mock + +import asyncio +from asyncio import transports + + +class TransportTests(unittest.TestCase): + + def test_ctor_extra_is_none(self): + transport = asyncio.Transport() + self.assertEqual(transport._extra, {}) + + def test_get_extra_info(self): + transport = asyncio.Transport({'extra': 'info'}) + self.assertEqual('info', transport.get_extra_info('extra')) + self.assertIsNone(transport.get_extra_info('unknown')) + + default = object() + self.assertIs(default, transport.get_extra_info('unknown', default)) + + def test_writelines(self): + transport = asyncio.Transport() + transport.write = mock.Mock() + + transport.writelines([b'line1', + bytearray(b'line2'), + memoryview(b'line3')]) + self.assertEqual(1, transport.write.call_count) + transport.write.assert_called_with(b'line1line2line3') + + def test_not_implemented(self): + transport = asyncio.Transport() + + self.assertRaises(NotImplementedError, + transport.set_write_buffer_limits) + self.assertRaises(NotImplementedError, transport.get_write_buffer_size) + self.assertRaises(NotImplementedError, transport.write, 'data') + self.assertRaises(NotImplementedError, transport.write_eof) + self.assertRaises(NotImplementedError, transport.can_write_eof) + self.assertRaises(NotImplementedError, transport.pause_reading) + self.assertRaises(NotImplementedError, transport.resume_reading) + self.assertRaises(NotImplementedError, transport.close) + self.assertRaises(NotImplementedError, transport.abort) + + def test_dgram_not_implemented(self): + transport = asyncio.DatagramTransport() + + self.assertRaises(NotImplementedError, transport.sendto, 'data') + self.assertRaises(NotImplementedError, transport.abort) + + def test_subprocess_transport_not_implemented(self): + transport = asyncio.SubprocessTransport() + + self.assertRaises(NotImplementedError, transport.get_pid) + self.assertRaises(NotImplementedError, transport.get_returncode) + self.assertRaises(NotImplementedError, transport.get_pipe_transport, 1) + self.assertRaises(NotImplementedError, transport.send_signal, 1) + self.assertRaises(NotImplementedError, transport.terminate) + self.assertRaises(NotImplementedError, transport.kill) + + def test_flowcontrol_mixin_set_write_limits(self): + + class MyTransport(transports._FlowControlMixin, + transports.Transport): + + def get_write_buffer_size(self): + return 512 + + loop = mock.Mock() + transport = MyTransport(loop=loop) + transport._protocol = mock.Mock() + + self.assertFalse(transport._protocol_paused) + + with self.assertRaisesRegex(ValueError, 'high.*must be >= low'): + transport.set_write_buffer_limits(high=0, low=1) + + transport.set_write_buffer_limits(high=1024, low=128) + self.assertFalse(transport._protocol_paused) + self.assertEqual(transport.get_write_buffer_limits(), (128, 1024)) + + transport.set_write_buffer_limits(high=256, low=128) + self.assertTrue(transport._protocol_paused) + self.assertEqual(transport.get_write_buffer_limits(), (128, 256)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py new file mode 100644 index 00000000..126196da --- /dev/null +++ b/tests/test_unix_events.py @@ -0,0 +1,1568 @@ +"""Tests for unix_events.py.""" + +import collections +import errno +import io +import os +import signal +import socket +import stat +import sys +import tempfile +import threading +import unittest +from unittest import mock + +if sys.platform == 'win32': + raise unittest.SkipTest('UNIX only') + + +import asyncio +from asyncio import log +from asyncio import test_utils +from asyncio import unix_events + + +MOCK_ANY = mock.ANY + + +def close_pipe_transport(transport): + # Don't call transport.close() because the event loop and the selector + # are mocked + if transport._pipe is None: + return + transport._pipe.close() + transport._pipe = None + + +@unittest.skipUnless(signal, 'Signals are not supported') +class SelectorEventLoopSignalTests(test_utils.TestCase): + + def setUp(self): + self.loop = asyncio.SelectorEventLoop() + self.set_event_loop(self.loop) + + def test_check_signal(self): + self.assertRaises( + TypeError, self.loop._check_signal, '1') + self.assertRaises( + ValueError, self.loop._check_signal, signal.NSIG + 1) + + def test_handle_signal_no_handler(self): + self.loop._handle_signal(signal.NSIG + 1) + + def test_handle_signal_cancelled_handler(self): + h = asyncio.Handle(mock.Mock(), (), + loop=mock.Mock()) + h.cancel() + self.loop._signal_handlers[signal.NSIG + 1] = h + self.loop.remove_signal_handler = mock.Mock() + self.loop._handle_signal(signal.NSIG + 1) + self.loop.remove_signal_handler.assert_called_with(signal.NSIG + 1) + + @mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler_setup_error(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.set_wakeup_fd.side_effect = ValueError + + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler_coroutine_error(self, m_signal): + m_signal.NSIG = signal.NSIG + + @asyncio.coroutine + def simple_coroutine(): + yield from [] + + # callback must not be a coroutine function + coro_func = simple_coroutine + coro_obj = coro_func() + self.addCleanup(coro_obj.close) + for func in (coro_func, coro_obj): + self.assertRaisesRegex( + TypeError, 'coroutines cannot be used with add_signal_handler', + self.loop.add_signal_handler, + signal.SIGINT, func) + + @mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + cb = lambda: True + self.loop.add_signal_handler(signal.SIGHUP, cb) + h = self.loop._signal_handlers.get(signal.SIGHUP) + self.assertIsInstance(h, asyncio.Handle) + self.assertEqual(h._callback, cb) + + @mock.patch('asyncio.unix_events.signal') + def test_add_signal_handler_install_error(self, m_signal): + m_signal.NSIG = signal.NSIG + + def set_wakeup_fd(fd): + if fd == -1: + raise ValueError() + m_signal.set_wakeup_fd = set_wakeup_fd + + class Err(OSError): + errno = errno.EFAULT + m_signal.signal.side_effect = Err + + self.assertRaises( + Err, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + + @mock.patch('asyncio.unix_events.signal') + @mock.patch('asyncio.base_events.logger') + def test_add_signal_handler_install_error2(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.loop._signal_handlers[signal.SIGHUP] = lambda: True + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(1, m_signal.set_wakeup_fd.call_count) + + @mock.patch('asyncio.unix_events.signal') + @mock.patch('asyncio.base_events.logger') + def test_add_signal_handler_install_error3(self, m_logging, m_signal): + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + m_signal.NSIG = signal.NSIG + + self.assertRaises( + RuntimeError, + self.loop.add_signal_handler, + signal.SIGINT, lambda: True) + self.assertFalse(m_logging.info.called) + self.assertEqual(2, m_signal.set_wakeup_fd.call_count) + + @mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler(self, m_signal): + m_signal.NSIG = signal.NSIG + + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + self.assertTrue( + self.loop.remove_signal_handler(signal.SIGHUP)) + self.assertTrue(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0]) + + @mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler_2(self, m_signal): + m_signal.NSIG = signal.NSIG + m_signal.SIGINT = signal.SIGINT + + self.loop.add_signal_handler(signal.SIGINT, lambda: True) + self.loop._signal_handlers[signal.SIGHUP] = object() + m_signal.set_wakeup_fd.reset_mock() + + self.assertTrue( + self.loop.remove_signal_handler(signal.SIGINT)) + self.assertFalse(m_signal.set_wakeup_fd.called) + self.assertTrue(m_signal.signal.called) + self.assertEqual( + (signal.SIGINT, m_signal.default_int_handler), + m_signal.signal.call_args[0]) + + @mock.patch('asyncio.unix_events.signal') + @mock.patch('asyncio.base_events.logger') + def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.set_wakeup_fd.side_effect = ValueError + + self.loop.remove_signal_handler(signal.SIGHUP) + self.assertTrue(m_logging.info) + + @mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler_error(self, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + m_signal.signal.side_effect = OSError + + self.assertRaises( + OSError, self.loop.remove_signal_handler, signal.SIGHUP) + + @mock.patch('asyncio.unix_events.signal') + def test_remove_signal_handler_error2(self, m_signal): + m_signal.NSIG = signal.NSIG + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + + class Err(OSError): + errno = errno.EINVAL + m_signal.signal.side_effect = Err + + self.assertRaises( + RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP) + + @mock.patch('asyncio.unix_events.signal') + def test_close(self, m_signal): + m_signal.NSIG = signal.NSIG + + self.loop.add_signal_handler(signal.SIGHUP, lambda: True) + self.loop.add_signal_handler(signal.SIGCHLD, lambda: True) + + self.assertEqual(len(self.loop._signal_handlers), 2) + + m_signal.set_wakeup_fd.reset_mock() + + self.loop.close() + + self.assertEqual(len(self.loop._signal_handlers), 0) + m_signal.set_wakeup_fd.assert_called_once_with(-1) + + +@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), + 'UNIX Sockets are not supported') +class SelectorEventLoopUnixSocketTests(test_utils.TestCase): + + def setUp(self): + self.loop = asyncio.SelectorEventLoop() + self.set_event_loop(self.loop) + + def test_create_unix_server_existing_path_sock(self): + with test_utils.unix_socket_path() as path: + sock = socket.socket(socket.AF_UNIX) + sock.bind(path) + with sock: + coro = self.loop.create_unix_server(lambda: None, path) + with self.assertRaisesRegex(OSError, + 'Address.*is already in use'): + self.loop.run_until_complete(coro) + + def test_create_unix_server_existing_path_nonsock(self): + with tempfile.NamedTemporaryFile() as file: + coro = self.loop.create_unix_server(lambda: None, file.name) + with self.assertRaisesRegex(OSError, + 'Address.*is already in use'): + self.loop.run_until_complete(coro) + + def test_create_unix_server_ssl_bool(self): + coro = self.loop.create_unix_server(lambda: None, path='spam', + ssl=True) + with self.assertRaisesRegex(TypeError, + 'ssl argument must be an SSLContext'): + self.loop.run_until_complete(coro) + + def test_create_unix_server_nopath_nosock(self): + coro = self.loop.create_unix_server(lambda: None, path=None) + with self.assertRaisesRegex(ValueError, + 'path was not specified, and no sock'): + self.loop.run_until_complete(coro) + + def test_create_unix_server_path_inetsock(self): + sock = socket.socket() + with sock: + coro = self.loop.create_unix_server(lambda: None, path=None, + sock=sock) + with self.assertRaisesRegex(ValueError, + 'A UNIX Domain Socket was expected'): + self.loop.run_until_complete(coro) + + @mock.patch('asyncio.unix_events.socket') + def test_create_unix_server_bind_error(self, m_socket): + # Ensure that the socket is closed on any bind error + sock = mock.Mock() + m_socket.socket.return_value = sock + + sock.bind.side_effect = OSError + coro = self.loop.create_unix_server(lambda: None, path="/test") + with self.assertRaises(OSError): + self.loop.run_until_complete(coro) + self.assertTrue(sock.close.called) + + sock.bind.side_effect = MemoryError + coro = self.loop.create_unix_server(lambda: None, path="/test") + with self.assertRaises(MemoryError): + self.loop.run_until_complete(coro) + self.assertTrue(sock.close.called) + + def test_create_unix_connection_path_sock(self): + coro = self.loop.create_unix_connection( + lambda: None, '/dev/null', sock=object()) + with self.assertRaisesRegex(ValueError, 'path and sock can not be'): + self.loop.run_until_complete(coro) + + def test_create_unix_connection_nopath_nosock(self): + coro = self.loop.create_unix_connection( + lambda: None, None) + with self.assertRaisesRegex(ValueError, + 'no path and sock were specified'): + self.loop.run_until_complete(coro) + + def test_create_unix_connection_nossl_serverhost(self): + coro = self.loop.create_unix_connection( + lambda: None, '/dev/null', server_hostname='spam') + with self.assertRaisesRegex(ValueError, + 'server_hostname is only meaningful'): + self.loop.run_until_complete(coro) + + def test_create_unix_connection_ssl_noserverhost(self): + coro = self.loop.create_unix_connection( + lambda: None, '/dev/null', ssl=True) + + with self.assertRaisesRegex( + ValueError, 'you have to pass server_hostname when using ssl'): + + self.loop.run_until_complete(coro) + + +class UnixReadPipeTransportTests(test_utils.TestCase): + + def setUp(self): + self.loop = self.new_test_loop() + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) + self.pipe = mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + + blocking_patcher = mock.patch('asyncio.unix_events._set_nonblocking') + blocking_patcher.start() + self.addCleanup(blocking_patcher.stop) + + fstat_patcher = mock.patch('os.fstat') + m_fstat = fstat_patcher.start() + st = mock.Mock() + st.st_mode = stat.S_IFIFO + m_fstat.return_value = st + self.addCleanup(fstat_patcher.stop) + + def read_pipe_transport(self, waiter=None): + transport = unix_events._UnixReadPipeTransport(self.loop, self.pipe, + self.protocol, + waiter=waiter) + self.addCleanup(close_pipe_transport, transport) + return transport + + def test_ctor(self): + tr = self.read_pipe_transport() + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = asyncio.Future(loop=self.loop) + tr = self.read_pipe_transport(waiter=fut) + test_utils.run_briefly(self.loop) + self.assertIsNone(fut.result()) + + @mock.patch('os.read') + def test__read_ready(self, m_read): + tr = self.read_pipe_transport() + m_read.return_value = b'data' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.protocol.data_received.assert_called_with(b'data') + + @mock.patch('os.read') + def test__read_ready_eof(self, m_read): + tr = self.read_pipe_transport() + m_read.return_value = b'' + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.eof_received.assert_called_with() + self.protocol.connection_lost.assert_called_with(None) + + @mock.patch('os.read') + def test__read_ready_blocked(self, m_read): + tr = self.read_pipe_transport() + m_read.side_effect = BlockingIOError + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + test_utils.run_briefly(self.loop) + self.assertFalse(self.protocol.data_received.called) + + @mock.patch('asyncio.log.logger.error') + @mock.patch('os.read') + def test__read_ready_error(self, m_read, m_logexc): + tr = self.read_pipe_transport() + err = OSError() + m_read.side_effect = err + tr._close = mock.Mock() + tr._read_ready() + + m_read.assert_called_with(5, tr.max_size) + tr._close.assert_called_with(err) + m_logexc.assert_called_with( + test_utils.MockPattern( + 'Fatal read error on pipe transport' + '\nprotocol:.*\ntransport:.*'), + exc_info=(OSError, MOCK_ANY, MOCK_ANY)) + + @mock.patch('os.read') + def test_pause_reading(self, m_read): + tr = self.read_pipe_transport() + m = mock.Mock() + self.loop.add_reader(5, m) + tr.pause_reading() + self.assertFalse(self.loop.readers) + + @mock.patch('os.read') + def test_resume_reading(self, m_read): + tr = self.read_pipe_transport() + tr.resume_reading() + self.loop.assert_reader(5, tr._read_ready) + + @mock.patch('os.read') + def test_close(self, m_read): + tr = self.read_pipe_transport() + tr._close = mock.Mock() + tr.close() + tr._close.assert_called_with(None) + + @mock.patch('os.read') + def test_close_already_closing(self, m_read): + tr = self.read_pipe_transport() + tr._closing = True + tr._close = mock.Mock() + tr.close() + self.assertFalse(tr._close.called) + + @mock.patch('os.read') + def test__close(self, m_read): + tr = self.read_pipe_transport() + err = object() + tr._close(err) + self.assertTrue(tr._closing) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + + def test__call_connection_lost(self): + tr = self.read_pipe_transport() + self.assertIsNotNone(tr._protocol) + self.assertIsNotNone(tr._loop) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertIsNone(tr._loop) + + def test__call_connection_lost_with_err(self): + tr = self.read_pipe_transport() + self.assertIsNotNone(tr._protocol) + self.assertIsNotNone(tr._loop) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertIsNone(tr._loop) + + +class UnixWritePipeTransportTests(test_utils.TestCase): + + def setUp(self): + self.loop = self.new_test_loop() + self.protocol = test_utils.make_test_protocol(asyncio.BaseProtocol) + self.pipe = mock.Mock(spec_set=io.RawIOBase) + self.pipe.fileno.return_value = 5 + + blocking_patcher = mock.patch('asyncio.unix_events._set_nonblocking') + blocking_patcher.start() + self.addCleanup(blocking_patcher.stop) + + fstat_patcher = mock.patch('os.fstat') + m_fstat = fstat_patcher.start() + st = mock.Mock() + st.st_mode = stat.S_IFSOCK + m_fstat.return_value = st + self.addCleanup(fstat_patcher.stop) + + def write_pipe_transport(self, waiter=None): + transport = unix_events._UnixWritePipeTransport(self.loop, self.pipe, + self.protocol, + waiter=waiter) + self.addCleanup(close_pipe_transport, transport) + return transport + + def test_ctor(self): + tr = self.write_pipe_transport() + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.protocol.connection_made.assert_called_with(tr) + + def test_ctor_with_waiter(self): + fut = asyncio.Future(loop=self.loop) + tr = self.write_pipe_transport(waiter=fut) + self.loop.assert_reader(5, tr._read_ready) + test_utils.run_briefly(self.loop) + self.assertEqual(None, fut.result()) + + def test_can_write_eof(self): + tr = self.write_pipe_transport() + self.assertTrue(tr.can_write_eof()) + + @mock.patch('os.write') + def test_write(self, m_write): + tr = self.write_pipe_transport() + m_write.return_value = 4 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @mock.patch('os.write') + def test_write_no_data(self, m_write): + tr = self.write_pipe_transport() + tr.write(b'') + self.assertFalse(m_write.called) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @mock.patch('os.write') + def test_write_partial(self, m_write): + tr = self.write_pipe_transport() + m_write.return_value = 2 + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'ta'], tr._buffer) + + @mock.patch('os.write') + def test_write_buffer(self, m_write): + tr = self.write_pipe_transport() + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'previous'] + tr.write(b'data') + self.assertFalse(m_write.called) + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'previous', b'data'], tr._buffer) + + @mock.patch('os.write') + def test_write_again(self, m_write): + tr = self.write_pipe_transport() + m_write.side_effect = BlockingIOError() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @mock.patch('asyncio.unix_events.logger') + @mock.patch('os.write') + def test_write_err(self, m_write, m_log): + tr = self.write_pipe_transport() + err = OSError() + m_write.side_effect = err + tr._fatal_error = mock.Mock() + tr.write(b'data') + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + tr._fatal_error.assert_called_with( + err, + 'Fatal write error on pipe transport') + self.assertEqual(1, tr._conn_lost) + + tr.write(b'data') + self.assertEqual(2, tr._conn_lost) + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + tr.write(b'data') + # This is a bit overspecified. :-( + m_log.warning.assert_called_with( + 'pipe closed by peer or os.write(pipe, data) raised exception.') + tr.close() + + @mock.patch('os.write') + def test_write_close(self, m_write): + tr = self.write_pipe_transport() + tr._read_ready() # pipe was closed by peer + + tr.write(b'data') + self.assertEqual(tr._conn_lost, 1) + tr.write(b'data') + self.assertEqual(tr._conn_lost, 2) + + def test__read_ready(self): + tr = self.write_pipe_transport() + tr._read_ready() + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + self.assertTrue(tr._closing) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + @mock.patch('os.write') + def test__write_ready(self, m_write): + tr = self.write_pipe_transport() + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + + @mock.patch('os.write') + def test__write_ready_partial(self, m_write): + tr = self.write_pipe_transport() + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 3 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'a'], tr._buffer) + + @mock.patch('os.write') + def test__write_ready_again(self, m_write): + tr = self.write_pipe_transport() + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.side_effect = BlockingIOError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @mock.patch('os.write') + def test__write_ready_empty(self, m_write): + tr = self.write_pipe_transport() + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.return_value = 0 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.loop.assert_writer(5, tr._write_ready) + self.assertEqual([b'data'], tr._buffer) + + @mock.patch('asyncio.log.logger.error') + @mock.patch('os.write') + def test__write_ready_err(self, m_write, m_logexc): + tr = self.write_pipe_transport() + self.loop.add_writer(5, tr._write_ready) + tr._buffer = [b'da', b'ta'] + m_write.side_effect = err = OSError() + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + m_logexc.assert_called_with( + test_utils.MockPattern( + 'Fatal write error on pipe transport' + '\nprotocol:.*\ntransport:.*'), + exc_info=(OSError, MOCK_ANY, MOCK_ANY)) + self.assertEqual(1, tr._conn_lost) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(err) + + @mock.patch('os.write') + def test__write_ready_closing(self, m_write): + tr = self.write_pipe_transport() + self.loop.add_writer(5, tr._write_ready) + tr._closing = True + tr._buffer = [b'da', b'ta'] + m_write.return_value = 4 + tr._write_ready() + m_write.assert_called_with(5, b'data') + self.assertFalse(self.loop.writers) + self.assertFalse(self.loop.readers) + self.assertEqual([], tr._buffer) + self.protocol.connection_lost.assert_called_with(None) + self.pipe.close.assert_called_with() + + @mock.patch('os.write') + def test_abort(self, m_write): + tr = self.write_pipe_transport() + self.loop.add_writer(5, tr._write_ready) + self.loop.add_reader(5, tr._read_ready) + tr._buffer = [b'da', b'ta'] + tr.abort() + self.assertFalse(m_write.called) + self.assertFalse(self.loop.readers) + self.assertFalse(self.loop.writers) + self.assertEqual([], tr._buffer) + self.assertTrue(tr._closing) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test__call_connection_lost(self): + tr = self.write_pipe_transport() + self.assertIsNotNone(tr._protocol) + self.assertIsNotNone(tr._loop) + + err = None + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertIsNone(tr._loop) + + def test__call_connection_lost_with_err(self): + tr = self.write_pipe_transport() + self.assertIsNotNone(tr._protocol) + self.assertIsNotNone(tr._loop) + + err = OSError() + tr._call_connection_lost(err) + self.protocol.connection_lost.assert_called_with(err) + self.pipe.close.assert_called_with() + + self.assertIsNone(tr._protocol) + self.assertIsNone(tr._loop) + + def test_close(self): + tr = self.write_pipe_transport() + tr.write_eof = mock.Mock() + tr.close() + tr.write_eof.assert_called_with() + + # closing the transport twice must not fail + tr.close() + + def test_close_closing(self): + tr = self.write_pipe_transport() + tr.write_eof = mock.Mock() + tr._closing = True + tr.close() + self.assertFalse(tr.write_eof.called) + + def test_write_eof(self): + tr = self.write_pipe_transport() + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.loop.readers) + test_utils.run_briefly(self.loop) + self.protocol.connection_lost.assert_called_with(None) + + def test_write_eof_pending(self): + tr = self.write_pipe_transport() + tr._buffer = [b'data'] + tr.write_eof() + self.assertTrue(tr._closing) + self.assertFalse(self.protocol.connection_lost.called) + + +class AbstractChildWatcherTests(unittest.TestCase): + + def test_not_implemented(self): + f = mock.Mock() + watcher = asyncio.AbstractChildWatcher() + self.assertRaises( + NotImplementedError, watcher.add_child_handler, f, f) + self.assertRaises( + NotImplementedError, watcher.remove_child_handler, f) + self.assertRaises( + NotImplementedError, watcher.attach_loop, f) + self.assertRaises( + NotImplementedError, watcher.close) + self.assertRaises( + NotImplementedError, watcher.__enter__) + self.assertRaises( + NotImplementedError, watcher.__exit__, f, f, f) + + +class BaseChildWatcherTests(unittest.TestCase): + + def test_not_implemented(self): + f = mock.Mock() + watcher = unix_events.BaseChildWatcher() + self.assertRaises( + NotImplementedError, watcher._do_waitpid, f) + + +WaitPidMocks = collections.namedtuple("WaitPidMocks", + ("waitpid", + "WIFEXITED", + "WIFSIGNALED", + "WEXITSTATUS", + "WTERMSIG", + )) + + +class ChildWatcherTestsMixin: + + ignore_warnings = mock.patch.object(log.logger, "warning") + + def setUp(self): + self.loop = self.new_test_loop() + self.running = False + self.zombies = {} + + with mock.patch.object( + self.loop, "add_signal_handler") as self.m_add_signal_handler: + self.watcher = self.create_watcher() + self.watcher.attach_loop(self.loop) + + def waitpid(self, pid, flags): + if isinstance(self.watcher, asyncio.SafeChildWatcher) or pid != -1: + self.assertGreater(pid, 0) + try: + if pid < 0: + return self.zombies.popitem() + else: + return pid, self.zombies.pop(pid) + except KeyError: + pass + if self.running: + return 0, 0 + else: + raise ChildProcessError() + + def add_zombie(self, pid, returncode): + self.zombies[pid] = returncode + 32768 + + def WIFEXITED(self, status): + return status >= 32768 + + def WIFSIGNALED(self, status): + return 32700 < status < 32768 + + def WEXITSTATUS(self, status): + self.assertTrue(self.WIFEXITED(status)) + return status - 32768 + + def WTERMSIG(self, status): + self.assertTrue(self.WIFSIGNALED(status)) + return 32768 - status + + def test_create_watcher(self): + self.m_add_signal_handler.assert_called_once_with( + signal.SIGCHLD, self.watcher._sig_chld) + + def waitpid_mocks(func): + def wrapped_func(self): + def patch(target, wrapper): + return mock.patch(target, wraps=wrapper, + new_callable=mock.Mock) + + with patch('os.WTERMSIG', self.WTERMSIG) as m_WTERMSIG, \ + patch('os.WEXITSTATUS', self.WEXITSTATUS) as m_WEXITSTATUS, \ + patch('os.WIFSIGNALED', self.WIFSIGNALED) as m_WIFSIGNALED, \ + patch('os.WIFEXITED', self.WIFEXITED) as m_WIFEXITED, \ + patch('os.waitpid', self.waitpid) as m_waitpid: + func(self, WaitPidMocks(m_waitpid, + m_WIFEXITED, m_WIFSIGNALED, + m_WEXITSTATUS, m_WTERMSIG, + )) + return wrapped_func + + @waitpid_mocks + def test_sigchld(self, m): + # register a child + callback = mock.Mock() + + with self.watcher: + self.running = True + self.watcher.add_child_handler(42, callback, 9, 10, 14) + + self.assertFalse(callback.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # child is running + self.watcher._sig_chld() + + self.assertFalse(callback.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # child terminates (returncode 12) + self.running = False + self.add_zombie(42, 12) + self.watcher._sig_chld() + + self.assertTrue(m.WIFEXITED.called) + self.assertTrue(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + callback.assert_called_once_with(42, 12, 9, 10, 14) + + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WEXITSTATUS.reset_mock() + callback.reset_mock() + + # ensure that the child is effectively reaped + self.add_zombie(42, 13) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback.called) + self.assertFalse(m.WTERMSIG.called) + + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WEXITSTATUS.reset_mock() + + # sigchld called again + self.zombies.clear() + self.watcher._sig_chld() + + self.assertFalse(callback.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + @waitpid_mocks + def test_sigchld_two_children(self, m): + callback1 = mock.Mock() + callback2 = mock.Mock() + + # register child 1 + with self.watcher: + self.running = True + self.watcher.add_child_handler(43, callback1, 7, 8) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # register child 2 + with self.watcher: + self.watcher.add_child_handler(44, callback2, 147, 18) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # children are running + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # child 1 terminates (signal 3) + self.add_zombie(43, -3) + self.watcher._sig_chld() + + callback1.assert_called_once_with(43, -3, 7, 8) + self.assertFalse(callback2.called) + self.assertTrue(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertTrue(m.WTERMSIG.called) + + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WTERMSIG.reset_mock() + callback1.reset_mock() + + # child 2 still running + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # child 2 terminates (code 108) + self.add_zombie(44, 108) + self.running = False + self.watcher._sig_chld() + + callback2.assert_called_once_with(44, 108, 147, 18) + self.assertFalse(callback1.called) + self.assertTrue(m.WIFEXITED.called) + self.assertTrue(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WEXITSTATUS.reset_mock() + callback2.reset_mock() + + # ensure that the children are effectively reaped + self.add_zombie(43, 14) + self.add_zombie(44, 15) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WTERMSIG.called) + + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WEXITSTATUS.reset_mock() + + # sigchld called again + self.zombies.clear() + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + @waitpid_mocks + def test_sigchld_two_children_terminating_together(self, m): + callback1 = mock.Mock() + callback2 = mock.Mock() + + # register child 1 + with self.watcher: + self.running = True + self.watcher.add_child_handler(45, callback1, 17, 8) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # register child 2 + with self.watcher: + self.watcher.add_child_handler(46, callback2, 1147, 18) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # children are running + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # child 1 terminates (code 78) + # child 2 terminates (signal 5) + self.add_zombie(45, 78) + self.add_zombie(46, -5) + self.running = False + self.watcher._sig_chld() + + callback1.assert_called_once_with(45, 78, 17, 8) + callback2.assert_called_once_with(46, -5, 1147, 18) + self.assertTrue(m.WIFSIGNALED.called) + self.assertTrue(m.WIFEXITED.called) + self.assertTrue(m.WEXITSTATUS.called) + self.assertTrue(m.WTERMSIG.called) + + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WTERMSIG.reset_mock() + m.WEXITSTATUS.reset_mock() + callback1.reset_mock() + callback2.reset_mock() + + # ensure that the children are effectively reaped + self.add_zombie(45, 14) + self.add_zombie(46, 15) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WTERMSIG.called) + + @waitpid_mocks + def test_sigchld_race_condition(self, m): + # register a child + callback = mock.Mock() + + with self.watcher: + # child terminates before being registered + self.add_zombie(50, 4) + self.watcher._sig_chld() + + self.watcher.add_child_handler(50, callback, 1, 12) + + callback.assert_called_once_with(50, 4, 1, 12) + callback.reset_mock() + + # ensure that the child is effectively reaped + self.add_zombie(50, -1) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback.called) + + @waitpid_mocks + def test_sigchld_replace_handler(self, m): + callback1 = mock.Mock() + callback2 = mock.Mock() + + # register a child + with self.watcher: + self.running = True + self.watcher.add_child_handler(51, callback1, 19) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # register the same child again + with self.watcher: + self.watcher.add_child_handler(51, callback2, 21) + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # child terminates (signal 8) + self.running = False + self.add_zombie(51, -8) + self.watcher._sig_chld() + + callback2.assert_called_once_with(51, -8, 21) + self.assertFalse(callback1.called) + self.assertTrue(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertTrue(m.WTERMSIG.called) + + m.WIFSIGNALED.reset_mock() + m.WIFEXITED.reset_mock() + m.WTERMSIG.reset_mock() + callback2.reset_mock() + + # ensure that the child is effectively reaped + self.add_zombie(51, 13) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(m.WTERMSIG.called) + + @waitpid_mocks + def test_sigchld_remove_handler(self, m): + callback = mock.Mock() + + # register a child + with self.watcher: + self.running = True + self.watcher.add_child_handler(52, callback, 1984) + + self.assertFalse(callback.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # unregister the child + self.watcher.remove_child_handler(52) + + self.assertFalse(callback.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # child terminates (code 99) + self.running = False + self.add_zombie(52, 99) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback.called) + + @waitpid_mocks + def test_sigchld_unknown_status(self, m): + callback = mock.Mock() + + # register a child + with self.watcher: + self.running = True + self.watcher.add_child_handler(53, callback, -19) + + self.assertFalse(callback.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # terminate with unknown status + self.zombies[53] = 1178 + self.running = False + self.watcher._sig_chld() + + callback.assert_called_once_with(53, 1178, -19) + self.assertTrue(m.WIFEXITED.called) + self.assertTrue(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + callback.reset_mock() + m.WIFEXITED.reset_mock() + m.WIFSIGNALED.reset_mock() + + # ensure that the child is effectively reaped + self.add_zombie(53, 101) + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback.called) + + @waitpid_mocks + def test_remove_child_handler(self, m): + callback1 = mock.Mock() + callback2 = mock.Mock() + callback3 = mock.Mock() + + # register children + with self.watcher: + self.running = True + self.watcher.add_child_handler(54, callback1, 1) + self.watcher.add_child_handler(55, callback2, 2) + self.watcher.add_child_handler(56, callback3, 3) + + # remove child handler 1 + self.assertTrue(self.watcher.remove_child_handler(54)) + + # remove child handler 2 multiple times + self.assertTrue(self.watcher.remove_child_handler(55)) + self.assertFalse(self.watcher.remove_child_handler(55)) + self.assertFalse(self.watcher.remove_child_handler(55)) + + # all children terminate + self.add_zombie(54, 0) + self.add_zombie(55, 1) + self.add_zombie(56, 2) + self.running = False + with self.ignore_warnings: + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + callback3.assert_called_once_with(56, 2, 3) + + @waitpid_mocks + def test_sigchld_unhandled_exception(self, m): + callback = mock.Mock() + + # register a child + with self.watcher: + self.running = True + self.watcher.add_child_handler(57, callback) + + # raise an exception + m.waitpid.side_effect = ValueError + + with mock.patch.object(log.logger, + 'error') as m_error: + + self.assertEqual(self.watcher._sig_chld(), None) + self.assertTrue(m_error.called) + + @waitpid_mocks + def test_sigchld_child_reaped_elsewhere(self, m): + # register a child + callback = mock.Mock() + + with self.watcher: + self.running = True + self.watcher.add_child_handler(58, callback) + + self.assertFalse(callback.called) + self.assertFalse(m.WIFEXITED.called) + self.assertFalse(m.WIFSIGNALED.called) + self.assertFalse(m.WEXITSTATUS.called) + self.assertFalse(m.WTERMSIG.called) + + # child terminates + self.running = False + self.add_zombie(58, 4) + + # waitpid is called elsewhere + os.waitpid(58, os.WNOHANG) + + m.waitpid.reset_mock() + + # sigchld + with self.ignore_warnings: + self.watcher._sig_chld() + + if isinstance(self.watcher, asyncio.FastChildWatcher): + # here the FastChildWatche enters a deadlock + # (there is no way to prevent it) + self.assertFalse(callback.called) + else: + callback.assert_called_once_with(58, 255) + + @waitpid_mocks + def test_sigchld_unknown_pid_during_registration(self, m): + # register two children + callback1 = mock.Mock() + callback2 = mock.Mock() + + with self.ignore_warnings, self.watcher: + self.running = True + # child 1 terminates + self.add_zombie(591, 7) + # an unknown child terminates + self.add_zombie(593, 17) + + self.watcher._sig_chld() + + self.watcher.add_child_handler(591, callback1) + self.watcher.add_child_handler(592, callback2) + + callback1.assert_called_once_with(591, 7) + self.assertFalse(callback2.called) + + @waitpid_mocks + def test_set_loop(self, m): + # register a child + callback = mock.Mock() + + with self.watcher: + self.running = True + self.watcher.add_child_handler(60, callback) + + # attach a new loop + old_loop = self.loop + self.loop = self.new_test_loop() + patch = mock.patch.object + + with patch(old_loop, "remove_signal_handler") as m_old_remove, \ + patch(self.loop, "add_signal_handler") as m_new_add: + + self.watcher.attach_loop(self.loop) + + m_old_remove.assert_called_once_with( + signal.SIGCHLD) + m_new_add.assert_called_once_with( + signal.SIGCHLD, self.watcher._sig_chld) + + # child terminates + self.running = False + self.add_zombie(60, 9) + self.watcher._sig_chld() + + callback.assert_called_once_with(60, 9) + + @waitpid_mocks + def test_set_loop_race_condition(self, m): + # register 3 children + callback1 = mock.Mock() + callback2 = mock.Mock() + callback3 = mock.Mock() + + with self.watcher: + self.running = True + self.watcher.add_child_handler(61, callback1) + self.watcher.add_child_handler(62, callback2) + self.watcher.add_child_handler(622, callback3) + + # detach the loop + old_loop = self.loop + self.loop = None + + with mock.patch.object( + old_loop, "remove_signal_handler") as m_remove_signal_handler: + + self.watcher.attach_loop(None) + + m_remove_signal_handler.assert_called_once_with( + signal.SIGCHLD) + + # child 1 & 2 terminate + self.add_zombie(61, 11) + self.add_zombie(62, -5) + + # SIGCHLD was not caught + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + self.assertFalse(callback3.called) + + # attach a new loop + self.loop = self.new_test_loop() + + with mock.patch.object( + self.loop, "add_signal_handler") as m_add_signal_handler: + + self.watcher.attach_loop(self.loop) + + m_add_signal_handler.assert_called_once_with( + signal.SIGCHLD, self.watcher._sig_chld) + callback1.assert_called_once_with(61, 11) # race condition! + callback2.assert_called_once_with(62, -5) # race condition! + self.assertFalse(callback3.called) + + callback1.reset_mock() + callback2.reset_mock() + + # child 3 terminates + self.running = False + self.add_zombie(622, 19) + self.watcher._sig_chld() + + self.assertFalse(callback1.called) + self.assertFalse(callback2.called) + callback3.assert_called_once_with(622, 19) + + @waitpid_mocks + def test_close(self, m): + # register two children + callback1 = mock.Mock() + + with self.watcher: + self.running = True + # child 1 terminates + self.add_zombie(63, 9) + # other child terminates + self.add_zombie(65, 18) + self.watcher._sig_chld() + + self.watcher.add_child_handler(63, callback1) + self.watcher.add_child_handler(64, callback1) + + self.assertEqual(len(self.watcher._callbacks), 1) + if isinstance(self.watcher, asyncio.FastChildWatcher): + self.assertEqual(len(self.watcher._zombies), 1) + + with mock.patch.object( + self.loop, + "remove_signal_handler") as m_remove_signal_handler: + + self.watcher.close() + + m_remove_signal_handler.assert_called_once_with( + signal.SIGCHLD) + self.assertFalse(self.watcher._callbacks) + if isinstance(self.watcher, asyncio.FastChildWatcher): + self.assertFalse(self.watcher._zombies) + + +class SafeChildWatcherTests (ChildWatcherTestsMixin, test_utils.TestCase): + def create_watcher(self): + return asyncio.SafeChildWatcher() + + +class FastChildWatcherTests (ChildWatcherTestsMixin, test_utils.TestCase): + def create_watcher(self): + return asyncio.FastChildWatcher() + + +class PolicyTests(unittest.TestCase): + + def create_policy(self): + return asyncio.DefaultEventLoopPolicy() + + def test_get_child_watcher(self): + policy = self.create_policy() + self.assertIsNone(policy._watcher) + + watcher = policy.get_child_watcher() + self.assertIsInstance(watcher, asyncio.SafeChildWatcher) + + self.assertIs(policy._watcher, watcher) + + self.assertIs(watcher, policy.get_child_watcher()) + self.assertIsNone(watcher._loop) + + def test_get_child_watcher_after_set(self): + policy = self.create_policy() + watcher = asyncio.FastChildWatcher() + + policy.set_child_watcher(watcher) + self.assertIs(policy._watcher, watcher) + self.assertIs(watcher, policy.get_child_watcher()) + + def test_get_child_watcher_with_mainloop_existing(self): + policy = self.create_policy() + loop = policy.get_event_loop() + + self.assertIsNone(policy._watcher) + watcher = policy.get_child_watcher() + + self.assertIsInstance(watcher, asyncio.SafeChildWatcher) + self.assertIs(watcher._loop, loop) + + loop.close() + + def test_get_child_watcher_thread(self): + + def f(): + policy.set_event_loop(policy.new_event_loop()) + + self.assertIsInstance(policy.get_event_loop(), + asyncio.AbstractEventLoop) + watcher = policy.get_child_watcher() + + self.assertIsInstance(watcher, asyncio.SafeChildWatcher) + self.assertIsNone(watcher._loop) + + policy.get_event_loop().close() + + policy = self.create_policy() + + th = threading.Thread(target=f) + th.start() + th.join() + + def test_child_watcher_replace_mainloop_existing(self): + policy = self.create_policy() + loop = policy.get_event_loop() + + watcher = policy.get_child_watcher() + + self.assertIs(watcher._loop, loop) + + new_loop = policy.new_event_loop() + policy.set_event_loop(new_loop) + + self.assertIs(watcher._loop, new_loop) + + policy.set_event_loop(None) + + self.assertIs(watcher._loop, None) + + loop.close() + new_loop.close() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py new file mode 100644 index 00000000..f9b3dd15 --- /dev/null +++ b/tests/test_windows_events.py @@ -0,0 +1,146 @@ +import os +import sys +import unittest + +if sys.platform != 'win32': + raise unittest.SkipTest('Windows only') + +import _winapi + +import asyncio +from asyncio import _overlapped +from asyncio import test_utils +from asyncio import windows_events + + +class UpperProto(asyncio.Protocol): + def __init__(self): + self.buf = [] + + def connection_made(self, trans): + self.trans = trans + + def data_received(self, data): + self.buf.append(data) + if b'\n' in data: + self.trans.write(b''.join(self.buf).upper()) + self.trans.close() + + +class ProactorTests(test_utils.TestCase): + + def setUp(self): + self.loop = asyncio.ProactorEventLoop() + self.set_event_loop(self.loop) + + def test_close(self): + a, b = self.loop._socketpair() + trans = self.loop._make_socket_transport(a, asyncio.Protocol()) + f = asyncio.async(self.loop.sock_recv(b, 100)) + trans.close() + self.loop.run_until_complete(f) + self.assertEqual(f.result(), b'') + b.close() + + def test_double_bind(self): + ADDRESS = r'\\.\pipe\test_double_bind-%s' % os.getpid() + server1 = windows_events.PipeServer(ADDRESS) + with self.assertRaises(PermissionError): + windows_events.PipeServer(ADDRESS) + server1.close() + + def test_pipe(self): + res = self.loop.run_until_complete(self._test_pipe()) + self.assertEqual(res, 'done') + + def _test_pipe(self): + ADDRESS = r'\\.\pipe\_test_pipe-%s' % os.getpid() + + with self.assertRaises(FileNotFoundError): + yield from self.loop.create_pipe_connection( + asyncio.Protocol, ADDRESS) + + [server] = yield from self.loop.start_serving_pipe( + UpperProto, ADDRESS) + self.assertIsInstance(server, windows_events.PipeServer) + + clients = [] + for i in range(5): + stream_reader = asyncio.StreamReader(loop=self.loop) + protocol = asyncio.StreamReaderProtocol(stream_reader, + loop=self.loop) + trans, proto = yield from self.loop.create_pipe_connection( + lambda: protocol, ADDRESS) + self.assertIsInstance(trans, asyncio.Transport) + self.assertEqual(protocol, proto) + clients.append((stream_reader, trans)) + + for i, (r, w) in enumerate(clients): + w.write('lower-{}\n'.format(i).encode()) + + for i, (r, w) in enumerate(clients): + response = yield from r.readline() + self.assertEqual(response, 'LOWER-{}\n'.format(i).encode()) + w.close() + + server.close() + + with self.assertRaises(FileNotFoundError): + yield from self.loop.create_pipe_connection( + asyncio.Protocol, ADDRESS) + + return 'done' + + def test_wait_for_handle(self): + event = _overlapped.CreateEvent(None, True, False, None) + self.addCleanup(_winapi.CloseHandle, event) + + # Wait for unset event with 0.5s timeout; + # result should be False at timeout + fut = self.loop._proactor.wait_for_handle(event, 0.5) + start = self.loop.time() + done = self.loop.run_until_complete(fut) + elapsed = self.loop.time() - start + + self.assertEqual(done, False) + self.assertFalse(fut.result()) + self.assertTrue(0.48 < elapsed < 0.9, elapsed) + + _overlapped.SetEvent(event) + + # Wait for set event; + # result should be True immediately + fut = self.loop._proactor.wait_for_handle(event, 10) + start = self.loop.time() + done = self.loop.run_until_complete(fut) + elapsed = self.loop.time() - start + + self.assertEqual(done, True) + self.assertTrue(fut.result()) + self.assertTrue(0 <= elapsed < 0.3, elapsed) + + # Tulip issue #195: cancelling a done _WaitHandleFuture must not crash + fut.cancel() + + def test_wait_for_handle_cancel(self): + event = _overlapped.CreateEvent(None, True, False, None) + self.addCleanup(_winapi.CloseHandle, event) + + # Wait for unset event with a cancelled future; + # CancelledError should be raised immediately + fut = self.loop._proactor.wait_for_handle(event, 10) + fut.cancel() + start = self.loop.time() + with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(fut) + elapsed = self.loop.time() - start + self.assertTrue(0 <= elapsed < 0.1, elapsed) + + # Tulip issue #195: cancelling a _WaitHandleFuture twice must not crash + fut = self.loop._proactor.wait_for_handle(event) + fut.cancel() + fut.cancel() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_windows_utils.py b/tests/test_windows_utils.py new file mode 100644 index 00000000..d48b8bcb --- /dev/null +++ b/tests/test_windows_utils.py @@ -0,0 +1,182 @@ +"""Tests for window_utils""" + +import socket +import sys +import unittest +import warnings +from unittest import mock + +if sys.platform != 'win32': + raise unittest.SkipTest('Windows only') + +import _winapi + +from asyncio import _overlapped +from asyncio import windows_utils +try: + from test import support +except ImportError: + from asyncio import test_support as support + + +class WinsocketpairTests(unittest.TestCase): + + def check_winsocketpair(self, ssock, csock): + csock.send(b'xxx') + self.assertEqual(b'xxx', ssock.recv(1024)) + csock.close() + ssock.close() + + def test_winsocketpair(self): + ssock, csock = windows_utils.socketpair() + self.check_winsocketpair(ssock, csock) + + @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 not supported or enabled') + def test_winsocketpair_ipv6(self): + ssock, csock = windows_utils.socketpair(family=socket.AF_INET6) + self.check_winsocketpair(ssock, csock) + + @unittest.skipIf(hasattr(socket, 'socketpair'), + 'socket.socketpair is available') + @mock.patch('asyncio.windows_utils.socket') + def test_winsocketpair_exc(self, m_socket): + m_socket.AF_INET = socket.AF_INET + m_socket.SOCK_STREAM = socket.SOCK_STREAM + m_socket.socket.return_value.getsockname.return_value = ('', 12345) + m_socket.socket.return_value.accept.return_value = object(), object() + m_socket.socket.return_value.connect.side_effect = OSError() + + self.assertRaises(OSError, windows_utils.socketpair) + + def test_winsocketpair_invalid_args(self): + self.assertRaises(ValueError, + windows_utils.socketpair, family=socket.AF_UNSPEC) + self.assertRaises(ValueError, + windows_utils.socketpair, type=socket.SOCK_DGRAM) + self.assertRaises(ValueError, + windows_utils.socketpair, proto=1) + + @unittest.skipIf(hasattr(socket, 'socketpair'), + 'socket.socketpair is available') + @mock.patch('asyncio.windows_utils.socket') + def test_winsocketpair_close(self, m_socket): + m_socket.AF_INET = socket.AF_INET + m_socket.SOCK_STREAM = socket.SOCK_STREAM + sock = mock.Mock() + m_socket.socket.return_value = sock + sock.bind.side_effect = OSError + self.assertRaises(OSError, windows_utils.socketpair) + self.assertTrue(sock.close.called) + + +class PipeTests(unittest.TestCase): + + def test_pipe_overlapped(self): + h1, h2 = windows_utils.pipe(overlapped=(True, True)) + try: + ov1 = _overlapped.Overlapped() + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, 0) + + ov1.ReadFile(h1, 100) + self.assertTrue(ov1.pending) + self.assertEqual(ov1.error, _winapi.ERROR_IO_PENDING) + ERROR_IO_INCOMPLETE = 996 + try: + ov1.getresult() + except OSError as e: + self.assertEqual(e.winerror, ERROR_IO_INCOMPLETE) + else: + raise RuntimeError('expected ERROR_IO_INCOMPLETE') + + ov2 = _overlapped.Overlapped() + self.assertFalse(ov2.pending) + self.assertEqual(ov2.error, 0) + + ov2.WriteFile(h2, b"hello") + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + + res = _winapi.WaitForMultipleObjects([ov2.event], False, 100) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + + self.assertFalse(ov1.pending) + self.assertEqual(ov1.error, ERROR_IO_INCOMPLETE) + self.assertFalse(ov2.pending) + self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + self.assertEqual(ov1.getresult(), b"hello") + finally: + _winapi.CloseHandle(h1) + _winapi.CloseHandle(h2) + + def test_pipe_handle(self): + h, _ = windows_utils.pipe(overlapped=(True, True)) + _winapi.CloseHandle(_) + p = windows_utils.PipeHandle(h) + self.assertEqual(p.fileno(), h) + self.assertEqual(p.handle, h) + + # check garbage collection of p closes handle + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "", ResourceWarning) + del p + support.gc_collect() + try: + _winapi.CloseHandle(h) + except OSError as e: + self.assertEqual(e.winerror, 6) # ERROR_INVALID_HANDLE + else: + raise RuntimeError('expected ERROR_INVALID_HANDLE') + + +class PopenTests(unittest.TestCase): + + def test_popen(self): + command = r"""if 1: + import sys + s = sys.stdin.readline() + sys.stdout.write(s.upper()) + sys.stderr.write('stderr') + """ + msg = b"blah\n" + + p = windows_utils.Popen([sys.executable, '-c', command], + stdin=windows_utils.PIPE, + stdout=windows_utils.PIPE, + stderr=windows_utils.PIPE) + + for f in [p.stdin, p.stdout, p.stderr]: + self.assertIsInstance(f, windows_utils.PipeHandle) + + ovin = _overlapped.Overlapped() + ovout = _overlapped.Overlapped() + overr = _overlapped.Overlapped() + + ovin.WriteFile(p.stdin.handle, msg) + ovout.ReadFile(p.stdout.handle, 100) + overr.ReadFile(p.stderr.handle, 100) + + events = [ovin.event, ovout.event, overr.event] + # Super-long timeout for slow buildbots. + res = _winapi.WaitForMultipleObjects(events, True, 10000) + self.assertEqual(res, _winapi.WAIT_OBJECT_0) + self.assertFalse(ovout.pending) + self.assertFalse(overr.pending) + self.assertFalse(ovin.pending) + + self.assertEqual(ovin.getresult(), len(msg)) + out = ovout.getresult().rstrip() + err = overr.getresult().rstrip() + + self.assertGreater(len(out), 0) + self.assertGreater(len(err), 0) + # allow for partial reads... + self.assertTrue(msg.upper().rstrip().startswith(out)) + self.assertTrue(b"stderr".startswith(err)) + + # The context manager calls wait() and closes resources + with p: + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/tox.ini b/tox.ini new file mode 100644 index 00000000..6209ff4e --- /dev/null +++ b/tox.ini @@ -0,0 +1,21 @@ +[tox] +envlist = py33,py34,py3_release + +[testenv] +deps= + aiotest +# Run tests in debug mode +setenv = + PYTHONASYNCIODEBUG = 1 +commands= + python runtests.py -r {posargs} + python run_aiotest.py -r {posargs} + +[testenv:py3_release] +# Run tests in release mode +setenv = + PYTHONASYNCIODEBUG = +basepython = python3 + +[testenv:py35] +basepython = python3.5 diff --git a/update_stdlib.sh b/update_stdlib.sh new file mode 100755 index 00000000..0cdbb1bd --- /dev/null +++ b/update_stdlib.sh @@ -0,0 +1,70 @@ +#!/bin/bash + +# Script to copy asyncio files to the standard library tree. +# Optional argument is the root of the Python 3.4 tree. +# Assumes you have already created Lib/asyncio and +# Lib/test/test_asyncio in the destination tree. + +CPYTHON=${1-$HOME/cpython} + +if [ ! -d $CPYTHON ] +then + echo Bad destination $CPYTHON + exit 1 +fi + +if [ ! -f asyncio/__init__.py ] +then + echo Bad current directory + exit 1 +fi + +maybe_copy() +{ + SRC=$1 + DST=$CPYTHON/$2 + if cmp $DST $SRC + then + return + fi + echo ======== $SRC === $DST ======== + diff -u $DST $SRC + echo -n "Copy $SRC? [y/N/back] " + read X + case $X in + [yY]*) echo Copying $SRC; cp $SRC $DST;; + back) echo Copying TO $SRC; cp $DST $SRC;; + *) echo Not copying $SRC;; + esac +} + +for i in `(cd asyncio && ls *.py)` +do + if [ $i == test_support.py ] + then + continue + fi + + if [ $i == selectors.py ] + then + if [ "`(cd $CPYTHON; hg branch)`" == "3.4" ] + then + echo "Destination is 3.4 branch -- ignoring selectors.py" + else + maybe_copy asyncio/$i Lib/$i + fi + else + maybe_copy asyncio/$i Lib/asyncio/$i + fi +done + +for i in `(cd tests && ls *.py *.pem)` +do + if [ $i == test_selectors.py ] + then + continue + fi + maybe_copy tests/$i Lib/test/test_asyncio/$i +done + +maybe_copy overlapped.c Modules/overlapped.c From febe4cb812b334b205458440c5999c430c74f698 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 26 Jan 2015 14:32:13 +0100 Subject: [PATCH 1312/1502] Fix ProactorEventLoop.start_serving_pipe() If a client connected before the server was closed: drop the client (close the pipe) and exit --- asyncio/windows_events.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 6c7e0580..109f5d3f 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -257,7 +257,7 @@ def _get_unconnected_pipe(self): def _server_pipe_handle(self, first): # Return a wrapper for a new pipe handle. - if self._address is None: + if self.closed(): return None flags = _winapi.PIPE_ACCESS_DUPLEX | _winapi.FILE_FLAG_OVERLAPPED if first: @@ -273,6 +273,9 @@ def _server_pipe_handle(self, first): self._free_instances.add(pipe) return pipe + def closed(self): + return (self._address is None) + def close(self): if self._accept_pipe_future is not None: self._accept_pipe_future.cancel() @@ -325,12 +328,21 @@ def loop_accept_pipe(f=None): if f: pipe = f.result() server._free_instances.discard(pipe) + + if server.closed(): + # A client connected before the server was closed: + # drop the client (close the pipe) and exit + pipe.close() + return + protocol = protocol_factory() self._make_duplex_pipe_transport( pipe, protocol, extra={'addr': address}) + pipe = server._get_unconnected_pipe() if pipe is None: return + f = self._proactor.accept_pipe(pipe) except OSError as exc: if pipe and pipe.fileno() != -1: From 4945c1a5f68f6e15fc1f8bed5932cf177a946ad9 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 26 Jan 2015 14:34:45 +0100 Subject: [PATCH 1313/1502] PipeHandle.fileno() now raises an exception if the pipe is closed --- asyncio/windows_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/asyncio/windows_utils.py b/asyncio/windows_utils.py index e6642960..5f8327eb 100644 --- a/asyncio/windows_utils.py +++ b/asyncio/windows_utils.py @@ -147,6 +147,8 @@ def handle(self): return self._handle def fileno(self): + if self._handle is None: + raise ValueError("I/O operatioon on closed pipe") return self._handle def close(self, *, CloseHandle=_winapi.CloseHandle): From 34214f46d8cae766aff88d1d917a2af2640beafb Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 26 Jan 2015 14:58:16 +0100 Subject: [PATCH 1314/1502] Python issue #23293: Rewrite IocpProactor.connect_pipe() as a coroutine Use a coroutine with asyncio.sleep() instead of call_later() to ensure that the schedule call is cancelled. Add also a unit test cancelling connect_pipe(). --- asyncio/windows_events.py | 39 +++++++++++++++++------------------- tests/test_windows_events.py | 13 ++++++++++++ 2 files changed, 31 insertions(+), 21 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 109f5d3f..c9ba7850 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -518,28 +518,25 @@ def finish_accept_pipe(trans, key, ov): return self._register(ov, pipe, finish_accept_pipe) - def _connect_pipe(self, fut, address, delay): - # Unfortunately there is no way to do an overlapped connect to a pipe. - # Call CreateFile() in a loop until it doesn't fail with - # ERROR_PIPE_BUSY - try: - handle = _overlapped.ConnectPipe(address) - except OSError as exc: - if exc.winerror == _overlapped.ERROR_PIPE_BUSY: - # Polling: retry later - delay = min(delay * 2, CONNECT_PIPE_MAX_DELAY) - self._loop.call_later(delay, - self._connect_pipe, fut, address, delay) - else: - fut.set_exception(exc) - else: - pipe = windows_utils.PipeHandle(handle) - fut.set_result(pipe) - + @coroutine def connect_pipe(self, address): - fut = futures.Future(loop=self._loop) - self._connect_pipe(fut, address, CONNECT_PIPE_INIT_DELAY) - return fut + delay = CONNECT_PIPE_INIT_DELAY + while True: + # Unfortunately there is no way to do an overlapped connect to a pipe. + # Call CreateFile() in a loop until it doesn't fail with + # ERROR_PIPE_BUSY + try: + handle = _overlapped.ConnectPipe(address) + break + except OSError as exc: + if exc.winerror != _overlapped.ERROR_PIPE_BUSY: + raise + + # ConnectPipe() failed with ERROR_PIPE_BUSY: retry later + delay = min(delay * 2, CONNECT_PIPE_MAX_DELAY) + yield from tasks.sleep(delay, loop=self._loop) + + return windows_utils.PipeHandle(handle) def wait_for_handle(self, handle, timeout=None): """Wait for a handle. diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index f9b3dd15..73d8fcdb 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -1,6 +1,7 @@ import os import sys import unittest +from unittest import mock if sys.platform != 'win32': raise unittest.SkipTest('Windows only') @@ -91,6 +92,18 @@ def _test_pipe(self): return 'done' + def test_connect_pipe_cancel(self): + exc = OSError() + exc.winerror = _overlapped.ERROR_PIPE_BUSY + with mock.patch.object(_overlapped, 'ConnectPipe', side_effect=exc) as connect: + coro = self.loop._proactor.connect_pipe('pipe_address') + task = self.loop.create_task(coro) + + # check that it's possible to cancel connect_pipe() + task.cancel() + with self.assertRaises(asyncio.CancelledError): + self.loop.run_until_complete(task) + def test_wait_for_handle(self): event = _overlapped.CreateEvent(None, True, False, None) self.addCleanup(_winapi.CloseHandle, event) From 09f7de12e5571b09962e4a23fdf86c56ba4df5ec Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 26 Jan 2015 22:13:20 +0100 Subject: [PATCH 1315/1502] Python issue #23095: Fix _WaitHandleFuture.cancel() If UnregisterWaitEx() fais with ERROR_IO_PENDING, it doesn't mean that the wait is unregistered yet. We still have to wait until the wait is cancelled. --- asyncio/windows_events.py | 37 +++++++++++++++++-------------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index c9ba7850..8f1d9d2a 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -126,14 +126,12 @@ def _unregister_wait(self): return self._registered = False + wait_handle = self._wait_handle + self._wait_handle = None try: - _overlapped.UnregisterWait(self._wait_handle) + _overlapped.UnregisterWait(wait_handle) except OSError as exc: - self._wait_handle = None - if exc.winerror == _overlapped.ERROR_IO_PENDING: - # ERROR_IO_PENDING is not an error, the wait was unregistered - self._unregister_wait_cb(None) - elif exc.winerror != _overlapped.ERROR_IO_PENDING: + if exc.winerror != _overlapped.ERROR_IO_PENDING: context = { 'message': 'Failed to unregister the wait handle', 'exception': exc, @@ -142,9 +140,10 @@ def _unregister_wait(self): if self._source_traceback: context['source_traceback'] = self._source_traceback self._loop.call_exception_handler(context) - else: - self._wait_handle = None - self._unregister_wait_cb(None) + return + # ERROR_IO_PENDING means that the unregister is pending + + self._unregister_wait_cb(None) def cancel(self): self._unregister_wait() @@ -209,14 +208,12 @@ def _unregister_wait(self): return self._registered = False + wait_handle = self._wait_handle + self._wait_handle = None try: - _overlapped.UnregisterWaitEx(self._wait_handle, self._event) + _overlapped.UnregisterWaitEx(wait_handle, self._event) except OSError as exc: - self._wait_handle = None - if exc.winerror == _overlapped.ERROR_IO_PENDING: - # ERROR_IO_PENDING is not an error, the wait was unregistered - self._unregister_wait_cb(None) - elif exc.winerror != _overlapped.ERROR_IO_PENDING: + if exc.winerror != _overlapped.ERROR_IO_PENDING: context = { 'message': 'Failed to unregister the wait handle', 'exception': exc, @@ -225,11 +222,11 @@ def _unregister_wait(self): if self._source_traceback: context['source_traceback'] = self._source_traceback self._loop.call_exception_handler(context) - else: - self._wait_handle = None - self._event_fut = self._proactor._wait_cancel( - self._event, - self._unregister_wait_cb) + return + # ERROR_IO_PENDING is not an error, the wait was unregistered + + self._event_fut = self._proactor._wait_cancel(self._event, + self._unregister_wait_cb) class PipeServer(object): From 29487bbe12e6dbbe5698077eed3903f87e8f873f Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 26 Jan 2015 22:07:16 +0100 Subject: [PATCH 1316/1502] Tulip issue #204: Fix IocpProactor.recv() If ReadFile() fails with ERROR_BROKEN_PIPE, the operation is not pending: don't register the overlapped. I don't know if WSARecv() can fail with ERROR_BROKEN_PIPE. Since Overlapped.WSARecv() already handled ERROR_BROKEN_PIPE, let me guess that it has the same behaviour than ReadFile(). --- asyncio/windows_events.py | 20 +++++++++++++------- overlapped.c | 4 ++-- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 8f1d9d2a..94aafb6f 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -406,13 +406,21 @@ def select(self, timeout=None): self._results = [] return tmp + def _result(self, value): + fut = futures.Future(loop=self._loop) + fut.set_result(value) + return fut + def recv(self, conn, nbytes, flags=0): self._register_with_iocp(conn) ov = _overlapped.Overlapped(NULL) - if isinstance(conn, socket.socket): - ov.WSARecv(conn.fileno(), nbytes, flags) - else: - ov.ReadFile(conn.fileno(), nbytes) + try: + if isinstance(conn, socket.socket): + ov.WSARecv(conn.fileno(), nbytes, flags) + else: + ov.ReadFile(conn.fileno(), nbytes) + except BrokenPipeError: + return self._result(b'') def finish_recv(trans, key, ov): try: @@ -505,9 +513,7 @@ def accept_pipe(self, pipe): # ConnectNamePipe() failed with ERROR_PIPE_CONNECTED which means # that the pipe is connected. There is no need to wait for the # completion of the connection. - f = futures.Future(loop=self._loop) - f.set_result(pipe) - return f + return self._result(pipe) def finish_accept_pipe(trans, key, ov): ov.getresult() diff --git a/overlapped.c b/overlapped.c index 4661152d..1a081ecb 100644 --- a/overlapped.c +++ b/overlapped.c @@ -730,7 +730,7 @@ Overlapped_ReadFile(OverlappedObject *self, PyObject *args) switch (err) { case ERROR_BROKEN_PIPE: mark_as_completed(&self->overlapped); - Py_RETURN_NONE; + return SetFromWindowsErr(err); case ERROR_SUCCESS: case ERROR_MORE_DATA: case ERROR_IO_PENDING: @@ -789,7 +789,7 @@ Overlapped_WSARecv(OverlappedObject *self, PyObject *args) switch (err) { case ERROR_BROKEN_PIPE: mark_as_completed(&self->overlapped); - Py_RETURN_NONE; + return SetFromWindowsErr(err); case ERROR_SUCCESS: case ERROR_MORE_DATA: case ERROR_IO_PENDING: From 54faf33fc4b9eac9389bf25fc2e9b4adfc3c0838 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 26 Jan 2015 22:42:10 +0100 Subject: [PATCH 1317/1502] _overlapped.ConnectPipe(): release the GIL --- overlapped.c | 3 +++ 1 file changed, 3 insertions(+) diff --git a/overlapped.c b/overlapped.c index 1a081ecb..ef77c887 100644 --- a/overlapped.c +++ b/overlapped.c @@ -1146,10 +1146,13 @@ ConnectPipe(OverlappedObject *self, PyObject *args) if (Address == NULL) return NULL; + Py_BEGIN_ALLOW_THREADS PipeHandle = CreateFileW(Address, GENERIC_READ | GENERIC_WRITE, 0, NULL, OPEN_EXISTING, FILE_FLAG_OVERLAPPED, NULL); + Py_END_ALLOW_THREADS + PyMem_Free(Address); if (PipeHandle == INVALID_HANDLE_VALUE) return SetFromWindowsErr(0); From f1774fc9e2b9a603e3f2b92d4f20cac0336e25e9 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 26 Jan 2015 15:11:17 +0100 Subject: [PATCH 1318/1502] Python issue #23208: Don't use the traceback of the current handle if we already know the traceback of the source. The handle may be more revelant, but having 3 tracebacks (handle, source, exception) becomes more difficult to read. The handle may be preferred later but it requires more work to make this choice. --- asyncio/base_events.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 1c51a7cf..e40d3ad5 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -956,7 +956,8 @@ def default_exception_handler(self, context): else: exc_info = False - if (self._current_handle is not None + if ('source_traceback' not in context + and self._current_handle is not None and self._current_handle._source_traceback): context['handle_traceback'] = self._current_handle._source_traceback From 8343733ce97951a47522b11ed854e5445b186a79 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 27 Jan 2015 11:07:10 +0100 Subject: [PATCH 1319/1502] test_sslproto: skip test if ssl module is missing --- tests/test_sslproto.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_sslproto.py b/tests/test_sslproto.py index 812dedbe..b1a61c48 100644 --- a/tests/test_sslproto.py +++ b/tests/test_sslproto.py @@ -2,6 +2,10 @@ import unittest from unittest import mock +try: + import ssl +except ImportError: + ssl = None import asyncio from asyncio import sslproto @@ -14,6 +18,7 @@ def setUp(self): self.loop = asyncio.new_event_loop() self.set_event_loop(self.loop) + @unittest.skipIf(ssl is None, 'No ssl module') def test_cancel_handshake(self): # Python issue #23197: cancelling an handshake must not raise an # exception or log an error, even if the handshake failed From fdd3d9b9d7a87baddc781b11e2a653ea9cf896c5 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 28 Jan 2015 00:27:25 +0100 Subject: [PATCH 1320/1502] Remove unused SSLProtocol._closing attribute --- asyncio/sslproto.py | 1 - 1 file changed, 1 deletion(-) diff --git a/asyncio/sslproto.py b/asyncio/sslproto.py index 117dc565..f2b856c4 100644 --- a/asyncio/sslproto.py +++ b/asyncio/sslproto.py @@ -408,7 +408,6 @@ def __init__(self, loop, app_protocol, sslcontext, waiter, self._write_buffer_size = 0 self._waiter = waiter - self._closing = False self._loop = loop self._app_protocol = app_protocol self._app_transport = _SSLProtocolTransport(self._loop, From 63d8fc1d7d93218931b99b0dc9b81cbee5ce10ec Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 28 Jan 2015 23:52:07 +0100 Subject: [PATCH 1321/1502] Fix SSLProtocol.eof_received() Wake-up the waiter if it is not done yet. --- asyncio/sslproto.py | 4 ++++ tests/test_sslproto.py | 40 +++++++++++++++++++++++++++++----------- 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/asyncio/sslproto.py b/asyncio/sslproto.py index f2b856c4..26937c84 100644 --- a/asyncio/sslproto.py +++ b/asyncio/sslproto.py @@ -489,6 +489,10 @@ def eof_received(self): try: if self._loop.get_debug(): logger.debug("%r received EOF", self) + + if self._waiter is not None and not self._waiter.done(): + self._waiter.set_exception(ConnectionResetError()) + if not self._in_handshake: keep_open = self._app_protocol.eof_received() if keep_open: diff --git a/tests/test_sslproto.py b/tests/test_sslproto.py index b1a61c48..148e30df 100644 --- a/tests/test_sslproto.py +++ b/tests/test_sslproto.py @@ -12,21 +12,36 @@ from asyncio import test_utils +@unittest.skipIf(ssl is None, 'No ssl module') class SslProtoHandshakeTests(test_utils.TestCase): def setUp(self): self.loop = asyncio.new_event_loop() self.set_event_loop(self.loop) - @unittest.skipIf(ssl is None, 'No ssl module') + def ssl_protocol(self, waiter=None): + sslcontext = test_utils.dummy_ssl_context() + app_proto = asyncio.Protocol() + return sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter) + + def connection_made(self, ssl_proto, do_handshake=None): + transport = mock.Mock() + sslpipe = mock.Mock() + sslpipe.shutdown.return_value = b'' + if do_handshake: + sslpipe.do_handshake.side_effect = do_handshake + else: + def mock_handshake(callback): + return [] + sslpipe.do_handshake.side_effect = mock_handshake + with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe): + ssl_proto.connection_made(transport) + def test_cancel_handshake(self): # Python issue #23197: cancelling an handshake must not raise an # exception or log an error, even if the handshake failed - sslcontext = test_utils.dummy_ssl_context() - app_proto = asyncio.Protocol() waiter = asyncio.Future(loop=self.loop) - ssl_proto = sslproto.SSLProtocol(self.loop, app_proto, sslcontext, - waiter) + ssl_proto = self.ssl_protocol(waiter) handshake_fut = asyncio.Future(loop=self.loop) def do_handshake(callback): @@ -36,12 +51,7 @@ def do_handshake(callback): return [] waiter.cancel() - transport = mock.Mock() - sslpipe = mock.Mock() - sslpipe.shutdown.return_value = b'' - sslpipe.do_handshake.side_effect = do_handshake - with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe): - ssl_proto.connection_made(transport) + self.connection_made(ssl_proto, do_handshake) with test_utils.disable_logger(): self.loop.run_until_complete(handshake_fut) @@ -49,6 +59,14 @@ def do_handshake(callback): # Close the transport ssl_proto._app_transport.close() + def test_eof_received_waiter(self): + waiter = asyncio.Future(loop=self.loop) + ssl_proto = self.ssl_protocol(waiter) + self.connection_made(ssl_proto) + ssl_proto.eof_received() + test_utils.run_briefly(self.loop) + self.assertIsInstance(waiter.exception(), ConnectionResetError) + if __name__ == '__main__': unittest.main() From aed248b83183b25e9ded057dba4c41f44099b9a4 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 29 Jan 2015 00:12:09 +0100 Subject: [PATCH 1322/1502] SSL transports now clear their reference to the waiter * Rephrase also the comment explaining why the waiter is not awaken immediatly. * SSLProtocol.eof_received() doesn't instanciate ConnectionResetError exception directly, it will be done by Future.set_exception(). The exception is not used if the waiter was cancelled or if there is no waiter. --- asyncio/proactor_events.py | 2 +- asyncio/selector_events.py | 27 ++++++++++++++++----------- asyncio/sslproto.py | 20 +++++++++++++------- asyncio/unix_events.py | 4 ++-- 4 files changed, 32 insertions(+), 21 deletions(-) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index ed170622..0f533a5e 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -38,7 +38,7 @@ def __init__(self, loop, sock, protocol, waiter=None, self._server._attach() self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - # wait until protocol.connection_made() has been called + # only wake up the waiter when connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) def __repr__(self): diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 24f84615..42d88f5d 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -581,7 +581,7 @@ def __init__(self, loop, sock, protocol, waiter=None, self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - # wait until protocol.connection_made() has been called + # only wake up the waiter when connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) def pause_reading(self): @@ -732,6 +732,16 @@ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, start_time = None self._on_handshake(start_time) + def _wakeup_waiter(self, exc=None): + if self._waiter is None: + return + if not self._waiter.cancelled(): + if exc is not None: + self._waiter.set_exception(exc) + else: + self._waiter.set_result(None) + self._waiter = None + def _on_handshake(self, start_time): try: self._sock.do_handshake() @@ -750,8 +760,7 @@ def _on_handshake(self, start_time): self._loop.remove_reader(self._sock_fd) self._loop.remove_writer(self._sock_fd) self._sock.close() - if self._waiter is not None and not self._waiter.cancelled(): - self._waiter.set_exception(exc) + self._wakeup_waiter(exc) if isinstance(exc, Exception): return else: @@ -774,9 +783,7 @@ def _on_handshake(self, start_time): "on matching the hostname", self, exc_info=True) self._sock.close() - if (self._waiter is not None - and not self._waiter.cancelled()): - self._waiter.set_exception(exc) + self._wakeup_waiter(exc) return # Add extra info that becomes available after handshake. @@ -789,10 +796,8 @@ def _on_handshake(self, start_time): self._write_wants_read = False self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) - if self._waiter is not None: - # wait until protocol.connection_made() has been called - self._loop.call_soon(self._waiter._set_result_unless_cancelled, - None) + # only wake up the waiter when connection_made() has been called + self._loop.call_soon(self._wakeup_waiter) if self._loop.get_debug(): dt = self._loop.time() - start_time @@ -924,7 +929,7 @@ def __init__(self, loop, sock, protocol, address=None, self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - # wait until protocol.connection_made() has been called + # only wake up the waiter when connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) def get_write_buffer_size(self): diff --git a/asyncio/sslproto.py b/asyncio/sslproto.py index 26937c84..fc809b98 100644 --- a/asyncio/sslproto.py +++ b/asyncio/sslproto.py @@ -418,6 +418,16 @@ def __init__(self, loop, app_protocol, sslcontext, waiter, self._in_shutdown = False self._transport = None + def _wakeup_waiter(self, exc=None): + if self._waiter is None: + return + if not self._waiter.cancelled(): + if exc is not None: + self._waiter.set_exception(exc) + else: + self._waiter.set_result(None) + self._waiter = None + def connection_made(self, transport): """Called when the low-level connection is made. @@ -490,8 +500,7 @@ def eof_received(self): if self._loop.get_debug(): logger.debug("%r received EOF", self) - if self._waiter is not None and not self._waiter.done(): - self._waiter.set_exception(ConnectionResetError()) + self._wakeup_waiter(ConnectionResetError) if not self._in_handshake: keep_open = self._app_protocol.eof_received() @@ -556,8 +565,7 @@ def _on_handshake_complete(self, handshake_exc): self, exc_info=True) self._transport.close() if isinstance(exc, Exception): - if self._waiter is not None and not self._waiter.cancelled(): - self._waiter.set_exception(exc) + self._wakeup_waiter(exc) return else: raise @@ -572,9 +580,7 @@ def _on_handshake_complete(self, handshake_exc): compression=sslobj.compression(), ) self._app_protocol.connection_made(self._app_transport) - if self._waiter is not None: - # wait until protocol.connection_made() has been called - self._waiter._set_result_unless_cancelled(None) + self._wakeup_waiter() self._session_established = True # In case transport.write() was already called. Don't call # immediatly _process_write_backlog(), but schedule it: diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 97f9addd..67973f14 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -301,7 +301,7 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): self._loop.add_reader(self._fileno, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - # wait until protocol.connection_made() has been called + # only wake up the waiter when connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) def __repr__(self): @@ -409,7 +409,7 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: - # wait until protocol.connection_made() has been called + # only wake up the waiter when connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) def __repr__(self): From c5600035ca9a46c55c43e7bebc014d9a4e0f4362 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 29 Jan 2015 00:18:18 +0100 Subject: [PATCH 1323/1502] Fix _SelectorSocketTransport constructor Only start reading when connection_made() has been called: protocol.data_received() must not be called before protocol.connection_made(). --- asyncio/selector_events.py | 4 +++- tests/test_selector_events.py | 16 +++++++++++----- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 42d88f5d..f4996293 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -578,8 +578,10 @@ def __init__(self, loop, sock, protocol, waiter=None, self._eof = False self._paused = False - self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) + # only start reading when connection_made() has been called + self._loop.call_soon(self._loop.add_reader, + self._sock_fd, self._read_ready) if waiter is not None: # only wake up the waiter when connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index ad86ada3..51526163 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -59,6 +59,7 @@ def setUp(self): def test_make_socket_transport(self): m = mock.Mock() self.loop.add_reader = mock.Mock() + self.loop.add_reader._is_coroutine = False transport = self.loop._make_socket_transport(m, asyncio.Protocol()) self.assertIsInstance(transport, _SelectorSocketTransport) close_transport(transport) @@ -67,6 +68,7 @@ def test_make_socket_transport(self): def test_make_ssl_transport(self): m = mock.Mock() self.loop.add_reader = mock.Mock() + self.loop.add_reader._is_coroutine = False self.loop.add_writer = mock.Mock() self.loop.remove_reader = mock.Mock() self.loop.remove_writer = mock.Mock() @@ -770,20 +772,24 @@ def socket_transport(self, waiter=None): return transport def test_ctor(self): - tr = self.socket_transport() + waiter = asyncio.Future(loop=self.loop) + tr = self.socket_transport(waiter=waiter) + self.loop.run_until_complete(waiter) + self.loop.assert_reader(7, tr._read_ready) test_utils.run_briefly(self.loop) self.protocol.connection_made.assert_called_with(tr) def test_ctor_with_waiter(self): - fut = asyncio.Future(loop=self.loop) + waiter = asyncio.Future(loop=self.loop) + self.socket_transport(waiter=waiter) + self.loop.run_until_complete(waiter) - self.socket_transport(waiter=fut) - test_utils.run_briefly(self.loop) - self.assertIsNone(fut.result()) + self.assertIsNone(waiter.result()) def test_pause_resume_reading(self): tr = self.socket_transport() + test_utils.run_briefly(self.loop) self.assertFalse(tr._paused) self.loop.assert_reader(7, tr._read_ready) tr.pause_reading() From 74a3c25a80dbe9d06283f0476482f19273f1e4c4 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 29 Jan 2015 01:59:57 +0100 Subject: [PATCH 1324/1502] BaseSubprocessTransport._kill_wait() now also call close() close() closes pipes, which is not None yet by _kill_wait(). --- asyncio/base_subprocess.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index f5e7dfec..81c6f1a7 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -117,12 +117,15 @@ def _kill_wait(self): proc.stderr.close() if proc.stdin: proc.stdin.close() + try: proc.kill() except ProcessLookupError: pass self._returncode = proc.wait() + self.close() + @coroutine def _post_init(self): try: From 2c89c715f433ad106be41ee92d970c61338fa074 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 29 Jan 2015 02:38:01 +0100 Subject: [PATCH 1325/1502] Fix _SelectorDatagramTransport constructor Only start reading after connection_made() has been called. --- asyncio/selector_events.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index f4996293..5fe46e5f 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -928,8 +928,10 @@ def __init__(self, loop, sock, protocol, address=None, waiter=None, extra=None): super().__init__(loop, sock, protocol, extra) self._address = address - self._loop.add_reader(self._sock_fd, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) + # only start reading when connection_made() has been called + self._loop.call_soon(self._loop.add_reader, + self._sock_fd, self._read_ready) if waiter is not None: # only wake up the waiter when connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) From 1162bf22e62544d5fdc3057ba23d2fc0fc244d07 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 29 Jan 2015 02:38:46 +0100 Subject: [PATCH 1326/1502] _SelectorTransport constructor: extra parameter is now optional --- asyncio/selector_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 5fe46e5f..d046eb2a 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -467,7 +467,7 @@ class _SelectorTransport(transports._FlowControlMixin, _buffer_factory = bytearray # Constructs initial value for self._buffer. - def __init__(self, loop, sock, protocol, extra, server=None): + def __init__(self, loop, sock, protocol, extra=None, server=None): super().__init__(extra, loop) self._extra['socket'] = sock self._extra['sockname'] = sock.getsockname() From 7a46edabd8a2aa665c211cd53d36fa45b4ca856c Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 29 Jan 2015 02:52:57 +0100 Subject: [PATCH 1327/1502] Fix _SelectorSslTransport.close() Don't call protocol.connection_lost() if protocol.connection_made() was not called yet: if the SSL handshake failed or is still in progress. The close() method can be called if the creation of the connection is cancelled, by a timeout for example. --- asyncio/selector_events.py | 7 ++++++- tests/test_selector_events.py | 15 ++++++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index d046eb2a..3195f622 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -479,6 +479,7 @@ def __init__(self, loop, sock, protocol, extra=None, server=None): self._sock = sock self._sock_fd = sock.fileno() self._protocol = protocol + self._protocol_connected = True self._server = server self._buffer = self._buffer_factory() self._conn_lost = 0 # Set when call to connection_lost scheduled. @@ -555,7 +556,8 @@ def _force_close(self, exc): def _call_connection_lost(self, exc): try: - self._protocol.connection_lost(exc) + if self._protocol_connected: + self._protocol.connection_lost(exc) finally: self._sock.close() self._sock = None @@ -718,6 +720,8 @@ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs) super().__init__(loop, sslsock, protocol, extra, server) + # the protocol connection is only made after the SSL handshake + self._protocol_connected = False self._server_hostname = server_hostname self._waiter = waiter @@ -797,6 +801,7 @@ def _on_handshake(self, start_time): self._read_wants_write = False self._write_wants_read = False self._loop.add_reader(self._sock_fd, self._read_ready) + self._protocol_connected = True self._loop.call_soon(self._protocol.connection_made, self) # only wake up the waiter when connection_made() has been called self._loop.call_soon(self._wakeup_waiter) diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 51526163..f64e40da 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -1427,7 +1427,7 @@ def test_write_eof(self): self.assertFalse(tr.can_write_eof()) self.assertRaises(NotImplementedError, tr.write_eof) - def test_close(self): + def check_close(self): tr = self._make_one() tr.close() @@ -1439,6 +1439,19 @@ def test_close(self): self.assertEqual(tr._conn_lost, 1) self.assertEqual(1, self.loop.remove_reader_count[1]) + test_utils.run_briefly(self.loop) + + def test_close(self): + self.check_close() + self.assertTrue(self.protocol.connection_made.called) + self.assertTrue(self.protocol.connection_lost.called) + + def test_close_not_connected(self): + self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError + self.check_close() + self.assertFalse(self.protocol.connection_made.called) + self.assertFalse(self.protocol.connection_lost.called) + @unittest.skipIf(ssl is None, 'No SSL support') def test_server_hostname(self): self.ssl_transport(server_hostname='localhost') From c66282230e3966e45ff22374096bc3a1761e3236 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 29 Jan 2015 10:00:08 +0100 Subject: [PATCH 1328/1502] Cleanup gather(): use cancelled() method instead of using private Future attribute --- asyncio/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 63412a97..4f19a252 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -592,7 +592,7 @@ def _done_callback(i, fut): fut.exception() return - if fut._state == futures._CANCELLED: + if fut.cancelled(): res = futures.CancelledError() if not return_exceptions: outer.set_exception(res) From 5fe382d0e916c63d7f2e72a9b5361058501f5b2b Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 29 Jan 2015 13:40:58 +0100 Subject: [PATCH 1329/1502] Fix _UnixReadPipeTransport and _UnixWritePipeTransport Only start reading when connection_made() has been called. --- asyncio/unix_events.py | 17 +++++++++++------ tests/test_unix_events.py | 29 +++++++++++------------------ 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 67973f14..7e1265a0 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -298,8 +298,10 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): _set_nonblocking(self._fileno) self._protocol = protocol self._closing = False - self._loop.add_reader(self._fileno, self._read_ready) self._loop.call_soon(self._protocol.connection_made, self) + # only start reading when connection_made() has been called + self._loop.call_soon(self._loop.add_reader, + self._fileno, self._read_ready) if waiter is not None: # only wake up the waiter when connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) @@ -401,13 +403,16 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): self._conn_lost = 0 self._closing = False # Set when close() or write_eof() called. - # On AIX, the reader trick only works for sockets. - # On other platforms it works for pipes and sockets. - # (Exception: OS X 10.4? Issue #19294.) + self._loop.call_soon(self._protocol.connection_made, self) + + # On AIX, the reader trick (to be notified when the read end of the + # socket is closed) only works for sockets. On other platforms it + # works for pipes and sockets. (Exception: OS X 10.4? Issue #19294.) if is_socket or not sys.platform.startswith("aix"): - self._loop.add_reader(self._fileno, self._read_ready) + # only start reading when connection_made() has been called + self._loop.call_soon(self._loop.add_reader, + self._fileno, self._read_ready) - self._loop.call_soon(self._protocol.connection_made, self) if waiter is not None: # only wake up the waiter when connection_made() has been called self._loop.call_soon(waiter._set_result_unless_cancelled, None) diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 126196da..41249ff0 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -350,16 +350,13 @@ def read_pipe_transport(self, waiter=None): return transport def test_ctor(self): - tr = self.read_pipe_transport() - self.loop.assert_reader(5, tr._read_ready) - test_utils.run_briefly(self.loop) - self.protocol.connection_made.assert_called_with(tr) + waiter = asyncio.Future(loop=self.loop) + tr = self.read_pipe_transport(waiter=waiter) + self.loop.run_until_complete(waiter) - def test_ctor_with_waiter(self): - fut = asyncio.Future(loop=self.loop) - tr = self.read_pipe_transport(waiter=fut) - test_utils.run_briefly(self.loop) - self.assertIsNone(fut.result()) + self.protocol.connection_made.assert_called_with(tr) + self.loop.assert_reader(5, tr._read_ready) + self.assertIsNone(waiter.result()) @mock.patch('os.read') def test__read_ready(self, m_read): @@ -502,17 +499,13 @@ def write_pipe_transport(self, waiter=None): return transport def test_ctor(self): - tr = self.write_pipe_transport() - self.loop.assert_reader(5, tr._read_ready) - test_utils.run_briefly(self.loop) - self.protocol.connection_made.assert_called_with(tr) + waiter = asyncio.Future(loop=self.loop) + tr = self.write_pipe_transport(waiter=waiter) + self.loop.run_until_complete(waiter) - def test_ctor_with_waiter(self): - fut = asyncio.Future(loop=self.loop) - tr = self.write_pipe_transport(waiter=fut) + self.protocol.connection_made.assert_called_with(tr) self.loop.assert_reader(5, tr._read_ready) - test_utils.run_briefly(self.loop) - self.assertEqual(None, fut.result()) + self.assertEqual(None, waiter.result()) def test_can_write_eof(self): tr = self.write_pipe_transport() From 59fcee2f2dd9d1853f3cf67d0d718bec2ad0057d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 29 Jan 2015 14:09:43 +0100 Subject: [PATCH 1330/1502] Fix BaseSelectorEventLoop._accept_connection() * Close the transport on error * In debug mode, log errors using call_exception_handler() --- asyncio/selector_events.py | 44 ++++++++++++++++++++++++++++++++------ tests/test_events.py | 37 +++++++++++++++++++++----------- 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 3195f622..91478326 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -22,6 +22,7 @@ from . import selectors from . import transports from . import sslproto +from .coroutines import coroutine from .log import logger @@ -181,16 +182,47 @@ def _accept_connection(self, protocol_factory, sock, else: raise # The event loop will catch, log and ignore it. else: + extra = {'peername': addr} + accept = self._accept_connection2(protocol_factory, conn, extra, + sslcontext, server) + self.create_task(accept) + + @coroutine + def _accept_connection2(self, protocol_factory, conn, extra, + sslcontext=None, server=None): + protocol = None + transport = None + try: protocol = protocol_factory() + waiter = futures.Future(loop=self) if sslcontext: - self._make_ssl_transport( - conn, protocol, sslcontext, - server_side=True, extra={'peername': addr}, server=server) + transport = self._make_ssl_transport( + conn, protocol, sslcontext, waiter=waiter, + server_side=True, extra=extra, server=server) else: - self._make_socket_transport( - conn, protocol , extra={'peername': addr}, + transport = self._make_socket_transport( + conn, protocol, waiter=waiter, extra=extra, server=server) - # It's now up to the protocol to handle the connection. + + try: + yield from waiter + except: + transport.close() + raise + + # It's now up to the protocol to handle the connection. + except Exception as exc: + if self.get_debug(): + context = { + 'message': ('Error on transport creation ' + 'for incoming connection'), + 'exception': exc, + } + if protocol is not None: + context['protocol'] = protocol + if transport is not None: + context['transport'] = transport + self.call_exception_handler(context) def add_reader(self, fd, callback, *args): """Add a reader callback.""" diff --git a/tests/test_events.py b/tests/test_events.py index a38c90eb..12af62b2 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -886,13 +886,18 @@ def test_create_server_ssl_verify_failed(self): if hasattr(sslcontext_client, 'check_hostname'): sslcontext_client.check_hostname = True + # no CA loaded f_c = self.loop.create_connection(MyProto, host, port, ssl=sslcontext_client) - with test_utils.disable_logger(): - with self.assertRaisesRegex(ssl.SSLError, - 'certificate verify failed '): - self.loop.run_until_complete(f_c) + with mock.patch.object(self.loop, 'call_exception_handler'): + with test_utils.disable_logger(): + with self.assertRaisesRegex(ssl.SSLError, + 'certificate verify failed '): + self.loop.run_until_complete(f_c) + + # execute the loop to log the connection error + test_utils.run_briefly(self.loop) # close connection self.assertIsNone(proto.transport) @@ -919,15 +924,20 @@ def test_create_unix_server_ssl_verify_failed(self): f_c = self.loop.create_unix_connection(MyProto, path, ssl=sslcontext_client, server_hostname='invalid') - with test_utils.disable_logger(): - with self.assertRaisesRegex(ssl.SSLError, - 'certificate verify failed '): - self.loop.run_until_complete(f_c) + with mock.patch.object(self.loop, 'call_exception_handler'): + with test_utils.disable_logger(): + with self.assertRaisesRegex(ssl.SSLError, + 'certificate verify failed '): + self.loop.run_until_complete(f_c) + + # execute the loop to log the connection error + test_utils.run_briefly(self.loop) # close connection self.assertIsNone(proto.transport) server.close() + def test_legacy_create_unix_server_ssl_verify_failed(self): with test_utils.force_legacy_ssl_support(): self.test_create_unix_server_ssl_verify_failed() @@ -949,11 +959,12 @@ def test_create_server_ssl_match_failed(self): # incorrect server_hostname f_c = self.loop.create_connection(MyProto, host, port, ssl=sslcontext_client) - with test_utils.disable_logger(): - with self.assertRaisesRegex( - ssl.CertificateError, - "hostname '127.0.0.1' doesn't match 'localhost'"): - self.loop.run_until_complete(f_c) + with mock.patch.object(self.loop, 'call_exception_handler'): + with test_utils.disable_logger(): + with self.assertRaisesRegex( + ssl.CertificateError, + "hostname '127.0.0.1' doesn't match 'localhost'"): + self.loop.run_until_complete(f_c) # close connection proto.transport.close() From 93a4cf4c5dae923cd1f8f43e360e5705bb24d910 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 29 Jan 2015 14:14:07 +0100 Subject: [PATCH 1331/1502] Document Protocol state machine --- asyncio/protocols.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/asyncio/protocols.py b/asyncio/protocols.py index 52fc25c2..80fcac9a 100644 --- a/asyncio/protocols.py +++ b/asyncio/protocols.py @@ -78,6 +78,11 @@ class Protocol(BaseProtocol): State machine of calls: start -> CM [-> DR*] [-> ER?] -> CL -> end + + * CM: connection_made() + * DR: data_received() + * ER: eof_received() + * CL: connection_lost() """ def data_received(self, data): From 2b0f2747531f0eee45c86e5111b15598e0b2613a Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 29 Jan 2015 16:17:08 +0100 Subject: [PATCH 1332/1502] Python issue #23243: test_sslproto: Close explicitly transports --- tests/test_sslproto.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/test_sslproto.py b/tests/test_sslproto.py index 148e30df..a72967ea 100644 --- a/tests/test_sslproto.py +++ b/tests/test_sslproto.py @@ -22,7 +22,9 @@ def setUp(self): def ssl_protocol(self, waiter=None): sslcontext = test_utils.dummy_ssl_context() app_proto = asyncio.Protocol() - return sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter) + proto = sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter) + self.addCleanup(proto._app_transport.close) + return proto def connection_made(self, ssl_proto, do_handshake=None): transport = mock.Mock() @@ -56,9 +58,6 @@ def do_handshake(callback): with test_utils.disable_logger(): self.loop.run_until_complete(handshake_fut) - # Close the transport - ssl_proto._app_transport.close() - def test_eof_received_waiter(self): waiter = asyncio.Future(loop=self.loop) ssl_proto = self.ssl_protocol(waiter) From 241c71030cb79217bd6be6f6dfe31e87bc5f6cbf Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 29 Jan 2015 17:32:39 +0100 Subject: [PATCH 1333/1502] Python issue #23243: On Python 3.4 and newer, emit a ResourceWarning when an event loop or a transport is not explicitly closed --- asyncio/base_events.py | 11 +++++++++++ asyncio/base_subprocess.py | 19 ++++++++++++++++++- asyncio/futures.py | 6 +++--- asyncio/proactor_events.py | 11 +++++++++++ asyncio/selector_events.py | 16 ++++++++++++++++ asyncio/sslproto.py | 13 +++++++++++++ asyncio/unix_events.py | 19 +++++++++++++++++++ asyncio/windows_utils.py | 6 +++++- tests/test_proactor_events.py | 6 +++++- 9 files changed, 101 insertions(+), 6 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index e40d3ad5..7108f251 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -26,6 +26,7 @@ import time import traceback import sys +import warnings from . import coroutines from . import events @@ -333,6 +334,16 @@ def is_closed(self): """Returns True if the event loop was closed.""" return self._closed + # On Python 3.3 and older, objects with a destructor part of a reference + # cycle are never destroyed. It's not more the case on Python 3.4 thanks + # to the PEP 442. + if sys.version_info >= (3, 4): + def __del__(self): + if not self.is_closed(): + warnings.warn("unclosed event loop %r" % self, ResourceWarning) + if not self.is_running(): + self.close() + def is_running(self): """Returns True if the event loop is running.""" return (self._owner is not None) diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index 81c6f1a7..651a9a29 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -1,5 +1,7 @@ import collections import subprocess +import sys +import warnings from . import protocols from . import transports @@ -13,6 +15,7 @@ def __init__(self, loop, protocol, args, shell, stdin, stdout, stderr, bufsize, extra=None, **kwargs): super().__init__(extra) + self._closed = False self._protocol = protocol self._loop = loop self._pid = None @@ -40,7 +43,10 @@ def __init__(self, loop, protocol, args, shell, program, self._pid) def __repr__(self): - info = [self.__class__.__name__, 'pid=%s' % self._pid] + info = [self.__class__.__name__] + if self._closed: + info.append('closed') + info.append('pid=%s' % self._pid) if self._returncode is not None: info.append('returncode=%s' % self._returncode) @@ -70,6 +76,7 @@ def _make_read_subprocess_pipe_proto(self, fd): raise NotImplementedError def close(self): + self._closed = True for proto in self._pipes.values(): if proto is None: continue @@ -77,6 +84,15 @@ def close(self): if self._returncode is None: self.terminate() + # On Python 3.3 and older, objects with a destructor part of a reference + # cycle are never destroyed. It's not more the case on Python 3.4 thanks + # to the PEP 442. + if sys.version_info >= (3, 4): + def __del__(self): + if not self._closed: + warnings.warn("unclosed transport %r" % self, ResourceWarning) + self.close() + def get_pid(self): return self._pid @@ -104,6 +120,7 @@ def _kill_wait(self): Function called when an exception is raised during the creation of a subprocess. """ + self._closed = True if self._loop.get_debug(): logger.warning('Exception during subprocess creation, ' 'kill the subprocess %r', diff --git a/asyncio/futures.py b/asyncio/futures.py index 19212a94..2c741fd4 100644 --- a/asyncio/futures.py +++ b/asyncio/futures.py @@ -195,9 +195,9 @@ def __repr__(self): info = self._repr_info() return '<%s %s>' % (self.__class__.__name__, ' '.join(info)) - # On Python 3.3 or older, objects with a destructor part of a reference - # cycle are never destroyed. It's not more the case on Python 3.4 thanks to - # the PEP 442. + # On Python 3.3 and older, objects with a destructor part of a reference + # cycle are never destroyed. It's not more the case on Python 3.4 thanks + # to the PEP 442. if _PY34: def __del__(self): if not self._log_traceback: diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 0f533a5e..65de926b 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -7,6 +7,8 @@ __all__ = ['BaseProactorEventLoop'] import socket +import sys +import warnings from . import base_events from . import constants @@ -74,6 +76,15 @@ def close(self): self._read_fut.cancel() self._read_fut = None + # On Python 3.3 and older, objects with a destructor part of a reference + # cycle are never destroyed. It's not more the case on Python 3.4 thanks + # to the PEP 442. + if sys.version_info >= (3, 4): + def __del__(self): + if self._sock is not None: + warnings.warn("unclosed transport %r" % self, ResourceWarning) + self.close() + def _fatal_error(self, exc, message='Fatal error on pipe transport'): if isinstance(exc, (BrokenPipeError, ConnectionResetError)): if self._loop.get_debug(): diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 91478326..4bd6dc8d 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -10,6 +10,8 @@ import errno import functools import socket +import sys +import warnings try: import ssl except ImportError: # pragma: no cover @@ -499,6 +501,11 @@ class _SelectorTransport(transports._FlowControlMixin, _buffer_factory = bytearray # Constructs initial value for self._buffer. + # Attribute used in the destructor: it must be set even if the constructor + # is not called (see _SelectorSslTransport which may start by raising an + # exception) + _sock = None + def __init__(self, loop, sock, protocol, extra=None, server=None): super().__init__(extra, loop) self._extra['socket'] = sock @@ -559,6 +566,15 @@ def close(self): self._conn_lost += 1 self._loop.call_soon(self._call_connection_lost, None) + # On Python 3.3 and older, objects with a destructor part of a reference + # cycle are never destroyed. It's not more the case on Python 3.4 thanks + # to the PEP 442. + if sys.version_info >= (3, 4): + def __del__(self): + if self._sock is not None: + warnings.warn("unclosed transport %r" % self, ResourceWarning) + self._sock.close() + def _fatal_error(self, exc, message='Fatal error on transport'): # Should be called from exception handler only. if isinstance(exc, (BrokenPipeError, diff --git a/asyncio/sslproto.py b/asyncio/sslproto.py index fc809b98..235855e2 100644 --- a/asyncio/sslproto.py +++ b/asyncio/sslproto.py @@ -1,4 +1,6 @@ import collections +import sys +import warnings try: import ssl except ImportError: # pragma: no cover @@ -295,6 +297,7 @@ def __init__(self, loop, ssl_protocol, app_protocol): self._loop = loop self._ssl_protocol = ssl_protocol self._app_protocol = app_protocol + self._closed = False def get_extra_info(self, name, default=None): """Get optional transport information.""" @@ -308,8 +311,18 @@ def close(self): protocol's connection_lost() method will (eventually) called with None as its argument. """ + self._closed = True self._ssl_protocol._start_shutdown() + # On Python 3.3 and older, objects with a destructor part of a reference + # cycle are never destroyed. It's not more the case on Python 3.4 thanks + # to the PEP 442. + if sys.version_info >= (3, 4): + def __del__(self): + if not self._closed: + warnings.warn("unclosed transport %r" % self, ResourceWarning) + self.close() + def pause_reading(self): """Pause the receiving end. diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 7e1265a0..b06f1b23 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -8,6 +8,7 @@ import subprocess import sys import threading +import warnings from . import base_events @@ -353,6 +354,15 @@ def close(self): if not self._closing: self._close(None) + # On Python 3.3 and older, objects with a destructor part of a reference + # cycle are never destroyed. It's not more the case on Python 3.4 thanks + # to the PEP 442. + if sys.version_info >= (3, 4): + def __del__(self): + if self._pipe is not None: + warnings.warn("unclosed transport %r" % self, ResourceWarning) + self._pipe.close() + def _fatal_error(self, exc, message='Fatal error on pipe transport'): # should be called by exception handler only if (isinstance(exc, OSError) and exc.errno == errno.EIO): @@ -529,6 +539,15 @@ def close(self): # write_eof is all what we needed to close the write pipe self.write_eof() + # On Python 3.3 and older, objects with a destructor part of a reference + # cycle are never destroyed. It's not more the case on Python 3.4 thanks + # to the PEP 442. + if sys.version_info >= (3, 4): + def __del__(self): + if self._pipe is not None: + warnings.warn("unclosed transport %r" % self, ResourceWarning) + self._pipe.close() + def abort(self): self._close(None) diff --git a/asyncio/windows_utils.py b/asyncio/windows_utils.py index 5f8327eb..870cd13a 100644 --- a/asyncio/windows_utils.py +++ b/asyncio/windows_utils.py @@ -14,6 +14,7 @@ import socket import subprocess import tempfile +import warnings __all__ = ['socketpair', 'pipe', 'Popen', 'PIPE', 'PipeHandle'] @@ -156,7 +157,10 @@ def close(self, *, CloseHandle=_winapi.CloseHandle): CloseHandle(self._handle) self._handle = None - __del__ = close + def __del__(self): + if self._handle is not None: + warnings.warn("unclosed %r" % self, ResourceWarning) + self.close() def __enter__(self): return self diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index 33a8a671..fcd9ab1e 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -499,8 +499,12 @@ def test_sock_accept(self): self.proactor.accept.assert_called_with(self.sock) def test_socketpair(self): + class EventLoop(BaseProactorEventLoop): + # override the destructor to not log a ResourceWarning + def __del__(self): + pass self.assertRaises( - NotImplementedError, BaseProactorEventLoop, self.proactor) + NotImplementedError, EventLoop, self.proactor) def test_make_socket_transport(self): tr = self.loop._make_socket_transport(self.sock, asyncio.Protocol()) From c92da156217a769bb147f1f0883f368f9ba57d3b Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 29 Jan 2015 23:53:11 +0100 Subject: [PATCH 1334/1502] Python issue #23347: Refactor creation of subprocess transports Changes on BaseSubprocessTransport: * Add a wait() method to wait until the child process exit * The constructor now accepts an optional waiter parameter. The _post_init() coroutine must not be called explicitly anymore. It makes subprocess transports closer to other transports, and it gives more freedom if we want later to change completly how subprocess transports are created. * close() now kills the process instead of kindly terminate it: the child process may ignore SIGTERM and continue to run. Call explicitly terminate() and wait() if you want to kindly terminate the child process. * close() now logs a warning in debug mode if the process is still running and needs to be killed * _make_subprocess_transport() is now fully asynchronous again: if the creation of the transport failed, wait asynchronously for the process eixt. Before the wait was synchronous. This change requires close() to *kill*, and not terminate, the child process. * Remove the _kill_wait() method, replaced with a more agressive close() method. It fixes _make_subprocess_transport() on error. BaseSubprocessTransport.close() calls the close() method of pipe transports, whereas _kill_wait() closed directly pipes of the subprocess.Popen object without unregistering file descriptors from the selector (which caused severe bugs). These changes simplifies the code of subprocess.py. --- asyncio/base_subprocess.py | 99 +++++++++++++++++++++----------------- asyncio/subprocess.py | 33 ++----------- asyncio/unix_events.py | 15 ++++-- asyncio/windows_events.py | 7 ++- tests/test_events.py | 35 +++++++------- 5 files changed, 92 insertions(+), 97 deletions(-) diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index 651a9a29..afeda139 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -3,6 +3,7 @@ import sys import warnings +from . import futures from . import protocols from . import transports from .coroutines import coroutine @@ -13,27 +14,32 @@ class BaseSubprocessTransport(transports.SubprocessTransport): def __init__(self, loop, protocol, args, shell, stdin, stdout, stderr, bufsize, - extra=None, **kwargs): + waiter=None, extra=None, **kwargs): super().__init__(extra) self._closed = False self._protocol = protocol self._loop = loop + self._proc = None self._pid = None - + self._returncode = None + self._exit_waiters = [] + self._pending_calls = collections.deque() self._pipes = {} + self._finished = False + if stdin == subprocess.PIPE: self._pipes[0] = None if stdout == subprocess.PIPE: self._pipes[1] = None if stderr == subprocess.PIPE: self._pipes[2] = None - self._pending_calls = collections.deque() - self._finished = False - self._returncode = None + + # Create the child process: set the _proc attribute self._start(args=args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, bufsize=bufsize, **kwargs) self._pid = self._proc.pid self._extra['subprocess'] = self._proc + if self._loop.get_debug(): if isinstance(args, (bytes, str)): program = args @@ -42,6 +48,8 @@ def __init__(self, loop, protocol, args, shell, logger.debug('process %r created: pid %s', program, self._pid) + self._loop.create_task(self._connect_pipes(waiter)) + def __repr__(self): info = [self.__class__.__name__] if self._closed: @@ -77,12 +85,23 @@ def _make_read_subprocess_pipe_proto(self, fd): def close(self): self._closed = True + for proto in self._pipes.values(): if proto is None: continue proto.pipe.close() - if self._returncode is None: - self.terminate() + + if self._proc is not None and self._returncode is None: + if self._loop.get_debug(): + logger.warning('Close running child process: kill %r', self) + + try: + self._proc.kill() + except ProcessLookupError: + pass + + # Don't clear the _proc reference yet because _post_init() may + # still run # On Python 3.3 and older, objects with a destructor part of a reference # cycle are never destroyed. It's not more the case on Python 3.4 thanks @@ -114,50 +133,24 @@ def terminate(self): def kill(self): self._proc.kill() - def _kill_wait(self): - """Close pipes, kill the subprocess and read its return status. - - Function called when an exception is raised during the creation - of a subprocess. - """ - self._closed = True - if self._loop.get_debug(): - logger.warning('Exception during subprocess creation, ' - 'kill the subprocess %r', - self, - exc_info=True) - - proc = self._proc - if proc.stdout: - proc.stdout.close() - if proc.stderr: - proc.stderr.close() - if proc.stdin: - proc.stdin.close() - - try: - proc.kill() - except ProcessLookupError: - pass - self._returncode = proc.wait() - - self.close() - @coroutine - def _post_init(self): + def _connect_pipes(self, waiter): try: proc = self._proc loop = self._loop + if proc.stdin is not None: _, pipe = yield from loop.connect_write_pipe( lambda: WriteSubprocessPipeProto(self, 0), proc.stdin) self._pipes[0] = pipe + if proc.stdout is not None: _, pipe = yield from loop.connect_read_pipe( lambda: ReadSubprocessPipeProto(self, 1), proc.stdout) self._pipes[1] = pipe + if proc.stderr is not None: _, pipe = yield from loop.connect_read_pipe( lambda: ReadSubprocessPipeProto(self, 2), @@ -166,13 +159,16 @@ def _post_init(self): assert self._pending_calls is not None - self._loop.call_soon(self._protocol.connection_made, self) + loop.call_soon(self._protocol.connection_made, self) for callback, data in self._pending_calls: - self._loop.call_soon(callback, *data) + loop.call_soon(callback, *data) self._pending_calls = None - except: - self._kill_wait() - raise + except Exception as exc: + if waiter is not None and not waiter.cancelled(): + waiter.set_exception(exc) + else: + if waiter is not None and not waiter.cancelled(): + waiter.set_result(None) def _call(self, cb, *data): if self._pending_calls is not None: @@ -197,6 +193,23 @@ def _process_exited(self, returncode): self._call(self._protocol.process_exited) self._try_finish() + # wake up futures waiting for wait() + for waiter in self._exit_waiters: + if not waiter.cancelled(): + waiter.set_result(returncode) + self._exit_waiters = None + + def wait(self): + """Wait until the process exit and return the process return code. + + This method is a coroutine.""" + if self._returncode is not None: + return self._returncode + + waiter = futures.Future(loop=self._loop) + self._exit_waiters.append(waiter) + return (yield from waiter) + def _try_finish(self): assert not self._finished if self._returncode is None: @@ -210,9 +223,9 @@ def _call_connection_lost(self, exc): try: self._protocol.connection_lost(exc) finally: + self._loop = None self._proc = None self._protocol = None - self._loop = None class WriteSubprocessPipeProto(protocols.BaseProtocol): diff --git a/asyncio/subprocess.py b/asyncio/subprocess.py index c848a21a..49b2d7c8 100644 --- a/asyncio/subprocess.py +++ b/asyncio/subprocess.py @@ -25,8 +25,6 @@ def __init__(self, limit, loop): super().__init__(loop=loop) self._limit = limit self.stdin = self.stdout = self.stderr = None - self.waiter = futures.Future(loop=loop) - self._waiters = collections.deque() self._transport = None def __repr__(self): @@ -61,9 +59,6 @@ def connection_made(self, transport): reader=None, loop=self._loop) - if not self.waiter.cancelled(): - self.waiter.set_result(None) - def pipe_data_received(self, fd, data): if fd == 1: reader = self.stdout @@ -94,16 +89,9 @@ def pipe_connection_lost(self, fd, exc): reader.set_exception(exc) def process_exited(self): - returncode = self._transport.get_returncode() self._transport.close() self._transport = None - # wake up futures waiting for wait() - while self._waiters: - waiter = self._waiters.popleft() - if not waiter.cancelled(): - waiter.set_result(returncode) - class Process: def __init__(self, transport, protocol, loop): @@ -124,15 +112,10 @@ def returncode(self): @coroutine def wait(self): - """Wait until the process exit and return the process return code.""" - returncode = self._transport.get_returncode() - if returncode is not None: - return returncode + """Wait until the process exit and return the process return code. - waiter = futures.Future(loop=self._loop) - self._protocol._waiters.append(waiter) - yield from waiter - return waiter.result() + This method is a coroutine.""" + return (yield from self._transport.wait()) def _check_alive(self): if self._transport.get_returncode() is not None: @@ -221,11 +204,6 @@ def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None, protocol_factory, cmd, stdin=stdin, stdout=stdout, stderr=stderr, **kwds) - try: - yield from protocol.waiter - except: - transport._kill_wait() - raise return Process(transport, protocol, loop) @coroutine @@ -241,9 +219,4 @@ def create_subprocess_exec(program, *args, stdin=None, stdout=None, program, *args, stdin=stdin, stdout=stdout, stderr=stderr, **kwds) - try: - yield from protocol.waiter - except: - transport._kill_wait() - raise return Process(transport, protocol, loop) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index b06f1b23..3ecdfd2e 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -16,6 +16,7 @@ from . import constants from . import coroutines from . import events +from . import futures from . import selector_events from . import selectors from . import transports @@ -175,16 +176,20 @@ def _make_subprocess_transport(self, protocol, args, shell, stdin, stdout, stderr, bufsize, extra=None, **kwargs): with events.get_child_watcher() as watcher: + waiter = futures.Future(loop=self) transp = _UnixSubprocessTransport(self, protocol, args, shell, stdin, stdout, stderr, bufsize, - extra=extra, **kwargs) + waiter=waiter, extra=extra, + **kwargs) + + watcher.add_child_handler(transp.get_pid(), + self._child_watcher_callback, transp) try: - yield from transp._post_init() + yield from waiter except: transp.close() + yield from transp.wait() raise - watcher.add_child_handler(transp.get_pid(), - self._child_watcher_callback, transp) return transp @@ -774,7 +779,7 @@ def __exit__(self, a, b, c): pass def add_child_handler(self, pid, callback, *args): - self._callbacks[pid] = callback, args + self._callbacks[pid] = (callback, args) # Prevent a race condition in case the child is already terminated. self._do_waitpid(pid) diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 94aafb6f..437eb0ac 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -366,13 +366,16 @@ def loop_accept_pipe(f=None): def _make_subprocess_transport(self, protocol, args, shell, stdin, stdout, stderr, bufsize, extra=None, **kwargs): + waiter = futures.Future(loop=self) transp = _WindowsSubprocessTransport(self, protocol, args, shell, stdin, stdout, stderr, bufsize, - extra=extra, **kwargs) + waiter=waiter, extra=extra, + **kwargs) try: - yield from transp._post_init() + yield from waiter except: transp.close() + yield from transp.wait() raise return transp diff --git a/tests/test_events.py b/tests/test_events.py index 12af62b2..4b957d8f 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1551,9 +1551,10 @@ def test_subprocess_exec(self): stdin = transp.get_pipe_transport(0) stdin.write(b'Python The Winner') self.loop.run_until_complete(proto.got_data[1].wait()) - transp.close() + with test_utils.disable_logger(): + transp.close() self.loop.run_until_complete(proto.completed) - self.check_terminated(proto.returncode) + self.check_killed(proto.returncode) self.assertEqual(b'Python The Winner', proto.data[1]) def test_subprocess_interactive(self): @@ -1567,21 +1568,20 @@ def test_subprocess_interactive(self): self.loop.run_until_complete(proto.connected) self.assertEqual('CONNECTED', proto.state) - try: - stdin = transp.get_pipe_transport(0) - stdin.write(b'Python ') - self.loop.run_until_complete(proto.got_data[1].wait()) - proto.got_data[1].clear() - self.assertEqual(b'Python ', proto.data[1]) - - stdin.write(b'The Winner') - self.loop.run_until_complete(proto.got_data[1].wait()) - self.assertEqual(b'Python The Winner', proto.data[1]) - finally: - transp.close() + stdin = transp.get_pipe_transport(0) + stdin.write(b'Python ') + self.loop.run_until_complete(proto.got_data[1].wait()) + proto.got_data[1].clear() + self.assertEqual(b'Python ', proto.data[1]) + stdin.write(b'The Winner') + self.loop.run_until_complete(proto.got_data[1].wait()) + self.assertEqual(b'Python The Winner', proto.data[1]) + + with test_utils.disable_logger(): + transp.close() self.loop.run_until_complete(proto.completed) - self.check_terminated(proto.returncode) + self.check_killed(proto.returncode) def test_subprocess_shell(self): connect = self.loop.subprocess_shell( @@ -1739,9 +1739,10 @@ def test_subprocess_close_client_stream(self): # GetLastError()==ERROR_INVALID_NAME on Windows!?! (Using # WriteFile() we get ERROR_BROKEN_PIPE as expected.) self.assertEqual(b'ERR:OSError', proto.data[2]) - transp.close() + with test_utils.disable_logger(): + transp.close() self.loop.run_until_complete(proto.completed) - self.check_terminated(proto.returncode) + self.check_killed(proto.returncode) def test_subprocess_wait_no_same_group(self): # start the new process in a new session From 2e187bbe06a6a7f1caf7c687e9f9d5e6eed0c592 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 29 Jan 2015 23:34:37 +0100 Subject: [PATCH 1335/1502] Python issue #23347: send_signal(), kill() and terminate() methods of BaseSubprocessTransport now check if the transport was closed and if the process exited. --- asyncio/base_subprocess.py | 9 ++++++ asyncio/subprocess.py | 7 ---- tests/test_subprocess.py | 65 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 7 deletions(-) diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index afeda139..001f9b8c 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -124,13 +124,22 @@ def get_pipe_transport(self, fd): else: return None + def _check_proc(self): + if self._closed: + raise ValueError("operation on closed transport") + if self._proc is None: + raise ProcessLookupError() + def send_signal(self, signal): + self._check_proc() self._proc.send_signal(signal) def terminate(self): + self._check_proc() self._proc.terminate() def kill(self): + self._check_proc() self._proc.kill() @coroutine diff --git a/asyncio/subprocess.py b/asyncio/subprocess.py index 49b2d7c8..d0c9779c 100644 --- a/asyncio/subprocess.py +++ b/asyncio/subprocess.py @@ -117,20 +117,13 @@ def wait(self): This method is a coroutine.""" return (yield from self._transport.wait()) - def _check_alive(self): - if self._transport.get_returncode() is not None: - raise ProcessLookupError() - def send_signal(self, signal): - self._check_alive() self._transport.send_signal(signal) def terminate(self): - self._check_alive() self._transport.terminate() def kill(self): - self._check_alive() self._transport.kill() @coroutine diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index ecc2c9d8..4f197f39 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -4,6 +4,7 @@ from unittest import mock import asyncio +from asyncio import base_subprocess from asyncio import subprocess from asyncio import test_utils try: @@ -23,6 +24,70 @@ 'data = sys.stdin.buffer.read()', 'sys.stdout.buffer.write(data)'))] +class TestSubprocessTransport(base_subprocess.BaseSubprocessTransport): + def _start(self, *args, **kwargs): + self._proc = mock.Mock() + self._proc.stdin = None + self._proc.stdout = None + self._proc.stderr = None + + +class SubprocessTransportTests(test_utils.TestCase): + def setUp(self): + self.loop = self.new_test_loop() + self.set_event_loop(self.loop) + + + def create_transport(self, waiter=None): + protocol = mock.Mock() + protocol.connection_made._is_coroutine = False + protocol.process_exited._is_coroutine = False + transport = TestSubprocessTransport( + self.loop, protocol, ['test'], False, + None, None, None, 0, waiter=waiter) + return (transport, protocol) + + def test_close(self): + waiter = asyncio.Future(loop=self.loop) + transport, protocol = self.create_transport(waiter) + transport._process_exited(0) + transport.close() + + # The loop didn't run yet + self.assertFalse(protocol.connection_made.called) + + # methods must raise ProcessLookupError if the transport was closed + self.assertRaises(ValueError, transport.send_signal, signal.SIGTERM) + self.assertRaises(ValueError, transport.terminate) + self.assertRaises(ValueError, transport.kill) + + self.loop.run_until_complete(waiter) + + def test_proc_exited(self): + waiter = asyncio.Future(loop=self.loop) + transport, protocol = self.create_transport(waiter) + transport._process_exited(6) + self.loop.run_until_complete(waiter) + + self.assertEqual(transport.get_returncode(), 6) + + self.assertTrue(protocol.connection_made.called) + self.assertTrue(protocol.process_exited.called) + self.assertTrue(protocol.connection_lost.called) + self.assertEqual(protocol.connection_lost.call_args[0], (None,)) + + self.assertFalse(transport._closed) + self.assertIsNone(transport._loop) + self.assertIsNone(transport._proc) + self.assertIsNone(transport._protocol) + + # methods must raise ProcessLookupError if the process exited + self.assertRaises(ProcessLookupError, + transport.send_signal, signal.SIGTERM) + self.assertRaises(ProcessLookupError, transport.terminate) + self.assertRaises(ProcessLookupError, transport.kill) + + class SubprocessMixin: def test_stdin_stdout(self): From cffe67de25372fa85ee1799788217bea84ab58cd Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 30 Jan 2015 00:09:48 +0100 Subject: [PATCH 1336/1502] tox.ini: enable ResourceWarning warnings --- tox.ini | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tox.ini b/tox.ini index 6209ff4e..3030441f 100644 --- a/tox.ini +++ b/tox.ini @@ -8,8 +8,8 @@ deps= setenv = PYTHONASYNCIODEBUG = 1 commands= - python runtests.py -r {posargs} - python run_aiotest.py -r {posargs} + python -Wd runtests.py -r {posargs} + python -Wd run_aiotest.py -r {posargs} [testenv:py3_release] # Run tests in release mode From 787b4945c6fde8d24b4038b493dd6aef2eb0c888 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 30 Jan 2015 00:11:25 +0100 Subject: [PATCH 1337/1502] Fix ResourceWarning in test_subprocess.test_proc_exit() --- tests/test_subprocess.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index 4f197f39..d4b71b7a 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -87,6 +87,8 @@ def test_proc_exited(self): self.assertRaises(ProcessLookupError, transport.terminate) self.assertRaises(ProcessLookupError, transport.kill) + transport.close() + class SubprocessMixin: From d37a906ecc49e12b84431d78b4911534040fd503 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 30 Jan 2015 00:15:54 +0100 Subject: [PATCH 1338/1502] Python issue #23347: Make BaseSubprocessTransport.wait() private --- asyncio/base_subprocess.py | 2 +- asyncio/subprocess.py | 2 +- asyncio/unix_events.py | 2 +- asyncio/windows_events.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index 001f9b8c..70676ab3 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -208,7 +208,7 @@ def _process_exited(self, returncode): waiter.set_result(returncode) self._exit_waiters = None - def wait(self): + def _wait(self): """Wait until the process exit and return the process return code. This method is a coroutine.""" diff --git a/asyncio/subprocess.py b/asyncio/subprocess.py index d0c9779c..4600a9f4 100644 --- a/asyncio/subprocess.py +++ b/asyncio/subprocess.py @@ -115,7 +115,7 @@ def wait(self): """Wait until the process exit and return the process return code. This method is a coroutine.""" - return (yield from self._transport.wait()) + return (yield from self._transport._wait()) def send_signal(self, signal): self._transport.send_signal(signal) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 3ecdfd2e..1fc39abe 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -188,7 +188,7 @@ def _make_subprocess_transport(self, protocol, args, shell, yield from waiter except: transp.close() - yield from transp.wait() + yield from transp._wait() raise return transp diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index 437eb0ac..c4bffc47 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -375,7 +375,7 @@ def _make_subprocess_transport(self, protocol, args, shell, yield from waiter except: transp.close() - yield from transp.wait() + yield from transp._wait() raise return transp From 6c7490e724d1bd64dafa7d4b773302fa58971632 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 30 Jan 2015 01:12:00 +0100 Subject: [PATCH 1339/1502] Python issue #23347: send_signal(), terminate(), kill() don't check if the transport was closed. The check broken a Tulip example and this limitation is arbitrary. Check if _proc is None should be enough. Enhance also close(): do nothing when called the second time. --- asyncio/base_subprocess.py | 7 +++---- tests/test_subprocess.py | 16 ---------------- 2 files changed, 3 insertions(+), 20 deletions(-) diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index 70676ab3..02b9e89f 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -84,6 +84,8 @@ def _make_read_subprocess_pipe_proto(self, fd): raise NotImplementedError def close(self): + if self._closed: + return self._closed = True for proto in self._pipes.values(): @@ -100,8 +102,7 @@ def close(self): except ProcessLookupError: pass - # Don't clear the _proc reference yet because _post_init() may - # still run + # Don't clear the _proc reference yet: _post_init() may still run # On Python 3.3 and older, objects with a destructor part of a reference # cycle are never destroyed. It's not more the case on Python 3.4 thanks @@ -125,8 +126,6 @@ def get_pipe_transport(self, fd): return None def _check_proc(self): - if self._closed: - raise ValueError("operation on closed transport") if self._proc is None: raise ProcessLookupError() diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index d4b71b7a..b467b04f 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -47,22 +47,6 @@ def create_transport(self, waiter=None): None, None, None, 0, waiter=waiter) return (transport, protocol) - def test_close(self): - waiter = asyncio.Future(loop=self.loop) - transport, protocol = self.create_transport(waiter) - transport._process_exited(0) - transport.close() - - # The loop didn't run yet - self.assertFalse(protocol.connection_made.called) - - # methods must raise ProcessLookupError if the transport was closed - self.assertRaises(ValueError, transport.send_signal, signal.SIGTERM) - self.assertRaises(ValueError, transport.terminate) - self.assertRaises(ValueError, transport.kill) - - self.loop.run_until_complete(waiter) - def test_proc_exited(self): waiter = asyncio.Future(loop=self.loop) transport, protocol = self.create_transport(waiter) From ab5d83ca3344320f9292da3192cb0729ba144b13 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 30 Jan 2015 01:16:07 +0100 Subject: [PATCH 1340/1502] Fix a ResourceWarning in the shell example Kill the process on timeout: don't keep a running process in the backgroud! --- examples/shell.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/shell.py b/examples/shell.py index 7dc7caf3..f9343256 100644 --- a/examples/shell.py +++ b/examples/shell.py @@ -36,12 +36,14 @@ def ls(loop): @asyncio.coroutine def test_call(*args, timeout=None): + proc = yield from asyncio.create_subprocess_exec(*args) try: - proc = yield from asyncio.create_subprocess_exec(*args) exitcode = yield from asyncio.wait_for(proc.wait(), timeout) print("%s: exit code %s" % (' '.join(args), exitcode)) except asyncio.TimeoutError: print("timeout! (%.1f sec)" % timeout) + proc.kill() + yield from proc.wait() loop = asyncio.get_event_loop() loop.run_until_complete(cat(loop)) From eec5196ad887c529216485f9846f393ea835d629 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 30 Jan 2015 01:20:08 +0100 Subject: [PATCH 1341/1502] Fix subprocess_attach_write_pipe example Close the transport, not directly the pipe. --- examples/subprocess_attach_write_pipe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/subprocess_attach_write_pipe.py b/examples/subprocess_attach_write_pipe.py index 86148774..c4e099f6 100644 --- a/examples/subprocess_attach_write_pipe.py +++ b/examples/subprocess_attach_write_pipe.py @@ -29,7 +29,7 @@ def task(): stdout, stderr = yield from proc.communicate() print("stdout = %r" % stdout.decode()) - pipe.close() + transport.close() loop.run_until_complete(task()) loop.close() From aeb1824eb0a135d345fcd30d8d2b78e047562fc7 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 2 Feb 2015 17:44:23 +0100 Subject: [PATCH 1342/1502] Workaround CPython bug #23353 Don't use yield/yield-from in an except block of a generator. Store the exception and handle it outside the except block. --- asyncio/test_utils.py | 4 ++++ asyncio/unix_events.py | 12 ++++++++++-- asyncio/windows_events.py | 11 +++++++++-- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/asyncio/test_utils.py b/asyncio/test_utils.py index 6eedc583..8cee95b8 100644 --- a/asyncio/test_utils.py +++ b/asyncio/test_utils.py @@ -416,6 +416,10 @@ def new_test_loop(self, gen=None): def tearDown(self): events.set_event_loop(None) + # Detect CPython bug #23353: ensure that yield/yield-from is not used + # in an except block of a generator + self.assertEqual(sys.exc_info(), (None, None, None)) + @contextlib.contextmanager def disable_logger(): diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 1fc39abe..75e7c9cc 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -186,10 +186,18 @@ def _make_subprocess_transport(self, protocol, args, shell, self._child_watcher_callback, transp) try: yield from waiter - except: + except Exception as exc: + # Workaround CPython bug #23353: using yield/yield-from in an + # except block of a generator doesn't clear properly + # sys.exc_info() + err = exc + else: + err = None + + if err is not None: transp.close() yield from transp._wait() - raise + raise err return transp diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index c4bffc47..f311e463 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -373,10 +373,17 @@ def _make_subprocess_transport(self, protocol, args, shell, **kwargs) try: yield from waiter - except: + except Exception as exc: + # Workaround CPython bug #23353: using yield/yield-from in an + # except block of a generator doesn't clear properly sys.exc_info() + err = exc + else: + err = None + + if err is not None: transp.close() yield from transp._wait() - raise + raise err return transp From b9d37994e3ae66035e06d0a89607794b601bf7e1 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 3 Feb 2015 15:07:25 +0100 Subject: [PATCH 1343/1502] Tulip issue #221: Fix docstring of QueueEmpty and QueueFull --- asyncio/queues.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/asyncio/queues.py b/asyncio/queues.py index dce0d53c..4aeb6c45 100644 --- a/asyncio/queues.py +++ b/asyncio/queues.py @@ -13,12 +13,16 @@ class QueueEmpty(Exception): - 'Exception raised by Queue.get(block=0)/get_nowait().' + """Exception raised when Queue.get_nowait() is called on a Queue object + which is empty. + """ pass class QueueFull(Exception): - 'Exception raised by Queue.put(block=0)/put_nowait().' + """Exception raised when the Queue.put_nowait() method is called on a Queue + object which is full. + """ pass From d3dbdf8530ec52d42145610cac4e2f981f72b1fa Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 4 Feb 2015 11:25:51 +0100 Subject: [PATCH 1344/1502] BaseSelectorEventLoop uses directly the private _debug attribute Just try to be consistent: _debug was already used in some places, and always used in BaseProactorEventLoop. --- asyncio/selector_events.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 4bd6dc8d..7cbd4fd1 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -214,7 +214,7 @@ def _accept_connection2(self, protocol_factory, conn, extra, # It's now up to the protocol to handle the connection. except Exception as exc: - if self.get_debug(): + if self._debug: context = { 'message': ('Error on transport creation ' 'for incoming connection'), @@ -312,7 +312,7 @@ def sock_recv(self, sock, n): This method is a coroutine. """ - if self.get_debug() and sock.gettimeout() != 0: + if self._debug and sock.gettimeout() != 0: raise ValueError("the socket must be non-blocking") fut = futures.Future(loop=self) self._sock_recv(fut, False, sock, n) @@ -350,7 +350,7 @@ def sock_sendall(self, sock, data): This method is a coroutine. """ - if self.get_debug() and sock.gettimeout() != 0: + if self._debug and sock.gettimeout() != 0: raise ValueError("the socket must be non-blocking") fut = futures.Future(loop=self) if data: @@ -393,7 +393,7 @@ def sock_connect(self, sock, address): This method is a coroutine. """ - if self.get_debug() and sock.gettimeout() != 0: + if self._debug and sock.gettimeout() != 0: raise ValueError("the socket must be non-blocking") fut = futures.Future(loop=self) try: @@ -453,7 +453,7 @@ def sock_accept(self, sock): This method is a coroutine. """ - if self.get_debug() and sock.gettimeout() != 0: + if self._debug and sock.gettimeout() != 0: raise ValueError("the socket must be non-blocking") fut = futures.Future(loop=self) self._sock_accept(fut, False, sock) From 0105a8ff79590054e9c30948a1ef7cdb0faa4cef Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 4 Feb 2015 11:34:44 +0100 Subject: [PATCH 1345/1502] Only call _check_resolved_address() in debug mode * _check_resolved_address() is implemented with getaddrinfo() which is slow * If available, use socket.inet_pton() instead of socket.getaddrinfo(), because it is much faster Microbenchmark (timeit) on Fedora 21 (Python 3.4, Linux 3.17, glibc 2.20) to validate the IPV4 address "127.0.0.1" or the IPv6 address "::1": * getaddrinfo() 10.4 usec per loop * inet_pton(): 0.285 usec per loop On glibc older than 2.14, getaddrinfo() always requests the list of all local IP addresses to the kernel (using a NETLINK socket). getaddrinfo() has other known issues, it's better to avoid it when it is possible. --- asyncio/base_events.py | 48 +++++++++++++++++++++++++------------- asyncio/proactor_events.py | 3 ++- asyncio/selector_events.py | 3 ++- tests/test_events.py | 4 ++++ 4 files changed, 40 insertions(+), 18 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 7108f251..5c397543 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -75,7 +75,11 @@ class _StopError(BaseException): def _check_resolved_address(sock, address): # Ensure that the address is already resolved to avoid the trap of hanging # the entire event loop when the address requires doing a DNS lookup. + # + # getaddrinfo() is slow (around 10 us per call): this function should only + # be called in debug mode family = sock.family + if family == socket.AF_INET: host, port = address elif family == socket.AF_INET6: @@ -83,22 +87,34 @@ def _check_resolved_address(sock, address): else: return - type_mask = 0 - if hasattr(socket, 'SOCK_NONBLOCK'): - type_mask |= socket.SOCK_NONBLOCK - if hasattr(socket, 'SOCK_CLOEXEC'): - type_mask |= socket.SOCK_CLOEXEC - # Use getaddrinfo(flags=AI_NUMERICHOST) to ensure that the address is - # already resolved. - try: - socket.getaddrinfo(host, port, - family=family, - type=(sock.type & ~type_mask), - proto=sock.proto, - flags=socket.AI_NUMERICHOST) - except socket.gaierror as err: - raise ValueError("address must be resolved (IP address), got %r: %s" - % (address, err)) + # On Windows, socket.inet_pton() is only available since Python 3.4 + if hasattr(socket, 'inet_pton'): + # getaddrinfo() is slow and has known issue: prefer inet_pton() + # if available + try: + socket.inet_pton(family, host) + except OSError as exc: + raise ValueError("address must be resolved (IP address), " + "got host %r: %s" + % (host, exc)) + else: + # Use getaddrinfo(flags=AI_NUMERICHOST) to ensure that the address is + # already resolved. + type_mask = 0 + if hasattr(socket, 'SOCK_NONBLOCK'): + type_mask |= socket.SOCK_NONBLOCK + if hasattr(socket, 'SOCK_CLOEXEC'): + type_mask |= socket.SOCK_CLOEXEC + try: + socket.getaddrinfo(host, port, + family=family, + type=(sock.type & ~type_mask), + proto=sock.proto, + flags=socket.AI_NUMERICHOST) + except socket.gaierror as err: + raise ValueError("address must be resolved (IP address), " + "got host %r: %s" + % (host, err)) def _raise_stop_error(*args): raise _StopError diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 65de926b..9c2b8f15 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -437,7 +437,8 @@ def sock_sendall(self, sock, data): def sock_connect(self, sock, address): try: - base_events._check_resolved_address(sock, address) + if self._debug: + base_events._check_resolved_address(sock, address) except ValueError as err: fut = futures.Future(loop=self) fut.set_exception(err) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 7cbd4fd1..a38ed1ce 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -397,7 +397,8 @@ def sock_connect(self, sock, address): raise ValueError("the socket must be non-blocking") fut = futures.Future(loop=self) try: - base_events._check_resolved_address(sock, address) + if self._debug: + base_events._check_resolved_address(sock, address) except ValueError as err: fut.set_exception(err) else: diff --git a/tests/test_events.py b/tests/test_events.py index 4b957d8f..8fbba8fe 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1437,6 +1437,10 @@ def wait(): 'selector': self.loop._selector.__class__.__name__}) def test_sock_connect_address(self): + # In debug mode, sock_connect() must ensure that the address is already + # resolved (call _check_resolved_address()) + self.loop.set_debug(True) + addresses = [(socket.AF_INET, ('www.python.org', 80))] if support.IPV6_ENABLED: addresses.extend(( From 7b2d8abfce1d7ef18ef516f9b1b7032172630375 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 4 Feb 2015 22:08:17 +0100 Subject: [PATCH 1346/1502] BaseEventLoop: rename _owner to _thread_id --- asyncio/base_events.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 5c397543..eb867cd5 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -188,7 +188,7 @@ def __init__(self): self._internal_fds = 0 # Identifier of the thread running the event loop, or None if the # event loop is not running - self._owner = None + self._thread_id = None self._clock_resolution = time.get_clock_info('monotonic').resolution self._exception_handler = None self._debug = (not sys.flags.ignore_environment @@ -269,7 +269,7 @@ def run_forever(self): self._check_closed() if self.is_running(): raise RuntimeError('Event loop is running.') - self._owner = threading.get_ident() + self._thread_id = threading.get_ident() try: while True: try: @@ -277,7 +277,7 @@ def run_forever(self): except _StopError: break finally: - self._owner = None + self._thread_id = None def run_until_complete(self, future): """Run until the Future is done. @@ -362,7 +362,7 @@ def __del__(self): def is_running(self): """Returns True if the event loop is running.""" - return (self._owner is not None) + return (self._thread_id is not None) def time(self): """Return the time according to the event loop's clock. @@ -449,10 +449,10 @@ def _check_thread(self): Should only be called when (self._debug == True). The caller is responsible for checking this condition for performance reasons. """ - if self._owner is None: + if self._thread_id is None: return thread_id = threading.get_ident() - if thread_id != self._owner: + if thread_id != self._thread_id: raise RuntimeError( "Non-thread-safe operation invoked on an event loop other " "than the current one") From d496a7e1ca3c029c2f349f90b9ea84f7d5ae9381 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 10 Feb 2015 14:37:37 +0100 Subject: [PATCH 1347/1502] BaseSubprocessTransport.close() doesn't try to kill the process if it already finished --- asyncio/base_subprocess.py | 7 ++++- tests/test_subprocess.py | 55 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index 02b9e89f..5458ab15 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -93,7 +93,12 @@ def close(self): continue proto.pipe.close() - if self._proc is not None and self._returncode is None: + if (self._proc is not None + # the child process finished? + and self._returncode is None + # the child process finished but the transport was not notified yet? + and self._proc.poll() is None + ): if self._loop.get_debug(): logger.warning('Close running child process: kill %r', self) diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index b467b04f..de0b08af 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -349,6 +349,61 @@ def cancel_make_transport(): self.loop.run_until_complete(cancel_make_transport()) test_utils.run_briefly(self.loop) + def test_close_kill_running(self): + @asyncio.coroutine + def kill_running(): + create = self.loop.subprocess_exec(asyncio.SubprocessProtocol, + *PROGRAM_BLOCKED) + transport, protocol = yield from create + proc = transport.get_extra_info('subprocess') + proc.kill = mock.Mock() + returncode = transport.get_returncode() + transport.close() + return (returncode, proc.kill.called) + + # Ignore "Close running child process: kill ..." log + with test_utils.disable_logger(): + returncode, killed = self.loop.run_until_complete(kill_running()) + self.assertIsNone(returncode) + + # transport.close() must kill the process if it is still running + self.assertTrue(killed) + test_utils.run_briefly(self.loop) + + def test_close_dont_kill_finished(self): + @asyncio.coroutine + def kill_running(): + create = self.loop.subprocess_exec(asyncio.SubprocessProtocol, + *PROGRAM_BLOCKED) + transport, protocol = yield from create + proc = transport.get_extra_info('subprocess') + + # kill the process (but asyncio is not notified immediatly) + proc.kill() + proc.wait() + + proc.kill = mock.Mock() + proc_returncode = proc.poll() + transport_returncode = transport.get_returncode() + transport.close() + return (proc_returncode, transport_returncode, proc.kill.called) + + # Ignore "Unknown child process pid ..." log of SafeChildWatcher, + # emitted because the test already consumes the exit status: + # proc.wait() + with test_utils.disable_logger(): + result = self.loop.run_until_complete(kill_running()) + test_utils.run_briefly(self.loop) + + proc_returncode, transport_return_code, killed = result + + self.assertIsNotNone(proc_returncode) + self.assertIsNone(transport_return_code) + + # transport.close() must not kill the process if it finished, even if + # the transport was not notified yet + self.assertFalse(killed) + if sys.platform != 'win32': # Unix From 10a91fec4dcbd4b7e2be1331184bc8b92274582c Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 12 Feb 2015 16:15:20 +0100 Subject: [PATCH 1348/1502] BaseSubprocessTransport: repr() mentions when the child process is running --- asyncio/base_subprocess.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index 5458ab15..f56873fb 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -57,6 +57,8 @@ def __repr__(self): info.append('pid=%s' % self._pid) if self._returncode is not None: info.append('returncode=%s' % self._returncode) + else: + info.append('running') stdin = self._pipes.get(0) if stdin is not None: From d9c50ebe146984cd3e14d27863133fe7d8d4cc6f Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Sat, 14 Feb 2015 09:38:17 -0500 Subject: [PATCH 1349/1502] Tulip issue #220: Restore JoinableQueue as a deprecated alias for Queue. To more closely match the standard Queue, asyncio.Queue has "join" and "task_done". JoinableQueue remains as a deprecated alias for Queue to avoid needlessly breaking too much code that depended on it. --- asyncio/queues.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/asyncio/queues.py b/asyncio/queues.py index b0fb3873..8680d58c 100644 --- a/asyncio/queues.py +++ b/asyncio/queues.py @@ -1,6 +1,7 @@ """Queues""" -__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty'] +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty', + 'JoinableQueue'] import collections import heapq @@ -275,3 +276,7 @@ def _put(self, item): def _get(self): return self._queue.pop() + + +JoinableQueue = Queue +"""Deprecated alias for Queue.""" From b60bad4ec86069d0a8edd0209b033a9afcc5d25a Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 17 Feb 2015 22:48:49 +0100 Subject: [PATCH 1350/1502] tests: Use os.devnull instead of hardcoded '/dev/null'. Patch written by Serhiy Storchaka . --- tests/test_unix_events.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 41249ff0..dc0835c5 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -295,7 +295,7 @@ def test_create_unix_server_bind_error(self, m_socket): def test_create_unix_connection_path_sock(self): coro = self.loop.create_unix_connection( - lambda: None, '/dev/null', sock=object()) + lambda: None, os.devnull, sock=object()) with self.assertRaisesRegex(ValueError, 'path and sock can not be'): self.loop.run_until_complete(coro) @@ -308,14 +308,14 @@ def test_create_unix_connection_nopath_nosock(self): def test_create_unix_connection_nossl_serverhost(self): coro = self.loop.create_unix_connection( - lambda: None, '/dev/null', server_hostname='spam') + lambda: None, os.devnull, server_hostname='spam') with self.assertRaisesRegex(ValueError, 'server_hostname is only meaningful'): self.loop.run_until_complete(coro) def test_create_unix_connection_ssl_noserverhost(self): coro = self.loop.create_unix_connection( - lambda: None, '/dev/null', ssl=True) + lambda: None, os.devnull, ssl=True) with self.assertRaisesRegex( ValueError, 'you have to pass server_hostname when using ssl'): From c18c4d06bb6501e7192aa93dbe8af7de44c95386 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 17 Feb 2015 22:45:35 +0100 Subject: [PATCH 1351/1502] Python issue #23475: Fix test_close_kill_running() Really kill the child process, don't mock completly the Popen.kill() method. This change fix memory leaks and reference leaks. --- tests/test_subprocess.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index de0b08af..92bf1b45 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -355,11 +355,19 @@ def kill_running(): create = self.loop.subprocess_exec(asyncio.SubprocessProtocol, *PROGRAM_BLOCKED) transport, protocol = yield from create + + kill_called = False + def kill(): + nonlocal kill_called + kill_called = True + orig_kill() + proc = transport.get_extra_info('subprocess') - proc.kill = mock.Mock() + orig_kill = proc.kill + proc.kill = kill returncode = transport.get_returncode() transport.close() - return (returncode, proc.kill.called) + return (returncode, kill_called) # Ignore "Close running child process: kill ..." log with test_utils.disable_logger(): From e147802bca02ca103264c5a10fb84457cd12abbf Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 17 Feb 2015 23:34:16 +0100 Subject: [PATCH 1352/1502] Fix warning in test_close_kill_running() Read process exit status to avoid the "Caught subprocess termination from unknown pid" message. --- tests/test_subprocess.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index 92bf1b45..5ccdafb1 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -367,6 +367,7 @@ def kill(): proc.kill = kill returncode = transport.get_returncode() transport.close() + yield from transport._wait() return (returncode, kill_called) # Ignore "Close running child process: kill ..." log From 804962515bb9e0ae6257194743f0a9eea2982107 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 27 Feb 2015 17:45:04 +0100 Subject: [PATCH 1353/1502] Python issue #23537: Remove 2 unused private methods of BaseSubprocessTransport Methods only raise NotImplementedError and are never used. --- asyncio/base_subprocess.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index f56873fb..c1cdfda7 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -79,12 +79,6 @@ def __repr__(self): def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): raise NotImplementedError - def _make_write_subprocess_pipe_proto(self, fd): - raise NotImplementedError - - def _make_read_subprocess_pipe_proto(self, fd): - raise NotImplementedError - def close(self): if self._closed: return From 22a9854f02928049672ed615fbdaeb89c3b16132 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 10 Mar 2015 15:03:57 +0100 Subject: [PATCH 1354/1502] Added tag 3.4.3 for changeset 122233297cfd From f04f680979942b99c104d6eeea6b9b43320fb592 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 10 Mar 2015 15:12:58 +0100 Subject: [PATCH 1355/1502] Set version to 3.4.4 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 2581bfda..3e06b5b7 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ setup( name="asyncio", - version="3.4.3", + version="3.4.4", description="reference implementation of PEP 3156", long_description=long_description, From e0bf949b1202883e4bd525fbdd9595558b3a634f Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 10 Mar 2015 16:23:11 +0100 Subject: [PATCH 1356/1502] Write Tulip 3.4.3 changelog --- ChangeLog | 147 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 147 insertions(+) diff --git a/ChangeLog b/ChangeLog index e483fb2d..25155a98 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,149 @@ +2015-02-04: Tulip 3.4.3 +======================= + +Major changes +------------- + +* New SSL implementation using ssl.MemoryBIO. The new implementation requires + Python 3.5 and newer, otherwise the legacy implementation is used. +* On Python 3.5 and newer usable, the ProactorEventLoop now supports SSL + thanks to the new SSL implementation. +* Fix multiple resource leaks: close sockets on error, explicitly clear + references, emit ResourceWarning when event loops and transports are not + closed explicitly, etc. +* The proactor event loop is now much more reliable (no more known race + condition). +* Enhance handling of task cancellation. + +Changes of the asyncio API +-------------------------- + +* Export BaseEventLoop symbol in the asyncio namespace +* create_task(), call_soon(), call_soon_threadsafe(), call_later(), + call_at() and run_in_executor() methods of BaseEventLoop now raise an + exception if the event loop is closed. +* call_soon(), call_soon_threadsafe(), call_later(), call_at() and + run_in_executor() methods of BaseEventLoop now raise an exception if the + callback is a coroutine object. +* BaseEventLoopPolicy.get_event_loop() now always raises a RuntimeError + if there is no event loop in the curren thread, instead of using an + assertion (which can be disabld at runtime) and so raises an AssertionError. +* selectors: Selector.get_key() now raises an exception if the selector is + closed. +* If wait_for() is cancelled, the waited task is also cancelled. +* _UnixSelectorEventLoop.add_signal_handler() now raises an exception if + the callback is a coroutine object or a coroutine function. It also raises + an exception if the event loop is closed. + +Performances +------------ + +* sock_connect() doesn't check if the address is already resolved anymore. + The check is only done in debug mode. Moreover, the check uses inet_pton() + instead of getaddrinfo(), if inet_pton() is available, because getaddrinfo() + is slow (around 10 us per call). + +Debug +----- + +* Better repr() of _ProactorBasePipeTransport, _SelectorTransport, + _UnixReadPipeTransport and _UnixWritePipeTransport: add closed/closing + status and the file descriptor +* Add repr(PipeHandle) +* PipeHandle destructor now emits a ResourceWarning is the pipe is not closed + explicitly. +* In debug mode, call_at() method of BaseEventLoop now raises an exception + if called from the wrong thread (not from the thread running the event + loop). Before, it only raised an exception if current thread had an event + loop. +* A ResourceWarning is now emitted when event loops and transports are + destroyed before being closed. +* BaseEventLoop.call_exception_handler() now logs the traceback where + the current handle was created (if no source_traceback was specified). +* BaseSubprocessTransport.close() now logs a warning when the child process is + still running and the method kills it. + +Bug fixes +--------- + +* windows_utils.socketpair() now reuses socket.socketpair() if available + (Python 3.5 or newer). +* Fix IocpProactor.accept_pipe(): handle ERROR_PIPE_CONNECTED, it means + that the pipe is connected. _overlapped.Overlapped.ConnectNamedPipe() now + returns True on ERROR_PIPE_CONNECTED. +* Rewrite IocpProactor.connect_pipe() using polling to avoid tricky bugs + if the connection is cancelled, instead of using QueueUserWorkItem() to run + blocking code. +* Fix IocpProactor.recv(): handle BrokenPipeError, set the result to an empty + string. +* Fix ProactorEventLoop.start_serving_pipe(): if a client connected while the + server is closing, drop the client connection. +* Fix a tricky race condition when IocpProactor.wait_for_handle() is + cancelled: wait until the wait is really cancelled before destroying the + overlapped object. Unregister also the overlapped operation to not block + in IocpProactor.close() before the wait will never complete. +* Fix _UnixSubprocessTransport._start(): make the write end of the stdin pipe + non-inheritable. +* Set more attributes in the body of classes to avoid attribute errors in + destructors if an error occurred in the constructor. +* Fix SubprocessStreamProtocol.process_exited(): close the transport + and clear its reference to the transport. +* Fix SubprocessStreamProtocol.connection_made(): set the transport of + stdout and stderr streams to respect reader buffer limits (stop reading when + the buffer is full). +* Fix FlowControlMixin constructor: if the loop parameter is None, get the + current event loop. +* Fix selectors.EpollSelector.select(): don't fail anymore if no file + descriptor is registered. +* Fix _SelectorTransport: don't wakeup the waiter if it was cancelled +* Fix _SelectorTransport._call_connection_lost(): only call connection_lost() + if connection_made() was already called. +* Fix BaseSelectorEventLoop._accept_connection(): close the transport on + error. In debug mode, log errors (ex: SSL handshake failure) on the creation + of the transport for incoming connection. +* Fix BaseProactorEventLoop.close(): stop the proactor before closing the + event loop because stopping the proactor may schedule new callbacks, which + is now forbidden when the event loop is closed. +* Fix wrap_future() to not use a free variable and so not keep a frame alive + too long. +* Fix formatting of the "Future/Task exception was never retrieved" log: add + a newline before the traceback. +* WriteSubprocessPipeProto.connection_lost() now clears its reference to the + subprocess.Popen object. +* If the creation of a subprocess transport fails, the child process is killed + and the event loop waits asynchronously for its completion. +* BaseEventLoop.run_until_complete() now consumes the exception to not log a + warning when a BaseException like KeyboardInterrupt is raised and + run_until_complete() is not a future (but a coroutine object). +* create_connection(), create_datagram_endpoint(), connect_read_pipe() and + connect_write_pipe() methods of BaseEventLoop now close the transport on + error. + +Other changes +------------- + +* Add tox.ini to run tests using tox. +* _FlowControlMixin constructor now requires an event loop. +* Embed asyncio/test_support.py to not depend on test.support of the system + Python. For example, test.support is not installed by default on Windows. +* selectors.Selector.close() now clears its reference to the mapping object. +* _SelectorTransport and _UnixWritePipeTransport now only starts listening for + read events after protocol.connection_made() has been called +* _SelectorTransport._fatal_error() now only logs ConnectionAbortedError + in debug mode. +* BaseProactorEventLoop._loop_self_reading() now handles correctly + CancelledError (just exit) and logs an error for other exceptions. +* _ProactorBasePipeTransport now clears explicitly references to read and + write future and to the socket +* BaseSubprocessTransport constructor now calls the internal _connect_pipes() + method (previously called _post_init()). The constructor now accepts an + optional waiter parameter to notify when the transport is ready. +* send_signal(), terminate() and kill() methods of BaseSubprocessTransport now + raise a ProcessLookupError if the process already exited. +* Add run_aiotest.py to run the aiotest test suite +* Add release.py script to build wheel packages on Windows and run unit tests + + 2014-09-30: Tulip 3.4.2 ======================= @@ -156,6 +302,7 @@ Misc changes * runtests.py now mention if tests are running in release or debug mode. + 2014-05-19: Tulip 3.4.1 ======================= From 26c6fedd884e7b61ac397105d65e2d27984406f4 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 10 Mar 2015 16:29:12 +0100 Subject: [PATCH 1357/1502] Fix repr(BaseSubprocessTransport) if it didn't start yet Replace "running" with "not started" and don't show the pid if the subprocess didn't start yet. --- asyncio/base_subprocess.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index c1cdfda7..d18f3e8f 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -54,11 +54,14 @@ def __repr__(self): info = [self.__class__.__name__] if self._closed: info.append('closed') - info.append('pid=%s' % self._pid) + if self._pid is not None: + info.append('pid=%s' % self._pid) if self._returncode is not None: info.append('returncode=%s' % self._returncode) - else: + elif self._pid is not None: info.append('running') + else: + info.append('not started') stdin = self._pipes.get(0) if stdin is not None: From 16b312187af2a3dfab2e5030076e8b8836b2044d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 18 Mar 2015 11:36:45 +0100 Subject: [PATCH 1358/1502] Python issue #23456: Add missing @coroutine decorators --- asyncio/base_subprocess.py | 1 + asyncio/locks.py | 3 +++ asyncio/streams.py | 1 + 3 files changed, 5 insertions(+) diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index d18f3e8f..c1477b82 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -211,6 +211,7 @@ def _process_exited(self, returncode): waiter.set_result(returncode) self._exit_waiters = None + @coroutine def _wait(self): """Wait until the process exit and return the process return code. diff --git a/asyncio/locks.py b/asyncio/locks.py index b943e9dd..41a68c6c 100644 --- a/asyncio/locks.py +++ b/asyncio/locks.py @@ -162,6 +162,7 @@ def __exit__(self, *args): # always raises; that's how the with-statement works. pass + @coroutine def __iter__(self): # This is not a coroutine. It is meant to enable the idiom: # @@ -362,6 +363,7 @@ def __enter__(self): def __exit__(self, *args): pass + @coroutine def __iter__(self): # See comment in Lock.__iter__(). yield from self.acquire() @@ -446,6 +448,7 @@ def __enter__(self): def __exit__(self, *args): pass + @coroutine def __iter__(self): # See comment in Lock.__iter__(). yield from self.acquire() diff --git a/asyncio/streams.py b/asyncio/streams.py index 7ff16a48..64ff3d2e 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -378,6 +378,7 @@ def feed_data(self, data): else: self._paused = True + @coroutine def _wait_for_data(self, func_name): """Wait until feed_data() or feed_eof() is called.""" # StreamReader uses a future to link the protocol feed_data() method From 15a569d29b6ea60ebf9716117702e22d87ae638d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 27 Mar 2015 15:19:22 +0100 Subject: [PATCH 1359/1502] Fix _SelectorTransport.__repr__() if the event loop is closed --- asyncio/selector_events.py | 2 +- tests/test_selector_events.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index a38ed1ce..68e9415e 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -535,7 +535,7 @@ def __repr__(self): info.append('closing') info.append('fd=%s' % self._sock_fd) # test if the transport was closed - if self._loop is not None: + if self._loop is not None and not self._loop.is_closed(): polling = _test_selector_event(self._loop._selector, self._sock_fd, selectors.EVENT_READ) if polling: diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index f64e40da..9478b954 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -62,6 +62,11 @@ def test_make_socket_transport(self): self.loop.add_reader._is_coroutine = False transport = self.loop._make_socket_transport(m, asyncio.Protocol()) self.assertIsInstance(transport, _SelectorSocketTransport) + + # Calling repr() must not fail when the event loop is closed + self.loop.close() + repr(transport) + close_transport(transport) @unittest.skipIf(ssl is None, 'No ssl module') From b26e6ae21989993e48168f7c5a5db57d496114d2 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 7 Apr 2015 21:36:27 +0200 Subject: [PATCH 1360/1502] Python issue #23879: SelectorEventLoop.sock_connect() must not call connect() again if the first call to connect() raises an InterruptedError. When the C function connect() fails with EINTR, the connection runs in background. We have to wait until the socket becomes writable to be notified when the connection succeed or fails. --- asyncio/selector_events.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 68e9415e..7c5b9b5b 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -408,14 +408,12 @@ def sock_connect(self, sock, address): def _sock_connect(self, fut, sock, address): fd = sock.fileno() try: - while True: - try: - sock.connect(address) - except InterruptedError: - continue - else: - break - except BlockingIOError: + sock.connect(address) + except (BlockingIOError, InterruptedError): + # Issue #23618: When the C function connect() fails with EINTR, the + # connection runs in background. We have to wait until the socket + # becomes writable to be notified when the connection succeed or + # fails. fut.add_done_callback(functools.partial(self._sock_connect_done, fd)) self.add_writer(fd, self._sock_connect_cb, fut, sock, address) From bf4b2cea4f2f8195b658ce29279d8f5c5bcc9e2e Mon Sep 17 00:00:00 2001 From: "Ludovic Gasc (GMLudo)" Date: Sat, 11 Apr 2015 10:08:27 -0400 Subject: [PATCH 1361/1502] Rename README file to have rst render on Github --- README => README.rst | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename README => README.rst (96%) diff --git a/README b/README.rst similarity index 96% rename from README rename to README.rst index 2f3150a2..f20c1c0c 100644 --- a/README +++ b/README.rst @@ -2,7 +2,7 @@ Tulip is the codename for my reference implementation of PEP 3156. PEP 3156: http://www.python.org/dev/peps/pep-3156/ -*** This requires Python 3.3 or later! *** +**This requires Python 3.3 or later!** Copyright/license: Open source, Apache 2.0. Enjoy. diff --git a/setup.py b/setup.py index 3e06b5b7..93cacdd3 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ ) extensions.append(ext) -with open("README") as fp: +with open("README.rst") as fp: long_description = fp.read() setup( From 1888b1dd776ef3914fc1bb5035216a6f98b9b721 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 11 Apr 2015 10:30:29 -0400 Subject: [PATCH 1362/1502] Switch hgignore and hgeol to git equivalents --- .gitattributes | 2 ++ .gitignore | 15 +++++++++++++++ .hgeol | 4 ---- .hgignore | 15 --------------- 4 files changed, 17 insertions(+), 19 deletions(-) create mode 100644 .gitattributes create mode 100644 .gitignore delete mode 100644 .hgeol delete mode 100644 .hgignore diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..648632ce --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +* text=auto +*.py text diff=python diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..2f917d8d --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +*\.py[co] +*~ +*\.orig +*\#.* +*@.* +.coverage +htmlcov +.DS_Store +venv +distribute_setup.py +distribute-*.tar.gz +build +dist +*.egg-info +.tox diff --git a/.hgeol b/.hgeol deleted file mode 100644 index 8233b6dd..00000000 --- a/.hgeol +++ /dev/null @@ -1,4 +0,0 @@ -[patterns] -** = native -.hgignore = native -.hgeol = native diff --git a/.hgignore b/.hgignore deleted file mode 100644 index 736c7fdf..00000000 --- a/.hgignore +++ /dev/null @@ -1,15 +0,0 @@ -.*\.py[co]$ -.*~$ -.*\.orig$ -.*\#.*$ -.*@.*$ -\.coverage$ -htmlcov$ -\.DS_Store$ -venv$ -distribute_setup.py$ -distribute-\d+.\d+.\d+.tar.gz$ -build$ -dist$ -.*\.egg-info$ -\.tox$ From 30f4788525cce647c99aa06b2805d4c6b98692c6 Mon Sep 17 00:00:00 2001 From: "Ludovic Gasc (GMLudo)" Date: Sat, 11 Apr 2015 11:30:49 -0400 Subject: [PATCH 1363/1502] add in .gitignore pyvenv and Pycharm files --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 2f917d8d..16fae2b7 100644 --- a/.gitignore +++ b/.gitignore @@ -7,9 +7,12 @@ htmlcov .DS_Store venv +pyvenv distribute_setup.py distribute-*.tar.gz build dist *.egg-info .tox +.idea/ +*.iml \ No newline at end of file From 173ff866346d8210acfa268239aa99d907fc38a2 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Mon, 13 Apr 2015 19:36:50 -0400 Subject: [PATCH 1364/1502] Update README.rst Update repo reference from google code to github --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index f20c1c0c..329abddd 100644 --- a/README.rst +++ b/README.rst @@ -6,7 +6,7 @@ PEP 3156: http://www.python.org/dev/peps/pep-3156/ Copyright/license: Open source, Apache 2.0. Enjoy. -Master Mercurial repo: http://code.google.com/p/tulip/ +Master GitHub repo: https://github.com/python/tulip The actual code lives in the 'asyncio' subdirectory. Tests are in the 'tests' subdirectory. From b08ee4008dc939fff3aa86f040da02deb602f5db Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 14 Apr 2015 16:38:17 +0200 Subject: [PATCH 1365/1502] Fix @coroutine for functions without __name__ Issue #222: Fix the @coroutine decorator for functions without __name__ attribute like functools.partial(). Enhance also the representation of a CoroWrapper if the coroutine function is a functools.partial(). --- asyncio/coroutines.py | 18 ++++++++---- asyncio/events.py | 14 +++++++--- asyncio/futures.py | 2 +- tests/test_tasks.py | 65 +++++++++++++++++++++++++++++-------------- 4 files changed, 67 insertions(+), 32 deletions(-) diff --git a/asyncio/coroutines.py b/asyncio/coroutines.py index a1b28751..c6394610 100644 --- a/asyncio/coroutines.py +++ b/asyncio/coroutines.py @@ -151,7 +151,8 @@ def wrapper(*args, **kwds): w = CoroWrapper(coro(*args, **kwds), func) if w._source_traceback: del w._source_traceback[-1] - w.__name__ = func.__name__ + if hasattr(func, '__name__'): + w.__name__ = func.__name__ if hasattr(func, '__qualname__'): w.__qualname__ = func.__qualname__ w.__doc__ = func.__doc__ @@ -175,25 +176,30 @@ def iscoroutine(obj): def _format_coroutine(coro): assert iscoroutine(coro) - coro_name = getattr(coro, '__qualname__', coro.__name__) + + if isinstance(coro, CoroWrapper): + func = coro.func + else: + func = coro + coro_name = events._format_callback(func, ()) filename = coro.gi_code.co_filename if (isinstance(coro, CoroWrapper) and not inspect.isgeneratorfunction(coro.func)): filename, lineno = events._get_function_source(coro.func) if coro.gi_frame is None: - coro_repr = ('%s() done, defined at %s:%s' + coro_repr = ('%s done, defined at %s:%s' % (coro_name, filename, lineno)) else: - coro_repr = ('%s() running, defined at %s:%s' + coro_repr = ('%s running, defined at %s:%s' % (coro_name, filename, lineno)) elif coro.gi_frame is not None: lineno = coro.gi_frame.f_lineno - coro_repr = ('%s() running at %s:%s' + coro_repr = ('%s running at %s:%s' % (coro_name, filename, lineno)) else: lineno = coro.gi_code.co_firstlineno - coro_repr = ('%s() done, defined at %s:%s' + coro_repr = ('%s done, defined at %s:%s' % (coro_name, filename, lineno)) return coro_repr diff --git a/asyncio/events.py b/asyncio/events.py index 8a7bb814..3b907c6d 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -54,15 +54,21 @@ def _format_callback(func, args, suffix=''): suffix = _format_args(args) + suffix return _format_callback(func.func, func.args, suffix) - func_repr = getattr(func, '__qualname__', None) - if not func_repr: + if hasattr(func, '__qualname__'): + func_repr = getattr(func, '__qualname__') + elif hasattr(func, '__name__'): + func_repr = getattr(func, '__name__') + else: func_repr = repr(func) if args is not None: func_repr += _format_args(args) if suffix: func_repr += suffix + return func_repr +def _format_callback_source(func, args): + func_repr = _format_callback(func, args) source = _get_function_source(func) if source: func_repr += ' at %s:%s' % source @@ -92,7 +98,7 @@ def _repr_info(self): if self._cancelled: info.append('cancelled') if self._callback is not None: - info.append(_format_callback(self._callback, self._args)) + info.append(_format_callback_source(self._callback, self._args)) if self._source_traceback: frame = self._source_traceback[-1] info.append('created at %s:%s' % (frame[0], frame[1])) @@ -119,7 +125,7 @@ def _run(self): try: self._callback(*self._args) except Exception as exc: - cb = _format_callback(self._callback, self._args) + cb = _format_callback_source(self._callback, self._args) msg = 'Exception in callback {}'.format(cb) context = { 'message': msg, diff --git a/asyncio/futures.py b/asyncio/futures.py index 2c741fd4..74a99ba0 100644 --- a/asyncio/futures.py +++ b/asyncio/futures.py @@ -162,7 +162,7 @@ def _format_callbacks(self): cb = '' def format_cb(callback): - return events._format_callback(callback, ()) + return events._format_callback_source(callback, ()) if size == 1: cb = format_cb(cb[0]) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 06447d77..ab614621 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1,5 +1,7 @@ """Tests for tasks.py.""" +import contextlib +import functools import os import re import sys @@ -28,6 +30,19 @@ def coroutine_function(): pass +@contextlib.contextmanager +def set_coroutine_debug(enabled): + coroutines = asyncio.coroutines + + old_debug = coroutines._DEBUG + try: + coroutines._DEBUG = enabled + yield + finally: + coroutines._DEBUG = old_debug + + + def format_coroutine(qualname, state, src, source_traceback, generator=False): if generator: state = '%s' % state @@ -279,6 +294,29 @@ def wait_for(fut): fut.set_result(None) self.loop.run_until_complete(task) + def test_task_repr_partial_corowrapper(self): + # Issue #222: repr(CoroWrapper) must not fail in debug mode if the + # coroutine is a partial function + with set_coroutine_debug(True): + self.loop.set_debug(True) + + @asyncio.coroutine + def func(x, y): + yield from asyncio.sleep(0) + + partial_func = asyncio.coroutine(functools.partial(func, 1)) + task = self.loop.create_task(partial_func(2)) + + # make warnings quiet + task._log_destroy_pending = False + self.addCleanup(task._coro.close) + + coro_repr = repr(task._coro) + expected = ('.func(1)() running, ') + self.assertTrue(coro_repr.startswith(expected), + coro_repr) + def test_task_basics(self): @asyncio.coroutine def outer(): @@ -1555,25 +1593,16 @@ def coro(): # The frame should have changed. self.assertIsNone(gen.gi_frame) - # Save debug flag. - old_debug = asyncio.coroutines._DEBUG - try: - # Test with debug flag cleared. - asyncio.coroutines._DEBUG = False + # Test with debug flag cleared. + with set_coroutine_debug(False): check() - # Test with debug flag set. - asyncio.coroutines._DEBUG = True + # Test with debug flag set. + with set_coroutine_debug(True): check() - finally: - # Restore original debug flag. - asyncio.coroutines._DEBUG = old_debug - def test_yield_from_corowrapper(self): - old_debug = asyncio.coroutines._DEBUG - asyncio.coroutines._DEBUG = True - try: + with set_coroutine_debug(True): @asyncio.coroutine def t1(): return (yield from t2()) @@ -1591,8 +1620,6 @@ def t3(f): task = asyncio.Task(t1(), loop=self.loop) val = self.loop.run_until_complete(task) self.assertEqual(val, (1, 2, 3)) - finally: - asyncio.coroutines._DEBUG = old_debug def test_yield_from_corowrapper_send(self): def foo(): @@ -1663,14 +1690,10 @@ def kill_me(loop): @mock.patch('asyncio.coroutines.logger') def test_coroutine_never_yielded(self, m_log): - debug = asyncio.coroutines._DEBUG - try: - asyncio.coroutines._DEBUG = True + with set_coroutine_debug(True): @asyncio.coroutine def coro_noop(): pass - finally: - asyncio.coroutines._DEBUG = debug tb_filename = __file__ tb_lineno = sys._getframe().f_lineno + 2 From 7718675eb9d3acf4b45b8f99f7e64468caa7a2a3 Mon Sep 17 00:00:00 2001 From: "Ludovic Gasc (GMLudo)" Date: Tue, 14 Apr 2015 16:04:09 -0400 Subject: [PATCH 1366/1502] #230: Change official URL from tulip to asyncio in README.rst --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 329abddd..036bd5dd 100644 --- a/README.rst +++ b/README.rst @@ -6,7 +6,7 @@ PEP 3156: http://www.python.org/dev/peps/pep-3156/ Copyright/license: Open source, Apache 2.0. Enjoy. -Master GitHub repo: https://github.com/python/tulip +Master GitHub repo: https://github.com/python/asyncio The actual code lives in the 'asyncio' subdirectory. Tests are in the 'tests' subdirectory. From a943b49d6485caddde0b6129d31685737743123a Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Sun, 19 Apr 2015 22:47:11 -0400 Subject: [PATCH 1367/1502] Test LifoQueue's and PriorityQueue's put() and task_done(). --- tests/test_queues.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/tests/test_queues.py b/tests/test_queues.py index a73539d1..6edb9f2a 100644 --- a/tests/test_queues.py +++ b/tests/test_queues.py @@ -408,14 +408,16 @@ def test_order(self): self.assertEqual([1, 2, 3], items) -class QueueJoinTests(_QueueTestBase): +class _QueueJoinTestBase(_QueueTestBase): + + q_class = None def test_task_done_underflow(self): - q = asyncio.Queue(loop=self.loop) + q = self.q_class(loop=self.loop) self.assertRaises(ValueError, q.task_done) def test_task_done(self): - q = asyncio.Queue(loop=self.loop) + q = self.q_class(loop=self.loop) for i in range(100): q.put_nowait(i) @@ -452,7 +454,7 @@ def test(): self.loop.run_until_complete(asyncio.wait(tasks, loop=self.loop)) def test_join_empty_queue(self): - q = asyncio.Queue(loop=self.loop) + q = self.q_class(loop=self.loop) # Test that a queue join()s successfully, and before anything else # (done twice for insurance). @@ -465,12 +467,24 @@ def join(): self.loop.run_until_complete(join()) def test_format(self): - q = asyncio.Queue(loop=self.loop) + q = self.q_class(loop=self.loop) self.assertEqual(q._format(), 'maxsize=0') q._unfinished_tasks = 2 self.assertEqual(q._format(), 'maxsize=0 tasks=2') +class QueueJoinTests(_QueueJoinTestBase): + q_class = asyncio.Queue + + +class LifoQueueJoinTests(_QueueJoinTestBase): + q_class = asyncio.LifoQueue + + +class PriorityQueueJoinTests(_QueueJoinTestBase): + q_class = asyncio.PriorityQueue + + if __name__ == '__main__': unittest.main() From e496c7c8e05ef1fe7f49cce51d40e0b669142c02 Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Sun, 19 Apr 2015 22:48:49 -0400 Subject: [PATCH 1368/1502] Fix LifoQueue's and PriorityQueue's put() and task_done(). --- asyncio/queues.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/asyncio/queues.py b/asyncio/queues.py index 84cdabcf..ed116620 100644 --- a/asyncio/queues.py +++ b/asyncio/queues.py @@ -54,6 +54,8 @@ def __init__(self, maxsize=0, *, loop=None): self._finished.set() self._init(maxsize) + # These three are overridable in subclasses. + def _init(self, maxsize): self._queue = collections.deque() @@ -62,6 +64,11 @@ def _get(self): def _put(self, item): self._queue.append(item) + + # End of the overridable methods. + + def __put_internal(self, item): + self._put(item) self._unfinished_tasks += 1 self._finished.clear() @@ -133,7 +140,7 @@ def put(self, item): 'queue non-empty, why are getters waiting?') getter = self._getters.popleft() - self._put(item) + self.__put_internal(item) # getter cannot be cancelled, we just removed done getters getter.set_result(self._get()) @@ -145,7 +152,7 @@ def put(self, item): yield from waiter else: - self._put(item) + self.__put_internal(item) def put_nowait(self, item): """Put an item into the queue without blocking. @@ -158,7 +165,7 @@ def put_nowait(self, item): 'queue non-empty, why are getters waiting?') getter = self._getters.popleft() - self._put(item) + self.__put_internal(item) # getter cannot be cancelled, we just removed done getters getter.set_result(self._get()) @@ -166,7 +173,7 @@ def put_nowait(self, item): elif self._maxsize > 0 and self._maxsize <= self.qsize(): raise QueueFull else: - self._put(item) + self.__put_internal(item) @coroutine def get(self): @@ -180,7 +187,7 @@ def get(self): if self._putters: assert self.full(), 'queue not full, why are putters waiting?' item, putter = self._putters.popleft() - self._put(item) + self.__put_internal(item) # When a getter runs and frees up a slot so this putter can # run, we need to defer the put for a tick to ensure that @@ -207,7 +214,7 @@ def get_nowait(self): if self._putters: assert self.full(), 'queue not full, why are putters waiting?' item, putter = self._putters.popleft() - self._put(item) + self.__put_internal(item) # Wake putter on next tick. # getter cannot be cancelled, we just removed done putters From 70d885692864a3bef520bbec5ecafb2267a99c7c Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Mon, 20 Apr 2015 02:26:06 -0400 Subject: [PATCH 1369/1502] Fix queue join tests for CPython's test runner. "python -m test" was running tests in the base class, which failed. --- tests/test_queues.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_queues.py b/tests/test_queues.py index 6edb9f2a..88b4f075 100644 --- a/tests/test_queues.py +++ b/tests/test_queues.py @@ -408,7 +408,7 @@ def test_order(self): self.assertEqual([1, 2, 3], items) -class _QueueJoinTestBase(_QueueTestBase): +class _QueueJoinTestMixin: q_class = None @@ -474,15 +474,15 @@ def test_format(self): self.assertEqual(q._format(), 'maxsize=0 tasks=2') -class QueueJoinTests(_QueueJoinTestBase): +class QueueJoinTests(_QueueJoinTestMixin, _QueueTestBase): q_class = asyncio.Queue -class LifoQueueJoinTests(_QueueJoinTestBase): +class LifoQueueJoinTests(_QueueJoinTestMixin, _QueueTestBase): q_class = asyncio.LifoQueue -class PriorityQueueJoinTests(_QueueJoinTestBase): +class PriorityQueueJoinTests(_QueueJoinTestMixin, _QueueTestBase): q_class = asyncio.PriorityQueue From 2798fb43af22c966a0c7ba15258a073cf651a3c2 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 5 May 2015 14:19:23 -0700 Subject: [PATCH 1370/1502] Rename the function arg to run_in_executor() to "func" to avoid confusion. --- asyncio/base_events.py | 16 ++++++++-------- asyncio/events.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index eb867cd5..bfa435ca 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -465,25 +465,25 @@ def call_soon_threadsafe(self, callback, *args): self._write_to_self() return handle - def run_in_executor(self, executor, callback, *args): - if (coroutines.iscoroutine(callback) - or coroutines.iscoroutinefunction(callback)): + def run_in_executor(self, executor, func, *args): + if (coroutines.iscoroutine(func) + or coroutines.iscoroutinefunction(func)): raise TypeError("coroutines cannot be used with run_in_executor()") self._check_closed() - if isinstance(callback, events.Handle): + if isinstance(func, events.Handle): assert not args - assert not isinstance(callback, events.TimerHandle) - if callback._cancelled: + assert not isinstance(func, events.TimerHandle) + if func._cancelled: f = futures.Future(loop=self) f.set_result(None) return f - callback, args = callback._callback, callback._args + func, args = func._callback, func._args if executor is None: executor = self._default_executor if executor is None: executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) self._default_executor = executor - return futures.wrap_future(executor.submit(callback, *args), loop=self) + return futures.wrap_future(executor.submit(func, *args), loop=self) def set_default_executor(self, executor): self._default_executor = executor diff --git a/asyncio/events.py b/asyncio/events.py index 3b907c6d..99e12e66 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -277,7 +277,7 @@ def create_task(self, coro): def call_soon_threadsafe(self, callback, *args): raise NotImplementedError - def run_in_executor(self, executor, callback, *args): + def run_in_executor(self, executor, func, *args): raise NotImplementedError def set_default_executor(self, executor): From 71f7c249efc8a97e7e06d25c65ae96a2c001a6b3 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Mon, 11 May 2015 12:30:48 -0400 Subject: [PATCH 1371/1502] Add new loop APIs: set_task_factory() and get_task_factory() --- asyncio/base_events.py | 28 +++++++++++++++++++++++++--- asyncio/events.py | 8 ++++++++ tests/test_base_events.py | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 69 insertions(+), 3 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index bfa435ca..efbb9f40 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -197,6 +197,7 @@ def __init__(self): # exceed this duration in seconds, the slow callback/task is logged. self.slow_callback_duration = 0.1 self._current_handle = None + self._task_factory = None def __repr__(self): return ('<%s running=%s closed=%s debug=%s>' @@ -209,11 +210,32 @@ def create_task(self, coro): Return a task object. """ self._check_closed() - task = tasks.Task(coro, loop=self) - if task._source_traceback: - del task._source_traceback[-1] + if self._task_factory is None: + task = tasks.Task(coro, loop=self) + if task._source_traceback: + del task._source_traceback[-1] + else: + task = self._task_factory(self, coro) return task + def set_task_factory(self, factory): + """Set a task factory that will be used by loop.create_task(). + + If factory is None the default task factory will be set. + + If factory is a callable, it should have a signature matching + '(loop, coro)', where 'loop' will be a reference to the active + event loop, 'coro' will be a coroutine object. The callable + must return a Future. + """ + if factory is not None and not callable(factory): + raise TypeError('task factory must be a callable or None') + self._task_factory = factory + + def get_task_factory(self): + """Return a task factory, or None if the default one is in use.""" + return self._task_factory + def _make_socket_transport(self, sock, protocol, waiter=None, *, extra=None, server=None): """Create socket transport.""" diff --git a/asyncio/events.py b/asyncio/events.py index 99e12e66..496075ba 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -438,6 +438,14 @@ def add_signal_handler(self, sig, callback, *args): def remove_signal_handler(self, sig): raise NotImplementedError + # Task factory. + + def set_task_factory(self, factory): + raise NotImplementedError + + def get_task_factory(self): + raise NotImplementedError + # Error handlers. def set_exception_handler(self, handler): diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 9e7c50cc..af6a4c3f 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -623,6 +623,42 @@ def custom_handler(loop, context): self.assertIs(type(_context['context']['exception']), ZeroDivisionError) + def test_set_task_factory_invalid(self): + with self.assertRaisesRegex( + TypeError, 'task factory must be a callable or None'): + + self.loop.set_task_factory(1) + + self.assertIsNone(self.loop.get_task_factory()) + + def test_set_task_factory(self): + self.loop._process_events = mock.Mock() + + class MyTask(asyncio.Task): + pass + + @asyncio.coroutine + def coro(): + pass + + factory = lambda loop, coro: MyTask(coro, loop=loop) + + self.assertIsNone(self.loop.get_task_factory()) + self.loop.set_task_factory(factory) + self.assertIs(self.loop.get_task_factory(), factory) + + task = self.loop.create_task(coro()) + self.assertTrue(isinstance(task, MyTask)) + self.loop.run_until_complete(task) + + self.loop.set_task_factory(None) + self.assertIsNone(self.loop.get_task_factory()) + + task = self.loop.create_task(coro()) + self.assertTrue(isinstance(task, asyncio.Task)) + self.assertFalse(isinstance(task, MyTask)) + self.loop.run_until_complete(task) + def test_env_var_debug(self): code = '\n'.join(( 'import asyncio', From 0f50393e51e89c000baa3644862c9381f2ee5e0b Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Mon, 11 May 2015 13:45:50 -0400 Subject: [PATCH 1372/1502] Sync script_helper import with CPython --- asyncio/test_support.py | 2 +- tests/test_base_events.py | 2 +- tests/test_tasks.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/asyncio/test_support.py b/asyncio/test_support.py index 3da47558..543b27a7 100644 --- a/asyncio/test_support.py +++ b/asyncio/test_support.py @@ -300,6 +300,6 @@ def requires_freebsd_version(*min_version): # Use test.script_helper if available try: - from test.script_helper import assert_python_ok + from test.support.script_helper import assert_python_ok except ImportError: pass diff --git a/tests/test_base_events.py b/tests/test_base_events.py index af6a4c3f..aaa8e67b 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -16,7 +16,7 @@ from asyncio import test_utils try: from test import support - from test.script_helper import assert_python_ok + from test.support.script_helper import assert_python_ok except ImportError: from asyncio import test_support as support from asyncio.test_support import assert_python_ok diff --git a/tests/test_tasks.py b/tests/test_tasks.py index ab614621..e47a668e 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -15,7 +15,7 @@ from asyncio import test_utils try: from test import support - from test.script_helper import assert_python_ok + from test.support.script_helper import assert_python_ok except ImportError: from asyncio import test_support as support from asyncio.test_support import assert_python_ok From 26595991285ee44427cfa00adb3f9fff70953166 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Mon, 11 May 2015 13:50:15 -0400 Subject: [PATCH 1373/1502] Make sure that CPython 3.4 and older will import script_helper --- asyncio/test_support.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/asyncio/test_support.py b/asyncio/test_support.py index 543b27a7..0fadfad9 100644 --- a/asyncio/test_support.py +++ b/asyncio/test_support.py @@ -302,4 +302,7 @@ def requires_freebsd_version(*min_version): try: from test.support.script_helper import assert_python_ok except ImportError: - pass + try: + from test.script_helper import assert_python_ok + except ImportError: + pass From 1181bc193ad7636d79267fc4b3f03ea238a3f345 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Mon, 11 May 2015 14:30:04 -0400 Subject: [PATCH 1374/1502] Fix script_helper imports for 3.4 --- tests/test_base_events.py | 9 +++++++-- tests/test_tasks.py | 9 +++++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index aaa8e67b..fd864ce6 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -16,10 +16,15 @@ from asyncio import test_utils try: from test import support - from test.support.script_helper import assert_python_ok except ImportError: from asyncio import test_support as support - from asyncio.test_support import assert_python_ok +try: + from test.support.script_helper import assert_python_ok +except ImportError: + try: + from test.script_helper import assert_python_ok + except ImportError: + from asyncio.test_support import assert_python_ok MOCK_ANY = mock.ANY diff --git a/tests/test_tasks.py b/tests/test_tasks.py index e47a668e..5b49e768 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -15,10 +15,15 @@ from asyncio import test_utils try: from test import support - from test.support.script_helper import assert_python_ok except ImportError: from asyncio import test_support as support - from asyncio.test_support import assert_python_ok +try: + from test.support.script_helper import assert_python_ok +except ImportError: + try: + from test.script_helper import assert_python_ok + except ImportError: + from asyncio.test_support import assert_python_ok PY34 = (sys.version_info >= (3, 4)) From 53edb85e6e08765008595a39eb28b4c544baa326 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Mon, 11 May 2015 11:56:36 -0400 Subject: [PATCH 1375/1502] Deprecate async() function in favour of ensure_future() --- asyncio/base_events.py | 2 +- asyncio/tasks.py | 27 +++++++++++++++----- asyncio/windows_events.py | 2 +- tests/test_base_events.py | 6 ++--- tests/test_tasks.py | 48 ++++++++++++++++++++---------------- tests/test_windows_events.py | 2 +- 6 files changed, 54 insertions(+), 33 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index efbb9f40..98aadaf1 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -315,7 +315,7 @@ def run_until_complete(self, future): self._check_closed() new_task = not isinstance(future, futures.Future) - future = tasks.async(future, loop=self) + future = tasks.ensure_future(future, loop=self) if new_task: # An exception is raised if the future didn't complete, so there # is no need to log the "destroy pending task" message diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 4f19a252..5840d2a4 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -3,7 +3,7 @@ __all__ = ['Task', 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', 'wait', 'wait_for', 'as_completed', 'sleep', 'async', - 'gather', 'shield', + 'gather', 'shield', 'ensure_future', ] import concurrent.futures @@ -12,6 +12,7 @@ import linecache import sys import traceback +import warnings import weakref from . import coroutines @@ -327,7 +328,7 @@ def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): if loop is None: loop = events.get_event_loop() - fs = {async(f, loop=loop) for f in set(fs)} + fs = {ensure_future(f, loop=loop) for f in set(fs)} return (yield from _wait(fs, timeout, return_when, loop)) @@ -361,7 +362,7 @@ def wait_for(fut, timeout, *, loop=None): timeout_handle = loop.call_later(timeout, _release_waiter, waiter) cb = functools.partial(_release_waiter, waiter) - fut = async(fut, loop=loop) + fut = ensure_future(fut, loop=loop) fut.add_done_callback(cb) try: @@ -449,7 +450,7 @@ def as_completed(fs, *, loop=None, timeout=None): if isinstance(fs, futures.Future) or coroutines.iscoroutine(fs): raise TypeError("expect a list of futures, not %s" % type(fs).__name__) loop = loop if loop is not None else events.get_event_loop() - todo = {async(f, loop=loop) for f in set(fs)} + todo = {ensure_future(f, loop=loop) for f in set(fs)} from .queues import Queue # Import here to avoid circular import problem. done = Queue(loop=loop) timeout_handle = None @@ -499,6 +500,20 @@ def sleep(delay, result=None, *, loop=None): def async(coro_or_future, *, loop=None): """Wrap a coroutine in a future. + If the argument is a Future, it is returned directly. + + This function is deprecated in 3.5. Use asyncio.ensure_future() instead. + """ + + warnings.warn("asyncio.async() function is deprecated, use ensure_future()", + RuntimeWarning) + + return ensure_future(coro_or_future, loop=loop) + + +def ensure_future(coro_or_future, *, loop=None): + """Wrap a coroutine in a future. + If the argument is a Future, it is returned directly. """ if isinstance(coro_or_future, futures.Future): @@ -564,7 +579,7 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False): arg_to_fut = {} for arg in set(coros_or_futures): if not isinstance(arg, futures.Future): - fut = async(arg, loop=loop) + fut = ensure_future(arg, loop=loop) if loop is None: loop = fut._loop # The caller cannot control this future, the "destroy pending task" @@ -640,7 +655,7 @@ def shield(arg, *, loop=None): except CancelledError: res = None """ - inner = async(arg, loop=loop) + inner = ensure_future(arg, loop=loop) if inner.done(): # Shortcut. return inner diff --git a/asyncio/windows_events.py b/asyncio/windows_events.py index f311e463..922594f1 100644 --- a/asyncio/windows_events.py +++ b/asyncio/windows_events.py @@ -488,7 +488,7 @@ def accept_coro(future, conn): future = self._register(ov, listener, finish_accept) coro = accept_coro(future, conn) - tasks.async(coro, loop=self._loop) + tasks.ensure_future(coro, loop=self._loop) return future def connect(self, conn, address): diff --git a/tests/test_base_events.py b/tests/test_base_events.py index fd864ce6..8c4498cf 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -504,7 +504,7 @@ def zero_error_coro(): # Test Future.__del__ with mock.patch('asyncio.base_events.logger') as log: - fut = asyncio.async(zero_error_coro(), loop=self.loop) + fut = asyncio.ensure_future(zero_error_coro(), loop=self.loop) fut.add_done_callback(lambda *args: self.loop.stop()) self.loop.run_forever() fut = None # Trigger Future.__del__ or futures._TracebackLogger @@ -703,7 +703,7 @@ def create_task(self, coro): self.set_event_loop(loop) coro = test() - task = asyncio.async(coro, loop=loop) + task = asyncio.ensure_future(coro, loop=loop) self.assertIsInstance(task, MyTask) # make warnings quiet @@ -1265,7 +1265,7 @@ def stop_loop_coro(loop): "took .* seconds$") # slow task - asyncio.async(stop_loop_coro(self.loop), loop=self.loop) + asyncio.ensure_future(stop_loop_coro(self.loop), loop=self.loop) self.loop.run_forever() fmt, *args = m_logger.warning.call_args[0] self.assertRegex(fmt % tuple(args), diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 5b49e768..8c799833 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -92,11 +92,11 @@ def notmuch(): loop.run_until_complete(t) loop.close() - def test_async_coroutine(self): + def test_ensure_future_coroutine(self): @asyncio.coroutine def notmuch(): return 'ok' - t = asyncio.async(notmuch(), loop=self.loop) + t = asyncio.ensure_future(notmuch(), loop=self.loop) self.loop.run_until_complete(t) self.assertTrue(t.done()) self.assertEqual(t.result(), 'ok') @@ -104,16 +104,16 @@ def notmuch(): loop = asyncio.new_event_loop() self.set_event_loop(loop) - t = asyncio.async(notmuch(), loop=loop) + t = asyncio.ensure_future(notmuch(), loop=loop) self.assertIs(t._loop, loop) loop.run_until_complete(t) loop.close() - def test_async_future(self): + def test_ensure_future_future(self): f_orig = asyncio.Future(loop=self.loop) f_orig.set_result('ko') - f = asyncio.async(f_orig) + f = asyncio.ensure_future(f_orig) self.loop.run_until_complete(f) self.assertTrue(f.done()) self.assertEqual(f.result(), 'ko') @@ -123,19 +123,19 @@ def test_async_future(self): self.set_event_loop(loop) with self.assertRaises(ValueError): - f = asyncio.async(f_orig, loop=loop) + f = asyncio.ensure_future(f_orig, loop=loop) loop.close() - f = asyncio.async(f_orig, loop=self.loop) + f = asyncio.ensure_future(f_orig, loop=self.loop) self.assertIs(f, f_orig) - def test_async_task(self): + def test_ensure_future_task(self): @asyncio.coroutine def notmuch(): return 'ok' t_orig = asyncio.Task(notmuch(), loop=self.loop) - t = asyncio.async(t_orig) + t = asyncio.ensure_future(t_orig) self.loop.run_until_complete(t) self.assertTrue(t.done()) self.assertEqual(t.result(), 'ok') @@ -145,16 +145,22 @@ def notmuch(): self.set_event_loop(loop) with self.assertRaises(ValueError): - t = asyncio.async(t_orig, loop=loop) + t = asyncio.ensure_future(t_orig, loop=loop) loop.close() - t = asyncio.async(t_orig, loop=self.loop) + t = asyncio.ensure_future(t_orig, loop=self.loop) self.assertIs(t, t_orig) - def test_async_neither(self): + def test_ensure_future_neither(self): with self.assertRaises(TypeError): - asyncio.async('ok') + asyncio.ensure_future('ok') + + def test_async_warning(self): + f = asyncio.Future(loop=self.loop) + with self.assertWarnsRegex(RuntimeWarning, + 'function is deprecated, use ensure_'): + self.assertIs(f, asyncio.async(f)) def test_task_repr(self): self.loop.set_debug(False) @@ -1420,7 +1426,7 @@ def outer(): else: proof += 10 - f = asyncio.async(outer(), loop=self.loop) + f = asyncio.ensure_future(outer(), loop=self.loop) test_utils.run_briefly(self.loop) f.cancel() self.loop.run_until_complete(f) @@ -1445,7 +1451,7 @@ def outer(): d, p = yield from asyncio.wait([inner()], loop=self.loop) proof += 100 - f = asyncio.async(outer(), loop=self.loop) + f = asyncio.ensure_future(outer(), loop=self.loop) test_utils.run_briefly(self.loop) f.cancel() self.assertRaises( @@ -1501,7 +1507,7 @@ def outer(): yield from asyncio.shield(inner(), loop=self.loop) proof += 100 - f = asyncio.async(outer(), loop=self.loop) + f = asyncio.ensure_future(outer(), loop=self.loop) test_utils.run_briefly(self.loop) f.cancel() with self.assertRaises(asyncio.CancelledError): @@ -1668,7 +1674,7 @@ def kill_me(loop): # schedule the task coro = kill_me(self.loop) - task = asyncio.async(coro, loop=self.loop) + task = asyncio.ensure_future(coro, loop=self.loop) self.assertEqual(asyncio.Task.all_tasks(loop=self.loop), {task}) # execute the task so it waits for future @@ -1996,8 +2002,8 @@ def inner(): yield from waiter proof += 1 - child1 = asyncio.async(inner(), loop=self.one_loop) - child2 = asyncio.async(inner(), loop=self.one_loop) + child1 = asyncio.ensure_future(inner(), loop=self.one_loop) + child2 = asyncio.ensure_future(inner(), loop=self.one_loop) gatherer = None @asyncio.coroutine @@ -2007,7 +2013,7 @@ def outer(): yield from gatherer proof += 100 - f = asyncio.async(outer(), loop=self.one_loop) + f = asyncio.ensure_future(outer(), loop=self.one_loop) test_utils.run_briefly(self.one_loop) self.assertTrue(f.cancel()) with self.assertRaises(asyncio.CancelledError): @@ -2034,7 +2040,7 @@ def inner(f): def outer(): yield from asyncio.gather(inner(a), inner(b), loop=self.one_loop) - f = asyncio.async(outer(), loop=self.one_loop) + f = asyncio.ensure_future(outer(), loop=self.one_loop) test_utils.run_briefly(self.one_loop) a.set_result(None) test_utils.run_briefly(self.one_loop) diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index 73d8fcdb..657a4274 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -37,7 +37,7 @@ def setUp(self): def test_close(self): a, b = self.loop._socketpair() trans = self.loop._make_socket_transport(a, asyncio.Protocol()) - f = asyncio.async(self.loop.sock_recv(b, 100)) + f = asyncio.ensure_future(self.loop.sock_recv(b, 100)) trans.close() self.loop.run_until_complete(f) self.assertEqual(f.result(), b'') From 6c6a5e4a069229f808266fdc21ec10b31c41423d Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Mon, 11 May 2015 14:13:53 -0400 Subject: [PATCH 1376/1502] Use DeprecationWarning for async() --- asyncio/tasks.py | 2 +- tests/test_tasks.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 5840d2a4..f617b62b 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -506,7 +506,7 @@ def async(coro_or_future, *, loop=None): """ warnings.warn("asyncio.async() function is deprecated, use ensure_future()", - RuntimeWarning) + DeprecationWarning) return ensure_future(coro_or_future, loop=loop) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 8c799833..4119085d 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -158,7 +158,7 @@ def test_ensure_future_neither(self): def test_async_warning(self): f = asyncio.Future(loop=self.loop) - with self.assertWarnsRegex(RuntimeWarning, + with self.assertWarnsRegex(DeprecationWarning, 'function is deprecated, use ensure_'): self.assertIs(f, asyncio.async(f)) From e3216b80438bf42abb76c99b034eb03b2768018d Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Mon, 11 May 2015 15:03:35 -0400 Subject: [PATCH 1377/1502] Enable Travis-CI integration --- .travis.yml | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 .travis.yml diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 00000000..2c5838b8 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,17 @@ +language: python + +os: + - linux + - osx + +python: + - 3.3 + - 3.4 + +install: + - pip install asyncio + - python setup.py install + +script: + - python runtests.py + - PYTHONASYNCIODEBUG=1 python runtests.py From 36e714117f1e1e88d2358dccbea19947752d9d5f Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Mon, 11 May 2015 22:24:00 -0400 Subject: [PATCH 1378/1502] Support PEP 492 native coroutines. --- asyncio/base_events.py | 47 +++++++++++++++++++----- asyncio/coroutines.py | 77 ++++++++++++++++++++++++++++++++------- asyncio/futures.py | 4 ++ asyncio/tasks.py | 8 +++- tests/test_base_events.py | 3 +- tests/test_tasks.py | 4 +- 6 files changed, 116 insertions(+), 27 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 98aadaf1..38344a77 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -191,8 +191,8 @@ def __init__(self): self._thread_id = None self._clock_resolution = time.get_clock_info('monotonic').resolution self._exception_handler = None - self._debug = (not sys.flags.ignore_environment - and bool(os.environ.get('PYTHONASYNCIODEBUG'))) + self.set_debug((not sys.flags.ignore_environment + and bool(os.environ.get('PYTHONASYNCIODEBUG')))) # In debug mode, if the execution of a callback or a step of a task # exceed this duration in seconds, the slow callback/task is logged. self.slow_callback_duration = 0.1 @@ -360,13 +360,18 @@ def close(self): return if self._debug: logger.debug("Close %r", self) - self._closed = True - self._ready.clear() - self._scheduled.clear() - executor = self._default_executor - if executor is not None: - self._default_executor = None - executor.shutdown(wait=False) + try: + self._closed = True + self._ready.clear() + self._scheduled.clear() + executor = self._default_executor + if executor is not None: + self._default_executor = None + executor.shutdown(wait=False) + finally: + # It is important to unregister "sys.coroutine_wrapper" + # if it was registered. + self.set_debug(False) def is_closed(self): """Returns True if the event loop was closed.""" @@ -1199,3 +1204,27 @@ def get_debug(self): def set_debug(self, enabled): self._debug = enabled + wrapper = coroutines.debug_wrapper + + try: + set_wrapper = sys.set_coroutine_wrapper + except AttributeError: + pass + else: + current_wrapper = sys.get_coroutine_wrapper() + if enabled: + if current_wrapper not in (None, wrapper): + warnings.warn( + "loop.set_debug(True): cannot set debug coroutine " + "wrapper; another wrapper is already set %r" % + current_wrapper, RuntimeWarning) + else: + set_wrapper(wrapper) + else: + if current_wrapper not in (None, wrapper): + warnings.warn( + "loop.set_debug(False): cannot unset debug coroutine " + "wrapper; another wrapper was set %r" % + current_wrapper, RuntimeWarning) + else: + set_wrapper(None) diff --git a/asyncio/coroutines.py b/asyncio/coroutines.py index c6394610..20c45798 100644 --- a/asyncio/coroutines.py +++ b/asyncio/coroutines.py @@ -14,6 +14,9 @@ from .log import logger +_PY35 = sys.version_info >= (3, 5) + + # Opcode of "yield from" instruction _YIELD_FROM = opcode.opmap['YIELD_FROM'] @@ -30,6 +33,27 @@ and bool(os.environ.get('PYTHONASYNCIODEBUG'))) +try: + types.coroutine +except AttributeError: + native_coroutine_support = False +else: + native_coroutine_support = True + +try: + _iscoroutinefunction = inspect.iscoroutinefunction +except AttributeError: + _iscoroutinefunction = lambda func: False + +try: + inspect.CO_COROUTINE +except AttributeError: + _is_native_coro_code = lambda code: False +else: + _is_native_coro_code = lambda code: (code.co_flags & + inspect.CO_COROUTINE) + + # Check for CPython issue #21209 def has_yield_from_bug(): class MyGen: @@ -54,16 +78,27 @@ def yield_from_gen(gen): del has_yield_from_bug +def debug_wrapper(gen): + # This function is called from 'sys.set_coroutine_wrapper'. + # We only wrap here coroutines defined via 'async def' syntax. + # Generator-based coroutines are wrapped in @coroutine + # decorator. + if _is_native_coro_code(gen.gi_code): + return CoroWrapper(gen, None) + else: + return gen + + class CoroWrapper: # Wrapper for coroutine object in _DEBUG mode. - def __init__(self, gen, func): - assert inspect.isgenerator(gen), gen + def __init__(self, gen, func=None): + assert inspect.isgenerator(gen) or inspect.iscoroutine(gen), gen self.gen = gen - self.func = func + self.func = func # Used to unwrap @coroutine decorator self._source_traceback = traceback.extract_stack(sys._getframe(1)) - # __name__, __qualname__, __doc__ attributes are set by the coroutine() - # decorator + self.__name__ = getattr(gen, '__name__', None) + self.__qualname__ = getattr(gen, '__qualname__', None) def __repr__(self): coro_repr = _format_coroutine(self) @@ -75,6 +110,9 @@ def __repr__(self): def __iter__(self): return self + if _PY35: + __await__ = __iter__ # make compatible with 'await' expression + def __next__(self): return next(self.gen) @@ -133,6 +171,14 @@ def coroutine(func): If the coroutine is not yielded from before it is destroyed, an error message is logged. """ + is_coroutine = _iscoroutinefunction(func) + if is_coroutine and _is_native_coro_code(func.__code__): + # In Python 3.5 that's all we need to do for coroutines + # defiend with "async def". + # Wrapping in CoroWrapper will happen via + # 'sys.set_coroutine_wrapper' function. + return func + if inspect.isgeneratorfunction(func): coro = func else: @@ -144,18 +190,22 @@ def coro(*args, **kw): return res if not _DEBUG: - wrapper = coro + if native_coroutine_support: + wrapper = types.coroutine(coro) + else: + wrapper = coro else: @functools.wraps(func) def wrapper(*args, **kwds): - w = CoroWrapper(coro(*args, **kwds), func) + w = CoroWrapper(coro(*args, **kwds), func=func) if w._source_traceback: del w._source_traceback[-1] - if hasattr(func, '__name__'): - w.__name__ = func.__name__ - if hasattr(func, '__qualname__'): - w.__qualname__ = func.__qualname__ - w.__doc__ = func.__doc__ + # Python < 3.5 does not implement __qualname__ + # on generator objects, so we set it manually. + # We use getattr as some callables (such as + # functools.partial may lack __qualname__). + w.__name__ = getattr(func, '__name__', None) + w.__qualname__ = getattr(func, '__qualname__', None) return w wrapper._is_coroutine = True # For iscoroutinefunction(). @@ -164,7 +214,8 @@ def wrapper(*args, **kwds): def iscoroutinefunction(func): """Return True if func is a decorated coroutine function.""" - return getattr(func, '_is_coroutine', False) + return (getattr(func, '_is_coroutine', False) or + _iscoroutinefunction(func)) _COROUTINE_TYPES = (types.GeneratorType, CoroWrapper) diff --git a/asyncio/futures.py b/asyncio/futures.py index 74a99ba0..d06828a6 100644 --- a/asyncio/futures.py +++ b/asyncio/futures.py @@ -19,6 +19,7 @@ _FINISHED = 'FINISHED' _PY34 = sys.version_info >= (3, 4) +_PY35 = sys.version_info >= (3, 5) Error = concurrent.futures._base.Error CancelledError = concurrent.futures.CancelledError @@ -387,6 +388,9 @@ def __iter__(self): assert self.done(), "yield from wasn't used with future" return self.result() # May raise too. + if _PY35: + __await__ = __iter__ # make compatible with 'await' expression + def wrap_future(fut, *, loop=None): """Wrap concurrent.futures.Future object.""" diff --git a/asyncio/tasks.py b/asyncio/tasks.py index f617b62b..fcb38338 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -11,6 +11,7 @@ import inspect import linecache import sys +import types import traceback import warnings import weakref @@ -73,7 +74,10 @@ def __init__(self, coro, *, loop=None): super().__init__(loop=loop) if self._source_traceback: del self._source_traceback[-1] - self._coro = iter(coro) # Use the iterator just in case. + if coro.__class__ is types.GeneratorType: + self._coro = coro + else: + self._coro = iter(coro) # Use the iterator just in case. self._fut_waiter = None self._must_cancel = False self._loop.call_soon(self._step) @@ -236,7 +240,7 @@ def _step(self, value=None, exc=None): elif value is not None: result = coro.send(value) else: - result = next(coro) + result = coro.send(None) except StopIteration as exc: self.set_result(exc.value) except futures.CancelledError as exc: diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 8c4498cf..b1f1e56c 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -61,7 +61,8 @@ def test_not_implemented(self): NotImplementedError, self.loop._make_write_pipe_transport, m, m) gen = self.loop._make_subprocess_transport(m, m, m, m, m, m, m) - self.assertRaises(NotImplementedError, next, iter(gen)) + with self.assertRaises(NotImplementedError): + gen.send(None) def test_close(self): self.assertFalse(self.loop.is_closed()) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 4119085d..6541df75 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1638,7 +1638,7 @@ def foo(): return a def call(arg): - cw = asyncio.coroutines.CoroWrapper(foo(), foo) + cw = asyncio.coroutines.CoroWrapper(foo()) cw.send(None) try: cw.send(arg) @@ -1653,7 +1653,7 @@ def call(arg): def test_corowrapper_weakref(self): wd = weakref.WeakValueDictionary() def foo(): yield from [] - cw = asyncio.coroutines.CoroWrapper(foo(), foo) + cw = asyncio.coroutines.CoroWrapper(foo()) wd['cw'] = cw # Would fail without __weakref__ slot. cw.gen = None # Suppress warning from __del__. From f24ba3833f3ec6c747f1b7eb8fd3c821dc2c7daf Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Tue, 12 May 2015 11:32:46 -0400 Subject: [PATCH 1379/1502] Make sure sys.set_coroutine_wrapper is called *only* when loop is running. Previous approach of installing coroutine wrapper in loop.set_debug() and uninstalling it in loop.close() was very fragile. Most of asyncio tests do not call loop.close() at all. Since coroutine wrapper is a global setting, we have to make sure that it's only set when the loop is running, and is automatically unset when it stops running. --- asyncio/base_events.py | 80 ++++++++++++++++++++++++------------------ 1 file changed, 45 insertions(+), 35 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 38344a77..5a536a22 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -198,6 +198,7 @@ def __init__(self): self.slow_callback_duration = 0.1 self._current_handle = None self._task_factory = None + self._coroutine_wrapper_set = False def __repr__(self): return ('<%s running=%s closed=%s debug=%s>' @@ -291,6 +292,7 @@ def run_forever(self): self._check_closed() if self.is_running(): raise RuntimeError('Event loop is running.') + self._set_coroutine_wrapper(self._debug) self._thread_id = threading.get_ident() try: while True: @@ -300,6 +302,7 @@ def run_forever(self): break finally: self._thread_id = None + self._set_coroutine_wrapper(False) def run_until_complete(self, future): """Run until the Future is done. @@ -360,18 +363,13 @@ def close(self): return if self._debug: logger.debug("Close %r", self) - try: - self._closed = True - self._ready.clear() - self._scheduled.clear() - executor = self._default_executor - if executor is not None: - self._default_executor = None - executor.shutdown(wait=False) - finally: - # It is important to unregister "sys.coroutine_wrapper" - # if it was registered. - self.set_debug(False) + self._closed = True + self._ready.clear() + self._scheduled.clear() + executor = self._default_executor + if executor is not None: + self._default_executor = None + executor.shutdown(wait=False) def is_closed(self): """Returns True if the event loop was closed.""" @@ -1199,32 +1197,44 @@ def _run_once(self): handle._run() handle = None # Needed to break cycles when an exception occurs. + def _set_coroutine_wrapper(self, enabled): + try: + set_wrapper = sys.set_coroutine_wrapper + get_wrapper = sys.get_coroutine_wrapper + except AttributeError: + return + + enabled = bool(enabled) + if self._coroutine_wrapper_set is enabled: + return + + wrapper = coroutines.debug_wrapper + current_wrapper = get_wrapper() + + if enabled: + if current_wrapper not in (None, wrapper): + warnings.warn( + "loop.set_debug(True): cannot set debug coroutine " + "wrapper; another wrapper is already set %r" % + current_wrapper, RuntimeWarning) + else: + set_wrapper(wrapper) + self._coroutine_wrapper_set = True + else: + if current_wrapper not in (None, wrapper): + warnings.warn( + "loop.set_debug(False): cannot unset debug coroutine " + "wrapper; another wrapper was set %r" % + current_wrapper, RuntimeWarning) + else: + set_wrapper(None) + self._coroutine_wrapper_set = False + def get_debug(self): return self._debug def set_debug(self, enabled): self._debug = enabled - wrapper = coroutines.debug_wrapper - try: - set_wrapper = sys.set_coroutine_wrapper - except AttributeError: - pass - else: - current_wrapper = sys.get_coroutine_wrapper() - if enabled: - if current_wrapper not in (None, wrapper): - warnings.warn( - "loop.set_debug(True): cannot set debug coroutine " - "wrapper; another wrapper is already set %r" % - current_wrapper, RuntimeWarning) - else: - set_wrapper(wrapper) - else: - if current_wrapper not in (None, wrapper): - warnings.warn( - "loop.set_debug(False): cannot unset debug coroutine " - "wrapper; another wrapper was set %r" % - current_wrapper, RuntimeWarning) - else: - set_wrapper(None) + if self.is_running(): + self._set_coroutine_wrapper(enabled) From 6ac55b2ae73fcd639c49523266aa4b85482a8cc3 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Wed, 13 May 2015 15:17:07 -0400 Subject: [PATCH 1380/1502] Enable 'async for' for StreamReader, 'async with' for locks in 3.5. --- asyncio/locks.py | 108 ++++++++++++++++++++++----------------------- asyncio/streams.py | 14 ++++++ 2 files changed, 67 insertions(+), 55 deletions(-) diff --git a/asyncio/locks.py b/asyncio/locks.py index 41a68c6c..b2e516b5 100644 --- a/asyncio/locks.py +++ b/asyncio/locks.py @@ -3,12 +3,16 @@ __all__ = ['Lock', 'Event', 'Condition', 'Semaphore', 'BoundedSemaphore'] import collections +import sys from . import events from . import futures from .coroutines import coroutine +_PY35 = sys.version_info >= (3, 5) + + class _ContextManager: """Context manager. @@ -39,7 +43,53 @@ def __exit__(self, *args): self._lock = None # Crudely prevent reuse. -class Lock: +class _ContextManagerMixin: + def __enter__(self): + raise RuntimeError( + '"yield from" should be used as context manager expression') + + def __exit__(self, *args): + # This must exist because __enter__ exists, even though that + # always raises; that's how the with-statement works. + pass + + @coroutine + def __iter__(self): + # This is not a coroutine. It is meant to enable the idiom: + # + # with (yield from lock): + # + # + # as an alternative to: + # + # yield from lock.acquire() + # try: + # + # finally: + # lock.release() + yield from self.acquire() + return _ContextManager(self) + + if _PY35: + + def __await__(self): + # To make "with await lock" work. + yield from self.acquire() + return _ContextManager(self) + + @coroutine + def __aenter__(self): + yield from self.acquire() + # We have no use for the "as ..." clause in the with + # statement for locks. + return None + + @coroutine + def __aexit__(self, exc_type, exc, tb): + self.release() + + +class Lock(_ContextManagerMixin): """Primitive lock objects. A primitive lock is a synchronization primitive that is not owned @@ -153,32 +203,6 @@ def release(self): else: raise RuntimeError('Lock is not acquired.') - def __enter__(self): - raise RuntimeError( - '"yield from" should be used as context manager expression') - - def __exit__(self, *args): - # This must exist because __enter__ exists, even though that - # always raises; that's how the with-statement works. - pass - - @coroutine - def __iter__(self): - # This is not a coroutine. It is meant to enable the idiom: - # - # with (yield from lock): - # - # - # as an alternative to: - # - # yield from lock.acquire() - # try: - # - # finally: - # lock.release() - yield from self.acquire() - return _ContextManager(self) - class Event: """Asynchronous equivalent to threading.Event. @@ -246,7 +270,7 @@ def wait(self): self._waiters.remove(fut) -class Condition: +class Condition(_ContextManagerMixin): """Asynchronous equivalent to threading.Condition. This class implements condition variable objects. A condition variable @@ -356,21 +380,8 @@ def notify_all(self): """ self.notify(len(self._waiters)) - def __enter__(self): - raise RuntimeError( - '"yield from" should be used as context manager expression') - - def __exit__(self, *args): - pass - @coroutine - def __iter__(self): - # See comment in Lock.__iter__(). - yield from self.acquire() - return _ContextManager(self) - - -class Semaphore: +class Semaphore(_ContextManagerMixin): """A Semaphore implementation. A semaphore manages an internal counter which is decremented by each @@ -441,19 +452,6 @@ def release(self): waiter.set_result(True) break - def __enter__(self): - raise RuntimeError( - '"yield from" should be used as context manager expression') - - def __exit__(self, *args): - pass - - @coroutine - def __iter__(self): - # See comment in Lock.__iter__(). - yield from self.acquire() - return _ContextManager(self) - class BoundedSemaphore(Semaphore): """A bounded semaphore implementation. diff --git a/asyncio/streams.py b/asyncio/streams.py index 64ff3d2e..176c65e3 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -6,6 +6,7 @@ ] import socket +import sys if hasattr(socket, 'AF_UNIX'): __all__.extend(['open_unix_connection', 'start_unix_server']) @@ -19,6 +20,7 @@ _DEFAULT_LIMIT = 2**16 +_PY35 = sys.version_info >= (3, 5) class IncompleteReadError(EOFError): @@ -485,3 +487,15 @@ def readexactly(self, n): n -= len(block) return b''.join(blocks) + + if _PY35: + @coroutine + def __aiter__(self): + return self + + @coroutine + def __anext__(self): + val = yield from self.readline() + if val == b'': + raise StopAsyncIteration + return val From 3a09a93277afc2cdb43badf92a2c85c2789813f6 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Tue, 12 May 2015 13:38:57 -0400 Subject: [PATCH 1381/1502] Use collections.abc.Coroutine for asyncio.iscoroutine() when available. --- asyncio/coroutines.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/asyncio/coroutines.py b/asyncio/coroutines.py index 20c45798..1e0a7049 100644 --- a/asyncio/coroutines.py +++ b/asyncio/coroutines.py @@ -53,6 +53,11 @@ _is_native_coro_code = lambda code: (code.co_flags & inspect.CO_COROUTINE) +try: + from collections.abc import Coroutine as CoroutineABC +except ImportError: + CoroutineABC = None + # Check for CPython issue #21209 def has_yield_from_bug(): @@ -219,6 +224,9 @@ def iscoroutinefunction(func): _COROUTINE_TYPES = (types.GeneratorType, CoroWrapper) +if CoroutineABC is not None: + _COROUTINE_TYPES += (CoroutineABC,) + def iscoroutine(obj): """Return True if obj is a coroutine object.""" From fed8618d3085c81115a0d9d8b409fe3ccc7f454e Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Thu, 28 May 2015 10:48:22 -0400 Subject: [PATCH 1382/1502] tasks: Drop useless code. --- asyncio/tasks.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index fcb38338..d8193ba4 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -74,10 +74,7 @@ def __init__(self, coro, *, loop=None): super().__init__(loop=loop) if self._source_traceback: del self._source_traceback[-1] - if coro.__class__ is types.GeneratorType: - self._coro = coro - else: - self._coro = iter(coro) # Use the iterator just in case. + self._coro = coro self._fut_waiter = None self._must_cancel = False self._loop.call_soon(self._step) @@ -237,10 +234,8 @@ def _step(self, value=None, exc=None): try: if exc is not None: result = coro.throw(exc) - elif value is not None: - result = coro.send(value) else: - result = coro.send(None) + result = coro.send(value) except StopIteration as exc: self.set_result(exc.value) except futures.CancelledError as exc: From 755806f019991ea1c6317146bb2c3d5665ed759f Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Sat, 30 May 2015 21:00:05 -0400 Subject: [PATCH 1383/1502] Support Awaitables (pep 492) in @coroutine decorator --- asyncio/coroutines.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/asyncio/coroutines.py b/asyncio/coroutines.py index 1e0a7049..4933cf83 100644 --- a/asyncio/coroutines.py +++ b/asyncio/coroutines.py @@ -54,9 +54,10 @@ inspect.CO_COROUTINE) try: - from collections.abc import Coroutine as CoroutineABC + from collections.abc import Coroutine as CoroutineABC, \ + Awaitable as AwaitableABC except ImportError: - CoroutineABC = None + CoroutineABC = AwaitableABC = None # Check for CPython issue #21209 @@ -192,6 +193,16 @@ def coro(*args, **kw): res = func(*args, **kw) if isinstance(res, futures.Future) or inspect.isgenerator(res): res = yield from res + elif AwaitableABC is not None: + # If 'func' returns an Awaitable (new in 3.5) we + # want to run it. + try: + await_meth = res.__await__ + except AttributeError: + pass + else: + if isinstance(res, AwaitableABC): + res = yield from await_meth() return res if not _DEBUG: From 7baef9de28bd0a8e5bb14ad962059dced9f23bb3 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Sun, 31 May 2015 21:19:13 -0400 Subject: [PATCH 1384/1502] coroutines: Fix CoroWrapper to support native coroutines --- asyncio/coroutines.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/coroutines.py b/asyncio/coroutines.py index 4933cf83..edb68062 100644 --- a/asyncio/coroutines.py +++ b/asyncio/coroutines.py @@ -120,7 +120,7 @@ def __iter__(self): __await__ = __iter__ # make compatible with 'await' expression def __next__(self): - return next(self.gen) + return self.gen.send(None) if _YIELD_FROM_BUG: # For for CPython issue #21209: using "yield from" and a custom From eff46724a9a9fb6d36c23907518ffde29eabb66b Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Wed, 24 Jun 2015 09:39:51 -0400 Subject: [PATCH 1385/1502] Sync with CPython. Specifically, changes from http://bugs.python.org/issue24400 --- asyncio/coroutines.py | 48 +++++++++++++++---------------------------- 1 file changed, 17 insertions(+), 31 deletions(-) diff --git a/asyncio/coroutines.py b/asyncio/coroutines.py index edb68062..4fc46a55 100644 --- a/asyncio/coroutines.py +++ b/asyncio/coroutines.py @@ -34,30 +34,20 @@ try: - types.coroutine + _types_coroutine = types.coroutine except AttributeError: - native_coroutine_support = False -else: - native_coroutine_support = True + _types_coroutine = None try: - _iscoroutinefunction = inspect.iscoroutinefunction + _inspect_iscoroutinefunction = inspect.iscoroutinefunction except AttributeError: - _iscoroutinefunction = lambda func: False + _inspect_iscoroutinefunction = lambda func: False try: - inspect.CO_COROUTINE -except AttributeError: - _is_native_coro_code = lambda code: False -else: - _is_native_coro_code = lambda code: (code.co_flags & - inspect.CO_COROUTINE) - -try: - from collections.abc import Coroutine as CoroutineABC, \ - Awaitable as AwaitableABC + from collections.abc import Coroutine as _CoroutineABC, \ + Awaitable as _AwaitableABC except ImportError: - CoroutineABC = AwaitableABC = None + _CoroutineABC = _AwaitableABC = None # Check for CPython issue #21209 @@ -89,10 +79,7 @@ def debug_wrapper(gen): # We only wrap here coroutines defined via 'async def' syntax. # Generator-based coroutines are wrapped in @coroutine # decorator. - if _is_native_coro_code(gen.gi_code): - return CoroWrapper(gen, None) - else: - return gen + return CoroWrapper(gen, None) class CoroWrapper: @@ -177,8 +164,7 @@ def coroutine(func): If the coroutine is not yielded from before it is destroyed, an error message is logged. """ - is_coroutine = _iscoroutinefunction(func) - if is_coroutine and _is_native_coro_code(func.__code__): + if _inspect_iscoroutinefunction(func): # In Python 3.5 that's all we need to do for coroutines # defiend with "async def". # Wrapping in CoroWrapper will happen via @@ -193,7 +179,7 @@ def coro(*args, **kw): res = func(*args, **kw) if isinstance(res, futures.Future) or inspect.isgenerator(res): res = yield from res - elif AwaitableABC is not None: + elif _AwaitableABC is not None: # If 'func' returns an Awaitable (new in 3.5) we # want to run it. try: @@ -201,15 +187,15 @@ def coro(*args, **kw): except AttributeError: pass else: - if isinstance(res, AwaitableABC): + if isinstance(res, _AwaitableABC): res = yield from await_meth() return res if not _DEBUG: - if native_coroutine_support: - wrapper = types.coroutine(coro) - else: + if _types_coroutine is None: wrapper = coro + else: + wrapper = _types_coroutine(coro) else: @functools.wraps(func) def wrapper(*args, **kwds): @@ -231,12 +217,12 @@ def wrapper(*args, **kwds): def iscoroutinefunction(func): """Return True if func is a decorated coroutine function.""" return (getattr(func, '_is_coroutine', False) or - _iscoroutinefunction(func)) + _inspect_iscoroutinefunction(func)) _COROUTINE_TYPES = (types.GeneratorType, CoroWrapper) -if CoroutineABC is not None: - _COROUTINE_TYPES += (CoroutineABC,) +if _CoroutineABC is not None: + _COROUTINE_TYPES += (_CoroutineABC,) def iscoroutine(obj): From 83ac3b846ff08b3b0ee19e6f29f7d623f7bb6146 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Wed, 24 Jun 2015 10:25:07 -0400 Subject: [PATCH 1386/1502] Fix CoroWrapper for 'async def' coroutines --- asyncio/coroutines.py | 50 ++++++++++++++++++++++++++++++++++--------- tests/test_tasks.py | 3 ++- 2 files changed, 42 insertions(+), 11 deletions(-) diff --git a/asyncio/coroutines.py b/asyncio/coroutines.py index 4fc46a55..896cc560 100644 --- a/asyncio/coroutines.py +++ b/asyncio/coroutines.py @@ -103,9 +103,6 @@ def __repr__(self): def __iter__(self): return self - if _PY35: - __await__ = __iter__ # make compatible with 'await' expression - def __next__(self): return self.gen.send(None) @@ -143,10 +140,28 @@ def gi_running(self): def gi_code(self): return self.gen.gi_code + if _PY35: + + __await__ = __iter__ # make compatible with 'await' expression + + @property + def cr_running(self): + return self.gen.cr_running + + @property + def cr_code(self): + return self.gen.cr_code + + @property + def cr_frame(self): + return self.gen.cr_frame + def __del__(self): # Be careful accessing self.gen.frame -- self.gen might not exist. gen = getattr(self, 'gen', None) frame = getattr(gen, 'gi_frame', None) + if frame is None: + frame = getattr(gen, 'cr_frame', None) if frame is not None and frame.f_lasti == -1: msg = '%r was never yielded from' % self tb = getattr(self, '_source_traceback', ()) @@ -233,28 +248,43 @@ def iscoroutine(obj): def _format_coroutine(coro): assert iscoroutine(coro) + coro_name = None if isinstance(coro, CoroWrapper): func = coro.func + coro_name = coro.__qualname__ else: func = coro - coro_name = events._format_callback(func, ()) - filename = coro.gi_code.co_filename + if coro_name is None: + coro_name = events._format_callback(func, ()) + + try: + coro_code = coro.gi_code + except AttributeError: + coro_code = coro.cr_code + + try: + coro_frame = coro.gi_frame + except AttributeError: + coro_frame = coro.cr_frame + + filename = coro_code.co_filename if (isinstance(coro, CoroWrapper) - and not inspect.isgeneratorfunction(coro.func)): + and not inspect.isgeneratorfunction(coro.func) + and coro.func is not None): filename, lineno = events._get_function_source(coro.func) - if coro.gi_frame is None: + if coro_frame is None: coro_repr = ('%s done, defined at %s:%s' % (coro_name, filename, lineno)) else: coro_repr = ('%s running, defined at %s:%s' % (coro_name, filename, lineno)) - elif coro.gi_frame is not None: - lineno = coro.gi_frame.f_lineno + elif coro_frame is not None: + lineno = coro_frame.f_lineno coro_repr = ('%s running at %s:%s' % (coro_name, filename, lineno)) else: - lineno = coro.gi_code.co_firstlineno + lineno = coro_code.co_firstlineno coro_repr = ('%s done, defined at %s:%s' % (coro_name, filename, lineno)) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 6541df75..251192ac 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1715,7 +1715,8 @@ def coro_noop(): self.assertTrue(m_log.error.called) message = m_log.error.call_args[0][0] func_filename, func_lineno = test_utils.get_function_source(coro_noop) - regex = (r'^ ' + + regex = (r'^ ' r'was never yielded from\n' r'Coroutine object created at \(most recent call last\):\n' r'.*\n' From 32a303ee0049e034a1eddcd7dcd85da9732db85f Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Wed, 24 Jun 2015 10:42:55 -0400 Subject: [PATCH 1387/1502] Fix regression in 83ac3b8 (failed unittests in debug mode) --- asyncio/coroutines.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/asyncio/coroutines.py b/asyncio/coroutines.py index 896cc560..a70eb1dd 100644 --- a/asyncio/coroutines.py +++ b/asyncio/coroutines.py @@ -252,6 +252,8 @@ def _format_coroutine(coro): if isinstance(coro, CoroWrapper): func = coro.func coro_name = coro.__qualname__ + if coro_name is not None: + coro_name = '{}()'.format(coro_name) else: func = coro From e802f173aa273d44153583d21508abf998128211 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Fri, 3 Jul 2015 00:39:40 -0400 Subject: [PATCH 1388/1502] coroutines: Proxy cr_await and gi_yieldfrom in CoroWrapper --- asyncio/coroutines.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/asyncio/coroutines.py b/asyncio/coroutines.py index a70eb1dd..15475f23 100644 --- a/asyncio/coroutines.py +++ b/asyncio/coroutines.py @@ -144,6 +144,14 @@ def gi_code(self): __await__ = __iter__ # make compatible with 'await' expression + @property + def gi_yieldfrom(self): + return self.gen.gi_yieldfrom + + @property + def cr_await(self): + return self.gen.cr_await + @property def cr_running(self): return self.gen.cr_running From 728a91239123913b9161e8ca174cf86288f8c7d7 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 6 Jul 2015 22:03:43 +0200 Subject: [PATCH 1389/1502] Rerite README * copy introduction from asyncio doc: explain what is asyncio * link to mailing list, IRC * explain how to install asyncio --- README.rst | 73 +++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 61 insertions(+), 12 deletions(-) diff --git a/README.rst b/README.rst index 036bd5dd..629368e2 100644 --- a/README.rst +++ b/README.rst @@ -1,21 +1,71 @@ -Tulip is the codename for my reference implementation of PEP 3156. +The asyncio module provides infrastructure for writing single-threaded +concurrent code using coroutines, multiplexing I/O access over sockets and +other resources, running network clients and servers, and other related +primitives. Here is a more detailed list of the package contents: -PEP 3156: http://www.python.org/dev/peps/pep-3156/ +* a pluggable event loop with various system-specific implementations; -**This requires Python 3.3 or later!** +* transport and protocol abstractions (similar to those in Twisted); -Copyright/license: Open source, Apache 2.0. Enjoy. +* concrete support for TCP, UDP, SSL, subprocess pipes, delayed calls, and + others (some may be system-dependent); -Master GitHub repo: https://github.com/python/asyncio +* a Future class that mimics the one in the concurrent.futures module, but + adapted for use with the event loop; -The actual code lives in the 'asyncio' subdirectory. -Tests are in the 'tests' subdirectory. +* coroutines and tasks based on ``yield from`` (PEP 380), to help write + concurrent code in a sequential fashion; -To run tests: - - make test +* cancellation support for Futures and coroutines; -To run coverage (coverage package is required): - - make coverage +* synchronization primitives for use between coroutines in a single thread, + mimicking those in the threading module; + +* an interface for passing work off to a threadpool, for times when you + absolutely, positively have to use a library that makes blocking I/O calls. + + +Installation +============ + +To install asyncio, type:: + + pip install asyncio + +asyncio requires Python 3.3 or later! The asyncio module is part of the Python +standard library since Python 3.4. + +asyncio is a free software distributed under the Apache license version 2.0. + + +Websites +======== + +* `asyncio project at GitHub `_: source + code, bug tracker +* `asyncio documentation `_ +* Mailing list: `python-tulip Google Group + `_ +* IRC: join the ``#asyncio`` channel on the Freenode network + + +Development +=========== + +The actual code lives in the 'asyncio' subdirectory. Tests are in the 'tests' +subdirectory. + +To run tests, run:: + + tox + +Or use the Makefile:: + + make test + +To run coverage (coverage package is required):: + + make coverage On Windows, things are a little more complicated. Assume 'P' is your Python binary (for example C:\Python33\python.exe). @@ -41,4 +91,3 @@ And coverage as follows: C> P runtests.py --coverage ---Guido van Rossum From 1975461f4ccf0d9ffd9e20e4dd4a4650ad6a0c18 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 6 Jul 2015 22:11:55 +0200 Subject: [PATCH 1390/1502] rename asyncio/ directory to trollius/ --- {asyncio => trollius}/__init__.py | 0 {asyncio => trollius}/base_events.py | 0 {asyncio => trollius}/base_subprocess.py | 0 {asyncio => trollius}/constants.py | 0 {asyncio => trollius}/coroutines.py | 0 {asyncio => trollius}/events.py | 0 {asyncio => trollius}/futures.py | 0 {asyncio => trollius}/locks.py | 0 {asyncio => trollius}/log.py | 0 {asyncio => trollius}/proactor_events.py | 0 {asyncio => trollius}/protocols.py | 0 {asyncio => trollius}/queues.py | 0 {asyncio => trollius}/selector_events.py | 0 {asyncio => trollius}/selectors.py | 0 {asyncio => trollius}/sslproto.py | 0 {asyncio => trollius}/streams.py | 0 {asyncio => trollius}/subprocess.py | 0 {asyncio => trollius}/tasks.py | 0 {asyncio => trollius}/test_support.py | 0 {asyncio => trollius}/test_utils.py | 0 {asyncio => trollius}/transports.py | 0 {asyncio => trollius}/unix_events.py | 0 {asyncio => trollius}/windows_events.py | 0 {asyncio => trollius}/windows_utils.py | 0 24 files changed, 0 insertions(+), 0 deletions(-) rename {asyncio => trollius}/__init__.py (100%) rename {asyncio => trollius}/base_events.py (100%) rename {asyncio => trollius}/base_subprocess.py (100%) rename {asyncio => trollius}/constants.py (100%) rename {asyncio => trollius}/coroutines.py (100%) rename {asyncio => trollius}/events.py (100%) rename {asyncio => trollius}/futures.py (100%) rename {asyncio => trollius}/locks.py (100%) rename {asyncio => trollius}/log.py (100%) rename {asyncio => trollius}/proactor_events.py (100%) rename {asyncio => trollius}/protocols.py (100%) rename {asyncio => trollius}/queues.py (100%) rename {asyncio => trollius}/selector_events.py (100%) rename {asyncio => trollius}/selectors.py (100%) rename {asyncio => trollius}/sslproto.py (100%) rename {asyncio => trollius}/streams.py (100%) rename {asyncio => trollius}/subprocess.py (100%) rename {asyncio => trollius}/tasks.py (100%) rename {asyncio => trollius}/test_support.py (100%) rename {asyncio => trollius}/test_utils.py (100%) rename {asyncio => trollius}/transports.py (100%) rename {asyncio => trollius}/unix_events.py (100%) rename {asyncio => trollius}/windows_events.py (100%) rename {asyncio => trollius}/windows_utils.py (100%) diff --git a/asyncio/__init__.py b/trollius/__init__.py similarity index 100% rename from asyncio/__init__.py rename to trollius/__init__.py diff --git a/asyncio/base_events.py b/trollius/base_events.py similarity index 100% rename from asyncio/base_events.py rename to trollius/base_events.py diff --git a/asyncio/base_subprocess.py b/trollius/base_subprocess.py similarity index 100% rename from asyncio/base_subprocess.py rename to trollius/base_subprocess.py diff --git a/asyncio/constants.py b/trollius/constants.py similarity index 100% rename from asyncio/constants.py rename to trollius/constants.py diff --git a/asyncio/coroutines.py b/trollius/coroutines.py similarity index 100% rename from asyncio/coroutines.py rename to trollius/coroutines.py diff --git a/asyncio/events.py b/trollius/events.py similarity index 100% rename from asyncio/events.py rename to trollius/events.py diff --git a/asyncio/futures.py b/trollius/futures.py similarity index 100% rename from asyncio/futures.py rename to trollius/futures.py diff --git a/asyncio/locks.py b/trollius/locks.py similarity index 100% rename from asyncio/locks.py rename to trollius/locks.py diff --git a/asyncio/log.py b/trollius/log.py similarity index 100% rename from asyncio/log.py rename to trollius/log.py diff --git a/asyncio/proactor_events.py b/trollius/proactor_events.py similarity index 100% rename from asyncio/proactor_events.py rename to trollius/proactor_events.py diff --git a/asyncio/protocols.py b/trollius/protocols.py similarity index 100% rename from asyncio/protocols.py rename to trollius/protocols.py diff --git a/asyncio/queues.py b/trollius/queues.py similarity index 100% rename from asyncio/queues.py rename to trollius/queues.py diff --git a/asyncio/selector_events.py b/trollius/selector_events.py similarity index 100% rename from asyncio/selector_events.py rename to trollius/selector_events.py diff --git a/asyncio/selectors.py b/trollius/selectors.py similarity index 100% rename from asyncio/selectors.py rename to trollius/selectors.py diff --git a/asyncio/sslproto.py b/trollius/sslproto.py similarity index 100% rename from asyncio/sslproto.py rename to trollius/sslproto.py diff --git a/asyncio/streams.py b/trollius/streams.py similarity index 100% rename from asyncio/streams.py rename to trollius/streams.py diff --git a/asyncio/subprocess.py b/trollius/subprocess.py similarity index 100% rename from asyncio/subprocess.py rename to trollius/subprocess.py diff --git a/asyncio/tasks.py b/trollius/tasks.py similarity index 100% rename from asyncio/tasks.py rename to trollius/tasks.py diff --git a/asyncio/test_support.py b/trollius/test_support.py similarity index 100% rename from asyncio/test_support.py rename to trollius/test_support.py diff --git a/asyncio/test_utils.py b/trollius/test_utils.py similarity index 100% rename from asyncio/test_utils.py rename to trollius/test_utils.py diff --git a/asyncio/transports.py b/trollius/transports.py similarity index 100% rename from asyncio/transports.py rename to trollius/transports.py diff --git a/asyncio/unix_events.py b/trollius/unix_events.py similarity index 100% rename from asyncio/unix_events.py rename to trollius/unix_events.py diff --git a/asyncio/windows_events.py b/trollius/windows_events.py similarity index 100% rename from asyncio/windows_events.py rename to trollius/windows_events.py diff --git a/asyncio/windows_utils.py b/trollius/windows_utils.py similarity index 100% rename from asyncio/windows_utils.py rename to trollius/windows_utils.py From ccafbc04e342961fa9579c907e9da42dc2ceb8dd Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 6 Jul 2015 22:13:28 +0200 Subject: [PATCH 1391/1502] Replace asyncio with trollius in code and config --- examples/cacheclt.py | 6 +-- examples/cachesvr.py | 4 +- examples/child_process.py | 8 ++-- examples/crawl.py | 4 +- examples/echo_client_tulip.py | 2 +- examples/echo_server_tulip.py | 2 +- examples/fetch0.py | 2 +- examples/fetch1.py | 2 +- examples/fetch2.py | 2 +- examples/fetch3.py | 4 +- examples/fuzz_as_completed.py | 2 +- examples/hello_callback.py | 2 +- examples/hello_coroutine.py | 2 +- examples/shell.py | 4 +- examples/simple_tcp_server.py | 2 +- examples/sink.py | 4 +- examples/source.py | 6 +-- examples/source1.py | 6 +-- examples/stacks.py | 2 +- examples/subprocess_attach_read_pipe.py | 2 +- examples/subprocess_attach_write_pipe.py | 4 +- examples/subprocess_shell.py | 4 +- examples/tcp_echo.py | 4 +- examples/timing_tcp_server.py | 2 +- examples/udp_echo.py | 2 +- run_aiotest.py | 8 ++-- setup.py | 6 +-- tests/test_base_events.py | 48 +++++++++++------------ tests/test_events.py | 16 ++++---- tests/test_futures.py | 22 +++++------ tests/test_locks.py | 4 +- tests/test_proactor_events.py | 18 ++++----- tests/test_queues.py | 4 +- tests/test_selector_events.py | 34 ++++++++-------- tests/test_selectors.py | 6 +-- tests/test_sslproto.py | 8 ++-- tests/test_streams.py | 6 +-- tests/test_subprocess.py | 12 +++--- tests/test_tasks.py | 12 +++--- tests/test_transports.py | 4 +- tests/test_unix_events.py | 50 ++++++++++++------------ tests/test_windows_events.py | 8 ++-- tests/test_windows_utils.py | 10 ++--- trollius/test_utils.py | 2 +- 44 files changed, 181 insertions(+), 181 deletions(-) diff --git a/examples/cacheclt.py b/examples/cacheclt.py index b11a4d1a..3e9de31a 100644 --- a/examples/cacheclt.py +++ b/examples/cacheclt.py @@ -4,8 +4,8 @@ """ import argparse -import asyncio -from asyncio import test_utils +import trollius as asyncio +from trollius import test_utils import json import logging @@ -166,7 +166,7 @@ def process(self): def main(): asyncio.set_event_loop(None) if args.iocp: - from asyncio.windows_events import ProactorEventLoop + from trollius.windows_events import ProactorEventLoop loop = ProactorEventLoop() else: loop = asyncio.new_event_loop() diff --git a/examples/cachesvr.py b/examples/cachesvr.py index 053f9c21..27ce6c30 100644 --- a/examples/cachesvr.py +++ b/examples/cachesvr.py @@ -57,7 +57,7 @@ """ import argparse -import asyncio +import trollius as asyncio import json import logging import os @@ -217,7 +217,7 @@ def handle_delete(self, key): def main(): asyncio.set_event_loop(None) if args.iocp: - from asyncio.windows_events import ProactorEventLoop + from trollius.windows_events import ProactorEventLoop loop = ProactorEventLoop() else: loop = asyncio.new_event_loop() diff --git a/examples/child_process.py b/examples/child_process.py index 3fac175e..915e358f 100644 --- a/examples/child_process.py +++ b/examples/child_process.py @@ -10,15 +10,15 @@ import sys try: - import asyncio + import trollius as asyncio except ImportError: # asyncio is not installed sys.path.append(os.path.join(os.path.dirname(__file__), '..')) - import asyncio + import trollius as asyncio if sys.platform == 'win32': - from asyncio.windows_utils import Popen, PIPE - from asyncio.windows_events import ProactorEventLoop + from trollius.windows_utils import Popen, PIPE + from trollius.windows_events import ProactorEventLoop else: from subprocess import Popen, PIPE diff --git a/examples/crawl.py b/examples/crawl.py index 4bb0b4ea..0393d626 100644 --- a/examples/crawl.py +++ b/examples/crawl.py @@ -15,7 +15,7 @@ # - Handle out of file descriptors directly? (How?) import argparse -import asyncio +import trollius as asyncio import asyncio.locks import cgi from http.client import BadStatusLine @@ -828,7 +828,7 @@ def main(): log = Logger(args.level) if args.iocp: - from asyncio.windows_events import ProactorEventLoop + from trollius.windows_events import ProactorEventLoop loop = ProactorEventLoop() asyncio.set_event_loop(loop) elif args.select: diff --git a/examples/echo_client_tulip.py b/examples/echo_client_tulip.py index 88124efe..eea8a58d 100644 --- a/examples/echo_client_tulip.py +++ b/examples/echo_client_tulip.py @@ -1,4 +1,4 @@ -import asyncio +import trollius as asyncio END = b'Bye-bye!\n' diff --git a/examples/echo_server_tulip.py b/examples/echo_server_tulip.py index 8167e540..e1f9f2b8 100644 --- a/examples/echo_server_tulip.py +++ b/examples/echo_server_tulip.py @@ -1,4 +1,4 @@ -import asyncio +import trollius as asyncio @asyncio.coroutine def echo_server(): diff --git a/examples/fetch0.py b/examples/fetch0.py index 180fcf26..222a97b1 100644 --- a/examples/fetch0.py +++ b/examples/fetch0.py @@ -2,7 +2,7 @@ import sys -from asyncio import * +from trollius import * @coroutine diff --git a/examples/fetch1.py b/examples/fetch1.py index 8dbb6e47..4e7037f2 100644 --- a/examples/fetch1.py +++ b/examples/fetch1.py @@ -6,7 +6,7 @@ import sys import urllib.parse -from asyncio import * +from trollius import * class Response: diff --git a/examples/fetch2.py b/examples/fetch2.py index 7617b59b..de6a288d 100644 --- a/examples/fetch2.py +++ b/examples/fetch2.py @@ -7,7 +7,7 @@ import urllib.parse from http.client import BadStatusLine -from asyncio import * +from trollius import * class Request: diff --git a/examples/fetch3.py b/examples/fetch3.py index 9419afd2..fc113d40 100644 --- a/examples/fetch3.py +++ b/examples/fetch3.py @@ -8,7 +8,7 @@ import urllib.parse from http.client import BadStatusLine -from asyncio import * +from trollius import * class ConnectionPool: @@ -214,7 +214,7 @@ def fetch(url, verbose=True, max_redirect=10): def main(): if '--iocp' in sys.argv: - from asyncio.windows_events import ProactorEventLoop + from trollius.windows_events import ProactorEventLoop loop = ProactorEventLoop() set_event_loop(loop) else: diff --git a/examples/fuzz_as_completed.py b/examples/fuzz_as_completed.py index 123fbf1b..f6203e84 100644 --- a/examples/fuzz_as_completed.py +++ b/examples/fuzz_as_completed.py @@ -2,7 +2,7 @@ """Fuzz tester for as_completed(), by Glenn Langford.""" -import asyncio +import trollius as asyncio import itertools import random import sys diff --git a/examples/hello_callback.py b/examples/hello_callback.py index 7ccbea1e..07205d9b 100644 --- a/examples/hello_callback.py +++ b/examples/hello_callback.py @@ -1,6 +1,6 @@ """Print 'Hello World' every two seconds, using a callback.""" -import asyncio +import trollius as asyncio def print_and_repeat(loop): diff --git a/examples/hello_coroutine.py b/examples/hello_coroutine.py index b9347aa8..de716dee 100644 --- a/examples/hello_coroutine.py +++ b/examples/hello_coroutine.py @@ -1,6 +1,6 @@ """Print 'Hello World' every two seconds, using a coroutine.""" -import asyncio +import trollius as asyncio @asyncio.coroutine diff --git a/examples/shell.py b/examples/shell.py index f9343256..61991a75 100644 --- a/examples/shell.py +++ b/examples/shell.py @@ -1,8 +1,8 @@ """Examples using create_subprocess_exec() and create_subprocess_shell().""" -import asyncio +import trollius as asyncio import signal -from asyncio.subprocess import PIPE +from trollius.subprocess import PIPE @asyncio.coroutine def cat(loop): diff --git a/examples/simple_tcp_server.py b/examples/simple_tcp_server.py index 5f874ffc..3e847b07 100644 --- a/examples/simple_tcp_server.py +++ b/examples/simple_tcp_server.py @@ -9,7 +9,7 @@ """ import sys -import asyncio +import trollius as asyncio import asyncio.streams diff --git a/examples/sink.py b/examples/sink.py index d362cbb2..8156b0ec 100644 --- a/examples/sink.py +++ b/examples/sink.py @@ -4,7 +4,7 @@ import os import sys -from asyncio import * +from trollius import * ARGS = argparse.ArgumentParser(description="TCP data sink example.") ARGS.add_argument( @@ -79,7 +79,7 @@ def main(): global args args = ARGS.parse_args() if args.iocp: - from asyncio.windows_events import ProactorEventLoop + from trollius.windows_events import ProactorEventLoop loop = ProactorEventLoop() set_event_loop(loop) else: diff --git a/examples/source.py b/examples/source.py index 7fd11fb0..61c7ab79 100644 --- a/examples/source.py +++ b/examples/source.py @@ -3,8 +3,8 @@ import argparse import sys -from asyncio import * -from asyncio import test_utils +from trollius import * +from trollius import test_utils ARGS = argparse.ArgumentParser(description="TCP data sink example.") @@ -85,7 +85,7 @@ def main(): global args args = ARGS.parse_args() if args.iocp: - from asyncio.windows_events import ProactorEventLoop + from trollius.windows_events import ProactorEventLoop loop = ProactorEventLoop() set_event_loop(loop) else: diff --git a/examples/source1.py b/examples/source1.py index 6802e963..5af467a9 100644 --- a/examples/source1.py +++ b/examples/source1.py @@ -3,8 +3,8 @@ import argparse import sys -from asyncio import * -from asyncio import test_utils +from trollius import * +from trollius import test_utils ARGS = argparse.ArgumentParser(description="TCP data sink example.") ARGS.add_argument( @@ -83,7 +83,7 @@ def main(): global args args = ARGS.parse_args() if args.iocp: - from asyncio.windows_events import ProactorEventLoop + from trollius.windows_events import ProactorEventLoop loop = ProactorEventLoop() set_event_loop(loop) else: diff --git a/examples/stacks.py b/examples/stacks.py index 0b7e0b2c..e03e78e6 100644 --- a/examples/stacks.py +++ b/examples/stacks.py @@ -1,7 +1,7 @@ """Crude demo for print_stack().""" -from asyncio import * +from trollius import * @coroutine diff --git a/examples/subprocess_attach_read_pipe.py b/examples/subprocess_attach_read_pipe.py index d8a62420..2cadc849 100644 --- a/examples/subprocess_attach_read_pipe.py +++ b/examples/subprocess_attach_read_pipe.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 """Example showing how to attach a read pipe to a subprocess.""" -import asyncio +import trollius as asyncio import os, sys code = """ diff --git a/examples/subprocess_attach_write_pipe.py b/examples/subprocess_attach_write_pipe.py index c4e099f6..646b2691 100644 --- a/examples/subprocess_attach_write_pipe.py +++ b/examples/subprocess_attach_write_pipe.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 """Example showing how to attach a write pipe to a subprocess.""" -import asyncio +import trollius as asyncio import os, sys -from asyncio import subprocess +from trollius import subprocess code = """ import os, sys diff --git a/examples/subprocess_shell.py b/examples/subprocess_shell.py index 745cb646..71d125ce 100644 --- a/examples/subprocess_shell.py +++ b/examples/subprocess_shell.py @@ -1,9 +1,9 @@ """Example writing to and reading from a subprocess at the same time using tasks.""" -import asyncio +import trollius as asyncio import os -from asyncio.subprocess import PIPE +from trollius.subprocess import PIPE @asyncio.coroutine diff --git a/examples/tcp_echo.py b/examples/tcp_echo.py index d743242a..1a0a2c61 100755 --- a/examples/tcp_echo.py +++ b/examples/tcp_echo.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 """TCP echo server example.""" import argparse -import asyncio +import trollius as asyncio import sys try: import signal @@ -105,7 +105,7 @@ def start_server(loop, host, port): ARGS.print_help() else: if args.iocp: - from asyncio import windows_events + from trollius import windows_events loop = windows_events.ProactorEventLoop() asyncio.set_event_loop(loop) else: diff --git a/examples/timing_tcp_server.py b/examples/timing_tcp_server.py index 883ce6d3..664ba63a 100644 --- a/examples/timing_tcp_server.py +++ b/examples/timing_tcp_server.py @@ -12,7 +12,7 @@ import time import random -import asyncio +import trollius as asyncio import asyncio.streams diff --git a/examples/udp_echo.py b/examples/udp_echo.py index 93ac7e6b..b13303ff 100755 --- a/examples/udp_echo.py +++ b/examples/udp_echo.py @@ -2,7 +2,7 @@ """UDP echo example.""" import argparse import sys -import asyncio +import trollius as asyncio try: import signal except ImportError: diff --git a/run_aiotest.py b/run_aiotest.py index 8d6fa293..0a08d66a 100644 --- a/run_aiotest.py +++ b/run_aiotest.py @@ -1,14 +1,14 @@ import aiotest.run -import asyncio +import trollius import sys if sys.platform == 'win32': - from asyncio.windows_utils import socketpair + from trollius.windows_utils import socketpair else: from socket import socketpair config = aiotest.TestConfig() -config.asyncio = asyncio +config.asyncio = trollius config.socketpair = socketpair -config.new_event_pool_policy = asyncio.DefaultEventLoopPolicy +config.new_event_pool_policy = trollius.DefaultEventLoopPolicy config.call_soon_check_closed = True aiotest.run.main(config) diff --git a/setup.py b/setup.py index 93cacdd3..660d38fb 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ extensions = [] if os.name == 'nt': ext = Extension( - 'asyncio._overlapped', ['overlapped.c'], libraries=['ws2_32'], + 'trollius._overlapped', ['overlapped.c'], libraries=['ws2_32'], ) extensions.append(ext) @@ -29,7 +29,7 @@ long_description = fp.read() setup( - name="asyncio", + name="trollius", version="3.4.4", description="reference implementation of PEP 3156", @@ -42,7 +42,7 @@ "Programming Language :: Python :: 3.3", ], - packages=["asyncio"], + packages=["trollius"], test_suite="runtests.runtests", ext_modules=extensions, diff --git a/tests/test_base_events.py b/tests/test_base_events.py index b1f1e56c..7e7380db 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -10,21 +10,21 @@ import unittest from unittest import mock -import asyncio -from asyncio import base_events -from asyncio import constants -from asyncio import test_utils +import trollius as asyncio +from trollius import base_events +from trollius import constants +from trollius import test_utils try: from test import support except ImportError: - from asyncio import test_support as support + from trollius import test_support as support try: from test.support.script_helper import assert_python_ok except ImportError: try: from test.script_helper import assert_python_ok except ImportError: - from asyncio.test_support import assert_python_ok + from trollius.test_support import assert_python_ok MOCK_ANY = mock.ANY @@ -288,7 +288,7 @@ def test_set_debug(self): self.loop.set_debug(False) self.assertFalse(self.loop.get_debug()) - @mock.patch('asyncio.base_events.logger') + @mock.patch('trollius.base_events.logger') def test__run_once_logging(self, m_logger): def slow_select(timeout): # Sleep a bit longer than a second to avoid timer resolution @@ -476,7 +476,7 @@ def zero_error(fut): 1/0 # Test call_soon (events.Handle) - with mock.patch('asyncio.base_events.logger') as log: + with mock.patch('trollius.base_events.logger') as log: fut = asyncio.Future(loop=self.loop) self.loop.call_soon(zero_error, fut) fut.add_done_callback(lambda fut: self.loop.stop()) @@ -486,7 +486,7 @@ def zero_error(fut): exc_info=(ZeroDivisionError, MOCK_ANY, MOCK_ANY)) # Test call_later (events.TimerHandle) - with mock.patch('asyncio.base_events.logger') as log: + with mock.patch('trollius.base_events.logger') as log: fut = asyncio.Future(loop=self.loop) self.loop.call_later(0.01, zero_error, fut) fut.add_done_callback(lambda fut: self.loop.stop()) @@ -504,7 +504,7 @@ def zero_error_coro(): 1/0 # Test Future.__del__ - with mock.patch('asyncio.base_events.logger') as log: + with mock.patch('trollius.base_events.logger') as log: fut = asyncio.ensure_future(zero_error_coro(), loop=self.loop) fut.add_done_callback(lambda *args: self.loop.stop()) self.loop.run_forever() @@ -551,7 +551,7 @@ def run_loop(): mock_handler.reset_mock() self.loop.set_exception_handler(None) - with mock.patch('asyncio.base_events.logger') as log: + with mock.patch('trollius.base_events.logger') as log: run_loop() log.error.assert_called_with( test_utils.MockPattern( @@ -574,7 +574,7 @@ def handler(loop, context): self.loop.set_exception_handler(handler) - with mock.patch('asyncio.base_events.logger') as log: + with mock.patch('trollius.base_events.logger') as log: run_loop() log.error.assert_called_with( test_utils.MockPattern( @@ -605,7 +605,7 @@ def zero_error(): loop.call_soon(zero_error) loop._run_once() - with mock.patch('asyncio.base_events.logger') as log: + with mock.patch('trollius.base_events.logger') as log: run_loop() log.error.assert_called_with( 'Exception in default exception handler', @@ -616,7 +616,7 @@ def custom_handler(loop, context): _context = None loop.set_exception_handler(custom_handler) - with mock.patch('asyncio.base_events.logger') as log: + with mock.patch('trollius.base_events.logger') as log: run_loop() log.error.assert_called_with( test_utils.MockPattern('Exception in default exception.*' @@ -821,7 +821,7 @@ def setUp(self): self.loop = asyncio.new_event_loop() self.set_event_loop(self.loop) - @mock.patch('asyncio.base_events.socket') + @mock.patch('trollius.base_events.socket') def test_create_connection_multiple_errors(self, m_socket): class MyProto(asyncio.Protocol): @@ -854,7 +854,7 @@ def _socket(*args, **kw): self.assertEqual(str(cm.exception), 'Multiple exceptions: err1, err2') - @mock.patch('asyncio.base_events.socket') + @mock.patch('trollius.base_events.socket') def test_create_connection_timeout(self, m_socket): # Ensure that the socket is closed on timeout sock = mock.Mock() @@ -932,7 +932,7 @@ def getaddrinfo_task(*args, **kwds): with self.assertRaises(OSError): self.loop.run_until_complete(coro) - @mock.patch('asyncio.base_events.socket') + @mock.patch('trollius.base_events.socket') def test_create_connection_multiple_errors_local_addr(self, m_socket): def bind(addr): @@ -1099,7 +1099,7 @@ def test_create_server_no_getaddrinfo(self): f = self.loop.create_server(MyProto, '0.0.0.0', 0) self.assertRaises(OSError, self.loop.run_until_complete, f) - @mock.patch('asyncio.base_events.socket') + @mock.patch('trollius.base_events.socket') def test_create_server_cant_bind(self, m_socket): class Err(OSError): @@ -1115,7 +1115,7 @@ class Err(OSError): self.assertRaises(OSError, self.loop.run_until_complete, fut) self.assertTrue(m_sock.close.called) - @mock.patch('asyncio.base_events.socket') + @mock.patch('trollius.base_events.socket') def test_create_datagram_endpoint_no_addrinfo(self, m_socket): m_socket.getaddrinfo.return_value = [] m_socket.getaddrinfo._is_coroutine = False @@ -1144,7 +1144,7 @@ def test_create_datagram_endpoint_connect_err(self): self.assertRaises( OSError, self.loop.run_until_complete, coro) - @mock.patch('asyncio.base_events.socket') + @mock.patch('trollius.base_events.socket') def test_create_datagram_endpoint_socket_err(self, m_socket): m_socket.getaddrinfo = socket.getaddrinfo m_socket.socket.side_effect = OSError @@ -1167,7 +1167,7 @@ def test_create_datagram_endpoint_no_matching_family(self): self.assertRaises( ValueError, self.loop.run_until_complete, coro) - @mock.patch('asyncio.base_events.socket') + @mock.patch('trollius.base_events.socket') def test_create_datagram_endpoint_setblk_err(self, m_socket): m_socket.socket.return_value.setblocking.side_effect = OSError @@ -1183,7 +1183,7 @@ def test_create_datagram_endpoint_noaddr_nofamily(self): asyncio.DatagramProtocol) self.assertRaises(ValueError, self.loop.run_until_complete, coro) - @mock.patch('asyncio.base_events.socket') + @mock.patch('trollius.base_events.socket') def test_create_datagram_endpoint_cant_bind(self, m_socket): class Err(OSError): pass @@ -1206,7 +1206,7 @@ def test_accept_connection_retry(self): self.loop._accept_connection(MyProto, sock) self.assertFalse(sock.close.called) - @mock.patch('asyncio.base_events.logger') + @mock.patch('trollius.base_events.logger') def test_accept_connection_exception(self, m_log): sock = mock.Mock() sock.fileno.return_value = 10 @@ -1243,7 +1243,7 @@ def simple_coroutine(): with self.assertRaises(TypeError): self.loop.run_in_executor(None, func) - @mock.patch('asyncio.base_events.logger') + @mock.patch('trollius.base_events.logger') def test_log_slow_callbacks(self, m_logger): def stop_loop_cb(loop): loop.stop() diff --git a/tests/test_events.py b/tests/test_events.py index 8fbba8fe..402898df 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -22,15 +22,15 @@ import weakref -import asyncio -from asyncio import proactor_events -from asyncio import selector_events -from asyncio import sslproto -from asyncio import test_utils +import trollius as asyncio +from trollius import proactor_events +from trollius import selector_events +from trollius import sslproto +from trollius import test_utils try: from test import support except ImportError: - from asyncio import test_support as support + from trollius import test_support as support def data_file(filename): @@ -1856,7 +1856,7 @@ def test_create_datagram_endpoint(self): def test_remove_fds_after_closing(self): raise unittest.SkipTest("IocpEventLoop does not have add_reader()") else: - from asyncio import selectors + from trollius import selectors class UnixEventLoopTestsMixin(EventLoopTestsMixin): def setUp(self): @@ -2333,7 +2333,7 @@ def test_get_event_loop_after_set_none(self): policy.set_event_loop(None) self.assertRaises(RuntimeError, policy.get_event_loop) - @mock.patch('asyncio.events.threading.current_thread') + @mock.patch('trollius.events.threading.current_thread') def test_get_event_loop_thread(self, m_current_thread): def f(): diff --git a/tests/test_futures.py b/tests/test_futures.py index c8b6829f..3aadb688 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -7,12 +7,12 @@ import unittest from unittest import mock -import asyncio -from asyncio import test_utils +import trollius as asyncio +from trollius import test_utils try: from test import support except ImportError: - from asyncio import test_support as support + from trollius import test_support as support def _fakefunc(f): @@ -211,20 +211,20 @@ def test(): self.assertRaises(AssertionError, test) fut.cancel() - @mock.patch('asyncio.base_events.logger') + @mock.patch('trollius.base_events.logger') def test_tb_logger_abandoned(self, m_log): fut = asyncio.Future(loop=self.loop) del fut self.assertFalse(m_log.error.called) - @mock.patch('asyncio.base_events.logger') + @mock.patch('trollius.base_events.logger') def test_tb_logger_result_unretrieved(self, m_log): fut = asyncio.Future(loop=self.loop) fut.set_result(42) del fut self.assertFalse(m_log.error.called) - @mock.patch('asyncio.base_events.logger') + @mock.patch('trollius.base_events.logger') def test_tb_logger_result_retrieved(self, m_log): fut = asyncio.Future(loop=self.loop) fut.set_result(42) @@ -232,7 +232,7 @@ def test_tb_logger_result_retrieved(self, m_log): del fut self.assertFalse(m_log.error.called) - @mock.patch('asyncio.base_events.logger') + @mock.patch('trollius.base_events.logger') def test_tb_logger_exception_unretrieved(self, m_log): fut = asyncio.Future(loop=self.loop) fut.set_exception(RuntimeError('boom')) @@ -240,7 +240,7 @@ def test_tb_logger_exception_unretrieved(self, m_log): test_utils.run_briefly(self.loop) self.assertTrue(m_log.error.called) - @mock.patch('asyncio.base_events.logger') + @mock.patch('trollius.base_events.logger') def test_tb_logger_exception_retrieved(self, m_log): fut = asyncio.Future(loop=self.loop) fut.set_exception(RuntimeError('boom')) @@ -248,7 +248,7 @@ def test_tb_logger_exception_retrieved(self, m_log): del fut self.assertFalse(m_log.error.called) - @mock.patch('asyncio.base_events.logger') + @mock.patch('trollius.base_events.logger') def test_tb_logger_exception_result_retrieved(self, m_log): fut = asyncio.Future(loop=self.loop) fut.set_exception(RuntimeError('boom')) @@ -273,7 +273,7 @@ def test_wrap_future_future(self): f2 = asyncio.wrap_future(f1) self.assertIs(f1, f2) - @mock.patch('asyncio.futures.events') + @mock.patch('trollius.futures.events') def test_wrap_future_use_global_loop(self, m_events): def run(arg): return (arg, threading.get_ident()) @@ -311,7 +311,7 @@ def test_future_source_traceback(self): lineno, 'test_future_source_traceback')) - @mock.patch('asyncio.base_events.logger') + @mock.patch('trollius.base_events.logger') def check_future_exception_never_retrieved(self, debug, m_log): self.loop.set_debug(debug) diff --git a/tests/test_locks.py b/tests/test_locks.py index dda4577a..5ad02eda 100644 --- a/tests/test_locks.py +++ b/tests/test_locks.py @@ -4,8 +4,8 @@ from unittest import mock import re -import asyncio -from asyncio import test_utils +import trollius as asyncio +from trollius import test_utils STR_RGX_REPR = ( diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index fcd9ab1e..47ae6bc5 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -4,12 +4,12 @@ import unittest from unittest import mock -import asyncio -from asyncio.proactor_events import BaseProactorEventLoop -from asyncio.proactor_events import _ProactorSocketTransport -from asyncio.proactor_events import _ProactorWritePipeTransport -from asyncio.proactor_events import _ProactorDuplexPipeTransport -from asyncio import test_utils +import trollius as asyncio +from trollius.proactor_events import BaseProactorEventLoop +from trollius.proactor_events import _ProactorSocketTransport +from trollius.proactor_events import _ProactorWritePipeTransport +from trollius.proactor_events import _ProactorDuplexPipeTransport +from trollius import test_utils def close_transport(transport): @@ -152,7 +152,7 @@ def test_loop_writing(self): self.loop._proactor.send.return_value.add_done_callback.\ assert_called_with(tr._loop_writing) - @mock.patch('asyncio.proactor_events.logger') + @mock.patch('trollius.proactor_events.logger') def test_loop_writing_err(self, m_log): err = self.loop._proactor.send.side_effect = OSError() tr = self.socket_transport() @@ -226,7 +226,7 @@ def test_close_buffer(self): test_utils.run_briefly(self.loop) self.assertFalse(self.protocol.connection_lost.called) - @mock.patch('asyncio.base_events.logger') + @mock.patch('trollius.base_events.logger') def test_fatal_error(self, m_logging): tr = self.socket_transport() tr._force_close = mock.Mock() @@ -539,7 +539,7 @@ def test_write_to_self(self): def test_process_events(self): self.loop._process_events([]) - @mock.patch('asyncio.base_events.logger') + @mock.patch('trollius.base_events.logger') def test_create_server(self, m_log): pf = mock.Mock() call_soon = self.loop.call_soon = mock.Mock() diff --git a/tests/test_queues.py b/tests/test_queues.py index 88b4f075..b097a70d 100644 --- a/tests/test_queues.py +++ b/tests/test_queues.py @@ -3,8 +3,8 @@ import unittest from unittest import mock -import asyncio -from asyncio import test_utils +import trollius as asyncio +from trollius import test_utils class _QueueTestBase(test_utils.TestCase): diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 9478b954..2ae9dc9f 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -9,14 +9,14 @@ except ImportError: ssl = None -import asyncio -from asyncio import selectors -from asyncio import test_utils -from asyncio.selector_events import BaseSelectorEventLoop -from asyncio.selector_events import _SelectorTransport -from asyncio.selector_events import _SelectorSslTransport -from asyncio.selector_events import _SelectorSocketTransport -from asyncio.selector_events import _SelectorDatagramTransport +import trollius as asyncio +from trollius import selectors +from trollius import test_utils +from trollius.selector_events import BaseSelectorEventLoop +from trollius.selector_events import _SelectorTransport +from trollius.selector_events import _SelectorSslTransport +from trollius.selector_events import _SelectorSocketTransport +from trollius.selector_events import _SelectorDatagramTransport MOCK_ANY = mock.ANY @@ -94,8 +94,8 @@ def test_make_ssl_transport(self): # execute pending callbacks to close the socket transport test_utils.run_briefly(self.loop) - @mock.patch('asyncio.selector_events.ssl', None) - @mock.patch('asyncio.sslproto.ssl', None) + @mock.patch('trollius.selector_events.ssl', None) + @mock.patch('trollius.sslproto.ssl', None) def test_make_ssl_transport_without_ssl_error(self): m = mock.Mock() self.loop.add_reader = mock.Mock() @@ -733,7 +733,7 @@ def test_force_close(self): self.assertFalse(self.loop.readers) self.assertEqual(1, self.loop.remove_reader_count[7]) - @mock.patch('asyncio.log.logger.error') + @mock.patch('trollius.log.logger.error') def test_fatal_error(self, m_exc): exc = OSError() tr = self.create_transport() @@ -969,7 +969,7 @@ def test_write_tryagain(self): self.loop.assert_writer(7, transport._write_ready) self.assertEqual(list_to_buffer([b'data']), transport._buffer) - @mock.patch('asyncio.selector_events.logger') + @mock.patch('trollius.selector_events.logger') def test_write_exception(self, m_log): err = self.sock.send.side_effect = OSError() @@ -1077,7 +1077,7 @@ def test_write_ready_exception(self): err, 'Fatal write error on socket transport') - @mock.patch('asyncio.base_events.logger') + @mock.patch('trollius.base_events.logger') def test_write_ready_exception_and_close(self, m_log): self.sock.send.side_effect = OSError() remove_writer = self.loop.remove_writer = mock.Mock() @@ -1243,7 +1243,7 @@ def test_write_closing(self): transport.write(b'data') self.assertEqual(transport._conn_lost, 2) - @mock.patch('asyncio.selector_events.logger') + @mock.patch('trollius.selector_events.logger') def test_write_exception(self, m_log): transport = self._make_one() transport._conn_lost = 1 @@ -1467,7 +1467,7 @@ def test_server_hostname(self): class SelectorSslWithoutSslTransportTests(unittest.TestCase): - @mock.patch('asyncio.selector_events.ssl', None) + @mock.patch('trollius.selector_events.ssl', None) def test_ssl_transport_requires_ssl_module(self): Mock = mock.Mock with self.assertRaises(RuntimeError): @@ -1606,7 +1606,7 @@ def test_sendto_tryagain(self): self.assertEqual( [(b'data', ('0.0.0.0', 12345))], list(transport._buffer)) - @mock.patch('asyncio.selector_events.logger') + @mock.patch('trollius.selector_events.logger') def test_sendto_exception(self, m_log): data = b'data' err = self.sock.sendto.side_effect = RuntimeError() @@ -1749,7 +1749,7 @@ def test_sendto_ready_error_received_connection(self): self.assertFalse(transport._fatal_error.called) self.assertTrue(self.protocol.error_received.called) - @mock.patch('asyncio.base_events.logger.error') + @mock.patch('trollius.base_events.logger.error') def test_fatal_error_connected(self, m_exc): transport = self.datagram_transport(address=('0.0.0.0', 1)) err = ConnectionRefusedError() diff --git a/tests/test_selectors.py b/tests/test_selectors.py index a33f0fa4..5749389e 100644 --- a/tests/test_selectors.py +++ b/tests/test_selectors.py @@ -9,7 +9,7 @@ try: from test import support except ImportError: - from asyncio import test_support as support + from trollius import test_support as support try: from time import monotonic as time except ImportError: @@ -18,8 +18,8 @@ import resource except ImportError: resource = None -from asyncio import selectors -from asyncio.test_utils import socketpair +from trollius import selectors +from trollius.test_utils import socketpair def find_ready_matching(ready, flag): diff --git a/tests/test_sslproto.py b/tests/test_sslproto.py index a72967ea..7c7bbf80 100644 --- a/tests/test_sslproto.py +++ b/tests/test_sslproto.py @@ -7,9 +7,9 @@ except ImportError: ssl = None -import asyncio -from asyncio import sslproto -from asyncio import test_utils +import trollius as asyncio +from trollius import sslproto +from trollius import test_utils @unittest.skipIf(ssl is None, 'No ssl module') @@ -36,7 +36,7 @@ def connection_made(self, ssl_proto, do_handshake=None): def mock_handshake(callback): return [] sslpipe.do_handshake.side_effect = mock_handshake - with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe): + with mock.patch('trollius.sslproto._SSLPipe', return_value=sslpipe): ssl_proto.connection_made(transport) def test_cancel_handshake(self): diff --git a/tests/test_streams.py b/tests/test_streams.py index 2273049b..5f7eb7e7 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -11,8 +11,8 @@ except ImportError: ssl = None -import asyncio -from asyncio import test_utils +import trollius as asyncio +from trollius import test_utils class StreamReaderTests(test_utils.TestCase): @@ -31,7 +31,7 @@ def tearDown(self): gc.collect() super().tearDown() - @mock.patch('asyncio.streams.events') + @mock.patch('trollius.streams.events') def test_ctor_global_loop(self, m_events): stream = asyncio.StreamReader() self.assertIs(stream._loop, m_events.get_event_loop.return_value) diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index 5ccdafb1..b4f3f950 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -3,16 +3,16 @@ import unittest from unittest import mock -import asyncio -from asyncio import base_subprocess -from asyncio import subprocess -from asyncio import test_utils +import trollius as asyncio +from trollius import base_subprocess +from trollius import subprocess +from trollius import test_utils try: from test import support except ImportError: - from asyncio import test_support as support + from trollius import test_support as support if sys.platform != 'win32': - from asyncio import unix_events + from trollius import unix_events # Program blocking PROGRAM_BLOCKED = [sys.executable, '-c', 'import time; time.sleep(3600)'] diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 251192ac..a8ceba01 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -10,20 +10,20 @@ import weakref from unittest import mock -import asyncio -from asyncio import coroutines -from asyncio import test_utils +import trollius as asyncio +from trollius import coroutines +from trollius import test_utils try: from test import support except ImportError: - from asyncio import test_support as support + from trollius import test_support as support try: from test.support.script_helper import assert_python_ok except ImportError: try: from test.script_helper import assert_python_ok except ImportError: - from asyncio.test_support import assert_python_ok + from trollius.test_support import assert_python_ok PY34 = (sys.version_info >= (3, 4)) @@ -1699,7 +1699,7 @@ def kill_me(loop): }) mock_handler.reset_mock() - @mock.patch('asyncio.coroutines.logger') + @mock.patch('trollius.coroutines.logger') def test_coroutine_never_yielded(self, m_log): with set_coroutine_debug(True): @asyncio.coroutine diff --git a/tests/test_transports.py b/tests/test_transports.py index 3b6e3d67..1cb03e0a 100644 --- a/tests/test_transports.py +++ b/tests/test_transports.py @@ -3,8 +3,8 @@ import unittest from unittest import mock -import asyncio -from asyncio import transports +import trollius as asyncio +from trollius import transports class TransportTests(unittest.TestCase): diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index dc0835c5..7f920d05 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -17,10 +17,10 @@ raise unittest.SkipTest('UNIX only') -import asyncio -from asyncio import log -from asyncio import test_utils -from asyncio import unix_events +import trollius as asyncio +from trollius import log +from trollius import test_utils +from trollius import unix_events MOCK_ANY = mock.ANY @@ -60,7 +60,7 @@ def test_handle_signal_cancelled_handler(self): self.loop._handle_signal(signal.NSIG + 1) self.loop.remove_signal_handler.assert_called_with(signal.NSIG + 1) - @mock.patch('asyncio.unix_events.signal') + @mock.patch('trollius.unix_events.signal') def test_add_signal_handler_setup_error(self, m_signal): m_signal.NSIG = signal.NSIG m_signal.set_wakeup_fd.side_effect = ValueError @@ -70,7 +70,7 @@ def test_add_signal_handler_setup_error(self, m_signal): self.loop.add_signal_handler, signal.SIGINT, lambda: True) - @mock.patch('asyncio.unix_events.signal') + @mock.patch('trollius.unix_events.signal') def test_add_signal_handler_coroutine_error(self, m_signal): m_signal.NSIG = signal.NSIG @@ -88,7 +88,7 @@ def simple_coroutine(): self.loop.add_signal_handler, signal.SIGINT, func) - @mock.patch('asyncio.unix_events.signal') + @mock.patch('trollius.unix_events.signal') def test_add_signal_handler(self, m_signal): m_signal.NSIG = signal.NSIG @@ -98,7 +98,7 @@ def test_add_signal_handler(self, m_signal): self.assertIsInstance(h, asyncio.Handle) self.assertEqual(h._callback, cb) - @mock.patch('asyncio.unix_events.signal') + @mock.patch('trollius.unix_events.signal') def test_add_signal_handler_install_error(self, m_signal): m_signal.NSIG = signal.NSIG @@ -116,8 +116,8 @@ class Err(OSError): self.loop.add_signal_handler, signal.SIGINT, lambda: True) - @mock.patch('asyncio.unix_events.signal') - @mock.patch('asyncio.base_events.logger') + @mock.patch('trollius.unix_events.signal') + @mock.patch('trollius.base_events.logger') def test_add_signal_handler_install_error2(self, m_logging, m_signal): m_signal.NSIG = signal.NSIG @@ -133,8 +133,8 @@ class Err(OSError): self.assertFalse(m_logging.info.called) self.assertEqual(1, m_signal.set_wakeup_fd.call_count) - @mock.patch('asyncio.unix_events.signal') - @mock.patch('asyncio.base_events.logger') + @mock.patch('trollius.unix_events.signal') + @mock.patch('trollius.base_events.logger') def test_add_signal_handler_install_error3(self, m_logging, m_signal): class Err(OSError): errno = errno.EINVAL @@ -148,7 +148,7 @@ class Err(OSError): self.assertFalse(m_logging.info.called) self.assertEqual(2, m_signal.set_wakeup_fd.call_count) - @mock.patch('asyncio.unix_events.signal') + @mock.patch('trollius.unix_events.signal') def test_remove_signal_handler(self, m_signal): m_signal.NSIG = signal.NSIG @@ -161,7 +161,7 @@ def test_remove_signal_handler(self, m_signal): self.assertEqual( (signal.SIGHUP, m_signal.SIG_DFL), m_signal.signal.call_args[0]) - @mock.patch('asyncio.unix_events.signal') + @mock.patch('trollius.unix_events.signal') def test_remove_signal_handler_2(self, m_signal): m_signal.NSIG = signal.NSIG m_signal.SIGINT = signal.SIGINT @@ -178,8 +178,8 @@ def test_remove_signal_handler_2(self, m_signal): (signal.SIGINT, m_signal.default_int_handler), m_signal.signal.call_args[0]) - @mock.patch('asyncio.unix_events.signal') - @mock.patch('asyncio.base_events.logger') + @mock.patch('trollius.unix_events.signal') + @mock.patch('trollius.base_events.logger') def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): m_signal.NSIG = signal.NSIG self.loop.add_signal_handler(signal.SIGHUP, lambda: True) @@ -189,7 +189,7 @@ def test_remove_signal_handler_cleanup_error(self, m_logging, m_signal): self.loop.remove_signal_handler(signal.SIGHUP) self.assertTrue(m_logging.info) - @mock.patch('asyncio.unix_events.signal') + @mock.patch('trollius.unix_events.signal') def test_remove_signal_handler_error(self, m_signal): m_signal.NSIG = signal.NSIG self.loop.add_signal_handler(signal.SIGHUP, lambda: True) @@ -199,7 +199,7 @@ def test_remove_signal_handler_error(self, m_signal): self.assertRaises( OSError, self.loop.remove_signal_handler, signal.SIGHUP) - @mock.patch('asyncio.unix_events.signal') + @mock.patch('trollius.unix_events.signal') def test_remove_signal_handler_error2(self, m_signal): m_signal.NSIG = signal.NSIG self.loop.add_signal_handler(signal.SIGHUP, lambda: True) @@ -211,7 +211,7 @@ class Err(OSError): self.assertRaises( RuntimeError, self.loop.remove_signal_handler, signal.SIGHUP) - @mock.patch('asyncio.unix_events.signal') + @mock.patch('trollius.unix_events.signal') def test_close(self, m_signal): m_signal.NSIG = signal.NSIG @@ -275,7 +275,7 @@ def test_create_unix_server_path_inetsock(self): 'A UNIX Domain Socket was expected'): self.loop.run_until_complete(coro) - @mock.patch('asyncio.unix_events.socket') + @mock.patch('trollius.unix_events.socket') def test_create_unix_server_bind_error(self, m_socket): # Ensure that the socket is closed on any bind error sock = mock.Mock() @@ -331,7 +331,7 @@ def setUp(self): self.pipe = mock.Mock(spec_set=io.RawIOBase) self.pipe.fileno.return_value = 5 - blocking_patcher = mock.patch('asyncio.unix_events._set_nonblocking') + blocking_patcher = mock.patch('trollius.unix_events._set_nonblocking') blocking_patcher.start() self.addCleanup(blocking_patcher.stop) @@ -389,7 +389,7 @@ def test__read_ready_blocked(self, m_read): test_utils.run_briefly(self.loop) self.assertFalse(self.protocol.data_received.called) - @mock.patch('asyncio.log.logger.error') + @mock.patch('trollius.log.logger.error') @mock.patch('os.read') def test__read_ready_error(self, m_read, m_logexc): tr = self.read_pipe_transport() @@ -480,7 +480,7 @@ def setUp(self): self.pipe = mock.Mock(spec_set=io.RawIOBase) self.pipe.fileno.return_value = 5 - blocking_patcher = mock.patch('asyncio.unix_events._set_nonblocking') + blocking_patcher = mock.patch('trollius.unix_events._set_nonblocking') blocking_patcher.start() self.addCleanup(blocking_patcher.stop) @@ -556,7 +556,7 @@ def test_write_again(self, m_write): self.loop.assert_writer(5, tr._write_ready) self.assertEqual([b'data'], tr._buffer) - @mock.patch('asyncio.unix_events.logger') + @mock.patch('trollius.unix_events.logger') @mock.patch('os.write') def test_write_err(self, m_write, m_log): tr = self.write_pipe_transport() @@ -646,7 +646,7 @@ def test__write_ready_empty(self, m_write): self.loop.assert_writer(5, tr._write_ready) self.assertEqual([b'data'], tr._buffer) - @mock.patch('asyncio.log.logger.error') + @mock.patch('trollius.log.logger.error') @mock.patch('os.write') def test__write_ready_err(self, m_write, m_logexc): tr = self.write_pipe_transport() diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index 657a4274..ec0a5ca4 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -8,10 +8,10 @@ import _winapi -import asyncio -from asyncio import _overlapped -from asyncio import test_utils -from asyncio import windows_events +import trollius as asyncio +from trollius import _overlapped +from trollius import test_utils +from trollius import windows_events class UpperProto(asyncio.Protocol): diff --git a/tests/test_windows_utils.py b/tests/test_windows_utils.py index d48b8bcb..191daabf 100644 --- a/tests/test_windows_utils.py +++ b/tests/test_windows_utils.py @@ -11,12 +11,12 @@ import _winapi -from asyncio import _overlapped -from asyncio import windows_utils +from trollius import _overlapped +from trollius import windows_utils try: from test import support except ImportError: - from asyncio import test_support as support + from trollius import test_support as support class WinsocketpairTests(unittest.TestCase): @@ -38,7 +38,7 @@ def test_winsocketpair_ipv6(self): @unittest.skipIf(hasattr(socket, 'socketpair'), 'socket.socketpair is available') - @mock.patch('asyncio.windows_utils.socket') + @mock.patch('trollius.windows_utils.socket') def test_winsocketpair_exc(self, m_socket): m_socket.AF_INET = socket.AF_INET m_socket.SOCK_STREAM = socket.SOCK_STREAM @@ -58,7 +58,7 @@ def test_winsocketpair_invalid_args(self): @unittest.skipIf(hasattr(socket, 'socketpair'), 'socket.socketpair is available') - @mock.patch('asyncio.windows_utils.socket') + @mock.patch('trollius.windows_utils.socket') def test_winsocketpair_close(self, m_socket): m_socket.AF_INET = socket.AF_INET m_socket.SOCK_STREAM = socket.SOCK_STREAM diff --git a/trollius/test_utils.py b/trollius/test_utils.py index 8cee95b8..af7f5bca 100644 --- a/trollius/test_utils.py +++ b/trollius/test_utils.py @@ -442,5 +442,5 @@ def mock_nonblocking_socket(): def force_legacy_ssl_support(): - return mock.patch('asyncio.sslproto._is_sslproto_available', + return mock.patch('trollius.sslproto._is_sslproto_available', return_value=False) From a01f3f4993b068dfe17454a06db6b166532dc81a Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 6 Jul 2015 22:38:39 +0200 Subject: [PATCH 1392/1502] Add compatibility files from old trollius project --- trollius/compat.py | 61 +++++++++++ trollius/executor.py | 84 +++++++++++++++ trollius/py27_weakrefset.py | 202 ++++++++++++++++++++++++++++++++++++ trollius/py33_exceptions.py | 144 +++++++++++++++++++++++++ trollius/py33_winapi.py | 75 +++++++++++++ trollius/py3_ssl.py | 149 ++++++++++++++++++++++++++ trollius/time_monotonic.py | 192 ++++++++++++++++++++++++++++++++++ 7 files changed, 907 insertions(+) create mode 100644 trollius/compat.py create mode 100644 trollius/executor.py create mode 100644 trollius/py27_weakrefset.py create mode 100644 trollius/py33_exceptions.py create mode 100644 trollius/py33_winapi.py create mode 100644 trollius/py3_ssl.py create mode 100644 trollius/time_monotonic.py diff --git a/trollius/compat.py b/trollius/compat.py new file mode 100644 index 00000000..79478420 --- /dev/null +++ b/trollius/compat.py @@ -0,0 +1,61 @@ +""" +Compatibility constants and functions for the different Python versions. +""" +import sys + +# Python 2.6 or older? +PY26 = (sys.version_info < (2, 7)) + +# Python 3.0 or newer? +PY3 = (sys.version_info >= (3,)) + +# Python 3.3 or newer? +PY33 = (sys.version_info >= (3, 3)) + +# Python 3.4 or newer? +PY34 = sys.version_info >= (3, 4) + +if PY3: + integer_types = (int,) + bytes_type = bytes + text_type = str + string_types = (bytes, str) + BYTES_TYPES = (bytes, bytearray, memoryview) +else: + integer_types = (int, long,) + bytes_type = str + text_type = unicode + string_types = basestring + if PY26: + BYTES_TYPES = (str, bytearray, buffer) + else: # Python 2.7 + BYTES_TYPES = (str, bytearray, memoryview, buffer) + +def flatten_bytes(data): + """ + Convert bytes-like objects (bytes, bytearray, memoryview, buffer) to + a bytes string. + """ + if not isinstance(data, BYTES_TYPES): + raise TypeError('data argument must be byte-ish (%r)', + type(data)) + if PY34: + # In Python 3.4, socket.send() and bytes.join() accept memoryview + # and bytearray + return data + if not data: + return b'' + if not PY3 and isinstance(data, (buffer, bytearray)): + return str(data) + elif not PY26 and isinstance(data, memoryview): + return data.tobytes() + else: + return data + +if PY3: + def reraise(tp, value, tb=None): + if value.__traceback__ is not tb: + raise value.with_traceback(tb) + raise value +else: + exec("""def reraise(tp, value, tb=None): raise tp, value, tb""") diff --git a/trollius/executor.py b/trollius/executor.py new file mode 100644 index 00000000..9e7fdd78 --- /dev/null +++ b/trollius/executor.py @@ -0,0 +1,84 @@ +from .log import logger + +__all__ = ( + 'CancelledError', 'TimeoutError', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + ) + +# Argument for default thread pool executor creation. +_MAX_WORKERS = 5 + +try: + import concurrent.futures + import concurrent.futures._base +except ImportError: + FIRST_COMPLETED = 'FIRST_COMPLETED' + FIRST_EXCEPTION = 'FIRST_EXCEPTION' + ALL_COMPLETED = 'ALL_COMPLETED' + + class Future(object): + def __init__(self, callback, args): + try: + self._result = callback(*args) + self._exception = None + except Exception as err: + self._result = None + self._exception = err + self.callbacks = [] + + def cancelled(self): + return False + + def done(self): + return True + + def exception(self): + return self._exception + + def result(self): + if self._exception is not None: + raise self._exception + else: + return self._result + + def add_done_callback(self, callback): + callback(self) + + class Error(Exception): + """Base class for all future-related exceptions.""" + pass + + class CancelledError(Error): + """The Future was cancelled.""" + pass + + class TimeoutError(Error): + """The operation exceeded the given deadline.""" + pass + + class SynchronousExecutor: + """ + Synchronous executor: submit() blocks until it gets the result. + """ + def submit(self, callback, *args): + return Future(callback, args) + + def shutdown(self, wait): + pass + + def get_default_executor(): + logger.error("concurrent.futures module is missing: " + "use a synchrounous executor as fallback!") + return SynchronousExecutor() +else: + FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED + FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION + ALL_COMPLETED = concurrent.futures.ALL_COMPLETED + + Future = concurrent.futures.Future + Error = concurrent.futures._base.Error + CancelledError = concurrent.futures.CancelledError + TimeoutError = concurrent.futures.TimeoutError + + def get_default_executor(): + return concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) diff --git a/trollius/py27_weakrefset.py b/trollius/py27_weakrefset.py new file mode 100644 index 00000000..990c3a6b --- /dev/null +++ b/trollius/py27_weakrefset.py @@ -0,0 +1,202 @@ +# Access WeakSet through the weakref module. +# This code is separated-out because it is needed +# by abc.py to load everything else at startup. + +from _weakref import ref + +__all__ = ['WeakSet'] + + +class _IterationGuard(object): + # This context manager registers itself in the current iterators of the + # weak container, such as to delay all removals until the context manager + # exits. + # This technique should be relatively thread-safe (since sets are). + + def __init__(self, weakcontainer): + # Don't create cycles + self.weakcontainer = ref(weakcontainer) + + def __enter__(self): + w = self.weakcontainer() + if w is not None: + w._iterating.add(self) + return self + + def __exit__(self, e, t, b): + w = self.weakcontainer() + if w is not None: + s = w._iterating + s.remove(self) + if not s: + w._commit_removals() + + +class WeakSet(object): + def __init__(self, data=None): + self.data = set() + def _remove(item, selfref=ref(self)): + self = selfref() + if self is not None: + if self._iterating: + self._pending_removals.append(item) + else: + self.data.discard(item) + self._remove = _remove + # A list of keys to be removed + self._pending_removals = [] + self._iterating = set() + if data is not None: + self.update(data) + + def _commit_removals(self): + l = self._pending_removals + discard = self.data.discard + while l: + discard(l.pop()) + + def __iter__(self): + with _IterationGuard(self): + for itemref in self.data: + item = itemref() + if item is not None: + yield item + + def __len__(self): + return len(self.data) - len(self._pending_removals) + + def __contains__(self, item): + try: + wr = ref(item) + except TypeError: + return False + return wr in self.data + + def __reduce__(self): + return (self.__class__, (list(self),), + getattr(self, '__dict__', None)) + + __hash__ = None + + def add(self, item): + if self._pending_removals: + self._commit_removals() + self.data.add(ref(item, self._remove)) + + def clear(self): + if self._pending_removals: + self._commit_removals() + self.data.clear() + + def copy(self): + return self.__class__(self) + + def pop(self): + if self._pending_removals: + self._commit_removals() + while True: + try: + itemref = self.data.pop() + except KeyError: + raise KeyError('pop from empty WeakSet') + item = itemref() + if item is not None: + return item + + def remove(self, item): + if self._pending_removals: + self._commit_removals() + self.data.remove(ref(item)) + + def discard(self, item): + if self._pending_removals: + self._commit_removals() + self.data.discard(ref(item)) + + def update(self, other): + if self._pending_removals: + self._commit_removals() + for element in other: + self.add(element) + + def __ior__(self, other): + self.update(other) + return self + + def difference(self, other): + newset = self.copy() + newset.difference_update(other) + return newset + __sub__ = difference + + def difference_update(self, other): + self.__isub__(other) + def __isub__(self, other): + if self._pending_removals: + self._commit_removals() + if self is other: + self.data.clear() + else: + self.data.difference_update(ref(item) for item in other) + return self + + def intersection(self, other): + return self.__class__(item for item in other if item in self) + __and__ = intersection + + def intersection_update(self, other): + self.__iand__(other) + def __iand__(self, other): + if self._pending_removals: + self._commit_removals() + self.data.intersection_update(ref(item) for item in other) + return self + + def issubset(self, other): + return self.data.issubset(ref(item) for item in other) + __le__ = issubset + + def __lt__(self, other): + return self.data < set(ref(item) for item in other) + + def issuperset(self, other): + return self.data.issuperset(ref(item) for item in other) + __ge__ = issuperset + + def __gt__(self, other): + return self.data > set(ref(item) for item in other) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return self.data == set(ref(item) for item in other) + + def __ne__(self, other): + opposite = self.__eq__(other) + if opposite is NotImplemented: + return NotImplemented + return not opposite + + def symmetric_difference(self, other): + newset = self.copy() + newset.symmetric_difference_update(other) + return newset + __xor__ = symmetric_difference + + def symmetric_difference_update(self, other): + self.__ixor__(other) + def __ixor__(self, other): + if self._pending_removals: + self._commit_removals() + if self is other: + self.data.clear() + else: + self.data.symmetric_difference_update(ref(item, self._remove) for item in other) + return self + + def union(self, other): + return self.__class__(e for s in (self, other) for e in s) + __or__ = union + + def isdisjoint(self, other): + return len(self.intersection(other)) == 0 diff --git a/trollius/py33_exceptions.py b/trollius/py33_exceptions.py new file mode 100644 index 00000000..94cbfca4 --- /dev/null +++ b/trollius/py33_exceptions.py @@ -0,0 +1,144 @@ +__all__ = ['BlockingIOError', 'BrokenPipeError', 'ChildProcessError', + 'ConnectionRefusedError', 'ConnectionResetError', + 'InterruptedError', 'ConnectionAbortedError', 'PermissionError', + 'FileNotFoundError', + ] + +import errno +import select +import socket +import sys +try: + import ssl +except ImportError: + ssl = None + +from .compat import PY33 + +if PY33: + import builtins + BlockingIOError = builtins.BlockingIOError + BrokenPipeError = builtins.BrokenPipeError + ChildProcessError = builtins.ChildProcessError + ConnectionRefusedError = builtins.ConnectionRefusedError + ConnectionResetError = builtins.ConnectionResetError + InterruptedError = builtins.InterruptedError + ConnectionAbortedError = builtins.ConnectionAbortedError + PermissionError = builtins.PermissionError + FileNotFoundError = builtins.FileNotFoundError + ProcessLookupError = builtins.ProcessLookupError + +else: + # Python < 3.3 + class BlockingIOError(OSError): + pass + + class BrokenPipeError(OSError): + pass + + class ChildProcessError(OSError): + pass + + class ConnectionRefusedError(OSError): + pass + + class InterruptedError(OSError): + pass + + class ConnectionResetError(OSError): + pass + + class ConnectionAbortedError(OSError): + pass + + class PermissionError(OSError): + pass + + class FileNotFoundError(OSError): + pass + + class ProcessLookupError(OSError): + pass + + +_MAP_ERRNO = { + errno.EACCES: PermissionError, + errno.EAGAIN: BlockingIOError, + errno.EALREADY: BlockingIOError, + errno.ECHILD: ChildProcessError, + errno.ECONNABORTED: ConnectionAbortedError, + errno.ECONNREFUSED: ConnectionRefusedError, + errno.ECONNRESET: ConnectionResetError, + errno.EINPROGRESS: BlockingIOError, + errno.EINTR: InterruptedError, + errno.ENOENT: FileNotFoundError, + errno.EPERM: PermissionError, + errno.EPIPE: BrokenPipeError, + errno.ESHUTDOWN: BrokenPipeError, + errno.EWOULDBLOCK: BlockingIOError, + errno.ESRCH: ProcessLookupError, +} + +if sys.platform == 'win32': + from trollius import _overlapped + _MAP_ERRNO.update({ + _overlapped.ERROR_CONNECTION_REFUSED: ConnectionRefusedError, + _overlapped.ERROR_CONNECTION_ABORTED: ConnectionAbortedError, + _overlapped.ERROR_NETNAME_DELETED: ConnectionResetError, + }) + + +def get_error_class(key, default): + return _MAP_ERRNO.get(key, default) + + +if sys.version_info >= (3,): + def reraise(tp, value, tb=None): + if value.__traceback__ is not tb: + raise value.with_traceback(tb) + raise value +else: + exec("""def reraise(tp, value, tb=None): + raise tp, value, tb +""") + + +def _wrap_error(exc, mapping, key): + if key not in mapping: + return + new_err_cls = mapping[key] + new_err = new_err_cls(*exc.args) + + # raise a new exception with the original traceback + if hasattr(exc, '__traceback__'): + traceback = exc.__traceback__ + else: + traceback = sys.exc_info()[2] + reraise(new_err_cls, new_err, traceback) + + +if not PY33: + def wrap_error(func, *args, **kw): + """ + Wrap socket.error, IOError, OSError, select.error to raise new specialized + exceptions of Python 3.3 like InterruptedError (PEP 3151). + """ + try: + return func(*args, **kw) + except (socket.error, IOError, OSError) as exc: + if ssl is not None and isinstance(exc, ssl.SSLError): + raise + if hasattr(exc, 'winerror'): + _wrap_error(exc, _MAP_ERRNO, exc.winerror) + # _MAP_ERRNO does not contain all Windows errors. + # For some errors like "file not found", exc.errno should + # be used (ex: ENOENT). + _wrap_error(exc, _MAP_ERRNO, exc.errno) + raise + except select.error as exc: + if exc.args: + _wrap_error(exc, _MAP_ERRNO, exc.args[0]) + raise +else: + def wrap_error(func, *args, **kw): + return func(*args, **kw) diff --git a/trollius/py33_winapi.py b/trollius/py33_winapi.py new file mode 100644 index 00000000..792bc459 --- /dev/null +++ b/trollius/py33_winapi.py @@ -0,0 +1,75 @@ + +__all__ = [ + 'CloseHandle', 'CreateNamedPipe', 'CreateFile', 'ConnectNamedPipe', + 'NULL', + 'GENERIC_READ', 'GENERIC_WRITE', 'OPEN_EXISTING', 'INFINITE', + 'PIPE_ACCESS_INBOUND', + 'PIPE_ACCESS_DUPLEX', 'PIPE_TYPE_MESSAGE', 'PIPE_READMODE_MESSAGE', + 'PIPE_WAIT', 'PIPE_UNLIMITED_INSTANCES', 'NMPWAIT_WAIT_FOREVER', + 'FILE_FLAG_OVERLAPPED', 'FILE_FLAG_FIRST_PIPE_INSTANCE', + 'WaitForMultipleObjects', 'WaitForSingleObject', + 'WAIT_OBJECT_0', 'ERROR_IO_PENDING', + ] + +try: + # FIXME: use _overlapped on Python 3.3? see windows_utils.pipe() + from _winapi import ( + CloseHandle, CreateNamedPipe, CreateFile, ConnectNamedPipe, + NULL, + GENERIC_READ, GENERIC_WRITE, OPEN_EXISTING, INFINITE, + PIPE_ACCESS_INBOUND, + PIPE_ACCESS_DUPLEX, PIPE_TYPE_MESSAGE, PIPE_READMODE_MESSAGE, + PIPE_WAIT, PIPE_UNLIMITED_INSTANCES, NMPWAIT_WAIT_FOREVER, + FILE_FLAG_OVERLAPPED, FILE_FLAG_FIRST_PIPE_INSTANCE, + WaitForMultipleObjects, WaitForSingleObject, + WAIT_OBJECT_0, ERROR_IO_PENDING, + ) +except ImportError: + # Python < 3.3 + from _multiprocessing import win32 + import _subprocess + + from trollius import _overlapped + + CloseHandle = win32.CloseHandle + CreateNamedPipe = win32.CreateNamedPipe + CreateFile = win32.CreateFile + NULL = win32.NULL + + GENERIC_READ = win32.GENERIC_READ + GENERIC_WRITE = win32.GENERIC_WRITE + OPEN_EXISTING = win32.OPEN_EXISTING + INFINITE = win32.INFINITE + + PIPE_ACCESS_INBOUND = win32.PIPE_ACCESS_INBOUND + PIPE_ACCESS_DUPLEX = win32.PIPE_ACCESS_DUPLEX + PIPE_READMODE_MESSAGE = win32.PIPE_READMODE_MESSAGE + PIPE_TYPE_MESSAGE = win32.PIPE_TYPE_MESSAGE + PIPE_WAIT = win32.PIPE_WAIT + PIPE_UNLIMITED_INSTANCES = win32.PIPE_UNLIMITED_INSTANCES + NMPWAIT_WAIT_FOREVER = win32.NMPWAIT_WAIT_FOREVER + + FILE_FLAG_OVERLAPPED = 0x40000000 + FILE_FLAG_FIRST_PIPE_INSTANCE = 0x00080000 + + WAIT_OBJECT_0 = _subprocess.WAIT_OBJECT_0 + WaitForSingleObject = _subprocess.WaitForSingleObject + ERROR_IO_PENDING = _overlapped.ERROR_IO_PENDING + + def ConnectNamedPipe(handle, overlapped): + ov = _overlapped.Overlapped() + ov.ConnectNamedPipe(handle) + return ov + + def WaitForMultipleObjects(events, wait_all, timeout): + if not wait_all: + raise NotImplementedError() + + for ev in events: + res = WaitForSingleObject(ev, timeout) + if res != WAIT_OBJECT_0: + err = win32.GetLastError() + msg = _overlapped.FormatMessage(err) + raise WindowsError(err, msg) + + return WAIT_OBJECT_0 diff --git a/trollius/py3_ssl.py b/trollius/py3_ssl.py new file mode 100644 index 00000000..c592ee66 --- /dev/null +++ b/trollius/py3_ssl.py @@ -0,0 +1,149 @@ +""" +Backport SSL functions and exceptions: +- BACKPORT_SSL_ERRORS (bool) +- SSLWantReadError, SSLWantWriteError, SSLEOFError +- BACKPORT_SSL_CONTEXT (bool) +- SSLContext +- wrap_socket() +- wrap_ssl_error() +""" +import errno +import ssl +import sys +from trollius.py33_exceptions import _wrap_error + +__all__ = ["SSLContext", "BACKPORT_SSL_ERRORS", "BACKPORT_SSL_CONTEXT", + "SSLWantReadError", "SSLWantWriteError", "SSLEOFError", + ] + +try: + SSLWantReadError = ssl.SSLWantReadError + SSLWantWriteError = ssl.SSLWantWriteError + SSLEOFError = ssl.SSLEOFError + BACKPORT_SSL_ERRORS = False +except AttributeError: + # Python < 3.3 + BACKPORT_SSL_ERRORS = True + + class SSLWantReadError(ssl.SSLError): + pass + + class SSLWantWriteError(ssl.SSLError): + pass + + class SSLEOFError(ssl.SSLError): + pass + + +try: + SSLContext = ssl.SSLContext + BACKPORT_SSL_CONTEXT = False + wrap_socket = ssl.wrap_socket +except AttributeError: + # Python < 3.2 + BACKPORT_SSL_CONTEXT = True + + if (sys.version_info < (2, 6, 6)): + # SSLSocket constructor has bugs in Python older than 2.6.6: + # http://bugs.python.org/issue5103 + # http://bugs.python.org/issue7943 + from socket import socket, error as socket_error, _delegate_methods + import _ssl + + class BackportSSLSocket(ssl.SSLSocket): + # Override SSLSocket.__init__() + def __init__(self, sock, keyfile=None, certfile=None, + server_side=False, cert_reqs=ssl.CERT_NONE, + ssl_version=ssl.PROTOCOL_SSLv23, ca_certs=None, + do_handshake_on_connect=True, + suppress_ragged_eofs=True): + socket.__init__(self, _sock=sock._sock) + # The initializer for socket overrides the methods send(), recv(), etc. + # in the instancce, which we don't need -- but we want to provide the + # methods defined in SSLSocket. + for attr in _delegate_methods: + try: + delattr(self, attr) + except AttributeError: + pass + + if certfile and not keyfile: + keyfile = certfile + # see if it's connected + try: + socket.getpeername(self) + except socket_error as e: + if e.errno != errno.ENOTCONN: + raise + # no, no connection yet + self._connected = False + self._sslobj = None + else: + # yes, create the SSL object + self._connected = True + self._sslobj = _ssl.sslwrap(self._sock, server_side, + keyfile, certfile, + cert_reqs, ssl_version, ca_certs) + if do_handshake_on_connect: + self.do_handshake() + self.keyfile = keyfile + self.certfile = certfile + self.cert_reqs = cert_reqs + self.ssl_version = ssl_version + self.ca_certs = ca_certs + self.do_handshake_on_connect = do_handshake_on_connect + self.suppress_ragged_eofs = suppress_ragged_eofs + self._makefile_refs = 0 + + def wrap_socket(sock, server_hostname=None, **kwargs): + # ignore server_hostname parameter, not supported + kwargs.pop('server_hostname', None) + return BackportSSLSocket(sock, **kwargs) + else: + _wrap_socket = ssl.wrap_socket + + def wrap_socket(sock, **kwargs): + # ignore server_hostname parameter, not supported + kwargs.pop('server_hostname', None) + return _wrap_socket(sock, **kwargs) + + + class SSLContext(object): + def __init__(self, protocol=ssl.PROTOCOL_SSLv23): + self.protocol = protocol + self.certfile = None + self.keyfile = None + + def load_cert_chain(self, certfile, keyfile): + self.certfile = certfile + self.keyfile = keyfile + + def wrap_socket(self, sock, **kwargs): + return wrap_socket(sock, + ssl_version=self.protocol, + certfile=self.certfile, + keyfile=self.keyfile, + **kwargs) + + @property + def verify_mode(self): + return ssl.CERT_NONE + + +if BACKPORT_SSL_ERRORS: + _MAP_ERRORS = { + ssl.SSL_ERROR_WANT_READ: SSLWantReadError, + ssl.SSL_ERROR_WANT_WRITE: SSLWantWriteError, + ssl.SSL_ERROR_EOF: SSLEOFError, + } + + def wrap_ssl_error(func, *args, **kw): + try: + return func(*args, **kw) + except ssl.SSLError as exc: + if exc.args: + _wrap_error(exc, _MAP_ERRORS, exc.args[0]) + raise +else: + def wrap_ssl_error(func, *args, **kw): + return func(*args, **kw) diff --git a/trollius/time_monotonic.py b/trollius/time_monotonic.py new file mode 100644 index 00000000..e99364cc --- /dev/null +++ b/trollius/time_monotonic.py @@ -0,0 +1,192 @@ +""" +Backport of time.monotonic() of Python 3.3 (PEP 418) for Python 2.7. + +- time_monotonic(). This clock may or may not be monotonic depending on the + operating system. +- time_monotonic_resolution: Resolution of time_monotonic() clock in second + +Support Windows, Mac OS X, Linux, FreeBSD, OpenBSD and Solaris, but requires +the ctypes module. +""" +import os +import sys +from .log import logger +from .py33_exceptions import get_error_class + +__all__ = ('time_monotonic',) + +# default implementation: system clock (non monotonic!) +from time import time as time_monotonic +# the worst resolution is 15.6 ms on Windows +time_monotonic_resolution = 0.050 + +if os.name == "nt": + # Windows: use GetTickCount64() or GetTickCount() + try: + import ctypes + from ctypes import windll + from ctypes.wintypes import DWORD + except ImportError: + logger.error("time_monotonic import error", exc_info=True) + else: + # GetTickCount64() requires Windows Vista, Server 2008 or later + if hasattr(windll.kernel32, 'GetTickCount64'): + ULONGLONG = ctypes.c_uint64 + + GetTickCount64 = windll.kernel32.GetTickCount64 + GetTickCount64.restype = ULONGLONG + GetTickCount64.argtypes = () + + def time_monotonic(): + return GetTickCount64() * 1e-3 + time_monotonic_resolution = 1e-3 + else: + GetTickCount = windll.kernel32.GetTickCount + GetTickCount.restype = DWORD + GetTickCount.argtypes = () + + # Detect GetTickCount() integer overflow (32 bits, roll-over after 49.7 + # days). It increases an internal epoch (reference time) by 2^32 each + # time that an overflow is detected. The epoch is stored in the + # process-local state and so the value of time_monotonic() may be + # different in two Python processes running for more than 49 days. + def time_monotonic(): + ticks = GetTickCount() + if ticks < time_monotonic.last: + # Integer overflow detected + time_monotonic.delta += 2**32 + time_monotonic.last = ticks + return (ticks + time_monotonic.delta) * 1e-3 + time_monotonic.last = 0 + time_monotonic.delta = 0 + time_monotonic_resolution = 1e-3 + +elif sys.platform == 'darwin': + # Mac OS X: use mach_absolute_time() and mach_timebase_info() + try: + import ctypes + import ctypes.util + libc_name = ctypes.util.find_library('c') + except ImportError: + logger.error("time_monotonic import error", exc_info=True) + libc_name = None + if libc_name: + libc = ctypes.CDLL(libc_name, use_errno=True) + + mach_absolute_time = libc.mach_absolute_time + mach_absolute_time.argtypes = () + mach_absolute_time.restype = ctypes.c_uint64 + + class mach_timebase_info_data_t(ctypes.Structure): + _fields_ = ( + ('numer', ctypes.c_uint32), + ('denom', ctypes.c_uint32), + ) + mach_timebase_info_data_p = ctypes.POINTER(mach_timebase_info_data_t) + + mach_timebase_info = libc.mach_timebase_info + mach_timebase_info.argtypes = (mach_timebase_info_data_p,) + mach_timebase_info.restype = ctypes.c_int + + def time_monotonic(): + return mach_absolute_time() * time_monotonic.factor + + timebase = mach_timebase_info_data_t() + mach_timebase_info(ctypes.byref(timebase)) + time_monotonic.factor = float(timebase.numer) / timebase.denom * 1e-9 + time_monotonic_resolution = time_monotonic.factor + del timebase + +elif sys.platform.startswith(("linux", "freebsd", "openbsd", "sunos")): + # Linux, FreeBSD, OpenBSD: use clock_gettime(CLOCK_MONOTONIC) + # Solaris: use clock_gettime(CLOCK_HIGHRES) + + library = None + try: + import ctypes + import ctypes.util + except ImportError: + logger.error("time_monotonic import error", exc_info=True) + libraries = () + else: + if sys.platform.startswith(("freebsd", "openbsd")): + libraries = ('c',) + elif sys.platform.startswith("linux"): + # Linux: in glibc 2.17+, clock_gettime() is provided by the libc, + # on older versions, it is provided by librt + libraries = ('c', 'rt') + else: + # Solaris + libraries = ('rt',) + + for name in libraries: + filename = ctypes.util.find_library(name) + if not filename: + continue + library = ctypes.CDLL(filename, use_errno=True) + if not hasattr(library, 'clock_gettime'): + library = None + + if library is not None: + if sys.platform.startswith("openbsd"): + import platform + release = platform.release() + release = tuple(map(int, release.split('.'))) + if release >= (5, 5): + time_t = ctypes.c_int64 + else: + time_t = ctypes.c_int32 + else: + time_t = ctypes.c_long + clockid_t = ctypes.c_int + + class timespec(ctypes.Structure): + _fields_ = ( + ('tv_sec', time_t), + ('tv_nsec', ctypes.c_long), + ) + timespec_p = ctypes.POINTER(timespec) + + clock_gettime = library.clock_gettime + clock_gettime.argtypes = (clockid_t, timespec_p) + clock_gettime.restype = ctypes.c_int + + def ctypes_oserror(): + errno = ctypes.get_errno() + message = os.strerror(errno) + error_class = get_error_class(errno, OSError) + return error_class(errno, message) + + def time_monotonic(): + ts = timespec() + err = clock_gettime(time_monotonic.clk_id, ctypes.byref(ts)) + if err: + raise ctypes_oserror() + return ts.tv_sec + ts.tv_nsec * 1e-9 + + if sys.platform.startswith("linux"): + time_monotonic.clk_id = 1 # CLOCK_MONOTONIC + elif sys.platform.startswith("freebsd"): + time_monotonic.clk_id = 4 # CLOCK_MONOTONIC + elif sys.platform.startswith("openbsd"): + time_monotonic.clk_id = 3 # CLOCK_MONOTONIC + else: + assert sys.platform.startswith("sunos") + time_monotonic.clk_id = 4 # CLOCK_HIGHRES + + def get_resolution(): + _clock_getres = library.clock_getres + _clock_getres.argtypes = (clockid_t, timespec_p) + _clock_getres.restype = ctypes.c_int + + ts = timespec() + err = _clock_getres(time_monotonic.clk_id, ctypes.byref(ts)) + if err: + raise ctypes_oserror() + return ts.tv_sec + ts.tv_nsec * 1e-9 + time_monotonic_resolution = get_resolution() + del get_resolution + +else: + logger.error("time_monotonic: unspported platform %r", sys.platform) + From a4749501b494ecda86c2092b4d4b86cc5e4bf2ff Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 7 Jul 2015 21:17:17 +0200 Subject: [PATCH 1393/1502] Port asyncio to Python 2, trollius/ directory --- trollius/__init__.py | 11 +- trollius/base_events.py | 157 +++++++------ trollius/base_subprocess.py | 7 +- trollius/coroutines.py | 134 ++++++++++- trollius/events.py | 442 +++++++++++++++++++----------------- trollius/futures.py | 98 ++++++-- trollius/locks.py | 86 +++---- trollius/proactor_events.py | 20 +- trollius/protocols.py | 2 +- trollius/queues.py | 30 +-- trollius/selector_events.py | 126 +++++----- trollius/selectors.py | 72 +++--- trollius/sslproto.py | 20 +- trollius/streams.py | 74 +++--- trollius/subprocess.py | 55 +++-- trollius/tasks.py | 236 ++++++++++++------- trollius/test_support.py | 17 +- trollius/test_utils.py | 235 +++++++++++++++---- trollius/transports.py | 15 +- trollius/unix_events.py | 148 +++++++----- trollius/windows_events.py | 93 ++++---- trollius/windows_utils.py | 26 ++- 22 files changed, 1319 insertions(+), 785 deletions(-) diff --git a/trollius/__init__.py b/trollius/__init__.py index 011466b3..a1379fbc 100644 --- a/trollius/__init__.py +++ b/trollius/__init__.py @@ -1,4 +1,4 @@ -"""The asyncio package, tracking PEP 3156.""" +"""The trollius package, tracking PEP 3156.""" import sys @@ -24,6 +24,7 @@ from .futures import * from .locks import * from .protocols import * +from .py33_exceptions import * from .queues import * from .streams import * from .subprocess import * @@ -33,6 +34,7 @@ __all__ = (base_events.__all__ + coroutines.__all__ + events.__all__ + + py33_exceptions.__all__ + futures.__all__ + locks.__all__ + protocols.__all__ + @@ -48,3 +50,10 @@ else: from .unix_events import * # pragma: no cover __all__ += unix_events.__all__ + +try: + from .py3_ssl import * + __all__ += py3_ssl.__all__ +except ImportError: + # SSL support is optionnal + pass diff --git a/trollius/base_events.py b/trollius/base_events.py index 5a536a22..c8541f19 100644 --- a/trollius/base_events.py +++ b/trollius/base_events.py @@ -15,25 +15,34 @@ import collections -import concurrent.futures import heapq import inspect import logging import os import socket import subprocess -import threading -import time -import traceback import sys -import warnings - +import traceback +try: + from collections import OrderedDict +except ImportError: + # Python 2.6: use ordereddict backport + from ordereddict import OrderedDict +try: + from threading import get_ident as _get_thread_ident +except ImportError: + # Python 2 + from threading import _get_ident as _get_thread_ident + +from . import compat from . import coroutines from . import events from . import futures from . import tasks -from .coroutines import coroutine +from .coroutines import coroutine, From, Return +from .executor import get_default_executor from .log import logger +from .time_monotonic import time_monotonic, time_monotonic_resolution __all__ = ['BaseEventLoop'] @@ -171,10 +180,10 @@ def _wakeup(self): @coroutine def wait_closed(self): if self.sockets is None or self._waiters is None: - return + raise Return() waiter = futures.Future(loop=self._loop) self._waiters.append(waiter) - yield from waiter + yield From(waiter) class BaseEventLoop(events.AbstractEventLoop): @@ -191,8 +200,7 @@ def __init__(self): self._thread_id = None self._clock_resolution = time.get_clock_info('monotonic').resolution self._exception_handler = None - self.set_debug((not sys.flags.ignore_environment - and bool(os.environ.get('PYTHONASYNCIODEBUG')))) + self.set_debug(bool(os.environ.get('TROLLIUSDEBUG'))) # In debug mode, if the execution of a callback or a step of a task # exceed this duration in seconds, the slow callback/task is logged. self.slow_callback_duration = 0.1 @@ -237,13 +245,13 @@ def get_task_factory(self): """Return a task factory, or None if the default one is in use.""" return self._task_factory - def _make_socket_transport(self, sock, protocol, waiter=None, *, + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None, server=None): """Create socket transport.""" raise NotImplementedError def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, - *, server_side=False, server_hostname=None, + server_side=False, server_hostname=None, extra=None, server=None): """Create SSL transport.""" raise NotImplementedError @@ -506,7 +514,7 @@ def run_in_executor(self, executor, func, *args): if executor is None: executor = self._default_executor if executor is None: - executor = concurrent.futures.ThreadPoolExecutor(_MAX_WORKERS) + executor = get_default_executor() self._default_executor = executor return futures.wrap_future(executor.submit(func, *args), loop=self) @@ -538,7 +546,7 @@ def _getaddrinfo_debug(self, host, port, family, type, proto, flags): logger.debug(msg) return addrinfo - def getaddrinfo(self, host, port, *, + def getaddrinfo(self, host, port, family=0, type=0, proto=0, flags=0): if self._debug: return self.run_in_executor(None, self._getaddrinfo_debug, @@ -551,7 +559,7 @@ def getnameinfo(self, sockaddr, flags=0): return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) @coroutine - def create_connection(self, protocol_factory, host=None, port=None, *, + def create_connection(self, protocol_factory, host=None, port=None, ssl=None, family=0, proto=0, flags=0, sock=None, local_addr=None, server_hostname=None): """Connect to a TCP server. @@ -601,15 +609,15 @@ def create_connection(self, protocol_factory, host=None, port=None, *, else: f2 = None - yield from tasks.wait(fs, loop=self) + yield From(tasks.wait(fs, loop=self)) infos = f1.result() if not infos: - raise OSError('getaddrinfo() returned empty list') + raise socket.error('getaddrinfo() returned empty list') if f2 is not None: laddr_infos = f2.result() if not laddr_infos: - raise OSError('getaddrinfo() returned empty list') + raise socket.error('getaddrinfo() returned empty list') exceptions = [] for family, type, proto, cname, address in infos: @@ -621,11 +629,11 @@ def create_connection(self, protocol_factory, host=None, port=None, *, try: sock.bind(laddr) break - except OSError as exc: - exc = OSError( + except socket.error as exc: + exc = socket.error( exc.errno, 'error while ' 'attempting to bind on address ' - '{!r}: {}'.format( + '{0!r}: {1}'.format( laddr, exc.strerror.lower())) exceptions.append(exc) else: @@ -634,8 +642,8 @@ def create_connection(self, protocol_factory, host=None, port=None, *, continue if self._debug: logger.debug("connect %r to %r", sock, address) - yield from self.sock_connect(sock, address) - except OSError as exc: + yield From(self.sock_connect(sock, address)) + except socket.error as exc: if sock is not None: sock.close() exceptions.append(exc) @@ -655,7 +663,7 @@ def create_connection(self, protocol_factory, host=None, port=None, *, raise exceptions[0] # Raise a combined exception so the user can see all # the various error messages. - raise OSError('Multiple exceptions: {}'.format( + raise socket.error('Multiple exceptions: {0}'.format( ', '.join(str(exc) for exc in exceptions))) elif sock is None: @@ -664,15 +672,15 @@ def create_connection(self, protocol_factory, host=None, port=None, *, sock.setblocking(False) - transport, protocol = yield from self._create_connection_transport( - sock, protocol_factory, ssl, server_hostname) + transport, protocol = yield From(self._create_connection_transport( + sock, protocol_factory, ssl, server_hostname)) if self._debug: # Get the socket from the transport because SSL transport closes # the old socket and creates a new SSL socket sock = transport.get_extra_info('socket') logger.debug("%r connected to %s:%r: (%r, %r)", sock, host, port, transport, protocol) - return transport, protocol + raise Return(transport, protocol) @coroutine def _create_connection_transport(self, sock, protocol_factory, ssl, @@ -688,12 +696,12 @@ def _create_connection_transport(self, sock, protocol_factory, ssl, transport = self._make_socket_transport(sock, protocol, waiter) try: - yield from waiter + yield From(waiter) except: transport.close() raise - return transport, protocol + raise Return(transport, protocol) @coroutine def create_datagram_endpoint(self, protocol_factory, @@ -706,17 +714,17 @@ def create_datagram_endpoint(self, protocol_factory, addr_pairs_info = (((family, proto), (None, None)),) else: # join address by (family, protocol) - addr_infos = collections.OrderedDict() + addr_infos = OrderedDict() for idx, addr in ((0, local_addr), (1, remote_addr)): if addr is not None: assert isinstance(addr, tuple) and len(addr) == 2, ( '2-tuple is expected') - infos = yield from self.getaddrinfo( + infos = yield From(self.getaddrinfo( *addr, family=family, type=socket.SOCK_DGRAM, - proto=proto, flags=flags) + proto=proto, flags=flags)) if not infos: - raise OSError('getaddrinfo() returned empty list') + raise socket.error('getaddrinfo() returned empty list') for fam, _, pro, _, address in infos: key = (fam, pro) @@ -748,9 +756,9 @@ def create_datagram_endpoint(self, protocol_factory, if local_addr: sock.bind(local_address) if remote_addr: - yield from self.sock_connect(sock, remote_address) + yield From(self.sock_connect(sock, remote_address)) r_addr = remote_address - except OSError as exc: + except socket.error as exc: if sock is not None: sock.close() exceptions.append(exc) @@ -778,16 +786,15 @@ def create_datagram_endpoint(self, protocol_factory, remote_addr, transport, protocol) try: - yield from waiter + yield From(waiter) except: transport.close() raise - return transport, protocol + raise Return(transport, protocol) @coroutine def create_server(self, protocol_factory, host=None, port=None, - *, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, sock=None, @@ -814,11 +821,11 @@ def create_server(self, protocol_factory, host=None, port=None, if host == '': host = None - infos = yield from self.getaddrinfo( + infos = yield From(self.getaddrinfo( host, port, family=family, - type=socket.SOCK_STREAM, proto=0, flags=flags) + type=socket.SOCK_STREAM, proto=0, flags=flags)) if not infos: - raise OSError('getaddrinfo() returned empty list') + raise socket.error('getaddrinfo() returned empty list') completed = False try: @@ -846,10 +853,11 @@ def create_server(self, protocol_factory, host=None, port=None, True) try: sock.bind(sa) - except OSError as err: - raise OSError(err.errno, 'error while attempting ' - 'to bind on address %r: %s' - % (sa, err.strerror.lower())) + except socket.error as err: + raise socket.error(err.errno, + 'error while attempting ' + 'to bind on address %r: %s' + % (sa, err.strerror.lower())) completed = True finally: if not completed: @@ -867,7 +875,7 @@ def create_server(self, protocol_factory, host=None, port=None, self._start_serving(protocol_factory, sock, ssl, server) if self._debug: logger.info("%r is serving", server) - return server + raise Return(server) @coroutine def connect_read_pipe(self, protocol_factory, pipe): @@ -876,7 +884,7 @@ def connect_read_pipe(self, protocol_factory, pipe): transport = self._make_read_pipe_transport(pipe, protocol, waiter) try: - yield from waiter + yield From(waiter) except: transport.close() raise @@ -884,7 +892,7 @@ def connect_read_pipe(self, protocol_factory, pipe): if self._debug: logger.debug('Read pipe %r connected: (%r, %r)', pipe.fileno(), transport, protocol) - return transport, protocol + raise Return(transport, protocol) @coroutine def connect_write_pipe(self, protocol_factory, pipe): @@ -893,7 +901,7 @@ def connect_write_pipe(self, protocol_factory, pipe): transport = self._make_write_pipe_transport(pipe, protocol, waiter) try: - yield from waiter + yield From(waiter) except: transport.close() raise @@ -901,7 +909,7 @@ def connect_write_pipe(self, protocol_factory, pipe): if self._debug: logger.debug('Write pipe %r connected: (%r, %r)', pipe.fileno(), transport, protocol) - return transport, protocol + raise Return(transport, protocol) def _log_subprocess(self, msg, stdin, stdout, stderr): info = [msg] @@ -917,11 +925,11 @@ def _log_subprocess(self, msg, stdin, stdout, stderr): logger.debug(' '.join(info)) @coroutine - def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, + def subprocess_shell(self, protocol_factory, cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=False, shell=True, bufsize=0, **kwargs): - if not isinstance(cmd, (bytes, str)): + if not isinstance(cmd, compat.string_types): raise ValueError("cmd must be a string") if universal_newlines: raise ValueError("universal_newlines must be False") @@ -935,17 +943,20 @@ def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, # (password) and may be too long debug_log = 'run shell command %r' % cmd self._log_subprocess(debug_log, stdin, stdout, stderr) - transport = yield from self._make_subprocess_transport( - protocol, cmd, True, stdin, stdout, stderr, bufsize, **kwargs) + transport = yield From(self._make_subprocess_transport( + protocol, cmd, True, stdin, stdout, stderr, bufsize, **kwargs)) if self._debug: logger.info('%s: %r' % (debug_log, transport)) - return transport, protocol + raise Return(transport, protocol) @coroutine - def subprocess_exec(self, protocol_factory, program, *args, - stdin=subprocess.PIPE, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, universal_newlines=False, - shell=False, bufsize=0, **kwargs): + def subprocess_exec(self, protocol_factory, program, *args, **kwargs): + stdin = kwargs.pop('stdin', subprocess.PIPE) + stdout = kwargs.pop('stdout', subprocess.PIPE) + stderr = kwargs.pop('stderr', subprocess.PIPE) + universal_newlines = kwargs.pop('universal_newlines', False) + shell = kwargs.pop('shell', False) + bufsize = kwargs.pop('bufsize', 0) if universal_newlines: raise ValueError("universal_newlines must be False") if shell: @@ -954,7 +965,7 @@ def subprocess_exec(self, protocol_factory, program, *args, raise ValueError("bufsize must be 0") popen_args = (program,) + args for arg in popen_args: - if not isinstance(arg, (str, bytes)): + if not isinstance(arg, compat.string_types ): raise TypeError("program arguments must be " "a bytes or text string, not %s" % type(arg).__name__) @@ -964,12 +975,12 @@ def subprocess_exec(self, protocol_factory, program, *args, # (password) and may be too long debug_log = 'execute program %r' % program self._log_subprocess(debug_log, stdin, stdout, stderr) - transport = yield from self._make_subprocess_transport( + transport = yield From(self._make_subprocess_transport( protocol, popen_args, False, stdin, stdout, stderr, - bufsize, **kwargs) + bufsize, **kwargs)) if self._debug: logger.info('%s: %r' % (debug_log, transport)) - return transport, protocol + raise Return(transport, protocol) def set_exception_handler(self, handler): """Set handler as the new event loop exception handler. @@ -985,7 +996,7 @@ def set_exception_handler(self, handler): """ if handler is not None and not callable(handler): raise TypeError('A callable object or None is expected, ' - 'got {!r}'.format(handler)) + 'got {0!r}'.format(handler)) self._exception_handler = handler def default_exception_handler(self, context): @@ -1004,7 +1015,15 @@ def default_exception_handler(self, context): exception = context.get('exception') if exception is not None: - exc_info = (type(exception), exception, exception.__traceback__) + if hasattr(exception, '__traceback__'): + # Python 3 + tb = exception.__traceback__ + else: + # call_exception_handler() is usually called indirectly + # from an except block. If it's not the case, the traceback + # is undefined... + tb = sys.exc_info()[2] + exc_info = (type(exception), exception, tb) else: exc_info = False @@ -1015,7 +1034,7 @@ def default_exception_handler(self, context): log_lines = [message] for key in sorted(context): - if key in {'message', 'exception'}: + if key in ('message', 'exception'): continue value = context[key] if key == 'source_traceback': @@ -1028,7 +1047,7 @@ def default_exception_handler(self, context): value += tb.rstrip() else: value = repr(value) - log_lines.append('{}: {}'.format(key, value)) + log_lines.append('{0}: {1}'.format(key, value)) logger.error('\n'.join(log_lines), exc_info=exc_info) @@ -1108,7 +1127,7 @@ def _run_once(self): sched_count = len(self._scheduled) if (sched_count > _MIN_SCHEDULED_TIMER_HANDLES and - self._timer_cancelled_count / sched_count > + float(self._timer_cancelled_count) / sched_count > _MIN_CANCELLED_TIMER_HANDLES_FRACTION): # Remove delayed calls that were cancelled if their number # is too high diff --git a/trollius/base_subprocess.py b/trollius/base_subprocess.py index c1477b82..d3a64655 100644 --- a/trollius/base_subprocess.py +++ b/trollius/base_subprocess.py @@ -6,7 +6,7 @@ from . import futures from . import protocols from . import transports -from .coroutines import coroutine +from .coroutines import coroutine, From from .log import logger @@ -15,7 +15,7 @@ class BaseSubprocessTransport(transports.SubprocessTransport): def __init__(self, loop, protocol, args, shell, stdin, stdout, stderr, bufsize, waiter=None, extra=None, **kwargs): - super().__init__(extra) + super(BaseSubprocessTransport, self).__init__(extra) self._closed = False self._protocol = protocol self._loop = loop @@ -221,7 +221,8 @@ def _wait(self): waiter = futures.Future(loop=self._loop) self._exit_waiters.append(waiter) - return (yield from waiter) + returncode = yield From(waiter) + return returncode def _try_finish(self): assert not self._finished diff --git a/trollius/coroutines.py b/trollius/coroutines.py index 15475f23..b12ca3e1 100644 --- a/trollius/coroutines.py +++ b/trollius/coroutines.py @@ -9,6 +9,7 @@ import traceback import types +from . import compat from . import events from . import futures from .log import logger @@ -18,7 +19,7 @@ # Opcode of "yield from" instruction -_YIELD_FROM = opcode.opmap['YIELD_FROM'] +_YIELD_FROM = opcode.opmap.get('YIELD_FROM', None) # If you set _DEBUG to true, @coroutine will wrap the resulting # generator objects in a CoroWrapper instance (defined below). That @@ -29,8 +30,7 @@ # before you define your coroutines. A downside of using this feature # is that tracebacks show entries for the CoroWrapper.__next__ method # when _DEBUG is true. -_DEBUG = (not sys.flags.ignore_environment - and bool(os.environ.get('PYTHONASYNCIODEBUG'))) +_DEBUG = bool(os.environ.get('TROLLIUSDEBUG')) try: @@ -74,6 +74,53 @@ def yield_from_gen(gen): del has_yield_from_bug +if compat.PY33: + # Don't use the Return class on Python 3.3 and later to support asyncio + # coroutines (to avoid the warning emited in Return destructor). + # + # The problem is that Return inherits from StopIteration. "yield from + # trollius_coroutine". Task._step() does not receive the Return exception, + # because "yield from" handles it internally. So it's not possible to set + # the raised attribute to True to avoid the warning in Return destructor. + def Return(*args): + if not args: + value = None + elif len(args) == 1: + value = args[0] + else: + value = args + return StopIteration(value) +else: + class Return(StopIteration): + def __init__(self, *args): + StopIteration.__init__(self) + if not args: + self.value = None + elif len(args) == 1: + self.value = args[0] + else: + self.value = args + self.raised = False + if _DEBUG: + frame = sys._getframe(1) + self._source_traceback = traceback.extract_stack(frame) + # explicitly clear the reference to avoid reference cycles + frame = None + else: + self._source_traceback = None + + def __del__(self): + if self.raised: + return + + fmt = 'Return(%r) used without raise' + if self._source_traceback: + fmt += '\nReturn created at (most recent call last):\n' + tb = ''.join(traceback.format_list(self._source_traceback)) + fmt += tb.rstrip() + logger.error(fmt, self.value) + + def debug_wrapper(gen): # This function is called from 'sys.set_coroutine_wrapper'. # We only wrap here coroutines defined via 'async def' syntax. @@ -104,7 +151,8 @@ def __iter__(self): return self def __next__(self): - return self.gen.send(None) + return next(self.gen) + next = __next__ if _YIELD_FROM_BUG: # For for CPython issue #21209: using "yield from" and a custom @@ -180,6 +228,56 @@ def __del__(self): msg += tb.rstrip() logger.error(msg) +if not compat.PY34: + # Backport functools.update_wrapper() from Python 3.4: + # - Python 2.7 fails if assigned attributes don't exist + # - Python 2.7 and 3.1 don't set the __wrapped__ attribute + # - Python 3.2 and 3.3 set __wrapped__ before updating __dict__ + def _update_wrapper(wrapper, + wrapped, + assigned = functools.WRAPPER_ASSIGNMENTS, + updated = functools.WRAPPER_UPDATES): + """Update a wrapper function to look like the wrapped function + + wrapper is the function to be updated + wrapped is the original function + assigned is a tuple naming the attributes assigned directly + from the wrapped function to the wrapper function (defaults to + functools.WRAPPER_ASSIGNMENTS) + updated is a tuple naming the attributes of the wrapper that + are updated with the corresponding attribute from the wrapped + function (defaults to functools.WRAPPER_UPDATES) + """ + for attr in assigned: + try: + value = getattr(wrapped, attr) + except AttributeError: + pass + else: + setattr(wrapper, attr, value) + for attr in updated: + getattr(wrapper, attr).update(getattr(wrapped, attr, {})) + # Issue #17482: set __wrapped__ last so we don't inadvertently copy it + # from the wrapped function when updating __dict__ + wrapper.__wrapped__ = wrapped + # Return the wrapper so this can be used as a decorator via partial() + return wrapper + + def _wraps(wrapped, + assigned = functools.WRAPPER_ASSIGNMENTS, + updated = functools.WRAPPER_UPDATES): + """Decorator factory to apply update_wrapper() to a wrapper function + + Returns a decorator that invokes update_wrapper() with the decorated + function as the wrapper argument and the arguments to wraps() as the + remaining arguments. Default arguments are as for update_wrapper(). + This is a convenience function to simplify applying partial() to + update_wrapper(). + """ + return functools.partial(_update_wrapper, wrapped=wrapped, + assigned=assigned, updated=updated) +else: + _wraps = functools.wraps def coroutine(func): """Decorator to mark coroutines. @@ -197,7 +295,7 @@ def coroutine(func): if inspect.isgeneratorfunction(func): coro = func else: - @functools.wraps(func) + @_wraps(func) def coro(*args, **kw): res = func(*args, **kw) if isinstance(res, futures.Future) or inspect.isgenerator(res): @@ -220,7 +318,7 @@ def coro(*args, **kw): else: wrapper = _types_coroutine(coro) else: - @functools.wraps(func) + @_wraps(func) def wrapper(*args, **kwds): w = CoroWrapper(coro(*args, **kwds), func=func) if w._source_traceback: @@ -246,7 +344,13 @@ def iscoroutinefunction(func): _COROUTINE_TYPES = (types.GeneratorType, CoroWrapper) if _CoroutineABC is not None: _COROUTINE_TYPES += (_CoroutineABC,) - +if events.asyncio is not None: + # Accept also asyncio CoroWrapper for interoperability + if hasattr(events.asyncio, 'coroutines'): + _COROUTINE_TYPES += (events.asyncio.coroutines.CoroWrapper,) + else: + # old Tulip/Python versions + _COROUTINE_TYPES += (events.asyncio.tasks.CoroWrapper,) def iscoroutine(obj): """Return True if obj is a coroutine object.""" @@ -299,3 +403,19 @@ def _format_coroutine(coro): % (coro_name, filename, lineno)) return coro_repr + + +class FromWrapper(object): + __slots__ = ('obj',) + + def __init__(self, obj): + if isinstance(obj, FromWrapper): + obj = obj.obj + assert not isinstance(obj, FromWrapper) + self.obj = obj + +def From(obj): + if not _DEBUG: + return obj + else: + return FromWrapper(obj) diff --git a/trollius/events.py b/trollius/events.py index 496075ba..3aa5b692 100644 --- a/trollius/events.py +++ b/trollius/events.py @@ -10,12 +10,23 @@ import functools import inspect -import reprlib import socket import subprocess import sys import threading import traceback +try: + import reprlib # Python 3 +except ImportError: + import repr as reprlib # Python 2 + +from trollius import compat +try: + import asyncio +except (ImportError, SyntaxError): + # ignore SyntaxError for convenience: ignore SyntaxError caused by "yield + # from" if asyncio module is in the Python path + asyncio = None _PY34 = sys.version_info >= (3, 4) @@ -75,7 +86,7 @@ def _format_callback_source(func, args): return func_repr -class Handle: +class Handle(object): """Object returned by callback registration methods.""" __slots__ = ('_callback', '_args', '_cancelled', '_loop', @@ -145,14 +156,14 @@ class TimerHandle(Handle): def __init__(self, when, callback, args, loop): assert when is not None - super().__init__(callback, args, loop) + super(TimerHandle, self).__init__(callback, args, loop) if self._source_traceback: del self._source_traceback[-1] self._when = when self._scheduled = False def _repr_info(self): - info = super()._repr_info() + info = super(TimerHandle, self)._repr_info() pos = 2 if self._cancelled else 1 info.insert(pos, 'when=%s' % self._when) return info @@ -191,10 +202,10 @@ def __ne__(self, other): def cancel(self): if not self._cancelled: self._loop._timer_handle_cancelled(self) - super().cancel() + super(TimerHandle, self).cancel() -class AbstractServer: +class AbstractServer(object): """Abstract server returned by create_server().""" def close(self): @@ -206,298 +217,303 @@ def wait_closed(self): return NotImplemented -class AbstractEventLoop: - """Abstract event loop.""" +if asyncio is not None: + # Reuse asyncio classes so asyncio.set_event_loop() and + # asyncio.set_event_loop_policy() accept Trollius event loop and trollius + # event loop policy + AbstractEventLoop = asyncio.AbstractEventLoop + AbstractEventLoopPolicy = asyncio.AbstractEventLoopPolicy +else: + class AbstractEventLoop(object): + """Abstract event loop.""" - # Running and stopping the event loop. + # Running and stopping the event loop. - def run_forever(self): - """Run the event loop until stop() is called.""" - raise NotImplementedError + def run_forever(self): + """Run the event loop until stop() is called.""" + raise NotImplementedError - def run_until_complete(self, future): - """Run the event loop until a Future is done. + def run_until_complete(self, future): + """Run the event loop until a Future is done. - Return the Future's result, or raise its exception. - """ - raise NotImplementedError + Return the Future's result, or raise its exception. + """ + raise NotImplementedError - def stop(self): - """Stop the event loop as soon as reasonable. + def stop(self): + """Stop the event loop as soon as reasonable. - Exactly how soon that is may depend on the implementation, but - no more I/O callbacks should be scheduled. - """ - raise NotImplementedError + Exactly how soon that is may depend on the implementation, but + no more I/O callbacks should be scheduled. + """ + raise NotImplementedError - def is_running(self): - """Return whether the event loop is currently running.""" - raise NotImplementedError + def is_running(self): + """Return whether the event loop is currently running.""" + raise NotImplementedError - def is_closed(self): - """Returns True if the event loop was closed.""" - raise NotImplementedError + def is_closed(self): + """Returns True if the event loop was closed.""" + raise NotImplementedError - def close(self): - """Close the loop. + def close(self): + """Close the loop. - The loop should not be running. + The loop should not be running. - This is idempotent and irreversible. + This is idempotent and irreversible. - No other methods should be called after this one. - """ - raise NotImplementedError + No other methods should be called after this one. + """ + raise NotImplementedError - # Methods scheduling callbacks. All these return Handles. + # Methods scheduling callbacks. All these return Handles. - def _timer_handle_cancelled(self, handle): - """Notification that a TimerHandle has been cancelled.""" - raise NotImplementedError + def _timer_handle_cancelled(self, handle): + """Notification that a TimerHandle has been cancelled.""" + raise NotImplementedError - def call_soon(self, callback, *args): - return self.call_later(0, callback, *args) + def call_soon(self, callback, *args): + return self.call_later(0, callback, *args) - def call_later(self, delay, callback, *args): - raise NotImplementedError + def call_later(self, delay, callback, *args): + raise NotImplementedError - def call_at(self, when, callback, *args): - raise NotImplementedError + def call_at(self, when, callback, *args): + raise NotImplementedError - def time(self): - raise NotImplementedError + def time(self): + raise NotImplementedError - # Method scheduling a coroutine object: create a task. + # Method scheduling a coroutine object: create a task. - def create_task(self, coro): - raise NotImplementedError + def create_task(self, coro): + raise NotImplementedError - # Methods for interacting with threads. + # Methods for interacting with threads. - def call_soon_threadsafe(self, callback, *args): - raise NotImplementedError + def call_soon_threadsafe(self, callback, *args): + raise NotImplementedError - def run_in_executor(self, executor, func, *args): - raise NotImplementedError + def run_in_executor(self, executor, func, *args): + raise NotImplementedError - def set_default_executor(self, executor): - raise NotImplementedError + def set_default_executor(self, executor): + raise NotImplementedError - # Network I/O methods returning Futures. + # Network I/O methods returning Futures. - def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): - raise NotImplementedError + def getaddrinfo(self, host, port, family=0, type=0, proto=0, flags=0): + raise NotImplementedError - def getnameinfo(self, sockaddr, flags=0): - raise NotImplementedError + def getnameinfo(self, sockaddr, flags=0): + raise NotImplementedError - def create_connection(self, protocol_factory, host=None, port=None, *, - ssl=None, family=0, proto=0, flags=0, sock=None, - local_addr=None, server_hostname=None): - raise NotImplementedError + def create_connection(self, protocol_factory, host=None, port=None, + ssl=None, family=0, proto=0, flags=0, sock=None, + local_addr=None, server_hostname=None): + raise NotImplementedError - def create_server(self, protocol_factory, host=None, port=None, *, - family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, - sock=None, backlog=100, ssl=None, reuse_address=None): - """A coroutine which creates a TCP server bound to host and port. + def create_server(self, protocol_factory, host=None, port=None, + family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, + sock=None, backlog=100, ssl=None, reuse_address=None): + """A coroutine which creates a TCP server bound to host and port. - The return value is a Server object which can be used to stop - the service. + The return value is a Server object which can be used to stop + the service. - If host is an empty string or None all interfaces are assumed - and a list of multiple sockets will be returned (most likely - one for IPv4 and another one for IPv6). + If host is an empty string or None all interfaces are assumed + and a list of multiple sockets will be returned (most likely + one for IPv4 and another one for IPv6). - family can be set to either AF_INET or AF_INET6 to force the - socket to use IPv4 or IPv6. If not set it will be determined - from host (defaults to AF_UNSPEC). + family can be set to either AF_INET or AF_INET6 to force the + socket to use IPv4 or IPv6. If not set it will be determined + from host (defaults to AF_UNSPEC). - flags is a bitmask for getaddrinfo(). + flags is a bitmask for getaddrinfo(). - sock can optionally be specified in order to use a preexisting - socket object. + sock can optionally be specified in order to use a preexisting + socket object. - backlog is the maximum number of queued connections passed to - listen() (defaults to 100). + backlog is the maximum number of queued connections passed to + listen() (defaults to 100). - ssl can be set to an SSLContext to enable SSL over the - accepted connections. + ssl can be set to an SSLContext to enable SSL over the + accepted connections. - reuse_address tells the kernel to reuse a local socket in - TIME_WAIT state, without waiting for its natural timeout to - expire. If not specified will automatically be set to True on - UNIX. - """ - raise NotImplementedError + reuse_address tells the kernel to reuse a local socket in + TIME_WAIT state, without waiting for its natural timeout to + expire. If not specified will automatically be set to True on + UNIX. + """ + raise NotImplementedError - def create_unix_connection(self, protocol_factory, path, *, - ssl=None, sock=None, - server_hostname=None): - raise NotImplementedError + def create_unix_connection(self, protocol_factory, path, + ssl=None, sock=None, + server_hostname=None): + raise NotImplementedError - def create_unix_server(self, protocol_factory, path, *, - sock=None, backlog=100, ssl=None): - """A coroutine which creates a UNIX Domain Socket server. + def create_unix_server(self, protocol_factory, path, + sock=None, backlog=100, ssl=None): + """A coroutine which creates a UNIX Domain Socket server. - The return value is a Server object, which can be used to stop - the service. + The return value is a Server object, which can be used to stop + the service. - path is a str, representing a file systsem path to bind the - server socket to. + path is a str, representing a file systsem path to bind the + server socket to. - sock can optionally be specified in order to use a preexisting - socket object. + sock can optionally be specified in order to use a preexisting + socket object. - backlog is the maximum number of queued connections passed to - listen() (defaults to 100). + backlog is the maximum number of queued connections passed to + listen() (defaults to 100). - ssl can be set to an SSLContext to enable SSL over the - accepted connections. - """ - raise NotImplementedError + ssl can be set to an SSLContext to enable SSL over the + accepted connections. + """ + raise NotImplementedError - def create_datagram_endpoint(self, protocol_factory, - local_addr=None, remote_addr=None, *, - family=0, proto=0, flags=0): - raise NotImplementedError + def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, + family=0, proto=0, flags=0): + raise NotImplementedError - # Pipes and subprocesses. + # Pipes and subprocesses. - def connect_read_pipe(self, protocol_factory, pipe): - """Register read pipe in event loop. Set the pipe to non-blocking mode. + def connect_read_pipe(self, protocol_factory, pipe): + """Register read pipe in event loop. Set the pipe to non-blocking mode. - protocol_factory should instantiate object with Protocol interface. - pipe is a file-like object. - Return pair (transport, protocol), where transport supports the - ReadTransport interface.""" - # The reason to accept file-like object instead of just file descriptor - # is: we need to own pipe and close it at transport finishing - # Can got complicated errors if pass f.fileno(), - # close fd in pipe transport then close f and vise versa. - raise NotImplementedError + protocol_factory should instantiate object with Protocol interface. + pipe is a file-like object. + Return pair (transport, protocol), where transport supports the + ReadTransport interface.""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError - def connect_write_pipe(self, protocol_factory, pipe): - """Register write pipe in event loop. + def connect_write_pipe(self, protocol_factory, pipe): + """Register write pipe in event loop. - protocol_factory should instantiate object with BaseProtocol interface. - Pipe is file-like object already switched to nonblocking. - Return pair (transport, protocol), where transport support - WriteTransport interface.""" - # The reason to accept file-like object instead of just file descriptor - # is: we need to own pipe and close it at transport finishing - # Can got complicated errors if pass f.fileno(), - # close fd in pipe transport then close f and vise versa. - raise NotImplementedError + protocol_factory should instantiate object with BaseProtocol interface. + Pipe is file-like object already switched to nonblocking. + Return pair (transport, protocol), where transport support + WriteTransport interface.""" + # The reason to accept file-like object instead of just file descriptor + # is: we need to own pipe and close it at transport finishing + # Can got complicated errors if pass f.fileno(), + # close fd in pipe transport then close f and vise versa. + raise NotImplementedError - def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, - stdout=subprocess.PIPE, stderr=subprocess.PIPE, - **kwargs): - raise NotImplementedError + def subprocess_shell(self, protocol_factory, cmd, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, stderr=subprocess.PIPE, + **kwargs): + raise NotImplementedError - def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, - stdout=subprocess.PIPE, stderr=subprocess.PIPE, - **kwargs): - raise NotImplementedError + def subprocess_exec(self, protocol_factory, *args, **kwargs): + raise NotImplementedError - # Ready-based callback registration methods. - # The add_*() methods return None. - # The remove_*() methods return True if something was removed, - # False if there was nothing to delete. + # Ready-based callback registration methods. + # The add_*() methods return None. + # The remove_*() methods return True if something was removed, + # False if there was nothing to delete. - def add_reader(self, fd, callback, *args): - raise NotImplementedError + def add_reader(self, fd, callback, *args): + raise NotImplementedError - def remove_reader(self, fd): - raise NotImplementedError + def remove_reader(self, fd): + raise NotImplementedError - def add_writer(self, fd, callback, *args): - raise NotImplementedError + def add_writer(self, fd, callback, *args): + raise NotImplementedError - def remove_writer(self, fd): - raise NotImplementedError + def remove_writer(self, fd): + raise NotImplementedError - # Completion based I/O methods returning Futures. + # Completion based I/O methods returning Futures. - def sock_recv(self, sock, nbytes): - raise NotImplementedError + def sock_recv(self, sock, nbytes): + raise NotImplementedError - def sock_sendall(self, sock, data): - raise NotImplementedError + def sock_sendall(self, sock, data): + raise NotImplementedError - def sock_connect(self, sock, address): - raise NotImplementedError + def sock_connect(self, sock, address): + raise NotImplementedError - def sock_accept(self, sock): - raise NotImplementedError + def sock_accept(self, sock): + raise NotImplementedError - # Signal handling. + # Signal handling. - def add_signal_handler(self, sig, callback, *args): - raise NotImplementedError + def add_signal_handler(self, sig, callback, *args): + raise NotImplementedError - def remove_signal_handler(self, sig): - raise NotImplementedError + def remove_signal_handler(self, sig): + raise NotImplementedError - # Task factory. + # Task factory. - def set_task_factory(self, factory): - raise NotImplementedError + def set_task_factory(self, factory): + raise NotImplementedError - def get_task_factory(self): - raise NotImplementedError + def get_task_factory(self): + raise NotImplementedError - # Error handlers. + # Error handlers. - def set_exception_handler(self, handler): - raise NotImplementedError + def set_exception_handler(self, handler): + raise NotImplementedError - def default_exception_handler(self, context): - raise NotImplementedError + def default_exception_handler(self, context): + raise NotImplementedError - def call_exception_handler(self, context): - raise NotImplementedError + def call_exception_handler(self, context): + raise NotImplementedError - # Debug flag management. + # Debug flag management. - def get_debug(self): - raise NotImplementedError + def get_debug(self): + raise NotImplementedError - def set_debug(self, enabled): - raise NotImplementedError + def set_debug(self, enabled): + raise NotImplementedError -class AbstractEventLoopPolicy: - """Abstract policy for accessing the event loop.""" + class AbstractEventLoopPolicy(object): + """Abstract policy for accessing the event loop.""" - def get_event_loop(self): - """Get the event loop for the current context. + def get_event_loop(self): + """Get the event loop for the current context. - Returns an event loop object implementing the BaseEventLoop interface, - or raises an exception in case no event loop has been set for the - current context and the current policy does not specify to create one. + Returns an event loop object implementing the BaseEventLoop interface, + or raises an exception in case no event loop has been set for the + current context and the current policy does not specify to create one. - It should never return None.""" - raise NotImplementedError + It should never return None.""" + raise NotImplementedError - def set_event_loop(self, loop): - """Set the event loop for the current context to loop.""" - raise NotImplementedError + def set_event_loop(self, loop): + """Set the event loop for the current context to loop.""" + raise NotImplementedError - def new_event_loop(self): - """Create and return a new event loop object according to this - policy's rules. If there's need to set this loop as the event loop for - the current context, set_event_loop must be called explicitly.""" - raise NotImplementedError + def new_event_loop(self): + """Create and return a new event loop object according to this + policy's rules. If there's need to set this loop as the event loop for + the current context, set_event_loop must be called explicitly.""" + raise NotImplementedError - # Child processes handling (Unix only). + # Child processes handling (Unix only). - def get_child_watcher(self): - "Get the watcher for child processes." - raise NotImplementedError + def get_child_watcher(self): + "Get the watcher for child processes." + raise NotImplementedError - def set_child_watcher(self, watcher): - """Set the watcher for child processes.""" - raise NotImplementedError + def set_child_watcher(self, watcher): + """Set the watcher for child processes.""" + raise NotImplementedError class BaseDefaultEventLoopPolicy(AbstractEventLoopPolicy): diff --git a/trollius/futures.py b/trollius/futures.py index d06828a6..80f6d118 100644 --- a/trollius/futures.py +++ b/trollius/futures.py @@ -5,13 +5,17 @@ 'Future', 'wrap_future', ] -import concurrent.futures._base import logging -import reprlib import sys import traceback +try: + import reprlib # Python 3 +except ImportError: + import repr as reprlib # Python 2 +from . import compat from . import events +from . import executor # States for Future. _PENDING = 'PENDING' @@ -21,9 +25,9 @@ _PY34 = sys.version_info >= (3, 4) _PY35 = sys.version_info >= (3, 5) -Error = concurrent.futures._base.Error -CancelledError = concurrent.futures.CancelledError -TimeoutError = concurrent.futures.TimeoutError +Error = executor.Error +CancelledError = executor.CancelledError +TimeoutError = executor.TimeoutError STACK_DEBUG = logging.DEBUG - 1 # heavy-duty debugging @@ -32,7 +36,7 @@ class InvalidStateError(Error): """The operation is not allowed in this state.""" -class _TracebackLogger: +class _TracebackLogger(object): """Helper to log a traceback upon destruction if not cleared. This solves a nasty problem with Futures and Tasks that have an @@ -112,7 +116,7 @@ def __del__(self): self.loop.call_exception_handler({'message': msg}) -class Future: +class Future(object): """This class is *almost* compatible with concurrent.futures.Future. Differences: @@ -138,10 +142,14 @@ class Future: _blocking = False # proper use of future (yield vs yield from) + # Used by Python 2 to raise the exception with the original traceback + # in the exception() method in debug mode + _exception_tb = None + _log_traceback = False # Used for Python 3.4 and later _tb_logger = None # Used for Python 3.3 only - def __init__(self, *, loop=None): + def __init__(self, loop=None): """Initialize the future. The optional event_loop argument allows to explicitly set the event @@ -168,23 +176,23 @@ def format_cb(callback): if size == 1: cb = format_cb(cb[0]) elif size == 2: - cb = '{}, {}'.format(format_cb(cb[0]), format_cb(cb[1])) + cb = '{0}, {1}'.format(format_cb(cb[0]), format_cb(cb[1])) elif size > 2: - cb = '{}, <{} more>, {}'.format(format_cb(cb[0]), - size-2, - format_cb(cb[-1])) + cb = '{0}, <{1} more>, {2}'.format(format_cb(cb[0]), + size-2, + format_cb(cb[-1])) return 'cb=[%s]' % cb def _repr_info(self): info = [self._state.lower()] if self._state == _FINISHED: if self._exception is not None: - info.append('exception={!r}'.format(self._exception)) + info.append('exception={0!r}'.format(self._exception)) else: # use reprlib to limit the length of the output, especially # for very long strings result = reprlib.repr(self._result) - info.append('result={}'.format(result)) + info.append('result={0}'.format(result)) if self._callbacks: info.append(self._format_callbacks()) if self._source_traceback: @@ -272,8 +280,13 @@ def result(self): if self._tb_logger is not None: self._tb_logger.clear() self._tb_logger = None + exc_tb = self._exception_tb + self._exception_tb = None if self._exception is not None: - raise self._exception + if exc_tb is not None: + compat.reraise(type(self._exception), self._exception, exc_tb) + else: + raise self._exception return self._result def exception(self): @@ -292,6 +305,7 @@ def exception(self): if self._tb_logger is not None: self._tb_logger.clear() self._tb_logger = None + self._exception_tb = None return self._exception def add_done_callback(self, fn): @@ -334,31 +348,61 @@ def set_result(self, result): InvalidStateError. """ if self._state != _PENDING: - raise InvalidStateError('{}: {!r}'.format(self._state, self)) + raise InvalidStateError('{0}: {1!r}'.format(self._state, self)) self._result = result self._state = _FINISHED self._schedule_callbacks() + def _get_exception_tb(self): + return self._exception_tb + def set_exception(self, exception): + self._set_exception_with_tb(exception, None) + + def _set_exception_with_tb(self, exception, exc_tb): """Mark the future done and set an exception. If the future is already done when this method is called, raises InvalidStateError. """ if self._state != _PENDING: - raise InvalidStateError('{}: {!r}'.format(self._state, self)) + raise InvalidStateError('{0}: {1!r}'.format(self._state, self)) if isinstance(exception, type): exception = exception() self._exception = exception + if exc_tb is not None: + self._exception_tb = exc_tb + exc_tb = None + elif self._loop.get_debug() and not compat.PY3: + self._exception_tb = sys.exc_info()[2] self._state = _FINISHED self._schedule_callbacks() if _PY34: self._log_traceback = True else: self._tb_logger = _TracebackLogger(self, exception) - # Arrange for the logger to be activated after all callbacks - # have had a chance to call result() or exception(). - self._loop.call_soon(self._tb_logger.activate) + if hasattr(exception, '__traceback__'): + # Python 3: exception contains a link to the traceback + + # Arrange for the logger to be activated after all callbacks + # have had a chance to call result() or exception(). + self._loop.call_soon(self._tb_logger.activate) + else: + if self._loop.get_debug(): + frame = sys._getframe(1) + tb = ['Traceback (most recent call last):\n'] + if self._exception_tb is not None: + tb += traceback.format_tb(self._exception_tb) + else: + tb += traceback.format_stack(frame) + tb += traceback.format_exception_only(type(exception), exception) + self._tb_logger.tb = tb + else: + self._tb_logger.tb = traceback.format_exception_only( + type(exception), + exception) + + self._tb_logger.exc = None # Truly internal methods. @@ -392,12 +436,18 @@ def __iter__(self): __await__ = __iter__ # make compatible with 'await' expression -def wrap_future(fut, *, loop=None): +if events.asyncio is not None: + # Accept also asyncio Future objects for interoperability + _FUTURE_CLASSES = (Future, events.asyncio.Future) +else: + _FUTURE_CLASSES = Future + +def wrap_future(fut, loop=None): """Wrap concurrent.futures.Future object.""" - if isinstance(fut, Future): + if isinstance(fut, _FUTURE_CLASSES): return fut - assert isinstance(fut, concurrent.futures.Future), \ - 'concurrent.futures.Future is expected, got {!r}'.format(fut) + assert isinstance(fut, executor.Future), \ + 'concurrent.futures.Future is expected, got {0!r}'.format(fut) if loop is None: loop = events.get_event_loop() new_future = Future(loop=loop) diff --git a/trollius/locks.py b/trollius/locks.py index b2e516b5..ecbe3b3b 100644 --- a/trollius/locks.py +++ b/trollius/locks.py @@ -7,7 +7,7 @@ from . import events from . import futures -from .coroutines import coroutine +from .coroutines import coroutine, From, Return _PY35 = sys.version_info >= (3, 5) @@ -19,7 +19,7 @@ class _ContextManager: This enables the following idiom for acquiring and releasing a lock around a block: - with (yield from lock): + with (yield From(lock)): while failing loudly when accidentally using: @@ -43,7 +43,7 @@ def __exit__(self, *args): self._lock = None # Crudely prevent reuse. -class _ContextManagerMixin: +class _ContextManagerMixin(object): def __enter__(self): raise RuntimeError( '"yield from" should be used as context manager expression') @@ -111,16 +111,16 @@ class Lock(_ContextManagerMixin): release() call resets the state to unlocked; first coroutine which is blocked in acquire() is being processed. - acquire() is a coroutine and should be called with 'yield from'. + acquire() is a coroutine and should be called with 'yield From'. - Locks also support the context management protocol. '(yield from lock)' + Locks also support the context management protocol. '(yield From(lock))' should be used as context manager expression. Usage: lock = Lock() ... - yield from lock + yield From(lock) try: ... finally: @@ -130,20 +130,20 @@ class Lock(_ContextManagerMixin): lock = Lock() ... - with (yield from lock): + with (yield From(lock)): ... Lock objects can be tested for locking state: if not lock.locked(): - yield from lock + yield From(lock) else: # lock is acquired ... """ - def __init__(self, *, loop=None): + def __init__(self, loop=None): self._waiters = collections.deque() self._locked = False if loop is not None: @@ -152,11 +152,11 @@ def __init__(self, *, loop=None): self._loop = events.get_event_loop() def __repr__(self): - res = super().__repr__() + res = super(Lock, self).__repr__() extra = 'locked' if self._locked else 'unlocked' if self._waiters: - extra = '{},waiters:{}'.format(extra, len(self._waiters)) - return '<{} [{}]>'.format(res[1:-1], extra) + extra = '{0},waiters:{1}'.format(extra, len(self._waiters)) + return '<{0} [{1}]>'.format(res[1:-1], extra) def locked(self): """Return True if lock is acquired.""" @@ -171,14 +171,14 @@ def acquire(self): """ if not self._waiters and not self._locked: self._locked = True - return True + raise Return(True) fut = futures.Future(loop=self._loop) self._waiters.append(fut) try: - yield from fut + yield From(fut) self._locked = True - return True + raise Return(True) finally: self._waiters.remove(fut) @@ -204,7 +204,7 @@ def release(self): raise RuntimeError('Lock is not acquired.') -class Event: +class Event(object): """Asynchronous equivalent to threading.Event. Class implementing event objects. An event manages a flag that can be set @@ -213,7 +213,7 @@ class Event: false. """ - def __init__(self, *, loop=None): + def __init__(self, loop=None): self._waiters = collections.deque() self._value = False if loop is not None: @@ -222,11 +222,11 @@ def __init__(self, *, loop=None): self._loop = events.get_event_loop() def __repr__(self): - res = super().__repr__() + res = super(Event, self).__repr__() extra = 'set' if self._value else 'unset' if self._waiters: - extra = '{},waiters:{}'.format(extra, len(self._waiters)) - return '<{} [{}]>'.format(res[1:-1], extra) + extra = '{0},waiters:{1}'.format(extra, len(self._waiters)) + return '<{0} [{1}]>'.format(res[1:-1], extra) def is_set(self): """Return True if and only if the internal flag is true.""" @@ -259,13 +259,13 @@ def wait(self): set() to set the flag to true, then return True. """ if self._value: - return True + raise Return(True) fut = futures.Future(loop=self._loop) self._waiters.append(fut) try: - yield from fut - return True + yield From(fut) + raise Return(True) finally: self._waiters.remove(fut) @@ -280,7 +280,7 @@ class Condition(_ContextManagerMixin): A new Lock object is created and used as the underlying lock. """ - def __init__(self, lock=None, *, loop=None): + def __init__(self, lock=None, loop=None): if loop is not None: self._loop = loop else: @@ -300,11 +300,11 @@ def __init__(self, lock=None, *, loop=None): self._waiters = collections.deque() def __repr__(self): - res = super().__repr__() + res = super(Condition, self).__repr__() extra = 'locked' if self.locked() else 'unlocked' if self._waiters: - extra = '{},waiters:{}'.format(extra, len(self._waiters)) - return '<{} [{}]>'.format(res[1:-1], extra) + extra = '{0},waiters:{1}'.format(extra, len(self._waiters)) + return '<{0} [{1}]>'.format(res[1:-1], extra) @coroutine def wait(self): @@ -326,13 +326,13 @@ def wait(self): fut = futures.Future(loop=self._loop) self._waiters.append(fut) try: - yield from fut - return True + yield From(fut) + raise Return(True) finally: self._waiters.remove(fut) finally: - yield from self.acquire() + yield From(self.acquire()) @coroutine def wait_for(self, predicate): @@ -344,9 +344,9 @@ def wait_for(self, predicate): """ result = predicate() while not result: - yield from self.wait() + yield From(self.wait()) result = predicate() - return result + raise Return(result) def notify(self, n=1): """By default, wake up one coroutine waiting on this condition, if any. @@ -396,7 +396,7 @@ class Semaphore(_ContextManagerMixin): ValueError is raised. """ - def __init__(self, value=1, *, loop=None): + def __init__(self, value=1, loop=None): if value < 0: raise ValueError("Semaphore initial value must be >= 0") self._value = value @@ -407,12 +407,12 @@ def __init__(self, value=1, *, loop=None): self._loop = events.get_event_loop() def __repr__(self): - res = super().__repr__() - extra = 'locked' if self.locked() else 'unlocked,value:{}'.format( + res = super(Semaphore, self).__repr__() + extra = 'locked' if self.locked() else 'unlocked,value:{0}'.format( self._value) if self._waiters: - extra = '{},waiters:{}'.format(extra, len(self._waiters)) - return '<{} [{}]>'.format(res[1:-1], extra) + extra = '{0},waiters:{1}'.format(extra, len(self._waiters)) + return '<{0} [{1}]>'.format(res[1:-1], extra) def locked(self): """Returns True if semaphore can not be acquired immediately.""" @@ -430,14 +430,14 @@ def acquire(self): """ if not self._waiters and self._value > 0: self._value -= 1 - return True + raise Return(True) fut = futures.Future(loop=self._loop) self._waiters.append(fut) try: - yield from fut + yield From(fut) self._value -= 1 - return True + raise Return(True) finally: self._waiters.remove(fut) @@ -460,11 +460,11 @@ class BoundedSemaphore(Semaphore): above the initial value. """ - def __init__(self, value=1, *, loop=None): + def __init__(self, value=1, loop=None): self._bound_value = value - super().__init__(value, loop=loop) + super(BoundedSemaphore, self).__init__(value, loop=loop) def release(self): if self._value >= self._bound_value: raise ValueError('BoundedSemaphore released too many times') - super().release() + super(BoundedSemaphore, self).release() diff --git a/trollius/proactor_events.py b/trollius/proactor_events.py index 9c2b8f15..49d8bc3e 100644 --- a/trollius/proactor_events.py +++ b/trollius/proactor_events.py @@ -16,6 +16,9 @@ from . import sslproto from . import transports from .log import logger +from .compat import flatten_bytes +from .py33_exceptions import (BrokenPipeError, + ConnectionAbortedError, ConnectionResetError) class _ProactorBasePipeTransport(transports._FlowControlMixin, @@ -24,7 +27,7 @@ class _ProactorBasePipeTransport(transports._FlowControlMixin, def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None): - super().__init__(extra, loop) + super(_ProactorBasePipeTransport, self).__init__(extra, loop) self._set_extra(sock) self._sock = sock self._protocol = protocol @@ -143,7 +146,8 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport, def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None): - super().__init__(loop, sock, protocol, waiter, extra, server) + super(_ProactorReadPipeTransport, self).__init__(loop, sock, protocol, + waiter, extra, server) self._paused = False self._loop.call_soon(self._loop_reading) @@ -220,9 +224,7 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport, """Transport for write pipes.""" def write(self, data): - if not isinstance(data, (bytes, bytearray, memoryview)): - raise TypeError('data argument must be byte-ish (%r)', - type(data)) + data = flatten_bytes(data) if self._eof_written: raise RuntimeError('write_eof() already called') @@ -301,7 +303,7 @@ def abort(self): class _ProactorWritePipeTransport(_ProactorBaseWritePipeTransport): def __init__(self, *args, **kw): - super().__init__(*args, **kw) + super(_ProactorWritePipeTransport, self).__init__(*args, **kw) self._read_fut = self._loop._proactor.recv(self._sock, 16) self._read_fut.add_done_callback(self._pipe_closed) @@ -368,7 +370,7 @@ def write_eof(self): class BaseProactorEventLoop(base_events.BaseEventLoop): def __init__(self, proactor): - super().__init__() + super(BaseProactorEventLoop, self).__init__() logger.debug('Using proactor: %s', proactor.__class__.__name__) self._proactor = proactor self._selector = proactor # convenient alias @@ -383,7 +385,7 @@ def _make_socket_transport(self, sock, protocol, waiter=None, extra, server) def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, - *, server_side=False, server_hostname=None, + server_side=False, server_hostname=None, extra=None, server=None): if not sslproto._is_sslproto_available(): raise NotImplementedError("Proactor event loop requires Python 3.5" @@ -427,7 +429,7 @@ def close(self): self._selector = None # Close the event loop - super().close() + super(BaseProactorEventLoop, self).close() def sock_recv(self, sock, n): return self._proactor.recv(sock, n) diff --git a/trollius/protocols.py b/trollius/protocols.py index 80fcac9a..2c18287b 100644 --- a/trollius/protocols.py +++ b/trollius/protocols.py @@ -4,7 +4,7 @@ 'SubprocessProtocol'] -class BaseProtocol: +class BaseProtocol(object): """Common base class for protocol interfaces. Usually user implements protocols that derived from BaseProtocol diff --git a/trollius/queues.py b/trollius/queues.py index ed116620..d856305e 100644 --- a/trollius/queues.py +++ b/trollius/queues.py @@ -9,6 +9,7 @@ from . import events from . import futures from . import locks +from .coroutines import From, Return from .tasks import coroutine @@ -26,7 +27,7 @@ class QueueFull(Exception): pass -class Queue: +class Queue(object): """A queue, useful for coordinating producer and consumer coroutines. If maxsize is less than or equal to zero, the queue size is infinite. If it @@ -38,7 +39,7 @@ class Queue: interrupted between calling qsize() and doing an operation on the Queue. """ - def __init__(self, maxsize=0, *, loop=None): + def __init__(self, maxsize=0, loop=None): if loop is None: self._loop = events.get_event_loop() else: @@ -73,22 +74,22 @@ def __put_internal(self, item): self._finished.clear() def __repr__(self): - return '<{} at {:#x} {}>'.format( + return '<{0} at {1:#x} {2}>'.format( type(self).__name__, id(self), self._format()) def __str__(self): - return '<{} {}>'.format(type(self).__name__, self._format()) + return '<{0} {1}>'.format(type(self).__name__, self._format()) def _format(self): - result = 'maxsize={!r}'.format(self._maxsize) + result = 'maxsize={0!r}'.format(self._maxsize) if getattr(self, '_queue', None): - result += ' _queue={!r}'.format(list(self._queue)) + result += ' _queue={0!r}'.format(list(self._queue)) if self._getters: - result += ' _getters[{}]'.format(len(self._getters)) + result += ' _getters[{0}]'.format(len(self._getters)) if self._putters: - result += ' _putters[{}]'.format(len(self._putters)) + result += ' _putters[{0}]'.format(len(self._putters)) if self._unfinished_tasks: - result += ' tasks={}'.format(self._unfinished_tasks) + result += ' tasks={0}'.format(self._unfinished_tasks) return result def _consume_done_getters(self): @@ -149,7 +150,7 @@ def put(self, item): waiter = futures.Future(loop=self._loop) self._putters.append((item, waiter)) - yield from waiter + yield From(waiter) else: self.__put_internal(item) @@ -195,15 +196,16 @@ def get(self): # ChannelTest.test_wait. self._loop.call_soon(putter._set_result_unless_cancelled, None) - return self._get() + raise Return(self._get()) elif self.qsize(): - return self._get() + raise Return(self._get()) else: waiter = futures.Future(loop=self._loop) self._getters.append(waiter) - return (yield from waiter) + result = yield From(waiter) + raise Return(result) def get_nowait(self): """Remove and return an item from the queue. @@ -257,7 +259,7 @@ def join(self): When the count of unfinished tasks drops to zero, join() unblocks. """ if self._unfinished_tasks > 0: - yield from self._finished.wait() + yield From(self._finished.wait()) class PriorityQueue(Queue): diff --git a/trollius/selector_events.py b/trollius/selector_events.py index 7c5b9b5b..ec14974f 100644 --- a/trollius/selector_events.py +++ b/trollius/selector_events.py @@ -14,6 +14,8 @@ import warnings try: import ssl + from .py3_ssl import (wrap_ssl_error, SSLContext, SSLWantReadError, + SSLWantWriteError) except ImportError: # pragma: no cover ssl = None @@ -24,8 +26,26 @@ from . import selectors from . import transports from . import sslproto +from .compat import flatten_bytes from .coroutines import coroutine from .log import logger +from .py33_exceptions import (wrap_error, + BlockingIOError, InterruptedError, ConnectionAbortedError, BrokenPipeError, + ConnectionResetError) + +# On Mac OS 10.6 with Python 2.6.1 or OpenIndiana 148 with Python 2.6.4, +# _SelectorSslTransport._read_ready() hangs if the socket has no data. +# Example: test_events.test_create_server_ssl() +_SSL_REQUIRES_SELECT = (sys.version_info < (2, 6, 6)) +if _SSL_REQUIRES_SELECT: + import select + + +def _get_socket_error(sock, address): + err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + if err != 0: + # Jump to the except clause below. + raise OSError(err, 'Connect call failed %s' % (address,)) def _test_selector_event(selector, fd, event): @@ -46,7 +66,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): """ def __init__(self, selector=None): - super().__init__() + super(BaseSelectorEventLoop, self).__init__() if selector is None: selector = selectors.DefaultSelector() @@ -54,13 +74,13 @@ def __init__(self, selector=None): self._selector = selector self._make_self_pipe() - def _make_socket_transport(self, sock, protocol, waiter=None, *, + def _make_socket_transport(self, sock, protocol, waiter=None, extra=None, server=None): return _SelectorSocketTransport(self, sock, protocol, waiter, extra, server) def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, - *, server_side=False, server_hostname=None, + server_side=False, server_hostname=None, extra=None, server=None): if not sslproto._is_sslproto_available(): return self._make_legacy_ssl_transport( @@ -75,7 +95,7 @@ def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, return ssl_protocol._app_transport def _make_legacy_ssl_transport(self, rawsock, protocol, sslcontext, - waiter, *, + waiter, server_side=False, server_hostname=None, extra=None, server=None): # Use the legacy API: SSL_write, SSL_read, etc. The legacy API is used @@ -95,7 +115,7 @@ def close(self): if self.is_closed(): return self._close_self_pipe() - super().close() + super(BaseSelectorEventLoop, self).close() if self._selector is not None: self._selector.close() self._selector = None @@ -125,7 +145,7 @@ def _process_self_data(self, data): def _read_from_self(self): while True: try: - data = self._ssock.recv(4096) + data = wrap_error(self._ssock.recv, 4096) if not data: break self._process_self_data(data) @@ -143,7 +163,7 @@ def _write_to_self(self): csock = self._csock if csock is not None: try: - csock.send(b'\0') + wrap_error(csock.send, b'\0') except OSError: if self._debug: logger.debug("Fail to write a null byte into the " @@ -158,14 +178,14 @@ def _start_serving(self, protocol_factory, sock, def _accept_connection(self, protocol_factory, sock, sslcontext=None, server=None): try: - conn, addr = sock.accept() + conn, addr = wrap_error(sock.accept) if self._debug: logger.debug("%r got a new connection from %r: %r", server, addr, conn) conn.setblocking(False) except (BlockingIOError, InterruptedError, ConnectionAbortedError): pass # False alarm. - except OSError as exc: + except socket.error as exc: # There's nowhere to send the error, so just log it. if exc.errno in (errno.EMFILE, errno.ENFILE, errno.ENOBUFS, errno.ENOMEM): @@ -331,7 +351,7 @@ def _sock_recv(self, fut, registered, sock, n): if fut.cancelled(): return try: - data = sock.recv(n) + data = wrap_error(sock.recv, n) except (BlockingIOError, InterruptedError): self.add_reader(fd, self._sock_recv, fut, True, sock, n) except Exception as exc: @@ -368,7 +388,7 @@ def _sock_sendall(self, fut, registered, sock, data): return try: - n = sock.send(data) + n = wrap_error(sock.send, data) except (BlockingIOError, InterruptedError): n = 0 except Exception as exc: @@ -408,7 +428,7 @@ def sock_connect(self, sock, address): def _sock_connect(self, fut, sock, address): fd = sock.fileno() try: - sock.connect(address) + wrap_error(sock.connect, address) except (BlockingIOError, InterruptedError): # Issue #23618: When the C function connect() fails with EINTR, the # connection runs in background. We have to wait until the socket @@ -430,10 +450,7 @@ def _sock_connect_cb(self, fut, sock, address): return try: - err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) - if err != 0: - # Jump to any except clause below. - raise OSError(err, 'Connect call failed %s' % (address,)) + wrap_error(_get_socket_error, sock, address) except (BlockingIOError, InterruptedError): # socket is still registered, the callback will be retried later pass @@ -465,7 +482,7 @@ def _sock_accept(self, fut, registered, sock): if fut.cancelled(): return try: - conn, address = sock.accept() + conn, address = wrap_error(sock.accept) conn.setblocking(False) except (BlockingIOError, InterruptedError): self.add_reader(fd, self._sock_accept, fut, True, sock) @@ -506,7 +523,7 @@ class _SelectorTransport(transports._FlowControlMixin, _sock = None def __init__(self, loop, sock, protocol, extra=None, server=None): - super().__init__(extra, loop) + super(_SelectorTransport, self).__init__(extra, loop) self._extra['socket'] = sock self._extra['sockname'] = sock.getsockname() if 'peername' not in self._extra: @@ -593,7 +610,7 @@ def _force_close(self, exc): if self._conn_lost: return if self._buffer: - self._buffer.clear() + del self._buffer[:] self._loop.remove_writer(self._sock_fd) if not self._closing: self._closing = True @@ -623,7 +640,7 @@ class _SelectorSocketTransport(_SelectorTransport): def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None): - super().__init__(loop, sock, protocol, extra, server) + super(_SelectorSocketTransport, self).__init__(loop, sock, protocol, extra, server) self._eof = False self._paused = False @@ -657,7 +674,7 @@ def resume_reading(self): def _read_ready(self): try: - data = self._sock.recv(self.max_size) + data = wrap_error(self._sock.recv, self.max_size) except (BlockingIOError, InterruptedError): pass except Exception as exc: @@ -678,9 +695,7 @@ def _read_ready(self): self.close() def write(self, data): - if not isinstance(data, (bytes, bytearray, memoryview)): - raise TypeError('data argument must be byte-ish (%r)', - type(data)) + data = flatten_bytes(data) if self._eof: raise RuntimeError('Cannot call write() after write_eof()') if not data: @@ -695,7 +710,7 @@ def write(self, data): if not self._buffer: # Optimization: try to send now. try: - n = self._sock.send(data) + n = wrap_error(self._sock.send, data) except (BlockingIOError, InterruptedError): pass except Exception as exc: @@ -715,13 +730,14 @@ def write(self, data): def _write_ready(self): assert self._buffer, 'Data should not be empty' + data = flatten_bytes(self._buffer) try: - n = self._sock.send(self._buffer) + n = wrap_error(self._sock.send, data) except (BlockingIOError, InterruptedError): pass except Exception as exc: self._loop.remove_writer(self._sock_fd) - self._buffer.clear() + del self._buffer[:] self._fatal_error(exc, 'Fatal write error on socket transport') else: if n: @@ -766,7 +782,7 @@ def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, wrap_kwargs['server_hostname'] = server_hostname sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs) - super().__init__(loop, sslsock, protocol, extra, server) + super(_SelectorSslTransport, self).__init__(loop, sslsock, protocol, extra, server) # the protocol connection is only made after the SSL handshake self._protocol_connected = False @@ -797,12 +813,12 @@ def _wakeup_waiter(self, exc=None): def _on_handshake(self, start_time): try: - self._sock.do_handshake() - except ssl.SSLWantReadError: + wrap_ssl_error(self._sock.do_handshake) + except SSLWantReadError: self._loop.add_reader(self._sock_fd, self._on_handshake, start_time) return - except ssl.SSLWantWriteError: + except SSLWantWriteError: self._loop.add_writer(self._sock_fd, self._on_handshake, start_time) return @@ -842,8 +858,9 @@ def _on_handshake(self, start_time): # Add extra info that becomes available after handshake. self._extra.update(peercert=peercert, cipher=self._sock.cipher(), - compression=self._sock.compression(), ) + if hasattr(self._sock, 'compression'): + self._extra['compression'] = self._sock.compression() self._read_wants_write = False self._write_wants_read = False @@ -883,6 +900,9 @@ def resume_reading(self): if self._loop.get_debug(): logger.debug("%r resumes reading", self) + def _sock_recv(self): + return wrap_ssl_error(self._sock.recv, self.max_size) + def _read_ready(self): if self._write_wants_read: self._write_wants_read = False @@ -892,10 +912,16 @@ def _read_ready(self): self._loop.add_writer(self._sock_fd, self._write_ready) try: - data = self._sock.recv(self.max_size) - except (BlockingIOError, InterruptedError, ssl.SSLWantReadError): + if _SSL_REQUIRES_SELECT: + rfds = (self._sock.fileno(),) + rfds = select.select(rfds, (), (), 0.0)[0] + if not rfds: + # False alarm. + return + data = wrap_error(self._sock_recv) + except (BlockingIOError, InterruptedError, SSLWantReadError): pass - except ssl.SSLWantWriteError: + except SSLWantWriteError: self._read_wants_write = True self._loop.remove_reader(self._sock_fd) self._loop.add_writer(self._sock_fd, self._write_ready) @@ -924,17 +950,18 @@ def _write_ready(self): self._loop.add_reader(self._sock_fd, self._read_ready) if self._buffer: + data = flatten_bytes(self._buffer) try: - n = self._sock.send(self._buffer) - except (BlockingIOError, InterruptedError, ssl.SSLWantWriteError): + n = wrap_error(self._sock.send, data) + except (BlockingIOError, InterruptedError, SSLWantWriteError): n = 0 - except ssl.SSLWantReadError: + except SSLWantReadError: n = 0 self._loop.remove_writer(self._sock_fd) self._write_wants_read = True except Exception as exc: self._loop.remove_writer(self._sock_fd) - self._buffer.clear() + del self._buffer[:] self._fatal_error(exc, 'Fatal write error on SSL transport') return @@ -949,9 +976,7 @@ def _write_ready(self): self._call_connection_lost(None) def write(self, data): - if not isinstance(data, (bytes, bytearray, memoryview)): - raise TypeError('data argument must be byte-ish (%r)', - type(data)) + data = flatten_bytes(data) if not data: return @@ -978,7 +1003,8 @@ class _SelectorDatagramTransport(_SelectorTransport): def __init__(self, loop, sock, protocol, address=None, waiter=None, extra=None): - super().__init__(loop, sock, protocol, extra) + super(_SelectorDatagramTransport, self).__init__(loop, sock, + protocol, extra) self._address = address self._loop.call_soon(self._protocol.connection_made, self) # only start reading when connection_made() has been called @@ -993,7 +1019,7 @@ def get_write_buffer_size(self): def _read_ready(self): try: - data, addr = self._sock.recvfrom(self.max_size) + data, addr = wrap_error(self._sock.recvfrom, self.max_size) except (BlockingIOError, InterruptedError): pass except OSError as exc: @@ -1004,9 +1030,7 @@ def _read_ready(self): self._protocol.datagram_received(data, addr) def sendto(self, data, addr=None): - if not isinstance(data, (bytes, bytearray, memoryview)): - raise TypeError('data argument must be byte-ish (%r)', - type(data)) + data = flatten_bytes(data) if not data: return @@ -1024,9 +1048,9 @@ def sendto(self, data, addr=None): # Attempt to send it right away first. try: if self._address: - self._sock.send(data) + wrap_error(self._sock.send, data) else: - self._sock.sendto(data, addr) + wrap_error(self._sock.sendto, data, addr) return except (BlockingIOError, InterruptedError): self._loop.add_writer(self._sock_fd, self._sendto_ready) @@ -1047,9 +1071,9 @@ def _sendto_ready(self): data, addr = self._buffer.popleft() try: if self._address: - self._sock.send(data) + wrap_error(self._sock.send, data) else: - self._sock.sendto(data, addr) + wrap_error(self._sock.sendto, data, addr) except (BlockingIOError, InterruptedError): self._buffer.appendleft((data, addr)) # Try again later. break diff --git a/trollius/selectors.py b/trollius/selectors.py index 6d569c30..d2f822cd 100644 --- a/trollius/selectors.py +++ b/trollius/selectors.py @@ -11,6 +11,9 @@ import select import sys +from .py33_exceptions import wrap_error, InterruptedError +from .compat import integer_types + # generic events, that must be mapped to implementation-specific ones EVENT_READ = (1 << 0) @@ -29,16 +32,16 @@ def _fileobj_to_fd(fileobj): Raises: ValueError if the object is invalid """ - if isinstance(fileobj, int): + if isinstance(fileobj, integer_types): fd = fileobj else: try: fd = int(fileobj.fileno()) except (AttributeError, TypeError, ValueError): raise ValueError("Invalid file object: " - "{!r}".format(fileobj)) from None + "{0!r}".format(fileobj)) if fd < 0: - raise ValueError("Invalid file descriptor: {}".format(fd)) + raise ValueError("Invalid file descriptor: {0}".format(fd)) return fd @@ -61,13 +64,13 @@ def __getitem__(self, fileobj): fd = self._selector._fileobj_lookup(fileobj) return self._selector._fd_to_key[fd] except KeyError: - raise KeyError("{!r} is not registered".format(fileobj)) from None + raise KeyError("{0!r} is not registered".format(fileobj)) def __iter__(self): return iter(self._selector._fd_to_key) -class BaseSelector(metaclass=ABCMeta): +class BaseSelector(object): """Selector abstract base class. A selector supports registering file objects to be monitored for specific @@ -81,6 +84,7 @@ class BaseSelector(metaclass=ABCMeta): depending on the platform. The default `Selector` class uses the most efficient implementation on the current platform. """ + __metaclass__ = ABCMeta @abstractmethod def register(self, fileobj, events, data=None): @@ -179,7 +183,7 @@ def get_key(self, fileobj): try: return mapping[fileobj] except KeyError: - raise KeyError("{!r} is not registered".format(fileobj)) from None + raise KeyError("{0!r} is not registered".format(fileobj)) @abstractmethod def get_map(self): @@ -223,12 +227,12 @@ def _fileobj_lookup(self, fileobj): def register(self, fileobj, events, data=None): if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)): - raise ValueError("Invalid events: {!r}".format(events)) + raise ValueError("Invalid events: {0!r}".format(events)) key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data) if key.fd in self._fd_to_key: - raise KeyError("{!r} (FD {}) is already registered" + raise KeyError("{0!r} (FD {1}) is already registered" .format(fileobj, key.fd)) self._fd_to_key[key.fd] = key @@ -238,7 +242,7 @@ def unregister(self, fileobj): try: key = self._fd_to_key.pop(self._fileobj_lookup(fileobj)) except KeyError: - raise KeyError("{!r} is not registered".format(fileobj)) from None + raise KeyError("{0!r} is not registered".format(fileobj)) return key def modify(self, fileobj, events, data=None): @@ -246,7 +250,7 @@ def modify(self, fileobj, events, data=None): try: key = self._fd_to_key[self._fileobj_lookup(fileobj)] except KeyError: - raise KeyError("{!r} is not registered".format(fileobj)) from None + raise KeyError("{0!r} is not registered".format(fileobj)) if events != key.events: self.unregister(fileobj) key = self.register(fileobj, events, data) @@ -282,12 +286,12 @@ class SelectSelector(_BaseSelectorImpl): """Select-based selector.""" def __init__(self): - super().__init__() + super(SelectSelector, self).__init__() self._readers = set() self._writers = set() def register(self, fileobj, events, data=None): - key = super().register(fileobj, events, data) + key = super(SelectSelector, self).register(fileobj, events, data) if events & EVENT_READ: self._readers.add(key.fd) if events & EVENT_WRITE: @@ -295,7 +299,7 @@ def register(self, fileobj, events, data=None): return key def unregister(self, fileobj): - key = super().unregister(fileobj) + key = super(SelectSelector, self).unregister(fileobj) self._readers.discard(key.fd) self._writers.discard(key.fd) return key @@ -311,7 +315,8 @@ def select(self, timeout=None): timeout = None if timeout is None else max(timeout, 0) ready = [] try: - r, w, _ = self._select(self._readers, self._writers, [], timeout) + r, w, _ = wrap_error(self._select, + self._readers, self._writers, [], timeout) except InterruptedError: return ready r = set(r) @@ -335,11 +340,11 @@ class PollSelector(_BaseSelectorImpl): """Poll-based selector.""" def __init__(self): - super().__init__() + super(PollSelector, self).__init__() self._poll = select.poll() def register(self, fileobj, events, data=None): - key = super().register(fileobj, events, data) + key = super(PollSelector, self).register(fileobj, events, data) poll_events = 0 if events & EVENT_READ: poll_events |= select.POLLIN @@ -349,7 +354,7 @@ def register(self, fileobj, events, data=None): return key def unregister(self, fileobj): - key = super().unregister(fileobj) + key = super(PollSelector, self).unregister(fileobj) self._poll.unregister(key.fd) return key @@ -361,10 +366,10 @@ def select(self, timeout=None): else: # poll() has a resolution of 1 millisecond, round away from # zero to wait *at least* timeout seconds. - timeout = math.ceil(timeout * 1e3) + timeout = int(math.ceil(timeout * 1e3)) ready = [] try: - fd_event_list = self._poll.poll(timeout) + fd_event_list = wrap_error(self._poll.poll, timeout) except InterruptedError: return ready for fd, event in fd_event_list: @@ -386,14 +391,14 @@ class EpollSelector(_BaseSelectorImpl): """Epoll-based selector.""" def __init__(self): - super().__init__() + super(EpollSelector, self).__init__() self._epoll = select.epoll() def fileno(self): return self._epoll.fileno() def register(self, fileobj, events, data=None): - key = super().register(fileobj, events, data) + key = super(EpollSelector, self).register(fileobj, events, data) epoll_events = 0 if events & EVENT_READ: epoll_events |= select.EPOLLIN @@ -403,7 +408,7 @@ def register(self, fileobj, events, data=None): return key def unregister(self, fileobj): - key = super().unregister(fileobj) + key = super(EpollSelector, self).unregister(fileobj) try: self._epoll.unregister(key.fd) except OSError: @@ -429,7 +434,7 @@ def select(self, timeout=None): ready = [] try: - fd_event_list = self._epoll.poll(timeout, max_ev) + fd_event_list = wrap_error(self._epoll.poll, timeout, max_ev) except InterruptedError: return ready for fd, event in fd_event_list: @@ -446,7 +451,7 @@ def select(self, timeout=None): def close(self): self._epoll.close() - super().close() + super(EpollSelector, self).close() if hasattr(select, 'devpoll'): @@ -455,14 +460,14 @@ class DevpollSelector(_BaseSelectorImpl): """Solaris /dev/poll selector.""" def __init__(self): - super().__init__() + super(DevpollSelector, self).__init__() self._devpoll = select.devpoll() def fileno(self): return self._devpoll.fileno() def register(self, fileobj, events, data=None): - key = super().register(fileobj, events, data) + key = super(DevpollSelector, self).register(fileobj, events, data) poll_events = 0 if events & EVENT_READ: poll_events |= select.POLLIN @@ -472,7 +477,7 @@ def register(self, fileobj, events, data=None): return key def unregister(self, fileobj): - key = super().unregister(fileobj) + key = super(DevpollSelector, self).unregister(fileobj) self._devpoll.unregister(key.fd) return key @@ -504,7 +509,7 @@ def select(self, timeout=None): def close(self): self._devpoll.close() - super().close() + super(DevpollSelector, self).close() if hasattr(select, 'kqueue'): @@ -513,14 +518,14 @@ class KqueueSelector(_BaseSelectorImpl): """Kqueue-based selector.""" def __init__(self): - super().__init__() + super(KqueueSelector, self).__init__() self._kqueue = select.kqueue() def fileno(self): return self._kqueue.fileno() def register(self, fileobj, events, data=None): - key = super().register(fileobj, events, data) + key = super(KqueueSelector, self).register(fileobj, events, data) if events & EVENT_READ: kev = select.kevent(key.fd, select.KQ_FILTER_READ, select.KQ_EV_ADD) @@ -532,7 +537,7 @@ def register(self, fileobj, events, data=None): return key def unregister(self, fileobj): - key = super().unregister(fileobj) + key = super(KqueueSelector, self).unregister(fileobj) if key.events & EVENT_READ: kev = select.kevent(key.fd, select.KQ_FILTER_READ, select.KQ_EV_DELETE) @@ -557,7 +562,8 @@ def select(self, timeout=None): max_ev = len(self._fd_to_key) ready = [] try: - kev_list = self._kqueue.control(None, max_ev, timeout) + kev_list = wrap_error(self._kqueue.control, + None, max_ev, timeout) except InterruptedError: return ready for kev in kev_list: @@ -576,7 +582,7 @@ def select(self, timeout=None): def close(self): self._kqueue.close() - super().close() + super(KqueueSelector, self).close() # Choose the best implementation, roughly: diff --git a/trollius/sslproto.py b/trollius/sslproto.py index 235855e2..5f4920a5 100644 --- a/trollius/sslproto.py +++ b/trollius/sslproto.py @@ -9,6 +9,7 @@ from . import protocols from . import transports from .log import logger +from .py3_ssl import BACKPORT_SSL_CONTEXT def _create_transport_context(server_side, server_hostname): @@ -26,10 +27,11 @@ def _create_transport_context(server_side, server_hostname): else: # Fallback for Python 3.3. sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - sslcontext.options |= ssl.OP_NO_SSLv2 - sslcontext.options |= ssl.OP_NO_SSLv3 - sslcontext.set_default_verify_paths() - sslcontext.verify_mode = ssl.CERT_REQUIRED + if not BACKPORT_SSL_CONTEXT: + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.options |= ssl.OP_NO_SSLv3 + sslcontext.set_default_verify_paths() + sslcontext.verify_mode = ssl.CERT_REQUIRED return sslcontext @@ -43,6 +45,11 @@ def _is_sslproto_available(): _WRAPPED = "WRAPPED" _SHUTDOWN = "SHUTDOWN" +if hasattr(ssl, 'CertificateError'): + _SSL_ERRORS = (ssl.SSLError, ssl.CertificateError) +else: + _SSL_ERRORS = ssl.SSLError + class _SSLPipe(object): """An SSL "Pipe". @@ -224,7 +231,7 @@ def feed_ssldata(self, data, only_handshake=False): elif self._state == _UNWRAPPED: # Drain possible plaintext data after close_notify. appdata.append(self._incoming.read()) - except (ssl.SSLError, ssl.CertificateError) as exc: + except _SSL_ERRORS as exc: if getattr(exc, 'errno', None) not in ( ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, ssl.SSL_ERROR_SYSCALL): @@ -569,7 +576,8 @@ def _on_handshake_complete(self, handshake_exc): ssl.match_hostname(peercert, self._server_hostname) except BaseException as exc: if self._loop.get_debug(): - if isinstance(exc, ssl.CertificateError): + if (hasattr(ssl, 'CertificateError') + and isinstance(exc, ssl.CertificateError)): logger.warning("%r: SSL handshake failed " "on verifying the certificate", self, exc_info=True) diff --git a/trollius/streams.py b/trollius/streams.py index 176c65e3..c235c5a5 100644 --- a/trollius/streams.py +++ b/trollius/streams.py @@ -15,7 +15,8 @@ from . import events from . import futures from . import protocols -from .coroutines import coroutine +from .coroutines import coroutine, From, Return +from .py33_exceptions import ConnectionResetError from .log import logger @@ -38,7 +39,7 @@ def __init__(self, partial, expected): @coroutine -def open_connection(host=None, port=None, *, +def open_connection(host=None, port=None, loop=None, limit=_DEFAULT_LIMIT, **kwds): """A wrapper for create_connection() returning a (reader, writer) pair. @@ -61,14 +62,14 @@ def open_connection(host=None, port=None, *, loop = events.get_event_loop() reader = StreamReader(limit=limit, loop=loop) protocol = StreamReaderProtocol(reader, loop=loop) - transport, _ = yield from loop.create_connection( - lambda: protocol, host, port, **kwds) + transport, _ = yield From(loop.create_connection( + lambda: protocol, host, port, **kwds)) writer = StreamWriter(transport, protocol, reader, loop) - return reader, writer + raise Return(reader, writer) @coroutine -def start_server(client_connected_cb, host=None, port=None, *, +def start_server(client_connected_cb, host=None, port=None, loop=None, limit=_DEFAULT_LIMIT, **kwds): """Start a socket server, call back for each client connected. @@ -100,28 +101,29 @@ def factory(): loop=loop) return protocol - return (yield from loop.create_server(factory, host, port, **kwds)) + server = yield From(loop.create_server(factory, host, port, **kwds)) + raise Return(server) if hasattr(socket, 'AF_UNIX'): # UNIX Domain Sockets are supported on this platform @coroutine - def open_unix_connection(path=None, *, + def open_unix_connection(path=None, loop=None, limit=_DEFAULT_LIMIT, **kwds): """Similar to `open_connection` but works with UNIX Domain Sockets.""" if loop is None: loop = events.get_event_loop() reader = StreamReader(limit=limit, loop=loop) protocol = StreamReaderProtocol(reader, loop=loop) - transport, _ = yield from loop.create_unix_connection( - lambda: protocol, path, **kwds) + transport, _ = yield From(loop.create_unix_connection( + lambda: protocol, path, **kwds)) writer = StreamWriter(transport, protocol, reader, loop) - return reader, writer + raise Return(reader, writer) @coroutine - def start_unix_server(client_connected_cb, path=None, *, + def start_unix_server(client_connected_cb, path=None, loop=None, limit=_DEFAULT_LIMIT, **kwds): """Similar to `start_server` but works with UNIX Domain Sockets.""" if loop is None: @@ -133,7 +135,8 @@ def factory(): loop=loop) return protocol - return (yield from loop.create_unix_server(factory, path, **kwds)) + server = (yield From(loop.create_unix_server(factory, path, **kwds))) + raise Return(server) class FlowControlMixin(protocols.Protocol): @@ -199,7 +202,7 @@ def _drain_helper(self): assert waiter is None or waiter.cancelled() waiter = futures.Future(loop=self._loop) self._drain_waiter = waiter - yield from waiter + yield From(waiter) class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): @@ -212,7 +215,7 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): """ def __init__(self, stream_reader, client_connected_cb=None, loop=None): - super().__init__(loop=loop) + super(StreamReaderProtocol, self).__init__(loop=loop) self._stream_reader = stream_reader self._stream_writer = None self._client_connected_cb = client_connected_cb @@ -233,7 +236,7 @@ def connection_lost(self, exc): self._stream_reader.feed_eof() else: self._stream_reader.set_exception(exc) - super().connection_lost(exc) + super(StreamReaderProtocol, self).connection_lost(exc) def data_received(self, data): self._stream_reader.feed_data(data) @@ -242,7 +245,7 @@ def eof_received(self): self._stream_reader.feed_eof() -class StreamWriter: +class StreamWriter(object): """Wraps a Transport. This exposes write(), writelines(), [can_]write_eof(), @@ -295,16 +298,16 @@ def drain(self): The intended use is to write w.write(data) - yield from w.drain() + yield From(w.drain()) """ if self._reader is not None: exc = self._reader.exception() if exc is not None: raise exc - yield from self._protocol._drain_helper() + yield From(self._protocol._drain_helper()) -class StreamReader: +class StreamReader(object): def __init__(self, limit=_DEFAULT_LIMIT, loop=None): # The line length limit is a security feature; @@ -391,9 +394,16 @@ def _wait_for_data(self, func_name): raise RuntimeError('%s() called while another coroutine is ' 'already waiting for incoming data' % func_name) + # In asyncio, there is no need to recheck if we got data or EOF thanks + # to "yield from". In trollius, a StreamReader method can be called + # after the _wait_for_data() coroutine is scheduled and before it is + # really executed. + if self._buffer or self._eof: + return + self._waiter = futures.Future(loop=self._loop) try: - yield from self._waiter + yield From(self._waiter) finally: self._waiter = None @@ -410,7 +420,7 @@ def readline(self): ichar = self._buffer.find(b'\n') if ichar < 0: line.extend(self._buffer) - self._buffer.clear() + del self._buffer[:] else: ichar += 1 line.extend(self._buffer[:ichar]) @@ -425,10 +435,10 @@ def readline(self): break if not_enough: - yield from self._wait_for_data('readline') + yield From(self._wait_for_data('readline')) self._maybe_resume_transport() - return bytes(line) + raise Return(bytes(line)) @coroutine def read(self, n=-1): @@ -436,7 +446,7 @@ def read(self, n=-1): raise self._exception if not n: - return b'' + raise Return(b'') if n < 0: # This used to just loop creating a new waiter hoping to @@ -445,25 +455,25 @@ def read(self, n=-1): # bytes. So just call self.read(self._limit) until EOF. blocks = [] while True: - block = yield from self.read(self._limit) + block = yield From(self.read(self._limit)) if not block: break blocks.append(block) - return b''.join(blocks) + raise Return(b''.join(blocks)) else: if not self._buffer and not self._eof: - yield from self._wait_for_data('read') + yield From(self._wait_for_data('read')) if n < 0 or len(self._buffer) <= n: data = bytes(self._buffer) - self._buffer.clear() + del self._buffer[:] else: # n > 0 and len(self._buffer) > n data = bytes(self._buffer[:n]) del self._buffer[:n] self._maybe_resume_transport() - return data + raise Return(data) @coroutine def readexactly(self, n): @@ -479,14 +489,14 @@ def readexactly(self, n): blocks = [] while n > 0: - block = yield from self.read(n) + block = yield From(self.read(n)) if not block: partial = b''.join(blocks) raise IncompleteReadError(partial, len(partial) + n) blocks.append(block) n -= len(block) - return b''.join(blocks) + raise Return(b''.join(blocks)) if _PY35: @coroutine diff --git a/trollius/subprocess.py b/trollius/subprocess.py index 4600a9f4..2f1becf3 100644 --- a/trollius/subprocess.py +++ b/trollius/subprocess.py @@ -8,13 +8,16 @@ from . import protocols from . import streams from . import tasks -from .coroutines import coroutine +from .coroutines import coroutine, From, Return +from .py33_exceptions import (BrokenPipeError, ConnectionResetError, + ProcessLookupError) from .log import logger PIPE = subprocess.PIPE STDOUT = subprocess.STDOUT -DEVNULL = subprocess.DEVNULL +if hasattr(subprocess, 'DEVNULL'): + DEVNULL = subprocess.DEVNULL class SubprocessStreamProtocol(streams.FlowControlMixin, @@ -22,7 +25,7 @@ class SubprocessStreamProtocol(streams.FlowControlMixin, """Like StreamReaderProtocol, but for a subprocess.""" def __init__(self, limit, loop): - super().__init__(loop=loop) + super(SubprocessStreamProtocol, self).__init__(loop=loop) self._limit = limit self.stdin = self.stdout = self.stderr = None self._transport = None @@ -115,7 +118,8 @@ def wait(self): """Wait until the process exit and return the process return code. This method is a coroutine.""" - return (yield from self._transport._wait()) + return_code = yield From(self._transport._wait()) + raise Return(return_code) def send_signal(self, signal): self._transport.send_signal(signal) @@ -134,7 +138,7 @@ def _feed_stdin(self, input): logger.debug('%r communicate: feed stdin (%s bytes)', self, len(input)) try: - yield from self.stdin.drain() + yield From(self.stdin.drain()) except (BrokenPipeError, ConnectionResetError) as exc: # communicate() ignores BrokenPipeError and ConnectionResetError if debug: @@ -159,12 +163,12 @@ def _read_stream(self, fd): if self._loop.get_debug(): name = 'stdout' if fd == 1 else 'stderr' logger.debug('%r communicate: read %s', self, name) - output = yield from stream.read() + output = yield From(stream.read()) if self._loop.get_debug(): name = 'stdout' if fd == 1 else 'stderr' logger.debug('%r communicate: close %s', self, name) transport.close() - return output + raise Return(output) @coroutine def communicate(self, input=None): @@ -180,36 +184,43 @@ def communicate(self, input=None): stderr = self._read_stream(2) else: stderr = self._noop() - stdin, stdout, stderr = yield from tasks.gather(stdin, stdout, stderr, - loop=self._loop) - yield from self.wait() - return (stdout, stderr) + stdin, stdout, stderr = yield From(tasks.gather(stdin, stdout, stderr, + loop=self._loop)) + yield From(self.wait()) + raise Return(stdout, stderr) @coroutine -def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None, - loop=None, limit=streams._DEFAULT_LIMIT, **kwds): +def create_subprocess_shell(cmd, **kwds): + stdin = kwds.pop('stdin', None) + stdout = kwds.pop('stdout', None) + stderr = kwds.pop('stderr', None) + loop = kwds.pop('loop', None) + limit = kwds.pop('limit', streams._DEFAULT_LIMIT) if loop is None: loop = events.get_event_loop() protocol_factory = lambda: SubprocessStreamProtocol(limit=limit, loop=loop) - transport, protocol = yield from loop.subprocess_shell( + transport, protocol = yield From(loop.subprocess_shell( protocol_factory, cmd, stdin=stdin, stdout=stdout, - stderr=stderr, **kwds) - return Process(transport, protocol, loop) + stderr=stderr, **kwds)) + raise Return(Process(transport, protocol, loop)) @coroutine -def create_subprocess_exec(program, *args, stdin=None, stdout=None, - stderr=None, loop=None, - limit=streams._DEFAULT_LIMIT, **kwds): +def create_subprocess_exec(program, *args, **kwds): + stdin = kwds.pop('stdin', None) + stdout = kwds.pop('stdout', None) + stderr = kwds.pop('stderr', None) + loop = kwds.pop('loop', None) + limit = kwds.pop('limit', streams._DEFAULT_LIMIT) if loop is None: loop = events.get_event_loop() protocol_factory = lambda: SubprocessStreamProtocol(limit=limit, loop=loop) - transport, protocol = yield from loop.subprocess_exec( + transport, protocol = yield From(loop.subprocess_exec( protocol_factory, program, *args, stdin=stdin, stdout=stdout, - stderr=stderr, **kwds) - return Process(transport, protocol, loop) + stderr=stderr, **kwds)) + raise Return(Process(transport, protocol, loop)) diff --git a/trollius/tasks.py b/trollius/tasks.py index d8193ba4..1fb23cdc 100644 --- a/trollius/tasks.py +++ b/trollius/tasks.py @@ -14,16 +14,30 @@ import types import traceback import warnings -import weakref +try: + from weakref import WeakSet +except ImportError: + # Python 2.6 + from .py27_weakrefset import WeakSet +from . import compat from . import coroutines from . import events +from . import executor from . import futures -from .coroutines import coroutine +from .locks import Lock, Condition, Semaphore, _ContextManager +from .coroutines import coroutine, From, Return + _PY34 = (sys.version_info >= (3, 4)) +@coroutine +def _lock_coroutine(lock): + yield From(lock.acquire()) + raise Return(_ContextManager(lock)) + + class Task(futures.Future): """A coroutine wrapped in a Future.""" @@ -37,7 +51,7 @@ class Task(futures.Future): # must be _wakeup(). # Weak set containing all tasks alive. - _all_tasks = weakref.WeakSet() + _all_tasks = WeakSet() # Dictionary containing tasks that are currently active in # all running event loops. {EventLoop: Task} @@ -67,11 +81,11 @@ def all_tasks(cls, loop=None): """ if loop is None: loop = events.get_event_loop() - return {t for t in cls._all_tasks if t._loop is loop} + return set(t for t in cls._all_tasks if t._loop is loop) - def __init__(self, coro, *, loop=None): + def __init__(self, coro, loop=None): assert coroutines.iscoroutine(coro), repr(coro) - super().__init__(loop=loop) + super(Task, self).__init__(loop=loop) if self._source_traceback: del self._source_traceback[-1] self._coro = coro @@ -96,7 +110,7 @@ def __del__(self): futures.Future.__del__(self) def _repr_info(self): - info = super()._repr_info() + info = super(Task, self)._repr_info() if self._must_cancel: # replace status @@ -109,7 +123,7 @@ def _repr_info(self): info.insert(2, 'wait_for=%r' % self._fut_waiter) return info - def get_stack(self, *, limit=None): + def get_stack(self, limit=None): """Return the list of stack frames for this task's coroutine. If the coroutine is not done, this returns the stack where it is @@ -152,7 +166,7 @@ def get_stack(self, *, limit=None): tb = tb.tb_next return frames - def print_stack(self, *, limit=None, file=None): + def print_stack(self, limit=None, file=None): """Print the stack or traceback for this task's coroutine. This produces output similar to that of the traceback module, @@ -219,9 +233,9 @@ def cancel(self): self._must_cancel = True return True - def _step(self, value=None, exc=None): + def _step(self, value=None, exc=None, exc_tb=None): assert not self.done(), \ - '_step(): already done: {!r}, {!r}, {!r}'.format(self, value, exc) + '_step(): already done: {0!r}, {1!r}, {2!r}'.format(self, value, exc) if self._must_cancel: if not isinstance(exc, futures.CancelledError): exc = futures.CancelledError() @@ -229,6 +243,10 @@ def _step(self, value=None, exc=None): coro = self._coro self._fut_waiter = None + if exc_tb is not None: + init_exc = exc + else: + init_exc = None self.__class__._current_tasks[self._loop] = self # Call either coro.throw(exc) or coro.send(value). try: @@ -237,71 +255,104 @@ def _step(self, value=None, exc=None): else: result = coro.send(value) except StopIteration as exc: - self.set_result(exc.value) + if compat.PY33: + # asyncio Task object? get the result of the coroutine + result = exc.value + else: + if isinstance(exc, Return): + exc.raised = True + result = exc.value + else: + result = None + self.set_result(result) except futures.CancelledError as exc: - super().cancel() # I.e., Future.cancel(self). - except Exception as exc: - self.set_exception(exc) + super(Task, self).cancel() # I.e., Future.cancel(self). except BaseException as exc: - self.set_exception(exc) - raise + if exc is init_exc: + self._set_exception_with_tb(exc, exc_tb) + exc_tb = None + else: + self.set_exception(exc) + + if not isinstance(exc, Exception): + # reraise BaseException + raise else: - if isinstance(result, futures.Future): - # Yielded Future must come from Future.__iter__(). - if result._blocking: - result._blocking = False - result.add_done_callback(self._wakeup) - self._fut_waiter = result - if self._must_cancel: - if self._fut_waiter.cancel(): - self._must_cancel = False + if coroutines._DEBUG: + if not coroutines._coroutine_at_yield_from(self._coro): + # trollius coroutine must "yield From(...)" + if not isinstance(result, coroutines.FromWrapper): + self._loop.call_soon( + self._step, None, + RuntimeError("yield used without From")) + return + result = result.obj else: - self._loop.call_soon( - self._step, None, - RuntimeError( - 'yield was used instead of yield from ' - 'in task {!r} with {!r}'.format(self, result))) + # asyncio coroutine using "yield from ..." + if isinstance(result, coroutines.FromWrapper): + result = result.obj + elif isinstance(result, coroutines.FromWrapper): + result = result.obj + + if coroutines.iscoroutine(result): + # "yield coroutine" creates a task, the current task + # will wait until the new task is done + result = self._loop.create_task(result) + # FIXME: faster check. common base class? hasattr? + elif isinstance(result, (Lock, Condition, Semaphore)): + coro = _lock_coroutine(result) + result = self._loop.create_task(coro) + + if isinstance(result, futures._FUTURE_CLASSES): + # Yielded Future must come from Future.__iter__(). + result.add_done_callback(self._wakeup) + self._fut_waiter = result + if self._must_cancel: + if self._fut_waiter.cancel(): + self._must_cancel = False elif result is None: # Bare yield relinquishes control for one event loop iteration. self._loop.call_soon(self._step) - elif inspect.isgenerator(result): - # Yielding a generator is just wrong. - self._loop.call_soon( - self._step, None, - RuntimeError( - 'yield was used instead of yield from for ' - 'generator in task {!r} with {}'.format( - self, result))) else: # Yielding something else is an error. self._loop.call_soon( self._step, None, RuntimeError( - 'Task got bad yield: {!r}'.format(result))) + 'Task got bad yield: {0!r}'.format(result))) finally: self.__class__._current_tasks.pop(self._loop) self = None # Needed to break cycles when an exception occurs. def _wakeup(self, future): - try: - value = future.result() - except Exception as exc: - # This may also be a cancellation. - self._step(None, exc) + if (future._state == futures._FINISHED + and future._exception is not None): + # Get the traceback before calling exception(), because calling + # the exception() method clears the traceback + exc_tb = future._get_exception_tb() + exc = future.exception() + self._step(None, exc, exc_tb) + exc_tb = None else: - self._step(value, None) + try: + value = future.result() + except Exception as exc: + # This may also be a cancellation. + self._step(None, exc) + else: + self._step(value, None) self = None # Needed to break cycles when an exception occurs. # wait() and as_completed() similar to those in PEP 3148. -FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED -FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION -ALL_COMPLETED = concurrent.futures.ALL_COMPLETED +# Export symbols in trollius.tasks for compatibility with asyncio +FIRST_COMPLETED = executor.FIRST_COMPLETED +FIRST_EXCEPTION = executor.FIRST_EXCEPTION +ALL_COMPLETED = executor.ALL_COMPLETED @coroutine -def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): +def wait(fs, loop=None, timeout=None, return_when=ALL_COMPLETED): """Wait for the Futures and coroutines given by fs to complete. The sequence futures must not be empty. @@ -312,24 +363,25 @@ def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): Usage: - done, pending = yield from asyncio.wait(fs) + done, pending = yield From(asyncio.wait(fs)) Note: This does not raise TimeoutError! Futures that aren't done when the timeout occurs are returned in the second set. """ - if isinstance(fs, futures.Future) or coroutines.iscoroutine(fs): + if isinstance(fs, futures._FUTURE_CLASSES) or coroutines.iscoroutine(fs): raise TypeError("expect a list of futures, not %s" % type(fs).__name__) if not fs: raise ValueError('Set of coroutines/Futures is empty.') if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED): - raise ValueError('Invalid return_when value: {}'.format(return_when)) + raise ValueError('Invalid return_when value: {0}'.format(return_when)) if loop is None: loop = events.get_event_loop() - fs = {ensure_future(f, loop=loop) for f in set(fs)} + fs = set(ensure_future(f, loop=loop) for f in set(fs)) - return (yield from _wait(fs, timeout, return_when, loop)) + result = yield From(_wait(fs, timeout, return_when, loop)) + raise Return(result) def _release_waiter(waiter, *args): @@ -338,7 +390,7 @@ def _release_waiter(waiter, *args): @coroutine -def wait_for(fut, timeout, *, loop=None): +def wait_for(fut, timeout, loop=None): """Wait for the single Future or coroutine to complete, with timeout. Coroutine will be wrapped in Task. @@ -355,7 +407,8 @@ def wait_for(fut, timeout, *, loop=None): loop = events.get_event_loop() if timeout is None: - return (yield from fut) + result = yield From(fut) + raise Return(result) waiter = futures.Future(loop=loop) timeout_handle = loop.call_later(timeout, _release_waiter, waiter) @@ -367,14 +420,14 @@ def wait_for(fut, timeout, *, loop=None): try: # wait until the future completes or the timeout try: - yield from waiter + yield From(waiter) except futures.CancelledError: fut.remove_done_callback(cb) fut.cancel() raise if fut.done(): - return fut.result() + raise Return(fut.result()) else: fut.remove_done_callback(cb) fut.cancel() @@ -394,12 +447,11 @@ def _wait(fs, timeout, return_when, loop): timeout_handle = None if timeout is not None: timeout_handle = loop.call_later(timeout, _release_waiter, waiter) - counter = len(fs) + non_local = {'counter': len(fs)} def _on_completion(f): - nonlocal counter - counter -= 1 - if (counter <= 0 or + non_local['counter'] -= 1 + if (non_local['counter'] <= 0 or return_when == FIRST_COMPLETED or return_when == FIRST_EXCEPTION and (not f.cancelled() and f.exception() is not None)): @@ -412,7 +464,7 @@ def _on_completion(f): f.add_done_callback(_on_completion) try: - yield from waiter + yield From(waiter) finally: if timeout_handle is not None: timeout_handle.cancel() @@ -424,11 +476,11 @@ def _on_completion(f): done.add(f) else: pending.add(f) - return done, pending + raise Return(done, pending) # This is *not* a @coroutine! It is just an iterator (yielding Futures). -def as_completed(fs, *, loop=None, timeout=None): +def as_completed(fs, loop=None, timeout=None): """Return an iterator whose values are coroutines. When waiting for the yielded coroutines you'll get the results (or @@ -438,18 +490,18 @@ def as_completed(fs, *, loop=None, timeout=None): This differs from PEP 3148; the proper way to use this is: for f in as_completed(fs): - result = yield from f # The 'yield from' may raise. + result = yield From(f) # The 'yield' may raise. # Use result. - If a timeout is specified, the 'yield from' will raise + If a timeout is specified, the 'yield' will raise TimeoutError when the timeout occurs before all Futures are done. Note: The futures 'f' are not necessarily members of fs. """ - if isinstance(fs, futures.Future) or coroutines.iscoroutine(fs): + if isinstance(fs, futures._FUTURE_CLASSES) or coroutines.iscoroutine(fs): raise TypeError("expect a list of futures, not %s" % type(fs).__name__) loop = loop if loop is not None else events.get_event_loop() - todo = {ensure_future(f, loop=loop) for f in set(fs)} + todo = set(ensure_future(f, loop=loop) for f in set(fs)) from .queues import Queue # Import here to avoid circular import problem. done = Queue(loop=loop) timeout_handle = None @@ -470,11 +522,11 @@ def _on_completion(f): @coroutine def _wait_for_one(): - f = yield from done.get() + f = yield From(done.get()) if f is None: # Dummy value from _on_timeout(). raise futures.TimeoutError - return f.result() # May raise f.exception(). + raise Return(f.result()) # May raise f.exception(). for f in todo: f.add_done_callback(_on_completion) @@ -485,18 +537,19 @@ def _wait_for_one(): @coroutine -def sleep(delay, result=None, *, loop=None): +def sleep(delay, result=None, loop=None): """Coroutine that completes after a given time (in seconds).""" future = futures.Future(loop=loop) h = future._loop.call_later(delay, future._set_result_unless_cancelled, result) try: - return (yield from future) + result = yield From(future) + raise Return(result) finally: h.cancel() -def async(coro_or_future, *, loop=None): +def async(coro_or_future, loop=None): """Wrap a coroutine in a future. If the argument is a Future, it is returned directly. @@ -515,7 +568,10 @@ def ensure_future(coro_or_future, *, loop=None): If the argument is a Future, it is returned directly. """ - if isinstance(coro_or_future, futures.Future): + # FIXME: only check if coroutines._DEBUG is True? + if isinstance(coro_or_future, coroutines.FromWrapper): + coro_or_future = coro_or_future.obj + if isinstance(coro_or_future, futures._FUTURE_CLASSES): if loop is not None and loop is not coro_or_future._loop: raise ValueError('loop argument must agree with Future') return coro_or_future @@ -538,8 +594,8 @@ class _GatheringFuture(futures.Future): cancelled. """ - def __init__(self, children, *, loop=None): - super().__init__(loop=loop) + def __init__(self, children, loop=None): + super(_GatheringFuture, self).__init__(loop=loop) self._children = children def cancel(self): @@ -550,7 +606,7 @@ def cancel(self): return True -def gather(*coros_or_futures, loop=None, return_exceptions=False): +def gather(*coros_or_futures, **kw): """Return a future aggregating results from the given coroutines or futures. @@ -570,6 +626,11 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False): prevent the cancellation of one child to cause other children to be cancelled.) """ + loop = kw.pop('loop', None) + return_exceptions = kw.pop('return_exceptions', False) + if kw: + raise TypeError("unexpected keyword") + if not coros_or_futures: outer = futures.Future(loop=loop) outer.set_result([]) @@ -577,7 +638,7 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False): arg_to_fut = {} for arg in set(coros_or_futures): - if not isinstance(arg, futures.Future): + if not isinstance(arg, futures._FUTURE_CLASSES): fut = ensure_future(arg, loop=loop) if loop is None: loop = fut._loop @@ -595,11 +656,10 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False): children = [arg_to_fut[arg] for arg in coros_or_futures] nchildren = len(children) outer = _GatheringFuture(children, loop=loop) - nfinished = 0 + non_local = {'nfinished': 0} results = [None] * nchildren def _done_callback(i, fut): - nonlocal nfinished if outer.done(): if not fut.cancelled(): # Mark exception retrieved. @@ -619,8 +679,8 @@ def _done_callback(i, fut): else: res = fut._result results[i] = res - nfinished += 1 - if nfinished == nchildren: + non_local['nfinished'] += 1 + if non_local['nfinished'] == nchildren: outer.set_result(results) for i, fut in enumerate(children): @@ -628,16 +688,16 @@ def _done_callback(i, fut): return outer -def shield(arg, *, loop=None): +def shield(arg, loop=None): """Wait for a future, shielding it from cancellation. The statement - res = yield from shield(something()) + res = yield From(shield(something())) is exactly equivalent to the statement - res = yield from something() + res = yield From(something()) *except* that if the coroutine containing it is cancelled, the task running in something() is not cancelled. From the POV of @@ -650,7 +710,7 @@ def shield(arg, *, loop=None): you can combine shield() with a try/except clause, as follows: try: - res = yield from shield(something()) + res = yield From(shield(something())) except CancelledError: res = None """ diff --git a/trollius/test_support.py b/trollius/test_support.py index 0fadfad9..b40576a4 100644 --- a/trollius/test_support.py +++ b/trollius/test_support.py @@ -4,6 +4,7 @@ # Ignore symbol TEST_HOME_DIR: test_events works without it +from __future__ import absolute_import import functools import gc import os @@ -14,6 +15,7 @@ import sys import time +from trollius import test_utils # A constant likely larger than the underlying OS pipe buffer size, to # make writes blocking. @@ -39,7 +41,9 @@ def _assert_python(expected_success, *args, **env_vars): isolated = env_vars.pop('__isolated') else: isolated = not env_vars - cmd_line = [sys.executable, '-X', 'faulthandler'] + cmd_line = [sys.executable] + if sys.version_info >= (3, 3): + cmd_line.extend(('-X', 'faulthandler')) if isolated and sys.version_info >= (3, 4): # isolated mode: ignore Python environment variables, ignore user # site-packages, and don't add the current directory to sys.path @@ -248,7 +252,7 @@ def wrapper(*args, **kw): else: if version < min_version: min_version_txt = '.'.join(map(str, min_version)) - raise unittest.SkipTest( + raise test_utils.SkipTest( "Mac OS X %s or higher required, not %s" % (min_version_txt, version_txt)) return func(*args, **kw) @@ -275,7 +279,7 @@ def wrapper(*args, **kw): else: if version < min_version: min_version_txt = '.'.join(map(str, min_version)) - raise unittest.SkipTest( + raise test_utils.SkipTest( "%s version %s or higher required, not %s" % (sysname, min_version_txt, version_txt)) return func(*args, **kw) @@ -300,9 +304,6 @@ def requires_freebsd_version(*min_version): # Use test.script_helper if available try: - from test.support.script_helper import assert_python_ok + from test.script_helper import assert_python_ok except ImportError: - try: - from test.script_helper import assert_python_ok - except ImportError: - pass + pass diff --git a/trollius/test_utils.py b/trollius/test_utils.py index af7f5bca..caa98f71 100644 --- a/trollius/test_utils.py +++ b/trollius/test_utils.py @@ -7,20 +7,32 @@ import os import re import socket -import socketserver import sys import tempfile import threading import time -import unittest -from unittest import mock -from http.server import HTTPServer from wsgiref.simple_server import WSGIRequestHandler, WSGIServer +try: + import socketserver + from http.server import HTTPServer +except ImportError: + # Python 2 + import SocketServer as socketserver + from BaseHTTPServer import HTTPServer + +try: + from unittest import mock +except ImportError: + # Python < 3.3 + import mock + try: import ssl + from .py3_ssl import SSLContext, wrap_socket except ImportError: # pragma: no cover + # SSL support disabled in Python ssl = None from . import base_events @@ -37,27 +49,116 @@ else: from socket import socketpair # pragma: no cover +try: + import unittest + skipIf = unittest.skipIf + skipUnless = unittest.skipUnless + SkipTest = unittest.SkipTest + _TestCase = unittest.TestCase +except AttributeError: + # Python 2.6: use the backported unittest module called "unittest2" + import unittest2 + skipIf = unittest2.skipIf + skipUnless = unittest2.skipUnless + SkipTest = unittest2.SkipTest + _TestCase = unittest2.TestCase + + +if not hasattr(_TestCase, 'assertRaisesRegex'): + class _BaseTestCaseContext: + + def __init__(self, test_case): + self.test_case = test_case + + def _raiseFailure(self, standardMsg): + msg = self.test_case._formatMessage(self.msg, standardMsg) + raise self.test_case.failureException(msg) + + + class _AssertRaisesBaseContext(_BaseTestCaseContext): + + def __init__(self, expected, test_case, callable_obj=None, + expected_regex=None): + _BaseTestCaseContext.__init__(self, test_case) + self.expected = expected + self.test_case = test_case + if callable_obj is not None: + try: + self.obj_name = callable_obj.__name__ + except AttributeError: + self.obj_name = str(callable_obj) + else: + self.obj_name = None + if isinstance(expected_regex, (bytes, str)): + expected_regex = re.compile(expected_regex) + self.expected_regex = expected_regex + self.msg = None + + def handle(self, name, callable_obj, args, kwargs): + """ + If callable_obj is None, assertRaises/Warns is being used as a + context manager, so check for a 'msg' kwarg and return self. + If callable_obj is not None, call it passing args and kwargs. + """ + if callable_obj is None: + self.msg = kwargs.pop('msg', None) + return self + with self: + callable_obj(*args, **kwargs) + + + class _AssertRaisesContext(_AssertRaisesBaseContext): + """A context manager used to implement TestCase.assertRaises* methods.""" + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, tb): + if exc_type is None: + try: + exc_name = self.expected.__name__ + except AttributeError: + exc_name = str(self.expected) + if self.obj_name: + self._raiseFailure("{0} not raised by {1}".format(exc_name, + self.obj_name)) + else: + self._raiseFailure("{0} not raised".format(exc_name)) + if not issubclass(exc_type, self.expected): + # let unexpected exceptions pass through + return False + self.exception = exc_value + if self.expected_regex is None: + return True + + expected_regex = self.expected_regex + if not expected_regex.search(str(exc_value)): + self._raiseFailure('"{0}" does not match "{1}"'.format( + expected_regex.pattern, str(exc_value))) + return True + def dummy_ssl_context(): if ssl is None: return None else: - return ssl.SSLContext(ssl.PROTOCOL_SSLv23) + return SSLContext(ssl.PROTOCOL_SSLv23) -def run_briefly(loop): +def run_briefly(loop, steps=1): @coroutine def once(): pass - gen = once() - t = loop.create_task(gen) - # Don't log a warning if the task is not done after run_until_complete(). - # It occurs if the loop is stopped or if a task raises a BaseException. - t._log_destroy_pending = False - try: - loop.run_until_complete(t) - finally: - gen.close() + for step in range(steps): + gen = once() + t = loop.create_task(gen) + # Don't log a warning if the task is not done after run_until_complete(). + # It occurs if the loop is stopped or if a task raises a BaseException. + t._log_destroy_pending = False + try: + loop.run_until_complete(t) + finally: + gen.close() def run_until(loop, pred, timeout=30): @@ -89,12 +190,12 @@ def log_message(self, format, *args): pass -class SilentWSGIServer(WSGIServer): +class SilentWSGIServer(WSGIServer, object): request_timeout = 2 def get_request(self): - request, client_addr = super().get_request() + request, client_addr = super(SilentWSGIServer, self).get_request() request.settimeout(self.request_timeout) return request, client_addr @@ -115,10 +216,10 @@ def finish_request(self, request, client_address): 'test', 'test_asyncio') keyfile = os.path.join(here, 'ssl_key.pem') certfile = os.path.join(here, 'ssl_cert.pem') - ssock = ssl.wrap_socket(request, - keyfile=keyfile, - certfile=certfile, - server_side=True) + ssock = wrap_socket(request, + keyfile=keyfile, + certfile=certfile, + server_side=True) try: self.RequestHandlerClass(ssock, client_address, self) ssock.close() @@ -131,7 +232,7 @@ class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer): pass -def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls): +def _run_test_server(address, use_ssl, server_cls, server_ssl_cls): def app(environ, start_response): status = '200 OK' @@ -158,7 +259,7 @@ def app(environ, start_response): if hasattr(socket, 'AF_UNIX'): - class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer): + class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer, object): def server_bind(self): socketserver.UnixStreamServer.server_bind(self) @@ -166,7 +267,7 @@ def server_bind(self): self.server_port = 80 - class UnixWSGIServer(UnixHTTPServer, WSGIServer): + class UnixWSGIServer(UnixHTTPServer, WSGIServer, object): request_timeout = 2 @@ -175,7 +276,7 @@ def server_bind(self): self.setup_environ() def get_request(self): - request, client_addr = super().get_request() + request, client_addr = super(UnixWSGIServer, self).get_request() request.settimeout(self.request_timeout) # Code in the stdlib expects that get_request # will return a socket and a tuple (host, port). @@ -214,18 +315,20 @@ def unix_socket_path(): @contextlib.contextmanager - def run_test_unix_server(*, use_ssl=False): + def run_test_unix_server(use_ssl=False): with unix_socket_path() as path: - yield from _run_test_server(address=path, use_ssl=use_ssl, - server_cls=SilentUnixWSGIServer, - server_ssl_cls=UnixSSLWSGIServer) + for item in _run_test_server(address=path, use_ssl=use_ssl, + server_cls=SilentUnixWSGIServer, + server_ssl_cls=UnixSSLWSGIServer): + yield item @contextlib.contextmanager -def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): - yield from _run_test_server(address=(host, port), use_ssl=use_ssl, - server_cls=SilentWSGIServer, - server_ssl_cls=SSLWSGIServer) +def run_test_server(host='127.0.0.1', port=0, use_ssl=False): + for item in _run_test_server(address=(host, port), use_ssl=use_ssl, + server_cls=SilentWSGIServer, + server_ssl_cls=SSLWSGIServer): + yield item def make_test_protocol(base): @@ -278,7 +381,7 @@ def gen(): """ def __init__(self, gen=None): - super().__init__() + super(TestLoop, self).__init__() if gen is None: def gen(): @@ -307,7 +410,7 @@ def advance_time(self, advance): self._time += advance def close(self): - super().close() + super(TestLoop, self).close() if self._check_on_close: try: self._gen.send(0) @@ -328,11 +431,11 @@ def remove_reader(self, fd): return False def assert_reader(self, fd, callback, *args): - assert fd in self.readers, 'fd {} is not registered'.format(fd) + assert fd in self.readers, 'fd {0} is not registered'.format(fd) handle = self.readers[fd] - assert handle._callback == callback, '{!r} != {!r}'.format( + assert handle._callback == callback, '{0!r} != {1!r}'.format( handle._callback, callback) - assert handle._args == args, '{!r} != {!r}'.format( + assert handle._args == args, '{0!r} != {1!r}'.format( handle._args, args) def add_writer(self, fd, callback, *args): @@ -347,11 +450,11 @@ def remove_writer(self, fd): return False def assert_writer(self, fd, callback, *args): - assert fd in self.writers, 'fd {} is not registered'.format(fd) + assert fd in self.writers, 'fd {0} is not registered'.format(fd) handle = self.writers[fd] - assert handle._callback == callback, '{!r} != {!r}'.format( + assert handle._callback == callback, '{0!r} != {1!r}'.format( handle._callback, callback) - assert handle._args == args, '{!r} != {!r}'.format( + assert handle._args == args, '{0!r} != {1!r}'.format( handle._args, args) def reset_counters(self): @@ -359,7 +462,7 @@ def reset_counters(self): self.remove_writer_count = collections.defaultdict(int) def _run_once(self): - super()._run_once() + super(TestLoop, self)._run_once() for when in self._timers: advance = self._gen.send(when) self.advance_time(advance) @@ -367,7 +470,7 @@ def _run_once(self): def call_at(self, when, callback, *args): self._timers.append(when) - return super().call_at(when, callback, *args) + return super(TestLoop, self).call_at(when, callback, *args) def _process_events(self, event_list): return @@ -400,8 +503,8 @@ def get_function_source(func): return source -class TestCase(unittest.TestCase): - def set_event_loop(self, loop, *, cleanup=True): +class TestCase(_TestCase): + def set_event_loop(self, loop, cleanup=True): assert loop is not None # ensure that the event loop is passed explicitly in asyncio events.set_event_loop(None) @@ -420,6 +523,48 @@ def tearDown(self): # in an except block of a generator self.assertEqual(sys.exc_info(), (None, None, None)) + if not hasattr(_TestCase, 'assertRaisesRegex'): + def assertRaisesRegex(self, expected_exception, expected_regex, + callable_obj=None, *args, **kwargs): + """Asserts that the message in a raised exception matches a regex. + + Args: + expected_exception: Exception class expected to be raised. + expected_regex: Regex (re pattern object or string) expected + to be found in error message. + callable_obj: Function to be called. + msg: Optional message used in case of failure. Can only be used + when assertRaisesRegex is used as a context manager. + args: Extra args. + kwargs: Extra kwargs. + """ + context = _AssertRaisesContext(expected_exception, self, callable_obj, + expected_regex) + + return context.handle('assertRaisesRegex', callable_obj, args, kwargs) + + if not hasattr(_TestCase, 'assertRegex'): + def assertRegex(self, text, expected_regex, msg=None): + """Fail the test unless the text matches the regular expression.""" + if isinstance(expected_regex, (str, bytes)): + assert expected_regex, "expected_regex must not be empty." + expected_regex = re.compile(expected_regex) + if not expected_regex.search(text): + msg = msg or "Regex didn't match" + msg = '%s: %r not found in %r' % (msg, expected_regex.pattern, text) + raise self.failureException(msg) + + def check_soure_traceback(self, source_traceback, lineno_delta): + frame = sys._getframe(1) + filename = frame.f_code.co_filename + lineno = frame.f_lineno + lineno_delta + name = frame.f_code.co_name + self.assertIsInstance(source_traceback, list) + self.assertEqual(source_traceback[-1][:3], + (filename, + lineno, + name)) + @contextlib.contextmanager def disable_logger(): diff --git a/trollius/transports.py b/trollius/transports.py index 22df3c7a..5bdbdafc 100644 --- a/trollius/transports.py +++ b/trollius/transports.py @@ -1,6 +1,7 @@ """Abstract Transport class.""" import sys +from .compat import flatten_bytes _PY34 = sys.version_info >= (3, 4) @@ -9,7 +10,7 @@ ] -class BaseTransport: +class BaseTransport(object): """Base class for transports.""" def __init__(self, extra=None): @@ -94,12 +95,8 @@ def writelines(self, list_of_data): The default implementation concatenates the arguments and calls write() on the result. """ - if not _PY34: - # In Python 3.3, bytes.join() doesn't handle memoryview. - list_of_data = ( - bytes(data) if isinstance(data, memoryview) else data - for data in list_of_data) - self.write(b''.join(list_of_data)) + data = map(flatten_bytes, list_of_data) + self.write(b''.join(data)) def write_eof(self): """Close the write end after flushing buffered data. @@ -230,7 +227,7 @@ class _FlowControlMixin(Transport): override set_write_buffer_limits() (e.g. to specify different defaults). - The subclass constructor must call super().__init__(extra). This + The subclass constructor must call super(Class, self).__init__(extra). This will call set_write_buffer_limits(). The user may call set_write_buffer_limits() and @@ -239,7 +236,7 @@ class _FlowControlMixin(Transport): """ def __init__(self, extra=None, loop=None): - super().__init__(extra) + super(_FlowControlMixin, self).__init__(extra) assert loop is not None self._loop = loop self._protocol_paused = False diff --git a/trollius/unix_events.py b/trollius/unix_events.py index 75e7c9cc..fcccaaa5 100644 --- a/trollius/unix_events.py +++ b/trollius/unix_events.py @@ -1,4 +1,5 @@ """Selector event loop for Unix with signal handling.""" +from __future__ import absolute_import import errno import os @@ -13,6 +14,7 @@ from . import base_events from . import base_subprocess +from . import compat from . import constants from . import coroutines from . import events @@ -20,8 +22,13 @@ from . import selector_events from . import selectors from . import transports -from .coroutines import coroutine +from .compat import flatten_bytes +from .coroutines import coroutine, From, Return from .log import logger +from .py33_exceptions import ( + reraise, wrap_error, + BlockingIOError, BrokenPipeError, ConnectionResetError, + InterruptedError, ChildProcessError) __all__ = ['SelectorEventLoop', @@ -33,9 +40,10 @@ raise ImportError('Signals are not really supported on Windows') -def _sighandler_noop(signum, frame): - """Dummy signal handler.""" - pass +if compat.PY33: + def _sighandler_noop(signum, frame): + """Dummy signal handler.""" + pass class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): @@ -45,23 +53,27 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): """ def __init__(self, selector=None): - super().__init__(selector) + super(_UnixSelectorEventLoop, self).__init__(selector) self._signal_handlers = {} def _socketpair(self): return socket.socketpair() def close(self): - super().close() + super(_UnixSelectorEventLoop, self).close() for sig in list(self._signal_handlers): self.remove_signal_handler(sig) - def _process_self_data(self, data): - for signum in data: - if not signum: - # ignore null bytes written by _write_to_self() - continue - self._handle_signal(signum) + # On Python <= 3.2, the C signal handler of Python writes a null byte into + # the wakeup file descriptor. We cannot retrieve the signal numbers from + # the file descriptor. + if compat.PY33: + def _process_self_data(self, data): + for signum in data: + if not signum: + # ignore null bytes written by _write_to_self() + continue + self._handle_signal(signum) def add_signal_handler(self, sig, callback, *args): """Add a handler for a signal. UNIX only. @@ -88,14 +100,30 @@ def add_signal_handler(self, sig, callback, *args): self._signal_handlers[sig] = handle try: - # Register a dummy signal handler to ask Python to write the signal - # number in the wakup file descriptor. _process_self_data() will - # read signal numbers from this file descriptor to handle signals. - signal.signal(sig, _sighandler_noop) + if compat.PY33: + # On Python 3.3 and newer, the C signal handler writes the + # signal number into the wakeup file descriptor and then calls + # Py_AddPendingCall() to schedule the Python signal handler. + # + # Register a dummy signal handler to ask Python to write the + # signal number into the wakup file descriptor. + # _process_self_data() will read signal numbers from this file + # descriptor to handle signals. + signal.signal(sig, _sighandler_noop) + else: + # On Python 3.2 and older, the C signal handler first calls + # Py_AddPendingCall() to schedule the Python signal handler, + # and then write a null byte into the wakeup file descriptor. + signal.signal(sig, self._handle_signal) # Set SA_RESTART to limit EINTR occurrences. signal.siginterrupt(sig, False) - except OSError as exc: + except (RuntimeError, OSError) as exc: + # On Python 2, signal.signal(signal.SIGKILL, signal.SIG_IGN) raises + # RuntimeError(22, 'Invalid argument'). On Python 3, + # OSError(22, 'Invalid argument') is raised instead. + exc_type, exc_value, tb = sys.exc_info() + del self._signal_handlers[sig] if not self._signal_handlers: try: @@ -103,12 +131,12 @@ def add_signal_handler(self, sig, callback, *args): except (ValueError, OSError) as nexc: logger.info('set_wakeup_fd(-1) failed: %s', nexc) - if exc.errno == errno.EINVAL: - raise RuntimeError('sig {} cannot be caught'.format(sig)) + if isinstance(exc, RuntimeError) or exc.errno == errno.EINVAL: + raise RuntimeError('sig {0} cannot be caught'.format(sig)) else: - raise + reraise(exc_type, exc_value, tb) - def _handle_signal(self, sig): + def _handle_signal(self, sig, frame=None): """Internal helper that is the actual signal handler.""" handle = self._signal_handlers.get(sig) if handle is None: @@ -138,7 +166,7 @@ def remove_signal_handler(self, sig): signal.signal(sig, handler) except OSError as exc: if exc.errno == errno.EINVAL: - raise RuntimeError('sig {} cannot be caught'.format(sig)) + raise RuntimeError('sig {0} cannot be caught'.format(sig)) else: raise @@ -157,11 +185,11 @@ def _check_signal(self, sig): Raise RuntimeError if there is a problem setting up the handler. """ if not isinstance(sig, int): - raise TypeError('sig must be an int, not {!r}'.format(sig)) + raise TypeError('sig must be an int, not {0!r}'.format(sig)) if not (1 <= sig < signal.NSIG): raise ValueError( - 'sig {} out of range(1, {})'.format(sig, signal.NSIG)) + 'sig {0} out of range(1, {1})'.format(sig, signal.NSIG)) def _make_read_pipe_transport(self, pipe, protocol, waiter=None, extra=None): @@ -185,7 +213,7 @@ def _make_subprocess_transport(self, protocol, args, shell, watcher.add_child_handler(transp.get_pid(), self._child_watcher_callback, transp) try: - yield from waiter + yield From(waiter) except Exception as exc: # Workaround CPython bug #23353: using yield/yield-from in an # except block of a generator doesn't clear properly @@ -196,16 +224,16 @@ def _make_subprocess_transport(self, protocol, args, shell, if err is not None: transp.close() - yield from transp._wait() + yield From(transp._wait()) raise err - return transp + raise Return(transp) def _child_watcher_callback(self, pid, returncode, transp): self.call_soon_threadsafe(transp._process_exited, returncode) @coroutine - def create_unix_connection(self, protocol_factory, path, *, + def create_unix_connection(self, protocol_factory, path, ssl=None, sock=None, server_hostname=None): assert server_hostname is None or isinstance(server_hostname, str) @@ -225,7 +253,7 @@ def create_unix_connection(self, protocol_factory, path, *, sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0) try: sock.setblocking(False) - yield from self.sock_connect(sock, path) + yield From(self.sock_connect(sock, path)) except: sock.close() raise @@ -235,12 +263,12 @@ def create_unix_connection(self, protocol_factory, path, *, raise ValueError('no path and sock were specified') sock.setblocking(False) - transport, protocol = yield from self._create_connection_transport( - sock, protocol_factory, ssl, server_hostname) - return transport, protocol + transport, protocol = yield From(self._create_connection_transport( + sock, protocol_factory, ssl, server_hostname)) + raise Return(transport, protocol) @coroutine - def create_unix_server(self, protocol_factory, path=None, *, + def create_unix_server(self, protocol_factory, path=None, sock=None, backlog=100, ssl=None): if isinstance(ssl, bool): raise TypeError('ssl argument must be an SSLContext or None') @@ -254,13 +282,13 @@ def create_unix_server(self, protocol_factory, path=None, *, try: sock.bind(path) - except OSError as exc: + except socket.error as exc: sock.close() if exc.errno == errno.EADDRINUSE: # Let's improve the error message by adding # with what exact address it occurs. - msg = 'Address {!r} is already in use'.format(path) - raise OSError(errno.EADDRINUSE, msg) from None + msg = 'Address {0!r} is already in use'.format(path) + raise OSError(errno.EADDRINUSE, msg) else: raise except: @@ -273,7 +301,7 @@ def create_unix_server(self, protocol_factory, path=None, *, if sock.family != socket.AF_UNIX: raise ValueError( - 'A UNIX Domain Socket was expected, got {!r}'.format(sock)) + 'A UNIX Domain Socket was expected, got {0!r}'.format(sock)) server = base_events.Server(self, [sock]) sock.listen(backlog) @@ -283,6 +311,7 @@ def create_unix_server(self, protocol_factory, path=None, *, if hasattr(os, 'set_blocking'): + # Python 3.5 and newer def _set_nonblocking(fd): os.set_blocking(fd, False) else: @@ -299,7 +328,7 @@ class _UnixReadPipeTransport(transports.ReadTransport): max_size = 256 * 1024 # max bytes we read in one event loop iteration def __init__(self, loop, pipe, protocol, waiter=None, extra=None): - super().__init__(extra) + super(_UnixReadPipeTransport, self).__init__(extra) self._extra['pipe'] = pipe self._loop = loop self._pipe = pipe @@ -341,7 +370,7 @@ def __repr__(self): def _read_ready(self): try: - data = os.read(self._fileno, self.max_size) + data = wrap_error(os.read, self._fileno, self.max_size) except (BlockingIOError, InterruptedError): pass except OSError as exc: @@ -409,7 +438,7 @@ class _UnixWritePipeTransport(transports._FlowControlMixin, transports.WriteTransport): def __init__(self, loop, pipe, protocol, waiter=None, extra=None): - super().__init__(extra, loop) + super(_UnixWritePipeTransport, self).__init__(extra, loop) self._extra['pipe'] = pipe self._pipe = pipe self._fileno = pipe.fileno() @@ -475,9 +504,7 @@ def _read_ready(self): self._close() def write(self, data): - assert isinstance(data, (bytes, bytearray, memoryview)), repr(data) - if isinstance(data, bytearray): - data = memoryview(data) + data = flatten_bytes(data) if not data: return @@ -491,7 +518,7 @@ def write(self, data): if not self._buffer: # Attempt to send it right away first. try: - n = os.write(self._fileno, data) + n = wrap_error(os.write, self._fileno, data) except (BlockingIOError, InterruptedError): n = 0 except Exception as exc: @@ -511,9 +538,9 @@ def _write_ready(self): data = b''.join(self._buffer) assert data, 'Data should not be empty' - self._buffer.clear() + del self._buffer[:] try: - n = os.write(self._fileno, data) + n = wrap_error(os.write, self._fileno, data) except (BlockingIOError, InterruptedError): self._buffer.append(data) except Exception as exc: @@ -582,7 +609,7 @@ def _close(self, exc=None): self._closing = True if self._buffer: self._loop.remove_writer(self._fileno) - self._buffer.clear() + del self._buffer[:] self._loop.remove_reader(self._fileno) self._loop.call_soon(self._call_connection_lost, exc) @@ -633,11 +660,20 @@ def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, universal_newlines=False, bufsize=bufsize, **kwargs) if stdin_w is not None: + # Retrieve the file descriptor from stdin_w, stdin_w should not + # "own" the file descriptor anymore: closing stdin_fd file + # descriptor must close immediatly the file stdin.close() - self._proc.stdin = open(stdin_w.detach(), 'wb', buffering=bufsize) + if hasattr(stdin_w, 'detach'): + stdin_fd = stdin_w.detach() + self._proc.stdin = os.fdopen(stdin_fd, 'wb', bufsize) + else: + stdin_dup = os.dup(stdin_w.fileno()) + stdin_w.close() + self._proc.stdin = os.fdopen(stdin_dup, 'wb', bufsize) -class AbstractChildWatcher: +class AbstractChildWatcher(object): """Abstract base class for monitoring child processes. Objects derived from this class monitor a collection of subprocesses and @@ -773,12 +809,12 @@ class SafeChildWatcher(BaseChildWatcher): """ def __init__(self): - super().__init__() + super(SafeChildWatcher, self).__init__() self._callbacks = {} def close(self): self._callbacks.clear() - super().close() + super(SafeChildWatcher, self).close() def __enter__(self): return self @@ -850,7 +886,7 @@ class FastChildWatcher(BaseChildWatcher): (O(1) each time a child terminates). """ def __init__(self): - super().__init__() + super(FastChildWatcher, self).__init__() self._callbacks = {} self._lock = threading.Lock() self._zombies = {} @@ -859,7 +895,7 @@ def __init__(self): def close(self): self._callbacks.clear() self._zombies.clear() - super().close() + super(FastChildWatcher, self).close() def __enter__(self): with self._lock: @@ -906,7 +942,7 @@ def _do_waitpid_all(self): # long as we're able to reap a child. while True: try: - pid, status = os.waitpid(-1, os.WNOHANG) + pid, status = wrap_error(os.waitpid, -1, os.WNOHANG) except ChildProcessError: # No more child processes exist. return @@ -949,7 +985,7 @@ class _UnixDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy): _loop_factory = _UnixSelectorEventLoop def __init__(self): - super().__init__() + super(_UnixDefaultEventLoopPolicy, self).__init__() self._watcher = None def _init_watcher(self): @@ -968,7 +1004,7 @@ def set_event_loop(self, loop): the child watcher. """ - super().set_event_loop(loop) + super(_UnixDefaultEventLoopPolicy, self).set_event_loop(loop) if self._watcher is not None and \ isinstance(threading.current_thread(), threading._MainThread): diff --git a/trollius/windows_events.py b/trollius/windows_events.py index 922594f1..7f7764e0 100644 --- a/trollius/windows_events.py +++ b/trollius/windows_events.py @@ -11,12 +11,15 @@ from . import base_subprocess from . import futures from . import proactor_events +from . import py33_winapi as _winapi from . import selector_events from . import tasks from . import windows_utils from . import _overlapped -from .coroutines import coroutine +from .coroutines import coroutine, From, Return from .log import logger +from .py33_exceptions import (wrap_error, get_error_class, + ConnectionRefusedError, BrokenPipeError) __all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor', @@ -42,14 +45,14 @@ class _OverlappedFuture(futures.Future): Cancelling it will immediately cancel the overlapped operation. """ - def __init__(self, ov, *, loop=None): - super().__init__(loop=loop) + def __init__(self, ov, loop=None): + super(_OverlappedFuture, self).__init__(loop=loop) if self._source_traceback: del self._source_traceback[-1] self._ov = ov def _repr_info(self): - info = super()._repr_info() + info = super(_OverlappedFuture, self)._repr_info() if self._ov is not None: state = 'pending' if self._ov.pending else 'completed' info.insert(1, 'overlapped=<%s, %#x>' % (state, self._ov.address)) @@ -73,22 +76,22 @@ def _cancel_overlapped(self): def cancel(self): self._cancel_overlapped() - return super().cancel() + return super(_OverlappedFuture, self).cancel() def set_exception(self, exception): - super().set_exception(exception) + super(_OverlappedFuture, self).set_exception(exception) self._cancel_overlapped() def set_result(self, result): - super().set_result(result) + super(_OverlappedFuture, self).set_result(result) self._ov = None class _BaseWaitHandleFuture(futures.Future): """Subclass of Future which represents a wait handle.""" - def __init__(self, ov, handle, wait_handle, *, loop=None): - super().__init__(loop=loop) + def __init__(self, ov, handle, wait_handle, loop=None): + super(_BaseWaitHandleFuture, self).__init__(loop=loop) if self._source_traceback: del self._source_traceback[-1] # Keep a reference to the Overlapped object to keep it alive until the @@ -107,7 +110,7 @@ def _poll(self): _winapi.WAIT_OBJECT_0) def _repr_info(self): - info = super()._repr_info() + info = super(_BaseWaitHandleFuture, self)._repr_info() info.append('handle=%#x' % self._handle) if self._handle is not None: state = 'signaled' if self._poll() else 'waiting' @@ -147,15 +150,15 @@ def _unregister_wait(self): def cancel(self): self._unregister_wait() - return super().cancel() + return super(_BaseWaitHandleFuture, self).cancel() def set_exception(self, exception): self._unregister_wait() - super().set_exception(exception) + super(_BaseWaitHandleFuture, self).set_exception(exception) def set_result(self, result): self._unregister_wait() - super().set_result(result) + super(_BaseWaitHandleFuture, self).set_result(result) class _WaitCancelFuture(_BaseWaitHandleFuture): @@ -163,8 +166,9 @@ class _WaitCancelFuture(_BaseWaitHandleFuture): _WaitHandleFuture using an event. """ - def __init__(self, ov, event, wait_handle, *, loop=None): - super().__init__(ov, event, wait_handle, loop=loop) + def __init__(self, ov, event, wait_handle, loop=None): + super(_WaitCancelFuture, self).__init__(ov, event, wait_handle, + loop=loop) self._done_callback = None @@ -178,8 +182,9 @@ def _schedule_callbacks(self): class _WaitHandleFuture(_BaseWaitHandleFuture): - def __init__(self, ov, handle, wait_handle, proactor, *, loop=None): - super().__init__(ov, handle, wait_handle, loop=loop) + def __init__(self, ov, handle, wait_handle, proactor, loop=None): + super(_WaitHandleFuture, self).__init__(ov, handle, wait_handle, + loop=loop) self._proactor = proactor self._unregister_proactor = True self._event = _overlapped.CreateEvent(None, True, False, None) @@ -201,7 +206,7 @@ def _unregister_wait_cb(self, fut): self._proactor._unregister(self._ov) self._proactor = None - super()._unregister_wait_cb(fut) + super(_WaitHandleFuture, self)._unregister_wait_cb(fut) def _unregister_wait(self): if not self._registered: @@ -259,7 +264,7 @@ def _server_pipe_handle(self, first): flags = _winapi.PIPE_ACCESS_DUPLEX | _winapi.FILE_FLAG_OVERLAPPED if first: flags |= _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE - h = _winapi.CreateNamedPipe( + h = wrap_error(_winapi.CreateNamedPipe, self._address, flags, _winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_READMODE_MESSAGE | _winapi.PIPE_WAIT, @@ -301,7 +306,7 @@ class ProactorEventLoop(proactor_events.BaseProactorEventLoop): def __init__(self, proactor=None): if proactor is None: proactor = IocpProactor() - super().__init__(proactor) + super(ProactorEventLoop, self).__init__(proactor) def _socketpair(self): return windows_utils.socketpair() @@ -309,11 +314,11 @@ def _socketpair(self): @coroutine def create_pipe_connection(self, protocol_factory, address): f = self._proactor.connect_pipe(address) - pipe = yield from f + pipe = yield From(f) protocol = protocol_factory() trans = self._make_duplex_pipe_transport(pipe, protocol, extra={'addr': address}) - return trans, protocol + raise Return(trans, protocol) @coroutine def start_serving_pipe(self, protocol_factory, address): @@ -372,7 +377,7 @@ def _make_subprocess_transport(self, protocol, args, shell, waiter=waiter, extra=extra, **kwargs) try: - yield from waiter + yield From(waiter) except Exception as exc: # Workaround CPython bug #23353: using yield/yield-from in an # except block of a generator doesn't clear properly sys.exc_info() @@ -382,13 +387,13 @@ def _make_subprocess_transport(self, protocol, args, shell, if err is not None: transp.close() - yield from transp._wait() + yield From(transp._wait()) raise err - return transp + raise Return(transp) -class IocpProactor: +class IocpProactor(object): """Proactor implementation using IOCP.""" def __init__(self, concurrency=0xffffffff): @@ -426,16 +431,16 @@ def recv(self, conn, nbytes, flags=0): ov = _overlapped.Overlapped(NULL) try: if isinstance(conn, socket.socket): - ov.WSARecv(conn.fileno(), nbytes, flags) + wrap_error(ov.WSARecv, conn.fileno(), nbytes, flags) else: - ov.ReadFile(conn.fileno(), nbytes) + wrap_error(ov.ReadFile, conn.fileno(), nbytes) except BrokenPipeError: return self._result(b'') def finish_recv(trans, key, ov): try: - return ov.getresult() - except OSError as exc: + return wrap_error(ov.getresult) + except WindowsError as exc: if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: raise ConnectionResetError(*exc.args) else: @@ -453,8 +458,8 @@ def send(self, conn, buf, flags=0): def finish_send(trans, key, ov): try: - return ov.getresult() - except OSError as exc: + return wrap_error(ov.getresult) + except WindowsError as exc: if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: raise ConnectionResetError(*exc.args) else: @@ -469,7 +474,7 @@ def accept(self, listener): ov.AcceptEx(listener.fileno(), conn.fileno()) def finish_accept(trans, key, ov): - ov.getresult() + wrap_error(ov.getresult) # Use SO_UPDATE_ACCEPT_CONTEXT so getsockname() etc work. buf = struct.pack('@P', listener.fileno()) conn.setsockopt(socket.SOL_SOCKET, @@ -481,7 +486,7 @@ def finish_accept(trans, key, ov): def accept_coro(future, conn): # Coroutine closing the accept socket if the future is cancelled try: - yield from future + yield From(future) except futures.CancelledError: conn.close() raise @@ -496,7 +501,7 @@ def connect(self, conn, address): # The socket needs to be locally bound before we call ConnectEx(). try: _overlapped.BindLocal(conn.fileno(), conn.family) - except OSError as e: + except WindowsError as e: if e.winerror != errno.WSAEINVAL: raise # Probably already locally bound; check using getsockname(). @@ -506,7 +511,7 @@ def connect(self, conn, address): ov.ConnectEx(conn.fileno(), address) def finish_connect(trans, key, ov): - ov.getresult() + wrap_error(ov.getresult) # Use SO_UPDATE_CONNECT_CONTEXT so getsockname() etc work. conn.setsockopt(socket.SOL_SOCKET, _overlapped.SO_UPDATE_CONNECT_CONTEXT, 0) @@ -526,7 +531,7 @@ def accept_pipe(self, pipe): return self._result(pipe) def finish_accept_pipe(trans, key, ov): - ov.getresult() + wrap_error(ov.getresult) return pipe return self._register(ov, pipe, finish_accept_pipe) @@ -539,17 +544,17 @@ def connect_pipe(self, address): # Call CreateFile() in a loop until it doesn't fail with # ERROR_PIPE_BUSY try: - handle = _overlapped.ConnectPipe(address) + handle = wrap_error(_overlapped.ConnectPipe, address) break - except OSError as exc: + except WindowsError as exc: if exc.winerror != _overlapped.ERROR_PIPE_BUSY: raise # ConnectPipe() failed with ERROR_PIPE_BUSY: retry later delay = min(delay * 2, CONNECT_PIPE_MAX_DELAY) - yield from tasks.sleep(delay, loop=self._loop) + yield From(tasks.sleep(delay, loop=self._loop)) - return windows_utils.PipeHandle(handle) + raise Return(windows_utils.PipeHandle(handle)) def wait_for_handle(self, handle, timeout=None): """Wait for a handle. @@ -572,7 +577,7 @@ def _wait_for_handle(self, handle, timeout, _is_cancel): else: # RegisterWaitForSingleObject() has a resolution of 1 millisecond, # round away from zero to wait *at least* timeout seconds. - ms = math.ceil(timeout * 1e3) + ms = int(math.ceil(timeout * 1e3)) # We only create ov so we can use ov.address as a key for the cache. ov = _overlapped.Overlapped(NULL) @@ -660,7 +665,7 @@ def _poll(self, timeout=None): else: # GetQueuedCompletionStatus() has a resolution of 1 millisecond, # round away from zero to wait *at least* timeout seconds. - ms = math.ceil(timeout * 1e3) + ms = int(math.ceil(timeout * 1e3)) if ms >= INFINITE: raise ValueError("timeout too big") @@ -705,7 +710,7 @@ def _poll(self, timeout=None): # Remove unregisted futures for ov in self._unregistered: self._cache.pop(ov.address, None) - self._unregistered.clear() + del self._unregistered[:] def _stop_serving(self, obj): # obj is a socket or pipe handle. It will be closed in diff --git a/trollius/windows_utils.py b/trollius/windows_utils.py index 870cd13a..b2b96177 100644 --- a/trollius/windows_utils.py +++ b/trollius/windows_utils.py @@ -1,6 +1,7 @@ """ Various Windows specific bits and pieces """ +from __future__ import absolute_import import sys @@ -16,6 +17,9 @@ import tempfile import warnings +from . import py33_winapi as _winapi +from .py33_exceptions import wrap_error, BlockingIOError, InterruptedError + __all__ = ['socketpair', 'pipe', 'Popen', 'PIPE', 'PipeHandle'] @@ -64,7 +68,7 @@ def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): try: csock.setblocking(False) try: - csock.connect((addr, port)) + wrap_error(csock.connect, (addr, port)) except (BlockingIOError, InterruptedError): pass csock.setblocking(True) @@ -80,7 +84,7 @@ def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): # Replacement for os.pipe() using handles instead of fds -def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): +def pipe(duplex=False, overlapped=(True, True), bufsize=BUFSIZE): """Like os.pipe() but with overlapped support and using handles not fds.""" address = tempfile.mktemp(prefix=r'\\.\pipe\python-pipe-%d-%d-' % (os.getpid(), next(_mmap_counter))) @@ -115,7 +119,12 @@ def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): flags_and_attribs, _winapi.NULL) ov = _winapi.ConnectNamedPipe(h1, overlapped=True) - ov.GetOverlappedResult(True) + if hasattr(ov, 'GetOverlappedResult'): + # _winapi module of Python 3.3 + ov.GetOverlappedResult(True) + else: + # _overlapped module + wrap_error(ov.getresult, True) return h1, h2 except: if h1 is not None: @@ -128,7 +137,7 @@ def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): # Wrapper for a pipe handle -class PipeHandle: +class PipeHandle(object): """Wrapper for an overlapped pipe handle which is vaguely file-object like. The IOCP event loop can use these instead of socket objects. @@ -152,7 +161,7 @@ def fileno(self): raise ValueError("I/O operatioon on closed pipe") return self._handle - def close(self, *, CloseHandle=_winapi.CloseHandle): + def close(self, CloseHandle=_winapi.CloseHandle): if self._handle is not None: CloseHandle(self._handle) self._handle = None @@ -200,8 +209,11 @@ def __init__(self, args, stdin=None, stdout=None, stderr=None, **kwds): else: stderr_wfd = stderr try: - super().__init__(args, stdin=stdin_rfd, stdout=stdout_wfd, - stderr=stderr_wfd, **kwds) + super(Popen, self).__init__(args, + stdin=stdin_rfd, + stdout=stdout_wfd, + stderr=stderr_wfd, + **kwds) except: for h in (stdin_wh, stdout_rh, stderr_rh): if h is not None: From 6d516e583cdf742978d5d9e9f5e62980f7c30077 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 7 Jul 2015 23:22:10 +0200 Subject: [PATCH 1394/1502] Port asyncio to Python 2, tests/ directory --- tests/echo3.py | 3 +- tests/test_base_events.py | 187 +++++++++-------- tests/test_events.py | 325 ++++++++++++++++------------- tests/test_futures.py | 60 +++--- tests/test_locks.py | 161 +++++++------- tests/test_proactor_events.py | 9 +- tests/test_queues.py | 102 ++++----- tests/test_selector_events.py | 57 +++-- tests/test_selectors.py | 49 ++--- tests/test_sslproto.py | 2 +- tests/test_streams.py | 52 ++--- tests/test_subprocess.py | 109 +++++----- tests/test_tasks.py | 383 +++++++++++++++++----------------- tests/test_transports.py | 11 +- tests/test_unix_events.py | 71 ++++--- tests/test_windows_events.py | 35 ++-- tests/test_windows_utils.py | 36 ++-- 17 files changed, 875 insertions(+), 777 deletions(-) diff --git a/tests/echo3.py b/tests/echo3.py index 06449673..4c2b505d 100644 --- a/tests/echo3.py +++ b/tests/echo3.py @@ -1,4 +1,5 @@ import os +from trollius.py33_exceptions import wrap_error if __name__ == '__main__': while True: @@ -6,6 +7,6 @@ if not buf: break try: - os.write(1, b'OUT:'+buf) + wrap_error(os.write, 1, b'OUT:'+buf) except OSError as ex: os.write(2, b'ERR:' + ex.__class__.__name__.encode('ascii')) diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 7e7380db..808b7090 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -8,23 +8,16 @@ import threading import time import unittest -from unittest import mock import trollius as asyncio +from trollius import Return, From from trollius import base_events from trollius import constants from trollius import test_utils -try: - from test import support -except ImportError: - from trollius import test_support as support -try: - from test.support.script_helper import assert_python_ok -except ImportError: - try: - from test.script_helper import assert_python_ok - except ImportError: - from trollius.test_support import assert_python_ok +from trollius.py33_exceptions import BlockingIOError +from trollius.test_utils import mock +from trollius.time_monotonic import time_monotonic +from trollius import test_support as support MOCK_ANY = mock.ANY @@ -61,6 +54,7 @@ def test_not_implemented(self): NotImplementedError, self.loop._make_write_pipe_transport, m, m) gen = self.loop._make_subprocess_transport(m, m, m, m, m, m, m) + # self.assertRaises(NotImplementedError, next, iter(gen)) with self.assertRaises(NotImplementedError): gen.send(None) @@ -265,9 +259,9 @@ def cb(): f.cancel() # Don't complain about abandoned Future. def test__run_once(self): - h1 = asyncio.TimerHandle(time.monotonic() + 5.0, lambda: True, (), + h1 = asyncio.TimerHandle(time_monotonic() + 5.0, lambda: True, (), self.loop) - h2 = asyncio.TimerHandle(time.monotonic() + 10.0, lambda: True, (), + h2 = asyncio.TimerHandle(time_monotonic() + 10.0, lambda: True, (), self.loop) h1.cancel() @@ -314,23 +308,21 @@ def fast_select(timeout): self.assertEqual(logging.DEBUG, m_logger.log.call_args[0][0]) def test__run_once_schedule_handle(self): - handle = None - processed = False + non_local = {'handle': None, 'processed': False} def cb(loop): - nonlocal processed, handle - processed = True - handle = loop.call_soon(lambda: True) + non_local['processed'] = True + non_local['handle'] = loop.call_soon(lambda: True) - h = asyncio.TimerHandle(time.monotonic() - 1, cb, (self.loop,), + h = asyncio.TimerHandle(time_monotonic() - 1, cb, (self.loop,), self.loop) self.loop._process_events = mock.Mock() self.loop._scheduled.append(h) self.loop._run_once() - self.assertTrue(processed) - self.assertEqual([handle], list(self.loop._ready)) + self.assertTrue(non_local['processed']) + self.assertEqual([non_local['handle']], list(self.loop._ready)) def test__run_once_cancelled_event_cleanup(self): self.loop._process_events = mock.Mock() @@ -497,10 +489,12 @@ def zero_error(fut): def test_default_exc_handler_coro(self): self.loop._process_events = mock.Mock() + self.loop.set_debug(True) + asyncio.set_event_loop(self.loop) @asyncio.coroutine def zero_error_coro(): - yield from asyncio.sleep(0.01, loop=self.loop) + yield From(asyncio.sleep(0.01, loop=self.loop)) 1/0 # Test Future.__del__ @@ -509,6 +503,7 @@ def zero_error_coro(): fut.add_done_callback(lambda *args: self.loop.stop()) self.loop.run_forever() fut = None # Trigger Future.__del__ or futures._TracebackLogger + support.gc_collect() if PY34: # Future.__del__ in Python 3.4 logs error with # an actual exception context @@ -582,7 +577,7 @@ def handler(loop, context): exc_info=(AttributeError, MOCK_ANY, MOCK_ANY)) def test_default_exc_handler_broken(self): - _context = None + contexts = [] class Loop(base_events.BaseEventLoop): @@ -590,8 +585,7 @@ class Loop(base_events.BaseEventLoop): _process_events = mock.Mock() def default_exception_handler(self, context): - nonlocal _context - _context = context + contexts.append(context) # Simulates custom buggy "default_exception_handler" raise ValueError('spam') @@ -614,7 +608,7 @@ def zero_error(): def custom_handler(loop, context): raise ValueError('ham') - _context = None + del contexts[:] loop.set_exception_handler(custom_handler) with mock.patch('trollius.base_events.logger') as log: run_loop() @@ -625,8 +619,9 @@ def custom_handler(loop, context): # Check that original context was passed to default # exception handler. - self.assertIn('context', _context) - self.assertIs(type(_context['context']['exception']), + context = contexts[0] + self.assertIn('context', context) + self.assertIs(type(context['context']['exception']), ZeroDivisionError) def test_set_task_factory_invalid(self): @@ -667,27 +662,18 @@ def coro(): def test_env_var_debug(self): code = '\n'.join(( - 'import asyncio', - 'loop = asyncio.get_event_loop()', + 'import trollius', + 'loop = trollius.get_event_loop()', 'print(loop.get_debug())')) - # Test with -E to not fail if the unit test was run with - # PYTHONASYNCIODEBUG set to a non-empty string - sts, stdout, stderr = assert_python_ok('-E', '-c', code) - self.assertEqual(stdout.rstrip(), b'False') - - sts, stdout, stderr = assert_python_ok('-c', code, - PYTHONASYNCIODEBUG='') + sts, stdout, stderr = support.assert_python_ok('-c', code, + TROLLIUSDEBUG='') self.assertEqual(stdout.rstrip(), b'False') - sts, stdout, stderr = assert_python_ok('-c', code, - PYTHONASYNCIODEBUG='1') + sts, stdout, stderr = support.assert_python_ok('-c', code, + TROLLIUSDEBUG='1') self.assertEqual(stdout.rstrip(), b'True') - sts, stdout, stderr = assert_python_ok('-E', '-c', code, - PYTHONASYNCIODEBUG='1') - self.assertEqual(stdout.rstrip(), b'False') - def test_create_task(self): class MyTask(asyncio.Task): pass @@ -829,27 +815,29 @@ class MyProto(asyncio.Protocol): @asyncio.coroutine def getaddrinfo(*args, **kw): - yield from [] - return [(2, 1, 6, '', ('107.6.106.82', 80)), - (2, 1, 6, '', ('107.6.106.82', 80))] + yield From(None) + raise Return([(2, 1, 6, '', ('107.6.106.82', 80)), + (2, 1, 6, '', ('107.6.106.82', 80))]) def getaddrinfo_task(*args, **kwds): return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) - idx = -1 - errors = ['err1', 'err2'] + non_local = { + 'idx': -1, + 'errors': ['err1', 'err2'], + } def _socket(*args, **kw): - nonlocal idx, errors - idx += 1 - raise OSError(errors[idx]) + non_local['idx'] += 1 + raise socket.error(non_local['errors'][non_local['idx']]) + m_socket.error = socket.error m_socket.socket = _socket self.loop.getaddrinfo = getaddrinfo_task coro = self.loop.create_connection(MyProto, 'example.com', 80) - with self.assertRaises(OSError) as cm: + with self.assertRaises(socket.error) as cm: self.loop.run_until_complete(coro) self.assertEqual(str(cm.exception), 'Multiple exceptions: err1, err2') @@ -859,6 +847,7 @@ def test_create_connection_timeout(self, m_socket): # Ensure that the socket is closed on timeout sock = mock.Mock() m_socket.socket.return_value = sock + m_socket.error = socket.error def getaddrinfo(*args, **kw): fut = asyncio.Future(loop=self.loop) @@ -887,7 +876,7 @@ def test_create_connection_no_host_port_sock(self): def test_create_connection_no_getaddrinfo(self): @asyncio.coroutine def getaddrinfo(*args, **kw): - yield from [] + yield From(None) def getaddrinfo_task(*args, **kwds): return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) @@ -895,24 +884,24 @@ def getaddrinfo_task(*args, **kwds): self.loop.getaddrinfo = getaddrinfo_task coro = self.loop.create_connection(MyProto, 'example.com', 80) self.assertRaises( - OSError, self.loop.run_until_complete, coro) + socket.error, self.loop.run_until_complete, coro) def test_create_connection_connect_err(self): @asyncio.coroutine def getaddrinfo(*args, **kw): - yield from [] - return [(2, 1, 6, '', ('107.6.106.82', 80))] + yield From(None) + raise Return([(2, 1, 6, '', ('107.6.106.82', 80))]) def getaddrinfo_task(*args, **kwds): return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) self.loop.getaddrinfo = getaddrinfo_task self.loop.sock_connect = mock.Mock() - self.loop.sock_connect.side_effect = OSError + self.loop.sock_connect.side_effect = socket.error coro = self.loop.create_connection(MyProto, 'example.com', 80) self.assertRaises( - OSError, self.loop.run_until_complete, coro) + socket.error, self.loop.run_until_complete, coro) def test_create_connection_multiple(self): @asyncio.coroutine @@ -925,11 +914,11 @@ def getaddrinfo_task(*args, **kwds): self.loop.getaddrinfo = getaddrinfo_task self.loop.sock_connect = mock.Mock() - self.loop.sock_connect.side_effect = OSError + self.loop.sock_connect.side_effect = socket.error coro = self.loop.create_connection( MyProto, 'example.com', 80, family=socket.AF_INET) - with self.assertRaises(OSError): + with self.assertRaises(socket.error): self.loop.run_until_complete(coro) @mock.patch('trollius.base_events.socket') @@ -937,10 +926,11 @@ def test_create_connection_multiple_errors_local_addr(self, m_socket): def bind(addr): if addr[0] == '0.0.0.1': - err = OSError('Err') + err = socket.error('Err') err.strerror = 'Err' raise err + m_socket.error = socket.error m_socket.socket.return_value.bind = bind @asyncio.coroutine @@ -953,12 +943,12 @@ def getaddrinfo_task(*args, **kwds): self.loop.getaddrinfo = getaddrinfo_task self.loop.sock_connect = mock.Mock() - self.loop.sock_connect.side_effect = OSError('Err2') + self.loop.sock_connect.side_effect = socket.error('Err2') coro = self.loop.create_connection( MyProto, 'example.com', 80, family=socket.AF_INET, local_addr=(None, 8080)) - with self.assertRaises(OSError) as cm: + with self.assertRaises(socket.error) as cm: self.loop.run_until_complete(coro) self.assertTrue(str(cm.exception).startswith('Multiple exceptions: ')) @@ -981,7 +971,7 @@ def getaddrinfo_task(*args, **kwds): MyProto, 'example.com', 80, family=socket.AF_INET, local_addr=(None, 8080)) self.assertRaises( - OSError, self.loop.run_until_complete, coro) + socket.error, self.loop.run_until_complete, coro) def test_create_connection_ssl_server_hostname_default(self): self.loop.getaddrinfo = mock.Mock() @@ -994,7 +984,9 @@ def mock_getaddrinfo(*args, **kwds): self.loop.getaddrinfo.side_effect = mock_getaddrinfo self.loop.sock_connect = mock.Mock() - self.loop.sock_connect.return_value = () + f = asyncio.Future(loop=self.loop) + f.set_result(()) + self.loop.sock_connect.return_value = f self.loop._make_ssl_transport = mock.Mock() class _SelectorTransportMock: @@ -1067,21 +1059,20 @@ def test_create_connection_ssl_server_hostname_errors(self): def test_create_server_empty_host(self): # if host is empty string use None instead - host = object() + non_local = {'host': object()} @asyncio.coroutine def getaddrinfo(*args, **kw): - nonlocal host - host = args[0] - yield from [] + non_local['host'] = args[0] + yield From(None) def getaddrinfo_task(*args, **kwds): return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) self.loop.getaddrinfo = getaddrinfo_task fut = self.loop.create_server(MyProto, '', 0) - self.assertRaises(OSError, self.loop.run_until_complete, fut) - self.assertIsNone(host) + self.assertRaises(socket.error, self.loop.run_until_complete, fut) + self.assertIsNone(non_local['host']) def test_create_server_host_port_sock(self): fut = self.loop.create_server( @@ -1093,18 +1084,25 @@ def test_create_server_no_host_port_sock(self): self.assertRaises(ValueError, self.loop.run_until_complete, fut) def test_create_server_no_getaddrinfo(self): - getaddrinfo = self.loop.getaddrinfo = mock.Mock() - getaddrinfo.return_value = [] + @asyncio.coroutine + def getaddrinfo(*args, **kw): + raise Return([]) + + def getaddrinfo_task(*args, **kwds): + return asyncio.Task(getaddrinfo(*args, **kwds), loop=self.loop) + + self.loop.getaddrinfo = getaddrinfo_task f = self.loop.create_server(MyProto, '0.0.0.0', 0) - self.assertRaises(OSError, self.loop.run_until_complete, f) + self.assertRaises(socket.error, self.loop.run_until_complete, f) @mock.patch('trollius.base_events.socket') def test_create_server_cant_bind(self, m_socket): - class Err(OSError): + class Err(socket.error): strerror = 'error' + m_socket.error = socket.error m_socket.getaddrinfo.return_value = [ (2, 1, 6, '', ('127.0.0.1', 10100))] m_socket.getaddrinfo._is_coroutine = False @@ -1112,18 +1110,19 @@ class Err(OSError): m_sock.bind.side_effect = Err fut = self.loop.create_server(MyProto, '0.0.0.0', 0) - self.assertRaises(OSError, self.loop.run_until_complete, fut) + self.assertRaises(socket.error, self.loop.run_until_complete, fut) self.assertTrue(m_sock.close.called) @mock.patch('trollius.base_events.socket') def test_create_datagram_endpoint_no_addrinfo(self, m_socket): + m_socket.error = socket.error m_socket.getaddrinfo.return_value = [] m_socket.getaddrinfo._is_coroutine = False coro = self.loop.create_datagram_endpoint( MyDatagramProto, local_addr=('localhost', 0)) self.assertRaises( - OSError, self.loop.run_until_complete, coro) + socket.error, self.loop.run_until_complete, coro) def test_create_datagram_endpoint_addr_error(self): coro = self.loop.create_datagram_endpoint( @@ -1137,29 +1136,31 @@ def test_create_datagram_endpoint_addr_error(self): def test_create_datagram_endpoint_connect_err(self): self.loop.sock_connect = mock.Mock() - self.loop.sock_connect.side_effect = OSError + self.loop.sock_connect.side_effect = socket.error coro = self.loop.create_datagram_endpoint( asyncio.DatagramProtocol, remote_addr=('127.0.0.1', 0)) self.assertRaises( - OSError, self.loop.run_until_complete, coro) + socket.error, self.loop.run_until_complete, coro) @mock.patch('trollius.base_events.socket') def test_create_datagram_endpoint_socket_err(self, m_socket): + m_socket.error = socket.error m_socket.getaddrinfo = socket.getaddrinfo - m_socket.socket.side_effect = OSError + m_socket.socket.side_effect = socket.error coro = self.loop.create_datagram_endpoint( asyncio.DatagramProtocol, family=socket.AF_INET) self.assertRaises( - OSError, self.loop.run_until_complete, coro) + socket.error, self.loop.run_until_complete, coro) coro = self.loop.create_datagram_endpoint( asyncio.DatagramProtocol, local_addr=('127.0.0.1', 0)) self.assertRaises( - OSError, self.loop.run_until_complete, coro) + socket.error, self.loop.run_until_complete, coro) - @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 not supported or enabled') + @test_utils.skipUnless(support.IPV6_ENABLED, + 'IPv6 not supported or enabled') def test_create_datagram_endpoint_no_matching_family(self): coro = self.loop.create_datagram_endpoint( asyncio.DatagramProtocol, @@ -1169,12 +1170,13 @@ def test_create_datagram_endpoint_no_matching_family(self): @mock.patch('trollius.base_events.socket') def test_create_datagram_endpoint_setblk_err(self, m_socket): - m_socket.socket.return_value.setblocking.side_effect = OSError + m_socket.error = socket.error + m_socket.socket.return_value.setblocking.side_effect = socket.error coro = self.loop.create_datagram_endpoint( asyncio.DatagramProtocol, family=socket.AF_INET) self.assertRaises( - OSError, self.loop.run_until_complete, coro) + socket.error, self.loop.run_until_complete, coro) self.assertTrue( m_socket.socket.return_value.close.called) @@ -1185,10 +1187,11 @@ def test_create_datagram_endpoint_noaddr_nofamily(self): @mock.patch('trollius.base_events.socket') def test_create_datagram_endpoint_cant_bind(self, m_socket): - class Err(OSError): + class Err(socket.error): pass m_socket.AF_INET6 = socket.AF_INET6 + m_socket.error = socket.error m_socket.getaddrinfo = socket.getaddrinfo m_sock = m_socket.socket.return_value = mock.Mock() m_sock.bind.side_effect = Err @@ -1210,7 +1213,7 @@ def test_accept_connection_retry(self): def test_accept_connection_exception(self, m_log): sock = mock.Mock() sock.fileno.return_value = 10 - sock.accept.side_effect = OSError(errno.EMFILE, 'Too many open files') + sock.accept.side_effect = socket.error(errno.EMFILE, 'Too many open files') self.loop.remove_reader = mock.Mock() self.loop.call_later = mock.Mock() @@ -1250,7 +1253,7 @@ def stop_loop_cb(loop): @asyncio.coroutine def stop_loop_coro(loop): - yield from () + yield From(None) loop.stop() asyncio.set_event_loop(self.loop) @@ -1260,7 +1263,8 @@ def stop_loop_coro(loop): # slow callback self.loop.call_soon(stop_loop_cb, self.loop) self.loop.run_forever() - fmt, *args = m_logger.warning.call_args[0] + fmt = m_logger.warning.call_args[0][0] + args = m_logger.warning.call_args[0][1:] self.assertRegex(fmt % tuple(args), "^Executing " "took .* seconds$") @@ -1268,7 +1272,8 @@ def stop_loop_coro(loop): # slow task asyncio.ensure_future(stop_loop_coro(self.loop), loop=self.loop) self.loop.run_forever() - fmt, *args = m_logger.warning.call_args[0] + fmt = m_logger.warning.call_args[0][0] + args = m_logger.warning.call_args[0][1:] self.assertRegex(fmt % tuple(args), "^Executing " "took .* seconds$") diff --git a/tests/test_events.py b/tests/test_events.py index 402898df..8f600eb5 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -1,5 +1,6 @@ """Tests for events.py.""" +import contextlib import functools import gc import io @@ -8,29 +9,39 @@ import re import signal import socket -try: - import ssl -except ImportError: - ssl = None import subprocess import sys import threading -import time import errno import unittest -from unittest import mock import weakref +try: + import ssl +except ImportError: + ssl = None + +try: + import concurrent +except ImportError: + concurrent = None + +from trollius import Return, From +from trollius import futures import trollius as asyncio +from trollius import compat +from trollius import events from trollius import proactor_events from trollius import selector_events from trollius import sslproto +from trollius import test_support as support from trollius import test_utils -try: - from test import support -except ImportError: - from trollius import test_support as support +from trollius.py33_exceptions import (wrap_error, + BlockingIOError, ConnectionRefusedError, + FileNotFoundError) +from trollius.test_utils import mock +from trollius.time_monotonic import time_monotonic def data_file(filename): @@ -95,7 +106,7 @@ def connection_lost(self, exc): class MyProto(MyBaseProto): def connection_made(self, transport): - super().connection_made(transport) + super(MyProto, self).connection_made(transport) transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n') @@ -187,7 +198,7 @@ def __init__(self, loop): self.transport = None self.connected = asyncio.Future(loop=loop) self.completed = asyncio.Future(loop=loop) - self.disconnects = {fd: asyncio.Future(loop=loop) for fd in range(3)} + self.disconnects = dict((fd, futures.Future(loop=loop)) for fd in range(3)) self.data = {1: b'', 2: b''} self.returncode = None self.got_data = {1: asyncio.Event(loop=loop), @@ -221,10 +232,10 @@ def process_exited(self): self.returncode = self.transport.get_returncode() -class EventLoopTestsMixin: +class EventLoopTestsMixin(object): def setUp(self): - super().setUp() + super(EventLoopTestsMixin, self).setUp() self.loop = self.create_event_loop() self.set_event_loop(self.loop) @@ -235,12 +246,12 @@ def tearDown(self): self.loop.close() gc.collect() - super().tearDown() + super(EventLoopTestsMixin, self).tearDown() def test_run_until_complete_nesting(self): @asyncio.coroutine def coro1(): - yield + yield From(None) @asyncio.coroutine def coro2(): @@ -263,10 +274,13 @@ def test_run_until_complete_stopped(self): @asyncio.coroutine def cb(): self.loop.stop() - yield from asyncio.sleep(0.1, loop=self.loop) + yield From(asyncio.sleep(0.1, loop=self.loop)) + task = cb() self.assertRaises(RuntimeError, self.loop.run_until_complete, task) + for task in asyncio.Task.all_tasks(loop=self.loop): + task._log_destroy_pending = False def test_call_later(self): results = [] @@ -276,9 +290,9 @@ def callback(arg): self.loop.stop() self.loop.call_later(0.1, callback, 'hello world') - t0 = time.monotonic() + t0 = time_monotonic() self.loop.run_forever() - t1 = time.monotonic() + t1 = time_monotonic() self.assertEqual(results, ['hello world']) self.assertTrue(0.08 <= t1-t0 <= 0.8, t1-t0) @@ -329,13 +343,14 @@ def callback(arg): self.loop.run_forever() self.assertEqual(results, ['hello', 'world']) + @test_utils.skipIf(concurrent is None, 'need concurrent.futures') def test_run_in_executor(self): def run(arg): - return (arg, threading.get_ident()) + return (arg, threading.current_thread().ident) f2 = self.loop.run_in_executor(None, run, 'yo') res, thread_id = self.loop.run_until_complete(f2) self.assertEqual(res, 'yo') - self.assertNotEqual(thread_id, threading.get_ident()) + self.assertNotEqual(thread_id, threading.current_thread().ident) def test_reader_callback(self): r, w = test_utils.socketpair() @@ -423,7 +438,7 @@ def test_sock_client_ops(self): sock = socket.socket() self._basetest_sock_client_ops(httpd, sock) - @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_unix_sock_client_ops(self): with test_utils.run_test_unix_server() as httpd: sock = socket.socket(socket.AF_UNIX) @@ -463,13 +478,12 @@ def test_sock_accept(self): conn.close() listener.close() - @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + @test_utils.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') def test_add_signal_handler(self): - caught = 0 + non_local = {'caught': 0} def my_handler(): - nonlocal caught - caught += 1 + non_local['caught'] += 1 # Check error behavior first. self.assertRaises( @@ -498,7 +512,7 @@ def my_handler(): self.loop.add_signal_handler(signal.SIGINT, my_handler) os.kill(os.getpid(), signal.SIGINT) - test_utils.run_until(self.loop, lambda: caught) + test_utils.run_until(self.loop, lambda: non_local['caught']) # Removing it should restore the default handler. self.assertTrue(self.loop.remove_signal_handler(signal.SIGINT)) @@ -507,30 +521,28 @@ def my_handler(): # Removing again returns False. self.assertFalse(self.loop.remove_signal_handler(signal.SIGINT)) - @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + @test_utils.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') def test_signal_handling_while_selecting(self): # Test with a signal actually arriving during a select() call. - caught = 0 + non_local = {'caught': 0} def my_handler(): - nonlocal caught - caught += 1 + non_local['caught'] += 1 self.loop.stop() self.loop.add_signal_handler(signal.SIGALRM, my_handler) signal.setitimer(signal.ITIMER_REAL, 0.01, 0) # Send SIGALRM once. self.loop.run_forever() - self.assertEqual(caught, 1) + self.assertEqual(non_local['caught'], 1) - @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + @test_utils.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') def test_signal_handling_args(self): some_args = (42,) - caught = 0 + non_local = {'caught': 0} def my_handler(*args): - nonlocal caught - caught += 1 + non_local['caught'] += 1 self.assertEqual(args, some_args) self.loop.add_signal_handler(signal.SIGALRM, my_handler, *some_args) @@ -538,7 +550,7 @@ def my_handler(*args): signal.setitimer(signal.ITIMER_REAL, 0.1, 0) # Send SIGALRM once. self.loop.call_later(0.5, self.loop.stop) self.loop.run_forever() - self.assertEqual(caught, 1) + self.assertEqual(non_local['caught'], 1) def _basetest_create_connection(self, connection_fut, check_sockname=True): tr, pr = self.loop.run_until_complete(connection_fut) @@ -557,7 +569,7 @@ def test_create_connection(self): lambda: MyProto(loop=self.loop), *httpd.address) self._basetest_create_connection(conn_fut) - @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_create_unix_connection(self): # Issue #20682: On Mac OS X Tiger, getsockname() returns a # zero-length address for UNIX socket. @@ -615,7 +627,7 @@ def _test_create_ssl_connection(self, httpd, create_connection, # ssl.Purpose was introduced in Python 3.4 if hasattr(ssl, 'Purpose'): - def _dummy_ssl_create_context(purpose=ssl.Purpose.SERVER_AUTH, *, + def _dummy_ssl_create_context(purpose=ssl.Purpose.SERVER_AUTH, cafile=None, capath=None, cadata=None): """ @@ -632,17 +644,18 @@ def _dummy_ssl_create_context(purpose=ssl.Purpose.SERVER_AUTH, *, self._basetest_create_ssl_connection(conn_fut, check_sockname) self.assertEqual(m.call_count, 1) - # With the real ssl.create_default_context(), certificate - # validation will fail - with self.assertRaises(ssl.SSLError) as cm: - conn_fut = create_connection(ssl=True) - # Ignore the "SSL handshake failed" log in debug mode - with test_utils.disable_logger(): - self._basetest_create_ssl_connection(conn_fut, check_sockname) + if not asyncio.BACKPORT_SSL_CONTEXT: + # With the real ssl.create_default_context(), certificate + # validation will fail + with self.assertRaises(ssl.SSLError) as cm: + conn_fut = create_connection(ssl=True) + # Ignore the "SSL handshake failed" log in debug mode + with test_utils.disable_logger(): + self._basetest_create_ssl_connection(conn_fut, check_sockname) self.assertEqual(cm.exception.reason, 'CERTIFICATE_VERIFY_FAILED') - @unittest.skipIf(ssl is None, 'No ssl module') + @test_utils.skipIf(ssl is None, 'No ssl module') def test_create_ssl_connection(self): with test_utils.run_test_server(use_ssl=True) as httpd: create_connection = functools.partial( @@ -655,8 +668,8 @@ def test_legacy_create_ssl_connection(self): with test_utils.force_legacy_ssl_support(): self.test_create_ssl_connection() - @unittest.skipIf(ssl is None, 'No ssl module') - @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @test_utils.skipIf(ssl is None, 'No ssl module') + @test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_create_ssl_unix_connection(self): # Issue #20682: On Mac OS X Tiger, getsockname() returns a # zero-length address for UNIX socket. @@ -691,10 +704,11 @@ def test_create_connection_local_addr_in_use(self): f = self.loop.create_connection( lambda: MyProto(loop=self.loop), *httpd.address, local_addr=httpd.address) - with self.assertRaises(OSError) as cm: + with self.assertRaises(socket.error) as cm: self.loop.run_until_complete(f) self.assertEqual(cm.exception.errno, errno.EADDRINUSE) - self.assertIn(str(httpd.address), cm.exception.strerror) + # FIXME: address missing from the message? + #self.assertIn(str(httpd.address), cm.exception.strerror) def test_create_server(self): proto = MyProto(self.loop) @@ -741,7 +755,7 @@ def _make_unix_server(self, factory, **kwargs): return server, path - @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_create_unix_server(self): proto = MyProto(loop=self.loop) server, path = self._make_unix_server(lambda: proto) @@ -769,20 +783,23 @@ def test_create_unix_server(self): # close server server.close() - @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_create_unix_server_path_socket_error(self): proto = MyProto(loop=self.loop) sock = socket.socket() - with sock: + try: f = self.loop.create_unix_server(lambda: proto, '/test', sock=sock) with self.assertRaisesRegex(ValueError, 'path and sock can not be specified ' 'at the same time'): self.loop.run_until_complete(f) + finally: + sock.close() def _create_ssl_context(self, certfile, keyfile=None): - sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext = asyncio.SSLContext(ssl.PROTOCOL_SSLv23) + if not asyncio.BACKPORT_SSL_CONTEXT: + sslcontext.options |= ssl.OP_NO_SSLv2 sslcontext.load_cert_chain(certfile, keyfile) return sslcontext @@ -801,7 +818,7 @@ def _make_ssl_unix_server(self, factory, certfile, keyfile=None): sslcontext = self._create_ssl_context(certfile, keyfile) return self._make_unix_server(factory, ssl=sslcontext) - @unittest.skipIf(ssl is None, 'No ssl module') + @test_utils.skipIf(ssl is None, 'No ssl module') def test_create_server_ssl(self): proto = MyProto(loop=self.loop) server, host, port = self._make_ssl_server( @@ -839,8 +856,8 @@ def test_legacy_create_server_ssl(self): with test_utils.force_legacy_ssl_support(): self.test_create_server_ssl() - @unittest.skipIf(ssl is None, 'No ssl module') - @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @test_utils.skipIf(ssl is None, 'No ssl module') + @test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_create_unix_server_ssl(self): proto = MyProto(loop=self.loop) server, path = self._make_ssl_unix_server( @@ -874,19 +891,19 @@ def test_legacy_create_unix_server_ssl(self): with test_utils.force_legacy_ssl_support(): self.test_create_unix_server_ssl() - @unittest.skipIf(ssl is None, 'No ssl module') + @test_utils.skipIf(ssl is None, 'No ssl module') + @test_utils.skipIf(asyncio.BACKPORT_SSL_CONTEXT, 'need ssl.SSLContext') def test_create_server_ssl_verify_failed(self): proto = MyProto(loop=self.loop) server, host, port = self._make_ssl_server( lambda: proto, SIGNED_CERTFILE) - sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client = asyncio.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext_client.options |= ssl.OP_NO_SSLv2 sslcontext_client.verify_mode = ssl.CERT_REQUIRED if hasattr(sslcontext_client, 'check_hostname'): sslcontext_client.check_hostname = True - # no CA loaded f_c = self.loop.create_connection(MyProto, host, port, ssl=sslcontext_client) @@ -907,14 +924,15 @@ def test_legacy_create_server_ssl_verify_failed(self): with test_utils.force_legacy_ssl_support(): self.test_create_server_ssl_verify_failed() - @unittest.skipIf(ssl is None, 'No ssl module') - @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @test_utils.skipIf(ssl is None, 'No ssl module') + @test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @test_utils.skipIf(asyncio.BACKPORT_SSL_CONTEXT, 'need ssl.SSLContext') def test_create_unix_server_ssl_verify_failed(self): proto = MyProto(loop=self.loop) server, path = self._make_ssl_unix_server( lambda: proto, SIGNED_CERTFILE) - sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslcontext_client = asyncio.SSLContext(ssl.PROTOCOL_SSLv23) sslcontext_client.options |= ssl.OP_NO_SSLv2 sslcontext_client.verify_mode = ssl.CERT_REQUIRED if hasattr(sslcontext_client, 'check_hostname'): @@ -937,33 +955,40 @@ def test_create_unix_server_ssl_verify_failed(self): self.assertIsNone(proto.transport) server.close() - def test_legacy_create_unix_server_ssl_verify_failed(self): with test_utils.force_legacy_ssl_support(): self.test_create_unix_server_ssl_verify_failed() - @unittest.skipIf(ssl is None, 'No ssl module') + @test_utils.skipIf(ssl is None, 'No ssl module') def test_create_server_ssl_match_failed(self): proto = MyProto(loop=self.loop) server, host, port = self._make_ssl_server( lambda: proto, SIGNED_CERTFILE) - sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - sslcontext_client.options |= ssl.OP_NO_SSLv2 - sslcontext_client.verify_mode = ssl.CERT_REQUIRED - sslcontext_client.load_verify_locations( - cafile=SIGNING_CA) + sslcontext_client = asyncio.SSLContext(ssl.PROTOCOL_SSLv23) + if not asyncio.BACKPORT_SSL_CONTEXT: + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + sslcontext_client.load_verify_locations( + cafile=SIGNING_CA) if hasattr(sslcontext_client, 'check_hostname'): sslcontext_client.check_hostname = True + if compat.PY3: + err_msg = "hostname '127.0.0.1' doesn't match 'localhost'" + else: + # http://bugs.python.org/issue22861 + err_msg = "hostname '127.0.0.1' doesn't match u'localhost'" + # incorrect server_hostname +# if not asyncio.BACKPORT_SSL_CONTEXT: f_c = self.loop.create_connection(MyProto, host, port, ssl=sslcontext_client) with mock.patch.object(self.loop, 'call_exception_handler'): with test_utils.disable_logger(): with self.assertRaisesRegex( ssl.CertificateError, - "hostname '127.0.0.1' doesn't match 'localhost'"): + err_msg): self.loop.run_until_complete(f_c) # close connection @@ -974,17 +999,18 @@ def test_legacy_create_server_ssl_match_failed(self): with test_utils.force_legacy_ssl_support(): self.test_create_server_ssl_match_failed() - @unittest.skipIf(ssl is None, 'No ssl module') - @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @test_utils.skipIf(ssl is None, 'No ssl module') + @test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_create_unix_server_ssl_verified(self): proto = MyProto(loop=self.loop) server, path = self._make_ssl_unix_server( lambda: proto, SIGNED_CERTFILE) - sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - sslcontext_client.options |= ssl.OP_NO_SSLv2 - sslcontext_client.verify_mode = ssl.CERT_REQUIRED - sslcontext_client.load_verify_locations(cafile=SIGNING_CA) + sslcontext_client = asyncio.SSLContext(ssl.PROTOCOL_SSLv23) + if not asyncio.BACKPORT_SSL_CONTEXT: + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + sslcontext_client.load_verify_locations(cafile=SIGNING_CA) if hasattr(sslcontext_client, 'check_hostname'): sslcontext_client.check_hostname = True @@ -1004,16 +1030,17 @@ def test_legacy_create_unix_server_ssl_verified(self): with test_utils.force_legacy_ssl_support(): self.test_create_unix_server_ssl_verified() - @unittest.skipIf(ssl is None, 'No ssl module') + @test_utils.skipIf(ssl is None, 'No ssl module') def test_create_server_ssl_verified(self): proto = MyProto(loop=self.loop) server, host, port = self._make_ssl_server( lambda: proto, SIGNED_CERTFILE) - sslcontext_client = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - sslcontext_client.options |= ssl.OP_NO_SSLv2 - sslcontext_client.verify_mode = ssl.CERT_REQUIRED - sslcontext_client.load_verify_locations(cafile=SIGNING_CA) + sslcontext_client = asyncio.SSLContext(ssl.PROTOCOL_SSLv23) + if not asyncio.BACKPORT_SSL_CONTEXT: + sslcontext_client.options |= ssl.OP_NO_SSLv2 + sslcontext_client.verify_mode = ssl.CERT_REQUIRED + sslcontext_client.load_verify_locations(cafile=SIGNING_CA) if hasattr(sslcontext_client, 'check_hostname'): sslcontext_client.check_hostname = True @@ -1026,6 +1053,7 @@ def test_create_server_ssl_verified(self): # close connection proto.transport.close() client.close() + server.close() self.loop.run_until_complete(proto.done) @@ -1034,12 +1062,12 @@ def test_legacy_create_server_ssl_verified(self): self.test_create_server_ssl_verified() def test_create_server_sock(self): - proto = asyncio.Future(loop=self.loop) + non_local = {'proto': asyncio.Future(loop=self.loop)} class TestMyProto(MyProto): def connection_made(self, transport): - super().connection_made(transport) - proto.set_result(self) + super(TestMyProto, self).connection_made(transport) + non_local['proto'].set_result(self) sock_ob = socket.socket(type=socket.SOCK_STREAM) sock_ob.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) @@ -1069,19 +1097,19 @@ def test_create_server_addr_in_use(self): host, port = sock.getsockname() f = self.loop.create_server(MyProto, host=host, port=port) - with self.assertRaises(OSError) as cm: + with self.assertRaises(socket.error) as cm: self.loop.run_until_complete(f) self.assertEqual(cm.exception.errno, errno.EADDRINUSE) server.close() - @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 not supported or enabled') + @test_utils.skipUnless(support.IPV6_ENABLED, 'IPv6 not supported or enabled') def test_create_server_dual_stack(self): f_proto = asyncio.Future(loop=self.loop) class TestMyProto(MyProto): def connection_made(self, transport): - super().connection_made(transport) + super(TestMyProto, self).connection_made(transport) f_proto.set_result(self) try_count = 0 @@ -1090,7 +1118,7 @@ def connection_made(self, transport): port = support.find_unused_port() f = self.loop.create_server(TestMyProto, host=None, port=port) server = self.loop.run_until_complete(f) - except OSError as ex: + except socket.error as ex: if ex.errno == errno.EADDRINUSE: try_count += 1 self.assertGreaterEqual(5, try_count) @@ -1131,16 +1159,17 @@ def test_server_close(self): client = socket.socket() self.assertRaises( - ConnectionRefusedError, client.connect, ('127.0.0.1', port)) + ConnectionRefusedError, wrap_error, client.connect, + ('127.0.0.1', port)) client.close() def test_create_datagram_endpoint(self): class TestMyDatagramProto(MyDatagramProto): def __init__(inner_self): - super().__init__(loop=self.loop) + super(TestMyDatagramProto, inner_self).__init__(loop=self.loop) def datagram_received(self, data, addr): - super().datagram_received(data, addr) + super(TestMyDatagramProto, self).datagram_received(data, addr) self.transport.sendto(b'resp:'+data, addr) coro = self.loop.create_datagram_endpoint( @@ -1192,7 +1221,7 @@ def test_internal_fds(self): self.assertIsNone(loop._csock) self.assertIsNone(loop._ssock) - @unittest.skipUnless(sys.platform != 'win32', + @test_utils.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") def test_read_pipe(self): proto = MyReadPipeProto(loop=self.loop) @@ -1202,8 +1231,8 @@ def test_read_pipe(self): @asyncio.coroutine def connect(): - t, p = yield from self.loop.connect_read_pipe( - lambda: proto, pipeobj) + t, p = yield From(self.loop.connect_read_pipe( + lambda: proto, pipeobj)) self.assertIs(p, proto) self.assertIs(t, proto.transport) self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) @@ -1227,7 +1256,7 @@ def connect(): # extra info is available self.assertIsNotNone(proto.transport.get_extra_info('pipe')) - @unittest.skipUnless(sys.platform != 'win32', + @test_utils.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") # select, poll and kqueue don't support character devices (PTY) on Mac OS X # older than 10.6 (Snow Leopard) @@ -1242,8 +1271,8 @@ def test_read_pty_output(self): @asyncio.coroutine def connect(): - t, p = yield from self.loop.connect_read_pipe(lambda: proto, - master_read_obj) + t, p = yield From(self.loop.connect_read_pipe(lambda: proto, + master_read_obj)) self.assertIs(p, proto) self.assertIs(t, proto.transport) self.assertEqual(['INITIAL', 'CONNECTED'], proto.state) @@ -1267,8 +1296,8 @@ def connect(): # extra info is available self.assertIsNotNone(proto.transport.get_extra_info('pipe')) - @unittest.skipUnless(sys.platform != 'win32', - "Don't support pipes for Windows") + @test_utils.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") def test_write_pipe(self): rpipe, wpipe = os.pipe() pipeobj = io.open(wpipe, 'wb', 1024) @@ -1306,12 +1335,17 @@ def reader(data): self.loop.run_until_complete(proto.done) self.assertEqual('CLOSED', proto.state) - @unittest.skipUnless(sys.platform != 'win32', + @test_utils.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") def test_write_pipe_disconnect_on_close(self): rsock, wsock = test_utils.socketpair() rsock.setblocking(False) - pipeobj = io.open(wsock.detach(), 'wb', 1024) + if hasattr(wsock, 'detach'): + wsock_fd = wsock.detach() + else: + # Python 2 + wsock_fd = wsock.fileno() + pipeobj = io.open(wsock_fd, 'wb', 1024) proto = MyWritePipeProto(loop=self.loop) connect = self.loop.connect_write_pipe(lambda: proto, pipeobj) @@ -1329,8 +1363,8 @@ def test_write_pipe_disconnect_on_close(self): self.loop.run_until_complete(proto.done) self.assertEqual('CLOSED', proto.state) - @unittest.skipUnless(sys.platform != 'win32', - "Don't support pipes for Windows") + @test_utils.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") # select, poll and kqueue don't support character devices (PTY) on Mac OS X # older than 10.6 (Snow Leopard) @support.requires_mac_ver(10, 6) @@ -1385,19 +1419,19 @@ def test_prompt_cancellation(self): def main(): try: self.loop.call_soon(f.cancel) - yield from f + yield From(f) except asyncio.CancelledError: res = 'cancelled' else: res = None finally: self.loop.stop() - return res + raise Return(res) - start = time.monotonic() + start = time_monotonic() t = asyncio.Task(main(), loop=self.loop) self.loop.run_forever() - elapsed = time.monotonic() - start + elapsed = time_monotonic() - start self.assertLess(elapsed, 0.1) self.assertEqual(t.result(), 'cancelled') @@ -1421,19 +1455,20 @@ def _run_once(): @asyncio.coroutine def wait(): loop = self.loop - yield from asyncio.sleep(1e-2, loop=loop) - yield from asyncio.sleep(1e-4, loop=loop) - yield from asyncio.sleep(1e-6, loop=loop) - yield from asyncio.sleep(1e-8, loop=loop) - yield from asyncio.sleep(1e-10, loop=loop) + yield From(asyncio.sleep(1e-2, loop=loop)) + yield From(asyncio.sleep(1e-4, loop=loop)) + yield From(asyncio.sleep(1e-6, loop=loop)) + yield From(asyncio.sleep(1e-8, loop=loop)) + yield From(asyncio.sleep(1e-10, loop=loop)) self.loop.run_until_complete(wait()) - # The ideal number of call is 12, but on some platforms, the selector + # The ideal number of call is 22, but on some platforms, the selector # may sleep at little bit less than timeout depending on the resolution # of the clock used by the kernel. Tolerate a few useless calls on # these platforms. - self.assertLessEqual(self.loop._run_once_counter, 20, - {'clock_resolution': self.loop._clock_resolution, + self.assertLessEqual(self.loop._run_once_counter, 30, + {'calls': self.loop._run_once_counter, + 'clock_resolution': self.loop._clock_resolution, 'selector': self.loop._selector.__class__.__name__}) def test_sock_connect_address(self): @@ -1451,7 +1486,7 @@ def test_sock_connect_address(self): for family, address in addresses: for sock_type in (socket.SOCK_STREAM, socket.SOCK_DGRAM): sock = socket.socket(family, sock_type) - with sock: + with contextlib.closing(sock): sock.setblocking(False) connect = self.loop.sock_connect(sock, address) with self.assertRaises(ValueError) as cm: @@ -1525,7 +1560,7 @@ def test(): self.loop.add_signal_handler(signal.SIGTERM, func) -class SubprocessTestsMixin: +class SubprocessTestsMixin(object): def check_terminated(self, returncode): if sys.platform == 'win32': @@ -1656,7 +1691,7 @@ def test_subprocess_terminate(self): self.check_terminated(proto.returncode) transp.close() - @unittest.skipIf(sys.platform == 'win32', "Don't have SIGHUP") + @test_utils.skipIf(sys.platform == 'win32', "Don't have SIGHUP") def test_subprocess_send_signal(self): prog = os.path.join(os.path.dirname(__file__), 'echo.py') @@ -1748,6 +1783,7 @@ def test_subprocess_close_client_stream(self): self.loop.run_until_complete(proto.completed) self.check_killed(proto.returncode) + @test_utils.skipUnless(hasattr(os, 'setsid'), "need os.setsid()") def test_subprocess_wait_no_same_group(self): # start the new process in a new session connect = self.loop.subprocess_shell( @@ -1762,9 +1798,9 @@ def test_subprocess_wait_no_same_group(self): def test_subprocess_exec_invalid_args(self): @asyncio.coroutine def connect(**kwds): - yield from self.loop.subprocess_exec( + yield From(self.loop.subprocess_exec( asyncio.SubprocessProtocol, - 'pwd', **kwds) + 'pwd', **kwds)) with self.assertRaises(ValueError): self.loop.run_until_complete(connect(universal_newlines=True)) @@ -1778,9 +1814,9 @@ def test_subprocess_shell_invalid_args(self): def connect(cmd=None, **kwds): if not cmd: cmd = 'pwd' - yield from self.loop.subprocess_shell( + yield From(self.loop.subprocess_shell( asyncio.SubprocessProtocol, - cmd, **kwds) + cmd, **kwds)) with self.assertRaises(ValueError): self.loop.run_until_complete(connect(['ls', '-l'])) @@ -1860,14 +1896,14 @@ def test_remove_fds_after_closing(self): class UnixEventLoopTestsMixin(EventLoopTestsMixin): def setUp(self): - super().setUp() + super(UnixEventLoopTestsMixin, self).setUp() watcher = asyncio.SafeChildWatcher() watcher.attach_loop(self.loop) asyncio.set_child_watcher(watcher) def tearDown(self): asyncio.set_child_watcher(None) - super().tearDown() + super(UnixEventLoopTestsMixin, self).tearDown() if hasattr(selectors, 'KqueueSelector'): class KqueueEventLoopTests(UnixEventLoopTestsMixin, @@ -1883,16 +1919,16 @@ def create_event_loop(self): @support.requires_mac_ver(10, 9) # Issue #20667: KqueueEventLoopTests.test_read_pty_output() # hangs on OpenBSD 5.5 - @unittest.skipIf(sys.platform.startswith('openbsd'), - 'test hangs on OpenBSD') + @test_utils.skipIf(sys.platform.startswith('openbsd'), + 'test hangs on OpenBSD') def test_read_pty_output(self): - super().test_read_pty_output() + super(KqueueEventLoopTests, self).test_read_pty_output() # kqueue doesn't support character devices (PTY) on Mac OS X older # than 10.9 (Maverick) @support.requires_mac_ver(10, 9) def test_write_pty(self): - super().test_write_pty() + super(KqueueEventLoopTests, self).test_write_pty() if hasattr(selectors, 'EpollSelector'): class EPollEventLoopTests(UnixEventLoopTestsMixin, @@ -2017,7 +2053,7 @@ def test_handle_repr_debug(self): self.loop.get_debug.return_value = True # simple function - create_filename = __file__ + create_filename = sys._getframe().f_code.co_filename create_lineno = sys._getframe().f_lineno + 1 h = asyncio.Handle(noop, (1, 2), self.loop) filename, lineno = test_utils.get_function_source(noop) @@ -2069,13 +2105,13 @@ def check_source_traceback(h): check_source_traceback(h) -class TimerTests(unittest.TestCase): +class TimerTests(test_utils.TestCase): def setUp(self): self.loop = mock.Mock() def test_hash(self): - when = time.monotonic() + when = time_monotonic() h = asyncio.TimerHandle(when, lambda: False, (), mock.Mock()) self.assertEqual(hash(h), hash(when)) @@ -2085,7 +2121,7 @@ def callback(*args): return args args = (1, 2, 3) - when = time.monotonic() + when = time_monotonic() h = asyncio.TimerHandle(when, callback, args, mock.Mock()) self.assertIs(h._callback, callback) self.assertIs(h._args, args) @@ -2120,7 +2156,7 @@ def test_timer_repr_debug(self): self.loop.get_debug.return_value = True # simple function - create_filename = __file__ + create_filename = sys._getframe().f_code.co_filename create_lineno = sys._getframe().f_lineno + 1 h = asyncio.TimerHandle(123, noop, (), self.loop) filename, lineno = test_utils.get_function_source(noop) @@ -2141,7 +2177,7 @@ def test_timer_comparison(self): def callback(*args): return args - when = time.monotonic() + when = time_monotonic() h1 = asyncio.TimerHandle(when, callback, (), self.loop) h2 = asyncio.TimerHandle(when, callback, (), self.loop) @@ -2178,7 +2214,7 @@ def callback(*args): self.assertIs(NotImplemented, h1.__ne__(h3)) -class AbstractEventLoopTests(unittest.TestCase): +class AbstractEventLoopTests(test_utils.TestCase): def test_not_implemented(self): f = mock.Mock() @@ -2191,8 +2227,13 @@ def test_not_implemented(self): NotImplementedError, loop.stop) self.assertRaises( NotImplementedError, loop.is_running) - self.assertRaises( - NotImplementedError, loop.is_closed) + # skip some tests if the AbstractEventLoop class comes from asyncio + # and the asyncio version (python version in fact) is older than 3.4.2 + if events.asyncio is None or sys.version_info >= (3, 4, 2): + self.assertRaises( + NotImplementedError, loop.is_closed) + self.assertRaises( + NotImplementedError, loop.create_task, None) self.assertRaises( NotImplementedError, loop.close) self.assertRaises( @@ -2266,7 +2307,7 @@ def test_not_implemented(self): NotImplementedError, loop.set_debug, f) -class ProtocolsAbsTests(unittest.TestCase): +class ProtocolsAbsTests(test_utils.TestCase): def test_empty(self): f = mock.Mock() @@ -2290,7 +2331,7 @@ def test_empty(self): self.assertIsNone(sp.process_exited()) -class PolicyTests(unittest.TestCase): +class PolicyTests(test_utils.TestCase): def test_event_loop_policy(self): policy = asyncio.AbstractEventLoopPolicy() diff --git a/tests/test_futures.py b/tests/test_futures.py index 3aadb688..34e68fc1 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -1,19 +1,23 @@ """Tests for futures.py.""" -import concurrent.futures +try: + import concurrent.futures +except ImportError: + concurrent = None import re import sys import threading import unittest -from unittest import mock import trollius as asyncio +from trollius import compat +from trollius import test_support as support from trollius import test_utils -try: - from test import support -except ImportError: - from trollius import test_support as support +from trollius.test_utils import mock + +def get_thread_ident(): + return threading.current_thread().ident def _fakefunc(f): return f @@ -43,10 +47,6 @@ def test_init_constructor_default_loop(self): f = asyncio.Future() self.assertIs(f._loop, self.loop) - def test_constructor_positional(self): - # Make sure Future doesn't accept a positional argument - self.assertRaises(TypeError, asyncio.Future, 42) - def test_cancel(self): f = asyncio.Future(loop=self.loop) self.assertTrue(f.cancel()) @@ -90,24 +90,6 @@ def test_exception_class(self): f.set_exception(RuntimeError) self.assertIsInstance(f.exception(), RuntimeError) - def test_yield_from_twice(self): - f = asyncio.Future(loop=self.loop) - - def fixture(): - yield 'A' - x = yield from f - yield 'B', x - y = yield from f - yield 'C', y - - g = fixture() - self.assertEqual(next(g), 'A') # yield 'A'. - self.assertEqual(next(g), f) # First yield from f. - f.set_result(42) - self.assertEqual(next(g), ('B', 42)) # yield 'B', x. - # The second "yield from f" does not yield f. - self.assertEqual(next(g), ('C', 42)) # yield 'C', y. - def test_future_repr(self): self.loop.set_debug(True) f_pending_debug = asyncio.Future(loop=self.loop) @@ -140,7 +122,8 @@ def test_future_repr(self): def func_repr(func): filename, lineno = test_utils.get_function_source(func) - text = '%s() at %s:%s' % (func.__qualname__, filename, lineno) + func_name = getattr(func, '__qualname__', func.__name__) + text = '%s() at %s:%s' % (func_name, filename, lineno) return re.escape(text) f_one_callbacks = asyncio.Future(loop=self.loop) @@ -234,10 +217,13 @@ def test_tb_logger_result_retrieved(self, m_log): @mock.patch('trollius.base_events.logger') def test_tb_logger_exception_unretrieved(self, m_log): + self.loop.set_debug(True) + asyncio.set_event_loop(self.loop) fut = asyncio.Future(loop=self.loop) fut.set_exception(RuntimeError('boom')) del fut test_utils.run_briefly(self.loop) + support.gc_collect() self.assertTrue(m_log.error.called) @mock.patch('trollius.base_events.logger') @@ -256,32 +242,35 @@ def test_tb_logger_exception_result_retrieved(self, m_log): del fut self.assertFalse(m_log.error.called) + @test_utils.skipIf(concurrent is None, 'need concurrent.futures') def test_wrap_future(self): def run(arg): - return (arg, threading.get_ident()) + return (arg, get_thread_ident()) ex = concurrent.futures.ThreadPoolExecutor(1) f1 = ex.submit(run, 'oi') f2 = asyncio.wrap_future(f1, loop=self.loop) res, ident = self.loop.run_until_complete(f2) self.assertIsInstance(f2, asyncio.Future) self.assertEqual(res, 'oi') - self.assertNotEqual(ident, threading.get_ident()) + self.assertNotEqual(ident, get_thread_ident()) def test_wrap_future_future(self): f1 = asyncio.Future(loop=self.loop) f2 = asyncio.wrap_future(f1) self.assertIs(f1, f2) + @test_utils.skipIf(concurrent is None, 'need concurrent.futures') @mock.patch('trollius.futures.events') def test_wrap_future_use_global_loop(self, m_events): def run(arg): - return (arg, threading.get_ident()) + return (arg, get_thread_ident()) ex = concurrent.futures.ThreadPoolExecutor(1) f1 = ex.submit(run, 'oi') f2 = asyncio.wrap_future(f1) self.assertIs(m_events.get_event_loop.return_value, f2._loop) + @test_utils.skipIf(concurrent is None, 'need concurrent.futures') def test_wrap_future_cancel(self): f1 = concurrent.futures.Future() f2 = asyncio.wrap_future(f1, loop=self.loop) @@ -290,6 +279,7 @@ def test_wrap_future_cancel(self): self.assertTrue(f1.cancelled()) self.assertTrue(f2.cancelled()) + @test_utils.skipIf(concurrent is None, 'need concurrent.futures') def test_wrap_future_cancel2(self): f1 = concurrent.futures.Future() f2 = asyncio.wrap_future(f1, loop=self.loop) @@ -367,12 +357,16 @@ def memory_error(): r'MemoryError$' ).format(filename=re.escape(frame[0]), lineno=frame[1]) - else: + elif compat.PY3: regex = (r'^Future/Task exception was never retrieved\n' r'Traceback \(most recent call last\):\n' r'.*\n' r'MemoryError$' ) + else: + regex = (r'^Future/Task exception was never retrieved\n' + r'MemoryError$' + ) m_log.error.assert_called_once_with(mock.ANY, exc_info=False) message = m_log.error.call_args[0][0] self.assertRegex(message, re.compile(regex, re.DOTALL)) diff --git a/tests/test_locks.py b/tests/test_locks.py index 5ad02eda..ec7dbba2 100644 --- a/tests/test_locks.py +++ b/tests/test_locks.py @@ -1,11 +1,12 @@ """Tests for lock.py""" import unittest -from unittest import mock import re import trollius as asyncio +from trollius import From, Return from trollius import test_utils +from trollius.test_utils import mock STR_RGX_REPR = ( @@ -42,7 +43,7 @@ def test_repr(self): @asyncio.coroutine def acquire_lock(): - yield from lock + yield From(lock.acquire()) self.loop.run_until_complete(acquire_lock()) self.assertTrue(repr(lock).endswith('[locked]>')) @@ -53,7 +54,8 @@ def test_lock(self): @asyncio.coroutine def acquire_lock(): - return (yield from lock) + yield From(lock.acquire()) + raise Return(lock) res = self.loop.run_until_complete(acquire_lock()) @@ -71,21 +73,21 @@ def test_acquire(self): @asyncio.coroutine def c1(result): - if (yield from lock.acquire()): + if (yield From(lock.acquire())): result.append(1) - return True + raise Return(True) @asyncio.coroutine def c2(result): - if (yield from lock.acquire()): + if (yield From(lock.acquire())): result.append(2) - return True + raise Return(True) @asyncio.coroutine def c3(result): - if (yield from lock.acquire()): + if (yield From(lock.acquire())): result.append(3) - return True + raise Return(True) t1 = asyncio.Task(c1(result), loop=self.loop) t2 = asyncio.Task(c2(result), loop=self.loop) @@ -147,22 +149,22 @@ def test_cancel_race(self): @asyncio.coroutine def lockit(name, blocker): - yield from lock.acquire() + yield From(lock.acquire()) try: if blocker is not None: - yield from blocker + yield From(blocker) finally: lock.release() fa = asyncio.Future(loop=self.loop) ta = asyncio.Task(lockit('A', fa), loop=self.loop) - test_utils.run_briefly(self.loop) + test_utils.run_briefly(self.loop, 2) self.assertTrue(lock.locked()) tb = asyncio.Task(lockit('B', None), loop=self.loop) - test_utils.run_briefly(self.loop) + test_utils.run_briefly(self.loop, 2) self.assertEqual(len(lock._waiters), 1) tc = asyncio.Task(lockit('C', None), loop=self.loop) - test_utils.run_briefly(self.loop) + test_utils.run_briefly(self.loop, 2) self.assertEqual(len(lock._waiters), 2) # Create the race and check. @@ -170,7 +172,7 @@ def lockit(name, blocker): fa.set_result(None) tb.cancel() self.assertTrue(lock._waiters[0].cancelled()) - test_utils.run_briefly(self.loop) + test_utils.run_briefly(self.loop, 2) self.assertFalse(lock.locked()) self.assertTrue(ta.done()) self.assertTrue(tb.cancelled()) @@ -194,7 +196,7 @@ def test_context_manager(self): @asyncio.coroutine def acquire_lock(): - return (yield from lock) + raise Return((yield From(lock))) with self.loop.run_until_complete(acquire_lock()): self.assertTrue(lock.locked()) @@ -206,9 +208,9 @@ def test_context_manager_cant_reuse(self): @asyncio.coroutine def acquire_lock(): - return (yield from lock) + raise Return((yield From(lock))) - # This spells "yield from lock" outside a generator. + # This spells "yield From(lock)" outside a generator. cm = self.loop.run_until_complete(acquire_lock()) with cm: self.assertTrue(lock.locked()) @@ -228,7 +230,7 @@ def test_context_manager_no_yield(self): except RuntimeError as err: self.assertEqual( str(err), - '"yield from" should be used as context manager expression') + '"yield" should be used as context manager expression') self.assertFalse(lock.locked()) @@ -273,30 +275,30 @@ def test_wait(self): @asyncio.coroutine def c1(result): - if (yield from ev.wait()): + if (yield From(ev.wait())): result.append(1) @asyncio.coroutine def c2(result): - if (yield from ev.wait()): + if (yield From(ev.wait())): result.append(2) @asyncio.coroutine def c3(result): - if (yield from ev.wait()): + if (yield From(ev.wait())): result.append(3) t1 = asyncio.Task(c1(result), loop=self.loop) t2 = asyncio.Task(c2(result), loop=self.loop) - test_utils.run_briefly(self.loop) + test_utils.run_briefly(self.loop, 2) self.assertEqual([], result) t3 = asyncio.Task(c3(result), loop=self.loop) ev.set() - test_utils.run_briefly(self.loop) - self.assertEqual([3, 1, 2], result) + test_utils.run_briefly(self.loop, 2) + self.assertEqual([1, 2, 3], result) self.assertTrue(t1.done()) self.assertIsNone(t1.result()) @@ -338,9 +340,9 @@ def test_clear_with_waiters(self): @asyncio.coroutine def c1(result): - if (yield from ev.wait()): + if (yield From(ev.wait())): result.append(1) - return True + raise Return(True) t = asyncio.Task(c1(result), loop=self.loop) test_utils.run_briefly(self.loop) @@ -386,56 +388,56 @@ def test_wait(self): @asyncio.coroutine def c1(result): - yield from cond.acquire() - if (yield from cond.wait()): + yield From(cond.acquire()) + if (yield From(cond.wait())): result.append(1) - return True + raise Return(True) @asyncio.coroutine def c2(result): - yield from cond.acquire() - if (yield from cond.wait()): + yield From(cond.acquire()) + if (yield From(cond.wait())): result.append(2) - return True + raise Return(True) @asyncio.coroutine def c3(result): - yield from cond.acquire() - if (yield from cond.wait()): + yield From(cond.acquire()) + if (yield From(cond.wait())): result.append(3) - return True + raise Return(True) t1 = asyncio.Task(c1(result), loop=self.loop) t2 = asyncio.Task(c2(result), loop=self.loop) t3 = asyncio.Task(c3(result), loop=self.loop) - test_utils.run_briefly(self.loop) + test_utils.run_briefly(self.loop, 2) self.assertEqual([], result) self.assertFalse(cond.locked()) self.assertTrue(self.loop.run_until_complete(cond.acquire())) cond.notify() - test_utils.run_briefly(self.loop) + test_utils.run_briefly(self.loop, 2) self.assertEqual([], result) self.assertTrue(cond.locked()) cond.release() - test_utils.run_briefly(self.loop) + test_utils.run_briefly(self.loop, 2) self.assertEqual([1], result) self.assertTrue(cond.locked()) cond.notify(2) - test_utils.run_briefly(self.loop) + test_utils.run_briefly(self.loop, 2) self.assertEqual([1], result) self.assertTrue(cond.locked()) cond.release() - test_utils.run_briefly(self.loop) + test_utils.run_briefly(self.loop, 2) self.assertEqual([1, 2], result) self.assertTrue(cond.locked()) cond.release() - test_utils.run_briefly(self.loop) + test_utils.run_briefly(self.loop, 2) self.assertEqual([1, 2, 3], result) self.assertTrue(cond.locked()) @@ -475,11 +477,11 @@ def predicate(): @asyncio.coroutine def c1(result): - yield from cond.acquire() - if (yield from cond.wait_for(predicate)): + yield From(cond.acquire()) + if (yield From(cond.wait_for(predicate))): result.append(1) cond.release() - return True + raise Return(True) t = asyncio.Task(c1(result), loop=self.loop) @@ -520,27 +522,27 @@ def test_notify(self): @asyncio.coroutine def c1(result): - yield from cond.acquire() - if (yield from cond.wait()): + yield From(cond.acquire()) + if (yield From(cond.wait())): result.append(1) cond.release() - return True + raise Return(True) @asyncio.coroutine def c2(result): - yield from cond.acquire() - if (yield from cond.wait()): + yield From(cond.acquire()) + if (yield From(cond.wait())): result.append(2) cond.release() - return True + raise Return(True) @asyncio.coroutine def c3(result): - yield from cond.acquire() - if (yield from cond.wait()): + yield From(cond.acquire()) + if (yield From(cond.wait())): result.append(3) cond.release() - return True + raise Return(True) t1 = asyncio.Task(c1(result), loop=self.loop) t2 = asyncio.Task(c2(result), loop=self.loop) @@ -552,14 +554,16 @@ def c3(result): self.loop.run_until_complete(cond.acquire()) cond.notify(1) cond.release() - test_utils.run_briefly(self.loop) + # each coroutine requires 2 runs of the event loop + test_utils.run_briefly(self.loop, 2) self.assertEqual([1], result) self.loop.run_until_complete(cond.acquire()) cond.notify(1) cond.notify(2048) cond.release() - test_utils.run_briefly(self.loop) + # each coroutine requires 2 runs of the event loop + test_utils.run_briefly(self.loop, 4) self.assertEqual([1, 2, 3], result) self.assertTrue(t1.done()) @@ -576,19 +580,19 @@ def test_notify_all(self): @asyncio.coroutine def c1(result): - yield from cond.acquire() - if (yield from cond.wait()): + yield From(cond.acquire()) + if (yield From(cond.wait())): result.append(1) cond.release() - return True + raise Return(True) @asyncio.coroutine def c2(result): - yield from cond.acquire() - if (yield from cond.wait()): + yield From(cond.acquire()) + if (yield From(cond.wait())): result.append(2) cond.release() - return True + raise Return(True) t1 = asyncio.Task(c1(result), loop=self.loop) t2 = asyncio.Task(c2(result), loop=self.loop) @@ -599,7 +603,8 @@ def c2(result): self.loop.run_until_complete(cond.acquire()) cond.notify_all() cond.release() - test_utils.run_briefly(self.loop) + # each coroutine requires 2 runs of the event loop + test_utils.run_briefly(self.loop, 4) self.assertEqual([1, 2], result) self.assertTrue(t1.done()) @@ -636,7 +641,7 @@ def test_context_manager(self): @asyncio.coroutine def acquire_cond(): - return (yield from cond) + raise Return((yield From(cond))) with self.loop.run_until_complete(acquire_cond()): self.assertTrue(cond.locked()) @@ -652,7 +657,7 @@ def test_context_manager_no_yield(self): except RuntimeError as err: self.assertEqual( str(err), - '"yield from" should be used as context manager expression') + '"yield From" should be used as context manager expression') self.assertFalse(cond.locked()) @@ -718,7 +723,8 @@ def test_semaphore(self): @asyncio.coroutine def acquire_lock(): - return (yield from sem) + yield From(sem.acquire()) + raise Return(sem) res = self.loop.run_until_complete(acquire_lock()) @@ -743,33 +749,34 @@ def test_acquire(self): @asyncio.coroutine def c1(result): - yield from sem.acquire() + yield From(sem.acquire()) result.append(1) - return True + raise Return(True) @asyncio.coroutine def c2(result): - yield from sem.acquire() + yield From(sem.acquire()) result.append(2) - return True + raise Return(True) @asyncio.coroutine def c3(result): - yield from sem.acquire() + yield From(sem.acquire()) result.append(3) - return True + raise Return(True) @asyncio.coroutine def c4(result): - yield from sem.acquire() + yield From(sem.acquire()) result.append(4) - return True + raise Return(True) t1 = asyncio.Task(c1(result), loop=self.loop) t2 = asyncio.Task(c2(result), loop=self.loop) t3 = asyncio.Task(c3(result), loop=self.loop) - test_utils.run_briefly(self.loop) + # each coroutine requires 2 runs of the event loop + test_utils.run_briefly(self.loop, 2) self.assertEqual([1], result) self.assertTrue(sem.locked()) self.assertEqual(2, len(sem._waiters)) @@ -829,7 +836,7 @@ def test_context_manager(self): @asyncio.coroutine def acquire_lock(): - return (yield from sem) + raise Return((yield From(sem))) with self.loop.run_until_complete(acquire_lock()): self.assertFalse(sem.locked()) @@ -849,7 +856,7 @@ def test_context_manager_no_yield(self): except RuntimeError as err: self.assertEqual( str(err), - '"yield from" should be used as context manager expression') + '"yield" should be used as context manager expression') self.assertEqual(2, sem._value) diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index 47ae6bc5..b55fff32 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -2,14 +2,15 @@ import socket import unittest -from unittest import mock -import trollius as asyncio +from trollius import test_utils from trollius.proactor_events import BaseProactorEventLoop +from trollius.proactor_events import _ProactorDuplexPipeTransport from trollius.proactor_events import _ProactorSocketTransport from trollius.proactor_events import _ProactorWritePipeTransport -from trollius.proactor_events import _ProactorDuplexPipeTransport -from trollius import test_utils +from trollius.py33_exceptions import ConnectionAbortedError, ConnectionResetError +from trollius.test_utils import mock +import trollius as asyncio def close_transport(transport): diff --git a/tests/test_queues.py b/tests/test_queues.py index b097a70d..7e92fbc9 100644 --- a/tests/test_queues.py +++ b/tests/test_queues.py @@ -1,10 +1,11 @@ """Tests for queues.py""" import unittest -from unittest import mock import trollius as asyncio +from trollius import Return, From from trollius import test_utils +from trollius.test_utils import mock class _QueueTestBase(test_utils.TestCase): @@ -32,7 +33,7 @@ def gen(): q = asyncio.Queue(loop=loop) self.assertTrue(fn(q).startswith('= (3,): + UNICODE_STR = 'unicode' +else: + UNICODE_STR = unicode('unicode') + try: + memoryview + except NameError: + # Python 2.6 + memoryview = buffer + MOCK_ANY = mock.ANY @@ -900,7 +917,7 @@ def test_write_memoryview(self): transport = self.socket_transport() transport.write(data) - self.sock.send.assert_called_with(data) + self.sock.send.assert_called_with(b'data') def test_write_no_data(self): transport = self.socket_transport() @@ -1112,7 +1129,7 @@ def test_write_eof_buffer(self): tr.close() -@unittest.skipIf(ssl is None, 'No ssl module') +@test_utils.skipIf(ssl is None, 'No ssl module') class SelectorSslTransportTests(test_utils.TestCase): def setUp(self): @@ -1150,13 +1167,13 @@ def test_on_handshake(self): def test_on_handshake_reader_retry(self): self.loop.set_debug(False) - self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError + self.sslsock.do_handshake.side_effect = SSLWantReadError transport = self.ssl_transport() self.loop.assert_reader(1, transport._on_handshake, None) def test_on_handshake_writer_retry(self): self.loop.set_debug(False) - self.sslsock.do_handshake.side_effect = ssl.SSLWantWriteError + self.sslsock.do_handshake.side_effect = SSLWantWriteError transport = self.ssl_transport() self.loop.assert_writer(1, transport._on_handshake, None) @@ -1234,7 +1251,7 @@ def test_write_no_data(self): def test_write_str(self): transport = self._make_one() - self.assertRaises(TypeError, transport.write, 'str') + self.assertRaises(TypeError, transport.write, UNICODE_STR) def test_write_closing(self): transport = self._make_one() @@ -1255,6 +1272,7 @@ def test_write_exception(self, m_log): transport.write(b'data') m_log.warning.assert_called_with('socket.send() raised exception.') + @test_utils.skipIf(_SSL_REQUIRES_SELECT, 'buggy ssl with the workaround') def test_read_ready_recv(self): self.sslsock.recv.return_value = b'data' transport = self._make_one() @@ -1276,6 +1294,7 @@ def test_read_ready_write_wants_read(self): self.loop.add_writer.assert_called_with( transport._sock_fd, transport._write_ready) + @test_utils.skipIf(_SSL_REQUIRES_SELECT, 'buggy ssl with the workaround') def test_read_ready_recv_eof(self): self.sslsock.recv.return_value = b'' transport = self._make_one() @@ -1284,6 +1303,7 @@ def test_read_ready_recv_eof(self): transport.close.assert_called_with() self.protocol.eof_received.assert_called_with() + @test_utils.skipIf(_SSL_REQUIRES_SELECT, 'buggy ssl with the workaround') def test_read_ready_recv_conn_reset(self): err = self.sslsock.recv.side_effect = ConnectionResetError() transport = self._make_one() @@ -1292,8 +1312,9 @@ def test_read_ready_recv_conn_reset(self): transport._read_ready() transport._force_close.assert_called_with(err) + @test_utils.skipIf(_SSL_REQUIRES_SELECT, 'buggy ssl with the workaround') def test_read_ready_recv_retry(self): - self.sslsock.recv.side_effect = ssl.SSLWantReadError + self.sslsock.recv.side_effect = SSLWantReadError transport = self._make_one() transport._read_ready() self.assertTrue(self.sslsock.recv.called) @@ -1307,10 +1328,11 @@ def test_read_ready_recv_retry(self): transport._read_ready() self.assertFalse(self.protocol.data_received.called) + @test_utils.skipIf(_SSL_REQUIRES_SELECT, 'buggy ssl with the workaround') def test_read_ready_recv_write(self): self.loop.remove_reader = mock.Mock() self.loop.add_writer = mock.Mock() - self.sslsock.recv.side_effect = ssl.SSLWantWriteError + self.sslsock.recv.side_effect = SSLWantWriteError transport = self._make_one() transport._read_ready() self.assertFalse(self.protocol.data_received.called) @@ -1320,6 +1342,7 @@ def test_read_ready_recv_write(self): self.loop.add_writer.assert_called_with( transport._sock_fd, transport._write_ready) + @test_utils.skipIf(_SSL_REQUIRES_SELECT, 'buggy ssl with the workaround') def test_read_ready_recv_exc(self): err = self.sslsock.recv.side_effect = OSError() transport = self._make_one() @@ -1383,7 +1406,7 @@ def test_write_ready_send_retry(self): transport = self._make_one() transport._buffer = list_to_buffer([b'data']) - self.sslsock.send.side_effect = ssl.SSLWantWriteError + self.sslsock.send.side_effect = SSLWantWriteError transport._write_ready() self.assertEqual(list_to_buffer([b'data']), transport._buffer) @@ -1396,7 +1419,7 @@ def test_write_ready_send_read(self): transport._buffer = list_to_buffer([b'data']) self.loop.remove_writer = mock.Mock() - self.sslsock.send.side_effect = ssl.SSLWantReadError + self.sslsock.send.side_effect = SSLWantReadError transport._write_ready() self.assertFalse(self.protocol.data_received.called) self.assertTrue(transport._write_wants_read) @@ -1457,7 +1480,7 @@ def test_close_not_connected(self): self.assertFalse(self.protocol.connection_made.called) self.assertFalse(self.protocol.connection_lost.called) - @unittest.skipIf(ssl is None, 'No SSL support') + @test_utils.skipIf(ssl is None, 'No SSL support') def test_server_hostname(self): self.ssl_transport(server_hostname='localhost') self.sslcontext.wrap_socket.assert_called_with( @@ -1465,7 +1488,7 @@ def test_server_hostname(self): server_hostname='localhost') -class SelectorSslWithoutSslTransportTests(unittest.TestCase): +class SelectorSslWithoutSslTransportTests(test_utils.TestCase): @mock.patch('trollius.selector_events.ssl', None) def test_ssl_transport_requires_ssl_module(self): @@ -1550,7 +1573,7 @@ def test_sendto_memoryview(self): transport.sendto(data, ('0.0.0.0', 1234)) self.assertTrue(self.sock.sendto.called) self.assertEqual( - self.sock.sendto.call_args[0], (data, ('0.0.0.0', 1234))) + self.sock.sendto.call_args[0], (b'data', ('0.0.0.0', 1234))) def test_sendto_no_data(self): transport = self.datagram_transport() @@ -1655,7 +1678,7 @@ def test_sendto_error_received_connected(self): def test_sendto_str(self): transport = self.datagram_transport() - self.assertRaises(TypeError, transport.sendto, 'str', ()) + self.assertRaises(TypeError, transport.sendto, UNICODE_STR, ()) def test_sendto_connected_addr(self): transport = self.datagram_transport(address=('0.0.0.0', 1)) diff --git a/tests/test_selectors.py b/tests/test_selectors.py index 5749389e..591b4aab 100644 --- a/tests/test_selectors.py +++ b/tests/test_selectors.py @@ -4,22 +4,17 @@ import signal import sys from time import sleep -import unittest -import unittest.mock -try: - from test import support -except ImportError: - from trollius import test_support as support -try: - from time import monotonic as time -except ImportError: - from time import time as time try: import resource except ImportError: resource = None + from trollius import selectors +from trollius import test_support as support +from trollius import test_utils +from trollius.test_utils import mock from trollius.test_utils import socketpair +from trollius.time_monotonic import time_monotonic as time def find_ready_matching(ready, flag): @@ -30,7 +25,7 @@ def find_ready_matching(ready, flag): return match -class BaseSelectorTestCase(unittest.TestCase): +class BaseSelectorTestCase(test_utils.TestCase): def make_socketpair(self): rd, wr = socketpair() @@ -91,7 +86,7 @@ def test_unregister_after_fd_close(self): s.unregister(r) s.unregister(w) - @unittest.skipUnless(os.name == 'posix', "requires posix") + @test_utils.skipUnless(os.name == 'posix', "requires posix") def test_unregister_after_fd_close_and_reuse(self): s = self.SELECTOR() self.addCleanup(s.close) @@ -151,8 +146,8 @@ def test_modify(self): # modify use a shortcut d3 = object() - s.register = unittest.mock.Mock() - s.unregister = unittest.mock.Mock() + s.register = mock.Mock() + s.unregister = mock.Mock() s.modify(rd, selectors.EVENT_READ, d3) self.assertFalse(s.register.called) @@ -304,8 +299,8 @@ def test_selector(self): self.assertEqual(bufs, [MSG] * NUM_SOCKETS) - @unittest.skipIf(sys.platform == 'win32', - 'select.select() cannot be used with empty fd sets') + @test_utils.skipIf(sys.platform == 'win32', + 'select.select() cannot be used with empty fd sets') def test_empty_select(self): s = self.SELECTOR() self.addCleanup(s.close) @@ -337,8 +332,8 @@ def test_timeout(self): # Tolerate 2.0 seconds for very slow buildbots self.assertTrue(0.8 <= dt <= 2.0, dt) - @unittest.skipUnless(hasattr(signal, "alarm"), - "signal.alarm() required for this test") + @test_utils.skipUnless(hasattr(signal, "alarm"), + "signal.alarm() required for this test") def test_select_interrupt(self): s = self.SELECTOR() self.addCleanup(s.close) @@ -361,7 +356,7 @@ class ScalableSelectorMixIn: # see issue #18963 for why it's skipped on older OS X versions @support.requires_mac_ver(10, 5) - @unittest.skipUnless(resource, "Test needs resource module") + @test_utils.skipUnless(resource, "Test needs resource module") def test_above_fd_setsize(self): # A scalable implementation should have no problem with more than # FD_SETSIZE file descriptors. Since we don't know the value, we just @@ -413,29 +408,29 @@ class SelectSelectorTestCase(BaseSelectorTestCase): SELECTOR = selectors.SelectSelector -@unittest.skipUnless(hasattr(selectors, 'PollSelector'), - "Test needs selectors.PollSelector") +@test_utils.skipUnless(hasattr(selectors, 'PollSelector'), + "Test needs selectors.PollSelector") class PollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn): SELECTOR = getattr(selectors, 'PollSelector', None) -@unittest.skipUnless(hasattr(selectors, 'EpollSelector'), - "Test needs selectors.EpollSelector") +@test_utils.skipUnless(hasattr(selectors, 'EpollSelector'), + "Test needs selectors.EpollSelector") class EpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn): SELECTOR = getattr(selectors, 'EpollSelector', None) -@unittest.skipUnless(hasattr(selectors, 'KqueueSelector'), - "Test needs selectors.KqueueSelector)") +@test_utils.skipUnless(hasattr(selectors, 'KqueueSelector'), + "Test needs selectors.KqueueSelector)") class KqueueSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn): SELECTOR = getattr(selectors, 'KqueueSelector', None) -@unittest.skipUnless(hasattr(selectors, 'DevpollSelector'), - "Test needs selectors.DevpollSelector") +@test_utils.skipUnless(hasattr(selectors, 'DevpollSelector'), + "Test needs selectors.DevpollSelector") class DevpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn): SELECTOR = getattr(selectors, 'DevpollSelector', None) diff --git a/tests/test_sslproto.py b/tests/test_sslproto.py index 7c7bbf80..8ea0f975 100644 --- a/tests/test_sslproto.py +++ b/tests/test_sslproto.py @@ -1,7 +1,6 @@ """Tests for asyncio/sslproto.py.""" import unittest -from unittest import mock try: import ssl except ImportError: @@ -10,6 +9,7 @@ import trollius as asyncio from trollius import sslproto from trollius import test_utils +from trollius.test_utils import mock @unittest.skipIf(ssl is None, 'No ssl module') diff --git a/tests/test_streams.py b/tests/test_streams.py index 5f7eb7e7..42b80cff 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -1,18 +1,21 @@ """Tests for streams.py.""" import gc +import io import os import socket import sys import unittest -from unittest import mock try: import ssl except ImportError: ssl = None import trollius as asyncio +from trollius import Return, From +from trollius import compat from trollius import test_utils +from trollius.test_utils import mock class StreamReaderTests(test_utils.TestCase): @@ -29,7 +32,7 @@ def tearDown(self): self.loop.close() gc.collect() - super().tearDown() + super(StreamReaderTests, self).tearDown() @mock.patch('trollius.streams.events') def test_ctor_global_loop(self, m_events): @@ -53,7 +56,7 @@ def test_open_connection(self): loop=self.loop) self._basetest_open_connection(conn_fut) - @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_open_unix_connection(self): with test_utils.run_test_unix_server() as httpd: conn_fut = asyncio.open_unix_connection(httpd.address, @@ -72,7 +75,7 @@ def _basetest_open_connection_no_loop_ssl(self, open_connection_fut): writer.close() - @unittest.skipIf(ssl is None, 'No ssl module') + @test_utils.skipIf(ssl is None, 'No ssl module') def test_open_connection_no_loop_ssl(self): with test_utils.run_test_server(use_ssl=True) as httpd: conn_fut = asyncio.open_connection( @@ -82,8 +85,8 @@ def test_open_connection_no_loop_ssl(self): self._basetest_open_connection_no_loop_ssl(conn_fut) - @unittest.skipIf(ssl is None, 'No ssl module') - @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @test_utils.skipIf(ssl is None, 'No ssl module') + @test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_open_unix_connection_no_loop_ssl(self): with test_utils.run_test_unix_server(use_ssl=True) as httpd: conn_fut = asyncio.open_unix_connection( @@ -109,7 +112,7 @@ def test_open_connection_error(self): loop=self.loop) self._basetest_open_connection_error(conn_fut) - @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_open_unix_connection_error(self): with test_utils.run_test_unix_server() as httpd: conn_fut = asyncio.open_unix_connection(httpd.address, @@ -413,7 +416,7 @@ def test_exception_waiter(self): @asyncio.coroutine def set_err(): - stream.set_exception(ValueError()) + self.loop.call_soon(stream.set_exception, ValueError()) t1 = asyncio.Task(stream.readline(), loop=self.loop) t2 = asyncio.Task(set_err(), loop=self.loop) @@ -444,7 +447,7 @@ def __init__(self, loop): @asyncio.coroutine def handle_client(self, client_reader, client_writer): - data = yield from client_reader.readline() + data = yield From(client_reader.readline()) client_writer.write(data) def start(self): @@ -483,14 +486,14 @@ def stop(self): @asyncio.coroutine def client(addr): - reader, writer = yield from asyncio.open_connection( - *addr, loop=self.loop) + reader, writer = yield From(asyncio.open_connection( + *addr, loop=self.loop)) # send a line writer.write(b"hello world!\n") # read it back - msgback = yield from reader.readline() + msgback = yield From(reader.readline()) writer.close() - return msgback + raise Return(msgback) # test the server variant with a coroutine as client handler server = MyServer(self.loop) @@ -508,7 +511,7 @@ def client(addr): server.stop() self.assertEqual(msg, b"hello world!\n") - @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_start_unix_server(self): class MyServer: @@ -520,7 +523,7 @@ def __init__(self, loop, path): @asyncio.coroutine def handle_client(self, client_reader, client_writer): - data = yield from client_reader.readline() + data = yield From(client_reader.readline()) client_writer.write(data) def start(self): @@ -551,14 +554,14 @@ def stop(self): @asyncio.coroutine def client(path): - reader, writer = yield from asyncio.open_unix_connection( - path, loop=self.loop) + reader, writer = yield From(asyncio.open_unix_connection( + path, loop=self.loop)) # send a line writer.write(b"hello world!\n") # read it back - msgback = yield from reader.readline() + msgback = yield From(reader.readline()) writer.close() - return msgback + raise Return(msgback) # test the server variant with a coroutine as client handler with test_utils.unix_socket_path() as path: @@ -578,7 +581,7 @@ def client(path): server.stop() self.assertEqual(msg, b"hello world!\n") - @unittest.skipIf(sys.platform == 'win32', "Don't have pipes") + @test_utils.skipIf(sys.platform == 'win32', "Don't have pipes") def test_read_all_from_pipe_reader(self): # See Tulip issue 168. This test is derived from the example # subprocess_attach_read_pipe.py, but we configure the @@ -595,7 +598,7 @@ def test_read_all_from_pipe_reader(self): rfd, wfd = os.pipe() args = [sys.executable, '-c', code, str(wfd)] - pipe = open(rfd, 'rb', 0) + pipe = io.open(rfd, 'rb', 0) reader = asyncio.StreamReader(loop=self.loop, limit=1) protocol = asyncio.StreamReaderProtocol(reader, loop=self.loop) transport, _ = self.loop.run_until_complete( @@ -605,9 +608,10 @@ def test_read_all_from_pipe_reader(self): watcher.attach_loop(self.loop) try: asyncio.set_child_watcher(watcher) - create = asyncio.create_subprocess_exec(*args, - pass_fds={wfd}, - loop=self.loop) + kw = {'loop': self.loop} + if compat.PY3: + kw['pass_fds'] = set((wfd,)) + create = asyncio.create_subprocess_exec(*args, **kw) proc = self.loop.run_until_complete(create) self.loop.run_until_complete(proc.wait()) finally: diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index b4f3f950..99071ee3 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -1,28 +1,32 @@ +from trollius import subprocess +from trollius import test_utils +import trollius as asyncio +import os import signal import sys import unittest -from unittest import mock +from trollius import From, Return +from trollius import test_support as support +from trollius.test_utils import mock +from trollius.py33_exceptions import BrokenPipeError, ConnectionResetError -import trollius as asyncio -from trollius import base_subprocess -from trollius import subprocess -from trollius import test_utils -try: - from test import support -except ImportError: - from trollius import test_support as support if sys.platform != 'win32': from trollius import unix_events + # Program blocking PROGRAM_BLOCKED = [sys.executable, '-c', 'import time; time.sleep(3600)'] # Program copying input to output -PROGRAM_CAT = [ - sys.executable, '-c', - ';'.join(('import sys', - 'data = sys.stdin.buffer.read()', - 'sys.stdout.buffer.write(data)'))] +if sys.version_info >= (3,): + PROGRAM_CAT = ';'.join(('import sys', + 'data = sys.stdin.buffer.read()', + 'sys.stdout.buffer.write(data)')) +else: + PROGRAM_CAT = ';'.join(('import sys', + 'data = sys.stdin.read()', + 'sys.stdout.write(data)')) +PROGRAM_CAT = [sys.executable, '-c', PROGRAM_CAT] class TestSubprocessTransport(base_subprocess.BaseSubprocessTransport): def _start(self, *args, **kwargs): @@ -81,21 +85,21 @@ def test_stdin_stdout(self): @asyncio.coroutine def run(data): - proc = yield from asyncio.create_subprocess_exec( - *args, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - loop=self.loop) + proc = yield From(asyncio.create_subprocess_exec( + *args, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + loop=self.loop)) # feed data proc.stdin.write(data) - yield from proc.stdin.drain() + yield From(proc.stdin.drain()) proc.stdin.close() # get output and exitcode - data = yield from proc.stdout.read() - exitcode = yield from proc.wait() - return (exitcode, data) + data = yield From(proc.stdout.read()) + exitcode = yield From(proc.wait()) + raise Return(exitcode, data) task = run(b'some data') task = asyncio.wait_for(task, 60.0, loop=self.loop) @@ -108,13 +112,13 @@ def test_communicate(self): @asyncio.coroutine def run(data): - proc = yield from asyncio.create_subprocess_exec( + proc = yield From(asyncio.create_subprocess_exec( *args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, - loop=self.loop) - stdout, stderr = yield from proc.communicate(data) - return proc.returncode, stdout + loop=self.loop)) + stdout, stderr = yield From(proc.communicate(data)) + raise Return(proc.returncode, stdout) task = run(b'some data') task = asyncio.wait_for(task, 60.0, loop=self.loop) @@ -129,10 +133,14 @@ def test_shell(self): exitcode = self.loop.run_until_complete(proc.wait()) self.assertEqual(exitcode, 7) + @test_utils.skipUnless(hasattr(os, 'setsid'), "need os.setsid()") def test_start_new_session(self): + def start_new_session(): + os.setsid() + # start the new process in a new session create = asyncio.create_subprocess_shell('exit 8', - start_new_session=True, + preexec_fn=start_new_session, loop=self.loop) proc = self.loop.run_until_complete(create) exitcode = self.loop.run_until_complete(proc.wait()) @@ -162,9 +170,13 @@ def test_terminate(self): else: self.assertEqual(-signal.SIGTERM, returncode) - @unittest.skipIf(sys.platform == 'win32', "Don't have SIGHUP") + @test_utils.skipIf(sys.platform == 'win32', "Don't have SIGHUP") def test_send_signal(self): - code = 'import time; print("sleeping", flush=True); time.sleep(3600)' + code = '; '.join(( + 'import sys, time', + 'print("sleeping")', + 'sys.stdout.flush()', + 'time.sleep(3600)')) args = [sys.executable, '-c', code] create = asyncio.create_subprocess_exec(*args, stdout=subprocess.PIPE, @@ -174,12 +186,12 @@ def test_send_signal(self): @asyncio.coroutine def send_signal(proc): # basic synchronization to wait until the program is sleeping - line = yield from proc.stdout.readline() + line = yield From(proc.stdout.readline()) self.assertEqual(line, b'sleeping\n') proc.send_signal(signal.SIGHUP) - returncode = (yield from proc.wait()) - return returncode + returncode = yield From(proc.wait()) + raise Return(returncode) returncode = self.loop.run_until_complete(send_signal(proc)) self.assertEqual(-signal.SIGHUP, returncode) @@ -202,7 +214,7 @@ def test_stdin_broken_pipe(self): @asyncio.coroutine def write_stdin(proc, data): proc.stdin.write(data) - yield from proc.stdin.drain() + yield From(proc.stdin.drain()) coro = write_stdin(proc, large_data) # drain() must raise BrokenPipeError or ConnectionResetError @@ -235,27 +247,28 @@ def test_pause_reading(): @asyncio.coroutine def connect_read_pipe_mock(*args, **kw): - transport, protocol = yield from connect_read_pipe(*args, **kw) + connect = connect_read_pipe(*args, **kw) + transport, protocol = yield From(connect) transport.pause_reading = mock.Mock() transport.resume_reading = mock.Mock() - return (transport, protocol) + raise Return(transport, protocol) self.loop.connect_read_pipe = connect_read_pipe_mock - proc = yield from asyncio.create_subprocess_exec( + proc = yield From(asyncio.create_subprocess_exec( sys.executable, '-c', code, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, limit=limit, - loop=self.loop) + loop=self.loop)) stdout_transport = proc._transport.get_pipe_transport(1) - stdout, stderr = yield from proc.communicate() + stdout, stderr = yield From(proc.communicate()) # The child process produced more than limit bytes of output, # the stream reader transport should pause the protocol to not # allocate too much memory. - return (stdout, stdout_transport) + raise Return(stdout, stdout_transport) # Issue #22685: Ensure that the stream reader pauses the protocol # when the child process produces too much data @@ -271,16 +284,16 @@ def test_stdin_not_inheritable(self): @asyncio.coroutine def len_message(message): code = 'import sys; data = sys.stdin.read(); print(len(data))' - proc = yield from asyncio.create_subprocess_exec( + proc = yield From(asyncio.create_subprocess_exec( sys.executable, '-c', code, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, close_fds=False, - loop=self.loop) - stdout, stderr = yield from proc.communicate(message) - exitcode = yield from proc.wait() - return (stdout, exitcode) + loop=self.loop)) + stdout, stderr = yield From(proc.communicate(message)) + exitcode = yield From(proc.wait()) + raise Return(stdout, exitcode) output, exitcode = self.loop.run_until_complete(len_message(b'abc')) self.assertEqual(output.rstrip(), b'3') @@ -308,7 +321,7 @@ def cancel_wait(): # Kill the process and wait until it is done proc.kill() - yield from proc.wait() + yield From(proc.wait()) self.loop.run_until_complete(cancel_wait()) @@ -321,7 +334,7 @@ def cancel_make_transport(): self.loop.call_soon(task.cancel) try: - yield from task + yield From(task) except asyncio.CancelledError: pass @@ -339,7 +352,7 @@ def cancel_make_transport(): self.loop.call_soon(task.cancel) try: - yield from task + yield From(task) except asyncio.CancelledError: pass diff --git a/tests/test_tasks.py b/tests/test_tasks.py index a8ceba01..c2045a53 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1,31 +1,22 @@ """Tests for tasks.py.""" import contextlib -import functools import os import re import sys import types import unittest import weakref -from unittest import mock import trollius as asyncio +from trollius import From, Return from trollius import coroutines +from trollius import test_support as support from trollius import test_utils -try: - from test import support -except ImportError: - from trollius import test_support as support -try: - from test.support.script_helper import assert_python_ok -except ImportError: - try: - from test.script_helper import assert_python_ok - except ImportError: - from trollius.test_support import assert_python_ok +from trollius.test_utils import mock +PY33 = (sys.version_info >= (3, 3)) PY34 = (sys.version_info >= (3, 4)) PY35 = (sys.version_info >= (3, 5)) @@ -167,8 +158,8 @@ def test_task_repr(self): @asyncio.coroutine def notmuch(): - yield from [] - return 'abc' + yield From(None) + raise Return('abc') # test coroutine function self.assertEqual(notmuch.__name__, 'notmuch') @@ -182,7 +173,7 @@ def notmuch(): # test coroutine object gen = notmuch() - if coroutines._DEBUG or PY35: + if PY35 or (coroutines._DEBUG and PY33): coro_qualname = 'TaskTests.test_task_repr..notmuch' else: coro_qualname = 'notmuch' @@ -226,7 +217,7 @@ def test_task_repr_coro_decorator(self): @asyncio.coroutine def notmuch(): - # notmuch() function doesn't use yield from: it will be wrapped by + # notmuch() function doesn't use yield: it will be wrapped by # @coroutine decorator return 123 @@ -240,13 +231,16 @@ def notmuch(): # test coroutine object gen = notmuch() - if coroutines._DEBUG or PY35: + if PY35 or coroutines._DEBUG: # On Python >= 3.5, generators now inherit the name of the # function, as expected, and have a qualified name (__qualname__ # attribute). coro_name = 'notmuch' - coro_qualname = ('TaskTests.test_task_repr_coro_decorator' - '..notmuch') + if PY35 or (coroutines._DEBUG and PY33): + coro_qualname = ('TaskTests.test_task_repr_coro_decorator' + '..notmuch') + else: + coro_qualname = 'notmuch' else: # On Python < 3.5, generators inherit the name of the code, not of # the function. See: http://bugs.python.org/issue21205 @@ -294,7 +288,8 @@ def test_task_repr_wait_for(self): @asyncio.coroutine def wait_for(fut): - return (yield from fut) + res = yield From(fut) + raise Return(res) fut = asyncio.Future(loop=self.loop) task = asyncio.Task(wait_for(fut), loop=self.loop) @@ -331,9 +326,9 @@ def func(x, y): def test_task_basics(self): @asyncio.coroutine def outer(): - a = yield from inner1() - b = yield from inner2() - return a+b + a = yield From(inner1()) + b = yield From(inner2()) + raise Return(a+b) @asyncio.coroutine def inner1(): @@ -357,8 +352,8 @@ def gen(): @asyncio.coroutine def task(): - yield from asyncio.sleep(10.0, loop=loop) - return 12 + yield From(asyncio.sleep(10.0, loop=loop)) + raise Return(12) t = asyncio.Task(task(), loop=loop) loop.call_soon(t.cancel) @@ -371,9 +366,9 @@ def task(): def test_cancel_yield(self): @asyncio.coroutine def task(): - yield - yield - return 12 + yield From(None) + yield From(None) + raise Return(12) t = asyncio.Task(task(), loop=self.loop) test_utils.run_briefly(self.loop) # start coro @@ -389,8 +384,8 @@ def test_cancel_inner_future(self): @asyncio.coroutine def task(): - yield from f - return 12 + yield From(f) + raise Return(12) t = asyncio.Task(task(), loop=self.loop) test_utils.run_briefly(self.loop) # start task @@ -405,8 +400,8 @@ def test_cancel_both_task_and_inner_future(self): @asyncio.coroutine def task(): - yield from f - return 12 + yield From(f) + raise Return(12) t = asyncio.Task(task(), loop=self.loop) test_utils.run_briefly(self.loop) @@ -431,7 +426,7 @@ def task(): try: yield from fut2 except asyncio.CancelledError: - return 42 + raise Return(42) t = asyncio.Task(task(), loop=self.loop) test_utils.run_briefly(self.loop) @@ -452,13 +447,13 @@ def test_cancel_task_ignoring(self): @asyncio.coroutine def task(): - yield from fut1 + yield From(fut1) try: - yield from fut2 + yield From(fut2) except asyncio.CancelledError: pass - res = yield from fut3 - return res + res = yield From(fut3) + raise Return(res) t = asyncio.Task(task(), loop=self.loop) test_utils.run_briefly(self.loop) @@ -485,8 +480,8 @@ def task(): t.cancel() self.assertTrue(t._must_cancel) # White-box test. # The sleep should be cancelled immediately. - yield from asyncio.sleep(100, loop=loop) - return 12 + yield From(asyncio.sleep(100, loop=loop)) + raise Return(12) t = asyncio.Task(task(), loop=loop) self.assertRaises( @@ -508,17 +503,16 @@ def gen(): loop = self.new_test_loop(gen) - x = 0 + non_local = {'x': 0} waiters = [] @asyncio.coroutine def task(): - nonlocal x - while x < 10: + while non_local['x'] < 10: waiters.append(asyncio.sleep(0.1, loop=loop)) - yield from waiters[-1] - x += 1 - if x == 2: + yield From(waiters[-1]) + non_local['x'] += 1 + if non_local['x'] == 3: loop.stop() t = asyncio.Task(task(), loop=loop) @@ -527,7 +521,7 @@ def task(): self.assertEqual(str(cm.exception), 'Event loop stopped before Future completed.') self.assertFalse(t.done()) - self.assertEqual(x, 2) + self.assertEqual(non_local['x'], 3) self.assertAlmostEqual(0.3, loop.time()) # close generators @@ -538,6 +532,7 @@ def task(): def test_wait_for(self): + @asyncio.coroutine def gen(): when = yield self.assertAlmostEqual(0.2, when) @@ -547,27 +542,34 @@ def gen(): loop = self.new_test_loop(gen) - foo_running = None + non_local = {'foo_running': None} @asyncio.coroutine def foo(): - nonlocal foo_running - foo_running = True + non_local['foo_running'] = True try: - yield from asyncio.sleep(0.2, loop=loop) + yield From(asyncio.sleep(0.2, loop=loop)) finally: - foo_running = False - return 'done' + non_local['foo_running'] = False + raise Return('done') fut = asyncio.Task(foo(), loop=loop) + test_utils.run_briefly(loop) with self.assertRaises(asyncio.TimeoutError): loop.run_until_complete(asyncio.wait_for(fut, 0.1, loop=loop)) + + # Trollius issue #2: need to run the loop briefly to ensure that the + # cancellation is propagated to all tasks + waiter = asyncio.Future(loop=loop) + fut.add_done_callback(lambda f: waiter.set_result(True)) + loop.run_until_complete(waiter) + self.assertTrue(fut.done()) # it should have been cancelled due to the timeout self.assertTrue(fut.cancelled()) self.assertAlmostEqual(0.1, loop.time()) - self.assertEqual(foo_running, False) + self.assertEqual(non_local['foo_running'], False) def test_wait_for_blocking(self): loop = self.new_test_loop() @@ -594,17 +596,24 @@ def gen(): @asyncio.coroutine def foo(): - yield from asyncio.sleep(0.2, loop=loop) - return 'done' + yield From(asyncio.sleep(0.2, loop=loop)) + raise Return('done') asyncio.set_event_loop(loop) try: fut = asyncio.Task(foo(), loop=loop) + test_utils.run_briefly(loop) with self.assertRaises(asyncio.TimeoutError): loop.run_until_complete(asyncio.wait_for(fut, 0.01)) finally: asyncio.set_event_loop(None) + # Trollius issue #2: need to run the loop briefly to ensure that the + # cancellation is propagated to all tasks + waiter = asyncio.Future(loop=loop) + fut.add_done_callback(lambda f: waiter.set_result(True)) + loop.run_until_complete(waiter) + self.assertAlmostEqual(0.01, loop.time()) self.assertTrue(fut.done()) self.assertTrue(fut.cancelled()) @@ -640,10 +649,10 @@ def gen(): @asyncio.coroutine def foo(): - done, pending = yield from asyncio.wait([b, a], loop=loop) + done, pending = yield From(asyncio.wait([b, a], loop=loop)) self.assertEqual(done, set([a, b])) self.assertEqual(pending, set()) - return 42 + raise Return(42) res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) self.assertEqual(res, 42) @@ -670,10 +679,10 @@ def gen(): @asyncio.coroutine def foo(): - done, pending = yield from asyncio.wait([b, a]) + done, pending = yield From(asyncio.wait([b, a])) self.assertEqual(done, set([a, b])) self.assertEqual(pending, set()) - return 42 + raise Return(42) asyncio.set_event_loop(loop) res = loop.run_until_complete( @@ -694,7 +703,7 @@ def coro(s): done, pending = self.loop.run_until_complete(task) self.assertFalse(pending) - self.assertEqual(set(f.result() for f in done), {'test', 'spam'}) + self.assertEqual(set(f.result() for f in done), set(('test', 'spam'))) def test_wait_errors(self): self.assertRaises( @@ -728,8 +737,8 @@ def gen(): loop=loop) done, pending = loop.run_until_complete(task) - self.assertEqual({b}, done) - self.assertEqual({a}, pending) + self.assertEqual(set((b,)), done) + self.assertEqual(set((a,)), pending) self.assertFalse(a.done()) self.assertTrue(b.done()) self.assertIsNone(b.result()) @@ -745,12 +754,12 @@ def test_wait_really_done(self): @asyncio.coroutine def coro1(): - yield + yield From(None) @asyncio.coroutine def coro2(): - yield - yield + yield From(None) + yield From(None) a = asyncio.Task(coro1(), loop=self.loop) b = asyncio.Task(coro2(), loop=self.loop) @@ -760,7 +769,7 @@ def coro2(): loop=self.loop) done, pending = self.loop.run_until_complete(task) - self.assertEqual({a, b}, done) + self.assertEqual(set((a, b)), done) self.assertTrue(a.done()) self.assertIsNone(a.result()) self.assertTrue(b.done()) @@ -789,8 +798,8 @@ def exc(): loop=loop) done, pending = loop.run_until_complete(task) - self.assertEqual({b}, done) - self.assertEqual({a}, pending) + self.assertEqual(set((b,)), done) + self.assertEqual(set((a,)), pending) self.assertAlmostEqual(0, loop.time()) # move forward to close generator @@ -813,7 +822,7 @@ def gen(): @asyncio.coroutine def exc(): - yield from asyncio.sleep(0.01, loop=loop) + yield From(asyncio.sleep(0.01, loop=loop)) raise ZeroDivisionError('err') b = asyncio.Task(exc(), loop=loop) @@ -821,8 +830,8 @@ def exc(): loop=loop) done, pending = loop.run_until_complete(task) - self.assertEqual({b}, done) - self.assertEqual({a}, pending) + self.assertEqual(set((b,)), done) + self.assertEqual(set((a,)), pending) self.assertAlmostEqual(0.01, loop.time()) # move forward to close generator @@ -844,14 +853,14 @@ def gen(): @asyncio.coroutine def sleeper(): - yield from asyncio.sleep(0.15, loop=loop) + yield From(asyncio.sleep(0.15, loop=loop)) raise ZeroDivisionError('really') b = asyncio.Task(sleeper(), loop=loop) @asyncio.coroutine def foo(): - done, pending = yield from asyncio.wait([b, a], loop=loop) + done, pending = yield From(asyncio.wait([b, a], loop=loop)) self.assertEqual(len(done), 2) self.assertEqual(pending, set()) errors = set(f for f in done if f.exception() is not None) @@ -881,8 +890,8 @@ def gen(): @asyncio.coroutine def foo(): - done, pending = yield from asyncio.wait([b, a], timeout=0.11, - loop=loop) + done, pending = yield From(asyncio.wait([b, a], timeout=0.11, + loop=loop)) self.assertEqual(done, set([a])) self.assertEqual(pending, set([b])) @@ -932,17 +941,16 @@ def gen(): # disable "slow callback" warning loop.slow_callback_duration = 1.0 completed = set() - time_shifted = False + non_local = {'time_shifted': False} @asyncio.coroutine def sleeper(dt, x): - nonlocal time_shifted - yield from asyncio.sleep(dt, loop=loop) + yield From(asyncio.sleep(dt, loop=loop)) completed.add(x) - if not time_shifted and 'a' in completed and 'b' in completed: - time_shifted = True + if not non_local['time_shifted'] and 'a' in completed and 'b' in completed: + non_local['time_shifted'] = True loop.advance_time(0.14) - return x + raise Return(x) a = sleeper(0.01, 'a') b = sleeper(0.01, 'b') @@ -952,8 +960,8 @@ def sleeper(dt, x): def foo(): values = [] for f in asyncio.as_completed([b, c, a], loop=loop): - values.append((yield from f)) - return values + values.append((yield From(f))) + raise Return(values) res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) self.assertAlmostEqual(0.15, loop.time()) @@ -985,11 +993,11 @@ def foo(): if values: loop.advance_time(0.02) try: - v = yield from f + v = yield From(f) values.append((1, v)) except asyncio.TimeoutError as exc: values.append((2, exc)) - return values + raise Return(values) res = loop.run_until_complete(asyncio.Task(foo(), loop=loop)) self.assertEqual(len(res), 2, res) @@ -1016,7 +1024,7 @@ def gen(): @asyncio.coroutine def foo(): for f in asyncio.as_completed([a], timeout=1, loop=loop): - v = yield from f + v = yield From(f) self.assertEqual(v, 'a') loop.run_until_complete(asyncio.Task(foo(), loop=loop)) @@ -1032,7 +1040,7 @@ def gen(): a = asyncio.sleep(0.05, 'a', loop=loop) b = asyncio.sleep(0.10, 'b', loop=loop) - fs = {a, b} + fs = set((a, b)) futs = list(asyncio.as_completed(fs, loop=loop)) self.assertEqual(len(futs), 2) @@ -1057,12 +1065,12 @@ def gen(): a = asyncio.sleep(0.05, 'a', loop=loop) b = asyncio.sleep(0.05, 'b', loop=loop) - fs = {a, b} + fs = set((a, b)) futs = list(asyncio.as_completed(fs, loop=loop)) self.assertEqual(len(futs), 2) waiter = asyncio.wait(futs, loop=loop) done, pending = loop.run_until_complete(waiter) - self.assertEqual(set(f.result() for f in done), {'a', 'b'}) + self.assertEqual(set(f.result() for f in done), set(('a', 'b'))) def test_as_completed_duplicate_coroutines(self): @@ -1076,13 +1084,13 @@ def runner(): c = coro('ham') for f in asyncio.as_completed([c, c, coro('spam')], loop=self.loop): - result.append((yield from f)) - return result + result.append((yield From(f))) + raise Return(result) fut = asyncio.Task(runner(), loop=self.loop) self.loop.run_until_complete(fut) result = fut.result() - self.assertEqual(set(result), {'ham', 'spam'}) + self.assertEqual(set(result), set(('ham', 'spam'))) self.assertEqual(len(result), 2) def test_sleep(self): @@ -1098,9 +1106,9 @@ def gen(): @asyncio.coroutine def sleeper(dt, arg): - yield from asyncio.sleep(dt/2, loop=loop) - res = yield from asyncio.sleep(dt/2, arg, loop=loop) - return res + yield From(asyncio.sleep(dt/2, loop=loop)) + res = yield From(asyncio.sleep(dt/2, arg, loop=loop)) + raise Return(res) t = asyncio.Task(sleeper(0.1, 'yeah'), loop=loop) loop.run_until_complete(t) @@ -1120,22 +1128,21 @@ def gen(): t = asyncio.Task(asyncio.sleep(10.0, 'yeah', loop=loop), loop=loop) - handle = None + non_local = {'handle': None} orig_call_later = loop.call_later def call_later(delay, callback, *args): - nonlocal handle - handle = orig_call_later(delay, callback, *args) - return handle + non_local['handle'] = orig_call_later(delay, callback, *args) + return non_local['handle'] loop.call_later = call_later test_utils.run_briefly(loop) - self.assertFalse(handle._cancelled) + self.assertFalse(non_local['handle']._cancelled) t.cancel() test_utils.run_briefly(loop) - self.assertTrue(handle._cancelled) + self.assertTrue(non_local['handle']._cancelled) def test_task_cancel_sleeping_task(self): @@ -1150,18 +1157,18 @@ def gen(): @asyncio.coroutine def sleep(dt): - yield from asyncio.sleep(dt, loop=loop) + yield From(asyncio.sleep(dt, loop=loop)) @asyncio.coroutine def doit(): sleeper = asyncio.Task(sleep(5000), loop=loop) loop.call_later(0.1, sleeper.cancel) try: - yield from sleeper + yield From(sleeper) except asyncio.CancelledError: - return 'cancelled' + raise Return('cancelled') else: - return 'slept in' + raise Return('slept in') doer = doit() self.assertEqual(loop.run_until_complete(doer), 'cancelled') @@ -1172,7 +1179,7 @@ def test_task_cancel_waiter_future(self): @asyncio.coroutine def coro(): - yield from fut + yield From(fut) task = asyncio.Task(coro(), loop=self.loop) test_utils.run_briefly(self.loop) @@ -1200,9 +1207,9 @@ def notmuch(): def test_step_result(self): @asyncio.coroutine def notmuch(): - yield None - yield 1 - return 'ko' + yield From(None) + yield From(1) + raise Return('ko') self.assertRaises( RuntimeError, self.loop.run_until_complete, notmuch()) @@ -1213,19 +1220,18 @@ def test_step_result_future(self): class Fut(asyncio.Future): def __init__(self, *args, **kwds): self.cb_added = False - super().__init__(*args, **kwds) + super(Fut, self).__init__(*args, **kwds) def add_done_callback(self, fn): self.cb_added = True - super().add_done_callback(fn) + super(Fut, self).add_done_callback(fn) fut = Fut(loop=self.loop) - result = None + non_local = {'result': None} @asyncio.coroutine def wait_for_future(): - nonlocal result - result = yield from fut + non_local['result'] = yield From(fut) t = asyncio.Task(wait_for_future(), loop=self.loop) test_utils.run_briefly(self.loop) @@ -1234,7 +1240,7 @@ def wait_for_future(): res = object() fut.set_result(res) test_utils.run_briefly(self.loop) - self.assertIs(res, result) + self.assertIs(res, non_local['result']) self.assertTrue(t.done()) self.assertIsNone(t.result()) @@ -1260,24 +1266,24 @@ def gen(): @asyncio.coroutine def sleeper(): - yield from asyncio.sleep(10, loop=loop) + yield From(asyncio.sleep(10, loop=loop)) base_exc = BaseException() @asyncio.coroutine def notmutch(): try: - yield from sleeper() + yield From(sleeper()) except asyncio.CancelledError: raise base_exc task = asyncio.Task(notmutch(), loop=loop) - test_utils.run_briefly(loop) + test_utils.run_briefly(loop, 2) task.cancel() self.assertFalse(task.done()) - self.assertRaises(BaseException, test_utils.run_briefly, loop) + self.assertRaises(BaseException, test_utils.run_briefly, loop, 2) self.assertTrue(task.done()) self.assertFalse(task.cancelled()) @@ -1379,7 +1385,7 @@ def test_current_task_with_interleaving_tasks(self): @asyncio.coroutine def coro1(loop): self.assertTrue(asyncio.Task.current_task(loop=loop) is task1) - yield from fut1 + yield From(fut1) self.assertTrue(asyncio.Task.current_task(loop=loop) is task1) fut2.set_result(True) @@ -1387,7 +1393,7 @@ def coro1(loop): def coro2(loop): self.assertTrue(asyncio.Task.current_task(loop=loop) is task2) fut1.set_result(True) - yield from fut2 + yield From(fut2) self.assertTrue(asyncio.Task.current_task(loop=loop) is task2) task1 = asyncio.Task(coro1(self.loop), loop=self.loop) @@ -1402,54 +1408,50 @@ def coro2(loop): def test_yield_future_passes_cancel(self): # Cancelling outer() cancels inner() cancels waiter. - proof = 0 + non_local = {'proof': 0} waiter = asyncio.Future(loop=self.loop) @asyncio.coroutine def inner(): - nonlocal proof try: - yield from waiter + yield From(waiter) except asyncio.CancelledError: - proof += 1 + non_local['proof'] += 1 raise else: self.fail('got past sleep() in inner()') @asyncio.coroutine def outer(): - nonlocal proof try: - yield from inner() + yield From(inner()) except asyncio.CancelledError: - proof += 100 # Expect this path. + non_local['proof'] += 100 # Expect this path. else: - proof += 10 + non_local['proof'] += 10 f = asyncio.ensure_future(outer(), loop=self.loop) test_utils.run_briefly(self.loop) f.cancel() self.loop.run_until_complete(f) - self.assertEqual(proof, 101) + self.assertEqual(non_local['proof'], 101) self.assertTrue(waiter.cancelled()) def test_yield_wait_does_not_shield_cancel(self): # Cancelling outer() makes wait() return early, leaves inner() # running. - proof = 0 + non_local = {'proof': 0} waiter = asyncio.Future(loop=self.loop) @asyncio.coroutine def inner(): - nonlocal proof - yield from waiter - proof += 1 + yield From(waiter) + non_local['proof'] += 1 @asyncio.coroutine def outer(): - nonlocal proof - d, p = yield from asyncio.wait([inner()], loop=self.loop) - proof += 100 + d, p = yield From(asyncio.wait([inner()], loop=self.loop)) + non_local['proof'] += 100 f = asyncio.ensure_future(outer(), loop=self.loop) test_utils.run_briefly(self.loop) @@ -1458,7 +1460,7 @@ def outer(): asyncio.CancelledError, self.loop.run_until_complete, f) waiter.set_result(None) test_utils.run_briefly(self.loop) - self.assertEqual(proof, 1) + self.assertEqual(non_local['proof'], 1) def test_shield_result(self): inner = asyncio.Future(loop=self.loop) @@ -1492,20 +1494,18 @@ def test_shield_shortcut(self): def test_shield_effect(self): # Cancelling outer() does not affect inner(). - proof = 0 + non_local = {'proof': 0} waiter = asyncio.Future(loop=self.loop) @asyncio.coroutine def inner(): - nonlocal proof - yield from waiter - proof += 1 + yield From(waiter) + non_local['proof'] += 1 @asyncio.coroutine def outer(): - nonlocal proof - yield from asyncio.shield(inner(), loop=self.loop) - proof += 100 + yield From(asyncio.shield(inner(), loop=self.loop)) + non_local['proof'] += 100 f = asyncio.ensure_future(outer(), loop=self.loop) test_utils.run_briefly(self.loop) @@ -1514,7 +1514,7 @@ def outer(): self.loop.run_until_complete(f) waiter.set_result(None) test_utils.run_briefly(self.loop) - self.assertEqual(proof, 1) + self.assertEqual(non_local['proof'], 1) def test_shield_gather(self): child1 = asyncio.Future(loop=self.loop) @@ -1583,7 +1583,7 @@ def check(): def coro(): # The actual coroutine. self.assertTrue(gen.gi_running) - yield from fut + yield From(fut) # A completed Future used to run the coroutine. fut = asyncio.Future(loop=self.loop) @@ -1622,7 +1622,8 @@ def t1(): def t2(): f = asyncio.Future(loop=self.loop) asyncio.Task(t3(f), loop=self.loop) - return (yield from f) + res = yield From(f) + raise Return(res) @asyncio.coroutine def t3(f): @@ -1635,15 +1636,16 @@ def t3(f): def test_yield_from_corowrapper_send(self): def foo(): a = yield - return a + raise Return(a) def call(arg): - cw = asyncio.coroutines.CoroWrapper(foo()) + cw = asyncio.coroutines.CoroWrapper(foo(), foo) cw.send(None) try: cw.send(arg) except StopIteration as ex: - return ex.args[0] + ex.raised = True + return ex.value else: raise AssertionError('StopIteration was expected') @@ -1652,18 +1654,19 @@ def call(arg): def test_corowrapper_weakref(self): wd = weakref.WeakValueDictionary() - def foo(): yield from [] - cw = asyncio.coroutines.CoroWrapper(foo()) + def foo(): + yield From(None) + cw = asyncio.coroutines.CoroWrapper(foo(), foo) wd['cw'] = cw # Would fail without __weakref__ slot. cw.gen = None # Suppress warning from __del__. - @unittest.skipUnless(PY34, - 'need python 3.4 or later') + @test_utils.skipUnless(PY34, + 'need python 3.4 or later') def test_log_destroyed_pending_task(self): @asyncio.coroutine def kill_me(loop): future = asyncio.Future(loop=loop) - yield from future + yield From(future) # at this point, the only reference to kill_me() task is # the Task._wakeup() method in future._callbacks raise Exception("code never reached") @@ -1675,7 +1678,7 @@ def kill_me(loop): # schedule the task coro = kill_me(self.loop) task = asyncio.ensure_future(coro, loop=self.loop) - self.assertEqual(asyncio.Task.all_tasks(loop=self.loop), {task}) + self.assertEqual(asyncio.Task.all_tasks(loop=self.loop), set((task,))) # execute the task so it waits for future self.loop._run_once() @@ -1706,7 +1709,7 @@ def test_coroutine_never_yielded(self, m_log): def coro_noop(): pass - tb_filename = __file__ + tb_filename = sys._getframe().f_code.co_filename tb_lineno = sys._getframe().f_lineno + 2 # create a coroutine object but don't use it coro_noop() @@ -1715,14 +1718,14 @@ def coro_noop(): self.assertTrue(m_log.error.called) message = m_log.error.call_args[0][0] func_filename, func_lineno = test_utils.get_function_source(coro_noop) - - regex = (r'^ ' + coro_name = getattr(coro_noop, '__qualname__', coro_noop.__name__) + regex = (r'^ ' r'was never yielded from\n' r'Coroutine object created at \(most recent call last\):\n' r'.*\n' r' File "%s", line %s, in test_coroutine_never_yielded\n' r' coro_noop\(\)$' - % (re.escape(coro_noop.__qualname__), + % (re.escape(coro_name), re.escape(func_filename), func_lineno, re.escape(tb_filename), tb_lineno)) @@ -1748,7 +1751,7 @@ def _test_cancel_wait_for(self, timeout): def blocking_coroutine(): fut = asyncio.Future(loop=loop) # Block: fut result is never set - yield from fut + yield From(fut) task = loop.create_task(blocking_coroutine()) @@ -1842,30 +1845,19 @@ def test_env_var_debug(self): aio_path = os.path.dirname(os.path.dirname(asyncio.__file__)) code = '\n'.join(( - 'import asyncio.coroutines', - 'print(asyncio.coroutines._DEBUG)')) + 'import trollius.coroutines', + 'print(trollius.coroutines._DEBUG)')) - # Test with -E to not fail if the unit test was run with - # PYTHONASYNCIODEBUG set to a non-empty string - sts, stdout, stderr = assert_python_ok('-E', '-c', code, - PYTHONPATH=aio_path) + sts, stdout, stderr = support.assert_python_ok('-c', code, + TROLLIUSDEBUG='', + PYTHONPATH=aio_path) self.assertEqual(stdout.rstrip(), b'False') - sts, stdout, stderr = assert_python_ok('-c', code, - PYTHONASYNCIODEBUG='', - PYTHONPATH=aio_path) - self.assertEqual(stdout.rstrip(), b'False') - - sts, stdout, stderr = assert_python_ok('-c', code, - PYTHONASYNCIODEBUG='1', - PYTHONPATH=aio_path) + sts, stdout, stderr = support.assert_python_ok('-c', code, + TROLLIUSDEBUG='1', + PYTHONPATH=aio_path) self.assertEqual(stdout.rstrip(), b'True') - sts, stdout, stderr = assert_python_ok('-E', '-c', code, - PYTHONASYNCIODEBUG='1', - PYTHONPATH=aio_path) - self.assertEqual(stdout.rstrip(), b'False') - class FutureGatherTests(GatherTestsBase, test_utils.TestCase): @@ -1954,7 +1946,7 @@ def test_result_exception_one_cancellation(self): class CoroutineGatherTests(GatherTestsBase, test_utils.TestCase): def setUp(self): - super().setUp() + super(CoroutineGatherTests, self).setUp() asyncio.set_event_loop(self.one_loop) def wrap_futures(self, *futures): @@ -1962,7 +1954,8 @@ def wrap_futures(self, *futures): for fut in futures: @asyncio.coroutine def coro(fut=fut): - return (yield from fut) + result = (yield From(fut)) + raise Return(result) coros.append(coro()) return coros @@ -1994,37 +1987,35 @@ def coro(s): def test_cancellation_broadcast(self): # Cancelling outer() cancels all children. - proof = 0 + non_local = {'proof': 0} waiter = asyncio.Future(loop=self.one_loop) @asyncio.coroutine def inner(): - nonlocal proof - yield from waiter - proof += 1 + yield From(waiter) + non_local['proof'] += 1 child1 = asyncio.ensure_future(inner(), loop=self.one_loop) child2 = asyncio.ensure_future(inner(), loop=self.one_loop) - gatherer = None + non_local['gatherer'] = None @asyncio.coroutine def outer(): - nonlocal proof, gatherer - gatherer = asyncio.gather(child1, child2, loop=self.one_loop) - yield from gatherer - proof += 100 + non_local['gatherer'] = asyncio.gather(child1, child2, loop=self.one_loop) + yield From(non_local['gatherer']) + non_local['proof'] += 100 f = asyncio.ensure_future(outer(), loop=self.one_loop) test_utils.run_briefly(self.one_loop) self.assertTrue(f.cancel()) with self.assertRaises(asyncio.CancelledError): self.one_loop.run_until_complete(f) - self.assertFalse(gatherer.cancel()) + self.assertFalse(non_local['gatherer'].cancel()) self.assertTrue(waiter.cancelled()) self.assertTrue(child1.cancelled()) self.assertTrue(child2.cancelled()) test_utils.run_briefly(self.one_loop) - self.assertEqual(proof, 0) + self.assertEqual(non_local['proof'], 0) def test_exception_marking(self): # Test for the first line marked "Mark exception retrieved." @@ -2039,7 +2030,7 @@ def inner(f): @asyncio.coroutine def outer(): - yield from asyncio.gather(inner(a), inner(b), loop=self.one_loop) + yield From(asyncio.gather(inner(a), inner(b), loop=self.one_loop)) f = asyncio.ensure_future(outer(), loop=self.one_loop) test_utils.run_briefly(self.one_loop) diff --git a/tests/test_transports.py b/tests/test_transports.py index 1cb03e0a..42f7729f 100644 --- a/tests/test_transports.py +++ b/tests/test_transports.py @@ -1,13 +1,20 @@ """Tests for transports.py.""" import unittest -from unittest import mock import trollius as asyncio +from trollius import test_utils from trollius import transports +from trollius.test_utils import mock +try: + memoryview +except NameError: + # Python 2.6 + memoryview = buffer -class TransportTests(unittest.TestCase): + +class TransportTests(test_utils.TestCase): def test_ctor_extra_is_none(self): transport = asyncio.Transport() diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 7f920d05..1223d86c 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -1,6 +1,7 @@ """Tests for unix_events.py.""" import collections +import contextlib import errno import io import os @@ -11,7 +12,6 @@ import tempfile import threading import unittest -from unittest import mock if sys.platform == 'win32': raise unittest.SkipTest('UNIX only') @@ -21,6 +21,8 @@ from trollius import log from trollius import test_utils from trollius import unix_events +from trollius.py33_exceptions import BlockingIOError, ChildProcessError +from trollius.test_utils import mock MOCK_ANY = mock.ANY @@ -35,7 +37,7 @@ def close_pipe_transport(transport): transport._pipe = None -@unittest.skipUnless(signal, 'Signals are not supported') +@test_utils.skipUnless(signal, 'Signals are not supported') class SelectorEventLoopSignalTests(test_utils.TestCase): def setUp(self): @@ -76,7 +78,7 @@ def test_add_signal_handler_coroutine_error(self, m_signal): @asyncio.coroutine def simple_coroutine(): - yield from [] + yield None # callback must not be a coroutine function coro_func = simple_coroutine @@ -756,7 +758,7 @@ def test_write_eof_pending(self): self.assertFalse(self.protocol.connection_lost.called) -class AbstractChildWatcherTests(unittest.TestCase): +class AbstractChildWatcherTests(test_utils.TestCase): def test_not_implemented(self): f = mock.Mock() @@ -775,7 +777,7 @@ def test_not_implemented(self): NotImplementedError, watcher.__exit__, f, f, f) -class BaseChildWatcherTests(unittest.TestCase): +class BaseChildWatcherTests(test_utils.TestCase): def test_not_implemented(self): f = mock.Mock() @@ -845,19 +847,27 @@ def test_create_watcher(self): def waitpid_mocks(func): def wrapped_func(self): + exit_stack = [] + def patch(target, wrapper): - return mock.patch(target, wraps=wrapper, - new_callable=mock.Mock) - - with patch('os.WTERMSIG', self.WTERMSIG) as m_WTERMSIG, \ - patch('os.WEXITSTATUS', self.WEXITSTATUS) as m_WEXITSTATUS, \ - patch('os.WIFSIGNALED', self.WIFSIGNALED) as m_WIFSIGNALED, \ - patch('os.WIFEXITED', self.WIFEXITED) as m_WIFEXITED, \ - patch('os.waitpid', self.waitpid) as m_waitpid: + m = mock.patch(target, wraps=wrapper) + exit_stack.append(m) + return m.__enter__() + + m_waitpid = patch('os.waitpid', self.waitpid) + m_WIFEXITED = patch('os.WIFEXITED', self.WIFEXITED) + m_WIFSIGNALED = patch('os.WIFSIGNALED', self.WIFSIGNALED) + m_WEXITSTATUS = patch('os.WEXITSTATUS', self.WEXITSTATUS) + m_WTERMSIG = patch('os.WTERMSIG', self.WTERMSIG) + try: func(self, WaitPidMocks(m_waitpid, m_WIFEXITED, m_WIFSIGNALED, m_WEXITSTATUS, m_WTERMSIG, )) + finally: + for obj in reversed(exit_stack): + obj.__exit__(None, None, None) + return wrapped_func @waitpid_mocks @@ -1330,17 +1340,18 @@ def test_sigchld_unknown_pid_during_registration(self, m): callback1 = mock.Mock() callback2 = mock.Mock() - with self.ignore_warnings, self.watcher: - self.running = True - # child 1 terminates - self.add_zombie(591, 7) - # an unknown child terminates - self.add_zombie(593, 17) + with self.ignore_warnings: + with self.watcher: + self.running = True + # child 1 terminates + self.add_zombie(591, 7) + # an unknown child terminates + self.add_zombie(593, 17) - self.watcher._sig_chld() + self.watcher._sig_chld() - self.watcher.add_child_handler(591, callback1) - self.watcher.add_child_handler(592, callback2) + self.watcher.add_child_handler(591, callback1) + self.watcher.add_child_handler(592, callback2) callback1.assert_called_once_with(591, 7) self.assertFalse(callback2.called) @@ -1359,15 +1370,15 @@ def test_set_loop(self, m): self.loop = self.new_test_loop() patch = mock.patch.object - with patch(old_loop, "remove_signal_handler") as m_old_remove, \ - patch(self.loop, "add_signal_handler") as m_new_add: + with patch(old_loop, "remove_signal_handler") as m_old_remove: + with patch(self.loop, "add_signal_handler") as m_new_add: - self.watcher.attach_loop(self.loop) + self.watcher.attach_loop(self.loop) - m_old_remove.assert_called_once_with( - signal.SIGCHLD) - m_new_add.assert_called_once_with( - signal.SIGCHLD, self.watcher._sig_chld) + m_old_remove.assert_called_once_with( + signal.SIGCHLD) + m_new_add.assert_called_once_with( + signal.SIGCHLD, self.watcher._sig_chld) # child terminates self.running = False @@ -1479,7 +1490,7 @@ def create_watcher(self): return asyncio.FastChildWatcher() -class PolicyTests(unittest.TestCase): +class PolicyTests(test_utils.TestCase): def create_policy(self): return asyncio.DefaultEventLoopPolicy() diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index ec0a5ca4..bcf23504 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -1,17 +1,18 @@ +from trollius import test_utils import os import sys import unittest -from unittest import mock if sys.platform != 'win32': - raise unittest.SkipTest('Windows only') - -import _winapi + raise test_utils.SkipTest('Windows only') import trollius as asyncio +from trollius import Return, From from trollius import _overlapped -from trollius import test_utils +from trollius import py33_winapi as _winapi from trollius import windows_events +from trollius.py33_exceptions import PermissionError, FileNotFoundError +from trollius.test_utils import mock class UpperProto(asyncio.Protocol): @@ -58,11 +59,11 @@ def _test_pipe(self): ADDRESS = r'\\.\pipe\_test_pipe-%s' % os.getpid() with self.assertRaises(FileNotFoundError): - yield from self.loop.create_pipe_connection( - asyncio.Protocol, ADDRESS) + yield From(self.loop.create_pipe_connection( + asyncio.Protocol, ADDRESS)) - [server] = yield from self.loop.start_serving_pipe( - UpperProto, ADDRESS) + [server] = yield From(self.loop.start_serving_pipe( + UpperProto, ADDRESS)) self.assertIsInstance(server, windows_events.PipeServer) clients = [] @@ -70,27 +71,27 @@ def _test_pipe(self): stream_reader = asyncio.StreamReader(loop=self.loop) protocol = asyncio.StreamReaderProtocol(stream_reader, loop=self.loop) - trans, proto = yield from self.loop.create_pipe_connection( - lambda: protocol, ADDRESS) + trans, proto = yield From(self.loop.create_pipe_connection( + lambda: protocol, ADDRESS)) self.assertIsInstance(trans, asyncio.Transport) self.assertEqual(protocol, proto) clients.append((stream_reader, trans)) for i, (r, w) in enumerate(clients): - w.write('lower-{}\n'.format(i).encode()) + w.write('lower-{0}\n'.format(i).encode()) for i, (r, w) in enumerate(clients): - response = yield from r.readline() - self.assertEqual(response, 'LOWER-{}\n'.format(i).encode()) + response = yield From(r.readline()) + self.assertEqual(response, 'LOWER-{0}\n'.format(i).encode()) w.close() server.close() with self.assertRaises(FileNotFoundError): - yield from self.loop.create_pipe_connection( - asyncio.Protocol, ADDRESS) + yield From(self.loop.create_pipe_connection( + asyncio.Protocol, ADDRESS)) - return 'done' + raise Return('done') def test_connect_pipe_cancel(self): exc = OSError() diff --git a/tests/test_windows_utils.py b/tests/test_windows_utils.py index 191daabf..bc21bb18 100644 --- a/tests/test_windows_utils.py +++ b/tests/test_windows_utils.py @@ -4,19 +4,17 @@ import sys import unittest import warnings -from unittest import mock if sys.platform != 'win32': - raise unittest.SkipTest('Windows only') - -import _winapi + from trollius.test_utils import SkipTest + raise SkipTest('Windows only') from trollius import _overlapped +from trollius import py33_winapi as _winapi +from trollius import test_support as support +from trollius import test_utils from trollius import windows_utils -try: - from test import support -except ImportError: - from trollius import test_support as support +from trollius.test_utils import mock class WinsocketpairTests(unittest.TestCase): @@ -31,13 +29,14 @@ def test_winsocketpair(self): ssock, csock = windows_utils.socketpair() self.check_winsocketpair(ssock, csock) - @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 not supported or enabled') + @test_utils.skipUnless(support.IPV6_ENABLED, + 'IPv6 not supported or enabled') def test_winsocketpair_ipv6(self): ssock, csock = windows_utils.socketpair(family=socket.AF_INET6) self.check_winsocketpair(ssock, csock) - @unittest.skipIf(hasattr(socket, 'socketpair'), - 'socket.socketpair is available') + @test_utils.skipIf(hasattr(socket, 'socketpair'), + 'socket.socketpair is available') @mock.patch('trollius.windows_utils.socket') def test_winsocketpair_exc(self, m_socket): m_socket.AF_INET = socket.AF_INET @@ -56,8 +55,8 @@ def test_winsocketpair_invalid_args(self): self.assertRaises(ValueError, windows_utils.socketpair, proto=1) - @unittest.skipIf(hasattr(socket, 'socketpair'), - 'socket.socketpair is available') + @test_utils.skipIf(hasattr(socket, 'socketpair'), + 'socket.socketpair is available') @mock.patch('trollius.windows_utils.socket') def test_winsocketpair_close(self, m_socket): m_socket.AF_INET = socket.AF_INET @@ -84,7 +83,7 @@ def test_pipe_overlapped(self): ERROR_IO_INCOMPLETE = 996 try: ov1.getresult() - except OSError as e: + except WindowsError as e: self.assertEqual(e.winerror, ERROR_IO_INCOMPLETE) else: raise RuntimeError('expected ERROR_IO_INCOMPLETE') @@ -94,15 +93,15 @@ def test_pipe_overlapped(self): self.assertEqual(ov2.error, 0) ov2.WriteFile(h2, b"hello") - self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + self.assertIn(ov2.error, set((0, _winapi.ERROR_IO_PENDING))) - res = _winapi.WaitForMultipleObjects([ov2.event], False, 100) + res = _winapi.WaitForSingleObject(ov2.event, 100) self.assertEqual(res, _winapi.WAIT_OBJECT_0) self.assertFalse(ov1.pending) self.assertEqual(ov1.error, ERROR_IO_INCOMPLETE) self.assertFalse(ov2.pending) - self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING}) + self.assertIn(ov2.error, set((0, _winapi.ERROR_IO_PENDING))) self.assertEqual(ov1.getresult(), b"hello") finally: _winapi.CloseHandle(h1) @@ -117,7 +116,8 @@ def test_pipe_handle(self): # check garbage collection of p closes handle with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "", ResourceWarning) + if sys.version_info >= (3, 4): + warnings.filterwarnings("ignore", "", ResourceWarning) del p support.gc_collect() try: From d8b296df530c7a39c4d1a6b534bb7728ee37a4de Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 7 Jul 2015 23:37:05 +0200 Subject: [PATCH 1395/1502] Add documentation --- doc/Makefile | 153 ++++++++++++ doc/asyncio.rst | 185 ++++++++++++++ doc/changelog.rst | 604 ++++++++++++++++++++++++++++++++++++++++++++++ doc/conf.py | 240 ++++++++++++++++++ doc/dev.rst | 85 +++++++ doc/index.rst | 80 ++++++ doc/install.rst | 111 +++++++++ doc/make.bat | 190 +++++++++++++++ doc/trollius.jpg | Bin 0 -> 30083 bytes doc/using.rst | 85 +++++++ 10 files changed, 1733 insertions(+) create mode 100644 doc/Makefile create mode 100644 doc/asyncio.rst create mode 100644 doc/changelog.rst create mode 100644 doc/conf.py create mode 100644 doc/dev.rst create mode 100644 doc/index.rst create mode 100644 doc/install.rst create mode 100644 doc/make.bat create mode 100644 doc/trollius.jpg create mode 100644 doc/using.rst diff --git a/doc/Makefile b/doc/Makefile new file mode 100644 index 00000000..314751af --- /dev/null +++ b/doc/Makefile @@ -0,0 +1,153 @@ +# Makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +PAPER = +BUILDDIR = build + +# Internal variables. +PAPEROPT_a4 = -D latex_paper_size=a4 +PAPEROPT_letter = -D latex_paper_size=letter +ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . +# the i18n builder cannot share the environment and doctrees with the others +I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . + +.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext + +help: + @echo "Please use \`make ' where is one of" + @echo " html to make standalone HTML files" + @echo " dirhtml to make HTML files named index.html in directories" + @echo " singlehtml to make a single large HTML file" + @echo " pickle to make pickle files" + @echo " json to make JSON files" + @echo " htmlhelp to make HTML files and a HTML help project" + @echo " qthelp to make HTML files and a qthelp project" + @echo " devhelp to make HTML files and a Devhelp project" + @echo " epub to make an epub" + @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" + @echo " latexpdf to make LaTeX files and run them through pdflatex" + @echo " text to make text files" + @echo " man to make manual pages" + @echo " texinfo to make Texinfo files" + @echo " info to make Texinfo files and run them through makeinfo" + @echo " gettext to make PO message catalogs" + @echo " changes to make an overview of all changed/added/deprecated items" + @echo " linkcheck to check all external links for integrity" + @echo " doctest to run all doctests embedded in the documentation (if enabled)" + +clean: + -rm -rf $(BUILDDIR)/* + +html: + $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." + +dirhtml: + $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml + @echo + @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." + +singlehtml: + $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml + @echo + @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." + +pickle: + $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle + @echo + @echo "Build finished; now you can process the pickle files." + +json: + $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json + @echo + @echo "Build finished; now you can process the JSON files." + +htmlhelp: + $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp + @echo + @echo "Build finished; now you can run HTML Help Workshop with the" \ + ".hhp project file in $(BUILDDIR)/htmlhelp." + +qthelp: + $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp + @echo + @echo "Build finished; now you can run "qcollectiongenerator" with the" \ + ".qhcp project file in $(BUILDDIR)/qthelp, like this:" + @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/Trollius.qhcp" + @echo "To view the help file:" + @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/Trollius.qhc" + +devhelp: + $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp + @echo + @echo "Build finished." + @echo "To view the help file:" + @echo "# mkdir -p $$HOME/.local/share/devhelp/Trollius" + @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/Trollius" + @echo "# devhelp" + +epub: + $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub + @echo + @echo "Build finished. The epub file is in $(BUILDDIR)/epub." + +latex: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo + @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." + @echo "Run \`make' in that directory to run these through (pdf)latex" \ + "(use \`make latexpdf' here to do that automatically)." + +latexpdf: + $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex + @echo "Running LaTeX files through pdflatex..." + $(MAKE) -C $(BUILDDIR)/latex all-pdf + @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." + +text: + $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text + @echo + @echo "Build finished. The text files are in $(BUILDDIR)/text." + +man: + $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man + @echo + @echo "Build finished. The manual pages are in $(BUILDDIR)/man." + +texinfo: + $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo + @echo + @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." + @echo "Run \`make' in that directory to run these through makeinfo" \ + "(use \`make info' here to do that automatically)." + +info: + $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo + @echo "Running Texinfo files through makeinfo..." + make -C $(BUILDDIR)/texinfo info + @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." + +gettext: + $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale + @echo + @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." + +changes: + $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes + @echo + @echo "The overview file is in $(BUILDDIR)/changes." + +linkcheck: + $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck + @echo + @echo "Link check complete; look for any errors in the above output " \ + "or in $(BUILDDIR)/linkcheck/output.txt." + +doctest: + $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest + @echo "Testing of doctests in the sources finished, look at the " \ + "results in $(BUILDDIR)/doctest/output.txt." diff --git a/doc/asyncio.rst b/doc/asyncio.rst new file mode 100644 index 00000000..5866d62f --- /dev/null +++ b/doc/asyncio.rst @@ -0,0 +1,185 @@ +++++++++++++++++++ +Trollius and Tulip +++++++++++++++++++ + +Differences between Trollius and Tulip +====================================== + +Syntax of coroutines +-------------------- + +The major difference between Trollius and Tulip is the syntax of coroutines: + +================== ====================== +Tulip Trollius +================== ====================== +``yield from ...`` ``yield From(...)`` +``yield from []`` ``yield From(None)`` +``return`` ``raise Return()`` +``return x`` ``raise Return(x)`` +``return x, y`` ``raise Return(x, y)`` +================== ====================== + +Because of this major difference, it was decided to call the module +``trollius`` instead of ``asyncio``. This choice also allows to use Trollius on +Python 3.4 and later. Changing imports is not enough to use Trollius code with +asyncio: the asyncio event loop explicit rejects coroutines using ``yield`` +(instead of ``yield from``). + +OSError and socket.error exceptions +----------------------------------- + +The ``OSError`` exception changed in Python 3.3: there are now subclasses like +``ConnectionResetError`` or ``BlockingIOError``. The exception hierarchy also +changed: ``socket.error`` is now an alias to ``OSError``. The ``asyncio`` +module is written for Python 3.3 and newer and so is based on these new +exceptions. + +.. seealso:: + + `PEP 3151: Reworking the OS and IO exception hierarchy + `_. + +On Python 3.2 and older, Trollius wraps ``OSError``, ``IOError``, +``socket.error`` and ``select.error`` exceptions on operating system and socket +operations to raise more specific exceptions, subclasses of ``OSError``: + +* ``trollius.BlockingIOError`` +* ``trollius.BrokenPipeError`` +* ``trollius.ChildProcessError`` +* ``trollius.ConnectionAbortedError`` +* ``trollius.ConnectionRefusedError`` +* ``trollius.ConnectionResetError`` +* ``trollius.FileNotFoundError`` +* ``trollius.InterruptedError`` +* ``trollius.PermissionError`` + +On Python 3.3 and newer, these symbols are just aliases to builtin exceptions. + +.. note:: + + ``ssl.SSLError`` exceptions are not wrapped to ``OSError``, even if + ``ssl.SSLError`` is a subclass of ``socket.error``. + + +SSLError +-------- + +On Python 3.2 and older, Trollius wraps ``ssl.SSLError`` exceptions to raise +more specific exceptions, subclasses of ``ssl.SSLError``, to mimic the Python +3.3: + +* ``trollius.SSLEOFError`` +* ``trollius.SSLWantReadError`` +* ``trollius.SSLWantWriteError`` + +On Python 3.3 and newer, these symbols are just aliases to exceptions of the +``ssl`` module. + +``trollius.BACKPORT_SSL_ERRORS`` constant: + +* ``True`` if ``ssl.SSLError`` are wrapped to Trollius exceptions (Python 2 + older than 2.7.9, or Python 3 older than 3.3), +* ``False`` is trollius SSL exceptions are just aliases. + + +SSLContext +---------- + +Python 3.3 has a new ``ssl.SSLContext`` class: see the `documentaton of the +ssl.SSLContext class +`_. + +On Python 3.2 and older, Trollius has a basic ``trollius.SSLContext`` class to +mimic Python 3.3 API, but it only has a few features: + +* ``protocol``, ``certfile`` and ``keyfile`` attributes +* read-only ``verify_mode`` attribute: its value is ``CERT_NONE`` +* ``load_cert_chain(certfile, keyfile)`` method +* ``wrap_socket(sock, **kw)`` method: see the ``ssl.wrap_socket()`` + documentation of your Python version for the keyword parameters + +Example of missing features: + +* no ``options`` attribute +* the ``verify_mode`` attriubte cannot be modified +* no ``set_default_verify_paths()`` method +* no "Server Name Indication" (SNI) support +* etc. + +On Python 3.2 and older, the trollius SSL transport does not have the +``'compression'`` extra info. + +``trollius.BACKPORT_SSL_CONTEXT`` constant: + +* ``True`` if ``trollius.SSLContext`` is the backported class (Python 2 older + than 2.7.9, or Python 3 older than 3.3), +* ``False`` if ``trollius.SSLContext`` is just an alias to ``ssl.SSLContext``. + + +Other differences +----------------- + +* Trollius uses the ``TROLLIUSDEBUG`` envrionment variable instead of + the ``PYTHONASYNCIODEBUG`` envrionment variable. ``TROLLIUSDEBUG`` variable + is used even if the Python command line option ``-E`` is used. +* ``asyncio.subprocess`` has no ``DEVNULL`` constant +* Python 2 does not support keyword-only parameters. +* If the ``concurrent.futures`` module is missing, + ``BaseEventLoop.run_in_executor()`` uses a synchronous executor instead of a + pool of threads. It blocks until the function returns. For example, DNS + resolutions are blocking in this case. +* Trollius has more symbols than Tulip for compatibility with Python older than + 3.3: + + - ``From``: part of ``yield From(...)`` syntax + - ``Return``: part of ``raise Return(...)`` syntax + + +Write code working on Trollius and Tulip +======================================== + +Trollius and Tulip are different, especially for coroutines (``yield +From(...)`` vs ``yield from ...``). + +To use asyncio or Trollius on Python 2 and Python 3, add the following code at +the top of your file:: + + try: + # Use builtin asyncio on Python 3.4+, or Tulip on Python 3.3 + import asyncio + except ImportError: + # Use Trollius on Python <= 3.2 + import trollius as asyncio + +It is possible to write code working on both projects using only callbacks. +This option is used by the following projects which work on Trollius and Tulip: + +* `AutobahnPython `_: WebSocket & + WAMP for Python, it works on Trollius (Python 2.6 and 2.7), Tulip (Python + 3.3) and Python 3.4 (asyncio), and also on Twisted. +* `Pulsar `_: Event driven concurrent + framework for Python. With pulsar you can write asynchronous servers + performing one or several activities in different threads and/or processes. + Trollius 0.3 requires Pulsar 0.8.2 or later. Pulsar uses the ``asyncio`` + module if available, or import ``trollius``. +* `Tornado `_ supports Tulip and Trollius since + Tornado 3.2: `tornado.platform.asyncio — Bridge between asyncio and Tornado + `_. It tries to import + asyncio or fallback on importing trollius. + +Another option is to provide functions returning ``Future`` objects, so the +caller can decide to use callback using ``fut.add_done_callback(callback)`` or +to use coroutines (``yield From(fut)`` for Trollius, or ``yield from fut`` for +Tulip). This option is used by the `aiodns `_ +project for example. + +Since Trollius 0.4, it's possible to use Tulip and Trollius coroutines in the +same process. The only limit is that the event loop must be a Trollius event +loop. + +.. note:: + + The Trollius module was called ``asyncio`` in Trollius version 0.2. The + module name changed to ``trollius`` to support Python 3.4. + diff --git a/doc/changelog.rst b/doc/changelog.rst new file mode 100644 index 00000000..1ddfc218 --- /dev/null +++ b/doc/changelog.rst @@ -0,0 +1,604 @@ +++++++++++ +Change log +++++++++++ + +Version 1.0.5 +============= + +Major changes: on Python 3.5+ ProactorEventLoop now supports SSL, a lot of +bugfixes (random race conditions) in the ProactorEventLoop. + +API changes: + +* Python issue #23209, #23225: selectors.BaseSelector.get_key() now raises a + RuntimeError if the selector is closed. And selectors.BaseSelector.close() + now clears its internal reference to the selector mapping to break a + reference cycle. Initial patch written by Martin Richard. +* PipeHandle.fileno() of asyncio.windows_utils now raises an exception if the + pipe is closed. +* Remove Overlapped.WaitNamedPipeAndConnect() of the _overlapped module, + it is no more used and it had issues. + +New SSL implementation: + +* Python issue #22560: On Python 3.5 and newer, a new SSL implementation based + on ssl.MemoryBIO instead of the legacy SSL implementation. Patch written by + Antoine Pitrou, based on the work of Geert Jansen. +* If available, the new SSL implementation can be used by ProactorEventLoop to + support SSL. + +Enhance, fix and cleanup the IocpProactor: + +* Python issue #23293: Rewrite IocpProactor.connect_pipe(). Add + _overlapped.ConnectPipe() which tries to connect to the pipe for asynchronous + I/O (overlapped): call CreateFile() in a loop until it doesn't fail with + ERROR_PIPE_BUSY. Use an increasing delay between 1 ms and 100 ms. +* Tulip issue #204: Fix IocpProactor.accept_pipe(). + Overlapped.ConnectNamedPipe() now returns a boolean: True if the pipe is + connected (if ConnectNamedPipe() failed with ERROR_PIPE_CONNECTED), False if + the connection is in progress. +* Tulip issue #204: Fix IocpProactor.recv(). If ReadFile() fails with + ERROR_BROKEN_PIPE, the operation is not pending: don't register the + overlapped. +* Python issue #23095: Rewrite _WaitHandleFuture.cancel(). + _WaitHandleFuture.cancel() now waits until the wait is cancelled to clear its + reference to the overlapped object. To wait until the cancellation is done, + UnregisterWaitEx() is used with an event instead of UnregisterWait(). +* Python issue #23293: Rewrite IocpProactor.connect_pipe() as a coroutine. Use + a coroutine with asyncio.sleep() instead of call_later() to ensure that the + schedule call is cancelled. +* Fix ProactorEventLoop.start_serving_pipe(). If a client connected before the + server was closed: drop the client (close the pipe) and exit +* Python issue #23293: Cleanup IocpProactor.close(). The special case for + connect_pipe() is not more needed. connect_pipe() doesn't use overlapped + operations anymore. +* IocpProactor.close(): don't cancel futures which are already cancelled +* Enhance (fix) BaseProactorEventLoop._loop_self_reading(). Handle correctly + CancelledError: just exit. On error, log the exception and exit; don't try to + close the event loop (it doesn't work). + +Bugfixes: + +* Close transports on error. Fix create_datagram_endpoint(), + connect_read_pipe() and connect_write_pipe(): close the transport if the task + is cancelled or on error. +* Close the transport on subprocess creation failure +* Fix _ProactorBasePipeTransport.close(). Set the _read_fut attribute to None + after cancelling it. +* Python issue #23243: Fix _UnixWritePipeTransport.close(). Do nothing if the + transport is already closed. Before it was not possible to close the + transport twice. +* Python issue #23242: SubprocessStreamProtocol now closes the subprocess + transport at subprocess exit. Clear also its reference to the transport. +* Fix BaseEventLoop._create_connection_transport(). Close the transport if the + creation of the transport (if the waiter) gets an exception. +* Python issue #23197: On SSL handshake failure, check if the waiter is + cancelled before setting its exception. +* Python issue #23173: Fix SubprocessStreamProtocol.connection_made() to handle + cancelled waiter. +* Python issue #23173: If an exception is raised during the creation of a + subprocess, kill the subprocess (close pipes, kill and read the return + status). Log an error in such case. +* Python issue #23209: Break some reference cycles in asyncio. Patch written by + Martin Richard. + +Changes: + +* Python issue #23208: Add BaseEventLoop._current_handle. In debug mode, + BaseEventLoop._run_once() now sets the BaseEventLoop._current_handle + attribute to the handle currently executed. +* Replace test_selectors.py with the file of Python 3.5 adapted for asyncio and + Python 3.3. +* Tulip issue #184: FlowControlMixin constructor now get the event loop if the + loop parameter is not set. +* _ProactorBasePipeTransport now sets the _sock attribute to None when the + transport is closed. +* Python issue #23219: cancelling wait_for() now cancels the task +* Python issue #23243: Close explicitly event loops and transports in tests +* Python issue #23140: Fix cancellation of Process.wait(). Check the state of + the waiter future before setting its result. +* Python issue #23046: Expose the BaseEventLoop class in the asyncio namespace +* Python issue #22926: In debug mode, call_soon(), call_at() and call_later() + methods of BaseEventLoop now use the identifier of the current thread to + ensure that they are called from the thread running the event loop. Before, + the get_event_loop() method was used to check the thread, and no exception + was raised when the thread had no event loop. Now the methods always raise an + exception in debug mode when called from the wrong thread. It should help to + notice misusage of the API. + + +2014-12-19: Version 1.0.4 +========================= + +Changes: + +* Python issue #22922: create_task(), call_at(), call_soon(), + call_soon_threadsafe() and run_in_executor() now raise an error if the event + loop is closed. Initial patch written by Torsten Landschoff. +* Python issue #22921: Don't require OpenSSL SNI to pass hostname to ssl + functions. Patch by Donald Stufft. +* Add run_aiotest.py: run the aiotest test suite. +* tox now also run the aiotest test suite +* Python issue #23074: get_event_loop() now raises an exception if the thread + has no event loop even if assertions are disabled. + +Bugfixes: + +* Fix a race condition in BaseSubprocessTransport._try_finish(): ensure that + connection_made() is called before connection_lost(). +* Python issue #23009: selectors, make sure EpollSelecrtor.select() works when + no file descriptor is registered. +* Python issue #22922: Fix ProactorEventLoop.close(). Call + _stop_accept_futures() before sestting the _closed attribute, otherwise + call_soon() raises an error. +* Python issue #22429: Fix EventLoop.run_until_complete(), don't stop the event + loop if a BaseException is raised, because the event loop is already stopped. +* Initialize more Future and Task attributes in the class definition to avoid + attribute errors in destructors. +* Python issue #22685: Set the transport of stdout and stderr StreamReader + objects in the SubprocessStreamProtocol. It allows to pause the transport to + not buffer too much stdout or stderr data. +* BaseSelectorEventLoop.close() now closes the self-pipe before calling the + parent close() method. If the event loop is already closed, the self-pipe is + not unregistered from the selector. + + +2014-10-20: Version 1.0.3 +========================= + +Changes: + +* On Python 2 in debug mode, Future.set_exception() now stores the traceback + object of the exception in addition to the exception object. When a task + waiting for another task and the other task raises an exception, the + traceback object is now copied with the exception. Be careful, storing the + traceback object may create reference leaks. +* Use ssl.create_default_context() if available to create the default SSL + context: Python 2.7.9 and newer, or Python 3.4 and newer. +* On Python 3.5 and newer, reuse socket.socketpair() in the windows_utils + submodule. +* On Python 3.4 and newer, use os.set_inheritable(). +* Enhance protocol representation: add "closed" or "closing" info. +* run_forever() now consumes BaseException of the temporary task. If the + coroutine raised a BaseException, consume the exception to not log a warning. + The caller doesn't have access to the local task. +* Python issue 22448: cleanup _run_once(), only iterate once to remove delayed + calls that were cancelled. +* The destructor of the Return class now shows where the Return object was + created. +* run_tests.py doesn't catch any exceptions anymore when loading tests, only + catch SkipTest. +* Fix (SSL) tests for the future Python 2.7.9 which includes a "new" ssl + module: module backported from Python 3.5. +* BaseEventLoop.add_signal_handler() now raises an exception if the parameter + is a coroutine function. +* Coroutine functions and objects are now rejected with a TypeError by the + following functions: add_signal_handler(), call_at(), call_later(), + call_soon(), call_soon_threadsafe(), run_in_executor(). + + +2014-10-02: Version 1.0.2 +========================= + +This release fixes bugs. It also provides more information in debug mode on +error. + +Major changes: + +* Tulip issue #203: Add _FlowControlMixin.get_write_buffer_limits() method. +* Python issue #22063: socket operations (socket,recv, sock_sendall, + sock_connect, sock_accept) of SelectorEventLoop now raise an exception in + debug mode if sockets are in blocking mode. + +Major bugfixes: + +* Tulip issue #205: Fix a race condition in BaseSelectorEventLoop.sock_connect(). +* Tulip issue #201: Fix a race condition in wait_for(). Don't raise a + TimeoutError if we reached the timeout and the future completed in the same + iteration of the event loop. A side effect of the bug is that Queue.get() + looses items. +* PipeServer.close() now cancels the "accept pipe" future which cancels the + overlapped operation. + +Other changes: + +* Python issue #22448: Improve cancelled timer callback handles cleanup. Patch + by Joshua Moore-Oliva. +* Python issue #22369: Change "context manager protocol" to "context management + protocol". Patch written by Serhiy Storchaka. +* Tulip issue #206: In debug mode, keep the callback in the representation of + Handle and TimerHandle after cancel(). +* Tulip issue #207: Fix test_tasks.test_env_var_debug() to use correct asyncio + module. +* runtests.py: display a message to mention if tests are run in debug or + release mode +* Tulip issue #200: Log errors in debug mode instead of simply ignoring them. +* Tulip issue #200: _WaitHandleFuture._unregister_wait() now catchs and logs + exceptions. +* _fatal_error() method of _UnixReadPipeTransport and _UnixWritePipeTransport + now log all exceptions in debug mode +* Fix debug log in BaseEventLoop.create_connection(): get the socket object + from the transport because SSL transport closes the old socket and creates a + new SSL socket object. +* Remove the _SelectorSslTransport._rawsock attribute: it contained the closed + socket (not very useful) and it was not used. +* Fix _SelectorTransport.__repr__() if the transport was closed +* Use the new os.set_blocking() function of Python 3.5 if available + + +2014-07-30: Version 1.0.1 +========================= + +This release supports PyPy and has a better support of asyncio coroutines, +especially in debug mode. + +Changes: + +* Tulip issue #198: asyncio.Condition now accepts an optional lock object. +* Enhance representation of Future and Future subclasses: add "created at". + +Bugfixes: + +* Fix Trollius issue #9: @trollius.coroutine now works on callbable objects + (without ``__name__`` attribute), not only on functions. +* Fix Trollius issue #13: asyncio futures are now accepted in all functions: + as_completed(), async(), @coroutine, gather(), run_until_complete(), + wrap_future(). +* Fix support of asyncio coroutines in debug mode. If the last instruction + of the coroutine is "yield from", it's an asyncio coroutine and it does not + need to use From(). +* Fix and enhance _WaitHandleFuture.cancel(): + + - Tulip issue #195: Fix a crash on Windows: don't call UnregisterWait() twice + if a _WaitHandleFuture is cancelled twice. + - Fix _WaitHandleFuture.cancel(): return the result of the parent cancel() + method (True or False). + - _WaitHandleFuture.cancel() now notify IocpProactor through the overlapped + object that the wait was cancelled. + +* Tulip issue #196: _OverlappedFuture now clears its reference to the + overlapped object. IocpProactor keeps a reference to the overlapped object + until it is notified of its completion. Log also an error in debug mode if it + gets unexpected notifications. +* Fix runtest.py to be able to log at level DEBUG. + +Other changes: + +* BaseSelectorEventLoop._write_to_self() now logs errors in debug mode. +* Fix as_completed(): it's not a coroutine, don't use ``yield From(...)`` but + ``yield ...`` +* Tulip issue #193: Convert StreamWriter.drain() to a classic coroutine. +* Tulip issue #194: Don't use sys.getrefcount() in unit tests: the full test + suite now pass on PyPy. + + +2014-07-21: Version 1.0 +======================= + +Major Changes +------------- + +* Event loops have a new ``create_task()`` method, which is now the recommanded + way to create a task object. This method can be overriden by third-party + event loops to use their own task class. +* The debug mode has been improved a lot. Set ``TROLLIUSDEBUG`` envrironment + variable to ``1`` and configure logging to log at level ``logging.DEBUG`` + (ex: ``logging.basicConfig(level=logging.DEBUG)``). Changes: + + - much better representation of Trollius objects (ex: ``repr(task)``): + unified ```` format, use qualified name when available + - show the traceback where objects were created + - show the current filename and line number for coroutine + - show the filename and line number where objects were created + - log most important socket events + - log most important subprocess events + +* ``Handle.cancel()`` now clears references to callback and args +* Log an error if a Task is destroyed while it is still pending, but only on + Python 3.4 and newer. +* Fix for asyncio coroutines when passing tuple value in debug mode. + ``CoroWrapper.send()`` now checks if it is called from a "yield from" + generator to decide if the parameter should be unpacked or not. +* ``Process.communicate()`` now ignores ``BrokenPipeError`` and + ``ConnectionResetError`` exceptions. +* Rewrite signal handling on Python 3.3 and newer to fix a race condition: use + the "self-pipe" to get signal numbers. + + +Other Changes +------------- + +* Fix ``ProactorEventLoop()`` in debug mode +* Fix a race condition when setting the result of a Future with + ``call_soon()``. Add an helper, a private method, to set the result only if + the future was not cancelled. +* Fix ``asyncio.__all__``: export also ``unix_events`` and ``windows_events`` + symbols. For example, on Windows, it was not possible to get + ``ProactorEventLoop`` or ``DefaultEventLoopPolicy`` using ``from asyncio + import *``. +* ``Handle.cancel()`` now clears references to callback and args +* Make Server attributes and methods private, the sockets attribute remains + public. +* BaseEventLoop.create_datagram_endpoint() now waits until + protocol.connection_made() has been called. Document also why transport + constructors use a waiter. +* _UnixSubprocessTransport: fix file mode of stdin: open stdin in write mode, + not in read mode. + + +2014-06-23: version 0.4 +======================= + +Changes between Trollius 0.3 and 0.4: + +* Trollius event loop now supports asyncio coroutines: + + - Trollius coroutines can yield asyncio coroutines, + - asyncio coroutines can yield Trollius coroutines, + - asyncio.set_event_loop() accepts a Trollius event loop, + - asyncio.set_event_loop_policy() accepts a Trollius event loop policy. + +* The ``PYTHONASYNCIODEBUG`` envrionment variable has been renamed to + ``TROLLIUSDEBUG``. The environment variable is now used even if the Python + command line option ``-E`` is used. +* Synchronize with Tulip. +* Support PyPy (fix subproces, fix unit tests). + +Tulip changes: + +* Tulip issue #171: BaseEventLoop.close() now raises an exception if the event + loop is running. You must first stop the event loop and then wait until it + stopped, before closing it. +* Tulip issue #172: only log selector timing in debug mode +* Enable the debug mode of event loops when the ``TROLLIUSDEBUG`` environment + variable is set +* BaseEventLoop._assert_is_current_event_loop() now only raises an exception if + the current loop is set. +* Tulip issue #105: in debug mode, log callbacks taking more than 100 ms to be + executed. +* Python issue 21595: ``BaseSelectorEventLoop._read_from_self()`` reads all + available bytes from the "self pipe", not only a single byte. This change + reduces the risk of having the pipe full and so getting the "BlockingIOError: + [Errno 11] Resource temporarily unavailable" message. +* Python issue 21723: asyncio.Queue: support any type of number (ex: float) for + the maximum size. Patch written by Vajrasky Kok. +* Issue #173: Enhance repr(Handle) and repr(Task): add the filename and line + number, when available. For task, the current line number of the coroutine + is used. +* Add BaseEventLoop.is_closed() method. run_forever() and run_until_complete() + methods now raises an exception if the event loop was closed. +* Make sure that socketpair() close sockets on error. Close the listening + socket if sock.bind() raises an exception. +* Fix ResourceWarning: close sockets on errors. + BaseEventLoop.create_connection(), BaseEventLoop.create_datagram_endpoint() + and _UnixSelectorEventLoop.create_unix_server() now close the newly created + socket on error. +* Rephrase and fix docstrings. +* Fix tests on Windows: wait for the subprocess exit. Before, regrtest failed + to remove the temporary test directory because the process was still running + in this directory. +* Refactor unit tests. + +On Python 3.5, generators now get their name from the function, no more from +the code. So the ``@coroutine`` decorator doesn't loose the original name of +the function anymore. + + +2014-05-26: version 0.3 +======================= + +Rename the Python module ``asyncio`` to ``trollius`` to support Python 3.4. On +Python 3.4, there is already a module called ``asyncio`` in the standard +library which conflicted with ``asyncio`` module of Trollius 0.2. To write +asyncio code working on Trollius and Tulip, use ``import trollius as asyncio``. + +Changes between Trollius 0.2 and 0.3: + +* Synchronize with Tulip 3.4.1. +* Enhance Trollius documentation. +* Trollius issue #7: Fix ``asyncio.time_monotonic`` on Windows older than + Vista (ex: Windows 2000 and Windows XP). +* Fedora packages have been accepted. + +Changes between Tulip 3.4.0 and 3.4.1: + +* Pull in Solaris ``devpoll`` support by Giampaolo Rodola + (``trollius.selectors`` module). +* Add options ``-r`` and ``--randomize`` to runtests.py to randomize test + order. +* Add a simple echo client/server example. +* Tulip issue #166: Add ``__weakref__`` slots to ``Handle`` and + ``CoroWrapper``. +* ``EventLoop.create_unix_server()`` now raises a ``ValueError`` if path and + sock are specified at the same time. +* Ensure ``call_soon()``, ``call_later()`` and ``call_at()`` are invoked on + current loop in debug mode. Raise a ``RuntimeError`` if the event loop of the + current thread is different. The check should help to debug thread-safetly + issue. Patch written by David Foster. +* Tulip issue #157: Improve test_events.py, avoid ``run_briefly()`` which is + not reliable. +* Reject add/remove reader/writer when event loop is closed. + +Bugfixes of Tulip 3.4.1: + +* Tulip issue #168: ``StreamReader.read(-1)`` from pipe may hang if + data exceeds buffer limit. +* CPython issue #21447: Fix a race condition in + ``BaseEventLoop._write_to_self()``. +* Different bugfixes in ``CoroWrapper`` of ``trollius.coroutines``, class used + when running Trollius in debug mode: + + - Fix ``CoroWrapper`` to workaround yield-from bug in CPython 3.4.0. The + CPython bug is now fixed in CPython 3.4.1 and 3.5. + - Make sure ``CoroWrapper.send`` proxies one argument correctly. + - CPython issue #21340: Be careful accessing instance variables in ``__del__``. + - Tulip issue #163: Add ``gi_{frame,running,code}`` properties to + ``CoroWrapper``. + +* Fix ``ResourceWarning`` warnings +* Tulip issue #159: Fix ``windows_utils.socketpair()``. Use ``"127.0.0.1"`` + (IPv4) or ``"::1"`` (IPv6) host instead of ``"localhost"``, because + ``"localhost"`` may be a different IP address. Reject also invalid arguments: + only ``AF_INET`` and ``AF_INET6`` with ``SOCK_STREAM`` (and ``proto=0``) are + supported. +* Tulip issue #158: ``Task._step()`` now also sets ``self`` to ``None`` if an + exception is raised. ``self`` is set to ``None`` to break a reference cycle. + + +2014-03-04: version 0.2 +======================= + +Trollius now uses ``yield From(...)`` syntax which looks close to Tulip ``yield +from ...`` and allows to port more easily Trollius code to Tulip. The usage of +``From()`` is not mandatory yet, but it may become mandatory in a future +version. However, if ``yield`` is used without ``From``, an exception is +raised if the event loop is running in debug mode. + +Major changes: + +* Replace ``yield ...`` syntax with ``yield From(...)`` +* On Python 2, Future.set_exception() now only saves the traceback if the debug + mode of the event loop is enabled for best performances in production mode. + Use ``loop.set_debug(True)`` to save the traceback. + +Bugfixes: + +* Fix ``BaseEventLoop.default_exception_handler()`` on Python 2: get the + traceback from ``sys.exc_info()`` +* Fix unit tests on SSL sockets on Python older than 2.6.6. Example: + Mac OS 10.6 with Python 2.6.1 or OpenIndiana 148 with Python 2.6.4. +* Fix error handling in the asyncio.time_monotonic module +* Fix acquire() method of Lock, Condition and Semaphore: don't return a context + manager but True, as Tulip. Task._step() now does the trick. + +Other changes: + +* tox.ini: set PYTHONASYNCIODEBUG to 1 to run tests + +2014-02-25: version 0.1.6 +========================= + +Trollius changes: + +* Add a new Sphinx documentation: + http://trollius.readthedocs.org/ +* tox: pass posargs to nosetests. Patch contributed by Ian Wienand. +* Fix support of Python 3.2 and add py32 to tox.ini +* Merge with Tulip 0.4.1 + +Major changes of Tulip 0.4.1: + +* Issue #81: Add support for UNIX Domain Sockets. New APIs: + + - loop.create_unix_connection() + - loop.create_unix_server() + - streams.open_unix_connection() + - streams.start_unix_server() + +* Issue #80: Add new event loop exception handling API. New APIs: + + - loop.set_exception_handler() + - loop.call_exception_handler() + - loop.default_exception_handler() + +* Issue #136: Add get_debug() and set_debug() methods to BaseEventLoopTests. + Add also a ``PYTHONASYNCIODEBUG`` environment variable to debug coroutines + since Python startup, to be able to debug coroutines defined directly in the + asyncio module. + +Other changes of Tulip 0.4.1: + +* asyncio.subprocess: Fix a race condition in communicate() +* Fix _ProactorWritePipeTransport._pipe_closed() +* Issue #139: Improve error messages on "fatal errors". +* Issue #140: WriteTransport.set_write_buffer_size() to call + _maybe_pause_protocol() +* Issue #129: BaseEventLoop.sock_connect() now raises an error if the address + is not resolved (hostname instead of an IP address) for AF_INET and + AF_INET6 address families. +* Issue #131: as_completed() and wait() now raises a TypeError if the list of + futures is not a list but a Future, Task or coroutine object +* Python issue #20495: Skip test_read_pty_output() of test_asyncio on FreeBSD + older than FreeBSD 8 +* Issue #130: Add more checks on subprocess_exec/subprocess_shell parameters +* Issue #126: call_soon(), call_soon_threadsafe(), call_later(), call_at() + and run_in_executor() now raise a TypeError if the callback is a coroutine + function. +* Python issue #20505: BaseEventLoop uses again the resolution of the clock + to decide if scheduled tasks should be executed or not. + + +2014-02-10: version 0.1.5 +========================= + +- Merge with Tulip 0.3.1: + + * New asyncio.subprocess module + * _UnixWritePipeTransport now also supports character devices, as + _UnixReadPipeTransport. Patch written by Jonathan Slenders. + * StreamReader.readexactly() now raises an IncompleteReadError if the + end of stream is reached before we received enough bytes, instead of + returning less bytes than requested. + * poll and epoll selectors now round the timeout away from zero (instead of + rounding towards zero) to fix a performance issue + * asyncio.queue: Empty renamed to QueueEmpty, Full to QueueFull + * _fatal_error() of _UnixWritePipeTransport and _ProactorBasePipeTransport + don't log BrokenPipeError nor ConnectionResetError + * Future.set_exception(exc) now instanciate exc if it is a class + * streams.StreamReader: Use bytearray instead of deque of bytes for internal + buffer + +- Fix test_wait_for() unit test + +2014-01-22: version 0.1.4 +========================= + +- The project moved to https://bitbucket.org/enovance/trollius +- Fix CoroWrapper (_DEBUG=True): add missing import +- Emit a warning when Return is not raised +- Merge with Tulip to get latest Tulip bugfixes +- Fix dependencies in tox.ini for the different Python versions + +2014-01-13: version 0.1.3 +========================= + +- Workaround bugs in the ssl module of Python older than 2.6.6. For example, + Mac OS 10.6 (Snow Leopard) uses Python 2.6.1. +- ``return x, y`` is now written ``raise Return(x, y)`` instead of + ``raise Return((x, y))`` +- Support "with (yield lock):" syntax for Lock, Condition and Semaphore +- SSL support is now optional: don't fail if the ssl module is missing +- Add tox.ini, tool to run unit tests. For example, "tox -e py27" creates a + virtual environment to run tests with Python 2.7. + +2014-01-08: version 0.1.2 +========================= + +- Trollius now supports CPython 2.6-3.4, PyPy and Windows. All unit tests + pass with CPython 2.7 on Linux. +- Fix Windows support. Fix compilation of the _overlapped module and add a + asyncio._winapi module (written in pure Python). Patch written by Marc + Schlaich. +- Support Python 2.6: require an extra dependency, + ordereddict (and unittest2 for unit tests) +- Support Python 3.2, 3.3 and 3.4 +- Support PyPy 2.2 +- Don't modify __builtins__ nor the ssl module to inject backported exceptions + like BlockingIOError or SSLWantReadError. Exceptions are available in the + asyncio module, ex: asyncio.BlockingIOError. + +2014-01-06: version 0.1.1 +========================= + +- Fix asyncio.time_monotonic on Mac OS X +- Fix create_connection(ssl=True) +- Don't export backported SSLContext in the ssl module anymore to not confuse + libraries testing hasattr(ssl, "SSLContext") +- Relax dependency on the backported concurrent.futures module: use a + synchronous executor if the module is missing + +2014-01-04: version 0.1 +======================= + +- First public release + diff --git a/doc/conf.py b/doc/conf.py new file mode 100644 index 00000000..0bcd2da9 --- /dev/null +++ b/doc/conf.py @@ -0,0 +1,240 @@ +# -*- coding: utf-8 -*- +# +# Trollius documentation build configuration file, created by +# sphinx-quickstart on Fri Feb 21 11:05:42 2014. +# +# This file is execfile()d with the current directory set to its containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +#import sys, os +#sys.path.insert(0, os.path.abspath('.')) + +# -- General configuration ----------------------------------------------------- + +# If your documentation needs a minimal Sphinx version, state it here. +#needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be extensions +# coming with Sphinx (named 'sphinx.ext.*') or your custom ones. +extensions = [] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['templates'] + +# The suffix of source filenames. +source_suffix = '.rst' + +# The encoding of source files. +#source_encoding = 'utf-8-sig' + +# The master toctree document. +master_doc = 'index' + +# General information about the project. +project = u'Trollius' +copyright = u'2014, Victor Stinner' + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +version = release = '1.0.5' + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +#language = None + +# There are two options for replacing |today|: either, you set today to some +# non-false value, then it is used: +#today = '' +# Else, today_fmt is used as the format for a strftime call. +#today_fmt = '%B %d, %Y' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +exclude_patterns = ['build'] + +# The reST default role (used for this markup: `text`) to use for all documents. +#default_role = None + +# If true, '()' will be appended to :func: etc. cross-reference text. +#add_function_parentheses = True + +# If true, the current module name will be prepended to all description +# unit titles (such as .. function::). +#add_module_names = True + +# If true, sectionauthor and moduleauthor directives will be shown in the +# output. They are ignored by default. +#show_authors = False + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + +# A list of ignored prefixes for module index sorting. +#modindex_common_prefix = [] + + +# -- Options for HTML output --------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +html_theme = 'default' + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +#html_theme_options = {} + +# Add any paths that contain custom themes here, relative to this directory. +#html_theme_path = [] + +# The name for this set of Sphinx documents. If None, it defaults to +# " v documentation". +#html_title = None + +# A shorter title for the navigation bar. Default is the same as html_title. +#html_short_title = None + +# The name of an image file (relative to this directory) to place at the top +# of the sidebar. +#html_logo = None + +# The name of an image file (within the static path) to use as favicon of the +# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 +# pixels large. +#html_favicon = None + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['static'] + +# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, +# using the given strftime format. +#html_last_updated_fmt = '%b %d, %Y' + +# If true, SmartyPants will be used to convert quotes and dashes to +# typographically correct entities. +#html_use_smartypants = True + +# Custom sidebar templates, maps document names to template names. +#html_sidebars = {} + +# Additional templates that should be rendered to pages, maps page names to +# template names. +#html_additional_pages = {} + +# If false, no module index is generated. +#html_domain_indices = True + +# If false, no index is generated. +#html_use_index = True + +# If true, the index is split into individual pages for each letter. +#html_split_index = False + +# If true, links to the reST sources are added to the pages. +#html_show_sourcelink = True + +# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. +#html_show_sphinx = True + +# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. +#html_show_copyright = True + +# If true, an OpenSearch description file will be output, and all pages will +# contain a tag referring to it. The value of this option must be the +# base URL from which the finished HTML is served. +#html_use_opensearch = '' + +# This is the file name suffix for HTML files (e.g. ".xhtml"). +#html_file_suffix = None + +# Output file base name for HTML help builder. +htmlhelp_basename = 'Trolliusdoc' + + +# -- Options for LaTeX output -------------------------------------------------- + +latex_elements = { +# The paper size ('letterpaper' or 'a4paper'). +#'papersize': 'letterpaper', + +# The font size ('10pt', '11pt' or '12pt'). +#'pointsize': '10pt', + +# Additional stuff for the LaTeX preamble. +#'preamble': '', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, author, documentclass [howto/manual]). +latex_documents = [ + ('index', 'Trollius.tex', u'Trollius Documentation', + u'Victor Stinner', 'manual'), +] + +# The name of an image file (relative to this directory) to place at the top of +# the title page. +#latex_logo = None + +# For "manual" documents, if this is true, then toplevel headings are parts, +# not chapters. +#latex_use_parts = False + +# If true, show page references after internal links. +#latex_show_pagerefs = False + +# If true, show URL addresses after external links. +#latex_show_urls = False + +# Documents to append as an appendix to all manuals. +#latex_appendices = [] + +# If false, no module index is generated. +#latex_domain_indices = True + + +# -- Options for manual page output -------------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + ('index', 'trollius', u'Trollius Documentation', + [u'Victor Stinner'], 1) +] + +# If true, show URL addresses after external links. +#man_show_urls = False + + +# -- Options for Texinfo output ------------------------------------------------ + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + ('index', 'Trollius', u'Trollius Documentation', + u'Victor Stinner', 'Trollius', 'One line description of project.', + 'Miscellaneous'), +] + +# Documents to append as an appendix to all manuals. +#texinfo_appendices = [] + +# If false, no module index is generated. +#texinfo_domain_indices = True + +# How to display URL addresses: 'footnote', 'no', or 'inline'. +#texinfo_show_urls = 'footnote' diff --git a/doc/dev.rst b/doc/dev.rst new file mode 100644 index 00000000..5965306c --- /dev/null +++ b/doc/dev.rst @@ -0,0 +1,85 @@ +Run tests +========= + +Run tests with tox +------------------ + +The `tox project `_ can be used to build a +virtual environment with all runtime and test dependencies and run tests +against different Python versions (2.6, 2.7, 3.2, 3.3). + +For example, to run tests with Python 2.7, just type:: + + tox -e py27 + +To run tests against other Python versions: + +* ``py26``: Python 2.6 +* ``py27``: Python 2.7 +* ``py32``: Python 3.2 +* ``py33``: Python 3.3 + + +Test Dependencies +----------------- + +On Python older than 3.3, unit tests require the `mock +`_ module. Python 2.6 requires also +`unittest2 `_. + +To run ``run_aiotest.py``, you need the `aiotest +`_ test suite: ``pip install aiotest``. + + +Run tests on UNIX +----------------- + +Run the following commands from the directory of the Trollius project. + +To run tests:: + + make test + +To run coverage (``coverage`` package is required):: + + make coverage + + +Run tests on Windows +-------------------- + +Run the following commands from the directory of the Trollius project. + +You can run the tests as follows:: + + C:\Python27\python.exe runtests.py + +And coverage as follows:: + + C:\Python27\python.exe runtests.py --coverage + + +CPython bugs +============ + +The development of asyncio and trollius helped to identify different bugs in CPython: + +* 2.5.0 <= python <= 3.4.2: `sys.exc_info() bug when yield/yield-from is used + in an except block in a generator (#23353>) + `_. The fix will be part of Python 3.4.3. + _UnixSelectorEventLoop._make_subprocess_transport() and + ProactorEventLoop._make_subprocess_transport() work around the bug. +* python == 3.4.0: `Segfault in gc with cyclic trash (#21435) + `_. + Regression introduced in Python 3.4.0, fixed in Python 3.4.1. + Status in Ubuntu the February, 3th 2015: only Ubuntu Trusty (14.04 LTS) is + impacted (`bug #1367907: Segfault in gc with cyclic trash + `_, see + also `update Python3 for trusty #1348954 + `_) +* 3.3.0 <= python <= 3.4.0: `gen.send(tuple) unpacks the tuple instead of + passing 1 argument (the tuple) when gen is an object with a send() method, + not a classic generator (#21209) `_. + Regression introduced in Python 3.4.0, fixed in Python 3.4.1. + trollius.CoroWrapper.send() works around the issue, the bug is checked at + runtime once, when the module is imported. diff --git a/doc/index.rst b/doc/index.rst new file mode 100644 index 00000000..86250930 --- /dev/null +++ b/doc/index.rst @@ -0,0 +1,80 @@ +Trollius +======== + +.. image:: trollius.jpg + :alt: Trollius altaicus from Khangai Mountains (Mongòlia) + :align: right + :target: http://commons.wikimedia.org/wiki/File:Trollius_altaicus.jpg + +Trollius provides infrastructure for writing single-threaded concurrent +code using coroutines, multiplexing I/O access over sockets and other +resources, running network clients and servers, and other related primitives. +Here is a more detailed list of the package contents: + +* a pluggable event loop with various system-specific implementations; + +* transport and protocol abstractions (similar to those in `Twisted + `_); + +* concrete support for TCP, UDP, SSL, subprocess pipes, delayed calls, and + others (some may be system-dependent); + +* a ``Future`` class that mimics the one in the ``concurrent.futures`` module, + but adapted for use with the event loop; + +* coroutines and tasks based on generators (``yield``), to help write + concurrent code in a sequential fashion; + +* cancellation support for ``Future``\s and coroutines; + +* synchronization primitives for use between coroutines in a single thread, + mimicking those in the ``threading`` module; + +* an interface for passing work off to a threadpool, for times when you + absolutely, positively have to use a library that makes blocking I/O calls. + +Trollius is a portage of the `Tulip project `_ +(``asyncio`` module, `PEP 3156 `_) +on Python 2. Trollius works on Python 2.6-3.5. It has been tested on Windows, +Linux, Mac OS X, FreeBSD and OpenIndiana. + +* `Asyncio documentation `_ +* `Trollius documentation `_ (this document) +* `Trollius project in the Python Cheeseshop (PyPI) + `_ (download wheel packages and + tarballs) +* `Trollius project at Bitbucket `_ + (bug tracker, source code) +* Mailing list: `python-tulip Google Group + `_ +* IRC: ``#asyncio`` channel on the `Freenode network `_ +* Copyright/license: Open source, Apache 2.0. Enjoy! + +See also the `Tulip project `_ (asyncio module +for Python 3.3). + + +Table Of Contents +================= + +.. toctree:: + + using + install + asyncio + dev + changelog + + +Trollius name +============= + +Extract of `Trollius Wikipedia article +`_: + +Trollius is a genus of about 30 species of plants in the family Ranunculaceae, +closely related to Ranunculus. The common name of some species is globeflower +or globe flower. Native to the cool temperate regions of the Northern +Hemisphere, with the greatest diversity of species in Asia, trollius usually +grow in heavy, wet clay soils. + diff --git a/doc/install.rst b/doc/install.rst new file mode 100644 index 00000000..ea2b4557 --- /dev/null +++ b/doc/install.rst @@ -0,0 +1,111 @@ +++++++++++++++++ +Install Trollius +++++++++++++++++ + +Packages for Linux +================== + +* `Debian package + `_ +* `ArchLinux package + `_ +* `Fedora and CentOS package: python-trollius + `_ + + +Install Trollius on Windows using pip +===================================== + +Since Trollius 0.2, `precompiled wheel packages `_ +are now distributed on the Python Cheeseshop (PyPI). Procedure to install +Trollius on Windows: + +* `Install pip + `_, download + ``get-pip.py`` and type:: + + \Python27\python.exe get-pip.py + +* If you already have pip, ensure that you have at least pip 1.4. If you need + to upgrade:: + + \Python27\python.exe -m pip install -U pip + +* Install Trollius:: + + \Python27\python.exe -m pip install trollius + +* pip also installs the ``futures`` dependency + +.. note:: + + Only wheel packages for Python 2.7 are currently distributed on the + Cheeseshop (PyPI). If you need wheel packages for other Python versions, + please ask. + +Download source code +==================== + +Command to download the development version of the source code (``trollius`` +branch):: + + hg clone 'https://bitbucket.org/enovance/trollius#trollius' + +The actual code lives in the ``trollius`` subdirectory. Tests are in the +``tests`` subdirectory. + +See the `trollius project at Bitbucket +`_. + +The source code of the Trollius project is in the ``trollius`` branch of the +Mercurial repository, not in the default branch. The default branch is the +Tulip project, Trollius repository is a fork of the Tulip repository. + + +Dependencies +============ + +On Python older than 3.2, the `futures `_ +project is needed to get a backport of ``concurrent.futures``. + +Python 2.6 requires also `ordereddict +`_. + + +Build manually Trollius on Windows +================================== + +On Windows, if you cannot use precompiled wheel packages, an extension module +must be compiled: the ``_overlapped`` module (source code: ``overlapped.c``). +Read `Compile Python extensions on Windows +`_ +to prepare your environment to build the Python extension. Then build the +extension using:: + + C:\Python27\python.exe setup.py build_ext + + +Backports +========= + +To support Python 2.6-3.4, many Python modules of the standard library have +been backported: + +======================== ========= ======================= +Name Python Backport +======================== ========= ======================= +OSError 3.3 asyncio.py33_exceptions +_overlapped 3.4 asyncio._overlapped +_winapi 3.3 asyncio.py33_winapi +collections.OrderedDict 2.7, 3.1 ordereddict (PyPI) +concurrent.futures 3.2 futures (PyPI) +selectors 3.4 asyncio.selectors +ssl 3.2, 3.3 asyncio.py3_ssl +time.monotonic 3.3 asyncio.time_monotonic +unittest 2.7, 3.1 unittest2 (PyPI) +unittest.mock 3.3 mock (PyPI) +weakref.WeakSet 2.7, 3.0 asyncio.py27_weakrefset +======================== ========= ======================= + + + diff --git a/doc/make.bat b/doc/make.bat new file mode 100644 index 00000000..5789d413 --- /dev/null +++ b/doc/make.bat @@ -0,0 +1,190 @@ +@ECHO OFF + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set BUILDDIR=_build +set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . +set I18NSPHINXOPTS=%SPHINXOPTS% . +if NOT "%PAPER%" == "" ( + set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% + set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% +) + +if "%1" == "" goto help + +if "%1" == "help" ( + :help + echo.Please use `make ^` where ^ is one of + echo. html to make standalone HTML files + echo. dirhtml to make HTML files named index.html in directories + echo. singlehtml to make a single large HTML file + echo. pickle to make pickle files + echo. json to make JSON files + echo. htmlhelp to make HTML files and a HTML help project + echo. qthelp to make HTML files and a qthelp project + echo. devhelp to make HTML files and a Devhelp project + echo. epub to make an epub + echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter + echo. text to make text files + echo. man to make manual pages + echo. texinfo to make Texinfo files + echo. gettext to make PO message catalogs + echo. changes to make an overview over all changed/added/deprecated items + echo. linkcheck to check all external links for integrity + echo. doctest to run all doctests embedded in the documentation if enabled + goto end +) + +if "%1" == "clean" ( + for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i + del /q /s %BUILDDIR%\* + goto end +) + +if "%1" == "html" ( + %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/html. + goto end +) + +if "%1" == "dirhtml" ( + %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. + goto end +) + +if "%1" == "singlehtml" ( + %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. + goto end +) + +if "%1" == "pickle" ( + %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the pickle files. + goto end +) + +if "%1" == "json" ( + %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the JSON files. + goto end +) + +if "%1" == "htmlhelp" ( + %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run HTML Help Workshop with the ^ +.hhp project file in %BUILDDIR%/htmlhelp. + goto end +) + +if "%1" == "qthelp" ( + %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run "qcollectiongenerator" with the ^ +.qhcp project file in %BUILDDIR%/qthelp, like this: + echo.^> qcollectiongenerator %BUILDDIR%\qthelp\Trollius.qhcp + echo.To view the help file: + echo.^> assistant -collectionFile %BUILDDIR%\qthelp\Trollius.ghc + goto end +) + +if "%1" == "devhelp" ( + %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. + goto end +) + +if "%1" == "epub" ( + %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The epub file is in %BUILDDIR%/epub. + goto end +) + +if "%1" == "latex" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "text" ( + %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The text files are in %BUILDDIR%/text. + goto end +) + +if "%1" == "man" ( + %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The manual pages are in %BUILDDIR%/man. + goto end +) + +if "%1" == "texinfo" ( + %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. + goto end +) + +if "%1" == "gettext" ( + %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The message catalogs are in %BUILDDIR%/locale. + goto end +) + +if "%1" == "changes" ( + %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes + if errorlevel 1 exit /b 1 + echo. + echo.The overview file is in %BUILDDIR%/changes. + goto end +) + +if "%1" == "linkcheck" ( + %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck + if errorlevel 1 exit /b 1 + echo. + echo.Link check complete; look for any errors in the above output ^ +or in %BUILDDIR%/linkcheck/output.txt. + goto end +) + +if "%1" == "doctest" ( + %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest + if errorlevel 1 exit /b 1 + echo. + echo.Testing of doctests in the sources finished, look at the ^ +results in %BUILDDIR%/doctest/output.txt. + goto end +) + +:end diff --git a/doc/trollius.jpg b/doc/trollius.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f4976c72b14a6039ddbcf871f0e1b5de4988392b GIT binary patch literal 30083 zcmb4~RahL+wyir@aCdiy;1Z;RyEN_ucXto&?jGDBxI=IncX#*TE|j zEFwH2A_6=D0unMhDiSgpG6DiBHYyqhCKeVJA_@*JHYP4QCKl#@9`Xh9^ByQ@IA~}% zOe6#(%>VED=mns`egS*|Kz>03K%jkrMEmm54MO(-0PJU% z;LHDX|F;JKKtg?ifCj*P{a67YLVk8pAyENemj%oFy@0;?e1s_ksjsQ%+SeL3!n*MDidK?>7uKQ5_ zXY#ECYQ#yL5Bscop(am9K z`zBfS+KfvyVR5hLDZ44`c=3v&MDXj;@KX?VO^lc$*oUCKS!R+>-)ZLlY@dn9eeVbA z@Mw-n-}*9FXk6e5p1BA>adY&mdM$WQ)9s_W*o_xV#8$8)6D1=ER3xQM>3KI>JQg8Z zeiN{p(u}5t!khj3t9aoxv~0eErue`{b|~Tjf9=L#9HDoR-*7V{=%n*@g<_xnr91sF zL&OKjGN6|K=!H`&{bvA6#+RnO21jxl`x~sCueA_|7{Zz2t{utuXQ8!F)8Cb0n0Urs zq8X9Nngl1q2l#2+#G{65t*A80GTegqz7nc4gN67e*i9mp?uRWoyX_Y-J4nTPOfc-# z{{r%hr|ejIo#ftsKdDb^uzj_0@W2U6As$Ut*E_-~c9Q)r%UWY;rD!^%mW(Kg`&F={ zr`m3dszumZ1eOAVm0rPZ_5oDEQ#eyau0k+B#nKx1-N{b+18`Ck_af}k6tnOvWM?6N zl>3lP1uS{Ks_Peq8hhfzUE%d)HDI8|vL^6vJo>~0sl?R5Z^Dzhf?t+9|A%>wy;)qb z-4R%JZ{9mu?RIbF`2&#mu6bHeMo~*fM$tI*XF9^fnN|7l0)UlXs_tR~(;z!wczcU` zPmVnkO&-Nfjqs4d{hl?f=l?^zuX>vUhV{)X4S*o2b zp=@$Vytu_DdTAmiNjI&LJggBW?`aF-DeRI3YdKkDkwwQ(Gh{0*Y=EIqB;c5+U`8_& z#5g?~JJR9u?`_W59Q;hShUWrD>}EDO?)R>mN?A4Ru!oeNDsC2@indxArYi#U%Llk{ z&mspzv|;-#bN9m}&DBHc&^FsEVX2ZQ8mL}#$_--;Do5wbQav{_znzQ!+1Os@I^0Zznq3u&?Jm_FSF z>GWq!>LJggw^A{WR3oND)c*`#JfR1v#!w_ga)%DD;@rtaSI& z638j^?LH9MJGWzPEl6^t=;SLmLCOB)8gy;p{bz5ng(}osh{*aX zz2=vBpH}!WzLu{s{!96H0=eU@u?ooQQ&aT}nq1I1x8jiO22ZF&=o{n!U~1yiRW7+iE2XWO>#_arznCJH8kFWTupmHTR+qiBqQd3xn>M?-B zi{L2zixvhf;G^J}+Fb$wM=aqSrlDtJ;ju@t2DncslFf`N}*iY=LyPD5hoI9E0 z%)b@1)-SlGkrBra#4Dijey9{Ud?Vp1$YDui7lKMKm`%8|`ZJ$?7}~P3k&+bqOHzk1 zxXj|4YB@Z1b7Ny17=Rd7|F0CvXWf>sjM%6_`L;_2|29V&t|0z2$>I!)& zV|TivwMJ2<#;IS!l$bzi`5k%6`H_y3YS12cK z2X@hpJM2BphG@F9uKR2et@O4%IxK0>(%EviXko%hxL+LaHifp77>mE3$XXT@Fcu${1^u^ znm13n=?Xshte+_)w^Q4OJl)u-s%5O8WDF`V&&sa4k|T$oE7nRyuy-&wqPDIo+^Y!rU#6>*N;yjvSkB81XJOYl(N!-gR;B)ix=x{-BM3Z4M9pZey@ox17}3cZivv( z0TQIp`rgC(#)AvGc>=p%ZQ4gpc;CUn5&Vxk4LfUd(i_`qY32sF;NNxQglb%i+zr zxIEdv>w2I>B2uYRFj>DE59O(dAaV!Y1i^gbNVPTKYQ?cz~0*N%Z9ff%InQBmTM?$m?V`k{07 zz;5h}QY1E=S`yTkzD&F?S`<}nHHb&Hn&FX3PA;0t(+hutMdQ#C-lqR4%qkR5q0hi} z=%vW$dv#r$w$Ysut^7KvUr_JBA+jzJdx*Ads?e;#IFwZ_189wl-+kp}EE=?6sisYl zv(A-pA<}8CG=C{3z9{Mb1Un{F0K{i~f%?z&$#MUsKqvq-8afOr1|}IfD;5RiS8O(R zQYvAgPt=3`=xn-&_^&iPuF%q*%1vw|Efh|DxA zUH{%?s8|IYyp&yh)AA{O4tu-7&F!!c72H5TGR|(F4O^|;TGTq~mNn;Fc-x9y>BdgE z1coNPRgI8dQ;iOQjZ7R&B~Gxf*Oz7*f=HU!t>|%64Sj$b|Z( zqwQM7#fTEFX~gIr2Ek-SBzat>wW5|xzGEEpvg!N4`)?;pr^Iy5;MFE#E!U*upsa5c zkBW80aXdzO=R0hpvIogvb8ch(KUK5WLic@l%ktv&~AG z`z524I2f^EKDy^h)IcpmF6^I{Doq;Q)-8W?>MP1Sf2Br4sdx`JnPp+bjoT~c^I07? zN7SG&xt?>H+Z1tAjr>X=PTYPT4m%IhIs5>WJhSZ>tjBh#6(G><)w*2Hj)}rwg*_}} zD=+i4`Bb02>d;x;31aCDd+5igmRGMCo=Mjdiy2@oIrfF$a$(TKuw}yF7O_AN0O{ZS zJQoqP?rN{fURwU79QKaVB^xaAjy!|2vV}*aTfF=RuW~U;>sa20tq6^*$%R(_S4A>d&sl&P_O-PT(w9agJFru4?^hb-FKeWQmYgM0e!e!D*U0rrqJVemL_1khf`Q+CWe$q9IKJNBv5V>eSw; zS6*;<3GE>Om;8?BX|`J)RRJ$@-uZIcEfC%unLJiI8vZo&&G)0)WrUZ0avX{8(SAwK zk%%IFfYd5t%RtgP@|{m%2kX`q==F`C_qf4l)tDC3*y+@5jany!d+a7}S=w#_fkHkA zD+UU96S8}ga|)U=g)~=Dhu4SSVmX(e&;`SDJtpxsNToJKzp*ZqUB(Hv{weKDwh4W< zG3(sKTsr;L)k&i&x;Cp}d=*)$u30g+qHA%bwYUGMN{z)iV2pWa%Vx&gJAf@^um3i5 zG4uuqwDt0VuE-fKRa+mLhJlkcZFnv$Hlv|P7>&SLxXxy=V4f$w;M&XOEt_vPKWR*? z0i~eb;QxC*;Jy7O?^KHW>Q|RT&7zGyN|K&k%5NcEnEB=eM8d$Mh|pi>*@~{kWu3tL zyQ?LeOaHiFv4EzGXG?*6Heq3f#iVkRpqzU>GJZcvgBk06Mp}XT)sc?yoKRE-&KglA z7HsvsD{~)%{lZwXted^TlU{J~E|73LRfjZvsQLGKkT#dqrRZ1aUC+R^b6mgp6f|*p zpZ@ZmpOdViR=fM5KlW?|XMC;V~b4zW{v(DT}F8cPj`sc_$O-|FWF2vSU2% zDx92`MRgi_uGW6#_Ra`gF8z~IJEM%mMi_|;w6a_7eN3+B9iUk=sSO)W4zNz;sj|Y5 z+x-{m;>tR}rZYgCljnG)a2iz=>;sYugeahsX`)uY(gY}cj zdDDwCj`ZK|Tb7g5MhlGAp`~U_^W0Ow!Kro8)*ZyqOnx8DUHyKLBL4 zJ_iZWvd*ky6g!O#VHL^lR$HwpOQoi};E#&Mjinz@t!zZFR<; zlJs}XEdV#Xr4cE#txm+P?p=5(22-A718ywXCgCp1zQHhdgW`(1xt>9)df*B_g;E9l zCpQWuwsEXDWHFBGB@yi{p_*NOH-htYQRCuMG1KRlSI^a*GUl0om(KFDvoLt`7Pk?3 zjeg&-O}+3qP2{D@PLA7RTkg%8F~!%SX{(@$sftKM<#&6rAS^}#iN+dyZ?2+)D|ocTvuLMeU~X; zOvz5Iu2Zp}*V3uG6VqV()7UGJZZvDn^z&EH*tV$Y#jayxT#ETI{)YVHaTb()A)yy1q(~mvZ ztWH3?G#%QLot~R=5$D^M!H*#OXZi4`UoREa2w0UfP-eIWGb&DT6N$B*-+t@BNf0-2 zb;H=TWQXfCoGON!9Q5~lS~f*3epj7ongWh&UaE`lJy6#`V<{eq ziECMI-6lEf&BVQ&v2jnUQ%`SWX?3~YJ(EhHmJ>XnJXra{N2Gg|pjolNo+TDSTKeQQ zoy0$w7>>PgcMh5k49~kzBTWV?Y!u{QZ>G3VdocC&+cCRkkN3vz+w9`^38)sUTGmvx z*3WsRJ5~sCOtMcnWRz2QPIM2I&kdt(B|)d$HZnfT8XNC<9zUcwPtJQl3#?SmW>giI zDmAOVCFD-Cv=JKkn03ONza8T7P%Uy20Nr>enYuXmy-P@q|EO+wYqc=%Rhm)qOvq5x zXBqFwI~}pEQh(nvT(T6<&sr+SUw?MB>gY}jJfJo@Dj;5hr6Mqmd>NFtI>I_ooo-~F zq+C*6(d1q1`*p&IdyE6U$awlBcWdIuYrYR;sxVe{QDe`o=DpqhzNMgS$p^bPVFj(m zZt;#}^EHLG-ssP~2Hr9?`CzXC*<~f4E#1|2rbY$Wdh9Sd7H-9ZlZfZ(YnImX)}I2J zhbsHxjWdgB261bu1@c(4{B#G-O~vhQ*oy`^K($jRGwmJiVs9~% zSXI7yo2Bd&391ei+x8&G@;* zkgl^bPe0*YZGUe|S*-rew4_yuAKmm~ioK*slH#RFwMgx8dl;TL`3cg@pI+)GLjPxd zI=ugNQNKVyqCugPv7$m_kPCq@Nh#P^46uaRMU))=OVXciEAY$U^tCn`1@2W;Y&{;P z#NSK`a@payf<#=NtEA8n9s<-8A||-TWxUM9%U3g6SEaSq>I?ky7AEfw&4ZPahvGE zz0ET^Y6xe1ksx*0k)^tA_y^!$>H?@@wn|PRV1-C-Al}H3eY?Q!yeDD7_?c`1tvF_6 zMjTEf?w^Tk&M{x^)I~ZT*il01eX7PL-UC=Ju<$7MjO8^dUNdVV!7#W@^YVbyM7vwJ znaCpcv&uUj&L*60^qK<7t7mF@`WA@kgg$wqPRy?@-GpmeKU<;>Ca|YNiku+kw-vt{ zQQ`_GN>ZnuCfv#u8hn@@cAILtRQS5B1o*&urssN^MUD za<%I?^k6lz3Fcj_pO62nck$y8(MHl~+C_|9xipp$NLu;;fT-Rl6 z$zsc+>ck9jiLvK(jau%3oubIk^o8e6zSVMc!4%mPQHGx!&Hz=-vx=WE z?M1EUb;<&bix?=^I*uEW!}syM=I_cxC%omnrA?BwId!;Y@+(xj6nSbLd1#%*OZeQF zV^!w7^!rfG#`(31`Ch27I756AK^|B)>bnkjUQYdKf8tMUMj!LXiRPu@8(PI;S^Jn` zi1$R|Wgu+g!OqDJYvZ;Bo2k!yB+;z`e9C8?l~Fq%fYOk9^il)-0UVj=Y<#G*8mDJ= zYjv2E2kZR{AFNS?fD`T$frE&r@I@mtpNuMg1G83i@TSv7n|O+n{kxE~k}cRRA=b8g z{{v7fTV!aH3btJl!P`na45CAxhfcBLbZp!aa0h2Sp7{|+GcD3w(Pd_N{2LH2j-6vF z4Wdv{A(e|OMoY|ip({QYSj`(Tl0(7iO`69YAP=%U zQtdhCXtOYK9`Qx(V~#oI7Y#B|S{pJhx)4@!3$Ih?B&+Juj@!ABlxAa_gPD&UiRrZkS2A^JtcFOJE_{T5hm`&A@0VgU(D@Ggj9nUEZa zuKWr=@PyWqphTRbG*1(Q)Ts7kyfNP+9(j1NyzE`TbG|0t8NYSE9+`l37At1I`T>Yb z_9{xNQ$67};Q)zNBs#qijf>@$f&!dGt7d^IysQkolGP=Nc95}KcRn0kUSWf~aDt+B zla<7&!ZL1AWCJ8vvsm44$VW*Z0LnT`Mwgm!o5)tQ+C?lrPX@37<$v2 z&0x2wp=?ZOy@>pKcRzHF*15S19ZUFF2a2#2ZyFw1^ncjys1wM^`2d7Tx{D-%I^W`} z1mn|YM@2a~H7t+)HefPrK@O&MSu0H$+RcVYXg5iZoE+NiF4PHHv2Vfg%GxrwzKf#t zqhJ+G@|b?JfIFTla7LBF)Fv+pL-8uBUj3v=W+ygjL!*V}TP-(t0NT*m?1_ms$8#_{ zD`V=GJ|5aBuuipDjQ)x0e3Yrys)QtlhpCHAJ0DcJ^k*7L*YRHfyPBUe$l0KeeokEw zCl>TzfHa>lw@#Ja=V%yv*k9s5>Oc(rZ53f(d#hw_&5Nw2dSal-yBPGX2BY0N=z;NZ zUaIjtzY_<~TFYJH=-%KFJEhj5Twgd!sh@|bnq4EFK16cmSLA*mzOCC9tzkMlz`~){U#^u&}G!9tJV)MT`=HqFi>( z^Lai1;ZDoCb=6}9GF8Fx{Y#DZk%MB+B(m8f)CF^{SLj-!nG@2T!bgsvs_MHw{bty$gr_%_6;qyBop13W zn6HQ%oV&D&h8-Fpq@gV;LdH)%P|1aZVK%xRH}T>UeDDNgugMTIwdqRE{Lq#|1l1d}u^#^J{jM7~D;4gJ@G$m~s~Mrr zowH1;r3vw%h}u(9(iF9o+R7;sxi-luQTe`CFrR09-x$zdQJ+@lIJ%opJ!{e60RCPO zBXw*9><>*fX-O`_qw-?4gY zLsyazNoMUPTGqe|Z~87vAeS#LP=G~@)38C2u<9lQP1ql~;&>A{N0RWwZb1FpV_GLE z3{1YRqO_&8*A>g|xHeQSdwoGZyiGPih0#{iKVwB<1{rroy0=W+`u2)nC5!1=tSVak zfb|MOAbnwtK`V351T%E|gNJ+cA-z!&i+>I^^5=Vt&eqhq3M|ogE;44=cy^w9!_x-* zU6DwJ+Y`Aj-Xtrt;TKCjI<<0cr+^(Wz(00>rcIPY)E<;#s8Q&!OroBce}%6 zNFj}V%kXxwhHE)pJN{n&!=kI_{mw5HV!0!#+6an+CAk9p#2F) zzXL6eAh_rA_zwV-W#!G%*^nVxtg#$Qk`H3V4dKkzdWu+nsqL9hc+>)8=Mb!jJ4f?K zq-y-L0H6BvipppXIc*xJXUnlcI)1iXMoGFLR9wL;emr|E?)?*m-!<8_^aibsMrD4tD;*^dw&(Hkw e>^JuoN%)pniv31B_&HeK*eyFHf=?eCz$X~J* zXFb<(b9G;)v!O@@hbGWUQ$=oH(MPUur)|B8(HRByFA828hYSx1gqPl`<4}n2L!>ta zN8!K8awPFl%Ij&ilCz?QgsZia7<&=3rqv_cCEB9?miie1ZYGNV^~jpyt=vm9^0)SU z+mQ;+*viaQCt^3n;&_zf&ZID+VkMGOO(d*f;q&M!*G26m!<@ux*ciN)7o3U#5#w$r z_vpxp$rxsN(s}Wwiv8J0%4epO=Eam6r~uK4v#$(5-`dF3sd4ra?jmcNR)7a#g@((k zI+IIoJza3M_c+v5%$83xRCkp1h5fP4d97n)uN4B=)}%PZHMsV%XksKd>Y^uH9w#xx zUz)6YN+M&~vBoE%9$8y_p^C*<6X%PlihhVFM^F`<0?&888-q{+9x64Xz^yX0j?5os&Nl# zFl4(Z;|Pum=B=dHbc0RN@Z1^3!ZaBuuElO9|JK;<2iWcRB#JlxxP2ou*GxQIg@F*TD)?p-@ z0C!2rpuC6c3W#`5Zo82(t$6zdF_`{W~nnH;tS^lI3BRH z3CIhh%1CY%f`EyfDepPi{YAX|sj62SF9iC+ITvqPp-N+0U0IP;4s*alE6tSZYI?9twV(2- ze%3ZIXV8q0JvSfxhNdZ?_O?Ll2bp^5D}y+1k!(i+dN%coZxGp&_mWy0S2Zt zZ(k_{ylD^4sd5!{ytyk{6Oeuel zTM)g}v?KIlpXazh-;jSk(lmK~TJCoz0QY;utC=At$^gm4QLYHJG#GQk(xuNGM+f+HHzoD!o86aeFy~iEdGt%V-(3o!_z=NYpjl(CCj5g=t=$zx@SasOA+=; zipR4(la1WVRP->#Y@@HYgG!IR6_Mh@_7t@&CzPmHkV&L~hLYD1A2J>gA3nUdyiGo# z0GCdGFPapKAKLFRF1S@^st4nw9j-#TGnOgr36~Iqt@}K%8)>4@yZE@tvkgB09O^Dk z$8oRmm2H^a@$$tL-wh87Bseyf(TYiWtld{4oUg)CHp1mZ9bMVd#A*DRoO=htDU^?r zv~E5p+~30(hZ4@2lDf0(Ev(mxx@P8Vas-1Os-dEVv))nZZiAyHM19zTZ#DA@5O$cT zvt7`X^=E{8DvG&=55pHs_uzmevK1y8XRiAma7AmYGqkh-21W|A-M++t|1vA(yt zX1r?`cAsTf{@(pi{S1feFz@0?@X}-}p`cXcyF>@NBJyc;+&R&UiAHXkW{~-hKeY z3M8&mNA$ssUYgj8&~%N%v*NoHdsYv%UNQw|u`l=8(mpw~1L;G0<4)kTt!4$egb6rL zPHdeLRzI6|~QLekLV*ACw}(8?Mu2hI(Mc>a*(-4gMGSiMoQQfG?1cUp@n>u+aaL3;4`QAcPa%ENv^eSdH5u z$%qMF@u;>9GYGonq?#1RMLcAJx{!vg9!!zyH0nzz3~piEc1&JT+JWVmZR~UMWF82S zr7=s{4l`GGlCP8(1_$r82qWoc)7zE5*%%q;KX0!JaGmI_<^E`)pC{S-@`jq)R3w-u z5rWCFjWdZ}23n@#nTyWJ?NlalEzR&!ofk;qO_?0D&lss>a%#A_J;|@2E|_kZbPK9x z3pY}Zg@O*L#?iSf!obI=SVV)o{&rU-&e$O@|xxIaQ9}Uhcn=9WVHg7kn5iaoLNy+Wuy>r#d zP6K<1jgz2teR^cSAjneCyB#Pf)1XHKI+_O!l8NU8hIo)t zdO|5Qxag*l$o|87?ZSJ^ITS*Gmp`)2R!&!ssBn^PFeaHcv8)ZQn_)zBNfqR6IlT10 z)O8!8=xEwnlUFAcnqG^zR`ffTW?`Quhz!ZWWIysO$#p0!|78Jc%6+;KK@a``!%up5 zk^|V`f61ssiSY0LM>X^TD5pK7{bXFyr*ajE`R15B!Lx9A8YQ#VcvvJi=No0|hUY&c zZKJk(MG{QQDkuH3&hK=GnG+Y#Qft>Jdy!%7U=*V->pJ#f)T*ZbAU-<(>QxGjmj*otgDBu2lzaOF;4>2SP!^vXaMUrs67VO6 zt(a!C9NFyG8@TAPSuq!jX#2UWmr(Jy4+SIwQd%spz9ho2kVR|DwGwpA+B60Y65+Jf zhHis4Rn=u94$hL#P1+AMw{$%$q`!*AUtW}{4CnOp5E`sNPEm=IUJUi8xF)W1mnc-# z2T$90urN{%f4^vb05TmL6&)j8xPRWBK7Bi&*IL1p?fPeaJp@*1%xlQrCO&!Q26*18 ze@Bq0L#?CChLh)Pva~Mc=cGgMx8(`Fh^&s8gu46HF53y8h4S00hV#T(^^yR%Mk}Tl>fj_n3K!_x0nn3?J3h19KLU?M0}R;NDhHO4l>5P zJ~!crDIF<%7C!M`V>*eX89OinyH}~4`E2jlh8^VLGder0d8Uga#+sKv{8Sj=Od3ic zAdc*Kugq8O}vHEpHB2QHCttC4(zZHG97Ux(vUwI z|6s{D(xl5hyFLpEr#U!$-@ys_cdZNI6+PZCFp`6J1eUZT;KNyfpAE4Z*LI zXg-6jOc>`ry$oLwikb_CnJxUP-mzVtXTZ!2(~#{wTOt#?eH!ARS6+_RvK@9|XNWYY1}^2DIRS@-&nACaDi#iD?T({9ozX>}phfvUGpu7w zAwn;sRNR{-oOnmH&G04DtmkRiRDi0nQ~j-V%bjU2$9c$I_EsJ~MHvNqiD@ zZ$Gm!8-`CXD(#zgEd(&Qe6srRA?+Y_9QEbu{=DTIj7B7z=2Y1O6G{|J8~wYkMj#wr zzV*7q+C<;4|H`X*7A*X$hM!#|DqGTCwhme)o?V6^7)2J^f2yf0QU;HShhEieH)P9x z036ej)yl5!)#BNyid^7U}-FG0H!l!)) zLT7lu*@_1esN{zO2}UBATsuTQe9VMPcTvSsWyg(dn0CPd_TEe2yC_^r)a$(Ph)?mt z8N~KBLe%h4QF~R(kf>nHgC+Z~tbcJrv&m%}Cz*Ws_nYy&y%bCo@j6oDHcCR~*ugXz z%a4VwvbY}rv<`e(uY%%8X7&kQ#;H8^9%?Z$;BH<){%l-aA)tjjT#A*SipzDWU6jb);gWL+7mCyf>jC5|M^Hi6Jc3ng) z_VX*p3m&aWfF7z>?CJ8 ztap5cbA0(%Xnc-AHDfR(m$Zm64C=g#^qf&}^$+H2B_$ z9XA8hF@MLS!QM5JCI5SB36LN`?g!@)x%A1!^x74QcVp#mv9P{0OD@Lf9Ye^8UQp!W7ku~jEm^bm&BOaz5+znAa^|8|!f(OG?;b>lgw2;o=f zkr>W*^^_~q-k5tG>8gZ5!pGEwv`0c{Xl96rZ$7IGKyf{>WfYWN8N>RjW}gJPW=`?} z5Q@#{N)}f!xX&PJGhvfl(`Q2T-ZE5#|O;(s+voGJlgLQFE`Qs zeJh}UUQFJK9Jkb-4?D@CR<0Z)6?%Omq+%UU5!hVyEp1k}MBBt~&u<->E!_~2{W16Z zbQOuO9U)lU`kKXFck=eBgmY|2<-Y!y=0`Rm{K1@)Lc;TDr{cM$VC2Bpaf97%_e~Qy zRhZf-?KokM`;i5{Wa&`-2u{3L;a^p!uM!R5N>uF-+vDoDSUhDqPumqYO+b%5oq?z9cIR8O?~wE{Y-7U9(^O9M|2~55-EQW zS2igs(G(FX*NN)GidJo+(7>&R1SaFHCR83L(6RgiURFCBFQ%WcLl=%u#e#BU&tCJ; zp+GAH(n-ZK5OW0E7}-+d7Wk>Ys8K)F z)_-H!pCOU|{&V*QfJVkD1Tt{&kN?zLsMD9*SN|Kw29h~r>Ng08l69f&z`6frP5M%8 zKwlf~*o+m!6)J3m=#refjT&(gM*V9f{9X$!T0DdF=P8eLjbaq_ORlg8DvAS=+PkUl zS;(L0xi-3i8Nz0dSxC`7;4!R^t)las7+h-~t*%V0k_B->K$Ios9B~`hU8Z^FNgUFQ z`W@ED!#x|atIm!W$^73RfK4aF!vLfB^wxRTWJo8Qz}-jrlJr#15ra&JsJSs^W)1Hs z6v|%yK4sXiJV`0;RIa2{0)Y$)+02Q^f;;K{2hQ7LpcS2_UbMrkV;#+Y6Fs_AQC7lG z)T1N`6AhHg8Pg8%yIU-pNY49@0w4>UF&Ux?17V~A{xxDNin?#?T92JE?fPR7>0eKDtCgmt554lD>I2tu^^y=V}k|$<ZyEsMvOfJ2!_#8`vbnAU6ja#L6Svjej< zQ2pl$fBwpaxlRIzO}IVzi}zy^0kL_MxGp7AAHha-mxp)CHN%|a0gN2k0iOooEY~g0 zo2Xc(Y1Dxe;VVT0mDWRSEndEiKxE;%E|IX9(JyoO&;7DirKfXP%?K@L^YA9ZeNEN|UFy-j zi|OLxyTuIi)hs|O4(U^^1KTGG-8VE3niP(jhZa3FE2j>f$-1wqzvcd7ztxQ_7;gPF zdHG$0;O1?pA*l2-<4H@t`ngLZzf}N5nNlLSzisfF2{VE_5HxgY6mWR)y#_tU1w+h^ z6DXTi>@Be*zVVyI7N&oH-*6J!I0w>nk+vg5m7`26b-Y8vJQ-eyxFE%7{a%EIl7#gS zG&K~gDGf@9lY(V%P2;27KVnnSoFD(@;BoCOR5Q1)<{Ubv;V;vhh5qTcd;pL=Z+93C zM?d4s^?>c-h3?S3m>pzo6TWwbPE#Fi)jgUEY&}LfNH(y?@ypM-qA3?%5ZfQkuS5N) zl8hW*o)BF03@jCYM~SWy1o+1tTqdk1rHkvy*gMSFLQdf{(0u@GZ4x2~ozjL9(HF(i zkfMD_F;aemzHY1dp{g9IS9O=US6@&yF*Ji_{iDHtuA~;kd@S$ADs&7uDK+P`%q7F= z6#c8uNm%NssRtly z0J8xFhqQf*x*x$0VZz;&MkZrY02j!#@JL8i}ViwqsDJ;_0a<~)gNH??hqx# zHtO@1WTGJD)3QV>%S!!9m4a3eBUG2NQI0e~nRpS^&GNM-pzL-iamTuhT<<~!neByq zTka6mO8w`WB?gcB*D+}m3@Mm(2V)%@8dbWIxeov?nUpF>oP@#e%%L8Y_WlE4XfLR4 z)(tWZDV2L2X`T;&JV2(9C)dP3xU|3e;-#$iQj>~sgJ`#Ga>1glqOzEH0z2rB$NE0j zZO`0cWZzEymfV>h6^db8tbmorLot#;I_8{bHFRK1>JE4j_2olv_pW=BndV`4+Lx_f zPiIFNy3hkusR&$)nt6JL^_I$1&PeiSt4FqVk&2JTI&^#|lYQ--EZl_LE5jY#Jw<9{ zq|20YMu9kUuPIJ2Ohwgt4y1)tm2ARiH(6(Fe3`Av z!Yb}EiPVrKT%J(Svb5!m75M-}O5~z<0}hl=!&`0WtR`m=!t89%WTK8;`djiO@MQ(; zYQnaMCTH}Tf|J}7YKO*PAZO@%C0!Hgz5KUBoJ`Skj{b>d?w+#y{lYejFL{rH-|RPn zcMKO(M1nH-0Jx3<#OE7VSy6^l!)zpgEI-yA%k)sbcc&es8n#ljN5Su}3EfG{tCnyc zf`YTy#H~}Sb$#&a1K&mPTd=r(;d^WxGI1;tK~78b7fEj%{GR}eG<3@xOZ?6T-01%2 zM6b&3yzlb`mbXoh)*$JlMwhl~D3w1|+%AFDDbh2@JW5NwjgAz+(Pxfzc*`z~uMQB3 z8IA6v4&`AUZ#*xt?OHm|Bn4Q~zbD>TL({@gWVW;pugXvd ze;cA8&iGjz+V*`Q2W^?+*C?sTv9DcDfFM=Hf2=U7zE&&da()u_s*VPOW1V%1!oq6Q zz@P3Rjw_MBh_}k$$|gs6L9@a8O^Jszy1o@(FA=qk@#`RTN>sCb+^P>Xwjq0d6xJf z>?$n5F%}0&VI2;e4mxG%t;3q;eHXEb@Vid2GmevXw^8z-DBf`U`v-Ez?MxK`voVDK z05NA%{_K|W4pM|}JSf^8U@>hk&8*VfZ&_IBuF>2Ww+F`K9#Vx9^og}Ne85@CR?BmIfZu=5xXlKSK;b( z1KhK@yA!$9kiJF(=c$Du@M(ZKi``VT z@sjeyM`GfsBwA5JVOXKoxrhYRRjcg|ptDGW;L#G=M=Sf5YY(g8+y$Ux0ezLkpk^y2 zgh9Z&KQf?$&0H!8*k?XaUK*m>CFW&SdEC55Npr^#*^p7>K49D%kLy<)DpI;JfC}aP z#{{RuD7cX1{;r~?9?qguhzF>P5w{@u?jGG1wQ|7US|W=2m4|Sw1Qj39|xZ<2YI)w_4?U`Co=ne^p(1Zoc>KLO;MO*&>3;qD2WY_fT&!Q`V8N7Rm{{ST} zvD25{f?;>aSi(0E6v8k!bNnLMHpW*Ki>b8Lr{OHgg_m~TdX`173;IlaSW)^v?6ZIs zkMu-R{*nFtpu;>$;xgD%zjt2adgp@f4HGfT-veZKRv-4|NYa2+^^@xiBDbOQ+|WxA zUUL#7ahSbGkgoP3fGdQkm6-SZadC;l-*2<0xkeII)c#MzC&a;1_%Gg|%rpmQ$1|VT zJ{k8+F6msxVdoa>qw^{Wn_aX&-9@wUrG1TkVYV}fjJbDMzY~OyAJ;j4=^5`bSKWpe zCa-V;%@l1NIV_Lr7P17q7Pz?}j2M}2DKm2_)g9^>01gTFd?kgu@e;{`2)dX>hORPH zrMwz^Cug29FYxxqr<2UdX8Sx!l=}#N)av7`_sq)OybI6wEQMYr)orKLaikaV69&j9 zNvW84w+Xb9?E=MSsK;@DXAn~GObl)$q`QLBE}~8U00We{rOdHR$+(UaC>|L?17P&H zN9at)jZl0mf9yml0H=$92FB{8(SyuXbLdM;UL}5xz0IG%Hya%iq*ezl!O#l0#NJDC zr++n29>6r0G9B^}Q0;@M=gcVikT{(pcA)z$#8xc%jcF^BDyDFwC>X+6g1C>wIjDNM zXKH9tH@S#JH5lZ3d+rQ{lp1}zo9*|$6qlQr9a_#JV*z);D0J{?HJWwlK4Luu+3|an zIi4DykNuWH-?bJ~*p483iU*psf6-V07G{~ zJejUJj~PX)wdPs{%Zq)_SSGC(n9LTf;S^x=hAc#tjHsYn9PZ^3xSEp{pOZBbnMClP zcr9H=;aG&<6nJQ0D5r>52VcgBz~?NzkbvSXh1{fb_Di0zk`%4R;$Og-f5lYX_tOIA zqXq)SfVj-!c(Q0*Stik5nOT4Kd)T zRW)a6mw~|6H2kH6S1tvddLS+Ckg{{T4yKI0h@>z-TOXW35=Czh4BWu(LY{s5Op2xNJbzqsYs zI&`#!tJFaaa*G^+zAw66&5EC%sBBk~0&2`dWi}mwpjaWm581K4TkK1{^_G$_y94Nt zp7abfSP+YwR1<*J9i%*Ea?)UjT{lho@~HY@eC|0)V`g3a0W!-w!1wOL=~uzKk&HFw z1~C{6fXf*tp38Z3iBQD<0Pv+tZHGu^9jC5SOISC{Y`mFKr-PE)Tw2jvC>p~X9N&*L zAZm60!~iD|0RaI40RRF40|WyB0RR910RRypF+oufVR7(*k)g35(cv(`|Jncu0RsU6 zKM+^Sz}sZTUoG`_@f6}w{09)DPQ$#NI7qSZ!7&WT_Cv-AY2Pl-Azhw(h@DgP9WDdH6T*!!5ZtvEH^l zzLqA0>Es~D9injZA&s9W8pZE5L+c-TpKK`)=`Q1b_8`{o)$w7$gnbi?BVH(n33te` z2KF1p-E0-ViSxF1VFM+?{fQizWzdCQNt`8@Q1E*K?JW0=o%=w1PbIWCdc8WeJYlv<}4LM}l#cvz3!V7&_ zYVr8JVL90>{J)0}kQf$S7s0o-4!3{xn7?8}^n%%gkF?;+Z7*ArCj>LHx2{N{G)eb1 z^|w2pwf2#m$yQFmvSsYc=O;^#2n;VB_7Q_wA)j&}`dGFWZNnVY0rG!epV|R-24F{u zLkv9I>TBs_Pj9Orj-0o5@FsRVDbjg9Y@y}yBz|OmCoq^K3|dFx;3j4|beF6dcI>dj zn<4-%D!`e~dP|w{ z3D;*@Z&DFP@5O^AXG>^m%BKPuEo^#;dS$=Nyg zSq3(y1y>-#XTbYE_x(T^Nj$Q!k=9Npdt|WS=?p{h21Et= zxeqc=!*h~$94rx~w~jw0*oiZCHYPkImlNsM-aSM=W&Z$+LR7sxeg1zrCg+L#CG7No zys>YS&94p+XK(2n{ntdEL=)mTW%{p`^4=3OWw3_+^5gp+e?PV0G2{pps}4TJx7CBQ zJ+vPL9gnPy5&M=Ax@X*k-JftqMnxu@avWYW&5c|m$p`hG`j$dpy_)2hE-igW?m+3GHTo zNZ);fRP*jw=O3?_d_mq?bK(~L*~ilR{Ql&_T@C)S$gBJZ&)fa%Q~Y;&XJOBMnabbW z@Pzrck0lQ?>d_R9_5L!#3!dV^rL&Ul!{*^5CykH}#nMVa z8LyZN<(0x@yAH&8I>Pr~-sL#~J2LRo^stKazlo(^cdVb*0RrzHy@ABvfLK@d4<%Nj zL8-g%^KVRYG3*PBnfNCl=_d~?$7~Fpl&fI50ML8b^EkF+5$MTzA&;_4f$R_d+go}t zbscH?T%<3}roSIBm~(F-0O=UP20;PF z4~WP!=^#d?Nh$63XW}675&UNNT#`e!vQgKT2sp=@^C2Simoohf>!>mMu{+qXO;W+oH!WYqS3uM2?NZ*0A1IBcuOF0?P9 zcrwGFUth7XVoAebTJgSe9;14YH(~tm>vj*HxBki(MrOT`=HKsX5`r0I#>lv(O}fmr z2TOkgjN(&gC8#b0tAR9oG6CSo6xpE)hUz~=YZ5*#eEr_q3}XmQ9>?4ZJ`3px20S7A ziN~obYV($~$}{+2!*&*oeVl0a!^!z!V{F+@x1VO?asDsnp1^+Sdl?=<2PMxNDES*7 z;A!x_fR{hmc`AktvscFBH7_Fk2N-eE%MT-gdmP|Q5kB@p0(E!o%O z`Tqb=c7M0O=Jg-d6{=7dk=kYH$?Eexw&lN58@I@yH`pFXx#fr-gYa{(MM#<3l!ImL zuaj;X*hcaI69XQy6PZS9|p zZ*VQZ`~9P_V`Q=IxY!!(;MOqY-%=nX=mogtoi}V$?Z&Ig^lzDP8^Mu=MEjjdy_oa< zCBL1rw&)Ap%QfQpLK9Xsw`3zfq?nx}*#W(Mo`=QNytZQGyEYR&<7>=m4f5z)+3)`V z2k@|*qPDR0cHonf3$@!Z%id{TUB))Ew%K8KKaF=uM`UK%J0$WC5c~fCn){!^dLyK{K5WNr{A6<33>%HxXM?+! zB3236S725V?WX0H<1XL-!~iD|0RRF50RsaC0RaI3000000RRypF+mVfVR4c0fuXU% z(IDaR|Jncu0RjO5KM*mZI#5w1V9VW+g%b4-gx*9q7-0osYBH(A0;OWgv`*5Qct{6i z3ki_HkK33gp@Rbow{fu-I1W=0NP>tqqcb7URzfXBYZnWKUb5&>=3~-!{lyd7zum*v zE+6d0t@M@xTTv>YjZ?RYXk)xBvJ)iYiPH4fXv+z1ieRC5l}ZCq_=j9eg9XPUF#sbK z!ExOmVe04peJ9)ziY2kK?kRPhnTg>NiMZPn0JE5e%7tz@9pIMe)DduoI-ufdvks*u z5SoizL6qRc6MZg86_Zv_vZM2^{W$*nd6#Tua4`mD4v@)Xh;o^8XwH+UV+F#rL1pZjh>)0#A!x|c>e$q2;Rh0N}-j~ z8bCBg#gP+G<{}yZ&IrOrRzGz0u+$SQL4qub;tWnM2%F$$-azuiTzN~Nz<&n)BJ%@K zRh#Fap)P{oZY4DU$l#3W#Kki*rV^XVTyi*qJH>GpXahYWpi4o-Kio$QzR|#qjA+N+ zuwb`G)PDrofw*B9pv+4mF*31IoX=!n)w4oiK)H64FtdqKI$u(z-H%`S4d?ZPV7~?b z06%h`ruftB(=r;v2A08PVsjN!Ns1B6se-0<*T4L!Qfzl$)k;zKpX^%jwL?%;7YGc( zq*#88?0PW6qOD~?a}cfs!5z9aOJJCQ7BG%xmf`(s8savhu#sIQMIUTXmkR(0jB+XC-BCWbaP6mieWNqbncUwjnHkM)3uc)Jn zrihm*tQ8Y7))vl3L^~dqwvdLz%Nsz|T?rgZoJ$s<;>uehW+P@2rI~dM9P4u7++MRU zpZWZUue|)IDz?s%jK`qDAB_FsCCsc_aP4qNiGF2~U`6O4`Jx(y0G*h%$8qn^{y-SR z1tF)$Fv$0b%z$byn8+dvq-b7cCH0t#1+qGVp(5e38XJG_KYM~tWDx5t2{-UvF8xB7E6^?0;qr%2$I|VYxNJa(hh7fS$XK= zicpV3sG-v}nhW@r0qE9viaYZb+qkS@qbgQRO-Z6!F!(bCLx;RTHJbfEjOG;N5J~|O z-ygU&7Bi8kGW1J4r7Vi7gvWe({{YDRZH4a>aa0jX=AbZ`!XRm3!IW||h;$#LF&pZ> zeAGFkm^02~f&IaUe16!5zkl=hH-YuIk=ScdIwG?sVG*l|j3OY?A7O)W#u~&e;$cj_ ziSm?7js0T5tHUVmFboVos4jl+?Ho}pAm$Zh1V32M=sy^0IhFyE1vvrnF3?()-9+vS zL&fSrs;|t&QrHrZ+fe>*n2FLh5%R>kHXiZbn%JUxClNh~wLzsOwE;1B7JqOGfOjd# zdqsu7lY?Hf6K6iW{OpmhQCzE5)O$DH78WwClFbBFf&jbBwk*0L$feJ`PkjQ*0HOxf z9^Nx4+~hSawTF$q(Ph|h-{;ru`9P?|E#Ni=Dw%$S27+HuofhKFFmVW8pNY8&GW(tW zd){U=uMNvdT2AKXmKMzN{NfDUqp~G~5d*C}2;N`$DCaS;H5{VMSenE(UBhauvKmKr zZdGrd(b3?D@Ii2i7D(EaUC}i`;b7W zIPueDp}D~hi5IPnK3ttIbYb+&9?fV$%}q*TOX_*ly%9eb~=x;9v~U2 zG@etK2T%^;gNa^|5*A_BhRA;=-7k*#q7aL@?irABGcFb3LY zO{k{??-f3f5lkaQuc<>}2e~~X(*FR;&csaI87vr^XWMaIJj4R6Ei$6Wvl$aS2t6?< z>6O`?jP4RGMZ;l}sOFG8zJ8&F7~xI8=amk+377>PDFzHyQm(uOohE=?@B?U#tE5b1Vhi3{`Xvi4JZE+%mQR;sMbA z05~yO&{PeqW*{-lyXA>oItg3^GFt)w;7jt4h+!y3f#VHGP9b$n7sOh)QZq?OnZT9) zn-?+$cttNHfn3B!T0nBZoIx&>HwOg8QnD5x6w17k@2$;~1w}Oj{zBzcd!}HO3??^* zm&`OmE+R^m1hNcSFk2G3M49%2)*{pm)U#57iGAe`AzA;#045Lt0RRF50s;a80|5a5 z000015da}EK~Z6GfsvuH!O`LH@eu#o00;pA00BP`krhb+n?Oz?1k0{4Paj!8n{PPf zD3X}LHl!QI)ab^J1QaA5vBZ(LQxq88KQf!*Vdh zvV9p~yqUb%#m1WTfh9*Q)+N=xyUS8Rn*@y($0g(Cmw7n`t|~XT2uB?)R;fH;pD;FW zH(eg^iJ%qh3WB@+n!*f}de|wOR)J#;+rR_wJ3eve0=ETgM_7^h!y31)UT{X+u;(A~ z1rBid97Zb*V0Qy2@C^bvg70et+VH=5QTjor-UX3%p*$Fm$a~IgwQ0izZ7H_|X~A~m z6_KDXFH_sbHWjQevE;uJvwh`FP;3CE!jKyQ2H?`|bxBm&swLHNAg*pobDOlyV{DY? z3Z|coQ~^xXCmdnPQ=o8-m!*uc6m`6|x1;AZ7l!bs=L*DoSE;b99&8mnunR!VYL!-tBz^D-kMJS`Dc6Weg>?41SxeoWKn#D^()XK4*SR7G$jTz%-0^u7!cmtc>Gx(y+)6fZynL;brk-xrm z)YdHGWpu>!h?@}|F^lHHji$dC2!-IL91kEI0El z6<@3edL%7+e!0y!mQR$|czy?w3Oa@hXSZ%4a3xb6Y%t{rpWY(1i|;6w-#8*}ed8ON zqqXy#qZfeX!Z)oF?<{F39AyedhN+1bMR{|Zye3P=+;(t+4Ll3WFWyBBTyyxtP?{Sx z_&PtF6@733pL>51S{LNr|Pv@*` zd0Q3wM8r`Y&Y#W^Nde%&3ZDgVjR(y~z_9s84PyYxXikH1p}Ed3+kpV|3yep1LzZ|SM~pKNJL3^aqR)g8`MV^IuCA zcFb#+mHJ{CpcY^f6@vsRw1*UWc)$-{j*ux}v-;9c1O5x^XUeU9C?&CJHe{2>vVwmMF_?%8T`OOf;cpo2le1w^BwuaYD zx-N7yK2<+G;)uI6WzIvZ=Qp!+h~=0dql;4cxQT~}s^mUeACBCMhK3}8=?|}Y4;;4x z(32|E?Fx_VezFK_hlJznI8Kk&E4bS|7^z6{i!rlc_cM)0Rd#9|D*$WG7aMe)IdERBtppE|g@kg4jI3IothnVt!RHdt zzJ*yR#wsY5GgvxSYKCbgmmo>4l&!a{1+aI|cwq49aN1JnC1fe7mG?*%Q0{b4%!F#UyQLD&LUR?tBJ>qF+W)-U5>1{+^(?{bT2OO3JA@? z`kA6k(s)m!eT;3MHix zrjfub1~);sUh+zh7YKxZ89~D&@xCUYXF#3r4#F)SF_p%kNbH&JVPlr8^^8vVPFlU%CP$i z{{RWBt7vd=Vb+7SMbM#RdLXL(au)w#nDI<3}*W&rkO`|CKUq99xai)+_rv}Xje~L>`#WY7;=wH6` zePTVC>Srm3yb7?9TMpNQ`Nl*78)(o0U#o^`Js~6ojN>cKepbb1?lV{$r?1U4Y&(hMJ z%gK#TYqA!chnHC`$7L!{4}XjEjDTt49AQ0g-T{x#1ZoTD+W^{{9#PtrKX#qncZ^Yy zVnM2Tzb#JOV~7uM)INB_y~C=8DGOLof*jhD7EOR6xT#tNcns94LoK3+C?jJuWr12z zFxT0EyuA5^RBasd{ISSK)qGdFN;xyX+7Nw%RcF*c|AvwX;6n{AlKjX&FsL4+7i(C7NaX1f4%FGcZ! zNyMQx=5wUkavNI$c^ z!cL?IY!uk@@b{E0G@$@Es+`ZRaik;=pelk?ipOl2w(3fw0fypn!%L$i#=&?38_};n zHOn9=;V1Kz{-dJ$`S*_Al@~_`-v0ntHM$_^G770js7_p*H4veAk@4s212Z91Ux;ab z1HEluv2Cog}C)(AbKq}d_U@uza@5urOW`!Z6304KbDTvutJ+f~zA#Bm_wQlc#) zdLpdWt&Xd>b_V_s>mhle6cmUfuy4$9B8y~;w!M`5;_N6hfG)TtKYlZ2j){duFPtk) z2Na0bGA0Ya`$UE$RruTPAM@V0$rgbC{5|_9y+y*TR8Vcg2% z6zvWgIVU?d&qX{83Z@#N6%T-%_tS~DYy@*=zZ-DXFzAG@(P8gi8ZjzmIr)Eh{;?E1 zfNF;wagThaRP};*KX~vu0P^DjK&G)N(>Kf$!HSC8AmIj+$j8o2G(aEw-Xk!@=U0EL zf-tsgjRAno>Ts^tpi6*N<2055S5vY8fICgDz^^5EA@)G|6A?Vk+(I}2OIoiXjS%*s z?)dYHWfA}Z#Msh`E(S`CgV0h^YEP(Oq7oqF2t5RKmPmBCUiqC}f1KJ)k#Y&4kD0y$ zDArVHYMPZHpd^r&DUXKjbyJX_px|L4mkSgI)I{(G-QY?Z0c2COAf}iA4D-WfCb&(4 zJ|^P^tu%y*=<-55WJ41JLK_4=+un0wXT(l8GekM3#5TNqW`t(#L)Bk-u(VNV0%;z;vp1>(hfn4^)IIoa2S;A8hiE{7oyhUYCu#a zIT!x`QMQ;Ppv4*8!{0Ys!7T*Z;3C;g7ir?*=YtW5OHM~be9SGS(L=I!x}Cp~ z$-h?lE?K}izJ0l2q#AE)vb?2Nadj}>v?M_Z^x!}On=jcLgammL3gRj|3pngLYVBxE z0W@+bV?oLpl3So>i@e(maTM@c6|tTUG5C^3fxdgh0QQA9)&C6QD)J=Q#VFwftnredud-Pgy2xeFOvC zN#wWD=NP6YhSUfh032KhFqgX&mi6T?B+Ja706xV(INs`8-+O%HLj7cPYtACG5^sYL z^5i>>TdsbpIeX7brPQ5s*WLot@*9Ir<>S0vPctZJ)Ejqii~|p(b&7J!gIXB?y{GDo zU3Dk)Wz81b2I`t1?_0*05oH?@z+e-5&TSM$A|ingDdBm=fUjDic~8IyrMLjR3NDr{ zUGcfzA*mtSl0;HIHu^9UVBF_uyB^iMn2AY?ho^Ukj~U7itX+uF_*`fsJXE0qu)m&W zD`Ynjo9M=D1Woj4{TX2%Kl3Kj*E5~K-dq({;~m(435*^ry2PTPzDHOqV1yr&+K1i| zCYtS0^BMJvROt3~oook(>5(MUgM=QaM$eyf5>Z>lhew%t>fcz1C}^(Fvf>8yh9}-e zH#lW{S(}XttO?-QWNW6b`ZqVeFl11w{Je%2f2d8d+StDIVRPd^C{^6_cdQT^LGtQO zmOlLB)GC6_t12q!;QT_{jjEQ>J#z6f2h>iIkr=9DNhT~nV^2X~)mdw7fCbqxD)Xbi zZCTFS+AQCXtTsaWPOC4+ zjA*ceI_xUeSbsQZA{j~zQA?9o$ihINGy52Cal-~Qr6xk_!jca*SK#VH6^GY3=1af||JKnl)_tRM)}q_zi~;U&RA-#lQrDI1am z0YgS}RyKkcD=7E|=kVPUt@$?l-arlwUZk#d;{pI4M^}N*;tm_ctcKWdUKRP?6flWn zIlAjKZb(l}Wqx}w+LH(%mTO$w`M^^I6c<6<*8SuQVj%pTICjzjrXc=X_kc1gzE^*o zR4YPLlYir!(`~3UbY%lvA&GOZ0Mp{a@-wf0>ln~;fuY2%MG4519+P;vK^w@*5gHE zSX-lpp7F{BXcI}&wOH~rLir_0A}L)bFTOX8(ilIc(FfA}n=uM#1ix@q9r*77a3lr# zPJWv9Vsw-g0W^GX^f!@dItyMO9Di^nxcw*x+Fw5J~t*GErz2;futI4k@1hT8<8h0;r8R&jtHDy@ER=yBdy3T z+pshz1$#e8if@cN{hWikl~>Lt08us9Sj&3t?-ClI20I{%*5bqL0->DL!DyN89*o)Y z%eCX)3dt7J07^ud90vo_Oe$6p&xpTA&mo_m69Y|mgP$a(1w~07Io?P`D?K;q8^zDwP-_K) z&@1-20@2Yw73J?MPyjvxA0)ub9{_Ka?0Idv$QXY5ulwn3jn6$wj9C z%RDYEaFOC??aF5uEh1cRD>n?>2=#zs=bhv~E&E|+e?6EEk?x){V53{8l?pWSIXD+3 zofD|;q;B#a2-O;;)$`|XuJIB?0tby6Gj{kZO#~|&*e5&? zg>0W0(p(6@OFfBFO@mOTq-_z}tx2wZWvOhN;0Ra=0YqfN?eP z{{V1S?Oo=*kp_$Az9_BvEuFJeR9g2P#%TIYJu_r#Wj01~V{{Wda z!MGuOfmxNN3LobcC#gYl3$tDvUJgyzkQC4Y9`UUCm0r^GRcqht5`-?N7^O-)PB68f zCoF3j8YjFjbKYNQa0`JR5+)he^mC41PdTA(@LGBEhjlRoAjA=UU`ku-D4QiGCTe^7 z3?`576xGvujMN6*P6>!i<&J*{NC07B#-d8aS`@+EkM=zW1_G5s<&Za#~II=t$a9cd*|JjEMj-CJj literal 0 HcmV?d00001 diff --git a/doc/using.rst b/doc/using.rst new file mode 100644 index 00000000..c730f86b --- /dev/null +++ b/doc/using.rst @@ -0,0 +1,85 @@ +++++++++++++++ +Using Trollius +++++++++++++++ + +Documentation of the asyncio module +=================================== + +The documentation of the asyncio is part of the Python project. It can be read +online: `asyncio - Asynchronous I/O, event loop, coroutines and tasks +`_. + +To adapt asyncio examples for Trollius, "just": + +* replace ``asyncio`` with ``trollius`` + (or use ``import trollius as asyncio``) +* replace ``yield from ...`` with ``yield From(...)`` +* replace ``yield from []`` with ``yield From(None)`` +* in coroutines, replace ``return res`` with ``raise Return(res)`` + + +Trollius Hello World +==================== + +Print ``Hello World`` every two seconds, using a coroutine:: + + import trollius + from trollius import From + + @trollius.coroutine + def greet_every_two_seconds(): + while True: + print('Hello World') + yield From(trollius.sleep(2)) + + loop = trollius.get_event_loop() + loop.run_until_complete(greet_every_two_seconds()) + + +Debug mode +========== + +To enable the debug mode: + +* Set ``TROLLIUSDEBUG`` envrironment variable to ``1`` +* Configure logging to log at level ``logging.DEBUG``, + ``logging.basicConfig(level=logging.DEBUG)`` for example + +The ``BaseEventLoop.set_debug()`` method can be used to set the debug mode on a +specific event loop. The environment variable enables also the debug mode for +coroutines. + +Effect of the debug mode: + +* On Python 2, :meth:`Future.set_exception` stores the traceback, so + ``loop.run_until_complete()`` raises the exception with the original + traceback. +* Log coroutines defined but never "yielded" +* BaseEventLoop.call_soon() and BaseEventLoop.call_at() methods raise an + exception if they are called from the wrong thread. +* Log the execution time of the selector +* Log callbacks taking more than 100 ms to be executed. The + BaseEventLoop.slow_callback_duration attribute is the minimum duration in + seconds of "slow" callbacks. +* Log most important subprocess events: + + - Log stdin, stdout and stderr transports and protocols + - Log process identifier (pid) + - Log connection of pipes + - Log process exit + - Log Process.communicate() tasks: feed stdin, read stdout and stderr + +* Log most important socket events: + + - Socket connected + - New client (socket.accept()) + - Connection reset or closed by peer (EOF) + - Log time elapsed in DNS resolution (getaddrinfo) + - Log pause/resume reading + - Log time of SSL handshake + - Log SSL handshake errors + +See `Debug mode of asyncio +`_ +for more information. + From 26a3183085595b7bc7210dfcb00cff444e37f6c1 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 7 Jul 2015 23:42:21 +0200 Subject: [PATCH 1396/1502] add test_asyncio.py --- tests/test_asyncio.py | 141 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 tests/test_asyncio.py diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py new file mode 100644 index 00000000..39d9e1aa --- /dev/null +++ b/tests/test_asyncio.py @@ -0,0 +1,141 @@ +from trollius import test_utils +from trollius import From, Return +import trollius +import trollius.coroutines +import unittest + +try: + import asyncio +except ImportError: + from trollius.test_utils import SkipTest + raise SkipTest('need asyncio') + + +@asyncio.coroutine +def asyncio_noop(value): + yield from [] + return (value,) + +@asyncio.coroutine +def asyncio_coroutine(coro, value): + res = yield from coro + return res + (value,) + +@trollius.coroutine +def trollius_noop(value): + yield From(None) + raise Return((value,)) + +@trollius.coroutine +def trollius_coroutine(coro, value): + res = yield trollius.From(coro) + raise trollius.Return(res + (value,)) + + +class AsyncioTests(test_utils.TestCase): + def setUp(self): + policy = trollius.get_event_loop_policy() + + asyncio.set_event_loop_policy(policy) + self.addCleanup(asyncio.set_event_loop_policy, None) + + self.loop = policy.new_event_loop() + self.addCleanup(self.loop.close) + policy.set_event_loop(self.loop) + + def test_policy(self): + self.assertIs(asyncio.get_event_loop(), self.loop) + + def test_asyncio(self): + coro = asyncio_noop("asyncio") + res = self.loop.run_until_complete(coro) + self.assertEqual(res, ("asyncio",)) + + def test_asyncio_in_trollius(self): + coro1 = asyncio_noop(1) + coro2 = asyncio_coroutine(coro1, 2) + res = self.loop.run_until_complete(trollius_coroutine(coro2, 3)) + self.assertEqual(res, (1, 2, 3)) + + def test_trollius_in_asyncio(self): + coro1 = trollius_noop(4) + coro2 = trollius_coroutine(coro1, 5) + res = self.loop.run_until_complete(asyncio_coroutine(coro2, 6)) + self.assertEqual(res, (4, 5, 6)) + + def test_step_future(self): + old_debug = trollius.coroutines._DEBUG + try: + def step_future(): + future = asyncio.Future() + self.loop.call_soon(future.set_result, "asyncio.Future") + return (yield from future) + + # test in release mode + trollius.coroutines._DEBUG = False + result = self.loop.run_until_complete(step_future()) + self.assertEqual(result, "asyncio.Future") + + # test in debug mode + trollius.coroutines._DEBUG = True + result = self.loop.run_until_complete(step_future()) + self.assertEqual(result, "asyncio.Future") + finally: + trollius.coroutines._DEBUG = old_debug + + def test_async(self): + fut = asyncio.Future() + self.assertIs(fut._loop, self.loop) + + fut2 = trollius.async(fut) + self.assertIs(fut2, fut) + self.assertIs(fut._loop, self.loop) + + def test_wrap_future(self): + fut = asyncio.Future() + self.assertIs(trollius.wrap_future(fut), fut) + + def test_run_until_complete(self): + fut = asyncio.Future() + fut.set_result("ok") + self.assertEqual(self.loop.run_until_complete(fut), + "ok") + + def test_coroutine_decorator(self): + @trollius.coroutine + def asyncio_future(fut): + return fut + + fut = asyncio.Future() + self.loop.call_soon(fut.set_result, 'ok') + res = self.loop.run_until_complete(asyncio_future(fut)) + self.assertEqual(res, "ok") + + def test_as_completed(self): + fut = asyncio.Future() + fut.set_result("ok") + + with self.assertRaises(TypeError): + for f in trollius.as_completed(fut): + pass + + @trollius.coroutine + def get_results(fut): + results = [] + for f in trollius.as_completed([fut]): + res = yield trollius.From(f) + results.append(res) + raise trollius.Return(results) + + results = self.loop.run_until_complete(get_results(fut)) + self.assertEqual(results, ["ok"]) + + def test_gather(self): + fut = asyncio.Future() + fut.set_result("ok") + results = self.loop.run_until_complete(trollius.gather(fut)) + self.assertEqual(results, ["ok"]) + + +if __name__ == '__main__': + unittest.main() From 603a58ed4bc354660ef20c4f6bf1f3b234453eef Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 7 Jul 2015 23:50:08 +0200 Subject: [PATCH 1397/1502] Port asyncio to Python 2, examples/ directory --- examples/cacheclt.py | 50 +++++------ examples/cachesvr.py | 23 ++--- examples/child_process.py | 19 ++-- examples/crawl.py | 105 +++++++++++++---------- examples/echo_client_tulip.py | 5 +- examples/echo_server_tulip.py | 5 +- examples/fetch0.py | 9 +- examples/fetch1.py | 24 ++++-- examples/fetch2.py | 50 +++++++---- examples/fetch3.py | 79 ++++++++++------- examples/fuzz_as_completed.py | 11 ++- examples/hello_callback.py | 4 +- examples/hello_coroutine.py | 9 +- examples/shell.py | 23 ++--- examples/simple_tcp_server.py | 34 +++++--- examples/sink.py | 10 ++- examples/source.py | 7 +- examples/source1.py | 20 +++-- examples/stacks.py | 4 +- examples/subprocess_attach_read_pipe.py | 14 +-- examples/subprocess_attach_write_pipe.py | 25 +++--- examples/subprocess_shell.py | 16 ++-- examples/tcp_echo.py | 2 +- examples/timing_tcp_server.py | 39 +++++---- examples/udp_echo.py | 6 +- 25 files changed, 340 insertions(+), 253 deletions(-) diff --git a/examples/cacheclt.py b/examples/cacheclt.py index 3e9de31a..1f8ece4f 100644 --- a/examples/cacheclt.py +++ b/examples/cacheclt.py @@ -5,6 +5,7 @@ import argparse import trollius as asyncio +from trollius import From, Return from trollius import test_utils import json import logging @@ -62,24 +63,24 @@ def __init__(self, host, port, sslctx=None, loop=None): @asyncio.coroutine def get(self, key): - resp = yield from self.request('get', key) + resp = yield From(self.request('get', key)) if resp is None: - return None - return resp.get('value') + raise Return() + raise Return(resp.get('value')) @asyncio.coroutine def set(self, key, value): - resp = yield from self.request('set', key, value) + resp = yield From(self.request('set', key, value)) if resp is None: - return False - return resp.get('status') == 'ok' + raise Return(False) + raise Return(resp.get('status') == 'ok') @asyncio.coroutine def delete(self, key): - resp = yield from self.request('delete', key) + resp = yield From(self.request('delete', key)) if resp is None: - return False - return resp.get('status') == 'ok' + raise Return(False) + raise Return(resp.get('status') == 'ok') @asyncio.coroutine def request(self, type, key, value=None): @@ -91,24 +92,25 @@ def request(self, type, key, value=None): waiter = asyncio.Future(loop=self.loop) if self.initialized: try: - yield from self.send(payload, waiter) + yield From(self.send(payload, waiter)) except IOError: self.todo.add((payload, waiter)) else: self.todo.add((payload, waiter)) - return (yield from waiter) + result = (yield From(waiter)) + raise Return(result) @asyncio.coroutine def activity(self): backoff = 0 while True: try: - self.reader, self.writer = yield from asyncio.open_connection( - self.host, self.port, ssl=self.sslctx, loop=self.loop) + self.reader, self.writer = yield From(asyncio.open_connection( + self.host, self.port, ssl=self.sslctx, loop=self.loop)) except Exception as exc: backoff = min(args.max_backoff, backoff + (backoff//2) + 1) logging.info('Error connecting: %r; sleep %s', exc, backoff) - yield from asyncio.sleep(backoff, loop=self.loop) + yield From(asyncio.sleep(backoff, loop=self.loop)) continue backoff = 0 self.next_id = 0 @@ -118,9 +120,9 @@ def activity(self): while self.todo: payload, waiter = self.todo.pop() if not waiter.done(): - yield from self.send(payload, waiter) + yield From(self.send(payload, waiter)) while True: - resp_id, resp = yield from self.process() + resp_id, resp = yield From(self.process()) if resp_id in self.pending: payload, waiter = self.pending.pop(resp_id) if not waiter.done(): @@ -143,11 +145,11 @@ def send(self, payload, waiter): self.writer.write(frame.encode('ascii')) self.writer.write(payload) self.pending[req_id] = payload, waiter - yield from self.writer.drain() + yield From(self.writer.drain()) @asyncio.coroutine def process(self): - frame = yield from self.reader.readline() + frame = yield From(self.reader.readline()) if not frame: raise EOFError() head, tail = frame.split(None, 1) @@ -156,11 +158,11 @@ def process(self): if head != b'response': raise IOError('Bad frame: %r' % frame) resp_id, resp_size = map(int, tail.split()) - data = yield from self.reader.readexactly(resp_size) + data = yield From(self.reader.readexactly(resp_size)) if len(data) != resp_size: raise EOFError() resp = json.loads(data.decode('utf8')) - return resp_id, resp + raise Return(resp_id, resp) def main(): @@ -193,13 +195,13 @@ def w(g): while True: logging.info('%s %s', label, '-'*20) try: - ret = yield from w(cache.set(key, 'hello-%s-world' % label)) + ret = yield From(w(cache.set(key, 'hello-%s-world' % label))) logging.info('%s set %s', label, ret) - ret = yield from w(cache.get(key)) + ret = yield From(w(cache.get(key))) logging.info('%s get %s', label, ret) - ret = yield from w(cache.delete(key)) + ret = yield From(w(cache.delete(key))) logging.info('%s del %s', label, ret) - ret = yield from w(cache.get(key)) + ret = yield From(w(cache.get(key))) logging.info('%s get2 %s', label, ret) except asyncio.TimeoutError: logging.warn('%s Timeout', label) diff --git a/examples/cachesvr.py b/examples/cachesvr.py index 27ce6c30..20a54e4a 100644 --- a/examples/cachesvr.py +++ b/examples/cachesvr.py @@ -58,6 +58,7 @@ import argparse import trollius as asyncio +from trollius import From import json import logging import os @@ -104,7 +105,7 @@ def handle_client(self, reader, writer): peer = writer.get_extra_info('socket').getpeername() logging.info('got a connection from %s', peer) try: - yield from self.frame_parser(reader, writer) + yield From(self.frame_parser(reader, writer)) except Exception as exc: logging.error('error %r from %s', exc, peer) else: @@ -122,13 +123,13 @@ def frame_parser(self, reader, writer): # if the client doesn't send enough data but doesn't # disconnect either. We add a timeout to each. (But the # timeout should really be implemented by StreamReader.) - framing_b = yield from asyncio.wait_for( + framing_b = yield From(asyncio.wait_for( reader.readline(), - timeout=args.timeout, loop=self.loop) + timeout=args.timeout, loop=self.loop)) if random.random()*100 < args.fail_percent: logging.warn('Inserting random failure') - yield from asyncio.sleep(args.fail_sleep*random.random(), - loop=self.loop) + yield From(asyncio.sleep(args.fail_sleep*random.random(), + loop=self.loop)) writer.write(b'error random failure\r\n') break logging.debug('framing_b = %r', framing_b) @@ -151,9 +152,9 @@ def frame_parser(self, reader, writer): writer.write(b'error invalid frame parameters\r\n') break last_request_id = request_id - request_b = yield from asyncio.wait_for( + request_b = yield From(asyncio.wait_for( reader.readexactly(byte_count), - timeout=args.timeout, loop=self.loop) + timeout=args.timeout, loop=self.loop)) try: request = json.loads(request_b.decode('utf8')) except ValueError: @@ -165,10 +166,10 @@ def frame_parser(self, reader, writer): break response_b = json.dumps(response).encode('utf8') + b'\r\n' byte_count = len(response_b) - framing_s = 'response {} {}\r\n'.format(request_id, byte_count) + framing_s = 'response {0} {1}\r\n'.format(request_id, byte_count) writer.write(framing_s.encode('ascii')) - yield from asyncio.sleep(args.resp_sleep*random.random(), - loop=self.loop) + yield From(asyncio.sleep(args.resp_sleep*random.random(), + loop=self.loop)) writer.write(response_b) def handle_request(self, request): @@ -226,7 +227,7 @@ def main(): import ssl # TODO: take cert/key from args as well. here = os.path.join(os.path.dirname(__file__), '..', 'tests') - sslctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + sslctx = asyncio.SSLContext(ssl.PROTOCOL_SSLv23) sslctx.options |= ssl.OP_NO_SSLv2 sslctx.load_cert_chain( certfile=os.path.join(here, 'ssl_cert.pem'), diff --git a/examples/child_process.py b/examples/child_process.py index 915e358f..9e403a4d 100644 --- a/examples/child_process.py +++ b/examples/child_process.py @@ -15,6 +15,7 @@ # asyncio is not installed sys.path.append(os.path.join(os.path.dirname(__file__), '..')) import trollius as asyncio +from trollius import From, Return if sys.platform == 'win32': from trollius.windows_utils import Popen, PIPE @@ -29,8 +30,8 @@ @asyncio.coroutine def connect_write_pipe(file): loop = asyncio.get_event_loop() - transport, _ = yield from loop.connect_write_pipe(asyncio.Protocol, file) - return transport + transport, _ = yield From(loop.connect_write_pipe(asyncio.Protocol, file)) + raise Return(transport) # # Wrap a readable pipe in a stream @@ -42,8 +43,8 @@ def connect_read_pipe(file): stream_reader = asyncio.StreamReader(loop=loop) def factory(): return asyncio.StreamReaderProtocol(stream_reader) - transport, _ = yield from loop.connect_read_pipe(factory, file) - return stream_reader, transport + transport, _ = yield From(loop.connect_read_pipe(factory, file)) + raise Return(stream_reader, transport) # @@ -80,9 +81,9 @@ def writeall(fd, buf): p = Popen([sys.executable, '-c', code], stdin=PIPE, stdout=PIPE, stderr=PIPE) - stdin = yield from connect_write_pipe(p.stdin) - stdout, stdout_transport = yield from connect_read_pipe(p.stdout) - stderr, stderr_transport = yield from connect_read_pipe(p.stderr) + stdin = yield From(connect_write_pipe(p.stdin)) + stdout, stdout_transport = yield From(connect_read_pipe(p.stdout)) + stderr, stderr_transport = yield From(connect_read_pipe(p.stderr)) # interact with subprocess name = {stdout:'OUT', stderr:'ERR'} @@ -100,9 +101,9 @@ def writeall(fd, buf): # get and print lines from stdout, stderr timeout = None while registered: - done, pending = yield from asyncio.wait( + done, pending = yield From(asyncio.wait( registered, timeout=timeout, - return_when=asyncio.FIRST_COMPLETED) + return_when=asyncio.FIRST_COMPLETED)) if not done: break for f in done: diff --git a/examples/crawl.py b/examples/crawl.py index 0393d626..7f540593 100644 --- a/examples/crawl.py +++ b/examples/crawl.py @@ -1,7 +1,9 @@ -#!/usr/bin/env python3.4 +#!/usr/bin/env python """A simple web crawler.""" +from __future__ import print_function + # TODO: # - More organized logging (with task ID or URL?). # - Use logging module for Logger. @@ -16,14 +18,22 @@ import argparse import trollius as asyncio +from trollius import From, Return import asyncio.locks import cgi -from http.client import BadStatusLine import logging import re import sys import time -import urllib.parse +try: + from httplib import BadStatusLine + import urlparse + from urllib import splitport as urllib_splitport +except ImportError: + # Python 3 + from http.client import BadStatusLine + from urllib import parse as urlparse + from urllib.parse import splitport as urllib_splitport ARGS = argparse.ArgumentParser(description="Web crawler") @@ -96,7 +106,8 @@ def __init__(self, level): def _log(self, n, args): if self.level >= n: - print(*args, file=sys.stderr, flush=True) + print(*args, file=sys.stderr) + sys.stderr.flush() def log(self, n, *args): self._log(n, args) @@ -133,14 +144,14 @@ def close(self): for conn in conns: conn.close() self.connections.clear() - self.queue.clear() + del self.queue[:] @asyncio.coroutine def get_connection(self, host, port, ssl): """Create or reuse a connection.""" port = port or (443 if ssl else 80) try: - ipaddrs = yield from self.loop.getaddrinfo(host, port) + ipaddrs = yield From(self.loop.getaddrinfo(host, port)) except Exception as exc: self.log(0, 'Exception %r for (%r, %r)' % (exc, host, port)) raise @@ -148,7 +159,8 @@ def get_connection(self, host, port, ssl): (host, ', '.join(ip[4][0] for ip in ipaddrs))) # Look for a reusable connection. - for _, _, _, _, (h, p, *_) in ipaddrs: + for _, _, _, _, addr in ipaddrs: + h, p = addr[:2] key = h, p, ssl conn = None conns = self.connections.get(key) @@ -163,13 +175,13 @@ def get_connection(self, host, port, ssl): else: self.log(1, '* Reusing pooled connection', key, 'FD =', conn.fileno()) - return conn + raise Return(conn) # Create a new connection. conn = Connection(self.log, self, host, port, ssl) - yield from conn.connect() + yield From(conn.connect()) self.log(1, '* New connection', conn.key, 'FD =', conn.fileno()) - return conn + raise Return(conn) def recycle_connection(self, conn): """Make a connection available for reuse. @@ -258,8 +270,8 @@ def fileno(self): @asyncio.coroutine def connect(self): - self.reader, self.writer = yield from asyncio.open_connection( - self.host, self.port, ssl=self.ssl) + self.reader, self.writer = yield From(asyncio.open_connection( + self.host, self.port, ssl=self.ssl)) peername = self.writer.get_extra_info('peername') if peername: self.host, self.port = peername[:2] @@ -286,7 +298,7 @@ def __init__(self, log, url, pool): self.log = log self.url = url self.pool = pool - self.parts = urllib.parse.urlparse(self.url) + self.parts = urlparse.urlparse(self.url) self.scheme = self.parts.scheme assert self.scheme in ('http', 'https'), repr(url) self.ssl = self.parts.scheme == 'https' @@ -311,8 +323,8 @@ def connect(self): (self.hostname, self.port, 'ssl' if self.ssl else 'tcp', self.url)) - self.conn = yield from self.pool.get_connection(self.hostname, - self.port, self.ssl) + self.conn = yield From(self.pool.get_connection(self.hostname, + self.port, self.ssl)) def close(self, recycle=False): """Close the connection, recycle if requested.""" @@ -336,7 +348,7 @@ def send_request(self): """Send the request.""" request_line = '%s %s %s' % (self.method, self.full_path, self.http_version) - yield from self.putline(request_line) + yield From(self.putline(request_line)) # TODO: What if a header is already set? self.headers.append(('User-Agent', 'asyncio-example-crawl/0.0')) self.headers.append(('Host', self.netloc)) @@ -344,15 +356,15 @@ def send_request(self): ##self.headers.append(('Accept-Encoding', 'gzip')) for key, value in self.headers: line = '%s: %s' % (key, value) - yield from self.putline(line) - yield from self.putline('') + yield From(self.putline(line)) + yield From(self.putline('')) @asyncio.coroutine def get_response(self): """Receive the response.""" response = Response(self.log, self.conn.reader) - yield from response.read_headers() - return response + yield From(response.read_headers()) + raise Return(response) class Response: @@ -374,14 +386,15 @@ def __init__(self, log, reader): @asyncio.coroutine def getline(self): """Read one line from the connection.""" - line = (yield from self.reader.readline()).decode('latin-1').rstrip() + line = (yield From(self.reader.readline())) + line = line.decode('latin-1').rstrip() self.log(2, '<', line) - return line + raise Return(line) @asyncio.coroutine def read_headers(self): """Read the response status and the request headers.""" - status_line = yield from self.getline() + status_line = yield From(self.getline()) status_parts = status_line.split(None, 2) if len(status_parts) != 3: self.log(0, 'bad status_line', repr(status_line)) @@ -389,7 +402,7 @@ def read_headers(self): self.http_version, status, self.reason = status_parts self.status = int(status) while True: - header_line = yield from self.getline() + header_line = yield From(self.getline()) if not header_line: break # TODO: Continuation lines. @@ -426,7 +439,7 @@ def read(self): self.log(2, 'parsing chunked response') blocks = [] while True: - size_header = yield from self.reader.readline() + size_header = yield From(self.reader.readline()) if not size_header: self.log(0, 'premature end of chunked response') break @@ -435,10 +448,10 @@ def read(self): size = int(parts[0], 16) if size: self.log(3, 'reading chunk of', size, 'bytes') - block = yield from self.reader.readexactly(size) + block = yield From(self.reader.readexactly(size)) assert len(block) == size, (len(block), size) blocks.append(block) - crlf = yield from self.reader.readline() + crlf = yield From(self.reader.readline()) assert crlf == b'\r\n', repr(crlf) if not size: break @@ -447,12 +460,12 @@ def read(self): 'bytes in', len(blocks), 'blocks') else: self.log(3, 'reading until EOF') - body = yield from self.reader.read() + body = yield From(self.reader.read()) # TODO: Should make sure not to recycle the connection # in this case. else: - body = yield from self.reader.readexactly(nbytes) - return body + body = yield From(self.reader.readexactly(nbytes)) + raise Return(body) class Fetcher: @@ -504,10 +517,10 @@ def fetch(self): self.request = None try: self.request = Request(self.log, self.url, self.crawler.pool) - yield from self.request.connect() - yield from self.request.send_request() - self.response = yield from self.request.get_response() - self.body = yield from self.response.read() + yield From(self.request.connect()) + yield From(self.request.send_request()) + self.response = yield From(self.request.get_response()) + self.body = yield From(self.response.read()) h_conn = self.response.get_header('connection').lower() if h_conn != 'close': self.request.close(recycle=True) @@ -531,7 +544,7 @@ def fetch(self): return next_url = self.response.get_redirect_url() if next_url: - self.next_url = urllib.parse.urljoin(self.url, next_url) + self.next_url = urlparse.urljoin(self.url, next_url) if self.max_redirect > 0: self.log(1, 'redirect to', self.next_url, 'from', self.url) self.crawler.add_url(self.next_url, self.max_redirect-1) @@ -556,8 +569,8 @@ def fetch(self): self.new_urls = set() for url in self.urls: url = unescape(url) - url = urllib.parse.urljoin(self.url, url) - url, frag = urllib.parse.urldefrag(url) + url = urlparse.urljoin(self.url, url) + url, frag = urlparse.urldefrag(url) if self.crawler.add_url(url): self.new_urls.add(url) @@ -657,8 +670,8 @@ def __init__(self, log, self.pool = ConnectionPool(self.log, max_pool, max_tasks) self.root_domains = set() for root in roots: - parts = urllib.parse.urlparse(root) - host, port = urllib.parse.splitport(parts.netloc) + parts = urlparse.urlparse(root) + host, port = urllib_splitport(parts.netloc) if not host: continue if re.match(r'\A[\d\.]*\Z', host): @@ -731,11 +744,11 @@ def add_url(self, url, max_redirect=None): """Add a URL to the todo list if not seen before.""" if self.exclude and re.search(self.exclude, url): return False - parts = urllib.parse.urlparse(url) + parts = urlparse.urlparse(url) if parts.scheme not in ('http', 'https'): self.log(2, 'skipping non-http scheme in', url) return False - host, port = urllib.parse.splitport(parts.netloc) + host, port = urllib_splitport(parts.netloc) if not self.host_okay(host): self.log(2, 'skipping non-root host in', url) return False @@ -750,7 +763,7 @@ def add_url(self, url, max_redirect=None): @asyncio.coroutine def crawl(self): """Run the crawler until all finished.""" - with (yield from self.termination): + with (yield From(self.termination)): while self.todo or self.busy: if self.todo: url, max_redirect = self.todo.popitem() @@ -762,7 +775,7 @@ def crawl(self): self.busy[url] = fetcher fetcher.task = asyncio.Task(self.fetch(fetcher)) else: - yield from self.termination.wait() + yield From(self.termination.wait()) self.t1 = time.time() @asyncio.coroutine @@ -772,13 +785,13 @@ def fetch(self, fetcher): Once this returns, move the fetcher from busy to done. """ url = fetcher.url - with (yield from self.governor): + with (yield From(self.governor)): try: - yield from fetcher.fetch() # Fetcher gonna fetch. + yield From(fetcher.fetch()) # Fetcher gonna fetch. finally: # Force GC of the task, so the error is logged. fetcher.task = None - with (yield from self.termination): + with (yield From(self.termination)): self.done[url] = fetcher del self.busy[url] self.termination.notify() diff --git a/examples/echo_client_tulip.py b/examples/echo_client_tulip.py index eea8a58d..0a609260 100644 --- a/examples/echo_client_tulip.py +++ b/examples/echo_client_tulip.py @@ -1,15 +1,16 @@ import trollius as asyncio +from trollius import From END = b'Bye-bye!\n' @asyncio.coroutine def echo_client(): - reader, writer = yield from asyncio.open_connection('localhost', 8000) + reader, writer = yield From(asyncio.open_connection('localhost', 8000)) writer.write(b'Hello, world\n') writer.write(b'What a fine day it is.\n') writer.write(END) while True: - line = yield from reader.readline() + line = yield From(reader.readline()) print('received:', line) if line == END or not line: break diff --git a/examples/echo_server_tulip.py b/examples/echo_server_tulip.py index e1f9f2b8..d7e6e29d 100644 --- a/examples/echo_server_tulip.py +++ b/examples/echo_server_tulip.py @@ -1,13 +1,14 @@ import trollius as asyncio +from trollius import From @asyncio.coroutine def echo_server(): - yield from asyncio.start_server(handle_connection, 'localhost', 8000) + yield From(asyncio.start_server(handle_connection, 'localhost', 8000)) @asyncio.coroutine def handle_connection(reader, writer): while True: - data = yield from reader.read(8192) + data = yield From(reader.read(8192)) if not data: break writer.write(data) diff --git a/examples/fetch0.py b/examples/fetch0.py index 222a97b1..f98feeb3 100644 --- a/examples/fetch0.py +++ b/examples/fetch0.py @@ -1,5 +1,6 @@ """Simplest possible HTTP client.""" +from __future__ import print_function import sys from trollius import * @@ -7,19 +8,19 @@ @coroutine def fetch(): - r, w = yield from open_connection('python.org', 80) + r, w = yield From(open_connection('python.org', 80)) request = 'GET / HTTP/1.0\r\n\r\n' print('>', request, file=sys.stderr) w.write(request.encode('latin-1')) while True: - line = yield from r.readline() + line = yield From(r.readline()) line = line.decode('latin-1').rstrip() if not line: break print('<', line, file=sys.stderr) print(file=sys.stderr) - body = yield from r.read() - return body + body = yield From(r.read()) + raise Return(body) def main(): diff --git a/examples/fetch1.py b/examples/fetch1.py index 4e7037f2..9e9a1caf 100644 --- a/examples/fetch1.py +++ b/examples/fetch1.py @@ -3,8 +3,12 @@ This version adds URL parsing (including SSL) and a Response object. """ +from __future__ import print_function import sys -import urllib.parse +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse from trollius import * @@ -22,13 +26,15 @@ def __init__(self, verbose=True): def read(self, reader): @coroutine def getline(): - return (yield from reader.readline()).decode('latin-1').rstrip() - status_line = yield from getline() + line = (yield From(reader.readline())) + line = line.decode('latin-1').rstrip() + raise Return(line) + status_line = yield From(getline()) if self.verbose: print('<', status_line, file=sys.stderr) self.http_version, status, self.reason = status_line.split(None, 2) self.status = int(status) while True: - header_line = yield from getline() + header_line = yield From(getline()) if not header_line: break if self.verbose: print('<', header_line, file=sys.stderr) @@ -40,7 +46,7 @@ def getline(): @coroutine def fetch(url, verbose=True): - parts = urllib.parse.urlparse(url) + parts = urlparse(url) if parts.scheme == 'http': ssl = False elif parts.scheme == 'https': @@ -57,12 +63,12 @@ def fetch(url, verbose=True): request = 'GET %s HTTP/1.0\r\n\r\n' % path if verbose: print('>', request, file=sys.stderr, end='') - r, w = yield from open_connection(parts.hostname, port, ssl=ssl) + r, w = yield From(open_connection(parts.hostname, port, ssl=ssl)) w.write(request.encode('latin-1')) response = Response(verbose) - yield from response.read(r) - body = yield from r.read() - return body + yield From(response.read(r)) + body = yield From(r.read()) + raise Return(body) def main(): diff --git a/examples/fetch2.py b/examples/fetch2.py index de6a288d..5a321a8a 100644 --- a/examples/fetch2.py +++ b/examples/fetch2.py @@ -3,9 +3,15 @@ This version adds a Request object. """ +from __future__ import print_function import sys -import urllib.parse -from http.client import BadStatusLine +try: + from urllib.parse import urlparse + from http.client import BadStatusLine +except ImportError: + # Python 2 + from urlparse import urlparse + from httplib import BadStatusLine from trollius import * @@ -15,7 +21,7 @@ class Request: def __init__(self, url, verbose=True): self.url = url self.verbose = verbose - self.parts = urllib.parse.urlparse(self.url) + self.parts = urlparse(self.url) self.scheme = self.parts.scheme assert self.scheme in ('http', 'https'), repr(url) self.ssl = self.parts.scheme == 'https' @@ -40,9 +46,9 @@ def connect(self): print('* Connecting to %s:%s using %s' % (self.hostname, self.port, 'ssl' if self.ssl else 'tcp'), file=sys.stderr) - self.reader, self.writer = yield from open_connection(self.hostname, + self.reader, self.writer = yield From(open_connection(self.hostname, self.port, - ssl=self.ssl) + ssl=self.ssl)) if self.verbose: print('* Connected to %s' % (self.writer.get_extra_info('peername'),), @@ -67,8 +73,8 @@ def send_request(self): @coroutine def get_response(self): response = Response(self.reader, self.verbose) - yield from response.read_headers() - return response + yield From(response.read_headers()) + raise Return(response) class Response: @@ -83,11 +89,13 @@ def __init__(self, reader, verbose=True): @coroutine def getline(self): - return (yield from self.reader.readline()).decode('latin-1').rstrip() + line = (yield From(self.reader.readline())) + line = line.decode('latin-1').rstrip() + raise Return(line) @coroutine def read_headers(self): - status_line = yield from self.getline() + status_line = yield From(self.getline()) if self.verbose: print('<', status_line, file=sys.stderr) status_parts = status_line.split(None, 2) if len(status_parts) != 3: @@ -95,7 +103,7 @@ def read_headers(self): self.http_version, status, self.reason = status_parts self.status = int(status) while True: - header_line = yield from self.getline() + header_line = yield From(self.getline()) if not header_line: break if self.verbose: print('<', header_line, file=sys.stderr) @@ -112,20 +120,20 @@ def read(self): nbytes = int(value) break if nbytes is None: - body = yield from self.reader.read() + body = yield From(self.reader.read()) else: - body = yield from self.reader.readexactly(nbytes) - return body + body = yield From(self.reader.readexactly(nbytes)) + raise Return(body) @coroutine def fetch(url, verbose=True): request = Request(url, verbose) - yield from request.connect() - yield from request.send_request() - response = yield from request.get_response() - body = yield from response.read() - return body + yield From(request.connect()) + yield From(request.send_request()) + response = yield From(request.get_response()) + body = yield From(response.read()) + raise Return(body) def main(): @@ -134,7 +142,11 @@ def main(): body = loop.run_until_complete(fetch(sys.argv[1], '-v' in sys.argv)) finally: loop.close() - sys.stdout.buffer.write(body) + if hasattr(sys.stdout, 'buffer'): + sys.stdout.buffer.write(body) + else: + # Python 2 + sys.stdout.write(body) if __name__ == '__main__': diff --git a/examples/fetch3.py b/examples/fetch3.py index fc113d40..0fc56d1d 100644 --- a/examples/fetch3.py +++ b/examples/fetch3.py @@ -4,9 +4,15 @@ chunked transfer-encoding. It also supports a --iocp flag. """ +from __future__ import print_function import sys -import urllib.parse -from http.client import BadStatusLine +try: + from urllib.parse import urlparse + from http.client import BadStatusLine +except ImportError: + # Python 2 + from urlparse import urlparse + from httplib import BadStatusLine from trollius import * @@ -25,12 +31,13 @@ def close(self): @coroutine def open_connection(self, host, port, ssl): port = port or (443 if ssl else 80) - ipaddrs = yield from get_event_loop().getaddrinfo(host, port) + ipaddrs = yield From(get_event_loop().getaddrinfo(host, port)) if self.verbose: print('* %s resolves to %s' % (host, ', '.join(ip[4][0] for ip in ipaddrs)), file=sys.stderr) - for _, _, _, _, (h, p, *_) in ipaddrs: + for _, _, _, _, addr in ipaddrs: + h, p = addr[:2] key = h, p, ssl conn = self.connections.get(key) if conn: @@ -40,14 +47,15 @@ def open_connection(self, host, port, ssl): continue if self.verbose: print('* Reusing pooled connection', key, file=sys.stderr) - return conn - reader, writer = yield from open_connection(host, port, ssl=ssl) - host, port, *_ = writer.get_extra_info('peername') + raise Return(conn) + reader, writer = yield From(open_connection(host, port, ssl=ssl)) + addr = writer.get_extra_info('peername') + host, port = addr[:2] key = host, port, ssl self.connections[key] = reader, writer if self.verbose: print('* New connection', key, file=sys.stderr) - return reader, writer + raise Return(reader, writer) class Request: @@ -55,7 +63,7 @@ class Request: def __init__(self, url, verbose=True): self.url = url self.verbose = verbose - self.parts = urllib.parse.urlparse(self.url) + self.parts = urlparse(self.url) self.scheme = self.parts.scheme assert self.scheme in ('http', 'https'), repr(url) self.ssl = self.parts.scheme == 'https' @@ -83,9 +91,9 @@ def connect(self, pool): self.vprint('* Connecting to %s:%s using %s' % (self.hostname, self.port, 'ssl' if self.ssl else 'tcp')) self.reader, self.writer = \ - yield from pool.open_connection(self.hostname, + yield From(pool.open_connection(self.hostname, self.port, - ssl=self.ssl) + ssl=self.ssl)) self.vprint('* Connected to %s' % (self.writer.get_extra_info('peername'),)) @@ -93,24 +101,24 @@ def connect(self, pool): def putline(self, line): self.vprint('>', line) self.writer.write(line.encode('latin-1') + b'\r\n') - ##yield from self.writer.drain() + ##yield From(self.writer.drain()) @coroutine def send_request(self): request = '%s %s %s' % (self.method, self.full_path, self.http_version) - yield from self.putline(request) + yield From(self.putline(request)) if 'host' not in {key.lower() for key, _ in self.headers}: self.headers.insert(0, ('Host', self.netloc)) for key, value in self.headers: line = '%s: %s' % (key, value) - yield from self.putline(line) - yield from self.putline('') + yield From(self.putline(line)) + yield From(self.putline('')) @coroutine def get_response(self): response = Response(self.reader, self.verbose) - yield from response.read_headers() - return response + yield From(response.read_headers()) + raise Return(response) class Response: @@ -129,20 +137,21 @@ def vprint(self, *args): @coroutine def getline(self): - line = (yield from self.reader.readline()).decode('latin-1').rstrip() + line = (yield From(self.reader.readline())) + line = line.decode('latin-1').rstrip() self.vprint('<', line) - return line + raise Return(line) @coroutine def read_headers(self): - status_line = yield from self.getline() + status_line = yield From(self.getline()) status_parts = status_line.split(None, 2) if len(status_parts) != 3: raise BadStatusLine(status_line) self.http_version, status, self.reason = status_parts self.status = int(status) while True: - header_line = yield from self.getline() + header_line = yield From(self.getline()) if not header_line: break # TODO: Continuation lines. @@ -173,23 +182,23 @@ def read(self): blocks = [] size = -1 while size: - size_header = yield from self.reader.readline() + size_header = yield From(self.reader.readline()) if not size_header: break parts = size_header.split(b';') size = int(parts[0], 16) if size: - block = yield from self.reader.readexactly(size) + block = yield From(self.reader.readexactly(size)) assert len(block) == size, (len(block), size) blocks.append(block) - crlf = yield from self.reader.readline() + crlf = yield From(self.reader.readline()) assert crlf == b'\r\n', repr(crlf) body = b''.join(blocks) else: - body = yield from self.reader.read() + body = yield From(self.reader.read()) else: - body = yield from self.reader.readexactly(nbytes) - return body + body = yield From(self.reader.readexactly(nbytes)) + raise Return(body) @coroutine @@ -198,16 +207,16 @@ def fetch(url, verbose=True, max_redirect=10): try: for _ in range(max_redirect): request = Request(url, verbose) - yield from request.connect(pool) - yield from request.send_request() - response = yield from request.get_response() - body = yield from response.read() + yield From(request.connect(pool)) + yield From(request.send_request()) + response = yield From(request.get_response()) + body = yield From(response.read()) next_url = response.get_redirect_url() if not next_url: break url = urllib.parse.urljoin(url, next_url) print('redirect to', url, file=sys.stderr) - return body + raise Return(body) finally: pool.close() @@ -223,7 +232,11 @@ def main(): body = loop.run_until_complete(fetch(sys.argv[1], '-v' in sys.argv)) finally: loop.close() - sys.stdout.buffer.write(body) + if hasattr(sys.stdout, 'buffer'): + sys.stdout.buffer.write(body) + else: + # Python 2 + sys.stdout.write(body) if __name__ == '__main__': diff --git a/examples/fuzz_as_completed.py b/examples/fuzz_as_completed.py index f6203e84..7e74fe78 100644 --- a/examples/fuzz_as_completed.py +++ b/examples/fuzz_as_completed.py @@ -2,26 +2,29 @@ """Fuzz tester for as_completed(), by Glenn Langford.""" +from __future__ import print_function + import trollius as asyncio +from trollius import From, Return import itertools import random import sys @asyncio.coroutine def sleeper(time): - yield from asyncio.sleep(time) - return time + yield From(asyncio.sleep(time)) + raise Return(time) @asyncio.coroutine def watcher(tasks,delay=False): res = [] for t in asyncio.as_completed(tasks): - r = yield from t + r = yield From(t) res.append(r) if delay: # simulate processing delay process_time = random.random() / 10 - yield from asyncio.sleep(process_time) + yield From(asyncio.sleep(process_time)) #print(res) #assert(sorted(res) == res) if sorted(res) != res: diff --git a/examples/hello_callback.py b/examples/hello_callback.py index 07205d9b..f192c8dc 100644 --- a/examples/hello_callback.py +++ b/examples/hello_callback.py @@ -1,6 +1,6 @@ """Print 'Hello World' every two seconds, using a callback.""" -import trollius as asyncio +import trollius def print_and_repeat(loop): @@ -9,7 +9,7 @@ def print_and_repeat(loop): if __name__ == '__main__': - loop = asyncio.get_event_loop() + loop = trollius.get_event_loop() print_and_repeat(loop) try: loop.run_forever() diff --git a/examples/hello_coroutine.py b/examples/hello_coroutine.py index de716dee..e6a4e6ca 100644 --- a/examples/hello_coroutine.py +++ b/examples/hello_coroutine.py @@ -1,17 +1,18 @@ """Print 'Hello World' every two seconds, using a coroutine.""" -import trollius as asyncio +import trollius +from trollius import From -@asyncio.coroutine +@trollius.coroutine def greet_every_two_seconds(): while True: print('Hello World') - yield from asyncio.sleep(2) + yield From(trollius.sleep(2)) if __name__ == '__main__': - loop = asyncio.get_event_loop() + loop = trollius.get_event_loop() try: loop.run_until_complete(greet_every_two_seconds()) finally: diff --git a/examples/shell.py b/examples/shell.py index 61991a75..91ba7fb1 100644 --- a/examples/shell.py +++ b/examples/shell.py @@ -1,31 +1,33 @@ """Examples using create_subprocess_exec() and create_subprocess_shell().""" import trollius as asyncio +from trollius import From import signal from trollius.subprocess import PIPE +from trollius.py33_exceptions import ProcessLookupError @asyncio.coroutine def cat(loop): - proc = yield from asyncio.create_subprocess_shell("cat", + proc = yield From(asyncio.create_subprocess_shell("cat", stdin=PIPE, - stdout=PIPE) + stdout=PIPE)) print("pid: %s" % proc.pid) message = "Hello World!" print("cat write: %r" % message) - stdout, stderr = yield from proc.communicate(message.encode('ascii')) + stdout, stderr = yield From(proc.communicate(message.encode('ascii'))) print("cat read: %r" % stdout.decode('ascii')) - exitcode = yield from proc.wait() + exitcode = yield From(proc.wait()) print("(exit code %s)" % exitcode) @asyncio.coroutine def ls(loop): - proc = yield from asyncio.create_subprocess_exec("ls", - stdout=PIPE) + proc = yield From(asyncio.create_subprocess_exec("ls", + stdout=PIPE)) while True: - line = yield from proc.stdout.readline() + line = yield From(proc.stdout.readline()) if not line: break print("ls>>", line.decode('ascii').rstrip()) @@ -35,10 +37,11 @@ def ls(loop): pass @asyncio.coroutine -def test_call(*args, timeout=None): - proc = yield from asyncio.create_subprocess_exec(*args) +def test_call(*args, **kw): + timeout = kw.pop('timeout', None) try: - exitcode = yield from asyncio.wait_for(proc.wait(), timeout) + proc = yield From(asyncio.create_subprocess_exec(*args)) + exitcode = yield From(asyncio.wait_for(proc.wait(), timeout)) print("%s: exit code %s" % (' '.join(args), exitcode)) except asyncio.TimeoutError: print("timeout! (%.1f sec)" % timeout) diff --git a/examples/simple_tcp_server.py b/examples/simple_tcp_server.py index 3e847b07..247f6e6c 100644 --- a/examples/simple_tcp_server.py +++ b/examples/simple_tcp_server.py @@ -8,9 +8,11 @@ fail if this port is currently in use. """ +from __future__ import print_function import sys import trollius as asyncio import asyncio.streams +from trollius import From, Return class MyServer: @@ -58,28 +60,31 @@ def _handle_client(self, client_reader, client_writer): out one or more lines back to the client with the result. """ while True: - data = (yield from client_reader.readline()).decode("utf-8") + data = (yield From(client_reader.readline())) + data = data.decode("utf-8") if not data: # an empty string means the client disconnected break - cmd, *args = data.rstrip().split(' ') + parts = data.rstrip().split(' ') + cmd = parts[0] + args = parts[1:] if cmd == 'add': arg1 = float(args[0]) arg2 = float(args[1]) retval = arg1 + arg2 - client_writer.write("{!r}\n".format(retval).encode("utf-8")) + client_writer.write("{0!r}\n".format(retval).encode("utf-8")) elif cmd == 'repeat': times = int(args[0]) msg = args[1] client_writer.write("begin\n".encode("utf-8")) for idx in range(times): - client_writer.write("{}. {}\n".format(idx+1, msg) + client_writer.write("{0}. {1}\n".format(idx+1, msg) .encode("utf-8")) client_writer.write("end\n".encode("utf-8")) else: - print("Bad command {!r}".format(data), file=sys.stderr) + print("Bad command {0!r}".format(data), file=sys.stderr) # This enables us to have flow control in our connection. - yield from client_writer.drain() + yield From(client_writer.drain()) def start(self, loop): """ @@ -115,32 +120,33 @@ def main(): @asyncio.coroutine def client(): - reader, writer = yield from asyncio.streams.open_connection( - '127.0.0.1', 12345, loop=loop) + reader, writer = yield From(asyncio.streams.open_connection( + '127.0.0.1', 12345, loop=loop)) def send(msg): print("> " + msg) writer.write((msg + '\n').encode("utf-8")) def recv(): - msgback = (yield from reader.readline()).decode("utf-8").rstrip() + msgback = (yield From(reader.readline())) + msgback = msgback.decode("utf-8").rstrip() print("< " + msgback) - return msgback + raise Return(msgback) # send a line send("add 1 2") - msg = yield from recv() + msg = yield From(recv()) send("repeat 5 hello") - msg = yield from recv() + msg = yield From(recv()) assert msg == 'begin' while True: - msg = yield from recv() + msg = yield From(recv()) if msg == 'end': break writer.close() - yield from asyncio.sleep(0.5) + yield From(asyncio.sleep(0.5)) # creates a client and connects to our server try: diff --git a/examples/sink.py b/examples/sink.py index 8156b0ec..fb28adef 100644 --- a/examples/sink.py +++ b/examples/sink.py @@ -1,5 +1,6 @@ """Test service that accepts connections and reads all data off them.""" +from __future__ import print_function import argparse import os import sys @@ -63,16 +64,17 @@ def start(loop, host, port): import ssl # TODO: take cert/key from args as well. here = os.path.join(os.path.dirname(__file__), '..', 'tests') - sslctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - sslctx.options |= ssl.OP_NO_SSLv2 + sslctx = SSLContext(ssl.PROTOCOL_SSLv23) + if not BACKPORT_SSL_CONTEXT: + sslctx.options |= ssl.OP_NO_SSLv2 sslctx.load_cert_chain( certfile=os.path.join(here, 'ssl_cert.pem'), keyfile=os.path.join(here, 'ssl_key.pem')) - server = yield from loop.create_server(Service, host, port, ssl=sslctx) + server = yield From(loop.create_server(Service, host, port, ssl=sslctx)) dprint('serving TLS' if sslctx else 'serving', [s.getsockname() for s in server.sockets]) - yield from server.wait_closed() + yield From(server.wait_closed()) def main(): diff --git a/examples/source.py b/examples/source.py index 61c7ab79..c3ebd558 100644 --- a/examples/source.py +++ b/examples/source.py @@ -1,5 +1,6 @@ """Test client that connects and sends infinite data.""" +from __future__ import print_function import argparse import sys @@ -74,11 +75,11 @@ def start(loop, host, port): sslctx = None if args.tls: sslctx = test_utils.dummy_ssl_context() - tr, pr = yield from loop.create_connection(Client, host, port, - ssl=sslctx) + tr, pr = yield From(loop.create_connection(Client, host, port, + ssl=sslctx)) dprint('tr =', tr) dprint('pr =', pr) - yield from pr.waiter + yield From(pr.waiter) def main(): diff --git a/examples/source1.py b/examples/source1.py index 5af467a9..48a53af9 100644 --- a/examples/source1.py +++ b/examples/source1.py @@ -1,5 +1,6 @@ """Like source.py, but uses streams.""" +from __future__ import print_function import argparse import sys @@ -33,7 +34,7 @@ class Debug: overwriting = False label = 'stream1:' - def print(self, *args): + def print_(self, *args): if self.overwriting: print(file=sys.stderr) self.overwriting = 0 @@ -46,7 +47,8 @@ def oprint(self, *args): if self.overwriting == 3: print(self.label, '[...]', file=sys.stderr) end = '\r' - print(self.label, *args, file=sys.stderr, end=end, flush=True) + print(self.label, *args, file=sys.stderr, end=end) + sys.stdout.flush() @coroutine @@ -55,11 +57,11 @@ def start(loop, args): total = 0 sslctx = None if args.tls: - d.print('using dummy SSLContext') + d.print_('using dummy SSLContext') sslctx = test_utils.dummy_ssl_context() - r, w = yield from open_connection(args.host, args.port, ssl=sslctx) - d.print('r =', r) - d.print('w =', w) + r, w = yield From(open_connection(args.host, args.port, ssl=sslctx)) + d.print_('r =', r) + d.print_('w =', w) if args.stop: w.write(b'stop') w.close() @@ -73,10 +75,10 @@ def start(loop, args): w.write(data) f = w.drain() if f: - d.print('pausing') - yield from f + d.print_('pausing') + yield From(f) except (ConnectionResetError, BrokenPipeError) as exc: - d.print('caught', repr(exc)) + d.print_('caught', repr(exc)) def main(): diff --git a/examples/stacks.py b/examples/stacks.py index e03e78e6..abe24a0f 100644 --- a/examples/stacks.py +++ b/examples/stacks.py @@ -10,9 +10,9 @@ def helper(r): for t in Task.all_tasks(): t.print_stack() print('--- end helper ---') - line = yield from r.readline() + line = yield From(r.readline()) 1/0 - return line + raise Return(line) def doit(): l = get_event_loop() diff --git a/examples/subprocess_attach_read_pipe.py b/examples/subprocess_attach_read_pipe.py index 2cadc849..a2f9bb5d 100644 --- a/examples/subprocess_attach_read_pipe.py +++ b/examples/subprocess_attach_read_pipe.py @@ -2,6 +2,7 @@ """Example showing how to attach a read pipe to a subprocess.""" import trollius as asyncio import os, sys +from trollius import From code = """ import os, sys @@ -17,16 +18,19 @@ def task(): rfd, wfd = os.pipe() args = [sys.executable, '-c', code, str(wfd)] - pipe = open(rfd, 'rb', 0) + pipe = os.fdopen(rfd, 'rb', 0) reader = asyncio.StreamReader(loop=loop) protocol = asyncio.StreamReaderProtocol(reader, loop=loop) - transport, _ = yield from loop.connect_read_pipe(lambda: protocol, pipe) + transport, _ = yield From(loop.connect_read_pipe(lambda: protocol, pipe)) - proc = yield from asyncio.create_subprocess_exec(*args, pass_fds={wfd}) - yield from proc.wait() + kwds = {} + if sys.version_info >= (3, 2): + kwds['pass_fds'] = (wfd,) + proc = yield From(asyncio.create_subprocess_exec(*args, **kwds)) + yield From(proc.wait()) os.close(wfd) - data = yield from reader.read() + data = yield From(reader.read()) print("read = %r" % data.decode()) loop.run_until_complete(task()) diff --git a/examples/subprocess_attach_write_pipe.py b/examples/subprocess_attach_write_pipe.py index 646b2691..8b9e7ec9 100644 --- a/examples/subprocess_attach_write_pipe.py +++ b/examples/subprocess_attach_write_pipe.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 """Example showing how to attach a write pipe to a subprocess.""" import trollius as asyncio +from trollius import From import os, sys from trollius import subprocess @@ -8,7 +9,11 @@ import os, sys fd = int(sys.argv[1]) data = os.read(fd, 1024) -sys.stdout.buffer.write(data) +if sys.version_info >= (3,): + stdout = sys.stdout.buffer +else: + stdout = sys.stdout +stdout.write(data) """ loop = asyncio.get_event_loop() @@ -17,19 +22,19 @@ def task(): rfd, wfd = os.pipe() args = [sys.executable, '-c', code, str(rfd)] - proc = yield from asyncio.create_subprocess_exec( - *args, - pass_fds={rfd}, - stdout=subprocess.PIPE) + kwargs = {'stdout': subprocess.PIPE} + if sys.version_info >= (3, 2): + kwargs['pass_fds'] = (rfd,) + proc = yield From(asyncio.create_subprocess_exec(*args, **kwargs)) - pipe = open(wfd, 'wb', 0) - transport, _ = yield from loop.connect_write_pipe(asyncio.Protocol, - pipe) + pipe = os.fdopen(wfd, 'wb', 0) + transport, _ = yield From(loop.connect_write_pipe(asyncio.Protocol, + pipe)) transport.write(b'data') - stdout, stderr = yield from proc.communicate() + stdout, stderr = yield From(proc.communicate()) print("stdout = %r" % stdout.decode()) - transport.close() + pipe.close() loop.run_until_complete(task()) loop.close() diff --git a/examples/subprocess_shell.py b/examples/subprocess_shell.py index 71d125ce..89412367 100644 --- a/examples/subprocess_shell.py +++ b/examples/subprocess_shell.py @@ -3,19 +3,21 @@ import trollius as asyncio import os +from trollius import From from trollius.subprocess import PIPE +from trollius.py33_exceptions import BrokenPipeError, ConnectionResetError @asyncio.coroutine def send_input(writer, input): try: for line in input: - print('sending', len(line), 'bytes') + print('sending %s bytes' % len(line)) writer.write(line) d = writer.drain() if d: print('pause writing') - yield from d + yield From(d) print('resume writing') writer.close() except BrokenPipeError: @@ -26,7 +28,7 @@ def send_input(writer, input): @asyncio.coroutine def log_errors(reader): while True: - line = yield from reader.readline() + line = yield From(reader.readline()) if not line: break print('ERROR', repr(line)) @@ -34,7 +36,7 @@ def log_errors(reader): @asyncio.coroutine def read_stdout(stdout): while True: - line = yield from stdout.readline() + line = yield From(stdout.readline()) print('received', repr(line)) if not line: break @@ -47,7 +49,7 @@ def start(cmd, input=None, **kwds): kwds['stdin'] = None else: kwds['stdin'] = PIPE - proc = yield from asyncio.create_subprocess_shell(cmd, **kwds) + proc = yield From(asyncio.create_subprocess_shell(cmd, **kwds)) tasks = [] if input is not None: @@ -66,9 +68,9 @@ def start(cmd, input=None, **kwds): if tasks: # feed stdin while consuming stdout to avoid hang # when stdin pipe is full - yield from asyncio.wait(tasks) + yield From(asyncio.wait(tasks)) - exitcode = yield from proc.wait() + exitcode = yield From(proc.wait()) print("exit code: %s" % exitcode) diff --git a/examples/tcp_echo.py b/examples/tcp_echo.py index 1a0a2c61..773327f7 100755 --- a/examples/tcp_echo.py +++ b/examples/tcp_echo.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python3 +#!/usr/bin/env python """TCP echo server example.""" import argparse import trollius as asyncio diff --git a/examples/timing_tcp_server.py b/examples/timing_tcp_server.py index 664ba63a..c93c407d 100644 --- a/examples/timing_tcp_server.py +++ b/examples/timing_tcp_server.py @@ -8,12 +8,14 @@ fail if this port is currently in use. """ +from __future__ import print_function import sys import time import random import trollius as asyncio import asyncio.streams +from trollius import From, Return class MyServer: @@ -61,29 +63,32 @@ def _handle_client(self, client_reader, client_writer): out one or more lines back to the client with the result. """ while True: - data = (yield from client_reader.readline()).decode("utf-8") + data = (yield From(client_reader.readline())) + data = data.decode("utf-8") if not data: # an empty string means the client disconnected break - cmd, *args = data.rstrip().split(' ') + parts = data.rstrip().split(' ') + cmd = parts[0] + args = parts[1:] if cmd == 'add': arg1 = float(args[0]) arg2 = float(args[1]) retval = arg1 + arg2 - client_writer.write("{!r}\n".format(retval).encode("utf-8")) + client_writer.write("{0!r}\n".format(retval).encode("utf-8")) elif cmd == 'repeat': times = int(args[0]) msg = args[1] client_writer.write("begin\n".encode("utf-8")) for idx in range(times): - client_writer.write("{}. {}\n".format( + client_writer.write("{0}. {1}\n".format( idx+1, msg + 'x'*random.randint(10, 50)) .encode("utf-8")) client_writer.write("end\n".encode("utf-8")) else: - print("Bad command {!r}".format(data), file=sys.stderr) + print("Bad command {0!r}".format(data), file=sys.stderr) # This enables us to have flow control in our connection. - yield from client_writer.drain() + yield From(client_writer.drain()) def start(self, loop): """ @@ -119,42 +124,44 @@ def main(): @asyncio.coroutine def client(): - reader, writer = yield from asyncio.streams.open_connection( - '127.0.0.1', 12345, loop=loop) + reader, writer = yield From(asyncio.streams.open_connection( + '127.0.0.1', 12345, loop=loop)) def send(msg): print("> " + msg) writer.write((msg + '\n').encode("utf-8")) def recv(): - msgback = (yield from reader.readline()).decode("utf-8").rstrip() + msgback = (yield From(reader.readline())) + msgback = msgback.decode("utf-8").rstrip() print("< " + msgback) - return msgback + raise Return(msgback) # send a line send("add 1 2") - msg = yield from recv() + msg = yield From(recv()) Ns = list(range(100, 100000, 10000)) times = [] for N in Ns: t0 = time.time() - send("repeat {} hello world ".format(N)) - msg = yield from recv() + send("repeat {0} hello world ".format(N)) + msg = yield From(recv()) assert msg == 'begin' while True: - msg = (yield from reader.readline()).decode("utf-8").rstrip() + msg = (yield From(reader.readline())) + msg = msg.decode("utf-8").rstrip() if msg == 'end': break t1 = time.time() dt = t1 - t0 - print("Time taken: {:.3f} seconds ({:.6f} per repetition)" + print("Time taken: {0:.3f} seconds ({1:.6f} per repetition)" .format(dt, dt/N)) times.append(dt) writer.close() - yield from asyncio.sleep(0.5) + yield From(asyncio.sleep(0.5)) # creates a client and connects to our server try: diff --git a/examples/udp_echo.py b/examples/udp_echo.py index b13303ff..bd646396 100755 --- a/examples/udp_echo.py +++ b/examples/udp_echo.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python3 +#!/usr/bin/env python """UDP echo example.""" import argparse import sys @@ -32,12 +32,12 @@ class MyClientUdpEchoProtocol: def connection_made(self, transport): self.transport = transport - print('sending "{}"'.format(self.message)) + print('sending "{0}"'.format(self.message)) self.transport.sendto(self.message.encode()) print('waiting to receive') def datagram_received(self, data, addr): - print('received "{}"'.format(data.decode())) + print('received "{0}"'.format(data.decode())) self.transport.close() def error_received(self, exc): From 22df34e7ebd65838c1cabdec1280edbdb5dda2ac Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Tue, 7 Jul 2015 23:50:26 +0200 Subject: [PATCH 1398/1502] Add interop_asyncio.py example --- examples/interop_asyncio.py | 53 +++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 examples/interop_asyncio.py diff --git a/examples/interop_asyncio.py b/examples/interop_asyncio.py new file mode 100644 index 00000000..b20e3edb --- /dev/null +++ b/examples/interop_asyncio.py @@ -0,0 +1,53 @@ +import asyncio +import trollius + +@asyncio.coroutine +def asyncio_noop(): + pass + +@asyncio.coroutine +def asyncio_coroutine(coro): + print("asyncio coroutine") + res = yield from coro + print("asyncio inner coroutine result: %r" % (res,)) + print("asyncio coroutine done") + return "asyncio" + +@trollius.coroutine +def trollius_noop(): + pass + +@trollius.coroutine +def trollius_coroutine(coro): + print("trollius coroutine") + res = yield trollius.From(coro) + print("trollius inner coroutine result: %r" % (res,)) + print("trollius coroutine done") + raise trollius.Return("trollius") + +def main(): + # use trollius event loop policy in asyncio + policy = trollius.get_event_loop_policy() + asyncio.set_event_loop_policy(policy) + + # create an event loop for the main thread: use Trollius event loop + loop = trollius.get_event_loop() + assert asyncio.get_event_loop() is loop + + print("[ asyncio coroutine called from trollius coroutine ]") + coro1 = asyncio_noop() + coro2 = asyncio_coroutine(coro1) + res = loop.run_until_complete(trollius_coroutine(coro2)) + print("trollius coroutine result: %r" % res) + print("") + + print("[ asyncio coroutine called from trollius coroutine ]") + coro1 = trollius_noop() + coro2 = trollius_coroutine(coro1) + res = loop.run_until_complete(asyncio_coroutine(coro2)) + print("asyncio coroutine result: %r" % res) + print("") + + loop.close() + +main() From 962dea8ccf212b1f095755521f19fc688cecdf8b Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 8 Jul 2015 00:00:15 +0200 Subject: [PATCH 1399/1502] copy files from trollius --- AUTHORS | 40 ++++++++++++--------------------- MANIFEST.in | 16 +++++++------ Makefile | 4 +++- setup.py | 65 ++++++++++++++++++++++++++++++++++++----------------- tox.ini | 44 +++++++++++++++++++++++++++++------- 5 files changed, 106 insertions(+), 63 deletions(-) diff --git a/AUTHORS b/AUTHORS index d25b4465..c625633a 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,26 +1,14 @@ -A. Jesse Jiryu Davis -Aaron Griffith -Andrew Svetlov -Anthony Baire -Antoine Pitrou -Arnaud Faure -Aymeric Augustin -Brett Cannon -Charles-François Natali -Christian Heimes -Donald Stufft -Eli Bendersky -Geert Jansen -Giampaolo Rodola' -Guido van Rossum : creator of the Tulip project and author of the PEP 3156 -Gustavo Carneiro -Jeff Quast -Jonathan Slenders -Nikolay Kim -Richard Oudkerk -Saúl Ibarra Corretgé -Serhiy Storchaka -Vajrasky Kok -Victor Stinner -Vladimir Kryachko -Yury Selivanov +Trollius authors +================ + +Ian Wienand +Marc Schlaich +Victor Stinner - creator of the Trollius project + +The photo of Trollis flower was taken by Imartin6 and distributed under the CC +BY-SA 3.0 license. It comes from: +http://commons.wikimedia.org/wiki/File:Trollius_altaicus.jpg + +Trollius is a port of the Tulip project on Python 2, see also authors of the +Tulip project (AUTHORS file of the Tulip project). + diff --git a/MANIFEST.in b/MANIFEST.in index d0dbde14..76c3383f 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,11 +1,13 @@ -include AUTHORS COPYING +include AUTHORS COPYING TODO tox.ini include Makefile include overlapped.c pypi.bat include check.py runtests.py run_aiotest.py release.py -include update_stdlib.sh +include update-tulip*.sh -recursive-include examples *.py -recursive-include tests *.crt -recursive-include tests *.key -recursive-include tests *.pem -recursive-include tests *.py +include doc/conf.py doc/make.bat doc/Makefile +include doc/*.rst doc/*.jpg + +include examples/*.py + +include tests/*.crt tests/*.pem tests/*.key +include tests/*.py diff --git a/Makefile b/Makefile index eda02f2d..768298b1 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ # Some simple testing tasks (sorry, UNIX only). -PYTHON=python3 +PYTHON=python VERBOSE=$(V) V= 0 FLAGS= @@ -40,6 +40,8 @@ clean: rm -rf build rm -rf asyncio.egg-info rm -f MANIFEST + rm -rf trollius.egg-info + rm -rf .tox # For distribution builders only! diff --git a/setup.py b/setup.py index 660d38fb..701d8ad6 100644 --- a/setup.py +++ b/setup.py @@ -1,23 +1,36 @@ # Release procedure: -# - run tox (to run runtests.py and run_aiotest.py) -# - maybe test examples -# - update version in setup.py +# - fill Tulip changelog +# - run maybe update_tulip.sh +# - run unit tests with concurrent.futures +# - run unit tests without concurrent.futures +# - run unit tests without ssl: set sys.modules['ssl']=None at startup +# - test examples +# - update version in setup.py (version) and doc/conf.py (version, release) +# - set release date in doc/changelog.rst +# - check that "python setup.py sdist" contains all files tracked by +# the SCM (Mercurial): update MANIFEST.in if needed # - hg ci -# - hg tag VERSION +# - hg tag trollius-VERSION # - hg push -# - run on Linux: python setup.py register sdist upload -# - run on Windows: python release.py VERSION -# - increment version in setup.py +# - On Linux: python setup.py register sdist bdist_wheel upload +# - On Windows: python release.py release +# - increment version in setup.py (version) and doc/conf.py (version, release) # - hg ci && hg push import os +import sys try: from setuptools import setup, Extension + SETUPTOOLS = True except ImportError: + SETUPTOOLS = False # Use distutils.core as a fallback. # We won't be able to build the Wheel file on Windows. from distutils.core import setup, Extension +with open("README") as fp: + long_description = fp.read() + extensions = [] if os.name == 'nt': ext = Extension( @@ -25,25 +38,35 @@ ) extensions.append(ext) -with open("README.rst") as fp: - long_description = fp.read() +requirements = [] +if sys.version_info < (2, 7): + requirements.append('ordereddict') +if sys.version_info < (3,): + requirements.append('futures') -setup( - name="trollius", - version="3.4.4", +install_options = { + "name": "trollius", + "version": "1.0.5", + "license": "Apache License 2.0", + "author": 'Victor Stinner', + "author_email": 'victor.stinner@gmail.com', - description="reference implementation of PEP 3156", - long_description=long_description, - url="http://www.python.org/dev/peps/pep-3156/", + "description": "Port of the Tulip project (asyncio module, PEP 3156) on Python 2", + "long_description": long_description, + "url": "https://bitbucket.org/enovance/trollius/", - classifiers=[ + "classifiers": [ "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.3", + "License :: OSI Approved :: Apache Software License", ], - packages=["trollius"], - test_suite="runtests.runtests", + "packages": ["trollius"], + "test_suite": "runtests.runtests", + + "ext_modules": extensions, +} +if SETUPTOOLS: + install_options['install_requires'] = requirements - ext_modules=extensions, -) +setup(**install_options) diff --git a/tox.ini b/tox.ini index 3030441f..229b13cb 100644 --- a/tox.ini +++ b/tox.ini @@ -1,21 +1,49 @@ [tox] -envlist = py33,py34,py3_release +envlist = py26,py27,py2_release,py32,py33,py34,py3_release [testenv] deps= aiotest -# Run tests in debug mode setenv = - PYTHONASYNCIODEBUG = 1 + TROLLIUSDEBUG = 1 commands= - python -Wd runtests.py -r {posargs} - python -Wd run_aiotest.py -r {posargs} + python runtests.py -r {posargs} + python run_aiotest.py -r {posargs} -[testenv:py3_release] +[testenv:py26] +deps= + aiotest + futures + mock + ordereddict + unittest2 + +[testenv:py27] +deps= + aiotest + futures + mock + +[testenv:py2_release] # Run tests in release mode +deps= + aiotest + futures + mock setenv = - PYTHONASYNCIODEBUG = -basepython = python3 + TROLLIUSDEBUG = +basepython = python2.7 + +[testenv:py32] +deps= + aiotest + mock [testenv:py35] basepython = python3.5 + +[testenv:py3_release] +# Run tests in release mode +setenv = + TROLLIUSDEBUG = +basepython = python3 From 3d64966ed97ba78fdb8aec1d6d4ff4369f36f29e Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 8 Jul 2015 00:00:35 +0200 Subject: [PATCH 1400/1502] replace update_stdlib with update-tulip --- update-tulip-step1.sh | 9 ++++++ update-tulip-step2.sh | 41 +++++++++++++++++++++++++ update-tulip-step3.sh | 4 +++ update_stdlib.sh | 70 ------------------------------------------- 4 files changed, 54 insertions(+), 70 deletions(-) create mode 100755 update-tulip-step1.sh create mode 100755 update-tulip-step2.sh create mode 100755 update-tulip-step3.sh delete mode 100755 update_stdlib.sh diff --git a/update-tulip-step1.sh b/update-tulip-step1.sh new file mode 100755 index 00000000..0b39f7ee --- /dev/null +++ b/update-tulip-step1.sh @@ -0,0 +1,9 @@ +set -e -x +hg update trollius +hg pull --update +hg update default +hg pull https://code.google.com/p/tulip/ +hg update +hg update trollius +hg merge default +echo "Now run ./update-tulip-step2.sh" diff --git a/update-tulip-step2.sh b/update-tulip-step2.sh new file mode 100755 index 00000000..ebf98bdd --- /dev/null +++ b/update-tulip-step2.sh @@ -0,0 +1,41 @@ +set -e + +# Check for merge conflicts +if $(hg resolve -l | grep -q -v '^R'); then + echo "Fix the following conflicts:" + hg resolve -l | grep -v '^R' + exit 1 +fi + +# Ensure that yield from is not used +if $(hg diff|grep -q 'yield from'); then + echo "yield from present in changed code!" + hg diff | grep 'yield from' -B5 -A3 + exit 1 +fi + +# Ensure that mock patchs trollius module, not asyncio +if $(grep -q 'patch.*asyncio' tests/*.py); then + echo "Fix following patch lines in tests/" + grep 'patch.*asyncio' tests/*.py + exit 1 +fi + +# Python 2.6 compatibility +if $(grep -q -E '\{[^0-9].*format' */*.py); then + echo "Issues with Python 2.6 compatibility:" + grep -E '\{[^0-9].*format' */*.py + exit 1 +fi +if $(grep -q -E 'unittest\.skip' tests/*.py); then + echo "Issues with Python 2.6 compatibility:" + grep -E 'unittest\.skip' tests/*.py + exit 1 +fi +if $(grep -q -F 'super()' */*.py); then + echo "Issues with Python 2.6 compatibility:" + grep -F 'super()' */*.py + exit 1 +fi + +echo "Now run ./update-tulip-step3.sh" diff --git a/update-tulip-step3.sh b/update-tulip-step3.sh new file mode 100755 index 00000000..202b44bc --- /dev/null +++ b/update-tulip-step3.sh @@ -0,0 +1,4 @@ +set -e -x +./update-tulip-step2.sh +tox -e py27,py34 +hg ci -m 'Merge Tulip into Trollius' diff --git a/update_stdlib.sh b/update_stdlib.sh deleted file mode 100755 index 0cdbb1bd..00000000 --- a/update_stdlib.sh +++ /dev/null @@ -1,70 +0,0 @@ -#!/bin/bash - -# Script to copy asyncio files to the standard library tree. -# Optional argument is the root of the Python 3.4 tree. -# Assumes you have already created Lib/asyncio and -# Lib/test/test_asyncio in the destination tree. - -CPYTHON=${1-$HOME/cpython} - -if [ ! -d $CPYTHON ] -then - echo Bad destination $CPYTHON - exit 1 -fi - -if [ ! -f asyncio/__init__.py ] -then - echo Bad current directory - exit 1 -fi - -maybe_copy() -{ - SRC=$1 - DST=$CPYTHON/$2 - if cmp $DST $SRC - then - return - fi - echo ======== $SRC === $DST ======== - diff -u $DST $SRC - echo -n "Copy $SRC? [y/N/back] " - read X - case $X in - [yY]*) echo Copying $SRC; cp $SRC $DST;; - back) echo Copying TO $SRC; cp $DST $SRC;; - *) echo Not copying $SRC;; - esac -} - -for i in `(cd asyncio && ls *.py)` -do - if [ $i == test_support.py ] - then - continue - fi - - if [ $i == selectors.py ] - then - if [ "`(cd $CPYTHON; hg branch)`" == "3.4" ] - then - echo "Destination is 3.4 branch -- ignoring selectors.py" - else - maybe_copy asyncio/$i Lib/$i - fi - else - maybe_copy asyncio/$i Lib/asyncio/$i - fi -done - -for i in `(cd tests && ls *.py *.pem)` -do - if [ $i == test_selectors.py ] - then - continue - fi - maybe_copy tests/$i Lib/test/test_asyncio/$i -done - -maybe_copy overlapped.c Modules/overlapped.c From f09e57b9ad38d8eee6947c76794395c2715cddf6 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 8 Jul 2015 00:01:07 +0200 Subject: [PATCH 1401/1502] Port overlapped.c --- overlapped.c | 71 +++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 67 insertions(+), 4 deletions(-) diff --git a/overlapped.c b/overlapped.c index ef77c887..52ff9deb 100644 --- a/overlapped.c +++ b/overlapped.c @@ -31,6 +31,18 @@ #define T_HANDLE T_POINTER +#if PY_MAJOR_VERSION >= 3 +# define PYTHON3 +#endif + +#ifndef Py_MIN +# define Py_MIN(X, Y) (((X) < (Y)) ? (X) : (Y)) +#endif + +#ifndef Py_MAX +# define Py_MAX(X, Y) (((X) > (Y)) ? (X) : (Y)) +#endif + enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_WRITE, TYPE_ACCEPT, TYPE_CONNECT, TYPE_DISCONNECT, TYPE_CONNECT_NAMED_PIPE, TYPE_WAIT_NAMED_PIPE_AND_CONNECT}; @@ -63,6 +75,7 @@ SetFromWindowsErr(DWORD err) if (err == 0) err = GetLastError(); +#if (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION >= 3) || PY_MAJOR_VERSION > 3 switch (err) { case ERROR_CONNECTION_REFUSED: exception_type = PyExc_ConnectionRefusedError; @@ -73,6 +86,9 @@ SetFromWindowsErr(DWORD err) default: exception_type = PyExc_OSError; } +#else + exception_type = PyExc_WindowsError; +#endif return PyErr_SetExcFromWindowsErr(exception_type, err); } @@ -345,7 +361,11 @@ overlapped_CreateEvent(PyObject *self, PyObject *args) Py_UNICODE *Name; HANDLE Event; +#ifdef PYTHON3 if (!PyArg_ParseTuple(args, "O" F_BOOL F_BOOL "Z", +#else + if (!PyArg_ParseTuple(args, "O" F_BOOL F_BOOL "z", +#endif &EventAttributes, &ManualReset, &InitialState, &Name)) return NULL; @@ -822,7 +842,11 @@ Overlapped_WriteFile(OverlappedObject *self, PyObject *args) return NULL; } +#ifdef PYTHON3 if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) +#else + if (!PyArg_Parse(bufobj, "s*", &self->write_buffer)) +#endif return NULL; #if SIZEOF_SIZE_T > SIZEOF_LONG @@ -878,7 +902,11 @@ Overlapped_WSASend(OverlappedObject *self, PyObject *args) return NULL; } +#ifdef PYTHON3 if (!PyArg_Parse(bufobj, "y*", &self->write_buffer)) +#else + if (!PyArg_Parse(bufobj, "s*", &self->write_buffer)) +#endif return NULL; #if SIZEOF_SIZE_T > SIZEOF_LONG @@ -1136,8 +1164,9 @@ static PyObject * ConnectPipe(OverlappedObject *self, PyObject *args) { PyObject *AddressObj; - wchar_t *Address; HANDLE PipeHandle; +#ifdef PYTHON3 + wchar_t *Address; if (!PyArg_ParseTuple(args, "U", &AddressObj)) return NULL; @@ -1146,14 +1175,26 @@ ConnectPipe(OverlappedObject *self, PyObject *args) if (Address == NULL) return NULL; +# define CREATE_FILE CreateFileW +#else + char *Address; + + if (!PyArg_ParseTuple(args, "s", &Address)) + return NULL; + +# define CREATE_FILE CreateFileA +#endif + Py_BEGIN_ALLOW_THREADS - PipeHandle = CreateFileW(Address, + PipeHandle = CREATE_FILE(Address, GENERIC_READ | GENERIC_WRITE, 0, NULL, OPEN_EXISTING, FILE_FLAG_OVERLAPPED, NULL); Py_END_ALLOW_THREADS +#ifdef PYTHON3 PyMem_Free(Address); +#endif if (PipeHandle == INVALID_HANDLE_VALUE) return SetFromWindowsErr(0); return Py_BuildValue(F_HANDLE, PipeHandle); @@ -1284,6 +1325,7 @@ static PyMethodDef overlapped_functions[] = { {NULL} }; +#ifdef PYTHON3 static struct PyModuleDef overlapped_module = { PyModuleDef_HEAD_INIT, "_overlapped", @@ -1295,12 +1337,13 @@ static struct PyModuleDef overlapped_module = { NULL, NULL }; +#endif #define WINAPI_CONSTANT(fmt, con) \ PyDict_SetItemString(d, #con, Py_BuildValue(fmt, con)) -PyMODINIT_FUNC -PyInit__overlapped(void) +PyObject* +_init_overlapped(void) { PyObject *m, *d; @@ -1316,7 +1359,11 @@ PyInit__overlapped(void) if (PyType_Ready(&OverlappedType) < 0) return NULL; +#ifdef PYTHON3 m = PyModule_Create(&overlapped_module); +#else + m = Py_InitModule("_overlapped", overlapped_functions); +#endif if (PyModule_AddObject(m, "Overlapped", (PyObject *)&OverlappedType) < 0) return NULL; @@ -1332,6 +1379,22 @@ PyInit__overlapped(void) WINAPI_CONSTANT(F_DWORD, SO_UPDATE_ACCEPT_CONTEXT); WINAPI_CONSTANT(F_DWORD, SO_UPDATE_CONNECT_CONTEXT); WINAPI_CONSTANT(F_DWORD, TF_REUSE_SOCKET); + WINAPI_CONSTANT(F_DWORD, ERROR_CONNECTION_REFUSED); + WINAPI_CONSTANT(F_DWORD, ERROR_CONNECTION_ABORTED); return m; } + +#ifdef PYTHON3 +PyMODINIT_FUNC +PyInit__overlapped(void) +{ + return _init_overlapped(); +} +#else +PyMODINIT_FUNC +init_overlapped(void) +{ + _init_overlapped(); +} +#endif From 2fd3bf4921aad316da774317e183f3795012904d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 8 Jul 2015 00:01:21 +0200 Subject: [PATCH 1402/1502] Port remaining files --- check.py | 2 +- run_aiotest.py | 2 +- runtests.py | 112 ++++++++++++++++++++++++++++--------------------- 3 files changed, 67 insertions(+), 49 deletions(-) mode change 100644 => 100755 runtests.py diff --git a/check.py b/check.py index 6db82d64..dcefc185 100644 --- a/check.py +++ b/check.py @@ -37,7 +37,7 @@ def process(fn): line = line.rstrip('\n') sline = line.rstrip() if len(line) >= 80 or line != sline or not isascii(line): - print('{}:{:d}:{}{}'.format( + print('{0}:{1:d}:{2}{3}'.format( fn, i+1, sline, '_' * (len(line) - len(sline)))) finally: f.close() diff --git a/run_aiotest.py b/run_aiotest.py index 0a08d66a..da133286 100644 --- a/run_aiotest.py +++ b/run_aiotest.py @@ -1,6 +1,6 @@ import aiotest.run -import trollius import sys +import trollius if sys.platform == 'win32': from trollius.windows_utils import socketpair else: diff --git a/runtests.py b/runtests.py old mode 100644 new mode 100755 index c38b0c18..63c0fba2 --- a/runtests.py +++ b/runtests.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python3 +#!/usr/bin/env python """Run Tulip unittests. Usage: @@ -20,86 +20,100 @@ # Originally written by Beech Horn (for NDB). -import argparse +from __future__ import print_function +import optparse import gc import logging import os import random import re import sys -import unittest import textwrap -import importlib.machinery +from trollius.compat import PY33 +if PY33: + import importlib.machinery +else: + import imp try: import coverage except ImportError: coverage = None +if sys.version_info < (3,): + sys.exc_clear() -from unittest.signals import installHandler - -assert sys.version >= '3.3', 'Please use Python 3.3 or higher.' - -ARGS = argparse.ArgumentParser(description="Run all unittests.") -ARGS.add_argument( - '-v', action="store", dest='verbose', - nargs='?', const=1, type=int, default=0, help='verbose') -ARGS.add_argument( +try: + import unittest + from unittest.signals import installHandler +except ImportError: + import unittest2 as unittest + from unittest2.signals import installHandler + +ARGS = optparse.OptionParser(description="Run all unittests.", usage="%prog [options] [pattern] [pattern2 ...]") +ARGS.add_option( + '-v', '--verbose', action="store_true", dest='verbose', + default=0, help='verbose') +ARGS.add_option( '-x', action="store_true", dest='exclude', help='exclude tests') -ARGS.add_argument( +ARGS.add_option( '-f', '--failfast', action="store_true", default=False, dest='failfast', help='Stop on first fail or error') -ARGS.add_argument( +ARGS.add_option( '-c', '--catch', action="store_true", default=False, dest='catchbreak', help='Catch control-C and display results') -ARGS.add_argument( +ARGS.add_option( '--forever', action="store_true", dest='forever', default=False, help='run tests forever to catch sporadic errors') -ARGS.add_argument( +ARGS.add_option( '--findleaks', action='store_true', dest='findleaks', help='detect tests that leak memory') -ARGS.add_argument('-r', '--randomize', action='store_true', - help='randomize test execution order.') -ARGS.add_argument('--seed', type=int, - help='random seed to reproduce a previous random run') -ARGS.add_argument( +ARGS.add_option( + '-r', '--randomize', action='store_true', + help='randomize test execution order.') +ARGS.add_option( + '--seed', type=int, + help='random seed to reproduce a previous random run') +ARGS.add_option( '-q', action="store_true", dest='quiet', help='quiet') -ARGS.add_argument( +ARGS.add_option( '--tests', action="store", dest='testsdir', default='tests', help='tests directory') -ARGS.add_argument( +ARGS.add_option( '--coverage', action="store_true", dest='coverage', help='enable html coverage report') -ARGS.add_argument( - 'pattern', action="store", nargs="*", - help='optional regex patterns to match test ids (default all tests)') -COV_ARGS = argparse.ArgumentParser(description="Run all unittests.") -COV_ARGS.add_argument( - '--coverage', action="store", dest='coverage', nargs='?', const='', - help='enable coverage report and provide python files directory') + +if PY33: + def load_module(modname, sourcefile): + loader = importlib.machinery.SourceFileLoader(modname, sourcefile) + return loader.load_module() +else: + def load_module(modname, sourcefile): + return imp.load_source(modname, sourcefile) def load_modules(basedir, suffix='.py'): + import trollius.test_utils + def list_dir(prefix, dir): files = [] modpath = os.path.join(dir, '__init__.py') if os.path.isfile(modpath): mod = os.path.split(dir)[-1] - files.append(('{}{}'.format(prefix, mod), modpath)) + files.append(('{0}{1}'.format(prefix, mod), modpath)) - prefix = '{}{}.'.format(prefix, mod) + prefix = '{0}{1}.'.format(prefix, mod) for name in os.listdir(dir): path = os.path.join(dir, name) if os.path.isdir(path): - files.extend(list_dir('{}{}.'.format(prefix, name), path)) + files.extend(list_dir('{0}{1}.'.format(prefix, name), path)) else: if (name != '__init__.py' and name.endswith(suffix) and not name.startswith(('.', '_'))): - files.append(('{}{}'.format(prefix, name[:-3]), path)) + files.append(('{0}{1}'.format(prefix, name[:-3]), path)) return files @@ -107,13 +121,17 @@ def list_dir(prefix, dir): for modname, sourcefile in list_dir('', basedir): if modname == 'runtests': continue + if modname == 'test_asyncio' and sys.version_info <= (3, 3): + print("Skipping '{0}': need at least Python 3.3".format(modname), + file=sys.stderr) + continue try: - loader = importlib.machinery.SourceFileLoader(modname, sourcefile) - mods.append((loader.load_module(), sourcefile)) + mod = load_module(modname, sourcefile) + mods.append((mod, sourcefile)) except SyntaxError: raise - except unittest.SkipTest as err: - print("Skipping '{}': {}".format(modname, err), file=sys.stderr) + except trollius.test_utils.SkipTest as err: + print("Skipping '{0}': {1}".format(modname, err), file=sys.stderr) return mods @@ -198,7 +216,7 @@ class TestRunner(unittest.TextTestRunner): def run(self, test): result = super().run(test) if result.leaks: - self.stream.writeln("{} tests leaks:".format(len(result.leaks))) + self.stream.writeln("{0} tests leaks:".format(len(result.leaks))) for name, leaks in result.leaks: self.stream.writeln(' '*4 + name + ':') for leak in leaks: @@ -218,7 +236,7 @@ def _runtests(args, tests): def runtests(): - args = ARGS.parse_args() + args, pattern = ARGS.parse_args() if args.coverage and coverage is None: URL = "bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py" @@ -238,15 +256,15 @@ def runtests(): testsdir = os.path.abspath(args.testsdir) if not os.path.isdir(testsdir): - print("Tests directory is not found: {}\n".format(testsdir)) + print("Tests directory is not found: {0}\n".format(testsdir)) ARGS.print_help() return excludes = includes = [] if args.exclude: - excludes = args.pattern + excludes = pattern else: - includes = args.pattern + includes = pattern v = 0 if args.quiet else args.verbose + 1 failfast = args.failfast @@ -273,8 +291,8 @@ def runtests(): finder = TestsFinder(args.testsdir, includes, excludes) if args.catchbreak: installHandler() - import asyncio.coroutines - if asyncio.coroutines._DEBUG: + import trollius.coroutines + if trollius.coroutines._DEBUG: print("Run tests in debug mode") else: print("Run tests in release mode") @@ -297,7 +315,7 @@ def runtests(): cov.report(show_missing=False) here = os.path.dirname(os.path.abspath(__file__)) print("\nFor html report:") - print("open file://{}/htmlcov/index.html".format(here)) + print("open file://{0}/htmlcov/index.html".format(here)) if __name__ == '__main__': From fa093793a8dff6d94682de4b2e2d15e5bebbd4f2 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 8 Jul 2015 00:01:39 +0200 Subject: [PATCH 1403/1502] Add TODO.rst from Trollius --- MANIFEST.in | 2 +- TODO.rst | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 TODO.rst diff --git a/MANIFEST.in b/MANIFEST.in index 76c3383f..e804c6ad 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,4 @@ -include AUTHORS COPYING TODO tox.ini +include AUTHORS COPYING TODO.rst tox.ini include Makefile include overlapped.c pypi.bat include check.py runtests.py run_aiotest.py release.py diff --git a/TODO.rst b/TODO.rst new file mode 100644 index 00000000..ff21b225 --- /dev/null +++ b/TODO.rst @@ -0,0 +1,20 @@ +Unsorted "TODO" tasks: + +* reuse selectors backport from PyPI +* check ssl.SSLxxx in update_xxx.sh +* document how to port asyncio to trollius +* use six instead of compat +* Replace logger with warning in monotonic clock and synchronous executor +* Windows: use _overlapped in py33_winapi? +* Fix tests failing with PyPy: + + - sys.getrefcount() + - test_queues.test_repr + - test_futures.test_tb_logger_exception_unretrieved + +* write unit test for create_connection(ssl=True) +* Fix examples: + + - stacks.py: 'exceptions.ZeroDivisionError' object has no attribute '__traceback__' + +* Fix all FIXME in the code From e6e6e0a703f140f952490696fa1adc2040c273e9 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 8 Jul 2015 00:02:31 +0200 Subject: [PATCH 1404/1502] use releaser * add releaser.conf * remove release.py --- release.py | 517 -------------------------------------------------- releaser.conf | 7 + 2 files changed, 7 insertions(+), 517 deletions(-) delete mode 100755 release.py create mode 100644 releaser.conf diff --git a/release.py b/release.py deleted file mode 100755 index a5acbc88..00000000 --- a/release.py +++ /dev/null @@ -1,517 +0,0 @@ -#!/usr/bin/env python3 -""" -Script to upload 32 bits and 64 bits wheel packages for Python 3.3 on Windows. - -Usage: "python release.py HG_TAG" where HG_TAG is a Mercurial tag, usually -a version number like "3.4.2". - -Requirements: - -- Python 3.3 and newer requires the Windows SDK 7.1 to build wheel packages -- Python 2.7 requires the Windows SDK 7.0 -- the aiotest module is required to run aiotest tests -""" -import contextlib -import optparse -import os -import platform -import re -import shutil -import subprocess -import sys -import tempfile -import textwrap - -PROJECT = 'asyncio' -DEBUG_ENV_VAR = 'PYTHONASYNCIODEBUG' -PYTHON_VERSIONS = ( - (3, 3), -) -PY3 = (sys.version_info >= (3,)) -HG = 'hg' -SDK_ROOT = r"C:\Program Files\Microsoft SDKs\Windows" -BATCH_FAIL_ON_ERROR = "@IF %errorlevel% neq 0 exit /b %errorlevel%" -WINDOWS = (sys.platform == 'win32') - - -def get_architecture_bits(): - arch = platform.architecture()[0] - return int(arch[:2]) - - -class PythonVersion: - def __init__(self, major, minor, bits): - self.major = major - self.minor = minor - self.bits = bits - self._executable = None - - @staticmethod - def running(): - bits = get_architecture_bits() - pyver = PythonVersion(sys.version_info.major, - sys.version_info.minor, - bits) - pyver._executable = sys.executable - return pyver - - def _get_executable_windows(self, app): - if self.bits == 32: - executable = 'c:\\Python%s%s_32bit\\python.exe' - else: - executable = 'c:\\Python%s%s\\python.exe' - executable = executable % (self.major, self.minor) - if not os.path.exists(executable): - print("Unable to find python %s" % self) - print("%s does not exists" % executable) - sys.exit(1) - return executable - - def _get_executable_unix(self, app): - return 'python%s.%s' % (self.major, self.minor) - - def get_executable(self, app): - if self._executable: - return self._executable - - if WINDOWS: - executable = self._get_executable_windows(app) - else: - executable = self._get_executable_unix(app) - - code = ( - 'import platform, sys; ' - 'print("{ver.major}.{ver.minor} {bits}".format(' - 'ver=sys.version_info, ' - 'bits=platform.architecture()[0]))' - ) - try: - exitcode, stdout = app.get_output(executable, '-c', code, - ignore_stderr=True) - except OSError as exc: - print("Error while checking %s:" % self) - print(str(exc)) - print("Executable: %s" % executable) - sys.exit(1) - else: - stdout = stdout.rstrip() - expected = "%s.%s %sbit" % (self.major, self.minor, self.bits) - if stdout != expected: - print("Python version or architecture doesn't match") - print("got %r, expected %r" % (stdout, expected)) - print("Executable: %s" % executable) - sys.exit(1) - - self._executable = executable - return executable - - def __str__(self): - return 'Python %s.%s (%s bits)' % (self.major, self.minor, self.bits) - - -class Release(object): - def __init__(self): - root = os.path.dirname(__file__) - self.root = os.path.realpath(root) - # Set these attributes to True to run also register sdist upload - self.wheel = False - self.test = False - self.register = False - self.sdist = False - self.aiotest = False - self.verbose = False - self.upload = False - # Release mode: enable more tests - self.release = False - self.python_versions = [] - if WINDOWS: - supported_archs = (32, 64) - else: - bits = get_architecture_bits() - supported_archs = (bits,) - for major, minor in PYTHON_VERSIONS: - for bits in supported_archs: - pyver = PythonVersion(major, minor, bits) - self.python_versions.append(pyver) - - @contextlib.contextmanager - def _popen(self, args, **kw): - verbose = kw.pop('verbose', True) - if self.verbose and verbose: - print('+ ' + ' '.join(args)) - if PY3: - kw['universal_newlines'] = True - proc = subprocess.Popen(args, **kw) - try: - yield proc - except: - proc.kill() - proc.wait() - raise - - def get_output(self, *args, **kw): - kw['stdout'] = subprocess.PIPE - ignore_stderr = kw.pop('ignore_stderr', False) - if ignore_stderr: - devnull = open(os.path.devnull, 'wb') - kw['stderr'] = devnull - else: - kw['stderr'] = subprocess.STDOUT - try: - with self._popen(args, **kw) as proc: - stdout, stderr = proc.communicate() - return proc.returncode, stdout - finally: - if ignore_stderr: - devnull.close() - - def check_output(self, *args, **kw): - exitcode, output = self.get_output(*args, **kw) - if exitcode: - sys.stdout.write(output) - sys.stdout.flush() - sys.exit(1) - return output - - def run_command(self, *args, **kw): - with self._popen(args, **kw) as proc: - exitcode = proc.wait() - if exitcode: - sys.exit(exitcode) - - def get_local_changes(self): - status = self.check_output(HG, 'status') - return [line for line in status.splitlines() - if not line.startswith("?")] - - def remove_directory(self, name): - path = os.path.join(self.root, name) - if os.path.exists(path): - if self.verbose: - print("Remove directory: %s" % name) - shutil.rmtree(path) - - def remove_file(self, name): - path = os.path.join(self.root, name) - if os.path.exists(path): - if self.verbose: - print("Remove file: %s" % name) - os.unlink(path) - - def windows_sdk_setenv(self, pyver): - if (pyver.major, pyver.minor) >= (3, 3): - path = "v7.1" - sdkver = (7, 1) - else: - path = "v7.0" - sdkver = (7, 0) - setenv = os.path.join(SDK_ROOT, path, 'Bin', 'SetEnv.cmd') - if not os.path.exists(setenv): - print("Unable to find Windows SDK %s.%s for %s" - % (sdkver[0], sdkver[1], pyver)) - print("Please download and install it") - print("%s does not exists" % setenv) - sys.exit(1) - if pyver.bits == 64: - arch = '/x64' - else: - arch = '/x86' - cmd = ["CALL", setenv, "/release", arch] - return (cmd, sdkver) - - def quote(self, arg): - if not re.search("[ '\"]", arg): - return arg - # FIXME: should we escape "? - return '"%s"' % arg - - def quote_args(self, args): - return ' '.join(self.quote(arg) for arg in args) - - def cleanup(self): - if self.verbose: - print("Cleanup") - self.remove_directory('build') - self.remove_directory('dist') - self.remove_file('_overlapped.pyd') - self.remove_file(os.path.join(PROJECT, '_overlapped.pyd')) - - def sdist_upload(self): - self.cleanup() - self.run_command(sys.executable, 'setup.py', 'sdist', 'upload') - - def build_inplace(self, pyver): - print("Build for %s" % pyver) - self.build(pyver, 'build') - - if WINDOWS: - if pyver.bits == 64: - arch = 'win-amd64' - else: - arch = 'win32' - build_dir = 'lib.%s-%s.%s' % (arch, pyver.major, pyver.minor) - src = os.path.join(self.root, 'build', build_dir, - PROJECT, '_overlapped.pyd') - dst = os.path.join(self.root, PROJECT, '_overlapped.pyd') - shutil.copyfile(src, dst) - - def runtests(self, pyver): - print("Run tests on %s" % pyver) - - if WINDOWS and not self.options.no_compile: - self.build_inplace(pyver) - - release_env = dict(os.environ) - release_env.pop(DEBUG_ENV_VAR, None) - - dbg_env = dict(os.environ) - dbg_env[DEBUG_ENV_VAR] = '1' - - python = pyver.get_executable(self) - args = (python, 'runtests.py', '-r') - - if self.release: - print("Run runtests.py in release mode on %s" % pyver) - self.run_command(*args, env=release_env) - - print("Run runtests.py in debug mode on %s" % pyver) - self.run_command(*args, env=dbg_env) - - if self.aiotest: - args = (python, 'run_aiotest.py') - - if self.release: - print("Run aiotest in release mode on %s" % pyver) - self.run_command(*args, env=release_env) - - print("Run aiotest in debug mode on %s" % pyver) - self.run_command(*args, env=dbg_env) - print("") - - def _build_windows(self, pyver, cmd): - setenv, sdkver = self.windows_sdk_setenv(pyver) - - temp = tempfile.NamedTemporaryFile(mode="w", suffix=".bat", - delete=False) - with temp: - temp.write("SETLOCAL EnableDelayedExpansion\n") - temp.write(self.quote_args(setenv) + "\n") - temp.write(BATCH_FAIL_ON_ERROR + "\n") - # Restore console colors: lightgrey on black - temp.write("COLOR 07\n") - temp.write("\n") - temp.write("SET DISTUTILS_USE_SDK=1\n") - temp.write("SET MSSDK=1\n") - temp.write("CD %s\n" % self.quote(self.root)) - temp.write(self.quote_args(cmd) + "\n") - temp.write(BATCH_FAIL_ON_ERROR + "\n") - - try: - if self.verbose: - print("Setup Windows SDK %s.%s" % sdkver) - print("+ " + ' '.join(cmd)) - # SDK 7.1 uses the COLOR command which makes SetEnv.cmd failing - # if the stdout is not a TTY (if we redirect stdout into a file) - if self.verbose or sdkver >= (7, 1): - self.run_command(temp.name, verbose=False) - else: - self.check_output(temp.name, verbose=False) - finally: - os.unlink(temp.name) - - def _build_unix(self, pyver, cmd): - self.check_output(*cmd) - - def build(self, pyver, *cmds): - self.cleanup() - - python = pyver.get_executable(self) - cmd = [python, 'setup.py'] + list(cmds) - - if WINDOWS: - self._build_windows(pyver, cmd) - else: - self._build_unix(pyver, cmd) - - def test_wheel(self, pyver): - print("Test building wheel package for %s" % pyver) - self.build(pyver, 'bdist_wheel') - - def publish_wheel(self, pyver): - print("Build and publish wheel package for %s" % pyver) - self.build(pyver, 'bdist_wheel', 'upload') - - def parse_options(self): - parser = optparse.OptionParser( - description="Run all unittests.", - usage="%prog [options] command") - parser.add_option( - '-v', '--verbose', action="store_true", dest='verbose', - default=0, help='verbose') - parser.add_option( - '-t', '--tag', type="str", - help='Mercurial tag or revision, required to release') - parser.add_option( - '-p', '--python', type="str", - help='Only build/test one specific Python version, ex: "2.7:32"') - parser.add_option( - '-C', "--no-compile", action="store_true", - help="Don't compile the module, this options implies --running", - default=False) - parser.add_option( - '-r', "--running", action="store_true", - help='Only use the running Python version', - default=False) - parser.add_option( - '--ignore', action="store_true", - help='Ignore local changes', - default=False) - self.options, args = parser.parse_args() - if len(args) == 1: - command = args[0] - else: - command = None - - if self.options.no_compile: - self.options.running = True - - if command == 'clean': - self.options.verbose = True - elif command == 'build': - self.options.running = True - elif command == 'test_wheel': - self.wheel = True - elif command == 'test': - self.test = True - elif command == 'release': - if not self.options.tag: - print("The release command requires the --tag option") - sys.exit(1) - - self.release = True - self.wheel = True - self.test = True - self.upload = True - else: - if command: - print("Invalid command: %s" % command) - else: - parser.print_help() - print("") - - print("Available commands:") - print("- build: build asyncio in place, imply --running") - print("- test: run tests") - print("- test_wheel: test building wheel packages") - print("- release: run tests and publish wheel packages,") - print(" require the --tag option") - print("- clean: cleanup the project") - sys.exit(1) - - if self.options.python and self.options.running: - print("--python and --running options are exclusive") - sys.exit(1) - - python = self.options.python - if python: - match = re.match("^([23])\.([0-9])/(32|64)$", python) - if not match: - print("Invalid Python version: %s" % python) - print('Format of a Python version: "x.y/bits"') - print("Example: 2.7/32") - sys.exit(1) - major = int(match.group(1)) - minor = int(match.group(2)) - bits = int(match.group(3)) - self.python_versions = [PythonVersion(major, minor, bits)] - - if self.options.running: - self.python_versions = [PythonVersion.running()] - - self.verbose = self.options.verbose - self.command = command - - def main(self): - self.parse_options() - - print("Directory: %s" % self.root) - os.chdir(self.root) - - if self.command == "clean": - self.cleanup() - sys.exit(1) - - if self.command == "build": - if len(self.python_versions) != 1: - print("build command requires one specific Python version") - print("Use the --python command line option") - sys.exit(1) - pyver = self.python_versions[0] - self.build_inplace(pyver) - - if (self.register or self.upload) and (not self.options.ignore): - lines = self.get_local_changes() - else: - lines = () - if lines: - print("ERROR: Found local changes") - for line in lines: - print(line) - print("") - print("Revert local changes") - print("or use the --ignore command line option") - sys.exit(1) - - hg_tag = self.options.tag - if hg_tag: - print("Update repository to revision %s" % hg_tag) - self.check_output(HG, 'update', hg_tag) - - hg_rev = self.check_output(HG, 'id').rstrip() - - if self.wheel: - for pyver in self.python_versions: - self.test_wheel(pyver) - - if self.test: - for pyver in self.python_versions: - self.runtests(pyver) - - if self.register: - self.run_command(sys.executable, 'setup.py', 'register') - - if self.sdist: - self.sdist_upload() - - if self.upload: - for pyver in self.python_versions: - self.publish_wheel(pyver) - - hg_rev2 = self.check_output(HG, 'id').rstrip() - if hg_rev != hg_rev2: - print("ERROR: The Mercurial revision changed") - print("Before: %s" % hg_rev) - print("After: %s" % hg_rev2) - sys.exit(1) - - print("") - print("Mercurial revision: %s" % hg_rev) - if self.command == 'build': - print("Inplace compilation done") - if self.wheel: - print("Compilation of wheel packages succeeded") - if self.test: - print("Tests succeeded") - if self.register: - print("Project registered on the Python cheeseshop (PyPI)") - if self.sdist: - print("Project source code uploaded to the Python " - "cheeseshop (PyPI)") - if self.upload: - print("Wheel packages uploaded to the Python cheeseshop (PyPI)") - for pyver in self.python_versions: - print("- %s" % pyver) - - -if __name__ == "__main__": - Release().main() diff --git a/releaser.conf b/releaser.conf new file mode 100644 index 00000000..37281395 --- /dev/null +++ b/releaser.conf @@ -0,0 +1,7 @@ +# Configuration file for the tool "releaser" +# https://bitbucket.org/haypo/misc/src/tip/bin/releaser.py + +[project] +name = trollius +debug_env_var = TROLLIUSDEBUG +python_versions = 2.7, 3.3, 3.4 From 35906103b9f40df213cce2df48322f6106daa31e Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 8 Jul 2015 00:04:03 +0200 Subject: [PATCH 1405/1502] remove asyncio changelog trollius has already doc/changelog.rst --- ChangeLog | 332 ------------------------------------------------------ 1 file changed, 332 deletions(-) delete mode 100644 ChangeLog diff --git a/ChangeLog b/ChangeLog deleted file mode 100644 index 25155a98..00000000 --- a/ChangeLog +++ /dev/null @@ -1,332 +0,0 @@ -2015-02-04: Tulip 3.4.3 -======================= - -Major changes -------------- - -* New SSL implementation using ssl.MemoryBIO. The new implementation requires - Python 3.5 and newer, otherwise the legacy implementation is used. -* On Python 3.5 and newer usable, the ProactorEventLoop now supports SSL - thanks to the new SSL implementation. -* Fix multiple resource leaks: close sockets on error, explicitly clear - references, emit ResourceWarning when event loops and transports are not - closed explicitly, etc. -* The proactor event loop is now much more reliable (no more known race - condition). -* Enhance handling of task cancellation. - -Changes of the asyncio API --------------------------- - -* Export BaseEventLoop symbol in the asyncio namespace -* create_task(), call_soon(), call_soon_threadsafe(), call_later(), - call_at() and run_in_executor() methods of BaseEventLoop now raise an - exception if the event loop is closed. -* call_soon(), call_soon_threadsafe(), call_later(), call_at() and - run_in_executor() methods of BaseEventLoop now raise an exception if the - callback is a coroutine object. -* BaseEventLoopPolicy.get_event_loop() now always raises a RuntimeError - if there is no event loop in the curren thread, instead of using an - assertion (which can be disabld at runtime) and so raises an AssertionError. -* selectors: Selector.get_key() now raises an exception if the selector is - closed. -* If wait_for() is cancelled, the waited task is also cancelled. -* _UnixSelectorEventLoop.add_signal_handler() now raises an exception if - the callback is a coroutine object or a coroutine function. It also raises - an exception if the event loop is closed. - -Performances ------------- - -* sock_connect() doesn't check if the address is already resolved anymore. - The check is only done in debug mode. Moreover, the check uses inet_pton() - instead of getaddrinfo(), if inet_pton() is available, because getaddrinfo() - is slow (around 10 us per call). - -Debug ------ - -* Better repr() of _ProactorBasePipeTransport, _SelectorTransport, - _UnixReadPipeTransport and _UnixWritePipeTransport: add closed/closing - status and the file descriptor -* Add repr(PipeHandle) -* PipeHandle destructor now emits a ResourceWarning is the pipe is not closed - explicitly. -* In debug mode, call_at() method of BaseEventLoop now raises an exception - if called from the wrong thread (not from the thread running the event - loop). Before, it only raised an exception if current thread had an event - loop. -* A ResourceWarning is now emitted when event loops and transports are - destroyed before being closed. -* BaseEventLoop.call_exception_handler() now logs the traceback where - the current handle was created (if no source_traceback was specified). -* BaseSubprocessTransport.close() now logs a warning when the child process is - still running and the method kills it. - -Bug fixes ---------- - -* windows_utils.socketpair() now reuses socket.socketpair() if available - (Python 3.5 or newer). -* Fix IocpProactor.accept_pipe(): handle ERROR_PIPE_CONNECTED, it means - that the pipe is connected. _overlapped.Overlapped.ConnectNamedPipe() now - returns True on ERROR_PIPE_CONNECTED. -* Rewrite IocpProactor.connect_pipe() using polling to avoid tricky bugs - if the connection is cancelled, instead of using QueueUserWorkItem() to run - blocking code. -* Fix IocpProactor.recv(): handle BrokenPipeError, set the result to an empty - string. -* Fix ProactorEventLoop.start_serving_pipe(): if a client connected while the - server is closing, drop the client connection. -* Fix a tricky race condition when IocpProactor.wait_for_handle() is - cancelled: wait until the wait is really cancelled before destroying the - overlapped object. Unregister also the overlapped operation to not block - in IocpProactor.close() before the wait will never complete. -* Fix _UnixSubprocessTransport._start(): make the write end of the stdin pipe - non-inheritable. -* Set more attributes in the body of classes to avoid attribute errors in - destructors if an error occurred in the constructor. -* Fix SubprocessStreamProtocol.process_exited(): close the transport - and clear its reference to the transport. -* Fix SubprocessStreamProtocol.connection_made(): set the transport of - stdout and stderr streams to respect reader buffer limits (stop reading when - the buffer is full). -* Fix FlowControlMixin constructor: if the loop parameter is None, get the - current event loop. -* Fix selectors.EpollSelector.select(): don't fail anymore if no file - descriptor is registered. -* Fix _SelectorTransport: don't wakeup the waiter if it was cancelled -* Fix _SelectorTransport._call_connection_lost(): only call connection_lost() - if connection_made() was already called. -* Fix BaseSelectorEventLoop._accept_connection(): close the transport on - error. In debug mode, log errors (ex: SSL handshake failure) on the creation - of the transport for incoming connection. -* Fix BaseProactorEventLoop.close(): stop the proactor before closing the - event loop because stopping the proactor may schedule new callbacks, which - is now forbidden when the event loop is closed. -* Fix wrap_future() to not use a free variable and so not keep a frame alive - too long. -* Fix formatting of the "Future/Task exception was never retrieved" log: add - a newline before the traceback. -* WriteSubprocessPipeProto.connection_lost() now clears its reference to the - subprocess.Popen object. -* If the creation of a subprocess transport fails, the child process is killed - and the event loop waits asynchronously for its completion. -* BaseEventLoop.run_until_complete() now consumes the exception to not log a - warning when a BaseException like KeyboardInterrupt is raised and - run_until_complete() is not a future (but a coroutine object). -* create_connection(), create_datagram_endpoint(), connect_read_pipe() and - connect_write_pipe() methods of BaseEventLoop now close the transport on - error. - -Other changes -------------- - -* Add tox.ini to run tests using tox. -* _FlowControlMixin constructor now requires an event loop. -* Embed asyncio/test_support.py to not depend on test.support of the system - Python. For example, test.support is not installed by default on Windows. -* selectors.Selector.close() now clears its reference to the mapping object. -* _SelectorTransport and _UnixWritePipeTransport now only starts listening for - read events after protocol.connection_made() has been called -* _SelectorTransport._fatal_error() now only logs ConnectionAbortedError - in debug mode. -* BaseProactorEventLoop._loop_self_reading() now handles correctly - CancelledError (just exit) and logs an error for other exceptions. -* _ProactorBasePipeTransport now clears explicitly references to read and - write future and to the socket -* BaseSubprocessTransport constructor now calls the internal _connect_pipes() - method (previously called _post_init()). The constructor now accepts an - optional waiter parameter to notify when the transport is ready. -* send_signal(), terminate() and kill() methods of BaseSubprocessTransport now - raise a ProcessLookupError if the process already exited. -* Add run_aiotest.py to run the aiotest test suite -* Add release.py script to build wheel packages on Windows and run unit tests - - -2014-09-30: Tulip 3.4.2 -======================= - -New shiny methods like create_task(), better documentation, much better debug -mode, better tests. - -asyncio API ------------ - -* Add BaseEventLoop.create_task() method: schedule a coroutine object. - It allows other asyncio implementations to use their own Task class to - change its behaviour. - -* New BaseEventLoop methods: - - - create_task(): schedule a coroutine - - get_debug() - - is_closed() - - set_debug() - -* Add _FlowControlMixin.get_write_buffer_limits() method - -* sock_recv(), sock_sendall(), sock_connect(), sock_accept() methods of - SelectorEventLoop now raise an exception if the socket is blocking mode - -* Include unix_events/windows_events symbols in asyncio.__all__. - Examples: SelectorEventLoop, ProactorEventLoop, DefaultEventLoopPolicy. - -* attach(), detach(), loop, active_count and waiters attributes of the Server - class are now private - -* BaseEventLoop: run_forever(), run_until_complete() now raises an exception if - the event loop was closed - -* close() now raises an exception if the event loop is running, because pending - callbacks would be lost - -* Queue now accepts a float for the maximum size. - -* Process.communicate() now ignores BrokenPipeError and ConnectionResetError - exceptions, as Popen.communicate() of the subprocess module - - -Performances ------------- - -* Optimize handling of cancelled timers - - -Debug ------ - -* Future (and Task), CoroWrapper and Handle now remembers where they were - created (new _source_traceback object), traceback displayed when errors are - logged. - -* On Python 3.4 and newer, Task destrutor now logs a warning if the task was - destroyed while it was still pending. It occurs if the last reference - to the task was removed, while the coroutine didn't finish yet. - -* Much more useful events are logged: - - - Event loop closed - - Network connection - - Creation of a subprocess - - Pipe lost - - Log many errors previously silently ignored - - SSL handshake failure - - etc. - -* BaseEventLoop._debug is now True if the envrionement variable - PYTHONASYNCIODEBUG is set - -* Log the duration of DNS resolution and SSL handshake - -* Log a warning if a callback blocks the event loop longer than 100 ms - (configurable duration) - -* repr(CoroWrapper) and repr(Task) now contains the current status of the - coroutine (running, done), current filename and line number, and filename and - line number where the object was created - -* Enhance representation (repr) of transports: add the file descriptor, status - (idle, polling, writing, etc.), size of the write buffer, ... - -* Add repr(BaseEventLoop) - -* run_until_complete() doesn't log a warning anymore when called with a - coroutine object which raises an exception. - - -Bugfixes --------- - -* windows_utils.socketpair() now ensures that sockets are closed in case - of error. - -* Rewrite bricks of the IocpProactor() to make it more reliable - -* IocpProactor destructor now closes it. - -* _OverlappedFuture.set_exception() now cancels the overlapped operation. - -* Rewrite _WaitHandleFuture: - - - cancel() is now able to signal the cancellation to the overlapped object - - _unregister_wait() now catchs and logs exceptions - -* PipeServer.close() (class used on Windows) now cancels the accept pipe - future. - -* Rewrite signal handling in the UNIX implementation of SelectorEventLoop: - use the self-pipe to store pending signals instead of registering a - signal handler calling directly _handle_signal(). The change fixes a - race condition. - -* create_unix_server(): close the socket on error. - -* Fix wait_for() - -* Rewrite gather() - -* drain() is now a classic coroutine, no more special return value (empty - tuple) - -* Rewrite SelectorEventLoop.sock_connect() to handle correctly timeout - -* Process data of the self-pipe faster to accept more pending events, - especially signals written by signal handlers: the callback reads all pending - data, not only a single byte - -* Don't try to set the result of a Future anymore if it was cancelled - (explicitly or by a timeout) - -* CoroWrapper now works around CPython issue #21209: yield from & custom - generator classes don't work together, issue with the send() method. It only - affected asyncio in debug mode on Python older than 3.4.2 - - -Misc changes ------------- - -* windows_utils.socketpair() now supports IPv6. - -* Better documentation (online & docstrings): fill remaining XXX, more examples - -* new asyncio.coroutines submodule, to ease maintenance with the trollius - project: @coroutine, _DEBUG, iscoroutine() and iscoroutinefunction() have - been moved from asyncio.tasks to asyncio.coroutines - -* Cleanup code, ex: remove unused attribute (ex: _rawsock) - -* Reuse os.set_blocking() of Python 3.5. - -* Close explicitly the event loop in Tulip examples. - -* runtests.py now mention if tests are running in release or debug mode. - - -2014-05-19: Tulip 3.4.1 -======================= - -2014-02-24: Tulip 0.4.1 -======================= - -2014-02-10: Tulip 0.3.1 -======================= - -* Add asyncio.subprocess submodule and the Process class. - -2013-11-25: Tulip 0.2.1 -======================= - -* Add support of subprocesses using transports and protocols. - -2013-10-22: Tulip 0.1.1 -======================= - -* First release. - -Creation of the project -======================= - -* 2013-10-14: The tulip package was renamed to asyncio. -* 2012-10-16: Creation of the Tulip project, started as mail threads on the - python-ideas mailing list. From 71046948a2ab96ae75b422db67dcc5ac1df8f0bf Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 9 Jul 2015 00:12:12 +0200 Subject: [PATCH 1406/1502] Fix Python 3.4 issues --- runtests.py | 3 +-- setup.py | 2 +- tests/echo3.py | 3 +-- tests/test_asyncio.py | 2 +- tests/test_locks.py | 4 ++-- tests/test_subprocess.py | 1 + tests/test_tasks.py | 33 ++------------------------------- tests/test_unix_events.py | 2 +- tox.ini | 9 ++++++++- trollius/base_events.py | 11 ++++++----- trollius/coroutines.py | 21 +++++++++++++++++++-- trollius/locks.py | 15 +++++++++++++-- trollius/selector_events.py | 3 +-- trollius/subprocess.py | 5 +---- trollius/tasks.py | 5 +---- trollius/windows_events.py | 3 +-- trollius/windows_utils.py | 1 - 17 files changed, 60 insertions(+), 63 deletions(-) diff --git a/runtests.py b/runtests.py index 63c0fba2..efc6418a 100755 --- a/runtests.py +++ b/runtests.py @@ -50,7 +50,7 @@ ARGS = optparse.OptionParser(description="Run all unittests.", usage="%prog [options] [pattern] [pattern2 ...]") ARGS.add_option( - '-v', '--verbose', action="store_true", dest='verbose', + '-v', '--verbose', type=int, dest='verbose', default=0, help='verbose') ARGS.add_option( '-x', action="store_true", dest='exclude', help='exclude tests') @@ -275,7 +275,6 @@ def runtests(): ) cov.start() - logger = logging.getLogger() if v == 0: level = logging.CRITICAL elif v == 1: diff --git a/setup.py b/setup.py index 701d8ad6..b40169cd 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ # We won't be able to build the Wheel file on Windows. from distutils.core import setup, Extension -with open("README") as fp: +with open("README.rst") as fp: long_description = fp.read() extensions = [] diff --git a/tests/echo3.py b/tests/echo3.py index 4c2b505d..06449673 100644 --- a/tests/echo3.py +++ b/tests/echo3.py @@ -1,5 +1,4 @@ import os -from trollius.py33_exceptions import wrap_error if __name__ == '__main__': while True: @@ -7,6 +6,6 @@ if not buf: break try: - wrap_error(os.write, 1, b'OUT:'+buf) + os.write(1, b'OUT:'+buf) except OSError as ex: os.write(2, b'ERR:' + ex.__class__.__name__.encode('ascii')) diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 39d9e1aa..d75054d6 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -87,7 +87,7 @@ def test_async(self): fut = asyncio.Future() self.assertIs(fut._loop, self.loop) - fut2 = trollius.async(fut) + fut2 = trollius.ensure_future(fut) self.assertIs(fut2, fut) self.assertIs(fut._loop, self.loop) diff --git a/tests/test_locks.py b/tests/test_locks.py index ec7dbba2..f3ea3b0b 100644 --- a/tests/test_locks.py +++ b/tests/test_locks.py @@ -230,7 +230,7 @@ def test_context_manager_no_yield(self): except RuntimeError as err: self.assertEqual( str(err), - '"yield" should be used as context manager expression') + '"yield From" should be used as context manager expression') self.assertFalse(lock.locked()) @@ -856,7 +856,7 @@ def test_context_manager_no_yield(self): except RuntimeError as err: self.assertEqual( str(err), - '"yield" should be used as context manager expression') + '"yield From" should be used as context manager expression') self.assertEqual(2, sem._value) diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index 99071ee3..b2eda3e9 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -6,6 +6,7 @@ import sys import unittest from trollius import From, Return +from trollius import base_subprocess from trollius import test_support as support from trollius.test_utils import mock from trollius.py33_exceptions import BrokenPipeError, ConnectionResetError diff --git a/tests/test_tasks.py b/tests/test_tasks.py index c2045a53..595b0040 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1,6 +1,7 @@ """Tests for tasks.py.""" import contextlib +import functools import os import re import sys @@ -356,6 +357,7 @@ def task(): raise Return(12) t = asyncio.Task(task(), loop=loop) + test_utils.run_briefly(loop) loop.call_soon(t.cancel) with self.assertRaises(asyncio.CancelledError): loop.run_until_complete(t) @@ -1304,37 +1306,6 @@ def fn2(): yield self.assertTrue(asyncio.iscoroutinefunction(fn2)) - def test_yield_vs_yield_from(self): - fut = asyncio.Future(loop=self.loop) - - @asyncio.coroutine - def wait_for_future(): - yield fut - - task = wait_for_future() - with self.assertRaises(RuntimeError): - self.loop.run_until_complete(task) - - self.assertFalse(fut.done()) - - def test_yield_vs_yield_from_generator(self): - @asyncio.coroutine - def coro(): - yield - - @asyncio.coroutine - def wait_for_future(): - gen = coro() - try: - yield gen - finally: - gen.close() - - task = wait_for_future() - self.assertRaises( - RuntimeError, - self.loop.run_until_complete, task) - def test_coroutine_non_gen_function(self): @asyncio.coroutine def func(): diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 1223d86c..698b1308 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -1,7 +1,6 @@ """Tests for unix_events.py.""" import collections -import contextlib import errno import io import os @@ -282,6 +281,7 @@ def test_create_unix_server_bind_error(self, m_socket): # Ensure that the socket is closed on any bind error sock = mock.Mock() m_socket.socket.return_value = sock + m_socket.error = socket.error sock.bind.side_effect = OSError coro = self.loop.create_unix_server(lambda: None, path="/test") diff --git a/tox.ini b/tox.ini index 229b13cb..caaab829 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py26,py27,py2_release,py32,py33,py34,py3_release +envlist = py26,py27,py2_release,py32,py33,py34,py3_release,pyflakes3 [testenv] deps= @@ -10,6 +10,13 @@ commands= python runtests.py -r {posargs} python run_aiotest.py -r {posargs} +[testenv:pyflakes3] +basepython = python3 +deps= + pyflakes +commands= + pyflakes trollius tests runtests.py check.py run_aiotest.py setup.py + [testenv:py26] deps= aiotest diff --git a/trollius/base_events.py b/trollius/base_events.py index c8541f19..a3391b48 100644 --- a/trollius/base_events.py +++ b/trollius/base_events.py @@ -23,6 +23,7 @@ import subprocess import sys import traceback +import warnings try: from collections import OrderedDict except ImportError: @@ -198,7 +199,7 @@ def __init__(self): # Identifier of the thread running the event loop, or None if the # event loop is not running self._thread_id = None - self._clock_resolution = time.get_clock_info('monotonic').resolution + self._clock_resolution = time_monotonic_resolution self._exception_handler = None self.set_debug(bool(os.environ.get('TROLLIUSDEBUG'))) # In debug mode, if the execution of a callback or a step of a task @@ -301,7 +302,7 @@ def run_forever(self): if self.is_running(): raise RuntimeError('Event loop is running.') self._set_coroutine_wrapper(self._debug) - self._thread_id = threading.get_ident() + self._thread_id = _get_thread_ident() try: while True: try: @@ -325,7 +326,7 @@ def run_until_complete(self, future): """ self._check_closed() - new_task = not isinstance(future, futures.Future) + new_task = not isinstance(future, futures._FUTURE_CLASSES) future = tasks.ensure_future(future, loop=self) if new_task: # An exception is raised if the future didn't complete, so there @@ -404,7 +405,7 @@ def time(self): epoch, precision, accuracy and drift are unspecified and may differ per event loop. """ - return time.monotonic() + return time_monotonic() def call_later(self, delay, callback, *args): """Arrange for a callback to be called at a given time. @@ -484,7 +485,7 @@ def _check_thread(self): """ if self._thread_id is None: return - thread_id = threading.get_ident() + thread_id = _get_thread_ident() if thread_id != self._thread_id: raise RuntimeError( "Non-thread-safe operation invoked on an event loop other " diff --git a/trollius/coroutines.py b/trollius/coroutines.py index b12ca3e1..a4a1dd5e 100644 --- a/trollius/coroutines.py +++ b/trollius/coroutines.py @@ -1,5 +1,6 @@ __all__ = ['coroutine', - 'iscoroutinefunction', 'iscoroutine'] + 'iscoroutinefunction', 'iscoroutine', + 'From', 'Return'] import functools import inspect @@ -129,6 +130,21 @@ def debug_wrapper(gen): return CoroWrapper(gen, None) +def _coroutine_at_yield_from(coro): + """Test if the last instruction of a coroutine is "yield from". + + Return False if the coroutine completed. + """ + frame = coro.gi_frame + if frame is None: + return False + code = coro.gi_code + assert frame.f_lasti >= 0 + offset = frame.f_lasti + 1 + instr = code.co_code[offset] + return (instr == _YIELD_FROM) + + class CoroWrapper: # Wrapper for coroutine object in _DEBUG mode. @@ -298,7 +314,8 @@ def coroutine(func): @_wraps(func) def coro(*args, **kw): res = func(*args, **kw) - if isinstance(res, futures.Future) or inspect.isgenerator(res): + if (isinstance(res, futures._FUTURE_CLASSES) + or inspect.isgenerator(res)): res = yield from res elif _AwaitableABC is not None: # If 'func' returns an Awaitable (new in 3.5) we diff --git a/trollius/locks.py b/trollius/locks.py index ecbe3b3b..3917694a 100644 --- a/trollius/locks.py +++ b/trollius/locks.py @@ -46,7 +46,7 @@ def __exit__(self, *args): class _ContextManagerMixin(object): def __enter__(self): raise RuntimeError( - '"yield from" should be used as context manager expression') + '"yield From" should be used as context manager expression') def __exit__(self, *args): # This must exist because __enter__ exists, even though that @@ -331,8 +331,19 @@ def wait(self): finally: self._waiters.remove(fut) - finally: + except Exception as exc: + # Workaround CPython bug #23353: using yield/yield-from in an + # except block of a generator doesn't clear properly + # sys.exc_info() + err = exc + else: + err = None + + if err is not None: yield From(self.acquire()) + raise err + + yield From(self.acquire()) @coroutine def wait_for(self, predicate): diff --git a/trollius/selector_events.py b/trollius/selector_events.py index ec14974f..cf0681f3 100644 --- a/trollius/selector_events.py +++ b/trollius/selector_events.py @@ -14,8 +14,7 @@ import warnings try: import ssl - from .py3_ssl import (wrap_ssl_error, SSLContext, SSLWantReadError, - SSLWantWriteError) + from .py3_ssl import wrap_ssl_error, SSLWantReadError, SSLWantWriteError except ImportError: # pragma: no cover ssl = None diff --git a/trollius/subprocess.py b/trollius/subprocess.py index 2f1becf3..1e767689 100644 --- a/trollius/subprocess.py +++ b/trollius/subprocess.py @@ -1,16 +1,13 @@ __all__ = ['create_subprocess_exec', 'create_subprocess_shell'] -import collections import subprocess from . import events -from . import futures from . import protocols from . import streams from . import tasks from .coroutines import coroutine, From, Return -from .py33_exceptions import (BrokenPipeError, ConnectionResetError, - ProcessLookupError) +from .py33_exceptions import BrokenPipeError, ConnectionResetError from .log import logger diff --git a/trollius/tasks.py b/trollius/tasks.py index 1fb23cdc..963b54dd 100644 --- a/trollius/tasks.py +++ b/trollius/tasks.py @@ -6,12 +6,9 @@ 'gather', 'shield', 'ensure_future', ] -import concurrent.futures import functools -import inspect import linecache import sys -import types import traceback import warnings try: @@ -407,7 +404,7 @@ def wait_for(fut, timeout, loop=None): loop = events.get_event_loop() if timeout is None: - result = yield From(fut) + result = yield From(fut) raise Return(result) waiter = futures.Future(loop=loop) diff --git a/trollius/windows_events.py b/trollius/windows_events.py index 7f7764e0..9274a9d9 100644 --- a/trollius/windows_events.py +++ b/trollius/windows_events.py @@ -18,8 +18,7 @@ from . import _overlapped from .coroutines import coroutine, From, Return from .log import logger -from .py33_exceptions import (wrap_error, get_error_class, - ConnectionRefusedError, BrokenPipeError) +from .py33_exceptions import wrap_error, BrokenPipeError __all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor', diff --git a/trollius/windows_utils.py b/trollius/windows_utils.py index b2b96177..7d21bbc0 100644 --- a/trollius/windows_utils.py +++ b/trollius/windows_utils.py @@ -8,7 +8,6 @@ if sys.platform != 'win32': # pragma: no cover raise ImportError('win32 only') -import _winapi import itertools import msvcrt import os From 14d618e2959019c0c44191eea2481da76a8f0e72 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 9 Jul 2015 01:04:55 +0200 Subject: [PATCH 1407/1502] Python 2.7 fixes --- TODO.rst | 2 ++ tests/echo3.py | 10 ++++++- tests/test_events.py | 3 +- tests/test_futures.py | 16 ++--------- tests/test_selectors.py | 33 ++++++++++------------ tests/test_sslproto.py | 1 + tests/test_subprocess.py | 24 ++++++++-------- tests/test_tasks.py | 25 +++++++++-------- tests/test_unix_events.py | 5 ++-- tox.ini | 10 ++++++- trollius/base_events.py | 6 ++-- trollius/base_subprocess.py | 19 +++++++------ trollius/coroutines.py | 55 ++++++++++++++++++++----------------- trollius/events.py | 2 +- trollius/futures.py | 11 ++------ trollius/locks.py | 52 ++++++++++++----------------------- trollius/py33_exceptions.py | 2 +- trollius/selector_events.py | 4 +-- trollius/selectors.py | 2 +- trollius/sslproto.py | 1 + trollius/streams.py | 23 ++++++++-------- trollius/subprocess.py | 2 ++ trollius/tasks.py | 3 +- trollius/test_utils.py | 13 ++++----- trollius/windows_events.py | 2 +- 25 files changed, 161 insertions(+), 165 deletions(-) diff --git a/TODO.rst b/TODO.rst index ff21b225..f9b25a5e 100644 --- a/TODO.rst +++ b/TODO.rst @@ -1,5 +1,7 @@ Unsorted "TODO" tasks: +* test_utils.py: remove assertRaisesRegex, assertRegex +* streams.py:FIXME: should we support __aiter__ and __anext__ in Trollius? * reuse selectors backport from PyPI * check ssl.SSLxxx in update_xxx.sh * document how to port asyncio to trollius diff --git a/tests/echo3.py b/tests/echo3.py index 06449673..a009ea34 100644 --- a/tests/echo3.py +++ b/tests/echo3.py @@ -1,4 +1,12 @@ import os +import sys + +asyncio_path = os.path.join(os.path.dirname(__file__), '..') +asyncio_path = os.path.abspath(asyncio_path) + +sys.path.insert(0, asyncio_path) +from trollius.py33_exceptions import wrap_error +sys.path.remove(asyncio_path) if __name__ == '__main__': while True: @@ -6,6 +14,6 @@ if not buf: break try: - os.write(1, b'OUT:'+buf) + wrap_error(os.write, 1, b'OUT:'+buf) except OSError as ex: os.write(2, b'ERR:' + ex.__class__.__name__.encode('ascii')) diff --git a/tests/test_events.py b/tests/test_events.py index 8f600eb5..225c7ce3 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -2083,8 +2083,9 @@ def test_handle_source_traceback(self): def check_source_traceback(h): lineno = sys._getframe(1).f_lineno - 1 self.assertIsInstance(h._source_traceback, list) + filename = sys._getframe().f_code.co_filename self.assertEqual(h._source_traceback[-1][:3], - (__file__, + (filename, lineno, 'test_handle_source_traceback')) diff --git a/tests/test_futures.py b/tests/test_futures.py index 34e68fc1..7a98d127 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -10,6 +10,7 @@ import unittest import trollius as asyncio +from trollius import From from trollius import compat from trollius import test_support as support from trollius import test_utils @@ -182,18 +183,6 @@ def test_copy_state(self): newf_cancelled._copy_state(f_cancelled) self.assertTrue(newf_cancelled.cancelled()) - def test_iter(self): - fut = asyncio.Future(loop=self.loop) - - def coro(): - yield from fut - - def test(): - arg1, arg2 = coro() - - self.assertRaises(AssertionError, test) - fut.cancel() - @mock.patch('trollius.base_events.logger') def test_tb_logger_abandoned(self, m_log): fut = asyncio.Future(loop=self.loop) @@ -296,8 +285,9 @@ def test_future_source_traceback(self): future = asyncio.Future(loop=self.loop) lineno = sys._getframe().f_lineno - 1 self.assertIsInstance(future._source_traceback, list) + filename = sys._getframe().f_code.co_filename self.assertEqual(future._source_traceback[-1][:3], - (__file__, + (filename, lineno, 'test_future_source_traceback')) diff --git a/tests/test_selectors.py b/tests/test_selectors.py index 591b4aab..dde82ec9 100644 --- a/tests/test_selectors.py +++ b/tests/test_selectors.py @@ -3,6 +3,7 @@ import random import signal import sys +import unittest from time import sleep try: import resource @@ -25,7 +26,7 @@ def find_ready_matching(ready, flag): return match -class BaseSelectorTestCase(test_utils.TestCase): +class BaseSelectorTestCase(object): def make_socketpair(self): rd, wr = socketpair() @@ -352,7 +353,7 @@ def test_select_interrupt(self): self.assertLess(time() - t, 2.5) -class ScalableSelectorMixIn: +class ScalableSelectorMixIn(object): # see issue #18963 for why it's skipped on older OS X versions @support.requires_mac_ver(10, 5) @@ -398,52 +399,48 @@ def test_above_fd_setsize(self): self.assertEqual(NUM_FDS // 2, len(s.select())) -class DefaultSelectorTestCase(BaseSelectorTestCase): +class DefaultSelectorTestCase(BaseSelectorTestCase, test_utils.TestCase): SELECTOR = selectors.DefaultSelector -class SelectSelectorTestCase(BaseSelectorTestCase): +class SelectSelectorTestCase(BaseSelectorTestCase, test_utils.TestCase): SELECTOR = selectors.SelectSelector @test_utils.skipUnless(hasattr(selectors, 'PollSelector'), "Test needs selectors.PollSelector") -class PollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn): +class PollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn, + test_utils.TestCase): SELECTOR = getattr(selectors, 'PollSelector', None) @test_utils.skipUnless(hasattr(selectors, 'EpollSelector'), "Test needs selectors.EpollSelector") -class EpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn): +class EpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn, + test_utils.TestCase): SELECTOR = getattr(selectors, 'EpollSelector', None) @test_utils.skipUnless(hasattr(selectors, 'KqueueSelector'), "Test needs selectors.KqueueSelector)") -class KqueueSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn): +class KqueueSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn, + test_utils.TestCase): SELECTOR = getattr(selectors, 'KqueueSelector', None) @test_utils.skipUnless(hasattr(selectors, 'DevpollSelector'), "Test needs selectors.DevpollSelector") -class DevpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn): +class DevpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn, + test_utils.TestCase): SELECTOR = getattr(selectors, 'DevpollSelector', None) -def test_main(): - tests = [DefaultSelectorTestCase, SelectSelectorTestCase, - PollSelectorTestCase, EpollSelectorTestCase, - KqueueSelectorTestCase, DevpollSelectorTestCase] - support.run_unittest(*tests) - support.reap_children() - - -if __name__ == "__main__": - test_main() +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_sslproto.py b/tests/test_sslproto.py index 8ea0f975..f60bbf8c 100644 --- a/tests/test_sslproto.py +++ b/tests/test_sslproto.py @@ -10,6 +10,7 @@ from trollius import sslproto from trollius import test_utils from trollius.test_utils import mock +from trollius import ConnectionResetError @unittest.skipIf(ssl is None, 'No ssl module') diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index b2eda3e9..f258ff15 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -9,7 +9,7 @@ from trollius import base_subprocess from trollius import test_support as support from trollius.test_utils import mock -from trollius.py33_exceptions import BrokenPipeError, ConnectionResetError +from trollius import BrokenPipeError, ConnectionResetError, ProcessLookupError if sys.platform != 'win32': from trollius import unix_events @@ -305,15 +305,15 @@ def test_cancel_process_wait(self): @asyncio.coroutine def cancel_wait(): - proc = yield from asyncio.create_subprocess_exec( + proc = yield From(asyncio.create_subprocess_exec( *PROGRAM_BLOCKED, - loop=self.loop) + loop=self.loop)) # Create an internal future waiting on the process exit task = self.loop.create_task(proc.wait()) self.loop.call_soon(task.cancel) try: - yield from task + yield From(task) except asyncio.CancelledError: pass @@ -368,12 +368,11 @@ def test_close_kill_running(self): def kill_running(): create = self.loop.subprocess_exec(asyncio.SubprocessProtocol, *PROGRAM_BLOCKED) - transport, protocol = yield from create + transport, protocol = yield From(create) - kill_called = False + non_local = {'kill_called': False} def kill(): - nonlocal kill_called - kill_called = True + non_local['kill_called'] = True orig_kill() proc = transport.get_extra_info('subprocess') @@ -381,8 +380,8 @@ def kill(): proc.kill = kill returncode = transport.get_returncode() transport.close() - yield from transport._wait() - return (returncode, kill_called) + yield From(transport._wait()) + raise Return(returncode, non_local['kill_called']) # Ignore "Close running child process: kill ..." log with test_utils.disable_logger(): @@ -398,7 +397,7 @@ def test_close_dont_kill_finished(self): def kill_running(): create = self.loop.subprocess_exec(asyncio.SubprocessProtocol, *PROGRAM_BLOCKED) - transport, protocol = yield from create + transport, protocol = yield From(create) proc = transport.get_extra_info('subprocess') # kill the process (but asyncio is not notified immediatly) @@ -409,7 +408,8 @@ def kill_running(): proc_returncode = proc.poll() transport_returncode = transport.get_returncode() transport.close() - return (proc_returncode, transport_returncode, proc.kill.called) + raise Return(proc_returncode, transport_returncode, + proc.kill.called) # Ignore "Unknown child process pid ..." log of SafeChildWatcher, # emitted because the test already consumes the exit status: diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 595b0040..d70acabb 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -26,6 +26,9 @@ def coroutine_function(): pass +@asyncio.coroutine +def coroutine_function2(x, y): + yield From(asyncio.sleep(0)) @contextlib.contextmanager def set_coroutine_debug(enabled): @@ -307,11 +310,8 @@ def test_task_repr_partial_corowrapper(self): with set_coroutine_debug(True): self.loop.set_debug(True) - @asyncio.coroutine - def func(x, y): - yield from asyncio.sleep(0) - - partial_func = asyncio.coroutine(functools.partial(func, 1)) + cb = functools.partial(coroutine_function2, 1) + partial_func = asyncio.coroutine(cb) task = self.loop.create_task(partial_func(2)) # make warnings quiet @@ -319,8 +319,7 @@ def func(x, y): self.addCleanup(task._coro.close) coro_repr = repr(task._coro) - expected = ('.func(1)() running, ') + expected = (' - # - # as an alternative to: - # - # yield from lock.acquire() - # try: - # - # finally: - # lock.release() - yield from self.acquire() - return _ContextManager(self) - - if _PY35: - - def __await__(self): - # To make "with await lock" work. - yield from self.acquire() - return _ContextManager(self) - - @coroutine - def __aenter__(self): - yield from self.acquire() - # We have no use for the "as ..." clause in the with - # statement for locks. - return None - - @coroutine - def __aexit__(self, exc_type, exc, tb): - self.release() + # FIXME: support PEP 492? + #if _PY35: + + # def __await__(self): + # # To make "with await lock" work. + # yield from self.acquire() + # return _ContextManager(self) + + # @coroutine + # def __aenter__(self): + # yield from self.acquire() + # # We have no use for the "as ..." clause in the with + # # statement for locks. + # return None + + # @coroutine + # def __aexit__(self, exc_type, exc, tb): + # self.release() class Lock(_ContextManagerMixin): diff --git a/trollius/py33_exceptions.py b/trollius/py33_exceptions.py index 94cbfca4..f10dfe9e 100644 --- a/trollius/py33_exceptions.py +++ b/trollius/py33_exceptions.py @@ -1,7 +1,7 @@ __all__ = ['BlockingIOError', 'BrokenPipeError', 'ChildProcessError', 'ConnectionRefusedError', 'ConnectionResetError', 'InterruptedError', 'ConnectionAbortedError', 'PermissionError', - 'FileNotFoundError', + 'FileNotFoundError', 'ProcessLookupError', ] import errno diff --git a/trollius/selector_events.py b/trollius/selector_events.py index cf0681f3..4eb4bc5d 100644 --- a/trollius/selector_events.py +++ b/trollius/selector_events.py @@ -26,7 +26,7 @@ from . import transports from . import sslproto from .compat import flatten_bytes -from .coroutines import coroutine +from .coroutines import coroutine, From from .log import logger from .py33_exceptions import (wrap_error, BlockingIOError, InterruptedError, ConnectionAbortedError, BrokenPipeError, @@ -226,7 +226,7 @@ def _accept_connection2(self, protocol_factory, conn, extra, server=server) try: - yield from waiter + yield From(waiter) except: transport.close() raise diff --git a/trollius/selectors.py b/trollius/selectors.py index d2f822cd..cf0475de 100644 --- a/trollius/selectors.py +++ b/trollius/selectors.py @@ -411,7 +411,7 @@ def unregister(self, fileobj): key = super(EpollSelector, self).unregister(fileobj) try: self._epoll.unregister(key.fd) - except OSError: + except IOError: # This can happen if the FD was closed since it # was registered. pass diff --git a/trollius/sslproto.py b/trollius/sslproto.py index 5f4920a5..2da10822 100644 --- a/trollius/sslproto.py +++ b/trollius/sslproto.py @@ -9,6 +9,7 @@ from . import protocols from . import transports from .log import logger +from .py33_exceptions import BrokenPipeError, ConnectionResetError from .py3_ssl import BACKPORT_SSL_CONTEXT diff --git a/trollius/streams.py b/trollius/streams.py index c235c5a5..c9532dc9 100644 --- a/trollius/streams.py +++ b/trollius/streams.py @@ -498,14 +498,15 @@ def readexactly(self, n): raise Return(b''.join(blocks)) - if _PY35: - @coroutine - def __aiter__(self): - return self - - @coroutine - def __anext__(self): - val = yield from self.readline() - if val == b'': - raise StopAsyncIteration - return val + # FIXME: should we support __aiter__ and __anext__ in Trollius? + #if _PY35: + # @coroutine + # def __aiter__(self): + # return self + + # @coroutine + # def __anext__(self): + # val = yield from self.readline() + # if val == b'': + # raise StopAsyncIteration + # return val diff --git a/trollius/subprocess.py b/trollius/subprocess.py index 1e767689..4ed2b5c5 100644 --- a/trollius/subprocess.py +++ b/trollius/subprocess.py @@ -1,3 +1,5 @@ +from __future__ import absolute_import + __all__ = ['create_subprocess_exec', 'create_subprocess_shell'] import subprocess diff --git a/trollius/tasks.py b/trollius/tasks.py index 963b54dd..2f461b3d 100644 --- a/trollius/tasks.py +++ b/trollius/tasks.py @@ -1,4 +1,5 @@ """Support for tasks, coroutines and the scheduler.""" +from __future__ import print_function __all__ = ['Task', 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', @@ -560,7 +561,7 @@ def async(coro_or_future, loop=None): return ensure_future(coro_or_future, loop=loop) -def ensure_future(coro_or_future, *, loop=None): +def ensure_future(coro_or_future, loop=None): """Wrap a coroutine in a future. If the argument is a Future, it is returned directly. diff --git a/trollius/test_utils.py b/trollius/test_utils.py index caa98f71..f6fb2468 100644 --- a/trollius/test_utils.py +++ b/trollius/test_utils.py @@ -50,18 +50,17 @@ from socket import socketpair # pragma: no cover try: - import unittest - skipIf = unittest.skipIf - skipUnless = unittest.skipUnless - SkipTest = unittest.SkipTest - _TestCase = unittest.TestCase -except AttributeError: - # Python 2.6: use the backported unittest module called "unittest2" import unittest2 skipIf = unittest2.skipIf skipUnless = unittest2.skipUnless SkipTest = unittest2.SkipTest _TestCase = unittest2.TestCase +except ImportError: + import unittest + skipIf = unittest.skipIf + skipUnless = unittest.skipUnless + SkipTest = unittest.SkipTest + _TestCase = unittest.TestCase if not hasattr(_TestCase, 'assertRaisesRegex'): diff --git a/trollius/windows_events.py b/trollius/windows_events.py index 9274a9d9..cecfa44e 100644 --- a/trollius/windows_events.py +++ b/trollius/windows_events.py @@ -18,7 +18,7 @@ from . import _overlapped from .coroutines import coroutine, From, Return from .log import logger -from .py33_exceptions import wrap_error, BrokenPipeError +from .py33_exceptions import wrap_error, BrokenPipeError, ConnectionResetError __all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor', From 352b8cec079145179f52db6fa70cede8c49ad0c0 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 9 Jul 2015 01:57:49 +0200 Subject: [PATCH 1408/1502] Fixes for Python 2.6 --- TODO.rst | 1 + runtests.py | 6 +++--- tests/test_events.py | 25 +++++++++++++------------ tests/test_selector_events.py | 4 ++-- tests/test_sslproto.py | 2 +- tests/test_unix_events.py | 4 ++-- trollius/coroutines.py | 2 +- trollius/events.py | 2 +- 8 files changed, 24 insertions(+), 22 deletions(-) diff --git a/TODO.rst b/TODO.rst index f9b25a5e..09df760c 100644 --- a/TODO.rst +++ b/TODO.rst @@ -1,5 +1,6 @@ Unsorted "TODO" tasks: +* Drop Python 2.6 support * test_utils.py: remove assertRaisesRegex, assertRegex * streams.py:FIXME: should we support __aiter__ and __anext__ in Trollius? * reuse selectors backport from PyPI diff --git a/runtests.py b/runtests.py index efc6418a..1fabcffd 100755 --- a/runtests.py +++ b/runtests.py @@ -42,11 +42,11 @@ sys.exc_clear() try: - import unittest - from unittest.signals import installHandler -except ImportError: import unittest2 as unittest from unittest2.signals import installHandler +except ImportError: + import unittest + from unittest.signals import installHandler ARGS = optparse.OptionParser(description="Run all unittests.", usage="%prog [options] [pattern] [pattern2 ...]") ARGS.add_option( diff --git a/tests/test_events.py b/tests/test_events.py index 225c7ce3..ad7e4686 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -653,7 +653,7 @@ def _dummy_ssl_create_context(purpose=ssl.Purpose.SERVER_AUTH, with test_utils.disable_logger(): self._basetest_create_ssl_connection(conn_fut, check_sockname) - self.assertEqual(cm.exception.reason, 'CERTIFICATE_VERIFY_FAILED') + self.assertEqual(cm.exception.reason, 'CERTIFICATE_VERIFY_FAILED') @test_utils.skipIf(ssl is None, 'No ssl module') def test_create_ssl_connection(self): @@ -981,18 +981,19 @@ def test_create_server_ssl_match_failed(self): err_msg = "hostname '127.0.0.1' doesn't match u'localhost'" # incorrect server_hostname -# if not asyncio.BACKPORT_SSL_CONTEXT: - f_c = self.loop.create_connection(MyProto, host, port, - ssl=sslcontext_client) - with mock.patch.object(self.loop, 'call_exception_handler'): - with test_utils.disable_logger(): - with self.assertRaisesRegex( - ssl.CertificateError, - err_msg): - self.loop.run_until_complete(f_c) + if not asyncio.BACKPORT_SSL_CONTEXT: + f_c = self.loop.create_connection(MyProto, host, port, + ssl=sslcontext_client) + with mock.patch.object(self.loop, 'call_exception_handler'): + with test_utils.disable_logger(): + with self.assertRaisesRegex( + ssl.CertificateError, + err_msg): + self.loop.run_until_complete(f_c) + + # close connection + proto.transport.close() - # close connection - proto.transport.close() server.close() def test_legacy_create_server_ssl_match_failed(self): diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 06b6e875..66c1c01a 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -86,7 +86,7 @@ def test_make_socket_transport(self): close_transport(transport) - @unittest.skipIf(ssl is None, 'No ssl module') + @test_utils.skipIf(ssl is None, 'No ssl module') def test_make_ssl_transport(self): m = mock.Mock() self.loop.add_reader = mock.Mock() @@ -1475,7 +1475,7 @@ def test_close(self): self.assertTrue(self.protocol.connection_lost.called) def test_close_not_connected(self): - self.sslsock.do_handshake.side_effect = ssl.SSLWantReadError + self.sslsock.do_handshake.side_effect = SSLWantReadError self.check_close() self.assertFalse(self.protocol.connection_made.called) self.assertFalse(self.protocol.connection_lost.called) diff --git a/tests/test_sslproto.py b/tests/test_sslproto.py index f60bbf8c..630c1616 100644 --- a/tests/test_sslproto.py +++ b/tests/test_sslproto.py @@ -13,7 +13,7 @@ from trollius import ConnectionResetError -@unittest.skipIf(ssl is None, 'No ssl module') +@test_utils.skipIf(ssl is None, 'No ssl module') class SslProtoHandshakeTests(test_utils.TestCase): def setUp(self): diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index f130c7f0..137a699f 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -230,8 +230,8 @@ def test_close(self, m_signal): m_signal.set_wakeup_fd.assert_called_once_with(-1) -@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), - 'UNIX Sockets are not supported') +@test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), + 'UNIX Sockets are not supported') class SelectorEventLoopUnixSocketTests(test_utils.TestCase): def setUp(self): diff --git a/trollius/coroutines.py b/trollius/coroutines.py index 2e5e4de0..45c2fe21 100644 --- a/trollius/coroutines.py +++ b/trollius/coroutines.py @@ -387,7 +387,7 @@ def _format_coroutine(coro): func = coro.func coro_name = coro.__qualname__ if coro_name is not None: - coro_name = '{}()'.format(coro_name) + coro_name = '{0}()'.format(coro_name) else: func = coro diff --git a/trollius/events.py b/trollius/events.py index bb2599ba..859ac074 100644 --- a/trollius/events.py +++ b/trollius/events.py @@ -137,7 +137,7 @@ def _run(self): self._callback(*self._args) except Exception as exc: cb = _format_callback_source(self._callback, self._args) - msg = 'Exception in callback {}'.format(cb) + msg = 'Exception in callback {0}'.format(cb) context = { 'message': msg, 'exception': exc, From 94fa7e219baa99d9b23f0b06c5cd5431ec9cd0e6 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 9 Jul 2015 02:10:56 +0200 Subject: [PATCH 1409/1502] tox.ini: fix py2_release --- tox.ini | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index d102e75d..bccb3b81 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,6 @@ [tox] -envlist = py26,py27,py2_release,py32,py33,py34,py3_release,pyflakes2,pyflakes3 +envlist = py26,py27,py2_release,py32,py33,py34,py3_release +# and: pyflakes2,pyflakes3 [testenv] deps= @@ -45,6 +46,7 @@ deps= aiotest futures mock + unittest2 setenv = TROLLIUSDEBUG = basepython = python2.7 From 03a4d4dc7b12c88ab9eb76188ea6f50ba7f7c1f1 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 9 Jul 2015 02:32:12 +0200 Subject: [PATCH 1410/1502] pick README from trollius --- README.rst | 98 +++++++++++++----------------------------------------- setup.py | 2 +- 2 files changed, 24 insertions(+), 76 deletions(-) diff --git a/README.rst b/README.rst index 629368e2..f67c00db 100644 --- a/README.rst +++ b/README.rst @@ -1,93 +1,41 @@ -The asyncio module provides infrastructure for writing single-threaded -concurrent code using coroutines, multiplexing I/O access over sockets and -other resources, running network clients and servers, and other related -primitives. Here is a more detailed list of the package contents: +Trollius provides infrastructure for writing single-threaded concurrent +code using coroutines, multiplexing I/O access over sockets and other +resources, running network clients and servers, and other related primitives. +Here is a more detailed list of the package contents: * a pluggable event loop with various system-specific implementations; -* transport and protocol abstractions (similar to those in Twisted); +* transport and protocol abstractions (similar to those in `Twisted + `_); * concrete support for TCP, UDP, SSL, subprocess pipes, delayed calls, and others (some may be system-dependent); -* a Future class that mimics the one in the concurrent.futures module, but - adapted for use with the event loop; +* a ``Future`` class that mimics the one in the ``concurrent.futures`` module, + but adapted for use with the event loop; -* coroutines and tasks based on ``yield from`` (PEP 380), to help write +* coroutines and tasks based on generators (``yield``), to help write concurrent code in a sequential fashion; -* cancellation support for Futures and coroutines; +* cancellation support for ``Future``\s and coroutines; * synchronization primitives for use between coroutines in a single thread, - mimicking those in the threading module; + mimicking those in the ``threading`` module; * an interface for passing work off to a threadpool, for times when you absolutely, positively have to use a library that makes blocking I/O calls. +Trollius is a portage of the `Tulip project `_ +(``asyncio`` module, `PEP 3156 `_) +on Python 2. Trollius works on Python 2.6-3.5. It has been tested on Windows, +Linux, Mac OS X, FreeBSD and OpenIndiana. -Installation -============ - -To install asyncio, type:: - - pip install asyncio - -asyncio requires Python 3.3 or later! The asyncio module is part of the Python -standard library since Python 3.4. - -asyncio is a free software distributed under the Apache license version 2.0. - - -Websites -======== - -* `asyncio project at GitHub `_: source - code, bug tracker -* `asyncio documentation `_ -* Mailing list: `python-tulip Google Group - `_ -* IRC: join the ``#asyncio`` channel on the Freenode network - - -Development -=========== - -The actual code lives in the 'asyncio' subdirectory. Tests are in the 'tests' -subdirectory. - -To run tests, run:: - - tox - -Or use the Makefile:: - - make test - -To run coverage (coverage package is required):: - - make coverage - -On Windows, things are a little more complicated. Assume 'P' is your -Python binary (for example C:\Python33\python.exe). - -You must first build the _overlapped.pyd extension and have it placed -in the asyncio directory, as follows: - - C> P setup.py build_ext --inplace - -If this complains about vcvars.bat, you probably don't have the -required version of Visual Studio installed. Compiling extensions for -Python 3.3 requires Microsoft Visual C++ 2010 (MSVC 10.0) of any -edition; you can download Visual Studio Express 2010 for free from -http://www.visualstudio.com/downloads (scroll down to Visual C++ 2010 -Express). - -Once you have built the _overlapped.pyd extension successfully you can -run the tests as follows: - - C> P runtests.py - -And coverage as follows: - - C> P runtests.py --coverage +* `Asyncio documentation `_ +* `Trollius documentation `_ +* `Trollius project in the Python Cheeseshop (PyPI) + `_ +* `Trollius project at Github `_ (code, + bug tracker) +* Copyright/license: Open source, Apache 2.0. Enjoy! +See also the `asyncio project at Github `_. diff --git a/setup.py b/setup.py index b40169cd..c98c9751 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,5 @@ # Release procedure: -# - fill Tulip changelog +# - fill trollius changelog # - run maybe update_tulip.sh # - run unit tests with concurrent.futures # - run unit tests without concurrent.futures From 6f52fbfaacae0322a9c3303f15ae44444c4b1e13 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 9 Jul 2015 02:35:26 +0200 Subject: [PATCH 1411/1502] README: replace Tulip with asyncio --- README.rst | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/README.rst b/README.rst index f67c00db..a55218d5 100644 --- a/README.rst +++ b/README.rst @@ -25,10 +25,11 @@ Here is a more detailed list of the package contents: * an interface for passing work off to a threadpool, for times when you absolutely, positively have to use a library that makes blocking I/O calls. -Trollius is a portage of the `Tulip project `_ -(``asyncio`` module, `PEP 3156 `_) -on Python 2. Trollius works on Python 2.6-3.5. It has been tested on Windows, -Linux, Mac OS X, FreeBSD and OpenIndiana. +Trollius is a portage of the `asyncio project +`_ (`PEP 3156 +`_) on Python 2. Trollius works on +Python 2.6-3.5. It has been tested on Windows, Linux, Mac OS X, FreeBSD and +OpenIndiana. * `Asyncio documentation `_ * `Trollius documentation `_ From fe673cc69295a2a082f9840c85bfd3cbe303e117 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 9 Jul 2015 02:39:43 +0200 Subject: [PATCH 1412/1502] replace Tulip with asyncio --- AUTHORS | 4 ++-- doc/asyncio.rst | 36 ++++++++++++++++++------------------ doc/index.rst | 5 ++--- trollius/coroutines.py | 2 +- 4 files changed, 23 insertions(+), 24 deletions(-) diff --git a/AUTHORS b/AUTHORS index c625633a..3c3966b9 100644 --- a/AUTHORS +++ b/AUTHORS @@ -9,6 +9,6 @@ The photo of Trollis flower was taken by Imartin6 and distributed under the CC BY-SA 3.0 license. It comes from: http://commons.wikimedia.org/wiki/File:Trollius_altaicus.jpg -Trollius is a port of the Tulip project on Python 2, see also authors of the -Tulip project (AUTHORS file of the Tulip project). +Trollius is a port of the asyncio project on Python 2, see also authors of the +asyncio project (AUTHORS file). diff --git a/doc/asyncio.rst b/doc/asyncio.rst index 5866d62f..011a9a8b 100644 --- a/doc/asyncio.rst +++ b/doc/asyncio.rst @@ -1,17 +1,17 @@ -++++++++++++++++++ -Trollius and Tulip -++++++++++++++++++ +++++++++++++++++++++ +Trollius and asyncio +++++++++++++++++++++ -Differences between Trollius and Tulip -====================================== +Differences between Trollius and asyncio +======================================== Syntax of coroutines -------------------- -The major difference between Trollius and Tulip is the syntax of coroutines: +The major difference between Trollius and asyncio is the syntax of coroutines: ================== ====================== -Tulip Trollius +asyncio Trollius ================== ====================== ``yield from ...`` ``yield From(...)`` ``yield from []`` ``yield From(None)`` @@ -129,41 +129,41 @@ Other differences ``BaseEventLoop.run_in_executor()`` uses a synchronous executor instead of a pool of threads. It blocks until the function returns. For example, DNS resolutions are blocking in this case. -* Trollius has more symbols than Tulip for compatibility with Python older than - 3.3: +* Trollius has more symbols than asyncio for compatibility with Python older + than 3.3: - ``From``: part of ``yield From(...)`` syntax - ``Return``: part of ``raise Return(...)`` syntax -Write code working on Trollius and Tulip -======================================== +Write code working on Trollius and asyncio +========================================== -Trollius and Tulip are different, especially for coroutines (``yield +Trollius and asyncio are different, especially for coroutines (``yield From(...)`` vs ``yield from ...``). To use asyncio or Trollius on Python 2 and Python 3, add the following code at the top of your file:: try: - # Use builtin asyncio on Python 3.4+, or Tulip on Python 3.3 + # Use builtin asyncio on Python 3.4+, or asyncio on Python 3.3 import asyncio except ImportError: # Use Trollius on Python <= 3.2 import trollius as asyncio It is possible to write code working on both projects using only callbacks. -This option is used by the following projects which work on Trollius and Tulip: +This option is used by the following projects which work on Trollius and asyncio: * `AutobahnPython `_: WebSocket & - WAMP for Python, it works on Trollius (Python 2.6 and 2.7), Tulip (Python + WAMP for Python, it works on Trollius (Python 2.6 and 2.7), asyncio (Python 3.3) and Python 3.4 (asyncio), and also on Twisted. * `Pulsar `_: Event driven concurrent framework for Python. With pulsar you can write asynchronous servers performing one or several activities in different threads and/or processes. Trollius 0.3 requires Pulsar 0.8.2 or later. Pulsar uses the ``asyncio`` module if available, or import ``trollius``. -* `Tornado `_ supports Tulip and Trollius since +* `Tornado `_ supports asyncio and Trollius since Tornado 3.2: `tornado.platform.asyncio — Bridge between asyncio and Tornado `_. It tries to import asyncio or fallback on importing trollius. @@ -171,10 +171,10 @@ This option is used by the following projects which work on Trollius and Tulip: Another option is to provide functions returning ``Future`` objects, so the caller can decide to use callback using ``fut.add_done_callback(callback)`` or to use coroutines (``yield From(fut)`` for Trollius, or ``yield from fut`` for -Tulip). This option is used by the `aiodns `_ +asyncio). This option is used by the `aiodns `_ project for example. -Since Trollius 0.4, it's possible to use Tulip and Trollius coroutines in the +Since Trollius 0.4, it's possible to use asyncio and Trollius coroutines in the same process. The only limit is that the event loop must be a Trollius event loop. diff --git a/doc/index.rst b/doc/index.rst index 86250930..317f9245 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -33,7 +33,7 @@ Here is a more detailed list of the package contents: * an interface for passing work off to a threadpool, for times when you absolutely, positively have to use a library that makes blocking I/O calls. -Trollius is a portage of the `Tulip project `_ +Trollius is a portage of the `asyncio project `_ (``asyncio`` module, `PEP 3156 `_) on Python 2. Trollius works on Python 2.6-3.5. It has been tested on Windows, Linux, Mac OS X, FreeBSD and OpenIndiana. @@ -50,8 +50,7 @@ Linux, Mac OS X, FreeBSD and OpenIndiana. * IRC: ``#asyncio`` channel on the `Freenode network `_ * Copyright/license: Open source, Apache 2.0. Enjoy! -See also the `Tulip project `_ (asyncio module -for Python 3.3). +See also the `asyncio project at Github `_. Table Of Contents diff --git a/trollius/coroutines.py b/trollius/coroutines.py index 45c2fe21..8bd0b7e5 100644 --- a/trollius/coroutines.py +++ b/trollius/coroutines.py @@ -371,7 +371,7 @@ def iscoroutinefunction(func): if hasattr(events.asyncio, 'coroutines'): _COROUTINE_TYPES += (events.asyncio.coroutines.CoroWrapper,) else: - # old Tulip/Python versions + # old asyncio/Python versions _COROUTINE_TYPES += (events.asyncio.tasks.CoroWrapper,) def iscoroutine(obj): From c335bfe4990605bcc473a530987346ce50cd0e2c Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 9 Jul 2015 02:40:44 +0200 Subject: [PATCH 1413/1502] queues.py: import coroutine from .coroutines not from .tasks --- asyncio/queues.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/queues.py b/asyncio/queues.py index ed116620..3b4dc21a 100644 --- a/asyncio/queues.py +++ b/asyncio/queues.py @@ -9,7 +9,7 @@ from . import events from . import futures from . import locks -from .tasks import coroutine +from .coroutines import coroutine class QueueEmpty(Exception): From f5ae096b6945f047f26a7fafb602fb0781450537 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 9 Jul 2015 02:41:01 +0200 Subject: [PATCH 1414/1502] Replace Tulip with asyncio in test comments --- tests/test_selector_events.py | 2 +- tests/test_streams.py | 6 +++--- tests/test_subprocess.py | 2 +- tests/test_windows_events.py | 6 ++++-- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index 9478b954..f0fcdd22 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -348,7 +348,7 @@ def test_sock_connect(self): self.loop._sock_connect.call_args[0]) def test_sock_connect_timeout(self): - # Tulip issue #205: sock_connect() must unregister the socket on + # asyncio issue #205: sock_connect() must unregister the socket on # timeout error # prepare mocks diff --git a/tests/test_streams.py b/tests/test_streams.py index 2273049b..242b377e 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -580,7 +580,7 @@ def client(path): @unittest.skipIf(sys.platform == 'win32', "Don't have pipes") def test_read_all_from_pipe_reader(self): - # See Tulip issue 168. This test is derived from the example + # See asyncio issue 168. This test is derived from the example # subprocess_attach_read_pipe.py, but we configure the # StreamReader's limit so that twice it is less than the size # of the data writter. Also we must explicitly attach a child @@ -621,7 +621,7 @@ def test_streamreader_constructor(self): self.addCleanup(asyncio.set_event_loop, None) asyncio.set_event_loop(self.loop) - # Tulip issue #184: Ensure that StreamReaderProtocol constructor + # asyncio issue #184: Ensure that StreamReaderProtocol constructor # retrieves the current loop if the loop parameter is not set reader = asyncio.StreamReader() self.assertIs(reader._loop, self.loop) @@ -630,7 +630,7 @@ def test_streamreaderprotocol_constructor(self): self.addCleanup(asyncio.set_event_loop, None) asyncio.set_event_loop(self.loop) - # Tulip issue #184: Ensure that StreamReaderProtocol constructor + # asyncio issue #184: Ensure that StreamReaderProtocol constructor # retrieves the current loop if the loop parameter is not set reader = mock.Mock() protocol = asyncio.StreamReaderProtocol(reader) diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index 5ccdafb1..ea85e191 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -266,7 +266,7 @@ def connect_read_pipe_mock(*args, **kw): self.assertTrue(transport.resume_reading.called) def test_stdin_not_inheritable(self): - # Tulip issue #209: stdin must not be inheritable, otherwise + # asyncio issue #209: stdin must not be inheritable, otherwise # the Process.communicate() hangs @asyncio.coroutine def len_message(message): diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index 657a4274..7fcf4023 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -132,7 +132,8 @@ def test_wait_for_handle(self): self.assertTrue(fut.result()) self.assertTrue(0 <= elapsed < 0.3, elapsed) - # Tulip issue #195: cancelling a done _WaitHandleFuture must not crash + # asyncio issue #195: cancelling a done _WaitHandleFuture + # must not crash fut.cancel() def test_wait_for_handle_cancel(self): @@ -149,7 +150,8 @@ def test_wait_for_handle_cancel(self): elapsed = self.loop.time() - start self.assertTrue(0 <= elapsed < 0.1, elapsed) - # Tulip issue #195: cancelling a _WaitHandleFuture twice must not crash + # asyncio issue #195: cancelling a _WaitHandleFuture twice + # must not crash fut = self.loop._proactor.wait_for_handle(event) fut.cancel() fut.cancel() From 46c187af2656418aae2b56d24a66b5485e80ac4b Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 9 Jul 2015 02:43:11 +0200 Subject: [PATCH 1415/1502] replace tulip with asyncio README: mention the old Tulip name --- AUTHORS | 2 +- README.rst | 2 ++ examples/timing_tcp_server.py | 2 +- runtests.py | 2 +- 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/AUTHORS b/AUTHORS index d25b4465..85913761 100644 --- a/AUTHORS +++ b/AUTHORS @@ -12,7 +12,7 @@ Donald Stufft Eli Bendersky Geert Jansen Giampaolo Rodola' -Guido van Rossum : creator of the Tulip project and author of the PEP 3156 +Guido van Rossum : creator of the asyncio project and author of the PEP 3156 Gustavo Carneiro Jeff Quast Jonathan Slenders diff --git a/README.rst b/README.rst index 629368e2..9f03922f 100644 --- a/README.rst +++ b/README.rst @@ -24,6 +24,8 @@ primitives. Here is a more detailed list of the package contents: * an interface for passing work off to a threadpool, for times when you absolutely, positively have to use a library that makes blocking I/O calls. +Note: The implementation of asyncio was previously called "Tulip". + Installation ============ diff --git a/examples/timing_tcp_server.py b/examples/timing_tcp_server.py index 883ce6d3..3fcdc974 100644 --- a/examples/timing_tcp_server.py +++ b/examples/timing_tcp_server.py @@ -1,7 +1,7 @@ """ A variant of simple_tcp_server.py that measures the time it takes to send N messages for a range of N. (This was O(N**2) in a previous -version of Tulip.) +version of asyncio.) Note that running this example starts both the TCP server and client in the same process. It listens on port 1234 on 127.0.0.1, so it will diff --git a/runtests.py b/runtests.py index c38b0c18..b6ed71e3 100644 --- a/runtests.py +++ b/runtests.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Run Tulip unittests. +"""Run asyncio unittests. Usage: python3 runtests.py [flags] [pattern] ... From 67bf1da7561a54c82e2dcafd2c7bdfd813e880c8 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 9 Jul 2015 02:52:51 +0200 Subject: [PATCH 1416/1502] update doc for unittest2 --- doc/dev.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/dev.rst b/doc/dev.rst index 5965306c..1bed7f8e 100644 --- a/doc/dev.rst +++ b/doc/dev.rst @@ -24,7 +24,7 @@ Test Dependencies ----------------- On Python older than 3.3, unit tests require the `mock -`_ module. Python 2.6 requires also +`_ module. Python 2.6 and 2.7 require also `unittest2 `_. To run ``run_aiotest.py``, you need the `aiotest From 4c6a4e1784f7595917f991ec470a395b3bbfcd67 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 9 Jul 2015 11:16:24 +0200 Subject: [PATCH 1417/1502] Fix import on windows --- trollius/windows_events.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trollius/windows_events.py b/trollius/windows_events.py index cecfa44e..3102d230 100644 --- a/trollius/windows_events.py +++ b/trollius/windows_events.py @@ -1,6 +1,5 @@ """Selector and proactor event loops for Windows.""" -import _winapi import errno import math import socket From d15529ba71d2ffbc41b09f4f7330dd566adda1ab Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 9 Jul 2015 11:23:14 +0200 Subject: [PATCH 1418/1502] get unittest from trollius.test_utils --- tests/test_asyncio.py | 2 +- tests/test_base_events.py | 6 +-- tests/test_events.py | 72 +++++++++++++++++------------------ tests/test_futures.py | 10 ++--- tests/test_locks.py | 2 +- tests/test_proactor_events.py | 2 +- tests/test_queues.py | 3 +- tests/test_selector_events.py | 20 +++++----- tests/test_selectors.py | 26 ++++++------- tests/test_sslproto.py | 6 +-- tests/test_streams.py | 16 ++++---- tests/test_subprocess.py | 8 ++-- tests/test_tasks.py | 6 +-- tests/test_transports.py | 3 +- tests/test_unix_events.py | 9 ++--- tests/test_windows_events.py | 5 +-- tests/test_windows_utils.py | 17 ++++----- trollius/test_utils.py | 23 +++++------ 18 files changed, 114 insertions(+), 122 deletions(-) diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index d75054d6..0421db04 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -2,7 +2,7 @@ from trollius import From, Return import trollius import trollius.coroutines -import unittest +from trollius.test_utils import unittest try: import asyncio diff --git a/tests/test_base_events.py b/tests/test_base_events.py index 808b7090..02ecfeb3 100644 --- a/tests/test_base_events.py +++ b/tests/test_base_events.py @@ -7,7 +7,6 @@ import sys import threading import time -import unittest import trollius as asyncio from trollius import Return, From @@ -17,6 +16,7 @@ from trollius.py33_exceptions import BlockingIOError from trollius.test_utils import mock from trollius.time_monotonic import time_monotonic +from trollius.test_utils import unittest from trollius import test_support as support @@ -1159,8 +1159,8 @@ def test_create_datagram_endpoint_socket_err(self, m_socket): self.assertRaises( socket.error, self.loop.run_until_complete, coro) - @test_utils.skipUnless(support.IPV6_ENABLED, - 'IPv6 not supported or enabled') + @unittest.skipUnless(support.IPV6_ENABLED, + 'IPv6 not supported or enabled') def test_create_datagram_endpoint_no_matching_family(self): coro = self.loop.create_datagram_endpoint( asyncio.DatagramProtocol, diff --git a/tests/test_events.py b/tests/test_events.py index ad7e4686..5153312b 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -13,7 +13,6 @@ import sys import threading import errno -import unittest import weakref try: @@ -37,6 +36,7 @@ from trollius import sslproto from trollius import test_support as support from trollius import test_utils +from trollius.test_utils import unittest from trollius.py33_exceptions import (wrap_error, BlockingIOError, ConnectionRefusedError, FileNotFoundError) @@ -343,7 +343,7 @@ def callback(arg): self.loop.run_forever() self.assertEqual(results, ['hello', 'world']) - @test_utils.skipIf(concurrent is None, 'need concurrent.futures') + @unittest.skipIf(concurrent is None, 'need concurrent.futures') def test_run_in_executor(self): def run(arg): return (arg, threading.current_thread().ident) @@ -438,7 +438,7 @@ def test_sock_client_ops(self): sock = socket.socket() self._basetest_sock_client_ops(httpd, sock) - @test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_unix_sock_client_ops(self): with test_utils.run_test_unix_server() as httpd: sock = socket.socket(socket.AF_UNIX) @@ -478,7 +478,7 @@ def test_sock_accept(self): conn.close() listener.close() - @test_utils.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') + @unittest.skipUnless(hasattr(signal, 'SIGKILL'), 'No SIGKILL') def test_add_signal_handler(self): non_local = {'caught': 0} @@ -521,7 +521,7 @@ def my_handler(): # Removing again returns False. self.assertFalse(self.loop.remove_signal_handler(signal.SIGINT)) - @test_utils.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') def test_signal_handling_while_selecting(self): # Test with a signal actually arriving during a select() call. non_local = {'caught': 0} @@ -536,7 +536,7 @@ def my_handler(): self.loop.run_forever() self.assertEqual(non_local['caught'], 1) - @test_utils.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') + @unittest.skipUnless(hasattr(signal, 'SIGALRM'), 'No SIGALRM') def test_signal_handling_args(self): some_args = (42,) non_local = {'caught': 0} @@ -569,7 +569,7 @@ def test_create_connection(self): lambda: MyProto(loop=self.loop), *httpd.address) self._basetest_create_connection(conn_fut) - @test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_create_unix_connection(self): # Issue #20682: On Mac OS X Tiger, getsockname() returns a # zero-length address for UNIX socket. @@ -655,7 +655,7 @@ def _dummy_ssl_create_context(purpose=ssl.Purpose.SERVER_AUTH, self.assertEqual(cm.exception.reason, 'CERTIFICATE_VERIFY_FAILED') - @test_utils.skipIf(ssl is None, 'No ssl module') + @unittest.skipIf(ssl is None, 'No ssl module') def test_create_ssl_connection(self): with test_utils.run_test_server(use_ssl=True) as httpd: create_connection = functools.partial( @@ -668,8 +668,8 @@ def test_legacy_create_ssl_connection(self): with test_utils.force_legacy_ssl_support(): self.test_create_ssl_connection() - @test_utils.skipIf(ssl is None, 'No ssl module') - @test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_create_ssl_unix_connection(self): # Issue #20682: On Mac OS X Tiger, getsockname() returns a # zero-length address for UNIX socket. @@ -755,7 +755,7 @@ def _make_unix_server(self, factory, **kwargs): return server, path - @test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_create_unix_server(self): proto = MyProto(loop=self.loop) server, path = self._make_unix_server(lambda: proto) @@ -783,7 +783,7 @@ def test_create_unix_server(self): # close server server.close() - @test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_create_unix_server_path_socket_error(self): proto = MyProto(loop=self.loop) sock = socket.socket() @@ -818,7 +818,7 @@ def _make_ssl_unix_server(self, factory, certfile, keyfile=None): sslcontext = self._create_ssl_context(certfile, keyfile) return self._make_unix_server(factory, ssl=sslcontext) - @test_utils.skipIf(ssl is None, 'No ssl module') + @unittest.skipIf(ssl is None, 'No ssl module') def test_create_server_ssl(self): proto = MyProto(loop=self.loop) server, host, port = self._make_ssl_server( @@ -856,8 +856,8 @@ def test_legacy_create_server_ssl(self): with test_utils.force_legacy_ssl_support(): self.test_create_server_ssl() - @test_utils.skipIf(ssl is None, 'No ssl module') - @test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_create_unix_server_ssl(self): proto = MyProto(loop=self.loop) server, path = self._make_ssl_unix_server( @@ -891,8 +891,8 @@ def test_legacy_create_unix_server_ssl(self): with test_utils.force_legacy_ssl_support(): self.test_create_unix_server_ssl() - @test_utils.skipIf(ssl is None, 'No ssl module') - @test_utils.skipIf(asyncio.BACKPORT_SSL_CONTEXT, 'need ssl.SSLContext') + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipIf(asyncio.BACKPORT_SSL_CONTEXT, 'need ssl.SSLContext') def test_create_server_ssl_verify_failed(self): proto = MyProto(loop=self.loop) server, host, port = self._make_ssl_server( @@ -924,9 +924,9 @@ def test_legacy_create_server_ssl_verify_failed(self): with test_utils.force_legacy_ssl_support(): self.test_create_server_ssl_verify_failed() - @test_utils.skipIf(ssl is None, 'No ssl module') - @test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') - @test_utils.skipIf(asyncio.BACKPORT_SSL_CONTEXT, 'need ssl.SSLContext') + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @unittest.skipIf(asyncio.BACKPORT_SSL_CONTEXT, 'need ssl.SSLContext') def test_create_unix_server_ssl_verify_failed(self): proto = MyProto(loop=self.loop) server, path = self._make_ssl_unix_server( @@ -959,7 +959,7 @@ def test_legacy_create_unix_server_ssl_verify_failed(self): with test_utils.force_legacy_ssl_support(): self.test_create_unix_server_ssl_verify_failed() - @test_utils.skipIf(ssl is None, 'No ssl module') + @unittest.skipIf(ssl is None, 'No ssl module') def test_create_server_ssl_match_failed(self): proto = MyProto(loop=self.loop) server, host, port = self._make_ssl_server( @@ -1000,8 +1000,8 @@ def test_legacy_create_server_ssl_match_failed(self): with test_utils.force_legacy_ssl_support(): self.test_create_server_ssl_match_failed() - @test_utils.skipIf(ssl is None, 'No ssl module') - @test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_create_unix_server_ssl_verified(self): proto = MyProto(loop=self.loop) server, path = self._make_ssl_unix_server( @@ -1031,7 +1031,7 @@ def test_legacy_create_unix_server_ssl_verified(self): with test_utils.force_legacy_ssl_support(): self.test_create_unix_server_ssl_verified() - @test_utils.skipIf(ssl is None, 'No ssl module') + @unittest.skipIf(ssl is None, 'No ssl module') def test_create_server_ssl_verified(self): proto = MyProto(loop=self.loop) server, host, port = self._make_ssl_server( @@ -1104,7 +1104,7 @@ def test_create_server_addr_in_use(self): server.close() - @test_utils.skipUnless(support.IPV6_ENABLED, 'IPv6 not supported or enabled') + @unittest.skipUnless(support.IPV6_ENABLED, 'IPv6 not supported or enabled') def test_create_server_dual_stack(self): f_proto = asyncio.Future(loop=self.loop) @@ -1222,7 +1222,7 @@ def test_internal_fds(self): self.assertIsNone(loop._csock) self.assertIsNone(loop._ssock) - @test_utils.skipUnless(sys.platform != 'win32', + @unittest.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") def test_read_pipe(self): proto = MyReadPipeProto(loop=self.loop) @@ -1257,7 +1257,7 @@ def connect(): # extra info is available self.assertIsNotNone(proto.transport.get_extra_info('pipe')) - @test_utils.skipUnless(sys.platform != 'win32', + @unittest.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") # select, poll and kqueue don't support character devices (PTY) on Mac OS X # older than 10.6 (Snow Leopard) @@ -1297,8 +1297,8 @@ def connect(): # extra info is available self.assertIsNotNone(proto.transport.get_extra_info('pipe')) - @test_utils.skipUnless(sys.platform != 'win32', - "Don't support pipes for Windows") + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") def test_write_pipe(self): rpipe, wpipe = os.pipe() pipeobj = io.open(wpipe, 'wb', 1024) @@ -1336,7 +1336,7 @@ def reader(data): self.loop.run_until_complete(proto.done) self.assertEqual('CLOSED', proto.state) - @test_utils.skipUnless(sys.platform != 'win32', + @unittest.skipUnless(sys.platform != 'win32', "Don't support pipes for Windows") def test_write_pipe_disconnect_on_close(self): rsock, wsock = test_utils.socketpair() @@ -1364,8 +1364,8 @@ def test_write_pipe_disconnect_on_close(self): self.loop.run_until_complete(proto.done) self.assertEqual('CLOSED', proto.state) - @test_utils.skipUnless(sys.platform != 'win32', - "Don't support pipes for Windows") + @unittest.skipUnless(sys.platform != 'win32', + "Don't support pipes for Windows") # select, poll and kqueue don't support character devices (PTY) on Mac OS X # older than 10.6 (Snow Leopard) @support.requires_mac_ver(10, 6) @@ -1692,7 +1692,7 @@ def test_subprocess_terminate(self): self.check_terminated(proto.returncode) transp.close() - @test_utils.skipIf(sys.platform == 'win32', "Don't have SIGHUP") + @unittest.skipIf(sys.platform == 'win32', "Don't have SIGHUP") def test_subprocess_send_signal(self): prog = os.path.join(os.path.dirname(__file__), 'echo.py') @@ -1784,7 +1784,7 @@ def test_subprocess_close_client_stream(self): self.loop.run_until_complete(proto.completed) self.check_killed(proto.returncode) - @test_utils.skipUnless(hasattr(os, 'setsid'), "need os.setsid()") + @unittest.skipUnless(hasattr(os, 'setsid'), "need os.setsid()") def test_subprocess_wait_no_same_group(self): # start the new process in a new session connect = self.loop.subprocess_shell( @@ -1920,8 +1920,8 @@ def create_event_loop(self): @support.requires_mac_ver(10, 9) # Issue #20667: KqueueEventLoopTests.test_read_pty_output() # hangs on OpenBSD 5.5 - @test_utils.skipIf(sys.platform.startswith('openbsd'), - 'test hangs on OpenBSD') + @unittest.skipIf(sys.platform.startswith('openbsd'), + 'test hangs on OpenBSD') def test_read_pty_output(self): super(KqueueEventLoopTests, self).test_read_pty_output() diff --git a/tests/test_futures.py b/tests/test_futures.py index 7a98d127..6467befd 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -7,7 +7,6 @@ import re import sys import threading -import unittest import trollius as asyncio from trollius import From @@ -15,6 +14,7 @@ from trollius import test_support as support from trollius import test_utils from trollius.test_utils import mock +from trollius.test_utils import unittest def get_thread_ident(): @@ -231,7 +231,7 @@ def test_tb_logger_exception_result_retrieved(self, m_log): del fut self.assertFalse(m_log.error.called) - @test_utils.skipIf(concurrent is None, 'need concurrent.futures') + @unittest.skipIf(concurrent is None, 'need concurrent.futures') def test_wrap_future(self): def run(arg): @@ -249,7 +249,7 @@ def test_wrap_future_future(self): f2 = asyncio.wrap_future(f1) self.assertIs(f1, f2) - @test_utils.skipIf(concurrent is None, 'need concurrent.futures') + @unittest.skipIf(concurrent is None, 'need concurrent.futures') @mock.patch('trollius.futures.events') def test_wrap_future_use_global_loop(self, m_events): def run(arg): @@ -259,7 +259,7 @@ def run(arg): f2 = asyncio.wrap_future(f1) self.assertIs(m_events.get_event_loop.return_value, f2._loop) - @test_utils.skipIf(concurrent is None, 'need concurrent.futures') + @unittest.skipIf(concurrent is None, 'need concurrent.futures') def test_wrap_future_cancel(self): f1 = concurrent.futures.Future() f2 = asyncio.wrap_future(f1, loop=self.loop) @@ -268,7 +268,7 @@ def test_wrap_future_cancel(self): self.assertTrue(f1.cancelled()) self.assertTrue(f2.cancelled()) - @test_utils.skipIf(concurrent is None, 'need concurrent.futures') + @unittest.skipIf(concurrent is None, 'need concurrent.futures') def test_wrap_future_cancel2(self): f1 = concurrent.futures.Future() f2 = asyncio.wrap_future(f1, loop=self.loop) diff --git a/tests/test_locks.py b/tests/test_locks.py index f3ea3b0b..71a6cb36 100644 --- a/tests/test_locks.py +++ b/tests/test_locks.py @@ -1,12 +1,12 @@ """Tests for lock.py""" -import unittest import re import trollius as asyncio from trollius import From, Return from trollius import test_utils from trollius.test_utils import mock +from trollius.test_utils import unittest STR_RGX_REPR = ( diff --git a/tests/test_proactor_events.py b/tests/test_proactor_events.py index b55fff32..ceb28e2b 100644 --- a/tests/test_proactor_events.py +++ b/tests/test_proactor_events.py @@ -1,7 +1,6 @@ """Tests for proactor_events.py""" import socket -import unittest from trollius import test_utils from trollius.proactor_events import BaseProactorEventLoop @@ -10,6 +9,7 @@ from trollius.proactor_events import _ProactorWritePipeTransport from trollius.py33_exceptions import ConnectionAbortedError, ConnectionResetError from trollius.test_utils import mock +from trollius.test_utils import unittest import trollius as asyncio diff --git a/tests/test_queues.py b/tests/test_queues.py index 7e92fbc9..e75ae4fd 100644 --- a/tests/test_queues.py +++ b/tests/test_queues.py @@ -1,11 +1,10 @@ """Tests for queues.py""" -import unittest - import trollius as asyncio from trollius import Return, From from trollius import test_utils from trollius.test_utils import mock +from trollius.test_utils import unittest class _QueueTestBase(test_utils.TestCase): diff --git a/tests/test_selector_events.py b/tests/test_selector_events.py index a746e147..2f3d5ce9 100644 --- a/tests/test_selector_events.py +++ b/tests/test_selector_events.py @@ -3,7 +3,6 @@ import errno import socket import sys -import unittest try: import ssl except ImportError: @@ -24,6 +23,7 @@ from trollius.selector_events import _SelectorTransport from trollius.selector_events import _SSL_REQUIRES_SELECT from trollius.test_utils import mock +from trollius.test_utils import unittest if sys.version_info >= (3,): @@ -86,7 +86,7 @@ def test_make_socket_transport(self): close_transport(transport) - @test_utils.skipIf(ssl is None, 'No ssl module') + @unittest.skipIf(ssl is None, 'No ssl module') def test_make_ssl_transport(self): m = mock.Mock() self.loop.add_reader = mock.Mock() @@ -1129,7 +1129,7 @@ def test_write_eof_buffer(self): tr.close() -@test_utils.skipIf(ssl is None, 'No ssl module') +@unittest.skipIf(ssl is None, 'No ssl module') class SelectorSslTransportTests(test_utils.TestCase): def setUp(self): @@ -1272,7 +1272,7 @@ def test_write_exception(self, m_log): transport.write(b'data') m_log.warning.assert_called_with('socket.send() raised exception.') - @test_utils.skipIf(_SSL_REQUIRES_SELECT, 'buggy ssl with the workaround') + @unittest.skipIf(_SSL_REQUIRES_SELECT, 'buggy ssl with the workaround') def test_read_ready_recv(self): self.sslsock.recv.return_value = b'data' transport = self._make_one() @@ -1294,7 +1294,7 @@ def test_read_ready_write_wants_read(self): self.loop.add_writer.assert_called_with( transport._sock_fd, transport._write_ready) - @test_utils.skipIf(_SSL_REQUIRES_SELECT, 'buggy ssl with the workaround') + @unittest.skipIf(_SSL_REQUIRES_SELECT, 'buggy ssl with the workaround') def test_read_ready_recv_eof(self): self.sslsock.recv.return_value = b'' transport = self._make_one() @@ -1303,7 +1303,7 @@ def test_read_ready_recv_eof(self): transport.close.assert_called_with() self.protocol.eof_received.assert_called_with() - @test_utils.skipIf(_SSL_REQUIRES_SELECT, 'buggy ssl with the workaround') + @unittest.skipIf(_SSL_REQUIRES_SELECT, 'buggy ssl with the workaround') def test_read_ready_recv_conn_reset(self): err = self.sslsock.recv.side_effect = ConnectionResetError() transport = self._make_one() @@ -1312,7 +1312,7 @@ def test_read_ready_recv_conn_reset(self): transport._read_ready() transport._force_close.assert_called_with(err) - @test_utils.skipIf(_SSL_REQUIRES_SELECT, 'buggy ssl with the workaround') + @unittest.skipIf(_SSL_REQUIRES_SELECT, 'buggy ssl with the workaround') def test_read_ready_recv_retry(self): self.sslsock.recv.side_effect = SSLWantReadError transport = self._make_one() @@ -1328,7 +1328,7 @@ def test_read_ready_recv_retry(self): transport._read_ready() self.assertFalse(self.protocol.data_received.called) - @test_utils.skipIf(_SSL_REQUIRES_SELECT, 'buggy ssl with the workaround') + @unittest.skipIf(_SSL_REQUIRES_SELECT, 'buggy ssl with the workaround') def test_read_ready_recv_write(self): self.loop.remove_reader = mock.Mock() self.loop.add_writer = mock.Mock() @@ -1342,7 +1342,7 @@ def test_read_ready_recv_write(self): self.loop.add_writer.assert_called_with( transport._sock_fd, transport._write_ready) - @test_utils.skipIf(_SSL_REQUIRES_SELECT, 'buggy ssl with the workaround') + @unittest.skipIf(_SSL_REQUIRES_SELECT, 'buggy ssl with the workaround') def test_read_ready_recv_exc(self): err = self.sslsock.recv.side_effect = OSError() transport = self._make_one() @@ -1480,7 +1480,7 @@ def test_close_not_connected(self): self.assertFalse(self.protocol.connection_made.called) self.assertFalse(self.protocol.connection_lost.called) - @test_utils.skipIf(ssl is None, 'No SSL support') + @unittest.skipIf(ssl is None, 'No SSL support') def test_server_hostname(self): self.ssl_transport(server_hostname='localhost') self.sslcontext.wrap_socket.assert_called_with( diff --git a/tests/test_selectors.py b/tests/test_selectors.py index dde82ec9..a06596ea 100644 --- a/tests/test_selectors.py +++ b/tests/test_selectors.py @@ -3,7 +3,6 @@ import random import signal import sys -import unittest from time import sleep try: import resource @@ -15,6 +14,7 @@ from trollius import test_utils from trollius.test_utils import mock from trollius.test_utils import socketpair +from trollius.test_utils import unittest from trollius.time_monotonic import time_monotonic as time @@ -87,7 +87,7 @@ def test_unregister_after_fd_close(self): s.unregister(r) s.unregister(w) - @test_utils.skipUnless(os.name == 'posix', "requires posix") + @unittest.skipUnless(os.name == 'posix', "requires posix") def test_unregister_after_fd_close_and_reuse(self): s = self.SELECTOR() self.addCleanup(s.close) @@ -300,8 +300,8 @@ def test_selector(self): self.assertEqual(bufs, [MSG] * NUM_SOCKETS) - @test_utils.skipIf(sys.platform == 'win32', - 'select.select() cannot be used with empty fd sets') + @unittest.skipIf(sys.platform == 'win32', + 'select.select() cannot be used with empty fd sets') def test_empty_select(self): s = self.SELECTOR() self.addCleanup(s.close) @@ -333,8 +333,8 @@ def test_timeout(self): # Tolerate 2.0 seconds for very slow buildbots self.assertTrue(0.8 <= dt <= 2.0, dt) - @test_utils.skipUnless(hasattr(signal, "alarm"), - "signal.alarm() required for this test") + @unittest.skipUnless(hasattr(signal, "alarm"), + "signal.alarm() required for this test") def test_select_interrupt(self): s = self.SELECTOR() self.addCleanup(s.close) @@ -357,7 +357,7 @@ class ScalableSelectorMixIn(object): # see issue #18963 for why it's skipped on older OS X versions @support.requires_mac_ver(10, 5) - @test_utils.skipUnless(resource, "Test needs resource module") + @unittest.skipUnless(resource, "Test needs resource module") def test_above_fd_setsize(self): # A scalable implementation should have no problem with more than # FD_SETSIZE file descriptors. Since we don't know the value, we just @@ -409,23 +409,23 @@ class SelectSelectorTestCase(BaseSelectorTestCase, test_utils.TestCase): SELECTOR = selectors.SelectSelector -@test_utils.skipUnless(hasattr(selectors, 'PollSelector'), - "Test needs selectors.PollSelector") +@unittest.skipUnless(hasattr(selectors, 'PollSelector'), + "Test needs selectors.PollSelector") class PollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn, test_utils.TestCase): SELECTOR = getattr(selectors, 'PollSelector', None) -@test_utils.skipUnless(hasattr(selectors, 'EpollSelector'), - "Test needs selectors.EpollSelector") +@unittest.skipUnless(hasattr(selectors, 'EpollSelector'), + "Test needs selectors.EpollSelector") class EpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn, test_utils.TestCase): SELECTOR = getattr(selectors, 'EpollSelector', None) -@test_utils.skipUnless(hasattr(selectors, 'KqueueSelector'), +@unittest.skipUnless(hasattr(selectors, 'KqueueSelector'), "Test needs selectors.KqueueSelector)") class KqueueSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn, test_utils.TestCase): @@ -433,7 +433,7 @@ class KqueueSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn, SELECTOR = getattr(selectors, 'KqueueSelector', None) -@test_utils.skipUnless(hasattr(selectors, 'DevpollSelector'), +@unittest.skipUnless(hasattr(selectors, 'DevpollSelector'), "Test needs selectors.DevpollSelector") class DevpollSelectorTestCase(BaseSelectorTestCase, ScalableSelectorMixIn, test_utils.TestCase): diff --git a/tests/test_sslproto.py b/tests/test_sslproto.py index 630c1616..aa2af326 100644 --- a/tests/test_sslproto.py +++ b/tests/test_sslproto.py @@ -1,19 +1,19 @@ """Tests for asyncio/sslproto.py.""" -import unittest try: import ssl except ImportError: ssl = None import trollius as asyncio +from trollius import ConnectionResetError from trollius import sslproto from trollius import test_utils from trollius.test_utils import mock -from trollius import ConnectionResetError +from trollius.test_utils import unittest -@test_utils.skipIf(ssl is None, 'No ssl module') +@unittest.skipIf(ssl is None, 'No ssl module') class SslProtoHandshakeTests(test_utils.TestCase): def setUp(self): diff --git a/tests/test_streams.py b/tests/test_streams.py index 9b983e09..d20e9fef 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -5,7 +5,6 @@ import os import socket import sys -import unittest try: import ssl except ImportError: @@ -16,6 +15,7 @@ from trollius import compat from trollius import test_utils from trollius.test_utils import mock +from trollius.test_utils import unittest class StreamReaderTests(test_utils.TestCase): @@ -56,7 +56,7 @@ def test_open_connection(self): loop=self.loop) self._basetest_open_connection(conn_fut) - @test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_open_unix_connection(self): with test_utils.run_test_unix_server() as httpd: conn_fut = asyncio.open_unix_connection(httpd.address, @@ -75,7 +75,7 @@ def _basetest_open_connection_no_loop_ssl(self, open_connection_fut): writer.close() - @test_utils.skipIf(ssl is None, 'No ssl module') + @unittest.skipIf(ssl is None, 'No ssl module') def test_open_connection_no_loop_ssl(self): with test_utils.run_test_server(use_ssl=True) as httpd: conn_fut = asyncio.open_connection( @@ -85,8 +85,8 @@ def test_open_connection_no_loop_ssl(self): self._basetest_open_connection_no_loop_ssl(conn_fut) - @test_utils.skipIf(ssl is None, 'No ssl module') - @test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @unittest.skipIf(ssl is None, 'No ssl module') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_open_unix_connection_no_loop_ssl(self): with test_utils.run_test_unix_server(use_ssl=True) as httpd: conn_fut = asyncio.open_unix_connection( @@ -112,7 +112,7 @@ def test_open_connection_error(self): loop=self.loop) self._basetest_open_connection_error(conn_fut) - @test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_open_unix_connection_error(self): with test_utils.run_test_unix_server() as httpd: conn_fut = asyncio.open_unix_connection(httpd.address, @@ -511,7 +511,7 @@ def client(addr): server.stop() self.assertEqual(msg, b"hello world!\n") - @test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') + @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') def test_start_unix_server(self): class MyServer: @@ -581,7 +581,7 @@ def client(path): server.stop() self.assertEqual(msg, b"hello world!\n") - @test_utils.skipIf(sys.platform == 'win32', "Don't have pipes") + @unittest.skipIf(sys.platform == 'win32', "Don't have pipes") def test_read_all_from_pipe_reader(self): # See asyncio issue 168. This test is derived from the example # subprocess_attach_read_pipe.py, but we configure the diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index c7333dc1..a8138346 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -4,12 +4,12 @@ import os import signal import sys -import unittest +from trollius import BrokenPipeError, ConnectionResetError, ProcessLookupError from trollius import From, Return from trollius import base_subprocess from trollius import test_support as support from trollius.test_utils import mock -from trollius import BrokenPipeError, ConnectionResetError, ProcessLookupError +from trollius.test_utils import unittest if sys.platform != 'win32': from trollius import unix_events @@ -134,7 +134,7 @@ def test_shell(self): exitcode = self.loop.run_until_complete(proc.wait()) self.assertEqual(exitcode, 7) - @test_utils.skipUnless(hasattr(os, 'setsid'), "need os.setsid()") + @unittest.skipUnless(hasattr(os, 'setsid'), "need os.setsid()") def test_start_new_session(self): def start_new_session(): os.setsid() @@ -171,7 +171,7 @@ def test_terminate(self): else: self.assertEqual(-signal.SIGTERM, returncode) - @test_utils.skipIf(sys.platform == 'win32', "Don't have SIGHUP") + @unittest.skipIf(sys.platform == 'win32', "Don't have SIGHUP") def test_send_signal(self): code = '; '.join(( 'import sys, time', diff --git a/tests/test_tasks.py b/tests/test_tasks.py index d70acabb..afa1190e 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -6,7 +6,6 @@ import re import sys import types -import unittest import weakref import trollius as asyncio @@ -15,6 +14,7 @@ from trollius import test_support as support from trollius import test_utils from trollius.test_utils import mock +from trollius.test_utils import unittest PY33 = (sys.version_info >= (3, 3)) @@ -1631,8 +1631,8 @@ def foo(): wd['cw'] = cw # Would fail without __weakref__ slot. cw.gen = None # Suppress warning from __del__. - @test_utils.skipUnless(PY34, - 'need python 3.4 or later') + @unittest.skipUnless(PY34, + 'need python 3.4 or later') def test_log_destroyed_pending_task(self): @asyncio.coroutine def kill_me(loop): diff --git a/tests/test_transports.py b/tests/test_transports.py index 42f7729f..d4c57805 100644 --- a/tests/test_transports.py +++ b/tests/test_transports.py @@ -1,11 +1,10 @@ """Tests for transports.py.""" -import unittest - import trollius as asyncio from trollius import test_utils from trollius import transports from trollius.test_utils import mock +from trollius.test_utils import unittest try: memoryview diff --git a/tests/test_unix_events.py b/tests/test_unix_events.py index 137a699f..b6b5bdcc 100644 --- a/tests/test_unix_events.py +++ b/tests/test_unix_events.py @@ -11,12 +11,11 @@ import sys import tempfile import threading -import unittest +from trollius.test_utils import unittest if sys.platform == 'win32': raise unittest.SkipTest('UNIX only') - import trollius as asyncio from trollius import log from trollius import test_utils @@ -37,7 +36,7 @@ def close_pipe_transport(transport): transport._pipe = None -@test_utils.skipUnless(signal, 'Signals are not supported') +@unittest.skipUnless(signal, 'Signals are not supported') class SelectorEventLoopSignalTests(test_utils.TestCase): def setUp(self): @@ -230,8 +229,8 @@ def test_close(self, m_signal): m_signal.set_wakeup_fd.assert_called_once_with(-1) -@test_utils.skipUnless(hasattr(socket, 'AF_UNIX'), - 'UNIX Sockets are not supported') +@unittest.skipUnless(hasattr(socket, 'AF_UNIX'), + 'UNIX Sockets are not supported') class SelectorEventLoopUnixSocketTests(test_utils.TestCase): def setUp(self): diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index cedb0f72..81e26f5d 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -1,10 +1,9 @@ -from trollius import test_utils import os import sys -import unittest +from trollius.test_utils import unittest if sys.platform != 'win32': - raise test_utils.SkipTest('Windows only') + raise unittest.SkipTest('Windows only') import trollius as asyncio from trollius import Return, From diff --git a/tests/test_windows_utils.py b/tests/test_windows_utils.py index bc21bb18..a7d7b885 100644 --- a/tests/test_windows_utils.py +++ b/tests/test_windows_utils.py @@ -2,12 +2,11 @@ import socket import sys -import unittest import warnings +from trollius.test_utils import unittest if sys.platform != 'win32': - from trollius.test_utils import SkipTest - raise SkipTest('Windows only') + raise unittest.SkipTest('Windows only') from trollius import _overlapped from trollius import py33_winapi as _winapi @@ -29,14 +28,14 @@ def test_winsocketpair(self): ssock, csock = windows_utils.socketpair() self.check_winsocketpair(ssock, csock) - @test_utils.skipUnless(support.IPV6_ENABLED, - 'IPv6 not supported or enabled') + @unittest.skipUnless(support.IPV6_ENABLED, + 'IPv6 not supported or enabled') def test_winsocketpair_ipv6(self): ssock, csock = windows_utils.socketpair(family=socket.AF_INET6) self.check_winsocketpair(ssock, csock) - @test_utils.skipIf(hasattr(socket, 'socketpair'), - 'socket.socketpair is available') + @unittest.skipIf(hasattr(socket, 'socketpair'), + 'socket.socketpair is available') @mock.patch('trollius.windows_utils.socket') def test_winsocketpair_exc(self, m_socket): m_socket.AF_INET = socket.AF_INET @@ -55,8 +54,8 @@ def test_winsocketpair_invalid_args(self): self.assertRaises(ValueError, windows_utils.socketpair, proto=1) - @test_utils.skipIf(hasattr(socket, 'socketpair'), - 'socket.socketpair is available') + @unittest.skipIf(hasattr(socket, 'socketpair'), + 'socket.socketpair is available') @mock.patch('trollius.windows_utils.socket') def test_winsocketpair_close(self, m_socket): m_socket.AF_INET = socket.AF_INET diff --git a/trollius/test_utils.py b/trollius/test_utils.py index f6fb2468..1d7080ad 100644 --- a/trollius/test_utils.py +++ b/trollius/test_utils.py @@ -50,20 +50,17 @@ from socket import socketpair # pragma: no cover try: - import unittest2 - skipIf = unittest2.skipIf - skipUnless = unittest2.skipUnless - SkipTest = unittest2.SkipTest - _TestCase = unittest2.TestCase + # Prefer unittest2 if available (on Python 2) + import unittest2 as unittest except ImportError: import unittest - skipIf = unittest.skipIf - skipUnless = unittest.skipUnless - SkipTest = unittest.SkipTest - _TestCase = unittest.TestCase +skipIf = unittest.skipIf +skipUnless = unittest.skipUnless +SkipTest = unittest.SkipTest -if not hasattr(_TestCase, 'assertRaisesRegex'): + +if not hasattr(unittest.TestCase, 'assertRaisesRegex'): class _BaseTestCaseContext: def __init__(self, test_case): @@ -502,7 +499,7 @@ def get_function_source(func): return source -class TestCase(_TestCase): +class TestCase(unittest.TestCase): def set_event_loop(self, loop, cleanup=True): assert loop is not None # ensure that the event loop is passed explicitly in asyncio @@ -522,7 +519,7 @@ def tearDown(self): # in an except block of a generator self.assertEqual(sys.exc_info(), (None, None, None)) - if not hasattr(_TestCase, 'assertRaisesRegex'): + if not hasattr(unittest.TestCase, 'assertRaisesRegex'): def assertRaisesRegex(self, expected_exception, expected_regex, callable_obj=None, *args, **kwargs): """Asserts that the message in a raised exception matches a regex. @@ -542,7 +539,7 @@ def assertRaisesRegex(self, expected_exception, expected_regex, return context.handle('assertRaisesRegex', callable_obj, args, kwargs) - if not hasattr(_TestCase, 'assertRegex'): + if not hasattr(unittest.TestCase, 'assertRegex'): def assertRegex(self, text, expected_regex, msg=None): """Fail the test unless the text matches the regular expression.""" if isinstance(expected_regex, (str, bytes)): From d735bba97153567b273926fbd0aa1d2a76c27dcf Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 9 Jul 2015 11:36:36 +0200 Subject: [PATCH 1419/1502] fix import --- tests/test_windows_events.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_windows_events.py b/tests/test_windows_events.py index 81e26f5d..ef0ab92c 100644 --- a/tests/test_windows_events.py +++ b/tests/test_windows_events.py @@ -11,6 +11,7 @@ from trollius import py33_winapi as _winapi from trollius import windows_events from trollius.py33_exceptions import PermissionError, FileNotFoundError +from trollius import test_utils from trollius.test_utils import mock From e752c82a05f840ea683d0a55f6454f953eb933db Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 9 Jul 2015 11:41:21 +0200 Subject: [PATCH 1420/1502] getaddrinfo() doesn't support keywords on Python 2 --- trollius/base_events.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/trollius/base_events.py b/trollius/base_events.py index 1245d485..d4dc4483 100644 --- a/trollius/base_events.py +++ b/trollius/base_events.py @@ -117,10 +117,10 @@ def _check_resolved_address(sock, address): type_mask |= socket.SOCK_CLOEXEC try: socket.getaddrinfo(host, port, - family=family, - type=(sock.type & ~type_mask), - proto=sock.proto, - flags=socket.AI_NUMERICHOST) + family, + (sock.type & ~type_mask), + sock.proto, + socket.AI_NUMERICHOST) except socket.gaierror as err: raise ValueError("address must be resolved (IP address), " "got host %r: %s" From 170c465d22795deaeee8078730aff6e2fb3cc864 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 9 Jul 2015 11:47:24 +0200 Subject: [PATCH 1421/1502] subprocess.Popen doesn't support context manager on Python 2 --- tests/test_windows_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_windows_utils.py b/tests/test_windows_utils.py index a7d7b885..f73e2631 100644 --- a/tests/test_windows_utils.py +++ b/tests/test_windows_utils.py @@ -172,9 +172,10 @@ def test_popen(self): self.assertTrue(msg.upper().rstrip().startswith(out)) self.assertTrue(b"stderr".startswith(err)) - # The context manager calls wait() and closes resources - with p: - pass + p.stdin.close() + p.stdout.close() + p.stderr.close() + p.wait() if __name__ == '__main__': From f751a25ebf553ece3b3ccfb0018299e579c64cbb Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 9 Jul 2015 11:54:16 +0200 Subject: [PATCH 1422/1502] Fix PipeHandle.__del__ on Python 2 --- trollius/windows_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/trollius/windows_utils.py b/trollius/windows_utils.py index 7d21bbc0..2a10ce8d 100644 --- a/trollius/windows_utils.py +++ b/trollius/windows_utils.py @@ -17,6 +17,7 @@ import warnings from . import py33_winapi as _winapi +from . import compat from .py33_exceptions import wrap_error, BlockingIOError, InterruptedError @@ -167,7 +168,8 @@ def close(self, CloseHandle=_winapi.CloseHandle): def __del__(self): if self._handle is not None: - warnings.warn("unclosed %r" % self, ResourceWarning) + if compat.PY3: + warnings.warn("unclosed %r" % self, ResourceWarning) self.close() def __enter__(self): From 3acc9d03176d234feda36701d63fe2dddd063482 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 9 Jul 2015 12:05:47 +0200 Subject: [PATCH 1423/1502] fix test_events on python 3.4.0 and 3.4.1 --- tests/test_events.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_events.py b/tests/test_events.py index 5153312b..8cd0ae86 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -2238,8 +2238,6 @@ def test_not_implemented(self): NotImplementedError, loop.create_task, None) self.assertRaises( NotImplementedError, loop.close) - self.assertRaises( - NotImplementedError, loop.create_task, None) self.assertRaises( NotImplementedError, loop.call_later, None, None) self.assertRaises( From 3b6a64a9fb6ec4ad0c984532aa776b130067c901 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Thu, 9 Jul 2015 22:52:48 +0200 Subject: [PATCH 1424/1502] Add asyncio.compat module Move compatibility helpers for the different Python versions to a new asyncio.compat module. --- asyncio/compat.py | 17 +++++++++++++++++ asyncio/coroutines.py | 6 ++---- asyncio/events.py | 7 +++---- asyncio/futures.py | 10 ++++------ asyncio/locks.py | 6 ++---- asyncio/streams.py | 4 ++-- asyncio/tasks.py | 5 ++--- asyncio/transports.py | 10 +++------- 8 files changed, 35 insertions(+), 30 deletions(-) create mode 100644 asyncio/compat.py diff --git a/asyncio/compat.py b/asyncio/compat.py new file mode 100644 index 00000000..660b7e7e --- /dev/null +++ b/asyncio/compat.py @@ -0,0 +1,17 @@ +"""Compatibility helpers for the different Python versions.""" + +import sys + +PY34 = sys.version_info >= (3, 4) +PY35 = sys.version_info >= (3, 5) + + +def flatten_list_bytes(list_of_data): + """Concatenate a sequence of bytes-like objects.""" + if not PY34: + # On Python 3.3 and older, bytes.join() doesn't handle + # memoryview. + list_of_data = ( + bytes(data) if isinstance(data, memoryview) else data + for data in list_of_data) + return b''.join(list_of_data) diff --git a/asyncio/coroutines.py b/asyncio/coroutines.py index 15475f23..e11b21b0 100644 --- a/asyncio/coroutines.py +++ b/asyncio/coroutines.py @@ -9,14 +9,12 @@ import traceback import types +from . import compat from . import events from . import futures from .log import logger -_PY35 = sys.version_info >= (3, 5) - - # Opcode of "yield from" instruction _YIELD_FROM = opcode.opmap['YIELD_FROM'] @@ -140,7 +138,7 @@ def gi_running(self): def gi_code(self): return self.gen.gi_code - if _PY35: + if compat.PY35: __await__ = __iter__ # make compatible with 'await' expression diff --git a/asyncio/events.py b/asyncio/events.py index 496075ba..d5f0d451 100644 --- a/asyncio/events.py +++ b/asyncio/events.py @@ -17,12 +17,11 @@ import threading import traceback - -_PY34 = sys.version_info >= (3, 4) +from asyncio import compat def _get_function_source(func): - if _PY34: + if compat.PY34: func = inspect.unwrap(func) elif hasattr(func, '__wrapped__'): func = func.__wrapped__ @@ -31,7 +30,7 @@ def _get_function_source(func): return (code.co_filename, code.co_firstlineno) if isinstance(func, functools.partial): return _get_function_source(func.func) - if _PY34 and isinstance(func, functools.partialmethod): + if compat.PY34 and isinstance(func, functools.partialmethod): return _get_function_source(func.func) return None diff --git a/asyncio/futures.py b/asyncio/futures.py index d06828a6..dbe06c4a 100644 --- a/asyncio/futures.py +++ b/asyncio/futures.py @@ -11,6 +11,7 @@ import sys import traceback +from . import compat from . import events # States for Future. @@ -18,9 +19,6 @@ _CANCELLED = 'CANCELLED' _FINISHED = 'FINISHED' -_PY34 = sys.version_info >= (3, 4) -_PY35 = sys.version_info >= (3, 5) - Error = concurrent.futures._base.Error CancelledError = concurrent.futures.CancelledError TimeoutError = concurrent.futures.TimeoutError @@ -199,7 +197,7 @@ def __repr__(self): # On Python 3.3 and older, objects with a destructor part of a reference # cycle are never destroyed. It's not more the case on Python 3.4 thanks # to the PEP 442. - if _PY34: + if compat.PY34: def __del__(self): if not self._log_traceback: # set_exception() was not called, or result() or exception() @@ -352,7 +350,7 @@ def set_exception(self, exception): self._exception = exception self._state = _FINISHED self._schedule_callbacks() - if _PY34: + if compat.PY34: self._log_traceback = True else: self._tb_logger = _TracebackLogger(self, exception) @@ -388,7 +386,7 @@ def __iter__(self): assert self.done(), "yield from wasn't used with future" return self.result() # May raise too. - if _PY35: + if compat.PY35: __await__ = __iter__ # make compatible with 'await' expression diff --git a/asyncio/locks.py b/asyncio/locks.py index b2e516b5..cc6f2bf7 100644 --- a/asyncio/locks.py +++ b/asyncio/locks.py @@ -5,14 +5,12 @@ import collections import sys +from . import compat from . import events from . import futures from .coroutines import coroutine -_PY35 = sys.version_info >= (3, 5) - - class _ContextManager: """Context manager. @@ -70,7 +68,7 @@ def __iter__(self): yield from self.acquire() return _ContextManager(self) - if _PY35: + if compat.PY35: def __await__(self): # To make "with await lock" work. diff --git a/asyncio/streams.py b/asyncio/streams.py index 176c65e3..6cd60c42 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -12,6 +12,7 @@ __all__.extend(['open_unix_connection', 'start_unix_server']) from . import coroutines +from . import compat from . import events from . import futures from . import protocols @@ -20,7 +21,6 @@ _DEFAULT_LIMIT = 2**16 -_PY35 = sys.version_info >= (3, 5) class IncompleteReadError(EOFError): @@ -488,7 +488,7 @@ def readexactly(self, n): return b''.join(blocks) - if _PY35: + if compat.PY35: @coroutine def __aiter__(self): return self diff --git a/asyncio/tasks.py b/asyncio/tasks.py index d8193ba4..1d5f8654 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -16,13 +16,12 @@ import warnings import weakref +from . import compat from . import coroutines from . import events from . import futures from .coroutines import coroutine -_PY34 = (sys.version_info >= (3, 4)) - class Task(futures.Future): """A coroutine wrapped in a Future.""" @@ -83,7 +82,7 @@ def __init__(self, coro, *, loop=None): # On Python 3.3 or older, objects with a destructor that are part of a # reference cycle are never destroyed. That's not the case any more on # Python 3.4 thanks to the PEP 442. - if _PY34: + if compat.PY34: def __del__(self): if self._state == futures._PENDING and self._log_destroy_pending: context = { diff --git a/asyncio/transports.py b/asyncio/transports.py index 22df3c7a..7a28d908 100644 --- a/asyncio/transports.py +++ b/asyncio/transports.py @@ -2,7 +2,7 @@ import sys -_PY34 = sys.version_info >= (3, 4) +from asyncio import compat __all__ = ['BaseTransport', 'ReadTransport', 'WriteTransport', 'Transport', 'DatagramTransport', 'SubprocessTransport', @@ -94,12 +94,8 @@ def writelines(self, list_of_data): The default implementation concatenates the arguments and calls write() on the result. """ - if not _PY34: - # In Python 3.3, bytes.join() doesn't handle memoryview. - list_of_data = ( - bytes(data) if isinstance(data, memoryview) else data - for data in list_of_data) - self.write(b''.join(list_of_data)) + data = compat.flatten_list_bytes(list_of_data) + self.write(data) def write_eof(self): """Close the write end after flushing buffered data. From 1f85dc7d6ba07d1a8be18446d251ebe5b7575d34 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 10 Jul 2015 22:51:13 +0200 Subject: [PATCH 1425/1502] Issue #234: Drop JoinableQueue on Python 3.5+ --- ChangeLog | 6 ++++++ asyncio/queues.py | 10 ++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/ChangeLog b/ChangeLog index 25155a98..421704c0 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,9 @@ +Tulip 3.4.4 +=========== + +* Issue #234: Drop asyncio.JoinableQueue on Python 3.5 and newer + + 2015-02-04: Tulip 3.4.3 ======================= diff --git a/asyncio/queues.py b/asyncio/queues.py index 3b4dc21a..c55dd8bb 100644 --- a/asyncio/queues.py +++ b/asyncio/queues.py @@ -1,11 +1,11 @@ """Queues""" -__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty', - 'JoinableQueue'] +__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty'] import collections import heapq +from . import compat from . import events from . import futures from . import locks @@ -289,5 +289,7 @@ def _get(self): return self._queue.pop() -JoinableQueue = Queue -"""Deprecated alias for Queue.""" +if not compat.PY35: + JoinableQueue = Queue + """Deprecated alias for Queue.""" + __all__.append('JoinableQueue') From 53a8e99b5074766bd6bd7d02909bd31fcf686eb3 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 11 Jul 2015 11:35:40 +0200 Subject: [PATCH 1426/1502] port update-tulip-step1.sh to git+github --- update-tulip-step1.sh | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/update-tulip-step1.sh b/update-tulip-step1.sh index 0b39f7ee..7d7a3838 100755 --- a/update-tulip-step1.sh +++ b/update-tulip-step1.sh @@ -1,9 +1,8 @@ set -e -x -hg update trollius -hg pull --update -hg update default -hg pull https://code.google.com/p/tulip/ -hg update -hg update trollius -hg merge default +git checkout trollius +git pull -u +git checkout master +git pull https://github.com/python/asyncio.git +git checkout trollius +git merge master echo "Now run ./update-tulip-step2.sh" From 05e99d691bc080e0d551c0eea50b0097c098e088 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 11 Jul 2015 11:37:07 +0200 Subject: [PATCH 1427/1502] rename update-tulip...sh scripts to update-asyncio-...sh --- update-tulip-step1.sh => update-asyncio-step1.sh | 0 update-tulip-step2.sh => update-asyncio-step2.sh | 0 update-tulip-step3.sh => update-asyncio-step3.sh | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename update-tulip-step1.sh => update-asyncio-step1.sh (100%) rename update-tulip-step2.sh => update-asyncio-step2.sh (100%) rename update-tulip-step3.sh => update-asyncio-step3.sh (100%) diff --git a/update-tulip-step1.sh b/update-asyncio-step1.sh similarity index 100% rename from update-tulip-step1.sh rename to update-asyncio-step1.sh diff --git a/update-tulip-step2.sh b/update-asyncio-step2.sh similarity index 100% rename from update-tulip-step2.sh rename to update-asyncio-step2.sh diff --git a/update-tulip-step3.sh b/update-asyncio-step3.sh similarity index 100% rename from update-tulip-step3.sh rename to update-asyncio-step3.sh From cfcc3c0f51b9d3e8fbd6a945db7b6ce4cd5ebc49 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 11 Jul 2015 11:38:14 +0200 Subject: [PATCH 1428/1502] update scripts to git --- update-asyncio-step2.sh | 8 ++++---- update-asyncio-step3.sh | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/update-asyncio-step2.sh b/update-asyncio-step2.sh index ebf98bdd..558940bb 100755 --- a/update-asyncio-step2.sh +++ b/update-asyncio-step2.sh @@ -1,16 +1,16 @@ set -e # Check for merge conflicts -if $(hg resolve -l | grep -q -v '^R'); then +if $(git status --porcelain|grep -q '^.U '); then echo "Fix the following conflicts:" - hg resolve -l | grep -v '^R' + git status exit 1 fi # Ensure that yield from is not used -if $(hg diff|grep -q 'yield from'); then +if $(git diff|grep -q 'yield from'); then echo "yield from present in changed code!" - hg diff | grep 'yield from' -B5 -A3 + git diff | grep 'yield from' -B5 -A3 exit 1 fi diff --git a/update-asyncio-step3.sh b/update-asyncio-step3.sh index 202b44bc..d7d63bad 100755 --- a/update-asyncio-step3.sh +++ b/update-asyncio-step3.sh @@ -1,4 +1,4 @@ set -e -x ./update-tulip-step2.sh tox -e py27,py34 -hg ci -m 'Merge Tulip into Trollius' +git commit -m 'Merge asyncio into trollius' From fe8894a6abc0d1b509518ddb163cfb6a4e98e04a Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 11 Jul 2015 11:41:44 +0200 Subject: [PATCH 1429/1502] rename compat.py to be able to merge asyncio --- trollius/compat.py | 61 ---------------------------------------------- 1 file changed, 61 deletions(-) delete mode 100644 trollius/compat.py diff --git a/trollius/compat.py b/trollius/compat.py deleted file mode 100644 index 79478420..00000000 --- a/trollius/compat.py +++ /dev/null @@ -1,61 +0,0 @@ -""" -Compatibility constants and functions for the different Python versions. -""" -import sys - -# Python 2.6 or older? -PY26 = (sys.version_info < (2, 7)) - -# Python 3.0 or newer? -PY3 = (sys.version_info >= (3,)) - -# Python 3.3 or newer? -PY33 = (sys.version_info >= (3, 3)) - -# Python 3.4 or newer? -PY34 = sys.version_info >= (3, 4) - -if PY3: - integer_types = (int,) - bytes_type = bytes - text_type = str - string_types = (bytes, str) - BYTES_TYPES = (bytes, bytearray, memoryview) -else: - integer_types = (int, long,) - bytes_type = str - text_type = unicode - string_types = basestring - if PY26: - BYTES_TYPES = (str, bytearray, buffer) - else: # Python 2.7 - BYTES_TYPES = (str, bytearray, memoryview, buffer) - -def flatten_bytes(data): - """ - Convert bytes-like objects (bytes, bytearray, memoryview, buffer) to - a bytes string. - """ - if not isinstance(data, BYTES_TYPES): - raise TypeError('data argument must be byte-ish (%r)', - type(data)) - if PY34: - # In Python 3.4, socket.send() and bytes.join() accept memoryview - # and bytearray - return data - if not data: - return b'' - if not PY3 and isinstance(data, (buffer, bytearray)): - return str(data) - elif not PY26 and isinstance(data, memoryview): - return data.tobytes() - else: - return data - -if PY3: - def reraise(tp, value, tb=None): - if value.__traceback__ is not tb: - raise value.with_traceback(tb) - raise value -else: - exec("""def reraise(tp, value, tb=None): raise tp, value, tb""") From 61199a49a0782b1cf83f7dc32719b456d950ba3f Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 11 Jul 2015 11:48:22 +0200 Subject: [PATCH 1430/1502] enhance git merge --- update-asyncio-step1.sh | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/update-asyncio-step1.sh b/update-asyncio-step1.sh index 7d7a3838..e2ac4f6f 100755 --- a/update-asyncio-step1.sh +++ b/update-asyncio-step1.sh @@ -3,6 +3,10 @@ git checkout trollius git pull -u git checkout master git pull https://github.com/python/asyncio.git + git checkout trollius -git merge master +# rename-threshold=25: a similarity of 25% is enough to consider two files +# rename candidates +git merge -X rename-threshold=25 master + echo "Now run ./update-tulip-step2.sh" From 53bf283c2966c74f98069f69645ea8a6c616a904 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 11 Jul 2015 12:03:16 +0200 Subject: [PATCH 1431/1502] don't commit in update-asyncio-step3.sh update setup.py to use git --- setup.py | 11 ++++++----- update-asyncio-step3.sh | 8 +++++++- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index c98c9751..9ec90cc8 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ # Release procedure: # - fill trollius changelog -# - run maybe update_tulip.sh +# - run maybe ./update_asyncio_step1.sh # - run unit tests with concurrent.futures # - run unit tests without concurrent.futures # - run unit tests without ssl: set sys.modules['ssl']=None at startup @@ -9,13 +9,14 @@ # - set release date in doc/changelog.rst # - check that "python setup.py sdist" contains all files tracked by # the SCM (Mercurial): update MANIFEST.in if needed -# - hg ci -# - hg tag trollius-VERSION -# - hg push +# - git commit +# - git tag trollius-VERSION +# - git push --tags +# - git push # - On Linux: python setup.py register sdist bdist_wheel upload # - On Windows: python release.py release # - increment version in setup.py (version) and doc/conf.py (version, release) -# - hg ci && hg push +# - gt commit && git push import os import sys diff --git a/update-asyncio-step3.sh b/update-asyncio-step3.sh index cdd3208c..cc13503e 100755 --- a/update-asyncio-step3.sh +++ b/update-asyncio-step3.sh @@ -1,4 +1,10 @@ set -e -x ./update-asyncio-step2.sh tox -e py27,py34 -git commit -m 'Merge asyncio into trollius' + +git status +echo +echo "Now type:" +echo "git commit -m 'Merge asyncio into trollius'" +echo +echo "You may have to add unstaged files" From 2173eee4d21671139ba836f25b092b6b311cf3fb Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 11 Jul 2015 12:14:18 +0200 Subject: [PATCH 1432/1502] fix 2.6 compat --- TODO.rst | 2 +- tox.ini | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/TODO.rst b/TODO.rst index 09df760c..260b491d 100644 --- a/TODO.rst +++ b/TODO.rst @@ -1,6 +1,6 @@ Unsorted "TODO" tasks: -* Drop Python 2.6 support +* Drop Python 2.6 and 3.2 support * test_utils.py: remove assertRaisesRegex, assertRegex * streams.py:FIXME: should we support __aiter__ and __anext__ in Trollius? * reuse selectors backport from PyPI diff --git a/tox.ini b/tox.ini index bccb3b81..9a1b1aaf 100644 --- a/tox.ini +++ b/tox.ini @@ -29,7 +29,7 @@ commands= deps= aiotest futures - mock + mock==1.0.1 ordereddict unittest2 From e82d8c676a7adcd797836bd5974ced28ca8febbd Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 11 Jul 2015 12:21:51 +0200 Subject: [PATCH 1433/1502] Fix Python 3.2 compat --- tests/test_events.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_events.py b/tests/test_events.py index 8cd0ae86..c8e9605f 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -653,7 +653,9 @@ def _dummy_ssl_create_context(purpose=ssl.Purpose.SERVER_AUTH, with test_utils.disable_logger(): self._basetest_create_ssl_connection(conn_fut, check_sockname) - self.assertEqual(cm.exception.reason, 'CERTIFICATE_VERIFY_FAILED') + # Test for Python 3.2 + if hasattr(ssl.SSLError, 'reason'): + self.assertEqual(cm.exception.reason, 'CERTIFICATE_VERIFY_FAILED') @unittest.skipIf(ssl is None, 'No ssl module') def test_create_ssl_connection(self): @@ -910,7 +912,7 @@ def test_create_server_ssl_verify_failed(self): with mock.patch.object(self.loop, 'call_exception_handler'): with test_utils.disable_logger(): with self.assertRaisesRegex(ssl.SSLError, - 'certificate verify failed '): + 'certificate verify failed'): self.loop.run_until_complete(f_c) # execute the loop to log the connection error @@ -945,7 +947,7 @@ def test_create_unix_server_ssl_verify_failed(self): with mock.patch.object(self.loop, 'call_exception_handler'): with test_utils.disable_logger(): with self.assertRaisesRegex(ssl.SSLError, - 'certificate verify failed '): + 'certificate verify failed'): self.loop.run_until_complete(f_c) # execute the loop to log the connection error From 1ade79951df23e8bd7d173ea1025b95768e8b2b7 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 11 Jul 2015 16:11:26 +0200 Subject: [PATCH 1434/1502] Test without ssl or without concurrent --- runtests.py | 19 ++++++++++++++++--- setup.py | 4 +--- tox.ini | 29 +++++++++++++++++++++++++++-- 3 files changed, 44 insertions(+), 8 deletions(-) diff --git a/runtests.py b/runtests.py index 38626af0..541a47ea 100755 --- a/runtests.py +++ b/runtests.py @@ -29,7 +29,8 @@ import re import sys import textwrap -from trollius.compat import PY33 +PY2 = (sys.version_info < (3,)) +PY33 = (sys.version_info >= (3, 3)) if PY33: import importlib.machinery else: @@ -38,7 +39,7 @@ import coverage except ImportError: coverage = None -if sys.version_info < (3,): +if PY2: sys.exc_clear() try: @@ -57,6 +58,12 @@ ARGS.add_option( '-f', '--failfast', action="store_true", default=False, dest='failfast', help='Stop on first fail or error') +ARGS.add_option( + '--no-ssl', action="store_true", default=False, + help='Disable the SSL module') +ARGS.add_option( + '--no-concurrent', action="store_true", default=False, + help='Disable the concurrent module') ARGS.add_option( '-c', '--catch', action="store_true", default=False, dest='catchbreak', help='Catch control-C and display results') @@ -121,7 +128,7 @@ def list_dir(prefix, dir): for modname, sourcefile in list_dir('', basedir): if modname == 'runtests': continue - if modname == 'test_asyncio' and sys.version_info <= (3, 3): + if modname == 'test_asyncio' and not PY33: print("Skipping '{0}': need at least Python 3.3".format(modname), file=sys.stderr) continue @@ -238,6 +245,12 @@ def _runtests(args, tests): def runtests(): args, pattern = ARGS.parse_args() + if args.no_ssl: + sys.modules['ssl'] = None + + if args.no_concurrent: + sys.modules['concurrent'] = None + if args.coverage and coverage is None: URL = "bitbucket.org/pypa/setuptools/raw/bootstrap/ez_setup.py" print(textwrap.dedent(""" diff --git a/setup.py b/setup.py index 9ec90cc8..407e3142 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,7 @@ # Release procedure: # - fill trollius changelog # - run maybe ./update_asyncio_step1.sh -# - run unit tests with concurrent.futures -# - run unit tests without concurrent.futures -# - run unit tests without ssl: set sys.modules['ssl']=None at startup +# - run all tests: tox # - test examples # - update version in setup.py (version) and doc/conf.py (version, release) # - set release date in doc/changelog.rst diff --git a/tox.ini b/tox.ini index 9a1b1aaf..5630059d 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py26,py27,py2_release,py32,py33,py34,py3_release +envlist = py26,py27,py2_release,py2_no_ssl,py2_no_concurrent,py32,py33,py34,py3_release,py3_no_ssl # and: pyflakes2,pyflakes3 [testenv] @@ -42,6 +42,7 @@ deps= [testenv:py2_release] # Run tests in release mode +basepython = python2 deps= aiotest futures @@ -49,7 +50,26 @@ deps= unittest2 setenv = TROLLIUSDEBUG = -basepython = python2.7 + +[testenv:py2_no_ssl] +basepython = python2 +deps= + aiotest + futures + mock + unittest2 +commands= + python runtests.py --no-ssl -r {posargs} + +[testenv:py2_no_concurrent] +basepython = python2 +deps= + aiotest + futures + mock + unittest2 +commands= + python runtests.py --no-concurrent -r {posargs} [testenv:py32] deps= @@ -61,6 +81,11 @@ basepython = python3.5 [testenv:py3_release] # Run tests in release mode +basepython = python3 setenv = TROLLIUSDEBUG = + +[testenv:py3_no_ssl] basepython = python3 +commands= + python runtests.py --no-ssl -r {posargs} From b5e33de48e1773a9d7f85b9e2ea74acb2ddc48df Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 11 Jul 2015 16:38:21 +0200 Subject: [PATCH 1435/1502] Enable warnings to see ResourceWarning on Python 3 --- tox.ini | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tox.ini b/tox.ini index 5630059d..6f0a3421 100644 --- a/tox.ini +++ b/tox.ini @@ -8,8 +8,8 @@ deps= setenv = TROLLIUSDEBUG = 1 commands= - python runtests.py -r {posargs} - python run_aiotest.py -r {posargs} + python -Wd runtests.py -r {posargs} + python -Wd run_aiotest.py -r {posargs} [testenv:pyflakes2] basepython = python2 @@ -59,7 +59,7 @@ deps= mock unittest2 commands= - python runtests.py --no-ssl -r {posargs} + python -Wd runtests.py --no-ssl -r {posargs} [testenv:py2_no_concurrent] basepython = python2 @@ -69,7 +69,7 @@ deps= mock unittest2 commands= - python runtests.py --no-concurrent -r {posargs} + python -Wd runtests.py --no-concurrent -r {posargs} [testenv:py32] deps= @@ -88,4 +88,4 @@ setenv = [testenv:py3_no_ssl] basepython = python3 commands= - python runtests.py --no-ssl -r {posargs} + python -Wd runtests.py --no-ssl -r {posargs} From 05ecc3ccf71ed9168acbd9c34ca2afc540bd7d07 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 11 Jul 2015 16:01:12 +0200 Subject: [PATCH 1436/1502] update doc * update project URL * set version to 2.0 --- README.rst | 4 +-- doc/changelog.rst | 87 +++++++++++++++++++++++++++++++++++++++++++---- doc/conf.py | 2 +- doc/index.rst | 2 +- doc/install.rst | 6 ++-- setup.py | 4 +-- 6 files changed, 90 insertions(+), 15 deletions(-) diff --git a/README.rst b/README.rst index a55218d5..a1bf4953 100644 --- a/README.rst +++ b/README.rst @@ -35,8 +35,8 @@ OpenIndiana. * `Trollius documentation `_ * `Trollius project in the Python Cheeseshop (PyPI) `_ -* `Trollius project at Github `_ (code, - bug tracker) +* `Trollius project at Github `_ + (bug tracker, source code) * Copyright/license: Open source, Apache 2.0. Enjoy! See also the `asyncio project at Github `_. diff --git a/doc/changelog.rst b/doc/changelog.rst index 1ddfc218..578e21a9 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -2,14 +2,32 @@ Change log ++++++++++ -Version 1.0.5 -============= +Version 2.0 +=========== Major changes: on Python 3.5+ ProactorEventLoop now supports SSL, a lot of bugfixes (random race conditions) in the ProactorEventLoop. +The Trollius project moved from Bitbucket to Github. The project is now a fork +of the Git repository of the asyncio project (previously called the "tulip" +project), the trollius source code lives in the trollius branch. + +The new Trollius home page is now: https://github.com/haypo/trollius + +The asyncio project moved to: https://github.com/python/asyncio + +Note: the PEP 492 is not supported in trollius (yet?). + API changes: +* Issue #234: Drop JoinableQueue on Python 3.5+ +* add the asyncio.ensure_future() function, previously called async(). + The async() function is now deprecated. +* New event loop methods: set_task_factory() and get_task_factory(). +* Python issue #23347: Make BaseSubprocessTransport.wait() private. +* Python issue #23347: send_signal(), kill() and terminate() methods of + BaseSubprocessTransport now check if the transport was closed and if the + process exited. * Python issue #23209, #23225: selectors.BaseSelector.get_key() now raises a RuntimeError if the selector is closed. And selectors.BaseSelector.close() now clears its internal reference to the selector mapping to break a @@ -18,6 +36,11 @@ API changes: pipe is closed. * Remove Overlapped.WaitNamedPipeAndConnect() of the _overlapped module, it is no more used and it had issues. +* Python issue #23537: Remove 2 unused private methods of + BaseSubprocessTransport: _make_write_subprocess_pipe_proto, + _make_read_subprocess_pipe_proto. Methods only raise NotImplementedError and + are never used. +* Remove unused SSLProtocol._closing attribute New SSL implementation: @@ -57,8 +80,35 @@ Enhance, fix and cleanup the IocpProactor: CancelledError: just exit. On error, log the exception and exit; don't try to close the event loop (it doesn't work). -Bugfixes: - +Bug fixes: + +* Fix LifoQueue's and PriorityQueue's put() and task_done(). +* Issue #222: Fix the @coroutine decorator for functions without __name__ + attribute like functools.partial(). Enhance also the representation of a + CoroWrapper if the coroutine function is a functools.partial(). +* Python issue #23879: SelectorEventLoop.sock_connect() must not call connect() + again if the first call to connect() raises an InterruptedError. When the C + function connect() fails with EINTR, the connection runs in background. We + have to wait until the socket becomes writable to be notified when the + connection succeed or fails. +* Fix _SelectorTransport.__repr__() if the event loop is closed +* Fix repr(BaseSubprocessTransport) if it didn't start yet +* Workaround CPython bug #23353. Don't use yield/yield-from in an except block + of a generator. Store the exception and handle it outside the except block. +* Fix BaseSelectorEventLoop._accept_connection(). Close the transport on error. + In debug mode, log errors using call_exception_handler(). +* Fix _UnixReadPipeTransport and _UnixWritePipeTransport. Only start reading + when connection_made() has been called. +* Fix _SelectorSslTransport.close(). Don't call protocol.connection_lost() if + protocol.connection_made() was not called yet: if the SSL handshake failed or + is still in progress. The close() method can be called if the creation of the + connection is cancelled, by a timeout for example. +* Fix _SelectorDatagramTransport constructor. Only start reading after + connection_made() has been called. +* Fix _SelectorSocketTransport constructor. Only start reading when + connection_made() has been called: protocol.data_received() must not be + called before protocol.connection_made(). +* Fix SSLProtocol.eof_received(). Wake-up the waiter if it is not done yet. * Close transports on error. Fix create_datagram_endpoint(), connect_read_pipe() and connect_write_pipe(): close the transport if the task is cancelled or on error. @@ -82,8 +132,34 @@ Bugfixes: * Python issue #23209: Break some reference cycles in asyncio. Patch written by Martin Richard. -Changes: +Optimization: + +* Only call _check_resolved_address() in debug mode. _check_resolved_address() + is implemented with getaddrinfo() which is slow. If available, use + socket.inet_pton() instead of socket.getaddrinfo(), because it is much faster +Other changes: + +* Python issue #23456: Add missing @coroutine decorators +* Python issue #23475: Fix test_close_kill_running(). Really kill the child + process, don't mock completly the Popen.kill() method. This change fix memory + leaks and reference leaks. +* BaseSubprocessTransport: repr() mentions when the child process is running +* BaseSubprocessTransport.close() doesn't try to kill the process if it already + finished. +* Tulip issue #221: Fix docstring of QueueEmpty and QueueFull +* Fix subprocess_attach_write_pipe example. Close the transport, not directly + the pipe. +* Python issue #23347: send_signal(), terminate(), kill() don't check if the + transport was closed. The check broken a Tulip example and this limitation is + arbitrary. Check if _proc is None should be enough. Enhance also close(): do + nothing when called the second time. +* Python issue #23347: Refactor creation of subprocess transports. +* Python issue #23243: On Python 3.4 and newer, emit a ResourceWarning when an + event loop or a transport is not explicitly closed +* tox.ini: enable ResourceWarning warnings +* Python issue #23243: test_sslproto: Close explicitly transports +* SSL transports now clear their reference to the waiter. * Python issue #23208: Add BaseEventLoop._current_handle. In debug mode, BaseEventLoop._run_once() now sets the BaseEventLoop._current_handle attribute to the handle currently executed. @@ -106,7 +182,6 @@ Changes: exception in debug mode when called from the wrong thread. It should help to notice misusage of the API. - 2014-12-19: Version 1.0.4 ========================= diff --git a/doc/conf.py b/doc/conf.py index 0bcd2da9..818d8190 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -48,7 +48,7 @@ # built documents. # # The short X.Y version. -version = release = '1.0.5' +version = release = '2.0' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/doc/index.rst b/doc/index.rst index 317f9245..c6b70bf8 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -43,7 +43,7 @@ Linux, Mac OS X, FreeBSD and OpenIndiana. * `Trollius project in the Python Cheeseshop (PyPI) `_ (download wheel packages and tarballs) -* `Trollius project at Bitbucket `_ +* `Trollius project at Github `_ (bug tracker, source code) * Mailing list: `python-tulip Google Group `_ diff --git a/doc/install.rst b/doc/install.rst index ea2b4557..72e20e8f 100644 --- a/doc/install.rst +++ b/doc/install.rst @@ -49,13 +49,13 @@ Download source code Command to download the development version of the source code (``trollius`` branch):: - hg clone 'https://bitbucket.org/enovance/trollius#trollius' + git clone https://github.com/haypo/trollius.git -b trollius The actual code lives in the ``trollius`` subdirectory. Tests are in the ``tests`` subdirectory. -See the `trollius project at Bitbucket -`_. +See the `trollius project at Github +`_. The source code of the Trollius project is in the ``trollius`` branch of the Mercurial repository, not in the default branch. The default branch is the diff --git a/setup.py b/setup.py index 407e3142..f3028e9e 100644 --- a/setup.py +++ b/setup.py @@ -45,14 +45,14 @@ install_options = { "name": "trollius", - "version": "1.0.5", + "version": "2.0", "license": "Apache License 2.0", "author": 'Victor Stinner', "author_email": 'victor.stinner@gmail.com', "description": "Port of the Tulip project (asyncio module, PEP 3156) on Python 2", "long_description": long_description, - "url": "https://bitbucket.org/enovance/trollius/", + "url": "https://github.com/haypo/trollius", "classifiers": [ "Programming Language :: Python", From b4f24c28431205dd14521179f722af8e19256cde Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sun, 12 Jul 2015 00:25:50 +0200 Subject: [PATCH 1437/1502] cleanup changelog --- doc/changelog.rst | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/doc/changelog.rst b/doc/changelog.rst index 578e21a9..3e9c5f1b 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -44,9 +44,9 @@ API changes: New SSL implementation: -* Python issue #22560: On Python 3.5 and newer, a new SSL implementation based - on ssl.MemoryBIO instead of the legacy SSL implementation. Patch written by - Antoine Pitrou, based on the work of Geert Jansen. +* Python issue #22560: On Python 3.5 and newer, use new SSL implementation + based on ssl.MemoryBIO instead of the legacy SSL implementation. Patch + written by Antoine Pitrou, based on the work of Geert Jansen. * If available, the new SSL implementation can be used by ProactorEventLoop to support SSL. @@ -69,13 +69,13 @@ Enhance, fix and cleanup the IocpProactor: UnregisterWaitEx() is used with an event instead of UnregisterWait(). * Python issue #23293: Rewrite IocpProactor.connect_pipe() as a coroutine. Use a coroutine with asyncio.sleep() instead of call_later() to ensure that the - schedule call is cancelled. -* Fix ProactorEventLoop.start_serving_pipe(). If a client connected before the - server was closed: drop the client (close the pipe) and exit + scheduled call is cancelled. +* Fix ProactorEventLoop.start_serving_pipe(). If a client was connected before + the server was closed: drop the client (close the pipe) and exit * Python issue #23293: Cleanup IocpProactor.close(). The special case for - connect_pipe() is not more needed. connect_pipe() doesn't use overlapped + connect_pipe() is no more needed. connect_pipe() doesn't use overlapped operations anymore. -* IocpProactor.close(): don't cancel futures which are already cancelled +* IocpProactor.close(): don't cancel futures which are already cancelled * Enhance (fix) BaseProactorEventLoop._loop_self_reading(). Handle correctly CancelledError: just exit. On error, log the exception and exit; don't try to close the event loop (it doesn't work). From 14bdbc05ee70cf89091685561dfcbd64740454e7 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 13 Jul 2015 16:55:13 +0200 Subject: [PATCH 1438/1502] changelog: summary --- doc/changelog.rst | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/doc/changelog.rst b/doc/changelog.rst index 3e9c5f1b..727dd5a5 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -5,8 +5,16 @@ Change log Version 2.0 =========== -Major changes: on Python 3.5+ ProactorEventLoop now supports SSL, a lot of -bugfixes (random race conditions) in the ProactorEventLoop. +Summary: + +* SSL support on Windows for proactor event loop with Python 3.5 and newer +* Many race conditions were fixed in the proactor event loop +* Trollius moved to Github and the fork was recreated on top to asyncio git + repository +* Many resource leaks (ex: unclosed sockets) were fixed +* Optimization of socket connections: avoid: don't call the slow getaddrinfo() + function to ensure that the address is already resolved. The check is now + only done in debug mode. The Trollius project moved from Bitbucket to Github. The project is now a fork of the Git repository of the asyncio project (previously called the "tulip" @@ -16,7 +24,7 @@ The new Trollius home page is now: https://github.com/haypo/trollius The asyncio project moved to: https://github.com/python/asyncio -Note: the PEP 492 is not supported in trollius (yet?). +Note: the PEP 492 is not supported in trollius yet. API changes: From 242222facef6af710d55a7d108be23c32b58c8ac Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 13 Jul 2015 17:33:39 +0200 Subject: [PATCH 1439/1502] Fix sslproto when ssl is not available --- TODO.rst | 1 + tests/test_events.py | 9 +++++++-- trollius/selector_events.py | 2 +- trollius/sslproto.py | 11 ++++++----- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/TODO.rst b/TODO.rst index 260b491d..4a433c22 100644 --- a/TODO.rst +++ b/TODO.rst @@ -1,6 +1,7 @@ Unsorted "TODO" tasks: * Drop Python 2.6 and 3.2 support +* Drop platform without ssl module? * test_utils.py: remove assertRaisesRegex, assertRegex * streams.py:FIXME: should we support __aiter__ and __anext__ in Trollius? * reuse selectors backport from PyPI diff --git a/tests/test_events.py b/tests/test_events.py index c8e9605f..af7113b6 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -64,6 +64,11 @@ def osx_tiger(): return version < (10, 5) +def skip_if_backported_sslcontext(): + backported = getattr(asyncio, 'BACKPORT_SSL_CONTEXT', False) + return unittest.skipIf(backported, 'need ssl.SSLContext') + + ONLYCERT = data_file('ssl_cert.pem') ONLYKEY = data_file('ssl_key.pem') SIGNED_CERTFILE = data_file('keycert3.pem') @@ -894,7 +899,7 @@ def test_legacy_create_unix_server_ssl(self): self.test_create_unix_server_ssl() @unittest.skipIf(ssl is None, 'No ssl module') - @unittest.skipIf(asyncio.BACKPORT_SSL_CONTEXT, 'need ssl.SSLContext') + @skip_if_backported_sslcontext() def test_create_server_ssl_verify_failed(self): proto = MyProto(loop=self.loop) server, host, port = self._make_ssl_server( @@ -928,7 +933,7 @@ def test_legacy_create_server_ssl_verify_failed(self): @unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipUnless(hasattr(socket, 'AF_UNIX'), 'No UNIX Sockets') - @unittest.skipIf(asyncio.BACKPORT_SSL_CONTEXT, 'need ssl.SSLContext') + @skip_if_backported_sslcontext() def test_create_unix_server_ssl_verify_failed(self): proto = MyProto(loop=self.loop) server, path = self._make_ssl_unix_server( diff --git a/trollius/selector_events.py b/trollius/selector_events.py index 4eb4bc5d..dc27ed14 100644 --- a/trollius/selector_events.py +++ b/trollius/selector_events.py @@ -23,8 +23,8 @@ from . import events from . import futures from . import selectors -from . import transports from . import sslproto +from . import transports from .compat import flatten_bytes from .coroutines import coroutine, From from .log import logger diff --git a/trollius/sslproto.py b/trollius/sslproto.py index 2da10822..707cc6d6 100644 --- a/trollius/sslproto.py +++ b/trollius/sslproto.py @@ -3,6 +3,7 @@ import warnings try: import ssl + from .py3_ssl import BACKPORT_SSL_CONTEXT except ImportError: # pragma: no cover ssl = None @@ -10,7 +11,6 @@ from . import transports from .log import logger from .py33_exceptions import BrokenPipeError, ConnectionResetError -from .py3_ssl import BACKPORT_SSL_CONTEXT def _create_transport_context(server_side, server_hostname): @@ -46,10 +46,11 @@ def _is_sslproto_available(): _WRAPPED = "WRAPPED" _SHUTDOWN = "SHUTDOWN" -if hasattr(ssl, 'CertificateError'): - _SSL_ERRORS = (ssl.SSLError, ssl.CertificateError) -else: - _SSL_ERRORS = ssl.SSLError +if ssl is not None: + if hasattr(ssl, 'CertificateError'): + _SSL_ERRORS = (ssl.SSLError, ssl.CertificateError) + else: + _SSL_ERRORS = ssl.SSLError class _SSLPipe(object): From 4a3cad7b7ae78c46cbe185e58f3632b642f0d225 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 13 Jul 2015 17:54:46 +0200 Subject: [PATCH 1440/1502] Fix TestCase.tearDown() for skipped tests --- trollius/compat.py | 3 +++ trollius/test_utils.py | 7 ++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/trollius/compat.py b/trollius/compat.py index 220a90f1..2b4e621b 100644 --- a/trollius/compat.py +++ b/trollius/compat.py @@ -2,6 +2,9 @@ import sys +# Python 2 or older? +PY2 = (sys.version_info <= (2,)) + # Python 2.6 or older? PY26 = (sys.version_info < (2, 7)) diff --git a/trollius/test_utils.py b/trollius/test_utils.py index 1d7080ad..ed331ed7 100644 --- a/trollius/test_utils.py +++ b/trollius/test_utils.py @@ -36,6 +36,7 @@ ssl = None from . import base_events +from . import compat from . import events from . import futures from . import selectors @@ -517,7 +518,11 @@ def tearDown(self): # Detect CPython bug #23353: ensure that yield/yield-from is not used # in an except block of a generator - self.assertEqual(sys.exc_info(), (None, None, None)) + if sys.exc_info()[0] == SkipTest: + if compat.PY2: + sys.exc_clear() + else: + self.assertEqual(sys.exc_info(), (None, None, None)) if not hasattr(unittest.TestCase, 'assertRaisesRegex'): def assertRaisesRegex(self, expected_exception, expected_regex, From 5d4fcb54a4281f95dc0254013ebf426a717a8b8a Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 13 Jul 2015 18:19:12 +0200 Subject: [PATCH 1441/1502] prepare release 2.0 --- MANIFEST.in | 4 +++- doc/changelog.rst | 4 ++-- doc/install.rst | 4 ++-- setup.py | 5 +++-- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index e804c6ad..f3b496f4 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,7 +2,9 @@ include AUTHORS COPYING TODO.rst tox.ini include Makefile include overlapped.c pypi.bat include check.py runtests.py run_aiotest.py release.py -include update-tulip*.sh +include update-asyncio-*.sh +include .travis.yml +include releaser.conf include doc/conf.py doc/make.bat doc/Makefile include doc/*.rst doc/*.jpg diff --git a/doc/changelog.rst b/doc/changelog.rst index 727dd5a5..684dd770 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -2,8 +2,8 @@ Change log ++++++++++ -Version 2.0 -=========== +Version 2.0 (2015-07-13) +======================== Summary: diff --git a/doc/install.rst b/doc/install.rst index 72e20e8f..4dd6b1a8 100644 --- a/doc/install.rst +++ b/doc/install.rst @@ -39,8 +39,8 @@ Trollius on Windows: .. note:: - Only wheel packages for Python 2.7 are currently distributed on the - Cheeseshop (PyPI). If you need wheel packages for other Python versions, + Only wheel packages for Python 2.7, 3.3 and 3.4 are currently distributed on + the Cheeseshop (PyPI). If you need wheel packages for other Python versions, please ask. Download source code diff --git a/setup.py b/setup.py index f3028e9e..94a0bfbb 100644 --- a/setup.py +++ b/setup.py @@ -4,15 +4,16 @@ # - run all tests: tox # - test examples # - update version in setup.py (version) and doc/conf.py (version, release) -# - set release date in doc/changelog.rst # - check that "python setup.py sdist" contains all files tracked by # the SCM (Mercurial): update MANIFEST.in if needed +# - run test on Windows: releaser.py test +# - set release date in doc/changelog.rst # - git commit # - git tag trollius-VERSION # - git push --tags # - git push # - On Linux: python setup.py register sdist bdist_wheel upload -# - On Windows: python release.py release +# - On Windows: python releaser.py release # - increment version in setup.py (version) and doc/conf.py (version, release) # - gt commit && git push From 5e9854d7b7bed6eb6e182808379342355e2bfca4 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 13 Jul 2015 18:20:03 +0200 Subject: [PATCH 1442/1502] update release procedure --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 94a0bfbb..37e3b332 100644 --- a/setup.py +++ b/setup.py @@ -1,12 +1,12 @@ # Release procedure: # - fill trollius changelog -# - run maybe ./update_asyncio_step1.sh +# - run maybe ./update-asyncio-step1.sh # - run all tests: tox # - test examples -# - update version in setup.py (version) and doc/conf.py (version, release) # - check that "python setup.py sdist" contains all files tracked by # the SCM (Mercurial): update MANIFEST.in if needed # - run test on Windows: releaser.py test +# - update version in setup.py (version) and doc/conf.py (version, release) # - set release date in doc/changelog.rst # - git commit # - git tag trollius-VERSION From 306b7b20a56c55da50a21a99d978b7a86cfa5d60 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 13 Jul 2015 18:29:52 +0200 Subject: [PATCH 1443/1502] post-release: set version to 2.0.1 --- doc/conf.py | 2 +- setup.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 818d8190..0d3b8dd8 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -48,7 +48,7 @@ # built documents. # # The short X.Y version. -version = release = '2.0' +version = release = '2.0.1' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/setup.py b/setup.py index 37e3b332..05e96c91 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ # - On Linux: python setup.py register sdist bdist_wheel upload # - On Windows: python releaser.py release # - increment version in setup.py (version) and doc/conf.py (version, release) -# - gt commit && git push +# - git commit -a && git push import os import sys @@ -46,7 +46,7 @@ install_options = { "name": "trollius", - "version": "2.0", + "version": "2.0.1", "license": "Apache License 2.0", "author": 'Victor Stinner', "author_email": 'victor.stinner@gmail.com', From 9414da199decf669a792b3c77f9d87471abefdcf Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 13 Jul 2015 18:31:35 +0200 Subject: [PATCH 1444/1502] test_utils.py: remove assertRaisesRegex unittest2 is now used on Python 2.6 and 2.7. --- TODO.rst | 1 - trollius/test_utils.py | 31 ------------------------------- 2 files changed, 32 deletions(-) diff --git a/TODO.rst b/TODO.rst index 4a433c22..11f6cbb0 100644 --- a/TODO.rst +++ b/TODO.rst @@ -2,7 +2,6 @@ Unsorted "TODO" tasks: * Drop Python 2.6 and 3.2 support * Drop platform without ssl module? -* test_utils.py: remove assertRaisesRegex, assertRegex * streams.py:FIXME: should we support __aiter__ and __anext__ in Trollius? * reuse selectors backport from PyPI * check ssl.SSLxxx in update_xxx.sh diff --git a/trollius/test_utils.py b/trollius/test_utils.py index ed331ed7..f67475d2 100644 --- a/trollius/test_utils.py +++ b/trollius/test_utils.py @@ -524,37 +524,6 @@ def tearDown(self): else: self.assertEqual(sys.exc_info(), (None, None, None)) - if not hasattr(unittest.TestCase, 'assertRaisesRegex'): - def assertRaisesRegex(self, expected_exception, expected_regex, - callable_obj=None, *args, **kwargs): - """Asserts that the message in a raised exception matches a regex. - - Args: - expected_exception: Exception class expected to be raised. - expected_regex: Regex (re pattern object or string) expected - to be found in error message. - callable_obj: Function to be called. - msg: Optional message used in case of failure. Can only be used - when assertRaisesRegex is used as a context manager. - args: Extra args. - kwargs: Extra kwargs. - """ - context = _AssertRaisesContext(expected_exception, self, callable_obj, - expected_regex) - - return context.handle('assertRaisesRegex', callable_obj, args, kwargs) - - if not hasattr(unittest.TestCase, 'assertRegex'): - def assertRegex(self, text, expected_regex, msg=None): - """Fail the test unless the text matches the regular expression.""" - if isinstance(expected_regex, (str, bytes)): - assert expected_regex, "expected_regex must not be empty." - expected_regex = re.compile(expected_regex) - if not expected_regex.search(text): - msg = msg or "Regex didn't match" - msg = '%s: %r not found in %r' % (msg, expected_regex.pattern, text) - raise self.failureException(msg) - def check_soure_traceback(self, source_traceback, lineno_delta): frame = sys._getframe(1) filename = frame.f_code.co_filename From a0bad4611ab403b261c86a629e05abd331415032 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 13 Jul 2015 18:49:46 +0200 Subject: [PATCH 1445/1502] update TODO list --- TODO.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/TODO.rst b/TODO.rst index 11f6cbb0..31683a50 100644 --- a/TODO.rst +++ b/TODO.rst @@ -3,7 +3,8 @@ Unsorted "TODO" tasks: * Drop Python 2.6 and 3.2 support * Drop platform without ssl module? * streams.py:FIXME: should we support __aiter__ and __anext__ in Trollius? -* reuse selectors backport from PyPI +* replace selectors.py with selectors34: + https://github.com/berkerpeksag/selectors34/pull/2 * check ssl.SSLxxx in update_xxx.sh * document how to port asyncio to trollius * use six instead of compat From d444158a05eb2d1dd171f8659f48e2a9f78eb3c6 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 13 Jul 2015 18:54:45 +0200 Subject: [PATCH 1446/1502] Use the six module --- doc/install.rst | 2 ++ setup.py | 2 +- tests/test_events.py | 3 ++- tests/test_futures.py | 3 ++- tests/test_streams.py | 3 ++- tox.ini | 7 +++++++ trollius/compat.py | 13 ++++--------- trollius/futures.py | 3 ++- trollius/test_utils.py | 2 +- trollius/windows_utils.py | 2 +- 10 files changed, 24 insertions(+), 16 deletions(-) diff --git a/doc/install.rst b/doc/install.rst index 4dd6b1a8..db4a8b15 100644 --- a/doc/install.rst +++ b/doc/install.rst @@ -65,6 +65,8 @@ Tulip project, Trollius repository is a fork of the Tulip repository. Dependencies ============ +Trollius requires the `six `_ module. + On Python older than 3.2, the `futures `_ project is needed to get a backport of ``concurrent.futures``. diff --git a/setup.py b/setup.py index 05e96c91..012429af 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,7 @@ ) extensions.append(ext) -requirements = [] +requirements = ['six'] if sys.version_info < (2, 7): requirements.append('ordereddict') if sys.version_info < (3,): diff --git a/tests/test_events.py b/tests/test_events.py index af7113b6..59c25d4d 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -8,6 +8,7 @@ import platform import re import signal +import six import socket import subprocess import sys @@ -981,7 +982,7 @@ def test_create_server_ssl_match_failed(self): if hasattr(sslcontext_client, 'check_hostname'): sslcontext_client.check_hostname = True - if compat.PY3: + if six.PY3: err_msg = "hostname '127.0.0.1' doesn't match 'localhost'" else: # http://bugs.python.org/issue22861 diff --git a/tests/test_futures.py b/tests/test_futures.py index 6467befd..78a097b2 100644 --- a/tests/test_futures.py +++ b/tests/test_futures.py @@ -5,6 +5,7 @@ except ImportError: concurrent = None import re +import six import sys import threading @@ -347,7 +348,7 @@ def memory_error(): r'MemoryError$' ).format(filename=re.escape(frame[0]), lineno=frame[1]) - elif compat.PY3: + elif six.PY3: regex = (r'^Future/Task exception was never retrieved\n' r'Traceback \(most recent call last\):\n' r'.*\n' diff --git a/tests/test_streams.py b/tests/test_streams.py index d20e9fef..9ecbb663 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -4,6 +4,7 @@ import io import os import socket +import six import sys try: import ssl @@ -609,7 +610,7 @@ def test_read_all_from_pipe_reader(self): try: asyncio.set_child_watcher(watcher) kw = {'loop': self.loop} - if compat.PY3: + if six.PY3: kw['pass_fds'] = set((wfd,)) create = asyncio.create_subprocess_exec(*args, **kw) proc = self.loop.run_until_complete(create) diff --git a/tox.ini b/tox.ini index 6f0a3421..2dde9431 100644 --- a/tox.ini +++ b/tox.ini @@ -5,6 +5,7 @@ envlist = py26,py27,py2_release,py2_no_ssl,py2_no_concurrent,py32,py33,py34,py3_ [testenv] deps= aiotest + six setenv = TROLLIUSDEBUG = 1 commands= @@ -31,6 +32,7 @@ deps= futures mock==1.0.1 ordereddict + six unittest2 [testenv:py27] @@ -38,6 +40,7 @@ deps= aiotest futures mock + six unittest2 [testenv:py2_release] @@ -47,6 +50,7 @@ deps= aiotest futures mock + six unittest2 setenv = TROLLIUSDEBUG = @@ -57,6 +61,7 @@ deps= aiotest futures mock + six unittest2 commands= python -Wd runtests.py --no-ssl -r {posargs} @@ -67,6 +72,7 @@ deps= aiotest futures mock + six unittest2 commands= python -Wd runtests.py --no-concurrent -r {posargs} @@ -75,6 +81,7 @@ commands= deps= aiotest mock + six [testenv:py35] basepython = python3.5 diff --git a/trollius/compat.py b/trollius/compat.py index 2b4e621b..df64abac 100644 --- a/trollius/compat.py +++ b/trollius/compat.py @@ -1,16 +1,11 @@ """Compatibility helpers for the different Python versions.""" +import six import sys -# Python 2 or older? -PY2 = (sys.version_info <= (2,)) - # Python 2.6 or older? PY26 = (sys.version_info < (2, 7)) -# Python 3.0 or newer? -PY3 = (sys.version_info >= (3,)) - # Python 3.3 or newer? PY33 = (sys.version_info >= (3, 3)) @@ -20,7 +15,7 @@ # Python 3.5 or newer? PY35 = sys.version_info >= (3, 5) -if PY3: +if six.PY3: integer_types = (int,) bytes_type = bytes text_type = str @@ -37,7 +32,7 @@ BYTES_TYPES = (str, bytearray, memoryview, buffer) -if PY3: +if six.PY3: def reraise(tp, value, tb=None): if value.__traceback__ is not tb: raise value.with_traceback(tb) @@ -60,7 +55,7 @@ def flatten_bytes(data): return data if not data: return b'' - if not PY3 and isinstance(data, (buffer, bytearray)): + if six.PY2 and isinstance(data, (buffer, bytearray)): return str(data) elif not PY26 and isinstance(data, memoryview): return data.tobytes() diff --git a/trollius/futures.py b/trollius/futures.py index 9899ae59..4d4e20f6 100644 --- a/trollius/futures.py +++ b/trollius/futures.py @@ -6,6 +6,7 @@ ] import logging +import six import sys import traceback try: @@ -370,7 +371,7 @@ def _set_exception_with_tb(self, exception, exc_tb): if exc_tb is not None: self._exception_tb = exc_tb exc_tb = None - elif self._loop.get_debug() and not compat.PY3: + elif self._loop.get_debug() and not six.PY3: self._exception_tb = sys.exc_info()[2] self._state = _FINISHED self._schedule_callbacks() diff --git a/trollius/test_utils.py b/trollius/test_utils.py index f67475d2..dbff44b8 100644 --- a/trollius/test_utils.py +++ b/trollius/test_utils.py @@ -519,7 +519,7 @@ def tearDown(self): # Detect CPython bug #23353: ensure that yield/yield-from is not used # in an except block of a generator if sys.exc_info()[0] == SkipTest: - if compat.PY2: + if six.PY2: sys.exc_clear() else: self.assertEqual(sys.exc_info(), (None, None, None)) diff --git a/trollius/windows_utils.py b/trollius/windows_utils.py index 2a10ce8d..d25d2a4b 100644 --- a/trollius/windows_utils.py +++ b/trollius/windows_utils.py @@ -168,7 +168,7 @@ def close(self, CloseHandle=_winapi.CloseHandle): def __del__(self): if self._handle is not None: - if compat.PY3: + if six.PY3: warnings.warn("unclosed %r" % self, ResourceWarning) self.close() From ce3ad816a2ef9456b4b1c26b99dfc85ea1236811 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 14 Jul 2015 13:06:47 +0200 Subject: [PATCH 1447/1502] Return True from StreamReader.eof_received() to fix http://bugs.python.org/issue24539 (but still needs a unittest). Add StreamReader.__repr__() for easy debugging. --- asyncio/streams.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/asyncio/streams.py b/asyncio/streams.py index 6cd60c42..902d1ca2 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -240,6 +240,7 @@ def data_received(self, data): def eof_received(self): self._stream_reader.feed_eof() + return True class StreamWriter: @@ -321,6 +322,24 @@ def __init__(self, limit=_DEFAULT_LIMIT, loop=None): self._transport = None self._paused = False + def __repr__(self): + info = ['StreamReader'] + if self._buffer: + info.append('%d bytes' % len(info)) + if self._eof: + info.append('eof') + if self._limit != _DEFAULT_LIMIT: + info.append('l=%d' % self._limit) + if self._waiter: + info.append('w=%r' % self._waiter) + if self._exception: + info.append('e=%r' % self._exception) + if self._transport: + info.append('t=%r' % self._transport) + if self._paused: + info.append('paused') + return '<%s>' % ' '.join(info) + def exception(self): return self._exception From d9d88818ab6c887d05939d0d570b62fb977a44ef Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 17 Jul 2015 18:46:08 +0200 Subject: [PATCH 1448/1502] setup.py: don't use bdist_wheel --- setup.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 012429af..19d9670d 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,9 @@ # - git tag trollius-VERSION # - git push --tags # - git push -# - On Linux: python setup.py register sdist bdist_wheel upload +# - On Linux: python setup.py register sdist upload +# FIXME: don't use bdist_wheel because of +# FIXME: https://github.com/haypo/trollius/issues/1 # - On Windows: python releaser.py release # - increment version in setup.py (version) and doc/conf.py (version, release) # - git commit -a && git push From d1630dc71ac8cf0ada17386ebd0c16941d248b4b Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sat, 18 Jul 2015 23:32:44 +0200 Subject: [PATCH 1449/1502] doc: add libraries --- doc/index.rst | 1 + doc/libraries.rst | 30 ++++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+) create mode 100644 doc/libraries.rst diff --git a/doc/index.rst b/doc/index.rst index c6b70bf8..5135da1b 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -60,6 +60,7 @@ Table Of Contents using install + libraries asyncio dev changelog diff --git a/doc/libraries.rst b/doc/libraries.rst new file mode 100644 index 00000000..424fd283 --- /dev/null +++ b/doc/libraries.rst @@ -0,0 +1,30 @@ +++++++++++++++++++ +Trollius Libraries +++++++++++++++++++ + +Libraries compatible with asyncio and trollius +============================================== + +* `aioeventlet `_: asyncio API + implemented on top of eventlet +* `aiogevent `_: asyncio API + implemented on top of gevent +* `AutobahnPython `_: WebSocket & + WAMP for Python, it works on Trollius (Python 2.6 and 2.7), asyncio (Python + 3.3) and Python 3.4 (asyncio), and also on Twisted. +* `Pulsar `_: Event driven concurrent + framework for Python. With pulsar you can write asynchronous servers + performing one or several activities in different threads and/or processes. + Trollius 0.3 requires Pulsar 0.8.2 or later. Pulsar uses the ``asyncio`` + module if available, or import ``trollius``. +* `Tornado `_ supports asyncio and Trollius since + Tornado 3.2: `tornado.platform.asyncio — Bridge between asyncio and Tornado + `_. It tries to import + asyncio or fallback on importing trollius. + +Specific Ports +============== + +* `trollius-redis `_: + A port of `asyncio-redis `_ to + trollius From b13b14a88fbb51e9982d1d5547655d6519b23d5d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sun, 19 Jul 2015 01:13:18 +0200 Subject: [PATCH 1450/1502] fix test_utils: add missing "import six" --- trollius/test_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trollius/test_utils.py b/trollius/test_utils.py index dbff44b8..12cdd454 100644 --- a/trollius/test_utils.py +++ b/trollius/test_utils.py @@ -14,6 +14,8 @@ from wsgiref.simple_server import WSGIRequestHandler, WSGIServer +import six + try: import socketserver from http.server import HTTPServer From f25cb291d8439d9f3d44f52811607d3fdb305d1f Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sun, 19 Jul 2015 01:14:28 +0200 Subject: [PATCH 1451/1502] fix windows_utils: add missing "import six" --- trollius/windows_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trollius/windows_utils.py b/trollius/windows_utils.py index d25d2a4b..288d5478 100644 --- a/trollius/windows_utils.py +++ b/trollius/windows_utils.py @@ -16,6 +16,8 @@ import tempfile import warnings +import six + from . import py33_winapi as _winapi from . import compat from .py33_exceptions import wrap_error, BlockingIOError, InterruptedError From 9bb67431adc916d9d4b4e23ca257658c980d035d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Sun, 19 Jul 2015 01:16:46 +0200 Subject: [PATCH 1452/1502] remove unused imports --- asyncio/locks.py | 1 - asyncio/streams.py | 1 - asyncio/subprocess.py | 2 -- asyncio/tasks.py | 2 -- asyncio/transports.py | 2 -- 5 files changed, 8 deletions(-) diff --git a/asyncio/locks.py b/asyncio/locks.py index cc6f2bf7..7a132796 100644 --- a/asyncio/locks.py +++ b/asyncio/locks.py @@ -3,7 +3,6 @@ __all__ = ['Lock', 'Event', 'Condition', 'Semaphore', 'BoundedSemaphore'] import collections -import sys from . import compat from . import events diff --git a/asyncio/streams.py b/asyncio/streams.py index 902d1ca2..6484c435 100644 --- a/asyncio/streams.py +++ b/asyncio/streams.py @@ -6,7 +6,6 @@ ] import socket -import sys if hasattr(socket, 'AF_UNIX'): __all__.extend(['open_unix_connection', 'start_unix_server']) diff --git a/asyncio/subprocess.py b/asyncio/subprocess.py index 4600a9f4..ead4039b 100644 --- a/asyncio/subprocess.py +++ b/asyncio/subprocess.py @@ -1,10 +1,8 @@ __all__ = ['create_subprocess_exec', 'create_subprocess_shell'] -import collections import subprocess from . import events -from . import futures from . import protocols from . import streams from . import tasks diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 1d5f8654..9bfc1cf8 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -10,8 +10,6 @@ import functools import inspect import linecache -import sys -import types import traceback import warnings import weakref diff --git a/asyncio/transports.py b/asyncio/transports.py index 7a28d908..70b323f2 100644 --- a/asyncio/transports.py +++ b/asyncio/transports.py @@ -1,7 +1,5 @@ """Abstract Transport class.""" -import sys - from asyncio import compat __all__ = ['BaseTransport', 'ReadTransport', 'WriteTransport', From ce41fba64d504647e3c2379094408acd16a38338 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 24 Jul 2015 22:03:00 +0200 Subject: [PATCH 1453/1502] Fix ResourceWarning warnings in test_streams --- tests/test_streams.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/tests/test_streams.py b/tests/test_streams.py index 242b377e..ef6f6030 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -446,6 +446,8 @@ def __init__(self, loop): def handle_client(self, client_reader, client_writer): data = yield from client_reader.readline() client_writer.write(data) + yield from client_writer.drain() + client_writer.close() def start(self): sock = socket.socket() @@ -457,12 +459,8 @@ def start(self): return sock.getsockname() def handle_client_callback(self, client_reader, client_writer): - task = asyncio.Task(client_reader.readline(), loop=self.loop) - - def done(task): - client_writer.write(task.result()) - - task.add_done_callback(done) + self.loop.create_task(self.handle_client(client_reader, + client_writer)) def start_callback(self): sock = socket.socket() @@ -522,6 +520,8 @@ def __init__(self, loop, path): def handle_client(self, client_reader, client_writer): data = yield from client_reader.readline() client_writer.write(data) + yield from client_writer.drain() + client_writer.close() def start(self): self.server = self.loop.run_until_complete( @@ -530,18 +530,14 @@ def start(self): loop=self.loop)) def handle_client_callback(self, client_reader, client_writer): - task = asyncio.Task(client_reader.readline(), loop=self.loop) - - def done(task): - client_writer.write(task.result()) - - task.add_done_callback(done) + self.loop.create_task(self.handle_client(client_reader, + client_writer)) def start_callback(self): - self.server = self.loop.run_until_complete( - asyncio.start_unix_server(self.handle_client_callback, - path=self.path, - loop=self.loop)) + start = asyncio.start_unix_server(self.handle_client_callback, + path=self.path, + loop=self.loop) + self.server = self.loop.run_until_complete(start) def stop(self): if self.server is not None: From 5eac181ac595bdb58c7e97690c8882c6f5651ac9 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 31 Jul 2015 17:44:49 +0200 Subject: [PATCH 1454/1502] Fix ResourceWarning in BaseSubprocessTransport Python issue #24763: Fix resource warnings when BaseSubprocessTransport constructor fails, if subprocess.Popen raises an exception for example. Patch written by Martin Richard, test written by me. --- asyncio/base_subprocess.py | 9 +++++++-- tests/test_subprocess.py | 15 +++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index c1477b82..a6971b1d 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -35,8 +35,13 @@ def __init__(self, loop, protocol, args, shell, self._pipes[2] = None # Create the child process: set the _proc attribute - self._start(args=args, shell=shell, stdin=stdin, stdout=stdout, - stderr=stderr, bufsize=bufsize, **kwargs) + try: + self._start(args=args, shell=shell, stdin=stdin, stdout=stdout, + stderr=stderr, bufsize=bufsize, **kwargs) + except: + self.close() + raise + self._pid = self._proc.pid self._extra['subprocess'] = self._proc diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index ea85e191..d138c263 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -1,6 +1,7 @@ import signal import sys import unittest +import warnings from unittest import mock import asyncio @@ -413,6 +414,20 @@ def kill_running(): # the transport was not notified yet self.assertFalse(killed) + def test_popen_error(self): + # Issue #24763: check that the subprocess transport is closed + # when BaseSubprocessTransport fails + with mock.patch('subprocess.Popen') as popen: + exc = ZeroDivisionError + popen.side_effect = exc + + create = asyncio.create_subprocess_exec(sys.executable, '-c', + 'pass', loop=self.loop) + with warnings.catch_warnings(record=True) as warns: + with self.assertRaises(exc): + self.loop.run_until_complete(create) + self.assertEqual(warns, []) + if sys.platform != 'win32': # Unix From 4f45ac71aa8e1aa97007c6da4f12f7c159b24a36 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 31 Jul 2015 17:59:13 +0200 Subject: [PATCH 1455/1502] Use compat.PY34 --- asyncio/base_events.py | 3 ++- asyncio/base_subprocess.py | 4 ++-- asyncio/proactor_events.py | 4 ++-- asyncio/selector_events.py | 4 ++-- asyncio/sslproto.py | 4 ++-- asyncio/unix_events.py | 5 +++-- 6 files changed, 13 insertions(+), 11 deletions(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 5a536a22..8e4ad4f6 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -28,6 +28,7 @@ import sys import warnings +from . import compat from . import coroutines from . import events from . import futures @@ -378,7 +379,7 @@ def is_closed(self): # On Python 3.3 and older, objects with a destructor part of a reference # cycle are never destroyed. It's not more the case on Python 3.4 thanks # to the PEP 442. - if sys.version_info >= (3, 4): + if compat.PY34: def __del__(self): if not self.is_closed(): warnings.warn("unclosed event loop %r" % self, ResourceWarning) diff --git a/asyncio/base_subprocess.py b/asyncio/base_subprocess.py index a6971b1d..6851cd2b 100644 --- a/asyncio/base_subprocess.py +++ b/asyncio/base_subprocess.py @@ -1,8 +1,8 @@ import collections import subprocess -import sys import warnings +from . import compat from . import futures from . import protocols from . import transports @@ -116,7 +116,7 @@ def close(self): # On Python 3.3 and older, objects with a destructor part of a reference # cycle are never destroyed. It's not more the case on Python 3.4 thanks # to the PEP 442. - if sys.version_info >= (3, 4): + if compat.PY34: def __del__(self): if not self._closed: warnings.warn("unclosed transport %r" % self, ResourceWarning) diff --git a/asyncio/proactor_events.py b/asyncio/proactor_events.py index 9c2b8f15..abe4c129 100644 --- a/asyncio/proactor_events.py +++ b/asyncio/proactor_events.py @@ -7,10 +7,10 @@ __all__ = ['BaseProactorEventLoop'] import socket -import sys import warnings from . import base_events +from . import compat from . import constants from . import futures from . import sslproto @@ -79,7 +79,7 @@ def close(self): # On Python 3.3 and older, objects with a destructor part of a reference # cycle are never destroyed. It's not more the case on Python 3.4 thanks # to the PEP 442. - if sys.version_info >= (3, 4): + if compat.PY34: def __del__(self): if self._sock is not None: warnings.warn("unclosed transport %r" % self, ResourceWarning) diff --git a/asyncio/selector_events.py b/asyncio/selector_events.py index 7c5b9b5b..4a996584 100644 --- a/asyncio/selector_events.py +++ b/asyncio/selector_events.py @@ -10,7 +10,6 @@ import errno import functools import socket -import sys import warnings try: import ssl @@ -18,6 +17,7 @@ ssl = None from . import base_events +from . import compat from . import constants from . import events from . import futures @@ -568,7 +568,7 @@ def close(self): # On Python 3.3 and older, objects with a destructor part of a reference # cycle are never destroyed. It's not more the case on Python 3.4 thanks # to the PEP 442. - if sys.version_info >= (3, 4): + if compat.PY34: def __del__(self): if self._sock is not None: warnings.warn("unclosed transport %r" % self, ResourceWarning) diff --git a/asyncio/sslproto.py b/asyncio/sslproto.py index 235855e2..e566946e 100644 --- a/asyncio/sslproto.py +++ b/asyncio/sslproto.py @@ -1,11 +1,11 @@ import collections -import sys import warnings try: import ssl except ImportError: # pragma: no cover ssl = None +from . import compat from . import protocols from . import transports from .log import logger @@ -317,7 +317,7 @@ def close(self): # On Python 3.3 and older, objects with a destructor part of a reference # cycle are never destroyed. It's not more the case on Python 3.4 thanks # to the PEP 442. - if sys.version_info >= (3, 4): + if compat.PY34: def __del__(self): if not self._closed: warnings.warn("unclosed transport %r" % self, ResourceWarning) diff --git a/asyncio/unix_events.py b/asyncio/unix_events.py index 75e7c9cc..bf3b0844 100644 --- a/asyncio/unix_events.py +++ b/asyncio/unix_events.py @@ -13,6 +13,7 @@ from . import base_events from . import base_subprocess +from . import compat from . import constants from . import coroutines from . import events @@ -370,7 +371,7 @@ def close(self): # On Python 3.3 and older, objects with a destructor part of a reference # cycle are never destroyed. It's not more the case on Python 3.4 thanks # to the PEP 442. - if sys.version_info >= (3, 4): + if compat.PY34: def __del__(self): if self._pipe is not None: warnings.warn("unclosed transport %r" % self, ResourceWarning) @@ -555,7 +556,7 @@ def close(self): # On Python 3.3 and older, objects with a destructor part of a reference # cycle are never destroyed. It's not more the case on Python 3.4 thanks # to the PEP 442. - if sys.version_info >= (3, 4): + if compat.PY34: def __del__(self): if self._pipe is not None: warnings.warn("unclosed transport %r" % self, ResourceWarning) From 4851618b6b154272e79bdc1b5f63b4887e12d8e9 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Sun, 2 Aug 2015 10:23:11 -0400 Subject: [PATCH 1456/1502] tasks: Fix code style --- asyncio/tasks.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 9bfc1cf8..45c6d1b0 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -249,9 +249,8 @@ def _step(self, value=None, exc=None): result._blocking = False result.add_done_callback(self._wakeup) self._fut_waiter = result - if self._must_cancel: - if self._fut_waiter.cancel(): - self._must_cancel = False + if self._must_cancel and self._fut_waiter.cancel(): + self._must_cancel = False else: self._loop.call_soon( self._step, None, From 51a3206600ff4c18f5368d0389157f6720126a4e Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Sun, 2 Aug 2015 16:47:28 -0400 Subject: [PATCH 1457/1502] Revert "tasks: Fix code style" This reverts commit 4851618b6b154272e79bdc1b5f63b4887e12d8e9 per Guido's comment: I don't like the refactoring. The original way (two nested if's) was IMO clearer about the side effect (of calling cancel()) that only happens if the first condition is true. The refactored version sweeps the side effect under the proverbial mat. --- asyncio/tasks.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 45c6d1b0..9bfc1cf8 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -249,8 +249,9 @@ def _step(self, value=None, exc=None): result._blocking = False result.add_done_callback(self._wakeup) self._fut_waiter = result - if self._must_cancel and self._fut_waiter.cancel(): - self._must_cancel = False + if self._must_cancel: + if self._fut_waiter.cancel(): + self._must_cancel = False else: self._loop.call_soon( self._step, None, From 27f3499f968e8734fef91677eb339b5d32a6f675 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Tue, 4 Aug 2015 15:35:46 -0400 Subject: [PATCH 1458/1502] Use '==' operator instead of 'is' --- asyncio/base_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncio/base_events.py b/asyncio/base_events.py index 8e4ad4f6..c2054454 100644 --- a/asyncio/base_events.py +++ b/asyncio/base_events.py @@ -1206,7 +1206,7 @@ def _set_coroutine_wrapper(self, enabled): return enabled = bool(enabled) - if self._coroutine_wrapper_set is enabled: + if self._coroutine_wrapper_set == enabled: return wrapper = coroutines.debug_wrapper From 5d71b68fdbdafeac85652acb030d6878ace683e6 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Wed, 5 Aug 2015 12:06:19 -0400 Subject: [PATCH 1459/1502] Run asyncio with python nightly builds too on travis --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index 2c5838b8..5a2c7d7b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,6 +7,7 @@ os: python: - 3.3 - 3.4 + - "nightly" install: - pip install asyncio From 1dd213ee66e4b8fbca658652f8b1bc3958707c0b Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Wed, 5 Aug 2015 13:47:44 -0400 Subject: [PATCH 1460/1502] Merge PR #256: fix issue23812 of queues loosing items on cancellation Patch by @gjcarneiro. --- asyncio/queues.py | 47 +++++++++++++++++++++++++++------- tests/test_queues.py | 61 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 98 insertions(+), 10 deletions(-) diff --git a/asyncio/queues.py b/asyncio/queues.py index c55dd8bb..b26edfbe 100644 --- a/asyncio/queues.py +++ b/asyncio/queues.py @@ -47,7 +47,7 @@ def __init__(self, maxsize=0, *, loop=None): # Futures. self._getters = collections.deque() - # Pairs of (item, Future). + # Futures self._putters = collections.deque() self._unfinished_tasks = 0 self._finished = locks.Event(loop=self._loop) @@ -98,7 +98,7 @@ def _consume_done_getters(self): def _consume_done_putters(self): # Delete waiters at the head of the put() queue who've timed out. - while self._putters and self._putters[0][1].done(): + while self._putters and self._putters[0].done(): self._putters.popleft() def qsize(self): @@ -148,8 +148,9 @@ def put(self, item): elif self._maxsize > 0 and self._maxsize <= self.qsize(): waiter = futures.Future(loop=self._loop) - self._putters.append((item, waiter)) + self._putters.append(waiter) yield from waiter + self._put(item) else: self.__put_internal(item) @@ -186,8 +187,7 @@ def get(self): self._consume_done_putters() if self._putters: assert self.full(), 'queue not full, why are putters waiting?' - item, putter = self._putters.popleft() - self.__put_internal(item) + putter = self._putters.popleft() # When a getter runs and frees up a slot so this putter can # run, we need to defer the put for a tick to ensure that @@ -201,9 +201,39 @@ def get(self): return self._get() else: waiter = futures.Future(loop=self._loop) - self._getters.append(waiter) - return (yield from waiter) + try: + return (yield from waiter) + except futures.CancelledError: + # if we get CancelledError, it means someone cancelled this + # get() coroutine. But there is a chance that the waiter + # already is ready and contains an item that has just been + # removed from the queue. In this case, we need to put the item + # back into the front of the queue. This get() must either + # succeed without fault or, if it gets cancelled, it must be as + # if it never happened. + if waiter.done(): + self._put_it_back(waiter.result()) + raise + + def _put_it_back(self, item): + """ + This is called when we have a waiter to get() an item and this waiter + gets cancelled. In this case, we put the item back: wake up another + waiter or put it in the _queue. + """ + self._consume_done_getters() + if self._getters: + assert not self._queue, ( + 'queue non-empty, why are getters waiting?') + + getter = self._getters.popleft() + self._put_internal(item) + + # getter cannot be cancelled, we just removed done getters + getter.set_result(item) + else: + self._queue.appendleft(item) def get_nowait(self): """Remove and return an item from the queue. @@ -213,8 +243,7 @@ def get_nowait(self): self._consume_done_putters() if self._putters: assert self.full(), 'queue not full, why are putters waiting?' - item, putter = self._putters.popleft() - self.__put_internal(item) + putter = self._putters.popleft() # Wake putter on next tick. # getter cannot be cancelled, we just removed done putters diff --git a/tests/test_queues.py b/tests/test_queues.py index 88b4f075..7c7d0eae 100644 --- a/tests/test_queues.py +++ b/tests/test_queues.py @@ -171,7 +171,7 @@ def test_get_with_putters(self): q.put_nowait(1) waiter = asyncio.Future(loop=self.loop) - q._putters.append((2, waiter)) + q._putters.append(waiter) res = self.loop.run_until_complete(q.get()) self.assertEqual(1, res) @@ -322,6 +322,64 @@ def test_nonblocking_put(self): q.put_nowait(1) self.assertEqual(1, q.get_nowait()) + def test_get_cancel_drop(self): + def gen(): + yield 0.01 + yield 0.1 + + loop = self.new_test_loop(gen) + + q = asyncio.Queue(loop=loop) + + reader = loop.create_task(q.get()) + + loop.run_until_complete(asyncio.sleep(0.01, loop=loop)) + + q.put_nowait(1) + q.put_nowait(2) + reader.cancel() + + try: + loop.run_until_complete(reader) + except asyncio.CancelledError: + # try again + reader = loop.create_task(q.get()) + loop.run_until_complete(reader) + + result = reader.result() + # if we get 2, it means 1 got dropped! + self.assertEqual(1, result) + + def test_put_cancel_drop(self): + + def gen(): + yield 0.01 + yield 0.1 + + loop = self.new_test_loop(gen) + q = asyncio.Queue(1, loop=loop) + + q.put_nowait(1) + + # putting a second item in the queue has to block (qsize=1) + writer = loop.create_task(q.put(2)) + loop.run_until_complete(asyncio.sleep(0.01, loop=loop)) + + value1 = q.get_nowait() + self.assertEqual(value1, 1) + + writer.cancel() + try: + loop.run_until_complete(writer) + except asyncio.CancelledError: + # try again + writer = loop.create_task(q.put(2)) + loop.run_until_complete(writer) + + value2 = q.get_nowait() + self.assertEqual(value2, 2) + self.assertEqual(q.qsize(), 0) + def test_nonblocking_put_exception(self): q = asyncio.Queue(maxsize=1, loop=self.loop) q.put_nowait(1) @@ -374,6 +432,7 @@ def test_put_cancelled_race(self): test_utils.run_briefly(self.loop) self.assertTrue(put_c.done()) self.assertEqual(q.get_nowait(), 'a') + test_utils.run_briefly(self.loop) self.assertEqual(q.get_nowait(), 'b') self.loop.run_until_complete(put_b) From f57cfc623ad37f35876c4336e051406f871976ed Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Wed, 5 Aug 2015 12:01:51 -0400 Subject: [PATCH 1461/1502] Make sure BaseException is re-raised in SSLProtocol --- asyncio/sslproto.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/asyncio/sslproto.py b/asyncio/sslproto.py index e566946e..e5ae49a5 100644 --- a/asyncio/sslproto.py +++ b/asyncio/sslproto.py @@ -613,7 +613,8 @@ def _process_write_backlog(self): if data: ssldata, offset = self._sslpipe.feed_appdata(data, offset) elif offset: - ssldata = self._sslpipe.do_handshake(self._on_handshake_complete) + ssldata = self._sslpipe.do_handshake( + self._on_handshake_complete) offset = 1 else: ssldata = self._sslpipe.shutdown(self._finalize) @@ -637,9 +638,13 @@ def _process_write_backlog(self): self._write_buffer_size -= len(data) except BaseException as exc: if self._in_handshake: + # BaseExceptions will be re-raised in _on_handshake_complete. self._on_handshake_complete(exc) else: self._fatal_error(exc, 'Fatal error on SSL transport') + if not isinstance(exc, Exception): + # BaseException + raise def _fatal_error(self, exc, message='Fatal error on transport'): # Should be called from exception handler only. From f4111812967fded634637936225205f29035a449 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Thu, 6 Aug 2015 13:59:28 -0400 Subject: [PATCH 1462/1502] queues: Fix getter-cancellation with many pending getters code path This commit fixes a typo in a method call and adds a unittest to make sure that that code path is tested. --- asyncio/queues.py | 2 +- tests/test_queues.py | 37 ++++++++++++++++++++++++++++++++++++- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/asyncio/queues.py b/asyncio/queues.py index b26edfbe..021043d6 100644 --- a/asyncio/queues.py +++ b/asyncio/queues.py @@ -228,7 +228,7 @@ def _put_it_back(self, item): 'queue non-empty, why are getters waiting?') getter = self._getters.popleft() - self._put_internal(item) + self.__put_internal(item) # getter cannot be cancelled, we just removed done getters getter.set_result(item) diff --git a/tests/test_queues.py b/tests/test_queues.py index 7c7d0eae..8e38175e 100644 --- a/tests/test_queues.py +++ b/tests/test_queues.py @@ -322,7 +322,7 @@ def test_nonblocking_put(self): q.put_nowait(1) self.assertEqual(1, q.get_nowait()) - def test_get_cancel_drop(self): + def test_get_cancel_drop_one_pending_reader(self): def gen(): yield 0.01 yield 0.1 @@ -350,6 +350,41 @@ def gen(): # if we get 2, it means 1 got dropped! self.assertEqual(1, result) + def test_get_cancel_drop_many_pending_readers(self): + def gen(): + yield 0.01 + yield 0.1 + + loop = self.new_test_loop(gen) + loop.set_debug(True) + + q = asyncio.Queue(loop=loop) + + reader1 = loop.create_task(q.get()) + reader2 = loop.create_task(q.get()) + reader3 = loop.create_task(q.get()) + + loop.run_until_complete(asyncio.sleep(0.01, loop=loop)) + + q.put_nowait(1) + q.put_nowait(2) + reader1.cancel() + + try: + loop.run_until_complete(reader1) + except asyncio.CancelledError: + pass + + loop.run_until_complete(reader3) + + # reader2 will receive `2`, because it was added to the + # queue of pending readers *before* put_nowaits were called. + self.assertEqual(reader2.result(), 2) + # reader3 will receive `1`, because reader1 was cancelled + # before is had a chance to execute, and `2` was already + # pushed to reader2 by second `put_nowait`. + self.assertEqual(reader3.result(), 1) + def test_put_cancel_drop(self): def gen(): From f3ed6c3e8a857d7e93ff7ae74f3f48d7b1252179 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Sun, 9 Aug 2015 18:19:30 -0400 Subject: [PATCH 1463/1502] Fix asyncio tests on windows --- tests/test_subprocess.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index d138c263..38f0ceeb 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -417,7 +417,11 @@ def kill_running(): def test_popen_error(self): # Issue #24763: check that the subprocess transport is closed # when BaseSubprocessTransport fails - with mock.patch('subprocess.Popen') as popen: + if sys.platform == 'win32': + target = 'asyncio.windows_utils.Popen' + else: + target = 'subprocess.Popen' + with mock.patch(target) as popen: exc = ZeroDivisionError popen.side_effect = exc From 8d79c57726e30fd19d5fadf46375853df7895516 Mon Sep 17 00:00:00 2001 From: Yury Selivanov Date: Fri, 14 Aug 2015 15:29:28 -0400 Subject: [PATCH 1464/1502] Fix Task.get_stask() for 'async def' coroutines --- asyncio/tasks.py | 6 +++++- tests/test_tasks.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/asyncio/tasks.py b/asyncio/tasks.py index 9bfc1cf8..a235e742 100644 --- a/asyncio/tasks.py +++ b/asyncio/tasks.py @@ -128,7 +128,11 @@ def get_stack(self, *, limit=None): returned for a suspended coroutine. """ frames = [] - f = self._coro.gi_frame + try: + # 'async def' coroutines + f = self._coro.cr_frame + except AttributeError: + f = self._coro.gi_frame if f is not None: while f is not None: if limit is not None: diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 251192ac..04267873 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -2,6 +2,7 @@ import contextlib import functools +import io import os import re import sys @@ -162,6 +163,37 @@ def test_async_warning(self): 'function is deprecated, use ensure_'): self.assertIs(f, asyncio.async(f)) + def test_get_stack(self): + T = None + + @asyncio.coroutine + def foo(): + yield from bar() + + @asyncio.coroutine + def bar(): + # test get_stack() + f = T.get_stack(limit=1) + try: + self.assertEqual(f[0].f_code.co_name, 'foo') + finally: + f = None + + # test print_stack() + file = io.StringIO() + T.print_stack(limit=1, file=file) + file.seek(0) + tb = file.read() + self.assertRegex(tb, r'foo\(\) running') + + @asyncio.coroutine + def runner(): + nonlocal T + T = asyncio.ensure_future(foo(), loop=self.loop) + yield from T + + self.loop.run_until_complete(runner()) + def test_task_repr(self): self.loop.set_debug(False) From f285592a7193e2e2a96af344efffb6dc8ad9836a Mon Sep 17 00:00:00 2001 From: Dhawal Yogesh Bhanushali Date: Thu, 12 Nov 2015 18:37:25 -0800 Subject: [PATCH 1465/1502] better exception traceback --- trollius/futures.py | 2 +- trollius/tasks.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/trollius/futures.py b/trollius/futures.py index 4d4e20f6..746124b1 100644 --- a/trollius/futures.py +++ b/trollius/futures.py @@ -371,7 +371,7 @@ def _set_exception_with_tb(self, exception, exc_tb): if exc_tb is not None: self._exception_tb = exc_tb exc_tb = None - elif self._loop.get_debug() and not six.PY3: + elif not six.PY3: self._exception_tb = sys.exc_info()[2] self._state = _FINISHED self._schedule_callbacks() diff --git a/trollius/tasks.py b/trollius/tasks.py index 3e0e1b10..af5c8683 100644 --- a/trollius/tasks.py +++ b/trollius/tasks.py @@ -235,6 +235,7 @@ def cancel(self): def _step(self, value=None, exc=None, exc_tb=None): assert not self.done(), \ '_step(): already done: {0!r}, {1!r}, {2!r}'.format(self, value, exc) + if self._must_cancel: if not isinstance(exc, futures.CancelledError): exc = futures.CancelledError() @@ -250,7 +251,10 @@ def _step(self, value=None, exc=None, exc_tb=None): # Call either coro.throw(exc) or coro.send(value). try: if exc is not None: - result = coro.throw(exc) + if exc_tb is not None: + result = coro.throw(exc, None, exc_tb) + else: + result = coro.throw(exc) else: result = coro.send(value) except StopIteration as exc: From 056a39469ae441d54743aae79c9e5b1c348305ad Mon Sep 17 00:00:00 2001 From: Gabi Davar Date: Fri, 13 Nov 2015 13:31:14 +0200 Subject: [PATCH 1466/1502] fix test on py2 win32 --- tests/test_subprocess.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index 21e003a8..63288cef 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -432,7 +432,7 @@ def test_popen_error(self): # Issue #24763: check that the subprocess transport is closed # when BaseSubprocessTransport fails if sys.platform == 'win32': - target = 'asyncio.windows_utils.Popen' + target = 'trollius.windows_utils.Popen' else: target = 'subprocess.Popen' with mock.patch(target) as popen: From b335213a8e0c56c10b8c1de970fa9b1e2b119322 Mon Sep 17 00:00:00 2001 From: Gabi Davar Date: Fri, 13 Nov 2015 13:32:13 +0200 Subject: [PATCH 1467/1502] ignore inline built extensions --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 16fae2b7..4f2b0741 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,5 @@ dist *.egg-info .tox .idea/ -*.iml \ No newline at end of file +*.iml +trollius/_overlapped.pyd From e22a30b026996681d6dc3ee737ba4a2eeded729e Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 5 Feb 2016 11:55:55 +0100 Subject: [PATCH 1468/1502] Fix Future on Python 3.5 Future.__iter__() doesn't exist on Python 3.5. --- trollius/futures.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/trollius/futures.py b/trollius/futures.py index 746124b1..b77d1a11 100644 --- a/trollius/futures.py +++ b/trollius/futures.py @@ -423,9 +423,6 @@ def _copy_state(self, other): result = other.result() self.set_result(result) - if compat.PY35: - __await__ = __iter__ # make compatible with 'await' expression - if events.asyncio is not None: # Accept also asyncio Future objects for interoperability _FUTURE_CLASSES = (Future, events.asyncio.Future) From e1a4ed96d64b623cb8a8e8c39da49a3e1a3e5f9d Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 5 Feb 2016 12:28:08 +0100 Subject: [PATCH 1469/1502] document change on exception --- doc/changelog.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/doc/changelog.rst b/doc/changelog.rst index 684dd770..186d4e1d 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -2,6 +2,14 @@ Change log ++++++++++ +Version 2.1 +=========== + +Changes: + +* Better exception traceback. Patch written by Dhawal Yogesh Bhanushali. + + Version 2.0 (2015-07-13) ======================== From a8b8ad449c72433685677c621ff33eb8099e744b Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 5 Feb 2016 12:59:42 +0100 Subject: [PATCH 1470/1502] Drop support of Python 2.6 and 3.2 --- TODO.rst | 1 - doc/changelog.rst | 1 + doc/dev.rst | 5 ++--- doc/index.rst | 4 ++-- doc/install.rst | 22 +++++++++++++++------- tox.ini | 2 +- 6 files changed, 21 insertions(+), 14 deletions(-) diff --git a/TODO.rst b/TODO.rst index 31683a50..e4a785ed 100644 --- a/TODO.rst +++ b/TODO.rst @@ -1,6 +1,5 @@ Unsorted "TODO" tasks: -* Drop Python 2.6 and 3.2 support * Drop platform without ssl module? * streams.py:FIXME: should we support __aiter__ and __anext__ in Trollius? * replace selectors.py with selectors34: diff --git a/doc/changelog.rst b/doc/changelog.rst index 186d4e1d..78fe94ed 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -8,6 +8,7 @@ Version 2.1 Changes: * Better exception traceback. Patch written by Dhawal Yogesh Bhanushali. +* Drop support for Python 2.6 and 3.2. Version 2.0 (2015-07-13) diff --git a/doc/dev.rst b/doc/dev.rst index 1bed7f8e..c0bbf364 100644 --- a/doc/dev.rst +++ b/doc/dev.rst @@ -6,7 +6,7 @@ Run tests with tox The `tox project `_ can be used to build a virtual environment with all runtime and test dependencies and run tests -against different Python versions (2.6, 2.7, 3.2, 3.3). +against different Python versions (2.7, 3.3, 3.4). For example, to run tests with Python 2.7, just type:: @@ -14,10 +14,9 @@ For example, to run tests with Python 2.7, just type:: To run tests against other Python versions: -* ``py26``: Python 2.6 * ``py27``: Python 2.7 -* ``py32``: Python 3.2 * ``py33``: Python 3.3 +* ``py34``: Python 3.4 Test Dependencies diff --git a/doc/index.rst b/doc/index.rst index 5135da1b..cb344e47 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -35,8 +35,8 @@ Here is a more detailed list of the package contents: Trollius is a portage of the `asyncio project `_ (``asyncio`` module, `PEP 3156 `_) -on Python 2. Trollius works on Python 2.6-3.5. It has been tested on Windows, -Linux, Mac OS X, FreeBSD and OpenIndiana. +on Python 2. Trollius works on Python 2.7, 3.3 and 3.4. It has been tested on +Windows, Linux, Mac OS X, FreeBSD and OpenIndiana. * `Asyncio documentation `_ * `Trollius documentation `_ (this document) diff --git a/doc/install.rst b/doc/install.rst index db4a8b15..845e6e5f 100644 --- a/doc/install.rst +++ b/doc/install.rst @@ -2,6 +2,17 @@ Install Trollius ++++++++++++++++ +Trollius supports Python 2.7, 3.3 and 3.4. + +There is an experimental support of Python 3.5. Issues with Python 3.5: + +* don't support asyncio coroutines +* ``Task.get_task()`` is broken +* ``repr(Task)`` is broken + +Support of Python 2.6 and 3.2 was dropped in Trollius 2.1. + + Packages for Linux ================== @@ -67,11 +78,8 @@ Dependencies Trollius requires the `six `_ module. -On Python older than 3.2, the `futures `_ -project is needed to get a backport of ``concurrent.futures``. - -Python 2.6 requires also `ordereddict -`_. +Python 2.7 requires `futures `_ to get a +backport of ``concurrent.futures``. Build manually Trollius on Windows @@ -90,8 +98,8 @@ extension using:: Backports ========= -To support Python 2.6-3.4, many Python modules of the standard library have -been backported: +To support old Python versions, many Python modules of the standard library +have been backported: ======================== ========= ======================= Name Python Backport diff --git a/tox.ini b/tox.ini index 2dde9431..5058dd84 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py26,py27,py2_release,py2_no_ssl,py2_no_concurrent,py32,py33,py34,py3_release,py3_no_ssl +envlist = py27,py2_release,py2_no_ssl,py2_no_concurrent,py33,py34,py3_release,py3_no_ssl # and: pyflakes2,pyflakes3 [testenv] From 9fb2a3d01056adf2dc83022dea8220bab67d30ca Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 5 Feb 2016 12:22:48 +0100 Subject: [PATCH 1471/1502] Ugly hack to support Python 3.5 with the PEP 479 --- TODO.rst | 2 + doc/changelog.rst | 2 + tests/test_tasks.py | 34 ++++++++--- trollius/coroutines.py | 124 ++++++++++++++++++++++++++++++----------- trollius/tasks.py | 14 ++++- 5 files changed, 131 insertions(+), 45 deletions(-) diff --git a/TODO.rst b/TODO.rst index e4a785ed..f600cbdf 100644 --- a/TODO.rst +++ b/TODO.rst @@ -1,5 +1,7 @@ Unsorted "TODO" tasks: +* Python 3.5: Fix test_task_repr() +* Python 3.4: Fix test_asyncio() * Drop platform without ssl module? * streams.py:FIXME: should we support __aiter__ and __anext__ in Trollius? * replace selectors.py with selectors34: diff --git a/doc/changelog.rst b/doc/changelog.rst index 78fe94ed..11f02437 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -7,6 +7,8 @@ Version 2.1 Changes: +* Ugly hack to support Python 3.5 with the PEP 479. asyncio coroutines are + not supported on Python 3.5. * Better exception traceback. Patch written by Dhawal Yogesh Bhanushali. * Drop support for Python 2.6 and 3.2. diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 6576ddbf..9aad2596 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -159,6 +159,7 @@ def test_async_warning(self): 'function is deprecated, use ensure_'): self.assertIs(f, asyncio.async(f)) + @unittest.skipIf(PY35, 'FIXME: test broken on Python 3.5') def test_get_stack(self): non_local = {'T': None} @@ -225,29 +226,37 @@ def notmuch(): coro = format_coroutine(coro_qualname, 'running', src, t._source_traceback, generator=True) - self.assertEqual(repr(t), - '()]>' % coro) + # FIXME: it correctly broken on Python 3.5+ + if not coroutines._PEP479: + self.assertEqual(repr(t), + '()]>' % coro) # test cancelling Task t.cancel() # Does not take immediate effect! - self.assertEqual(repr(t), - '()]>' % coro) + # FIXME: it correctly broken on Python 3.5+ + if not coroutines._PEP479: + self.assertEqual(repr(t), + '()]>' % coro) # test cancelled Task self.assertRaises(asyncio.CancelledError, self.loop.run_until_complete, t) coro = format_coroutine(coro_qualname, 'done', src, t._source_traceback) - self.assertEqual(repr(t), - '' % coro) + # FIXME: it correctly broken on Python 3.5+ + if not coroutines._PEP479: + self.assertEqual(repr(t), + '' % coro) # test finished Task t = asyncio.Task(notmuch(), loop=self.loop) self.loop.run_until_complete(t) coro = format_coroutine(coro_qualname, 'done', src, t._source_traceback) - self.assertEqual(repr(t), - "" % coro) + # FIXME: it correctly broken on Python 3.5+ + if not coroutines._PEP479: + self.assertEqual(repr(t), + "" % coro) def test_task_repr_coro_decorator(self): self.loop.set_debug(False) @@ -1647,6 +1656,9 @@ def call(arg): cw.send(None) try: cw.send(arg) + except coroutines.ReturnException as ex: + ex.raised = True + return ex.value except StopIteration as ex: ex.raised = True return ex.value @@ -1689,7 +1701,11 @@ def kill_me(loop): self.assertEqual(len(self.loop._ready), 0) # remove the future used in kill_me(), and references to the task - del coro.gi_frame.f_locals['future'] + if coroutines._PEP479: + coro = coro.gi_frame.f_locals.pop('coro') + del coro.gi_frame.f_locals['future'] + else: + del coro.gi_frame.f_locals['future'] coro = None source_traceback = task._source_traceback task = None diff --git a/trollius/coroutines.py b/trollius/coroutines.py index eea8c60c..9def9847 100644 --- a/trollius/coroutines.py +++ b/trollius/coroutines.py @@ -7,6 +7,7 @@ import opcode import os import sys +import textwrap import traceback import types @@ -77,14 +78,50 @@ def yield_from_gen(gen): _YIELD_FROM_BUG = False -if compat.PY33: - # Don't use the Return class on Python 3.3 and later to support asyncio +if compat.PY35: + return_base_class = Exception +else: + return_base_class = StopIteration + +class ReturnException(return_base_class): + def __init__(self, *args): + return_base_class.__init__(self) + if not args: + self.value = None + elif len(args) == 1: + self.value = args[0] + else: + self.value = args + self.raised = False + if _DEBUG: + frame = sys._getframe(1) + self._source_traceback = traceback.extract_stack(frame) + # explicitly clear the reference to avoid reference cycles + frame = None + else: + self._source_traceback = None + + def __del__(self): + if self.raised: + return + + fmt = 'Return(%r) used without raise' + if self._source_traceback: + fmt += '\nReturn created at (most recent call last):\n' + tb = ''.join(traceback.format_list(self._source_traceback)) + fmt += tb.rstrip() + logger.error(fmt, self.value) + + +if compat.PY33 and not compat.PY35: + # Don't use the Return class on Python 3.3 and 3.4 to support asyncio # coroutines (to avoid the warning emited in Return destructor). # - # The problem is that Return inherits from StopIteration. "yield from - # trollius_coroutine". Task._step() does not receive the Return exception, - # because "yield from" handles it internally. So it's not possible to set - # the raised attribute to True to avoid the warning in Return destructor. + # The problem is that ReturnException inherits from StopIteration. + # "yield from trollius_coroutine". Task._step() does not receive the Return + # exception, because "yield from" handles it internally. So it's not + # possible to set the raised attribute to True to avoid the warning in + # Return destructor. def Return(*args): if not args: value = None @@ -94,34 +131,7 @@ def Return(*args): value = args return StopIteration(value) else: - class Return(StopIteration): - def __init__(self, *args): - StopIteration.__init__(self) - if not args: - self.value = None - elif len(args) == 1: - self.value = args[0] - else: - self.value = args - self.raised = False - if _DEBUG: - frame = sys._getframe(1) - self._source_traceback = traceback.extract_stack(frame) - # explicitly clear the reference to avoid reference cycles - frame = None - else: - self._source_traceback = None - - def __del__(self): - if self.raised: - return - - fmt = 'Return(%r) used without raise' - if self._source_traceback: - fmt += '\nReturn created at (most recent call last):\n' - tb = ''.join(traceback.format_list(self._source_traceback)) - fmt += tb.rstrip() - logger.error(fmt, self.value) + Return = ReturnException def debug_wrapper(gen): @@ -297,6 +307,47 @@ def _wraps(wrapped, else: _wraps = functools.wraps +_PEP479 = (sys.version_info >= (3, 5)) +if _PEP479: + # Need exec() because yield+return raises a SyntaxError on Python 2 + exec(textwrap.dedent(''' + def pep479_wrapper(func, coro_func): + @_wraps(func) + def pep479_wrapped(*args, **kw): + coro = coro_func(*args, **kw) + value = None + error = None + while True: + try: + if error is not None: + value = coro.throw(error) + elif value is not None: + value = coro.send(value) + else: + value = next(coro) + except RuntimeError: + # FIXME: special case for + # FIXME: "isinstance(exc.__context__, StopIteration)"? + raise + except StopIteration as exc: + return exc.value + except Return as exc: + exc.raised = True + return exc.value + except BaseException as exc: + raise + + try: + value = yield value + error = None + except BaseException as exc: + value = None + error = exc + + return pep479_wrapped + ''')) + + def coroutine(func): """Decorator to mark coroutines. @@ -331,6 +382,11 @@ def coro(*args, **kw): res = yield From(await_meth()) raise Return(res) + if _PEP479: + # FIXME: use @_wraps + coro = pep479_wrapper(func, coro) + coro = _wraps(func)(coro) + if not _DEBUG: if _types_coroutine is None: wrapper = coro diff --git a/trollius/tasks.py b/trollius/tasks.py index af5c8683..440a6d8b 100644 --- a/trollius/tasks.py +++ b/trollius/tasks.py @@ -23,7 +23,7 @@ from . import executor from . import futures from .locks import Lock, Condition, Semaphore, _ContextManager -from .coroutines import coroutine, From, Return +from .coroutines import coroutine, From, Return, ReturnException @@ -257,12 +257,22 @@ def _step(self, value=None, exc=None, exc_tb=None): result = coro.throw(exc) else: result = coro.send(value) + # On Python 3.3 and Python 3.4, ReturnException is not used in + # practice. But this except is kept to have a single code base + # for all Python versions. + except coroutines.ReturnException as exc: + if isinstance(exc, ReturnException): + exc.raised = True + result = exc.value + else: + result = None + self.set_result(result) except StopIteration as exc: if compat.PY33: # asyncio Task object? get the result of the coroutine result = exc.value else: - if isinstance(exc, Return): + if isinstance(exc, ReturnException): exc.raised = True result = exc.value else: From f7a315a22528a93ba3aa0dcdda5ba8c8001f7836 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 5 Feb 2016 13:15:50 +0100 Subject: [PATCH 1472/1502] Fix CoroWrapper.throw() Add support for up to 3 parameters, not only 1. --- trollius/coroutines.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trollius/coroutines.py b/trollius/coroutines.py index 9def9847..6a06989b 100644 --- a/trollius/coroutines.py +++ b/trollius/coroutines.py @@ -198,8 +198,8 @@ def send(self, *value): def send(self, value): return self.gen.send(value) - def throw(self, exc): - return self.gen.throw(exc) + def throw(self, exc_type, exc_value=None, exc_tb=None): + return self.gen.throw(exc_type, exc_value, exc_tb) def close(self): return self.gen.close() From 5641f314fbc11ba4c3360eb3ecea2c59d77ddb7c Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 5 Feb 2016 13:25:27 +0100 Subject: [PATCH 1473/1502] document the win32 fix --- doc/changelog.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/changelog.rst b/doc/changelog.rst index 11f02437..23640a76 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -11,6 +11,7 @@ Changes: not supported on Python 3.5. * Better exception traceback. Patch written by Dhawal Yogesh Bhanushali. * Drop support for Python 2.6 and 3.2. +* Fix tests on Windows with Python 2. Patch written by Gabi Davar. Version 2.0 (2015-07-13) From 2c7f6ac0a0c8ad584ecde846383d658e809c068a Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 5 Feb 2016 13:26:19 +0100 Subject: [PATCH 1474/1502] set version to 2.1 --- doc/conf.py | 2 +- setup.py | 16 ++++++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 0d3b8dd8..46ec218c 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -48,7 +48,7 @@ # built documents. # # The short X.Y version. -version = release = '2.0.1' +version = release = '2.1' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/setup.py b/setup.py index 19d9670d..2a3df465 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,9 @@ -# Release procedure: +# Prepare a release: +# # - fill trollius changelog # - run maybe ./update-asyncio-step1.sh -# - run all tests: tox +# - run all tests on Linux: tox +# - run tests on Windows # - test examples # - check that "python setup.py sdist" contains all files tracked by # the SCM (Mercurial): update MANIFEST.in if needed @@ -9,13 +11,19 @@ # - update version in setup.py (version) and doc/conf.py (version, release) # - set release date in doc/changelog.rst # - git commit +# - git push +# +# Release a new version: +# # - git tag trollius-VERSION # - git push --tags -# - git push # - On Linux: python setup.py register sdist upload # FIXME: don't use bdist_wheel because of # FIXME: https://github.com/haypo/trollius/issues/1 # - On Windows: python releaser.py release +# +# After the release: +# # - increment version in setup.py (version) and doc/conf.py (version, release) # - git commit -a && git push @@ -48,7 +56,7 @@ install_options = { "name": "trollius", - "version": "2.0.1", + "version": "2.1", "license": "Apache License 2.0", "author": 'Victor Stinner', "author_email": 'victor.stinner@gmail.com', From 98ba7f856929b8c72920e93c0cf4fe49f510e968 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 5 Feb 2016 15:42:16 +0100 Subject: [PATCH 1475/1502] Deprecate Trollius --- doc/asyncio.rst | 3 ++ doc/changelog.rst | 4 ++ doc/deprecated.rst | 98 ++++++++++++++++++++++++++++++++++++++++++++++ doc/dev.rst | 3 ++ doc/index.rst | 4 ++ doc/install.rst | 4 ++ doc/libraries.rst | 5 +++ doc/using.rst | 3 ++ setup.py | 6 +++ 9 files changed, 130 insertions(+) create mode 100644 doc/deprecated.rst diff --git a/doc/asyncio.rst b/doc/asyncio.rst index 011a9a8b..dd4cef3e 100644 --- a/doc/asyncio.rst +++ b/doc/asyncio.rst @@ -2,6 +2,9 @@ Trollius and asyncio ++++++++++++++++++++ +.. warning:: + :ref:`The Trollius project is now deprecated! ` + Differences between Trollius and asyncio ======================================== diff --git a/doc/changelog.rst b/doc/changelog.rst index 23640a76..d4750cf6 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -2,11 +2,15 @@ Change log ++++++++++ +.. warning:: + :ref:`The Trollius project is now deprecated! ` + Version 2.1 =========== Changes: +* :ref:`The Trollius project is now deprecated `. * Ugly hack to support Python 3.5 with the PEP 479. asyncio coroutines are not supported on Python 3.5. * Better exception traceback. Patch written by Dhawal Yogesh Bhanushali. diff --git a/doc/deprecated.rst b/doc/deprecated.rst new file mode 100644 index 00000000..f5f5743a --- /dev/null +++ b/doc/deprecated.rst @@ -0,0 +1,98 @@ +.. _deprecated: + +Trollius is deprecated +====================== + +.. warning:: + The Trollius project is now deprecated! + +Trollius is deprecated since the release 2.1. The maintainer of Trollius, +Victor Stinner, doesn't want to maintain the project anymore for many reasons. +This page lists some reasons. + +DON'T PANIC! There is the asyncio project which has the same API and is well +maintained! Only trollius is deprecated. + +Since the Trollius is used for some projects in the wild, Trollius will +not disappear. You can only expect *minimum* maintenance like minor bugfixes. +Don't expect new features nor synchronization with the latest asyncio. + +To be clear: I am looking for a new maintainer. If you want to take over +trollius: please do it, I will give you everything you need (and maybe more!). + +asyncio +------- + +`asyncio is here `_! asyncio is well +maintainted, has a growing community, has many libraries and don't stop +evolving to be enhanced by users feedbacks. I (Victor Stinner) even heard that +it is fast! + +asyncio requires Python 3.3 or newer. Yeah, you all have a long list of reasons +to not port your legacy code for Python 3. But I have my own reasons to prefer +to invest in the Python 3 rather than in legacy Python (Python 2). + + +No Trollius Community +--------------------- + +* Very the asyncio is growing everyday, there is no trollius community. + Sadly, asyncio libraries don't work for trollius. +* Only :ref:`very few libraries support Trollius `: to be clear, + there is no HTTP client for Trollius, whereas HTTP is the most common + protocol in 2015. +* It's a deliberate choice of library authors to not support Trollius to + keep a simple code base. The Python 3 is simpler than Python 2: supporting + Python 2 in a library requires more work. For example, aiohttp doesn't + want to support trollius. + +Python 2 +-------- + +* Seriously? Come on! Stop procrastination and upgrade your code to Python 3! + +Lack of interest +---------------- + +* The Trollius project was created to replace eventlet with asyncio (trollius) + in the OpenStack project, but replacing eventlet with trollius has failed for + different reasons. The original motivation is gone. + +Technical issues with trollius +------------------------------ + +* While Trollius API is "simple", the implementation is very complex to be + efficient on all platforms. +* Trollius requires :ref:`backports ` of libraries to support + old Python versions. These backports are not as well supported as the version + in the Python standard library. +* Supporting Python 2.7, Python 3.3 and Python 3.5 in the same code base + and supporting asyncio is very difficult. Generators and coroutines changed + a lot in each Python version. For example, hacks are required to support + Python 3.5 with the PEP 479 which changed the usage of the ``StopIteration`` + exception. Trollius initially also supported Python 2.6 and 3.2. + +Technical issues related to asyncio and yield from +-------------------------------------------------- + +* Synchronizing trollius with asyncio is a complex, tricky and error-prone task + which requires to be very carefull and a lot of manual changes. +* Porting Python 3 asyncio to Python 2 requires a lot of subtle changes which + takes a lot of time at each synchronization. +* It is not possible to use asyncio ``yield from`` coroutines in Python 2, + since the ``yield from`` instruction raises a ``SyntaxError`` exceptions. + Supporting Trollius and asyncio requires to duplicate some parts of the + library and application code which makes the code more complex and more + difficult to maintain. +* Trollius coroutines are slower than asyncio coroutines: the ``yield`` + instruction requires to delegate manually nested coroutines, whereas the + ``yield from`` instruction delegates them directly in the Python language. + asyncio requires less loop iterations than trollius for the same nested + coroutine. + +Other technical issues +---------------------- + +* Building wheel packages on Windows is difficult (need a running Windows, + need working Windows SDKs for each version of Python, need to test + and fix bugs specific to Windows, etc.) diff --git a/doc/dev.rst b/doc/dev.rst index c0bbf364..a6a3fcd9 100644 --- a/doc/dev.rst +++ b/doc/dev.rst @@ -1,6 +1,9 @@ Run tests ========= +.. warning:: + :ref:`The Trollius project is now deprecated! ` + Run tests with tox ------------------ diff --git a/doc/index.rst b/doc/index.rst index cb344e47..afc5167a 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -1,6 +1,9 @@ Trollius ======== +.. warning:: + :ref:`The Trollius project is now deprecated! ` + .. image:: trollius.jpg :alt: Trollius altaicus from Khangai Mountains (Mongòlia) :align: right @@ -58,6 +61,7 @@ Table Of Contents .. toctree:: + deprecated using install libraries diff --git a/doc/install.rst b/doc/install.rst index 845e6e5f..d67aa59a 100644 --- a/doc/install.rst +++ b/doc/install.rst @@ -2,6 +2,9 @@ Install Trollius ++++++++++++++++ +.. warning:: + :ref:`The Trollius project is now deprecated! ` + Trollius supports Python 2.7, 3.3 and 3.4. There is an experimental support of Python 3.5. Issues with Python 3.5: @@ -94,6 +97,7 @@ extension using:: C:\Python27\python.exe setup.py build_ext +.. _backports: Backports ========= diff --git a/doc/libraries.rst b/doc/libraries.rst index 424fd283..07939141 100644 --- a/doc/libraries.rst +++ b/doc/libraries.rst @@ -1,7 +1,12 @@ +.. _libraries: + ++++++++++++++++++ Trollius Libraries ++++++++++++++++++ +.. warning:: + :ref:`The Trollius project is now deprecated! ` + Libraries compatible with asyncio and trollius ============================================== diff --git a/doc/using.rst b/doc/using.rst index c730f86b..c8185dc2 100644 --- a/doc/using.rst +++ b/doc/using.rst @@ -2,6 +2,9 @@ Using Trollius ++++++++++++++ +.. warning:: + :ref:`The Trollius project is now deprecated! ` + Documentation of the asyncio module =================================== diff --git a/setup.py b/setup.py index 2a3df465..b42cba1f 100644 --- a/setup.py +++ b/setup.py @@ -79,4 +79,10 @@ if SETUPTOOLS: install_options['install_requires'] = requirements +print("!!! WARNING !!! The Trollius project is now deprecated!") +print("") + setup(**install_options) + +print("") +print("!!! WARNING !!! The Trollius project is now deprecated!") From dc11b97a0f88eb1b4a3a1f66a4b6319bf6fc7225 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 19 Feb 2016 14:44:53 +0100 Subject: [PATCH 1476/1502] post-release: set version to 2.2 --- doc/conf.py | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 46ec218c..8921f510 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -48,7 +48,7 @@ # built documents. # # The short X.Y version. -version = release = '2.1' +version = release = '2.2' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/setup.py b/setup.py index b42cba1f..64f14daa 100644 --- a/setup.py +++ b/setup.py @@ -56,7 +56,7 @@ install_options = { "name": "trollius", - "version": "2.1", + "version": "2.2", "license": "Apache License 2.0", "author": 'Victor Stinner', "author_email": 'victor.stinner@gmail.com', From 907fac43a50567368115121d39cfc7f53c29f907 Mon Sep 17 00:00:00 2001 From: pavan123k Date: Tue, 23 Feb 2016 17:05:45 -0800 Subject: [PATCH 1477/1502] From(proc.wait()) instead of form proc.wait() --- examples/shell.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/shell.py b/examples/shell.py index 91ba7fb1..c78594ed 100644 --- a/examples/shell.py +++ b/examples/shell.py @@ -46,7 +46,7 @@ def test_call(*args, **kw): except asyncio.TimeoutError: print("timeout! (%.1f sec)" % timeout) proc.kill() - yield from proc.wait() + yield From(proc.wait()) loop = asyncio.get_event_loop() loop.run_until_complete(cat(loop)) From 20544a61679f2c8d69ff50d667768f3303623102 Mon Sep 17 00:00:00 2001 From: Adam Chainz Date: Thu, 2 Jun 2016 22:02:32 +0100 Subject: [PATCH 1478/1502] Convert readthedocs link for their .org -> .io migration for hosted projects MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit As per [their blog post of the 27th April](https://blog.readthedocs.com/securing-subdomains/) ‘Securing subdomains’: > Starting today, Read the Docs will start hosting projects from subdomains on the domain readthedocs.io, instead of on readthedocs.org. This change addresses some security concerns around site cookies while hosting user generated data on the same domain as our dashboard. Test Plan: Manually visited all the links I’ve modified. --- README.rst | 2 +- doc/asyncio.rst | 2 +- doc/changelog.rst | 2 +- doc/index.rst | 2 +- doc/install.rst | 2 +- doc/libraries.rst | 6 +++--- 6 files changed, 8 insertions(+), 8 deletions(-) diff --git a/README.rst b/README.rst index a1bf4953..50caa739 100644 --- a/README.rst +++ b/README.rst @@ -32,7 +32,7 @@ Python 2.6-3.5. It has been tested on Windows, Linux, Mac OS X, FreeBSD and OpenIndiana. * `Asyncio documentation `_ -* `Trollius documentation `_ +* `Trollius documentation `_ * `Trollius project in the Python Cheeseshop (PyPI) `_ * `Trollius project at Github `_ diff --git a/doc/asyncio.rst b/doc/asyncio.rst index dd4cef3e..2c379353 100644 --- a/doc/asyncio.rst +++ b/doc/asyncio.rst @@ -168,7 +168,7 @@ This option is used by the following projects which work on Trollius and asyncio module if available, or import ``trollius``. * `Tornado `_ supports asyncio and Trollius since Tornado 3.2: `tornado.platform.asyncio — Bridge between asyncio and Tornado - `_. It tries to import + `_. It tries to import asyncio or fallback on importing trollius. Another option is to provide functions returning ``Future`` objects, so the diff --git a/doc/changelog.rst b/doc/changelog.rst index d4750cf6..8c54c967 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -580,7 +580,7 @@ Other changes: Trollius changes: * Add a new Sphinx documentation: - http://trollius.readthedocs.org/ + https://trollius.readthedocs.io/ * tox: pass posargs to nosetests. Patch contributed by Ian Wienand. * Fix support of Python 3.2 and add py32 to tox.ini * Merge with Tulip 0.4.1 diff --git a/doc/index.rst b/doc/index.rst index afc5167a..1ea3d974 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -42,7 +42,7 @@ on Python 2. Trollius works on Python 2.7, 3.3 and 3.4. It has been tested on Windows, Linux, Mac OS X, FreeBSD and OpenIndiana. * `Asyncio documentation `_ -* `Trollius documentation `_ (this document) +* `Trollius documentation `_ (this document) * `Trollius project in the Python Cheeseshop (PyPI) `_ (download wheel packages and tarballs) diff --git a/doc/install.rst b/doc/install.rst index d67aa59a..d9e65ccd 100644 --- a/doc/install.rst +++ b/doc/install.rst @@ -91,7 +91,7 @@ Build manually Trollius on Windows On Windows, if you cannot use precompiled wheel packages, an extension module must be compiled: the ``_overlapped`` module (source code: ``overlapped.c``). Read `Compile Python extensions on Windows -`_ +`_ to prepare your environment to build the Python extension. Then build the extension using:: diff --git a/doc/libraries.rst b/doc/libraries.rst index 07939141..9abc3c07 100644 --- a/doc/libraries.rst +++ b/doc/libraries.rst @@ -10,7 +10,7 @@ Trollius Libraries Libraries compatible with asyncio and trollius ============================================== -* `aioeventlet `_: asyncio API +* `aioeventlet `_: asyncio API implemented on top of eventlet * `aiogevent `_: asyncio API implemented on top of gevent @@ -24,12 +24,12 @@ Libraries compatible with asyncio and trollius module if available, or import ``trollius``. * `Tornado `_ supports asyncio and Trollius since Tornado 3.2: `tornado.platform.asyncio — Bridge between asyncio and Tornado - `_. It tries to import + `_. It tries to import asyncio or fallback on importing trollius. Specific Ports ============== * `trollius-redis `_: - A port of `asyncio-redis `_ to + A port of `asyncio-redis `_ to trollius From 3d8e1db7738552f546a3ee7da85e1d8161d594a7 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Mon, 20 Nov 2017 17:35:07 +0100 Subject: [PATCH 1479/1502] Update GitHub URL --- README.rst | 2 +- doc/index.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 50caa739..ddb59ccc 100644 --- a/README.rst +++ b/README.rst @@ -35,7 +35,7 @@ OpenIndiana. * `Trollius documentation `_ * `Trollius project in the Python Cheeseshop (PyPI) `_ -* `Trollius project at Github `_ +* `Trollius project at Github `_ (bug tracker, source code) * Copyright/license: Open source, Apache 2.0. Enjoy! diff --git a/doc/index.rst b/doc/index.rst index 1ea3d974..b667bc9e 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -46,7 +46,7 @@ Windows, Linux, Mac OS X, FreeBSD and OpenIndiana. * `Trollius project in the Python Cheeseshop (PyPI) `_ (download wheel packages and tarballs) -* `Trollius project at Github `_ +* `Trollius project at Github `_ (bug tracker, source code) * Mailing list: `python-tulip Google Group `_ From b4bacc05d9455c84838c00044c65a430552a9445 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Wed, 22 Nov 2017 18:05:57 +0100 Subject: [PATCH 1480/1502] trollius.readthedocs.io has been removed Update also releaser URL --- README.rst | 1 - doc/index.rst | 1 - releaser.conf | 2 +- 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/README.rst b/README.rst index ddb59ccc..4c55d506 100644 --- a/README.rst +++ b/README.rst @@ -32,7 +32,6 @@ Python 2.6-3.5. It has been tested on Windows, Linux, Mac OS X, FreeBSD and OpenIndiana. * `Asyncio documentation `_ -* `Trollius documentation `_ * `Trollius project in the Python Cheeseshop (PyPI) `_ * `Trollius project at Github `_ diff --git a/doc/index.rst b/doc/index.rst index b667bc9e..7119f029 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -42,7 +42,6 @@ on Python 2. Trollius works on Python 2.7, 3.3 and 3.4. It has been tested on Windows, Linux, Mac OS X, FreeBSD and OpenIndiana. * `Asyncio documentation `_ -* `Trollius documentation `_ (this document) * `Trollius project in the Python Cheeseshop (PyPI) `_ (download wheel packages and tarballs) diff --git a/releaser.conf b/releaser.conf index 37281395..1b45fd09 100644 --- a/releaser.conf +++ b/releaser.conf @@ -1,5 +1,5 @@ # Configuration file for the tool "releaser" -# https://bitbucket.org/haypo/misc/src/tip/bin/releaser.py +# https://github.com/vstinner/misc/blob/master/bin/releaser.py [project] name = trollius From c6c86417fe40646181d0b1f7b7ee071de95798bc Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 9 Mar 2018 10:59:38 +0100 Subject: [PATCH 1481/1502] Remove aiotest tests --- MANIFEST.in | 2 +- doc/dev.rst | 3 --- run_aiotest.py | 14 -------------- tox.ini | 12 ++---------- 4 files changed, 3 insertions(+), 28 deletions(-) delete mode 100644 run_aiotest.py diff --git a/MANIFEST.in b/MANIFEST.in index f3b496f4..9dabbdcd 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,7 +1,7 @@ include AUTHORS COPYING TODO.rst tox.ini include Makefile include overlapped.c pypi.bat -include check.py runtests.py run_aiotest.py release.py +include check.py runtests.py release.py include update-asyncio-*.sh include .travis.yml include releaser.conf diff --git a/doc/dev.rst b/doc/dev.rst index a6a3fcd9..a05179bd 100644 --- a/doc/dev.rst +++ b/doc/dev.rst @@ -29,9 +29,6 @@ On Python older than 3.3, unit tests require the `mock `_ module. Python 2.6 and 2.7 require also `unittest2 `_. -To run ``run_aiotest.py``, you need the `aiotest -`_ test suite: ``pip install aiotest``. - Run tests on UNIX ----------------- diff --git a/run_aiotest.py b/run_aiotest.py deleted file mode 100644 index da133286..00000000 --- a/run_aiotest.py +++ /dev/null @@ -1,14 +0,0 @@ -import aiotest.run -import sys -import trollius -if sys.platform == 'win32': - from trollius.windows_utils import socketpair -else: - from socket import socketpair - -config = aiotest.TestConfig() -config.asyncio = trollius -config.socketpair = socketpair -config.new_event_pool_policy = trollius.DefaultEventLoopPolicy -config.call_soon_check_closed = True -aiotest.run.main(config) diff --git a/tox.ini b/tox.ini index 5058dd84..f51177da 100644 --- a/tox.ini +++ b/tox.ini @@ -4,31 +4,28 @@ envlist = py27,py2_release,py2_no_ssl,py2_no_concurrent,py33,py34,py3_release,py [testenv] deps= - aiotest six setenv = TROLLIUSDEBUG = 1 commands= python -Wd runtests.py -r {posargs} - python -Wd run_aiotest.py -r {posargs} [testenv:pyflakes2] basepython = python2 deps= pyflakes commands= - pyflakes trollius tests runtests.py check.py run_aiotest.py setup.py + pyflakes trollius tests runtests.py check.py setup.py [testenv:pyflakes3] basepython = python3 deps= pyflakes commands= - pyflakes trollius tests runtests.py check.py run_aiotest.py setup.py + pyflakes trollius tests runtests.py check.py setup.py [testenv:py26] deps= - aiotest futures mock==1.0.1 ordereddict @@ -37,7 +34,6 @@ deps= [testenv:py27] deps= - aiotest futures mock six @@ -47,7 +43,6 @@ deps= # Run tests in release mode basepython = python2 deps= - aiotest futures mock six @@ -58,7 +53,6 @@ setenv = [testenv:py2_no_ssl] basepython = python2 deps= - aiotest futures mock six @@ -69,7 +63,6 @@ commands= [testenv:py2_no_concurrent] basepython = python2 deps= - aiotest futures mock six @@ -79,7 +72,6 @@ commands= [testenv:py32] deps= - aiotest mock six From 97c56f0103fc581b191be74da3900f42ad090ca0 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 9 Mar 2018 11:01:29 +0100 Subject: [PATCH 1482/1502] Add "No Maintenance Intended" badge --- README.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.rst b/README.rst index 4c55d506..5aa64bcd 100644 --- a/README.rst +++ b/README.rst @@ -1,3 +1,11 @@ +======== +Trollius +======== + +.. image:: http://unmaintained.tech/badge.svg + :target: http://unmaintained.tech/ + :alt: No Maintenance Intended + Trollius provides infrastructure for writing single-threaded concurrent code using coroutines, multiplexing I/O access over sockets and other resources, running network clients and servers, and other related primitives. From fa13a1c0182840b72faa17078eb78c8b5e5f45c4 Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 9 Mar 2018 11:08:10 +0100 Subject: [PATCH 1483/1502] Prepare 2.2 release --- doc/changelog.rst | 16 ++++++++++++++-- setup.py | 8 ++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/doc/changelog.rst b/doc/changelog.rst index 8c54c967..31704701 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -5,8 +5,20 @@ Change log .. warning:: :ref:`The Trollius project is now deprecated! ` -Version 2.1 -=========== +Version 2.2 (2018-03-09) +======================== + +Changes: + +* ``run_aiotest.py`` has been removed since the ``aiotest`` project has been + removed +* Add the "No Maintenance Intended" badge to README +* The Trollius documentation is no longer online: + http://trollius.readthedocs.io/ has been removed +* Update the GitHub URL to: https://github.com/vstinner/trollius + +Version 2.1 (2016-02-05) +======================== Changes: diff --git a/setup.py b/setup.py index 64f14daa..f68e143c 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,6 @@ # Prepare a release: # +# - git pull --rebase # - fill trollius changelog # - run maybe ./update-asyncio-step1.sh # - run all tests on Linux: tox @@ -17,15 +18,18 @@ # # - git tag trollius-VERSION # - git push --tags -# - On Linux: python setup.py register sdist upload +# - Remove untracked files/dirs: git clean -fdx +# - On Linux: python2 setup.py sdist # FIXME: don't use bdist_wheel because of # FIXME: https://github.com/haypo/trollius/issues/1 +# - twine upload dist/* # - On Windows: python releaser.py release # # After the release: # # - increment version in setup.py (version) and doc/conf.py (version, release) -# - git commit -a && git push +# - git commit -a -m "post release X.Y" +# - git push import os import sys From 30d97606a3129b865e1d00700b2bb46cf7fe48ef Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 9 Mar 2018 11:18:30 +0100 Subject: [PATCH 1484/1502] post release: set version to 2.3 --- doc/conf.py | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 8921f510..bed8900b 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -48,7 +48,7 @@ # built documents. # # The short X.Y version. -version = release = '2.2' +version = release = '2.3' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/setup.py b/setup.py index f68e143c..569ca50a 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,7 @@ install_options = { "name": "trollius", - "version": "2.2", + "version": "2.3", "license": "Apache License 2.0", "author": 'Victor Stinner', "author_email": 'victor.stinner@gmail.com', From 479788fb7b568c0c685599e55e0e98cab9fe474c Mon Sep 17 00:00:00 2001 From: Victor Stinner Date: Fri, 9 Mar 2018 11:33:01 +0100 Subject: [PATCH 1485/1502] README: add deprecation warning --- README.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.rst b/README.rst index 5aa64bcd..773d0562 100644 --- a/README.rst +++ b/README.rst @@ -6,6 +6,10 @@ Trollius :target: http://unmaintained.tech/ :alt: No Maintenance Intended +.. warning:: + The Trollius project is deprecated since the release 2.1.! Use asyncio on + Python 3 instead. + Trollius provides infrastructure for writing single-threaded concurrent code using coroutines, multiplexing I/O access over sockets and other resources, running network clients and servers, and other related primitives. From 6d0a8772af9b721d167f498817575acf4042c679 Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Mon, 29 Jul 2019 11:42:09 -0500 Subject: [PATCH 1486/1502] Add appveyor.yml from https://github.com/vstinner/trollius/pull/13/files --- appveyor.yml | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 appveyor.yml diff --git a/appveyor.yml b/appveyor.yml new file mode 100644 index 00000000..41fd7927 --- /dev/null +++ b/appveyor.yml @@ -0,0 +1,29 @@ +version: build-{build}-{branch} + +environment: + matrix: + # http://www.appveyor.com/docs/installed-software#python lists available + # versions + - PYTHON: "C:\\Python27" + - PYTHON: "C:\\Python27-x64" + # 32-bit + - PYTHON: "C:\\pypy2.7-v7.1.0-win32" + PYTHON_ID: "pypy" + PYTHON_EXE: pypy + PYTHON_VERSION: "2.7.x" + PYTHON_ARCH: "32" + +install: + - "set PATH=%PYTHON%;%PYTHON%\\Scripts;%PATH%" + - python --version + - pip install wheel aiotest futures mock six unittest2 + - pip install -e . + +build_script: + - python setup.py bdist_wheel + +test_script: + - python runtests.py + +artifacts: + - path: dist/*.whl From cd162671d51427981cb40bb1866fcd81b25fcd0f Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Mon, 29 Jul 2019 12:05:53 -0500 Subject: [PATCH 1487/1502] Metadata updates. Be explicit about supporting only Python 2.7 Drop testing on appveyor because the needed dependencies no longer exist. Incorporate https://github.com/vstinner/trollius/pull/15/ --- MANIFEST.in | 1 + README.rst | 46 +--------------------- appveyor.yml | 32 ++++++++++----- doc/changelog.rst | 15 +++++-- setup.py | 91 +++++++++---------------------------------- trollius/selectors.py | 3 +- 6 files changed, 57 insertions(+), 131 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 9dabbdcd..5d4033fc 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,6 +4,7 @@ include overlapped.c pypi.bat include check.py runtests.py release.py include update-asyncio-*.sh include .travis.yml +include appveyor.yml include releaser.conf include doc/conf.py doc/make.bat doc/Makefile diff --git a/README.rst b/README.rst index 773d0562..79764d3d 100644 --- a/README.rst +++ b/README.rst @@ -7,47 +7,5 @@ Trollius :alt: No Maintenance Intended .. warning:: - The Trollius project is deprecated since the release 2.1.! Use asyncio on - Python 3 instead. - -Trollius provides infrastructure for writing single-threaded concurrent -code using coroutines, multiplexing I/O access over sockets and other -resources, running network clients and servers, and other related primitives. -Here is a more detailed list of the package contents: - -* a pluggable event loop with various system-specific implementations; - -* transport and protocol abstractions (similar to those in `Twisted - `_); - -* concrete support for TCP, UDP, SSL, subprocess pipes, delayed calls, and - others (some may be system-dependent); - -* a ``Future`` class that mimics the one in the ``concurrent.futures`` module, - but adapted for use with the event loop; - -* coroutines and tasks based on generators (``yield``), to help write - concurrent code in a sequential fashion; - -* cancellation support for ``Future``\s and coroutines; - -* synchronization primitives for use between coroutines in a single thread, - mimicking those in the ``threading`` module; - -* an interface for passing work off to a threadpool, for times when you - absolutely, positively have to use a library that makes blocking I/O calls. - -Trollius is a portage of the `asyncio project -`_ (`PEP 3156 -`_) on Python 2. Trollius works on -Python 2.6-3.5. It has been tested on Windows, Linux, Mac OS X, FreeBSD and -OpenIndiana. - -* `Asyncio documentation `_ -* `Trollius project in the Python Cheeseshop (PyPI) - `_ -* `Trollius project at Github `_ - (bug tracker, source code) -* Copyright/license: Open source, Apache 2.0. Enjoy! - -See also the `asyncio project at Github `_. + The Trollius project is deprecated and unsupported. It is on PyPI + to support existing dependencies only. diff --git a/appveyor.yml b/appveyor.yml index 41fd7927..255b638b 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -1,12 +1,14 @@ -version: build-{build}-{branch} - environment: matrix: - # http://www.appveyor.com/docs/installed-software#python lists available - # versions - - PYTHON: "C:\\Python27" - PYTHON: "C:\\Python27-x64" - # 32-bit + PYTHON_VERSION: "2.7.x" # currently 2.7.13 + PYTHON_ARCH: "64" + PYTHON_EXE: python + - PYTHON: "C:\\Python27" + PYTHON_VERSION: "2.7.x" # currently 2.7.13 + PYTHON_ARCH: "32" + PYTHON_EXE: python + GWHEEL_ONLY: true - PYTHON: "C:\\pypy2.7-v7.1.0-win32" PYTHON_ID: "pypy" PYTHON_EXE: pypy @@ -15,15 +17,25 @@ environment: install: - "set PATH=%PYTHON%;%PYTHON%\\Scripts;%PATH%" - - python --version - - pip install wheel aiotest futures mock six unittest2 - - pip install -e . + - "SET PATH=%PYTHON%;%PYTHON%\\Scripts;%PYTHON%\\bin;%PATH%" + - "SET PYEXE=%PYTHON%\\%PYTHON_EXE%.exe" + + # Check that we have the expected version and architecture for Python + - "%PYEXE% --version" + + - "%PYEXE% -m pip install --disable-pip-version-check -U pip" + - "%PYEXE% -m pip install -U setuptools wheel" build_script: - python setup.py bdist_wheel +# Theer's nothing to actually test anymore; the aiotest dependency is gone. test_script: - - python runtests.py + - python --version artifacts: - path: dist/*.whl + +cache: + - "%TMP%\\py\\" + - '%LOCALAPPDATA%\pip\Cache' diff --git a/doc/changelog.rst b/doc/changelog.rst index 31704701..710625a9 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -1,10 +1,18 @@ -++++++++++ -Change log -++++++++++ +============ + Change log +============ .. warning:: :ref:`The Trollius project is now deprecated! ` +Version 2.3 (2019-07-29) +======================== + +This is a packaging-only release. It is intended to be the last +release. + +- Release Windows wheels. + Version 2.2 (2018-03-09) ======================== @@ -712,4 +720,3 @@ Other changes of Tulip 0.4.1: ======================= - First public release - diff --git a/setup.py b/setup.py index 569ca50a..015b7e58 100644 --- a/setup.py +++ b/setup.py @@ -1,46 +1,6 @@ -# Prepare a release: -# -# - git pull --rebase -# - fill trollius changelog -# - run maybe ./update-asyncio-step1.sh -# - run all tests on Linux: tox -# - run tests on Windows -# - test examples -# - check that "python setup.py sdist" contains all files tracked by -# the SCM (Mercurial): update MANIFEST.in if needed -# - run test on Windows: releaser.py test -# - update version in setup.py (version) and doc/conf.py (version, release) -# - set release date in doc/changelog.rst -# - git commit -# - git push -# -# Release a new version: -# -# - git tag trollius-VERSION -# - git push --tags -# - Remove untracked files/dirs: git clean -fdx -# - On Linux: python2 setup.py sdist -# FIXME: don't use bdist_wheel because of -# FIXME: https://github.com/haypo/trollius/issues/1 -# - twine upload dist/* -# - On Windows: python releaser.py release -# -# After the release: -# -# - increment version in setup.py (version) and doc/conf.py (version, release) -# - git commit -a -m "post release X.Y" -# - git push - import os import sys -try: - from setuptools import setup, Extension - SETUPTOOLS = True -except ImportError: - SETUPTOOLS = False - # Use distutils.core as a fallback. - # We won't be able to build the Wheel file on Windows. - from distutils.core import setup, Extension +from setuptools import setup, Extension with open("README.rst") as fp: long_description = fp.read() @@ -53,40 +13,27 @@ extensions.append(ext) requirements = ['six'] -if sys.version_info < (2, 7): - requirements.append('ordereddict') if sys.version_info < (3,): requirements.append('futures') -install_options = { - "name": "trollius", - "version": "2.3", - "license": "Apache License 2.0", - "author": 'Victor Stinner', - "author_email": 'victor.stinner@gmail.com', - - "description": "Port of the Tulip project (asyncio module, PEP 3156) on Python 2", - "long_description": long_description, - "url": "https://github.com/haypo/trollius", - - "classifiers": [ +setup( + name="trollius", + version="2.3", + license="Apache License 2.0", + author='Victor Stinner', + author_email='victor.stinner@gmail.com', + description="Deprecated, unmaintained port of the asyncio module (PEP 3156) on Python 2", + long_description=long_description, + url="https://github.com/jamadden/trollius", + classifiers=[ "Programming Language :: Python", - "Programming Language :: Python :: 3", + "Programming Language :: Python :: 2.7", "License :: OSI Approved :: Apache Software License", ], - - "packages": ["trollius"], - "test_suite": "runtests.runtests", - - "ext_modules": extensions, -} -if SETUPTOOLS: - install_options['install_requires'] = requirements - -print("!!! WARNING !!! The Trollius project is now deprecated!") -print("") - -setup(**install_options) - -print("") -print("!!! WARNING !!! The Trollius project is now deprecated!") + packages=[ + "trollius", + ], + ext_modules=extensions, + install_requires=requirements, + python_requires=">=2.7, < 3", +) diff --git a/trollius/selectors.py b/trollius/selectors.py index cf0475de..2b24e5e0 100644 --- a/trollius/selectors.py +++ b/trollius/selectors.py @@ -309,7 +309,8 @@ def _select(self, r, w, _, timeout=None): r, w, x = select.select(r, w, w, timeout) return r, w + x, [] else: - _select = select.select + def _select(self, r, w, x, timeout=None): + return select.select(r, w, x, timeout) def select(self, timeout=None): timeout = None if timeout is None else max(timeout, 0) From 439637de0a88959711b7f293e6ed0665796d0575 Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Mon, 29 Jul 2019 12:10:03 -0500 Subject: [PATCH 1488/1502] Copy logic from gevent to install pypy. --- appveyor.yml | 39 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/appveyor.yml b/appveyor.yml index 255b638b..c61e3e8a 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -16,7 +16,44 @@ environment: PYTHON_ARCH: "32" install: - - "set PATH=%PYTHON%;%PYTHON%\\Scripts;%PATH%" + # If there is a newer build queued for the same PR, cancel this one. + # The AppVeyor 'rollout builds' option is supposed to serve the same + # purpose but it is problematic because it tends to cancel builds pushed + # directly to master instead of just PR builds (or the converse). + # credits: JuliaLang developers. + - ps: if ($env:APPVEYOR_PULL_REQUEST_NUMBER -and $env:APPVEYOR_BUILD_NUMBER -ne ((Invoke-RestMethod ` + https://ci.appveyor.com/api/projects/$env:APPVEYOR_ACCOUNT_NAME/$env:APPVEYOR_PROJECT_SLUG/history?recordsNumber=50).builds | ` + Where-Object pullRequestId -eq $env:APPVEYOR_PULL_REQUEST_NUMBER)[0].buildNumber) { ` + throw "There are newer queued builds for this pull request, failing early." } + - ECHO "Filesystem root:" + - ps: "ls \"C:/\"" + + - ECHO "Installed SDKs:" + - ps: "ls \"C:/Program Files/Microsoft SDKs/Windows\"" + + # Install Python (from the official .msi of http://python.org) and pip when + # not already installed. + # PyPy portion based on https://github.com/wbond/asn1crypto/blob/master/appveyor.yml + - ps: + $env:PYTMP = "${env:TMP}\py"; + if (!(Test-Path "$env:PYTMP")) { + New-Item -ItemType directory -Path "$env:PYTMP" | Out-Null; + } + if ("${env:PYTHON_ID}" -eq "pypy") { + if (!(Test-Path "${env:PYTMP}\pypy2-v7.1.0-win32.zip")) { + (New-Object Net.WebClient).DownloadFile('https://bitbucket.org/pypy/pypy/downloads/pypy2.7-v7.1.0-win32.zip', "${env:PYTMP}\pypy2-v7.1.0-win32.zip"); + } + 7z x -y "${env:PYTMP}\pypy2-v7.1.0-win32.zip" -oC:\ | Out-Null; + & "${env:PYTHON}\pypy.exe" "-mensurepip"; + + } + elseif (-not(Test-Path($env:PYTHON))) { + & appveyor\install.ps1; + } + + # Prepend newly installed Python to the PATH of this build (this cannot be + # done from inside the powershell script as it would require to restart + # the parent CMD process). - "SET PATH=%PYTHON%;%PYTHON%\\Scripts;%PYTHON%\\bin;%PATH%" - "SET PYEXE=%PYTHON%\\%PYTHON_EXE%.exe" From 1a76339089c7edee88df54316dc392d13acca7c6 Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Mon, 29 Jul 2019 12:13:31 -0500 Subject: [PATCH 1489/1502] Use right executable for pypy --- appveyor.yml | 2 +- setup.cfg | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) create mode 100644 setup.cfg diff --git a/appveyor.yml b/appveyor.yml index c61e3e8a..b45d2110 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -68,7 +68,7 @@ build_script: # Theer's nothing to actually test anymore; the aiotest dependency is gone. test_script: - - python --version + - "%PYEXE% --version" artifacts: - path: dist/*.whl diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 00000000..e058b5a4 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,8 @@ +[bdist_wheel] +universal = 0 + +[zest.releaser] +create-wheel = no + +[metadata] +long_description_content_type = text/x-rst From 85e055dfea4d02354e63935c1a20dfb2708444be Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Mon, 29 Jul 2019 12:15:54 -0500 Subject: [PATCH 1490/1502] No really. --- appveyor.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/appveyor.yml b/appveyor.yml index b45d2110..ccb86e5a 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -64,7 +64,7 @@ install: - "%PYEXE% -m pip install -U setuptools wheel" build_script: - - python setup.py bdist_wheel + - "%PYEXE% setup.py bdist_wheel" # Theer's nothing to actually test anymore; the aiotest dependency is gone. test_script: From 6b843ec24315e9092447c667dcdeeb7ba1878aa4 Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Mon, 29 Jul 2019 12:19:48 -0500 Subject: [PATCH 1491/1502] Nope, drop PyPy: It does not actually build. --- appveyor.yml | 14 +++++++++----- doc/changelog.rst | 7 ++++--- setup.py | 2 +- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/appveyor.yml b/appveyor.yml index ccb86e5a..2848eecb 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -9,11 +9,15 @@ environment: PYTHON_ARCH: "32" PYTHON_EXE: python GWHEEL_ONLY: true - - PYTHON: "C:\\pypy2.7-v7.1.0-win32" - PYTHON_ID: "pypy" - PYTHON_EXE: pypy - PYTHON_VERSION: "2.7.x" - PYTHON_ARCH: "32" + # PyPy 7.1 won't actually build the _overlapped extension: + # overlapped.c(92) : warning C4013: 'PyErr_SetExcFromWindowsErr' undefined; assuming extern returning int + # overlapped.c(92) : warning C4047: 'return' : 'PyObject *' differs in levels of indirection from 'int' + # overlapped.c(1166) : warning C4101: 'AddressObj' : unreferenced local variable + # - PYTHON: "C:\\pypy2.7-v7.1.0-win32" + # PYTHON_ID: "pypy" + # PYTHON_EXE: pypy + # PYTHON_VERSION: "2.7.x" + # PYTHON_ARCH: "32" install: # If there is a newer build queued for the same PR, cancel this one. diff --git a/doc/changelog.rst b/doc/changelog.rst index 710625a9..f61ccdae 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -5,13 +5,14 @@ .. warning:: :ref:`The Trollius project is now deprecated! ` -Version 2.3 (2019-07-29) -======================== +Version 2.2.post1 (2019-07-29) +============================== This is a packaging-only release. It is intended to be the last release. -- Release Windows wheels. +- Release Windows wheels for CPython 2.7. +- Use ``python_requires`` to restrict installation to Python 2.7. Version 2.2 (2018-03-09) ======================== diff --git a/setup.py b/setup.py index 015b7e58..268f0fa0 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ setup( name="trollius", - version="2.3", + version="2.2.post1", license="Apache License 2.0", author='Victor Stinner', author_email='victor.stinner@gmail.com', From a692f5aa68129984f053c4eb3f6af56f4f496152 Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Mon, 29 Jul 2019 12:21:36 -0500 Subject: [PATCH 1492/1502] Tweaking versions. --- doc/changelog.rst | 4 ++-- setup.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/changelog.rst b/doc/changelog.rst index f61ccdae..8574be12 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -5,8 +5,8 @@ .. warning:: :ref:`The Trollius project is now deprecated! ` -Version 2.2.post1 (2019-07-29) -============================== +Version 2.2.post1 (Unknown) +=========================== This is a packaging-only release. It is intended to be the last release. diff --git a/setup.py b/setup.py index 268f0fa0..4243636e 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ setup( name="trollius", - version="2.2.post1", + version="2.2.post1.dev0", license="Apache License 2.0", author='Victor Stinner', author_email='victor.stinner@gmail.com', From 65e0c607d1a4f3eb31ad457a0741f4513b19369e Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Mon, 29 Jul 2019 12:22:47 -0500 Subject: [PATCH 1493/1502] add missing metadata. --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index 4243636e..065d02bf 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,8 @@ packages=[ "trollius", ], + zip_safe=False, + keywords="Deprecated Unmaintained asyncio backport", ext_modules=extensions, install_requires=requirements, python_requires=">=2.7, < 3", From 44ce7ec3fb28a30cc9bc076083460325f9ebb92b Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Mon, 29 Jul 2019 12:23:10 -0500 Subject: [PATCH 1494/1502] Preparing release 2.2.post1 --- doc/changelog.rst | 4 ++-- setup.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/changelog.rst b/doc/changelog.rst index 8574be12..350790df 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -5,8 +5,8 @@ .. warning:: :ref:`The Trollius project is now deprecated! ` -Version 2.2.post1 (Unknown) -=========================== +2.2.post1 (2019-07-29) +====================== This is a packaging-only release. It is intended to be the last release. diff --git a/setup.py b/setup.py index 065d02bf..7abc9f96 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ setup( name="trollius", - version="2.2.post1.dev0", + version="2.2.post1", license="Apache License 2.0", author='Victor Stinner', author_email='victor.stinner@gmail.com', From 25857261afd85911acb372e8f4776d87d4b6ef93 Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Mon, 29 Jul 2019 12:25:42 -0500 Subject: [PATCH 1495/1502] Back to development: 2.2.post2 --- doc/changelog.rst | 6 ++++++ setup.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/doc/changelog.rst b/doc/changelog.rst index 350790df..ef63cd07 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -5,6 +5,12 @@ .. warning:: :ref:`The Trollius project is now deprecated! ` +2.2.post2 (unreleased) +====================== + +- Nothing changed yet. + + 2.2.post1 (2019-07-29) ====================== diff --git a/setup.py b/setup.py index 7abc9f96..14c777a0 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ setup( name="trollius", - version="2.2.post1", + version="2.2.post2.dev0", license="Apache License 2.0", author='Victor Stinner', author_email='victor.stinner@gmail.com', From 2aeeb23e9771b67234dd6fef338e57000412b784 Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Mon, 29 Jul 2019 16:59:07 -0500 Subject: [PATCH 1496/1502] Add Development Status :: 7 - Inactive classifier. --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 14c777a0..ebe612d9 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ "Programming Language :: Python", "Programming Language :: Python :: 2.7", "License :: OSI Approved :: Apache Software License", + "Development Status :: 7 - Inactive", ], packages=[ "trollius", From a74449ad5e7de34ca6cadedd062d9ee4307db911 Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Wed, 28 Apr 2021 12:11:34 -0500 Subject: [PATCH 1497/1502] Make socket.error with errno EBADF get raised as OSError. Fixes #17 --- CHANGES.rst | 8 ++++++++ MANIFEST.in | 2 ++ runtests.py | 4 ++-- setup.py | 5 ++++- tests/test_py33_exceptions.py | 35 +++++++++++++++++++++++++++++++++++ trollius/py33_exceptions.py | 3 +++ 6 files changed, 54 insertions(+), 3 deletions(-) create mode 100644 CHANGES.rst create mode 100644 tests/test_py33_exceptions.py diff --git a/CHANGES.rst b/CHANGES.rst new file mode 100644 index 00000000..5082340a --- /dev/null +++ b/CHANGES.rst @@ -0,0 +1,8 @@ +========= + Changes +========= + +2.2.1 (unreleased) +================== + +- Properly reraise socket.error with an errno of EBADF as an OSError. diff --git a/MANIFEST.in b/MANIFEST.in index 5d4033fc..017fcfba 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -14,3 +14,5 @@ include examples/*.py include tests/*.crt tests/*.pem tests/*.key include tests/*.py + +include *.rst diff --git a/runtests.py b/runtests.py index 541a47ea..7eea28c1 100755 --- a/runtests.py +++ b/runtests.py @@ -151,7 +151,7 @@ def randomize_tests(tests, seed): random.shuffle(tests._tests) -class TestsFinder: +class TestsFinder(object): def __init__(self, testsdir, includes=(), excludes=()): self._testsdir = testsdir @@ -167,7 +167,7 @@ def find_available_tests(self): mods = [mod for mod, _ in load_modules(self._testsdir)] for mod in mods: for name in set(dir(mod)): - if name.endswith('Tests'): + if name.endswith('Tests') or name.startswith('Test'): self._test_factories.append(getattr(mod, name)) def load_tests(self): diff --git a/setup.py b/setup.py index 7abc9f96..20c654b8 100644 --- a/setup.py +++ b/setup.py @@ -5,6 +5,9 @@ with open("README.rst") as fp: long_description = fp.read() +with open('CHANGES.rst') as fp: + long_description += '\n\n' + fp.read() + extensions = [] if os.name == 'nt': ext = Extension( @@ -18,7 +21,7 @@ setup( name="trollius", - version="2.2.post1", + version="2.2.1", license="Apache License 2.0", author='Victor Stinner', author_email='victor.stinner@gmail.com', diff --git a/tests/test_py33_exceptions.py b/tests/test_py33_exceptions.py new file mode 100644 index 00000000..42fb4e23 --- /dev/null +++ b/tests/test_py33_exceptions.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +""" +Tests for py33_exceptions. + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import unittest +from trollius import py33_exceptions + +class TestWrapErrors(unittest.TestCase): + + def test_ebadf_wrapped_to_OSError(self): + # https://github.com/jamadden/trollius/issues/17 + import socket + import os + import errno + s = socket.socket() + os.close(s.fileno()) + + with self.assertRaises(socket.error) as exc: + s.send(b'abc') + + self.assertEqual(exc.exception.errno, errno.EBADF) + + with self.assertRaises(OSError) as exc: + py33_exceptions.wrap_error(s.send, b'abc') + + self.assertEqual(exc.exception.errno, errno.EBADF) + +if __name__ == '__main__': + unittest.main() diff --git a/trollius/py33_exceptions.py b/trollius/py33_exceptions.py index f10dfe9e..477aaddf 100644 --- a/trollius/py33_exceptions.py +++ b/trollius/py33_exceptions.py @@ -79,6 +79,9 @@ class ProcessLookupError(OSError): errno.ESRCH: ProcessLookupError, } +if hasattr(errno, 'EBADF') and errno.EBADF not in _MAP_ERRNO: + _MAP_ERRNO[errno.EBADF] = OSError + if sys.platform == 'win32': from trollius import _overlapped _MAP_ERRNO.update({ From e60f33a3fc1269465189df6c19aa279b8d08ac9c Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Wed, 28 Apr 2021 12:20:17 -0500 Subject: [PATCH 1498/1502] Update changelog. --- CHANGES.rst | 721 ++++++++++++++++++++++++++++++++++++++++++++++ doc/changelog.rst | 721 +--------------------------------------------- 2 files changed, 722 insertions(+), 720 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 5082340a..3d0dc8c3 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -2,7 +2,728 @@ Changes ========= +.. warning:: + The Trollius project is now deprecated! + + 2.2.1 (unreleased) ================== - Properly reraise socket.error with an errno of EBADF as an OSError. + +2.2.post1 (2019-07-29) +====================== + +This is a packaging-only release. It is intended to be the last +release. + +- Release Windows wheels for CPython 2.7. +- Use ``python_requires`` to restrict installation to Python 2.7. + +Version 2.2 (2018-03-09) +======================== + +Changes: + +* ``run_aiotest.py`` has been removed since the ``aiotest`` project has been + removed +* Add the "No Maintenance Intended" badge to README +* The Trollius documentation is no longer online: + http://trollius.readthedocs.io/ has been removed +* Update the GitHub URL to: https://github.com/vstinner/trollius + +Version 2.1 (2016-02-05) +======================== + +Changes: + +* The Trollius project is now deprecated. +* Ugly hack to support Python 3.5 with the PEP 479. asyncio coroutines are + not supported on Python 3.5. +* Better exception traceback. Patch written by Dhawal Yogesh Bhanushali. +* Drop support for Python 2.6 and 3.2. +* Fix tests on Windows with Python 2. Patch written by Gabi Davar. + + +Version 2.0 (2015-07-13) +======================== + +Summary: + +* SSL support on Windows for proactor event loop with Python 3.5 and newer +* Many race conditions were fixed in the proactor event loop +* Trollius moved to Github and the fork was recreated on top to asyncio git + repository +* Many resource leaks (ex: unclosed sockets) were fixed +* Optimization of socket connections: avoid: don't call the slow getaddrinfo() + function to ensure that the address is already resolved. The check is now + only done in debug mode. + +The Trollius project moved from Bitbucket to Github. The project is now a fork +of the Git repository of the asyncio project (previously called the "tulip" +project), the trollius source code lives in the trollius branch. + +The new Trollius home page is now: https://github.com/haypo/trollius + +The asyncio project moved to: https://github.com/python/asyncio + +Note: the PEP 492 is not supported in trollius yet. + +API changes: + +* Issue #234: Drop JoinableQueue on Python 3.5+ +* add the asyncio.ensure_future() function, previously called async(). + The async() function is now deprecated. +* New event loop methods: set_task_factory() and get_task_factory(). +* Python issue #23347: Make BaseSubprocessTransport.wait() private. +* Python issue #23347: send_signal(), kill() and terminate() methods of + BaseSubprocessTransport now check if the transport was closed and if the + process exited. +* Python issue #23209, #23225: selectors.BaseSelector.get_key() now raises a + RuntimeError if the selector is closed. And selectors.BaseSelector.close() + now clears its internal reference to the selector mapping to break a + reference cycle. Initial patch written by Martin Richard. +* PipeHandle.fileno() of asyncio.windows_utils now raises an exception if the + pipe is closed. +* Remove Overlapped.WaitNamedPipeAndConnect() of the _overlapped module, + it is no more used and it had issues. +* Python issue #23537: Remove 2 unused private methods of + BaseSubprocessTransport: _make_write_subprocess_pipe_proto, + _make_read_subprocess_pipe_proto. Methods only raise NotImplementedError and + are never used. +* Remove unused SSLProtocol._closing attribute + +New SSL implementation: + +* Python issue #22560: On Python 3.5 and newer, use new SSL implementation + based on ssl.MemoryBIO instead of the legacy SSL implementation. Patch + written by Antoine Pitrou, based on the work of Geert Jansen. +* If available, the new SSL implementation can be used by ProactorEventLoop to + support SSL. + +Enhance, fix and cleanup the IocpProactor: + +* Python issue #23293: Rewrite IocpProactor.connect_pipe(). Add + _overlapped.ConnectPipe() which tries to connect to the pipe for asynchronous + I/O (overlapped): call CreateFile() in a loop until it doesn't fail with + ERROR_PIPE_BUSY. Use an increasing delay between 1 ms and 100 ms. +* Tulip issue #204: Fix IocpProactor.accept_pipe(). + Overlapped.ConnectNamedPipe() now returns a boolean: True if the pipe is + connected (if ConnectNamedPipe() failed with ERROR_PIPE_CONNECTED), False if + the connection is in progress. +* Tulip issue #204: Fix IocpProactor.recv(). If ReadFile() fails with + ERROR_BROKEN_PIPE, the operation is not pending: don't register the + overlapped. +* Python issue #23095: Rewrite _WaitHandleFuture.cancel(). + _WaitHandleFuture.cancel() now waits until the wait is cancelled to clear its + reference to the overlapped object. To wait until the cancellation is done, + UnregisterWaitEx() is used with an event instead of UnregisterWait(). +* Python issue #23293: Rewrite IocpProactor.connect_pipe() as a coroutine. Use + a coroutine with asyncio.sleep() instead of call_later() to ensure that the + scheduled call is cancelled. +* Fix ProactorEventLoop.start_serving_pipe(). If a client was connected before + the server was closed: drop the client (close the pipe) and exit +* Python issue #23293: Cleanup IocpProactor.close(). The special case for + connect_pipe() is no more needed. connect_pipe() doesn't use overlapped + operations anymore. +* IocpProactor.close(): don't cancel futures which are already cancelled +* Enhance (fix) BaseProactorEventLoop._loop_self_reading(). Handle correctly + CancelledError: just exit. On error, log the exception and exit; don't try to + close the event loop (it doesn't work). + +Bug fixes: + +* Fix LifoQueue's and PriorityQueue's put() and task_done(). +* Issue #222: Fix the @coroutine decorator for functions without __name__ + attribute like functools.partial(). Enhance also the representation of a + CoroWrapper if the coroutine function is a functools.partial(). +* Python issue #23879: SelectorEventLoop.sock_connect() must not call connect() + again if the first call to connect() raises an InterruptedError. When the C + function connect() fails with EINTR, the connection runs in background. We + have to wait until the socket becomes writable to be notified when the + connection succeed or fails. +* Fix _SelectorTransport.__repr__() if the event loop is closed +* Fix repr(BaseSubprocessTransport) if it didn't start yet +* Workaround CPython bug #23353. Don't use yield/yield-from in an except block + of a generator. Store the exception and handle it outside the except block. +* Fix BaseSelectorEventLoop._accept_connection(). Close the transport on error. + In debug mode, log errors using call_exception_handler(). +* Fix _UnixReadPipeTransport and _UnixWritePipeTransport. Only start reading + when connection_made() has been called. +* Fix _SelectorSslTransport.close(). Don't call protocol.connection_lost() if + protocol.connection_made() was not called yet: if the SSL handshake failed or + is still in progress. The close() method can be called if the creation of the + connection is cancelled, by a timeout for example. +* Fix _SelectorDatagramTransport constructor. Only start reading after + connection_made() has been called. +* Fix _SelectorSocketTransport constructor. Only start reading when + connection_made() has been called: protocol.data_received() must not be + called before protocol.connection_made(). +* Fix SSLProtocol.eof_received(). Wake-up the waiter if it is not done yet. +* Close transports on error. Fix create_datagram_endpoint(), + connect_read_pipe() and connect_write_pipe(): close the transport if the task + is cancelled or on error. +* Close the transport on subprocess creation failure +* Fix _ProactorBasePipeTransport.close(). Set the _read_fut attribute to None + after cancelling it. +* Python issue #23243: Fix _UnixWritePipeTransport.close(). Do nothing if the + transport is already closed. Before it was not possible to close the + transport twice. +* Python issue #23242: SubprocessStreamProtocol now closes the subprocess + transport at subprocess exit. Clear also its reference to the transport. +* Fix BaseEventLoop._create_connection_transport(). Close the transport if the + creation of the transport (if the waiter) gets an exception. +* Python issue #23197: On SSL handshake failure, check if the waiter is + cancelled before setting its exception. +* Python issue #23173: Fix SubprocessStreamProtocol.connection_made() to handle + cancelled waiter. +* Python issue #23173: If an exception is raised during the creation of a + subprocess, kill the subprocess (close pipes, kill and read the return + status). Log an error in such case. +* Python issue #23209: Break some reference cycles in asyncio. Patch written by + Martin Richard. + +Optimization: + +* Only call _check_resolved_address() in debug mode. _check_resolved_address() + is implemented with getaddrinfo() which is slow. If available, use + socket.inet_pton() instead of socket.getaddrinfo(), because it is much faster + +Other changes: + +* Python issue #23456: Add missing @coroutine decorators +* Python issue #23475: Fix test_close_kill_running(). Really kill the child + process, don't mock completly the Popen.kill() method. This change fix memory + leaks and reference leaks. +* BaseSubprocessTransport: repr() mentions when the child process is running +* BaseSubprocessTransport.close() doesn't try to kill the process if it already + finished. +* Tulip issue #221: Fix docstring of QueueEmpty and QueueFull +* Fix subprocess_attach_write_pipe example. Close the transport, not directly + the pipe. +* Python issue #23347: send_signal(), terminate(), kill() don't check if the + transport was closed. The check broken a Tulip example and this limitation is + arbitrary. Check if _proc is None should be enough. Enhance also close(): do + nothing when called the second time. +* Python issue #23347: Refactor creation of subprocess transports. +* Python issue #23243: On Python 3.4 and newer, emit a ResourceWarning when an + event loop or a transport is not explicitly closed +* tox.ini: enable ResourceWarning warnings +* Python issue #23243: test_sslproto: Close explicitly transports +* SSL transports now clear their reference to the waiter. +* Python issue #23208: Add BaseEventLoop._current_handle. In debug mode, + BaseEventLoop._run_once() now sets the BaseEventLoop._current_handle + attribute to the handle currently executed. +* Replace test_selectors.py with the file of Python 3.5 adapted for asyncio and + Python 3.3. +* Tulip issue #184: FlowControlMixin constructor now get the event loop if the + loop parameter is not set. +* _ProactorBasePipeTransport now sets the _sock attribute to None when the + transport is closed. +* Python issue #23219: cancelling wait_for() now cancels the task +* Python issue #23243: Close explicitly event loops and transports in tests +* Python issue #23140: Fix cancellation of Process.wait(). Check the state of + the waiter future before setting its result. +* Python issue #23046: Expose the BaseEventLoop class in the asyncio namespace +* Python issue #22926: In debug mode, call_soon(), call_at() and call_later() + methods of BaseEventLoop now use the identifier of the current thread to + ensure that they are called from the thread running the event loop. Before, + the get_event_loop() method was used to check the thread, and no exception + was raised when the thread had no event loop. Now the methods always raise an + exception in debug mode when called from the wrong thread. It should help to + notice misusage of the API. + +2014-12-19: Version 1.0.4 +========================= + +Changes: + +* Python issue #22922: create_task(), call_at(), call_soon(), + call_soon_threadsafe() and run_in_executor() now raise an error if the event + loop is closed. Initial patch written by Torsten Landschoff. +* Python issue #22921: Don't require OpenSSL SNI to pass hostname to ssl + functions. Patch by Donald Stufft. +* Add run_aiotest.py: run the aiotest test suite. +* tox now also run the aiotest test suite +* Python issue #23074: get_event_loop() now raises an exception if the thread + has no event loop even if assertions are disabled. + +Bugfixes: + +* Fix a race condition in BaseSubprocessTransport._try_finish(): ensure that + connection_made() is called before connection_lost(). +* Python issue #23009: selectors, make sure EpollSelecrtor.select() works when + no file descriptor is registered. +* Python issue #22922: Fix ProactorEventLoop.close(). Call + _stop_accept_futures() before sestting the _closed attribute, otherwise + call_soon() raises an error. +* Python issue #22429: Fix EventLoop.run_until_complete(), don't stop the event + loop if a BaseException is raised, because the event loop is already stopped. +* Initialize more Future and Task attributes in the class definition to avoid + attribute errors in destructors. +* Python issue #22685: Set the transport of stdout and stderr StreamReader + objects in the SubprocessStreamProtocol. It allows to pause the transport to + not buffer too much stdout or stderr data. +* BaseSelectorEventLoop.close() now closes the self-pipe before calling the + parent close() method. If the event loop is already closed, the self-pipe is + not unregistered from the selector. + + +2014-10-20: Version 1.0.3 +========================= + +Changes: + +* On Python 2 in debug mode, Future.set_exception() now stores the traceback + object of the exception in addition to the exception object. When a task + waiting for another task and the other task raises an exception, the + traceback object is now copied with the exception. Be careful, storing the + traceback object may create reference leaks. +* Use ssl.create_default_context() if available to create the default SSL + context: Python 2.7.9 and newer, or Python 3.4 and newer. +* On Python 3.5 and newer, reuse socket.socketpair() in the windows_utils + submodule. +* On Python 3.4 and newer, use os.set_inheritable(). +* Enhance protocol representation: add "closed" or "closing" info. +* run_forever() now consumes BaseException of the temporary task. If the + coroutine raised a BaseException, consume the exception to not log a warning. + The caller doesn't have access to the local task. +* Python issue 22448: cleanup _run_once(), only iterate once to remove delayed + calls that were cancelled. +* The destructor of the Return class now shows where the Return object was + created. +* run_tests.py doesn't catch any exceptions anymore when loading tests, only + catch SkipTest. +* Fix (SSL) tests for the future Python 2.7.9 which includes a "new" ssl + module: module backported from Python 3.5. +* BaseEventLoop.add_signal_handler() now raises an exception if the parameter + is a coroutine function. +* Coroutine functions and objects are now rejected with a TypeError by the + following functions: add_signal_handler(), call_at(), call_later(), + call_soon(), call_soon_threadsafe(), run_in_executor(). + + +2014-10-02: Version 1.0.2 +========================= + +This release fixes bugs. It also provides more information in debug mode on +error. + +Major changes: + +* Tulip issue #203: Add _FlowControlMixin.get_write_buffer_limits() method. +* Python issue #22063: socket operations (socket,recv, sock_sendall, + sock_connect, sock_accept) of SelectorEventLoop now raise an exception in + debug mode if sockets are in blocking mode. + +Major bugfixes: + +* Tulip issue #205: Fix a race condition in BaseSelectorEventLoop.sock_connect(). +* Tulip issue #201: Fix a race condition in wait_for(). Don't raise a + TimeoutError if we reached the timeout and the future completed in the same + iteration of the event loop. A side effect of the bug is that Queue.get() + looses items. +* PipeServer.close() now cancels the "accept pipe" future which cancels the + overlapped operation. + +Other changes: + +* Python issue #22448: Improve cancelled timer callback handles cleanup. Patch + by Joshua Moore-Oliva. +* Python issue #22369: Change "context manager protocol" to "context management + protocol". Patch written by Serhiy Storchaka. +* Tulip issue #206: In debug mode, keep the callback in the representation of + Handle and TimerHandle after cancel(). +* Tulip issue #207: Fix test_tasks.test_env_var_debug() to use correct asyncio + module. +* runtests.py: display a message to mention if tests are run in debug or + release mode +* Tulip issue #200: Log errors in debug mode instead of simply ignoring them. +* Tulip issue #200: _WaitHandleFuture._unregister_wait() now catchs and logs + exceptions. +* _fatal_error() method of _UnixReadPipeTransport and _UnixWritePipeTransport + now log all exceptions in debug mode +* Fix debug log in BaseEventLoop.create_connection(): get the socket object + from the transport because SSL transport closes the old socket and creates a + new SSL socket object. +* Remove the _SelectorSslTransport._rawsock attribute: it contained the closed + socket (not very useful) and it was not used. +* Fix _SelectorTransport.__repr__() if the transport was closed +* Use the new os.set_blocking() function of Python 3.5 if available + + +2014-07-30: Version 1.0.1 +========================= + +This release supports PyPy and has a better support of asyncio coroutines, +especially in debug mode. + +Changes: + +* Tulip issue #198: asyncio.Condition now accepts an optional lock object. +* Enhance representation of Future and Future subclasses: add "created at". + +Bugfixes: + +* Fix Trollius issue #9: @trollius.coroutine now works on callbable objects + (without ``__name__`` attribute), not only on functions. +* Fix Trollius issue #13: asyncio futures are now accepted in all functions: + as_completed(), async(), @coroutine, gather(), run_until_complete(), + wrap_future(). +* Fix support of asyncio coroutines in debug mode. If the last instruction + of the coroutine is "yield from", it's an asyncio coroutine and it does not + need to use From(). +* Fix and enhance _WaitHandleFuture.cancel(): + + - Tulip issue #195: Fix a crash on Windows: don't call UnregisterWait() twice + if a _WaitHandleFuture is cancelled twice. + - Fix _WaitHandleFuture.cancel(): return the result of the parent cancel() + method (True or False). + - _WaitHandleFuture.cancel() now notify IocpProactor through the overlapped + object that the wait was cancelled. + +* Tulip issue #196: _OverlappedFuture now clears its reference to the + overlapped object. IocpProactor keeps a reference to the overlapped object + until it is notified of its completion. Log also an error in debug mode if it + gets unexpected notifications. +* Fix runtest.py to be able to log at level DEBUG. + +Other changes: + +* BaseSelectorEventLoop._write_to_self() now logs errors in debug mode. +* Fix as_completed(): it's not a coroutine, don't use ``yield From(...)`` but + ``yield ...`` +* Tulip issue #193: Convert StreamWriter.drain() to a classic coroutine. +* Tulip issue #194: Don't use sys.getrefcount() in unit tests: the full test + suite now pass on PyPy. + + +2014-07-21: Version 1.0 +======================= + +Major Changes +------------- + +* Event loops have a new ``create_task()`` method, which is now the recommanded + way to create a task object. This method can be overriden by third-party + event loops to use their own task class. +* The debug mode has been improved a lot. Set ``TROLLIUSDEBUG`` envrironment + variable to ``1`` and configure logging to log at level ``logging.DEBUG`` + (ex: ``logging.basicConfig(level=logging.DEBUG)``). Changes: + + - much better representation of Trollius objects (ex: ``repr(task)``): + unified ```` format, use qualified name when available + - show the traceback where objects were created + - show the current filename and line number for coroutine + - show the filename and line number where objects were created + - log most important socket events + - log most important subprocess events + +* ``Handle.cancel()`` now clears references to callback and args +* Log an error if a Task is destroyed while it is still pending, but only on + Python 3.4 and newer. +* Fix for asyncio coroutines when passing tuple value in debug mode. + ``CoroWrapper.send()`` now checks if it is called from a "yield from" + generator to decide if the parameter should be unpacked or not. +* ``Process.communicate()`` now ignores ``BrokenPipeError`` and + ``ConnectionResetError`` exceptions. +* Rewrite signal handling on Python 3.3 and newer to fix a race condition: use + the "self-pipe" to get signal numbers. + + +Other Changes +------------- + +* Fix ``ProactorEventLoop()`` in debug mode +* Fix a race condition when setting the result of a Future with + ``call_soon()``. Add an helper, a private method, to set the result only if + the future was not cancelled. +* Fix ``asyncio.__all__``: export also ``unix_events`` and ``windows_events`` + symbols. For example, on Windows, it was not possible to get + ``ProactorEventLoop`` or ``DefaultEventLoopPolicy`` using ``from asyncio + import *``. +* ``Handle.cancel()`` now clears references to callback and args +* Make Server attributes and methods private, the sockets attribute remains + public. +* BaseEventLoop.create_datagram_endpoint() now waits until + protocol.connection_made() has been called. Document also why transport + constructors use a waiter. +* _UnixSubprocessTransport: fix file mode of stdin: open stdin in write mode, + not in read mode. + + +2014-06-23: version 0.4 +======================= + +Changes between Trollius 0.3 and 0.4: + +* Trollius event loop now supports asyncio coroutines: + + - Trollius coroutines can yield asyncio coroutines, + - asyncio coroutines can yield Trollius coroutines, + - asyncio.set_event_loop() accepts a Trollius event loop, + - asyncio.set_event_loop_policy() accepts a Trollius event loop policy. + +* The ``PYTHONASYNCIODEBUG`` envrionment variable has been renamed to + ``TROLLIUSDEBUG``. The environment variable is now used even if the Python + command line option ``-E`` is used. +* Synchronize with Tulip. +* Support PyPy (fix subproces, fix unit tests). + +Tulip changes: + +* Tulip issue #171: BaseEventLoop.close() now raises an exception if the event + loop is running. You must first stop the event loop and then wait until it + stopped, before closing it. +* Tulip issue #172: only log selector timing in debug mode +* Enable the debug mode of event loops when the ``TROLLIUSDEBUG`` environment + variable is set +* BaseEventLoop._assert_is_current_event_loop() now only raises an exception if + the current loop is set. +* Tulip issue #105: in debug mode, log callbacks taking more than 100 ms to be + executed. +* Python issue 21595: ``BaseSelectorEventLoop._read_from_self()`` reads all + available bytes from the "self pipe", not only a single byte. This change + reduces the risk of having the pipe full and so getting the "BlockingIOError: + [Errno 11] Resource temporarily unavailable" message. +* Python issue 21723: asyncio.Queue: support any type of number (ex: float) for + the maximum size. Patch written by Vajrasky Kok. +* Issue #173: Enhance repr(Handle) and repr(Task): add the filename and line + number, when available. For task, the current line number of the coroutine + is used. +* Add BaseEventLoop.is_closed() method. run_forever() and run_until_complete() + methods now raises an exception if the event loop was closed. +* Make sure that socketpair() close sockets on error. Close the listening + socket if sock.bind() raises an exception. +* Fix ResourceWarning: close sockets on errors. + BaseEventLoop.create_connection(), BaseEventLoop.create_datagram_endpoint() + and _UnixSelectorEventLoop.create_unix_server() now close the newly created + socket on error. +* Rephrase and fix docstrings. +* Fix tests on Windows: wait for the subprocess exit. Before, regrtest failed + to remove the temporary test directory because the process was still running + in this directory. +* Refactor unit tests. + +On Python 3.5, generators now get their name from the function, no more from +the code. So the ``@coroutine`` decorator doesn't loose the original name of +the function anymore. + + +2014-05-26: version 0.3 +======================= + +Rename the Python module ``asyncio`` to ``trollius`` to support Python 3.4. On +Python 3.4, there is already a module called ``asyncio`` in the standard +library which conflicted with ``asyncio`` module of Trollius 0.2. To write +asyncio code working on Trollius and Tulip, use ``import trollius as asyncio``. + +Changes between Trollius 0.2 and 0.3: + +* Synchronize with Tulip 3.4.1. +* Enhance Trollius documentation. +* Trollius issue #7: Fix ``asyncio.time_monotonic`` on Windows older than + Vista (ex: Windows 2000 and Windows XP). +* Fedora packages have been accepted. + +Changes between Tulip 3.4.0 and 3.4.1: + +* Pull in Solaris ``devpoll`` support by Giampaolo Rodola + (``trollius.selectors`` module). +* Add options ``-r`` and ``--randomize`` to runtests.py to randomize test + order. +* Add a simple echo client/server example. +* Tulip issue #166: Add ``__weakref__`` slots to ``Handle`` and + ``CoroWrapper``. +* ``EventLoop.create_unix_server()`` now raises a ``ValueError`` if path and + sock are specified at the same time. +* Ensure ``call_soon()``, ``call_later()`` and ``call_at()`` are invoked on + current loop in debug mode. Raise a ``RuntimeError`` if the event loop of the + current thread is different. The check should help to debug thread-safetly + issue. Patch written by David Foster. +* Tulip issue #157: Improve test_events.py, avoid ``run_briefly()`` which is + not reliable. +* Reject add/remove reader/writer when event loop is closed. + +Bugfixes of Tulip 3.4.1: + +* Tulip issue #168: ``StreamReader.read(-1)`` from pipe may hang if + data exceeds buffer limit. +* CPython issue #21447: Fix a race condition in + ``BaseEventLoop._write_to_self()``. +* Different bugfixes in ``CoroWrapper`` of ``trollius.coroutines``, class used + when running Trollius in debug mode: + + - Fix ``CoroWrapper`` to workaround yield-from bug in CPython 3.4.0. The + CPython bug is now fixed in CPython 3.4.1 and 3.5. + - Make sure ``CoroWrapper.send`` proxies one argument correctly. + - CPython issue #21340: Be careful accessing instance variables in ``__del__``. + - Tulip issue #163: Add ``gi_{frame,running,code}`` properties to + ``CoroWrapper``. + +* Fix ``ResourceWarning`` warnings +* Tulip issue #159: Fix ``windows_utils.socketpair()``. Use ``"127.0.0.1"`` + (IPv4) or ``"::1"`` (IPv6) host instead of ``"localhost"``, because + ``"localhost"`` may be a different IP address. Reject also invalid arguments: + only ``AF_INET`` and ``AF_INET6`` with ``SOCK_STREAM`` (and ``proto=0``) are + supported. +* Tulip issue #158: ``Task._step()`` now also sets ``self`` to ``None`` if an + exception is raised. ``self`` is set to ``None`` to break a reference cycle. + + +2014-03-04: version 0.2 +======================= + +Trollius now uses ``yield From(...)`` syntax which looks close to Tulip ``yield +from ...`` and allows to port more easily Trollius code to Tulip. The usage of +``From()`` is not mandatory yet, but it may become mandatory in a future +version. However, if ``yield`` is used without ``From``, an exception is +raised if the event loop is running in debug mode. + +Major changes: + +* Replace ``yield ...`` syntax with ``yield From(...)`` +* On Python 2, Future.set_exception() now only saves the traceback if the debug + mode of the event loop is enabled for best performances in production mode. + Use ``loop.set_debug(True)`` to save the traceback. + +Bugfixes: + +* Fix ``BaseEventLoop.default_exception_handler()`` on Python 2: get the + traceback from ``sys.exc_info()`` +* Fix unit tests on SSL sockets on Python older than 2.6.6. Example: + Mac OS 10.6 with Python 2.6.1 or OpenIndiana 148 with Python 2.6.4. +* Fix error handling in the asyncio.time_monotonic module +* Fix acquire() method of Lock, Condition and Semaphore: don't return a context + manager but True, as Tulip. Task._step() now does the trick. + +Other changes: + +* tox.ini: set PYTHONASYNCIODEBUG to 1 to run tests + +2014-02-25: version 0.1.6 +========================= + +Trollius changes: + +* Add a new Sphinx documentation: + https://trollius.readthedocs.io/ +* tox: pass posargs to nosetests. Patch contributed by Ian Wienand. +* Fix support of Python 3.2 and add py32 to tox.ini +* Merge with Tulip 0.4.1 + +Major changes of Tulip 0.4.1: + +* Issue #81: Add support for UNIX Domain Sockets. New APIs: + + - loop.create_unix_connection() + - loop.create_unix_server() + - streams.open_unix_connection() + - streams.start_unix_server() + +* Issue #80: Add new event loop exception handling API. New APIs: + + - loop.set_exception_handler() + - loop.call_exception_handler() + - loop.default_exception_handler() + +* Issue #136: Add get_debug() and set_debug() methods to BaseEventLoopTests. + Add also a ``PYTHONASYNCIODEBUG`` environment variable to debug coroutines + since Python startup, to be able to debug coroutines defined directly in the + asyncio module. + +Other changes of Tulip 0.4.1: + +* asyncio.subprocess: Fix a race condition in communicate() +* Fix _ProactorWritePipeTransport._pipe_closed() +* Issue #139: Improve error messages on "fatal errors". +* Issue #140: WriteTransport.set_write_buffer_size() to call + _maybe_pause_protocol() +* Issue #129: BaseEventLoop.sock_connect() now raises an error if the address + is not resolved (hostname instead of an IP address) for AF_INET and + AF_INET6 address families. +* Issue #131: as_completed() and wait() now raises a TypeError if the list of + futures is not a list but a Future, Task or coroutine object +* Python issue #20495: Skip test_read_pty_output() of test_asyncio on FreeBSD + older than FreeBSD 8 +* Issue #130: Add more checks on subprocess_exec/subprocess_shell parameters +* Issue #126: call_soon(), call_soon_threadsafe(), call_later(), call_at() + and run_in_executor() now raise a TypeError if the callback is a coroutine + function. +* Python issue #20505: BaseEventLoop uses again the resolution of the clock + to decide if scheduled tasks should be executed or not. + + +2014-02-10: version 0.1.5 +========================= + +- Merge with Tulip 0.3.1: + + * New asyncio.subprocess module + * _UnixWritePipeTransport now also supports character devices, as + _UnixReadPipeTransport. Patch written by Jonathan Slenders. + * StreamReader.readexactly() now raises an IncompleteReadError if the + end of stream is reached before we received enough bytes, instead of + returning less bytes than requested. + * poll and epoll selectors now round the timeout away from zero (instead of + rounding towards zero) to fix a performance issue + * asyncio.queue: Empty renamed to QueueEmpty, Full to QueueFull + * _fatal_error() of _UnixWritePipeTransport and _ProactorBasePipeTransport + don't log BrokenPipeError nor ConnectionResetError + * Future.set_exception(exc) now instanciate exc if it is a class + * streams.StreamReader: Use bytearray instead of deque of bytes for internal + buffer + +- Fix test_wait_for() unit test + +2014-01-22: version 0.1.4 +========================= + +- The project moved to https://bitbucket.org/enovance/trollius +- Fix CoroWrapper (_DEBUG=True): add missing import +- Emit a warning when Return is not raised +- Merge with Tulip to get latest Tulip bugfixes +- Fix dependencies in tox.ini for the different Python versions + +2014-01-13: version 0.1.3 +========================= + +- Workaround bugs in the ssl module of Python older than 2.6.6. For example, + Mac OS 10.6 (Snow Leopard) uses Python 2.6.1. +- ``return x, y`` is now written ``raise Return(x, y)`` instead of + ``raise Return((x, y))`` +- Support "with (yield lock):" syntax for Lock, Condition and Semaphore +- SSL support is now optional: don't fail if the ssl module is missing +- Add tox.ini, tool to run unit tests. For example, "tox -e py27" creates a + virtual environment to run tests with Python 2.7. + +2014-01-08: version 0.1.2 +========================= + +- Trollius now supports CPython 2.6-3.4, PyPy and Windows. All unit tests + pass with CPython 2.7 on Linux. +- Fix Windows support. Fix compilation of the _overlapped module and add a + asyncio._winapi module (written in pure Python). Patch written by Marc + Schlaich. +- Support Python 2.6: require an extra dependency, + ordereddict (and unittest2 for unit tests) +- Support Python 3.2, 3.3 and 3.4 +- Support PyPy 2.2 +- Don't modify __builtins__ nor the ssl module to inject backported exceptions + like BlockingIOError or SSLWantReadError. Exceptions are available in the + asyncio module, ex: asyncio.BlockingIOError. + +2014-01-06: version 0.1.1 +========================= + +- Fix asyncio.time_monotonic on Mac OS X +- Fix create_connection(ssl=True) +- Don't export backported SSLContext in the ssl module anymore to not confuse + libraries testing hasattr(ssl, "SSLContext") +- Relax dependency on the backported concurrent.futures module: use a + synchronous executor if the module is missing + +2014-01-04: version 0.1 +======================= + +- First public release diff --git a/doc/changelog.rst b/doc/changelog.rst index 350790df..c287fbf9 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -1,723 +1,4 @@ -============ - Change log -============ - .. warning:: :ref:`The Trollius project is now deprecated! ` -2.2.post1 (2019-07-29) -====================== - -This is a packaging-only release. It is intended to be the last -release. - -- Release Windows wheels for CPython 2.7. -- Use ``python_requires`` to restrict installation to Python 2.7. - -Version 2.2 (2018-03-09) -======================== - -Changes: - -* ``run_aiotest.py`` has been removed since the ``aiotest`` project has been - removed -* Add the "No Maintenance Intended" badge to README -* The Trollius documentation is no longer online: - http://trollius.readthedocs.io/ has been removed -* Update the GitHub URL to: https://github.com/vstinner/trollius - -Version 2.1 (2016-02-05) -======================== - -Changes: - -* :ref:`The Trollius project is now deprecated `. -* Ugly hack to support Python 3.5 with the PEP 479. asyncio coroutines are - not supported on Python 3.5. -* Better exception traceback. Patch written by Dhawal Yogesh Bhanushali. -* Drop support for Python 2.6 and 3.2. -* Fix tests on Windows with Python 2. Patch written by Gabi Davar. - - -Version 2.0 (2015-07-13) -======================== - -Summary: - -* SSL support on Windows for proactor event loop with Python 3.5 and newer -* Many race conditions were fixed in the proactor event loop -* Trollius moved to Github and the fork was recreated on top to asyncio git - repository -* Many resource leaks (ex: unclosed sockets) were fixed -* Optimization of socket connections: avoid: don't call the slow getaddrinfo() - function to ensure that the address is already resolved. The check is now - only done in debug mode. - -The Trollius project moved from Bitbucket to Github. The project is now a fork -of the Git repository of the asyncio project (previously called the "tulip" -project), the trollius source code lives in the trollius branch. - -The new Trollius home page is now: https://github.com/haypo/trollius - -The asyncio project moved to: https://github.com/python/asyncio - -Note: the PEP 492 is not supported in trollius yet. - -API changes: - -* Issue #234: Drop JoinableQueue on Python 3.5+ -* add the asyncio.ensure_future() function, previously called async(). - The async() function is now deprecated. -* New event loop methods: set_task_factory() and get_task_factory(). -* Python issue #23347: Make BaseSubprocessTransport.wait() private. -* Python issue #23347: send_signal(), kill() and terminate() methods of - BaseSubprocessTransport now check if the transport was closed and if the - process exited. -* Python issue #23209, #23225: selectors.BaseSelector.get_key() now raises a - RuntimeError if the selector is closed. And selectors.BaseSelector.close() - now clears its internal reference to the selector mapping to break a - reference cycle. Initial patch written by Martin Richard. -* PipeHandle.fileno() of asyncio.windows_utils now raises an exception if the - pipe is closed. -* Remove Overlapped.WaitNamedPipeAndConnect() of the _overlapped module, - it is no more used and it had issues. -* Python issue #23537: Remove 2 unused private methods of - BaseSubprocessTransport: _make_write_subprocess_pipe_proto, - _make_read_subprocess_pipe_proto. Methods only raise NotImplementedError and - are never used. -* Remove unused SSLProtocol._closing attribute - -New SSL implementation: - -* Python issue #22560: On Python 3.5 and newer, use new SSL implementation - based on ssl.MemoryBIO instead of the legacy SSL implementation. Patch - written by Antoine Pitrou, based on the work of Geert Jansen. -* If available, the new SSL implementation can be used by ProactorEventLoop to - support SSL. - -Enhance, fix and cleanup the IocpProactor: - -* Python issue #23293: Rewrite IocpProactor.connect_pipe(). Add - _overlapped.ConnectPipe() which tries to connect to the pipe for asynchronous - I/O (overlapped): call CreateFile() in a loop until it doesn't fail with - ERROR_PIPE_BUSY. Use an increasing delay between 1 ms and 100 ms. -* Tulip issue #204: Fix IocpProactor.accept_pipe(). - Overlapped.ConnectNamedPipe() now returns a boolean: True if the pipe is - connected (if ConnectNamedPipe() failed with ERROR_PIPE_CONNECTED), False if - the connection is in progress. -* Tulip issue #204: Fix IocpProactor.recv(). If ReadFile() fails with - ERROR_BROKEN_PIPE, the operation is not pending: don't register the - overlapped. -* Python issue #23095: Rewrite _WaitHandleFuture.cancel(). - _WaitHandleFuture.cancel() now waits until the wait is cancelled to clear its - reference to the overlapped object. To wait until the cancellation is done, - UnregisterWaitEx() is used with an event instead of UnregisterWait(). -* Python issue #23293: Rewrite IocpProactor.connect_pipe() as a coroutine. Use - a coroutine with asyncio.sleep() instead of call_later() to ensure that the - scheduled call is cancelled. -* Fix ProactorEventLoop.start_serving_pipe(). If a client was connected before - the server was closed: drop the client (close the pipe) and exit -* Python issue #23293: Cleanup IocpProactor.close(). The special case for - connect_pipe() is no more needed. connect_pipe() doesn't use overlapped - operations anymore. -* IocpProactor.close(): don't cancel futures which are already cancelled -* Enhance (fix) BaseProactorEventLoop._loop_self_reading(). Handle correctly - CancelledError: just exit. On error, log the exception and exit; don't try to - close the event loop (it doesn't work). - -Bug fixes: - -* Fix LifoQueue's and PriorityQueue's put() and task_done(). -* Issue #222: Fix the @coroutine decorator for functions without __name__ - attribute like functools.partial(). Enhance also the representation of a - CoroWrapper if the coroutine function is a functools.partial(). -* Python issue #23879: SelectorEventLoop.sock_connect() must not call connect() - again if the first call to connect() raises an InterruptedError. When the C - function connect() fails with EINTR, the connection runs in background. We - have to wait until the socket becomes writable to be notified when the - connection succeed or fails. -* Fix _SelectorTransport.__repr__() if the event loop is closed -* Fix repr(BaseSubprocessTransport) if it didn't start yet -* Workaround CPython bug #23353. Don't use yield/yield-from in an except block - of a generator. Store the exception and handle it outside the except block. -* Fix BaseSelectorEventLoop._accept_connection(). Close the transport on error. - In debug mode, log errors using call_exception_handler(). -* Fix _UnixReadPipeTransport and _UnixWritePipeTransport. Only start reading - when connection_made() has been called. -* Fix _SelectorSslTransport.close(). Don't call protocol.connection_lost() if - protocol.connection_made() was not called yet: if the SSL handshake failed or - is still in progress. The close() method can be called if the creation of the - connection is cancelled, by a timeout for example. -* Fix _SelectorDatagramTransport constructor. Only start reading after - connection_made() has been called. -* Fix _SelectorSocketTransport constructor. Only start reading when - connection_made() has been called: protocol.data_received() must not be - called before protocol.connection_made(). -* Fix SSLProtocol.eof_received(). Wake-up the waiter if it is not done yet. -* Close transports on error. Fix create_datagram_endpoint(), - connect_read_pipe() and connect_write_pipe(): close the transport if the task - is cancelled or on error. -* Close the transport on subprocess creation failure -* Fix _ProactorBasePipeTransport.close(). Set the _read_fut attribute to None - after cancelling it. -* Python issue #23243: Fix _UnixWritePipeTransport.close(). Do nothing if the - transport is already closed. Before it was not possible to close the - transport twice. -* Python issue #23242: SubprocessStreamProtocol now closes the subprocess - transport at subprocess exit. Clear also its reference to the transport. -* Fix BaseEventLoop._create_connection_transport(). Close the transport if the - creation of the transport (if the waiter) gets an exception. -* Python issue #23197: On SSL handshake failure, check if the waiter is - cancelled before setting its exception. -* Python issue #23173: Fix SubprocessStreamProtocol.connection_made() to handle - cancelled waiter. -* Python issue #23173: If an exception is raised during the creation of a - subprocess, kill the subprocess (close pipes, kill and read the return - status). Log an error in such case. -* Python issue #23209: Break some reference cycles in asyncio. Patch written by - Martin Richard. - -Optimization: - -* Only call _check_resolved_address() in debug mode. _check_resolved_address() - is implemented with getaddrinfo() which is slow. If available, use - socket.inet_pton() instead of socket.getaddrinfo(), because it is much faster - -Other changes: - -* Python issue #23456: Add missing @coroutine decorators -* Python issue #23475: Fix test_close_kill_running(). Really kill the child - process, don't mock completly the Popen.kill() method. This change fix memory - leaks and reference leaks. -* BaseSubprocessTransport: repr() mentions when the child process is running -* BaseSubprocessTransport.close() doesn't try to kill the process if it already - finished. -* Tulip issue #221: Fix docstring of QueueEmpty and QueueFull -* Fix subprocess_attach_write_pipe example. Close the transport, not directly - the pipe. -* Python issue #23347: send_signal(), terminate(), kill() don't check if the - transport was closed. The check broken a Tulip example and this limitation is - arbitrary. Check if _proc is None should be enough. Enhance also close(): do - nothing when called the second time. -* Python issue #23347: Refactor creation of subprocess transports. -* Python issue #23243: On Python 3.4 and newer, emit a ResourceWarning when an - event loop or a transport is not explicitly closed -* tox.ini: enable ResourceWarning warnings -* Python issue #23243: test_sslproto: Close explicitly transports -* SSL transports now clear their reference to the waiter. -* Python issue #23208: Add BaseEventLoop._current_handle. In debug mode, - BaseEventLoop._run_once() now sets the BaseEventLoop._current_handle - attribute to the handle currently executed. -* Replace test_selectors.py with the file of Python 3.5 adapted for asyncio and - Python 3.3. -* Tulip issue #184: FlowControlMixin constructor now get the event loop if the - loop parameter is not set. -* _ProactorBasePipeTransport now sets the _sock attribute to None when the - transport is closed. -* Python issue #23219: cancelling wait_for() now cancels the task -* Python issue #23243: Close explicitly event loops and transports in tests -* Python issue #23140: Fix cancellation of Process.wait(). Check the state of - the waiter future before setting its result. -* Python issue #23046: Expose the BaseEventLoop class in the asyncio namespace -* Python issue #22926: In debug mode, call_soon(), call_at() and call_later() - methods of BaseEventLoop now use the identifier of the current thread to - ensure that they are called from the thread running the event loop. Before, - the get_event_loop() method was used to check the thread, and no exception - was raised when the thread had no event loop. Now the methods always raise an - exception in debug mode when called from the wrong thread. It should help to - notice misusage of the API. - -2014-12-19: Version 1.0.4 -========================= - -Changes: - -* Python issue #22922: create_task(), call_at(), call_soon(), - call_soon_threadsafe() and run_in_executor() now raise an error if the event - loop is closed. Initial patch written by Torsten Landschoff. -* Python issue #22921: Don't require OpenSSL SNI to pass hostname to ssl - functions. Patch by Donald Stufft. -* Add run_aiotest.py: run the aiotest test suite. -* tox now also run the aiotest test suite -* Python issue #23074: get_event_loop() now raises an exception if the thread - has no event loop even if assertions are disabled. - -Bugfixes: - -* Fix a race condition in BaseSubprocessTransport._try_finish(): ensure that - connection_made() is called before connection_lost(). -* Python issue #23009: selectors, make sure EpollSelecrtor.select() works when - no file descriptor is registered. -* Python issue #22922: Fix ProactorEventLoop.close(). Call - _stop_accept_futures() before sestting the _closed attribute, otherwise - call_soon() raises an error. -* Python issue #22429: Fix EventLoop.run_until_complete(), don't stop the event - loop if a BaseException is raised, because the event loop is already stopped. -* Initialize more Future and Task attributes in the class definition to avoid - attribute errors in destructors. -* Python issue #22685: Set the transport of stdout and stderr StreamReader - objects in the SubprocessStreamProtocol. It allows to pause the transport to - not buffer too much stdout or stderr data. -* BaseSelectorEventLoop.close() now closes the self-pipe before calling the - parent close() method. If the event loop is already closed, the self-pipe is - not unregistered from the selector. - - -2014-10-20: Version 1.0.3 -========================= - -Changes: - -* On Python 2 in debug mode, Future.set_exception() now stores the traceback - object of the exception in addition to the exception object. When a task - waiting for another task and the other task raises an exception, the - traceback object is now copied with the exception. Be careful, storing the - traceback object may create reference leaks. -* Use ssl.create_default_context() if available to create the default SSL - context: Python 2.7.9 and newer, or Python 3.4 and newer. -* On Python 3.5 and newer, reuse socket.socketpair() in the windows_utils - submodule. -* On Python 3.4 and newer, use os.set_inheritable(). -* Enhance protocol representation: add "closed" or "closing" info. -* run_forever() now consumes BaseException of the temporary task. If the - coroutine raised a BaseException, consume the exception to not log a warning. - The caller doesn't have access to the local task. -* Python issue 22448: cleanup _run_once(), only iterate once to remove delayed - calls that were cancelled. -* The destructor of the Return class now shows where the Return object was - created. -* run_tests.py doesn't catch any exceptions anymore when loading tests, only - catch SkipTest. -* Fix (SSL) tests for the future Python 2.7.9 which includes a "new" ssl - module: module backported from Python 3.5. -* BaseEventLoop.add_signal_handler() now raises an exception if the parameter - is a coroutine function. -* Coroutine functions and objects are now rejected with a TypeError by the - following functions: add_signal_handler(), call_at(), call_later(), - call_soon(), call_soon_threadsafe(), run_in_executor(). - - -2014-10-02: Version 1.0.2 -========================= - -This release fixes bugs. It also provides more information in debug mode on -error. - -Major changes: - -* Tulip issue #203: Add _FlowControlMixin.get_write_buffer_limits() method. -* Python issue #22063: socket operations (socket,recv, sock_sendall, - sock_connect, sock_accept) of SelectorEventLoop now raise an exception in - debug mode if sockets are in blocking mode. - -Major bugfixes: - -* Tulip issue #205: Fix a race condition in BaseSelectorEventLoop.sock_connect(). -* Tulip issue #201: Fix a race condition in wait_for(). Don't raise a - TimeoutError if we reached the timeout and the future completed in the same - iteration of the event loop. A side effect of the bug is that Queue.get() - looses items. -* PipeServer.close() now cancels the "accept pipe" future which cancels the - overlapped operation. - -Other changes: - -* Python issue #22448: Improve cancelled timer callback handles cleanup. Patch - by Joshua Moore-Oliva. -* Python issue #22369: Change "context manager protocol" to "context management - protocol". Patch written by Serhiy Storchaka. -* Tulip issue #206: In debug mode, keep the callback in the representation of - Handle and TimerHandle after cancel(). -* Tulip issue #207: Fix test_tasks.test_env_var_debug() to use correct asyncio - module. -* runtests.py: display a message to mention if tests are run in debug or - release mode -* Tulip issue #200: Log errors in debug mode instead of simply ignoring them. -* Tulip issue #200: _WaitHandleFuture._unregister_wait() now catchs and logs - exceptions. -* _fatal_error() method of _UnixReadPipeTransport and _UnixWritePipeTransport - now log all exceptions in debug mode -* Fix debug log in BaseEventLoop.create_connection(): get the socket object - from the transport because SSL transport closes the old socket and creates a - new SSL socket object. -* Remove the _SelectorSslTransport._rawsock attribute: it contained the closed - socket (not very useful) and it was not used. -* Fix _SelectorTransport.__repr__() if the transport was closed -* Use the new os.set_blocking() function of Python 3.5 if available - - -2014-07-30: Version 1.0.1 -========================= - -This release supports PyPy and has a better support of asyncio coroutines, -especially in debug mode. - -Changes: - -* Tulip issue #198: asyncio.Condition now accepts an optional lock object. -* Enhance representation of Future and Future subclasses: add "created at". - -Bugfixes: - -* Fix Trollius issue #9: @trollius.coroutine now works on callbable objects - (without ``__name__`` attribute), not only on functions. -* Fix Trollius issue #13: asyncio futures are now accepted in all functions: - as_completed(), async(), @coroutine, gather(), run_until_complete(), - wrap_future(). -* Fix support of asyncio coroutines in debug mode. If the last instruction - of the coroutine is "yield from", it's an asyncio coroutine and it does not - need to use From(). -* Fix and enhance _WaitHandleFuture.cancel(): - - - Tulip issue #195: Fix a crash on Windows: don't call UnregisterWait() twice - if a _WaitHandleFuture is cancelled twice. - - Fix _WaitHandleFuture.cancel(): return the result of the parent cancel() - method (True or False). - - _WaitHandleFuture.cancel() now notify IocpProactor through the overlapped - object that the wait was cancelled. - -* Tulip issue #196: _OverlappedFuture now clears its reference to the - overlapped object. IocpProactor keeps a reference to the overlapped object - until it is notified of its completion. Log also an error in debug mode if it - gets unexpected notifications. -* Fix runtest.py to be able to log at level DEBUG. - -Other changes: - -* BaseSelectorEventLoop._write_to_self() now logs errors in debug mode. -* Fix as_completed(): it's not a coroutine, don't use ``yield From(...)`` but - ``yield ...`` -* Tulip issue #193: Convert StreamWriter.drain() to a classic coroutine. -* Tulip issue #194: Don't use sys.getrefcount() in unit tests: the full test - suite now pass on PyPy. - - -2014-07-21: Version 1.0 -======================= - -Major Changes -------------- - -* Event loops have a new ``create_task()`` method, which is now the recommanded - way to create a task object. This method can be overriden by third-party - event loops to use their own task class. -* The debug mode has been improved a lot. Set ``TROLLIUSDEBUG`` envrironment - variable to ``1`` and configure logging to log at level ``logging.DEBUG`` - (ex: ``logging.basicConfig(level=logging.DEBUG)``). Changes: - - - much better representation of Trollius objects (ex: ``repr(task)``): - unified ```` format, use qualified name when available - - show the traceback where objects were created - - show the current filename and line number for coroutine - - show the filename and line number where objects were created - - log most important socket events - - log most important subprocess events - -* ``Handle.cancel()`` now clears references to callback and args -* Log an error if a Task is destroyed while it is still pending, but only on - Python 3.4 and newer. -* Fix for asyncio coroutines when passing tuple value in debug mode. - ``CoroWrapper.send()`` now checks if it is called from a "yield from" - generator to decide if the parameter should be unpacked or not. -* ``Process.communicate()`` now ignores ``BrokenPipeError`` and - ``ConnectionResetError`` exceptions. -* Rewrite signal handling on Python 3.3 and newer to fix a race condition: use - the "self-pipe" to get signal numbers. - - -Other Changes -------------- - -* Fix ``ProactorEventLoop()`` in debug mode -* Fix a race condition when setting the result of a Future with - ``call_soon()``. Add an helper, a private method, to set the result only if - the future was not cancelled. -* Fix ``asyncio.__all__``: export also ``unix_events`` and ``windows_events`` - symbols. For example, on Windows, it was not possible to get - ``ProactorEventLoop`` or ``DefaultEventLoopPolicy`` using ``from asyncio - import *``. -* ``Handle.cancel()`` now clears references to callback and args -* Make Server attributes and methods private, the sockets attribute remains - public. -* BaseEventLoop.create_datagram_endpoint() now waits until - protocol.connection_made() has been called. Document also why transport - constructors use a waiter. -* _UnixSubprocessTransport: fix file mode of stdin: open stdin in write mode, - not in read mode. - - -2014-06-23: version 0.4 -======================= - -Changes between Trollius 0.3 and 0.4: - -* Trollius event loop now supports asyncio coroutines: - - - Trollius coroutines can yield asyncio coroutines, - - asyncio coroutines can yield Trollius coroutines, - - asyncio.set_event_loop() accepts a Trollius event loop, - - asyncio.set_event_loop_policy() accepts a Trollius event loop policy. - -* The ``PYTHONASYNCIODEBUG`` envrionment variable has been renamed to - ``TROLLIUSDEBUG``. The environment variable is now used even if the Python - command line option ``-E`` is used. -* Synchronize with Tulip. -* Support PyPy (fix subproces, fix unit tests). - -Tulip changes: - -* Tulip issue #171: BaseEventLoop.close() now raises an exception if the event - loop is running. You must first stop the event loop and then wait until it - stopped, before closing it. -* Tulip issue #172: only log selector timing in debug mode -* Enable the debug mode of event loops when the ``TROLLIUSDEBUG`` environment - variable is set -* BaseEventLoop._assert_is_current_event_loop() now only raises an exception if - the current loop is set. -* Tulip issue #105: in debug mode, log callbacks taking more than 100 ms to be - executed. -* Python issue 21595: ``BaseSelectorEventLoop._read_from_self()`` reads all - available bytes from the "self pipe", not only a single byte. This change - reduces the risk of having the pipe full and so getting the "BlockingIOError: - [Errno 11] Resource temporarily unavailable" message. -* Python issue 21723: asyncio.Queue: support any type of number (ex: float) for - the maximum size. Patch written by Vajrasky Kok. -* Issue #173: Enhance repr(Handle) and repr(Task): add the filename and line - number, when available. For task, the current line number of the coroutine - is used. -* Add BaseEventLoop.is_closed() method. run_forever() and run_until_complete() - methods now raises an exception if the event loop was closed. -* Make sure that socketpair() close sockets on error. Close the listening - socket if sock.bind() raises an exception. -* Fix ResourceWarning: close sockets on errors. - BaseEventLoop.create_connection(), BaseEventLoop.create_datagram_endpoint() - and _UnixSelectorEventLoop.create_unix_server() now close the newly created - socket on error. -* Rephrase and fix docstrings. -* Fix tests on Windows: wait for the subprocess exit. Before, regrtest failed - to remove the temporary test directory because the process was still running - in this directory. -* Refactor unit tests. - -On Python 3.5, generators now get their name from the function, no more from -the code. So the ``@coroutine`` decorator doesn't loose the original name of -the function anymore. - - -2014-05-26: version 0.3 -======================= - -Rename the Python module ``asyncio`` to ``trollius`` to support Python 3.4. On -Python 3.4, there is already a module called ``asyncio`` in the standard -library which conflicted with ``asyncio`` module of Trollius 0.2. To write -asyncio code working on Trollius and Tulip, use ``import trollius as asyncio``. - -Changes between Trollius 0.2 and 0.3: - -* Synchronize with Tulip 3.4.1. -* Enhance Trollius documentation. -* Trollius issue #7: Fix ``asyncio.time_monotonic`` on Windows older than - Vista (ex: Windows 2000 and Windows XP). -* Fedora packages have been accepted. - -Changes between Tulip 3.4.0 and 3.4.1: - -* Pull in Solaris ``devpoll`` support by Giampaolo Rodola - (``trollius.selectors`` module). -* Add options ``-r`` and ``--randomize`` to runtests.py to randomize test - order. -* Add a simple echo client/server example. -* Tulip issue #166: Add ``__weakref__`` slots to ``Handle`` and - ``CoroWrapper``. -* ``EventLoop.create_unix_server()`` now raises a ``ValueError`` if path and - sock are specified at the same time. -* Ensure ``call_soon()``, ``call_later()`` and ``call_at()`` are invoked on - current loop in debug mode. Raise a ``RuntimeError`` if the event loop of the - current thread is different. The check should help to debug thread-safetly - issue. Patch written by David Foster. -* Tulip issue #157: Improve test_events.py, avoid ``run_briefly()`` which is - not reliable. -* Reject add/remove reader/writer when event loop is closed. - -Bugfixes of Tulip 3.4.1: - -* Tulip issue #168: ``StreamReader.read(-1)`` from pipe may hang if - data exceeds buffer limit. -* CPython issue #21447: Fix a race condition in - ``BaseEventLoop._write_to_self()``. -* Different bugfixes in ``CoroWrapper`` of ``trollius.coroutines``, class used - when running Trollius in debug mode: - - - Fix ``CoroWrapper`` to workaround yield-from bug in CPython 3.4.0. The - CPython bug is now fixed in CPython 3.4.1 and 3.5. - - Make sure ``CoroWrapper.send`` proxies one argument correctly. - - CPython issue #21340: Be careful accessing instance variables in ``__del__``. - - Tulip issue #163: Add ``gi_{frame,running,code}`` properties to - ``CoroWrapper``. - -* Fix ``ResourceWarning`` warnings -* Tulip issue #159: Fix ``windows_utils.socketpair()``. Use ``"127.0.0.1"`` - (IPv4) or ``"::1"`` (IPv6) host instead of ``"localhost"``, because - ``"localhost"`` may be a different IP address. Reject also invalid arguments: - only ``AF_INET`` and ``AF_INET6`` with ``SOCK_STREAM`` (and ``proto=0``) are - supported. -* Tulip issue #158: ``Task._step()`` now also sets ``self`` to ``None`` if an - exception is raised. ``self`` is set to ``None`` to break a reference cycle. - - -2014-03-04: version 0.2 -======================= - -Trollius now uses ``yield From(...)`` syntax which looks close to Tulip ``yield -from ...`` and allows to port more easily Trollius code to Tulip. The usage of -``From()`` is not mandatory yet, but it may become mandatory in a future -version. However, if ``yield`` is used without ``From``, an exception is -raised if the event loop is running in debug mode. - -Major changes: - -* Replace ``yield ...`` syntax with ``yield From(...)`` -* On Python 2, Future.set_exception() now only saves the traceback if the debug - mode of the event loop is enabled for best performances in production mode. - Use ``loop.set_debug(True)`` to save the traceback. - -Bugfixes: - -* Fix ``BaseEventLoop.default_exception_handler()`` on Python 2: get the - traceback from ``sys.exc_info()`` -* Fix unit tests on SSL sockets on Python older than 2.6.6. Example: - Mac OS 10.6 with Python 2.6.1 or OpenIndiana 148 with Python 2.6.4. -* Fix error handling in the asyncio.time_monotonic module -* Fix acquire() method of Lock, Condition and Semaphore: don't return a context - manager but True, as Tulip. Task._step() now does the trick. - -Other changes: - -* tox.ini: set PYTHONASYNCIODEBUG to 1 to run tests - -2014-02-25: version 0.1.6 -========================= - -Trollius changes: - -* Add a new Sphinx documentation: - https://trollius.readthedocs.io/ -* tox: pass posargs to nosetests. Patch contributed by Ian Wienand. -* Fix support of Python 3.2 and add py32 to tox.ini -* Merge with Tulip 0.4.1 - -Major changes of Tulip 0.4.1: - -* Issue #81: Add support for UNIX Domain Sockets. New APIs: - - - loop.create_unix_connection() - - loop.create_unix_server() - - streams.open_unix_connection() - - streams.start_unix_server() - -* Issue #80: Add new event loop exception handling API. New APIs: - - - loop.set_exception_handler() - - loop.call_exception_handler() - - loop.default_exception_handler() - -* Issue #136: Add get_debug() and set_debug() methods to BaseEventLoopTests. - Add also a ``PYTHONASYNCIODEBUG`` environment variable to debug coroutines - since Python startup, to be able to debug coroutines defined directly in the - asyncio module. - -Other changes of Tulip 0.4.1: - -* asyncio.subprocess: Fix a race condition in communicate() -* Fix _ProactorWritePipeTransport._pipe_closed() -* Issue #139: Improve error messages on "fatal errors". -* Issue #140: WriteTransport.set_write_buffer_size() to call - _maybe_pause_protocol() -* Issue #129: BaseEventLoop.sock_connect() now raises an error if the address - is not resolved (hostname instead of an IP address) for AF_INET and - AF_INET6 address families. -* Issue #131: as_completed() and wait() now raises a TypeError if the list of - futures is not a list but a Future, Task or coroutine object -* Python issue #20495: Skip test_read_pty_output() of test_asyncio on FreeBSD - older than FreeBSD 8 -* Issue #130: Add more checks on subprocess_exec/subprocess_shell parameters -* Issue #126: call_soon(), call_soon_threadsafe(), call_later(), call_at() - and run_in_executor() now raise a TypeError if the callback is a coroutine - function. -* Python issue #20505: BaseEventLoop uses again the resolution of the clock - to decide if scheduled tasks should be executed or not. - - -2014-02-10: version 0.1.5 -========================= - -- Merge with Tulip 0.3.1: - - * New asyncio.subprocess module - * _UnixWritePipeTransport now also supports character devices, as - _UnixReadPipeTransport. Patch written by Jonathan Slenders. - * StreamReader.readexactly() now raises an IncompleteReadError if the - end of stream is reached before we received enough bytes, instead of - returning less bytes than requested. - * poll and epoll selectors now round the timeout away from zero (instead of - rounding towards zero) to fix a performance issue - * asyncio.queue: Empty renamed to QueueEmpty, Full to QueueFull - * _fatal_error() of _UnixWritePipeTransport and _ProactorBasePipeTransport - don't log BrokenPipeError nor ConnectionResetError - * Future.set_exception(exc) now instanciate exc if it is a class - * streams.StreamReader: Use bytearray instead of deque of bytes for internal - buffer - -- Fix test_wait_for() unit test - -2014-01-22: version 0.1.4 -========================= - -- The project moved to https://bitbucket.org/enovance/trollius -- Fix CoroWrapper (_DEBUG=True): add missing import -- Emit a warning when Return is not raised -- Merge with Tulip to get latest Tulip bugfixes -- Fix dependencies in tox.ini for the different Python versions - -2014-01-13: version 0.1.3 -========================= - -- Workaround bugs in the ssl module of Python older than 2.6.6. For example, - Mac OS 10.6 (Snow Leopard) uses Python 2.6.1. -- ``return x, y`` is now written ``raise Return(x, y)`` instead of - ``raise Return((x, y))`` -- Support "with (yield lock):" syntax for Lock, Condition and Semaphore -- SSL support is now optional: don't fail if the ssl module is missing -- Add tox.ini, tool to run unit tests. For example, "tox -e py27" creates a - virtual environment to run tests with Python 2.7. - -2014-01-08: version 0.1.2 -========================= - -- Trollius now supports CPython 2.6-3.4, PyPy and Windows. All unit tests - pass with CPython 2.7 on Linux. -- Fix Windows support. Fix compilation of the _overlapped module and add a - asyncio._winapi module (written in pure Python). Patch written by Marc - Schlaich. -- Support Python 2.6: require an extra dependency, - ordereddict (and unittest2 for unit tests) -- Support Python 3.2, 3.3 and 3.4 -- Support PyPy 2.2 -- Don't modify __builtins__ nor the ssl module to inject backported exceptions - like BlockingIOError or SSLWantReadError. Exceptions are available in the - asyncio module, ex: asyncio.BlockingIOError. - -2014-01-06: version 0.1.1 -========================= - -- Fix asyncio.time_monotonic on Mac OS X -- Fix create_connection(ssl=True) -- Don't export backported SSLContext in the ssl module anymore to not confuse - libraries testing hasattr(ssl, "SSLContext") -- Relax dependency on the backported concurrent.futures module: use a - synchronous executor if the module is missing - -2014-01-04: version 0.1 -======================= - -- First public release +.. include:: ../CHANGES.rst From 284b31816c739ce61586127ae1594b63e816533b Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Wed, 28 Apr 2021 12:22:20 -0500 Subject: [PATCH 1499/1502] Preparing release 2.2.1 --- CHANGES.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGES.rst b/CHANGES.rst index 3d0dc8c3..d3042c9b 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -6,7 +6,7 @@ The Trollius project is now deprecated! -2.2.1 (unreleased) +2.2.1 (2021-04-28) ================== - Properly reraise socket.error with an errno of EBADF as an OSError. From 007602ef9a6006c9c89e641b8a8e4573eedad220 Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Wed, 28 Apr 2021 12:23:56 -0500 Subject: [PATCH 1500/1502] Back to development: 2.2.2 --- CHANGES.rst | 6 ++++++ setup.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGES.rst b/CHANGES.rst index d3042c9b..f2e31294 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -6,6 +6,12 @@ The Trollius project is now deprecated! +2.2.2 (unreleased) +================== + +- Nothing changed yet. + + 2.2.1 (2021-04-28) ================== diff --git a/setup.py b/setup.py index 20c654b8..afbb3cd8 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ setup( name="trollius", - version="2.2.1", + version='2.2.2.dev0', license="Apache License 2.0", author='Victor Stinner', author_email='victor.stinner@gmail.com', From 5f541ea9b91564051cf5b99b442b37344f9ab420 Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Wed, 28 Apr 2021 12:30:00 -0500 Subject: [PATCH 1501/1502] Temporarily playing with tags to get appveyor wheels. --- CHANGES.rst | 6 ------ setup.py | 2 +- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index f2e31294..d3042c9b 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -6,12 +6,6 @@ The Trollius project is now deprecated! -2.2.2 (unreleased) -================== - -- Nothing changed yet. - - 2.2.1 (2021-04-28) ================== diff --git a/setup.py b/setup.py index 997bac0f..02aa0333 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ setup( name="trollius", - version='2.2.2.dev0', + version='2.2.1', license="Apache License 2.0", author='Victor Stinner', author_email='victor.stinner@gmail.com', From 2df073c33c7fe8ab0ac60cad4826de730a411a8c Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Wed, 28 Apr 2021 12:30:59 -0500 Subject: [PATCH 1502/1502] Revert. --- CHANGES.rst | 6 ++++++ setup.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGES.rst b/CHANGES.rst index d3042c9b..f2e31294 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -6,6 +6,12 @@ The Trollius project is now deprecated! +2.2.2 (unreleased) +================== + +- Nothing changed yet. + + 2.2.1 (2021-04-28) ================== diff --git a/setup.py b/setup.py index 02aa0333..997bac0f 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ setup( name="trollius", - version='2.2.1', + version='2.2.2.dev0', license="Apache License 2.0", author='Victor Stinner', author_email='victor.stinner@gmail.com',